Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 66 additions & 26 deletions init2winit/dataset_lib/criteo_terabyte_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@

# Change to the path to raw dataset files.
RAW_CRITEO1TB_FILE_PATH = ''
CRITEO1TB_DEFAULT_HPARAMS = config_dict.ConfigDict(dict(
input_shape=(13 + 26,),
train_size=4_195_197_692,
# We assume the tie breaking example went to the validation set, because
# the test set in the mlperf version has 89_137_318 examples.
valid_size=89_137_319,
test_size=89_137_318,
))
CRITEO1TB_DEFAULT_HPARAMS = config_dict.ConfigDict(
dict(
input_shape=(13 + 26,),
train_size=4_195_197_692,
# We assume the tie breaking example went to the validation set, because
# the test set in the mlperf version has 89_137_318 examples.
valid_size=89_137_319,
test_size=89_137_318,
)
)
CRITEO1TB_METADATA = {
'apply_one_hot_in_loss': True,
}
Expand Down Expand Up @@ -155,17 +157,26 @@ def criteo_tsv_reader(
return ds


def _convert_to_numpy_iterator_fn(
num_batches, per_host_eval_batch_size, tf_dataset, split_size):
"""Make eval iterator. This function is called at the start of each eval."""
# Some hosts could see different numbers of examples, which in some cases
# could lead to some hosts not having enough examples to make the same number
# of batches. This makes pmap hang because it is waiting for a batch that will
# never come from the host with less data.
#
# We assume that all files have the same number of examples in them, and while
# this may not always be true, it dramatically simplifies/speeds up the logic
# because the alternative is to run the same iterator (without data file
def _eval_numpy_iterator(
num_batches, per_host_eval_batch_size, tf_dataset, split_size
):
"""Yields padded numpy batches from a tf.data eval split.

Caps the number of batches to the split size divided across hosts. If the
source runs out before `num_batches`, yields zero-filled batches so every
host sees the same count.

Args:
num_batches (int): number of batches to process.
per_host_eval_batch_size (int): batch size per host.
tf_dataset (tfds.Dataset): source tensorflow dataset.
split_size (int): total number of examples in the eval split.

Yields:
Padded numpy batches.
"""
# We assume all files have the same number of examples. This simplifies the
# logic because the alternative is to run the same iterator (without data file
# sharding) on each host, skipping (num_hosts - 1) / num_hosts batches.
#
# Any final partial batches are padded to be the full batch size, so we can
Expand Down Expand Up @@ -218,14 +229,25 @@ def get_criteo1tb(shuffle_rng,
num_batches_to_prefetch = (hps.num_tf_data_prefetches
if hps.num_tf_data_prefetches > 0 else tf.data.AUTOTUNE)

num_device_prefetches = hps.get('num_device_prefetches', 0)

train_dataset = criteo_tsv_reader(
split='train',
shuffle_rng=shuffle_rng,
file_path=train_file_path,
num_dense_features=hps.num_dense_features,
batch_size=per_host_batch_size,
num_batches_to_prefetch=num_batches_to_prefetch)
train_iterator_fn = lambda: tfds.as_numpy(train_dataset)
data_utils.log_rss('train dataset created')
if num_device_prefetches > 0:
train_iterator_fn = lambda: data_utils.prefetch_iterator(
tfds.as_numpy(train_dataset), num_device_prefetches
)
data_utils.log_rss(
f'using prefetching with {num_device_prefetches} in the train dataset'
)
else:
train_iterator_fn = lambda: tfds.as_numpy(train_dataset)
eval_train_dataset = criteo_tsv_reader(
split='eval_train',
shuffle_rng=None,
Expand All @@ -234,10 +256,12 @@ def get_criteo1tb(shuffle_rng,
batch_size=per_host_eval_batch_size,
num_batches_to_prefetch=num_batches_to_prefetch)
eval_train_iterator_fn = functools.partial(
_convert_to_numpy_iterator_fn,
_eval_numpy_iterator,
per_host_eval_batch_size=per_host_eval_batch_size,
tf_dataset=eval_train_dataset,
split_size=hps.train_size)
split_size=hps.train_size,
)
data_utils.log_rss('eval_train dataset created')
validation_dataset = criteo_tsv_reader(
split='validation',
shuffle_rng=None,
Expand All @@ -246,10 +270,12 @@ def get_criteo1tb(shuffle_rng,
batch_size=per_host_eval_batch_size,
num_batches_to_prefetch=num_batches_to_prefetch)
validation_iterator_fn = functools.partial(
_convert_to_numpy_iterator_fn,
_eval_numpy_iterator,
per_host_eval_batch_size=per_host_eval_batch_size,
tf_dataset=validation_dataset,
split_size=hps.valid_size)
split_size=hps.valid_size,
)
data_utils.log_rss('validation dataset created')
test_dataset = criteo_tsv_reader(
split='test',
shuffle_rng=None,
Expand All @@ -258,10 +284,24 @@ def get_criteo1tb(shuffle_rng,
batch_size=per_host_eval_batch_size,
num_batches_to_prefetch=num_batches_to_prefetch)
test_iterator_fn = functools.partial(
_convert_to_numpy_iterator_fn,
_eval_numpy_iterator,
per_host_eval_batch_size=per_host_eval_batch_size,
tf_dataset=test_dataset,
split_size=hps.test_size)
split_size=hps.test_size,
)
data_utils.log_rss('test dataset created')

# Cache all the eval_train/validation/test iterators to avoid re-processing the same data files.
eval_train_iterator_fn = data_utils.CachedIteratorFactory(
eval_train_iterator_fn(None), 'eval_train'
)
validation_iterator_fn = data_utils.CachedIteratorFactory(
validation_iterator_fn(None), 'validation'
)
test_iterator_fn = data_utils.CachedIteratorFactory(
test_iterator_fn(None), 'test'
)

return data_utils.Dataset(
train_iterator_fn,
eval_train_iterator_fn,
Expand Down
83 changes: 83 additions & 0 deletions init2winit/dataset_lib/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
"""Common code used by different models."""

import collections
import itertools
import queue
import resource
import threading
from typing import Any, Iterator

from absl import logging
import flax.linen as nn
import jax
from jax.nn import one_hot
Expand All @@ -33,6 +39,83 @@
])


def log_rss(msg: str):
"""Logs the current memory usage and prints the given message."""
rss_mb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024
logging.info('%s — RSS: %.1f MB', msg, rss_mb)


def prefetch_iterator(source_iter: Iterator[jax.typing.ArrayLike],
num_prefetch: int) -> Iterator[jax.typing.ArrayLike]:
"""Wraps the given iterator with prefetching.

Args:
source_iter: The iterator to wrap.
num_prefetch: The number of items to prefetch.

Yields:
Prefetched items from `source_iter`.
"""
if num_prefetch < 0:
raise ValueError(f'num_prefetch must be non-negative, got {num_prefetch}')
elif num_prefetch == 0:
yield from source_iter
return

buf = queue.Queue(maxsize=num_prefetch)
sentinel = object() # Used to signal end of iterator

def producer():
for item in source_iter:
buf.put(item)
buf.put(sentinel)

t = threading.Thread(target=producer, daemon=True)
t.start()

while True:
item = buf.get()
if item is sentinel:
return
yield item


def _history_keeping_iterator(iterable, history, num_items=None):
for val in itertools.islice(iterable, num_items):
history.append(val)
yield val


class CachedIteratorFactory:
"""A callable that caches batches from an iterator factory.

On each call, yields up to `num_batches` items — first from the cache,
then by continuing to iterate the underlying source (caching as it goes).
Once the source is fully exhausted, the source iterator is freed.
"""

def __init__(self, source: Iterator[Any], split_name: str = 'eval'):
self._split_name = split_name
self._source = iter(source)
self.cache = []

def __call__(self, num_batches: int | None = None) -> Iterator[Any]:
yield from self.cache[:num_batches]

# If we have exhausted the source or cached enough batches, we're done.
if self._source is None or (
num_batches is not None and num_batches <= len(self.cache)
):
return

remaining = None if num_batches is None else num_batches - len(self.cache)
before = len(self.cache)
yield from _history_keeping_iterator(self._source, self.cache, remaining)
if remaining is None or len(self.cache) - before < remaining:
self._source = None
log_rss(f'{self._split_name} cache {len(self.cache)} batches')


def iterator_as_numpy(iterator):
for x in iterator:
yield jax.tree.map(lambda y: y._numpy(), x) # pylint: disable=protected-access
Expand Down
70 changes: 70 additions & 0 deletions init2winit/dataset_lib/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,32 @@
test_parameters = zip(test_names, image_formats, batch_axes, input_shapes)


class PrefetchIteratorTest(absltest.TestCase):
"""Unit tests for data_utils.prefetch_iterator."""

def test_output_matches_original(self):
"""Test that the output matches the original iterator."""
sizes = [0, 1, 20]
for size in sizes:
data = list(range(size))
result = list(data_utils.prefetch_iterator(iter(data), num_prefetch=4))
self.assertEqual(result, data)

def test_various_buffer_sizes(self):
"""Test that the output matches the original iterator for various buffer sizes."""
data = list(range(10))
for buf_size in [1, 2, 5, 10, 20]:
result = list(data_utils.prefetch_iterator(iter(data), buf_size))
self.assertEqual(result, data, f'Failed with buffer_size={buf_size}')

def test_numpy_arrays(self):
"""Test that the output matches the original iterator for numpy arrays."""
arrays = [np.arange(i, i + 3) for i in range(5)]
result = list(data_utils.prefetch_iterator(iter(arrays), num_prefetch=2))
for expected, actual in zip(arrays, result):
np.testing.assert_array_equal(actual, expected)


class DataUtilsTest(parameterized.TestCase):
"""Unit tests for datasets.py."""

Expand Down Expand Up @@ -81,5 +107,49 @@ def test_padding_seq2seq(self):
np.array_equal(padded_batch['weights'], expected_weights_array))


