diff --git a/pyproject.toml b/pyproject.toml index eae1134db..82fb2d9cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "openai==1.99.9", "sqlalchemy>=2.0.42", "semver<4.0.0", + "jsonpath-ng>=1.6.1", ] diff --git a/src/app/endpoints/authorized.py b/src/app/endpoints/authorized.py index 9ada28f74..8fa029ee6 100644 --- a/src/app/endpoints/authorized.py +++ b/src/app/endpoints/authorized.py @@ -1,10 +1,11 @@ """Handler for REST API call to authorized endpoint.""" import logging -from typing import Any +from typing import Annotated, Any from fastapi import APIRouter, Depends +from auth.interface import AuthTuple from auth import get_auth_dependency from models.responses import AuthorizedResponse, UnauthorizedResponse, ForbiddenResponse @@ -31,7 +32,7 @@ @router.post("/authorized", responses=authorized_responses) async def authorized_endpoint_handler( - auth: Any = Depends(auth_dependency), + auth: Annotated[AuthTuple, Depends(auth_dependency)], ) -> AuthorizedResponse: """ Handle request to the /authorized endpoint. diff --git a/src/app/endpoints/config.py b/src/app/endpoints/config.py index fec854294..a9d2daac1 100644 --- a/src/app/endpoints/config.py +++ b/src/app/endpoints/config.py @@ -1,17 +1,22 @@ """Handler for REST API call to retrieve service configuration.""" import logging -from typing import Any +from typing import Annotated, Any -from fastapi import APIRouter, Request +from fastapi import APIRouter, Request, Depends -from models.config import Configuration +from auth.interface import AuthTuple +from auth import get_auth_dependency +from authorization.middleware import authorize from configuration import configuration +from models.config import Action, Configuration from utils.endpoints import check_configuration_loaded logger = logging.getLogger(__name__) router = APIRouter(tags=["config"]) +auth_dependency = get_auth_dependency() + get_config_responses: dict[int | str, dict[str, Any]] = { 200: { @@ -56,7 +61,11 @@ @router.get("/config", responses=get_config_responses) -def config_endpoint_handler(_request: Request) -> Configuration: +@authorize(Action.GET_CONFIG) +async def config_endpoint_handler( + auth: Annotated[AuthTuple, Depends(auth_dependency)], + request: Request, +) -> Configuration: """ Handle requests to the /config endpoint. @@ -66,6 +75,12 @@ def config_endpoint_handler(_request: Request) -> Configuration: Returns: Configuration: The loaded service configuration object. """ + # Used only for authorization + _ = auth + + # Nothing interesting in the request + _ = request + # ensure that configuration is loaded check_configuration_loaded(configuration) diff --git a/src/app/endpoints/conversations.py b/src/app/endpoints/conversations.py index bbb7825d1..ba8440855 100644 --- a/src/app/endpoints/conversations.py +++ b/src/app/endpoints/conversations.py @@ -5,19 +5,21 @@ from llama_stack_client import APIConnectionError, NotFoundError -from fastapi import APIRouter, HTTPException, status, Depends +from fastapi import APIRouter, HTTPException, Request, status, Depends from client import AsyncLlamaStackClientHolder from configuration import configuration +from app.database import get_session +from auth import get_auth_dependency +from authorization.middleware import authorize +from models.config import Action +from models.database.conversations import UserConversation from models.responses import ( ConversationResponse, ConversationDeleteResponse, ConversationsListResponse, ConversationDetails, ) -from models.database.conversations import UserConversation -from auth import get_auth_dependency -from app.database import get_session from utils.endpoints import check_configuration_loaded, validate_conversation_ownership from utils.suid import check_suid @@ -146,7 +148,9 @@ def simplify_session_data(session_data: dict) -> list[dict[str, Any]]: @router.get("/conversations", responses=conversations_list_responses) -def get_conversations_list_endpoint_handler( +@authorize(Action.LIST_CONVERSATIONS) +async def get_conversations_list_endpoint_handler( + request: Request, auth: Any = Depends(auth_dependency), ) -> ConversationsListResponse: """Handle request to retrieve all conversations for the authenticated user.""" @@ -158,11 +162,16 @@ def get_conversations_list_endpoint_handler( with get_session() as session: try: - # Get all conversations for this user - user_conversations = ( - session.query(UserConversation).filter_by(user_id=user_id).all() + query = session.query(UserConversation) + + filtered_query = ( + query + if Action.LIST_OTHERS_CONVERSATIONS in request.state.authorized_actions + else query.filter_by(user_id=user_id) ) + user_conversations = filtered_query.all() + # Return conversation summaries with metadata conversations = [ ConversationDetails( @@ -200,7 +209,9 @@ def get_conversations_list_endpoint_handler( @router.get("/conversations/{conversation_id}", responses=conversation_responses) +@authorize(Action.GET_CONVERSATION) async def get_conversation_endpoint_handler( + request: Request, conversation_id: str, auth: Any = Depends(auth_dependency), ) -> ConversationResponse: @@ -239,6 +250,9 @@ async def get_conversation_endpoint_handler( validate_conversation_ownership( user_id=user_id, conversation_id=conversation_id, + others_allowed=( + Action.READ_OTHERS_CONVERSATIONS in request.state.authorized_actions + ), ) agent_id = conversation_id @@ -309,7 +323,9 @@ async def get_conversation_endpoint_handler( @router.delete( "/conversations/{conversation_id}", responses=conversation_delete_responses ) +@authorize(Action.DELETE_CONVERSATION) async def delete_conversation_endpoint_handler( + request: Request, conversation_id: str, auth: Any = Depends(auth_dependency), ) -> ConversationDeleteResponse: @@ -342,6 +358,9 @@ async def delete_conversation_endpoint_handler( validate_conversation_ownership( user_id=user_id, conversation_id=conversation_id, + others_allowed=( + Action.DELETE_OTHERS_CONVERSATIONS in request.state.authorized_actions + ), ) agent_id = conversation_id diff --git a/src/app/endpoints/feedback.py b/src/app/endpoints/feedback.py index 5d82267ee..39d2a2692 100644 --- a/src/app/endpoints/feedback.py +++ b/src/app/endpoints/feedback.py @@ -5,11 +5,14 @@ from pathlib import Path import json from datetime import datetime, UTC -from fastapi import APIRouter, Request, HTTPException, Depends, status +from fastapi import APIRouter, HTTPException, Depends, Request, status from auth import get_auth_dependency from auth.interface import AuthTuple +from authorization.middleware import authorize from configuration import configuration +from models.config import Action +from models.requests import FeedbackRequest from models.responses import ( ErrorResponse, FeedbackResponse, @@ -17,7 +20,6 @@ UnauthorizedResponse, ForbiddenResponse, ) -from models.requests import FeedbackRequest from utils.suid import get_suid logger = logging.getLogger(__name__) @@ -79,7 +81,8 @@ async def assert_feedback_enabled(_request: Request) -> None: @router.post("", responses=feedback_response) -def feedback_endpoint_handler( +@authorize(Action.FEEDBACK) +async def feedback_endpoint_handler( feedback_request: FeedbackRequest, auth: Annotated[AuthTuple, Depends(auth_dependency)], _ensure_feedback_enabled: Any = Depends(assert_feedback_enabled), diff --git a/src/app/endpoints/health.py b/src/app/endpoints/health.py index f1ade175f..f62c9a013 100644 --- a/src/app/endpoints/health.py +++ b/src/app/endpoints/health.py @@ -6,12 +6,16 @@ """ import logging -from typing import Any +from typing import Annotated, Any from llama_stack.providers.datatypes import HealthStatus -from fastapi import APIRouter, status, Response +from fastapi import APIRouter, status, Response, Depends from client import AsyncLlamaStackClientHolder +from auth.interface import AuthTuple +from auth import get_auth_dependency +from authorization.middleware import authorize +from models.config import Action from models.responses import ( LivenessResponse, ReadinessResponse, @@ -21,6 +25,8 @@ logger = logging.getLogger(__name__) router = APIRouter(tags=["health"]) +auth_dependency = get_auth_dependency() + async def get_providers_health_statuses() -> list[ProviderHealthStatus]: """ @@ -72,7 +78,11 @@ async def get_providers_health_statuses() -> list[ProviderHealthStatus]: @router.get("/readiness", responses=get_readiness_responses) -async def readiness_probe_get_method(response: Response) -> ReadinessResponse: +@authorize(Action.INFO) +async def readiness_probe_get_method( + auth: Annotated[AuthTuple, Depends(auth_dependency)], + response: Response, +) -> ReadinessResponse: """ Handle the readiness probe endpoint, returning service readiness. @@ -80,6 +90,9 @@ async def readiness_probe_get_method(response: Response) -> ReadinessResponse: and details of unhealthy providers; otherwise, indicates the service is ready. """ + # Used only for authorization + _ = auth + provider_statuses = await get_providers_health_statuses() # Check if any provider is unhealthy (not counting not_implemented as unhealthy) @@ -112,11 +125,17 @@ async def readiness_probe_get_method(response: Response) -> ReadinessResponse: @router.get("/liveness", responses=get_liveness_responses) -def liveness_probe_get_method() -> LivenessResponse: +@authorize(Action.INFO) +async def liveness_probe_get_method( + auth: Annotated[AuthTuple, Depends(auth_dependency)], +) -> LivenessResponse: """ Return the liveness status of the service. Returns: LivenessResponse: Indicates that the service is alive. """ + # Used only for authorization + _ = auth + return LivenessResponse(alive=True) diff --git a/src/app/endpoints/info.py b/src/app/endpoints/info.py index d0c5f1098..ecc124297 100644 --- a/src/app/endpoints/info.py +++ b/src/app/endpoints/info.py @@ -1,17 +1,24 @@ """Handler for REST API call to provide info.""" import logging -from typing import Any +from typing import Annotated, Any from fastapi import APIRouter, Request +from fastapi import Depends +from auth.interface import AuthTuple +from auth import get_auth_dependency +from authorization.middleware import authorize from configuration import configuration -from version import __version__ +from models.config import Action from models.responses import InfoResponse +from version import __version__ logger = logging.getLogger(__name__) router = APIRouter(tags=["info"]) +auth_dependency = get_auth_dependency() + get_info_responses: dict[int | str, dict[str, Any]] = { 200: { @@ -22,7 +29,11 @@ @router.get("/info", responses=get_info_responses) -def info_endpoint_handler(_request: Request) -> InfoResponse: +@authorize(Action.INFO) +async def info_endpoint_handler( + auth: Annotated[AuthTuple, Depends(auth_dependency)], + request: Request, +) -> InfoResponse: """ Handle request to the /info endpoint. @@ -32,4 +43,10 @@ def info_endpoint_handler(_request: Request) -> InfoResponse: Returns: InfoResponse: An object containing the service's name and version. """ + # Used only for authorization + _ = auth + + # Nothing interesting in the request + _ = request + return InfoResponse(name=configuration.configuration.name, version=__version__) diff --git a/src/app/endpoints/metrics.py b/src/app/endpoints/metrics.py index 10b986e6e..e45e6c668 100644 --- a/src/app/endpoints/metrics.py +++ b/src/app/endpoints/metrics.py @@ -1,19 +1,30 @@ """Handler for REST API call to provide metrics.""" +from typing import Annotated from fastapi.responses import PlainTextResponse -from fastapi import APIRouter, Request +from fastapi import APIRouter, Request, Depends from prometheus_client import ( generate_latest, CONTENT_TYPE_LATEST, ) +from auth.interface import AuthTuple +from auth import get_auth_dependency +from authorization.middleware import authorize +from models.config import Action from metrics.utils import setup_model_metrics router = APIRouter(tags=["metrics"]) +auth_dependency = get_auth_dependency() + @router.get("/metrics", response_class=PlainTextResponse) -async def metrics_endpoint_handler(_request: Request) -> PlainTextResponse: +@authorize(Action.GET_METRICS) +async def metrics_endpoint_handler( + auth: Annotated[AuthTuple, Depends(auth_dependency)], + request: Request, +) -> PlainTextResponse: """ Handle request to the /metrics endpoint. @@ -24,6 +35,12 @@ async def metrics_endpoint_handler(_request: Request) -> PlainTextResponse: set up, then responds with the current metrics snapshot in Prometheus format. """ + # Used only for authorization + _ = auth + + # Nothing interesting in the request + _ = request + # Setup the model metrics if not already done. This is a one-time setup # and will not be run again on subsequent calls to this endpoint await setup_model_metrics() diff --git a/src/app/endpoints/models.py b/src/app/endpoints/models.py index feac9e4cd..cdd85b930 100644 --- a/src/app/endpoints/models.py +++ b/src/app/endpoints/models.py @@ -1,13 +1,18 @@ """Handler for REST API call to list available models.""" import logging -from typing import Any +from typing import Annotated, Any -from llama_stack_client import APIConnectionError from fastapi import APIRouter, HTTPException, Request, status +from fastapi.params import Depends +from llama_stack_client import APIConnectionError +from auth import get_auth_dependency +from auth.interface import AuthTuple from client import AsyncLlamaStackClientHolder from configuration import configuration +from authorization.middleware import authorize +from models.config import Action from models.responses import ModelsResponse from utils.endpoints import check_configuration_loaded @@ -15,6 +20,9 @@ router = APIRouter(tags=["models"]) +auth_dependency = get_auth_dependency() + + models_responses: dict[int | str, dict[str, Any]] = { 200: { "models": [ @@ -43,7 +51,11 @@ @router.get("/models", responses=models_responses) -async def models_endpoint_handler(_request: Request) -> ModelsResponse: +@authorize(Action.GET_MODELS) +async def models_endpoint_handler( + request: Request, + auth: Annotated[AuthTuple, Depends(auth_dependency)], +) -> ModelsResponse: """ Handle requests to the /models endpoint. @@ -57,6 +69,12 @@ async def models_endpoint_handler(_request: Request) -> ModelsResponse: Returns: ModelsResponse: An object containing the list of available models. """ + # Used only by the middleware + _ = auth + + # Nothing interesting in the request + _ = request + check_configuration_loaded(configuration) llama_stack_configuration = configuration.llama_stack_configuration diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index eae3ae287..7c5343b7f 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -16,7 +16,7 @@ ) from llama_stack_client.types.model_list_response import ModelListResponse -from fastapi import APIRouter, HTTPException, status, Depends +from fastapi import APIRouter, HTTPException, Request, status, Depends from auth import get_auth_dependency from auth.interface import AuthTuple @@ -24,10 +24,12 @@ from configuration import configuration from app.database import get_session import metrics +import constants +from authorization.middleware import authorize +from models.config import Action from models.database.conversations import UserConversation -from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse from models.requests import QueryRequest, Attachment -import constants +from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse from utils.endpoints import ( check_configuration_loaded, get_agent, @@ -148,7 +150,9 @@ def evaluate_model_hints( @router.post("/query", responses=query_response) +@authorize(Action.QUERY) async def query_endpoint_handler( + request: Request, query_request: QueryRequest, auth: Annotated[AuthTuple, Depends(auth_dependency)], mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), @@ -178,7 +182,11 @@ async def query_endpoint_handler( user_conversation: UserConversation | None = None if query_request.conversation_id: user_conversation = validate_conversation_ownership( - user_id=user_id, conversation_id=query_request.conversation_id + user_id=user_id, + conversation_id=query_request.conversation_id, + others_allowed=( + Action.QUERY_OTHERS_CONVERSATIONS in request.state.authorized_actions + ), ) if user_conversation is None: diff --git a/src/app/endpoints/root.py b/src/app/endpoints/root.py index e34ff0d6e..b02234042 100644 --- a/src/app/endpoints/root.py +++ b/src/app/endpoints/root.py @@ -1,14 +1,21 @@ """Handler for the / endpoint.""" import logging +from typing import Annotated -from fastapi import APIRouter, Request +from fastapi import APIRouter, Request, Depends from fastapi.responses import HTMLResponse -logger = logging.getLogger("app.endpoints.handlers") +from auth.interface import AuthTuple +from auth import get_auth_dependency +from authorization.middleware import authorize +from models.config import Action +logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["root"]) +auth_dependency = get_auth_dependency() + INDEX_PAGE = """ @@ -771,7 +778,17 @@ @router.get("/", response_class=HTMLResponse) -def root_endpoint_handler(_request: Request) -> HTMLResponse: +@authorize(Action.INFO) +async def root_endpoint_handler( + auth: Annotated[AuthTuple, Depends(auth_dependency)], + request: Request, +) -> HTMLResponse: """Handle request to the / endpoint.""" + # Used only for authorization + _ = auth + + # Nothing interesting in the request + _ = request + logger.info("Serving index page") return HTMLResponse(INDEX_PAGE) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 69d29ebd1..a7fe6788e 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -22,9 +22,11 @@ from auth import get_auth_dependency from auth.interface import AuthTuple +from authorization.middleware import authorize from client import AsyncLlamaStackClientHolder from configuration import configuration import metrics +from models.config import Action from models.requests import QueryRequest from models.database.conversations import UserConversation from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt @@ -509,8 +511,9 @@ def _handle_heartbeat_event(chunk_id: int) -> Iterator[str]: @router.post("/streaming_query") +@authorize(Action.STREAMING_QUERY) async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals - _request: Request, + request: Request, query_request: QueryRequest, auth: Annotated[AuthTuple, Depends(auth_dependency)], mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), @@ -533,6 +536,9 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals HTTPException: Returns HTTP 500 if unable to connect to the Llama Stack server. """ + # Nothing interesting in the request + _ = request + check_configuration_loaded(configuration) llama_stack_config = configuration.llama_stack_configuration diff --git a/src/auth/jwk_token.py b/src/auth/jwk_token.py index 9fa493e6d..ad68abaed 100644 --- a/src/auth/jwk_token.py +++ b/src/auth/jwk_token.py @@ -1,6 +1,7 @@ """Manage authentication flow for FastAPI endpoints with JWK based JWT auth.""" import logging +import json from asyncio import Lock from typing import Any, Callable @@ -188,4 +189,4 @@ async def __call__(self, request: Request) -> tuple[str, str, str]: logger.info("Successfully authenticated user %s (ID: %s)", username, user_id) - return user_id, username, user_token + return user_id, username, json.dumps(claims) diff --git a/src/authorization/__init__.py b/src/authorization/__init__.py new file mode 100644 index 000000000..fff3fce71 --- /dev/null +++ b/src/authorization/__init__.py @@ -0,0 +1 @@ +"""Authorization module for role-based access control.""" diff --git a/src/authorization/middleware.py b/src/authorization/middleware.py new file mode 100644 index 000000000..a919c8789 --- /dev/null +++ b/src/authorization/middleware.py @@ -0,0 +1,115 @@ +"""Authorization middleware and decorators.""" + +import logging +from functools import wraps, lru_cache +from typing import Any, Callable, Tuple +from fastapi import HTTPException, status + +from authorization.resolvers import ( + AccessResolver, + GenericAccessResolver, + JwtRolesResolver, + NoopAccessResolver, + NoopRolesResolver, + RolesResolver, +) +from models.config import Action +from configuration import configuration +import constants + +logger = logging.getLogger(__name__) + + +@lru_cache(maxsize=1) +def get_authorization_resolvers() -> Tuple[RolesResolver, AccessResolver]: + """Get authorization resolvers from configuration (cached).""" + authorization_cfg = configuration.authorization_configuration + authentication_config = configuration.authentication_configuration + + match authentication_config.module: + case ( + constants.AUTH_MOD_NOOP + | constants.AUTH_MOD_K8S + | constants.AUTH_MOD_NOOP_WITH_TOKEN + ): + return ( + NoopRolesResolver(), + NoopAccessResolver(), + ) + case constants.AUTH_MOD_JWK_TOKEN: + jwt_role_rules_unset = ( + len( + authentication_config.jwk_configuration.jwt_configuration.role_rules + ) + ) == 0 + + authz_access_rules_unset = len(authorization_cfg.access_rules) == 0 + + if jwt_role_rules_unset or authz_access_rules_unset: + return NoopRolesResolver(), NoopAccessResolver() + + return ( + JwtRolesResolver( + role_rules=( + authentication_config.jwk_configuration.jwt_configuration.role_rules + ) + ), + GenericAccessResolver(authorization_cfg.access_rules), + ) + + case _: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Internal server error", + ) + + +async def _perform_authorization_check(action: Action, kwargs: dict[str, Any]) -> None: + """Perform authorization check - common logic for all decorators.""" + role_resolver, access_resolver = get_authorization_resolvers() + + try: + auth = kwargs["auth"] + except KeyError as exc: + logger.error( + "Authorization only allowed on endpoints that accept " + "'auth: Any = Depends(get_auth_dependency())'" + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Internal server error", + ) from exc + + # Everyone gets the everyone (aka *) role + everyone_roles = {"*"} + + user_roles = await role_resolver.resolve_roles(auth) | everyone_roles + + if not access_resolver.check_access(action, user_roles): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Insufficient permissions for action: {action}", + ) + + authorized_actions = access_resolver.get_actions(user_roles) + + try: + request = kwargs["request"] + request.state.authorized_actions = authorized_actions + except KeyError: + # This endpoint doesn't seem care about the authorized actions, so no need to set it + pass + + +def authorize(action: Action) -> Callable: + """Check authorization for an endpoint (async version).""" + + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + await _perform_authorization_check(action, kwargs) + return await func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/src/authorization/resolvers.py b/src/authorization/resolvers.py new file mode 100644 index 000000000..cdba3f4f6 --- /dev/null +++ b/src/authorization/resolvers.py @@ -0,0 +1,186 @@ +"""Authorization resolvers for role evaluation and access control.""" + +from abc import ABC, abstractmethod +import json +import logging +from typing import Any + +from jsonpath_ng import parse + +from auth.interface import AuthTuple +from models.config import JwtRoleRule, AccessRule, JsonPathOperator, Action + +logger = logging.getLogger(__name__) + + +UserRoles = set[str] + + +class RoleResolutionError(Exception): + """Custom exception for role resolution errors.""" + + +class RolesResolver(ABC): # pylint: disable=too-few-public-methods + """Base class for all role resolution strategies.""" + + @abstractmethod + async def resolve_roles(self, auth: AuthTuple) -> UserRoles: + """Given an auth tuple, return the list of user roles.""" + + +class NoopRolesResolver(RolesResolver): # pylint: disable=too-few-public-methods + """No-op roles resolver that does not perform any role resolution.""" + + async def resolve_roles(self, auth: AuthTuple) -> UserRoles: + """Return an empty list of roles.""" + _ = auth # Unused + return set() + + +class JwtRolesResolver(RolesResolver): # pylint: disable=too-few-public-methods + """Processes JWT claims with the given JSONPath rules to get roles.""" + + def __init__(self, role_rules: list[JwtRoleRule]): + """Initialize the resolver with rules.""" + self.role_rules = role_rules + + async def resolve_roles(self, auth: AuthTuple) -> UserRoles: + """Extract roles from JWT claims using configured rules.""" + jwt_claims = self._get_claims(auth) + return { + role + for rule in self.role_rules + for role in self.evaluate_role_rules(rule, jwt_claims) + } + + @staticmethod + def evaluate_role_rules(rule: JwtRoleRule, jwt_claims: dict[str, Any]) -> UserRoles: + """Get roles from a JWT role rule if it matches the claims.""" + return ( + set(rule.roles) + if JwtRolesResolver._evaluate_operator( + rule.negate, + [match.value for match in parse(rule.jsonpath).find(jwt_claims)], + rule.operator, + rule.value, + ) + else set() + ) + + @staticmethod + def _get_claims(auth: AuthTuple) -> dict[str, Any]: + """Get the JWT claims from the auth tuple.""" + _, _, token = auth + jwt_claims = json.loads(token) + + if not jwt_claims: + raise RoleResolutionError( + "Invalid authentication token: no JWT claims found" + ) + + return jwt_claims + + @staticmethod + def _evaluate_operator( + negate: bool, match: Any, operator: JsonPathOperator, value: Any + ) -> bool: # pylint: disable=too-many-branches + """Evaluate an operator against a match and value.""" + result = False + match operator: + case JsonPathOperator.EQUALS: + result = match == value + case JsonPathOperator.CONTAINS: + result = value in match + case JsonPathOperator.IN: + result = match in value + + if negate: + result = not result + + return result + + +class AccessResolver(ABC): # pylint: disable=too-few-public-methods + """Base class for all access resolution strategies.""" + + @abstractmethod + def check_access(self, action: Action, user_roles: UserRoles) -> bool: + """Check if the user has access to the specified action based on their roles.""" + + @abstractmethod + def get_actions(self, user_roles: UserRoles) -> set[Action]: + """Get the actions that the user can perform based on their roles.""" + + +class NoopAccessResolver(AccessResolver): # pylint: disable=too-few-public-methods + """No-op access resolver that does not perform any access checks.""" + + def check_access(self, action: Action, user_roles: UserRoles) -> bool: + """Return True always, indicating access is granted.""" + _ = action # We're noop, it doesn't matter, everyone is allowed + _ = user_roles # We're noop, it doesn't matter, everyone is allowed + return True + + def get_actions(self, user_roles: UserRoles) -> set[Action]: + """Return an empty set of actions, indicating no specific actions are allowed.""" + _ = user_roles # We're noop, it doesn't matter, everyone is allowed + return set(Action) - {Action.ADMIN} + + +class GenericAccessResolver(AccessResolver): # pylint: disable=too-few-public-methods + """Generic role-based access resolver, should apply with most authentication methods. + + This resolver simply checks if a list of roles allow a user to perform a specific + action. The special action ADMIN will grant the user the ability to perform any action, + """ + + def __init__(self, access_rules: list[AccessRule]): + """Initialize the access resolver with access rules.""" + for rule in access_rules: + # Since this is nonsensical, it might be a mistake, so hard fail + if Action.ADMIN in rule.actions and len(rule.actions) > 1: + raise ValueError( + "Access rule with 'admin' action cannot have other actions" + ) + + self.access_rules = access_rules + + # Build a lookup table for access rules + self._access_lookup: dict[str, set[Action]] = {} + for rule in access_rules: + if rule.role not in self._access_lookup: + self._access_lookup[rule.role] = set() + self._access_lookup[rule.role].update(rule.actions) + + def check_access(self, action: Action, user_roles: UserRoles) -> bool: + """Check if the user has access to the specified action based on their roles.""" + if action != Action.ADMIN and self.check_access(Action.ADMIN, user_roles): + # Recurse to check if the roles allow the user to perform the admin action, + # if they do, then we allow any action + return True + + for role in user_roles: + if role in self._access_lookup and action in self._access_lookup[role]: + logger.debug( + "Access granted: role '%s' can perform action '%s'", role, action + ) + return True + + logger.debug( + "Access denied: roles %s cannot perform action '%s'", user_roles, action + ) + return False + + def get_actions(self, user_roles: UserRoles) -> set[Action]: + """Get the actions that the user can perform based on their roles.""" + actions = { + action + for role in user_roles + for action in self._access_lookup.get(role, set()) + } + + # If the user is allowed the admin action, they can perform any action + if Action.ADMIN in actions: + return set(Action) - {Action.ADMIN} + + return actions diff --git a/src/configuration.py b/src/configuration.py index 573520d20..8c3184541 100644 --- a/src/configuration.py +++ b/src/configuration.py @@ -9,6 +9,7 @@ import yaml from models.config import ( + AuthorizationConfiguration, Configuration, Customization, LlamaStackConfiguration, @@ -104,6 +105,18 @@ def authentication_configuration(self) -> AuthenticationConfiguration: return self._configuration.authentication + @property + def authorization_configuration(self) -> AuthorizationConfiguration: + """Return authorization configuration or default no-op configuration.""" + assert ( + self._configuration is not None + ), "logic error: configuration is not loaded" + + if self._configuration.authorization is None: + return AuthorizationConfiguration() + + return self._configuration.authorization + @property def customization(self) -> Optional[Customization]: """Return customization configuration.""" diff --git a/src/models/config.py b/src/models/config.py index 5419ec600..d9e280f4c 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -1,9 +1,19 @@ """Model with service configuration.""" from pathlib import Path -from typing import Optional - -from pydantic import BaseModel, model_validator, FilePath, AnyHttpUrl, PositiveInt +from typing import Optional, Any +from enum import Enum + +import jsonpath_ng +from jsonpath_ng.exceptions import JSONPathError +from pydantic import ( + BaseModel, + Field, + model_validator, + FilePath, + AnyHttpUrl, + PositiveInt, +) from typing_extensions import Self, Literal import constants @@ -210,11 +220,115 @@ def check_storage_location_is_set_when_needed(self) -> Self: return self +class JsonPathOperator(str, Enum): + """Supported operators for JSONPath evaluation.""" + + EQUALS = "equals" + CONTAINS = "contains" + IN = "in" + + +class JwtRoleRule(BaseModel): + """Rule for extracting roles from JWT claims.""" + + jsonpath: str # JSONPath expression to evaluate against the JWT payload + operator: JsonPathOperator # Comparison operator + negate: bool = False # If True, negate the rule + value: Any # Value to compare against + roles: list[str] # Roles to assign if rule matches + + @model_validator(mode="after") + def check_jsonpath(self) -> Self: + """Verify that the JSONPath expression is valid.""" + try: + jsonpath_ng.parse(self.jsonpath) + return self + except JSONPathError as e: + raise ValueError( + f"Invalid JSONPath expression: {self.jsonpath}: {e}" + ) from e + + @model_validator(mode="after") + def check_roles(self) -> Self: + """Ensure that at least one role is specified.""" + if not self.roles: + raise ValueError("At least one role must be specified in the rule") + + if len(self.roles) != len(set(self.roles)): + raise ValueError("Roles must be unique in the rule") + + if any(role == "*" for role in self.roles): + raise ValueError( + "The wildcard '*' role is not allowed in role rules, " + "everyone automatically gets this role" + ) + + return self + + +class Action(str, Enum): + """Available actions in the system.""" + + # Special action to allow unrestricted access to all actions + ADMIN = "admin" + + # List the conversations of other users + LIST_OTHERS_CONVERSATIONS = "list_other_conversations" + + # Read the contents of conversations of other users + READ_OTHERS_CONVERSATIONS = "read_other_conversations" + + # Continue the conversations of other users + QUERY_OTHERS_CONVERSATIONS = "query_other_conversations" + + # Delete the conversations of other users + DELETE_OTHERS_CONVERSATIONS = "delete_other_conversations" + + # Access the query endpoint + QUERY = "query" + + # Access the streaming query endpoint + STREAMING_QUERY = "streaming_query" + + # Access the conversation endpoint + GET_CONVERSATION = "get_conversation" + + # List own conversations + LIST_CONVERSATIONS = "list_conversations" + + # Access the conversation delete endpoint + DELETE_CONVERSATION = "delete_conversation" + FEEDBACK = "feedback" + GET_MODELS = "get_models" + GET_METRICS = "get_metrics" + GET_CONFIG = "get_config" + + INFO = "info" + + +class AccessRule(BaseModel): + """Rule defining what actions a role can perform.""" + + role: str # Role name + actions: list[Action] # Allowed actions for this role + + +class AuthorizationConfiguration(BaseModel): + """Authorization configuration.""" + + access_rules: list[AccessRule] = Field( + default_factory=list + ) # Rules for role-based access control + + class JwtConfiguration(BaseModel): """JWT configuration.""" user_id_claim: str = constants.DEFAULT_JWT_UID_CLAIM username_claim: str = constants.DEFAULT_JWT_USER_NAME_CLAIM + role_rules: list[JwtRoleRule] = Field( + default_factory=list + ) # Rules for extracting roles from JWT claims class JwkConfiguration(BaseModel): @@ -310,6 +424,7 @@ class Configuration(BaseModel): database: DatabaseConfiguration = DatabaseConfiguration() mcp_servers: list[ModelContextProtocolServer] = [] authentication: AuthenticationConfiguration = AuthenticationConfiguration() + authorization: Optional[AuthorizationConfiguration] = None customization: Optional[Customization] = None inference: InferenceConfiguration = InferenceConfiguration() diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index 993dc4e83..34c0147de 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -19,15 +19,25 @@ def validate_conversation_ownership( - user_id: str, conversation_id: str + user_id: str, conversation_id: str, others_allowed: bool = False ) -> UserConversation | None: - """Validate that the conversation belongs to the user.""" + """Validate that the conversation belongs to the user. + + Validates that the conversation with the given ID belongs to the user with the given ID. + If `others_allowed` is True, it allows conversations that do not belong to the user, + which is useful for admin access. + """ with get_session() as session: - conversation = ( - session.query(UserConversation) - .filter_by(id=conversation_id, user_id=user_id) - .first() + conversation_query = session.query(UserConversation) + + filtered_conversation_query = ( + conversation_query.filter_by(id=conversation_id) + if others_allowed + else conversation_query.filter_by(id=conversation_id, user_id=user_id) ) + + conversation: UserConversation | None = filtered_conversation_query.first() + return conversation diff --git a/tests/unit/app/endpoints/test_config.py b/tests/unit/app/endpoints/test_config.py index 48ea9603d..eb3f4993d 100644 --- a/tests/unit/app/endpoints/test_config.py +++ b/tests/unit/app/endpoints/test_config.py @@ -5,10 +5,14 @@ from fastapi import HTTPException, Request, status from app.endpoints.config import config_endpoint_handler from configuration import AppConfig +from tests.unit.utils.auth_helpers import mock_authorization_resolvers -def test_config_endpoint_handler_configuration_not_loaded(mocker): +@pytest.mark.asyncio +async def test_config_endpoint_handler_configuration_not_loaded(mocker): """Test the config endpoint handler.""" + mock_authorization_resolvers(mocker) + mocker.patch( "app.endpoints.config.configuration._configuration", new=None, @@ -20,14 +24,19 @@ def test_config_endpoint_handler_configuration_not_loaded(mocker): "type": "http", } ) - with pytest.raises(HTTPException) as e: - config_endpoint_handler(request) - assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR - assert e.detail["response"] == "Configuration is not loaded" + auth = ("test_user", "token", {}) + with pytest.raises(HTTPException) as exc_info: + await config_endpoint_handler(auth=auth, request=request) + + assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert exc_info.value.detail["response"] == "Configuration is not loaded" -def test_config_endpoint_handler_configuration_loaded(): +@pytest.mark.asyncio +async def test_config_endpoint_handler_configuration_loaded(mocker): """Test the config endpoint handler.""" + mock_authorization_resolvers(mocker) + config_dict = { "name": "foo", "service": { @@ -49,15 +58,21 @@ def test_config_endpoint_handler_configuration_loaded(): "authentication": { "module": "noop", }, + "authorization": {"access_rules": []}, "customization": None, } cfg = AppConfig() cfg.init_from_dict(config_dict) + + # Mock configuration + mocker.patch("app.endpoints.config.configuration", cfg) + request = Request( scope={ "type": "http", } ) - response = config_endpoint_handler(request) + auth = ("test_user", "token", {}) + response = await config_endpoint_handler(auth=auth, request=request) assert response is not None assert response == cfg.configuration diff --git a/tests/unit/app/endpoints/test_conversations.py b/tests/unit/app/endpoints/test_conversations.py index aed6fdc5b..f4fcd4c80 100644 --- a/tests/unit/app/endpoints/test_conversations.py +++ b/tests/unit/app/endpoints/test_conversations.py @@ -1,7 +1,9 @@ +# pylint: disable=redefined-outer-name + """Unit tests for the /conversations REST API endpoints.""" +from fastapi import HTTPException, status, Request import pytest -from fastapi import HTTPException, status from llama_stack_client import APIConnectionError, NotFoundError from app.endpoints.conversations import ( @@ -10,18 +12,34 @@ get_conversations_list_endpoint_handler, simplify_session_data, ) +from models.config import Action from models.responses import ( ConversationResponse, ConversationDeleteResponse, ConversationsListResponse, ) from configuration import AppConfig +from tests.unit.utils.auth_helpers import mock_authorization_resolvers MOCK_AUTH = ("mock_user_id", "mock_username", "mock_token") VALID_CONVERSATION_ID = "123e4567-e89b-12d3-a456-426614174000" INVALID_CONVERSATION_ID = "invalid-id" +@pytest.fixture +def dummy_request() -> Request: + """Mock request object for testing.""" + request = Request( + scope={ + "type": "http", + } + ) + + request.state.authorized_actions = set(Action) + + return request + + def create_mock_conversation( mocker, conversation_id, @@ -48,9 +66,11 @@ def mock_database_session(mocker, query_result=None): """Helper function to mock get_session with proper context manager support.""" mock_session = mocker.Mock() if query_result is not None: - mock_session.query.return_value.filter_by.return_value.all.return_value = ( - query_result - ) + # Mock both the filtered and unfiltered query paths + mock_query = mocker.Mock() + mock_query.all.return_value = query_result + mock_query.filter_by.return_value.all.return_value = query_result + mock_session.query.return_value = mock_query # Mock get_session to return a context manager mock_session_context = mocker.MagicMock() @@ -230,27 +250,35 @@ class TestGetConversationEndpoint: """Test cases for the GET /conversations/{conversation_id} endpoint.""" @pytest.mark.asyncio - async def test_configuration_not_loaded(self, mocker): + async def test_configuration_not_loaded(self, mocker, dummy_request): """Test the endpoint when configuration is not loaded.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", None) with pytest.raises(HTTPException) as exc_info: await get_conversation_endpoint_handler( - VALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + auth=MOCK_AUTH, ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert "Configuration is not loaded" in exc_info.value.detail["response"] @pytest.mark.asyncio - async def test_invalid_conversation_id_format(self, mocker, setup_configuration): + async def test_invalid_conversation_id_format( + self, mocker, setup_configuration, dummy_request + ): """Test the endpoint with an invalid conversation ID format.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=False) with pytest.raises(HTTPException) as exc_info: await get_conversation_endpoint_handler( - INVALID_CONVERSATION_ID, auth=MOCK_AUTH + conversation_id=INVALID_CONVERSATION_ID, + auth=MOCK_AUTH, + request=dummy_request, ) assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST @@ -258,8 +286,11 @@ async def test_invalid_conversation_id_format(self, mocker, setup_configuration) assert INVALID_CONVERSATION_ID in exc_info.value.detail["cause"] @pytest.mark.asyncio - async def test_llama_stack_connection_error(self, mocker, setup_configuration): + async def test_llama_stack_connection_error( + self, mocker, setup_configuration, dummy_request + ): """Test the endpoint when LlamaStack connection fails.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch("app.endpoints.conversations.validate_conversation_ownership") @@ -275,15 +306,20 @@ async def test_llama_stack_connection_error(self, mocker, setup_configuration): # simulate situation when it is not possible to connect to Llama Stack with pytest.raises(HTTPException) as exc_info: await get_conversation_endpoint_handler( - VALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + auth=MOCK_AUTH, ) assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE assert "Unable to connect to Llama Stack" in exc_info.value.detail["response"] @pytest.mark.asyncio - async def test_llama_stack_not_found_error(self, mocker, setup_configuration): + async def test_llama_stack_not_found_error( + self, mocker, setup_configuration, dummy_request + ): """Test the endpoint when LlamaStack returns NotFoundError.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch("app.endpoints.conversations.validate_conversation_ownership") @@ -300,7 +336,9 @@ async def test_llama_stack_not_found_error(self, mocker, setup_configuration): with pytest.raises(HTTPException) as exc_info: await get_conversation_endpoint_handler( - VALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + auth=MOCK_AUTH, ) assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND @@ -309,8 +347,11 @@ async def test_llama_stack_not_found_error(self, mocker, setup_configuration): assert VALID_CONVERSATION_ID in exc_info.value.detail["cause"] @pytest.mark.asyncio - async def test_session_retrieve_exception(self, mocker, setup_configuration): + async def test_session_retrieve_exception( + self, mocker, setup_configuration, dummy_request + ): """Test the endpoint when session retrieval raises an exception.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch("app.endpoints.conversations.validate_conversation_ownership") @@ -325,7 +366,9 @@ async def test_session_retrieve_exception(self, mocker, setup_configuration): with pytest.raises(HTTPException) as exc_info: await get_conversation_endpoint_handler( - VALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + auth=MOCK_AUTH, ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @@ -336,9 +379,15 @@ async def test_session_retrieve_exception(self, mocker, setup_configuration): @pytest.mark.asyncio async def test_successful_conversation_retrieval( - self, mocker, setup_configuration, mock_session_data, expected_chat_history - ): + self, + mocker, + setup_configuration, + mock_session_data, + expected_chat_history, + dummy_request, + ): # pylint: disable=too-many-arguments,too-many-positional-arguments """Test successful conversation retrieval with simplified response structure.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch("app.endpoints.conversations.validate_conversation_ownership") @@ -360,7 +409,7 @@ async def test_successful_conversation_retrieval( mock_client_holder.return_value.get_client.return_value = mock_client response = await get_conversation_endpoint_handler( - VALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, conversation_id=VALID_CONVERSATION_ID, auth=MOCK_AUTH ) assert isinstance(response, ConversationResponse) @@ -375,27 +424,35 @@ class TestDeleteConversationEndpoint: """Test cases for the DELETE /conversations/{conversation_id} endpoint.""" @pytest.mark.asyncio - async def test_configuration_not_loaded(self, mocker): + async def test_configuration_not_loaded(self, mocker, dummy_request): """Test the endpoint when configuration is not loaded.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", None) with pytest.raises(HTTPException) as exc_info: await delete_conversation_endpoint_handler( - VALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + auth=MOCK_AUTH, ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert "Configuration is not loaded" in exc_info.value.detail["response"] @pytest.mark.asyncio - async def test_invalid_conversation_id_format(self, mocker, setup_configuration): + async def test_invalid_conversation_id_format( + self, mocker, setup_configuration, dummy_request + ): """Test the endpoint with an invalid conversation ID format.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=False) with pytest.raises(HTTPException) as exc_info: await delete_conversation_endpoint_handler( - INVALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, + conversation_id=INVALID_CONVERSATION_ID, + auth=MOCK_AUTH, ) assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST @@ -403,8 +460,11 @@ async def test_invalid_conversation_id_format(self, mocker, setup_configuration) assert INVALID_CONVERSATION_ID in exc_info.value.detail["cause"] @pytest.mark.asyncio - async def test_llama_stack_connection_error(self, mocker, setup_configuration): + async def test_llama_stack_connection_error( + self, mocker, setup_configuration, dummy_request + ): """Test the endpoint when LlamaStack connection fails.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch("app.endpoints.conversations.validate_conversation_ownership") @@ -419,15 +479,20 @@ async def test_llama_stack_connection_error(self, mocker, setup_configuration): with pytest.raises(HTTPException) as exc_info: await delete_conversation_endpoint_handler( - VALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + auth=MOCK_AUTH, ) assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE assert "Unable to connect to Llama Stack" in exc_info.value.detail["response"] @pytest.mark.asyncio - async def test_llama_stack_not_found_error(self, mocker, setup_configuration): + async def test_llama_stack_not_found_error( + self, mocker, setup_configuration, dummy_request + ): """Test the endpoint when LlamaStack returns NotFoundError.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch("app.endpoints.conversations.validate_conversation_ownership") @@ -444,7 +509,9 @@ async def test_llama_stack_not_found_error(self, mocker, setup_configuration): with pytest.raises(HTTPException) as exc_info: await delete_conversation_endpoint_handler( - VALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + auth=MOCK_AUTH, ) assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND @@ -453,8 +520,11 @@ async def test_llama_stack_not_found_error(self, mocker, setup_configuration): assert VALID_CONVERSATION_ID in exc_info.value.detail["cause"] @pytest.mark.asyncio - async def test_session_deletion_exception(self, mocker, setup_configuration): + async def test_session_deletion_exception( + self, mocker, setup_configuration, dummy_request + ): """Test the endpoint when session deletion raises an exception.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch("app.endpoints.conversations.validate_conversation_ownership") @@ -471,7 +541,9 @@ async def test_session_deletion_exception(self, mocker, setup_configuration): with pytest.raises(HTTPException) as exc_info: await delete_conversation_endpoint_handler( - VALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + auth=MOCK_AUTH, ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @@ -482,8 +554,11 @@ async def test_session_deletion_exception(self, mocker, setup_configuration): ) @pytest.mark.asyncio - async def test_successful_conversation_deletion(self, mocker, setup_configuration): + async def test_successful_conversation_deletion( + self, mocker, setup_configuration, dummy_request + ): """Test successful conversation deletion.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch("app.endpoints.conversations.validate_conversation_ownership") @@ -501,7 +576,7 @@ async def test_successful_conversation_deletion(self, mocker, setup_configuratio mock_client_holder.return_value.get_client.return_value = mock_client response = await delete_conversation_endpoint_handler( - VALID_CONVERSATION_ID, auth=MOCK_AUTH + request=dummy_request, conversation_id=VALID_CONVERSATION_ID, auth=MOCK_AUTH ) assert isinstance(response, ConversationDeleteResponse) @@ -517,18 +592,26 @@ async def test_successful_conversation_deletion(self, mocker, setup_configuratio class TestGetConversationsListEndpoint: """Test cases for the GET /conversations endpoint.""" - def test_configuration_not_loaded(self, mocker): + @pytest.mark.asyncio + async def test_configuration_not_loaded(self, mocker, dummy_request): """Test the endpoint when configuration is not loaded.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", None) with pytest.raises(HTTPException) as exc_info: - get_conversations_list_endpoint_handler(auth=MOCK_AUTH) + await get_conversations_list_endpoint_handler( + auth=MOCK_AUTH, request=dummy_request + ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert "Configuration is not loaded" in exc_info.value.detail["response"] - def test_successful_conversations_list_retrieval(self, mocker, setup_configuration): + @pytest.mark.asyncio + async def test_successful_conversations_list_retrieval( + self, mocker, setup_configuration, dummy_request + ): """Test successful retrieval of conversations list.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) # Mock database session and query results @@ -554,7 +637,9 @@ def test_successful_conversations_list_retrieval(self, mocker, setup_configurati ] mock_database_session(mocker, mock_conversations) - response = get_conversations_list_endpoint_handler(auth=MOCK_AUTH) + response = await get_conversations_list_endpoint_handler( + auth=MOCK_AUTH, request=dummy_request + ) assert isinstance(response, ConversationsListResponse) assert len(response.conversations) == 2 @@ -567,21 +652,29 @@ def test_successful_conversations_list_retrieval(self, mocker, setup_configurati == "456e7890-e12b-34d5-a678-901234567890" ) - def test_empty_conversations_list(self, mocker, setup_configuration): + @pytest.mark.asyncio + async def test_empty_conversations_list( + self, mocker, setup_configuration, dummy_request + ): """Test when user has no conversations.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) # Mock database session with no results mock_database_session(mocker, []) - response = get_conversations_list_endpoint_handler(auth=MOCK_AUTH) + response = await get_conversations_list_endpoint_handler( + auth=MOCK_AUTH, request=dummy_request + ) assert isinstance(response, ConversationsListResponse) assert len(response.conversations) == 0 assert response.conversations == [] - def test_database_exception(self, mocker, setup_configuration): + @pytest.mark.asyncio + async def test_database_exception(self, mocker, setup_configuration, dummy_request): """Test when database query raises an exception.""" + mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations.configuration", setup_configuration) # Mock database session to raise exception @@ -589,7 +682,9 @@ def test_database_exception(self, mocker, setup_configuration): mock_session.query.side_effect = Exception("Database error") with pytest.raises(HTTPException) as exc_info: - get_conversations_list_endpoint_handler(auth=MOCK_AUTH) + await get_conversations_list_endpoint_handler( + auth=MOCK_AUTH, request=dummy_request + ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert "Unknown error" in exc_info.value.detail["response"] diff --git a/tests/unit/app/endpoints/test_feedback.py b/tests/unit/app/endpoints/test_feedback.py index 337552c05..4a155ef6d 100644 --- a/tests/unit/app/endpoints/test_feedback.py +++ b/tests/unit/app/endpoints/test_feedback.py @@ -11,6 +11,7 @@ store_feedback, feedback_status, ) +from tests.unit.utils.auth_helpers import mock_authorization_resolvers def test_is_feedback_enabled(): @@ -62,9 +63,12 @@ async def test_assert_feedback_enabled(mocker): ], ids=["no_categories", "with_negative_categories"], ) -def test_feedback_endpoint_handler(mocker, feedback_request_data): +@pytest.mark.asyncio +async def test_feedback_endpoint_handler(mocker, feedback_request_data): """Test that feedback_endpoint_handler processes feedback for different payloads.""" + mock_authorization_resolvers(mocker) + # Mock the dependencies mocker.patch("app.endpoints.feedback.assert_feedback_enabled", return_value=None) mocker.patch("app.endpoints.feedback.store_feedback", return_value=None) @@ -74,7 +78,7 @@ def test_feedback_endpoint_handler(mocker, feedback_request_data): feedback_request.model_dump.return_value = feedback_request_data # Call the endpoint handler - result = feedback_endpoint_handler( + result = await feedback_endpoint_handler( feedback_request=feedback_request, _ensure_feedback_enabled=assert_feedback_enabled, auth=("test_user_id", "test_username", "test_token"), @@ -84,8 +88,11 @@ def test_feedback_endpoint_handler(mocker, feedback_request_data): assert result.response == "feedback received" -def test_feedback_endpoint_handler_error(mocker): +@pytest.mark.asyncio +async def test_feedback_endpoint_handler_error(mocker): """Test that feedback_endpoint_handler raises an HTTPException on error.""" + mock_authorization_resolvers(mocker) + # Mock the dependencies mocker.patch("app.endpoints.feedback.assert_feedback_enabled", return_value=None) mocker.patch( @@ -98,7 +105,7 @@ def test_feedback_endpoint_handler_error(mocker): # Call the endpoint handler and assert it raises an exception with pytest.raises(HTTPException) as exc_info: - feedback_endpoint_handler( + await feedback_endpoint_handler( feedback_request=feedback_request, _ensure_feedback_enabled=assert_feedback_enabled, auth=("test_user_id", "test_username", "test_token"), diff --git a/tests/unit/app/endpoints/test_health.py b/tests/unit/app/endpoints/test_health.py index 02714ad26..6e435adf6 100644 --- a/tests/unit/app/endpoints/test_health.py +++ b/tests/unit/app/endpoints/test_health.py @@ -2,18 +2,22 @@ from unittest.mock import Mock +import pytest from llama_stack.providers.datatypes import HealthStatus - from app.endpoints.health import ( readiness_probe_get_method, liveness_probe_get_method, get_providers_health_statuses, ) from models.responses import ProviderHealthStatus, ReadinessResponse +from tests.unit.utils.auth_helpers import mock_authorization_resolvers +@pytest.mark.asyncio async def test_readiness_probe_fails_due_to_unhealthy_providers(mocker): """Test the readiness endpoint handler fails when providers are unhealthy.""" + mock_authorization_resolvers(mocker) + # Mock get_providers_health_statuses to return an unhealthy provider mock_get_providers_health_statuses = mocker.patch( "app.endpoints.health.get_providers_health_statuses" @@ -26,10 +30,11 @@ async def test_readiness_probe_fails_due_to_unhealthy_providers(mocker): ) ] - # Mock the Response object + # Mock the Response object and auth mock_response = Mock() + auth = ("test_user", "token", {}) - response = await readiness_probe_get_method(mock_response) + response = await readiness_probe_get_method(auth=auth, response=mock_response) assert response.ready is False assert "test_provider" in response.reason @@ -37,8 +42,11 @@ async def test_readiness_probe_fails_due_to_unhealthy_providers(mocker): assert mock_response.status_code == 503 +@pytest.mark.asyncio async def test_readiness_probe_success_when_all_providers_healthy(mocker): """Test the readiness endpoint handler succeeds when all providers are healthy.""" + mock_authorization_resolvers(mocker) + # Mock get_providers_health_statuses to return healthy providers mock_get_providers_health_statuses = mocker.patch( "app.endpoints.health.get_providers_health_statuses" @@ -56,10 +64,11 @@ async def test_readiness_probe_success_when_all_providers_healthy(mocker): ), ] - # Mock the Response object + # Mock the Response object and auth mock_response = Mock() + auth = ("test_user", "token", {}) - response = await readiness_probe_get_method(mock_response) + response = await readiness_probe_get_method(auth=auth, response=mock_response) assert response is not None assert isinstance(response, ReadinessResponse) assert response.ready is True @@ -68,9 +77,13 @@ async def test_readiness_probe_success_when_all_providers_healthy(mocker): assert len(response.providers) == 0 -def test_liveness_probe(): +@pytest.mark.asyncio +async def test_liveness_probe(mocker): """Test the liveness endpoint handler.""" - response = liveness_probe_get_method() + mock_authorization_resolvers(mocker) + + auth = ("test_user", "token", {}) + response = await liveness_probe_get_method(auth=auth) assert response is not None assert response.alive is True diff --git a/tests/unit/app/endpoints/test_info.py b/tests/unit/app/endpoints/test_info.py index 4e8f30ca5..837c58ded 100644 --- a/tests/unit/app/endpoints/test_info.py +++ b/tests/unit/app/endpoints/test_info.py @@ -1,12 +1,15 @@ """Unit tests for the /info REST API endpoint.""" +import pytest from fastapi import Request from app.endpoints.info import info_endpoint_handler from configuration import AppConfig +from tests.unit.utils.auth_helpers import mock_authorization_resolvers -def test_info_endpoint(): +@pytest.mark.asyncio +async def test_info_endpoint(mocker): """Test the info endpoint handler.""" config_dict = { "name": "foo", @@ -27,15 +30,24 @@ def test_info_endpoint(): "feedback_enabled": False, }, "customization": None, + "authorization": {"access_rules": []}, + "authentication": {"module": "noop"}, } cfg = AppConfig() cfg.init_from_dict(config_dict) + + # Mock configuration + mocker.patch("configuration.configuration", cfg) + + mock_authorization_resolvers(mocker) + request = Request( scope={ "type": "http", } ) - response = info_endpoint_handler(request) + auth = ("test_user", "token", {}) + response = await info_endpoint_handler(auth=auth, request=request) assert response is not None assert response.name is not None assert response.version is not None diff --git a/tests/unit/app/endpoints/test_metrics.py b/tests/unit/app/endpoints/test_metrics.py index 8d0742151..95db10b64 100644 --- a/tests/unit/app/endpoints/test_metrics.py +++ b/tests/unit/app/endpoints/test_metrics.py @@ -1,12 +1,17 @@ """Unit tests for the /metrics REST API endpoint.""" +import pytest from fastapi import Request from app.endpoints.metrics import metrics_endpoint_handler +from tests.unit.utils.auth_helpers import mock_authorization_resolvers +@pytest.mark.asyncio async def test_metrics_endpoint(mocker): """Test the metrics endpoint handler.""" + mock_authorization_resolvers(mocker) + mock_setup_metrics = mocker.patch( "app.endpoints.metrics.setup_model_metrics", return_value=None ) @@ -15,7 +20,8 @@ async def test_metrics_endpoint(mocker): "type": "http", } ) - response = await metrics_endpoint_handler(request) + auth = ("test_user", "token", {}) + response = await metrics_endpoint_handler(auth=auth, request=request) assert response is not None assert response.status_code == 200 assert "text/plain" in response.headers["Content-Type"] diff --git a/tests/unit/app/endpoints/test_models.py b/tests/unit/app/endpoints/test_models.py index ca58c4b3e..a00c9b37b 100644 --- a/tests/unit/app/endpoints/test_models.py +++ b/tests/unit/app/endpoints/test_models.py @@ -8,11 +8,14 @@ from app.endpoints.models import models_endpoint_handler from configuration import AppConfig +from tests.unit.utils.auth_helpers import mock_authorization_resolvers @pytest.mark.asyncio async def test_models_endpoint_handler_configuration_not_loaded(mocker): """Test the models endpoint handler if configuration is not loaded.""" + mock_authorization_resolvers(mocker) + # simulate state when no configuration is loaded mocker.patch( "app.endpoints.models.configuration", @@ -26,9 +29,10 @@ async def test_models_endpoint_handler_configuration_not_loaded(mocker): "headers": [(b"authorization", b"Bearer invalid-token")], } ) + auth = ("user_id", "user_name", "token") with pytest.raises(HTTPException) as e: - await models_endpoint_handler(request) + await models_endpoint_handler(request=request, auth=auth) assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert e.detail["response"] == "Configuration is not loaded" @@ -36,6 +40,8 @@ async def test_models_endpoint_handler_configuration_not_loaded(mocker): @pytest.mark.asyncio async def test_models_endpoint_handler_improper_llama_stack_configuration(mocker): """Test the models endpoint handler if Llama Stack configuration is not proper.""" + mock_authorization_resolvers(mocker) + # configuration for tests config_dict = { "name": "test", @@ -57,6 +63,8 @@ async def test_models_endpoint_handler_improper_llama_stack_configuration(mocker }, "mcp_servers": [], "customization": None, + "authorization": {"access_rules": []}, + "authentication": {"module": "noop"}, } cfg = AppConfig() cfg.init_from_dict(config_dict) @@ -72,15 +80,18 @@ async def test_models_endpoint_handler_improper_llama_stack_configuration(mocker "headers": [(b"authorization", b"Bearer invalid-token")], } ) + auth = ("test_user", "token", {}) with pytest.raises(HTTPException) as e: - await models_endpoint_handler(request) + await models_endpoint_handler(request=request, auth=auth) assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert e.detail["response"] == "LLama stack is not configured" @pytest.mark.asyncio -async def test_models_endpoint_handler_configuration_loaded(): +async def test_models_endpoint_handler_configuration_loaded(mocker): """Test the models endpoint handler if configuration is loaded.""" + mock_authorization_resolvers(mocker) + # configuration for tests config_dict = { "name": "foo", @@ -101,6 +112,8 @@ async def test_models_endpoint_handler_configuration_loaded(): "feedback_enabled": False, }, "customization": None, + "authorization": {"access_rules": []}, + "authentication": {"module": "noop"}, } cfg = AppConfig() cfg.init_from_dict(config_dict) @@ -111,9 +124,10 @@ async def test_models_endpoint_handler_configuration_loaded(): "headers": [(b"authorization", b"Bearer invalid-token")], } ) + auth = ("test_user", "token", {}) with pytest.raises(HTTPException) as e: - await models_endpoint_handler(request) + await models_endpoint_handler(request=request, auth=auth) assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert e.detail["response"] == "Unable to connect to Llama Stack" @@ -121,6 +135,8 @@ async def test_models_endpoint_handler_configuration_loaded(): @pytest.mark.asyncio async def test_models_endpoint_handler_unable_to_retrieve_models_list(mocker): """Test the models endpoint handler if configuration is loaded.""" + mock_authorization_resolvers(mocker) + # configuration for tests config_dict = { "name": "foo", @@ -141,6 +157,8 @@ async def test_models_endpoint_handler_unable_to_retrieve_models_list(mocker): "feedback_enabled": False, }, "customization": None, + "authorization": {"access_rules": []}, + "authentication": {"module": "noop"}, } cfg = AppConfig() cfg.init_from_dict(config_dict) @@ -159,13 +177,16 @@ async def test_models_endpoint_handler_unable_to_retrieve_models_list(mocker): "headers": [(b"authorization", b"Bearer invalid-token")], } ) - response = await models_endpoint_handler(request) + auth = ("test_user", "token", {}) + response = await models_endpoint_handler(request=request, auth=auth) assert response is not None @pytest.mark.asyncio async def test_models_endpoint_llama_stack_connection_error(mocker): """Test the model endpoint when LlamaStack connection fails.""" + mock_authorization_resolvers(mocker) + # configuration for tests config_dict = { "name": "foo", @@ -186,6 +207,8 @@ async def test_models_endpoint_llama_stack_connection_error(mocker): "feedback_enabled": False, }, "customization": None, + "authorization": {"access_rules": []}, + "authentication": {"module": "noop"}, } # mock AsyncLlamaStackClientHolder to raise APIConnectionError @@ -206,8 +229,9 @@ async def test_models_endpoint_llama_stack_connection_error(mocker): "headers": [(b"authorization", b"Bearer invalid-token")], } ) + auth = ("test_user", "token", {}) with pytest.raises(HTTPException) as e: - await models_endpoint_handler(request) + await models_endpoint_handler(request=request, auth=auth) assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert e.detail["response"] == "Unable to connect to Llama Stack" diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 3356f2b27..b12101b47 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -1,9 +1,11 @@ +# pylint: disable=redefined-outer-name + """Unit tests for the /query REST API endpoint.""" # pylint: disable=too-many-lines import json -from fastapi import HTTPException, status +from fastapi import HTTPException, status, Request import pytest from llama_stack_client import APIConnectionError @@ -24,13 +26,26 @@ ) from models.requests import QueryRequest, Attachment -from models.config import ModelContextProtocolServer +from models.config import Action, ModelContextProtocolServer from models.database.conversations import UserConversation from utils.types import ToolCallSummary, TurnSummary MOCK_AUTH = ("mock_user_id", "mock_username", "mock_token") +@pytest.fixture +def dummy_request() -> Request: + """Dummy request fixture for testing.""" + req = Request( + scope={ + "type": "http", + } + ) + + req.state.authorized_actions = set(Action) + return req + + def mock_database_operations(mocker): """Helper function to mock database operations for query endpoints.""" mocker.patch( @@ -69,7 +84,7 @@ def setup_configuration_fixture(): @pytest.mark.asyncio -async def test_query_endpoint_handler_configuration_not_loaded(mocker): +async def test_query_endpoint_handler_configuration_not_loaded(mocker, dummy_request): """Test the query endpoint handler if configuration is not loaded.""" # simulate state when no configuration is loaded mocker.patch( @@ -81,7 +96,11 @@ async def test_query_endpoint_handler_configuration_not_loaded(mocker): query = "What is OpenStack?" query_request = QueryRequest(query=query) with pytest.raises(HTTPException) as e: - await query_endpoint_handler(query_request, auth=["test-user", "", "token"]) + await query_endpoint_handler( + query_request=query_request, + request=dummy_request, + auth=["test-user", "", "token"], + ) assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert e.value.detail["response"] == "Configuration is not loaded" @@ -107,7 +126,9 @@ def test_is_transcripts_disabled(setup_configuration, mocker): assert is_transcripts_enabled() is False, "Transcripts should be disabled" -async def _test_query_endpoint_handler(mocker, store_transcript_to_file=False): +async def _test_query_endpoint_handler( + mocker, dummy_request: Request, store_transcript_to_file=False +): """Test the query endpoint handler.""" mock_metric = mocker.patch("metrics.llm_calls_total") mock_client = mocker.AsyncMock() @@ -157,7 +178,9 @@ async def _test_query_endpoint_handler(mocker, store_transcript_to_file=False): query_request = QueryRequest(query=query) - response = await query_endpoint_handler(query_request, auth=MOCK_AUTH) + response = await query_endpoint_handler( + request=dummy_request, query_request=query_request, auth=MOCK_AUTH + ) # Assert the response is as expected assert response.response == summary.llm_response @@ -186,15 +209,21 @@ async def _test_query_endpoint_handler(mocker, store_transcript_to_file=False): @pytest.mark.asyncio -async def test_query_endpoint_handler_transcript_storage_disabled(mocker): +async def test_query_endpoint_handler_transcript_storage_disabled( + mocker, dummy_request +): """Test the query endpoint handler with transcript storage disabled.""" - await _test_query_endpoint_handler(mocker, store_transcript_to_file=False) + await _test_query_endpoint_handler( + mocker, dummy_request, store_transcript_to_file=False + ) @pytest.mark.asyncio -async def test_query_endpoint_handler_store_transcript(mocker): +async def test_query_endpoint_handler_store_transcript(mocker, dummy_request): """Test the query endpoint handler with transcript storage enabled.""" - await _test_query_endpoint_handler(mocker, store_transcript_to_file=True) + await _test_query_endpoint_handler( + mocker, dummy_request, store_transcript_to_file=True + ) def test_select_model_and_provider_id_from_request(mocker): @@ -1095,7 +1124,7 @@ def test_get_rag_toolgroups(): @pytest.mark.asyncio -async def test_query_endpoint_handler_on_connection_error(mocker): +async def test_query_endpoint_handler_on_connection_error(mocker, dummy_request): """Test the query endpoint handler.""" mock_metric = mocker.patch("metrics.llm_calls_failures_total") @@ -1111,7 +1140,9 @@ async def test_query_endpoint_handler_on_connection_error(mocker): mock_get_client.side_effect = APIConnectionError(request=query_request) with pytest.raises(HTTPException) as exc_info: - await query_endpoint_handler(query_request, auth=MOCK_AUTH) + await query_endpoint_handler( + query_request=query_request, request=dummy_request, auth=MOCK_AUTH + ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert "Unable to connect to Llama Stack" in str(exc_info.value.detail) @@ -1119,7 +1150,7 @@ async def test_query_endpoint_handler_on_connection_error(mocker): @pytest.mark.asyncio -async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker): +async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker, dummy_request): """Test that auth tuple is correctly unpacked in query endpoint handler.""" # Mock dependencies mock_config = mocker.Mock() @@ -1159,7 +1190,8 @@ async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker): mock_database_operations(mocker) _ = await query_endpoint_handler( - QueryRequest(query="test query"), + request=dummy_request, + query_request=QueryRequest(query="test query"), auth=("user123", "username", "auth_token_123"), mcp_headers=None, ) @@ -1168,7 +1200,7 @@ async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker): @pytest.mark.asyncio -async def test_query_endpoint_handler_no_tools_true(mocker): +async def test_query_endpoint_handler_no_tools_true(mocker, dummy_request): """Test the query endpoint handler with no_tools=True.""" mock_client = mocker.AsyncMock() mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") @@ -1209,7 +1241,9 @@ async def test_query_endpoint_handler_no_tools_true(mocker): query_request = QueryRequest(query=query, no_tools=True) - response = await query_endpoint_handler(query_request, auth=MOCK_AUTH) + response = await query_endpoint_handler( + request=dummy_request, query_request=query_request, auth=MOCK_AUTH + ) # Assert the response is as expected assert response.response == summary.llm_response @@ -1217,7 +1251,7 @@ async def test_query_endpoint_handler_no_tools_true(mocker): @pytest.mark.asyncio -async def test_query_endpoint_handler_no_tools_false(mocker): +async def test_query_endpoint_handler_no_tools_false(mocker, dummy_request): """Test the query endpoint handler with no_tools=False (default behavior).""" mock_client = mocker.AsyncMock() mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") @@ -1258,7 +1292,9 @@ async def test_query_endpoint_handler_no_tools_false(mocker): query_request = QueryRequest(query=query, no_tools=False) - response = await query_endpoint_handler(query_request, auth=MOCK_AUTH) + response = await query_endpoint_handler( + request=dummy_request, query_request=query_request, auth=MOCK_AUTH + ) # Assert the response is as expected assert response.response == summary.llm_response diff --git a/tests/unit/app/endpoints/test_root.py b/tests/unit/app/endpoints/test_root.py index c2e889dfe..a78856efd 100644 --- a/tests/unit/app/endpoints/test_root.py +++ b/tests/unit/app/endpoints/test_root.py @@ -1,16 +1,22 @@ """Unit tests for the / endpoint handler.""" +import pytest from fastapi import Request from app.endpoints.root import root_endpoint_handler +from tests.unit.utils.auth_helpers import mock_authorization_resolvers -def test_root_endpoint(): +@pytest.mark.asyncio +async def test_root_endpoint(mocker): """Test the root endpoint handler.""" + mock_authorization_resolvers(mocker) + + auth = ("test_user", "token", {}) request = Request( scope={ "type": "http", } ) - response = root_endpoint_handler(request) + response = await root_endpoint_handler(auth=auth, request=request) assert response is not None diff --git a/tests/unit/auth/test_jwk_token.py b/tests/unit/auth/test_jwk_token.py index 7f59bc870..5a579b93d 100644 --- a/tests/unit/auth/test_jwk_token.py +++ b/tests/unit/auth/test_jwk_token.py @@ -3,6 +3,8 @@ """Unit tests for functions defined in auth/jwk_token.py""" import time +import json +import base64 import pytest from fastapi import HTTPException, Request @@ -171,12 +173,19 @@ def set_auth_header(request: Request, token: str): request.scope["headers"] = new_headers +def get_claims(token: str): + """Extract claims from a JWT token without validating it.""" + payload = token.split(".")[1] + padded = payload + "=" * (-len(payload) % 4) + return json.loads(base64.urlsafe_b64decode(padded)) + + def ensure_test_user_id_and_name(auth_tuple, token): """Utility to ensure that the values in the auth tuple match the test values.""" - user_id, username, tuple_token = auth_tuple + user_id, username, token_claims = auth_tuple assert user_id == TEST_USER_ID assert username == TEST_USER_NAME - assert tuple_token == token + assert json.loads(token_claims) == get_claims(token) async def test_valid( diff --git a/tests/unit/authorization/__init__.py b/tests/unit/authorization/__init__.py new file mode 100644 index 000000000..96552c229 --- /dev/null +++ b/tests/unit/authorization/__init__.py @@ -0,0 +1 @@ +"""Unit tests for authorization module.""" diff --git a/tests/unit/authorization/test_resolvers.py b/tests/unit/authorization/test_resolvers.py new file mode 100644 index 000000000..7e02d4071 --- /dev/null +++ b/tests/unit/authorization/test_resolvers.py @@ -0,0 +1,101 @@ +"""Unit tests for the authorization resolvers.""" + +from authorization.resolvers import JwtRolesResolver, GenericAccessResolver +from models.config import JwtRoleRule, AccessRule, JsonPathOperator, Action + + +class TestJwtRolesResolver: + """Test cases for JwtRolesResolver.""" + + async def test_resolve_roles_redhat_employee(self): + """Test role extraction for RedHat employee JWT.""" + role_rules = [ + JwtRoleRule( + jsonpath="$.realm_access.roles[*]", + operator=JsonPathOperator.CONTAINS, + value="redhat:employees", + roles=["employee"], + ) + ] + jwt_resolver = JwtRolesResolver(role_rules) + + jwt_claims = { + "exp": 1754489339, + "iat": 1754488439, + "sub": "f:123:employee@redhat.com", + "realm_access": { + "roles": [ + "uma_authorization", + "redhat:employees", + "default-roles-redhat", + ] + }, + } + + # Mock auth tuple with JWT claims as third element + auth = ("user", "token", str(jwt_claims).replace("'", '"')) + roles = await jwt_resolver.resolve_roles(auth) + assert "employee" in roles + + async def test_resolve_roles_no_match(self): + """Test role extraction when no rules match.""" + role_rules = [ + JwtRoleRule( + jsonpath="$.realm_access.roles[*]", + operator=JsonPathOperator.CONTAINS, + value="redhat:employees", + roles=["employee"], + ) + ] + jwt_resolver = JwtRolesResolver(role_rules) + + jwt_claims = { + "exp": 1754489339, + "iat": 1754488439, + "sub": "f:123:user@example.com", + "realm_access": {"roles": ["uma_authorization", "default-roles-example"]}, + } + + # Mock auth tuple with JWT claims as third element + auth = ("user", "token", str(jwt_claims).replace("'", '"')) + roles = await jwt_resolver.resolve_roles(auth) + assert len(roles) == 0 + + +class TestGenericAccessResolver: + """Test cases for GenericAccessResolver.""" + + async def test_check_access_with_valid_role(self): + """Test access check with valid role.""" + access_rules = [ + AccessRule(role="employee", actions=[Action.QUERY, Action.GET_MODELS]) + ] + resolver = GenericAccessResolver(access_rules) + + # Test access granted + has_access = resolver.check_access(Action.QUERY, {"employee"}) + assert has_access is True + + # Test access denied + has_access = resolver.check_access(Action.FEEDBACK, frozenset(["employee"])) + assert has_access is False + + async def test_check_access_with_invalid_role(self): + """Test access check with invalid role.""" + access_rules = [ + AccessRule(role="employee", actions=[Action.QUERY, Action.GET_MODELS]) + ] + resolver = GenericAccessResolver(access_rules) + + has_access = resolver.check_access(Action.QUERY, {"visitor"}) + assert has_access is False + + async def test_check_access_with_no_roles(self): + """Test access check with no roles.""" + access_rules = [ + AccessRule(role="employee", actions=[Action.QUERY, Action.GET_MODELS]) + ] + resolver = GenericAccessResolver(access_rules) + + has_access = resolver.check_access(Action.QUERY, set()) + assert has_access is False diff --git a/tests/unit/models/test_config.py b/tests/unit/models/test_config.py index b10fa1424..ba44cf9c0 100644 --- a/tests/unit/models/test_config.py +++ b/tests/unit/models/test_config.py @@ -552,6 +552,7 @@ def test_dump_configuration(tmp_path) -> None: assert "user_data_collection" in content assert "mcp_servers" in content assert "authentication" in content + assert "authorization" in content assert "customization" in content assert "inference" in content assert "database" in content @@ -619,6 +620,7 @@ def test_dump_configuration(tmp_path) -> None: "sqlite": {"db_path": "/tmp/lightspeed-stack.db"}, "postgres": None, }, + "authorization": None, } diff --git a/tests/unit/utils/auth_helpers.py b/tests/unit/utils/auth_helpers.py new file mode 100644 index 000000000..fe07950de --- /dev/null +++ b/tests/unit/utils/auth_helpers.py @@ -0,0 +1,27 @@ +"""Helper functions for mocking authorization in tests.""" + +from typing import Any +from unittest.mock import AsyncMock, Mock + +from models.config import Action + + +def mock_authorization_resolvers(mocker: Any) -> None: + """Mock authorization resolvers to allow all access. + + This function mocks the authorization middleware to bypass authorization + checks in tests by creating mock resolvers that always grant access. + + Args: + mocker: The pytest-mock mocker fixture + """ + mock_resolvers = mocker.patch( + "authorization.middleware.get_authorization_resolvers" + ) + mock_role_resolver = AsyncMock() + mock_access_resolver = AsyncMock() + mock_role_resolver.resolve_roles.return_value = set() + mock_access_resolver.check_access.return_value = True + # get_actions should be synchronous, not async + mock_access_resolver.get_actions = Mock(return_value=set(Action)) + mock_resolvers.return_value = (mock_role_resolver, mock_access_resolver) diff --git a/uv.lock b/uv.lock index 64b55961c..e24709428 100644 --- a/uv.lock +++ b/uv.lock @@ -1119,6 +1119,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7d/4f/1195bbac8e0c2acc5f740661631d8d750dc38d4a32b23ee5df3cde6f4e0d/joblib-1.5.1-py3-none-any.whl", hash = "sha256:4719a31f054c7d766948dcd83e9613686b27114f190f717cec7eaa2084f8a74a", size = 307746, upload-time = "2025-05-23T12:04:35.124Z" }, ] +[[package]] +name = "jsonpath-ng" +version = "1.7.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ply" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/86/08646239a313f895186ff0a4573452038eed8c86f54380b3ebac34d32fb2/jsonpath-ng-1.7.0.tar.gz", hash = "sha256:f6f5f7fd4e5ff79c785f1573b394043b39849fb2bb47bcead935d12b00beab3c", size = 37838, upload-time = "2024-10-11T15:41:42.404Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/35/5a/73ecb3d82f8615f32ccdadeb9356726d6cae3a4bbc840b437ceb95708063/jsonpath_ng-1.7.0-py3-none-any.whl", hash = "sha256:f3d7f9e848cba1b6da28c55b1c26ff915dc9e0b1ba7e752a53d6da8d5cbd00b6", size = 30105, upload-time = "2024-11-20T17:58:30.418Z" }, +] + [[package]] name = "jsonschema" version = "4.25.1" @@ -1249,6 +1261,7 @@ dependencies = [ { name = "cachetools" }, { name = "email-validator" }, { name = "fastapi" }, + { name = "jsonpath-ng" }, { name = "kubernetes" }, { name = "llama-stack" }, { name = "llama-stack-client" }, @@ -1327,6 +1340,7 @@ requires-dist = [ { name = "cachetools", specifier = ">=6.1.0" }, { name = "email-validator", specifier = ">=2.2.0" }, { name = "fastapi", specifier = ">=0.115.12" }, + { name = "jsonpath-ng", specifier = ">=1.6.1" }, { name = "kubernetes", specifier = ">=30.1.0" }, { name = "llama-stack", specifier = "==0.2.17" }, { name = "llama-stack-client", specifier = "==0.2.17" }, @@ -2233,6 +2247,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "ply" +version = "3.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e5/69/882ee5c9d017149285cab114ebeab373308ef0f874fcdac9beb90e0ac4da/ply-3.11.tar.gz", hash = "sha256:00c7c1aaa88358b9c765b6d3000c6eec0ba42abca5351b095321aef446081da3", size = 159130, upload-time = "2018-02-15T19:01:31.097Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a3/58/35da89ee790598a0700ea49b2a66594140f44dec458c07e8e3d4979137fc/ply-3.11-py2.py3-none-any.whl", hash = "sha256:096f9b8350b65ebd2fd1346b12452efe5b9607f7482813ffca50c22722a807ce", size = 49567, upload-time = "2018-02-15T19:01:27.172Z" }, +] + [[package]] name = "polyleven" version = "0.9.0"