Skip to content

Conversation

@yangcao77
Copy link
Contributor

@yangcao77 yangcao77 commented Sep 15, 2025

Description

This PR implements the question validation feature, part of the feature parity with road-core
it is default to be disabled, but can be enabled by setting following in config :

question_validation:
  question_validation_enabled: true

a temp agent is used to validate the question using the validation prompt & query's model and provider pair.

the question validation result will be saved in transcript.

verified both query & streaming_query:
Screenshot 2025-09-15 at 11 07 41 AM

Screenshot 2025-09-15 at 11 07 13 AM

Type of change

  • Refactor
  • New feature
  • Bug fix
  • CVE fix
  • Optimization
  • Documentation Update
  • Configuration Update
  • Bump-up service version
  • Bump-up dependent library
  • Bump-up library or tool used for development (does not change the final image)
  • CI configuration change
  • Konflux configuration change
  • Unit tests improvement
  • Integration tests improvement
  • End to end tests improvement

Related Tickets & Documents

  • Related Issue #
  • Closes #

Checklist before requesting a review

  • I have performed a self-review of my code.
  • PR has passed all pre-merge test jobs.
  • If it is a core feature, I have added thorough tests.

Testing

  • Please provide detailed steps to perform tests related to this code change.
  • How were the fix/results from this change verified? Please provide relevant screenshots or results.

Summary by CodeRabbit

  • New Features

    • Optional question validation for queries, with a standardized invalid-query message when rejected.
    • Toggle validation via new configuration settings; behavior also applied to streaming responses.
    • Support custom validation prompts and custom invalid-query messages through profiles.
    • Transcripts now record whether a query was considered valid.
  • Documentation

    • OpenAPI updated with QuestionValidationConfiguration and new fields on Configuration and CustomProfile (including query_responses) to reflect the new validation and customization options.

Signed-off-by: Stephanie <yangcao@redhat.com>
Signed-off-by: Stephanie <yangcao@redhat.com>
Signed-off-by: Stephanie <yangcao@redhat.com>
Signed-off-by: Stephanie <yangcao@redhat.com>
Signed-off-by: Stephanie <yangcao@redhat.com>
Signed-off-by: Stephanie <yangcao@redhat.com>
Signed-off-by: Stephanie <yangcao@redhat.com>
Signed-off-by: Stephanie <yangcao@redhat.com>
Signed-off-by: Stephanie <yangcao@redhat.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 15, 2025

Walkthrough

Adds question validation capability. Introduces configuration schema and constants, utilities for validation prompts/responses, and a new utils.agent module. Integrates validation into query and streaming_query endpoints, propagates a query_is_valid flag, and updates transcript storage. OpenAPI, config models, and tests are extended accordingly. get_agent moved to utils.agent.

Changes

Cohort / File(s) Summary
OpenAPI schemas
docs/openapi.json
Adds QuestionValidationConfiguration schema with question_validation_enabled. Adds question_validation to Configuration. Adds query_responses to CustomProfile.
Config models and accessors
src/models/config.py, src/configuration.py, tests/unit/models/config/test_dump_configuration.py
Adds QuestionValidationConfiguration; extends Configuration with question_validation; adds CustomProfile.query_responses and getter; exposes AppConfig.question_validation; updates dump expectations.
Constants for validation
src/constants.py
Adds SUBJECT_REJECTED, SUBJECT_ALLOWED, DEFAULT_VALIDATION_SYSTEM_PROMPT, DEFAULT_INVALID_QUERY_RESPONSE.
Agent utility extraction
src/utils/agent.py, tests/unit/utils/test_agent.py
New module with get_agent and get_temp_agent helpers; comprehensive unit tests covering creation, reuse, sessions, tools, and temp agents.
Endpoint utilities adjustments
src/utils/endpoints.py, tests/unit/utils/test_endpoints.py
Removes get_agent; adds get_validation_system_prompt and get_invalid_query_response; updates tests for auth and prompt/response resolution.
Query endpoint validation integration
src/app/endpoints/query.py, tests/unit/app/endpoints/test_query.py
Adds validate_question; retrieve_response now returns (TurnSummary, conversation_id, bool); integrates validation and invalid-query short-circuit; transcripts include query_is_valid; tests updated/added.
Streaming query validation integration
src/app/endpoints/streaming_query.py, tests/unit/app/endpoints/test_streaming_query.py
Adds pre-query validation and invalid-query handling; propagates query_is_valid through streaming; updates transcript storage; clarifies no_tools path; adds tests for valid/invalid flows.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor U as User
  participant API as Query/Streaming Endpoint
  participant VC as Temp Validator Agent
  participant AG as Agent (persistent)
  participant LS as Llama Stack
  participant TS as Transcript Store

  U->>API: Submit query
  API->>API: Check config.question_validation_enabled
  alt Validation enabled
    API->>VC: Create temp agent (validation prompt)
    API->>VC: Ask to validate query
    VC-->>API: REJECTED/ALLOWED + validation convo_id
    alt REJECTED
      API->>U: Invalid query response
      API->>TS: Store transcript (query_is_valid=false)
    else ALLOWED
      API->>AG: get_agent (reuse/create)
      API->>LS: Retrieve/Stream response
      LS-->>API: TurnSummary/Stream
      API->>TS: Store transcript (query_is_valid=true)
      API-->>U: Return response/stream
    end
  else Validation disabled
    API->>AG: get_agent (reuse/create)
    API->>LS: Retrieve/Stream response
    LS-->>API: TurnSummary/Stream
    API->>TS: Store transcript (query_is_valid=true)
    API-->>U: Return response/stream
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • tisnik
  • manstis

Poem

I twitch my ears at questions asked,
A tiny judge with velvet task.
REJECTED? I stamp—politely so.
ALLOWED? Hop on, the answers flow!
New prompts, new paths, transcripts tidy—
A bun approves, both calm and spry. 🐇✨

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "[RHDHPAI-1097]Question validation" is concise, directly related to the primary change (adding question validation), and clear enough for a teammate scanning PR history; it avoids unnecessary detail and correctly highlights the main intent.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (22)
src/utils/agent.py (4)

75-87: Fix docstring: incorrect return description and Pylint line-too-long.

Docstring claims two return values and triggers the linter error (line > 100 chars). Clarify returns and wrap lines.

Apply this diff:

