From b12f0f4bdc209b023178560e54a917c7047a18b8 Mon Sep 17 00:00:00 2001 From: RJ Skerry-Ryan Date: Thu, 19 Feb 2026 21:20:17 -0800 Subject: [PATCH] Fall back to np.stack if *any* array is not an ndarray, not just the first. PiperOrigin-RevId: 872706726 --- grain/_src/python/dataset/transformations/batch.py | 5 ++--- grain/_src/python/dataset/transformations/batch_test.py | 7 +++++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/grain/_src/python/dataset/transformations/batch.py b/grain/_src/python/dataset/transformations/batch.py index 2ba2c6903..1fdc42815 100644 --- a/grain/_src/python/dataset/transformations/batch.py +++ b/grain/_src/python/dataset/transformations/batch.py @@ -86,9 +86,8 @@ def __call__(self, values: Sequence[T]) -> T: def _batch_fn(*xs: Sequence[T]) -> T: # If the thread pool is not available or the elements are not NumPy # arrays, fall back to the standard serial `np.stack` operation. - if (self._parallel_batch_executor is None) or not isinstance( - xs[0], np.ndarray - ): + all_ndarray = all(isinstance(x, np.ndarray) for x in xs) + if (self._parallel_batch_executor is None) or not all_ndarray: return np.stack(xs) xs = cast(Sequence[np.ndarray], xs) # Fall back to the standard serial `np.stack` operation if the size of diff --git a/grain/_src/python/dataset/transformations/batch_test.py b/grain/_src/python/dataset/transformations/batch_test.py index 02f9d66b0..9a48f5bc3 100644 --- a/grain/_src/python/dataset/transformations/batch_test.py +++ b/grain/_src/python/dataset/transformations/batch_test.py @@ -69,6 +69,13 @@ def test_batch_single_value_parallel_batch_enabled_success(self): batched_values = make_batch_parallel(values) self.assertEqual(batched_values.shape, (1, 3)) + def test_batch_non_numpy_values(self): + values = [np.asarray([1, 2, 3]), [4, 5, 6]] + make_batch_parallel = batch._MakeBatchParallel() + batched_values = make_batch_parallel(values) + self.assertIsInstance(batched_values, np.ndarray) + self.assertEqual(batched_values.shape, (2, 3)) + def test_batch_two_values_success(self): values = [np.asarray([1, 2, 3]), np.asarray([4, 5, 6])] batched_values = batch.make_batch(values)