From 4ac5e515a0a5e2c250716fd65f354e3334239ab9 Mon Sep 17 00:00:00 2001 From: Stephanie Date: Fri, 19 Sep 2025 11:04:25 -0400 Subject: [PATCH 1/9] update list conversations API to include topic_summary Signed-off-by: Stephanie --- src/app/endpoints/conversations.py | 1 + src/app/endpoints/query.py | 55 ++- src/app/endpoints/streaming_query.py | 16 +- src/constants.py | 4 + src/models/database/conversations.py | 2 + src/models/responses.py | 12 + src/utils/endpoints.py | 45 +++ .../unit/app/endpoints/test_conversations.py | 208 ++++++++++- tests/unit/app/endpoints/test_query.py | 345 ++++++++++++++++++ .../app/endpoints/test_streaming_query.py | 28 ++ tests/unit/utils/test_endpoints.py | 162 +++++++- 11 files changed, 863 insertions(+), 15 deletions(-) diff --git a/src/app/endpoints/conversations.py b/src/app/endpoints/conversations.py index 68f3e485c..f20512549 100644 --- a/src/app/endpoints/conversations.py +++ b/src/app/endpoints/conversations.py @@ -189,6 +189,7 @@ async def get_conversations_list_endpoint_handler( message_count=conv.message_count, last_used_model=conv.last_used_model, last_used_provider=conv.last_used_provider, + topic_summary=conv.topic_summary, ) for conv in user_conversations ] diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 43c3eb603..8698a3b84 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -3,7 +3,7 @@ from datetime import datetime, UTC import json import logging -from typing import Annotated, Any, cast +from typing import Annotated, Any, cast, Optional from llama_stack_client import APIConnectionError from llama_stack_client import AsyncLlamaStackClient # type: ignore @@ -34,6 +34,8 @@ from utils.endpoints import ( check_configuration_loaded, get_agent, + get_topic_summary_system_prompt, + get_temp_agent, get_system_prompt, validate_conversation_ownership, validate_model_provider_override, @@ -78,7 +80,11 @@ def is_transcripts_enabled() -> bool: def persist_user_conversation_details( - user_id: str, conversation_id: str, model: str, provider_id: str + user_id: str, + conversation_id: str, + model: str, + provider_id: str, + topic_summary: Optional[str], ) -> None: """Associate conversation to user in the database.""" with get_session() as session: @@ -92,6 +98,7 @@ def persist_user_conversation_details( user_id=user_id, last_used_model=model, last_used_provider=provider_id, + topic_summary=topic_summary, message_count=1, ) session.add(conversation) @@ -149,6 +156,38 @@ def evaluate_model_hints( return model_id, provider_id +async def get_topic_summary( + question: str, client: AsyncLlamaStackClient, model_id: str +) -> str: + """Get a topic summary for a question. + Args: + question: The question to be validated. + client: The AsyncLlamaStackClient to use for the request. + model_id: The ID of the model to use. + Returns: + str: The topic summary for the question. + """ + topic_summary_system_prompt = get_topic_summary_system_prompt(configuration) + agent, session_id, conversation_id = await get_temp_agent( + client, model_id, topic_summary_system_prompt + ) + response = await agent.create_turn( + messages=[UserMessage(role="user", content=question)], + session_id=session_id, + stream=False, + toolgroups=None, + ) + response = cast(Turn, response) + return ( + interleaved_content_as_str(response.output_message.content) + if ( + getattr(response, "output_message", None) is not None + and getattr(response.output_message, "content", None) is not None + ) + else "" + ) + + @router.post("/query", responses=query_response) @authorize(Action.QUERY) async def query_endpoint_handler( @@ -226,6 +265,17 @@ async def query_endpoint_handler( # Update metrics for the LLM call metrics.llm_calls_total.labels(provider_id, model_id).inc() + # Get the initial topic summary for the conversation + topic_summary = None + with get_session() as session: + existing_conversation = ( + session.query(UserConversation).filter_by(id=conversation_id).first() + ) + if not existing_conversation: + topic_summary = await get_topic_summary( + query_request.query, client, model_id + ) + if not is_transcripts_enabled(): logger.debug("Transcript collection is disabled in the configuration") else: @@ -248,6 +298,7 @@ async def query_endpoint_handler( conversation_id=conversation_id, model=model_id, provider_id=provider_id, + topic_summary=topic_summary, ) return QueryResponse( diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 60d9d4d6e..468a252ed 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -19,7 +19,7 @@ from fastapi import APIRouter, HTTPException, Request, Depends, status from fastapi.responses import StreamingResponse - +from app.database import get_session from authentication import get_auth_dependency from authentication.interface import AuthTuple from authorization.middleware import authorize @@ -46,6 +46,7 @@ validate_conversation_ownership, persist_user_conversation_details, evaluate_model_hints, + get_topic_summary, ) logger = logging.getLogger("app.endpoints.handlers") @@ -659,11 +660,24 @@ async def response_generator( attachments=query_request.attachments or [], ) + # Get the initial topic summary for the conversation + + topic_summary = None + with get_session() as session: + existing_conversation = ( + session.query(UserConversation).filter_by(id=conversation_id).first() + ) + if not existing_conversation: + topic_summary = await get_topic_summary( + query_request.query, client, model_id + ) + persist_user_conversation_details( user_id=user_id, conversation_id=conversation_id, model=model_id, provider_id=provider_id, + topic_summary=topic_summary, ) # Update metrics for the LLM call diff --git a/src/constants.py b/src/constants.py index e79ebcebb..1422b8911 100644 --- a/src/constants.py +++ b/src/constants.py @@ -28,6 +28,10 @@ # configuration file nor in the query request DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant" +# Default topic summary system prompt used only when no other topic summary system prompt is specified in +# configuration file nor in the query request +DEFAULT_TOPIC_SUMMARY_SYSTEM_PROMPT = "You are a topic summarizer" + # Authentication constants DEFAULT_VIRTUAL_PATH = "/ls-access" DEFAULT_USER_NAME = "lightspeed-user" diff --git a/src/models/database/conversations.py b/src/models/database/conversations.py index 1cce8a64d..fd720b418 100644 --- a/src/models/database/conversations.py +++ b/src/models/database/conversations.py @@ -34,3 +34,5 @@ class UserConversation(Base): # pylint: disable=too-few-public-methods # The number of user messages in the conversation message_count: Mapped[int] = mapped_column(default=0) + + topic_summary: Mapped[str] = mapped_column(default="") diff --git a/src/models/responses.py b/src/models/responses.py index 29b5e5776..1aee4584a 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -482,6 +482,7 @@ class ConversationDetails(BaseModel): message_count: Number of user messages in the conversation. last_used_model: The last model used for the conversation. last_used_provider: The provider of the last used model. + topic_summary: The topic summary for the conversation. Example: ```python @@ -492,6 +493,7 @@ class ConversationDetails(BaseModel): message_count=5, last_used_model="gemini/gemini-2.0-flash", last_used_provider="gemini", + topic_summary="Openshift Microservices Deployment Strategies", ) ``` """ @@ -532,6 +534,12 @@ class ConversationDetails(BaseModel): examples=["openai", "gemini"], ) + topic_summary: Optional[str] = Field( + None, + description="Topic summary for the conversation", + examples=["Openshift Microservices Deployment Strategies"], + ) + class ConversationsListResponse(BaseModel): """Model representing a response for listing conversations of a user. @@ -550,6 +558,7 @@ class ConversationsListResponse(BaseModel): message_count=5, last_used_model="gemini/gemini-2.0-flash", last_used_provider="gemini", + topic_summary="Openshift Microservices Deployment Strategies", ), ConversationDetails( conversation_id="456e7890-e12b-34d5-a678-901234567890" @@ -557,6 +566,7 @@ class ConversationsListResponse(BaseModel): message_count=2, last_used_model="gemini/gemini-2.0-flash", last_used_provider="gemini", + topic_summary="RHDH Purpose Summary", ) ] ) @@ -578,6 +588,7 @@ class ConversationsListResponse(BaseModel): "message_count": 5, "last_used_model": "gemini/gemini-2.0-flash", "last_used_provider": "gemini", + "topic_summary": "Openshift Microservices Deployment Strategies", }, { "conversation_id": "456e7890-e12b-34d5-a678-901234567890", @@ -585,6 +596,7 @@ class ConversationsListResponse(BaseModel): "message_count": 2, "last_used_model": "gemini/gemini-2.5-flash", "last_used_provider": "gemini", + "topic_summary": "RHDH Purpose Summary", }, ] } diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index e17a76d06..0c98f9f12 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -111,6 +111,20 @@ def get_system_prompt(query_request: QueryRequest, config: AppConfig) -> str: return constants.DEFAULT_SYSTEM_PROMPT +def get_topic_summary_system_prompt(config: AppConfig) -> str: + """Get the topic summary system prompt.""" + # profile takes precedence for setting prompt + if ( + config.customization is not None + and config.customization.custom_profile is not None + ): + prompt = config.customization.custom_profile.get_prompts().get("topic_summary") + if prompt: + return prompt + + return constants.DEFAULT_TOPIC_SUMMARY_SYSTEM_PROMPT + + def validate_model_provider_override( query_request: QueryRequest, authorized_actions: set[Action] | frozenset[Action] ) -> None: @@ -185,3 +199,34 @@ async def get_agent( session_id = await agent.create_session(get_suid()) return agent, conversation_id, session_id + + +async def get_temp_agent( + client: AsyncLlamaStackClient, + model_id: str, + system_prompt: str, +) -> tuple[AsyncAgent, str, str]: + """Create a temporary agent with new agent_id and session_id. + This function creates a new agent without persistence, shields, or tools. + Useful for temporary operations or one-off queries, such as validating a question or generating a summary. + Args: + client: The AsyncLlamaStackClient to use for the request. + model_id: The ID of the model to use. + system_prompt: The system prompt/instructions for the agent. + Returns: + tuple[AsyncAgent, str]: A tuple containing the agent and session_id. + """ + logger.debug("Creating temporary agent") + agent = AsyncAgent( + client, # type: ignore[arg-type] + model=model_id, + instructions=system_prompt, + enable_session_persistence=False, # Temporary agent doesn't need persistence + ) + await agent.initialize() + + # Generate new IDs for the temporary agent + conversation_id = agent.agent_id + session_id = await agent.create_session(get_suid()) + + return agent, session_id, conversation_id diff --git a/tests/unit/app/endpoints/test_conversations.py b/tests/unit/app/endpoints/test_conversations.py index 7c372da13..89548782a 100644 --- a/tests/unit/app/endpoints/test_conversations.py +++ b/tests/unit/app/endpoints/test_conversations.py @@ -48,6 +48,7 @@ def create_mock_conversation( message_count, last_used_model, last_used_provider, + topic_summary=None, ): # pylint: disable=too-many-arguments,too-many-positional-arguments """Helper function to create a mock conversation object with all required attributes.""" mock_conversation = mocker.Mock() @@ -59,6 +60,7 @@ def create_mock_conversation( mock_conversation.message_count = message_count mock_conversation.last_used_model = last_used_model mock_conversation.last_used_provider = last_used_provider + mock_conversation.topic_summary = topic_summary return mock_conversation @@ -627,6 +629,7 @@ async def test_successful_conversations_list_retrieval( 5, "gemini/gemini-2.0-flash", "gemini", + "OpenStack deployment strategies", ), create_mock_conversation( mocker, @@ -636,6 +639,7 @@ async def test_successful_conversations_list_retrieval( 2, "gemini/gemini-2.5-flash", "gemini", + "Kubernetes troubleshooting", ), ] mock_database_session(mocker, mock_conversations) @@ -646,14 +650,26 @@ async def test_successful_conversations_list_retrieval( assert isinstance(response, ConversationsListResponse) assert len(response.conversations) == 2 - assert ( - response.conversations[0].conversation_id - == "123e4567-e89b-12d3-a456-426614174000" - ) - assert ( - response.conversations[1].conversation_id - == "456e7890-e12b-34d5-a678-901234567890" - ) + + # Test first conversation + conv1 = response.conversations[0] + assert conv1.conversation_id == "123e4567-e89b-12d3-a456-426614174000" + assert conv1.created_at == "2024-01-01T00:00:00Z" + assert conv1.last_message_at == "2024-01-01T00:05:00Z" + assert conv1.message_count == 5 + assert conv1.last_used_model == "gemini/gemini-2.0-flash" + assert conv1.last_used_provider == "gemini" + assert conv1.topic_summary == "OpenStack deployment strategies" + + # Test second conversation + conv2 = response.conversations[1] + assert conv2.conversation_id == "456e7890-e12b-34d5-a678-901234567890" + assert conv2.created_at == "2024-01-01T01:00:00Z" + assert conv2.last_message_at == "2024-01-01T01:02:00Z" + assert conv2.message_count == 2 + assert conv2.last_used_model == "gemini/gemini-2.5-flash" + assert conv2.last_used_provider == "gemini" + assert conv2.topic_summary == "Kubernetes troubleshooting" @pytest.mark.asyncio async def test_empty_conversations_list( @@ -691,7 +707,177 @@ async def test_database_exception(self, mocker, setup_configuration, dummy_reque assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert "Unknown error" in exc_info.value.detail["response"] - assert ( - "Unknown error while getting conversations for user" - in exc_info.value.detail["cause"] + + @pytest.mark.asyncio + async def test_conversations_list_with_none_topic_summary( + self, mocker, setup_configuration, dummy_request + ): + """Test conversations list when topic_summary is None.""" + mock_authorization_resolvers(mocker) + mocker.patch("app.endpoints.conversations.configuration", setup_configuration) + + # Mock database session with conversation having None topic_summary + mock_conversations = [ + create_mock_conversation( + mocker, + "123e4567-e89b-12d3-a456-426614174000", + "2024-01-01T00:00:00Z", + "2024-01-01T00:05:00Z", + 5, + "gemini/gemini-2.0-flash", + "gemini", + None, # topic_summary is None + ), + ] + mock_database_session(mocker, mock_conversations) + + response = await get_conversations_list_endpoint_handler( + auth=MOCK_AUTH, request=dummy_request + ) + + assert isinstance(response, ConversationsListResponse) + assert len(response.conversations) == 1 + + conv = response.conversations[0] + assert conv.conversation_id == "123e4567-e89b-12d3-a456-426614174000" + assert conv.topic_summary is None + + @pytest.mark.asyncio + async def test_conversations_list_with_mixed_topic_summaries( + self, mocker, setup_configuration, dummy_request + ): + """Test conversations list with mixed topic_summary values (some None, some not).""" + mock_authorization_resolvers(mocker) + mocker.patch("app.endpoints.conversations.configuration", setup_configuration) + + # Mock database session with mixed topic_summary values + mock_conversations = [ + create_mock_conversation( + mocker, + "123e4567-e89b-12d3-a456-426614174000", + "2024-01-01T00:00:00Z", + "2024-01-01T00:05:00Z", + 5, + "gemini/gemini-2.0-flash", + "gemini", + "OpenStack deployment strategies", # Has topic_summary + ), + create_mock_conversation( + mocker, + "456e7890-e12b-34d5-a678-901234567890", + "2024-01-01T01:00:00Z", + "2024-01-01T01:02:00Z", + 2, + "gemini/gemini-2.5-flash", + "gemini", + None, # No topic_summary + ), + create_mock_conversation( + mocker, + "789e0123-e45b-67d8-a901-234567890123", + "2024-01-01T02:00:00Z", + "2024-01-01T02:03:00Z", + 3, + "openai/gpt-4", + "openai", + "Machine learning model training", # Has topic_summary + ), + ] + mock_database_session(mocker, mock_conversations) + + response = await get_conversations_list_endpoint_handler( + auth=MOCK_AUTH, request=dummy_request + ) + + assert isinstance(response, ConversationsListResponse) + assert len(response.conversations) == 3 + + # Test first conversation (with topic_summary) + conv1 = response.conversations[0] + assert conv1.conversation_id == "123e4567-e89b-12d3-a456-426614174000" + assert conv1.topic_summary == "OpenStack deployment strategies" + + # Test second conversation (without topic_summary) + conv2 = response.conversations[1] + assert conv2.conversation_id == "456e7890-e12b-34d5-a678-901234567890" + assert conv2.topic_summary is None + + # Test third conversation (with topic_summary) + conv3 = response.conversations[2] + assert conv3.conversation_id == "789e0123-e45b-67d8-a901-234567890123" + assert conv3.topic_summary == "Machine learning model training" + + @pytest.mark.asyncio + async def test_conversations_list_with_empty_topic_summary( + self, mocker, setup_configuration, dummy_request + ): + """Test conversations list when topic_summary is an empty string.""" + mock_authorization_resolvers(mocker) + mocker.patch("app.endpoints.conversations.configuration", setup_configuration) + + # Mock database session with conversation having empty topic_summary + mock_conversations = [ + create_mock_conversation( + mocker, + "123e4567-e89b-12d3-a456-426614174000", + "2024-01-01T00:00:00Z", + "2024-01-01T00:05:00Z", + 5, + "gemini/gemini-2.0-flash", + "gemini", + "", # Empty topic_summary + ), + ] + mock_database_session(mocker, mock_conversations) + + response = await get_conversations_list_endpoint_handler( + auth=MOCK_AUTH, request=dummy_request ) + + assert isinstance(response, ConversationsListResponse) + assert len(response.conversations) == 1 + + conv = response.conversations[0] + assert conv.conversation_id == "123e4567-e89b-12d3-a456-426614174000" + assert conv.topic_summary == "" + + @pytest.mark.asyncio + async def test_conversations_list_topic_summary_field_presence( + self, mocker, setup_configuration, dummy_request + ): + """Test that topic_summary field is always present in ConversationDetails objects.""" + mock_authorization_resolvers(mocker) + mocker.patch("app.endpoints.conversations.configuration", setup_configuration) + + # Mock database session with conversations + mock_conversations = [ + create_mock_conversation( + mocker, + "123e4567-e89b-12d3-a456-426614174000", + "2024-01-01T00:00:00Z", + "2024-01-01T00:05:00Z", + 5, + "gemini/gemini-2.0-flash", + "gemini", + "Test topic summary", + ), + ] + mock_database_session(mocker, mock_conversations) + + response = await get_conversations_list_endpoint_handler( + auth=MOCK_AUTH, request=dummy_request + ) + + assert isinstance(response, ConversationsListResponse) + assert len(response.conversations) == 1 + + conv = response.conversations[0] + + # Verify that topic_summary field exists and is accessible + assert hasattr(conv, "topic_summary") + assert conv.topic_summary == "Test topic summary" + + # Verify that the field is properly serialized (if needed for API responses) + conv_dict = conv.model_dump() + assert "topic_summary" in conv_dict + assert conv_dict["topic_summary"] == "Test topic summary" diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 3b3d64f3f..7034d64c5 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -23,6 +23,7 @@ is_transcripts_enabled, get_rag_toolgroups, evaluate_model_hints, + get_topic_summary, ) from models.requests import QueryRequest, Attachment @@ -62,6 +63,13 @@ def mock_database_operations(mocker): ) mocker.patch("app.endpoints.query.persist_user_conversation_details") + # Mock the database session and query + mock_session = mocker.Mock() + mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.__enter__ = mocker.Mock(return_value=mock_session) + mock_session.__exit__ = mocker.Mock(return_value=None) + mocker.patch("app.endpoints.query.get_session", return_value=mock_session) + @pytest.fixture(name="setup_configuration") def setup_configuration_fixture(): @@ -182,6 +190,11 @@ async def _test_query_endpoint_handler( ) mock_transcript = mocker.patch("app.endpoints.query.store_transcript") + # Mock get_topic_summary function + mocker.patch( + "app.endpoints.query.get_topic_summary", return_value="Test topic summary" + ) + # Mock database operations mock_database_operations(mocker) @@ -1208,6 +1221,10 @@ async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker, dummy_requ return_value=("test_model", "test_model", "test_provider"), ) mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) + # Mock get_topic_summary function + mocker.patch( + "app.endpoints.query.get_topic_summary", return_value="Test topic summary" + ) # Mock database operations mock_database_operations(mocker) @@ -1258,6 +1275,10 @@ async def test_query_endpoint_handler_no_tools_true(mocker, dummy_request): return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), ) mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) + # Mock get_topic_summary function + mocker.patch( + "app.endpoints.query.get_topic_summary", return_value="Test topic summary" + ) # Mock database operations mock_database_operations(mocker) @@ -1309,6 +1330,10 @@ async def test_query_endpoint_handler_no_tools_false(mocker, dummy_request): return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), ) mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) + # Mock get_topic_summary function + mocker.patch( + "app.endpoints.query.get_topic_summary", return_value="Test topic summary" + ) # Mock database operations mock_database_operations(mocker) @@ -1586,3 +1611,323 @@ async def test_query_endpoint_rejects_model_provider_override_without_permission ) assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN assert exc_info.value.detail["response"] == expected_msg + + +@pytest.mark.asyncio +async def test_get_topic_summary_successful_response(mocker): + """Test get_topic_summary with successful response from agent.""" + # Mock the dependencies + mock_client = mocker.AsyncMock() + mock_agent = mocker.AsyncMock() + mock_response = mocker.Mock() + mock_response.output_message.content = "This is a topic summary about OpenStack" + + # Mock the get_temp_agent function + mock_get_temp_agent = mocker.patch( + "app.endpoints.query.get_temp_agent", + return_value=(mock_agent, "session_123", "conversation_456"), + ) + + # Mock the agent's create_turn method + mock_agent.create_turn.return_value = mock_response + + # Mock the interleaved_content_as_str function + mocker.patch( + "app.endpoints.query.interleaved_content_as_str", + return_value="This is a topic summary about OpenStack", + ) + + # Mock the get_topic_summary_system_prompt function + mocker.patch( + "app.endpoints.query.get_topic_summary_system_prompt", + return_value="You are a topic summarizer", + ) + + # Mock the configuration + mock_config = mocker.Mock() + mocker.patch("app.endpoints.query.configuration", mock_config) + + # Call the function + result = await get_topic_summary( + question="What is OpenStack?", client=mock_client, model_id="test_model" + ) + + # Assertions + assert result == "This is a topic summary about OpenStack" + + # Verify get_temp_agent was called with correct parameters + mock_get_temp_agent.assert_called_once_with( + mock_client, "test_model", "You are a topic summarizer" + ) + + # Verify create_turn was called with correct parameters + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(role="user", content="What is OpenStack?")], + session_id="session_123", + stream=False, + toolgroups=None, + ) + + +@pytest.mark.asyncio +async def test_get_topic_summary_empty_response(mocker): + """Test get_topic_summary with empty response from agent.""" + # Mock the dependencies + mock_client = mocker.AsyncMock() + mock_agent = mocker.AsyncMock() + mock_response = mocker.Mock() + mock_response.output_message = None + + # Mock the get_temp_agent function + mocker.patch( + "app.endpoints.query.get_temp_agent", + return_value=(mock_agent, "session_123", "conversation_456"), + ) + + # Mock the agent's create_turn method + mock_agent.create_turn.return_value = mock_response + + # Mock the get_topic_summary_system_prompt function + mocker.patch( + "app.endpoints.query.get_topic_summary_system_prompt", + return_value="You are a topic summarizer", + ) + + # Mock the configuration + mock_config = mocker.Mock() + mocker.patch("app.endpoints.query.configuration", mock_config) + + # Call the function + result = await get_topic_summary( + question="What is OpenStack?", client=mock_client, model_id="test_model" + ) + + # Assertions + assert result == "" + + +@pytest.mark.asyncio +async def test_get_topic_summary_none_content(mocker): + """Test get_topic_summary with None content in response.""" + # Mock the dependencies + mock_client = mocker.AsyncMock() + mock_agent = mocker.AsyncMock() + mock_response = mocker.Mock() + mock_response.output_message.content = None + + # Mock the get_temp_agent function + mocker.patch( + "app.endpoints.query.get_temp_agent", + return_value=(mock_agent, "session_123", "conversation_456"), + ) + + # Mock the agent's create_turn method + mock_agent.create_turn.return_value = mock_response + + # Mock the get_topic_summary_system_prompt function + mocker.patch( + "app.endpoints.query.get_topic_summary_system_prompt", + return_value="You are a topic summarizer", + ) + + # Mock the configuration + mock_config = mocker.Mock() + mocker.patch("app.endpoints.query.configuration", mock_config) + + # Call the function + result = await get_topic_summary( + question="What is OpenStack?", client=mock_client, model_id="test_model" + ) + + # Assertions + assert result == "" + + +@pytest.mark.asyncio +async def test_get_topic_summary_with_interleaved_content(mocker): + """Test get_topic_summary with interleaved content response.""" + # Mock the dependencies + mock_client = mocker.AsyncMock() + mock_agent = mocker.AsyncMock() + mock_response = mocker.Mock() + mock_content = [TextContentItem(text="Topic summary", type="text")] + mock_response.output_message.content = mock_content + + # Mock the get_temp_agent function + mocker.patch( + "app.endpoints.query.get_temp_agent", + return_value=(mock_agent, "session_123", "conversation_456"), + ) + + # Mock the agent's create_turn method + mock_agent.create_turn.return_value = mock_response + + # Mock the interleaved_content_as_str function + mock_interleaved_content_as_str = mocker.patch( + "app.endpoints.query.interleaved_content_as_str", return_value="Topic summary" + ) + + # Mock the get_topic_summary_system_prompt function + mocker.patch( + "app.endpoints.query.get_topic_summary_system_prompt", + return_value="You are a topic summarizer", + ) + + # Mock the configuration + mock_config = mocker.Mock() + mocker.patch("app.endpoints.query.configuration", mock_config) + + # Call the function + result = await get_topic_summary( + question="What is OpenStack?", client=mock_client, model_id="test_model" + ) + + # Assertions + assert result == "Topic summary" + + # Verify interleaved_content_as_str was called with the content + mock_interleaved_content_as_str.assert_called_once_with(mock_content) + + +@pytest.mark.asyncio +async def test_get_topic_summary_system_prompt_retrieval(mocker): + """Test that get_topic_summary properly retrieves and uses the system prompt.""" + # Mock the dependencies + mock_client = mocker.AsyncMock() + mock_agent = mocker.AsyncMock() + mock_response = mocker.Mock() + mock_response.output_message.content = "Topic summary" + + # Mock the get_temp_agent function + mocker.patch( + "app.endpoints.query.get_temp_agent", + return_value=(mock_agent, "session_123", "conversation_456"), + ) + + # Mock the agent's create_turn method + mock_agent.create_turn.return_value = mock_response + + # Mock the interleaved_content_as_str function + mocker.patch( + "app.endpoints.query.interleaved_content_as_str", return_value="Topic summary" + ) + + # Mock the get_topic_summary_system_prompt function + mock_get_topic_summary_system_prompt = mocker.patch( + "app.endpoints.query.get_topic_summary_system_prompt", + return_value="Custom topic summarizer prompt", + ) + + # Mock the configuration + mock_config = mocker.Mock() + mocker.patch("app.endpoints.query.configuration", mock_config) + + # Call the function + result = await get_topic_summary( + question="What is OpenStack?", client=mock_client, model_id="test_model" + ) + + # Assertions + assert result == "Topic summary" + + # Verify get_topic_summary_system_prompt was called with configuration + mock_get_topic_summary_system_prompt.assert_called_once_with(mock_config) + + +@pytest.mark.asyncio +async def test_get_topic_summary_agent_creation_parameters(mocker): + """Test that get_topic_summary creates agent with correct parameters.""" + # Mock the dependencies + mock_client = mocker.AsyncMock() + mock_agent = mocker.AsyncMock() + mock_response = mocker.Mock() + mock_response.output_message.content = "Topic summary" + + # Mock the get_temp_agent function + mock_get_temp_agent = mocker.patch( + "app.endpoints.query.get_temp_agent", + return_value=(mock_agent, "session_123", "conversation_456"), + ) + + # Mock the agent's create_turn method + mock_agent.create_turn.return_value = mock_response + + # Mock the interleaved_content_as_str function + mocker.patch( + "app.endpoints.query.interleaved_content_as_str", return_value="Topic summary" + ) + + # Mock the get_topic_summary_system_prompt function + mocker.patch( + "app.endpoints.query.get_topic_summary_system_prompt", + return_value="Custom system prompt", + ) + + # Mock the configuration + mock_config = mocker.Mock() + mocker.patch("app.endpoints.query.configuration", mock_config) + + # Call the function + result = await get_topic_summary( + question="Test question?", client=mock_client, model_id="custom_model" + ) + + # Assertions + assert result == "Topic summary" + + # Verify get_temp_agent was called with correct parameters + mock_get_temp_agent.assert_called_once_with( + mock_client, "custom_model", "Custom system prompt" + ) + + +@pytest.mark.asyncio +async def test_get_topic_summary_create_turn_parameters(mocker): + """Test that get_topic_summary calls create_turn with correct parameters.""" + # Mock the dependencies + mock_client = mocker.AsyncMock() + mock_agent = mocker.AsyncMock() + mock_response = mocker.Mock() + mock_response.output_message.content = "Topic summary" + + # Mock the get_temp_agent function + mocker.patch( + "app.endpoints.query.get_temp_agent", + return_value=(mock_agent, "test_session", "test_conversation"), + ) + + # Mock the agent's create_turn method + mock_agent.create_turn.return_value = mock_response + + # Mock the interleaved_content_as_str function + mocker.patch( + "app.endpoints.query.interleaved_content_as_str", return_value="Topic summary" + ) + + # Mock the get_topic_summary_system_prompt function + mocker.patch( + "app.endpoints.query.get_topic_summary_system_prompt", + return_value="Custom system prompt", + ) + + # Mock the configuration + mock_config = mocker.Mock() + mocker.patch("app.endpoints.query.configuration", mock_config) + + # Call the function + result = await get_topic_summary( + question="What is the meaning of life?", + client=mock_client, + model_id="test_model", + ) + + # Assertions + assert result == "Topic summary" + + # Verify create_turn was called with correct parameters + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(role="user", content="What is the meaning of life?")], + session_id="test_session", + stream=False, + toolgroups=None, + ) diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 38983666a..3e2ecf580 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -58,6 +58,13 @@ def mock_database_operations(mocker): ) mocker.patch("app.endpoints.streaming_query.persist_user_conversation_details") + # Mock the database session and query + mock_session = mocker.Mock() + mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.__enter__ = mocker.Mock(return_value=mock_session) + mock_session.__exit__ = mocker.Mock(return_value=None) + mocker.patch("app.endpoints.streaming_query.get_session", return_value=mock_session) + def mock_metrics(mocker): """Helper function to mock metrics operations for streaming query endpoints.""" @@ -283,6 +290,12 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) ) mock_transcript = mocker.patch("app.endpoints.streaming_query.store_transcript") + # Mock get_topic_summary function + mocker.patch( + "app.endpoints.streaming_query.get_topic_summary", + return_value="Test topic summary", + ) + mock_database_operations(mocker) query_request = QueryRequest(query=query) @@ -1308,6 +1321,11 @@ async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker): mocker.patch( "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False ) + # Mock get_topic_summary function + mocker.patch( + "app.endpoints.streaming_query.get_topic_summary", + return_value="Test topic summary", + ) mock_database_operations(mocker) request = Request( @@ -1354,6 +1372,11 @@ async def test_streaming_query_endpoint_handler_no_tools_true(mocker): mocker.patch( "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False ) + # Mock get_topic_summary function + mocker.patch( + "app.endpoints.streaming_query.get_topic_summary", + return_value="Test topic summary", + ) # Mock database operations mock_database_operations(mocker) @@ -1401,6 +1424,11 @@ async def test_streaming_query_endpoint_handler_no_tools_false(mocker): mocker.patch( "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False ) + # Mock get_topic_summary function + mocker.patch( + "app.endpoints.streaming_query.get_topic_summary", + return_value="Test topic summary", + ) # Mock database operations mock_database_operations(mocker) diff --git a/tests/unit/utils/test_endpoints.py b/tests/unit/utils/test_endpoints.py index 04701ac48..47ba3f7b8 100644 --- a/tests/unit/utils/test_endpoints.py +++ b/tests/unit/utils/test_endpoints.py @@ -10,7 +10,7 @@ from models.requests import QueryRequest from models.config import Action from utils import endpoints -from utils.endpoints import get_agent +from utils.endpoints import get_agent, get_temp_agent from tests.unit import config_dict @@ -657,6 +657,108 @@ async def test_get_agent_no_tools_false_preserves_parser( ) +@pytest.mark.asyncio +async def test_get_temp_agent_basic_functionality(prepare_agent_mocks, mocker): + """Test get_temp_agent function creates agent with correct parameters.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "temp_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "utils.endpoints.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("utils.endpoints.get_suid", return_value="temp_session_id") + + # Call function + result_agent, result_session_id, result_conversation_id = await get_temp_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + ) + + # Assert agent, session_id, and conversation_id are created and returned + assert result_agent == mock_agent + assert result_session_id == "temp_session_id" + assert result_conversation_id == mock_agent.agent_id + + # Verify Agent was created with correct parameters for temporary agent + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + enable_session_persistence=False, # Key difference: no persistence + ) + + # Verify agent was initialized and session was created + mock_agent.initialize.assert_called_once() + mock_agent.create_session.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_temp_agent_returns_valid_ids(prepare_agent_mocks, mocker): + """Test get_temp_agent function returns valid agent_id and session_id.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.agent_id = "generated_agent_id" + mock_agent.create_session.return_value = "generated_session_id" + + # Mock Agent class + mocker.patch("utils.endpoints.AsyncAgent", return_value=mock_agent) + + # Mock get_suid + mocker.patch("utils.endpoints.get_suid", return_value="generated_session_id") + + # Call function + result_agent, result_session_id, result_conversation_id = await get_temp_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + ) + + # Assert all three values are returned and are not None/empty + assert result_agent is not None + assert result_session_id is not None + assert result_conversation_id is not None + + # Assert they are strings + assert isinstance(result_session_id, str) + assert isinstance(result_conversation_id, str) + + # Assert conversation_id matches agent_id + assert result_conversation_id == result_agent.agent_id + + +@pytest.mark.asyncio +async def test_get_temp_agent_no_persistence(prepare_agent_mocks, mocker): + """Test get_temp_agent function creates agent without session persistence.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "temp_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "utils.endpoints.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("utils.endpoints.get_suid", return_value="temp_session_id") + + # Call function + await get_temp_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + ) + + # Verify Agent was created with session persistence disabled + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + enable_session_persistence=False, + ) + + def test_validate_model_provider_override_allowed_with_action(): """Ensure no exception when caller has MODEL_OVERRIDE and request includes model/provider.""" query_request = QueryRequest(query="q", model="m", provider="p") @@ -677,3 +779,61 @@ def test_validate_model_provider_override_no_override_without_action(): """No exception when request does not include model/provider regardless of permission.""" query_request = QueryRequest(query="q") endpoints.validate_model_provider_override(query_request, set()) + + +def test_get_topic_summary_system_prompt_default(setup_configuration): + """Test that default topic summary system prompt is returned when no custom profile is configured.""" + topic_summary_prompt = endpoints.get_topic_summary_system_prompt( + setup_configuration + ) + assert topic_summary_prompt == constants.DEFAULT_TOPIC_SUMMARY_SYSTEM_PROMPT + + +def test_get_topic_summary_system_prompt_with_custom_profile(): + """Test that custom profile topic summary prompt is returned when available.""" + test_config = config_dict.copy() + test_config["customization"] = { + "profile_path": "tests/profiles/test/profile.py", + } + cfg = AppConfig() + cfg.init_from_dict(test_config) + + # Mock the custom profile to return a topic_summary prompt + custom_profile = CustomProfile(path="tests/profiles/test/profile.py") + prompts = custom_profile.get_prompts() + + topic_summary_prompt = endpoints.get_topic_summary_system_prompt(cfg) + assert topic_summary_prompt == prompts.get("topic_summary") + + +def test_get_topic_summary_system_prompt_with_custom_profile_no_topic_summary(mocker): + """Test that default topic summary prompt is returned when custom profile has no topic_summary prompt.""" + test_config = config_dict.copy() + test_config["customization"] = { + "profile_path": "tests/profiles/test/profile.py", + } + cfg = AppConfig() + cfg.init_from_dict(test_config) + + # Mock the custom profile to return None for topic_summary prompt + mock_profile = mocker.Mock() + mock_profile.get_prompts.return_value = { + "default": "some prompt" + } # No topic_summary key + + # Patch the custom_profile property to return our mock + mocker.patch.object(cfg.customization, "custom_profile", mock_profile) + + topic_summary_prompt = endpoints.get_topic_summary_system_prompt(cfg) + assert topic_summary_prompt == constants.DEFAULT_TOPIC_SUMMARY_SYSTEM_PROMPT + + +def test_get_topic_summary_system_prompt_no_customization(): + """Test that default topic summary prompt is returned when customization is None.""" + test_config = config_dict.copy() + test_config["customization"] = None + cfg = AppConfig() + cfg.init_from_dict(test_config) + + topic_summary_prompt = endpoints.get_topic_summary_system_prompt(cfg) + assert topic_summary_prompt == constants.DEFAULT_TOPIC_SUMMARY_SYSTEM_PROMPT From 52f53d21880f4e655dfcde54c9b28fd0e6f65b2b Mon Sep 17 00:00:00 2001 From: Stephanie Date: Fri, 19 Sep 2025 11:22:56 -0400 Subject: [PATCH 2/9] generate docs Signed-off-by: Stephanie --- docs/openapi.json | 98 +++++++++++++++++++++++++++++++++++++++++++++-- src/constants.py | 61 ++++++++++++++++++++++++++++- 2 files changed, 153 insertions(+), 6 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index dc7dfbb63..4185a1557 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1155,6 +1155,9 @@ }, "inference": { "$ref": "#/components/schemas/InferenceConfiguration" + }, + "conversation_cache": { + "$ref": "#/components/schemas/ConversationCache" } }, "additionalProperties": false, @@ -1168,6 +1171,60 @@ "title": "Configuration", "description": "Global service configuration." }, + "ConversationCache": { + "properties": { + "type": { + "anyOf": [ + { + "type": "string", + "enum": [ + "memory", + "sqlite", + "postgres" + ] + }, + { + "type": "null" + } + ], + "title": "Type" + }, + "memory": { + "anyOf": [ + { + "$ref": "#/components/schemas/InMemoryCacheConfig" + }, + { + "type": "null" + } + ] + }, + "sqlite": { + "anyOf": [ + { + "$ref": "#/components/schemas/SQLiteDatabaseConfiguration" + }, + { + "type": "null" + } + ] + }, + "postgres": { + "anyOf": [ + { + "$ref": "#/components/schemas/PostgreSQLDatabaseConfiguration" + }, + { + "type": "null" + } + ] + } + }, + "additionalProperties": false, + "type": "object", + "title": "ConversationCache", + "description": "Conversation cache configuration." + }, "ConversationDeleteResponse": { "properties": { "conversation_id": { @@ -1285,6 +1342,21 @@ "openai", "gemini" ] + }, + "topic_summary": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Topic Summary", + "description": "Topic summary for the conversation", + "examples": [ + "Openshift Microservices Deployment Strategies" + ] } }, "type": "object", @@ -1292,7 +1364,7 @@ "conversation_id" ], "title": "ConversationDetails", - "description": "Model representing the details of a user conversation.\n\nAttributes:\n conversation_id: The conversation ID (UUID).\n created_at: When the conversation was created.\n last_message_at: When the last message was sent.\n message_count: Number of user messages in the conversation.\n last_used_model: The last model used for the conversation.\n last_used_provider: The provider of the last used model.\n\nExample:\n ```python\n conversation = ConversationDetails(\n conversation_id=\"123e4567-e89b-12d3-a456-426614174000\"\n created_at=\"2024-01-01T00:00:00Z\",\n last_message_at=\"2024-01-01T00:05:00Z\",\n message_count=5,\n last_used_model=\"gemini/gemini-2.0-flash\",\n last_used_provider=\"gemini\",\n )\n ```" + "description": "Model representing the details of a user conversation.\n\nAttributes:\n conversation_id: The conversation ID (UUID).\n created_at: When the conversation was created.\n last_message_at: When the last message was sent.\n message_count: Number of user messages in the conversation.\n last_used_model: The last model used for the conversation.\n last_used_provider: The provider of the last used model.\n topic_summary: The topic summary for the conversation.\n\nExample:\n ```python\n conversation = ConversationDetails(\n conversation_id=\"123e4567-e89b-12d3-a456-426614174000\"\n created_at=\"2024-01-01T00:00:00Z\",\n last_message_at=\"2024-01-01T00:05:00Z\",\n message_count=5,\n last_used_model=\"gemini/gemini-2.0-flash\",\n last_used_provider=\"gemini\",\n topic_summary=\"Openshift Microservices Deployment Strategies\",\n )\n ```" }, "ConversationResponse": { "properties": { @@ -1353,7 +1425,7 @@ "conversations" ], "title": "ConversationsListResponse", - "description": "Model representing a response for listing conversations of a user.\n\nAttributes:\n conversations: List of conversation details associated with the user.\n\nExample:\n ```python\n conversations_list = ConversationsListResponse(\n conversations=[\n ConversationDetails(\n conversation_id=\"123e4567-e89b-12d3-a456-426614174000\",\n created_at=\"2024-01-01T00:00:00Z\",\n last_message_at=\"2024-01-01T00:05:00Z\",\n message_count=5,\n last_used_model=\"gemini/gemini-2.0-flash\",\n last_used_provider=\"gemini\",\n ),\n ConversationDetails(\n conversation_id=\"456e7890-e12b-34d5-a678-901234567890\"\n created_at=\"2024-01-01T01:00:00Z\",\n message_count=2,\n last_used_model=\"gemini/gemini-2.0-flash\",\n last_used_provider=\"gemini\",\n )\n ]\n )\n ```", + "description": "Model representing a response for listing conversations of a user.\n\nAttributes:\n conversations: List of conversation details associated with the user.\n\nExample:\n ```python\n conversations_list = ConversationsListResponse(\n conversations=[\n ConversationDetails(\n conversation_id=\"123e4567-e89b-12d3-a456-426614174000\",\n created_at=\"2024-01-01T00:00:00Z\",\n last_message_at=\"2024-01-01T00:05:00Z\",\n message_count=5,\n last_used_model=\"gemini/gemini-2.0-flash\",\n last_used_provider=\"gemini\",\n topic_summary=\"Openshift Microservices Deployment Strategies\",\n ),\n ConversationDetails(\n conversation_id=\"456e7890-e12b-34d5-a678-901234567890\"\n created_at=\"2024-01-01T01:00:00Z\",\n message_count=2,\n last_used_model=\"gemini/gemini-2.0-flash\",\n last_used_provider=\"gemini\",\n topic_summary=\"RHDH Purpose Summary\",\n )\n ]\n )\n ```", "examples": [ { "conversations": [ @@ -1363,14 +1435,16 @@ "last_message_at": "2024-01-01T00:05:00Z", "last_used_model": "gemini/gemini-2.0-flash", "last_used_provider": "gemini", - "message_count": 5 + "message_count": 5, + "topic_summary": "Openshift Microservices Deployment Strategies" }, { "conversation_id": "456e7890-e12b-34d5-a678-901234567890", "created_at": "2024-01-01T01:00:00Z", "last_used_model": "gemini/gemini-2.5-flash", "last_used_provider": "gemini", - "message_count": 2 + "message_count": 2, + "topic_summary": "RHDH Purpose Summary" } ] } @@ -1751,6 +1825,22 @@ "type": "object", "title": "HTTPValidationError" }, + "InMemoryCacheConfig": { + "properties": { + "max_entries": { + "type": "integer", + "exclusiveMinimum": 0.0, + "title": "Max Entries" + } + }, + "additionalProperties": false, + "type": "object", + "required": [ + "max_entries" + ], + "title": "InMemoryCacheConfig", + "description": "In-memory cache configuration." + }, "InferenceConfiguration": { "properties": { "default_model": { diff --git a/src/constants.py b/src/constants.py index f73270166..26a05213e 100644 --- a/src/constants.py +++ b/src/constants.py @@ -29,8 +29,65 @@ DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant" # Default topic summary system prompt used only when no other topic summary system prompt is specified in -# configuration file nor in the query request -DEFAULT_TOPIC_SUMMARY_SYSTEM_PROMPT = "You are a topic summarizer" +# configuration file +DEFAULT_TOPIC_SUMMARY_SYSTEM_PROMPT = """ +Instructions: +- You are a topic summarizer +- Your job is to extract precise topic summary from user input + +For Input Analysis: +- Scan entire user message +- Identify core subject matter +- Distill essence into concise descriptor +- Prioritize key concepts +- Eliminate extraneous details + +For Output Constraints: +- Maximum 5 words +- Capitalize only significant words (e.g., nouns, verbs, adjectives, adverbs). +- Do not use all uppercase - capitalize only the first letter of significant words +- Exclude articles and prepositions (e.g., "a," "the," "of," "on," "in") +- Exclude all punctuation and interpunction marks (e.g., . , : ; ! ? "") +- Retain original abbreviations. Do not expand an abbreviation if its specific meaning in the context is unknown or ambiguous. +- Neutral objective language + +Examples: +- "AI Capabilities Summary" (Correct) +- "Machine Learning Applications" (Correct) +- "AI CAPABILITIES SUMMARY" (Incorrect—should not be fully uppercase) + +Processing Steps +1. Analyze semantic structure +2. Identify primary topic +3. Remove contextual noise +4. Condense to essential meaning +5. Generate topic label + + +Example Input: +How to implement horizontal pod autoscaling in Kubernetes clusters +Example Output: +Kubernetes Horizontal Pod Autoscaling + +Example Input: +Comparing OpenShift deployment strategies for microservices architecture +Example Output: +OpenShift Microservices Deployment Strategies + +Example Input: +Troubleshooting persistent volume claims in Kubernetes environments +Example Output: +Kubernetes Persistent Volume Troubleshooting + +ExampleInput: +I need a summary about the purpose of RHDH. +Example Output: +RHDH Purpose Summary + +Input: +{query} +Output: +""" # Authentication constants DEFAULT_VIRTUAL_PATH = "/ls-access" From f3da43eeacb91f04f117b055f811ef6140e3142f Mon Sep 17 00:00:00 2001 From: Stephanie Date: Fri, 19 Sep 2025 11:37:54 -0400 Subject: [PATCH 3/9] fix format Signed-off-by: Stephanie --- src/app/endpoints/query.py | 3 ++- src/app/endpoints/streaming_query.py | 26 +++++++++++++------------- src/constants.py | 4 ++-- src/utils/endpoints.py | 4 +++- tests/unit/utils/test_endpoints.py | 8 ++++++-- 5 files changed, 26 insertions(+), 19 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 1915051c0..34dca2966 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -160,6 +160,7 @@ async def get_topic_summary( question: str, client: AsyncLlamaStackClient, model_id: str ) -> str: """Get a topic summary for a question. + Args: question: The question to be validated. client: The AsyncLlamaStackClient to use for the request. @@ -168,7 +169,7 @@ async def get_topic_summary( str: The topic summary for the question. """ topic_summary_system_prompt = get_topic_summary_system_prompt(configuration) - agent, session_id, conversation_id = await get_temp_agent( + agent, session_id, _ = await get_temp_agent( client, model_id, topic_summary_system_prompt ) response = await agent.create_turn( diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 985882ef7..24f7291c7 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -19,7 +19,20 @@ from fastapi import APIRouter, HTTPException, Request, Depends, status from fastapi.responses import StreamingResponse + from app.database import get_session +from app.endpoints.query import ( + get_rag_toolgroups, + is_input_shield, + is_output_shield, + is_transcripts_enabled, + select_model_and_provider_id, + validate_attachments_metadata, + validate_conversation_ownership, + persist_user_conversation_details, + evaluate_model_hints, + get_topic_summary, +) from authentication import get_auth_dependency from authentication.interface import AuthTuple from authorization.middleware import authorize @@ -37,19 +50,6 @@ from utils.types import TurnSummary from utils.endpoints import validate_model_provider_override -from app.endpoints.query import ( - get_rag_toolgroups, - is_input_shield, - is_output_shield, - is_transcripts_enabled, - select_model_and_provider_id, - validate_attachments_metadata, - validate_conversation_ownership, - persist_user_conversation_details, - evaluate_model_hints, - get_topic_summary, -) - logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["streaming_query"]) auth_dependency = get_auth_dependency() diff --git a/src/constants.py b/src/constants.py index 26a05213e..086047d90 100644 --- a/src/constants.py +++ b/src/constants.py @@ -28,8 +28,8 @@ # configuration file nor in the query request DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant" -# Default topic summary system prompt used only when no other topic summary system prompt is specified in -# configuration file +# Default topic summary system prompt used only when no other topic summary system +# prompt is specified in configuration file DEFAULT_TOPIC_SUMMARY_SYSTEM_PROMPT = """ Instructions: - You are a topic summarizer diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index c7ffaa445..18c181eba 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -212,8 +212,10 @@ async def get_temp_agent( system_prompt: str, ) -> tuple[AsyncAgent, str, str]: """Create a temporary agent with new agent_id and session_id. + This function creates a new agent without persistence, shields, or tools. - Useful for temporary operations or one-off queries, such as validating a question or generating a summary. + Useful for temporary operations or one-off queries, such as validating a + question or generating a summary. Args: client: The AsyncLlamaStackClient to use for the request. model_id: The ID of the model to use. diff --git a/tests/unit/utils/test_endpoints.py b/tests/unit/utils/test_endpoints.py index 47ba3f7b8..bed970a56 100644 --- a/tests/unit/utils/test_endpoints.py +++ b/tests/unit/utils/test_endpoints.py @@ -782,7 +782,9 @@ def test_validate_model_provider_override_no_override_without_action(): def test_get_topic_summary_system_prompt_default(setup_configuration): - """Test that default topic summary system prompt is returned when no custom profile is configured.""" + """Test that default topic summary system prompt is returned when no custom + profile is configured. + """ topic_summary_prompt = endpoints.get_topic_summary_system_prompt( setup_configuration ) @@ -807,7 +809,9 @@ def test_get_topic_summary_system_prompt_with_custom_profile(): def test_get_topic_summary_system_prompt_with_custom_profile_no_topic_summary(mocker): - """Test that default topic summary prompt is returned when custom profile has no topic_summary prompt.""" + """Test that default topic summary prompt is returned when custom profile has + no topic_summary prompt. + """ test_config = config_dict.copy() test_config["customization"] = { "profile_path": "tests/profiles/test/profile.py", From 5c5d85f232b8b8ac292b0e2f9f562955ef9e704c Mon Sep 17 00:00:00 2001 From: Stephanie Date: Mon, 22 Sep 2025 11:25:14 -0400 Subject: [PATCH 4/9] ignore pylint error Signed-off-by: Stephanie --- src/app/endpoints/query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 34dca2966..999f197f1 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -191,7 +191,7 @@ async def get_topic_summary( @router.post("/query", responses=query_response) @authorize(Action.QUERY) -async def query_endpoint_handler( +async def query_endpoint_handler( # pylint: disable=R0914 request: Request, query_request: QueryRequest, auth: Annotated[AuthTuple, Depends(auth_dependency)], From ed95e0150cd06a2af87c642e9b7160cec306be91 Mon Sep 17 00:00:00 2001 From: Stephanie Date: Mon, 29 Sep 2025 14:34:11 -0400 Subject: [PATCH 5/9] fix skip_userid_check Signed-off-by: Stephanie --- src/app/endpoints/conversations_v2.py | 12 +++++++++--- src/app/endpoints/query.py | 3 ++- src/app/endpoints/streaming_query.py | 1 + src/utils/endpoints.py | 3 ++- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/app/endpoints/conversations_v2.py b/src/app/endpoints/conversations_v2.py index 5033b5e5f..8ec34869d 100644 --- a/src/app/endpoints/conversations_v2.py +++ b/src/app/endpoints/conversations_v2.py @@ -102,6 +102,8 @@ async def get_conversations_list_endpoint_handler( logger.info("Retrieving conversations for user %s", user_id) + skip_userid_check = auth[2] + if configuration.conversation_cache is None: logger.warning("Converastion cache is not configured") raise HTTPException( @@ -112,7 +114,7 @@ async def get_conversations_list_endpoint_handler( }, ) - conversations = configuration.conversation_cache.list(user_id, False) + conversations = configuration.conversation_cache.list(user_id, skip_userid_check) logger.info("Conversations for user %s: %s", user_id, len(conversations)) return ConversationsListResponseV2(conversations=conversations) @@ -132,6 +134,8 @@ async def get_conversation_endpoint_handler( user_id = auth[0] logger.info("Retrieving conversation %s for user %s", conversation_id, user_id) + skip_userid_check = auth[2] + if configuration.conversation_cache is None: logger.warning("Converastion cache is not configured") raise HTTPException( @@ -144,7 +148,7 @@ async def get_conversation_endpoint_handler( check_conversation_existence(user_id, conversation_id) - conversation = configuration.conversation_cache.get(user_id, conversation_id, False) + conversation = configuration.conversation_cache.get(user_id, conversation_id, skip_userid_check) chat_history = [transform_chat_message(entry) for entry in conversation] return ConversationResponse( @@ -168,6 +172,8 @@ async def delete_conversation_endpoint_handler( user_id = auth[0] logger.info("Deleting conversation %s for user %s", conversation_id, user_id) + skip_userid_check = auth[2] + if configuration.conversation_cache is None: logger.warning("Converastion cache is not configured") raise HTTPException( @@ -181,7 +187,7 @@ async def delete_conversation_endpoint_handler( check_conversation_existence(user_id, conversation_id) logger.info("Deleting conversation %s for user %s", conversation_id, user_id) - deleted = configuration.conversation_cache.delete(user_id, conversation_id, False) + deleted = configuration.conversation_cache.delete(user_id, conversation_id, skip_userid_check) if deleted: return ConversationDeleteResponse( diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index b123e4010..75e6801cc 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -240,7 +240,7 @@ async def query_endpoint_handler( # pylint: disable=R0914 # log Llama Stack configuration logger.info("Llama stack config: %s", configuration.llama_stack_configuration) - user_id, _, _, token = auth + user_id, _, _skip_userid_check, token = auth user_conversation: UserConversation | None = None if query_request.conversation_id: @@ -339,6 +339,7 @@ async def query_endpoint_handler( # pylint: disable=R0914 model_id, query_request.query, summary.llm_response, + _skip_userid_check ) # Convert tool calls to response format diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 629c18a72..beb5ee3b5 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -720,6 +720,7 @@ async def response_generator( model_id, query_request.query, summary.llm_response, + _skip_userid_check ) # Get the initial topic summary for the conversation diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index 7d04349c5..1f524962f 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -189,6 +189,7 @@ def store_conversation_into_cache( model_id: str, query: str, response: str, + _skip_userid_check: bool ) -> None: """Store one part of conversation into conversation history cache.""" if config.conversation_cache_configuration.type is not None: @@ -202,7 +203,7 @@ def store_conversation_into_cache( provider=provider_id, model=model_id, ) - cache.insert_or_append(user_id, conversation_id, cache_entry, False) + cache.insert_or_append(user_id, conversation_id, cache_entry, _skip_userid_check) # # pylint: disable=R0913,R0917 From c9769e879084d50f90217e117a07f1075543f4f3 Mon Sep 17 00:00:00 2001 From: Stephanie Date: Mon, 29 Sep 2025 15:29:17 -0400 Subject: [PATCH 6/9] add topic summary Signed-off-by: Stephanie --- docs/openapi.json | 39 ++++++++- src/app/endpoints/conversations_v2.py | 13 ++- src/app/endpoints/query.py | 3 +- src/app/endpoints/streaming_query.py | 44 +++++------ src/cache/cache.py | 23 +++++- src/cache/in_memory_cache.py | 25 +++++- src/cache/noop_cache.py | 25 +++++- src/cache/postgres_cache.py | 101 ++++++++++++++++++++++-- src/cache/sqlite_cache.py | 98 +++++++++++++++++++++-- src/models/cache_entry.py | 14 ++++ src/models/responses.py | 6 +- src/utils/endpoints.py | 11 ++- tests/unit/cache/test_postgres_cache.py | 66 +++++++++++++++- tests/unit/cache/test_sqlite_cache.py | 94 +++++++++++++++++++++- 14 files changed, 503 insertions(+), 59 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index a33374c01..8ac9be564 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -744,7 +744,9 @@ }, "conversations": [ { - "conversation_id": "123e4567-e89b-12d3-a456-426614174000" + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "topic_summary": "This is a topic summary", + "last_message_timestamp": "2024-01-01T00:00:00Z" } ] } @@ -1419,6 +1421,37 @@ "title": "ConversationCacheConfiguration", "description": "Conversation cache configuration." }, + "ConversationData": { + "properties": { + "conversation_id": { + "type": "string", + "title": "Conversation Id" + }, + "topic_summary": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Topic Summary" + }, + "last_message_timestamp": { + "type": "number", + "title": "Last Message Timestamp" + } + }, + "type": "object", + "required": [ + "conversation_id", + "topic_summary", + "last_message_timestamp" + ], + "title": "ConversationData", + "description": "Model representing conversation data returned by cache list operations.\n\nAttributes:\n conversation_id: The conversation ID\n topic_summary: The topic summary for the conversation (can be None)\n last_message_timestamp: The timestamp of the last message in the conversation" + }, "ConversationDeleteResponse": { "properties": { "conversation_id": { @@ -1648,7 +1681,7 @@ "properties": { "conversations": { "items": { - "type": "string" + "$ref": "#/components/schemas/ConversationData" }, "type": "array", "title": "Conversations" @@ -1659,7 +1692,7 @@ "conversations" ], "title": "ConversationsListResponseV2", - "description": "Model representing a response for listing conversations of a user.\n\nAttributes:\n conversations: List of conversation IDs associated with the user." + "description": "Model representing a response for listing conversations of a user.\n\nAttributes:\n conversations: List of conversation data associated with the user." }, "CustomProfile": { "properties": { diff --git a/src/app/endpoints/conversations_v2.py b/src/app/endpoints/conversations_v2.py index 8ec34869d..16932162b 100644 --- a/src/app/endpoints/conversations_v2.py +++ b/src/app/endpoints/conversations_v2.py @@ -83,6 +83,8 @@ "conversations": [ { "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "topic_summary": "This is a topic summary", + "last_message_timestamp": "2024-01-01T00:00:00Z", } ] } @@ -148,7 +150,9 @@ async def get_conversation_endpoint_handler( check_conversation_existence(user_id, conversation_id) - conversation = configuration.conversation_cache.get(user_id, conversation_id, skip_userid_check) + conversation = configuration.conversation_cache.get( + user_id, conversation_id, skip_userid_check + ) chat_history = [transform_chat_message(entry) for entry in conversation] return ConversationResponse( @@ -187,7 +191,9 @@ async def delete_conversation_endpoint_handler( check_conversation_existence(user_id, conversation_id) logger.info("Deleting conversation %s for user %s", conversation_id, user_id) - deleted = configuration.conversation_cache.delete(user_id, conversation_id, skip_userid_check) + deleted = configuration.conversation_cache.delete( + user_id, conversation_id, skip_userid_check + ) if deleted: return ConversationDeleteResponse( @@ -221,7 +227,8 @@ def check_conversation_existence(user_id: str, conversation_id: str) -> None: if configuration.conversation_cache is None: return conversations = configuration.conversation_cache.list(user_id, False) - if conversation_id not in conversations: + conversation_ids = [conv.conversation_id for conv in conversations] + if conversation_id not in conversation_ids: logger.error("No conversation found for conversation ID %s", conversation_id) raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 75e6801cc..3f75a6cd4 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -339,7 +339,8 @@ async def query_endpoint_handler( # pylint: disable=R0914 model_id, query_request.query, summary.llm_response, - _skip_userid_check + _skip_userid_check, + topic_summary, ) # Convert tool calls to response format diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index beb5ee3b5..0bd868a5d 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -20,13 +20,10 @@ from llama_stack_client.types.shared import ToolCall from llama_stack_client.types.shared.interleaved_content_item import TextContentItem -from fastapi import APIRouter, HTTPException, Request, Depends, status -from fastapi.responses import StreamingResponse from app.database import get_session import metrics from app.endpoints.query import ( - evaluate_model_hints, get_rag_toolgroups, is_input_shield, is_output_shield, @@ -59,7 +56,6 @@ from utils.mcp_headers import handle_mcp_headers_with_toolgroups, mcp_headers_dependency from utils.transcripts import store_transcript from utils.types import TurnSummary -from utils.endpoints import validate_model_provider_override logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["streaming_query"]) @@ -712,6 +708,19 @@ async def response_generator( attachments=query_request.attachments or [], ) + # Get the initial topic summary for the conversation + topic_summary = None + with get_session() as session: + existing_conversation = ( + session.query(UserConversation) + .filter_by(id=conversation_id) + .first() + ) + if not existing_conversation: + topic_summary = await get_topic_summary( + query_request.query, client, model_id + ) + store_conversation_into_cache( configuration, user_id, @@ -720,28 +729,17 @@ async def response_generator( model_id, query_request.query, summary.llm_response, - _skip_userid_check + _skip_userid_check, + topic_summary, ) - # Get the initial topic summary for the conversation - topic_summary = None - with get_session() as session: - existing_conversation = ( - session.query(UserConversation).filter_by(id=conversation_id).first() + persist_user_conversation_details( + user_id=user_id, + conversation_id=conversation_id, + model=model_id, + provider_id=provider_id, + topic_summary=topic_summary, ) - if not existing_conversation: - topic_summary = await get_topic_summary( - query_request.query, client, model_id - ) - - - persist_user_conversation_details( - user_id=user_id, - conversation_id=conversation_id, - model=model_id, - provider_id=provider_id, - topic_summary=topic_summary, - ) # Update metrics for the LLM call metrics.llm_calls_total.labels(provider_id, model_id).inc() diff --git a/src/cache/cache.py b/src/cache/cache.py index 98b087a41..263cb322d 100644 --- a/src/cache/cache.py +++ b/src/cache/cache.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod -from models.cache_entry import CacheEntry +from models.cache_entry import CacheEntry, ConversationData from utils.suid import check_suid @@ -90,7 +90,7 @@ def delete( """ @abstractmethod - def list(self, user_id: str, skip_user_id_check: bool) -> list[str]: + def list(self, user_id: str, skip_user_id_check: bool) -> list[ConversationData]: """List all conversations for a given user_id. Args: @@ -98,7 +98,24 @@ def list(self, user_id: str, skip_user_id_check: bool) -> list[str]: skip_user_id_check: Skip user_id suid check. Returns: - A list of conversation ids from the cache + A list of ConversationData objects containing conversation_id, topic_summary, and last_message_timestamp + """ + + @abstractmethod + def set_topic_summary( + self, + user_id: str, + conversation_id: str, + topic_summary: str, + skip_user_id_check: bool, + ) -> None: + """Abstract method to store topic summary in the cache. + + Args: + user_id: User identification. + conversation_id: Conversation ID unique for given user. + topic_summary: The topic summary to store. + skip_user_id_check: Skip user_id suid check. """ @abstractmethod diff --git a/src/cache/in_memory_cache.py b/src/cache/in_memory_cache.py index 7c29bd2a3..1b6b4123f 100644 --- a/src/cache/in_memory_cache.py +++ b/src/cache/in_memory_cache.py @@ -1,7 +1,7 @@ """In-memory cache implementation.""" from cache.cache import Cache -from models.cache_entry import CacheEntry +from models.cache_entry import CacheEntry, ConversationData from models.config import InMemoryCacheConfig from log import get_logger from utils.connection_decorator import connection @@ -85,7 +85,9 @@ def delete( return True @connection - def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]: + def list( + self, user_id: str, skip_user_id_check: bool = False + ) -> list[ConversationData]: """List all conversations for a given user_id. Args: @@ -99,6 +101,25 @@ def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]: super()._check_user_id(user_id, skip_user_id_check) return [] + @connection + def set_topic_summary( + self, + user_id: str, + conversation_id: str, + topic_summary: str, + skip_user_id_check: bool = False, + ) -> None: + """Set the topic summary for the given conversation. + + Args: + user_id: User identification. + conversation_id: Conversation ID unique for given user. + topic_summary: The topic summary to store. + skip_user_id_check: Skip user_id suid check. + """ + # just check if user_id and conversation_id are UUIDs + super().construct_key(user_id, conversation_id, skip_user_id_check) + def ready(self) -> bool: """Check if the cache is ready. diff --git a/src/cache/noop_cache.py b/src/cache/noop_cache.py index 5c24271fa..fcd20f368 100644 --- a/src/cache/noop_cache.py +++ b/src/cache/noop_cache.py @@ -1,7 +1,7 @@ """No-operation cache implementation.""" from cache.cache import Cache -from models.cache_entry import CacheEntry +from models.cache_entry import CacheEntry, ConversationData from log import get_logger from utils.connection_decorator import connection @@ -83,7 +83,9 @@ def delete( return True @connection - def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]: + def list( + self, user_id: str, skip_user_id_check: bool = False + ) -> list[ConversationData]: """List all conversations for a given user_id. Args: @@ -97,6 +99,25 @@ def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]: super()._check_user_id(user_id, skip_user_id_check) return [] + @connection + def set_topic_summary( + self, + user_id: str, + conversation_id: str, + topic_summary: str, + skip_user_id_check: bool = False, + ) -> None: + """Set the topic summary for the given conversation. + + Args: + user_id: User identification. + conversation_id: Conversation ID unique for given user. + topic_summary: The topic summary to store. + skip_user_id_check: Skip user_id suid check. + """ + # just check if user_id and conversation_id are UUIDs + super().construct_key(user_id, conversation_id, skip_user_id_check) + def ready(self) -> bool: """Check if the cache is ready. diff --git a/src/cache/postgres_cache.py b/src/cache/postgres_cache.py index ae591b84f..3d9fb17a4 100644 --- a/src/cache/postgres_cache.py +++ b/src/cache/postgres_cache.py @@ -4,7 +4,7 @@ from cache.cache import Cache from cache.cache_error import CacheError -from models.cache_entry import CacheEntry +from models.cache_entry import CacheEntry, ConversationData from models.config import PostgreSQLDatabaseConfiguration from log import get_logger from utils.connection_decorator import connection @@ -46,6 +46,16 @@ class PostgresCache(Cache): ); """ + CREATE_CONVERSATIONS_TABLE = """ + CREATE TABLE IF NOT EXISTS conversations ( + user_id text NOT NULL, + conversation_id text NOT NULL, + topic_summary text, + last_message_timestamp timestamp NOT NULL, + PRIMARY KEY(user_id, conversation_id) + ); + """ + CREATE_INDEX = """ CREATE INDEX IF NOT EXISTS timestamps ON cache (created_at) @@ -73,13 +83,31 @@ class PostgresCache(Cache): """ LIST_CONVERSATIONS_STATEMENT = """ - SELECT conversation_id, max(created_at) AS created_at - FROM cache + SELECT conversation_id, topic_summary, EXTRACT(EPOCH FROM last_message_timestamp) as last_message_timestamp + FROM conversations WHERE user_id=%s - GROUP BY conversation_id - ORDER BY created_at DESC + ORDER BY last_message_timestamp DESC """ + INSERT_OR_UPDATE_TOPIC_SUMMARY_STATEMENT = """ + INSERT INTO conversations(user_id, conversation_id, topic_summary, last_message_timestamp) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (user_id, conversation_id) + DO UPDATE SET topic_summary = EXCLUDED.topic_summary, last_message_timestamp = EXCLUDED.last_message_timestamp + """ + + DELETE_CONVERSATION_STATEMENT = """ + DELETE FROM conversations + WHERE user_id=%s AND conversation_id=%s + """ + + UPSERT_CONVERSATION_STATEMENT = """ + INSERT INTO conversations(user_id, conversation_id, topic_summary, last_message_timestamp) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (user_id, conversation_id) + DO UPDATE SET last_message_timestamp = EXCLUDED.last_message_timestamp + """ + def __init__(self, config: PostgreSQLDatabaseConfiguration) -> None: """Create a new instance of PostgreSQL cache.""" self.postgres_config = config @@ -143,6 +171,9 @@ def initialize_cache(self) -> None: logger.info("Initializing table for cache") cursor.execute(PostgresCache.CREATE_CACHE_TABLE) + logger.info("Initializing table for conversations") + cursor.execute(PostgresCache.CREATE_CONVERSATIONS_TABLE) + logger.info("Initializing index for cache") cursor.execute(PostgresCache.CREATE_INDEX) @@ -220,6 +251,12 @@ def insert_or_append( cache_entry.model, ), ) + + # Update or insert conversation record with last_message_timestamp + cursor.execute( + PostgresCache.UPSERT_CONVERSATION_STATEMENT, + (user_id, conversation_id, None), + ) # commit is implicit at this point except psycopg2.DatabaseError as e: logger.error("PostgresCache.insert_or_append: %s", e) @@ -251,13 +288,22 @@ def delete( (user_id, conversation_id), ) deleted = cursor.rowcount + + # Also delete conversation record for this conversation + cursor.execute( + PostgresCache.DELETE_CONVERSATION_STATEMENT, + (user_id, conversation_id), + ) + return deleted > 0 except psycopg2.DatabaseError as e: logger.error("PostgresCache.delete: %s", e) raise CacheError("PostgresCache.delete", e) from e @connection - def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]: + def list( + self, user_id: str, skip_user_id_check: bool = False + ) -> list[ConversationData]: """List all conversations for a given user_id. Args: @@ -265,7 +311,7 @@ def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]: skip_user_id_check: Skip user_id suid check. Returns: - A list of conversation ids from the cache + A list of ConversationData objects containing conversation_id, topic_summary, and last_message_timestamp """ if self.connection is None: @@ -276,7 +322,46 @@ def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]: cursor.execute(self.LIST_CONVERSATIONS_STATEMENT, (user_id,)) conversations = cursor.fetchall() - return [conversation[0] for conversation in conversations] + result = [] + for conversation in conversations: + conversation_data = ConversationData( + conversation_id=conversation[0], + topic_summary=conversation[1], + last_message_timestamp=float(conversation[2]), + ) + result.append(conversation_data) + + return result + + @connection + def set_topic_summary( + self, + user_id: str, + conversation_id: str, + topic_summary: str, + skip_user_id_check: bool = False, + ) -> None: + """Set the topic summary for the given conversation. + + Args: + user_id: User identification. + conversation_id: Conversation ID unique for given user. + topic_summary: The topic summary to store. + skip_user_id_check: Skip user_id suid check. + """ + if self.connection is None: + logger.error("Cache is disconnected") + raise CacheError("set_topic_summary: cache is disconnected") + + try: + with self.connection.cursor() as cursor: + cursor.execute( + self.INSERT_OR_UPDATE_TOPIC_SUMMARY_STATEMENT, + (user_id, conversation_id, topic_summary), + ) + except psycopg2.DatabaseError as e: + logger.error("PostgresCache.set_topic_summary: %s", e) + raise CacheError("PostgresCache.set_topic_summary", e) from e def ready(self) -> bool: """Check if the cache is ready. diff --git a/src/cache/sqlite_cache.py b/src/cache/sqlite_cache.py index e8ebd7679..330c2d7bd 100644 --- a/src/cache/sqlite_cache.py +++ b/src/cache/sqlite_cache.py @@ -6,7 +6,7 @@ from cache.cache import Cache from cache.cache_error import CacheError -from models.cache_entry import CacheEntry +from models.cache_entry import CacheEntry, ConversationData from models.config import SQLiteDatabaseConfiguration from log import get_logger from utils.connection_decorator import connection @@ -50,6 +50,16 @@ class SQLiteCache(Cache): ); """ + CREATE_CONVERSATIONS_TABLE = """ + CREATE TABLE IF NOT EXISTS conversations ( + user_id text NOT NULL, + conversation_id text NOT NULL, + topic_summary text, + last_message_timestamp int NOT NULL, + PRIMARY KEY(user_id, conversation_id) + ); + """ + CREATE_INDEX = """ CREATE INDEX IF NOT EXISTS timestamps ON cache (created_at) @@ -77,12 +87,27 @@ class SQLiteCache(Cache): """ LIST_CONVERSATIONS_STATEMENT = """ - SELECT DISTINCT conversation_id - FROM cache + SELECT conversation_id, topic_summary, last_message_timestamp + FROM conversations WHERE user_id=? - ORDER BY created_at DESC + ORDER BY last_message_timestamp DESC """ + INSERT_OR_UPDATE_TOPIC_SUMMARY_STATEMENT = """ + INSERT OR REPLACE INTO conversations(user_id, conversation_id, topic_summary, last_message_timestamp) + VALUES (?, ?, ?, ?) + """ + + DELETE_CONVERSATION_STATEMENT = """ + DELETE FROM conversations + WHERE user_id=? AND conversation_id=? + """ + + UPSERT_CONVERSATION_STATEMENT = """ + INSERT OR REPLACE INTO conversations(user_id, conversation_id, topic_summary, last_message_timestamp) + VALUES (?, ?, ?, ?) + """ + def __init__(self, config: SQLiteDatabaseConfiguration) -> None: """Create a new instance of SQLite cache.""" self.sqlite_config = config @@ -141,6 +166,9 @@ def initialize_cache(self) -> None: logger.info("Initializing table for cache") cursor.execute(SQLiteCache.CREATE_CACHE_TABLE) + logger.info("Initializing table for conversations") + cursor.execute(SQLiteCache.CREATE_CONVERSATIONS_TABLE) + logger.info("Initializing index for cache") cursor.execute(SQLiteCache.CREATE_INDEX) @@ -206,18 +234,26 @@ def insert_or_append( raise CacheError("insert_or_append: cache is disconnected") cursor = self.connection.cursor() + current_time = time() cursor.execute( self.INSERT_CONVERSATION_HISTORY_STATEMENT, ( user_id, conversation_id, - time(), + current_time, cache_entry.query, cache_entry.response, cache_entry.provider, cache_entry.model, ), ) + + # Update or insert conversation record with last_message_timestamp + cursor.execute( + self.UPSERT_CONVERSATION_STATEMENT, + (user_id, conversation_id, None, current_time), + ) + cursor.close() self.connection.commit() @@ -246,12 +282,21 @@ def delete( (user_id, conversation_id), ) deleted = cursor.rowcount > 0 + + # Also delete conversation record for this conversation + cursor.execute( + self.DELETE_CONVERSATION_STATEMENT, + (user_id, conversation_id), + ) + cursor.close() self.connection.commit() return deleted @connection - def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]: + def list( + self, user_id: str, skip_user_id_check: bool = False + ) -> list[ConversationData]: """List all conversations for a given user_id. Args: @@ -259,7 +304,7 @@ def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]: skip_user_id_check: Skip user_id suid check. Returns: - A list of conversation ids from the cache + A list of ConversationData objects containing conversation_id, topic_summary, and last_message_timestamp """ if self.connection is None: @@ -271,7 +316,44 @@ def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]: conversations = cursor.fetchall() cursor.close() - return [conversation[0] for conversation in conversations] + result = [] + for conversation in conversations: + conversation_data = ConversationData( + conversation_id=conversation[0], + topic_summary=conversation[1], + last_message_timestamp=conversation[2], + ) + result.append(conversation_data) + + return result + + @connection + def set_topic_summary( + self, + user_id: str, + conversation_id: str, + topic_summary: str, + skip_user_id_check: bool = False, + ) -> None: + """Set the topic summary for the given conversation. + + Args: + user_id: User identification. + conversation_id: Conversation ID unique for given user. + topic_summary: The topic summary to store. + skip_user_id_check: Skip user_id suid check. + """ + if self.connection is None: + logger.error("Cache is disconnected") + raise CacheError("set_topic_summary: cache is disconnected") + + cursor = self.connection.cursor() + cursor.execute( + self.INSERT_OR_UPDATE_TOPIC_SUMMARY_STATEMENT, + (user_id, conversation_id, topic_summary, time()), + ) + cursor.close() + self.connection.commit() def ready(self) -> bool: """Check if the cache is ready. diff --git a/src/models/cache_entry.py b/src/models/cache_entry.py index 810bad711..f87445bef 100644 --- a/src/models/cache_entry.py +++ b/src/models/cache_entry.py @@ -17,3 +17,17 @@ class CacheEntry(BaseModel): response: str provider: str model: str + + +class ConversationData(BaseModel): + """Model representing conversation data returned by cache list operations. + + Attributes: + conversation_id: The conversation ID + topic_summary: The topic summary for the conversation (can be None) + last_message_timestamp: The timestamp of the last message in the conversation + """ + + conversation_id: str + topic_summary: str | None + last_message_timestamp: float diff --git a/src/models/responses.py b/src/models/responses.py index 9c865b281..f44b79ea7 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -4,6 +4,8 @@ from pydantic import AnyUrl, BaseModel, Field +from models.cache_entry import ConversationData + class ModelsResponse(BaseModel): """Model representing a response to models request.""" @@ -683,10 +685,10 @@ class ConversationsListResponseV2(BaseModel): """Model representing a response for listing conversations of a user. Attributes: - conversations: List of conversation IDs associated with the user. + conversations: List of conversation data associated with the user. """ - conversations: list[str] + conversations: list[ConversationData] class ErrorResponse(BaseModel): diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index 1f524962f..fce20d59e 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -189,7 +189,8 @@ def store_conversation_into_cache( model_id: str, query: str, response: str, - _skip_userid_check: bool + _skip_userid_check: bool, + topic_summary: str | None, ) -> None: """Store one part of conversation into conversation history cache.""" if config.conversation_cache_configuration.type is not None: @@ -203,7 +204,13 @@ def store_conversation_into_cache( provider=provider_id, model=model_id, ) - cache.insert_or_append(user_id, conversation_id, cache_entry, _skip_userid_check) + cache.insert_or_append( + user_id, conversation_id, cache_entry, _skip_userid_check + ) + if topic_summary: + cache.set_topic_summary( + user_id, conversation_id, topic_summary, _skip_userid_check + ) # # pylint: disable=R0913,R0917 diff --git a/tests/unit/cache/test_postgres_cache.py b/tests/unit/cache/test_postgres_cache.py index 61998de70..51e06448c 100644 --- a/tests/unit/cache/test_postgres_cache.py +++ b/tests/unit/cache/test_postgres_cache.py @@ -7,7 +7,7 @@ from cache.cache_error import CacheError from cache.postgres_cache import PostgresCache from models.config import PostgreSQLDatabaseConfiguration -from models.cache_entry import CacheEntry +from models.cache_entry import CacheEntry, ConversationData from utils import suid @@ -301,3 +301,67 @@ def test_list_operation_when_connected(postgres_cache_config_fixture, mocker): # should not fail lst = cache.list(USER_ID_1, False) assert not lst + assert isinstance(lst, list) + + +def test_topic_summary_operations(postgres_cache_config_fixture, mocker): + """Test topic summary set operations and retrieval via list.""" + # prevent real connection to PG instance + mock_connect = mocker.patch("psycopg2.connect") + cache = PostgresCache(postgres_cache_config_fixture) + + mock_connection = mock_connect.return_value + mock_cursor = mock_connection.cursor.return_value.__enter__.return_value + + # Mock fetchall to return conversation data + mock_cursor.fetchall.return_value = [ + ( + CONVERSATION_ID_1, + "This conversation is about machine learning and AI", + 1234567890.0, + ) + ] + + # Set a topic summary + test_summary = "This conversation is about machine learning and AI" + cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, test_summary, False) + + # Retrieve the topic summary via list + conversations = cache.list(USER_ID_1, False) + assert len(conversations) == 1 + assert conversations[0].topic_summary == test_summary + assert isinstance(conversations[0], ConversationData) + + +def test_topic_summary_after_conversation_delete(postgres_cache_config_fixture, mocker): + """Test that topic summary is deleted when conversation is deleted.""" + # prevent real connection to PG instance + mock_connect = mocker.patch("psycopg2.connect") + cache = PostgresCache(postgres_cache_config_fixture) + + mock_connection = mock_connect.return_value + mock_cursor = mock_connection.cursor.return_value.__enter__.return_value + + # Mock the delete operation to return 1 (deleted) + mock_cursor.rowcount = 1 + + # Add some cache entries and a topic summary + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, cache_entry_1, False) + cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, "Test summary", False) + + # Delete the conversation + deleted = cache.delete(USER_ID_1, CONVERSATION_ID_1, False) + assert deleted is True + + +def test_topic_summary_when_disconnected(postgres_cache_config_fixture, mocker): + """Test topic summary operations when cache is disconnected.""" + # prevent real connection to PG instance + mocker.patch("psycopg2.connect") + cache = PostgresCache(postgres_cache_config_fixture) + + cache.connection = None + cache.connect = lambda: None + + with pytest.raises(CacheError, match="cache is disconnected"): + cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, "Test", False) diff --git a/tests/unit/cache/test_sqlite_cache.py b/tests/unit/cache/test_sqlite_cache.py index 381b009ab..32cc9a46a 100644 --- a/tests/unit/cache/test_sqlite_cache.py +++ b/tests/unit/cache/test_sqlite_cache.py @@ -7,7 +7,7 @@ import pytest from models.config import SQLiteDatabaseConfiguration -from models.cache_entry import CacheEntry +from models.cache_entry import CacheEntry, ConversationData from utils import suid from cache.cache_error import CacheError @@ -188,6 +188,7 @@ def test_list_operation_when_connected(tmpdir): # should not fail lst = cache.list(USER_ID_1, False) assert not lst + assert isinstance(lst, list) def test_ready_method(tmpdir): @@ -255,3 +256,94 @@ def test_multiple_ids(tmpdir): lst = cache.get(USER_ID_2, CONVERSATION_ID_2, False) assert lst[0] == cache_entry_1 assert lst[1] == cache_entry_2 + + +def test_list_with_conversations(tmpdir): + """Test the list() method with actual conversations.""" + cache = create_cache(tmpdir) + + # Add some conversations + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, cache_entry_1, False) + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_2, cache_entry_2, False) + + # Set topic summaries + cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, "First conversation", False) + cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_2, "Second conversation", False) + + # Test list functionality + conversations = cache.list(USER_ID_1, False) + assert len(conversations) == 2 + assert all(isinstance(conv, ConversationData) for conv in conversations) + + # Check that conversations are ordered by last_message_timestamp DESC + assert ( + conversations[0].last_message_timestamp + >= conversations[1].last_message_timestamp + ) + + # Check conversation IDs + conv_ids = [conv.conversation_id for conv in conversations] + assert CONVERSATION_ID_1 in conv_ids + assert CONVERSATION_ID_2 in conv_ids + + +def test_topic_summary_operations(tmpdir): + """Test topic summary set operations and retrieval via list.""" + cache = create_cache(tmpdir) + + # Add a conversation + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, cache_entry_1, False) + + # Set a topic summary + test_summary = "This conversation is about machine learning and AI" + cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, test_summary, False) + + # Retrieve the topic summary via list + conversations = cache.list(USER_ID_1, False) + assert len(conversations) == 1 + assert conversations[0].topic_summary == test_summary + + # Update the topic summary + updated_summary = "This conversation is about deep learning and neural networks" + cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, updated_summary, False) + + # Verify the update via list + conversations = cache.list(USER_ID_1, False) + assert len(conversations) == 1 + assert conversations[0].topic_summary == updated_summary + + +def test_topic_summary_after_conversation_delete(tmpdir): + """Test that topic summary is deleted when conversation is deleted.""" + cache = create_cache(tmpdir) + + # Add some cache entries and a topic summary + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, cache_entry_1, False) + cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, "Test summary", False) + + # Verify both exist + entries = cache.get(USER_ID_1, CONVERSATION_ID_1, False) + assert len(entries) == 1 + conversations = cache.list(USER_ID_1, False) + assert len(conversations) == 1 + assert conversations[0].topic_summary == "Test summary" + + # Delete the conversation + deleted = cache.delete(USER_ID_1, CONVERSATION_ID_1, False) + assert deleted is True + + # Verify both are deleted + entries = cache.get(USER_ID_1, CONVERSATION_ID_1, False) + assert len(entries) == 0 + conversations = cache.list(USER_ID_1, False) + assert len(conversations) == 0 + + +def test_topic_summary_when_disconnected(tmpdir): + """Test topic summary operations when cache is disconnected.""" + cache = create_cache(tmpdir) + cache.connection = None + cache.connect = lambda: None + + with pytest.raises(CacheError, match="cache is disconnected"): + cache.set_topic_summary(USER_ID_1, CONVERSATION_ID_1, "Test", False) From c91471aa473fdee047918e1365e45c87017fbeca Mon Sep 17 00:00:00 2001 From: Stephanie Date: Mon, 29 Sep 2025 15:39:00 -0400 Subject: [PATCH 7/9] fix pylint Signed-off-by: Stephanie --- src/app/endpoints/streaming_query.py | 5 ++--- src/cache/cache.py | 3 ++- src/cache/postgres_cache.py | 3 ++- src/cache/sqlite_cache.py | 3 ++- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 0bd868a5d..bf4d8635c 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -20,9 +20,7 @@ from llama_stack_client.types.shared import ToolCall from llama_stack_client.types.shared.interleaved_content_item import TextContentItem - from app.database import get_session -import metrics from app.endpoints.query import ( get_rag_toolgroups, is_input_shield, @@ -41,6 +39,7 @@ from client import AsyncLlamaStackClientHolder from configuration import configuration from constants import DEFAULT_RAG_TOOL +import metrics from metrics.utils import update_llm_token_count_from_turn from models.config import Action from models.database.conversations import UserConversation @@ -570,7 +569,7 @@ def _handle_heartbeat_event(chunk_id: int) -> Iterator[str]: @router.post("/streaming_query", responses=streaming_query_responses) @authorize(Action.STREAMING_QUERY) -async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals +async def streaming_query_endpoint_handler( # pylint: disable=R0915,R0914 request: Request, query_request: QueryRequest, auth: Annotated[AuthTuple, Depends(auth_dependency)], diff --git a/src/cache/cache.py b/src/cache/cache.py index 263cb322d..4cdab6307 100644 --- a/src/cache/cache.py +++ b/src/cache/cache.py @@ -98,7 +98,8 @@ def list(self, user_id: str, skip_user_id_check: bool) -> list[ConversationData] skip_user_id_check: Skip user_id suid check. Returns: - A list of ConversationData objects containing conversation_id, topic_summary, and last_message_timestamp + A list of ConversationData objects containing conversation_id, topic_summary, and + last_message_timestamp """ @abstractmethod diff --git a/src/cache/postgres_cache.py b/src/cache/postgres_cache.py index 3d9fb17a4..a8e5d0338 100644 --- a/src/cache/postgres_cache.py +++ b/src/cache/postgres_cache.py @@ -311,7 +311,8 @@ def list( skip_user_id_check: Skip user_id suid check. Returns: - A list of ConversationData objects containing conversation_id, topic_summary, and last_message_timestamp + A list of ConversationData objects containing conversation_id, topic_summary, and + last_message_timestamp """ if self.connection is None: diff --git a/src/cache/sqlite_cache.py b/src/cache/sqlite_cache.py index 330c2d7bd..88934f45c 100644 --- a/src/cache/sqlite_cache.py +++ b/src/cache/sqlite_cache.py @@ -304,7 +304,8 @@ def list( skip_user_id_check: Skip user_id suid check. Returns: - A list of ConversationData objects containing conversation_id, topic_summary, and last_message_timestamp + A list of ConversationData objects containing conversation_id, + topic_summary, and last_message_timestamp """ if self.connection is None: From a44bf7bb30166e38d0d4451f2b455bd1b3c8a900 Mon Sep 17 00:00:00 2001 From: Stephanie Date: Mon, 29 Sep 2025 18:06:39 -0400 Subject: [PATCH 8/9] fix follow-up queries issue on topic-summary Signed-off-by: Stephanie --- src/cache/sqlite_cache.py | 4 +++- src/utils/endpoints.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/cache/sqlite_cache.py b/src/cache/sqlite_cache.py index 88934f45c..a39f8ade2 100644 --- a/src/cache/sqlite_cache.py +++ b/src/cache/sqlite_cache.py @@ -104,8 +104,10 @@ class SQLiteCache(Cache): """ UPSERT_CONVERSATION_STATEMENT = """ - INSERT OR REPLACE INTO conversations(user_id, conversation_id, topic_summary, last_message_timestamp) + INSERT INTO conversations(user_id, conversation_id, topic_summary, last_message_timestamp) VALUES (?, ?, ?, ?) + ON CONFLICT (user_id, conversation_id) + DO UPDATE SET last_message_timestamp = excluded.last_message_timestamp """ def __init__(self, config: SQLiteDatabaseConfiguration) -> None: diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index fce20d59e..0b0c15102 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -207,7 +207,7 @@ def store_conversation_into_cache( cache.insert_or_append( user_id, conversation_id, cache_entry, _skip_userid_check ) - if topic_summary: + if topic_summary and len(topic_summary) > 0: cache.set_topic_summary( user_id, conversation_id, topic_summary, _skip_userid_check ) From 17c651889f78b3f662f8b92d402b57a6b299ad15 Mon Sep 17 00:00:00 2001 From: Stephanie Date: Tue, 30 Sep 2025 11:40:58 -0400 Subject: [PATCH 9/9] rebase Signed-off-by: Stephanie --- docs/openapi.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/openapi.json b/docs/openapi.json index 8ac9be564..c9cf4fbc7 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -28,7 +28,7 @@ "root" ], "summary": "Root Endpoint Handler", - "description": "Handle GET requests to the root (\"/\") endpoint and returns the static HTML index page.\n\nReturns:\n HTMLResponse: The HTML content of the index page, including a heading,\n embedded image with the service icon, and links to the API documentation\n via Swagger UI and ReDoc.", + "description": "Handle request to the / endpoint.", "operationId": "root_endpoint_handler__get", "responses": { "200": {