Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 70 additions & 17 deletions fastapi/dependencies/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,35 @@
from asyncio import iscoroutinefunction


def _unwrap_partial(obj: Any) -> Any:
while isinstance(obj, partial):
obj = obj.func
return obj


def _iter_wrapped_callables(obj: Any):
"""
Yield the callable and any objects in its ``__wrapped__`` chain.

This intentionally does *not* use ``inspect.unwrap()``, because that would only
return the innermost wrapped object, and FastAPI needs to understand whether
*any* layer is async/generator based (e.g. an async wrapper created with
``functools.wraps`` around a sync function).
"""
seen: set[int] = set()
while True:
obj = _unwrap_partial(obj)
obj_id = id(obj)
if obj_id in seen:
return
seen.add(obj_id)
yield obj
wrapped = getattr(obj, "__wrapped__", None)
if wrapped is None:
return
obj = wrapped


@dataclass
class SecurityRequirement:
security_scheme: SecurityBase
Expand Down Expand Up @@ -79,33 +108,57 @@ def _uses_scopes(self) -> bool:
def _unwrapped_call(self) -> Any:
if self.call is None:
return self.call # pragma: no cover
unwrapped = inspect.unwrap(self.call)
if isinstance(unwrapped, partial):
unwrapped = unwrapped.func
return unwrapped
# NOTE: We intentionally do not unwrap the ``__wrapped__`` chain (as created
# by ``functools.wraps``). Execution mode (sync vs async) must be detected
# based on the actual callable that will run (the outer wrapper), and in
# some cases also based on any wrapped functions (e.g. a sync wrapper
# returning a coroutine from an async wrapped function).
return _unwrap_partial(self.call)

@cached_property
def is_gen_callable(self) -> bool:
if inspect.isgeneratorfunction(self._unwrapped_call):
return True
dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004
return inspect.isgeneratorfunction(dunder_call)
# If there's any coroutine layer, it must be treated as async instead of a
# generator dependency.
if self.is_coroutine_callable:
return False
call = self._unwrapped_call
if inspect.isroutine(call):
return any(inspect.isgeneratorfunction(c) for c in _iter_wrapped_callables(call))
dunder_call = getattr(call, "__call__", None) # noqa: B004
if dunder_call is None:
return False
dunder_call = _unwrap_partial(dunder_call)
return any(inspect.isgeneratorfunction(c) for c in _iter_wrapped_callables(dunder_call))

@cached_property
def is_async_gen_callable(self) -> bool:
if inspect.isasyncgenfunction(self._unwrapped_call):
return True
dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004
return inspect.isasyncgenfunction(dunder_call)
# If there's any coroutine layer, it must be treated as async instead of an
# async-generator dependency.
if self.is_coroutine_callable:
return False
call = self._unwrapped_call
if inspect.isroutine(call):
return any(inspect.isasyncgenfunction(c) for c in _iter_wrapped_callables(call))
dunder_call = getattr(call, "__call__", None) # noqa: B004
if dunder_call is None:
return False
dunder_call = _unwrap_partial(dunder_call)
return any(inspect.isasyncgenfunction(c) for c in _iter_wrapped_callables(dunder_call))

@cached_property
def is_coroutine_callable(self) -> bool:
if inspect.isroutine(self._unwrapped_call):
return iscoroutinefunction(self._unwrapped_call)
if inspect.isclass(self._unwrapped_call):
# If either the original callable OR any wrapper in a ``__wrapped__`` chain
# is async, we should treat this callable as async.
call = self._unwrapped_call
if inspect.isroutine(call):
return any(iscoroutinefunction(c) for c in _iter_wrapped_callables(call))
if inspect.isclass(call):
return False
dunder_call = getattr(call, "__call__", None) # noqa: B004
if dunder_call is None:
return False
dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004
return iscoroutinefunction(dunder_call)
dunder_call = _unwrap_partial(dunder_call)
return any(iscoroutinefunction(c) for c in _iter_wrapped_callables(dunder_call))

@cached_property
def computed_scope(self) -> Union[str, None]:
Expand Down