diff --git a/pyproject.toml b/pyproject.toml index 66d196226..fdd526b09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ dependencies = [ # Used by Llama Stack version checker "semver<4.0.0", # Used by authorization resolvers - "jsonpath-ng>=1.6.1", + "jsonpath-ng>=1.6.1" ] @@ -176,6 +176,7 @@ addopts = [ [tool.pylint.main] source-roots = "src" +ignore = ["query.py"] [build-system] requires = ["pdm-backend"] diff --git a/run.yaml b/run.yaml index af519cfa8..5d1d6765d 100644 --- a/run.yaml +++ b/run.yaml @@ -60,6 +60,9 @@ providers: provider_id: meta-reference provider_type: inline::meta-reference inference: + - provider_id: sentence-transformers # Can be any embedding provider + provider_type: inline::sentence-transformers + config: {} - provider_id: openai provider_type: remote::openai config: @@ -99,14 +102,17 @@ providers: - provider_id: model-context-protocol provider_type: remote::model-context-protocol config: {} + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} vector_io: - config: kvstore: - db_path: .llama/distributions/ollama/faiss_store.db + db_path: .llama/distributions/ollama/faiss_store.db # Location of vector database namespace: null type: sqlite provider_id: faiss - provider_type: inline::faiss + provider_type: inline::faiss # Or preferred vector DB scoring_fns: [] server: auth: null @@ -117,10 +123,23 @@ server: tls_certfile: null tls_keyfile: null shields: [] -vector_dbs: [] - +vector_dbs: + - vector_db_id: my_knowledge_base + embedding_model: sentence-transformers/all-mpnet-base-v2 + embedding_dimension: 768 + provider_id: faiss models: + - metadata: + embedding_dimension: 768 # Depends on chosen model + model_id: sentence-transformers/all-mpnet-base-v2 # Example embedding model + provider_id: sentence-transformers + provider_model_id: sentence-transformers/all-mpnet-base-v2 # Location of embedding model + model_type: embedding - model_id: gpt-4-turbo provider_id: openai model_type: llm provider_model_id: gpt-4-turbo + +tool_groups: + - toolgroup_id: builtin::rag + provider_id: rag-runtime diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index d6e2d3541..04f0f63bc 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -8,6 +8,7 @@ from typing import Annotated, Any, Optional, cast from fastapi import APIRouter, Depends, HTTPException, Request, status +from pydantic import AnyUrl from llama_stack_client import ( APIConnectionError, AsyncLlamaStackClient, # type: ignore @@ -39,6 +40,7 @@ ForbiddenResponse, QueryResponse, ReferencedDocument, + ToolCall, UnauthorizedResponse, ) from utils.endpoints import ( @@ -248,6 +250,10 @@ async def query_endpoint_handler( # Update metrics for the LLM call metrics.llm_calls_total.labels(provider_id, model_id).inc() + # Convert RAG chunks to dictionary format once for reuse + logger.info("Processing RAG chunks...") + rag_chunks_dict = [chunk.model_dump() for chunk in summary.rag_chunks] + if not is_transcripts_enabled(): logger.debug("Transcript collection is disabled in the configuration") else: @@ -260,11 +266,12 @@ async def query_endpoint_handler( query=query_request.query, query_request=query_request, summary=summary, - rag_chunks=[], # TODO(lucasagomes): implement rag_chunks + rag_chunks=rag_chunks_dict, truncated=False, # TODO(lucasagomes): implement truncation as part of quota work attachments=query_request.attachments or [], ) + logger.info("Persisting conversation details...") persist_user_conversation_details( user_id=user_id, conversation_id=conversation_id, @@ -272,11 +279,50 @@ async def query_endpoint_handler( provider_id=provider_id, ) - return QueryResponse( + # Convert tool calls to response format + logger.info("Processing tool calls...") + tool_calls = [ + ToolCall( + tool_name=tc.name, + arguments=( + tc.args if isinstance(tc.args, dict) else {"query": str(tc.args)} + ), + result=( + {"response": tc.response} + if tc.response and tc.name != constants.DEFAULT_RAG_TOOL + else None + ), + ) + for tc in summary.tool_calls + ] + + logger.info("Extracting referenced documents...") + referenced_docs = [] + doc_sources = set() + for chunk in summary.rag_chunks: + if chunk.source and chunk.source not in doc_sources: + doc_sources.add(chunk.source) + referenced_docs.append( + ReferencedDocument( + doc_url=( + AnyUrl(chunk.source) + if chunk.source.startswith("http") + else None + ), + doc_title=chunk.source, + ) + ) + + logger.info("Building final response...") + response = QueryResponse( conversation_id=conversation_id, response=summary.llm_response, + rag_chunks=summary.rag_chunks if summary.rag_chunks else [], + tool_calls=tool_calls if tool_calls else None, referenced_documents=referenced_documents, ) + logger.info("Query processing completed successfully!") + return response # connection to Llama Stack server except APIConnectionError as e: diff --git a/src/constants.py b/src/constants.py index 8521b97c6..8369b9369 100644 --- a/src/constants.py +++ b/src/constants.py @@ -52,6 +52,8 @@ DEFAULT_JWT_UID_CLAIM = "user_id" DEFAULT_JWT_USER_NAME_CLAIM = "username" +# default RAG tool value +DEFAULT_RAG_TOOL = "knowledge_search" # PostgreSQL connection constants # See: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-SSLMODE diff --git a/src/models/responses.py b/src/models/responses.py index 1f99bb31f..458c486bb 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -34,6 +34,22 @@ class ModelsResponse(BaseModel): ) +class RAGChunk(BaseModel): + """Model representing a RAG chunk used in the response.""" + + content: str = Field(description="The content of the chunk") + source: Optional[str] = Field(None, description="Source document or URL") + score: Optional[float] = Field(None, description="Relevance score") + + +class ToolCall(BaseModel): + """Model representing a tool call made during response generation.""" + + tool_name: str = Field(description="Name of the tool called") + arguments: dict[str, Any] = Field(description="Arguments passed to the tool") + result: Optional[dict[str, Any]] = Field(None, description="Result from the tool") + + class ReferencedDocument(BaseModel): """Model representing a document referenced in generating a response. @@ -42,27 +58,27 @@ class ReferencedDocument(BaseModel): doc_title: Title of the referenced doc. """ - doc_url: AnyUrl = Field(description="URL of the referenced document") + doc_url: Optional[AnyUrl] = Field( + None, description="URL of the referenced document" + ) doc_title: str = Field(description="Title of the referenced document") -# TODO(lucasagomes): a lot of fields to add to QueryResponse. For now -# we are keeping it simple. The missing fields are: -# - truncated: Set to True if conversation history was truncated to be within context window. -# - input_tokens: Number of tokens sent to LLM -# - output_tokens: Number of tokens received from LLM -# - available_quotas: Quota available as measured by all configured quota limiters -# - tool_calls: List of tool requests. -# - tool_results: List of tool results. -# See LLMResponse in ols-service for more details. class QueryResponse(BaseModel): """Model representing LLM response to a query. Attributes: conversation_id: The optional conversation ID (UUID). response: The response. + rag_chunks: List of RAG chunks used to generate the response. referenced_documents: The URLs and titles for the documents used to generate the response. + tool_calls: List of tool calls made during response generation. + TODO: truncated: Whether conversation history was truncated. + TODO: input_tokens: Number of tokens sent to LLM. + TODO: output_tokens: Number of tokens received from LLM. + TODO: available_quotas: Quota available as measured by all configured quota limiters + TODO: tool_results: List of tool results. """ conversation_id: Optional[str] = Field( @@ -78,6 +94,13 @@ class QueryResponse(BaseModel): ], ) + rag_chunks: list[RAGChunk] = [] + + tool_calls: Optional[list[ToolCall]] = Field( + None, + description="List of tool calls made during response generation", + ) + referenced_documents: list[ReferencedDocument] = Field( default_factory=list, description="List of documents referenced in generating the response", @@ -99,6 +122,20 @@ class QueryResponse(BaseModel): { "conversation_id": "123e4567-e89b-12d3-a456-426614174000", "response": "Operator Lifecycle Manager (OLM) helps users install...", + "rag_chunks": [ + { + "content": "OLM is a component of the Operator Framework toolkit...", + "source": "kubernetes-docs/operators.md", + "score": 0.95, + } + ], + "tool_calls": [ + { + "tool_name": "knowledge_search", + "arguments": {"query": "operator lifecycle manager"}, + "result": {"chunks_found": 5}, + } + ], "referenced_documents": [ { "doc_url": "https://docs.openshift.com/" diff --git a/src/utils/transcripts.py b/src/utils/transcripts.py index e29d4319d..74b20fdbd 100644 --- a/src/utils/transcripts.py +++ b/src/utils/transcripts.py @@ -39,7 +39,7 @@ def store_transcript( # pylint: disable=too-many-arguments,too-many-positional- query: str, query_request: QueryRequest, summary: TurnSummary, - rag_chunks: list[str], + rag_chunks: list[dict], truncated: bool, attachments: list[Attachment], ) -> None: @@ -52,7 +52,7 @@ def store_transcript( # pylint: disable=too-many-arguments,too-many-positional- query: The query (without attachments). query_request: The request containing a query. summary: Summary of the query/response turn. - rag_chunks: The list of `RagChunk` objects. + rag_chunks: The list of serialized `RAGChunk` dictionaries. truncated: The flag indicating if the history was truncated. attachments: The list of `Attachment` objects. """ diff --git a/src/utils/types.py b/src/utils/types.py index 5770139ae..b89ef8945 100644 --- a/src/utils/types.py +++ b/src/utils/types.py @@ -1,13 +1,15 @@ """Common types for the project.""" from typing import Any, Optional - +import json from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str from llama_stack_client.lib.agents.tool_parser import ToolParser from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.shared.tool_call import ToolCall from llama_stack_client.types.tool_execution_step import ToolExecutionStep -from pydantic.main import BaseModel +from pydantic import BaseModel +from models.responses import RAGChunk +from constants import DEFAULT_RAG_TOOL class Singleton(type): @@ -61,6 +63,7 @@ class TurnSummary(BaseModel): llm_response: str tool_calls: list[ToolCallSummary] + rag_chunks: list[RAGChunk] = [] def append_tool_calls_from_llama(self, tec: ToolExecutionStep) -> None: """Append the tool calls from a llama tool execution step.""" @@ -68,11 +71,65 @@ def append_tool_calls_from_llama(self, tec: ToolExecutionStep) -> None: responses_by_id = {tc.call_id: tc for tc in tec.tool_responses} for call_id, tc in calls_by_id.items(): resp = responses_by_id.get(call_id) + response_content = ( + interleaved_content_as_str(resp.content) if resp else None + ) + self.tool_calls.append( ToolCallSummary( id=call_id, name=tc.tool_name, args=tc.arguments, - response=interleaved_content_as_str(resp.content) if resp else None, + response=response_content, ) ) + + # Extract RAG chunks from knowledge_search tool responses + if tc.tool_name == DEFAULT_RAG_TOOL and resp and response_content: + self._extract_rag_chunks_from_response(response_content) + + def _extract_rag_chunks_from_response(self, response_content: str) -> None: + """Extract RAG chunks from tool response content.""" + try: + # Parse the response to get chunks + # Try JSON first + try: + data = json.loads(response_content) + if isinstance(data, dict) and "chunks" in data: + for chunk in data["chunks"]: + self.rag_chunks.append( + RAGChunk( + content=chunk.get("content", ""), + source=chunk.get("source"), + score=chunk.get("score"), + ) + ) + elif isinstance(data, list): + # Handle list of chunks + for chunk in data: + if isinstance(chunk, dict): + self.rag_chunks.append( + RAGChunk( + content=chunk.get("content", str(chunk)), + source=chunk.get("source"), + score=chunk.get("score"), + ) + ) + except json.JSONDecodeError: + # If not JSON, treat the entire response as a single chunk + if response_content.strip(): + self.rag_chunks.append( + RAGChunk( + content=response_content, + source=DEFAULT_RAG_TOOL, + score=None, + ) + ) + except (KeyError, AttributeError, TypeError, ValueError): + # Treat response as single chunk on data access/structure errors + if response_content.strip(): + self.rag_chunks.append( + RAGChunk( + content=response_content, source=DEFAULT_RAG_TOOL, score=None + ) + ) diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 38983666a..ae732ed8d 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -42,9 +42,10 @@ stream_build_event, ) -from models.requests import QueryRequest, Attachment -from models.config import ModelContextProtocolServer, Action from authorization.resolvers import NoopRolesResolver +from models.config import ModelContextProtocolServer, Action +from models.requests import QueryRequest, Attachment +from models.responses import RAGChunk from utils.types import ToolCallSummary, TurnSummary MOCK_AUTH = ("mock_user_id", "mock_username", False, "mock_token") @@ -343,6 +344,13 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) response=" ".join(SAMPLE_KNOWLEDGE_SEARCH_RESULTS), ) ], + rag_chunks=[ + RAGChunk( + content=" ".join(SAMPLE_KNOWLEDGE_SEARCH_RESULTS), + source="knowledge_search", + score=None, + ) + ], ), attachments=[], rag_chunks=[], diff --git a/tests/unit/models/responses/test_query_response.py b/tests/unit/models/responses/test_query_response.py index 76b3136db..68333616b 100644 --- a/tests/unit/models/responses/test_query_response.py +++ b/tests/unit/models/responses/test_query_response.py @@ -1,6 +1,6 @@ """Unit tests for QueryResponse model.""" -from models.responses import QueryResponse +from models.responses import QueryResponse, RAGChunk, ToolCall, ReferencedDocument class TestQueryResponse: @@ -20,3 +20,110 @@ def test_optional_conversation_id(self) -> None: qr = QueryResponse(response="LLM answer") assert qr.conversation_id is None assert qr.response == "LLM answer" + + def test_rag_chunks_empty_by_default(self) -> None: + """Test that rag_chunks is empty by default.""" + qr = QueryResponse(response="LLM answer") + assert not qr.rag_chunks + + def test_rag_chunks_with_data(self) -> None: + """Test QueryResponse with RAG chunks.""" + rag_chunks = [ + RAGChunk( + content="Kubernetes is an open-source container orchestration system", + source="kubernetes-docs/overview.md", + score=0.95, + ), + RAGChunk( + content="Container orchestration automates deployment and management", + source="kubernetes-docs/concepts.md", + score=0.87, + ), + ] + + qr = QueryResponse( + conversation_id="123e4567-e89b-12d3-a456-426614174000", + response="LLM answer with RAG context", + rag_chunks=rag_chunks, + ) + + assert len(qr.rag_chunks) == 2 + assert ( + qr.rag_chunks[0].content + == "Kubernetes is an open-source container orchestration system" + ) + assert qr.rag_chunks[0].source == "kubernetes-docs/overview.md" + assert qr.rag_chunks[0].score == 0.95 + assert ( + qr.rag_chunks[1].content + == "Container orchestration automates deployment and management" + ) + assert qr.rag_chunks[1].source == "kubernetes-docs/concepts.md" + assert qr.rag_chunks[1].score == 0.87 + + def test_rag_chunks_with_optional_fields(self) -> None: + """Test RAG chunks with optional source and score fields.""" + rag_chunks = [ + RAGChunk(content="Some content without source or score"), + RAGChunk(content="Content with source only", source="docs/guide.md"), + RAGChunk(content="Content with score only", score=0.75), + ] + + qr = QueryResponse(response="LLM answer", rag_chunks=rag_chunks) + + assert len(qr.rag_chunks) == 3 + assert qr.rag_chunks[0].source is None + assert qr.rag_chunks[0].score is None + assert qr.rag_chunks[1].source == "docs/guide.md" + assert qr.rag_chunks[1].score is None + assert qr.rag_chunks[2].source is None + assert qr.rag_chunks[2].score == 0.75 + + def test_complete_query_response_with_all_fields(self) -> None: + """Test QueryResponse with all fields including RAG chunks, tool calls, and docs.""" + rag_chunks = [ + RAGChunk( + content="OLM is a component of the Operator Framework toolkit", + source="kubernetes-docs/operators.md", + score=0.95, + ) + ] + + tool_calls = [ + ToolCall( + tool_name="knowledge_search", + arguments={"query": "operator lifecycle manager"}, + result={"chunks_found": 5}, + ) + ] + + referenced_documents = [ + ReferencedDocument( + doc_url=( + "https://docs.openshift.com/container-platform/4.15/operators/olm/index.html" + ), + doc_title="Operator Lifecycle Manager (OLM)", + ) + ] + + qr = QueryResponse( + conversation_id="123e4567-e89b-12d3-a456-426614174000", + response="Operator Lifecycle Manager (OLM) helps users install...", + rag_chunks=rag_chunks, + tool_calls=tool_calls, + referenced_documents=referenced_documents, + ) + + assert qr.conversation_id == "123e4567-e89b-12d3-a456-426614174000" + assert qr.response == "Operator Lifecycle Manager (OLM) helps users install..." + assert len(qr.rag_chunks) == 1 + assert ( + qr.rag_chunks[0].content + == "OLM is a component of the Operator Framework toolkit" + ) + assert len(qr.tool_calls) == 1 + assert qr.tool_calls[0].tool_name == "knowledge_search" + assert len(qr.referenced_documents) == 1 + assert ( + qr.referenced_documents[0].doc_title == "Operator Lifecycle Manager (OLM)" + ) diff --git a/tests/unit/models/responses/test_rag_chunk.py b/tests/unit/models/responses/test_rag_chunk.py new file mode 100644 index 000000000..bec534d37 --- /dev/null +++ b/tests/unit/models/responses/test_rag_chunk.py @@ -0,0 +1,98 @@ +"""Unit tests for RAGChunk model.""" + +from models.responses import RAGChunk + + +class TestRAGChunk: + """Test cases for the RAGChunk model.""" + + def test_constructor_with_content_only(self) -> None: + """Test RAGChunk constructor with content only.""" + chunk = RAGChunk(content="Sample content") + assert chunk.content == "Sample content" + assert chunk.source is None + assert chunk.score is None + + def test_constructor_with_all_fields(self) -> None: + """Test RAGChunk constructor with all fields.""" + chunk = RAGChunk( + content="Kubernetes is an open-source container orchestration system", + source="kubernetes-docs/overview.md", + score=0.95, + ) + assert ( + chunk.content + == "Kubernetes is an open-source container orchestration system" + ) + assert chunk.source == "kubernetes-docs/overview.md" + assert chunk.score == 0.95 + + def test_constructor_with_content_and_source(self) -> None: + """Test RAGChunk constructor with content and source.""" + chunk = RAGChunk( + content="Container orchestration automates deployment", + source="docs/concepts.md", + ) + assert chunk.content == "Container orchestration automates deployment" + assert chunk.source == "docs/concepts.md" + assert chunk.score is None + + def test_constructor_with_content_and_score(self) -> None: + """Test RAGChunk constructor with content and score.""" + chunk = RAGChunk(content="Pod is the smallest deployable unit", score=0.82) + assert chunk.content == "Pod is the smallest deployable unit" + assert chunk.source is None + assert chunk.score == 0.82 + + def test_score_range_validation(self) -> None: + """Test that RAGChunk accepts valid score ranges.""" + # Test minimum score + chunk_min = RAGChunk(content="Test content", score=0.0) + assert chunk_min.score == 0.0 + + # Test maximum score + chunk_max = RAGChunk(content="Test content", score=1.0) + assert chunk_max.score == 1.0 + + # Test decimal score + chunk_decimal = RAGChunk(content="Test content", score=0.751) + assert chunk_decimal.score == 0.751 + + def test_empty_content(self) -> None: + """Test RAGChunk with empty content.""" + chunk = RAGChunk(content="") + assert chunk.content == "" + assert chunk.source is None + assert chunk.score is None + + def test_multiline_content(self) -> None: + """Test RAGChunk with multiline content.""" + multiline_content = """This is a multiline content + that spans multiple lines + and contains various information.""" + + chunk = RAGChunk( + content=multiline_content, source="docs/multiline.md", score=0.88 + ) + assert chunk.content == multiline_content + assert chunk.source == "docs/multiline.md" + assert chunk.score == 0.88 + + def test_long_source_path(self) -> None: + """Test RAGChunk with long source path.""" + long_source = ( + "very/deep/nested/directory/structure/with/many/levels/document.md" + ) + chunk = RAGChunk( + content="Content from deeply nested document", source=long_source + ) + assert chunk.source == long_source + + def test_url_as_source(self) -> None: + """Test RAGChunk with URL as source.""" + url_source = "https://docs.example.com/api/v1/documentation" + chunk = RAGChunk( + content="API documentation content", source=url_source, score=0.92 + ) + assert chunk.source == url_source + assert chunk.score == 0.92