diff --git a/returns/primitives/reawaitable.py b/returns/primitives/reawaitable.py index fba8505d3..4e87d4717 100644 --- a/returns/primitives/reawaitable.py +++ b/returns/primitives/reawaitable.py @@ -1,8 +1,10 @@ from collections.abc import Awaitable, Callable, Generator -from typing import NewType, TypeVar, cast, final +from functools import wraps +from typing import NewType, ParamSpec, TypeVar, cast, final _ValueType = TypeVar('_ValueType') -_FunctionCoroType = TypeVar('_FunctionCoroType', bound=Callable[..., Awaitable]) +_AwaitableT = TypeVar('_AwaitableT', bound=Awaitable) +_Ps = ParamSpec('_Ps') _Sentinel = NewType('_Sentinel', object) _sentinel: _Sentinel = cast(_Sentinel, object()) @@ -104,7 +106,9 @@ async def _awaitable(self) -> _ValueType: return self._cache # type: ignore -def reawaitable(coro: _FunctionCoroType) -> _FunctionCoroType: +def reawaitable( + coro: Callable[_Ps, _AwaitableT], +) -> Callable[_Ps, _AwaitableT]: """ Allows to decorate coroutine functions to be awaitable multiple times. @@ -124,6 +128,12 @@ def reawaitable(coro: _FunctionCoroType) -> _FunctionCoroType: >>> assert anyio.run(main) == 3 """ - return lambda *args, **kwargs: ReAwaitable( # type: ignore - coro(*args, **kwargs), - ) + + @wraps(coro) + def decorator( + *args: _Ps.args, + **kwargs: _Ps.kwargs, + ) -> _AwaitableT: + return ReAwaitable(coro(*args, **kwargs)) # type: ignore[return-value] + + return decorator