diff --git a/init2winit/dataset_lib/criteo_terabyte_dataset.py b/init2winit/dataset_lib/criteo_terabyte_dataset.py index 65ce0e51..51513917 100644 --- a/init2winit/dataset_lib/criteo_terabyte_dataset.py +++ b/init2winit/dataset_lib/criteo_terabyte_dataset.py @@ -37,14 +37,16 @@ # Change to the path to raw dataset files. RAW_CRITEO1TB_FILE_PATH = '' -CRITEO1TB_DEFAULT_HPARAMS = config_dict.ConfigDict(dict( - input_shape=(13 + 26,), - train_size=4_195_197_692, - # We assume the tie breaking example went to the validation set, because - # the test set in the mlperf version has 89_137_318 examples. - valid_size=89_137_319, - test_size=89_137_318, -)) +CRITEO1TB_DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + input_shape=(13 + 26,), + train_size=4_195_197_692, + # We assume the tie breaking example went to the validation set, because + # the test set in the mlperf version has 89_137_318 examples. + valid_size=89_137_319, + test_size=89_137_318, + ) +) CRITEO1TB_METADATA = { 'apply_one_hot_in_loss': True, } @@ -155,17 +157,26 @@ def criteo_tsv_reader( return ds -def _convert_to_numpy_iterator_fn( - num_batches, per_host_eval_batch_size, tf_dataset, split_size): - """Make eval iterator. This function is called at the start of each eval.""" - # Some hosts could see different numbers of examples, which in some cases - # could lead to some hosts not having enough examples to make the same number - # of batches. This makes pmap hang because it is waiting for a batch that will - # never come from the host with less data. - # - # We assume that all files have the same number of examples in them, and while - # this may not always be true, it dramatically simplifies/speeds up the logic - # because the alternative is to run the same iterator (without data file +def _eval_numpy_iterator( + num_batches, per_host_eval_batch_size, tf_dataset, split_size +): + """Yields padded numpy batches from a tf.data eval split. + + Caps the number of batches to the split size divided across hosts. If the + source runs out before `num_batches`, yields zero-filled batches so every + host sees the same count. + + Args: + num_batches (int): number of batches to process. + per_host_eval_batch_size (int): batch size per host. + tf_dataset (tfds.Dataset): source tensorflow dataset. + split_size (int): total number of examples in the eval split. + + Yields: + Padded numpy batches. + """ + # We assume all files have the same number of examples. This simplifies the + # logic because the alternative is to run the same iterator (without data file # sharding) on each host, skipping (num_hosts - 1) / num_hosts batches. # # Any final partial batches are padded to be the full batch size, so we can @@ -218,6 +229,8 @@ def get_criteo1tb(shuffle_rng, num_batches_to_prefetch = (hps.num_tf_data_prefetches if hps.num_tf_data_prefetches > 0 else tf.data.AUTOTUNE) + num_device_prefetches = hps.get('num_device_prefetches', 0) + train_dataset = criteo_tsv_reader( split='train', shuffle_rng=shuffle_rng, @@ -225,7 +238,16 @@ def get_criteo1tb(shuffle_rng, num_dense_features=hps.num_dense_features, batch_size=per_host_batch_size, num_batches_to_prefetch=num_batches_to_prefetch) - train_iterator_fn = lambda: tfds.as_numpy(train_dataset) + data_utils.log_rss('train dataset created') + if num_device_prefetches > 0: + train_iterator_fn = lambda: data_utils.prefetch_iterator( + tfds.as_numpy(train_dataset), num_device_prefetches + ) + data_utils.log_rss( + f'using prefetching with {num_device_prefetches} in the train dataset' + ) + else: + train_iterator_fn = lambda: tfds.as_numpy(train_dataset) eval_train_dataset = criteo_tsv_reader( split='eval_train', shuffle_rng=None, @@ -234,10 +256,12 @@ def get_criteo1tb(shuffle_rng, batch_size=per_host_eval_batch_size, num_batches_to_prefetch=num_batches_to_prefetch) eval_train_iterator_fn = functools.partial( - _convert_to_numpy_iterator_fn, + _eval_numpy_iterator, per_host_eval_batch_size=per_host_eval_batch_size, tf_dataset=eval_train_dataset, - split_size=hps.train_size) + split_size=hps.train_size, + ) + data_utils.log_rss('eval_train dataset created') validation_dataset = criteo_tsv_reader( split='validation', shuffle_rng=None, @@ -246,10 +270,12 @@ def get_criteo1tb(shuffle_rng, batch_size=per_host_eval_batch_size, num_batches_to_prefetch=num_batches_to_prefetch) validation_iterator_fn = functools.partial( - _convert_to_numpy_iterator_fn, + _eval_numpy_iterator, per_host_eval_batch_size=per_host_eval_batch_size, tf_dataset=validation_dataset, - split_size=hps.valid_size) + split_size=hps.valid_size, + ) + data_utils.log_rss('validation dataset created') test_dataset = criteo_tsv_reader( split='test', shuffle_rng=None, @@ -258,10 +284,24 @@ def get_criteo1tb(shuffle_rng, batch_size=per_host_eval_batch_size, num_batches_to_prefetch=num_batches_to_prefetch) test_iterator_fn = functools.partial( - _convert_to_numpy_iterator_fn, + _eval_numpy_iterator, per_host_eval_batch_size=per_host_eval_batch_size, tf_dataset=test_dataset, - split_size=hps.test_size) + split_size=hps.test_size, + ) + data_utils.log_rss('test dataset created') + + # Cache all the eval_train/validation/test iterators to avoid re-processing the same data files. + eval_train_iterator_fn = data_utils.CachedIteratorFactory( + eval_train_iterator_fn(None), 'eval_train' + ) + validation_iterator_fn = data_utils.CachedIteratorFactory( + validation_iterator_fn(None), 'validation' + ) + test_iterator_fn = data_utils.CachedIteratorFactory( + test_iterator_fn(None), 'test' + ) + return data_utils.Dataset( train_iterator_fn, eval_train_iterator_fn, diff --git a/init2winit/dataset_lib/data_utils.py b/init2winit/dataset_lib/data_utils.py index 393bcb5f..5f0ec34d 100644 --- a/init2winit/dataset_lib/data_utils.py +++ b/init2winit/dataset_lib/data_utils.py @@ -16,7 +16,13 @@ """Common code used by different models.""" import collections +import itertools +import queue +import resource +import threading +from typing import Any, Iterator +from absl import logging import flax.linen as nn import jax from jax.nn import one_hot @@ -33,6 +39,83 @@ ]) +def log_rss(msg: str): + """Logs the current memory usage and prints the given message.""" + rss_mb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 + logging.info('%s — RSS: %.1f MB', msg, rss_mb) + + +def prefetch_iterator(source_iter: Iterator[jax.typing.ArrayLike], + num_prefetch: int) -> Iterator[jax.typing.ArrayLike]: + """Wraps the given iterator with prefetching. + + Args: + source_iter: The iterator to wrap. + num_prefetch: The number of items to prefetch. + + Yields: + Prefetched items from `source_iter`. + """ + if num_prefetch < 0: + raise ValueError(f'num_prefetch must be non-negative, got {num_prefetch}') + elif num_prefetch == 0: + yield from source_iter + return + + buf = queue.Queue(maxsize=num_prefetch) + sentinel = object() # Used to signal end of iterator + + def producer(): + for item in source_iter: + buf.put(item) + buf.put(sentinel) + + t = threading.Thread(target=producer, daemon=True) + t.start() + + while True: + item = buf.get() + if item is sentinel: + return + yield item + + +def _history_keeping_iterator(iterable, history, num_items=None): + for val in itertools.islice(iterable, num_items): + history.append(val) + yield val + + +class CachedIteratorFactory: + """A callable that caches batches from an iterator factory. + + On each call, yields up to `num_batches` items — first from the cache, + then by continuing to iterate the underlying source (caching as it goes). + Once the source is fully exhausted, the source iterator is freed. + """ + + def __init__(self, source: Iterator[Any], split_name: str = 'eval'): + self._split_name = split_name + self._source = iter(source) + self.cache = [] + + def __call__(self, num_batches: int | None = None) -> Iterator[Any]: + yield from self.cache[:num_batches] + + # If we have exhausted the source or cached enough batches, we're done. + if self._source is None or ( + num_batches is not None and num_batches <= len(self.cache) + ): + return + + remaining = None if num_batches is None else num_batches - len(self.cache) + before = len(self.cache) + yield from _history_keeping_iterator(self._source, self.cache, remaining) + if remaining is None or len(self.cache) - before < remaining: + self._source = None + log_rss(f'{self._split_name} cache {len(self.cache)} batches') + + def iterator_as_numpy(iterator): for x in iterator: yield jax.tree.map(lambda y: y._numpy(), x) # pylint: disable=protected-access diff --git a/init2winit/dataset_lib/test_data_utils.py b/init2winit/dataset_lib/test_data_utils.py index 184a0287..fdbce54e 100644 --- a/init2winit/dataset_lib/test_data_utils.py +++ b/init2winit/dataset_lib/test_data_utils.py @@ -36,6 +36,32 @@ test_parameters = zip(test_names, image_formats, batch_axes, input_shapes) +class PrefetchIteratorTest(absltest.TestCase): + """Unit tests for data_utils.prefetch_iterator.""" + + def test_output_matches_original(self): + """Test that the output matches the original iterator.""" + sizes = [0, 1, 20] + for size in sizes: + data = list(range(size)) + result = list(data_utils.prefetch_iterator(iter(data), num_prefetch=4)) + self.assertEqual(result, data) + + def test_various_buffer_sizes(self): + """Test that the output matches the original iterator for various buffer sizes.""" + data = list(range(10)) + for buf_size in [1, 2, 5, 10, 20]: + result = list(data_utils.prefetch_iterator(iter(data), buf_size)) + self.assertEqual(result, data, f'Failed with buffer_size={buf_size}') + + def test_numpy_arrays(self): + """Test that the output matches the original iterator for numpy arrays.""" + arrays = [np.arange(i, i + 3) for i in range(5)] + result = list(data_utils.prefetch_iterator(iter(arrays), num_prefetch=2)) + for expected, actual in zip(arrays, result): + np.testing.assert_array_equal(actual, expected) + + class DataUtilsTest(parameterized.TestCase): """Unit tests for datasets.py.""" @@ -81,5 +107,49 @@ def test_padding_seq2seq(self): np.array_equal(padded_batch['weights'], expected_weights_array)) +class CachedIteratorFactoryTest(absltest.TestCase): + """Tests that CachedIteratorFactory caches correctly and frees memory.""" + + def _make_cached_iter_factory(self, data): + """Helper function to create a CachedIteratorFactory.""" + return data_utils.CachedIteratorFactory(iter(data), split_name='test') + + def test_progressive_caching(self): + """Test that the factory correctly caches progressively.""" + data = list(range(20)) + cached_iter_factory = self._make_cached_iter_factory(data) + self.assertEqual(list(cached_iter_factory(num_batches=5)), [0, 1, 2, 3, 4]) + self.assertEqual(list(cached_iter_factory(num_batches=10)), list(range(10))) + self.assertEqual(list(cached_iter_factory(num_batches=3)), [0, 1, 2]) + + def test_varying_num_batches(self): + """Test varying num_batches with a 10-element source.""" + data = list(range(10)) + cached_iter_factory = self._make_cached_iter_factory(data) + self.assertEqual(list(cached_iter_factory(num_batches=5)), list(range(5))) + self.assertLen(cached_iter_factory.cache, 5) + self.assertEqual(list(cached_iter_factory(num_batches=10)), list(range(10))) + self.assertLen(cached_iter_factory.cache, 10) + # If we request more batches than available, we get all that is available. + self.assertEqual(list(cached_iter_factory(num_batches=12)), list(range(10))) + + def test_full_iteration(self): + """Test that the factory correctly caches the full dataset.""" + data = list(range(10)) + cached_iter_factory = self._make_cached_iter_factory(data) + first = list(cached_iter_factory(num_batches=None)) + self.assertEqual(first, data) + second = list(cached_iter_factory(num_batches=None)) + self.assertEqual(second, data) + + def test_source_freed_after_exhaustion(self): + """Test that the source iterator is freed after exhaustion.""" + data = list(range(5)) + cached_iter_factory = self._make_cached_iter_factory(data) + self.assertIsNotNone(cached_iter_factory._source) # pylint: disable=protected-access + list(cached_iter_factory(num_batches=None)) + self.assertIsNone(cached_iter_factory._source) # pylint: disable=protected-access + + if __name__ == '__main__': absltest.main()