diff --git a/docs/openapi.json b/docs/openapi.json index 4d2c15a6a..7eb67fffd 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -671,6 +671,11 @@ } ], "title": "System Prompt" + }, + "default_estimation_tokenizer": { + "type": "string", + "title": "Default Estimation Tokenizer", + "description": "The default tokenizer to use for estimating token usage when the model is not supported by tiktoken." } }, "type": "object", @@ -1132,6 +1137,28 @@ "response": { "type": "string", "title": "Response" + }, + "input_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Input Tokens" + }, + "output_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Output Tokens" } }, "type": "object", @@ -1139,7 +1166,7 @@ "response" ], "title": "QueryResponse", - "description": "Model representing LLM response to a query.\n\nAttributes:\n conversation_id: The optional conversation ID (UUID).\n response: The response.", + "description": "Model representing LLM response to a query.\n\nAttributes:\n conversation_id: The optional conversation ID (UUID).\n response: The response.\n input_tokens: Number of tokens sent to LLM.\n output_tokens: Number of tokens received from LLM.", "examples": [ { "conversation_id": "123e4567-e89b-12d3-a456-426614174000", diff --git a/pyproject.toml b/pyproject.toml index 3bfd84d11..d5ff4e0b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "llama-stack>=0.2.13", "rich>=14.0.0", "cachetools>=6.1.0", + "tiktoken>=0.9.0,<1.0.0", ] [tool.pyright] diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 3aa833f4e..46a0b5334 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -32,6 +32,7 @@ from utils.endpoints import check_configuration_loaded, get_system_prompt from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups from utils.suid import get_suid +from utils.token_counter import get_token_counter from utils.types import GraniteToolParser logger = logging.getLogger("app.endpoints.handlers") @@ -121,7 +122,7 @@ def query_endpoint_handler( # try to get Llama Stack client client = LlamaStackClientHolder().get_client() model_id = select_model_id(client.models.list(), query_request) - response, conversation_id = retrieve_response( + response, conversation_id, token_usage = retrieve_response( client, model_id, query_request, @@ -144,7 +145,12 @@ def query_endpoint_handler( attachments=query_request.attachments or [], ) - return QueryResponse(conversation_id=conversation_id, response=response) + return QueryResponse( + conversation_id=conversation_id, + response=response, + input_tokens=token_usage["input_tokens"], + output_tokens=token_usage["output_tokens"], + ) # connection to Llama Stack server except APIConnectionError as e: @@ -202,13 +208,21 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s return model_id +def _build_toolgroups(client: LlamaStackClient) -> list[Toolgroup] | None: + """Build toolgroups from vector DBs and MCP servers.""" + vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()] + return (get_rag_toolgroups(vector_db_ids) or []) + [ + mcp_server.name for mcp_server in configuration.mcp_servers + ] + + def retrieve_response( client: LlamaStackClient, model_id: str, query_request: QueryRequest, token: str, mcp_headers: dict[str, dict[str, str]] | None = None, -) -> tuple[str, str]: +) -> tuple[str, str, dict[str, int]]: """Retrieve response from LLMs and agents.""" available_shields = [shield.identifier for shield in client.shields.list()] if not available_shields: @@ -251,19 +265,37 @@ def retrieve_response( ), } - vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()] - toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ - mcp_server.name for mcp_server in configuration.mcp_servers - ] response = agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], session_id=conversation_id, documents=query_request.get_documents(), stream=False, - toolgroups=toolgroups or None, + toolgroups=_build_toolgroups(client) or None, ) - return str(response.output_message.content), conversation_id # type: ignore[union-attr] + response_content = str(response.output_message.content) # type: ignore[union-attr] + + # Currently (2025-07-08) the usage is not returned by the API, so we need to estimate it + # try: + # token_usage = { + # "input_tokens": response.usage.get("prompt_tokens", 0), + # "output_tokens": response.usage.get("completion_tokens", 0), + # } + # except AttributeError: + # Estimate token usage + try: + token_counter = get_token_counter(model_id) + token_usage = token_counter.count_conversation_turn_tokens( + conversation_id, system_prompt, query_request, response_content + ) + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning("Failed to estimate token usage: %s", e) + token_usage = { + "input_tokens": 0, + "output_tokens": 0, + } + + return response_content, conversation_id, token_usage def validate_attachments_metadata(attachments: list[Attachment]) -> None: diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 2e2092e17..1f49cce07 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -9,10 +9,12 @@ from llama_stack_client import APIConnectionError from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore +from llama_stack_client.types.agents.turn_create_params import Toolgroup from llama_stack_client import AsyncLlamaStackClient # type: ignore from llama_stack_client.types.shared.interleaved_content_item import TextContentItem from llama_stack_client.types import UserMessage # type: ignore + from fastapi import APIRouter, HTTPException, Request, Depends, status from fastapi.responses import StreamingResponse @@ -24,6 +26,7 @@ from utils.common import retrieve_user_id from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups from utils.suid import get_suid +from utils.token_counter import get_token_counter from utils.types import GraniteToolParser from app.endpoints.conversations import conversation_id_to_agent_id @@ -97,8 +100,13 @@ def stream_start_event(conversation_id: str) -> str: ) -def stream_end_event(metadata_map: dict) -> str: - """Yield the end of the data stream.""" +def stream_end_event(metadata_map: dict, metrics_map: dict[str, int]) -> str: + """Yield the end of the data stream. + + Args: + metadata_map: Dictionary containing metadata about referenced documents + metrics_map: Dictionary containing metrics like 'input_tokens' and 'output_tokens' + """ return format_stream_data( { "event": "end", @@ -114,8 +122,8 @@ def stream_end_event(metadata_map: dict) -> str: ) ], "truncated": None, # TODO(jboos): implement truncated - "input_tokens": 0, # TODO(jboos): implement input tokens - "output_tokens": 0, # TODO(jboos): implement output tokens + "input_tokens": metrics_map.get("input_tokens", 0), + "output_tokens": metrics_map.get("output_tokens", 0), }, "available_quotas": {}, # TODO(jboos): implement available quotas } @@ -204,7 +212,7 @@ async def streaming_query_endpoint_handler( # try to get Llama Stack client client = AsyncLlamaStackClientHolder().get_client() model_id = select_model_id(await client.models.list(), query_request) - response, conversation_id = await retrieve_response( + response, conversation_id, token_usage = await retrieve_response( client, model_id, query_request, @@ -229,7 +237,24 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]: chunk_id += 1 yield event - yield stream_end_event(metadata_map) + # Currently (2025-07-08) the usage is not returned by the API, so we need to estimate + # try: + # output_tokens = response.usage.get("completion_tokens", 0) + # except AttributeError: + # Estimate output tokens from complete response + try: + token_counter = get_token_counter(model_id) + output_tokens = token_counter.count_tokens(complete_response) + logger.debug("Estimated output tokens: %s", output_tokens) + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning("Failed to estimate output tokens: %s", e) + output_tokens = 0 + + metrics_map = { + "input_tokens": token_usage["input_tokens"], + "output_tokens": output_tokens, + } + yield stream_end_event(metadata_map, metrics_map) if not is_transcripts_enabled(): logger.debug("Transcript collection is disabled in the configuration") @@ -260,13 +285,23 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]: ) from e +async def _build_toolgroups(client: AsyncLlamaStackClient) -> list[Toolgroup] | None: + """Build toolgroups from vector DBs and MCP servers.""" + vector_db_ids = [ + vector_db.identifier for vector_db in await client.vector_dbs.list() + ] + return (get_rag_toolgroups(vector_db_ids) or []) + [ + mcp_server.name for mcp_server in configuration.mcp_servers + ] + + async def retrieve_response( client: AsyncLlamaStackClient, model_id: str, query_request: QueryRequest, token: str, mcp_headers: dict[str, dict[str, str]] | None = None, -) -> tuple[Any, str]: +) -> tuple[Any, str, dict[str, int]]: """Retrieve response from LLMs and agents.""" available_shields = [shield.identifier for shield in await client.shields.list()] if not available_shields: @@ -312,18 +347,33 @@ async def retrieve_response( } logger.debug("Session ID: %s", conversation_id) - vector_db_ids = [ - vector_db.identifier for vector_db in await client.vector_dbs.list() - ] - toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ - mcp_server.name for mcp_server in configuration.mcp_servers - ] + response = await agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], session_id=conversation_id, documents=query_request.get_documents(), stream=True, - toolgroups=toolgroups or None, + toolgroups=await _build_toolgroups(client) or None, ) - return response, conversation_id + # Currently (2025-07-08) the usage is not returned by the API, so we need to estimate it + # try: + # token_usage = { + # "input_tokens": response.usage.get("prompt_tokens", 0), + # "output_tokens": 0, # Will be calculated during streaming + # } + # except AttributeError: + # # Estimate input tokens (Output will be calculated during streaming) + try: + token_counter = get_token_counter(model_id) + token_usage = token_counter.count_conversation_turn_tokens( + conversation_id, system_prompt, query_request + ) + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning("Failed to estimate token usage: %s", e) + token_usage = { + "input_tokens": 0, + "output_tokens": 0, + } + + return response, conversation_id, token_usage diff --git a/src/constants.py b/src/constants.py index d4c5b02a4..433344f72 100644 --- a/src/constants.py +++ b/src/constants.py @@ -42,6 +42,8 @@ } ) DEFAULT_AUTHENTICATION_MODULE = AUTH_MOD_NOOP +# Default tokenizer for estimating token usage +DEFAULT_ESTIMATION_TOKENIZER = "cl100k_base" # Data collector constants DATA_COLLECTOR_COLLECTION_INTERVAL = 7200 # 2 hours in seconds diff --git a/src/models/config.py b/src/models/config.py index 8e2d36e36..24b9786ad 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -157,6 +157,7 @@ class Customization(BaseModel): disable_query_system_prompt: bool = False system_prompt_path: Optional[FilePath] = None system_prompt: Optional[str] = None + default_estimation_tokenizer: str = constants.DEFAULT_ESTIMATION_TOKENIZER @model_validator(mode="after") def check_customization_model(self) -> Self: diff --git a/src/models/responses.py b/src/models/responses.py index 76270739d..7bf654b99 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -16,8 +16,6 @@ class ModelsResponse(BaseModel): # - referenced_documents: The optional URLs and titles for the documents used # to generate the response. # - truncated: Set to True if conversation history was truncated to be within context window. -# - input_tokens: Number of tokens sent to LLM -# - output_tokens: Number of tokens received from LLM # - available_quotas: Quota available as measured by all configured quota limiters # - tool_calls: List of tool requests. # - tool_results: List of tool results. @@ -28,10 +26,14 @@ class QueryResponse(BaseModel): Attributes: conversation_id: The optional conversation ID (UUID). response: The response. + input_tokens: Number of tokens sent to LLM. + output_tokens: Number of tokens received from LLM. """ conversation_id: Optional[str] = None response: str + input_tokens: int = 0 + output_tokens: int = 0 # provides examples for /docs endpoint model_config = { diff --git a/src/utils/token_counter.py b/src/utils/token_counter.py new file mode 100644 index 000000000..e11d03c40 --- /dev/null +++ b/src/utils/token_counter.py @@ -0,0 +1,256 @@ +"""Token counting utilities using tiktoken. + +This module provides utilities for counting tokens in text and conversation messages +using the tiktoken library. It supports automatic model-specific encoding detection +with fallback to a default tokenizer, and includes conversation-level token tracking +for Agent conversations. +""" + +from functools import lru_cache +import logging +from typing import Sequence + +from cachetools import TTLCache # type: ignore +import tiktoken + + +from llama_stack_client.types import ( + UserMessage, + SystemMessage, + ToolResponseMessage, + CompletionMessage, +) +from models.requests import QueryRequest + +from configuration import configuration, AppConfig +from constants import DEFAULT_ESTIMATION_TOKENIZER + +logger = logging.getLogger(__name__) + +# Class-level cache to track cumulative input tokens for each conversation +_conversation_cache: TTLCache[str, int] = TTLCache(maxsize=1000, ttl=3600) + + +class TokenCounter: + """A utility class for counting tokens in text and conversation messages. + + This class provides methods to count tokens in plain text and structured + conversation messages. It automatically handles model-specific tokenization + using tiktoken, with fallback to a default tokenizer if the model is not + recognized. It also tracks cumulative input tokens for Agent conversations. + + Attributes: + _encoder: The tiktoken encoding object used for tokenization + """ + + def __init__(self, model_id: str): + """Initialize the TokenCounter with a specific model. + + Args: + model_id: The identifier of the model to use for tokenization. + This is used to determine the appropriate tiktoken encoding. + + Note: + If the model_id is not recognized by tiktoken, the system will + fall back to the default estimation tokenizer specified in the + configuration. + """ + self._encoder = None + + try: + # Use tiktoken's encoding_for_model function which handles GPT models automatically + self._encoder = tiktoken.encoding_for_model(model_id) + logger.debug("Initialized tiktoken encoding for model: %s", model_id) + except KeyError as e: + fallback_encoding = get_default_estimation_tokenizer(configuration) + logger.warning( + "Failed to get encoding for model %s: %s, using %s", + model_id, + e, + fallback_encoding, + ) + self._encoder = tiktoken.get_encoding(fallback_encoding) + + def count_tokens(self, text: str) -> int: + """Count the number of tokens in a given text string. + + Args: + text: The text string to count tokens for. + + Returns: + The number of tokens in the text. Returns 0 if text is empty or None. + """ + if not text or not self._encoder: + return 0 + return len(self._encoder.encode(text)) + + def count_turn_tokens( + self, system_prompt: str, query_request: QueryRequest, response: str = "" + ) -> dict[str, int]: + """Count tokens for a complete conversation turn. + + This method estimates token usage for a typical conversation turn, + including system prompt, user query, and optional response. It accounts + for message formatting overhead and conversation structure. + + Args: + system_prompt: The system prompt message content. + query: The user's query message content. + response: The assistant's response message content (optional). + + Returns: + A dictionary containing: + - 'input_tokens': Total tokens in the input messages (system + query) + - 'output_tokens': Total tokens in the response message + """ + # Estimate token usage + input_messages: list[SystemMessage | UserMessage] = [] + if system_prompt: + input_messages.append( + SystemMessage(role="system", content=str(system_prompt)) + ) + input_content = query_request.query + if query_request.attachments: + for attachment in query_request.attachments: + input_content += "\n" + attachment.content + input_messages.append(UserMessage(role="user", content=input_content)) + + input_tokens = self.count_message_tokens(input_messages) + output_tokens = self.count_tokens(response) + + logger.debug("Estimated tokens in/out: %d / %d", input_tokens, output_tokens) + + return { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + } + + def count_conversation_turn_tokens( + self, + conversation_id: str, + system_prompt: str, + query_request: QueryRequest, + response: str = "", + ) -> dict[str, int]: + """Count tokens for a conversation turn with cumulative tracking. + + This method estimates token usage for a conversation turn and tracks + cumulative input tokens across the conversation. It accounts for the + fact that Agent conversations include the entire message history in + each turn. + + Args: + conversation_id: The conversation ID to track tokens for. + system_prompt: The system prompt message content. + query: The user's query message content. + response: The assistant's response message content (optional). + + Returns: + A dictionary containing: + - 'input_tokens': Cumulative input tokens for the conversation + - 'output_tokens': Total tokens in the response message + """ + # Get the current turn's token usage + turn_token_usage = self.count_turn_tokens( + system_prompt, query_request, response + ) + + # Get cumulative input tokens for this conversation + cumulative_input_tokens = _conversation_cache.get(conversation_id, 0) + + # Add this turn's input tokens to the cumulative total + new_cumulative_input_tokens = ( + cumulative_input_tokens + turn_token_usage["input_tokens"] + ) + _conversation_cache[conversation_id] = new_cumulative_input_tokens + + # TODO(csibbitt) - Add counting for MCP and RAG content + + logger.debug( + "Token usage for conversation %s: turn input=%d, cumulative input=%d, output=%d", + conversation_id, + turn_token_usage["input_tokens"], + new_cumulative_input_tokens, + turn_token_usage["output_tokens"], + ) + + return { + "input_tokens": new_cumulative_input_tokens, + "output_tokens": turn_token_usage["output_tokens"], + } + + def count_message_tokens( + self, + messages: Sequence[ + UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage + ], + ) -> int: + """Count tokens for a list of conversation messages. + + This method counts tokens for structured conversation messages, including + the message content and formatting overhead for roles and conversation + structure. + + Args: + messages: A list of message objects (e.g., SystemMessage, UserMessage) + + Returns: + The total number of tokens across all messages, including formatting overhead. + """ + total_tokens = 0 + + for message in messages: + total_tokens += self.count_tokens(str(message.content)) + # Add role overhead (varies by model, 4 is typical for OpenAI models) + role_formatting_overhead = 4 + total_tokens += role_formatting_overhead + + # Add conversation formatting overhead + if messages: + total_tokens += self._get_conversation_overhead(len(messages)) + + return total_tokens + + def _get_conversation_overhead(self, message_count: int) -> int: + """Calculate the token overhead for conversation formatting. + + This method estimates the additional tokens needed for conversation + structure, including start/end tokens and message separators. + + Args: + message_count: The number of messages in the conversation. + + Returns: + The estimated token overhead for conversation formatting. + """ + base_overhead = 3 # Start of conversation tokens (based on OpenAI chat format) + separator_overhead = max(0, (message_count - 1) * 1) # Message separator tokens + return base_overhead + separator_overhead + + +@lru_cache(maxsize=8) +def get_token_counter(model_id: str) -> TokenCounter: + """Get a cached TokenCounter instance for the specified model. + + This function provides a cached TokenCounter instance to avoid repeated + initialization of the same model's tokenizer. + + Args: + model_id: The identifier of the model to get a token counter for. + + Returns: + A TokenCounter instance configured for the specified model. + """ + return TokenCounter(model_id) + + +def get_default_estimation_tokenizer(config: AppConfig) -> str: + """Get the default estimation tokenizer.""" + if ( + config.customization is not None + and config.customization.default_estimation_tokenizer is not None + ): + return config.customization.default_estimation_tokenizer + + # default system prompt has the lowest precedence + return DEFAULT_ESTIMATION_TOKENIZER diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 5c31b5b62..525ead1b6 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -127,7 +127,11 @@ def _test_query_endpoint_handler(mocker, store_transcript_to_file=False): mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(llm_response, conversation_id), + return_value=( + llm_response, + conversation_id, + {"input_tokens": 10, "output_tokens": 20}, + ), ) mocker.patch("app.endpoints.query.select_model_id", return_value="fake_model_id") mocker.patch( @@ -305,6 +309,7 @@ def test_validate_attachments_metadata_invalid_content_type(): def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] mock_vector_db = mocker.Mock() @@ -323,12 +328,14 @@ def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id, token_usage = retrieve_response( mock_client, model_id, query_request, access_token ) assert response == "LLM answer" assert conversation_id == "fake_session_id" + assert token_usage["input_tokens"] > 0 + assert token_usage["output_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -341,6 +348,8 @@ def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker): def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks + + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -357,12 +366,14 @@ def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id, token_usage = retrieve_response( mock_client, model_id, query_request, access_token ) assert response == "LLM answer" assert conversation_id == "fake_session_id" + assert token_usage["input_tokens"] > 0 + assert token_usage["output_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -388,6 +399,8 @@ def __repr__(self): return "MockShield" mock_client, mock_agent = prepare_agent_mocks + + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [MockShield("shield1")] mock_client.vector_dbs.list.return_value = [] @@ -404,12 +417,14 @@ def __repr__(self): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id, token_usage = retrieve_response( mock_client, model_id, query_request, access_token ) assert response == "LLM answer" assert conversation_id == "fake_session_id" + assert token_usage["input_tokens"] > 0 + assert token_usage["output_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -435,6 +450,7 @@ def __repr__(self): return "MockShield" mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [ MockShield("shield1"), @@ -454,12 +470,14 @@ def __repr__(self): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id, token_usage = retrieve_response( mock_client, model_id, query_request, access_token ) assert response == "LLM answer" assert conversation_id == "fake_session_id" + assert token_usage["input_tokens"] > 0 + assert token_usage["output_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -472,6 +490,7 @@ def __repr__(self): def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -496,12 +515,14 @@ def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id, token_usage = retrieve_response( mock_client, model_id, query_request, access_token ) assert response == "LLM answer" assert conversation_id == "fake_session_id" + assert token_usage["input_tokens"] > 0 + assert token_usage["output_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -519,6 +540,7 @@ def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker): def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -548,12 +570,14 @@ def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id, token_usage = retrieve_response( mock_client, model_id, query_request, access_token ) assert response == "LLM answer" assert conversation_id == "fake_session_id" + assert token_usage["input_tokens"] > 0 + assert token_usage["output_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -575,6 +599,7 @@ def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocker): def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): """Test the retrieve_response function with MCP servers configured.""" mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -601,12 +626,14 @@ def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token_123" - response, conversation_id = retrieve_response( + response, conversation_id, token_usage = retrieve_response( mock_client, model_id, query_request, access_token ) assert response == "LLM answer" assert conversation_id == "fake_session_id" + assert token_usage["input_tokens"] > 0 + assert token_usage["output_tokens"] > 0 # Verify get_agent was called with the correct parameters mock_get_agent.assert_called_once_with( @@ -645,6 +672,7 @@ def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): def test_retrieve_response_with_mcp_servers_empty_token(prepare_agent_mocks, mocker): """Test the retrieve_response function with MCP servers and empty access token.""" mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -664,12 +692,14 @@ def test_retrieve_response_with_mcp_servers_empty_token(prepare_agent_mocks, moc model_id = "fake_model_id" access_token = "" # Empty token - response, conversation_id = retrieve_response( + response, conversation_id, token_usage = retrieve_response( mock_client, model_id, query_request, access_token ) assert response == "LLM answer" assert conversation_id == "fake_session_id" + assert token_usage["input_tokens"] > 0 + assert token_usage["output_tokens"] > 0 # Verify get_agent was called with the correct parameters mock_get_agent.assert_called_once_with( @@ -730,7 +760,7 @@ def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker): }, } - response, conversation_id = retrieve_response( + response, conversation_id, _ = retrieve_response( mock_client, model_id, query_request, @@ -1174,7 +1204,11 @@ def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker): mock_retrieve_response = mocker.patch( "app.endpoints.query.retrieve_response", - return_value=("test response", "test_conversation_id"), + return_value=( + "test response", + "test_conversation_id", + {"input_tokens": 10, "output_tokens": 20}, + ), ) mocker.patch("app.endpoints.query.select_model_id", return_value="test_model") diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index fd32b7b1f..67c0dd9f1 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -151,6 +151,9 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) # Mock the streaming response from LLama Stack mock_streaming_response = mocker.AsyncMock() + # Currently usage is not returned by the API + # we simulate by using del to prevent pytest from returning a Mock + del mock_streaming_response.usage mock_streaming_response.__aiter__.return_value = [ mocker.Mock( event=mocker.Mock( @@ -191,7 +194,11 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) query = "What is OpenStack?" mocker.patch( "app.endpoints.streaming_query.retrieve_response", - return_value=(mock_streaming_response, "test_conversation_id"), + return_value=( + mock_streaming_response, + "test_conversation_id", + {"input_tokens": 10, "output_tokens": 20}, + ), ) mocker.patch( "app.endpoints.streaming_query.select_model_id", return_value="fake_model_id" @@ -271,8 +278,15 @@ async def test_streaming_query_endpoint_handler_store_transcript(mocker): async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks + + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] + mock_client.models.list.return_value = [ + mocker.Mock( + identifier="fake_model_id", model_type="llm", provider_id="test_provider" + ), + ] mock_vector_db = mocker.Mock() mock_vector_db.identifier = "VectorDB-1" mock_client.vector_dbs.list.return_value = [mock_vector_db] @@ -290,7 +304,7 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker model_id = "fake_model_id" token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, token_usage = await retrieve_response( mock_client, model_id, query_request, token ) @@ -298,6 +312,7 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker # conversation_id should be returned assert response is not None assert conversation_id == "test_conversation_id" + assert token_usage["input_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], session_id="test_conversation_id", @@ -310,8 +325,15 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks + + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] + mock_client.models.list.return_value = [ + mocker.Mock( + identifier="fake_model_id", model_type="llm", provider_id="test_provider" + ), + ] mock_client.vector_dbs.list.return_value = [] # Mock configuration with empty MCP servers @@ -327,7 +349,7 @@ async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocke model_id = "fake_model_id" token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, token_usage = await retrieve_response( mock_client, model_id, query_request, token ) @@ -335,6 +357,7 @@ async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocke # conversation_id should be returned assert response is not None assert conversation_id == "test_conversation_id" + assert token_usage["input_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], session_id="test_conversation_id", @@ -360,8 +383,15 @@ def __repr__(self): return "MockShield" mock_client, mock_agent = prepare_agent_mocks + + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [MockShield("shield1")] + mock_client.models.list.return_value = [ + mocker.Mock( + identifier="fake_model_id", model_type="llm", provider_id="test_provider" + ), + ] mock_client.vector_dbs.list.return_value = [] # Mock configuration with empty MCP servers @@ -377,12 +407,13 @@ def __repr__(self): model_id = "fake_model_id" token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, token_usage = await retrieve_response( mock_client, model_id, query_request, token ) assert response is not None assert conversation_id == "test_conversation_id" + assert token_usage["input_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], session_id="test_conversation_id", @@ -408,11 +439,18 @@ def __repr__(self): return "MockShield" mock_client, mock_agent = prepare_agent_mocks + + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [ MockShield("shield1"), MockShield("shield2"), ] + mock_client.models.list.return_value = [ + mocker.Mock( + identifier="fake_model_id", model_type="llm", provider_id="test_provider" + ), + ] mock_client.vector_dbs.list.return_value = [] # Mock configuration with empty MCP servers @@ -428,12 +466,13 @@ def __repr__(self): model_id = "fake_model_id" token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, token_usage = await retrieve_response( mock_client, model_id, query_request, token ) assert response is not None assert conversation_id == "test_conversation_id" + assert token_usage["input_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], session_id="test_conversation_id", @@ -446,8 +485,15 @@ def __repr__(self): async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks + + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] + mock_client.models.list.return_value = [ + mocker.Mock( + identifier="fake_model_id", model_type="llm", provider_id="test_provider" + ), + ] mock_client.vector_dbs.list.return_value = [] # Mock configuration with empty MCP servers @@ -471,12 +517,13 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker model_id = "fake_model_id" token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, token_usage = await retrieve_response( mock_client, model_id, query_request, token ) assert response is not None assert conversation_id == "test_conversation_id" + assert token_usage["input_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], session_id="test_conversation_id", @@ -494,8 +541,15 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks + + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] + mock_client.models.list.return_value = [ + mocker.Mock( + identifier="fake_model_id", model_type="llm", provider_id="test_provider" + ), + ] mock_client.vector_dbs.list.return_value = [] # Mock configuration with empty MCP servers @@ -524,12 +578,13 @@ async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocke model_id = "fake_model_id" token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, token_usage = await retrieve_response( mock_client, model_id, query_request, token ) assert response is not None assert conversation_id == "test_conversation_id" + assert token_usage["input_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], session_id="test_conversation_id", @@ -626,8 +681,15 @@ def test_stream_build_event_returns_none(mocker): async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): """Test the retrieve_response function with MCP servers configured.""" mock_client, mock_agent = prepare_agent_mocks + + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] + mock_client.models.list.return_value = [ + mocker.Mock( + identifier="fake_model_id", model_type="llm", provider_id="test_provider" + ), + ] mock_client.vector_dbs.list.return_value = [] # Mock configuration with MCP servers @@ -653,12 +715,13 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token_123" - response, conversation_id = await retrieve_response( + response, conversation_id, token_usage = await retrieve_response( mock_client, model_id, query_request, access_token ) assert response is not None assert conversation_id == "test_conversation_id" + assert token_usage["input_tokens"] > 0 # Verify get_agent was called with the correct parameters mock_get_agent.assert_called_once_with( @@ -699,8 +762,15 @@ async def test_retrieve_response_with_mcp_servers_empty_token( ): """Test the retrieve_response function with MCP servers and empty access token.""" mock_client, mock_agent = prepare_agent_mocks + + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] + mock_client.models.list.return_value = [ + mocker.Mock( + identifier="fake_model_id", model_type="llm", provider_id="test_provider" + ), + ] mock_client.vector_dbs.list.return_value = [] # Mock configuration with MCP servers @@ -719,12 +789,13 @@ async def test_retrieve_response_with_mcp_servers_empty_token( model_id = "fake_model_id" access_token = "" # Empty token - response, conversation_id = await retrieve_response( + response, conversation_id, token_usage = await retrieve_response( mock_client, model_id, query_request, access_token ) assert response is not None assert conversation_id == "test_conversation_id" + assert token_usage["input_tokens"] > 0 # Verify get_agent was called with the correct parameters mock_get_agent.assert_called_once_with( @@ -792,7 +863,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker): }, } - response, conversation_id = await retrieve_response( + response, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, @@ -1154,7 +1225,11 @@ async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker): mock_streaming_response.__aiter__.return_value = iter([]) mock_retrieve_response = mocker.patch( "app.endpoints.streaming_query.retrieve_response", - return_value=(mock_streaming_response, "test_conversation_id"), + return_value=( + mock_streaming_response, + "test_conversation_id", + {"input_tokens": 10, "output_tokens": 20}, + ), ) mocker.patch( diff --git a/tests/unit/utils/test_token_counter.py b/tests/unit/utils/test_token_counter.py new file mode 100644 index 000000000..ce68e0f18 --- /dev/null +++ b/tests/unit/utils/test_token_counter.py @@ -0,0 +1,149 @@ +"""Unit tests for token counter utilities.""" + +from llama_stack_client.types import UserMessage, CompletionMessage + +from utils.token_counter import TokenCounter +from models.requests import QueryRequest, Attachment +from configuration import AppConfig + + +config_dict = { + "name": "foo", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "workers": 1, + "color_log": True, + "access_log": True, + }, + "llama_stack": { + "api_key": "xyzzy", + "url": "http://x.y.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": { + "feedback_disabled": True, + }, +} + + +class TestTokenCounter: + """Test cases for TokenCounter class.""" + + def setup_class(self): + """Setup the test class.""" + cfg = AppConfig() + cfg.init_from_dict(config_dict) + + def test_count_tokens_empty_string(self): + """Test counting tokens for empty message list.""" + counter = TokenCounter("gpt-4") + assert counter.count_tokens("") == 0 + + def test_count_tokens_simple(self): + """Test counting tokens for a simple message.""" + counter = TokenCounter("llama3.2:1b") + assert counter.count_tokens("Hello World!") == 3 + + def test_count_message_tokens_simple(self): + """Test counting tokens for simple messages.""" + counter = TokenCounter("llama3.2:1b") + + messages = [ + UserMessage(role="user", content="Hello"), + CompletionMessage( + role="assistant", content="Hi there", stop_reason="end_of_turn" + ), + ] + + result = counter.count_message_tokens(messages) + + # 3 tokens worth of content + 4 role overhead per message + 4 conversation overhead + expected = 3 + (4 * 2) + 4 + assert result == expected + + def test_count_message_tokens_empty_messages(self): + """Test counting tokens for empty message list.""" + counter = TokenCounter("llama3.2:1b") + result = counter.count_message_tokens([]) + assert result == 0 + + def test_count_conversation_turn_tokens(self): + """Test cumulative token tracking across conversation turns.""" + counter = TokenCounter("llama3.2:1b") + + # First conversation should accumulate tokens + result1 = counter.count_conversation_turn_tokens( + "conv1", "System", QueryRequest(query="Hello"), "Hi" + ) + assert result1["input_tokens"] == 14 + result2 = counter.count_conversation_turn_tokens( + "conv1", "System", QueryRequest(query="How are you?"), "Good" + ) + assert result2["input_tokens"] == 31 + result3 = counter.count_conversation_turn_tokens( + "conv1", "System", QueryRequest(query="Fantastic"), "Yup" + ) + assert result3["input_tokens"] == 45 + + # Second conversation should be independent of the first + result4 = counter.count_conversation_turn_tokens( + "conv2", "System", QueryRequest(query="Hello"), "Hi" + ) + assert result4["input_tokens"] == 14 + + def test_count_conversation_turn_tokens_with_attachments(self): + """Test conversation turn token counting with 2 attachments.""" + counter = TokenCounter("llama3.2:1b") + + # Create 2 attachments + attachments = [ + Attachment( + attachment_type="log", + content_type="text/plain", + content="This is a log file with some error messages", + ), + Attachment( + attachment_type="configuration", + content_type="application/yaml", + content="kind: Pod\nmetadata:\n name: test-pod\nspec:\n" + + " containers:\n - name: app\n image: nginx:latest", + ), + ] + + query_request = QueryRequest( + query="Analyze these files for me", attachments=attachments + ) + + # Test the conversation turn with attachments + result = counter.count_conversation_turn_tokens( + "conv_with_attachments", "System prompt", query_request, "Analysis complete" + ) + + # Verify that the result contains the expected structure + assert "input_tokens" in result + assert "output_tokens" in result + + # The input tokens should include: + # - System message tokens + # - User query tokens + # - 2 attachment content tokens + # - Role formatting overhead for each message + # - Conversation formatting overhead + assert result["input_tokens"] > 0 + + # Output tokens should be the response content + assert result["output_tokens"] > 0 + + # Verify that attachments increase the token count compared to no attachments + query_request_no_attachments = QueryRequest(query="Analyze these files for me") + result_no_attachments = counter.count_conversation_turn_tokens( + "conv_no_attachments", + "System prompt", + query_request_no_attachments, + "Analysis complete", + ) + + # The version with attachments should have more input tokens + assert result["input_tokens"] > result_no_attachments["input_tokens"] diff --git a/uv.lock b/uv.lock index bedf9c894..287924185 100644 --- a/uv.lock +++ b/uv.lock @@ -836,6 +836,7 @@ dependencies = [ { name = "kubernetes" }, { name = "llama-stack" }, { name = "rich" }, + { name = "tiktoken" }, { name = "uvicorn" }, ] @@ -868,6 +869,7 @@ requires-dist = [ { name = "kubernetes", specifier = ">=30.1.0" }, { name = "llama-stack", specifier = ">=0.2.13" }, { name = "rich", specifier = ">=14.0.0" }, + { name = "tiktoken", specifier = ">=0.9.0" }, { name = "uvicorn", specifier = ">=0.34.3" }, ]