diff --git a/init2winit/dataset_lib/datasets.py b/init2winit/dataset_lib/datasets.py index c333c7d0..e9a22ec3 100644 --- a/init2winit/dataset_lib/datasets.py +++ b/init2winit/dataset_lib/datasets.py @@ -25,7 +25,10 @@ from init2winit.dataset_lib import imagenet_dataset from init2winit.dataset_lib import librispeech from init2winit.dataset_lib import lm1b_v2 +# We get TF v2 eager execution error if we import fineweb_edu_10b +# and fineweb_edu_10b_mdlm before lm1b_v2 from init2winit.dataset_lib import fineweb_edu_10b # pylint: disable=g-bad-import-order +from init2winit.dataset_lib import fineweb_edu_10b_mdlm # pylint: disable=g-bad-import-order from init2winit.dataset_lib import mlperf_imagenet_dataset from init2winit.dataset_lib import nanodo_c4 from init2winit.dataset_lib import nanodo_fineweb_edu @@ -39,93 +42,148 @@ from init2winit.dataset_lib import wikitext2 _Dataset = collections.namedtuple( - 'Dataset', ('getter', 'hparams', 'meta_data', 'fake_batch_getter')) + 'Dataset', ('getter', 'hparams', 'meta_data', 'fake_batch_getter') +) _ALL_DATASETS = { - 'mnist': - _Dataset(small_image_datasets.get_mnist, - small_image_datasets.MNIST_HPARAMS, - small_image_datasets.MNIST_METADATA, None), - 'mnist_autoencoder': - _Dataset(small_image_datasets.get_mnist_autoencoder, - small_image_datasets.MNIST_AUTOENCODER_HPARAMS, - small_image_datasets.MNIST_AUTOENCODER_METADATA, None), - 'fashion_mnist': - _Dataset(small_image_datasets.get_fashion_mnist, - small_image_datasets.FASHION_MNIST_HPARAMS, - small_image_datasets.FASHION_MNIST_METADATA, None), - 'cifar10': - _Dataset(small_image_datasets.get_cifar10, - small_image_datasets.CIFAR10_DEFAULT_HPARAMS, - small_image_datasets.CIFAR10_METADATA, None), - 'cifar100': - _Dataset(small_image_datasets.get_cifar100, - small_image_datasets.CIFAR100_DEFAULT_HPARAMS, - small_image_datasets.CIFAR100_METADATA, None), - 'criteo1tb': - _Dataset(criteo_terabyte_dataset.get_criteo1tb, - criteo_terabyte_dataset.CRITEO1TB_DEFAULT_HPARAMS, - criteo_terabyte_dataset.CRITEO1TB_METADATA, - criteo_terabyte_dataset.get_fake_batch), - 'fake': - _Dataset(fake_dataset.get_fake, fake_dataset.DEFAULT_HPARAMS, - fake_dataset.METADATA, fake_dataset.get_fake_batch), - 'fastmri': - _Dataset(fastmri_dataset.get_fastmri, fastmri_dataset.DEFAULT_HPARAMS, - fastmri_dataset.METADATA, fastmri_dataset.get_fake_batch), + 'mnist': _Dataset( + small_image_datasets.get_mnist, + small_image_datasets.MNIST_HPARAMS, + small_image_datasets.MNIST_METADATA, + None, + ), + 'mnist_autoencoder': _Dataset( + small_image_datasets.get_mnist_autoencoder, + small_image_datasets.MNIST_AUTOENCODER_HPARAMS, + small_image_datasets.MNIST_AUTOENCODER_METADATA, + None, + ), + 'fashion_mnist': _Dataset( + small_image_datasets.get_fashion_mnist, + small_image_datasets.FASHION_MNIST_HPARAMS, + small_image_datasets.FASHION_MNIST_METADATA, + None, + ), + 'cifar10': _Dataset( + small_image_datasets.get_cifar10, + small_image_datasets.CIFAR10_DEFAULT_HPARAMS, + small_image_datasets.CIFAR10_METADATA, + None, + ), + 'cifar100': _Dataset( + small_image_datasets.get_cifar100, + small_image_datasets.CIFAR100_DEFAULT_HPARAMS, + small_image_datasets.CIFAR100_METADATA, + None, + ), + 'criteo1tb': _Dataset( + criteo_terabyte_dataset.get_criteo1tb, + criteo_terabyte_dataset.CRITEO1TB_DEFAULT_HPARAMS, + criteo_terabyte_dataset.CRITEO1TB_METADATA, + criteo_terabyte_dataset.get_fake_batch, + ), + 'fake': _Dataset( + fake_dataset.get_fake, + fake_dataset.DEFAULT_HPARAMS, + fake_dataset.METADATA, + fake_dataset.get_fake_batch, + ), + 'fastmri': _Dataset( + fastmri_dataset.get_fastmri, + fastmri_dataset.DEFAULT_HPARAMS, + fastmri_dataset.METADATA, + fastmri_dataset.get_fake_batch, + ), 'fineweb_edu_10B': _Dataset( fineweb_edu_10b.get_fineweb_edu, fineweb_edu_10b.DEFAULT_HPARAMS, - fineweb_edu_10b.METADATA, None), - 'imagenet': - _Dataset(imagenet_dataset.get_imagenet, - imagenet_dataset.DEFAULT_HPARAMS, imagenet_dataset.METADATA, - imagenet_dataset.get_fake_batch), - 'translate_wmt': - _Dataset(translate_wmt.get_translate_wmt, translate_wmt.DEFAULT_HPARAMS, - translate_wmt.METADATA, translate_wmt.get_fake_batch), - 'librispeech': - _Dataset(librispeech.get_librispeech, librispeech.DEFAULT_HPARAMS, - librispeech.METADATA, librispeech.get_fake_batch), - 'lm1b_v2': - _Dataset(lm1b_v2.get_lm1b, lm1b_v2.DEFAULT_HPARAMS, lm1b_v2.METADATA, - None), - 'mlperf_imagenet': - _Dataset(mlperf_imagenet_dataset.get_mlperf_imagenet, - mlperf_imagenet_dataset.DEFAULT_HPARAMS, - mlperf_imagenet_dataset.METADATA, - mlperf_imagenet_dataset.get_fake_batch), - 'svhn_no_extra': - _Dataset(small_image_datasets.get_svhn_no_extra, - small_image_datasets.SVHN_NO_EXTRA_DEFAULT_HPARAMS, - small_image_datasets.SVHN_NO_EXTRA_METADATA, None), + fineweb_edu_10b.METADATA, + None, + ), + 'fineweb_edu_10B_mdlm': _Dataset( + fineweb_edu_10b_mdlm.get_fineweb_edu_mdlm, + fineweb_edu_10b_mdlm.DEFAULT_HPARAMS, + fineweb_edu_10b_mdlm.METADATA, + None, + ), + 'imagenet': _Dataset( + imagenet_dataset.get_imagenet, + imagenet_dataset.DEFAULT_HPARAMS, + imagenet_dataset.METADATA, + imagenet_dataset.get_fake_batch, + ), + 'translate_wmt': _Dataset( + translate_wmt.get_translate_wmt, + translate_wmt.DEFAULT_HPARAMS, + translate_wmt.METADATA, + translate_wmt.get_fake_batch, + ), + 'librispeech': _Dataset( + librispeech.get_librispeech, + librispeech.DEFAULT_HPARAMS, + librispeech.METADATA, + librispeech.get_fake_batch, + ), + 'lm1b_v2': _Dataset( + lm1b_v2.get_lm1b, lm1b_v2.DEFAULT_HPARAMS, lm1b_v2.METADATA, None + ), + 'mlperf_imagenet': _Dataset( + mlperf_imagenet_dataset.get_mlperf_imagenet, + mlperf_imagenet_dataset.DEFAULT_HPARAMS, + mlperf_imagenet_dataset.METADATA, + mlperf_imagenet_dataset.get_fake_batch, + ), + 'svhn_no_extra': _Dataset( + small_image_datasets.get_svhn_no_extra, + small_image_datasets.SVHN_NO_EXTRA_DEFAULT_HPARAMS, + small_image_datasets.SVHN_NO_EXTRA_METADATA, + None, + ), 'c4': _Dataset( nanodo_c4.get_dataset, nanodo_c4.DEFAULT_HPARAMS, - nanodo_c4.METADATA, None), + nanodo_c4.METADATA, + None, + ), 'fineweb_edu': _Dataset( nanodo_fineweb_edu.get_dataset, nanodo_fineweb_edu.DEFAULT_HPARAMS, - nanodo_fineweb_edu.METADATA, None), - 'nqm_noise': - _Dataset(nqm_noise.get_nqm_noise, nqm_noise.NQM_HPARAMS, - nqm_noise.NQM_METADATA, None), - 'ogbg_molpcba': - _Dataset(ogbg_molpcba.get_ogbg_molpcba, ogbg_molpcba.DEFAULT_HPARAMS, - ogbg_molpcba.METADATA, ogbg_molpcba.get_fake_batch), - 'uniref50': - _Dataset(proteins.get_uniref, proteins.DEFAULT_HPARAMS, - proteins.METADATA, None), - 'wikitext2': - _Dataset(wikitext2.get_wikitext2, wikitext2.DEFAULT_HPARAMS, - wikitext2.METADATA, None), - 'wikitext103': - _Dataset(wikitext103.get_wikitext103, wikitext103.DEFAULT_HPARAMS, - wikitext2.METADATA, None), - 'wikitext103_spm': - _Dataset(wikitext103_spm.get_wikitext103, - wikitext103_spm.DEFAULT_HPARAMS, - wikitext103_spm.METADATA, None), + nanodo_fineweb_edu.METADATA, + None, + ), + 'nqm_noise': _Dataset( + nqm_noise.get_nqm_noise, + nqm_noise.NQM_HPARAMS, + nqm_noise.NQM_METADATA, + None, + ), + 'ogbg_molpcba': _Dataset( + ogbg_molpcba.get_ogbg_molpcba, + ogbg_molpcba.DEFAULT_HPARAMS, + ogbg_molpcba.METADATA, + ogbg_molpcba.get_fake_batch, + ), + 'uniref50': _Dataset( + proteins.get_uniref, proteins.DEFAULT_HPARAMS, proteins.METADATA, None + ), + 'wikitext2': _Dataset( + wikitext2.get_wikitext2, + wikitext2.DEFAULT_HPARAMS, + wikitext2.METADATA, + None, + ), + 'wikitext103': _Dataset( + wikitext103.get_wikitext103, + wikitext103.DEFAULT_HPARAMS, + wikitext2.METADATA, + None, + ), + 'wikitext103_spm': _Dataset( + wikitext103_spm.get_wikitext103, + wikitext103_spm.DEFAULT_HPARAMS, + wikitext103_spm.METADATA, + None, + ), } @@ -144,8 +202,10 @@ def get_dataset_hparams(dataset_name): # TODO(mbadura): Refactor to explicitly support different input specs if 'input_shape' not in hparams or hparams.input_shape is None: if 'input_edge_shape' in hparams and 'input_node_shape' in hparams: - hparams.input_shape = (hparams.input_node_shape, - hparams.input_edge_shape) + hparams.input_shape = ( + hparams.input_node_shape, + hparams.input_edge_shape, + ) elif dataset_name == 'lm1b_v2': max_len = max(hparams.max_target_length, hparams.max_eval_target_length) hparams.input_shape = (max_len,) @@ -155,12 +215,16 @@ def get_dataset_hparams(dataset_name): hparams.input_shape = (max_len,) hparams.output_shape = (max_len, hparams.vocab_size) elif dataset_name == 'translate_wmt': - max_len = max(hparams.max_target_length, hparams.max_eval_target_length, - hparams.max_predict_length) + max_len = max( + hparams.max_target_length, + hparams.max_eval_target_length, + hparams.max_predict_length, + ) hparams.input_shape = [(max_len,), (max_len,)] else: raise ValueError( - 'Undefined input shape for dataset: {}'.format(dataset_name)) + 'Undefined input shape for dataset: {}'.format(dataset_name) + ) return hparams except KeyError: raise ValueError('Unrecognized dataset: {}'.format(dataset_name)) from None @@ -209,7 +273,8 @@ def get_fake_batch(dataset_name): getter = _ALL_DATASETS[dataset_name].fake_batch_getter if getter is None: raise ValueError( - f'Fake batch getter not defined for dataset {dataset_name}') from None + f'Fake batch getter not defined for dataset {dataset_name}' + ) from None except KeyError: raise ValueError('Unrecognized dataset: {}'.format(dataset_name)) from None @@ -235,7 +300,7 @@ def get_data_selector(selector_name: Optional[str]): selector = data_selectors.ALL_SELECTORS[selector_name] except KeyError: raise ValueError( - 'Unrecognized selector: {}'.format(selector_name)) from None + 'Unrecognized selector: {}'.format(selector_name) + ) from None return selector - diff --git a/init2winit/dataset_lib/fineweb_edu_10b_input_pipeline.py b/init2winit/dataset_lib/fineweb_edu_10b_input_pipeline.py index 10f0ccfd..2f5c5ec2 100644 --- a/init2winit/dataset_lib/fineweb_edu_10b_input_pipeline.py +++ b/init2winit/dataset_lib/fineweb_edu_10b_input_pipeline.py @@ -76,6 +76,7 @@ def get_fineweb_edu_dataset( train_batch_size: int, valid_batch_size: int, shuffle_seed: int, + shift: bool = True, ) -> tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]: """Returns wikitext-103 dataset. @@ -84,6 +85,8 @@ def get_fineweb_edu_dataset( train_batch_size: Batch size for train iterations valid_batch_size: Batch size for validation iterations shuffle_seed: seed for shuffling dataset sequences + shift: If True (default), inputs = x[:-1], targets = x[1:] (AR mode). If + False, inputs = targets = x (MDLM mode). Returns: train_dataset, eval_train_dataset, valid_dataset, test_dataset @@ -101,37 +104,47 @@ def get_fineweb_edu_dataset( val_tokens = val_dataset.flat_map(tf.data.Dataset.from_tensor_slices) # split into sequences + seq_batch_len = hps.sequence_length + 1 if shift else hps.sequence_length + eval_seq_batch_len = ( + hps.eval_sequence_length + 1 if shift else hps.eval_sequence_length + ) train_sequences_dataset = train_tokens.batch( - hps.sequence_length + 1, drop_remainder=True + seq_batch_len, drop_remainder=True ) eval_train_sequences_dataset = train_tokens.batch( - hps.eval_sequence_length + 1, drop_remainder=True + eval_seq_batch_len, drop_remainder=True ) val_sequences_dataset = val_tokens.batch( - hps.eval_sequence_length + 1, drop_remainder=True + eval_seq_batch_len, drop_remainder=True ) # Split the sequences into inputs and targets. + if shift: + map_fn = lambda x: { + 'inputs': x['input_ids'][: hps.sequence_length], + 'targets': x['input_ids'][1:], + } + eval_map_fn = lambda x: { + 'inputs': x['input_ids'][: hps.eval_sequence_length], + 'targets': x['input_ids'][1:], + } + else: + map_fn = lambda x: { + 'inputs': x['input_ids'][: hps.sequence_length], + 'targets': x['input_ids'][: hps.sequence_length], + } + eval_map_fn = lambda x: { + 'inputs': x['input_ids'][: hps.eval_sequence_length], + 'targets': x['input_ids'][: hps.eval_sequence_length], + } train_sequences_dataset = train_sequences_dataset.map( - lambda x: { - 'inputs': x['input_ids'][: hps.sequence_length], - 'targets': x['input_ids'][1:], - }, - num_parallel_calls=AUTOTUNE, + map_fn, num_parallel_calls=AUTOTUNE ) eval_train_sequences_dataset = eval_train_sequences_dataset.map( - lambda x: { - 'inputs': x['input_ids'][: hps.eval_sequence_length], - 'targets': x['input_ids'][1:], - }, - num_parallel_calls=AUTOTUNE, + eval_map_fn, num_parallel_calls=AUTOTUNE ) val_sequences_dataset = val_sequences_dataset.map( - lambda x: { - 'inputs': x['input_ids'][: hps.eval_sequence_length], - 'targets': x['input_ids'][1:], - }, - num_parallel_calls=AUTOTUNE, + eval_map_fn, num_parallel_calls=AUTOTUNE ) # Shuffle the train sequences. diff --git a/init2winit/dataset_lib/fineweb_edu_10b_mdlm.py b/init2winit/dataset_lib/fineweb_edu_10b_mdlm.py new file mode 100644 index 00000000..92488ffb --- /dev/null +++ b/init2winit/dataset_lib/fineweb_edu_10b_mdlm.py @@ -0,0 +1,119 @@ +# coding=utf-8 +# Copyright 2026 The init2winit Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MDLM variant of FineWeb-Edu 10B dataset. + +Wraps the standard fineweb_edu_10b dataset with metadata appropriate for +masked diffusion language modeling (no shifting, bidirectional). +""" + +import itertools + +from init2winit.dataset_lib import data_utils +from init2winit.dataset_lib import fineweb_edu_10b +from init2winit.dataset_lib import fineweb_edu_10b_input_pipeline as input_pipeline +import jax + +PAD_ID = input_pipeline.PAD_ID +VOCAB_SIZE = input_pipeline.VOCAB_SIZE + +DEFAULT_HPARAMS = fineweb_edu_10b.DEFAULT_HPARAMS +Dataset = data_utils.Dataset + + +METADATA = { + 'apply_one_hot_in_loss': False, + 'shift_inputs': False, + 'causal': False, + 'pad_token': PAD_ID, +} + + +def get_fineweb_edu_mdlm( + shuffle_rng, batch_size, eval_batch_size=None, hps=None, pad_id=PAD_ID +): + """Returns FineWeb-Edu 10B Dataset without input shifting for MDLM.""" + process_count = jax.process_count() + n_devices = jax.local_device_count() + + if batch_size % process_count != 0: + raise ValueError( + 'process_count={} must divide batch_size={}.'.format( + process_count, batch_size + ) + ) + + if eval_batch_size is None: + eval_batch_size = batch_size + + if eval_batch_size % process_count != 0: + raise ValueError( + 'process_count={} must divide batch_size={}.'.format( + process_count, batch_size + ) + ) + + per_host_batch_size = int(batch_size / process_count) + per_host_eval_batch_size = int(eval_batch_size / process_count) + + if per_host_batch_size % n_devices != 0: + raise ValueError( + 'per_host_batch_size={} must be divisible by n_devices={}.'.format( + per_host_batch_size, n_devices + ) + ) + if per_host_eval_batch_size % n_devices != 0: + raise ValueError( + 'per_host_eval_batch_size={} must be divisible by n_devices={}.'.format( + per_host_eval_batch_size, n_devices + ) + ) + + train_dataset, eval_train_dataset, valid_dataset = ( + input_pipeline.get_fineweb_edu_dataset( + hps, + train_batch_size=per_host_batch_size, + valid_batch_size=per_host_eval_batch_size, + shuffle_seed=data_utils.convert_jax_to_tf_random_seed(shuffle_rng), + shift=False, + ) + ) + + def train_iterator_fn(): + for batch in train_dataset: + yield fineweb_edu_10b.add_weights_to_batch( + data_utils.tf_to_numpy(batch), pad_id + ) + + def eval_train_epoch(num_batches=None): + for batch in itertools.islice(iter(eval_train_dataset), num_batches): + yield fineweb_edu_10b.add_weights_to_batch( + data_utils.tf_to_numpy(batch), pad_id + ) + + def valid_epoch(num_batches=None): + for batch in itertools.islice(iter(valid_dataset), num_batches): + yield fineweb_edu_10b.add_weights_to_batch( + data_utils.tf_to_numpy(batch), pad_id + ) + + # pylint: disable=unreachable + def test_epoch(*args, **kwargs): + del args + del kwargs + return + yield # This yield is needed to make this a valid (null) iterator. + + return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch) diff --git a/init2winit/dataset_lib/test_fineweb_edu_10b_mdlm.py b/init2winit/dataset_lib/test_fineweb_edu_10b_mdlm.py new file mode 100644 index 00000000..f31d2083 --- /dev/null +++ b/init2winit/dataset_lib/test_fineweb_edu_10b_mdlm.py @@ -0,0 +1,168 @@ +# coding=utf-8 +# Copyright 2026 The init2winit Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for fineweb_edu_10b_mdlm dataset wrapper. + +Verifies that the MDLM variant does NOT shift inputs (inputs == targets) +and that padding is applied correctly. + +""" + +from unittest import mock + +from absl.testing import absltest +from init2winit.dataset_lib import fineweb_edu_10b_input_pipeline as input_pipeline +from ml_collections.config_dict import config_dict +import numpy as np +import tensorflow as tf + + +def _make_fake_token_dataset(num_tokens): + """Creates a fake dataset mimicking the tokenized FineWeb-Edu format. + + Args: + num_tokens: Number of tokens in the dataset. + + Returns: + A tf.data.Dataset where each element is a dict with 'input_ids' as a 1D + tensor (one document). The pipeline flat_maps these into a stream of + individual token dicts, then re-batches into sequences. + """ + # Simulate a single "document" containing all tokens. + tokens = np.arange(num_tokens, dtype=np.int64) + ds = tf.data.Dataset.from_tensor_slices({'input_ids': [tokens]}) + return ds + + +def _make_hps(seq_len=8): + return config_dict.ConfigDict( + dict( + sequence_length=seq_len, + max_target_length=seq_len, + max_eval_target_length=seq_len, + eval_sequence_length=seq_len, + ) + ) + + +class InputPipelineShiftTest(absltest.TestCase): + """Tests that shift=True produces shifted data and shift=False does not.""" + + def _get_first_batch(self, shift, seq_len=8, num_tokens=100): + """Runs the pipeline on fake data and returns the first train batch.""" + hps = _make_hps(seq_len) + fake_ds = _make_fake_token_dataset(num_tokens) + + with mock.patch.object(tf.data.Dataset, 'load', return_value=fake_ds): + train_ds, _, _ = input_pipeline.get_fineweb_edu_dataset( + hps, + train_batch_size=2, + valid_batch_size=2, + shuffle_seed=0, + shift=shift, + ) + # Take one batch (undo repeat). + batch = next(iter(train_ds.take(1))) + inputs = batch['inputs'].numpy() + targets = batch['targets'].numpy() + return inputs, targets + + def test_shift_true_targets_offset_by_one(self): + """With shift=True (AR), targets[i] == inputs[i] + 1.""" + inputs, targets = self._get_first_batch(shift=True) + + # Inputs and targets should differ. + self.assertFalse(np.array_equal(inputs, targets)) + + # For contiguous token ids, targets should be inputs shifted by 1. + np.testing.assert_array_equal(targets, inputs + 1) + + def test_shift_false_inputs_equal_targets(self): + """With shift=False (MDLM), inputs == targets exactly.""" + inputs, targets = self._get_first_batch(shift=False) + np.testing.assert_array_equal(inputs, targets) + + def test_shift_false_sequence_length_correct(self): + """With shift=False, sequences should be seq_len long, not seq_len+1.""" + seq_len = 8 + inputs, targets = self._get_first_batch(shift=False, seq_len=seq_len) + self.assertEqual(inputs.shape[-1], seq_len) + self.assertEqual(targets.shape[-1], seq_len) + + def test_shift_true_sequence_length_correct(self): + """With shift=True, sequences come from seq_len+1 tokens.""" + seq_len = 8 + inputs, targets = self._get_first_batch(shift=True, seq_len=seq_len) + self.assertEqual(inputs.shape[-1], seq_len) + self.assertEqual(targets.shape[-1], seq_len) + + +class PaddingTest(absltest.TestCase): + """Tests that eval batches are padded correctly.""" + + def test_eval_batch_padding_applied(self): + """Eval batches should be padded to batch_size with PAD_ID.""" + hps = _make_hps(seq_len=4) + # 13 tokens -> 3 sequences of length 4 (shift=False, drop_remainder=True). + # With batch_size=2, we get 1 full batch + 1 partial batch (1 real + 1 pad). + fake_ds = _make_fake_token_dataset(13) + + with mock.patch.object(tf.data.Dataset, 'load', return_value=fake_ds): + _, _, valid_ds = input_pipeline.get_fineweb_edu_dataset( + hps, + train_batch_size=2, + valid_batch_size=2, + shuffle_seed=0, + shift=False, + ) + + batches = list(valid_ds.as_numpy_iterator()) + # Should have 2 batches: first full, second padded. + self.assertLen(batches, 2) + + padded_batch = batches[1] + pad_id = int(input_pipeline.PAD_ID.numpy()) + + # The second row of the padded batch should be all PAD_ID. + np.testing.assert_array_equal(padded_batch['inputs'][1], np.full(4, pad_id)) + np.testing.assert_array_equal( + padded_batch['targets'][1], np.full(4, pad_id) + ) + + def test_eval_batch_padding_not_in_full_batches(self): + """Full eval batches should contain no padding.""" + hps = _make_hps(seq_len=4) + fake_ds = _make_fake_token_dataset(13) + + with mock.patch.object(tf.data.Dataset, 'load', return_value=fake_ds): + _, _, valid_ds = input_pipeline.get_fineweb_edu_dataset( + hps, + train_batch_size=2, + valid_batch_size=2, + shuffle_seed=0, + shift=False, + ) + + batches = list(valid_ds.as_numpy_iterator()) + full_batch = batches[0] + pad_id = int(input_pipeline.PAD_ID.numpy()) + + # No element in the full batch should be PAD_ID. + self.assertTrue(np.all(full_batch['inputs'] != pad_id)) + self.assertTrue(np.all(full_batch['targets'] != pad_id)) + + +if __name__ == '__main__': + absltest.main()