diff --git a/grain/_src/python/dataset/transformations/batch.py b/grain/_src/python/dataset/transformations/batch.py index 2ba2c690..1fdc4281 100644 --- a/grain/_src/python/dataset/transformations/batch.py +++ b/grain/_src/python/dataset/transformations/batch.py @@ -86,9 +86,8 @@ def __call__(self, values: Sequence[T]) -> T: def _batch_fn(*xs: Sequence[T]) -> T: # If the thread pool is not available or the elements are not NumPy # arrays, fall back to the standard serial `np.stack` operation. - if (self._parallel_batch_executor is None) or not isinstance( - xs[0], np.ndarray - ): + all_ndarray = all(isinstance(x, np.ndarray) for x in xs) + if (self._parallel_batch_executor is None) or not all_ndarray: return np.stack(xs) xs = cast(Sequence[np.ndarray], xs) # Fall back to the standard serial `np.stack` operation if the size of diff --git a/grain/_src/python/dataset/transformations/batch_test.py b/grain/_src/python/dataset/transformations/batch_test.py index 02f9d66b..9a48f5bc 100644 --- a/grain/_src/python/dataset/transformations/batch_test.py +++ b/grain/_src/python/dataset/transformations/batch_test.py @@ -69,6 +69,13 @@ def test_batch_single_value_parallel_batch_enabled_success(self): batched_values = make_batch_parallel(values) self.assertEqual(batched_values.shape, (1, 3)) + def test_batch_non_numpy_values(self): + values = [np.asarray([1, 2, 3]), [4, 5, 6]] + make_batch_parallel = batch._MakeBatchParallel() + batched_values = make_batch_parallel(values) + self.assertIsInstance(batched_values, np.ndarray) + self.assertEqual(batched_values.shape, (2, 3)) + def test_batch_two_values_success(self): values = [np.asarray([1, 2, 3]), np.asarray([4, 5, 6])] batched_values = batch.make_batch(values)