Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 149 additions & 84 deletions init2winit/dataset_lib/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from init2winit.dataset_lib import imagenet_dataset
from init2winit.dataset_lib import librispeech
from init2winit.dataset_lib import lm1b_v2
# We get TF v2 eager execution error if we import fineweb_edu_10b
# and fineweb_edu_10b_mdlm before lm1b_v2
from init2winit.dataset_lib import fineweb_edu_10b # pylint: disable=g-bad-import-order
from init2winit.dataset_lib import fineweb_edu_10b_mdlm # pylint: disable=g-bad-import-order
from init2winit.dataset_lib import mlperf_imagenet_dataset
from init2winit.dataset_lib import nanodo_c4
from init2winit.dataset_lib import nanodo_fineweb_edu
Expand All @@ -39,93 +42,148 @@
from init2winit.dataset_lib import wikitext2

_Dataset = collections.namedtuple(
'Dataset', ('getter', 'hparams', 'meta_data', 'fake_batch_getter'))
'Dataset', ('getter', 'hparams', 'meta_data', 'fake_batch_getter')
)

_ALL_DATASETS = {
'mnist':
_Dataset(small_image_datasets.get_mnist,
small_image_datasets.MNIST_HPARAMS,
small_image_datasets.MNIST_METADATA, None),
'mnist_autoencoder':
_Dataset(small_image_datasets.get_mnist_autoencoder,
small_image_datasets.MNIST_AUTOENCODER_HPARAMS,
small_image_datasets.MNIST_AUTOENCODER_METADATA, None),
'fashion_mnist':
_Dataset(small_image_datasets.get_fashion_mnist,
small_image_datasets.FASHION_MNIST_HPARAMS,
small_image_datasets.FASHION_MNIST_METADATA, None),
'cifar10':
_Dataset(small_image_datasets.get_cifar10,
small_image_datasets.CIFAR10_DEFAULT_HPARAMS,
small_image_datasets.CIFAR10_METADATA, None),
'cifar100':
_Dataset(small_image_datasets.get_cifar100,
small_image_datasets.CIFAR100_DEFAULT_HPARAMS,
small_image_datasets.CIFAR100_METADATA, None),
'criteo1tb':
_Dataset(criteo_terabyte_dataset.get_criteo1tb,
criteo_terabyte_dataset.CRITEO1TB_DEFAULT_HPARAMS,
criteo_terabyte_dataset.CRITEO1TB_METADATA,
criteo_terabyte_dataset.get_fake_batch),
'fake':
_Dataset(fake_dataset.get_fake, fake_dataset.DEFAULT_HPARAMS,
fake_dataset.METADATA, fake_dataset.get_fake_batch),
'fastmri':
_Dataset(fastmri_dataset.get_fastmri, fastmri_dataset.DEFAULT_HPARAMS,
fastmri_dataset.METADATA, fastmri_dataset.get_fake_batch),
'mnist': _Dataset(
small_image_datasets.get_mnist,
small_image_datasets.MNIST_HPARAMS,
small_image_datasets.MNIST_METADATA,
None,
),
'mnist_autoencoder': _Dataset(
small_image_datasets.get_mnist_autoencoder,
small_image_datasets.MNIST_AUTOENCODER_HPARAMS,
small_image_datasets.MNIST_AUTOENCODER_METADATA,
None,
),
'fashion_mnist': _Dataset(
small_image_datasets.get_fashion_mnist,
small_image_datasets.FASHION_MNIST_HPARAMS,
small_image_datasets.FASHION_MNIST_METADATA,
None,
),
'cifar10': _Dataset(
small_image_datasets.get_cifar10,
small_image_datasets.CIFAR10_DEFAULT_HPARAMS,
small_image_datasets.CIFAR10_METADATA,
None,
),
'cifar100': _Dataset(
small_image_datasets.get_cifar100,
small_image_datasets.CIFAR100_DEFAULT_HPARAMS,
small_image_datasets.CIFAR100_METADATA,
None,
),
'criteo1tb': _Dataset(
criteo_terabyte_dataset.get_criteo1tb,
criteo_terabyte_dataset.CRITEO1TB_DEFAULT_HPARAMS,
criteo_terabyte_dataset.CRITEO1TB_METADATA,
criteo_terabyte_dataset.get_fake_batch,
),
'fake': _Dataset(
fake_dataset.get_fake,
fake_dataset.DEFAULT_HPARAMS,
fake_dataset.METADATA,
fake_dataset.get_fake_batch,
),
'fastmri': _Dataset(
fastmri_dataset.get_fastmri,
fastmri_dataset.DEFAULT_HPARAMS,
fastmri_dataset.METADATA,
fastmri_dataset.get_fake_batch,
),
'fineweb_edu_10B': _Dataset(
fineweb_edu_10b.get_fineweb_edu,
fineweb_edu_10b.DEFAULT_HPARAMS,
fineweb_edu_10b.METADATA, None),
'imagenet':
_Dataset(imagenet_dataset.get_imagenet,
imagenet_dataset.DEFAULT_HPARAMS, imagenet_dataset.METADATA,
imagenet_dataset.get_fake_batch),
'translate_wmt':
_Dataset(translate_wmt.get_translate_wmt, translate_wmt.DEFAULT_HPARAMS,
translate_wmt.METADATA, translate_wmt.get_fake_batch),
'librispeech':
_Dataset(librispeech.get_librispeech, librispeech.DEFAULT_HPARAMS,
librispeech.METADATA, librispeech.get_fake_batch),
'lm1b_v2':
_Dataset(lm1b_v2.get_lm1b, lm1b_v2.DEFAULT_HPARAMS, lm1b_v2.METADATA,
None),
'mlperf_imagenet':
_Dataset(mlperf_imagenet_dataset.get_mlperf_imagenet,
mlperf_imagenet_dataset.DEFAULT_HPARAMS,
mlperf_imagenet_dataset.METADATA,
mlperf_imagenet_dataset.get_fake_batch),
'svhn_no_extra':
_Dataset(small_image_datasets.get_svhn_no_extra,
small_image_datasets.SVHN_NO_EXTRA_DEFAULT_HPARAMS,
small_image_datasets.SVHN_NO_EXTRA_METADATA, None),
fineweb_edu_10b.METADATA,
None,
),
'fineweb_edu_10B_mdlm': _Dataset(
fineweb_edu_10b_mdlm.get_fineweb_edu_mdlm,
fineweb_edu_10b_mdlm.DEFAULT_HPARAMS,
fineweb_edu_10b_mdlm.METADATA,
None,
),
'imagenet': _Dataset(
imagenet_dataset.get_imagenet,
imagenet_dataset.DEFAULT_HPARAMS,
imagenet_dataset.METADATA,
imagenet_dataset.get_fake_batch,
),
'translate_wmt': _Dataset(
translate_wmt.get_translate_wmt,
translate_wmt.DEFAULT_HPARAMS,
translate_wmt.METADATA,
translate_wmt.get_fake_batch,
),
'librispeech': _Dataset(
librispeech.get_librispeech,
librispeech.DEFAULT_HPARAMS,
librispeech.METADATA,
librispeech.get_fake_batch,
),
'lm1b_v2': _Dataset(
lm1b_v2.get_lm1b, lm1b_v2.DEFAULT_HPARAMS, lm1b_v2.METADATA, None
),
'mlperf_imagenet': _Dataset(
mlperf_imagenet_dataset.get_mlperf_imagenet,
mlperf_imagenet_dataset.DEFAULT_HPARAMS,
mlperf_imagenet_dataset.METADATA,
mlperf_imagenet_dataset.get_fake_batch,
),
'svhn_no_extra': _Dataset(
small_image_datasets.get_svhn_no_extra,
small_image_datasets.SVHN_NO_EXTRA_DEFAULT_HPARAMS,
small_image_datasets.SVHN_NO_EXTRA_METADATA,
None,
),
'c4': _Dataset(
nanodo_c4.get_dataset,
nanodo_c4.DEFAULT_HPARAMS,
nanodo_c4.METADATA, None),
nanodo_c4.METADATA,
None,
),
'fineweb_edu': _Dataset(
nanodo_fineweb_edu.get_dataset,
nanodo_fineweb_edu.DEFAULT_HPARAMS,
nanodo_fineweb_edu.METADATA, None),
'nqm_noise':
_Dataset(nqm_noise.get_nqm_noise, nqm_noise.NQM_HPARAMS,
nqm_noise.NQM_METADATA, None),
'ogbg_molpcba':
_Dataset(ogbg_molpcba.get_ogbg_molpcba, ogbg_molpcba.DEFAULT_HPARAMS,
ogbg_molpcba.METADATA, ogbg_molpcba.get_fake_batch),
'uniref50':
_Dataset(proteins.get_uniref, proteins.DEFAULT_HPARAMS,
proteins.METADATA, None),
'wikitext2':
_Dataset(wikitext2.get_wikitext2, wikitext2.DEFAULT_HPARAMS,
wikitext2.METADATA, None),
'wikitext103':
_Dataset(wikitext103.get_wikitext103, wikitext103.DEFAULT_HPARAMS,
wikitext2.METADATA, None),
'wikitext103_spm':
_Dataset(wikitext103_spm.get_wikitext103,
wikitext103_spm.DEFAULT_HPARAMS,
wikitext103_spm.METADATA, None),
nanodo_fineweb_edu.METADATA,
None,
),
'nqm_noise': _Dataset(
nqm_noise.get_nqm_noise,
nqm_noise.NQM_HPARAMS,
nqm_noise.NQM_METADATA,
None,
),
'ogbg_molpcba': _Dataset(
ogbg_molpcba.get_ogbg_molpcba,
ogbg_molpcba.DEFAULT_HPARAMS,
ogbg_molpcba.METADATA,
ogbg_molpcba.get_fake_batch,
),
'uniref50': _Dataset(
proteins.get_uniref, proteins.DEFAULT_HPARAMS, proteins.METADATA, None
),
'wikitext2': _Dataset(
wikitext2.get_wikitext2,
wikitext2.DEFAULT_HPARAMS,
wikitext2.METADATA,
None,
),
'wikitext103': _Dataset(
wikitext103.get_wikitext103,
wikitext103.DEFAULT_HPARAMS,
wikitext2.METADATA,
None,
),
'wikitext103_spm': _Dataset(
wikitext103_spm.get_wikitext103,
wikitext103_spm.DEFAULT_HPARAMS,
wikitext103_spm.METADATA,
None,
),
}


