diff --git a/docker-compose-library.yaml b/docker-compose-library.yaml index 4733d5d6..cbdd2de5 100644 --- a/docker-compose-library.yaml +++ b/docker-compose-library.yaml @@ -18,7 +18,7 @@ services: - TAVILY_SEARCH_API_KEY=${TAVILY_SEARCH_API_KEY:-} # OpenAI - OPENAI_API_KEY=${OPENAI_API_KEY} - - E2E_OPENAI_MODEL=${E2E_OPENAI_MODEL:-gpt-4-turbo} + - E2E_OPENAI_MODEL=${E2E_OPENAI_MODEL:-gpt-4o-mini} # Azure - AZURE_API_KEY=${AZURE_API_KEY:-} # RHAIIS diff --git a/docker-compose.yaml b/docker-compose.yaml index 3b00c381..8122011a 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -16,7 +16,7 @@ services: - TAVILY_SEARCH_API_KEY=${TAVILY_SEARCH_API_KEY:-} # OpenAI - OPENAI_API_KEY=${OPENAI_API_KEY} - - E2E_OPENAI_MODEL=${E2E_OPENAI_MODEL} + - E2E_OPENAI_MODEL=${E2E_OPENAI_MODEL:-gpt-4o-mini} # Azure - AZURE_API_KEY=${AZURE_API_KEY} # RHAIIS diff --git a/run.yaml b/run.yaml index 7787c93d..4deba67b 100644 --- a/run.yaml +++ b/run.yaml @@ -131,8 +131,15 @@ storage: namespace: prompts backend: kv_default registered_resources: - models: [] - shields: [] + models: + - model_id: gpt-4o-mini + provider_id: openai + model_type: llm + provider_model_id: gpt-4o-mini + shields: + - shield_id: llama-guard + provider_id: llama-guard + provider_shield_id: openai/gpt-4o-mini vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index a41ba9fb..952c0a0b 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -8,11 +8,11 @@ from typing import Annotated, Any, Optional, cast from fastapi import APIRouter, Depends, HTTPException, Request -from litellm.exceptions import RateLimitError from llama_stack_client import ( APIConnectionError, APIStatusError, - AsyncLlamaStackClient, # type: ignore + AsyncLlamaStackClient, + RateLimitError, # type: ignore ) from llama_stack_client.types import Shield, UserMessage # type: ignore from llama_stack_client.types.alpha.agents.turn import Turn diff --git a/src/app/endpoints/query_v2.py b/src/app/endpoints/query_v2.py index 6fce4537..dd4eef8f 100644 --- a/src/app/endpoints/query_v2.py +++ b/src/app/endpoints/query_v2.py @@ -10,7 +10,7 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObject, ) -from llama_stack_client import AsyncLlamaStackClient # type: ignore +from llama_stack_client import AsyncLlamaStackClient import metrics from app.endpoints.query import ( @@ -42,7 +42,10 @@ ) from utils.mcp_headers import mcp_headers_dependency from utils.responses import extract_text_from_response_output_item -from utils.shields import detect_shield_violations, get_available_shields +from utils.shields import ( + append_turn_to_conversation, + run_shield_moderation, +) from utils.suid import normalize_conversation_id, to_llama_stack_conversation_id from utils.token_counter import TokenCounter from utils.types import RAGChunk, ToolCallSummary, ToolResultSummary, TurnSummary @@ -322,9 +325,6 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche and the conversation ID, the list of parsed referenced documents, and token usage information. """ - # List available shields for Responses API - available_shields = await get_available_shields(client) - # use system prompt from request or default one system_prompt = get_system_prompt(query_request, configuration) logger.debug("Using system prompt: %s", system_prompt) @@ -370,6 +370,26 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche conversation_id, ) + # Run shield moderation before calling LLM + moderation_result = await run_shield_moderation(client, input_text) + if moderation_result.blocked: + violation_message = moderation_result.message or "" + await append_turn_to_conversation( + client, llama_stack_conv_id, input_text, violation_message + ) + summary = TurnSummary( + llm_response=violation_message, + tool_calls=[], + tool_results=[], + rag_chunks=[], + ) + return ( + summary, + normalize_conversation_id(conversation_id), + [], + TokenCounter(), + ) + # Create OpenAI response using responses API create_kwargs: dict[str, Any] = { "input": input_text, @@ -381,10 +401,6 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche "conversation": llama_stack_conv_id, } - # Add shields to extra_body if available - if available_shields: - create_kwargs["extra_body"] = {"guardrails": available_shields} - response = await client.responses.create(**create_kwargs) response = cast(OpenAIResponseObject, response) logger.info("Response: %s", response) @@ -410,9 +426,6 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche if tool_result: tool_results.append(tool_result) - # Check for shield violations across all output items - detect_shield_violations(response.output) - logger.info( "Response processing complete - Tool calls: %d, Response length: %d chars", len(tool_calls), diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index b9477b7a..8e526615 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -11,10 +11,10 @@ from fastapi import APIRouter, Depends, Request from fastapi.responses import StreamingResponse -from litellm.exceptions import RateLimitError from llama_stack_client import ( APIConnectionError, - AsyncLlamaStackClient, # type: ignore + AsyncLlamaStackClient, + RateLimitError, # type: ignore ) from llama_stack_client.types import UserMessage # type: ignore from llama_stack_client.types.alpha.agents.agent_turn_response_stream_chunk import ( diff --git a/src/app/endpoints/streaming_query_v2.py b/src/app/endpoints/streaming_query_v2.py index 266d3fbf..45e63b88 100644 --- a/src/app/endpoints/streaming_query_v2.py +++ b/src/app/endpoints/streaming_query_v2.py @@ -6,8 +6,14 @@ from fastapi import APIRouter, Depends, Request from fastapi.responses import StreamingResponse from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseContentPartOutputText, + OpenAIResponseMessage, OpenAIResponseObject, OpenAIResponseObjectStream, + OpenAIResponseObjectStreamResponseCompleted, + OpenAIResponseObjectStreamResponseContentPartAdded, + OpenAIResponseObjectStreamResponseOutputTextDelta, + OpenAIResponseOutputMessageContentOutputText, ) from llama_stack_client import AsyncLlamaStackClient @@ -53,7 +59,10 @@ from utils.quota import consume_tokens, get_available_quotas from utils.suid import normalize_conversation_id, to_llama_stack_conversation_id from utils.mcp_headers import mcp_headers_dependency -from utils.shields import detect_shield_violations, get_available_shields +from utils.shields import ( + append_turn_to_conversation, + run_shield_moderation, +) from utils.token_counter import TokenCounter from utils.transcripts import store_transcript from utils.types import ToolCallSummary, TurnSummary @@ -234,12 +243,6 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat # Capture the response object for token usage extraction latest_response_object = getattr(chunk, "response", None) - # Check for shield violations in the completed response - if latest_response_object: - output = getattr(latest_response_object, "output", None) - if output is not None: - detect_shield_violations(output) - if not emitted_turn_complete: final_message = summary.llm_response or "".join(text_parts) if not final_message: @@ -394,9 +397,6 @@ async def retrieve_response( # pylint: disable=too-many-locals tuple: A tuple containing the streaming response object and the conversation ID. """ - # List available shields for Responses API - available_shields = await get_available_shields(client) - # use system prompt from request or default one system_prompt = get_system_prompt(query_request, configuration) logger.debug("Using system prompt: %s", system_prompt) @@ -441,6 +441,18 @@ async def retrieve_response( # pylint: disable=too-many-locals conversation_id, ) + # Run shield moderation before calling LLM + moderation_result = await run_shield_moderation(client, input_text) + if moderation_result.blocked: + violation_message = moderation_result.message or "" + await append_turn_to_conversation( + client, llama_stack_conv_id, input_text, violation_message + ) + return ( + create_violation_stream(violation_message, moderation_result.shield_model), + normalize_conversation_id(conversation_id), + ) + create_params: dict[str, Any] = { "input": input_text, "model": model_id, @@ -451,14 +463,58 @@ async def retrieve_response( # pylint: disable=too-many-locals "conversation": llama_stack_conv_id, } - # Add shields to extra_body if available - if available_shields: - create_params["extra_body"] = {"guardrails": available_shields} - response = await client.responses.create(**create_params) response_stream = cast(AsyncIterator[OpenAIResponseObjectStream], response) - # async for chunk in response_stream: - # logger.error("Chunk: %s", chunk.model_dump_json()) - # Return the normalized conversation_id - # The response_generator will emit it in the start event + return response_stream, normalize_conversation_id(conversation_id) + + +async def create_violation_stream( + message: str, + shield_model: str | None = None, +) -> AsyncIterator[OpenAIResponseObjectStream]: + """Generate a minimal streaming response for cases where input is blocked by a shield. + + This yields only the essential streaming events to indicate that the input was rejected. + Dummy item identifiers are used solely for protocol compliance and are not used later. + """ + response_id = "resp_shield_violation" + + # Content part added (triggers empty initial token) + yield OpenAIResponseObjectStreamResponseContentPartAdded( + content_index=0, + response_id=response_id, + item_id="msg_shield_violation_1", + output_index=0, + part=OpenAIResponseContentPartOutputText(text=""), + sequence_number=0, + ) + + # Text delta + yield OpenAIResponseObjectStreamResponseOutputTextDelta( + content_index=1, + delta=message, + item_id="msg_shield_violation_2", + output_index=1, + sequence_number=1, + ) + + # Completed response + yield OpenAIResponseObjectStreamResponseCompleted( + response=OpenAIResponseObject( + id=response_id, + created_at=0, # not used + model=shield_model or "shield", + output=[ + OpenAIResponseMessage( + id="msg_shield_violation_3", + content=[ + OpenAIResponseOutputMessageContentOutputText(text=message) + ], + role="assistant", + status="completed", + ) + ], + status="completed", + ) + ) diff --git a/src/utils/shields.py b/src/utils/shields.py index 38730fb7..ca99671c 100644 --- a/src/utils/shields.py +++ b/src/utils/shields.py @@ -1,14 +1,20 @@ """Utility functions for working with Llama Stack shields.""" import logging -from typing import Any +from typing import Any, cast -from llama_stack_client import AsyncLlamaStackClient +from fastapi import HTTPException +from llama_stack_client import AsyncLlamaStackClient, BadRequestError +from llama_stack_client.types import CreateResponse import metrics +from models.responses import NotFoundResponse +from utils.types import ShieldModerationResult logger = logging.getLogger(__name__) +DEFAULT_VIOLATION_MESSAGE = "I cannot process this request due to policy restrictions." + async def get_available_shields(client: AsyncLlamaStackClient) -> list[str]: """ @@ -52,3 +58,100 @@ def detect_shield_violations(output_items: list[Any]) -> bool: logger.warning("Shield violation detected: %s", refusal) return True return False + + +async def run_shield_moderation( + client: AsyncLlamaStackClient, + input_text: str, +) -> ShieldModerationResult: + """ + Run shield moderation on input text. + + Iterates through all configured shields and runs moderation checks. + Raises HTTPException if shield model is not found. + + Parameters: + client: The Llama Stack client. + input_text: The text to moderate. + + Returns: + ShieldModerationResult: Result indicating if content was blocked and the message. + + Raises: + HTTPException: If shield's provider_resource_id is not configured or model not found. + """ + available_models = {model.identifier for model in await client.models.list()} + + for shield in await client.shields.list(): + if ( + not shield.provider_resource_id + or shield.provider_resource_id not in available_models + ): + response = NotFoundResponse( + resource="Shield model", resource_id=shield.provider_resource_id or "" + ) + raise HTTPException(**response.model_dump()) + + try: + moderation = await client.moderations.create( + input=input_text, model=shield.provider_resource_id + ) + moderation_result = cast(CreateResponse, moderation) + + if moderation_result.results and moderation_result.results[0].flagged: + result = moderation_result.results[0] + metrics.llm_calls_validation_errors_total.inc() + logger.warning( + "Shield '%s' flagged content: categories=%s", + shield.identifier, + result.categories, + ) + violation_message = result.user_message or DEFAULT_VIOLATION_MESSAGE + return ShieldModerationResult( + blocked=True, + message=violation_message, + shield_model=shield.provider_resource_id, + ) + + # Known Llama Stack bug: BadRequestError is raised when violation is present + # in the shield LLM response but has wrong format that cannot be parsed. + except BadRequestError: + logger.warning( + "Shield '%s' returned BadRequestError, treating as blocked", + shield.identifier, + ) + metrics.llm_calls_validation_errors_total.inc() + return ShieldModerationResult( + blocked=True, + message=DEFAULT_VIOLATION_MESSAGE, + shield_model=shield.provider_resource_id, + ) + + return ShieldModerationResult(blocked=False) + + +async def append_turn_to_conversation( + client: AsyncLlamaStackClient, + conversation_id: str, + user_message: str, + assistant_message: str, +) -> None: + """ + Append a user/assistant turn to a conversation after shield violation. + + Used to record the conversation turn when a shield blocks the request, + storing both the user's original message and the violation response. + + Parameters: + client: The Llama Stack client. + conversation_id: The Llama Stack conversation ID. + user_message: The user's input message. + assistant_message: The shield violation response message. + """ + await client.conversations.items.create( + conversation_id, + items=[ + {"type": "message", "role": "user", "content": user_message}, + {"type": "message", "role": "assistant", "content": assistant_message}, + ], + ) diff --git a/src/utils/types.py b/src/utils/types.py index 070869da..dd5685e5 100644 --- a/src/utils/types.py +++ b/src/utils/types.py @@ -100,6 +100,14 @@ def get_parser(model_id: str) -> Optional[ToolParser]: return None +class ShieldModerationResult(BaseModel): + """Result of shield moderation check.""" + + blocked: bool + message: str | None = None + shield_model: str | None = None + + class ToolCallSummary(BaseModel): """Model representing a tool call made during response generation (for tool_calls list).""" diff --git a/tests/e2e/configs/run-azure.yaml b/tests/e2e/configs/run-azure.yaml index a8aa02de..08004a1d 100644 --- a/tests/e2e/configs/run-azure.yaml +++ b/tests/e2e/configs/run-azure.yaml @@ -80,6 +80,10 @@ providers: api_key: ${env.AZURE_API_KEY} api_base: https://ols-test.openai.azure.com/ api_version: 2024-02-15-preview + - provider_id: openai + provider_type: remote::openai + config: + api_key: ${env.OPENAI_API_KEY} - config: {} provider_id: sentence-transformers provider_type: inline::sentence-transformers @@ -138,7 +142,10 @@ registered_resources: provider_id: azure model_type: llm provider_model_id: gpt-4o-mini - shields: [] + shields: + - shield_id: llama-guard + provider_id: llama-guard + provider_shield_id: openai/gpt-4o-mini vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/tests/e2e/configs/run-ci.yaml b/tests/e2e/configs/run-ci.yaml index 06b57faa..49971d2b 100644 --- a/tests/e2e/configs/run-ci.yaml +++ b/tests/e2e/configs/run-ci.yaml @@ -132,17 +132,20 @@ storage: backend: kv_default registered_resources: models: + - model_id: gpt-4o-mini + provider_id: openai + model_type: llm + provider_model_id: gpt-4o-mini - model_id: sentence-transformers/all-mpnet-base-v2 model_type: embedding provider_id: sentence-transformers provider_model_id: sentence-transformers/all-mpnet-base-v2 metadata: embedding_dimension: 768 - - model_id: gpt-4o-mini - provider_id: openai - model_type: llm - provider_model_id: gpt-4o-mini - shields: [] + shields: + - shield_id: llama-guard + provider_id: llama-guard + provider_shield_id: openai/gpt-4o-mini vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/tests/e2e/configs/run-rhaiis.yaml b/tests/e2e/configs/run-rhaiis.yaml index 7ec33263..b828e89e 100644 --- a/tests/e2e/configs/run-rhaiis.yaml +++ b/tests/e2e/configs/run-rhaiis.yaml @@ -24,6 +24,10 @@ providers: - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} + - provider_id: openai + provider_type: remote::openai + config: + api_key: ${env.OPENAI_API_KEY} vector_io: - provider_id: documentation_faiss provider_type: inline::faiss @@ -130,14 +134,10 @@ registered_resources: provider_id: vllm model_type: llm provider_model_id: ${env.RHAIIS_MODEL} - shields: - shield_id: llama-guard - provider_id: ${env.SAFETY_MODEL:+llama-guard} - provider_shield_id: ${env.SAFETY_MODEL:=} - - shield_id: code-scanner - provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner} - provider_shield_id: ${env.CODE_SCANNER_MODEL:=} + provider_id: llama-guard + provider_shield_id: openai/gpt-4o-mini datasets: [] scoring_fns: [] benchmarks: [] diff --git a/tests/e2e/configs/run-rhelai.yaml b/tests/e2e/configs/run-rhelai.yaml index 2d9ac373..8327f293 100644 --- a/tests/e2e/configs/run-rhelai.yaml +++ b/tests/e2e/configs/run-rhelai.yaml @@ -24,6 +24,10 @@ providers: - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} + - provider_id: openai + provider_type: remote::openai + config: + api_key: ${env.OPENAI_API_KEY} vector_io: - provider_id: documentation_faiss provider_type: inline::faiss @@ -130,14 +134,10 @@ registered_resources: provider_id: vllm model_type: llm provider_model_id: ${env.RHEL_AI_MODEL} - shields: - shield_id: llama-guard - provider_id: ${env.SAFETY_MODEL:+llama-guard} - provider_shield_id: ${env.SAFETY_MODEL:=} - - shield_id: code-scanner - provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner} - provider_shield_id: ${env.CODE_SCANNER_MODEL:=} + provider_id: llama-guard + provider_shield_id: openai/gpt-4o-mini datasets: [] scoring_fns: [] benchmarks: [] diff --git a/tests/e2e/configs/run-vertexai.yaml b/tests/e2e/configs/run-vertexai.yaml index 37e083b8..af6bbe2a 100644 --- a/tests/e2e/configs/run-vertexai.yaml +++ b/tests/e2e/configs/run-vertexai.yaml @@ -79,6 +79,10 @@ providers: config: project: ${env.VERTEX_AI_PROJECT} location: ${env.VERTEX_AI_LOCATION} + - provider_id: openai + provider_type: remote::openai + config: + api_key: ${env.OPENAI_API_KEY} - config: {} provider_id: sentence-transformers provider_type: inline::sentence-transformers @@ -133,7 +137,10 @@ storage: backend: kv_default registered_resources: models: [] - shields: [] + shields: + - shield_id: llama-guard + provider_id: llama-guard + provider_shield_id: openai/gpt-4o-mini vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/tests/e2e/features/steps/info.py b/tests/e2e/features/steps/info.py index e2d1ff64..59212668 100644 --- a/tests/e2e/features/steps/info.py +++ b/tests/e2e/features/steps/info.py @@ -106,6 +106,7 @@ def check_shield_structure(context: Context) -> None: assert found_shield is not None, "No shield found in response" expected_model = context.default_model + expected_provider = context.default_provider # Validate structure and values assert found_shield["type"] == "shield", "type should be 'shield'" @@ -113,11 +114,14 @@ def check_shield_structure(context: Context) -> None: found_shield["provider_id"] == "llama-guard" ), "provider_id should be 'llama-guard'" assert ( - found_shield["provider_resource_id"] == expected_model - ), f"provider_resource_id should be '{expected_model}', but is '{found_shield["provider_resource_id"]}'" + found_shield["provider_resource_id"] == f"{expected_provider}/{expected_model}" + ), ( + f"provider_resource_id should be '{expected_provider}/{expected_model}', " + f"but is '{found_shield['provider_resource_id']}'" + ) assert ( - found_shield["identifier"] == "llama-guard-shield" - ), f"identifier should be 'llama-guard-shield', but is '{found_shield["identifier"]}'" + found_shield["identifier"] == "llama-guard" + ), f"identifier should be 'llama-guard', but is '{found_shield["identifier"]}'" @then("The response contains {count:d} tools listed for provider {provider_name}") diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index e8f2dd59..98981f34 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -9,8 +9,8 @@ import pytest from fastapi import HTTPException, Request, status -from litellm.exceptions import RateLimitError -from llama_stack_client import APIConnectionError +import httpx +from llama_stack_client import APIConnectionError, RateLimitError from llama_stack_client.types import UserMessage # type: ignore from llama_stack_client.types.alpha.agents.turn import Turn from llama_stack_client.types.shared.interleaved_content_item import TextContentItem @@ -2415,13 +2415,16 @@ async def test_query_endpoint_quota_exceeded( query_request = QueryRequest( query="What is OpenStack?", provider="openai", - model="gpt-4-turbo", + model="gpt-4o-mini", ) # type: ignore mock_client = mocker.AsyncMock() mock_client.models.list = mocker.AsyncMock(return_value=[]) mock_agent = mocker.AsyncMock() + mock_response = httpx.Response(429, request=httpx.Request("POST", "http://test")) mock_agent.create_turn.side_effect = RateLimitError( - model="gpt-4-turbo", llm_provider="openai", message="" + "Rate limit exceeded for model gpt-4o-mini", + response=mock_response, + body=None, ) mocker.patch( "app.endpoints.query.get_agent", @@ -2429,7 +2432,7 @@ async def test_query_endpoint_quota_exceeded( ) mocker.patch( "app.endpoints.query.select_model_and_provider_id", - return_value=("openai/gpt-4-turbo", "gpt-4-turbo", "openai"), + return_value=("openai/gpt-4o-mini", "gpt-4o-mini", "openai"), ) mocker.patch("app.endpoints.query.validate_model_provider_override") mocker.patch( @@ -2450,8 +2453,8 @@ async def test_query_endpoint_quota_exceeded( assert exc_info.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS detail = exc_info.value.detail assert isinstance(detail, dict) - assert detail["response"] == "The model quota has been exceeded" # type: ignore - assert "gpt-4-turbo" in detail["cause"] # type: ignore + assert detail["response"] == "The quota has been exceeded" # type: ignore + assert "gpt-4o-mini" in detail["cause"] # type: ignore async def test_query_endpoint_generate_topic_summary_default_true( diff --git a/tests/unit/app/endpoints/test_query_v2.py b/tests/unit/app/endpoints/test_query_v2.py index ed1b61bf..77c04dba 100644 --- a/tests/unit/app/endpoints/test_query_v2.py +++ b/tests/unit/app/endpoints/test_query_v2.py @@ -5,8 +5,8 @@ import pytest from fastapi import HTTPException, Request, status -from litellm.exceptions import RateLimitError -from llama_stack_client import APIConnectionError +import httpx +from llama_stack_client import APIConnectionError, RateLimitError from pytest_mock import MockerFixture from app.endpoints.query_v2 import ( @@ -17,6 +17,7 @@ ) from models.config import ModelContextProtocolServer from models.requests import Attachment, QueryRequest +from utils.types import ShieldModerationResult # User ID must be proper UUID MOCK_AUTH = ( @@ -130,8 +131,9 @@ async def test_retrieve_response_no_tools_bypasses_tools(mocker: MockerFixture) mock_vector_stores = mocker.Mock() mock_vector_stores.data = [] mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) - # Mock shields.list + # Mock shields.list and models.list for run_shield_moderation mock_client.shields.list = mocker.AsyncMock(return_value=[]) + mock_client.models.list = mocker.AsyncMock(return_value=[]) # Ensure system prompt resolution does not require real config mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") @@ -172,8 +174,9 @@ async def test_retrieve_response_builds_rag_and_mcp_tools( # pylint: disable=to mock_vector_stores = mocker.Mock() mock_vector_stores.data = [mocker.Mock(id="dbA")] mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) - # Mock shields.list + # Mock shields.list and models.list for run_shield_moderation mock_client.shields.list = mocker.AsyncMock(return_value=[]) + mock_client.models.list = mocker.AsyncMock(return_value=[]) mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") mock_cfg = mocker.Mock() @@ -249,8 +252,9 @@ async def test_retrieve_response_parses_output_and_tool_calls( mock_vector_stores = mocker.Mock() mock_vector_stores.data = [] mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) - # Mock shields.list + # Mock shields.list and models.list for run_shield_moderation mock_client.shields.list = mocker.AsyncMock(return_value=[]) + mock_client.models.list = mocker.AsyncMock(return_value=[]) mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) @@ -300,8 +304,9 @@ async def test_retrieve_response_with_usage_info(mocker: MockerFixture) -> None: mock_vector_stores = mocker.Mock() mock_vector_stores.data = [] mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) - # Mock shields.list + # Mock shields.list and models.list for run_shield_moderation mock_client.shields.list = mocker.AsyncMock(return_value=[]) + mock_client.models.list = mocker.AsyncMock(return_value=[]) mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) @@ -343,8 +348,9 @@ async def test_retrieve_response_with_usage_dict(mocker: MockerFixture) -> None: mock_vector_stores = mocker.Mock() mock_vector_stores.data = [] mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) - # Mock shields.list + # Mock shields.list and models.list for run_shield_moderation mock_client.shields.list = mocker.AsyncMock(return_value=[]) + mock_client.models.list = mocker.AsyncMock(return_value=[]) mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) @@ -386,8 +392,9 @@ async def test_retrieve_response_with_empty_usage_dict(mocker: MockerFixture) -> mock_vector_stores = mocker.Mock() mock_vector_stores.data = [] mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) - # Mock shields.list + # Mock shields.list and models.list for run_shield_moderation mock_client.shields.list = mocker.AsyncMock(return_value=[]) + mock_client.models.list = mocker.AsyncMock(return_value=[]) mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) @@ -420,8 +427,9 @@ async def test_retrieve_response_validates_attachments(mocker: MockerFixture) -> mock_vector_stores = mocker.Mock() mock_vector_stores.data = [] mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) - # Mock shields.list + # Mock shields.list and models.list for run_shield_moderation mock_client.shields.list = mocker.AsyncMock(return_value=[]) + mock_client.models.list = mocker.AsyncMock(return_value=[]) mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) @@ -558,13 +566,16 @@ async def test_query_endpoint_quota_exceeded( query_request = QueryRequest( query="What is OpenStack?", provider="openai", - model="gpt-4-turbo", + model="gpt-4o-mini", attachments=[], ) # type: ignore mock_client = mocker.AsyncMock() mock_client.models.list = mocker.AsyncMock(return_value=[]) + mock_response = httpx.Response(429, request=httpx.Request("POST", "http://test")) mock_client.responses.create.side_effect = RateLimitError( - model="gpt-4-turbo", llm_provider="openai", message="" + "Rate limit exceeded for model gpt-4o-mini", + response=mock_response, + body=None, ) # Mock conversation creation (needed for query_v2) mock_conversation = mocker.Mock() @@ -572,7 +583,7 @@ async def test_query_endpoint_quota_exceeded( mock_client.conversations.create = mocker.AsyncMock(return_value=mock_conversation) mocker.patch( "app.endpoints.query.select_model_and_provider_id", - return_value=("openai/gpt-4-turbo", "gpt-4-turbo", "openai"), + return_value=("openai/gpt-4o-mini", "gpt-4o-mini", "openai"), ) mocker.patch("app.endpoints.query.validate_model_provider_override") mocker.patch( @@ -582,7 +593,10 @@ async def test_query_endpoint_quota_exceeded( mocker.patch("app.endpoints.query.check_tokens_available") mocker.patch("app.endpoints.query.get_session") mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) - mocker.patch("app.endpoints.query_v2.get_available_shields", return_value=[]) + mocker.patch( + "app.endpoints.query_v2.run_shield_moderation", + return_value=ShieldModerationResult(blocked=False), + ) mocker.patch( "app.endpoints.query_v2.prepare_tools_for_responses_api", return_value=None ) @@ -594,21 +608,34 @@ async def test_query_endpoint_quota_exceeded( assert exc_info.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS detail = exc_info.value.detail assert isinstance(detail, dict) - assert detail["response"] == "The model quota has been exceeded" # type: ignore - assert "gpt-4-turbo" in detail["cause"] # type: ignore + assert detail["response"] == "The quota has been exceeded" # type: ignore + assert "gpt-4o-mini" in detail["cause"] # type: ignore @pytest.mark.asyncio async def test_retrieve_response_with_shields_available(mocker: MockerFixture) -> None: - """Test that shields are listed and passed to responses API when available.""" + """Test that shield moderation runs and passes when content is safe.""" mock_client = mocker.Mock() - # Mock shields.list to return available shields - shield1 = mocker.Mock() - shield1.identifier = "shield-1" - shield2 = mocker.Mock() - shield2.identifier = "shield-2" - mock_client.shields.list = mocker.AsyncMock(return_value=[shield1, shield2]) + # Create mock shield with provider_resource_id + mock_shield = mocker.Mock() + mock_shield.identifier = "content-safety-shield" + mock_shield.provider_resource_id = "moderation-model" + mock_client.shields.list = mocker.AsyncMock(return_value=[mock_shield]) + + # Create mock model matching the shield's provider_resource_id + mock_model = mocker.Mock() + mock_model.identifier = "moderation-model" + mock_client.models.list = mocker.AsyncMock(return_value=[mock_model]) + + # Mock moderations.create to return safe (not flagged) content + mock_moderation_result = mocker.Mock() + mock_moderation_result.flagged = False + mock_moderation_response = mocker.Mock() + mock_moderation_response.results = [mock_moderation_result] + mock_client.moderations.create = mocker.AsyncMock( + return_value=mock_moderation_response + ) output_item = mocker.Mock() output_item.type = "message" @@ -640,22 +667,24 @@ async def test_retrieve_response_with_shields_available(mocker: MockerFixture) - assert conv_id == "abc123def456" # Normalized (without conv_ prefix) assert summary.llm_response == "Safe response" - # Verify that shields were passed in extra_body - kwargs = mock_client.responses.create.call_args.kwargs - assert "extra_body" in kwargs - assert "guardrails" in kwargs["extra_body"] - assert kwargs["extra_body"]["guardrails"] == ["shield-1", "shield-2"] + # Verify that moderation was called with the user's query + mock_client.moderations.create.assert_called_once_with( + input="hello", model="moderation-model" + ) + # Verify that responses.create was called (moderation passed) + mock_client.responses.create.assert_called_once() @pytest.mark.asyncio async def test_retrieve_response_with_no_shields_available( mocker: MockerFixture, ) -> None: - """Test that no extra_body is added when no shields are available.""" + """Test that LLM is called when no shields are configured.""" mock_client = mocker.Mock() - # Mock shields.list to return no shields + # Mock shields.list and models.list for run_shield_moderation mock_client.shields.list = mocker.AsyncMock(return_value=[]) + mock_client.models.list = mocker.AsyncMock(return_value=[]) output_item = mocker.Mock() output_item.type = "message" @@ -687,40 +716,22 @@ async def test_retrieve_response_with_no_shields_available( assert conv_id == "abc123def456" # Normalized (without conv_ prefix) assert summary.llm_response == "Response without shields" - # Verify that no extra_body was added - kwargs = mock_client.responses.create.call_args.kwargs - assert "extra_body" not in kwargs + # Verify that responses.create was called + mock_client.responses.create.assert_called_once() @pytest.mark.asyncio async def test_retrieve_response_detects_shield_violation( mocker: MockerFixture, ) -> None: - """Test that shield violations are detected and metrics are incremented.""" + """Test that shield moderation blocks content and returns early.""" mock_client = mocker.Mock() - # Mock shields.list to return available shields - shield1 = mocker.Mock() - shield1.identifier = "safety-shield" - mock_client.shields.list = mocker.AsyncMock(return_value=[shield1]) - - # Create output with shield violation (refusal) - output_item = mocker.Mock() - output_item.type = "message" - output_item.role = "assistant" - output_item.content = "I cannot help with that request" - output_item.refusal = "Content violates safety policy" - - response_obj = mocker.Mock() - response_obj.id = "resp-violation" - response_obj.output = [output_item] - response_obj.usage = None - - mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) # Mock conversations.create for new conversation creation mock_conversation = mocker.Mock() mock_conversation.id = "conv_abc123def456" mock_client.conversations.create = mocker.AsyncMock(return_value=mock_conversation) + mock_client.conversations.items.create = mocker.AsyncMock(return_value=None) mock_vector_stores = mocker.Mock() mock_vector_stores.data = [] mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) @@ -728,8 +739,13 @@ async def test_retrieve_response_detects_shield_violation( mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) - # Mock the validation error metric - validation_metric = mocker.patch("metrics.llm_calls_validation_errors_total") + # Mock run_shield_moderation to return blocked + mocker.patch( + "app.endpoints.query_v2.run_shield_moderation", + return_value=ShieldModerationResult( + blocked=True, message="Content violates safety policy" + ), + ) qr = QueryRequest(query="dangerous query") summary, conv_id, _referenced_docs, _token_usage = await retrieve_response( @@ -737,61 +753,10 @@ async def test_retrieve_response_detects_shield_violation( ) assert conv_id == "abc123def456" # Normalized (without conv_ prefix) - assert summary.llm_response == "I cannot help with that request" - - # Verify that the validation error metric was incremented - validation_metric.inc.assert_called_once() - + assert summary.llm_response == "Content violates safety policy" -@pytest.mark.asyncio -async def test_retrieve_response_no_violation_with_shields( - mocker: MockerFixture, -) -> None: - """Test that no metric is incremented when there's no shield violation.""" - mock_client = mocker.Mock() - - # Mock shields.list to return available shields - shield1 = mocker.Mock() - shield1.identifier = "safety-shield" - mock_client.shields.list = mocker.AsyncMock(return_value=[shield1]) - - # Create output without shield violation - output_item = mocker.Mock() - output_item.type = "message" - output_item.role = "assistant" - output_item.content = "Safe response" - output_item.refusal = None # No violation - - response_obj = mocker.Mock() - response_obj.id = "resp-safe" - response_obj.output = [output_item] - response_obj.usage = None - - mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) - # Mock conversations.create for new conversation creation - mock_conversation = mocker.Mock() - mock_conversation.id = "conv_abc123def456" - mock_client.conversations.create = mocker.AsyncMock(return_value=mock_conversation) - mock_vector_stores = mocker.Mock() - mock_vector_stores.data = [] - mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) - - mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") - mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) - - # Mock the validation error metric - validation_metric = mocker.patch("metrics.llm_calls_validation_errors_total") - - qr = QueryRequest(query="safe query") - summary, conv_id, _referenced_docs, _token_usage = await retrieve_response( - mock_client, "model-safe", qr, token="tkn", provider_id="test-provider" - ) - - assert conv_id == "abc123def456" # Normalized (without conv_ prefix) - assert summary.llm_response == "Safe response" - - # Verify that the validation error metric was NOT incremented - validation_metric.inc.assert_not_called() + # Verify that responses.create was NOT called (blocked by moderation) + mock_client.responses.create.assert_not_called() def _create_message_output_with_citations(mocker: MockerFixture) -> Any: @@ -870,7 +835,9 @@ async def test_retrieve_response_parses_referenced_documents( mock_vector_stores = mocker.Mock() mock_vector_stores.data = [] mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores) + # Mock shields.list and models.list for run_shield_moderation mock_client.shields.list = mocker.AsyncMock(return_value=[]) + mock_client.models.list = mocker.AsyncMock(return_value=[]) mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index a684dad9..304c43cc 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -8,8 +8,8 @@ import pytest from fastapi import HTTPException, Request, status from fastapi.responses import StreamingResponse -from litellm.exceptions import RateLimitError -from llama_stack_client import APIConnectionError +import httpx +from llama_stack_client import APIConnectionError, RateLimitError from llama_stack_client.types import UserMessage # type: ignore from llama_stack_client.types.alpha.agents.turn import Turn from llama_stack_client.types.alpha.shield_call_step import ShieldCallStep @@ -2075,7 +2075,7 @@ async def test_query_endpoint_quota_exceeded(mocker: MockerFixture) -> None: query_request = QueryRequest( query="What is OpenStack?", provider="openai", - model="gpt-4-turbo", + model="gpt-4o-mini", ) # type: ignore request = Request(scope={"type": "http"}) request.state.authorized_actions = set() @@ -2084,8 +2084,11 @@ async def test_query_endpoint_quota_exceeded(mocker: MockerFixture) -> None: mock_client.shields.list = mocker.AsyncMock(return_value=[]) mock_client.vector_stores.list = mocker.AsyncMock(return_value=mocker.Mock(data=[])) mock_agent = mocker.AsyncMock() + mock_response = httpx.Response(429, request=httpx.Request("POST", "http://test")) mock_agent.create_turn.side_effect = RateLimitError( - model="gpt-4-turbo", llm_provider="openai", message="" + "Rate limit exceeded for model gpt-4o-mini", + response=mock_response, + body=None, ) mocker.patch( "app.endpoints.streaming_query.get_agent", @@ -2093,7 +2096,7 @@ async def test_query_endpoint_quota_exceeded(mocker: MockerFixture) -> None: ) mocker.patch( "app.endpoints.streaming_query.select_model_and_provider_id", - return_value=("openai/gpt-4-turbo", "gpt-4-turbo", "openai"), + return_value=("openai/gpt-4o-mini", "gpt-4o-mini", "openai"), ) mocker.patch("app.endpoints.streaming_query.validate_model_provider_override") mocker.patch( @@ -2136,8 +2139,8 @@ async def test_query_endpoint_quota_exceeded(mocker: MockerFixture) -> None: content_str = content.decode() # The error is formatted as SSE: data: {"event":"error","response":"...","cause":"..."}\n\n # Check for the error message in the content - assert "The model quota has been exceeded" in content_str - assert "gpt-4-turbo" in content_str + assert "The quota has been exceeded" in content_str + assert "gpt-4o-mini" in content_str # ============================================================================ diff --git a/tests/unit/app/endpoints/test_streaming_query_v2.py b/tests/unit/app/endpoints/test_streaming_query_v2.py index 92c1ecfe..a9220ab8 100644 --- a/tests/unit/app/endpoints/test_streaming_query_v2.py +++ b/tests/unit/app/endpoints/test_streaming_query_v2.py @@ -1,14 +1,14 @@ # pylint: disable=redefined-outer-name,import-error, too-many-function-args """Unit tests for the /streaming_query (v2) endpoint using Responses API.""" -from types import SimpleNamespace from typing import Any, AsyncIterator +from unittest.mock import Mock import pytest from fastapi import Request, status from fastapi.responses import StreamingResponse -from litellm.exceptions import RateLimitError -from llama_stack_client import APIConnectionError +import httpx +from llama_stack_client import APIConnectionError, RateLimitError from pytest_mock import MockerFixture from app.endpoints.streaming_query_v2 import ( @@ -17,6 +17,7 @@ ) from models.config import Action, ModelContextProtocolServer from models.requests import QueryRequest +from utils.types import ShieldModerationResult @pytest.fixture @@ -49,8 +50,9 @@ async def test_retrieve_response_builds_rag_and_mcp_tools( mock_conversation = mocker.Mock() mock_conversation.id = "conv_abc123def456" mock_client.conversations.create = mocker.AsyncMock(return_value=mock_conversation) - # Mock shields.list + # Mock shields.list and models.list for run_shield_moderation mock_client.shields.list = mocker.AsyncMock(return_value=[]) + mock_client.models.list = mocker.AsyncMock(return_value=[]) mocker.patch( "app.endpoints.streaming_query_v2.get_system_prompt", return_value="PROMPT" @@ -85,8 +87,9 @@ async def test_retrieve_response_no_tools_passes_none(mocker: MockerFixture) -> mock_conversation = mocker.Mock() mock_conversation.id = "conv_abc123def456" mock_client.conversations.create = mocker.AsyncMock(return_value=mock_conversation) - # Mock shields.list + # Mock shields.list and models.list for run_shield_moderation mock_client.shields.list = mocker.AsyncMock(return_value=[]) + mock_client.models.list = mocker.AsyncMock(return_value=[]) mocker.patch( "app.endpoints.streaming_query_v2.get_system_prompt", return_value="PROMPT" @@ -147,11 +150,11 @@ async def test_streaming_query_endpoint_handler_v2_success_yields_events( ) # Build a fake async stream of chunks - async def fake_stream() -> AsyncIterator[SimpleNamespace]: + async def fake_stream() -> AsyncIterator[Mock]: """ Produce a fake asynchronous stream of response events used for testing streaming endpoints. - Yields SimpleNamespace objects that emulate event frames from a + Yields Mock objects that emulate event frames from a streaming responses API, including: - a "response.created" event with a conversation id, - content and text delta events ("response.content_part.added", @@ -162,36 +165,29 @@ async def fake_stream() -> AsyncIterator[SimpleNamespace]: - a final "response.output_text.done" event and a "response.completed" event. Returns: - AsyncIterator[SimpleNamespace]: An async iterator that yields - event-like SimpleNamespace objects representing the streamed + AsyncIterator[Mock]: An async iterator that yields + event-like Mock objects representing the streamed response frames; the final yielded response contains an `output` attribute (an empty list) to allow shield violation detection in tests. """ - yield SimpleNamespace( - type="response.created", response=SimpleNamespace(id="conv-xyz") - ) - yield SimpleNamespace(type="response.content_part.added") - yield SimpleNamespace(type="response.output_text.delta", delta="Hello ") - yield SimpleNamespace(type="response.output_text.delta", delta="world") - yield SimpleNamespace( - type="response.output_item.added", - item=SimpleNamespace( - type="function_call", id="item1", name="search", call_id="call1" - ), - ) - yield SimpleNamespace( - type="response.function_call_arguments.delta", delta='{"q":"x"}' - ) - yield SimpleNamespace( + yield Mock(type="response.created", response=Mock(id="conv-xyz")) + yield Mock(type="response.content_part.added") + yield Mock(type="response.output_text.delta", delta="Hello ") + yield Mock(type="response.output_text.delta", delta="world") + item_mock = Mock(type="function_call", id="item1", call_id="call1") + item_mock.name = "search" # 'name' is a special Mock param, set explicitly + yield Mock(type="response.output_item.added", item=item_mock) + yield Mock(type="response.function_call_arguments.delta", delta='{"q":"x"}') + yield Mock( type="response.function_call_arguments.done", item_id="item1", arguments='{"q":"x"}', ) - yield SimpleNamespace(type="response.output_text.done", text="Hello world") + yield Mock(type="response.output_text.done", text="Hello world") # Include a response object with output attribute for shield violation detection - mock_response = SimpleNamespace(output=[]) - yield SimpleNamespace(type="response.completed", response=mock_response) + mock_response = Mock(output=[]) + yield Mock(type="response.completed", response=mock_response) mocker.patch( "app.endpoints.streaming_query_v2.retrieve_response", @@ -278,15 +274,28 @@ def _raise(*_a: Any, **_k: Any) -> None: @pytest.mark.asyncio async def test_retrieve_response_with_shields_available(mocker: MockerFixture) -> None: - """Test that shields are listed and passed to streaming responses API.""" + """Test that shield moderation runs and passes when content is safe.""" mock_client = mocker.Mock() - # Mock shields.list to return available shields - shield1 = mocker.Mock() - shield1.identifier = "shield-1" - shield2 = mocker.Mock() - shield2.identifier = "shield-2" - mock_client.shields.list = mocker.AsyncMock(return_value=[shield1, shield2]) + # Create mock shield with provider_resource_id + mock_shield = mocker.Mock() + mock_shield.identifier = "content-safety-shield" + mock_shield.provider_resource_id = "moderation-model" + mock_client.shields.list = mocker.AsyncMock(return_value=[mock_shield]) + + # Create mock model matching the shield's provider_resource_id + mock_model = mocker.Mock() + mock_model.identifier = "moderation-model" + mock_client.models.list = mocker.AsyncMock(return_value=[mock_model]) + + # Mock moderations.create to return safe (not flagged) content + mock_moderation_result = mocker.Mock() + mock_moderation_result.flagged = False + mock_moderation_response = mocker.Mock() + mock_moderation_response.results = [mock_moderation_result] + mock_client.moderations.create = mocker.AsyncMock( + return_value=mock_moderation_response + ) mock_vector_stores = mocker.Mock() mock_vector_stores.data = [] @@ -307,22 +316,24 @@ async def test_retrieve_response_with_shields_available(mocker: MockerFixture) - qr = QueryRequest(query="hello") await retrieve_response(mock_client, "model-shields", qr, token="tok") - # Verify that shields were passed in extra_body - kwargs = mock_client.responses.create.call_args.kwargs - assert "extra_body" in kwargs - assert "guardrails" in kwargs["extra_body"] - assert kwargs["extra_body"]["guardrails"] == ["shield-1", "shield-2"] + # Verify that moderation was called with the user's query + mock_client.moderations.create.assert_called_once_with( + input="hello", model="moderation-model" + ) + # Verify that responses.create was called (moderation passed) + mock_client.responses.create.assert_called_once() @pytest.mark.asyncio async def test_retrieve_response_with_no_shields_available( mocker: MockerFixture, ) -> None: - """Test that no extra_body is added when no shields are available.""" + """Test that LLM is called when no shields are configured.""" mock_client = mocker.Mock() - # Mock shields.list to return no shields + # Mock shields.list and models.list for run_shield_moderation mock_client.shields.list = mocker.AsyncMock(return_value=[]) + mock_client.models.list = mocker.AsyncMock(return_value=[]) mock_vector_stores = mocker.Mock() mock_vector_stores.data = [] @@ -343,16 +354,15 @@ async def test_retrieve_response_with_no_shields_available( qr = QueryRequest(query="hello") await retrieve_response(mock_client, "model-no-shields", qr, token="tok") - # Verify that no extra_body was added - kwargs = mock_client.responses.create.call_args.kwargs - assert "extra_body" not in kwargs + # Verify that responses.create was called + mock_client.responses.create.assert_called_once() @pytest.mark.asyncio -async def test_streaming_response_detects_shield_violation( +async def test_streaming_response_blocked_by_shield_moderation( mocker: MockerFixture, dummy_request: Request ) -> None: - """Test that shield violations in streaming responses are detected and metrics incremented.""" + """Test that when shield moderation blocks, a violation stream is returned.""" # Skip real config checks mocker.patch("app.endpoints.streaming_query.check_configuration_loaded") @@ -391,44 +401,31 @@ async def test_streaming_response_detects_shield_violation( mocker.AsyncMock(return_value=None), ) - # Mock the validation error metric - validation_metric = mocker.patch("metrics.llm_calls_validation_errors_total") - - # Build a fake async stream with shield violation - async def fake_stream_with_violation() -> AsyncIterator[SimpleNamespace]: - """ - Produce an async iterator of SimpleNamespace events that simulates a streaming response. - - Yields: - AsyncIterator[SimpleNamespace]: Sequence of event objects in order: - - type="response.created" with a `response.id` - - type="response.output_text.delta" with a `delta` fragment - - type="response.output_text.done" with a `text` final chunk - - type="response.completed" whose `response.output` contains a - message object with a `refusal` field indicating a safety - policy violation - """ - yield SimpleNamespace( - type="response.created", response=SimpleNamespace(id="conv-violation") + # Build a fake async stream for violation response + async def fake_violation_stream() -> AsyncIterator[Mock]: + """Produce an async iterator simulating a shield violation response.""" + yield Mock( + type="response.content_part.added", + response_id="resp_shield", + item_id="msg_shield", + ) + yield Mock( + type="response.output_text.delta", delta="Content violates safety policy" ) - yield SimpleNamespace(type="response.output_text.delta", delta="I cannot ") - yield SimpleNamespace(type="response.output_text.done", text="I cannot help") - # Response completed with refusal in output - violation_item = SimpleNamespace( + violation_item = Mock( type="message", role="assistant", - refusal="Content violates safety policy", + content="Content violates safety policy", + refusal=None, ) - response_with_violation = SimpleNamespace( - id="conv-violation", output=[violation_item] - ) - yield SimpleNamespace( - type="response.completed", response=response_with_violation + yield Mock( + type="response.completed", + response=Mock(id="resp_shield", output=[violation_item]), ) mocker.patch( "app.endpoints.streaming_query_v2.retrieve_response", - return_value=(fake_stream_with_violation(), ""), + return_value=(fake_violation_stream(), "conv123"), ) mocker.patch("metrics.llm_calls_total") @@ -448,8 +445,9 @@ async def fake_stream_with_violation() -> AsyncIterator[SimpleNamespace]: s = chunk.decode() if isinstance(chunk, (bytes, bytearray)) else str(chunk) events.append(s) - # Verify that the validation error metric was incremented - validation_metric.inc.assert_called_once() + # Verify that the stream contains the violation message + all_events = "".join(events) + assert "Content violates safety policy" in all_events @pytest.mark.asyncio @@ -499,7 +497,7 @@ async def test_streaming_response_no_shield_violation( validation_metric = mocker.patch("metrics.llm_calls_validation_errors_total") # Build a fake async stream without violation - async def fake_stream_without_violation() -> AsyncIterator[SimpleNamespace]: + async def fake_stream_without_violation() -> AsyncIterator[Mock]: """ Produce a deterministic sequence of streaming response events that end with a message. @@ -511,22 +509,16 @@ async def fake_stream_without_violation() -> AsyncIterator[SimpleNamespace]: message where `refusal` is `None`. Returns: - An iterator yielding SimpleNamespace objects representing the + An iterator yielding Mock objects representing the streaming events of a successful response with no refusal. """ - yield SimpleNamespace( - type="response.created", response=SimpleNamespace(id="conv-safe") - ) - yield SimpleNamespace(type="response.output_text.delta", delta="Safe ") - yield SimpleNamespace(type="response.output_text.done", text="Safe response") + yield Mock(type="response.created", response=Mock(id="conv-safe")) + yield Mock(type="response.output_text.delta", delta="Safe ") + yield Mock(type="response.output_text.done", text="Safe response") # Response completed without refusal - safe_item = SimpleNamespace( - type="message", - role="assistant", - refusal=None, # No violation - ) - response_safe = SimpleNamespace(id="conv-safe", output=[safe_item]) - yield SimpleNamespace(type="response.completed", response=response_safe) + safe_item = Mock(type="message", role="assistant", refusal=None) + response_safe = Mock(id="conv-safe", output=[safe_item]) + yield Mock(type="response.completed", response=response_safe) mocker.patch( "app.endpoints.streaming_query_v2.retrieve_response", @@ -563,8 +555,11 @@ async def test_streaming_query_endpoint_handler_v2_quota_exceeded( mock_client = mocker.Mock() mock_client.models.list = mocker.AsyncMock(return_value=[mocker.Mock()]) + mock_response = httpx.Response(429, request=httpx.Request("POST", "http://test")) mock_client.responses.create.side_effect = RateLimitError( - model="gpt-4-turbo", llm_provider="openai", message="" + "Rate limit exceeded for model gpt-4o-mini", + response=mock_response, + body=None, ) # Mock conversation creation (needed for query_v2) mock_conversation = mocker.Mock() @@ -572,6 +567,7 @@ async def test_streaming_query_endpoint_handler_v2_quota_exceeded( mock_client.conversations.create = mocker.AsyncMock(return_value=mock_conversation) mock_client.vector_stores.list = mocker.AsyncMock(return_value=mocker.Mock(data=[])) mock_client.shields.list = mocker.AsyncMock(return_value=[]) + mock_client.models.list = mocker.AsyncMock(return_value=[]) mocker.patch( "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client @@ -582,11 +578,12 @@ async def test_streaming_query_endpoint_handler_v2_quota_exceeded( ) mocker.patch( "app.endpoints.streaming_query.select_model_and_provider_id", - return_value=("openai/gpt-4-turbo", "gpt-4-turbo", "openai"), + return_value=("openai/gpt-4o-mini", "gpt-4o-mini", "openai"), ) mocker.patch("app.endpoints.streaming_query.validate_model_provider_override") mocker.patch( - "app.endpoints.streaming_query_v2.get_available_shields", return_value=[] + "app.endpoints.streaming_query_v2.run_shield_moderation", + return_value=ShieldModerationResult(blocked=False), ) mocker.patch( "app.endpoints.streaming_query_v2.prepare_tools_for_responses_api", @@ -628,5 +625,5 @@ async def test_streaming_query_endpoint_handler_v2_quota_exceeded( content_str = content.decode() # The error is formatted as SSE: data: {"event":"error","response":"...","cause":"..."}\n\n # Check for the error message in the content - assert "The model quota has been exceeded" in content_str - assert "gpt-4-turbo" in content_str + assert "The quota has been exceeded" in content_str + assert "gpt-4o-mini" in content_str diff --git a/tests/unit/utils/test_shields.py b/tests/unit/utils/test_shields.py new file mode 100644 index 00000000..cf238a42 --- /dev/null +++ b/tests/unit/utils/test_shields.py @@ -0,0 +1,344 @@ +"""Unit tests for utils/shields.py functions.""" + +import httpx +import pytest +from fastapi import HTTPException, status +from llama_stack_client import BadRequestError +from pytest_mock import MockerFixture + +from utils.shields import ( + DEFAULT_VIOLATION_MESSAGE, + append_turn_to_conversation, + detect_shield_violations, + get_available_shields, + run_shield_moderation, +) + + +class TestGetAvailableShields: + """Tests for get_available_shields function.""" + + @pytest.mark.asyncio + async def test_returns_shield_identifiers(self, mocker: MockerFixture) -> None: + """Test that get_available_shields returns list of shield identifiers.""" + mock_client = mocker.Mock() + shield1 = mocker.Mock() + shield1.identifier = "shield-1" + shield2 = mocker.Mock() + shield2.identifier = "shield-2" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield1, shield2]) + + result = await get_available_shields(mock_client) + + assert result == ["shield-1", "shield-2"] + mock_client.shields.list.assert_called_once() + + @pytest.mark.asyncio + async def test_returns_empty_list_when_no_shields( + self, mocker: MockerFixture + ) -> None: + """Test that get_available_shields returns empty list when no shields available.""" + mock_client = mocker.Mock() + mock_client.shields.list = mocker.AsyncMock(return_value=[]) + + result = await get_available_shields(mock_client) + + assert result == [] + + +class TestDetectShieldViolations: + """Tests for detect_shield_violations function.""" + + def test_detects_violation_when_refusal_present( + self, mocker: MockerFixture + ) -> None: + """Test that detect_shield_violations returns True when refusal is present.""" + mock_metric = mocker.patch( + "utils.shields.metrics.llm_calls_validation_errors_total" + ) + + output_item = mocker.Mock(type="message", refusal="Content blocked") + output_items = [output_item] + + result = detect_shield_violations(output_items) + + assert result is True + mock_metric.inc.assert_called_once() + + def test_returns_false_when_no_violation(self, mocker: MockerFixture) -> None: + """Test that detect_shield_violations returns False when no refusal.""" + mock_metric = mocker.patch( + "utils.shields.metrics.llm_calls_validation_errors_total" + ) + + output_item = mocker.Mock(type="message", refusal=None) + output_items = [output_item] + + result = detect_shield_violations(output_items) + + assert result is False + mock_metric.inc.assert_not_called() + + def test_returns_false_for_non_message_items(self, mocker: MockerFixture) -> None: + """Test that detect_shield_violations ignores non-message items.""" + mock_metric = mocker.patch( + "utils.shields.metrics.llm_calls_validation_errors_total" + ) + + output_item = mocker.Mock(type="tool_call", refusal="Content blocked") + output_items = [output_item] + + result = detect_shield_violations(output_items) + + assert result is False + mock_metric.inc.assert_not_called() + + def test_returns_false_for_empty_list(self, mocker: MockerFixture) -> None: + """Test that detect_shield_violations returns False for empty list.""" + mock_metric = mocker.patch( + "utils.shields.metrics.llm_calls_validation_errors_total" + ) + + result = detect_shield_violations([]) + + assert result is False + mock_metric.inc.assert_not_called() + + +class TestRunShieldModeration: + """Tests for run_shield_moderation function.""" + + @pytest.mark.asyncio + async def test_returns_not_blocked_when_no_shields( + self, mocker: MockerFixture + ) -> None: + """Test that run_shield_moderation returns not blocked when no shields.""" + mock_client = mocker.Mock() + mock_client.shields.list = mocker.AsyncMock(return_value=[]) + mock_client.models.list = mocker.AsyncMock(return_value=[]) + + result = await run_shield_moderation(mock_client, "test input") + + assert result.blocked is False + assert result.shield_model is None + + @pytest.mark.asyncio + async def test_returns_not_blocked_when_moderation_passes( + self, mocker: MockerFixture + ) -> None: + """Test that run_shield_moderation returns not blocked when content is safe.""" + mock_client = mocker.Mock() + + # Setup shield + shield = mocker.Mock() + shield.identifier = "test-shield" + shield.provider_resource_id = "moderation-model" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield]) + + # Setup model + model = mocker.Mock() + model.identifier = "moderation-model" + mock_client.models.list = mocker.AsyncMock(return_value=[model]) + + # Setup moderation result (not flagged) + moderation_result = mocker.Mock() + moderation_result.results = [mocker.Mock(flagged=False)] + mock_client.moderations.create = mocker.AsyncMock( + return_value=moderation_result + ) + + result = await run_shield_moderation(mock_client, "safe input") + + assert result.blocked is False + assert result.shield_model is None + mock_client.moderations.create.assert_called_once_with( + input="safe input", model="moderation-model" + ) + + @pytest.mark.asyncio + async def test_returns_blocked_when_content_flagged( + self, mocker: MockerFixture + ) -> None: + """Test that run_shield_moderation returns blocked when content is flagged.""" + mock_metric = mocker.patch( + "utils.shields.metrics.llm_calls_validation_errors_total" + ) + mock_client = mocker.Mock() + + # Setup shield + shield = mocker.Mock() + shield.identifier = "test-shield" + shield.provider_resource_id = "moderation-model" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield]) + + # Setup model + model = mocker.Mock() + model.identifier = "moderation-model" + mock_client.models.list = mocker.AsyncMock(return_value=[model]) + + # Setup moderation result (flagged) + flagged_result = mocker.Mock() + flagged_result.flagged = True + flagged_result.categories = ["violence"] + flagged_result.user_message = "Content blocked for violence" + moderation_result = mocker.Mock() + moderation_result.results = [flagged_result] + mock_client.moderations.create = mocker.AsyncMock( + return_value=moderation_result + ) + + result = await run_shield_moderation(mock_client, "violent content") + + assert result.blocked is True + assert result.message == "Content blocked for violence" + assert result.shield_model == "moderation-model" + mock_metric.inc.assert_called_once() + + @pytest.mark.asyncio + async def test_returns_blocked_with_default_message_when_no_user_message( + self, mocker: MockerFixture + ) -> None: + """Test that run_shield_moderation uses default message when user_message is None.""" + mocker.patch("utils.shields.metrics.llm_calls_validation_errors_total") + mock_client = mocker.Mock() + + # Setup shield + shield = mocker.Mock() + shield.identifier = "test-shield" + shield.provider_resource_id = "moderation-model" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield]) + + # Setup model + model = mocker.Mock() + model.identifier = "moderation-model" + mock_client.models.list = mocker.AsyncMock(return_value=[model]) + + # Setup moderation result (flagged, no user_message) + flagged_result = mocker.Mock() + flagged_result.flagged = True + flagged_result.categories = ["spam"] + flagged_result.user_message = None + moderation_result = mocker.Mock() + moderation_result.results = [flagged_result] + mock_client.moderations.create = mocker.AsyncMock( + return_value=moderation_result + ) + + result = await run_shield_moderation(mock_client, "spam content") + + assert result.blocked is True + assert result.message == DEFAULT_VIOLATION_MESSAGE + assert result.shield_model == "moderation-model" + + @pytest.mark.asyncio + async def test_raises_http_exception_when_shield_model_not_found( + self, mocker: MockerFixture + ) -> None: + """Test that run_shield_moderation raises HTTPException when shield model not in models.""" + mock_client = mocker.Mock() + + # Setup shield with provider_resource_id + shield = mocker.Mock() + shield.identifier = "test-shield" + shield.provider_resource_id = "missing-model" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield]) + + # Setup models (doesn't include the shield's model) + model = mocker.Mock() + model.identifier = "other-model" + mock_client.models.list = mocker.AsyncMock(return_value=[model]) + + with pytest.raises(HTTPException) as exc_info: + await run_shield_moderation(mock_client, "test input") + + assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND + assert "missing-model" in exc_info.value.detail["cause"] # type: ignore + + @pytest.mark.asyncio + async def test_raises_http_exception_when_shield_has_no_provider_resource_id( + self, mocker: MockerFixture + ) -> None: + """Test that run_shield_moderation raises HTTPException when no provider_resource_id.""" + mock_client = mocker.Mock() + + # Setup shield without provider_resource_id + shield = mocker.Mock() + shield.identifier = "test-shield" + shield.provider_resource_id = None + mock_client.shields.list = mocker.AsyncMock(return_value=[shield]) + + mock_client.models.list = mocker.AsyncMock(return_value=[]) + + with pytest.raises(HTTPException) as exc_info: + await run_shield_moderation(mock_client, "test input") + + assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND + + @pytest.mark.asyncio + async def test_returns_blocked_on_bad_request_error( + self, mocker: MockerFixture + ) -> None: + """Test that run_shield_moderation returns blocked when BadRequestError is raised.""" + mock_metric = mocker.patch( + "utils.shields.metrics.llm_calls_validation_errors_total" + ) + mock_client = mocker.Mock() + + # Setup shield + shield = mocker.Mock() + shield.identifier = "test-shield" + shield.provider_resource_id = "moderation-model" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield]) + + # Setup model + model = mocker.Mock() + model.identifier = "moderation-model" + mock_client.models.list = mocker.AsyncMock(return_value=[model]) + + # Setup moderation to raise BadRequestError + mock_response = httpx.Response( + 400, request=httpx.Request("POST", "http://test") + ) + mock_client.moderations.create = mocker.AsyncMock( + side_effect=BadRequestError( + "Bad request", response=mock_response, body=None + ) + ) + + result = await run_shield_moderation(mock_client, "test input") + + assert result.blocked is True + assert result.message == DEFAULT_VIOLATION_MESSAGE + assert result.shield_model == "moderation-model" + mock_metric.inc.assert_called_once() + + +class TestAppendTurnToConversation: # pylint: disable=too-few-public-methods + """Tests for append_turn_to_conversation function.""" + + @pytest.mark.asyncio + async def test_appends_user_and_assistant_messages( + self, mocker: MockerFixture + ) -> None: + """Test that append_turn_to_conversation creates conversation items correctly.""" + mock_client = mocker.Mock() + mock_client.conversations.items.create = mocker.AsyncMock(return_value=None) + + await append_turn_to_conversation( + mock_client, + conversation_id="conv-123", + user_message="Hello", + assistant_message="I cannot help with that", + ) + + mock_client.conversations.items.create.assert_called_once_with( + "conv-123", + items=[ + {"type": "message", "role": "user", "content": "Hello"}, + { + "type": "message", + "role": "assistant", + "content": "I cannot help with that", + }, + ], + )