diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml index 809227a..ccadae3 100644 --- a/.github/workflows/unittests.yml +++ b/.github/workflows/unittests.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [ "3.7", "3.8", "3.9", "3.10", "3.11", "3.12" ] + python-version: [ "3.10", "3.11", "3.12", "3.13" ] steps: - uses: actions/checkout@v3 @@ -27,9 +27,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 pytest - pip install -r requirements-tests.txt - pip install . + pip install .[anyio,tests] - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names diff --git a/README.md b/README.md index f953029..5b553c2 100644 --- a/README.md +++ b/README.md @@ -15,12 +15,18 @@ because [python decided not to support RLock in asyncio](https://discuss.python. their [argument](https://discuss.python.org/t/asyncio-rlock-reentrant-locks-for-async-python/21509/2) being that every extra bit of functionality adds to maintenance cost. -Install with +Install normally for asyncio support: ```bash -pip install fair-async-rlock + ``` +or with AnyIO support: + +```bash +pip install fair-async-rlock[anyio] +```` + ## About Fair Reentrant Lock for AsyncIO A reentrant lock (or recursive lock) is a particular type of lock that can be "locked" multiple times by the same task @@ -111,4 +117,5 @@ with `asyncio.Lock`. ### Change Log 27 Jan, 2024 - 1.0.7 released. Fixed a bug that allowed another task to get the lock before a waiter got its turn on the -event loop. \ No newline at end of file +event loop. +17 Mar, 2025 - 2.0.0 released. Remove support for < 3.10. \ No newline at end of file diff --git a/fair_async_rlock/__init__.py b/fair_async_rlock/__init__.py deleted file mode 100644 index dee7f22..0000000 --- a/fair_async_rlock/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from fair_async_rlock.fair_async_rlock import * diff --git a/pyproject.toml b/pyproject.toml index b5a3c46..4ff2cbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,29 @@ +# pyproject.toml + [build-system] -requires = [ - "setuptools>=42", - "wheel" +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "fair_async_rlock" +version = "2.0.0" +description = "A fair async RLock for Python" +readme = "README.md" +requires-python = ">=3.10" +license = { text = "Apache Software License" } +authors = [{ name = "Joshua G. Albert", email = "albert@strw.leidenuniv.nl" }] +keywords = ["async", "fair", "reentrant", "lock", "concurrency"] +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent" ] -build-backend = "setuptools.build_meta" \ No newline at end of file +urls = { "Homepage" = "https://github.com/joshuaalbert/FairAsyncRLock" } +dynamic = ["dependencies", "optional-dependencies"] + +[tool.setuptools] +include-package-data = true + + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/requirements-anyio.txt b/requirements-anyio.txt new file mode 100644 index 0000000..4575a89 --- /dev/null +++ b/requirements-anyio.txt @@ -0,0 +1 @@ +anyio>=4.5 \ No newline at end of file diff --git a/requirements-tests.txt b/requirements-tests.txt index 2a32ed4..a727deb 100644 --- a/requirements-tests.txt +++ b/requirements-tests.txt @@ -1,2 +1,5 @@ -pytest<8.0.0 -pytest-asyncio \ No newline at end of file +flake8 +pytest +pytest-asyncio +trio==0.25.* +anyio>=4.5 \ No newline at end of file diff --git a/fair_async_rlock/tests/__init__.py b/requirements.txt similarity index 100% rename from fair_async_rlock/tests/__init__.py rename to requirements.txt diff --git a/setup.py b/setup.py index 76192c3..0fd064a 100755 --- a/setup.py +++ b/setup.py @@ -1,31 +1,19 @@ #!/usr/bin/env python -from setuptools import find_packages from setuptools import setup -with open("README.md", "r") as fh: - long_description = fh.read() +install_requires = [] -setup(name='fair_async_rlock', - version='1.0.7', - description='A well-tested implementation of a fair asynchronous RLock for concurrent programming.', - long_description=long_description, - long_description_content_type="text/markdown", - url="https://github.com/Joshuaalbert/FairAsyncRLock", - author='Joshua G. Albert', - author_email='albert@strw.leidenuniv.nl', - setup_requires=[], - install_requires=[], - tests_require=[ - 'pytest', - 'pytest-asyncio' - ], - package_dir={'': './'}, - packages=find_packages('./'), - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: Apache Software License", - "Operating System :: OS Independent", - ], - python_requires='>=3.7', - ) + +def load_requirements(file_name): + with open(file_name, "r") as file: + return [line.strip() for line in file if line.strip() and not line.startswith("#")] + + +setup( + install_requires=load_requirements("requirements.txt"), + extras_require={ + "tests": load_requirements("requirements-tests.txt"), + 'anyio': load_requirements("requirements-anyio.txt") + } +) diff --git a/src/fair_async_rlock/__init__.py b/src/fair_async_rlock/__init__.py new file mode 100644 index 0000000..ab47c83 --- /dev/null +++ b/src/fair_async_rlock/__init__.py @@ -0,0 +1,6 @@ +from fair_async_rlock.asyncio_fair_async_rlock import * + +try: + from fair_async_rlock.anyio_fair_async_rlock import * +except ImportError: + pass diff --git a/src/fair_async_rlock/anyio_fair_async_rlock.py b/src/fair_async_rlock/anyio_fair_async_rlock.py new file mode 100644 index 0000000..48686f2 --- /dev/null +++ b/src/fair_async_rlock/anyio_fair_async_rlock.py @@ -0,0 +1,29 @@ +from typing import TypeVar + +import anyio + +from fair_async_rlock.base_fair_async_rlock import BaseFairAsyncRLock + +__all__ = [ + 'AnyIOFairAsyncRLock' +] +TaskType = TypeVar('TaskType') +EventType = TypeVar('EventType') + + +class AnyIOFairAsyncRLock(BaseFairAsyncRLock[anyio.TaskInfo, anyio.Event]): + """ + A fair reentrant lock for async programming. Fair means that it respects the order of acquisition. + """ + + def _get_current_task(self) -> anyio.TaskInfo: + return anyio.get_current_task() + + def _get_cancelled_exc_class(self) -> type[BaseException]: + return anyio.get_cancelled_exc_class() + + def _get_wake_event(self) -> anyio.Event: + return anyio.Event() + + async def _checkpoint(self) -> None: + await anyio.lowlevel.checkpoint() diff --git a/src/fair_async_rlock/asyncio_fair_async_rlock.py b/src/fair_async_rlock/asyncio_fair_async_rlock.py new file mode 100644 index 0000000..0de9aa3 --- /dev/null +++ b/src/fair_async_rlock/asyncio_fair_async_rlock.py @@ -0,0 +1,28 @@ +import asyncio +from typing import Any, TypeVar + +from fair_async_rlock.base_fair_async_rlock import BaseFairAsyncRLock + +__all__ = [ + 'FairAsyncRLock' +] +TaskType = TypeVar('TaskType') +EventType = TypeVar('EventType') + + +class FairAsyncRLock(BaseFairAsyncRLock[asyncio.Task[Any], asyncio.Event]): + """ + A fair reentrant lock for async programming. Fair means that it respects the order of acquisition. + """ + + def _get_current_task(self) -> asyncio.Task[Any] | None: + return asyncio.current_task() + + def _get_cancelled_exc_class(self) -> type[BaseException]: + return asyncio.CancelledError + + def _get_wake_event(self) -> asyncio.Event: + return asyncio.Event() + + async def _checkpoint(self) -> None: + await asyncio.sleep(0) diff --git a/fair_async_rlock/fair_async_rlock.py b/src/fair_async_rlock/base_fair_async_rlock.py similarity index 53% rename from fair_async_rlock/fair_async_rlock.py rename to src/fair_async_rlock/base_fair_async_rlock.py index 931ca5c..300d15e 100644 --- a/fair_async_rlock/fair_async_rlock.py +++ b/src/fair_async_rlock/base_fair_async_rlock.py @@ -1,47 +1,85 @@ -import asyncio +from abc import ABC, abstractmethod from collections import deque -__all__ = [ - 'FairAsyncRLock' -] +from typing import Any, Generic, TypeVar, Protocol +class _Event(Protocol): + def set(self) -> None: ... + async def wait(self) -> Any: ... -class FairAsyncRLock: + +TaskType = TypeVar("TaskType") +EventType = TypeVar("EventType", bound=_Event) + + +class AbstractFairAsyncRLock(ABC, Generic[TaskType, EventType]): + @abstractmethod + def _get_current_task(self) -> TaskType | None: + ... + + @abstractmethod + def _get_cancelled_exc_class(self) -> type[BaseException]: + ... + + @abstractmethod + def _get_wake_event(self) -> EventType: + ... + + @abstractmethod + async def _checkpoint(self) -> None: + ... + + +class BaseFairAsyncRLock(AbstractFairAsyncRLock[TaskType, EventType]): """ A fair reentrant lock for async programming. Fair means that it respects the order of acquisition. """ def __init__(self): - self._owner: asyncio.Task | None = None - self._count = 0 - self._owner_transfer = False - self._queue = deque() + self._owner: TaskType | None = None + self._count: int = 0 + self._owner_transfer: bool = False + self._queue: deque[EventType] = deque() - def is_owner(self, task=None): + def is_owner(self, task: TaskType | None = None) -> bool: if task is None: - task = asyncio.current_task() + task = self._get_current_task() return self._owner == task def locked(self) -> bool: return self._owner is not None - async def acquire(self): + async def acquire(self) -> None: """Acquire the lock.""" - me = asyncio.current_task() + me = self._get_current_task() # If the lock is reentrant, acquire it immediately if self.is_owner(task=me): self._count += 1 + try: + await self._checkpoint() + except self._get_cancelled_exc_class(): + # Cancelled, while reentrant, so release the lock + self._owner_transfer = False + self._owner = me + self._count = 1 + self._current_task_release() + raise return # If the lock is free (and ownership not in midst of transfer), acquire it immediately if self._count == 0 and not self._owner_transfer: self._owner = me self._count = 1 + try: + await self._checkpoint() + except self._get_cancelled_exc_class(): + self._current_task_release() + raise return # Create an event for this task, to notify when it's ready for acquire - event = asyncio.Event() + event = self._get_wake_event() self._queue.append(event) # Wait for the lock to be free, then acquire @@ -50,17 +88,18 @@ async def acquire(self): self._owner_transfer = False self._owner = me self._count = 1 - except asyncio.CancelledError: + except self._get_cancelled_exc_class(): try: # if in queue, then cancelled before release self._queue.remove(event) - except ValueError: # otherwise, release happened, this was next, and we simulate passing on + except ValueError: + # otherwise, release happened, this was next, and we simulate passing on self._owner_transfer = False self._owner = me self._count = 1 self._current_task_release() raise - def _current_task_release(self): + def _current_task_release(self) -> None: self._count -= 1 if self._count == 0: self._owner = None @@ -73,7 +112,7 @@ def _current_task_release(self): def release(self): """Release the lock""" - me = asyncio.current_task() + me = self._get_current_task() if self._owner is None: raise RuntimeError(f"Cannot release un-acquired lock. {me} tried to release.") diff --git a/src/fair_async_rlock/py.typed b/src/fair_async_rlock/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/fair_async_rlock/tests/__init__.py b/src/fair_async_rlock/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/fair_async_rlock/tests/test_anyio_fair_async_rlock.py b/src/fair_async_rlock/tests/test_anyio_fair_async_rlock.py new file mode 100644 index 0000000..378dee9 --- /dev/null +++ b/src/fair_async_rlock/tests/test_anyio_fair_async_rlock.py @@ -0,0 +1,496 @@ +import asyncio +import multiprocessing +import random +from contextlib import suppress +from functools import wraps +from typing import Any, Awaitable, Callable, NoReturn, Union + +import anyio +import anyio.lowlevel +import pytest + +from fair_async_rlock import AnyIOFairAsyncRLock + +pytestmark: pytest.MarkDecorator = pytest.mark.anyio + +CoRo = Callable[..., Awaitable[Any]] + +SMALL_DELAY = 0.04 # Just enough for python to reliably execute a few lines of code + + +class DummyError(Exception): + pass + + +def with_timeout(t: float) -> Callable[[CoRo], CoRo]: + def wrapper(corofunc: CoRo) -> CoRo: + @wraps(corofunc) + async def run(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 + with anyio.move_on_after(t) as scope: + await corofunc(*args, **kwargs) + if scope.cancelled_caught: + pytest.fail("Test timeout.") + + return run + + return wrapper + + +def repeat(n: int) -> Callable[[CoRo], CoRo]: + def wrapper(corofunc: CoRo) -> CoRo: + @wraps(corofunc) + async def run(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 + for _ in range(n): + await corofunc(*args, **kwargs) + + return run + + return wrapper + + +@repeat(10) +@with_timeout(1) +async def test_reentrant() -> None: + """Test that the lock can be acquired multiple times by the same task.""" + lock = AnyIOFairAsyncRLock() + async with lock and lock: + assert True + + +@repeat(10) +@with_timeout(1) +async def test_exclusion() -> None: + """Test that the lock prevents multiple tasks from acquiring it at the same time.""" + lock = AnyIOFairAsyncRLock() + got_in = anyio.Event() + + async def inner() -> None: + async with lock: + got_in.set() # Never reached: Shouldn't happen + + async with lock, anyio.create_task_group() as tg: + tg.start_soon(inner) + await anyio.sleep(SMALL_DELAY) + assert not got_in.is_set() + tg.cancel_scope.cancel() + + +@repeat(10) +@with_timeout(1) +async def test_fairness() -> None: + """Test that the lock is acquired in the order it is requested.""" + lock = AnyIOFairAsyncRLock() + order: list[int] = [] + + async def worker(n: int) -> None: + async with lock: + await anyio.sleep(SMALL_DELAY) + order.append(n) + + async with anyio.create_task_group() as tg: + for i in range(5): + tg.start_soon(worker, i) + await anyio.lowlevel.checkpoint() + + assert order == list(range(5)) + assert not lock.locked() + + +@repeat(10) +@with_timeout(1) +async def test_unowned_release() -> None: + """Test that releasing an un-acquired lock is handled gracefully.""" + async with anyio.create_task_group() as tg: + lock = AnyIOFairAsyncRLock() + + with pytest.raises(RuntimeError, match="Cannot release un-acquired lock."): + lock.release() + + async def worker() -> None: + with pytest.raises(RuntimeError, match="Cannot release un-acquired lock."): + lock.release() + + tg.start_soon(worker) + + +@with_timeout(1) +async def test_stress_1() -> None: + """Test that the lock can be acquired and released by multiple tasks rapidly.""" + lock = AnyIOFairAsyncRLock() + num_tasks = 100 + iterations = 100 + + async def worker() -> None: + for _ in range(iterations): + async with lock: + pass + + async with anyio.create_task_group() as tg: + for _ in range(num_tasks): + tg.start_soon(worker) + + assert not lock.locked() + + +@with_timeout(1) +async def test_stress_2() -> None: + """Test that the lock can be acquired and released by multiple tasks rapidly.""" + lock = AnyIOFairAsyncRLock() + num_tasks = 100 + + alive_tasks: int = 0 + async with anyio.create_task_group() as tg: + + async def worker() -> None: + nonlocal alive_tasks + alive_tasks += 1 + with anyio.CancelScope() as scope: + while not scope.cancel_called: + async with lock: + n: int = random.randint(0, 2) # noqa: S311 + if n == 0: # Create a new task 1/3 times. + tg.start_soon(worker) + else: # Cancel a task 2/3 times. + scope.cancel() + alive_tasks -= 1 + + for _ in range(num_tasks): + tg.start_soon(worker) + + assert alive_tasks == 0 + assert not lock.locked() + + +@repeat(10) +@with_timeout(1) +async def test_lock_status_checks() -> None: + """Test that the lock status checks work as expected.""" + lock = AnyIOFairAsyncRLock() + assert not lock.is_owner() + async with lock: + assert lock.is_owner() + + +@repeat(10) +@with_timeout(1) +async def test_nested_lock_acquisition() -> None: + """Test that lock ownership is correctly tracked.""" + lock1 = AnyIOFairAsyncRLock() + lock2 = AnyIOFairAsyncRLock() + + lock1_acquired = anyio.Event() + worker_task: Union[anyio.TaskInfo, None] = None + + async def worker() -> None: + nonlocal worker_task + worker_task = anyio.get_current_task() + async with lock1: + lock1_acquired.set() + await anyio.sleep(SMALL_DELAY) + + async def control_task() -> None: + nonlocal worker_task + async with anyio.create_task_group() as tg: + tg.start_soon(worker) + await lock1_acquired.wait() + assert lock1.is_owner(task=worker_task) + assert not lock2.is_owner() + assert worker_task != anyio.get_current_task() + async with lock2: + assert lock1.is_owner(task=worker_task) + assert lock2.is_owner() + + await control_task() + + +@repeat(10) +@with_timeout(1) +async def test_lock_released_on_exception() -> None: + """Test that the lock is released when an exception is raised.""" + lock = AnyIOFairAsyncRLock() + with suppress(Exception): + async with lock: + raise DummyError + + assert not lock.locked() + + +@repeat(10) +@with_timeout(1) +async def test_release_foreign_lock() -> None: + """Test that releasing a lock acquired by another task is handled gracefully.""" + lock = AnyIOFairAsyncRLock() + lock_acquired = anyio.Event() + + async def task1() -> None: + async with lock: + lock_acquired.set() + await anyio.sleep(SMALL_DELAY) + + async def task2() -> None: + await lock_acquired.wait() + with pytest.raises(RuntimeError, match="Cannot release foreign lock."): + lock.release() + + async with anyio.create_task_group() as tg: + tg.start_soon(task1) + await lock_acquired.wait() + tg.start_soon(task2) + + assert not lock.locked() + + +@repeat(10) +@with_timeout(1) +async def test_acquire_exception_handling() -> None: + """Test that if an exception is raised by current owner during lock acquisition, the lock is still handed over.""" + lock = AnyIOFairAsyncRLock() + lock_acquired = anyio.Event() + success_flag = anyio.Event() + + async def failing_task() -> NoReturn: + try: + await lock.acquire() + lock_acquired.set() + await anyio.sleep(SMALL_DELAY) + raise DummyError # noqa: TRY301 + except DummyError: + lock.release() + + async def succeeding_task() -> None: + await lock.acquire() + success_flag.set() + lock.release() + + async with anyio.create_task_group() as tg: + tg.start_soon(failing_task) + await lock_acquired.wait() + tg.start_soon(succeeding_task) + + assert not lock.locked() + assert success_flag.is_set() + + +@repeat(10) +@with_timeout(1) +async def test_lock_cancellation_during_acquisition() -> None: + """Test that if cancellation is raised during lock acquisition, the lock is not acquired.""" + lock = AnyIOFairAsyncRLock() + t1_ac = anyio.Event() + t2_ac = anyio.Event() + t2_started = anyio.Event() + + async def task1() -> None: + async with lock: + t1_ac.set() + await anyio.sleep(100) + + async def task2() -> None: + await t1_ac.wait() + t2_started.set() + async with lock: + t2_ac.set() # Never reached: Shouldn't happen + + async with anyio.create_task_group() as tg: + tg.start_soon(task1) + tg.start_soon(task2) + await t2_started.wait() + tg.cancel_scope.cancel() + + assert t2_started.is_set() + assert not t2_ac.is_set() + assert not lock.locked() + + +@repeat(10) +@with_timeout(1) +async def test_lock_cancellation_after_acquisition() -> None: + """Test that if cancellation is raised after lock acquisition, the lock is still released.""" + lock = AnyIOFairAsyncRLock() + lock_acquired = anyio.Event() + cancellation_event = anyio.Event() + + async def task_to_cancel() -> None: + async with lock: + lock_acquired.set() + try: + await anyio.sleep(SMALL_DELAY) + except anyio.get_cancelled_exc_class(): + cancellation_event.set() + + async with anyio.create_task_group() as tg: + tg.start_soon(task_to_cancel) + await lock_acquired.wait() + tg.cancel_scope.cancel() + + await cancellation_event.wait() + + assert not lock.locked() + + +def test_non_cooperative_cancel(): + def run_coroutine(q): + async def run_test(): + lock = AnyIOFairAsyncRLock() + + async def while_loop(): + while True: + await lock.acquire() + # await anyio.sleep(0.) # Uncommenting makes the task cooperative + + async with anyio.create_task_group() as tg: + tg.start_soon(while_loop) + await anyio.lowlevel.checkpoint() + tg.cancel_scope.cancel() + assert lock._count == 0 # The lock should be released after the task is cancelled + + try: + asyncio.run(run_test()) + except BaseException as e: + q.put(e) + + # Because cancellation is cooperative if we try to use wait_for to test for timeout it will just run forever, + # because control never yielded. await keyword does not yield control. + # Thus, we must wrap into a process to test for hanging. + q = multiprocessing.Queue() + proc = multiprocessing.Process(target=run_coroutine, args=(q,)) + proc.start() + # Wait for the process to finish, with a timeout. + proc.join(timeout=2) + + # If the process is still alive, it means the coroutine did not yield and cancel. + try: + assert not proc.is_alive(), "Test did not terminate as expected." + if not q.empty(): + raise q.get() + finally: + if proc.is_alive(): + proc.terminate() + proc.join(timeout=1) + + +def test_non_cooperative_cancel_reentrant(): + def run_coroutine(q): + async def run_test(): + event = anyio.Event() + lock = AnyIOFairAsyncRLock() + + async def while_loop(): + idx = 0 + while True: + await lock.acquire() + # await anyio.sleep(0.) # Uncommenting makes the task cooperative + idx += 1 + if idx == 2: + event.set() + + async with anyio.create_task_group() as tg: + tg.start_soon(while_loop) + await anyio.lowlevel.checkpoint() + await event.wait() + tg.cancel_scope.cancel() + assert lock._owner == None + assert lock._count == 0 # The lock should be released after the task is cancelled + + try: + asyncio.run(run_test()) + except BaseException as e: + q.put(e) + + # Because cancellation is cooperative if we try to use wait_for to test for timeout it will just run forever, + # because control never yielded. await keyword does not yield control. + # Thus, we must wrap into a process to test for hanging. + q = multiprocessing.Queue() + proc = multiprocessing.Process(target=run_coroutine, args=(q,)) + proc.start() + # Wait for the process to finish, with a timeout. + proc.join(timeout=2) + + # If the process is still alive, it means the coroutine did not yield and cancel. + try: + assert not proc.is_alive(), "Test did not terminate as expected." + if not q.empty(): + raise q.get() + finally: + if proc.is_alive(): + proc.terminate() + proc.join(timeout=1) + + +def test_non_cooperative_cancel_reentrant_nested(): + num_acquires = 2 + + def run_coroutine(q): + async def run_test(): + lock = AnyIOFairAsyncRLock() + + async def while_loop(num_children): + idx = 0 + while True: + if idx < num_acquires: + await lock.acquire() + # await asyncio.sleep(0) # Uncommenting makes the task cooperative + idx += 1 + continue + else: + if num_children > 0: + async with anyio.create_task_group() as tg: + tg.start_soon(while_loop, num_children - 1) + await anyio.lowlevel.checkpoint() + + async with anyio.create_task_group() as tg: + tg.start_soon(while_loop, 3) + await asyncio.sleep(1) # Give the task a chance to run + tg.cancel_scope.cancel() + + assert lock._owner == None + assert lock._count == 0 # The lock should be released after the task is cancelled + + try: + asyncio.run(run_test()) + except BaseException as e: + q.put(e) + + # Because cancellation is cooperative if we try to use wait_for to test for timeout it will just run forever, + # because control never yielded. await keyword does not yield control. + # Thus, we must wrap into a process to test for hanging. + q = multiprocessing.Queue() + proc = multiprocessing.Process(target=run_coroutine, args=(q,)) + proc.start() + # Wait for the process to finish, with a timeout. + proc.join(timeout=2) + + # If the process is still alive, it means the coroutine did not yield and cancel. + try: + assert not proc.is_alive(), "Test did not terminate as expected." + if not q.empty(): + raise q.get() + finally: + if proc.is_alive(): + proc.terminate() + proc.join(timeout=1) + + +@pytest.mark.anyio +async def test_anyio_checkpoints(): + lock = AnyIOFairAsyncRLock() + + async def acquirer(): + async with lock: + pass + + async def neighbor(): + if not lock.locked(): + await anyio.sleep(0) + + assert lock.locked() + + # check for scheduling + async with anyio.create_task_group() as tg: + tg.start_soon(acquirer) + tg.start_soon(neighbor) + + # check for cancellation + with anyio.move_on_after(0): + with pytest.raises(anyio.get_cancelled_exc_class()): + await acquirer() diff --git a/fair_async_rlock/tests/test_fair_async_rlock.py b/src/fair_async_rlock/tests/test_fair_async_rlock.py similarity index 72% rename from fair_async_rlock/tests/test_fair_async_rlock.py rename to src/fair_async_rlock/tests/test_fair_async_rlock.py index 89a220c..7dbe225 100644 --- a/fair_async_rlock/tests/test_fair_async_rlock.py +++ b/src/fair_async_rlock/tests/test_fair_async_rlock.py @@ -1,4 +1,5 @@ import asyncio +import multiprocessing import random from time import monotonic_ns, perf_counter @@ -157,11 +158,13 @@ async def test_nested_lock_acquisition(): lock2 = FairAsyncRLock() lock1_acquired = asyncio.Event() + lock2_acquired = asyncio.Event() async def worker(): async with lock1: lock1_acquired.set() # Signal that lock1 has been acquired await asyncio.sleep(0) # Yield control while holding lock1 + await lock2_acquired.wait() # At this point, lock1 is released async def control_task(): @@ -170,6 +173,7 @@ async def control_task(): assert lock1.is_owner(task=task) # worker task should own lock1 async with lock2: # Acquire lock2 assert lock1.is_owner(task=task) # worker task should still own lock1 + lock2_acquired.set() # Signal that lock2 has been acquired await task # Await completion of worker task after lock2 is released await control_task() @@ -440,19 +444,20 @@ async def task_to_cancel(): @pytest.mark.asyncio async def test_lock_cancellation_after_acquisition(): lock = FairAsyncRLock() - cancellation_event = asyncio.Event() + acquire_event = asyncio.Event() async def task_to_cancel(): async with lock: # acquire the lock - try: - await asyncio.sleep(1) # simulate some work - except asyncio.CancelledError: - cancellation_event.set() + await asyncio.sleep(0) # simulate some work + acquire_event.set() + await asyncio.sleep(10) # hold the lock for a while task = asyncio.create_task(task_to_cancel()) await asyncio.sleep(0) # yield control to let the task start + await acquire_event.wait() task.cancel() - await cancellation_event.wait() # wait for the task to handle the cancellation + with pytest.raises(asyncio.CancelledError): + await task assert lock._owner is None # lock should not be owned by any task @@ -653,12 +658,261 @@ async def task3(): await asyncio.gather(t1, t2, t3) + @pytest.mark.asyncio def test_locked(): lock = FairAsyncRLock() assert not lock.locked() + async def task(): async with lock: assert lock.locked() + asyncio.run(task()) - assert not lock.locked() \ No newline at end of file + assert not lock.locked() + + +def test_non_cooperative_cancel(): + def run_coroutine(q): + async def run_test(): + lock = FairAsyncRLock() + + async def while_loop(): + idx = 0 + while True: + await lock.acquire() + # await asyncio.sleep(0) # Uncommenting makes the task + + t = asyncio.create_task(while_loop()) + await asyncio.sleep(0) # Give the task a chance to run + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t # Cancellation should raise CancelledError if cooperative + assert lock._count == 0 # The lock should be released after the task is cancelled + + try: + asyncio.run(run_test()) + except BaseException as e: + q.put(e) + + # Because cancellation is cooperative if we try to use wait_for to test for timeout it will just run forever, + # because control never yielded. await keyword does not yield control. + # Thus, we must wrap into a process to test for hanging. + q = multiprocessing.Queue() + proc = multiprocessing.Process(target=run_coroutine, args=(q,)) + proc.start() + # Wait for the process to finish, with a timeout. + proc.join(timeout=2) + + # If the process is still alive, it means the coroutine did not yield and cancel. + try: + assert not proc.is_alive(), "Test did not terminate as expected." + if not q.empty(): + raise q.get() + finally: + if proc.is_alive(): + proc.terminate() + proc.join(timeout=1) + + +def test_non_cooperative_cancel_reentrant(): + def run_coroutine(q): + async def run_test(): + event = asyncio.Event() + lock = FairAsyncRLock() + + async def while_loop(): + idx = 0 + while True: + await lock.acquire() + # await asyncio.sleep(0) # Uncommenting makes the task cooperative + idx += 1 + if idx >= 2: + event.set() + + t = asyncio.create_task(while_loop()) + await asyncio.sleep(0) # Give the task a chance to run + await event.wait() # Wait for the task to acquire the lock twice + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t # Cancellation should raise CancelledError if cooperative + assert lock._owner == None + assert lock._count == 0 # The lock should be released after the task is cancelled + + try: + asyncio.run(run_test()) + except BaseException as e: + q.put(e) + + # Because cancellation is cooperative if we try to use wait_for to test for timeout it will just run forever, + # because control never yielded. await keyword does not yield control. + # Thus, we must wrap into a process to test for hanging. + q = multiprocessing.Queue() + proc = multiprocessing.Process(target=run_coroutine, args=(q,)) + proc.start() + # Wait for the process to finish, with a timeout. + proc.join(timeout=2) + + # If the process is still alive, it means the coroutine did not yield and cancel. + try: + assert not proc.is_alive(), "Test did not terminate as expected." + if not q.empty(): + raise q.get() + finally: + if proc.is_alive(): + proc.terminate() + proc.join(timeout=1) + + +def test_non_cooperative_cancel_reentrant_nested(): + num_acquires = 2 + + def run_coroutine(q): + async def run_test(): + lock = FairAsyncRLock() + + async def while_loop(num_children): + idx = 0 + while True: + if idx < num_acquires: + await lock.acquire() + # await asyncio.sleep(0) # Uncommenting makes the task cooperative + idx += 1 + continue + else: + if num_children > 0: + task = asyncio.create_task(while_loop(num_children - 1)) + await task + + + t = asyncio.create_task(while_loop(num_children=3)) + await asyncio.sleep(1) # Give the task a chance to run + + t.cancel() + with pytest.raises(asyncio.CancelledError): + # Non-cooperative cancellation would timeout, and cancel never happens. + await t # Cancellation should raise CancelledError if cooperative + + assert lock._owner == None + assert lock._count == 0 # The lock should be released after the task is cancelled + + try: + asyncio.run(run_test()) + except BaseException as e: + q.put(e) + + # Because cancellation is cooperative if we try to use wait_for to test for timeout it will just run forever, + # because control never yielded. await keyword does not yield control. + # Thus, we must wrap into a process to test for hanging. + q = multiprocessing.Queue() + proc = multiprocessing.Process(target=run_coroutine, args=(q,)) + proc.start() + # Wait for the process to finish, with a timeout. + proc.join(timeout=2) + + # If the process is still alive, it means the coroutine did not yield and cancel. + try: + assert not proc.is_alive(), "Test did not terminate as expected." + if not q.empty(): + raise q.get() + finally: + if proc.is_alive(): + proc.terminate() + proc.join(timeout=1) + + +@pytest.mark.asyncio +async def test_chained_lock_count(): + lock = FairAsyncRLock() + + async def run_inner(): + async with lock: + assert lock._count == 1 + await asyncio.get_running_loop().create_future() + + async def run_outer(): + async with lock: + owner = lock._owner + assert lock._count == 1 + task = asyncio.create_task(run_inner()) + await asyncio.sleep(0.1) + try: + await task + except asyncio.CancelledError: + assert lock._count == 1 + assert lock._owner == owner + + task = asyncio.create_task(run_outer()) + await asyncio.sleep(1) + task.cancel() + try: + await task + except asyncio.CancelledError: + assert lock._count == 0 + assert lock._owner == None + + + +@pytest.mark.asyncio +async def test_chained_lock_count_reentrant(): + lock = FairAsyncRLock() + + async def c(): + async with lock: + assert lock._count == 3 + await asyncio.get_running_loop().create_future() + + async def b(): + async with lock: + assert lock._count == 2 + try: + await c() + except asyncio.CancelledError: + assert lock._count == 2 + raise + + async def a(): + async with lock: + assert lock._count == 1 + try: + await b() + except asyncio.CancelledError: + assert lock._count == 1 + raise + + task = asyncio.create_task(a()) + await asyncio.sleep(1) + task.cancel() + try: + await task + except asyncio.CancelledError: + assert lock._count == 0 + assert lock._owner == None + +@pytest.mark.asyncio +async def test_chained_lock_count_reentrant(): + lock = FairAsyncRLock() + + async def b(): + assert lock._count == 1 + assert lock._owner is not None + + asyncio.current_task().cancel() + + with pytest.raises(asyncio.CancelledError): + await lock.acquire() + + assert lock._count == 1 + assert lock._owner is not None + + async def a(): + assert lock._count == 0 + assert lock._owner is None + + async with lock: + await b() + + assert lock._count == 0 + assert lock._owner is None + + await a() \ No newline at end of file