diff --git a/openjudge/models/openai_chat_model.py b/openjudge/models/openai_chat_model.py index 1d5294520..4a9bc97f4 100644 --- a/openjudge/models/openai_chat_model.py +++ b/openjudge/models/openai_chat_model.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """OpenAI Client.""" + import copy import os from typing import Any, AsyncGenerator, Callable, Dict, Literal, Type @@ -27,6 +28,16 @@ def _format_audio_data_for_qwen_omni(messages: list[dict | ChatMessage]) -> list """ format_data = [] for msg in messages: + msg_dict = msg.to_dict() if isinstance(msg, ChatMessage) else msg + if isinstance(msg_dict.get("content"), list): + for block in msg_dict["content"]: + if ( + isinstance(block, dict) + and "input_audio" in block + and isinstance(block["input_audio"].get("data"), str) + ): + if not block["input_audio"]["data"].startswith("http"): + block["input_audio"]["data"] = "data:;base64," + block["input_audio"]["data"] try: msg_copy = copy.deepcopy(msg) msg_dict = msg_copy.to_dict() if isinstance(msg_copy, ChatMessage) else msg_copy @@ -58,6 +69,8 @@ def __init__( reasoning_effort: Literal["low", "medium", "high"] | None = None, organization: str | None = None, client_args: Dict[str, Any] | None = None, + max_retries: int | None = None, + timeout: float | None = None, **kwargs: Any, ) -> None: """Initialize the openai client. @@ -94,6 +107,11 @@ def __init__( if organization: client_args["organization"] = organization + if max_retries is not None: + client_args["max_retries"] = max_retries + + if timeout is not None: + client_args["timeout"] = timeout self.client = AsyncOpenAI(**client_args) @@ -289,7 +307,10 @@ def _handle_non_streaming_response( parsed_response.parsed.update(callback_result) except Exception as e: # Log the exception but don't fail the entire operation - logger.warning(f"Callback function raised an exception: {type(e).__name__}: {e}", exc_info=True) + logger.warning( + f"Callback function raised an exception: {type(e).__name__}: {e}", + exc_info=True, + ) return parsed_response @@ -386,6 +407,9 @@ async def _handle_streaming_response( final_response.parsed = final_response.parsed or {} final_response.parsed.update(callback_result) except Exception as e: - logger.warning(f"Callback function raised an exception: {type(e).__name__}: {e}", exc_info=True) + logger.warning( + f"Callback function raised an exception: {type(e).__name__}: {e}", + exc_info=True, + ) yield final_response diff --git a/tests/models/test_openai_chat_model.py b/tests/models/test_openai_chat_model.py index c3073f3be..237f15247 100644 --- a/tests/models/test_openai_chat_model.py +++ b/tests/models/test_openai_chat_model.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Unit tests for OpenAIChatModel.""" + import asyncio from unittest.mock import AsyncMock, patch @@ -298,6 +299,54 @@ def test_qwen_omni_audio_formatting(self): "data:;base64,", ) + @pytest.mark.parametrize( + "init_kwargs, expected_in_call, not_expected_in_call", + [ + ( + {"max_retries": 5, "timeout": 120.0}, + {"max_retries": 5, "timeout": 120.0}, + [], + ), + ({}, {}, ["max_retries", "timeout"]), + ({"max_retries": None, "timeout": None}, {}, ["max_retries", "timeout"]), + ({"max_retries": 3}, {"max_retries": 3}, ["timeout"]), + ({"timeout": 30.0}, {"timeout": 30.0}, ["max_retries"]), + ], + ids=[ + "with_retries_and_timeout", + "defaults", + "with_none_values", + "with_retries_only", + "with_timeout_only", + ], + ) + @patch("openjudge.models.openai_chat_model.AsyncOpenAI") + def test_client_initialization_with_retries_and_timeout( + self, + mock_async_openai, + init_kwargs, + expected_in_call, + not_expected_in_call, + ): + """Test that max_retries and timeout parameters are passed to AsyncOpenAI client correctly.""" + OpenAIChatModel( + model="gpt-4", + api_key="test-key", + **init_kwargs, + ) + + mock_async_openai.assert_called_once() + call_kwargs = mock_async_openai.call_args[1] + + assert call_kwargs["api_key"] == "test-key" + + for key, value in expected_in_call.items(): + assert key in call_kwargs + assert call_kwargs[key] == value + + for key in not_expected_in_call: + assert key not in call_kwargs + if __name__ == "__main__": pytest.main([__file__])