Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/app/endpoints/a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import uuid
from datetime import datetime, timezone
from typing import Annotated, Any, AsyncIterator, MutableMapping
from typing import Annotated, Any, AsyncIterator, MutableMapping, Optional

from fastapi import APIRouter, Depends, HTTPException, Request, status
from llama_stack.apis.agents.openai_responses import (
Expand Down Expand Up @@ -65,8 +65,8 @@
# Task store and context store are created lazily based on configuration.
# For multi-worker deployments, configure 'a2a_state' with 'sqlite' or 'postgres'
# to share state across workers.
_TASK_STORE: TaskStore | None = None
_CONTEXT_STORE: A2AContextStore | None = None
_TASK_STORE: Optional[TaskStore] = None
_CONTEXT_STORE: Optional[A2AContextStore] = None


async def _get_task_store() -> TaskStore:
Expand Down Expand Up @@ -120,7 +120,7 @@ class TaskResultAggregator:
def __init__(self) -> None:
"""Initialize the task result aggregator with default state."""
self._task_state: TaskState = TaskState.working
self._task_status_message: Message | None = None
self._task_status_message: Optional[Message] = None

def process_event(
self, event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Any
Expand Down Expand Up @@ -169,7 +169,7 @@ def task_state(self) -> TaskState:
return self._task_state

@property
def task_status_message(self) -> Message | None:
def task_status_message(self) -> Optional[Message]:
"""Return the current task status message."""
return self._task_status_message

Expand All @@ -185,7 +185,7 @@ class A2AAgentExecutor(AgentExecutor):
"""

def __init__(
self, auth_token: str, mcp_headers: dict[str, dict[str, str]] | None = None
self, auth_token: str, mcp_headers: Optional[dict[str, dict[str, str]]] = None
):
"""Initialize the A2A agent executor.

Expand Down Expand Up @@ -413,7 +413,7 @@ async def _convert_stream_to_events( # pylint: disable=too-many-branches,too-ma
stream: AsyncIterator[OpenAIResponseObjectStream],
task_id: str,
context_id: str,
conversation_id: str | None,
conversation_id: Optional[str],
) -> AsyncIterator[Any]:
"""Convert Responses API stream chunks to A2A events.

Expand Down
18 changes: 9 additions & 9 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,12 @@ def persist_user_conversation_details(


def evaluate_model_hints(
user_conversation: UserConversation | None,
user_conversation: Optional[UserConversation],
query_request: QueryRequest,
) -> tuple[str | None, str | None]:
) -> tuple[Optional[str], Optional[str]]:
"""Evaluate model hints from user conversation."""
model_id: str | None = query_request.model
provider_id: str | None = query_request.provider
model_id: Optional[str] = query_request.model
provider_id: Optional[str] = query_request.provider

if user_conversation is not None:
if query_request.model is not None:
Expand Down Expand Up @@ -271,7 +271,7 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
user_id, _, _skip_userid_check, token = auth

started_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
user_conversation: UserConversation | None = None
user_conversation: Optional[UserConversation] = None
if query_request.conversation_id:
logger.debug(
"Conversation ID specified in query: %s", query_request.conversation_id
Expand Down Expand Up @@ -483,7 +483,7 @@ async def query_endpoint_handler(


def select_model_and_provider_id(
models: ModelListResponse, model_id: str | None, provider_id: str | None
models: ModelListResponse, model_id: Optional[str], provider_id: Optional[str]
) -> tuple[str, str, str]:
"""
Select the model ID and provider ID based on the request or available models.
Expand Down Expand Up @@ -663,7 +663,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
model_id: str,
query_request: QueryRequest,
token: str,
mcp_headers: dict[str, dict[str, str]] | None = None,
mcp_headers: Optional[dict[str, dict[str, str]]] = None,
*,
provider_id: str = "",
) -> tuple[TurnSummary, str, list[ReferencedDocument], TokenCounter]:
Expand Down Expand Up @@ -859,7 +859,7 @@ def validate_attachments_metadata(attachments: list[Attachment]) -> None:

def get_rag_toolgroups(
vector_db_ids: list[str],
) -> list[Toolgroup] | None:
) -> Optional[list[Toolgroup]]:
"""
Return a list of RAG Tool groups if the given vector DB list is not empty.

Expand All @@ -870,7 +870,7 @@ def get_rag_toolgroups(
vector_db_ids (list[str]): List of vector database identifiers to include in the toolgroup.

Returns:
list[Toolgroup] | None: A list with a single RAG toolgroup if
Optional[list[Toolgroup]]: A list with a single RAG toolgroup if
vector_db_ids is non-empty; otherwise, None.
"""
return (
Expand Down
26 changes: 13 additions & 13 deletions src/app/endpoints/query_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import json
import logging
from typing import Annotated, Any, cast
from typing import Annotated, Any, Optional, cast

from fastapi import APIRouter, Depends, Request
from llama_stack.apis.agents.openai_responses import (
Expand Down Expand Up @@ -74,7 +74,7 @@

def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too-many-branches
output_item: Any,
) -> tuple[ToolCallSummary | None, ToolResultSummary | None]:
) -> tuple[Optional[ToolCallSummary], Optional[ToolResultSummary]]:
"""Translate applicable Responses API tool outputs into ``ToolCallSummary`` records.

The OpenAI ``response.output`` array may contain any ``OpenAIResponseOutput`` variant:
Expand Down Expand Up @@ -110,7 +110,7 @@ def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too-
"status": getattr(output_item, "status", None),
}
results = getattr(output_item, "results", None)
response_payload: Any | None = None
response_payload: Optional[Any] = None
if results is not None:
# Store only the essential result metadata to avoid large payloads
response_payload = {
Expand Down Expand Up @@ -294,7 +294,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
model_id: str,
query_request: QueryRequest,
token: str,
mcp_headers: dict[str, dict[str, str]] | None = None,
mcp_headers: Optional[dict[str, dict[str, str]]] = None,
*,
provider_id: str = "",
) -> tuple[TurnSummary, str, list[ReferencedDocument], TokenCounter]:
Expand Down Expand Up @@ -505,7 +505,7 @@ def parse_referenced_documents_from_responses_api(
"""
documents: list[ReferencedDocument] = []
# Use a set to track unique documents by (doc_url, doc_title) tuple
seen_docs: set[tuple[str | None, str | None]] = set()
seen_docs: set[tuple[Optional[str], Optional[str]]] = set()

