diff --git a/grain/_src/python/dataset/transformations/filter.py b/grain/_src/python/dataset/transformations/filter.py index 40b0ef1f..2c320d45 100644 --- a/grain/_src/python/dataset/transformations/filter.py +++ b/grain/_src/python/dataset/transformations/filter.py @@ -14,7 +14,7 @@ """Filter transformation for LazyDataset.""" import functools -from typing import Any, Callable, TypeVar, Union +from typing import Any, Callable, Sequence, TypeVar, Union from absl import logging from grain._src.core import transforms @@ -53,14 +53,22 @@ def _transform_name(self): def __len__(self) -> int: return len(self._parent) + def _filter_element(self, element: T | None) -> T | None: + if element is not None and self._filter_fn(element): + return element + return None + def __getitem__(self, index): if isinstance(index, slice): return self.slice(index) element = self._parent[index] with self._stats.record_self_time(): - if element is not None and self._filter_fn(element): - return element - return None + return self._filter_element(element) + + def _getitems(self, indices: Sequence[int]) -> Sequence[T | None]: + elements = self._parent._getitems(indices) # pylint: disable=protected-access + with self._stats.record_self_time(num_elements=len(indices)): + return [self._filter_element(element) for element in elements] @property def _element_spec(self) -> Any: diff --git a/grain/_src/python/dataset/transformations/filter_test.py b/grain/_src/python/dataset/transformations/filter_test.py index 72d8aa0f..744422bf 100644 --- a/grain/_src/python/dataset/transformations/filter_test.py +++ b/grain/_src/python/dataset/transformations/filter_test.py @@ -95,6 +95,23 @@ def test_filter_even_elements_only(self): ] self.assertEqual(expected_data, actual_data) + def test_filter_data_with_get_items(self): + filter_even_elts_ds = filter_dataset.FilterMapDataset( + self.range_ds, FilterEvenElementsOnly() + ) + indices = [0, 1, 4, 5, 8, 9] + # Expected results for indices [0, 1, 4, 5, 8, 9] with + # FilterEvenElementsOnly: + # 0 -> 0%2=0 -> False -> None + # 1 -> 1%2=1 -> True -> 1 + # 4 -> 4%2=0 -> False -> None + # 5 -> 5%2=1 -> True -> 5 + # 8 -> 8%2=0 -> False -> None + # 9 -> 9%2=1 -> True -> 9 + expected_data = [None, 1, None, 5, None, 9] + actual_data = filter_even_elts_ds._getitems(indices) + self.assertEqual(expected_data, actual_data) + def test_element_spec(self): ds = filter_dataset.FilterMapDataset( self.range_ds, FilterEvenElementsOnly()