diff --git a/examples/slice_reader.py b/examples/slice_reader.py new file mode 100644 index 0000000..5224210 --- /dev/null +++ b/examples/slice_reader.py @@ -0,0 +1,25 @@ +import libcachesim as lcs +import logging +logging.basicConfig(level=logging.DEBUG) + + +URI = "s3://cache-datasets/cache_dataset_oracleGeneral/2007_msr/msr_hm_0.oracleGeneral.zst" +reader = lcs.TraceReader( + trace = URI, + trace_type = lcs.TraceType.ORACLE_GENERAL_TRACE, + reader_init_params = lcs.ReaderInitParam(ignore_obj_size=False) +) + +for req in reader[:3]: + print(req.obj_id, req.obj_size) + +for req in reader[1:4]: + print(req.obj_id, req.obj_size) + +reader.reset() +read_n_req = 4 +for req in reader: + if read_n_req <= 0: + break + print(req.obj_id, req.obj_size) + read_n_req -= 1 \ No newline at end of file diff --git a/libcachesim/cache.py b/libcachesim/cache.py index 89ba182..7102c4f 100644 --- a/libcachesim/cache.py +++ b/libcachesim/cache.py @@ -81,6 +81,9 @@ def get_occupied_byte(self) -> int: def get_n_obj(self) -> int: return self._cache.get_n_obj() + + def set_cache_size(self, new_size: int) -> None: + self._cache.set_cache_size(new_size) def print_cache(self) -> str: return self._cache.print_cache() diff --git a/libcachesim/synthetic_reader.py b/libcachesim/synthetic_reader.py index b2e4d10..e746dc0 100644 --- a/libcachesim/synthetic_reader.py +++ b/libcachesim/synthetic_reader.py @@ -13,6 +13,28 @@ from .protocols import ReaderProtocol +class SyntheticReaderSliceIterator: + """Iterator for sliced SyntheticReader.""" + + def __init__(self, reader: "SyntheticReader", start: int, stop: int, step: int): + self.reader = reader + self.start = start + self.stop = stop + self.step = step + self.current = start + + def __iter__(self) -> Iterator[Request]: + return self + + def __next__(self) -> Request: + if self.current >= self.stop: + raise StopIteration + + req = self.reader[self.current] + self.current += self.step + return req + + class SyntheticReader(ReaderProtocol): """Efficient synthetic request generator supporting multiple distributions""" @@ -206,19 +228,29 @@ def __next__(self) -> Request: return self.read_one_req() - def __getitem__(self, index: int) -> Request: - """Support index access""" - if index < 0 or index >= self.num_of_req: - raise IndexError("Index out of range") + def __getitem__(self, key: Union[int, slice]) -> Union[Request, SyntheticReaderSliceIterator]: + """Support index and slice access""" + if isinstance(key, slice): + # Handle slice + start, stop, step = key.indices(self.num_of_req) + return SyntheticReaderSliceIterator(self, start, stop, step) + elif isinstance(key, int): + # Handle single index + if key < 0: + key += self.num_of_req + if key < 0 or key >= self.num_of_req: + raise IndexError("Index out of range") - req = Request() - obj_id = self.obj_ids[index] - req.obj_id = obj_id - req.obj_size = self.obj_size - req.clock_time = index * self.time_span // self.num_of_req - req.op = ReqOp.OP_READ - req.valid = True - return req + req = Request() + obj_id = self.obj_ids[key] + req.obj_id = obj_id + req.obj_size = self.obj_size + req.clock_time = key * self.time_span // self.num_of_req + req.op = ReqOp.OP_READ + req.valid = True + return req + else: + raise TypeError("SyntheticReader indices must be integers or slices") def _gen_zipf(m: int, alpha: float, n: int, start: int = 0) -> np.ndarray: diff --git a/libcachesim/trace_reader.py b/libcachesim/trace_reader.py index 53af0e4..b6aa873 100644 --- a/libcachesim/trace_reader.py +++ b/libcachesim/trace_reader.py @@ -1,14 +1,15 @@ """Wrapper of Reader with S3 support.""" +from __future__ import annotations import logging -from typing import overload, Union, Optional +from typing import overload, Union, Optional, Any from collections.abc import Iterator from urllib.parse import urlparse from .protocols import ReaderProtocol from .libcachesim_python import ( TraceType, - SamplerType, + TraceFormat, Request, ReaderInitParam, Reader, @@ -21,6 +22,78 @@ logger = logging.getLogger(__name__) +class TraceReaderSliceIterator: + """Iterator for sliced TraceReader.""" + + def __init__(self, reader: "TraceReader", start: int, stop: int, step: int): + # Clone the reader to avoid side effects on the original + self.reader = reader.clone() + self.start = start + self.stop = stop + self.step = step + self.current = start + + # Initialize position: reset and skip to start position once + self.reader.reset() + if start > 0: + self._skip_to_start_position(start) + + def __iter__(self) -> Iterator[Request]: + return self + + def __next__(self) -> Request: + if self.current >= self.stop: + raise StopIteration + + # Read the current request + try: + req = self.reader.read_one_req() + except RuntimeError: + raise StopIteration + + # Advance to next position based on step + if self.step > 1: + self._skip_requests(self.step - 1) + + self.current += self.step + return req + + def _skip_to_start_position(self, position: int) -> None: + """Skip to the start position efficiently.""" + if not self.reader._reader.is_zstd_file: + # Try using skip_n_req for non-zstd files + skipped = self.reader.skip_n_req(position) + if skipped != position: + # If we couldn't skip the expected number, simulate the rest + remaining = position - skipped + self._simulate_skip(remaining) + else: + # For zstd files, always simulate + self._simulate_skip(position) + + def _skip_requests(self, n: int) -> None: + """Skip n requests efficiently.""" + if not self.reader._reader.is_zstd_file: + # Try using skip_n_req for non-zstd files + skipped = self.reader.skip_n_req(n) + if skipped != n: + # If we couldn't skip all, we're likely at EOF + self.current = self.stop # Mark as done + else: + # For zstd files, simulate + self._simulate_skip(n) + + def _simulate_skip(self, n: int) -> None: + """Simulate skip by reading requests one by one.""" + for _ in range(n): + try: + self.reader.read_one_req() + except RuntimeError: + # If we can't read more, we're at EOF + self.current = self.stop # Mark as done + break + + class TraceReader(ReaderProtocol): _reader: Reader @@ -302,10 +375,51 @@ def __next__(self) -> Request: raise StopIteration return req - def __getitem__(self, index: int) -> Request: - if index < 0 or index >= self._reader.get_num_of_req(): - raise IndexError("Index out of range") - self._reader.reset() - self._reader.skip_n_req(index) - req = Request() - return self._reader.read_one_req(req) + def __getitem__(self, key: Union[int, slice]) -> Union[Request, TraceReaderSliceIterator]: + if isinstance(key, slice): + # Handle slice + total_len = self._reader.get_num_of_req() + start, stop, step = key.indices(total_len) + return TraceReaderSliceIterator(self, start, stop, step) + elif isinstance(key, int): + # Handle single index + total_len = self._reader.get_num_of_req() + if key < 0: + key += total_len + if key < 0 or key >= total_len: + raise IndexError("Index out of range") + + self._reader.reset() + + # Try to skip to the target position + if key > 0: + if not self._reader.is_zstd_file: + # For non-zstd files, try skip_n_req and check return value + skipped = self._reader.skip_n_req(key) + if skipped != key: + # If we couldn't skip the expected number, simulate the rest + remaining = key - skipped + self._simulate_skip_single(remaining) + else: + # For zstd files, always simulate + self._simulate_skip_single(key) + + # Read the target request + req = Request() + ret = self._reader.read_one_req(req) + if ret != 0: + raise IndexError(f"Cannot read request at index {key}") + return req + else: + raise TypeError("TraceReader indices must be integers or slices") + + def _simulate_skip_single(self, n: int) -> None: + """Simulate skip by reading requests one by one for single index access.""" + for i in range(n): + req = Request() + ret = self._reader.read_one_req(req) + if ret != 0: + raise IndexError(f"Cannot skip to position, reached EOF at {i}") + + # Note: Removed old inefficient methods _can_use_skip_n_req and _simulate_skip_and_read_single + # The new implementation is more efficient and handles skip_n_req return values properly diff --git a/src/export_cache.cpp b/src/export_cache.cpp index d7ec858..e34ddf5 100644 --- a/src/export_cache.cpp +++ b/src/export_cache.cpp @@ -352,6 +352,10 @@ void export_cache(py::module& m) { .def("get_occupied_byte", [](cache_t& self) { return self.get_occupied_byte(&self); }) .def("get_n_obj", [](cache_t& self) { return self.get_n_obj(&self); }) + .def( + "set_cache_size", + [](cache_t& self, uint64_t new_size) { self.cache_size = new_size; }, + "new_size"_a) .def("print_cache", [](cache_t& self) { // Capture stdout to return as string std::ostringstream captured_output; diff --git a/src/export_reader.cpp b/src/export_reader.cpp index eff5b31..8361ff8 100644 --- a/src/export_reader.cpp +++ b/src/export_reader.cpp @@ -98,6 +98,13 @@ void export_reader(py::module& m) { .value("UNKNOWN_TRACE", trace_type_e::UNKNOWN_TRACE) .export_values(); + // Trace format enumeration + py::enum_(m, "TraceFormat") + .value("BINARY_TRACE_FORMAT", trace_format_e::BINARY_TRACE_FORMAT) + .value("TXT_TRACE_FORMAT", trace_format_e::TXT_TRACE_FORMAT) + .value("INVALID_TRACE_FORMAT", trace_format_e::INVALID_TRACE_FORMAT) + .export_values(); + py::enum_(m, "ReadDirection") .value("READ_FORWARD", read_direction::READ_FORWARD) .value("READ_BACKWARD", read_direction::READ_BACKWARD) @@ -302,11 +309,9 @@ void export_reader(py::module& m) { .def( "skip_n_req", [](reader_t& self, int n) { - int ret = skip_n_req(&self, n); - if (ret != 0) { - throw std::runtime_error("Failed to skip requests"); - } - return ret; + int count = skip_n_req(&self, n); + // Return the actual number of requests skipped + return count; }, "n"_a) .def("read_one_req_above",