-
Notifications
You must be signed in to change notification settings - Fork 63
[RHDHPAI-1097]Question validation #541
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Stephanie <yangcao@redhat.com>
Signed-off-by: Stephanie <yangcao@redhat.com>
…to question-validation
Signed-off-by: Stephanie <yangcao@redhat.com>
Signed-off-by: Stephanie <yangcao@redhat.com>
Signed-off-by: Stephanie <yangcao@redhat.com>
…to question-validation
Signed-off-by: Stephanie <yangcao@redhat.com>
Signed-off-by: Stephanie <yangcao@redhat.com>
Signed-off-by: Stephanie <yangcao@redhat.com>
Signed-off-by: Stephanie <yangcao@redhat.com>
WalkthroughAdds question validation capability. Introduces configuration schema and constants, utilities for validation prompts/responses, and a new utils.agent module. Integrates validation into query and streaming_query endpoints, propagates a query_is_valid flag, and updates transcript storage. OpenAPI, config models, and tests are extended accordingly. get_agent moved to utils.agent. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor U as User
participant API as Query/Streaming Endpoint
participant VC as Temp Validator Agent
participant AG as Agent (persistent)
participant LS as Llama Stack
participant TS as Transcript Store
U->>API: Submit query
API->>API: Check config.question_validation_enabled
alt Validation enabled
API->>VC: Create temp agent (validation prompt)
API->>VC: Ask to validate query
VC-->>API: REJECTED/ALLOWED + validation convo_id
alt REJECTED
API->>U: Invalid query response
API->>TS: Store transcript (query_is_valid=false)
else ALLOWED
API->>AG: get_agent (reuse/create)
API->>LS: Retrieve/Stream response
LS-->>API: TurnSummary/Stream
API->>TS: Store transcript (query_is_valid=true)
API-->>U: Return response/stream
end
else Validation disabled
API->>AG: get_agent (reuse/create)
API->>LS: Retrieve/Stream response
LS-->>API: TurnSummary/Stream
API->>TS: Store transcript (query_is_valid=true)
API-->>U: Return response/stream
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (22)
src/utils/agent.py (4)
75-87: Fix docstring: incorrect return description and Pylint line-too-long.Docstring claims two return values and triggers the linter error (line > 100 chars). Clarify returns and wrap lines.
Apply this diff:
-async def get_temp_agent( +async def get_temp_agent( client: AsyncLlamaStackClient, model_id: str, system_prompt: str, ) -> tuple[AsyncAgent, str, str]: - """Create a temporary agent with new agent_id and session_id. - - This function creates a new agent without persistence, shields, or tools. - Useful for temporary operations or one-off queries, such as validating a question or generating a summary. + """Create a temporary agent without persistence, shields, or tools. + + Useful for one-off operations (e.g., validate a question, generate a + summary). Args: client: The AsyncLlamaStackClient to use for the request. model_id: The ID of the model to use. system_prompt: The system prompt/instructions for the agent. Returns: - tuple[AsyncAgent, str]: A tuple containing the agent and session_id. + (agent, session_id, conversation_id). """
27-28: Document tuple order in get_agent docstring to prevent future drift.Make the return contract explicit.
Apply this diff:
- """Get existing agent or create a new one with session persistence.""" + """Get an agent with session persistence. + + Returns: + (agent, session_id, conversation_id). + """Also applies to: 67-67
39-43: no_tools only disables tool parser; should it also disable shields?If “no_tools” implies a minimal agent, consider also suppressing input/output shields; otherwise rename to reflect behavior.
Candidate change:
- input_shields=available_input_shields if available_input_shields else [], - output_shields=available_output_shields if available_output_shields else [], + input_shields=[] if no_tools else (available_input_shields or []), + output_shields=[] if no_tools else (available_output_shields or []),
51-51: Reduce log level for session listing.Session payloads can be noisy; prefer debug to avoid info-level noise.
Apply this diff:
- logger.info("session response: %s", sessions_response) + logger.debug("session response: %s", sessions_response)src/constants.py (1)
35-41: Tighten the default validation prompt for deterministic outputs.Ask the validator to respond with exactly one token and nothing else to reduce parsing ambiguity.
Apply this diff:
DEFAULT_VALIDATION_SYSTEM_PROMPT = ( - "You are a helpful assistant that validates questions. You will be given a " - "question and you will need to validate if it is valid or not. You will " - f"return '{SUBJECT_REJECTED}' if the question is not valid and " - f"'{SUBJECT_ALLOWED}' if it is valid." + "You validate user questions. Decide if the question is acceptable.\n" + f"Reply with exactly one token: '{SUBJECT_ALLOWED}' or '{SUBJECT_REJECTED}'.\n" + "Do not include explanations, punctuation, or extra text." )tests/unit/models/config/test_dump_configuration.py (1)
91-91: Dump includes question_validation section (LGTM).Assertions cover presence and default False value.
Consider a follow-up test that sets question_validation_enabled=True and verifies serialization.
Also applies to: 167-168
src/models/config.py (1)
430-453: Avoid mutable default dicts in dataclass fields.Using Field(default={}) shares the same dict across instances. Switch to default_factory for both prompts and query_responses. Update prompts too for consistency.
Apply this diff:
- prompts: dict[str, str] = Field(default={}, init=False) - query_responses: dict[str, str] = Field(default={}, init=False) + prompts: dict[str, str] = Field(default_factory=dict, init=False) + query_responses: dict[str, str] = Field(default_factory=dict, init=False)If desired, also return shallow copies to discourage mutation:
- return self.prompts + return dict(self.prompts) - return self.query_responses + return dict(self.query_responses)src/utils/endpoints.py (3)
66-77: Handle whitespace-only custom validation promptReturn default when the custom profile’s "validation" prompt is empty or whitespace-only.
Apply:
def get_validation_system_prompt(config: AppConfig) -> str: """Get the validation system prompt.""" # profile takes precedence for setting prompt if ( config.customization is not None and config.customization.custom_profile is not None ): - prompt = config.customization.custom_profile.get_prompts().get("validation") - if prompt: - return prompt + prompt = config.customization.custom_profile.get_prompts().get("validation") + if prompt and prompt.strip(): + return prompt.strip() return constants.DEFAULT_VALIDATION_SYSTEM_PROMPT
80-91: Treat whitespace-only invalid response as unsetMirror the guard above for "invalid_resp".
Apply:
def get_invalid_query_response(config: AppConfig) -> str: """Get the invalid query response.""" if ( config.customization is not None and config.customization.custom_profile is not None ): - prompt = config.customization.custom_profile.get_query_responses().get( - "invalid_resp" - ) - if prompt: - return prompt + prompt = config.customization.custom_profile.get_query_responses().get( + "invalid_resp" + ) + if prompt and prompt.strip(): + return prompt.strip() return constants.DEFAULT_INVALID_QUERY_RESPONSE
113-116: Comment nit: wrong option nameThe comment mentions disable_system_prompt but the flag is disable_query_system_prompt.
Apply:
- # makes sense here - if the configuration wants precedence, it can - # disable query system prompt altogether with disable_system_prompt. + # makes sense here - if the configuration wants precedence, it can + # disable query system prompt altogether with disable_query_system_prompt.tests/unit/utils/test_endpoints.py (5)
264-264: Fix Pylint: line too longShorten docstring to satisfy 100-char limit.
Apply:
-def test_get_validation_system_prompt_with_custom_profile_no_validation_prompt(): - """Test that default validation system prompt is returned when custom profile has no validation prompt.""" +def test_get_validation_system_prompt_with_custom_profile_no_validation_prompt(): + """Default validation prompt when profile lacks 'validation' key."""
281-281: Fix Pylint: line too longShorter docstring.
Apply:
-def test_get_validation_system_prompt_with_custom_profile_empty_validation_prompt(): - """Test that default validation system prompt is returned when custom profile has empty validation prompt.""" +def test_get_validation_system_prompt_with_custom_profile_empty_validation_prompt(): + """Default validation prompt when profile sets 'validation' to empty."""
327-327: Fix Pylint: line too longShorter docstring.
Apply:
-def test_get_invalid_query_response_with_custom_profile_no_invalid_resp(): - """Test that default invalid query response is returned when custom profile has no invalid_resp.""" +def test_get_invalid_query_response_with_custom_profile_no_invalid_resp(): + """Default invalid response when profile lacks 'invalid_resp'."""
343-343: Fix Pylint: line too longShorter docstring.
Apply:
-def test_get_invalid_query_response_with_custom_profile_empty_invalid_resp(): - """Test that default invalid query response is returned when custom profile has empty invalid_resp.""" +def test_get_invalid_query_response_with_custom_profile_empty_invalid_resp(): + """Default invalid response when 'invalid_resp' is empty."""
235-256: Consider adding whitespace-only coverageAdd cases where profile returns " " for "validation" and "invalid_resp" to assert fallback to defaults.
docs/openapi.json (2)
1055-1058: Document the new Configuration.question_validation fieldAdd a brief description for discoverability.
Apply:
"question_validation": { - "$ref": "#/components/schemas/QuestionValidationConfiguration" + "$ref": "#/components/schemas/QuestionValidationConfiguration", + "description": "Feature toggle and settings for pre-query validation." }
2293-2305: Clarify property semanticsDescribe what enabling validation does to client behavior. No wire-format changes.
Apply:
"QuestionValidationConfiguration": { "properties": { "question_validation_enabled": { "type": "boolean", "title": "Question Validation Enabled", - "default": false + "default": false, + "description": "When true, incoming questions are vetted by a lightweight validator before invoking the main model. Invalid questions return a configured message and are recorded in transcripts." } },tests/unit/app/endpoints/test_streaming_query.py (1)
1785-1797: Also assert SSE content-type for valid-path streamingAdd a quick check to ensure we always return event-stream.
Apply:
@@ # Verify the response is a StreamingResponse assert isinstance(response, StreamingResponse) + assert response.media_type == "text/event-stream"tests/unit/app/endpoints/test_query.py (1)
173-176: Docstring drift in retrieve_response (prod code)retrieve_response now returns a triple (summary, conversation_id, is_valid) but its docstring (src/app/endpoints/query.py) still advertises a pair. Update the docstring to avoid confusion.
Suggested change (in src/app/endpoints/query.py):
- Returns: - tuple[TurnSummary, str]: A tuple containing a summary of the LLM or agent's response content - and the conversation ID. + Returns: + tuple[TurnSummary, str, bool]: (summary, conversation_id, query_is_valid).src/app/endpoints/query.py (1)
396-408: Fix docstring: function returns (bool, str), not “one-word response”.Docstring is misleading and omits the conversation_id in the return. Clarify behavior.
-async def validate_question( - question: str, client: AsyncLlamaStackClient, model_id: str -) -> tuple[bool, str]: - """Validate a question and provides a one-word response. +async def validate_question( + question: str, client: AsyncLlamaStackClient, model_id: str +) -> tuple[bool, str]: + """Validate a question using a temporary agent and return the result. @@ - Returns: - bool: True if the question was deemed valid, False otherwise + Returns: + tuple[bool, str]: (is_valid, validation_conversation_id)src/app/endpoints/streaming_query.py (1)
713-713: Set explicit SSE media type for the valid streaming path.Helps clients and proxies treat the stream correctly.
- return StreamingResponse(response_generator(response, query_is_valid)) + return StreamingResponse( + response_generator(response, query_is_valid), + media_type="text/event-stream", + )tests/unit/utils/test_agent.py (1)
127-176: Reduce duplication via parametrization.Several tests differ only in shields or MCP servers. Consider pytest.mark.parametrize to shrink boilerplate and speed execution.
Also applies to: 179-229, 231-283
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
docs/openapi.json(3 hunks)src/app/endpoints/query.py(8 hunks)src/app/endpoints/streaming_query.py(6 hunks)src/configuration.py(2 hunks)src/constants.py(1 hunks)src/models/config.py(4 hunks)src/utils/agent.py(1 hunks)src/utils/endpoints.py(1 hunks)tests/unit/app/endpoints/test_query.py(37 hunks)tests/unit/app/endpoints/test_streaming_query.py(4 hunks)tests/unit/models/config/test_dump_configuration.py(2 hunks)tests/unit/utils/test_agent.py(1 hunks)tests/unit/utils/test_endpoints.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-08-06T06:02:21.060Z
Learnt from: eranco74
PR: lightspeed-core/lightspeed-stack#348
File: src/utils/endpoints.py:91-94
Timestamp: 2025-08-06T06:02:21.060Z
Learning: The direct assignment to `agent._agent_id` in `src/utils/endpoints.py` is a necessary workaround for the missing agent rehydration feature in the LLS client SDK. This allows preserving conversation IDs when handling existing agents.
Applied to files:
src/utils/agent.py
🧬 Code graph analysis (10)
src/configuration.py (1)
src/models/config.py (1)
QuestionValidationConfiguration(235-238)
tests/unit/utils/test_endpoints.py (4)
src/models/requests.py (1)
QueryRequest(72-222)src/models/config.py (4)
Action(311-350)CustomProfile(425-452)get_prompts(446-448)get_query_responses(450-452)src/utils/endpoints.py (3)
validate_model_provider_override(137-157)get_validation_system_prompt(66-77)get_invalid_query_response(80-91)src/configuration.py (3)
AppConfig(33-140)init_from_dict(56-58)customization(115-119)
tests/unit/app/endpoints/test_streaming_query.py (4)
src/configuration.py (1)
question_validation(136-140)tests/unit/app/endpoints/test_query.py (2)
mock_metrics(50-55)mock_database_operations(58-63)src/models/requests.py (1)
QueryRequest(72-222)src/app/endpoints/streaming_query.py (1)
streaming_query_endpoint_handler(530-725)
src/utils/agent.py (2)
src/utils/suid.py (1)
get_suid(6-12)src/utils/types.py (2)
GraniteToolParser(26-40)get_parser(36-40)
tests/unit/app/endpoints/test_query.py (4)
src/configuration.py (1)
question_validation(136-140)src/app/endpoints/query.py (2)
retrieve_response(427-588)query_endpoint_handler(156-271)tests/unit/app/endpoints/test_conversations.py (1)
dummy_request(30-40)tests/unit/app/endpoints/test_streaming_query.py (2)
mock_metrics(62-67)mock_database_operations(53-59)
src/app/endpoints/query.py (4)
src/utils/agent.py (2)
get_agent(18-67)get_temp_agent(70-101)src/utils/endpoints.py (3)
check_configuration_loaded(57-63)get_invalid_query_response(80-91)get_validation_system_prompt(66-77)src/configuration.py (2)
configuration(61-65)question_validation(136-140)src/utils/types.py (1)
TurnSummary(59-78)
src/utils/endpoints.py (2)
src/models/config.py (3)
config(132-138)get_prompts(446-448)get_query_responses(450-452)src/configuration.py (2)
AppConfig(33-140)customization(115-119)
src/models/config.py (1)
src/configuration.py (1)
question_validation(136-140)
src/app/endpoints/streaming_query.py (6)
src/utils/agent.py (1)
get_agent(18-67)src/utils/endpoints.py (2)
get_system_prompt(94-134)get_invalid_query_response(80-91)src/utils/transcripts.py (1)
store_transcript(33-86)src/utils/types.py (1)
TurnSummary(59-78)src/app/endpoints/query.py (3)
get_rag_toolgroups(623-650)validate_question(396-424)is_transcripts_enabled(73-79)src/configuration.py (2)
configuration(61-65)question_validation(136-140)
tests/unit/utils/test_agent.py (2)
src/configuration.py (3)
configuration(61-65)AppConfig(33-140)init_from_dict(56-58)src/utils/agent.py (2)
get_agent(18-67)get_temp_agent(70-101)
🪛 GitHub Actions: Python linter
tests/unit/utils/test_endpoints.py
[error] 264-264: Pylint: Line too long (110/100) in tests/unit/utils/test_endpoints.py:264. Command: 'uv run pylint src tests'
[error] 281-281: Pylint: Line too long (113/100) in tests/unit/utils/test_endpoints.py:281. Command: 'uv run pylint src tests'
[error] 327-327: Pylint: Line too long (103/100) in tests/unit/utils/test_endpoints.py:327. Command: 'uv run pylint src tests'
[error] 343-343: Pylint: Line too long (106/100) in tests/unit/utils/test_endpoints.py:343. Command: 'uv run pylint src tests'
src/utils/agent.py
[error] 78-78: Pylint: Line too long (110/100) in src/utils/agent.py:78. Command: 'uv run pylint src tests'
src/app/endpoints/streaming_query.py
[error] 530-530: Pylint: Too many statements (55/50) in src/app/endpoints/streaming_query.py:530. Command: 'uv run pylint src tests'
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-pr
- GitHub Check: e2e_tests
🔇 Additional comments (12)
src/utils/agent.py (2)
46-49: Agent ID rebind workaround is acceptable.Directly setting agent._agent_id to preserve conversation continuity matches prior guidance and keeps transcripts coherent.
Please confirm this remains required with current Llama Stack SDK; if rehydration lands, we can remove the create/delete swap.
30-33: Catch the client's NotFoundError instead of suppressing ValueErrorFile: src/utils/agent.py lines 30–33
with suppress(ValueError): agent_response = await client.agents.retrieve(agent_id=conversation_id) existing_agent_id = agent_response.agent_idsuppress(ValueError) only hides ValueError; client.agents.retrieve will surface client-specific errors (e.g., llama_stack_client.NotFoundError or other API errors) when an agent is missing — catch the client's NotFoundError (or the client's base API error) explicitly (import from llama_stack_client) or broaden the handled exception after confirming the client contract.
src/configuration.py (1)
22-23: Expose question_validation in AppConfig (LGTM).Import and property wiring look correct and consistent with the rest of the accessors.
Also applies to: 135-141
src/constants.py (1)
41-42: Default invalid-query response (LGTM).Reasonable default; can be overridden via config/profile.
src/models/config.py (2)
235-239: QuestionValidationConfiguration addition (LGTM).Sane default off; integrates cleanly with top-level Configuration.
512-515: Configuration.question_validation wiring (LGTM).Field with default_factory ensures presence in serialized config.
docs/openapi.json (1)
1293-1300: LGTM: CustomProfile.query_responses exposureSchema aligns with code (get_query_responses), enabling configuration-driven invalid responses.
src/app/endpoints/query.py (2)
499-514: Validation short‑circuit in retrieve_response: LGTM.Early return with a friendly invalid‑query message and transcript flagging is correct.
220-227: Resolved — callers already handle the 3‑tuple return
- src/app/endpoints/query.py:220 unpacks (summary, conversation_id, query_is_valid).
- src/app/endpoints/streaming_query.py calls its local retrieve_response (returns (response, conversation_id)) and is unaffected.
- tests/unit/app/endpoints/test_query.py use 3‑tuple unpacking; streaming tests use 2‑tuple.
tests/unit/utils/test_agent.py (2)
10-36: Config fixture setup: LGTM.Good minimal AppConfig for isolating agent tests.
385-441: Parser path coverage: LGTM.Solid check that GraniteToolParser is honored when no_tools=False.
src/app/endpoints/streaming_query.py (1)
530-530: Extract helpers to lower statements in streaming_query_endpoint_handlerExtract pre-validation and the "invalid transcript" SSE path into top-level helpers so the handler stays under pylint's statements limit.
File: src/app/endpoints/streaming_query.py
Add outside the handler:
def build_invalid_query_sse(conversation_id: str, text: str): return iter([ stream_start_event(conversation_id), format_stream_data({"event": "turn_complete", "data": {"id": 0, "token": text}}), stream_end_event({}), ])Inside the handler, replace the invalid-path return with:
return StreamingResponse(build_invalid_query_sse(conv_id, invalid_text), media_type="text/event-stream")Verify locally (corrected command):
python -m pylint src/app/endpoints/streaming_query.py -j 0
| response = cast(Turn, response) | ||
| return ( | ||
| constants.SUBJECT_REJECTED | ||
| not in interleaved_content_as_str(response.output_message.content), | ||
| conversation_id, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Harden against missing output_message/content to avoid AttributeError.
Some turns may lack output_message.content; this will crash. Guard before parsing.
- response = cast(Turn, response)
- return (
- constants.SUBJECT_REJECTED
- not in interleaved_content_as_str(response.output_message.content),
- conversation_id,
- )
+ response = cast(Turn, response)
+ output_content = ""
+ if getattr(response, "output_message", None) and getattr(response.output_message, "content", None) is not None:
+ output_content = interleaved_content_as_str(response.output_message.content)
+ is_valid = constants.SUBJECT_REJECTED not in output_content
+ return (is_valid, conversation_id)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| response = cast(Turn, response) | |
| return ( | |
| constants.SUBJECT_REJECTED | |
| not in interleaved_content_as_str(response.output_message.content), | |
| conversation_id, | |
| ) | |
| response = cast(Turn, response) | |
| output_content = "" | |
| if getattr(response, "output_message", None) and getattr(response.output_message, "content", None) is not None: | |
| output_content = interleaved_content_as_str(response.output_message.content) | |
| is_valid = constants.SUBJECT_REJECTED not in output_content | |
| return (is_valid, conversation_id) |
🤖 Prompt for AI Agents
In src/app/endpoints/query.py around lines 419 to 424, the code assumes
response.output_message.content always exists which can raise AttributeError for
some turns; update the logic to safely handle missing output_message or missing
content by checking that response and response.output_message are not None and
that response.output_message.content exists (or use getattr with a default empty
list/string) before calling interleaved_content_as_str, and treat missing
content as an empty string so the boolean expression becomes False without
throwing; return the same tuple (subject_rejected_flag, conversation_id) using
the guarded value.
| # Check question validation before getting response | ||
| query_is_valid = True | ||
| if configuration.question_validation.question_validation_enabled: | ||
| query_is_valid, temp_agent_conversation_id = await validate_question( | ||
| query_request.query, client, llama_stack_model_id | ||
| ) | ||
| if not query_is_valid: | ||
| response = get_invalid_query_response(configuration) | ||
| if not is_transcripts_enabled(): | ||
| logger.debug( | ||
| "Transcript collection is disabled in the configuration" | ||
| ) | ||
| else: | ||
| summary = TurnSummary( | ||
| llm_response=response, | ||
| tool_calls=[], | ||
| ) | ||
| store_transcript( | ||
| user_id=user_id, | ||
| conversation_id=query_request.conversation_id | ||
| or temp_agent_conversation_id, | ||
| model_id=model_id, | ||
| provider_id=provider_id, | ||
| query_is_valid=query_is_valid, | ||
| query=query_request.query, | ||
| query_request=query_request, | ||
| summary=summary, | ||
| rag_chunks=[], # TODO(lucasagomes): implement rag_chunks | ||
| truncated=False, # TODO(lucasagomes): implement truncation as part | ||
| # of quota work | ||
| attachments=query_request.attachments or [], | ||
| ) | ||
| return StreamingResponse(response) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Invalid streaming path returns plain text, breaking SSE clients.
For invalid queries you return StreamingResponse(str), not SSE events. Emit start/turn_complete/end events to keep the contract.
- if configuration.question_validation.question_validation_enabled:
- query_is_valid, temp_agent_conversation_id = await validate_question(
- query_request.query, client, llama_stack_model_id
- )
- if not query_is_valid:
- response = get_invalid_query_response(configuration)
- if not is_transcripts_enabled():
- logger.debug(
- "Transcript collection is disabled in the configuration"
- )
- else:
- summary = TurnSummary(
- llm_response=response,
- tool_calls=[],
- )
- store_transcript(
- user_id=user_id,
- conversation_id=query_request.conversation_id
- or temp_agent_conversation_id,
- model_id=model_id,
- provider_id=provider_id,
- query_is_valid=query_is_valid,
- query=query_request.query,
- query_request=query_request,
- summary=summary,
- rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
- truncated=False, # TODO(lucasagomes): implement truncation as part
- # of quota work
- attachments=query_request.attachments or [],
- )
- return StreamingResponse(response)
+ if configuration.question_validation.question_validation_enabled:
+ query_is_valid, temp_agent_conversation_id = await validate_question(
+ query_request.query, client, llama_stack_model_id
+ )
+ if not query_is_valid:
+ invalid_text = get_invalid_query_response(configuration)
+ if is_transcripts_enabled():
+ summary = TurnSummary(llm_response=invalid_text, tool_calls=[])
+ store_transcript(
+ user_id=user_id,
+ conversation_id=query_request.conversation_id or temp_agent_conversation_id,
+ model_id=model_id,
+ provider_id=provider_id,
+ query_is_valid=query_is_valid,
+ query=query_request.query,
+ query_request=query_request,
+ summary=summary,
+ rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
+ truncated=False, # TODO(lucasagomes): implement truncation as part of quota work
+ attachments=query_request.attachments or [],
+ )
+ # Return minimal SSE stream for invalid query
+ conv_id = query_request.conversation_id or temp_agent_conversation_id
+ return StreamingResponse(
+ iter(
+ [
+ stream_start_event(conv_id),
+ format_stream_data({"event": "turn_complete", "data": {"id": 0, "token": invalid_text}}),
+ stream_end_event({}),
+ ]
+ ),
+ media_type="text/event-stream",
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Check question validation before getting response | |
| query_is_valid = True | |
| if configuration.question_validation.question_validation_enabled: | |
| query_is_valid, temp_agent_conversation_id = await validate_question( | |
| query_request.query, client, llama_stack_model_id | |
| ) | |
| if not query_is_valid: | |
| response = get_invalid_query_response(configuration) | |
| if not is_transcripts_enabled(): | |
| logger.debug( | |
| "Transcript collection is disabled in the configuration" | |
| ) | |
| else: | |
| summary = TurnSummary( | |
| llm_response=response, | |
| tool_calls=[], | |
| ) | |
| store_transcript( | |
| user_id=user_id, | |
| conversation_id=query_request.conversation_id | |
| or temp_agent_conversation_id, | |
| model_id=model_id, | |
| provider_id=provider_id, | |
| query_is_valid=query_is_valid, | |
| query=query_request.query, | |
| query_request=query_request, | |
| summary=summary, | |
| rag_chunks=[], # TODO(lucasagomes): implement rag_chunks | |
| truncated=False, # TODO(lucasagomes): implement truncation as part | |
| # of quota work | |
| attachments=query_request.attachments or [], | |
| ) | |
| return StreamingResponse(response) | |
| # Check question validation before getting response | |
| query_is_valid = True | |
| if configuration.question_validation.question_validation_enabled: | |
| query_is_valid, temp_agent_conversation_id = await validate_question( | |
| query_request.query, client, llama_stack_model_id | |
| ) | |
| if not query_is_valid: | |
| invalid_text = get_invalid_query_response(configuration) | |
| if is_transcripts_enabled(): | |
| summary = TurnSummary(llm_response=invalid_text, tool_calls=[]) | |
| store_transcript( | |
| user_id=user_id, | |
| conversation_id=query_request.conversation_id or temp_agent_conversation_id, | |
| model_id=model_id, | |
| provider_id=provider_id, | |
| query_is_valid=query_is_valid, | |
| query=query_request.query, | |
| query_request=query_request, | |
| summary=summary, | |
| rag_chunks=[], # TODO(lucasagomes): implement rag_chunks | |
| truncated=False, # TODO(lucasagomes): implement truncation as part of quota work | |
| attachments=query_request.attachments or [], | |
| ) | |
| # Return minimal SSE stream for invalid query | |
| conv_id = query_request.conversation_id or temp_agent_conversation_id | |
| return StreamingResponse( | |
| iter( | |
| [ | |
| stream_start_event(conv_id), | |
| format_stream_data({"event": "turn_complete", "data": {"id": 0, "token": invalid_text}}), | |
| stream_end_event({}), | |
| ] | |
| ), | |
| media_type="text/event-stream", | |
| ) |
| conversation_id = agent.agent_id | ||
| session_id = await agent.create_session(get_suid()) | ||
|
|
||
| return agent, conversation_id, session_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Standardize return tuple order between get_agent and get_temp_agent.
get_agent returns (agent, conversation_id, session_id) while get_temp_agent returns (agent, session_id, conversation_id). This will break unpacking at call sites. Standardize on (agent, session_id, conversation_id) for both.
Apply this diff:
- return agent, conversation_id, session_id
+ return agent, session_id, conversation_idRun to find unpacking patterns and usages:
Also applies to: 101-101
🏁 Script executed:
#!/bin/bash
# Expect unpacking or indexing to match (agent, session_id, conversation_id)
rg -n -C2 -P '\bget_(temp_)?agent\s*\('
rg -n -C2 -P '\(\s*\w+\s*,\s*\w+\s*,\s*\w+\s*\)\s*=\s*await\s+get_(temp_)?agent'Length of output: 5262
Standardize return tuple order — make get_temp_agent match get_agent (agent, conversation_id, session_id)
get_agent returns (agent, conversation_id, session_id) and is unpacked that way across the codebase/tests; get_temp_agent returns (agent, session_id, conversation_id) causing inconsistent unpacking. Do NOT change get_agent without updating callers/tests. Action: change get_temp_agent to return (agent, conversation_id, session_id) and update its callers/tests accordingly.
Examples of affected locations: src/app/endpoints/query.py (lines 410, 487), src/app/endpoints/streaming_query.py (line 783), tests/unit/utils/test_agent.py (lines 56, 100, 153, 205, 259, 358, 416, 456, 494, 526).
🤖 Prompt for AI Agents
In src/utils/agent.py around line 67, get_temp_agent currently returns (agent,
session_id, conversation_id) which is inconsistent with get_agent; change the
return tuple to (agent, conversation_id, session_id) to match get_agent, then
update all callers and tests that unpack the old order—specifically adjust
usages in src/app/endpoints/query.py (lines ~410, ~487),
src/app/endpoints/streaming_query.py (~783) and the listed tests in
tests/unit/utils/test_agent.py (lines ~56, 100, 153, 205, 259, 358, 416, 456,
494, 526) to expect and unpack (agent, conversation_id, session_id).
| @pytest.mark.asyncio | ||
| async def test_streaming_query_endpoint_with_question_validation_invalid_query( | ||
| setup_configuration, mocker | ||
| ): | ||
| """Test streaming query endpoint with question validation enabled and invalid query.""" | ||
| # Mock metrics | ||
| mock_metrics(mocker) | ||
|
|
||
| # Mock database operations | ||
| mock_database_operations(mocker) | ||
|
|
||
| # Setup configuration with question validation enabled | ||
| setup_configuration.question_validation.question_validation_enabled = True | ||
| mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) | ||
|
|
||
| # Mock the client | ||
| mock_client = mocker.AsyncMock() | ||
| mock_client.models.list.return_value = [ | ||
| mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1") | ||
| ] | ||
| mocker.patch( | ||
| "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client | ||
| ) | ||
|
|
||
| # Mock the validation agent response (invalid query) | ||
| mock_validation_agent = mocker.AsyncMock() | ||
| mock_validation_agent.agent_id = "validation_agent_id" | ||
| mock_validation_agent.create_session.return_value = "validation_session_id" | ||
|
|
||
| # Mock the validation response that contains SUBJECT_REJECTED | ||
| mock_validation_turn = mocker.Mock() | ||
| mock_validation_turn.output_message.content = [{"type": "text", "text": "REJECTED"}] | ||
| mock_validation_agent.create_turn.return_value = mock_validation_turn | ||
|
|
||
| # Mock the validation functions | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query.validate_question", | ||
| return_value=(False, "validation_agent_id"), | ||
| ) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query.get_invalid_query_response", | ||
| return_value="Invalid query response", | ||
| ) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query.interleaved_content_as_str", | ||
| return_value="REJECTED", | ||
| ) | ||
|
|
||
| # Mock other dependencies | ||
| mocker.patch("app.endpoints.streaming_query.validate_model_provider_override") | ||
| mocker.patch("app.endpoints.streaming_query.check_configuration_loaded") | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False | ||
| ) | ||
| mocker.patch( | ||
| "app.endpoints.streaming_query.select_model_and_provider_id", | ||
| return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), | ||
| ) | ||
|
|
||
| query_request = QueryRequest(query="Invalid question about unrelated topic") | ||
|
|
||
| request = Request( | ||
| scope={ | ||
| "type": "http", | ||
| } | ||
| ) | ||
| request.state.authorized_actions = set(Action) | ||
|
|
||
| response = await streaming_query_endpoint_handler( | ||
| request=request, query_request=query_request, auth=MOCK_AUTH | ||
| ) | ||
|
|
||
| # Verify the response is a StreamingResponse | ||
| assert isinstance(response, StreamingResponse) | ||
|
|
||
| # Collect the streaming response content | ||
| streaming_content = [] | ||
| async for chunk in response.body_iterator: | ||
| streaming_content.append(chunk) | ||
|
|
||
| # Convert to string for assertions | ||
| full_content = "".join(streaming_content) | ||
|
|
||
| # Verify the response contains the invalid query response | ||
| assert "Invalid query response" in full_content | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Streaming invalid path should emit proper SSE with event-stream content-type
Current streaming endpoint returns StreamingResponse with a raw string for invalid queries, which is not SSE and likely sets the wrong media type. Strengthen this test and fix the handler to always emit SSE.
Apply to the test to catch regressions:
@@
response = await streaming_query_endpoint_handler(
request=request, query_request=query_request, auth=MOCK_AUTH
)
# Verify the response is a StreamingResponse
assert isinstance(response, StreamingResponse)
+ # Must be SSE
+ assert response.media_type == "text/event-stream"
@@
- # Verify the response contains the invalid query response
+ # Verify the response contains the invalid query response
assert "Invalid query response" in full_content
+ # And is wrapped as SSE tokens
+ assert "data: " in full_contentProposed handler fix (in app/endpoints/streaming_query.py) to ensure SSE for invalid path:
@@ async def streaming_query_endpoint_handler(...):
- if configuration.question_validation.question_validation_enabled:
+ if configuration.question_validation.question_validation_enabled:
query_is_valid, temp_agent_conversation_id = await validate_question(
query_request.query, client, llama_stack_model_id
)
if not query_is_valid:
- response = get_invalid_query_response(configuration)
- if not is_transcripts_enabled():
- logger.debug(
- "Transcript collection is disabled in the configuration"
- )
- else:
- summary = TurnSummary(
- llm_response=response,
- tool_calls=[],
- )
- store_transcript(
- user_id=user_id,
- conversation_id=query_request.conversation_id
- or temp_agent_conversation_id,
- model_id=model_id,
- provider_id=provider_id,
- query_is_valid=query_is_valid,
- query=query_request.query,
- query_request=query_request,
- summary=summary,
- rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
- truncated=False, # TODO(lucasagomes): implement truncation as part
- # of quota work
- attachments=query_request.attachments or [],
- )
- return StreamingResponse(response)
+ invalid_msg = get_invalid_query_response(configuration)
+ async def invalid_sse() -> AsyncIterator[str]:
+ yield stream_start_event(query_request.conversation_id or temp_agent_conversation_id or "")
+ # send one token event with the invalid message
+ yield f'data: {json.dumps({"event": "token", "token": invalid_msg, "role": "system"})}\n\n'
+ yield stream_end_event({})
+ if is_transcripts_enabled():
+ store_transcript(
+ user_id=user_id,
+ conversation_id=query_request.conversation_id or temp_agent_conversation_id,
+ model_id=model_id,
+ provider_id=provider_id,
+ query_is_valid=False,
+ query=query_request.query,
+ query_request=query_request,
+ summary=TurnSummary(llm_response=invalid_msg, tool_calls=[]),
+ rag_chunks=[],
+ truncated=False,
+ attachments=query_request.attachments or [],
+ )
+ return StreamingResponse(invalid_sse(), media_type="text/event-stream")🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Search for handler definition and related symbols =="
rg -n --hidden --no-ignore-vcs "async def streaming_query_endpoint_handler" -S || true
rg -n --hidden --no-ignore-vcs "def streaming_query_endpoint_handler" -S || true
rg -n --hidden --no-ignore-vcs "streaming_query_endpoint_handler" -S || true
echo
echo "== Locate app/endpoints/streaming_query.py =="
FILE=$(rg --files | rg "app/endpoints/streaming_query.py" || true)
if [ -n "$FILE" ]; then
echo "Found: $FILE"
echo "---- head (1-300) ----"
sed -n '1,300p' "$FILE"
echo "---- tail (301-800) ----"
sed -n '301,800p' "$FILE" || true
else
echo "app/endpoints/streaming_query.py not found by exact path; searching for files mentioning 'streaming_query'"
rg -n --hidden --no-ignore-vcs "streaming_query" -S -g '!**/node_modules/**' || true
fi
echo
echo "== Search for helpers and patterns referenced in review comment =="
rg -n --hidden --no-ignore-vcs "get_invalid_query_response" -S -C3 || true
rg -n --hidden --no-ignore-vcs "validate_question" -S -C3 || true
rg -n --hidden --no-ignore-vcs "is_transcripts_enabled" -S -C3 || true
rg -n --hidden --no-ignore-vcs "store_transcript" -S -C3 || true
rg -n --hidden --no-ignore-vcs "stream_start_event|stream_end_event" -S -C3 || true
rg -n --hidden --no-ignore-vcs "StreamingResponse" -S -C3 || true
rg -n --hidden --no-ignore-vcs "text/event-stream" -S -C3 || true
echo
echo "== Show any StreamingResponse returns inside streaming_query module(s) =="
rg -n --hidden --no-ignore-vcs "return\s+StreamingResponse" -S -C5 || true
echo
echo "== Done =="Length of output: 78021
Streaming invalid path should emit proper SSE with event-stream content-type
Handler currently returns a raw string via StreamingResponse for invalid-question validation; emit SSE and set media_type="text/event-stream". Strengthen test to catch regressions. Also set media_type for the normal streaming return.
Apply to the test to catch regressions:
@@
response = await streaming_query_endpoint_handler(
request=request, query_request=query_request, auth=MOCK_AUTH
)
# Verify the response is a StreamingResponse
assert isinstance(response, StreamingResponse)
+ # Must be SSE
+ assert response.media_type == "text/event-stream"
@@
- # Verify the response contains the invalid query response
+ # Verify the response contains the invalid query response
assert "Invalid query response" in full_content
+ # And is wrapped as SSE tokens
+ assert "data: " in full_contentProposed handler fix (in src/app/endpoints/streaming_query.py) to ensure SSE for invalid path:
@@ async def streaming_query_endpoint_handler(...):
- if configuration.question_validation.question_validation_enabled:
+ if configuration.question_validation.question_validation_enabled:
query_is_valid, temp_agent_conversation_id = await validate_question(
query_request.query, client, llama_stack_model_id
)
if not query_is_valid:
- response = get_invalid_query_response(configuration)
- if not is_transcripts_enabled():
- logger.debug(
- "Transcript collection is disabled in the configuration"
- )
- else:
- summary = TurnSummary(
- llm_response=response,
- tool_calls=[],
- )
- store_transcript(
- user_id=user_id,
- conversation_id=query_request.conversation_id
- or temp_agent_conversation_id,
- model_id=model_id,
- provider_id=provider_id,
- query_is_valid=query_is_valid,
- query=query_request.query,
- query_request=query_request,
- summary=summary,
- rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
- truncated=False, # TODO(lucasagomes): implement truncation as part
- # of quota work
- attachments=query_request.attachments or [],
- )
- return StreamingResponse(response)
+ invalid_msg = get_invalid_query_response(configuration)
+ async def invalid_sse() -> AsyncIterator[str]:
+ yield stream_start_event(query_request.conversation_id or temp_agent_conversation_id or "")
+ # send one token event with the invalid message
+ yield f'data: {json.dumps({"event": "token", "token": invalid_msg, "role": "system"})}\n\n'
+ yield stream_end_event({})
+ if is_transcripts_enabled():
+ store_transcript(
+ user_id=user_id,
+ conversation_id=query_request.conversation_id or temp_agent_conversation_id,
+ model_id=model_id,
+ provider_id=provider_id,
+ query_is_valid=False,
+ query=query_request.query,
+ query_request=query_request,
+ summary=TurnSummary(llm_response=invalid_msg, tool_calls=[]),
+ rag_chunks=[],
+ truncated=False,
+ attachments=query_request.attachments or [],
+ )
+ return StreamingResponse(invalid_sse(), media_type="text/event-stream")Also change the normal successful stream return to explicitly set media_type, e.g.:
- return StreamingResponse(response_generator(response, query_is_valid))
+ return StreamingResponse(response_generator(response, query_is_valid), media_type="text/event-stream")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @pytest.mark.asyncio | |
| async def test_streaming_query_endpoint_with_question_validation_invalid_query( | |
| setup_configuration, mocker | |
| ): | |
| """Test streaming query endpoint with question validation enabled and invalid query.""" | |
| # Mock metrics | |
| mock_metrics(mocker) | |
| # Mock database operations | |
| mock_database_operations(mocker) | |
| # Setup configuration with question validation enabled | |
| setup_configuration.question_validation.question_validation_enabled = True | |
| mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) | |
| # Mock the client | |
| mock_client = mocker.AsyncMock() | |
| mock_client.models.list.return_value = [ | |
| mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1") | |
| ] | |
| mocker.patch( | |
| "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client | |
| ) | |
| # Mock the validation agent response (invalid query) | |
| mock_validation_agent = mocker.AsyncMock() | |
| mock_validation_agent.agent_id = "validation_agent_id" | |
| mock_validation_agent.create_session.return_value = "validation_session_id" | |
| # Mock the validation response that contains SUBJECT_REJECTED | |
| mock_validation_turn = mocker.Mock() | |
| mock_validation_turn.output_message.content = [{"type": "text", "text": "REJECTED"}] | |
| mock_validation_agent.create_turn.return_value = mock_validation_turn | |
| # Mock the validation functions | |
| mocker.patch( | |
| "app.endpoints.streaming_query.validate_question", | |
| return_value=(False, "validation_agent_id"), | |
| ) | |
| mocker.patch( | |
| "app.endpoints.streaming_query.get_invalid_query_response", | |
| return_value="Invalid query response", | |
| ) | |
| mocker.patch( | |
| "app.endpoints.streaming_query.interleaved_content_as_str", | |
| return_value="REJECTED", | |
| ) | |
| # Mock other dependencies | |
| mocker.patch("app.endpoints.streaming_query.validate_model_provider_override") | |
| mocker.patch("app.endpoints.streaming_query.check_configuration_loaded") | |
| mocker.patch( | |
| "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False | |
| ) | |
| mocker.patch( | |
| "app.endpoints.streaming_query.select_model_and_provider_id", | |
| return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), | |
| ) | |
| query_request = QueryRequest(query="Invalid question about unrelated topic") | |
| request = Request( | |
| scope={ | |
| "type": "http", | |
| } | |
| ) | |
| request.state.authorized_actions = set(Action) | |
| response = await streaming_query_endpoint_handler( | |
| request=request, query_request=query_request, auth=MOCK_AUTH | |
| ) | |
| # Verify the response is a StreamingResponse | |
| assert isinstance(response, StreamingResponse) | |
| # Collect the streaming response content | |
| streaming_content = [] | |
| async for chunk in response.body_iterator: | |
| streaming_content.append(chunk) | |
| # Convert to string for assertions | |
| full_content = "".join(streaming_content) | |
| # Verify the response contains the invalid query response | |
| assert "Invalid query response" in full_content | |
| @pytest.mark.asyncio | |
| async def test_streaming_query_endpoint_with_question_validation_invalid_query( | |
| setup_configuration, mocker | |
| ): | |
| """Test streaming query endpoint with question validation enabled and invalid query.""" | |
| # Mock metrics | |
| mock_metrics(mocker) | |
| # Mock database operations | |
| mock_database_operations(mocker) | |
| # Setup configuration with question validation enabled | |
| setup_configuration.question_validation.question_validation_enabled = True | |
| mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) | |
| # Mock the client | |
| mock_client = mocker.AsyncMock() | |
| mock_client.models.list.return_value = [ | |
| mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1") | |
| ] | |
| mocker.patch( | |
| "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client | |
| ) | |
| # Mock the validation agent response (invalid query) | |
| mock_validation_agent = mocker.AsyncMock() | |
| mock_validation_agent.agent_id = "validation_agent_id" | |
| mock_validation_agent.create_session.return_value = "validation_session_id" | |
| # Mock the validation response that contains SUBJECT_REJECTED | |
| mock_validation_turn = mocker.Mock() | |
| mock_validation_turn.output_message.content = [{"type": "text", "text": "REJECTED"}] | |
| mock_validation_agent.create_turn.return_value = mock_validation_turn | |
| # Mock the validation functions | |
| mocker.patch( | |
| "app.endpoints.streaming_query.validate_question", | |
| return_value=(False, "validation_agent_id"), | |
| ) | |
| mocker.patch( | |
| "app.endpoints.streaming_query.get_invalid_query_response", | |
| return_value="Invalid query response", | |
| ) | |
| mocker.patch( | |
| "app.endpoints.streaming_query.interleaved_content_as_str", | |
| return_value="REJECTED", | |
| ) | |
| # Mock other dependencies | |
| mocker.patch("app.endpoints.streaming_query.validate_model_provider_override") | |
| mocker.patch("app.endpoints.streaming_query.check_configuration_loaded") | |
| mocker.patch( | |
| "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False | |
| ) | |
| mocker.patch( | |
| "app.endpoints.streaming_query.select_model_and_provider_id", | |
| return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), | |
| ) | |
| query_request = QueryRequest(query="Invalid question about unrelated topic") | |
| request = Request( | |
| scope={ | |
| "type": "http", | |
| } | |
| ) | |
| request.state.authorized_actions = set(Action) | |
| response = await streaming_query_endpoint_handler( | |
| request=request, query_request=query_request, auth=MOCK_AUTH | |
| ) | |
| # Verify the response is a StreamingResponse | |
| assert isinstance(response, StreamingResponse) | |
| # Must be SSE | |
| assert response.media_type == "text/event-stream" | |
| # Collect the streaming response content | |
| streaming_content = [] | |
| async for chunk in response.body_iterator: | |
| streaming_content.append(chunk) | |
| # Convert to string for assertions | |
| full_content = "".join(streaming_content) | |
| # Verify the response contains the invalid query response | |
| assert "Invalid query response" in full_content | |
| # And is wrapped as SSE tokens | |
| assert "data: " in full_content |
🤖 Prompt for AI Agents
In tests/unit/app/endpoints/test_streaming_query.py around lines 1592-1678, the
test verifies invalid-question validation returns a StreamingResponse but
doesn't assert SSE content-type; update both the handler and test: change the
streaming_query handler to return StreamingResponse(...,
media_type="text/event-stream") for the invalid-question branch (and ensure the
normal successful-stream branch also sets media_type="text/event-stream"), then
update this test to assert response.media_type == "text/event-stream" (and keep
the existing body iteration assertions) so regressions will be caught.
tisnik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just question: won't it be possible to use LS provider for question validation? I dunno, having this in the core might be fine, OTOH LS prefers this functionality to be deferred to providers.
@tisnik thanks for this information. I have tried the safety shield, it works for us. the only concern is the model_id needs to be preset & hard-coded. |
Description
This PR implements the question validation feature, part of the feature parity with road-core
it is default to be disabled, but can be enabled by setting following in config :
a temp agent is used to validate the question using the validation prompt & query's model and provider pair.
the question validation result will be saved in transcript.
verified both query & streaming_query:

Type of change
Related Tickets & Documents
Checklist before requesting a review
Testing
Summary by CodeRabbit
New Features
Documentation