Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ async def _register_mcp_toolgroups_async(
This function performs network calls against the provided async client and does not
catch exceptions raised by those calls — any exceptions from the client (e.g., RPC
or HTTP errors) will propagate to the caller.

Parameters:
client (AsyncLlamaStackClient): The LlamaStack async client used to
query and register toolgroups.
mcp_servers (List[ModelContextProtocolServer]): MCP server descriptors
to ensure are registered.
logger (Logger): Logger used for debug messages about registration
progress.
"""
# Get registered tools
registered_toolgroups = await client.toolgroups.list()
Expand Down Expand Up @@ -101,6 +109,10 @@ def run_once_async(func: Callable) -> Callable:
Later invocations return/await the same Task, receiving the same result or
propagated exception. Requires an active running event loop when the
wrapped function is first called.

Returns:
Any: The result produced by the wrapped coroutine, or the exception it
raised propagated to callers.
"""
task = None

Expand All @@ -114,6 +126,9 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
Subsequent calls return the same awaited task result. Exceptions raised
by the task propagate to callers. Requires an active running event loop
when first called.

Returns:
The awaited result of the wrapped coroutine.
"""
nonlocal task
if task is None:
Expand Down
113 changes: 105 additions & 8 deletions src/utils/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,17 @@ def get_system_prompt(query_request: QueryRequest, config: AppConfig) -> str:


def get_topic_summary_system_prompt(config: AppConfig) -> str:
"""Get the topic summary system prompt."""
"""
Get the topic summary system prompt.

Parameters:
config (AppConfig): Application configuration from which to read
customization/profile settings.

Returns:
str: The topic summary system prompt from the active custom profile if
set, otherwise the default prompt.
"""
# profile takes precedence for setting prompt
if (
config.customization is not None
Expand All @@ -223,8 +233,9 @@ def validate_model_provider_override(
) -> None:
"""Validate whether model/provider overrides are allowed by RBAC.

Raises HTTP 403 if the request includes model or provider and the caller
lacks Action.MODEL_OVERRIDE permission.
Raises:
HTTPException: HTTP 403 if the request includes model or provider and
the caller lacks Action.MODEL_OVERRIDE permission.
"""
if (query_request.model is not None or query_request.provider is not None) and (
Action.MODEL_OVERRIDE not in authorized_actions
Expand All @@ -242,7 +253,24 @@ def store_conversation_into_cache(
_skip_userid_check: bool,
topic_summary: str | None,
) -> None:
"""Store one part of conversation into conversation history cache."""
"""
Store one part of conversation into conversation history cache.

If a conversation cache type is configured but the cache instance is not
initialized, the function logs a warning and returns without persisting
anything.

Parameters:
config (AppConfig): Application configuration that may contain
conversation cache settings and instance.
user_id (str): Owner identifier used as the cache key.
conversation_id (str): Conversation identifier used as the cache key.
cache_entry (CacheEntry): Entry to insert or append to the conversation history.
_skip_userid_check (bool): When true, bypasses enforcing that the cache
operation must match the user id.
topic_summary (str | None): Optional topic summary to store alongside
the conversation; ignored if None or empty.
"""
if config.conversation_cache_configuration.type is not None:
cache = config.conversation_cache
if cache is None:
Expand Down Expand Up @@ -366,10 +394,12 @@ async def get_temp_agent(
This function creates a new agent without persistence, shields, or tools.
Useful for temporary operations or one-off queries, such as validating a
question or generating a summary.
Args:

Parameters:
client: The AsyncLlamaStackClient to use for the request.
model_id: The ID of the model to use.
system_prompt: The system prompt/instructions for the agent.

Returns:
tuple[AsyncAgent, str]: A tuple containing the agent and session_id.
"""
Expand Down Expand Up @@ -412,7 +442,23 @@ def create_rag_chunks_dict(summary: TurnSummary) -> list[dict[str, Any]]:
def _process_http_source(
src: str, doc_urls: set[str]
) -> tuple[AnyUrl | None, str] | None:
"""Process HTTP source and return (doc_url, doc_title) tuple."""
"""
Process HTTP source and return (doc_url, doc_title) tuple.

Parameters:
src (str): The source URL string to process.
doc_urls (set[str]): Set of already-seen source strings; the function
will add `src` to this set when it is new.

Returns:
tuple[AnyUrl | None, str] | None: A tuple (validated_url, doc_title)
when `src` was not previously seen:
- `validated_url`: an `AnyUrl` instance if `src` is a valid URL, or
`None` if validation failed.
- `doc_title`: the last path segment of the URL or `src` if no path
segment is present.
Returns `None` if `src` was already present in `doc_urls`.
"""
if src not in doc_urls:
doc_urls.add(src)
try:
Expand All @@ -433,7 +479,29 @@ def _process_document_id(
metas_by_id: dict[str, dict[str, Any]],
metadata_map: dict[str, Any] | None,
) -> tuple[AnyUrl | None, str] | None:
"""Process document ID and return (doc_url, doc_title) tuple."""
"""
Process document ID and return (doc_url, doc_title) tuple.

Parameters:
src (str): Document identifier to process.
doc_ids (set[str]): Set of already-seen document IDs; the function adds `src` to this set.
doc_urls (set[str]): Set of already-seen document URLs; the function
adds discovered URLs to this set to avoid duplicates.
metas_by_id (dict[str, dict[str, Any]]): Mapping of document IDs to
metadata dicts that may
contain `docs_url` and
`title`.
metadata_map (dict[str, Any] | None): If provided (truthy), indicates
metadata is available and enables
metadata lookup; when falsy,
metadata lookup is skipped.

Returns:
tuple[AnyUrl | None, str] | None: `(validated_url, doc_title)` where
`validated_url` is a validated `AnyUrl` or `None` and `doc_title` is
the chosen title string; returns `None` if the `src` or its URL was
already processed.
"""
if src in doc_ids:
return None
doc_ids.add(src)
Expand Down Expand Up @@ -500,6 +568,17 @@ def _process_rag_chunks_for_documents(
Process RAG chunks and return a list of (doc_url, doc_title) tuples.

This is the core logic shared between both return formats.

Parameters:
rag_chunks (list): Iterable of RAG chunk objects; each chunk must
provide a `source` attribute (e.g., an HTTP URL or a document ID).
metadata_map (dict[str, Any] | None): Optional mapping of document IDs
to metadata dictionaries used to resolve titles and document URLs.

Returns:
list[tuple[AnyUrl | None, str]]: Ordered list of tuples where the first
element is a validated URL object or `None` (if no URL is available)
and the second element is the document title.
"""
doc_urls: set[str] = set()
doc_ids: set[str] = set()
Expand Down Expand Up @@ -547,7 +626,7 @@ def create_referenced_documents(
optional metadata enrichment, deduplication, and proper URL handling. It can return
either ReferencedDocument objects (for query endpoint) or dictionaries (for streaming).

Args:
Parameters:
rag_chunks: List of RAG chunks with source information
metadata_map: Optional mapping containing metadata about referenced documents
return_dict_format: If True, returns list of dicts; if False, returns list of
Expand Down Expand Up @@ -580,6 +659,16 @@ def create_referenced_documents_with_metadata(
Create referenced documents from RAG chunks with metadata enrichment for streaming.

This function now returns ReferencedDocument objects for consistency with the query endpoint.

Parameters:
summary (TurnSummary): Summary object containing `rag_chunks` to be processed.
metadata_map (dict[str, Any]): Metadata keyed by document id used to
derive or enrich document `doc_url` and `doc_title`.

Returns:
list[ReferencedDocument]: ReferencedDocument objects with `doc_url` and
`doc_title` populated; `doc_url` may be `None` if no valid URL could be
determined.
"""
document_entries = _process_rag_chunks_for_documents(
summary.rag_chunks, metadata_map
Expand All @@ -598,6 +687,14 @@ def create_referenced_documents_from_chunks(

This is a backward compatibility wrapper around the unified
create_referenced_documents function.

Parameters:
rag_chunks (list): List of RAG chunk entries containing source and metadata information.

Returns:
list[ReferencedDocument]: ReferencedDocument instances created from the
chunks; each contains `doc_url` (validated URL or `None`) and
`doc_title`.
"""
document_entries = _process_rag_chunks_for_documents(rag_chunks)
return [
Expand Down
4 changes: 4 additions & 0 deletions src/utils/llama_stack_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ async def check_llama_stack_version(
and maximal supported versions. Raises
InvalidLlamaStackVersionException if the detected version is
outside the supported range.

Raises:
InvalidLlamaStackVersionException: If the detected version is outside
the supported range or cannot be parsed.
"""
version_info = await client.inspect.version()
compare_versions(
Expand Down
4 changes: 2 additions & 2 deletions src/utils/mcp_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def mcp_headers_dependency(request: Request) -> dict[str, dict[str, str]]:

mcp headers is a json dictionary or mcp url paths and their respective headers

Args:
Parameters:
request (Request): The FastAPI request object.

Returns:
Expand All @@ -32,7 +32,7 @@ def extract_mcp_headers(request: Request) -> dict[str, dict[str, str]]:
If the header is missing, contains invalid JSON, or the decoded
value is not a dictionary, an empty dictionary is returned.

Args:
Parameters:
request: The FastAPI request object

Returns:
Expand Down
4 changes: 2 additions & 2 deletions src/utils/quota.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def consume_tokens(
) -> None:
"""Consume tokens from cluster and/or user quotas.

Args:
Parameters:
quota_limiters: List of quota limiter instances to consume tokens from.
user_id: Identifier of the user consuming tokens.
input_tokens: Number of input tokens to consume.
Expand All @@ -40,7 +40,7 @@ def consume_tokens(
def check_tokens_available(quota_limiters: list[QuotaLimiter], user_id: str) -> None:
"""Check if tokens are available for user.

Args:
Parameters:
quota_limiters: List of quota limiter instances to check.
user_id: Identifier of the user to check quota for.

Expand Down
2 changes: 1 addition & 1 deletion src/utils/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def extract_text_from_response_output_item(output_item: Any) -> str:
formats including string content, content arrays with text parts, and refusal
messages.

Args:
Parameters:
output_item: A Responses API output item (typically from response.output array).
Expected to have attributes like type, role, and content.

Expand Down
8 changes: 4 additions & 4 deletions src/utils/shields.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ async def get_available_shields(client: AsyncLlamaStackClient) -> list[str]:
"""
Discover and return available shield identifiers.

Args:
Parameters:
client: The Llama Stack client to query for available shields.

Returns:
List of shield identifiers that are available.
list[str]: List of available shield identifiers; empty if no shields are available.
"""
available_shields = [shield.identifier for shield in await client.shields.list()]
if not available_shields:
Expand All @@ -36,11 +36,11 @@ def detect_shield_violations(output_items: list[Any]) -> bool:
attributes. If a refusal is found, increments the validation error
metric and logs a warning.

Args:
Parameters:
output_items: List of output items from the LLM response to check.

Returns:
True if a shield violation was detected, False otherwise.
bool: True if a shield violation was detected, False otherwise.
"""
for output_item in output_items:
item_type = getattr(output_item, "type", None)
Expand Down
16 changes: 11 additions & 5 deletions src/utils/token_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ class TokenCounter:
llm_calls: int = 0

def __str__(self) -> str:
"""Textual representation of TokenCounter instance."""
"""
Return a human-readable summary of the token usage stored in this TokenCounter.

Returns:
summary (str): A formatted string containing `input_tokens`,
`output_tokens`, `input_tokens_counted`, and `llm_calls`.
"""
return (
f"{self.__class__.__name__}: "
+ f"input_tokens: {self.input_tokens} "
Expand All @@ -47,9 +53,9 @@ def extract_token_usage_from_turn(turn: Turn, system_prompt: str = "") -> TokenC
This function uses the same tokenizer and logic as the metrics system
to ensure consistency between API responses and Prometheus metrics.

Args:
turn: The turn object containing token usage information
system_prompt: The system prompt used for the turn
Parameters:
turn (Turn): The turn object containing token usage information
system_prompt (str): The system prompt used for the turn

Returns:
TokenCounter: Token usage information
Expand Down Expand Up @@ -102,7 +108,7 @@ def extract_and_update_token_metrics(
This function combines the token counting logic with the metrics system
to ensure both API responses and Prometheus metrics are updated consistently.

Args:
Parameters:
turn: The turn object containing token usage information
model: The model identifier for metrics labeling
provider: The provider identifier for metrics labeling
Expand Down
16 changes: 12 additions & 4 deletions src/utils/tool_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@ def format_tool_response(tool_dict: dict[str, Any]) -> dict[str, Any]:
tool_dict: Raw tool dictionary from Llama Stack

Returns:
Formatted tool dictionary with only required fields
dict[str, Any]: Formatted tool dictionary containing the following keys:
- identifier: tool identifier string (defaults to "").
- description: cleaned or original description string.
- parameters: list of parameter definitions (defaults to empty list).
- provider_id: provider identifier string (defaults to "").
- toolgroup_id: tool group identifier string (defaults to "").
- server_source: server source string (defaults to "").
- type: tool type string (defaults to "").
"""
# Clean up description if it contains structured metadata
description = tool_dict.get("description", "")
Expand Down Expand Up @@ -116,10 +123,11 @@ def format_tools_list(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Format a list of tools with structured description parsing.

Args:
tools: List of raw tool dictionaries
Parameters:
tools: (list[dict[str, Any]]): List of raw tool dictionaries

Returns:
List of formatted tool dictionaries
list[dict[str, Any]]: Formatted tool dictionaries with normalized
fields and cleaned descriptions.
"""
return [format_tool_response(tool) for tool in tools]
Loading
Loading