Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 1 addition & 17 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from llama_stack_client.types.model_list_response import ModelListResponse
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
from llama_stack_client.types.tool_execution_step import ToolExecutionStep
from pydantic import AnyUrl

import constants
import metrics
Expand Down Expand Up @@ -360,22 +359,7 @@ async def query_endpoint_handler( # pylint: disable=R0914
for tc in summary.tool_calls
]

logger.info("Extracting referenced documents...")
referenced_docs = []
doc_sources = set()
for chunk in summary.rag_chunks:
if chunk.source and chunk.source not in doc_sources:
doc_sources.add(chunk.source)
referenced_docs.append(
ReferencedDocument(
doc_url=(
AnyUrl(chunk.source)
if chunk.source.startswith("http")
else None
),
doc_title=chunk.source,
)
)
logger.info("Using referenced documents from response...")

logger.info("Building final response...")
response = QueryResponse(
Expand Down
37 changes: 24 additions & 13 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
from models.responses import ForbiddenResponse, UnauthorizedResponse
from utils.endpoints import (
check_configuration_loaded,
create_referenced_documents_with_metadata,
create_rag_chunks_dict,
get_agent,
get_system_prompt,
store_conversation_into_cache,
Expand Down Expand Up @@ -142,7 +144,7 @@ def stream_start_event(conversation_id: str) -> str:
)


def stream_end_event(metadata_map: dict) -> str:
def stream_end_event(metadata_map: dict, summary: TurnSummary) -> str:
"""
Yield the end of the data stream.

Expand All @@ -158,20 +160,27 @@ def stream_end_event(metadata_map: dict) -> str:
str: A Server-Sent Events (SSE) formatted string
representing the end of the data stream.
"""
# Process RAG chunks using utility function
rag_chunks = create_rag_chunks_dict(summary)

# Extract referenced documents using utility function
referenced_docs = create_referenced_documents_with_metadata(summary, metadata_map)

# Convert ReferencedDocument objects to dictionaries for JSON serialization
referenced_docs_dict = [
{
"doc_url": str(doc.doc_url) if doc.doc_url else None,
"doc_title": doc.doc_title,
}
for doc in referenced_docs
]

return format_stream_data(
{
"event": "end",
"data": {
"referenced_documents": [
{
"doc_url": v["docs_url"],
"doc_title": v["title"],
}
for v in filter(
lambda v: ("docs_url" in v) and ("title" in v),
metadata_map.values(),
)
],
"rag_chunks": rag_chunks,
"referenced_documents": referenced_docs_dict,
"truncated": None, # TODO(jboos): implement truncated
"input_tokens": 0, # TODO(jboos): implement input tokens
"output_tokens": 0, # TODO(jboos): implement output tokens
Expand Down Expand Up @@ -667,6 +676,8 @@ async def response_generator(
yield stream_start_event(conversation_id)

async for chunk in turn_response:
if chunk.event is None:
continue
p = chunk.event.payload
if p.event_type == "turn_complete":
summary.llm_response = interleaved_content_as_str(
Expand All @@ -687,7 +698,7 @@ async def response_generator(
chunk_id += 1
yield event

yield stream_end_event(metadata_map)
yield stream_end_event(metadata_map, summary)

if not is_transcripts_enabled():
logger.debug("Transcript collection is disabled in the configuration")
Expand All @@ -701,7 +712,7 @@ async def response_generator(
query=query_request.query,
query_request=query_request,
summary=summary,
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
rag_chunks=create_rag_chunks_dict(summary),
truncated=False, # TODO(lucasagomes): implement truncation as part
# of quota work
attachments=query_request.attachments or [],
Expand Down
217 changes: 217 additions & 0 deletions src/utils/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
"""Utility functions for endpoint handlers."""

from contextlib import suppress
from typing import Any
from fastapi import HTTPException, status
from llama_stack_client._client import AsyncLlamaStackClient
from llama_stack_client.lib.agents.agent import AsyncAgent
from pydantic import AnyUrl, ValidationError

import constants
from models.cache_entry import CacheEntry
from models.requests import QueryRequest
from models.responses import ReferencedDocument
from models.database.conversations import UserConversation
from models.config import Action
from app.database import get_session
from configuration import AppConfig
from utils.suid import get_suid
from utils.types import TurnSummary
from utils.types import GraniteToolParser


Expand Down Expand Up @@ -340,3 +344,216 @@ async def get_temp_agent(
session_id = await agent.create_session(get_suid())

return agent, session_id, conversation_id


def create_rag_chunks_dict(summary: TurnSummary) -> list[dict[str, Any]]:
"""
Create dictionary representation of RAG chunks for streaming response.

Args:
summary: TurnSummary containing RAG chunks

Returns:
List of dictionaries with content, source, and score
"""
return [
{"content": chunk.content, "source": chunk.source, "score": chunk.score}
for chunk in summary.rag_chunks
]


def _process_http_source(
src: str, doc_urls: set[str]
) -> tuple[AnyUrl | None, str] | None:
"""Process HTTP source and return (doc_url, doc_title) tuple."""
if src not in doc_urls:
doc_urls.add(src)
try:
validated_url = AnyUrl(src)
except ValidationError:
logger.warning("Invalid URL in chunk source: %s", src)
validated_url = None

doc_title = src.rsplit("/", 1)[-1] or src
return (validated_url, doc_title)
return None


def _process_document_id(
src: str,
doc_ids: set[str],
doc_urls: set[str],
metas_by_id: dict[str, dict[str, Any]],
metadata_map: dict[str, Any] | None,
) -> tuple[AnyUrl | None, str] | None:
"""Process document ID and return (doc_url, doc_title) tuple."""
if src in doc_ids:
return None
doc_ids.add(src)

meta = metas_by_id.get(src, {}) if metadata_map else {}
doc_url = meta.get("docs_url")
title = meta.get("title")
# Type check to ensure we have the right types
if not isinstance(doc_url, (str, type(None))):
doc_url = None
if not isinstance(title, (str, type(None))):
title = None

if doc_url:
if doc_url in doc_urls:
return None
doc_urls.add(doc_url)

try:
validated_doc_url = None
if doc_url and doc_url.startswith("http"):
validated_doc_url = AnyUrl(doc_url)
except ValidationError:
logger.warning("Invalid URL in metadata: %s", doc_url)
validated_doc_url = None

doc_title = title or (doc_url.rsplit("/", 1)[-1] if doc_url else src)
return (validated_doc_url, doc_title)


def _add_additional_metadata_docs(
doc_urls: set[str],
metas_by_id: dict[str, dict[str, Any]],
) -> list[tuple[AnyUrl | None, str]]:
"""Add additional referenced documents from metadata_map."""
additional_entries: list[tuple[AnyUrl | None, str]] = []
for meta in metas_by_id.values():
doc_url = meta.get("docs_url")
title = meta.get("title") # Note: must be "title", not "Title"
# Type check to ensure we have the right types
if not isinstance(doc_url, (str, type(None))):
doc_url = None
if not isinstance(title, (str, type(None))):
title = None
if doc_url and doc_url not in doc_urls and title is not None:
doc_urls.add(doc_url)
try:
validated_url = None
if doc_url.startswith("http"):
validated_url = AnyUrl(doc_url)
except ValidationError:
logger.warning("Invalid URL in metadata_map: %s", doc_url)
validated_url = None

additional_entries.append((validated_url, title))
return additional_entries


def _process_rag_chunks_for_documents(
rag_chunks: list,
metadata_map: dict[str, Any] | None = None,
) -> list[tuple[AnyUrl | None, str]]:
"""
Process RAG chunks and return a list of (doc_url, doc_title) tuples.

This is the core logic shared between both return formats.
"""
doc_urls: set[str] = set()
doc_ids: set[str] = set()

# Process metadata_map if provided
metas_by_id: dict[str, dict[str, Any]] = {}
if metadata_map:
metas_by_id = {k: v for k, v in metadata_map.items() if isinstance(v, dict)}

document_entries: list[tuple[AnyUrl | None, str]] = []

for chunk in rag_chunks:
src = chunk.source
if not src or src == constants.DEFAULT_RAG_TOOL:
continue

if src.startswith("http"):
entry = _process_http_source(src, doc_urls)
if entry:
document_entries.append(entry)
else:
entry = _process_document_id(
src, doc_ids, doc_urls, metas_by_id, metadata_map
)
if entry:
document_entries.append(entry)

# Add any additional referenced documents from metadata_map not already present
if metadata_map:
additional_entries = _add_additional_metadata_docs(doc_urls, metas_by_id)
document_entries.extend(additional_entries)

return document_entries


def create_referenced_documents(
rag_chunks: list,
metadata_map: dict[str, Any] | None = None,
return_dict_format: bool = False,
) -> list[ReferencedDocument] | list[dict[str, str | None]]:
"""
Create referenced documents from RAG chunks with optional metadata enrichment.

This unified function processes RAG chunks and creates referenced documents with
optional metadata enrichment, deduplication, and proper URL handling. It can return
either ReferencedDocument objects (for query endpoint) or dictionaries (for streaming).

Args:
rag_chunks: List of RAG chunks with source information
metadata_map: Optional mapping containing metadata about referenced documents
return_dict_format: If True, returns list of dicts; if False, returns list of
ReferencedDocument objects

Returns:
List of ReferencedDocument objects or dictionaries with doc_url and doc_title
"""
document_entries = _process_rag_chunks_for_documents(rag_chunks, metadata_map)

if return_dict_format:
return [
{
"doc_url": str(doc_url) if doc_url else None,
"doc_title": doc_title,
}
for doc_url, doc_title in document_entries
]
return [
ReferencedDocument(doc_url=doc_url, doc_title=doc_title)
for doc_url, doc_title in document_entries
]


# Backward compatibility functions
def create_referenced_documents_with_metadata(
summary: TurnSummary, metadata_map: dict[str, Any]
) -> list[ReferencedDocument]:
"""
Create referenced documents from RAG chunks with metadata enrichment for streaming.

This function now returns ReferencedDocument objects for consistency with the query endpoint.
"""
document_entries = _process_rag_chunks_for_documents(
summary.rag_chunks, metadata_map
)
return [
ReferencedDocument(doc_url=doc_url, doc_title=doc_title)
for doc_url, doc_title in document_entries
]


def create_referenced_documents_from_chunks(
rag_chunks: list,
) -> list[ReferencedDocument]:
"""
Create referenced documents from RAG chunks for query endpoint.

This is a backward compatibility wrapper around the unified
create_referenced_documents function.
"""
document_entries = _process_rag_chunks_for_documents(rag_chunks)
return [
ReferencedDocument(doc_url=doc_url, doc_title=doc_title)
for doc_url, doc_title in document_entries
]
Loading
Loading