diff --git a/docs/openapi.json b/docs/openapi.json index 96b4725c8..cff61b422 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1052,6 +1052,9 @@ }, "inference": { "$ref": "#/components/schemas/InferenceConfiguration" + }, + "question_validation": { + "$ref": "#/components/schemas/QuestionValidationConfiguration" } }, "additionalProperties": false, @@ -1286,6 +1289,14 @@ "type": "object", "title": "Prompts", "default": {} + }, + "query_responses": { + "additionalProperties": { + "type": "string" + }, + "type": "object", + "title": "Query Responses", + "default": {} } }, "type": "object", @@ -2279,6 +2290,19 @@ } ] }, + "QuestionValidationConfiguration": { + "properties": { + "question_validation_enabled": { + "type": "boolean", + "title": "Question Validation Enabled", + "default": false + } + }, + "additionalProperties": false, + "type": "object", + "title": "QuestionValidationConfiguration", + "description": "Question validation configuration." + }, "ReadinessResponse": { "properties": { "ready": { diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 43c3eb603..99bc2ac09 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -5,8 +5,8 @@ import logging from typing import Annotated, Any, cast -from llama_stack_client import APIConnectionError -from llama_stack_client import AsyncLlamaStackClient # type: ignore + +from llama_stack_client import APIConnectionError, AsyncLlamaStackClient # type: ignore from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str from llama_stack_client.types import UserMessage, Shield # type: ignore from llama_stack_client.types.agents.turn import Turn @@ -31,9 +31,11 @@ from models.database.conversations import UserConversation from models.requests import QueryRequest, Attachment from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse +from utils.agent import get_agent, get_temp_agent from utils.endpoints import ( check_configuration_loaded, - get_agent, + get_invalid_query_response, + get_validation_system_prompt, get_system_prompt, validate_conversation_ownership, validate_model_provider_override, @@ -215,7 +217,7 @@ async def query_endpoint_handler( user_conversation=user_conversation, query_request=query_request ), ) - summary, conversation_id = await retrieve_response( + summary, conversation_id, query_is_valid = await retrieve_response( client, llama_stack_model_id, query_request, @@ -234,7 +236,7 @@ async def query_endpoint_handler( conversation_id=conversation_id, model_id=model_id, provider_id=provider_id, - query_is_valid=True, # TODO(lucasagomes): implement as part of query validation + query_is_valid=query_is_valid, query=query_request.query, query_request=query_request, summary=summary, @@ -391,6 +393,37 @@ def is_input_shield(shield: Shield) -> bool: return _is_inout_shield(shield) or not is_output_shield(shield) +async def validate_question( + question: str, client: AsyncLlamaStackClient, model_id: str +) -> tuple[bool, str]: + """Validate a question and provides a one-word response. + + Args: + question: The question to be validated. + client: The AsyncLlamaStackClient to use for the request. + model_id: The ID of the model to use. + + Returns: + bool: True if the question was deemed valid, False otherwise + """ + validation_system_prompt = get_validation_system_prompt(configuration) + agent, session_id, conversation_id = await get_temp_agent( + client, model_id, validation_system_prompt + ) + response = await agent.create_turn( + messages=[UserMessage(role="user", content=question)], + session_id=session_id, + stream=False, + toolgroups=None, + ) + response = cast(Turn, response) + return ( + constants.SUBJECT_REJECTED + not in interleaved_content_as_str(response.output_message.content), + conversation_id, + ) + + async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches,too-many-arguments client: AsyncLlamaStackClient, model_id: str, @@ -399,7 +432,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche mcp_headers: dict[str, dict[str, str]] | None = None, *, provider_id: str = "", -) -> tuple[TurnSummary, str]: +) -> tuple[TurnSummary, str, bool]: """ Retrieve response from LLMs and agents. @@ -462,6 +495,23 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche ) logger.debug("Conversation ID: %s, session ID: %s", conversation_id, session_id) + + # Validate the question if question validation is enabled + if configuration.question_validation.question_validation_enabled: + question_is_valid, _ = await validate_question( + query_request.query, client, model_id + ) + + if not question_is_valid: + return ( + TurnSummary( + llm_response=get_invalid_query_response(configuration), + tool_calls=[], + ), + conversation_id, + False, + ) + # bypass tools and MCP servers if no_tools is True if query_request.no_tools: mcp_headers = {} @@ -535,7 +585,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche "Response lacks output_message.content (conversation_id=%s)", conversation_id, ) - return summary, conversation_id + return summary, conversation_id, True def validate_attachments_metadata(attachments: list[Attachment]) -> None: diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 60d9d4d6e..053388d4f 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -30,7 +30,12 @@ from models.config import Action from models.requests import QueryRequest from models.database.conversations import UserConversation -from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt +from utils.agent import get_agent +from utils.endpoints import ( + check_configuration_loaded, + get_system_prompt, + get_invalid_query_response, +) from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups from utils.transcripts import store_transcript from utils.types import TurnSummary @@ -38,6 +43,7 @@ from app.endpoints.query import ( get_rag_toolgroups, + validate_question, is_input_shield, is_output_shield, is_transcripts_enabled, @@ -587,6 +593,40 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals user_conversation=user_conversation, query_request=query_request ), ) + + # Check question validation before getting response + query_is_valid = True + if configuration.question_validation.question_validation_enabled: + query_is_valid, temp_agent_conversation_id = await validate_question( + query_request.query, client, llama_stack_model_id + ) + if not query_is_valid: + response = get_invalid_query_response(configuration) + if not is_transcripts_enabled(): + logger.debug( + "Transcript collection is disabled in the configuration" + ) + else: + summary = TurnSummary( + llm_response=response, + tool_calls=[], + ) + store_transcript( + user_id=user_id, + conversation_id=query_request.conversation_id + or temp_agent_conversation_id, + model_id=model_id, + provider_id=provider_id, + query_is_valid=query_is_valid, + query=query_request.query, + query_request=query_request, + summary=summary, + rag_chunks=[], # TODO(lucasagomes): implement rag_chunks + truncated=False, # TODO(lucasagomes): implement truncation as part + # of quota work + attachments=query_request.attachments or [], + ) + return StreamingResponse(response) response, conversation_id = await retrieve_response( client, llama_stack_model_id, @@ -598,6 +638,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals async def response_generator( turn_response: AsyncIterator[AgentTurnResponseStreamChunk], + query_is_valid: bool, ) -> AsyncIterator[str]: """ Generate SSE formatted streaming response. @@ -649,7 +690,7 @@ async def response_generator( conversation_id=conversation_id, model_id=model_id, provider_id=provider_id, - query_is_valid=True, # TODO(lucasagomes): implement as part of query validation + query_is_valid=query_is_valid, query=query_request.query, query_request=query_request, summary=summary, @@ -669,7 +710,7 @@ async def response_generator( # Update metrics for the LLM call metrics.llm_calls_total.labels(provider_id, model_id).inc() - return StreamingResponse(response_generator(response)) + return StreamingResponse(response_generator(response, query_is_valid)) # connection to Llama Stack server except APIConnectionError as e: # Update metrics for the LLM call failure @@ -750,6 +791,7 @@ async def retrieve_response( ) logger.debug("Conversation ID: %s, session ID: %s", conversation_id, session_id) + # bypass tools and MCP servers if no_tools is True if query_request.no_tools: mcp_headers = {} diff --git a/src/configuration.py b/src/configuration.py index 8aa50a0a8..82b5f0d4c 100644 --- a/src/configuration.py +++ b/src/configuration.py @@ -19,6 +19,7 @@ AuthenticationConfiguration, InferenceConfiguration, DatabaseConfiguration, + QuestionValidationConfiguration, ) @@ -131,5 +132,12 @@ def database_configuration(self) -> DatabaseConfiguration: raise LogicError("logic error: configuration is not loaded") return self._configuration.database + @property + def question_validation(self) -> QuestionValidationConfiguration: + """Return question validation configuration.""" + if self._configuration is None: + raise LogicError("logic error: configuration is not loaded") + return self._configuration.question_validation + configuration: AppConfig = AppConfig() diff --git a/src/constants.py b/src/constants.py index e79ebcebb..594333e59 100644 --- a/src/constants.py +++ b/src/constants.py @@ -28,6 +28,19 @@ # configuration file nor in the query request DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant" + +# Query validation +SUBJECT_REJECTED = "REJECTED" +SUBJECT_ALLOWED = "ALLOWED" +DEFAULT_VALIDATION_SYSTEM_PROMPT = ( + "You are a helpful assistant that validates questions. You will be given a " + "question and you will need to validate if it is valid or not. You will " + f"return '{SUBJECT_REJECTED}' if the question is not valid and " + f"'{SUBJECT_ALLOWED}' if it is valid." +) +DEFAULT_INVALID_QUERY_RESPONSE = "Invalid query, please try again." + + # Authentication constants DEFAULT_VIRTUAL_PATH = "/ls-access" DEFAULT_USER_NAME = "lightspeed-user" diff --git a/src/models/config.py b/src/models/config.py index bc7538e91..9f093e410 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -232,6 +232,12 @@ def check_storage_location_is_set_when_needed(self) -> Self: return self +class QuestionValidationConfiguration(ConfigurationBase): + """Question validation configuration.""" + + question_validation_enabled: bool = False + + class JsonPathOperator(str, Enum): """Supported operators for JSONPath evaluation.""" @@ -421,6 +427,7 @@ class CustomProfile: path: str prompts: dict[str, str] = Field(default={}, init=False) + query_responses: dict[str, str] = Field(default={}, init=False) def __post_init__(self) -> None: """Validate and load profile.""" @@ -432,11 +439,18 @@ def _validate_and_process(self) -> None: profile_module = checks.import_python_module("profile", self.path) if profile_module is not None and checks.is_valid_profile(profile_module): self.prompts = profile_module.PROFILE_CONFIG.get("system_prompts", {}) + self.query_responses = profile_module.PROFILE_CONFIG.get( + "query_responses", {} + ) def get_prompts(self) -> dict[str, str]: """Retrieve prompt attribute.""" return self.prompts + def get_query_responses(self) -> dict[str, str]: + """Retrieve query responses attribute.""" + return self.query_responses + class Customization(ConfigurationBase): """Service customization.""" @@ -495,6 +509,9 @@ class Configuration(ConfigurationBase): authorization: Optional[AuthorizationConfiguration] = None customization: Optional[Customization] = None inference: InferenceConfiguration = Field(default_factory=InferenceConfiguration) + question_validation: QuestionValidationConfiguration = Field( + default_factory=QuestionValidationConfiguration + ) def dump(self, filename: str = "configuration.json") -> None: """Dump actual configuration into JSON file.""" diff --git a/src/utils/agent.py b/src/utils/agent.py new file mode 100644 index 000000000..a58690977 --- /dev/null +++ b/src/utils/agent.py @@ -0,0 +1,101 @@ +"""Utility functions for agent management.""" + +from contextlib import suppress +import logging + +from fastapi import HTTPException, status +from llama_stack_client._client import AsyncLlamaStackClient +from llama_stack_client.lib.agents.agent import AsyncAgent + +from utils.suid import get_suid +from utils.types import GraniteToolParser + + +logger = logging.getLogger("utils.agent") + + +# pylint: disable=R0913,R0917 +async def get_agent( + client: AsyncLlamaStackClient, + model_id: str, + system_prompt: str, + available_input_shields: list[str], + available_output_shields: list[str], + conversation_id: str | None, + no_tools: bool = False, +) -> tuple[AsyncAgent, str, str]: + """Get existing agent or create a new one with session persistence.""" + existing_agent_id = None + if conversation_id: + with suppress(ValueError): + agent_response = await client.agents.retrieve(agent_id=conversation_id) + existing_agent_id = agent_response.agent_id + + logger.debug("Creating new agent") + agent = AsyncAgent( + client, # type: ignore[arg-type] + model=model_id, + instructions=system_prompt, + input_shields=available_input_shields if available_input_shields else [], + output_shields=available_output_shields if available_output_shields else [], + tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id), + enable_session_persistence=True, + ) + await agent.initialize() + + if existing_agent_id and conversation_id: + orphan_agent_id = agent.agent_id + agent._agent_id = conversation_id # type: ignore[assignment] # pylint: disable=protected-access + await client.agents.delete(agent_id=orphan_agent_id) + sessions_response = await client.agents.session.list(agent_id=conversation_id) + logger.info("session response: %s", sessions_response) + try: + session_id = str(sessions_response.data[0]["session_id"]) + except IndexError as e: + logger.error("No sessions found for conversation %s", conversation_id) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "response": "Conversation not found", + "cause": f"Conversation {conversation_id} could not be retrieved.", + }, + ) from e + else: + conversation_id = agent.agent_id + session_id = await agent.create_session(get_suid()) + + return agent, conversation_id, session_id + + +async def get_temp_agent( + client: AsyncLlamaStackClient, + model_id: str, + system_prompt: str, +) -> tuple[AsyncAgent, str, str]: + """Create a temporary agent with new agent_id and session_id. + + This function creates a new agent without persistence, shields, or tools. + Useful for temporary operations or one-off queries, such as validating a question or generating a summary. + + Args: + client: The AsyncLlamaStackClient to use for the request. + model_id: The ID of the model to use. + system_prompt: The system prompt/instructions for the agent. + + Returns: + tuple[AsyncAgent, str]: A tuple containing the agent and session_id. + """ + logger.debug("Creating temporary agent") + agent = AsyncAgent( + client, # type: ignore[arg-type] + model=model_id, + instructions=system_prompt, + enable_session_persistence=False, # Temporary agent doesn't need persistence + ) + await agent.initialize() + + # Generate new IDs for the temporary agent + conversation_id = agent.agent_id + session_id = await agent.create_session(get_suid()) + + return agent, session_id, conversation_id diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index e17a76d06..c45fe5e61 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -1,10 +1,7 @@ """Utility functions for endpoint handlers.""" -from contextlib import suppress import logging from fastapi import HTTPException, status -from llama_stack_client._client import AsyncLlamaStackClient -from llama_stack_client.lib.agents.agent import AsyncAgent import constants from models.requests import QueryRequest @@ -12,8 +9,6 @@ from models.config import Action from app.database import get_session from configuration import AppConfig -from utils.suid import get_suid -from utils.types import GraniteToolParser logger = logging.getLogger("utils.endpoints") @@ -68,6 +63,34 @@ def check_configuration_loaded(config: AppConfig) -> None: ) +def get_validation_system_prompt(config: AppConfig) -> str: + """Get the validation system prompt.""" + # profile takes precedence for setting prompt + if ( + config.customization is not None + and config.customization.custom_profile is not None + ): + prompt = config.customization.custom_profile.get_prompts().get("validation") + if prompt: + return prompt + + return constants.DEFAULT_VALIDATION_SYSTEM_PROMPT + + +def get_invalid_query_response(config: AppConfig) -> str: + """Get the invalid query response.""" + if ( + config.customization is not None + and config.customization.custom_profile is not None + ): + prompt = config.customization.custom_profile.get_query_responses().get( + "invalid_resp" + ) + if prompt: + return prompt + return constants.DEFAULT_INVALID_QUERY_RESPONSE + + def get_system_prompt(query_request: QueryRequest, config: AppConfig) -> str: """Get the system prompt: the provided one, configured one, or default one.""" system_prompt_disabled = ( @@ -132,56 +155,3 @@ def validate_model_provider_override( ) }, ) - - -# # pylint: disable=R0913,R0917 -async def get_agent( - client: AsyncLlamaStackClient, - model_id: str, - system_prompt: str, - available_input_shields: list[str], - available_output_shields: list[str], - conversation_id: str | None, - no_tools: bool = False, -) -> tuple[AsyncAgent, str, str]: - """Get existing agent or create a new one with session persistence.""" - existing_agent_id = None - if conversation_id: - with suppress(ValueError): - agent_response = await client.agents.retrieve(agent_id=conversation_id) - existing_agent_id = agent_response.agent_id - - logger.debug("Creating new agent") - agent = AsyncAgent( - client, # type: ignore[arg-type] - model=model_id, - instructions=system_prompt, - input_shields=available_input_shields if available_input_shields else [], - output_shields=available_output_shields if available_output_shields else [], - tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id), - enable_session_persistence=True, - ) - await agent.initialize() - - if existing_agent_id and conversation_id: - orphan_agent_id = agent.agent_id - agent._agent_id = conversation_id # type: ignore[assignment] # pylint: disable=protected-access - await client.agents.delete(agent_id=orphan_agent_id) - sessions_response = await client.agents.session.list(agent_id=conversation_id) - logger.info("session response: %s", sessions_response) - try: - session_id = str(sessions_response.data[0]["session_id"]) - except IndexError as e: - logger.error("No sessions found for conversation %s", conversation_id) - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail={ - "response": "Conversation not found", - "cause": f"Conversation {conversation_id} could not be retrieved.", - }, - ) from e - else: - conversation_id = agent.agent_id - session_id = await agent.create_session(get_suid()) - - return agent, conversation_id, session_id diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 3b3d64f3f..074e04379 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -152,6 +152,8 @@ async def _test_query_endpoint_handler( mock_config.user_data_collection_configuration.transcripts_enabled = ( store_transcript_to_file ) + # Mock question validation configuration to be disabled + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) summary = TurnSummary( @@ -170,7 +172,7 @@ async def _test_query_endpoint_handler( mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(summary, conversation_id), + return_value=(summary, conversation_id, True), ) mocker.patch( "app.endpoints.query.select_model_and_provider_id", @@ -447,6 +449,7 @@ async def test_retrieve_response_no_returned_message(prepare_agent_mocks, mocker # Mock configuration with empty MCP servers mock_config = mocker.Mock() mock_config.mcp_servers = [] + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( "app.endpoints.query.get_agent", @@ -458,7 +461,7 @@ async def test_retrieve_response_no_returned_message(prepare_agent_mocks, mocker model_id = "fake_model_id" access_token = "test_token" - response, _ = await retrieve_response( + response, _, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -479,6 +482,7 @@ async def test_retrieve_response_message_without_content(prepare_agent_mocks, mo # Mock configuration with empty MCP servers mock_config = mocker.Mock() mock_config.mcp_servers = [] + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( "app.endpoints.query.get_agent", @@ -490,7 +494,7 @@ async def test_retrieve_response_message_without_content(prepare_agent_mocks, mo model_id = "fake_model_id" access_token = "test_token" - response, _ = await retrieve_response( + response, _, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -512,6 +516,7 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker # Mock configuration with empty MCP servers mock_config = mocker.Mock() mock_config.mcp_servers = [] + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( "app.endpoints.query.get_agent", @@ -523,7 +528,7 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -551,6 +556,7 @@ async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocke # Mock configuration with empty MCP servers mock_config = mocker.Mock() mock_config.mcp_servers = [] + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( "app.endpoints.query.get_agent", @@ -562,7 +568,7 @@ async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocke model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -601,6 +607,7 @@ def __repr__(self): # Mock configuration with empty MCP servers mock_config = mocker.Mock() mock_config.mcp_servers = [] + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( "app.endpoints.query.get_agent", @@ -612,7 +619,7 @@ def __repr__(self): model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -654,6 +661,7 @@ def __repr__(self): # Mock configuration with empty MCP servers mock_config = mocker.Mock() mock_config.mcp_servers = [] + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( "app.endpoints.query.get_agent", @@ -665,7 +673,7 @@ def __repr__(self): model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -709,6 +717,7 @@ def __repr__(self): # Mock configuration with empty MCP servers mock_config = mocker.Mock() mock_config.mcp_servers = [] + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mock_get_agent = mocker.patch( "app.endpoints.query.get_agent", @@ -720,7 +729,7 @@ def __repr__(self): model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -758,6 +767,7 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker # Mock configuration with empty MCP servers mock_config = mocker.Mock() mock_config.mcp_servers = [] + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) attachments = [ @@ -777,7 +787,7 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -808,6 +818,7 @@ async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocke # Mock configuration with empty MCP servers mock_config = mocker.Mock() mock_config.mcp_servers = [] + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) attachments = [ @@ -832,7 +843,7 @@ async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocke model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -877,6 +888,7 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): ] mock_config = mocker.Mock() mock_config.mcp_servers = mcp_servers + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mock_get_agent = mocker.patch( "app.endpoints.query.get_agent", @@ -888,7 +900,7 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token_123" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -947,6 +959,7 @@ async def test_retrieve_response_with_mcp_servers_empty_token( ] mock_config = mocker.Mock() mock_config.mcp_servers = mcp_servers + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mock_get_agent = mocker.patch( "app.endpoints.query.get_agent", @@ -958,7 +971,7 @@ async def test_retrieve_response_with_mcp_servers_empty_token( model_id = "fake_model_id" access_token = "" # Empty token - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -1009,6 +1022,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers( ] mock_config = mocker.Mock() mock_config.mcp_servers = mcp_servers + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mock_get_agent = mocker.patch( "app.endpoints.query.get_agent", @@ -1030,7 +1044,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers( }, } - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, @@ -1106,6 +1120,7 @@ async def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker): # Mock configuration with empty MCP servers mock_config = mocker.Mock() mock_config.mcp_servers = [] + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( "app.endpoints.query.get_agent", @@ -1115,7 +1130,7 @@ async def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker): query_request = QueryRequest(query="What is OpenStack?") - _, conversation_id = await retrieve_response( + _, conversation_id, _ = await retrieve_response( mock_client, "fake_model_id", query_request, "test_token" ) @@ -1177,6 +1192,7 @@ async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker, dummy_requ # Mock dependencies mock_config = mocker.Mock() mock_config.llama_stack_configuration = mocker.Mock() + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mock_client = mocker.AsyncMock() @@ -1200,7 +1216,7 @@ async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker, dummy_requ ) mock_retrieve_response = mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(summary, "test_conversation_id"), + return_value=(summary, "test_conversation_id", True), ) mocker.patch( @@ -1251,7 +1267,7 @@ async def test_query_endpoint_handler_no_tools_true(mocker, dummy_request): mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(summary, conversation_id), + return_value=(summary, conversation_id, True), ) mocker.patch( "app.endpoints.query.select_model_and_provider_id", @@ -1302,7 +1318,7 @@ async def test_query_endpoint_handler_no_tools_false(mocker, dummy_request): mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(summary, conversation_id), + return_value=(summary, conversation_id, True), ) mocker.patch( "app.endpoints.query.select_model_and_provider_id", @@ -1343,6 +1359,7 @@ async def test_retrieve_response_no_tools_bypasses_mcp_and_rag( ] mock_config = mocker.Mock() mock_config.mcp_servers = mcp_servers + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( "app.endpoints.query.get_agent", @@ -1354,7 +1371,7 @@ async def test_retrieve_response_no_tools_bypasses_mcp_and_rag( model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -1394,6 +1411,7 @@ async def test_retrieve_response_no_tools_false_preserves_functionality( ] mock_config = mocker.Mock() mock_config.mcp_servers = mcp_servers + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( "app.endpoints.query.get_agent", @@ -1405,7 +1423,7 @@ async def test_retrieve_response_no_tools_false_preserves_functionality( model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -1586,3 +1604,129 @@ async def test_query_endpoint_rejects_model_provider_override_without_permission ) assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN assert exc_info.value.detail["response"] == expected_msg + + +@pytest.mark.asyncio +async def test_query_endpoint_with_question_validation_invalid_query( + setup_configuration, dummy_request, mocker +): + """Test query endpoint with question validation enabled and invalid query.""" + # Mock metrics + mock_metrics(mocker) + + # Mock database operations + mock_database_operations(mocker) + + # Setup configuration with question validation enabled + setup_configuration.question_validation.question_validation_enabled = True + mocker.patch("app.endpoints.query.configuration", setup_configuration) + + # Mock the validation agent response (invalid query) + mock_validation_agent = mocker.AsyncMock() + mock_validation_agent.agent_id = "validation_agent_id" + mock_validation_agent.create_session.return_value = "validation_session_id" + + # Mock the validation response that contains SUBJECT_REJECTED + mock_validation_turn = mocker.Mock() + mock_validation_turn.output_message.content = [{"type": "text", "text": "REJECTED"}] + mock_validation_agent.create_turn.return_value = mock_validation_turn + + # Mock the main agent (should not be called for invalid queries) + mock_agent = mocker.AsyncMock() + mock_agent.agent_id = "conversation_id" + mock_agent.create_session.return_value = "session_id" + + # Mock the client + mock_client = mocker.AsyncMock() + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1") + ] + mocker.patch( + "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client + ) + + # Mock the retrieve_response function to return invalid query response + summary = TurnSummary( + llm_response="Invalid query response", + tool_calls=[], + ) + conversation_id = "fake_conversation_id" + + mocker.patch( + "app.endpoints.query.retrieve_response", + return_value=(summary, conversation_id, False), # query_is_valid=False + ) + mocker.patch( + "app.endpoints.query.select_model_and_provider_id", + return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), + ) + mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) + + query_request = QueryRequest(query="Invalid question about unrelated topic") + + response = await query_endpoint_handler( + request=dummy_request, query_request=query_request, auth=MOCK_AUTH + ) + + # Verify the response contains the invalid query response + assert response.conversation_id == conversation_id + assert response.response == "Invalid query response" + + +@pytest.mark.asyncio +async def test_query_endpoint_with_question_validation_valid_query( + setup_configuration, dummy_request, mocker +): + """Test query endpoint with question validation enabled and valid query.""" + # Mock metrics + mock_metrics(mocker) + + # Mock database operations + mock_database_operations(mocker) + + # Setup configuration with question validation enabled + setup_configuration.question_validation.question_validation_enabled = True + mocker.patch("app.endpoints.query.configuration", setup_configuration) + + # Mock the client + mock_client = mocker.AsyncMock() + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1") + ] + mocker.patch( + "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client + ) + + # Mock the retrieve_response function to return valid query response + summary = TurnSummary( + llm_response="Valid LLM response about Backstage and Kubernetes", + tool_calls=[ + ToolCallSummary( + id="123", + name="test-tool", + args="testing", + response="tool response", + ) + ], + ) + conversation_id = "fake_conversation_id" + + mocker.patch( + "app.endpoints.query.retrieve_response", + return_value=(summary, conversation_id, True), # query_is_valid=True + ) + mocker.patch( + "app.endpoints.query.select_model_and_provider_id", + return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), + ) + mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) + + query_request = QueryRequest(query="Valid question about Backstage and Kubernetes") + + response = await query_endpoint_handler( + request=dummy_request, query_request=query_request, auth=MOCK_AUTH + ) + + # Verify the response contains the normal LLM response + assert response.conversation_id == conversation_id + assert response.response == "Valid LLM response about Backstage and Kubernetes" diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 38983666a..595894b52 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -1283,6 +1283,7 @@ async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker): # Mock dependencies mock_config = mocker.Mock() mock_config.llama_stack_configuration = mocker.Mock() + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mock_client = mocker.AsyncMock() @@ -1337,6 +1338,7 @@ async def test_streaming_query_endpoint_handler_no_tools_true(mocker): mock_config = mocker.Mock() mock_config.user_data_collection_configuration.transcripts_disabled = True + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.streaming_query.configuration", mock_config) # Mock the streaming response @@ -1384,6 +1386,7 @@ async def test_streaming_query_endpoint_handler_no_tools_false(mocker): mock_config = mocker.Mock() mock_config.user_data_collection_configuration.transcripts_disabled = True + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.streaming_query.configuration", mock_config) # Mock the streaming response @@ -1584,3 +1587,211 @@ async def test_streaming_query_endpoint_rejects_model_provider_override_without_ ) assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN assert exc_info.value.detail["response"] == expected_msg + + +@pytest.mark.asyncio +async def test_streaming_query_endpoint_with_question_validation_invalid_query( + setup_configuration, mocker +): + """Test streaming query endpoint with question validation enabled and invalid query.""" + # Mock metrics + mock_metrics(mocker) + + # Mock database operations + mock_database_operations(mocker) + + # Setup configuration with question validation enabled + setup_configuration.question_validation.question_validation_enabled = True + mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) + + # Mock the client + mock_client = mocker.AsyncMock() + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1") + ] + mocker.patch( + "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client + ) + + # Mock the validation agent response (invalid query) + mock_validation_agent = mocker.AsyncMock() + mock_validation_agent.agent_id = "validation_agent_id" + mock_validation_agent.create_session.return_value = "validation_session_id" + + # Mock the validation response that contains SUBJECT_REJECTED + mock_validation_turn = mocker.Mock() + mock_validation_turn.output_message.content = [{"type": "text", "text": "REJECTED"}] + mock_validation_agent.create_turn.return_value = mock_validation_turn + + # Mock the validation functions + mocker.patch( + "app.endpoints.streaming_query.validate_question", + return_value=(False, "validation_agent_id"), + ) + mocker.patch( + "app.endpoints.streaming_query.get_invalid_query_response", + return_value="Invalid query response", + ) + mocker.patch( + "app.endpoints.streaming_query.interleaved_content_as_str", + return_value="REJECTED", + ) + + # Mock other dependencies + mocker.patch("app.endpoints.streaming_query.validate_model_provider_override") + mocker.patch("app.endpoints.streaming_query.check_configuration_loaded") + mocker.patch( + "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False + ) + mocker.patch( + "app.endpoints.streaming_query.select_model_and_provider_id", + return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), + ) + + query_request = QueryRequest(query="Invalid question about unrelated topic") + + request = Request( + scope={ + "type": "http", + } + ) + request.state.authorized_actions = set(Action) + + response = await streaming_query_endpoint_handler( + request=request, query_request=query_request, auth=MOCK_AUTH + ) + + # Verify the response is a StreamingResponse + assert isinstance(response, StreamingResponse) + + # Collect the streaming response content + streaming_content = [] + async for chunk in response.body_iterator: + streaming_content.append(chunk) + + # Convert to string for assertions + full_content = "".join(streaming_content) + + # Verify the response contains the invalid query response + assert "Invalid query response" in full_content + + +@pytest.mark.asyncio +async def test_streaming_query_endpoint_with_question_validation_valid_query( + setup_configuration, mocker +): + """Test streaming query endpoint with question validation enabled and valid query.""" + # Mock metrics + mock_metrics(mocker) + + # Mock database operations + mock_database_operations(mocker) + + # Setup configuration with question validation enabled + setup_configuration.question_validation.question_validation_enabled = True + mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) + + # Mock the client + mock_client = mocker.AsyncMock() + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1") + ] + mocker.patch( + "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client + ) + + # Mock the validation functions to return valid query + mocker.patch( + "app.endpoints.streaming_query.validate_question", + return_value=(True, "validation_agent_id"), + ) + mocker.patch( + "app.endpoints.streaming_query.interleaved_content_as_str", + return_value="Valid question about Backstage", + ) + + # Mock the retrieve_response function to return valid streaming response + mock_streaming_response = mocker.AsyncMock() + mock_streaming_response.__aiter__.return_value = [ + AgentTurnResponseStreamChunk( + event=TurnResponseEvent( + payload=AgentTurnResponseStepProgressPayload( + event_type="step_progress", + step_type="inference", + delta=TextDelta( + text="Valid LLM response about Backstage and Kubernetes", + type="text", + ), + step_id="s1", + ) + ) + ), + AgentTurnResponseStreamChunk( + event=TurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + event_type="turn_complete", + turn=Turn( + turn_id="t1", + input_messages=[], + output_message=CompletionMessage( + role="assistant", + content=[ + TextContentItem( + text="Valid LLM response about Backstage and Kubernetes", + type="text", + ) + ], + stop_reason="end_of_turn", + ), + session_id="test_session_id", + started_at=datetime.now(), + steps=[], + completed_at=datetime.now(), + output_attachments=[], + ), + ) + ) + ), + ] + + mocker.patch( + "app.endpoints.streaming_query.retrieve_response", + return_value=(mock_streaming_response, "test_conversation_id"), + ) + mocker.patch( + "app.endpoints.streaming_query.select_model_and_provider_id", + return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), + ) + mocker.patch( + "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False + ) + + query_request = QueryRequest(query="Valid question about Backstage and Kubernetes") + + request = Request( + scope={ + "type": "http", + } + ) + request.state.authorized_actions = set(Action) + + response = await streaming_query_endpoint_handler( + request=request, query_request=query_request, auth=MOCK_AUTH + ) + + # Verify the response is a StreamingResponse + assert isinstance(response, StreamingResponse) + + # Collect the streaming response content + streaming_content = [] + async for chunk in response.body_iterator: + streaming_content.append(chunk) + + # Convert to string for assertions + full_content = "".join(streaming_content) + + # Verify the response contains the normal LLM response + assert "Valid LLM response about Backstage and Kubernetes" in full_content + assert '"event": "start"' in full_content + assert '"event": "token"' in full_content + assert '"event": "end"' in full_content diff --git a/tests/unit/models/config/test_dump_configuration.py b/tests/unit/models/config/test_dump_configuration.py index 303b17998..a97e2f884 100644 --- a/tests/unit/models/config/test_dump_configuration.py +++ b/tests/unit/models/config/test_dump_configuration.py @@ -88,6 +88,7 @@ def test_dump_configuration(tmp_path) -> None: assert "customization" in content assert "inference" in content assert "database" in content + assert "question_validation" in content # check the whole deserialized JSON file content assert content == { @@ -163,6 +164,7 @@ def test_dump_configuration(tmp_path) -> None: }, }, "authorization": None, + "question_validation": {"question_validation_enabled": False}, } diff --git a/tests/unit/utils/test_agent.py b/tests/unit/utils/test_agent.py new file mode 100644 index 000000000..b81893b7d --- /dev/null +++ b/tests/unit/utils/test_agent.py @@ -0,0 +1,538 @@ +"""Unit tests for agent utility functions.""" + +import pytest + +from configuration import AppConfig + +from utils.agent import get_agent, get_temp_agent + + +@pytest.fixture(name="setup_configuration") +def setup_configuration_fixture(): + """Set up configuration for tests.""" + test_config_dict = { + "name": "test", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "workers": 1, + "color_log": True, + "access_log": True, + }, + "llama_stack": { + "api_key": "test-key", + "url": "http://test.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": { + "transcripts_enabled": False, + }, + "mcp_servers": [], + } + cfg = AppConfig() + cfg.init_from_dict(test_config_dict) + return cfg + + +@pytest.mark.asyncio +async def test_get_agent_with_conversation_id(prepare_agent_mocks, mocker): + """Test get_agent function when agent exists in llama stack.""" + mock_client, mock_agent = prepare_agent_mocks + conversation_id = "test_conversation_id" + + # Mock existing agent retrieval + mock_agent_response = mocker.Mock() + mock_agent_response.agent_id = conversation_id + mock_client.agents.retrieve.return_value = mock_agent_response + + mock_client.agents.session.list.return_value = mocker.Mock( + data=[{"session_id": "test_session_id"}] + ) + + # Mock Agent class + mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) + + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=conversation_id, + ) + + # Assert the same agent is returned and conversation_id is preserved + assert result_agent == mock_agent + assert result_conversation_id == conversation_id + assert result_session_id == "test_session_id" + + +@pytest.mark.asyncio +async def test_get_agent_with_conversation_id_and_no_agent_in_llama_stack( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function when conversation_id is provided.""" + mock_client, mock_agent = prepare_agent_mocks + mock_client.agents.retrieve.side_effect = ValueError( + "fake not finding existing agent" + ) + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) + + # Mock get_suid + mocker.patch("utils.agent.get_suid", return_value="new_session_id") + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("configuration.configuration", setup_configuration) + conversation_id = "non_existent_conversation_id" + # Call function with conversation_id + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=conversation_id, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == result_agent.agent_id + assert conversation_id != result_agent.agent_id + assert result_session_id == "new_session_id" + + # Verify Agent was created with correct parameters + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + output_shields=["output_shield2"], + tool_parser=None, + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_no_conversation_id( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function when conversation_id is None.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) + + # Mock get_suid + mocker.patch("utils.agent.get_suid", return_value="new_session_id") + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("configuration.configuration", setup_configuration) + + # Call function with None conversation_id + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=None, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" + + # Verify Agent was created with correct parameters + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + output_shields=["output_shield2"], + tool_parser=None, + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_empty_shields( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function with empty shields list.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) + + # Mock get_suid + mocker.patch("utils.agent.get_suid", return_value="new_session_id") + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("configuration.configuration", setup_configuration) + + # Call function with empty shields list + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=[], + available_output_shields=[], + conversation_id=None, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" + + # Verify Agent was created with empty shields + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=[], + output_shields=[], + tool_parser=None, + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_multiple_mcp_servers( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function with multiple MCP servers.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) + + # Mock get_suid + mocker.patch("utils.agent.get_suid", return_value="new_session_id") + + # Mock configuration with multiple MCP servers + mock_mcp_server1 = mocker.Mock() + mock_mcp_server1.name = "mcp_server_1" + mock_mcp_server2 = mocker.Mock() + mock_mcp_server2.name = "mcp_server_2" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server1, mock_mcp_server2], + ) + mocker.patch("configuration.configuration", setup_configuration) + + # Call function + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1", "shield2"], + available_output_shields=["output_shield3", "output_shield4"], + conversation_id=None, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" + + # Verify Agent was created with tools from both MCP servers + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1", "shield2"], + output_shields=["output_shield3", "output_shield4"], + tool_parser=None, + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_session_persistence_enabled( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function ensures session persistence is enabled.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) + + # Mock get_suid + mocker.patch("utils.agent.get_suid", return_value="new_session_id") + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("configuration.configuration", setup_configuration) + + # Call function + await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=None, + ) + + # Verify Agent was created with session persistence enabled + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + output_shields=["output_shield2"], + tool_parser=None, + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_no_tools_no_parser( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function sets tool_parser=None when no_tools=True.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) + + # Mock get_suid + mocker.patch("utils.agent.get_suid", return_value="new_session_id") + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("configuration.configuration", setup_configuration) + + # Call function with no_tools=True + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=None, + no_tools=True, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" + + # Verify Agent was created with tool_parser=None + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + output_shields=["output_shield2"], + tool_parser=None, + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_no_tools_false_preserves_parser( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function preserves tool_parser when no_tools=False.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) + + # Mock get_suid + mocker.patch("utils.agent.get_suid", return_value="new_session_id") + + # Mock GraniteToolParser + mock_parser = mocker.Mock() + mock_granite_parser = mocker.patch("utils.agent.GraniteToolParser") + mock_granite_parser.get_parser.return_value = mock_parser + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("configuration.configuration", setup_configuration) + + # Call function with no_tools=False + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=None, + no_tools=False, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" + + # Verify Agent was created with the proper tool_parser + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + output_shields=["output_shield2"], + tool_parser=mock_parser, + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_temp_agent_basic_functionality(prepare_agent_mocks, mocker): + """Test get_temp_agent function creates agent with correct parameters.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "temp_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) + + # Mock get_suid + mocker.patch("utils.agent.get_suid", return_value="temp_session_id") + + # Call function + result_agent, result_session_id, result_conversation_id = await get_temp_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + ) + + # Assert agent, session_id, and conversation_id are created and returned + assert result_agent == mock_agent + assert result_session_id == "temp_session_id" + assert result_conversation_id == mock_agent.agent_id + + # Verify Agent was created with correct parameters for temporary agent + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + enable_session_persistence=False, # Key difference: no persistence + ) + + # Verify agent was initialized and session was created + mock_agent.initialize.assert_called_once() + mock_agent.create_session.assert_called_once_with("temp_session_id") + + +@pytest.mark.asyncio +async def test_get_temp_agent_returns_valid_ids(prepare_agent_mocks, mocker): + """Test get_temp_agent function returns valid agent_id and session_id.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.agent_id = "generated_agent_id" + mock_agent.create_session.return_value = "generated_session_id" + + # Mock Agent class + mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) + + # Mock get_suid + mocker.patch("utils.agent.get_suid", return_value="generated_session_id") + + # Call function + result_agent, result_session_id, result_conversation_id = await get_temp_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + ) + + # Assert all three values are returned and are not None/empty + assert result_agent is not None + assert result_session_id is not None + assert result_conversation_id is not None + + # Assert they are strings + assert isinstance(result_session_id, str) + assert isinstance(result_conversation_id, str) + + # Assert conversation_id matches agent_id + assert result_conversation_id == result_agent.agent_id + + +@pytest.mark.asyncio +async def test_get_temp_agent_no_persistence(prepare_agent_mocks, mocker): + """Test get_temp_agent function creates agent without session persistence.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "temp_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) + + # Mock get_suid + mocker.patch("utils.agent.get_suid", return_value="temp_session_id") + + # Call function + await get_temp_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + ) + + # Verify Agent was created with session persistence disabled + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + enable_session_persistence=False, + ) diff --git a/tests/unit/utils/test_endpoints.py b/tests/unit/utils/test_endpoints.py index 04701ac48..c12693e80 100644 --- a/tests/unit/utils/test_endpoints.py +++ b/tests/unit/utils/test_endpoints.py @@ -10,7 +10,6 @@ from models.requests import QueryRequest from models.config import Action from utils import endpoints -from utils.endpoints import get_agent from tests.unit import config_dict @@ -119,34 +118,6 @@ def query_request_with_system_prompt_fixture(): return QueryRequest(query="query", system_prompt="System prompt defined in query") -@pytest.fixture(name="setup_configuration") -def setup_configuration_fixture(): - """Set up configuration for tests.""" - test_config_dict = { - "name": "test", - "service": { - "host": "localhost", - "port": 8080, - "auth_enabled": False, - "workers": 1, - "color_log": True, - "access_log": True, - }, - "llama_stack": { - "api_key": "test-key", - "url": "http://test.com:1234", - "use_as_library_client": False, - }, - "user_data_collection": { - "transcripts_enabled": False, - }, - "mcp_servers": [], - } - cfg = AppConfig() - cfg.init_from_dict(test_config_dict) - return cfg - - def test_get_default_system_prompt( config_without_system_prompt, query_request_without_system_prompt ): @@ -238,442 +209,147 @@ def test_get_profile_prompt_with_enabled_query_system_prompt( assert system_prompt == query_request_with_system_prompt.system_prompt -@pytest.mark.asyncio -async def test_get_agent_with_conversation_id(prepare_agent_mocks, mocker): - """Test get_agent function when agent exists in llama stack.""" - mock_client, mock_agent = prepare_agent_mocks - conversation_id = "test_conversation_id" - - # Mock existing agent retrieval - mock_agent_response = mocker.Mock() - mock_agent_response.agent_id = conversation_id - mock_client.agents.retrieve.return_value = mock_agent_response - - mock_client.agents.session.list.return_value = mocker.Mock( - data=[{"session_id": "test_session_id"}] - ) - - # Mock Agent class - mocker.patch("utils.endpoints.AsyncAgent", return_value=mock_agent) - - result_agent, result_conversation_id, result_session_id = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=conversation_id, - ) - - # Assert the same agent is returned and conversation_id is preserved - assert result_agent == mock_agent - assert result_conversation_id == conversation_id - assert result_session_id == "test_session_id" - - -@pytest.mark.asyncio -async def test_get_agent_with_conversation_id_and_no_agent_in_llama_stack( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function when conversation_id is provided.""" - mock_client, mock_agent = prepare_agent_mocks - mock_client.agents.retrieve.side_effect = ValueError( - "fake not finding existing agent" - ) - mock_agent.create_session.return_value = "new_session_id" +def test_validate_model_provider_override_allowed_with_action(): + """Ensure no exception when caller has MODEL_OVERRIDE and request includes model/provider.""" + query_request = QueryRequest(query="q", model="m", provider="p") + authorized_actions = {Action.MODEL_OVERRIDE} + endpoints.validate_model_provider_override(query_request, authorized_actions) - # Mock Agent class - mock_agent_class = mocker.patch( - "utils.endpoints.AsyncAgent", return_value=mock_agent - ) - # Mock get_suid - mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("configuration.configuration", setup_configuration) - conversation_id = "non_existent_conversation_id" - # Call function with conversation_id - result_agent, result_conversation_id, result_session_id = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=conversation_id, - ) +def test_validate_model_provider_override_rejected_without_action(): + """Ensure HTTP 403 when request includes model/provider and caller lacks permission.""" + query_request = QueryRequest(query="q", model="m", provider="p") + authorized_actions: set[Action] = set() + with pytest.raises(HTTPException) as exc_info: + endpoints.validate_model_provider_override(query_request, authorized_actions) + assert exc_info.value.status_code == 403 - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert conversation_id != result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with correct parameters - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1"], - output_shields=["output_shield2"], - tool_parser=None, - enable_session_persistence=True, - ) +def test_validate_model_provider_override_no_override_without_action(): + """No exception when request does not include model/provider regardless of permission.""" + query_request = QueryRequest(query="q") + endpoints.validate_model_provider_override(query_request, set()) -@pytest.mark.asyncio -async def test_get_agent_no_conversation_id( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function when conversation_id is None.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - # Mock Agent class - mock_agent_class = mocker.patch( - "utils.endpoints.AsyncAgent", return_value=mock_agent - ) +# Tests for get_validation_system_prompt - # Mock get_suid - mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("configuration.configuration", setup_configuration) - - # Call function with None conversation_id - result_agent, result_conversation_id, result_session_id = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=None, - ) - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with correct parameters - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1"], - output_shields=["output_shield2"], - tool_parser=None, - enable_session_persistence=True, +def test_get_default_validation_system_prompt(config_without_system_prompt): + """Test that default validation system prompt is returned when no custom profile is provided.""" + validation_prompt = endpoints.get_validation_system_prompt( + config_without_system_prompt ) + assert validation_prompt == constants.DEFAULT_VALIDATION_SYSTEM_PROMPT -@pytest.mark.asyncio -async def test_get_agent_empty_shields( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function with empty shields list.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "utils.endpoints.AsyncAgent", return_value=mock_agent - ) +def test_get_validation_system_prompt_with_custom_profile(): + """Test that validation system prompt from custom profile is returned when available.""" + test_config = config_dict.copy() + test_config["customization"] = { + "profile_path": "tests/profiles/test/profile.py", + } + cfg = AppConfig() + cfg.init_from_dict(test_config) - # Mock get_suid - mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("configuration.configuration", setup_configuration) - - # Call function with empty shields list - result_agent, result_conversation_id, result_session_id = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=[], - available_output_shields=[], - conversation_id=None, - ) + validation_prompt = endpoints.get_validation_system_prompt(cfg) - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with empty shields - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=[], - output_shields=[], - tool_parser=None, - enable_session_persistence=True, - ) + # Get the expected prompt from the test profile + custom_profile = CustomProfile(path="tests/profiles/test/profile.py") + expected_prompt = custom_profile.get_prompts().get("validation") + assert validation_prompt == expected_prompt -@pytest.mark.asyncio -async def test_get_agent_multiple_mcp_servers( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function with multiple MCP servers.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - # Mock Agent class - mock_agent_class = mocker.patch( - "utils.endpoints.AsyncAgent", return_value=mock_agent - ) +def test_get_validation_system_prompt_with_custom_profile_no_validation_prompt(): + """Test that default validation system prompt is returned when custom profile has no validation prompt.""" + # Create a test profile that doesn't have validation prompt + test_config = config_dict.copy() + test_config["customization"] = { + "profile_path": "tests/profiles/test/profile.py", + } + cfg = AppConfig() + cfg.init_from_dict(test_config) - # Mock get_suid - mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") - - # Mock configuration with multiple MCP servers - mock_mcp_server1 = mocker.Mock() - mock_mcp_server1.name = "mcp_server_1" - mock_mcp_server2 = mocker.Mock() - mock_mcp_server2.name = "mcp_server_2" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server1, mock_mcp_server2], - ) - mocker.patch("configuration.configuration", setup_configuration) - - # Call function - result_agent, result_conversation_id, result_session_id = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1", "shield2"], - available_output_shields=["output_shield3", "output_shield4"], - conversation_id=None, - ) + # Manually set the prompts to not include validation + cfg.customization.custom_profile.prompts = {"default": "test prompt"} - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with tools from both MCP servers - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1", "shield2"], - output_shields=["output_shield3", "output_shield4"], - tool_parser=None, - enable_session_persistence=True, - ) + validation_prompt = endpoints.get_validation_system_prompt(cfg) + assert validation_prompt == constants.DEFAULT_VALIDATION_SYSTEM_PROMPT -@pytest.mark.asyncio -async def test_get_agent_session_persistence_enabled( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function ensures session persistence is enabled.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" +def test_get_validation_system_prompt_with_custom_profile_empty_validation_prompt(): + """Test that default validation system prompt is returned when custom profile has empty validation prompt.""" + # Create a test profile that has empty validation prompt + test_config = config_dict.copy() + test_config["customization"] = { + "profile_path": "tests/profiles/test/profile.py", + } + cfg = AppConfig() + cfg.init_from_dict(test_config) - # Mock Agent class - mock_agent_class = mocker.patch( - "utils.endpoints.AsyncAgent", return_value=mock_agent - ) + # Manually set the prompts to have empty validation prompt + cfg.customization.custom_profile.prompts = {"validation": ""} - # Mock get_suid - mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("configuration.configuration", setup_configuration) - - # Call function - await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=None, - ) + validation_prompt = endpoints.get_validation_system_prompt(cfg) + assert validation_prompt == constants.DEFAULT_VALIDATION_SYSTEM_PROMPT - # Verify Agent was created with session persistence enabled - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1"], - output_shields=["output_shield2"], - tool_parser=None, - enable_session_persistence=True, - ) +# Tests for get_invalid_query_response -@pytest.mark.asyncio -async def test_get_agent_no_tools_no_parser( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function sets tool_parser=None when no_tools=True.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - # Mock Agent class - mock_agent_class = mocker.patch( - "utils.endpoints.AsyncAgent", return_value=mock_agent +def test_get_default_invalid_query_response(config_without_system_prompt): + """Test that default invalid query response is returned when no custom profile is provided.""" + invalid_response = endpoints.get_invalid_query_response( + config_without_system_prompt ) + assert invalid_response == constants.DEFAULT_INVALID_QUERY_RESPONSE - # Mock get_suid - mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("configuration.configuration", setup_configuration) - - # Call function with no_tools=True - result_agent, result_conversation_id, result_session_id = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=None, - no_tools=True, - ) - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with tool_parser=None - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1"], - output_shields=["output_shield2"], - tool_parser=None, - enable_session_persistence=True, - ) +def test_get_invalid_query_response_with_custom_profile(): + """Test that invalid query response from custom profile is returned when available.""" + test_config = config_dict.copy() + test_config["customization"] = { + "profile_path": "tests/profiles/test/profile.py", + } + cfg = AppConfig() + cfg.init_from_dict(test_config) + invalid_response = endpoints.get_invalid_query_response(cfg) -@pytest.mark.asyncio -async def test_get_agent_no_tools_false_preserves_parser( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function preserves tool_parser when no_tools=False.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" + # Get the expected response from the test profile + custom_profile = CustomProfile(path="tests/profiles/test/profile.py") + expected_response = custom_profile.get_query_responses().get("invalid_resp") - # Mock Agent class - mock_agent_class = mocker.patch( - "utils.endpoints.AsyncAgent", return_value=mock_agent - ) + assert invalid_response == expected_response - # Mock get_suid - mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") - - # Mock GraniteToolParser - mock_parser = mocker.Mock() - mock_granite_parser = mocker.patch("utils.endpoints.GraniteToolParser") - mock_granite_parser.get_parser.return_value = mock_parser - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("configuration.configuration", setup_configuration) - - # Call function with no_tools=False - result_agent, result_conversation_id, result_session_id = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=None, - no_tools=False, - ) - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with the proper tool_parser - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1"], - output_shields=["output_shield2"], - tool_parser=mock_parser, - enable_session_persistence=True, - ) +def test_get_invalid_query_response_with_custom_profile_no_invalid_resp(): + """Test that default invalid query response is returned when custom profile has no invalid_resp.""" + test_config = config_dict.copy() + test_config["customization"] = { + "profile_path": "tests/profiles/test/profile.py", + } + cfg = AppConfig() + cfg.init_from_dict(test_config) + # Manually set the query_responses to not include invalid_resp + cfg.customization.custom_profile.query_responses = {} -def test_validate_model_provider_override_allowed_with_action(): - """Ensure no exception when caller has MODEL_OVERRIDE and request includes model/provider.""" - query_request = QueryRequest(query="q", model="m", provider="p") - authorized_actions = {Action.MODEL_OVERRIDE} - endpoints.validate_model_provider_override(query_request, authorized_actions) + invalid_response = endpoints.get_invalid_query_response(cfg) + assert invalid_response == constants.DEFAULT_INVALID_QUERY_RESPONSE -def test_validate_model_provider_override_rejected_without_action(): - """Ensure HTTP 403 when request includes model/provider and caller lacks permission.""" - query_request = QueryRequest(query="q", model="m", provider="p") - authorized_actions: set[Action] = set() - with pytest.raises(HTTPException) as exc_info: - endpoints.validate_model_provider_override(query_request, authorized_actions) - assert exc_info.value.status_code == 403 +def test_get_invalid_query_response_with_custom_profile_empty_invalid_resp(): + """Test that default invalid query response is returned when custom profile has empty invalid_resp.""" + test_config = config_dict.copy() + test_config["customization"] = { + "profile_path": "tests/profiles/test/profile.py", + } + cfg = AppConfig() + cfg.init_from_dict(test_config) + # Manually set the query_responses to have empty invalid_resp + cfg.customization.custom_profile.query_responses = {"invalid_resp": ""} -def test_validate_model_provider_override_no_override_without_action(): - """No exception when request does not include model/provider regardless of permission.""" - query_request = QueryRequest(query="q") - endpoints.validate_model_provider_override(query_request, set()) + invalid_response = endpoints.get_invalid_query_response(cfg) + assert invalid_response == constants.DEFAULT_INVALID_QUERY_RESPONSE