diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 3f75a6cd4..f14bafa4a 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -22,7 +22,6 @@ from llama_stack_client.types.model_list_response import ModelListResponse from llama_stack_client.types.shared.interleaved_content_item import TextContentItem from llama_stack_client.types.tool_execution_step import ToolExecutionStep -from pydantic import AnyUrl import constants import metrics @@ -360,22 +359,7 @@ async def query_endpoint_handler( # pylint: disable=R0914 for tc in summary.tool_calls ] - logger.info("Extracting referenced documents...") - referenced_docs = [] - doc_sources = set() - for chunk in summary.rag_chunks: - if chunk.source and chunk.source not in doc_sources: - doc_sources.add(chunk.source) - referenced_docs.append( - ReferencedDocument( - doc_url=( - AnyUrl(chunk.source) - if chunk.source.startswith("http") - else None - ), - doc_title=chunk.source, - ) - ) + logger.info("Using referenced documents from response...") logger.info("Building final response...") response = QueryResponse( diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index bf4d8635c..d268b02f3 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -47,6 +47,8 @@ from models.responses import ForbiddenResponse, UnauthorizedResponse from utils.endpoints import ( check_configuration_loaded, + create_referenced_documents_with_metadata, + create_rag_chunks_dict, get_agent, get_system_prompt, store_conversation_into_cache, @@ -142,7 +144,7 @@ def stream_start_event(conversation_id: str) -> str: ) -def stream_end_event(metadata_map: dict) -> str: +def stream_end_event(metadata_map: dict, summary: TurnSummary) -> str: """ Yield the end of the data stream. @@ -158,20 +160,27 @@ def stream_end_event(metadata_map: dict) -> str: str: A Server-Sent Events (SSE) formatted string representing the end of the data stream. """ + # Process RAG chunks using utility function + rag_chunks = create_rag_chunks_dict(summary) + + # Extract referenced documents using utility function + referenced_docs = create_referenced_documents_with_metadata(summary, metadata_map) + + # Convert ReferencedDocument objects to dictionaries for JSON serialization + referenced_docs_dict = [ + { + "doc_url": str(doc.doc_url) if doc.doc_url else None, + "doc_title": doc.doc_title, + } + for doc in referenced_docs + ] + return format_stream_data( { "event": "end", "data": { - "referenced_documents": [ - { - "doc_url": v["docs_url"], - "doc_title": v["title"], - } - for v in filter( - lambda v: ("docs_url" in v) and ("title" in v), - metadata_map.values(), - ) - ], + "rag_chunks": rag_chunks, + "referenced_documents": referenced_docs_dict, "truncated": None, # TODO(jboos): implement truncated "input_tokens": 0, # TODO(jboos): implement input tokens "output_tokens": 0, # TODO(jboos): implement output tokens @@ -667,6 +676,8 @@ async def response_generator( yield stream_start_event(conversation_id) async for chunk in turn_response: + if chunk.event is None: + continue p = chunk.event.payload if p.event_type == "turn_complete": summary.llm_response = interleaved_content_as_str( @@ -687,7 +698,7 @@ async def response_generator( chunk_id += 1 yield event - yield stream_end_event(metadata_map) + yield stream_end_event(metadata_map, summary) if not is_transcripts_enabled(): logger.debug("Transcript collection is disabled in the configuration") @@ -701,7 +712,7 @@ async def response_generator( query=query_request.query, query_request=query_request, summary=summary, - rag_chunks=[], # TODO(lucasagomes): implement rag_chunks + rag_chunks=create_rag_chunks_dict(summary), truncated=False, # TODO(lucasagomes): implement truncation as part # of quota work attachments=query_request.attachments or [], diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index 0b0c15102..de7e9cec6 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -1,18 +1,22 @@ """Utility functions for endpoint handlers.""" from contextlib import suppress +from typing import Any from fastapi import HTTPException, status from llama_stack_client._client import AsyncLlamaStackClient from llama_stack_client.lib.agents.agent import AsyncAgent +from pydantic import AnyUrl, ValidationError import constants from models.cache_entry import CacheEntry from models.requests import QueryRequest +from models.responses import ReferencedDocument from models.database.conversations import UserConversation from models.config import Action from app.database import get_session from configuration import AppConfig from utils.suid import get_suid +from utils.types import TurnSummary from utils.types import GraniteToolParser @@ -340,3 +344,216 @@ async def get_temp_agent( session_id = await agent.create_session(get_suid()) return agent, session_id, conversation_id + + +def create_rag_chunks_dict(summary: TurnSummary) -> list[dict[str, Any]]: + """ + Create dictionary representation of RAG chunks for streaming response. + + Args: + summary: TurnSummary containing RAG chunks + + Returns: + List of dictionaries with content, source, and score + """ + return [ + {"content": chunk.content, "source": chunk.source, "score": chunk.score} + for chunk in summary.rag_chunks + ] + + +def _process_http_source( + src: str, doc_urls: set[str] +) -> tuple[AnyUrl | None, str] | None: + """Process HTTP source and return (doc_url, doc_title) tuple.""" + if src not in doc_urls: + doc_urls.add(src) + try: + validated_url = AnyUrl(src) + except ValidationError: + logger.warning("Invalid URL in chunk source: %s", src) + validated_url = None + + doc_title = src.rsplit("/", 1)[-1] or src + return (validated_url, doc_title) + return None + + +def _process_document_id( + src: str, + doc_ids: set[str], + doc_urls: set[str], + metas_by_id: dict[str, dict[str, Any]], + metadata_map: dict[str, Any] | None, +) -> tuple[AnyUrl | None, str] | None: + """Process document ID and return (doc_url, doc_title) tuple.""" + if src in doc_ids: + return None + doc_ids.add(src) + + meta = metas_by_id.get(src, {}) if metadata_map else {} + doc_url = meta.get("docs_url") + title = meta.get("title") + # Type check to ensure we have the right types + if not isinstance(doc_url, (str, type(None))): + doc_url = None + if not isinstance(title, (str, type(None))): + title = None + + if doc_url: + if doc_url in doc_urls: + return None + doc_urls.add(doc_url) + + try: + validated_doc_url = None + if doc_url and doc_url.startswith("http"): + validated_doc_url = AnyUrl(doc_url) + except ValidationError: + logger.warning("Invalid URL in metadata: %s", doc_url) + validated_doc_url = None + + doc_title = title or (doc_url.rsplit("/", 1)[-1] if doc_url else src) + return (validated_doc_url, doc_title) + + +def _add_additional_metadata_docs( + doc_urls: set[str], + metas_by_id: dict[str, dict[str, Any]], +) -> list[tuple[AnyUrl | None, str]]: + """Add additional referenced documents from metadata_map.""" + additional_entries: list[tuple[AnyUrl | None, str]] = [] + for meta in metas_by_id.values(): + doc_url = meta.get("docs_url") + title = meta.get("title") # Note: must be "title", not "Title" + # Type check to ensure we have the right types + if not isinstance(doc_url, (str, type(None))): + doc_url = None + if not isinstance(title, (str, type(None))): + title = None + if doc_url and doc_url not in doc_urls and title is not None: + doc_urls.add(doc_url) + try: + validated_url = None + if doc_url.startswith("http"): + validated_url = AnyUrl(doc_url) + except ValidationError: + logger.warning("Invalid URL in metadata_map: %s", doc_url) + validated_url = None + + additional_entries.append((validated_url, title)) + return additional_entries + + +def _process_rag_chunks_for_documents( + rag_chunks: list, + metadata_map: dict[str, Any] | None = None, +) -> list[tuple[AnyUrl | None, str]]: + """ + Process RAG chunks and return a list of (doc_url, doc_title) tuples. + + This is the core logic shared between both return formats. + """ + doc_urls: set[str] = set() + doc_ids: set[str] = set() + + # Process metadata_map if provided + metas_by_id: dict[str, dict[str, Any]] = {} + if metadata_map: + metas_by_id = {k: v for k, v in metadata_map.items() if isinstance(v, dict)} + + document_entries: list[tuple[AnyUrl | None, str]] = [] + + for chunk in rag_chunks: + src = chunk.source + if not src or src == constants.DEFAULT_RAG_TOOL: + continue + + if src.startswith("http"): + entry = _process_http_source(src, doc_urls) + if entry: + document_entries.append(entry) + else: + entry = _process_document_id( + src, doc_ids, doc_urls, metas_by_id, metadata_map + ) + if entry: + document_entries.append(entry) + + # Add any additional referenced documents from metadata_map not already present + if metadata_map: + additional_entries = _add_additional_metadata_docs(doc_urls, metas_by_id) + document_entries.extend(additional_entries) + + return document_entries + + +def create_referenced_documents( + rag_chunks: list, + metadata_map: dict[str, Any] | None = None, + return_dict_format: bool = False, +) -> list[ReferencedDocument] | list[dict[str, str | None]]: + """ + Create referenced documents from RAG chunks with optional metadata enrichment. + + This unified function processes RAG chunks and creates referenced documents with + optional metadata enrichment, deduplication, and proper URL handling. It can return + either ReferencedDocument objects (for query endpoint) or dictionaries (for streaming). + + Args: + rag_chunks: List of RAG chunks with source information + metadata_map: Optional mapping containing metadata about referenced documents + return_dict_format: If True, returns list of dicts; if False, returns list of + ReferencedDocument objects + + Returns: + List of ReferencedDocument objects or dictionaries with doc_url and doc_title + """ + document_entries = _process_rag_chunks_for_documents(rag_chunks, metadata_map) + + if return_dict_format: + return [ + { + "doc_url": str(doc_url) if doc_url else None, + "doc_title": doc_title, + } + for doc_url, doc_title in document_entries + ] + return [ + ReferencedDocument(doc_url=doc_url, doc_title=doc_title) + for doc_url, doc_title in document_entries + ] + + +# Backward compatibility functions +def create_referenced_documents_with_metadata( + summary: TurnSummary, metadata_map: dict[str, Any] +) -> list[ReferencedDocument]: + """ + Create referenced documents from RAG chunks with metadata enrichment for streaming. + + This function now returns ReferencedDocument objects for consistency with the query endpoint. + """ + document_entries = _process_rag_chunks_for_documents( + summary.rag_chunks, metadata_map + ) + return [ + ReferencedDocument(doc_url=doc_url, doc_title=doc_title) + for doc_url, doc_title in document_entries + ] + + +def create_referenced_documents_from_chunks( + rag_chunks: list, +) -> list[ReferencedDocument]: + """ + Create referenced documents from RAG chunks for query endpoint. + + This is a backward compatibility wrapper around the unified + create_referenced_documents function. + """ + document_entries = _process_rag_chunks_for_documents(rag_chunks) + return [ + ReferencedDocument(doc_url=doc_url, doc_title=doc_title) + for doc_url, doc_title in document_entries + ] diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 52a387073..eda89d496 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -374,7 +374,13 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) ], ), attachments=[], - rag_chunks=[], + rag_chunks=[ + { + "content": " ".join(SAMPLE_KNOWLEDGE_SEARCH_RESULTS), + "source": "knowledge_search", + "score": None, + } + ], truncated=False, ) else: @@ -1668,3 +1674,44 @@ async def test_streaming_query_endpoint_rejects_model_provider_override_without_ ) assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN assert exc_info.value.detail["response"] == expected_msg + + +@pytest.mark.asyncio +async def test_streaming_query_handles_none_event(mocker): + """Test that streaming query handles chunks with None events gracefully.""" + mock_metrics(mocker) + # Mock the client + mock_client = mocker.AsyncMock() + mock_async_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") + mock_async_lsc.return_value = mock_client + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), + ] + # Create a mock chunk with None event + mock_chunk = mocker.Mock() + mock_chunk.event = None + # Create mock streaming response with None event chunk + mock_streaming_response = mocker.AsyncMock() + mock_streaming_response.__aiter__.return_value = [mock_chunk] + # Mock the retrieve_response to return our mock streaming response + mocker.patch( + "app.endpoints.streaming_query.retrieve_response", + return_value=(mock_streaming_response, "00000000-0000-0000-0000-000000000000"), + ) + # Mock other dependencies + mocker.patch( + "app.endpoints.streaming_query.select_model_and_provider_id", + return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), + ) + mocker.patch( + "app.endpoints.streaming_query.is_transcripts_enabled", + return_value=False, + ) + mock_database_operations(mocker) + query_request = QueryRequest(query="test query") + request = Request(scope={"type": "http"}) + # This should not raise an exception + response = await streaming_query_endpoint_handler( + request, query_request, auth=MOCK_AUTH + ) + assert isinstance(response, StreamingResponse) diff --git a/tests/unit/utils/test_endpoints.py b/tests/unit/utils/test_endpoints.py index bed970a56..18f46b2dd 100644 --- a/tests/unit/utils/test_endpoints.py +++ b/tests/unit/utils/test_endpoints.py @@ -3,6 +3,7 @@ import os import pytest from fastapi import HTTPException +from pydantic import AnyUrl import constants from configuration import AppConfig @@ -841,3 +842,107 @@ def test_get_topic_summary_system_prompt_no_customization(): topic_summary_prompt = endpoints.get_topic_summary_system_prompt(cfg) assert topic_summary_prompt == constants.DEFAULT_TOPIC_SUMMARY_SYSTEM_PROMPT + + +# Tests for unified create_referenced_documents function +class TestCreateReferencedDocuments: + """Test cases for the unified create_referenced_documents function.""" + + def test_create_referenced_documents_empty_chunks(self): + """Test that empty chunks list returns empty result.""" + result = endpoints.create_referenced_documents([]) + assert not result + + def test_create_referenced_documents_http_urls_referenced_document_format(self): + """Test HTTP URLs with ReferencedDocument format.""" + + mock_chunk1 = type("MockChunk", (), {"source": "https://example.com/doc1"})() + mock_chunk2 = type("MockChunk", (), {"source": "https://example.com/doc2"})() + + result = endpoints.create_referenced_documents([mock_chunk1, mock_chunk2]) + + assert len(result) == 2 + assert result[0].doc_url == AnyUrl("https://example.com/doc1") + assert result[0].doc_title == "doc1" + assert result[1].doc_url == AnyUrl("https://example.com/doc2") + assert result[1].doc_title == "doc2" + + def test_create_referenced_documents_document_ids_with_metadata(self): + """Test document IDs with metadata enrichment.""" + + mock_chunk1 = type("MockChunk", (), {"source": "doc_id_1"})() + mock_chunk2 = type("MockChunk", (), {"source": "doc_id_2"})() + + metadata_map = { + "doc_id_1": {"docs_url": "https://example.com/doc1", "title": "Document 1"}, + "doc_id_2": {"docs_url": "https://example.com/doc2", "title": "Document 2"}, + } + + result = endpoints.create_referenced_documents( + [mock_chunk1, mock_chunk2], metadata_map + ) + + assert len(result) == 2 + assert result[0].doc_url == AnyUrl("https://example.com/doc1") + assert result[0].doc_title == "Document 1" + assert result[1].doc_url == AnyUrl("https://example.com/doc2") + assert result[1].doc_title == "Document 2" + + def test_create_referenced_documents_skips_tool_names(self): + """Test that tool names like 'knowledge_search' are skipped.""" + + mock_chunk1 = type("MockChunk", (), {"source": "knowledge_search"})() + mock_chunk2 = type("MockChunk", (), {"source": "https://example.com/doc1"})() + + result = endpoints.create_referenced_documents([mock_chunk1, mock_chunk2]) + + assert len(result) == 1 + assert result[0].doc_url == AnyUrl("https://example.com/doc1") + assert result[0].doc_title == "doc1" + + def test_create_referenced_documents_skips_empty_sources(self): + """Test that chunks with empty or None sources are skipped.""" + + mock_chunk1 = type("MockChunk", (), {"source": None})() + mock_chunk2 = type("MockChunk", (), {"source": ""})() + mock_chunk3 = type("MockChunk", (), {"source": "https://example.com/doc1"})() + + result = endpoints.create_referenced_documents( + [mock_chunk1, mock_chunk2, mock_chunk3] + ) + + assert len(result) == 1 + assert result[0].doc_url == AnyUrl("https://example.com/doc1") + assert result[0].doc_title == "doc1" + + def test_create_referenced_documents_deduplication(self): + """Test that duplicate sources are deduplicated.""" + + mock_chunk1 = type("MockChunk", (), {"source": "https://example.com/doc1"})() + mock_chunk2 = type( + "MockChunk", (), {"source": "https://example.com/doc1"} + )() # Duplicate + mock_chunk3 = type("MockChunk", (), {"source": "doc_id_1"})() + mock_chunk4 = type("MockChunk", (), {"source": "doc_id_1"})() # Duplicate + + result = endpoints.create_referenced_documents( + [mock_chunk1, mock_chunk2, mock_chunk3, mock_chunk4] + ) + + assert len(result) == 2 + assert result[0].doc_url == AnyUrl("https://example.com/doc1") + assert result[1].doc_title == "doc_id_1" + + def test_create_referenced_documents_invalid_urls(self): + """Test handling of invalid URLs.""" + + mock_chunk1 = type("MockChunk", (), {"source": "not-a-valid-url"})() + mock_chunk2 = type("MockChunk", (), {"source": "https://example.com/doc1"})() + + result = endpoints.create_referenced_documents([mock_chunk1, mock_chunk2]) + + assert len(result) == 2 + assert result[0].doc_url is None + assert result[0].doc_title == "not-a-valid-url" + assert result[1].doc_url == AnyUrl("https://example.com/doc1") + assert result[1].doc_title == "doc1"