From 39b8a86f34262263a6ba71f91199a3062be23262 Mon Sep 17 00:00:00 2001 From: init2winit Team Date: Mon, 23 Feb 2026 00:11:39 -0800 Subject: [PATCH] internal PiperOrigin-RevId: 873907256 --- .../dataset_lib/criteo_terabyte_dataset.py | 49 +++++++--- init2winit/dataset_lib/data_utils.py | 96 +++++++++++++++++++ 2 files changed, 134 insertions(+), 11 deletions(-) diff --git a/init2winit/dataset_lib/criteo_terabyte_dataset.py b/init2winit/dataset_lib/criteo_terabyte_dataset.py index db1268b8..a8f27c3d 100644 --- a/init2winit/dataset_lib/criteo_terabyte_dataset.py +++ b/init2winit/dataset_lib/criteo_terabyte_dataset.py @@ -37,14 +37,16 @@ # Change to the path to raw dataset files. RAW_CRITEO1TB_FILE_PATH = '' -CRITEO1TB_DEFAULT_HPARAMS = config_dict.ConfigDict(dict( - input_shape=(13 + 26,), - train_size=4_195_197_692, - # We assume the tie breaking example went to the validation set, because - # the test set in the mlperf version has 89_137_318 examples. - valid_size=89_137_319, - test_size=89_137_318, -)) +CRITEO1TB_DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + input_shape=(13 + 26,), + train_size=4_195_197_692, + # We assume the tie breaking example went to the validation set, because + # the test set in the mlperf version has 89_137_318 examples. + valid_size=89_137_319, + test_size=89_137_318, + ) +) CRITEO1TB_METADATA = { 'apply_one_hot_in_loss': True, } @@ -142,9 +144,9 @@ def criteo_tsv_reader( ds = ds.repeat() ds = ds.interleave( tf.data.TextLineDataset, - cycle_length=128, + cycle_length=64, block_length=batch_size // 8, - num_parallel_calls=128, + num_parallel_calls=64, deterministic=False) if is_training: ds = ds.shuffle(buffer_size=524_288 * 100, seed=data_shuffle_seed) @@ -218,6 +220,8 @@ def get_criteo1tb(shuffle_rng, num_batches_to_prefetch = (hps.num_tf_data_prefetches if hps.num_tf_data_prefetches > 0 else tf.data.AUTOTUNE) + num_device_prefetches = hps.get('num_device_prefetches', 0) + train_dataset = criteo_tsv_reader( split='train', shuffle_rng=shuffle_rng, @@ -225,7 +229,16 @@ def get_criteo1tb(shuffle_rng, num_dense_features=hps.num_dense_features, batch_size=per_host_batch_size, num_batches_to_prefetch=num_batches_to_prefetch) - train_iterator_fn = lambda: tfds.as_numpy(train_dataset) + data_utils.log_rss('train dataset created') + if num_device_prefetches > 0: + train_iterator_fn = lambda: data_utils.prefetch_iterator( + tfds.as_numpy(train_dataset), num_device_prefetches + ) + data_utils.log_rss( + f'using prefetching with {num_device_prefetches} in the train dataset' + ) + else: + train_iterator_fn = lambda: tfds.as_numpy(train_dataset) eval_train_dataset = criteo_tsv_reader( split='eval_train', shuffle_rng=None, @@ -238,6 +251,7 @@ def get_criteo1tb(shuffle_rng, per_host_eval_batch_size=per_host_eval_batch_size, tf_dataset=eval_train_dataset, split_size=hps.train_size) + data_utils.log_rss('eval_train dataset created') validation_dataset = criteo_tsv_reader( split='validation', shuffle_rng=None, @@ -250,6 +264,7 @@ def get_criteo1tb(shuffle_rng, per_host_eval_batch_size=per_host_eval_batch_size, tf_dataset=validation_dataset, split_size=hps.valid_size) + data_utils.log_rss('validation dataset created') test_dataset = criteo_tsv_reader( split='test', shuffle_rng=None, @@ -262,6 +277,18 @@ def get_criteo1tb(shuffle_rng, per_host_eval_batch_size=per_host_eval_batch_size, tf_dataset=test_dataset, split_size=hps.test_size) + data_utils.log_rss('test dataset created') + + # Cache all the eval_train/validation/test iterators to avoid re-processing the + # same data files. + eval_train_iterator_fn = data_utils.CachedEvalIterator( + eval_train_iterator_fn, 'eval_train' + ) + validation_iterator_fn = data_utils.CachedEvalIterator( + validation_iterator_fn, 'validation' + ) + test_iterator_fn = data_utils.CachedEvalIterator(test_iterator_fn, 'test') + return data_utils.Dataset( train_iterator_fn, eval_train_iterator_fn, diff --git a/init2winit/dataset_lib/data_utils.py b/init2winit/dataset_lib/data_utils.py index 393bcb5f..0eb77a90 100644 --- a/init2winit/dataset_lib/data_utils.py +++ b/init2winit/dataset_lib/data_utils.py @@ -16,7 +16,12 @@ """Common code used by different models.""" import collections +import queue +import resource +import threading +from typing import Iterator +from absl import logging import flax.linen as nn import jax from jax.nn import one_hot @@ -33,6 +38,97 @@ ]) +def log_rss(msg: str): + """Logs the current memory usage and prints the given message.""" + rss_mb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 + logging.info('%s — RSS: %.1f MB', msg, rss_mb) + + +def prefetch_iterator(source_iter: Iterator[jax.typing.ArrayLike], + num_prefetch: int) -> Iterator[jax.typing.ArrayLike]: + """Wraps the given iterator with prefetching. + + Args: + source_iter: The iterator to wrap. + num_prefetch: The number of items to prefetch. + + Yields: + Prefetched items from `source_iter`. + """ + buf = queue.Queue(maxsize=num_prefetch) + sentinel = object() # Used to signal end of iterator + + def producer(): + try: + for item in source_iter: + buf.put(item) + except Exception as e: # pylint: disable=broad-except + buf.put(e) + buf.put(sentinel) + + t = threading.Thread(target=producer, daemon=True) + t.start() + + while True: + item = buf.get() + if item is sentinel: + return + if isinstance(item, Exception): + raise item + yield item + + +class CachedEvalIterator: + """Lazily caches eval batches, which are typically small enough to fit into host memory.""" + + def __init__(self, iterator_factory, split_name='eval'): + self._factory = iterator_factory + self._split_name = split_name + self._cache = [] + self._iterator = None + self._fully_cached = False + + def __call__(self, num_batches=None): + yielded = 0 + + limit = ( + len(self._cache) + if num_batches is None + else min(len(self._cache), num_batches) + ) + for i in range(limit): + yield self._cache[i] + yielded += 1 + + if num_batches is not None and yielded >= num_batches: + return + + if self._fully_cached: + return + + if self._iterator is None: + logging.info('Building %s cache lazily...', self._split_name) + self._iterator = iter(self._factory(None)) + + for batch in self._iterator: + self._cache.append(batch) + yield batch + yielded += 1 + if num_batches is not None and yielded >= num_batches: + return + + self._fully_cached = True + self._factory = None + self._iterator = None + rss_mb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 + logging.info( + '%s cache complete: %d batches — RSS: %.1f MB', + self._split_name, + len(self._cache), + rss_mb, + ) + + def iterator_as_numpy(iterator): for x in iterator: yield jax.tree.map(lambda y: y._numpy(), x) # pylint: disable=protected-access