diff --git a/docs/openapi.json b/docs/openapi.json index d036fa3f2..1d32805be 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1787,6 +1787,9 @@ }, "type": "array", "title": "Byok Rag" + }, + "quota_handlers": { + "$ref": "#/components/schemas/QuotaHandlersConfiguration" } }, "additionalProperties": false, @@ -3590,6 +3593,103 @@ } ] }, + "QuotaHandlersConfiguration": { + "properties": { + "sqlite": { + "anyOf": [ + { + "$ref": "#/components/schemas/SQLiteDatabaseConfiguration" + }, + { + "type": "null" + } + ] + }, + "postgres": { + "anyOf": [ + { + "$ref": "#/components/schemas/PostgreSQLDatabaseConfiguration" + }, + { + "type": "null" + } + ] + }, + "limiters": { + "items": { + "$ref": "#/components/schemas/QuotaLimiterConfiguration" + }, + "type": "array", + "title": "Limiters" + }, + "scheduler": { + "$ref": "#/components/schemas/QuotaSchedulerConfiguration" + }, + "enable_token_history": { + "type": "boolean", + "title": "Enable Token History", + "default": false + } + }, + "additionalProperties": false, + "type": "object", + "title": "QuotaHandlersConfiguration", + "description": "Quota limiter configuration." + }, + "QuotaLimiterConfiguration": { + "properties": { + "type": { + "type": "string", + "enum": [ + "user_limiter", + "cluster_limiter" + ], + "title": "Type" + }, + "name": { + "type": "string", + "title": "Name" + }, + "initial_quota": { + "type": "integer", + "minimum": 0.0, + "title": "Initial Quota" + }, + "quota_increase": { + "type": "integer", + "minimum": 0.0, + "title": "Quota Increase" + }, + "period": { + "type": "string", + "title": "Period" + } + }, + "additionalProperties": false, + "type": "object", + "required": [ + "type", + "name", + "initial_quota", + "quota_increase", + "period" + ], + "title": "QuotaLimiterConfiguration", + "description": "Configuration for one quota limiter." + }, + "QuotaSchedulerConfiguration": { + "properties": { + "period": { + "type": "integer", + "exclusiveMinimum": 0.0, + "title": "Period", + "default": 1 + } + }, + "type": "object", + "title": "QuotaSchedulerConfiguration", + "description": "Quota scheduler configuration." + }, "RAGChunk": { "properties": { "content": { @@ -3691,15 +3791,19 @@ "description": "URL of the referenced document" }, "doc_title": { - "type": "string", + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], "title": "Doc Title", "description": "Title of the referenced document" } }, "type": "object", - "required": [ - "doc_title" - ], "title": "ReferencedDocument", "description": "Model representing a document referenced in generating a response.\n\nAttributes:\n doc_url: Url to the referenced doc.\n doc_title: Title of the referenced doc." }, diff --git a/src/app/endpoints/conversations_v2.py b/src/app/endpoints/conversations_v2.py index 7ce0ac7ca..edf74ed54 100644 --- a/src/app/endpoints/conversations_v2.py +++ b/src/app/endpoints/conversations_v2.py @@ -314,13 +314,19 @@ def check_conversation_existence(user_id: str, conversation_id: str) -> None: def transform_chat_message(entry: CacheEntry) -> dict[str, Any]: """Transform the message read from cache into format used by response payload.""" + user_message = {"content": entry.query, "type": "user"} + assistant_message: dict[str, Any] = {"content": entry.response, "type": "assistant"} + + # If referenced_documents exist on the entry, add them to the assistant message + if entry.referenced_documents is not None: + assistant_message["referenced_documents"] = [ + doc.model_dump(mode="json") for doc in entry.referenced_documents + ] + return { "provider": entry.provider, "model": entry.model, - "messages": [ - {"content": entry.query, "type": "user"}, - {"content": entry.response, "type": "assistant"}, - ], + "messages": [user_message, assistant_message], "started_at": entry.started_at, "completed_at": entry.completed_at, } diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 98be63c42..c02656eb8 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -31,6 +31,7 @@ from authorization.middleware import authorize from client import AsyncLlamaStackClientHolder from configuration import configuration +from models.cache_entry import CacheEntry from models.config import Action from models.database.conversations import UserConversation from models.requests import Attachment, QueryRequest @@ -331,16 +332,22 @@ async def query_endpoint_handler( # pylint: disable=R0914 ) completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ") + + cache_entry = CacheEntry( + query=query_request.query, + response=summary.llm_response, + provider=provider_id, + model=model_id, + started_at=started_at, + completed_at=completed_at, + referenced_documents=referenced_documents if referenced_documents else None, + ) + store_conversation_into_cache( configuration, user_id, conversation_id, - provider_id, - model_id, - query_request.query, - summary.llm_response, - started_at, - completed_at, + cache_entry, _skip_userid_check, topic_summary, ) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index d903469a7..7f7caa79b 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -43,12 +43,14 @@ from constants import DEFAULT_RAG_TOOL, MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT import metrics from metrics.utils import update_llm_token_count_from_turn +from models.cache_entry import CacheEntry from models.config import Action from models.database.conversations import UserConversation from models.requests import QueryRequest from models.responses import ForbiddenResponse, UnauthorizedResponse from utils.endpoints import ( check_configuration_loaded, + create_referenced_documents_with_metadata, create_rag_chunks_dict, get_agent, get_system_prompt, @@ -863,16 +865,28 @@ async def response_generator( ) completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ") + + referenced_documents = create_referenced_documents_with_metadata( + summary, metadata_map + ) + + cache_entry = CacheEntry( + query=query_request.query, + response=summary.llm_response, + provider=provider_id, + model=model_id, + started_at=started_at, + completed_at=completed_at, + referenced_documents=( + referenced_documents if referenced_documents else None + ), + ) + store_conversation_into_cache( configuration, user_id, conversation_id, - provider_id, - model_id, - query_request.query, - summary.llm_response, - started_at, - completed_at, + cache_entry, _skip_userid_check, topic_summary, ) diff --git a/src/cache/cache.py b/src/cache/cache.py index 4cdab6307..8a6e6fc07 100644 --- a/src/cache/cache.py +++ b/src/cache/cache.py @@ -2,7 +2,8 @@ from abc import ABC, abstractmethod -from models.cache_entry import CacheEntry, ConversationData +from models.cache_entry import CacheEntry +from models.responses import ConversationData from utils.suid import check_suid diff --git a/src/cache/in_memory_cache.py b/src/cache/in_memory_cache.py index 1b6b4123f..388585b5a 100644 --- a/src/cache/in_memory_cache.py +++ b/src/cache/in_memory_cache.py @@ -1,8 +1,9 @@ """In-memory cache implementation.""" from cache.cache import Cache -from models.cache_entry import CacheEntry, ConversationData +from models.cache_entry import CacheEntry from models.config import InMemoryCacheConfig +from models.responses import ConversationData from log import get_logger from utils.connection_decorator import connection diff --git a/src/cache/noop_cache.py b/src/cache/noop_cache.py index fcd20f368..40b1bd144 100644 --- a/src/cache/noop_cache.py +++ b/src/cache/noop_cache.py @@ -1,7 +1,8 @@ """No-operation cache implementation.""" from cache.cache import Cache -from models.cache_entry import CacheEntry, ConversationData +from models.cache_entry import CacheEntry +from models.responses import ConversationData from log import get_logger from utils.connection_decorator import connection diff --git a/src/cache/postgres_cache.py b/src/cache/postgres_cache.py index 0881b66c6..0774ef1c8 100644 --- a/src/cache/postgres_cache.py +++ b/src/cache/postgres_cache.py @@ -1,11 +1,13 @@ """PostgreSQL cache implementation.""" +import json import psycopg2 from cache.cache import Cache from cache.cache_error import CacheError -from models.cache_entry import CacheEntry, ConversationData +from models.cache_entry import CacheEntry from models.config import PostgreSQLDatabaseConfiguration +from models.responses import ConversationData, ReferencedDocument from log import get_logger from utils.connection_decorator import connection @@ -18,17 +20,18 @@ class PostgresCache(Cache): The cache itself lives stored in following table: ``` - Column | Type | Nullable | - -----------------+--------------------------------+----------+ - user_id | text | not null | - conversation_id | text | not null | - created_at | timestamp without time zone | not null | - started_at | text | | - completed_at | text | | - query | text | | - response | text | | - provider | text | | - model | text | | + Column | Type | Nullable | + -----------------------+--------------------------------+----------+ + user_id | text | not null | + conversation_id | text | not null | + created_at | timestamp without time zone | not null | + started_at | text | | + completed_at | text | | + query | text | | + response | text | | + provider | text | | + model | text | | + referenced_documents | jsonb | | Indexes: "cache_pkey" PRIMARY KEY, btree (user_id, conversation_id, created_at) "timestamps" btree (created_at) @@ -37,15 +40,16 @@ class PostgresCache(Cache): CREATE_CACHE_TABLE = """ CREATE TABLE IF NOT EXISTS cache ( - user_id text NOT NULL, - conversation_id text NOT NULL, - created_at timestamp NOT NULL, - started_at text, - completed_at text, - query text, - response text, - provider text, - model text, + user_id text NOT NULL, + conversation_id text NOT NULL, + created_at timestamp NOT NULL, + started_at text, + completed_at text, + query text, + response text, + provider text, + model text, + referenced_documents jsonb, PRIMARY KEY(user_id, conversation_id, created_at) ); """ @@ -66,7 +70,7 @@ class PostgresCache(Cache): """ SELECT_CONVERSATION_HISTORY_STATEMENT = """ - SELECT query, response, provider, model, started_at, completed_at + SELECT query, response, provider, model, started_at, completed_at, referenced_documents FROM cache WHERE user_id=%s AND conversation_id=%s ORDER BY created_at @@ -74,8 +78,8 @@ class PostgresCache(Cache): INSERT_CONVERSATION_HISTORY_STATEMENT = """ INSERT INTO cache(user_id, conversation_id, created_at, started_at, completed_at, - query, response, provider, model) - VALUES (%s, %s, CURRENT_TIMESTAMP, %s, %s, %s, %s, %s, %s) + query, response, provider, model, referenced_documents) + VALUES (%s, %s, CURRENT_TIMESTAMP, %s, %s, %s, %s, %s, %s, %s) """ QUERY_CACHE_SIZE = """ @@ -211,6 +215,21 @@ def get( result = [] for conversation_entry in conversation_entries: + # Parse referenced_documents back into ReferencedDocument objects + docs_data = conversation_entry[6] + docs_obj = None + if docs_data: + try: + docs_obj = [ + ReferencedDocument.model_validate(doc) for doc in docs_data + ] + except (ValueError, TypeError) as e: + logger.warning( + "Failed to deserialize referenced_documents for " + "conversation %s: %s", + conversation_id, + e, + ) cache_entry = CacheEntry( query=conversation_entry[0], response=conversation_entry[1], @@ -218,6 +237,7 @@ def get( model=conversation_entry[3], started_at=conversation_entry[4], completed_at=conversation_entry[5], + referenced_documents=docs_obj, ) result.append(cache_entry) @@ -245,6 +265,22 @@ def insert_or_append( raise CacheError("insert_or_append: cache is disconnected") try: + referenced_documents_json = None + if cache_entry.referenced_documents: + try: + docs_as_dicts = [ + doc.model_dump(mode="json") + for doc in cache_entry.referenced_documents + ] + referenced_documents_json = json.dumps(docs_as_dicts) + except (TypeError, ValueError) as e: + logger.warning( + "Failed to serialize referenced_documents for " + "conversation %s: %s", + conversation_id, + e, + ) + # the whole operation is run in one transaction with self.connection.cursor() as cursor: cursor.execute( @@ -258,6 +294,7 @@ def insert_or_append( cache_entry.response, cache_entry.provider, cache_entry.model, + referenced_documents_json, ), ) diff --git a/src/cache/sqlite_cache.py b/src/cache/sqlite_cache.py index b8a91fac9..bf70355bd 100644 --- a/src/cache/sqlite_cache.py +++ b/src/cache/sqlite_cache.py @@ -3,11 +3,13 @@ from time import time import sqlite3 +import json from cache.cache import Cache from cache.cache_error import CacheError -from models.cache_entry import CacheEntry, ConversationData +from models.cache_entry import CacheEntry from models.config import SQLiteDatabaseConfiguration +from models.responses import ConversationData, ReferencedDocument from log import get_logger from utils.connection_decorator import connection @@ -20,17 +22,18 @@ class SQLiteCache(Cache): The cache itself is stored in following table: ``` - Column | Type | Nullable | - -----------------+-----------------------------+----------+ - user_id | text | not null | - conversation_id | text | not null | - created_at | int | not null | - started_at | text | | - completed_at | text | | - query | text | | - response | text | | - provider | text | | - model | text | | + Column | Type | Nullable | + -----------------------+-----------------------------+----------+ + user_id | text | not null | + conversation_id | text | not null | + created_at | int | not null | + started_at | text | | + completed_at | text | | + query | text | | + response | text | | + provider | text | | + model | text | | + referenced_documents | text | | Indexes: "cache_pkey" PRIMARY KEY, btree (user_id, conversation_id, created_at) "cache_key_key" UNIQUE CONSTRAINT, btree (key) @@ -41,15 +44,16 @@ class SQLiteCache(Cache): CREATE_CACHE_TABLE = """ CREATE TABLE IF NOT EXISTS cache ( - user_id text NOT NULL, - conversation_id text NOT NULL, - created_at int NOT NULL, - started_at text, - completed_at text, - query text, - response text, - provider text, - model text, + user_id text NOT NULL, + conversation_id text NOT NULL, + created_at int NOT NULL, + started_at text, + completed_at text, + query text, + response text, + provider text, + model text, + referenced_documents text, PRIMARY KEY(user_id, conversation_id, created_at) ); """ @@ -70,7 +74,7 @@ class SQLiteCache(Cache): """ SELECT_CONVERSATION_HISTORY_STATEMENT = """ - SELECT query, response, provider, model, started_at, completed_at + SELECT query, response, provider, model, started_at, completed_at, referenced_documents FROM cache WHERE user_id=? AND conversation_id=? ORDER BY created_at @@ -78,8 +82,8 @@ class SQLiteCache(Cache): INSERT_CONVERSATION_HISTORY_STATEMENT = """ INSERT INTO cache(user_id, conversation_id, created_at, started_at, completed_at, - query, response, provider, model) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + query, response, provider, model, referenced_documents) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """ QUERY_CACHE_SIZE = """ @@ -209,6 +213,21 @@ def get( result = [] for conversation_entry in conversation_entries: + docs_json_str = conversation_entry[6] + docs_obj = None + if docs_json_str: + try: + docs_data = json.loads(docs_json_str) + docs_obj = [ + ReferencedDocument.model_validate(doc) for doc in docs_data + ] + except (json.JSONDecodeError, ValueError) as e: + logger.warning( + "Failed to deserialize referenced_documents for " + "conversation %s: %s", + conversation_id, + e, + ) cache_entry = CacheEntry( query=conversation_entry[0], response=conversation_entry[1], @@ -216,6 +235,7 @@ def get( model=conversation_entry[3], started_at=conversation_entry[4], completed_at=conversation_entry[5], + referenced_documents=docs_obj, ) result.append(cache_entry) @@ -244,6 +264,23 @@ def insert_or_append( cursor = self.connection.cursor() current_time = time() + + referenced_documents_json = None + if cache_entry.referenced_documents: + try: + docs_as_dicts = [ + doc.model_dump(mode="json") + for doc in cache_entry.referenced_documents + ] + referenced_documents_json = json.dumps(docs_as_dicts) + except (TypeError, ValueError) as e: + logger.warning( + "Failed to serialize referenced_documents for " + "conversation %s: %s", + conversation_id, + e, + ) + cursor.execute( self.INSERT_CONVERSATION_HISTORY_STATEMENT, ( @@ -256,6 +293,7 @@ def insert_or_append( cache_entry.response, cache_entry.provider, cache_entry.model, + referenced_documents_json, ), ) diff --git a/src/models/cache_entry.py b/src/models/cache_entry.py index 9f3119f10..116372bbb 100644 --- a/src/models/cache_entry.py +++ b/src/models/cache_entry.py @@ -1,6 +1,7 @@ """Model for conversation history cache entry.""" from pydantic import BaseModel +from models.responses import ReferencedDocument class CacheEntry(BaseModel): @@ -11,6 +12,7 @@ class CacheEntry(BaseModel): response: The response string provider: Provider identification model: Model identification + referenced_documents: List of documents referenced by the response """ query: str @@ -19,17 +21,4 @@ class CacheEntry(BaseModel): model: str started_at: str completed_at: str - - -class ConversationData(BaseModel): - """Model representing conversation data returned by cache list operations. - - Attributes: - conversation_id: The conversation ID - topic_summary: The topic summary for the conversation (can be None) - last_message_timestamp: The timestamp of the last message in the conversation - """ - - conversation_id: str - topic_summary: str | None - last_message_timestamp: float + referenced_documents: list[ReferencedDocument] | None = None diff --git a/src/models/responses.py b/src/models/responses.py index fdf324d76..ce5297454 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -7,7 +7,6 @@ from pydantic import AnyUrl, BaseModel, Field from llama_stack_client.types import ProviderInfo -from models.cache_entry import ConversationData class ModelsResponse(BaseModel): @@ -163,6 +162,20 @@ class ToolCall(BaseModel): result: Optional[dict[str, Any]] = Field(None, description="Result from the tool") +class ConversationData(BaseModel): + """Model representing conversation data returned by cache list operations. + + Attributes: + conversation_id: The conversation ID + topic_summary: The topic summary for the conversation (can be None) + last_message_timestamp: The timestamp of the last message in the conversation + """ + + conversation_id: str + topic_summary: str | None + last_message_timestamp: float + + class ReferencedDocument(BaseModel): """Model representing a document referenced in generating a response. @@ -175,7 +188,7 @@ class ReferencedDocument(BaseModel): None, description="URL of the referenced document" ) - doc_title: str = Field(description="Title of the referenced document") + doc_title: str | None = Field(None, description="Title of the referenced document") class QueryResponse(BaseModel): diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index e9246b714..de2e9bec7 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -232,12 +232,7 @@ def store_conversation_into_cache( config: AppConfig, user_id: str, conversation_id: str, - provider_id: str, - model_id: str, - query: str, - response: str, - started_at: str, - completed_at: str, + cache_entry: CacheEntry, _skip_userid_check: bool, topic_summary: str | None, ) -> None: @@ -247,14 +242,6 @@ def store_conversation_into_cache( if cache is None: logger.warning("Conversation cache configured but not initialized") return - cache_entry = CacheEntry( - query=query, - response=response, - provider=provider_id, - model=model_id, - started_at=started_at, - completed_at=completed_at, - ) cache.insert_or_append( user_id, conversation_id, cache_entry, _skip_userid_check ) diff --git a/tests/unit/app/endpoints/test_conversations_v2.py b/tests/unit/app/endpoints/test_conversations_v2.py index 0448c2206..8467ed7ae 100644 --- a/tests/unit/app/endpoints/test_conversations_v2.py +++ b/tests/unit/app/endpoints/test_conversations_v2.py @@ -14,7 +14,7 @@ ) from models.cache_entry import CacheEntry from models.requests import ConversationUpdateRequest -from models.responses import ConversationUpdateResponse +from models.responses import ConversationUpdateResponse, ReferencedDocument from tests.unit.utils.auth_helpers import mock_authorization_resolvers MOCK_AUTH = ("mock_user_id", "mock_username", False, "mock_token") @@ -59,6 +59,68 @@ def test_transform_message() -> None: assert message2["content"] == "response" +class TestTransformChatMessage: + """Test cases for the transform_chat_message utility function.""" + + def test_transform_message_without_documents(self) -> None: + """Test the transformation when no referenced_documents are present.""" + entry = CacheEntry( + query="query", + response="response", + provider="provider", + model="model", + started_at="2024-01-01T00:00:00Z", + completed_at="2024-01-01T00:00:05Z", + # referenced_documents is None by default + ) + transformed = transform_chat_message(entry) + + assistant_message = transformed["messages"][1] + + # Assert that the key is NOT present when the list is None + assert "referenced_documents" not in assistant_message + + def test_transform_message_with_referenced_documents(self) -> None: + """Test the transformation when referenced_documents are present.""" + docs = [ReferencedDocument(doc_title="Test Doc", doc_url="http://example.com")] + entry = CacheEntry( + query="query", + response="response", + provider="provider", + model="model", + started_at="2024-01-01T00:00:00Z", + completed_at="2024-01-01T00:00:05Z", + referenced_documents=docs, + ) + + transformed = transform_chat_message(entry) + assistant_message = transformed["messages"][1] + + assert "referenced_documents" in assistant_message + ref_docs = assistant_message["referenced_documents"] + assert len(ref_docs) == 1 + assert ref_docs[0]["doc_title"] == "Test Doc" + assert str(ref_docs[0]["doc_url"]) == "http://example.com/" + + def test_transform_message_with_empty_referenced_documents(self) -> None: + """Test the transformation when referenced_documents is an empty list.""" + entry = CacheEntry( + query="query", + response="response", + provider="provider", + model="model", + started_at="2024-01-01T00:00:00Z", + completed_at="2024-01-01T00:00:05Z", + referenced_documents=[], # Explicitly empty + ) + + transformed = transform_chat_message(entry) + assistant_message = transformed["messages"][1] + + assert "referenced_documents" in assistant_message + assert assistant_message["referenced_documents"] == [] + + @pytest.fixture def mock_configuration(): """Mock configuration with conversation cache.""" diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index b0d89351d..51943d15d 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -30,6 +30,7 @@ ) from authorization.resolvers import NoopRolesResolver from configuration import AppConfig +from models.cache_entry import CacheEntry from models.config import Action, ModelContextProtocolServer from models.database.conversations import UserConversation from models.requests import Attachment, QueryRequest @@ -166,6 +167,7 @@ def test_is_transcripts_disabled(setup_configuration, mocker) -> None: assert is_transcripts_enabled() is False, "Transcripts should be disabled" +# pylint: disable=too-many-locals async def _test_query_endpoint_handler( mocker, dummy_request: Request, store_transcript_to_file=False ) -> None: @@ -184,6 +186,17 @@ async def _test_query_endpoint_handler( ) mocker.patch("app.endpoints.query.configuration", mock_config) + mock_store_in_cache = mocker.patch( + "app.endpoints.query.store_conversation_into_cache" + ) + + # Create mock referenced documents to simulate a successful RAG response + mock_referenced_documents = [ + ReferencedDocument( + doc_title="Test Doc 1", doc_url=AnyUrl("http://example.com/1") + ) + ] + summary = TurnSummary( llm_response="LLM answer", tool_calls=[ @@ -197,11 +210,15 @@ async def _test_query_endpoint_handler( ) conversation_id = "00000000-0000-0000-0000-000000000000" query = "What is OpenStack?" - referenced_documents: list[ReferencedDocument] = [] mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(summary, conversation_id, referenced_documents, TokenCounter()), + return_value=( + summary, + conversation_id, + mock_referenced_documents, + TokenCounter(), + ), ) mocker.patch( "app.endpoints.query.select_model_and_provider_id", @@ -231,6 +248,19 @@ async def _test_query_endpoint_handler( assert response.response == summary.llm_response assert response.conversation_id == conversation_id + # Assert that mock was called and get the arguments + mock_store_in_cache.assert_called_once() + call_args = mock_store_in_cache.call_args[0] + # Extract CacheEntry object from the call arguments, + # it's the 4th argument from the func signature + cached_entry = call_args[3] + + assert isinstance(cached_entry, CacheEntry) + assert cached_entry.response == "LLM answer" + assert cached_entry.referenced_documents is not None + assert len(cached_entry.referenced_documents) == 1 + assert cached_entry.referenced_documents[0].doc_title == "Test Doc 1" + # Note: metrics are now handled inside extract_and_update_token_metrics() which is mocked # Assert the store_transcript function is called if transcripts are enabled diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 3bd12816a..8664678ce 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -51,6 +51,7 @@ from authorization.resolvers import NoopRolesResolver from constants import MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT +from models.cache_entry import CacheEntry from models.config import ModelContextProtocolServer, Action from models.requests import QueryRequest, Attachment from models.responses import RAGChunk @@ -208,6 +209,7 @@ async def test_streaming_query_endpoint_on_connection_error(mocker): assert response.media_type == "text/event-stream" +# pylint: disable=too-many-locals async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False): """Test the streaming query endpoint handler.""" mock_client = mocker.AsyncMock() @@ -283,6 +285,7 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) role="assistant", content=[TextContentItem(text="LLM answer", type="text")], stop_reason="end_of_turn", + tool_calls=[], ), session_id="test_session_id", started_at=datetime.now(), @@ -295,6 +298,9 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) ), ] + mock_store_in_cache = mocker.patch( + "app.endpoints.streaming_query.store_conversation_into_cache" + ) query = "What is OpenStack?" mocker.patch( "app.endpoints.streaming_query.retrieve_response", @@ -356,6 +362,23 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) assert len(referenced_documents) == 2 assert referenced_documents[1]["doc_title"] == "Doc2" + # Assert that mock was called and get the arguments + mock_store_in_cache.assert_called_once() + call_args = mock_store_in_cache.call_args[0] + # Extract CacheEntry object from the call arguments, + # it's the 4th argument from the func signature + cached_entry = call_args[3] + + # Assert that the CacheEntry was constructed correctly + assert isinstance(cached_entry, CacheEntry) + assert cached_entry.response == "LLM answer" + assert cached_entry.referenced_documents is not None + assert len(cached_entry.referenced_documents) == 2 + assert cached_entry.referenced_documents[0].doc_title == "Doc1" + assert ( + str(cached_entry.referenced_documents[1].doc_url) == "https://example.com/doc2" + ) + # Assert the store_transcript function is called if transcripts are enabled if store_transcript: mock_transcript.assert_called_once_with( diff --git a/tests/unit/cache/test_postgres_cache.py b/tests/unit/cache/test_postgres_cache.py index 32646997d..18023f006 100644 --- a/tests/unit/cache/test_postgres_cache.py +++ b/tests/unit/cache/test_postgres_cache.py @@ -1,5 +1,7 @@ """Unit tests for PostgreSQL cache implementation.""" +import json + import pytest import psycopg2 @@ -7,7 +9,8 @@ from cache.cache_error import CacheError from cache.postgres_cache import PostgresCache from models.config import PostgreSQLDatabaseConfiguration -from models.cache_entry import CacheEntry, ConversationData +from models.cache_entry import CacheEntry +from models.responses import ConversationData, ReferencedDocument from utils import suid @@ -375,3 +378,110 @@ def test_topic_summary_when_disconnected(postgres_cache_config_fixture, mocker): with pytest.raises(CacheError, match="cache is disconnected"): cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, "Test", False) + + +def test_insert_and_get_with_referenced_documents( + postgres_cache_config_fixture, mocker +): + """Test that a CacheEntry with referenced_documents is stored and retrieved correctly.""" + # prevent real connection to PG instance + mock_connect = mocker.patch("psycopg2.connect") + cache = PostgresCache(postgres_cache_config_fixture) + + mock_connection = mock_connect.return_value + mock_cursor = mock_connection.cursor.return_value.__enter__.return_value + + # Create a CacheEntry with referenced documents + docs = [ReferencedDocument(doc_title="Test Doc", doc_url="http://example.com/")] + entry_with_docs = CacheEntry( + query="user message", + response="AI message", + provider="foo", + model="bar", + started_at="start_time", + completed_at="end_time", + referenced_documents=docs, + ) + + # Call the insert method + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_with_docs) + + # Find the INSERT INTO cache(...) call + insert_calls = [ + c + for c in mock_cursor.execute.call_args_list + if isinstance(c[0][0], str) and "INSERT INTO cache(" in c[0][0] + ] + assert insert_calls, "INSERT call not found" + sql_params = insert_calls[-1][0][1] + inserted_json_str = sql_params[-1] + + assert json.loads(inserted_json_str) == [ + {"doc_url": "http://example.com/", "doc_title": "Test Doc"} + ] + + # Simulate the database returning that data + db_return_value = ( + "user message", + "AI message", + "foo", + "bar", + "start_time", + "end_time", + [{"doc_url": "http://example.com/", "doc_title": "Test Doc"}], + ) + mock_cursor.fetchall.return_value = [db_return_value] + + # Call the get method + retrieved_entries = cache.get(USER_ID_1, CONVERSATION_ID_1) + + # Assert that the retrieved entry matches the original + assert len(retrieved_entries) == 1 + assert retrieved_entries[0] == entry_with_docs + assert retrieved_entries[0].referenced_documents[0].doc_title == "Test Doc" + + +def test_insert_and_get_without_referenced_documents( + postgres_cache_config_fixture, mocker +): + """Test that a CacheEntry with no referenced_documents is handled correctly.""" + mock_connect = mocker.patch("psycopg2.connect") + cache = PostgresCache(postgres_cache_config_fixture) + + mock_connection = mock_connect.return_value + mock_cursor = mock_connection.cursor.return_value.__enter__.return_value + + # Use CacheEntry without referenced_documents + entry_without_docs = cache_entry_2 + + # Call the insert method + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_without_docs) + + insert_calls = [ + c + for c in mock_cursor.execute.call_args_list + if isinstance(c[0][0], str) and "INSERT INTO cache(" in c[0][0] + ] + assert insert_calls, "INSERT call not found" + sql_params = insert_calls[-1][0][1] + assert sql_params[-1] is None + + # Simulate the database returning a row with None + db_return_value = ( + entry_without_docs.query, + entry_without_docs.response, + entry_without_docs.provider, + entry_without_docs.model, + entry_without_docs.started_at, + entry_without_docs.completed_at, + None, # referenced_documents is None in the DB + ) + mock_cursor.fetchall.return_value = [db_return_value] + + # Call the get method + retrieved_entries = cache.get(USER_ID_1, CONVERSATION_ID_1) + + # Assert that the retrieved entry matches the original + assert len(retrieved_entries) == 1 + assert retrieved_entries[0] == entry_without_docs + assert retrieved_entries[0].referenced_documents is None diff --git a/tests/unit/cache/test_sqlite_cache.py b/tests/unit/cache/test_sqlite_cache.py index 6c9fd0fab..1f77ee9f1 100644 --- a/tests/unit/cache/test_sqlite_cache.py +++ b/tests/unit/cache/test_sqlite_cache.py @@ -7,7 +7,8 @@ import pytest from models.config import SQLiteDatabaseConfiguration -from models.cache_entry import CacheEntry, ConversationData +from models.cache_entry import CacheEntry +from models.responses import ConversationData, ReferencedDocument from utils import suid from cache.cache_error import CacheError @@ -357,3 +358,53 @@ def test_topic_summary_when_disconnected(tmpdir): with pytest.raises(CacheError, match="cache is disconnected"): cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, "Test", False) + + +def test_insert_and_get_with_referenced_documents(tmpdir): + """ + Test that a CacheEntry with referenced_documents is correctly + serialized, stored, and retrieved. + """ + cache = create_cache(tmpdir) + + # Create a CacheEntry with referenced documents + docs = [ReferencedDocument(doc_title="Test Doc", doc_url="http://example.com")] + entry_with_docs = CacheEntry( + query="user message", + response="AI message", + provider="foo", + model="bar", + started_at="start_time", + completed_at="end_time", + referenced_documents=docs, + ) + + # Call the insert method + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_with_docs) + retrieved_entries = cache.get(USER_ID_1, CONVERSATION_ID_1) + + # Assert that the retrieved entry matches the original + assert len(retrieved_entries) == 1 + assert retrieved_entries[0] == entry_with_docs + assert retrieved_entries[0].referenced_documents is not None + assert retrieved_entries[0].referenced_documents[0].doc_title == "Test Doc" + + +def test_insert_and_get_without_referenced_documents(tmpdir): + """ + Test that a CacheEntry without referenced_documents is correctly + stored and retrieved with its referenced_documents attribute as None. + """ + cache = create_cache(tmpdir) + + # Use CacheEntry without referenced_documents + entry_without_docs = cache_entry_1 + + # Call the insert method + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_without_docs) + retrieved_entries = cache.get(USER_ID_1, CONVERSATION_ID_1) + + # Assert that the retrieved entry matches the original + assert len(retrieved_entries) == 1 + assert retrieved_entries[0] == entry_without_docs + assert retrieved_entries[0].referenced_documents is None