From ef94176582cd016c888faa768732e135ee8e25a8 Mon Sep 17 00:00:00 2001 From: Anxhela Coba Date: Tue, 16 Sep 2025 12:41:51 -0400 Subject: [PATCH 01/10] Add RAG chunks in query response Signed-off-by: Anxhela Coba --- pyproject.toml | 17 +++++++- run.yaml | 29 +++++++++++--- src/app/endpoints/query.py | 48 +++++++++++++++++++++- src/models/responses.py | 81 ++++++++++++++++++++++++++++++++------ src/utils/types.py | 69 ++++++++++++++++++++++++++++++-- uv.lock | 33 +++++++++++++++- 6 files changed, 253 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 93c4d1f60..14e56f964 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,11 +45,26 @@ dependencies = [ "email-validator>=2.2.0", "openai==1.99.9", # Used by database interface - "sqlalchemy>=2.0.42", + "sqlalchemy>=2.0.41", # Used by Llama Stack version checker "semver<4.0.0", # Used by authorization resolvers "jsonpath-ng>=1.6.1", + "opentelemetry-sdk>=1.34.0", + "opentelemetry-exporter-otlp>=1.34.0", + "opentelemetry-instrumentation>=0.55b0", + "aiosqlite>=0.21.0", + "litellm>=1.72.1", + "blobfile>=3.0.0", + "datasets>=3.6.0", + "faiss-cpu>=1.11.0", + "mcp>=1.9.4", + "autoevals>=0.0.129", + "psutil>=7.0.0", + "torch>=2.7.1", + "peft>=0.15.2", + "trl>=0.18.2", + "sentence-transformers>=5.1.0", ] diff --git a/run.yaml b/run.yaml index af519cfa8..212802004 100644 --- a/run.yaml +++ b/run.yaml @@ -60,6 +60,9 @@ providers: provider_id: meta-reference provider_type: inline::meta-reference inference: + - provider_id: sentence-transformers # Can be any embedding provider + provider_type: inline::sentence-transformers + config: {} - provider_id: openai provider_type: remote::openai config: @@ -99,14 +102,17 @@ providers: - provider_id: model-context-protocol provider_type: remote::model-context-protocol config: {} + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} vector_io: - config: kvstore: - db_path: .llama/distributions/ollama/faiss_store.db + db_path: /path/to/your/vector/store.db namespace: null type: sqlite - provider_id: faiss - provider_type: inline::faiss + provider_id: my_vector_db + provider_type: inline::faiss # Or prefered vector DB scoring_fns: [] server: auth: null @@ -117,10 +123,23 @@ server: tls_certfile: null tls_keyfile: null shields: [] -vector_dbs: [] - +vector_dbs: + - vector_db_id: my_knowledge_base + embedding_model: sentence-transformers/all-mpnet-base-v2 + embedding_dimension: 768 + provider_id: my_vector_db models: + - metadata: + embedding_dimension: 768 # Depends on chosen model + model_id: sentence-transformers/all-mpnet-base-v2 # Example model + provider_id: sentence-transformers + provider_model_id: path/to/model + model_type: embedding - model_id: gpt-4-turbo provider_id: openai model_type: llm provider_model_id: gpt-4-turbo + +tool_groups: + - toolgroup_id: builtin::rag + provider_id: rag-runtime \ No newline at end of file diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 43c3eb603..b42f2ac00 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -30,7 +30,7 @@ from models.config import Action from models.database.conversations import UserConversation from models.requests import QueryRequest, Attachment -from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse +from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse, RAGChunk, ReferencedDocument, ToolCall from utils.endpoints import ( check_configuration_loaded, get_agent, @@ -243,6 +243,7 @@ async def query_endpoint_handler( attachments=query_request.attachments or [], ) + logger.info("Persisting conversation details...") persist_user_conversation_details( user_id=user_id, conversation_id=conversation_id, @@ -250,10 +251,53 @@ async def query_endpoint_handler( provider_id=provider_id, ) - return QueryResponse( + # Convert tool calls and RAG chunks to response format + logger.info("Processing tool calls...") + tool_calls = [ + ToolCall( + tool_name=tc.name, + arguments=tc.args if isinstance(tc.args, dict) else {"query": str(tc.args)}, + result={"response": tc.response} if tc.response else None + ) + for tc in summary.tool_calls + ] + + + logger.info("Processing RAG chunks...") + rag_chunks = [ + RAGChunk( + content=chunk.content, + source=chunk.source, + score=chunk.score + ) + for chunk in summary.rag_chunks + ] + + # Extract referenced documents from RAG chunks + logger.info("Extracting referenced documents...") + referenced_docs = [] + doc_sources = set() + for chunk in summary.rag_chunks: + if chunk.source and chunk.source not in doc_sources: + doc_sources.add(chunk.source) + referenced_docs.append( + ReferencedDocument( + url=chunk.source if chunk.source.startswith("http") else None, + title=chunk.source, + chunk_count=sum(1 for c in summary.rag_chunks if c.source == chunk.source) + ) + ) + + logger.info("Building final response...") + response = QueryResponse( conversation_id=conversation_id, response=summary.llm_response, + rag_chunks=rag_chunks if rag_chunks else None, + referenced_documents=referenced_docs if referenced_docs else None, + tool_calls=tool_calls if tool_calls else None, ) + logger.info("Query processing completed successfully!") + return response # connection to Llama Stack server except APIConnectionError as e: diff --git a/src/models/responses.py b/src/models/responses.py index 29b5e5776..8d67d1b83 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -1,6 +1,6 @@ """Models for REST API responses.""" -from typing import Any, Optional +from typing import Any, Optional, List from pydantic import BaseModel, Field @@ -34,23 +34,45 @@ class ModelsResponse(BaseModel): ) -# TODO(lucasagomes): a lot of fields to add to QueryResponse. For now -# we are keeping it simple. The missing fields are: -# - referenced_documents: The optional URLs and titles for the documents used -# to generate the response. -# - truncated: Set to True if conversation history was truncated to be within context window. -# - input_tokens: Number of tokens sent to LLM -# - output_tokens: Number of tokens received from LLM -# - available_quotas: Quota available as measured by all configured quota limiters -# - tool_calls: List of tool requests. -# - tool_results: List of tool results. -# See LLMResponse in ols-service for more details. +class RAGChunk(BaseModel): + """Model representing a RAG chunk used in the response.""" + + content: str = Field(description="The content of the chunk") + source: Optional[str] = Field(None, description="Source document or URL") + score: Optional[float] = Field(None, description="Relevance score") + + +class ReferencedDocument(BaseModel): + """Model representing a document referenced in the response.""" + + url: Optional[str] = Field(None, description="URL of the document") + title: Optional[str] = Field(None, description="Title of the document") + chunk_count: Optional[int] = Field(None, description="Number of chunks from this document") + + +class ToolCall(BaseModel): + """Model representing a tool call made during response generation.""" + + tool_name: str = Field(description="Name of the tool called") + arguments: dict[str, Any] = Field(description="Arguments passed to the tool") + result: Optional[dict[str, Any]] = Field(None, description="Result from the tool") + + class QueryResponse(BaseModel): """Model representing LLM response to a query. Attributes: conversation_id: The optional conversation ID (UUID). response: The response. + rag_chunks: List of RAG chunks used to generate the response. + referenced_documents: List of documents referenced in the response. + tool_calls: List of tool calls made during response generation. + TODO: truncated: Whether conversation history was truncated. + TODO: input_tokens: Number of tokens sent to LLM. + TODO: output_tokens: Number of tokens received from LLM. + TODO: available_quotas: Quota available as measured by all configured quota limiters + TODO: tool_results: List of tool results. + """ conversation_id: Optional[str] = Field( @@ -66,6 +88,20 @@ class QueryResponse(BaseModel): ], ) + rag_chunks: Optional[List[RAGChunk]] = Field( + None, + description="List of RAG chunks used to generate the response", + ) + + referenced_documents: Optional[List[ReferencedDocument]] = Field( + None, + description="List of documents referenced in the response", + ) + + tool_calls: Optional[List[ToolCall]] = Field( + None, + description="List of tool calls made during response generation", + ) # provides examples for /docs endpoint model_config = { "json_schema_extra": { @@ -73,6 +109,27 @@ class QueryResponse(BaseModel): { "conversation_id": "123e4567-e89b-12d3-a456-426614174000", "response": "Operator Lifecycle Manager (OLM) helps users install...", + "rag_chunks": [ + { + "content": "OLM is a component of the Operator Framework toolkit...", + "source": "kubernetes-docs/operators.md", + "score": 0.95 + } + ], + "referenced_documents": [ + { + "url": "https://kubernetes.io/docs/concepts/extend-kubernetes/operator/", + "title": "Operator Pattern", + "chunk_count": 2 + } + ], + "tool_calls": [ + { + "tool_name": "knowledge_search", + "arguments": {"query": "operator lifecycle manager"}, + "result": {"chunks_found": 5} + } + ], } ] } diff --git a/src/utils/types.py b/src/utils/types.py index 5770139ae..ccdc1a31b 100644 --- a/src/utils/types.py +++ b/src/utils/types.py @@ -1,7 +1,7 @@ """Common types for the project.""" -from typing import Any, Optional - +from typing import Any, Optional, List +import json from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str from llama_stack_client.lib.agents.tool_parser import ToolParser from llama_stack_client.types.shared.completion_message import CompletionMessage @@ -56,11 +56,20 @@ class ToolCallSummary(BaseModel): response: str | None +class RAGChunkData(BaseModel): + """RAG chunk data extracted from tool responses.""" + + content: str + source: Optional[str] = None + score: Optional[float] = None + + class TurnSummary(BaseModel): """Summary of a turn in llama stack.""" llm_response: str tool_calls: list[ToolCallSummary] + rag_chunks: List[RAGChunkData] = [] def append_tool_calls_from_llama(self, tec: ToolExecutionStep) -> None: """Append the tool calls from a llama tool execution step.""" @@ -68,11 +77,65 @@ def append_tool_calls_from_llama(self, tec: ToolExecutionStep) -> None: responses_by_id = {tc.call_id: tc for tc in tec.tool_responses} for call_id, tc in calls_by_id.items(): resp = responses_by_id.get(call_id) + response_content = interleaved_content_as_str(resp.content) if resp else None + self.tool_calls.append( ToolCallSummary( id=call_id, name=tc.tool_name, args=tc.arguments, - response=interleaved_content_as_str(resp.content) if resp else None, + response=response_content, ) ) + + # Extract RAG chunks from knowledge_search tool responses + if tc.tool_name == "knowledge_search" and resp and response_content: + self._extract_rag_chunks_from_response(response_content) + + def _extract_rag_chunks_from_response(self, response_content: str) -> None: + """Extract RAG chunks from tool response content.""" + try: + # Parse the response to get chunks + # Try JSON first + try: + data = json.loads(response_content) + if isinstance(data, dict) and "chunks" in data: + for chunk in data["chunks"]: + self.rag_chunks.append( + RAGChunkData( + content=chunk.get("content", ""), + source=chunk.get("source"), + score=chunk.get("score") + ) + ) + elif isinstance(data, list): + # Handle list of chunks + for chunk in data: + if isinstance(chunk, dict): + self.rag_chunks.append( + RAGChunkData( + content=chunk.get("content", str(chunk)), + source=chunk.get("source"), + score=chunk.get("score") + ) + ) + except json.JSONDecodeError: + # If not JSON, treat the entire response as a single chunk + if response_content.strip(): + self.rag_chunks.append( + RAGChunkData( + content=response_content, + source="knowledge_search", + score=None + ) + ) + except Exception: + # Treat response as single chunk + if response_content.strip(): + self.rag_chunks.append( + RAGChunkData( + content=response_content, + source="knowledge_search", + score=None + ) + ) diff --git a/uv.lock b/uv.lock index fdacc5ff4..0dd25c613 100644 --- a/uv.lock +++ b/uv.lock @@ -1280,20 +1280,36 @@ name = "lightspeed-stack" source = { editable = "." } dependencies = [ { name = "aiohttp" }, + { name = "aiosqlite" }, { name = "authlib" }, + { name = "autoevals" }, + { name = "blobfile" }, { name = "cachetools" }, + { name = "datasets" }, { name = "email-validator" }, + { name = "faiss-cpu" }, { name = "fastapi" }, { name = "jsonpath-ng" }, { name = "kubernetes" }, + { name = "litellm" }, { name = "llama-stack" }, { name = "llama-stack-client" }, + { name = "mcp" }, { name = "openai" }, + { name = "opentelemetry-exporter-otlp" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-sdk" }, + { name = "peft" }, { name = "prometheus-client" }, + { name = "psutil" }, { name = "rich" }, { name = "semver" }, + { name = "sentence-transformers" }, { name = "sqlalchemy" }, { name = "starlette" }, + { name = "torch", version = "2.7.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'darwin'" }, + { name = "torch", version = "2.7.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform != 'darwin'" }, + { name = "trl" }, { name = "uvicorn" }, ] @@ -1360,20 +1376,35 @@ llslibdev = [ [package.metadata] requires-dist = [ { name = "aiohttp", specifier = ">=3.12.14" }, + { name = "aiosqlite", specifier = ">=0.21.0" }, { name = "authlib", specifier = ">=1.6.0" }, + { name = "autoevals", specifier = ">=0.0.129" }, + { name = "blobfile", specifier = ">=3.0.0" }, { name = "cachetools", specifier = ">=6.1.0" }, + { name = "datasets", specifier = ">=3.6.0" }, { name = "email-validator", specifier = ">=2.2.0" }, + { name = "faiss-cpu", specifier = ">=1.11.0" }, { name = "fastapi", specifier = ">=0.115.12" }, { name = "jsonpath-ng", specifier = ">=1.6.1" }, { name = "kubernetes", specifier = ">=30.1.0" }, + { name = "litellm", specifier = ">=1.72.1" }, { name = "llama-stack", specifier = "==0.2.19" }, { name = "llama-stack-client", specifier = "==0.2.19" }, + { name = "mcp", specifier = ">=1.9.4" }, { name = "openai", specifier = "==1.99.9" }, + { name = "opentelemetry-exporter-otlp", specifier = ">=1.34.0" }, + { name = "opentelemetry-instrumentation", specifier = ">=0.55b0" }, + { name = "opentelemetry-sdk", specifier = ">=1.34.0" }, + { name = "peft", specifier = ">=0.15.2" }, { name = "prometheus-client", specifier = ">=0.22.1" }, + { name = "psutil", specifier = ">=7.0.0" }, { name = "rich", specifier = ">=14.0.0" }, { name = "semver", specifier = "<4.0.0" }, - { name = "sqlalchemy", specifier = ">=2.0.42" }, + { name = "sentence-transformers", specifier = ">=5.1.0" }, + { name = "sqlalchemy", specifier = ">=2.0.41" }, { name = "starlette", specifier = ">=0.47.1" }, + { name = "torch", specifier = ">=2.7.1", index = "https://download.pytorch.org/whl/cpu" }, + { name = "trl", specifier = ">=0.18.2" }, { name = "uvicorn", specifier = ">=0.34.3" }, ] From 2d798ec77613d8e418f11a717acb8cf668f06b3e Mon Sep 17 00:00:00 2001 From: Anxhela Coba Date: Tue, 16 Sep 2025 14:00:26 -0400 Subject: [PATCH 02/10] default values Signed-off-by: Anxhela Coba --- run.yaml | 12 ++++++------ src/utils/types.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/run.yaml b/run.yaml index 212802004..98dc92d50 100644 --- a/run.yaml +++ b/run.yaml @@ -108,11 +108,11 @@ providers: vector_io: - config: kvstore: - db_path: /path/to/your/vector/store.db + db_path: .llama/distributions/ollama/faiss_store.db # Location of vector database namespace: null type: sqlite - provider_id: my_vector_db - provider_type: inline::faiss # Or prefered vector DB + provider_id: faiss + provider_type: inline::faiss # Or preferred vector DB scoring_fns: [] server: auth: null @@ -131,9 +131,9 @@ vector_dbs: models: - metadata: embedding_dimension: 768 # Depends on chosen model - model_id: sentence-transformers/all-mpnet-base-v2 # Example model + model_id: sentence-transformers/all-mpnet-base-v2 # Example embedding model provider_id: sentence-transformers - provider_model_id: path/to/model + provider_model_id: sentence-transformers/all-mpnet-base-v2 # Location of embedding model model_type: embedding - model_id: gpt-4-turbo provider_id: openai @@ -142,4 +142,4 @@ models: tool_groups: - toolgroup_id: builtin::rag - provider_id: rag-runtime \ No newline at end of file + provider_id: rag-runtime diff --git a/src/utils/types.py b/src/utils/types.py index ccdc1a31b..a65a8df31 100644 --- a/src/utils/types.py +++ b/src/utils/types.py @@ -7,7 +7,7 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.shared.tool_call import ToolCall from llama_stack_client.types.tool_execution_step import ToolExecutionStep -from pydantic.main import BaseModel +from pydantic.main import BaseModel, Field class Singleton(type): @@ -69,7 +69,7 @@ class TurnSummary(BaseModel): llm_response: str tool_calls: list[ToolCallSummary] - rag_chunks: List[RAGChunkData] = [] + rag_chunks: List[RAGChunkData] = Field(default_factory=list) def append_tool_calls_from_llama(self, tec: ToolExecutionStep) -> None: """Append the tool calls from a llama tool execution step.""" From 927754c1eab27ab0adf9b1799204cf6eab944f86 Mon Sep 17 00:00:00 2001 From: Anxhela Coba Date: Tue, 16 Sep 2025 14:17:25 -0400 Subject: [PATCH 03/10] add test for RAGChunks Signed-off-by: Anxhela Coba --- pyproject.toml | 1 + src/utils/types.py | 2 +- tests/unit/app/endpoints/test_streaming_query.py | 9 ++++++++- uv.lock | 2 ++ 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 14e56f964..0a6fa266d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ dependencies = [ "peft>=0.15.2", "trl>=0.18.2", "sentence-transformers>=5.1.0", + "greenlet>=3.2.4", ] diff --git a/src/utils/types.py b/src/utils/types.py index a65a8df31..37f9a2ec1 100644 --- a/src/utils/types.py +++ b/src/utils/types.py @@ -7,7 +7,7 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.shared.tool_call import ToolCall from llama_stack_client.types.tool_execution_step import ToolExecutionStep -from pydantic.main import BaseModel, Field +from pydantic import BaseModel, Field class Singleton(type): diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 38983666a..d15b23d3f 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -45,7 +45,7 @@ from models.requests import QueryRequest, Attachment from models.config import ModelContextProtocolServer, Action from authorization.resolvers import NoopRolesResolver -from utils.types import ToolCallSummary, TurnSummary +from utils.types import ToolCallSummary, TurnSummary, RAGChunkData MOCK_AUTH = ("mock_user_id", "mock_username", False, "mock_token") @@ -343,6 +343,13 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) response=" ".join(SAMPLE_KNOWLEDGE_SEARCH_RESULTS), ) ], + rag_chunks=[ + RAGChunkData( + content=" ".join(SAMPLE_KNOWLEDGE_SEARCH_RESULTS), + source="knowledge_search", + score=None, + ) + ], ), attachments=[], rag_chunks=[], diff --git a/uv.lock b/uv.lock index 0dd25c613..faaa40f3f 100644 --- a/uv.lock +++ b/uv.lock @@ -1289,6 +1289,7 @@ dependencies = [ { name = "email-validator" }, { name = "faiss-cpu" }, { name = "fastapi" }, + { name = "greenlet" }, { name = "jsonpath-ng" }, { name = "kubernetes" }, { name = "litellm" }, @@ -1385,6 +1386,7 @@ requires-dist = [ { name = "email-validator", specifier = ">=2.2.0" }, { name = "faiss-cpu", specifier = ">=1.11.0" }, { name = "fastapi", specifier = ">=0.115.12" }, + { name = "greenlet", specifier = ">=3.2.4" }, { name = "jsonpath-ng", specifier = ">=1.6.1" }, { name = "kubernetes", specifier = ">=30.1.0" }, { name = "litellm", specifier = ">=1.72.1" }, From 92d28fef24e787a2c26c5868b66161f829da5eec Mon Sep 17 00:00:00 2001 From: Anxhela Coba Date: Thu, 18 Sep 2025 15:04:53 -0400 Subject: [PATCH 04/10] pr review comments Signed-off-by: Anxhela Coba --- src/app/endpoints/query.py | 23 +++++++----------- src/constants.py | 2 ++ src/models/responses.py | 11 ++++----- src/utils/transcripts.py | 4 ++-- src/utils/types.py | 24 ++++++++----------- .../app/endpoints/test_streaming_query.py | 5 ++-- 6 files changed, 29 insertions(+), 40 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index b42f2ac00..905324ba9 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -226,6 +226,10 @@ async def query_endpoint_handler( # Update metrics for the LLM call metrics.llm_calls_total.labels(provider_id, model_id).inc() + # Convert RAG chunks to dictionary format once for reuse + logger.info("Processing RAG chunks...") + rag_chunks_dict = [chunk.model_dump() for chunk in summary.rag_chunks] + if not is_transcripts_enabled(): logger.debug("Transcript collection is disabled in the configuration") else: @@ -238,7 +242,7 @@ async def query_endpoint_handler( query=query_request.query, query_request=query_request, summary=summary, - rag_chunks=[], # TODO(lucasagomes): implement rag_chunks + rag_chunks=rag_chunks_dict, truncated=False, # TODO(lucasagomes): implement truncation as part of quota work attachments=query_request.attachments or [], ) @@ -251,29 +255,18 @@ async def query_endpoint_handler( provider_id=provider_id, ) - # Convert tool calls and RAG chunks to response format + # Convert tool calls to response format logger.info("Processing tool calls...") tool_calls = [ ToolCall( tool_name=tc.name, arguments=tc.args if isinstance(tc.args, dict) else {"query": str(tc.args)}, - result={"response": tc.response} if tc.response else None + result={"response": tc.response} if tc.response and tc.name != constants.DEFAULT_RAG_TOOL else None ) for tc in summary.tool_calls ] - logger.info("Processing RAG chunks...") - rag_chunks = [ - RAGChunk( - content=chunk.content, - source=chunk.source, - score=chunk.score - ) - for chunk in summary.rag_chunks - ] - - # Extract referenced documents from RAG chunks logger.info("Extracting referenced documents...") referenced_docs = [] doc_sources = set() @@ -292,7 +285,7 @@ async def query_endpoint_handler( response = QueryResponse( conversation_id=conversation_id, response=summary.llm_response, - rag_chunks=rag_chunks if rag_chunks else None, + rag_chunks=summary.rag_chunks if summary.rag_chunks else [], referenced_documents=referenced_docs if referenced_docs else None, tool_calls=tool_calls if tool_calls else None, ) diff --git a/src/constants.py b/src/constants.py index e79ebcebb..8499ef937 100644 --- a/src/constants.py +++ b/src/constants.py @@ -52,6 +52,8 @@ DEFAULT_JWT_UID_CLAIM = "user_id" DEFAULT_JWT_USER_NAME_CLAIM = "username" +# default RAG tool value +DEFAULT_RAG_TOOL = "knowledge_search" # PostgreSQL connection constants # See: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-SSLMODE diff --git a/src/models/responses.py b/src/models/responses.py index 8d67d1b83..ce63d84d6 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -1,6 +1,6 @@ """Models for REST API responses.""" -from typing import Any, Optional, List +from typing import Any, Optional from pydantic import BaseModel, Field @@ -88,17 +88,14 @@ class QueryResponse(BaseModel): ], ) - rag_chunks: Optional[List[RAGChunk]] = Field( - None, - description="List of RAG chunks used to generate the response", - ) + rag_chunks: list[RAGChunk] = [] - referenced_documents: Optional[List[ReferencedDocument]] = Field( + referenced_documents: Optional[list[ReferencedDocument]] = Field( None, description="List of documents referenced in the response", ) - tool_calls: Optional[List[ToolCall]] = Field( + tool_calls: Optional[list[ToolCall]] = Field( None, description="List of tool calls made during response generation", ) diff --git a/src/utils/transcripts.py b/src/utils/transcripts.py index e29d4319d..74b20fdbd 100644 --- a/src/utils/transcripts.py +++ b/src/utils/transcripts.py @@ -39,7 +39,7 @@ def store_transcript( # pylint: disable=too-many-arguments,too-many-positional- query: str, query_request: QueryRequest, summary: TurnSummary, - rag_chunks: list[str], + rag_chunks: list[dict], truncated: bool, attachments: list[Attachment], ) -> None: @@ -52,7 +52,7 @@ def store_transcript( # pylint: disable=too-many-arguments,too-many-positional- query: The query (without attachments). query_request: The request containing a query. summary: Summary of the query/response turn. - rag_chunks: The list of `RagChunk` objects. + rag_chunks: The list of serialized `RAGChunk` dictionaries. truncated: The flag indicating if the history was truncated. attachments: The list of `Attachment` objects. """ diff --git a/src/utils/types.py b/src/utils/types.py index 37f9a2ec1..4aef06e8c 100644 --- a/src/utils/types.py +++ b/src/utils/types.py @@ -1,13 +1,15 @@ """Common types for the project.""" -from typing import Any, Optional, List +from typing import Any, Optional import json from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str from llama_stack_client.lib.agents.tool_parser import ToolParser from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.shared.tool_call import ToolCall from llama_stack_client.types.tool_execution_step import ToolExecutionStep -from pydantic import BaseModel, Field +from constants import DEFAULT_RAG_TOOL +from pydantic import BaseModel +from models.responses import RAGChunk, Field class Singleton(type): @@ -56,12 +58,6 @@ class ToolCallSummary(BaseModel): response: str | None -class RAGChunkData(BaseModel): - """RAG chunk data extracted from tool responses.""" - - content: str - source: Optional[str] = None - score: Optional[float] = None class TurnSummary(BaseModel): @@ -69,7 +65,7 @@ class TurnSummary(BaseModel): llm_response: str tool_calls: list[ToolCallSummary] - rag_chunks: List[RAGChunkData] = Field(default_factory=list) + rag_chunks: list[RAGChunk] = [] def append_tool_calls_from_llama(self, tec: ToolExecutionStep) -> None: """Append the tool calls from a llama tool execution step.""" @@ -89,7 +85,7 @@ def append_tool_calls_from_llama(self, tec: ToolExecutionStep) -> None: ) # Extract RAG chunks from knowledge_search tool responses - if tc.tool_name == "knowledge_search" and resp and response_content: + if tc.tool_name == DEFAULT_RAG_TOOL and resp and response_content: self._extract_rag_chunks_from_response(response_content) def _extract_rag_chunks_from_response(self, response_content: str) -> None: @@ -102,7 +98,7 @@ def _extract_rag_chunks_from_response(self, response_content: str) -> None: if isinstance(data, dict) and "chunks" in data: for chunk in data["chunks"]: self.rag_chunks.append( - RAGChunkData( + RAGChunk( content=chunk.get("content", ""), source=chunk.get("source"), score=chunk.get("score") @@ -113,7 +109,7 @@ def _extract_rag_chunks_from_response(self, response_content: str) -> None: for chunk in data: if isinstance(chunk, dict): self.rag_chunks.append( - RAGChunkData( + RAGChunk( content=chunk.get("content", str(chunk)), source=chunk.get("source"), score=chunk.get("score") @@ -123,7 +119,7 @@ def _extract_rag_chunks_from_response(self, response_content: str) -> None: # If not JSON, treat the entire response as a single chunk if response_content.strip(): self.rag_chunks.append( - RAGChunkData( + RAGChunk( content=response_content, source="knowledge_search", score=None @@ -133,7 +129,7 @@ def _extract_rag_chunks_from_response(self, response_content: str) -> None: # Treat response as single chunk if response_content.strip(): self.rag_chunks.append( - RAGChunkData( + RAGChunk( content=response_content, source="knowledge_search", score=None diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index d15b23d3f..06792bcd1 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -45,7 +45,8 @@ from models.requests import QueryRequest, Attachment from models.config import ModelContextProtocolServer, Action from authorization.resolvers import NoopRolesResolver -from utils.types import ToolCallSummary, TurnSummary, RAGChunkData +from utils.types import ToolCallSummary, TurnSummary +from models.responses import RAGChunk MOCK_AUTH = ("mock_user_id", "mock_username", False, "mock_token") @@ -344,7 +345,7 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) ) ], rag_chunks=[ - RAGChunkData( + RAGChunk( content=" ".join(SAMPLE_KNOWLEDGE_SEARCH_RESULTS), source="knowledge_search", score=None, From ffd014f7394760f4aec4786e63b2df8e40585cab Mon Sep 17 00:00:00 2001 From: Anxhela Coba Date: Thu, 18 Sep 2025 17:11:24 -0400 Subject: [PATCH 05/10] remove lbirary dependencies Signed-off-by: Anxhela Coba --- pyproject.toml | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0a6fa266d..eebf48660 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,27 +45,11 @@ dependencies = [ "email-validator>=2.2.0", "openai==1.99.9", # Used by database interface - "sqlalchemy>=2.0.41", + "sqlalchemy>=2.0.42", # Used by Llama Stack version checker "semver<4.0.0", # Used by authorization resolvers - "jsonpath-ng>=1.6.1", - "opentelemetry-sdk>=1.34.0", - "opentelemetry-exporter-otlp>=1.34.0", - "opentelemetry-instrumentation>=0.55b0", - "aiosqlite>=0.21.0", - "litellm>=1.72.1", - "blobfile>=3.0.0", - "datasets>=3.6.0", - "faiss-cpu>=1.11.0", - "mcp>=1.9.4", - "autoevals>=0.0.129", - "psutil>=7.0.0", - "torch>=2.7.1", - "peft>=0.15.2", - "trl>=0.18.2", - "sentence-transformers>=5.1.0", - "greenlet>=3.2.4", + "jsonpath-ng>=1.6.1" ] From 98f8141194c2f91a525cfc93cfca7a54db07efd4 Mon Sep 17 00:00:00 2001 From: Anxhela Coba Date: Fri, 19 Sep 2025 16:11:17 -0400 Subject: [PATCH 06/10] variables and uv lock Signed-off-by: Anxhela Coba --- run.yaml | 2 +- src/utils/types.py | 4 ++-- uv.lock | 35 +---------------------------------- 3 files changed, 4 insertions(+), 37 deletions(-) diff --git a/run.yaml b/run.yaml index 98dc92d50..5d1d6765d 100644 --- a/run.yaml +++ b/run.yaml @@ -127,7 +127,7 @@ vector_dbs: - vector_db_id: my_knowledge_base embedding_model: sentence-transformers/all-mpnet-base-v2 embedding_dimension: 768 - provider_id: my_vector_db + provider_id: faiss models: - metadata: embedding_dimension: 768 # Depends on chosen model diff --git a/src/utils/types.py b/src/utils/types.py index 4aef06e8c..cecf96976 100644 --- a/src/utils/types.py +++ b/src/utils/types.py @@ -121,7 +121,7 @@ def _extract_rag_chunks_from_response(self, response_content: str) -> None: self.rag_chunks.append( RAGChunk( content=response_content, - source="knowledge_search", + source=DEFAULT_RAG_TOOL, score=None ) ) @@ -131,7 +131,7 @@ def _extract_rag_chunks_from_response(self, response_content: str) -> None: self.rag_chunks.append( RAGChunk( content=response_content, - source="knowledge_search", + source=DEFAULT_RAG_TOOL, score=None ) ) diff --git a/uv.lock b/uv.lock index faaa40f3f..fdacc5ff4 100644 --- a/uv.lock +++ b/uv.lock @@ -1280,37 +1280,20 @@ name = "lightspeed-stack" source = { editable = "." } dependencies = [ { name = "aiohttp" }, - { name = "aiosqlite" }, { name = "authlib" }, - { name = "autoevals" }, - { name = "blobfile" }, { name = "cachetools" }, - { name = "datasets" }, { name = "email-validator" }, - { name = "faiss-cpu" }, { name = "fastapi" }, - { name = "greenlet" }, { name = "jsonpath-ng" }, { name = "kubernetes" }, - { name = "litellm" }, { name = "llama-stack" }, { name = "llama-stack-client" }, - { name = "mcp" }, { name = "openai" }, - { name = "opentelemetry-exporter-otlp" }, - { name = "opentelemetry-instrumentation" }, - { name = "opentelemetry-sdk" }, - { name = "peft" }, { name = "prometheus-client" }, - { name = "psutil" }, { name = "rich" }, { name = "semver" }, - { name = "sentence-transformers" }, { name = "sqlalchemy" }, { name = "starlette" }, - { name = "torch", version = "2.7.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'darwin'" }, - { name = "torch", version = "2.7.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform != 'darwin'" }, - { name = "trl" }, { name = "uvicorn" }, ] @@ -1377,36 +1360,20 @@ llslibdev = [ [package.metadata] requires-dist = [ { name = "aiohttp", specifier = ">=3.12.14" }, - { name = "aiosqlite", specifier = ">=0.21.0" }, { name = "authlib", specifier = ">=1.6.0" }, - { name = "autoevals", specifier = ">=0.0.129" }, - { name = "blobfile", specifier = ">=3.0.0" }, { name = "cachetools", specifier = ">=6.1.0" }, - { name = "datasets", specifier = ">=3.6.0" }, { name = "email-validator", specifier = ">=2.2.0" }, - { name = "faiss-cpu", specifier = ">=1.11.0" }, { name = "fastapi", specifier = ">=0.115.12" }, - { name = "greenlet", specifier = ">=3.2.4" }, { name = "jsonpath-ng", specifier = ">=1.6.1" }, { name = "kubernetes", specifier = ">=30.1.0" }, - { name = "litellm", specifier = ">=1.72.1" }, { name = "llama-stack", specifier = "==0.2.19" }, { name = "llama-stack-client", specifier = "==0.2.19" }, - { name = "mcp", specifier = ">=1.9.4" }, { name = "openai", specifier = "==1.99.9" }, - { name = "opentelemetry-exporter-otlp", specifier = ">=1.34.0" }, - { name = "opentelemetry-instrumentation", specifier = ">=0.55b0" }, - { name = "opentelemetry-sdk", specifier = ">=1.34.0" }, - { name = "peft", specifier = ">=0.15.2" }, { name = "prometheus-client", specifier = ">=0.22.1" }, - { name = "psutil", specifier = ">=7.0.0" }, { name = "rich", specifier = ">=14.0.0" }, { name = "semver", specifier = "<4.0.0" }, - { name = "sentence-transformers", specifier = ">=5.1.0" }, - { name = "sqlalchemy", specifier = ">=2.0.41" }, + { name = "sqlalchemy", specifier = ">=2.0.42" }, { name = "starlette", specifier = ">=0.47.1" }, - { name = "torch", specifier = ">=2.7.1", index = "https://download.pytorch.org/whl/cpu" }, - { name = "trl", specifier = ">=0.18.2" }, { name = "uvicorn", specifier = ">=0.34.3" }, ] From 99080a8617a34e80ee0809d3d88365f12a320ecb Mon Sep 17 00:00:00 2001 From: Anxhela Coba Date: Mon, 22 Sep 2025 09:55:57 -0400 Subject: [PATCH 07/10] lint Signed-off-by: Anxhela Coba --- src/app/endpoints/query.py | 23 ++++++++++---- src/models/responses.py | 20 ++++++++----- src/utils/types.py | 30 +++++++++---------- .../app/endpoints/test_streaming_query.py | 6 ++-- 4 files changed, 47 insertions(+), 32 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 905324ba9..7053ab7b1 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -30,7 +30,13 @@ from models.config import Action from models.database.conversations import UserConversation from models.requests import QueryRequest, Attachment -from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse, RAGChunk, ReferencedDocument, ToolCall +from models.responses import ( + QueryResponse, + UnauthorizedResponse, + ForbiddenResponse, + ReferencedDocument, + ToolCall, +) from utils.endpoints import ( check_configuration_loaded, get_agent, @@ -260,13 +266,18 @@ async def query_endpoint_handler( tool_calls = [ ToolCall( tool_name=tc.name, - arguments=tc.args if isinstance(tc.args, dict) else {"query": str(tc.args)}, - result={"response": tc.response} if tc.response and tc.name != constants.DEFAULT_RAG_TOOL else None + arguments=( + tc.args if isinstance(tc.args, dict) else {"query": str(tc.args)} + ), + result=( + {"response": tc.response} + if tc.response and tc.name != constants.DEFAULT_RAG_TOOL + else None + ), ) for tc in summary.tool_calls ] - logger.info("Extracting referenced documents...") referenced_docs = [] doc_sources = set() @@ -277,7 +288,9 @@ async def query_endpoint_handler( ReferencedDocument( url=chunk.source if chunk.source.startswith("http") else None, title=chunk.source, - chunk_count=sum(1 for c in summary.rag_chunks if c.source == chunk.source) + chunk_count=sum( + 1 for c in summary.rag_chunks if c.source == chunk.source + ), ) ) diff --git a/src/models/responses.py b/src/models/responses.py index ce63d84d6..bac6301b6 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -36,7 +36,7 @@ class ModelsResponse(BaseModel): class RAGChunk(BaseModel): """Model representing a RAG chunk used in the response.""" - + content: str = Field(description="The content of the chunk") source: Optional[str] = Field(None, description="Source document or URL") score: Optional[float] = Field(None, description="Relevance score") @@ -44,15 +44,17 @@ class RAGChunk(BaseModel): class ReferencedDocument(BaseModel): """Model representing a document referenced in the response.""" - + url: Optional[str] = Field(None, description="URL of the document") title: Optional[str] = Field(None, description="Title of the document") - chunk_count: Optional[int] = Field(None, description="Number of chunks from this document") + chunk_count: Optional[int] = Field( + None, description="Number of chunks from this document" + ) class ToolCall(BaseModel): """Model representing a tool call made during response generation.""" - + tool_name: str = Field(description="Name of the tool called") arguments: dict[str, Any] = Field(description="Arguments passed to the tool") result: Optional[dict[str, Any]] = Field(None, description="Result from the tool") @@ -110,21 +112,23 @@ class QueryResponse(BaseModel): { "content": "OLM is a component of the Operator Framework toolkit...", "source": "kubernetes-docs/operators.md", - "score": 0.95 + "score": 0.95, } ], "referenced_documents": [ { - "url": "https://kubernetes.io/docs/concepts/extend-kubernetes/operator/", + "url": ( + "https://kubernetes.io/docs/concepts/extend-kubernetes/operator/" + ), "title": "Operator Pattern", - "chunk_count": 2 + "chunk_count": 2, } ], "tool_calls": [ { "tool_name": "knowledge_search", "arguments": {"query": "operator lifecycle manager"}, - "result": {"chunks_found": 5} + "result": {"chunks_found": 5}, } ], } diff --git a/src/utils/types.py b/src/utils/types.py index cecf96976..b89ef8945 100644 --- a/src/utils/types.py +++ b/src/utils/types.py @@ -7,9 +7,9 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.shared.tool_call import ToolCall from llama_stack_client.types.tool_execution_step import ToolExecutionStep -from constants import DEFAULT_RAG_TOOL from pydantic import BaseModel -from models.responses import RAGChunk, Field +from models.responses import RAGChunk +from constants import DEFAULT_RAG_TOOL class Singleton(type): @@ -58,8 +58,6 @@ class ToolCallSummary(BaseModel): response: str | None - - class TurnSummary(BaseModel): """Summary of a turn in llama stack.""" @@ -73,8 +71,10 @@ def append_tool_calls_from_llama(self, tec: ToolExecutionStep) -> None: responses_by_id = {tc.call_id: tc for tc in tec.tool_responses} for call_id, tc in calls_by_id.items(): resp = responses_by_id.get(call_id) - response_content = interleaved_content_as_str(resp.content) if resp else None - + response_content = ( + interleaved_content_as_str(resp.content) if resp else None + ) + self.tool_calls.append( ToolCallSummary( id=call_id, @@ -83,11 +83,11 @@ def append_tool_calls_from_llama(self, tec: ToolExecutionStep) -> None: response=response_content, ) ) - + # Extract RAG chunks from knowledge_search tool responses if tc.tool_name == DEFAULT_RAG_TOOL and resp and response_content: self._extract_rag_chunks_from_response(response_content) - + def _extract_rag_chunks_from_response(self, response_content: str) -> None: """Extract RAG chunks from tool response content.""" try: @@ -101,7 +101,7 @@ def _extract_rag_chunks_from_response(self, response_content: str) -> None: RAGChunk( content=chunk.get("content", ""), source=chunk.get("source"), - score=chunk.get("score") + score=chunk.get("score"), ) ) elif isinstance(data, list): @@ -112,7 +112,7 @@ def _extract_rag_chunks_from_response(self, response_content: str) -> None: RAGChunk( content=chunk.get("content", str(chunk)), source=chunk.get("source"), - score=chunk.get("score") + score=chunk.get("score"), ) ) except json.JSONDecodeError: @@ -122,16 +122,14 @@ def _extract_rag_chunks_from_response(self, response_content: str) -> None: RAGChunk( content=response_content, source=DEFAULT_RAG_TOOL, - score=None + score=None, ) ) - except Exception: - # Treat response as single chunk + except (KeyError, AttributeError, TypeError, ValueError): + # Treat response as single chunk on data access/structure errors if response_content.strip(): self.rag_chunks.append( RAGChunk( - content=response_content, - source=DEFAULT_RAG_TOOL, - score=None + content=response_content, source=DEFAULT_RAG_TOOL, score=None ) ) diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 06792bcd1..ae732ed8d 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -42,11 +42,11 @@ stream_build_event, ) -from models.requests import QueryRequest, Attachment -from models.config import ModelContextProtocolServer, Action from authorization.resolvers import NoopRolesResolver -from utils.types import ToolCallSummary, TurnSummary +from models.config import ModelContextProtocolServer, Action +from models.requests import QueryRequest, Attachment from models.responses import RAGChunk +from utils.types import ToolCallSummary, TurnSummary MOCK_AUTH = ("mock_user_id", "mock_username", False, "mock_token") From 0dd1f1d7597066cdba62f83b7c4ec21c3c3bca18 Mon Sep 17 00:00:00 2001 From: Anxhela Coba Date: Mon, 22 Sep 2025 10:25:36 -0400 Subject: [PATCH 08/10] disable lint on query.py Signed-off-by: Anxhela Coba --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index eebf48660..eeb5fd631 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -176,6 +176,7 @@ addopts = [ [tool.pylint.main] source-roots = "src" +ignore = ["query.py"] [build-system] requires = ["pdm-backend"] From 0871751409bcc97cd9df139a936832337b162f60 Mon Sep 17 00:00:00 2001 From: Anxhela Coba Date: Tue, 23 Sep 2025 16:23:34 -0400 Subject: [PATCH 09/10] lint Signed-off-by: Anxhela Coba --- src/models/responses.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/models/responses.py b/src/models/responses.py index e19d8e10a..458c486bb 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -41,6 +41,7 @@ class RAGChunk(BaseModel): source: Optional[str] = Field(None, description="Source document or URL") score: Optional[float] = Field(None, description="Relevance score") + class ToolCall(BaseModel): """Model representing a tool call made during response generation.""" @@ -57,12 +58,13 @@ class ReferencedDocument(BaseModel): doc_title: Title of the referenced doc. """ - doc_url: AnyUrl = Field( + doc_url: Optional[AnyUrl] = Field( None, description="URL of the referenced document" ) doc_title: str = Field(description="Title of the referenced document") + class QueryResponse(BaseModel): """Model representing LLM response to a query. @@ -98,7 +100,7 @@ class QueryResponse(BaseModel): None, description="List of tool calls made during response generation", ) - + referenced_documents: list[ReferencedDocument] = Field( default_factory=list, description="List of documents referenced in generating the response", @@ -145,7 +147,6 @@ class QueryResponse(BaseModel): ] } } - class InfoResponse(BaseModel): From 930cd74073869c12930dd662d1cdd94eb963c89b Mon Sep 17 00:00:00 2001 From: Anxhela Coba Date: Wed, 24 Sep 2025 11:45:10 -0400 Subject: [PATCH 10/10] unit tests Signed-off-by: Anxhela Coba --- .../models/responses/test_query_response.py | 109 +++++++++++++++++- tests/unit/models/responses/test_rag_chunk.py | 98 ++++++++++++++++ 2 files changed, 206 insertions(+), 1 deletion(-) create mode 100644 tests/unit/models/responses/test_rag_chunk.py diff --git a/tests/unit/models/responses/test_query_response.py b/tests/unit/models/responses/test_query_response.py index 76b3136db..68333616b 100644 --- a/tests/unit/models/responses/test_query_response.py +++ b/tests/unit/models/responses/test_query_response.py @@ -1,6 +1,6 @@ """Unit tests for QueryResponse model.""" -from models.responses import QueryResponse +from models.responses import QueryResponse, RAGChunk, ToolCall, ReferencedDocument class TestQueryResponse: @@ -20,3 +20,110 @@ def test_optional_conversation_id(self) -> None: qr = QueryResponse(response="LLM answer") assert qr.conversation_id is None assert qr.response == "LLM answer" + + def test_rag_chunks_empty_by_default(self) -> None: + """Test that rag_chunks is empty by default.""" + qr = QueryResponse(response="LLM answer") + assert not qr.rag_chunks + + def test_rag_chunks_with_data(self) -> None: + """Test QueryResponse with RAG chunks.""" + rag_chunks = [ + RAGChunk( + content="Kubernetes is an open-source container orchestration system", + source="kubernetes-docs/overview.md", + score=0.95, + ), + RAGChunk( + content="Container orchestration automates deployment and management", + source="kubernetes-docs/concepts.md", + score=0.87, + ), + ] + + qr = QueryResponse( + conversation_id="123e4567-e89b-12d3-a456-426614174000", + response="LLM answer with RAG context", + rag_chunks=rag_chunks, + ) + + assert len(qr.rag_chunks) == 2 + assert ( + qr.rag_chunks[0].content + == "Kubernetes is an open-source container orchestration system" + ) + assert qr.rag_chunks[0].source == "kubernetes-docs/overview.md" + assert qr.rag_chunks[0].score == 0.95 + assert ( + qr.rag_chunks[1].content + == "Container orchestration automates deployment and management" + ) + assert qr.rag_chunks[1].source == "kubernetes-docs/concepts.md" + assert qr.rag_chunks[1].score == 0.87 + + def test_rag_chunks_with_optional_fields(self) -> None: + """Test RAG chunks with optional source and score fields.""" + rag_chunks = [ + RAGChunk(content="Some content without source or score"), + RAGChunk(content="Content with source only", source="docs/guide.md"), + RAGChunk(content="Content with score only", score=0.75), + ] + + qr = QueryResponse(response="LLM answer", rag_chunks=rag_chunks) + + assert len(qr.rag_chunks) == 3 + assert qr.rag_chunks[0].source is None + assert qr.rag_chunks[0].score is None + assert qr.rag_chunks[1].source == "docs/guide.md" + assert qr.rag_chunks[1].score is None + assert qr.rag_chunks[2].source is None + assert qr.rag_chunks[2].score == 0.75 + + def test_complete_query_response_with_all_fields(self) -> None: + """Test QueryResponse with all fields including RAG chunks, tool calls, and docs.""" + rag_chunks = [ + RAGChunk( + content="OLM is a component of the Operator Framework toolkit", + source="kubernetes-docs/operators.md", + score=0.95, + ) + ] + + tool_calls = [ + ToolCall( + tool_name="knowledge_search", + arguments={"query": "operator lifecycle manager"}, + result={"chunks_found": 5}, + ) + ] + + referenced_documents = [ + ReferencedDocument( + doc_url=( + "https://docs.openshift.com/container-platform/4.15/operators/olm/index.html" + ), + doc_title="Operator Lifecycle Manager (OLM)", + ) + ] + + qr = QueryResponse( + conversation_id="123e4567-e89b-12d3-a456-426614174000", + response="Operator Lifecycle Manager (OLM) helps users install...", + rag_chunks=rag_chunks, + tool_calls=tool_calls, + referenced_documents=referenced_documents, + ) + + assert qr.conversation_id == "123e4567-e89b-12d3-a456-426614174000" + assert qr.response == "Operator Lifecycle Manager (OLM) helps users install..." + assert len(qr.rag_chunks) == 1 + assert ( + qr.rag_chunks[0].content + == "OLM is a component of the Operator Framework toolkit" + ) + assert len(qr.tool_calls) == 1 + assert qr.tool_calls[0].tool_name == "knowledge_search" + assert len(qr.referenced_documents) == 1 + assert ( + qr.referenced_documents[0].doc_title == "Operator Lifecycle Manager (OLM)" + ) diff --git a/tests/unit/models/responses/test_rag_chunk.py b/tests/unit/models/responses/test_rag_chunk.py new file mode 100644 index 000000000..bec534d37 --- /dev/null +++ b/tests/unit/models/responses/test_rag_chunk.py @@ -0,0 +1,98 @@ +"""Unit tests for RAGChunk model.""" + +from models.responses import RAGChunk + + +class TestRAGChunk: + """Test cases for the RAGChunk model.""" + + def test_constructor_with_content_only(self) -> None: + """Test RAGChunk constructor with content only.""" + chunk = RAGChunk(content="Sample content") + assert chunk.content == "Sample content" + assert chunk.source is None + assert chunk.score is None + + def test_constructor_with_all_fields(self) -> None: + """Test RAGChunk constructor with all fields.""" + chunk = RAGChunk( + content="Kubernetes is an open-source container orchestration system", + source="kubernetes-docs/overview.md", + score=0.95, + ) + assert ( + chunk.content + == "Kubernetes is an open-source container orchestration system" + ) + assert chunk.source == "kubernetes-docs/overview.md" + assert chunk.score == 0.95 + + def test_constructor_with_content_and_source(self) -> None: + """Test RAGChunk constructor with content and source.""" + chunk = RAGChunk( + content="Container orchestration automates deployment", + source="docs/concepts.md", + ) + assert chunk.content == "Container orchestration automates deployment" + assert chunk.source == "docs/concepts.md" + assert chunk.score is None + + def test_constructor_with_content_and_score(self) -> None: + """Test RAGChunk constructor with content and score.""" + chunk = RAGChunk(content="Pod is the smallest deployable unit", score=0.82) + assert chunk.content == "Pod is the smallest deployable unit" + assert chunk.source is None + assert chunk.score == 0.82 + + def test_score_range_validation(self) -> None: + """Test that RAGChunk accepts valid score ranges.""" + # Test minimum score + chunk_min = RAGChunk(content="Test content", score=0.0) + assert chunk_min.score == 0.0 + + # Test maximum score + chunk_max = RAGChunk(content="Test content", score=1.0) + assert chunk_max.score == 1.0 + + # Test decimal score + chunk_decimal = RAGChunk(content="Test content", score=0.751) + assert chunk_decimal.score == 0.751 + + def test_empty_content(self) -> None: + """Test RAGChunk with empty content.""" + chunk = RAGChunk(content="") + assert chunk.content == "" + assert chunk.source is None + assert chunk.score is None + + def test_multiline_content(self) -> None: + """Test RAGChunk with multiline content.""" + multiline_content = """This is a multiline content + that spans multiple lines + and contains various information.""" + + chunk = RAGChunk( + content=multiline_content, source="docs/multiline.md", score=0.88 + ) + assert chunk.content == multiline_content + assert chunk.source == "docs/multiline.md" + assert chunk.score == 0.88 + + def test_long_source_path(self) -> None: + """Test RAGChunk with long source path.""" + long_source = ( + "very/deep/nested/directory/structure/with/many/levels/document.md" + ) + chunk = RAGChunk( + content="Content from deeply nested document", source=long_source + ) + assert chunk.source == long_source + + def test_url_as_source(self) -> None: + """Test RAGChunk with URL as source.""" + url_source = "https://docs.example.com/api/v1/documentation" + chunk = RAGChunk( + content="API documentation content", source=url_source, score=0.92 + ) + assert chunk.source == url_source + assert chunk.score == 0.92