Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,6 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
)

completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")

cache_entry = CacheEntry(
query=query_request.query,
response=summary.llm_response,
Expand All @@ -383,6 +382,8 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
started_at=started_at,
completed_at=completed_at,
referenced_documents=referenced_documents if referenced_documents else None,
tool_calls=summary.tool_calls if summary.tool_calls else None,
tool_results=summary.tool_results if summary.tool_results else None,
)

consume_tokens(
Expand Down
85 changes: 79 additions & 6 deletions src/cache/postgres_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from models.cache_entry import CacheEntry
from models.config import PostgreSQLDatabaseConfiguration
from models.responses import ConversationData, ReferencedDocument
from log import get_logger
from utils.connection_decorator import connection
from utils.types import ToolCallSummary, ToolResultSummary
from log import get_logger

logger = get_logger("cache.postgres_cache")

Expand All @@ -32,7 +33,9 @@ class PostgresCache(Cache):
response | text | |
provider | text | |
model | text | |
referenced_documents | jsonb | |
referenced_documents | jsonb | |
tool_calls | jsonb | |
tool_results | jsonb | |
Indexes:
"cache_pkey" PRIMARY KEY, btree (user_id, conversation_id, created_at)
"timestamps" btree (created_at)
Expand All @@ -55,6 +58,8 @@ class PostgresCache(Cache):
provider text,
model text,
referenced_documents jsonb,
tool_calls jsonb,
tool_results jsonb,
PRIMARY KEY(user_id, conversation_id, created_at)
);
"""
Expand All @@ -75,16 +80,18 @@ class PostgresCache(Cache):
"""

SELECT_CONVERSATION_HISTORY_STATEMENT = """
SELECT query, response, provider, model, started_at, completed_at, referenced_documents
SELECT query, response, provider, model, started_at, completed_at,
referenced_documents, tool_calls, tool_results
FROM cache
WHERE user_id=%s AND conversation_id=%s
ORDER BY created_at
"""

INSERT_CONVERSATION_HISTORY_STATEMENT = """
INSERT INTO cache(user_id, conversation_id, created_at, started_at, completed_at,
query, response, provider, model, referenced_documents)
VALUES (%s, %s, CURRENT_TIMESTAMP, %s, %s, %s, %s, %s, %s, %s)
query, response, provider, model, referenced_documents,
tool_calls, tool_results)
VALUES (%s, %s, CURRENT_TIMESTAMP, %s, %s, %s, %s, %s, %s, %s, %s, %s)
"""

QUERY_CACHE_SIZE = """
Expand Down Expand Up @@ -220,7 +227,7 @@ def initialize_cache(self, namespace: str) -> None:
self.connection.commit()

@connection
def get(
def get( # pylint: disable=R0914
self, user_id: str, conversation_id: str, skip_user_id_check: bool = False
) -> list[CacheEntry]:
"""Get the value associated with the given key.
Expand Down Expand Up @@ -260,6 +267,40 @@ def get(
conversation_id,
e,
)

# Parse tool_calls back into ToolCallSummary objects
tool_calls_data = conversation_entry[7]
tool_calls_obj = None
if tool_calls_data:
try:
tool_calls_obj = [
ToolCallSummary.model_validate(tc) for tc in tool_calls_data
]
except (ValueError, TypeError) as e:
logger.warning(
"Failed to deserialize tool_calls for "
"conversation %s: %s",
conversation_id,
e,
)

# Parse tool_results back into ToolResultSummary objects
tool_results_data = conversation_entry[8]
tool_results_obj = None
if tool_results_data:
try:
tool_results_obj = [
ToolResultSummary.model_validate(tr)
for tr in tool_results_data
]
except (ValueError, TypeError) as e:
logger.warning(
"Failed to deserialize tool_results for "
"conversation %s: %s",
conversation_id,
e,
)

cache_entry = CacheEntry(
query=conversation_entry[0],
response=conversation_entry[1],
Expand All @@ -268,6 +309,8 @@ def get(
started_at=conversation_entry[4],
completed_at=conversation_entry[5],
referenced_documents=docs_obj,
tool_calls=tool_calls_obj,
tool_results=tool_results_obj,
)
result.append(cache_entry)

Expand Down Expand Up @@ -311,6 +354,34 @@ def insert_or_append(
e,
)

tool_calls_json = None
if cache_entry.tool_calls:
try:
tool_calls_as_dicts = [
tc.model_dump(mode="json") for tc in cache_entry.tool_calls
]
tool_calls_json = json.dumps(tool_calls_as_dicts)
except (TypeError, ValueError) as e:
logger.warning(
"Failed to serialize tool_calls for conversation %s: %s",
conversation_id,
e,
)

tool_results_json = None
if cache_entry.tool_results:
try:
tool_results_as_dicts = [
tr.model_dump(mode="json") for tr in cache_entry.tool_results
]
tool_results_json = json.dumps(tool_results_as_dicts)
except (TypeError, ValueError) as e:
logger.warning(
"Failed to serialize tool_results for conversation %s: %s",
conversation_id,
e,
)

# the whole operation is run in one transaction
with self.connection.cursor() as cursor:
cursor.execute(
Expand All @@ -325,6 +396,8 @@ def insert_or_append(
cache_entry.provider,
cache_entry.model,
referenced_documents_json,
tool_calls_json,
tool_results_json,
),
)

Expand Down
82 changes: 77 additions & 5 deletions src/cache/sqlite_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from models.cache_entry import CacheEntry
from models.config import SQLiteDatabaseConfiguration
from models.responses import ConversationData, ReferencedDocument
from log import get_logger
from utils.connection_decorator import connection
from utils.types import ToolCallSummary, ToolResultSummary
from log import get_logger

logger = get_logger("cache.sqlite_cache")

Expand All @@ -34,6 +35,8 @@ class SQLiteCache(Cache):
provider | text | |
model | text | |
referenced_documents | text | |
tool_calls | text | |
tool_results | text | |
Indexes:
"cache_pkey" PRIMARY KEY, btree (user_id, conversation_id, created_at)
"cache_key_key" UNIQUE CONSTRAINT, btree (key)
Expand All @@ -54,6 +57,8 @@ class SQLiteCache(Cache):
provider text,
model text,
referenced_documents text,
tool_calls text,
tool_results text,
PRIMARY KEY(user_id, conversation_id, created_at)
);
"""
Expand All @@ -74,16 +79,18 @@ class SQLiteCache(Cache):
"""

