diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index d2041de6..b4858038 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -19,6 +19,7 @@ from collections.abc import Iterator, Sequence import copy import functools +import math from multiprocessing import queues import queue import threading @@ -69,6 +70,10 @@ def _initialize_prefetch_stats( ) +def _is_batch_iter_pushdown_experiment_enabled() -> bool: + return False + + @dataset_stats.trace_input_pipeline_prefetch def _getitem( stats: dataset_stats.Stats, parent: dataset.MapDataset[T], index: int @@ -77,6 +82,16 @@ def _getitem( return stats.record_bytes_consumed(parent[index]) +@dataset_stats.trace_input_pipeline_prefetch +def _getitems( + stats: dataset_stats.Stats, + parent: dataset.MapDataset[T], + indices: list[int], +) -> list[T]: + """Helper to record the memory usage of the elements before prefetching.""" + return [stats.record_bytes_consumed(x) for x in parent._getitems(indices)] # pylint: disable=protected-access + + @typing.runtime_checkable class SupportsInPlaceSlicing(Protocol): """Datasets that support mutation by setting the processed data slice.""" @@ -114,6 +129,10 @@ def __str__(self) -> str: ) def __iter__(self) -> dataset.DatasetIterator[T]: + if _is_batch_iter_pushdown_experiment_enabled(): + return _BatchedPrefetchDatasetIterator( + self._parent, self._read_options, self._allow_nones + ) return PrefetchDatasetIterator( self._parent, self._read_options, self._allow_nones ) @@ -141,14 +160,15 @@ def __init__( self._read_options = read_options self._next_returned_index = 0 self._next_buffered_index = 0 + # Buffer of (future, batch_size) tuples for prefetched elements. self._buffer = collections.deque() self._lock = threading.Lock() - self._prefetch_buffer_size = ( + self._target_buffer_size = ( read_options.prefetch_buffer_size if read_options.num_threads > 0 else 0 ) self._num_threads = read_options.num_threads self._allow_nones = allow_nones - if self._prefetch_buffer_size > 0: + if self._target_buffer_size > 0: self._executor = futures.ThreadPoolExecutor( self._num_threads, thread_name_prefix="grain-prefetch" ) @@ -194,25 +214,27 @@ def __next__(self) -> T: if self._next_returned_index == self._dataset_length: break with self._lock, timer: - if self._prefetch_buffer_size > 0: + if self._target_buffer_size > 0: if not self._buffer: # Fill the buffer on the first iteration. self._fill_buffer() - element = self._buffer.popleft() + future, _ = self._buffer.popleft() # Prefetch elements until the buffer is full again. self._fill_buffer() - element = element.result() + element = future.result() else: # In case prefetch buffer size was decreased, we still want to consume # the already prefetched elements. if self._buffer: - element = self._buffer.popleft().result() + future, _ = self._buffer.popleft() + element = future.result() else: element = self._stats.record_bytes_consumed( self._map_parent[self._next_returned_index] ) self._next_buffered_index += 1 self._next_returned_index += 1 + return_element = self._allow_nones or element is not None self._threshold_checker.check(return_element) if return_element: @@ -224,23 +246,26 @@ def __next__(self) -> T: def get_state(self): return {"next_index": self._next_returned_index} + def _set_state_helper(self, state): + self._next_returned_index = state["next_index"] + self._next_buffered_index = self._next_returned_index + if ( + self._next_returned_index < 0 + or self._next_returned_index > self._dataset_length + ): + raise IndexError( + f"Checkpoint `next_index` {self._next_returned_index} is out of" + f" range for dataset of length {self._dataset_length}." + ) + if self._target_buffer_size > 0: + # Cancel all pending futures in the buffer. + while self._buffer: + future, _ = self._buffer.popleft() + future.cancel() + def set_state(self, state): with self._lock: - self._next_returned_index = state["next_index"] - self._next_buffered_index = self._next_returned_index - if ( - self._next_returned_index < 0 - or self._next_returned_index > self._dataset_length - ): - raise IndexError( - f"Checkpoint `next_index` {self._next_returned_index} is out of" - f" range for dataset of length {self._dataset_length}." - ) - if self._prefetch_buffer_size > 0: - # Cancel all pending futures in the buffer. - while self._buffer: - future = self._buffer.popleft() - future.cancel() + self._set_state_helper(state) def _get_next_index(self) -> int: return self._next_returned_index @@ -254,12 +279,12 @@ def __str__(self) -> str: f" allow_nones={self._allow_nones})" ) - def set_prefetch_buffer_size(self, buffer_size: int): - self._prefetch_buffer_size = buffer_size + def set_target_buffer_size(self, buffer_size: int): + self._target_buffer_size = buffer_size # The executor is created in the constructor only if the prefetch buffer # size is greater than 0. If the user changes the prefetch buffer size, we # need to create or destroy the executor accordingly. - if self._prefetch_buffer_size > 0 and not hasattr(self, "_executor"): + if self._target_buffer_size > 0 and not hasattr(self, "_executor"): if self._num_threads == 0: raise ValueError( "num_threads must be greater than 0 when prefetch buffer size is" @@ -268,7 +293,7 @@ def set_prefetch_buffer_size(self, buffer_size: int): self._executor = futures.ThreadPoolExecutor( self._num_threads, thread_name_prefix="grain-prefetch" ) - elif self._prefetch_buffer_size == 0 and hasattr(self, "_executor"): + elif self._target_buffer_size == 0 and hasattr(self, "_executor"): self._executor.shutdown() delattr(self, "_executor") @@ -292,21 +317,22 @@ def set_num_threads(self, num_threads: int) -> None: def _fill_buffer(self): while ( - len(self._buffer) < self._prefetch_buffer_size + len(self._buffer) < self._target_buffer_size and self._next_buffered_index < self._dataset_length ): # Note that we trigger creation of `_stats` in this (single) thread, it is # important because the stats initialization is not thread-safe. - self._buffer.append( + self._buffer.append(( self._executor.submit( functools.partial(_getitem, self._stats, self._map_parent), self._next_buffered_index, - ) - ) + ), + 1, # batch_size = 1 when batch pushdown is not used. + )) self._next_buffered_index += 1 def start_prefetch(self): - if self._prefetch_buffer_size > 0: + if self._target_buffer_size > 0: self._fill_buffer() def close(self) -> None: @@ -319,10 +345,122 @@ def close(self) -> None: self._executor.shutdown(wait=False) # Cancel all pending futures in the buffer. while self._buffer: - future = self._buffer.popleft() + future, _ = self._buffer.popleft() future.cancel() +class _BatchedPrefetchDatasetIterator(PrefetchDatasetIterator[T]): + """Iterator that performs prefetching in batches using a thread pool.""" + + def __init__( + self, + parent: dataset.MapDataset[T], + read_options: grain_options.ReadOptions, + allow_nones: bool, + ): + super().__init__(parent, read_options, allow_nones) + # The number of elements to prefetch in each batch. + self._batch_pushdown_size = ( + int(math.ceil(self._target_buffer_size / self._num_threads)) + if self._target_buffer_size > 0 + else 1 + ) + # Queue of elements from the most recently completed prefetch batch. + self._current_batch = collections.deque() + # Total count of elements across all pending futures in _buffer. + self._total_buffered_count = 0 + + @dataset_stats.record_next_duration_if_output + @dataset_stats.trace_input_pipeline_next( + stage_category=dataset_stats.IPL_CAT_PREFETCH + ) + def __next__(self) -> T: + self._assert_not_closed() + # The time recorded here is the time spent in prefetch node to return an + # element, including the time spent in parent node. + timer = dataset_stats.Timer() + # We loop here to skip all None elements (in case the underlying dataset + # is sparse), if self._allow_nones = False, else we return Nones too. + while True: + with self._lock, timer: + if not self._current_batch: + if self._next_returned_index == self._dataset_length: + break + if self._target_buffer_size > 0: + if not self._buffer: + # Fill the buffer on the first iteration. + self._fill_buffer() + future, batch_size = self._buffer.popleft() + self._total_buffered_count -= batch_size + # Prefetch elements until the buffer is full again. + self._fill_buffer() + batch = future.result() + self._current_batch.extend(batch) + else: + # In case prefetch buffer size was decreased, we still want to + # consume the already prefetched elements. + if self._buffer: + future, batch_size = self._buffer.popleft() + batch = future.result() + self._total_buffered_count -= batch_size + self._current_batch.extend(batch) + else: + element = self._stats.record_bytes_consumed( + self._map_parent[self._next_returned_index] + ) + self._next_buffered_index += 1 + self._current_batch.append(element) + + element = self._current_batch.popleft() + self._next_returned_index += 1 + + return_element = self._allow_nones or element is not None + self._threshold_checker.check(return_element) + if return_element: + with self._stats.record_self_time(offset_ns=timer.value()): + element = self._stats.record_bytes_produced(element) + return self._stats.record_output_spec(element) + raise StopIteration + + def _set_state_helper(self, state): + super()._set_state_helper(state) + self._current_batch.clear() + self._total_buffered_count = 0 + + def _fill_buffer(self): + while ( + self._total_buffered_count < self._target_buffer_size + and self._next_buffered_index < self._dataset_length + ): + batch_size = min( + self._batch_pushdown_size, + self._dataset_length - self._next_buffered_index, + self._target_buffer_size - self._total_buffered_count, + ) + indices = list( + range( + self._next_buffered_index, self._next_buffered_index + batch_size + ) + ) + # Note that we trigger creation of `_stats` in this (single) thread, it is + # important because the stats initialization is not thread-safe. + self._buffer.append(( + self._executor.submit( + functools.partial(_getitems, self._stats, self._map_parent), + indices, + ), + batch_size, + )) + self._next_buffered_index += batch_size + self._total_buffered_count += batch_size + + def __str__(self) -> str: + return ( + f"_BatchedPrefetchDatasetIterator(read_options={self._read_options}," + f" allow_nones={self._allow_nones})" + ) + + def _set_slice_iter_dataset( ds: dataset.IterDataset, sl: slice, @@ -729,6 +867,7 @@ def is_prefetch_iterator(it: dataset.DatasetIterator) -> bool: it, ( PrefetchDatasetIterator, + _BatchedPrefetchDatasetIterator, ThreadPrefetchDatasetIterator, interleave.InterleaveDatasetIterator, ), diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index 0855b0b4..0a6b34cd 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -148,7 +148,7 @@ def test_prefetch_does_not_buffer_unnecessary_elements(self): _ = [next(ds_iter) for _ in range(5)] self.assertEmpty(ds_iter._buffer) # iterated through all elements - def test_set_prefetch_buffer_size_0_to_positive(self): + def test_set_target_buffer_size_0_to_positive(self): prefetch_lazy_iter_ds = prefetch.PrefetchIterDataset( self.range_ds, read_options=options.ReadOptions(prefetch_buffer_size=0) ) @@ -158,19 +158,19 @@ def test_set_prefetch_buffer_size_0_to_positive(self): # With prefetch_buffer_size=0, executor is not created. self.assertFalse(hasattr(ds_iter, '_executor')) - self.assertEqual(ds_iter._prefetch_buffer_size, 0) + self.assertEqual(ds_iter._target_buffer_size, 0) self.assertEqual(next(ds_iter), 0) # Setting prefetch_buffer_size to 2. - ds_iter.set_prefetch_buffer_size(2) - self.assertEqual(ds_iter._prefetch_buffer_size, 2) + ds_iter.set_target_buffer_size(2) + self.assertEqual(ds_iter._target_buffer_size, 2) self.assertEqual(next(ds_iter), 1) self.assertTrue(hasattr(ds_iter, '_executor')) self.assertLen(ds_iter._buffer, 2) self.assertEqual(next(ds_iter), 2) self.assertLen(ds_iter._buffer, 2) - def test_set_prefetch_buffer_size_positive_to_0(self): + def test_set_target_buffer_size_positive_to_0(self): prefetch_lazy_iter_ds = prefetch.PrefetchIterDataset( self.range_ds, read_options=options.ReadOptions(prefetch_buffer_size=2) ) @@ -178,13 +178,13 @@ def test_set_prefetch_buffer_size_positive_to_0(self): self.assertIsInstance(ds_iter, prefetch.PrefetchDatasetIterator) ds_iter = cast(prefetch.PrefetchDatasetIterator, ds_iter) - self.assertEqual(ds_iter._prefetch_buffer_size, 2) + self.assertEqual(ds_iter._target_buffer_size, 2) self.assertEqual(next(ds_iter), 0) self.assertLen(ds_iter._buffer, 2) # Setting prefetch_buffer_size to 0. - ds_iter.set_prefetch_buffer_size(0) - self.assertEqual(ds_iter._prefetch_buffer_size, 0) + ds_iter.set_target_buffer_size(0) + self.assertEqual(ds_iter._target_buffer_size, 0) # Should consume buffer first. self.assertEqual(next(ds_iter), 1) self.assertLen(ds_iter._buffer, 1) @@ -194,7 +194,7 @@ def test_set_prefetch_buffer_size_positive_to_0(self): self.assertEqual(next(ds_iter), 3) self.assertEmpty(ds_iter._buffer) - def test_set_prefetch_buffer_size_increase(self): + def test_set_target_buffer_size_increase(self): prefetch_lazy_iter_ds = prefetch.PrefetchIterDataset( self.range_ds, read_options=options.ReadOptions(prefetch_buffer_size=1) ) @@ -202,19 +202,19 @@ def test_set_prefetch_buffer_size_increase(self): self.assertIsInstance(ds_iter, prefetch.PrefetchDatasetIterator) ds_iter = cast(prefetch.PrefetchDatasetIterator, ds_iter) - self.assertEqual(ds_iter._prefetch_buffer_size, 1) + self.assertEqual(ds_iter._target_buffer_size, 1) self.assertEqual(next(ds_iter), 0) self.assertLen(ds_iter._buffer, 1) # Setting prefetch_buffer_size to 2. - ds_iter.set_prefetch_buffer_size(2) - self.assertEqual(ds_iter._prefetch_buffer_size, 2) + ds_iter.set_target_buffer_size(2) + self.assertEqual(ds_iter._target_buffer_size, 2) self.assertEqual(next(ds_iter), 1) self.assertLen(ds_iter._buffer, 2) self.assertEqual(next(ds_iter), 2) self.assertLen(ds_iter._buffer, 2) - def test_set_prefetch_buffer_size_decrease(self): + def test_set_target_buffer_size_decrease(self): prefetch_lazy_iter_ds = prefetch.PrefetchIterDataset( self.range_ds, read_options=options.ReadOptions(prefetch_buffer_size=2) ) @@ -222,13 +222,13 @@ def test_set_prefetch_buffer_size_decrease(self): self.assertIsInstance(ds_iter, prefetch.PrefetchDatasetIterator) ds_iter = cast(prefetch.PrefetchDatasetIterator, ds_iter) - self.assertEqual(ds_iter._prefetch_buffer_size, 2) + self.assertEqual(ds_iter._target_buffer_size, 2) self.assertEqual(next(ds_iter), 0) self.assertLen(ds_iter._buffer, 2) # Setting prefetch_buffer_size to 1. - ds_iter.set_prefetch_buffer_size(1) - self.assertEqual(ds_iter._prefetch_buffer_size, 1) + ds_iter.set_target_buffer_size(1) + self.assertEqual(ds_iter._target_buffer_size, 1) self.assertEqual(next(ds_iter), 1) self.assertLen(ds_iter._buffer, 1) self.assertEqual(next(ds_iter), 2)