Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions returns/primitives/reawaitable.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from collections.abc import Awaitable, Callable, Generator
from typing import NewType, TypeVar, cast, final
from functools import wraps
from typing import NewType, ParamSpec, TypeVar, cast, final

_ValueType = TypeVar('_ValueType')
_FunctionCoroType = TypeVar('_FunctionCoroType', bound=Callable[..., Awaitable])
_AwaitableT = TypeVar('_AwaitableT', bound=Awaitable)
_Ps = ParamSpec('_Ps')

_Sentinel = NewType('_Sentinel', object)
_sentinel: _Sentinel = cast(_Sentinel, object())
Expand Down Expand Up @@ -104,7 +106,9 @@ async def _awaitable(self) -> _ValueType:
return self._cache # type: ignore


def reawaitable(coro: _FunctionCoroType) -> _FunctionCoroType:
def reawaitable(
coro: Callable[_Ps, _AwaitableT],
) -> Callable[_Ps, _AwaitableT]:
"""
Allows to decorate coroutine functions to be awaitable multiple times.

Expand All @@ -124,6 +128,12 @@ def reawaitable(coro: _FunctionCoroType) -> _FunctionCoroType:
>>> assert anyio.run(main) == 3

"""
return lambda *args, **kwargs: ReAwaitable( # type: ignore
coro(*args, **kwargs),
)

@wraps(coro)
def decorator(
*args: _Ps.args,
**kwargs: _Ps.kwargs,
) -> _AwaitableT:
return ReAwaitable(coro(*args, **kwargs)) # type: ignore[return-value]

return decorator