From 83e9e2313c5fbf5ae04566202b3b9d3c6329b9ea Mon Sep 17 00:00:00 2001 From: Omer Tuchfeld Date: Wed, 6 Aug 2025 17:12:51 +0200 Subject: [PATCH] Role-based authorization layer This commit adds a second authorization layer that evaluates what authenticated users are allowed to do, separate from whether they can authenticate at all. The authorization system sits between endpoint handlers and the existing auth dependency injection. # Architecture The system uses two resolver interfaces that can be swapped based on auth module (although `AccessResolver` shouldn't change much across different auth modules): - `RolesResolver`: given an `AuthTuple`, determine what roles the user has - `AccessResolver`: given roles and an action, determine if access is allowed For JWT auth, roles are extracted by applying user-provided rules containing JSONPath expressions to JWT claims . Each role rule specifies a JSONPath, operator (equals/contains/in), a value, and the resulting roles to apply if the rule matches the user claims. This lets you map claim values like `department: "engineering"` to roles like "developer". Roles are arbitrary except for the special `*` role which applies to everyone automatically. Access rules then map roles to permitted actions. # Implementation The `@authorize(action)` decorator wraps endpoint functions and performs the check: 1. Extract `AuthTuple` from endpoint dependencies 2. Resolve user roles from auth credentials using the `RolesResolver` 3. Check if those roles permit the requested action using the `AccessResolver` 4. Raise 403 if denied, continue if allowed All endpoints now declare what action they perform (`READ_CONVERSATIONS`, `FEEDBACK`, etc.) through the `@authorize` decorator. The middleware automatically handles the permission check before the endpoint runs. The middleware also populates the `request: Request` `state` property with the `authorized_actions` set, which contains all actions the user is allowed to perform based on their roles. Endpoints can then inspect this property to dynamically adjust their behavior based on what the user is allowed to do, for special actions such as listing others' conversations (as opposed to listing only the user's own conversations) or deleting conversations. These behaviors which are more "complicated" than just whether the endpoint is accessible or not, as they depend on the actual endpoint logic. # Backwards compatibility For non-JWT auth modules, we will default to no-op resolvers that allow all access, maintaining current behavior. # Technical notes - All endpoints should accept `auth: Any = Depends(get_auth_dependency())` to use the authorization system. The parameter must be named `auth` to be recognized by the middleware. # Config Example configuration: ```yaml authentication: module: jwk-token jwk_config: url: ${SSO_BASE_URL}/protocol/openid-connect/certs jwt_configuration: user_id_claim: ${USER_ID_CLAIM} username_claim: ${USERNAME_CLAIM} role_rules: - jsonpath: "$.realm_access.roles[*]" operator: "contains" value: "redhat:employees" roles: ["redhat_employee"] - jsonpath: "$.realm_access.roles[*]" operator: "contains" value: "candlepin_system_access_view_edit_all" roles: ["read_only_admin"] authorization: access_rules: - role: "admin" actions: - query_other_conversations - delete_other_conversations - role: "read_only_admin" actions: - list_other_conversations - read_other_conversations - role: "redhat_employee" actions: - get_models - get_config - role: "*" actions: - query - streaming_query - get_conversation - list_conversations - delete_conversation - feedback - get_metrics - info ``` --- pyproject.toml | 1 + src/app/endpoints/authorized.py | 5 +- src/app/endpoints/config.py | 23 ++- src/app/endpoints/conversations.py | 35 +++- src/app/endpoints/feedback.py | 9 +- src/app/endpoints/health.py | 27 ++- src/app/endpoints/info.py | 23 ++- src/app/endpoints/metrics.py | 21 +- src/app/endpoints/models.py | 24 ++- src/app/endpoints/query.py | 16 +- src/app/endpoints/root.py | 23 ++- src/app/endpoints/streaming_query.py | 8 +- src/auth/jwk_token.py | 3 +- src/authorization/__init__.py | 1 + src/authorization/middleware.py | 115 +++++++++++ src/authorization/resolvers.py | 186 ++++++++++++++++++ src/configuration.py | 13 ++ src/models/config.py | 121 +++++++++++- src/utils/endpoints.py | 22 ++- tests/unit/app/endpoints/test_config.py | 29 ++- .../unit/app/endpoints/test_conversations.py | 169 ++++++++++++---- tests/unit/app/endpoints/test_feedback.py | 15 +- tests/unit/app/endpoints/test_health.py | 27 ++- tests/unit/app/endpoints/test_info.py | 16 +- tests/unit/app/endpoints/test_metrics.py | 8 +- tests/unit/app/endpoints/test_models.py | 36 +++- tests/unit/app/endpoints/test_query.py | 72 +++++-- tests/unit/app/endpoints/test_root.py | 10 +- tests/unit/auth/test_jwk_token.py | 13 +- tests/unit/authorization/__init__.py | 1 + tests/unit/authorization/test_resolvers.py | 101 ++++++++++ tests/unit/models/test_config.py | 2 + tests/unit/utils/auth_helpers.py | 27 +++ uv.lock | 23 +++ 34 files changed, 1092 insertions(+), 133 deletions(-) create mode 100644 src/authorization/__init__.py create mode 100644 src/authorization/middleware.py create mode 100644 src/authorization/resolvers.py create mode 100644 tests/unit/authorization/__init__.py create mode 100644 tests/unit/authorization/test_resolvers.py create mode 100644 tests/unit/utils/auth_helpers.py diff --git a/pyproject.toml b/pyproject.toml index eae1134db..82fb2d9cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "openai==1.99.9", "sqlalchemy>=2.0.42", "semver<4.0.0", + "jsonpath-ng>=1.6.1", ] diff --git a/src/app/endpoints/authorized.py b/src/app/endpoints/authorized.py index 9ada28f74..8fa029ee6 100644 --- a/src/app/endpoints/authorized.py +++ b/src/app/endpoints/authorized.py @@ -1,10 +1,11 @@ """Handler for REST API call to authorized endpoint.""" import logging -from typing import Any +from typing import Annotated, Any from fastapi import APIRouter, Depends +from auth.interface import AuthTuple from auth import get_auth_dependency from models.responses import AuthorizedResponse, UnauthorizedResponse, ForbiddenResponse @@ -31,7 +32,7 @@ @router.post("/authorized", responses=authorized_responses) async def authorized_endpoint_handler( - auth: Any = Depends(auth_dependency), + auth: Annotated[AuthTuple, Depends(auth_dependency)], ) -> AuthorizedResponse: """ Handle request to the /authorized endpoint. diff --git a/src/app/endpoints/config.py b/src/app/endpoints/config.py index fec854294..a9d2daac1 100644 --- a/src/app/endpoints/config.py +++ b/src/app/endpoints/config.py @@ -1,17 +1,22 @@ """Handler for REST API call to retrieve service configuration.""" import logging -from typing import Any +from typing import Annotated, Any -from fastapi import APIRouter, Request +from fastapi import APIRouter, Request, Depends -from models.config import Configuration +from auth.interface import AuthTuple +from auth import get_auth_dependency +from authorization.middleware import authorize from configuration import configuration +from models.config import Action, Configuration from utils.endpoints import check_configuration_loaded logger = logging.getLogger(__name__) router = APIRouter(tags=["config"]) +auth_dependency = get_auth_dependency() + get_config_responses: dict[int | str, dict[str, Any]] = { 200: { @@ -56,7 +61,11 @@ @router.get("/config", responses=get_config_responses) -def config_endpoint_handler(_request: Request) -> Configuration: +@authorize(Action.GET_CONFIG) +async def config_endpoint_handler( + auth: Annotated[AuthTuple, Depends(auth_dependency)], + request: Request, +) -> Configuration: """ Handle requests to the /config endpoint. @@ -66,6 +75,12 @@ def config_endpoint_handler(_request: Request) -> Configuration: Returns: Configuration: The loaded service configuration object. """ + # Used only for authorization + _ = auth + + # Nothing interesting in the request + _ = request + # ensure that configuration is loaded check_configuration_loaded(configuration) diff --git a/src/app/endpoints/conversations.py b/src/app/endpoints/conversations.py index bbb7825d1..ba8440855 100644 --- a/src/app/endpoints/conversations.py +++ b/src/app/endpoints/conversations.py @@ -5,19 +5,21 @@ from llama_stack_client import APIConnectionError, NotFoundError -from fastapi import APIRouter, HTTPException, status, Depends +from fastapi import APIRouter, HTTPException, Request, status, Depends from client import AsyncLlamaStackClientHolder from configuration import configuration +from app.database import get_session +from auth import get_auth_dependency +from authorization.middleware import authorize +from models.config import Action +from models.database.conversations import UserConversation from models.responses import ( ConversationResponse, ConversationDeleteResponse, ConversationsListResponse, ConversationDetails, ) -from models.database.conversations import UserConversation -from auth import get_auth_dependency -from app.database import get_session from utils.endpoints import check_configuration_loaded, validate_conversation_ownership from utils.suid import check_suid @@ -146,7 +148,9 @@ def simplify_session_data(session_data: dict) -> list[dict[str, Any]]: @router.get("/conversations", responses=conversations_list_responses) -def get_conversations_list_endpoint_handler( +@authorize(Action.LIST_CONVERSATIONS) +async def get_conversations_list_endpoint_handler( + request: Request, auth: Any = Depends(auth_dependency), ) -> ConversationsListResponse: """Handle request to retrieve all conversations for the authenticated user.""" @@ -158,11 +162,16 @@ def get_conversations_list_endpoint_handler( with get_session() as session: try: - # Get all conversations for this user - user_conversations = ( - session.query(UserConversation).filter_by(user_id=user_id).all() + query = session.query(UserConversation) + + filtered_query = ( + query + if Action.LIST_OTHERS_CONVERSATIONS in request.state.authorized_actions + else query.filter_by(user_id=user_id) ) + user_conversations = filtered_query.all() + # Return conversation summaries with metadata conversations = [ ConversationDetails( @@ -200,7 +209,9 @@ def get_conversations_list_endpoint_handler( @router.get("/conversations/{conversation_id}", responses=conversation_responses) +@authorize(Action.GET_CONVERSATION) async def get_conversation_endpoint_handler( + request: Request, conversation_id: str, auth: Any = Depends(auth_dependency), ) -> ConversationResponse: @@ -239,6 +250,9 @@ async def get_conversation_endpoint_handler( validate_conversation_ownership( user_id=user_id, conversation_id=conversation_id, + others_allowed=( + Action.READ_OTHERS_CONVERSATIONS in request.state.authorized_actions + ), ) agent_id = conversation_id @@ -309,7 +323,9 @@ async def get_conversation_endpoint_handler( @router.delete( "/conversations/{conversation_id}", responses=conversation_delete_responses ) +@authorize(Action.DELETE_CONVERSATION) async def delete_conversation_endpoint_handler( + request: Request, conversation_id: str, auth: Any = Depends(auth_dependency), ) -> ConversationDeleteResponse: @@ -342,6 +358,9 @@ async def delete_conversation_endpoint_handler( validate_conversation_ownership( user_id=user_id, conversation_id=conversation_id, + others_allowed=( + Action.DELETE_OTHERS_CONVERSATIONS in request.state.authorized_actions + ), ) agent_id = conversation_id diff --git a/src/app/endpoints/feedback.py b/src/app/endpoints/feedback.py index 5d82267ee..39d2a2692 100644 --- a/src/app/endpoints/feedback.py +++ b/src/app/endpoints/feedback.py @@ -5,11 +5,14 @@ from pathlib import Path import json from datetime import datetime, UTC -from fastapi import APIRouter, Request, HTTPException, Depends, status +from fastapi import APIRouter, HTTPException, Depends, Request, status from auth import get_auth_dependency from auth.interface import AuthTuple +from authorization.middleware import authorize from configuration import configuration +from models.config import Action +from models.requests import FeedbackRequest from models.responses import ( ErrorResponse, FeedbackResponse, @@ -17,7 +20,6 @@ UnauthorizedResponse, ForbiddenResponse, ) -from models.requests import FeedbackRequest from utils.suid import get_suid logger = logging.getLogger(__name__) @@ -79,7 +81,8 @@ async def assert_feedback_enabled(_request: Request) -> None: @router.post("", responses=feedback_response) -def feedback_endpoint_handler( +@authorize(Action.FEEDBACK) +async def feedback_endpoint_handler( feedback_request: FeedbackRequest, auth: Annotated[AuthTuple, Depends(auth_dependency)], _ensure_feedback_enabled: Any = Depends(assert_feedback_enabled), diff --git a/src/app/endpoints/health.py b/src/app/endpoints/health.py index f1ade175f..f62c9a013 100644 --- a/src/app/endpoints/health.py +++ b/src/app/endpoints/health.py @@ -6,12 +6,16 @@ """ import logging -from typing import Any +from typing import Annotated, Any from llama_stack.providers.datatypes import HealthStatus -from fastapi import APIRouter, status, Response +from fastapi import APIRouter, status, Response, Depends from client import AsyncLlamaStackClientHolder +from auth.interface import AuthTuple +from auth import get_auth_dependency +from authorization.middleware import authorize +from models.config import Action from models.responses import ( LivenessResponse, ReadinessResponse, @@ -21,6 +25,8 @@ logger = logging.getLogger(__name__) router = APIRouter(tags=["health"]) +auth_dependency = get_auth_dependency() + async def get_providers_health_statuses() -> list[ProviderHealthStatus]: """ @@ -72,7 +78,11 @@ async def get_providers_health_statuses() -> list[ProviderHealthStatus]: @router.get("/readiness", responses=get_readiness_responses) -async def readiness_probe_get_method(response: Response) -> ReadinessResponse: +@authorize(Action.INFO) +async def readiness_probe_get_method( + auth: Annotated[AuthTuple, Depends(auth_dependency)], + response: Response, +) -> ReadinessResponse: """ Handle the readiness probe endpoint, returning service readiness. @@ -80,6 +90,9 @@ async def readiness_probe_get_method(response: Response) -> ReadinessResponse: and details of unhealthy providers; otherwise, indicates the service is ready. """ + # Used only for authorization + _ = auth + provider_statuses = await get_providers_health_statuses() # Check if any provider is unhealthy (not counting not_implemented as unhealthy) @@ -112,11 +125,17 @@ async def readiness_probe_get_method(response: Response) -> ReadinessResponse: @router.get("/liveness", responses=get_liveness_responses) -def liveness_probe_get_method() -> LivenessResponse: +@authorize(Action.INFO) +async def liveness_probe_get_method( + auth: Annotated[AuthTuple, Depends(auth_dependency)], +) -> LivenessResponse: """ Return the liveness status of the service. Returns: LivenessResponse: Indicates that the service is alive. """ + # Used only for authorization + _ = auth + return LivenessResponse(alive=True) diff --git a/src/app/endpoints/info.py b/src/app/endpoints/info.py index d0c5f1098..ecc124297 100644 --- a/src/app/endpoints/info.py +++ b/src/app/endpoints/info.py @@ -1,17 +1,24 @@ """Handler for REST API call to provide info.""" import logging -from typing import Any +from typing import Annotated, Any from fastapi import APIRouter, Request +from fastapi import Depends +from auth.interface import AuthTuple +from auth import get_auth_dependency +from authorization.middleware import authorize from configuration import configuration -from version import __version__ +from models.config import Action from models.responses import InfoResponse +from version import __version__ logger = logging.getLogger(__name__) router = APIRouter(tags=["info"]) +auth_dependency = get_auth_dependency() + get_info_responses: dict[int | str, dict[str, Any]] = { 200: { @@ -22,7 +29,11 @@ @router.get("/info", responses=get_info_responses) -def info_endpoint_handler(_request: Request) -> InfoResponse: +@authorize(Action.INFO) +async def info_endpoint_handler( + auth: Annotated[AuthTuple, Depends(auth_dependency)], + request: Request, +) -> InfoResponse: """ Handle request to the /info endpoint. @@ -32,4 +43,10 @@ def info_endpoint_handler(_request: Request) -> InfoResponse: Returns: InfoResponse: An object containing the service's name and version. """ + # Used only for authorization + _ = auth + + # Nothing interesting in the request + _ = request + return InfoResponse(name=configuration.configuration.name, version=__version__) diff --git a/src/app/endpoints/metrics.py b/src/app/endpoints/metrics.py index 10b986e6e..e45e6c668 100644 --- a/src/app/endpoints/metrics.py +++ b/src/app/endpoints/metrics.py @@ -1,19 +1,30 @@ """Handler for REST API call to provide metrics.""" +from typing import Annotated from fastapi.responses import PlainTextResponse -from fastapi import APIRouter, Request +from fastapi import APIRouter, Request, Depends from prometheus_client import ( generate_latest, CONTENT_TYPE_LATEST, ) +from auth.interface import AuthTuple +from auth import get_auth_dependency +from authorization.middleware import authorize +from models.config import Action from metrics.utils import setup_model_metrics router = APIRouter(tags=["metrics"]) +auth_dependency = get_auth_dependency() + @router.get("/metrics", response_class=PlainTextResponse) -async def metrics_endpoint_handler(_request: Request) -> PlainTextResponse: +@authorize(Action.GET_METRICS) +async def metrics_endpoint_handler( + auth: Annotated[AuthTuple, Depends(auth_dependency)], + request: Request, +) -> PlainTextResponse: """ Handle request to the /metrics endpoint. @@ -24,6 +35,12 @@ async def metrics_endpoint_handler(_request: Request) -> PlainTextResponse: set up, then responds with the current metrics snapshot in Prometheus format. """ + # Used only for authorization + _ = auth + + # Nothing interesting in the request + _ = request + # Setup the model metrics if not already done. This is a one-time setup # and will not be run again on subsequent calls to this endpoint await setup_model_metrics() diff --git a/src/app/endpoints/models.py b/src/app/endpoints/models.py index feac9e4cd..cdd85b930 100644 --- a/src/app/endpoints/models.py +++ b/src/app/endpoints/models.py @@ -1,13 +1,18 @@ """Handler for REST API call to list available models.""" import logging -from typing import Any +from typing import Annotated, Any -from llama_stack_client import APIConnectionError from fastapi import APIRouter, HTTPException, Request, status +from fastapi.params import Depends +from llama_stack_client import APIConnectionError +from auth import get_auth_dependency +from auth.interface import AuthTuple from client import AsyncLlamaStackClientHolder from configuration import configuration +from authorization.middleware import authorize +from models.config import Action from models.responses import ModelsResponse from utils.endpoints import check_configuration_loaded @@ -15,6 +20,9 @@ router = APIRouter(tags=["models"]) +auth_dependency = get_auth_dependency() + + models_responses: dict[int | str, dict[str, Any]] = { 200: { "models": [ @@ -43,7 +51,11 @@ @router.get("/models", responses=models_responses) -async def models_endpoint_handler(_request: Request) -> ModelsResponse: +@authorize(Action.GET_MODELS) +async def models_endpoint_handler( + request: Request, + auth: Annotated[AuthTuple, Depends(auth_dependency)], +) -> ModelsResponse: """ Handle requests to the /models endpoint. @@ -57,6 +69,12 @@ async def models_endpoint_handler(_request: Request) -> ModelsResponse: Returns: ModelsResponse: An object containing the list of available models. """ + # Used only by the middleware + _ = auth + + # Nothing interesting in the request + _ = request + check_configuration_loaded(configuration) llama_stack_configuration = configuration.llama_stack_configuration diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index eae3ae287..7c5343b7f 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -16,7 +16,7 @@ ) from llama_stack_client.types.model_list_response import ModelListResponse -from fastapi import APIRouter, HTTPException, status, Depends +from fastapi import APIRouter, HTTPException, Request, status, Depends from auth import get_auth_dependency from auth.interface import AuthTuple @@ -24,10 +24,12 @@ from configuration import configuration from app.database import get_session import metrics +import constants +from authorization.middleware import authorize +from models.config import Action from models.database.conversations import UserConversation -from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse from models.requests import QueryRequest, Attachment -import constants +from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse from utils.endpoints import ( check_configuration_loaded, get_agent, @@ -148,7 +150,9 @@ def evaluate_model_hints( @router.post("/query", responses=query_response) +@authorize(Action.QUERY) async def query_endpoint_handler( + request: Request, query_request: QueryRequest, auth: Annotated[AuthTuple, Depends(auth_dependency)], mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), @@ -178,7 +182,11 @@ async def query_endpoint_handler( user_conversation: UserConversation | None = None if query_request.conversation_id: user_conversation = validate_conversation_ownership( - user_id=user_id, conversation_id=query_request.conversation_id + user_id=user_id, + conversation_id=query_request.conversation_id, + others_allowed=( + Action.QUERY_OTHERS_CONVERSATIONS in request.state.authorized_actions + ), ) if user_conversation is None: diff --git a/src/app/endpoints/root.py b/src/app/endpoints/root.py index e34ff0d6e..b02234042 100644 --- a/src/app/endpoints/root.py +++ b/src/app/endpoints/root.py @@ -1,14 +1,21 @@ """Handler for the / endpoint.""" import logging +from typing import Annotated -from fastapi import APIRouter, Request +from fastapi import APIRouter, Request, Depends from fastapi.responses import HTMLResponse -logger = logging.getLogger("app.endpoints.handlers") +from auth.interface import AuthTuple +from auth import get_auth_dependency +from authorization.middleware import authorize +from models.config import Action +logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["root"]) +auth_dependency = get_auth_dependency() + INDEX_PAGE = """ @@ -771,7 +778,17 @@ @router.get("/", response_class=HTMLResponse) -def root_endpoint_handler(_request: Request) -> HTMLResponse: +@authorize(Action.INFO) +async def root_endpoint_handler( + auth: Annotated[AuthTuple, Depends(auth_dependency)], + request: Request, +) -> HTMLResponse: """Handle request to the / endpoint.""" + # Used only for authorization + _ = auth + + # Nothing interesting in the request + _ = request + logger.info("Serving index page") return HTMLResponse(INDEX_PAGE) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 69d29ebd1..a7fe6788e 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -22,9 +22,11 @@ from auth import get_auth_dependency from auth.interface import AuthTuple +from authorization.middleware import authorize from client import AsyncLlamaStackClientHolder from configuration import configuration import metrics +from models.config import Action from models.requests import QueryRequest from models.database.conversations import UserConversation from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt @@ -509,8 +511,9 @@ def _handle_heartbeat_event(chunk_id: int) -> Iterator[str]: @router.post("/streaming_query") +@authorize(Action.STREAMING_QUERY) async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals - _request: Request, + request: Request, query_request: QueryRequest, auth: Annotated[AuthTuple, Depends(auth_dependency)], mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), @@ -533,6 +536,9 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals HTTPException: Returns HTTP 500 if unable to connect to the Llama Stack server. """ + # Nothing interesting in the request + _ = request + check_configuration_loaded(configuration) llama_stack_config = configuration.llama_stack_configuration diff --git a/src/auth/jwk_token.py b/src/auth/jwk_token.py index 9fa493e6d..ad68abaed 100644 --- a/src/auth/jwk_token.py +++ b/src/auth/jwk_token.py @@ -1,6 +1,7 @@ """Manage authentication flow for FastAPI endpoints with JWK based JWT auth.""" import logging +import json from asyncio import Lock from typing import Any, Callable @@ -188,4 +189,4 @@ async def __call__(self, request: Request) -> tuple[str, str, str]: logger.info("Successfully authenticated user %s (ID: %s)", username, user_id) - return user_id, username, user_token + return user_id, username, json.dumps(claims) diff --git a/src/authorization/__init__.py b/src/authorization/__init__.py new file mode 100644 index 000000000..fff3fce71 --- /dev/null +++ b/src/authorization/__init__.py @@ -0,0 +1 @@ +"""Authorization module for role-based access control.""" diff --git a/src/authorization/middleware.py b/src/authorization/middleware.py new file mode 100644 index 000000000..a919c8789 --- /dev/null +++ b/src/authorization/middleware.py @@ -0,0 +1,115 @@ +"""Authorization middleware and decorators.""" + +import logging +from functools import wraps, lru_cache +from typing import Any, Callable, Tuple +from fastapi import HTTPException, status + +from authorization.resolvers import ( + AccessResolver, + GenericAccessResolver, + JwtRolesResolver, + NoopAccessResolver, + NoopRolesResolver, + RolesResolver, +) +from models.config import Action +from configuration import configuration +import constants + +logger = logging.getLogger(__name__) + + +@lru_cache(maxsize=1) +def get_authorization_resolvers() -> Tuple[RolesResolver, AccessResolver]: + """Get authorization resolvers from configuration (cached).""" + authorization_cfg = configuration.authorization_configuration + authentication_config = configuration.authentication_configuration + + match authentication_config.module: + case ( + constants.AUTH_MOD_NOOP + | constants.AUTH_MOD_K8S + | constants.AUTH_MOD_NOOP_WITH_TOKEN + ): + return ( + NoopRolesResolver(), + NoopAccessResolver(), + ) + case constants.AUTH_MOD_JWK_TOKEN: + jwt_role_rules_unset = ( + len( + authentication_config.jwk_configuration.jwt_configuration.role_rules + ) + ) == 0 + + authz_access_rules_unset = len(authorization_cfg.access_rules) == 0 + + if jwt_role_rules_unset or authz_access_rules_unset: + return NoopRolesResolver(), NoopAccessResolver() + + return ( + JwtRolesResolver( + role_rules=( + authentication_config.jwk_configuration.jwt_configuration.role_rules + ) + ), + GenericAccessResolver(authorization_cfg.access_rules), + ) + + case _: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Internal server error", + ) + + +async def _perform_authorization_check(action: Action, kwargs: dict[str, Any]) -> None: + """Perform authorization check - common logic for all decorators.""" + role_resolver, access_resolver = get_authorization_resolvers() + + try: + auth = kwargs["auth"] + except KeyError as exc: + logger.error( + "Authorization only allowed on endpoints that accept " + "'auth: Any = Depends(get_auth_dependency())'" + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Internal server error", + ) from exc + + # Everyone gets the everyone (aka *) role + everyone_roles = {"*"} + + user_roles = await role_resolver.resolve_roles(auth) | everyone_roles + + if not access_resolver.check_access(action, user_roles): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Insufficient permissions for action: {action}", + ) + + authorized_actions = access_resolver.get_actions(user_roles) + + try: + request = kwargs["request"] + request.state.authorized_actions = authorized_actions + except KeyError: + # This endpoint doesn't seem care about the authorized actions, so no need to set it + pass + + +def authorize(action: Action) -> Callable: + """Check authorization for an endpoint (async version).""" + + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + await _perform_authorization_check(action, kwargs) + return await func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/src/authorization/resolvers.py b/src/authorization/resolvers.py new file mode 100644 index 000000000..cdba3f4f6 --- /dev/null +++ b/src/authorization/resolvers.py @@ -0,0 +1,186 @@ +"""Authorization resolvers for role evaluation and access control.""" + +from abc import ABC, abstractmethod +import json +import logging +from typing import Any + +from jsonpath_ng import parse + +from auth.interface import AuthTuple +from models.config import JwtRoleRule, AccessRule, JsonPathOperator, Action + +logger = logging.getLogger(__name__) + + +UserRoles = set[str] + + +class RoleResolutionError(Exception): + """Custom exception for role resolution errors.""" + + +class RolesResolver(ABC): # pylint: disable=too-few-public-methods + """Base class for all role resolution strategies.""" + + @abstractmethod + async def resolve_roles(self, auth: AuthTuple) -> UserRoles: + """Given an auth tuple, return the list of user roles.""" + + +class NoopRolesResolver(RolesResolver): # pylint: disable=too-few-public-methods + """No-op roles resolver that does not perform any role resolution.""" + + async def resolve_roles(self, auth: AuthTuple) -> UserRoles: + """Return an empty list of roles.""" + _ = auth # Unused + return set() + + +class JwtRolesResolver(RolesResolver): # pylint: disable=too-few-public-methods + """Processes JWT claims with the given JSONPath rules to get roles.""" + + def __init__(self, role_rules: list[JwtRoleRule]): + """Initialize the resolver with rules.""" + self.role_rules = role_rules + + async def resolve_roles(self, auth: AuthTuple) -> UserRoles: + """Extract roles from JWT claims using configured rules.""" + jwt_claims = self._get_claims(auth) + return { + role + for rule in self.role_rules + for role in self.evaluate_role_rules(rule, jwt_claims) + } + + @staticmethod + def evaluate_role_rules(rule: JwtRoleRule, jwt_claims: dict[str, Any]) -> UserRoles: + """Get roles from a JWT role rule if it matches the claims.""" + return ( + set(rule.roles) + if JwtRolesResolver._evaluate_operator( + rule.negate, + [match.value for match in parse(rule.jsonpath).find(jwt_claims)], + rule.operator, + rule.value, + ) + else set() + ) + + @staticmethod + def _get_claims(auth: AuthTuple) -> dict[str, Any]: + """Get the JWT claims from the auth tuple.""" + _, _, token = auth + jwt_claims = json.loads(token) + + if not jwt_claims: + raise RoleResolutionError( + "Invalid authentication token: no JWT claims found" + ) + + return jwt_claims + + @staticmethod + def _evaluate_operator( + negate: bool, match: Any, operator: JsonPathOperator, value: Any + ) -> bool: # pylint: disable=too-many-branches + """Evaluate an operator against a match and value.""" + result = False + match operator: + case JsonPathOperator.EQUALS: + result = match == value + case JsonPathOperator.CONTAINS: + result = value in match + case JsonPathOperator.IN: + result = match in value + + if negate: + result = not result + + return result + + +class AccessResolver(ABC): # pylint: disable=too-few-public-methods + """Base class for all access resolution strategies.""" + + @abstractmethod + def check_access(self, action: Action, user_roles: UserRoles) -> bool: + """Check if the user has access to the specified action based on their roles.""" + + @abstractmethod + def get_actions(self, user_roles: UserRoles) -> set[Action]: + """Get the actions that the user can perform based on their roles.""" + + +class NoopAccessResolver(AccessResolver): # pylint: disable=too-few-public-methods + """No-op access resolver that does not perform any access checks.""" + + def check_access(self, action: Action, user_roles: UserRoles) -> bool: + """Return True always, indicating access is granted.""" + _ = action # We're noop, it doesn't matter, everyone is allowed + _ = user_roles # We're noop, it doesn't matter, everyone is allowed + return True + + def get_actions(self, user_roles: UserRoles) -> set[Action]: + """Return an empty set of actions, indicating no specific actions are allowed.""" + _ = user_roles # We're noop, it doesn't matter, everyone is allowed + return set(Action) - {Action.ADMIN} + + +class GenericAccessResolver(AccessResolver): # pylint: disable=too-few-public-methods + """Generic role-based access resolver, should apply with most authentication methods. + + This resolver simply checks if a list of roles allow a user to perform a specific + action. The special action ADMIN will grant the user the ability to perform any action, + """ + + def __init__(self, access_rules: list[AccessRule]): + """Initialize the access resolver with access rules.""" + for rule in access_rules: + # Since this is nonsensical, it might be a mistake, so hard fail + if Action.ADMIN in rule.actions and len(rule.actions) > 1: + raise ValueError( + "Access rule with 'admin' action cannot have other actions" + ) + + self.access_rules = access_rules + + # Build a lookup table for access rules + self._access_lookup: dict[str, set[Action]] = {} + for rule in access_rules: + if rule.role not in self._access_lookup: + self._access_lookup[rule.role] = set() + self._access_lookup[rule.role].update(rule.actions) + + def check_access(self, action: Action, user_roles: UserRoles) -> bool: + """Check if the user has access to the specified action based on their roles.""" + if action != Action.ADMIN and self.check_access(Action.ADMIN, user_roles): + # Recurse to check if the roles allow the user to perform the admin action, + # if they do, then we allow any action + return True + + for role in user_roles: + if role in self._access_lookup and action in self._access_lookup[role]: + logger.debug( + "Access granted: role '%s' can perform action '%s'", role, action + ) + return True + + logger.debug( + "Access denied: roles %s cannot perform action '%s'", user_roles, action + ) + return False + + def get_actions(self, user_roles: UserRoles) -> set[Action]: + """Get the actions that the user can perform based on their roles.""" + actions = { + action + for role in user_roles + for action in self._access_lookup.get(role, set()) + } + + # If the user is allowed the admin action, they can perform any action + if Action.ADMIN in actions: + return set(Action) - {Action.ADMIN} + + return actions diff --git a/src/configuration.py b/src/configuration.py index 573520d20..8c3184541 100644 --- a/src/configuration.py +++ b/src/configuration.py @@ -9,6 +9,7 @@ import yaml from models.config import ( + AuthorizationConfiguration, Configuration, Customization, LlamaStackConfiguration, @@ -104,6 +105,18 @@ def authentication_configuration(self) -> AuthenticationConfiguration: return self._configuration.authentication + @property + def authorization_configuration(self) -> AuthorizationConfiguration: + """Return authorization configuration or default no-op configuration.""" + assert ( + self._configuration is not None + ), "logic error: configuration is not loaded" + + if self._configuration.authorization is None: + return AuthorizationConfiguration() + + return self._configuration.authorization + @property def customization(self) -> Optional[Customization]: """Return customization configuration.""" diff --git a/src/models/config.py b/src/models/config.py index 5419ec600..d9e280f4c 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -1,9 +1,19 @@ """Model with service configuration.""" from pathlib import Path -from typing import Optional - -from pydantic import BaseModel, model_validator, FilePath, AnyHttpUrl, PositiveInt +from typing import Optional, Any +from enum import Enum + +import jsonpath_ng +from jsonpath_ng.exceptions import JSONPathError +from pydantic import ( + BaseModel, + Field, + model_validator, + FilePath, + AnyHttpUrl, + PositiveInt, +) from typing_extensions import Self, Literal import constants @@ -210,11 +220,115 @@ def check_storage_location_is_set_when_needed(self) -> Self: return self +class JsonPathOperator(str, Enum): + """Supported operators for JSONPath evaluation.""" + + EQUALS = "equals" + CONTAINS = "contains" + IN = "in" + + +class JwtRoleRule(BaseModel): + """Rule for extracting roles from JWT claims.""" + + jsonpath: str # JSONPath expression to evaluate against the JWT payload + operator: JsonPathOperator # Comparison operator + negate: bool = False # If True, negate the rule + value: Any # Value to compare against + roles: list[str] # Roles to assign if rule matches + + @model_validator(mode="after") + def check_jsonpath(self) -> Self: + """Verify that the JSONPath expression is valid.""" + try: + jsonpath_ng.parse(self.jsonpath) + return self + except JSONPathError as e: + raise ValueError( + f"Invalid JSONPath expression: {self.jsonpath}: {e}" + ) from e + + @model_validator(mode="after") + def check_roles(self) -> Self: + """Ensure that at least one role is specified.""" + if not self.roles: + raise ValueError("At least one role must be specified in the rule") + + if len(self.roles) != len(set(self.roles)): + raise ValueError("Roles must be unique in the rule") + + if any(role == "*" for role in self.roles): + raise ValueError( + "The wildcard '*' role is not allowed in role rules, " + "everyone automatically gets this role" + ) + + return self + + +class Action(str, Enum): + """Available actions in the system.""" + + # Special action to allow unrestricted access to all actions + ADMIN = "admin" + + # List the conversations of other users + LIST_OTHERS_CONVERSATIONS = "list_other_conversations" + + # Read the contents of conversations of other users + READ_OTHERS_CONVERSATIONS = "read_other_conversations" + + # Continue the conversations of other users + QUERY_OTHERS_CONVERSATIONS = "query_other_conversations" + + # Delete the conversations of other users + DELETE_OTHERS_CONVERSATIONS = "delete_other_conversations" + + # Access the query endpoint + QUERY = "query" + + # Access the streaming query endpoint + STREAMING_QUERY = "streaming_query" + + # Access the conversation endpoint + GET_CONVERSATION = "get_conversation" + + # List own conversations + LIST_CONVERSATIONS = "list_conversations" + + # Access the conversation delete endpoint + DELETE_CONVERSATION = "delete_conversation" + FEEDBACK = "feedback" + GET_MODELS = "get_models" + GET_METRICS = "get_metrics" + GET_CONFIG = "get_config" + + INFO = "info" + + +class AccessRule(BaseModel): + """Rule defining what actions a role can perform.""" + + role: str # Role name + actions: list[Action] # Allowed actions for this role + + +class AuthorizationConfiguration(BaseModel): + """Authorization configuration.""" + + access_rules: list[AccessRule] = Field( + default_factory=list + ) # Rules for role-based access control + + class JwtConfiguration(BaseModel): """JWT configuration.""" user_id_claim: str = constants.DEFAULT_JWT_UID_CLAIM username_claim: str = constants.DEFAULT_JWT_USER_NAME_CLAIM + role_rules: list[JwtRoleRule] = Field( + default_factory=list + ) # Rules for extracting roles from JWT claims class JwkConfiguration(BaseModel): @@ -310,6 +424,7 @@ class Configuration(BaseModel): database: DatabaseConfiguration = DatabaseConfiguration() mcp_servers: list[ModelContextProtocolServer] = [] authentication: AuthenticationConfiguration = AuthenticationConfiguration() + authorization: Optional[AuthorizationConfiguration] = None customization: Optional[Customization] = None inference: InferenceConfiguration = InferenceConfiguration() diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index 993dc4e83..34c0147de 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -19,15 +19,25 @@ def validate_conversation_ownership( - user_id: str, conversation_id: str + user_id: str, conversation_id: str, others_allowed: bool = False ) -> UserConversation | None: - """Validate that the conversation belongs to the user.""" + """Validate that the conversation belongs to the user. + + Validates that the conversation with the given ID belongs to the user with the given ID. + If `others_allowed` is True, it allows conversations that do not belong to the user, + which is useful for admin access. + """ with get_session() as session: - conversation = ( - session.query(UserConversation) - .filter_by(id=conversation_id, user_id=user_id) - .first() + conversation_query = session.query(UserConversation) + + filtered_conversation_query = ( + conversation_query.filter_by(id=conversation_id) + if others_allowed + else conversation_query.filter_by(id=conversation_id, user_id=user_id) ) + + conversation: UserConversation | None = filtered_conversation_query.first() + return conversation diff --git a/tests/unit/app/endpoints/test_config.py b/tests/unit/app/endpoints/test_config.py index 48ea9603d..eb3f4993d 100644 --- a/tests/unit/app/endpoints/test_config.py +++ b/tests/unit/app/endpoints/test_config.py @@ -5,10 +5,14 @@ from fastapi import HTTPException, Request, status from app.endpoints.config import config_endpoint_handler from configuration import AppConfig +from tests.unit.utils.auth_helpers import mock_authorization_resolvers -def test_config_endpoint_handler_configuration_not_loaded(mocker): +@pytest.mark.asyncio +async def test_config_endpoint_handler_configuration_not_loaded(mocker): """Test the config endpoint handler.""" + mock_authorization_resolvers(mocker) + mocker.patch( "app.endpoints.config.configuration._configuration", new=None, @@ -20,14 +24,19 @@ def test_config_endpoint_handler_configuration_not_loaded(mocker): "type": "http", } ) - with pytest.raises(HTTPException) as e: - config_endpoint_handler(request) - assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR - assert e.detail["response"] == "Configuration is not loaded" + auth = ("test_user", "token", {}) + with pytest.raises(HTTPException) as exc_info: + await config_endpoint_handler(auth=auth, request=request) + + assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert exc_info.value.detail["response"] == "Configuration is not loaded" -def test_config_endpoint_handler_configuration_loaded(): +@pytest.mark.asyncio +async def test_config_endpoint_handler_configuration_loaded(mocker): """Test the config endpoint handler.""" + mock_authorization_resolvers(mocker) + config_dict = { "name": "foo", "service": { @@ -49,15 +58,21 @@ def test_config_endpoint_handler_configuration_loaded(): "authentication": { "module": "noop", }, + "authorization": {"access_rules": []}, "customization": None, } cfg = AppConfig() cfg.init_from_dict(config_dict) + + # Mock configuration + mocker.patch("app.endpoints.config.configuration", cfg) + request = Request( scope={ "type": "http", } ) - response = config_endpoint_handler(request) + auth = ("test_user", "token", {}) + response = await config_endpoint_handler(auth=auth, request=request) assert response is not None assert response == cfg.configuration diff --git a/tests/unit/app/endpoints/test_conversations.py b/tests/unit/app/endpoints/test_conversations.py index aed6fdc5b..f4fcd4c80 100644 --- a/tests/unit/app/endpoints/test_conversations.py +++ b/tests/unit/app/endpoints/test_conversations.py @@ -1,7 +1,9 @@ +# pylint: disable=redefined-outer-name + """Unit tests for the /conversations REST API endpoints.""" +from fastapi import HTTPException, status, Request import pytest -from fastapi import HTTPException, status from llama_stack_client import APIConnectionError, NotFoundError from app.endpoints.conversations import ( @@ -10,18 +12,34 @@ get_conversations_list_endpoint_handler, simplify_session_data, ) +from models.config import Action from models.responses import ( ConversationResponse, ConversationDeleteResponse, ConversationsListResponse, ) from configuration import AppConfig +from tests.unit.utils.auth_helpers import mock_authorization_resolvers MOCK_AUTH = ("mock_user_id", "mock_username", "mock_token") VALID_CONVERSATION_ID = "123e4567-e89b-12d3-a456-426614174000" INVALID_CONVERSATION_ID = "invalid-id" +@pytest.fixture +def dummy_request() -> Request: + """Mock request object for testing.""" + request = Request( + scope={ + "type": "http", + } + ) + + request.state.authorized_actions = set(Action) + + return request + + def create_mock_conversation( mocker, conversation_id, @@ -48,9 +66,11 @@ def mock_database_session(mocker, query_result=None): """Helper function to mock get_session with proper context manager support.""" mock_session = mocker.Mock() if query_result is not None: - mock_session.query.return_value.filter_by.return_value.all.return_value = ( - query_result - ) + # Mock both the filtered and unfiltered query paths + mock_query = mocker.Mock() + mock_query.all.return_value = query_result + mock_query.filter_by.return_value.all.return_value = query_result + mock_session.query.return_value = mock_query # Mock get_session to return a context manager mock_session_context = mocker.MagicMock() @@ -230,27 +250,35 @@ class TestGetConversationEndpoint: """Test cases for the GET /conversations/{conversation_id} endpoint.""" @pytest.mark.asyncio - async def test_configuration_not_loaded(self, mocker): + async def test_configuration_not_loaded(self, mocker, dummy_request): """Test the endpoint when configuration is not loaded.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", None) with pytest.raises(HTTPException) as exc_info: await get_conversation_endpoint_handler( - VALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + auth=MOCK_AUTH, ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert "Configuration is not loaded" in exc_info.value.detail["response"] @pytest.mark.asyncio - async def test_invalid_conversation_id_format(self, mocker, setup_configuration): + async def test_invalid_conversation_id_format( + self, mocker, setup_configuration, dummy_request + ): """Test the endpoint with an invalid conversation ID format.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=False) with pytest.raises(HTTPException) as exc_info: await get_conversation_endpoint_handler( - INVALID_CONVERSATION_ID, auth=MOCK_AUTH + conversation_id=INVALID_CONVERSATION_ID, + auth=MOCK_AUTH, + request=dummy_request, ) assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST @@ -258,8 +286,11 @@ async def test_invalid_conversation_id_format(self, mocker, setup_configuration) assert INVALID_CONVERSATION_ID in exc_info.value.detail["cause"] @pytest.mark.asyncio - async def test_llama_stack_connection_error(self, mocker, setup_configuration): + async def test_llama_stack_connection_error( + self, mocker, setup_configuration, dummy_request + ): """Test the endpoint when LlamaStack connection fails.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch("app.endpoints.conversations.validate_conversation_ownership") @@ -275,15 +306,20 @@ async def test_llama_stack_connection_error(self, mocker, setup_configuration): # simulate situation when it is not possible to connect to Llama Stack with pytest.raises(HTTPException) as exc_info: await get_conversation_endpoint_handler( - VALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + auth=MOCK_AUTH, ) assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE assert "Unable to connect to Llama Stack" in exc_info.value.detail["response"] @pytest.mark.asyncio - async def test_llama_stack_not_found_error(self, mocker, setup_configuration): + async def test_llama_stack_not_found_error( + self, mocker, setup_configuration, dummy_request + ): """Test the endpoint when LlamaStack returns NotFoundError.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch("app.endpoints.conversations.validate_conversation_ownership") @@ -300,7 +336,9 @@ async def test_llama_stack_not_found_error(self, mocker, setup_configuration): with pytest.raises(HTTPException) as exc_info: await get_conversation_endpoint_handler( - VALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + auth=MOCK_AUTH, ) assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND @@ -309,8 +347,11 @@ async def test_llama_stack_not_found_error(self, mocker, setup_configuration): assert VALID_CONVERSATION_ID in exc_info.value.detail["cause"] @pytest.mark.asyncio - async def test_session_retrieve_exception(self, mocker, setup_configuration): + async def test_session_retrieve_exception( + self, mocker, setup_configuration, dummy_request + ): """Test the endpoint when session retrieval raises an exception.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch("app.endpoints.conversations.validate_conversation_ownership") @@ -325,7 +366,9 @@ async def test_session_retrieve_exception(self, mocker, setup_configuration): with pytest.raises(HTTPException) as exc_info: await get_conversation_endpoint_handler( - VALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + auth=MOCK_AUTH, ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @@ -336,9 +379,15 @@ async def test_session_retrieve_exception(self, mocker, setup_configuration): @pytest.mark.asyncio async def test_successful_conversation_retrieval( - self, mocker, setup_configuration, mock_session_data, expected_chat_history - ): + self, + mocker, + setup_configuration, + mock_session_data, + expected_chat_history, + dummy_request, + ): # pylint: disable=too-many-arguments,too-many-positional-arguments """Test successful conversation retrieval with simplified response structure.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch("app.endpoints.conversations.validate_conversation_ownership") @@ -360,7 +409,7 @@ async def test_successful_conversation_retrieval( mock_client_holder.return_value.get_client.return_value = mock_client response = await get_conversation_endpoint_handler( - VALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, conversation_id=VALID_CONVERSATION_ID, auth=MOCK_AUTH ) assert isinstance(response, ConversationResponse) @@ -375,27 +424,35 @@ class TestDeleteConversationEndpoint: """Test cases for the DELETE /conversations/{conversation_id} endpoint.""" @pytest.mark.asyncio - async def test_configuration_not_loaded(self, mocker): + async def test_configuration_not_loaded(self, mocker, dummy_request): """Test the endpoint when configuration is not loaded.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", None) with pytest.raises(HTTPException) as exc_info: await delete_conversation_endpoint_handler( - VALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + auth=MOCK_AUTH, ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert "Configuration is not loaded" in exc_info.value.detail["response"] @pytest.mark.asyncio - async def test_invalid_conversation_id_format(self, mocker, setup_configuration): + async def test_invalid_conversation_id_format( + self, mocker, setup_configuration, dummy_request + ): """Test the endpoint with an invalid conversation ID format.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=False) with pytest.raises(HTTPException) as exc_info: await delete_conversation_endpoint_handler( - INVALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, + conversation_id=INVALID_CONVERSATION_ID, + auth=MOCK_AUTH, ) assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST @@ -403,8 +460,11 @@ async def test_invalid_conversation_id_format(self, mocker, setup_configuration) assert INVALID_CONVERSATION_ID in exc_info.value.detail["cause"] @pytest.mark.asyncio - async def test_llama_stack_connection_error(self, mocker, setup_configuration): + async def test_llama_stack_connection_error( + self, mocker, setup_configuration, dummy_request + ): """Test the endpoint when LlamaStack connection fails.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch("app.endpoints.conversations.validate_conversation_ownership") @@ -419,15 +479,20 @@ async def test_llama_stack_connection_error(self, mocker, setup_configuration): with pytest.raises(HTTPException) as exc_info: await delete_conversation_endpoint_handler( - VALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + auth=MOCK_AUTH, ) assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE assert "Unable to connect to Llama Stack" in exc_info.value.detail["response"] @pytest.mark.asyncio - async def test_llama_stack_not_found_error(self, mocker, setup_configuration): + async def test_llama_stack_not_found_error( + self, mocker, setup_configuration, dummy_request + ): """Test the endpoint when LlamaStack returns NotFoundError.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch("app.endpoints.conversations.validate_conversation_ownership") @@ -444,7 +509,9 @@ async def test_llama_stack_not_found_error(self, mocker, setup_configuration): with pytest.raises(HTTPException) as exc_info: await delete_conversation_endpoint_handler( - VALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + auth=MOCK_AUTH, ) assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND @@ -453,8 +520,11 @@ async def test_llama_stack_not_found_error(self, mocker, setup_configuration): assert VALID_CONVERSATION_ID in exc_info.value.detail["cause"] @pytest.mark.asyncio - async def test_session_deletion_exception(self, mocker, setup_configuration): + async def test_session_deletion_exception( + self, mocker, setup_configuration, dummy_request + ): """Test the endpoint when session deletion raises an exception.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch("app.endpoints.conversations.validate_conversation_ownership") @@ -471,7 +541,9 @@ async def test_session_deletion_exception(self, mocker, setup_configuration): with pytest.raises(HTTPException) as exc_info: await delete_conversation_endpoint_handler( - VALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + auth=MOCK_AUTH, ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @@ -482,8 +554,11 @@ async def test_session_deletion_exception(self, mocker, setup_configuration): ) @pytest.mark.asyncio - async def test_successful_conversation_deletion(self, mocker, setup_configuration): + async def test_successful_conversation_deletion( + self, mocker, setup_configuration, dummy_request + ): """Test successful conversation deletion.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch("app.endpoints.conversations.validate_conversation_ownership") @@ -501,7 +576,7 @@ async def test_successful_conversation_deletion(self, mocker, setup_configuratio mock_client_holder.return_value.get_client.return_value = mock_client response = await delete_conversation_endpoint_handler( - VALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, conversation_id=VALID_CONVERSATION_ID, auth=MOCK_AUTH ) assert isinstance(response, ConversationDeleteResponse) @@ -517,18 +592,26 @@ async def test_successful_conversation_deletion(self, mocker, setup_configuratio class TestGetConversationsListEndpoint: """Test cases for the GET /conversations endpoint.""" - def test_configuration_not_loaded(self, mocker): + @pytest.mark.asyncio + async def test_configuration_not_loaded(self, mocker, dummy_request): """Test the endpoint when configuration is not loaded.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", None) with pytest.raises(HTTPException) as exc_info: - get_conversations_list_endpoint_handler(auth=MOCK_AUTH) + await get_conversations_list_endpoint_handler( + auth=MOCK_AUTH, request=dummy_request + ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert "Configuration is not loaded" in exc_info.value.detail["response"] - def test_successful_conversations_list_retrieval(self, mocker, setup_configuration): + @pytest.mark.asyncio + async def test_successful_conversations_list_retrieval( + self, mocker, setup_configuration, dummy_request + ): """Test successful retrieval of conversations list.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) # Mock database session and query results @@ -554,7 +637,9 @@ def test_successful_conversations_list_retrieval(self, mocker, setup_configurati ] mock_database_session(mocker, mock_conversations) - response = get_conversations_list_endpoint_handler(auth=MOCK_AUTH) + response = await get_conversations_list_endpoint_handler( + auth=MOCK_AUTH, request=dummy_request + ) assert isinstance(response, ConversationsListResponse) assert len(response.conversations) == 2 @@ -567,21 +652,29 @@ def test_successful_conversations_list_retrieval(self, mocker, setup_configurati == "456e7890-e12b-34d5-a678-901234567890" ) - def test_empty_conversations_list(self, mocker, setup_configuration): + @pytest.mark.asyncio + async def test_empty_conversations_list( + self, mocker, setup_configuration, dummy_request + ): """Test when user has no conversations.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) # Mock database session with no results mock_database_session(mocker, []) - response = get_conversations_list_endpoint_handler(auth=MOCK_AUTH) + response = await get_conversations_list_endpoint_handler( + auth=MOCK_AUTH, request=dummy_request + ) assert isinstance(response, ConversationsListResponse) assert len(response.conversations) == 0 assert response.conversations == [] - def test_database_exception(self, mocker, setup_configuration): + @pytest.mark.asyncio + async def test_database_exception(self, mocker, setup_configuration, dummy_request): """Test when database query raises an exception.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) # Mock database session to raise exception @@ -589,7 +682,9 @@ def test_database_exception(self, mocker, setup_configuration): mock_session.query.side_effect = Exception("Database error") with pytest.raises(HTTPException) as exc_info: - get_conversations_list_endpoint_handler(auth=MOCK_AUTH) + await get_conversations_list_endpoint_handler( + auth=MOCK_AUTH, request=dummy_request + ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert "Unknown error" in exc_info.value.detail["response"] diff --git a/tests/unit/app/endpoints/test_feedback.py b/tests/unit/app/endpoints/test_feedback.py index 337552c05..4a155ef6d 100644 --- a/tests/unit/app/endpoints/test_feedback.py +++ b/tests/unit/app/endpoints/test_feedback.py @@ -11,6 +11,7 @@ store_feedback, feedback_status, ) +from tests.unit.utils.auth_helpers import mock_authorization_resolvers def test_is_feedback_enabled(): @@ -62,9 +63,12 @@ async def test_assert_feedback_enabled(mocker): ], ids=["no_categories", "with_negative_categories"], ) -def test_feedback_endpoint_handler(mocker, feedback_request_data): +@pytest.mark.asyncio +async def test_feedback_endpoint_handler(mocker, feedback_request_data): """Test that feedback_endpoint_handler processes feedback for different payloads.""" + mock_authorization_resolvers(mocker) + # Mock the dependencies mocker.patch("app.endpoints.feedback.assert_feedback_enabled", return_value=None) mocker.patch("app.endpoints.feedback.store_feedback", return_value=None) @@ -74,7 +78,7 @@ def test_feedback_endpoint_handler(mocker, feedback_request_data): feedback_request.model_dump.return_value = feedback_request_data # Call the endpoint handler - result = feedback_endpoint_handler( + result = await feedback_endpoint_handler( feedback_request=feedback_request, _ensure_feedback_enabled=assert_feedback_enabled, auth=("test_user_id", "test_username", "test_token"), @@ -84,8 +88,11 @@ def test_feedback_endpoint_handler(mocker, feedback_request_data): assert result.response == "feedback received" -def test_feedback_endpoint_handler_error(mocker): +@pytest.mark.asyncio +async def test_feedback_endpoint_handler_error(mocker): """Test that feedback_endpoint_handler raises an HTTPException on error.""" + mock_authorization_resolvers(mocker) + # Mock the dependencies mocker.patch("app.endpoints.feedback.assert_feedback_enabled", return_value=None) mocker.patch( @@ -98,7 +105,7 @@ def test_feedback_endpoint_handler_error(mocker): # Call the endpoint handler and assert it raises an exception with pytest.raises(HTTPException) as exc_info: - feedback_endpoint_handler( + await feedback_endpoint_handler( feedback_request=feedback_request, _ensure_feedback_enabled=assert_feedback_enabled, auth=("test_user_id", "test_username", "test_token"), diff --git a/tests/unit/app/endpoints/test_health.py b/tests/unit/app/endpoints/test_health.py index 02714ad26..6e435adf6 100644 --- a/tests/unit/app/endpoints/test_health.py +++ b/tests/unit/app/endpoints/test_health.py @@ -2,18 +2,22 @@ from unittest.mock import Mock +import pytest from llama_stack.providers.datatypes import HealthStatus - from app.endpoints.health import ( readiness_probe_get_method, liveness_probe_get_method, get_providers_health_statuses, ) from models.responses import ProviderHealthStatus, ReadinessResponse +from tests.unit.utils.auth_helpers import mock_authorization_resolvers +@pytest.mark.asyncio async def test_readiness_probe_fails_due_to_unhealthy_providers(mocker): """Test the readiness endpoint handler fails when providers are unhealthy.""" + mock_authorization_resolvers(mocker) + # Mock get_providers_health_statuses to return an unhealthy provider mock_get_providers_health_statuses = mocker.patch( "app.endpoints.health.get_providers_health_statuses" @@ -26,10 +30,11 @@ async def test_readiness_probe_fails_due_to_unhealthy_providers(mocker): ) ] - # Mock the Response object + # Mock the Response object and auth mock_response = Mock() + auth = ("test_user", "token", {}) - response = await readiness_probe_get_method(mock_response) + response = await readiness_probe_get_method(auth=auth, response=mock_response) assert response.ready is False assert "test_provider" in response.reason @@ -37,8 +42,11 @@ async def test_readiness_probe_fails_due_to_unhealthy_providers(mocker): assert mock_response.status_code == 503 +@pytest.mark.asyncio async def test_readiness_probe_success_when_all_providers_healthy(mocker): """Test the readiness endpoint handler succeeds when all providers are healthy.""" + mock_authorization_resolvers(mocker) + # Mock get_providers_health_statuses to return healthy providers mock_get_providers_health_statuses = mocker.patch( "app.endpoints.health.get_providers_health_statuses" @@ -56,10 +64,11 @@ async def test_readiness_probe_success_when_all_providers_healthy(mocker): ), ] - # Mock the Response object + # Mock the Response object and auth mock_response = Mock() + auth = ("test_user", "token", {}) - response = await readiness_probe_get_method(mock_response) + response = await readiness_probe_get_method(auth=auth, response=mock_response) assert response is not None assert isinstance(response, ReadinessResponse) assert response.ready is True @@ -68,9 +77,13 @@ async def test_readiness_probe_success_when_all_providers_healthy(mocker): assert len(response.providers) == 0 -def test_liveness_probe(): +@pytest.mark.asyncio +async def test_liveness_probe(mocker): """Test the liveness endpoint handler.""" - response = liveness_probe_get_method() + mock_authorization_resolvers(mocker) + + auth = ("test_user", "token", {}) + response = await liveness_probe_get_method(auth=auth) assert response is not None assert response.alive is True diff --git a/tests/unit/app/endpoints/test_info.py b/tests/unit/app/endpoints/test_info.py index 4e8f30ca5..837c58ded 100644 --- a/tests/unit/app/endpoints/test_info.py +++ b/tests/unit/app/endpoints/test_info.py @@ -1,12 +1,15 @@ """Unit tests for the /info REST API endpoint.""" +import pytest from fastapi import Request from app.endpoints.info import info_endpoint_handler from configuration import AppConfig +from tests.unit.utils.auth_helpers import mock_authorization_resolvers -def test_info_endpoint(): +@pytest.mark.asyncio +async def test_info_endpoint(mocker): """Test the info endpoint handler.""" config_dict = { "name": "foo", @@ -27,15 +30,24 @@ def test_info_endpoint(): "feedback_enabled": False, }, "customization": None, + "authorization": {"access_rules": []}, + "authentication": {"module": "noop"}, } cfg = AppConfig() cfg.init_from_dict(config_dict) + + # Mock configuration + mocker.patch("configuration.configuration", cfg) + + mock_authorization_resolvers(mocker) + request = Request( scope={ "type": "http", } ) - response = info_endpoint_handler(request) + auth = ("test_user", "token", {}) + response = await info_endpoint_handler(auth=auth, request=request) assert response is not None assert response.name is not None assert response.version is not None diff --git a/tests/unit/app/endpoints/test_metrics.py b/tests/unit/app/endpoints/test_metrics.py index 8d0742151..95db10b64 100644 --- a/tests/unit/app/endpoints/test_metrics.py +++ b/tests/unit/app/endpoints/test_metrics.py @@ -1,12 +1,17 @@ """Unit tests for the /metrics REST API endpoint.""" +import pytest from fastapi import Request from app.endpoints.metrics import metrics_endpoint_handler +from tests.unit.utils.auth_helpers import mock_authorization_resolvers +@pytest.mark.asyncio async def test_metrics_endpoint(mocker): """Test the metrics endpoint handler.""" + mock_authorization_resolvers(mocker) + mock_setup_metrics = mocker.patch( "app.endpoints.metrics.setup_model_metrics", return_value=None ) @@ -15,7 +20,8 @@ async def test_metrics_endpoint(mocker): "type": "http", } ) - response = await metrics_endpoint_handler(request) + auth = ("test_user", "token", {}) + response = await metrics_endpoint_handler(auth=auth, request=request) assert response is not None assert response.status_code == 200 assert "text/plain" in response.headers["Content-Type"] diff --git a/tests/unit/app/endpoints/test_models.py b/tests/unit/app/endpoints/test_models.py index ca58c4b3e..a00c9b37b 100644 --- a/tests/unit/app/endpoints/test_models.py +++ b/tests/unit/app/endpoints/test_models.py @@ -8,11 +8,14 @@ from app.endpoints.models import models_endpoint_handler from configuration import AppConfig +from tests.unit.utils.auth_helpers import mock_authorization_resolvers @pytest.mark.asyncio async def test_models_endpoint_handler_configuration_not_loaded(mocker): """Test the models endpoint handler if configuration is not loaded.""" + mock_authorization_resolvers(mocker) + # simulate state when no configuration is loaded mocker.patch( "app.endpoints.models.configuration", @@ -26,9 +29,10 @@ async def test_models_endpoint_handler_configuration_not_loaded(mocker): "headers": [(b"authorization", b"Bearer invalid-token")], } ) + auth = ("user_id", "user_name", "token") with pytest.raises(HTTPException) as e: - await models_endpoint_handler(request) + await models_endpoint_handler(request=request, auth=auth) assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert e.detail["response"] == "Configuration is not loaded" @@ -36,6 +40,8 @@ async def test_models_endpoint_handler_configuration_not_loaded(mocker): @pytest.mark.asyncio async def test_models_endpoint_handler_improper_llama_stack_configuration(mocker): """Test the models endpoint handler if Llama Stack configuration is not proper.""" + mock_authorization_resolvers(mocker) + # configuration for tests config_dict = { "name": "test", @@ -57,6 +63,8 @@ async def test_models_endpoint_handler_improper_llama_stack_configuration(mocker }, "mcp_servers": [], "customization": None, + "authorization": {"access_rules": []}, + "authentication": {"module": "noop"}, } cfg = AppConfig() cfg.init_from_dict(config_dict) @@ -72,15 +80,18 @@ async def test_models_endpoint_handler_improper_llama_stack_configuration(mocker "headers": [(b"authorization", b"Bearer invalid-token")], } ) + auth = ("test_user", "token", {}) with pytest.raises(HTTPException) as e: - await models_endpoint_handler(request) + await models_endpoint_handler(request=request, auth=auth) assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert e.detail["response"] == "LLama stack is not configured" @pytest.mark.asyncio -async def test_models_endpoint_handler_configuration_loaded(): +async def test_models_endpoint_handler_configuration_loaded(mocker): """Test the models endpoint handler if configuration is loaded.""" + mock_authorization_resolvers(mocker) + # configuration for tests config_dict = { "name": "foo", @@ -101,6 +112,8 @@ async def test_models_endpoint_handler_configuration_loaded(): "feedback_enabled": False, }, "customization": None, + "authorization": {"access_rules": []}, + "authentication": {"module": "noop"}, } cfg = AppConfig() cfg.init_from_dict(config_dict) @@ -111,9 +124,10 @@ async def test_models_endpoint_handler_configuration_loaded(): "headers": [(b"authorization", b"Bearer invalid-token")], } ) + auth = ("test_user", "token", {}) with pytest.raises(HTTPException) as e: - await models_endpoint_handler(request) + await models_endpoint_handler(request=request, auth=auth) assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert e.detail["response"] == "Unable to connect to Llama Stack" @@ -121,6 +135,8 @@ async def test_models_endpoint_handler_configuration_loaded(): @pytest.mark.asyncio async def test_models_endpoint_handler_unable_to_retrieve_models_list(mocker): """Test the models endpoint handler if configuration is loaded.""" + mock_authorization_resolvers(mocker) + # configuration for tests config_dict = { "name": "foo", @@ -141,6 +157,8 @@ async def test_models_endpoint_handler_unable_to_retrieve_models_list(mocker): "feedback_enabled": False, }, "customization": None, + "authorization": {"access_rules": []}, + "authentication": {"module": "noop"}, } cfg = AppConfig() cfg.init_from_dict(config_dict) @@ -159,13 +177,16 @@ async def test_models_endpoint_handler_unable_to_retrieve_models_list(mocker): "headers": [(b"authorization", b"Bearer invalid-token")], } ) - response = await models_endpoint_handler(request) + auth = ("test_user", "token", {}) + response = await models_endpoint_handler(request=request, auth=auth) assert response is not None @pytest.mark.asyncio async def test_models_endpoint_llama_stack_connection_error(mocker): """Test the model endpoint when LlamaStack connection fails.""" + mock_authorization_resolvers(mocker) + # configuration for tests config_dict = { "name": "foo", @@ -186,6 +207,8 @@ async def test_models_endpoint_llama_stack_connection_error(mocker): "feedback_enabled": False, }, "customization": None, + "authorization": {"access_rules": []}, + "authentication": {"module": "noop"}, } # mock AsyncLlamaStackClientHolder to raise APIConnectionError @@ -206,8 +229,9 @@ async def test_models_endpoint_llama_stack_connection_error(mocker): "headers": [(b"authorization", b"Bearer invalid-token")], } ) + auth = ("test_user", "token", {}) with pytest.raises(HTTPException) as e: - await models_endpoint_handler(request) + await models_endpoint_handler(request=request, auth=auth) assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert e.detail["response"] == "Unable to connect to Llama Stack" diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 3356f2b27..b12101b47 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -1,9 +1,11 @@ +# pylint: disable=redefined-outer-name + """Unit tests for the /query REST API endpoint.""" # pylint: disable=too-many-lines import json -from fastapi import HTTPException, status +from fastapi import HTTPException, status, Request import pytest from llama_stack_client import APIConnectionError @@ -24,13 +26,26 @@ ) from models.requests import QueryRequest, Attachment -from models.config import ModelContextProtocolServer +from models.config import Action, ModelContextProtocolServer from models.database.conversations import UserConversation from utils.types import ToolCallSummary, TurnSummary MOCK_AUTH = ("mock_user_id", "mock_username", "mock_token") +@pytest.fixture +def dummy_request() -> Request: + """Dummy request fixture for testing.""" + req = Request( + scope={ + "type": "http", + } + ) + + req.state.authorized_actions = set(Action) + return req + + def mock_database_operations(mocker): """Helper function to mock database operations for query endpoints.""" mocker.patch( @@ -69,7 +84,7 @@ def setup_configuration_fixture(): @pytest.mark.asyncio -async def test_query_endpoint_handler_configuration_not_loaded(mocker): +async def test_query_endpoint_handler_configuration_not_loaded(mocker, dummy_request): """Test the query endpoint handler if configuration is not loaded.""" # simulate state when no configuration is loaded mocker.patch( @@ -81,7 +96,11 @@ async def test_query_endpoint_handler_configuration_not_loaded(mocker): query = "What is OpenStack?" query_request = QueryRequest(query=query) with pytest.raises(HTTPException) as e: - await query_endpoint_handler(query_request, auth=["test-user", "", "token"]) + await query_endpoint_handler( + query_request=query_request, + request=dummy_request, + auth=["test-user", "", "token"], + ) assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert e.value.detail["response"] == "Configuration is not loaded" @@ -107,7 +126,9 @@ def test_is_transcripts_disabled(setup_configuration, mocker): assert is_transcripts_enabled() is False, "Transcripts should be disabled" -async def _test_query_endpoint_handler(mocker, store_transcript_to_file=False): +async def _test_query_endpoint_handler( + mocker, dummy_request: Request, store_transcript_to_file=False +): """Test the query endpoint handler.""" mock_metric = mocker.patch("metrics.llm_calls_total") mock_client = mocker.AsyncMock() @@ -157,7 +178,9 @@ async def _test_query_endpoint_handler(mocker, store_transcript_to_file=False): query_request = QueryRequest(query=query) - response = await query_endpoint_handler(query_request, auth=MOCK_AUTH) + response = await query_endpoint_handler( + request=dummy_request, query_request=query_request, auth=MOCK_AUTH + ) # Assert the response is as expected assert response.response == summary.llm_response @@ -186,15 +209,21 @@ async def _test_query_endpoint_handler(mocker, store_transcript_to_file=False): @pytest.mark.asyncio -async def test_query_endpoint_handler_transcript_storage_disabled(mocker): +async def test_query_endpoint_handler_transcript_storage_disabled( + mocker, dummy_request +): """Test the query endpoint handler with transcript storage disabled.""" - await _test_query_endpoint_handler(mocker, store_transcript_to_file=False) + await _test_query_endpoint_handler( + mocker, dummy_request, store_transcript_to_file=False + ) @pytest.mark.asyncio -async def test_query_endpoint_handler_store_transcript(mocker): +async def test_query_endpoint_handler_store_transcript(mocker, dummy_request): """Test the query endpoint handler with transcript storage enabled.""" - await _test_query_endpoint_handler(mocker, store_transcript_to_file=True) + await _test_query_endpoint_handler( + mocker, dummy_request, store_transcript_to_file=True + ) def test_select_model_and_provider_id_from_request(mocker): @@ -1095,7 +1124,7 @@ def test_get_rag_toolgroups(): @pytest.mark.asyncio -async def test_query_endpoint_handler_on_connection_error(mocker): +async def test_query_endpoint_handler_on_connection_error(mocker, dummy_request): """Test the query endpoint handler.""" mock_metric = mocker.patch("metrics.llm_calls_failures_total") @@ -1111,7 +1140,9 @@ async def test_query_endpoint_handler_on_connection_error(mocker): mock_get_client.side_effect = APIConnectionError(request=query_request) with pytest.raises(HTTPException) as exc_info: - await query_endpoint_handler(query_request, auth=MOCK_AUTH) + await query_endpoint_handler( + query_request=query_request, request=dummy_request, auth=MOCK_AUTH + ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert "Unable to connect to Llama Stack" in str(exc_info.value.detail) @@ -1119,7 +1150,7 @@ async def test_query_endpoint_handler_on_connection_error(mocker): @pytest.mark.asyncio -async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker): +async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker, dummy_request): """Test that auth tuple is correctly unpacked in query endpoint handler.""" # Mock dependencies mock_config = mocker.Mock() @@ -1159,7 +1190,8 @@ async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker): mock_database_operations(mocker) _ = await query_endpoint_handler( - QueryRequest(query="test query"), + request=dummy_request, + query_request=QueryRequest(query="test query"), auth=("user123", "username", "auth_token_123"), mcp_headers=None, ) @@ -1168,7 +1200,7 @@ async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker): @pytest.mark.asyncio -async def test_query_endpoint_handler_no_tools_true(mocker): +async def test_query_endpoint_handler_no_tools_true(mocker, dummy_request): """Test the query endpoint handler with no_tools=True.""" mock_client = mocker.AsyncMock() mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") @@ -1209,7 +1241,9 @@ async def test_query_endpoint_handler_no_tools_true(mocker): query_request = QueryRequest(query=query, no_tools=True) - response = await query_endpoint_handler(query_request, auth=MOCK_AUTH) + response = await query_endpoint_handler( + request=dummy_request, query_request=query_request, auth=MOCK_AUTH + ) # Assert the response is as expected assert response.response == summary.llm_response @@ -1217,7 +1251,7 @@ async def test_query_endpoint_handler_no_tools_true(mocker): @pytest.mark.asyncio -async def test_query_endpoint_handler_no_tools_false(mocker): +async def test_query_endpoint_handler_no_tools_false(mocker, dummy_request): """Test the query endpoint handler with no_tools=False (default behavior).""" mock_client = mocker.AsyncMock() mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") @@ -1258,7 +1292,9 @@ async def test_query_endpoint_handler_no_tools_false(mocker): query_request = QueryRequest(query=query, no_tools=False) - response = await query_endpoint_handler(query_request, auth=MOCK_AUTH) + response = await query_endpoint_handler( + request=dummy_request, query_request=query_request, auth=MOCK_AUTH + ) # Assert the response is as expected assert response.response == summary.llm_response diff --git a/tests/unit/app/endpoints/test_root.py b/tests/unit/app/endpoints/test_root.py index c2e889dfe..a78856efd 100644 --- a/tests/unit/app/endpoints/test_root.py +++ b/tests/unit/app/endpoints/test_root.py @@ -1,16 +1,22 @@ """Unit tests for the / endpoint handler.""" +import pytest from fastapi import Request from app.endpoints.root import root_endpoint_handler +from tests.unit.utils.auth_helpers import mock_authorization_resolvers -def test_root_endpoint(): +@pytest.mark.asyncio +async def test_root_endpoint(mocker): """Test the root endpoint handler.""" + mock_authorization_resolvers(mocker) + + auth = ("test_user", "token", {}) request = Request( scope={ "type": "http", } ) - response = root_endpoint_handler(request) + response = await root_endpoint_handler(auth=auth, request=request) assert response is not None diff --git a/tests/unit/auth/test_jwk_token.py b/tests/unit/auth/test_jwk_token.py index 7f59bc870..5a579b93d 100644 --- a/tests/unit/auth/test_jwk_token.py +++ b/tests/unit/auth/test_jwk_token.py @@ -3,6 +3,8 @@ """Unit tests for functions defined in auth/jwk_token.py""" import time +import json +import base64 import pytest from fastapi import HTTPException, Request @@ -171,12 +173,19 @@ def set_auth_header(request: Request, token: str): request.scope["headers"] = new_headers +def get_claims(token: str): + """Extract claims from a JWT token without validating it.""" + payload = token.split(".")[1] + padded = payload + "=" * (-len(payload) % 4) + return json.loads(base64.urlsafe_b64decode(padded)) + + def ensure_test_user_id_and_name(auth_tuple, token): """Utility to ensure that the values in the auth tuple match the test values.""" - user_id, username, tuple_token = auth_tuple + user_id, username, token_claims = auth_tuple assert user_id == TEST_USER_ID assert username == TEST_USER_NAME - assert tuple_token == token + assert json.loads(token_claims) == get_claims(token) async def test_valid( diff --git a/tests/unit/authorization/__init__.py b/tests/unit/authorization/__init__.py new file mode 100644 index 000000000..96552c229 --- /dev/null +++ b/tests/unit/authorization/__init__.py @@ -0,0 +1 @@ +"""Unit tests for authorization module.""" diff --git a/tests/unit/authorization/test_resolvers.py b/tests/unit/authorization/test_resolvers.py new file mode 100644 index 000000000..7e02d4071 --- /dev/null +++ b/tests/unit/authorization/test_resolvers.py @@ -0,0 +1,101 @@ +"""Unit tests for the authorization resolvers.""" + +from authorization.resolvers import JwtRolesResolver, GenericAccessResolver +from models.config import JwtRoleRule, AccessRule, JsonPathOperator, Action + + +class TestJwtRolesResolver: + """Test cases for JwtRolesResolver.""" + + async def test_resolve_roles_redhat_employee(self): + """Test role extraction for RedHat employee JWT.""" + role_rules = [ + JwtRoleRule( + jsonpath="$.realm_access.roles[*]", + operator=JsonPathOperator.CONTAINS, + value="redhat:employees", + roles=["employee"], + ) + ] + jwt_resolver = JwtRolesResolver(role_rules) + + jwt_claims = { + "exp": 1754489339, + "iat": 1754488439, + "sub": "f:123:employee@redhat.com", + "realm_access": { + "roles": [ + "uma_authorization", + "redhat:employees", + "default-roles-redhat", + ] + }, + } + + # Mock auth tuple with JWT claims as third element + auth = ("user", "token", str(jwt_claims).replace("'", '"')) + roles = await jwt_resolver.resolve_roles(auth) + assert "employee" in roles + + async def test_resolve_roles_no_match(self): + """Test role extraction when no rules match.""" + role_rules = [ + JwtRoleRule( + jsonpath="$.realm_access.roles[*]", + operator=JsonPathOperator.CONTAINS, + value="redhat:employees", + roles=["employee"], + ) + ] + jwt_resolver = JwtRolesResolver(role_rules) + + jwt_claims = { + "exp": 1754489339, + "iat": 1754488439, + "sub": "f:123:user@example.com", + "realm_access": {"roles": ["uma_authorization", "default-roles-example"]}, + } + + # Mock auth tuple with JWT claims as third element + auth = ("user", "token", str(jwt_claims).replace("'", '"')) + roles = await jwt_resolver.resolve_roles(auth) + assert len(roles) == 0 + + +class TestGenericAccessResolver: + """Test cases for GenericAccessResolver.""" + + async def test_check_access_with_valid_role(self): + """Test access check with valid role.""" + access_rules = [ + AccessRule(role="employee", actions=[Action.QUERY, Action.GET_MODELS]) + ] + resolver = GenericAccessResolver(access_rules) + + # Test access granted + has_access = resolver.check_access(Action.QUERY, {"employee"}) + assert has_access is True + + # Test access denied + has_access = resolver.check_access(Action.FEEDBACK, frozenset(["employee"])) + assert has_access is False + + async def test_check_access_with_invalid_role(self): + """Test access check with invalid role.""" + access_rules = [ + AccessRule(role="employee", actions=[Action.QUERY, Action.GET_MODELS]) + ] + resolver = GenericAccessResolver(access_rules) + + has_access = resolver.check_access(Action.QUERY, {"visitor"}) + assert has_access is False + + async def test_check_access_with_no_roles(self): + """Test access check with no roles.""" + access_rules = [ + AccessRule(role="employee", actions=[Action.QUERY, Action.GET_MODELS]) + ] + resolver = GenericAccessResolver(access_rules) + + has_access = resolver.check_access(Action.QUERY, set()) + assert has_access is False diff --git a/tests/unit/models/test_config.py b/tests/unit/models/test_config.py index b10fa1424..ba44cf9c0 100644 --- a/tests/unit/models/test_config.py +++ b/tests/unit/models/test_config.py @@ -552,6 +552,7 @@ def test_dump_configuration(tmp_path) -> None: assert "user_data_collection" in content assert "mcp_servers" in content assert "authentication" in content + assert "authorization" in content assert "customization" in content assert "inference" in content assert "database" in content @@ -619,6 +620,7 @@ def test_dump_configuration(tmp_path) -> None: "sqlite": {"db_path": "/tmp/lightspeed-stack.db"}, "postgres": None, }, + "authorization": None, } diff --git a/tests/unit/utils/auth_helpers.py b/tests/unit/utils/auth_helpers.py new file mode 100644 index 000000000..fe07950de --- /dev/null +++ b/tests/unit/utils/auth_helpers.py @@ -0,0 +1,27 @@ +"""Helper functions for mocking authorization in tests.""" + +from typing import Any +from unittest.mock import AsyncMock, Mock + +from models.config import Action + + +def mock_authorization_resolvers(mocker: Any) -> None: + """Mock authorization resolvers to allow all access. + + This function mocks the authorization middleware to bypass authorization + checks in tests by creating mock resolvers that always grant access. + + Args: + mocker: The pytest-mock mocker fixture + """ + mock_resolvers = mocker.patch( + "authorization.middleware.get_authorization_resolvers" + ) + mock_role_resolver = AsyncMock() + mock_access_resolver = AsyncMock() + mock_role_resolver.resolve_roles.return_value = set() + mock_access_resolver.check_access.return_value = True + # get_actions should be synchronous, not async + mock_access_resolver.get_actions = Mock(return_value=set(Action)) + mock_resolvers.return_value = (mock_role_resolver, mock_access_resolver) diff --git a/uv.lock b/uv.lock index 64b55961c..e24709428 100644 --- a/uv.lock +++ b/uv.lock @@ -1119,6 +1119,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7d/4f/1195bbac8e0c2acc5f740661631d8d750dc38d4a32b23ee5df3cde6f4e0d/joblib-1.5.1-py3-none-any.whl", hash = "sha256:4719a31f054c7d766948dcd83e9613686b27114f190f717cec7eaa2084f8a74a", size = 307746, upload-time = "2025-05-23T12:04:35.124Z" }, ] +[[package]] +name = "jsonpath-ng" +version = "1.7.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ply" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/86/08646239a313f895186ff0a4573452038eed8c86f54380b3ebac34d32fb2/jsonpath-ng-1.7.0.tar.gz", hash = "sha256:f6f5f7fd4e5ff79c785f1573b394043b39849fb2bb47bcead935d12b00beab3c", size = 37838, upload-time = "2024-10-11T15:41:42.404Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/35/5a/73ecb3d82f8615f32ccdadeb9356726d6cae3a4bbc840b437ceb95708063/jsonpath_ng-1.7.0-py3-none-any.whl", hash = "sha256:f3d7f9e848cba1b6da28c55b1c26ff915dc9e0b1ba7e752a53d6da8d5cbd00b6", size = 30105, upload-time = "2024-11-20T17:58:30.418Z" }, +] + [[package]] name = "jsonschema" version = "4.25.1" @@ -1249,6 +1261,7 @@ dependencies = [ { name = "cachetools" }, { name = "email-validator" }, { name = "fastapi" }, + { name = "jsonpath-ng" }, { name = "kubernetes" }, { name = "llama-stack" }, { name = "llama-stack-client" }, @@ -1327,6 +1340,7 @@ requires-dist = [ { name = "cachetools", specifier = ">=6.1.0" }, { name = "email-validator", specifier = ">=2.2.0" }, { name = "fastapi", specifier = ">=0.115.12" }, + { name = "jsonpath-ng", specifier = ">=1.6.1" }, { name = "kubernetes", specifier = ">=30.1.0" }, { name = "llama-stack", specifier = "==0.2.17" }, { name = "llama-stack-client", specifier = "==0.2.17" }, @@ -2233,6 +2247,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "ply" +version = "3.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e5/69/882ee5c9d017149285cab114ebeab373308ef0f874fcdac9beb90e0ac4da/ply-3.11.tar.gz", hash = "sha256:00c7c1aaa88358b9c765b6d3000c6eec0ba42abca5351b095321aef446081da3", size = 159130, upload-time = "2018-02-15T19:01:31.097Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a3/58/35da89ee790598a0700ea49b2a66594140f44dec458c07e8e3d4979137fc/ply-3.11-py2.py3-none-any.whl", hash = "sha256:096f9b8350b65ebd2fd1346b12452efe5b9607f7482813ffca50c22722a807ce", size = 49567, upload-time = "2018-02-15T19:01:27.172Z" }, +] + [[package]] name = "polyleven" version = "0.9.0"