-async def get_temp_agent(
+async def get_temp_agent(
     client: AsyncLlamaStackClient,
     model_id: str,
     system_prompt: str,
 ) -> tuple[AsyncAgent, str, str]:
-    """Create a temporary agent with new agent_id and session_id.
-
-    This function creates a new agent without persistence, shields, or tools.
-    Useful for temporary operations or one-off queries, such as validating a question or generating a summary.
+    """Create a temporary agent without persistence, shields, or tools.
+
+    Useful for one-off operations (e.g., validate a question, generate a
+    summary).
 
     Args:
         client: The AsyncLlamaStackClient to use for the request.
         model_id: The ID of the model to use.
         system_prompt: The system prompt/instructions for the agent.
 
     Returns:
-        tuple[AsyncAgent, str]: A tuple containing the agent and session_id.
+        (agent, session_id, conversation_id).
     """

27-28: Document tuple order in get_agent docstring to prevent future drift.

Make the return contract explicit.

Apply this diff:

-    """Get existing agent or create a new one with session persistence."""
+    """Get an agent with session persistence.
+
+    Returns:
+        (agent, session_id, conversation_id).
+    """

Also applies to: 67-67


39-43: no_tools only disables tool parser; should it also disable shields?

If “no_tools” implies a minimal agent, consider also suppressing input/output shields; otherwise rename to reflect behavior.

Candidate change:

-        input_shields=available_input_shields if available_input_shields else [],
-        output_shields=available_output_shields if available_output_shields else [],
+        input_shields=[] if no_tools else (available_input_shields or []),
+        output_shields=[] if no_tools else (available_output_shields or []),

51-51: Reduce log level for session listing.

Session payloads can be noisy; prefer debug to avoid info-level noise.

Apply this diff:

-        logger.info("session response: %s", sessions_response)
+        logger.debug("session response: %s", sessions_response)
src/constants.py (1)

35-41: Tighten the default validation prompt for deterministic outputs.

Ask the validator to respond with exactly one token and nothing else to reduce parsing ambiguity.

Apply this diff:

 DEFAULT_VALIDATION_SYSTEM_PROMPT = (
-    "You are a helpful assistant that validates questions. You will be given a "
-    "question and you will need to validate if it is valid or not. You will "
-    f"return '{SUBJECT_REJECTED}' if the question is not valid and "
-    f"'{SUBJECT_ALLOWED}' if it is valid."
+    "You validate user questions. Decide if the question is acceptable.\n"
+    f"Reply with exactly one token: '{SUBJECT_ALLOWED}' or '{SUBJECT_REJECTED}'.\n"
+    "Do not include explanations, punctuation, or extra text."
 )
tests/unit/models/config/test_dump_configuration.py (1)

91-91: Dump includes question_validation section (LGTM).

Assertions cover presence and default False value.

Consider a follow-up test that sets question_validation_enabled=True and verifies serialization.

Also applies to: 167-168

src/models/config.py (1)

430-453: Avoid mutable default dicts in dataclass fields.

Using Field(default={}) shares the same dict across instances. Switch to default_factory for both prompts and query_responses. Update prompts too for consistency.

Apply this diff:

-    prompts: dict[str, str] = Field(default={}, init=False)
-    query_responses: dict[str, str] = Field(default={}, init=False)
+    prompts: dict[str, str] = Field(default_factory=dict, init=False)
+    query_responses: dict[str, str] = Field(default_factory=dict, init=False)

If desired, also return shallow copies to discourage mutation:

-        return self.prompts
+        return dict(self.prompts)
-        return self.query_responses
+        return dict(self.query_responses)
src/utils/endpoints.py (3)

66-77: Handle whitespace-only custom validation prompt

Return default when the custom profile’s "validation" prompt is empty or whitespace-only.

Apply:

 def get_validation_system_prompt(config: AppConfig) -> str:
     """Get the validation system prompt."""
     # profile takes precedence for setting prompt
     if (
         config.customization is not None
         and config.customization.custom_profile is not None
     ):
-        prompt = config.customization.custom_profile.get_prompts().get("validation")
-        if prompt:
-            return prompt
+        prompt = config.customization.custom_profile.get_prompts().get("validation")
+        if prompt and prompt.strip():
+            return prompt.strip()

     return constants.DEFAULT_VALIDATION_SYSTEM_PROMPT

80-91: Treat whitespace-only invalid response as unset

Mirror the guard above for "invalid_resp".

Apply:

 def get_invalid_query_response(config: AppConfig) -> str:
     """Get the invalid query response."""
     if (
         config.customization is not None
         and config.customization.custom_profile is not None
     ):
-        prompt = config.customization.custom_profile.get_query_responses().get(
-            "invalid_resp"
-        )
-        if prompt:
-            return prompt
+        prompt = config.customization.custom_profile.get_query_responses().get(
+            "invalid_resp"
+        )
+        if prompt and prompt.strip():
+            return prompt.strip()
     return constants.DEFAULT_INVALID_QUERY_RESPONSE

113-116: Comment nit: wrong option name

The comment mentions disable_system_prompt but the flag is disable_query_system_prompt.

Apply:

-        # makes sense here - if the configuration wants precedence, it can
-        # disable query system prompt altogether with disable_system_prompt.
+        # makes sense here - if the configuration wants precedence, it can
+        # disable query system prompt altogether with disable_query_system_prompt.
tests/unit/utils/test_endpoints.py (5)

264-264: Fix Pylint: line too long

Shorten docstring to satisfy 100-char limit.

Apply:

-def test_get_validation_system_prompt_with_custom_profile_no_validation_prompt():
-    """Test that default validation system prompt is returned when custom profile has no validation prompt."""
+def test_get_validation_system_prompt_with_custom_profile_no_validation_prompt():
+    """Default validation prompt when profile lacks 'validation' key."""

281-281: Fix Pylint: line too long

Shorter docstring.

Apply:

-def test_get_validation_system_prompt_with_custom_profile_empty_validation_prompt():
-    """Test that default validation system prompt is returned when custom profile has empty validation prompt."""
+def test_get_validation_system_prompt_with_custom_profile_empty_validation_prompt():
+    """Default validation prompt when profile sets 'validation' to empty."""

327-327: Fix Pylint: line too long

Shorter docstring.

Apply:

-def test_get_invalid_query_response_with_custom_profile_no_invalid_resp():
-    """Test that default invalid query response is returned when custom profile has no invalid_resp."""
+def test_get_invalid_query_response_with_custom_profile_no_invalid_resp():
+    """Default invalid response when profile lacks 'invalid_resp'."""

343-343: Fix Pylint: line too long

Shorter docstring.

Apply:

-def test_get_invalid_query_response_with_custom_profile_empty_invalid_resp():
-    """Test that default invalid query response is returned when custom profile has empty invalid_resp."""
+def test_get_invalid_query_response_with_custom_profile_empty_invalid_resp():
+    """Default invalid response when 'invalid_resp' is empty."""

235-256: Consider adding whitespace-only coverage

Add cases where profile returns " " for "validation" and "invalid_resp" to assert fallback to defaults.

docs/openapi.json (2)

1055-1058: Document the new Configuration.question_validation field

Add a brief description for discoverability.

Apply:

         "question_validation": {
-            "$ref": "#/components/schemas/QuestionValidationConfiguration"
+            "$ref": "#/components/schemas/QuestionValidationConfiguration",
+            "description": "Feature toggle and settings for pre-query validation."
         }

2293-2305: Clarify property semantics

Describe what enabling validation does to client behavior. No wire-format changes.

Apply:

     "QuestionValidationConfiguration": {
       "properties": {
         "question_validation_enabled": {
           "type": "boolean",
           "title": "Question Validation Enabled",
-          "default": false
+          "default": false,
+          "description": "When true, incoming questions are vetted by a lightweight validator before invoking the main model. Invalid questions return a configured message and are recorded in transcripts."
         }
       },
tests/unit/app/endpoints/test_streaming_query.py (1)

1785-1797: Also assert SSE content-type for valid-path streaming

Add a quick check to ensure we always return event-stream.

Apply:

@@
     # Verify the response is a StreamingResponse
     assert isinstance(response, StreamingResponse)
+    assert response.media_type == "text/event-stream"
tests/unit/app/endpoints/test_query.py (1)

173-176: Docstring drift in retrieve_response (prod code)

retrieve_response now returns a triple (summary, conversation_id, is_valid) but its docstring (src/app/endpoints/query.py) still advertises a pair. Update the docstring to avoid confusion.

Suggested change (in src/app/endpoints/query.py):

-    Returns:
-        tuple[TurnSummary, str]: A tuple containing a summary of the LLM or agent's response content
-        and the conversation ID.
+    Returns:
+        tuple[TurnSummary, str, bool]: (summary, conversation_id, query_is_valid).
src/app/endpoints/query.py (1)

396-408: Fix docstring: function returns (bool, str), not “one-word response”.

Docstring is misleading and omits the conversation_id in the return. Clarify behavior.

-async def validate_question(
-    question: str, client: AsyncLlamaStackClient, model_id: str
-) -> tuple[bool, str]:
-    """Validate a question and provides a one-word response.
+async def validate_question(
+    question: str, client: AsyncLlamaStackClient, model_id: str
+) -> tuple[bool, str]:
+    """Validate a question using a temporary agent and return the result.
@@
-    Returns:
-        bool: True if the question was deemed valid, False otherwise
+    Returns:
+        tuple[bool, str]: (is_valid, validation_conversation_id)
src/app/endpoints/streaming_query.py (1)

713-713: Set explicit SSE media type for the valid streaming path.

Helps clients and proxies treat the stream correctly.

-        return StreamingResponse(response_generator(response, query_is_valid))
+        return StreamingResponse(
+            response_generator(response, query_is_valid),
+            media_type="text/event-stream",
+        )
tests/unit/utils/test_agent.py (1)

127-176: Reduce duplication via parametrization.

Several tests differ only in shields or MCP servers. Consider pytest.mark.parametrize to shrink boilerplate and speed execution.

Also applies to: 179-229, 231-283

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b14d91c and 1110d88.

📒 Files selected for processing (13)
  • docs/openapi.json (3 hunks)
  • src/app/endpoints/query.py (8 hunks)
  • src/app/endpoints/streaming_query.py (6 hunks)
  • src/configuration.py (2 hunks)
  • src/constants.py (1 hunks)
  • src/models/config.py (4 hunks)
  • src/utils/agent.py (1 hunks)
  • src/utils/endpoints.py (1 hunks)
  • tests/unit/app/endpoints/test_query.py (37 hunks)
  • tests/unit/app/endpoints/test_streaming_query.py (4 hunks)
  • tests/unit/models/config/test_dump_configuration.py (2 hunks)
  • tests/unit/utils/test_agent.py (1 hunks)
  • tests/unit/utils/test_endpoints.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-08-06T06:02:21.060Z
Learnt from: eranco74
PR: lightspeed-core/lightspeed-stack#348
File: src/utils/endpoints.py:91-94
Timestamp: 2025-08-06T06:02:21.060Z
Learning: The direct assignment to `agent._agent_id` in `src/utils/endpoints.py` is a necessary workaround for the missing agent rehydration feature in the LLS client SDK. This allows preserving conversation IDs when handling existing agents.

Applied to files:

  • src/utils/agent.py
🧬 Code graph analysis (10)
src/configuration.py (1)
src/models/config.py (1)
  • QuestionValidationConfiguration (235-238)
tests/unit/utils/test_endpoints.py (4)
src/models/requests.py (1)
  • QueryRequest (72-222)
src/models/config.py (4)
  • Action (311-350)
  • CustomProfile (425-452)
  • get_prompts (446-448)
  • get_query_responses (450-452)
src/utils/endpoints.py (3)
  • validate_model_provider_override (137-157)
  • get_validation_system_prompt (66-77)
  • get_invalid_query_response (80-91)
src/configuration.py (3)
  • AppConfig (33-140)
  • init_from_dict (56-58)
  • customization (115-119)
tests/unit/app/endpoints/test_streaming_query.py (4)
src/configuration.py (1)
  • question_validation (136-140)
tests/unit/app/endpoints/test_query.py (2)
  • mock_metrics (50-55)
  • mock_database_operations (58-63)
src/models/requests.py (1)
  • QueryRequest (72-222)
src/app/endpoints/streaming_query.py (1)
  • streaming_query_endpoint_handler (530-725)
src/utils/agent.py (2)
src/utils/suid.py (1)
  • get_suid (6-12)
src/utils/types.py (2)
  • GraniteToolParser (26-40)
  • get_parser (36-40)
tests/unit/app/endpoints/test_query.py (4)
src/configuration.py (1)
  • question_validation (136-140)
src/app/endpoints/query.py (2)
  • retrieve_response (427-588)
  • query_endpoint_handler (156-271)
tests/unit/app/endpoints/test_conversations.py (1)
  • dummy_request (30-40)
tests/unit/app/endpoints/test_streaming_query.py (2)
  • mock_metrics (62-67)
  • mock_database_operations (53-59)
src/app/endpoints/query.py (4)
src/utils/agent.py (2)
  • get_agent (18-67)
  • get_temp_agent (70-101)
src/utils/endpoints.py (3)
  • check_configuration_loaded (57-63)
  • get_invalid_query_response (80-91)
  • get_validation_system_prompt (66-77)
src/configuration.py (2)
  • configuration (61-65)
  • question_validation (136-140)
src/utils/types.py (1)
  • TurnSummary (59-78)
src/utils/endpoints.py (2)
src/models/config.py (3)
  • config (132-138)
  • get_prompts (446-448)
  • get_query_responses (450-452)
src/configuration.py (2)
  • AppConfig (33-140)
  • customization (115-119)
src/models/config.py (1)
src/configuration.py (1)
  • question_validation (136-140)
src/app/endpoints/streaming_query.py (6)
src/utils/agent.py (1)
  • get_agent (18-67)
src/utils/endpoints.py (2)
  • get_system_prompt (94-134)
  • get_invalid_query_response (80-91)
src/utils/transcripts.py (1)
  • store_transcript (33-86)
src/utils/types.py (1)
  • TurnSummary (59-78)
src/app/endpoints/query.py (3)
  • get_rag_toolgroups (623-650)
  • validate_question (396-424)
  • is_transcripts_enabled (73-79)
src/configuration.py (2)
  • configuration (61-65)
  • question_validation (136-140)
tests/unit/utils/test_agent.py (2)
src/configuration.py (3)
  • configuration (61-65)
  • AppConfig (33-140)
  • init_from_dict (56-58)
src/utils/agent.py (2)
  • get_agent (18-67)
  • get_temp_agent (70-101)
🪛 GitHub Actions: Python linter
tests/unit/utils/test_endpoints.py

[error] 264-264: Pylint: Line too long (110/100) in tests/unit/utils/test_endpoints.py:264. Command: 'uv run pylint src tests'


[error] 281-281: Pylint: Line too long (113/100) in tests/unit/utils/test_endpoints.py:281. Command: 'uv run pylint src tests'


[error] 327-327: Pylint: Line too long (103/100) in tests/unit/utils/test_endpoints.py:327. Command: 'uv run pylint src tests'


[error] 343-343: Pylint: Line too long (106/100) in tests/unit/utils/test_endpoints.py:343. Command: 'uv run pylint src tests'

src/utils/agent.py

[error] 78-78: Pylint: Line too long (110/100) in src/utils/agent.py:78. Command: 'uv run pylint src tests'

src/app/endpoints/streaming_query.py

[error] 530-530: Pylint: Too many statements (55/50) in src/app/endpoints/streaming_query.py:530. Command: 'uv run pylint src tests'

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-pr
  • GitHub Check: e2e_tests
🔇 Additional comments (12)
src/utils/agent.py (2)

46-49: Agent ID rebind workaround is acceptable.

Directly setting agent._agent_id to preserve conversation continuity matches prior guidance and keeps transcripts coherent.

Please confirm this remains required with current Llama Stack SDK; if rehydration lands, we can remove the create/delete swap.


30-33: Catch the client's NotFoundError instead of suppressing ValueError

File: src/utils/agent.py lines 30–33

        with suppress(ValueError):
            agent_response = await client.agents.retrieve(agent_id=conversation_id)
            existing_agent_id = agent_response.agent_id

suppress(ValueError) only hides ValueError; client.agents.retrieve will surface client-specific errors (e.g., llama_stack_client.NotFoundError or other API errors) when an agent is missing — catch the client's NotFoundError (or the client's base API error) explicitly (import from llama_stack_client) or broaden the handled exception after confirming the client contract.

src/configuration.py (1)

22-23: Expose question_validation in AppConfig (LGTM).

Import and property wiring look correct and consistent with the rest of the accessors.

Also applies to: 135-141

src/constants.py (1)

41-42: Default invalid-query response (LGTM).

Reasonable default; can be overridden via config/profile.

src/models/config.py (2)

235-239: QuestionValidationConfiguration addition (LGTM).

Sane default off; integrates cleanly with top-level Configuration.


512-515: Configuration.question_validation wiring (LGTM).

Field with default_factory ensures presence in serialized config.

docs/openapi.json (1)

1293-1300: LGTM: CustomProfile.query_responses exposure

Schema aligns with code (get_query_responses), enabling configuration-driven invalid responses.

src/app/endpoints/query.py (2)

499-514: Validation short‑circuit in retrieve_response: LGTM.

Early return with a friendly invalid‑query message and transcript flagging is correct.


220-227: Resolved — callers already handle the 3‑tuple return

  • src/app/endpoints/query.py:220 unpacks (summary, conversation_id, query_is_valid).
  • src/app/endpoints/streaming_query.py calls its local retrieve_response (returns (response, conversation_id)) and is unaffected.
  • tests/unit/app/endpoints/test_query.py use 3‑tuple unpacking; streaming tests use 2‑tuple.
tests/unit/utils/test_agent.py (2)

10-36: Config fixture setup: LGTM.

Good minimal AppConfig for isolating agent tests.


385-441: Parser path coverage: LGTM.

Solid check that GraniteToolParser is honored when no_tools=False.

src/app/endpoints/streaming_query.py (1)

530-530: Extract helpers to lower statements in streaming_query_endpoint_handler

Extract pre-validation and the "invalid transcript" SSE path into top-level helpers so the handler stays under pylint's statements limit.

File: src/app/endpoints/streaming_query.py

Add outside the handler:

def build_invalid_query_sse(conversation_id: str, text: str):
    return iter([
        stream_start_event(conversation_id),
        format_stream_data({"event": "turn_complete", "data": {"id": 0, "token": text}}),
        stream_end_event({}),
    ])

Inside the handler, replace the invalid-path return with:

return StreamingResponse(build_invalid_query_sse(conv_id, invalid_text), media_type="text/event-stream")

Verify locally (corrected command):

python -m pylint src/app/endpoints/streaming_query.py -j 0

Comment on lines +419 to +424
response = cast(Turn, response)
return (
constants.SUBJECT_REJECTED
not in interleaved_content_as_str(response.output_message.content),
conversation_id,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Harden against missing output_message/content to avoid AttributeError.

Some turns may lack output_message.content; this will crash. Guard before parsing.

-    response = cast(Turn, response)
-    return (
-        constants.SUBJECT_REJECTED
-        not in interleaved_content_as_str(response.output_message.content),
-        conversation_id,
-    )
+    response = cast(Turn, response)
+    output_content = ""
+    if getattr(response, "output_message", None) and getattr(response.output_message, "content", None) is not None:
+        output_content = interleaved_content_as_str(response.output_message.content)
+    is_valid = constants.SUBJECT_REJECTED not in output_content
+    return (is_valid, conversation_id)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
response = cast(Turn, response)
return (
constants.SUBJECT_REJECTED
not in interleaved_content_as_str(response.output_message.content),
conversation_id,
)
response = cast(Turn, response)
output_content = ""
if getattr(response, "output_message", None) and getattr(response.output_message, "content", None) is not None:
output_content = interleaved_content_as_str(response.output_message.content)
is_valid = constants.SUBJECT_REJECTED not in output_content
return (is_valid, conversation_id)
🤖 Prompt for AI Agents
In src/app/endpoints/query.py around lines 419 to 424, the code assumes
response.output_message.content always exists which can raise AttributeError for
some turns; update the logic to safely handle missing output_message or missing
content by checking that response and response.output_message are not None and
that response.output_message.content exists (or use getattr with a default empty
list/string) before calling interleaved_content_as_str, and treat missing
content as an empty string so the boolean expression becomes False without
throwing; return the same tuple (subject_rejected_flag, conversation_id) using
the guarded value.

Comment on lines +597 to +629
# Check question validation before getting response
query_is_valid = True
if configuration.question_validation.question_validation_enabled:
query_is_valid, temp_agent_conversation_id = await validate_question(
query_request.query, client, llama_stack_model_id
)
if not query_is_valid:
response = get_invalid_query_response(configuration)
if not is_transcripts_enabled():
logger.debug(
"Transcript collection is disabled in the configuration"
)
else:
summary = TurnSummary(
llm_response=response,
tool_calls=[],
)
store_transcript(
user_id=user_id,
conversation_id=query_request.conversation_id
or temp_agent_conversation_id,
model_id=model_id,
provider_id=provider_id,
query_is_valid=query_is_valid,
query=query_request.query,
query_request=query_request,
summary=summary,
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
truncated=False, # TODO(lucasagomes): implement truncation as part
# of quota work
attachments=query_request.attachments or [],
)
return StreamingResponse(response)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Invalid streaming path returns plain text, breaking SSE clients.

For invalid queries you return StreamingResponse(str), not SSE events. Emit start/turn_complete/end events to keep the contract.

-        if configuration.question_validation.question_validation_enabled:
-            query_is_valid, temp_agent_conversation_id = await validate_question(
-                query_request.query, client, llama_stack_model_id
-            )
-            if not query_is_valid:
-                response = get_invalid_query_response(configuration)
-                if not is_transcripts_enabled():
-                    logger.debug(
-                        "Transcript collection is disabled in the configuration"
-                    )
-                else:
-                    summary = TurnSummary(
-                        llm_response=response,
-                        tool_calls=[],
-                    )
-                    store_transcript(
-                        user_id=user_id,
-                        conversation_id=query_request.conversation_id
-                        or temp_agent_conversation_id,
-                        model_id=model_id,
-                        provider_id=provider_id,
-                        query_is_valid=query_is_valid,
-                        query=query_request.query,
-                        query_request=query_request,
-                        summary=summary,
-                        rag_chunks=[],  # TODO(lucasagomes): implement rag_chunks
-                        truncated=False,  # TODO(lucasagomes): implement truncation as part
-                        # of quota work
-                        attachments=query_request.attachments or [],
-                    )
-                return StreamingResponse(response)
+        if configuration.question_validation.question_validation_enabled:
+            query_is_valid, temp_agent_conversation_id = await validate_question(
+                query_request.query, client, llama_stack_model_id
+            )
+            if not query_is_valid:
+                invalid_text = get_invalid_query_response(configuration)
+                if is_transcripts_enabled():
+                    summary = TurnSummary(llm_response=invalid_text, tool_calls=[])
+                    store_transcript(
+                        user_id=user_id,
+                        conversation_id=query_request.conversation_id or temp_agent_conversation_id,
+                        model_id=model_id,
+                        provider_id=provider_id,
+                        query_is_valid=query_is_valid,
+                        query=query_request.query,
+                        query_request=query_request,
+                        summary=summary,
+                        rag_chunks=[],  # TODO(lucasagomes): implement rag_chunks
+                        truncated=False,  # TODO(lucasagomes): implement truncation as part of quota work
+                        attachments=query_request.attachments or [],
+                    )
+                # Return minimal SSE stream for invalid query
+                conv_id = query_request.conversation_id or temp_agent_conversation_id
+                return StreamingResponse(
+                    iter(
+                        [
+                            stream_start_event(conv_id),
+                            format_stream_data({"event": "turn_complete", "data": {"id": 0, "token": invalid_text}}),
+                            stream_end_event({}),
+                        ]
+                    ),
+                    media_type="text/event-stream",
+                )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Check question validation before getting response
query_is_valid = True
if configuration.question_validation.question_validation_enabled:
query_is_valid, temp_agent_conversation_id = await validate_question(
query_request.query, client, llama_stack_model_id
)
if not query_is_valid:
response = get_invalid_query_response(configuration)
if not is_transcripts_enabled():
logger.debug(
"Transcript collection is disabled in the configuration"
)
else:
summary = TurnSummary(
llm_response=response,
tool_calls=[],
)
store_transcript(
user_id=user_id,
conversation_id=query_request.conversation_id
or temp_agent_conversation_id,
model_id=model_id,
provider_id=provider_id,
query_is_valid=query_is_valid,
query=query_request.query,
query_request=query_request,
summary=summary,
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
truncated=False, # TODO(lucasagomes): implement truncation as part
# of quota work
attachments=query_request.attachments or [],
)
return StreamingResponse(response)
# Check question validation before getting response
query_is_valid = True
if configuration.question_validation.question_validation_enabled:
query_is_valid, temp_agent_conversation_id = await validate_question(
query_request.query, client, llama_stack_model_id
)
if not query_is_valid:
invalid_text = get_invalid_query_response(configuration)
if is_transcripts_enabled():
summary = TurnSummary(llm_response=invalid_text, tool_calls=[])
store_transcript(
user_id=user_id,
conversation_id=query_request.conversation_id or temp_agent_conversation_id,
model_id=model_id,
provider_id=provider_id,
query_is_valid=query_is_valid,
query=query_request.query,
query_request=query_request,
summary=summary,
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
truncated=False, # TODO(lucasagomes): implement truncation as part of quota work
attachments=query_request.attachments or [],
)
# Return minimal SSE stream for invalid query
conv_id = query_request.conversation_id or temp_agent_conversation_id
return StreamingResponse(
iter(
[
stream_start_event(conv_id),
format_stream_data({"event": "turn_complete", "data": {"id": 0, "token": invalid_text}}),
stream_end_event({}),
]
),
media_type="text/event-stream",
)

conversation_id = agent.agent_id
session_id = await agent.create_session(get_suid())

return agent, conversation_id, session_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Standardize return tuple order between get_agent and get_temp_agent.

get_agent returns (agent, conversation_id, session_id) while get_temp_agent returns (agent, session_id, conversation_id). This will break unpacking at call sites. Standardize on (agent, session_id, conversation_id) for both.

Apply this diff:

-    return agent, conversation_id, session_id
+    return agent, session_id, conversation_id

Run to find unpacking patterns and usages:

Also applies to: 101-101


🏁 Script executed:

#!/bin/bash
# Expect unpacking or indexing to match (agent, session_id, conversation_id)
rg -n -C2 -P '\bget_(temp_)?agent\s*\('
rg -n -C2 -P '\(\s*\w+\s*,\s*\w+\s*,\s*\w+\s*\)\s*=\s*await\s+get_(temp_)?agent'

Length of output: 5262


Standardize return tuple order — make get_temp_agent match get_agent (agent, conversation_id, session_id)

get_agent returns (agent, conversation_id, session_id) and is unpacked that way across the codebase/tests; get_temp_agent returns (agent, session_id, conversation_id) causing inconsistent unpacking. Do NOT change get_agent without updating callers/tests. Action: change get_temp_agent to return (agent, conversation_id, session_id) and update its callers/tests accordingly.

Examples of affected locations: src/app/endpoints/query.py (lines 410, 487), src/app/endpoints/streaming_query.py (line 783), tests/unit/utils/test_agent.py (lines 56, 100, 153, 205, 259, 358, 416, 456, 494, 526).

🤖 Prompt for AI Agents
In src/utils/agent.py around line 67, get_temp_agent currently returns (agent,
session_id, conversation_id) which is inconsistent with get_agent; change the
return tuple to (agent, conversation_id, session_id) to match get_agent, then
update all callers and tests that unpack the old order—specifically adjust
usages in src/app/endpoints/query.py (lines ~410, ~487),
src/app/endpoints/streaming_query.py (~783) and the listed tests in
tests/unit/utils/test_agent.py (lines ~56, 100, 153, 205, 259, 358, 416, 456,
494, 526) to expect and unpack (agent, conversation_id, session_id).

Comment on lines +1592 to +1678
@pytest.mark.asyncio
async def test_streaming_query_endpoint_with_question_validation_invalid_query(
setup_configuration, mocker
):
"""Test streaming query endpoint with question validation enabled and invalid query."""
# Mock metrics
mock_metrics(mocker)

# Mock database operations
mock_database_operations(mocker)

# Setup configuration with question validation enabled
setup_configuration.question_validation.question_validation_enabled = True
mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration)

# Mock the client
mock_client = mocker.AsyncMock()
mock_client.models.list.return_value = [
mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1")
]
mocker.patch(
"client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client
)

# Mock the validation agent response (invalid query)
mock_validation_agent = mocker.AsyncMock()
mock_validation_agent.agent_id = "validation_agent_id"
mock_validation_agent.create_session.return_value = "validation_session_id"

# Mock the validation response that contains SUBJECT_REJECTED
mock_validation_turn = mocker.Mock()
mock_validation_turn.output_message.content = [{"type": "text", "text": "REJECTED"}]
mock_validation_agent.create_turn.return_value = mock_validation_turn

# Mock the validation functions
mocker.patch(
"app.endpoints.streaming_query.validate_question",
return_value=(False, "validation_agent_id"),
)
mocker.patch(
"app.endpoints.streaming_query.get_invalid_query_response",
return_value="Invalid query response",
)
mocker.patch(
"app.endpoints.streaming_query.interleaved_content_as_str",
return_value="REJECTED",
)

# Mock other dependencies
mocker.patch("app.endpoints.streaming_query.validate_model_provider_override")
mocker.patch("app.endpoints.streaming_query.check_configuration_loaded")
mocker.patch(
"app.endpoints.streaming_query.is_transcripts_enabled", return_value=False
)
mocker.patch(
"app.endpoints.streaming_query.select_model_and_provider_id",
return_value=("fake_model_id", "fake_model_id", "fake_provider_id"),
)

query_request = QueryRequest(query="Invalid question about unrelated topic")

request = Request(
scope={
"type": "http",
}
)
request.state.authorized_actions = set(Action)

response = await streaming_query_endpoint_handler(
request=request, query_request=query_request, auth=MOCK_AUTH
)

# Verify the response is a StreamingResponse
assert isinstance(response, StreamingResponse)

# Collect the streaming response content
streaming_content = []
async for chunk in response.body_iterator:
streaming_content.append(chunk)

# Convert to string for assertions
full_content = "".join(streaming_content)

# Verify the response contains the invalid query response
assert "Invalid query response" in full_content


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Streaming invalid path should emit proper SSE with event-stream content-type

Current streaming endpoint returns StreamingResponse with a raw string for invalid queries, which is not SSE and likely sets the wrong media type. Strengthen this test and fix the handler to always emit SSE.

Apply to the test to catch regressions:

@@
     response = await streaming_query_endpoint_handler(
         request=request, query_request=query_request, auth=MOCK_AUTH
     )
 
     # Verify the response is a StreamingResponse
     assert isinstance(response, StreamingResponse)
+    # Must be SSE
+    assert response.media_type == "text/event-stream"
@@
-    # Verify the response contains the invalid query response
+    # Verify the response contains the invalid query response
     assert "Invalid query response" in full_content
+    # And is wrapped as SSE tokens
+    assert "data: " in full_content

Proposed handler fix (in app/endpoints/streaming_query.py) to ensure SSE for invalid path:

@@ async def streaming_query_endpoint_handler(...):
-        if configuration.question_validation.question_validation_enabled:
+        if configuration.question_validation.question_validation_enabled:
             query_is_valid, temp_agent_conversation_id = await validate_question(
                 query_request.query, client, llama_stack_model_id
             )
             if not query_is_valid:
-                response = get_invalid_query_response(configuration)
-                if not is_transcripts_enabled():
-                    logger.debug(
-                        "Transcript collection is disabled in the configuration"
-                    )
-                else:
-                    summary = TurnSummary(
-                        llm_response=response,
-                        tool_calls=[],
-                    )
-                    store_transcript(
-                        user_id=user_id,
-                        conversation_id=query_request.conversation_id
-                        or temp_agent_conversation_id,
-                        model_id=model_id,
-                        provider_id=provider_id,
-                        query_is_valid=query_is_valid,
-                        query=query_request.query,
-                        query_request=query_request,
-                        summary=summary,
-                        rag_chunks=[],  # TODO(lucasagomes): implement rag_chunks
-                        truncated=False,  # TODO(lucasagomes): implement truncation as part
-                        # of quota work
-                        attachments=query_request.attachments or [],
-                    )
-                return StreamingResponse(response)
+                invalid_msg = get_invalid_query_response(configuration)
+                async def invalid_sse() -> AsyncIterator[str]:
+                    yield stream_start_event(query_request.conversation_id or temp_agent_conversation_id or "")
+                    # send one token event with the invalid message
+                    yield f'data: {json.dumps({"event": "token", "token": invalid_msg, "role": "system"})}\n\n'
+                    yield stream_end_event({})
+                if is_transcripts_enabled():
+                    store_transcript(
+                        user_id=user_id,
+                        conversation_id=query_request.conversation_id or temp_agent_conversation_id,
+                        model_id=model_id,
+                        provider_id=provider_id,
+                        query_is_valid=False,
+                        query=query_request.query,
+                        query_request=query_request,
+                        summary=TurnSummary(llm_response=invalid_msg, tool_calls=[]),
+                        rag_chunks=[],
+                        truncated=False,
+                        attachments=query_request.attachments or [],
+                    )
+                return StreamingResponse(invalid_sse(), media_type="text/event-stream")

🏁 Script executed:

#!/bin/bash
set -euo pipefail
echo "== Search for handler definition and related symbols =="
rg -n --hidden --no-ignore-vcs "async def streaming_query_endpoint_handler" -S || true
rg -n --hidden --no-ignore-vcs "def streaming_query_endpoint_handler" -S || true
rg -n --hidden --no-ignore-vcs "streaming_query_endpoint_handler" -S || true

echo
echo "== Locate app/endpoints/streaming_query.py =="
FILE=$(rg --files | rg "app/endpoints/streaming_query.py" || true)
if [ -n "$FILE" ]; then
  echo "Found: $FILE"
  echo "---- head (1-300) ----"
  sed -n '1,300p' "$FILE"
  echo "---- tail (301-800) ----"
  sed -n '301,800p' "$FILE" || true
else
  echo "app/endpoints/streaming_query.py not found by exact path; searching for files mentioning 'streaming_query'"
  rg -n --hidden --no-ignore-vcs "streaming_query" -S -g '!**/node_modules/**' || true
fi

echo
echo "== Search for helpers and patterns referenced in review comment =="
rg -n --hidden --no-ignore-vcs "get_invalid_query_response" -S -C3 || true
rg -n --hidden --no-ignore-vcs "validate_question" -S -C3 || true
rg -n --hidden --no-ignore-vcs "is_transcripts_enabled" -S -C3 || true
rg -n --hidden --no-ignore-vcs "store_transcript" -S -C3 || true
rg -n --hidden --no-ignore-vcs "stream_start_event|stream_end_event" -S -C3 || true
rg -n --hidden --no-ignore-vcs "StreamingResponse" -S -C3 || true
rg -n --hidden --no-ignore-vcs "text/event-stream" -S -C3 || true

echo
echo "== Show any StreamingResponse returns inside streaming_query module(s) =="
rg -n --hidden --no-ignore-vcs "return\s+StreamingResponse" -S -C5 || true

echo
echo "== Done =="

Length of output: 78021


Streaming invalid path should emit proper SSE with event-stream content-type

Handler currently returns a raw string via StreamingResponse for invalid-question validation; emit SSE and set media_type="text/event-stream". Strengthen test to catch regressions. Also set media_type for the normal streaming return.

Apply to the test to catch regressions:

@@
     response = await streaming_query_endpoint_handler(
         request=request, query_request=query_request, auth=MOCK_AUTH
     )
 
     # Verify the response is a StreamingResponse
     assert isinstance(response, StreamingResponse)
+    # Must be SSE
+    assert response.media_type == "text/event-stream"
@@
-    # Verify the response contains the invalid query response
+    # Verify the response contains the invalid query response
     assert "Invalid query response" in full_content
+    # And is wrapped as SSE tokens
+    assert "data: " in full_content

Proposed handler fix (in src/app/endpoints/streaming_query.py) to ensure SSE for invalid path:

@@ async def streaming_query_endpoint_handler(...):
-        if configuration.question_validation.question_validation_enabled:
+        if configuration.question_validation.question_validation_enabled:
             query_is_valid, temp_agent_conversation_id = await validate_question(
                 query_request.query, client, llama_stack_model_id
             )
             if not query_is_valid:
-                response = get_invalid_query_response(configuration)
-                if not is_transcripts_enabled():
-                    logger.debug(
-                        "Transcript collection is disabled in the configuration"
-                    )
-                else:
-                    summary = TurnSummary(
-                        llm_response=response,
-                        tool_calls=[],
-                    )
-                    store_transcript(
-                        user_id=user_id,
-                        conversation_id=query_request.conversation_id
-                        or temp_agent_conversation_id,
-                        model_id=model_id,
-                        provider_id=provider_id,
-                        query_is_valid=query_is_valid,
-                        query=query_request.query,
-                        query_request=query_request,
-                        summary=summary,
-                        rag_chunks=[],  # TODO(lucasagomes): implement rag_chunks
-                        truncated=False,  # TODO(lucasagomes): implement truncation as part
-                        # of quota work
-                        attachments=query_request.attachments or [],
-                    )
-                return StreamingResponse(response)
+                invalid_msg = get_invalid_query_response(configuration)
+                async def invalid_sse() -> AsyncIterator[str]:
+                    yield stream_start_event(query_request.conversation_id or temp_agent_conversation_id or "")
+                    # send one token event with the invalid message
+                    yield f'data: {json.dumps({"event": "token", "token": invalid_msg, "role": "system"})}\n\n'
+                    yield stream_end_event({})
+                if is_transcripts_enabled():
+                    store_transcript(
+                        user_id=user_id,
+                        conversation_id=query_request.conversation_id or temp_agent_conversation_id,
+                        model_id=model_id,
+                        provider_id=provider_id,
+                        query_is_valid=False,
+                        query=query_request.query,
+                        query_request=query_request,
+                        summary=TurnSummary(llm_response=invalid_msg, tool_calls=[]),
+                        rag_chunks=[],
+                        truncated=False,
+                        attachments=query_request.attachments or [],
+                    )
+                return StreamingResponse(invalid_sse(), media_type="text/event-stream")

Also change the normal successful stream return to explicitly set media_type, e.g.:

-        return StreamingResponse(response_generator(response, query_is_valid))
+        return StreamingResponse(response_generator(response, query_is_valid), media_type="text/event-stream")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@pytest.mark.asyncio
async def test_streaming_query_endpoint_with_question_validation_invalid_query(
setup_configuration, mocker
):
"""Test streaming query endpoint with question validation enabled and invalid query."""
# Mock metrics
mock_metrics(mocker)
# Mock database operations
mock_database_operations(mocker)
# Setup configuration with question validation enabled
setup_configuration.question_validation.question_validation_enabled = True
mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration)
# Mock the client
mock_client = mocker.AsyncMock()
mock_client.models.list.return_value = [
mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1")
]
mocker.patch(
"client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client
)
# Mock the validation agent response (invalid query)
mock_validation_agent = mocker.AsyncMock()
mock_validation_agent.agent_id = "validation_agent_id"
mock_validation_agent.create_session.return_value = "validation_session_id"
# Mock the validation response that contains SUBJECT_REJECTED
mock_validation_turn = mocker.Mock()
mock_validation_turn.output_message.content = [{"type": "text", "text": "REJECTED"}]
mock_validation_agent.create_turn.return_value = mock_validation_turn
# Mock the validation functions
mocker.patch(
"app.endpoints.streaming_query.validate_question",
return_value=(False, "validation_agent_id"),
)
mocker.patch(
"app.endpoints.streaming_query.get_invalid_query_response",
return_value="Invalid query response",
)
mocker.patch(
"app.endpoints.streaming_query.interleaved_content_as_str",
return_value="REJECTED",
)
# Mock other dependencies
mocker.patch("app.endpoints.streaming_query.validate_model_provider_override")
mocker.patch("app.endpoints.streaming_query.check_configuration_loaded")
mocker.patch(
"app.endpoints.streaming_query.is_transcripts_enabled", return_value=False
)
mocker.patch(
"app.endpoints.streaming_query.select_model_and_provider_id",
return_value=("fake_model_id", "fake_model_id", "fake_provider_id"),
)
query_request = QueryRequest(query="Invalid question about unrelated topic")
request = Request(
scope={
"type": "http",
}
)
request.state.authorized_actions = set(Action)
response = await streaming_query_endpoint_handler(
request=request, query_request=query_request, auth=MOCK_AUTH
)
# Verify the response is a StreamingResponse
assert isinstance(response, StreamingResponse)
# Collect the streaming response content
streaming_content = []
async for chunk in response.body_iterator:
streaming_content.append(chunk)
# Convert to string for assertions
full_content = "".join(streaming_content)
# Verify the response contains the invalid query response
assert "Invalid query response" in full_content
@pytest.mark.asyncio
async def test_streaming_query_endpoint_with_question_validation_invalid_query(
setup_configuration, mocker
):
"""Test streaming query endpoint with question validation enabled and invalid query."""
# Mock metrics
mock_metrics(mocker)
# Mock database operations
mock_database_operations(mocker)
# Setup configuration with question validation enabled
setup_configuration.question_validation.question_validation_enabled = True
mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration)
# Mock the client
mock_client = mocker.AsyncMock()
mock_client.models.list.return_value = [
mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1")
]
mocker.patch(
"client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client
)
# Mock the validation agent response (invalid query)
mock_validation_agent = mocker.AsyncMock()
mock_validation_agent.agent_id = "validation_agent_id"
mock_validation_agent.create_session.return_value = "validation_session_id"
# Mock the validation response that contains SUBJECT_REJECTED
mock_validation_turn = mocker.Mock()
mock_validation_turn.output_message.content = [{"type": "text", "text": "REJECTED"}]
mock_validation_agent.create_turn.return_value = mock_validation_turn
# Mock the validation functions
mocker.patch(
"app.endpoints.streaming_query.validate_question",
return_value=(False, "validation_agent_id"),
)
mocker.patch(
"app.endpoints.streaming_query.get_invalid_query_response",
return_value="Invalid query response",
)
mocker.patch(
"app.endpoints.streaming_query.interleaved_content_as_str",
return_value="REJECTED",
)
# Mock other dependencies
mocker.patch("app.endpoints.streaming_query.validate_model_provider_override")
mocker.patch("app.endpoints.streaming_query.check_configuration_loaded")
mocker.patch(
"app.endpoints.streaming_query.is_transcripts_enabled", return_value=False
)
mocker.patch(
"app.endpoints.streaming_query.select_model_and_provider_id",
return_value=("fake_model_id", "fake_model_id", "fake_provider_id"),
)
query_request = QueryRequest(query="Invalid question about unrelated topic")
request = Request(
scope={
"type": "http",
}
)
request.state.authorized_actions = set(Action)
response = await streaming_query_endpoint_handler(
request=request, query_request=query_request, auth=MOCK_AUTH
)
# Verify the response is a StreamingResponse
assert isinstance(response, StreamingResponse)
# Must be SSE
assert response.media_type == "text/event-stream"
# Collect the streaming response content
streaming_content = []
async for chunk in response.body_iterator:
streaming_content.append(chunk)
# Convert to string for assertions
full_content = "".join(streaming_content)
# Verify the response contains the invalid query response
assert "Invalid query response" in full_content
# And is wrapped as SSE tokens
assert "data: " in full_content
🤖 Prompt for AI Agents
In tests/unit/app/endpoints/test_streaming_query.py around lines 1592-1678, the
test verifies invalid-question validation returns a StreamingResponse but
doesn't assert SSE content-type; update both the handler and test: change the
streaming_query handler to return StreamingResponse(...,
media_type="text/event-stream") for the invalid-question branch (and ensure the
normal successful-stream branch also sets media_type="text/event-stream"), then
update this test to assert response.media_type == "text/event-stream" (and keep
the existing body iteration assertions) so regressions will be caught.

Copy link
Contributor

@tisnik tisnik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just question: won't it be possible to use LS provider for question validation? I dunno, having this in the core might be fine, OTOH LS prefers this functionality to be deferred to providers.

Pls see https://github.com/lightspeed-core/lightspeed-providers/tree/main/lightspeed_stack_providers/providers/inline/safety/lightspeed_question_validity

@yangcao77
Copy link
Contributor Author

just question: won't it be possible to use LS provider for question validation? I dunno, having this in the core might be fine, OTOH LS prefers this functionality to be deferred to providers.

Pls see https://github.com/lightspeed-core/lightspeed-providers/tree/main/lightspeed_stack_providers/providers/inline/safety/lightspeed_question_validity

@tisnik thanks for this information. I have tried the safety shield, it works for us.
opened a PR, lightspeed-core/lightspeed-providers#34, the imports need to be updated to make the shields work.

the only concern is the model_id needs to be preset & hard-coded.
and also with the safety shield set in llama-stack server level, the question validation will not be easily disabled via the runtime configurations

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants