Skip to content
24 changes: 24 additions & 0 deletions docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,9 @@
},
"inference": {
"$ref": "#/components/schemas/InferenceConfiguration"
},
"question_validation": {
"$ref": "#/components/schemas/QuestionValidationConfiguration"
}
},
"additionalProperties": false,
Expand Down Expand Up @@ -1286,6 +1289,14 @@
"type": "object",
"title": "Prompts",
"default": {}
},
"query_responses": {
"additionalProperties": {
"type": "string"
},
"type": "object",
"title": "Query Responses",
"default": {}
}
},
"type": "object",
Expand Down Expand Up @@ -2279,6 +2290,19 @@
}
]
},
"QuestionValidationConfiguration": {
"properties": {
"question_validation_enabled": {
"type": "boolean",
"title": "Question Validation Enabled",
"default": false
}
},
"additionalProperties": false,
"type": "object",
"title": "QuestionValidationConfiguration",
"description": "Question validation configuration."
},
"ReadinessResponse": {
"properties": {
"ready": {
Expand Down
64 changes: 57 additions & 7 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import logging
from typing import Annotated, Any, cast

from llama_stack_client import APIConnectionError
from llama_stack_client import AsyncLlamaStackClient # type: ignore

from llama_stack_client import APIConnectionError, AsyncLlamaStackClient # type: ignore
from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str
from llama_stack_client.types import UserMessage, Shield # type: ignore
from llama_stack_client.types.agents.turn import Turn
Expand All @@ -31,9 +31,11 @@
from models.database.conversations import UserConversation
from models.requests import QueryRequest, Attachment
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
from utils.agent import get_agent, get_temp_agent
from utils.endpoints import (
check_configuration_loaded,
get_agent,
get_invalid_query_response,
get_validation_system_prompt,
get_system_prompt,
validate_conversation_ownership,
validate_model_provider_override,
Expand Down Expand Up @@ -215,7 +217,7 @@ async def query_endpoint_handler(
user_conversation=user_conversation, query_request=query_request
),
)
summary, conversation_id = await retrieve_response(
summary, conversation_id, query_is_valid = await retrieve_response(
client,
llama_stack_model_id,
query_request,
Expand All @@ -234,7 +236,7 @@ async def query_endpoint_handler(
conversation_id=conversation_id,
model_id=model_id,
provider_id=provider_id,
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
query_is_valid=query_is_valid,
query=query_request.query,
query_request=query_request,
summary=summary,
Expand Down Expand Up @@ -391,6 +393,37 @@ def is_input_shield(shield: Shield) -> bool:
return _is_inout_shield(shield) or not is_output_shield(shield)


async def validate_question(
question: str, client: AsyncLlamaStackClient, model_id: str
) -> tuple[bool, str]:
"""Validate a question and provides a one-word response.

Args:
question: The question to be validated.
client: The AsyncLlamaStackClient to use for the request.
model_id: The ID of the model to use.

Returns:
bool: True if the question was deemed valid, False otherwise
"""
validation_system_prompt = get_validation_system_prompt(configuration)
agent, session_id, conversation_id = await get_temp_agent(
client, model_id, validation_system_prompt
)
response = await agent.create_turn(
messages=[UserMessage(role="user", content=question)],
session_id=session_id,
stream=False,
toolgroups=None,
)
response = cast(Turn, response)
return (
constants.SUBJECT_REJECTED
not in interleaved_content_as_str(response.output_message.content),
conversation_id,
)
Comment on lines +419 to +424
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Harden against missing output_message/content to avoid AttributeError.

Some turns may lack output_message.content; this will crash. Guard before parsing.

-    response = cast(Turn, response)
-    return (
-        constants.SUBJECT_REJECTED
-        not in interleaved_content_as_str(response.output_message.content),
-        conversation_id,
-    )
+    response = cast(Turn, response)
+    output_content = ""
+    if getattr(response, "output_message", None) and getattr(response.output_message, "content", None) is not None:
+        output_content = interleaved_content_as_str(response.output_message.content)
+    is_valid = constants.SUBJECT_REJECTED not in output_content
+    return (is_valid, conversation_id)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
response = cast(Turn, response)
return (
constants.SUBJECT_REJECTED
not in interleaved_content_as_str(response.output_message.content),
conversation_id,
)
response = cast(Turn, response)
output_content = ""
if getattr(response, "output_message", None) and getattr(response.output_message, "content", None) is not None:
output_content = interleaved_content_as_str(response.output_message.content)
is_valid = constants.SUBJECT_REJECTED not in output_content
return (is_valid, conversation_id)
🤖 Prompt for AI Agents
In src/app/endpoints/query.py around lines 419 to 424, the code assumes
response.output_message.content always exists which can raise AttributeError for
some turns; update the logic to safely handle missing output_message or missing
content by checking that response and response.output_message are not None and
that response.output_message.content exists (or use getattr with a default empty
list/string) before calling interleaved_content_as_str, and treat missing
content as an empty string so the boolean expression becomes False without
throwing; return the same tuple (subject_rejected_flag, conversation_id) using
the guarded value.



async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches,too-many-arguments
client: AsyncLlamaStackClient,
model_id: str,
Expand All @@ -399,7 +432,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
mcp_headers: dict[str, dict[str, str]] | None = None,
*,
provider_id: str = "",
) -> tuple[TurnSummary, str]:
) -> tuple[TurnSummary, str, bool]:
"""
Retrieve response from LLMs and agents.

Expand Down Expand Up @@ -462,6 +495,23 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
)

logger.debug("Conversation ID: %s, session ID: %s", conversation_id, session_id)

# Validate the question if question validation is enabled
if configuration.question_validation.question_validation_enabled:
question_is_valid, _ = await validate_question(
query_request.query, client, model_id
)

if not question_is_valid:
return (
TurnSummary(
llm_response=get_invalid_query_response(configuration),
tool_calls=[],
),
conversation_id,
False,
)

# bypass tools and MCP servers if no_tools is True
if query_request.no_tools:
mcp_headers = {}
Expand Down Expand Up @@ -535,7 +585,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
"Response lacks output_message.content (conversation_id=%s)",
conversation_id,
)
return summary, conversation_id
return summary, conversation_id, True


def validate_attachments_metadata(attachments: list[Attachment]) -> None:
Expand Down
48 changes: 45 additions & 3 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,20 @@
from models.config import Action
from models.requests import QueryRequest
from models.database.conversations import UserConversation
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
from utils.agent import get_agent
from utils.endpoints import (
check_configuration_loaded,
get_system_prompt,
get_invalid_query_response,
)
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.transcripts import store_transcript
from utils.types import TurnSummary
from utils.endpoints import validate_model_provider_override

from app.endpoints.query import (
get_rag_toolgroups,
validate_question,
is_input_shield,
is_output_shield,
is_transcripts_enabled,
Expand Down Expand Up @@ -587,6 +593,40 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
user_conversation=user_conversation, query_request=query_request
),
)

# Check question validation before getting response
query_is_valid = True
if configuration.question_validation.question_validation_enabled:
query_is_valid, temp_agent_conversation_id = await validate_question(
query_request.query, client, llama_stack_model_id
)
if not query_is_valid:
response = get_invalid_query_response(configuration)
if not is_transcripts_enabled():
logger.debug(
"Transcript collection is disabled in the configuration"
)
else:
summary = TurnSummary(
llm_response=response,
tool_calls=[],
)
store_transcript(
user_id=user_id,
conversation_id=query_request.conversation_id
or temp_agent_conversation_id,
model_id=model_id,
provider_id=provider_id,
query_is_valid=query_is_valid,
query=query_request.query,
query_request=query_request,
summary=summary,
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
truncated=False, # TODO(lucasagomes): implement truncation as part
# of quota work
attachments=query_request.attachments or [],
)
return StreamingResponse(response)
Comment on lines +597 to +629
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Invalid streaming path returns plain text, breaking SSE clients.

For invalid queries you return StreamingResponse(str), not SSE events. Emit start/turn_complete/end events to keep the contract.

-        if configuration.question_validation.question_validation_enabled:
-            query_is_valid, temp_agent_conversation_id = await validate_question(
-                query_request.query, client, llama_stack_model_id
-            )
-            if not query_is_valid:
-                response = get_invalid_query_response(configuration)
-                if not is_transcripts_enabled():
-                    logger.debug(
-                        "Transcript collection is disabled in the configuration"
-                    )
-                else:
-                    summary = TurnSummary(
-                        llm_response=response,
-                        tool_calls=[],
-                    )
-                    store_transcript(
-                        user_id=user_id,
-                        conversation_id=query_request.conversation_id
-                        or temp_agent_conversation_id,
-                        model_id=model_id,
-                        provider_id=provider_id,
-                        query_is_valid=query_is_valid,
-                        query=query_request.query,
-                        query_request=query_request,
-                        summary=summary,
-                        rag_chunks=[],  # TODO(lucasagomes): implement rag_chunks
-                        truncated=False,  # TODO(lucasagomes): implement truncation as part
-                        # of quota work
-                        attachments=query_request.attachments or [],
-                    )
-                return StreamingResponse(response)
+        if configuration.question_validation.question_validation_enabled:
+            query_is_valid, temp_agent_conversation_id = await validate_question(
+                query_request.query, client, llama_stack_model_id
+            )
+            if not query_is_valid:
+                invalid_text = get_invalid_query_response(configuration)
+                if is_transcripts_enabled():
+                    summary = TurnSummary(llm_response=invalid_text, tool_calls=[])
+                    store_transcript(
+                        user_id=user_id,
+                        conversation_id=query_request.conversation_id or temp_agent_conversation_id,
+                        model_id=model_id,
+                        provider_id=provider_id,
+                        query_is_valid=query_is_valid,
+                        query=query_request.query,
+                        query_request=query_request,
+                        summary=summary,
+                        rag_chunks=[],  # TODO(lucasagomes): implement rag_chunks
+                        truncated=False,  # TODO(lucasagomes): implement truncation as part of quota work
+                        attachments=query_request.attachments or [],
+                    )
+                # Return minimal SSE stream for invalid query
+                conv_id = query_request.conversation_id or temp_agent_conversation_id
+                return StreamingResponse(
+                    iter(
+                        [
+                            stream_start_event(conv_id),
+                            format_stream_data({"event": "turn_complete", "data": {"id": 0, "token": invalid_text}}),
+                            stream_end_event({}),
+                        ]
+                    ),
+                    media_type="text/event-stream",
+                )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Check question validation before getting response
query_is_valid = True
if configuration.question_validation.question_validation_enabled:
query_is_valid, temp_agent_conversation_id = await validate_question(
query_request.query, client, llama_stack_model_id
)
if not query_is_valid:
response = get_invalid_query_response(configuration)
if not is_transcripts_enabled():
logger.debug(
"Transcript collection is disabled in the configuration"
)
else:
summary = TurnSummary(
llm_response=response,
tool_calls=[],
)
store_transcript(
user_id=user_id,
conversation_id=query_request.conversation_id
or temp_agent_conversation_id,
model_id=model_id,
provider_id=provider_id,
query_is_valid=query_is_valid,
query=query_request.query,
query_request=query_request,
summary=summary,
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
truncated=False, # TODO(lucasagomes): implement truncation as part
# of quota work
attachments=query_request.attachments or [],
)
return StreamingResponse(response)
# Check question validation before getting response
query_is_valid = True
if configuration.question_validation.question_validation_enabled:
query_is_valid, temp_agent_conversation_id = await validate_question(
query_request.query, client, llama_stack_model_id
)
if not query_is_valid:
invalid_text = get_invalid_query_response(configuration)
if is_transcripts_enabled():
summary = TurnSummary(llm_response=invalid_text, tool_calls=[])
store_transcript(
user_id=user_id,
conversation_id=query_request.conversation_id or temp_agent_conversation_id,
model_id=model_id,
provider_id=provider_id,
query_is_valid=query_is_valid,
query=query_request.query,
query_request=query_request,
summary=summary,
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
truncated=False, # TODO(lucasagomes): implement truncation as part of quota work
attachments=query_request.attachments or [],
)
# Return minimal SSE stream for invalid query
conv_id = query_request.conversation_id or temp_agent_conversation_id
return StreamingResponse(
iter(
[
stream_start_event(conv_id),
format_stream_data({"event": "turn_complete", "data": {"id": 0, "token": invalid_text}}),
stream_end_event({}),
]
),
media_type="text/event-stream",
)

response, conversation_id = await retrieve_response(
client,
llama_stack_model_id,
Expand All @@ -598,6 +638,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals

async def response_generator(
turn_response: AsyncIterator[AgentTurnResponseStreamChunk],
query_is_valid: bool,
) -> AsyncIterator[str]:
"""
Generate SSE formatted streaming response.
Expand Down Expand Up @@ -649,7 +690,7 @@ async def response_generator(
conversation_id=conversation_id,
model_id=model_id,
provider_id=provider_id,
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
query_is_valid=query_is_valid,
query=query_request.query,
query_request=query_request,
summary=summary,
Expand All @@ -669,7 +710,7 @@ async def response_generator(
# Update metrics for the LLM call
metrics.llm_calls_total.labels(provider_id, model_id).inc()

return StreamingResponse(response_generator(response))
return StreamingResponse(response_generator(response, query_is_valid))
# connection to Llama Stack server
except APIConnectionError as e:
# Update metrics for the LLM call failure
Expand Down Expand Up @@ -750,6 +791,7 @@ async def retrieve_response(
)

logger.debug("Conversation ID: %s, session ID: %s", conversation_id, session_id)

# bypass tools and MCP servers if no_tools is True
if query_request.no_tools:
mcp_headers = {}
Expand Down
8 changes: 8 additions & 0 deletions src/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
AuthenticationConfiguration,
InferenceConfiguration,
DatabaseConfiguration,
QuestionValidationConfiguration,
)


Expand Down Expand Up @@ -131,5 +132,12 @@ def database_configuration(self) -> DatabaseConfiguration:
raise LogicError("logic error: configuration is not loaded")
return self._configuration.database

@property
def question_validation(self) -> QuestionValidationConfiguration:
"""Return question validation configuration."""
if self._configuration is None:
raise LogicError("logic error: configuration is not loaded")
return self._configuration.question_validation


configuration: AppConfig = AppConfig()
13 changes: 13 additions & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@
# configuration file nor in the query request
DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant"


# Query validation
SUBJECT_REJECTED = "REJECTED"
SUBJECT_ALLOWED = "ALLOWED"
DEFAULT_VALIDATION_SYSTEM_PROMPT = (
"You are a helpful assistant that validates questions. You will be given a "
"question and you will need to validate if it is valid or not. You will "
f"return '{SUBJECT_REJECTED}' if the question is not valid and "
f"'{SUBJECT_ALLOWED}' if it is valid."
)
DEFAULT_INVALID_QUERY_RESPONSE = "Invalid query, please try again."


# Authentication constants
DEFAULT_VIRTUAL_PATH = "/ls-access"
DEFAULT_USER_NAME = "lightspeed-user"
Expand Down
17 changes: 17 additions & 0 deletions src/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,12 @@ def check_storage_location_is_set_when_needed(self) -> Self:
return self


class QuestionValidationConfiguration(ConfigurationBase):
"""Question validation configuration."""

question_validation_enabled: bool = False


class JsonPathOperator(str, Enum):
"""Supported operators for JSONPath evaluation."""

Expand Down Expand Up @@ -421,6 +427,7 @@ class CustomProfile:

path: str
prompts: dict[str, str] = Field(default={}, init=False)
query_responses: dict[str, str] = Field(default={}, init=False)

def __post_init__(self) -> None:
"""Validate and load profile."""
Expand All @@ -432,11 +439,18 @@ def _validate_and_process(self) -> None:
profile_module = checks.import_python_module("profile", self.path)
if profile_module is not None and checks.is_valid_profile(profile_module):
self.prompts = profile_module.PROFILE_CONFIG.get("system_prompts", {})
self.query_responses = profile_module.PROFILE_CONFIG.get(
"query_responses", {}
)

def get_prompts(self) -> dict[str, str]:
"""Retrieve prompt attribute."""
return self.prompts

def get_query_responses(self) -> dict[str, str]:
"""Retrieve query responses attribute."""
return self.query_responses


class Customization(ConfigurationBase):
"""Service customization."""
Expand Down Expand Up @@ -495,6 +509,9 @@ class Configuration(ConfigurationBase):
authorization: Optional[AuthorizationConfiguration] = None
customization: Optional[Customization] = None
inference: InferenceConfiguration = Field(default_factory=InferenceConfiguration)
question_validation: QuestionValidationConfiguration = Field(
default_factory=QuestionValidationConfiguration
)

def dump(self, filename: str = "configuration.json") -> None:
"""Dump actual configuration into JSON file."""
Expand Down
Loading
Loading