Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions openjudge/models/openai_chat_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
"""OpenAI Client."""

import copy
import os
from typing import Any, AsyncGenerator, Callable, Dict, Literal, Type
Expand Down Expand Up @@ -27,6 +28,16 @@ def _format_audio_data_for_qwen_omni(messages: list[dict | ChatMessage]) -> list
"""
format_data = []
for msg in messages:
msg_dict = msg.to_dict() if isinstance(msg, ChatMessage) else msg
if isinstance(msg_dict.get("content"), list):
for block in msg_dict["content"]:
if (
isinstance(block, dict)
and "input_audio" in block
and isinstance(block["input_audio"].get("data"), str)
):
if not block["input_audio"]["data"].startswith("http"):
block["input_audio"]["data"] = "data:;base64," + block["input_audio"]["data"]
try:
msg_copy = copy.deepcopy(msg)
msg_dict = msg_copy.to_dict() if isinstance(msg_copy, ChatMessage) else msg_copy
Expand Down Expand Up @@ -58,6 +69,8 @@ def __init__(
reasoning_effort: Literal["low", "medium", "high"] | None = None,
organization: str | None = None,
client_args: Dict[str, Any] | None = None,
max_retries: int | None = None,
timeout: float | None = None,
**kwargs: Any,
) -> None:
"""Initialize the openai client.
Expand Down Expand Up @@ -94,6 +107,11 @@ def __init__(

if organization:
client_args["organization"] = organization
if max_retries is not None:
client_args["max_retries"] = max_retries

if timeout is not None:
client_args["timeout"] = timeout

self.client = AsyncOpenAI(**client_args)

Expand Down Expand Up @@ -289,7 +307,10 @@ def _handle_non_streaming_response(
parsed_response.parsed.update(callback_result)
except Exception as e:
# Log the exception but don't fail the entire operation
logger.warning(f"Callback function raised an exception: {type(e).__name__}: {e}", exc_info=True)
logger.warning(
f"Callback function raised an exception: {type(e).__name__}: {e}",
exc_info=True,
)

return parsed_response

Expand Down Expand Up @@ -386,6 +407,9 @@ async def _handle_streaming_response(
final_response.parsed = final_response.parsed or {}
final_response.parsed.update(callback_result)
except Exception as e:
logger.warning(f"Callback function raised an exception: {type(e).__name__}: {e}", exc_info=True)
logger.warning(
f"Callback function raised an exception: {type(e).__name__}: {e}",
exc_info=True,
)

yield final_response
49 changes: 49 additions & 0 deletions tests/models/test_openai_chat_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
"""Unit tests for OpenAIChatModel."""

import asyncio
from unittest.mock import AsyncMock, patch

Expand Down Expand Up @@ -298,6 +299,54 @@ def test_qwen_omni_audio_formatting(self):
"data:;base64,",
)

@pytest.mark.parametrize(
"init_kwargs, expected_in_call, not_expected_in_call",
[
(
{"max_retries": 5, "timeout": 120.0},
{"max_retries": 5, "timeout": 120.0},
[],
),
({}, {}, ["max_retries", "timeout"]),
({"max_retries": None, "timeout": None}, {}, ["max_retries", "timeout"]),
({"max_retries": 3}, {"max_retries": 3}, ["timeout"]),
({"timeout": 30.0}, {"timeout": 30.0}, ["max_retries"]),
],
ids=[
"with_retries_and_timeout",
"defaults",
"with_none_values",
"with_retries_only",
"with_timeout_only",
],
)
@patch("openjudge.models.openai_chat_model.AsyncOpenAI")
def test_client_initialization_with_retries_and_timeout(
self,
mock_async_openai,
init_kwargs,
expected_in_call,
not_expected_in_call,
):
"""Test that max_retries and timeout parameters are passed to AsyncOpenAI client correctly."""
OpenAIChatModel(
model="gpt-4",
api_key="test-key",
**init_kwargs,
)

mock_async_openai.assert_called_once()
call_kwargs = mock_async_openai.call_args[1]

assert call_kwargs["api_key"] == "test-key"

for key, value in expected_in_call.items():
assert key in call_kwargs
assert call_kwargs[key] == value

for key in not_expected_in_call:
assert key not in call_kwargs

Comment on lines 323 to 346
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These five tests for max_retries and timeout have very similar structures and contain duplicated code. They can be consolidated into a single parameterized test using pytest.mark.parametrize. This will make the tests more concise, easier to read, and simpler to maintain or extend in the future.

    @pytest.mark.parametrize(
        "init_kwargs, expected_in_call, not_expected_in_call",
        [
            (
                {"max_retries": 5, "timeout": 120.0},
                {"max_retries": 5, "timeout": 120.0},
                [],
            ),
            ({}, {}, ["max_retries", "timeout"]),
            ({"max_retries": None, "timeout": None}, {}, ["max_retries", "timeout"]),
            ({"max_retries": 3}, {"max_retries": 3}, ["timeout"]),
            ({"timeout": 30.0}, {"timeout": 30.0}, ["max_retries"]),
        ],
        ids=[
            "with_retries_and_timeout",
            "defaults",
            "with_none_values",
            "with_retries_only",
            "with_timeout_only",
        ],
    )
    @patch("openjudge.models.openai_chat_model.AsyncOpenAI")
    def test_client_initialization_with_retries_and_timeout(
        self,
        mock_async_openai,
        init_kwargs,
        expected_in_call,
        not_expected_in_call,
    ):
        """Test that max_retries and timeout parameters are passed to AsyncOpenAI client correctly."""
        OpenAIChatModel(
            model="gpt-4",
            api_key="test-key",
            **init_kwargs,
        )

        mock_async_openai.assert_called_once()
        call_kwargs = mock_async_openai.call_args[1]

        assert call_kwargs["api_key"] == "test-key"

        for key, value in expected_in_call.items():
            assert key in call_kwargs
            assert call_kwargs[key] == value

        for key in not_expected_in_call:
            assert key not in call_kwargs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


if __name__ == "__main__":
pytest.main([__file__])