diff --git a/agentlightning/client.py b/agentlightning/client.py index 0848ce380..b232170be 100644 --- a/agentlightning/client.py +++ b/agentlightning/client.py @@ -22,6 +22,8 @@ logger = logging.getLogger(__name__) +DEFAULT_CACHE_CAPACITY = 100 + class AgentLightningClient: """Client wrapper for the legacy version-aware Agent Lightning server. @@ -64,7 +66,10 @@ def __init__(self, endpoint: str, poll_interval: float = 5.0, timeout: float = 1 self.task_count = 0 self.poll_interval = poll_interval self.timeout = timeout - self._resource_cache: Dict[str, ResourcesUpdate] = {} # TODO: mechanism to evict cache + + from agentlightning.utils.cache import LRUCache + + self._resource_cache: Dict[str, ResourcesUpdate] = LRUCache(capacity=DEFAULT_CACHE_CAPACITY) self._default_headers = {"X-AgentLightning-Client": "true"} async def _request_json_async(self, url: str) -> Optional[Dict[str, Any]]: diff --git a/agentlightning/server.py b/agentlightning/server.py index b1343fcb1..ae5320cd8 100644 --- a/agentlightning/server.py +++ b/agentlightning/server.py @@ -32,6 +32,8 @@ logger = logging.getLogger(__name__) +DEFAULT_CACHE_CAPACITY = 100 + class ServerDataStore: """Async-safe container for in-memory server state. @@ -50,8 +52,10 @@ def __init__(self): self._processing_tasks: Dict[str, Task] = {} # Currently processing tasks self._completed_rollouts: Dict[str, RolloutLegacy] = {} + from agentlightning.utils.cache import LRUCache + # Store for versioned resources - self._resource_versions: Dict[str, NamedResources] = {} + self._resource_versions: Dict[str, NamedResources] = LRUCache(capacity=DEFAULT_CACHE_CAPACITY) self._latest_resources_id: Optional[str] = None # Locks for thread-safe access diff --git a/agentlightning/utils/cache.py b/agentlightning/utils/cache.py new file mode 100644 index 000000000..92151cdf5 --- /dev/null +++ b/agentlightning/utils/cache.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft. All rights reserved. + +from collections import OrderedDict +from typing import Generic, Optional, TypeVar + +K = TypeVar("K") +V = TypeVar("V") + + +class LRUCache(OrderedDict[K, V]): + """A simple LRU (Least Recently Used) cache implementation using OrderedDict. + + This cache has a fixed capacity. When the cache is full, adding a new item + discards the least recently used item. + + Accessing an item (via `__getitem__` or `get`) moves it to the end, marking + it as recently used. + """ + + def __init__(self, capacity: int, *args, **kwargs): + self.capacity = capacity + super().__init__(*args, **kwargs) + + def __getitem__(self, key: K) -> V: + value = super().__getitem__(key) + self.move_to_end(key) + return value + + def __setitem__(self, key: K, value: V): + if key in self: + self.move_to_end(key) + super().__setitem__(key, value) + if len(self) > self.capacity: + self.popitem(last=False) + + def get(self, key: K, default: Optional[V] = None) -> Optional[V]: + if key in self: + value = super().__getitem__(key) + self.move_to_end(key) + return value + return default diff --git a/tests/test_cache_fix.py b/tests/test_cache_fix.py new file mode 100644 index 000000000..5da731c06 --- /dev/null +++ b/tests/test_cache_fix.py @@ -0,0 +1,48 @@ + +import pytest +from agentlightning.utils.cache import LRUCache + +def test_lru_cache_capacity(): + cache = LRUCache(capacity=2) + cache["a"] = 1 + cache["b"] = 2 + assert len(cache) == 2 + + cache["c"] = 3 + assert len(cache) == 2 + assert "a" not in cache + assert "b" in cache + assert "c" in cache + +def test_lru_cache_access_order(): + cache = LRUCache(capacity=2) + cache["a"] = 1 + cache["b"] = 2 + + # Access 'a' to make it most recently used + _ = cache["a"] + + # Add 'c', should evict 'b' (least recently used) + cache["c"] = 3 + + assert "b" not in cache + assert "a" in cache + assert "c" in cache + +def test_integration_client_import(): + try: + from agentlightning.client import AgentLightningClient + client = AgentLightningClient(endpoint="http://localhost:8000") + assert isinstance(client._resource_cache, LRUCache) + assert client._resource_cache.capacity == 100 + except ImportError: + pytest.fail("Could not import AgentLightningClient or verify LRUCache integration") + +def test_integration_server_import(): + try: + from agentlightning.server import ServerDataStore + store = ServerDataStore() + assert isinstance(store._resource_versions, LRUCache) + assert store._resource_versions.capacity == 100 + except ImportError: + pytest.fail("Could not import ServerDataStore or verify LRUCache integration")