From 906e60a73d9a17af01a83c9747780f45244551ba Mon Sep 17 00:00:00 2001 From: Ben Keith Date: Thu, 14 Aug 2025 14:57:31 -0400 Subject: [PATCH] Add tool calls to stored transcripts Also moved all of the transcript handling to its own module as it grew a bit with this. --- src/app/endpoints/query.py | 146 ++++--------- src/app/endpoints/streaming_query.py | 37 ++-- src/utils/transcripts.py | 86 ++++++++ src/utils/types.py | 43 +++- tests/unit/app/endpoints/test_query.py | 204 ++++++++---------- .../app/endpoints/test_streaming_query.py | 15 +- tests/unit/utils/test_transcripts.py | 127 +++++++++++ 7 files changed, 413 insertions(+), 245 deletions(-) create mode 100644 src/utils/transcripts.py create mode 100644 tests/unit/utils/test_transcripts.py diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index d53d3e83d..eae3ae287 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -3,13 +3,13 @@ from datetime import datetime, UTC import json import logging -import os -from pathlib import Path -from typing import Annotated, Any +from typing import Annotated, Any, cast from llama_stack_client import APIConnectionError from llama_stack_client import AsyncLlamaStackClient # type: ignore +from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str from llama_stack_client.types import UserMessage, Shield # type: ignore +from llama_stack_client.types.agents.turn import Turn from llama_stack_client.types.agents.turn_create_params import ( ToolgroupAgentToolGroupWithArgs, Toolgroup, @@ -35,7 +35,8 @@ validate_conversation_ownership, ) from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups -from utils.suid import get_suid +from utils.transcripts import store_transcript +from utils.types import TurnSummary logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["query"]) @@ -203,7 +204,7 @@ async def query_endpoint_handler( user_conversation=user_conversation, query_request=query_request ), ) - response, conversation_id = await retrieve_response( + summary, conversation_id = await retrieve_response( client, llama_stack_model_id, query_request, @@ -224,7 +225,7 @@ async def query_endpoint_handler( query_is_valid=True, # TODO(lucasagomes): implement as part of query validation query=query_request.query, query_request=query_request, - response=response, + summary=summary, rag_chunks=[], # TODO(lucasagomes): implement rag_chunks truncated=False, # TODO(lucasagomes): implement truncation as part of quota work attachments=query_request.attachments or [], @@ -237,7 +238,10 @@ async def query_endpoint_handler( provider_id=provider_id, ) - return QueryResponse(conversation_id=conversation_id, response=response) + return QueryResponse( + conversation_id=conversation_id, + response=summary.llm_response, + ) # connection to Llama Stack server except APIConnectionError as e: @@ -381,7 +385,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche query_request: QueryRequest, token: str, mcp_headers: dict[str, dict[str, str]] | None = None, -) -> tuple[str, str]: +) -> tuple[TurnSummary, str]: """ Retrieve response from LLMs and agents. @@ -404,7 +408,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing. Returns: - tuple[str, str]: A tuple containing the LLM or agent's response content + tuple[TurnSummary, str]: A tuple containing a summary of the LLM or agent's response content and the conversation ID. """ available_input_shields = [ @@ -484,27 +488,35 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche stream=False, toolgroups=toolgroups, ) + response = cast(Turn, response) + + summary = TurnSummary( + llm_response=( + interleaved_content_as_str(response.output_message.content) + if ( + getattr(response, "output_message", None) is not None + and getattr(response.output_message, "content", None) is not None + ) + else "" + ), + tool_calls=[], + ) # Check for validation errors in the response - steps = getattr(response, "steps", []) + steps = response.steps or [] for step in steps: if step.step_type == "shield_call" and step.violation: # Metric for LLM validation errors metrics.llm_calls_validation_errors_total.inc() - break - - output_message = getattr(response, "output_message", None) - if output_message is not None: - content = getattr(output_message, "content", None) - if content is not None: - return str(content), conversation_id - - # fallback - logger.warning( - "Response lacks output_message.content (conversation_id=%s)", - conversation_id, - ) - return "", conversation_id + if step.step_type == "tool_execution": + summary.append_tool_calls_from_llama(step) + + if not summary.llm_response: + logger.warning( + "Response lacks output_message.content (conversation_id=%s)", + conversation_id, + ) + return summary, conversation_id def validate_attachments_metadata(attachments: list[Attachment]) -> None: @@ -539,92 +551,6 @@ def validate_attachments_metadata(attachments: list[Attachment]) -> None: ) -def construct_transcripts_path(user_id: str, conversation_id: str) -> Path: - """ - Construct path to transcripts. - - Constructs a sanitized filesystem path for storing transcripts - based on the user ID and conversation ID. - - Parameters: - user_id (str): The user identifier, which will be normalized and sanitized. - conversation_id (str): The conversation identifier, which will be normalized and sanitized. - - Returns: - Path: The constructed path for storing transcripts for the specified user and conversation. - """ - # these two normalizations are required by Snyk as it detects - # this Path sanitization pattern - uid = os.path.normpath("/" + user_id).lstrip("/") - cid = os.path.normpath("/" + conversation_id).lstrip("/") - file_path = ( - configuration.user_data_collection_configuration.transcripts_storage or "" - ) - return Path(file_path, uid, cid) - - -def store_transcript( # pylint: disable=too-many-arguments,too-many-positional-arguments - user_id: str, - conversation_id: str, - model_id: str, - provider_id: str | None, - query_is_valid: bool, - query: str, - query_request: QueryRequest, - response: str, - rag_chunks: list[str], - truncated: bool, - attachments: list[Attachment], -) -> None: - """ - Store transcript in the local filesystem. - - Constructs a sanitized filesystem path for storing transcripts - based on the user ID and conversation ID. - - Returns: - Path: The constructed path for storing transcripts for the specified user and conversation. - - Args: - user_id: The user ID (UUID). - conversation_id: The conversation ID (UUID). - query_is_valid: The result of the query validation. - query: The query (without attachments). - query_request: The request containing a query. - response: The response to store. - rag_chunks: The list of `RagChunk` objects. - truncated: The flag indicating if the history was truncated. - attachments: The list of `Attachment` objects. - """ - transcripts_path = construct_transcripts_path(user_id, conversation_id) - transcripts_path.mkdir(parents=True, exist_ok=True) - - data_to_store = { - "metadata": { - "provider": provider_id, - "model": model_id, - "query_provider": query_request.provider, - "query_model": query_request.model, - "user_id": user_id, - "conversation_id": conversation_id, - "timestamp": datetime.now(UTC).isoformat(), - }, - "redacted_query": query, - "query_is_valid": query_is_valid, - "llm_response": response, - "rag_chunks": rag_chunks, - "truncated": truncated, - "attachments": [attachment.model_dump() for attachment in attachments], - } - - # stores feedback in a file under unique uuid - transcript_file_path = transcripts_path / f"{get_suid()}.json" - with open(transcript_file_path, "w", encoding="utf-8") as transcript_file: - json.dump(data_to_store, transcript_file) - - logger.info("Transcript successfully stored at: %s", transcript_file_path) - - def get_rag_toolgroups( vector_db_ids: list[str], ) -> list[Toolgroup] | None: diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 329a72302..69d29ebd1 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -4,13 +4,16 @@ import json import re import logging -from typing import Annotated, Any, AsyncIterator, Iterator +from typing import Annotated, Any, AsyncIterator, Iterator, cast from llama_stack_client import APIConnectionError from llama_stack_client import AsyncLlamaStackClient # type: ignore from llama_stack_client.types import UserMessage # type: ignore from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str +from llama_stack_client.types.agents.agent_turn_response_stream_chunk import ( + AgentTurnResponseStreamChunk, +) from llama_stack_client.types.shared import ToolCall from llama_stack_client.types.shared.interleaved_content_item import TextContentItem @@ -26,13 +29,14 @@ from models.database.conversations import UserConversation from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups +from utils.transcripts import store_transcript +from utils.types import TurnSummary from app.endpoints.query import ( get_rag_toolgroups, is_input_shield, is_output_shield, is_transcripts_enabled, - store_transcript, select_model_and_provider_id, validate_attachments_metadata, validate_conversation_ownership, @@ -574,7 +578,9 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals ) metadata_map: dict[str, dict[str, Any]] = {} - async def response_generator(turn_response: Any) -> AsyncIterator[str]: + async def response_generator( + turn_response: AsyncIterator[AgentTurnResponseStreamChunk], + ) -> AsyncIterator[str]: """ Generate SSE formatted streaming response. @@ -587,20 +593,24 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]: complete response for transcript storage if enabled. """ chunk_id = 0 - complete_response = "No response from the model" + summary = TurnSummary( + llm_response="No response from the model", tool_calls=[] + ) # Send start event yield stream_start_event(conversation_id) async for chunk in turn_response: + p = chunk.event.payload + if p.event_type == "turn_complete": + summary.llm_response = interleaved_content_as_str( + p.turn.output_message.content + ) + elif p.event_type == "step_complete": + if p.step_details.step_type == "tool_execution": + summary.append_tool_calls_from_llama(p.step_details) + for event in stream_build_event(chunk, chunk_id, metadata_map): - if ( - json.loads(event.replace("data: ", ""))["event"] - == "turn_complete" - ): - complete_response = json.loads(event.replace("data: ", ""))[ - "data" - ]["token"] chunk_id += 1 yield event @@ -617,7 +627,7 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]: query_is_valid=True, # TODO(lucasagomes): implement as part of query validation query=query_request.query, query_request=query_request, - response=complete_response, + summary=summary, rag_chunks=[], # TODO(lucasagomes): implement rag_chunks truncated=False, # TODO(lucasagomes): implement truncation as part # of quota work @@ -655,7 +665,7 @@ async def retrieve_response( query_request: QueryRequest, token: str, mcp_headers: dict[str, dict[str, str]] | None = None, -) -> tuple[Any, str]: +) -> tuple[AsyncIterator[AgentTurnResponseStreamChunk], str]: """ Retrieve response from LLMs and agents. @@ -758,5 +768,6 @@ async def retrieve_response( stream=True, toolgroups=toolgroups, ) + response = cast(AsyncIterator[AgentTurnResponseStreamChunk], response) return response, conversation_id diff --git a/src/utils/transcripts.py b/src/utils/transcripts.py new file mode 100644 index 000000000..e29d4319d --- /dev/null +++ b/src/utils/transcripts.py @@ -0,0 +1,86 @@ +"""Transcript handling. + +Transcripts are a log of individual query/response pairs that get +stored on disk for later analysis +""" + +from datetime import UTC, datetime +import json +import logging +import os +from pathlib import Path + +from configuration import configuration +from models.requests import Attachment, QueryRequest +from utils.suid import get_suid +from utils.types import TurnSummary + +logger = logging.getLogger("utils.transcripts") + + +def construct_transcripts_path(user_id: str, conversation_id: str) -> Path: + """Construct path to transcripts.""" + # these two normalizations are required by Snyk as it detects + # this Path sanitization pattern + uid = os.path.normpath("/" + user_id).lstrip("/") + cid = os.path.normpath("/" + conversation_id).lstrip("/") + file_path = ( + configuration.user_data_collection_configuration.transcripts_storage or "" + ) + return Path(file_path, uid, cid) + + +def store_transcript( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals + user_id: str, + conversation_id: str, + model_id: str, + provider_id: str | None, + query_is_valid: bool, + query: str, + query_request: QueryRequest, + summary: TurnSummary, + rag_chunks: list[str], + truncated: bool, + attachments: list[Attachment], +) -> None: + """Store transcript in the local filesystem. + + Args: + user_id: The user ID (UUID). + conversation_id: The conversation ID (UUID). + query_is_valid: The result of the query validation. + query: The query (without attachments). + query_request: The request containing a query. + summary: Summary of the query/response turn. + rag_chunks: The list of `RagChunk` objects. + truncated: The flag indicating if the history was truncated. + attachments: The list of `Attachment` objects. + """ + transcripts_path = construct_transcripts_path(user_id, conversation_id) + transcripts_path.mkdir(parents=True, exist_ok=True) + + data_to_store = { + "metadata": { + "provider": provider_id, + "model": model_id, + "query_provider": query_request.provider, + "query_model": query_request.model, + "user_id": user_id, + "conversation_id": conversation_id, + "timestamp": datetime.now(UTC).isoformat(), + }, + "redacted_query": query, + "query_is_valid": query_is_valid, + "llm_response": summary.llm_response, + "rag_chunks": rag_chunks, + "truncated": truncated, + "attachments": [attachment.model_dump() for attachment in attachments], + "tool_calls": [tc.model_dump() for tc in summary.tool_calls], + } + + # stores feedback in a file under unique uuid + transcript_file_path = transcripts_path / f"{get_suid()}.json" + with open(transcript_file_path, "w", encoding="utf-8") as transcript_file: + json.dump(data_to_store, transcript_file) + + logger.info("Transcript successfully stored at: %s", transcript_file_path) diff --git a/src/utils/types.py b/src/utils/types.py index 52bbb8fe6..5770139ae 100644 --- a/src/utils/types.py +++ b/src/utils/types.py @@ -1,10 +1,13 @@ """Common types for the project.""" -from typing import Optional +from typing import Any, Optional +from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str from llama_stack_client.lib.agents.tool_parser import ToolParser from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.shared.tool_call import ToolCall +from llama_stack_client.types.tool_execution_step import ToolExecutionStep +from pydantic.main import BaseModel class Singleton(type): @@ -35,3 +38,41 @@ def get_parser(model_id: str) -> Optional[ToolParser]: if model_id and model_id.lower().startswith("granite"): return GraniteToolParser() return None + + +class ToolCallSummary(BaseModel): + """Represents a tool call for data collection. + + Use our own tool call model to keep things consistent across llama + upgrades or if we used something besides llama in the future. + """ + + # ID of the call itself + id: str + # Name of the tool used + name: str + # Arguments to the tool call + args: str | dict[Any, Any] + response: str | None + + +class TurnSummary(BaseModel): + """Summary of a turn in llama stack.""" + + llm_response: str + tool_calls: list[ToolCallSummary] + + def append_tool_calls_from_llama(self, tec: ToolExecutionStep) -> None: + """Append the tool calls from a llama tool execution step.""" + calls_by_id = {tc.call_id: tc for tc in tec.tool_calls} + responses_by_id = {tc.call_id: tc for tc in tec.tool_responses} + for call_id, tc in calls_by_id.items(): + resp = responses_by_id.get(call_id) + self.tool_calls.append( + ToolCallSummary( + id=call_id, + name=tc.tool_name, + args=tc.arguments, + response=interleaved_content_as_str(resp.content) if resp else None, + ) + ) diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 9f3d065fc..3356f2b27 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -8,6 +8,9 @@ from llama_stack_client import APIConnectionError from llama_stack_client.types import UserMessage # type: ignore +from llama_stack_client.types.shared.interleaved_content import ( + TextContentItem, +) from configuration import AppConfig from app.endpoints.query import ( @@ -16,8 +19,6 @@ retrieve_response, validate_attachments_metadata, is_transcripts_enabled, - construct_transcripts_path, - store_transcript, get_rag_toolgroups, evaluate_model_hints, ) @@ -25,6 +26,7 @@ from models.requests import QueryRequest, Attachment from models.config import ModelContextProtocolServer from models.database.conversations import UserConversation +from utils.types import ToolCallSummary, TurnSummary MOCK_AUTH = ("mock_user_id", "mock_username", "mock_token") @@ -122,13 +124,23 @@ async def _test_query_endpoint_handler(mocker, store_transcript_to_file=False): ) mocker.patch("app.endpoints.query.configuration", mock_config) - llm_response = "LLM answer" + summary = TurnSummary( + llm_response="LLM answer", + tool_calls=[ + ToolCallSummary( + id="123", + name="test-tool", + args="testing", + response="tool response", + ) + ], + ) conversation_id = "fake_conversation_id" query = "What is OpenStack?" mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(llm_response, conversation_id), + return_value=(summary, conversation_id), ) mocker.patch( "app.endpoints.query.select_model_and_provider_id", @@ -148,7 +160,7 @@ async def _test_query_endpoint_handler(mocker, store_transcript_to_file=False): response = await query_endpoint_handler(query_request, auth=MOCK_AUTH) # Assert the response is as expected - assert response.response == llm_response + assert response.response == summary.llm_response assert response.conversation_id == conversation_id # Assert the metric for successful LLM calls is incremented @@ -164,7 +176,7 @@ async def _test_query_endpoint_handler(mocker, store_transcript_to_file=False): query_is_valid=True, query=query, query_request=query_request, - response=llm_response, + summary=summary, attachments=[], rag_chunks=[], truncated=False, @@ -412,7 +424,7 @@ async def test_retrieve_response_no_returned_message(prepare_agent_mocks, mocker ) # fallback mechanism: check that the response is empty - assert response == "" + assert response.llm_response == "" @pytest.mark.asyncio @@ -443,7 +455,7 @@ async def test_retrieve_response_message_without_content(prepare_agent_mocks, mo ) # fallback mechanism: check that the response is empty - assert response == "" + assert response.llm_response == "" @pytest.mark.asyncio @@ -470,13 +482,13 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = await retrieve_response( + summary, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) # Assert that the metric for validation errors is NOT incremented mock_metric.inc.assert_not_called() - assert response == "LLM answer" + assert summary.llm_response == "LLM answer" assert conversation_id == "fake_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], @@ -508,11 +520,11 @@ async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocke model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = await retrieve_response( + summary, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) - assert response == "LLM answer" + assert summary.llm_response == "LLM answer" assert conversation_id == "fake_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], @@ -557,11 +569,11 @@ def __repr__(self): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = await retrieve_response( + summary, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) - assert response == "LLM answer" + assert summary.llm_response == "LLM answer" assert conversation_id == "fake_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], @@ -609,11 +621,11 @@ def __repr__(self): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = await retrieve_response( + summary, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) - assert response == "LLM answer" + assert summary.llm_response == "LLM answer" assert conversation_id == "fake_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], @@ -663,11 +675,11 @@ def __repr__(self): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = await retrieve_response( + summary, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) - assert response == "LLM answer" + assert summary.llm_response == "LLM answer" assert conversation_id == "fake_conversation_id" # Verify get_agent was called with the correct parameters @@ -719,11 +731,11 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = await retrieve_response( + summary, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) - assert response == "LLM answer" + assert summary.llm_response == "LLM answer" assert conversation_id == "fake_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], @@ -773,11 +785,11 @@ async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocke model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = await retrieve_response( + summary, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) - assert response == "LLM answer" + assert summary.llm_response == "LLM answer" assert conversation_id == "fake_conversation_id" mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], @@ -828,11 +840,11 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token_123" - response, conversation_id = await retrieve_response( + summary, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) - assert response == "LLM answer" + assert summary.llm_response == "LLM answer" assert conversation_id == "fake_conversation_id" # Verify get_agent was called with the correct parameters @@ -897,11 +909,11 @@ async def test_retrieve_response_with_mcp_servers_empty_token( model_id = "fake_model_id" access_token = "" # Empty token - response, conversation_id = await retrieve_response( + summary, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) - assert response == "LLM answer" + assert summary.llm_response == "LLM answer" assert conversation_id == "fake_conversation_id" # Verify get_agent was called with the correct parameters @@ -968,7 +980,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers( }, } - response, conversation_id = await retrieve_response( + summary, conversation_id = await retrieve_response( mock_client, model_id, query_request, @@ -976,7 +988,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers( mcp_headers=mcp_headers, ) - assert response == "LLM answer" + assert summary.llm_response == "LLM answer" assert conversation_id == "fake_conversation_id" # Verify get_agent was called with the correct parameters @@ -1033,6 +1045,9 @@ async def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker): ), ] mock_agent.create_turn.return_value.steps = steps + mock_agent.create_turn.return_value.output_message.content = TextContentItem( + text="LLM answer", type="text" + ) mock_client.shields.list.return_value = [] mock_vector_db = mocker.Mock() mock_vector_db.identifier = "VectorDB-1" @@ -1066,86 +1081,6 @@ async def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker): ) -def test_construct_transcripts_path(setup_configuration, mocker): - """Test the construct_transcripts_path function.""" - # Update configuration for this test - setup_configuration.user_data_collection_configuration.transcripts_storage = ( - "/tmp/transcripts" - ) - mocker.patch("app.endpoints.query.configuration", setup_configuration) - - user_id = "user123" - conversation_id = "123e4567-e89b-12d3-a456-426614174000" - - path = construct_transcripts_path(user_id, conversation_id) - - assert ( - str(path) == "/tmp/transcripts/user123/123e4567-e89b-12d3-a456-426614174000" - ), "Path should be constructed correctly" - - -def test_store_transcript(mocker): - """Test the store_transcript function.""" - - mocker.patch("builtins.open", mocker.mock_open()) - mocker.patch( - "app.endpoints.query.construct_transcripts_path", - return_value=mocker.MagicMock(), - ) - - # Mock the JSON to assert the data is stored correctly - mock_json = mocker.patch("app.endpoints.query.json") - - # Mock parameters - user_id = "user123" - conversation_id = "123e4567-e89b-12d3-a456-426614174000" - query = "What is OpenStack?" - model = "fake-model" - provider = "fake-provider" - query_request = QueryRequest(query=query, model=model, provider=provider) - response = "LLM answer" - query_is_valid = True - rag_chunks = [] - truncated = False - attachments = [] - - store_transcript( - user_id, - conversation_id, - model, - provider, - query_is_valid, - query, - query_request, - response, - rag_chunks, - truncated, - attachments, - ) - - # Assert that the transcript was stored correctly - mock_json.dump.assert_called_once_with( - { - "metadata": { - "provider": "fake-provider", - "model": "fake-model", - "query_provider": query_request.provider, - "query_model": query_request.model, - "user_id": user_id, - "conversation_id": conversation_id, - "timestamp": mocker.ANY, - }, - "redacted_query": query, - "query_is_valid": query_is_valid, - "llm_response": response, - "rag_chunks": rag_chunks, - "truncated": truncated, - "attachments": attachments, - }, - mocker.ANY, - ) - - def test_get_rag_toolgroups(): """Test get_rag_toolgroups function.""" vector_db_ids = [] @@ -1199,9 +1134,20 @@ async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker): "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client ) + summary = TurnSummary( + llm_response="LLM answer", + tool_calls=[ + ToolCallSummary( + id="123", + name="test-tool", + args="testing", + response="tool response", + ) + ], + ) mock_retrieve_response = mocker.patch( "app.endpoints.query.retrieve_response", - return_value=("test response", "test_conversation_id"), + return_value=(summary, "test_conversation_id"), ) mocker.patch( @@ -1235,13 +1181,23 @@ async def test_query_endpoint_handler_no_tools_true(mocker): mock_config.user_data_collection_configuration.transcripts_disabled = True mocker.patch("app.endpoints.query.configuration", mock_config) - llm_response = "LLM answer without tools" + summary = TurnSummary( + llm_response="LLM answer", + tool_calls=[ + ToolCallSummary( + id="123", + name="test-tool", + args="testing", + response="tool response", + ) + ], + ) conversation_id = "fake_conversation_id" query = "What is OpenStack?" mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(llm_response, conversation_id), + return_value=(summary, conversation_id), ) mocker.patch( "app.endpoints.query.select_model_and_provider_id", @@ -1256,7 +1212,7 @@ async def test_query_endpoint_handler_no_tools_true(mocker): response = await query_endpoint_handler(query_request, auth=MOCK_AUTH) # Assert the response is as expected - assert response.response == llm_response + assert response.response == summary.llm_response assert response.conversation_id == conversation_id @@ -1274,13 +1230,23 @@ async def test_query_endpoint_handler_no_tools_false(mocker): mock_config.user_data_collection_configuration.transcripts_disabled = True mocker.patch("app.endpoints.query.configuration", mock_config) - llm_response = "LLM answer with tools" + summary = TurnSummary( + llm_response="LLM answer", + tool_calls=[ + ToolCallSummary( + id="123", + name="test-tool", + args="testing", + response="tool response", + ) + ], + ) conversation_id = "fake_conversation_id" query = "What is OpenStack?" mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(llm_response, conversation_id), + return_value=(summary, conversation_id), ) mocker.patch( "app.endpoints.query.select_model_and_provider_id", @@ -1295,7 +1261,7 @@ async def test_query_endpoint_handler_no_tools_false(mocker): response = await query_endpoint_handler(query_request, auth=MOCK_AUTH) # Assert the response is as expected - assert response.response == llm_response + assert response.response == summary.llm_response assert response.conversation_id == conversation_id @@ -1329,11 +1295,11 @@ async def test_retrieve_response_no_tools_bypasses_mcp_and_rag( model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = await retrieve_response( + summary, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) - assert response == "LLM answer" + assert summary.llm_response == "LLM answer" assert conversation_id == "fake_conversation_id" # Verify that agent.extra_headers is empty (no MCP headers) @@ -1379,11 +1345,11 @@ async def test_retrieve_response_no_tools_false_preserves_functionality( model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = await retrieve_response( + summary, conversation_id = await retrieve_response( mock_client, model_id, query_request, access_token ) - assert response == "LLM answer" + assert summary.llm_response == "LLM answer" assert conversation_id == "fake_conversation_id" # Verify that agent.extra_headers contains MCP headers diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 8e03aa9c3..8953ca216 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -44,6 +44,7 @@ from models.requests import QueryRequest, Attachment from models.config import ModelContextProtocolServer +from utils.types import ToolCallSummary, TurnSummary MOCK_AUTH = ("mock_user_id", "mock_username", "mock_token") @@ -218,7 +219,7 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) step_type="tool_execution", tool_responses=[ ToolResponse( - call_id="c1", + call_id="t1", tool_name="knowledge_search", content=[ TextContentItem(text=s, type="text") @@ -323,7 +324,17 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) query_is_valid=True, query=query, query_request=query_request, - response="LLM answer", + summary=TurnSummary( + llm_response="LLM answer", + tool_calls=[ + ToolCallSummary( + id="t1", + name="knowledge_search", + args={}, + response=" ".join(SAMPLE_KNOWLEDGE_SEARCH_RESULTS), + ) + ], + ), attachments=[], rag_chunks=[], truncated=False, diff --git a/tests/unit/utils/test_transcripts.py b/tests/unit/utils/test_transcripts.py new file mode 100644 index 000000000..b30b430a2 --- /dev/null +++ b/tests/unit/utils/test_transcripts.py @@ -0,0 +1,127 @@ +"""Unit tests for functions defined in utils.transcripts module.""" + +from configuration import AppConfig +from models.requests import QueryRequest + +from utils.transcripts import ( + construct_transcripts_path, + store_transcript, +) +from utils.types import ToolCallSummary, TurnSummary + + +def test_construct_transcripts_path(mocker): + """Test the construct_transcripts_path function.""" + + config_dict = { + "name": "test", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "workers": 1, + "color_log": True, + "access_log": True, + }, + "llama_stack": { + "api_key": "test-key", + "url": "http://test.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": { + "transcripts_storage": "/tmp/transcripts", + }, + } + cfg = AppConfig() + cfg.init_from_dict(config_dict) + # Update configuration for this test + mocker.patch("utils.transcripts.configuration", cfg) + + user_id = "user123" + conversation_id = "123e4567-e89b-12d3-a456-426614174000" + + path = construct_transcripts_path(user_id, conversation_id) + + assert ( + str(path) == "/tmp/transcripts/user123/123e4567-e89b-12d3-a456-426614174000" + ), "Path should be constructed correctly" + + +def test_store_transcript(mocker): + """Test the store_transcript function.""" + + mocker.patch("builtins.open", mocker.mock_open()) + mocker.patch( + "utils.transcripts.construct_transcripts_path", + return_value=mocker.MagicMock(), + ) + + # Mock the JSON to assert the data is stored correctly + mock_json = mocker.patch("utils.transcripts.json") + + # Mock parameters + user_id = "user123" + conversation_id = "123e4567-e89b-12d3-a456-426614174000" + query = "What is OpenStack?" + model = "fake-model" + provider = "fake-provider" + query_request = QueryRequest(query=query, model=model, provider=provider) + summary = TurnSummary( + llm_response="LLM answer", + tool_calls=[ + ToolCallSummary( + id="123", + name="test-tool", + args="testing", + response="tool response", + ) + ], + ) + query_is_valid = True + rag_chunks = [] + truncated = False + attachments = [] + + store_transcript( + user_id, + conversation_id, + model, + provider, + query_is_valid, + query, + query_request, + summary, + rag_chunks, + truncated, + attachments, + ) + + # Assert that the transcript was stored correctly + mock_json.dump.assert_called_once_with( + { + "metadata": { + "provider": "fake-provider", + "model": "fake-model", + "query_provider": query_request.provider, + "query_model": query_request.model, + "user_id": user_id, + "conversation_id": conversation_id, + "timestamp": mocker.ANY, + }, + "redacted_query": query, + "query_is_valid": query_is_valid, + "llm_response": summary.llm_response, + "rag_chunks": rag_chunks, + "truncated": truncated, + "attachments": attachments, + "tool_calls": [ + { + "id": "123", + "name": "test-tool", + "args": "testing", + "response": "tool response", + } + ], + }, + mocker.ANY, + )