From 5e88dddf272aeb5bb678bb58422c7b4988b5b473 Mon Sep 17 00:00:00 2001 From: init2winit Team Date: Sun, 22 Feb 2026 23:27:30 -0800 Subject: [PATCH] internal PiperOrigin-RevId: 873892307 --- .../dataset_lib/criteo_terabyte_dataset.py | 272 +++++++++++++++++- init2winit/dataset_lib/data_utils.py | 96 +++++++ 2 files changed, 354 insertions(+), 14 deletions(-) diff --git a/init2winit/dataset_lib/criteo_terabyte_dataset.py b/init2winit/dataset_lib/criteo_terabyte_dataset.py index db1268b8..6be85e86 100644 --- a/init2winit/dataset_lib/criteo_terabyte_dataset.py +++ b/init2winit/dataset_lib/criteo_terabyte_dataset.py @@ -34,17 +34,19 @@ import tensorflow as tf import tensorflow_datasets as tfds - # Change to the path to raw dataset files. RAW_CRITEO1TB_FILE_PATH = '' -CRITEO1TB_DEFAULT_HPARAMS = config_dict.ConfigDict(dict( - input_shape=(13 + 26,), - train_size=4_195_197_692, - # We assume the tie breaking example went to the validation set, because - # the test set in the mlperf version has 89_137_318 examples. - valid_size=89_137_319, - test_size=89_137_318, -)) +PREPROCESSED_CRITEO1TB_FILE_PATH = '' # pylint: disable=invalid-name +CRITEO1TB_DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + input_shape=(13 + 26,), + train_size=4_195_197_692, + # We assume the tie breaking example went to the validation set, because + # the test set in the mlperf version has 89_137_318 examples. + valid_size=89_137_319, + test_size=89_137_318, + ) +) CRITEO1TB_METADATA = { 'apply_one_hot_in_loss': True, } @@ -142,9 +144,9 @@ def criteo_tsv_reader( ds = ds.repeat() ds = ds.interleave( tf.data.TextLineDataset, - cycle_length=128, + cycle_length=64, block_length=batch_size // 8, - num_parallel_calls=128, + num_parallel_calls=64, deterministic=False) if is_training: ds = ds.shuffle(buffer_size=524_288 * 100, seed=data_shuffle_seed) @@ -155,6 +157,94 @@ def criteo_tsv_reader( return ds +_ARRAYRECORD_FEATURE_SPEC = { + 'inputs': tf.io.FixedLenFeature([13 + 26], tf.float32), + 'targets': tf.io.FixedLenFeature([1], tf.float32), +} + + +@tf.function +def _parse_arrayrecord_example_fn(serialized_examples): + """Parse a batch of serialized tf.train.Examples from ArrayRecord.""" + parsed = tf.io.parse_example(serialized_examples, _ARRAYRECORD_FEATURE_SPEC) + return { + 'inputs': parsed['inputs'], + 'targets': tf.squeeze(parsed['targets'], axis=-1), + } + + +def criteo_arrayrecord_reader( + split, shuffle_rng, file_path, batch_size, num_batches_to_prefetch +): + """Input reader for preprocessed Criteo ArrayRecord data. + + Args: + split: one of {'train', 'eval_train', 'validation', 'test'}. + shuffle_rng: jax.random.PRNGKey for shuffling (train). + file_path: glob pattern for .array_record files. + batch_size: per-host batch size. + num_batches_to_prefetch: number of batches to prefetch. + + Returns: + A tf.data.Dataset object. + """ + # Import here to avoid hard dependency for TSV-only users. + if split not in ['train', 'eval_train', 'validation', 'test']: + raise ValueError(f'Invalid split name {split}.') + data_shuffle_seed = None + + is_training = split == 'train' + if is_training: + _, data_shuffle_seed = jax.random.split(shuffle_rng, 2) + data_shuffle_seed = data_utils.convert_jax_to_tf_random_seed( + data_shuffle_seed + ) + + # Discover all matching files. + all_files = sorted(tf.io.gfile.glob(file_path)) + if not all_files: + raise ValueError(f'No ArrayRecord files found matching: {file_path}') + + # Shard files across hosts. + index = jax.process_index() + num_hosts = jax.process_count() + host_files = all_files[index::num_hosts] + + # Interleave per-file datasets, with batch+parse inside each file's + # sub-pipeline. This is critical for performance: interleaving dense float + # tensors (post-parse) is much faster than interleaving raw byte strings + # and batching them later. + file_ds = tf.data.Dataset.from_tensor_slices(host_files) + if is_training: + file_ds = file_ds.repeat() + file_ds = file_ds.shuffle( + buffer_size=2 * len(host_files), seed=data_shuffle_seed + ) + + ds = file_ds.interleave( + lambda f: ( + ar_dataset.ArrayRecordDataset([f]) + .batch( + batch_size, + drop_remainder=is_training, + num_parallel_calls=tf.data.AUTOTUNE, + deterministic=False, + ) + .map( + _parse_arrayrecord_example_fn, + num_parallel_calls=tf.data.AUTOTUNE, + deterministic=False, + ) + ), + cycle_length=64, + block_length=batch_size // 8, + num_parallel_calls=64, + deterministic=False, + ) + ds = ds.prefetch(num_batches_to_prefetch) + return ds + + def _convert_to_numpy_iterator_fn( num_batches, per_host_eval_batch_size, tf_dataset, split_size): """Make eval iterator. This function is called at the start of each eval.""" @@ -211,12 +301,143 @@ def get_criteo1tb(shuffle_rng, per_host_eval_batch_size = eval_batch_size // process_count per_host_batch_size = batch_size // process_count + use_raw_tsv = hps.get('use_raw_tsv', False) + num_batches_to_prefetch = ( + hps.num_tf_data_prefetches + if hps.num_tf_data_prefetches > 0 + else tf.data.AUTOTUNE + ) + num_device_prefetches = hps.get('num_device_prefetches', 0) + + if use_raw_tsv: + return _get_criteo1tb_tsv( + shuffle_rng, + per_host_batch_size, + per_host_eval_batch_size, + hps, + num_batches_to_prefetch, + num_device_prefetches, + ) + else: + return _get_criteo1tb_arrayrecord( + shuffle_rng, + per_host_batch_size, + per_host_eval_batch_size, + hps, + num_batches_to_prefetch, + num_device_prefetches, + ) + + +def _get_criteo1tb_arrayrecord( + shuffle_rng, + per_host_batch_size, + per_host_eval_batch_size, + hps, + num_batches_to_prefetch, + num_device_prefetches, +): + """Load Criteo 1TB from preprocessed ArrayRecord files.""" + base = hps.get('preprocessed_data_path', PREPROCESSED_CRITEO1TB_FILE_PATH) + train_file_path = os.path.join(base, 'train', '*') + validation_file_path = os.path.join( + base, 'val_set_second_half_of_day23_not_used', '*' + ) + test_file_path = os.path.join(base, 'eval', '*') + + train_dataset = criteo_arrayrecord_reader( + split='train', + shuffle_rng=shuffle_rng, + file_path=train_file_path, + batch_size=per_host_batch_size, + num_batches_to_prefetch=num_batches_to_prefetch, + ) + data_utils.log_rss('train arrayrecord dataset created') + + if num_device_prefetches > 0: + train_iterator_fn = lambda: data_utils.prefetch_iterator( + tfds.as_numpy(train_dataset), num_device_prefetches + ) + data_utils.log_rss( + f'using prefetching with {num_device_prefetches} in the train dataset' + ) + else: + train_iterator_fn = lambda: tfds.as_numpy(train_dataset) + + eval_train_dataset = criteo_arrayrecord_reader( + split='eval_train', + shuffle_rng=None, + file_path=train_file_path, + batch_size=per_host_eval_batch_size, + num_batches_to_prefetch=num_batches_to_prefetch, + ) + eval_train_iterator_fn = functools.partial( + _convert_to_numpy_iterator_fn, + per_host_eval_batch_size=per_host_eval_batch_size, + tf_dataset=eval_train_dataset, + split_size=hps.train_size, + ) + data_utils.log_rss('eval_train arrayrecord dataset created') + + validation_dataset = criteo_arrayrecord_reader( + split='validation', + shuffle_rng=None, + file_path=validation_file_path, + batch_size=per_host_eval_batch_size, + num_batches_to_prefetch=num_batches_to_prefetch, + ) + validation_iterator_fn = functools.partial( + _convert_to_numpy_iterator_fn, + per_host_eval_batch_size=per_host_eval_batch_size, + tf_dataset=validation_dataset, + split_size=hps.valid_size, + ) + data_utils.log_rss('validation arrayrecord dataset created') + + test_dataset = criteo_arrayrecord_reader( + split='test', + shuffle_rng=None, + file_path=test_file_path, + batch_size=per_host_eval_batch_size, + num_batches_to_prefetch=num_batches_to_prefetch, + ) + test_iterator_fn = functools.partial( + _convert_to_numpy_iterator_fn, + per_host_eval_batch_size=per_host_eval_batch_size, + tf_dataset=test_dataset, + split_size=hps.test_size, + ) + data_utils.log_rss('test arrayrecord dataset created') + + eval_train_iterator_fn = data_utils.CachedEvalIterator( + eval_train_iterator_fn, 'eval_train' + ) + validation_iterator_fn = data_utils.CachedEvalIterator( + validation_iterator_fn, 'validation' + ) + test_iterator_fn = data_utils.CachedEvalIterator(test_iterator_fn, 'test') + + return data_utils.Dataset( + train_iterator_fn, + eval_train_iterator_fn, + validation_iterator_fn, + test_iterator_fn, + ) + + +def _get_criteo1tb_tsv( + shuffle_rng, + per_host_batch_size, + per_host_eval_batch_size, + hps, + num_batches_to_prefetch, + num_device_prefetches, +): + """Load Criteo 1TB from raw TSV files (legacy path).""" train_file_path = os.path.join(RAW_CRITEO1TB_FILE_PATH, 'train/*/*') validation_file_path = os.path.join( RAW_CRITEO1TB_FILE_PATH, 'val_set_second_half_of_day23_not_used/*') test_file_path = os.path.join(RAW_CRITEO1TB_FILE_PATH, 'eval/day_23/*') - num_batches_to_prefetch = (hps.num_tf_data_prefetches - if hps.num_tf_data_prefetches > 0 else tf.data.AUTOTUNE) train_dataset = criteo_tsv_reader( split='train', @@ -225,7 +446,16 @@ def get_criteo1tb(shuffle_rng, num_dense_features=hps.num_dense_features, batch_size=per_host_batch_size, num_batches_to_prefetch=num_batches_to_prefetch) - train_iterator_fn = lambda: tfds.as_numpy(train_dataset) + data_utils.log_rss('train dataset created') + if num_device_prefetches > 0: + train_iterator_fn = lambda: data_utils.prefetch_iterator( + tfds.as_numpy(train_dataset), num_device_prefetches + ) + data_utils.log_rss( + f'using prefetching with {num_device_prefetches} in the train dataset' + ) + else: + train_iterator_fn = lambda: tfds.as_numpy(train_dataset) eval_train_dataset = criteo_tsv_reader( split='eval_train', shuffle_rng=None, @@ -238,6 +468,7 @@ def get_criteo1tb(shuffle_rng, per_host_eval_batch_size=per_host_eval_batch_size, tf_dataset=eval_train_dataset, split_size=hps.train_size) + data_utils.log_rss('eval_train dataset created') validation_dataset = criteo_tsv_reader( split='validation', shuffle_rng=None, @@ -250,6 +481,7 @@ def get_criteo1tb(shuffle_rng, per_host_eval_batch_size=per_host_eval_batch_size, tf_dataset=validation_dataset, split_size=hps.valid_size) + data_utils.log_rss('validation dataset created') test_dataset = criteo_tsv_reader( split='test', shuffle_rng=None, @@ -262,6 +494,18 @@ def get_criteo1tb(shuffle_rng, per_host_eval_batch_size=per_host_eval_batch_size, tf_dataset=test_dataset, split_size=hps.test_size) + data_utils.log_rss('test dataset created') + + # Cache all the eval_train/validation/test iterators to avoid re-processing the + # same data files. + eval_train_iterator_fn = data_utils.CachedEvalIterator( + eval_train_iterator_fn, 'eval_train' + ) + validation_iterator_fn = data_utils.CachedEvalIterator( + validation_iterator_fn, 'validation' + ) + test_iterator_fn = data_utils.CachedEvalIterator(test_iterator_fn, 'test') + return data_utils.Dataset( train_iterator_fn, eval_train_iterator_fn, diff --git a/init2winit/dataset_lib/data_utils.py b/init2winit/dataset_lib/data_utils.py index 393bcb5f..0eb77a90 100644 --- a/init2winit/dataset_lib/data_utils.py +++ b/init2winit/dataset_lib/data_utils.py @@ -16,7 +16,12 @@ """Common code used by different models.""" import collections +import queue +import resource +import threading +from typing import Iterator +from absl import logging import flax.linen as nn import jax from jax.nn import one_hot @@ -33,6 +38,97 @@ ]) +def log_rss(msg: str): + """Logs the current memory usage and prints the given message.""" + rss_mb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 + logging.info('%s — RSS: %.1f MB', msg, rss_mb) + + +def prefetch_iterator(source_iter: Iterator[jax.typing.ArrayLike], + num_prefetch: int) -> Iterator[jax.typing.ArrayLike]: + """Wraps the given iterator with prefetching. + + Args: + source_iter: The iterator to wrap. + num_prefetch: The number of items to prefetch. + + Yields: + Prefetched items from `source_iter`. + """ + buf = queue.Queue(maxsize=num_prefetch) + sentinel = object() # Used to signal end of iterator + + def producer(): + try: + for item in source_iter: + buf.put(item) + except Exception as e: # pylint: disable=broad-except + buf.put(e) + buf.put(sentinel) + + t = threading.Thread(target=producer, daemon=True) + t.start() + + while True: + item = buf.get() + if item is sentinel: + return + if isinstance(item, Exception): + raise item + yield item + + +class CachedEvalIterator: + """Lazily caches eval batches, which are typically small enough to fit into host memory.""" + + def __init__(self, iterator_factory, split_name='eval'): + self._factory = iterator_factory + self._split_name = split_name + self._cache = [] + self._iterator = None + self._fully_cached = False + + def __call__(self, num_batches=None): + yielded = 0 + + limit = ( + len(self._cache) + if num_batches is None + else min(len(self._cache), num_batches) + ) + for i in range(limit): + yield self._cache[i] + yielded += 1 + + if num_batches is not None and yielded >= num_batches: + return + + if self._fully_cached: + return + + if self._iterator is None: + logging.info('Building %s cache lazily...', self._split_name) + self._iterator = iter(self._factory(None)) + + for batch in self._iterator: + self._cache.append(batch) + yield batch + yielded += 1 + if num_batches is not None and yielded >= num_batches: + return + + self._fully_cached = True + self._factory = None + self._iterator = None + rss_mb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 + logging.info( + '%s cache complete: %d batches — RSS: %.1f MB', + self._split_name, + len(self._cache), + rss_mb, + ) + + def iterator_as_numpy(iterator): for x in iterator: yield jax.tree.map(lambda y: y._numpy(), x) # pylint: disable=protected-access