diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 952c0a0b..5fb1278d 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -374,7 +374,6 @@ async def query_endpoint_handler_base( # pylint: disable=R0914 ) completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ") - cache_entry = CacheEntry( query=query_request.query, response=summary.llm_response, @@ -383,6 +382,8 @@ async def query_endpoint_handler_base( # pylint: disable=R0914 started_at=started_at, completed_at=completed_at, referenced_documents=referenced_documents if referenced_documents else None, + tool_calls=summary.tool_calls if summary.tool_calls else None, + tool_results=summary.tool_results if summary.tool_results else None, ) consume_tokens( diff --git a/src/cache/postgres_cache.py b/src/cache/postgres_cache.py index a8300022..5e72cf3a 100644 --- a/src/cache/postgres_cache.py +++ b/src/cache/postgres_cache.py @@ -9,8 +9,9 @@ from models.cache_entry import CacheEntry from models.config import PostgreSQLDatabaseConfiguration from models.responses import ConversationData, ReferencedDocument -from log import get_logger from utils.connection_decorator import connection +from utils.types import ToolCallSummary, ToolResultSummary +from log import get_logger logger = get_logger("cache.postgres_cache") @@ -32,7 +33,9 @@ class PostgresCache(Cache): response | text | | provider | text | | model | text | | - referenced_documents | jsonb | | + referenced_documents | jsonb | | + tool_calls | jsonb | | + tool_results | jsonb | | Indexes: "cache_pkey" PRIMARY KEY, btree (user_id, conversation_id, created_at) "timestamps" btree (created_at) @@ -55,6 +58,8 @@ class PostgresCache(Cache): provider text, model text, referenced_documents jsonb, + tool_calls jsonb, + tool_results jsonb, PRIMARY KEY(user_id, conversation_id, created_at) ); """ @@ -75,7 +80,8 @@ class PostgresCache(Cache): """ SELECT_CONVERSATION_HISTORY_STATEMENT = """ - SELECT query, response, provider, model, started_at, completed_at, referenced_documents + SELECT query, response, provider, model, started_at, completed_at, + referenced_documents, tool_calls, tool_results FROM cache WHERE user_id=%s AND conversation_id=%s ORDER BY created_at @@ -83,8 +89,9 @@ class PostgresCache(Cache): INSERT_CONVERSATION_HISTORY_STATEMENT = """ INSERT INTO cache(user_id, conversation_id, created_at, started_at, completed_at, - query, response, provider, model, referenced_documents) - VALUES (%s, %s, CURRENT_TIMESTAMP, %s, %s, %s, %s, %s, %s, %s) + query, response, provider, model, referenced_documents, + tool_calls, tool_results) + VALUES (%s, %s, CURRENT_TIMESTAMP, %s, %s, %s, %s, %s, %s, %s, %s, %s) """ QUERY_CACHE_SIZE = """ @@ -220,7 +227,7 @@ def initialize_cache(self, namespace: str) -> None: self.connection.commit() @connection - def get( + def get( # pylint: disable=R0914 self, user_id: str, conversation_id: str, skip_user_id_check: bool = False ) -> list[CacheEntry]: """Get the value associated with the given key. @@ -260,6 +267,40 @@ def get( conversation_id, e, ) + + # Parse tool_calls back into ToolCallSummary objects + tool_calls_data = conversation_entry[7] + tool_calls_obj = None + if tool_calls_data: + try: + tool_calls_obj = [ + ToolCallSummary.model_validate(tc) for tc in tool_calls_data + ] + except (ValueError, TypeError) as e: + logger.warning( + "Failed to deserialize tool_calls for " + "conversation %s: %s", + conversation_id, + e, + ) + + # Parse tool_results back into ToolResultSummary objects + tool_results_data = conversation_entry[8] + tool_results_obj = None + if tool_results_data: + try: + tool_results_obj = [ + ToolResultSummary.model_validate(tr) + for tr in tool_results_data + ] + except (ValueError, TypeError) as e: + logger.warning( + "Failed to deserialize tool_results for " + "conversation %s: %s", + conversation_id, + e, + ) + cache_entry = CacheEntry( query=conversation_entry[0], response=conversation_entry[1], @@ -268,6 +309,8 @@ def get( started_at=conversation_entry[4], completed_at=conversation_entry[5], referenced_documents=docs_obj, + tool_calls=tool_calls_obj, + tool_results=tool_results_obj, ) result.append(cache_entry) @@ -311,6 +354,34 @@ def insert_or_append( e, ) + tool_calls_json = None + if cache_entry.tool_calls: + try: + tool_calls_as_dicts = [ + tc.model_dump(mode="json") for tc in cache_entry.tool_calls + ] + tool_calls_json = json.dumps(tool_calls_as_dicts) + except (TypeError, ValueError) as e: + logger.warning( + "Failed to serialize tool_calls for conversation %s: %s", + conversation_id, + e, + ) + + tool_results_json = None + if cache_entry.tool_results: + try: + tool_results_as_dicts = [ + tr.model_dump(mode="json") for tr in cache_entry.tool_results + ] + tool_results_json = json.dumps(tool_results_as_dicts) + except (TypeError, ValueError) as e: + logger.warning( + "Failed to serialize tool_results for conversation %s: %s", + conversation_id, + e, + ) + # the whole operation is run in one transaction with self.connection.cursor() as cursor: cursor.execute( @@ -325,6 +396,8 @@ def insert_or_append( cache_entry.provider, cache_entry.model, referenced_documents_json, + tool_calls_json, + tool_results_json, ), ) diff --git a/src/cache/sqlite_cache.py b/src/cache/sqlite_cache.py index bf70355b..ba48cb77 100644 --- a/src/cache/sqlite_cache.py +++ b/src/cache/sqlite_cache.py @@ -10,8 +10,9 @@ from models.cache_entry import CacheEntry from models.config import SQLiteDatabaseConfiguration from models.responses import ConversationData, ReferencedDocument -from log import get_logger from utils.connection_decorator import connection +from utils.types import ToolCallSummary, ToolResultSummary +from log import get_logger logger = get_logger("cache.sqlite_cache") @@ -34,6 +35,8 @@ class SQLiteCache(Cache): provider | text | | model | text | | referenced_documents | text | | + tool_calls | text | | + tool_results | text | | Indexes: "cache_pkey" PRIMARY KEY, btree (user_id, conversation_id, created_at) "cache_key_key" UNIQUE CONSTRAINT, btree (key) @@ -54,6 +57,8 @@ class SQLiteCache(Cache): provider text, model text, referenced_documents text, + tool_calls text, + tool_results text, PRIMARY KEY(user_id, conversation_id, created_at) ); """ @@ -74,7 +79,8 @@ class SQLiteCache(Cache): """ SELECT_CONVERSATION_HISTORY_STATEMENT = """ - SELECT query, response, provider, model, started_at, completed_at, referenced_documents + SELECT query, response, provider, model, started_at, completed_at, + referenced_documents, tool_calls, tool_results FROM cache WHERE user_id=? AND conversation_id=? ORDER BY created_at @@ -82,8 +88,9 @@ class SQLiteCache(Cache): INSERT_CONVERSATION_HISTORY_STATEMENT = """ INSERT INTO cache(user_id, conversation_id, created_at, started_at, completed_at, - query, response, provider, model, referenced_documents) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + query, response, provider, model, referenced_documents, + tool_calls, tool_results) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """ QUERY_CACHE_SIZE = """ @@ -187,7 +194,7 @@ def initialize_cache(self) -> None: self.connection.commit() @connection - def get( + def get( # pylint: disable=R0914 self, user_id: str, conversation_id: str, skip_user_id_check: bool = False ) -> list[CacheEntry]: """Get the value associated with the given key. @@ -228,6 +235,39 @@ def get( conversation_id, e, ) + + # Parse tool_calls back into ToolCallSummary objects + tool_calls_json_str = conversation_entry[7] + tool_calls_obj = None + if tool_calls_json_str: + try: + tool_calls_data = json.loads(tool_calls_json_str) + tool_calls_obj = [ + ToolCallSummary.model_validate(tc) for tc in tool_calls_data + ] + except (json.JSONDecodeError, ValueError) as e: + logger.warning( + "Failed to deserialize tool_calls for conversation %s: %s", + conversation_id, + e, + ) + + # Parse tool_results back into ToolResultSummary objects + tool_results_json_str = conversation_entry[8] + tool_results_obj = None + if tool_results_json_str: + try: + tool_results_data = json.loads(tool_results_json_str) + tool_results_obj = [ + ToolResultSummary.model_validate(tr) for tr in tool_results_data + ] + except (json.JSONDecodeError, ValueError) as e: + logger.warning( + "Failed to deserialize tool_results for conversation %s: %s", + conversation_id, + e, + ) + cache_entry = CacheEntry( query=conversation_entry[0], response=conversation_entry[1], @@ -236,6 +276,8 @@ def get( started_at=conversation_entry[4], completed_at=conversation_entry[5], referenced_documents=docs_obj, + tool_calls=tool_calls_obj, + tool_results=tool_results_obj, ) result.append(cache_entry) @@ -281,6 +323,34 @@ def insert_or_append( e, ) + tool_calls_json = None + if cache_entry.tool_calls: + try: + tool_calls_as_dicts = [ + tc.model_dump(mode="json") for tc in cache_entry.tool_calls + ] + tool_calls_json = json.dumps(tool_calls_as_dicts) + except (TypeError, ValueError) as e: + logger.warning( + "Failed to serialize tool_calls for conversation %s: %s", + conversation_id, + e, + ) + + tool_results_json = None + if cache_entry.tool_results: + try: + tool_results_as_dicts = [ + tr.model_dump(mode="json") for tr in cache_entry.tool_results + ] + tool_results_json = json.dumps(tool_results_as_dicts) + except (TypeError, ValueError) as e: + logger.warning( + "Failed to serialize tool_results for conversation %s: %s", + conversation_id, + e, + ) + cursor.execute( self.INSERT_CONVERSATION_HISTORY_STATEMENT, ( @@ -294,6 +364,8 @@ def insert_or_append( cache_entry.provider, cache_entry.model, referenced_documents_json, + tool_calls_json, + tool_results_json, ), ) diff --git a/src/models/cache_entry.py b/src/models/cache_entry.py index 637de38a..e00069ce 100644 --- a/src/models/cache_entry.py +++ b/src/models/cache_entry.py @@ -3,6 +3,7 @@ from typing import Optional from pydantic import BaseModel from models.responses import ReferencedDocument +from utils.types import ToolCallSummary, ToolResultSummary class CacheEntry(BaseModel): @@ -14,6 +15,8 @@ class CacheEntry(BaseModel): provider: Provider identification model: Model identification referenced_documents: List of documents referenced by the response + tool_calls: List of tool calls made during response generation + tool_results: List of tool results from tool calls """ query: str @@ -23,3 +26,5 @@ class CacheEntry(BaseModel): started_at: str completed_at: str referenced_documents: Optional[list[ReferencedDocument]] = None + tool_calls: Optional[list[ToolCallSummary]] = None + tool_results: Optional[list[ToolResultSummary]] = None diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index 2fb51871..0db9d503 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -806,6 +806,8 @@ async def cleanup_after_streaming( started_at=started_at, completed_at=completed_at, referenced_documents=referenced_documents if referenced_documents else None, + tool_calls=summary.tool_calls if summary.tool_calls else None, + tool_results=summary.tool_results if summary.tool_results else None, ) store_conversation_into_cache( diff --git a/tests/unit/cache/test_postgres_cache.py b/tests/unit/cache/test_postgres_cache.py index 5503936c..b720de61 100644 --- a/tests/unit/cache/test_postgres_cache.py +++ b/tests/unit/cache/test_postgres_cache.py @@ -14,6 +14,7 @@ from models.cache_entry import CacheEntry from models.responses import ConversationData, ReferencedDocument from utils import suid +from utils.types import ToolCallSummary, ToolResultSummary from cache.cache_error import CacheError from cache.postgres_cache import PostgresCache @@ -597,7 +598,8 @@ def test_insert_and_get_with_referenced_documents( ] assert insert_calls, "INSERT call not found" sql_params = insert_calls[-1][0][1] - inserted_json_str = sql_params[-1] + # referenced_documents is now at index -3 (before tool_calls and tool_results) + inserted_json_str = sql_params[-3] assert json.loads(inserted_json_str) == [ {"doc_url": "http://example.com/", "doc_title": "Test Doc"} @@ -612,6 +614,8 @@ def test_insert_and_get_with_referenced_documents( "start_time", "end_time", [{"doc_url": "http://example.com/", "doc_title": "Test Doc"}], + None, # tool_calls + None, # tool_results ) mock_cursor.fetchall.return_value = [db_return_value] @@ -648,7 +652,10 @@ def test_insert_and_get_without_referenced_documents( ] assert insert_calls, "INSERT call not found" sql_params = insert_calls[-1][0][1] - assert sql_params[-1] is None + # Last 3 params are referenced_documents, tool_calls, tool_results - all should be None + assert sql_params[-3] is None # referenced_documents + assert sql_params[-2] is None # tool_calls + assert sql_params[-1] is None # tool_results # Simulate the database returning a row with None db_return_value = ( @@ -659,6 +666,8 @@ def test_insert_and_get_without_referenced_documents( entry_without_docs.started_at, entry_without_docs.completed_at, None, # referenced_documents is None in the DB + None, # tool_calls + None, # tool_results ) mock_cursor.fetchall.return_value = [db_return_value] @@ -669,6 +678,8 @@ def test_insert_and_get_without_referenced_documents( assert len(retrieved_entries) == 1 assert retrieved_entries[0] == entry_without_docs assert retrieved_entries[0].referenced_documents is None + assert retrieved_entries[0].tool_calls is None + assert retrieved_entries[0].tool_results is None def test_initialize_cache_with_custom_namespace( @@ -710,3 +721,119 @@ def test_connect_to_cache_with_too_long_namespace( # should fail due to invalid namespace containing spaces with pytest.raises(ValueError, match="Invalid namespace: too long namespace"): PostgresCache(postgres_cache_config_fixture_too_long_namespace) + + +def test_insert_and_get_with_tool_calls_and_results( + postgres_cache_config_fixture: PostgreSQLDatabaseConfiguration, + mocker: MockerFixture, +) -> None: + """Test that a CacheEntry with tool_calls and tool_results is stored and retrieved correctly.""" + # prevent real connection to PG instance + mock_connect = mocker.patch("psycopg2.connect") + cache = PostgresCache(postgres_cache_config_fixture) + + mock_connection = mock_connect.return_value + mock_cursor = mock_connection.cursor.return_value.__enter__.return_value + + # Create tool_calls and tool_results + tool_calls = [ + ToolCallSummary( + id="call_1", name="test_tool", args={"param": "value"}, type="tool_call" + ) + ] + tool_results = [ + ToolResultSummary( + id="call_1", + status="success", + content="result data", + type="tool_result", + round=1, + ) + ] + entry_with_tools = CacheEntry( + query="user message", + response="AI message", + provider="foo", + model="bar", + started_at="start_time", + completed_at="end_time", + tool_calls=tool_calls, + tool_results=tool_results, + ) + + # Call the insert method + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_with_tools) + + # Find the INSERT INTO cache(...) call + insert_calls = [ + c + for c in mock_cursor.execute.call_args_list + if isinstance(c[0][0], str) and "INSERT INTO cache(" in c[0][0] + ] + assert insert_calls, "INSERT call not found" + sql_params = insert_calls[-1][0][1] + + # Verify tool_calls JSON + tool_calls_json = sql_params[-2] + assert json.loads(tool_calls_json) == [ + { + "id": "call_1", + "name": "test_tool", + "args": {"param": "value"}, + "type": "tool_call", + } + ] + + # Verify tool_results JSON + tool_results_json = sql_params[-1] + assert json.loads(tool_results_json) == [ + { + "id": "call_1", + "status": "success", + "content": "result data", + "type": "tool_result", + "round": 1, + } + ] + + # Simulate the database returning that data + db_return_value = ( + "user message", + "AI message", + "foo", + "bar", + "start_time", + "end_time", + None, # referenced_documents + [ + { + "id": "call_1", + "name": "test_tool", + "args": {"param": "value"}, + "type": "tool_call", + } + ], + [ + { + "id": "call_1", + "status": "success", + "content": "result data", + "type": "tool_result", + "round": 1, + } + ], + ) + mock_cursor.fetchall.return_value = [db_return_value] + + # Call the get method + retrieved_entries = cache.get(USER_ID_1, CONVERSATION_ID_1) + + # Assert that the retrieved entry matches the original + assert len(retrieved_entries) == 1 + assert retrieved_entries[0] == entry_with_tools + assert retrieved_entries[0].tool_calls is not None + assert len(retrieved_entries[0].tool_calls) == 1 + assert retrieved_entries[0].tool_calls[0].name == "test_tool" + assert retrieved_entries[0].tool_results is not None + assert len(retrieved_entries[0].tool_results) == 1 + assert retrieved_entries[0].tool_results[0].status == "success" diff --git a/tests/unit/cache/test_sqlite_cache.py b/tests/unit/cache/test_sqlite_cache.py index 9e8f5626..a62195db 100644 --- a/tests/unit/cache/test_sqlite_cache.py +++ b/tests/unit/cache/test_sqlite_cache.py @@ -13,6 +13,7 @@ from models.cache_entry import CacheEntry from models.responses import ConversationData, ReferencedDocument from utils import suid +from utils.types import ToolCallSummary, ToolResultSummary from cache.cache_error import CacheError from cache.sqlite_cache import SQLiteCache @@ -470,3 +471,103 @@ def test_insert_and_get_without_referenced_documents(tmpdir: Path) -> None: assert len(retrieved_entries) == 1 assert retrieved_entries[0] == entry_without_docs assert retrieved_entries[0].referenced_documents is None + assert retrieved_entries[0].tool_calls is None + assert retrieved_entries[0].tool_results is None + + +def test_insert_and_get_with_tool_calls_and_results(tmpdir: Path) -> None: + """ + Test that a CacheEntry with tool_calls and tool_results is correctly + serialized, stored, and retrieved. + """ + cache = create_cache(tmpdir) + + # Create tool_calls and tool_results + tool_calls = [ + ToolCallSummary( + id="call_1", name="test_tool", args={"param": "value"}, type="tool_call" + ) + ] + tool_results = [ + ToolResultSummary( + id="call_1", + status="success", + content="result data", + type="tool_result", + round=1, + ) + ] + entry_with_tools = CacheEntry( + query="user message", + response="AI message", + provider="foo", + model="bar", + started_at="start_time", + completed_at="end_time", + tool_calls=tool_calls, + tool_results=tool_results, + ) + + # Call the insert method + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_with_tools) + retrieved_entries = cache.get(USER_ID_1, CONVERSATION_ID_1) + + # Assert that the retrieved entry matches the original + assert len(retrieved_entries) == 1 + assert retrieved_entries[0] == entry_with_tools + assert retrieved_entries[0].tool_calls is not None + assert len(retrieved_entries[0].tool_calls) == 1 + assert retrieved_entries[0].tool_calls[0].name == "test_tool" + assert retrieved_entries[0].tool_calls[0].args == {"param": "value"} + assert retrieved_entries[0].tool_results is not None + assert len(retrieved_entries[0].tool_results) == 1 + assert retrieved_entries[0].tool_results[0].status == "success" + assert retrieved_entries[0].tool_results[0].content == "result data" + + +def test_insert_and_get_with_all_fields(tmpdir: Path) -> None: + """ + Test that a CacheEntry with all fields (referenced_documents, tool_calls, + tool_results) is correctly serialized, stored, and retrieved. + """ + cache = create_cache(tmpdir) + + # Create all fields + docs = [ + ReferencedDocument(doc_title="Test Doc", doc_url=AnyUrl("http://example.com")) + ] + tool_calls = [ + ToolCallSummary( + id="call_1", name="test_tool", args={"key": "value"}, type="tool_call" + ) + ] + tool_results = [ + ToolResultSummary( + id="call_1", status="success", content="output", type="tool_result", round=1 + ) + ] + entry_with_all = CacheEntry( + query="user query", + response="AI response", + provider="provider", + model="model", + started_at="start", + completed_at="end", + referenced_documents=docs, + tool_calls=tool_calls, + tool_results=tool_results, + ) + + # Call the insert method + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_with_all) + retrieved_entries = cache.get(USER_ID_1, CONVERSATION_ID_1) + + # Assert that the retrieved entry matches the original + assert len(retrieved_entries) == 1 + assert retrieved_entries[0] == entry_with_all + assert retrieved_entries[0].referenced_documents is not None + assert retrieved_entries[0].referenced_documents[0].doc_title == "Test Doc" + assert retrieved_entries[0].tool_calls is not None + assert retrieved_entries[0].tool_calls[0].name == "test_tool" + assert retrieved_entries[0].tool_results is not None + assert retrieved_entries[0].tool_results[0].status == "success" diff --git a/tests/unit/utils/test_endpoints.py b/tests/unit/utils/test_endpoints.py index c0641685..b1343674 100644 --- a/tests/unit/utils/test_endpoints.py +++ b/tests/unit/utils/test_endpoints.py @@ -1069,7 +1069,9 @@ async def test_cleanup_after_streaming_generate_topic_summary_default_true( provider_id="test_provider", llama_stack_model_id="test_llama_model", query_request=query_request, - summary=mocker.Mock(llm_response="test response", tool_calls=[]), + summary=mocker.Mock( + llm_response="test response", tool_calls=[], tool_results=[] + ), metadata_map={}, started_at="2024-01-01T00:00:00Z", client=mock_client, @@ -1121,7 +1123,9 @@ async def test_cleanup_after_streaming_generate_topic_summary_explicit_false( provider_id="test_provider", llama_stack_model_id="test_llama_model", query_request=query_request, - summary=mocker.Mock(llm_response="test response", tool_calls=[]), + summary=mocker.Mock( + llm_response="test response", tool_calls=[], tool_results=[] + ), metadata_map={}, started_at="2024-01-01T00:00:00Z", client=mock_client,