From 83b4fac390345f1542eb820d608377cc07f59cdd Mon Sep 17 00:00:00 2001 From: Eran Cohen Date: Tue, 5 Aug 2025 14:53:16 +0300 Subject: [PATCH] feat: Convert lightspeed-core to async architecture - Migrate all sync endpoints to async handlers (query, conversations, models, etc.) - Replace sync LlamaStackClient with AsyncLlamaStackClient throughout - Implement proper agent initialization with await agent.initialize() - Centralize test fixtures in conftest.py and fix agent mocking - Remove legacy sync client infrastructure This resolves blocking behavior in all endpoints except streaming_query which was already async, enabling proper concurrent request handling. Signed-off-by: Eran Cohen --- scripts/generate_openapi_schema.py | 6 +- src/app/endpoints/authorized.py | 9 +- src/app/endpoints/conversations.py | 16 +- src/app/endpoints/models.py | 8 +- src/app/endpoints/query.py | 85 +-- src/app/endpoints/streaming_query.py | 53 +- src/client.py | 39 +- src/lightspeed_stack.py | 4 +- src/metrics/utils.py | 8 +- src/models/config.py | 2 +- src/utils/common.py | 42 +- src/utils/endpoints.py | 68 ++- tests/unit/app/endpoints/test_authorized.py | 79 +-- .../unit/app/endpoints/test_conversations.py | 137 +++-- tests/unit/app/endpoints/test_models.py | 45 +- tests/unit/app/endpoints/test_query.py | 544 +++--------------- .../app/endpoints/test_streaming_query.py | 444 +------------- tests/unit/conftest.py | 28 + tests/unit/metrics/test_utis.py | 9 +- tests/unit/test_client.py | 69 +-- tests/unit/utils/test_common.py | 18 +- tests/unit/utils/test_endpoints.py | 448 +++++++++++++++ 22 files changed, 843 insertions(+), 1318 deletions(-) create mode 100644 tests/unit/conftest.py diff --git a/scripts/generate_openapi_schema.py b/scripts/generate_openapi_schema.py index 11681ceb7..a616c7b22 100644 --- a/scripts/generate_openapi_schema.py +++ b/scripts/generate_openapi_schema.py @@ -10,13 +10,15 @@ # it is needed to read proper configuration in order to start the app to generate schema from configuration import configuration -from client import LlamaStackClientHolder +from client import AsyncLlamaStackClientHolder cfg_file = "lightspeed-stack.yaml" configuration.load_configuration(cfg_file) # Llama Stack client needs to be loaded before REST API is fully initialized -LlamaStackClientHolder().load(configuration.configuration.llama_stack) +import asyncio # noqa: E402 + +asyncio.run(AsyncLlamaStackClientHolder().load(configuration.configuration.llama_stack)) from app.main import app # noqa: E402 pylint: disable=C0413 diff --git a/src/app/endpoints/authorized.py b/src/app/endpoints/authorized.py index c434ed2f5..ce7bf20de 100644 --- a/src/app/endpoints/authorized.py +++ b/src/app/endpoints/authorized.py @@ -1,10 +1,9 @@ """Handler for REST API call to authorized endpoint.""" -import asyncio import logging from typing import Any -from fastapi import APIRouter, Request +from fastapi import APIRouter, Depends from auth import get_auth_dependency from models.responses import AuthorizedResponse, UnauthorizedResponse, ForbiddenResponse @@ -31,8 +30,10 @@ @router.post("/authorized", responses=authorized_responses) -def authorized_endpoint_handler(_request: Request) -> AuthorizedResponse: +async def authorized_endpoint_handler( + auth: Any = Depends(auth_dependency), +) -> AuthorizedResponse: """Handle request to the /authorized endpoint.""" # Ignore the user token, we should not return it in the response - user_id, user_name, _ = asyncio.run(auth_dependency(_request)) + user_id, user_name, _ = auth return AuthorizedResponse(user_id=user_id, username=user_name) diff --git a/src/app/endpoints/conversations.py b/src/app/endpoints/conversations.py index 6032d01d6..59df19e17 100644 --- a/src/app/endpoints/conversations.py +++ b/src/app/endpoints/conversations.py @@ -7,7 +7,7 @@ from fastapi import APIRouter, HTTPException, status, Depends -from client import LlamaStackClientHolder +from client import AsyncLlamaStackClientHolder from configuration import configuration from models.responses import ConversationResponse, ConversationDeleteResponse from auth import get_auth_dependency @@ -110,7 +110,7 @@ def simplify_session_data(session_data: dict) -> list[dict[str, Any]]: @router.get("/conversations/{conversation_id}", responses=conversation_responses) -def get_conversation_endpoint_handler( +async def get_conversation_endpoint_handler( conversation_id: str, _auth: Any = Depends(auth_dependency), ) -> ConversationResponse: @@ -132,9 +132,9 @@ def get_conversation_endpoint_handler( logger.info("Retrieving conversation %s", conversation_id) try: - client = LlamaStackClientHolder().get_client() + client = AsyncLlamaStackClientHolder().get_client() - session_data = client.agents.session.list(agent_id=agent_id).data[0] + session_data = (await client.agents.session.list(agent_id=agent_id)).data[0] logger.info("Successfully retrieved conversation %s", conversation_id) @@ -179,7 +179,7 @@ def get_conversation_endpoint_handler( @router.delete( "/conversations/{conversation_id}", responses=conversation_delete_responses ) -def delete_conversation_endpoint_handler( +async def delete_conversation_endpoint_handler( conversation_id: str, _auth: Any = Depends(auth_dependency), ) -> ConversationDeleteResponse: @@ -201,10 +201,12 @@ def delete_conversation_endpoint_handler( try: # Get Llama Stack client - client = LlamaStackClientHolder().get_client() + client = AsyncLlamaStackClientHolder().get_client() # Delete session using the conversation_id as session_id # In this implementation, conversation_id and session_id are the same - client.agents.session.delete(agent_id=agent_id, session_id=conversation_id) + await client.agents.session.delete( + agent_id=agent_id, session_id=conversation_id + ) logger.info("Successfully deleted conversation %s", conversation_id) diff --git a/src/app/endpoints/models.py b/src/app/endpoints/models.py index eb9840970..d583f3907 100644 --- a/src/app/endpoints/models.py +++ b/src/app/endpoints/models.py @@ -6,7 +6,7 @@ from llama_stack_client import APIConnectionError from fastapi import APIRouter, HTTPException, Request, status -from client import LlamaStackClientHolder +from client import AsyncLlamaStackClientHolder from configuration import configuration from models.responses import ModelsResponse from utils.endpoints import check_configuration_loaded @@ -43,7 +43,7 @@ @router.get("/models", responses=models_responses) -def models_endpoint_handler(_request: Request) -> ModelsResponse: +async def models_endpoint_handler(_request: Request) -> ModelsResponse: """Handle requests to the /models endpoint.""" check_configuration_loaded(configuration) @@ -52,9 +52,9 @@ def models_endpoint_handler(_request: Request) -> ModelsResponse: try: # try to get Llama Stack client - client = LlamaStackClientHolder().get_client() + client = AsyncLlamaStackClientHolder().get_client() # retrieve models - models = client.models.list() + models = await client.models.list() m = [dict(m) for m in models] return ModelsResponse(models=m) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 3f7ef4701..44127a7a0 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -1,6 +1,5 @@ """Handler for REST API call to provide answer to query.""" -from contextlib import suppress from datetime import datetime, UTC import json import logging @@ -8,9 +7,8 @@ from pathlib import Path from typing import Annotated, Any -from llama_stack_client.lib.agents.agent import Agent from llama_stack_client import APIConnectionError -from llama_stack_client import LlamaStackClient # type: ignore +from llama_stack_client import AsyncLlamaStackClient # type: ignore from llama_stack_client.types import UserMessage, Shield # type: ignore from llama_stack_client.types.agents.turn_create_params import ( ToolgroupAgentToolGroupWithArgs, @@ -20,18 +18,17 @@ from fastapi import APIRouter, HTTPException, status, Depends -from client import LlamaStackClientHolder +from auth import get_auth_dependency +from auth.interface import AuthTuple +from client import AsyncLlamaStackClientHolder from configuration import configuration import metrics from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse from models.requests import QueryRequest, Attachment import constants -from auth import get_auth_dependency -from auth.interface import AuthTuple -from utils.endpoints import check_configuration_loaded, get_system_prompt +from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups from utils.suid import get_suid -from utils.types import GraniteToolParser logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["query"]) @@ -68,53 +65,8 @@ def is_transcripts_enabled() -> bool: return configuration.user_data_collection_configuration.transcripts_enabled -def get_agent( # pylint: disable=too-many-arguments,too-many-positional-arguments - client: LlamaStackClient, - model_id: str, - system_prompt: str, - available_input_shields: list[str], - available_output_shields: list[str], - conversation_id: str | None, - no_tools: bool = False, -) -> tuple[Agent, str, str]: - """Get existing agent or create a new one with session persistence.""" - existing_agent_id = None - if conversation_id: - with suppress(ValueError): - existing_agent_id = client.agents.retrieve( - agent_id=conversation_id - ).agent_id - - logger.debug("Creating new agent") - # TODO(lucasagomes): move to ReActAgent - agent = Agent( - client, - model=model_id, - instructions=system_prompt, - input_shields=available_input_shields if available_input_shields else [], - output_shields=available_output_shields if available_output_shields else [], - tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id), - enable_session_persistence=True, - ) - - agent.initialize() - - if existing_agent_id and conversation_id: - orphan_agent_id = agent.agent_id - agent.agent_id = conversation_id - client.agents.delete(agent_id=orphan_agent_id) - sessions_response = client.agents.session.list(agent_id=conversation_id) - logger.info("session response: %s", sessions_response) - session_id = str(sessions_response.data[0]["session_id"]) - else: - conversation_id = agent.agent_id - session_id = agent.create_session(get_suid()) - - return agent, conversation_id, session_id - - @router.post("/query", responses=query_response) -def query_endpoint_handler( +async def query_endpoint_handler( query_request: QueryRequest, auth: Annotated[AuthTuple, Depends(auth_dependency)], mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), @@ -129,11 +81,11 @@ def query_endpoint_handler( try: # try to get Llama Stack client - client = LlamaStackClientHolder().get_client() + client = AsyncLlamaStackClientHolder().get_client() model_id, provider_id = select_model_and_provider_id( - client.models.list(), query_request + await client.models.list(), query_request ) - response, conversation_id = retrieve_response( + response, conversation_id = await retrieve_response( client, model_id, query_request, @@ -253,8 +205,8 @@ def is_input_shield(shield: Shield) -> bool: return _is_inout_shield(shield) or not is_output_shield(shield) -def retrieve_response( # pylint: disable=too-many-locals - client: LlamaStackClient, +async def retrieve_response( # pylint: disable=too-many-locals + client: AsyncLlamaStackClient, model_id: str, query_request: QueryRequest, token: str, @@ -262,10 +214,12 @@ def retrieve_response( # pylint: disable=too-many-locals ) -> tuple[str, str]: """Retrieve response from LLMs and agents.""" available_input_shields = [ - shield.identifier for shield in filter(is_input_shield, client.shields.list()) + shield.identifier + for shield in filter(is_input_shield, await client.shields.list()) ] available_output_shields = [ - shield.identifier for shield in filter(is_output_shield, client.shields.list()) + shield.identifier + for shield in filter(is_output_shield, await client.shields.list()) ] if not available_input_shields and not available_output_shields: logger.info("No available shields. Disabling safety") @@ -284,7 +238,7 @@ def retrieve_response( # pylint: disable=too-many-locals if query_request.attachments: validate_attachments_metadata(query_request.attachments) - agent, conversation_id, session_id = get_agent( + agent, conversation_id, session_id = await get_agent( client, model_id, system_prompt, @@ -294,6 +248,7 @@ def retrieve_response( # pylint: disable=too-many-locals query_request.no_tools or False, ) + logger.debug("Conversation ID: %s, session ID: %s", conversation_id, session_id) # bypass tools and MCP servers if no_tools is True if query_request.no_tools: mcp_headers = {} @@ -318,7 +273,9 @@ def retrieve_response( # pylint: disable=too-many-locals ), } - vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()] + vector_db_ids = [ + vector_db.identifier for vector_db in await client.vector_dbs.list() + ] toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ mcp_server.name for mcp_server in configuration.mcp_servers ] @@ -326,7 +283,7 @@ def retrieve_response( # pylint: disable=too-many-locals if not toolgroups: toolgroups = None - response = agent.create_turn( + response = await agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], session_id=session_id, documents=query_request.get_documents(), diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 1da0f65df..009c2f017 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -1,14 +1,12 @@ """Handler for REST API call to provide answer to streaming query.""" import ast -from contextlib import suppress import json import re import logging from typing import Annotated, Any, AsyncIterator, Iterator from llama_stack_client import APIConnectionError -from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore from llama_stack_client import AsyncLlamaStackClient # type: ignore from llama_stack_client.types import UserMessage # type: ignore @@ -25,10 +23,8 @@ from configuration import configuration import metrics from models.requests import QueryRequest -from utils.endpoints import check_configuration_loaded, get_system_prompt +from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups -from utils.suid import get_suid -from utils.types import GraniteToolParser from app.endpoints.query import ( get_rag_toolgroups, @@ -45,50 +41,6 @@ auth_dependency = get_auth_dependency() -# # pylint: disable=R0913,R0917 -async def get_agent( - client: AsyncLlamaStackClient, - model_id: str, - system_prompt: str, - available_input_shields: list[str], - available_output_shields: list[str], - conversation_id: str | None, - no_tools: bool = False, -) -> tuple[AsyncAgent, str, str]: - """Get existing agent or create a new one with session persistence.""" - existing_agent_id = None - if conversation_id: - with suppress(ValueError): - agent_response = await client.agents.retrieve(agent_id=conversation_id) - existing_agent_id = agent_response.agent_id - - logger.debug("Creating new agent") - agent = AsyncAgent( - client, # type: ignore[arg-type] - model=model_id, - instructions=system_prompt, - input_shields=available_input_shields if available_input_shields else [], - output_shields=available_output_shields if available_output_shields else [], - tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id), - enable_session_persistence=True, - ) - - await agent.initialize() - - if existing_agent_id and conversation_id: - orphan_agent_id = agent.agent_id - agent._agent_id = conversation_id # type: ignore[assignment] # pylint: disable=protected-access - await client.agents.delete(agent_id=orphan_agent_id) - sessions_response = await client.agents.session.list(agent_id=conversation_id) - logger.info("session response: %s", sessions_response) - session_id = str(sessions_response.data[0]["session_id"]) - else: - conversation_id = agent.agent_id - session_id = await agent.create_session(get_suid()) - - return agent, conversation_id, session_id - - METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n") @@ -556,8 +508,7 @@ async def retrieve_response( query_request.no_tools or False, ) - logger.debug("Session ID: %s", conversation_id) - + logger.debug("Conversation ID: %s, session ID: %s", conversation_id, session_id) # bypass tools and MCP servers if no_tools is True if query_request.no_tools: mcp_headers = {} diff --git a/src/client.py b/src/client.py index 04ddef416..2dcbfdc12 100644 --- a/src/client.py +++ b/src/client.py @@ -6,9 +6,8 @@ from llama_stack import ( AsyncLlamaStackAsLibraryClient, # type: ignore - LlamaStackAsLibraryClient, # type: ignore ) -from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient # type: ignore +from llama_stack_client import AsyncLlamaStackClient # type: ignore from models.config import LlamaStackConfiguration from utils.types import Singleton @@ -16,42 +15,6 @@ logger = logging.getLogger(__name__) -class LlamaStackClientHolder(metaclass=Singleton): - """Container for an initialised LlamaStackClient.""" - - _lsc: Optional[LlamaStackClient] = None - - def load(self, llama_stack_config: LlamaStackConfiguration) -> None: - """Retrieve Llama stack client according to configuration.""" - if llama_stack_config.use_as_library_client is True: - if llama_stack_config.library_client_config_path is not None: - logger.info("Using Llama stack as library client") - client = LlamaStackAsLibraryClient( - llama_stack_config.library_client_config_path - ) - client.initialize() - self._lsc = client - else: - msg = "Configuration problem: library_client_config_path option is not set" - logger.error(msg) - # tisnik: use custom exception there - with cause etc. - raise ValueError(msg) - - else: - logger.info("Using Llama stack running as a service") - self._lsc = LlamaStackClient( - base_url=llama_stack_config.url, api_key=llama_stack_config.api_key - ) - - def get_client(self) -> LlamaStackClient: - """Return an initialised LlamaStackClient.""" - if not self._lsc: - raise RuntimeError( - "LlamaStackClient has not been initialised. Ensure 'load(..)' has been called." - ) - return self._lsc - - class AsyncLlamaStackClientHolder(metaclass=Singleton): """Container for an initialised AsyncLlamaStackClient.""" diff --git a/src/lightspeed_stack.py b/src/lightspeed_stack.py index 4479e308c..bbe004d62 100644 --- a/src/lightspeed_stack.py +++ b/src/lightspeed_stack.py @@ -12,7 +12,7 @@ from runners.uvicorn import start_uvicorn from runners.data_collector import start_data_collector from configuration import configuration -from client import LlamaStackClientHolder, AsyncLlamaStackClientHolder +from client import AsyncLlamaStackClientHolder FORMAT = "%(message)s" logging.basicConfig( @@ -69,8 +69,6 @@ def main() -> None: logger.info( "Llama stack configuration: %s", configuration.llama_stack_configuration ) - logger.info("Creating LlamaStackClient") - LlamaStackClientHolder().load(configuration.configuration.llama_stack) logger.info("Creating AsyncLlamaStackClient") asyncio.run( AsyncLlamaStackClientHolder().load(configuration.configuration.llama_stack) diff --git a/src/metrics/utils.py b/src/metrics/utils.py index cece371c1..aceddd829 100644 --- a/src/metrics/utils.py +++ b/src/metrics/utils.py @@ -1,7 +1,7 @@ """Utility functions for metrics handling.""" from configuration import configuration -from client import LlamaStackClientHolder, AsyncLlamaStackClientHolder +from client import AsyncLlamaStackClientHolder from log import get_logger import metrics from utils.common import run_once_async @@ -13,11 +13,7 @@ async def setup_model_metrics() -> None: """Perform setup of all metrics related to LLM model and provider.""" logger.info("Setting up model metrics") - model_list = [] - if configuration.llama_stack_configuration.use_as_library_client: - model_list = await AsyncLlamaStackClientHolder().get_client().models.list() - else: - model_list = LlamaStackClientHolder().get_client().models.list() + model_list = await AsyncLlamaStackClientHolder().get_client().models.list() models = [ model diff --git a/src/models/config.py b/src/models/config.py index 20af7ad34..a6f3cc593 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -92,7 +92,7 @@ def check_llama_stack_model(self) -> Self: if self.library_client_config_path is None: # pylint: disable=line-too-long raise ValueError( - "LLama stack library client mode is enabled but a configuration file path is not specified" # noqa: C0301 + "LLama stack library client mode is enabled but a configuration file path is not specified" # noqa: E501 ) # the configuration file must exists and be regular readable file checks.file_check( diff --git a/src/utils/common.py b/src/utils/common.py index e4a35a57a..1ae261b68 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -5,12 +5,10 @@ from typing import Any, Callable, List, cast from logging import Logger -from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient -from llama_stack import ( - AsyncLlamaStackAsLibraryClient, -) +from llama_stack_client import AsyncLlamaStackClient +from llama_stack import AsyncLlamaStackAsLibraryClient -from client import LlamaStackClientHolder, AsyncLlamaStackClientHolder +from client import AsyncLlamaStackClientHolder from models.config import Configuration, ModelContextProtocolServer @@ -31,9 +29,9 @@ async def register_mcp_servers_async( await client.initialize() await _register_mcp_toolgroups_async(client, configuration.mcp_servers, logger) else: - # Service client - use sync interface - client = LlamaStackClientHolder().get_client() - _register_mcp_toolgroups_sync(client, configuration.mcp_servers, logger) + # Service client - also use async interface + client = AsyncLlamaStackClientHolder().get_client() + await _register_mcp_toolgroups_async(client, configuration.mcp_servers, logger) async def _register_mcp_toolgroups_async( @@ -64,34 +62,6 @@ async def _register_mcp_toolgroups_async( logger.debug("MCP server %s registered successfully", mcp.name) -def _register_mcp_toolgroups_sync( - client: LlamaStackClient, - mcp_servers: List[ModelContextProtocolServer], - logger: Logger, -) -> None: - """Sync logic for registering MCP toolgroups.""" - # Get registered tool groups - registered_toolgroups = client.toolgroups.list() - registered_toolgroups_ids = [ - tool_group.provider_resource_id for tool_group in registered_toolgroups - ] - logger.debug("Registered toolgroups: %s", registered_toolgroups_ids) - - # Register toolgroups for MCP servers if not already registered - for mcp in mcp_servers: - if mcp.name not in registered_toolgroups_ids: - logger.debug("Registering MCP server: %s, %s", mcp.name, mcp.url) - - registration_params = { - "toolgroup_id": mcp.name, - "provider_id": mcp.provider_id, - "mcp_endpoint": {"uri": mcp.url}, - } - - client.toolgroups.register(**registration_params) - logger.debug("MCP server %s registered successfully", mcp.name) - - def run_once_async(func: Callable) -> Callable: """Decorate an async function to run only once.""" task = None diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index c9a8f91c9..19f816de0 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -1,26 +1,35 @@ """Utility functions for endpoint handlers.""" +from contextlib import suppress +import logging from fastapi import HTTPException, status +from llama_stack_client._client import AsyncLlamaStackClient +from llama_stack_client.lib.agents.agent import AsyncAgent import constants from models.requests import QueryRequest from configuration import AppConfig +from utils.suid import get_suid +from utils.types import GraniteToolParser -def check_configuration_loaded(configuration: AppConfig) -> None: +logger = logging.getLogger("utils.endpoints") + + +def check_configuration_loaded(config: AppConfig) -> None: """Check that configuration is loaded and raise exception when it is not.""" - if configuration is None: + if config is None: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"response": "Configuration is not loaded"}, ) -def get_system_prompt(query_request: QueryRequest, configuration: AppConfig) -> str: +def get_system_prompt(query_request: QueryRequest, config: AppConfig) -> str: """Get the system prompt: the provided one, configured one, or default one.""" system_prompt_disabled = ( - configuration.customization is not None - and configuration.customization.disable_query_system_prompt + config.customization is not None + and config.customization.disable_query_system_prompt ) if system_prompt_disabled and query_request.system_prompt: raise HTTPException( @@ -41,10 +50,53 @@ def get_system_prompt(query_request: QueryRequest, configuration: AppConfig) -> return query_request.system_prompt if ( - configuration.customization is not None - and configuration.customization.system_prompt is not None + config.customization is not None + and config.customization.system_prompt is not None ): - return configuration.customization.system_prompt + return config.customization.system_prompt # default system prompt has the lowest precedence return constants.DEFAULT_SYSTEM_PROMPT + + +# # pylint: disable=R0913,R0917 +async def get_agent( + client: AsyncLlamaStackClient, + model_id: str, + system_prompt: str, + available_input_shields: list[str], + available_output_shields: list[str], + conversation_id: str | None, + no_tools: bool = False, +) -> tuple[AsyncAgent, str, str]: + """Get existing agent or create a new one with session persistence.""" + existing_agent_id = None + if conversation_id: + with suppress(ValueError): + agent_response = await client.agents.retrieve(agent_id=conversation_id) + existing_agent_id = agent_response.agent_id + + logger.debug("Creating new agent") + agent = AsyncAgent( + client, # type: ignore[arg-type] + model=model_id, + instructions=system_prompt, + input_shields=available_input_shields if available_input_shields else [], + output_shields=available_output_shields if available_output_shields else [], + tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id), + enable_session_persistence=True, + ) + await agent.initialize() + + if existing_agent_id and conversation_id: + orphan_agent_id = agent.agent_id + agent._agent_id = conversation_id # type: ignore[assignment] # pylint: disable=protected-access + await client.agents.delete(agent_id=orphan_agent_id) + sessions_response = await client.agents.session.list(agent_id=conversation_id) + logger.info("session response: %s", sessions_response) + session_id = str(sessions_response.data[0]["session_id"]) + else: + conversation_id = agent.agent_id + session_id = await agent.create_session(get_suid()) + + return agent, conversation_id, session_id diff --git a/tests/unit/app/endpoints/test_authorized.py b/tests/unit/app/endpoints/test_authorized.py index 0fb22afed..a1d4144d9 100644 --- a/tests/unit/app/endpoints/test_authorized.py +++ b/tests/unit/app/endpoints/test_authorized.py @@ -1,30 +1,19 @@ """Unit tests for the /authorized REST API endpoint.""" -from unittest.mock import AsyncMock - import pytest -from fastapi import Request, HTTPException +from fastapi import HTTPException +from starlette.datastructures import Headers from app.endpoints.authorized import authorized_endpoint_handler +from auth.utils import extract_user_token + +MOCK_AUTH = ("test-id", "test-user", "token") -def test_authorized_endpoint(mocker): +@pytest.mark.asyncio +async def test_authorized_endpoint(): """Test the authorized endpoint handler.""" - # Mock the auth dependency to return a user ID and username - auth_dependency_mock = AsyncMock() - auth_dependency_mock.return_value = ("test-id", "test-user", None) - mocker.patch( - "app.endpoints.authorized.auth_dependency", side_effect=auth_dependency_mock - ) - - request = Request( - scope={ - "type": "http", - "query_string": b"", - } - ) - - response = authorized_endpoint_handler(request) + response = await authorized_endpoint_handler(auth=MOCK_AUTH) assert response.model_dump() == { "user_id": "test-id", @@ -32,25 +21,41 @@ def test_authorized_endpoint(mocker): } -def test_authorized_unauthorized(mocker): - """Test the authorized endpoint handler with a custom user ID.""" - auth_dependency_mock = AsyncMock() - auth_dependency_mock.side_effect = HTTPException( - status_code=403, detail="User is not authorized" - ) - mocker.patch( - "app.endpoints.authorized.auth_dependency", side_effect=auth_dependency_mock - ) +@pytest.mark.asyncio +async def test_authorized_unauthorized(): + """Test the authorized endpoint handler behavior under unauthorized conditions. + + Note: In real scenarios, FastAPI's dependency injection would prevent the handler + from being called if auth fails. This test simulates what would happen if somehow + invalid auth data reached the handler. + """ + # Test scenario 1: None auth data (complete auth failure) + with pytest.raises(TypeError): + # This would occur if auth dependency somehow returned None + await authorized_endpoint_handler(auth=None) - request = Request( - scope={ - "type": "http", - "query_string": b"", - } - ) + # Test scenario 2: Invalid auth tuple structure + with pytest.raises(ValueError): + # This would occur if auth dependency returned malformed data + await authorized_endpoint_handler(auth=("incomplete-auth-data",)) + +@pytest.mark.asyncio +async def test_authorized_dependency_unauthorized(): + """Test that auth dependency raises HTTPException with 403 for unauthorized access.""" + # Test the auth utility function that would be called by auth dependencies + # This simulates the unauthorized scenario that would prevent the handler from being called + + # Test case 1: No Authorization header (400 error from extract_user_token) + headers_no_auth = Headers({}) with pytest.raises(HTTPException) as exc_info: - authorized_endpoint_handler(request) + extract_user_token(headers_no_auth) + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "No Authorization header found" - assert exc_info.value.status_code == 403 - assert exc_info.value.detail == "User is not authorized" + # Test case 2: Invalid Authorization header format (400 error from extract_user_token) + headers_invalid_auth = Headers({"Authorization": "InvalidFormat"}) + with pytest.raises(HTTPException) as exc_info: + extract_user_token(headers_invalid_auth) + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "No token found in Authorization header" diff --git a/tests/unit/app/endpoints/test_conversations.py b/tests/unit/app/endpoints/test_conversations.py index d1df17a09..8b8954475 100644 --- a/tests/unit/app/endpoints/test_conversations.py +++ b/tests/unit/app/endpoints/test_conversations.py @@ -114,7 +114,8 @@ def expected_chat_history_fixture(): class TestSimplifySessionData: """Test cases for the simplify_session_data function.""" - def test_simplify_session_data_with_model_dump( + @pytest.mark.asyncio + async def test_simplify_session_data_with_model_dump( self, mock_session_data, expected_chat_history ): """Test simplify_session_data with session data.""" @@ -122,7 +123,8 @@ def test_simplify_session_data_with_model_dump( assert result == expected_chat_history - def test_simplify_session_data_empty_turns(self): + @pytest.mark.asyncio + async def test_simplify_session_data_empty_turns(self): """Test simplify_session_data with empty turns.""" session_data = { "session_id": VALID_CONVERSATION_ID, @@ -134,7 +136,8 @@ def test_simplify_session_data_empty_turns(self): assert not result - def test_simplify_session_data_filters_unwanted_fields(self): + @pytest.mark.asyncio + async def test_simplify_session_data_filters_unwanted_fields(self): """Test that simplify_session_data properly filters out unwanted fields.""" session_data = { "session_id": VALID_CONVERSATION_ID, @@ -181,88 +184,103 @@ def test_simplify_session_data_filters_unwanted_fields(self): class TestGetConversationEndpoint: """Test cases for the GET /conversations/{conversation_id} endpoint.""" - def test_configuration_not_loaded(self, mocker): + @pytest.mark.asyncio + async def test_configuration_not_loaded(self, mocker): """Test the endpoint when configuration is not loaded.""" mocker.patch("app.endpoints.conversations.configuration", None) with pytest.raises(HTTPException) as exc_info: - get_conversation_endpoint_handler(VALID_CONVERSATION_ID, _auth=MOCK_AUTH) + await get_conversation_endpoint_handler( + VALID_CONVERSATION_ID, _auth=MOCK_AUTH + ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert "Configuration is not loaded" in exc_info.value.detail["response"] - def test_invalid_conversation_id_format(self, mocker, setup_configuration): + @pytest.mark.asyncio + async def test_invalid_conversation_id_format(self, mocker, setup_configuration): """Test the endpoint with an invalid conversation ID format.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=False) with pytest.raises(HTTPException) as exc_info: - get_conversation_endpoint_handler(INVALID_CONVERSATION_ID, _auth=MOCK_AUTH) + await get_conversation_endpoint_handler( + INVALID_CONVERSATION_ID, _auth=MOCK_AUTH + ) assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST assert "Invalid conversation ID format" in exc_info.value.detail["response"] assert INVALID_CONVERSATION_ID in exc_info.value.detail["cause"] - def test_llama_stack_connection_error(self, mocker, setup_configuration): + @pytest.mark.asyncio + async def test_llama_stack_connection_error(self, mocker, setup_configuration): """Test the endpoint when LlamaStack connection fails.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Mock LlamaStackClientHolder to raise APIConnectionError - mock_client = mocker.Mock() + # Mock AsyncLlamaStackClientHolder to raise APIConnectionError + mock_client = mocker.AsyncMock() mock_client.agents.session.list.side_effect = APIConnectionError(request=None) mock_client_holder = mocker.patch( - "app.endpoints.conversations.LlamaStackClientHolder" + "app.endpoints.conversations.AsyncLlamaStackClientHolder" ) mock_client_holder.return_value.get_client.return_value = mock_client # simulate situation when it is not possible to connect to Llama Stack with pytest.raises(HTTPException) as exc_info: - get_conversation_endpoint_handler(VALID_CONVERSATION_ID, _auth=MOCK_AUTH) + await get_conversation_endpoint_handler( + VALID_CONVERSATION_ID, _auth=MOCK_AUTH + ) assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE assert "Unable to connect to Llama Stack" in exc_info.value.detail["response"] - def test_llama_stack_not_found_error(self, mocker, setup_configuration): + @pytest.mark.asyncio + async def test_llama_stack_not_found_error(self, mocker, setup_configuration): """Test the endpoint when LlamaStack returns NotFoundError.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Mock LlamaStackClientHolder to raise NotFoundError - mock_client = mocker.Mock() + # Mock AsyncLlamaStackClientHolder to raise NotFoundError + mock_client = mocker.AsyncMock() mock_client.agents.session.list.side_effect = NotFoundError( message="Session not found", response=mocker.Mock(request=None), body=None ) mock_client_holder = mocker.patch( - "app.endpoints.conversations.LlamaStackClientHolder" + "app.endpoints.conversations.AsyncLlamaStackClientHolder" ) mock_client_holder.return_value.get_client.return_value = mock_client with pytest.raises(HTTPException) as exc_info: - get_conversation_endpoint_handler(VALID_CONVERSATION_ID, _auth=MOCK_AUTH) + await get_conversation_endpoint_handler( + VALID_CONVERSATION_ID, _auth=MOCK_AUTH + ) assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND assert "Conversation not found" in exc_info.value.detail["response"] assert "could not be retrieved" in exc_info.value.detail["cause"] assert VALID_CONVERSATION_ID in exc_info.value.detail["cause"] - def test_session_retrieve_exception(self, mocker, setup_configuration): + @pytest.mark.asyncio + async def test_session_retrieve_exception(self, mocker, setup_configuration): """Test the endpoint when session retrieval raises an exception.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Mock LlamaStackClientHolder to raise a general exception + # Mock AsyncLlamaStackClientHolder to raise a general exception mock_client = mocker.Mock() - mock_client.agents.session.retrieve.side_effect = Exception( - "Failed to get session" + mock_client.agents.session.retrieve.side_effect = HTTPException( + status_code=500, detail="Failed to get session" ) mock_client_holder = mocker.patch( - "app.endpoints.conversations.LlamaStackClientHolder" + "app.endpoints.conversations.AsyncLlamaStackClientHolder" ) mock_client_holder.return_value.get_client.return_value = mock_client with pytest.raises(HTTPException) as exc_info: - get_conversation_endpoint_handler(VALID_CONVERSATION_ID, _auth=MOCK_AUTH) + await get_conversation_endpoint_handler( + VALID_CONVERSATION_ID, _auth=MOCK_AUTH + ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert "Unknown error" in exc_info.value.detail["response"] @@ -270,7 +288,8 @@ def test_session_retrieve_exception(self, mocker, setup_configuration): "Unknown error while getting conversation" in exc_info.value.detail["cause"] ) - def test_successful_conversation_retrieval( + @pytest.mark.asyncio + async def test_successful_conversation_retrieval( self, mocker, setup_configuration, mock_session_data, expected_chat_history ): """Test successful conversation retrieval with simplified response structure.""" @@ -281,17 +300,17 @@ def test_successful_conversation_retrieval( mock_session_obj = mocker.Mock() mock_session_obj.model_dump.return_value = mock_session_data - # Mock LlamaStackClientHolder - mock_client = mocker.Mock() + # Mock AsyncLlamaStackClientHolder + mock_client = mocker.AsyncMock() mock_client.agents.session.list.return_value = mocker.Mock( data=[mock_session_data] ) mock_client_holder = mocker.patch( - "app.endpoints.conversations.LlamaStackClientHolder" + "app.endpoints.conversations.AsyncLlamaStackClientHolder" ) mock_client_holder.return_value.get_client.return_value = mock_client - response = get_conversation_endpoint_handler( + response = await get_conversation_endpoint_handler( VALID_CONVERSATION_ID, _auth=MOCK_AUTH ) @@ -306,23 +325,27 @@ def test_successful_conversation_retrieval( class TestDeleteConversationEndpoint: """Test cases for the DELETE /conversations/{conversation_id} endpoint.""" - def test_configuration_not_loaded(self, mocker): + @pytest.mark.asyncio + async def test_configuration_not_loaded(self, mocker): """Test the endpoint when configuration is not loaded.""" mocker.patch("app.endpoints.conversations.configuration", None) with pytest.raises(HTTPException) as exc_info: - delete_conversation_endpoint_handler(VALID_CONVERSATION_ID, _auth=MOCK_AUTH) + await delete_conversation_endpoint_handler( + VALID_CONVERSATION_ID, _auth=MOCK_AUTH + ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert "Configuration is not loaded" in exc_info.value.detail["response"] - def test_invalid_conversation_id_format(self, mocker, setup_configuration): + @pytest.mark.asyncio + async def test_invalid_conversation_id_format(self, mocker, setup_configuration): """Test the endpoint with an invalid conversation ID format.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=False) with pytest.raises(HTTPException) as exc_info: - delete_conversation_endpoint_handler( + await delete_conversation_endpoint_handler( INVALID_CONVERSATION_ID, _auth=MOCK_AUTH ) @@ -330,65 +353,74 @@ def test_invalid_conversation_id_format(self, mocker, setup_configuration): assert "Invalid conversation ID format" in exc_info.value.detail["response"] assert INVALID_CONVERSATION_ID in exc_info.value.detail["cause"] - def test_llama_stack_connection_error(self, mocker, setup_configuration): + @pytest.mark.asyncio + async def test_llama_stack_connection_error(self, mocker, setup_configuration): """Test the endpoint when LlamaStack connection fails.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Mock LlamaStackClientHolder to raise APIConnectionError - mock_client = mocker.Mock() + # Mock AsyncLlamaStackClientHolder to raise APIConnectionError + mock_client = mocker.AsyncMock() mock_client.agents.session.delete.side_effect = APIConnectionError(request=None) mock_client_holder = mocker.patch( - "app.endpoints.conversations.LlamaStackClientHolder" + "app.endpoints.conversations.AsyncLlamaStackClientHolder" ) mock_client_holder.return_value.get_client.return_value = mock_client with pytest.raises(HTTPException) as exc_info: - delete_conversation_endpoint_handler(VALID_CONVERSATION_ID, _auth=MOCK_AUTH) + await delete_conversation_endpoint_handler( + VALID_CONVERSATION_ID, _auth=MOCK_AUTH + ) assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE assert "Unable to connect to Llama Stack" in exc_info.value.detail["response"] - def test_llama_stack_not_found_error(self, mocker, setup_configuration): + @pytest.mark.asyncio + async def test_llama_stack_not_found_error(self, mocker, setup_configuration): """Test the endpoint when LlamaStack returns NotFoundError.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Mock LlamaStackClientHolder to raise NotFoundError - mock_client = mocker.Mock() + # Mock AsyncLlamaStackClientHolder to raise NotFoundError + mock_client = mocker.AsyncMock() mock_client.agents.session.delete.side_effect = NotFoundError( message="Session not found", response=mocker.Mock(request=None), body=None ) mock_client_holder = mocker.patch( - "app.endpoints.conversations.LlamaStackClientHolder" + "app.endpoints.conversations.AsyncLlamaStackClientHolder" ) mock_client_holder.return_value.get_client.return_value = mock_client with pytest.raises(HTTPException) as exc_info: - delete_conversation_endpoint_handler(VALID_CONVERSATION_ID, _auth=MOCK_AUTH) + await delete_conversation_endpoint_handler( + VALID_CONVERSATION_ID, _auth=MOCK_AUTH + ) assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND assert "Conversation not found" in exc_info.value.detail["response"] assert "could not be deleted" in exc_info.value.detail["cause"] assert VALID_CONVERSATION_ID in exc_info.value.detail["cause"] - def test_session_deletion_exception(self, mocker, setup_configuration): + @pytest.mark.asyncio + async def test_session_deletion_exception(self, mocker, setup_configuration): """Test the endpoint when session deletion raises an exception.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Mock LlamaStackClientHolder to raise a general exception - mock_client = mocker.Mock() + # Mock AsyncLlamaStackClientHolder to raise a general exception + mock_client = mocker.AsyncMock() mock_client.agents.session.delete.side_effect = Exception( "Session deletion failed" ) mock_client_holder = mocker.patch( - "app.endpoints.conversations.LlamaStackClientHolder" + "app.endpoints.conversations.AsyncLlamaStackClientHolder" ) mock_client_holder.return_value.get_client.return_value = mock_client with pytest.raises(HTTPException) as exc_info: - delete_conversation_endpoint_handler(VALID_CONVERSATION_ID, _auth=MOCK_AUTH) + await delete_conversation_endpoint_handler( + VALID_CONVERSATION_ID, _auth=MOCK_AUTH + ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert "Unknown error" in exc_info.value.detail["response"] @@ -397,20 +429,21 @@ def test_session_deletion_exception(self, mocker, setup_configuration): in exc_info.value.detail["cause"] ) - def test_successful_conversation_deletion(self, mocker, setup_configuration): + @pytest.mark.asyncio + async def test_successful_conversation_deletion(self, mocker, setup_configuration): """Test successful conversation deletion.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Mock LlamaStackClientHolder - mock_client = mocker.Mock() + # Mock AsyncLlamaStackClientHolder + mock_client = mocker.AsyncMock() mock_client.agents.session.delete.return_value = None # Successful deletion mock_client_holder = mocker.patch( - "app.endpoints.conversations.LlamaStackClientHolder" + "app.endpoints.conversations.AsyncLlamaStackClientHolder" ) mock_client_holder.return_value.get_client.return_value = mock_client - response = delete_conversation_endpoint_handler( + response = await delete_conversation_endpoint_handler( VALID_CONVERSATION_ID, _auth=MOCK_AUTH ) diff --git a/tests/unit/app/endpoints/test_models.py b/tests/unit/app/endpoints/test_models.py index f11da379f..ca58c4b3e 100644 --- a/tests/unit/app/endpoints/test_models.py +++ b/tests/unit/app/endpoints/test_models.py @@ -10,7 +10,8 @@ from configuration import AppConfig -def test_models_endpoint_handler_configuration_not_loaded(mocker): +@pytest.mark.asyncio +async def test_models_endpoint_handler_configuration_not_loaded(mocker): """Test the models endpoint handler if configuration is not loaded.""" # simulate state when no configuration is loaded mocker.patch( @@ -27,12 +28,13 @@ def test_models_endpoint_handler_configuration_not_loaded(mocker): ) with pytest.raises(HTTPException) as e: - models_endpoint_handler(request) - assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + await models_endpoint_handler(request) + assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert e.detail["response"] == "Configuration is not loaded" -def test_models_endpoint_handler_improper_llama_stack_configuration(mocker): +@pytest.mark.asyncio +async def test_models_endpoint_handler_improper_llama_stack_configuration(mocker): """Test the models endpoint handler if Llama Stack configuration is not proper.""" # configuration for tests config_dict = { @@ -71,12 +73,13 @@ def test_models_endpoint_handler_improper_llama_stack_configuration(mocker): } ) with pytest.raises(HTTPException) as e: - models_endpoint_handler(request) - assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + await models_endpoint_handler(request) + assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert e.detail["response"] == "LLama stack is not configured" -def test_models_endpoint_handler_configuration_loaded(): +@pytest.mark.asyncio +async def test_models_endpoint_handler_configuration_loaded(): """Test the models endpoint handler if configuration is loaded.""" # configuration for tests config_dict = { @@ -110,12 +113,13 @@ def test_models_endpoint_handler_configuration_loaded(): ) with pytest.raises(HTTPException) as e: - models_endpoint_handler(request) - assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + await models_endpoint_handler(request) + assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert e.detail["response"] == "Unable to connect to Llama Stack" -def test_models_endpoint_handler_unable_to_retrieve_models_list(mocker): +@pytest.mark.asyncio +async def test_models_endpoint_handler_unable_to_retrieve_models_list(mocker): """Test the models endpoint handler if configuration is loaded.""" # configuration for tests config_dict = { @@ -142,9 +146,9 @@ def test_models_endpoint_handler_unable_to_retrieve_models_list(mocker): cfg.init_from_dict(config_dict) # Mock the LlamaStack client - mock_client = mocker.Mock() + mock_client = mocker.AsyncMock() mock_client.models.list.return_value = [] - mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client") + mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") mock_lsc.return_value = mock_client mock_config = mocker.Mock() mocker.patch("app.endpoints.models.configuration", mock_config) @@ -155,11 +159,12 @@ def test_models_endpoint_handler_unable_to_retrieve_models_list(mocker): "headers": [(b"authorization", b"Bearer invalid-token")], } ) - response = models_endpoint_handler(request) + response = await models_endpoint_handler(request) assert response is not None -def test_models_endpoint_llama_stack_connection_error(mocker): +@pytest.mark.asyncio +async def test_models_endpoint_llama_stack_connection_error(mocker): """Test the model endpoint when LlamaStack connection fails.""" # configuration for tests config_dict = { @@ -183,11 +188,13 @@ def test_models_endpoint_llama_stack_connection_error(mocker): "customization": None, } - # mock LlamaStackClientHolder to raise APIConnectionError + # mock AsyncLlamaStackClientHolder to raise APIConnectionError # when models.list() method is called - mock_client = mocker.Mock() + mock_client = mocker.AsyncMock() mock_client.models.list.side_effect = APIConnectionError(request=None) - mock_client_holder = mocker.patch("app.endpoints.models.LlamaStackClientHolder") + mock_client_holder = mocker.patch( + "app.endpoints.models.AsyncLlamaStackClientHolder" + ) mock_client_holder.return_value.get_client.return_value = mock_client cfg = AppConfig() @@ -201,6 +208,6 @@ def test_models_endpoint_llama_stack_connection_error(mocker): ) with pytest.raises(HTTPException) as e: - models_endpoint_handler(request) - assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + await models_endpoint_handler(request) + assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert e.detail["response"] == "Unable to connect to Llama Stack" diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 1d8fabdcd..5530f56ca 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -19,7 +19,6 @@ construct_transcripts_path, store_transcript, get_rag_toolgroups, - get_agent, ) from models.requests import QueryRequest, Attachment @@ -57,16 +56,8 @@ def setup_configuration_fixture(): return cfg -@pytest.fixture(autouse=True, name="prepare_agent_mocks") -def prepare_agent_mocks_fixture(mocker): - """Fixture that yields mock agent when called.""" - mock_client = mocker.Mock() - mock_agent = mocker.Mock() - mock_agent.create_turn.return_value.steps = [] - yield mock_client, mock_agent - - -def test_query_endpoint_handler_configuration_not_loaded(mocker): +@pytest.mark.asyncio +async def test_query_endpoint_handler_configuration_not_loaded(mocker): """Test the query endpoint handler if configuration is not loaded.""" # simulate state when no configuration is loaded mocker.patch( @@ -77,9 +68,9 @@ def test_query_endpoint_handler_configuration_not_loaded(mocker): request = None with pytest.raises(HTTPException) as e: - query_endpoint_handler(request, auth=["test-user", "", "token"]) - assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR - assert e.detail["response"] == "Configuration is not loaded" + await query_endpoint_handler(request, auth=["test-user", "", "token"]) + assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert e.value.detail["response"] == "Configuration is not loaded" def test_is_transcripts_enabled(setup_configuration, mocker): @@ -103,11 +94,11 @@ def test_is_transcripts_disabled(setup_configuration, mocker): assert is_transcripts_enabled() is False, "Transcripts should be disabled" -def _test_query_endpoint_handler(mocker, store_transcript_to_file=False): +async def _test_query_endpoint_handler(mocker, store_transcript_to_file=False): """Test the query endpoint handler.""" mock_metric = mocker.patch("metrics.llm_calls_total") - mock_client = mocker.Mock() - mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client") + mock_client = mocker.AsyncMock() + mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") mock_lsc.return_value = mock_client mock_client.models.list.return_value = [ mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), @@ -140,7 +131,7 @@ def _test_query_endpoint_handler(mocker, store_transcript_to_file=False): query_request = QueryRequest(query=query) - response = query_endpoint_handler(query_request, auth=MOCK_AUTH) + response = await query_endpoint_handler(query_request, auth=MOCK_AUTH) # Assert the response is as expected assert response.response == llm_response @@ -166,14 +157,16 @@ def _test_query_endpoint_handler(mocker, store_transcript_to_file=False): mock_transcript.assert_not_called() -def test_query_endpoint_handler_transcript_storage_disabled(mocker): +@pytest.mark.asyncio +async def test_query_endpoint_handler_transcript_storage_disabled(mocker): """Test the query endpoint handler with transcript storage disabled.""" - _test_query_endpoint_handler(mocker, store_transcript_to_file=False) + await _test_query_endpoint_handler(mocker, store_transcript_to_file=False) -def test_query_endpoint_handler_store_transcript(mocker): +@pytest.mark.asyncio +async def test_query_endpoint_handler_store_transcript(mocker): """Test the query endpoint handler with transcript storage enabled.""" - _test_query_endpoint_handler(mocker, store_transcript_to_file=True) + await _test_query_endpoint_handler(mocker, store_transcript_to_file=True) def test_select_model_and_provider_id_from_request(mocker): @@ -362,7 +355,8 @@ def test_validate_attachments_metadata_invalid_content_type(): ) -def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker): +@pytest.mark.asyncio +async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_metric = mocker.patch("metrics.llm_calls_validation_errors_total") mock_client, mock_agent = prepare_agent_mocks @@ -385,7 +379,7 @@ def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -402,7 +396,8 @@ def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker): ) -def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocker): +@pytest.mark.asyncio +async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" @@ -422,7 +417,7 @@ def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -437,7 +432,8 @@ def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocker): ) -def test_retrieve_response_one_available_shield(prepare_agent_mocks, mocker): +@pytest.mark.asyncio +async def test_retrieve_response_one_available_shield(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" class MockShield: @@ -470,7 +466,7 @@ def __repr__(self): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -485,7 +481,8 @@ def __repr__(self): ) -def test_retrieve_response_two_available_shields(prepare_agent_mocks, mocker): +@pytest.mark.asyncio +async def test_retrieve_response_two_available_shields(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" class MockShield: @@ -521,7 +518,7 @@ def __repr__(self): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -536,7 +533,8 @@ def __repr__(self): ) -def test_retrieve_response_four_available_shields(prepare_agent_mocks, mocker): +@pytest.mark.asyncio +async def test_retrieve_response_four_available_shields(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" class MockShield: @@ -574,7 +572,7 @@ def __repr__(self): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -601,7 +599,8 @@ def __repr__(self): ) -def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker): +@pytest.mark.asyncio +async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" @@ -629,7 +628,7 @@ def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -649,7 +648,8 @@ def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker): ) -def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocker): +@pytest.mark.asyncio +async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" @@ -682,7 +682,7 @@ def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -706,7 +706,8 @@ def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocker): ) -def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): +@pytest.mark.asyncio +async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): """Test the retrieve_response function with MCP servers configured.""" mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" @@ -736,7 +737,7 @@ def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token_123" - response, conversation_id = retrieve_response( + response, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -779,7 +780,10 @@ def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): ) -def test_retrieve_response_with_mcp_servers_empty_token(prepare_agent_mocks, mocker): +@pytest.mark.asyncio +async def test_retrieve_response_with_mcp_servers_empty_token( + prepare_agent_mocks, mocker +): """Test the retrieve_response function with MCP servers and empty access token.""" mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" @@ -802,7 +806,7 @@ def test_retrieve_response_with_mcp_servers_empty_token(prepare_agent_mocks, moc model_id = "fake_model_id" access_token = "" # Empty token - response, conversation_id = retrieve_response( + response, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -830,7 +834,8 @@ def test_retrieve_response_with_mcp_servers_empty_token(prepare_agent_mocks, moc ) -def test_retrieve_response_with_mcp_servers_and_mcp_headers( +@pytest.mark.asyncio +async def test_retrieve_response_with_mcp_servers_and_mcp_headers( prepare_agent_mocks, mocker ): """Test the retrieve_response function with MCP servers configured.""" @@ -872,7 +877,7 @@ def test_retrieve_response_with_mcp_servers_and_mcp_headers( }, } - response, conversation_id = retrieve_response( + response, conversation_id = await retrieve_response( mock_client, model_id, query_request, @@ -924,7 +929,8 @@ def test_retrieve_response_with_mcp_servers_and_mcp_headers( ) -def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker): +@pytest.mark.asyncio +async def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_metric = mocker.patch("metrics.llm_calls_validation_errors_total") mock_client, mock_agent = prepare_agent_mocks @@ -952,7 +958,7 @@ def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker): query_request = QueryRequest(query="What is OpenStack?") - _, conversation_id = retrieve_response( + _, conversation_id = await retrieve_response( mock_client, "fake_model_id", query_request, "test_token" ) @@ -1058,7 +1064,8 @@ def test_get_rag_toolgroups(): assert result[0]["args"]["vector_db_ids"] == vector_db_ids -def test_query_endpoint_handler_on_connection_error(mocker): +@pytest.mark.asyncio +async def test_query_endpoint_handler_on_connection_error(mocker): """Test the query endpoint handler.""" mock_metric = mocker.patch("metrics.llm_calls_failures_total") @@ -1070,320 +1077,32 @@ def test_query_endpoint_handler_on_connection_error(mocker): query_request = QueryRequest(query="What is OpenStack?") # simulate situation when it is not possible to connect to Llama Stack - mock_get_client = mocker.patch("client.LlamaStackClientHolder.get_client") + mock_get_client = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") mock_get_client.side_effect = APIConnectionError(request=query_request) with pytest.raises(HTTPException) as exc_info: - query_endpoint_handler(query_request, auth=MOCK_AUTH) + await query_endpoint_handler(query_request, auth=MOCK_AUTH) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert "Unable to connect to Llama Stack" in str(exc_info.value.detail) mock_metric.inc.assert_called_once() -def test_get_agent_with_conversation_id(prepare_agent_mocks, mocker): - """Test get_agent function when agent exists in llama stack.""" - mock_client, mock_agent = prepare_agent_mocks - mock_client.agents.session.list.return_value = mocker.Mock( - data=[{"session_id": "test_session_id"}] - ) - - # Set up cache with existing agent - conversation_id = "test_conversation_id" - - # Mock Agent class - mocker.patch("app.endpoints.query.Agent", return_value=mock_agent) - - result_agent, result_conversation_id, result_session_id = get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=conversation_id, - ) - - # Assert the same agent is returned - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert conversation_id == result_agent.agent_id - assert result_session_id == "test_session_id" - - -def test_get_agent_with_conversation_id_and_no_agent_in_llama_stack( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function when conversation_id is provided.""" - mock_client, mock_agent = prepare_agent_mocks - mock_client.agents.retrieve.side_effect = ValueError( - "fake not finding existing agent" - ) - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "app.endpoints.query.Agent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch("app.endpoints.query.get_suid", return_value="new_session_id") - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("app.endpoints.query.configuration", setup_configuration) - conversation_id = "non_existent_conversation_id" - # Call function with conversation_id - result_agent, result_conversation_id, result_session_id = get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=conversation_id, - ) - - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert conversation_id != result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with correct parameters - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1"], - output_shields=["output_shield2"], - tool_parser=None, - enable_session_persistence=True, - ) - - -def test_get_agent_no_conversation_id(setup_configuration, prepare_agent_mocks, mocker): - """Test get_agent function when conversation_id is None.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "app.endpoints.query.Agent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch("app.endpoints.query.get_suid", return_value="new_session_id") - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("app.endpoints.query.configuration", setup_configuration) - - # Call function with None conversation_id - result_agent, result_conversation_id, result_session_id = get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=None, - ) - - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with correct parameters - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1"], - output_shields=["output_shield2"], - tool_parser=None, - enable_session_persistence=True, - ) - - -def test_get_agent_empty_shields(setup_configuration, prepare_agent_mocks, mocker): - """Test get_agent function with empty shields list.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "app.endpoints.query.Agent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch("app.endpoints.query.get_suid", return_value="new_session_id") - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("app.endpoints.query.configuration", setup_configuration) - - # Call function with empty shields list - result_agent, result_conversation_id, result_session_id = get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=[], - available_output_shields=[], - conversation_id=None, - ) - - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with empty shields - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=[], - output_shields=[], - tool_parser=None, - enable_session_persistence=True, - ) - - -def test_get_agent_multiple_mcp_servers( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function with multiple MCP servers.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "app.endpoints.query.Agent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch("app.endpoints.query.get_suid", return_value="new_session_id") - - # Mock configuration with multiple MCP servers - mock_mcp_server1 = mocker.Mock() - mock_mcp_server1.name = "mcp_server_1" - mock_mcp_server2 = mocker.Mock() - mock_mcp_server2.name = "mcp_server_2" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server1, mock_mcp_server2], - ) - mocker.patch("app.endpoints.query.configuration", setup_configuration) - - # Call function - result_agent, result_conversation_id, result_session_id = get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1", "shield2"], - available_output_shields=["output_shield3", "output_shield4"], - conversation_id=None, - ) - - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with tools from both MCP servers - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1", "shield2"], - output_shields=["output_shield3", "output_shield4"], - tool_parser=None, - enable_session_persistence=True, - ) - - -def test_get_agent_session_persistence_enabled( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function ensures session persistence is enabled.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "app.endpoints.query.Agent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch("app.endpoints.query.get_suid", return_value="new_session_id") - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("app.endpoints.query.configuration", setup_configuration) - - # Call function - get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=None, - ) - - # Verify Agent was created with session persistence enabled - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1"], - output_shields=["output_shield2"], - tool_parser=None, - enable_session_persistence=True, - ) - - -def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker): +@pytest.mark.asyncio +async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker): """Test that auth tuple is correctly unpacked in query endpoint handler.""" # Mock dependencies mock_config = mocker.Mock() mock_config.llama_stack_configuration = mocker.Mock() mocker.patch("app.endpoints.query.configuration", mock_config) - mock_client = mocker.Mock() + mock_client = mocker.AsyncMock() mock_client.models.list.return_value = [ mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1") ] - mocker.patch("client.LlamaStackClientHolder.get_client", return_value=mock_client) + mocker.patch( + "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client + ) mock_retrieve_response = mocker.patch( "app.endpoints.query.retrieve_response", @@ -1396,7 +1115,7 @@ def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker): ) mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) - _ = query_endpoint_handler( + _ = await query_endpoint_handler( QueryRequest(query="test query"), auth=("user123", "username", "auth_token_123"), mcp_headers=None, @@ -1405,10 +1124,11 @@ def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker): assert mock_retrieve_response.call_args[0][3] == "auth_token_123" -def test_query_endpoint_handler_no_tools_true(mocker): +@pytest.mark.asyncio +async def test_query_endpoint_handler_no_tools_true(mocker): """Test the query endpoint handler with no_tools=True.""" - mock_client = mocker.Mock() - mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client") + mock_client = mocker.AsyncMock() + mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") mock_lsc.return_value = mock_client mock_client.models.list.return_value = [ mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), @@ -1434,17 +1154,18 @@ def test_query_endpoint_handler_no_tools_true(mocker): query_request = QueryRequest(query=query, no_tools=True) - response = query_endpoint_handler(query_request, auth=MOCK_AUTH) + response = await query_endpoint_handler(query_request, auth=MOCK_AUTH) # Assert the response is as expected assert response.response == llm_response assert response.conversation_id == conversation_id -def test_query_endpoint_handler_no_tools_false(mocker): +@pytest.mark.asyncio +async def test_query_endpoint_handler_no_tools_false(mocker): """Test the query endpoint handler with no_tools=False (default behavior).""" - mock_client = mocker.Mock() - mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client") + mock_client = mocker.AsyncMock() + mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") mock_lsc.return_value = mock_client mock_client.models.list.return_value = [ mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), @@ -1470,14 +1191,17 @@ def test_query_endpoint_handler_no_tools_false(mocker): query_request = QueryRequest(query=query, no_tools=False) - response = query_endpoint_handler(query_request, auth=MOCK_AUTH) + response = await query_endpoint_handler(query_request, auth=MOCK_AUTH) # Assert the response is as expected assert response.response == llm_response assert response.conversation_id == conversation_id -def test_retrieve_response_no_tools_bypasses_mcp_and_rag(prepare_agent_mocks, mocker): +@pytest.mark.asyncio +async def test_retrieve_response_no_tools_bypasses_mcp_and_rag( + prepare_agent_mocks, mocker +): """Test that retrieve_response bypasses MCP servers and RAG when no_tools=True.""" mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" @@ -1504,7 +1228,7 @@ def test_retrieve_response_no_tools_bypasses_mcp_and_rag(prepare_agent_mocks, mo model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -1524,7 +1248,8 @@ def test_retrieve_response_no_tools_bypasses_mcp_and_rag(prepare_agent_mocks, mo ) -def test_retrieve_response_no_tools_false_preserves_functionality( +@pytest.mark.asyncio +async def test_retrieve_response_no_tools_false_preserves_functionality( prepare_agent_mocks, mocker ): """Test that retrieve_response preserves normal functionality when no_tools=False.""" @@ -1553,7 +1278,7 @@ def test_retrieve_response_no_tools_false_preserves_functionality( model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -1583,117 +1308,6 @@ def test_retrieve_response_no_tools_false_preserves_functionality( ) -def test_get_agent_no_tools_no_parser(setup_configuration, prepare_agent_mocks, mocker): - """Test get_agent function sets tool_parser=None when no_tools=True.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "app.endpoints.query.Agent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch("app.endpoints.query.get_suid", return_value="new_session_id") - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("app.endpoints.query.configuration", setup_configuration) - - # Call function with no_tools=True - result_agent, result_conversation_id, result_session_id = get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=None, - no_tools=True, - ) - - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with tool_parser=None - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1"], - output_shields=["output_shield2"], - tool_parser=None, - enable_session_persistence=True, - ) - - -def test_get_agent_no_tools_false_preserves_parser( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function preserves tool_parser when no_tools=False.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "app.endpoints.query.Agent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch("app.endpoints.query.get_suid", return_value="new_session_id") - - # Mock GraniteToolParser - mock_parser = mocker.Mock() - mock_granite_parser = mocker.patch("app.endpoints.query.GraniteToolParser") - mock_granite_parser.get_parser.return_value = mock_parser - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("app.endpoints.query.configuration", setup_configuration) - - # Call function with no_tools=False - result_agent, result_conversation_id, result_session_id = get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=None, - no_tools=False, - ) - - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with the proper tool_parser - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1"], - output_shields=["output_shield2"], - tool_parser=mock_parser, - enable_session_persistence=True, - ) - - def test_no_tools_parameter_backward_compatibility(): """Test that default behavior is unchanged when no_tools parameter is not specified.""" # This test ensures that existing code that doesn't specify no_tools continues to work diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 15713e4d6..534c95e1b 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -40,7 +40,6 @@ streaming_query_endpoint_handler, retrieve_response, stream_build_event, - get_agent, ) from models.requests import QueryRequest, Attachment @@ -106,14 +105,6 @@ def setup_configuration_fixture(): return cfg -@pytest.fixture(autouse=True, name="prepare_agent_mocks") -def prepare_agent_mocks_fixture(mocker): - """Preparation for mock for the LLM agent.""" - mock_client = mocker.AsyncMock() - mock_agent = mocker.AsyncMock() - yield mock_client, mock_agent - - @pytest.mark.asyncio async def test_streaming_query_endpoint_handler_configuration_not_loaded(mocker): """Test the streaming query endpoint handler if configuration is not loaded.""" @@ -149,7 +140,7 @@ async def test_streaming_query_endpoint_on_connection_error(mocker): # simulate situation when it is not possible to connect to Llama Stack mock_client = mocker.AsyncMock() mock_client.models.side_effect = APIConnectionError(request=query_request) - mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client") + mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") mock_lsc.return_value = mock_client mock_async_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") mock_async_lsc.return_value = mock_client @@ -1236,318 +1227,6 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker): ) -@pytest.mark.asyncio -async def test_get_agent_with_conversation_id(prepare_agent_mocks, mocker): - """Test get_agent function when agent exists in llama stack.""" - - mock_client, mock_agent = prepare_agent_mocks - - conversation_id = "test_conversation_id" - - # Mock Agent class - mocker.patch("app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent) - - result_agent, result_conversation_id, _ = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=conversation_id, - ) - - # Assert the same agent is returned - assert result_agent == mock_agent - assert result_conversation_id == conversation_id - assert conversation_id == mock_agent._agent_id # pylint: disable=protected-access - - -@pytest.mark.asyncio -async def test_get_agent_with_conversation_id_and_no_agent_in_llama_stack( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function when conversation_id is provided but agent not in llama stack.""" - - mock_client, mock_agent = prepare_agent_mocks - mock_client.agents.retrieve.side_effect = ValueError( - "fake not finding existing agent" - ) - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch( - "app.endpoints.streaming_query.get_suid", return_value="new_session_id" - ) - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) - - # Call function with conversation_id - result_agent, result_conversation_id, result_session_id = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id="non_existent_conversation_id", - ) - - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert "non_existent_conversation_id" != result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with correct parameters - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1"], - output_shields=["output_shield2"], - tool_parser=None, - enable_session_persistence=True, - ) - - -@pytest.mark.asyncio -async def test_get_agent_no_conversation_id( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function when conversation_id is None.""" - - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch( - "app.endpoints.streaming_query.get_suid", return_value="new_session_id" - ) - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) - - # Call function with None conversation_id - result_agent, result_conversation_id, result_session_id = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=None, - ) - - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with correct parameters - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1"], - output_shields=["output_shield2"], - tool_parser=None, - enable_session_persistence=True, - ) - - -@pytest.mark.asyncio -async def test_get_agent_empty_shields( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function with empty shields list.""" - - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch( - "app.endpoints.streaming_query.get_suid", return_value="new_session_id" - ) - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) - - # Call function with empty shields list - result_agent, result_conversation_id, result_session_id = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=[], - available_output_shields=[], - conversation_id=None, - ) - - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with empty shields - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=[], - output_shields=[], - tool_parser=None, - enable_session_persistence=True, - ) - - -@pytest.mark.asyncio -async def test_get_agent_multiple_mcp_servers( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function with multiple MCP servers.""" - - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch( - "app.endpoints.streaming_query.get_suid", return_value="new_session_id" - ) - - # Mock configuration with multiple MCP servers - mock_mcp_server1 = mocker.Mock() - mock_mcp_server1.name = "mcp_server_1" - mock_mcp_server2 = mocker.Mock() - mock_mcp_server2.name = "mcp_server_2" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server1, mock_mcp_server2], - ) - mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) - - # Call function - result_agent, result_conversation_id, result_session_id = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1", "shield2"], - available_output_shields=["output_shield3", "output_shield4"], - conversation_id=None, - ) - - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with tools from both MCP servers - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1", "shield2"], - output_shields=["output_shield3", "output_shield4"], - tool_parser=None, - enable_session_persistence=True, - ) - - -@pytest.mark.asyncio -async def test_get_agent_session_persistence_enabled( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function ensures session persistence is enabled.""" - - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch( - "app.endpoints.streaming_query.get_suid", return_value="new_session_id" - ) - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) - - # Call function - await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=None, - ) - - # Verify Agent was created with session persistence enabled - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1"], - output_shields=["output_shield2"], - tool_parser=None, - enable_session_persistence=True, - ) - - @pytest.mark.asyncio async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker): """Test that auth tuple is correctly unpacked in streaming query endpoint handler.""" @@ -1777,124 +1456,3 @@ async def test_retrieve_response_no_tools_false_preserves_functionality( stream=True, toolgroups=expected_toolgroups, ) - - -@pytest.mark.asyncio -async def test_get_agent_no_tools_no_parser( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function sets tool_parser=None when no_tools=True.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch( - "app.endpoints.streaming_query.get_suid", return_value="new_session_id" - ) - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) - - # Call function with no_tools=True - result_agent, result_conversation_id, result_session_id = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=None, - no_tools=True, - ) - - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with tool_parser=None - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1"], - output_shields=["output_shield2"], - tool_parser=None, - enable_session_persistence=True, - ) - - -@pytest.mark.asyncio -async def test_get_agent_no_tools_false_preserves_parser( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function preserves tool_parser when no_tools=False.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch( - "app.endpoints.streaming_query.get_suid", return_value="new_session_id" - ) - - # Mock GraniteToolParser - mock_parser = mocker.Mock() - mock_granite_parser = mocker.patch( - "app.endpoints.streaming_query.GraniteToolParser" - ) - mock_granite_parser.get_parser.return_value = mock_parser - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) - - # Call function with no_tools=False - result_agent, result_conversation_id, result_session_id = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=None, - no_tools=False, - ) - - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with the proper tool_parser - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1"], - output_shields=["output_shield2"], - tool_parser=mock_parser, - enable_session_persistence=True, - ) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 000000000..73f4f5eae --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,28 @@ +"""Shared pytest fixtures for unit tests.""" + +from __future__ import annotations + +import pytest + + +@pytest.fixture(name="prepare_agent_mocks", scope="function") +def prepare_agent_mocks_fixture(mocker): + """Prepare for mock for the LLM agent. + + Provides common mocks for AsyncLlamaStackClient and AsyncAgent + with proper agent_id setup to avoid initialization errors. + + Returns: + tuple: (mock_client, mock_agent) + """ + mock_client = mocker.AsyncMock() + mock_agent = mocker.AsyncMock() + + # Set up agent_id property to avoid "Agent ID not initialized" error + mock_agent._agent_id = "test_agent_id" # pylint: disable=protected-access + mock_agent.agent_id = "test_agent_id" + + # Set up create_turn mock structure for query endpoints that need it + mock_agent.create_turn.return_value.steps = [] + + yield mock_client, mock_agent diff --git a/tests/unit/metrics/test_utis.py b/tests/unit/metrics/test_utis.py index 295241cd6..b434a0ccb 100644 --- a/tests/unit/metrics/test_utis.py +++ b/tests/unit/metrics/test_utis.py @@ -7,7 +7,14 @@ async def test_setup_model_metrics(mocker): """Test the setup_model_metrics function.""" # Mock the LlamaStackAsLibraryClient - mock_client = mocker.patch("client.LlamaStackClientHolder.get_client").return_value + mock_client = mocker.patch( + "client.AsyncLlamaStackClientHolder.get_client" + ).return_value + # Make sure the client is an AsyncMock for async methods + mock_client = mocker.AsyncMock() + mocker.patch( + "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client + ) mocker.patch( "metrics.utils.configuration.inference.default_provider", "default_provider", diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 980ac494b..5405092fe 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -2,22 +2,10 @@ import pytest -from client import LlamaStackClientHolder, AsyncLlamaStackClientHolder +from client import AsyncLlamaStackClientHolder from models.config import LlamaStackConfiguration -def test_client_get_client_method() -> None: - """Test how get_client method works for uninitialized client.""" - - client = LlamaStackClientHolder() - - with pytest.raises( - RuntimeError, - match="LlamaStackClient has not been initialised. Ensure 'load\\(..\\)' has been called.", - ): - client.get_client() - - def test_async_client_get_client_method() -> None: """Test how get_client method works for uninitialized client.""" client = AsyncLlamaStackClientHolder() @@ -32,61 +20,6 @@ def test_async_client_get_client_method() -> None: client.get_client() -def test_get_llama_stack_library_client() -> None: - """Test if Llama Stack can be initialized in library client mode.""" - cfg = LlamaStackConfiguration( - url=None, - api_key=None, - use_as_library_client=True, - library_client_config_path="./tests/configuration/minimal-stack.yaml", - ) - client = LlamaStackClientHolder() - client.load(cfg) - assert client is not None - - ls_client = client.get_client() - assert ls_client is not None - assert not ls_client.is_closed() - ls_client.close() - assert ls_client.is_closed() - - -def test_get_llama_stack_remote_client() -> None: - """Test if Llama Stack can be initialized in remove client (server) mode.""" - cfg = LlamaStackConfiguration( - url="http://localhost:8321", - api_key=None, - use_as_library_client=False, - library_client_config_path="./tests/configuration/minimal-stack.yaml", - ) - client = LlamaStackClientHolder() - client.load(cfg) - assert client is not None - - ls_client = client.get_client() - assert ls_client is not None - assert not ls_client.is_closed() - ls_client.close() - assert ls_client.is_closed() - - -def test_get_llama_stack_wrong_configuration() -> None: - """Test if configuration is checked before Llama Stack is initialized.""" - cfg = LlamaStackConfiguration( - url=None, - api_key=None, - use_as_library_client=True, - library_client_config_path="./tests/configuration/minimal-stack.yaml", - ) - cfg.library_client_config_path = None - with pytest.raises( - Exception, - match="Configuration problem: library_client_config_path option is not set", - ): - client = LlamaStackClientHolder() - client.load(cfg) - - @pytest.mark.asyncio async def test_get_async_llama_stack_library_client() -> None: """Test the initialization of asynchronous Llama Stack client in library mode.""" diff --git a/tests/unit/utils/test_common.py b/tests/unit/utils/test_common.py index f321ee46c..dd9652c0d 100644 --- a/tests/unit/utils/test_common.py +++ b/tests/unit/utils/test_common.py @@ -24,7 +24,7 @@ async def test_register_mcp_servers_empty_list(mocker): mock_logger = Mock(spec=Logger) # Mock the LlamaStack client (shouldn't be called since no MCP servers) - mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client") + mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") # Create configuration with empty MCP servers config = Configuration( @@ -55,8 +55,8 @@ async def test_register_mcp_servers_single_server_not_registered(mocker): mock_logger = Mock(spec=Logger) # Mock the LlamaStack client - mock_client = Mock() - mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client") + mock_client = AsyncMock() + mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") mock_lsc.return_value = mock_client mock_tool = Mock() mock_tool.provider_resource_id = "existing-server" @@ -100,11 +100,11 @@ async def test_register_mcp_servers_single_server_already_registered(mocker): mock_logger = Mock(spec=Logger) # Mock the LlamaStack client - mock_client = Mock() + mock_client = AsyncMock() mock_tool = Mock() mock_tool.provider_resource_id = "existing-server" mock_client.toolgroups.list.return_value = [mock_tool] - mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client") + mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") mock_lsc.return_value = mock_client # Create configuration with MCP server that matches existing toolgroup @@ -138,8 +138,8 @@ async def test_register_mcp_servers_multiple_servers_mixed_registration(mocker): mock_logger = Mock(spec=Logger) # Mock the LlamaStack client - mock_client = Mock() - mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client") + mock_client = AsyncMock() + mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") mock_lsc.return_value = mock_client mock_tool1 = Mock() mock_tool1.provider_resource_id = "existing-server" @@ -200,10 +200,10 @@ async def test_register_mcp_servers_with_custom_provider(mocker): mock_logger = Mock(spec=Logger) # Mock the LlamaStack client - mock_client = Mock() + mock_client = AsyncMock() mock_client.toolgroups.list.return_value = [] mock_client.toolgroups.register.return_value = None - mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client") + mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") mock_lsc.return_value = mock_client # Create configuration with MCP server using custom provider diff --git a/tests/unit/utils/test_endpoints.py b/tests/unit/utils/test_endpoints.py index eb8311b5b..e3b02d6ad 100644 --- a/tests/unit/utils/test_endpoints.py +++ b/tests/unit/utils/test_endpoints.py @@ -10,6 +10,7 @@ from models.requests import QueryRequest from utils import endpoints +from utils.endpoints import get_agent CONFIGURED_SYSTEM_PROMPT = "This is a configured system prompt" @@ -80,6 +81,34 @@ def query_request_with_system_prompt_fixture(): return QueryRequest(query="query", system_prompt="System prompt defined in query") +@pytest.fixture(name="setup_configuration") +def setup_configuration_fixture(): + """Set up configuration for tests.""" + test_config_dict = { + "name": "test", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "workers": 1, + "color_log": True, + "access_log": True, + }, + "llama_stack": { + "api_key": "test-key", + "url": "http://test.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": { + "transcripts_enabled": False, + }, + "mcp_servers": [], + } + cfg = AppConfig() + cfg.init_from_dict(test_config_dict) + return cfg + + def test_get_default_system_prompt( config_without_system_prompt, query_request_without_system_prompt ): @@ -143,3 +172,422 @@ def test_get_system_prompt_with_disable_query_system_prompt_and_non_system_promp config_with_custom_system_prompt_and_disable_query_system_prompt, ) assert system_prompt == CONFIGURED_SYSTEM_PROMPT + + +@pytest.mark.asyncio +async def test_get_agent_with_conversation_id(prepare_agent_mocks, mocker): + """Test get_agent function when agent exists in llama stack.""" + mock_client, mock_agent = prepare_agent_mocks + conversation_id = "test_conversation_id" + + # Mock existing agent retrieval + mock_agent_response = mocker.Mock() + mock_agent_response.agent_id = conversation_id + mock_client.agents.retrieve.return_value = mock_agent_response + + mock_client.agents.session.list.return_value = mocker.Mock( + data=[{"session_id": "test_session_id"}] + ) + + # Mock Agent class + mocker.patch("utils.endpoints.AsyncAgent", return_value=mock_agent) + + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=conversation_id, + ) + + # Assert the same agent is returned and conversation_id is preserved + assert result_agent == mock_agent + assert result_conversation_id == conversation_id + assert result_session_id == "test_session_id" + + +@pytest.mark.asyncio +async def test_get_agent_with_conversation_id_and_no_agent_in_llama_stack( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function when conversation_id is provided.""" + mock_client, mock_agent = prepare_agent_mocks + mock_client.agents.retrieve.side_effect = ValueError( + "fake not finding existing agent" + ) + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "utils.endpoints.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("configuration.configuration", setup_configuration) + conversation_id = "non_existent_conversation_id" + # Call function with conversation_id + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=conversation_id, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == result_agent.agent_id + assert conversation_id != result_agent.agent_id + assert result_session_id == "new_session_id" + + # Verify Agent was created with correct parameters + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + output_shields=["output_shield2"], + tool_parser=None, + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_no_conversation_id( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function when conversation_id is None.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "utils.endpoints.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("configuration.configuration", setup_configuration) + + # Call function with None conversation_id + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=None, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" + + # Verify Agent was created with correct parameters + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + output_shields=["output_shield2"], + tool_parser=None, + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_empty_shields( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function with empty shields list.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "utils.endpoints.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("configuration.configuration", setup_configuration) + + # Call function with empty shields list + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=[], + available_output_shields=[], + conversation_id=None, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" + + # Verify Agent was created with empty shields + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=[], + output_shields=[], + tool_parser=None, + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_multiple_mcp_servers( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function with multiple MCP servers.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "utils.endpoints.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") + + # Mock configuration with multiple MCP servers + mock_mcp_server1 = mocker.Mock() + mock_mcp_server1.name = "mcp_server_1" + mock_mcp_server2 = mocker.Mock() + mock_mcp_server2.name = "mcp_server_2" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server1, mock_mcp_server2], + ) + mocker.patch("configuration.configuration", setup_configuration) + + # Call function + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1", "shield2"], + available_output_shields=["output_shield3", "output_shield4"], + conversation_id=None, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" + + # Verify Agent was created with tools from both MCP servers + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1", "shield2"], + output_shields=["output_shield3", "output_shield4"], + tool_parser=None, + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_session_persistence_enabled( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function ensures session persistence is enabled.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "utils.endpoints.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("configuration.configuration", setup_configuration) + + # Call function + await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=None, + ) + + # Verify Agent was created with session persistence enabled + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + output_shields=["output_shield2"], + tool_parser=None, + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_no_tools_no_parser( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function sets tool_parser=None when no_tools=True.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "utils.endpoints.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("configuration.configuration", setup_configuration) + + # Call function with no_tools=True + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=None, + no_tools=True, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" + + # Verify Agent was created with tool_parser=None + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + output_shields=["output_shield2"], + tool_parser=None, + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_no_tools_false_preserves_parser( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function preserves tool_parser when no_tools=False.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "utils.endpoints.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") + + # Mock GraniteToolParser + mock_parser = mocker.Mock() + mock_granite_parser = mocker.patch("utils.endpoints.GraniteToolParser") + mock_granite_parser.get_parser.return_value = mock_parser + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("configuration.configuration", setup_configuration) + + # Call function with no_tools=False + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=None, + no_tools=False, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" + + # Verify Agent was created with the proper tool_parser + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + output_shields=["output_shield2"], + tool_parser=mock_parser, + enable_session_persistence=True, + )