class CachedIteratorFactoryTest(absltest.TestCase):
"""Tests that CachedIteratorFactory caches correctly and frees memory."""

def _make_cached_iter_factory(self, data):
"""Helper function to create a CachedIteratorFactory."""
return data_utils.CachedIteratorFactory(iter(data), split_name='test')

def test_progressive_caching(self):
"""Test that the factory correctly caches progressively."""
data = list(range(20))
cached_iter_factory = self._make_cached_iter_factory(data)
self.assertEqual(list(cached_iter_factory(num_batches=5)), [0, 1, 2, 3, 4])
self.assertEqual(list(cached_iter_factory(num_batches=10)), list(range(10)))
self.assertEqual(list(cached_iter_factory(num_batches=3)), [0, 1, 2])

def test_varying_num_batches(self):
"""Test varying num_batches with a 10-element source."""
data = list(range(10))
cached_iter_factory = self._make_cached_iter_factory(data)
self.assertEqual(list(cached_iter_factory(num_batches=5)), list(range(5)))
self.assertLen(cached_iter_factory.cache, 5)
self.assertEqual(list(cached_iter_factory(num_batches=10)), list(range(10)))
self.assertLen(cached_iter_factory.cache, 10)
# If we request more batches than available, we get all that is available.
self.assertEqual(list(cached_iter_factory(num_batches=12)), list(range(10)))

def test_full_iteration(self):
"""Test that the factory correctly caches the full dataset."""
data = list(range(10))
cached_iter_factory = self._make_cached_iter_factory(data)
first = list(cached_iter_factory(num_batches=None))
self.assertEqual(first, data)
second = list(cached_iter_factory(num_batches=None))
self.assertEqual(second, data)

def test_source_freed_after_exhaustion(self):
"""Test that the source iterator is freed after exhaustion."""
data = list(range(5))
cached_iter_factory = self._make_cached_iter_factory(data)
self.assertIsNotNone(cached_iter_factory._source) # pylint: disable=protected-access
list(cached_iter_factory(num_batches=None))
self.assertIsNone(cached_iter_factory._source) # pylint: disable=protected-access


if __name__ == '__main__':
absltest.main()