diff --git a/docs/source/api.rst b/docs/source/api.rst index 710daa9..dcd19d5 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -52,7 +52,6 @@ Private API ---------------------- .. autofunction:: flask_utils.decorators._is_optional -.. autofunction:: flask_utils.decorators._make_optional .. autofunction:: flask_utils.decorators._is_allow_empty .. autofunction:: flask_utils.decorators._check_type diff --git a/flask_utils/__init__.py b/flask_utils/__init__.py index 06b02a6..753197e 100644 --- a/flask_utils/__init__.py +++ b/flask_utils/__init__.py @@ -1,5 +1,5 @@ # Increment versions here according to SemVer -__version__ = "0.7.1" +__version__ = "1.0.0" from flask_utils.utils import is_it_true from flask_utils.errors import GoneError diff --git a/flask_utils/decorators.py b/flask_utils/decorators.py index 7abf242..b7423d7 100644 --- a/flask_utils/decorators.py +++ b/flask_utils/decorators.py @@ -1,5 +1,6 @@ +import inspect +import warnings from typing import Any -from typing import Dict from typing import Type from typing import Union from typing import Callable @@ -18,8 +19,6 @@ from flask_utils.errors import BadRequestError -VALIDATE_PARAMS_MAX_DEPTH = 4 - def _handle_bad_request( use_error_handlers: bool, @@ -61,41 +60,13 @@ def _is_optional(type_hint: Type) -> bool: # type: ignore return get_origin(type_hint) is Union and type(None) in get_args(type_hint) -def _make_optional(type_hint: Type) -> Type: # type: ignore - """Wrap type hint with :data:`~typing.Optional` if it's not already. - - :param type_hint: Type hint to wrap. - :type type_hint: Type - - :return: Type hint wrapped with :data:`~typing.Optional`. - :rtype: Type - - :Example: - - .. code-block:: python - - from typing import Optional - from flask_utils.decorators import _make_optional - - _make_optional(str) # Optional[str] - _make_optional(Optional[str]) # Optional[str] - - .. versionadded:: 0.2.0 - """ - if not _is_optional(type_hint): - return Optional[type_hint] # type: ignore - return type_hint - - -def _is_allow_empty(value: Any, type_hint: Type, allow_empty: bool) -> bool: # type: ignore +def _is_allow_empty(value: Any, type_hint: Type) -> bool: # type: ignore """Determine if the value is considered empty and whether it's allowed. :param value: Value to check. :type value: Any :param type_hint: Type hint to check against. :type type_hint: Type - :param allow_empty: Whether to allow empty values. - :type allow_empty: bool :return: True if the value is empty and allowed, False otherwise. :rtype: bool @@ -117,22 +88,20 @@ def _is_allow_empty(value: Any, type_hint: Type, allow_empty: bool) -> bool: # .. versionadded:: 0.2.0 """ - if value in [None, "", [], {}]: - # Check if type is explicitly Optional or allow_empty is True - if _is_optional(type_hint) or allow_empty: + if not value: + # Check if type is explicitly Optional + if _is_optional(type_hint): return True return False -def _check_type(value: Any, expected_type: Type, allow_empty: bool = False, curr_depth: int = 0) -> bool: # type: ignore +def _check_type(value: Any, expected_type: Type, curr_depth: int = 0) -> bool: # type: ignore """Check if the value matches the expected type, recursively if necessary. :param value: Value to check. :type value: Any :param expected_type: Expected type. :type expected_type: Type - :param allow_empty: Whether to allow empty values. - :type allow_empty: bool :param curr_depth: Current depth of the recursive check. :type curr_depth: int @@ -169,10 +138,12 @@ def _check_type(value: Any, expected_type: Type, allow_empty: bool = False, curr .. versionadded:: 0.2.0 """ + max_depth = current_app.config.get("VALIDATE_PARAMS_MAX_DEPTH", 4) - if curr_depth >= VALIDATE_PARAMS_MAX_DEPTH: + if curr_depth >= max_depth: + warnings.warn(f"Maximum depth of {max_depth} reached.", SyntaxWarning, stacklevel=2) return True - if expected_type is Any or _is_allow_empty(value, expected_type, allow_empty): # type: ignore + if expected_type is Any or _is_allow_empty(value, expected_type): # type: ignore return True if isinstance(value, bool): @@ -186,11 +157,9 @@ def _check_type(value: Any, expected_type: Type, allow_empty: bool = False, curr args = get_args(expected_type) if origin is Union: - return any(_check_type(value, arg, allow_empty, (curr_depth + 1)) for arg in args) + return any(_check_type(value, arg, (curr_depth + 1)) for arg in args) elif origin is list: - return isinstance(value, list) and all( - _check_type(item, args[0], allow_empty, (curr_depth + 1)) for item in value - ) + return isinstance(value, list) and all(_check_type(item, args[0], (curr_depth + 1)) for item in value) elif origin is dict: key_type, val_type = args if not isinstance(value, dict): @@ -198,29 +167,20 @@ def _check_type(value: Any, expected_type: Type, allow_empty: bool = False, curr for k, v in value.items(): if not isinstance(k, key_type): return False - if not _check_type(v, val_type, allow_empty, (curr_depth + 1)): + if not _check_type(v, val_type, (curr_depth + 1)): return False return True else: return isinstance(value, expected_type) -def validate_params( - parameters: Dict[Any, Any], - allow_empty: bool = False, -) -> Callable: # type: ignore +def validate_params() -> Callable: # type: ignore """ Decorator to validate request JSON body parameters. This decorator ensures that the JSON body of a request matches the specified parameter types and includes all required parameters. - :param parameters: Dictionary of parameters to validate. The keys are parameter names - and the values are the expected types. - :type parameters: Dict[Any, Any] - :param allow_empty: Allow empty values for parameters. Defaults to False. - :type allow_empty: bool - :raises BadRequestError: If the JSON body is malformed, the Content-Type header is missing or incorrect, required parameters are missing, or parameters are of the wrong type. @@ -232,21 +192,13 @@ def validate_params( from flask import Flask, request from typing import List, Dict from flask_utils.decorators import validate_params - from flask_utils.errors.badrequest import BadRequestError + from flask_utils.errors import BadRequestError app = Flask(__name__) @app.route("/example", methods=["POST"]) - @validate_params( - { - "name": str, - "age": int, - "is_student": bool, - "courses": List[str], - "grades": Dict[str, int], - } - ) - def example(): + @validate_params() + def example(name: str, age: int, is_student: bool, courses: List[str], grades: Dict[str, int]): \""" This route expects a JSON body with the following: - name: str @@ -255,8 +207,8 @@ def example(): - courses: list of str - grades: dict with str keys and int values \""" - data = request.get_json() - return data + # Use the data in your route + ... .. tip:: You can use any of the following types: @@ -270,6 +222,37 @@ def example(): * Optional * Union + .. warning:: + If a parameter exists both in the route parameters and in the JSON body, + the value from the JSON body will override the route parameter. A warning + is issued when this occurs. + + :Example: + + .. code-block:: python + + from flask import Flask, request + from typing import List, Dict + from flask_utils.decorators import validate_params + from flask_utils.errors import BadRequestError + + app = Flask(__name__) + + @app.route("/users/", methods=["POST"]) + @validate_params() + def create_user(user_id: int): + print(f"User ID: {user_id}") + return "User created" + + ... + + requests.post("/users/123", json={"user_id": 456}) + # Output: User ID: 456 + + .. versionchanged:: 1.0.0 + The decorator doesn't take any parameters anymore, + it loads the types and parameters from the function signature as well as the Flask route's slug parameters. + .. versionchanged:: 0.7.0 The decorator will now use the custom error handlers if ``register_error_handlers`` has been set to ``True`` when initializing the :class:`~flask_utils.extension.FlaskUtils` extension. @@ -296,33 +279,67 @@ def wrapper(*args, **kwargs): # type: ignore "or the JSON body is missing.", original_exception=e, ) - - if not data: - return _handle_bad_request(use_error_handlers, "Missing json body.") - if not isinstance(data, dict): - return _handle_bad_request(use_error_handlers, "JSON body must be a dict") + return _handle_bad_request( + use_error_handlers, + "JSON body must be a dict", + original_exception=BadRequestError("JSON body must be a dict"), + ) - for key, type_hint in parameters.items(): - if not _is_optional(type_hint) and key not in data: + signature = inspect.signature(fn) + parameters = signature.parameters + # Extract the parameter names and annotations + expected_params = {} + for name, param in parameters.items(): + if param.annotation != inspect.Parameter.empty: + expected_params[name] = param.annotation + else: + warnings.warn(f"Parameter {name} has no type annotation.", SyntaxWarning, stacklevel=2) + expected_params[name] = Any + + request_data = request.view_args # Flask route parameters + for key in data: + if key in request_data: + warnings.warn( + f"Parameter {key} is defined in both the route and the JSON body. " + f"The JSON body will override the route parameter.", + SyntaxWarning, + stacklevel=2, + ) + request_data.update(data or {}) + + for key, type_hint in expected_params.items(): + # TODO: Handle deeply nested types + if key not in request_data and not _is_optional(type_hint): return _handle_bad_request( - use_error_handlers, f"Missing key: {key}", f"Expected keys are: {list(parameters.keys())}" + use_error_handlers, f"Missing key: {key}", f"Expected keys are: {list(expected_params.keys())}" ) - for key in data: - if key not in parameters: + for key in request_data: + if key not in expected_params: return _handle_bad_request( - use_error_handlers, f"Unexpected key: {key}.", f"Expected keys are: {list(parameters.keys())}" + use_error_handlers, + f"Unexpected key: {key}.", + f"Expected keys are: {list(expected_params.keys())}", ) - for key in data: - if key in parameters and not _check_type(data[key], parameters[key], allow_empty): + for key, value in request_data.items(): + if key in expected_params and not _check_type(value, expected_params[key]): return _handle_bad_request( use_error_handlers, f"Wrong type for key {key}.", - f"It should be {getattr(parameters[key], '__name__', str(parameters[key]))}", + f"It should be {getattr(expected_params[key], '__name__', str(expected_params[key]))}", ) + provided_values = {} + for key in expected_params: + if not _is_optional(expected_params[key]): + provided_values[key] = request_data[key] + else: + provided_values[key] = request_data.get(key, None) + + kwargs.update(provided_values) + return fn(*args, **kwargs) return wrapper diff --git a/flask_utils/extension.py b/flask_utils/extension.py index 3ed1610..5a7b2ee 100644 --- a/flask_utils/extension.py +++ b/flask_utils/extension.py @@ -33,6 +33,26 @@ class FlaskUtils(object): fu = FlaskUtils() fu.init_app(app) + .. versionchanged:: 1.0.0 + The :func:`~flask_utils.decorators.validate_params` decorator will now use the ``VALIDATE_PARAMS_MAX_DEPTH`` + config variable to determine the maximum depth of the validation for dictionaries. + + :Example: + + .. code-block:: python + + from flask import Flask + from flask_utils import FlaskUtils + + app = Flask(__name__) + fu = FlaskUtils(app) + app.config["VALIDATE_PARAMS_MAX_DEPTH"] = 3 + + The `VALIDATE_PARAMS_MAX_DEPTH` configuration determines the maximum depth of nested dictionary validation + when using the `validate_params` decorator. This allows fine-tuning of validation behavior for + complex nested structures. + + .. versionadded:: 0.5.0 """ @@ -107,3 +127,5 @@ def init_app(self, app: Flask, register_error_handlers: bool = True) -> None: self.has_error_handlers_registered = True app.extensions["flask_utils"] = self + # Default depth of 4 allows for moderately nested structures while not slowing down too much the validation + app.config.setdefault("VALIDATE_PARAMS_MAX_DEPTH", 4) diff --git a/tests/test_validate_params.py b/tests/test_validate_params.py index dce0693..2b4b7ba 100644 --- a/tests/test_validate_params.py +++ b/tests/test_validate_params.py @@ -1,3 +1,4 @@ +import warnings from typing import Any from typing import Dict from typing import List @@ -5,6 +6,8 @@ from typing import Optional import pytest +from flask import Flask +from flask import jsonify from flask_utils import validate_params @@ -13,8 +16,8 @@ class TestBadFormat: @pytest.fixture(autouse=True) def setup_routes(self, flask_client): @flask_client.post("/bad-format") - @validate_params({"name": str}) - def bad_format(): + @validate_params() + def bad_format(name: str): return "OK", 200 def test_malformed_body(self, client): @@ -35,11 +38,14 @@ def test_bad_content_type(self, client): ) def test_missing_body(self, client): - response = client.post("/bad-format", json={}) + response = client.post("/bad-format") assert response.status_code == 400 error_dict = response.get_json()["error"] - assert error_dict["message"] == "Missing json body." + assert ( + error_dict["message"] + == "The Content-Type header is missing or is not set to application/json, or the JSON body is missing." + ) def test_body_not_dict(self, client): response = client.post("/bad-format", json=["not", "a", "dict"]) @@ -53,17 +59,8 @@ class TestDefaultTypes: @pytest.fixture(autouse=True) def setup_routes(self, flask_client): @flask_client.post("/default-types") - @validate_params( - { - "name": str, - "age": int, - "is_active": bool, - "weight": float, - "hobbies": list, - "address": dict, - } - ) - def default_types(): + @validate_params() + def default_types(name: str, age: int, is_active: bool, weight: float, hobbies: list, address: dict): return "OK", 200 def test_valid_request(self, client): @@ -244,8 +241,8 @@ class TestTupleUnion: @pytest.fixture(autouse=True) def setup_routes(self, flask_client): @flask_client.post("/tuple-union") - @validate_params({"name": (str, int)}) - def union(): + @validate_params() + def union(name: (str, int)): return "OK", 200 def test_valid_request(self, client): @@ -267,8 +264,8 @@ class TestUnion: @pytest.fixture(autouse=True) def setup_routes(self, flask_client): @flask_client.post("/union") - @validate_params({"name": Union[str, int]}) - def union(): + @validate_params() + def union(name: Union[str, int]): return "OK", 200 def test_valid_request(self, client): @@ -290,8 +287,8 @@ class TestOptional: @pytest.fixture(autouse=True) def setup_routes(self, flask_client): @flask_client.post("/optional") - @validate_params({"name": str, "age": Optional[int]}) - def optional(): + @validate_params() + def optional(name: str, age: Optional[int]): return "OK", 200 def test_valid_request(self, client): @@ -313,8 +310,8 @@ class TestList: @pytest.fixture(autouse=True) def setup_routes(self, flask_client): @flask_client.post("/list") - @validate_params({"name": List[str]}) - def list(): + @validate_params() + def list(name: List[str]): return "OK", 200 def test_valid_request(self, client): @@ -333,8 +330,8 @@ class TestDict: @pytest.fixture(autouse=True) def setup_routes(self, flask_client): @flask_client.post("/dict") - @validate_params({"name": Dict[str, int]}) - def dict_route(): + @validate_params() + def dict_route(name: Dict[str, int]): return "OK", 200 def test_valid_request(self, client): @@ -353,8 +350,8 @@ class TestAny: @pytest.fixture(autouse=True) def setup_routes(self, flask_client): @flask_client.post("/any") - @validate_params({"name": Any}) - def any_route(): + @validate_params() + def any_route(name: Any): return "OK", 200 def test_valid_request(self, client): @@ -381,16 +378,10 @@ class TestMixAndMatch: @pytest.fixture(autouse=True) def setup_routes(self, flask_client): @flask_client.post("/mix-and-match") - @validate_params( - { - "name": Union[str, int], - "age": Optional[int], - "hobbies": List[str], - "address": Dict[str, int], - "is_active": Any, - } - ) - def mix_and_match(): + @validate_params() + def mix_and_match( + name: Union[str, int], age: Optional[int], hobbies: List[str], address: Dict[str, int], is_active: Any + ): return "OK", 200 def test_valid_request(self, client): @@ -437,3 +428,191 @@ def test_unexpected_key(self, client): error_dict = response.get_json()["error"] assert error_dict["message"] == "Unexpected key: unexpected_key." + + +class TestValidateParamsWithoutErrorHandlers: + @pytest.fixture(scope="function") + def setup_routes(self): + app = Flask(__name__) + app.testing = True + + @app.route("/example", methods=["POST", "GET"]) + @validate_params() + def example(name: str): + return "OK", 200 + + yield app + + @pytest.fixture(autouse=True) + def client(self, setup_routes): + yield setup_routes.test_client() + + def test_missing_content_type(self, client): + response = client.get("/example") + assert response.status_code == 400 + assert ( + response.json["error"] + == "The Content-Type header is missing or is not set to application/json, or the JSON body is missing." + ) + assert "success" not in response.json + assert "code" not in response.json + assert not isinstance(response.json["error"], dict) + + def test_malformed_json_body(self, client): + response = client.post("/example", data="not a json", headers={"Content-Type": "application/json"}) + assert response.status_code == 400 + assert response.json["error"] == "The Json Body is malformed." + assert "success" not in response.json + assert "code" not in response.json + assert not isinstance(response.json["error"], dict) + + def test_json_body_not_dict(self, client): + response = client.post("/example", json=["not", "a", "dict"]) + assert response.status_code == 400 + assert response.json["error"] == "JSON body must be a dict" + assert "success" not in response.json + assert "code" not in response.json + assert not isinstance(response.json["error"], dict) + + def test_missing_key(self, client, setup_routes): + @setup_routes.route("/example2", methods=["POST"]) + @validate_params() + def example2(name: str, age: int): + return "OK", 200 + + response = client.post("/example2", json={"name": "John"}) + assert response.status_code == 400 + assert response.json["error"] == "Missing key: age" + assert "success" not in response.json + assert "code" not in response.json + assert not isinstance(response.json["error"], dict) + + def test_unexpected_key(self, client): + response = client.post("/example", json={"name": "John", "extra": "value"}) + assert response.status_code == 400 + assert response.json["error"] == "Unexpected key: extra." + assert "success" not in response.json + assert "code" not in response.json + assert not isinstance(response.json["error"], dict) + + def test_wrong_type(self, client): + response = client.post("/example", json={"name": 123}) + assert response.status_code == 400 + assert response.json["error"] == "Wrong type for key name." + assert "success" not in response.json + assert "code" not in response.json + assert not isinstance(response.json["error"], dict) + + +class TestAnnotationWarnings: + @pytest.fixture(autouse=True) + def setup_routes(self, flask_client): + @flask_client.post("/example/") + @validate_params() + def example(user_id: int, name): + return "OK", 200 + + def test_no_type_annotation(self, client): + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered. + warnings.simplefilter("always") + # Trigger a warning. + response = client.post("/example/1", json={"name": "John"}) + + assert response.status_code == 200 + assert len(w) == 1 + assert issubclass(w[-1].category, SyntaxWarning) + assert "Parameter name has no type annotation." in str(w[-1].message) + + def test_duplicate_keys(self, client): + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered. + warnings.simplefilter("always") + # Trigger a warning. + response = client.post("/example/1", json={"name": "John", "user_id": 1}) + + assert response.status_code == 200 + assert len(w) == 2 + assert issubclass(w[-1].category, SyntaxWarning) + assert ( + "Parameter user_id is defined in both the route and the JSON body. " + "The JSON body will override the route parameter." in str(w[-1].message) + ) + + +class TestMaxDepth: + @pytest.fixture(autouse=True) + def setup_routes(self, flask_client): + @flask_client.post("/example") + @validate_params() + def example(user_info: Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, str]]]]]]): + return "OK", 200 + + def test_max_depth(self, client): + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered. + warnings.simplefilter("always") + # Trigger a warning. + response = client.post( + "/example", + json={ + "user_info": { + "name": {"age": {"is_active": {"weight": {"hobbies": {"address": {"city": "New York City"}}}}}} + } + }, + ) + + assert response.status_code == 200 + assert len(w) == 1 + assert issubclass(w[-1].category, SyntaxWarning) + assert "Maximum depth of 4 reached." in str(w[-1].message) + + def test_change_max_depth(self, client, flask_client): + flask_client.config["VALIDATE_PARAMS_MAX_DEPTH"] = 1 + + with warnings.catch_warnings(record=True) as w: + client.post( + "/example", + json={ + "user_info": { + "name": {"age": {"is_active": {"weight": {"hobbies": {"address": {"city": "New York City"}}}}}} + } + }, + ) + assert len(w) == 1 + assert issubclass(w[-1].category, SyntaxWarning) + assert "Maximum depth of 1 reached." in str(w[-1].message) + + +class TestJSONOverridesRouteParams: + @pytest.fixture(autouse=True) + def setup_routes(self, flask_client): + @flask_client.post("/users/") + @validate_params() + def create_user(user_id: int): + return f"{user_id}" + + def test_valid_request(self, client): + with warnings.catch_warnings(record=True): + response = client.post("/users/123", json={"user_id": 456}) + assert response.status_code == 200 + assert response.text == "456" + + +class TestEmptyValues: + @pytest.fixture(autouse=True) + def setup_routes(self, flask_client): + @flask_client.route("/empty", methods=["POST"]) + @validate_params() + def empty_route(name: Optional[str]): + return jsonify({"name": name}) + + def test_empty_value_optional(self, client): + response = client.post("/empty", json={}) + assert response.status_code == 200 + assert response.get_json() == {"name": None} + + # Testing with 'name' as empty string + response = client.post("/empty", json={"name": ""}) + assert response.status_code == 200 + assert response.get_json() == {"name": ""} diff --git a/tests/test_validate_params_with_error_handlers.py b/tests/test_validate_params_with_error_handlers.py deleted file mode 100644 index ad1d48d..0000000 --- a/tests/test_validate_params_with_error_handlers.py +++ /dev/null @@ -1,78 +0,0 @@ -import pytest -from flask import Flask - -from flask_utils import validate_params - - -class TestValidateParamsWithoutErrorHandlers: - @pytest.fixture(scope="function") - def setup_routes(self): - app = Flask(__name__) - app.testing = True - - @app.route("/example", methods=["POST", "GET"]) - @validate_params({"name": str}) - def example(): - return "OK", 200 - - yield app - - @pytest.fixture(autouse=True) - def client(self, setup_routes): - yield setup_routes.test_client() - - def test_missing_content_type(self, client): - response = client.get("/example") - assert response.status_code == 400 - assert ( - response.json["error"] - == "The Content-Type header is missing or is not set to application/json, or the JSON body is missing." - ) - assert "success" not in response.json - assert "code" not in response.json - assert not isinstance(response.json["error"], dict) - - def test_malformed_json_body(self, client): - response = client.post("/example", data="not a json", headers={"Content-Type": "application/json"}) - assert response.status_code == 400 - assert response.json["error"] == "The Json Body is malformed." - assert "success" not in response.json - assert "code" not in response.json - assert not isinstance(response.json["error"], dict) - - def test_json_body_not_dict(self, client): - response = client.post("/example", json=["not", "a", "dict"]) - assert response.status_code == 400 - assert response.json["error"] == "JSON body must be a dict" - assert "success" not in response.json - assert "code" not in response.json - assert not isinstance(response.json["error"], dict) - - def test_missing_key(self, client, setup_routes): - @setup_routes.route("/example2", methods=["POST"]) - @validate_params({"name": str, "age": int}) - def example2(): - return "OK", 200 - - response = client.post("/example2", json={"name": "John"}) - assert response.status_code == 400 - assert response.json["error"] == "Missing key: age" - assert "success" not in response.json - assert "code" not in response.json - assert not isinstance(response.json["error"], dict) - - def test_unexpected_key(self, client): - response = client.post("/example", json={"name": "John", "extra": "value"}) - assert response.status_code == 400 - assert response.json["error"] == "Unexpected key: extra." - assert "success" not in response.json - assert "code" not in response.json - assert not isinstance(response.json["error"], dict) - - def test_wrong_type(self, client): - response = client.post("/example", json={"name": 123}) - assert response.status_code == 400 - assert response.json["error"] == "Wrong type for key name." - assert "success" not in response.json - assert "code" not in response.json - assert not isinstance(response.json["error"], dict)