From 9e871ebfd62dde6ad3dadd44dd1bda1c29ad912c Mon Sep 17 00:00:00 2001 From: AhmedGoudaa Date: Tue, 23 Dec 2025 16:34:13 +0400 Subject: [PATCH 1/4] asyncio support --- CHANGELOG.md | 14 + README.md | 22 +- docs/benchmarking-and-profiling.md | 61 +- docs/benchmarking.md | 260 ------- pyproject.toml | 5 +- src/advanced_caching/__init__.py | 13 +- src/advanced_caching/_decorator_common.py | 66 ++ src/advanced_caching/_schedulers.py | 74 ++ src/advanced_caching/decorators.py | 827 ++++++++++------------ src/advanced_caching/storage.py | 3 +- tests/benchmark.py | 308 +++----- tests/compare_benchmarks.py | 248 ------- tests/test_correctness.py | 424 +++++------ tests/test_integration_redis.py | 163 +++-- tests/test_sync_support.py | 63 ++ uv.lock | 27 +- 16 files changed, 1022 insertions(+), 1556 deletions(-) delete mode 100644 docs/benchmarking.md create mode 100644 src/advanced_caching/_decorator_common.py create mode 100644 src/advanced_caching/_schedulers.py delete mode 100644 tests/compare_benchmarks.py create mode 100644 tests/test_sync_support.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ff85e63..f4388c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,20 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.2.0] - 2025-12-23 + +### Changed +- **Major Architecture Overhaul**: The library is now fully async-native. + - `TTLCache`, `SWRCache`, and `BGCache` now support `async def` functions natively using `await`. + - Synchronous functions are still supported via intelligent inspection, maintaining backward compatibility. +- **Unified Scheduling**: `SWRCache` (in sync mode) and `BGCache` now use `APScheduler` (`SharedScheduler` and `SharedAsyncScheduler`) for all background tasks, replacing ad-hoc threading. +- **Testing**: Integration tests rewritten to use `pytest-asyncio` with `mode="auto"`. + +### Added +- `AsyncTTLCache`, `AsyncStaleWhileRevalidateCache`, `AsyncBackgroundCache` classes (aliased to `TTLCache`, `SWRCache`, `BGCache`). +- `SharedAsyncScheduler` for managing async background jobs. +- `pytest-asyncio` configuration in `pyproject.toml`. + ## [0.1.6] - 2025-12-15 ### Changed diff --git a/README.md b/README.md index a92c51b..726d8b9 100644 --- a/README.md +++ b/README.md @@ -44,23 +44,35 @@ uv pip install "advanced-caching[redis]" # Redis support ```python from advanced_caching import TTLCache, SWRCache, BGCache +# Sync function @TTLCache.cached("user:{}", ttl=300) def get_user(user_id: int) -> dict: return db.fetch(user_id) +# Async function (works natively) +@TTLCache.cached("user:{}", ttl=300) +async def get_user_async(user_id: int) -> dict: + return await db.fetch(user_id) + +# Stale-While-Revalidate (Sync) @SWRCache.cached("product:{}", ttl=60, stale_ttl=30) def get_product(product_id: int) -> dict: return api.fetch_product(product_id) -# Background refresh +# Stale-While-Revalidate (Async) +@SWRCache.cached("async:product:{}", ttl=60, stale_ttl=30) +async def get_product_async(product_id: int) -> dict: + return await api.fetch_product(product_id) + +# Background refresh (Sync) @BGCache.register_loader("inventory", interval_seconds=300) def load_inventory() -> list[dict]: return warehouse_api.get_all_items() -# Async works too -@TTLCache.cached("user:{}", ttl=300) -async def get_user_async(user_id: int) -> dict: - return await db.fetch(user_id) +# Background refresh (Async) +@BGCache.register_loader("inventory_async", interval_seconds=300) +async def load_inventory_async() -> list[dict]: + return await warehouse_api.get_all_items() ``` --- diff --git a/docs/benchmarking-and-profiling.md b/docs/benchmarking-and-profiling.md index fe58e2a..5e14f6e 100644 --- a/docs/benchmarking-and-profiling.md +++ b/docs/benchmarking-and-profiling.md @@ -2,10 +2,9 @@ This repo includes a small, reproducible benchmark harness and a profiler-friendly workload script. -- Benchmark runner: `tests/benchmark.py` +- Benchmark suite: `tests/benchmark.py` - Profiler workload: `tests/profile_decorators.py` - Benchmark log (append-only JSON-lines): `benchmarks.log` -- Run comparison helper: `tests/compare_benchmarks.py` ## 1) Benchmarking (step-by-step) @@ -17,14 +16,14 @@ This repo uses `uv`. From the repo root: uv sync ``` -### Step 1 — Run the default benchmark +### Step 1 — Run the benchmark suite ```bash uv run python tests/benchmark.py ``` What you get: -- A printed table for **cold** (always miss), **hot** (always hit), and **mixed** (hits + misses). +- Printed tables for **hot cache hits** (comparing TTLCache, SWRCache, BGCache). - A new JSON entry appended to `benchmarks.log` with the config + median/mean/stdev per strategy. ### Step 2 — Tune benchmark parameters (optional) @@ -35,60 +34,42 @@ What you get: - `BENCH_WORK_MS` (default `5.0`) — simulated I/O latency (sleep) - `BENCH_WARMUP` (default `10`) - `BENCH_RUNS` (default `300`) -- `BENCH_MIXED_KEY_SPACE` (default `100`) -- `BENCH_MIXED_RUNS` (default `500`) Examples: ```bash -BENCH_RUNS=1000 BENCH_MIXED_RUNS=2000 uv run python tests/benchmark.py -``` - -```bash -# Focus on decorator overhead (no artificial sleep) -BENCH_WORK_MS=0 BENCH_RUNS=200000 BENCH_MIXED_RUNS=300000 uv run python tests/benchmark.py +BENCH_RUNS=1000 uv run python tests/benchmark.py ``` ### Step 3 — Compare two runs -There are two ways to select runs: - -- Relative: `last` / `last-N` -- Explicit: integer indices (0-based; negatives allowed) - -List run indices quickly: +The benchmark appends JSON lines to `benchmarks.log`. A quick helper to list runs: ```bash uv run python - <<'PY' import json from pathlib import Path runs=[] +if not Path('benchmarks.log').exists(): + print("No benchmarks.log found") + exit(0) for line in Path('benchmarks.log').read_text(encoding='utf-8', errors='replace').splitlines(): - line=line.strip() - if not line.startswith('{'): - continue - try: - obj=json.loads(line) - except Exception: - continue - if isinstance(obj,dict) and 'results' in obj: - runs.append(obj) + line=line.strip() + if not line.startswith('{'): + continue + try: + obj=json.loads(line) + except Exception: + continue + if isinstance(obj,dict) and 'sections' in obj: + runs.append(obj) print('count',len(runs)) for i,r in enumerate(runs): - print(i,r.get('ts')) + print(i,r.get('ts')) PY ``` -Compare (example: index 2 vs index 11): - -```bash -uv run python tests/compare_benchmarks.py --a 2 --b 11 -``` - -What to look at: -- **Hot TTL/SWR** medians: these are the pure “cache-hit overhead” numbers. -- **Mixed** medians: reflect a real-ish distribution; watch for regressions here. -- Ignore small (<5–10%) deltas unless they repeat across multiple clean runs. +To compare two indices (e.g., 2 vs 11), load the JSON objects in a notebook or script and diff the `sections` (hot medians for TTL/SWR/BG are the most sensitive to overhead changes). ### Step 4 — Make results stable (recommended practice) @@ -163,6 +144,10 @@ PROFILE_N=5000000 \ - `SWRCache` hot: overhead of key generation + `get_entry()` + freshness checks. - `BGCache` hot: overhead of key lookup + `get()` + return. +- **Async results (important)** + - Async medians include the cost of creating/awaiting a coroutine and event-loop scheduling. + - For AsyncBG/AsyncSWR, compare against the `async_baseline` row (plain `await` with no cache) to estimate *cache-specific* overhead. + - **Mixed path** - A high mean + low median typically indicates occasional slow misses/refreshes. diff --git a/docs/benchmarking.md b/docs/benchmarking.md deleted file mode 100644 index d6c5f7f..0000000 --- a/docs/benchmarking.md +++ /dev/null @@ -1,260 +0,0 @@ -# Benchmarking Guide - -This guide explains how to run benchmarks and compare performance between different versions of advanced-caching. - -## Running Benchmarks - -### Full Benchmark Suite - -Run all benchmarks and save results to `benchmarks.log`: - -```bash -uv run python tests/benchmark.py -``` - -The script will run multiple benchmark scenarios: -- **Cold Cache**: Initial cache misses with storage overhead -- **Hot Cache**: Repeated cache hits (best-case performance) -- **Varying Keys**: Realistic mixed workload with 100+ different keys - -Results are saved in JSON format to `benchmarks.log` for later comparison. - -## Comparing Benchmark Results - -### View Baseline vs Current Performance - -Compare the last two benchmark runs: - -```bash -uv run python tests/compare_benchmarks.py -``` - -### Compare Specific Runs - -Compare specific runs using selectors: - -```bash -# Compare second-to-last run vs latest -uv run python tests/compare_benchmarks.py --a last-1 --b last - -# Compare run index 0 vs run index 2 -uv run python tests/compare_benchmarks.py --a 0 --b 2 - -# Compare run at index -3 (third from last) vs latest -uv run python tests/compare_benchmarks.py --a -3 --b last -``` - -### Custom Log File - -If benchmarks are saved to a different file: - -```bash -uv run python tests/compare_benchmarks.py --log my_benchmarks.log -``` - -## Understanding the Output - -### Example Comparison Report - -``` -==================================================================================================== -BENCHMARK COMPARISON REPORT -==================================================================================================== - -Run A (baseline): 2025-12-12T10:30:00 - Config: {'runs': 1000, 'warmup': 100} - -Run B (current): 2025-12-12T10:45:30 - Config: {'runs': 1000, 'warmup': 100} - ----------------------------------------------------------------------------------------------------- - -📊 COLD CACHE ----------------------------------------------------------------------------------------------------- -Strategy A (ms) B (ms) Change % Status ----------------------------------------------------------------------------------------------------- -TTLCache 15.3250 14.8900 -0.4350 -2.84% ✓ FASTER (2.84%) -SWRCache 18.5100 18.2300 -0.2800 -1.51% ✓ SAME -No Cache (baseline) 13.1000 13.0900 -0.0100 -0.08% ✓ SAME - -📊 HOT CACHE ----------------------------------------------------------------------------------------------------- -Strategy A (ms) B (ms) Change % Status ----------------------------------------------------------------------------------------------------- -TTLCache 0.0012 0.0011 -0.0001 -8.33% ✓ FASTER (8.33%) -SWRCache 0.0014 0.0015 +0.0001 7.14% ✗ SLOWER (7.14%) -BGCache 0.0003 0.0003 +0.0000 0.00% ✓ SAME - -==================================================================================================== -SUMMARY -==================================================================================================== - -✓ 3 IMPROVEMENT(S): - • TTLCache in cold cache → 2.84% faster - • TTLCache in hot cache → 8.33% faster - • SWRCache in cold cache → 1.51% faster - -✗ 1 REGRESSION(S): - • SWRCache in hot cache → 7.14% slower - ----------------------------------------------------------------------------------------------------- - -🎯 VERDICT: ✓ OVERALL IMPROVEMENT (avg +5.89% faster) - -==================================================================================================== -``` - -### Key Sections Explained - -#### 1. Header -- **Run A (baseline)**: Previous benchmark results with timestamp -- **Run B (current)**: Latest benchmark results for comparison -- Shows configuration parameters used for both runs - -#### 2. Per-Section Comparison -Groups results by benchmark scenario: -- **Strategy**: Name of the caching approach (e.g., TTLCache, SWRCache) -- **A (ms)**: Median time in baseline run (milliseconds) -- **B (ms)**: Median time in current run -- **Change**: Absolute difference (B - A) in milliseconds -- **%**: Percentage change relative to baseline -- **Status**: Visual indicator with emoji: - - ✓ FASTER: Performance improved - - ✓ SAME: Within 2% threshold (no significant change) - - ✗ SLOWER: Performance regressed - -#### 3. Summary -- **IMPROVEMENTS**: Strategies that got faster - - Shows top 5 improvements - - Sorted by percentage gain (largest first) -- **REGRESSIONS**: Strategies that got slower - - Shows top 5 regressions - - Sorted by percentage loss (largest first) - -#### 4. Verdict -Overall performance assessment: -- **STABLE**: No significant changes (< 2% difference) -- **IMPROVEMENT**: More improvements than regressions - - Shows average speedup percentage -- **REGRESSION**: More regressions than improvements - - Shows average slowdown percentage - -## Interpreting Results - -### Performance Thresholds - -- **< 2%**: Considered **same** (normal measurement noise) -- **2-5%**: **Notable** change (worth investigating) -- **> 5%**: **Significant** change (code optimization or regression) - -### What to Look For - -✓ **Good signs**: -- Hot cache times remain stable (< 5% change) -- Cold cache shows improvements (refactoring benefits) -- No regressions in any benchmark - -✗ **Warning signs**: -- Hot cache performance degrades (potential code path regression) -- New dependencies add overhead to all runs -- Asymmetric changes (cache hits slow, misses fast) - -## Benchmark Scenarios - -### 1. Cold Cache -**What it tests**: Cache miss handling and data storage overhead - -Measures: -- Function execution time -- Cache miss detection -- Storage backend write performance -- Different cache backends side-by-side - -Use this to: -- Verify caching decorators don't add significant overhead -- Detect regression in cache backends -- Compare storage implementation performance - -### 2. Hot Cache -**What it tests**: Pure cache hit speed (best case) - -Measures: -- Cache lookup time -- Data deserialization -- Decorator wrapper overhead -- Hit on same key repeated 1000+ times - -Use this to: -- Ensure caching is providing speed benefit -- Detect memory/performance issues -- Compare backend performance under load - -### 3. Varying Keys -**What it tests**: Mixed realistic workload - -Measures: -- Performance with 100+ unique keys -- Mix of hits and misses -- Cache eviction/aging behavior -- Real-world usage patterns - -Use this to: -- Understand performance with realistic data -- Detect memory issues under load -- Test cache aging and refresh behavior - -## Workflow Tips - -### Before Making Code Changes - -Save baseline benchmarks: - -```bash -uv run python tests/benchmark.py -git add benchmarks.log -git commit -m "baseline: benchmark before optimization" -``` - -### After Code Changes - -Run new benchmarks: - -```bash -uv run python tests/benchmark.py -uv run python tests/compare_benchmarks.py -``` - -Review the comparison report and decide: -- ✓ Changes are good → commit -- ✗ Regression detected → revert or optimize further - -### Benchmarking in CI/CD - -GitHub Actions runs benchmarks on every push to `main`: - -```yaml -- name: Run benchmarks - run: uv run python tests/benchmark.py -``` - -Results are stored as artifacts for later comparison. - -## Troubleshooting - -### "Need at least 2 JSON runs" - -You need at least 2 benchmark results to compare. Run benchmarks twice: - -```bash -uv run python tests/benchmark.py # Creates first run -uv run python tests/benchmark.py # Creates second run -uv run python tests/compare_benchmarks.py # Now you can compare -``` - -### "Unsupported selector" - -Valid selectors are: -- `last` - most recent run -- `last-N` - N runs ago (e.g., `last-1`, `last-5`) -- `0`, `1`, `2` - absolute index from start -- `-1`, `-2` - absolute index from end diff --git a/pyproject.toml b/pyproject.toml index 5c4ea20..0d82cd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "advanced-caching" -version = "0.1.6" +version = "0.2.0" description = "Production-ready composable caching with TTL, SWR, and background refresh patterns for Python." readme = "README.md" requires-python = ">=3.10" @@ -48,6 +48,7 @@ Issues = "https://github.com/agkloop/advanced_caching/issues" [dependency-groups] dev = [ "pytest>=8.2", + "pytest-asyncio>=1.3.0", "pytest-cov>=4.0", "ruff>=0.14.8", "scalene>=1.5.55", @@ -58,3 +59,5 @@ dev = [ testpaths = ["tests"] python_files = ["test_*.py"] addopts = "-v" +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" diff --git a/src/advanced_caching/__init__.py b/src/advanced_caching/__init__.py index aaacbba..4382b27 100644 --- a/src/advanced_caching/__init__.py +++ b/src/advanced_caching/__init__.py @@ -4,7 +4,7 @@ Expose storage backends, decorators, and scheduler utilities under `advanced_caching`. """ -__version__ = "0.1.6" +__version__ = "0.2.0" from .storage import ( InMemCache, @@ -18,14 +18,14 @@ ) from .decorators import ( TTLCache, + AsyncTTLCache, SWRCache, - StaleWhileRevalidateCache, - BackgroundCache, + AsyncStaleWhileRevalidateCache, BGCache, + AsyncBackgroundCache, ) __all__ = [ - "__version__", "InMemCache", "RedisCache", "HybridCache", @@ -35,8 +35,9 @@ "PickleSerializer", "JsonSerializer", "TTLCache", + "AsyncTTLCache", "SWRCache", - "StaleWhileRevalidateCache", - "BackgroundCache", + "AsyncStaleWhileRevalidateCache", "BGCache", + "AsyncBackgroundCache", ] diff --git a/src/advanced_caching/_decorator_common.py b/src/advanced_caching/_decorator_common.py new file mode 100644 index 0000000..47fc2a4 --- /dev/null +++ b/src/advanced_caching/_decorator_common.py @@ -0,0 +1,66 @@ +"""Internal helpers shared by caching decorators. + +This module is intentionally *not* part of the public API. + +Goals: +- Eliminate repeated cache-backend normalization patterns. +- Keep decorator hot paths small by binding frequently-used attributes once. +- Centralize wrapper metadata used by tests/debugging (`__wrapped__`, `_cache`, etc.). +""" + +from __future__ import annotations +from typing import Callable, TypeVar + +from .storage import CacheStorage, InMemCache + +T = TypeVar("T") + + +def normalize_cache_factory( + cache: CacheStorage | Callable[[], CacheStorage] | None, + *, + default_factory: Callable[[], CacheStorage] = InMemCache, +) -> Callable[[], CacheStorage]: + """Normalize a cache backend parameter into a no-arg factory. + + Accepted forms: + - None: use default_factory + - Callable[[], CacheStorage]: use as-is + - CacheStorage instance: wrap into a factory that returns the instance + + This keeps decorator code paths small and consistent. + """ + + if cache is None: + return default_factory + if callable(cache): + return cache # type: ignore[return-value] + + cache_instance = cache + + def factory() -> CacheStorage: + return cache_instance + + return factory + + +def attach_wrapper_metadata( + wrapper: Callable[..., T], + func: Callable[..., T], + *, + cache_obj: CacheStorage, + cache_key: str | None = None, +) -> None: + """Attach metadata fields used for debugging/tests. + + Notes: + - We intentionally avoid functools.wraps() here to keep decoration overhead + minimal and to preserve existing behavior. + """ + + wrapper.__wrapped__ = func # type: ignore[attr-defined] + wrapper.__name__ = func.__name__ # type: ignore[attr-defined] + wrapper.__doc__ = func.__doc__ # type: ignore[attr-defined] + wrapper._cache = cache_obj # type: ignore[attr-defined] + if cache_key is not None: + wrapper._cache_key = cache_key # type: ignore[attr-defined] diff --git a/src/advanced_caching/_schedulers.py b/src/advanced_caching/_schedulers.py new file mode 100644 index 0000000..873f08f --- /dev/null +++ b/src/advanced_caching/_schedulers.py @@ -0,0 +1,74 @@ +"""Internal APScheduler singletons used by BGCache decorators. + +These are internal to allow decorators.py to stay focused on caching semantics. +""" + +from __future__ import annotations + +import threading +from typing import ClassVar + +from apscheduler.schedulers.background import BackgroundScheduler +from apscheduler.schedulers.asyncio import AsyncIOScheduler + + +class SharedScheduler: + """Singleton `BackgroundScheduler` for sync BGCache jobs.""" + + _scheduler: ClassVar[BackgroundScheduler | None] = None + _lock: ClassVar[threading.RLock] = threading.RLock() + _started: ClassVar[bool] = False + + @classmethod + def get_scheduler(cls) -> BackgroundScheduler: + with cls._lock: + if cls._scheduler is None: + cls._scheduler = BackgroundScheduler(daemon=True) + assert cls._scheduler is not None + return cls._scheduler + + @classmethod + def start(cls) -> None: + with cls._lock: + if not cls._started: + cls.get_scheduler().start() + cls._started = True + + @classmethod + def shutdown(cls, wait: bool = True) -> None: + with cls._lock: + if cls._started and cls._scheduler is not None: + cls._scheduler.shutdown(wait=wait) + cls._started = False + cls._scheduler = None + + +class SharedAsyncScheduler: + """Singleton `AsyncIOScheduler` for AsyncBGCache jobs.""" + + _scheduler: ClassVar[AsyncIOScheduler | None] = None + _lock: ClassVar[threading.RLock] = threading.RLock() + _started: ClassVar[bool] = False + + @classmethod + def get_scheduler(cls) -> AsyncIOScheduler: + with cls._lock: + if cls._scheduler is None: + cls._scheduler = AsyncIOScheduler() + assert cls._scheduler is not None + return cls._scheduler + + @classmethod + def ensure_started(cls) -> None: + with cls._lock: + if not cls._started: + cls.get_scheduler().start() + cls._started = True + + @classmethod + def shutdown(cls, wait: bool = True) -> None: + with cls._lock: + if cls._started and cls._scheduler is not None: + cls._scheduler.shutdown(wait=wait) + cls._started = False + cls._scheduler = None diff --git a/src/advanced_caching/decorators.py b/src/advanced_caching/decorators.py index bb044e6..e5d6437 100644 --- a/src/advanced_caching/decorators.py +++ b/src/advanced_caching/decorators.py @@ -13,15 +13,16 @@ import atexit import logging import os -import threading import time from concurrent.futures import ThreadPoolExecutor -from typing import Callable, TypeVar, ClassVar +from datetime import datetime, timedelta +from typing import Callable, TypeVar -from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.triggers.interval import IntervalTrigger -from .storage import InMemCache, CacheEntry, CacheStorage +from ._decorator_common import attach_wrapper_metadata, normalize_cache_factory +from ._schedulers import SharedAsyncScheduler, SharedScheduler +from .storage import CacheEntry, CacheStorage, InMemCache T = TypeVar("T") @@ -29,28 +30,59 @@ logger = logging.getLogger(__name__) -_SWR_EXECUTOR: ThreadPoolExecutor | None = None -_SWR_EXECUTOR_LOCK = threading.Lock() +# Helper to normalize cache key builders for all decorators. +def _create_key_fn(key: str | Callable[..., str]) -> Callable[..., str]: + if callable(key): + return key # type: ignore[assignment] + template = key + if "{" not in template: -def _get_swr_executor() -> ThreadPoolExecutor: - global _SWR_EXECUTOR - if _SWR_EXECUTOR is None: - with _SWR_EXECUTOR_LOCK: - if _SWR_EXECUTOR is None: - max_workers = min(32, (os.cpu_count() or 1) * 4) - _SWR_EXECUTOR = ThreadPoolExecutor( - max_workers=max_workers, thread_name_prefix="advanced_caching_swr" - ) + def key_fn(*args, **kwargs) -> str: + return template + + return key_fn + + if ( + template.count("{}") == 1 + and template.count("{") == 1 + and template.count("}") == 1 + ): + prefix, suffix = template.split("{}", 1) + + def key_fn(*args, **kwargs) -> str: + if args: + return prefix + str(args[0]) + suffix + if kwargs: + if len(kwargs) == 1: + return prefix + str(next(iter(kwargs.values()))) + suffix + return template + return template + + return key_fn - def _shutdown() -> None: + def key_fn(*args, **kwargs) -> str: + if args: + try: + return template.format(args[0]) + except Exception: + try: + return template.format(*args) + except Exception: + return template + if kwargs: + try: + return template.format(**kwargs) + except Exception: + if len(kwargs) == 1: try: - _SWR_EXECUTOR.shutdown(wait=False, cancel_futures=True) # type: ignore[union-attr] + return template.format(next(iter(kwargs.values()))) except Exception: - pass + return template + return template + return template - atexit.register(_shutdown) - return _SWR_EXECUTOR + return key_fn # ============================================================================ @@ -58,10 +90,11 @@ def _shutdown() -> None: # ============================================================================ -class SimpleTTLCache: +class AsyncTTLCache: """ Simple TTL cache decorator (singleton pattern). Each decorated function gets its own cache instance. + Supports both sync and async functions (preserves sync/async nature). Key templates (high-performance, simple): - Positional placeholder: "user:{}" → first positional arg @@ -70,16 +103,8 @@ class SimpleTTLCache: Examples: @TTLCache.cached("user:{}", ttl=60) - def get_user(user_id): - return db.fetch_user(user_id) - - @TTLCache.cached("user:{user_id}", ttl=60) - def get_user(*, user_id): - return db.fetch_user(user_id) - - @TTLCache.cached(key=lambda *, lang="en": f"i18n:{lang}", ttl=60) - def load_i18n(lang: str = "en"): - ... + async def get_user(user_id): + return await db.fetch_user(user_id) """ @classmethod @@ -96,121 +121,58 @@ def cached( key: Cache key template (e.g., "user:{}") or generator function ttl: Time-to-live in seconds cache: Optional cache backend (defaults to InMemCache) + """ + key_fn = _create_key_fn(key) + cache_factory = normalize_cache_factory(cache, default_factory=InMemCache) - Example: - @TTLCache.cached("user:{}", ttl=300) - def get_user(user_id): - return db.fetch_user(user_id) + def decorator(func: Callable[..., T]) -> Callable[..., T]: + cache_obj = cache_factory() + cache_get_entry = cache_obj.get_entry + cache_set = cache_obj.set + now_fn = time.time - # With key function - @TTLCache.cached(key=lambda x: f"calc:{x}", ttl=60) - def calculate(x): - return x * 2 - """ - # Each decorated function gets its own cache instance - cache_factory: Callable[[], CacheStorage] - if cache is None: - cache_factory = InMemCache - elif callable(cache): - cache_factory = cache # type: ignore[assignment] - else: - cache_instance = cache - - def cache_factory() -> CacheStorage: - return cache_instance - - function_cache: CacheStorage | None = None - cache_lock = threading.Lock() - - def get_cache() -> CacheStorage: - nonlocal function_cache - if function_cache is None: - with cache_lock: - if function_cache is None: - function_cache = cache_factory() - return function_cache - - # Precompute key builder to reduce per-call branching - if callable(key): - key_fn: Callable[..., str] = key # type: ignore[assignment] - else: - template = key - - # Fast path for common templates like "prefix:{}" (single positional placeholder). - if "{" not in template: - - def key_fn(*args, **kwargs) -> str: - return template + if asyncio.iscoroutinefunction(func): - elif ( - template.count("{}") == 1 - and template.count("{") == 1 - and template.count("}") == 1 - ): - prefix, suffix = template.split("{}", 1) - - def key_fn(*args, **kwargs) -> str: - if args: - return prefix + str(args[0]) + suffix - if kwargs: - if len(kwargs) == 1: - return prefix + str(next(iter(kwargs.values()))) + suffix - return template - return template + async def async_wrapper(*args, **kwargs) -> T: + if ttl <= 0: + return await func(*args, **kwargs) - else: + cache_key = key_fn(*args, **kwargs) + entry = cache_get_entry(cache_key) + if entry is not None: + if now_fn() < entry.fresh_until: + return entry.value - def key_fn(*args, **kwargs) -> str: - if args: - try: - return template.format(args[0]) - except Exception: - return template - if kwargs: - try: - return template.format(**kwargs) - except Exception: - # Attempt single-kwarg positional fallback - if len(kwargs) == 1: - try: - return template.format(next(iter(kwargs.values()))) - except Exception: - return template - return template - return template + result = await func(*args, **kwargs) + cache_set(cache_key, result, ttl) + return result - def decorator(func: Callable[..., T]) -> Callable[..., T]: - def wrapper(*args, **kwargs) -> T: - # If ttl is 0 or negative, disable caching and call through + attach_wrapper_metadata(async_wrapper, func, cache_obj=cache_obj) + return async_wrapper # type: ignore + + # Sync wrapper + def sync_wrapper(*args, **kwargs) -> T: if ttl <= 0: return func(*args, **kwargs) - cache_key = key_fn(*args, **kwargs) - - cache_obj = get_cache() - # Try cache first - cached_value = cache_obj.get(cache_key) - if cached_value is not None: - return cached_value + cache_key = key_fn(*args, **kwargs) + entry = cache_get_entry(cache_key) + if entry is not None: + if now_fn() < entry.fresh_until: + return entry.value - # Cache miss - call function result = func(*args, **kwargs) - cache_obj.set(cache_key, result, ttl) + cache_set(cache_key, result, ttl) return result - # Store cache reference for testing/debugging - wrapper.__wrapped__ = func # type: ignore - wrapper.__name__ = func.__name__ # type: ignore - wrapper.__doc__ = func.__doc__ # type: ignore - wrapper._cache = get_cache() # type: ignore - - return wrapper + attach_wrapper_metadata(sync_wrapper, func, cache_obj=cache_obj) + return sync_wrapper return decorator # Alias for easier import -TTLCache = SimpleTTLCache +TTLCache = AsyncTTLCache # ============================================================================ @@ -218,20 +180,10 @@ def wrapper(*args, **kwargs) -> T: # ============================================================================ -class StaleWhileRevalidateCache: +class AsyncStaleWhileRevalidateCache: """ - SWR cache with background refresh - composable with any cache backend. - Serves stale data while refreshing in background (non-blocking). - - Example: - @SWRCache.cached("product:{}", ttl=60, stale_ttl=30) - def get_product(product_id: int): - return db.fetch_product(product_id) - - # With Redis - @SWRCache.cached("product:{}", ttl=60, stale_ttl=30, cache=redis_cache) - def get_product(product_id: int): - return db.fetch_product(product_id) + SWR (Stale-While-Revalidate) cache decorator. + Supports both sync and async functions. """ @classmethod @@ -243,215 +195,159 @@ def cached( cache: CacheStorage | Callable[[], CacheStorage] | None = None, enable_lock: bool = True, ) -> Callable[[Callable[..., T]], Callable[..., T]]: - """ - SWR cache decorator. + key_fn = _create_key_fn(key) + cache_factory = normalize_cache_factory(cache, default_factory=InMemCache) - Args: - key: Cache key template or generator function. - ttl: Fresh TTL in seconds. - stale_ttl: Additional time to serve stale data while refreshing. - cache: Optional cache backend (InMemCache, RedisCache, etc.). - enable_lock: Whether to use locking to prevent thundering herd. - - Example: - @SWRCache.cached("user:{}", ttl=60, stale_ttl=30) - def get_user(user_id: int): - return db.query("SELECT * FROM users WHERE id = ?", user_id) - - # With Redis - @SWRCache.cached("user:{}", ttl=60, stale_ttl=30, cache=redis_cache) - def get_user(user_id: int): - return db.query("SELECT * FROM users WHERE id = ?", user_id) - """ - # Each decorated function gets its own cache instance - cache_factory: Callable[[], CacheStorage] - if cache is None: - cache_factory = InMemCache - elif callable(cache): - cache_factory = cache # type: ignore[assignment] - else: - cache_instance = cache - - def cache_factory() -> CacheStorage: - return cache_instance - - function_cache: CacheStorage | None = None - cache_lock = threading.Lock() - - def get_cache() -> CacheStorage: - nonlocal function_cache - if function_cache is None: - with cache_lock: - if function_cache is None: - function_cache = cache_factory() - return function_cache - - # Precompute key builder to reduce per-call branching - if callable(key): - key_fn: Callable[..., str] = key # type: ignore[assignment] - else: - template = key - - # Fast path for common templates like "prefix:{}" (single positional placeholder). - if "{" not in template: - - def key_fn(*args, **kwargs) -> str: - return template + def decorator(func: Callable[..., T]) -> Callable[..., T]: + cache_obj = cache_factory() + get_entry = cache_obj.get_entry + set_entry = cache_obj.set_entry + set_if_not_exists = cache_obj.set_if_not_exists + now_fn = time.time + + if asyncio.iscoroutinefunction(func): + create_task = asyncio.create_task + + async def async_wrapper(*args, **kwargs) -> T: + if ttl <= 0: + return await func(*args, **kwargs) + cache_key = key_fn(*args, **kwargs) + now = now_fn() + entry = get_entry(cache_key) + + if entry is None: + result = await func(*args, **kwargs) + created_at = now_fn() + set_entry( + cache_key, + CacheEntry( + value=result, + fresh_until=created_at + ttl, + created_at=created_at, + ), + ) + return result - elif ( - template.count("{}") == 1 - and template.count("{") == 1 - and template.count("}") == 1 - ): - prefix, suffix = template.split("{}", 1) - - def key_fn(*args, **kwargs) -> str: - if args: - return prefix + str(args[0]) + suffix - if kwargs: - if len(kwargs) == 1: - return prefix + str(next(iter(kwargs.values()))) + suffix - return template - return template + if now < entry.fresh_until: + return entry.value - else: + if (now - entry.created_at) > (ttl + stale_ttl): + result = await func(*args, **kwargs) + created_at = now_fn() + set_entry( + cache_key, + CacheEntry( + value=result, + fresh_until=created_at + ttl, + created_at=created_at, + ), + ) + return result - def key_fn(*args, **kwargs) -> str: - if args: - try: - return template.format(args[0]) - except Exception: - return template - if kwargs: + if enable_lock: + lock_key = f"{cache_key}:refresh_lock" + if not set_if_not_exists(lock_key, "1", stale_ttl or 10): + return entry.value + + async def refresh_job() -> None: try: - return template.format(**kwargs) + new_value = await func(*args, **kwargs) + refreshed_at = now_fn() + set_entry( + cache_key, + CacheEntry( + value=new_value, + fresh_until=refreshed_at + ttl, + created_at=refreshed_at, + ), + ) except Exception: - if len(kwargs) == 1: - try: - return template.format(next(iter(kwargs.values()))) - except Exception: - return template - return template - return template + logger.exception( + "Async SWR background refresh failed for key %r", + cache_key, + ) - def decorator(func: Callable[..., T]) -> Callable[..., T]: - def wrapper(*args, **kwargs) -> T: - # If ttl is 0 or negative, disable caching and SWR behavior + create_task(refresh_job()) + return entry.value + + attach_wrapper_metadata(async_wrapper, func, cache_obj=cache_obj) + return async_wrapper # type: ignore + + # Sync wrapper + def sync_wrapper(*args, **kwargs) -> T: if ttl <= 0: return func(*args, **kwargs) cache_key = key_fn(*args, **kwargs) - - cache_obj = get_cache() - now = time.time() - - # Try to get from cache - entry = cache_obj.get_entry(cache_key) + now = now_fn() + entry = get_entry(cache_key) if entry is None: - # Cache miss - fetch now result = func(*args, **kwargs) - cache_entry = CacheEntry( - value=result, fresh_until=now + ttl, created_at=now + created_at = now_fn() + set_entry( + cache_key, + CacheEntry( + value=result, + fresh_until=created_at + ttl, + created_at=created_at, + ), ) - cache_obj.set_entry(cache_key, cache_entry) return result if now < entry.fresh_until: return entry.value - age = now - entry.created_at - if age > (ttl + stale_ttl): - # Too stale, fetch now + if (now - entry.created_at) > (ttl + stale_ttl): result = func(*args, **kwargs) - cache_entry = CacheEntry( - value=result, fresh_until=now + ttl, created_at=now + created_at = now_fn() + set_entry( + cache_key, + CacheEntry( + value=result, + fresh_until=created_at + ttl, + created_at=created_at, + ), ) - cache_obj.set_entry(cache_key, cache_entry) return result - # Stale but within grace period - return stale and refresh in background - # Try to acquire refresh lock - lock_key = f"{cache_key}:refresh_lock" if enable_lock: - acquired = cache_obj.set_if_not_exists( - lock_key, "1", stale_ttl or 10 - ) - if not acquired: + lock_key = f"{cache_key}:refresh_lock" + if not set_if_not_exists(lock_key, "1", stale_ttl or 10): return entry.value - # Refresh in background thread - def refresh_job(): + def refresh_job() -> None: try: new_value = func(*args, **kwargs) - now = time.time() - cache_entry = CacheEntry( - value=new_value, fresh_until=now + ttl, created_at=now + refreshed_at = now_fn() + set_entry( + cache_key, + CacheEntry( + value=new_value, + fresh_until=refreshed_at + ttl, + created_at=refreshed_at, + ), ) - cache_obj.set_entry(cache_key, cache_entry) except Exception: - # Log background refresh failures but never raise logger.exception( - "SWR background refresh failed for key %r", cache_key + "Sync SWR background refresh failed for key %r", cache_key ) - # Use a shared executor to avoid per-refresh thread creation overhead. - _get_swr_executor().submit(refresh_job) - + # Run refresh in background using SharedScheduler + scheduler = SharedScheduler.get_scheduler() + SharedScheduler.start() + scheduler.add_job(refresh_job) return entry.value - wrapper.__wrapped__ = func # type: ignore - wrapper.__name__ = func.__name__ # type: ignore - wrapper.__doc__ = func.__doc__ # type: ignore - wrapper._cache = get_cache() # type: ignore - return wrapper + attach_wrapper_metadata(sync_wrapper, func, cache_obj=cache_obj) + return sync_wrapper return decorator -# Alias for shorter usage -SWRCache = StaleWhileRevalidateCache - - -# ============================================================================ -# Shared Scheduler - Singleton for all background jobs -# ============================================================================ - - -class _SharedScheduler: - """ - Shared BackgroundScheduler instance - singleton for all background jobs. - Ensures only one scheduler runs for all registered loaders. - """ - - _scheduler: ClassVar[BackgroundScheduler | None] = None - _lock: ClassVar[threading.RLock] = threading.RLock() - _started: ClassVar[bool] = False - - @classmethod - def get_scheduler(cls) -> BackgroundScheduler: - """Get or create the shared background scheduler instance.""" - with cls._lock: - if cls._scheduler is None: - cls._scheduler = BackgroundScheduler(daemon=True) - assert cls._scheduler is not None # Type narrowing for IDE - return cls._scheduler +SWRCache = AsyncStaleWhileRevalidateCache - @classmethod - def start(cls) -> None: - """Start the shared background scheduler.""" - with cls._lock: - if not cls._started: - cls.get_scheduler().start() - cls._started = True - @classmethod - def shutdown(cls, wait: bool = True) -> None: - """Stop the shared background scheduler.""" - with cls._lock: - if cls._started and cls._scheduler is not None: - cls._scheduler.shutdown(wait=wait) - cls._started = False - cls._scheduler = None +# Schedulers are implemented as internal singletons in `advanced_caching._schedulers`. # ============================================================================ @@ -459,38 +355,13 @@ def shutdown(cls, wait: bool = True) -> None: # ============================================================================ -class BackgroundCache: - """ - Background cache with BackgroundScheduler for periodic data loading. - All instances share ONE BackgroundScheduler, but each has its own cache storage. - Works with both sync and async functions. - - Args (public API, unified naming): - key (str): Unique cache key for the loader. - interval_seconds (int): Refresh interval. - ttl (int | None): TTL for cached value (defaults to 2 * interval_seconds). - - Example: - # Async function - @BGCache.register_loader(key="categories", interval_seconds=300) - async def load_categories(): - return await db.query("SELECT * FROM categories") - - # Sync function - @BGCache.register_loader(key="config", interval_seconds=300) - def load_config(): - return {"key": "value"} - - # With custom cache backend - @BGCache.register_loader(key="products", interval_seconds=300, cache=redis_cache) - def load_products(): - return fetch_products_from_db() - """ +class AsyncBackgroundCache: + """Background cache loader that uses APScheduler (AsyncIOScheduler for async, BackgroundScheduler for sync).""" @classmethod def shutdown(cls, wait: bool = True) -> None: - """Stop the shared BackgroundScheduler.""" - _SharedScheduler.shutdown(wait) + SharedAsyncScheduler.shutdown(wait) + SharedScheduler.shutdown(wait) @classmethod def register_loader( @@ -502,21 +373,7 @@ def register_loader( on_error: Callable[[Exception], None] | None = None, cache: CacheStorage | Callable[[], CacheStorage] | None = None, ) -> Callable[[Callable[[], T]], Callable[[], T]]: - """Register a background data loader. - - Args: - key: Unique cache key to store the loaded data. - interval_seconds: How often to refresh the data (in seconds). - ttl: Cache TTL (defaults to 2 * interval_seconds if None). - run_immediately: Whether to load data immediately on registration. - on_error: Optional error handler callback. - cache: Optional cache backend (InMemCache, RedisCache, etc.). - - Returns: - Decorated function that returns cached data (sync or async). - """ cache_key = key - # If interval_seconds <= 0 or ttl == 0, disable background scheduling and caching. if interval_seconds <= 0: interval_seconds = 0 if ttl is None and interval_seconds > 0: @@ -524,164 +381,196 @@ def register_loader( if ttl is None: ttl = 0 - # Create a dedicated cache instance for this loader - cache_factory: Callable[[], CacheStorage] - if cache is None: - cache_factory = InMemCache - elif callable(cache): - cache_factory = cache # type: ignore[assignment] - else: - cache_instance = cache - - def cache_factory() -> CacheStorage: - return cache_instance - - loader_cache: CacheStorage | None = None - cache_init_lock = threading.Lock() - - def get_cache() -> CacheStorage: - nonlocal loader_cache - if loader_cache is None: - with cache_init_lock: - if loader_cache is None: - loader_cache = cache_factory() - return loader_cache + cache_factory = normalize_cache_factory(cache, default_factory=InMemCache) + cache_obj = cache_factory() + cache_get = cache_obj.get + cache_set = cache_obj.set def decorator(loader_func: Callable[[], T]) -> Callable[[], T]: - # Detect if function is async - is_async = asyncio.iscoroutinefunction(loader_func) - # Single-flight lock to avoid duplicate initial loads under concurrency - loader_lock = asyncio.Lock() if is_async else threading.Lock() + if asyncio.iscoroutinefunction(loader_func): + loader_lock: asyncio.Lock | None = None + initial_load_done = False + initial_load_task: asyncio.Task[None] | None = None - # If no scheduling/caching is desired, just wrap the function and call through - if interval_seconds <= 0 or ttl <= 0: - if is_async: + if interval_seconds <= 0 or ttl <= 0: async def async_wrapper() -> T: return await loader_func() - async_wrapper.__wrapped__ = loader_func # type: ignore - async_wrapper.__name__ = loader_func.__name__ # type: ignore - async_wrapper.__doc__ = loader_func.__doc__ # type: ignore - async_wrapper._cache = loader_cache # type: ignore - async_wrapper._cache_key = cache_key # type: ignore - + attach_wrapper_metadata( + async_wrapper, + loader_func, + cache_obj=cache_obj, + cache_key=cache_key, + ) return async_wrapper # type: ignore - else: - - def sync_wrapper() -> T: - return loader_func() - sync_wrapper.__wrapped__ = loader_func # type: ignore - sync_wrapper.__name__ = loader_func.__name__ # type: ignore - sync_wrapper.__doc__ = loader_func.__doc__ # type: ignore - sync_wrapper._cache = loader_cache # type: ignore - sync_wrapper._cache_key = cache_key # type: ignore + async def refresh_job() -> None: + try: + data = await loader_func() + cache_set(cache_key, data, ttl) + except Exception as e: + if on_error: + try: + on_error(e) + except Exception: + logger.exception( + "Async BGCache error handler failed for key %r", + cache_key, + ) + else: + logger.exception( + "Async BGCache refresh job failed for key %r", cache_key + ) - return sync_wrapper # type: ignore + next_run_time: datetime | None = None - # Create wrapper that loads and caches - def refresh_job(): - """Job that runs periodically to refresh the cache.""" - try: - cache_obj = get_cache() - if is_async: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + if run_immediately: + if cache_get(cache_key) is None: try: - data = loop.run_until_complete(loader_func()) - finally: - loop.close() - else: - data = loader_func() + loop = asyncio.get_running_loop() + except RuntimeError: + asyncio.run(refresh_job()) + initial_load_done = True + next_run_time = datetime.now() + timedelta( + seconds=interval_seconds * 2 + ) + else: + initial_load_task = loop.create_task(refresh_job()) + next_run_time = datetime.now() + timedelta( + seconds=interval_seconds * 2 + ) + + scheduler = SharedAsyncScheduler.get_scheduler() + SharedAsyncScheduler.ensure_started() + scheduler.add_job( + refresh_job, + trigger=IntervalTrigger(seconds=interval_seconds), + id=cache_key, + replace_existing=True, + next_run_time=next_run_time, + ) + + async def async_wrapper() -> T: + nonlocal loader_lock, initial_load_done, initial_load_task + + value = cache_get(cache_key) + if value is not None: + return value + + # Miss path: serialize initial load / fallback loads. + # We create the asyncio.Lock lazily to avoid requiring a running + # loop at decoration/import time. + if loader_lock is None: + loader_lock = asyncio.Lock() + async with loader_lock: + value = cache_get(cache_key) + if value is not None: + return value + + # If we scheduled an initial refresh task, wait for it once. + if not initial_load_done: + if initial_load_task is not None: + await initial_load_task + elif not run_immediately: + await refresh_job() + initial_load_done = True - cache_obj.set(cache_key, data, ttl) + value = cache_get(cache_key) + if value is not None: + return value + result = await loader_func() + cache_set(cache_key, result, ttl) + return result + + attach_wrapper_metadata( + async_wrapper, loader_func, cache_obj=cache_obj, cache_key=cache_key + ) + return async_wrapper # type: ignore + + # Sync wrapper + from threading import Lock + + sync_lock = Lock() + sync_initial_load_done = False + + if interval_seconds <= 0 or ttl <= 0: + + def sync_wrapper() -> T: + return loader_func() + + attach_wrapper_metadata( + sync_wrapper, loader_func, cache_obj=cache_obj, cache_key=cache_key + ) + return sync_wrapper + + def sync_refresh_job() -> None: + try: + data = loader_func() + cache_set(cache_key, data, ttl) except Exception as e: - # User-provided error handler gets first chance if on_error: try: on_error(e) except Exception: - # Avoid user handler breaking the scheduler logger.exception( - "BGCache error handler failed for key %r", cache_key + "Sync BGCache error handler failed for key %r", + cache_key, ) else: - # Log uncaught loader errors for visibility logger.exception( - "BGCache refresh job failed for key %r", cache_key + "Sync BGCache refresh job failed for key %r", cache_key ) - # Get shared scheduler - scheduler = _SharedScheduler.get_scheduler() + next_run_time_sync: datetime | None = None - # Run immediately if requested (but only if cache is empty) if run_immediately: - cache_obj = get_cache() - if cache_obj.get(cache_key) is None: - refresh_job() + if cache_get(cache_key) is None: + sync_refresh_job() + sync_initial_load_done = True + next_run_time_sync = datetime.now() + timedelta( + seconds=interval_seconds * 2 + ) - # Schedule periodic refresh - scheduler.add_job( - refresh_job, + scheduler_sync = SharedScheduler.get_scheduler() + SharedScheduler.start() + scheduler_sync.add_job( + sync_refresh_job, trigger=IntervalTrigger(seconds=interval_seconds), id=cache_key, replace_existing=True, + next_run_time=next_run_time_sync, ) - # Start scheduler if not already started - _SharedScheduler.start() - - # Return a wrapper that gets from cache - if is_async: + def sync_wrapper_fn() -> T: + nonlocal sync_initial_load_done + value = cache_get(cache_key) + if value is not None: + return value - async def async_wrapper() -> T: - """Get cached data or call loader if not available.""" - cache_obj = get_cache() - value = cache_obj.get(cache_key) + with sync_lock: + value = cache_get(cache_key) if value is not None: return value - async with loader_lock: # type: ignore[arg-type] - value = cache_obj.get(cache_key) - if value is not None: - return value - result = await loader_func() - cache_obj.set(cache_key, result, ttl) - return result - - async_wrapper.__wrapped__ = loader_func # type: ignore - async_wrapper.__name__ = loader_func.__name__ # type: ignore - async_wrapper.__doc__ = loader_func.__doc__ # type: ignore - async_wrapper._cache = get_cache() # type: ignore - async_wrapper._cache_key = cache_key # type: ignore - return async_wrapper # type: ignore - else: + if not sync_initial_load_done: + if not run_immediately: + sync_refresh_job() + sync_initial_load_done = True - def sync_wrapper() -> T: - """Get cached data or call loader if not available.""" - cache_obj = get_cache() - value = cache_obj.get(cache_key) + value = cache_get(cache_key) if value is not None: return value - with loader_lock: # type: ignore[arg-type] - value = cache_obj.get(cache_key) - if value is not None: - return value - result = loader_func() - cache_obj.set(cache_key, result, ttl) - return result - - sync_wrapper.__wrapped__ = loader_func # type: ignore - sync_wrapper.__name__ = loader_func.__name__ # type: ignore - sync_wrapper.__doc__ = loader_func.__doc__ # type: ignore - sync_wrapper._cache = get_cache() # type: ignore - sync_wrapper._cache_key = cache_key # type: ignore + result = loader_func() + cache_set(cache_key, result, ttl) + return result - return sync_wrapper # type: ignore + attach_wrapper_metadata( + sync_wrapper_fn, loader_func, cache_obj=cache_obj, cache_key=cache_key + ) + return sync_wrapper_fn return decorator -# Alias for shorter usage -BGCache = BackgroundCache +BGCache = AsyncBackgroundCache diff --git a/src/advanced_caching/storage.py b/src/advanced_caching/storage.py index 9313bee..876975c 100644 --- a/src/advanced_caching/storage.py +++ b/src/advanced_caching/storage.py @@ -201,8 +201,7 @@ def get(self, key: str) -> Any | None: if entry is None: return None - now = time.time() - if not entry.is_fresh(now): + if time.time() >= entry.fresh_until: del self._data[key] return None diff --git a/tests/benchmark.py b/tests/benchmark.py index e12f764..8965631 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -1,5 +1,10 @@ +""" +Benchmarks for advanced_caching (Async-only architecture). +""" + from __future__ import annotations +import asyncio import json import os import random @@ -9,24 +14,14 @@ from datetime import datetime from pathlib import Path from statistics import mean, median, stdev -from typing import Callable, Iterable - -_REPO_ROOT = Path(__file__).resolve().parents[1] -_SRC_DIR = _REPO_ROOT / "src" -if _SRC_DIR.exists() and str(_SRC_DIR) not in sys.path: - sys.path.insert(0, str(_SRC_DIR)) +from typing import Dict, List from advanced_caching import BGCache, SWRCache, TTLCache -@dataclass(frozen=True) -class Config: - seed: int = 12345 - work_ms: float = 5.0 - warmup: int = 10 - runs: int = 300 - mixed_key_space: int = 100 - mixed_runs: int = 500 +# --------------------------------------------------------------------------- +# Config + helpers +# --------------------------------------------------------------------------- def _env_int(name: str, default: int) -> int: @@ -43,6 +38,16 @@ def _env_float(name: str, default: float) -> float: return float(raw) +@dataclass(frozen=True) +class Config: + seed: int = 12345 + work_ms: float = 5.0 + warmup: int = 10 + runs: int = 300 + mixed_key_space: int = 100 + mixed_runs: int = 500 + + CFG = Config( seed=_env_int("BENCH_SEED", 12345), work_ms=_env_float("BENCH_WORK_MS", 5.0), @@ -64,39 +69,36 @@ class Stats: stdev_ms: float -def io_bound_call(user_id: int) -> dict: - """Simulate a typical small I/O call (db/API).""" - time.sleep(CFG.work_ms / 1000.0) +async def async_io_bound_call(user_id: int) -> dict: + await asyncio.sleep(CFG.work_ms / 1000.0) return {"id": user_id, "name": f"User{user_id}"} -def _timed(fn: Callable[[], object], warmup: int, runs: int) -> list[float]: +async def _timed_async(fn, warmup: int, runs: int) -> List[float]: for _ in range(warmup): - fn() - - times: list[float] = [] + await fn() + out: List[float] = [] for _ in range(runs): t0 = time.perf_counter() - fn() - times.append((time.perf_counter() - t0) * 1000.0) - return times + await fn() + out.append((time.perf_counter() - t0) * 1000.0) + return out -def bench( - label: str, fn: Callable[[], object], *, notes: str, warmup: int, runs: int +def stats_from_samples( + label: str, notes: str, runs: int, samples: List[float] ) -> Stats: - times = _timed(fn, warmup=warmup, runs=runs) return Stats( - label=label, - notes=notes, - runs=runs, - median_ms=median(times), - mean_ms=mean(times), - stdev_ms=(stdev(times) if len(times) > 1 else 0.0), + label, + notes, + runs, + median(samples), + mean(samples), + stdev(samples) if len(samples) > 1 else 0.0, ) -def print_table(title: str, rows: list[Stats]) -> None: +def print_table(title: str, rows: List[Stats]) -> None: print("\n" + title) print("-" * len(title)) print( @@ -108,156 +110,8 @@ def print_table(title: str, rows: list[Stats]) -> None: ) -def keys_unique(n: int) -> Iterable[int]: - for i in range(1, n + 1): - yield i - - -def keys_mixed(n: int, key_space: int) -> list[int]: - return [RNG.randint(1, key_space) for _ in range(n)] - - -def scenario_cold() -> list[Stats]: - """Always-miss: new key every call.""" - cold_keys = iter(keys_unique(CFG.runs + CFG.warmup)) - baseline = bench( - "baseline", - lambda: io_bound_call(next(cold_keys)), - notes="no cache", - warmup=CFG.warmup, - runs=CFG.runs, - ) - - ttl_counter = iter(keys_unique(CFG.runs + CFG.warmup)) - - @TTLCache.cached("user:{}", ttl=60) - def ttl_fn(user_id: int) -> dict: - return io_bound_call(user_id) - - ttl = bench( - "TTLCache", - lambda: ttl_fn(next(ttl_counter)), - notes="miss + store", - warmup=CFG.warmup, - runs=CFG.runs, - ) - - swr_counter = iter(keys_unique(CFG.runs + CFG.warmup)) - - @SWRCache.cached("user:{}", ttl=60, stale_ttl=30) - def swr_fn(user_id: int) -> dict: - return io_bound_call(user_id) - - swr = bench( - "SWRCache", - lambda: swr_fn(next(swr_counter)), - notes="miss + store", - warmup=CFG.warmup, - runs=CFG.runs, - ) - - return [baseline, ttl, swr] - - -def scenario_hot() -> list[Stats]: - """Always-hit: same key every call.""" - baseline = bench( - "baseline", - lambda: io_bound_call(1), - notes="no cache", - warmup=max(2, CFG.warmup // 2), - runs=max(50, CFG.runs), - ) - - @TTLCache.cached("user:{}", ttl=60) - def ttl_fn(user_id: int) -> dict: - return io_bound_call(user_id) - - ttl_fn(1) - ttl = bench( - "TTLCache", - lambda: ttl_fn(1), - notes="hit", - warmup=CFG.warmup, - runs=CFG.runs, - ) - - @SWRCache.cached("user:{}", ttl=60, stale_ttl=30) - def swr_fn(user_id: int) -> dict: - return io_bound_call(user_id) - - swr_fn(1) - swr = bench( - "SWRCache", - lambda: swr_fn(1), - notes="fresh hit", - warmup=CFG.warmup, - runs=CFG.runs, - ) - - @BGCache.register_loader("bench_user", interval_seconds=60, run_immediately=True) - def bg_user() -> dict: - return io_bound_call(1) - - time.sleep(0.05) - bg = bench( - "BGCache", - bg_user, - notes="preloaded", - warmup=CFG.warmup, - runs=CFG.runs, - ) - - return [baseline, ttl, swr, bg] - - -def scenario_mixed() -> list[Stats]: - """Fixed key space: mix of hits/misses.""" - keys = keys_mixed(CFG.mixed_runs + CFG.warmup, CFG.mixed_key_space) - it = iter(keys) - baseline = bench( - "baseline", - lambda: io_bound_call(next(it)), - notes=f"no cache (key_space={CFG.mixed_key_space})", - warmup=CFG.warmup, - runs=CFG.mixed_runs, - ) - - keys = keys_mixed(CFG.mixed_runs + CFG.warmup, CFG.mixed_key_space) - it = iter(keys) - - @TTLCache.cached("user:{}", ttl=60) - def ttl_fn(user_id: int) -> dict: - return io_bound_call(user_id) - - ttl = bench( - "TTLCache", - lambda: ttl_fn(next(it)), - notes=f"mixed (key_space={CFG.mixed_key_space})", - warmup=CFG.warmup, - runs=CFG.mixed_runs, - ) - - keys = keys_mixed(CFG.mixed_runs + CFG.warmup, CFG.mixed_key_space) - it = iter(keys) - - @SWRCache.cached("user:{}", ttl=60, stale_ttl=30) - def swr_fn(user_id: int) -> dict: - return io_bound_call(user_id) - - swr = bench( - "SWRCache", - lambda: swr_fn(next(it)), - notes=f"mixed (key_space={CFG.mixed_key_space})", - warmup=CFG.warmup, - runs=CFG.mixed_runs, - ) - - return [baseline, ttl, swr] - - def append_json_log( - status: str, error: str | None, sections: dict[str, list[Stats]] + status: str, error: str | None, sections: Dict[str, List[Stats]] ) -> None: payload = { "ts": datetime.now().isoformat(timespec="seconds"), @@ -288,7 +142,6 @@ def append_json_log( for name, rows in sections.items() }, } - try: log_path = Path(__file__).resolve().parent.parent / "benchmarks.log" log_path.parent.mkdir(parents=True, exist_ok=True) @@ -298,26 +151,80 @@ def append_json_log( pass -def main() -> None: - status = "ok" - error: str | None = None - sections: dict[str, list[Stats]] = {} +def shutdown_schedulers() -> None: + try: + BGCache.shutdown(wait=False) + except Exception: + pass - print("advanced_caching benchmark (minimal)") - print( - f"work_ms={CFG.work_ms} seed={CFG.seed} warmup={CFG.warmup} runs={CFG.runs} mixed_runs={CFG.mixed_runs}" + +# --------------------------------------------------------------------------- +# Scenarios +# --------------------------------------------------------------------------- + + +async def scenario_hot_hits() -> List[Stats]: + """Benchmark hot cache hits for all strategies.""" + + # 1. TTLCache + @TTLCache.cached("bench:ttl:{}", ttl=60) + async def ttl_fn(user_id: int) -> dict: + return await async_io_bound_call(user_id) + + # Prime cache + await ttl_fn(1) + + ttl_samples = await _timed_async( + lambda: ttl_fn(1), warmup=CFG.warmup, runs=CFG.runs ) + ttl_stats = stats_from_samples("TTLCache", "hot hit", CFG.runs, ttl_samples) - try: - sections["cold"] = scenario_cold() - print_table("Cold (always miss)", sections["cold"]) + # 2. SWRCache + @SWRCache.cached("bench:swr:{}", ttl=60, stale_ttl=30) + async def swr_fn(user_id: int) -> dict: + return await async_io_bound_call(user_id) + + # Prime cache + await swr_fn(1) + + swr_samples = await _timed_async( + lambda: swr_fn(1), warmup=CFG.warmup, runs=CFG.runs + ) + swr_stats = stats_from_samples("SWRCache", "hot hit", CFG.runs, swr_samples) + + # 3. BGCache + @BGCache.register_loader("bench:bg", interval_seconds=60, run_immediately=True) + async def bg_loader() -> dict: + return await async_io_bound_call(1) - sections["hot"] = scenario_hot() - print_table("Hot (always hit)", sections["hot"]) + # Wait for load + await asyncio.sleep(0.05) + + bg_samples = await _timed_async( + lambda: bg_loader(), warmup=CFG.warmup, runs=CFG.runs + ) + bg_stats = stats_from_samples("BGCache", "preloaded", CFG.runs, bg_samples) + + return [ttl_stats, swr_stats, bg_stats] + + +async def run_benchmarks() -> Dict[str, List[Stats]]: + return { + "hot_hits": await scenario_hot_hits(), + } - sections["mixed"] = scenario_mixed() - print_table("Mixed (hits + misses)", sections["mixed"]) +def main() -> None: + status = "ok" + error = None + sections: Dict[str, List[Stats]] = {} + + print("advanced_caching benchmark (Async-only)") + print(f"work_ms={CFG.work_ms} seed={CFG.seed} warmup={CFG.warmup} runs={CFG.runs}") + + try: + sections = asyncio.run(run_benchmarks()) + print_table("Hot Cache Hits", sections["hot_hits"]) except KeyboardInterrupt: status = "interrupted" error = "KeyboardInterrupt" @@ -327,10 +234,7 @@ def main() -> None: error = f"{type(e).__name__}: {e}" raise finally: - try: - BGCache.shutdown(wait=False) - except Exception: - pass + shutdown_schedulers() append_json_log(status=status, error=error, sections=sections) diff --git a/tests/compare_benchmarks.py b/tests/compare_benchmarks.py deleted file mode 100644 index 2d41df3..0000000 --- a/tests/compare_benchmarks.py +++ /dev/null @@ -1,248 +0,0 @@ -from __future__ import annotations - -import argparse -import json -from pathlib import Path -from typing import Any - - -def _load_json_runs(log_path: Path) -> list[dict[str, Any]]: - runs: list[dict[str, Any]] = [] - for line in log_path.read_text(encoding="utf-8", errors="replace").splitlines(): - line = line.strip() - if not line.startswith("{"): - continue - try: - obj = json.loads(line) - except Exception: - continue - if ( - isinstance(obj, dict) - and "results" in obj - and isinstance(obj["results"], dict) - ): - runs.append(obj) - return runs - - -def _parse_selector(spec: str) -> int: - """Return a list index from a selector. - - Supported: - - "last" => -1 - - "last-N" => -(N+1) - - integer (0-based): "0", "2" ... - - negative integer: "-1", "-2" ... - """ - if spec == "last": - return -1 - if spec.startswith("last-"): - n = int(spec.split("-", 1)[1]) - return -(n + 1) - try: - return int(spec) - except ValueError as e: - raise ValueError( - f"Unsupported selector: {spec!r}. Use 'last', 'last-N', or an integer index." - ) from e - - -def _median_map(run: dict[str, Any]) -> dict[tuple[str, str], float]: - out: dict[tuple[str, str], float] = {} - results = run.get("results", {}) - for section, rows in results.items(): - if not isinstance(rows, list): - continue - for row in rows: - if not isinstance(row, dict): - continue - label = str(row.get("label", "")) - med = row.get("median_ms") - if not label or not isinstance(med, (int, float)): - continue - out[(str(section), label)] = float(med) - return out - - -def _print_compare(a: dict[str, Any], b: dict[str, Any]) -> None: - a_ts = a.get("ts", "?") - b_ts = b.get("ts", "?") - a_cfg = a.get("config", {}) - b_cfg = b.get("config", {}) - - # Header - print("\n" + "=" * 100) - print("BENCHMARK COMPARISON REPORT") - print("=" * 100 + "\n") - - print(f"Run A (baseline): {a_ts}") - print(f" Config: {a_cfg}") - print() - print(f"Run B (current): {b_ts}") - print(f" Config: {b_cfg}") - print("\n" + "-" * 100 + "\n") - - a_m = _median_map(a) - b_m = _median_map(b) - keys = sorted(set(a_m) | set(b_m)) - - # Group by section - sections = {} - for section, label in keys: - if section not in sections: - sections[section] = [] - sections[section].append(label) - - # Calculate summary statistics - improvements = [] - regressions = [] - - for section in sorted(sections.keys()): - print(f"\n📊 {section.upper()}") - print("-" * 100) - print( - f"{'Strategy':<25} {'A (ms)':>12} {'B (ms)':>12} {'Change':>12} {'%':>8} {'Status':>12}" - ) - print("-" * 100) - - for label in sorted(sections[section]): - a_med = a_m.get((section, label)) - b_med = b_m.get((section, label)) - if a_med is None or b_med is None: - continue - - delta = b_med - a_med - pct = (delta / a_med * 100.0) if a_med > 0 else 0.0 - - # Determine status - if abs(pct) < 2: - status = "✓ SAME" - elif pct < 0: - status = f"✓ FASTER ({abs(pct):.1f}%)" - improvements.append((section, label, abs(pct))) - else: - status = f"✗ SLOWER ({pct:.1f}%)" - regressions.append((section, label, pct)) - - delta_str = f"{delta:+.4f}" - print( - f"{label:<25} {a_med:>12.4f} {b_med:>12.4f} {delta_str:>12} {pct:>7.1f}% {status:>12}" - ) - - # Summary section - print("\n" + "=" * 100) - print("SUMMARY") - print("=" * 100) - - if improvements: - print(f"\n✓ {len(improvements)} IMPROVEMENT(S):") - for section, label, pct in sorted(improvements, key=lambda x: -x[2])[:5]: - print(f" • {label:<30} in {section:<15} → {pct:>6.2f}% faster") - else: - print("\n✓ No improvements detected") - - if regressions: - print(f"\n✗ {len(regressions)} REGRESSION(S):") - for section, label, pct in sorted(regressions, key=lambda x: -x[2])[:5]: - print(f" • {label:<30} in {section:<15} → {pct:>6.2f}% slower") - else: - print("\n✓ No regressions detected") - - # Overall verdict - print("\n" + "-" * 100) - total_changes = len(improvements) + len(regressions) - if total_changes == 0: - verdict = "✓ PERFORMANCE STABLE (no significant changes)" - elif len(improvements) > len(regressions): - avg_improvement = sum(x[2] for x in improvements) / len(improvements) - verdict = f"✓ OVERALL IMPROVEMENT (avg +{avg_improvement:.2f}% faster)" - else: - avg_regression = sum(x[2] for x in regressions) / len(regressions) - verdict = f"✗ OVERALL REGRESSION (avg +{avg_regression:.2f}% slower)" - - # Detailed analysis by section - print("\n" + "=" * 100) - print("DETAILED ANALYSIS BY SCENARIO") - print("=" * 100) - - for section in sorted(sections.keys()): - section_improvements = [x for x in improvements if x[0] == section] - section_regressions = [x for x in regressions if x[0] == section] - - if not section_improvements and not section_regressions: - continue - - print(f"\n🔍 {section.upper()}") - - if section_improvements: - avg_improvement = sum(x[2] for x in section_improvements) / len( - section_improvements - ) - print(f" ✓ Average improvement: {avg_improvement:.2f}%") - - if section_regressions: - avg_regression = sum(x[2] for x in section_regressions) / len( - section_regressions - ) - print(f" ✗ Average regression: {avg_regression:.2f}%") - - if not section_improvements and section_regressions: - print(f" ⚠️ Watch: Only regressions detected in this scenario") - - # Recommendations - print("\n" + "=" * 100) - print("RECOMMENDATIONS") - print("=" * 100) - - if total_changes == 0: - print("\n✓ No action needed. Performance is stable.") - recommendation = "continue with current changes" - elif len(regressions) > 0 and sum(x[2] for x in regressions) / len(regressions) > 5: - print("\n⚠️ SIGNIFICANT REGRESSIONS DETECTED") - print(" Consider:") - print(" • Profiling the affected code paths") - print(" • Reviewing recent changes for optimization issues") - print(" • Checking for new dependencies or imports") - recommendation = "investigate and optimize" - elif len(improvements) > 0: - print("\n✓ Performance improvements detected!") - print(" Recommendation: Merge and deploy") - recommendation = "good to merge" - else: - print("\n✓ No significant regressions detected") - recommendation = "safe to merge" - - print("\n" + "=" * 100) - print(f"\n📋 STATUS: {recommendation.upper()}\n") - print("=" * 100 + "\n") - - -def main() -> None: - p = argparse.ArgumentParser( - description="Compare two JSON benchmark runs in benchmarks.log" - ) - p.add_argument("--log", default="benchmarks.log", help="Path to benchmarks.log") - p.add_argument( - "--a", - default="last-1", - help="Run selector: last, last-N, or integer index (0-based; negatives allowed)", - ) - p.add_argument( - "--b", - default="last", - help="Run selector: last, last-N, or integer index (0-based; negatives allowed)", - ) - args = p.parse_args() - - log_path = Path(args.log) - runs = _load_json_runs(log_path) - if len(runs) < 2: - raise SystemExit(f"Need at least 2 JSON runs in {log_path}") - - a = runs[_parse_selector(args.a)] - b = runs[_parse_selector(args.b)] - _print_compare(a, b) - - -if __name__ == "__main__": - main() diff --git a/tests/test_correctness.py b/tests/test_correctness.py index e1af30d..9736350 100644 --- a/tests/test_correctness.py +++ b/tests/test_correctness.py @@ -1,10 +1,11 @@ """ -Fast and reliable unit tests for caching decorators. +Fast and reliable unit tests for caching decorators (Async-first). Tests TTLCache, SWRCache, and BGCache functionality. """ -import concurrent.futures +import asyncio import pytest +import pytest_asyncio import time from advanced_caching import ( @@ -17,223 +18,229 @@ ) -@pytest.fixture(autouse=True) -def cleanup(): +@pytest_asyncio.fixture +async def cleanup(): """Clean up scheduler between tests.""" yield try: BGCache.shutdown(wait=False) except: pass - time.sleep(0.05) + await asyncio.sleep(0.05) +@pytest.mark.asyncio +@pytest.mark.usefixtures("cleanup") class TestTTLCache: """TTLCache decorator tests.""" - def test_basic_caching(self): + async def test_basic_caching(self): """Test basic TTL caching with function calls.""" call_count = {"count": 0} @TTLCache.cached("user:{}", ttl=60) - def get_user(user_id): + async def get_user(user_id): call_count["count"] += 1 return {"id": user_id, "name": f"User{user_id}"} # First call - cache miss - result1 = get_user(1) + result1 = await get_user(1) assert result1 == {"id": 1, "name": "User1"} assert call_count["count"] == 1 # Second call - cache hit - result2 = get_user(1) + result2 = await get_user(1) assert result2 == {"id": 1, "name": "User1"} assert call_count["count"] == 1 # Not incremented # Different key - cache miss - result3 = get_user(2) + result3 = await get_user(2) assert result3 == {"id": 2, "name": "User2"} assert call_count["count"] == 2 - def test_ttl_expiration(self): + async def test_ttl_expiration(self): """Test that cache expires after TTL.""" call_count = {"count": 0} - @TTLCache.cached("data:{}", ttl=0.5) - def get_data(key): + @TTLCache.cached("data:{}", ttl=0.2) + async def get_data(key): call_count["count"] += 1 return {"key": key, "count": call_count["count"]} # First call - result1 = get_data("test") + result1 = await get_data("test") assert result1["count"] == 1 assert call_count["count"] == 1 # Cache should still be valid - result2 = get_data("test") + result2 = await get_data("test") assert result2["count"] == 1 assert call_count["count"] == 1 # Wait for expiration - time.sleep(0.6) + await asyncio.sleep(0.3) # Cache should be expired, function called again - result3 = get_data("test") + result3 = await get_data("test") assert result3["count"] == 2 assert call_count["count"] == 2 - def test_custom_cache_backend(self): + async def test_custom_cache_backend(self): """Test TTLCache with custom backend.""" custom_cache = InMemCache() @TTLCache.cached("item:{}", ttl=60, cache=custom_cache) - def get_item(item_id): + async def get_item(item_id): return {"id": item_id} - result = get_item(123) + result = await get_item(123) assert result == {"id": 123} # Verify in custom cache assert custom_cache.exists("item:123") - def test_callable_key_function(self): + async def test_callable_key_function(self): """Test TTLCache with callable key function.""" @TTLCache.cached(key=lambda user_id: f"user:{user_id}", ttl=60) - def get_user(user_id): + async def get_user(user_id): return {"id": user_id} - result = get_user(42) + result = await get_user(42) assert result == {"id": 42} - def test_isolated_caches(self): + async def test_isolated_caches(self): """Test that each TTL cached function has its own cache.""" @TTLCache.cached("user:{}", ttl=60) - def get_user(user_id): + async def get_user(user_id): return {"type": "user", "id": user_id} @TTLCache.cached("product:{}", ttl=60) - def get_product(product_id): + async def get_product(product_id): return {"type": "product", "id": product_id} # Each should have its own cache assert get_user._cache is not get_product._cache # Both should work - assert get_user(1)["type"] == "user" - assert get_product(1)["type"] == "product" + assert (await get_user(1))["type"] == "user" + assert (await get_product(1))["type"] == "product" +@pytest.mark.asyncio class TestSWRCache: """SWRCache (Stale-While-Revalidate) tests.""" - def test_fresh_cache_hit(self): + async def test_fresh_cache_hit(self): """Test SWR with fresh cache returns immediately.""" call_count = {"count": 0} @SWRCache.cached("user:{}", ttl=60, stale_ttl=30) - def get_user(user_id): + async def get_user(user_id): call_count["count"] += 1 return {"id": user_id, "count": call_count["count"]} # First call - cache miss - result1 = get_user(1) + result1 = await get_user(1) assert result1["count"] == 1 assert call_count["count"] == 1 # Second call - should hit fresh cache - result2 = get_user(1) + result2 = await get_user(1) assert result2["count"] == 1 # Same cached value assert call_count["count"] == 1 # Function not called again - def test_stale_with_background_refresh(self): + async def test_stale_with_background_refresh(self): """Test SWR serves stale data while refreshing in background.""" call_count = {"count": 0} - @SWRCache.cached("data:{}", ttl=0.3, stale_ttl=0.5) - def get_data(key): + @SWRCache.cached("data:{}", ttl=0.2, stale_ttl=0.5) + async def get_data(key): call_count["count"] += 1 return {"key": key, "count": call_count["count"]} # First call - result1 = get_data("test") + result1 = await get_data("test") assert result1["count"] == 1 assert call_count["count"] == 1 # Wait for data to become stale but within grace period - time.sleep(0.4) + await asyncio.sleep(0.3) # Should return stale value and refresh in background - result2 = get_data("test") + result2 = await get_data("test") assert result2["count"] == 1 # Still getting stale data - # Background refresh may or may not have completed yet # Wait for background refresh to complete - time.sleep(0.2) + await asyncio.sleep(0.2) # Now should have fresh data - result3 = get_data("test") + result3 = await get_data("test") assert result3["count"] >= 2 # Should be refreshed - def test_too_stale_refetch(self): + async def test_too_stale_refetch(self): """Test SWR refetches when too stale.""" call_count = {"count": 0} - @SWRCache.cached("data:{}", ttl=0.2, stale_ttl=0.2) - def get_data(key): + @SWRCache.cached("data:{}", ttl=0.1, stale_ttl=0.1) + async def get_data(key): call_count["count"] += 1 return {"key": key, "count": call_count["count"]} # First call - result1 = get_data("test") + result1 = await get_data("test") assert result1["count"] == 1 # Wait until beyond TTL + stale_ttl - time.sleep(0.5) + await asyncio.sleep(0.3) # Should refetch immediately (not within grace period) - result2 = get_data("test") + result2 = await get_data("test") assert result2["count"] == 2 # Refetched assert call_count["count"] == 2 - def test_custom_cache_backend(self): + async def test_custom_cache_backend(self): """Test SWRCache with custom backend.""" custom_cache = InMemCache() @SWRCache.cached("item:{}", ttl=60, stale_ttl=30, cache=custom_cache) - def get_item(item_id): + async def get_item(item_id): return {"id": item_id} - result = get_item(123) + result = await get_item(123) assert result == {"id": 123} +@pytest.mark.asyncio +@pytest.mark.usefixtures("cleanup") class TestBGCache: """BGCache (Background Scheduler) tests.""" - def test_sync_loader_immediate(self): - """Test sync loader with immediate execution.""" + async def test_async_loader_immediate(self): + """Test async loader with immediate execution.""" call_count = {"count": 0} - @BGCache.register_loader("sync_test", interval_seconds=10, run_immediately=True) - def load_data(): + @BGCache.register_loader( + "async_test", interval_seconds=10, run_immediately=True + ) + async def load_data(): call_count["count"] += 1 return {"value": call_count["count"]} - time.sleep(0.1) # Wait for initial load + await asyncio.sleep(0.1) # Wait for initial load # First call should return cached data - result = load_data() + result = await load_data() assert result == {"value": 1} assert call_count["count"] == 1 # Second call should still use cache - result2 = load_data() + result2 = await load_data() assert result2 == {"value": 1} assert call_count["count"] == 1 # Not called again - def test_sync_loader_no_immediate(self): + async def test_sync_loader_no_immediate(self): """Test sync loader without immediate execution.""" call_count = {"count": 0} @@ -244,48 +251,48 @@ def load_data(): call_count["count"] += 1 return {"value": call_count["count"]} - time.sleep(0.1) + await asyncio.sleep(0.1) # Should not have been called yet assert call_count["count"] == 0 # First call will execute the function since cache is empty - result = load_data() + result = await load_data() assert result == {"value": 1} assert call_count["count"] == 1 - def test_custom_cache_backend(self): + async def test_custom_cache_backend(self): """Test BGCache using custom cache backend.""" custom_cache = InMemCache() @BGCache.register_loader( "custom", interval_seconds=10, run_immediately=True, cache=custom_cache ) - def load_data(): + async def load_data(): return {"custom": True} - time.sleep(0.1) + await asyncio.sleep(0.1) # Verify data is in custom cache cached_value = custom_cache.get("custom") assert cached_value == {"custom": True} # Call function - result = load_data() + result = await load_data() assert result == {"custom": True} - def test_isolated_cache_instances(self): + async def test_isolated_cache_instances(self): """Test that each loader has its own cache.""" @BGCache.register_loader("loader1", interval_seconds=10, run_immediately=True) - def load1(): + async def load1(): return {"id": 1} @BGCache.register_loader("loader2", interval_seconds=10, run_immediately=True) - def load2(): + async def load2(): return {"id": 2} - time.sleep(0.1) + await asyncio.sleep(0.1) # Each should have its own cache assert load1._cache is not load2._cache @@ -293,10 +300,10 @@ def load2(): assert load2._cache_key == "loader2" # Each should have correct data - assert load1() == {"id": 1} - assert load2() == {"id": 2} + assert (await load1()) == {"id": 1} + assert (await load2()) == {"id": 2} - def test_error_handling(self): + async def test_error_handling(self): """Test error handler is called on failure.""" errors = [] @@ -309,111 +316,60 @@ def error_handler(e): run_immediately=True, on_error=error_handler, ) - def load_data(): + async def load_data(): raise ValueError("Test error") - time.sleep(0.1) + await asyncio.sleep(0.1) # Error should have been captured assert len(errors) == 1 assert isinstance(errors[0], ValueError) assert str(errors[0]) == "Test error" - def test_periodic_refresh(self): + async def test_periodic_refresh(self): """Test that data refreshes periodically.""" call_count = {"count": 0} - @BGCache.register_loader("periodic", interval_seconds=0.5, run_immediately=True) - def load_data(): + @BGCache.register_loader("periodic", interval_seconds=0.2, run_immediately=True) + async def load_data(): call_count["count"] += 1 return {"value": call_count["count"]} # Wait for initial load - time.sleep(0.1) + await asyncio.sleep(0.1) assert call_count["count"] == 1 # Wait for one refresh - time.sleep(0.6) + await asyncio.sleep(0.3) assert call_count["count"] >= 2 # Get updated data - result = load_data() + result = await load_data() assert result["value"] >= 2 - def test_multiple_loaders(self): + async def test_multiple_loaders(self): """Test multiple loaders can coexist.""" @BGCache.register_loader("loader_a", interval_seconds=10, run_immediately=True) - def load_a(): + async def load_a(): return {"name": "a"} @BGCache.register_loader("loader_b", interval_seconds=10, run_immediately=True) - def load_b(): + async def load_b(): return {"name": "b"} @BGCache.register_loader("loader_c", interval_seconds=10, run_immediately=True) - def load_c(): + async def load_c(): return {"name": "c"} - time.sleep(0.15) + await asyncio.sleep(0.15) # All should work independently - assert load_a()["name"] == "a" - assert load_b()["name"] == "b" - assert load_c()["name"] == "c" - - def test_concurrent_access_is_thread_safe(self): - """Concurrent callers should read cached data without duplicate loads.""" - call_count = {"count": 0} - - @BGCache.register_loader( - "concurrent_loader", interval_seconds=60, run_immediately=True - ) - def load_data(): - # Simulate work to surface races if present - time.sleep(0.05) - call_count["count"] += 1 - return {"value": call_count["count"]} - - # Wait for initial load triggered by run_immediately - time.sleep(0.1) - - def call_loader(_: int): - return load_data() - - with concurrent.futures.ThreadPoolExecutor(max_workers=12) as executor: - results = list(executor.map(call_loader, range(24))) - - # All callers should see the cached value produced by the first load - assert all(r == {"value": 1} for r in results) - assert call_count["count"] == 1 - - def test_concurrent_initial_load_when_no_immediate(self): - """When run_immediately=False, first concurrent callers should single-flight load.""" - call_count = {"count": 0} + assert (await load_a())["name"] == "a" + assert (await load_b())["name"] == "b" + assert (await load_c())["name"] == "c" - @BGCache.register_loader( - "concurrent_no_immediate", - interval_seconds=30, - run_immediately=False, - ttl=30, - ) - def load_data(): - time.sleep(0.05) - call_count["count"] += 1 - return {"value": call_count["count"]} - - def call_loader(_: int): - return load_data() - - with concurrent.futures.ThreadPoolExecutor(max_workers=12) as executor: - results = list(executor.map(call_loader, range(24))) - - # Only one load should have happened, all callers get cached value - assert all(r == {"value": 1} for r in results) - assert call_count["count"] == 1 - - def test_lambda_cache_factory(self): + async def test_lambda_cache_factory(self): """Test BGCache with lambda returning HybridCache.""" call_count = {"count": 0} @@ -425,17 +381,17 @@ def test_lambda_cache_factory(self): l1_cache=InMemCache(), l2_cache=InMemCache(), l1_ttl=60 ), ) - def get_test_data() -> dict[str, str]: + async def get_test_data() -> dict[str, str]: call_count["count"] += 1 return {"key": "value", "count": str(call_count["count"])} # First call should hit the cache (run_immediately=True loaded it) - result1 = get_test_data() + result1 = await get_test_data() assert result1 == {"key": "value", "count": "1"} assert call_count["count"] == 1 # Second call should return cached value - result2 = get_test_data() + result2 = await get_test_data() assert result2 == {"key": "value", "count": "1"} assert call_count["count"] == 1 # No additional call @@ -444,56 +400,26 @@ def get_test_data() -> dict[str, str]: assert get_test_data._cache is not None assert isinstance(get_test_data._cache, HybridCache) - def test_lambda_cache_nested_dict_access(self): - """Test nested dict access pattern with lambda cache factory.""" - - @BGCache.register_loader( - "nested_dict_map", - interval_seconds=3600, - run_immediately=True, - cache=lambda: HybridCache( - l1_cache=InMemCache(), l2_cache=InMemCache(), l1_ttl=3600 - ), - ) - def get_mapping() -> dict[str, dict]: - return { - "color": {"en": "Color", "fr": "Couleur"}, - "size": {"en": "Size", "fr": "Taille"}, - } - - # Test the exact pattern that could fail if lambda not instantiated - name = get_mapping().get("color", {}).get("en") - assert name == "Color" - - name = get_mapping().get("size", {}).get("fr") - assert name == "Taille" - - name = get_mapping().get("missing", {}).get("en") - assert name is None - - # Verify cache is properly instantiated (not a lambda) - cache_obj = get_mapping._cache - assert isinstance(cache_obj, HybridCache) - assert not callable(cache_obj) or hasattr(cache_obj, "get") - +@pytest.mark.asyncio +@pytest.mark.usefixtures("cleanup") class TestCachePerformance: """Performance and speed tests.""" - def test_cache_hit_speed(self): + async def test_cache_hit_speed(self): """Test that cache hits are fast.""" @BGCache.register_loader("perf_test", interval_seconds=10, run_immediately=True) - def load_data(): - time.sleep(0.01) # Simulate slow operation + async def load_data(): + await asyncio.sleep(0.01) # Simulate slow operation return {"data": "value"} - time.sleep(0.05) # Wait for initial load + await asyncio.sleep(0.05) # Wait for initial load # Measure cache hit time start = time.perf_counter() for _ in range(1000): - result = load_data() + result = await load_data() duration = time.perf_counter() - start # Should be very fast (<1ms per call on average) @@ -501,55 +427,56 @@ def load_data(): assert avg_time < 0.001, f"Cache hit too slow: {avg_time * 1000:.3f}ms" assert result == {"data": "value"} - def test_ttl_cache_hit_speed(self): + async def test_ttl_cache_hit_speed(self): """Test TTLCache hit speed.""" @TTLCache.cached("item:{}", ttl=60) - def get_item(item_id): - time.sleep(0.001) # Simulate work + async def get_item(item_id): + await asyncio.sleep(0.001) # Simulate work return {"id": item_id} # Prime cache - get_item(1) + await get_item(1) # Measure cache hits start = time.perf_counter() for _ in range(1000): - get_item(1) + await get_item(1) duration = time.perf_counter() - start avg_time = duration / 1000 assert avg_time < 0.0005, f"TTL cache hit too slow: {avg_time * 1000:.3f}ms" +@pytest.mark.asyncio class TestKeyTemplates: """Key template behavior for TTLCache and SWRCache.""" - def test_ttl_positional_template(self): + async def test_ttl_positional_template(self): calls = {"n": 0} @TTLCache.cached("user:{}", ttl=60) - def get_user(user_id: int): + async def get_user(user_id: int): calls["n"] += 1 return {"id": user_id} - assert get_user(1) == {"id": 1} - assert get_user(1) == {"id": 1} + assert (await get_user(1)) == {"id": 1} + assert (await get_user(1)) == {"id": 1} assert calls["n"] == 1 # cache hit by positional template - def test_ttl_named_template(self): + async def test_ttl_named_template(self): calls = {"n": 0} @TTLCache.cached("user:{user_id}", ttl=60) - def get_user(*, user_id: int): + async def get_user(*, user_id: int): calls["n"] += 1 return {"id": user_id} - assert get_user(user_id=2) == {"id": 2} - assert get_user(user_id=2) == {"id": 2} + assert (await get_user(user_id=2)) == {"id": 2} + assert (await get_user(user_id=2)) == {"id": 2} assert calls["n"] == 1 - def test_swr_default_arg_with_key_function(self): + async def test_swr_default_arg_with_key_function(self): calls = {"n": 0} @SWRCache.cached( @@ -557,58 +484,58 @@ def test_swr_default_arg_with_key_function(self): ttl=5, stale_ttl=10, ) - def load_all(lang: str = "en") -> dict: + async def load_all(lang: str = "en") -> dict: calls["n"] += 1 return {"hello": f"Hello in {lang}"} # Default arg used (no args provided) - r1 = load_all() - r2 = load_all(lang="en") - r3 = load_all() + r1 = await load_all() + r2 = await load_all(lang="en") + r3 = await load_all() assert r1 == {"hello": "Hello in en"} assert r2 == {"hello": "Hello in en"} assert r3 == {"hello": "Hello in en"} assert calls["n"] == 1 # all share the same cache key - def test_swr_named_template_with_kwargs(self): + async def test_swr_named_template_with_kwargs(self): calls = {"n": 0} @SWRCache.cached("i18n:{lang}", ttl=5, stale_ttl=10) - def load_i18n(*, lang: str = "en") -> dict: + async def load_i18n(*, lang: str = "en") -> dict: calls["n"] += 1 return {"hello": f"Hello in {lang}"} - r1 = load_i18n(lang="en") - r2 = load_i18n(lang="en") + r1 = await load_i18n(lang="en") + r2 = await load_i18n(lang="en") assert r1 == {"hello": "Hello in en"} assert r2 == {"hello": "Hello in en"} assert calls["n"] == 1 - def test_swr_positional_template_with_args(self): + async def test_swr_positional_template_with_args(self): calls = {"n": 0} @SWRCache.cached("i18n:{}", ttl=5, stale_ttl=10) - def load_i18n(lang: str) -> dict: + async def load_i18n(lang: str) -> dict: calls["n"] += 1 return {"hello": f"Hello in {lang}"} - r1 = load_i18n("en") - r2 = load_i18n("en") + r1 = await load_i18n("en") + r2 = await load_i18n("en") assert r1 == {"hello": "Hello in en"} assert r2 == {"hello": "Hello in en"} assert calls["n"] == 1 - def test_swr_named_template_with_extra_kwargs(self): + async def test_swr_named_template_with_extra_kwargs(self): calls = {"n": 0} @SWRCache.cached("i18n:{lang}", ttl=5, stale_ttl=10) - def load_i18n(lang: str, region: str | None = None) -> dict: + async def load_i18n(lang: str, region: str | None = None) -> dict: calls["n"] += 1 suffix = f"-{region}" if region else "" return {"hello": f"Hello in {lang}{suffix}"} - r1 = load_i18n(lang="en", region="US") - r2 = load_i18n(lang="en", region="US") + r1 = await load_i18n(lang="en", region="US") + r2 = await load_i18n(lang="en", region="US") assert r1 == {"hello": "Hello in en-US"} assert r2 == {"hello": "Hello in en-US"} assert calls["n"] == 1 @@ -663,61 +590,62 @@ def test_hybridcache_basic_flow(self): assert cache.get("x") is None +@pytest.mark.asyncio class TestDecoratorKeyEdgeCases: """Exercise edge key-generation paths for decorators.""" - def test_ttl_key_without_placeholders(self): + async def test_ttl_key_without_placeholders(self): calls = {"n": 0} @TTLCache.cached("static-key", ttl=60) - def f(user_id: int): + async def f(user_id: int): calls["n"] += 1 return user_id - assert f(1) == 1 - assert f(2) == 1 # same key, result from first call + assert (await f(1)) == 1 + assert (await f(2)) == 1 # same key, result from first call assert calls["n"] == 1 - def test_swr_key_without_args_or_kwargs(self): + async def test_swr_key_without_args_or_kwargs(self): calls = {"n": 0} @SWRCache.cached("static", ttl=1, stale_ttl=1) - def f() -> int: + async def f() -> int: calls["n"] += 1 return calls["n"] # First call: miss - assert f() == 1 + assert (await f()) == 1 # Immediate second call: hit - assert f() == 1 + assert (await f()) == 1 assert calls["n"] == 1 - def test_swr_key_template_single_kwarg_positional_fallback(self): + async def test_swr_key_template_single_kwarg_positional_fallback(self): calls = {"n": 0} # Template with positional placeholder but only kwarg passed @SWRCache.cached("foo:{}", ttl=1, stale_ttl=1) - def f(*, x: int) -> int: + async def f(*, x: int) -> int: calls["n"] += 1 return x - assert f(x=1) == 1 - assert f(x=1) == 1 + assert (await f(x=1)) == 1 + assert (await f(x=1)) == 1 assert calls["n"] == 1 - def test_swr_invalid_format_falls_back_to_raw_key(self): + async def test_swr_invalid_format_falls_back_to_raw_key(self): calls = {"n": 0} # Template expects named field that is never provided; we only pass kwargs @SWRCache.cached("foo:{missing}", ttl=1, stale_ttl=1) - def f(*, x: int) -> int: + async def f(*, x: int) -> int: calls["n"] += 1 return x # First call populates cache with raw key "foo:{missing}" after format failure - assert f(x=1) == 1 + assert (await f(x=1)) == 1 # Second call uses same raw key and returns cached value despite different arg - assert f(x=2) == 1 + assert (await f(x=2)) == 1 assert calls["n"] == 1 @@ -842,65 +770,61 @@ def test_l2_ttl_with_zero_ttl_in_set(self): assert cache.get("key4") == "value4" +@pytest.mark.asyncio +@pytest.mark.usefixtures("cleanup") class TestNoCachingWhenZero: """Ensure ttl/interval_seconds == 0 disables caching/background behavior.""" - def test_ttlcache_ttl_zero_disables_caching(self): + async def test_ttlcache_ttl_zero_disables_caching(self): calls = {"n": 0} @TTLCache.cached("user:{}", ttl=0) - def get_user(user_id: int) -> int: + async def get_user(user_id: int) -> int: calls["n"] += 1 return calls["n"] # Each call should invoke the function (no caching) - assert get_user(1) == 1 - assert get_user(1) == 2 - assert get_user(1) == 3 + assert (await get_user(1)) == 1 + assert (await get_user(1)) == 2 + assert (await get_user(1)) == 3 assert calls["n"] == 3 - def test_swrcache_ttl_zero_disables_caching(self): + async def test_swrcache_ttl_zero_disables_caching(self): calls = {"n": 0} @SWRCache.cached("data:{}", ttl=0, stale_ttl=10) - def get_data(key: str) -> int: + async def get_data(key: str) -> int: calls["n"] += 1 return calls["n"] # Each call should invoke the function (no SWR behavior) - assert get_data("k") == 1 - assert get_data("k") == 2 - assert get_data("k") == 3 + assert (await get_data("k")) == 1 + assert (await get_data("k")) == 2 + assert (await get_data("k")) == 3 assert calls["n"] == 3 - def test_bgcache_interval_zero_disables_background_and_cache(self): + async def test_bgcache_interval_zero_disables_background_and_cache(self): calls = {"n": 0} @BGCache.register_loader(key="no_bg", interval_seconds=0, ttl=None) - def load_data() -> int: + async def load_data() -> int: calls["n"] += 1 return calls["n"] # No background scheduler, no caching: each call increments - assert load_data() == 1 - assert load_data() == 2 - assert load_data() == 3 + assert (await load_data()) == 1 + assert (await load_data()) == 2 + assert (await load_data()) == 3 assert calls["n"] == 3 - def test_bgcache_ttl_zero_disables_background_and_cache(self): + async def test_bgcache_ttl_zero_disables_background_and_cache(self): calls = {"n": 0} @BGCache.register_loader(key="no_bg_ttl", interval_seconds=10, ttl=0) - def load_data() -> int: + async def load_data() -> int: calls["n"] += 1 return calls["n"] # Because ttl == 0, wrapper should bypass cache and scheduler - assert load_data() == 1 - assert load_data() == 2 - assert load_data() == 3 - assert calls["n"] == 3 - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) + assert (await load_data()) == 1 + assert (await load_data()) == 2 diff --git a/tests/test_integration_redis.py b/tests/test_integration_redis.py index f3c6788..b9461de 100644 --- a/tests/test_integration_redis.py +++ b/tests/test_integration_redis.py @@ -6,6 +6,7 @@ import pickle import pytest import time +import asyncio from typing import Any try: @@ -28,6 +29,12 @@ ) +@pytest.fixture(autouse=True) +async def reset_scheduler(): + yield + BGCache.shutdown(wait=False) + + @pytest.fixture(scope="module") def redis_container(): """Fixture to start a Redis container for the entire test module.""" @@ -179,122 +186,125 @@ def loads(data: bytes) -> Any: assert loaded.value == entry.value +@pytest.mark.asyncio class TestTTLCacheWithRedis: """Test TTLCache decorator with Redis backend.""" - def test_ttlcache_redis_basic(self, redis_client): + async def test_ttlcache_redis_basic(self, redis_client): """Test TTLCache with Redis backend.""" calls = {"n": 0} cache = RedisCache(redis_client, prefix="ttl:") @TTLCache.cached("user:{}", ttl=60, cache=cache) - def get_user(user_id: int): + async def get_user(user_id: int): calls["n"] += 1 return {"id": user_id, "name": f"User{user_id}"} - result1 = get_user(1) + result1 = await get_user(1) assert result1 == {"id": 1, "name": "User1"} assert calls["n"] == 1 - result2 = get_user(1) + result2 = await get_user(1) assert result2 == {"id": 1, "name": "User1"} assert calls["n"] == 1 - result3 = get_user(2) + result3 = await get_user(2) assert result3 == {"id": 2, "name": "User2"} assert calls["n"] == 2 - def test_ttlcache_redis_expiration(self, redis_client): + async def test_ttlcache_redis_expiration(self, redis_client): """Test TTLCache with Redis respects TTL.""" calls = {"n": 0} cache = RedisCache(redis_client, prefix="ttl:") @TTLCache.cached("data:{}", ttl=1, cache=cache) - def get_data(key: str): + async def get_data(key: str): calls["n"] += 1 return f"data_{key}" - result1 = get_data("test") + result1 = await get_data("test") assert result1 == "data_test" assert calls["n"] == 1 - result2 = get_data("test") + result2 = await get_data("test") assert calls["n"] == 1 - time.sleep(1.1) + await asyncio.sleep(1.1) - result3 = get_data("test") + result3 = await get_data("test") assert result3 == "data_test" assert calls["n"] == 2 - def test_ttlcache_redis_named_template(self, redis_client): + async def test_ttlcache_redis_named_template(self, redis_client): """Test TTLCache with Redis using named key template.""" calls = {"n": 0} cache = RedisCache(redis_client, prefix="ttl:") @TTLCache.cached("product:{product_id}", ttl=60, cache=cache) - def get_product(*, product_id: int): + async def get_product(*, product_id: int): calls["n"] += 1 return {"id": product_id, "name": f"Product{product_id}"} - result1 = get_product(product_id=100) + result1 = await get_product(product_id=100) assert result1 == {"id": 100, "name": "Product100"} assert calls["n"] == 1 - result2 = get_product(product_id=100) + result2 = await get_product(product_id=100) assert calls["n"] == 1 +@pytest.mark.asyncio class TestSWRCacheWithRedis: """Test SWRCache with Redis backend.""" - def test_swrcache_redis_basic(self, redis_client): + async def test_swrcache_redis_basic(self, redis_client): """Test SWRCache with Redis backend.""" calls = {"n": 0} cache = RedisCache(redis_client, prefix="swr:") @SWRCache.cached("product:{}", ttl=1, stale_ttl=1, cache=cache) - def get_product(product_id: int): + async def get_product(product_id: int): calls["n"] += 1 return {"id": product_id, "count": calls["n"]} - result1 = get_product(1) + result1 = await get_product(1) assert result1["count"] == 1 assert calls["n"] == 1 - result2 = get_product(1) + result2 = await get_product(1) assert result2["count"] == 1 assert calls["n"] == 1 - def test_swrcache_redis_stale_serve(self, redis_client): + async def test_swrcache_redis_stale_serve(self, redis_client): """Test SWRCache serves stale data while refreshing.""" calls = {"n": 0} cache = RedisCache(redis_client, prefix="swr:") @SWRCache.cached("data:{}", ttl=0.3, stale_ttl=0.5, cache=cache) - def get_data(key: str): + async def get_data(key: str): calls["n"] += 1 return {"key": key, "count": calls["n"]} - result1 = get_data("test") + result1 = await get_data("test") assert result1["count"] == 1 - time.sleep(0.4) + await asyncio.sleep(0.4) - result2 = get_data("test") + result2 = await get_data("test") assert result2["count"] == 1 # Give background refresh enough time (Redis + thread scheduling) - time.sleep(0.35) + await asyncio.sleep(0.35) - result3 = get_data("test") + result3 = await get_data("test") assert result3["count"] >= 2 +@pytest.mark.asyncio class TestBGCacheWithRedis: """Test BGCache with Redis backend.""" - def test_bgcache_redis_sync_loader(self, redis_client): + async def test_bgcache_redis_sync_loader(self, redis_client): """Test BGCache with sync loader and Redis backend.""" calls = {"n": 0} cache = RedisCache(redis_client, prefix="bg:") @@ -305,21 +315,21 @@ def test_bgcache_redis_sync_loader(self, redis_client): run_immediately=True, cache=cache, ) - def load_inventory(): + async def load_inventory(): calls["n"] += 1 return {"items": [f"item_{i}" for i in range(3)]} - time.sleep(0.1) + await asyncio.sleep(0.1) - result = load_inventory() + result = await load_inventory() assert result == {"items": ["item_0", "item_1", "item_2"]} assert calls["n"] == 1 - result2 = load_inventory() + result2 = await load_inventory() assert result2 == {"items": ["item_0", "item_1", "item_2"]} assert calls["n"] == 1 - def test_bgcache_redis_with_error_handler(self, redis_client): + async def test_bgcache_redis_with_error_handler(self, redis_client): """Test BGCache error handling with Redis.""" errors = [] cache = RedisCache(redis_client, prefix="bg:") @@ -334,10 +344,10 @@ def on_error(exc): on_error=on_error, cache=cache, ) - def failing_loader(): + async def failing_loader(): raise ValueError("Simulated failure") - time.sleep(0.1) + await asyncio.sleep(0.1) assert len(errors) == 1 assert isinstance(errors[0], ValueError) @@ -378,7 +388,8 @@ def test_hybridcache_l1_miss_l2_hit(self, redis_client): assert l1.get("key") == "value_from_l2" - def test_hybridcache_with_ttlcache(self, redis_client): + @pytest.mark.asyncio + async def test_hybridcache_with_ttlcache(self, redis_client): """Test TTLCache using HybridCache backend.""" l2 = RedisCache(redis_client, prefix="hybrid_ttl:") cache = HybridCache( @@ -390,15 +401,15 @@ def test_hybridcache_with_ttlcache(self, redis_client): calls = {"n": 0} @TTLCache.cached("user:{}", ttl=60, cache=cache) - def get_user(user_id: int): + async def get_user(user_id: int): calls["n"] += 1 return {"id": user_id} - result1 = get_user(1) + result1 = await get_user(1) assert result1 == {"id": 1} assert calls["n"] == 1 - result2 = get_user(1) + result2 = await get_user(1) assert result2 == {"id": 1} assert calls["n"] == 1 @@ -496,7 +507,8 @@ def test_hybridcache_l2_ttl_shorter_than_requested(self, redis_client): # TTL should be approximately l2_ttl (3 seconds), allow some margin assert 2 <= redis_ttl <= 4 - def test_hybridcache_with_bgcache_and_l2_ttl(self, redis_client): + @pytest.mark.asyncio + async def test_hybridcache_with_bgcache_and_l2_ttl(self, redis_client): """Test BGCache with HybridCache using l2_ttl.""" l2 = RedisCache(redis_client, prefix="hybrid_bg:") cache = HybridCache(l1_cache=InMemCache(), l2_cache=l2, l1_ttl=10, l2_ttl=60) @@ -509,18 +521,18 @@ def test_hybridcache_with_bgcache_and_l2_ttl(self, redis_client): run_immediately=True, cache=cache, ) - def load_config(): + async def load_config(): calls["n"] += 1 return {"setting": "value", "count": calls["n"]} - time.sleep(0.1) + await asyncio.sleep(0.1) - result = load_config() + result = await load_config() assert result["count"] == 1 assert calls["n"] == 1 # Verify it's cached - result2 = load_config() + result2 = await load_config() assert result2["count"] == 1 assert calls["n"] == 1 @@ -670,10 +682,11 @@ def test_hybridcache_very_short_l2_ttl(self, redis_client): assert cache.get("short_ttl") == "value" +@pytest.mark.asyncio class TestCacheRehydration: """Test that decorators can retrieve existing data from Redis without re-executing functions.""" - def test_ttlcache_rehydrates_from_redis(self, redis_client): + async def test_ttlcache_rehydrates_from_redis(self, redis_client): """Test TTLCache retrieves existing Redis data without executing function.""" # Pre-populate Redis test_data = {"result": "from_redis"} @@ -690,22 +703,22 @@ def test_ttlcache_rehydrates_from_redis(self, redis_client): l1_ttl=60, ), ) - def compute(x): + async def compute(x): nonlocal call_count call_count += 1 return {"result": f"computed_{x}"} # First call should retrieve from Redis without executing function - result = compute(42) + result = await compute(42) assert result == test_data assert call_count == 0, "Function should not execute when data exists in Redis" # Second call should hit L1 cache - result = compute(42) + result = await compute(42) assert result == test_data assert call_count == 0 - def test_swrcache_rehydrates_from_redis(self, redis_client): + async def test_swrcache_rehydrates_from_redis(self, redis_client): """Test SWRCache retrieves existing Redis data without executing function.""" # Pre-populate Redis with CacheEntry now = time.time() @@ -727,22 +740,22 @@ def test_swrcache_rehydrates_from_redis(self, redis_client): l1_ttl=60, ), ) - def fetch(x): + async def fetch(x): nonlocal call_count call_count += 1 return {"result": f"fetched_{x}"} # First call should retrieve from Redis without executing function - result = fetch(99) + result = await fetch(99) assert result == {"result": "from_redis"} assert call_count == 0, "Function should not execute when data exists in Redis" # Second call should hit L1 cache - result = fetch(99) + result = await fetch(99) assert result == {"result": "from_redis"} assert call_count == 0 - def test_bgcache_rehydrates_from_redis(self, redis_client): + async def test_bgcache_rehydrates_from_redis(self, redis_client): """Test BGCache retrieves existing Redis data without executing function on init.""" # Pre-populate Redis test_data = {"users": ["Alice", "Bob", "Charlie"]} @@ -760,7 +773,7 @@ def test_bgcache_rehydrates_from_redis(self, redis_client): l1_ttl=60, ), ) - def load_users(): + async def load_users(): nonlocal call_count call_count += 1 return {"users": ["New1", "New2"]} @@ -769,13 +782,13 @@ def load_users(): assert call_count == 0, "Function should not execute when data exists in Redis" # First call should hit L1 cache - result = load_users() + result = await load_users() assert result == test_data assert call_count == 0 BGCache.shutdown(wait=False) - def test_ttlcache_executes_on_cache_miss(self, redis_client): + async def test_ttlcache_executes_on_cache_miss(self, redis_client): """Test TTLCache executes function when Redis is empty.""" redis_client.flushdb() @@ -790,22 +803,22 @@ def test_ttlcache_executes_on_cache_miss(self, redis_client): l1_ttl=60, ), ) - def compute(x): + async def compute(x): nonlocal call_count call_count += 1 return {"result": f"computed_{x}"} # First call should execute function (cache miss) - result = compute(42) + result = await compute(42) assert result == {"result": "computed_42"} assert call_count == 1 # Second call should hit L1 cache - result = compute(42) + result = await compute(42) assert result == {"result": "computed_42"} assert call_count == 1 - def test_swrcache_executes_on_cache_miss(self, redis_client): + async def test_swrcache_executes_on_cache_miss(self, redis_client): """Test SWRCache executes function when Redis is empty.""" redis_client.flushdb() @@ -821,22 +834,22 @@ def test_swrcache_executes_on_cache_miss(self, redis_client): l1_ttl=60, ), ) - def fetch(x): + async def fetch(x): nonlocal call_count call_count += 1 return {"result": f"fetched_{x}"} # First call should execute function (cache miss) - result = fetch(99) + result = await fetch(99) assert result == {"result": "fetched_99"} assert call_count == 1 # Second call should hit L1 cache - result = fetch(99) + result = await fetch(99) assert result == {"result": "fetched_99"} assert call_count == 1 - def test_bgcache_executes_on_cache_miss(self, redis_client): + async def test_bgcache_executes_on_cache_miss(self, redis_client): """Test BGCache executes function on init when Redis is empty.""" redis_client.flushdb() @@ -852,22 +865,23 @@ def test_bgcache_executes_on_cache_miss(self, redis_client): l1_ttl=60, ), ) - def load_data(): + async def load_data(): nonlocal call_count call_count += 1 return {"data": "fresh_load"} # Function should execute during init (cache miss) + await asyncio.sleep(0.1) assert call_count == 1 # First call should hit L1 cache - result = load_data() + result = await load_data() assert result == {"data": "fresh_load"} assert call_count == 1 BGCache.shutdown(wait=False) - def test_ttlcache_different_args_separate_entries(self, redis_client): + async def test_ttlcache_different_args_separate_entries(self, redis_client): """Test TTLCache creates separate cache entries for different arguments.""" # Pre-populate Redis with data for arg=10 test_data = {"result": "from_redis_10"} @@ -884,23 +898,23 @@ def test_ttlcache_different_args_separate_entries(self, redis_client): l1_ttl=60, ), ) - def compute(x): + async def compute(x): nonlocal call_count call_count += 1 return {"result": f"computed_{x}"} # Call with arg=10 should get from Redis - result = compute(10) + result = await compute(10) assert result == test_data assert call_count == 0 # Call with arg=20 should execute function (no Redis data) - result = compute(20) + result = await compute(20) assert result == {"result": "computed_20"} assert call_count == 1 # Call with arg=10 again should get from L1 - result = compute(10) + result = await compute(10) assert result == test_data assert call_count == 1 @@ -925,19 +939,20 @@ def test_redis_cache_hit_performance(self, redis_client): assert avg_time_ms < 20, f"Redis cache hit too slow: {avg_time_ms:.3f}ms" assert result == {"data": "test"} - def test_ttlcache_with_redis_performance(self, redis_client): + @pytest.mark.asyncio + async def test_ttlcache_with_redis_performance(self, redis_client): """Test TTLCache performance with Redis backend.""" cache = RedisCache(redis_client, prefix="perf_ttl:") @TTLCache.cached("item:{}", ttl=60, cache=cache) - def get_item(item_id: int): + async def get_item(item_id: int): return {"id": item_id} - get_item(1) + await get_item(1) start = time.perf_counter() for _ in range(1000): - get_item(1) + await get_item(1) duration = time.perf_counter() - start avg_time_ms = (duration / 1000) * 1000 diff --git a/tests/test_sync_support.py b/tests/test_sync_support.py new file mode 100644 index 0000000..e552c20 --- /dev/null +++ b/tests/test_sync_support.py @@ -0,0 +1,63 @@ +import asyncio +import time +import pytest +from advanced_caching import TTLCache, SWRCache, BGCache + + +def test_ttl_sync_remains_sync(): + @TTLCache.cached("ttl_sync", ttl=60) + def sync_fn(x): + return x + 1 + + assert not asyncio.iscoroutinefunction(sync_fn) + assert sync_fn(1) == 2 + + +@pytest.mark.asyncio +async def test_ttl_async_remains_async(): + @TTLCache.cached("ttl_async", ttl=60) + async def async_fn(x): + return x + 1 + + assert asyncio.iscoroutinefunction(async_fn) + assert await async_fn(1) == 2 + + +def test_swr_sync_remains_sync(): + @SWRCache.cached("swr_sync", ttl=60) + def sync_fn(x): + return x + 1 + + assert not asyncio.iscoroutinefunction(sync_fn) + assert sync_fn(1) == 2 + + +@pytest.mark.asyncio +async def test_swr_async_remains_async(): + @SWRCache.cached("swr_async", ttl=60) + async def async_fn(x): + return x + 1 + + assert asyncio.iscoroutinefunction(async_fn) + assert await async_fn(1) == 2 + + +def test_bg_sync_remains_sync(): + @BGCache.register_loader("bg_sync", interval_seconds=60) + def sync_loader(): + return 42 + + assert not asyncio.iscoroutinefunction(sync_loader) + assert sync_loader() == 42 + BGCache.shutdown() + + +@pytest.mark.asyncio +async def test_bg_async_remains_async(): + @BGCache.register_loader("bg_async", interval_seconds=60) + async def async_loader(): + return 42 + + assert asyncio.iscoroutinefunction(async_loader) + assert await async_loader() == 42 + BGCache.shutdown() diff --git a/uv.lock b/uv.lock index bfafc6e..5641f7e 100644 --- a/uv.lock +++ b/uv.lock @@ -8,7 +8,7 @@ resolution-markers = [ [[package]] name = "advanced-caching" -version = "0.1.6" +version = "0.2.0" source = { editable = "." } dependencies = [ { name = "apscheduler" }, @@ -27,6 +27,7 @@ redis = [ [package.dev-dependencies] dev = [ { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "ruff" }, { name = "scalene" }, @@ -46,6 +47,7 @@ provides-extras = ["redis", "dev"] [package.metadata.requires-dev] dev = [ { name = "pytest", specifier = ">=8.2" }, + { name = "pytest-asyncio", specifier = ">=1.3.0" }, { name = "pytest-cov", specifier = ">=4.0" }, { name = "ruff", specifier = ">=0.14.8" }, { name = "scalene", specifier = ">=1.5.55" }, @@ -82,6 +84,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fe/ba/e2081de779ca30d473f21f5b30e0e737c438205440784c7dfc81efc2b029/async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c", size = 6233 }, ] +[[package]] +name = "backports-asyncio-runner" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/ff/70dca7d7cb1cbc0edb2c6cc0c38b65cba36cccc491eca64cabd5fe7f8670/backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162", size = 69893 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313 }, +] + [[package]] name = "certifi" version = "2025.11.12" @@ -907,6 +918,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801 }, ] +[[package]] +name = "pytest-asyncio" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "backports-asyncio-runner", marker = "python_full_version < '3.11'" }, + { name = "pytest" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075 }, +] + [[package]] name = "pytest-cov" version = "7.0.0" From ac502003009e26c88394689779f5f0b5422c6a63 Mon Sep 17 00:00:00 2001 From: AhmedGoudaa Date: Tue, 23 Dec 2025 16:37:20 +0400 Subject: [PATCH 2/4] asyncio support --- tests/test_correctness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_correctness.py b/tests/test_correctness.py index 9736350..a3428d2 100644 --- a/tests/test_correctness.py +++ b/tests/test_correctness.py @@ -257,7 +257,7 @@ def load_data(): assert call_count["count"] == 0 # First call will execute the function since cache is empty - result = await load_data() + result = load_data() assert result == {"value": 1} assert call_count["count"] == 1 From fdc8d55f60d25f1318e6bab9c171ed5749dbc2f7 Mon Sep 17 00:00:00 2001 From: AhmedGoudaa Date: Tue, 23 Dec 2025 16:43:21 +0400 Subject: [PATCH 3/4] deterministic tests --- tests/test_correctness.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/tests/test_correctness.py b/tests/test_correctness.py index a3428d2..35b53f7 100644 --- a/tests/test_correctness.py +++ b/tests/test_correctness.py @@ -306,9 +306,11 @@ async def load2(): async def test_error_handling(self): """Test error handler is called on failure.""" errors = [] + error_event = asyncio.Event() def error_handler(e): errors.append(e) + error_event.set() @BGCache.register_loader( "error_test", @@ -319,7 +321,10 @@ def error_handler(e): async def load_data(): raise ValueError("Test error") - await asyncio.sleep(0.1) + try: + await asyncio.wait_for(error_event.wait(), timeout=1.0) + except asyncio.TimeoutError: + pytest.fail("Error handler was not called within 1s") # Error should have been captured assert len(errors) == 1 @@ -329,19 +334,31 @@ async def load_data(): async def test_periodic_refresh(self): """Test that data refreshes periodically.""" call_count = {"count": 0} + load_event = asyncio.Event() - @BGCache.register_loader("periodic", interval_seconds=0.2, run_immediately=True) + @BGCache.register_loader("periodic", interval_seconds=0.1, run_immediately=True) async def load_data(): call_count["count"] += 1 + load_event.set() return {"value": call_count["count"]} # Wait for initial load - await asyncio.sleep(0.1) - assert call_count["count"] == 1 + try: + await asyncio.wait_for(load_event.wait(), timeout=1.0) + except asyncio.TimeoutError: + pytest.fail("Initial load did not complete within 1s") + + assert call_count["count"] >= 1 + initial_count = call_count["count"] # Wait for one refresh - await asyncio.sleep(0.3) - assert call_count["count"] >= 2 + load_event.clear() + try: + await asyncio.wait_for(load_event.wait(), timeout=1.0) + except asyncio.TimeoutError: + pytest.fail("Periodic refresh did not occur within 1s") + + assert call_count["count"] > initial_count # Get updated data result = await load_data() From 1c3cfe4b0f0144489c18115fb065a498d44dcc87 Mon Sep 17 00:00:00 2001 From: AhmedGoudaa Date: Tue, 23 Dec 2025 16:48:04 +0400 Subject: [PATCH 4/4] comments --- src/advanced_caching/decorators.py | 3 --- tests/test_integration_redis.py | 4 ++-- tests/test_sync_support.py | 1 - 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/advanced_caching/decorators.py b/src/advanced_caching/decorators.py index e5d6437..ec2efeb 100644 --- a/src/advanced_caching/decorators.py +++ b/src/advanced_caching/decorators.py @@ -10,11 +10,8 @@ from __future__ import annotations import asyncio -import atexit import logging -import os import time -from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta from typing import Callable, TypeVar diff --git a/tests/test_integration_redis.py b/tests/test_integration_redis.py index b9461de..3603a28 100644 --- a/tests/test_integration_redis.py +++ b/tests/test_integration_redis.py @@ -30,7 +30,7 @@ @pytest.fixture(autouse=True) -async def reset_scheduler(): +def reset_scheduler(): yield BGCache.shutdown(wait=False) @@ -249,7 +249,7 @@ async def get_product(*, product_id: int): assert result1 == {"id": 100, "name": "Product100"} assert calls["n"] == 1 - result2 = await get_product(product_id=100) + await get_product(product_id=100) assert calls["n"] == 1 diff --git a/tests/test_sync_support.py b/tests/test_sync_support.py index e552c20..f8c6a03 100644 --- a/tests/test_sync_support.py +++ b/tests/test_sync_support.py @@ -1,5 +1,4 @@ import asyncio -import time import pytest from advanced_caching import TTLCache, SWRCache, BGCache