Expand All @@ -144,8 +202,10 @@ def get_dataset_hparams(dataset_name):
# TODO(mbadura): Refactor to explicitly support different input specs
if 'input_shape' not in hparams or hparams.input_shape is None:
if 'input_edge_shape' in hparams and 'input_node_shape' in hparams:
hparams.input_shape = (hparams.input_node_shape,
hparams.input_edge_shape)
hparams.input_shape = (
hparams.input_node_shape,
hparams.input_edge_shape,
)
elif dataset_name == 'lm1b_v2':
max_len = max(hparams.max_target_length, hparams.max_eval_target_length)
hparams.input_shape = (max_len,)
Expand All @@ -155,12 +215,16 @@ def get_dataset_hparams(dataset_name):
hparams.input_shape = (max_len,)
hparams.output_shape = (max_len, hparams.vocab_size)
elif dataset_name == 'translate_wmt':
max_len = max(hparams.max_target_length, hparams.max_eval_target_length,
hparams.max_predict_length)
max_len = max(
hparams.max_target_length,
hparams.max_eval_target_length,
hparams.max_predict_length,
)
hparams.input_shape = [(max_len,), (max_len,)]
else:
raise ValueError(
'Undefined input shape for dataset: {}'.format(dataset_name))
'Undefined input shape for dataset: {}'.format(dataset_name)
)
return hparams
except KeyError:
raise ValueError('Unrecognized dataset: {}'.format(dataset_name)) from None
Expand Down Expand Up @@ -209,7 +273,8 @@ def get_fake_batch(dataset_name):
getter = _ALL_DATASETS[dataset_name].fake_batch_getter
if getter is None:
raise ValueError(
f'Fake batch getter not defined for dataset {dataset_name}') from None
f'Fake batch getter not defined for dataset {dataset_name}'
) from None
except KeyError:
raise ValueError('Unrecognized dataset: {}'.format(dataset_name)) from None