if not response.output:
return documents
Expand Down Expand Up @@ -535,7 +535,7 @@ def parse_referenced_documents_from_responses_api(

# If we have at least a filename or url
if filename or doc_url:
# Treat empty string as None for URL to satisfy AnyUrl | None
# Treat empty string as None for URL to satisfy Optional[AnyUrl]
final_url = doc_url if doc_url else None
if (final_url, filename) not in seen_docs:
documents.append(
Expand Down Expand Up @@ -692,15 +692,15 @@ def _increment_llm_call_metric(provider: str, model: str) -> None:
logger.warning("Failed to update LLM call metric: %s", e)


def get_rag_tools(vector_store_ids: list[str]) -> list[dict[str, Any]] | None:
def get_rag_tools(vector_store_ids: list[str]) -> Optional[list[dict[str, Any]]]:
"""
Convert vector store IDs to tools format for Responses API.

Args:
vector_store_ids: List of vector store identifiers

Returns:
list[dict[str, Any]] | None: List containing file_search tool configuration,
Optional[list[dict[str, Any]]]: List containing file_search tool configuration,
or None if no vector stores provided
"""
if not vector_store_ids:
Expand All @@ -717,8 +717,8 @@ def get_rag_tools(vector_store_ids: list[str]) -> list[dict[str, Any]] | None:

def get_mcp_tools(
mcp_servers: list,
token: str | None = None,
mcp_headers: dict[str, dict[str, str]] | None = None,
token: Optional[str] = None,
mcp_headers: Optional[dict[str, dict[str, str]]] = None,
) -> list[dict[str, Any]]:
"""
Convert MCP servers to tools format for Responses API.
Expand Down Expand Up @@ -762,8 +762,8 @@ async def prepare_tools_for_responses_api(
query_request: QueryRequest,
token: str,
config: AppConfig,
mcp_headers: dict[str, dict[str, str]] | None = None,
) -> list[dict[str, Any]] | None:
mcp_headers: Optional[dict[str, dict[str, str]]] = None,
) -> Optional[list[dict[str, Any]]]:
"""
Prepare tools for Responses API including RAG and MCP tools.

Expand All @@ -778,7 +778,7 @@ async def prepare_tools_for_responses_api(
mcp_headers: Per-request headers for MCP servers

Returns:
list[dict[str, Any]] | None: List of tool configurations for the
Optional[list[dict[str, Any]]]: List of tool configurations for the
Responses API, or None if no_tools is True or no tools are available
"""
if query_request.no_tools:
Expand Down
20 changes: 14 additions & 6 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@
import uuid
from collections.abc import Callable
from datetime import UTC, datetime
from typing import Annotated, Any, AsyncGenerator, AsyncIterator, Iterator, cast
from typing import (
Annotated,
Any,
AsyncGenerator,
AsyncIterator,
Iterator,
Optional,
cast,
)

from fastapi import APIRouter, Depends, Request
from fastapi.responses import StreamingResponse
Expand Down Expand Up @@ -231,7 +239,7 @@ def stream_build_event(
chunk_id: int,
metadata_map: dict,
media_type: str = MEDIA_TYPE_JSON,
conversation_id: str | None = None,
conversation_id: Optional[str] = None,
) -> Iterator[str]:
"""Build a streaming event from a chunk response.

Expand Down Expand Up @@ -384,7 +392,7 @@ async def stream_http_error(error: AbstractErrorResponse) -> AsyncGenerator[str,
def _handle_turn_start_event(
_chunk_id: int,
media_type: str = MEDIA_TYPE_JSON,
conversation_id: str | None = None,
conversation_id: Optional[str] = None,
) -> Iterator[str]:
"""
Yield turn start event.
Expand Down Expand Up @@ -734,7 +742,7 @@ async def response_generator(
# Send start event at the beginning of the stream
yield stream_start_event(context.conversation_id)

latest_turn: Any | None = None
latest_turn: Optional[Any] = None

async for chunk in turn_response:
if chunk.event is None:
Expand Down Expand Up @@ -850,7 +858,7 @@ async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-loc

user_id, _user_name, _skip_userid_check, token = auth

user_conversation: UserConversation | None = None
user_conversation: Optional[UserConversation] = None
if query_request.conversation_id:
user_conversation = validate_conversation_ownership(
user_id=user_id, conversation_id=query_request.conversation_id
Expand Down Expand Up @@ -1001,7 +1009,7 @@ async def retrieve_response(
model_id: str,
query_request: QueryRequest,
token: str,
mcp_headers: dict[str, dict[str, str]] | None = None,
mcp_headers: Optional[dict[str, dict[str, str]]] = None,
) -> tuple[AsyncIterator[AgentTurnResponseStreamChunk], str]:
"""
Retrieve response from LLMs and agents.
Expand Down
8 changes: 4 additions & 4 deletions src/app/endpoints/streaming_query_v2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Streaming query handler using Responses API (v2)."""

import logging
from typing import Annotated, Any, AsyncIterator, cast
from typing import Annotated, Any, AsyncIterator, Optional, cast

from fastapi import APIRouter, Depends, Request
from fastapi.responses import StreamingResponse
Expand Down Expand Up @@ -138,7 +138,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
start_event_emitted = False

# Track the latest response object from response.completed event
latest_response_object: Any | None = None
latest_response_object: Optional[Any] = None

logger.debug("Starting streaming response (Responses API) processing")

Expand Down Expand Up @@ -372,7 +372,7 @@ async def retrieve_response( # pylint: disable=too-many-locals
model_id: str,
query_request: QueryRequest,
token: str,
mcp_headers: dict[str, dict[str, str]] | None = None,
mcp_headers: Optional[dict[str, dict[str, str]]] = None,
) -> tuple[AsyncIterator[OpenAIResponseObjectStream], str]:
"""
Retrieve response from LLMs and agents.
Expand Down Expand Up @@ -471,7 +471,7 @@ async def retrieve_response( # pylint: disable=too-many-locals

async def create_violation_stream(
message: str,
shield_model: str | None = None,
shield_model: Optional[str] = None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Generate a minimal streaming response for cases where input is blocked by a shield.

Expand Down
4 changes: 2 additions & 2 deletions src/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ class ShieldModerationResult(BaseModel):
"""Result of shield moderation check."""

blocked: bool
message: str | None = None
shield_model: str | None = None
message: Optional[str] = None
shield_model: Optional[str] = None


class ToolCallSummary(BaseModel):
Expand Down
Loading