diff --git a/src/app/endpoints/rlsapi_v1.py b/src/app/endpoints/rlsapi_v1.py index ffb7227a..898cb054 100644 --- a/src/app/endpoints/rlsapi_v1.py +++ b/src/app/endpoints/rlsapi_v1.py @@ -5,12 +5,14 @@ """ import logging -from typing import Annotated, Any, cast +from typing import Annotated, Any from fastapi import APIRouter, Depends, HTTPException from llama_stack_client import APIConnectionError # type: ignore -from llama_stack_client.types import UserMessage # type: ignore -from llama_stack_client.types.alpha.agents.turn import Turn +from llama_stack_client.types.chat.completion_create_params import ( + MessageOpenAISystemMessageParam, + MessageOpenAIUserMessageParam, +) import constants from authentication import get_auth_dependency @@ -27,9 +29,7 @@ ) from models.rlsapi.requests import RlsapiV1InferRequest from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse -from utils.endpoints import get_temp_agent from utils.suid import get_suid -from utils.types import content_to_str logger = logging.getLogger(__name__) router = APIRouter(tags=["rlsapi-v1"]) @@ -82,8 +82,7 @@ def _get_default_model_id() -> str: async def retrieve_simple_response(question: str) -> str: """Retrieve a simple response from the LLM for a stateless query. - Creates a temporary agent, sends a single turn with the user's question, - and returns the LLM response text. No conversation persistence or tools. + Uses direct chat completion API for simple stateless inference. Args: question: The combined user input (question + context). @@ -100,24 +99,33 @@ async def retrieve_simple_response(question: str) -> str: logger.debug("Using model %s for rlsapi v1 inference", model_id) - agent, session_id, _ = await get_temp_agent( - client, model_id, constants.DEFAULT_SYSTEM_PROMPT + sys_msg: MessageOpenAISystemMessageParam = { + "role": "system", + "content": constants.DEFAULT_SYSTEM_PROMPT, + } + user_msg: MessageOpenAIUserMessageParam = { + "role": "user", + "content": question, + } + + response = await client.chat.completions.create( + model=model_id, + messages=[sys_msg, user_msg], ) - response = await agent.create_turn( - messages=[UserMessage(role="user", content=question).model_dump()], - session_id=session_id, - stream=False, - ) - response = cast(Turn, response) + if not response.choices: + return "" - if getattr(response, "output_message", None) is None: + choice = response.choices[0] + message = getattr(choice, "message", None) + if message is None: return "" - if getattr(response.output_message, "content", None) is None: + content = getattr(message, "content", None) + if content is None: return "" - return content_to_str(response.output_message.content) + return str(content) @router.post("/infer", responses=infer_responses) diff --git a/tests/unit/app/endpoints/test_rlsapi_v1.py b/tests/unit/app/endpoints/test_rlsapi_v1.py index 39e66c45..ee6300f1 100644 --- a/tests/unit/app/endpoints/test_rlsapi_v1.py +++ b/tests/unit/app/endpoints/test_rlsapi_v1.py @@ -3,9 +3,15 @@ # pylint: disable=protected-access # pylint: disable=unused-argument +from typing import Any + import pytest from fastapi import HTTPException, status from llama_stack_client import APIConnectionError +from llama_stack_client.types.chat.completion_create_response import ( + OpenAIChatCompletion, + OpenAIChatCompletionChoice, +) from pydantic import ValidationError from pytest_mock import MockerFixture @@ -25,13 +31,36 @@ RlsapiV1Terminal, ) from models.rlsapi.responses import RlsapiV1InferResponse -from tests.unit.conftest import AgentFixtures from tests.unit.utils.auth_helpers import mock_authorization_resolvers from utils.suid import check_suid MOCK_AUTH: AuthTuple = ("test_user_id", "test_user", True, "test_token") +def _setup_chat_completions_mock(mocker: MockerFixture, create_behavior: Any) -> None: + """Set up chat.completions mock with custom create() behavior. + + Args: + mocker: The pytest-mock fixture for creating mocks. + create_behavior: The AsyncMock or side_effect for completions.create(). + """ + mock_completions = mocker.Mock() + mock_completions.create = create_behavior + + mock_chat = mocker.Mock() + mock_chat.completions = mock_completions + + mock_client = mocker.Mock() + mock_client.chat = mock_chat + + mock_client_holder = mocker.Mock() + mock_client_holder.get_client.return_value = mock_client + mocker.patch( + "app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder", + return_value=mock_client_holder, + ) + + @pytest.fixture(name="mock_configuration") def mock_configuration_fixture( mocker: MockerFixture, minimal_config: AppConfig @@ -44,79 +73,35 @@ def mock_configuration_fixture( @pytest.fixture(name="mock_llm_response") -def mock_llm_response_fixture( - mocker: MockerFixture, prepare_agent_mocks: AgentFixtures -) -> None: - """Mock the LLM integration for successful responses.""" - mock_client, mock_agent = prepare_agent_mocks - - # Create mock output message with content - mock_output_message = mocker.Mock() - mock_output_message.content = "This is a test LLM response." +def mock_llm_response_fixture(mocker: MockerFixture) -> None: + """Mock the LLM integration for successful responses via chat.completions.""" + mock_message = mocker.Mock() + mock_message.content = "This is a test LLM response." - # Create mock turn response - mock_turn = mocker.Mock() - mock_turn.output_message = mock_output_message - mock_turn.steps = [] + mock_choice = mocker.Mock(spec=OpenAIChatCompletionChoice) + mock_choice.message = mock_message - # Use AsyncMock for async method - mock_agent.create_turn = mocker.AsyncMock(return_value=mock_turn) + mock_response = mocker.Mock(spec=OpenAIChatCompletion) + mock_response.choices = [mock_choice] - # Mock get_temp_agent to return our mock agent - mocker.patch( - "app.endpoints.rlsapi_v1.get_temp_agent", - return_value=(mock_agent, "test_session_id", None), - ) - - # Mock the client holder - mock_client_holder = mocker.Mock() - mock_client_holder.get_client.return_value = mock_client - mocker.patch( - "app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder", - return_value=mock_client_holder, - ) + _setup_chat_completions_mock(mocker, mocker.AsyncMock(return_value=mock_response)) @pytest.fixture(name="mock_empty_llm_response") -def mock_empty_llm_response_fixture( - mocker: MockerFixture, prepare_agent_mocks: AgentFixtures -) -> None: - """Mock the LLM integration for empty responses (output_message=None).""" - mock_client, mock_agent = prepare_agent_mocks +def mock_empty_llm_response_fixture(mocker: MockerFixture) -> None: + """Mock chat.completions to return empty choices list.""" + mock_response = mocker.Mock(spec=OpenAIChatCompletion) + mock_response.choices = [] - # Create mock turn response with no output - mock_turn = mocker.Mock() - mock_turn.output_message = None - mock_turn.steps = [] - - # Use AsyncMock for async method - mock_agent.create_turn = mocker.AsyncMock(return_value=mock_turn) - - # Mock get_temp_agent to return our mock agent - mocker.patch( - "app.endpoints.rlsapi_v1.get_temp_agent", - return_value=(mock_agent, "test_session_id", None), - ) - - # Mock the client holder - mock_client_holder = mocker.Mock() - mock_client_holder.get_client.return_value = mock_client - mocker.patch( - "app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder", - return_value=mock_client_holder, - ) + _setup_chat_completions_mock(mocker, mocker.AsyncMock(return_value=mock_response)) @pytest.fixture(name="mock_api_connection_error") def mock_api_connection_error_fixture(mocker: MockerFixture) -> None: - """Mock AsyncLlamaStackClientHolder to raise APIConnectionError.""" - mock_client_holder = mocker.Mock() - mock_client_holder.get_client.side_effect = APIConnectionError( - request=mocker.Mock() - ) - mocker.patch( - "app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder", - return_value=mock_client_holder, + """Mock chat.completions.create() to raise APIConnectionError.""" + _setup_chat_completions_mock( + mocker, + mocker.AsyncMock(side_effect=APIConnectionError(request=mocker.Mock())), ) @@ -212,6 +197,7 @@ async def test_infer_minimal_request( assert isinstance(response, RlsapiV1InferResponse) assert response.data.text == "This is a test LLM response." + assert response.data.request_id is not None assert check_suid(response.data.request_id)