From 2d1567aea3d10e3a087813f4fd8aec513a06851a Mon Sep 17 00:00:00 2001 From: are-ces <195810094+are-ces@users.noreply.github.com> Date: Mon, 22 Sep 2025 16:46:20 +0200 Subject: [PATCH] LCORE-347: Implement referenced documents support on /query --- src/app/endpoints/query.py | 127 +++++++++++++--- src/models/responses.py | 39 ++++- tests/unit/app/endpoints/test_query.py | 191 ++++++++++++++++++++----- 3 files changed, 301 insertions(+), 56 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 41143621e..48594d215 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -1,36 +1,46 @@ """Handler for REST API call to provide answer to query.""" -from datetime import datetime, UTC +import ast import json import logging -from typing import Annotated, Any, cast - -from llama_stack_client import APIConnectionError -from llama_stack_client import AsyncLlamaStackClient # type: ignore +import re +from datetime import UTC, datetime +from typing import Annotated, Any, Optional, cast + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from llama_stack_client import ( + APIConnectionError, + AsyncLlamaStackClient, # type: ignore +) from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str -from llama_stack_client.types import UserMessage, Shield # type: ignore +from llama_stack_client.types import Shield, UserMessage # type: ignore from llama_stack_client.types.agents.turn import Turn from llama_stack_client.types.agents.turn_create_params import ( - ToolgroupAgentToolGroupWithArgs, Toolgroup, + ToolgroupAgentToolGroupWithArgs, ) from llama_stack_client.types.model_list_response import ModelListResponse +from llama_stack_client.types.shared.interleaved_content_item import TextContentItem +from llama_stack_client.types.tool_execution_step import ToolExecutionStep -from fastapi import APIRouter, HTTPException, Request, status, Depends - +import constants +import metrics +from app.database import get_session from authentication import get_auth_dependency from authentication.interface import AuthTuple +from authorization.middleware import authorize from client import AsyncLlamaStackClientHolder from configuration import configuration -from app.database import get_session -import metrics from metrics.utils import update_llm_token_count_from_turn -import constants -from authorization.middleware import authorize from models.config import Action from models.database.conversations import UserConversation -from models.requests import QueryRequest, Attachment -from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse +from models.requests import Attachment, QueryRequest +from models.responses import ( + ForbiddenResponse, + QueryResponse, + ReferencedDocument, + UnauthorizedResponse, +) from utils.endpoints import ( check_configuration_loaded, get_agent, @@ -38,7 +48,7 @@ validate_conversation_ownership, validate_model_provider_override, ) -from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups +from utils.mcp_headers import handle_mcp_headers_with_toolgroups, mcp_headers_dependency from utils.transcripts import store_transcript from utils.types import TurnSummary @@ -50,6 +60,13 @@ 200: { "conversation_id": "123e4567-e89b-12d3-a456-426614174000", "response": "LLM answer", + "referenced_documents": [ + { + "doc_url": "https://docs.openshift.com/" + "container-platform/4.15/operators/olm/index.html", + "doc_title": "Operator Lifecycle Manager (OLM)", + } + ], }, 400: { "description": "Missing or invalid credentials provided by client", @@ -220,7 +237,7 @@ async def query_endpoint_handler( user_conversation=user_conversation, query_request=query_request ), ) - summary, conversation_id = await retrieve_response( + summary, conversation_id, referenced_documents = await retrieve_response( client, llama_stack_model_id, query_request, @@ -258,6 +275,7 @@ async def query_endpoint_handler( return QueryResponse( conversation_id=conversation_id, response=summary.llm_response, + referenced_documents=referenced_documents, ) # connection to Llama Stack server @@ -396,6 +414,70 @@ def is_input_shield(shield: Shield) -> bool: return _is_inout_shield(shield) or not is_output_shield(shield) +def parse_metadata_from_text_item( + text_item: TextContentItem, +) -> Optional[ReferencedDocument]: + """ + Parse a single TextContentItem to extract referenced documents. + + Args: + text_item (TextContentItem): The TextContentItem containing metadata. + + Returns: + ReferencedDocument: A ReferencedDocument object containing 'doc_url' and 'doc_title' + representing the referenced documents found in the metadata. + """ + docs: list[ReferencedDocument] = [] + if not isinstance(text_item, TextContentItem): + return docs + + metadata_blocks = re.findall( + r"Metadata:\s*({.*?})(?:\n|$)", text_item.text, re.DOTALL + ) + for block in metadata_blocks: + try: + data = ast.literal_eval(block) + url = data.get("docs_url") + title = data.get("title") + if url and title: + return ReferencedDocument(doc_url=url, doc_title=title) + logger.debug("Invalid metadata block (missing url or title): %s", block) + except (ValueError, SyntaxError) as e: + logger.debug("Failed to parse metadata block: %s | Error: %s", block, e) + return None + + +def parse_referenced_documents(response: Turn) -> list[ReferencedDocument]: + """ + Parse referenced documents from Turn. + + Iterate through the steps of a response and collect all referenced + documents from rag tool responses. + + Args: + response(Turn): The response object from the agent turn. + + Returns: + list[ReferencedDocument]: A list of ReferencedDocument, each with 'doc_url' and 'doc_title' + representing all referenced documents found in the response. + """ + docs = [] + for step in response.steps: + if not isinstance(step, ToolExecutionStep): + continue + for tool_response in step.tool_responses: + # TODO(are-ces): use constant instead + if tool_response.tool_name != "knowledge_search": + continue + for text_item in tool_response.content: + if not isinstance(text_item, TextContentItem): + continue + doc = parse_metadata_from_text_item(text_item) + if doc: + docs.append(doc) + return docs + + async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches,too-many-arguments client: AsyncLlamaStackClient, model_id: str, @@ -404,7 +486,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche mcp_headers: dict[str, dict[str, str]] | None = None, *, provider_id: str = "", -) -> tuple[TurnSummary, str]: +) -> tuple[TurnSummary, str, list[ReferencedDocument]]: """ Retrieve response from LLMs and agents. @@ -428,8 +510,9 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing. Returns: - tuple[TurnSummary, str]: A tuple containing a summary of the LLM or agent's response content - and the conversation ID. + tuple[TurnSummary, str, list[ReferencedDocument]]: A tuple containing + a summary of the LLM or agent's response + content, the conversation ID and the list of parsed referenced documents. """ available_input_shields = [ shield.identifier @@ -522,6 +605,8 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche tool_calls=[], ) + referenced_documents = parse_referenced_documents(response) + # Update token count metrics for the LLM call model_label = model_id.split("/", 1)[1] if "/" in model_id else model_id update_llm_token_count_from_turn(response, model_label, provider_id, system_prompt) @@ -540,7 +625,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche "Response lacks output_message.content (conversation_id=%s)", conversation_id, ) - return summary, conversation_id + return (summary, conversation_id, referenced_documents) def validate_attachments_metadata(attachments: list[Attachment]) -> None: diff --git a/src/models/responses.py b/src/models/responses.py index 489f532b7..1f99bb31f 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -2,7 +2,7 @@ from typing import Any, Optional -from pydantic import BaseModel, Field +from pydantic import AnyUrl, BaseModel, Field class ModelsResponse(BaseModel): @@ -34,10 +34,21 @@ class ModelsResponse(BaseModel): ) +class ReferencedDocument(BaseModel): + """Model representing a document referenced in generating a response. + + Attributes: + doc_url: Url to the referenced doc. + doc_title: Title of the referenced doc. + """ + + doc_url: AnyUrl = Field(description="URL of the referenced document") + + doc_title: str = Field(description="Title of the referenced document") + + # TODO(lucasagomes): a lot of fields to add to QueryResponse. For now # we are keeping it simple. The missing fields are: -# - referenced_documents: The optional URLs and titles for the documents used -# to generate the response. # - truncated: Set to True if conversation history was truncated to be within context window. # - input_tokens: Number of tokens sent to LLM # - output_tokens: Number of tokens received from LLM @@ -51,6 +62,7 @@ class QueryResponse(BaseModel): Attributes: conversation_id: The optional conversation ID (UUID). response: The response. + referenced_documents: The URLs and titles for the documents used to generate the response. """ conversation_id: Optional[str] = Field( @@ -66,6 +78,20 @@ class QueryResponse(BaseModel): ], ) + referenced_documents: list[ReferencedDocument] = Field( + default_factory=list, + description="List of documents referenced in generating the response", + examples=[ + [ + { + "doc_url": "https://docs.openshift.com/" + "container-platform/4.15/operators/olm/index.html", + "doc_title": "Operator Lifecycle Manager (OLM)", + } + ] + ], + ) + # provides examples for /docs endpoint model_config = { "json_schema_extra": { @@ -73,6 +99,13 @@ class QueryResponse(BaseModel): { "conversation_id": "123e4567-e89b-12d3-a456-426614174000", "response": "Operator Lifecycle Manager (OLM) helps users install...", + "referenced_documents": [ + { + "doc_url": "https://docs.openshift.com/" + "container-platform/4.15/operators/olm/index.html", + "doc_title": "Operator Lifecycle Manager (OLM)", + } + ], } ] } diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 3b3d64f3f..c15c12253 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -5,31 +5,38 @@ # pylint: disable=too-many-lines import json -from fastapi import HTTPException, status, Request -import pytest +import pytest +from fastapi import HTTPException, Request, status from llama_stack_client import APIConnectionError from llama_stack_client.types import UserMessage # type: ignore -from llama_stack_client.types.shared.interleaved_content import ( - TextContentItem, -) +from llama_stack_client.types.agents.turn import Turn +from llama_stack_client.types.shared.interleaved_content_item import TextContentItem +from llama_stack_client.types.tool_execution_step import ToolExecutionStep +from llama_stack_client.types.tool_response import ToolResponse +from pydantic import AnyUrl -from configuration import AppConfig from app.endpoints.query import ( + evaluate_model_hints, + get_rag_toolgroups, + is_transcripts_enabled, + parse_metadata_from_text_item, + parse_referenced_documents, query_endpoint_handler, - select_model_and_provider_id, retrieve_response, + select_model_and_provider_id, validate_attachments_metadata, - is_transcripts_enabled, - get_rag_toolgroups, - evaluate_model_hints, ) - -from models.requests import QueryRequest, Attachment +from authorization.resolvers import NoopRolesResolver +from configuration import AppConfig from models.config import Action, ModelContextProtocolServer from models.database.conversations import UserConversation +from models.requests import Attachment, QueryRequest +from models.responses import ReferencedDocument +from tests.unit.app.endpoints.test_streaming_query import ( + SAMPLE_KNOWLEDGE_SEARCH_RESULTS, +) from utils.types import ToolCallSummary, TurnSummary -from authorization.resolvers import NoopRolesResolver MOCK_AUTH = ("mock_user_id", "mock_username", False, "mock_token") @@ -167,10 +174,11 @@ async def _test_query_endpoint_handler( ) conversation_id = "fake_conversation_id" query = "What is OpenStack?" + referenced_documents = [] mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(summary, conversation_id), + return_value=(summary, conversation_id, referenced_documents), ) mocker.patch( "app.endpoints.query.select_model_and_provider_id", @@ -458,7 +466,7 @@ async def test_retrieve_response_no_returned_message(prepare_agent_mocks, mocker model_id = "fake_model_id" access_token = "test_token" - response, _ = await retrieve_response( + response, _, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -490,7 +498,7 @@ async def test_retrieve_response_message_without_content(prepare_agent_mocks, mo model_id = "fake_model_id" access_token = "test_token" - response, _ = await retrieve_response( + response, _, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -523,7 +531,7 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -562,7 +570,7 @@ async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocke model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -612,7 +620,7 @@ def __repr__(self): model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -665,7 +673,7 @@ def __repr__(self): model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -720,7 +728,7 @@ def __repr__(self): model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -777,7 +785,7 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -832,7 +840,7 @@ async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocke model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -856,6 +864,123 @@ async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocke ) +def test_parse_metadata_from_text_item_valid(mocker): + """Test parsing metadata from a TextContentItem.""" + text = """ + Some text... + Metadata: {"docs_url": "https://redhat.com", "title": "Example Doc"} + """ + mock_item = mocker.Mock(spec=TextContentItem) + mock_item.text = text + + doc = parse_metadata_from_text_item(mock_item) + + assert isinstance(doc, ReferencedDocument) + assert doc.doc_url == AnyUrl("https://redhat.com") + assert doc.doc_title == "Example Doc" + + +def test_parse_metadata_from_text_item_missing_title(mocker): + """Test parsing metadata from a TextContentItem with missing title.""" + mock_item = mocker.Mock(spec=TextContentItem) + mock_item.text = """Metadata: {"docs_url": "https://redhat.com"}""" + doc = parse_metadata_from_text_item(mock_item) + assert doc is None + + +def test_parse_metadata_from_text_item_missing_url(mocker): + """Test parsing metadata from a TextContentItem with missing url.""" + mock_item = mocker.Mock(spec=TextContentItem) + mock_item.text = """Metadata: {"title": "Example Doc"}""" + doc = parse_metadata_from_text_item(mock_item) + assert doc is None + + +def test_parse_metadata_from_text_item_malformed_url(mocker): + """Test parsing metadata from a TextContentItem with malformed url.""" + mock_item = mocker.Mock(spec=TextContentItem) + mock_item.text = ( + """Metadata: {"docs_url": "not a valid url", "title": "Example Doc"}""" + ) + doc = parse_metadata_from_text_item(mock_item) + assert doc is None + + +def test_parse_referenced_documents_single_doc(mocker): + """Test parsing metadata from a Turn containing a single doc.""" + text_item = mocker.Mock(spec=TextContentItem) + text_item.text = ( + """Metadata: {"docs_url": "https://redhat.com", "title": "Example Doc"}""" + ) + + tool_response = mocker.Mock(spec=ToolResponse) + tool_response.tool_name = "knowledge_search" + tool_response.content = [text_item] + + step = mocker.Mock(spec=ToolExecutionStep) + step.tool_responses = [tool_response] + + response = mocker.Mock(spec=Turn) + response.steps = [step] + + docs = parse_referenced_documents(response) + + assert len(docs) == 1 + assert docs[0].doc_url == AnyUrl("https://redhat.com") + assert docs[0].doc_title == "Example Doc" + + +def test_parse_referenced_documents_multiple_docs(mocker): + """Test parsing metadata from a Turn containing multiple docs.""" + text_item = mocker.Mock(spec=TextContentItem) + text_item.text = SAMPLE_KNOWLEDGE_SEARCH_RESULTS + + tool_response = ToolResponse( + call_id="c1", + tool_name="knowledge_search", + content=[ + TextContentItem(text=s, type="text") + for s in SAMPLE_KNOWLEDGE_SEARCH_RESULTS + ], + ) + + step = mocker.Mock(spec=ToolExecutionStep) + step.tool_responses = [tool_response] + + response = mocker.Mock(spec=Turn) + response.steps = [step] + + docs = parse_referenced_documents(response) + + assert len(docs) == 2 + assert docs[0].doc_url == AnyUrl("https://example.com/doc1") + assert docs[0].doc_title == "Doc1" + assert docs[1].doc_url == AnyUrl("https://example.com/doc2") + assert docs[1].doc_title == "Doc2" + + +def test_parse_referenced_documents_ignores_other_tools(mocker): + """Test parsing metadata from a Turn with the wrong tool name.""" + text_item = mocker.Mock(spec=TextContentItem) + text_item.text = ( + """Metadata: {"docs_url": "https://redhat.com", "title": "Example Doc"}""" + ) + + tool_response = mocker.Mock(spec=ToolResponse) + tool_response.tool_name = "not rag tool" + tool_response.content = [text_item] + + step = mocker.Mock(spec=ToolExecutionStep) + step.tool_responses = [tool_response] + + response = mocker.Mock() + response.steps = [step] + + docs = parse_referenced_documents(response) + + assert not docs + + @pytest.mark.asyncio async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): """Test the retrieve_response function with MCP servers configured.""" @@ -888,7 +1013,7 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token_123" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -958,7 +1083,7 @@ async def test_retrieve_response_with_mcp_servers_empty_token( model_id = "fake_model_id" access_token = "" # Empty token - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -1030,7 +1155,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers( }, } - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, @@ -1115,7 +1240,7 @@ async def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker): query_request = QueryRequest(query="What is OpenStack?") - _, conversation_id = await retrieve_response( + _, conversation_id, _ = await retrieve_response( mock_client, "fake_model_id", query_request, "test_token" ) @@ -1200,7 +1325,7 @@ async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker, dummy_requ ) mock_retrieve_response = mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(summary, "test_conversation_id"), + return_value=(summary, "test_conversation_id", []), ) mocker.patch( @@ -1248,10 +1373,11 @@ async def test_query_endpoint_handler_no_tools_true(mocker, dummy_request): ) conversation_id = "fake_conversation_id" query = "What is OpenStack?" + referenced_documents = [] mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(summary, conversation_id), + return_value=(summary, conversation_id, referenced_documents), ) mocker.patch( "app.endpoints.query.select_model_and_provider_id", @@ -1299,10 +1425,11 @@ async def test_query_endpoint_handler_no_tools_false(mocker, dummy_request): ) conversation_id = "fake_conversation_id" query = "What is OpenStack?" + referenced_documents = [] mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(summary, conversation_id), + return_value=(summary, conversation_id, referenced_documents), ) mocker.patch( "app.endpoints.query.select_model_and_provider_id", @@ -1354,7 +1481,7 @@ async def test_retrieve_response_no_tools_bypasses_mcp_and_rag( model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token ) @@ -1405,7 +1532,7 @@ async def test_retrieve_response_no_tools_false_preserves_functionality( model_id = "fake_model_id" access_token = "test_token" - summary, conversation_id = await retrieve_response( + summary, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, access_token )