diff --git a/grain/_src/python/BUILD b/grain/_src/python/BUILD index 3fe798268..20421ab37 100644 --- a/grain/_src/python/BUILD +++ b/grain/_src/python/BUILD @@ -243,7 +243,9 @@ py_library( name = "options", srcs = ["options.py"], srcs_version = "PY3", - deps = ["@abseil-py//absl/logging"], + deps = [ + "@abseil-py//absl/logging", + ], ) py_test( diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index 7ceb2a8b4..b1ff477ed 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -143,10 +143,15 @@ def __init__( self._next_buffered_index = 0 self._buffer = collections.deque() self._lock = threading.Lock() - self._prefetch_buffer_size = ( - read_options.prefetch_buffer_size if read_options.num_threads > 0 else 0 - ) + + assert isinstance(read_options.num_threads, int) + assert isinstance(read_options.prefetch_buffer_size, int) self._num_threads = read_options.num_threads + self._prefetch_buffer_size = read_options.prefetch_buffer_size + + if self._num_threads == 0: + self._prefetch_buffer_size = 0 + self._allow_nones = allow_nones if self._prefetch_buffer_size > 0: self._executor = futures.ThreadPoolExecutor( @@ -254,7 +259,7 @@ def __str__(self) -> str: f" allow_nones={self._allow_nones})" ) - def set_prefetch_buffer_size(self, buffer_size: int): + def _set_prefetch_buffer_size(self, buffer_size: int): self._prefetch_buffer_size = buffer_size # The executor is created in the constructor only if the prefetch buffer # size is greater than 0. If the user changes the prefetch buffer size, we @@ -272,7 +277,7 @@ def set_prefetch_buffer_size(self, buffer_size: int): self._executor.shutdown() delattr(self, "_executor") - def set_num_threads(self, num_threads: int) -> None: + def _set_num_threads(self, num_threads: int) -> None: self._num_threads = num_threads old_executor = None # Accounts for the case where the executor does not exit. This can diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index 0855b0b4b..f3de16815 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -162,7 +162,7 @@ def test_set_prefetch_buffer_size_0_to_positive(self): self.assertEqual(next(ds_iter), 0) # Setting prefetch_buffer_size to 2. - ds_iter.set_prefetch_buffer_size(2) + ds_iter._set_prefetch_buffer_size(2) self.assertEqual(ds_iter._prefetch_buffer_size, 2) self.assertEqual(next(ds_iter), 1) self.assertTrue(hasattr(ds_iter, '_executor')) @@ -183,7 +183,7 @@ def test_set_prefetch_buffer_size_positive_to_0(self): self.assertLen(ds_iter._buffer, 2) # Setting prefetch_buffer_size to 0. - ds_iter.set_prefetch_buffer_size(0) + ds_iter._set_prefetch_buffer_size(0) self.assertEqual(ds_iter._prefetch_buffer_size, 0) # Should consume buffer first. self.assertEqual(next(ds_iter), 1) @@ -207,7 +207,7 @@ def test_set_prefetch_buffer_size_increase(self): self.assertLen(ds_iter._buffer, 1) # Setting prefetch_buffer_size to 2. - ds_iter.set_prefetch_buffer_size(2) + ds_iter._set_prefetch_buffer_size(2) self.assertEqual(ds_iter._prefetch_buffer_size, 2) self.assertEqual(next(ds_iter), 1) self.assertLen(ds_iter._buffer, 2) @@ -227,7 +227,7 @@ def test_set_prefetch_buffer_size_decrease(self): self.assertLen(ds_iter._buffer, 2) # Setting prefetch_buffer_size to 1. - ds_iter.set_prefetch_buffer_size(1) + ds_iter._set_prefetch_buffer_size(1) self.assertEqual(ds_iter._prefetch_buffer_size, 1) self.assertEqual(next(ds_iter), 1) self.assertLen(ds_iter._buffer, 1) @@ -328,7 +328,7 @@ def test_set_num_threads_decrease_threads(self): self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5))) # Decrease threads - ds_iter.set_num_threads(5) + ds_iter._set_num_threads(5) self.assertEqual(ds_iter._num_threads, 5) self.assertEqual(ds_iter._executor._max_workers, 5) self.assertEqual([next(ds_iter) for _ in range(15)], list(range(5, 20))) @@ -345,7 +345,7 @@ def test_set_num_threads_increase_threads(self): self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5))) # Increase threads - ds_iter.set_num_threads(10) + ds_iter._set_num_threads(10) self.assertEqual(ds_iter._num_threads, 10) self.assertEqual(ds_iter._executor._max_workers, 10) self.assertEqual([next(ds_iter) for _ in range(15)], list(range(5, 20))) @@ -360,7 +360,7 @@ def test_set_num_threads_decrease_to_zero(self): ) self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5))) # Decrease threads to 0 - ds_iter.set_num_threads(0) + ds_iter._set_num_threads(0) self.assertEqual(ds_iter._num_threads, 0) self.assertFalse(hasattr(ds_iter, '_executor')) self.assertEqual([next(ds_iter) for _ in range(15)], list(range(5, 20))) @@ -370,13 +370,13 @@ def test_set_num_threads_increase_from_zero(self): self.assertIsInstance(ds_iter, prefetch.PrefetchDatasetIterator) ds_iter = cast(prefetch.PrefetchDatasetIterator, ds_iter) self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5))) - ds_iter.set_num_threads(0) + ds_iter._set_num_threads(0) self.assertEqual(ds_iter._num_threads, 0) self.assertFalse(hasattr(ds_iter, '_executor')) self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5, 10))) # Increase threads from 0 - ds_iter.set_num_threads(5) + ds_iter._set_num_threads(5) self.assertEqual(ds_iter._num_threads, 5) self.assertEqual(ds_iter._executor._max_workers, 5) self.assertEqual([next(ds_iter) for _ in range(10)], list(range(10, 20))) diff --git a/grain/_src/python/options.py b/grain/_src/python/options.py index 293e3441a..f855aa932 100644 --- a/grain/_src/python/options.py +++ b/grain/_src/python/options.py @@ -12,9 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """Dataclasses for holdings options.""" +from __future__ import annotations + import dataclasses from absl import logging +class AutotuneParameter: + + def __init__(self, *args, **kwargs): + raise NotImplementedError @dataclasses.dataclass(slots=True) @@ -41,25 +47,29 @@ class ReadOptions: # benchmarks reading from remote hard drives. # These values should work well for datasets with elements between 1 and # 10 KiB on disk. - num_threads: int = 16 - prefetch_buffer_size: int = 500 + num_threads: int | AutotuneParameter = 16 + prefetch_buffer_size: int | AutotuneParameter = 500 def __post_init__(self): - if self.num_threads < 0: + if isinstance(self.num_threads, int) and self.num_threads < 0: raise ValueError( f'num_threads must be non-negative, got {self.num_threads}' ) - if self.prefetch_buffer_size < 0: + + if ( + isinstance(self.prefetch_buffer_size, int) + and self.prefetch_buffer_size < 0 + ): raise ValueError( 'prefetch_buffer_size must be non-negative, got' f' {self.prefetch_buffer_size}' ) + # Avoid warning when setting prefetch_buffer_size=0, since this is commonly # used to disable prefetching. - if ( - self.prefetch_buffer_size < self.num_threads - and self.prefetch_buffer_size != 0 - ): + buffer_size = int(self.prefetch_buffer_size) + num_threads = int(self.num_threads) + if buffer_size < num_threads and buffer_size != 0: logging.warning( 'prefetch_buffer_size=%s is smaller than num_threads=%s. This will' ' limit the number of threads that can actually be used in parallel'