From 53a54413f9e2c8ff62075c0a96cccd63111348d2 Mon Sep 17 00:00:00 2001 From: Chris Sibbitt Date: Mon, 7 Jul 2025 21:52:58 -0400 Subject: [PATCH 1/7] Token counting * Returns token usage data in query and streaming query responses * llama-stack currently doesn't include usage in Agent API * We estimate locally with an embedding model * Accuracy will vary, but it should be nearly 100% for GPT models. --- docs/openapi.json | 24 ++- lightspeed-stack.yaml | 1 + pyproject.toml | 1 + src/app/endpoints/config.py | 1 + src/app/endpoints/query.py | 49 ++++- src/app/endpoints/streaming_query.py | 59 +++++- src/models/config.py | 1 + src/models/responses.py | 6 +- src/utils/token_counter.py | 177 ++++++++++++++++++ tests/unit/app/endpoints/test_config.py | 1 + tests/unit/app/endpoints/test_info.py | 1 + tests/unit/app/endpoints/test_models.py | 3 + tests/unit/app/endpoints/test_query.py | 51 ++++- .../app/endpoints/test_streaming_query.py | 90 ++++++++- tests/unit/models/test_config.py | 4 + tests/unit/test_configuration.py | 4 + tests/unit/utils/test_token_counter.py | 67 +++++++ uv.lock | 2 + 18 files changed, 502 insertions(+), 40 deletions(-) create mode 100644 src/utils/token_counter.py create mode 100644 tests/unit/utils/test_token_counter.py diff --git a/docs/openapi.json b/docs/openapi.json index 9413565ad..108f4619d 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -843,6 +843,28 @@ "response": { "type": "string", "title": "Response" + }, + "input_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Input Tokens" + }, + "output_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Output Tokens" } }, "type": "object", @@ -850,7 +872,7 @@ "response" ], "title": "QueryResponse", - "description": "Model representing LLM response to a query.\n\nAttributes:\n conversation_id: The optional conversation ID (UUID).\n response: The response.", + "description": "Model representing LLM response to a query.\n\nAttributes:\n conversation_id: The optional conversation ID (UUID).\n response: The response.\n input_tokens: Number of tokens sent to LLM.\n output_tokens: Number of tokens received from LLM.", "examples": [ { "conversation_id": "123e4567-e89b-12d3-a456-426614174000", diff --git a/lightspeed-stack.yaml b/lightspeed-stack.yaml index e729b5ee9..34806f58e 100644 --- a/lightspeed-stack.yaml +++ b/lightspeed-stack.yaml @@ -22,3 +22,4 @@ user_data_collection: transcripts_storage: "/tmp/data/transcripts" authentication: module: "noop" +default_estimation_tokenizer: "cl100k_base" diff --git a/pyproject.toml b/pyproject.toml index 26b79039d..fe92a1b9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "llama-stack>=0.2.13", "rich>=14.0.0", "cachetools>=6.1.0", + "tiktoken>=0.6.0", ] [tool.pyright] diff --git a/src/app/endpoints/config.py b/src/app/endpoints/config.py index 8cedc74d2..c84d4519b 100644 --- a/src/app/endpoints/config.py +++ b/src/app/endpoints/config.py @@ -46,6 +46,7 @@ {"name": "server2", "provider_id": "provider2", "url": "http://url.com:2"}, {"name": "server3", "provider_id": "provider3", "url": "http://url.com:3"}, ], + "default_estimation_tokenizer": "cl100k_base", }, 503: { "detail": { diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 3727c9cb9..22d91a229 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -31,6 +31,7 @@ from utils.endpoints import check_configuration_loaded, get_system_prompt from utils.mcp_headers import mcp_headers_dependency from utils.suid import get_suid +from utils.token_counter import get_token_counter logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["query"]) @@ -106,7 +107,7 @@ def query_endpoint_handler( # try to get Llama Stack client client = LlamaStackClientHolder().get_client() model_id = select_model_id(client.models.list(), query_request) - response, conversation_id = retrieve_response( + response, conversation_id, token_usage = retrieve_response( client, model_id, query_request, @@ -129,7 +130,12 @@ def query_endpoint_handler( attachments=query_request.attachments or [], ) - return QueryResponse(conversation_id=conversation_id, response=response) + return QueryResponse( + conversation_id=conversation_id, + response=response, + input_tokens=token_usage["input_tokens"], + output_tokens=token_usage["output_tokens"], + ) # connection to Llama Stack server except APIConnectionError as e: @@ -187,13 +193,21 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s return model_id +def _build_toolgroups(client: LlamaStackClient) -> list[Toolgroup] | None: + """Build toolgroups from vector DBs and MCP servers.""" + vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()] + return (get_rag_toolgroups(vector_db_ids) or []) + [ + mcp_server.name for mcp_server in configuration.mcp_servers + ] + + def retrieve_response( client: LlamaStackClient, model_id: str, query_request: QueryRequest, token: str, mcp_headers: dict[str, dict[str, str]] | None = None, -) -> tuple[str, str]: +) -> tuple[str, str, dict[str, int]]: """Retrieve response from LLMs and agents.""" available_shields = [shield.identifier for shield in client.shields.list()] if not available_shields: @@ -235,19 +249,36 @@ def retrieve_response( ), } - vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()] - toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ - mcp_server.name for mcp_server in configuration.mcp_servers - ] response = agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], session_id=conversation_id, documents=query_request.get_documents(), stream=False, - toolgroups=toolgroups or None, + toolgroups=_build_toolgroups(client) or None, ) - return str(response.output_message.content), conversation_id # type: ignore[union-attr] + response_content = str(response.output_message.content) # type: ignore[union-attr] + + # Currently (2025-07-08) the usage is not returned by the API, so we need to estimate it + # try: + # token_usage = { + # "input_tokens": response.usage.get("prompt_tokens", 0), + # "output_tokens": response.usage.get("completion_tokens", 0), + # } + # except AttributeError: + # Estimate token usage + try: + token_usage = get_token_counter(model_id).count_turn_tokens( + system_prompt, query_request.query, response_content + ) + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning("Failed to estimate token usage: %s", e) + token_usage = { + "input_tokens": 0, + "output_tokens": 0, + } + + return response_content, conversation_id, token_usage def validate_attachments_metadata(attachments: list[Attachment]) -> None: diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 5caff43ad..6840fc50a 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -24,6 +24,7 @@ from utils.common import retrieve_user_id from utils.mcp_headers import mcp_headers_dependency from utils.suid import get_suid +from utils.token_counter import get_token_counter from app.endpoints.query import ( @@ -94,8 +95,13 @@ def stream_start_event(conversation_id: str) -> str: ) -def stream_end_event(metadata_map: dict) -> str: - """Yield the end of the data stream.""" +def stream_end_event(metadata_map: dict, metrics_map: dict[str, int]) -> str: + """Yield the end of the data stream. + + Args: + metadata_map: Dictionary containing metadata about referenced documents + metrics_map: Dictionary containing metrics like 'input_tokens' and 'output_tokens' + """ return format_stream_data( { "event": "end", @@ -111,8 +117,8 @@ def stream_end_event(metadata_map: dict) -> str: ) ], "truncated": None, # TODO(jboos): implement truncated - "input_tokens": 0, # TODO(jboos): implement input tokens - "output_tokens": 0, # TODO(jboos): implement output tokens + "input_tokens": metrics_map.get("input_tokens", 0), + "output_tokens": metrics_map.get("output_tokens", 0), }, "available_quotas": {}, # TODO(jboos): implement available quotas } @@ -199,7 +205,7 @@ async def streaming_query_endpoint_handler( # try to get Llama Stack client client = AsyncLlamaStackClientHolder().get_client() model_id = select_model_id(await client.models.list(), query_request) - response, conversation_id = await retrieve_response( + response, conversation_id, token_usage = await retrieve_response( client, model_id, query_request, @@ -224,7 +230,25 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]: chunk_id += 1 yield event - yield stream_end_event(metadata_map) + # Currently (2025-07-08) the usage is not returned by the API, so we need to estimate + # try: + # output_tokens = response.usage.get("completion_tokens", 0) + # except AttributeError: + # Estimate output tokens from complete response + try: + output_tokens = get_token_counter(model_id).count_tokens( + complete_response + ) + logger.debug("Estimated output tokens: %s", output_tokens) + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning("Failed to estimate output tokens: %s", e) + output_tokens = 0 + + metrics_map = { + "input_tokens": token_usage["input_tokens"], + "output_tokens": output_tokens, + } + yield stream_end_event(metadata_map, metrics_map) if not is_transcripts_enabled(): logger.debug("Transcript collection is disabled in the configuration") @@ -261,7 +285,7 @@ async def retrieve_response( query_request: QueryRequest, token: str, mcp_headers: dict[str, dict[str, str]] | None = None, -) -> tuple[Any, str]: +) -> tuple[Any, str, dict[str, int]]: """Retrieve response from LLMs and agents.""" available_shields = [shield.identifier for shield in await client.shields.list()] if not available_shields: @@ -318,4 +342,23 @@ async def retrieve_response( toolgroups=toolgroups or None, ) - return response, conversation_id + # Currently (2025-07-08) the usage is not returned by the API, so we need to estimate it + # try: + # token_usage = { + # "input_tokens": response.usage.get("prompt_tokens", 0), + # "output_tokens": 0, # Will be calculated during streaming + # } + # except AttributeError: + # # Estimate input tokens (Output will be calculated during streaming) + try: + token_usage = get_token_counter(model_id).count_turn_tokens( + system_prompt, query_request.query + ) + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning("Failed to estimate token usage: %s", e) + token_usage = { + "input_tokens": 0, + "output_tokens": 0, + } + + return response, conversation_id, token_usage diff --git a/src/models/config.py b/src/models/config.py index 3bef0ccf4..94c6bd524 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -154,6 +154,7 @@ class Configuration(BaseModel): AuthenticationConfiguration() ) customization: Optional[Customization] = None + default_estimation_tokenizer: str = "cl100k_base" def dump(self, filename: str = "configuration.json") -> None: """Dump actual configuration into JSON file.""" diff --git a/src/models/responses.py b/src/models/responses.py index 92c366c9b..8664d8365 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -16,8 +16,6 @@ class ModelsResponse(BaseModel): # - referenced_documents: The optional URLs and titles for the documents used # to generate the response. # - truncated: Set to True if conversation history was truncated to be within context window. -# - input_tokens: Number of tokens sent to LLM -# - output_tokens: Number of tokens received from LLM # - available_quotas: Quota available as measured by all configured quota limiters # - tool_calls: List of tool requests. # - tool_results: List of tool results. @@ -28,10 +26,14 @@ class QueryResponse(BaseModel): Attributes: conversation_id: The optional conversation ID (UUID). response: The response. + input_tokens: Number of tokens sent to LLM. + output_tokens: Number of tokens received from LLM. """ conversation_id: Optional[str] = None response: str + input_tokens: Optional[int] = None + output_tokens: Optional[int] = None # provides examples for /docs endpoint model_config = { diff --git a/src/utils/token_counter.py b/src/utils/token_counter.py new file mode 100644 index 000000000..aa437ac6d --- /dev/null +++ b/src/utils/token_counter.py @@ -0,0 +1,177 @@ +"""Token counting utilities using tiktoken. + +This module provides utilities for counting tokens in text and conversation messages +using the tiktoken library. It supports automatic model-specific encoding detection +with fallback to a default tokenizer. +""" + +from functools import lru_cache +import logging +from typing import Sequence + +from llama_stack_client.types import ( + UserMessage, + SystemMessage, + ToolResponseMessage, + CompletionMessage, +) +import tiktoken + +from configuration import configuration + +logger = logging.getLogger(__name__) + + +class TokenCounter: + """A utility class for counting tokens in text and conversation messages. + + This class provides methods to count tokens in plain text and structured + conversation messages. It automatically handles model-specific tokenization + using tiktoken, with fallback to a default tokenizer if the model is not + recognized. + + Attributes: + _encoder: The tiktoken encoding object used for tokenization + """ + + def __init__(self, model_id: str): + """Initialize the TokenCounter with a specific model. + + Args: + model_id: The identifier of the model to use for tokenization. + This is used to determine the appropriate tiktoken encoding. + + Note: + If the model_id is not recognized by tiktoken, the system will + fall back to the default estimation tokenizer specified in the + configuration. + """ + self._encoder = None + + try: + # Use tiktoken's encoding_for_model function which handles GPT models automatically + self._encoder = tiktoken.encoding_for_model(model_id) + logger.debug("Initialized tiktoken encoding for model: %s", model_id) + except KeyError as e: + fallback_encoding = configuration.configuration.default_estimation_tokenizer + logger.warning( + "Failed to get encoding for model %s: %s, using %s", + model_id, + e, + fallback_encoding, + ) + self._encoder = tiktoken.get_encoding(fallback_encoding) + + def count_tokens(self, text: str) -> int: + """Count the number of tokens in a given text string. + + Args: + text: The text string to count tokens for. + + Returns: + The number of tokens in the text. Returns 0 if text is empty or None. + """ + if not text or not self._encoder: + return 0 + return len(self._encoder.encode(text)) + + def count_turn_tokens( + self, system_prompt: str, query: str, response: str = "" + ) -> dict[str, int]: + """Count tokens for a complete conversation turn. + + This method estimates token usage for a typical conversation turn, + including system prompt, user query, and optional response. It accounts + for message formatting overhead and conversation structure. + + Args: + system_prompt: The system prompt message content. + query: The user's query message content. + response: The assistant's response message content (optional). + + Returns: + A dictionary containing: + - 'input_tokens': Total tokens in the input messages (system + query) + - 'output_tokens': Total tokens in the response message + """ + # Estimate token usage + input_messages: list[SystemMessage | UserMessage] = [] + if system_prompt: + input_messages.append( + SystemMessage(role="system", content=str(system_prompt)) + ) + input_messages.append(UserMessage(role="user", content=query)) + + input_tokens = self.count_message_tokens(input_messages) + output_tokens = self.count_tokens(response) + + logger.debug("Estimated tokens in/out: %d / %d", input_tokens, output_tokens) + + return { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + } + + def count_message_tokens( + self, + messages: Sequence[ + UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage + ], + ) -> int: + """Count tokens for a list of conversation messages. + + This method counts tokens for structured conversation messages, including + the message content and formatting overhead for roles and conversation + structure. + + Args: + messages: A list of message objects (e.g., SystemMessage, UserMessage) + + Returns: + The total number of tokens across all messages, including formatting overhead. + """ + total_tokens = 0 + + for message in messages: + total_tokens += self.count_tokens(str(message.content)) + # Add role overhead + role_formatting_overhead = 4 + total_tokens += role_formatting_overhead + + # Add conversation formatting overhead + if messages: + total_tokens += self._get_conversation_overhead(len(messages)) + + return total_tokens + + def _get_conversation_overhead(self, message_count: int) -> int: + """Calculate the token overhead for conversation formatting. + + This method estimates the additional tokens needed for conversation + structure, including start/end tokens and message separators. + + Args: + message_count: The number of messages in the conversation. + + Returns: + The estimated token overhead for conversation formatting. + """ + base_overhead = 3 # Start of conversation + separator_overhead = max(0, (message_count - 1) * 1) # Between messages + return base_overhead + separator_overhead + + +@lru_cache(maxsize=8) +def get_token_counter(model_id: str) -> TokenCounter: + """Get a cached TokenCounter instance for the specified model. + + This function provides a cached TokenCounter instance to avoid repeated + initialization of the same model's tokenizer. + + Args: + model_id: The identifier of the model to get a token counter for. + + Returns: + A TokenCounter instance configured for the specified model. + """ + return TokenCounter(model_id) diff --git a/tests/unit/app/endpoints/test_config.py b/tests/unit/app/endpoints/test_config.py index 4a2cd3c1c..341f4020f 100644 --- a/tests/unit/app/endpoints/test_config.py +++ b/tests/unit/app/endpoints/test_config.py @@ -44,6 +44,7 @@ def test_config_endpoint_handler_configuration_loaded(mocker): "module": "noop", }, "customization": None, + "default_estimation_tokenizer": "cl100k_base", } cfg = AppConfig() cfg.init_from_dict(config_dict) diff --git a/tests/unit/app/endpoints/test_info.py b/tests/unit/app/endpoints/test_info.py index 7dc1d46bc..f5410226a 100644 --- a/tests/unit/app/endpoints/test_info.py +++ b/tests/unit/app/endpoints/test_info.py @@ -23,6 +23,7 @@ def test_info_endpoint(mocker): "feedback_disabled": True, }, "customization": None, + "default_estimation_tokenizer": "cl100k_base", } cfg = AppConfig() cfg.init_from_dict(config_dict) diff --git a/tests/unit/app/endpoints/test_models.py b/tests/unit/app/endpoints/test_models.py index f9bcc51ba..571eb35c3 100644 --- a/tests/unit/app/endpoints/test_models.py +++ b/tests/unit/app/endpoints/test_models.py @@ -46,6 +46,7 @@ def test_models_endpoint_handler_improper_llama_stack_configuration(mocker): }, "mcp_servers": [], "customization": None, + "default_estimation_tokenizer": "cl100k_base", } cfg = AppConfig() cfg.init_from_dict(config_dict) @@ -84,6 +85,7 @@ def test_models_endpoint_handler_configuration_loaded(mocker): "feedback_disabled": True, }, "customization": None, + "default_estimation_tokenizer": "cl100k_base", } cfg = AppConfig() cfg.init_from_dict(config_dict) @@ -117,6 +119,7 @@ def test_models_endpoint_handler_unable_to_retrieve_models_list(mocker): "feedback_disabled": True, }, "customization": None, + "default_estimation_tokenizer": "cl100k_base", } cfg = AppConfig() cfg.init_from_dict(config_dict) diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index c1a8e774b..58484db7d 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -44,6 +44,7 @@ def setup_configuration(): }, "mcp_servers": [], "customization": None, + "default_estimation_tokenizer": "cl100k_base", } cfg = AppConfig() cfg.init_from_dict(config_dict) @@ -118,7 +119,11 @@ def _test_query_endpoint_handler(mocker, store_transcript=False): mocker.patch( "app.endpoints.query.retrieve_response", - return_value=(llm_response, conversation_id), + return_value=( + llm_response, + conversation_id, + {"input_tokens": 10, "output_tokens": 20}, + ), ) mocker.patch("app.endpoints.query.select_model_id", return_value="fake_model_id") mocker.patch( @@ -295,6 +300,7 @@ def test_validate_attachments_metadata_invalid_content_type(): def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] mock_vector_db = mocker.Mock() @@ -313,12 +319,14 @@ def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id, token_usage = retrieve_response( mock_client, model_id, query_request, access_token ) assert response == "LLM answer" assert conversation_id == "fake_session_id" + assert token_usage["input_tokens"] > 0 + assert token_usage["output_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -331,6 +339,8 @@ def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker): def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks + + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -347,12 +357,14 @@ def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id, token_usage = retrieve_response( mock_client, model_id, query_request, access_token ) assert response == "LLM answer" assert conversation_id == "fake_session_id" + assert token_usage["input_tokens"] > 0 + assert token_usage["output_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -370,6 +382,8 @@ def __init__(self, identifier): self.identifier = identifier mock_client, mock_agent = prepare_agent_mocks + + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [MockShield("shield1")] mock_client.vector_dbs.list.return_value = [] @@ -386,12 +400,14 @@ def __init__(self, identifier): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id, token_usage = retrieve_response( mock_client, model_id, query_request, access_token ) assert response == "LLM answer" assert conversation_id == "fake_session_id" + assert token_usage["input_tokens"] > 0 + assert token_usage["output_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -409,6 +425,7 @@ def __init__(self, identifier): self.identifier = identifier mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [ MockShield("shield1"), @@ -428,12 +445,14 @@ def __init__(self, identifier): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id, token_usage = retrieve_response( mock_client, model_id, query_request, access_token ) assert response == "LLM answer" assert conversation_id == "fake_session_id" + assert token_usage["input_tokens"] > 0 + assert token_usage["output_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -446,6 +465,7 @@ def __init__(self, identifier): def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -470,12 +490,14 @@ def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id, token_usage = retrieve_response( mock_client, model_id, query_request, access_token ) assert response == "LLM answer" assert conversation_id == "fake_session_id" + assert token_usage["input_tokens"] > 0 + assert token_usage["output_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -493,6 +515,7 @@ def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker): def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -522,12 +545,14 @@ def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token" - response, conversation_id = retrieve_response( + response, conversation_id, token_usage = retrieve_response( mock_client, model_id, query_request, access_token ) assert response == "LLM answer" assert conversation_id == "fake_session_id" + assert token_usage["input_tokens"] > 0 + assert token_usage["output_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(content="What is OpenStack?", role="user")], session_id="fake_session_id", @@ -549,6 +574,7 @@ def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocker): def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): """Test the retrieve_response function with MCP servers configured.""" mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -575,12 +601,14 @@ def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token_123" - response, conversation_id = retrieve_response( + response, conversation_id, token_usage = retrieve_response( mock_client, model_id, query_request, access_token ) assert response == "LLM answer" assert conversation_id == "fake_session_id" + assert token_usage["input_tokens"] > 0 + assert token_usage["output_tokens"] > 0 # Verify get_agent was called with the correct parameters mock_get_agent.assert_called_once_with( @@ -619,6 +647,7 @@ def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): def test_retrieve_response_with_mcp_servers_empty_token(prepare_agent_mocks, mocker): """Test the retrieve_response function with MCP servers and empty access token.""" mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -638,12 +667,14 @@ def test_retrieve_response_with_mcp_servers_empty_token(prepare_agent_mocks, moc model_id = "fake_model_id" access_token = "" # Empty token - response, conversation_id = retrieve_response( + response, conversation_id, token_usage = retrieve_response( mock_client, model_id, query_request, access_token ) assert response == "LLM answer" assert conversation_id == "fake_session_id" + assert token_usage["input_tokens"] > 0 + assert token_usage["output_tokens"] > 0 # Verify get_agent was called with the correct parameters mock_get_agent.assert_called_once_with( @@ -698,7 +729,7 @@ def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker): "https://git.example.com/mcp": {"Authorization": "Bearer test_token_123"}, } - response, conversation_id = retrieve_response( + response, conversation_id, token_usage = retrieve_response( mock_client, model_id, query_request, diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 23ccb19e8..8b2e493d9 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -141,6 +141,8 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) # Mock the streaming response from LLama Stack mock_streaming_response = mocker.AsyncMock() + # Currently usage is not returned by the API, we simulate by using del to prevent pytest from returning a Mock + del mock_streaming_response.usage mock_streaming_response.__aiter__.return_value = [ mocker.Mock( event=mocker.Mock( @@ -181,7 +183,11 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) query = "What is OpenStack?" mocker.patch( "app.endpoints.streaming_query.retrieve_response", - return_value=(mock_streaming_response, "test_conversation_id"), + return_value=( + mock_streaming_response, + "test_conversation_id", + {"input_tokens": 10, "output_tokens": 20}, + ), ) mocker.patch( "app.endpoints.streaming_query.select_model_id", return_value="fake_model_id" @@ -263,8 +269,15 @@ async def test_streaming_query_endpoint_handler_store_transcript(mocker): async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks + + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] + mock_client.models.list.return_value = [ + mocker.Mock( + identifier="fake_model_id", model_type="llm", provider_id="test_provider" + ), + ] mock_vector_db = mocker.Mock() mock_vector_db.identifier = "VectorDB-1" mock_client.vector_dbs.list.return_value = [mock_vector_db] @@ -282,13 +295,14 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker model_id = "fake_model_id" token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, token_usage = await retrieve_response( mock_client, model_id, query_request, token ) # For streaming, the response should be the streaming object and conversation_id should be returned assert response is not None assert conversation_id == "test_conversation_id" + assert token_usage["input_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], session_id="test_conversation_id", @@ -301,8 +315,15 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks + + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] + mock_client.models.list.return_value = [ + mocker.Mock( + identifier="fake_model_id", model_type="llm", provider_id="test_provider" + ), + ] mock_client.vector_dbs.list.return_value = [] # Mock configuration with empty MCP servers @@ -318,13 +339,14 @@ async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocke model_id = "fake_model_id" token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, token_usage = await retrieve_response( mock_client, model_id, query_request, token ) # For streaming, the response should be the streaming object and conversation_id should be returned assert response is not None assert conversation_id == "test_conversation_id" + assert token_usage["input_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], session_id="test_conversation_id", @@ -345,8 +367,15 @@ def identifier(self): return self.identifier mock_client, mock_agent = prepare_agent_mocks + + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [MockShield("shield1")] + mock_client.models.list.return_value = [ + mocker.Mock( + identifier="fake_model_id", model_type="llm", provider_id="test_provider" + ), + ] mock_client.vector_dbs.list.return_value = [] # Mock configuration with empty MCP servers @@ -362,12 +391,13 @@ def identifier(self): model_id = "fake_model_id" token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, token_usage = await retrieve_response( mock_client, model_id, query_request, token ) assert response is not None assert conversation_id == "test_conversation_id" + assert token_usage["input_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], session_id="test_conversation_id", @@ -388,11 +418,18 @@ def identifier(self): return self.identifier mock_client, mock_agent = prepare_agent_mocks + + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [ MockShield("shield1"), MockShield("shield2"), ] + mock_client.models.list.return_value = [ + mocker.Mock( + identifier="fake_model_id", model_type="llm", provider_id="test_provider" + ), + ] mock_client.vector_dbs.list.return_value = [] # Mock configuration with empty MCP servers @@ -408,12 +445,13 @@ def identifier(self): model_id = "fake_model_id" token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, token_usage = await retrieve_response( mock_client, model_id, query_request, token ) assert response is not None assert conversation_id == "test_conversation_id" + assert token_usage["input_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], session_id="test_conversation_id", @@ -426,8 +464,15 @@ def identifier(self): async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks + + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] + mock_client.models.list.return_value = [ + mocker.Mock( + identifier="fake_model_id", model_type="llm", provider_id="test_provider" + ), + ] mock_client.vector_dbs.list.return_value = [] # Mock configuration with empty MCP servers @@ -451,12 +496,13 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker model_id = "fake_model_id" token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, token_usage = await retrieve_response( mock_client, model_id, query_request, token ) assert response is not None assert conversation_id == "test_conversation_id" + assert token_usage["input_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], session_id="test_conversation_id", @@ -474,8 +520,15 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks + + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] + mock_client.models.list.return_value = [ + mocker.Mock( + identifier="fake_model_id", model_type="llm", provider_id="test_provider" + ), + ] mock_client.vector_dbs.list.return_value = [] # Mock configuration with empty MCP servers @@ -504,12 +557,13 @@ async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocke model_id = "fake_model_id" token = "test_token" - response, conversation_id = await retrieve_response( + response, conversation_id, token_usage = await retrieve_response( mock_client, model_id, query_request, token ) assert response is not None assert conversation_id == "test_conversation_id" + assert token_usage["input_tokens"] > 0 mock_agent.create_turn.assert_called_once_with( messages=[UserMessage(role="user", content="What is OpenStack?")], session_id="test_conversation_id", @@ -606,8 +660,15 @@ def test_stream_build_event_returns_none(mocker): async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): """Test the retrieve_response function with MCP servers configured.""" mock_client, mock_agent = prepare_agent_mocks + + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] + mock_client.models.list.return_value = [ + mocker.Mock( + identifier="fake_model_id", model_type="llm", provider_id="test_provider" + ), + ] mock_client.vector_dbs.list.return_value = [] # Mock configuration with MCP servers @@ -633,12 +694,13 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): model_id = "fake_model_id" access_token = "test_token_123" - response, conversation_id = await retrieve_response( + response, conversation_id, token_usage = await retrieve_response( mock_client, model_id, query_request, access_token ) assert response is not None assert conversation_id == "test_conversation_id" + assert token_usage["input_tokens"] > 0 # Verify get_agent was called with the correct parameters mock_get_agent.assert_called_once_with( @@ -679,8 +741,15 @@ async def test_retrieve_response_with_mcp_servers_empty_token( ): """Test the retrieve_response function with MCP servers and empty access token.""" mock_client, mock_agent = prepare_agent_mocks + + mock_agent.create_turn.return_value = mocker.Mock(spec=["output_message"]) mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] + mock_client.models.list.return_value = [ + mocker.Mock( + identifier="fake_model_id", model_type="llm", provider_id="test_provider" + ), + ] mock_client.vector_dbs.list.return_value = [] # Mock configuration with MCP servers @@ -699,12 +768,13 @@ async def test_retrieve_response_with_mcp_servers_empty_token( model_id = "fake_model_id" access_token = "" # Empty token - response, conversation_id = await retrieve_response( + response, conversation_id, token_usage = await retrieve_response( mock_client, model_id, query_request, access_token ) assert response is not None assert conversation_id == "test_conversation_id" + assert token_usage["input_tokens"] > 0 # Verify get_agent was called with the correct parameters mock_get_agent.assert_called_once_with( @@ -766,7 +836,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker): "https://git.example.com/mcp": {"Authorization": "Bearer test_token_456"}, } - response, conversation_id = await retrieve_response( + response, conversation_id, token_usage = await retrieve_response( mock_client, model_id, query_request, diff --git a/tests/unit/models/test_config.py b/tests/unit/models/test_config.py index 254ae7838..0d76ccf14 100644 --- a/tests/unit/models/test_config.py +++ b/tests/unit/models/test_config.py @@ -321,6 +321,7 @@ def test_dump_configuration(tmp_path) -> None: ), mcp_servers=[], customization=None, + default_estimation_tokenizer="cl100k_base", ) assert cfg is not None dump_file = tmp_path / "test.json" @@ -375,6 +376,7 @@ def test_dump_configuration(tmp_path) -> None: "k8s_cluster_api": None, }, "customization": None, + "default_estimation_tokenizer": "cl100k_base", } @@ -449,6 +451,7 @@ def test_dump_configuration_with_one_mcp_server(tmp_path) -> None: "k8s_cluster_api": None, }, "customization": None, + "default_estimation_tokenizer": "cl100k_base", } @@ -541,6 +544,7 @@ def test_dump_configuration_with_more_mcp_servers(tmp_path) -> None: "k8s_cluster_api": None, }, "customization": None, + "default_estimation_tokenizer": "cl100k_base", } diff --git a/tests/unit/test_configuration.py b/tests/unit/test_configuration.py index 97831f23e..4d872c16a 100644 --- a/tests/unit/test_configuration.py +++ b/tests/unit/test_configuration.py @@ -50,6 +50,7 @@ def test_init_from_dict() -> None: }, "mcp_servers": [], "customization": None, + "default_estimation_tokenizer": "cl100k_base", } cfg = AppConfig() cfg.init_from_dict(config_dict) @@ -112,6 +113,7 @@ def test_init_from_dict_with_mcp_servers() -> None: }, ], "customization": None, + "default_estimation_tokenizer": "cl100k_base", } cfg = AppConfig() cfg.init_from_dict(config_dict) @@ -219,6 +221,7 @@ def test_mcp_servers_property_empty() -> None: }, "mcp_servers": [], "customization": None, + "default_estimation_tokenizer": "cl100k_base", } cfg = AppConfig() cfg.init_from_dict(config_dict) @@ -255,6 +258,7 @@ def test_mcp_servers_property_with_servers() -> None: }, ], "customization": None, + "default_estimation_tokenizer": "cl100k_base", } cfg = AppConfig() cfg.init_from_dict(config_dict) diff --git a/tests/unit/utils/test_token_counter.py b/tests/unit/utils/test_token_counter.py new file mode 100644 index 000000000..846d5f255 --- /dev/null +++ b/tests/unit/utils/test_token_counter.py @@ -0,0 +1,67 @@ +"""Unit tests for token counter utilities.""" + +from utils.token_counter import TokenCounter +from llama_stack_client.types import UserMessage, CompletionMessage +from configuration import AppConfig + + +config_dict = { + "name": "foo", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "workers": 1, + "color_log": True, + "access_log": True, + }, + "llama_stack": { + "api_key": "xyzzy", + "url": "http://x.y.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": { + "feedback_disabled": True, + }, + "default_estimation_tokenizer": "cl100k_base", +} + + +class TestTokenCounter: + """Test cases for TokenCounter class.""" + + def setup_class(self): + cfg = AppConfig() + cfg.init_from_dict(config_dict) + + def test_count_tokens_empty_string(self): + """Test counting tokens for empty message list.""" + counter = TokenCounter("gpt-4") + assert counter.count_tokens("") == 0 + + def test_count_tokens_simple(self): + counter = TokenCounter("llama3.2:1b") + assert counter.count_tokens("Hello World!") == 3 + + def test_count_message_tokens_simple(self): + """Test counting tokens for simple messages.""" + counter = TokenCounter("llama3.2:1b") + + messages = [ + UserMessage(role="user", content="Hello"), + CompletionMessage( + role="assistant", content="Hi there", stop_reason="end_of_turn" + ), + ] + + result = counter.count_message_tokens(messages) + + # 3 tokens worth of content + 4 role overhead per message + 4 conversation overhead + expected = 3 + (4 * 2) + 4 + assert result == expected + + def test_count_message_tokens_empty_messages(self): + """Test counting tokens for empty message list.""" + counter = TokenCounter("llama3.2:1b") + result = counter.count_message_tokens([]) + assert result == 0 diff --git a/uv.lock b/uv.lock index 8a03720d8..4f9602fc8 100644 --- a/uv.lock +++ b/uv.lock @@ -688,6 +688,7 @@ dependencies = [ { name = "kubernetes" }, { name = "llama-stack" }, { name = "rich" }, + { name = "tiktoken" }, { name = "uvicorn" }, ] @@ -715,6 +716,7 @@ requires-dist = [ { name = "kubernetes", specifier = ">=30.1.0" }, { name = "llama-stack", specifier = ">=0.2.13" }, { name = "rich", specifier = ">=14.0.0" }, + { name = "tiktoken", specifier = ">=0.6.0" }, { name = "uvicorn", specifier = ">=0.34.3" }, ] From 4c1bd955746797cae0c964dcace65ef805cfce4c Mon Sep 17 00:00:00 2001 From: Chris Sibbitt Date: Wed, 9 Jul 2025 15:29:30 -0400 Subject: [PATCH 2/7] raise minimum version of tiktoken --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fe92a1b9a..bf7fb9c35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "llama-stack>=0.2.13", "rich>=14.0.0", "cachetools>=6.1.0", - "tiktoken>=0.6.0", + "tiktoken>=0.9.0,<1.0.0", ] [tool.pyright] From 67269cb6bc1e11c9587f743b1fb848eed5cd37de Mon Sep 17 00:00:00 2001 From: Chris Sibbitt Date: Tue, 15 Jul 2025 11:30:32 -0400 Subject: [PATCH 3/7] Fixes from review * Token counts consider full conversation history * Token counts not defined as optional * Added default_estimation_tokenizer to constants and customization --- docs/openapi.json | 5 ++ lightspeed-stack.yaml | 1 - src/app/endpoints/config.py | 1 - src/app/endpoints/query.py | 5 +- src/app/endpoints/streaming_query.py | 31 ++++++---- src/constants.py | 2 + src/models/config.py | 2 +- src/models/responses.py | 4 +- src/utils/token_counter.py | 81 ++++++++++++++++++++++--- tests/unit/app/endpoints/test_config.py | 1 - tests/unit/app/endpoints/test_info.py | 1 - tests/unit/app/endpoints/test_models.py | 3 - tests/unit/app/endpoints/test_query.py | 1 - tests/unit/models/test_config.py | 4 -- tests/unit/test_configuration.py | 4 -- tests/unit/utils/test_token_counter.py | 25 +++++++- uv.lock | 2 +- 17 files changed, 131 insertions(+), 42 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 87ec07033..79104465b 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -559,6 +559,11 @@ } ], "title": "System Prompt" + }, + "default_estimation_tokenizer": { + "type": "string", + "title": "Default Estimation Tokenizer", + "description": "The default tokenizer to use for estimating token usage when the model is not supported by tiktoken." } }, "type": "object", diff --git a/lightspeed-stack.yaml b/lightspeed-stack.yaml index 34806f58e..e729b5ee9 100644 --- a/lightspeed-stack.yaml +++ b/lightspeed-stack.yaml @@ -22,4 +22,3 @@ user_data_collection: transcripts_storage: "/tmp/data/transcripts" authentication: module: "noop" -default_estimation_tokenizer: "cl100k_base" diff --git a/src/app/endpoints/config.py b/src/app/endpoints/config.py index c84d4519b..8cedc74d2 100644 --- a/src/app/endpoints/config.py +++ b/src/app/endpoints/config.py @@ -46,7 +46,6 @@ {"name": "server2", "provider_id": "provider2", "url": "http://url.com:2"}, {"name": "server3", "provider_id": "provider3", "url": "http://url.com:3"}, ], - "default_estimation_tokenizer": "cl100k_base", }, 503: { "detail": { diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 22d91a229..ca1e743ab 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -268,8 +268,9 @@ def retrieve_response( # except AttributeError: # Estimate token usage try: - token_usage = get_token_counter(model_id).count_turn_tokens( - system_prompt, query_request.query, response_content + token_counter = get_token_counter(model_id) + token_usage = token_counter.count_conversation_turn_tokens( + conversation_id, system_prompt, query_request.query, response_content ) except Exception as e: # pylint: disable=broad-exception-caught logger.warning("Failed to estimate token usage: %s", e) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 0699bf3dc..ef974cdbc 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -9,10 +9,12 @@ from llama_stack_client import APIConnectionError from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore +from llama_stack_client.types.agents.turn_create_params import Toolgroup from llama_stack_client import AsyncLlamaStackClient # type: ignore from llama_stack_client.types.shared.interleaved_content_item import TextContentItem from llama_stack_client.types import UserMessage # type: ignore + from fastapi import APIRouter, HTTPException, Request, Depends, status from fastapi.responses import StreamingResponse @@ -237,9 +239,8 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]: # except AttributeError: # Estimate output tokens from complete response try: - output_tokens = get_token_counter(model_id).count_tokens( - complete_response - ) + token_counter = get_token_counter(model_id) + output_tokens = token_counter.count_tokens(complete_response) logger.debug("Estimated output tokens: %s", output_tokens) except Exception as e: # pylint: disable=broad-exception-caught logger.warning("Failed to estimate output tokens: %s", e) @@ -280,6 +281,16 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]: ) from e +async def _build_toolgroups(client: AsyncLlamaStackClient) -> list[Toolgroup] | None: + """Build toolgroups from vector DBs and MCP servers.""" + vector_db_ids = [ + vector_db.identifier for vector_db in await client.vector_dbs.list() + ] + return (get_rag_toolgroups(vector_db_ids) or []) + [ + mcp_server.name for mcp_server in configuration.mcp_servers + ] + + async def retrieve_response( client: AsyncLlamaStackClient, model_id: str, @@ -329,18 +340,13 @@ async def retrieve_response( } logger.debug("Session ID: %s", conversation_id) - vector_db_ids = [ - vector_db.identifier for vector_db in await client.vector_dbs.list() - ] - toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ - mcp_server.name for mcp_server in configuration.mcp_servers - ] + response = await agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], session_id=conversation_id, documents=query_request.get_documents(), stream=True, - toolgroups=toolgroups or None, + toolgroups=await _build_toolgroups(client) or None, ) # Currently (2025-07-08) the usage is not returned by the API, so we need to estimate it @@ -352,8 +358,9 @@ async def retrieve_response( # except AttributeError: # # Estimate input tokens (Output will be calculated during streaming) try: - token_usage = get_token_counter(model_id).count_turn_tokens( - system_prompt, query_request.query + token_counter = get_token_counter(model_id) + token_usage = token_counter.count_conversation_turn_tokens( + conversation_id, system_prompt, query_request.query ) except Exception as e: # pylint: disable=broad-exception-caught logger.warning("Failed to estimate token usage: %s", e) diff --git a/src/constants.py b/src/constants.py index 9407a4ee9..b02bd50c6 100644 --- a/src/constants.py +++ b/src/constants.py @@ -42,3 +42,5 @@ } ) DEFAULT_AUTHENTICATION_MODULE = AUTH_MOD_NOOP +# Default tokenizer for estimating token usage +DEFAULT_ESTIMATION_TOKENIZER = "cl100k_base" diff --git a/src/models/config.py b/src/models/config.py index 94c6bd524..7f93cc414 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -130,6 +130,7 @@ class Customization(BaseModel): system_prompt_path: Optional[FilePath] = None system_prompt: Optional[str] = None + default_estimation_tokenizer: str = constants.DEFAULT_ESTIMATION_TOKENIZER @model_validator(mode="after") def check_authentication_model(self) -> Self: @@ -154,7 +155,6 @@ class Configuration(BaseModel): AuthenticationConfiguration() ) customization: Optional[Customization] = None - default_estimation_tokenizer: str = "cl100k_base" def dump(self, filename: str = "configuration.json") -> None: """Dump actual configuration into JSON file.""" diff --git a/src/models/responses.py b/src/models/responses.py index 8664d8365..b1938c34d 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -32,8 +32,8 @@ class QueryResponse(BaseModel): conversation_id: Optional[str] = None response: str - input_tokens: Optional[int] = None - output_tokens: Optional[int] = None + input_tokens: int = 0 + output_tokens: int = 0 # provides examples for /docs endpoint model_config = { diff --git a/src/utils/token_counter.py b/src/utils/token_counter.py index aa437ac6d..6b7b4e39c 100644 --- a/src/utils/token_counter.py +++ b/src/utils/token_counter.py @@ -2,13 +2,16 @@ This module provides utilities for counting tokens in text and conversation messages using the tiktoken library. It supports automatic model-specific encoding detection -with fallback to a default tokenizer. +with fallback to a default tokenizer, and includes conversation-level token tracking +for Agent conversations. """ from functools import lru_cache import logging from typing import Sequence +from cachetools import TTLCache # type: ignore + from llama_stack_client.types import ( UserMessage, SystemMessage, @@ -17,10 +20,14 @@ ) import tiktoken -from configuration import configuration +from configuration import configuration, AppConfig +from constants import DEFAULT_ESTIMATION_TOKENIZER logger = logging.getLogger(__name__) +# Class-level cache to track cumulative input tokens for each conversation +_conversation_cache: TTLCache[str, int] = TTLCache(maxsize=1000, ttl=3600) + class TokenCounter: """A utility class for counting tokens in text and conversation messages. @@ -28,7 +35,7 @@ class TokenCounter: This class provides methods to count tokens in plain text and structured conversation messages. It automatically handles model-specific tokenization using tiktoken, with fallback to a default tokenizer if the model is not - recognized. + recognized. It also tracks cumulative input tokens for Agent conversations. Attributes: _encoder: The tiktoken encoding object used for tokenization @@ -53,7 +60,7 @@ def __init__(self, model_id: str): self._encoder = tiktoken.encoding_for_model(model_id) logger.debug("Initialized tiktoken encoding for model: %s", model_id) except KeyError as e: - fallback_encoding = configuration.configuration.default_estimation_tokenizer + fallback_encoding = get_default_estimation_tokenizer(configuration) logger.warning( "Failed to get encoding for model %s: %s, using %s", model_id, @@ -112,6 +119,54 @@ def count_turn_tokens( "output_tokens": output_tokens, } + def count_conversation_turn_tokens( + self, conversation_id: str, system_prompt: str, query: str, response: str = "" + ) -> dict[str, int]: + """Count tokens for a conversation turn with cumulative tracking. + + This method estimates token usage for a conversation turn and tracks + cumulative input tokens across the conversation. It accounts for the + fact that Agent conversations include the entire message history in + each turn. + + Args: + conversation_id: The conversation ID to track tokens for. + system_prompt: The system prompt message content. + query: The user's query message content. + response: The assistant's response message content (optional). + + Returns: + A dictionary containing: + - 'input_tokens': Cumulative input tokens for the conversation + - 'output_tokens': Total tokens in the response message + """ + # Get the current turn's token usage + turn_token_usage = self.count_turn_tokens(system_prompt, query, response) + + # Get cumulative input tokens for this conversation + cumulative_input_tokens = _conversation_cache.get(conversation_id, 0) + + # Add this turn's input tokens to the cumulative total + new_cumulative_input_tokens = ( + cumulative_input_tokens + turn_token_usage["input_tokens"] + ) + _conversation_cache[conversation_id] = new_cumulative_input_tokens + + # TODO(csibbitt) - Add counting for MCP and RAG content + + logger.debug( + "Token usage for conversation %s: turn input=%d, cumulative input=%d, output=%d", + conversation_id, + turn_token_usage["input_tokens"], + new_cumulative_input_tokens, + turn_token_usage["output_tokens"], + ) + + return { + "input_tokens": new_cumulative_input_tokens, + "output_tokens": turn_token_usage["output_tokens"], + } + def count_message_tokens( self, messages: Sequence[ @@ -134,7 +189,7 @@ def count_message_tokens( for message in messages: total_tokens += self.count_tokens(str(message.content)) - # Add role overhead + # Add role overhead (varies by model, 4 is typical for OpenAI models) role_formatting_overhead = 4 total_tokens += role_formatting_overhead @@ -156,8 +211,8 @@ def _get_conversation_overhead(self, message_count: int) -> int: Returns: The estimated token overhead for conversation formatting. """ - base_overhead = 3 # Start of conversation - separator_overhead = max(0, (message_count - 1) * 1) # Between messages + base_overhead = 3 # Start of conversation tokens (based on OpenAI chat format) + separator_overhead = max(0, (message_count - 1) * 1) # Message separator tokens return base_overhead + separator_overhead @@ -175,3 +230,15 @@ def get_token_counter(model_id: str) -> TokenCounter: A TokenCounter instance configured for the specified model. """ return TokenCounter(model_id) + + +def get_default_estimation_tokenizer(config: AppConfig) -> str: + """Get the default estimation tokenizer.""" + if ( + config.customization is not None + and config.customization.default_estimation_tokenizer is not None + ): + return config.customization.default_estimation_tokenizer + + # default system prompt has the lowest precedence + return DEFAULT_ESTIMATION_TOKENIZER diff --git a/tests/unit/app/endpoints/test_config.py b/tests/unit/app/endpoints/test_config.py index 341f4020f..4a2cd3c1c 100644 --- a/tests/unit/app/endpoints/test_config.py +++ b/tests/unit/app/endpoints/test_config.py @@ -44,7 +44,6 @@ def test_config_endpoint_handler_configuration_loaded(mocker): "module": "noop", }, "customization": None, - "default_estimation_tokenizer": "cl100k_base", } cfg = AppConfig() cfg.init_from_dict(config_dict) diff --git a/tests/unit/app/endpoints/test_info.py b/tests/unit/app/endpoints/test_info.py index f5410226a..7dc1d46bc 100644 --- a/tests/unit/app/endpoints/test_info.py +++ b/tests/unit/app/endpoints/test_info.py @@ -23,7 +23,6 @@ def test_info_endpoint(mocker): "feedback_disabled": True, }, "customization": None, - "default_estimation_tokenizer": "cl100k_base", } cfg = AppConfig() cfg.init_from_dict(config_dict) diff --git a/tests/unit/app/endpoints/test_models.py b/tests/unit/app/endpoints/test_models.py index 571eb35c3..f9bcc51ba 100644 --- a/tests/unit/app/endpoints/test_models.py +++ b/tests/unit/app/endpoints/test_models.py @@ -46,7 +46,6 @@ def test_models_endpoint_handler_improper_llama_stack_configuration(mocker): }, "mcp_servers": [], "customization": None, - "default_estimation_tokenizer": "cl100k_base", } cfg = AppConfig() cfg.init_from_dict(config_dict) @@ -85,7 +84,6 @@ def test_models_endpoint_handler_configuration_loaded(mocker): "feedback_disabled": True, }, "customization": None, - "default_estimation_tokenizer": "cl100k_base", } cfg = AppConfig() cfg.init_from_dict(config_dict) @@ -119,7 +117,6 @@ def test_models_endpoint_handler_unable_to_retrieve_models_list(mocker): "feedback_disabled": True, }, "customization": None, - "default_estimation_tokenizer": "cl100k_base", } cfg = AppConfig() cfg.init_from_dict(config_dict) diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 58484db7d..eb59cf82f 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -44,7 +44,6 @@ def setup_configuration(): }, "mcp_servers": [], "customization": None, - "default_estimation_tokenizer": "cl100k_base", } cfg = AppConfig() cfg.init_from_dict(config_dict) diff --git a/tests/unit/models/test_config.py b/tests/unit/models/test_config.py index 0d76ccf14..254ae7838 100644 --- a/tests/unit/models/test_config.py +++ b/tests/unit/models/test_config.py @@ -321,7 +321,6 @@ def test_dump_configuration(tmp_path) -> None: ), mcp_servers=[], customization=None, - default_estimation_tokenizer="cl100k_base", ) assert cfg is not None dump_file = tmp_path / "test.json" @@ -376,7 +375,6 @@ def test_dump_configuration(tmp_path) -> None: "k8s_cluster_api": None, }, "customization": None, - "default_estimation_tokenizer": "cl100k_base", } @@ -451,7 +449,6 @@ def test_dump_configuration_with_one_mcp_server(tmp_path) -> None: "k8s_cluster_api": None, }, "customization": None, - "default_estimation_tokenizer": "cl100k_base", } @@ -544,7 +541,6 @@ def test_dump_configuration_with_more_mcp_servers(tmp_path) -> None: "k8s_cluster_api": None, }, "customization": None, - "default_estimation_tokenizer": "cl100k_base", } diff --git a/tests/unit/test_configuration.py b/tests/unit/test_configuration.py index 4d872c16a..97831f23e 100644 --- a/tests/unit/test_configuration.py +++ b/tests/unit/test_configuration.py @@ -50,7 +50,6 @@ def test_init_from_dict() -> None: }, "mcp_servers": [], "customization": None, - "default_estimation_tokenizer": "cl100k_base", } cfg = AppConfig() cfg.init_from_dict(config_dict) @@ -113,7 +112,6 @@ def test_init_from_dict_with_mcp_servers() -> None: }, ], "customization": None, - "default_estimation_tokenizer": "cl100k_base", } cfg = AppConfig() cfg.init_from_dict(config_dict) @@ -221,7 +219,6 @@ def test_mcp_servers_property_empty() -> None: }, "mcp_servers": [], "customization": None, - "default_estimation_tokenizer": "cl100k_base", } cfg = AppConfig() cfg.init_from_dict(config_dict) @@ -258,7 +255,6 @@ def test_mcp_servers_property_with_servers() -> None: }, ], "customization": None, - "default_estimation_tokenizer": "cl100k_base", } cfg = AppConfig() cfg.init_from_dict(config_dict) diff --git a/tests/unit/utils/test_token_counter.py b/tests/unit/utils/test_token_counter.py index 846d5f255..c839ee0e5 100644 --- a/tests/unit/utils/test_token_counter.py +++ b/tests/unit/utils/test_token_counter.py @@ -23,7 +23,6 @@ "user_data_collection": { "feedback_disabled": True, }, - "default_estimation_tokenizer": "cl100k_base", } @@ -65,3 +64,27 @@ def test_count_message_tokens_empty_messages(self): counter = TokenCounter("llama3.2:1b") result = counter.count_message_tokens([]) assert result == 0 + + def test_count_conversation_turn_tokens(self): + """Test cumulative token tracking across conversation turns.""" + counter = TokenCounter("llama3.2:1b") + + # First conversation should accumulate tokens + result1 = counter.count_conversation_turn_tokens( + "conv1", "System", "Hello", "Hi" + ) + assert result1["input_tokens"] == 14 + result2 = counter.count_conversation_turn_tokens( + "conv1", "System", "How are you?", "Good" + ) + assert result2["input_tokens"] == 31 + result3 = counter.count_conversation_turn_tokens( + "conv1", "System", "Fantastic", "Yup" + ) + assert result3["input_tokens"] == 45 + + # Second conversation should be independent of the first + result4 = counter.count_conversation_turn_tokens( + "conv2", "System", "Hello", "Hi" + ) + assert result4["input_tokens"] == 14 diff --git a/uv.lock b/uv.lock index 4f9602fc8..24495e5c6 100644 --- a/uv.lock +++ b/uv.lock @@ -716,7 +716,7 @@ requires-dist = [ { name = "kubernetes", specifier = ">=30.1.0" }, { name = "llama-stack", specifier = ">=0.2.13" }, { name = "rich", specifier = ">=14.0.0" }, - { name = "tiktoken", specifier = ">=0.6.0" }, + { name = "tiktoken", specifier = ">=0.9.0,<1.0.0" }, { name = "uvicorn", specifier = ">=0.34.3" }, ] From 4ef7798273e81d1ce40fd9d4b6e925dc2e7e780f Mon Sep 17 00:00:00 2001 From: Chris Sibbitt Date: Thu, 17 Jul 2025 10:33:58 -0400 Subject: [PATCH 4/7] Count tokens in attachments --- src/app/endpoints/query.py | 2 +- src/app/endpoints/streaming_query.py | 2 +- src/utils/token_counter.py | 15 +++-- tests/unit/app/endpoints/test_query.py | 2 +- .../app/endpoints/test_streaming_query.py | 2 +- tests/unit/utils/test_token_counter.py | 67 +++++++++++++++++-- 6 files changed, 78 insertions(+), 12 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 9b6e1765d..46a0b5334 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -286,7 +286,7 @@ def retrieve_response( try: token_counter = get_token_counter(model_id) token_usage = token_counter.count_conversation_turn_tokens( - conversation_id, system_prompt, query_request.query, response_content + conversation_id, system_prompt, query_request, response_content ) except Exception as e: # pylint: disable=broad-exception-caught logger.warning("Failed to estimate token usage: %s", e) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 62d95fbb0..1f49cce07 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -367,7 +367,7 @@ async def retrieve_response( try: token_counter = get_token_counter(model_id) token_usage = token_counter.count_conversation_turn_tokens( - conversation_id, system_prompt, query_request.query + conversation_id, system_prompt, query_request ) except Exception as e: # pylint: disable=broad-exception-caught logger.warning("Failed to estimate token usage: %s", e) diff --git a/src/utils/token_counter.py b/src/utils/token_counter.py index 6b7b4e39c..f162238da 100644 --- a/src/utils/token_counter.py +++ b/src/utils/token_counter.py @@ -18,6 +18,7 @@ ToolResponseMessage, CompletionMessage, ) +from models.requests import QueryRequest import tiktoken from configuration import configuration, AppConfig @@ -83,7 +84,7 @@ def count_tokens(self, text: str) -> int: return len(self._encoder.encode(text)) def count_turn_tokens( - self, system_prompt: str, query: str, response: str = "" + self, system_prompt: str, query_request: QueryRequest, response: str = "" ) -> dict[str, int]: """Count tokens for a complete conversation turn. @@ -107,7 +108,13 @@ def count_turn_tokens( input_messages.append( SystemMessage(role="system", content=str(system_prompt)) ) - input_messages.append(UserMessage(role="user", content=query)) + input_messages.append(UserMessage(role="user", content=query_request.query)) + + if query_request.attachments: + for attachment in query_request.attachments: + input_messages.append( + UserMessage(role="user", content=attachment.content) + ) input_tokens = self.count_message_tokens(input_messages) output_tokens = self.count_tokens(response) @@ -120,7 +127,7 @@ def count_turn_tokens( } def count_conversation_turn_tokens( - self, conversation_id: str, system_prompt: str, query: str, response: str = "" + self, conversation_id: str, system_prompt: str, query_request: QueryRequest, response: str = "" ) -> dict[str, int]: """Count tokens for a conversation turn with cumulative tracking. @@ -141,7 +148,7 @@ def count_conversation_turn_tokens( - 'output_tokens': Total tokens in the response message """ # Get the current turn's token usage - turn_token_usage = self.count_turn_tokens(system_prompt, query, response) + turn_token_usage = self.count_turn_tokens(system_prompt, query_request, response) # Get cumulative input tokens for this conversation cumulative_input_tokens = _conversation_cache.get(conversation_id, 0) diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 14e5e2147..70adb7cd9 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -1204,7 +1204,7 @@ def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker): mock_retrieve_response = mocker.patch( "app.endpoints.query.retrieve_response", - return_value=("test response", "test_conversation_id"), + return_value=("test response", "test_conversation_id", {"input_tokens": 10, "output_tokens": 20}), ) mocker.patch("app.endpoints.query.select_model_id", return_value="test_model") diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index ffe1ba9f1..5ca72d42b 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -1224,7 +1224,7 @@ async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker): mock_streaming_response.__aiter__.return_value = iter([]) mock_retrieve_response = mocker.patch( "app.endpoints.streaming_query.retrieve_response", - return_value=(mock_streaming_response, "test_conversation_id"), + return_value=(mock_streaming_response, "test_conversation_id", {"input_tokens": 10, "output_tokens": 20}), ) mocker.patch( diff --git a/tests/unit/utils/test_token_counter.py b/tests/unit/utils/test_token_counter.py index c839ee0e5..507accff7 100644 --- a/tests/unit/utils/test_token_counter.py +++ b/tests/unit/utils/test_token_counter.py @@ -2,6 +2,7 @@ from utils.token_counter import TokenCounter from llama_stack_client.types import UserMessage, CompletionMessage +from models.requests import QueryRequest, Attachment from configuration import AppConfig @@ -71,20 +72,78 @@ def test_count_conversation_turn_tokens(self): # First conversation should accumulate tokens result1 = counter.count_conversation_turn_tokens( - "conv1", "System", "Hello", "Hi" + "conv1", "System", QueryRequest(query="Hello"), "Hi" ) assert result1["input_tokens"] == 14 result2 = counter.count_conversation_turn_tokens( - "conv1", "System", "How are you?", "Good" + "conv1", "System", QueryRequest(query="How are you?"), "Good" ) assert result2["input_tokens"] == 31 result3 = counter.count_conversation_turn_tokens( - "conv1", "System", "Fantastic", "Yup" + "conv1", "System", QueryRequest(query="Fantastic"), "Yup" ) assert result3["input_tokens"] == 45 # Second conversation should be independent of the first result4 = counter.count_conversation_turn_tokens( - "conv2", "System", "Hello", "Hi" + "conv2", "System", QueryRequest(query="Hello"), "Hi" ) assert result4["input_tokens"] == 14 + + def test_count_conversation_turn_tokens_with_attachments(self): + """Test conversation turn token counting with 2 attachments.""" + counter = TokenCounter("llama3.2:1b") + + # Create 2 attachments + attachments = [ + Attachment( + attachment_type="log", + content_type="text/plain", + content="This is a log file with some error messages", + ), + Attachment( + attachment_type="configuration", + content_type="application/yaml", + content="kind: Pod\nmetadata:\n name: test-pod\nspec:\n containers:\n - name: app", + ), + ] + + query_request = QueryRequest( + query="Analyze these files for me", + attachments=attachments + ) + + # Test the conversation turn with attachments + result = counter.count_conversation_turn_tokens( + "conv_with_attachments", + "System prompt", + query_request, + "Analysis complete" + ) + + # Verify that the result contains the expected structure + assert "input_tokens" in result + assert "output_tokens" in result + + # The input tokens should include: + # - System message tokens + # - User query tokens + # - 2 attachment content tokens + # - Role formatting overhead for each message + # - Conversation formatting overhead + assert result["input_tokens"] > 0 + + # Output tokens should be the response content + assert result["output_tokens"] > 0 + + # Verify that attachments increase the token count compared to no attachments + query_request_no_attachments = QueryRequest(query="Analyze these files for me") + result_no_attachments = counter.count_conversation_turn_tokens( + "conv_no_attachments", + "System prompt", + query_request_no_attachments, + "Analysis complete" + ) + + # The version with attachments should have more input tokens + assert result["input_tokens"] > result_no_attachments["input_tokens"] From cd5bb035ff39d0d76de2c14022ef35f7d4ac6335 Mon Sep 17 00:00:00 2001 From: Chris Sibbitt Date: Thu, 17 Jul 2025 11:12:26 -0400 Subject: [PATCH 5/7] Lint picking --- src/utils/token_counter.py | 13 ++++++++++--- tests/unit/app/endpoints/test_query.py | 8 ++++++-- .../unit/app/endpoints/test_streaming_query.py | 11 ++++++++--- tests/unit/utils/test_token_counter.py | 18 +++++++++--------- 4 files changed, 33 insertions(+), 17 deletions(-) diff --git a/src/utils/token_counter.py b/src/utils/token_counter.py index f162238da..b488bf81d 100644 --- a/src/utils/token_counter.py +++ b/src/utils/token_counter.py @@ -11,6 +11,8 @@ from typing import Sequence from cachetools import TTLCache # type: ignore +import tiktoken + from llama_stack_client.types import ( UserMessage, @@ -19,7 +21,6 @@ CompletionMessage, ) from models.requests import QueryRequest -import tiktoken from configuration import configuration, AppConfig from constants import DEFAULT_ESTIMATION_TOKENIZER @@ -127,7 +128,11 @@ def count_turn_tokens( } def count_conversation_turn_tokens( - self, conversation_id: str, system_prompt: str, query_request: QueryRequest, response: str = "" + self, + conversation_id: str, + system_prompt: str, + query_request: QueryRequest, + response: str = "", ) -> dict[str, int]: """Count tokens for a conversation turn with cumulative tracking. @@ -148,7 +153,9 @@ def count_conversation_turn_tokens( - 'output_tokens': Total tokens in the response message """ # Get the current turn's token usage - turn_token_usage = self.count_turn_tokens(system_prompt, query_request, response) + turn_token_usage = self.count_turn_tokens( + system_prompt, query_request, response + ) # Get cumulative input tokens for this conversation cumulative_input_tokens = _conversation_cache.get(conversation_id, 0) diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 70adb7cd9..525ead1b6 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -760,7 +760,7 @@ def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker): }, } - response, conversation_id, token_usage = retrieve_response( + response, conversation_id, _ = retrieve_response( mock_client, model_id, query_request, @@ -1204,7 +1204,11 @@ def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker): mock_retrieve_response = mocker.patch( "app.endpoints.query.retrieve_response", - return_value=("test response", "test_conversation_id", {"input_tokens": 10, "output_tokens": 20}), + return_value=( + "test response", + "test_conversation_id", + {"input_tokens": 10, "output_tokens": 20}, + ), ) mocker.patch("app.endpoints.query.select_model_id", return_value="test_model") diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 5ca72d42b..67c0dd9f1 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -151,7 +151,8 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) # Mock the streaming response from LLama Stack mock_streaming_response = mocker.AsyncMock() - # Currently usage is not returned by the API, we simulate by using del to prevent pytest from returning a Mock + # Currently usage is not returned by the API + # we simulate by using del to prevent pytest from returning a Mock del mock_streaming_response.usage mock_streaming_response.__aiter__.return_value = [ mocker.Mock( @@ -862,7 +863,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker): }, } - response, conversation_id, token_usage = await retrieve_response( + response, conversation_id, _ = await retrieve_response( mock_client, model_id, query_request, @@ -1224,7 +1225,11 @@ async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker): mock_streaming_response.__aiter__.return_value = iter([]) mock_retrieve_response = mocker.patch( "app.endpoints.streaming_query.retrieve_response", - return_value=(mock_streaming_response, "test_conversation_id", {"input_tokens": 10, "output_tokens": 20}), + return_value=( + mock_streaming_response, + "test_conversation_id", + {"input_tokens": 10, "output_tokens": 20}, + ), ) mocker.patch( diff --git a/tests/unit/utils/test_token_counter.py b/tests/unit/utils/test_token_counter.py index 507accff7..ce68e0f18 100644 --- a/tests/unit/utils/test_token_counter.py +++ b/tests/unit/utils/test_token_counter.py @@ -1,7 +1,8 @@ """Unit tests for token counter utilities.""" -from utils.token_counter import TokenCounter from llama_stack_client.types import UserMessage, CompletionMessage + +from utils.token_counter import TokenCounter from models.requests import QueryRequest, Attachment from configuration import AppConfig @@ -31,6 +32,7 @@ class TestTokenCounter: """Test cases for TokenCounter class.""" def setup_class(self): + """Setup the test class.""" cfg = AppConfig() cfg.init_from_dict(config_dict) @@ -40,6 +42,7 @@ def test_count_tokens_empty_string(self): assert counter.count_tokens("") == 0 def test_count_tokens_simple(self): + """Test counting tokens for a simple message.""" counter = TokenCounter("llama3.2:1b") assert counter.count_tokens("Hello World!") == 3 @@ -104,21 +107,18 @@ def test_count_conversation_turn_tokens_with_attachments(self): Attachment( attachment_type="configuration", content_type="application/yaml", - content="kind: Pod\nmetadata:\n name: test-pod\nspec:\n containers:\n - name: app", + content="kind: Pod\nmetadata:\n name: test-pod\nspec:\n" + + " containers:\n - name: app\n image: nginx:latest", ), ] query_request = QueryRequest( - query="Analyze these files for me", - attachments=attachments + query="Analyze these files for me", attachments=attachments ) # Test the conversation turn with attachments result = counter.count_conversation_turn_tokens( - "conv_with_attachments", - "System prompt", - query_request, - "Analysis complete" + "conv_with_attachments", "System prompt", query_request, "Analysis complete" ) # Verify that the result contains the expected structure @@ -142,7 +142,7 @@ def test_count_conversation_turn_tokens_with_attachments(self): "conv_no_attachments", "System prompt", query_request_no_attachments, - "Analysis complete" + "Analysis complete", ) # The version with attachments should have more input tokens From 3c5a3e1d1c21834b978102bd66a2ed4d7e9a56c4 Mon Sep 17 00:00:00 2001 From: Chris Sibbitt Date: Thu, 17 Jul 2025 13:43:28 -0400 Subject: [PATCH 6/7] Fix attachments counting to be closer to how they are handled in llama-stack --- src/utils/token_counter.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/utils/token_counter.py b/src/utils/token_counter.py index b488bf81d..e11d03c40 100644 --- a/src/utils/token_counter.py +++ b/src/utils/token_counter.py @@ -109,13 +109,11 @@ def count_turn_tokens( input_messages.append( SystemMessage(role="system", content=str(system_prompt)) ) - input_messages.append(UserMessage(role="user", content=query_request.query)) - + input_content = query_request.query if query_request.attachments: for attachment in query_request.attachments: - input_messages.append( - UserMessage(role="user", content=attachment.content) - ) + input_content += "\n" + attachment.content + input_messages.append(UserMessage(role="user", content=input_content)) input_tokens = self.count_message_tokens(input_messages) output_tokens = self.count_tokens(response) From 58182975ccd5a5ed900419a3306294fc7bcafec1 Mon Sep 17 00:00:00 2001 From: Chris Sibbitt Date: Thu, 17 Jul 2025 13:48:09 -0400 Subject: [PATCH 7/7] More optimistic version pin for tiktoken --- uv.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uv.lock b/uv.lock index 91a558245..287924185 100644 --- a/uv.lock +++ b/uv.lock @@ -869,7 +869,7 @@ requires-dist = [ { name = "kubernetes", specifier = ">=30.1.0" }, { name = "llama-stack", specifier = ">=0.2.13" }, { name = "rich", specifier = ">=14.0.0" }, - { name = "tiktoken", specifier = ">=0.9.0,<1.0.0" }, + { name = "tiktoken", specifier = ">=0.9.0" }, { name = "uvicorn", specifier = ">=0.34.3" }, ]