diff --git a/src/app/endpoints/conversations.py b/src/app/endpoints/conversations.py index 0e761c60e..6032d01d6 100644 --- a/src/app/endpoints/conversations.py +++ b/src/app/endpoints/conversations.py @@ -18,8 +18,6 @@ router = APIRouter(tags=["conversations"]) auth_dependency = get_auth_dependency() -conversation_id_to_agent_id: dict[str, str] = {} - conversation_responses: dict[int | str, dict[str, Any]] = { 200: { "conversation_id": "123e4567-e89b-12d3-a456-426614174000", @@ -69,21 +67,20 @@ } -def simplify_session_data(session_data: Any) -> list[dict[str, Any]]: +def simplify_session_data(session_data: dict) -> list[dict[str, Any]]: """Simplify session data to include only essential conversation information. Args: - session_data: The full session data from llama-stack + session_data: The full session data dict from llama-stack Returns: Simplified session data with only input_messages and output_message per turn """ - session_dict = session_data.model_dump() # Create simplified structure chat_history = [] # Extract only essential data from each turn - for turn in session_dict.get("turns", []): + for turn in session_data.get("turns", []): # Clean up input messages cleaned_messages = [] for msg in turn.get("input_messages", []): @@ -131,25 +128,13 @@ def get_conversation_endpoint_handler( }, ) - agent_id = conversation_id_to_agent_id.get(conversation_id) - if not agent_id: - logger.error("Agent ID not found for conversation %s", conversation_id) - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail={ - "response": "conversation ID not found", - "cause": f"conversation ID {conversation_id} not found!", - }, - ) - + agent_id = conversation_id logger.info("Retrieving conversation %s", conversation_id) try: client = LlamaStackClientHolder().get_client() - session_data = client.agents.session.retrieve( - agent_id=agent_id, session_id=conversation_id - ) + session_data = client.agents.session.list(agent_id=agent_id).data[0] logger.info("Successfully retrieved conversation %s", conversation_id) @@ -211,16 +196,7 @@ def delete_conversation_endpoint_handler( "cause": f"Conversation ID {conversation_id} is not a valid UUID", }, ) - agent_id = conversation_id_to_agent_id.get(conversation_id) - if not agent_id: - logger.error("Agent ID not found for conversation %s", conversation_id) - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail={ - "response": "conversation ID not found", - "cause": f"conversation ID {conversation_id} not found!", - }, - ) + agent_id = conversation_id logger.info("Deleting conversation %s", conversation_id) try: diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 8701aa6bd..774907d58 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -1,5 +1,6 @@ """Handler for REST API call to provide answer to query.""" +from contextlib import suppress from datetime import datetime, UTC import json import logging @@ -7,8 +8,6 @@ from pathlib import Path from typing import Any -from cachetools import TTLCache # type: ignore - from llama_stack_client.lib.agents.agent import Agent from llama_stack_client import APIConnectionError from llama_stack_client import LlamaStackClient # type: ignore @@ -23,7 +22,6 @@ from client import LlamaStackClientHolder from configuration import configuration -from app.endpoints.conversations import conversation_id_to_agent_id import metrics from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse from models.requests import QueryRequest, Attachment @@ -39,9 +37,6 @@ router = APIRouter(tags=["query"]) auth_dependency = get_auth_dependency() -# Global agent registry to persist agents across requests -_agent_cache: TTLCache[str, Agent] = TTLCache(maxsize=1000, ttl=3600) - query_response: dict[int | str, dict[str, Any]] = { 200: { "conversation_id": "123e4567-e89b-12d3-a456-426614174000", @@ -81,16 +76,14 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen available_output_shields: list[str], conversation_id: str | None, no_tools: bool = False, -) -> tuple[Agent, str]: +) -> tuple[Agent, str, str]: """Get existing agent or create a new one with session persistence.""" - if conversation_id is not None: - agent = _agent_cache.get(conversation_id) - if agent: - logger.debug( - "Reusing existing agent with conversation_id: %s", conversation_id - ) - return agent, conversation_id - logger.debug("No existing agent found for conversation_id: %s", conversation_id) + existing_agent_id = None + if conversation_id: + with suppress(ValueError): + existing_agent_id = client.agents.retrieve( + agent_id=conversation_id + ).agent_id logger.debug("Creating new agent") # TODO(lucasagomes): move to ReActAgent @@ -103,12 +96,18 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id), enable_session_persistence=True, ) - conversation_id = agent.create_session(get_suid()) - logger.debug("Created new agent and conversation_id: %s", conversation_id) - _agent_cache[conversation_id] = agent - conversation_id_to_agent_id[conversation_id] = agent.agent_id + if existing_agent_id and conversation_id: + orphan_agent_id = agent.agent_id + agent.agent_id = conversation_id + client.agents.delete(agent_id=orphan_agent_id) + sessions_response = client.agents.session.list(agent_id=conversation_id) + logger.info("session response: %s", sessions_response) + session_id = str(sessions_response.data[0]["session_id"]) + else: + conversation_id = agent.agent_id + session_id = agent.create_session(get_suid()) - return agent, conversation_id + return agent, conversation_id, session_id @router.post("/query", responses=query_response) @@ -282,7 +281,7 @@ def retrieve_response( # pylint: disable=too-many-locals if query_request.attachments: validate_attachments_metadata(query_request.attachments) - agent, conversation_id = get_agent( + agent, conversation_id, session_id = get_agent( client, model_id, system_prompt, @@ -326,7 +325,7 @@ def retrieve_response( # pylint: disable=too-many-locals response = agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], - session_id=conversation_id, + session_id=session_id, documents=query_request.get_documents(), stream=False, toolgroups=toolgroups, diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index d27388e22..e4663327c 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -1,13 +1,12 @@ """Handler for REST API call to provide answer to streaming query.""" import ast +from contextlib import suppress import json import re import logging from typing import Any, AsyncIterator, Iterator -from cachetools import TTLCache # type: ignore - from llama_stack_client import APIConnectionError from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore from llama_stack_client import AsyncLlamaStackClient # type: ignore @@ -31,7 +30,6 @@ from utils.suid import get_suid from utils.types import GraniteToolParser -from app.endpoints.conversations import conversation_id_to_agent_id from app.endpoints.query import ( get_rag_toolgroups, is_input_shield, @@ -46,9 +44,6 @@ router = APIRouter(tags=["streaming_query"]) auth_dependency = get_auth_dependency() -# Global agent registry to persist agents across requests -_agent_cache: TTLCache[str, AsyncAgent] = TTLCache(maxsize=1000, ttl=3600) - # # pylint: disable=R0913,R0917 async def get_agent( @@ -59,16 +54,13 @@ async def get_agent( available_output_shields: list[str], conversation_id: str | None, no_tools: bool = False, -) -> tuple[AsyncAgent, str]: +) -> tuple[AsyncAgent, str, str]: """Get existing agent or create a new one with session persistence.""" - if conversation_id is not None: - agent = _agent_cache.get(conversation_id) - if agent: - logger.debug( - "Reusing existing agent with conversation_id: %s", conversation_id - ) - return agent, conversation_id - logger.debug("No existing agent found for conversation_id: %s", conversation_id) + existing_agent_id = None + if conversation_id: + with suppress(ValueError): + agent_response = await client.agents.retrieve(agent_id=conversation_id) + existing_agent_id = agent_response.agent_id logger.debug("Creating new agent") agent = AsyncAgent( @@ -80,11 +72,19 @@ async def get_agent( tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id), enable_session_persistence=True, ) - conversation_id = await agent.create_session(get_suid()) - logger.debug("Created new agent and conversation_id: %s", conversation_id) - _agent_cache[conversation_id] = agent - conversation_id_to_agent_id[conversation_id] = agent.agent_id - return agent, conversation_id + + if existing_agent_id and conversation_id: + orphan_agent_id = agent.agent_id + agent._agent_id = conversation_id # type: ignore[assignment] # pylint: disable=protected-access + await client.agents.delete(agent_id=orphan_agent_id) + sessions_response = await client.agents.session.list(agent_id=conversation_id) + logger.info("session response: %s", sessions_response) + session_id = str(sessions_response.data[0]["session_id"]) + else: + conversation_id = agent.agent_id + session_id = await agent.create_session(get_suid()) + + return agent, conversation_id, session_id METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n") @@ -526,7 +526,7 @@ async def retrieve_response( if query_request.attachments: validate_attachments_metadata(query_request.attachments) - agent, conversation_id = await get_agent( + agent, conversation_id, session_id = await get_agent( client, model_id, system_prompt, @@ -576,7 +576,7 @@ async def retrieve_response( logger.debug("Session ID: %s", conversation_id) response = await agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], - session_id=conversation_id, + session_id=session_id, documents=query_request.get_documents(), stream=True, toolgroups=toolgroups, diff --git a/tests/unit/app/endpoints/test_conversations.py b/tests/unit/app/endpoints/test_conversations.py index 0f5456b70..d1df17a09 100644 --- a/tests/unit/app/endpoints/test_conversations.py +++ b/tests/unit/app/endpoints/test_conversations.py @@ -7,7 +7,6 @@ from app.endpoints.conversations import ( get_conversation_endpoint_handler, delete_conversation_endpoint_handler, - conversation_id_to_agent_id, simplify_session_data, ) from models.responses import ConversationResponse, ConversationDeleteResponse @@ -15,7 +14,6 @@ MOCK_AUTH = ("mock_user_id", "mock_username", "mock_token") VALID_CONVERSATION_ID = "123e4567-e89b-12d3-a456-426614174000" -VALID_AGENT_ID = "agent_123" INVALID_CONVERSATION_ID = "invalid-id" @@ -48,16 +46,6 @@ def setup_configuration_fixture(): return cfg -@pytest.fixture(autouse=True) -def setup_conversation_mapping(): - """Set up and clean up the conversation ID to agent ID mapping.""" - # Clear the mapping before each test - conversation_id_to_agent_id.clear() - yield - # Clean up after each test - conversation_id_to_agent_id.clear() - - @pytest.fixture(name="mock_session_data") def mock_session_data_fixture(): """Create mock session data for testing.""" @@ -127,19 +115,14 @@ class TestSimplifySessionData: """Test cases for the simplify_session_data function.""" def test_simplify_session_data_with_model_dump( - self, mock_session_data, expected_chat_history, mocker + self, mock_session_data, expected_chat_history ): - """Test simplify_session_data with session data that has model_dump method.""" - # Create a mock object with model_dump method - mock_session_obj = mocker.Mock() - mock_session_obj.model_dump.return_value = mock_session_data - - result = simplify_session_data(mock_session_obj) + """Test simplify_session_data with session data.""" + result = simplify_session_data(mock_session_data) assert result == expected_chat_history - mock_session_obj.model_dump.assert_called_once() - def test_simplify_session_data_empty_turns(self, mocker): + def test_simplify_session_data_empty_turns(self): """Test simplify_session_data with empty turns.""" session_data = { "session_id": VALID_CONVERSATION_ID, @@ -147,14 +130,11 @@ def test_simplify_session_data_empty_turns(self, mocker): "turns": [], } - mock_session_obj = mocker.Mock() - mock_session_obj.model_dump.return_value = session_data - - result = simplify_session_data(mock_session_obj) + result = simplify_session_data(session_data) assert not result - def test_simplify_session_data_filters_unwanted_fields(self, mocker): + def test_simplify_session_data_filters_unwanted_fields(self): """Test that simplify_session_data properly filters out unwanted fields.""" session_data = { "session_id": VALID_CONVERSATION_ID, @@ -182,10 +162,7 @@ def test_simplify_session_data_filters_unwanted_fields(self, mocker): ], } - mock_session_obj = mocker.Mock() - mock_session_obj.model_dump.return_value = session_data - - result = simplify_session_data(mock_session_obj) + result = simplify_session_data(session_data) expected = [ { @@ -226,31 +203,14 @@ def test_invalid_conversation_id_format(self, mocker, setup_configuration): assert "Invalid conversation ID format" in exc_info.value.detail["response"] assert INVALID_CONVERSATION_ID in exc_info.value.detail["cause"] - def test_conversation_not_found_in_mapping(self, mocker, setup_configuration): - """Test the endpoint when conversation ID is not in the mapping.""" - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - - with pytest.raises(HTTPException) as exc_info: - get_conversation_endpoint_handler(VALID_CONVERSATION_ID, _auth=MOCK_AUTH) - - assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND - assert "conversation ID not found" in exc_info.value.detail["response"] - assert VALID_CONVERSATION_ID in exc_info.value.detail["cause"] - def test_llama_stack_connection_error(self, mocker, setup_configuration): """Test the endpoint when LlamaStack connection fails.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Set up conversation mapping - conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID - # Mock LlamaStackClientHolder to raise APIConnectionError mock_client = mocker.Mock() - mock_client.agents.session.retrieve.side_effect = APIConnectionError( - request=None - ) + mock_client.agents.session.list.side_effect = APIConnectionError(request=None) mock_client_holder = mocker.patch( "app.endpoints.conversations.LlamaStackClientHolder" ) @@ -268,12 +228,9 @@ def test_llama_stack_not_found_error(self, mocker, setup_configuration): mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Set up conversation mapping - conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID - # Mock LlamaStackClientHolder to raise NotFoundError mock_client = mocker.Mock() - mock_client.agents.session.retrieve.side_effect = NotFoundError( + mock_client.agents.session.list.side_effect = NotFoundError( message="Session not found", response=mocker.Mock(request=None), body=None ) mock_client_holder = mocker.patch( @@ -294,9 +251,6 @@ def test_session_retrieve_exception(self, mocker, setup_configuration): mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Set up conversation mapping - conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID - # Mock LlamaStackClientHolder to raise a general exception mock_client = mocker.Mock() mock_client.agents.session.retrieve.side_effect = Exception( @@ -323,16 +277,15 @@ def test_successful_conversation_retrieval( mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Set up conversation mapping - conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID - # Mock session data with model_dump method mock_session_obj = mocker.Mock() mock_session_obj.model_dump.return_value = mock_session_data # Mock LlamaStackClientHolder mock_client = mocker.Mock() - mock_client.agents.session.retrieve.return_value = mock_session_obj + mock_client.agents.session.list.return_value = mocker.Mock( + data=[mock_session_data] + ) mock_client_holder = mocker.patch( "app.endpoints.conversations.LlamaStackClientHolder" ) @@ -345,8 +298,8 @@ def test_successful_conversation_retrieval( assert isinstance(response, ConversationResponse) assert response.conversation_id == VALID_CONVERSATION_ID assert response.chat_history == expected_chat_history - mock_client.agents.session.retrieve.assert_called_once_with( - agent_id=VALID_AGENT_ID, session_id=VALID_CONVERSATION_ID + mock_client.agents.session.list.assert_called_once_with( + agent_id=VALID_CONVERSATION_ID ) @@ -377,26 +330,11 @@ def test_invalid_conversation_id_format(self, mocker, setup_configuration): assert "Invalid conversation ID format" in exc_info.value.detail["response"] assert INVALID_CONVERSATION_ID in exc_info.value.detail["cause"] - def test_conversation_not_found_in_mapping(self, mocker, setup_configuration): - """Test the endpoint when conversation ID is not in the mapping.""" - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - - with pytest.raises(HTTPException) as exc_info: - delete_conversation_endpoint_handler(VALID_CONVERSATION_ID, _auth=MOCK_AUTH) - - assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND - assert "conversation ID not found" in exc_info.value.detail["response"] - assert VALID_CONVERSATION_ID in exc_info.value.detail["cause"] - def test_llama_stack_connection_error(self, mocker, setup_configuration): """Test the endpoint when LlamaStack connection fails.""" mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Set up conversation mapping - conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID - # Mock LlamaStackClientHolder to raise APIConnectionError mock_client = mocker.Mock() mock_client.agents.session.delete.side_effect = APIConnectionError(request=None) @@ -416,9 +354,6 @@ def test_llama_stack_not_found_error(self, mocker, setup_configuration): mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Set up conversation mapping - conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID - # Mock LlamaStackClientHolder to raise NotFoundError mock_client = mocker.Mock() mock_client.agents.session.delete.side_effect = NotFoundError( @@ -442,9 +377,6 @@ def test_session_deletion_exception(self, mocker, setup_configuration): mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Set up conversation mapping - conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID - # Mock LlamaStackClientHolder to raise a general exception mock_client = mocker.Mock() mock_client.agents.session.delete.side_effect = Exception( @@ -470,9 +402,6 @@ def test_successful_conversation_deletion(self, mocker, setup_configuration): mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - # Set up conversation mapping - conversation_id_to_agent_id[VALID_CONVERSATION_ID] = VALID_AGENT_ID - # Mock LlamaStackClientHolder mock_client = mocker.Mock() mock_client.agents.session.delete.return_value = None # Successful deletion @@ -490,5 +419,5 @@ def test_successful_conversation_deletion(self, mocker, setup_configuration): assert response.success is True assert response.response == "Conversation deleted successfully" mock_client.agents.session.delete.assert_called_once_with( - agent_id=VALID_AGENT_ID, session_id=VALID_CONVERSATION_ID + agent_id=VALID_CONVERSATION_ID, session_id=VALID_CONVERSATION_ID ) diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 03e32563d..8443b12b6 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -20,7 +20,6 @@ store_transcript, get_rag_toolgroups, get_agent, - _agent_cache, ) from models.requests import QueryRequest, Attachment @@ -65,8 +64,6 @@ def prepare_agent_mocks_fixture(mocker): mock_agent = mocker.Mock() mock_agent.create_turn.return_value.steps = [] yield mock_client, mock_agent - # cleanup agent cache after tests - _agent_cache.clear() def test_query_endpoint_handler_configuration_not_loaded(mocker): @@ -380,7 +377,8 @@ def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker): mock_config.mcp_servers = [] mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -394,7 +392,7 @@ def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker): # Assert that the metric for validation errors is NOT incremented mock_metric.inc.assert_not_called() assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -416,7 +414,8 @@ def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocker): mock_config.mcp_servers = [] mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -428,7 +427,7 @@ def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocker): ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -463,7 +462,8 @@ def __repr__(self): mock_config.mcp_servers = [] mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -475,7 +475,7 @@ def __repr__(self): ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -513,7 +513,8 @@ def __repr__(self): mock_config.mcp_servers = [] mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -525,7 +526,7 @@ def __repr__(self): ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -565,7 +566,8 @@ def __repr__(self): mock_config.mcp_servers = [] mocker.patch("app.endpoints.query.configuration", mock_config) mock_get_agent = mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -577,7 +579,7 @@ def __repr__(self): ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" # Verify get_agent was called with the correct parameters mock_get_agent.assert_called_once_with( @@ -619,7 +621,8 @@ def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker): ), ] mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) @@ -631,7 +634,7 @@ def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker): ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -671,7 +674,8 @@ def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocker): ), ] mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) @@ -683,7 +687,7 @@ def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocker): ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -724,7 +728,8 @@ def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): mock_config.mcp_servers = mcp_servers mocker.patch("app.endpoints.query.configuration", mock_config) mock_get_agent = mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -736,7 +741,7 @@ def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" # Verify get_agent was called with the correct parameters mock_get_agent.assert_called_once_with( @@ -789,7 +794,8 @@ def test_retrieve_response_with_mcp_servers_empty_token(prepare_agent_mocks, moc mock_config.mcp_servers = mcp_servers mocker.patch("app.endpoints.query.configuration", mock_config) mock_get_agent = mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -801,7 +807,7 @@ def test_retrieve_response_with_mcp_servers_empty_token(prepare_agent_mocks, moc ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" # Verify get_agent was called with the correct parameters mock_get_agent.assert_called_once_with( @@ -848,7 +854,8 @@ def test_retrieve_response_with_mcp_servers_and_mcp_headers( mock_config.mcp_servers = mcp_servers mocker.patch("app.endpoints.query.configuration", mock_config) mock_get_agent = mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -874,7 +881,7 @@ def test_retrieve_response_with_mcp_servers_and_mcp_headers( ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" # Verify get_agent was called with the correct parameters mock_get_agent.assert_called_once_with( @@ -939,7 +946,8 @@ def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker): mock_config.mcp_servers = [] mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -951,7 +959,7 @@ def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker): # Assert that the metric for validation errors is incremented mock_metric.inc.assert_called_once() - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -1073,15 +1081,20 @@ def test_query_endpoint_handler_on_connection_error(mocker): mock_metric.inc.assert_called_once() -def test_get_agent_cache_hit(prepare_agent_mocks): - """Test get_agent function when agent exists in cache.""" +def test_get_agent_with_conversation_id(prepare_agent_mocks, mocker): + """Test get_agent function when agent exists in llama stack.""" mock_client, mock_agent = prepare_agent_mocks + mock_client.agents.session.list.return_value = mocker.Mock( + data=[{"session_id": "test_session_id"}] + ) # Set up cache with existing agent conversation_id = "test_conversation_id" - _agent_cache[conversation_id] = mock_agent - result_agent, result_conversation_id = get_agent( + # Mock Agent class + mocker.patch("app.endpoints.query.Agent", return_value=mock_agent) + + result_agent, result_conversation_id, result_session_id = get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1090,16 +1103,21 @@ def test_get_agent_cache_hit(prepare_agent_mocks): conversation_id=conversation_id, ) - # Assert cached agent is returned + # Assert the same agent is returned assert result_agent == mock_agent - assert result_conversation_id == conversation_id + assert result_conversation_id == result_agent.agent_id + assert conversation_id == result_agent.agent_id + assert result_session_id == "test_session_id" -def test_get_agent_cache_miss_with_conversation_id( +def test_get_agent_with_conversation_id_and_no_agent_in_llama_stack( setup_configuration, prepare_agent_mocks, mocker ): - """Test get_agent function when conversation_id is provided but agent not in cache.""" + """Test get_agent function when conversation_id is provided.""" mock_client, mock_agent = prepare_agent_mocks + mock_client.agents.retrieve.side_effect = ValueError( + "fake not finding existing agent" + ) mock_agent.create_session.return_value = "new_session_id" # Mock Agent class @@ -1120,20 +1138,22 @@ def test_get_agent_cache_miss_with_conversation_id( return_value=[mock_mcp_server], ) mocker.patch("app.endpoints.query.configuration", setup_configuration) - - # Call function with conversation_id but no cached agent - result_agent, result_conversation_id = get_agent( + conversation_id = "non_existent_conversation_id" + # Call function with conversation_id + result_agent, result_conversation_id, result_session_id = get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", available_input_shields=["shield1"], available_output_shields=["output_shield2"], - conversation_id="non_existent_conversation_id", + conversation_id=conversation_id, ) # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert conversation_id != result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with correct parameters mock_agent_class.assert_called_once_with( @@ -1146,9 +1166,6 @@ def test_get_agent_cache_miss_with_conversation_id( enable_session_persistence=True, ) - # Verify agent was stored in cache - assert _agent_cache["new_session_id"] == mock_agent - def test_get_agent_no_conversation_id(setup_configuration, prepare_agent_mocks, mocker): """Test get_agent function when conversation_id is None.""" @@ -1175,7 +1192,7 @@ def test_get_agent_no_conversation_id(setup_configuration, prepare_agent_mocks, mocker.patch("app.endpoints.query.configuration", setup_configuration) # Call function with None conversation_id - result_agent, result_conversation_id = get_agent( + result_agent, result_conversation_id, result_session_id = get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1186,7 +1203,8 @@ def test_get_agent_no_conversation_id(setup_configuration, prepare_agent_mocks, # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with correct parameters mock_agent_class.assert_called_once_with( @@ -1199,9 +1217,6 @@ def test_get_agent_no_conversation_id(setup_configuration, prepare_agent_mocks, enable_session_persistence=True, ) - # Verify agent was stored in cache - assert _agent_cache["new_session_id"] == mock_agent - def test_get_agent_empty_shields(setup_configuration, prepare_agent_mocks, mocker): """Test get_agent function with empty shields list.""" @@ -1228,7 +1243,7 @@ def test_get_agent_empty_shields(setup_configuration, prepare_agent_mocks, mocke mocker.patch("app.endpoints.query.configuration", setup_configuration) # Call function with empty shields list - result_agent, result_conversation_id = get_agent( + result_agent, result_conversation_id, result_session_id = get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1239,7 +1254,8 @@ def test_get_agent_empty_shields(setup_configuration, prepare_agent_mocks, mocke # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with empty shields mock_agent_class.assert_called_once_with( @@ -1282,7 +1298,7 @@ def test_get_agent_multiple_mcp_servers( mocker.patch("app.endpoints.query.configuration", setup_configuration) # Call function - result_agent, result_conversation_id = get_agent( + result_agent, result_conversation_id, result_session_id = get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1293,7 +1309,8 @@ def test_get_agent_multiple_mcp_servers( # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with tools from both MCP servers mock_agent_class.assert_called_once_with( @@ -1479,7 +1496,8 @@ def test_retrieve_response_no_tools_bypasses_mcp_and_rag(prepare_agent_mocks, mo mock_config.mcp_servers = mcp_servers mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?", no_tools=True) @@ -1491,7 +1509,7 @@ def test_retrieve_response_no_tools_bypasses_mcp_and_rag(prepare_agent_mocks, mo ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" # Verify that agent.extra_headers is empty (no MCP headers) assert mock_agent.extra_headers == {} @@ -1527,7 +1545,8 @@ def test_retrieve_response_no_tools_false_preserves_functionality( mock_config.mcp_servers = mcp_servers mocker.patch("app.endpoints.query.configuration", mock_config) mocker.patch( - "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + "app.endpoints.query.get_agent", + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?", no_tools=False) @@ -1539,7 +1558,7 @@ def test_retrieve_response_no_tools_false_preserves_functionality( ) assert response == "LLM answer" - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" # Verify that agent.extra_headers contains MCP headers expected_extra_headers = { @@ -1589,7 +1608,7 @@ def test_get_agent_no_tools_no_parser(setup_configuration, prepare_agent_mocks, mocker.patch("app.endpoints.query.configuration", setup_configuration) # Call function with no_tools=True - result_agent, result_conversation_id = get_agent( + result_agent, result_conversation_id, result_session_id = get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1601,7 +1620,8 @@ def test_get_agent_no_tools_no_parser(setup_configuration, prepare_agent_mocks, # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with tool_parser=None mock_agent_class.assert_called_once_with( @@ -1647,7 +1667,7 @@ def test_get_agent_no_tools_false_preserves_parser( mocker.patch("app.endpoints.query.configuration", setup_configuration) # Call function with no_tools=False - result_agent, result_conversation_id = get_agent( + result_agent, result_conversation_id, result_session_id = get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1659,7 +1679,8 @@ def test_get_agent_no_tools_false_preserves_parser( # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with the proper tool_parser mock_agent_class.assert_called_once_with( diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 8ff286adc..0251e5146 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -41,7 +41,6 @@ retrieve_response, stream_build_event, get_agent, - _agent_cache, ) from models.requests import QueryRequest, Attachment @@ -113,8 +112,6 @@ def prepare_agent_mocks_fixture(mocker): mock_client = mocker.AsyncMock() mock_agent = mocker.AsyncMock() yield mock_client, mock_agent - # cleanup agent cache after tests - _agent_cache.clear() @pytest.mark.asyncio @@ -317,7 +314,7 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "test_conversation_id"), + return_value=(mock_agent, "test_conversation_id", "test_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -334,7 +331,7 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker assert conversation_id == "test_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_conversation_id", + session_id="test_session_id", documents=[], stream=True, # Should be True for streaming endpoint toolgroups=get_rag_toolgroups(["VectorDB-1"]), @@ -354,7 +351,7 @@ async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocke mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "test_conversation_id"), + return_value=(mock_agent, "test_conversation_id", "test_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -371,7 +368,7 @@ async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocke assert conversation_id == "test_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_conversation_id", + session_id="test_session_id", documents=[], stream=True, # Should be True for streaming endpoint toolgroups=None, @@ -404,7 +401,7 @@ def __repr__(self): mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "test_conversation_id"), + return_value=(mock_agent, "test_conversation_id", "test_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -419,7 +416,7 @@ def __repr__(self): assert conversation_id == "test_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_conversation_id", + session_id="test_session_id", documents=[], stream=True, # Should be True for streaming endpoint toolgroups=None, @@ -455,7 +452,7 @@ def __repr__(self): mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "test_conversation_id"), + return_value=(mock_agent, "test_conversation_id", "test_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -470,7 +467,7 @@ def __repr__(self): assert conversation_id == "test_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_conversation_id", + session_id="test_session_id", documents=[], stream=True, # Should be True for streaming endpoint toolgroups=None, @@ -508,7 +505,7 @@ def __repr__(self): mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mock_get_agent = mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "test_conversation_id"), + return_value=(mock_agent, "test_conversation_id", "test_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -535,7 +532,7 @@ def __repr__(self): mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_conversation_id", + session_id="test_session_id", documents=[], stream=True, # Should be True for streaming endpoint toolgroups=None, @@ -563,7 +560,7 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker ] mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "test_conversation_id"), + return_value=(mock_agent, "test_conversation_id", "test_session_id"), ) query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) @@ -578,7 +575,7 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker assert conversation_id == "test_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_conversation_id", + session_id="test_session_id", stream=True, # Should be True for streaming endpoint documents=[ { @@ -616,7 +613,7 @@ async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocke ] mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "test_conversation_id"), + return_value=(mock_agent, "test_conversation_id", "test_session_id"), ) query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) @@ -631,7 +628,7 @@ async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocke assert conversation_id == "test_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_conversation_id", + session_id="test_session_id", stream=True, # Should be True for streaming endpoint documents=[ { @@ -1017,7 +1014,7 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mock_get_agent = mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "test_conversation_id"), + return_value=(mock_agent, "test_conversation_id", "test_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -1060,7 +1057,7 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): # Check that create_turn was called with the correct parameters mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_conversation_id", + session_id="test_session_id", documents=[], stream=True, toolgroups=[mcp_server.name for mcp_server in mcp_servers], @@ -1085,7 +1082,7 @@ async def test_retrieve_response_with_mcp_servers_empty_token( mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mock_get_agent = mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "test_conversation_id"), + return_value=(mock_agent, "test_conversation_id", "test_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -1119,7 +1116,7 @@ async def test_retrieve_response_with_mcp_servers_empty_token( # Check that create_turn was called with the correct parameters mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_conversation_id", + session_id="test_session_id", documents=[], stream=True, toolgroups=[mcp_server.name for mcp_server in mcp_servers], @@ -1150,7 +1147,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker): mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mock_get_agent = mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "test_conversation_id"), + return_value=(mock_agent, "test_conversation_id", "test_session_id"), ) query_request = QueryRequest(query="What is OpenStack?") @@ -1206,7 +1203,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker): # Check that create_turn was called with the correct parameters mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_conversation_id", + session_id="test_session_id", documents=[], stream=True, toolgroups=[mcp_server.name for mcp_server in mcp_servers], @@ -1214,16 +1211,17 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker): @pytest.mark.asyncio -async def test_get_agent_cache_hit(prepare_agent_mocks): - """Test get_agent function when agent exists in cache.""" +async def test_get_agent_with_conversation_id(prepare_agent_mocks, mocker): + """Test get_agent function when agent exists in llama stack.""" mock_client, mock_agent = prepare_agent_mocks - # Set up cache with existing agent conversation_id = "test_conversation_id" - _agent_cache[conversation_id] = mock_agent - result_agent, result_conversation_id = await get_agent( + # Mock Agent class + mocker.patch("app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent) + + result_agent, result_conversation_id, _ = await get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1232,18 +1230,22 @@ async def test_get_agent_cache_hit(prepare_agent_mocks): conversation_id=conversation_id, ) - # Assert cached agent is returned + # Assert the same agent is returned assert result_agent == mock_agent assert result_conversation_id == conversation_id + assert conversation_id == mock_agent._agent_id # pylint: disable=protected-access @pytest.mark.asyncio -async def test_get_agent_cache_miss_with_conversation_id( +async def test_get_agent_with_conversation_id_and_no_agent_in_llama_stack( setup_configuration, prepare_agent_mocks, mocker ): - """Test get_agent function when conversation_id is provided but agent not in cache.""" + """Test get_agent function when conversation_id is provided but agent not in llama stack.""" mock_client, mock_agent = prepare_agent_mocks + mock_client.agents.retrieve.side_effect = ValueError( + "fake not finding existing agent" + ) mock_agent.create_session.return_value = "new_session_id" # Mock Agent class @@ -1267,8 +1269,8 @@ async def test_get_agent_cache_miss_with_conversation_id( ) mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) - # Call function with conversation_id but no cached agent - result_agent, result_conversation_id = await get_agent( + # Call function with conversation_id + result_agent, result_conversation_id, result_session_id = await get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1279,7 +1281,9 @@ async def test_get_agent_cache_miss_with_conversation_id( # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert "non_existent_conversation_id" != result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with correct parameters mock_agent_class.assert_called_once_with( @@ -1292,9 +1296,6 @@ async def test_get_agent_cache_miss_with_conversation_id( enable_session_persistence=True, ) - # Verify agent was stored in cache - assert _agent_cache["new_session_id"] == mock_agent - @pytest.mark.asyncio async def test_get_agent_no_conversation_id( @@ -1327,7 +1328,7 @@ async def test_get_agent_no_conversation_id( mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) # Call function with None conversation_id - result_agent, result_conversation_id = await get_agent( + result_agent, result_conversation_id, result_session_id = await get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1338,7 +1339,8 @@ async def test_get_agent_no_conversation_id( # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with correct parameters mock_agent_class.assert_called_once_with( @@ -1351,9 +1353,6 @@ async def test_get_agent_no_conversation_id( enable_session_persistence=True, ) - # Verify agent was stored in cache - assert _agent_cache["new_session_id"] == mock_agent - @pytest.mark.asyncio async def test_get_agent_empty_shields( @@ -1386,7 +1385,7 @@ async def test_get_agent_empty_shields( mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) # Call function with empty shields list - result_agent, result_conversation_id = await get_agent( + result_agent, result_conversation_id, result_session_id = await get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1397,7 +1396,8 @@ async def test_get_agent_empty_shields( # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with empty shields mock_agent_class.assert_called_once_with( @@ -1444,7 +1444,7 @@ async def test_get_agent_multiple_mcp_servers( mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) # Call function - result_agent, result_conversation_id = await get_agent( + result_agent, result_conversation_id, result_session_id = await get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1455,7 +1455,8 @@ async def test_get_agent_multiple_mcp_servers( # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with tools from both MCP servers mock_agent_class.assert_called_once_with( @@ -1556,7 +1557,7 @@ async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker): "app.endpoints.streaming_query.retrieve_user_id", return_value="user123" ) - _ = await streaming_query_endpoint_handler( + await streaming_query_endpoint_handler( None, QueryRequest(query="test query"), auth=("user123", "username", "auth_token_123"), @@ -1669,7 +1670,7 @@ async def test_retrieve_response_no_tools_bypasses_mcp_and_rag( mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "fake_session_id"), + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?", no_tools=True) @@ -1681,7 +1682,7 @@ async def test_retrieve_response_no_tools_bypasses_mcp_and_rag( ) assert response is not None - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" # Verify that agent.extra_headers is empty (no MCP headers) assert mock_agent.extra_headers == {} @@ -1719,7 +1720,7 @@ async def test_retrieve_response_no_tools_false_preserves_functionality( mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mocker.patch( "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "fake_session_id"), + return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) query_request = QueryRequest(query="What is OpenStack?", no_tools=False) @@ -1731,7 +1732,7 @@ async def test_retrieve_response_no_tools_false_preserves_functionality( ) assert response is not None - assert conversation_id == "fake_session_id" + assert conversation_id == "fake_conversation_id" # Verify that agent.extra_headers contains MCP headers expected_extra_headers = { @@ -1785,7 +1786,7 @@ async def test_get_agent_no_tools_no_parser( mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) # Call function with no_tools=True - result_agent, result_conversation_id = await get_agent( + result_agent, result_conversation_id, result_session_id = await get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1797,7 +1798,8 @@ async def test_get_agent_no_tools_no_parser( # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with tool_parser=None mock_agent_class.assert_called_once_with( @@ -1848,7 +1850,7 @@ async def test_get_agent_no_tools_false_preserves_parser( mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) # Call function with no_tools=False - result_agent, result_conversation_id = await get_agent( + result_agent, result_conversation_id, result_session_id = await get_agent( client=mock_client, model_id="test_model", system_prompt="test_prompt", @@ -1860,7 +1862,8 @@ async def test_get_agent_no_tools_false_preserves_parser( # Assert new agent is created assert result_agent == mock_agent - assert result_conversation_id == "new_session_id" + assert result_conversation_id == result_agent.agent_id + assert result_session_id == "new_session_id" # Verify Agent was created with the proper tool_parser mock_agent_class.assert_called_once_with(