diff --git a/docs/openapi.json b/docs/openapi.json index 911cf83b9..c9cf4fbc7 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -744,7 +744,9 @@ }, "conversations": [ { - "conversation_id": "123e4567-e89b-12d3-a456-426614174000" + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "topic_summary": "This is a topic summary", + "last_message_timestamp": "2024-01-01T00:00:00Z" } ] } @@ -1419,6 +1421,37 @@ "title": "ConversationCacheConfiguration", "description": "Conversation cache configuration." }, + "ConversationData": { + "properties": { + "conversation_id": { + "type": "string", + "title": "Conversation Id" + }, + "topic_summary": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Topic Summary" + }, + "last_message_timestamp": { + "type": "number", + "title": "Last Message Timestamp" + } + }, + "type": "object", + "required": [ + "conversation_id", + "topic_summary", + "last_message_timestamp" + ], + "title": "ConversationData", + "description": "Model representing conversation data returned by cache list operations.\n\nAttributes:\n conversation_id: The conversation ID\n topic_summary: The topic summary for the conversation (can be None)\n last_message_timestamp: The timestamp of the last message in the conversation" + }, "ConversationDeleteResponse": { "properties": { "conversation_id": { @@ -1536,6 +1569,21 @@ "openai", "gemini" ] + }, + "topic_summary": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Topic Summary", + "description": "Topic summary for the conversation", + "examples": [ + "Openshift Microservices Deployment Strategies" + ] } }, "type": "object", @@ -1543,7 +1591,7 @@ "conversation_id" ], "title": "ConversationDetails", - "description": "Model representing the details of a user conversation.\n\nAttributes:\n conversation_id: The conversation ID (UUID).\n created_at: When the conversation was created.\n last_message_at: When the last message was sent.\n message_count: Number of user messages in the conversation.\n last_used_model: The last model used for the conversation.\n last_used_provider: The provider of the last used model.\n\nExample:\n ```python\n conversation = ConversationDetails(\n conversation_id=\"123e4567-e89b-12d3-a456-426614174000\"\n created_at=\"2024-01-01T00:00:00Z\",\n last_message_at=\"2024-01-01T00:05:00Z\",\n message_count=5,\n last_used_model=\"gemini/gemini-2.0-flash\",\n last_used_provider=\"gemini\",\n )\n ```" + "description": "Model representing the details of a user conversation.\n\nAttributes:\n conversation_id: The conversation ID (UUID).\n created_at: When the conversation was created.\n last_message_at: When the last message was sent.\n message_count: Number of user messages in the conversation.\n last_used_model: The last model used for the conversation.\n last_used_provider: The provider of the last used model.\n topic_summary: The topic summary for the conversation.\n\nExample:\n ```python\n conversation = ConversationDetails(\n conversation_id=\"123e4567-e89b-12d3-a456-426614174000\"\n created_at=\"2024-01-01T00:00:00Z\",\n last_message_at=\"2024-01-01T00:05:00Z\",\n message_count=5,\n last_used_model=\"gemini/gemini-2.0-flash\",\n last_used_provider=\"gemini\",\n topic_summary=\"Openshift Microservices Deployment Strategies\",\n )\n ```" }, "ConversationResponse": { "properties": { @@ -1604,7 +1652,7 @@ "conversations" ], "title": "ConversationsListResponse", - "description": "Model representing a response for listing conversations of a user.\n\nAttributes:\n conversations: List of conversation details associated with the user.\n\nExample:\n ```python\n conversations_list = ConversationsListResponse(\n conversations=[\n ConversationDetails(\n conversation_id=\"123e4567-e89b-12d3-a456-426614174000\",\n created_at=\"2024-01-01T00:00:00Z\",\n last_message_at=\"2024-01-01T00:05:00Z\",\n message_count=5,\n last_used_model=\"gemini/gemini-2.0-flash\",\n last_used_provider=\"gemini\",\n ),\n ConversationDetails(\n conversation_id=\"456e7890-e12b-34d5-a678-901234567890\"\n created_at=\"2024-01-01T01:00:00Z\",\n message_count=2,\n last_used_model=\"gemini/gemini-2.0-flash\",\n last_used_provider=\"gemini\",\n )\n ]\n )\n ```", + "description": "Model representing a response for listing conversations of a user.\n\nAttributes:\n conversations: List of conversation details associated with the user.\n\nExample:\n ```python\n conversations_list = ConversationsListResponse(\n conversations=[\n ConversationDetails(\n conversation_id=\"123e4567-e89b-12d3-a456-426614174000\",\n created_at=\"2024-01-01T00:00:00Z\",\n last_message_at=\"2024-01-01T00:05:00Z\",\n message_count=5,\n last_used_model=\"gemini/gemini-2.0-flash\",\n last_used_provider=\"gemini\",\n topic_summary=\"Openshift Microservices Deployment Strategies\",\n ),\n ConversationDetails(\n conversation_id=\"456e7890-e12b-34d5-a678-901234567890\"\n created_at=\"2024-01-01T01:00:00Z\",\n message_count=2,\n last_used_model=\"gemini/gemini-2.0-flash\",\n last_used_provider=\"gemini\",\n topic_summary=\"RHDH Purpose Summary\",\n )\n ]\n )\n ```", "examples": [ { "conversations": [ @@ -1614,14 +1662,16 @@ "last_message_at": "2024-01-01T00:05:00Z", "last_used_model": "gemini/gemini-2.0-flash", "last_used_provider": "gemini", - "message_count": 5 + "message_count": 5, + "topic_summary": "Openshift Microservices Deployment Strategies" }, { "conversation_id": "456e7890-e12b-34d5-a678-901234567890", "created_at": "2024-01-01T01:00:00Z", "last_used_model": "gemini/gemini-2.5-flash", "last_used_provider": "gemini", - "message_count": 2 + "message_count": 2, + "topic_summary": "RHDH Purpose Summary" } ] } @@ -1631,7 +1681,7 @@ "properties": { "conversations": { "items": { - "type": "string" + "$ref": "#/components/schemas/ConversationData" }, "type": "array", "title": "Conversations" @@ -1642,7 +1692,7 @@ "conversations" ], "title": "ConversationsListResponseV2", - "description": "Model representing a response for listing conversations of a user.\n\nAttributes:\n conversations: List of conversation IDs associated with the user." + "description": "Model representing a response for listing conversations of a user.\n\nAttributes:\n conversations: List of conversation data associated with the user." }, "CustomProfile": { "properties": { diff --git a/src/app/endpoints/conversations.py b/src/app/endpoints/conversations.py index bfb951296..214e171ac 100644 --- a/src/app/endpoints/conversations.py +++ b/src/app/endpoints/conversations.py @@ -214,6 +214,7 @@ async def get_conversations_list_endpoint_handler( message_count=conv.message_count, last_used_model=conv.last_used_model, last_used_provider=conv.last_used_provider, + topic_summary=conv.topic_summary, ) for conv in user_conversations ] diff --git a/src/app/endpoints/conversations_v2.py b/src/app/endpoints/conversations_v2.py index 5033b5e5f..16932162b 100644 --- a/src/app/endpoints/conversations_v2.py +++ b/src/app/endpoints/conversations_v2.py @@ -83,6 +83,8 @@ "conversations": [ { "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "topic_summary": "This is a topic summary", + "last_message_timestamp": "2024-01-01T00:00:00Z", } ] } @@ -102,6 +104,8 @@ async def get_conversations_list_endpoint_handler( logger.info("Retrieving conversations for user %s", user_id) + skip_userid_check = auth[2] + if configuration.conversation_cache is None: logger.warning("Converastion cache is not configured") raise HTTPException( @@ -112,7 +116,7 @@ async def get_conversations_list_endpoint_handler( }, ) - conversations = configuration.conversation_cache.list(user_id, False) + conversations = configuration.conversation_cache.list(user_id, skip_userid_check) logger.info("Conversations for user %s: %s", user_id, len(conversations)) return ConversationsListResponseV2(conversations=conversations) @@ -132,6 +136,8 @@ async def get_conversation_endpoint_handler( user_id = auth[0] logger.info("Retrieving conversation %s for user %s", conversation_id, user_id) + skip_userid_check = auth[2] + if configuration.conversation_cache is None: logger.warning("Converastion cache is not configured") raise HTTPException( @@ -144,7 +150,9 @@ async def get_conversation_endpoint_handler( check_conversation_existence(user_id, conversation_id) - conversation = configuration.conversation_cache.get(user_id, conversation_id, False) + conversation = configuration.conversation_cache.get( + user_id, conversation_id, skip_userid_check + ) chat_history = [transform_chat_message(entry) for entry in conversation] return ConversationResponse( @@ -168,6 +176,8 @@ async def delete_conversation_endpoint_handler( user_id = auth[0] logger.info("Deleting conversation %s for user %s", conversation_id, user_id) + skip_userid_check = auth[2] + if configuration.conversation_cache is None: logger.warning("Converastion cache is not configured") raise HTTPException( @@ -181,7 +191,9 @@ async def delete_conversation_endpoint_handler( check_conversation_existence(user_id, conversation_id) logger.info("Deleting conversation %s for user %s", conversation_id, user_id) - deleted = configuration.conversation_cache.delete(user_id, conversation_id, False) + deleted = configuration.conversation_cache.delete( + user_id, conversation_id, skip_userid_check + ) if deleted: return ConversationDeleteResponse( @@ -215,7 +227,8 @@ def check_conversation_existence(user_id: str, conversation_id: str) -> None: if configuration.conversation_cache is None: return conversations = configuration.conversation_cache.list(user_id, False) - if conversation_id not in conversations: + conversation_ids = [conv.conversation_id for conv in conversations] + if conversation_id not in conversation_ids: logger.error("No conversation found for conversation ID %s", conversation_id) raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 15a215444..3f75a6cd4 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -46,6 +46,8 @@ from utils.endpoints import ( check_configuration_loaded, get_agent, + get_topic_summary_system_prompt, + get_temp_agent, get_system_prompt, store_conversation_into_cache, validate_conversation_ownership, @@ -98,7 +100,11 @@ def is_transcripts_enabled() -> bool: def persist_user_conversation_details( - user_id: str, conversation_id: str, model: str, provider_id: str + user_id: str, + conversation_id: str, + model: str, + provider_id: str, + topic_summary: Optional[str], ) -> None: """Associate conversation to user in the database.""" with get_session() as session: @@ -112,6 +118,7 @@ def persist_user_conversation_details( user_id=user_id, last_used_model=model, last_used_provider=provider_id, + topic_summary=topic_summary, message_count=1, ) session.add(conversation) @@ -169,9 +176,42 @@ def evaluate_model_hints( return model_id, provider_id +async def get_topic_summary( + question: str, client: AsyncLlamaStackClient, model_id: str +) -> str: + """Get a topic summary for a question. + + Args: + question: The question to be validated. + client: The AsyncLlamaStackClient to use for the request. + model_id: The ID of the model to use. + Returns: + str: The topic summary for the question. + """ + topic_summary_system_prompt = get_topic_summary_system_prompt(configuration) + agent, session_id, _ = await get_temp_agent( + client, model_id, topic_summary_system_prompt + ) + response = await agent.create_turn( + messages=[UserMessage(role="user", content=question)], + session_id=session_id, + stream=False, + toolgroups=None, + ) + response = cast(Turn, response) + return ( + interleaved_content_as_str(response.output_message.content) + if ( + getattr(response, "output_message", None) is not None + and getattr(response.output_message, "content", None) is not None + ) + else "" + ) + + @router.post("/query", responses=query_response) @authorize(Action.QUERY) -async def query_endpoint_handler( +async def query_endpoint_handler( # pylint: disable=R0914 request: Request, query_request: QueryRequest, auth: Annotated[AuthTuple, Depends(auth_dependency)], @@ -200,7 +240,7 @@ async def query_endpoint_handler( # log Llama Stack configuration logger.info("Llama stack config: %s", configuration.llama_stack_configuration) - user_id, _, _, token = auth + user_id, _, _skip_userid_check, token = auth user_conversation: UserConversation | None = None if query_request.conversation_id: @@ -251,6 +291,16 @@ async def query_endpoint_handler( # Update metrics for the LLM call metrics.llm_calls_total.labels(provider_id, model_id).inc() + # Get the initial topic summary for the conversation + topic_summary = None + with get_session() as session: + existing_conversation = ( + session.query(UserConversation).filter_by(id=conversation_id).first() + ) + if not existing_conversation: + topic_summary = await get_topic_summary( + query_request.query, client, model_id + ) # Convert RAG chunks to dictionary format once for reuse logger.info("Processing RAG chunks...") rag_chunks_dict = [chunk.model_dump() for chunk in summary.rag_chunks] @@ -278,6 +328,7 @@ async def query_endpoint_handler( conversation_id=conversation_id, model=model_id, provider_id=provider_id, + topic_summary=topic_summary, ) store_conversation_into_cache( @@ -288,6 +339,8 @@ async def query_endpoint_handler( model_id, query_request.query, summary.llm_response, + _skip_userid_check, + topic_summary, ) # Convert tool calls to response format diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index f48646ac7..bf4d8635c 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -20,17 +20,18 @@ from llama_stack_client.types.shared import ToolCall from llama_stack_client.types.shared.interleaved_content_item import TextContentItem -import metrics +from app.database import get_session from app.endpoints.query import ( - evaluate_model_hints, get_rag_toolgroups, is_input_shield, is_output_shield, is_transcripts_enabled, - persist_user_conversation_details, select_model_and_provider_id, validate_attachments_metadata, validate_conversation_ownership, + persist_user_conversation_details, + evaluate_model_hints, + get_topic_summary, ) from authentication import get_auth_dependency from authentication.interface import AuthTuple @@ -38,6 +39,7 @@ from client import AsyncLlamaStackClientHolder from configuration import configuration from constants import DEFAULT_RAG_TOOL +import metrics from metrics.utils import update_llm_token_count_from_turn from models.config import Action from models.database.conversations import UserConversation @@ -567,7 +569,7 @@ def _handle_heartbeat_event(chunk_id: int) -> Iterator[str]: @router.post("/streaming_query", responses=streaming_query_responses) @authorize(Action.STREAMING_QUERY) -async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals +async def streaming_query_endpoint_handler( # pylint: disable=R0915,R0914 request: Request, query_request: QueryRequest, auth: Annotated[AuthTuple, Depends(auth_dependency)], @@ -705,6 +707,19 @@ async def response_generator( attachments=query_request.attachments or [], ) + # Get the initial topic summary for the conversation + topic_summary = None + with get_session() as session: + existing_conversation = ( + session.query(UserConversation) + .filter_by(id=conversation_id) + .first() + ) + if not existing_conversation: + topic_summary = await get_topic_summary( + query_request.query, client, model_id + ) + store_conversation_into_cache( configuration, user_id, @@ -713,14 +728,17 @@ async def response_generator( model_id, query_request.query, summary.llm_response, + _skip_userid_check, + topic_summary, ) - persist_user_conversation_details( - user_id=user_id, - conversation_id=conversation_id, - model=model_id, - provider_id=provider_id, - ) + persist_user_conversation_details( + user_id=user_id, + conversation_id=conversation_id, + model=model_id, + provider_id=provider_id, + topic_summary=topic_summary, + ) # Update metrics for the LLM call metrics.llm_calls_total.labels(provider_id, model_id).inc() diff --git a/src/cache/cache.py b/src/cache/cache.py index 98b087a41..4cdab6307 100644 --- a/src/cache/cache.py +++ b/src/cache/cache.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod -from models.cache_entry import CacheEntry +from models.cache_entry import CacheEntry, ConversationData from utils.suid import check_suid @@ -90,7 +90,7 @@ def delete( """ @abstractmethod - def list(self, user_id: str, skip_user_id_check: bool) -> list[str]: + def list(self, user_id: str, skip_user_id_check: bool) -> list[ConversationData]: """List all conversations for a given user_id. Args: @@ -98,7 +98,25 @@ def list(self, user_id: str, skip_user_id_check: bool) -> list[str]: skip_user_id_check: Skip user_id suid check. Returns: - A list of conversation ids from the cache + A list of ConversationData objects containing conversation_id, topic_summary, and + last_message_timestamp + """ + + @abstractmethod + def set_topic_summary( + self, + user_id: str, + conversation_id: str, + topic_summary: str, + skip_user_id_check: bool, + ) -> None: + """Abstract method to store topic summary in the cache. + + Args: + user_id: User identification. + conversation_id: Conversation ID unique for given user. + topic_summary: The topic summary to store. + skip_user_id_check: Skip user_id suid check. """ @abstractmethod diff --git a/src/cache/in_memory_cache.py b/src/cache/in_memory_cache.py index 7c29bd2a3..1b6b4123f 100644 --- a/src/cache/in_memory_cache.py +++ b/src/cache/in_memory_cache.py @@ -1,7 +1,7 @@ """In-memory cache implementation.""" from cache.cache import Cache -from models.cache_entry import CacheEntry +from models.cache_entry import CacheEntry, ConversationData from models.config import InMemoryCacheConfig from log import get_logger from utils.connection_decorator import connection @@ -85,7 +85,9 @@ def delete( return True @connection - def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]: + def list( + self, user_id: str, skip_user_id_check: bool = False + ) -> list[ConversationData]: """List all conversations for a given user_id. Args: @@ -99,6 +101,25 @@ def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]: super()._check_user_id(user_id, skip_user_id_check) return [] + @connection + def set_topic_summary( + self, + user_id: str, + conversation_id: str, + topic_summary: str, + skip_user_id_check: bool = False, + ) -> None: + """Set the topic summary for the given conversation. + + Args: + user_id: User identification. + conversation_id: Conversation ID unique for given user. + topic_summary: The topic summary to store. + skip_user_id_check: Skip user_id suid check. + """ + # just check if user_id and conversation_id are UUIDs + super().construct_key(user_id, conversation_id, skip_user_id_check) + def ready(self) -> bool: """Check if the cache is ready. diff --git a/src/cache/noop_cache.py b/src/cache/noop_cache.py index 5c24271fa..fcd20f368 100644 --- a/src/cache/noop_cache.py +++ b/src/cache/noop_cache.py @@ -1,7 +1,7 @@ """No-operation cache implementation.""" from cache.cache import Cache -from models.cache_entry import CacheEntry +from models.cache_entry import CacheEntry, ConversationData from log import get_logger from utils.connection_decorator import connection @@ -83,7 +83,9 @@ def delete( return True @connection - def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]: + def list( + self, user_id: str, skip_user_id_check: bool = False + ) -> list[ConversationData]: """List all conversations for a given user_id. Args: @@ -97,6 +99,25 @@ def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]: super()._check_user_id(user_id, skip_user_id_check) return [] + @connection + def set_topic_summary( + self, + user_id: str, + conversation_id: str, + topic_summary: str, + skip_user_id_check: bool = False, + ) -> None: + """Set the topic summary for the given conversation. + + Args: + user_id: User identification. + conversation_id: Conversation ID unique for given user. + topic_summary: The topic summary to store. + skip_user_id_check: Skip user_id suid check. + """ + # just check if user_id and conversation_id are UUIDs + super().construct_key(user_id, conversation_id, skip_user_id_check) + def ready(self) -> bool: """Check if the cache is ready. diff --git a/src/cache/postgres_cache.py b/src/cache/postgres_cache.py index ae591b84f..a8e5d0338 100644 --- a/src/cache/postgres_cache.py +++ b/src/cache/postgres_cache.py @@ -4,7 +4,7 @@ from cache.cache import Cache from cache.cache_error import CacheError -from models.cache_entry import CacheEntry +from models.cache_entry import CacheEntry, ConversationData from models.config import PostgreSQLDatabaseConfiguration from log import get_logger from utils.connection_decorator import connection @@ -46,6 +46,16 @@ class PostgresCache(Cache): ); """ + CREATE_CONVERSATIONS_TABLE = """ + CREATE TABLE IF NOT EXISTS conversations ( + user_id text NOT NULL, + conversation_id text NOT NULL, + topic_summary text, + last_message_timestamp timestamp NOT NULL, + PRIMARY KEY(user_id, conversation_id) + ); + """ + CREATE_INDEX = """ CREATE INDEX IF NOT EXISTS timestamps ON cache (created_at) @@ -73,13 +83,31 @@ class PostgresCache(Cache): """ LIST_CONVERSATIONS_STATEMENT = """ - SELECT conversation_id, max(created_at) AS created_at - FROM cache + SELECT conversation_id, topic_summary, EXTRACT(EPOCH FROM last_message_timestamp) as last_message_timestamp + FROM conversations WHERE user_id=%s - GROUP BY conversation_id - ORDER BY created_at DESC + ORDER BY last_message_timestamp DESC """ + INSERT_OR_UPDATE_TOPIC_SUMMARY_STATEMENT = """ + INSERT INTO conversations(user_id, conversation_id, topic_summary, last_message_timestamp) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (user_id, conversation_id) + DO UPDATE SET topic_summary = EXCLUDED.topic_summary, last_message_timestamp = EXCLUDED.last_message_timestamp + """ + + DELETE_CONVERSATION_STATEMENT = """ + DELETE FROM conversations + WHERE user_id=%s AND conversation_id=%s + """ + + UPSERT_CONVERSATION_STATEMENT = """ + INSERT INTO conversations(user_id, conversation_id, topic_summary, last_message_timestamp) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (user_id, conversation_id) + DO UPDATE SET last_message_timestamp = EXCLUDED.last_message_timestamp + """ + def __init__(self, config: PostgreSQLDatabaseConfiguration) -> None: """Create a new instance of PostgreSQL cache.""" self.postgres_config = config @@ -143,6 +171,9 @@ def initialize_cache(self) -> None: logger.info("Initializing table for cache") cursor.execute(PostgresCache.CREATE_CACHE_TABLE) + logger.info("Initializing table for conversations") + cursor.execute(PostgresCache.CREATE_CONVERSATIONS_TABLE) + logger.info("Initializing index for cache") cursor.execute(PostgresCache.CREATE_INDEX) @@ -220,6 +251,12 @@ def insert_or_append( cache_entry.model, ), ) + + # Update or insert conversation record with last_message_timestamp + cursor.execute( + PostgresCache.UPSERT_CONVERSATION_STATEMENT, + (user_id, conversation_id, None), + ) # commit is implicit at this point except psycopg2.DatabaseError as e: logger.error("PostgresCache.insert_or_append: %s", e) @@ -251,13 +288,22 @@ def delete( (user_id, conversation_id), ) deleted = cursor.rowcount + + # Also delete conversation record for this conversation + cursor.execute( + PostgresCache.DELETE_CONVERSATION_STATEMENT, + (user_id, conversation_id), + ) + return deleted > 0 except psycopg2.DatabaseError as e: logger.error("PostgresCache.delete: %s", e) raise CacheError("PostgresCache.delete", e) from e @connection - def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]: + def list( + self, user_id: str, skip_user_id_check: bool = False + ) -> list[ConversationData]: """List all conversations for a given user_id. Args: @@ -265,7 +311,8 @@ def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]: skip_user_id_check: Skip user_id suid check. Returns: - A list of conversation ids from the cache + A list of ConversationData objects containing conversation_id, topic_summary, and + last_message_timestamp """ if self.connection is None: @@ -276,7 +323,46 @@ def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]: cursor.execute(self.LIST_CONVERSATIONS_STATEMENT, (user_id,)) conversations = cursor.fetchall() - return [conversation[0] for conversation in conversations] + result = [] + for conversation in conversations: + conversation_data = ConversationData( + conversation_id=conversation[0], + topic_summary=conversation[1], + last_message_timestamp=float(conversation[2]), + ) + result.append(conversation_data) + + return result + + @connection + def set_topic_summary( + self, + user_id: str, + conversation_id: str, + topic_summary: str, + skip_user_id_check: bool = False, + ) -> None: + """Set the topic summary for the given conversation. + + Args: + user_id: User identification. + conversation_id: Conversation ID unique for given user. + topic_summary: The topic summary to store. + skip_user_id_check: Skip user_id suid check. + """ + if self.connection is None: + logger.error("Cache is disconnected") + raise CacheError("set_topic_summary: cache is disconnected") + + try: + with self.connection.cursor() as cursor: + cursor.execute( + self.INSERT_OR_UPDATE_TOPIC_SUMMARY_STATEMENT, + (user_id, conversation_id, topic_summary), + ) + except psycopg2.DatabaseError as e: + logger.error("PostgresCache.set_topic_summary: %s", e) + raise CacheError("PostgresCache.set_topic_summary", e) from e def ready(self) -> bool: """Check if the cache is ready. diff --git a/src/cache/sqlite_cache.py b/src/cache/sqlite_cache.py index e8ebd7679..a39f8ade2 100644 --- a/src/cache/sqlite_cache.py +++ b/src/cache/sqlite_cache.py @@ -6,7 +6,7 @@ from cache.cache import Cache from cache.cache_error import CacheError -from models.cache_entry import CacheEntry +from models.cache_entry import CacheEntry, ConversationData from models.config import SQLiteDatabaseConfiguration from log import get_logger from utils.connection_decorator import connection @@ -50,6 +50,16 @@ class SQLiteCache(Cache): ); """ + CREATE_CONVERSATIONS_TABLE = """ + CREATE TABLE IF NOT EXISTS conversations ( + user_id text NOT NULL, + conversation_id text NOT NULL, + topic_summary text, + last_message_timestamp int NOT NULL, + PRIMARY KEY(user_id, conversation_id) + ); + """ + CREATE_INDEX = """ CREATE INDEX IF NOT EXISTS timestamps ON cache (created_at) @@ -77,12 +87,29 @@ class SQLiteCache(Cache): """ LIST_CONVERSATIONS_STATEMENT = """ - SELECT DISTINCT conversation_id - FROM cache + SELECT conversation_id, topic_summary, last_message_timestamp + FROM conversations WHERE user_id=? - ORDER BY created_at DESC + ORDER BY last_message_timestamp DESC """ + INSERT_OR_UPDATE_TOPIC_SUMMARY_STATEMENT = """ + INSERT OR REPLACE INTO conversations(user_id, conversation_id, topic_summary, last_message_timestamp) + VALUES (?, ?, ?, ?) + """ + + DELETE_CONVERSATION_STATEMENT = """ + DELETE FROM conversations + WHERE user_id=? AND conversation_id=? + """ + + UPSERT_CONVERSATION_STATEMENT = """ + INSERT INTO conversations(user_id, conversation_id, topic_summary, last_message_timestamp) + VALUES (?, ?, ?, ?) + ON CONFLICT (user_id, conversation_id) + DO UPDATE SET last_message_timestamp = excluded.last_message_timestamp + """ + def __init__(self, config: SQLiteDatabaseConfiguration) -> None: """Create a new instance of SQLite cache.""" self.sqlite_config = config @@ -141,6 +168,9 @@ def initialize_cache(self) -> None: logger.info("Initializing table for cache") cursor.execute(SQLiteCache.CREATE_CACHE_TABLE) + logger.info("Initializing table for conversations") + cursor.execute(SQLiteCache.CREATE_CONVERSATIONS_TABLE) + logger.info("Initializing index for cache") cursor.execute(SQLiteCache.CREATE_INDEX) @@ -206,18 +236,26 @@ def insert_or_append( raise CacheError("insert_or_append: cache is disconnected") cursor = self.connection.cursor() + current_time = time() cursor.execute( self.INSERT_CONVERSATION_HISTORY_STATEMENT, ( user_id, conversation_id, - time(), + current_time, cache_entry.query, cache_entry.response, cache_entry.provider, cache_entry.model, ), ) + + # Update or insert conversation record with last_message_timestamp + cursor.execute( + self.UPSERT_CONVERSATION_STATEMENT, + (user_id, conversation_id, None, current_time), + ) + cursor.close() self.connection.commit() @@ -246,12 +284,21 @@ def delete( (user_id, conversation_id), ) deleted = cursor.rowcount > 0 + + # Also delete conversation record for this conversation + cursor.execute( + self.DELETE_CONVERSATION_STATEMENT, + (user_id, conversation_id), + ) + cursor.close() self.connection.commit() return deleted @connection - def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]: + def list( + self, user_id: str, skip_user_id_check: bool = False + ) -> list[ConversationData]: """List all conversations for a given user_id. Args: @@ -259,7 +306,8 @@ def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]: skip_user_id_check: Skip user_id suid check. Returns: - A list of conversation ids from the cache + A list of ConversationData objects containing conversation_id, + topic_summary, and last_message_timestamp """ if self.connection is None: @@ -271,7 +319,44 @@ def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]: conversations = cursor.fetchall() cursor.close() - return [conversation[0] for conversation in conversations] + result = [] + for conversation in conversations: + conversation_data = ConversationData( + conversation_id=conversation[0], + topic_summary=conversation[1], + last_message_timestamp=conversation[2], + ) + result.append(conversation_data) + + return result + + @connection + def set_topic_summary( + self, + user_id: str, + conversation_id: str, + topic_summary: str, + skip_user_id_check: bool = False, + ) -> None: + """Set the topic summary for the given conversation. + + Args: + user_id: User identification. + conversation_id: Conversation ID unique for given user. + topic_summary: The topic summary to store. + skip_user_id_check: Skip user_id suid check. + """ + if self.connection is None: + logger.error("Cache is disconnected") + raise CacheError("set_topic_summary: cache is disconnected") + + cursor = self.connection.cursor() + cursor.execute( + self.INSERT_OR_UPDATE_TOPIC_SUMMARY_STATEMENT, + (user_id, conversation_id, topic_summary, time()), + ) + cursor.close() + self.connection.commit() def ready(self) -> bool: """Check if the cache is ready. diff --git a/src/constants.py b/src/constants.py index 8369b9369..6b8b33e34 100644 --- a/src/constants.py +++ b/src/constants.py @@ -28,6 +28,67 @@ # configuration file nor in the query request DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant" +# Default topic summary system prompt used only when no other topic summary system +# prompt is specified in configuration file +DEFAULT_TOPIC_SUMMARY_SYSTEM_PROMPT = """ +Instructions: +- You are a topic summarizer +- Your job is to extract precise topic summary from user input + +For Input Analysis: +- Scan entire user message +- Identify core subject matter +- Distill essence into concise descriptor +- Prioritize key concepts +- Eliminate extraneous details + +For Output Constraints: +- Maximum 5 words +- Capitalize only significant words (e.g., nouns, verbs, adjectives, adverbs). +- Do not use all uppercase - capitalize only the first letter of significant words +- Exclude articles and prepositions (e.g., "a," "the," "of," "on," "in") +- Exclude all punctuation and interpunction marks (e.g., . , : ; ! ? "") +- Retain original abbreviations. Do not expand an abbreviation if its specific meaning in the context is unknown or ambiguous. +- Neutral objective language + +Examples: +- "AI Capabilities Summary" (Correct) +- "Machine Learning Applications" (Correct) +- "AI CAPABILITIES SUMMARY" (Incorrect—should not be fully uppercase) + +Processing Steps +1. Analyze semantic structure +2. Identify primary topic +3. Remove contextual noise +4. Condense to essential meaning +5. Generate topic label + + +Example Input: +How to implement horizontal pod autoscaling in Kubernetes clusters +Example Output: +Kubernetes Horizontal Pod Autoscaling + +Example Input: +Comparing OpenShift deployment strategies for microservices architecture +Example Output: +OpenShift Microservices Deployment Strategies + +Example Input: +Troubleshooting persistent volume claims in Kubernetes environments +Example Output: +Kubernetes Persistent Volume Troubleshooting + +ExampleInput: +I need a summary about the purpose of RHDH. +Example Output: +RHDH Purpose Summary + +Input: +{query} +Output: +""" + # Authentication constants DEFAULT_VIRTUAL_PATH = "/ls-access" DEFAULT_USER_NAME = "lightspeed-user" diff --git a/src/models/cache_entry.py b/src/models/cache_entry.py index 810bad711..f87445bef 100644 --- a/src/models/cache_entry.py +++ b/src/models/cache_entry.py @@ -17,3 +17,17 @@ class CacheEntry(BaseModel): response: str provider: str model: str + + +class ConversationData(BaseModel): + """Model representing conversation data returned by cache list operations. + + Attributes: + conversation_id: The conversation ID + topic_summary: The topic summary for the conversation (can be None) + last_message_timestamp: The timestamp of the last message in the conversation + """ + + conversation_id: str + topic_summary: str | None + last_message_timestamp: float diff --git a/src/models/database/conversations.py b/src/models/database/conversations.py index 1cce8a64d..fd720b418 100644 --- a/src/models/database/conversations.py +++ b/src/models/database/conversations.py @@ -34,3 +34,5 @@ class UserConversation(Base): # pylint: disable=too-few-public-methods # The number of user messages in the conversation message_count: Mapped[int] = mapped_column(default=0) + + topic_summary: Mapped[str] = mapped_column(default="") diff --git a/src/models/responses.py b/src/models/responses.py index cab7c3ccd..f44b79ea7 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -4,6 +4,8 @@ from pydantic import AnyUrl, BaseModel, Field +from models.cache_entry import ConversationData + class ModelsResponse(BaseModel): """Model representing a response to models request.""" @@ -556,6 +558,7 @@ class ConversationDetails(BaseModel): message_count: Number of user messages in the conversation. last_used_model: The last model used for the conversation. last_used_provider: The provider of the last used model. + topic_summary: The topic summary for the conversation. Example: ```python @@ -566,6 +569,7 @@ class ConversationDetails(BaseModel): message_count=5, last_used_model="gemini/gemini-2.0-flash", last_used_provider="gemini", + topic_summary="Openshift Microservices Deployment Strategies", ) ``` """ @@ -606,6 +610,12 @@ class ConversationDetails(BaseModel): examples=["openai", "gemini"], ) + topic_summary: Optional[str] = Field( + None, + description="Topic summary for the conversation", + examples=["Openshift Microservices Deployment Strategies"], + ) + class ConversationsListResponse(BaseModel): """Model representing a response for listing conversations of a user. @@ -624,6 +634,7 @@ class ConversationsListResponse(BaseModel): message_count=5, last_used_model="gemini/gemini-2.0-flash", last_used_provider="gemini", + topic_summary="Openshift Microservices Deployment Strategies", ), ConversationDetails( conversation_id="456e7890-e12b-34d5-a678-901234567890" @@ -631,6 +642,7 @@ class ConversationsListResponse(BaseModel): message_count=2, last_used_model="gemini/gemini-2.0-flash", last_used_provider="gemini", + topic_summary="RHDH Purpose Summary", ) ] ) @@ -652,6 +664,7 @@ class ConversationsListResponse(BaseModel): "message_count": 5, "last_used_model": "gemini/gemini-2.0-flash", "last_used_provider": "gemini", + "topic_summary": "Openshift Microservices Deployment Strategies", }, { "conversation_id": "456e7890-e12b-34d5-a678-901234567890", @@ -659,6 +672,7 @@ class ConversationsListResponse(BaseModel): "message_count": 2, "last_used_model": "gemini/gemini-2.5-flash", "last_used_provider": "gemini", + "topic_summary": "RHDH Purpose Summary", }, ] } @@ -671,10 +685,10 @@ class ConversationsListResponseV2(BaseModel): """Model representing a response for listing conversations of a user. Attributes: - conversations: List of conversation IDs associated with the user. + conversations: List of conversation data associated with the user. """ - conversations: list[str] + conversations: list[ConversationData] class ErrorResponse(BaseModel): diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index a34dcb447..0b0c15102 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -143,6 +143,20 @@ def get_system_prompt(query_request: QueryRequest, config: AppConfig) -> str: return constants.DEFAULT_SYSTEM_PROMPT +def get_topic_summary_system_prompt(config: AppConfig) -> str: + """Get the topic summary system prompt.""" + # profile takes precedence for setting prompt + if ( + config.customization is not None + and config.customization.custom_profile is not None + ): + prompt = config.customization.custom_profile.get_prompts().get("topic_summary") + if prompt: + return prompt + + return constants.DEFAULT_TOPIC_SUMMARY_SYSTEM_PROMPT + + def validate_model_provider_override( query_request: QueryRequest, authorized_actions: set[Action] | frozenset[Action] ) -> None: @@ -175,6 +189,8 @@ def store_conversation_into_cache( model_id: str, query: str, response: str, + _skip_userid_check: bool, + topic_summary: str | None, ) -> None: """Store one part of conversation into conversation history cache.""" if config.conversation_cache_configuration.type is not None: @@ -188,7 +204,13 @@ def store_conversation_into_cache( provider=provider_id, model=model_id, ) - cache.insert_or_append(user_id, conversation_id, cache_entry, False) + cache.insert_or_append( + user_id, conversation_id, cache_entry, _skip_userid_check + ) + if topic_summary and len(topic_summary) > 0: + cache.set_topic_summary( + user_id, conversation_id, topic_summary, _skip_userid_check + ) # # pylint: disable=R0913,R0917 @@ -285,3 +307,36 @@ async def get_agent( logger.debug("New session ID: %s", session_id) return agent, conversation_id, session_id + + +async def get_temp_agent( + client: AsyncLlamaStackClient, + model_id: str, + system_prompt: str, +) -> tuple[AsyncAgent, str, str]: + """Create a temporary agent with new agent_id and session_id. + + This function creates a new agent without persistence, shields, or tools. + Useful for temporary operations or one-off queries, such as validating a + question or generating a summary. + Args: + client: The AsyncLlamaStackClient to use for the request. + model_id: The ID of the model to use. + system_prompt: The system prompt/instructions for the agent. + Returns: + tuple[AsyncAgent, str]: A tuple containing the agent and session_id. + """ + logger.debug("Creating temporary agent") + agent = AsyncAgent( + client, # type: ignore[arg-type] + model=model_id, + instructions=system_prompt, + enable_session_persistence=False, # Temporary agent doesn't need persistence + ) + await agent.initialize() + + # Generate new IDs for the temporary agent + conversation_id = agent.agent_id + session_id = await agent.create_session(get_suid()) + + return agent, session_id, conversation_id diff --git a/tests/unit/app/endpoints/test_conversations.py b/tests/unit/app/endpoints/test_conversations.py index 7c372da13..89548782a 100644 --- a/tests/unit/app/endpoints/test_conversations.py +++ b/tests/unit/app/endpoints/test_conversations.py @@ -48,6 +48,7 @@ def create_mock_conversation( message_count, last_used_model, last_used_provider, + topic_summary=None, ): # pylint: disable=too-many-arguments,too-many-positional-arguments """Helper function to create a mock conversation object with all required attributes.""" mock_conversation = mocker.Mock() @@ -59,6 +60,7 @@ def create_mock_conversation( mock_conversation.message_count = message_count mock_conversation.last_used_model = last_used_model mock_conversation.last_used_provider = last_used_provider + mock_conversation.topic_summary = topic_summary return mock_conversation @@ -627,6 +629,7 @@ async def test_successful_conversations_list_retrieval( 5, "gemini/gemini-2.0-flash", "gemini", + "OpenStack deployment strategies", ), create_mock_conversation( mocker, @@ -636,6 +639,7 @@ async def test_successful_conversations_list_retrieval( 2, "gemini/gemini-2.5-flash", "gemini", + "Kubernetes troubleshooting", ), ] mock_database_session(mocker, mock_conversations) @@ -646,14 +650,26 @@ async def test_successful_conversations_list_retrieval( assert isinstance(response, ConversationsListResponse) assert len(response.conversations) == 2 - assert ( - response.conversations[0].conversation_id - == "123e4567-e89b-12d3-a456-426614174000" - ) - assert ( - response.conversations[1].conversation_id - == "456e7890-e12b-34d5-a678-901234567890" - ) + + # Test first conversation + conv1 = response.conversations[0] + assert conv1.conversation_id == "123e4567-e89b-12d3-a456-426614174000" + assert conv1.created_at == "2024-01-01T00:00:00Z" + assert conv1.last_message_at == "2024-01-01T00:05:00Z" + assert conv1.message_count == 5 + assert conv1.last_used_model == "gemini/gemini-2.0-flash" + assert conv1.last_used_provider == "gemini" + assert conv1.topic_summary == "OpenStack deployment strategies" + + # Test second conversation + conv2 = response.conversations[1] + assert conv2.conversation_id == "456e7890-e12b-34d5-a678-901234567890" + assert conv2.created_at == "2024-01-01T01:00:00Z" + assert conv2.last_message_at == "2024-01-01T01:02:00Z" + assert conv2.message_count == 2 + assert conv2.last_used_model == "gemini/gemini-2.5-flash" + assert conv2.last_used_provider == "gemini" + assert conv2.topic_summary == "Kubernetes troubleshooting" @pytest.mark.asyncio async def test_empty_conversations_list( @@ -691,7 +707,177 @@ async def test_database_exception(self, mocker, setup_configuration, dummy_reque assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert "Unknown error" in exc_info.value.detail["response"] - assert ( - "Unknown error while getting conversations for user" - in exc_info.value.detail["cause"] + + @pytest.mark.asyncio + async def test_conversations_list_with_none_topic_summary( + self, mocker, setup_configuration, dummy_request + ): + """Test conversations list when topic_summary is None.""" + mock_authorization_resolvers(mocker) + mocker.patch("app.endpoints.conversations.configuration", setup_configuration) + + # Mock database session with conversation having None topic_summary + mock_conversations = [ + create_mock_conversation( + mocker, + "123e4567-e89b-12d3-a456-426614174000", + "2024-01-01T00:00:00Z", + "2024-01-01T00:05:00Z", + 5, + "gemini/gemini-2.0-flash", + "gemini", + None, # topic_summary is None + ), + ] + mock_database_session(mocker, mock_conversations) + + response = await get_conversations_list_endpoint_handler( + auth=MOCK_AUTH, request=dummy_request + ) + + assert isinstance(response, ConversationsListResponse) + assert len(response.conversations) == 1 + + conv = response.conversations[0] + assert conv.conversation_id == "123e4567-e89b-12d3-a456-426614174000" + assert conv.topic_summary is None + + @pytest.mark.asyncio + async def test_conversations_list_with_mixed_topic_summaries( + self, mocker, setup_configuration, dummy_request + ): + """Test conversations list with mixed topic_summary values (some None, some not).""" + mock_authorization_resolvers(mocker) + mocker.patch("app.endpoints.conversations.configuration", setup_configuration) + + # Mock database session with mixed topic_summary values + mock_conversations = [ + create_mock_conversation( + mocker, + "123e4567-e89b-12d3-a456-426614174000", + "2024-01-01T00:00:00Z", + "2024-01-01T00:05:00Z", + 5, + "gemini/gemini-2.0-flash", + "gemini", + "OpenStack deployment strategies", # Has topic_summary + ), + create_mock_conversation( + mocker, + "456e7890-e12b-34d5-a678-901234567890", + "2024-01-01T01:00:00Z", + "2024-01-01T01:02:00Z", + 2, + "gemini/gemini-2.5-flash", + "gemini", + None, # No topic_summary + ), + create_mock_conversation( + mocker, + "789e0123-e45b-67d8-a901-234567890123", + "2024-01-01T02:00:00Z", + "2024-01-01T02:03:00Z", + 3, + "openai/gpt-4", + "openai", + "Machine learning model training", # Has topic_summary + ), + ] + mock_database_session(mocker, mock_conversations) + + response = await get_conversations_list_endpoint_handler( + auth=MOCK_AUTH, request=dummy_request + ) + + assert isinstance(response, ConversationsListResponse) + assert len(response.conversations) == 3 + + # Test first conversation (with topic_summary) + conv1 = response.conversations[0] + assert conv1.conversation_id == "123e4567-e89b-12d3-a456-426614174000" + assert conv1.topic_summary == "OpenStack deployment strategies" + + # Test second conversation (without topic_summary) + conv2 = response.conversations[1] + assert conv2.conversation_id == "456e7890-e12b-34d5-a678-901234567890" + assert conv2.topic_summary is None + + # Test third conversation (with topic_summary) + conv3 = response.conversations[2] + assert conv3.conversation_id == "789e0123-e45b-67d8-a901-234567890123" + assert conv3.topic_summary == "Machine learning model training" + + @pytest.mark.asyncio + async def test_conversations_list_with_empty_topic_summary( + self, mocker, setup_configuration, dummy_request + ): + """Test conversations list when topic_summary is an empty string.""" + mock_authorization_resolvers(mocker) + mocker.patch("app.endpoints.conversations.configuration", setup_configuration) + + # Mock database session with conversation having empty topic_summary + mock_conversations = [ + create_mock_conversation( + mocker, + "123e4567-e89b-12d3-a456-426614174000", + "2024-01-01T00:00:00Z", + "2024-01-01T00:05:00Z", + 5, + "gemini/gemini-2.0-flash", + "gemini", + "", # Empty topic_summary + ), + ] + mock_database_session(mocker, mock_conversations) + + response = await get_conversations_list_endpoint_handler( + auth=MOCK_AUTH, request=dummy_request ) + + assert isinstance(response, ConversationsListResponse) + assert len(response.conversations) == 1 + + conv = response.conversations[0] + assert conv.conversation_id == "123e4567-e89b-12d3-a456-426614174000" + assert conv.topic_summary == "" + + @pytest.mark.asyncio + async def test_conversations_list_topic_summary_field_presence( + self, mocker, setup_configuration, dummy_request + ): + """Test that topic_summary field is always present in ConversationDetails objects.""" + mock_authorization_resolvers(mocker) + mocker.patch("app.endpoints.conversations.configuration", setup_configuration) + + # Mock database session with conversations + mock_conversations = [ + create_mock_conversation( + mocker, + "123e4567-e89b-12d3-a456-426614174000", + "2024-01-01T00:00:00Z", + "2024-01-01T00:05:00Z", + 5, + "gemini/gemini-2.0-flash", + "gemini", + "Test topic summary", + ), + ] + mock_database_session(mocker, mock_conversations) + + response = await get_conversations_list_endpoint_handler( + auth=MOCK_AUTH, request=dummy_request + ) + + assert isinstance(response, ConversationsListResponse) + assert len(response.conversations) == 1 + + conv = response.conversations[0] + + # Verify that topic_summary field exists and is accessible + assert hasattr(conv, "topic_summary") + assert conv.topic_summary == "Test topic summary" + + # Verify that the field is properly serialized (if needed for API responses) + conv_dict = conv.model_dump() + assert "topic_summary" in conv_dict + assert conv_dict["topic_summary"] == "Test topic summary" diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 5e1f6363a..300fa7768 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -18,6 +18,7 @@ from app.endpoints.query import ( evaluate_model_hints, + get_topic_summary, get_rag_toolgroups, is_transcripts_enabled, parse_metadata_from_text_item, @@ -75,6 +76,13 @@ def mock_database_operations(mocker): ) mocker.patch("app.endpoints.query.persist_user_conversation_details") + # Mock the database session and query + mock_session = mocker.Mock() + mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.__enter__ = mocker.Mock(return_value=mock_session) + mock_session.__exit__ = mocker.Mock(return_value=None) + mocker.patch("app.endpoints.query.get_session", return_value=mock_session) + @pytest.fixture(name="setup_configuration") def setup_configuration_fixture(): @@ -199,6 +207,11 @@ async def _test_query_endpoint_handler( ) mock_transcript = mocker.patch("app.endpoints.query.store_transcript") + # Mock get_topic_summary function + mocker.patch( + "app.endpoints.query.get_topic_summary", return_value="Test topic summary" + ) + # Mock database operations mock_database_operations(mocker) @@ -1394,6 +1407,10 @@ async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker, dummy_requ return_value=("test_model", "test_model", "test_provider"), ) mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) + # Mock get_topic_summary function + mocker.patch( + "app.endpoints.query.get_topic_summary", return_value="Test topic summary" + ) # Mock database operations mock_database_operations(mocker) @@ -1445,6 +1462,10 @@ async def test_query_endpoint_handler_no_tools_true(mocker, dummy_request): return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), ) mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) + # Mock get_topic_summary function + mocker.patch( + "app.endpoints.query.get_topic_summary", return_value="Test topic summary" + ) # Mock database operations mock_database_operations(mocker) @@ -1497,6 +1518,10 @@ async def test_query_endpoint_handler_no_tools_false(mocker, dummy_request): return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), ) mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) + # Mock get_topic_summary function + mocker.patch( + "app.endpoints.query.get_topic_summary", return_value="Test topic summary" + ) # Mock database operations mock_database_operations(mocker) @@ -1782,3 +1807,323 @@ async def test_query_endpoint_rejects_model_provider_override_without_permission ) assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN assert exc_info.value.detail["response"] == expected_msg + + +@pytest.mark.asyncio +async def test_get_topic_summary_successful_response(mocker): + """Test get_topic_summary with successful response from agent.""" + # Mock the dependencies + mock_client = mocker.AsyncMock() + mock_agent = mocker.AsyncMock() + mock_response = mocker.Mock() + mock_response.output_message.content = "This is a topic summary about OpenStack" + + # Mock the get_temp_agent function + mock_get_temp_agent = mocker.patch( + "app.endpoints.query.get_temp_agent", + return_value=(mock_agent, "session_123", "conversation_456"), + ) + + # Mock the agent's create_turn method + mock_agent.create_turn.return_value = mock_response + + # Mock the interleaved_content_as_str function + mocker.patch( + "app.endpoints.query.interleaved_content_as_str", + return_value="This is a topic summary about OpenStack", + ) + + # Mock the get_topic_summary_system_prompt function + mocker.patch( + "app.endpoints.query.get_topic_summary_system_prompt", + return_value="You are a topic summarizer", + ) + + # Mock the configuration + mock_config = mocker.Mock() + mocker.patch("app.endpoints.query.configuration", mock_config) + + # Call the function + result = await get_topic_summary( + question="What is OpenStack?", client=mock_client, model_id="test_model" + ) + + # Assertions + assert result == "This is a topic summary about OpenStack" + + # Verify get_temp_agent was called with correct parameters + mock_get_temp_agent.assert_called_once_with( + mock_client, "test_model", "You are a topic summarizer" + ) + + # Verify create_turn was called with correct parameters + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(role="user", content="What is OpenStack?")], + session_id="session_123", + stream=False, + toolgroups=None, + ) + + +@pytest.mark.asyncio +async def test_get_topic_summary_empty_response(mocker): + """Test get_topic_summary with empty response from agent.""" + # Mock the dependencies + mock_client = mocker.AsyncMock() + mock_agent = mocker.AsyncMock() + mock_response = mocker.Mock() + mock_response.output_message = None + + # Mock the get_temp_agent function + mocker.patch( + "app.endpoints.query.get_temp_agent", + return_value=(mock_agent, "session_123", "conversation_456"), + ) + + # Mock the agent's create_turn method + mock_agent.create_turn.return_value = mock_response + + # Mock the get_topic_summary_system_prompt function + mocker.patch( + "app.endpoints.query.get_topic_summary_system_prompt", + return_value="You are a topic summarizer", + ) + + # Mock the configuration + mock_config = mocker.Mock() + mocker.patch("app.endpoints.query.configuration", mock_config) + + # Call the function + result = await get_topic_summary( + question="What is OpenStack?", client=mock_client, model_id="test_model" + ) + + # Assertions + assert result == "" + + +@pytest.mark.asyncio +async def test_get_topic_summary_none_content(mocker): + """Test get_topic_summary with None content in response.""" + # Mock the dependencies + mock_client = mocker.AsyncMock() + mock_agent = mocker.AsyncMock() + mock_response = mocker.Mock() + mock_response.output_message.content = None + + # Mock the get_temp_agent function + mocker.patch( + "app.endpoints.query.get_temp_agent", + return_value=(mock_agent, "session_123", "conversation_456"), + ) + + # Mock the agent's create_turn method + mock_agent.create_turn.return_value = mock_response + + # Mock the get_topic_summary_system_prompt function + mocker.patch( + "app.endpoints.query.get_topic_summary_system_prompt", + return_value="You are a topic summarizer", + ) + + # Mock the configuration + mock_config = mocker.Mock() + mocker.patch("app.endpoints.query.configuration", mock_config) + + # Call the function + result = await get_topic_summary( + question="What is OpenStack?", client=mock_client, model_id="test_model" + ) + + # Assertions + assert result == "" + + +@pytest.mark.asyncio +async def test_get_topic_summary_with_interleaved_content(mocker): + """Test get_topic_summary with interleaved content response.""" + # Mock the dependencies + mock_client = mocker.AsyncMock() + mock_agent = mocker.AsyncMock() + mock_response = mocker.Mock() + mock_content = [TextContentItem(text="Topic summary", type="text")] + mock_response.output_message.content = mock_content + + # Mock the get_temp_agent function + mocker.patch( + "app.endpoints.query.get_temp_agent", + return_value=(mock_agent, "session_123", "conversation_456"), + ) + + # Mock the agent's create_turn method + mock_agent.create_turn.return_value = mock_response + + # Mock the interleaved_content_as_str function + mock_interleaved_content_as_str = mocker.patch( + "app.endpoints.query.interleaved_content_as_str", return_value="Topic summary" + ) + + # Mock the get_topic_summary_system_prompt function + mocker.patch( + "app.endpoints.query.get_topic_summary_system_prompt", + return_value="You are a topic summarizer", + ) + + # Mock the configuration + mock_config = mocker.Mock() + mocker.patch("app.endpoints.query.configuration", mock_config) + + # Call the function + result = await get_topic_summary( + question="What is OpenStack?", client=mock_client, model_id="test_model" + ) + + # Assertions + assert result == "Topic summary" + + # Verify interleaved_content_as_str was called with the content + mock_interleaved_content_as_str.assert_called_once_with(mock_content) + + +@pytest.mark.asyncio +async def test_get_topic_summary_system_prompt_retrieval(mocker): + """Test that get_topic_summary properly retrieves and uses the system prompt.""" + # Mock the dependencies + mock_client = mocker.AsyncMock() + mock_agent = mocker.AsyncMock() + mock_response = mocker.Mock() + mock_response.output_message.content = "Topic summary" + + # Mock the get_temp_agent function + mocker.patch( + "app.endpoints.query.get_temp_agent", + return_value=(mock_agent, "session_123", "conversation_456"), + ) + + # Mock the agent's create_turn method + mock_agent.create_turn.return_value = mock_response + + # Mock the interleaved_content_as_str function + mocker.patch( + "app.endpoints.query.interleaved_content_as_str", return_value="Topic summary" + ) + + # Mock the get_topic_summary_system_prompt function + mock_get_topic_summary_system_prompt = mocker.patch( + "app.endpoints.query.get_topic_summary_system_prompt", + return_value="Custom topic summarizer prompt", + ) + + # Mock the configuration + mock_config = mocker.Mock() + mocker.patch("app.endpoints.query.configuration", mock_config) + + # Call the function + result = await get_topic_summary( + question="What is OpenStack?", client=mock_client, model_id="test_model" + ) + + # Assertions + assert result == "Topic summary" + + # Verify get_topic_summary_system_prompt was called with configuration + mock_get_topic_summary_system_prompt.assert_called_once_with(mock_config) + + +@pytest.mark.asyncio +async def test_get_topic_summary_agent_creation_parameters(mocker): + """Test that get_topic_summary creates agent with correct parameters.""" + # Mock the dependencies + mock_client = mocker.AsyncMock() + mock_agent = mocker.AsyncMock() + mock_response = mocker.Mock() + mock_response.output_message.content = "Topic summary" + + # Mock the get_temp_agent function + mock_get_temp_agent = mocker.patch( + "app.endpoints.query.get_temp_agent", + return_value=(mock_agent, "session_123", "conversation_456"), + ) + + # Mock the agent's create_turn method + mock_agent.create_turn.return_value = mock_response + + # Mock the interleaved_content_as_str function + mocker.patch( + "app.endpoints.query.interleaved_content_as_str", return_value="Topic summary" + ) + + # Mock the get_topic_summary_system_prompt function + mocker.patch( + "app.endpoints.query.get_topic_summary_system_prompt", + return_value="Custom system prompt", + ) + + # Mock the configuration + mock_config = mocker.Mock() + mocker.patch("app.endpoints.query.configuration", mock_config) + + # Call the function + result = await get_topic_summary( + question="Test question?", client=mock_client, model_id="custom_model" + ) + + # Assertions + assert result == "Topic summary" + + # Verify get_temp_agent was called with correct parameters + mock_get_temp_agent.assert_called_once_with( + mock_client, "custom_model", "Custom system prompt" + ) + + +@pytest.mark.asyncio +async def test_get_topic_summary_create_turn_parameters(mocker): + """Test that get_topic_summary calls create_turn with correct parameters.""" + # Mock the dependencies + mock_client = mocker.AsyncMock() + mock_agent = mocker.AsyncMock() + mock_response = mocker.Mock() + mock_response.output_message.content = "Topic summary" + + # Mock the get_temp_agent function + mocker.patch( + "app.endpoints.query.get_temp_agent", + return_value=(mock_agent, "test_session", "test_conversation"), + ) + + # Mock the agent's create_turn method + mock_agent.create_turn.return_value = mock_response + + # Mock the interleaved_content_as_str function + mocker.patch( + "app.endpoints.query.interleaved_content_as_str", return_value="Topic summary" + ) + + # Mock the get_topic_summary_system_prompt function + mocker.patch( + "app.endpoints.query.get_topic_summary_system_prompt", + return_value="Custom system prompt", + ) + + # Mock the configuration + mock_config = mocker.Mock() + mocker.patch("app.endpoints.query.configuration", mock_config) + + # Call the function + result = await get_topic_summary( + question="What is the meaning of life?", + client=mock_client, + model_id="test_model", + ) + + # Assertions + assert result == "Topic summary" + + # Verify create_turn was called with correct parameters + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(role="user", content="What is the meaning of life?")], + session_id="test_session", + stream=False, + toolgroups=None, + ) diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 773ab66a5..52a387073 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -64,6 +64,13 @@ def mock_database_operations(mocker): ) mocker.patch("app.endpoints.streaming_query.persist_user_conversation_details") + # Mock the database session and query + mock_session = mocker.Mock() + mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.__enter__ = mocker.Mock(return_value=mock_session) + mock_session.__exit__ = mocker.Mock(return_value=None) + mocker.patch("app.endpoints.streaming_query.get_session", return_value=mock_session) + def mock_metrics(mocker): """Helper function to mock metrics operations for streaming query endpoints.""" @@ -292,6 +299,12 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) ) mock_transcript = mocker.patch("app.endpoints.streaming_query.store_transcript") + # Mock get_topic_summary function + mocker.patch( + "app.endpoints.streaming_query.get_topic_summary", + return_value="Test topic summary", + ) + mock_database_operations(mocker) query_request = QueryRequest(query=query) @@ -1364,6 +1377,11 @@ async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker): mocker.patch( "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False ) + # Mock get_topic_summary function + mocker.patch( + "app.endpoints.streaming_query.get_topic_summary", + return_value="Test topic summary", + ) mock_database_operations(mocker) request = Request( @@ -1410,6 +1428,11 @@ async def test_streaming_query_endpoint_handler_no_tools_true(mocker): mocker.patch( "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False ) + # Mock get_topic_summary function + mocker.patch( + "app.endpoints.streaming_query.get_topic_summary", + return_value="Test topic summary", + ) # Mock database operations mock_database_operations(mocker) @@ -1457,6 +1480,11 @@ async def test_streaming_query_endpoint_handler_no_tools_false(mocker): mocker.patch( "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False ) + # Mock get_topic_summary function + mocker.patch( + "app.endpoints.streaming_query.get_topic_summary", + return_value="Test topic summary", + ) # Mock database operations mock_database_operations(mocker) diff --git a/tests/unit/cache/test_postgres_cache.py b/tests/unit/cache/test_postgres_cache.py index 61998de70..51e06448c 100644 --- a/tests/unit/cache/test_postgres_cache.py +++ b/tests/unit/cache/test_postgres_cache.py @@ -7,7 +7,7 @@ from cache.cache_error import CacheError from cache.postgres_cache import PostgresCache from models.config import PostgreSQLDatabaseConfiguration -from models.cache_entry import CacheEntry +from models.cache_entry import CacheEntry, ConversationData from utils import suid @@ -301,3 +301,67 @@ def test_list_operation_when_connected(postgres_cache_config_fixture, mocker): # should not fail lst = cache.list(USER_ID_1, False) assert not lst + assert isinstance(lst, list) + + +def test_topic_summary_operations(postgres_cache_config_fixture, mocker): + """Test topic summary set operations and retrieval via list.""" + # prevent real connection to PG instance + mock_connect = mocker.patch("psycopg2.connect") + cache = PostgresCache(postgres_cache_config_fixture) + + mock_connection = mock_connect.return_value + mock_cursor = mock_connection.cursor.return_value.__enter__.return_value + + # Mock fetchall to return conversation data + mock_cursor.fetchall.return_value = [ + ( + CONVERSATION_ID_1, + "This conversation is about machine learning and AI", + 1234567890.0, + ) + ] + + # Set a topic summary + test_summary = "This conversation is about machine learning and AI" + cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, test_summary, False) + + # Retrieve the topic summary via list + conversations = cache.list(USER_ID_1, False) + assert len(conversations) == 1 + assert conversations[0].topic_summary == test_summary + assert isinstance(conversations[0], ConversationData) + + +def test_topic_summary_after_conversation_delete(postgres_cache_config_fixture, mocker): + """Test that topic summary is deleted when conversation is deleted.""" + # prevent real connection to PG instance + mock_connect = mocker.patch("psycopg2.connect") + cache = PostgresCache(postgres_cache_config_fixture) + + mock_connection = mock_connect.return_value + mock_cursor = mock_connection.cursor.return_value.__enter__.return_value + + # Mock the delete operation to return 1 (deleted) + mock_cursor.rowcount = 1 + + # Add some cache entries and a topic summary + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, cache_entry_1, False) + cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, "Test summary", False) + + # Delete the conversation + deleted = cache.delete(USER_ID_1, CONVERSATION_ID_1, False) + assert deleted is True + + +def test_topic_summary_when_disconnected(postgres_cache_config_fixture, mocker): + """Test topic summary operations when cache is disconnected.""" + # prevent real connection to PG instance + mocker.patch("psycopg2.connect") + cache = PostgresCache(postgres_cache_config_fixture) + + cache.connection = None + cache.connect = lambda: None + + with pytest.raises(CacheError, match="cache is disconnected"): + cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, "Test", False) diff --git a/tests/unit/cache/test_sqlite_cache.py b/tests/unit/cache/test_sqlite_cache.py index 381b009ab..32cc9a46a 100644 --- a/tests/unit/cache/test_sqlite_cache.py +++ b/tests/unit/cache/test_sqlite_cache.py @@ -7,7 +7,7 @@ import pytest from models.config import SQLiteDatabaseConfiguration -from models.cache_entry import CacheEntry +from models.cache_entry import CacheEntry, ConversationData from utils import suid from cache.cache_error import CacheError @@ -188,6 +188,7 @@ def test_list_operation_when_connected(tmpdir): # should not fail lst = cache.list(USER_ID_1, False) assert not lst + assert isinstance(lst, list) def test_ready_method(tmpdir): @@ -255,3 +256,94 @@ def test_multiple_ids(tmpdir): lst = cache.get(USER_ID_2, CONVERSATION_ID_2, False) assert lst[0] == cache_entry_1 assert lst[1] == cache_entry_2 + + +def test_list_with_conversations(tmpdir): + """Test the list() method with actual conversations.""" + cache = create_cache(tmpdir) + + # Add some conversations + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, cache_entry_1, False) + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_2, cache_entry_2, False) + + # Set topic summaries + cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, "First conversation", False) + cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_2, "Second conversation", False) + + # Test list functionality + conversations = cache.list(USER_ID_1, False) + assert len(conversations) == 2 + assert all(isinstance(conv, ConversationData) for conv in conversations) + + # Check that conversations are ordered by last_message_timestamp DESC + assert ( + conversations[0].last_message_timestamp + >= conversations[1].last_message_timestamp + ) + + # Check conversation IDs + conv_ids = [conv.conversation_id for conv in conversations] + assert CONVERSATION_ID_1 in conv_ids + assert CONVERSATION_ID_2 in conv_ids + + +def test_topic_summary_operations(tmpdir): + """Test topic summary set operations and retrieval via list.""" + cache = create_cache(tmpdir) + + # Add a conversation + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, cache_entry_1, False) + + # Set a topic summary + test_summary = "This conversation is about machine learning and AI" + cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, test_summary, False) + + # Retrieve the topic summary via list + conversations = cache.list(USER_ID_1, False) + assert len(conversations) == 1 + assert conversations[0].topic_summary == test_summary + + # Update the topic summary + updated_summary = "This conversation is about deep learning and neural networks" + cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, updated_summary, False) + + # Verify the update via list + conversations = cache.list(USER_ID_1, False) + assert len(conversations) == 1 + assert conversations[0].topic_summary == updated_summary + + +def test_topic_summary_after_conversation_delete(tmpdir): + """Test that topic summary is deleted when conversation is deleted.""" + cache = create_cache(tmpdir) + + # Add some cache entries and a topic summary + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, cache_entry_1, False) + cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, "Test summary", False) + + # Verify both exist + entries = cache.get(USER_ID_1, CONVERSATION_ID_1, False) + assert len(entries) == 1 + conversations = cache.list(USER_ID_1, False) + assert len(conversations) == 1 + assert conversations[0].topic_summary == "Test summary" + + # Delete the conversation + deleted = cache.delete(USER_ID_1, CONVERSATION_ID_1, False) + assert deleted is True + + # Verify both are deleted + entries = cache.get(USER_ID_1, CONVERSATION_ID_1, False) + assert len(entries) == 0 + conversations = cache.list(USER_ID_1, False) + assert len(conversations) == 0 + + +def test_topic_summary_when_disconnected(tmpdir): + """Test topic summary operations when cache is disconnected.""" + cache = create_cache(tmpdir) + cache.connection = None + cache.connect = lambda: None + + with pytest.raises(CacheError, match="cache is disconnected"): + cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, "Test", False) diff --git a/tests/unit/utils/test_endpoints.py b/tests/unit/utils/test_endpoints.py index 04701ac48..bed970a56 100644 --- a/tests/unit/utils/test_endpoints.py +++ b/tests/unit/utils/test_endpoints.py @@ -10,7 +10,7 @@ from models.requests import QueryRequest from models.config import Action from utils import endpoints -from utils.endpoints import get_agent +from utils.endpoints import get_agent, get_temp_agent from tests.unit import config_dict @@ -657,6 +657,108 @@ async def test_get_agent_no_tools_false_preserves_parser( ) +@pytest.mark.asyncio +async def test_get_temp_agent_basic_functionality(prepare_agent_mocks, mocker): + """Test get_temp_agent function creates agent with correct parameters.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "temp_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "utils.endpoints.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("utils.endpoints.get_suid", return_value="temp_session_id") + + # Call function + result_agent, result_session_id, result_conversation_id = await get_temp_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + ) + + # Assert agent, session_id, and conversation_id are created and returned + assert result_agent == mock_agent + assert result_session_id == "temp_session_id" + assert result_conversation_id == mock_agent.agent_id + + # Verify Agent was created with correct parameters for temporary agent + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + enable_session_persistence=False, # Key difference: no persistence + ) + + # Verify agent was initialized and session was created + mock_agent.initialize.assert_called_once() + mock_agent.create_session.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_temp_agent_returns_valid_ids(prepare_agent_mocks, mocker): + """Test get_temp_agent function returns valid agent_id and session_id.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.agent_id = "generated_agent_id" + mock_agent.create_session.return_value = "generated_session_id" + + # Mock Agent class + mocker.patch("utils.endpoints.AsyncAgent", return_value=mock_agent) + + # Mock get_suid + mocker.patch("utils.endpoints.get_suid", return_value="generated_session_id") + + # Call function + result_agent, result_session_id, result_conversation_id = await get_temp_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + ) + + # Assert all three values are returned and are not None/empty + assert result_agent is not None + assert result_session_id is not None + assert result_conversation_id is not None + + # Assert they are strings + assert isinstance(result_session_id, str) + assert isinstance(result_conversation_id, str) + + # Assert conversation_id matches agent_id + assert result_conversation_id == result_agent.agent_id + + +@pytest.mark.asyncio +async def test_get_temp_agent_no_persistence(prepare_agent_mocks, mocker): + """Test get_temp_agent function creates agent without session persistence.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "temp_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "utils.endpoints.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("utils.endpoints.get_suid", return_value="temp_session_id") + + # Call function + await get_temp_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + ) + + # Verify Agent was created with session persistence disabled + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + enable_session_persistence=False, + ) + + def test_validate_model_provider_override_allowed_with_action(): """Ensure no exception when caller has MODEL_OVERRIDE and request includes model/provider.""" query_request = QueryRequest(query="q", model="m", provider="p") @@ -677,3 +779,65 @@ def test_validate_model_provider_override_no_override_without_action(): """No exception when request does not include model/provider regardless of permission.""" query_request = QueryRequest(query="q") endpoints.validate_model_provider_override(query_request, set()) + + +def test_get_topic_summary_system_prompt_default(setup_configuration): + """Test that default topic summary system prompt is returned when no custom + profile is configured. + """ + topic_summary_prompt = endpoints.get_topic_summary_system_prompt( + setup_configuration + ) + assert topic_summary_prompt == constants.DEFAULT_TOPIC_SUMMARY_SYSTEM_PROMPT + + +def test_get_topic_summary_system_prompt_with_custom_profile(): + """Test that custom profile topic summary prompt is returned when available.""" + test_config = config_dict.copy() + test_config["customization"] = { + "profile_path": "tests/profiles/test/profile.py", + } + cfg = AppConfig() + cfg.init_from_dict(test_config) + + # Mock the custom profile to return a topic_summary prompt + custom_profile = CustomProfile(path="tests/profiles/test/profile.py") + prompts = custom_profile.get_prompts() + + topic_summary_prompt = endpoints.get_topic_summary_system_prompt(cfg) + assert topic_summary_prompt == prompts.get("topic_summary") + + +def test_get_topic_summary_system_prompt_with_custom_profile_no_topic_summary(mocker): + """Test that default topic summary prompt is returned when custom profile has + no topic_summary prompt. + """ + test_config = config_dict.copy() + test_config["customization"] = { + "profile_path": "tests/profiles/test/profile.py", + } + cfg = AppConfig() + cfg.init_from_dict(test_config) + + # Mock the custom profile to return None for topic_summary prompt + mock_profile = mocker.Mock() + mock_profile.get_prompts.return_value = { + "default": "some prompt" + } # No topic_summary key + + # Patch the custom_profile property to return our mock + mocker.patch.object(cfg.customization, "custom_profile", mock_profile) + + topic_summary_prompt = endpoints.get_topic_summary_system_prompt(cfg) + assert topic_summary_prompt == constants.DEFAULT_TOPIC_SUMMARY_SYSTEM_PROMPT + + +def test_get_topic_summary_system_prompt_no_customization(): + """Test that default topic summary prompt is returned when customization is None.""" + test_config = config_dict.copy() + test_config["customization"] = None + cfg = AppConfig() + cfg.init_from_dict(test_config) + + topic_summary_prompt = endpoints.get_topic_summary_system_prompt(cfg) + assert topic_summary_prompt == constants.DEFAULT_TOPIC_SUMMARY_SYSTEM_PROMPT