Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 170 additions & 31 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections.abc import Iterator, Sequence
import copy
import functools
import math
from multiprocessing import queues
import queue
import threading
Expand Down Expand Up @@ -69,6 +70,10 @@ def _initialize_prefetch_stats(
)


def _is_batch_iter_pushdown_experiment_enabled() -> bool:
return False


@dataset_stats.trace_input_pipeline_prefetch
def _getitem(
stats: dataset_stats.Stats, parent: dataset.MapDataset[T], index: int
Expand All @@ -77,6 +82,16 @@ def _getitem(
return stats.record_bytes_consumed(parent[index])


@dataset_stats.trace_input_pipeline_prefetch
def _getitems(
stats: dataset_stats.Stats,
parent: dataset.MapDataset[T],
indices: list[int],
) -> list[T]:
"""Helper to record the memory usage of the elements before prefetching."""
return [stats.record_bytes_consumed(x) for x in parent._getitems(indices)] # pylint: disable=protected-access


@typing.runtime_checkable
class SupportsInPlaceSlicing(Protocol):
"""Datasets that support mutation by setting the processed data slice."""
Expand Down Expand Up @@ -114,6 +129,10 @@ def __str__(self) -> str:
)

def __iter__(self) -> dataset.DatasetIterator[T]:
if _is_batch_iter_pushdown_experiment_enabled():
return _BatchedPrefetchDatasetIterator(
self._parent, self._read_options, self._allow_nones
)
return PrefetchDatasetIterator(
self._parent, self._read_options, self._allow_nones
)
Expand Down Expand Up @@ -141,14 +160,15 @@ def __init__(
self._read_options = read_options
self._next_returned_index = 0
self._next_buffered_index = 0
# Buffer of (future, batch_size) tuples for prefetched elements.
self._buffer = collections.deque()
self._lock = threading.Lock()
self._prefetch_buffer_size = (
self._target_buffer_size = (
read_options.prefetch_buffer_size if read_options.num_threads > 0 else 0
)
self._num_threads = read_options.num_threads
self._allow_nones = allow_nones
if self._prefetch_buffer_size > 0:
if self._target_buffer_size > 0:
self._executor = futures.ThreadPoolExecutor(
self._num_threads, thread_name_prefix="grain-prefetch"
)
Expand Down Expand Up @@ -194,25 +214,27 @@ def __next__(self) -> T:
if self._next_returned_index == self._dataset_length:
break
with self._lock, timer:
if self._prefetch_buffer_size > 0:
if self._target_buffer_size > 0:
if not self._buffer:
# Fill the buffer on the first iteration.
self._fill_buffer()
element = self._buffer.popleft()
future, _ = self._buffer.popleft()
# Prefetch elements until the buffer is full again.
self._fill_buffer()
element = element.result()
element = future.result()
else:
# In case prefetch buffer size was decreased, we still want to consume
# the already prefetched elements.
if self._buffer:
element = self._buffer.popleft().result()
future, _ = self._buffer.popleft()
element = future.result()
else:
element = self._stats.record_bytes_consumed(
self._map_parent[self._next_returned_index]
)
self._next_buffered_index += 1
self._next_returned_index += 1

return_element = self._allow_nones or element is not None
self._threshold_checker.check(return_element)
if return_element:
Expand All @@ -224,23 +246,26 @@ def __next__(self) -> T:
def get_state(self):
return {"next_index": self._next_returned_index}

def _set_state_helper(self, state):
self._next_returned_index = state["next_index"]
self._next_buffered_index = self._next_returned_index
if (
self._next_returned_index < 0
or self._next_returned_index > self._dataset_length
):
raise IndexError(
f"Checkpoint `next_index` {self._next_returned_index} is out of"
f" range for dataset of length {self._dataset_length}."
)
if self._target_buffer_size > 0:
# Cancel all pending futures in the buffer.
while self._buffer:
future, _ = self._buffer.popleft()
future.cancel()

def set_state(self, state):
with self._lock:
self._next_returned_index = state["next_index"]
self._next_buffered_index = self._next_returned_index
if (
self._next_returned_index < 0
or self._next_returned_index > self._dataset_length
):
raise IndexError(
f"Checkpoint `next_index` {self._next_returned_index} is out of"
f" range for dataset of length {self._dataset_length}."
)
if self._prefetch_buffer_size > 0:
# Cancel all pending futures in the buffer.
while self._buffer:
future = self._buffer.popleft()
future.cancel()
self._set_state_helper(state)

def _get_next_index(self) -> int:
return self._next_returned_index
Expand All @@ -254,12 +279,12 @@ def __str__(self) -> str:
f" allow_nones={self._allow_nones})"
)

def set_prefetch_buffer_size(self, buffer_size: int):
self._prefetch_buffer_size = buffer_size
def set_target_buffer_size(self, buffer_size: int):
self._target_buffer_size = buffer_size
# The executor is created in the constructor only if the prefetch buffer
# size is greater than 0. If the user changes the prefetch buffer size, we
# need to create or destroy the executor accordingly.
if self._prefetch_buffer_size > 0 and not hasattr(self, "_executor"):
if self._target_buffer_size > 0 and not hasattr(self, "_executor"):
if self._num_threads == 0:
raise ValueError(
"num_threads must be greater than 0 when prefetch buffer size is"
Expand All @@ -268,7 +293,7 @@ def set_prefetch_buffer_size(self, buffer_size: int):
self._executor = futures.ThreadPoolExecutor(
self._num_threads, thread_name_prefix="grain-prefetch"
)
elif self._prefetch_buffer_size == 0 and hasattr(self, "_executor"):
elif self._target_buffer_size == 0 and hasattr(self, "_executor"):
self._executor.shutdown()
delattr(self, "_executor")

Expand All @@ -292,21 +317,22 @@ def set_num_threads(self, num_threads: int) -> None:

def _fill_buffer(self):
while (
len(self._buffer) < self._prefetch_buffer_size
len(self._buffer) < self._target_buffer_size
and self._next_buffered_index < self._dataset_length
):
# Note that we trigger creation of `_stats` in this (single) thread, it is
# important because the stats initialization is not thread-safe.
self._buffer.append(
self._buffer.append((
self._executor.submit(
functools.partial(_getitem, self._stats, self._map_parent),
self._next_buffered_index,
)
)
),
1, # batch_size = 1 when batch pushdown is not used.
))
self._next_buffered_index += 1

def start_prefetch(self):
if self._prefetch_buffer_size > 0:
if self._target_buffer_size > 0:
self._fill_buffer()

def close(self) -> None:
Expand All @@ -319,10 +345,122 @@ def close(self) -> None:
self._executor.shutdown(wait=False)
# Cancel all pending futures in the buffer.
while self._buffer:
future = self._buffer.popleft()
future, _ = self._buffer.popleft()
future.cancel()


class _BatchedPrefetchDatasetIterator(PrefetchDatasetIterator[T]):
"""Iterator that performs prefetching in batches using a thread pool."""

def __init__(
self,
parent: dataset.MapDataset[T],
read_options: grain_options.ReadOptions,
allow_nones: bool,
):
super().__init__(parent, read_options, allow_nones)
# The number of elements to prefetch in each batch.
self._batch_pushdown_size = (
int(math.ceil(self._target_buffer_size / self._num_threads))
if self._target_buffer_size > 0
else 1
)
# Queue of elements from the most recently completed prefetch batch.
self._current_batch = collections.deque()
# Total count of elements across all pending futures in _buffer.
self._total_buffered_count = 0

@dataset_stats.record_next_duration_if_output
@dataset_stats.trace_input_pipeline_next(
stage_category=dataset_stats.IPL_CAT_PREFETCH
)
def __next__(self) -> T:
self._assert_not_closed()
# The time recorded here is the time spent in prefetch node to return an
# element, including the time spent in parent node.
timer = dataset_stats.Timer()
# We loop here to skip all None elements (in case the underlying dataset
# is sparse), if self._allow_nones = False, else we return Nones too.
while True:
with self._lock, timer:
if not self._current_batch:
if self._next_returned_index == self._dataset_length:
break
if self._target_buffer_size > 0:
if not self._buffer:
# Fill the buffer on the first iteration.
self._fill_buffer()
future, batch_size = self._buffer.popleft()
self._total_buffered_count -= batch_size
# Prefetch elements until the buffer is full again.
self._fill_buffer()
batch = future.result()
self._current_batch.extend(batch)
else:
# In case prefetch buffer size was decreased, we still want to
# consume the already prefetched elements.
if self._buffer:
future, batch_size = self._buffer.popleft()
batch = future.result()
self._total_buffered_count -= batch_size
self._current_batch.extend(batch)
else:
element = self._stats.record_bytes_consumed(
self._map_parent[self._next_returned_index]
)
self._next_buffered_index += 1
self._current_batch.append(element)

element = self._current_batch.popleft()
self._next_returned_index += 1

return_element = self._allow_nones or element is not None
self._threshold_checker.check(return_element)
if return_element:
with self._stats.record_self_time(offset_ns=timer.value()):
element = self._stats.record_bytes_produced(element)
return self._stats.record_output_spec(element)
raise StopIteration

def _set_state_helper(self, state):
super()._set_state_helper(state)
self._current_batch.clear()
self._total_buffered_count = 0

def _fill_buffer(self):
while (
self._total_buffered_count < self._target_buffer_size
and self._next_buffered_index < self._dataset_length
):
batch_size = min(
self._batch_pushdown_size,
self._dataset_length - self._next_buffered_index,
self._target_buffer_size - self._total_buffered_count,
)
indices = list(
range(
self._next_buffered_index, self._next_buffered_index + batch_size
)
)
# Note that we trigger creation of `_stats` in this (single) thread, it is
# important because the stats initialization is not thread-safe.
self._buffer.append((
self._executor.submit(
functools.partial(_getitems, self._stats, self._map_parent),
indices,
),
batch_size,
))
self._next_buffered_index += batch_size
self._total_buffered_count += batch_size

def __str__(self) -> str:
return (
f"_BatchedPrefetchDatasetIterator(read_options={self._read_options},"
f" allow_nones={self._allow_nones})"
)


def _set_slice_iter_dataset(
ds: dataset.IterDataset,
sl: slice,
Expand Down Expand Up @@ -729,6 +867,7 @@ def is_prefetch_iterator(it: dataset.DatasetIterator) -> bool:
it,
(
PrefetchDatasetIterator,
_BatchedPrefetchDatasetIterator,
ThreadPrefetchDatasetIterator,
interleave.InterleaveDatasetIterator,
),
Expand Down
Loading