Expand All @@ -235,7 +300,7 @@ def get_data_selector(selector_name: Optional[str]):
selector = data_selectors.ALL_SELECTORS[selector_name]
except KeyError:
raise ValueError(
'Unrecognized selector: {}'.format(selector_name)) from None
'Unrecognized selector: {}'.format(selector_name)
) from None

return selector

49 changes: 31 additions & 18 deletions init2winit/dataset_lib/fineweb_edu_10b_input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def get_fineweb_edu_dataset(
train_batch_size: int,
valid_batch_size: int,
shuffle_seed: int,
shift: bool = True,
) -> tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
"""Returns wikitext-103 dataset.

Expand All @@ -84,6 +85,8 @@ def get_fineweb_edu_dataset(
train_batch_size: Batch size for train iterations
valid_batch_size: Batch size for validation iterations
shuffle_seed: seed for shuffling dataset sequences
shift: If True (default), inputs = x[:-1], targets = x[1:] (AR mode). If
False, inputs = targets = x (MDLM mode).

Returns:
train_dataset, eval_train_dataset, valid_dataset, test_dataset
Expand All @@ -101,37 +104,47 @@ def get_fineweb_edu_dataset(
val_tokens = val_dataset.flat_map(tf.data.Dataset.from_tensor_slices)

# split into sequences
seq_batch_len = hps.sequence_length + 1 if shift else hps.sequence_length
eval_seq_batch_len = (
hps.eval_sequence_length + 1 if shift else hps.eval_sequence_length
)
train_sequences_dataset = train_tokens.batch(
hps.sequence_length + 1, drop_remainder=True
seq_batch_len, drop_remainder=True
)
eval_train_sequences_dataset = train_tokens.batch(
hps.eval_sequence_length + 1, drop_remainder=True
eval_seq_batch_len, drop_remainder=True
)
val_sequences_dataset = val_tokens.batch(
hps.eval_sequence_length + 1, drop_remainder=True
eval_seq_batch_len, drop_remainder=True
)

# Split the sequences into inputs and targets.
if shift:
map_fn = lambda x: {
'inputs': x['input_ids'][: hps.sequence_length],
'targets': x['input_ids'][1:],
}
eval_map_fn = lambda x: {
'inputs': x['input_ids'][: hps.eval_sequence_length],
'targets': x['input_ids'][1:],
}
else:
map_fn = lambda x: {
'inputs': x['input_ids'][: hps.sequence_length],
'targets': x['input_ids'][: hps.sequence_length],
}
eval_map_fn = lambda x: {
'inputs': x['input_ids'][: hps.eval_sequence_length],
'targets': x['input_ids'][: hps.eval_sequence_length],
}
train_sequences_dataset = train_sequences_dataset.map(
lambda x: {
'inputs': x['input_ids'][: hps.sequence_length],
'targets': x['input_ids'][1:],
},
num_parallel_calls=AUTOTUNE,
map_fn, num_parallel_calls=AUTOTUNE
)
eval_train_sequences_dataset = eval_train_sequences_dataset.map(
lambda x: {
'inputs': x['input_ids'][: hps.eval_sequence_length],
'targets': x['input_ids'][1:],
},
num_parallel_calls=AUTOTUNE,
eval_map_fn, num_parallel_calls=AUTOTUNE
)
val_sequences_dataset = val_sequences_dataset.map(
lambda x: {
'inputs': x['input_ids'][: hps.eval_sequence_length],
'targets': x['input_ids'][1:],
},
num_parallel_calls=AUTOTUNE,
eval_map_fn, num_parallel_calls=AUTOTUNE
)

# Shuffle the train sequences.
Expand Down
Loading