Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion grain/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ py_library(
"//grain/_src/python:options",
"//grain/_src/python:record",
"//grain/_src/python:samplers",
"//grain/_src/python:shared_memory_array",
"//grain/_src/python/ipc:shared_memory_array",
"//grain/_src/python/dataset",
"//grain/_src/python/dataset:base",
"//grain/_src/python/dataset:elastic_iterator",
Expand Down
72 changes: 3 additions & 69 deletions grain/_src/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -68,26 +68,6 @@ py_test(
],
)

py_library(
name = "multiprocessing_common",
srcs = [
"multiprocessing_common.py",
],
srcs_version = "PY3",
)

py_test(
name = "multiprocessing_common_test",
srcs = [
"multiprocessing_common_test.py",
],
srcs_version = "PY3",
deps = [
":multiprocessing_common",
"@abseil-py//absl/testing:absltest",
],
)

py_library(
name = "operations",
srcs = [
Expand All @@ -96,8 +76,8 @@ py_library(
srcs_version = "PY3",
deps = [
":record",
":shared_memory_array",
"//grain/_src/core:tree_lib",
"//grain/_src/python/ipc:shared_memory_array",
"@abseil-py//absl/logging",
"@pypi//numpy:pkg",
],
Expand Down Expand Up @@ -159,15 +139,14 @@ py_library(
":options",
":record",
":samplers",
":shared_memory_array",
"//grain/_src/core:monitoring",
"//grain/_src/core:sharding",
"//grain/_src/core:transforms",
"//grain/_src/core:tree_lib",
"//grain/_src/python/checkpoint:base",
"//grain/_src/python/dataset",
"//grain/_src/python/dataset:base",
"@abseil-py//absl/logging",
"//grain/_src/python/ipc:shared_memory_array",
"@pypi//etils:pkg",
"@pypi//numpy:pkg",
],
Expand Down Expand Up @@ -196,10 +175,10 @@ py_test(
":operations",
":options",
":samplers",
":shared_memory_array",
"//grain/_src/core:sharding",
"//grain/_src/core:transforms",
"//grain/_src/python/dataset",
"//grain/_src/python/ipc:shared_memory_array",
"//grain/_src/python/testing:experimental",
"@abseil-py//absl/flags",
"@abseil-py//absl/testing:absltest",
Expand Down Expand Up @@ -275,48 +254,3 @@ py_test(
"@abseil-py//absl/testing:absltest",
],
)

py_library(
name = "shared_memory_array",
srcs = ["shared_memory_array.py"],
srcs_version = "PY3",
deps = [
"//grain/_src/core:tree_lib",
"@pypi//numpy:pkg",
],
)

py_test(
name = "shared_memory_array_test",
srcs = ["shared_memory_array_test.py"],
srcs_version = "PY3",
# Skip bazel test due to JAX installation issue. This is tested with pytest instead.
target_compatible_with = select({
"//conditions:default": ["@platforms//:incompatible"],
}),
deps = [
":operations",
":record",
":shared_memory_array",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:parameterized",
"@pypi//jax:pkg",
"@pypi//numpy:pkg",
],
)

py_library(
name = "variable_size_queue",
srcs = ["variable_size_queue.py"],
srcs_version = "PY3",
)

py_test(
name = "variable_size_queue_test",
srcs = ["variable_size_queue_test.py"],
srcs_version = "PY3",
deps = [
":variable_size_queue",
"@abseil-py//absl/testing:absltest",
],
)
21 changes: 11 additions & 10 deletions grain/_src/python/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@
from grain._src.python import operations as ops
from grain._src.python import options
from grain._src.python import record
from grain._src.python import samplers
from grain._src.python.checkpoint import base as checkpoint_base
from grain._src.python.dataset import base as dataset_base
from grain._src.python.dataset import dataset
from grain._src.python.dataset.transformations import batch as batch_ds
from grain._src.python.dataset.transformations import flatmap
from grain._src.python.operations import Operation
from grain._src.python.samplers import Sampler
from grain._src.python.shared_memory_array import SharedMemoryArray
from grain._src.python.ipc import shared_memory_array
import numpy as np

from grain._src.core import monitoring
Expand Down Expand Up @@ -108,7 +107,9 @@ def copy_if_applied(element: Any) -> Any:
):
return element

shared_memory_arr = SharedMemoryArray(element.shape, element.dtype)
shared_memory_arr = shared_memory_array.SharedMemoryArray(
element.shape, element.dtype
)
np.copyto(shared_memory_arr, element, casting="no")
return shared_memory_arr.metadata

Expand All @@ -121,7 +122,7 @@ class _SamplerMapDataset(dataset.MapDataset[record.Record]):
def __init__(
self,
data_source: dataset_base.RandomAccessDataSource,
sampler: Sampler,
sampler: samplers.Sampler,
shard_options: sharding.ShardOptions,
):
super().__init__(dataset.MapDataset.source(data_source))
Expand Down Expand Up @@ -212,7 +213,7 @@ def __init__(
parent: dataset.IterDataset[_T],
shard_options: sharding.ShardOptions,
worker_count: int,
sampler: Sampler,
sampler: samplers.Sampler,
data_source: dataset_base.RandomAccessDataSource,
):
super().__init__(parent)
Expand All @@ -239,7 +240,7 @@ def __init__(
parent: dataset.DatasetIterator[_T],
shard_options: sharding.ShardOptions | None,
worker_count: int,
sampler: Sampler,
sampler: samplers.Sampler,
data_source: dataset_base.RandomAccessDataSource,
):
super().__init__(parent)
Expand Down Expand Up @@ -350,8 +351,8 @@ def __init__(
self,
*,
data_source: dataset_base.RandomAccessDataSource,
sampler: Sampler,
operations: Sequence[transforms.Transformation | Operation] = (),
sampler: samplers.Sampler,
operations: Sequence[transforms.Transformation | ops.Operation] = (),
worker_count: Optional[int] = 0,
worker_buffer_size: int = 1,
shard_options: sharding.ShardOptions | None = None,
Expand Down Expand Up @@ -653,7 +654,7 @@ def flat_map(self, element: record.Record) -> Sequence[record.Record]:


def _apply_transform_to_dataset(
transform: transforms.Transformation | Operation,
transform: transforms.Transformation | ops.Operation,
ds: dataset.IterDataset,
) -> dataset.IterDataset:
"""Applies the `transform` to the dataset."""
Expand Down
2 changes: 1 addition & 1 deletion grain/_src/python/data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@
from grain._src.python import data_loader as data_loader_lib
from grain._src.python import options
from grain._src.python import samplers
from grain._src.python import shared_memory_array
# pylint: disable=g-importing-member
from grain._src.python.data_sources import ArrayRecordDataSource
from grain._src.python.data_sources import RangeDataSource
from grain._src.python.data_sources import SharedMemoryDataSource
from grain._src.python.dataset.transformations import batch
from grain._src.python.ipc import shared_memory_array
from grain._src.python.operations import BatchOperation
from grain._src.python.operations import FilterOperation
from grain._src.python.operations import MapOperation
Expand Down
10 changes: 5 additions & 5 deletions grain/_src/python/dataset/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ py_library(
"//grain/_src/core:transforms",
"//grain/_src/core:tree_lib",
"//grain/_src/python:grain_logging",
"//grain/_src/python:multiprocessing_common",
"//grain/_src/python:options",
"//grain/_src/python:shared_memory_array",
"//grain/_src/python/checkpoint:base",
"//grain/_src/python/ipc:queue",
"//grain/_src/python/ipc:shared_memory_array",
"//grain/proto:execution_summary_py_pb2",
"@abseil-py//absl/flags",
"@abseil-py//absl/logging",
Expand Down Expand Up @@ -111,7 +111,7 @@ py_library(
srcs_version = "PY3",
deps = [
"//grain/_src/core:tree_lib",
"//grain/_src/python:shared_memory_array",
"//grain/_src/python/ipc:shared_memory_array",
"//grain/proto:execution_summary_py_pb2",
"@abseil-py//absl/logging",
"@pypi//numpy:pkg",
Expand All @@ -124,7 +124,7 @@ py_test(
srcs_version = "PY3",
deps = [
":stats_utils",
"//grain/_src/python:shared_memory_array",
"//grain/_src/python/ipc:shared_memory_array",
"//grain/proto:execution_summary_py_pb2",
"@abseil-py//absl/testing:absltest",
"@pypi//numpy:pkg",
Expand All @@ -145,7 +145,7 @@ py_test(
"//grain/_src/core:pytest",
"//grain/_src/core:transforms",
"//grain/_src/python:options",
"//grain/_src/python:shared_memory_array",
"//grain/_src/python/ipc:shared_memory_array",
"//grain/proto:execution_summary_py_pb2",
"@abseil-py//absl/flags",
"@abseil-py//absl/testing:absltest",
Expand Down
2 changes: 1 addition & 1 deletion grain/_src/python/dataset/stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from grain._src.core import pytest
from grain._src.core import transforms
from grain._src.python import options
from grain._src.python import shared_memory_array
from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats
from grain._src.python.ipc import shared_memory_array
from grain.proto import execution_summary_pb2
import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion grain/_src/python/dataset/stats_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from absl import logging
from grain._src.core import tree_lib
from grain._src.python import shared_memory_array
from grain._src.python.ipc import shared_memory_array
from grain.proto import execution_summary_pb2
import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion grain/_src/python/dataset/stats_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from grain._src.python import shared_memory_array
from grain._src.python.dataset import stats_utils
from grain._src.python.ipc import shared_memory_array
from grain.proto import execution_summary_pb2
import numpy as np

Expand Down
12 changes: 6 additions & 6 deletions grain/_src/python/dataset/transformations/process_prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@
from grain._src.core.config import config
import multiprocessing as mp
from grain._src.python import grain_logging
from grain._src.python import multiprocessing_common
from grain._src.python import shared_memory_array
from grain._src.python.dataset import base
from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats as dataset_stats
from grain._src.python.dataset.transformations import interleave
from grain._src.python.dataset.transformations import prefetch
from grain._src.python.ipc import queue as grain_queue
from grain._src.python.ipc import shared_memory_array


T = TypeVar("T")
Expand Down Expand Up @@ -213,7 +213,7 @@ def _put_dataset_elements_in_buffer(
if set_state_request_count.value > 0:
set_state_request_count.value -= 1
parent_exhausted = False
if not multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types
if not grain_queue.add_element_to_queue( # pytype: disable=wrong-arg-types
(_SetStateIsDone(), None, None, None),
buffer,
should_stop.is_set,
Expand All @@ -234,7 +234,7 @@ def _put_dataset_elements_in_buffer(
try:
element = it.__next__()
except Exception as e: # pylint: disable=broad-except
multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types
grain_queue.add_element_to_queue( # pytype: disable=wrong-arg-types
(None, None, None, e), buffer, should_stop.is_set
)
parent_exhausted = True
Expand All @@ -244,7 +244,7 @@ def _put_dataset_elements_in_buffer(
# __next__ method.
if not it._stats._config.is_prefetch: # pylint: disable=protected-access
it._stats.record_bytes_produced(element) # pylint: disable=protected-access
if not multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types
if not grain_queue.add_element_to_queue( # pytype: disable=wrong-arg-types
(element, it.get_state(), next_index, None),
buffer,
should_stop.is_set,
Expand All @@ -258,7 +258,7 @@ def _put_dataset_elements_in_buffer(
except Exception as e: # pylint: disable=broad-except
_clear_queue_and_maybe_unlink_shm(buffer)
_clear_queue_and_maybe_unlink_shm(set_state_queue)
multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types
grain_queue.add_element_to_queue( # pytype: disable=wrong-arg-types
(None, None, None, e), buffer, should_stop.is_set
)
return
Expand Down
Loading