diff --git a/src/app/endpoints/a2a.py b/src/app/endpoints/a2a.py index e6187bfa..6679b8f5 100644 --- a/src/app/endpoints/a2a.py +++ b/src/app/endpoints/a2a.py @@ -5,7 +5,7 @@ import logging import uuid from datetime import datetime, timezone -from typing import Annotated, Any, AsyncIterator, MutableMapping +from typing import Annotated, Any, AsyncIterator, MutableMapping, Optional from fastapi import APIRouter, Depends, HTTPException, Request, status from llama_stack.apis.agents.openai_responses import ( @@ -65,8 +65,8 @@ # Task store and context store are created lazily based on configuration. # For multi-worker deployments, configure 'a2a_state' with 'sqlite' or 'postgres' # to share state across workers. -_TASK_STORE: TaskStore | None = None -_CONTEXT_STORE: A2AContextStore | None = None +_TASK_STORE: Optional[TaskStore] = None +_CONTEXT_STORE: Optional[A2AContextStore] = None async def _get_task_store() -> TaskStore: @@ -120,7 +120,7 @@ class TaskResultAggregator: def __init__(self) -> None: """Initialize the task result aggregator with default state.""" self._task_state: TaskState = TaskState.working - self._task_status_message: Message | None = None + self._task_status_message: Optional[Message] = None def process_event( self, event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Any @@ -169,7 +169,7 @@ def task_state(self) -> TaskState: return self._task_state @property - def task_status_message(self) -> Message | None: + def task_status_message(self) -> Optional[Message]: """Return the current task status message.""" return self._task_status_message @@ -185,7 +185,7 @@ class A2AAgentExecutor(AgentExecutor): """ def __init__( - self, auth_token: str, mcp_headers: dict[str, dict[str, str]] | None = None + self, auth_token: str, mcp_headers: Optional[dict[str, dict[str, str]]] = None ): """Initialize the A2A agent executor. @@ -413,7 +413,7 @@ async def _convert_stream_to_events( # pylint: disable=too-many-branches,too-ma stream: AsyncIterator[OpenAIResponseObjectStream], task_id: str, context_id: str, - conversation_id: str | None, + conversation_id: Optional[str], ) -> AsyncIterator[Any]: """Convert Responses API stream chunks to A2A events. diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 952c0a0b..5a9af396 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -156,12 +156,12 @@ def persist_user_conversation_details( def evaluate_model_hints( - user_conversation: UserConversation | None, + user_conversation: Optional[UserConversation], query_request: QueryRequest, -) -> tuple[str | None, str | None]: +) -> tuple[Optional[str], Optional[str]]: """Evaluate model hints from user conversation.""" - model_id: str | None = query_request.model - provider_id: str | None = query_request.provider + model_id: Optional[str] = query_request.model + provider_id: Optional[str] = query_request.provider if user_conversation is not None: if query_request.model is not None: @@ -271,7 +271,7 @@ async def query_endpoint_handler_base( # pylint: disable=R0914 user_id, _, _skip_userid_check, token = auth started_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ") - user_conversation: UserConversation | None = None + user_conversation: Optional[UserConversation] = None if query_request.conversation_id: logger.debug( "Conversation ID specified in query: %s", query_request.conversation_id @@ -483,7 +483,7 @@ async def query_endpoint_handler( def select_model_and_provider_id( - models: ModelListResponse, model_id: str | None, provider_id: str | None + models: ModelListResponse, model_id: Optional[str], provider_id: Optional[str] ) -> tuple[str, str, str]: """ Select the model ID and provider ID based on the request or available models. @@ -663,7 +663,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche model_id: str, query_request: QueryRequest, token: str, - mcp_headers: dict[str, dict[str, str]] | None = None, + mcp_headers: Optional[dict[str, dict[str, str]]] = None, *, provider_id: str = "", ) -> tuple[TurnSummary, str, list[ReferencedDocument], TokenCounter]: @@ -859,7 +859,7 @@ def validate_attachments_metadata(attachments: list[Attachment]) -> None: def get_rag_toolgroups( vector_db_ids: list[str], -) -> list[Toolgroup] | None: +) -> Optional[list[Toolgroup]]: """ Return a list of RAG Tool groups if the given vector DB list is not empty. @@ -870,7 +870,7 @@ def get_rag_toolgroups( vector_db_ids (list[str]): List of vector database identifiers to include in the toolgroup. Returns: - list[Toolgroup] | None: A list with a single RAG toolgroup if + Optional[list[Toolgroup]]: A list with a single RAG toolgroup if vector_db_ids is non-empty; otherwise, None. """ return ( diff --git a/src/app/endpoints/query_v2.py b/src/app/endpoints/query_v2.py index dd4eef8f..69422c42 100644 --- a/src/app/endpoints/query_v2.py +++ b/src/app/endpoints/query_v2.py @@ -4,7 +4,7 @@ import json import logging -from typing import Annotated, Any, cast +from typing import Annotated, Any, Optional, cast from fastapi import APIRouter, Depends, Request from llama_stack.apis.agents.openai_responses import ( @@ -74,7 +74,7 @@ def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too-many-branches output_item: Any, -) -> tuple[ToolCallSummary | None, ToolResultSummary | None]: +) -> tuple[Optional[ToolCallSummary], Optional[ToolResultSummary]]: """Translate applicable Responses API tool outputs into ``ToolCallSummary`` records. The OpenAI ``response.output`` array may contain any ``OpenAIResponseOutput`` variant: @@ -110,7 +110,7 @@ def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too- "status": getattr(output_item, "status", None), } results = getattr(output_item, "results", None) - response_payload: Any | None = None + response_payload: Optional[Any] = None if results is not None: # Store only the essential result metadata to avoid large payloads response_payload = { @@ -294,7 +294,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche model_id: str, query_request: QueryRequest, token: str, - mcp_headers: dict[str, dict[str, str]] | None = None, + mcp_headers: Optional[dict[str, dict[str, str]]] = None, *, provider_id: str = "", ) -> tuple[TurnSummary, str, list[ReferencedDocument], TokenCounter]: @@ -505,7 +505,7 @@ def parse_referenced_documents_from_responses_api( """ documents: list[ReferencedDocument] = [] # Use a set to track unique documents by (doc_url, doc_title) tuple - seen_docs: set[tuple[str | None, str | None]] = set() + seen_docs: set[tuple[Optional[str], Optional[str]]] = set() if not response.output: return documents @@ -535,7 +535,7 @@ def parse_referenced_documents_from_responses_api( # If we have at least a filename or url if filename or doc_url: - # Treat empty string as None for URL to satisfy AnyUrl | None + # Treat empty string as None for URL to satisfy Optional[AnyUrl] final_url = doc_url if doc_url else None if (final_url, filename) not in seen_docs: documents.append( @@ -692,7 +692,7 @@ def _increment_llm_call_metric(provider: str, model: str) -> None: logger.warning("Failed to update LLM call metric: %s", e) -def get_rag_tools(vector_store_ids: list[str]) -> list[dict[str, Any]] | None: +def get_rag_tools(vector_store_ids: list[str]) -> Optional[list[dict[str, Any]]]: """ Convert vector store IDs to tools format for Responses API. @@ -700,7 +700,7 @@ def get_rag_tools(vector_store_ids: list[str]) -> list[dict[str, Any]] | None: vector_store_ids: List of vector store identifiers Returns: - list[dict[str, Any]] | None: List containing file_search tool configuration, + Optional[list[dict[str, Any]]]: List containing file_search tool configuration, or None if no vector stores provided """ if not vector_store_ids: @@ -717,8 +717,8 @@ def get_rag_tools(vector_store_ids: list[str]) -> list[dict[str, Any]] | None: def get_mcp_tools( mcp_servers: list, - token: str | None = None, - mcp_headers: dict[str, dict[str, str]] | None = None, + token: Optional[str] = None, + mcp_headers: Optional[dict[str, dict[str, str]]] = None, ) -> list[dict[str, Any]]: """ Convert MCP servers to tools format for Responses API. @@ -762,8 +762,8 @@ async def prepare_tools_for_responses_api( query_request: QueryRequest, token: str, config: AppConfig, - mcp_headers: dict[str, dict[str, str]] | None = None, -) -> list[dict[str, Any]] | None: + mcp_headers: Optional[dict[str, dict[str, str]]] = None, +) -> Optional[list[dict[str, Any]]]: """ Prepare tools for Responses API including RAG and MCP tools. @@ -778,7 +778,7 @@ async def prepare_tools_for_responses_api( mcp_headers: Per-request headers for MCP servers Returns: - list[dict[str, Any]] | None: List of tool configurations for the + Optional[list[dict[str, Any]]]: List of tool configurations for the Responses API, or None if no_tools is True or no tools are available """ if query_request.no_tools: diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 8e526615..9bc5b7a7 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -7,7 +7,15 @@ import uuid from collections.abc import Callable from datetime import UTC, datetime -from typing import Annotated, Any, AsyncGenerator, AsyncIterator, Iterator, cast +from typing import ( + Annotated, + Any, + AsyncGenerator, + AsyncIterator, + Iterator, + Optional, + cast, +) from fastapi import APIRouter, Depends, Request from fastapi.responses import StreamingResponse @@ -231,7 +239,7 @@ def stream_build_event( chunk_id: int, metadata_map: dict, media_type: str = MEDIA_TYPE_JSON, - conversation_id: str | None = None, + conversation_id: Optional[str] = None, ) -> Iterator[str]: """Build a streaming event from a chunk response. @@ -384,7 +392,7 @@ async def stream_http_error(error: AbstractErrorResponse) -> AsyncGenerator[str, def _handle_turn_start_event( _chunk_id: int, media_type: str = MEDIA_TYPE_JSON, - conversation_id: str | None = None, + conversation_id: Optional[str] = None, ) -> Iterator[str]: """ Yield turn start event. @@ -734,7 +742,7 @@ async def response_generator( # Send start event at the beginning of the stream yield stream_start_event(context.conversation_id) - latest_turn: Any | None = None + latest_turn: Optional[Any] = None async for chunk in turn_response: if chunk.event is None: @@ -850,7 +858,7 @@ async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-loc user_id, _user_name, _skip_userid_check, token = auth - user_conversation: UserConversation | None = None + user_conversation: Optional[UserConversation] = None if query_request.conversation_id: user_conversation = validate_conversation_ownership( user_id=user_id, conversation_id=query_request.conversation_id @@ -1001,7 +1009,7 @@ async def retrieve_response( model_id: str, query_request: QueryRequest, token: str, - mcp_headers: dict[str, dict[str, str]] | None = None, + mcp_headers: Optional[dict[str, dict[str, str]]] = None, ) -> tuple[AsyncIterator[AgentTurnResponseStreamChunk], str]: """ Retrieve response from LLMs and agents. diff --git a/src/app/endpoints/streaming_query_v2.py b/src/app/endpoints/streaming_query_v2.py index 45e63b88..847bcacb 100644 --- a/src/app/endpoints/streaming_query_v2.py +++ b/src/app/endpoints/streaming_query_v2.py @@ -1,7 +1,7 @@ """Streaming query handler using Responses API (v2).""" import logging -from typing import Annotated, Any, AsyncIterator, cast +from typing import Annotated, Any, AsyncIterator, Optional, cast from fastapi import APIRouter, Depends, Request from fastapi.responses import StreamingResponse @@ -138,7 +138,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat start_event_emitted = False # Track the latest response object from response.completed event - latest_response_object: Any | None = None + latest_response_object: Optional[Any] = None logger.debug("Starting streaming response (Responses API) processing") @@ -372,7 +372,7 @@ async def retrieve_response( # pylint: disable=too-many-locals model_id: str, query_request: QueryRequest, token: str, - mcp_headers: dict[str, dict[str, str]] | None = None, + mcp_headers: Optional[dict[str, dict[str, str]]] = None, ) -> tuple[AsyncIterator[OpenAIResponseObjectStream], str]: """ Retrieve response from LLMs and agents. @@ -471,7 +471,7 @@ async def retrieve_response( # pylint: disable=too-many-locals async def create_violation_stream( message: str, - shield_model: str | None = None, + shield_model: Optional[str] = None, ) -> AsyncIterator[OpenAIResponseObjectStream]: """Generate a minimal streaming response for cases where input is blocked by a shield. diff --git a/src/utils/types.py b/src/utils/types.py index a915285b..06f8d9e6 100644 --- a/src/utils/types.py +++ b/src/utils/types.py @@ -104,8 +104,8 @@ class ShieldModerationResult(BaseModel): """Result of shield moderation check.""" blocked: bool - message: str | None = None - shield_model: str | None = None + message: Optional[str] = None + shield_model: Optional[str] = None class ToolCallSummary(BaseModel):