From 6b33af69e467fe57cd3d0284953ba5ee262b85be Mon Sep 17 00:00:00 2001 From: Stephanie Date: Mon, 8 Sep 2025 16:45:00 -0400 Subject: [PATCH 1/7] question validation change Signed-off-by: Stephanie --- lightspeed-stack.yaml | 5 +- src/app/endpoints/query.py | 42 ++++++++-- src/app/endpoints/streaming_query.py | 3 +- src/configuration.py | 8 ++ src/constants.py | 15 ++++ src/models/config.py | 7 ++ src/utils/agent.py | 100 ++++++++++++++++++++++++ src/utils/endpoints.py | 110 +++++++++++++-------------- 8 files changed, 226 insertions(+), 64 deletions(-) create mode 100644 src/utils/agent.py diff --git a/lightspeed-stack.yaml b/lightspeed-stack.yaml index 9ac7f63a1..0948f4e2c 100644 --- a/lightspeed-stack.yaml +++ b/lightspeed-stack.yaml @@ -13,7 +13,7 @@ llama_stack: # Alternative for "as library use" # use_as_library_client: true # library_client_config_path: - url: http://llama-stack:8321 + url: http://localhost:8321 api_key: xyzzy user_data_collection: feedback_enabled: true @@ -23,3 +23,6 @@ user_data_collection: authentication: module: "noop" + +question_validation: + question_validation_enabled: true \ No newline at end of file diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 5705ba269..023e52ccb 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -5,8 +5,8 @@ import logging from typing import Annotated, Any, cast -from llama_stack_client import APIConnectionError -from llama_stack_client import AsyncLlamaStackClient # type: ignore + +from llama_stack_client import APIConnectionError, AsyncLlamaStackClient # type: ignore from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str from llama_stack_client.types import UserMessage, Shield # type: ignore from llama_stack_client.types.agents.turn import Turn @@ -31,9 +31,11 @@ from models.database.conversations import UserConversation from models.requests import QueryRequest, Attachment from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse +from utils.agent import get_agent, get_temp_agent from utils.endpoints import ( check_configuration_loaded, - get_agent, + get_invalid_query_response, + get_validation_system_prompt, get_system_prompt, validate_conversation_ownership, validate_model_provider_override, @@ -215,7 +217,7 @@ async def query_endpoint_handler( user_conversation=user_conversation, query_request=query_request ), ) - summary, conversation_id = await retrieve_response( + summary, conversation_id, query_is_valid= await retrieve_response( client, llama_stack_model_id, query_request, @@ -391,6 +393,28 @@ def is_input_shield(shield: Shield) -> bool: return _is_inout_shield(shield) or not is_output_shield(shield) +async def validate_question(question: str, client: AsyncLlamaStackClient, model_id: str) -> bool: + """Validate a question and provides a one-word response. + + Args: + question: The question to be validated. + client: The AsyncLlamaStackClient to use for the request. + model_id: The ID of the model to use. + + Returns: + bool: True if the question was deemed valid, False otherwise + """ + validation_system_prompt = get_validation_system_prompt() + agent, session_id = await get_temp_agent(client, model_id, validation_system_prompt) + response = await agent.create_turn( + messages=[UserMessage(role="user", content=question)], + session_id=session_id, + stream=False, + toolgroups=None, + ) + response = cast(Turn, response) + return constants.SUBJECT_REJECTED not in interleaved_content_as_str(response.output_message.content) + async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches,too-many-arguments client: AsyncLlamaStackClient, model_id: str, @@ -399,7 +423,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche mcp_headers: dict[str, dict[str, str]] | None = None, *, provider_id: str = "", -) -> tuple[TurnSummary, str]: +) -> tuple[TurnSummary, str, bool]: """ Retrieve response from LLMs and agents. @@ -496,6 +520,12 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche if not toolgroups: toolgroups = None + if configuration.question_validation.question_validation_enabled and (not await validate_question(query_request.query, client, model_id)): + return TurnSummary( + llm_response=get_invalid_query_response(), + tool_calls=[], + ), conversation_id, False + response = await agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], session_id=session_id, @@ -535,7 +565,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche "Response lacks output_message.content (conversation_id=%s)", conversation_id, ) - return summary, conversation_id + return summary, conversation_id, True def validate_attachments_metadata(attachments: list[Attachment]) -> None: diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index d6007e962..5069b2346 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -30,7 +30,8 @@ from models.config import Action from models.requests import QueryRequest from models.database.conversations import UserConversation -from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt +from utils.agent import get_agent +from utils.endpoints import check_configuration_loaded, get_system_prompt from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups from utils.transcripts import store_transcript from utils.types import TurnSummary diff --git a/src/configuration.py b/src/configuration.py index f70fe8205..a537fbf83 100644 --- a/src/configuration.py +++ b/src/configuration.py @@ -19,6 +19,7 @@ AuthenticationConfiguration, InferenceConfiguration, DatabaseConfiguration, + QuestionValidationConfiguration, ) @@ -134,5 +135,12 @@ def database_configuration(self) -> DatabaseConfiguration: raise LogicError("logic error: configuration is not loaded") return self._configuration.database + @property + def question_validation(self) -> QuestionValidationConfiguration: + """Return question validation configuration.""" + if self._configuration is None: + raise LogicError("logic error: configuration is not loaded") + return self._configuration.question_validation + configuration: AppConfig = AppConfig() diff --git a/src/constants.py b/src/constants.py index f5982b44e..5fce0e19e 100644 --- a/src/constants.py +++ b/src/constants.py @@ -28,6 +28,21 @@ # configuration file nor in the query request DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant" + +# Query validation +SUBJECT_REJECTED = "REJECTED" +SUBJECT_ALLOWED = "ALLOWED" +DEFAULT_VALIDATION_SYSTEM_PROMPT = ( + "You are a helpful assistant that validates questions. You will be given a " + "question and you will need to validate if it is valid or not. You will " + f"return '{SUBJECT_REJECTED}' if the question is not valid and " + f"'{SUBJECT_ALLOWED}' if it is valid." +) +DEFAULT_INVALID_QUERY_RESPONSE = ( + "Invalid query, please try again." +) + + # Authentication constants DEFAULT_VIRTUAL_PATH = "/ls-access" DEFAULT_USER_NAME = "lightspeed-user" diff --git a/src/models/config.py b/src/models/config.py index c4efa404c..325a931e3 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -230,6 +230,12 @@ def check_storage_location_is_set_when_needed(self) -> Self: return self +class QuestionValidationConfiguration(ConfigurationBase): + """Question validation configuration.""" + + question_validation_enabled: bool = False + + class JsonPathOperator(str, Enum): """Supported operators for JSONPath evaluation.""" @@ -464,6 +470,7 @@ class Configuration(ConfigurationBase): authorization: Optional[AuthorizationConfiguration] = None customization: Optional[Customization] = None inference: InferenceConfiguration = InferenceConfiguration() + question_validation: QuestionValidationConfiguration = QuestionValidationConfiguration() def dump(self, filename: str = "configuration.json") -> None: """Dump actual configuration into JSON file.""" diff --git a/src/utils/agent.py b/src/utils/agent.py new file mode 100644 index 000000000..ddbce84cc --- /dev/null +++ b/src/utils/agent.py @@ -0,0 +1,100 @@ +"""Utility functions for agent management.""" + +from contextlib import suppress +import logging + +from fastapi import HTTPException, status +from llama_stack_client._client import AsyncLlamaStackClient +from llama_stack_client.lib.agents.agent import AsyncAgent + +from utils.suid import get_suid +from utils.types import GraniteToolParser + + +logger = logging.getLogger("utils.agent") + + +# pylint: disable=R0913,R0917 +async def get_agent( + client: AsyncLlamaStackClient, + model_id: str, + system_prompt: str, + available_input_shields: list[str], + available_output_shields: list[str], + conversation_id: str | None, + no_tools: bool = False, +) -> tuple[AsyncAgent, str, str]: + """Get existing agent or create a new one with session persistence.""" + existing_agent_id = None + if conversation_id: + with suppress(ValueError): + agent_response = await client.agents.retrieve(agent_id=conversation_id) + existing_agent_id = agent_response.agent_id + + logger.debug("Creating new agent") + agent = AsyncAgent( + client, # type: ignore[arg-type] + model=model_id, + instructions=system_prompt, + input_shields=available_input_shields if available_input_shields else [], + output_shields=available_output_shields if available_output_shields else [], + tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id), + enable_session_persistence=True, + ) + await agent.initialize() + + if existing_agent_id and conversation_id: + orphan_agent_id = agent.agent_id + agent._agent_id = conversation_id # type: ignore[assignment] # pylint: disable=protected-access + await client.agents.delete(agent_id=orphan_agent_id) + sessions_response = await client.agents.session.list(agent_id=conversation_id) + logger.info("session response: %s", sessions_response) + try: + session_id = str(sessions_response.data[0]["session_id"]) + except IndexError as e: + logger.error("No sessions found for conversation %s", conversation_id) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "response": "Conversation not found", + "cause": f"Conversation {conversation_id} could not be retrieved.", + }, + ) from e + else: + conversation_id = agent.agent_id + session_id = await agent.create_session(get_suid()) + + return agent, conversation_id, session_id + + +async def get_temp_agent( + client: AsyncLlamaStackClient, + model_id: str, + system_prompt: str, +) -> tuple[AsyncAgent, str]: + """Create a temporary agent with new agent_id and session_id. + + This function creates a new agent without persistence, shields, or tools. + Useful for temporary operations or one-off queries, such as validating a question or generating a summary. + + Args: + client: The AsyncLlamaStackClient to use for the request. + model_id: The ID of the model to use. + system_prompt: The system prompt/instructions for the agent. + + Returns: + tuple[AsyncAgent, str]: A tuple containing the agent and session_id. + """ + logger.debug("Creating temporary agent") + agent = AsyncAgent( + client, # type: ignore[arg-type] + model=model_id, + instructions=system_prompt, + enable_session_persistence=False, # Temporary agent doesn't need persistence + ) + await agent.initialize() + + # Generate new IDs for the temporary agent + session_id = await agent.create_session(get_suid()) + + return agent, session_id diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index cd543d828..8ab916d21 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -1,10 +1,7 @@ """Utility functions for endpoint handlers.""" -from contextlib import suppress import logging from fastapi import HTTPException, status -from llama_stack_client._client import AsyncLlamaStackClient -from llama_stack_client.lib.agents.agent import AsyncAgent import constants from models.requests import QueryRequest @@ -12,8 +9,6 @@ from models.config import Action from app.database import get_session from configuration import AppConfig -from utils.suid import get_suid -from utils.types import GraniteToolParser logger = logging.getLogger("utils.endpoints") @@ -67,6 +62,60 @@ def check_configuration_loaded(config: AppConfig) -> None: detail={"response": "Configuration is not loaded"}, ) +QUESTION_VALIDATOR_PROMPT_TEMPLATE = f""" +Instructions: +- You are a question classifying tool +- You are an expert in Backstage, Red Hat Developer Hub (RHDH), Kubernetes, Openshift, CI/CD and GitOps Pipelines +- Your job is to determine if a user's question is related to Backstage or Red Hat Developer Hub (RHDH) technologies, \ + including integrations, plugins, catalog exploration, service creation, or workflow automation. +- If a question appears to be related to Backstage, RHDH, Kubernetes, Openshift, or any of their features, answer with the word {constants.SUBJECT_ALLOWED} +- If a question is not related to Backstage, RHDH, Kubernetes, Openshift, or their features, answer with the word {constants.SUBJECT_REJECTED} +- Do not explain your answer, just provide the one-word response + + +Example Question: +Why is the sky blue? +Example Response: +{constants.SUBJECT_REJECTED} + +Example Question: +Can you help configure my cluster to automatically scale? +Example Response: +{constants.SUBJECT_ALLOWED} + +Example Question: +How do I create import an existing software template in Backstage? +Example Response: +{constants.SUBJECT_ALLOWED} + +Example Question: +How do I accomplish $task in RHDH? +Example Response: +{constants.SUBJECT_ALLOWED} + +Example Question: +How do I explore a component in RHDH catalog? +Example Response: +{constants.SUBJECT_ALLOWED} + +Example Question: +How can I integrate GitOps into my pipeline? +Example Response: +{constants.SUBJECT_ALLOWED} + +Question: +{{query}} +Response: +""" + +def get_validation_system_prompt() -> str: + """Get the validation system prompt.""" + #return constants.DEFAULT_VALIDATION_SYSTEM_PROMPT + return QUESTION_VALIDATOR_PROMPT_TEMPLATE + +def get_invalid_query_response() -> str: + """Get the invalid query response.""" + return constants.DEFAULT_INVALID_QUERY_RESPONSE def get_system_prompt(query_request: QueryRequest, config: AppConfig) -> str: """Get the system prompt: the provided one, configured one, or default one.""" @@ -125,54 +174,3 @@ def validate_model_provider_override( ) -# # pylint: disable=R0913,R0917 -async def get_agent( - client: AsyncLlamaStackClient, - model_id: str, - system_prompt: str, - available_input_shields: list[str], - available_output_shields: list[str], - conversation_id: str | None, - no_tools: bool = False, -) -> tuple[AsyncAgent, str, str]: - """Get existing agent or create a new one with session persistence.""" - existing_agent_id = None - if conversation_id: - with suppress(ValueError): - agent_response = await client.agents.retrieve(agent_id=conversation_id) - existing_agent_id = agent_response.agent_id - - logger.debug("Creating new agent") - agent = AsyncAgent( - client, # type: ignore[arg-type] - model=model_id, - instructions=system_prompt, - input_shields=available_input_shields if available_input_shields else [], - output_shields=available_output_shields if available_output_shields else [], - tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id), - enable_session_persistence=True, - ) - await agent.initialize() - - if existing_agent_id and conversation_id: - orphan_agent_id = agent.agent_id - agent._agent_id = conversation_id # type: ignore[assignment] # pylint: disable=protected-access - await client.agents.delete(agent_id=orphan_agent_id) - sessions_response = await client.agents.session.list(agent_id=conversation_id) - logger.info("session response: %s", sessions_response) - try: - session_id = str(sessions_response.data[0]["session_id"]) - except IndexError as e: - logger.error("No sessions found for conversation %s", conversation_id) - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail={ - "response": "Conversation not found", - "cause": f"Conversation {conversation_id} could not be retrieved.", - }, - ) from e - else: - conversation_id = agent.agent_id - session_id = await agent.create_session(get_suid()) - - return agent, conversation_id, session_id From 230a7557b5bf2f83ae1208ddd930a596d5691f63 Mon Sep 17 00:00:00 2001 From: Stephanie Date: Tue, 9 Sep 2025 12:31:53 -0400 Subject: [PATCH 2/7] update streaming endpoint Signed-off-by: Stephanie --- src/app/endpoints/query.py | 25 ++++---- src/app/endpoints/streaming_query.py | 86 ++++++++++++++++++++-------- src/utils/agent.py | 5 +- 3 files changed, 81 insertions(+), 35 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 023e52ccb..99e5ff6fc 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -236,7 +236,7 @@ async def query_endpoint_handler( conversation_id=conversation_id, model_id=model_id, provider_id=provider_id, - query_is_valid=True, # TODO(lucasagomes): implement as part of query validation + query_is_valid=query_is_valid, query=query_request.query, query_request=query_request, summary=summary, @@ -393,7 +393,7 @@ def is_input_shield(shield: Shield) -> bool: return _is_inout_shield(shield) or not is_output_shield(shield) -async def validate_question(question: str, client: AsyncLlamaStackClient, model_id: str) -> bool: +async def validate_question(question: str, client: AsyncLlamaStackClient, model_id: str) -> tuple[bool, str]: """Validate a question and provides a one-word response. Args: @@ -405,7 +405,7 @@ async def validate_question(question: str, client: AsyncLlamaStackClient, model_ bool: True if the question was deemed valid, False otherwise """ validation_system_prompt = get_validation_system_prompt() - agent, session_id = await get_temp_agent(client, model_id, validation_system_prompt) + agent, session_id, conversation_id = await get_temp_agent(client, model_id, validation_system_prompt) response = await agent.create_turn( messages=[UserMessage(role="user", content=question)], session_id=session_id, @@ -413,7 +413,7 @@ async def validate_question(question: str, client: AsyncLlamaStackClient, model_ toolgroups=None, ) response = cast(Turn, response) - return constants.SUBJECT_REJECTED not in interleaved_content_as_str(response.output_message.content) + return constants.SUBJECT_REJECTED not in interleaved_content_as_str(response.output_message.content), conversation_id async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches,too-many-arguments client: AsyncLlamaStackClient, @@ -486,6 +486,17 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche ) logger.debug("Conversation ID: %s, session ID: %s", conversation_id, session_id) + + # Validate the question if question validation is enabled + if configuration.question_validation.question_validation_enabled: + question_is_valid, _ = await validate_question(query_request.query, client, model_id) + + if not question_is_valid: + return TurnSummary( + llm_response=get_invalid_query_response(), + tool_calls=[], + ), conversation_id, False + # bypass tools and MCP servers if no_tools is True if query_request.no_tools: mcp_headers = {} @@ -520,12 +531,6 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche if not toolgroups: toolgroups = None - if configuration.question_validation.question_validation_enabled and (not await validate_question(query_request.query, client, model_id)): - return TurnSummary( - llm_response=get_invalid_query_response(), - tool_calls=[], - ), conversation_id, False - response = await agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], session_id=session_id, diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 5069b2346..8bbc1ca44 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -31,7 +31,7 @@ from models.requests import QueryRequest from models.database.conversations import UserConversation from utils.agent import get_agent -from utils.endpoints import check_configuration_loaded, get_system_prompt +from utils.endpoints import check_configuration_loaded, get_system_prompt, get_invalid_query_response from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups from utils.transcripts import store_transcript from utils.types import TurnSummary @@ -39,6 +39,7 @@ from app.endpoints.query import ( get_rag_toolgroups, + validate_question, is_input_shield, is_output_shield, is_transcripts_enabled, @@ -588,6 +589,35 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals user_conversation=user_conversation, query_request=query_request ), ) + + # Check question validation before getting response + query_is_valid = True + if configuration.question_validation.question_validation_enabled: + query_is_valid,temp_agent_conversation_id = await validate_question(query_request.query, client, llama_stack_model_id) + if not query_is_valid: + response = get_invalid_query_response() + if not is_transcripts_enabled(): + logger.debug("Transcript collection is disabled in the configuration") + else: + summary = TurnSummary( + llm_response=response, + tool_calls=[], + ) + store_transcript( + user_id=user_id, + conversation_id = query_request.conversation_id or temp_agent_conversation_id, + model_id=model_id, + provider_id=provider_id, + query_is_valid=query_is_valid, + query=query_request.query, + query_request=query_request, + summary=summary, + rag_chunks=[], # TODO(lucasagomes): implement rag_chunks + truncated=False, # TODO(lucasagomes): implement truncation as part + # of quota work + attachments=query_request.attachments or [], + ) + return StreamingResponse(response) response, conversation_id = await retrieve_response( client, llama_stack_model_id, @@ -599,6 +629,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals async def response_generator( turn_response: AsyncIterator[AgentTurnResponseStreamChunk], + query_is_valid: bool, ) -> AsyncIterator[str]: """ Generate SSE formatted streaming response. @@ -618,27 +649,34 @@ async def response_generator( # Send start event yield stream_start_event(conversation_id) - - async for chunk in turn_response: - p = chunk.event.payload - if p.event_type == "turn_complete": - summary.llm_response = interleaved_content_as_str( - p.turn.output_message.content - ) - system_prompt = get_system_prompt(query_request, configuration) - try: - update_llm_token_count_from_turn( - p.turn, model_id, provider_id, system_prompt + + if not query_is_valid: + # Generate SSE events for invalid query + yield format_stream_data({ + "event": "token", + "data": {"id": 0, "token": get_invalid_query_response()} + }) + else: + async for chunk in turn_response: + p = chunk.event.payload + if p.event_type == "turn_complete": + summary.llm_response = interleaved_content_as_str( + p.turn.output_message.content ) - except Exception: # pylint: disable=broad-except - logger.exception("Failed to update token usage metrics") - elif p.event_type == "step_complete": - if p.step_details.step_type == "tool_execution": - summary.append_tool_calls_from_llama(p.step_details) - - for event in stream_build_event(chunk, chunk_id, metadata_map): - chunk_id += 1 - yield event + system_prompt = get_system_prompt(query_request, configuration) + try: + update_llm_token_count_from_turn( + p.turn, model_id, provider_id, system_prompt + ) + except Exception: # pylint: disable=broad-except + logger.exception("Failed to update token usage metrics") + elif p.event_type == "step_complete": + if p.step_details.step_type == "tool_execution": + summary.append_tool_calls_from_llama(p.step_details) + + for event in stream_build_event(chunk, chunk_id, metadata_map): + chunk_id += 1 + yield event yield stream_end_event(metadata_map) @@ -650,7 +688,7 @@ async def response_generator( conversation_id=conversation_id, model_id=model_id, provider_id=provider_id, - query_is_valid=True, # TODO(lucasagomes): implement as part of query validation + query_is_valid=query_is_valid, query=query_request.query, query_request=query_request, summary=summary, @@ -670,7 +708,7 @@ async def response_generator( # Update metrics for the LLM call metrics.llm_calls_total.labels(provider_id, model_id).inc() - return StreamingResponse(response_generator(response)) + return StreamingResponse(response_generator(response, query_is_valid)) # connection to Llama Stack server except APIConnectionError as e: # Update metrics for the LLM call failure @@ -751,6 +789,8 @@ async def retrieve_response( ) logger.debug("Conversation ID: %s, session ID: %s", conversation_id, session_id) + + # bypass tools and MCP servers if no_tools is True if query_request.no_tools: mcp_headers = {} diff --git a/src/utils/agent.py b/src/utils/agent.py index ddbce84cc..8780c1ee8 100644 --- a/src/utils/agent.py +++ b/src/utils/agent.py @@ -71,7 +71,7 @@ async def get_temp_agent( client: AsyncLlamaStackClient, model_id: str, system_prompt: str, -) -> tuple[AsyncAgent, str]: +) -> tuple[AsyncAgent, str, str]: """Create a temporary agent with new agent_id and session_id. This function creates a new agent without persistence, shields, or tools. @@ -95,6 +95,7 @@ async def get_temp_agent( await agent.initialize() # Generate new IDs for the temporary agent + conversation_id = agent.agent_id session_id = await agent.create_session(get_suid()) - return agent, session_id + return agent, session_id, conversation_id From 008fa923915ebdf6f5b770f86e3cbd764bdef102 Mon Sep 17 00:00:00 2001 From: Stephanie Date: Wed, 10 Sep 2025 17:38:58 -0400 Subject: [PATCH 3/7] fix existing unit tests Signed-off-by: Stephanie --- tests/unit/app/endpoints/test_query.py | 56 ++- .../app/endpoints/test_streaming_query.py | 3 + .../models/config/test_dump_configuration.py | 4 + tests/unit/utils/test_agent.py | 455 ++++++++++++++++++ tests/unit/utils/test_endpoints.py | 444 ----------------- 5 files changed, 499 insertions(+), 463 deletions(-) create mode 100644 tests/unit/utils/test_agent.py diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 3b3d64f3f..8786cc2c2 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -152,6 +152,8 @@ async def _test_query_endpoint_handler( mock_config.user_data_collection_configuration.transcripts_enabled = ( store_transcript_to_file ) + # Mock question validation configuration to be disabled + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) summary = TurnSummary( @@ -170,7 +172,7 @@ async def _test_query_endpoint_handler( mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(summary, conversation_id), + return_value=(summary, conversation_id, True), ) mocker.patch( "app.endpoints.query.select_model_and_provider_id", @@ -447,6 +449,7 @@ async def test_retrieve_response_no_returned_message(prepare_agent_mocks, mocker # Mock configuration with empty MCP servers mock_config = mocker.Mock() mock_config.mcp_servers = [] + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( "app.endpoints.query.get_agent", @@ -458,7 +461,7 @@ async def test_retrieve_response_no_returned_message(prepare_agent_mocks, mocker model_id = "fake_model_id" access_token = "test_token" - response, _ = await retrieve_response( + response, _, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -479,6 +482,7 @@ async def test_retrieve_response_message_without_content(prepare_agent_mocks, mo # Mock configuration with empty MCP servers mock_config = mocker.Mock() mock_config.mcp_servers = [] + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( "app.endpoints.query.get_agent", @@ -490,7 +494,7 @@ async def test_retrieve_response_message_without_content(prepare_agent_mocks, mo model_id = "fake_model_id" access_token = "test_token" - response, _ = await retrieve_response( + response, _, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -512,6 +516,7 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker # Mock configuration with empty MCP servers mock_config = mocker.Mock() mock_config.mcp_servers = [] + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( "app.endpoints.query.get_agent", @@ -523,7 +528,7 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -551,6 +556,7 @@ async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocke # Mock configuration with empty MCP servers mock_config = mocker.Mock() mock_config.mcp_servers = [] + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( "app.endpoints.query.get_agent", @@ -562,7 +568,7 @@ async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocke model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -601,6 +607,7 @@ def __repr__(self): # Mock configuration with empty MCP servers mock_config = mocker.Mock() mock_config.mcp_servers = [] + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( "app.endpoints.query.get_agent", @@ -612,7 +619,7 @@ def __repr__(self): model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -654,6 +661,7 @@ def __repr__(self): # Mock configuration with empty MCP servers mock_config = mocker.Mock() mock_config.mcp_servers = [] + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( "app.endpoints.query.get_agent", @@ -665,7 +673,7 @@ def __repr__(self): model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -709,6 +717,7 @@ def __repr__(self): # Mock configuration with empty MCP servers mock_config = mocker.Mock() mock_config.mcp_servers = [] + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mock_get_agent = mocker.patch( "app.endpoints.query.get_agent", @@ -720,7 +729,7 @@ def __repr__(self): model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -758,6 +767,7 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker # Mock configuration with empty MCP servers mock_config = mocker.Mock() mock_config.mcp_servers = [] + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) attachments = [ @@ -777,7 +787,7 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -808,6 +818,7 @@ async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocke # Mock configuration with empty MCP servers mock_config = mocker.Mock() mock_config.mcp_servers = [] + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) attachments = [ @@ -832,7 +843,7 @@ async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocke model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -877,6 +888,7 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): ] mock_config = mocker.Mock() mock_config.mcp_servers = mcp_servers + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mock_get_agent = mocker.patch( "app.endpoints.query.get_agent", @@ -888,7 +900,7 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token_123" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -947,6 +959,7 @@ async def test_retrieve_response_with_mcp_servers_empty_token( ] mock_config = mocker.Mock() mock_config.mcp_servers = mcp_servers + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mock_get_agent = mocker.patch( "app.endpoints.query.get_agent", @@ -958,7 +971,7 @@ async def test_retrieve_response_with_mcp_servers_empty_token( model_id = "fake_model_id" access_token = "" # Empty token - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -1009,6 +1022,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers( ] mock_config = mocker.Mock() mock_config.mcp_servers = mcp_servers + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mock_get_agent = mocker.patch( "app.endpoints.query.get_agent", @@ -1030,7 +1044,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers( }, } - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, @@ -1106,6 +1120,7 @@ async def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker): # Mock configuration with empty MCP servers mock_config = mocker.Mock() mock_config.mcp_servers = [] + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( "app.endpoints.query.get_agent", @@ -1115,7 +1130,7 @@ async def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker): query_request = QueryRequest(query="What is OpenStack?") - _, conversation_id = await retrieve_response( + _, conversation_id, _ = await retrieve_response( mock_client, "fake_model_id", query_request, "test_token" ) @@ -1177,6 +1192,7 @@ async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker, dummy_requ # Mock dependencies mock_config = mocker.Mock() mock_config.llama_stack_configuration = mocker.Mock() + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mock_client = mocker.AsyncMock() @@ -1200,7 +1216,7 @@ async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker, dummy_requ ) mock_retrieve_response = mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(summary, "test_conversation_id"), + return_value=(summary, "test_conversation_id", True), ) mocker.patch( @@ -1251,7 +1267,7 @@ async def test_query_endpoint_handler_no_tools_true(mocker, dummy_request): mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(summary, conversation_id), + return_value=(summary, conversation_id, True), ) mocker.patch( "app.endpoints.query.select_model_and_provider_id", @@ -1302,7 +1318,7 @@ async def test_query_endpoint_handler_no_tools_false(mocker, dummy_request): mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(summary, conversation_id), + return_value=(summary, conversation_id, True), ) mocker.patch( "app.endpoints.query.select_model_and_provider_id", @@ -1343,6 +1359,7 @@ async def test_retrieve_response_no_tools_bypasses_mcp_and_rag( ] mock_config = mocker.Mock() mock_config.mcp_servers = mcp_servers + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( "app.endpoints.query.get_agent", @@ -1354,7 +1371,7 @@ async def test_retrieve_response_no_tools_bypasses_mcp_and_rag( model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -1394,6 +1411,7 @@ async def test_retrieve_response_no_tools_false_preserves_functionality( ] mock_config = mocker.Mock() mock_config.mcp_servers = mcp_servers + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( "app.endpoints.query.get_agent", @@ -1405,7 +1423,7 @@ async def test_retrieve_response_no_tools_false_preserves_functionality( model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 38983666a..9d896d184 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -1283,6 +1283,7 @@ async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker): # Mock dependencies mock_config = mocker.Mock() mock_config.llama_stack_configuration = mocker.Mock() + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mock_client = mocker.AsyncMock() @@ -1337,6 +1338,7 @@ async def test_streaming_query_endpoint_handler_no_tools_true(mocker): mock_config = mocker.Mock() mock_config.user_data_collection_configuration.transcripts_disabled = True + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.streaming_query.configuration", mock_config) # Mock the streaming response @@ -1384,6 +1386,7 @@ async def test_streaming_query_endpoint_handler_no_tools_false(mocker): mock_config = mocker.Mock() mock_config.user_data_collection_configuration.transcripts_disabled = True + mock_config.question_validation.question_validation_enabled = False mocker.patch("app.endpoints.streaming_query.configuration", mock_config) # Mock the streaming response diff --git a/tests/unit/models/config/test_dump_configuration.py b/tests/unit/models/config/test_dump_configuration.py index 303b17998..7aad93045 100644 --- a/tests/unit/models/config/test_dump_configuration.py +++ b/tests/unit/models/config/test_dump_configuration.py @@ -88,6 +88,7 @@ def test_dump_configuration(tmp_path) -> None: assert "customization" in content assert "inference" in content assert "database" in content + assert "question_validation" in content # check the whole deserialized JSON file content assert content == { @@ -163,6 +164,9 @@ def test_dump_configuration(tmp_path) -> None: }, }, "authorization": None, + "question_validation": { + "question_validation_enabled": False + }, } diff --git a/tests/unit/utils/test_agent.py b/tests/unit/utils/test_agent.py new file mode 100644 index 000000000..3d1a9e44d --- /dev/null +++ b/tests/unit/utils/test_agent.py @@ -0,0 +1,455 @@ +"""Unit tests for agent utility functions.""" + +import pytest + +from configuration import AppConfig +from tests.unit import config_dict + +from utils.agent import get_agent + + +@pytest.fixture(name="setup_configuration") +def setup_configuration_fixture(): + """Set up configuration for tests.""" + test_config_dict = { + "name": "test", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "workers": 1, + "color_log": True, + "access_log": True, + }, + "llama_stack": { + "api_key": "test-key", + "url": "http://test.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": { + "transcripts_enabled": False, + }, + "mcp_servers": [], + } + cfg = AppConfig() + cfg.init_from_dict(test_config_dict) + return cfg + + +@pytest.mark.asyncio +async def test_get_agent_with_conversation_id(prepare_agent_mocks, mocker): + """Test get_agent function when agent exists in llama stack.""" + mock_client, mock_agent = prepare_agent_mocks + conversation_id = "test_conversation_id" + + # Mock existing agent retrieval + mock_agent_response = mocker.Mock() + mock_agent_response.agent_id = conversation_id + mock_client.agents.retrieve.return_value = mock_agent_response + + mock_client.agents.session.list.return_value = mocker.Mock( + data=[{"session_id": "test_session_id"}] + ) + + # Mock Agent class + mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) + + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=conversation_id, + ) + + # Assert the same agent is returned and conversation_id is preserved + assert result_agent == mock_agent + assert result_conversation_id == conversation_id + assert result_session_id == "test_session_id" + + +@pytest.mark.asyncio +async def test_get_agent_with_conversation_id_and_no_agent_in_llama_stack( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function when conversation_id is provided.""" + mock_client, mock_agent = prepare_agent_mocks + mock_client.agents.retrieve.side_effect = ValueError( + "fake not finding existing agent" + ) + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "utils.agent.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("utils.agent.get_suid", return_value="new_session_id") + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("configuration.configuration", setup_configuration) + conversation_id = "non_existent_conversation_id" + # Call function with conversation_id + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=conversation_id, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == result_agent.agent_id + assert conversation_id != result_agent.agent_id + assert result_session_id == "new_session_id" + + # Verify Agent was created with correct parameters + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + output_shields=["output_shield2"], + tool_parser=None, + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_no_conversation_id( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function when conversation_id is None.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "utils.agent.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("utils.agent.get_suid", return_value="new_session_id") + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("configuration.configuration", setup_configuration) + + # Call function with None conversation_id + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=None, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" + + # Verify Agent was created with correct parameters + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + output_shields=["output_shield2"], + tool_parser=None, + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_empty_shields( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function with empty shields list.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "utils.agent.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("utils.agent.get_suid", return_value="new_session_id") + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("configuration.configuration", setup_configuration) + + # Call function with empty shields list + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=[], + available_output_shields=[], + conversation_id=None, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" + + # Verify Agent was created with empty shields + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=[], + output_shields=[], + tool_parser=None, + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_multiple_mcp_servers( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function with multiple MCP servers.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "utils.agent.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("utils.agent.get_suid", return_value="new_session_id") + + # Mock configuration with multiple MCP servers + mock_mcp_server1 = mocker.Mock() + mock_mcp_server1.name = "mcp_server_1" + mock_mcp_server2 = mocker.Mock() + mock_mcp_server2.name = "mcp_server_2" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server1, mock_mcp_server2], + ) + mocker.patch("configuration.configuration", setup_configuration) + + # Call function + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1", "shield2"], + available_output_shields=["output_shield3", "output_shield4"], + conversation_id=None, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" + + # Verify Agent was created with tools from both MCP servers + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1", "shield2"], + output_shields=["output_shield3", "output_shield4"], + tool_parser=None, + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_session_persistence_enabled( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function ensures session persistence is enabled.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "utils.agent.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("utils.agent.get_suid", return_value="new_session_id") + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("configuration.configuration", setup_configuration) + + # Call function + await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=None, + ) + + # Verify Agent was created with session persistence enabled + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + output_shields=["output_shield2"], + tool_parser=None, + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_no_tools_no_parser( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function sets tool_parser=None when no_tools=True.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "utils.agent.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("utils.agent.get_suid", return_value="new_session_id") + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("configuration.configuration", setup_configuration) + + # Call function with no_tools=True + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=None, + no_tools=True, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" + + # Verify Agent was created with tool_parser=None + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + output_shields=["output_shield2"], + tool_parser=None, + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_no_tools_false_preserves_parser( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function preserves tool_parser when no_tools=False.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "utils.agent.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("utils.agent.get_suid", return_value="new_session_id") + + # Mock GraniteToolParser + mock_parser = mocker.Mock() + mock_granite_parser = mocker.patch("utils.agent.GraniteToolParser") + mock_granite_parser.get_parser.return_value = mock_parser + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("configuration.configuration", setup_configuration) + + # Call function with no_tools=False + result_agent, result_conversation_id, result_session_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], + conversation_id=None, + no_tools=False, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" + + # Verify Agent was created with the proper tool_parser + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + output_shields=["output_shield2"], + tool_parser=mock_parser, + enable_session_persistence=True, + ) diff --git a/tests/unit/utils/test_endpoints.py b/tests/unit/utils/test_endpoints.py index d9a496d07..5feeabd9d 100644 --- a/tests/unit/utils/test_endpoints.py +++ b/tests/unit/utils/test_endpoints.py @@ -11,7 +11,6 @@ from models.requests import QueryRequest from models.config import Action from utils import endpoints -from utils.endpoints import get_agent CONFIGURED_SYSTEM_PROMPT = "This is a configured system prompt" @@ -82,32 +81,6 @@ def query_request_with_system_prompt_fixture(): return QueryRequest(query="query", system_prompt="System prompt defined in query") -@pytest.fixture(name="setup_configuration") -def setup_configuration_fixture(): - """Set up configuration for tests.""" - test_config_dict = { - "name": "test", - "service": { - "host": "localhost", - "port": 8080, - "auth_enabled": False, - "workers": 1, - "color_log": True, - "access_log": True, - }, - "llama_stack": { - "api_key": "test-key", - "url": "http://test.com:1234", - "use_as_library_client": False, - }, - "user_data_collection": { - "transcripts_enabled": False, - }, - "mcp_servers": [], - } - cfg = AppConfig() - cfg.init_from_dict(test_config_dict) - return cfg def test_get_default_system_prompt( @@ -175,423 +148,6 @@ def test_get_system_prompt_with_disable_query_system_prompt_and_non_system_promp assert system_prompt == CONFIGURED_SYSTEM_PROMPT -@pytest.mark.asyncio -async def test_get_agent_with_conversation_id(prepare_agent_mocks, mocker): - """Test get_agent function when agent exists in llama stack.""" - mock_client, mock_agent = prepare_agent_mocks - conversation_id = "test_conversation_id" - - # Mock existing agent retrieval - mock_agent_response = mocker.Mock() - mock_agent_response.agent_id = conversation_id - mock_client.agents.retrieve.return_value = mock_agent_response - - mock_client.agents.session.list.return_value = mocker.Mock( - data=[{"session_id": "test_session_id"}] - ) - - # Mock Agent class - mocker.patch("utils.endpoints.AsyncAgent", return_value=mock_agent) - - result_agent, result_conversation_id, result_session_id = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=conversation_id, - ) - - # Assert the same agent is returned and conversation_id is preserved - assert result_agent == mock_agent - assert result_conversation_id == conversation_id - assert result_session_id == "test_session_id" - - -@pytest.mark.asyncio -async def test_get_agent_with_conversation_id_and_no_agent_in_llama_stack( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function when conversation_id is provided.""" - mock_client, mock_agent = prepare_agent_mocks - mock_client.agents.retrieve.side_effect = ValueError( - "fake not finding existing agent" - ) - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "utils.endpoints.AsyncAgent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("configuration.configuration", setup_configuration) - conversation_id = "non_existent_conversation_id" - # Call function with conversation_id - result_agent, result_conversation_id, result_session_id = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=conversation_id, - ) - - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert conversation_id != result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with correct parameters - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1"], - output_shields=["output_shield2"], - tool_parser=None, - enable_session_persistence=True, - ) - - -@pytest.mark.asyncio -async def test_get_agent_no_conversation_id( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function when conversation_id is None.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "utils.endpoints.AsyncAgent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("configuration.configuration", setup_configuration) - - # Call function with None conversation_id - result_agent, result_conversation_id, result_session_id = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=None, - ) - - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with correct parameters - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1"], - output_shields=["output_shield2"], - tool_parser=None, - enable_session_persistence=True, - ) - - -@pytest.mark.asyncio -async def test_get_agent_empty_shields( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function with empty shields list.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "utils.endpoints.AsyncAgent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("configuration.configuration", setup_configuration) - - # Call function with empty shields list - result_agent, result_conversation_id, result_session_id = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=[], - available_output_shields=[], - conversation_id=None, - ) - - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with empty shields - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=[], - output_shields=[], - tool_parser=None, - enable_session_persistence=True, - ) - - -@pytest.mark.asyncio -async def test_get_agent_multiple_mcp_servers( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function with multiple MCP servers.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "utils.endpoints.AsyncAgent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") - - # Mock configuration with multiple MCP servers - mock_mcp_server1 = mocker.Mock() - mock_mcp_server1.name = "mcp_server_1" - mock_mcp_server2 = mocker.Mock() - mock_mcp_server2.name = "mcp_server_2" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server1, mock_mcp_server2], - ) - mocker.patch("configuration.configuration", setup_configuration) - - # Call function - result_agent, result_conversation_id, result_session_id = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1", "shield2"], - available_output_shields=["output_shield3", "output_shield4"], - conversation_id=None, - ) - - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with tools from both MCP servers - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1", "shield2"], - output_shields=["output_shield3", "output_shield4"], - tool_parser=None, - enable_session_persistence=True, - ) - - -@pytest.mark.asyncio -async def test_get_agent_session_persistence_enabled( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function ensures session persistence is enabled.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "utils.endpoints.AsyncAgent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("configuration.configuration", setup_configuration) - - # Call function - await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=None, - ) - - # Verify Agent was created with session persistence enabled - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1"], - output_shields=["output_shield2"], - tool_parser=None, - enable_session_persistence=True, - ) - - -@pytest.mark.asyncio -async def test_get_agent_no_tools_no_parser( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function sets tool_parser=None when no_tools=True.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "utils.endpoints.AsyncAgent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("configuration.configuration", setup_configuration) - - # Call function with no_tools=True - result_agent, result_conversation_id, result_session_id = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=None, - no_tools=True, - ) - - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with tool_parser=None - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1"], - output_shields=["output_shield2"], - tool_parser=None, - enable_session_persistence=True, - ) - - -@pytest.mark.asyncio -async def test_get_agent_no_tools_false_preserves_parser( - setup_configuration, prepare_agent_mocks, mocker -): - """Test get_agent function preserves tool_parser when no_tools=False.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_session.return_value = "new_session_id" - - # Mock Agent class - mock_agent_class = mocker.patch( - "utils.endpoints.AsyncAgent", return_value=mock_agent - ) - - # Mock get_suid - mocker.patch("utils.endpoints.get_suid", return_value="new_session_id") - - # Mock GraniteToolParser - mock_parser = mocker.Mock() - mock_granite_parser = mocker.patch("utils.endpoints.GraniteToolParser") - mock_granite_parser.get_parser.return_value = mock_parser - - # Mock configuration - mock_mcp_server = mocker.Mock() - mock_mcp_server.name = "mcp_server_1" - mocker.patch.object( - type(setup_configuration), - "mcp_servers", - new_callable=mocker.PropertyMock, - return_value=[mock_mcp_server], - ) - mocker.patch("configuration.configuration", setup_configuration) - - # Call function with no_tools=False - result_agent, result_conversation_id, result_session_id = await get_agent( - client=mock_client, - model_id="test_model", - system_prompt="test_prompt", - available_input_shields=["shield1"], - available_output_shields=["output_shield2"], - conversation_id=None, - no_tools=False, - ) - - # Assert new agent is created - assert result_agent == mock_agent - assert result_conversation_id == result_agent.agent_id - assert result_session_id == "new_session_id" - - # Verify Agent was created with the proper tool_parser - mock_agent_class.assert_called_once_with( - mock_client, - model="test_model", - instructions="test_prompt", - input_shields=["shield1"], - output_shields=["output_shield2"], - tool_parser=mock_parser, - enable_session_persistence=True, - ) def test_validate_model_provider_override_allowed_with_action(): From fdcaf43c9d3eaff5dc9b475305197b5012baadb8 Mon Sep 17 00:00:00 2001 From: Stephanie Date: Thu, 11 Sep 2025 17:38:48 -0400 Subject: [PATCH 4/7] add query tests Signed-off-by: Stephanie --- tests/unit/app/endpoints/test_query.py | 126 ++++++++++++ .../app/endpoints/test_streaming_query.py | 181 ++++++++++++++++++ tests/unit/utils/test_agent.py | 104 +++++++++- 3 files changed, 410 insertions(+), 1 deletion(-) diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 8786cc2c2..f5f419617 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -1604,3 +1604,129 @@ async def test_query_endpoint_rejects_model_provider_override_without_permission ) assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN assert exc_info.value.detail["response"] == expected_msg + + +@pytest.mark.asyncio +async def test_query_endpoint_with_question_validation_invalid_query( + setup_configuration, dummy_request, mocker +): + """Test query endpoint with question validation enabled and invalid query.""" + # Mock metrics + mock_metrics(mocker) + + # Mock database operations + mock_database_operations(mocker) + + # Setup configuration with question validation enabled + setup_configuration.question_validation.question_validation_enabled = True + mocker.patch("app.endpoints.query.configuration", setup_configuration) + + # Mock the validation agent response (invalid query) + mock_validation_agent = mocker.AsyncMock() + mock_validation_agent.agent_id = "validation_agent_id" + mock_validation_agent.create_session.return_value = "validation_session_id" + + # Mock the validation response that contains SUBJECT_REJECTED + mock_validation_turn = mocker.Mock() + mock_validation_turn.output_message.content = [{"type": "text", "text": "REJECTED"}] + mock_validation_agent.create_turn.return_value = mock_validation_turn + + # Mock the main agent (should not be called for invalid queries) + mock_agent = mocker.AsyncMock() + mock_agent.agent_id = "conversation_id" + mock_agent.create_session.return_value = "session_id" + + # Mock the client + mock_client = mocker.AsyncMock() + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1") + ] + mocker.patch("client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client) + + # Mock the retrieve_response function to return invalid query response + summary = TurnSummary( + llm_response="Invalid query response", + tool_calls=[], + ) + conversation_id = "fake_conversation_id" + + mocker.patch( + "app.endpoints.query.retrieve_response", + return_value=(summary, conversation_id, False), # query_is_valid=False + ) + mocker.patch( + "app.endpoints.query.select_model_and_provider_id", + return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), + ) + mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) + + query_request = QueryRequest(query="Invalid question about unrelated topic") + + response = await query_endpoint_handler( + request=dummy_request, + query_request=query_request, + auth=MOCK_AUTH + ) + + # Verify the response contains the invalid query response + assert response.conversation_id == conversation_id + assert response.response == "Invalid query response" + + +@pytest.mark.asyncio +async def test_query_endpoint_with_question_validation_valid_query( + setup_configuration, dummy_request, mocker +): + """Test query endpoint with question validation enabled and valid query.""" + # Mock metrics + mock_metrics(mocker) + + # Mock database operations + mock_database_operations(mocker) + + # Setup configuration with question validation enabled + setup_configuration.question_validation.question_validation_enabled = True + mocker.patch("app.endpoints.query.configuration", setup_configuration) + + # Mock the client + mock_client = mocker.AsyncMock() + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1") + ] + mocker.patch("client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client) + + # Mock the retrieve_response function to return valid query response + summary = TurnSummary( + llm_response="Valid LLM response about Backstage and Kubernetes", + tool_calls=[ + ToolCallSummary( + id="123", + name="test-tool", + args="testing", + response="tool response", + ) + ], + ) + conversation_id = "fake_conversation_id" + + mocker.patch( + "app.endpoints.query.retrieve_response", + return_value=(summary, conversation_id, True), # query_is_valid=True + ) + mocker.patch( + "app.endpoints.query.select_model_and_provider_id", + return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), + ) + mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) + + query_request = QueryRequest(query="Valid question about Backstage and Kubernetes") + + response = await query_endpoint_handler( + request=dummy_request, + query_request=query_request, + auth=MOCK_AUTH + ) + + # Verify the response contains the normal LLM response + assert response.conversation_id == conversation_id + assert response.response == "Valid LLM response about Backstage and Kubernetes" diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 9d896d184..173694b06 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -1587,3 +1587,184 @@ async def test_streaming_query_endpoint_rejects_model_provider_override_without_ ) assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN assert exc_info.value.detail["response"] == expected_msg + + +@pytest.mark.asyncio +async def test_streaming_query_endpoint_with_question_validation_invalid_query( + setup_configuration, mocker +): + """Test streaming query endpoint with question validation enabled and invalid query.""" + # Mock metrics + mock_metrics(mocker) + + # Mock database operations + mock_database_operations(mocker) + + # Setup configuration with question validation enabled + setup_configuration.question_validation.question_validation_enabled = True + mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) + + # Mock the client + mock_client = mocker.AsyncMock() + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1") + ] + mocker.patch("client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client) + + # Mock the validation agent response (invalid query) + mock_validation_agent = mocker.AsyncMock() + mock_validation_agent.agent_id = "validation_agent_id" + mock_validation_agent.create_session.return_value = "validation_session_id" + + # Mock the validation response that contains SUBJECT_REJECTED + mock_validation_turn = mocker.Mock() + mock_validation_turn.output_message.content = [{"type": "text", "text": "REJECTED"}] + mock_validation_agent.create_turn.return_value = mock_validation_turn + + # Mock the validation functions + mocker.patch("app.endpoints.streaming_query.validate_question", return_value=(False, "validation_agent_id")) + mocker.patch("app.endpoints.streaming_query.get_invalid_query_response", return_value="Invalid query response") + mocker.patch("app.endpoints.streaming_query.interleaved_content_as_str", return_value="REJECTED") + + # Mock other dependencies + mocker.patch("app.endpoints.streaming_query.validate_model_provider_override") + mocker.patch("app.endpoints.streaming_query.check_configuration_loaded") + mocker.patch("app.endpoints.streaming_query.is_transcripts_enabled", return_value=False) + mocker.patch( + "app.endpoints.streaming_query.select_model_and_provider_id", + return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), + ) + + query_request = QueryRequest(query="Invalid question about unrelated topic") + + request = Request( + scope={ + "type": "http", + } + ) + request.state.authorized_actions = set(Action) + + response = await streaming_query_endpoint_handler( + request=request, + query_request=query_request, + auth=MOCK_AUTH + ) + + # Verify the response is a StreamingResponse + assert isinstance(response, StreamingResponse) + + # Collect the streaming response content + streaming_content = [] + async for chunk in response.body_iterator: + streaming_content.append(chunk) + + # Convert to string for assertions + full_content = "".join(streaming_content) + + # Verify the response contains the invalid query response + assert "Invalid query response" in full_content + + +@pytest.mark.asyncio +async def test_streaming_query_endpoint_with_question_validation_valid_query( + setup_configuration, mocker +): + """Test streaming query endpoint with question validation enabled and valid query.""" + # Mock metrics + mock_metrics(mocker) + + # Mock database operations + mock_database_operations(mocker) + + # Setup configuration with question validation enabled + setup_configuration.question_validation.question_validation_enabled = True + mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) + + # Mock the client + mock_client = mocker.AsyncMock() + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1") + ] + mocker.patch("client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client) + + # Mock the validation functions to return valid query + mocker.patch("app.endpoints.streaming_query.validate_question", return_value=(True, "validation_agent_id")) + mocker.patch("app.endpoints.streaming_query.interleaved_content_as_str", return_value="Valid question about Backstage") + + # Mock the retrieve_response function to return valid streaming response + mock_streaming_response = mocker.AsyncMock() + mock_streaming_response.__aiter__.return_value = [ + AgentTurnResponseStreamChunk( + event=TurnResponseEvent( + payload=AgentTurnResponseStepProgressPayload( + event_type="step_progress", + step_type="inference", + delta=TextDelta(text="Valid LLM response about Backstage and Kubernetes", type="text"), + step_id="s1", + ) + ) + ), + AgentTurnResponseStreamChunk( + event=TurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + event_type="turn_complete", + turn=Turn( + turn_id="t1", + input_messages=[], + output_message=CompletionMessage( + role="assistant", + content=[TextContentItem(text="Valid LLM response about Backstage and Kubernetes", type="text")], + stop_reason="end_of_turn", + ), + session_id="test_session_id", + started_at=datetime.now(), + steps=[], + completed_at=datetime.now(), + output_attachments=[], + ), + ) + ) + ), + ] + + mocker.patch( + "app.endpoints.streaming_query.retrieve_response", + return_value=(mock_streaming_response, "test_conversation_id"), + ) + mocker.patch( + "app.endpoints.streaming_query.select_model_and_provider_id", + return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), + ) + mocker.patch("app.endpoints.streaming_query.is_transcripts_enabled", return_value=False) + + query_request = QueryRequest(query="Valid question about Backstage and Kubernetes") + + request = Request( + scope={ + "type": "http", + } + ) + request.state.authorized_actions = set(Action) + + response = await streaming_query_endpoint_handler( + request=request, + query_request=query_request, + auth=MOCK_AUTH + ) + + # Verify the response is a StreamingResponse + assert isinstance(response, StreamingResponse) + + # Collect the streaming response content + streaming_content = [] + async for chunk in response.body_iterator: + streaming_content.append(chunk) + + # Convert to string for assertions + full_content = "".join(streaming_content) + + # Verify the response contains the normal LLM response + assert "Valid LLM response about Backstage and Kubernetes" in full_content + assert '"event": "start"' in full_content + assert '"event": "token"' in full_content + assert '"event": "end"' in full_content diff --git a/tests/unit/utils/test_agent.py b/tests/unit/utils/test_agent.py index 3d1a9e44d..f8d89bcc8 100644 --- a/tests/unit/utils/test_agent.py +++ b/tests/unit/utils/test_agent.py @@ -5,7 +5,7 @@ from configuration import AppConfig from tests.unit import config_dict -from utils.agent import get_agent +from utils.agent import get_agent, get_temp_agent @pytest.fixture(name="setup_configuration") @@ -453,3 +453,105 @@ async def test_get_agent_no_tools_false_preserves_parser( tool_parser=mock_parser, enable_session_persistence=True, ) + + +@pytest.mark.asyncio +async def test_get_temp_agent_basic_functionality(prepare_agent_mocks, mocker): + """Test get_temp_agent function creates agent with correct parameters.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "temp_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "utils.agent.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("utils.agent.get_suid", return_value="temp_session_id") + + # Call function + result_agent, result_session_id, result_conversation_id = await get_temp_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + ) + + # Assert agent, session_id, and conversation_id are created and returned + assert result_agent == mock_agent + assert result_session_id == "temp_session_id" + assert result_conversation_id == mock_agent.agent_id + + # Verify Agent was created with correct parameters for temporary agent + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + enable_session_persistence=False, # Key difference: no persistence + ) + + # Verify agent was initialized and session was created + mock_agent.initialize.assert_called_once() + mock_agent.create_session.assert_called_once_with("temp_session_id") + + +@pytest.mark.asyncio +async def test_get_temp_agent_returns_valid_ids(prepare_agent_mocks, mocker): + """Test get_temp_agent function returns valid agent_id and session_id.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.agent_id = "generated_agent_id" + mock_agent.create_session.return_value = "generated_session_id" + + # Mock Agent class + mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) + + # Mock get_suid + mocker.patch("utils.agent.get_suid", return_value="generated_session_id") + + # Call function + result_agent, result_session_id, result_conversation_id = await get_temp_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + ) + + # Assert all three values are returned and are not None/empty + assert result_agent is not None + assert result_session_id is not None + assert result_conversation_id is not None + + # Assert they are strings + assert isinstance(result_session_id, str) + assert isinstance(result_conversation_id, str) + + # Assert conversation_id matches agent_id + assert result_conversation_id == result_agent.agent_id + + +@pytest.mark.asyncio +async def test_get_temp_agent_no_persistence(prepare_agent_mocks, mocker): + """Test get_temp_agent function creates agent without session persistence.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "temp_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "utils.agent.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("utils.agent.get_suid", return_value="temp_session_id") + + # Call function + await get_temp_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + ) + + # Verify Agent was created with session persistence disabled + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + enable_session_persistence=False, + ) From 67a24061d95226c9b5b546ae619ada47f78ced6e Mon Sep 17 00:00:00 2001 From: Stephanie Date: Fri, 12 Sep 2025 15:57:37 -0400 Subject: [PATCH 5/7] fix format Signed-off-by: Stephanie --- src/app/endpoints/query.py | 35 ++++-- src/app/endpoints/streaming_query.py | 32 +++-- src/constants.py | 4 +- src/models/config.py | 4 +- src/utils/agent.py | 10 +- src/utils/endpoints.py | 8 +- tests/unit/app/endpoints/test_query.py | 54 ++++---- .../app/endpoints/test_streaming_query.py | 115 +++++++++++------- .../models/config/test_dump_configuration.py | 4 +- tests/unit/utils/test_agent.py | 41 ++----- tests/unit/utils/test_endpoints.py | 4 - 11 files changed, 170 insertions(+), 141 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 99e5ff6fc..a7085d360 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -217,7 +217,7 @@ async def query_endpoint_handler( user_conversation=user_conversation, query_request=query_request ), ) - summary, conversation_id, query_is_valid= await retrieve_response( + summary, conversation_id, query_is_valid = await retrieve_response( client, llama_stack_model_id, query_request, @@ -393,7 +393,9 @@ def is_input_shield(shield: Shield) -> bool: return _is_inout_shield(shield) or not is_output_shield(shield) -async def validate_question(question: str, client: AsyncLlamaStackClient, model_id: str) -> tuple[bool, str]: +async def validate_question( + question: str, client: AsyncLlamaStackClient, model_id: str +) -> tuple[bool, str]: """Validate a question and provides a one-word response. Args: @@ -405,7 +407,9 @@ async def validate_question(question: str, client: AsyncLlamaStackClient, model_ bool: True if the question was deemed valid, False otherwise """ validation_system_prompt = get_validation_system_prompt() - agent, session_id, conversation_id = await get_temp_agent(client, model_id, validation_system_prompt) + agent, session_id, conversation_id = await get_temp_agent( + client, model_id, validation_system_prompt + ) response = await agent.create_turn( messages=[UserMessage(role="user", content=question)], session_id=session_id, @@ -413,7 +417,12 @@ async def validate_question(question: str, client: AsyncLlamaStackClient, model_ toolgroups=None, ) response = cast(Turn, response) - return constants.SUBJECT_REJECTED not in interleaved_content_as_str(response.output_message.content), conversation_id + return ( + constants.SUBJECT_REJECTED + not in interleaved_content_as_str(response.output_message.content), + conversation_id, + ) + async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches,too-many-arguments client: AsyncLlamaStackClient, @@ -489,13 +498,19 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche # Validate the question if question validation is enabled if configuration.question_validation.question_validation_enabled: - question_is_valid, _ = await validate_question(query_request.query, client, model_id) - + question_is_valid, _ = await validate_question( + query_request.query, client, model_id + ) + if not question_is_valid: - return TurnSummary( - llm_response=get_invalid_query_response(), - tool_calls=[], - ), conversation_id, False + return ( + TurnSummary( + llm_response=get_invalid_query_response(), + tool_calls=[], + ), + conversation_id, + False, + ) # bypass tools and MCP servers if no_tools is True if query_request.no_tools: diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 8bbc1ca44..664625160 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -31,7 +31,11 @@ from models.requests import QueryRequest from models.database.conversations import UserConversation from utils.agent import get_agent -from utils.endpoints import check_configuration_loaded, get_system_prompt, get_invalid_query_response +from utils.endpoints import ( + check_configuration_loaded, + get_system_prompt, + get_invalid_query_response, +) from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups from utils.transcripts import store_transcript from utils.types import TurnSummary @@ -589,15 +593,19 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals user_conversation=user_conversation, query_request=query_request ), ) - + # Check question validation before getting response query_is_valid = True if configuration.question_validation.question_validation_enabled: - query_is_valid,temp_agent_conversation_id = await validate_question(query_request.query, client, llama_stack_model_id) + query_is_valid, temp_agent_conversation_id = await validate_question( + query_request.query, client, llama_stack_model_id + ) if not query_is_valid: response = get_invalid_query_response() if not is_transcripts_enabled(): - logger.debug("Transcript collection is disabled in the configuration") + logger.debug( + "Transcript collection is disabled in the configuration" + ) else: summary = TurnSummary( llm_response=response, @@ -605,7 +613,8 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals ) store_transcript( user_id=user_id, - conversation_id = query_request.conversation_id or temp_agent_conversation_id, + conversation_id=query_request.conversation_id + or temp_agent_conversation_id, model_id=model_id, provider_id=provider_id, query_is_valid=query_is_valid, @@ -649,13 +658,15 @@ async def response_generator( # Send start event yield stream_start_event(conversation_id) - + if not query_is_valid: # Generate SSE events for invalid query - yield format_stream_data({ - "event": "token", - "data": {"id": 0, "token": get_invalid_query_response()} - }) + yield format_stream_data( + { + "event": "token", + "data": {"id": 0, "token": get_invalid_query_response()}, + } + ) else: async for chunk in turn_response: p = chunk.event.payload @@ -790,7 +801,6 @@ async def retrieve_response( logger.debug("Conversation ID: %s, session ID: %s", conversation_id, session_id) - # bypass tools and MCP servers if no_tools is True if query_request.no_tools: mcp_headers = {} diff --git a/src/constants.py b/src/constants.py index 907bf0f1d..594333e59 100644 --- a/src/constants.py +++ b/src/constants.py @@ -38,9 +38,7 @@ f"return '{SUBJECT_REJECTED}' if the question is not valid and " f"'{SUBJECT_ALLOWED}' if it is valid." ) -DEFAULT_INVALID_QUERY_RESPONSE = ( - "Invalid query, please try again." -) +DEFAULT_INVALID_QUERY_RESPONSE = "Invalid query, please try again." # Authentication constants diff --git a/src/models/config.py b/src/models/config.py index 773dac1f3..c37275d2b 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -472,7 +472,9 @@ class Configuration(ConfigurationBase): authorization: Optional[AuthorizationConfiguration] = None customization: Optional[Customization] = None inference: InferenceConfiguration = Field(default_factory=InferenceConfiguration) - question_validation: QuestionValidationConfiguration = Field(default_factory=QuestionValidationConfiguration) + question_validation: QuestionValidationConfiguration = Field( + default_factory=QuestionValidationConfiguration + ) def dump(self, filename: str = "configuration.json") -> None: """Dump actual configuration into JSON file.""" diff --git a/src/utils/agent.py b/src/utils/agent.py index 8780c1ee8..a58690977 100644 --- a/src/utils/agent.py +++ b/src/utils/agent.py @@ -73,15 +73,15 @@ async def get_temp_agent( system_prompt: str, ) -> tuple[AsyncAgent, str, str]: """Create a temporary agent with new agent_id and session_id. - + This function creates a new agent without persistence, shields, or tools. Useful for temporary operations or one-off queries, such as validating a question or generating a summary. - + Args: client: The AsyncLlamaStackClient to use for the request. model_id: The ID of the model to use. system_prompt: The system prompt/instructions for the agent. - + Returns: tuple[AsyncAgent, str]: A tuple containing the agent and session_id. """ @@ -93,9 +93,9 @@ async def get_temp_agent( enable_session_persistence=False, # Temporary agent doesn't need persistence ) await agent.initialize() - + # Generate new IDs for the temporary agent conversation_id = agent.agent_id session_id = await agent.create_session(get_suid()) - + return agent, session_id, conversation_id diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index 8ab916d21..9ca510da3 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -62,6 +62,7 @@ def check_configuration_loaded(config: AppConfig) -> None: detail={"response": "Configuration is not loaded"}, ) + QUESTION_VALIDATOR_PROMPT_TEMPLATE = f""" Instructions: - You are a question classifying tool @@ -108,15 +109,18 @@ def check_configuration_loaded(config: AppConfig) -> None: Response: """ + def get_validation_system_prompt() -> str: """Get the validation system prompt.""" - #return constants.DEFAULT_VALIDATION_SYSTEM_PROMPT + # return constants.DEFAULT_VALIDATION_SYSTEM_PROMPT return QUESTION_VALIDATOR_PROMPT_TEMPLATE + def get_invalid_query_response() -> str: """Get the invalid query response.""" return constants.DEFAULT_INVALID_QUERY_RESPONSE + def get_system_prompt(query_request: QueryRequest, config: AppConfig) -> str: """Get the system prompt: the provided one, configured one, or default one.""" system_prompt_disabled = ( @@ -172,5 +176,3 @@ def validate_model_provider_override( ) }, ) - - diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index f5f419617..074e04379 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -1613,43 +1613,45 @@ async def test_query_endpoint_with_question_validation_invalid_query( """Test query endpoint with question validation enabled and invalid query.""" # Mock metrics mock_metrics(mocker) - + # Mock database operations mock_database_operations(mocker) - + # Setup configuration with question validation enabled setup_configuration.question_validation.question_validation_enabled = True mocker.patch("app.endpoints.query.configuration", setup_configuration) - + # Mock the validation agent response (invalid query) mock_validation_agent = mocker.AsyncMock() mock_validation_agent.agent_id = "validation_agent_id" mock_validation_agent.create_session.return_value = "validation_session_id" - + # Mock the validation response that contains SUBJECT_REJECTED mock_validation_turn = mocker.Mock() mock_validation_turn.output_message.content = [{"type": "text", "text": "REJECTED"}] mock_validation_agent.create_turn.return_value = mock_validation_turn - + # Mock the main agent (should not be called for invalid queries) mock_agent = mocker.AsyncMock() mock_agent.agent_id = "conversation_id" mock_agent.create_session.return_value = "session_id" - + # Mock the client mock_client = mocker.AsyncMock() mock_client.models.list.return_value = [ mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1") ] - mocker.patch("client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client) - + mocker.patch( + "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client + ) + # Mock the retrieve_response function to return invalid query response summary = TurnSummary( llm_response="Invalid query response", tool_calls=[], ) conversation_id = "fake_conversation_id" - + mocker.patch( "app.endpoints.query.retrieve_response", return_value=(summary, conversation_id, False), # query_is_valid=False @@ -1659,15 +1661,13 @@ async def test_query_endpoint_with_question_validation_invalid_query( return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), ) mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) - + query_request = QueryRequest(query="Invalid question about unrelated topic") - + response = await query_endpoint_handler( - request=dummy_request, - query_request=query_request, - auth=MOCK_AUTH + request=dummy_request, query_request=query_request, auth=MOCK_AUTH ) - + # Verify the response contains the invalid query response assert response.conversation_id == conversation_id assert response.response == "Invalid query response" @@ -1680,21 +1680,23 @@ async def test_query_endpoint_with_question_validation_valid_query( """Test query endpoint with question validation enabled and valid query.""" # Mock metrics mock_metrics(mocker) - + # Mock database operations mock_database_operations(mocker) - + # Setup configuration with question validation enabled setup_configuration.question_validation.question_validation_enabled = True mocker.patch("app.endpoints.query.configuration", setup_configuration) - + # Mock the client mock_client = mocker.AsyncMock() mock_client.models.list.return_value = [ mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1") ] - mocker.patch("client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client) - + mocker.patch( + "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client + ) + # Mock the retrieve_response function to return valid query response summary = TurnSummary( llm_response="Valid LLM response about Backstage and Kubernetes", @@ -1708,7 +1710,7 @@ async def test_query_endpoint_with_question_validation_valid_query( ], ) conversation_id = "fake_conversation_id" - + mocker.patch( "app.endpoints.query.retrieve_response", return_value=(summary, conversation_id, True), # query_is_valid=True @@ -1718,15 +1720,13 @@ async def test_query_endpoint_with_question_validation_valid_query( return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), ) mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) - + query_request = QueryRequest(query="Valid question about Backstage and Kubernetes") - + response = await query_endpoint_handler( - request=dummy_request, - query_request=query_request, - auth=MOCK_AUTH + request=dummy_request, query_request=query_request, auth=MOCK_AUTH ) - + # Verify the response contains the normal LLM response assert response.conversation_id == conversation_id assert response.response == "Valid LLM response about Backstage and Kubernetes" diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 173694b06..595894b52 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -1596,71 +1596,82 @@ async def test_streaming_query_endpoint_with_question_validation_invalid_query( """Test streaming query endpoint with question validation enabled and invalid query.""" # Mock metrics mock_metrics(mocker) - + # Mock database operations mock_database_operations(mocker) - + # Setup configuration with question validation enabled setup_configuration.question_validation.question_validation_enabled = True mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) - + # Mock the client mock_client = mocker.AsyncMock() mock_client.models.list.return_value = [ mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1") ] - mocker.patch("client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client) - + mocker.patch( + "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client + ) + # Mock the validation agent response (invalid query) mock_validation_agent = mocker.AsyncMock() mock_validation_agent.agent_id = "validation_agent_id" mock_validation_agent.create_session.return_value = "validation_session_id" - + # Mock the validation response that contains SUBJECT_REJECTED mock_validation_turn = mocker.Mock() mock_validation_turn.output_message.content = [{"type": "text", "text": "REJECTED"}] mock_validation_agent.create_turn.return_value = mock_validation_turn - + # Mock the validation functions - mocker.patch("app.endpoints.streaming_query.validate_question", return_value=(False, "validation_agent_id")) - mocker.patch("app.endpoints.streaming_query.get_invalid_query_response", return_value="Invalid query response") - mocker.patch("app.endpoints.streaming_query.interleaved_content_as_str", return_value="REJECTED") - + mocker.patch( + "app.endpoints.streaming_query.validate_question", + return_value=(False, "validation_agent_id"), + ) + mocker.patch( + "app.endpoints.streaming_query.get_invalid_query_response", + return_value="Invalid query response", + ) + mocker.patch( + "app.endpoints.streaming_query.interleaved_content_as_str", + return_value="REJECTED", + ) + # Mock other dependencies mocker.patch("app.endpoints.streaming_query.validate_model_provider_override") mocker.patch("app.endpoints.streaming_query.check_configuration_loaded") - mocker.patch("app.endpoints.streaming_query.is_transcripts_enabled", return_value=False) + mocker.patch( + "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False + ) mocker.patch( "app.endpoints.streaming_query.select_model_and_provider_id", return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), ) - + query_request = QueryRequest(query="Invalid question about unrelated topic") - + request = Request( scope={ "type": "http", } ) request.state.authorized_actions = set(Action) - + response = await streaming_query_endpoint_handler( - request=request, - query_request=query_request, - auth=MOCK_AUTH + request=request, query_request=query_request, auth=MOCK_AUTH ) - + # Verify the response is a StreamingResponse assert isinstance(response, StreamingResponse) - + # Collect the streaming response content streaming_content = [] async for chunk in response.body_iterator: streaming_content.append(chunk) - + # Convert to string for assertions full_content = "".join(streaming_content) - + # Verify the response contains the invalid query response assert "Invalid query response" in full_content @@ -1672,25 +1683,33 @@ async def test_streaming_query_endpoint_with_question_validation_valid_query( """Test streaming query endpoint with question validation enabled and valid query.""" # Mock metrics mock_metrics(mocker) - + # Mock database operations mock_database_operations(mocker) - + # Setup configuration with question validation enabled setup_configuration.question_validation.question_validation_enabled = True mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) - + # Mock the client mock_client = mocker.AsyncMock() mock_client.models.list.return_value = [ mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1") ] - mocker.patch("client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client) - + mocker.patch( + "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client + ) + # Mock the validation functions to return valid query - mocker.patch("app.endpoints.streaming_query.validate_question", return_value=(True, "validation_agent_id")) - mocker.patch("app.endpoints.streaming_query.interleaved_content_as_str", return_value="Valid question about Backstage") - + mocker.patch( + "app.endpoints.streaming_query.validate_question", + return_value=(True, "validation_agent_id"), + ) + mocker.patch( + "app.endpoints.streaming_query.interleaved_content_as_str", + return_value="Valid question about Backstage", + ) + # Mock the retrieve_response function to return valid streaming response mock_streaming_response = mocker.AsyncMock() mock_streaming_response.__aiter__.return_value = [ @@ -1699,7 +1718,10 @@ async def test_streaming_query_endpoint_with_question_validation_valid_query( payload=AgentTurnResponseStepProgressPayload( event_type="step_progress", step_type="inference", - delta=TextDelta(text="Valid LLM response about Backstage and Kubernetes", type="text"), + delta=TextDelta( + text="Valid LLM response about Backstage and Kubernetes", + type="text", + ), step_id="s1", ) ) @@ -1713,7 +1735,12 @@ async def test_streaming_query_endpoint_with_question_validation_valid_query( input_messages=[], output_message=CompletionMessage( role="assistant", - content=[TextContentItem(text="Valid LLM response about Backstage and Kubernetes", type="text")], + content=[ + TextContentItem( + text="Valid LLM response about Backstage and Kubernetes", + type="text", + ) + ], stop_reason="end_of_turn", ), session_id="test_session_id", @@ -1726,7 +1753,7 @@ async def test_streaming_query_endpoint_with_question_validation_valid_query( ) ), ] - + mocker.patch( "app.endpoints.streaming_query.retrieve_response", return_value=(mock_streaming_response, "test_conversation_id"), @@ -1735,34 +1762,34 @@ async def test_streaming_query_endpoint_with_question_validation_valid_query( "app.endpoints.streaming_query.select_model_and_provider_id", return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), ) - mocker.patch("app.endpoints.streaming_query.is_transcripts_enabled", return_value=False) - + mocker.patch( + "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False + ) + query_request = QueryRequest(query="Valid question about Backstage and Kubernetes") - + request = Request( scope={ "type": "http", } ) request.state.authorized_actions = set(Action) - + response = await streaming_query_endpoint_handler( - request=request, - query_request=query_request, - auth=MOCK_AUTH + request=request, query_request=query_request, auth=MOCK_AUTH ) - + # Verify the response is a StreamingResponse assert isinstance(response, StreamingResponse) - + # Collect the streaming response content streaming_content = [] async for chunk in response.body_iterator: streaming_content.append(chunk) - + # Convert to string for assertions full_content = "".join(streaming_content) - + # Verify the response contains the normal LLM response assert "Valid LLM response about Backstage and Kubernetes" in full_content assert '"event": "start"' in full_content diff --git a/tests/unit/models/config/test_dump_configuration.py b/tests/unit/models/config/test_dump_configuration.py index 7aad93045..a97e2f884 100644 --- a/tests/unit/models/config/test_dump_configuration.py +++ b/tests/unit/models/config/test_dump_configuration.py @@ -164,9 +164,7 @@ def test_dump_configuration(tmp_path) -> None: }, }, "authorization": None, - "question_validation": { - "question_validation_enabled": False - }, + "question_validation": {"question_validation_enabled": False}, } diff --git a/tests/unit/utils/test_agent.py b/tests/unit/utils/test_agent.py index f8d89bcc8..b81893b7d 100644 --- a/tests/unit/utils/test_agent.py +++ b/tests/unit/utils/test_agent.py @@ -3,7 +3,6 @@ import pytest from configuration import AppConfig -from tests.unit import config_dict from utils.agent import get_agent, get_temp_agent @@ -81,9 +80,7 @@ async def test_get_agent_with_conversation_id_and_no_agent_in_llama_stack( mock_agent.create_session.return_value = "new_session_id" # Mock Agent class - mock_agent_class = mocker.patch( - "utils.agent.AsyncAgent", return_value=mock_agent - ) + mock_agent_class = mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) # Mock get_suid mocker.patch("utils.agent.get_suid", return_value="new_session_id") @@ -136,9 +133,7 @@ async def test_get_agent_no_conversation_id( mock_agent.create_session.return_value = "new_session_id" # Mock Agent class - mock_agent_class = mocker.patch( - "utils.agent.AsyncAgent", return_value=mock_agent - ) + mock_agent_class = mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) # Mock get_suid mocker.patch("utils.agent.get_suid", return_value="new_session_id") @@ -190,9 +185,7 @@ async def test_get_agent_empty_shields( mock_agent.create_session.return_value = "new_session_id" # Mock Agent class - mock_agent_class = mocker.patch( - "utils.agent.AsyncAgent", return_value=mock_agent - ) + mock_agent_class = mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) # Mock get_suid mocker.patch("utils.agent.get_suid", return_value="new_session_id") @@ -244,9 +237,7 @@ async def test_get_agent_multiple_mcp_servers( mock_agent.create_session.return_value = "new_session_id" # Mock Agent class - mock_agent_class = mocker.patch( - "utils.agent.AsyncAgent", return_value=mock_agent - ) + mock_agent_class = mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) # Mock get_suid mocker.patch("utils.agent.get_suid", return_value="new_session_id") @@ -300,9 +291,7 @@ async def test_get_agent_session_persistence_enabled( mock_agent.create_session.return_value = "new_session_id" # Mock Agent class - mock_agent_class = mocker.patch( - "utils.agent.AsyncAgent", return_value=mock_agent - ) + mock_agent_class = mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) # Mock get_suid mocker.patch("utils.agent.get_suid", return_value="new_session_id") @@ -349,9 +338,7 @@ async def test_get_agent_no_tools_no_parser( mock_agent.create_session.return_value = "new_session_id" # Mock Agent class - mock_agent_class = mocker.patch( - "utils.agent.AsyncAgent", return_value=mock_agent - ) + mock_agent_class = mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) # Mock get_suid mocker.patch("utils.agent.get_suid", return_value="new_session_id") @@ -404,9 +391,7 @@ async def test_get_agent_no_tools_false_preserves_parser( mock_agent.create_session.return_value = "new_session_id" # Mock Agent class - mock_agent_class = mocker.patch( - "utils.agent.AsyncAgent", return_value=mock_agent - ) + mock_agent_class = mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) # Mock get_suid mocker.patch("utils.agent.get_suid", return_value="new_session_id") @@ -462,9 +447,7 @@ async def test_get_temp_agent_basic_functionality(prepare_agent_mocks, mocker): mock_agent.create_session.return_value = "temp_session_id" # Mock Agent class - mock_agent_class = mocker.patch( - "utils.agent.AsyncAgent", return_value=mock_agent - ) + mock_agent_class = mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) # Mock get_suid mocker.patch("utils.agent.get_suid", return_value="temp_session_id") @@ -518,11 +501,11 @@ async def test_get_temp_agent_returns_valid_ids(prepare_agent_mocks, mocker): assert result_agent is not None assert result_session_id is not None assert result_conversation_id is not None - + # Assert they are strings assert isinstance(result_session_id, str) assert isinstance(result_conversation_id, str) - + # Assert conversation_id matches agent_id assert result_conversation_id == result_agent.agent_id @@ -534,9 +517,7 @@ async def test_get_temp_agent_no_persistence(prepare_agent_mocks, mocker): mock_agent.create_session.return_value = "temp_session_id" # Mock Agent class - mock_agent_class = mocker.patch( - "utils.agent.AsyncAgent", return_value=mock_agent - ) + mock_agent_class = mocker.patch("utils.agent.AsyncAgent", return_value=mock_agent) # Mock get_suid mocker.patch("utils.agent.get_suid", return_value="temp_session_id") diff --git a/tests/unit/utils/test_endpoints.py b/tests/unit/utils/test_endpoints.py index 5feeabd9d..96870cf61 100644 --- a/tests/unit/utils/test_endpoints.py +++ b/tests/unit/utils/test_endpoints.py @@ -81,8 +81,6 @@ def query_request_with_system_prompt_fixture(): return QueryRequest(query="query", system_prompt="System prompt defined in query") - - def test_get_default_system_prompt( config_without_system_prompt, query_request_without_system_prompt ): @@ -148,8 +146,6 @@ def test_get_system_prompt_with_disable_query_system_prompt_and_non_system_promp assert system_prompt == CONFIGURED_SYSTEM_PROMPT - - def test_validate_model_provider_override_allowed_with_action(): """Ensure no exception when caller has MODEL_OVERRIDE and request includes model/provider.""" query_request = QueryRequest(query="q", model="m", provider="p") From dabe81600b88acba3b72c538574c80244cbc81f0 Mon Sep 17 00:00:00 2001 From: Stephanie Date: Mon, 15 Sep 2025 12:53:01 -0400 Subject: [PATCH 6/7] update util functions with profile customization and with tests Signed-off-by: Stephanie --- lightspeed-stack.yaml | 5 +- src/app/endpoints/query.py | 4 +- src/app/endpoints/streaming_query.py | 49 +++++------ src/models/config.py | 8 ++ src/utils/endpoints.py | 72 +++++---------- tests/unit/utils/test_endpoints.py | 125 ++++++++++++++++++++++++++- 6 files changed, 176 insertions(+), 87 deletions(-) diff --git a/lightspeed-stack.yaml b/lightspeed-stack.yaml index 0948f4e2c..9ac7f63a1 100644 --- a/lightspeed-stack.yaml +++ b/lightspeed-stack.yaml @@ -13,7 +13,7 @@ llama_stack: # Alternative for "as library use" # use_as_library_client: true # library_client_config_path: - url: http://localhost:8321 + url: http://llama-stack:8321 api_key: xyzzy user_data_collection: feedback_enabled: true @@ -23,6 +23,3 @@ user_data_collection: authentication: module: "noop" - -question_validation: - question_validation_enabled: true \ No newline at end of file diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 035558a46..99bc2ac09 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -406,7 +406,7 @@ async def validate_question( Returns: bool: True if the question was deemed valid, False otherwise """ - validation_system_prompt = get_validation_system_prompt() + validation_system_prompt = get_validation_system_prompt(configuration) agent, session_id, conversation_id = await get_temp_agent( client, model_id, validation_system_prompt ) @@ -505,7 +505,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche if not question_is_valid: return ( TurnSummary( - llm_response=get_invalid_query_response(), + llm_response=get_invalid_query_response(configuration), tool_calls=[], ), conversation_id, diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index c31d516a5..053388d4f 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -601,7 +601,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals query_request.query, client, llama_stack_model_id ) if not query_is_valid: - response = get_invalid_query_response() + response = get_invalid_query_response(configuration) if not is_transcripts_enabled(): logger.debug( "Transcript collection is disabled in the configuration" @@ -659,35 +659,26 @@ async def response_generator( # Send start event yield stream_start_event(conversation_id) - if not query_is_valid: - # Generate SSE events for invalid query - yield format_stream_data( - { - "event": "token", - "data": {"id": 0, "token": get_invalid_query_response()}, - } - ) - else: - async for chunk in turn_response: - p = chunk.event.payload - if p.event_type == "turn_complete": - summary.llm_response = interleaved_content_as_str( - p.turn.output_message.content + async for chunk in turn_response: + p = chunk.event.payload + if p.event_type == "turn_complete": + summary.llm_response = interleaved_content_as_str( + p.turn.output_message.content + ) + system_prompt = get_system_prompt(query_request, configuration) + try: + update_llm_token_count_from_turn( + p.turn, model_id, provider_id, system_prompt ) - system_prompt = get_system_prompt(query_request, configuration) - try: - update_llm_token_count_from_turn( - p.turn, model_id, provider_id, system_prompt - ) - except Exception: # pylint: disable=broad-except - logger.exception("Failed to update token usage metrics") - elif p.event_type == "step_complete": - if p.step_details.step_type == "tool_execution": - summary.append_tool_calls_from_llama(p.step_details) - - for event in stream_build_event(chunk, chunk_id, metadata_map): - chunk_id += 1 - yield event + except Exception: # pylint: disable=broad-except + logger.exception("Failed to update token usage metrics") + elif p.event_type == "step_complete": + if p.step_details.step_type == "tool_execution": + summary.append_tool_calls_from_llama(p.step_details) + + for event in stream_build_event(chunk, chunk_id, metadata_map): + chunk_id += 1 + yield event yield stream_end_event(metadata_map) diff --git a/src/models/config.py b/src/models/config.py index e4c9ad2d1..9f093e410 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -427,6 +427,7 @@ class CustomProfile: path: str prompts: dict[str, str] = Field(default={}, init=False) + query_responses: dict[str, str] = Field(default={}, init=False) def __post_init__(self) -> None: """Validate and load profile.""" @@ -438,11 +439,18 @@ def _validate_and_process(self) -> None: profile_module = checks.import_python_module("profile", self.path) if profile_module is not None and checks.is_valid_profile(profile_module): self.prompts = profile_module.PROFILE_CONFIG.get("system_prompts", {}) + self.query_responses = profile_module.PROFILE_CONFIG.get( + "query_responses", {} + ) def get_prompts(self) -> dict[str, str]: """Retrieve prompt attribute.""" return self.prompts + def get_query_responses(self) -> dict[str, str]: + """Retrieve query responses attribute.""" + return self.query_responses + class Customization(ConfigurationBase): """Service customization.""" diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index 9ec146856..c45fe5e61 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -63,61 +63,31 @@ def check_configuration_loaded(config: AppConfig) -> None: ) -QUESTION_VALIDATOR_PROMPT_TEMPLATE = f""" -Instructions: -- You are a question classifying tool -- You are an expert in Backstage, Red Hat Developer Hub (RHDH), Kubernetes, Openshift, CI/CD and GitOps Pipelines -- Your job is to determine if a user's question is related to Backstage or Red Hat Developer Hub (RHDH) technologies, \ - including integrations, plugins, catalog exploration, service creation, or workflow automation. -- If a question appears to be related to Backstage, RHDH, Kubernetes, Openshift, or any of their features, answer with the word {constants.SUBJECT_ALLOWED} -- If a question is not related to Backstage, RHDH, Kubernetes, Openshift, or their features, answer with the word {constants.SUBJECT_REJECTED} -- Do not explain your answer, just provide the one-word response - - -Example Question: -Why is the sky blue? -Example Response: -{constants.SUBJECT_REJECTED} - -Example Question: -Can you help configure my cluster to automatically scale? -Example Response: -{constants.SUBJECT_ALLOWED} - -Example Question: -How do I create import an existing software template in Backstage? -Example Response: -{constants.SUBJECT_ALLOWED} - -Example Question: -How do I accomplish $task in RHDH? -Example Response: -{constants.SUBJECT_ALLOWED} - -Example Question: -How do I explore a component in RHDH catalog? -Example Response: -{constants.SUBJECT_ALLOWED} - -Example Question: -How can I integrate GitOps into my pipeline? -Example Response: -{constants.SUBJECT_ALLOWED} - -Question: -{{query}} -Response: -""" - - -def get_validation_system_prompt() -> str: +def get_validation_system_prompt(config: AppConfig) -> str: """Get the validation system prompt.""" - # return constants.DEFAULT_VALIDATION_SYSTEM_PROMPT - return QUESTION_VALIDATOR_PROMPT_TEMPLATE + # profile takes precedence for setting prompt + if ( + config.customization is not None + and config.customization.custom_profile is not None + ): + prompt = config.customization.custom_profile.get_prompts().get("validation") + if prompt: + return prompt + + return constants.DEFAULT_VALIDATION_SYSTEM_PROMPT -def get_invalid_query_response() -> str: +def get_invalid_query_response(config: AppConfig) -> str: """Get the invalid query response.""" + if ( + config.customization is not None + and config.customization.custom_profile is not None + ): + prompt = config.customization.custom_profile.get_query_responses().get( + "invalid_resp" + ) + if prompt: + return prompt return constants.DEFAULT_INVALID_QUERY_RESPONSE diff --git a/tests/unit/utils/test_endpoints.py b/tests/unit/utils/test_endpoints.py index ba6a86807..c12693e80 100644 --- a/tests/unit/utils/test_endpoints.py +++ b/tests/unit/utils/test_endpoints.py @@ -209,7 +209,6 @@ def test_get_profile_prompt_with_enabled_query_system_prompt( assert system_prompt == query_request_with_system_prompt.system_prompt - def test_validate_model_provider_override_allowed_with_action(): """Ensure no exception when caller has MODEL_OVERRIDE and request includes model/provider.""" query_request = QueryRequest(query="q", model="m", provider="p") @@ -230,3 +229,127 @@ def test_validate_model_provider_override_no_override_without_action(): """No exception when request does not include model/provider regardless of permission.""" query_request = QueryRequest(query="q") endpoints.validate_model_provider_override(query_request, set()) + + +# Tests for get_validation_system_prompt + + +def test_get_default_validation_system_prompt(config_without_system_prompt): + """Test that default validation system prompt is returned when no custom profile is provided.""" + validation_prompt = endpoints.get_validation_system_prompt( + config_without_system_prompt + ) + assert validation_prompt == constants.DEFAULT_VALIDATION_SYSTEM_PROMPT + + +def test_get_validation_system_prompt_with_custom_profile(): + """Test that validation system prompt from custom profile is returned when available.""" + test_config = config_dict.copy() + test_config["customization"] = { + "profile_path": "tests/profiles/test/profile.py", + } + cfg = AppConfig() + cfg.init_from_dict(test_config) + + validation_prompt = endpoints.get_validation_system_prompt(cfg) + + # Get the expected prompt from the test profile + custom_profile = CustomProfile(path="tests/profiles/test/profile.py") + expected_prompt = custom_profile.get_prompts().get("validation") + + assert validation_prompt == expected_prompt + + +def test_get_validation_system_prompt_with_custom_profile_no_validation_prompt(): + """Test that default validation system prompt is returned when custom profile has no validation prompt.""" + # Create a test profile that doesn't have validation prompt + test_config = config_dict.copy() + test_config["customization"] = { + "profile_path": "tests/profiles/test/profile.py", + } + cfg = AppConfig() + cfg.init_from_dict(test_config) + + # Manually set the prompts to not include validation + cfg.customization.custom_profile.prompts = {"default": "test prompt"} + + validation_prompt = endpoints.get_validation_system_prompt(cfg) + assert validation_prompt == constants.DEFAULT_VALIDATION_SYSTEM_PROMPT + + +def test_get_validation_system_prompt_with_custom_profile_empty_validation_prompt(): + """Test that default validation system prompt is returned when custom profile has empty validation prompt.""" + # Create a test profile that has empty validation prompt + test_config = config_dict.copy() + test_config["customization"] = { + "profile_path": "tests/profiles/test/profile.py", + } + cfg = AppConfig() + cfg.init_from_dict(test_config) + + # Manually set the prompts to have empty validation prompt + cfg.customization.custom_profile.prompts = {"validation": ""} + + validation_prompt = endpoints.get_validation_system_prompt(cfg) + assert validation_prompt == constants.DEFAULT_VALIDATION_SYSTEM_PROMPT + + +# Tests for get_invalid_query_response + + +def test_get_default_invalid_query_response(config_without_system_prompt): + """Test that default invalid query response is returned when no custom profile is provided.""" + invalid_response = endpoints.get_invalid_query_response( + config_without_system_prompt + ) + assert invalid_response == constants.DEFAULT_INVALID_QUERY_RESPONSE + + +def test_get_invalid_query_response_with_custom_profile(): + """Test that invalid query response from custom profile is returned when available.""" + test_config = config_dict.copy() + test_config["customization"] = { + "profile_path": "tests/profiles/test/profile.py", + } + cfg = AppConfig() + cfg.init_from_dict(test_config) + + invalid_response = endpoints.get_invalid_query_response(cfg) + + # Get the expected response from the test profile + custom_profile = CustomProfile(path="tests/profiles/test/profile.py") + expected_response = custom_profile.get_query_responses().get("invalid_resp") + + assert invalid_response == expected_response + + +def test_get_invalid_query_response_with_custom_profile_no_invalid_resp(): + """Test that default invalid query response is returned when custom profile has no invalid_resp.""" + test_config = config_dict.copy() + test_config["customization"] = { + "profile_path": "tests/profiles/test/profile.py", + } + cfg = AppConfig() + cfg.init_from_dict(test_config) + + # Manually set the query_responses to not include invalid_resp + cfg.customization.custom_profile.query_responses = {} + + invalid_response = endpoints.get_invalid_query_response(cfg) + assert invalid_response == constants.DEFAULT_INVALID_QUERY_RESPONSE + + +def test_get_invalid_query_response_with_custom_profile_empty_invalid_resp(): + """Test that default invalid query response is returned when custom profile has empty invalid_resp.""" + test_config = config_dict.copy() + test_config["customization"] = { + "profile_path": "tests/profiles/test/profile.py", + } + cfg = AppConfig() + cfg.init_from_dict(test_config) + + # Manually set the query_responses to have empty invalid_resp + cfg.customization.custom_profile.query_responses = {"invalid_resp": ""} + + invalid_response = endpoints.get_invalid_query_response(cfg) + assert invalid_response == constants.DEFAULT_INVALID_QUERY_RESPONSE From 1110d88bb2f6741b8b443c0bffb838a355c9a2dd Mon Sep 17 00:00:00 2001 From: Stephanie Date: Mon, 15 Sep 2025 12:53:41 -0400 Subject: [PATCH 7/7] generate openapi json Signed-off-by: Stephanie --- docs/openapi.json | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/docs/openapi.json b/docs/openapi.json index 96b4725c8..cff61b422 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1052,6 +1052,9 @@ }, "inference": { "$ref": "#/components/schemas/InferenceConfiguration" + }, + "question_validation": { + "$ref": "#/components/schemas/QuestionValidationConfiguration" } }, "additionalProperties": false, @@ -1286,6 +1289,14 @@ "type": "object", "title": "Prompts", "default": {} + }, + "query_responses": { + "additionalProperties": { + "type": "string" + }, + "type": "object", + "title": "Query Responses", + "default": {} } }, "type": "object", @@ -2279,6 +2290,19 @@ } ] }, + "QuestionValidationConfiguration": { + "properties": { + "question_validation_enabled": { + "type": "boolean", + "title": "Question Validation Enabled", + "default": false + } + }, + "additionalProperties": false, + "type": "object", + "title": "QuestionValidationConfiguration", + "description": "Question validation configuration." + }, "ReadinessResponse": { "properties": { "ready": {