Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions init2winit/dataset_lib/criteo_terabyte_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@

# Change to the path to raw dataset files.
RAW_CRITEO1TB_FILE_PATH = ''
CRITEO1TB_DEFAULT_HPARAMS = config_dict.ConfigDict(dict(
input_shape=(13 + 26,),
train_size=4_195_197_692,
# We assume the tie breaking example went to the validation set, because
# the test set in the mlperf version has 89_137_318 examples.
valid_size=89_137_319,
test_size=89_137_318,
))
CRITEO1TB_DEFAULT_HPARAMS = config_dict.ConfigDict(
dict(
input_shape=(13 + 26,),
train_size=4_195_197_692,
# We assume the tie breaking example went to the validation set, because
# the test set in the mlperf version has 89_137_318 examples.
valid_size=89_137_319,
test_size=89_137_318,
)
)
CRITEO1TB_METADATA = {
'apply_one_hot_in_loss': True,
}
Expand Down Expand Up @@ -142,9 +144,9 @@ def criteo_tsv_reader(
ds = ds.repeat()
ds = ds.interleave(
tf.data.TextLineDataset,
cycle_length=128,
cycle_length=64,
block_length=batch_size // 8,
num_parallel_calls=128,
num_parallel_calls=64,
deterministic=False)
if is_training:
ds = ds.shuffle(buffer_size=524_288 * 100, seed=data_shuffle_seed)
Expand Down Expand Up @@ -218,14 +220,25 @@ def get_criteo1tb(shuffle_rng,
num_batches_to_prefetch = (hps.num_tf_data_prefetches
if hps.num_tf_data_prefetches > 0 else tf.data.AUTOTUNE)

num_device_prefetches = hps.get('num_device_prefetches', 0)

train_dataset = criteo_tsv_reader(
split='train',
shuffle_rng=shuffle_rng,
file_path=train_file_path,
num_dense_features=hps.num_dense_features,
batch_size=per_host_batch_size,
num_batches_to_prefetch=num_batches_to_prefetch)
train_iterator_fn = lambda: tfds.as_numpy(train_dataset)
data_utils.log_rss('train dataset created')
if num_device_prefetches > 0:
train_iterator_fn = lambda: data_utils.prefetch_iterator(
tfds.as_numpy(train_dataset), num_device_prefetches
)
data_utils.log_rss(
f'using prefetching with {num_device_prefetches} in the train dataset'
)
else:
train_iterator_fn = lambda: tfds.as_numpy(train_dataset)
eval_train_dataset = criteo_tsv_reader(
split='eval_train',
shuffle_rng=None,
Expand All @@ -238,6 +251,7 @@ def get_criteo1tb(shuffle_rng,
per_host_eval_batch_size=per_host_eval_batch_size,
tf_dataset=eval_train_dataset,
split_size=hps.train_size)
data_utils.log_rss('eval_train dataset created')
validation_dataset = criteo_tsv_reader(
split='validation',
shuffle_rng=None,
Expand All @@ -250,6 +264,7 @@ def get_criteo1tb(shuffle_rng,
per_host_eval_batch_size=per_host_eval_batch_size,
tf_dataset=validation_dataset,
split_size=hps.valid_size)
data_utils.log_rss('validation dataset created')
test_dataset = criteo_tsv_reader(
split='test',
shuffle_rng=None,
Expand All @@ -262,6 +277,18 @@ def get_criteo1tb(shuffle_rng,
per_host_eval_batch_size=per_host_eval_batch_size,
tf_dataset=test_dataset,
split_size=hps.test_size)
data_utils.log_rss('test dataset created')

# Cache all the eval_train/validation/test iterators to avoid re-processing the
# same data files.
eval_train_iterator_fn = data_utils.CachedEvalIterator(
eval_train_iterator_fn, 'eval_train'
)
validation_iterator_fn = data_utils.CachedEvalIterator(
validation_iterator_fn, 'validation'
)
test_iterator_fn = data_utils.CachedEvalIterator(test_iterator_fn, 'test')

return data_utils.Dataset(
train_iterator_fn,
eval_train_iterator_fn,
Expand Down
96 changes: 96 additions & 0 deletions init2winit/dataset_lib/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
"""Common code used by different models."""

import collections
import queue
import resource
import threading
from typing import Iterator

from absl import logging
import flax.linen as nn
import jax
from jax.nn import one_hot
Expand All @@ -33,6 +38,97 @@
])


def log_rss(msg: str):
"""Logs the current memory usage and prints the given message."""
rss_mb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024
logging.info('%s — RSS: %.1f MB', msg, rss_mb)


def prefetch_iterator(source_iter: Iterator[jax.typing.ArrayLike],
num_prefetch: int) -> Iterator[jax.typing.ArrayLike]:
"""Wraps the given iterator with prefetching.

Args:
source_iter: The iterator to wrap.
num_prefetch: The number of items to prefetch.

Yields:
Prefetched items from `source_iter`.
"""
buf = queue.Queue(maxsize=num_prefetch)
sentinel = object() # Used to signal end of iterator

def producer():
try:
for item in source_iter:
buf.put(item)
except Exception as e: # pylint: disable=broad-except
buf.put(e)
buf.put(sentinel)

t = threading.Thread(target=producer, daemon=True)
t.start()

while True:
item = buf.get()
if item is sentinel:
return
if isinstance(item, Exception):
raise item
yield item


class CachedEvalIterator:
"""Lazily caches eval batches, which are typically small enough to fit into host memory."""

def __init__(self, iterator_factory, split_name='eval'):
self._factory = iterator_factory
self._split_name = split_name
self._cache = []
self._iterator = None
self._fully_cached = False

def __call__(self, num_batches=None):
yielded = 0

limit = (
len(self._cache)
if num_batches is None
else min(len(self._cache), num_batches)
)
for i in range(limit):
yield self._cache[i]
yielded += 1

if num_batches is not None and yielded >= num_batches:
return

if self._fully_cached:
return

if self._iterator is None:
logging.info('Building %s cache lazily...', self._split_name)
self._iterator = iter(self._factory(None))

for batch in self._iterator:
self._cache.append(batch)
yield batch
yielded += 1
if num_batches is not None and yielded >= num_batches:
return

self._fully_cached = True
self._factory = None
self._iterator = None
rss_mb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024
logging.info(
'%s cache complete: %d batches — RSS: %.1f MB',
self._split_name,
len(self._cache),
rss_mb,
)


def iterator_as_numpy(iterator):
for x in iterator:
yield jax.tree.map(lambda y: y._numpy(), x) # pylint: disable=protected-access
Expand Down