From 088cc5260acf34d753da49c89cc3afe805c6e5ee Mon Sep 17 00:00:00 2001 From: Maysun J Faisal Date: Wed, 8 Oct 2025 17:53:52 -0400 Subject: [PATCH 1/4] First draft of referenced_documents caching Signed-off-by: Maysun J Faisal --- src/app/endpoints/conversations_v2.py | 18 +++++++++--- src/app/endpoints/query.py | 25 +++++++++++++---- src/app/endpoints/streaming_query.py | 29 ++++++++++++++----- src/cache/cache.py | 3 +- src/cache/in_memory_cache.py | 3 +- src/cache/noop_cache.py | 3 +- src/cache/postgres_cache.py | 39 +++++++++++++++++--------- src/cache/sqlite_cache.py | 40 ++++++++++++++++++--------- src/models/cache_entry.py | 24 ++++++---------- src/models/responses.py | 16 ++++++++++- src/utils/endpoints.py | 15 +--------- 11 files changed, 139 insertions(+), 76 deletions(-) diff --git a/src/app/endpoints/conversations_v2.py b/src/app/endpoints/conversations_v2.py index 7ce0ac7ca..d311ce54d 100644 --- a/src/app/endpoints/conversations_v2.py +++ b/src/app/endpoints/conversations_v2.py @@ -314,13 +314,23 @@ def check_conversation_existence(user_id: str, conversation_id: str) -> None: def transform_chat_message(entry: CacheEntry) -> dict[str, Any]: """Transform the message read from cache into format used by response payload.""" + user_message = { + "content": entry.query, + "type": "user" + } + assistant_message: dict[str, Any] = { + "content": entry.response, + "type": "assistant" + } + + # Check for additional_kwargs and add it to the assistant message if it exists + if entry.additional_kwargs: + assistant_message["additional_kwargs"] = entry.additional_kwargs.model_dump() + return { "provider": entry.provider, "model": entry.model, - "messages": [ - {"content": entry.query, "type": "user"}, - {"content": entry.response, "type": "assistant"}, - ], + "messages": [user_message, assistant_message], "started_at": entry.started_at, "completed_at": entry.completed_at, } diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 98be63c42..d92db9fc3 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -22,6 +22,7 @@ from llama_stack_client.types.model_list_response import ModelListResponse from llama_stack_client.types.shared.interleaved_content_item import TextContentItem from llama_stack_client.types.tool_execution_step import ToolExecutionStep +from pydantic import AnyUrl import constants import metrics @@ -31,6 +32,7 @@ from authorization.middleware import authorize from client import AsyncLlamaStackClientHolder from configuration import configuration +from models.cache_entry import CacheEntry, AdditionalKwargs from models.config import Action from models.database.conversations import UserConversation from models.requests import Attachment, QueryRequest @@ -331,16 +333,27 @@ async def query_endpoint_handler( # pylint: disable=R0914 ) completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ") + + additional_kwargs_obj = None + if referenced_documents: + additional_kwargs_obj = AdditionalKwargs( + referenced_documents=referenced_documents + ) + cache_entry = CacheEntry( + query=query_request.query, + response=summary.llm_response, + provider=provider_id, + model=model_id, + started_at=started_at, + completed_at=completed_at, + additional_kwargs=additional_kwargs_obj + ) + store_conversation_into_cache( configuration, user_id, conversation_id, - provider_id, - model_id, - query_request.query, - summary.llm_response, - started_at, - completed_at, + cache_entry, _skip_userid_check, topic_summary, ) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index d903469a7..bc1dd2c58 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -21,6 +21,7 @@ ) from llama_stack_client.types.shared import ToolCall from llama_stack_client.types.shared.interleaved_content_item import TextContentItem +from pydantic import AnyUrl from app.database import get_session from app.endpoints.query import ( @@ -43,10 +44,11 @@ from constants import DEFAULT_RAG_TOOL, MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT import metrics from metrics.utils import update_llm_token_count_from_turn +from models.cache_entry import CacheEntry, AdditionalKwargs from models.config import Action from models.database.conversations import UserConversation from models.requests import QueryRequest -from models.responses import ForbiddenResponse, UnauthorizedResponse +from models.responses import ForbiddenResponse, UnauthorizedResponse, ReferencedDocument from utils.endpoints import ( check_configuration_loaded, create_rag_chunks_dict, @@ -863,16 +865,29 @@ async def response_generator( ) completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ") + + referenced_documents = create_referenced_documents_with_metadata(summary, metadata_map) + + additional_kwargs_obj = None + if referenced_documents: + additional_kwargs_obj = AdditionalKwargs( + referenced_documents=referenced_documents + ) + cache_entry = CacheEntry( + query=query_request.query, + response=summary.llm_response, + provider=provider_id, + model=model_id, + started_at=started_at, + completed_at=completed_at, + additional_kwargs=additional_kwargs_obj + ) + store_conversation_into_cache( configuration, user_id, conversation_id, - provider_id, - model_id, - query_request.query, - summary.llm_response, - started_at, - completed_at, + cache_entry, _skip_userid_check, topic_summary, ) diff --git a/src/cache/cache.py b/src/cache/cache.py index 4cdab6307..8a6e6fc07 100644 --- a/src/cache/cache.py +++ b/src/cache/cache.py @@ -2,7 +2,8 @@ from abc import ABC, abstractmethod -from models.cache_entry import CacheEntry, ConversationData +from models.cache_entry import CacheEntry +from models.responses import ConversationData from utils.suid import check_suid diff --git a/src/cache/in_memory_cache.py b/src/cache/in_memory_cache.py index 1b6b4123f..388585b5a 100644 --- a/src/cache/in_memory_cache.py +++ b/src/cache/in_memory_cache.py @@ -1,8 +1,9 @@ """In-memory cache implementation.""" from cache.cache import Cache -from models.cache_entry import CacheEntry, ConversationData +from models.cache_entry import CacheEntry from models.config import InMemoryCacheConfig +from models.responses import ConversationData from log import get_logger from utils.connection_decorator import connection diff --git a/src/cache/noop_cache.py b/src/cache/noop_cache.py index fcd20f368..40b1bd144 100644 --- a/src/cache/noop_cache.py +++ b/src/cache/noop_cache.py @@ -1,7 +1,8 @@ """No-operation cache implementation.""" from cache.cache import Cache -from models.cache_entry import CacheEntry, ConversationData +from models.cache_entry import CacheEntry +from models.responses import ConversationData from log import get_logger from utils.connection_decorator import connection diff --git a/src/cache/postgres_cache.py b/src/cache/postgres_cache.py index 0881b66c6..e25de3522 100644 --- a/src/cache/postgres_cache.py +++ b/src/cache/postgres_cache.py @@ -4,8 +4,9 @@ from cache.cache import Cache from cache.cache_error import CacheError -from models.cache_entry import CacheEntry, ConversationData +from models.cache_entry import CacheEntry, AdditionalKwargs from models.config import PostgreSQLDatabaseConfiguration +from models.responses import ConversationData from log import get_logger from utils.connection_decorator import connection @@ -37,15 +38,16 @@ class PostgresCache(Cache): CREATE_CACHE_TABLE = """ CREATE TABLE IF NOT EXISTS cache ( - user_id text NOT NULL, - conversation_id text NOT NULL, - created_at timestamp NOT NULL, - started_at text, - completed_at text, - query text, - response text, - provider text, - model text, + user_id text NOT NULL, + conversation_id text NOT NULL, + created_at timestamp NOT NULL, + started_at text, + completed_at text, + query text, + response text, + provider text, + model text, + additional_kwargs jsonb, PRIMARY KEY(user_id, conversation_id, created_at) ); """ @@ -66,7 +68,7 @@ class PostgresCache(Cache): """ SELECT_CONVERSATION_HISTORY_STATEMENT = """ - SELECT query, response, provider, model, started_at, completed_at + SELECT query, response, provider, model, started_at, completed_at, additional_kwargs FROM cache WHERE user_id=%s AND conversation_id=%s ORDER BY created_at @@ -74,8 +76,8 @@ class PostgresCache(Cache): INSERT_CONVERSATION_HISTORY_STATEMENT = """ INSERT INTO cache(user_id, conversation_id, created_at, started_at, completed_at, - query, response, provider, model) - VALUES (%s, %s, CURRENT_TIMESTAMP, %s, %s, %s, %s, %s, %s) + query, response, provider, model, additional_kwargs) + VALUES (%s, %s, CURRENT_TIMESTAMP, %s, %s, %s, %s, %s, %s, %s) """ QUERY_CACHE_SIZE = """ @@ -211,6 +213,11 @@ def get( result = [] for conversation_entry in conversation_entries: + # Parse it back into an LLMResponse object + additional_kwargs_data = conversation_entry[6] + additional_kwargs_obj = None + if additional_kwargs_data: + additional_kwargs_obj = AdditionalKwargs.model_validate(additional_kwargs_data) cache_entry = CacheEntry( query=conversation_entry[0], response=conversation_entry[1], @@ -218,6 +225,7 @@ def get( model=conversation_entry[3], started_at=conversation_entry[4], completed_at=conversation_entry[5], + additional_kwargs=additional_kwargs_obj, ) result.append(cache_entry) @@ -245,6 +253,10 @@ def insert_or_append( raise CacheError("insert_or_append: cache is disconnected") try: + additional_kwargs_json = None + if cache_entry.additional_kwargs: + # Use exclude_none=True to keep JSON clean + additional_kwargs_json = cache_entry.additional_kwargs.model_dump_json(exclude_none=True) # the whole operation is run in one transaction with self.connection.cursor() as cursor: cursor.execute( @@ -258,6 +270,7 @@ def insert_or_append( cache_entry.response, cache_entry.provider, cache_entry.model, + additional_kwargs_json, ), ) diff --git a/src/cache/sqlite_cache.py b/src/cache/sqlite_cache.py index b8a91fac9..40af1eff5 100644 --- a/src/cache/sqlite_cache.py +++ b/src/cache/sqlite_cache.py @@ -3,11 +3,13 @@ from time import time import sqlite3 +import json from cache.cache import Cache from cache.cache_error import CacheError -from models.cache_entry import CacheEntry, ConversationData +from models.cache_entry import CacheEntry, AdditionalKwargs from models.config import SQLiteDatabaseConfiguration +from models.responses import ConversationData from log import get_logger from utils.connection_decorator import connection @@ -41,15 +43,16 @@ class SQLiteCache(Cache): CREATE_CACHE_TABLE = """ CREATE TABLE IF NOT EXISTS cache ( - user_id text NOT NULL, - conversation_id text NOT NULL, - created_at int NOT NULL, - started_at text, - completed_at text, - query text, - response text, - provider text, - model text, + user_id text NOT NULL, + conversation_id text NOT NULL, + created_at int NOT NULL, + started_at text, + completed_at text, + query text, + response text, + provider text, + model text, + additional_kwargs text, PRIMARY KEY(user_id, conversation_id, created_at) ); """ @@ -70,7 +73,7 @@ class SQLiteCache(Cache): """ SELECT_CONVERSATION_HISTORY_STATEMENT = """ - SELECT query, response, provider, model, started_at, completed_at + SELECT query, response, provider, model, started_at, completed_at, additional_kwargs FROM cache WHERE user_id=? AND conversation_id=? ORDER BY created_at @@ -78,8 +81,8 @@ class SQLiteCache(Cache): INSERT_CONVERSATION_HISTORY_STATEMENT = """ INSERT INTO cache(user_id, conversation_id, created_at, started_at, completed_at, - query, response, provider, model) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + query, response, provider, model, additional_kwargs) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """ QUERY_CACHE_SIZE = """ @@ -209,6 +212,10 @@ def get( result = [] for conversation_entry in conversation_entries: + additional_kwargs_json = conversation_entry[6] + additional_kwargs_obj = None + if additional_kwargs_json: + additional_kwargs_obj = AdditionalKwargs.model_validate_json(additional_kwargs_json) cache_entry = CacheEntry( query=conversation_entry[0], response=conversation_entry[1], @@ -216,6 +223,7 @@ def get( model=conversation_entry[3], started_at=conversation_entry[4], completed_at=conversation_entry[5], + additional_kwargs=additional_kwargs_obj, ) result.append(cache_entry) @@ -244,6 +252,11 @@ def insert_or_append( cursor = self.connection.cursor() current_time = time() + + additional_kwargs_json = None + if cache_entry.additional_kwargs: + additional_kwargs_json = cache_entry.additional_kwargs.model_dump_json(exclude_none=True) + cursor.execute( self.INSERT_CONVERSATION_HISTORY_STATEMENT, ( @@ -256,6 +269,7 @@ def insert_or_append( cache_entry.response, cache_entry.provider, cache_entry.model, + additional_kwargs_json, ), ) diff --git a/src/models/cache_entry.py b/src/models/cache_entry.py index 9f3119f10..3b47cb2a8 100644 --- a/src/models/cache_entry.py +++ b/src/models/cache_entry.py @@ -1,6 +1,12 @@ """Model for conversation history cache entry.""" -from pydantic import BaseModel +from pydantic import BaseModel, Field +from typing import List +from models.responses import ReferencedDocument + +class AdditionalKwargs(BaseModel): + """A structured model for the 'additional_kwargs' dictionary.""" + referenced_documents: List[ReferencedDocument] = Field(default_factory=list) class CacheEntry(BaseModel): @@ -11,6 +17,7 @@ class CacheEntry(BaseModel): response: The response string provider: Provider identification model: Model identification + additional_kwargs: additional property to store data like referenced documents """ query: str @@ -19,17 +26,4 @@ class CacheEntry(BaseModel): model: str started_at: str completed_at: str - - -class ConversationData(BaseModel): - """Model representing conversation data returned by cache list operations. - - Attributes: - conversation_id: The conversation ID - topic_summary: The topic summary for the conversation (can be None) - last_message_timestamp: The timestamp of the last message in the conversation - """ - - conversation_id: str - topic_summary: str | None - last_message_timestamp: float + additional_kwargs: AdditionalKwargs | None = None diff --git a/src/models/responses.py b/src/models/responses.py index fdf324d76..90d824ce8 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -162,6 +162,18 @@ class ToolCall(BaseModel): arguments: dict[str, Any] = Field(description="Arguments passed to the tool") result: Optional[dict[str, Any]] = Field(None, description="Result from the tool") +class ConversationData(BaseModel): + """Model representing conversation data returned by cache list operations. + + Attributes: + conversation_id: The conversation ID + topic_summary: The topic summary for the conversation (can be None) + last_message_timestamp: The timestamp of the last message in the conversation + """ + + conversation_id: str + topic_summary: str | None + last_message_timestamp: float class ReferencedDocument(BaseModel): """Model representing a document referenced in generating a response. @@ -175,7 +187,9 @@ class ReferencedDocument(BaseModel): None, description="URL of the referenced document" ) - doc_title: str = Field(description="Title of the referenced document") + doc_title: str | None = Field( + None, description="Title of the referenced document" + ) class QueryResponse(BaseModel): diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index e9246b714..de2e9bec7 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -232,12 +232,7 @@ def store_conversation_into_cache( config: AppConfig, user_id: str, conversation_id: str, - provider_id: str, - model_id: str, - query: str, - response: str, - started_at: str, - completed_at: str, + cache_entry: CacheEntry, _skip_userid_check: bool, topic_summary: str | None, ) -> None: @@ -247,14 +242,6 @@ def store_conversation_into_cache( if cache is None: logger.warning("Conversation cache configured but not initialized") return - cache_entry = CacheEntry( - query=query, - response=response, - provider=provider_id, - model=model_id, - started_at=started_at, - completed_at=completed_at, - ) cache.insert_or_append( user_id, conversation_id, cache_entry, _skip_userid_check ) From 88c6edfa6dc45b1dabee245978c7233d5d1e2e36 Mon Sep 17 00:00:00 2001 From: Maysun J Faisal Date: Thu, 9 Oct 2025 18:12:57 -0400 Subject: [PATCH 2/4] Add Unit tests Signed-off-by: Maysun J Faisal --- docs/openapi.json | 51 +++++++++- .../app/endpoints/test_conversations_v2.py | 37 +++++++- tests/unit/app/endpoints/test_query.py | 23 ++++- .../app/endpoints/test_streaming_query.py | 17 ++++ tests/unit/cache/test_postgres_cache.py | 95 ++++++++++++++++++- tests/unit/cache/test_sqlite_cache.py | 54 ++++++++++- 6 files changed, 267 insertions(+), 10 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index d036fa3f2..45f77ed59 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -3691,15 +3691,19 @@ "description": "URL of the referenced document" }, "doc_title": { - "type": "string", + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], "title": "Doc Title", "description": "Title of the referenced document" } }, "type": "object", - "required": [ - "doc_title" - ], "title": "ReferencedDocument", "description": "Model representing a document referenced in generating a response.\n\nAttributes:\n doc_url: Url to the referenced doc.\n doc_title: Title of the referenced doc." }, @@ -3969,6 +3973,45 @@ "title": "ToolsResponse", "description": "Model representing a response to tools request." }, + "UnauthorizedResponse": { + "properties": { + "tools": { + "items": { + "additionalProperties": true, + "type": "object" + }, + "type": "array", + "title": "Tools", + "description": "List of tools available from all configured MCP servers and built-in toolgroups", + "examples": [ + [ + { + "description": "Read contents of a file from the filesystem", + "identifier": "filesystem_read", + "parameters": [ + { + "description": "Path to the file to read", + "name": "path", + "parameter_type": "string", + "required": true + } + ], + "provider_id": "model-context-protocol", + "server_source": "http://localhost:3000", + "toolgroup_id": "filesystem-tools", + "type": "tool" + } + ] + ] + } + }, + "type": "object", + "required": [ + "tools" + ], + "title": "ToolsResponse", + "description": "Model representing a response to tools request." + }, "UnauthorizedResponse": { "properties": { "detail": { diff --git a/tests/unit/app/endpoints/test_conversations_v2.py b/tests/unit/app/endpoints/test_conversations_v2.py index 0448c2206..6982a6197 100644 --- a/tests/unit/app/endpoints/test_conversations_v2.py +++ b/tests/unit/app/endpoints/test_conversations_v2.py @@ -3,6 +3,7 @@ """Unit tests for the /conversations REST API endpoints.""" from unittest.mock import Mock +from pydantic import AnyUrl import pytest from fastapi import HTTPException, status @@ -12,9 +13,9 @@ check_valid_conversation_id, check_conversation_existence, ) -from models.cache_entry import CacheEntry +from models.cache_entry import AdditionalKwargs, CacheEntry from models.requests import ConversationUpdateRequest -from models.responses import ConversationUpdateResponse +from models.responses import ConversationUpdateResponse, ReferencedDocument from tests.unit.utils.auth_helpers import mock_authorization_resolvers MOCK_AUTH = ("mock_user_id", "mock_username", False, "mock_token") @@ -59,6 +60,38 @@ def test_transform_message() -> None: assert message2["content"] == "response" +def test_transform_message_with_additional_kwargs() -> None: + """Test the transform_chat_message function when additional_kwargs are present.""" + # CacheEntry with referenced documents + docs = [ReferencedDocument(doc_title="Test Doc", doc_url=AnyUrl("http://example.com"))] + kwargs_obj = AdditionalKwargs(referenced_documents=docs) + + entry = CacheEntry( + query="query", + response="response", + provider="provider", + model="model", + started_at="2024-01-01T00:00:00Z", + completed_at="2024-01-01T00:00:05Z", + additional_kwargs=kwargs_obj + ) + + transformed = transform_chat_message(entry) + assert transformed is not None + + assistant_message = transformed["messages"][1] + + # Check that the assistant message contains the additional_kwargs field + assert "additional_kwargs" in assistant_message + + # Check the content of the referenced documents + kwargs = assistant_message["additional_kwargs"] + assert "referenced_documents" in kwargs + assert len(kwargs["referenced_documents"]) == 1 + assert kwargs["referenced_documents"][0]["doc_title"] == "Test Doc" + assert str(kwargs["referenced_documents"][0]["doc_url"]) == "http://example.com/" + + @pytest.fixture def mock_configuration(): """Mock configuration with conversation cache.""" diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index b0d89351d..9da5a3ebb 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -30,6 +30,7 @@ ) from authorization.resolvers import NoopRolesResolver from configuration import AppConfig +from models.cache_entry import CacheEntry from models.config import Action, ModelContextProtocolServer from models.database.conversations import UserConversation from models.requests import Attachment, QueryRequest @@ -183,6 +184,13 @@ async def _test_query_endpoint_handler( store_transcript_to_file ) mocker.patch("app.endpoints.query.configuration", mock_config) + + mock_store_in_cache = mocker.patch("app.endpoints.query.store_conversation_into_cache") + + # Create mock referenced documents to simulate a successful RAG response + mock_referenced_documents = [ + ReferencedDocument(doc_title="Test Doc 1", doc_url=AnyUrl("http://example.com/1")) + ] summary = TurnSummary( llm_response="LLM answer", @@ -197,11 +205,10 @@ async def _test_query_endpoint_handler( ) conversation_id = "00000000-0000-0000-0000-000000000000" query = "What is OpenStack?" - referenced_documents: list[ReferencedDocument] = [] mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(summary, conversation_id, referenced_documents, TokenCounter()), + return_value=(summary, conversation_id, mock_referenced_documents, TokenCounter()), ) mocker.patch( "app.endpoints.query.select_model_and_provider_id", @@ -230,6 +237,18 @@ async def _test_query_endpoint_handler( # Assert the response is as expected assert response.response == summary.llm_response assert response.conversation_id == conversation_id + + # Assert that mock was called and get the arguments + mock_store_in_cache.assert_called_once() + call_args = mock_store_in_cache.call_args[0] + # Extract CacheEntry object from the call arguments, it's the 4th argument from the func signature + cached_entry = call_args[3] + + assert isinstance(cached_entry, CacheEntry) + assert cached_entry.response == "LLM answer" + assert cached_entry.additional_kwargs is not None + assert len(cached_entry.additional_kwargs.referenced_documents) == 1 + assert cached_entry.additional_kwargs.referenced_documents[0].doc_title == "Test Doc 1" # Note: metrics are now handled inside extract_and_update_token_metrics() which is mocked diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 3bd12816a..501b67877 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -51,6 +51,7 @@ from authorization.resolvers import NoopRolesResolver from constants import MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT +from models.cache_entry import CacheEntry from models.config import ModelContextProtocolServer, Action from models.requests import QueryRequest, Attachment from models.responses import RAGChunk @@ -295,6 +296,8 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) ), ] + mock_store_in_cache = mocker.patch("app.endpoints.streaming_query.store_conversation_into_cache") + query = "What is OpenStack?" mocker.patch( "app.endpoints.streaming_query.retrieve_response", @@ -355,6 +358,20 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) referenced_documents = d["data"]["referenced_documents"] assert len(referenced_documents) == 2 assert referenced_documents[1]["doc_title"] == "Doc2" + + # Assert that mock was called and get the arguments + mock_store_in_cache.assert_called_once() + call_args = mock_store_in_cache.call_args[0] + # Extract CacheEntry object from the call arguments, it's the 4th argument from the func signature + cached_entry = call_args[3] + + # Assert that the CacheEntry was constructed correctly + assert isinstance(cached_entry, CacheEntry) + assert cached_entry.response == "LLM answer" + assert cached_entry.additional_kwargs is not None + assert len(cached_entry.additional_kwargs.referenced_documents) == 2 + assert cached_entry.additional_kwargs.referenced_documents[0].doc_title == "Doc1" + assert str(cached_entry.additional_kwargs.referenced_documents[1].doc_url) == "https://example.com/doc2" # Assert the store_transcript function is called if transcripts are enabled if store_transcript: diff --git a/tests/unit/cache/test_postgres_cache.py b/tests/unit/cache/test_postgres_cache.py index 32646997d..fd6780e0c 100644 --- a/tests/unit/cache/test_postgres_cache.py +++ b/tests/unit/cache/test_postgres_cache.py @@ -1,13 +1,18 @@ """Unit tests for PostgreSQL cache implementation.""" +import json + import pytest import psycopg2 +from pydantic import AnyUrl + from cache.cache_error import CacheError from cache.postgres_cache import PostgresCache from models.config import PostgreSQLDatabaseConfiguration -from models.cache_entry import CacheEntry, ConversationData +from models.cache_entry import AdditionalKwargs, CacheEntry, ReferencedDocument +from models.responses import ConversationData from utils import suid @@ -375,3 +380,91 @@ def test_topic_summary_when_disconnected(postgres_cache_config_fixture, mocker): with pytest.raises(CacheError, match="cache is disconnected"): cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, "Test", False) + + +def test_insert_and_get_with_additional_kwargs(postgres_cache_config_fixture, mocker): + """Test that a CacheEntry with additional_kwargs is stored and retrieved correctly.""" + # prevent real connection to PG instance + mock_connect = mocker.patch("psycopg2.connect") + cache = PostgresCache(postgres_cache_config_fixture) + + mock_connection = mock_connect.return_value + mock_cursor = mock_connection.cursor.return_value.__enter__.return_value + + # Create a CacheEntry with referenced documents + docs = [ReferencedDocument(doc_title="Test Doc", doc_url=AnyUrl("http://example.com"))] + kwargs_obj = AdditionalKwargs(referenced_documents=docs) + entry_with_kwargs = CacheEntry( + query="user message", + response="AI message", + provider="foo", model="bar", + started_at="start_time", completed_at="end_time", + additional_kwargs=kwargs_obj + ) + + # Call the insert method + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_with_kwargs) + + + insert_call = mock_cursor.execute.call_args_list[1] + sql_params = insert_call[0][1] + inserted_json_str = sql_params[-1] + + assert json.loads(inserted_json_str) == { + "referenced_documents": [{"doc_title": "Test Doc", "doc_url": "http://example.com/"}] + } + + # Simulate the database returning that data + db_return_value = ( + "user message", "AI message", "foo", "bar", "start_time", "end_time", + {"referenced_documents": [{"doc_title": "Test Doc", "doc_url": "http://example.com/"}]} + ) + mock_cursor.fetchall.return_value = [db_return_value] + + # Call the get method + retrieved_entries = cache.get(USER_ID_1, CONVERSATION_ID_1) + + # Assert that the retrieved entry matches the original + assert len(retrieved_entries) == 1 + assert retrieved_entries[0] == entry_with_kwargs + assert retrieved_entries[0].additional_kwargs.referenced_documents[0].doc_title == "Test Doc" + + +def test_insert_and_get_without_additional_kwargs(postgres_cache_config_fixture, mocker): + """Test that a CacheEntry with no additional_kwargs is handled correctly.""" + # Arrange + mock_connect = mocker.patch("psycopg2.connect") + cache = PostgresCache(postgres_cache_config_fixture) + + mock_connection = mock_connect.return_value + mock_cursor = mock_connection.cursor.return_value.__enter__.return_value + + # Use CacheEntry without additional_kwargs + entry_without_kwargs = cache_entry_2 + + # Call the insert method + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_without_kwargs) + + insert_call = mock_cursor.execute.call_args_list[1] + sql_params = insert_call[0][1] + assert sql_params[-1] is None + + # 4. Simulate the database returning a row with None + db_return_value = ( + entry_without_kwargs.query, + entry_without_kwargs.response, + entry_without_kwargs.provider, + entry_without_kwargs.model, + entry_without_kwargs.started_at, + entry_without_kwargs.completed_at, + None # additional_kwargs is None in the DB + ) + mock_cursor.fetchall.return_value = [db_return_value] + + # Call the get method + retrieved_entries = cache.get(USER_ID_1, CONVERSATION_ID_1) + + # Assert that the retrieved entry matches the original + assert len(retrieved_entries) == 1 + assert retrieved_entries[0] == entry_without_kwargs + assert retrieved_entries[0].additional_kwargs is None diff --git a/tests/unit/cache/test_sqlite_cache.py b/tests/unit/cache/test_sqlite_cache.py index 6c9fd0fab..698ff4b36 100644 --- a/tests/unit/cache/test_sqlite_cache.py +++ b/tests/unit/cache/test_sqlite_cache.py @@ -6,8 +6,11 @@ import pytest +from pydantic import AnyUrl + from models.config import SQLiteDatabaseConfiguration -from models.cache_entry import CacheEntry, ConversationData +from models.cache_entry import AdditionalKwargs, CacheEntry, ReferencedDocument +from models.responses import ConversationData from utils import suid from cache.cache_error import CacheError @@ -357,3 +360,52 @@ def test_topic_summary_when_disconnected(tmpdir): with pytest.raises(CacheError, match="cache is disconnected"): cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, "Test", False) + + +def test_insert_and_get_with_additional_kwargs(tmpdir): + """ + Test that a CacheEntry with additional_kwargs is correctly + serialized, stored, and retrieved. + """ + cache = create_cache(tmpdir) + + # Create a CacheEntry with referenced documents + docs = [ReferencedDocument(doc_title="Test Doc", doc_url=AnyUrl("http://example.com"))] + kwargs_obj = AdditionalKwargs(referenced_documents=docs) + entry_with_kwargs = CacheEntry( + query="user message", + response="AI message", + provider="foo", model="bar", + started_at="start_time", completed_at="end_time", + additional_kwargs=kwargs_obj + ) + + # Call the insert method + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_with_kwargs) + retrieved_entries = cache.get(USER_ID_1, CONVERSATION_ID_1) + + # Assert that the retrieved entry matches the original + assert len(retrieved_entries) == 1 + assert retrieved_entries[0] == entry_with_kwargs + assert retrieved_entries[0].additional_kwargs is not None + assert retrieved_entries[0].additional_kwargs.referenced_documents[0].doc_title == "Test Doc" + + +def test_insert_and_get_without_additional_kwargs(tmpdir): + """ + Test that a CacheEntry without additional_kwargs is correctly + stored and retrieved with its additional_kwargs attribute as None. + """ + cache = create_cache(tmpdir) + + # Use CacheEntry without additional_kwargs + entry_without_kwargs = cache_entry_1 + + # Call the insert method + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_without_kwargs) + retrieved_entries = cache.get(USER_ID_1, CONVERSATION_ID_1) + + # Assert that the retrieved entry matches the original + assert len(retrieved_entries) == 1 + assert retrieved_entries[0] == entry_without_kwargs + assert retrieved_entries[0].additional_kwargs is None \ No newline at end of file From ef0b934720a71bbe1a728b96826f164c7243a286 Mon Sep 17 00:00:00 2001 From: Maysun J Faisal Date: Fri, 17 Oct 2025 00:29:20 -0400 Subject: [PATCH 3/4] Flatten referenced_documents in /v2/conversations response Signed-off-by: Maysun J Faisal --- src/app/endpoints/conversations_v2.py | 8 +- src/app/endpoints/query.py | 9 +- src/app/endpoints/streaming_query.py | 10 +-- src/cache/postgres_cache.py | 73 +++++++-------- src/cache/sqlite_cache.py | 71 ++++++++------- src/models/cache_entry.py | 6 +- src/models/responses.py | 1 - .../app/endpoints/test_conversations_v2.py | 90 +++++++++++++------ tests/unit/app/endpoints/test_query.py | 6 +- .../app/endpoints/test_streaming_query.py | 9 +- tests/unit/cache/test_postgres_cache.py | 54 ++++++----- tests/unit/cache/test_sqlite_cache.py | 35 ++++---- 12 files changed, 198 insertions(+), 174 deletions(-) diff --git a/src/app/endpoints/conversations_v2.py b/src/app/endpoints/conversations_v2.py index d311ce54d..36265c9f7 100644 --- a/src/app/endpoints/conversations_v2.py +++ b/src/app/endpoints/conversations_v2.py @@ -323,9 +323,11 @@ def transform_chat_message(entry: CacheEntry) -> dict[str, Any]: "type": "assistant" } - # Check for additional_kwargs and add it to the assistant message if it exists - if entry.additional_kwargs: - assistant_message["additional_kwargs"] = entry.additional_kwargs.model_dump() + # If referenced_documents exist on the entry, add them to the assistant message + if entry.referenced_documents is not None: + assistant_message["referenced_documents"] = [ + doc.model_dump(mode='json') for doc in entry.referenced_documents + ] return { "provider": entry.provider, diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index d92db9fc3..5255e9610 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -32,7 +32,7 @@ from authorization.middleware import authorize from client import AsyncLlamaStackClientHolder from configuration import configuration -from models.cache_entry import CacheEntry, AdditionalKwargs +from models.cache_entry import CacheEntry from models.config import Action from models.database.conversations import UserConversation from models.requests import Attachment, QueryRequest @@ -334,11 +334,6 @@ async def query_endpoint_handler( # pylint: disable=R0914 completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ") - additional_kwargs_obj = None - if referenced_documents: - additional_kwargs_obj = AdditionalKwargs( - referenced_documents=referenced_documents - ) cache_entry = CacheEntry( query=query_request.query, response=summary.llm_response, @@ -346,7 +341,7 @@ async def query_endpoint_handler( # pylint: disable=R0914 model=model_id, started_at=started_at, completed_at=completed_at, - additional_kwargs=additional_kwargs_obj + referenced_documents=referenced_documents if referenced_documents else None ) store_conversation_into_cache( diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index bc1dd2c58..076884427 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -44,13 +44,14 @@ from constants import DEFAULT_RAG_TOOL, MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT import metrics from metrics.utils import update_llm_token_count_from_turn -from models.cache_entry import CacheEntry, AdditionalKwargs +from models.cache_entry import CacheEntry from models.config import Action from models.database.conversations import UserConversation from models.requests import QueryRequest from models.responses import ForbiddenResponse, UnauthorizedResponse, ReferencedDocument from utils.endpoints import ( check_configuration_loaded, + create_referenced_documents_with_metadata, create_rag_chunks_dict, get_agent, get_system_prompt, @@ -868,11 +869,6 @@ async def response_generator( referenced_documents = create_referenced_documents_with_metadata(summary, metadata_map) - additional_kwargs_obj = None - if referenced_documents: - additional_kwargs_obj = AdditionalKwargs( - referenced_documents=referenced_documents - ) cache_entry = CacheEntry( query=query_request.query, response=summary.llm_response, @@ -880,7 +876,7 @@ async def response_generator( model=model_id, started_at=started_at, completed_at=completed_at, - additional_kwargs=additional_kwargs_obj + referenced_documents=referenced_documents if referenced_documents else None ) store_conversation_into_cache( diff --git a/src/cache/postgres_cache.py b/src/cache/postgres_cache.py index e25de3522..73b7d47fc 100644 --- a/src/cache/postgres_cache.py +++ b/src/cache/postgres_cache.py @@ -1,12 +1,13 @@ """PostgreSQL cache implementation.""" +import json import psycopg2 from cache.cache import Cache from cache.cache_error import CacheError -from models.cache_entry import CacheEntry, AdditionalKwargs +from models.cache_entry import CacheEntry from models.config import PostgreSQLDatabaseConfiguration -from models.responses import ConversationData +from models.responses import ConversationData, ReferencedDocument from log import get_logger from utils.connection_decorator import connection @@ -19,17 +20,18 @@ class PostgresCache(Cache): The cache itself lives stored in following table: ``` - Column | Type | Nullable | - -----------------+--------------------------------+----------+ - user_id | text | not null | - conversation_id | text | not null | - created_at | timestamp without time zone | not null | - started_at | text | | - completed_at | text | | - query | text | | - response | text | | - provider | text | | - model | text | | + Column | Type | Nullable | + -----------------------+--------------------------------+----------+ + user_id | text | not null | + conversation_id | text | not null | + created_at | timestamp without time zone | not null | + started_at | text | | + completed_at | text | | + query | text | | + response | text | | + provider | text | | + model | text | | + referenced_documents | jsonb | | Indexes: "cache_pkey" PRIMARY KEY, btree (user_id, conversation_id, created_at) "timestamps" btree (created_at) @@ -38,16 +40,16 @@ class PostgresCache(Cache): CREATE_CACHE_TABLE = """ CREATE TABLE IF NOT EXISTS cache ( - user_id text NOT NULL, - conversation_id text NOT NULL, - created_at timestamp NOT NULL, - started_at text, - completed_at text, - query text, - response text, - provider text, - model text, - additional_kwargs jsonb, + user_id text NOT NULL, + conversation_id text NOT NULL, + created_at timestamp NOT NULL, + started_at text, + completed_at text, + query text, + response text, + provider text, + model text, + referenced_documents jsonb, PRIMARY KEY(user_id, conversation_id, created_at) ); """ @@ -68,7 +70,7 @@ class PostgresCache(Cache): """ SELECT_CONVERSATION_HISTORY_STATEMENT = """ - SELECT query, response, provider, model, started_at, completed_at, additional_kwargs + SELECT query, response, provider, model, started_at, completed_at, referenced_documents FROM cache WHERE user_id=%s AND conversation_id=%s ORDER BY created_at @@ -76,7 +78,7 @@ class PostgresCache(Cache): INSERT_CONVERSATION_HISTORY_STATEMENT = """ INSERT INTO cache(user_id, conversation_id, created_at, started_at, completed_at, - query, response, provider, model, additional_kwargs) + query, response, provider, model, referenced_documents) VALUES (%s, %s, CURRENT_TIMESTAMP, %s, %s, %s, %s, %s, %s, %s) """ @@ -214,10 +216,10 @@ def get( result = [] for conversation_entry in conversation_entries: # Parse it back into an LLMResponse object - additional_kwargs_data = conversation_entry[6] - additional_kwargs_obj = None - if additional_kwargs_data: - additional_kwargs_obj = AdditionalKwargs.model_validate(additional_kwargs_data) + docs_data = conversation_entry[6] + docs_obj = None + if docs_data: + docs_obj = [ReferencedDocument.model_validate(doc) for doc in docs_data] cache_entry = CacheEntry( query=conversation_entry[0], response=conversation_entry[1], @@ -225,7 +227,7 @@ def get( model=conversation_entry[3], started_at=conversation_entry[4], completed_at=conversation_entry[5], - additional_kwargs=additional_kwargs_obj, + referenced_documents=docs_obj, ) result.append(cache_entry) @@ -253,10 +255,11 @@ def insert_or_append( raise CacheError("insert_or_append: cache is disconnected") try: - additional_kwargs_json = None - if cache_entry.additional_kwargs: - # Use exclude_none=True to keep JSON clean - additional_kwargs_json = cache_entry.additional_kwargs.model_dump_json(exclude_none=True) + referenced_documents_json = None + if cache_entry.referenced_documents: + docs_as_dicts = [doc.model_dump(mode='json') for doc in cache_entry.referenced_documents] + referenced_documents_json = json.dumps(docs_as_dicts) + # the whole operation is run in one transaction with self.connection.cursor() as cursor: cursor.execute( @@ -270,7 +273,7 @@ def insert_or_append( cache_entry.response, cache_entry.provider, cache_entry.model, - additional_kwargs_json, + referenced_documents_json, ), ) diff --git a/src/cache/sqlite_cache.py b/src/cache/sqlite_cache.py index 40af1eff5..7002f4ec0 100644 --- a/src/cache/sqlite_cache.py +++ b/src/cache/sqlite_cache.py @@ -7,9 +7,9 @@ from cache.cache import Cache from cache.cache_error import CacheError -from models.cache_entry import CacheEntry, AdditionalKwargs +from models.cache_entry import CacheEntry from models.config import SQLiteDatabaseConfiguration -from models.responses import ConversationData +from models.responses import ConversationData, ReferencedDocument from log import get_logger from utils.connection_decorator import connection @@ -22,17 +22,18 @@ class SQLiteCache(Cache): The cache itself is stored in following table: ``` - Column | Type | Nullable | - -----------------+-----------------------------+----------+ - user_id | text | not null | - conversation_id | text | not null | - created_at | int | not null | - started_at | text | | - completed_at | text | | - query | text | | - response | text | | - provider | text | | - model | text | | + Column | Type | Nullable | + -----------------------+-----------------------------+----------+ + user_id | text | not null | + conversation_id | text | not null | + created_at | int | not null | + started_at | text | | + completed_at | text | | + query | text | | + response | text | | + provider | text | | + model | text | | + referenced_documents | text | | Indexes: "cache_pkey" PRIMARY KEY, btree (user_id, conversation_id, created_at) "cache_key_key" UNIQUE CONSTRAINT, btree (key) @@ -43,16 +44,16 @@ class SQLiteCache(Cache): CREATE_CACHE_TABLE = """ CREATE TABLE IF NOT EXISTS cache ( - user_id text NOT NULL, - conversation_id text NOT NULL, - created_at int NOT NULL, - started_at text, - completed_at text, - query text, - response text, - provider text, - model text, - additional_kwargs text, + user_id text NOT NULL, + conversation_id text NOT NULL, + created_at int NOT NULL, + started_at text, + completed_at text, + query text, + response text, + provider text, + model text, + referenced_documents text, PRIMARY KEY(user_id, conversation_id, created_at) ); """ @@ -73,7 +74,7 @@ class SQLiteCache(Cache): """ SELECT_CONVERSATION_HISTORY_STATEMENT = """ - SELECT query, response, provider, model, started_at, completed_at, additional_kwargs + SELECT query, response, provider, model, started_at, completed_at, referenced_documents FROM cache WHERE user_id=? AND conversation_id=? ORDER BY created_at @@ -81,7 +82,7 @@ class SQLiteCache(Cache): INSERT_CONVERSATION_HISTORY_STATEMENT = """ INSERT INTO cache(user_id, conversation_id, created_at, started_at, completed_at, - query, response, provider, model, additional_kwargs) + query, response, provider, model, referenced_documents) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """ @@ -212,10 +213,11 @@ def get( result = [] for conversation_entry in conversation_entries: - additional_kwargs_json = conversation_entry[6] - additional_kwargs_obj = None - if additional_kwargs_json: - additional_kwargs_obj = AdditionalKwargs.model_validate_json(additional_kwargs_json) + docs_json_str = conversation_entry[6] + docs_obj = None + if docs_json_str: + docs_data = json.loads(docs_json_str) + docs_obj = [ReferencedDocument.model_validate(doc) for doc in docs_data] cache_entry = CacheEntry( query=conversation_entry[0], response=conversation_entry[1], @@ -223,7 +225,7 @@ def get( model=conversation_entry[3], started_at=conversation_entry[4], completed_at=conversation_entry[5], - additional_kwargs=additional_kwargs_obj, + referenced_documents=docs_obj, ) result.append(cache_entry) @@ -253,9 +255,10 @@ def insert_or_append( cursor = self.connection.cursor() current_time = time() - additional_kwargs_json = None - if cache_entry.additional_kwargs: - additional_kwargs_json = cache_entry.additional_kwargs.model_dump_json(exclude_none=True) + referenced_documents_json = None + if cache_entry.referenced_documents: + docs_as_dicts = [doc.model_dump(mode='json') for doc in cache_entry.referenced_documents] + referenced_documents_json = json.dumps(docs_as_dicts) cursor.execute( self.INSERT_CONVERSATION_HISTORY_STATEMENT, @@ -269,7 +272,7 @@ def insert_or_append( cache_entry.response, cache_entry.provider, cache_entry.model, - additional_kwargs_json, + referenced_documents_json, ), ) diff --git a/src/models/cache_entry.py b/src/models/cache_entry.py index 3b47cb2a8..86184894a 100644 --- a/src/models/cache_entry.py +++ b/src/models/cache_entry.py @@ -4,10 +4,6 @@ from typing import List from models.responses import ReferencedDocument -class AdditionalKwargs(BaseModel): - """A structured model for the 'additional_kwargs' dictionary.""" - referenced_documents: List[ReferencedDocument] = Field(default_factory=list) - class CacheEntry(BaseModel): """Model representing a cache entry. @@ -26,4 +22,4 @@ class CacheEntry(BaseModel): model: str started_at: str completed_at: str - additional_kwargs: AdditionalKwargs | None = None + referenced_documents: List[ReferencedDocument] | None = None diff --git a/src/models/responses.py b/src/models/responses.py index 90d824ce8..662439bcf 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -7,7 +7,6 @@ from pydantic import AnyUrl, BaseModel, Field from llama_stack_client.types import ProviderInfo -from models.cache_entry import ConversationData class ModelsResponse(BaseModel): diff --git a/tests/unit/app/endpoints/test_conversations_v2.py b/tests/unit/app/endpoints/test_conversations_v2.py index 6982a6197..b5a975700 100644 --- a/tests/unit/app/endpoints/test_conversations_v2.py +++ b/tests/unit/app/endpoints/test_conversations_v2.py @@ -13,7 +13,7 @@ check_valid_conversation_id, check_conversation_existence, ) -from models.cache_entry import AdditionalKwargs, CacheEntry +from models.cache_entry import CacheEntry from models.requests import ConversationUpdateRequest from models.responses import ConversationUpdateResponse, ReferencedDocument from tests.unit.utils.auth_helpers import mock_authorization_resolvers @@ -60,36 +60,68 @@ def test_transform_message() -> None: assert message2["content"] == "response" -def test_transform_message_with_additional_kwargs() -> None: - """Test the transform_chat_message function when additional_kwargs are present.""" - # CacheEntry with referenced documents - docs = [ReferencedDocument(doc_title="Test Doc", doc_url=AnyUrl("http://example.com"))] - kwargs_obj = AdditionalKwargs(referenced_documents=docs) - - entry = CacheEntry( - query="query", - response="response", - provider="provider", - model="model", - started_at="2024-01-01T00:00:00Z", - completed_at="2024-01-01T00:00:05Z", - additional_kwargs=kwargs_obj - ) +class TestTransformChatMessage: + """Test cases for the transform_chat_message utility function.""" - transformed = transform_chat_message(entry) - assert transformed is not None + def test_transform_message_without_documents(self) -> None: + """Test the transformation when no referenced_documents are present.""" + entry = CacheEntry( + query="query", + response="response", + provider="provider", + model="model", + started_at="2024-01-01T00:00:00Z", + completed_at="2024-01-01T00:00:05Z", + # referenced_documents is None by default + ) + transformed = transform_chat_message(entry) + + assistant_message = transformed["messages"][1] + + # Assert that the key is NOT present when the list is None + assert "referenced_documents" not in assistant_message + + def test_transform_message_with_referenced_documents(self) -> None: + """Test the transformation when referenced_documents are present.""" + docs = [ReferencedDocument(doc_title="Test Doc", doc_url=AnyUrl("http://example.com"))] + + entry = CacheEntry( + query="query", + response="response", + provider="provider", + model="model", + started_at="2024-01-01T00:00:00Z", + completed_at="2024-01-01T00:00:05Z", + referenced_documents=docs + ) + + transformed = transform_chat_message(entry) + assistant_message = transformed["messages"][1] + + assert "referenced_documents" in assistant_message + + ref_docs = assistant_message["referenced_documents"] + assert len(ref_docs) == 1 + assert ref_docs[0]["doc_title"] == "Test Doc" + assert str(ref_docs[0]["doc_url"]) == "http://example.com/" + + def test_transform_message_with_empty_referenced_documents(self) -> None: + """Test the transformation when referenced_documents is an empty list.""" + entry = CacheEntry( + query="query", + response="response", + provider="provider", + model="model", + started_at="2024-01-01T00:00:00Z", + completed_at="2024-01-01T00:00:05Z", + referenced_documents=[] # Explicitly empty + ) + + transformed = transform_chat_message(entry) + assistant_message = transformed["messages"][1] - assistant_message = transformed["messages"][1] - - # Check that the assistant message contains the additional_kwargs field - assert "additional_kwargs" in assistant_message - - # Check the content of the referenced documents - kwargs = assistant_message["additional_kwargs"] - assert "referenced_documents" in kwargs - assert len(kwargs["referenced_documents"]) == 1 - assert kwargs["referenced_documents"][0]["doc_title"] == "Test Doc" - assert str(kwargs["referenced_documents"][0]["doc_url"]) == "http://example.com/" + assert "referenced_documents" in assistant_message + assert assistant_message["referenced_documents"] == [] @pytest.fixture diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 9da5a3ebb..0cc942b43 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -246,9 +246,9 @@ async def _test_query_endpoint_handler( assert isinstance(cached_entry, CacheEntry) assert cached_entry.response == "LLM answer" - assert cached_entry.additional_kwargs is not None - assert len(cached_entry.additional_kwargs.referenced_documents) == 1 - assert cached_entry.additional_kwargs.referenced_documents[0].doc_title == "Test Doc 1" + assert cached_entry.referenced_documents is not None + assert len(cached_entry.referenced_documents) == 1 + assert cached_entry.referenced_documents[0].doc_title == "Test Doc 1" # Note: metrics are now handled inside extract_and_update_token_metrics() which is mocked diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 501b67877..068395b07 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -284,6 +284,7 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) role="assistant", content=[TextContentItem(text="LLM answer", type="text")], stop_reason="end_of_turn", + tool_calls=[], ), session_id="test_session_id", started_at=datetime.now(), @@ -368,10 +369,10 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) # Assert that the CacheEntry was constructed correctly assert isinstance(cached_entry, CacheEntry) assert cached_entry.response == "LLM answer" - assert cached_entry.additional_kwargs is not None - assert len(cached_entry.additional_kwargs.referenced_documents) == 2 - assert cached_entry.additional_kwargs.referenced_documents[0].doc_title == "Doc1" - assert str(cached_entry.additional_kwargs.referenced_documents[1].doc_url) == "https://example.com/doc2" + assert cached_entry.referenced_documents is not None + assert len(cached_entry.referenced_documents) == 2 + assert cached_entry.referenced_documents[0].doc_title == "Doc1" + assert str(cached_entry.referenced_documents[1].doc_url) == "https://example.com/doc2" # Assert the store_transcript function is called if transcripts are enabled if store_transcript: diff --git a/tests/unit/cache/test_postgres_cache.py b/tests/unit/cache/test_postgres_cache.py index fd6780e0c..9120dfa9b 100644 --- a/tests/unit/cache/test_postgres_cache.py +++ b/tests/unit/cache/test_postgres_cache.py @@ -11,7 +11,7 @@ from cache.cache_error import CacheError from cache.postgres_cache import PostgresCache from models.config import PostgreSQLDatabaseConfiguration -from models.cache_entry import AdditionalKwargs, CacheEntry, ReferencedDocument +from models.cache_entry import CacheEntry, ReferencedDocument from models.responses import ConversationData from utils import suid @@ -382,8 +382,8 @@ def test_topic_summary_when_disconnected(postgres_cache_config_fixture, mocker): cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, "Test", False) -def test_insert_and_get_with_additional_kwargs(postgres_cache_config_fixture, mocker): - """Test that a CacheEntry with additional_kwargs is stored and retrieved correctly.""" +def test_insert_and_get_with_referenced_documents(postgres_cache_config_fixture, mocker): + """Test that a CacheEntry with referenced_documents is stored and retrieved correctly.""" # prevent real connection to PG instance mock_connect = mocker.patch("psycopg2.connect") cache = PostgresCache(postgres_cache_config_fixture) @@ -393,31 +393,30 @@ def test_insert_and_get_with_additional_kwargs(postgres_cache_config_fixture, mo # Create a CacheEntry with referenced documents docs = [ReferencedDocument(doc_title="Test Doc", doc_url=AnyUrl("http://example.com"))] - kwargs_obj = AdditionalKwargs(referenced_documents=docs) - entry_with_kwargs = CacheEntry( + entry_with_docs = CacheEntry( query="user message", response="AI message", provider="foo", model="bar", started_at="start_time", completed_at="end_time", - additional_kwargs=kwargs_obj + referenced_documents=docs ) # Call the insert method - cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_with_kwargs) + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_with_docs) insert_call = mock_cursor.execute.call_args_list[1] sql_params = insert_call[0][1] inserted_json_str = sql_params[-1] - assert json.loads(inserted_json_str) == { - "referenced_documents": [{"doc_title": "Test Doc", "doc_url": "http://example.com/"}] - } + assert json.loads(inserted_json_str) == [ + {"doc_url": "http://example.com/", "doc_title": "Test Doc"} + ] # Simulate the database returning that data db_return_value = ( "user message", "AI message", "foo", "bar", "start_time", "end_time", - {"referenced_documents": [{"doc_title": "Test Doc", "doc_url": "http://example.com/"}]} + [{"doc_url": "http://example.com/", "doc_title": "Test Doc"}] ) mock_cursor.fetchall.return_value = [db_return_value] @@ -426,24 +425,23 @@ def test_insert_and_get_with_additional_kwargs(postgres_cache_config_fixture, mo # Assert that the retrieved entry matches the original assert len(retrieved_entries) == 1 - assert retrieved_entries[0] == entry_with_kwargs - assert retrieved_entries[0].additional_kwargs.referenced_documents[0].doc_title == "Test Doc" + assert retrieved_entries[0] == entry_with_docs + assert retrieved_entries[0].referenced_documents[0].doc_title == "Test Doc" -def test_insert_and_get_without_additional_kwargs(postgres_cache_config_fixture, mocker): - """Test that a CacheEntry with no additional_kwargs is handled correctly.""" - # Arrange +def test_insert_and_get_without_referenced_documents(postgres_cache_config_fixture, mocker): + """Test that a CacheEntry with no referenced_documents is handled correctly.""" mock_connect = mocker.patch("psycopg2.connect") cache = PostgresCache(postgres_cache_config_fixture) mock_connection = mock_connect.return_value mock_cursor = mock_connection.cursor.return_value.__enter__.return_value - # Use CacheEntry without additional_kwargs - entry_without_kwargs = cache_entry_2 + # Use CacheEntry without referenced_documents + entry_without_docs = cache_entry_2 # Call the insert method - cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_without_kwargs) + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_without_docs) insert_call = mock_cursor.execute.call_args_list[1] sql_params = insert_call[0][1] @@ -451,13 +449,13 @@ def test_insert_and_get_without_additional_kwargs(postgres_cache_config_fixture, # 4. Simulate the database returning a row with None db_return_value = ( - entry_without_kwargs.query, - entry_without_kwargs.response, - entry_without_kwargs.provider, - entry_without_kwargs.model, - entry_without_kwargs.started_at, - entry_without_kwargs.completed_at, - None # additional_kwargs is None in the DB + entry_without_docs.query, + entry_without_docs.response, + entry_without_docs.provider, + entry_without_docs.model, + entry_without_docs.started_at, + entry_without_docs.completed_at, + None # referenced_documents is None in the DB ) mock_cursor.fetchall.return_value = [db_return_value] @@ -466,5 +464,5 @@ def test_insert_and_get_without_additional_kwargs(postgres_cache_config_fixture, # Assert that the retrieved entry matches the original assert len(retrieved_entries) == 1 - assert retrieved_entries[0] == entry_without_kwargs - assert retrieved_entries[0].additional_kwargs is None + assert retrieved_entries[0] == entry_without_docs + assert retrieved_entries[0].referenced_documents is None diff --git a/tests/unit/cache/test_sqlite_cache.py b/tests/unit/cache/test_sqlite_cache.py index 698ff4b36..cbb825247 100644 --- a/tests/unit/cache/test_sqlite_cache.py +++ b/tests/unit/cache/test_sqlite_cache.py @@ -9,7 +9,7 @@ from pydantic import AnyUrl from models.config import SQLiteDatabaseConfiguration -from models.cache_entry import AdditionalKwargs, CacheEntry, ReferencedDocument +from models.cache_entry import CacheEntry, ReferencedDocument from models.responses import ConversationData from utils import suid @@ -362,50 +362,49 @@ def test_topic_summary_when_disconnected(tmpdir): cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, "Test", False) -def test_insert_and_get_with_additional_kwargs(tmpdir): +def test_insert_and_get_with_referenced_documents(tmpdir): """ - Test that a CacheEntry with additional_kwargs is correctly + Test that a CacheEntry with referenced_documents is correctly serialized, stored, and retrieved. """ cache = create_cache(tmpdir) # Create a CacheEntry with referenced documents docs = [ReferencedDocument(doc_title="Test Doc", doc_url=AnyUrl("http://example.com"))] - kwargs_obj = AdditionalKwargs(referenced_documents=docs) - entry_with_kwargs = CacheEntry( + entry_with_docs = CacheEntry( query="user message", response="AI message", provider="foo", model="bar", started_at="start_time", completed_at="end_time", - additional_kwargs=kwargs_obj + referenced_documents=docs ) # Call the insert method - cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_with_kwargs) + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_with_docs) retrieved_entries = cache.get(USER_ID_1, CONVERSATION_ID_1) # Assert that the retrieved entry matches the original assert len(retrieved_entries) == 1 - assert retrieved_entries[0] == entry_with_kwargs - assert retrieved_entries[0].additional_kwargs is not None - assert retrieved_entries[0].additional_kwargs.referenced_documents[0].doc_title == "Test Doc" + assert retrieved_entries[0] == entry_with_docs + assert retrieved_entries[0].referenced_documents is not None + assert retrieved_entries[0].referenced_documents[0].doc_title == "Test Doc" -def test_insert_and_get_without_additional_kwargs(tmpdir): +def test_insert_and_get_without_referenced_documents(tmpdir): """ - Test that a CacheEntry without additional_kwargs is correctly - stored and retrieved with its additional_kwargs attribute as None. + Test that a CacheEntry without referenced_documents is correctly + stored and retrieved with its referenced_documents attribute as None. """ cache = create_cache(tmpdir) - # Use CacheEntry without additional_kwargs - entry_without_kwargs = cache_entry_1 + # Use CacheEntry without referenced_documents + entry_without_docs = cache_entry_1 # Call the insert method - cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_without_kwargs) + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_without_docs) retrieved_entries = cache.get(USER_ID_1, CONVERSATION_ID_1) # Assert that the retrieved entry matches the original assert len(retrieved_entries) == 1 - assert retrieved_entries[0] == entry_without_kwargs - assert retrieved_entries[0].additional_kwargs is None \ No newline at end of file + assert retrieved_entries[0] == entry_without_docs + assert retrieved_entries[0].referenced_documents is None \ No newline at end of file From 54c5c5a97935e4aa8e3e5a4abe8177ecf188baf9 Mon Sep 17 00:00:00 2001 From: Maysun J Faisal Date: Mon, 20 Oct 2025 17:13:43 -0400 Subject: [PATCH 4/4] Address linter issues Signed-off-by: Maysun J Faisal --- docs/openapi.json | 139 +++++++++++++----- src/app/endpoints/conversations_v2.py | 12 +- src/app/endpoints/query.py | 5 +- src/app/endpoints/streaming_query.py | 13 +- src/cache/postgres_cache.py | 29 +++- src/cache/sqlite_cache.py | 29 +++- src/models/cache_entry.py | 7 +- src/models/responses.py | 6 +- .../app/endpoints/test_conversations_v2.py | 13 +- tests/unit/app/endpoints/test_query.py | 23 ++- .../app/endpoints/test_streaming_query.py | 17 ++- tests/unit/cache/test_postgres_cache.py | 59 +++++--- tests/unit/cache/test_sqlite_cache.py | 20 +-- 13 files changed, 251 insertions(+), 121 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 45f77ed59..1d32805be 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1787,6 +1787,9 @@ }, "type": "array", "title": "Byok Rag" + }, + "quota_handlers": { + "$ref": "#/components/schemas/QuotaHandlersConfiguration" } }, "additionalProperties": false, @@ -3590,6 +3593,103 @@ } ] }, + "QuotaHandlersConfiguration": { + "properties": { + "sqlite": { + "anyOf": [ + { + "$ref": "#/components/schemas/SQLiteDatabaseConfiguration" + }, + { + "type": "null" + } + ] + }, + "postgres": { + "anyOf": [ + { + "$ref": "#/components/schemas/PostgreSQLDatabaseConfiguration" + }, + { + "type": "null" + } + ] + }, + "limiters": { + "items": { + "$ref": "#/components/schemas/QuotaLimiterConfiguration" + }, + "type": "array", + "title": "Limiters" + }, + "scheduler": { + "$ref": "#/components/schemas/QuotaSchedulerConfiguration" + }, + "enable_token_history": { + "type": "boolean", + "title": "Enable Token History", + "default": false + } + }, + "additionalProperties": false, + "type": "object", + "title": "QuotaHandlersConfiguration", + "description": "Quota limiter configuration." + }, + "QuotaLimiterConfiguration": { + "properties": { + "type": { + "type": "string", + "enum": [ + "user_limiter", + "cluster_limiter" + ], + "title": "Type" + }, + "name": { + "type": "string", + "title": "Name" + }, + "initial_quota": { + "type": "integer", + "minimum": 0.0, + "title": "Initial Quota" + }, + "quota_increase": { + "type": "integer", + "minimum": 0.0, + "title": "Quota Increase" + }, + "period": { + "type": "string", + "title": "Period" + } + }, + "additionalProperties": false, + "type": "object", + "required": [ + "type", + "name", + "initial_quota", + "quota_increase", + "period" + ], + "title": "QuotaLimiterConfiguration", + "description": "Configuration for one quota limiter." + }, + "QuotaSchedulerConfiguration": { + "properties": { + "period": { + "type": "integer", + "exclusiveMinimum": 0.0, + "title": "Period", + "default": 1 + } + }, + "type": "object", + "title": "QuotaSchedulerConfiguration", + "description": "Quota scheduler configuration." + }, "RAGChunk": { "properties": { "content": { @@ -3973,45 +4073,6 @@ "title": "ToolsResponse", "description": "Model representing a response to tools request." }, - "UnauthorizedResponse": { - "properties": { - "tools": { - "items": { - "additionalProperties": true, - "type": "object" - }, - "type": "array", - "title": "Tools", - "description": "List of tools available from all configured MCP servers and built-in toolgroups", - "examples": [ - [ - { - "description": "Read contents of a file from the filesystem", - "identifier": "filesystem_read", - "parameters": [ - { - "description": "Path to the file to read", - "name": "path", - "parameter_type": "string", - "required": true - } - ], - "provider_id": "model-context-protocol", - "server_source": "http://localhost:3000", - "toolgroup_id": "filesystem-tools", - "type": "tool" - } - ] - ] - } - }, - "type": "object", - "required": [ - "tools" - ], - "title": "ToolsResponse", - "description": "Model representing a response to tools request." - }, "UnauthorizedResponse": { "properties": { "detail": { diff --git a/src/app/endpoints/conversations_v2.py b/src/app/endpoints/conversations_v2.py index 36265c9f7..edf74ed54 100644 --- a/src/app/endpoints/conversations_v2.py +++ b/src/app/endpoints/conversations_v2.py @@ -314,19 +314,13 @@ def check_conversation_existence(user_id: str, conversation_id: str) -> None: def transform_chat_message(entry: CacheEntry) -> dict[str, Any]: """Transform the message read from cache into format used by response payload.""" - user_message = { - "content": entry.query, - "type": "user" - } - assistant_message: dict[str, Any] = { - "content": entry.response, - "type": "assistant" - } + user_message = {"content": entry.query, "type": "user"} + assistant_message: dict[str, Any] = {"content": entry.response, "type": "assistant"} # If referenced_documents exist on the entry, add them to the assistant message if entry.referenced_documents is not None: assistant_message["referenced_documents"] = [ - doc.model_dump(mode='json') for doc in entry.referenced_documents + doc.model_dump(mode="json") for doc in entry.referenced_documents ] return { diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 5255e9610..c02656eb8 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -22,7 +22,6 @@ from llama_stack_client.types.model_list_response import ModelListResponse from llama_stack_client.types.shared.interleaved_content_item import TextContentItem from llama_stack_client.types.tool_execution_step import ToolExecutionStep -from pydantic import AnyUrl import constants import metrics @@ -341,9 +340,9 @@ async def query_endpoint_handler( # pylint: disable=R0914 model=model_id, started_at=started_at, completed_at=completed_at, - referenced_documents=referenced_documents if referenced_documents else None + referenced_documents=referenced_documents if referenced_documents else None, ) - + store_conversation_into_cache( configuration, user_id, diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 076884427..7f7caa79b 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -21,7 +21,6 @@ ) from llama_stack_client.types.shared import ToolCall from llama_stack_client.types.shared.interleaved_content_item import TextContentItem -from pydantic import AnyUrl from app.database import get_session from app.endpoints.query import ( @@ -48,7 +47,7 @@ from models.config import Action from models.database.conversations import UserConversation from models.requests import QueryRequest -from models.responses import ForbiddenResponse, UnauthorizedResponse, ReferencedDocument +from models.responses import ForbiddenResponse, UnauthorizedResponse from utils.endpoints import ( check_configuration_loaded, create_referenced_documents_with_metadata, @@ -867,7 +866,9 @@ async def response_generator( completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ") - referenced_documents = create_referenced_documents_with_metadata(summary, metadata_map) + referenced_documents = create_referenced_documents_with_metadata( + summary, metadata_map + ) cache_entry = CacheEntry( query=query_request.query, @@ -876,9 +877,11 @@ async def response_generator( model=model_id, started_at=started_at, completed_at=completed_at, - referenced_documents=referenced_documents if referenced_documents else None + referenced_documents=( + referenced_documents if referenced_documents else None + ), ) - + store_conversation_into_cache( configuration, user_id, diff --git a/src/cache/postgres_cache.py b/src/cache/postgres_cache.py index 73b7d47fc..0774ef1c8 100644 --- a/src/cache/postgres_cache.py +++ b/src/cache/postgres_cache.py @@ -215,11 +215,21 @@ def get( result = [] for conversation_entry in conversation_entries: - # Parse it back into an LLMResponse object + # Parse referenced_documents back into ReferencedDocument objects docs_data = conversation_entry[6] docs_obj = None if docs_data: - docs_obj = [ReferencedDocument.model_validate(doc) for doc in docs_data] + try: + docs_obj = [ + ReferencedDocument.model_validate(doc) for doc in docs_data + ] + except (ValueError, TypeError) as e: + logger.warning( + "Failed to deserialize referenced_documents for " + "conversation %s: %s", + conversation_id, + e, + ) cache_entry = CacheEntry( query=conversation_entry[0], response=conversation_entry[1], @@ -257,8 +267,19 @@ def insert_or_append( try: referenced_documents_json = None if cache_entry.referenced_documents: - docs_as_dicts = [doc.model_dump(mode='json') for doc in cache_entry.referenced_documents] - referenced_documents_json = json.dumps(docs_as_dicts) + try: + docs_as_dicts = [ + doc.model_dump(mode="json") + for doc in cache_entry.referenced_documents + ] + referenced_documents_json = json.dumps(docs_as_dicts) + except (TypeError, ValueError) as e: + logger.warning( + "Failed to serialize referenced_documents for " + "conversation %s: %s", + conversation_id, + e, + ) # the whole operation is run in one transaction with self.connection.cursor() as cursor: diff --git a/src/cache/sqlite_cache.py b/src/cache/sqlite_cache.py index 7002f4ec0..bf70355bd 100644 --- a/src/cache/sqlite_cache.py +++ b/src/cache/sqlite_cache.py @@ -216,8 +216,18 @@ def get( docs_json_str = conversation_entry[6] docs_obj = None if docs_json_str: - docs_data = json.loads(docs_json_str) - docs_obj = [ReferencedDocument.model_validate(doc) for doc in docs_data] + try: + docs_data = json.loads(docs_json_str) + docs_obj = [ + ReferencedDocument.model_validate(doc) for doc in docs_data + ] + except (json.JSONDecodeError, ValueError) as e: + logger.warning( + "Failed to deserialize referenced_documents for " + "conversation %s: %s", + conversation_id, + e, + ) cache_entry = CacheEntry( query=conversation_entry[0], response=conversation_entry[1], @@ -257,8 +267,19 @@ def insert_or_append( referenced_documents_json = None if cache_entry.referenced_documents: - docs_as_dicts = [doc.model_dump(mode='json') for doc in cache_entry.referenced_documents] - referenced_documents_json = json.dumps(docs_as_dicts) + try: + docs_as_dicts = [ + doc.model_dump(mode="json") + for doc in cache_entry.referenced_documents + ] + referenced_documents_json = json.dumps(docs_as_dicts) + except (TypeError, ValueError) as e: + logger.warning( + "Failed to serialize referenced_documents for " + "conversation %s: %s", + conversation_id, + e, + ) cursor.execute( self.INSERT_CONVERSATION_HISTORY_STATEMENT, diff --git a/src/models/cache_entry.py b/src/models/cache_entry.py index 86184894a..116372bbb 100644 --- a/src/models/cache_entry.py +++ b/src/models/cache_entry.py @@ -1,7 +1,6 @@ """Model for conversation history cache entry.""" -from pydantic import BaseModel, Field -from typing import List +from pydantic import BaseModel from models.responses import ReferencedDocument @@ -13,7 +12,7 @@ class CacheEntry(BaseModel): response: The response string provider: Provider identification model: Model identification - additional_kwargs: additional property to store data like referenced documents + referenced_documents: List of documents referenced by the response """ query: str @@ -22,4 +21,4 @@ class CacheEntry(BaseModel): model: str started_at: str completed_at: str - referenced_documents: List[ReferencedDocument] | None = None + referenced_documents: list[ReferencedDocument] | None = None diff --git a/src/models/responses.py b/src/models/responses.py index 662439bcf..ce5297454 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -161,6 +161,7 @@ class ToolCall(BaseModel): arguments: dict[str, Any] = Field(description="Arguments passed to the tool") result: Optional[dict[str, Any]] = Field(None, description="Result from the tool") + class ConversationData(BaseModel): """Model representing conversation data returned by cache list operations. @@ -174,6 +175,7 @@ class ConversationData(BaseModel): topic_summary: str | None last_message_timestamp: float + class ReferencedDocument(BaseModel): """Model representing a document referenced in generating a response. @@ -186,9 +188,7 @@ class ReferencedDocument(BaseModel): None, description="URL of the referenced document" ) - doc_title: str | None = Field( - None, description="Title of the referenced document" - ) + doc_title: str | None = Field(None, description="Title of the referenced document") class QueryResponse(BaseModel): diff --git a/tests/unit/app/endpoints/test_conversations_v2.py b/tests/unit/app/endpoints/test_conversations_v2.py index b5a975700..8467ed7ae 100644 --- a/tests/unit/app/endpoints/test_conversations_v2.py +++ b/tests/unit/app/endpoints/test_conversations_v2.py @@ -3,7 +3,6 @@ """Unit tests for the /conversations REST API endpoints.""" from unittest.mock import Mock -from pydantic import AnyUrl import pytest from fastapi import HTTPException, status @@ -75,7 +74,7 @@ def test_transform_message_without_documents(self) -> None: # referenced_documents is None by default ) transformed = transform_chat_message(entry) - + assistant_message = transformed["messages"][1] # Assert that the key is NOT present when the list is None @@ -83,8 +82,7 @@ def test_transform_message_without_documents(self) -> None: def test_transform_message_with_referenced_documents(self) -> None: """Test the transformation when referenced_documents are present.""" - docs = [ReferencedDocument(doc_title="Test Doc", doc_url=AnyUrl("http://example.com"))] - + docs = [ReferencedDocument(doc_title="Test Doc", doc_url="http://example.com")] entry = CacheEntry( query="query", response="response", @@ -92,14 +90,13 @@ def test_transform_message_with_referenced_documents(self) -> None: model="model", started_at="2024-01-01T00:00:00Z", completed_at="2024-01-01T00:00:05Z", - referenced_documents=docs + referenced_documents=docs, ) transformed = transform_chat_message(entry) assistant_message = transformed["messages"][1] - + assert "referenced_documents" in assistant_message - ref_docs = assistant_message["referenced_documents"] assert len(ref_docs) == 1 assert ref_docs[0]["doc_title"] == "Test Doc" @@ -114,7 +111,7 @@ def test_transform_message_with_empty_referenced_documents(self) -> None: model="model", started_at="2024-01-01T00:00:00Z", completed_at="2024-01-01T00:00:05Z", - referenced_documents=[] # Explicitly empty + referenced_documents=[], # Explicitly empty ) transformed = transform_chat_message(entry) diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 0cc942b43..51943d15d 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -167,6 +167,7 @@ def test_is_transcripts_disabled(setup_configuration, mocker) -> None: assert is_transcripts_enabled() is False, "Transcripts should be disabled" +# pylint: disable=too-many-locals async def _test_query_endpoint_handler( mocker, dummy_request: Request, store_transcript_to_file=False ) -> None: @@ -184,12 +185,16 @@ async def _test_query_endpoint_handler( store_transcript_to_file ) mocker.patch("app.endpoints.query.configuration", mock_config) - - mock_store_in_cache = mocker.patch("app.endpoints.query.store_conversation_into_cache") + + mock_store_in_cache = mocker.patch( + "app.endpoints.query.store_conversation_into_cache" + ) # Create mock referenced documents to simulate a successful RAG response mock_referenced_documents = [ - ReferencedDocument(doc_title="Test Doc 1", doc_url=AnyUrl("http://example.com/1")) + ReferencedDocument( + doc_title="Test Doc 1", doc_url=AnyUrl("http://example.com/1") + ) ] summary = TurnSummary( @@ -208,7 +213,12 @@ async def _test_query_endpoint_handler( mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(summary, conversation_id, mock_referenced_documents, TokenCounter()), + return_value=( + summary, + conversation_id, + mock_referenced_documents, + TokenCounter(), + ), ) mocker.patch( "app.endpoints.query.select_model_and_provider_id", @@ -237,11 +247,12 @@ async def _test_query_endpoint_handler( # Assert the response is as expected assert response.response == summary.llm_response assert response.conversation_id == conversation_id - + # Assert that mock was called and get the arguments mock_store_in_cache.assert_called_once() call_args = mock_store_in_cache.call_args[0] - # Extract CacheEntry object from the call arguments, it's the 4th argument from the func signature + # Extract CacheEntry object from the call arguments, + # it's the 4th argument from the func signature cached_entry = call_args[3] assert isinstance(cached_entry, CacheEntry) diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 068395b07..8664678ce 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -209,6 +209,7 @@ async def test_streaming_query_endpoint_on_connection_error(mocker): assert response.media_type == "text/event-stream" +# pylint: disable=too-many-locals async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False): """Test the streaming query endpoint handler.""" mock_client = mocker.AsyncMock() @@ -297,8 +298,9 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) ), ] - mock_store_in_cache = mocker.patch("app.endpoints.streaming_query.store_conversation_into_cache") - + mock_store_in_cache = mocker.patch( + "app.endpoints.streaming_query.store_conversation_into_cache" + ) query = "What is OpenStack?" mocker.patch( "app.endpoints.streaming_query.retrieve_response", @@ -359,20 +361,23 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) referenced_documents = d["data"]["referenced_documents"] assert len(referenced_documents) == 2 assert referenced_documents[1]["doc_title"] == "Doc2" - + # Assert that mock was called and get the arguments mock_store_in_cache.assert_called_once() call_args = mock_store_in_cache.call_args[0] - # Extract CacheEntry object from the call arguments, it's the 4th argument from the func signature + # Extract CacheEntry object from the call arguments, + # it's the 4th argument from the func signature cached_entry = call_args[3] - + # Assert that the CacheEntry was constructed correctly assert isinstance(cached_entry, CacheEntry) assert cached_entry.response == "LLM answer" assert cached_entry.referenced_documents is not None assert len(cached_entry.referenced_documents) == 2 assert cached_entry.referenced_documents[0].doc_title == "Doc1" - assert str(cached_entry.referenced_documents[1].doc_url) == "https://example.com/doc2" + assert ( + str(cached_entry.referenced_documents[1].doc_url) == "https://example.com/doc2" + ) # Assert the store_transcript function is called if transcripts are enabled if store_transcript: diff --git a/tests/unit/cache/test_postgres_cache.py b/tests/unit/cache/test_postgres_cache.py index 9120dfa9b..18023f006 100644 --- a/tests/unit/cache/test_postgres_cache.py +++ b/tests/unit/cache/test_postgres_cache.py @@ -6,13 +6,11 @@ import psycopg2 -from pydantic import AnyUrl - from cache.cache_error import CacheError from cache.postgres_cache import PostgresCache from models.config import PostgreSQLDatabaseConfiguration -from models.cache_entry import CacheEntry, ReferencedDocument -from models.responses import ConversationData +from models.cache_entry import CacheEntry +from models.responses import ConversationData, ReferencedDocument from utils import suid @@ -382,7 +380,9 @@ def test_topic_summary_when_disconnected(postgres_cache_config_fixture, mocker): cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, "Test", False) -def test_insert_and_get_with_referenced_documents(postgres_cache_config_fixture, mocker): +def test_insert_and_get_with_referenced_documents( + postgres_cache_config_fixture, mocker +): """Test that a CacheEntry with referenced_documents is stored and retrieved correctly.""" # prevent real connection to PG instance mock_connect = mocker.patch("psycopg2.connect") @@ -392,31 +392,43 @@ def test_insert_and_get_with_referenced_documents(postgres_cache_config_fixture, mock_cursor = mock_connection.cursor.return_value.__enter__.return_value # Create a CacheEntry with referenced documents - docs = [ReferencedDocument(doc_title="Test Doc", doc_url=AnyUrl("http://example.com"))] + docs = [ReferencedDocument(doc_title="Test Doc", doc_url="http://example.com/")] entry_with_docs = CacheEntry( query="user message", response="AI message", - provider="foo", model="bar", - started_at="start_time", completed_at="end_time", - referenced_documents=docs + provider="foo", + model="bar", + started_at="start_time", + completed_at="end_time", + referenced_documents=docs, ) # Call the insert method cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_with_docs) - - insert_call = mock_cursor.execute.call_args_list[1] - sql_params = insert_call[0][1] + # Find the INSERT INTO cache(...) call + insert_calls = [ + c + for c in mock_cursor.execute.call_args_list + if isinstance(c[0][0], str) and "INSERT INTO cache(" in c[0][0] + ] + assert insert_calls, "INSERT call not found" + sql_params = insert_calls[-1][0][1] inserted_json_str = sql_params[-1] - + assert json.loads(inserted_json_str) == [ {"doc_url": "http://example.com/", "doc_title": "Test Doc"} ] # Simulate the database returning that data db_return_value = ( - "user message", "AI message", "foo", "bar", "start_time", "end_time", - [{"doc_url": "http://example.com/", "doc_title": "Test Doc"}] + "user message", + "AI message", + "foo", + "bar", + "start_time", + "end_time", + [{"doc_url": "http://example.com/", "doc_title": "Test Doc"}], ) mock_cursor.fetchall.return_value = [db_return_value] @@ -429,7 +441,9 @@ def test_insert_and_get_with_referenced_documents(postgres_cache_config_fixture, assert retrieved_entries[0].referenced_documents[0].doc_title == "Test Doc" -def test_insert_and_get_without_referenced_documents(postgres_cache_config_fixture, mocker): +def test_insert_and_get_without_referenced_documents( + postgres_cache_config_fixture, mocker +): """Test that a CacheEntry with no referenced_documents is handled correctly.""" mock_connect = mocker.patch("psycopg2.connect") cache = PostgresCache(postgres_cache_config_fixture) @@ -443,11 +457,16 @@ def test_insert_and_get_without_referenced_documents(postgres_cache_config_fixtu # Call the insert method cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_without_docs) - insert_call = mock_cursor.execute.call_args_list[1] - sql_params = insert_call[0][1] + insert_calls = [ + c + for c in mock_cursor.execute.call_args_list + if isinstance(c[0][0], str) and "INSERT INTO cache(" in c[0][0] + ] + assert insert_calls, "INSERT call not found" + sql_params = insert_calls[-1][0][1] assert sql_params[-1] is None - # 4. Simulate the database returning a row with None + # Simulate the database returning a row with None db_return_value = ( entry_without_docs.query, entry_without_docs.response, @@ -455,7 +474,7 @@ def test_insert_and_get_without_referenced_documents(postgres_cache_config_fixtu entry_without_docs.model, entry_without_docs.started_at, entry_without_docs.completed_at, - None # referenced_documents is None in the DB + None, # referenced_documents is None in the DB ) mock_cursor.fetchall.return_value = [db_return_value] diff --git a/tests/unit/cache/test_sqlite_cache.py b/tests/unit/cache/test_sqlite_cache.py index cbb825247..1f77ee9f1 100644 --- a/tests/unit/cache/test_sqlite_cache.py +++ b/tests/unit/cache/test_sqlite_cache.py @@ -6,11 +6,9 @@ import pytest -from pydantic import AnyUrl - from models.config import SQLiteDatabaseConfiguration -from models.cache_entry import CacheEntry, ReferencedDocument -from models.responses import ConversationData +from models.cache_entry import CacheEntry +from models.responses import ConversationData, ReferencedDocument from utils import suid from cache.cache_error import CacheError @@ -370,13 +368,15 @@ def test_insert_and_get_with_referenced_documents(tmpdir): cache = create_cache(tmpdir) # Create a CacheEntry with referenced documents - docs = [ReferencedDocument(doc_title="Test Doc", doc_url=AnyUrl("http://example.com"))] + docs = [ReferencedDocument(doc_title="Test Doc", doc_url="http://example.com")] entry_with_docs = CacheEntry( query="user message", response="AI message", - provider="foo", model="bar", - started_at="start_time", completed_at="end_time", - referenced_documents=docs + provider="foo", + model="bar", + started_at="start_time", + completed_at="end_time", + referenced_documents=docs, ) # Call the insert method @@ -396,7 +396,7 @@ def test_insert_and_get_without_referenced_documents(tmpdir): stored and retrieved with its referenced_documents attribute as None. """ cache = create_cache(tmpdir) - + # Use CacheEntry without referenced_documents entry_without_docs = cache_entry_1 @@ -407,4 +407,4 @@ def test_insert_and_get_without_referenced_documents(tmpdir): # Assert that the retrieved entry matches the original assert len(retrieved_entries) == 1 assert retrieved_entries[0] == entry_without_docs - assert retrieved_entries[0].referenced_documents is None \ No newline at end of file + assert retrieved_entries[0].referenced_documents is None