diff --git a/fair_async_rlock/fair_async_rlock.py b/fair_async_rlock/fair_async_rlock.py index 931ca5c..7a3845f 100644 --- a/fair_async_rlock/fair_async_rlock.py +++ b/fair_async_rlock/fair_async_rlock.py @@ -1,5 +1,10 @@ +from __future__ import annotations import asyncio from collections import deque +from typing import Optional, TYPE_CHECKING, Any + +if TYPE_CHECKING: + from types import TracebackType __all__ = [ 'FairAsyncRLock' @@ -11,21 +16,29 @@ class FairAsyncRLock: A fair reentrant lock for async programming. Fair means that it respects the order of acquisition. """ - def __init__(self): + __slots__ = ("_owner", "_count", "_owner_transfer", "_queue", "_loop") + + def __init__(self) -> None: self._owner: asyncio.Task | None = None self._count = 0 self._owner_transfer = False - self._queue = deque() - - def is_owner(self, task=None): - if task is None: - task = asyncio.current_task() - return self._owner == task + self._queue: deque[asyncio.Future[None]] = deque() + self._loop: Optional[asyncio.AbstractEventLoop] = None + + @property + def loop(self) -> asyncio.AbstractEventLoop: + if not self._loop: + self._loop = asyncio.get_event_loop() + return self._loop + + def is_owner(self, task:Optional[asyncio.Task[Any]] = None) -> bool: + return self._owner == (task or asyncio.current_task()) def locked(self) -> bool: + """determines if the lock is being currently held or not""" return self._owner is not None - async def acquire(self): + async def acquire(self) -> None: """Acquire the lock.""" me = asyncio.current_task() @@ -41,18 +54,18 @@ async def acquire(self): return # Create an event for this task, to notify when it's ready for acquire - event = asyncio.Event() - self._queue.append(event) + fut = self.loop.create_future() + self._queue.append(fut) # Wait for the lock to be free, then acquire try: - await event.wait() + await fut self._owner_transfer = False self._owner = me self._count = 1 except asyncio.CancelledError: try: # if in queue, then cancelled before release - self._queue.remove(event) + self._queue.remove(fut) except ValueError: # otherwise, release happened, this was next, and we simulate passing on self._owner_transfer = False self._owner = me @@ -60,18 +73,17 @@ async def acquire(self): self._current_task_release() raise - def _current_task_release(self): + def _current_task_release(self) -> None: self._count -= 1 if self._count == 0: self._owner = None if self._queue: # Wake up the next task in the queue - event = self._queue.popleft() - event.set() + self._queue.popleft().set_result(None) # Setting this here prevents another task getting lock until owner transfer. self._owner_transfer = True - def release(self): + def release(self) -> None: """Release the lock""" me = asyncio.current_task() @@ -87,5 +99,10 @@ async def __aenter__(self): await self.acquire() return self - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc:Optional[BaseException], + tb:Optional[TracebackType] + ) -> None: self.release() diff --git a/fair_async_rlock/py.typed b/fair_async_rlock/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/fair_async_rlock/tests/test_fair_async_rlock.py b/fair_async_rlock/tests/test_fair_async_rlock.py index 89a220c..5388756 100644 --- a/fair_async_rlock/tests/test_fair_async_rlock.py +++ b/fair_async_rlock/tests/test_fair_async_rlock.py @@ -653,7 +653,6 @@ async def task3(): await asyncio.gather(t1, t2, t3) -@pytest.mark.asyncio def test_locked(): lock = FairAsyncRLock() assert not lock.locked()