Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions src/app/endpoints/rlsapi_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
"""

import logging
from typing import Annotated, Any, cast
from typing import Annotated, Any

from fastapi import APIRouter, Depends, HTTPException
from llama_stack_client import APIConnectionError # type: ignore
from llama_stack_client.types import UserMessage # type: ignore
from llama_stack_client.types.alpha.agents.turn import Turn
from llama_stack_client.types.chat.completion_create_params import (
MessageOpenAISystemMessageParam,
MessageOpenAIUserMessageParam,
)

import constants
from authentication import get_auth_dependency
Expand All @@ -27,9 +29,7 @@
)
from models.rlsapi.requests import RlsapiV1InferRequest
from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse
from utils.endpoints import get_temp_agent
from utils.suid import get_suid
from utils.types import content_to_str

logger = logging.getLogger(__name__)
router = APIRouter(tags=["rlsapi-v1"])
Expand Down Expand Up @@ -82,8 +82,7 @@ def _get_default_model_id() -> str:
async def retrieve_simple_response(question: str) -> str:
"""Retrieve a simple response from the LLM for a stateless query.

Creates a temporary agent, sends a single turn with the user's question,
and returns the LLM response text. No conversation persistence or tools.
Uses direct chat completion API for simple stateless inference.

Args:
question: The combined user input (question + context).
Expand All @@ -100,24 +99,33 @@ async def retrieve_simple_response(question: str) -> str:

logger.debug("Using model %s for rlsapi v1 inference", model_id)

agent, session_id, _ = await get_temp_agent(
client, model_id, constants.DEFAULT_SYSTEM_PROMPT
sys_msg: MessageOpenAISystemMessageParam = {
"role": "system",
"content": constants.DEFAULT_SYSTEM_PROMPT,
}
user_msg: MessageOpenAIUserMessageParam = {
"role": "user",
"content": question,
}

response = await client.chat.completions.create(
model=model_id,
messages=[sys_msg, user_msg],
)

response = await agent.create_turn(
messages=[UserMessage(role="user", content=question).model_dump()],
session_id=session_id,
stream=False,
)
response = cast(Turn, response)
if not response.choices:
return ""

if getattr(response, "output_message", None) is None:
choice = response.choices[0]
message = getattr(choice, "message", None)
if message is None:
return ""

if getattr(response.output_message, "content", None) is None:
content = getattr(message, "content", None)
if content is None:
return ""

return content_to_str(response.output_message.content)
return str(content)


@router.post("/infer", responses=infer_responses)
Expand Down
112 changes: 49 additions & 63 deletions tests/unit/app/endpoints/test_rlsapi_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
# pylint: disable=protected-access
# pylint: disable=unused-argument

from typing import Any

import pytest
from fastapi import HTTPException, status
from llama_stack_client import APIConnectionError
from llama_stack_client.types.chat.completion_create_response import (
OpenAIChatCompletion,
OpenAIChatCompletionChoice,
)
from pydantic import ValidationError
from pytest_mock import MockerFixture

Expand All @@ -25,13 +31,36 @@
RlsapiV1Terminal,
)
from models.rlsapi.responses import RlsapiV1InferResponse
from tests.unit.conftest import AgentFixtures
from tests.unit.utils.auth_helpers import mock_authorization_resolvers
from utils.suid import check_suid

MOCK_AUTH: AuthTuple = ("test_user_id", "test_user", True, "test_token")


def _setup_chat_completions_mock(mocker: MockerFixture, create_behavior: Any) -> None:
"""Set up chat.completions mock with custom create() behavior.

Args:
mocker: The pytest-mock fixture for creating mocks.
create_behavior: The AsyncMock or side_effect for completions.create().
"""
mock_completions = mocker.Mock()
mock_completions.create = create_behavior

mock_chat = mocker.Mock()
mock_chat.completions = mock_completions

mock_client = mocker.Mock()
mock_client.chat = mock_chat

mock_client_holder = mocker.Mock()
mock_client_holder.get_client.return_value = mock_client
mocker.patch(
"app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder",
return_value=mock_client_holder,
)


