From bd69ab4c2241cc9fe5a924780fbbbeb6c7767069 Mon Sep 17 00:00:00 2001 From: Pierre Marcenac Date: Fri, 23 Jan 2026 08:44:39 -0800 Subject: [PATCH] Rollback changes in multiprocessing. PiperOrigin-RevId: 860124914 --- grain/_src/python/BUILD | 43 + grain/_src/python/data_loader.py | 84 +- grain/_src/python/data_loader_test.py | 76 ++ grain/_src/python/dataset/BUILD | 1 + grain/_src/python/dataset/dataset.py | 9 +- grain/_src/python/dataset/dataset_test.py | 34 + .../_src/python/dataset/transformations/BUILD | 4 - .../dataset/transformations/interleave.py | 8 +- .../dataset/transformations/prefetch.py | 432 ++++++++- .../dataset/transformations/prefetch_test.py | 580 ++++++++++++ grain/_src/python/grain_pool.py | 838 ++++++++++++++++++ grain/_src/python/grain_pool_test.py | 474 ++++++++++ 12 files changed, 2532 insertions(+), 51 deletions(-) create mode 100644 grain/_src/python/grain_pool.py create mode 100644 grain/_src/python/grain_pool_test.py diff --git a/grain/_src/python/BUILD b/grain/_src/python/BUILD index 3fe798268..62388508e 100644 --- a/grain/_src/python/BUILD +++ b/grain/_src/python/BUILD @@ -209,6 +209,49 @@ py_test( ], ) +py_library( + name = "grain_pool", + srcs = ["grain_pool.py"], + srcs_version = "PY3", + target_compatible_with = select({ + "@platforms//os:windows": ["@platforms//:incompatible"], + "//conditions:default": [], + }), + deps = [ + ":grain_logging", + ":multiprocessing_common", + ":options", + ":record", + ":shared_memory_array", + "//grain/_src/core:config", + "//grain/_src/core:monitoring", + "//grain/_src/core:parallel", + "//grain/_src/core:tree_lib", + "@abseil-py//absl/flags", + "@abseil-py//absl/logging", + "@pypi//cloudpickle:pkg", + ], +) + +py_test( + name = "grain_pool_test", + srcs = ["grain_pool_test.py"], + shard_count = 20, + srcs_version = "PY3", + tags = ["not_run:arm"], + deps = [ + ":data_sources", + ":grain_pool", + ":options", + ":record", + "//grain/_src/core:config", + "//grain/_src/core:monitoring", + "@abseil-py//absl/flags", + "@abseil-py//absl/testing:absltest", + "@abseil-py//absl/testing:parameterized", + ], +) + py_library( name = "load", srcs = ["load.py"], diff --git a/grain/_src/python/data_loader.py b/grain/_src/python/data_loader.py index bee44d937..fd911aacb 100644 --- a/grain/_src/python/data_loader.py +++ b/grain/_src/python/data_loader.py @@ -24,11 +24,13 @@ import sys from typing import Any, Awaitable, Optional, Sequence, TypeVar +from absl import logging from etils import epath from grain._src.core import monitoring as grain_monitoring from grain._src.core import sharding from grain._src.core import transforms from grain._src.core import tree_lib +import multiprocessing as mp from grain._src.python import operations as ops from grain._src.python import options from grain._src.python import record @@ -37,6 +39,8 @@ from grain._src.python.dataset import dataset from grain._src.python.dataset.transformations import batch as batch_ds from grain._src.python.dataset.transformations import flatmap +from grain._src.python.dataset.transformations import prefetch +from grain._src.python.operations import BatchOperation from grain._src.python.operations import Operation from grain._src.python.samplers import Sampler from grain._src.python.shared_memory_array import SharedMemoryArray @@ -256,15 +260,14 @@ def get_state(self): # structure and switch to native dataset checkpointing. This class can be # removed afterwards. dataset_state = self._parent.get_state() - if "iterators_in_use_states" not in dataset_state: - next_index_in_cycle = 0 - workers_state = [dataset_state] - else: - next_index_in_cycle = dataset_state["next_index_in_cycle"] - workers_state = dataset_state["iterators_in_use_states"] - - last_worker_index = next_index_in_cycle - 1 - worker_count = len(workers_state) + if "workers_state" not in dataset_state: + dataset_state = { + "workers_state": {"0": dataset_state}, + "last_worker_index": -1, + } + workers_state = dataset_state["workers_state"] + last_worker_index = dataset_state["last_worker_index"] + worker_count = len(dataset_state["workers_state"]) shard_index = self._shard_options.shard_index if self._shard_options else 0 shard_count = self._shard_options.shard_count if self._shard_options else 1 @@ -275,7 +278,7 @@ def get_state(self): str(i): ( local_offset + i * shard_count - + workers_state[i]["next_index"] * global_worker_count + + workers_state[str(i)]["next_index"] * global_worker_count ) for i in range(worker_count) } @@ -308,28 +311,25 @@ def set_state(self, state): self._parent.set_state(dataset_state) return - iterators_in_use_indices = list(range(worker_count)) - iterators_in_use_states = [ - { + iterations_to_skip = {str(i): 0 for i in range(worker_count)} + workers_state = { + str(i): { "next_index": ( ( - last_seen_indices[str(worker_index)] + last_seen_indices[str(i)] + global_worker_count - shard_index - - worker_index * shard_count + - i * shard_count ) // global_worker_count ) } - for worker_index in range(worker_count) - ] - + for i in range(worker_count) + } dataset_state = { - "next_index_in_cycle": last_worker_index + 1 % worker_count, - "next_index_in_datasets": worker_count, - "iterators_in_use_indices": iterators_in_use_indices, - "iterators_in_use_states": iterators_in_use_states, - "exhausted": [False] * worker_count, + "workers_state": workers_state, + "iterations_to_skip": iterations_to_skip, + "last_worker_index": last_worker_index, } self._parent.set_state(dataset_state) @@ -389,17 +389,31 @@ def __init__( f"Current worker_buffer_size is {worker_buffer_size}." ) - operations = list(operations) - for i in range(len(operations)): - op = operations[i] - if type(op) is ops.BatchOperation: # pylint: disable=unidiomatic-typecheck - operations[i] = transforms.Batch( - batch_size=op.batch_size, - drop_remainder=op.drop_remainder, - batch_fn=op.batch_fn, + worker_count = _determine_worker_count(worker_count) + if worker_count > 0: + + # Shared memory should be enabled iff worker_count > 0. + # This replaces Batch Transform with a BatchOperation in operations list + # if shared memory is enabled. + if operations and isinstance( + (last_op := operations[-1]), transforms.Batch + ): + logging.info("Creating BatchOperation to enable SharedMemoryArray.") + batch_operation = BatchOperation( + batch_size=last_op.batch_size, + drop_remainder=last_op.drop_remainder, + batch_fn=last_op.batch_fn, ) + batch_operation.disable_deprecation_message() + operations = list(operations) + operations[-1] = batch_operation - worker_count = _determine_worker_count(worker_count) + if operations and isinstance(operations[-1], BatchOperation): + logging.info("Enabling SharedMemoryArray for BatchOperation.") + operations[-1]._enable_shared_memory() + else: + logging.info("Adding CopyNumPyArrayToSharedMemory Map.") + operations = list(operations) + [CopyNumPyArrayToSharedMemory()] self._data_source = data_source self._sampler = sampler @@ -457,7 +471,11 @@ def _create_dataset(self) -> dataset.IterDataset: ds = _apply_transform_to_dataset(operation, ds) ds = ds.map(lambda r: r.data) if self.multiprocessing_options.num_workers > 0: - ds = ds.mp_prefetch(self.multiprocessing_options) + ds = prefetch.MultiprocessPrefetchIterDataset( + ds, + self.multiprocessing_options, + always_report_worker_state=True, + ) if not self._use_native_dataset_checkpointing: ds = _DataLoaderStateIterDataset( ds, diff --git a/grain/_src/python/data_loader_test.py b/grain/_src/python/data_loader_test.py index e945c6e06..22b9a54ba 100644 --- a/grain/_src/python/data_loader_test.py +++ b/grain/_src/python/data_loader_test.py @@ -705,6 +705,82 @@ def test_batch_transform_mapped_to_batch_operation(self): actual = list(data_loader) np.testing.assert_equal(actual, expected) + @mock.patch.object(data_loader_lib, "CopyNumPyArrayToSharedMemory") + def test_shared_memory_for_batch_operation( + self, mock_copy_numpy_array_to_shared_memory + ): + range_data_source = RangeDataSource(start=0, stop=8, step=1) + sampler = samplers.SequentialSampler( + num_records=len(range_data_source), shard_options=sharding.NoSharding() + ) + + operations = [ + PlusOne(), + FilterEven(), + ] + + batch_operation = mock.MagicMock(BatchOperation(batch_size=2)) + + data_loader = data_loader_lib.DataLoader( + data_source=range_data_source, + sampler=sampler, + operations=operations, + worker_count=0, + read_options=self.read_options, + ) + batch_operation._enable_shared_memory.assert_not_called() + self.assertTrue( + data_loader._operations[-1], mock_copy_numpy_array_to_shared_memory + ) + + data_loader = data_loader_lib.DataLoader( + data_source=range_data_source, + sampler=sampler, + operations=operations + [batch_operation], + worker_count=2, + read_options=self.read_options, + ) + batch_operation._enable_shared_memory.assert_called_once() + self.assertTrue(data_loader._operations[-1], batch_operation) + + @mock.patch.object(BatchOperation, "_enable_shared_memory", autospec=True) + def test_shared_memory_for_batch_transform(self, mock_enable_shared_memory): + range_data_source = RangeDataSource(start=0, stop=8, step=1) + sampler = samplers.SequentialSampler( + num_records=len(range_data_source), shard_options=sharding.NoSharding() + ) + operations = [ + PlusOne(), + FilterEven(), + ] + + data_loader = data_loader_lib.DataLoader( + data_source=range_data_source, + sampler=sampler, + operations=operations, + worker_count=2, + read_options=self.read_options, + ) + mock_enable_shared_memory.assert_not_called() + self.assertIsInstance( + data_loader._operations[-1], + data_loader_lib.CopyNumPyArrayToSharedMemory, + ) + + batch_transform = transforms.Batch(batch_size=2) + + data_loader = data_loader_lib.DataLoader( + data_source=range_data_source, + sampler=sampler, + operations=operations + [batch_transform], + worker_count=2, + read_options=self.read_options, + ) + mock_enable_shared_memory.assert_called_once_with( + data_loader._operations[-1] + ) + self.assertIsInstance(data_loader._operations[-1], BatchOperation) + def test_data_loader_with_batch_fn(self): # Map transforms elements to be [1, 2, 3, 4, 5, 6, 7, 8] # Filter keeps only even elements [2, 4, 6, 8] diff --git a/grain/_src/python/dataset/BUILD b/grain/_src/python/dataset/BUILD index 8e3e81ad4..cea30707e 100644 --- a/grain/_src/python/dataset/BUILD +++ b/grain/_src/python/dataset/BUILD @@ -53,6 +53,7 @@ py_library( "//grain/_src/core:transforms", "//grain/_src/core:tree_lib", "//grain/_src/python:grain_logging", + "//grain/_src/python:grain_pool", "//grain/_src/python:multiprocessing_common", "//grain/_src/python:options", "//grain/_src/python:shared_memory_array", diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index 2e77d65a1..210e0c594 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -1353,14 +1353,13 @@ def mp_prefetch( """ options = options or grain_options.MultiprocessingOptions(num_workers=10) - # Loaded lazily due to a circular dependency (dataset <-> process_prefetch). + # Loaded lazily due to a circular dependency (dataset <-> prefetch). # pylint: disable=g-import-not-at-top - from grain._src.python.dataset.transformations import process_prefetch + from grain._src.python.dataset.transformations import prefetch # pylint: enable=g-import-not-at-top - return process_prefetch.multiprocess_prefetch( + return prefetch.MultiprocessPrefetchIterDataset( self, - num_workers=options.num_workers, - buffer_size=options.per_worker_buffer_size, + multiprocessing_options=options, worker_init_fn=worker_init_fn, sequential_slice=sequential_slice, ) diff --git a/grain/_src/python/dataset/dataset_test.py b/grain/_src/python/dataset/dataset_test.py index 54411432f..88e67ccc0 100644 --- a/grain/_src/python/dataset/dataset_test.py +++ b/grain/_src/python/dataset/dataset_test.py @@ -1393,6 +1393,40 @@ def test_execution_summary_with_no_logging(self): log_value = "Grain Dataset Execution Summary" self.assertNotIn(log_value, "".join(logs.output)) + @flagsaver.flagsaver(grain_py_debug_mode=True) + @mock.patch.object(dataset_stats, "_REPORTING_PERIOD_SEC", 0.05) + def test_execution_summary_with_mp_prefetch(self): + def worker_init_fn_wrapper(worker_index, worker_count): + del worker_index, worker_count + dataset_stats._REPORTING_PERIOD_SEC = 0.05 + + ds = dataset.MapDataset.range(10000).map(MapAddingOne()) + ds = ds.to_iter_dataset() + ds = ds.mp_prefetch( + options.MultiprocessingOptions(num_workers=1), + worker_init_fn=worker_init_fn_wrapper, + ) + it = ds.__iter__() + _ = list(it) + all_nodes_present = False + while not all_nodes_present: + time.sleep(1) + all_nodes_present = True + summary = dataset.get_execution_summary(it) + node_names = {node.name for node in summary.nodes.values()} + all_nodes_present = all_nodes_present and any( + "RangeMapDataset" in name for name in node_names + ) + all_nodes_present = all_nodes_present and any( + "MapMapDataset" in name for name in node_names + ) + all_nodes_present = all_nodes_present and any( + "PrefetchDatasetIterator" in name for name in node_names + ) + all_nodes_present = all_nodes_present and any( + "MultiprocessPrefetchDatasetIterator" in name for name in node_names + ) + class GetElementSpecTest(parameterized.TestCase): diff --git a/grain/_src/python/dataset/transformations/BUILD b/grain/_src/python/dataset/transformations/BUILD index 21ddb6a47..ee15b9230 100644 --- a/grain/_src/python/dataset/transformations/BUILD +++ b/grain/_src/python/dataset/transformations/BUILD @@ -371,10 +371,6 @@ py_test( srcs = ["process_prefetch_test.py"], shard_count = 50, srcs_version = "PY3", - target_compatible_with = select({ - "@platforms//os:windows": ["@platforms//:incompatible"], - "//conditions:default": [], - }), deps = [ "//grain/_src/core:transforms", "//grain/_src/python:options", diff --git a/grain/_src/python/dataset/transformations/interleave.py b/grain/_src/python/dataset/transformations/interleave.py index 3539899bb..d69e21143 100644 --- a/grain/_src/python/dataset/transformations/interleave.py +++ b/grain/_src/python/dataset/transformations/interleave.py @@ -28,7 +28,7 @@ T = TypeVar("T") -class InterleaveDatasetIterator(dataset.DatasetIterator[T]): +class _InterleaveDatasetIterator(dataset.DatasetIterator[T]): """Iterates over the interleaved datasets.""" def __init__( @@ -282,7 +282,7 @@ def __str__(self) -> str: def _add_prefetch_and_make_iterator( ds: dataset.IterDataset[T] | dataset.MapDataset[T], - interleave_iterator: weakref.ref[InterleaveDatasetIterator[T]], + interleave_iterator: weakref.ref[_InterleaveDatasetIterator[T]], start_prefetch: bool, ) -> dataset.DatasetIterator[T]: """Adds prefetching to an IterDataset and returns an iterator. @@ -383,8 +383,8 @@ def __init__( self._make_iter_buffer_size = make_iter_buffer_size self._iter_buffer_size = iter_buffer_size - def __iter__(self) -> dataset.DatasetIterator[T]: - return InterleaveDatasetIterator( + def __iter__(self) -> _InterleaveDatasetIterator[T]: + return _InterleaveDatasetIterator( self._datasets, cycle_length=self._cycle_length, num_make_iter_threads=self._num_make_iter_threads, diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index f31b26f3b..9316419e3 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -16,24 +16,35 @@ from __future__ import annotations import collections -from collections.abc import Iterator, Sequence +from collections.abc import Callable, Iterator, Sequence +import contextlib import copy import functools +import math from multiprocessing import queues +from multiprocessing import synchronize import queue +import sys import threading +import time import typing -from typing import Any, Optional, Protocol, TypeVar +from typing import Any, Generic, Optional, Protocol, TypeVar +import cloudpickle from concurrent import futures from grain._src.core import monitoring as grain_monitoring +from grain._src.core import tree_lib +import multiprocessing as mp +from grain._src.python import grain_pool from grain._src.python import options as grain_options +from grain._src.python import shared_memory_array from grain._src.python.dataset import base from grain._src.python.dataset import dataset from grain._src.python.dataset import stats as dataset_stats from grain._src.python.dataset.transformations import filter as filter_dataset from grain._src.python.dataset.transformations import interleave from grain._src.python.dataset.transformations import source +import numpy as np T = TypeVar("T") @@ -323,6 +334,131 @@ def close(self) -> None: future.cancel() +def _iterator_with_context( + iterator: contextlib.AbstractContextManager[Iterator[T]], +) -> Iterator[T]: + with iterator as it: + yield from it + + +def _validate_no_double_prefetch( + parent: dataset.MapDataset | dataset.IterDataset, +) -> None: + """Checks that there are no multiple levels of parallelization.""" + to_check: list[dataset.MapDataset | dataset.IterDataset] = [parent] + while to_check: + ds = to_check.pop(0) + if isinstance(ds, MultiprocessPrefetchIterDataset): + raise ValueError( + "Nesting multiprocessing or multithreading is not allowed." + ) + to_check.extend(ds.parents) + + +class MultiprocessPrefetchIterDataset(dataset.IterDataset[T]): + """Uses a pool of processes to prefetch elements ahead of time. + + It usually makes sense to add this transformation in the end of the pipeline + since it will execute the parent IterDataset in multiple processes. + """ + + def __init__( + self, + parent: dataset.IterDataset[T], + multiprocessing_options: grain_options.MultiprocessingOptions, + worker_init_fn: Callable[[int, int], None] | None = None, + sequential_slice: bool = False, + always_report_worker_state: bool = False, + ): + if multiprocessing_options.num_workers < 0: + raise ValueError( + "`num_workers` must be greater than or equal to 0, got " + f"{multiprocessing_options.num_workers}." + ) + super().__init__(parent) + self._multiprocessing_options = multiprocessing_options + self._worker_init_fn = worker_init_fn + self._sequential_slice = sequential_slice + _validate_no_double_prefetch(self._parent) + self._always_report_worker_state = always_report_worker_state + + def __str__(self) -> str: + return ( + "MultiprocessPrefetchIterDataset(" + f"multiprocessing_options={self._multiprocessing_options})" + ) + + def __iter__(self) -> dataset.DatasetIterator[T]: + if self._multiprocessing_options.num_workers == 0: + return self._parent.__iter__() + return _MultiprocessPrefetchDatasetIterator( + self._parent, + self._multiprocessing_options, + self._worker_init_fn, + self._sequential_slice, + self._always_report_worker_state, + ) + + @property + def _element_spec(self) -> Any: + return dataset.get_element_spec(self._parent) + + +# Keys in `MultiprocessPrefetchDatasetIterator` checkpoints. +_WORKERS_STATE = "workers_state" +_ITERATIONS_TO_SKIP = "iterations_to_skip" +_LAST_WORKER_INDEX = "last_worker_index" + +# Minimal interval (in seconds) between consecutive state recordings in worker +# processes of `MultiprocessPrefetchDatasetIterator`. We record the state +# periodically to reduce the overhead of sending the state from workers. +# Note that this is also an approximate upper bound on how long it is going to +# take to recover from a checkpointed state. Larger values will decrease the +# overhead of sending the updated state but will also make recovery from a +# checkpoint longer on average. +_RECORD_STATE_INTERVAL_S = 3 + + +def _copy_leaf_to_shm(leaf: Any, min_size: int = 0) -> Any: + """Copies `leaf` to shared memory if it's a big enough numpy array.""" + if isinstance(leaf, shared_memory_array.SharedMemoryArray): + return leaf.metadata + if ( + not isinstance(leaf, np.ndarray) + or leaf.dtype.hasobject + or not leaf.flags.c_contiguous + or math.prod(leaf.shape) == 0 + or leaf.nbytes < min_size + ): + return leaf + + shared_memory_arr = shared_memory_array.SharedMemoryArray( + leaf.shape, leaf.dtype + ) + np.copyto(shared_memory_arr, leaf, casting="no") + return shared_memory_arr.metadata + + +def _copy_struct_to_shm(struct: Any, min_size: int = 0) -> Any: + """Copies leaf ndarrays of the structure to shared memory.""" + return tree_lib.map_structure( + functools.partial(_copy_leaf_to_shm, min_size=min_size), struct + ) + + +def _open_leaf_from_shm(leaf: Any) -> Any: + """Recovers `leaf` from shared memory if it's a numpy array metadata.""" + if isinstance(leaf, shared_memory_array.SharedMemoryArrayMetadata): + leaf = shared_memory_array.SharedMemoryArray.from_metadata(leaf) + leaf.unlink_on_del() + return leaf + + +def _open_struct_from_shm(struct: Any) -> Any: + """Recovers leaf ndarrays of the structure from shared memory.""" + return tree_lib.map_structure(_open_leaf_from_shm, struct) + + def _set_slice_iter_dataset( ds: dataset.IterDataset, sl: slice, @@ -379,6 +515,123 @@ def _set_slice_map_dataset( _set_slice_iter_dataset(parent, sl, sequential_slice) +def _check_picklable( + ds: dataset.IterDataset | dataset.MapDataset, +): + """Detects the first unpickle-able dataset in post-order. + + Args: + ds: IterDataset or MapDataset to check whether it is picklable. + + NOTE: This function's time complexity is O(n^2) where n is the number of + Grain dataset operations because `cloudpickle.dumps(ds)` will trigger + pickling into all the datasets. If this naive O(n^2) algorithm takes too + much time, we could consider doing copying `ds`, delete its parents and then + do `cloudpickle.dumps(new_ds)` to reduce the time complexity to O(n). + """ + + # Traverses the graph in post-order to find the first unpickle-able subtree + for parent in ds.parents: + _check_picklable(parent) + + try: + cloudpickle.dumps(ds) + except Exception as e: # pylint: disable=broad-exception-caught + if sys.version_info >= (3, 11): + e.add_note( + f"Dataset: {ds} cannot be pickled!" + ) + raise e + + +class GetElementProducerFn(grain_pool.GetElementProducerFn, Generic[T]): + """Implements `GetElementProducerFn` for `grain_pool.MultiProcessIterator`. + + This class implements `GetElementProducerFn` with `serialize` being overridden + to generate better error messages if user-provided dataset is not pickle-able. + """ + + def __init__( + self, + state: dict[str, dict[str, Any] | int], + ds: dataset.IterDataset[T], + sequential_slice: bool = False, + always_report_worker_state: bool = False, + ): + self._state = state + self._ds = ds + self._sequential_slice = sequential_slice + self._always_report_worker_state = always_report_worker_state + + def __call__( + self, + *, + worker_index: int, + worker_count: int, + start_profiling_event: synchronize.Event | None = None, + stop_profiling_event: synchronize.Event | None = None, + stats_out_queue: queues.Queue | None = None, # pylint: disable=g-bare-generic + ) -> Iterator[tuple[T, Optional[dict[str, Any]]]]: + if worker_count > 1: + _set_slice_iter_dataset( + self._ds, + slice(worker_index, None, worker_count), + self._sequential_slice, + ) + # Prevent OutputDatasetIterator injection in worker processes. + # The injection should only happen in the main process iterator, + # which wraps the _MultiprocessPrefetchDatasetIterator. + it = self._ds.__iter__() + it._ctx.mp_context = base.MultiprocessingContext( + process_index=worker_index, process_count=worker_count + ) + min_shm_size = it._ctx.dataset_options.min_shm_size + # Recover from the last recorded state for the given worker. + worker_state = self._state[_WORKERS_STATE][str(worker_index)] + if worker_state is not None: + it.set_state(worker_state) + # Set the stats queue in worker process to send stats to the main process. + it._stats._config.stats_out_queue = stats_out_queue # pytype: disable=attribute-error + # Skip the required number of iterations after the last recorded state. + for _ in range(self._state[_ITERATIONS_TO_SKIP][str(worker_index)]): + _ = next(it) + last_recorded_state_time = time.time() + for element in it: + now = time.time() + element = _copy_struct_to_shm(element, min_size=min_shm_size) + # If the node is prefetch, we already record the bytes produced in it's + # __next__ method. + if not it._stats._config.is_prefetch: + it._stats.record_bytes_produced(element) + if ( + self._always_report_worker_state + or now - last_recorded_state_time >= _RECORD_STATE_INTERVAL_S + ): + last_recorded_state_time = now + yield (element, it.get_state()) # pytype: disable=attribute-error + else: + yield (element, None) + + def serialize(self) -> bytes: + """Overrides the default implementation to generate better error messages.""" + + try: + return cloudpickle.dumps(self) + except Exception as e: # pylint: disable=broad-except + # Calls `_check_picklable` to generate useful pickle errors + # + # Note: No need to check `self._state` because it should not generate + # unpicklable errors and it is controlled by us, not from user's code + # in most cases. Except for the case when users try to implement their own + # `MapDataset` and `IterDataset` with custom pickle-ing logic that + # contains unpickle-able objects. + _check_picklable(self._ds) + + # If somehow we cannot find the dataset that is causing the pickle + # issues, just raise the original error + raise e + + def _get_dataset_options(ds: dataset.IterDataset) -> base.DatasetOptions: result = base.DatasetOptions() to_visit = [ds] @@ -390,6 +643,175 @@ def _get_dataset_options(ds: dataset.IterDataset) -> base.DatasetOptions: return result +class _MultiprocessPrefetchDatasetIterator(dataset.DatasetIterator[T]): + """Iterator that performs prefetching using a multiprocessing pool.""" + + def __init__( + self, + parent: dataset.IterDataset[T], + multiprocessing_options: grain_options.MultiprocessingOptions, + worker_init_fn: Callable[[int, int], None] | None = None, + sequential_slice: bool = False, + always_report_worker_state: bool = False, + ): + super().__init__() + self._iter_parent = parent + # Since the parent iterator is going to be created in each subprocess, and + # the options are propagated during iterator creation, we need to manually + # propagate them. + self._ctx.dataset_options = _get_dataset_options(parent) + self._multiprocessing_options = multiprocessing_options + self._worker_init_fn = worker_init_fn + self._sequential_slice = sequential_slice + # The underlying iterator producing elements and workers state. + self._iterator = None + # Raw reference to the underlying iterator that can be used to determine the + # last worker index. + self._raw_iterator = None + # Create initial state. We record state of each worker periodically together + # with the number of iterations without the recorded state and index of the + # last worker. + iterations_to_skip: dict[str, int] = { + str(i): 0 for i in range(multiprocessing_options.num_workers) + } + workers_state: dict[str, Any] = { + str(i): None for i in range(multiprocessing_options.num_workers) + } + self._stats_in_queues = tuple( + mp.get_context("spawn").Queue(maxsize=5) + for _ in range(multiprocessing_options.num_workers) + ) + self._start_profiling_event = mp.get_context("spawn").Event() + self._stop_profiling_event = mp.get_context("spawn").Event() + + self._state: dict[str, dict[str, Any] | int] = { + _WORKERS_STATE: workers_state, + _ITERATIONS_TO_SKIP: iterations_to_skip, + _LAST_WORKER_INDEX: -1, + } + + self._always_report_worker_state = always_report_worker_state + + def _initialize_stats( + self, execution_tracking_mode: base.ExecutionTrackingMode + ): + self._stats = _initialize_prefetch_stats( + self, + execution_tracking_mode, + parent_stats=[], + stats_in_queues=self._stats_in_queues, + ) + return self._stats + + @functools.cached_property + def _stats(self): + return self._initialize_stats( + self._ctx.dataset_options.execution_tracking_mode + ) + + def __iter__(self) -> dataset.DatasetIterator[T]: + return self + + @dataset_stats.record_next_duration_if_output + @dataset_stats.trace_input_pipeline_next( + stage_category=dataset_stats.IPL_CAT_PREFETCH + ) + def __next__(self) -> T: + self._assert_not_closed() + self._ensure_iterator_initialized() + # The time recorded here is the time spent in prefetch node to return an + # element, including the time spent in parent node. + timer = dataset_stats.Timer() + result, state = next(self._iterator) + with self._stats.record_self_time(offset_ns=timer.value()): + worker_index = self._raw_iterator.get_last_worker_index() # pytype: disable=attribute-error + + # pytype: disable=annotation-type-mismatch + iterations_to_skip: dict[str, Any] = self._state[_ITERATIONS_TO_SKIP] + worker_state: dict[str, Any] = self._state[_WORKERS_STATE] + # pytype: enable=annotation-type-mismatch + + self._state[_LAST_WORKER_INDEX] = worker_index + worker_index_str = str(worker_index) + if state is None: + iterations_to_skip[worker_index_str] += 1 + else: + iterations_to_skip[worker_index_str] = 0 + worker_state[worker_index_str] = state + result = self._stats.record_bytes_produced(result) + return _open_struct_from_shm(result) + + def start_prefetch(self) -> None: + """Prefetches elements from the iterator. + + This will run background processes for prefetching. To make sure to clean up + the resources, it should be followed by at least one `next` call. + """ + self._ensure_iterator_initialized() + + def set_state(self, state: dict[str, dict[str, Any] | int]) -> None: + self._state = state + self._raw_iterator = None + self._iterator = None + + def get_state(self) -> dict[str, Any]: + result = copy.deepcopy(self._state) + workers_state: dict[str, Any] = result[_WORKERS_STATE] # pytype: disable=annotation-type-mismatch + parent_state = None + for worker_index, worker_state in workers_state.items(): + # Create initial state from the parent iterator. This is to make sure the + # spec of the produced iterator does not change. + if worker_state is None: + parent_state = parent_state or self._iter_parent.__iter__().get_state() + workers_state[worker_index] = copy.deepcopy(parent_state) + return result + + def _ensure_iterator_initialized(self) -> None: + if self._iterator is None: + self._raw_iterator = self._create_iterator_context() + self._raw_iterator.start_prefetch() + self._iterator = _iterator_with_context(self._raw_iterator) + + def _create_iterator_context(self) -> grain_pool.MultiProcessIterator[T]: + """Creates a `MultiProcessIterator`.""" + # Apply the latest options to the subprocess dataset. We delay this until + # starting subprocesses because child iterators may update them. + ds = dataset.WithOptionsIterDataset( + self._iter_parent, self._ctx.dataset_options + ) + get_element_producer_fn = GetElementProducerFn( + self._state, + ds, + self._sequential_slice, + self._always_report_worker_state, + ) + + return grain_pool.MultiProcessIterator( + get_element_producer_fn, + self._multiprocessing_options, + (self._state[_LAST_WORKER_INDEX] + 1) + % self._multiprocessing_options.num_workers, + self._worker_init_fn, + self._start_profiling_event, + self._stop_profiling_event, + self._stats_in_queues, + ) + + def __str__(self) -> str: + return ( + "MultiprocessPrefetchDatasetIterator(" + f"multiprocessing_options={self._multiprocessing_options})" + ) + + def close(self) -> None: + """Shuts down the prefetching threads and multiprocessing pool.""" + if self._closed: + return + self._closed = True + if self._raw_iterator is not None: + self._raw_iterator.stop_prefetch() + + class ThreadPrefetchIterDataset(dataset.IterDataset[T]): """Iterable dataset that uses a synchronized queue for prefetching. @@ -565,8 +987,6 @@ def close(self): """Stops the iterator. No further calls to the iterator are expected.""" self._closed = True self._stop_prefetch() - if isinstance(self._maybe_nonnative_parent, dataset.DatasetIterator): - self._maybe_nonnative_parent.close() def _clear_buffer(self): while True: @@ -692,6 +1112,8 @@ def multithread_prefetch( if num_threads == 0: return ds + _validate_no_double_prefetch(ds) + shards = [] for i in range(num_threads): worker_ds = copy.deepcopy(ds) @@ -720,6 +1142,6 @@ def is_prefetch_iterator(it: dataset.DatasetIterator) -> bool: ( PrefetchDatasetIterator, ThreadPrefetchDatasetIterator, - interleave.InterleaveDatasetIterator, + _MultiprocessPrefetchDatasetIterator, ), ) diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index a530b1691..5b0e5a076 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from concurrent import futures import dataclasses +import logging as std_logging import platform import sys import threading @@ -19,6 +21,7 @@ from typing import TypeVar, cast from unittest import mock +from absl import logging from absl.testing import absltest from absl.testing import parameterized from grain._src.core import transforms @@ -442,6 +445,550 @@ def test_set_next_index(self): self.assertEqual(next(ds_iter), i) +class MultiprocessPrefetchIterDatasetTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + ds = dataset.MapDataset.range(20) + ds = prefetch.PrefetchIterDataset(ds, read_options=options.ReadOptions()) + self.iter_ds = ds.filter(FilterKeepingOddElementsOnly()) + + @parameterized.named_parameters( + dict( + testcase_name='0_workers', + num_workers=0, + per_worker_buffer_size=1, + ), + dict( + testcase_name='1_worker', + num_workers=1, + per_worker_buffer_size=1, + ), + dict( + testcase_name='1_worker_large_buffer', + num_workers=1, + per_worker_buffer_size=20, + ), + dict( + testcase_name='10_workers', + num_workers=10, + per_worker_buffer_size=1, + ), + dict( + testcase_name='10_workers_large_buffer', + num_workers=10, + per_worker_buffer_size=20, + ), + ) + def test_prefetch_data(self, num_workers: int, per_worker_buffer_size: int): + prefetch_lazy_iter_ds = prefetch.MultiprocessPrefetchIterDataset( + self.iter_ds, + options.MultiprocessingOptions(num_workers, per_worker_buffer_size), + ) + actual = list(prefetch_lazy_iter_ds) + expected = list(range(1, 20, 2)) + self.assertSequenceEqual(actual, expected) + + def test_prefetch_size_zero_data(self): + ds = dataset.MapDataset.source( + [np.zeros(shape=(0,), dtype=np.int64)] + ).repeat(3) + iter_ds = ds.to_iter_dataset() + prefetch_lazy_iter_ds = prefetch.MultiprocessPrefetchIterDataset( + iter_ds, + options.MultiprocessingOptions(num_workers=1), + ) + actual = list(prefetch_lazy_iter_ds) + expected = [np.zeros(shape=(0,), dtype=np.int64)] * 3 + self.assertLen(actual, 3) + self.assertLen(expected, 3) + for i in range(3): + np.testing.assert_array_equal(actual[i], expected[i]) + + @parameterized.product( + ( + dict( + num_workers=0, + record_state_interval=prefetch._RECORD_STATE_INTERVAL_S, + ), + dict( + num_workers=1, + record_state_interval=prefetch._RECORD_STATE_INTERVAL_S, + ), + dict( + num_workers=10, + record_state_interval=prefetch._RECORD_STATE_INTERVAL_S, + ), + dict( + num_workers=10, + record_state_interval=0, + ), + ), + step_index=[0, 3, 8], + ) + def test_checkpoint( + self, num_workers: int, record_state_interval: int, step_index: int + ): + with mock.patch.object( + prefetch, '_RECORD_STATE_INTERVAL_S', record_state_interval + ): + ds = prefetch.MultiprocessPrefetchIterDataset( + self.iter_ds, + options.MultiprocessingOptions(num_workers), + ) + ds_iter = ds.__iter__() + + max_steps = 10 + values_without_interruption = [] + checkpoints = [] + for _ in range(max_steps): + checkpoints.append(ds_iter.get_state()) + values_without_interruption.append(next(ds_iter)) + + ds_iter.set_state(checkpoints[step_index]) + for i in range(step_index, max_steps): + value = next(ds_iter) + self.assertEqual(value, values_without_interruption[i]) + + def test_set_state_twice(self): + with mock.patch.object(prefetch, '_RECORD_STATE_INTERVAL_S', 0): + ds = prefetch.MultiprocessPrefetchIterDataset( + self.iter_ds, + options.MultiprocessingOptions(2), + ) + ds_iter = ds.__iter__() + + max_steps = 10 + values_without_interruption = [] + checkpoints = [] + for _ in range(max_steps): + checkpoints.append(ds_iter.get_state()) + values_without_interruption.append(next(ds_iter)) + + for starting_step in [0, 3, 8]: + ds_iter.set_state(checkpoints[starting_step]) + for i in range(starting_step, max_steps): + value = next(ds_iter) + self.assertEqual(value, values_without_interruption[i]) + + def test_fails_with_negative_num_workers(self): + with self.assertRaisesRegex( + ValueError, '`num_workers` must be greater than or equal to 0' + ): + prefetch.MultiprocessPrefetchIterDataset( + self.iter_ds, + options.MultiprocessingOptions(num_workers=-1), + ) + + def test_fails_with_multiple_prefetches(self): + ds = prefetch.MultiprocessPrefetchIterDataset( + self.iter_ds, + options.MultiprocessingOptions(num_workers=10), + ) + with self.assertRaisesRegex( + ValueError, + 'Nesting multiprocessing or multithreading is not allowed.', + ): + _ = prefetch.MultiprocessPrefetchIterDataset( + ds, + options.MultiprocessingOptions(num_workers=1), + ) + + def test_works_with_iter_source_single_worker(self): + # Even though a pure IterDataset cannot be sliced, we should still be able + # to multiprocess-prefetch it with a single worker, since that doesn't + # require any slicing. + ds = prefetch.MultiprocessPrefetchIterDataset( + RepeatedIntSourceIterDataset().map(lambda x: x + 1), + options.MultiprocessingOptions(num_workers=1), + ) + ds_iter = iter(ds) + self.assertEqual(next(ds_iter), 2) + + def test_fails_with_iter_source_multiple_workers(self): + ds = prefetch.MultiprocessPrefetchIterDataset( + RepeatedIntSourceIterDataset().map(lambda x: x + 1), + options.MultiprocessingOptions(num_workers=2), + ) + ds_iter = iter(ds) + + with self.assertRaisesRegex( + Exception, + 'Cannot slice `IterDataset` source.', + ): + next(ds_iter) + + def test_propagates_transform_error(self): + error_msg = 'I shall fail!' + + def failing_transform(element): + del element + raise ValueError(error_msg) + + ds = prefetch.MultiprocessPrefetchIterDataset( + self.iter_ds.map(failing_transform), + options.MultiprocessingOptions(num_workers=1), + ) + with self.assertRaisesRegex(Exception, error_msg): + list(ds) + + def test_reports_worker_crash(self): + def failing_transform(element): + del element + sys.exit(123) + + ds = prefetch.MultiprocessPrefetchIterDataset( + self.iter_ds.map(failing_transform), + options.MultiprocessingOptions(num_workers=1), + ) + with self.assertRaisesRegex( + RuntimeError, 'was terminated unexpectedly with exit code 123' + ): + list(ds) + + def test_reports_unpicklable_transform(self): + class UnpicklableObject: + + def __getstate__(self): + raise ValueError('UnpicklableObject is not picklable') + + local_state = UnpicklableObject() + + ds = prefetch.MultiprocessPrefetchIterDataset( + self.iter_ds.map(lambda _: 1 if local_state is None else 2), + options.MultiprocessingOptions(num_workers=1), + ) + with self.assertRaisesRegex( + ValueError, 'UnpicklableObject is not picklable' + ) as context_manager: + list(ds) + + if sys.version_info >= (3, 11): + self.assertRegex( + ''.join(context_manager.exception.__notes__), + r'Dataset: MapIterDataset.* cannot be pickled!', + ) + + def test_reports_first_unpicklable_dataset_when_with_multiple_parents(self): + class UnpicklableObject: + + def __getstate__(self): + raise ValueError('UnpicklableObject is not picklable') + + local_unpicklable_obj = UnpicklableObject() + + class LeftTransform(transforms.Map): + + def map(self, x): + return x if local_unpicklable_obj else x + + class RightTransform(transforms.Map): + + def map(self, x): + return x if local_unpicklable_obj else x + + ds_left = dataset.MapDataset.range(0, 10) + ds_left = ds_left.map(LeftTransform()) + ds_right = dataset.MapDataset.range(10, 20) + ds_right = ds_right.map(RightTransform()) + + ds = dataset.MapDataset.mix([ds_left, ds_right], [1.0, 1.0]) + + iter_ds = ds.to_iter_dataset( + read_options=options.ReadOptions(prefetch_buffer_size=0) + ) + iter_ds = iter_ds.mp_prefetch() + + with self.assertRaisesRegex( + ValueError, + r'UnpicklableObject is not picklable', + ) as context_manager: + list(iter_ds) + + if sys.version_info >= (3, 11): + self.assertRegex( + ''.join(context_manager.exception.__notes__), + r'Dataset: MapMapDataset\(transform=LeftTransform\) cannot be' + r' pickled!', + ) + + def test_reports_unpicklable_issue_when_only_one_parent_unpicklable(self): + class UnpicklableObject: + + def __getstate__(self): + raise ValueError('UnpicklableObject is not picklable') + + class PickleableTransform(transforms.Map): + + def map(self, x): + return x + + local_unpicklable_obj = UnpicklableObject() + + class RightTransform(transforms.Map): + + def map(self, x): + return x if local_unpicklable_obj else x + + ds_left = dataset.MapDataset.range(0, 10) + ds_left = ds_left.map(PickleableTransform()) + ds_right = dataset.MapDataset.range(10, 20) + ds_right = ds_right.map(RightTransform()) + + ds = dataset.MapDataset.mix([ds_left, ds_right], [1.0, 1.0]) + + iter_ds = ds.to_iter_dataset( + read_options=options.ReadOptions(prefetch_buffer_size=0) + ) + iter_ds = iter_ds.mp_prefetch() + + with self.assertRaisesRegex( + ValueError, 'UnpicklableObject is not picklable' + ) as context_manager: + list(iter_ds) + + if sys.version_info >= (3, 11): + self.assertRegex( + ''.join(context_manager.exception.__notes__), + r'Dataset: MapMapDataset\(transform=RightTransform\) cannot be' + r' pickled!', + ) + + @parameterized.product( + start_prefetch_calls=[0, 1, 10], + num_workers=[6], + per_worker_buffer_size=[1, 20], + ) + def test_start_prefetch( + self, + start_prefetch_calls: int, + num_workers: int, + per_worker_buffer_size: int, + ): + + class _SleepTransform(transforms.Map): + + def map(self, features): + time.sleep(1) + return features + + ds = dataset.MapDataset.range(10) + ds = ds.map(_SleepTransform()) + ds = prefetch.PrefetchIterDataset(ds, read_options=options.ReadOptions()) + ds = prefetch.MultiprocessPrefetchIterDataset( + ds, + options.MultiprocessingOptions(num_workers, per_worker_buffer_size), + ) + + it = ds.__iter__() + for _ in range(start_prefetch_calls): + it.start_prefetch() + + # Waits for prefetching. + start_time = time.time() + while time.time() - start_time < 30: + time.sleep(2) + + # Measures time to read from the dataset. + start_time = time.time() + self.assertSequenceEqual(list(it), list(range(10))) + + time_to_fetch = time.time() - start_time + logging.info('Reading dataset took %.2f seconds.', time_to_fetch) + # Note that we can't reliably assert the upper bound on the time it takes + # read the dataset elements since worker startup time can vary a lot. + if not start_prefetch_calls: + self.assertGreater(time_to_fetch, 1) + + @parameterized.parameters(0, 0.5, 30) + def test_prefetch_but_no_read(self, sleep_s): + ds = dataset.MapDataset.source([1, 2, 3]).repeat() + ds = ds.filter(lambda x: x > 3) + ds = ds.to_iter_dataset() + ds = ds.mp_prefetch() + it = ds.__iter__() + it.start_prefetch() + time.sleep(sleep_s) + del it + + def test_prefetch_with_random_map(self): + ds = dataset.MapDataset.source([0]).repeat(100).to_iter_dataset() + ds = ds.random_map(lambda x, rng: x + rng.integers(sys.maxsize), seed=42) + ds = prefetch.MultiprocessPrefetchIterDataset( + ds, + options.MultiprocessingOptions(num_workers=5), + ) + # Make sure that sliced datasets on workers are seeded differently and thus + # produce different random elements. + elements = list(ds) + distinct_elements = set(elements) + self.assertLen(distinct_elements, len(elements)) + + def test_concurrent_start_prefetch(self): + num_iters = 10 # Can't set this much higher without Forge OOMing. + + def make_iter(i): + ds = dataset.MapDataset.source([i]) + ds = ds.to_iter_dataset() + ds = ds.mp_prefetch(options=options.MultiprocessingOptions(num_workers=1)) + return ds.__iter__() + + iters = [make_iter(i) for i in range(num_iters)] + with futures.ThreadPoolExecutor(max_workers=num_iters) as executor: + for it in iters: + executor.submit(it.start_prefetch) + for it in iters: + _ = next(it) + + def test_options_before_prefetch(self): + ds = dataset.MapDataset.source([1, 2, 3]).repeat(1000) + ds = ds.to_iter_dataset() + ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1) + ds = dataset.WithOptionsIterDataset(ds, ds_options) + ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=1)) + ds = ds.filter(lambda x: x > 2) + with self.assertRaises(Exception): + list(ds) + + def test_multiprocess_prefetch_with_sequential_slice(self): + ds = dataset.MapDataset.source(range(10)).to_iter_dataset() + ds = prefetch.MultiprocessPrefetchIterDataset( + ds, + options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), + sequential_slice=True, + ) + self.assertEqual(list(ds), [0, 4, 7, 1, 5, 8, 2, 6, 9, 3]) + + def test_multiprocess_prefetch_with_default_slice_non_sequential(self): + ds = dataset.MapDataset.source(range(10)).to_iter_dataset() + ds_sequential_off = prefetch.MultiprocessPrefetchIterDataset( + ds, + options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), + sequential_slice=False, + ) + ds_sequential_default = prefetch.MultiprocessPrefetchIterDataset( + ds, + options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), + ) + elements_sequential_off = list(ds_sequential_off) + elements_sequential_default = list(ds_sequential_default) + self.assertEqual( + elements_sequential_off, + elements_sequential_default, + ) + self.assertEqual( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + elements_sequential_default, + ) + + def test_multiprocess_prefetch_sequential_slice_order_from_source(self): + ds = dataset.MapDataset.source(range(10)).to_iter_dataset() + ds_sequential_on = prefetch.MultiprocessPrefetchIterDataset( + ds, + options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), + sequential_slice=True, + ) + elements_sequential_on = list(ds_sequential_on) + self.assertEqual([0, 4, 7, 1, 5, 8, 2, 6, 9, 3], elements_sequential_on) + + def test_multiprocess_prefetch_sequential_slice_order_from_range(self): + ds_range = dataset.MapDataset.range(10).to_iter_dataset() + ds_range_sequential_on = prefetch.MultiprocessPrefetchIterDataset( + ds_range, + options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), + sequential_slice=True, + ) + elements_range_sequential_on = list(ds_range_sequential_on) + self.assertEqual( + [0, 4, 7, 1, 5, 8, 2, 6, 9, 3], + elements_range_sequential_on, + ) + + def test_multiprocess_prefetch_sequential_slice_order_from_range_slice(self): + ds_range = dataset.MapDataset.range( + start=2, stop=21, step=3 + ).to_iter_dataset() + ds_range_sequential_on = prefetch.MultiprocessPrefetchIterDataset( + ds_range, + options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), + sequential_slice=True, + ) + elements_range_sequential_on = list(ds_range_sequential_on) + self.assertEqual( + [2, 11, 17, 5, 14, 20, 8], + elements_range_sequential_on, + ) + + def test_multiprocess_prefetch_sequential_slice_order_same(self): + ds_source = dataset.MapDataset.source(range(10)).to_iter_dataset() + ds_range = dataset.MapDataset.range(10).to_iter_dataset() + ds_source_mp = prefetch.MultiprocessPrefetchIterDataset( + ds_source, + options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), + sequential_slice=True, + ) + ds_range_mp = prefetch.MultiprocessPrefetchIterDataset( + ds_range, + options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), + sequential_slice=True, + ) + elements_source = list(ds_source_mp) + elements_range = list(ds_range_mp) + self.assertEqual(elements_source, elements_range) + + def test_options_after_prefetch(self): + ds = dataset.MapDataset.source([1, 2, 3]).repeat(1000) + ds = ds.filter(lambda x: x > 2) + ds = ds.to_iter_dataset() + ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=1)) + ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1) + ds = dataset.WithOptionsIterDataset(ds, ds_options) + with self.assertRaises(Exception): + list(ds) + + def test_worker_init_fn(self): + def set_worker_index_and_count(worker_index: int, worker_count: int): + log_formatter = std_logging.Formatter( + f'[Worker {worker_index} out of {worker_count}] %(message)s' + ) + logging.get_absl_handler().setFormatter(log_formatter) + + def map_fn(x): + # absl logging from workers is not propagated to the main process in unit + # tests. Therefore, we manually pass the formatted log message. + record = logging.get_absl_logger().makeRecord( + 'grain', + logging.INFO, + 'grain_pool_test', + 123, + f'processing element {x}', + (), + None, + ) + return logging.get_absl_handler().format(record) + + ds = dataset.MapDataset.range(2).map(map_fn) + ds = ds.to_iter_dataset() + ds = ds.mp_prefetch( + options.MultiprocessingOptions(num_workers=2), + worker_init_fn=set_worker_index_and_count, + ) + self.assertEqual( + list(ds), + [ + '[Worker 0 out of 2] processing element 0', + '[Worker 1 out of 2] processing element 1', + ], + ) + + def test_element_spec(self): + ds = dataset.MapDataset.range(2).to_iter_dataset() + ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=2)) + spec = dataset.get_element_spec(ds) + self.assertEqual(spec.dtype, np.int64) + self.assertEqual(spec.shape, ()) + + class ThreadPrefetchIterDatasetTest(parameterized.TestCase): def setUp(self): @@ -866,6 +1413,24 @@ def test_set_state_on_fresh_iterator(self): value = next(ds_iter) self.assertEqual(value, values_without_interruption[i]) + def test_get_state_doesnt_start_prefetch(self): + event = threading.Event() + + def f(x): + event.set() + return x + + ds = dataset.MapDataset.source([1, 2, 3]).map(f).to_iter_dataset() + ds = prefetch.multithread_prefetch( + ds, + num_threads=2, + buffer_size=10, + ) + it = ds.__iter__() + it.get_state() + time.sleep(1) + self.assertFalse(event.is_set()) + def test_does_not_hang_after_stop_iteration(self): ds = dataset.MapDataset.source([1, 2, 3]).repeat(100).to_iter_dataset() ds = prefetch.multithread_prefetch( @@ -876,6 +1441,21 @@ def test_does_not_hang_after_stop_iteration(self): it = ds.__iter__() it.start_prefetch() + def test_fails_with_multiprocess_prefetch_parent(self): + ds = prefetch.MultiprocessPrefetchIterDataset( + self.ds, + options.MultiprocessingOptions(num_workers=2), + ) + with self.assertRaisesRegex( + ValueError, + 'Nesting multiprocessing or multithreading is not allowed.', + ): + _ = prefetch.multithread_prefetch( + ds, + num_threads=1, + buffer_size=1, + ) + def test_mp_context_is_set_correctly(self): num_workers = 4 ds = dataset.MapDataset.range(20).to_iter_dataset() diff --git a/grain/_src/python/grain_pool.py b/grain/_src/python/grain_pool.py new file mode 100644 index 000000000..840b38a90 --- /dev/null +++ b/grain/_src/python/grain_pool.py @@ -0,0 +1,838 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module provides a way to distribute processing across multiple workers. + +In the context of Grain we use the term "process" similar to JAX, where usually +each machine runs one Python process (identified by `jax.process_index()`). +In Grain each "process" can create additional Python child processes that we +call "workers". + +GrainPool manages a set of Python processes. It's similar to +`multiprocessing.Pool` but optimises communication between the processes to +enable high throughput data pipelines. +The GrainPool works as follows: +* Parent process launches a set of "num_workers" child processes. +* Each child process produces elements by reading data and transforming it. The + resulting elements are added to a queue (each child process has its queue). +* Parent process reads data from the children queues in a strict round-robin + fashion. + +Shutdown logic considerations: +* Child processes are launched as Daemon processes. In case of (unexpected) + parent termination, child processes will be terminated by OS. +* System uses a multiprocessing event ("termination_event") for termination. + Parent and child processes continuously check if the "termination_event" and + if set, they break from what they are doing. +* We never block indefinitely when calling get() or put() on a queue. This + ensures parent and child processes continue to check the termination_event. + +MultiProcessIterator wraps GrainPool adding lifecycle management, checkpointing +support and multithreaded elements read. +""" + +from __future__ import annotations + +from collections.abc import Iterator +import cProfile +import dataclasses +from multiprocessing import context +from multiprocessing import pool +from multiprocessing import queues +from multiprocessing import synchronize +import pstats +import queue +import sys +import threading +import traceback +from typing import Any, Callable, Protocol, Type, TypeVar, Union, runtime_checkable + +from absl import flags +from absl import logging +import cloudpickle +from grain._src.core import monitoring as grain_monitoring +from grain._src.core import parallel +from grain._src.core import tree_lib +from grain._src.core.config import config +import multiprocessing as mp +from grain._src.python import grain_logging +from grain._src.python import multiprocessing_common +from grain._src.python import record +from grain._src.python import shared_memory_array +from grain._src.python.options import MultiprocessingOptions # pylint: disable=g-importing-member + +# pylint: disable=g-bare-generic +# Generic queues.Queue used without subscripting in the file. + +T = TypeVar("T") + +# Maximum number of threads for starting and stopping processes. +_PROCESS_MANAGEMENT_MAX_THREADS = 64 +_PROCESS_JOIN_TIMEOUT = 25 +_QUEUE_WAIT_TIMEOUT = 1 +# Input queues contain small structures (record metadata), thus they are safe +# to have a big size. +_INPUT_QUEUE_MAX_SIZE = 10000 + + +@dataclasses.dataclass +class _ProcessingComplete: + """Indicates child process finished processing.""" + + +_PROCESSING_COMPLETE = _ProcessingComplete() + + +@dataclasses.dataclass(slots=True, frozen=True) +class GrainPoolElement: + """Wrapper for output records emited by Grain Pool.""" + + record: Any + worker_index: Any + + +@dataclasses.dataclass(slots=True, frozen=True) +class RemoteWorkerError: + """Grain worker exception that can be pickled and sent over a queue.""" + error_cls: Type[Exception] + error: str + worker_index: int + + @property + def original_error(self) -> Exception: + msg = ( + f"Grain worker {self.worker_index} failed with the following" + f" error:\n\n{self.error}" + ) + # Custom exception classes can have different c'tor arguments. + try: + return self.error_cls(msg) + except Exception: # pylint: disable=broad-except + return RuntimeError(msg) + + +def _print_profile(preamble: str, profile: cProfile.Profile): + """Prints output of cProfile, sorted by cumulative time.""" + print(preamble) + stats = pstats.Stats(profile).sort_stats(pstats.SortKey.CUMULATIVE) + stats.print_stats() + + +@runtime_checkable +class GetElementProducerFn(Protocol[T]): + """A callable class able to generate elements with serialization support.""" + + def __call__( + self, + *, + worker_index: int, + worker_count: int, + start_profiling_event: synchronize.Event | None = None, + stop_profiling_event: synchronize.Event | None = None, + stats_out_queue: queues.Queue | None = None, + ) -> Iterator[T]: + """Returns a generator of elements.""" + + def serialize(self) -> bytes: + """Serializes itself and the result will be used by `deserialize`. + + If a class inherits from this class, it should make sure `deserialize` + is compatible with this `serialize` function. + i.e. `GetElementProducerFn.deserialize(obj.serialize())` should return the + same object as `obj: GetElementProducerFn`. + + Returns: + a serialized string of myself. + """ + return cloudpickle.dumps(self) + + @classmethod + def deserialize(cls, serialized: bytes) -> GetElementProducerFn[T]: + """Deserializes the result from `serialize`.""" + del cls + + obj = cloudpickle.loads(serialized) + if not isinstance(obj, GetElementProducerFn): + raise ValueError( + "`serialized` should be deserialized into `GetElementProducerFn`." + ) + + return obj + + +def parse_debug_flags(debug_flags: dict[str, Any]): + """Parses debug flags.""" + + flags.FLAGS["grain_py_debug_mode"].present = True + flags.FLAGS["grain_py_dataset_visualization_output_dir"].present = True + config.update("py_debug_mode", debug_flags["grain_py_debug_mode"]) + config.update( + "py_dataset_visualization_output_dir", + debug_flags["grain_py_dataset_visualization_output_dir"], + ) + + +def _initialize_and_get_element_producer( + args_queue: queues.Queue, + *, + debug_flags: dict[str, Any], + worker_index: int, + worker_count: int, + start_profiling_event: synchronize.Event, + stop_profiling_event: synchronize.Event, + stats_out_queue: queues.Queue, +) -> Iterator[Any]: + """Unpickles the element producer from the args queue and closes the queue.""" + ( + serialized_flag_parse_fn, + serialized_init_fns, + serialized_element_producer_fn, + ) = args_queue.get() + flag_parse_fn: Callable[[Any], None] = cloudpickle.loads( + serialized_flag_parse_fn + ) + flag_parse_fn(debug_flags) + init_fns: list[Callable[[int, int], None]] = cloudpickle.loads( + serialized_init_fns + ) + for init_fn in init_fns: + init_fn(worker_index, worker_count) + element_producer_fn: GetElementProducerFn[Any] = ( + GetElementProducerFn.deserialize(serialized_element_producer_fn) + ) + + element_producer = element_producer_fn( + worker_index=worker_index, + worker_count=worker_count, + start_profiling_event=start_profiling_event, + stop_profiling_event=stop_profiling_event, + stats_out_queue=stats_out_queue, + ) + # args_queue has only a single argument and thus can be safely closed. + args_queue.close() + return element_producer + + +def _worker_loop( + *, + args_queue: queues.Queue, + errors_queue: queues.Queue, + output_queue: queues.Queue, + termination_event: synchronize.Event, + start_profiling_event: synchronize.Event, + stop_profiling_event: synchronize.Event, + worker_index: int, + worker_count: int, + enable_profiling: bool, + debug_flags: dict[str, Any], + stats_out_queue: queues.Queue, +): + """Code to be run on each child process.""" + out_of_elements = False + try: + worker_index_suffix = "" if worker_count == 1 else f" {worker_index}" + grain_logging.set_process_identifier_prefix( + f"PyGrain Worker{worker_index_suffix}" + ) + logging.info("Starting work.") + element_producer = _initialize_and_get_element_producer( + args_queue, + debug_flags=debug_flags, + worker_index=worker_index, + worker_count=worker_count, + start_profiling_event=start_profiling_event, + stop_profiling_event=stop_profiling_event, + stats_out_queue=stats_out_queue, + ) + profiling_enabled = enable_profiling and worker_index == 0 + if profiling_enabled: + profile = cProfile.Profile() + profile.enable() + # If termination event is set, we terminate and discard remaining elements. + while not termination_event.is_set(): + try: + next_element = next(element_producer) + if not multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types + next_element, output_queue, termination_event.is_set + ): + # We failed to put the element into the output queue because the + # termination event was set. The element may contain a shared memory + # block reference that has to be cleaned up. + _unlink_shm_in_structure(next_element) + except StopIteration: + out_of_elements = True + multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types + _ProcessingComplete(), output_queue, termination_event.is_set + ) + break + if profiling_enabled: + profile.disable() + _print_profile(f"PROFILE OF PROCESS WITH IDX {worker_index}.", profile) + + except Exception as e: # pylint: disable=broad-except + logging.exception( + "Error occurred in child process with worker_index: %i", worker_index + ) + remote_error = RemoteWorkerError( + error_cls=e.__class__, + error="".join( + traceback.format_exception(e.__class__, e, e.__traceback__) + ), + worker_index=worker_index, + ) + try: + errors_queue.put(remote_error, timeout=_QUEUE_WAIT_TIMEOUT) + except queue.Full: + logging.error("Couldn't send exception from child process. Queue full!") + + logging.info( + "Setting termination event in process with worker_index: %i", + worker_index, + ) + termination_event.set() + + if termination_event.is_set(): + if not out_of_elements: + # Since the termination event is set the consumer will not get any more + # elements from the output queue. The elements may contain reference to + # shared memory blocks that have to be cleaned up. + while not output_queue.empty(): + _unlink_shm_in_structure(output_queue.get_nowait()) + # When adding elements to the queue, element is put in a buffer and a + # background thread flushes the elements through the pipe. The process that + # writes to the queue joins that thread automatically on exit. We call + # cancel_join_thread when system terminates to prevent deadlocks. + output_queue.cancel_join_thread() + output_queue.close() + logging.info("Process %i exiting.", worker_index) + + +def _unlink_shm_if_metadata(obj: Any): + if isinstance(obj, shared_memory_array.SharedMemoryArrayMetadata): + obj.close_and_unlink_shm() + + +def _unlink_shm_in_structure(structure: Any): + if isinstance(structure, record.Record): + _unlink_shm_in_structure(structure.data) + else: + tree_lib.map_structure(_unlink_shm_if_metadata, structure) + + +class GrainPool(Iterator[T]): + """Pool to parallelize processing of Grain pipelines among a set of processes.""" + + def __init__( + self, + ctx: context.BaseContext, + *, + get_element_producer_fn: GetElementProducerFn[T], + worker_index_to_start_reading: int = 0, + termination_event: threading.Event | None = None, + start_profiling_event: synchronize.Event | None = None, + stop_profiling_event: synchronize.Event | None = None, + options: MultiprocessingOptions, + worker_init_fn: Callable[[int, int], None] | None = None, + stats_in_queues: tuple[queues.Queue, ...] | None = None, + ): + """Initialise a Grain Pool. + + Args: + ctx: Context to make multiprocessing primitives work. + get_element_producer_fn: Callable that returns an iterator over the + elements given the process index and process count. + worker_index_to_start_reading: index of worker to start reading output + batches from (needed for checkpointing support). + termination_event: Setting this event will terminate the pool. Otherwise, + the pool will terminate when either one of the workers failed or when + all workers are done processing data. GrainPool will not set this event. + start_profiling_event: Event to start prism profiling. + stop_profiling_event: Event to stop prism profiling. + options: Options for multiprocessing. See MultiprocessingOptions. + worker_init_fn: Function to run in each worker process before the element + producer. The function takes two arguments: the current worker index and + the total worker count. + stats_in_queues: Queue to propagate execution summary from child processes + to the parent. + """ + self.num_processes = options.num_workers + logging.info("Grain pool will use %i processes.", self.num_processes) + self.worker_args_queues = [] + self.worker_output_queues = [] + self.processes = [] + # Reader termination should always result in worker termination. However, + # worker termination should not shut down the reader: workers are terminated + # when they finished processing data, but the reader may still need to read + # the remaining output from the shared queues. That is why we use two + # separate events. + self._reader_termination_event = termination_event or threading.Event() + self._workers_termination_event = ctx.Event() + self._worker_init_fn = worker_init_fn + self.completed_processes = set() + # Queue to propagate errors from child processes to the parent. Note that + # this queue is shared by all child processes. + self.worker_error_queue = ctx.Queue(self.num_processes) + self.stats_in_queues = stats_in_queues + + try: + get_element_producer_fn = get_element_producer_fn.serialize() + except Exception as e: + if sys.version_info >= (3, 11): + e.add_note( + "\nCould not serialize transformation function passed to Grain " + "workers. This likely means that your data source, sampler or one " + "of your transformations cannot be serialized. Please make sure " + "that the objects work with cloudpickle.dumps()." + ) + raise e + + for worker_index in range(self.num_processes): + worker_args_queue = ctx.Queue(1) + worker_output_queue = ctx.Queue(options.per_worker_buffer_size) + process_kwargs = dict( + args_queue=worker_args_queue, + errors_queue=self.worker_error_queue, + output_queue=worker_output_queue, + stats_out_queue=( + self.stats_in_queues[worker_index] + if self.stats_in_queues + else None + ), + termination_event=self._workers_termination_event, + start_profiling_event=start_profiling_event, + stop_profiling_event=stop_profiling_event, + worker_index=worker_index, + worker_count=options.num_workers, + enable_profiling=options.enable_profiling, + debug_flags=dict( + grain_py_debug_mode=config.get_or_default("py_debug_mode"), + grain_py_dataset_visualization_output_dir=( + config.get_or_default("py_dataset_visualization_output_dir") + ), + ), + ) + # The process kwargs must all be pickable and will be unpickle before + # absl.app.run() is called. We send arguments via a queue to ensure that + # they are unpickled after absl.app.run() was called in the child + # processes. + worker_init_fns = [self._worker_init_fn] if self._worker_init_fn else [] + parse_debug_flags_fn = parse_debug_flags + worker_init_fns = cloudpickle.dumps(worker_init_fns) + parse_debug_flags_fn = cloudpickle.dumps(parse_debug_flags_fn) + worker_args_queue.put( + (parse_debug_flags_fn, worker_init_fns, get_element_producer_fn) + ) + process = ctx.Process( # pytype: disable=attribute-error # re-none + target=_worker_loop, kwargs=process_kwargs, daemon=True + ) + self.worker_args_queues.append(worker_args_queue) + self.worker_output_queues.append(worker_output_queue) + self.processes.append(process) + + logging.info("Grain pool will start child processes.") + parallel.run_in_parallel( + function=lambda child_process: child_process.start(), + list_of_kwargs_to_function=[ + {"child_process": p} for p in self.processes + ], + num_workers=min(_PROCESS_MANAGEMENT_MAX_THREADS, self.num_processes), + ) + logging.info("Grain pool started all child processes.") + self._next_worker_index = worker_index_to_start_reading + + def __iter__(self) -> GrainPool: + return self + + def _process_failed(self, worker_index: int) -> bool: + exit_code = self.processes[worker_index].exitcode + return exit_code is not None and exit_code != 0 + + def _processing_completed(self) -> bool: + return all(p.exitcode == 0 for p in self.processes) + + def _update_next_worker_index(self) -> None: + self._next_worker_index = (self._next_worker_index + 1) % self.num_processes + + def __next__(self) -> GrainPoolElement: + processing_failed = False + while ( + not self._workers_termination_event.is_set() + and len(self.completed_processes) < self.num_processes + ): + # If the reader was shut down, e.g. due to iterator deletion, we should + # shut down the workers. + if self._reader_termination_event.is_set(): + self._shutdown() + # Since the reader is shut down it doesn't matter what we return here. + # We should not raise an exception because it is common to iterate over + # infinite datasets and delete the iterator before processing is + # complete. + return GrainPoolElement( + "Grain worker pool reader was terminated, shutting down workers.", + -1, + ) + if self._next_worker_index in self.completed_processes: + self._update_next_worker_index() + continue + try: + element_worker_index = self._next_worker_index + element = self.worker_output_queues[self._next_worker_index].get( + timeout=_QUEUE_WAIT_TIMEOUT + ) + logging.debug("Read element from process: %s", self._next_worker_index) + if element == _PROCESSING_COMPLETE: + logging.info( + "Processing complete for process with worker_index %i", + self._next_worker_index, + ) + self.completed_processes.add(self._next_worker_index) + self._update_next_worker_index() + else: + self._update_next_worker_index() + return GrainPoolElement(element, element_worker_index) + except queue.Empty: + logging.debug("Got no element from process %s", self._next_worker_index) + if self._process_failed(self._next_worker_index): + processing_failed = True + logging.info( + "Process with idx %i Failed (Exitcode: %s).", + self._next_worker_index, + self.processes[self._next_worker_index].exitcode, + ) + break + + if processing_failed or self._workers_termination_event.is_set(): + logging.error("Processing Failed. Shutting down.") + self._shutdown() + + try: + remote_error = self.worker_error_queue.get(timeout=_QUEUE_WAIT_TIMEOUT) + raise remote_error.original_error + except queue.Empty: + # Worker did not report any error. This means that either an exception + # was raised outside of the worker loop (e.g. during flag parsing) or + # the worker process was forcefully terminated. Unfortunately, there is + # no debugging info available in the main process at this point apart + # from the exit code. The crash logs, however, should've been produced. + raise RuntimeError( + f"Grain worker process {self._next_worker_index} was terminated" + " unexpectedly with exit code " + f"{self.processes[self._next_worker_index].exitcode}. Search the " + "logs above for the source of the crash." + ) from None + + # Processing successfully completed. + raise StopIteration + + def __del__(self): + self._shutdown() + + def __enter__(self) -> GrainPool: + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + logging.info("Grain pool is exiting.") + self._shutdown() + + def _shutdown(self) -> None: + """Gracefully shutdown the multiprocessing system.""" + logging.info("Shutting down multiprocessing system.") + try: + self._workers_termination_event.set() + # There is a chance that shutdown was triggered before the worker + # processes fully initialized and read from the arg queues. The arg + # queues will block the main process until their elements are flushed + # through the pipes, which will never happen since the workers were shut + # down. Here we avoid blocking the main process, see + # https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Queue.cancel_join_thread + for q in self.worker_args_queues: + q.cancel_join_thread() + q.close() + # Not joining here will cause the children to be zombie after they finish. + # Need to join or call active_children. + for process in self.processes: + process.join(timeout=_PROCESS_JOIN_TIMEOUT) + finally: + for process in self.processes: + # In case all our attempts to terminate the system fails, we forcefully + # kill the child processes. + if process.is_alive(): + logging.info("Killing worker process with pid %i", process.pid) + process.kill() + + +@dataclasses.dataclass(slots=True, frozen=True) +class _ReaderQueueElement: + """Element to be added to the reader queue.""" + + async_result: pool.AsyncResult[Any] + # index of worker producing the element in [0, worker_count] + worker_index: int + + +@dataclasses.dataclass(frozen=True) +class _GrainPoolProcessingComplete: + """Indicates processing of grain pool is complete.""" + + +_GRAIN_POOL_PROCESSING_COMPLETE = _GrainPoolProcessingComplete() +_QueueElement = Union[ + _ReaderQueueElement, _GrainPoolProcessingComplete, Exception +] + + +def _open_shared_memory_for_leaf(element: Any) -> Any: + if isinstance(element, shared_memory_array.SharedMemoryArrayMetadata): + element = shared_memory_array.SharedMemoryArray.from_metadata(element) + element.unlink_on_del() + return element + + +def _open_shared_memory_for_structure(structure: Any) -> Any: + if isinstance(structure, record.Record): + structure.data = tree_lib.map_structure( + _open_shared_memory_for_leaf, structure.data + ) + return structure + return tree_lib.map_structure(_open_shared_memory_for_leaf, structure) + + +def _process_elements_in_grain_pool( + *, + get_element_producer_fn: GetElementProducerFn, + multiprocessing_options: MultiprocessingOptions, + reader_queue: queue.Queue[_QueueElement], + thread_pool: pool.ThreadPool, + termination_event: threading.Event, + start_profiling_event: synchronize.Event | None, + stop_profiling_event: synchronize.Event | None, + worker_index_to_start_reading: int, + worker_init_fn: Callable[[int, int], None] | None, + stats_in_queues: tuple[queues.Queue, ...] | None, +) -> None: + """Processes elements in grain worker pool asynchronously.""" + + def read_thread_should_stop(): + return termination_event.is_set() or not threading.main_thread().is_alive() + + ctx = mp.get_context("spawn") + + try: + with GrainPool( + ctx=ctx, + get_element_producer_fn=get_element_producer_fn, + worker_index_to_start_reading=worker_index_to_start_reading, + termination_event=termination_event, + start_profiling_event=start_profiling_event, + stop_profiling_event=stop_profiling_event, + options=multiprocessing_options, + worker_init_fn=worker_init_fn, + stats_in_queues=stats_in_queues, + ) as g_pool: + for element in g_pool: + if read_thread_should_stop(): + break + # Note: We use a thread pool for opening the shared memory because + # in some cases the calls to `shm_open` can actually become the + # bottleneck for a single thread. + async_result = thread_pool.apply_async( + _open_shared_memory_for_structure, + args=(element.record,), + ) + multiprocessing_common.add_element_to_queue( + _ReaderQueueElement( + async_result, + element.worker_index, + ), + reader_queue, + read_thread_should_stop, + ) + # This exception could arise from user-provide code. Propagating it to + # the main thread to re-raise it as is. + except Exception as e: # pylint: disable=broad-except + multiprocessing_common.add_element_to_queue( + e, reader_queue, read_thread_should_stop + ) + return + multiprocessing_common.add_element_to_queue( + _GrainPoolProcessingComplete(), + reader_queue, + read_thread_should_stop, + ) + + +class MultiProcessIteratorInvalidStateError(Exception): + """Raised when iterator is an invalid state and can't be iterated on.""" + + +class MultiProcessIterator(Iterator[T]): + """Runs iterators returned by `get_element_producer_fn` in child processes. + + Note: MultiProcessIterator implements the Context Manager protocol to clean + resources. As such, it must be used within a "with" statement. + + Wraps `GrainPool` adding lifecycle management, multithreaded elements read and + recording the last worker index useful for checkpointing. + """ + + def __init__( + self, + get_element_producer_fn: GetElementProducerFn, + multiprocessing_options: MultiprocessingOptions, + worker_index_to_start_reading: int, + worker_init_fn: Callable[[int, int], None] | None = None, + start_profiling_event: synchronize.Event | None = None, + stop_profiling_event: synchronize.Event | None = None, + stats_in_queues: tuple[queues.Queue, ...] | None = None, + ): + """Initializes MultiProcessIterator. + + Args: + get_element_producer_fn: factory making record iterators for each child + process. + multiprocessing_options: options for distributing the record iterators. + worker_index_to_start_reading: Index of the next worker to read from. This + is useful for recovering from a checkpoint. + worker_init_fn: Function to run in each worker process before the element + producer. The function takes two arguments: the current worker index and + the total worker count. + start_profiling_event: Event to start prism profiling. + stop_profiling_event: Event to stop prism profiling. + stats_in_queues: Queues to send execution summaries from worker processes + to the main process. + """ + self._get_element_producer_fn = get_element_producer_fn + self._multiprocessing_options = multiprocessing_options + self._last_worker_index = worker_index_to_start_reading - 1 + self._worker_init_fn = worker_init_fn + self._reader_queue = None + self._reader_thread_pool = None + self._termination_event = None + self._reader_thread = None + self._stats_in_queues = stats_in_queues + self._start_profiling_event = start_profiling_event + self._stop_profiling_event = stop_profiling_event + + def __del__(self): + if self._reader_thread: + logging.info("Destroying multiprocess iterator.") + self.stop_prefetch() + + def start_prefetch(self) -> None: + """Starts the prefetching threads.""" + + if self._reader_thread: + return + + max_buffered_elements = ( + self._multiprocessing_options.num_workers + * self._multiprocessing_options.per_worker_buffer_size + ) + self._reader_queue = queue.Queue(maxsize=max_buffered_elements) + self._reader_thread_pool = pool.ThreadPool(max_buffered_elements) + self._termination_event = threading.Event() + self._reader_thread = threading.Thread( + target=_process_elements_in_grain_pool, + kwargs=dict( + get_element_producer_fn=self._get_element_producer_fn, + multiprocessing_options=self._multiprocessing_options, + reader_queue=self._reader_queue, + thread_pool=self._reader_thread_pool, + termination_event=self._termination_event, + start_profiling_event=self._start_profiling_event, + stop_profiling_event=self._stop_profiling_event, + worker_index_to_start_reading=self._last_worker_index + 1, + worker_init_fn=self._worker_init_fn, + stats_in_queues=self._stats_in_queues, + ), + ) + self._reader_thread.start() + shared_memory_array.SharedMemoryArray.enable_async_del( + self._multiprocessing_options.num_workers + ) + + def stop_prefetch(self) -> None: + """Cleans up prefetching threads.""" + + if not self._reader_thread: + return + + # pytype: disable=attribute-error + self._termination_event.set() + self._reader_thread_pool.close() + self._reader_thread.join() + self._reader_thread_pool.join() + # pytype: enable=attribute-error + self._termination_event = None + self._reader_thread_pool = None + self._reader_thread = None + self._reader_queue = None + + def __enter__(self): + self.start_prefetch() + return self + + def __exit__(self, exc_type, exc_value, tb): + self.stop_prefetch() + + def _can_iterate(self): + """Checks whether the object is in a state where it can be iterated on.""" + return ( + self._reader_queue is not None + and self._termination_event is not None + and self._reader_thread_pool is not None + and self._reader_thread is not None + ) + + def __iter__(self): + if not self._can_iterate(): + raise MultiProcessIteratorInvalidStateError( + "MultiProcessIterator is in an invalid state. Note that" + " MultiProcessIterator should be used with a 'with' statement." + ) + return self + + def get_last_worker_index(self): + return self._last_worker_index + + def __next__(self): + if not self._can_iterate(): + raise MultiProcessIteratorInvalidStateError( + "MultiProcessIterator is in an invalid state. Note that" + " MultiProcessIterator should be used with a 'with' statement." + ) + element = multiprocessing_common.get_element_from_queue( + self._reader_queue, self._termination_event.is_set # pytype: disable=attribute-error + ) + if isinstance(element, Exception): + raise element + if ( + element == _GRAIN_POOL_PROCESSING_COMPLETE + or element == multiprocessing_common.SYSTEM_TERMINATED + ): + raise StopIteration + + if not isinstance(element, _ReaderQueueElement): + raise ValueError( + f"Got invalid element type from GrainPool: {type(element)}" + ) + + result = multiprocessing_common.get_async_result( + element.async_result, self._termination_event.is_set + ) + if isinstance(result, multiprocessing_common._SystemTerminated): # pylint: disable=protected-access + raise StopIteration + self._last_worker_index = element.worker_index + return result + +# pylint: enable=g-bare-generic diff --git a/grain/_src/python/grain_pool_test.py b/grain/_src/python/grain_pool_test.py new file mode 100644 index 000000000..93a39f61a --- /dev/null +++ b/grain/_src/python/grain_pool_test.py @@ -0,0 +1,474 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Iterator +import multiprocessing +import os +import platform +import signal +import sys +from typing import Any +from absl import flags +from absl.testing import absltest +from absl.testing import parameterized +from grain._src.core import config +from grain._src.core import monitoring as grain_monitoring +import multiprocessing as mp +from grain._src.python import data_sources +from grain._src.python import grain_pool as gp +from grain._src.python import record +from grain._src.python.options import MultiprocessingOptions # pylint: disable=g-importing-member + +# pylint: disable=g-bare-generic +# Generic queues.Queue used without subscripting in the file. + + +class GrainPoolTest(absltest.TestCase): + + def _join_and_assert_process_exitcode(self, process: multiprocessing.Process): + # The process can be potentially terminated forcibly and needs a moment to + # finalize and update the exitcode. + process.join(timeout=gp._PROCESS_JOIN_TIMEOUT) + self.assertIn(process.exitcode, {0, -signal.SIGTERM}) + + def test_pool_with_flags_not_parsed(self): + class GetElementProducerFn(gp.GetElementProducerFn): + + def __call__(self, *, worker_index: int, worker_count: int, **kwargs): + del self + return iter(range(worker_index, 14, worker_count)) + + get_element_producer_fn = GetElementProducerFn() + # unparse the flags explicitly + flags.FLAGS.unparse_flags() + + _ = gp.GrainPool( + ctx=mp.get_context("spawn"), + get_element_producer_fn=get_element_producer_fn, + options=MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1), + ) + + def test_pool_equal_split_in_memory_data_source(self): + in_memory_ds = data_sources.SharedMemoryDataSource(range(12)) + + # 12 elements in the `in_memory_ds` are divided + # equally among 4 processes. + class GetElementProducerFn(gp.GetElementProducerFn): + + def __call__(self, *, worker_index: int, worker_count: int, **kwargs): + del self + return iter(range(worker_index, 12, worker_count)) + + get_element_producer_fn = GetElementProducerFn() + + output_elements = [] + with gp.GrainPool( + ctx=mp.get_context("spawn"), + get_element_producer_fn=get_element_producer_fn, + options=MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1), + ) as grain_pool: + for element in grain_pool: + output_elements.append(element) + # turn each element in `in_memory_ds` to their negatives. + in_memory_ds[element.record] = -in_memory_ds[element.record] + + self.assertEqual( + output_elements, [gp.GrainPoolElement(x, x % 4) for x in range(12)] + ) + + self.assertEqual(list(iter(in_memory_ds)), [-x for x in range(12)]) + + def test_pool_equal_split(self): + ctx = mp.get_context("spawn") + + # 16 elements divide equally among 4 processes + class GetElementProducerFn(gp.GetElementProducerFn): + + def __call__(self, *, worker_index: int, worker_count: int, **kwargs): + del self + return iter(range(worker_index, 16, worker_count)) + + get_element_producer_fn = GetElementProducerFn() + + options = MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1) + output_elements = [] + with gp.GrainPool( + ctx=ctx, + get_element_producer_fn=get_element_producer_fn, + options=options, + ) as grain_pool: + for element in grain_pool: + output_elements.append(element) + expected_elements = list( + map( + lambda x: gp.GrainPoolElement(x, x % options.num_workers), range(16) + ) + ) + self.assertEqual(expected_elements, output_elements) + # Make sure num_processes processes were launched. + self.assertLen(grain_pool.processes, options.num_workers) + # Make sure all child processes exited successfully. + for child_process in grain_pool.processes: + self._join_and_assert_process_exitcode(child_process) + + def test_pool_non_equal_split(self): + ctx = mp.get_context("spawn") + + # 14 elements do not divide equally among 4 processes + class GetElementProducerFn(gp.GetElementProducerFn): + + def __call__(self, *, worker_index: int, worker_count: int, **kwargs): + del self + return iter(range(worker_index, 14, worker_count)) + + get_element_producer_fn = GetElementProducerFn() + + options = MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1) + output_elements = [] + with gp.GrainPool( + ctx=ctx, + get_element_producer_fn=get_element_producer_fn, + options=options, + ) as grain_pool: + for element in grain_pool: + output_elements.append(element) + expected_elements = list( + map( + lambda x: gp.GrainPoolElement(x, x % options.num_workers), range(14) + ) + ) + self.assertEqual(expected_elements, output_elements) + # Make sure all child processes exited successfully. + for child_process in grain_pool.processes: + self._join_and_assert_process_exitcode(child_process) + + @absltest.skipIf( + platform.system() == "Windows", "SIGKILL signal not available on Windows." + ) + def test_pool_kill_child(self): + ctx = mp.get_context("spawn") + + class GetElementProducerFn(gp.GetElementProducerFn): + + def __call__(self, *, worker_index: int, worker_count: int, **kwargs): + del self + return iter(range(worker_index, 14, worker_count)) + + get_element_producer_fn = GetElementProducerFn() + + options = MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1) + with gp.GrainPool( + ctx=ctx, + get_element_producer_fn=get_element_producer_fn, + options=options, + ) as grain_pool: + child_pid = grain_pool.processes[0].pid + os.kill(child_pid, signal.SIGKILL) + + self.assertEqual( + grain_pool.processes[0].exitcode, -1 * signal.SIGKILL.value + ) + for child_process in grain_pool.processes[1:]: + self._join_and_assert_process_exitcode(child_process) + + def test_pool_object_deletion(self): + ctx = mp.get_context("spawn") + + class GetElementProducerFn(gp.GetElementProducerFn): + + def __call__(self, *, worker_index: int, worker_count: int, **kwargs): + del self + return iter(range(worker_index, 14, worker_count)) + + get_element_producer_fn = GetElementProducerFn() + + options = MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1) + + # Users should generally use the with statement, here we test if GrainPool + # was created without the "with statement", that object deletion would + # have child processes gracefully exited. + grain_pool = gp.GrainPool( + ctx=ctx, + get_element_producer_fn=get_element_producer_fn, + options=options, + ) + + child_processes = grain_pool.processes + grain_pool.__del__() + + for child_process in child_processes: + self._join_and_assert_process_exitcode(child_process) + + +def _make_uniform_element_producer_fn( + last_seen_index: int = -1, +) -> gp.GetElementProducerFn: + + class _RoundrobinElementProducerFn(gp.GetElementProducerFn): + + def __call__( + self, *, worker_index: int, worker_count: int, **kwargs + ) -> Iterator[int]: + del self + yield from range(10)[last_seen_index + 1 + worker_index :: worker_count] + + return _RoundrobinElementProducerFn() + + +class RoundrobinRecordElementProducerFn(gp.GetElementProducerFn): + + def __call__( + self, *, worker_index: int, worker_count: int, **kwargs + ) -> Iterator[record.Record[int]]: + del self + for i in range(5)[worker_index::worker_count]: + yield record.Record(record.RecordMetadata(i), i) + + +class NonUniformElementProducerFn(gp.GetElementProducerFn): + + def __call__( + self, *, worker_index: int, worker_count: int, **kwargs + ) -> Iterator[int]: + del self, worker_count + for _ in range(worker_index * 3): + yield worker_index + + +class MultiProcessIteratorTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name="two_workers", + get_element_producer_fn=_make_uniform_element_producer_fn(), + multiprocessing_options=MultiprocessingOptions(num_workers=2), + worker_index_to_start_reading=0, + expected=list(range(10)), + ), + dict( + testcase_name="five_workers", + get_element_producer_fn=_make_uniform_element_producer_fn(), + multiprocessing_options=MultiprocessingOptions(num_workers=5), + worker_index_to_start_reading=0, + expected=list(range(10)), + ), + dict( + testcase_name="from_checkpoint", + get_element_producer_fn=_make_uniform_element_producer_fn(5), + multiprocessing_options=MultiprocessingOptions(num_workers=2), + worker_index_to_start_reading=1, + expected=[7, 6, 9, 8], + ), + dict( + testcase_name="non_uniform", + get_element_producer_fn=NonUniformElementProducerFn(), + multiprocessing_options=MultiprocessingOptions(num_workers=3), + worker_index_to_start_reading=0, + expected=[1, 2, 1, 2, 1, 2, 2, 2, 2], + ), + dict( + testcase_name="record_producer_fn", + get_element_producer_fn=RoundrobinRecordElementProducerFn(), + multiprocessing_options=MultiprocessingOptions(num_workers=3), + worker_index_to_start_reading=0, + expected=[ + record.Record(record.RecordMetadata(i), i) for i in range(5) + ], + ), + ) + def test_produces_correct_data( + self, + get_element_producer_fn: gp.GetElementProducerFn, + multiprocessing_options: MultiprocessingOptions, + worker_index_to_start_reading: int, + expected: Any, + ): + with gp.MultiProcessIterator( + get_element_producer_fn, + multiprocessing_options, + worker_index_to_start_reading, + ) as iterator: + actual = list(iterator) + self.assertEqual(actual, expected) + + @parameterized.named_parameters( + dict( + testcase_name="two_workers", + get_element_producer_fn=_make_uniform_element_producer_fn(), + multiprocessing_options=MultiprocessingOptions(num_workers=2), + worker_index_to_start_reading=1, + num_iters=5, + expected_last_worker_index=1, + ), + dict( + testcase_name="five_workers", + get_element_producer_fn=_make_uniform_element_producer_fn(), + multiprocessing_options=MultiprocessingOptions(num_workers=5), + worker_index_to_start_reading=0, + num_iters=7, + expected_last_worker_index=1, + ), + dict( + testcase_name="five_workers_incomplete_round", + get_element_producer_fn=_make_uniform_element_producer_fn(), + multiprocessing_options=MultiprocessingOptions(num_workers=5), + worker_index_to_start_reading=0, + num_iters=3, + expected_last_worker_index=2, + ), + dict( + testcase_name="from_checkpoint", + get_element_producer_fn=_make_uniform_element_producer_fn(5), + multiprocessing_options=MultiprocessingOptions(num_workers=2), + worker_index_to_start_reading=0, + num_iters=3, + expected_last_worker_index=0, + ), + dict( + testcase_name="non_uniform_record_producer_fn", + get_element_producer_fn=NonUniformElementProducerFn(), + multiprocessing_options=MultiprocessingOptions(num_workers=3), + worker_index_to_start_reading=0, + num_iters=6, + expected_last_worker_index=2, + ), + ) + def test_get_state( + self, + get_element_producer_fn: gp.GetElementProducerFn, + multiprocessing_options: MultiprocessingOptions, + worker_index_to_start_reading: int, + num_iters: int, + expected_last_worker_index: int, + ): + with gp.MultiProcessIterator( + get_element_producer_fn, + multiprocessing_options, + worker_index_to_start_reading, + ) as iterator: + for _ in range(num_iters): + _ = next(iterator) + actual_last_worker_index = iterator.get_last_worker_index() + self.assertEqual(actual_last_worker_index, expected_last_worker_index) + + def test_fails_with_zero_workers(self): + with self.assertRaisesRegex( + ValueError, "Number of processes must be at least 1" + ): + with gp.MultiProcessIterator( + _make_uniform_element_producer_fn(), + MultiprocessingOptions(num_workers=0), + 0, + ) as iterator: + list(iterator) + + def test_propagates_error(self): + error_msg = "very unique error" + + class FailingGetElementProducerFn(gp.GetElementProducerFn): + + def __call__( + self, *, worker_index: int, worker_count: int, **kwargs + ) -> Iterator[int]: + del self, worker_index, worker_count + raise ValueError(error_msg) + + failing_get_element_producer_fn = FailingGetElementProducerFn() + + with gp.MultiProcessIterator( + failing_get_element_producer_fn, + MultiprocessingOptions(num_workers=2), + 0, + ) as iterator: + with self.assertRaisesRegex(ValueError, error_msg): + list(iterator) + + def test_reports_worker_crash(self): + + class FailingGetElementProducerFn(gp.GetElementProducerFn): + + def __call__( + self, *, worker_index: int, worker_count: int, **kwargs + ) -> Iterator[int]: + del self, worker_index, worker_count + sys.exit(12) + + failing_get_element_producer_fn = FailingGetElementProducerFn() + + with gp.MultiProcessIterator( + failing_get_element_producer_fn, + MultiprocessingOptions(num_workers=2), + 0, + ) as iterator: + with self.assertRaisesRegex( + RuntimeError, "was terminated unexpectedly with exit code 12" + ): + list(iterator) + + def test_reports_unpicklable_element_producer_fn(self): + error_msg = "UnpicklableObject is not picklable" + + class UnpicklableObject: + + def __getstate__(self): + raise ValueError(error_msg) + + local_state = UnpicklableObject() + + class GetElementProducerFnWithUnpicklableClosure(gp.GetElementProducerFn): + + def __call__( + self, *, worker_index: int, worker_count: int, **kwargs + ) -> Iterator[int]: + del self, worker_index, worker_count + yield 1 if local_state is None else 2 + + get_element_producer_fn_with_unpicklable_closure = ( + GetElementProducerFnWithUnpicklableClosure() + ) + + with gp.MultiProcessIterator( + get_element_producer_fn_with_unpicklable_closure, + MultiprocessingOptions(num_workers=2), + 0, + ) as iterator: + with self.assertRaisesRegex(ValueError, error_msg): + list(iterator) + + def test_worker_init_fn(self): + + def _set_worker_index_and_count(worker_index: int, worker_count: int): + gp.monkey_patched_index_and_count = (worker_index, worker_count) + + class GetElementProducerFnReturningGlobal(gp.GetElementProducerFn): + + def __call__( + self, *, worker_index: int, worker_count: int, **kwargs + ) -> Iterator[tuple[int, int]]: + del self, worker_index, worker_count + yield gp.monkey_patched_index_and_count # pytype: disable=module-attr + + with gp.MultiProcessIterator( + GetElementProducerFnReturningGlobal(), + MultiprocessingOptions(num_workers=2), + 0, + worker_init_fn=_set_worker_index_and_count, + ) as iterator: + result = list(iterator) + self.assertEqual(result, [(0, 2), (1, 2)]) + +# pylint: enable=g-bare-generic + +if __name__ == "__main__": + absltest.main()