From 4d1d5df9534192e8f502c4bc80699453c3ce058a Mon Sep 17 00:00:00 2001 From: Adam Novak Date: Tue, 3 Feb 2026 16:50:25 -0500 Subject: [PATCH 1/6] Consistently use flock and not the sometimes-non-interacting fcntl locks --- src/toil/server/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/toil/server/utils.py b/src/toil/server/utils.py index faec772c49..e57289dfc8 100644 --- a/src/toil/server/utils.py +++ b/src/toil/server/utils.py @@ -117,7 +117,7 @@ def safe_read_file(file: str) -> str | None: try: # acquire a shared lock on the state file, which is blocking until we can lock it - fcntl.lockf(file_obj.fileno(), fcntl.LOCK_SH) + fcntl.flock(file_obj.fileno(), fcntl.LOCK_SH) try: return file_obj.read() From beea0564b6deb1dfdfe90fe078ceae65981b82da Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 3 Feb 2026 17:19:10 -0500 Subject: [PATCH 2/6] Add interleaving tests for safe_read_file and safe_write_file Add deterministic tests that verify the locking protocol is correct by mocking fcntl.flock and file operations to control thread execution order. Tests verify: - Reader blocked while writer holds exclusive lock - Writer blocked while reader holds shared lock - Multiple readers can hold shared locks simultaneously - Writers serialize (cannot hold exclusive locks concurrently) - Reader never sees partial write content Also add BOTS.md with development environment notes for AI assistants. --- BOTS.md | 27 + CLAUDE.md | 1 + src/toil/test/server/safeFileTest.py | 703 +++++++++++++++++++++++++++ 3 files changed, 731 insertions(+) create mode 100644 BOTS.md create mode 120000 CLAUDE.md create mode 100644 src/toil/test/server/safeFileTest.py diff --git a/BOTS.md b/BOTS.md new file mode 100644 index 0000000000..1063452863 --- /dev/null +++ b/BOTS.md @@ -0,0 +1,27 @@ +# Notes for AI Assistants + +## Development Environment + +The Python virtual environment is likely located at `./venv`, but this is just a guess - the user may have it elsewhere or may have already activated a virtualenv. If commands like `python` or `pytest` work directly, the environment is probably already active. + +If you need to use the venv explicitly: + +```bash +./venv/bin/python -m pytest src/toil/test/path/to/test.py -v +./venv/bin/python -c "import toil; print(toil.__version__)" +``` + +## Running Tests + +Tests use pytest. Example commands: + +```bash +# Run a specific test file +./venv/bin/python -m pytest src/toil/test/server/safeFileTest.py -v + +# Run a specific test +./venv/bin/python -m pytest src/toil/test/server/safeFileTest.py::TestSafeFileInterleaving::test_reader_blocked_while_writer_holds_lock -v + +# Run tests with a keyword filter +./venv/bin/python -m pytest src/toil/test -k "safe" -v +``` diff --git a/CLAUDE.md b/CLAUDE.md new file mode 120000 index 0000000000..1a1007d91a --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1 @@ +BOTS.md \ No newline at end of file diff --git a/src/toil/test/server/safeFileTest.py b/src/toil/test/server/safeFileTest.py new file mode 100644 index 0000000000..f88a81cf31 --- /dev/null +++ b/src/toil/test/server/safeFileTest.py @@ -0,0 +1,703 @@ +# Copyright (C) 2015-2026 Regents of the University of California +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for safe_read_file and safe_write_file that verify correct locking +behavior by testing specific interleavings. + +Approach: Mock fcntl.flock and file I/O operations to inject synchronization +barriers, allowing deterministic control over thread execution order. This +tests that the locking protocol is correct - the code acquires the right +locks at the right times. +""" + +import fcntl +import os +import threading +import time +from collections.abc import Generator +from pathlib import Path +from typing import Any +from unittest.mock import patch + +import pytest + +# Time to wait between operations in interleaving tests. Must be long enough +# for threads to reach their synchronization points even on busy CI systems. +TICK_SECONDS = 2.0 + + +def _server_available() -> bool: + """Check if the server extra is installed.""" + try: + import connexion # noqa: F401 + + return True + except ImportError: + return False + + +needs_server = pytest.mark.skipif( + not _server_available(), + reason="Install Toil with the 'server' extra to include this test.", +) + + +class SimulatedLock: + """ + Simulates flock semantics using threading primitives. + + Supports shared (LOCK_SH) and exclusive (LOCK_EX) locks with proper + blocking behavior: + - Multiple shared locks can be held simultaneously + - Exclusive lock blocks until all shared locks are released + - Shared locks block while exclusive lock is held + """ + + def __init__(self) -> None: + self._condition = threading.Condition() + self._shared_count = 0 + self._exclusive_held = False + self._exclusive_holder: int | None = None + + def acquire_shared(self) -> None: + """Acquire a shared lock (blocks if exclusive lock held).""" + with self._condition: + while self._exclusive_held: + self._condition.wait() + self._shared_count += 1 + + def acquire_exclusive(self) -> None: + """Acquire an exclusive lock (blocks if any lock held).""" + thread_id = threading.current_thread().ident + with self._condition: + while self._exclusive_held or self._shared_count > 0: + self._condition.wait() + self._exclusive_held = True + self._exclusive_holder = thread_id + + def release(self) -> None: + """Release whatever lock this thread holds.""" + thread_id = threading.current_thread().ident + with self._condition: + if self._exclusive_held and self._exclusive_holder == thread_id: + self._exclusive_held = False + self._exclusive_holder = None + elif self._shared_count > 0: + self._shared_count -= 1 + self._condition.notify_all() + + @property + def has_exclusive(self) -> bool: + return self._exclusive_held + + @property + def shared_count(self) -> int: + return self._shared_count + + +class LockManager: + """Manages simulated locks for multiple files.""" + + def __init__(self) -> None: + self._locks: dict[str, SimulatedLock] = {} + self._fd_to_path: dict[int, str] = {} + self._global_lock = threading.Lock() + + # Barriers that tests can use to pause at specific points + self.before_acquire: dict[str, threading.Event] = {} + self.after_acquire: dict[str, threading.Event] = {} + self.before_release: dict[str, threading.Event] = {} + + def register_fd(self, fd: int, path: str) -> None: + """Associate a file descriptor with a path.""" + with self._global_lock: + real_path = os.path.realpath(path) + self._fd_to_path[fd] = real_path + if real_path not in self._locks: + self._locks[real_path] = SimulatedLock() + + def get_lock(self, fd: int) -> SimulatedLock | None: + """Get the lock for a file descriptor.""" + with self._global_lock: + path = self._fd_to_path.get(fd) + if path: + return self._locks.get(path) + return None + + def get_lock_for_path(self, path: str) -> SimulatedLock | None: + """Get the lock for a path.""" + real_path = os.path.realpath(path) + with self._global_lock: + return self._locks.get(real_path) + + def flock(self, fd: int, operation: int) -> None: + """Simulated flock that respects barriers.""" + lock = self.get_lock(fd) + if lock is None: + return + + thread_name = threading.current_thread().name + + if operation == fcntl.LOCK_UN: + barrier = self.before_release.get(thread_name) + if barrier: + barrier.wait() + lock.release() + + elif operation == fcntl.LOCK_SH: + barrier = self.before_acquire.get(thread_name) + if barrier: + barrier.wait() + lock.acquire_shared() + barrier = self.after_acquire.get(thread_name) + if barrier: + barrier.wait() + + elif operation == fcntl.LOCK_EX: + barrier = self.before_acquire.get(thread_name) + if barrier: + barrier.wait() + lock.acquire_exclusive() + barrier = self.after_acquire.get(thread_name) + if barrier: + barrier.wait() + + +class FileOperationTracker: + """ + Tracks and controls file operations with barriers. + + Wraps file objects to inject synchronization points during read/write. + """ + + def __init__(self, lock_manager: LockManager) -> None: + self.lock_manager = lock_manager + self.during_write: dict[str, threading.Event] = {} + self.during_read: dict[str, threading.Event] = {} + + def wrap_file(self, file_obj: Any, path: str) -> Any: + """Wrap a file object to track operations.""" + try: + self.lock_manager.register_fd(file_obj.fileno(), path) + except (OSError, ValueError): + pass + + original_write = file_obj.write + original_read = file_obj.read + tracker = self + + def tracked_write(data: str) -> int: + thread_name = threading.current_thread().name + barrier = tracker.during_write.get(thread_name) + if barrier: + barrier.wait() + return original_write(data) + + def tracked_read(size: int = -1) -> str: + thread_name = threading.current_thread().name + barrier = tracker.during_read.get(thread_name) + if barrier: + barrier.wait() + return original_read(size) + + file_obj.write = tracked_write + file_obj.read = tracked_read + return file_obj + + +@needs_server +@pytest.mark.timeout(30) +class TestSafeFileInterleaving: + """ + Tests that verify locking correctness through deterministic interleavings. + + Each test explicitly controls thread execution order using barriers to + verify that locks are held and respected at the right times. + """ + + @pytest.fixture(autouse=True) + def setup_managers(self, tmp_path: Path) -> Generator[None]: + """Set up lock manager and file tracker for each test.""" + self.test_file = tmp_path / "test_file" + self.lock_manager = LockManager() + self.file_tracker = FileOperationTracker(self.lock_manager) + yield + + def _create_patches(self) -> tuple[Any, Any]: + """Create patch contexts for flock and open.""" + original_open = open + lock_manager = self.lock_manager + file_tracker = self.file_tracker + + def patched_open(path: Any, mode: str = "r", *args: Any, **kwargs: Any) -> Any: + f = original_open(path, mode, *args, **kwargs) + return file_tracker.wrap_file(f, str(path)) + + def patched_flock(fd: int, operation: int) -> None: + lock_manager.flock(fd, operation) + + return ( + patch("builtins.open", patched_open), + patch("toil.server.utils.fcntl.flock", patched_flock), + ) + + def test_reader_blocked_while_writer_holds_lock(self) -> None: + """ + Verify that a reader cannot proceed while a writer holds the + exclusive lock. + + Sequence: + 1. Writer acquires exclusive lock + 2. Writer is paused (holding lock) + 3. Reader tries to acquire shared lock + 4. Verify reader is blocked (cannot proceed) + 5. Release writer + 6. Verify reader proceeds and sees written content + """ + from toil.server.utils import safe_read_file, safe_write_file + + # Create file with initial content + self.test_file.write_text("original") + + # Barrier to pause writer after acquiring lock + writer_after_lock = threading.Event() + self.lock_manager.after_acquire["writer"] = writer_after_lock + + # Barrier for write operation (in case it's needed) + writer_write_barrier = threading.Event() + self.file_tracker.during_write["writer"] = writer_write_barrier + + reader_completed = threading.Event() + results: dict[str, Any] = {"writer_done": False, "reader_result": None} + errors: list[Exception] = [] + + def writer() -> None: + try: + safe_write_file(str(self.test_file), "updated") + results["writer_done"] = True + except Exception as e: + errors.append(e) + + def reader() -> None: + try: + results["reader_result"] = safe_read_file(str(self.test_file)) + reader_completed.set() + except Exception as e: + errors.append(e) + + patches = self._create_patches() + with patches[0], patches[1]: + # Start writer - will block on after_acquire barrier + t_writer = threading.Thread(target=writer, name="writer") + t_writer.start() + + # Wait for writer to acquire lock and hit barrier + time.sleep(TICK_SECONDS) + + # Verify writer has the lock + lock = self.lock_manager.get_lock_for_path(str(self.test_file)) + assert lock is not None + assert lock.has_exclusive, "Writer should hold exclusive lock" + + # Start reader - should block on lock acquisition + t_reader = threading.Thread(target=reader, name="reader") + t_reader.start() + + # Give reader time to try to acquire lock + time.sleep(TICK_SECONDS) + + # Reader should NOT have completed (blocked on lock) + assert not reader_completed.is_set(), ( + "Reader should be blocked while writer holds exclusive lock" + ) + assert results["reader_result"] is None + + # Release writer barriers + writer_after_lock.set() + writer_write_barrier.set() + + # Wait for both to complete + t_writer.join(timeout=TICK_SECONDS * 5) + t_reader.join(timeout=TICK_SECONDS * 5) + + assert errors == [] + assert results["writer_done"] + # Reader should see updated content (writer finished first) + assert results["reader_result"] == "updated" + + def test_writer_blocked_while_reader_holds_lock(self) -> None: + """ + Verify that a writer cannot proceed while a reader holds a + shared lock. + + Sequence: + 1. Reader acquires shared lock + 2. Reader is paused (holding lock) + 3. Writer tries to acquire exclusive lock + 4. Verify writer is blocked + 5. Release reader + 6. Verify writer proceeds + """ + from toil.server.utils import safe_read_file, safe_write_file + + # Create file with initial content + self.test_file.write_text("original") + + # Barrier to pause reader after acquiring lock + reader_after_lock = threading.Event() + self.lock_manager.after_acquire["reader"] = reader_after_lock + + writer_completed = threading.Event() + results: dict[str, Any] = {"reader_result": None, "writer_done": False} + errors: list[Exception] = [] + + def reader() -> None: + try: + results["reader_result"] = safe_read_file(str(self.test_file)) + except Exception as e: + errors.append(e) + + def writer() -> None: + try: + safe_write_file(str(self.test_file), "updated") + results["writer_done"] = True + writer_completed.set() + except Exception as e: + errors.append(e) + + patches = self._create_patches() + with patches[0], patches[1]: + # Start reader - will block on after_acquire barrier + t_reader = threading.Thread(target=reader, name="reader") + t_reader.start() + + # Wait for reader to acquire lock and hit barrier + time.sleep(TICK_SECONDS) + + # Verify reader has the lock + lock = self.lock_manager.get_lock_for_path(str(self.test_file)) + assert lock is not None + assert lock.shared_count == 1, "Reader should hold shared lock" + + # Start writer - should block on lock acquisition + t_writer = threading.Thread(target=writer, name="writer") + t_writer.start() + + # Give writer time to try to acquire lock + time.sleep(TICK_SECONDS) + + # Writer should NOT have completed (blocked on lock) + assert not writer_completed.is_set(), ( + "Writer should be blocked while reader holds shared lock" + ) + assert not results["writer_done"] + + # Release reader barrier + reader_after_lock.set() + + # Wait for both to complete + t_reader.join(timeout=TICK_SECONDS * 5) + t_writer.join(timeout=TICK_SECONDS * 5) + + assert errors == [] + assert results["reader_result"] == "original" + assert results["writer_done"] + + def test_multiple_readers_not_blocked(self) -> None: + """ + Verify that multiple readers can hold shared locks simultaneously. + + Sequence: + 1. Reader1 acquires shared lock, pauses + 2. Reader2 acquires shared lock (should succeed immediately) + 3. Both readers hold locks simultaneously + 4. Release both, verify both complete + """ + from toil.server.utils import safe_read_file + + # Create file + self.test_file.write_text("content") + + # Barriers to pause both readers after acquiring locks + reader1_after_lock = threading.Event() + reader2_after_lock = threading.Event() + self.lock_manager.after_acquire["reader1"] = reader1_after_lock + self.lock_manager.after_acquire["reader2"] = reader2_after_lock + + results: dict[str, str | None] = {"reader1": None, "reader2": None} + errors: list[Exception] = [] + + def reader1() -> None: + try: + results["reader1"] = safe_read_file(str(self.test_file)) + except Exception as e: + errors.append(e) + + def reader2() -> None: + try: + results["reader2"] = safe_read_file(str(self.test_file)) + except Exception as e: + errors.append(e) + + patches = self._create_patches() + with patches[0], patches[1]: + # Start reader1 + t_reader1 = threading.Thread(target=reader1, name="reader1") + t_reader1.start() + + # Wait for reader1 to acquire lock + time.sleep(TICK_SECONDS) + + lock = self.lock_manager.get_lock_for_path(str(self.test_file)) + assert lock is not None + assert lock.shared_count == 1 + + # Start reader2 - should also acquire shared lock (not blocked) + t_reader2 = threading.Thread(target=reader2, name="reader2") + t_reader2.start() + + # Give reader2 time to acquire lock + time.sleep(TICK_SECONDS) + + # Both should hold shared locks simultaneously + assert lock.shared_count == 2, ( + "Both readers should hold shared locks simultaneously" + ) + + # Release both barriers + reader1_after_lock.set() + reader2_after_lock.set() + + t_reader1.join(timeout=TICK_SECONDS * 5) + t_reader2.join(timeout=TICK_SECONDS * 5) + + assert errors == [] + assert results["reader1"] == "content" + assert results["reader2"] == "content" + + def test_writers_serialize(self) -> None: + """ + Verify that two writers cannot hold exclusive locks simultaneously. + + Sequence: + 1. Writer1 acquires exclusive lock, pauses + 2. Writer2 tries to acquire exclusive lock + 3. Verify writer2 is blocked + 4. Release writer1 + 5. Verify writer2 proceeds + """ + from toil.server.utils import safe_write_file + + # Create file + self.test_file.write_text("original") + + # Barrier to pause writer1 after acquiring lock + writer1_after_lock = threading.Event() + self.lock_manager.after_acquire["writer1"] = writer1_after_lock + + writer2_completed = threading.Event() + results: dict[str, bool] = {"writer1_done": False, "writer2_done": False} + errors: list[Exception] = [] + + def writer1() -> None: + try: + safe_write_file(str(self.test_file), "from_writer1") + results["writer1_done"] = True + except Exception as e: + errors.append(e) + + def writer2() -> None: + try: + safe_write_file(str(self.test_file), "from_writer2") + results["writer2_done"] = True + writer2_completed.set() + except Exception as e: + errors.append(e) + + patches = self._create_patches() + with patches[0], patches[1]: + # Start writer1 + t_writer1 = threading.Thread(target=writer1, name="writer1") + t_writer1.start() + + # Wait for writer1 to acquire lock + time.sleep(TICK_SECONDS) + + lock = self.lock_manager.get_lock_for_path(str(self.test_file)) + assert lock is not None + assert lock.has_exclusive, "Writer1 should hold exclusive lock" + + # Start writer2 - should block + t_writer2 = threading.Thread(target=writer2, name="writer2") + t_writer2.start() + + # Give writer2 time to try to acquire + time.sleep(TICK_SECONDS) + + # Writer2 should be blocked + assert not writer2_completed.is_set(), ( + "Writer2 should be blocked while writer1 holds exclusive lock" + ) + assert not results["writer2_done"] + + # Writer1 still holds exclusive lock + assert lock.has_exclusive + + # Release writer1 + writer1_after_lock.set() + + # Wait for both + t_writer1.join(timeout=TICK_SECONDS * 5) + t_writer2.join(timeout=TICK_SECONDS * 5) + + assert errors == [] + assert results["writer1_done"] + assert results["writer2_done"] + + def test_reader_never_sees_partial_write(self) -> None: + """ + Verify that a reader cannot see partial write content. + + The lock ensures the reader either sees the old content or the + complete new content, never content mid-write. + + Sequence: + 1. File has "AAAA" + 2. Writer acquires lock, pauses before writing + 3. Reader tries to read - should block + 4. Writer completes writing "BBBB" + 5. Writer releases lock + 6. Reader acquires lock, reads complete "BBBB" + """ + from toil.server.utils import safe_read_file, safe_write_file + + # Create file with initial content + self.test_file.write_text("AAAA") + + # Pause writer after acquiring lock (before any write) + writer_after_lock = threading.Event() + self.lock_manager.after_acquire["writer"] = writer_after_lock + + results: dict[str, Any] = {"reader_result": None} + errors: list[Exception] = [] + + def writer() -> None: + try: + safe_write_file(str(self.test_file), "BBBB") + except Exception as e: + errors.append(e) + + def reader() -> None: + try: + results["reader_result"] = safe_read_file(str(self.test_file)) + except Exception as e: + errors.append(e) + + patches = self._create_patches() + with patches[0], patches[1]: + # Start writer + t_writer = threading.Thread(target=writer, name="writer") + t_writer.start() + + # Wait for writer to hold lock + time.sleep(TICK_SECONDS) + + # Start reader while writer holds lock + t_reader = threading.Thread(target=reader, name="reader") + t_reader.start() + + # Give reader time to block + time.sleep(TICK_SECONDS) + + # Reader should still be waiting (not completed) + assert results["reader_result"] is None + + # Release writer to complete + writer_after_lock.set() + + t_writer.join(timeout=TICK_SECONDS * 5) + t_reader.join(timeout=TICK_SECONDS * 5) + + assert errors == [] + # Reader must see complete new content, never partial or mixed + assert results["reader_result"] == "BBBB", ( + "Reader should see complete write, not partial" + ) + + def test_writer_paused_mid_write_blocks_reader(self) -> None: + """ + Verify that a reader is blocked even when writer is paused during + the actual write operation (not just after lock acquisition). + + This tests that the lock is held throughout the entire write. + """ + from toil.server.utils import safe_read_file, safe_write_file + + # Create file with initial content + self.test_file.write_text("original") + + # Pause writer during the write operation itself + writer_during_write = threading.Event() + self.file_tracker.during_write["writer"] = writer_during_write + + reader_completed = threading.Event() + results: dict[str, Any] = {"reader_result": None} + errors: list[Exception] = [] + + def writer() -> None: + try: + safe_write_file(str(self.test_file), "updated") + except Exception as e: + errors.append(e) + + def reader() -> None: + try: + results["reader_result"] = safe_read_file(str(self.test_file)) + reader_completed.set() + except Exception as e: + errors.append(e) + + patches = self._create_patches() + with patches[0], patches[1]: + # Start writer - will block during write + t_writer = threading.Thread(target=writer, name="writer") + t_writer.start() + + # Wait for writer to reach the write barrier + time.sleep(TICK_SECONDS) + + # Writer should hold exclusive lock while paused mid-write + lock = self.lock_manager.get_lock_for_path(str(self.test_file)) + assert lock is not None + assert lock.has_exclusive, "Writer should hold lock during write" + + # Start reader - should block + t_reader = threading.Thread(target=reader, name="reader") + t_reader.start() + + time.sleep(TICK_SECONDS) + + # Reader should NOT have completed + assert not reader_completed.is_set(), ( + "Reader should be blocked while writer is mid-write" + ) + + # Release writer to finish + writer_during_write.set() + + t_writer.join(timeout=TICK_SECONDS * 5) + t_reader.join(timeout=TICK_SECONDS * 5) + + assert errors == [] + assert results["reader_result"] == "updated" From 814422d70aa5380a21617d6d4c4572f69ac5fc93 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 3 Feb 2026 17:23:03 -0500 Subject: [PATCH 3/6] Replace sleeps with event-based synchronization in interleaving tests Use Checkpoint class with arrive_and_wait/wait_for_arrival pattern instead of time.sleep() calls. Tests now run in ~0.7s instead of ~24s and are more deterministic. --- src/toil/test/server/safeFileTest.py | 333 ++++++++++++--------------- 1 file changed, 145 insertions(+), 188 deletions(-) diff --git a/src/toil/test/server/safeFileTest.py b/src/toil/test/server/safeFileTest.py index f88a81cf31..22349de8c4 100644 --- a/src/toil/test/server/safeFileTest.py +++ b/src/toil/test/server/safeFileTest.py @@ -16,15 +16,16 @@ behavior by testing specific interleavings. Approach: Mock fcntl.flock and file I/O operations to inject synchronization -barriers, allowing deterministic control over thread execution order. This +checkpoints, allowing deterministic control over thread execution order. This tests that the locking protocol is correct - the code acquires the right locks at the right times. + +The tests use no sleeps - all synchronization is done via events/conditions. """ import fcntl import os import threading -import time from collections.abc import Generator from pathlib import Path from typing import Any @@ -32,9 +33,10 @@ import pytest -# Time to wait between operations in interleaving tests. Must be long enough -# for threads to reach their synchronization points even on busy CI systems. -TICK_SECONDS = 2.0 +# Timeout for waiting on synchronization events. Should be long enough to +# never trigger in normal operation, but short enough to fail fast if +# something deadlocks. +SYNC_TIMEOUT = 10.0 def _server_available() -> bool: @@ -53,6 +55,51 @@ def _server_available() -> bool: ) +class Checkpoint: + """ + A synchronization point that allows a test to pause a thread and know + when the thread has arrived. + + Usage: + checkpoint = Checkpoint() + + # In worker thread: + checkpoint.arrive_and_wait() # signals arrival, then blocks + + # In test: + checkpoint.wait_for_arrival() # blocks until thread arrives + # ... check state ... + checkpoint.release() # allows thread to proceed + """ + + def __init__(self) -> None: + self._arrived = threading.Event() + self._released = threading.Event() + + def arrive_and_wait(self, timeout: float = SYNC_TIMEOUT) -> bool: + """ + Signal that we've arrived at the checkpoint, then wait for release. + Returns True if released, False if timed out. + """ + self._arrived.set() + return self._released.wait(timeout=timeout) + + def wait_for_arrival(self, timeout: float = SYNC_TIMEOUT) -> bool: + """ + Wait for a thread to arrive at this checkpoint. + Returns True if arrived, False if timed out. + """ + return self._arrived.wait(timeout=timeout) + + def release(self) -> None: + """Release the thread waiting at this checkpoint.""" + self._released.set() + + def has_arrived(self) -> bool: + """Check if a thread has arrived (non-blocking).""" + return self._arrived.is_set() + + class SimulatedLock: """ Simulates flock semantics using threading primitives. @@ -114,10 +161,9 @@ def __init__(self) -> None: self._fd_to_path: dict[int, str] = {} self._global_lock = threading.Lock() - # Barriers that tests can use to pause at specific points - self.before_acquire: dict[str, threading.Event] = {} - self.after_acquire: dict[str, threading.Event] = {} - self.before_release: dict[str, threading.Event] = {} + # Checkpoints that tests can use to pause at specific points. + # Keyed by thread name. + self.after_acquire: dict[str, Checkpoint] = {} def register_fd(self, fd: int, path: str) -> None: """Associate a file descriptor with a path.""" @@ -142,7 +188,7 @@ def get_lock_for_path(self, path: str) -> SimulatedLock | None: return self._locks.get(real_path) def flock(self, fd: int, operation: int) -> None: - """Simulated flock that respects barriers.""" + """Simulated flock that respects checkpoints.""" lock = self.get_lock(fd) if lock is None: return @@ -150,41 +196,33 @@ def flock(self, fd: int, operation: int) -> None: thread_name = threading.current_thread().name if operation == fcntl.LOCK_UN: - barrier = self.before_release.get(thread_name) - if barrier: - barrier.wait() lock.release() elif operation == fcntl.LOCK_SH: - barrier = self.before_acquire.get(thread_name) - if barrier: - barrier.wait() lock.acquire_shared() - barrier = self.after_acquire.get(thread_name) - if barrier: - barrier.wait() + checkpoint = self.after_acquire.get(thread_name) + if checkpoint: + checkpoint.arrive_and_wait() elif operation == fcntl.LOCK_EX: - barrier = self.before_acquire.get(thread_name) - if barrier: - barrier.wait() lock.acquire_exclusive() - barrier = self.after_acquire.get(thread_name) - if barrier: - barrier.wait() + checkpoint = self.after_acquire.get(thread_name) + if checkpoint: + checkpoint.arrive_and_wait() class FileOperationTracker: """ - Tracks and controls file operations with barriers. + Tracks and controls file operations with checkpoints. Wraps file objects to inject synchronization points during read/write. """ def __init__(self, lock_manager: LockManager) -> None: self.lock_manager = lock_manager - self.during_write: dict[str, threading.Event] = {} - self.during_read: dict[str, threading.Event] = {} + # Checkpoints keyed by thread name + self.during_write: dict[str, Checkpoint] = {} + self.during_read: dict[str, Checkpoint] = {} def wrap_file(self, file_obj: Any, path: str) -> Any: """Wrap a file object to track operations.""" @@ -199,16 +237,16 @@ def wrap_file(self, file_obj: Any, path: str) -> Any: def tracked_write(data: str) -> int: thread_name = threading.current_thread().name - barrier = tracker.during_write.get(thread_name) - if barrier: - barrier.wait() + checkpoint = tracker.during_write.get(thread_name) + if checkpoint: + checkpoint.arrive_and_wait() return original_write(data) def tracked_read(size: int = -1) -> str: thread_name = threading.current_thread().name - barrier = tracker.during_read.get(thread_name) - if barrier: - barrier.wait() + checkpoint = tracker.during_read.get(thread_name) + if checkpoint: + checkpoint.arrive_and_wait() return original_read(size) file_obj.write = tracked_write @@ -222,8 +260,9 @@ class TestSafeFileInterleaving: """ Tests that verify locking correctness through deterministic interleavings. - Each test explicitly controls thread execution order using barriers to - verify that locks are held and respected at the right times. + Each test explicitly controls thread execution order using checkpoints to + verify that locks are held and respected at the right times. No sleeps + are used - all synchronization is event-based. """ @pytest.fixture(autouse=True) @@ -258,25 +297,20 @@ def test_reader_blocked_while_writer_holds_lock(self) -> None: exclusive lock. Sequence: - 1. Writer acquires exclusive lock - 2. Writer is paused (holding lock) - 3. Reader tries to acquire shared lock - 4. Verify reader is blocked (cannot proceed) - 5. Release writer - 6. Verify reader proceeds and sees written content + 1. Writer acquires exclusive lock, arrives at checkpoint + 2. Test verifies writer has lock + 3. Reader tries to acquire shared lock (will block on simulated lock) + 4. Test verifies reader is blocked + 5. Test releases writer checkpoint + 6. Both complete, reader sees written content """ from toil.server.utils import safe_read_file, safe_write_file - # Create file with initial content self.test_file.write_text("original") - # Barrier to pause writer after acquiring lock - writer_after_lock = threading.Event() - self.lock_manager.after_acquire["writer"] = writer_after_lock - - # Barrier for write operation (in case it's needed) - writer_write_barrier = threading.Event() - self.file_tracker.during_write["writer"] = writer_write_barrier + # Checkpoint to pause writer after acquiring lock + writer_checkpoint = Checkpoint() + self.lock_manager.after_acquire["writer"] = writer_checkpoint reader_completed = threading.Event() results: dict[str, Any] = {"writer_done": False, "reader_result": None} @@ -298,65 +332,49 @@ def reader() -> None: patches = self._create_patches() with patches[0], patches[1]: - # Start writer - will block on after_acquire barrier t_writer = threading.Thread(target=writer, name="writer") t_writer.start() - # Wait for writer to acquire lock and hit barrier - time.sleep(TICK_SECONDS) + # Wait for writer to acquire lock and hit checkpoint + assert writer_checkpoint.wait_for_arrival(), "Writer didn't reach checkpoint" # Verify writer has the lock lock = self.lock_manager.get_lock_for_path(str(self.test_file)) assert lock is not None assert lock.has_exclusive, "Writer should hold exclusive lock" - # Start reader - should block on lock acquisition + # Start reader - will block on simulated lock (not checkpoint) t_reader = threading.Thread(target=reader, name="reader") t_reader.start() - # Give reader time to try to acquire lock - time.sleep(TICK_SECONDS) - - # Reader should NOT have completed (blocked on lock) - assert not reader_completed.is_set(), ( + # Reader should NOT complete while writer holds lock. + # Give it a moment to try, then check it's still blocked. + assert not reader_completed.wait(timeout=0.1), ( "Reader should be blocked while writer holds exclusive lock" ) assert results["reader_result"] is None - # Release writer barriers - writer_after_lock.set() - writer_write_barrier.set() + # Release writer + writer_checkpoint.release() - # Wait for both to complete - t_writer.join(timeout=TICK_SECONDS * 5) - t_reader.join(timeout=TICK_SECONDS * 5) + t_writer.join(timeout=SYNC_TIMEOUT) + t_reader.join(timeout=SYNC_TIMEOUT) assert errors == [] assert results["writer_done"] - # Reader should see updated content (writer finished first) assert results["reader_result"] == "updated" def test_writer_blocked_while_reader_holds_lock(self) -> None: """ Verify that a writer cannot proceed while a reader holds a shared lock. - - Sequence: - 1. Reader acquires shared lock - 2. Reader is paused (holding lock) - 3. Writer tries to acquire exclusive lock - 4. Verify writer is blocked - 5. Release reader - 6. Verify writer proceeds """ from toil.server.utils import safe_read_file, safe_write_file - # Create file with initial content self.test_file.write_text("original") - # Barrier to pause reader after acquiring lock - reader_after_lock = threading.Event() - self.lock_manager.after_acquire["reader"] = reader_after_lock + reader_checkpoint = Checkpoint() + self.lock_manager.after_acquire["reader"] = reader_checkpoint writer_completed = threading.Event() results: dict[str, Any] = {"reader_result": None, "writer_done": False} @@ -378,37 +396,28 @@ def writer() -> None: patches = self._create_patches() with patches[0], patches[1]: - # Start reader - will block on after_acquire barrier t_reader = threading.Thread(target=reader, name="reader") t_reader.start() - # Wait for reader to acquire lock and hit barrier - time.sleep(TICK_SECONDS) + assert reader_checkpoint.wait_for_arrival(), "Reader didn't reach checkpoint" - # Verify reader has the lock lock = self.lock_manager.get_lock_for_path(str(self.test_file)) assert lock is not None assert lock.shared_count == 1, "Reader should hold shared lock" - # Start writer - should block on lock acquisition t_writer = threading.Thread(target=writer, name="writer") t_writer.start() - # Give writer time to try to acquire lock - time.sleep(TICK_SECONDS) - - # Writer should NOT have completed (blocked on lock) - assert not writer_completed.is_set(), ( + # Writer should be blocked + assert not writer_completed.wait(timeout=0.1), ( "Writer should be blocked while reader holds shared lock" ) assert not results["writer_done"] - # Release reader barrier - reader_after_lock.set() + reader_checkpoint.release() - # Wait for both to complete - t_reader.join(timeout=TICK_SECONDS * 5) - t_writer.join(timeout=TICK_SECONDS * 5) + t_reader.join(timeout=SYNC_TIMEOUT) + t_writer.join(timeout=SYNC_TIMEOUT) assert errors == [] assert results["reader_result"] == "original" @@ -417,23 +426,15 @@ def writer() -> None: def test_multiple_readers_not_blocked(self) -> None: """ Verify that multiple readers can hold shared locks simultaneously. - - Sequence: - 1. Reader1 acquires shared lock, pauses - 2. Reader2 acquires shared lock (should succeed immediately) - 3. Both readers hold locks simultaneously - 4. Release both, verify both complete """ from toil.server.utils import safe_read_file - # Create file self.test_file.write_text("content") - # Barriers to pause both readers after acquiring locks - reader1_after_lock = threading.Event() - reader2_after_lock = threading.Event() - self.lock_manager.after_acquire["reader1"] = reader1_after_lock - self.lock_manager.after_acquire["reader2"] = reader2_after_lock + reader1_checkpoint = Checkpoint() + reader2_checkpoint = Checkpoint() + self.lock_manager.after_acquire["reader1"] = reader1_checkpoint + self.lock_manager.after_acquire["reader2"] = reader2_checkpoint results: dict[str, str | None] = {"reader1": None, "reader2": None} errors: list[Exception] = [] @@ -452,35 +453,34 @@ def reader2() -> None: patches = self._create_patches() with patches[0], patches[1]: - # Start reader1 t_reader1 = threading.Thread(target=reader1, name="reader1") t_reader1.start() - # Wait for reader1 to acquire lock - time.sleep(TICK_SECONDS) + assert reader1_checkpoint.wait_for_arrival(), "Reader1 didn't reach checkpoint" lock = self.lock_manager.get_lock_for_path(str(self.test_file)) assert lock is not None assert lock.shared_count == 1 - # Start reader2 - should also acquire shared lock (not blocked) + # Reader2 should also acquire shared lock (not blocked by reader1) t_reader2 = threading.Thread(target=reader2, name="reader2") t_reader2.start() - # Give reader2 time to acquire lock - time.sleep(TICK_SECONDS) + # Reader2 should reach its checkpoint (proving it got the lock) + assert reader2_checkpoint.wait_for_arrival(), ( + "Reader2 should acquire shared lock while reader1 holds one" + ) # Both should hold shared locks simultaneously assert lock.shared_count == 2, ( "Both readers should hold shared locks simultaneously" ) - # Release both barriers - reader1_after_lock.set() - reader2_after_lock.set() + reader1_checkpoint.release() + reader2_checkpoint.release() - t_reader1.join(timeout=TICK_SECONDS * 5) - t_reader2.join(timeout=TICK_SECONDS * 5) + t_reader1.join(timeout=SYNC_TIMEOUT) + t_reader2.join(timeout=SYNC_TIMEOUT) assert errors == [] assert results["reader1"] == "content" @@ -489,22 +489,13 @@ def reader2() -> None: def test_writers_serialize(self) -> None: """ Verify that two writers cannot hold exclusive locks simultaneously. - - Sequence: - 1. Writer1 acquires exclusive lock, pauses - 2. Writer2 tries to acquire exclusive lock - 3. Verify writer2 is blocked - 4. Release writer1 - 5. Verify writer2 proceeds """ from toil.server.utils import safe_write_file - # Create file self.test_file.write_text("original") - # Barrier to pause writer1 after acquiring lock - writer1_after_lock = threading.Event() - self.lock_manager.after_acquire["writer1"] = writer1_after_lock + writer1_checkpoint = Checkpoint() + self.lock_manager.after_acquire["writer1"] = writer1_checkpoint writer2_completed = threading.Event() results: dict[str, bool] = {"writer1_done": False, "writer2_done": False} @@ -527,39 +518,29 @@ def writer2() -> None: patches = self._create_patches() with patches[0], patches[1]: - # Start writer1 t_writer1 = threading.Thread(target=writer1, name="writer1") t_writer1.start() - # Wait for writer1 to acquire lock - time.sleep(TICK_SECONDS) + assert writer1_checkpoint.wait_for_arrival(), "Writer1 didn't reach checkpoint" lock = self.lock_manager.get_lock_for_path(str(self.test_file)) assert lock is not None assert lock.has_exclusive, "Writer1 should hold exclusive lock" - # Start writer2 - should block t_writer2 = threading.Thread(target=writer2, name="writer2") t_writer2.start() - # Give writer2 time to try to acquire - time.sleep(TICK_SECONDS) - # Writer2 should be blocked - assert not writer2_completed.is_set(), ( + assert not writer2_completed.wait(timeout=0.1), ( "Writer2 should be blocked while writer1 holds exclusive lock" ) assert not results["writer2_done"] + assert lock.has_exclusive # Still held by writer1 - # Writer1 still holds exclusive lock - assert lock.has_exclusive + writer1_checkpoint.release() - # Release writer1 - writer1_after_lock.set() - - # Wait for both - t_writer1.join(timeout=TICK_SECONDS * 5) - t_writer2.join(timeout=TICK_SECONDS * 5) + t_writer1.join(timeout=SYNC_TIMEOUT) + t_writer2.join(timeout=SYNC_TIMEOUT) assert errors == [] assert results["writer1_done"] @@ -568,28 +549,17 @@ def writer2() -> None: def test_reader_never_sees_partial_write(self) -> None: """ Verify that a reader cannot see partial write content. - - The lock ensures the reader either sees the old content or the - complete new content, never content mid-write. - - Sequence: - 1. File has "AAAA" - 2. Writer acquires lock, pauses before writing - 3. Reader tries to read - should block - 4. Writer completes writing "BBBB" - 5. Writer releases lock - 6. Reader acquires lock, reads complete "BBBB" + Reader either sees old content or complete new content. """ from toil.server.utils import safe_read_file, safe_write_file - # Create file with initial content self.test_file.write_text("AAAA") - # Pause writer after acquiring lock (before any write) - writer_after_lock = threading.Event() - self.lock_manager.after_acquire["writer"] = writer_after_lock + writer_checkpoint = Checkpoint() + self.lock_manager.after_acquire["writer"] = writer_checkpoint results: dict[str, Any] = {"reader_result": None} + reader_completed = threading.Event() errors: list[Exception] = [] def writer() -> None: @@ -601,36 +571,32 @@ def writer() -> None: def reader() -> None: try: results["reader_result"] = safe_read_file(str(self.test_file)) + reader_completed.set() except Exception as e: errors.append(e) patches = self._create_patches() with patches[0], patches[1]: - # Start writer t_writer = threading.Thread(target=writer, name="writer") t_writer.start() - # Wait for writer to hold lock - time.sleep(TICK_SECONDS) + assert writer_checkpoint.wait_for_arrival(), "Writer didn't reach checkpoint" - # Start reader while writer holds lock t_reader = threading.Thread(target=reader, name="reader") t_reader.start() - # Give reader time to block - time.sleep(TICK_SECONDS) - - # Reader should still be waiting (not completed) + # Reader should be blocked + assert not reader_completed.wait(timeout=0.1), ( + "Reader should be blocked while writer holds lock" + ) assert results["reader_result"] is None - # Release writer to complete - writer_after_lock.set() + writer_checkpoint.release() - t_writer.join(timeout=TICK_SECONDS * 5) - t_reader.join(timeout=TICK_SECONDS * 5) + t_writer.join(timeout=SYNC_TIMEOUT) + t_reader.join(timeout=SYNC_TIMEOUT) assert errors == [] - # Reader must see complete new content, never partial or mixed assert results["reader_result"] == "BBBB", ( "Reader should see complete write, not partial" ) @@ -639,17 +605,14 @@ def test_writer_paused_mid_write_blocks_reader(self) -> None: """ Verify that a reader is blocked even when writer is paused during the actual write operation (not just after lock acquisition). - - This tests that the lock is held throughout the entire write. """ from toil.server.utils import safe_read_file, safe_write_file - # Create file with initial content self.test_file.write_text("original") # Pause writer during the write operation itself - writer_during_write = threading.Event() - self.file_tracker.during_write["writer"] = writer_during_write + write_checkpoint = Checkpoint() + self.file_tracker.during_write["writer"] = write_checkpoint reader_completed = threading.Event() results: dict[str, Any] = {"reader_result": None} @@ -670,34 +633,28 @@ def reader() -> None: patches = self._create_patches() with patches[0], patches[1]: - # Start writer - will block during write t_writer = threading.Thread(target=writer, name="writer") t_writer.start() - # Wait for writer to reach the write barrier - time.sleep(TICK_SECONDS) + assert write_checkpoint.wait_for_arrival(), "Writer didn't reach write checkpoint" # Writer should hold exclusive lock while paused mid-write lock = self.lock_manager.get_lock_for_path(str(self.test_file)) assert lock is not None assert lock.has_exclusive, "Writer should hold lock during write" - # Start reader - should block t_reader = threading.Thread(target=reader, name="reader") t_reader.start() - time.sleep(TICK_SECONDS) - - # Reader should NOT have completed - assert not reader_completed.is_set(), ( + # Reader should be blocked + assert not reader_completed.wait(timeout=0.1), ( "Reader should be blocked while writer is mid-write" ) - # Release writer to finish - writer_during_write.set() + write_checkpoint.release() - t_writer.join(timeout=TICK_SECONDS * 5) - t_reader.join(timeout=TICK_SECONDS * 5) + t_writer.join(timeout=SYNC_TIMEOUT) + t_reader.join(timeout=SYNC_TIMEOUT) assert errors == [] assert results["reader_result"] == "updated" From 4caabdb8d05d187f1330a6c1db908ea063004a32 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 3 Feb 2026 17:31:04 -0500 Subject: [PATCH 4/6] Import needs_server from toil.test instead of redefining --- src/toil/test/server/safeFileTest.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/src/toil/test/server/safeFileTest.py b/src/toil/test/server/safeFileTest.py index 22349de8c4..e6430816a4 100644 --- a/src/toil/test/server/safeFileTest.py +++ b/src/toil/test/server/safeFileTest.py @@ -33,28 +33,14 @@ import pytest +from toil.test import needs_server + # Timeout for waiting on synchronization events. Should be long enough to # never trigger in normal operation, but short enough to fail fast if # something deadlocks. SYNC_TIMEOUT = 10.0 -def _server_available() -> bool: - """Check if the server extra is installed.""" - try: - import connexion # noqa: F401 - - return True - except ImportError: - return False - - -needs_server = pytest.mark.skipif( - not _server_available(), - reason="Install Toil with the 'server' extra to include this test.", -) - - class Checkpoint: """ A synchronization point that allows a test to pause a thread and know From 5740f2bc7ccd039af2c5c624015ed91e8b912ac4 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 3 Feb 2026 17:39:10 -0500 Subject: [PATCH 5/6] Remove unnecessary needs_server decorator, add TODOs for known issues --- src/toil/test/server/safeFileTest.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/toil/test/server/safeFileTest.py b/src/toil/test/server/safeFileTest.py index e6430816a4..5ec4eb6cd8 100644 --- a/src/toil/test/server/safeFileTest.py +++ b/src/toil/test/server/safeFileTest.py @@ -33,8 +33,6 @@ import pytest -from toil.test import needs_server - # Timeout for waiting on synchronization events. Should be long enough to # never trigger in normal operation, but short enough to fail fast if # something deadlocks. @@ -86,6 +84,9 @@ def has_arrived(self) -> bool: return self._arrived.is_set() +# TODO: SimulatedLock doesn't track which threads hold shared locks. +# release() will decrement count even if called by a thread that doesn't hold +# a lock. Should use a set to track shared lock holders. class SimulatedLock: """ Simulates flock semantics using threading primitives. @@ -240,7 +241,10 @@ def tracked_read(size: int = -1) -> str: return file_obj -@needs_server +# TODO: Add tests for AtomicFileCreate path (concurrent new file creation). +# TODO: The 0.1s timeout waits to verify blocking are effectively sleeps; +# consider adding a checkpoint before lock acquisition to prove thread is +# actually blocked waiting for the lock. @pytest.mark.timeout(30) class TestSafeFileInterleaving: """ From cf11f27a8878ca75b20add23a9d5c8e8fd4d289b Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 4 Feb 2026 15:55:19 -0500 Subject: [PATCH 6/6] Refactor safe file interleaving tests per code review Simplify test infrastructure by creating a unified Checkpointer base class with three implementations (LockCheckpointer, ReadCheckpointer, WriteCheckpointer). Each checkpointer provides its own install() context manager that patches the necessary functions, composing naturally via ExitStack. - Remove LockManager, FileOperationTracker, and _create_patches() - Add Checkpointer ABC with add()/get() and abstract install() method - LockCheckpointer handles fd registration and flock simulation - ReadCheckpointer/WriteCheckpointer hook into file read/write ops - Consolidate patching into single patched_io() context manager - Replace test_reader_never_sees_partial_write with test_reader_paused_mid_read_blocks_writer for better coverage --- src/toil/test/server/safeFileTest.py | 326 ++++++++++++++++----------- 1 file changed, 195 insertions(+), 131 deletions(-) diff --git a/src/toil/test/server/safeFileTest.py b/src/toil/test/server/safeFileTest.py index 5ec4eb6cd8..d5afa53f21 100644 --- a/src/toil/test/server/safeFileTest.py +++ b/src/toil/test/server/safeFileTest.py @@ -23,10 +23,13 @@ The tests use no sleeps - all synchronization is done via events/conditions. """ +import builtins import fcntl import os import threading +from abc import ABC, abstractmethod from collections.abc import Generator +from contextlib import ExitStack, contextmanager from pathlib import Path from typing import Any from unittest.mock import patch @@ -140,19 +143,43 @@ def shared_count(self) -> int: return self._shared_count -class LockManager: - """Manages simulated locks for multiple files.""" +class Checkpointer(ABC): + """Base class for hooking checkpoints into operations.""" def __init__(self) -> None: + self._checkpoints: dict[str, Checkpoint] = {} + + def add(self, thread_name: str, checkpoint: Checkpoint) -> None: + """Register a checkpoint for the given thread.""" + self._checkpoints[thread_name] = checkpoint + + def get(self, thread_name: str) -> Checkpoint | None: + """Get checkpoint for thread, if any.""" + return self._checkpoints.get(thread_name) + + @abstractmethod + @contextmanager + def install(self) -> Generator[None, None, None]: + """ + Install patches for this checkpointer. + + Each checkpointer provides its own context manager that patches the + necessary functions. Multiple checkpointers compose by each capturing + the current (possibly already-patched) functions. + """ + ... + + +class LockCheckpointer(Checkpointer): + """Checkpointer that pauses after flock acquisition.""" + + def __init__(self) -> None: + super().__init__() self._locks: dict[str, SimulatedLock] = {} self._fd_to_path: dict[int, str] = {} self._global_lock = threading.Lock() - # Checkpoints that tests can use to pause at specific points. - # Keyed by thread name. - self.after_acquire: dict[str, Checkpoint] = {} - - def register_fd(self, fd: int, path: str) -> None: + def _register_fd(self, fd: int, path: str) -> None: """Associate a file descriptor with a path.""" with self._global_lock: real_path = os.path.realpath(path) @@ -160,7 +187,7 @@ def register_fd(self, fd: int, path: str) -> None: if real_path not in self._locks: self._locks[real_path] = SimulatedLock() - def get_lock(self, fd: int) -> SimulatedLock | None: + def _get_lock(self, fd: int) -> SimulatedLock | None: """Get the lock for a file descriptor.""" with self._global_lock: path = self._fd_to_path.get(fd) @@ -169,76 +196,110 @@ def get_lock(self, fd: int) -> SimulatedLock | None: return None def get_lock_for_path(self, path: str) -> SimulatedLock | None: - """Get the lock for a path.""" + """Get the lock for a path (for test assertions).""" real_path = os.path.realpath(path) with self._global_lock: return self._locks.get(real_path) - def flock(self, fd: int, operation: int) -> None: - """Simulated flock that respects checkpoints.""" - lock = self.get_lock(fd) - if lock is None: - return + @contextmanager + def install(self) -> Generator[None, None, None]: + """Patch open to register fds, and patch flock with lock simulation.""" + original_open = builtins.open + checkpointer = self - thread_name = threading.current_thread().name + def patched_open( + path: Any, mode: str = "r", *args: Any, **kwargs: Any + ) -> Any: + f = original_open(path, mode, *args, **kwargs) + try: + checkpointer._register_fd(f.fileno(), str(path)) + except (OSError, ValueError): + pass + return f - if operation == fcntl.LOCK_UN: - lock.release() + def patched_flock(fd: int, operation: int) -> None: + lock = checkpointer._get_lock(fd) + if lock is None: + return - elif operation == fcntl.LOCK_SH: - lock.acquire_shared() - checkpoint = self.after_acquire.get(thread_name) - if checkpoint: - checkpoint.arrive_and_wait() + thread_name = threading.current_thread().name - elif operation == fcntl.LOCK_EX: - lock.acquire_exclusive() - checkpoint = self.after_acquire.get(thread_name) - if checkpoint: - checkpoint.arrive_and_wait() + if operation == fcntl.LOCK_UN: + lock.release() + elif operation == fcntl.LOCK_SH: + lock.acquire_shared() + checkpoint = checkpointer.get(thread_name) + if checkpoint: + checkpoint.arrive_and_wait() + elif operation == fcntl.LOCK_EX: + lock.acquire_exclusive() + checkpoint = checkpointer.get(thread_name) + if checkpoint: + checkpoint.arrive_and_wait() + + with ( + patch("builtins.open", patched_open), + patch("toil.server.utils.fcntl.flock", patched_flock), + ): + yield -class FileOperationTracker: - """ - Tracks and controls file operations with checkpoints. +class ReadCheckpointer(Checkpointer): + """Checkpointer that pauses during file read.""" - Wraps file objects to inject synchronization points during read/write. - """ + @contextmanager + def install(self) -> Generator[None, None, None]: + """Patch open to wrap read operations with checkpoint hooks.""" + original_open = builtins.open + checkpointer = self - def __init__(self, lock_manager: LockManager) -> None: - self.lock_manager = lock_manager - # Checkpoints keyed by thread name - self.during_write: dict[str, Checkpoint] = {} - self.during_read: dict[str, Checkpoint] = {} + def patched_open( + path: Any, mode: str = "r", *args: Any, **kwargs: Any + ) -> Any: + f = original_open(path, mode, *args, **kwargs) + original_read = f.read - def wrap_file(self, file_obj: Any, path: str) -> Any: - """Wrap a file object to track operations.""" - try: - self.lock_manager.register_fd(file_obj.fileno(), path) - except (OSError, ValueError): - pass + def tracked_read(size: int = -1) -> str: + thread_name = threading.current_thread().name + checkpoint = checkpointer.get(thread_name) + if checkpoint: + checkpoint.arrive_and_wait() + return original_read(size) - original_write = file_obj.write - original_read = file_obj.read - tracker = self + f.read = tracked_read + return f - def tracked_write(data: str) -> int: - thread_name = threading.current_thread().name - checkpoint = tracker.during_write.get(thread_name) - if checkpoint: - checkpoint.arrive_and_wait() - return original_write(data) + with patch("builtins.open", patched_open): + yield - def tracked_read(size: int = -1) -> str: - thread_name = threading.current_thread().name - checkpoint = tracker.during_read.get(thread_name) - if checkpoint: - checkpoint.arrive_and_wait() - return original_read(size) - file_obj.write = tracked_write - file_obj.read = tracked_read - return file_obj +class WriteCheckpointer(Checkpointer): + """Checkpointer that pauses during file write.""" + + @contextmanager + def install(self) -> Generator[None, None, None]: + """Patch open to wrap write operations with checkpoint hooks.""" + original_open = builtins.open + checkpointer = self + + def patched_open( + path: Any, mode: str = "r", *args: Any, **kwargs: Any + ) -> Any: + f = original_open(path, mode, *args, **kwargs) + original_write = f.write + + def tracked_write(data: str) -> int: + thread_name = threading.current_thread().name + checkpoint = checkpointer.get(thread_name) + if checkpoint: + checkpoint.arrive_and_wait() + return original_write(data) + + f.write = tracked_write + return f + + with patch("builtins.open", patched_open): + yield # TODO: Add tests for AtomicFileCreate path (concurrent new file creation). @@ -256,30 +317,25 @@ class TestSafeFileInterleaving: """ @pytest.fixture(autouse=True) - def setup_managers(self, tmp_path: Path) -> Generator[None]: - """Set up lock manager and file tracker for each test.""" + def setup_test_file(self, tmp_path: Path) -> Generator[None]: + """Set up test file path for each test.""" self.test_file = tmp_path / "test_file" - self.lock_manager = LockManager() - self.file_tracker = FileOperationTracker(self.lock_manager) yield - def _create_patches(self) -> tuple[Any, Any]: - """Create patch contexts for flock and open.""" - original_open = open - lock_manager = self.lock_manager - file_tracker = self.file_tracker - - def patched_open(path: Any, mode: str = "r", *args: Any, **kwargs: Any) -> Any: - f = original_open(path, mode, *args, **kwargs) - return file_tracker.wrap_file(f, str(path)) - - def patched_flock(fd: int, operation: int) -> None: - lock_manager.flock(fd, operation) + @contextmanager + def patched_io( + self, *checkpointers: Checkpointer + ) -> Generator[None, None, None]: + """ + Context manager that installs all checkpointer patches. - return ( - patch("builtins.open", patched_open), - patch("toil.server.utils.fcntl.flock", patched_flock), - ) + Uses ExitStack to compose the context managers from each checkpointer. + Patches compose naturally since each captures the current open. + """ + with ExitStack() as stack: + for checkpointer in checkpointers: + stack.enter_context(checkpointer.install()) + yield def test_reader_blocked_while_writer_holds_lock(self) -> None: """ @@ -298,9 +354,9 @@ def test_reader_blocked_while_writer_holds_lock(self) -> None: self.test_file.write_text("original") - # Checkpoint to pause writer after acquiring lock writer_checkpoint = Checkpoint() - self.lock_manager.after_acquire["writer"] = writer_checkpoint + locker = LockCheckpointer() + locker.add("writer", writer_checkpoint) reader_completed = threading.Event() results: dict[str, Any] = {"writer_done": False, "reader_result": None} @@ -320,8 +376,7 @@ def reader() -> None: except Exception as e: errors.append(e) - patches = self._create_patches() - with patches[0], patches[1]: + with self.patched_io(locker): t_writer = threading.Thread(target=writer, name="writer") t_writer.start() @@ -329,7 +384,7 @@ def reader() -> None: assert writer_checkpoint.wait_for_arrival(), "Writer didn't reach checkpoint" # Verify writer has the lock - lock = self.lock_manager.get_lock_for_path(str(self.test_file)) + lock = locker.get_lock_for_path(str(self.test_file)) assert lock is not None assert lock.has_exclusive, "Writer should hold exclusive lock" @@ -364,7 +419,8 @@ def test_writer_blocked_while_reader_holds_lock(self) -> None: self.test_file.write_text("original") reader_checkpoint = Checkpoint() - self.lock_manager.after_acquire["reader"] = reader_checkpoint + locker = LockCheckpointer() + locker.add("reader", reader_checkpoint) writer_completed = threading.Event() results: dict[str, Any] = {"reader_result": None, "writer_done": False} @@ -384,14 +440,13 @@ def writer() -> None: except Exception as e: errors.append(e) - patches = self._create_patches() - with patches[0], patches[1]: + with self.patched_io(locker): t_reader = threading.Thread(target=reader, name="reader") t_reader.start() assert reader_checkpoint.wait_for_arrival(), "Reader didn't reach checkpoint" - lock = self.lock_manager.get_lock_for_path(str(self.test_file)) + lock = locker.get_lock_for_path(str(self.test_file)) assert lock is not None assert lock.shared_count == 1, "Reader should hold shared lock" @@ -423,8 +478,9 @@ def test_multiple_readers_not_blocked(self) -> None: reader1_checkpoint = Checkpoint() reader2_checkpoint = Checkpoint() - self.lock_manager.after_acquire["reader1"] = reader1_checkpoint - self.lock_manager.after_acquire["reader2"] = reader2_checkpoint + locker = LockCheckpointer() + locker.add("reader1", reader1_checkpoint) + locker.add("reader2", reader2_checkpoint) results: dict[str, str | None] = {"reader1": None, "reader2": None} errors: list[Exception] = [] @@ -441,14 +497,13 @@ def reader2() -> None: except Exception as e: errors.append(e) - patches = self._create_patches() - with patches[0], patches[1]: + with self.patched_io(locker): t_reader1 = threading.Thread(target=reader1, name="reader1") t_reader1.start() assert reader1_checkpoint.wait_for_arrival(), "Reader1 didn't reach checkpoint" - lock = self.lock_manager.get_lock_for_path(str(self.test_file)) + lock = locker.get_lock_for_path(str(self.test_file)) assert lock is not None assert lock.shared_count == 1 @@ -485,7 +540,8 @@ def test_writers_serialize(self) -> None: self.test_file.write_text("original") writer1_checkpoint = Checkpoint() - self.lock_manager.after_acquire["writer1"] = writer1_checkpoint + locker = LockCheckpointer() + locker.add("writer1", writer1_checkpoint) writer2_completed = threading.Event() results: dict[str, bool] = {"writer1_done": False, "writer2_done": False} @@ -506,14 +562,13 @@ def writer2() -> None: except Exception as e: errors.append(e) - patches = self._create_patches() - with patches[0], patches[1]: + with self.patched_io(locker): t_writer1 = threading.Thread(target=writer1, name="writer1") t_writer1.start() assert writer1_checkpoint.wait_for_arrival(), "Writer1 didn't reach checkpoint" - lock = self.lock_manager.get_lock_for_path(str(self.test_file)) + lock = locker.get_lock_for_path(str(self.test_file)) assert lock is not None assert lock.has_exclusive, "Writer1 should hold exclusive lock" @@ -536,60 +591,66 @@ def writer2() -> None: assert results["writer1_done"] assert results["writer2_done"] - def test_reader_never_sees_partial_write(self) -> None: + def test_reader_paused_mid_read_blocks_writer(self) -> None: """ - Verify that a reader cannot see partial write content. - Reader either sees old content or complete new content. + Verify that a writer is blocked even when reader is paused during + the actual read operation (not just after lock acquisition). """ from toil.server.utils import safe_read_file, safe_write_file - self.test_file.write_text("AAAA") + self.test_file.write_text("original") - writer_checkpoint = Checkpoint() - self.lock_manager.after_acquire["writer"] = writer_checkpoint + # Pause reader during the read operation itself + read_checkpoint = Checkpoint() + read_hooker = ReadCheckpointer() + read_hooker.add("reader", read_checkpoint) + + # Also need lock checkpointer for fd registration and lock simulation + locker = LockCheckpointer() + writer_completed = threading.Event() results: dict[str, Any] = {"reader_result": None} - reader_completed = threading.Event() errors: list[Exception] = [] - def writer() -> None: + def reader() -> None: try: - safe_write_file(str(self.test_file), "BBBB") + results["reader_result"] = safe_read_file(str(self.test_file)) except Exception as e: errors.append(e) - def reader() -> None: + def writer() -> None: try: - results["reader_result"] = safe_read_file(str(self.test_file)) - reader_completed.set() + safe_write_file(str(self.test_file), "updated") + writer_completed.set() except Exception as e: errors.append(e) - patches = self._create_patches() - with patches[0], patches[1]: - t_writer = threading.Thread(target=writer, name="writer") - t_writer.start() - - assert writer_checkpoint.wait_for_arrival(), "Writer didn't reach checkpoint" - + with self.patched_io(locker, read_hooker): t_reader = threading.Thread(target=reader, name="reader") t_reader.start() - # Reader should be blocked - assert not reader_completed.wait(timeout=0.1), ( - "Reader should be blocked while writer holds lock" + assert read_checkpoint.wait_for_arrival(), "Reader didn't reach read checkpoint" + + # Reader should hold shared lock while paused mid-read + lock = locker.get_lock_for_path(str(self.test_file)) + assert lock is not None + assert lock.shared_count == 1, "Reader should hold lock during read" + + t_writer = threading.Thread(target=writer, name="writer") + t_writer.start() + + # Writer should be blocked + assert not writer_completed.wait(timeout=0.1), ( + "Writer should be blocked while reader is mid-read" ) - assert results["reader_result"] is None - writer_checkpoint.release() + read_checkpoint.release() - t_writer.join(timeout=SYNC_TIMEOUT) t_reader.join(timeout=SYNC_TIMEOUT) + t_writer.join(timeout=SYNC_TIMEOUT) assert errors == [] - assert results["reader_result"] == "BBBB", ( - "Reader should see complete write, not partial" - ) + assert results["reader_result"] == "original" def test_writer_paused_mid_write_blocks_reader(self) -> None: """ @@ -602,7 +663,11 @@ def test_writer_paused_mid_write_blocks_reader(self) -> None: # Pause writer during the write operation itself write_checkpoint = Checkpoint() - self.file_tracker.during_write["writer"] = write_checkpoint + write_hooker = WriteCheckpointer() + write_hooker.add("writer", write_checkpoint) + + # Also need lock checkpointer for fd registration and lock simulation + locker = LockCheckpointer() reader_completed = threading.Event() results: dict[str, Any] = {"reader_result": None} @@ -621,15 +686,14 @@ def reader() -> None: except Exception as e: errors.append(e) - patches = self._create_patches() - with patches[0], patches[1]: + with self.patched_io(locker, write_hooker): t_writer = threading.Thread(target=writer, name="writer") t_writer.start() assert write_checkpoint.wait_for_arrival(), "Writer didn't reach write checkpoint" # Writer should hold exclusive lock while paused mid-write - lock = self.lock_manager.get_lock_for_path(str(self.test_file)) + lock = locker.get_lock_for_path(str(self.test_file)) assert lock is not None assert lock.has_exclusive, "Writer should hold lock during write"