Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 258 additions & 14 deletions init2winit/dataset_lib/criteo_terabyte_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,19 @@
import tensorflow as tf
import tensorflow_datasets as tfds


# Change to the path to raw dataset files.
RAW_CRITEO1TB_FILE_PATH = ''
CRITEO1TB_DEFAULT_HPARAMS = config_dict.ConfigDict(dict(
input_shape=(13 + 26,),
train_size=4_195_197_692,
# We assume the tie breaking example went to the validation set, because
# the test set in the mlperf version has 89_137_318 examples.
valid_size=89_137_319,
test_size=89_137_318,
))
PREPROCESSED_CRITEO1TB_FILE_PATH = '' # pylint: disable=invalid-name
CRITEO1TB_DEFAULT_HPARAMS = config_dict.ConfigDict(
dict(
input_shape=(13 + 26,),
train_size=4_195_197_692,
# We assume the tie breaking example went to the validation set, because
# the test set in the mlperf version has 89_137_318 examples.
valid_size=89_137_319,
test_size=89_137_318,
)
)
CRITEO1TB_METADATA = {
'apply_one_hot_in_loss': True,
}
Expand Down Expand Up @@ -142,9 +144,9 @@ def criteo_tsv_reader(
ds = ds.repeat()
ds = ds.interleave(
tf.data.TextLineDataset,
cycle_length=128,
cycle_length=64,
block_length=batch_size // 8,
num_parallel_calls=128,
num_parallel_calls=64,
deterministic=False)
if is_training:
ds = ds.shuffle(buffer_size=524_288 * 100, seed=data_shuffle_seed)
Expand All @@ -155,6 +157,94 @@ def criteo_tsv_reader(
return ds


_ARRAYRECORD_FEATURE_SPEC = {
'inputs': tf.io.FixedLenFeature([13 + 26], tf.float32),
'targets': tf.io.FixedLenFeature([1], tf.float32),
}


@tf.function
def _parse_arrayrecord_example_fn(serialized_examples):
"""Parse a batch of serialized tf.train.Examples from ArrayRecord."""
parsed = tf.io.parse_example(serialized_examples, _ARRAYRECORD_FEATURE_SPEC)
return {
'inputs': parsed['inputs'],
'targets': tf.squeeze(parsed['targets'], axis=-1),
}


def criteo_arrayrecord_reader(
split, shuffle_rng, file_path, batch_size, num_batches_to_prefetch
):
"""Input reader for preprocessed Criteo ArrayRecord data.

Args:
split: one of {'train', 'eval_train', 'validation', 'test'}.
shuffle_rng: jax.random.PRNGKey for shuffling (train).
file_path: glob pattern for .array_record files.
batch_size: per-host batch size.
num_batches_to_prefetch: number of batches to prefetch.

Returns:
A tf.data.Dataset object.
"""
# Import here to avoid hard dependency for TSV-only users.
if split not in ['train', 'eval_train', 'validation', 'test']:
raise ValueError(f'Invalid split name {split}.')
data_shuffle_seed = None

is_training = split == 'train'
if is_training:
_, data_shuffle_seed = jax.random.split(shuffle_rng, 2)
data_shuffle_seed = data_utils.convert_jax_to_tf_random_seed(
data_shuffle_seed
)

# Discover all matching files.
all_files = sorted(tf.io.gfile.glob(file_path))
if not all_files:
raise ValueError(f'No ArrayRecord files found matching: {file_path}')

# Shard files across hosts.
index = jax.process_index()
num_hosts = jax.process_count()
host_files = all_files[index::num_hosts]

# Interleave per-file datasets, with batch+parse inside each file's
# sub-pipeline. This is critical for performance: interleaving dense float
# tensors (post-parse) is much faster than interleaving raw byte strings
# and batching them later.
file_ds = tf.data.Dataset.from_tensor_slices(host_files)
if is_training:
file_ds = file_ds.repeat()
file_ds = file_ds.shuffle(
buffer_size=2 * len(host_files), seed=data_shuffle_seed
)

ds = file_ds.interleave(
lambda f: (
ar_dataset.ArrayRecordDataset([f])
.batch(
batch_size,
drop_remainder=is_training,
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False,
)
.map(
_parse_arrayrecord_example_fn,
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False,
)
),
cycle_length=64,
block_length=batch_size // 8,
num_parallel_calls=64,
deterministic=False,
)
ds = ds.prefetch(num_batches_to_prefetch)
return ds


def _convert_to_numpy_iterator_fn(
num_batches, per_host_eval_batch_size, tf_dataset, split_size):
"""Make eval iterator. This function is called at the start of each eval."""
Expand Down Expand Up @@ -211,12 +301,143 @@ def get_criteo1tb(shuffle_rng,
per_host_eval_batch_size = eval_batch_size // process_count
per_host_batch_size = batch_size // process_count

use_raw_tsv = hps.get('use_raw_tsv', False)
num_batches_to_prefetch = (
hps.num_tf_data_prefetches
if hps.num_tf_data_prefetches > 0
else tf.data.AUTOTUNE
)
num_device_prefetches = hps.get('num_device_prefetches', 0)

if use_raw_tsv:
return _get_criteo1tb_tsv(
shuffle_rng,
per_host_batch_size,
per_host_eval_batch_size,
hps,
num_batches_to_prefetch,
num_device_prefetches,
)
else:
return _get_criteo1tb_arrayrecord(
shuffle_rng,
per_host_batch_size,
per_host_eval_batch_size,
hps,
num_batches_to_prefetch,
num_device_prefetches,
)


def _get_criteo1tb_arrayrecord(
shuffle_rng,
per_host_batch_size,
per_host_eval_batch_size,
hps,
num_batches_to_prefetch,
num_device_prefetches,
):
"""Load Criteo 1TB from preprocessed ArrayRecord files."""
base = hps.get('preprocessed_data_path', PREPROCESSED_CRITEO1TB_FILE_PATH)
train_file_path = os.path.join(base, 'train', '*')
validation_file_path = os.path.join(
base, 'val_set_second_half_of_day23_not_used', '*'
)
test_file_path = os.path.join(base, 'eval', '*')

train_dataset = criteo_arrayrecord_reader(
split='train',
shuffle_rng=shuffle_rng,
file_path=train_file_path,
batch_size=per_host_batch_size,
num_batches_to_prefetch=num_batches_to_prefetch,
)
data_utils.log_rss('train arrayrecord dataset created')

if num_device_prefetches > 0:
train_iterator_fn = lambda: data_utils.prefetch_iterator(
tfds.as_numpy(train_dataset), num_device_prefetches
)
data_utils.log_rss(
f'using prefetching with {num_device_prefetches} in the train dataset'
)
else:
train_iterator_fn = lambda: tfds.as_numpy(train_dataset)

eval_train_dataset = criteo_arrayrecord_reader(
split='eval_train',
shuffle_rng=None,
file_path=train_file_path,
batch_size=per_host_eval_batch_size,
num_batches_to_prefetch=num_batches_to_prefetch,
)
eval_train_iterator_fn = functools.partial(
_convert_to_numpy_iterator_fn,
per_host_eval_batch_size=per_host_eval_batch_size,
tf_dataset=eval_train_dataset,
split_size=hps.train_size,
)
data_utils.log_rss('eval_train arrayrecord dataset created')

validation_dataset = criteo_arrayrecord_reader(
split='validation',
shuffle_rng=None,
file_path=validation_file_path,
batch_size=per_host_eval_batch_size,
num_batches_to_prefetch=num_batches_to_prefetch,
)
validation_iterator_fn = functools.partial(
_convert_to_numpy_iterator_fn,
per_host_eval_batch_size=per_host_eval_batch_size,
tf_dataset=validation_dataset,
split_size=hps.valid_size,
)
data_utils.log_rss('validation arrayrecord dataset created')

test_dataset = criteo_arrayrecord_reader(
split='test',
shuffle_rng=None,
file_path=test_file_path,
batch_size=per_host_eval_batch_size,
num_batches_to_prefetch=num_batches_to_prefetch,
)
test_iterator_fn = functools.partial(
_convert_to_numpy_iterator_fn,
per_host_eval_batch_size=per_host_eval_batch_size,
tf_dataset=test_dataset,
split_size=hps.test_size,
)
data_utils.log_rss('test arrayrecord dataset created')

eval_train_iterator_fn = data_utils.CachedEvalIterator(
eval_train_iterator_fn, 'eval_train'
)
validation_iterator_fn = data_utils.CachedEvalIterator(
validation_iterator_fn, 'validation'
)
test_iterator_fn = data_utils.CachedEvalIterator(test_iterator_fn, 'test')

return data_utils.Dataset(
train_iterator_fn,
eval_train_iterator_fn,
validation_iterator_fn,
test_iterator_fn,
)


def _get_criteo1tb_tsv(
shuffle_rng,
per_host_batch_size,
per_host_eval_batch_size,
hps,
num_batches_to_prefetch,
num_device_prefetches,
):
"""Load Criteo 1TB from raw TSV files (legacy path)."""
train_file_path = os.path.join(RAW_CRITEO1TB_FILE_PATH, 'train/*/*')
validation_file_path = os.path.join(
RAW_CRITEO1TB_FILE_PATH, 'val_set_second_half_of_day23_not_used/*')
test_file_path = os.path.join(RAW_CRITEO1TB_FILE_PATH, 'eval/day_23/*')
num_batches_to_prefetch = (hps.num_tf_data_prefetches
if hps.num_tf_data_prefetches > 0 else tf.data.AUTOTUNE)

train_dataset = criteo_tsv_reader(
split='train',
Expand All @@ -225,7 +446,16 @@ def get_criteo1tb(shuffle_rng,
num_dense_features=hps.num_dense_features,
batch_size=per_host_batch_size,
num_batches_to_prefetch=num_batches_to_prefetch)
train_iterator_fn = lambda: tfds.as_numpy(train_dataset)
data_utils.log_rss('train dataset created')
if num_device_prefetches > 0:
train_iterator_fn = lambda: data_utils.prefetch_iterator(
tfds.as_numpy(train_dataset), num_device_prefetches
)
data_utils.log_rss(
f'using prefetching with {num_device_prefetches} in the train dataset'
)
else:
train_iterator_fn = lambda: tfds.as_numpy(train_dataset)
eval_train_dataset = criteo_tsv_reader(
split='eval_train',
shuffle_rng=None,
Expand All @@ -238,6 +468,7 @@ def get_criteo1tb(shuffle_rng,
per_host_eval_batch_size=per_host_eval_batch_size,
tf_dataset=eval_train_dataset,
split_size=hps.train_size)
data_utils.log_rss('eval_train dataset created')
validation_dataset = criteo_tsv_reader(
split='validation',
shuffle_rng=None,
Expand All @@ -250,6 +481,7 @@ def get_criteo1tb(shuffle_rng,
per_host_eval_batch_size=per_host_eval_batch_size,
tf_dataset=validation_dataset,
split_size=hps.valid_size)
data_utils.log_rss('validation dataset created')
test_dataset = criteo_tsv_reader(
split='test',
shuffle_rng=None,
Expand All @@ -262,6 +494,18 @@ def get_criteo1tb(shuffle_rng,
per_host_eval_batch_size=per_host_eval_batch_size,
tf_dataset=test_dataset,
split_size=hps.test_size)
data_utils.log_rss('test dataset created')

# Cache all the eval_train/validation/test iterators to avoid re-processing the
# same data files.
eval_train_iterator_fn = data_utils.CachedEvalIterator(
eval_train_iterator_fn, 'eval_train'
)
validation_iterator_fn = data_utils.CachedEvalIterator(
validation_iterator_fn, 'validation'
)
test_iterator_fn = data_utils.CachedEvalIterator(test_iterator_fn, 'test')

return data_utils.Dataset(
train_iterator_fn,
eval_train_iterator_fn,
Expand Down
Loading