From 3371c974c1cab6e2c9f08caa5e39985614238586 Mon Sep 17 00:00:00 2001 From: Zoltan Szabo Date: Mon, 1 Sep 2025 13:18:05 +0200 Subject: [PATCH] allow disabling query model and provider --- README.md | 4 ++ docs/openapi.json | 3 +- src/app/endpoints/query.py | 4 ++ src/app/endpoints/streaming_query.py | 4 ++ src/authorization/middleware.py | 23 ++++--- src/models/config.py | 2 + src/utils/endpoints.py | 24 ++++++++ tests/unit/app/endpoints/test_query.py | 56 +++++++++++++++++ .../app/endpoints/test_streaming_query.py | 61 ++++++++++++++++++- tests/unit/utils/test_endpoints.py | 23 +++++++ 10 files changed, 194 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 9409317ff..308fb778e 100644 --- a/README.md +++ b/README.md @@ -369,6 +369,10 @@ customization: disable_query_system_prompt: true ``` +### Control model/provider overrides via authorization + +By default, clients may specify `model` and `provider` in `/v1/query` and `/v1/streaming_query`. Override is permitted only to callers granted the `MODEL_OVERRIDE` action via the authorization rules. Requests that include `model` or `provider` without this permission are rejected with HTTP 403. + ## Safety Shields A single Llama Stack configuration file can include multiple safety shields, which are utilized in agent diff --git a/docs/openapi.json b/docs/openapi.json index 83bcbf51c..874181120 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -786,7 +786,8 @@ "get_models", "get_metrics", "get_config", - "info" + "info", + "model_override" ], "title": "Action", "description": "Available actions in the system." diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 85bd986ef..243a94e4e 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -35,6 +35,7 @@ get_agent, get_system_prompt, validate_conversation_ownership, + validate_model_provider_override, ) from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups from utils.transcripts import store_transcript @@ -174,6 +175,9 @@ async def query_endpoint_handler( """ check_configuration_loaded(configuration) + # Enforce RBAC: optionally disallow overriding model/provider in requests + validate_model_provider_override(query_request, request.state.authorized_actions) + # log Llama Stack configuration, but without sensitive information llama_stack_config = configuration.llama_stack_configuration.model_copy() llama_stack_config.api_key = "********" diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 4aafc6f98..7f2a418b4 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -33,6 +33,7 @@ from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups from utils.transcripts import store_transcript from utils.types import TurnSummary +from utils.endpoints import validate_model_provider_override from app.endpoints.query import ( get_rag_toolgroups, @@ -548,6 +549,9 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals check_configuration_loaded(configuration) + # Enforce RBAC: optionally disallow overriding model/provider in requests + validate_model_provider_override(query_request, request.state.authorized_actions) + # log Llama Stack configuration, but without sensitive information llama_stack_config = configuration.llama_stack_configuration.model_copy() llama_stack_config.api_key = "********" diff --git a/src/authorization/middleware.py b/src/authorization/middleware.py index a919c8789..6d03b8d67 100644 --- a/src/authorization/middleware.py +++ b/src/authorization/middleware.py @@ -4,6 +4,7 @@ from functools import wraps, lru_cache from typing import Any, Callable, Tuple from fastapi import HTTPException, status +from starlette.requests import Request from authorization.resolvers import ( AccessResolver, @@ -64,7 +65,9 @@ def get_authorization_resolvers() -> Tuple[RolesResolver, AccessResolver]: ) -async def _perform_authorization_check(action: Action, kwargs: dict[str, Any]) -> None: +async def _perform_authorization_check( + action: Action, args: tuple[Any, ...], kwargs: dict[str, Any] +) -> None: """Perform authorization check - common logic for all decorators.""" role_resolver, access_resolver = get_authorization_resolvers() @@ -93,12 +96,16 @@ async def _perform_authorization_check(action: Action, kwargs: dict[str, Any]) - authorized_actions = access_resolver.get_actions(user_roles) - try: - request = kwargs["request"] - request.state.authorized_actions = authorized_actions - except KeyError: - # This endpoint doesn't seem care about the authorized actions, so no need to set it - pass + req: Request | None = None + if "request" in kwargs and isinstance(kwargs["request"], Request): + req = kwargs["request"] + else: + for arg in args: + if isinstance(arg, Request): + req = arg + break + if req is not None: + req.state.authorized_actions = authorized_actions def authorize(action: Action) -> Callable: @@ -107,7 +114,7 @@ def authorize(action: Action) -> Callable: def decorator(func: Callable) -> Callable: @wraps(func) async def wrapper(*args: Any, **kwargs: Any) -> Any: - await _perform_authorization_check(action, kwargs) + await _perform_authorization_check(action, args, kwargs) return await func(*args, **kwargs) return wrapper diff --git a/src/models/config.py b/src/models/config.py index 9166eb4bc..14da03280 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -311,6 +311,8 @@ class Action(str, Enum): GET_CONFIG = "get_config" INFO = "info" + # Allow overriding model/provider via request + MODEL_OVERRIDE = "model_override" class AccessRule(ConfigurationBase): diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index 34c0147de..1c23384b8 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -9,6 +9,7 @@ import constants from models.requests import QueryRequest from models.database.conversations import UserConversation +from models.config import Action from app.database import get_session from configuration import AppConfig from utils.suid import get_suid @@ -84,6 +85,29 @@ def get_system_prompt(query_request: QueryRequest, config: AppConfig) -> str: return constants.DEFAULT_SYSTEM_PROMPT +def validate_model_provider_override( + query_request: QueryRequest, authorized_actions: set[Action] | frozenset[Action] +) -> None: + """Validate whether model/provider overrides are allowed by RBAC. + + Raises HTTP 403 if the request includes model or provider and the caller + lacks Action.MODEL_OVERRIDE permission. + """ + if (query_request.model is not None or query_request.provider is not None) and ( + Action.MODEL_OVERRIDE not in authorized_actions + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "response": ( + "This instance does not permit overriding model/provider in the query request " + "(missing permission: MODEL_OVERRIDE). Please remove the model and provider " + "fields from your request." + ) + }, + ) + + # # pylint: disable=R0913,R0917 async def get_agent( client: AsyncLlamaStackClient, diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index b12101b47..d480963f2 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -29,6 +29,7 @@ from models.config import Action, ModelContextProtocolServer from models.database.conversations import UserConversation from utils.types import ToolCallSummary, TurnSummary +from authorization.resolvers import NoopRolesResolver MOCK_AUTH = ("mock_user_id", "mock_username", "mock_token") @@ -1507,3 +1508,58 @@ def test_evaluate_model_hints( assert provider_id == expected_provider assert model_id == expected_model + + +@pytest.mark.asyncio +async def test_query_endpoint_rejects_model_provider_override_without_permission( + mocker, dummy_request +): + """Assert 403 and message when request includes model/provider without MODEL_OVERRIDE.""" + # Patch endpoint configuration (no need to set customization) + cfg = AppConfig() + cfg.init_from_dict( + { + "name": "test", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "workers": 1, + "color_log": True, + "access_log": True, + }, + "llama_stack": { + "api_key": "test-key", + "url": "http://test.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": {"transcripts_enabled": False}, + "mcp_servers": [], + } + ) + mocker.patch("app.endpoints.query.configuration", cfg) + + # Patch authorization to exclude MODEL_OVERRIDE from authorized actions + access_resolver = mocker.Mock() + access_resolver.check_access.return_value = True + access_resolver.get_actions.return_value = set(Action) - {Action.MODEL_OVERRIDE} + mocker.patch( + "authorization.middleware.get_authorization_resolvers", + return_value=(NoopRolesResolver(), access_resolver), + ) + + # Build a request that tries to override model/provider + query_request = QueryRequest(query="What?", model="m", provider="p") + + with pytest.raises(HTTPException) as exc_info: + await query_endpoint_handler( + request=dummy_request, query_request=query_request, auth=MOCK_AUTH + ) + + expected_msg = ( + "This instance does not permit overriding model/provider in the query request " + "(missing permission: MODEL_OVERRIDE). Please remove the model and provider " + "fields from your request." + ) + assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN + assert exc_info.value.detail["response"] == expected_msg diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 794e5c183..5faed197c 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -43,7 +43,8 @@ ) from models.requests import QueryRequest, Attachment -from models.config import ModelContextProtocolServer +from models.config import ModelContextProtocolServer, Action +from authorization.resolvers import NoopRolesResolver from utils.types import ToolCallSummary, TurnSummary MOCK_AUTH = ("mock_user_id", "mock_username", "mock_token") @@ -1515,3 +1516,61 @@ async def test_retrieve_response_no_tools_false_preserves_functionality( stream=True, toolgroups=expected_toolgroups, ) + + +@pytest.mark.asyncio +async def test_streaming_query_endpoint_rejects_model_provider_override_without_permission( + mocker, +): + """Assert 403 when request includes model/provider without MODEL_OVERRIDE.""" + cfg = AppConfig() + cfg.init_from_dict( + { + "name": "test", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "workers": 1, + "color_log": True, + "access_log": True, + }, + "llama_stack": { + "api_key": "test-key", + "url": "http://test.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": {"transcripts_enabled": False}, + "mcp_servers": [], + } + ) + mocker.patch("app.endpoints.streaming_query.configuration", cfg) + + # Patch authorization to exclude MODEL_OVERRIDE from authorized actions + access_resolver = mocker.Mock() + access_resolver.check_access.return_value = True + access_resolver.get_actions.return_value = set(Action) - {Action.MODEL_OVERRIDE} + mocker.patch( + "authorization.middleware.get_authorization_resolvers", + return_value=(NoopRolesResolver(), access_resolver), + ) + + # Build a query request that tries to override model/provider + query_request = QueryRequest(query="What?", model="m", provider="p") + + request = Request( + scope={ + "type": "http", + } + ) + + with pytest.raises(HTTPException) as exc_info: + await streaming_query_endpoint_handler(request, query_request, auth=MOCK_AUTH) + + expected_msg = ( + "This instance does not permit overriding model/provider in the query request " + "(missing permission: MODEL_OVERRIDE). Please remove the model and provider " + "fields from your request." + ) + assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN + assert exc_info.value.detail["response"] == expected_msg diff --git a/tests/unit/utils/test_endpoints.py b/tests/unit/utils/test_endpoints.py index e3b02d6ad..d9a496d07 100644 --- a/tests/unit/utils/test_endpoints.py +++ b/tests/unit/utils/test_endpoints.py @@ -9,6 +9,7 @@ from tests.unit import config_dict from models.requests import QueryRequest +from models.config import Action from utils import endpoints from utils.endpoints import get_agent @@ -591,3 +592,25 @@ async def test_get_agent_no_tools_false_preserves_parser( tool_parser=mock_parser, enable_session_persistence=True, ) + + +def test_validate_model_provider_override_allowed_with_action(): + """Ensure no exception when caller has MODEL_OVERRIDE and request includes model/provider.""" + query_request = QueryRequest(query="q", model="m", provider="p") + authorized_actions = {Action.MODEL_OVERRIDE} + endpoints.validate_model_provider_override(query_request, authorized_actions) + + +def test_validate_model_provider_override_rejected_without_action(): + """Ensure HTTP 403 when request includes model/provider and caller lacks permission.""" + query_request = QueryRequest(query="q", model="m", provider="p") + authorized_actions: set[Action] = set() + with pytest.raises(HTTPException) as exc_info: + endpoints.validate_model_provider_override(query_request, authorized_actions) + assert exc_info.value.status_code == 403 + + +def test_validate_model_provider_override_no_override_without_action(): + """No exception when request does not include model/provider regardless of permission.""" + query_request = QueryRequest(query="q") + endpoints.validate_model_provider_override(query_request, set())