From 539c630aefb0f2de05963dc5ffd0fbee70931494 Mon Sep 17 00:00:00 2001 From: Stephanie Date: Mon, 12 Jan 2026 16:55:21 -0500 Subject: [PATCH 1/4] store tool calls and tool summary in cache entry Signed-off-by: Stephanie --- src/app/endpoints/query.py | 5 +- src/cache/postgres_cache.py | 83 ++++++++++++++++++++- src/cache/sqlite_cache.py | 82 +++++++++++++++++++- src/models/cache_entry.py | 5 ++ src/utils/endpoints.py | 2 + tests/unit/cache/test_postgres_cache.py | 99 ++++++++++++++++++++++++- tests/unit/cache/test_sqlite_cache.py | 93 +++++++++++++++++++++++ 7 files changed, 358 insertions(+), 11 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 952c0a0bd..fbb0d8d6b 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -363,7 +363,7 @@ async def query_endpoint_handler_base( # pylint: disable=R0914 truncated=False, # TODO(lucasagomes): implement truncation as part of quota work attachments=query_request.attachments or [], ) - + logger.info("Persisting conversation details...") persist_user_conversation_details( user_id=user_id, @@ -374,7 +374,6 @@ async def query_endpoint_handler_base( # pylint: disable=R0914 ) completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ") - cache_entry = CacheEntry( query=query_request.query, response=summary.llm_response, @@ -383,6 +382,8 @@ async def query_endpoint_handler_base( # pylint: disable=R0914 started_at=started_at, completed_at=completed_at, referenced_documents=referenced_documents if referenced_documents else None, + tool_calls=summary.tool_calls if summary.tool_calls else None, + tool_results=summary.tool_results if summary.tool_results else None, ) consume_tokens( diff --git a/src/cache/postgres_cache.py b/src/cache/postgres_cache.py index a83000221..2f98af237 100644 --- a/src/cache/postgres_cache.py +++ b/src/cache/postgres_cache.py @@ -9,6 +9,7 @@ from models.cache_entry import CacheEntry from models.config import PostgreSQLDatabaseConfiguration from models.responses import ConversationData, ReferencedDocument +from utils.types import ToolCallSummary, ToolResultSummary from log import get_logger from utils.connection_decorator import connection @@ -32,7 +33,9 @@ class PostgresCache(Cache): response | text | | provider | text | | model | text | | - referenced_documents | jsonb | | + referenced_documents | jsonb | | + tool_calls | jsonb | | + tool_results | jsonb | | Indexes: "cache_pkey" PRIMARY KEY, btree (user_id, conversation_id, created_at) "timestamps" btree (created_at) @@ -55,6 +58,8 @@ class PostgresCache(Cache): provider text, model text, referenced_documents jsonb, + tool_calls jsonb, + tool_results jsonb, PRIMARY KEY(user_id, conversation_id, created_at) ); """ @@ -75,7 +80,8 @@ class PostgresCache(Cache): """ SELECT_CONVERSATION_HISTORY_STATEMENT = """ - SELECT query, response, provider, model, started_at, completed_at, referenced_documents + SELECT query, response, provider, model, started_at, completed_at, + referenced_documents, tool_calls, tool_results FROM cache WHERE user_id=%s AND conversation_id=%s ORDER BY created_at @@ -83,8 +89,9 @@ class PostgresCache(Cache): INSERT_CONVERSATION_HISTORY_STATEMENT = """ INSERT INTO cache(user_id, conversation_id, created_at, started_at, completed_at, - query, response, provider, model, referenced_documents) - VALUES (%s, %s, CURRENT_TIMESTAMP, %s, %s, %s, %s, %s, %s, %s) + query, response, provider, model, referenced_documents, + tool_calls, tool_results) + VALUES (%s, %s, CURRENT_TIMESTAMP, %s, %s, %s, %s, %s, %s, %s, %s, %s) """ QUERY_CACHE_SIZE = """ @@ -260,6 +267,40 @@ def get( conversation_id, e, ) + + # Parse tool_calls back into ToolCallSummary objects + tool_calls_data = conversation_entry[7] + tool_calls_obj = None + if tool_calls_data: + try: + tool_calls_obj = [ + ToolCallSummary.model_validate(tc) for tc in tool_calls_data + ] + except (ValueError, TypeError) as e: + logger.warning( + "Failed to deserialize tool_calls for " + "conversation %s: %s", + conversation_id, + e, + ) + + # Parse tool_results back into ToolResultSummary objects + tool_results_data = conversation_entry[8] + tool_results_obj = None + if tool_results_data: + try: + tool_results_obj = [ + ToolResultSummary.model_validate(tr) + for tr in tool_results_data + ] + except (ValueError, TypeError) as e: + logger.warning( + "Failed to deserialize tool_results for " + "conversation %s: %s", + conversation_id, + e, + ) + cache_entry = CacheEntry( query=conversation_entry[0], response=conversation_entry[1], @@ -268,6 +309,8 @@ def get( started_at=conversation_entry[4], completed_at=conversation_entry[5], referenced_documents=docs_obj, + tool_calls=tool_calls_obj, + tool_results=tool_results_obj, ) result.append(cache_entry) @@ -311,6 +354,36 @@ def insert_or_append( e, ) + tool_calls_json = None + if cache_entry.tool_calls: + try: + tool_calls_as_dicts = [ + tc.model_dump(mode="json") for tc in cache_entry.tool_calls + ] + tool_calls_json = json.dumps(tool_calls_as_dicts) + except (TypeError, ValueError) as e: + logger.warning( + "Failed to serialize tool_calls for " + "conversation %s: %s", + conversation_id, + e, + ) + + tool_results_json = None + if cache_entry.tool_results: + try: + tool_results_as_dicts = [ + tr.model_dump(mode="json") for tr in cache_entry.tool_results + ] + tool_results_json = json.dumps(tool_results_as_dicts) + except (TypeError, ValueError) as e: + logger.warning( + "Failed to serialize tool_results for " + "conversation %s: %s", + conversation_id, + e, + ) + # the whole operation is run in one transaction with self.connection.cursor() as cursor: cursor.execute( @@ -325,6 +398,8 @@ def insert_or_append( cache_entry.provider, cache_entry.model, referenced_documents_json, + tool_calls_json, + tool_results_json, ), ) diff --git a/src/cache/sqlite_cache.py b/src/cache/sqlite_cache.py index bf70355bd..4d686d178 100644 --- a/src/cache/sqlite_cache.py +++ b/src/cache/sqlite_cache.py @@ -10,6 +10,7 @@ from models.cache_entry import CacheEntry from models.config import SQLiteDatabaseConfiguration from models.responses import ConversationData, ReferencedDocument +from utils.types import ToolCallSummary, ToolResultSummary from log import get_logger from utils.connection_decorator import connection @@ -34,6 +35,8 @@ class SQLiteCache(Cache): provider | text | | model | text | | referenced_documents | text | | + tool_calls | text | | + tool_results | text | | Indexes: "cache_pkey" PRIMARY KEY, btree (user_id, conversation_id, created_at) "cache_key_key" UNIQUE CONSTRAINT, btree (key) @@ -54,6 +57,8 @@ class SQLiteCache(Cache): provider text, model text, referenced_documents text, + tool_calls text, + tool_results text, PRIMARY KEY(user_id, conversation_id, created_at) ); """ @@ -74,7 +79,8 @@ class SQLiteCache(Cache): """ SELECT_CONVERSATION_HISTORY_STATEMENT = """ - SELECT query, response, provider, model, started_at, completed_at, referenced_documents + SELECT query, response, provider, model, started_at, completed_at, + referenced_documents, tool_calls, tool_results FROM cache WHERE user_id=? AND conversation_id=? ORDER BY created_at @@ -82,8 +88,9 @@ class SQLiteCache(Cache): INSERT_CONVERSATION_HISTORY_STATEMENT = """ INSERT INTO cache(user_id, conversation_id, created_at, started_at, completed_at, - query, response, provider, model, referenced_documents) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + query, response, provider, model, referenced_documents, + tool_calls, tool_results) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """ QUERY_CACHE_SIZE = """ @@ -228,6 +235,41 @@ def get( conversation_id, e, ) + + # Parse tool_calls back into ToolCallSummary objects + tool_calls_json_str = conversation_entry[7] + tool_calls_obj = None + if tool_calls_json_str: + try: + tool_calls_data = json.loads(tool_calls_json_str) + tool_calls_obj = [ + ToolCallSummary.model_validate(tc) for tc in tool_calls_data + ] + except (json.JSONDecodeError, ValueError) as e: + logger.warning( + "Failed to deserialize tool_calls for " + "conversation %s: %s", + conversation_id, + e, + ) + + # Parse tool_results back into ToolResultSummary objects + tool_results_json_str = conversation_entry[8] + tool_results_obj = None + if tool_results_json_str: + try: + tool_results_data = json.loads(tool_results_json_str) + tool_results_obj = [ + ToolResultSummary.model_validate(tr) for tr in tool_results_data + ] + except (json.JSONDecodeError, ValueError) as e: + logger.warning( + "Failed to deserialize tool_results for " + "conversation %s: %s", + conversation_id, + e, + ) + cache_entry = CacheEntry( query=conversation_entry[0], response=conversation_entry[1], @@ -236,6 +278,8 @@ def get( started_at=conversation_entry[4], completed_at=conversation_entry[5], referenced_documents=docs_obj, + tool_calls=tool_calls_obj, + tool_results=tool_results_obj, ) result.append(cache_entry) @@ -281,6 +325,36 @@ def insert_or_append( e, ) + tool_calls_json = None + if cache_entry.tool_calls: + try: + tool_calls_as_dicts = [ + tc.model_dump(mode="json") for tc in cache_entry.tool_calls + ] + tool_calls_json = json.dumps(tool_calls_as_dicts) + except (TypeError, ValueError) as e: + logger.warning( + "Failed to serialize tool_calls for " + "conversation %s: %s", + conversation_id, + e, + ) + + tool_results_json = None + if cache_entry.tool_results: + try: + tool_results_as_dicts = [ + tr.model_dump(mode="json") for tr in cache_entry.tool_results + ] + tool_results_json = json.dumps(tool_results_as_dicts) + except (TypeError, ValueError) as e: + logger.warning( + "Failed to serialize tool_results for " + "conversation %s: %s", + conversation_id, + e, + ) + cursor.execute( self.INSERT_CONVERSATION_HISTORY_STATEMENT, ( @@ -294,6 +368,8 @@ def insert_or_append( cache_entry.provider, cache_entry.model, referenced_documents_json, + tool_calls_json, + tool_results_json, ), ) diff --git a/src/models/cache_entry.py b/src/models/cache_entry.py index 637de38a9..e00069ce4 100644 --- a/src/models/cache_entry.py +++ b/src/models/cache_entry.py @@ -3,6 +3,7 @@ from typing import Optional from pydantic import BaseModel from models.responses import ReferencedDocument +from utils.types import ToolCallSummary, ToolResultSummary class CacheEntry(BaseModel): @@ -14,6 +15,8 @@ class CacheEntry(BaseModel): provider: Provider identification model: Model identification referenced_documents: List of documents referenced by the response + tool_calls: List of tool calls made during response generation + tool_results: List of tool results from tool calls """ query: str @@ -23,3 +26,5 @@ class CacheEntry(BaseModel): started_at: str completed_at: str referenced_documents: Optional[list[ReferencedDocument]] = None + tool_calls: Optional[list[ToolCallSummary]] = None + tool_results: Optional[list[ToolResultSummary]] = None diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index 2fb51871c..0db9d5034 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -806,6 +806,8 @@ async def cleanup_after_streaming( started_at=started_at, completed_at=completed_at, referenced_documents=referenced_documents if referenced_documents else None, + tool_calls=summary.tool_calls if summary.tool_calls else None, + tool_results=summary.tool_results if summary.tool_results else None, ) store_conversation_into_cache( diff --git a/tests/unit/cache/test_postgres_cache.py b/tests/unit/cache/test_postgres_cache.py index 5503936c7..4c6ef1141 100644 --- a/tests/unit/cache/test_postgres_cache.py +++ b/tests/unit/cache/test_postgres_cache.py @@ -14,6 +14,7 @@ from models.cache_entry import CacheEntry from models.responses import ConversationData, ReferencedDocument from utils import suid +from utils.types import ToolCallSummary, ToolResultSummary from cache.cache_error import CacheError from cache.postgres_cache import PostgresCache @@ -597,7 +598,8 @@ def test_insert_and_get_with_referenced_documents( ] assert insert_calls, "INSERT call not found" sql_params = insert_calls[-1][0][1] - inserted_json_str = sql_params[-1] + # referenced_documents is now at index -3 (before tool_calls and tool_results) + inserted_json_str = sql_params[-3] assert json.loads(inserted_json_str) == [ {"doc_url": "http://example.com/", "doc_title": "Test Doc"} @@ -612,6 +614,8 @@ def test_insert_and_get_with_referenced_documents( "start_time", "end_time", [{"doc_url": "http://example.com/", "doc_title": "Test Doc"}], + None, # tool_calls + None, # tool_results ) mock_cursor.fetchall.return_value = [db_return_value] @@ -648,7 +652,10 @@ def test_insert_and_get_without_referenced_documents( ] assert insert_calls, "INSERT call not found" sql_params = insert_calls[-1][0][1] - assert sql_params[-1] is None + # Last 3 params are referenced_documents, tool_calls, tool_results - all should be None + assert sql_params[-3] is None # referenced_documents + assert sql_params[-2] is None # tool_calls + assert sql_params[-1] is None # tool_results # Simulate the database returning a row with None db_return_value = ( @@ -659,6 +666,8 @@ def test_insert_and_get_without_referenced_documents( entry_without_docs.started_at, entry_without_docs.completed_at, None, # referenced_documents is None in the DB + None, # tool_calls + None, # tool_results ) mock_cursor.fetchall.return_value = [db_return_value] @@ -669,6 +678,8 @@ def test_insert_and_get_without_referenced_documents( assert len(retrieved_entries) == 1 assert retrieved_entries[0] == entry_without_docs assert retrieved_entries[0].referenced_documents is None + assert retrieved_entries[0].tool_calls is None + assert retrieved_entries[0].tool_results is None def test_initialize_cache_with_custom_namespace( @@ -710,3 +721,87 @@ def test_connect_to_cache_with_too_long_namespace( # should fail due to invalid namespace containing spaces with pytest.raises(ValueError, match="Invalid namespace: too long namespace"): PostgresCache(postgres_cache_config_fixture_too_long_namespace) + + +def test_insert_and_get_with_tool_calls_and_results( + postgres_cache_config_fixture: PostgreSQLDatabaseConfiguration, + mocker: MockerFixture, +) -> None: + """Test that a CacheEntry with tool_calls and tool_results is stored and retrieved correctly.""" + # prevent real connection to PG instance + mock_connect = mocker.patch("psycopg2.connect") + cache = PostgresCache(postgres_cache_config_fixture) + + mock_connection = mock_connect.return_value + mock_cursor = mock_connection.cursor.return_value.__enter__.return_value + + # Create tool_calls and tool_results + tool_calls = [ + ToolCallSummary(id="call_1", name="test_tool", args={"param": "value"}, type="tool_call") + ] + tool_results = [ + ToolResultSummary( + id="call_1", status="success", content="result data", type="tool_result", round=1 + ) + ] + entry_with_tools = CacheEntry( + query="user message", + response="AI message", + provider="foo", + model="bar", + started_at="start_time", + completed_at="end_time", + tool_calls=tool_calls, + tool_results=tool_results, + ) + + # Call the insert method + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_with_tools) + + # Find the INSERT INTO cache(...) call + insert_calls = [ + c + for c in mock_cursor.execute.call_args_list + if isinstance(c[0][0], str) and "INSERT INTO cache(" in c[0][0] + ] + assert insert_calls, "INSERT call not found" + sql_params = insert_calls[-1][0][1] + + # Verify tool_calls JSON + tool_calls_json = sql_params[-2] + assert json.loads(tool_calls_json) == [ + {"id": "call_1", "name": "test_tool", "args": {"param": "value"}, "type": "tool_call"} + ] + + # Verify tool_results JSON + tool_results_json = sql_params[-1] + assert json.loads(tool_results_json) == [ + {"id": "call_1", "status": "success", "content": "result data", "type": "tool_result", "round": 1} + ] + + # Simulate the database returning that data + db_return_value = ( + "user message", + "AI message", + "foo", + "bar", + "start_time", + "end_time", + None, # referenced_documents + [{"id": "call_1", "name": "test_tool", "args": {"param": "value"}, "type": "tool_call"}], + [{"id": "call_1", "status": "success", "content": "result data", "type": "tool_result", "round": 1}], + ) + mock_cursor.fetchall.return_value = [db_return_value] + + # Call the get method + retrieved_entries = cache.get(USER_ID_1, CONVERSATION_ID_1) + + # Assert that the retrieved entry matches the original + assert len(retrieved_entries) == 1 + assert retrieved_entries[0] == entry_with_tools + assert retrieved_entries[0].tool_calls is not None + assert len(retrieved_entries[0].tool_calls) == 1 + assert retrieved_entries[0].tool_calls[0].name == "test_tool" + assert retrieved_entries[0].tool_results is not None + assert len(retrieved_entries[0].tool_results) == 1 + assert retrieved_entries[0].tool_results[0].status == "success" diff --git a/tests/unit/cache/test_sqlite_cache.py b/tests/unit/cache/test_sqlite_cache.py index 9e8f5626f..96d420f25 100644 --- a/tests/unit/cache/test_sqlite_cache.py +++ b/tests/unit/cache/test_sqlite_cache.py @@ -13,6 +13,7 @@ from models.cache_entry import CacheEntry from models.responses import ConversationData, ReferencedDocument from utils import suid +from utils.types import ToolCallSummary, ToolResultSummary from cache.cache_error import CacheError from cache.sqlite_cache import SQLiteCache @@ -470,3 +471,95 @@ def test_insert_and_get_without_referenced_documents(tmpdir: Path) -> None: assert len(retrieved_entries) == 1 assert retrieved_entries[0] == entry_without_docs assert retrieved_entries[0].referenced_documents is None + assert retrieved_entries[0].tool_calls is None + assert retrieved_entries[0].tool_results is None + + +def test_insert_and_get_with_tool_calls_and_results(tmpdir: Path) -> None: + """ + Test that a CacheEntry with tool_calls and tool_results is correctly + serialized, stored, and retrieved. + """ + cache = create_cache(tmpdir) + + # Create tool_calls and tool_results + tool_calls = [ + ToolCallSummary(id="call_1", name="test_tool", args={"param": "value"}, type="tool_call") + ] + tool_results = [ + ToolResultSummary( + id="call_1", status="success", content="result data", type="tool_result", round=1 + ) + ] + entry_with_tools = CacheEntry( + query="user message", + response="AI message", + provider="foo", + model="bar", + started_at="start_time", + completed_at="end_time", + tool_calls=tool_calls, + tool_results=tool_results, + ) + + # Call the insert method + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_with_tools) + retrieved_entries = cache.get(USER_ID_1, CONVERSATION_ID_1) + + # Assert that the retrieved entry matches the original + assert len(retrieved_entries) == 1 + assert retrieved_entries[0] == entry_with_tools + assert retrieved_entries[0].tool_calls is not None + assert len(retrieved_entries[0].tool_calls) == 1 + assert retrieved_entries[0].tool_calls[0].name == "test_tool" + assert retrieved_entries[0].tool_calls[0].args == {"param": "value"} + assert retrieved_entries[0].tool_results is not None + assert len(retrieved_entries[0].tool_results) == 1 + assert retrieved_entries[0].tool_results[0].status == "success" + assert retrieved_entries[0].tool_results[0].content == "result data" + + +def test_insert_and_get_with_all_fields(tmpdir: Path) -> None: + """ + Test that a CacheEntry with all fields (referenced_documents, tool_calls, + tool_results) is correctly serialized, stored, and retrieved. + """ + cache = create_cache(tmpdir) + + # Create all fields + docs = [ + ReferencedDocument(doc_title="Test Doc", doc_url=AnyUrl("http://example.com")) + ] + tool_calls = [ + ToolCallSummary(id="call_1", name="test_tool", args={"key": "value"}, type="tool_call") + ] + tool_results = [ + ToolResultSummary( + id="call_1", status="success", content="output", type="tool_result", round=1 + ) + ] + entry_with_all = CacheEntry( + query="user query", + response="AI response", + provider="provider", + model="model", + started_at="start", + completed_at="end", + referenced_documents=docs, + tool_calls=tool_calls, + tool_results=tool_results, + ) + + # Call the insert method + cache.insert_or_append(USER_ID_1, CONVERSATION_ID_1, entry_with_all) + retrieved_entries = cache.get(USER_ID_1, CONVERSATION_ID_1) + + # Assert that the retrieved entry matches the original + assert len(retrieved_entries) == 1 + assert retrieved_entries[0] == entry_with_all + assert retrieved_entries[0].referenced_documents is not None + assert retrieved_entries[0].referenced_documents[0].doc_title == "Test Doc" + assert retrieved_entries[0].tool_calls is not None + assert retrieved_entries[0].tool_calls[0].name == "test_tool" + assert retrieved_entries[0].tool_results is not None + assert retrieved_entries[0].tool_results[0].status == "success" From 7789ae1bb6bc233f1fadc31093e855b1ee8be59a Mon Sep 17 00:00:00 2001 From: Stephanie Date: Mon, 12 Jan 2026 17:03:43 -0500 Subject: [PATCH 2/4] fix format Signed-off-by: Stephanie --- src/app/endpoints/query.py | 2 +- src/cache/postgres_cache.py | 6 ++-- src/cache/sqlite_cache.py | 12 +++---- tests/unit/cache/test_postgres_cache.py | 44 +++++++++++++++++++++---- tests/unit/cache/test_sqlite_cache.py | 14 ++++++-- 5 files changed, 56 insertions(+), 22 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index fbb0d8d6b..5fb1278d9 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -363,7 +363,7 @@ async def query_endpoint_handler_base( # pylint: disable=R0914 truncated=False, # TODO(lucasagomes): implement truncation as part of quota work attachments=query_request.attachments or [], ) - + logger.info("Persisting conversation details...") persist_user_conversation_details( user_id=user_id, diff --git a/src/cache/postgres_cache.py b/src/cache/postgres_cache.py index 2f98af237..c910a3413 100644 --- a/src/cache/postgres_cache.py +++ b/src/cache/postgres_cache.py @@ -363,8 +363,7 @@ def insert_or_append( tool_calls_json = json.dumps(tool_calls_as_dicts) except (TypeError, ValueError) as e: logger.warning( - "Failed to serialize tool_calls for " - "conversation %s: %s", + "Failed to serialize tool_calls for " "conversation %s: %s", conversation_id, e, ) @@ -378,8 +377,7 @@ def insert_or_append( tool_results_json = json.dumps(tool_results_as_dicts) except (TypeError, ValueError) as e: logger.warning( - "Failed to serialize tool_results for " - "conversation %s: %s", + "Failed to serialize tool_results for " "conversation %s: %s", conversation_id, e, ) diff --git a/src/cache/sqlite_cache.py b/src/cache/sqlite_cache.py index 4d686d178..7a172cc3f 100644 --- a/src/cache/sqlite_cache.py +++ b/src/cache/sqlite_cache.py @@ -247,8 +247,7 @@ def get( ] except (json.JSONDecodeError, ValueError) as e: logger.warning( - "Failed to deserialize tool_calls for " - "conversation %s: %s", + "Failed to deserialize tool_calls for " "conversation %s: %s", conversation_id, e, ) @@ -264,8 +263,7 @@ def get( ] except (json.JSONDecodeError, ValueError) as e: logger.warning( - "Failed to deserialize tool_results for " - "conversation %s: %s", + "Failed to deserialize tool_results for " "conversation %s: %s", conversation_id, e, ) @@ -334,8 +332,7 @@ def insert_or_append( tool_calls_json = json.dumps(tool_calls_as_dicts) except (TypeError, ValueError) as e: logger.warning( - "Failed to serialize tool_calls for " - "conversation %s: %s", + "Failed to serialize tool_calls for " "conversation %s: %s", conversation_id, e, ) @@ -349,8 +346,7 @@ def insert_or_append( tool_results_json = json.dumps(tool_results_as_dicts) except (TypeError, ValueError) as e: logger.warning( - "Failed to serialize tool_results for " - "conversation %s: %s", + "Failed to serialize tool_results for " "conversation %s: %s", conversation_id, e, ) diff --git a/tests/unit/cache/test_postgres_cache.py b/tests/unit/cache/test_postgres_cache.py index 4c6ef1141..b720de614 100644 --- a/tests/unit/cache/test_postgres_cache.py +++ b/tests/unit/cache/test_postgres_cache.py @@ -737,11 +737,17 @@ def test_insert_and_get_with_tool_calls_and_results( # Create tool_calls and tool_results tool_calls = [ - ToolCallSummary(id="call_1", name="test_tool", args={"param": "value"}, type="tool_call") + ToolCallSummary( + id="call_1", name="test_tool", args={"param": "value"}, type="tool_call" + ) ] tool_results = [ ToolResultSummary( - id="call_1", status="success", content="result data", type="tool_result", round=1 + id="call_1", + status="success", + content="result data", + type="tool_result", + round=1, ) ] entry_with_tools = CacheEntry( @@ -770,13 +776,24 @@ def test_insert_and_get_with_tool_calls_and_results( # Verify tool_calls JSON tool_calls_json = sql_params[-2] assert json.loads(tool_calls_json) == [ - {"id": "call_1", "name": "test_tool", "args": {"param": "value"}, "type": "tool_call"} + { + "id": "call_1", + "name": "test_tool", + "args": {"param": "value"}, + "type": "tool_call", + } ] # Verify tool_results JSON tool_results_json = sql_params[-1] assert json.loads(tool_results_json) == [ - {"id": "call_1", "status": "success", "content": "result data", "type": "tool_result", "round": 1} + { + "id": "call_1", + "status": "success", + "content": "result data", + "type": "tool_result", + "round": 1, + } ] # Simulate the database returning that data @@ -788,8 +805,23 @@ def test_insert_and_get_with_tool_calls_and_results( "start_time", "end_time", None, # referenced_documents - [{"id": "call_1", "name": "test_tool", "args": {"param": "value"}, "type": "tool_call"}], - [{"id": "call_1", "status": "success", "content": "result data", "type": "tool_result", "round": 1}], + [ + { + "id": "call_1", + "name": "test_tool", + "args": {"param": "value"}, + "type": "tool_call", + } + ], + [ + { + "id": "call_1", + "status": "success", + "content": "result data", + "type": "tool_result", + "round": 1, + } + ], ) mock_cursor.fetchall.return_value = [db_return_value] diff --git a/tests/unit/cache/test_sqlite_cache.py b/tests/unit/cache/test_sqlite_cache.py index 96d420f25..a62195db7 100644 --- a/tests/unit/cache/test_sqlite_cache.py +++ b/tests/unit/cache/test_sqlite_cache.py @@ -484,11 +484,17 @@ def test_insert_and_get_with_tool_calls_and_results(tmpdir: Path) -> None: # Create tool_calls and tool_results tool_calls = [ - ToolCallSummary(id="call_1", name="test_tool", args={"param": "value"}, type="tool_call") + ToolCallSummary( + id="call_1", name="test_tool", args={"param": "value"}, type="tool_call" + ) ] tool_results = [ ToolResultSummary( - id="call_1", status="success", content="result data", type="tool_result", round=1 + id="call_1", + status="success", + content="result data", + type="tool_result", + round=1, ) ] entry_with_tools = CacheEntry( @@ -531,7 +537,9 @@ def test_insert_and_get_with_all_fields(tmpdir: Path) -> None: ReferencedDocument(doc_title="Test Doc", doc_url=AnyUrl("http://example.com")) ] tool_calls = [ - ToolCallSummary(id="call_1", name="test_tool", args={"key": "value"}, type="tool_call") + ToolCallSummary( + id="call_1", name="test_tool", args={"key": "value"}, type="tool_call" + ) ] tool_results = [ ToolResultSummary( From ba3e2a65593a152f88a089c5d433b9132451c974 Mon Sep 17 00:00:00 2001 From: Stephanie Date: Mon, 12 Jan 2026 17:14:49 -0500 Subject: [PATCH 3/4] fix pylint Signed-off-by: Stephanie --- src/cache/postgres_cache.py | 10 ++++++---- src/cache/sqlite_cache.py | 16 ++++++++++------ 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/cache/postgres_cache.py b/src/cache/postgres_cache.py index c910a3413..61aace081 100644 --- a/src/cache/postgres_cache.py +++ b/src/cache/postgres_cache.py @@ -9,9 +9,9 @@ from models.cache_entry import CacheEntry from models.config import PostgreSQLDatabaseConfiguration from models.responses import ConversationData, ReferencedDocument +from utils.connection_decorator import connection from utils.types import ToolCallSummary, ToolResultSummary from log import get_logger -from utils.connection_decorator import connection logger = get_logger("cache.postgres_cache") @@ -227,7 +227,7 @@ def initialize_cache(self, namespace: str) -> None: self.connection.commit() @connection - def get( + def get( # pylint: disable=R0914 self, user_id: str, conversation_id: str, skip_user_id_check: bool = False ) -> list[CacheEntry]: """Get the value associated with the given key. @@ -363,7 +363,8 @@ def insert_or_append( tool_calls_json = json.dumps(tool_calls_as_dicts) except (TypeError, ValueError) as e: logger.warning( - "Failed to serialize tool_calls for " "conversation %s: %s", + "Failed to serialize tool_calls for " + "conversation %s: %s", conversation_id, e, ) @@ -377,7 +378,8 @@ def insert_or_append( tool_results_json = json.dumps(tool_results_as_dicts) except (TypeError, ValueError) as e: logger.warning( - "Failed to serialize tool_results for " "conversation %s: %s", + "Failed to serialize tool_results for " + "conversation %s: %s", conversation_id, e, ) diff --git a/src/cache/sqlite_cache.py b/src/cache/sqlite_cache.py index 7a172cc3f..78a9e1ce7 100644 --- a/src/cache/sqlite_cache.py +++ b/src/cache/sqlite_cache.py @@ -10,9 +10,9 @@ from models.cache_entry import CacheEntry from models.config import SQLiteDatabaseConfiguration from models.responses import ConversationData, ReferencedDocument +from utils.connection_decorator import connection from utils.types import ToolCallSummary, ToolResultSummary from log import get_logger -from utils.connection_decorator import connection logger = get_logger("cache.sqlite_cache") @@ -194,7 +194,7 @@ def initialize_cache(self) -> None: self.connection.commit() @connection - def get( + def get( # pylint: disable=R0914 self, user_id: str, conversation_id: str, skip_user_id_check: bool = False ) -> list[CacheEntry]: """Get the value associated with the given key. @@ -247,7 +247,8 @@ def get( ] except (json.JSONDecodeError, ValueError) as e: logger.warning( - "Failed to deserialize tool_calls for " "conversation %s: %s", + "Failed to deserialize tool_calls for " + "conversation %s: %s", conversation_id, e, ) @@ -263,7 +264,8 @@ def get( ] except (json.JSONDecodeError, ValueError) as e: logger.warning( - "Failed to deserialize tool_results for " "conversation %s: %s", + "Failed to deserialize tool_results for " + "conversation %s: %s", conversation_id, e, ) @@ -332,7 +334,8 @@ def insert_or_append( tool_calls_json = json.dumps(tool_calls_as_dicts) except (TypeError, ValueError) as e: logger.warning( - "Failed to serialize tool_calls for " "conversation %s: %s", + "Failed to serialize tool_calls for " + "conversation %s: %s", conversation_id, e, ) @@ -346,7 +349,8 @@ def insert_or_append( tool_results_json = json.dumps(tool_results_as_dicts) except (TypeError, ValueError) as e: logger.warning( - "Failed to serialize tool_results for " "conversation %s: %s", + "Failed to serialize tool_results for " + "conversation %s: %s", conversation_id, e, ) From 3899c70f5ed5a833da16eb722de15f067ccfccc8 Mon Sep 17 00:00:00 2001 From: Stephanie Date: Mon, 12 Jan 2026 17:20:49 -0500 Subject: [PATCH 4/4] fix test failure Signed-off-by: Stephanie --- src/cache/postgres_cache.py | 6 ++---- src/cache/sqlite_cache.py | 12 ++++-------- tests/unit/utils/test_endpoints.py | 8 ++++++-- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/cache/postgres_cache.py b/src/cache/postgres_cache.py index 61aace081..5e72cf3a1 100644 --- a/src/cache/postgres_cache.py +++ b/src/cache/postgres_cache.py @@ -363,8 +363,7 @@ def insert_or_append( tool_calls_json = json.dumps(tool_calls_as_dicts) except (TypeError, ValueError) as e: logger.warning( - "Failed to serialize tool_calls for " - "conversation %s: %s", + "Failed to serialize tool_calls for conversation %s: %s", conversation_id, e, ) @@ -378,8 +377,7 @@ def insert_or_append( tool_results_json = json.dumps(tool_results_as_dicts) except (TypeError, ValueError) as e: logger.warning( - "Failed to serialize tool_results for " - "conversation %s: %s", + "Failed to serialize tool_results for conversation %s: %s", conversation_id, e, ) diff --git a/src/cache/sqlite_cache.py b/src/cache/sqlite_cache.py index 78a9e1ce7..ba48cb77f 100644 --- a/src/cache/sqlite_cache.py +++ b/src/cache/sqlite_cache.py @@ -247,8 +247,7 @@ def get( # pylint: disable=R0914 ] except (json.JSONDecodeError, ValueError) as e: logger.warning( - "Failed to deserialize tool_calls for " - "conversation %s: %s", + "Failed to deserialize tool_calls for conversation %s: %s", conversation_id, e, ) @@ -264,8 +263,7 @@ def get( # pylint: disable=R0914 ] except (json.JSONDecodeError, ValueError) as e: logger.warning( - "Failed to deserialize tool_results for " - "conversation %s: %s", + "Failed to deserialize tool_results for conversation %s: %s", conversation_id, e, ) @@ -334,8 +332,7 @@ def insert_or_append( tool_calls_json = json.dumps(tool_calls_as_dicts) except (TypeError, ValueError) as e: logger.warning( - "Failed to serialize tool_calls for " - "conversation %s: %s", + "Failed to serialize tool_calls for conversation %s: %s", conversation_id, e, ) @@ -349,8 +346,7 @@ def insert_or_append( tool_results_json = json.dumps(tool_results_as_dicts) except (TypeError, ValueError) as e: logger.warning( - "Failed to serialize tool_results for " - "conversation %s: %s", + "Failed to serialize tool_results for conversation %s: %s", conversation_id, e, ) diff --git a/tests/unit/utils/test_endpoints.py b/tests/unit/utils/test_endpoints.py index c0641685d..b13436744 100644 --- a/tests/unit/utils/test_endpoints.py +++ b/tests/unit/utils/test_endpoints.py @@ -1069,7 +1069,9 @@ async def test_cleanup_after_streaming_generate_topic_summary_default_true( provider_id="test_provider", llama_stack_model_id="test_llama_model", query_request=query_request, - summary=mocker.Mock(llm_response="test response", tool_calls=[]), + summary=mocker.Mock( + llm_response="test response", tool_calls=[], tool_results=[] + ), metadata_map={}, started_at="2024-01-01T00:00:00Z", client=mock_client, @@ -1121,7 +1123,9 @@ async def test_cleanup_after_streaming_generate_topic_summary_explicit_false( provider_id="test_provider", llama_stack_model_id="test_llama_model", query_request=query_request, - summary=mocker.Mock(llm_response="test response", tool_calls=[]), + summary=mocker.Mock( + llm_response="test response", tool_calls=[], tool_results=[] + ), metadata_map={}, started_at="2024-01-01T00:00:00Z", client=mock_client,