SELECT_CONVERSATION_HISTORY_STATEMENT = """
SELECT query, response, provider, model, started_at, completed_at, referenced_documents
SELECT query, response, provider, model, started_at, completed_at,
referenced_documents, tool_calls, tool_results
FROM cache
WHERE user_id=? AND conversation_id=?
ORDER BY created_at
"""

INSERT_CONVERSATION_HISTORY_STATEMENT = """
INSERT INTO cache(user_id, conversation_id, created_at, started_at, completed_at,
query, response, provider, model, referenced_documents)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
query, response, provider, model, referenced_documents,
tool_calls, tool_results)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""

QUERY_CACHE_SIZE = """
Expand Down Expand Up @@ -187,7 +194,7 @@ def initialize_cache(self) -> None:
self.connection.commit()

@connection
def get(
def get( # pylint: disable=R0914
self, user_id: str, conversation_id: str, skip_user_id_check: bool = False
) -> list[CacheEntry]:
"""Get the value associated with the given key.
Expand Down Expand Up @@ -228,6 +235,39 @@ def get(
conversation_id,
e,
)

# Parse tool_calls back into ToolCallSummary objects
tool_calls_json_str = conversation_entry[7]
tool_calls_obj = None
if tool_calls_json_str:
try:
tool_calls_data = json.loads(tool_calls_json_str)
tool_calls_obj = [
ToolCallSummary.model_validate(tc) for tc in tool_calls_data
]
except (json.JSONDecodeError, ValueError) as e:
logger.warning(
"Failed to deserialize tool_calls for conversation %s: %s",
conversation_id,
e,
)

# Parse tool_results back into ToolResultSummary objects
tool_results_json_str = conversation_entry[8]
tool_results_obj = None
if tool_results_json_str:
try:
tool_results_data = json.loads(tool_results_json_str)
tool_results_obj = [
ToolResultSummary.model_validate(tr) for tr in tool_results_data
]
except (json.JSONDecodeError, ValueError) as e:
logger.warning(
"Failed to deserialize tool_results for conversation %s: %s",
conversation_id,
e,
)

cache_entry = CacheEntry(
query=conversation_entry[0],
response=conversation_entry[1],
Expand All @@ -236,6 +276,8 @@ def get(
started_at=conversation_entry[4],
completed_at=conversation_entry[5],
referenced_documents=docs_obj,
tool_calls=tool_calls_obj,
tool_results=tool_results_obj,
)
result.append(cache_entry)

Expand Down Expand Up @@ -281,6 +323,34 @@ def insert_or_append(
e,
)

tool_calls_json = None
if cache_entry.tool_calls:
try:
tool_calls_as_dicts = [
tc.model_dump(mode="json") for tc in cache_entry.tool_calls
]
tool_calls_json = json.dumps(tool_calls_as_dicts)
except (TypeError, ValueError) as e:
logger.warning(
"Failed to serialize tool_calls for conversation %s: %s",
conversation_id,
e,
)

tool_results_json = None
if cache_entry.tool_results:
try:
tool_results_as_dicts = [
tr.model_dump(mode="json") for tr in cache_entry.tool_results
]
tool_results_json = json.dumps(tool_results_as_dicts)
except (TypeError, ValueError) as e:
logger.warning(
"Failed to serialize tool_results for conversation %s: %s",
conversation_id,
e,
)

cursor.execute(
self.INSERT_CONVERSATION_HISTORY_STATEMENT,
(
Expand All @@ -294,6 +364,8 @@ def insert_or_append(
cache_entry.provider,
cache_entry.model,
referenced_documents_json,
tool_calls_json,
tool_results_json,
),
)

Expand Down
5 changes: 5 additions & 0 deletions src/models/cache_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional
from pydantic import BaseModel
from models.responses import ReferencedDocument
from utils.types import ToolCallSummary, ToolResultSummary


class CacheEntry(BaseModel):
Expand All @@ -14,6 +15,8 @@ class CacheEntry(BaseModel):
provider: Provider identification
model: Model identification
referenced_documents: List of documents referenced by the response
tool_calls: List of tool calls made during response generation
tool_results: List of tool results from tool calls
"""

query: str
Expand All @@ -23,3 +26,5 @@ class CacheEntry(BaseModel):
started_at: str
completed_at: str
referenced_documents: Optional[list[ReferencedDocument]] = None
tool_calls: Optional[list[ToolCallSummary]] = None
tool_results: Optional[list[ToolResultSummary]] = None
2 changes: 2 additions & 0 deletions src/utils/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,8 @@ async def cleanup_after_streaming(
started_at=started_at,
completed_at=completed_at,
referenced_documents=referenced_documents if referenced_documents else None,
tool_calls=summary.tool_calls if summary.tool_calls else None,
tool_results=summary.tool_results if summary.tool_results else None,
)

store_conversation_into_cache(
Expand Down
Loading
Loading