diff --git a/init2winit/dataset_lib/datasets.py b/init2winit/dataset_lib/datasets.py index c333c7d0..e9a22ec3 100644 --- a/init2winit/dataset_lib/datasets.py +++ b/init2winit/dataset_lib/datasets.py @@ -25,7 +25,10 @@ from init2winit.dataset_lib import imagenet_dataset from init2winit.dataset_lib import librispeech from init2winit.dataset_lib import lm1b_v2 +# We get TF v2 eager execution error if we import fineweb_edu_10b +# and fineweb_edu_10b_mdlm before lm1b_v2 from init2winit.dataset_lib import fineweb_edu_10b # pylint: disable=g-bad-import-order +from init2winit.dataset_lib import fineweb_edu_10b_mdlm # pylint: disable=g-bad-import-order from init2winit.dataset_lib import mlperf_imagenet_dataset from init2winit.dataset_lib import nanodo_c4 from init2winit.dataset_lib import nanodo_fineweb_edu @@ -39,93 +42,148 @@ from init2winit.dataset_lib import wikitext2 _Dataset = collections.namedtuple( - 'Dataset', ('getter', 'hparams', 'meta_data', 'fake_batch_getter')) + 'Dataset', ('getter', 'hparams', 'meta_data', 'fake_batch_getter') +) _ALL_DATASETS = { - 'mnist': - _Dataset(small_image_datasets.get_mnist, - small_image_datasets.MNIST_HPARAMS, - small_image_datasets.MNIST_METADATA, None), - 'mnist_autoencoder': - _Dataset(small_image_datasets.get_mnist_autoencoder, - small_image_datasets.MNIST_AUTOENCODER_HPARAMS, - small_image_datasets.MNIST_AUTOENCODER_METADATA, None), - 'fashion_mnist': - _Dataset(small_image_datasets.get_fashion_mnist, - small_image_datasets.FASHION_MNIST_HPARAMS, - small_image_datasets.FASHION_MNIST_METADATA, None), - 'cifar10': - _Dataset(small_image_datasets.get_cifar10, - small_image_datasets.CIFAR10_DEFAULT_HPARAMS, - small_image_datasets.CIFAR10_METADATA, None), - 'cifar100': - _Dataset(small_image_datasets.get_cifar100, - small_image_datasets.CIFAR100_DEFAULT_HPARAMS, - small_image_datasets.CIFAR100_METADATA, None), - 'criteo1tb': - _Dataset(criteo_terabyte_dataset.get_criteo1tb, - criteo_terabyte_dataset.CRITEO1TB_DEFAULT_HPARAMS, - criteo_terabyte_dataset.CRITEO1TB_METADATA, - criteo_terabyte_dataset.get_fake_batch), - 'fake': - _Dataset(fake_dataset.get_fake, fake_dataset.DEFAULT_HPARAMS, - fake_dataset.METADATA, fake_dataset.get_fake_batch), - 'fastmri': - _Dataset(fastmri_dataset.get_fastmri, fastmri_dataset.DEFAULT_HPARAMS, - fastmri_dataset.METADATA, fastmri_dataset.get_fake_batch), + 'mnist': _Dataset( + small_image_datasets.get_mnist, + small_image_datasets.MNIST_HPARAMS, + small_image_datasets.MNIST_METADATA, + None, + ), + 'mnist_autoencoder': _Dataset( + small_image_datasets.get_mnist_autoencoder, + small_image_datasets.MNIST_AUTOENCODER_HPARAMS, + small_image_datasets.MNIST_AUTOENCODER_METADATA, + None, + ), + 'fashion_mnist': _Dataset( + small_image_datasets.get_fashion_mnist, + small_image_datasets.FASHION_MNIST_HPARAMS, + small_image_datasets.FASHION_MNIST_METADATA, + None, + ), + 'cifar10': _Dataset( + small_image_datasets.get_cifar10, + small_image_datasets.CIFAR10_DEFAULT_HPARAMS, + small_image_datasets.CIFAR10_METADATA, + None, + ), + 'cifar100': _Dataset( + small_image_datasets.get_cifar100, + small_image_datasets.CIFAR100_DEFAULT_HPARAMS, + small_image_datasets.CIFAR100_METADATA, + None, + ), + 'criteo1tb': _Dataset( + criteo_terabyte_dataset.get_criteo1tb, + criteo_terabyte_dataset.CRITEO1TB_DEFAULT_HPARAMS, + criteo_terabyte_dataset.CRITEO1TB_METADATA, + criteo_terabyte_dataset.get_fake_batch, + ), + 'fake': _Dataset( + fake_dataset.get_fake, + fake_dataset.DEFAULT_HPARAMS, + fake_dataset.METADATA, + fake_dataset.get_fake_batch, + ), + 'fastmri': _Dataset( + fastmri_dataset.get_fastmri, + fastmri_dataset.DEFAULT_HPARAMS, + fastmri_dataset.METADATA, + fastmri_dataset.get_fake_batch, + ), 'fineweb_edu_10B': _Dataset( fineweb_edu_10b.get_fineweb_edu, fineweb_edu_10b.DEFAULT_HPARAMS, - fineweb_edu_10b.METADATA, None), - 'imagenet': - _Dataset(imagenet_dataset.get_imagenet, - imagenet_dataset.DEFAULT_HPARAMS, imagenet_dataset.METADATA, - imagenet_dataset.get_fake_batch), - 'translate_wmt': - _Dataset(translate_wmt.get_translate_wmt, translate_wmt.DEFAULT_HPARAMS, - translate_wmt.METADATA, translate_wmt.get_fake_batch), - 'librispeech': - _Dataset(librispeech.get_librispeech, librispeech.DEFAULT_HPARAMS, - librispeech.METADATA, librispeech.get_fake_batch), - 'lm1b_v2': - _Dataset(lm1b_v2.get_lm1b, lm1b_v2.DEFAULT_HPARAMS, lm1b_v2.METADATA, - None), - 'mlperf_imagenet': - _Dataset(mlperf_imagenet_dataset.get_mlperf_imagenet, - mlperf_imagenet_dataset.DEFAULT_HPARAMS, - mlperf_imagenet_dataset.METADATA, - mlperf_imagenet_dataset.get_fake_batch), - 'svhn_no_extra': - _Dataset(small_image_datasets.get_svhn_no_extra, - small_image_datasets.SVHN_NO_EXTRA_DEFAULT_HPARAMS, - small_image_datasets.SVHN_NO_EXTRA_METADATA, None), + fineweb_edu_10b.METADATA, + None, + ), + 'fineweb_edu_10B_mdlm': _Dataset( + fineweb_edu_10b_mdlm.get_fineweb_edu_mdlm, + fineweb_edu_10b_mdlm.DEFAULT_HPARAMS, + fineweb_edu_10b_mdlm.METADATA, + None, + ), + 'imagenet': _Dataset( + imagenet_dataset.get_imagenet, + imagenet_dataset.DEFAULT_HPARAMS, + imagenet_dataset.METADATA, + imagenet_dataset.get_fake_batch, + ), + 'translate_wmt': _Dataset( + translate_wmt.get_translate_wmt, + translate_wmt.DEFAULT_HPARAMS, + translate_wmt.METADATA, + translate_wmt.get_fake_batch, + ), + 'librispeech': _Dataset( + librispeech.get_librispeech, + librispeech.DEFAULT_HPARAMS, + librispeech.METADATA, + librispeech.get_fake_batch, + ), + 'lm1b_v2': _Dataset( + lm1b_v2.get_lm1b, lm1b_v2.DEFAULT_HPARAMS, lm1b_v2.METADATA, None + ), + 'mlperf_imagenet': _Dataset( + mlperf_imagenet_dataset.get_mlperf_imagenet, + mlperf_imagenet_dataset.DEFAULT_HPARAMS, + mlperf_imagenet_dataset.METADATA, + mlperf_imagenet_dataset.get_fake_batch, + ), + 'svhn_no_extra': _Dataset( + small_image_datasets.get_svhn_no_extra, + small_image_datasets.SVHN_NO_EXTRA_DEFAULT_HPARAMS, + small_image_datasets.SVHN_NO_EXTRA_METADATA, + None, + ), 'c4': _Dataset( nanodo_c4.get_dataset, nanodo_c4.DEFAULT_HPARAMS, - nanodo_c4.METADATA, None), + nanodo_c4.METADATA, + None, + ), 'fineweb_edu': _Dataset( nanodo_fineweb_edu.get_dataset, nanodo_fineweb_edu.DEFAULT_HPARAMS, - nanodo_fineweb_edu.METADATA, None), - 'nqm_noise': - _Dataset(nqm_noise.get_nqm_noise, nqm_noise.NQM_HPARAMS, - nqm_noise.NQM_METADATA, None), - 'ogbg_molpcba': - _Dataset(ogbg_molpcba.get_ogbg_molpcba, ogbg_molpcba.DEFAULT_HPARAMS, - ogbg_molpcba.METADATA, ogbg_molpcba.get_fake_batch), - 'uniref50': - _Dataset(proteins.get_uniref, proteins.DEFAULT_HPARAMS, - proteins.METADATA, None), - 'wikitext2': - _Dataset(wikitext2.get_wikitext2, wikitext2.DEFAULT_HPARAMS, - wikitext2.METADATA, None), - 'wikitext103': - _Dataset(wikitext103.get_wikitext103, wikitext103.DEFAULT_HPARAMS, - wikitext2.METADATA, None), - 'wikitext103_spm': - _Dataset(wikitext103_spm.get_wikitext103, - wikitext103_spm.DEFAULT_HPARAMS, - wikitext103_spm.METADATA, None), + nanodo_fineweb_edu.METADATA, + None, + ), + 'nqm_noise': _Dataset( + nqm_noise.get_nqm_noise, + nqm_noise.NQM_HPARAMS, + nqm_noise.NQM_METADATA, + None, + ), + 'ogbg_molpcba': _Dataset( + ogbg_molpcba.get_ogbg_molpcba, + ogbg_molpcba.DEFAULT_HPARAMS, + ogbg_molpcba.METADATA, + ogbg_molpcba.get_fake_batch, + ), + 'uniref50': _Dataset( + proteins.get_uniref, proteins.DEFAULT_HPARAMS, proteins.METADATA, None + ), + 'wikitext2': _Dataset( + wikitext2.get_wikitext2, + wikitext2.DEFAULT_HPARAMS, + wikitext2.METADATA, + None, + ), + 'wikitext103': _Dataset( + wikitext103.get_wikitext103, + wikitext103.DEFAULT_HPARAMS, + wikitext2.METADATA, + None, + ), + 'wikitext103_spm': _Dataset( + wikitext103_spm.get_wikitext103, + wikitext103_spm.DEFAULT_HPARAMS, + wikitext103_spm.METADATA, + None, + ), } @@ -144,8 +202,10 @@ def get_dataset_hparams(dataset_name): # TODO(mbadura): Refactor to explicitly support different input specs if 'input_shape' not in hparams or hparams.input_shape is None: if 'input_edge_shape' in hparams and 'input_node_shape' in hparams: - hparams.input_shape = (hparams.input_node_shape, - hparams.input_edge_shape) + hparams.input_shape = ( + hparams.input_node_shape, + hparams.input_edge_shape, + ) elif dataset_name == 'lm1b_v2': max_len = max(hparams.max_target_length, hparams.max_eval_target_length) hparams.input_shape = (max_len,) @@ -155,12 +215,16 @@ def get_dataset_hparams(dataset_name): hparams.input_shape = (max_len,) hparams.output_shape = (max_len, hparams.vocab_size) elif dataset_name == 'translate_wmt': - max_len = max(hparams.max_target_length, hparams.max_eval_target_length, - hparams.max_predict_length) + max_len = max( + hparams.max_target_length, + hparams.max_eval_target_length, + hparams.max_predict_length, + ) hparams.input_shape = [(max_len,), (max_len,)] else: raise ValueError( - 'Undefined input shape for dataset: {}'.format(dataset_name)) + 'Undefined input shape for dataset: {}'.format(dataset_name) + ) return hparams except KeyError: raise ValueError('Unrecognized dataset: {}'.format(dataset_name)) from None @@ -209,7 +273,8 @@ def get_fake_batch(dataset_name): getter = _ALL_DATASETS[dataset_name].fake_batch_getter if getter is None: raise ValueError( - f'Fake batch getter not defined for dataset {dataset_name}') from None + f'Fake batch getter not defined for dataset {dataset_name}' + ) from None except KeyError: raise ValueError('Unrecognized dataset: {}'.format(dataset_name)) from None @@ -235,7 +300,7 @@ def get_data_selector(selector_name: Optional[str]): selector = data_selectors.ALL_SELECTORS[selector_name] except KeyError: raise ValueError( - 'Unrecognized selector: {}'.format(selector_name)) from None + 'Unrecognized selector: {}'.format(selector_name) + ) from None return selector - diff --git a/init2winit/dataset_lib/fineweb_edu_10b_input_pipeline.py b/init2winit/dataset_lib/fineweb_edu_10b_input_pipeline.py index 10f0ccfd..2f5c5ec2 100644 --- a/init2winit/dataset_lib/fineweb_edu_10b_input_pipeline.py +++ b/init2winit/dataset_lib/fineweb_edu_10b_input_pipeline.py @@ -76,6 +76,7 @@ def get_fineweb_edu_dataset( train_batch_size: int, valid_batch_size: int, shuffle_seed: int, + shift: bool = True, ) -> tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]: """Returns wikitext-103 dataset. @@ -84,6 +85,8 @@ def get_fineweb_edu_dataset( train_batch_size: Batch size for train iterations valid_batch_size: Batch size for validation iterations shuffle_seed: seed for shuffling dataset sequences + shift: If True (default), inputs = x[:-1], targets = x[1:] (AR mode). If + False, inputs = targets = x (MDLM mode). Returns: train_dataset, eval_train_dataset, valid_dataset, test_dataset @@ -101,37 +104,47 @@ def get_fineweb_edu_dataset( val_tokens = val_dataset.flat_map(tf.data.Dataset.from_tensor_slices) # split into sequences + seq_batch_len = hps.sequence_length + 1 if shift else hps.sequence_length + eval_seq_batch_len = ( + hps.eval_sequence_length + 1 if shift else hps.eval_sequence_length + ) train_sequences_dataset = train_tokens.batch( - hps.sequence_length + 1, drop_remainder=True + seq_batch_len, drop_remainder=True ) eval_train_sequences_dataset = train_tokens.batch( - hps.eval_sequence_length + 1, drop_remainder=True + eval_seq_batch_len, drop_remainder=True ) val_sequences_dataset = val_tokens.batch( - hps.eval_sequence_length + 1, drop_remainder=True + eval_seq_batch_len, drop_remainder=True ) # Split the sequences into inputs and targets. + if shift: + map_fn = lambda x: { + 'inputs': x['input_ids'][: hps.sequence_length], + 'targets': x['input_ids'][1:], + } + eval_map_fn = lambda x: { + 'inputs': x['input_ids'][: hps.eval_sequence_length], + 'targets': x['input_ids'][1:], + } + else: + map_fn = lambda x: { + 'inputs': x['input_ids'][: hps.sequence_length], + 'targets': x['input_ids'][: hps.sequence_length], + } + eval_map_fn = lambda x: { + 'inputs': x['input_ids'][: hps.eval_sequence_length], + 'targets': x['input_ids'][: hps.eval_sequence_length], + } train_sequences_dataset = train_sequences_dataset.map( - lambda x: { - 'inputs': x['input_ids'][: hps.sequence_length], - 'targets': x['input_ids'][1:], - }, - num_parallel_calls=AUTOTUNE, + map_fn, num_parallel_calls=AUTOTUNE ) eval_train_sequences_dataset = eval_train_sequences_dataset.map( - lambda x: { - 'inputs': x['input_ids'][: hps.eval_sequence_length], - 'targets': x['input_ids'][1:], - }, - num_parallel_calls=AUTOTUNE, + eval_map_fn, num_parallel_calls=AUTOTUNE ) val_sequences_dataset = val_sequences_dataset.map( - lambda x: { - 'inputs': x['input_ids'][: hps.eval_sequence_length], - 'targets': x['input_ids'][1:], - }, - num_parallel_calls=AUTOTUNE, + eval_map_fn, num_parallel_calls=AUTOTUNE ) # Shuffle the train sequences. diff --git a/init2winit/dataset_lib/fineweb_edu_10b_mdlm.py b/init2winit/dataset_lib/fineweb_edu_10b_mdlm.py new file mode 100644 index 00000000..92488ffb --- /dev/null +++ b/init2winit/dataset_lib/fineweb_edu_10b_mdlm.py @@ -0,0 +1,119 @@ +# coding=utf-8 +# Copyright 2026 The init2winit Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MDLM variant of FineWeb-Edu 10B dataset. + +Wraps the standard fineweb_edu_10b dataset with metadata appropriate for +masked diffusion language modeling (no shifting, bidirectional). +""" + +import itertools + +from init2winit.dataset_lib import data_utils +from init2winit.dataset_lib import fineweb_edu_10b +from init2winit.dataset_lib import fineweb_edu_10b_input_pipeline as input_pipeline +import jax + +PAD_ID = input_pipeline.PAD_ID +VOCAB_SIZE = input_pipeline.VOCAB_SIZE + +DEFAULT_HPARAMS = fineweb_edu_10b.DEFAULT_HPARAMS +Dataset = data_utils.Dataset + + +METADATA = { + 'apply_one_hot_in_loss': False, + 'shift_inputs': False, + 'causal': False, + 'pad_token': PAD_ID, +} + + +def get_fineweb_edu_mdlm( + shuffle_rng, batch_size, eval_batch_size=None, hps=None, pad_id=PAD_ID +): + """Returns FineWeb-Edu 10B Dataset without input shifting for MDLM.""" + process_count = jax.process_count() + n_devices = jax.local_device_count() + + if batch_size % process_count != 0: + raise ValueError( + 'process_count={} must divide batch_size={}.'.format( + process_count, batch_size + ) + ) + + if eval_batch_size is None: + eval_batch_size = batch_size + + if eval_batch_size % process_count != 0: + raise ValueError( + 'process_count={} must divide batch_size={}.'.format( + process_count, batch_size + ) + ) + + per_host_batch_size = int(batch_size / process_count) + per_host_eval_batch_size = int(eval_batch_size / process_count) + + if per_host_batch_size % n_devices != 0: + raise ValueError( + 'per_host_batch_size={} must be divisible by n_devices={}.'.format( + per_host_batch_size, n_devices + ) + ) + if per_host_eval_batch_size % n_devices != 0: + raise ValueError( + 'per_host_eval_batch_size={} must be divisible by n_devices={}.'.format( + per_host_eval_batch_size, n_devices + ) + ) + + train_dataset, eval_train_dataset, valid_dataset = ( + input_pipeline.get_fineweb_edu_dataset( + hps, + train_batch_size=per_host_batch_size, + valid_batch_size=per_host_eval_batch_size, + shuffle_seed=data_utils.convert_jax_to_tf_random_seed(shuffle_rng), + shift=False, + ) + ) + + def train_iterator_fn(): + for batch in train_dataset: + yield fineweb_edu_10b.add_weights_to_batch( + data_utils.tf_to_numpy(batch), pad_id + ) + + def eval_train_epoch(num_batches=None): + for batch in itertools.islice(iter(eval_train_dataset), num_batches): + yield fineweb_edu_10b.add_weights_to_batch( + data_utils.tf_to_numpy(batch), pad_id + ) + + def valid_epoch(num_batches=None): + for batch in itertools.islice(iter(valid_dataset), num_batches): + yield fineweb_edu_10b.add_weights_to_batch( + data_utils.tf_to_numpy(batch), pad_id + ) + + # pylint: disable=unreachable + def test_epoch(*args, **kwargs): + del args + del kwargs + return + yield # This yield is needed to make this a valid (null) iterator. + + return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch) diff --git a/init2winit/dataset_lib/test_fineweb_edu_10b_mdlm.py b/init2winit/dataset_lib/test_fineweb_edu_10b_mdlm.py new file mode 100644 index 00000000..f31d2083 --- /dev/null +++ b/init2winit/dataset_lib/test_fineweb_edu_10b_mdlm.py @@ -0,0 +1,168 @@ +# coding=utf-8 +# Copyright 2026 The init2winit Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for fineweb_edu_10b_mdlm dataset wrapper. + +Verifies that the MDLM variant does NOT shift inputs (inputs == targets) +and that padding is applied correctly. + +""" + +from unittest import mock + +from absl.testing import absltest +from init2winit.dataset_lib import fineweb_edu_10b_input_pipeline as input_pipeline +from ml_collections.config_dict import config_dict +import numpy as np +import tensorflow as tf + + +def _make_fake_token_dataset(num_tokens): + """Creates a fake dataset mimicking the tokenized FineWeb-Edu format. + + Args: + num_tokens: Number of tokens in the dataset. + + Returns: + A tf.data.Dataset where each element is a dict with 'input_ids' as a 1D + tensor (one document). The pipeline flat_maps these into a stream of + individual token dicts, then re-batches into sequences. + """ + # Simulate a single "document" containing all tokens. + tokens = np.arange(num_tokens, dtype=np.int64) + ds = tf.data.Dataset.from_tensor_slices({'input_ids': [tokens]}) + return ds + + +def _make_hps(seq_len=8): + return config_dict.ConfigDict( + dict( + sequence_length=seq_len, + max_target_length=seq_len, + max_eval_target_length=seq_len, + eval_sequence_length=seq_len, + ) + ) + + +class InputPipelineShiftTest(absltest.TestCase): + """Tests that shift=True produces shifted data and shift=False does not.""" + + def _get_first_batch(self, shift, seq_len=8, num_tokens=100): + """Runs the pipeline on fake data and returns the first train batch.""" + hps = _make_hps(seq_len) + fake_ds = _make_fake_token_dataset(num_tokens) + + with mock.patch.object(tf.data.Dataset, 'load', return_value=fake_ds): + train_ds, _, _ = input_pipeline.get_fineweb_edu_dataset( + hps, + train_batch_size=2, + valid_batch_size=2, + shuffle_seed=0, + shift=shift, + ) + # Take one batch (undo repeat). + batch = next(iter(train_ds.take(1))) + inputs = batch['inputs'].numpy() + targets = batch['targets'].numpy() + return inputs, targets + + def test_shift_true_targets_offset_by_one(self): + """With shift=True (AR), targets[i] == inputs[i] + 1.""" + inputs, targets = self._get_first_batch(shift=True) + + # Inputs and targets should differ. + self.assertFalse(np.array_equal(inputs, targets)) + + # For contiguous token ids, targets should be inputs shifted by 1. + np.testing.assert_array_equal(targets, inputs + 1) + + def test_shift_false_inputs_equal_targets(self): + """With shift=False (MDLM), inputs == targets exactly.""" + inputs, targets = self._get_first_batch(shift=False) + np.testing.assert_array_equal(inputs, targets) + + def test_shift_false_sequence_length_correct(self): + """With shift=False, sequences should be seq_len long, not seq_len+1.""" + seq_len = 8 + inputs, targets = self._get_first_batch(shift=False, seq_len=seq_len) + self.assertEqual(inputs.shape[-1], seq_len) + self.assertEqual(targets.shape[-1], seq_len) + + def test_shift_true_sequence_length_correct(self): + """With shift=True, sequences come from seq_len+1 tokens.""" + seq_len = 8 + inputs, targets = self._get_first_batch(shift=True, seq_len=seq_len) + self.assertEqual(inputs.shape[-1], seq_len) + self.assertEqual(targets.shape[-1], seq_len) + + +class PaddingTest(absltest.TestCase): + """Tests that eval batches are padded correctly.""" + + def test_eval_batch_padding_applied(self): + """Eval batches should be padded to batch_size with PAD_ID.""" + hps = _make_hps(seq_len=4) + # 13 tokens -> 3 sequences of length 4 (shift=False, drop_remainder=True). + # With batch_size=2, we get 1 full batch + 1 partial batch (1 real + 1 pad). + fake_ds = _make_fake_token_dataset(13) + + with mock.patch.object(tf.data.Dataset, 'load', return_value=fake_ds): + _, _, valid_ds = input_pipeline.get_fineweb_edu_dataset( + hps, + train_batch_size=2, + valid_batch_size=2, + shuffle_seed=0, + shift=False, + ) + + batches = list(valid_ds.as_numpy_iterator()) + # Should have 2 batches: first full, second padded. + self.assertLen(batches, 2) + + padded_batch = batches[1] + pad_id = int(input_pipeline.PAD_ID.numpy()) + + # The second row of the padded batch should be all PAD_ID. + np.testing.assert_array_equal(padded_batch['inputs'][1], np.full(4, pad_id)) + np.testing.assert_array_equal( + padded_batch['targets'][1], np.full(4, pad_id) + ) + + def test_eval_batch_padding_not_in_full_batches(self): + """Full eval batches should contain no padding.""" + hps = _make_hps(seq_len=4) + fake_ds = _make_fake_token_dataset(13) + + with mock.patch.object(tf.data.Dataset, 'load', return_value=fake_ds): + _, _, valid_ds = input_pipeline.get_fineweb_edu_dataset( + hps, + train_batch_size=2, + valid_batch_size=2, + shuffle_seed=0, + shift=False, + ) + + batches = list(valid_ds.as_numpy_iterator()) + full_batch = batches[0] + pad_id = int(input_pipeline.PAD_ID.numpy()) + + # No element in the full batch should be PAD_ID. + self.assertTrue(np.all(full_batch['inputs'] != pad_id)) + self.assertTrue(np.all(full_batch['targets'] != pad_id)) + + +if __name__ == '__main__': + absltest.main() diff --git a/init2winit/model_lib/losses.py b/init2winit/model_lib/losses.py index 9dfc5635..c0fdb2d2 100644 --- a/init2winit/model_lib/losses.py +++ b/init2winit/model_lib/losses.py @@ -428,6 +428,9 @@ def weighted_mean_absolute_error(logits, targets, weights=None): ), 'ctc': (ctc_loss, jax.nn.log_softmax), 'mean_absolute_error': (weighted_mean_absolute_error, None), + # The below loss is needed for the masked diffusion language model. + # This is because the loss calculation is part of the model there. + 'passthrough': (lambda logits, targets, weights=None: (logits, 1.0), None), } diff --git a/init2winit/model_lib/mdlm_rope_nanodo.py b/init2winit/model_lib/mdlm_rope_nanodo.py new file mode 100644 index 00000000..feaa70dd --- /dev/null +++ b/init2winit/model_lib/mdlm_rope_nanodo.py @@ -0,0 +1,266 @@ +# coding=utf-8 +# Copyright 2026 The init2winit Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MDLM (Masked Diffusion Language Model) with RoPE transformer. + +Bidirectional transformer for masked diffusion, reusing building blocks from +rope_nanodo.py. Implements the MDLM training objective from +"Simple and Effective Masked Diffusion Language Models" (Sahoo et al., 2024). +""" + +# pylint: disable=invalid-name +from flax import linen as nn +from init2winit import utils +from init2winit.model_lib import base_model +from init2winit.model_lib import mdlm_schedules +from init2winit.model_lib import model_utils +from init2winit.model_lib import rope_nanodo +import jax +import jax.numpy as jnp +from ml_collections.config_dict import config_dict + +DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + emb_dim=512, + num_heads=8, + num_layers=12, + mlp_dim=2048, + rng_seed=-1, + computation_dtype='bfloat16', + model_dtype='float32', + optimizer='adam', + batch_size=256, + lr_hparams={'base_lr': 0.01, 'schedule': 'constant'}, + opt_hparams={ + 'beta1': 0.9, + 'beta2': 0.999, + 'epsilon': 1e-8, + 'weight_decay': 0.0, + }, + l2_decay_factor=0.0005, + l2_decay_rank_threshold=2, + grad_clip=None, + label_smoothing=0.0, + use_shallue_label_smoothing=False, + normalization='rmsnorm', + mlp_activation='glu', + qk_norm=True, + tie_embeddings=True, + noise_schedule='log_linear', + epsilon=1e-7, + ) +) + + +class TimestepEmbedding(nn.Module): + """Embeds scalar timestep into D-dimensional vector via MLP.""" + + D: int + dtype: jnp.dtype = jnp.float32 + + @nn.compact + def __call__(self, t_B: jax.Array): + half_d = self.D // 2 + freq = jnp.exp( + -jnp.log(10000.0) * jnp.arange(half_d, dtype=jnp.float32) / half_d + ) + angles = t_B[:, None] * freq[None, :] + sincos = jnp.concatenate([jnp.sin(angles), jnp.cos(angles)], axis=-1) + + h = nn.Dense(self.D, dtype=self.dtype)(sincos) + h = nn.gelu(h) + h = nn.Dense(self.D, dtype=self.dtype)(h) + return h + + +class MDLMTransformer(nn.Module): + """Bidirectional transformer for MDLM.""" + + docfg: rope_nanodo.DoConfig + + def setup(self): + cfg = self.docfg + self.embed = nn.Embed( + num_embeddings=cfg.V + 1, + features=cfg.D, + embedding_init=cfg.embed_init, + ) + + self.time_embed = TimestepEmbedding(D=cfg.D, dtype=cfg.dtype) + + self.blocks = [rope_nanodo.TBlock(cfg) for _ in range(cfg.N)] + if cfg.normalization == 'layernorm': + self.out_ln = nn.LayerNorm(dtype=cfg.dtype, use_bias=False) + elif cfg.normalization == 'rmsnorm': + self.out_ln = nn.RMSNorm(dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) + else: + raise ValueError(f'Unknown normalization: {cfg.normalization}') + + if cfg.tie_embeddings: + self.output_proj = None + else: + self.output_proj = nn.Dense( + cfg.V, kernel_init=cfg.embed_init, dtype=cfg.dtype, name='output_proj' + ) + + def __call__(self, z_BxL: jax.Array, t_B: jax.Array, train: bool): + del train + cfg = self.docfg + + z_BxLxD = self.embed(z_BxL) + + t_BxD = self.time_embed(t_B) + z_BxLxD = z_BxLxD + t_BxD[:, None, :] + + for block in self.blocks: + z_BxLxD = block(z_BxLxD) + z_BxLxD = self.out_ln(z_BxLxD) + + if self.output_proj is not None: + logits_BxLxV = self.output_proj(z_BxLxD) + else: + embed_matrix = self.embed.embedding[: cfg.V] + logits_BxLxV = z_BxLxD.astype(jnp.float32) @ embed_matrix.T + + return logits_BxLxV + + +class MDLMModel(base_model.BaseModel): + """MDLM model with diffusion training objective.""" + + def build_flax_module(self): + config = rope_nanodo.DoConfig( + D=self.hps['emb_dim'], + H=self.hps['num_heads'], + N=self.hps['num_layers'], + V=self.hps['vocab_size'], + L=self.hps['input_shape'][0], + F=self.hps['mlp_dim'], + dtype=utils.dtype_from_str(self.hps['computation_dtype']), + mlp_activation=self.hps['mlp_activation'], + normalization=self.hps['normalization'], + qk_norm=self.hps['qk_norm'], + tie_embeddings=self.hps['tie_embeddings'], + is_causal=False, + ) + return MDLMTransformer(config) + + def get_fake_inputs(self, hps): + dummy_inputs = [ + jnp.zeros((hps.batch_size, *hps.input_shape), dtype=jnp.int32), + jnp.zeros((hps.batch_size,), dtype=jnp.float32), # t_B timestep + ] + return dummy_inputs + + def _mask_and_forward(self, params, batch, rng, train): + vocab_size = self.hps['vocab_size'] + mask_id = vocab_size + + get_alpha, _ = mdlm_schedules.get_schedule(self.hps['noise_schedule']) + + rng_t, rng_mask = jax.random.split(rng) + + x_BxL = batch['inputs'] + B = x_BxL.shape[0] + + t_B = jax.random.uniform(rng_t, shape=(B,), minval=1e-5, maxval=1.0) + + alpha_t_B = get_alpha(t_B) + + mask_prob_BxL = jnp.broadcast_to((1.0 - alpha_t_B)[:, None], x_BxL.shape) + mask_draws_BxL = jax.random.uniform(rng_mask, shape=x_BxL.shape) + is_masked_BxL = mask_draws_BxL < mask_prob_BxL + + z_BxL = jnp.where(is_masked_BxL, mask_id, x_BxL) + + variables = {'params': params} + logits_BxLxV = self.flax_module.apply(variables, z_BxL, t_B, train=train) + + _NEG_INF = jnp.finfo(logits_BxLxV.dtype).min + logits_BxLxV = jnp.where( + is_masked_BxL[:, :, None], + logits_BxLxV, + _NEG_INF, + ) + unmasked_targets = jax.nn.one_hot(x_BxL, vocab_size) + logits_BxLxV = jnp.where( + is_masked_BxL[:, :, None], + logits_BxLxV, + unmasked_targets * 1e6, + ) + + return logits_BxLxV, z_BxL, is_masked_BxL, t_B, alpha_t_B + + def _compute_elbo(self, params, batch, rng, train): + logits_BxLxV, _, is_masked_BxL, t_B, alpha_t_B = self._mask_and_forward( + params, batch, rng, train + ) + + _, get_alpha_deriv = mdlm_schedules.get_schedule(self.hps['noise_schedule']) + + x_BxL = batch['inputs'] + B = x_BxL.shape[0] + + log_probs_BxLxV = jax.nn.log_softmax(logits_BxLxV, axis=-1) + + targets_BxL = x_BxL + log_prob_true_BxL = log_probs_BxLxV[ + jnp.arange(B)[:, None], + jnp.arange(x_BxL.shape[1])[None, :], + targets_BxL, + ] + + alpha_deriv_B = get_alpha_deriv(t_B) + weight_B = -alpha_deriv_B / (1.0 - alpha_t_B + self.hps['epsilon']) + + loss_BxL = -log_prob_true_BxL * is_masked_BxL.astype( + log_prob_true_BxL.dtype + ) + + weighted_loss_BxL = weight_B[:, None] * loss_BxL + + pad_weights = batch.get('weights') + if pad_weights is not None: + weighted_loss_BxL = weighted_loss_BxL * pad_weights + + total_loss = jnp.sum(weighted_loss_BxL) + if pad_weights is not None: + num_tokens = jnp.sum(pad_weights) + else: + num_tokens = jnp.array(B * x_BxL.shape[1], dtype=jnp.float32) + return total_loss / num_tokens + + def inference(self, params, batch, rng): + logits_BxLxV, z_BxL, is_masked_BxL, t_B, _ = self._mask_and_forward( + params, batch, rng, train=False + ) + predictions_BxL = jnp.argmax(logits_BxLxV, axis=-1) + return predictions_BxL, z_BxL, is_masked_BxL, t_B + + def evaluate_batch(self, params, batch_stats, batch): + rng = batch['eval_rng'] + loss = self._compute_elbo(params, batch, rng, train=False) + return self.metrics_bundle.single_from_model_output(normalized_loss=loss) + + def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): + loss = self._compute_elbo(params, batch, dropout_rng, train=True) + + if self.hps.get('l2_decay_factor'): + l2_loss = model_utils.l2_regularization( + params, self.hps.l2_decay_rank_threshold + ) + loss += 0.5 * self.hps.l2_decay_factor * l2_loss + + return loss, {} diff --git a/init2winit/model_lib/mdlm_schedules.py b/init2winit/model_lib/mdlm_schedules.py new file mode 100644 index 00000000..60fa38f8 --- /dev/null +++ b/init2winit/model_lib/mdlm_schedules.py @@ -0,0 +1,61 @@ +# coding=utf-8 +# Copyright 2026 The init2winit Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Noise schedules for Masked Diffusion Language Models (MDLM). + +Each schedule defines alpha_t (noise level) as a function of t in [0, 1], +where alpha_0 ~ 1 (clean) and alpha_1 ~ 0 (fully masked). +""" + +import jax.numpy as jnp + + +def log_linear_alpha(t): + return 1.0 - t + + +def log_linear_alpha_derivative(t): + del t + return -1.0 + + +def cosine_alpha(t): + return jnp.cos(jnp.pi * t / 2.0) + + +def cosine_alpha_derivative(t): + return -jnp.pi / 2.0 * jnp.sin(jnp.pi * t / 2.0) + + +def geometric_alpha(t): + return (1.0 - t) ** 2 + + +def geometric_alpha_derivative(t): + return -2.0 * (1.0 - t) + + +_ALL_SCHEDULES = { + 'log_linear': (log_linear_alpha, log_linear_alpha_derivative), + 'cosine': (cosine_alpha, cosine_alpha_derivative), + 'geometric': (geometric_alpha, geometric_alpha_derivative), +} + + +def get_schedule(name): + try: + return _ALL_SCHEDULES[name] + except KeyError: + raise ValueError(f'Unknown noise schedule: {name}') from None diff --git a/init2winit/model_lib/metrics.py b/init2winit/model_lib/metrics.py index 90de68d6..83cd426f 100644 --- a/init2winit/model_lib/metrics.py +++ b/init2winit/model_lib/metrics.py @@ -481,12 +481,39 @@ def compute(self): return _Metric -def compute_wer(decoded, - decoded_paddings, - targets, - target_paddings, - tokenizer, - tokenizer_type='SPM'): +def mdlm_perplexity(): + """Metrics that calculate the perplexity in addition to the cross entropy loss for the masked diffusion language model.""" + + @flax.struct.dataclass + class _MDLMPerplexity(metrics.Metric): + """Calculates the perplexity based on the cross-entropy provided by the masked diffusion language model.""" + + total: np.float32 + weight: np.float32 + + @classmethod + def from_model_output(cls, normalized_loss, **_): + return cls(total=normalized_loss, weight=1.0) + + def merge(self, other): + return type(self)( + total=self.total + other.total, weight=self.weight + other.weight + ) + + def compute(self): + return jnp.exp(self.total / self.weight) + + return _MDLMPerplexity + + +def compute_wer( + decoded, + decoded_paddings, + targets, + target_paddings, + tokenizer, + tokenizer_type='SPM', +): """Computes word error rate.""" word_errors = 0.0 num_words = 0.0 @@ -606,6 +633,9 @@ def compute(self): ssim=weighted_average_metric(ssim), num_examples=NumExamples, ), + 'mdlm_metrics': metrics.Collection.create( + ce_loss=average_ctc_loss(), perplexity=mdlm_perplexity() + ), } diff --git a/init2winit/model_lib/models.py b/init2winit/model_lib/models.py index 38bc1c1e..3b30aced 100644 --- a/init2winit/model_lib/models.py +++ b/init2winit/model_lib/models.py @@ -28,6 +28,7 @@ from init2winit.model_lib import local_attention_transformer from init2winit.model_lib import lstm_lm from init2winit.model_lib import max_pooling_cnn +from init2winit.model_lib import mdlm_rope_nanodo from init2winit.model_lib import mlperf_resnet from init2winit.model_lib import nanodo from init2winit.model_lib import nqm @@ -91,6 +92,10 @@ nanodo.NanodoModel, nanodo.DEFAULT_HPARAMS, ), + 'mdlm_rope_nanodo': ( + mdlm_rope_nanodo.MDLMModel, + mdlm_rope_nanodo.DEFAULT_HPARAMS, + ), 'rope_nanodo': ( rope_nanodo.RoPENanodoModel, rope_nanodo.DEFAULT_HPARAMS, diff --git a/init2winit/model_lib/rope_nanodo.py b/init2winit/model_lib/rope_nanodo.py index 12210690..0b43ac8d 100644 --- a/init2winit/model_lib/rope_nanodo.py +++ b/init2winit/model_lib/rope_nanodo.py @@ -88,6 +88,8 @@ class DoConfig: mlp_activation: str = 'glu' normalization: str = 'rmsnorm' qk_norm: bool = True + is_causal: bool = True + eps: float = 1e-6 class Mlp(nn.Module): @@ -144,14 +146,13 @@ def precompute_freqs_cis_jax(dim, end, theta=10000.0): @jax.jit def apply_rope(q, k, freqs_cis): """Apply rotary embeddings to Q and K.""" + input_dtype = q.dtype def rotate_tensor(x): - # Split into real and imaginary parts x_r2 = x.reshape(*x.shape[:-1], -1, 2) L = x.shape[1] freqs = freqs_cis[:, :L, :, :, :] - # Apply rotation rotated_x_r2 = jnp.stack( [ x_r2[..., 0] * freqs[..., 0] - x_r2[..., 1] * freqs[..., 1], @@ -162,14 +163,13 @@ def rotate_tensor(x): return rotated_x_r2.reshape(*x.shape) - # Apply rotation to Q and K separately - rotated_q = rotate_tensor(q) - rotated_k = rotate_tensor(k) + rotated_q = rotate_tensor(q).astype(input_dtype) + rotated_k = rotate_tensor(k).astype(input_dtype) return rotated_q, rotated_k -class CausalAttn(nn.Module): +class Attention(nn.Module): """Causal attention layer with rotary embeddings.""" cfg: DoConfig @@ -204,50 +204,42 @@ def setup(self): use_bias=False, dtype=cfg.dtype, ) - self.layer_norm_q = nn.LayerNorm(dtype=cfg.dtype, use_bias=False) - self.layer_norm_k = nn.LayerNorm(dtype=cfg.dtype, use_bias=False) + if cfg.qk_norm: + self.eps = cfg.eps + attn_scale0 = jnp.log2(cfg.L**2 - cfg.L).astype(cfg.dtype) + self.attn_scale = self.param( + 'attn_scale', + nn.initializers.constant(attn_scale0, dtype=cfg.dtype), + (), + ) def __call__(self, x_BxLxD: jax.Array): cfg = self.cfg - # Project inputs to Q, K, V q_BxLxHxDh = self.multilinear_query(x_BxLxD) k_BxLxHxDh = self.multilinear_key(x_BxLxD) v_BxLxHxDh = self.multilinear_value(x_BxLxD) if cfg.qk_norm: - q_BxLxHxDh = self.layer_norm_q(q_BxLxHxDh) - k_BxLxHxDh = self.layer_norm_k(k_BxLxHxDh) + q_BxLxHxDh /= ( + jnp.linalg.norm(q_BxLxHxDh, axis=-1, keepdims=True) + self.eps + ) + k_BxLxHxDh /= ( + jnp.linalg.norm(k_BxLxHxDh, axis=-1, keepdims=True) + self.eps + ) + q_BxLxHxDh = q_BxLxHxDh * self.attn_scale - # Apply rotary embeddings to Q and K q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis) - # Scale queries - q_BxLxHxDh /= self.Dh**0.5 - - # Compute attention scores - att_BxHxLxL = jnp.einsum('...qhd,...khd->...hqk', q_BxLxHxDh, k_BxLxHxDh) - # TODO(kasimbeg): Remove this. - # # cast to fp32 for softmax - # att_BxHxLxL = att_BxHxLxL.astype(jnp.float32) - - # Causal attention mask - L = x_BxLxD.shape[1] - mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_)) - - # Apply mask and softmax - _NEG_INF = jnp.finfo(cfg.dtype).min - att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) - att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) - att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) - - # Compute attention output - out_BxLxHxDh = jnp.einsum('...hqk,...khd->...qhd', att_BxHxLxL, v_BxLxHxDh) + out_BxLxHxDh = jax.nn.dot_product_attention( + q_BxLxHxDh, + k_BxLxHxDh, + v_BxLxHxDh, + is_causal=cfg.is_causal, + scale=1.0, + ) - # Reshape and project output out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape) - - # Output projection out_BxLxD = self.output_projection(out_BxLxD) return out_BxLxD @@ -272,7 +264,7 @@ def __call__(self, in_BxLxD: jax.Array): else: raise ValueError(f'Unknown normalization: {cfg.normalization}') - x_BxLxD = CausalAttn(cfg)(x_BxLxD) + x_BxLxD = Attention(cfg)(x_BxLxD) x_BxLxD += in_BxLxD z_BxLxD = Mlp(cfg)(x_BxLxD) diff --git a/init2winit/model_lib/test_models.py b/init2winit/model_lib/test_models.py index c6a3a693..0dc7be7c 100644 --- a/init2winit/model_lib/test_models.py +++ b/init2winit/model_lib/test_models.py @@ -186,6 +186,11 @@ 'output_shape': (32, 32000), 'vocab_size': 32000, }, + 'mdlm_rope_nanodo': { + 'input_shape': (32,), + 'output_shape': (32, 50258), + 'vocab_size': 50257, + }, 'simple_cnn': { 'input_shape': (32, 32, 3), 'output_shape': (5,), @@ -289,6 +294,7 @@ 'performer': 'cross_entropy', 'resnet': 'cross_entropy', 'rope_nanodo': 'cross_entropy', + 'mdlm_rope_nanodo': 'passthrough', 'simple_cnn': 'cross_entropy', 'transformer': 'cross_entropy', 'unet': 'mean_absolute_error', @@ -327,6 +333,7 @@ 'nanodo': 'classification_metrics', 'unet': 'image_reconstruction_metrics', 'rope_nanodo': 'classification_metrics', + 'mdlm_rope_nanodo': 'mdlm_metrics', 'vit': 'classification_metrics', 'wide_resnet': 'classification_metrics', 'xformer_translate_binary': 'classification_metrics',