@pytest.fixture(name="mock_configuration")
def mock_configuration_fixture(
mocker: MockerFixture, minimal_config: AppConfig
Expand All @@ -44,79 +73,35 @@ def mock_configuration_fixture(


@pytest.fixture(name="mock_llm_response")
def mock_llm_response_fixture(
mocker: MockerFixture, prepare_agent_mocks: AgentFixtures
) -> None:
"""Mock the LLM integration for successful responses."""
mock_client, mock_agent = prepare_agent_mocks

# Create mock output message with content
mock_output_message = mocker.Mock()
mock_output_message.content = "This is a test LLM response."
def mock_llm_response_fixture(mocker: MockerFixture) -> None:
"""Mock the LLM integration for successful responses via chat.completions."""
mock_message = mocker.Mock()
mock_message.content = "This is a test LLM response."

# Create mock turn response
mock_turn = mocker.Mock()
mock_turn.output_message = mock_output_message
mock_turn.steps = []
mock_choice = mocker.Mock(spec=OpenAIChatCompletionChoice)
mock_choice.message = mock_message

# Use AsyncMock for async method
mock_agent.create_turn = mocker.AsyncMock(return_value=mock_turn)
mock_response = mocker.Mock(spec=OpenAIChatCompletion)
mock_response.choices = [mock_choice]

# Mock get_temp_agent to return our mock agent
mocker.patch(
"app.endpoints.rlsapi_v1.get_temp_agent",
return_value=(mock_agent, "test_session_id", None),
)

# Mock the client holder
mock_client_holder = mocker.Mock()
mock_client_holder.get_client.return_value = mock_client
mocker.patch(
"app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder",
return_value=mock_client_holder,
)
_setup_chat_completions_mock(mocker, mocker.AsyncMock(return_value=mock_response))


@pytest.fixture(name="mock_empty_llm_response")
def mock_empty_llm_response_fixture(
mocker: MockerFixture, prepare_agent_mocks: AgentFixtures
) -> None:
"""Mock the LLM integration for empty responses (output_message=None)."""
mock_client, mock_agent = prepare_agent_mocks
def mock_empty_llm_response_fixture(mocker: MockerFixture) -> None:
"""Mock chat.completions to return empty choices list."""
mock_response = mocker.Mock(spec=OpenAIChatCompletion)
mock_response.choices = []

# Create mock turn response with no output
mock_turn = mocker.Mock()
mock_turn.output_message = None
mock_turn.steps = []

# Use AsyncMock for async method
mock_agent.create_turn = mocker.AsyncMock(return_value=mock_turn)

# Mock get_temp_agent to return our mock agent
mocker.patch(
"app.endpoints.rlsapi_v1.get_temp_agent",
return_value=(mock_agent, "test_session_id", None),
)

# Mock the client holder
mock_client_holder = mocker.Mock()
mock_client_holder.get_client.return_value = mock_client
mocker.patch(
"app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder",
return_value=mock_client_holder,
)
_setup_chat_completions_mock(mocker, mocker.AsyncMock(return_value=mock_response))


@pytest.fixture(name="mock_api_connection_error")
def mock_api_connection_error_fixture(mocker: MockerFixture) -> None:
"""Mock AsyncLlamaStackClientHolder to raise APIConnectionError."""
mock_client_holder = mocker.Mock()
mock_client_holder.get_client.side_effect = APIConnectionError(
request=mocker.Mock()
)
mocker.patch(
"app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder",
return_value=mock_client_holder,
"""Mock chat.completions.create() to raise APIConnectionError."""
_setup_chat_completions_mock(
mocker,
mocker.AsyncMock(side_effect=APIConnectionError(request=mocker.Mock())),
)


Expand Down Expand Up @@ -212,6 +197,7 @@ async def test_infer_minimal_request(

assert isinstance(response, RlsapiV1InferResponse)
assert response.data.text == "This is a test LLM response."
assert response.data.request_id is not None
assert check_suid(response.data.request_id)


Expand Down
Loading