From d4477d90a8432759c4c7aee7b297e6d66ef8f626 Mon Sep 17 00:00:00 2001 From: Ihor Indyk Date: Thu, 19 Feb 2026 14:58:10 -0800 Subject: [PATCH] Move ipc-related modules into a separate folder. PiperOrigin-RevId: 872579501 --- grain/BUILD | 2 +- grain/_src/python/BUILD | 72 +------------------ grain/_src/python/data_loader.py | 21 +++--- grain/_src/python/data_loader_test.py | 2 +- grain/_src/python/dataset/BUILD | 10 +-- grain/_src/python/dataset/stats_test.py | 2 +- grain/_src/python/dataset/stats_utils.py | 2 +- grain/_src/python/dataset/stats_utils_test.py | 2 +- .../transformations/process_prefetch.py | 12 ++-- grain/_src/python/ipc/BUILD | 68 ++++++++++++++++++ .../queue.py} | 2 +- .../queue_test.py} | 2 +- .../python/{ => ipc}/shared_memory_array.py | 1 + .../{ => ipc}/shared_memory_array_test.py | 31 ++++---- .../python/{ => ipc}/variable_size_queue.py | 0 .../{ => ipc}/variable_size_queue_test.py | 2 +- grain/_src/python/operations.py | 6 +- grain/multiprocessing.py | 2 +- grain/python/__init__.py | 3 +- 19 files changed, 124 insertions(+), 118 deletions(-) create mode 100644 grain/_src/python/ipc/BUILD rename grain/_src/python/{multiprocessing_common.py => ipc/queue.py} (98%) rename grain/_src/python/{multiprocessing_common_test.py => ipc/queue_test.py} (98%) rename grain/_src/python/{ => ipc}/shared_memory_array.py (99%) rename grain/_src/python/{ => ipc}/shared_memory_array_test.py (93%) rename grain/_src/python/{ => ipc}/variable_size_queue.py (100%) rename grain/_src/python/{ => ipc}/variable_size_queue_test.py (99%) diff --git a/grain/BUILD b/grain/BUILD index 5fe820eb7..e1441364f 100644 --- a/grain/BUILD +++ b/grain/BUILD @@ -54,7 +54,7 @@ py_library( "//grain/_src/python:options", "//grain/_src/python:record", "//grain/_src/python:samplers", - "//grain/_src/python:shared_memory_array", + "//grain/_src/python/ipc:shared_memory_array", "//grain/_src/python/dataset", "//grain/_src/python/dataset:base", "//grain/_src/python/dataset:elastic_iterator", diff --git a/grain/_src/python/BUILD b/grain/_src/python/BUILD index 3fe798268..45cb690fe 100644 --- a/grain/_src/python/BUILD +++ b/grain/_src/python/BUILD @@ -68,26 +68,6 @@ py_test( ], ) -py_library( - name = "multiprocessing_common", - srcs = [ - "multiprocessing_common.py", - ], - srcs_version = "PY3", -) - -py_test( - name = "multiprocessing_common_test", - srcs = [ - "multiprocessing_common_test.py", - ], - srcs_version = "PY3", - deps = [ - ":multiprocessing_common", - "@abseil-py//absl/testing:absltest", - ], -) - py_library( name = "operations", srcs = [ @@ -96,8 +76,8 @@ py_library( srcs_version = "PY3", deps = [ ":record", - ":shared_memory_array", "//grain/_src/core:tree_lib", + "//grain/_src/python/ipc:shared_memory_array", "@abseil-py//absl/logging", "@pypi//numpy:pkg", ], @@ -159,7 +139,6 @@ py_library( ":options", ":record", ":samplers", - ":shared_memory_array", "//grain/_src/core:monitoring", "//grain/_src/core:sharding", "//grain/_src/core:transforms", @@ -167,7 +146,7 @@ py_library( "//grain/_src/python/checkpoint:base", "//grain/_src/python/dataset", "//grain/_src/python/dataset:base", - "@abseil-py//absl/logging", + "//grain/_src/python/ipc:shared_memory_array", "@pypi//etils:pkg", "@pypi//numpy:pkg", ], @@ -196,10 +175,10 @@ py_test( ":operations", ":options", ":samplers", - ":shared_memory_array", "//grain/_src/core:sharding", "//grain/_src/core:transforms", "//grain/_src/python/dataset", + "//grain/_src/python/ipc:shared_memory_array", "//grain/_src/python/testing:experimental", "@abseil-py//absl/flags", "@abseil-py//absl/testing:absltest", @@ -275,48 +254,3 @@ py_test( "@abseil-py//absl/testing:absltest", ], ) - -py_library( - name = "shared_memory_array", - srcs = ["shared_memory_array.py"], - srcs_version = "PY3", - deps = [ - "//grain/_src/core:tree_lib", - "@pypi//numpy:pkg", - ], -) - -py_test( - name = "shared_memory_array_test", - srcs = ["shared_memory_array_test.py"], - srcs_version = "PY3", - # Skip bazel test due to JAX installation issue. This is tested with pytest instead. - target_compatible_with = select({ - "//conditions:default": ["@platforms//:incompatible"], - }), - deps = [ - ":operations", - ":record", - ":shared_memory_array", - "@abseil-py//absl/testing:absltest", - "@abseil-py//absl/testing:parameterized", - "@pypi//jax:pkg", - "@pypi//numpy:pkg", - ], -) - -py_library( - name = "variable_size_queue", - srcs = ["variable_size_queue.py"], - srcs_version = "PY3", -) - -py_test( - name = "variable_size_queue_test", - srcs = ["variable_size_queue_test.py"], - srcs_version = "PY3", - deps = [ - ":variable_size_queue", - "@abseil-py//absl/testing:absltest", - ], -) diff --git a/grain/_src/python/data_loader.py b/grain/_src/python/data_loader.py index f6caafaee..5b7da3e9e 100644 --- a/grain/_src/python/data_loader.py +++ b/grain/_src/python/data_loader.py @@ -32,14 +32,13 @@ from grain._src.python import operations as ops from grain._src.python import options from grain._src.python import record +from grain._src.python import samplers from grain._src.python.checkpoint import base as checkpoint_base from grain._src.python.dataset import base as dataset_base from grain._src.python.dataset import dataset from grain._src.python.dataset.transformations import batch as batch_ds from grain._src.python.dataset.transformations import flatmap -from grain._src.python.operations import Operation -from grain._src.python.samplers import Sampler -from grain._src.python.shared_memory_array import SharedMemoryArray +from grain._src.python.ipc import shared_memory_array import numpy as np from grain._src.core import monitoring @@ -108,7 +107,9 @@ def copy_if_applied(element: Any) -> Any: ): return element - shared_memory_arr = SharedMemoryArray(element.shape, element.dtype) + shared_memory_arr = shared_memory_array.SharedMemoryArray( + element.shape, element.dtype + ) np.copyto(shared_memory_arr, element, casting="no") return shared_memory_arr.metadata @@ -121,7 +122,7 @@ class _SamplerMapDataset(dataset.MapDataset[record.Record]): def __init__( self, data_source: dataset_base.RandomAccessDataSource, - sampler: Sampler, + sampler: samplers.Sampler, shard_options: sharding.ShardOptions, ): super().__init__(dataset.MapDataset.source(data_source)) @@ -212,7 +213,7 @@ def __init__( parent: dataset.IterDataset[_T], shard_options: sharding.ShardOptions, worker_count: int, - sampler: Sampler, + sampler: samplers.Sampler, data_source: dataset_base.RandomAccessDataSource, ): super().__init__(parent) @@ -239,7 +240,7 @@ def __init__( parent: dataset.DatasetIterator[_T], shard_options: sharding.ShardOptions | None, worker_count: int, - sampler: Sampler, + sampler: samplers.Sampler, data_source: dataset_base.RandomAccessDataSource, ): super().__init__(parent) @@ -350,8 +351,8 @@ def __init__( self, *, data_source: dataset_base.RandomAccessDataSource, - sampler: Sampler, - operations: Sequence[transforms.Transformation | Operation] = (), + sampler: samplers.Sampler, + operations: Sequence[transforms.Transformation | ops.Operation] = (), worker_count: Optional[int] = 0, worker_buffer_size: int = 1, shard_options: sharding.ShardOptions | None = None, @@ -653,7 +654,7 @@ def flat_map(self, element: record.Record) -> Sequence[record.Record]: def _apply_transform_to_dataset( - transform: transforms.Transformation | Operation, + transform: transforms.Transformation | ops.Operation, ds: dataset.IterDataset, ) -> dataset.IterDataset: """Applies the `transform` to the dataset.""" diff --git a/grain/_src/python/data_loader_test.py b/grain/_src/python/data_loader_test.py index f53fef171..7598ebb96 100644 --- a/grain/_src/python/data_loader_test.py +++ b/grain/_src/python/data_loader_test.py @@ -30,12 +30,12 @@ from grain._src.python import data_loader as data_loader_lib from grain._src.python import options from grain._src.python import samplers -from grain._src.python import shared_memory_array # pylint: disable=g-importing-member from grain._src.python.data_sources import ArrayRecordDataSource from grain._src.python.data_sources import RangeDataSource from grain._src.python.data_sources import SharedMemoryDataSource from grain._src.python.dataset.transformations import batch +from grain._src.python.ipc import shared_memory_array from grain._src.python.operations import BatchOperation from grain._src.python.operations import FilterOperation from grain._src.python.operations import MapOperation diff --git a/grain/_src/python/dataset/BUILD b/grain/_src/python/dataset/BUILD index 8e3e81ad4..8fcd0bbcd 100644 --- a/grain/_src/python/dataset/BUILD +++ b/grain/_src/python/dataset/BUILD @@ -53,10 +53,10 @@ py_library( "//grain/_src/core:transforms", "//grain/_src/core:tree_lib", "//grain/_src/python:grain_logging", - "//grain/_src/python:multiprocessing_common", "//grain/_src/python:options", - "//grain/_src/python:shared_memory_array", "//grain/_src/python/checkpoint:base", + "//grain/_src/python/ipc:queue", + "//grain/_src/python/ipc:shared_memory_array", "//grain/proto:execution_summary_py_pb2", "@abseil-py//absl/flags", "@abseil-py//absl/logging", @@ -111,7 +111,7 @@ py_library( srcs_version = "PY3", deps = [ "//grain/_src/core:tree_lib", - "//grain/_src/python:shared_memory_array", + "//grain/_src/python/ipc:shared_memory_array", "//grain/proto:execution_summary_py_pb2", "@abseil-py//absl/logging", "@pypi//numpy:pkg", @@ -124,7 +124,7 @@ py_test( srcs_version = "PY3", deps = [ ":stats_utils", - "//grain/_src/python:shared_memory_array", + "//grain/_src/python/ipc:shared_memory_array", "//grain/proto:execution_summary_py_pb2", "@abseil-py//absl/testing:absltest", "@pypi//numpy:pkg", @@ -145,7 +145,7 @@ py_test( "//grain/_src/core:pytest", "//grain/_src/core:transforms", "//grain/_src/python:options", - "//grain/_src/python:shared_memory_array", + "//grain/_src/python/ipc:shared_memory_array", "//grain/proto:execution_summary_py_pb2", "@abseil-py//absl/flags", "@abseil-py//absl/testing:absltest", diff --git a/grain/_src/python/dataset/stats_test.py b/grain/_src/python/dataset/stats_test.py index 6c17d24b8..abd003c9c 100644 --- a/grain/_src/python/dataset/stats_test.py +++ b/grain/_src/python/dataset/stats_test.py @@ -28,9 +28,9 @@ from grain._src.core import pytest from grain._src.core import transforms from grain._src.python import options -from grain._src.python import shared_memory_array from grain._src.python.dataset import dataset from grain._src.python.dataset import stats +from grain._src.python.ipc import shared_memory_array from grain.proto import execution_summary_pb2 import numpy as np diff --git a/grain/_src/python/dataset/stats_utils.py b/grain/_src/python/dataset/stats_utils.py index a95ce130c..8498c0680 100644 --- a/grain/_src/python/dataset/stats_utils.py +++ b/grain/_src/python/dataset/stats_utils.py @@ -19,7 +19,7 @@ from absl import logging from grain._src.core import tree_lib -from grain._src.python import shared_memory_array +from grain._src.python.ipc import shared_memory_array from grain.proto import execution_summary_pb2 import numpy as np diff --git a/grain/_src/python/dataset/stats_utils_test.py b/grain/_src/python/dataset/stats_utils_test.py index 486ba5ed0..edbf29848 100644 --- a/grain/_src/python/dataset/stats_utils_test.py +++ b/grain/_src/python/dataset/stats_utils_test.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from grain._src.python import shared_memory_array from grain._src.python.dataset import stats_utils +from grain._src.python.ipc import shared_memory_array from grain.proto import execution_summary_pb2 import numpy as np diff --git a/grain/_src/python/dataset/transformations/process_prefetch.py b/grain/_src/python/dataset/transformations/process_prefetch.py index 156eba673..5a8cda83b 100644 --- a/grain/_src/python/dataset/transformations/process_prefetch.py +++ b/grain/_src/python/dataset/transformations/process_prefetch.py @@ -32,13 +32,13 @@ from grain._src.core.config import config import multiprocessing as mp from grain._src.python import grain_logging -from grain._src.python import multiprocessing_common -from grain._src.python import shared_memory_array from grain._src.python.dataset import base from grain._src.python.dataset import dataset from grain._src.python.dataset import stats as dataset_stats from grain._src.python.dataset.transformations import interleave from grain._src.python.dataset.transformations import prefetch +from grain._src.python.ipc import queue as grain_queue +from grain._src.python.ipc import shared_memory_array T = TypeVar("T") @@ -213,7 +213,7 @@ def _put_dataset_elements_in_buffer( if set_state_request_count.value > 0: set_state_request_count.value -= 1 parent_exhausted = False - if not multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types + if not grain_queue.add_element_to_queue( # pytype: disable=wrong-arg-types (_SetStateIsDone(), None, None, None), buffer, should_stop.is_set, @@ -234,7 +234,7 @@ def _put_dataset_elements_in_buffer( try: element = it.__next__() except Exception as e: # pylint: disable=broad-except - multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types + grain_queue.add_element_to_queue( # pytype: disable=wrong-arg-types (None, None, None, e), buffer, should_stop.is_set ) parent_exhausted = True @@ -244,7 +244,7 @@ def _put_dataset_elements_in_buffer( # __next__ method. if not it._stats._config.is_prefetch: # pylint: disable=protected-access it._stats.record_bytes_produced(element) # pylint: disable=protected-access - if not multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types + if not grain_queue.add_element_to_queue( # pytype: disable=wrong-arg-types (element, it.get_state(), next_index, None), buffer, should_stop.is_set, @@ -258,7 +258,7 @@ def _put_dataset_elements_in_buffer( except Exception as e: # pylint: disable=broad-except _clear_queue_and_maybe_unlink_shm(buffer) _clear_queue_and_maybe_unlink_shm(set_state_queue) - multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types + grain_queue.add_element_to_queue( # pytype: disable=wrong-arg-types (None, None, None, e), buffer, should_stop.is_set ) return diff --git a/grain/_src/python/ipc/BUILD b/grain/_src/python/ipc/BUILD new file mode 100644 index 000000000..ab654319d --- /dev/null +++ b/grain/_src/python/ipc/BUILD @@ -0,0 +1,68 @@ +package(default_visibility = ["//grain:__subpackages__"]) + +licenses(["notice"]) + +py_library( + name = "queue", + srcs = [ + "queue.py", + ], + srcs_version = "PY3", +) + +py_test( + name = "queue_test", + srcs = [ + "queue_test.py", + ], + srcs_version = "PY3", + deps = [ + ":queue", + "@abseil-py//absl/testing:absltest", + ], +) + +py_library( + name = "shared_memory_array", + srcs = ["shared_memory_array.py"], + srcs_version = "PY3", + deps = [ + "//grain/_src/core:tree_lib", + "@pypi//numpy:pkg", + ], +) + +py_test( + name = "shared_memory_array_test", + srcs = ["shared_memory_array_test.py"], + srcs_version = "PY3", + # Skip bazel test due to JAX installation issue. This is tested with pytest instead. + target_compatible_with = select({ + "//conditions:default": ["@platforms//:incompatible"], + }), + deps = [ + ":shared_memory_array", + "//grain/_src/python:operations", + "//grain/_src/python:record", + "@abseil-py//absl/testing:absltest", + "@abseil-py//absl/testing:parameterized", + "@pypi//jax:pkg", + "@pypi//numpy:pkg", + ], +) + +py_library( + name = "variable_size_queue", + srcs = ["variable_size_queue.py"], + srcs_version = "PY3", +) + +py_test( + name = "variable_size_queue_test", + srcs = ["variable_size_queue_test.py"], + srcs_version = "PY3", + deps = [ + ":variable_size_queue", + "@abseil-py//absl/testing:absltest", + ], +) diff --git a/grain/_src/python/multiprocessing_common.py b/grain/_src/python/ipc/queue.py similarity index 98% rename from grain/_src/python/multiprocessing_common.py rename to grain/_src/python/ipc/queue.py index a40f66b58..07cb4d101 100644 --- a/grain/_src/python/multiprocessing_common.py +++ b/grain/_src/python/ipc/queue.py @@ -17,7 +17,7 @@ import multiprocessing from multiprocessing import pool import queue -from typing import TypeVar, Union, Callable +from typing import Callable, TypeVar, Union T = TypeVar('T') diff --git a/grain/_src/python/multiprocessing_common_test.py b/grain/_src/python/ipc/queue_test.py similarity index 98% rename from grain/_src/python/multiprocessing_common_test.py rename to grain/_src/python/ipc/queue_test.py index 4293fd7b5..e5ea48a93 100644 --- a/grain/_src/python/multiprocessing_common_test.py +++ b/grain/_src/python/ipc/queue_test.py @@ -18,7 +18,7 @@ import queue from absl.testing import absltest -from grain._src.python import multiprocessing_common +from grain._src.python.ipc import queue as multiprocessing_common class MultiProcessingCommonTest(absltest.TestCase): diff --git a/grain/_src/python/shared_memory_array.py b/grain/_src/python/ipc/shared_memory_array.py similarity index 99% rename from grain/_src/python/shared_memory_array.py rename to grain/_src/python/ipc/shared_memory_array.py index a7ef243e2..8dab94068 100644 --- a/grain/_src/python/shared_memory_array.py +++ b/grain/_src/python/ipc/shared_memory_array.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Shared memory array.""" + from __future__ import annotations import dataclasses diff --git a/grain/_src/python/shared_memory_array_test.py b/grain/_src/python/ipc/shared_memory_array_test.py similarity index 93% rename from grain/_src/python/shared_memory_array_test.py rename to grain/_src/python/ipc/shared_memory_array_test.py index 65994882f..a1acb971e 100644 --- a/grain/_src/python/shared_memory_array_test.py +++ b/grain/_src/python/ipc/shared_memory_array_test.py @@ -22,16 +22,19 @@ from absl.testing import absltest from absl.testing import parameterized import multiprocessing +from grain._src.python import operations from grain._src.python import record -from grain._src.python.operations import BatchOperation -from grain._src.python.shared_memory_array import copy_to_shm -from grain._src.python.shared_memory_array import open_from_shm -from grain._src.python.shared_memory_array import SharedMemoryArray -from grain._src.python.shared_memory_array import SharedMemoryArrayMetadata -from grain._src.python.shared_memory_array import unlink_shm +from grain._src.python.ipc import shared_memory_array import jax import numpy as np +SharedMemoryArray = shared_memory_array.SharedMemoryArray +SharedMemoryArrayMetadata = shared_memory_array.SharedMemoryArrayMetadata +copy_to_shm = shared_memory_array.copy_to_shm +open_from_shm = shared_memory_array.open_from_shm +unlink_shm = shared_memory_array.unlink_shm +BatchOperation = operations.BatchOperation + def _create_and_delete_shm() -> SharedMemoryArrayMetadata: data = np.array([[1, 2], [3, 4]], dtype=np.int32) @@ -64,15 +67,13 @@ def test_batch_dict_of_data_with_shared_memory(self, mode): else: data = list(map(jax.numpy.array, data)) - input_data = iter( - [ - record.Record( - record.RecordMetadata(index=idx, record_key=idx + 1), - {"a": item}, - ) - for idx, item in enumerate(data) - ] - ) + input_data = iter([ + record.Record( + record.RecordMetadata(index=idx, record_key=idx + 1), + {"a": item}, + ) + for idx, item in enumerate(data) + ]) batch_operation = BatchOperation(batch_size=2) batch_operation._enable_shared_memory() diff --git a/grain/_src/python/variable_size_queue.py b/grain/_src/python/ipc/variable_size_queue.py similarity index 100% rename from grain/_src/python/variable_size_queue.py rename to grain/_src/python/ipc/variable_size_queue.py diff --git a/grain/_src/python/variable_size_queue_test.py b/grain/_src/python/ipc/variable_size_queue_test.py similarity index 99% rename from grain/_src/python/variable_size_queue_test.py rename to grain/_src/python/ipc/variable_size_queue_test.py index e6034f8ce..ba1063ce9 100644 --- a/grain/_src/python/variable_size_queue_test.py +++ b/grain/_src/python/ipc/variable_size_queue_test.py @@ -20,7 +20,7 @@ from absl.testing import absltest import multiprocessing as mp -from grain._src.python import variable_size_queue +from grain._src.python.ipc import variable_size_queue def _consumer_function_for_test(q, result): diff --git a/grain/_src/python/operations.py b/grain/_src/python/operations.py index 1103b788a..9a86f4b0e 100644 --- a/grain/_src/python/operations.py +++ b/grain/_src/python/operations.py @@ -23,7 +23,7 @@ from absl import logging from grain._src.core import tree_lib from grain._src.python import record -from grain._src.python.shared_memory_array import SharedMemoryArray +from grain._src.python.ipc import shared_memory_array import numpy as np _IN = TypeVar("_IN") @@ -202,7 +202,9 @@ def stacking_function(*args): shape, dtype = (len(args),) + first_arg.shape, first_arg.dtype if not self._use_shared_memory or dtype.hasobject: return np.stack(args) - return np.stack(args, out=SharedMemoryArray(shape, dtype=dtype)).metadata + return np.stack( + args, out=shared_memory_array.SharedMemoryArray(shape, dtype=dtype) + ).metadata return tree_lib.map_structure( stacking_function, input_records[0], *input_records[1:] diff --git a/grain/multiprocessing.py b/grain/multiprocessing.py index 72f2445e8..ce933465c 100644 --- a/grain/multiprocessing.py +++ b/grain/multiprocessing.py @@ -17,5 +17,5 @@ # pylint: disable=g-importing-member # pylint: disable=unused-import +from grain._src.python.ipc.shared_memory_array import SharedMemoryArray from grain._src.python.options import MultiprocessingOptions -from grain._src.python.shared_memory_array import SharedMemoryArray diff --git a/grain/python/__init__.py b/grain/python/__init__.py index 0ac9d7d1f..ed97769ec 100644 --- a/grain/python/__init__.py +++ b/grain/python/__init__.py @@ -64,7 +64,7 @@ IterDataset, DatasetIterator, ) - +from grain._src.python.ipc.shared_memory_array import SharedMemoryArray from grain._src.python.load import load from grain._src.python.operations import ( BatchOperation, @@ -80,7 +80,6 @@ Sampler, SequentialSampler, ) -from grain._src.python.shared_memory_array import SharedMemoryArray from grain.python import experimental # These are imported only if Orbax is present.