diff --git a/src/labthings_fastapi/exceptions.py b/src/labthings_fastapi/exceptions.py index 42a1709..0cbcf41 100644 --- a/src/labthings_fastapi/exceptions.py +++ b/src/labthings_fastapi/exceptions.py @@ -147,6 +147,17 @@ class NoBlobManagerError(RuntimeError): """ +class NoUrlForContextError(RuntimeError): + """Raised if URLFor is serialised without a url_for context variable being set. + + This usually indicates that URLFor is being serialised somewhere other than in + an HTTP response, + for example in test code or in a background task. In these cases, you should + set up the url_for context variable manually, for example using the + `.testing.use_dummy_url_for` context manager. + """ + + class UnsupportedConstraintError(ValueError): """A constraint argument is not supported. diff --git a/src/labthings_fastapi/middleware/__init__.py b/src/labthings_fastapi/middleware/__init__.py new file mode 100644 index 0000000..160876d --- /dev/null +++ b/src/labthings_fastapi/middleware/__init__.py @@ -0,0 +1 @@ +"""Middleware for use with LabThings.""" diff --git a/src/labthings_fastapi/middleware/url_for.py b/src/labthings_fastapi/middleware/url_for.py new file mode 100644 index 0000000..09ab2c7 --- /dev/null +++ b/src/labthings_fastapi/middleware/url_for.py @@ -0,0 +1,207 @@ +r"""Middleware to make url_for available as a context variable. + +This module is intended mostly for internal use within LabThings. The short +summary is that, if you need to refer to other endpoints in the LabThings +server, you should not return hard-coded URLs, but instead use a `URLFor` +object. This will be converted to a URL when it's serialised by FastAPI, using +the correct ``url_for`` function for the current request. + +Under the hood, this module defines a `url_for` function that performs the +conversion. This function may only be run in certain places in the code, as +it relies on a context variable. As a rule of thumb, it's OK to call +`url_for` from a serializer of a `pydantic` model, but you should not call +it from within an Action or Property. + +There are several places in LabThings where we need to be able to include URLs +to other endpoints in the LabThings server, most notably in the output of +Actions. For example, if an Action outputs a `.Blob`\ , the URL to download +that `.Blob` would need to be generated. + +Actions are particularly complicated, as they are often invoked by one HTTP +request, and polled by subsequent requests. In order to ensure that the URL +we generate is consistent with the URL being requested, we should always use +the ``url_for`` method from the HTTP request we are responding to. This means +it is, in general, not a great idea to generate URLs within an Action and hold +on to them as strings. While it will work most of the time, it would be better +to store the endpoint name, and only convert it to a URL when the action's +output is serialised by FastAPI. + +This module includes a `.ContextVar` for the ``url_for`` function, and provides +a middleware function that sets the context variable for every request, and a +custom type that works with `pydantic` to convert endpoint names to URLs at +serialisation time. +""" + +from collections.abc import Awaitable, Callable, Iterator +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Any +from typing_extensions import Self +from fastapi import Request, Response +from pydantic import GetCoreSchemaHandler +from pydantic.networks import AnyUrl +from pydantic_core import core_schema +from starlette.datastructures import URL + +from labthings_fastapi.exceptions import NoUrlForContextError + +url_for_ctx: ContextVar[Callable[..., URL]] = ContextVar("url_for_ctx") +"""Context variable storing the url_for function for the current request.""" + + +@contextmanager +def set_url_for_context( + url_for_function: Callable[..., URL], +) -> Iterator[None]: + """Set the url_for context variable for the duration of the context. + + :param url_for_function: The url_for function to set in the context variable. + """ + token = url_for_ctx.set(url_for_function) + try: + yield + finally: + url_for_ctx.reset(token) + + +def dummy_url_for(endpoint: str, **params: Any) -> URL: + r"""Generate a fake URL as a placeholder for a real ``url_for`` function. + + This is intended for use in test code. + + :param endpoint: The name of the endpoint. + :param \**params: The path parameters. + :return: A fake URL. + """ + param_str = "&".join(f"{k}={v}" for k, v in params.items()) + return URL(f"urlfor://{endpoint}/?{param_str}") + + +def url_for(endpoint_name: str, **params: Any) -> URL: + r"""Get a URL for the given endpoint name and path parameters. + + This function uses the ``url_for`` function stored in a context variable + to convert endpoint names and parameters to URLs. It is intended to have + the same signature as `fastapi.Request.url_for`\ . + + This function will raise a `NoUrlForContextError` if there is no + ``url_for`` function in the context variable. This will be the case if + the function is called outside of a request handler. As a rule, this + function should not be called from within Actions or Properties. + + `URLFor` is provided as a safe way to return URLs: it ensures that the + URL is only generated at serialisation time, when there is a valid + ``url_for`` function in the context. This also means the URL is always + correct for the request being handled. + + :param endpoint_name: The name of the endpoint to generate a URL for. + :param \**params: The path parameters to use in the URL. + :return: The generated URL. + :raises NoUrlForContextError: if there is no url_for function in the context. + """ + try: + url_for_func = url_for_ctx.get() + except LookupError as err: + raise NoUrlForContextError("No url_for context available.") from err + return url_for_func(endpoint_name, **params) + + +async def url_for_middleware( + request: Request, call_next: Callable[[Request], Awaitable[Response]] +) -> Response: + """Middleware to set the url_for context variable for each request. + + This middleware retrieves the ``url_for`` function from the incoming + request, and sets it in the context variable for the duration of the + request. + + :param request: The incoming FastAPI request. + :param call_next: The next middleware or endpoint handler to call. + :return: The response from the next handler. + """ + token = url_for_ctx.set(request.url_for) + try: + response = await call_next(request) + finally: + url_for_ctx.reset(token) + return response + + +class URLFor: + """A pydantic-compatible type that converts endpoint names to URLs. + + This class is intended to be used as a field type in `pydantic` models + or as a return type from actions or properties. It does not convert + endpoint names to URLs immediately, but instead stores the endpoint name + and parameters, and only generates the URL when it is serialised by + FastAPI. + + It is safe to *create* a `URLFor` instance anywhere, but converting it + to a string (i.e. generating the URL) requires a valid `url_for` function + and should generally be left for FastAPI. + + Fields or return values annotated as `.URLFor` will only accept a `.URLFor` + instance, but will be serialised to JSON as a string, and will show up in + the JSONSchema as a string. + + Validating a string, i.e. converting a string to a `.URLFor` instance, is + not supported, and will raise a `TypeError`. + """ + + def __init__(self, endpoint_name: str, **params: Any) -> None: + r"""Create a URLFor instance. + + :param endpoint_name: The name of the endpoint to generate a URL for. + :param \**params: The path parameters to use in the URL. + """ + self.endpoint_name = endpoint_name + self.params = params + + def __str__(self) -> str: + """Convert the URLFor instance to a URL string. + + :return: The generated URL as a string. + """ + url = url_for(self.endpoint_name, **self.params) + return str(url) + + @classmethod + def __get_pydantic_core_schema__( + cls, source: type[Any], handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + """Get the pydantic core schema for the URLFor type. + + This magic method allows `pydantic` to serialise URLFor + instances, and generate a JSONSchema for them. Currently, + URLFor instances may not be validated from strings, and + attempting to do so will raise an error. + + The "core schema" we generate describes the field as a + string, and serialises it by calling ``str(obj)`` which in + turn calls our ``__str__`` method to generate the URL. + + :param source: The source type being converted. + :param handler: The pydantic core schema handler. + :return: The pydantic core schema for the URLFor type. + """ + return core_schema.no_info_wrap_validator_function( + cls._validate, + AnyUrl.__get_pydantic_core_schema__(AnyUrl, handler), + serialization=core_schema.to_string_ser_schema( # codespell:ignore ser + when_used="always" + ), + ) + + @classmethod + def _validate(cls, value: Any, handler: Callable[[Any], Self]) -> Self: + """Validate and convert a value to a URLFor instance. + + :param value: The value to validate. + :param handler: The handler to convert the value if needed. + :return: The validated URLFor instance. + :raises TypeError: if the value is not a URLFor instance. + """ + if isinstance(value, cls): + return value + else: + raise TypeError("URLFor instances may not be created from strings.") diff --git a/src/labthings_fastapi/testing.py b/src/labthings_fastapi/testing.py index 57f99bf..02f51ca 100644 --- a/src/labthings_fastapi/testing.py +++ b/src/labthings_fastapi/testing.py @@ -1,7 +1,9 @@ """Test harnesses to help with writitng tests for things..""" from __future__ import annotations +from collections.abc import Iterator from concurrent.futures import Future +from contextlib import contextmanager from typing import ( TYPE_CHECKING, Any, @@ -18,6 +20,7 @@ from .utilities import class_attributes from .thing_slots import ThingSlot from .thing_server_interface import ThingServerInterface +from .middleware.url_for import set_url_for_context, dummy_url_for if TYPE_CHECKING: from .thing import Thing @@ -217,3 +220,10 @@ def _mock_slots(thing: Thing) -> None: for _attr_name, attr in class_attributes(thing): if isinstance(attr, ThingSlot): attr.connect(thing, mocks, ...) + + +@contextmanager +def use_dummy_url_for() -> Iterator[None]: + """Use the dummy URL for function in the context variable.""" + with set_url_for_context(dummy_url_for): + yield diff --git a/tests/test_middleware_url_for.py b/tests/test_middleware_url_for.py new file mode 100644 index 0000000..e8f0dc8 --- /dev/null +++ b/tests/test_middleware_url_for.py @@ -0,0 +1,157 @@ +"""Test the URLFor class and associated supporting code.""" + +import threading +import pytest +from pydantic import BaseModel +from pydantic_core import PydanticSerializationError +from fastapi import FastAPI +from starlette.testclient import TestClient + +from labthings_fastapi.middleware import url_for +from labthings_fastapi.middleware.url_for import URLFor, url_for_middleware +from labthings_fastapi.testing import use_dummy_url_for +from labthings_fastapi.exceptions import NoUrlForContextError + + +class ModelWithURL(BaseModel): + """A model containing a URLFor field.""" + + url: URLFor + + +def test_url_for(): + """Check that the `url_for` function uses the context var as expected.""" + with pytest.raises(NoUrlForContextError): + url_for.url_for("my_endpoint", id=123) + with use_dummy_url_for(): + assert url_for.url_for("my_endpoint", id=123) == "urlfor://my_endpoint/?id=123" + + +def test_string_conversion(mocker): + """Test that URLFor can be converted to a string.""" + url_for_spy = mocker.spy(url_for, "url_for") + u = URLFor("my_endpoint", id=123) + with pytest.raises(NoUrlForContextError): + _ = str(u) + with use_dummy_url_for(): + assert str(u) == "urlfor://my_endpoint/?id=123" + assert url_for_spy.call_count == 2 + + +def test_serialisation(mocker): + """Test that URLFor is serialised by calling str() on it.""" + u = URLFor("my_endpoint", id=123) + m = ModelWithURL(url=u) + + # Check that serialisation fails without a url_for context + # and that it tries to call `url_for` + with pytest.raises(NoUrlForContextError) as excinfo: + _ = m.model_dump() + assert "url_for" in [frame.name for frame in excinfo.traceback] + with pytest.raises(PydanticSerializationError, match="NoUrlForContextError"): + _ = m.model_dump_json() + with use_dummy_url_for(): + assert m.model_dump()["url"] == "urlfor://my_endpoint/?id=123" + + +def test_validation(): + """Test that URLFor validation works as expected.""" + # URLFor is a custom type, so the initialiser works normally + u = URLFor("my_endpoint", id=123) + + # Initialising with an instance should leave it unchanged + m = ModelWithURL(url=u) + assert m.url is u + + # Trying to initialise with anything else should raise an error + with pytest.raises(TypeError): + _ = ModelWithURL(url="https://example.com") + with pytest.raises(TypeError): + _ = ModelWithURL(url="endpoint_name") + with pytest.raises(TypeError): + _ = ModelWithURL(url=None) + + +def test_middleware(): + """Check the middleware function works as expected.""" + app = FastAPI() + app.middleware("http")(url_for_middleware) + + class Model(BaseModel): + url: URLFor + + @app.get("/test-endpoint/{item_id}/", name="test-endpoint") + async def test_endpoint(item_id: int) -> URLFor: + """An async endpoint that returns a URLFor instance.""" + return URLFor("test-endpoint", item_id=item_id) + + @app.get("/sync-endpoint/{item_id}/") + def sync_endpoint(item_id: int) -> URLFor: + """A sync endpoint that returns a URLFor instance.""" + return URLFor("test-endpoint", item_id=item_id) + + @app.get("/model-endpoint/{item_id}/") + async def model_endpoint(item_id: int) -> Model: + """An async endpoint that returns a model containing a URLFor.""" + return Model(url=URLFor("test-endpoint", item_id=item_id)) + + @app.get("/direct-async-endpoint/{item_id}/") + async def direct_async_endpoint(item_id: int) -> str: + """An async endpoint that calls `url_for` directly.""" + return str(url_for.url_for("test-endpoint", item_id=item_id)) + + @app.get("/direct_sync-endpoint/{item_id}/") + def direct_sync_endpoint(item_id: int) -> str: + """A sync endpoint that calls `url_for` directly.""" + return str(url_for.url_for("test-endpoint", item_id=item_id)) + + def assert_url_for_fails(item_id: int): + with pytest.raises(NoUrlForContextError): + _ = url_for.url_for("test-endpoint", item_id=item_id) + + def append_from_thread(item_id: int, output: list) -> None: + output.append(URLFor("test-endpoint", item_id=item_id)) + + @app.get("/assert_fails_in_thread/{item_id}/") + async def assert_fails_in_thread(item_id: int) -> bool: + t = threading.Thread(target=assert_url_for_fails, args=(item_id,)) + t.start() + t.join() + return True + + @app.get("/return_from_thread/{item_id}/") + async def return_from_thread(item_id: int) -> URLFor: + output = [] + append_from_thread(item_id, output) + return output[0] + + URL = "http://testserver/test-endpoint/42/" + + with TestClient(app) as client: + response = client.get("/test-endpoint/42/") + assert response.status_code == 200 + assert response.json() == URL + + response = client.get("/sync-endpoint/42/") + assert response.status_code == 200 + assert response.json() == URL + + response = client.get("/model-endpoint/42/") + assert response.status_code == 200 + assert response.json() == {"url": URL} + + response = client.get("/direct-async-endpoint/42/") + assert response.status_code == 200 + assert response.json() == URL + + response = client.get("/direct_sync-endpoint/42/") + assert response.status_code == 200 + assert response.json() == URL + + response = client.get("/assert_fails_in_thread/42/") + assert response.status_code == 200 + assert response.json() is True + + response = client.get("/return_from_thread/42/") + assert response.status_code == 200 + assert response.json() == URL