From e971aabc314bee36a79a86199613723933788be6 Mon Sep 17 00:00:00 2001 From: are-ces <195810094+are-ces@users.noreply.github.com> Date: Wed, 1 Oct 2025 11:01:36 +0200 Subject: [PATCH] Removed global vars to make compatible with uvicorn workers > 1 --- src/app/endpoints/authorized.py | 7 ++-- src/app/endpoints/config.py | 8 ++--- src/app/endpoints/conversations.py | 18 +++++----- src/app/endpoints/conversations_v2.py | 15 ++++---- src/app/endpoints/feedback.py | 16 ++++----- src/app/endpoints/health.py | 14 ++++---- src/app/endpoints/info.py | 11 +++--- src/app/endpoints/metrics.py | 13 ++++--- src/app/endpoints/models.py | 7 ++-- src/app/endpoints/query.py | 3 +- src/app/endpoints/root.py | 8 ++--- src/app/endpoints/streaming_query.py | 3 +- src/app/main.py | 52 ++++++++++++++++++--------- src/authentication/__init__.py | 19 +++++++--- src/lightspeed_stack.py | 20 ++++------- 15 files changed, 109 insertions(+), 105 deletions(-) diff --git a/src/app/endpoints/authorized.py b/src/app/endpoints/authorized.py index 8bf2a2e50..b108b2057 100644 --- a/src/app/endpoints/authorized.py +++ b/src/app/endpoints/authorized.py @@ -5,13 +5,12 @@ from fastapi import APIRouter, Depends -from authentication.interface import AuthTuple from authentication import get_auth_dependency -from models.responses import AuthorizedResponse, UnauthorizedResponse, ForbiddenResponse +from authentication.interface import AuthTuple +from models.responses import AuthorizedResponse, ForbiddenResponse, UnauthorizedResponse logger = logging.getLogger(__name__) router = APIRouter(tags=["authorized"]) -auth_dependency = get_auth_dependency() authorized_responses: dict[int | str, dict[str, Any]] = { @@ -38,7 +37,7 @@ @router.post("/authorized", responses=authorized_responses) async def authorized_endpoint_handler( - auth: Annotated[AuthTuple, Depends(auth_dependency)], + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], ) -> AuthorizedResponse: """ Handle request to the /authorized endpoint. diff --git a/src/app/endpoints/config.py b/src/app/endpoints/config.py index 99c1104d5..a68f8707b 100644 --- a/src/app/endpoints/config.py +++ b/src/app/endpoints/config.py @@ -3,10 +3,10 @@ import logging from typing import Annotated, Any -from fastapi import APIRouter, Request, Depends +from fastapi import APIRouter, Depends, Request -from authentication.interface import AuthTuple from authentication import get_auth_dependency +from authentication.interface import AuthTuple from authorization.middleware import authorize from configuration import configuration from models.config import Action, Configuration @@ -15,8 +15,6 @@ logger = logging.getLogger(__name__) router = APIRouter(tags=["config"]) -auth_dependency = get_auth_dependency() - get_config_responses: dict[int | str, dict[str, Any]] = { 200: { @@ -63,7 +61,7 @@ @router.get("/config", responses=get_config_responses) @authorize(Action.GET_CONFIG) async def config_endpoint_handler( - auth: Annotated[AuthTuple, Depends(auth_dependency)], + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], request: Request, ) -> Configuration: """ diff --git a/src/app/endpoints/conversations.py b/src/app/endpoints/conversations.py index bfb951296..1e1f4ed7a 100644 --- a/src/app/endpoints/conversations.py +++ b/src/app/endpoints/conversations.py @@ -3,22 +3,21 @@ import logging from typing import Any +from fastapi import APIRouter, Depends, HTTPException, Request, status from llama_stack_client import APIConnectionError, NotFoundError -from fastapi import APIRouter, HTTPException, Request, status, Depends - -from client import AsyncLlamaStackClientHolder -from configuration import configuration from app.database import get_session from authentication import get_auth_dependency from authorization.middleware import authorize +from client import AsyncLlamaStackClientHolder +from configuration import configuration from models.config import Action from models.database.conversations import UserConversation from models.responses import ( - ConversationResponse, ConversationDeleteResponse, - ConversationsListResponse, ConversationDetails, + ConversationResponse, + ConversationsListResponse, UnauthorizedResponse, ) from utils.endpoints import ( @@ -30,7 +29,6 @@ logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["conversations"]) -auth_dependency = get_auth_dependency() conversation_responses: dict[int | str, dict[str, Any]] = { 200: { @@ -180,7 +178,7 @@ def simplify_session_data(session_data: dict) -> list[dict[str, Any]]: @authorize(Action.LIST_CONVERSATIONS) async def get_conversations_list_endpoint_handler( request: Request, - auth: Any = Depends(auth_dependency), + auth: Any = Depends(get_auth_dependency()), ) -> ConversationsListResponse: """Handle request to retrieve all conversations for the authenticated user.""" check_configuration_loaded(configuration) @@ -242,7 +240,7 @@ async def get_conversations_list_endpoint_handler( async def get_conversation_endpoint_handler( request: Request, conversation_id: str, - auth: Any = Depends(auth_dependency), + auth: Any = Depends(get_auth_dependency()), ) -> ConversationResponse: """ Handle request to retrieve a conversation by ID. @@ -370,7 +368,7 @@ async def get_conversation_endpoint_handler( async def delete_conversation_endpoint_handler( request: Request, conversation_id: str, - auth: Any = Depends(auth_dependency), + auth: Any = Depends(get_auth_dependency()), ) -> ConversationDeleteResponse: """ Handle request to delete a conversation by ID. diff --git a/src/app/endpoints/conversations_v2.py b/src/app/endpoints/conversations_v2.py index 5033b5e5f..e8697e78b 100644 --- a/src/app/endpoints/conversations_v2.py +++ b/src/app/endpoints/conversations_v2.py @@ -3,17 +3,17 @@ import logging from typing import Any -from fastapi import APIRouter, Request, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Request, status -from configuration import configuration from authentication import get_auth_dependency from authorization.middleware import authorize +from configuration import configuration from models.cache_entry import CacheEntry from models.config import Action from models.responses import ( - ConversationsListResponseV2, - ConversationResponse, ConversationDeleteResponse, + ConversationResponse, + ConversationsListResponseV2, UnauthorizedResponse, ) from utils.endpoints import check_configuration_loaded @@ -21,7 +21,6 @@ logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["conversations_v2"]) -auth_dependency = get_auth_dependency() conversation_responses: dict[int | str, dict[str, Any]] = { @@ -93,7 +92,7 @@ @authorize(Action.LIST_CONVERSATIONS) async def get_conversations_list_endpoint_handler( request: Request, # pylint: disable=unused-argument - auth: Any = Depends(auth_dependency), + auth: Any = Depends(get_auth_dependency()), ) -> ConversationsListResponseV2: """Handle request to retrieve all conversations for the authenticated user.""" check_configuration_loaded(configuration) @@ -123,7 +122,7 @@ async def get_conversations_list_endpoint_handler( async def get_conversation_endpoint_handler( request: Request, # pylint: disable=unused-argument conversation_id: str, - auth: Any = Depends(auth_dependency), + auth: Any = Depends(get_auth_dependency()), ) -> ConversationResponse: """Handle request to retrieve a conversation by ID.""" check_configuration_loaded(configuration) @@ -159,7 +158,7 @@ async def get_conversation_endpoint_handler( async def delete_conversation_endpoint_handler( request: Request, # pylint: disable=unused-argument conversation_id: str, - auth: Any = Depends(auth_dependency), + auth: Any = Depends(get_auth_dependency()), ) -> ConversationDeleteResponse: """Handle request to delete a conversation by ID.""" check_configuration_loaded(configuration) diff --git a/src/app/endpoints/feedback.py b/src/app/endpoints/feedback.py index ccf5d5aca..2756bdb81 100644 --- a/src/app/endpoints/feedback.py +++ b/src/app/endpoints/feedback.py @@ -1,12 +1,13 @@ """Handler for REST API endpoint for user feedback.""" +import json import logging import threading -from typing import Annotated, Any +from datetime import UTC, datetime from pathlib import Path -import json -from datetime import datetime, UTC -from fastapi import APIRouter, HTTPException, Depends, Request, status +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, HTTPException, Request, status from authentication import get_auth_dependency from authentication.interface import AuthTuple @@ -18,15 +19,14 @@ ErrorResponse, FeedbackResponse, FeedbackStatusUpdateResponse, + ForbiddenResponse, StatusResponse, UnauthorizedResponse, - ForbiddenResponse, ) from utils.suid import get_suid logger = logging.getLogger(__name__) router = APIRouter(prefix="/feedback", tags=["feedback"]) -auth_dependency = get_auth_dependency() feedback_status_lock = threading.Lock() # Response for the feedback endpoint @@ -87,7 +87,7 @@ async def assert_feedback_enabled(_request: Request) -> None: @authorize(Action.FEEDBACK) async def feedback_endpoint_handler( feedback_request: FeedbackRequest, - auth: Annotated[AuthTuple, Depends(auth_dependency)], + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], _ensure_feedback_enabled: Any = Depends(assert_feedback_enabled), ) -> FeedbackResponse: """Handle feedback requests. @@ -183,7 +183,7 @@ def feedback_status() -> StatusResponse: @authorize(Action.ADMIN) async def update_feedback_status( feedback_update_request: FeedbackStatusUpdateRequest, - auth: Annotated[AuthTuple, Depends(auth_dependency)], + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], ) -> FeedbackStatusUpdateResponse: """ Handle feedback status update requests. diff --git a/src/app/endpoints/health.py b/src/app/endpoints/health.py index ca09646d5..23e89921d 100644 --- a/src/app/endpoints/health.py +++ b/src/app/endpoints/health.py @@ -8,25 +8,23 @@ import logging from typing import Annotated, Any +from fastapi import APIRouter, Depends, Response, status from llama_stack.providers.datatypes import HealthStatus -from fastapi import APIRouter, status, Response, Depends -from client import AsyncLlamaStackClientHolder -from authentication.interface import AuthTuple from authentication import get_auth_dependency +from authentication.interface import AuthTuple from authorization.middleware import authorize +from client import AsyncLlamaStackClientHolder from models.config import Action from models.responses import ( LivenessResponse, - ReadinessResponse, ProviderHealthStatus, + ReadinessResponse, ) logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["health"]) -auth_dependency = get_auth_dependency() - async def get_providers_health_statuses() -> list[ProviderHealthStatus]: """ @@ -80,7 +78,7 @@ async def get_providers_health_statuses() -> list[ProviderHealthStatus]: @router.get("/readiness", responses=get_readiness_responses) @authorize(Action.INFO) async def readiness_probe_get_method( - auth: Annotated[AuthTuple, Depends(auth_dependency)], + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], response: Response, ) -> ReadinessResponse: """ @@ -126,7 +124,7 @@ async def readiness_probe_get_method( @router.get("/liveness", responses=get_liveness_responses) @authorize(Action.INFO) async def liveness_probe_get_method( - auth: Annotated[AuthTuple, Depends(auth_dependency)], + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], ) -> LivenessResponse: """ Return the liveness status of the service. diff --git a/src/app/endpoints/info.py b/src/app/endpoints/info.py index dfcb62023..1bd74dc2d 100644 --- a/src/app/endpoints/info.py +++ b/src/app/endpoints/info.py @@ -3,15 +3,14 @@ import logging from typing import Annotated, Any -from fastapi import APIRouter, HTTPException, Request, status -from fastapi import Depends +from fastapi import APIRouter, Depends, HTTPException, Request, status from llama_stack_client import APIConnectionError -from authentication.interface import AuthTuple from authentication import get_auth_dependency +from authentication.interface import AuthTuple from authorization.middleware import authorize -from configuration import configuration from client import AsyncLlamaStackClientHolder +from configuration import configuration from models.config import Action from models.responses import InfoResponse from version import __version__ @@ -19,8 +18,6 @@ logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["info"]) -auth_dependency = get_auth_dependency() - get_info_responses: dict[int | str, dict[str, Any]] = { 200: { @@ -40,7 +37,7 @@ @router.get("/info", responses=get_info_responses) @authorize(Action.INFO) async def info_endpoint_handler( - auth: Annotated[AuthTuple, Depends(auth_dependency)], + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], request: Request, ) -> InfoResponse: """ diff --git a/src/app/endpoints/metrics.py b/src/app/endpoints/metrics.py index 9bc938f63..5a8c90d0d 100644 --- a/src/app/endpoints/metrics.py +++ b/src/app/endpoints/metrics.py @@ -1,28 +1,27 @@ """Handler for REST API call to provide metrics.""" from typing import Annotated + +from fastapi import APIRouter, Depends, Request from fastapi.responses import PlainTextResponse -from fastapi import APIRouter, Request, Depends from prometheus_client import ( - generate_latest, CONTENT_TYPE_LATEST, + generate_latest, ) -from authentication.interface import AuthTuple from authentication import get_auth_dependency +from authentication.interface import AuthTuple from authorization.middleware import authorize -from models.config import Action from metrics.utils import setup_model_metrics +from models.config import Action router = APIRouter(tags=["metrics"]) -auth_dependency = get_auth_dependency() - @router.get("/metrics", response_class=PlainTextResponse) @authorize(Action.GET_METRICS) async def metrics_endpoint_handler( - auth: Annotated[AuthTuple, Depends(auth_dependency)], + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], request: Request, ) -> PlainTextResponse: """ diff --git a/src/app/endpoints/models.py b/src/app/endpoints/models.py index afecf3438..a54749cb0 100644 --- a/src/app/endpoints/models.py +++ b/src/app/endpoints/models.py @@ -9,9 +9,9 @@ from authentication import get_auth_dependency from authentication.interface import AuthTuple +from authorization.middleware import authorize from client import AsyncLlamaStackClientHolder from configuration import configuration -from authorization.middleware import authorize from models.config import Action from models.responses import ModelsResponse from utils.endpoints import check_configuration_loaded @@ -20,9 +20,6 @@ router = APIRouter(tags=["models"]) -auth_dependency = get_auth_dependency() - - models_responses: dict[int | str, dict[str, Any]] = { 200: { "models": [ @@ -54,7 +51,7 @@ @authorize(Action.GET_MODELS) async def models_endpoint_handler( request: Request, - auth: Annotated[AuthTuple, Depends(auth_dependency)], + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], ) -> ModelsResponse: """ Handle requests to the /models endpoint. diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 15a215444..38ce42117 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -57,7 +57,6 @@ logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["query"]) -auth_dependency = get_auth_dependency() query_response: dict[int | str, dict[str, Any]] = { 200: { @@ -174,7 +173,7 @@ def evaluate_model_hints( async def query_endpoint_handler( request: Request, query_request: QueryRequest, - auth: Annotated[AuthTuple, Depends(auth_dependency)], + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), ) -> QueryResponse: """ diff --git a/src/app/endpoints/root.py b/src/app/endpoints/root.py index 3bdf01df6..4485d2d6f 100644 --- a/src/app/endpoints/root.py +++ b/src/app/endpoints/root.py @@ -3,19 +3,17 @@ import logging from typing import Annotated -from fastapi import APIRouter, Request, Depends +from fastapi import APIRouter, Depends, Request from fastapi.responses import HTMLResponse -from authentication.interface import AuthTuple from authentication import get_auth_dependency +from authentication.interface import AuthTuple from authorization.middleware import authorize from models.config import Action logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["root"]) -auth_dependency = get_auth_dependency() - INDEX_PAGE = """ @@ -780,7 +778,7 @@ @router.get("/", response_class=HTMLResponse) @authorize(Action.INFO) async def root_endpoint_handler( - auth: Annotated[AuthTuple, Depends(auth_dependency)], + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], request: Request, ) -> HTMLResponse: """ diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index f48646ac7..8104b453e 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -56,7 +56,6 @@ logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["streaming_query"]) -auth_dependency = get_auth_dependency() streaming_query_responses: dict[int | str, dict[str, Any]] = { 200: { @@ -570,7 +569,7 @@ def _handle_heartbeat_event(chunk_id: int) -> Iterator[str]: async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals request: Request, query_request: QueryRequest, - auth: Annotated[AuthTuple, Depends(auth_dependency)], + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), ) -> StreamingResponse: """ diff --git a/src/app/main.py b/src/app/main.py index 3b9830fdb..cbf45e81b 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -1,26 +1,57 @@ """Definition of FastAPI based web service.""" -from typing import Callable, Awaitable +import os +from contextlib import asynccontextmanager +from typing import AsyncIterator, Awaitable, Callable from fastapi import FastAPI, Request, Response from fastapi.middleware.cors import CORSMiddleware from starlette.routing import Mount, Route, WebSocketRoute +import metrics +import version from app import routers -from app.database import initialize_database, create_tables +from app.database import create_tables, initialize_database +from client import AsyncLlamaStackClientHolder from configuration import configuration from log import get_logger -import metrics from utils.common import register_mcp_servers_async -import version +from utils.llama_stack_version import check_llama_stack_version logger = get_logger(__name__) logger.info("Initializing app") + service_name = configuration.configuration.name +# running on FastAPI startup +@asynccontextmanager +async def lifespan(_app: FastAPI) -> AsyncIterator[None]: + """ + Initialize app resources. + + FastAPI lifespan context: initializes configuration, Llama client, MCP servers, + logger, and database before serving requests. + """ + configuration.load_configuration(os.environ["LIGHTSPEED_STACK_CONFIG_PATH"]) + await AsyncLlamaStackClientHolder().load(configuration.configuration.llama_stack) + client = AsyncLlamaStackClientHolder().get_client() + # check if the Llama Stack version is supported by the service + await check_llama_stack_version(client) + + logger.info("Registering MCP servers") + await register_mcp_servers_async(logger, configuration.configuration) + get_logger("app.endpoints.handlers") + logger.info("App startup complete") + + initialize_database() + create_tables() + + yield + + app = FastAPI( title=f"{service_name} service - OpenAPI", summary=f"{service_name} service API specification.", @@ -38,6 +69,7 @@ servers=[ {"url": "http://localhost:8080/", "description": "Locally running service"} ], + lifespan=lifespan, ) cors = configuration.service_configuration.cors @@ -84,15 +116,3 @@ async def rest_api_metrics( for route in app.routes if isinstance(route, (Mount, Route, WebSocketRoute)) ] - - -@app.on_event("startup") -async def startup_event() -> None: - """Perform logger setup on service startup.""" - logger.info("Registering MCP servers") - await register_mcp_servers_async(logger, configuration.configuration) - get_logger("app.endpoints.handlers") - logger.info("App startup complete") - - initialize_database() - create_tables() diff --git a/src/authentication/__init__.py b/src/authentication/__init__.py index ec0a171d0..6b71b3b47 100644 --- a/src/authentication/__init__.py +++ b/src/authentication/__init__.py @@ -1,12 +1,12 @@ """This package contains authentication code and modules.""" import logging +import os -from authentication.interface import AuthInterface -from authentication import noop, noop_with_token, k8s, jwk_token -from configuration import configuration import constants - +from authentication import jwk_token, k8s, noop, noop_with_token +from authentication.interface import AuthInterface +from configuration import LogicError, configuration logger = logging.getLogger(__name__) @@ -15,7 +15,16 @@ def get_auth_dependency( virtual_path: str = constants.DEFAULT_VIRTUAL_PATH, ) -> AuthInterface: """Select the configured authentication dependency interface.""" - module = configuration.authentication_configuration.module + try: + module = configuration.authentication_configuration.module + except LogicError: + # Only load once if not already loaded + config_path = os.getenv( + "LIGHTSPEED_STACK_CONFIG_PATH", + "tests/configuration/lightspeed-stack.yaml", + ) + configuration.load_configuration(config_path) + module = configuration.authentication_configuration.module logger.debug( "Initializing authentication dependency: module='%s', virtual_path='%s'", diff --git a/src/lightspeed_stack.py b/src/lightspeed_stack.py index 7aedcfe10..a8747ddda 100644 --- a/src/lightspeed_stack.py +++ b/src/lightspeed_stack.py @@ -4,15 +4,14 @@ main() function. """ -from argparse import ArgumentParser -import asyncio import logging +import os +from argparse import ArgumentParser + from rich.logging import RichHandler -from runners.uvicorn import start_uvicorn from configuration import configuration -from client import AsyncLlamaStackClientHolder -from utils.llama_stack_version import check_llama_stack_version +from runners.uvicorn import start_uvicorn FORMAT = "%(message)s" logging.basicConfig( @@ -75,14 +74,9 @@ def main() -> None: raise SystemExit(1) from e return - logger.info("Creating AsyncLlamaStackClient") - asyncio.run( - AsyncLlamaStackClientHolder().load(configuration.configuration.llama_stack) - ) - client = AsyncLlamaStackClientHolder().get_client() - - # check if the Llama Stack version is supported by the service - asyncio.run(check_llama_stack_version(client)) + # Store config path in env so each uvicorn worker can load it + # (step is needed because process context isn’t shared). + os.environ["LIGHTSPEED_STACK_CONFIG_PATH"] = args.config_file # if every previous steps don't fail, start the service on specified port start_uvicorn(configuration.service_configuration)