From da5f85a7c878f0399c7b8a5d2fcfb9d729e567ea Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 11 Mar 2025 15:46:49 +0100 Subject: [PATCH 01/82] first LM commit --- algoperf/workloads/lm/__init__.py | 0 algoperf/workloads/lm/dev/data_pytorch.py | 42 ++++++++++ algoperf/workloads/lm/input_pipeline.py | 82 ++++++++++++++++++++ algoperf/workloads/lm/lm_pytorch/__init__.py | 0 algoperf/workloads/lm/lm_pytorch/workload.py | 36 +++++++++ algoperf/workloads/lm/test_01.py | 22 ++++++ algoperf/workloads/lm/test_input_pipeline.py | 68 ++++++++++++++++ algoperf/workloads/lm/workload.py | 66 ++++++++++++++++ 8 files changed, 316 insertions(+) create mode 100644 algoperf/workloads/lm/__init__.py create mode 100644 algoperf/workloads/lm/dev/data_pytorch.py create mode 100644 algoperf/workloads/lm/input_pipeline.py create mode 100644 algoperf/workloads/lm/lm_pytorch/__init__.py create mode 100644 algoperf/workloads/lm/lm_pytorch/workload.py create mode 100644 algoperf/workloads/lm/test_01.py create mode 100644 algoperf/workloads/lm/test_input_pipeline.py create mode 100644 algoperf/workloads/lm/workload.py diff --git a/algoperf/workloads/lm/__init__.py b/algoperf/workloads/lm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/lm/dev/data_pytorch.py b/algoperf/workloads/lm/dev/data_pytorch.py new file mode 100644 index 000000000..d0081a75d --- /dev/null +++ b/algoperf/workloads/lm/dev/data_pytorch.py @@ -0,0 +1,42 @@ + +import torch + +from datasets import Dataset, load_from_disk +from torch.utils.data import DataLoader + +trainset_path = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens/train" +vocab_size = 50280 +seq_len = 2048 +sampler = 'sequential' +sampler_seed = None +num_workers = 4 + +train_set = load_from_disk(trainset_path) # + +""" +>>> type(train_set) + + +>>> len(train_set) +7501407 + +>>> train_set[0] +{'input_ids': tensor([ 5166, 20, 1639, ..., 275, 253, 19992])} + +>>> type(train_set[0]['input_ids']) + + +# In PyTorch we do: +trainloader = DataLoader( + train_set, + sampler = ..., + batch_size = ..., + num_workers = ..., + pin_memory = ..., + ) + +# PyTorch’s DataLoader expects an iterable dataset, +# which means it calls __getitem__() and __len__() on train_set. + +""" + diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py new file mode 100644 index 000000000..7424dd6d5 --- /dev/null +++ b/algoperf/workloads/lm/input_pipeline.py @@ -0,0 +1,82 @@ +"""Input pipeline for a LM dataset.""" +import functools +import os + +from datasets import Dataset, load_from_disk +from typing import Dict, List, Optional, Union + +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds + +from algoperf import data_utils +from algoperf.pytorch_utils import pytorch_setup + +RANK = pytorch_setup()[1] +# Avoid multithreading in all processes but the first (rank 0). +AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None + + +def get_lm_dataset(data_rng, + split: str, + data_dir: str, + is_training: bool, + vocab_size: int, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False, + vocab_path: Optional[str] = None): + """Load HF dataset and return a TF dataset.""" + + dataset_path = os.path.join(data_dir, split) + dataset = load_from_disk(dataset_path) # Loads HF arrow dataset + + is_training = split == "train" + shuffle = split in ['train', 'eval_train'] + + def tf_generator(): + """Generates data in a TensorFlow-friendly format.""" + for example in dataset: + yield { + "inputs": tf.convert_to_tensor(example["input_ids"][:-1], dtype=tf.int32), + "targets": tf.convert_to_tensor(example["input_ids"][1:], dtype=tf.int32), + } + + # Create a TensorFlow dataset from the generator function + ds = tf.data.Dataset.from_generator( + tf_generator, + output_signature={ + "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int32), + "targets": tf.TensorSpec(shape=(None,), dtype=tf.int32), + } + ) + + # Avoid creating too many threads when using PyTorch DDP. + if RANK != 0: + options = tf.data.Options() + options.threading.private_threadpool_size = 1 + ds = ds.with_options(options) + + if shuffle: + print(f"Shuffling dataset with seed: {data_rng[0]}, type={type(data_rng[0])}") + ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) + + if is_training: + ds = ds.repeat() + + # Batch the dataset, ensuring the last batch is dropped if not full during training + ds = ds.batch(global_batch_size, drop_remainder=is_training) + ds = ds.prefetch(AUTOTUNE) + + # Limit the dataset to a fixed number of batches if `num_batches` is specified + if num_batches: + ds = ds.take(num_batches) + + # Shard the dataset across multiple GPUs/TPUs if necessary + ds = map( + functools.partial( + data_utils.shard_and_maybe_pad_np, + global_batch_size=global_batch_size), + ds) + + return ds \ No newline at end of file diff --git a/algoperf/workloads/lm/lm_pytorch/__init__.py b/algoperf/workloads/lm/lm_pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py new file mode 100644 index 000000000..904657b1d --- /dev/null +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -0,0 +1,36 @@ +"""LM workload implemented in PyTorch.""" + +import contextlib +from typing import Any, Dict, Optional, Tuple + +from absl import logging +import jax +import tensorflow as tf +import torch +import torch.distributed as dist +from torch.nn import DataParallel as DP +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP + +from algoperf import param_utils +from algoperf import pytorch_utils +from algoperf import spec +from algoperf.workloads.lm.workload import BaseLmWorkload + +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() + + +class LmWorkload(BaseLmWorkload): + """LM PyTorch workload.""" + + def init_model_fn(): + pass + + def model_fn(): + pass + + def _build_input_queue(): + pass + + def eval_step(): + pass diff --git a/algoperf/workloads/lm/test_01.py b/algoperf/workloads/lm/test_01.py new file mode 100644 index 000000000..e33ddf3e7 --- /dev/null +++ b/algoperf/workloads/lm/test_01.py @@ -0,0 +1,22 @@ +import os +import tensorflow as tf +import torch +from datasets import load_from_disk + +from algoperf.workloads.lm.input_pipeline import get_lm_dataset + +DATASET_PATH = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens" +BATCH_SIZE = 2 +SEED = 42 # Fixed random seed for reproducibility + +tf_seed = SEED + +# Load the dataset +ds = get_lm_dataset( + data_rng=[tf_seed], # Ensure correct seed type + split="train", + data_dir=DATASET_PATH, + is_training=True, + vocab_size=0, # Not needed but kept for function signature + global_batch_size=BATCH_SIZE, +) diff --git a/algoperf/workloads/lm/test_input_pipeline.py b/algoperf/workloads/lm/test_input_pipeline.py new file mode 100644 index 000000000..47c11969f --- /dev/null +++ b/algoperf/workloads/lm/test_input_pipeline.py @@ -0,0 +1,68 @@ +import os +import tensorflow as tf +import torch +from datasets import load_from_disk + +from algoperf.workloads.lm.input_pipeline import get_lm_dataset + +DATASET_PATH = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens" +BATCH_SIZE = 2 +SEED = 42 # Fixed random seed for reproducibility + + +def test_tf_dataset(): + """Tests if get_lm_dataset correctly loads the HF dataset as a TensorFlow dataset.""" + + print(f"Loading dataset from: {DATASET_PATH}") + + tf_seed = SEED + + # Load the dataset + ds = get_lm_dataset( + data_rng=[tf_seed], # Ensure correct seed type + split="train", + data_dir=DATASET_PATH, + is_training=True, + vocab_size=0, # Not needed but kept for function signature + global_batch_size=BATCH_SIZE, + ) + + print("Testing TensorFlow Dataset Output...") + for batch in ds.take(2): # Take two batches to test + print("Inputs:", batch["inputs"].numpy()) # Convert to NumPy for inspection + print("Targets:", batch["targets"].numpy()) + +def test_pytorch_dataloader(): + """Tests if the TensorFlow dataset can be converted to PyTorch format correctly.""" + + # Use the same TensorFlow-compatible seed + tf_seed = tf.constant(SEED, dtype=tf.int64) + + # Load the dataset + ds = get_lm_dataset( + data_rng=[tf_seed], # Ensure correct seed type + split="train", + data_dir=DATASET_PATH, + is_training=True, + vocab_size=0, + global_batch_size=BATCH_SIZE, + ) + + def _input_queue_generator(): + """Generator that converts TF dataset batches to PyTorch tensors.""" + for batch in iter(ds): + batch = {k: torch.tensor(v.numpy()) for k, v in batch.items()} # Convert to PyTorch tensors + yield batch + + dataloader = _input_queue_generator() + + print("\nTesting PyTorch DataLoader Output...") + for _ in range(2): # Take two batches + batch = next(dataloader) + print("Inputs:", batch["inputs"]) + print("Targets:", batch["targets"]) + +# Run tests +if __name__ == "__main__": + test_tf_dataset() + test_pytorch_dataloader() \ No newline at end of file diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py new file mode 100644 index 000000000..d070cabec --- /dev/null +++ b/algoperf/workloads/lm/workload.py @@ -0,0 +1,66 @@ +"""LM workload parent class.""" + +import abc +import math +import os +from typing import Any, Dict, Optional, Tuple + +import jax +import numpy as np +import torch + +from algoperf import spec +from algoperf.workloads.lm import input_pipeline + +USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ + + +class BaseLmWorkload(spec.Workload): + """A LM workload.""" + + _vocab_size: int = 32000 + + def __init__(self) -> None: + super().__init__() + self._tokenizer = None + + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + is_training = split == 'train' + ds, self._tokenizer = input_pipeline.get_lm_dataset( + data_rng, + split, + data_dir, + is_training=is_training, + vocab_size=self._vocab_size, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=repeat_final_dataset) + + for batch in iter(ds): + yield batch + + def _eval_model_on_split(self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0) -> Dict[str, float]: + """Run a full evaluation of the model.""" + + def loss_fn( + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + """Evaluate the loss function at (label_batch, logits_batch).""" + pass \ No newline at end of file From a12a36404ce907c8e50e67c8e4a5eb25baa9a2f3 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 12 Mar 2025 15:49:04 +0100 Subject: [PATCH 02/82] lm data pipeline --- algoperf/workloads/lm/input_pipeline.py | 11 +-- algoperf/workloads/lm/test_01.py | 96 +++++++++++++++++++++---- datasets/dataset_setup.py | 96 +++++++++++++++++++++++++ datasets/lm_preprocess.py | 0 4 files changed, 185 insertions(+), 18 deletions(-) create mode 100644 datasets/lm_preprocess.py diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 7424dd6d5..a14cebeda 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -5,6 +5,7 @@ from datasets import Dataset, load_from_disk from typing import Dict, List, Optional, Union +import jax import numpy as np import tensorflow as tf import tensorflow_datasets as tfds @@ -17,7 +18,7 @@ AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None -def get_lm_dataset(data_rng, +def get_lm_dataset(data_rng: jax.random.PRNGKey, split: str, data_dir: str, is_training: bool, @@ -37,11 +38,12 @@ def get_lm_dataset(data_rng, def tf_generator(): """Generates data in a TensorFlow-friendly format.""" for example in dataset: + input_ids = example["input_ids"].numpy().astype(np.int32) # torch tensor TODO: remove numpy conversion yield { - "inputs": tf.convert_to_tensor(example["input_ids"][:-1], dtype=tf.int32), - "targets": tf.convert_to_tensor(example["input_ids"][1:], dtype=tf.int32), + "inputs": tf.convert_to_tensor(input_ids[:-1], dtype=tf.int32), + "targets": tf.convert_to_tensor(input_ids[1:], dtype=tf.int32), } - + # Create a TensorFlow dataset from the generator function ds = tf.data.Dataset.from_generator( tf_generator, @@ -58,7 +60,6 @@ def tf_generator(): ds = ds.with_options(options) if shuffle: - print(f"Shuffling dataset with seed: {data_rng[0]}, type={type(data_rng[0])}") ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) if is_training: diff --git a/algoperf/workloads/lm/test_01.py b/algoperf/workloads/lm/test_01.py index e33ddf3e7..977fae11a 100644 --- a/algoperf/workloads/lm/test_01.py +++ b/algoperf/workloads/lm/test_01.py @@ -1,22 +1,92 @@ + import os +import numpy as np import tensorflow as tf import torch + from datasets import load_from_disk +from absl import app +from absl import flags +from absl import logging + +from algoperf.profiler import PassThroughProfiler +from algoperf import random_utils as prng +from algoperf.pytorch_utils import pytorch_init +from algoperf.pytorch_utils import pytorch_setup from algoperf.workloads.lm.input_pipeline import get_lm_dataset + +tf.config.set_visible_devices([], 'GPU') + +# Environment variables +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +# disable only for deepspeech if it works fine for other workloads +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' +# (nico) +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + +flags.DEFINE_enum( + 'framework', + None, + enum_values=['jax', 'pytorch'], + help='Whether to use Jax or Pytorch for the submission. Controls among ' + 'other things if the Jax or Numpy RNG library is used for RNG.') + +FLAGS = flags.FLAGS +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() + + DATASET_PATH = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens" BATCH_SIZE = 2 -SEED = 42 # Fixed random seed for reproducibility - -tf_seed = SEED - -# Load the dataset -ds = get_lm_dataset( - data_rng=[tf_seed], # Ensure correct seed type - split="train", - data_dir=DATASET_PATH, - is_training=True, - vocab_size=0, # Not needed but kept for function signature - global_batch_size=BATCH_SIZE, -) +RNG_SEED = 1996 # Fixed random seed for reproducibility + + +def main(_): + profiler = PassThroughProfiler() + if FLAGS.framework == 'pytorch': + pytorch_init(USE_PYTORCH_DDP, RANK, profiler) + + rng = prng.PRNGKey(RNG_SEED) + data_rng, _, _, _ = prng.split(rng, 4) + + print(f"data_rng = {data_rng}") + + # Load the dataset + ds = get_lm_dataset( + data_rng=data_rng, + split="train", + data_dir=DATASET_PATH, + is_training=True, + vocab_size=0, # Not needed but kept for function signature + global_batch_size=BATCH_SIZE, + ) + # Check if `ds` acts as a generator + if hasattr(ds, '__iter__'): + print("Dataset is an iterable/generator.") + + # Fetch first batch + try: + first_batch = next(iter(ds)) + print(f"Successfully retrieved first batch.") + except Exception as e: + print(f"Error retrieving first batch: {e}") + return + + # Print structure of a batch + print(f"First batch keys: {first_batch.keys()}") + print(f"First batch shapes:") + for key, value in first_batch.items(): + print(f" - {key}: {value.shape} (dtype: {value.dtype})") + + # Validate batch dimensions + assert "inputs" in first_batch and "targets" in first_batch, "Missing expected keys!" + assert first_batch["inputs"].shape[0] == BATCH_SIZE, "Batch size mismatch!" + assert first_batch["inputs"].shape == first_batch["targets"].shape, "Inputs and targets should have the same shape!" + + print(f"Dataset is correctly batched and structured.") + print(f"Test completed successfully.") + +if __name__ == '__main__': + flags.mark_flag_as_required('framework') + app.run(main) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index efe923dbe..14dd24545 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -76,13 +76,21 @@ normalize_feature_names from datasets import librispeech_preprocess from datasets import librispeech_tokenizer +from datasets import lm_preprocess +import datasets as hf_datasets +# from datasets import load_dataset, Dataset +from transformers import AutoTokenizer + +import math import functools +import itertools import os import shutil import subprocess import tarfile +from typing import Dict, List, Any from absl import app from absl import flags from absl import logging @@ -126,6 +134,9 @@ flags.DEFINE_boolean('librispeech', False, 'If --all=false, whether or not to download LibriSpeech.') +flags.DEFINE_boolean('finewebedu', + False, + 'If --all=false, whether or not to download FineWebEdu.') flags.DEFINE_boolean('mnist', False, 'If --all=false, whether or not to download MNIST.') @@ -699,6 +710,86 @@ def download_wmt(data_dir): ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7) +def download_finewebedu(data_dir, tmp_dir): + """Download FineWebEdu-10B.""" + + # data_dir = "/fast/najroldi/data" + + tmp_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser("~/.cache/huggingface/datasets") + data_dir = os.path.join(data_dir, 'finewebedu') + + _maybe_mkdir(tmp_dir) + _maybe_mkdir(data_dir) + + ds = hf_datasets.load_dataset( + 'HuggingFaceFW/fineweb-edu', + name='sample-10BT', + split='train', + # cache_dir=tmp_dir + ) + + ds = ds.shuffle(seed=1996) # shuffle so that multiproc has shards of similar size + + seq_len = 2048 + max_seq_length = seq_len+1 + map_setup = dict(batched=True, batch_size=1024, num_proc=8) + + # Tokenize + tokenizer = AutoTokenizer.from_pretrained('gpt2') + logging.info(f"Vocab size of tokenizer = {len(tokenizer)}") + def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq + add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs] + return tokenizer( + add_eos_batched(examples["text"]), + return_special_tokens_mask=False, + return_attention_mask=False + ) + + tokenizer.model_max_length = 1e30 # prevent truncation during tokenization + tokenized_dataset = ds.map( + tokenize, + remove_columns=['text', 'id', 'dump', 'url', 'file_path', 'language', + 'language_score', 'token_count', 'score', 'int_score'], + **map_setup + ) + tokenizer.model_max_length = seq_len + + # Concat in chunks of max_seq_len + def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + """Concatenate text and generate chunks of max_seq_length""" + concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + if total_length >= max_seq_length: + total_length = (total_length // max_seq_length) * max_seq_length + result = { + k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] + for k, t in concatenated_examples.items() + } + return result + + lm_dataset = tokenized_dataset.map( + concat_chunck, + **map_setup + ) + + n_tokens = len(lm_dataset) * max_seq_length + logging.info(f"Number of tokens in dataset: {n_tokens:_}") + + # Split dataset into training and validation sets + # TODO: avoid (single doc) contamination between train and val + VAL_TOKENS = 10_000_000 + val_samples = VAL_TOKENS // max_seq_length + 1 + val_dataset = lm_dataset.select(range(val_samples)) + train_dataset = lm_dataset.select(range(val_samples, len(lm_dataset))) + logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length :_}") + logging.info(f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length :_}") + + # Save datasets + train_dataset.save_to_disk(os.path.join(data_dir, f"train")) + val_dataset.save_to_disk(os.path.join(data_dir, f"val")) + + def main(_): data_dir = FLAGS.data_dir tmp_dir = FLAGS.temp_dir @@ -781,6 +872,11 @@ def main(_): logging.info('Downloading WMT...') download_wmt(data_dir) + if FLAGS.all or FLAGS.finewebedu: + if not FLAGS.skip_download: + logging.info('Downloading FineWebEdu-10B...') + download_finewebedu(data_dir) + # pylint: enable=logging-format-interpolation # pylint: enable=consider-using-with diff --git a/datasets/lm_preprocess.py b/datasets/lm_preprocess.py new file mode 100644 index 000000000..e69de29bb From ca83ab8954a9e164dc538cb4749847812ee0e032 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 14 Mar 2025 11:31:08 +0100 Subject: [PATCH 03/82] testing --- algoperf/workloads/lm/{ => dev}/test_01.py | 0 .../lm/{ => dev}/test_input_pipeline.py | 0 algoperf/workloads/lm/input_pipeline.py | 37 +++++---- .../workloads/lm/lm_jax/__init__.py | 0 algoperf/workloads/lm/lm_jax/workload.py | 20 +++++ algoperf/workloads/lm/lm_pytorch/workload.py | 56 ++++++++++++- algoperf/workloads/lm/test.py | 37 +++++++++ algoperf/workloads/lm/workload.py | 80 ++++++++++++++----- datasets/dataset_setup.py | 25 ++++-- 9 files changed, 211 insertions(+), 44 deletions(-) rename algoperf/workloads/lm/{ => dev}/test_01.py (100%) rename algoperf/workloads/lm/{ => dev}/test_input_pipeline.py (100%) rename datasets/lm_preprocess.py => algoperf/workloads/lm/lm_jax/__init__.py (100%) create mode 100644 algoperf/workloads/lm/lm_jax/workload.py create mode 100644 algoperf/workloads/lm/test.py diff --git a/algoperf/workloads/lm/test_01.py b/algoperf/workloads/lm/dev/test_01.py similarity index 100% rename from algoperf/workloads/lm/test_01.py rename to algoperf/workloads/lm/dev/test_01.py diff --git a/algoperf/workloads/lm/test_input_pipeline.py b/algoperf/workloads/lm/dev/test_input_pipeline.py similarity index 100% rename from algoperf/workloads/lm/test_input_pipeline.py rename to algoperf/workloads/lm/dev/test_input_pipeline.py diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index a14cebeda..f0024e4a6 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -15,6 +15,10 @@ RANK = pytorch_setup()[1] # Avoid multithreading in all processes but the first (rank 0). +# This ensures that only the primary process (RANK == 0) uses TensorFlow's +# automatic optimization (AUTOTUNE), while other processes disable it (None). +# tf.data.AUTOTUNE is a constant that lets TensorFlow automatically determine the optimal +# number of elements to prefetch or parallelize for dataset operations, improving performance. AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None @@ -30,34 +34,36 @@ def get_lm_dataset(data_rng: jax.random.PRNGKey, """Load HF dataset and return a TF dataset.""" dataset_path = os.path.join(data_dir, split) - dataset = load_from_disk(dataset_path) # Loads HF arrow dataset + dataset = load_from_disk(dataset_path) is_training = split == "train" shuffle = split in ['train', 'eval_train'] + dataset.set_format("tensorflow") # tf.int64 + def tf_generator(): """Generates data in a TensorFlow-friendly format.""" for example in dataset: - input_ids = example["input_ids"].numpy().astype(np.int32) # torch tensor TODO: remove numpy conversion yield { - "inputs": tf.convert_to_tensor(input_ids[:-1], dtype=tf.int32), - "targets": tf.convert_to_tensor(input_ids[1:], dtype=tf.int32), + "inputs": example["input_ids"][:-1], + "targets": example["input_ids"][1:], } - # Create a TensorFlow dataset from the generator function + # Create a TensorFlow dataset ds = tf.data.Dataset.from_generator( - tf_generator, - output_signature={ - "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int32), - "targets": tf.TensorSpec(shape=(None,), dtype=tf.int32), - } - ) + tf_generator, + output_signature={ + "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int64), + "targets": tf.TensorSpec(shape=(None,), dtype=tf.int64), + } + ) # Avoid creating too many threads when using PyTorch DDP. - if RANK != 0: + # Limits TensorFlow's threading for non-primary processes (RANK != 0) + if RANK != 0: options = tf.data.Options() - options.threading.private_threadpool_size = 1 - ds = ds.with_options(options) + options.threading.private_threadpool_size = 1 # restrict dataset operations to a single thread + ds = ds.with_options(options) # apply threading restrictions if shuffle: ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) @@ -66,6 +72,9 @@ def tf_generator(): ds = ds.repeat() # Batch the dataset, ensuring the last batch is dropped if not full during training + # i.e. it groups consecutive elements into fixed-size chunks. + # Instead of processing individual elements, the dataset yields batches (tensors with multiple elements), + # improving efficiency and parallelism in training ds = ds.batch(global_batch_size, drop_remainder=is_training) ds = ds.prefetch(AUTOTUNE) diff --git a/datasets/lm_preprocess.py b/algoperf/workloads/lm/lm_jax/__init__.py similarity index 100% rename from datasets/lm_preprocess.py rename to algoperf/workloads/lm/lm_jax/__init__.py diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py new file mode 100644 index 000000000..4cdb42409 --- /dev/null +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -0,0 +1,20 @@ +"""LM workload implemented in Jax.""" + +import functools +from typing import Dict, Optional, Tuple + +from flax import jax_utils +import jax +import jax.numpy as jnp +import numpy as np + +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.lm.workload import BaseLmWorkload + + +class LmWorkload(BaseLmWorkload): + + @property + def eval_batch_size(self) -> int: + return 131_072 diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 904657b1d..9ee21ccb6 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -29,8 +29,58 @@ def init_model_fn(): def model_fn(): pass - def _build_input_queue(): - pass - + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + per_device_batch_size = int(global_batch_size / N_GPUS) + + # Only create and iterate over tf input pipeline in one Python process to + # avoid creating too many threads. + if RANK == 0: + np_iter = super()._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=repeat_final_dataset) + while True: + if RANK == 0: + batch = next(np_iter) + inputs = torch.as_tensor( + batch['inputs'], dtype=torch.float32, device=DEVICE) + targets = torch.as_tensor( + batch['targets'], dtype=torch.float32, device=DEVICE) + # Send batch to other devices when using DDP. + if USE_PYTORCH_DDP: + dist.broadcast(inputs, src=0) + inputs = inputs[0] # TODO: check + dist.broadcast(targets, src=0) + targets = targets[0] # TODO: check + else: + batch = {} + inputs = torch.empty((N_GPUS, per_device_batch_size, 39), + dtype=torch.float32, + device=DEVICE) + dist.broadcast(inputs, src=0) + inputs = inputs[RANK] + targets = torch.empty((N_GPUS, per_device_batch_size, 1), + dtype=torch.float32, + device=DEVICE) + dist.broadcast(targets, src=0) + targets = targets[RANK] + + batch = { + 'inputs': inputs, + 'targets': targets, + # 'weights': weights, + } + yield batch + + def eval_step(): pass diff --git a/algoperf/workloads/lm/test.py b/algoperf/workloads/lm/test.py new file mode 100644 index 000000000..7e693d0af --- /dev/null +++ b/algoperf/workloads/lm/test.py @@ -0,0 +1,37 @@ +""" +Test data pipaline in JAX and PyTorch. + +Instantiate a workload and loops over the input queue. +""" + +import jax +import numpy as np +import torch + +import algoperf.workloads.lm.lm_jax.workload as lm_jax +# import algoperf.workloads.lm.lm_pytorch.workload as lm_pytorch + + +data_rng = jax.random.PRNGKey(0) +split = 'train' +data_dir = "/fast/najroldi/data/finewebedu" +global_batch_size = 8 +num_batches = 10 +repeat_final_dataset = False + +# ------------------------------------------------------------------------------ +# JAX +# ------------------------------------------------------------------------------ + +# 1 GPU +workload = lm_jax.LmWorkload() + +input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=repeat_final_dataset) + +next(input_queue) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index d070cabec..63d2c707e 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -32,7 +32,7 @@ def _build_input_queue(self, num_batches: Optional[int] = None, repeat_final_dataset: bool = False): is_training = split == 'train' - ds, self._tokenizer = input_pipeline.get_lm_dataset( + ds = input_pipeline.get_lm_dataset( data_rng, split, data_dir, @@ -41,26 +41,66 @@ def _build_input_queue(self, global_batch_size=global_batch_size, num_batches=num_batches, repeat_final_dataset=repeat_final_dataset) - + for batch in iter(ds): yield batch - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: - """Run a full evaluation of the model.""" - - def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable - """Evaluate the loss function at (label_batch, logits_batch).""" + def _eval_model_on_split(): + pass + + def eval_period_time_sec(): + pass + + def has_reached_test_target(): + pass + + def has_reached_validation_target(): + pass + + def init_model_fn(): + pass + + def is_output_params(): + pass + + def loss_fn(): + pass + + def loss_type(): + pass + + def max_allowed_runtime_sec(): + pass + + def model_fn(): + pass + + def num_eval_train_examples(): + pass + + def num_test_examples(): + pass + + def num_train_examples(): + pass + + def num_validation_examples(): + pass + + def step_hint(): + pass + + def test_target_value(): + pass + + def train_mean(): + pass + + def train_stddev(): + pass + + def validation_target_value(): + pass + + def target_metric_name(): pass \ No newline at end of file diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 14dd24545..aab793832 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -76,10 +76,8 @@ normalize_feature_names from datasets import librispeech_preprocess from datasets import librispeech_tokenizer -from datasets import lm_preprocess import datasets as hf_datasets -# from datasets import load_dataset, Dataset from transformers import AutoTokenizer import math @@ -721,6 +719,9 @@ def download_finewebedu(data_dir, tmp_dir): _maybe_mkdir(tmp_dir) _maybe_mkdir(data_dir) + # Use local disk instead of NFS for temp storage + os.environ["TMPDIR"] = tmp_dir + ds = hf_datasets.load_dataset( 'HuggingFaceFW/fineweb-edu', name='sample-10BT', @@ -745,7 +746,6 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: return_special_tokens_mask=False, return_attention_mask=False ) - tokenizer.model_max_length = 1e30 # prevent truncation during tokenization tokenized_dataset = ds.map( tokenize, @@ -754,8 +754,21 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: **map_setup ) tokenizer.model_max_length = seq_len + + tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) + from datasets import load_from_disk + tokenized_dataset = load_from_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) # Concat in chunks of max_seq_len + # TODO: this might take to much memory + # TODO: bug fix: Python's shutil.rmtree tried to delete a .nfs* file, but it was still in use (OSError: [Errno 16] Device or resource busy + # TODO: bug fix: I am losing tokens in the concat-chunk: num_tokens before split: 9_944_182_212 + # (1) loss happening because of batched=True: potentially losing the last tokens in the last batch of the 1024 batched examples + # NOTE: the current approach leads to data loss at batch boundaries, + # but concatenation *cannot* happen if batched=False, + # because concat_chunck relies on processing multiple examples at once. + # (2) loss happening because of nproc>1: potentially losing the last tokens in each process + # TODO: this does not allow to later change the seq_len... not a problem in AlgoPerf, but bad in plainLM def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Concatenate text and generate chunks of max_seq_length""" concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} @@ -767,13 +780,11 @@ def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: for k, t in concatenated_examples.items() } return result - lm_dataset = tokenized_dataset.map( - concat_chunck, + concat_chunck,\ **map_setup ) - - n_tokens = len(lm_dataset) * max_seq_length + n_tokens = len(lm_dataset) * max_seq_length # 9_944_182_212 logging.info(f"Number of tokens in dataset: {n_tokens:_}") # Split dataset into training and validation sets From e3e78dc6443c5485af64bfe986951f72d9754f99 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Mon, 17 Mar 2025 11:18:41 +0100 Subject: [PATCH 04/82] LM workload tested torch pipeline --- algoperf/data_utils.py | 2 +- .../lm/dev/test_build_input_queue_torch.py | 80 +++++++++++++++++++ .../workloads/lm/{test.py => dev/test_jax.py} | 19 ++++- algoperf/workloads/lm/input_pipeline.py | 3 +- algoperf/workloads/lm/lm_jax/workload.py | 5 +- algoperf/workloads/lm/lm_pytorch/workload.py | 68 +++++++++------- algoperf/workloads/lm/workload.py | 7 +- submission_runner.py | 2 +- 8 files changed, 146 insertions(+), 40 deletions(-) create mode 100644 algoperf/workloads/lm/dev/test_build_input_queue_torch.py rename algoperf/workloads/lm/{test.py => dev/test_jax.py} (63%) diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 37d1bd20f..068c21c03 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -65,7 +65,7 @@ def _prepare(x): # Assumes that `global_batch_size % local_device_count == 0`. return x.reshape((local_device_count, -1, *x.shape[1:])) - return jax.tree.map(_prepare, batch) + return jax.tree_util.tree_map(_prepare, batch) def pad(tensor: np.ndarray, diff --git a/algoperf/workloads/lm/dev/test_build_input_queue_torch.py b/algoperf/workloads/lm/dev/test_build_input_queue_torch.py new file mode 100644 index 000000000..86b1ca6b7 --- /dev/null +++ b/algoperf/workloads/lm/dev/test_build_input_queue_torch.py @@ -0,0 +1,80 @@ + +import jax +import torch +import pdb +import numpy as np + +from algoperf import random_utils as prng +from algoperf import spec +from algoperf.profiler import PassThroughProfiler +from algoperf.pytorch_utils import pytorch_init +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.lm.lm_pytorch.workload import LmWorkload + +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() + +n_gpus = max(N_GPUS, jax.local_device_count()) + +def sync_ddp(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def test_dataloader_torch(): + # Test config. + rng_seed = 1996 + data_dir = '/fast/najroldi/data/finewebedu' + split = 'train' + global_batch_size = 8 + dtype = torch.int32 + seq_len = 2048 + + local_batch_size = global_batch_size // N_GPUS + + workload = LmWorkload() + + data_rng = jax.random.PRNGKey(rng_seed) + + input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + + # batch = next(input_queue) + + print(f"RANK {RANK} of {N_GPUS}") + sync_ddp() + + # Start test. + for _ in range(100): + + batch = next(input_queue) + assert type(batch) == dict + + assert 'inputs' in batch + assert 'targets' in batch + + assert type(batch['inputs']) == torch.Tensor + assert type(batch['targets']) == torch.Tensor + + assert batch['inputs'].dtype == dtype + assert batch['targets'].dtype == dtype + + assert batch['inputs'].shape == (local_batch_size, seq_len) + assert batch['targets'].shape == (local_batch_size, seq_len) + + sync_ddp() + + print(f"=== ALL TEST PASSED ===") + + +def main(): + profiler = PassThroughProfiler() + pytorch_init(USE_PYTORCH_DDP, RANK, profiler) + test_dataloader_torch() + + +if __name__ == '__main__': + main() + diff --git a/algoperf/workloads/lm/test.py b/algoperf/workloads/lm/dev/test_jax.py similarity index 63% rename from algoperf/workloads/lm/test.py rename to algoperf/workloads/lm/dev/test_jax.py index 7e693d0af..4ba3de631 100644 --- a/algoperf/workloads/lm/test.py +++ b/algoperf/workloads/lm/dev/test_jax.py @@ -15,6 +15,7 @@ data_rng = jax.random.PRNGKey(0) split = 'train' data_dir = "/fast/najroldi/data/finewebedu" +seq_len = 2048 global_batch_size = 8 num_batches = 10 repeat_final_dataset = False @@ -34,4 +35,20 @@ num_batches=num_batches, repeat_final_dataset=repeat_final_dataset) -next(input_queue) +batch = next(input_queue) +assert type(batch) == dict + +assert 'inputs' in batch +assert 'targets' in batch + +assert type(batch['inputs']) == np.ndarray +assert type(batch['targets']) == np.ndarray + +assert batch['inputs'].dtype == np.int64 +assert batch['targets'].dtype == np.int64 + +assert batch['inputs'].shape == (1, global_batch_size, seq_len) +assert batch['targets'].shape == (1, global_batch_size, seq_len) + +print(f"JAX devices = {jax.devices()}") +print("1") diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index f0024e4a6..e74490a16 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -25,7 +25,6 @@ def get_lm_dataset(data_rng: jax.random.PRNGKey, split: str, data_dir: str, - is_training: bool, vocab_size: int, global_batch_size: int, num_batches: Optional[int] = None, @@ -39,7 +38,7 @@ def get_lm_dataset(data_rng: jax.random.PRNGKey, is_training = split == "train" shuffle = split in ['train', 'eval_train'] - dataset.set_format("tensorflow") # tf.int64 + dataset.set_format("tensorflow") # tf.int64 # TODO: is this needed? def tf_generator(): """Generates data in a TensorFlow-friendly format.""" diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 4cdb42409..773f8c54c 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -14,7 +14,4 @@ class LmWorkload(BaseLmWorkload): - - @property - def eval_batch_size(self) -> int: - return 131_072 + pass diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 9ee21ccb6..0ff7884c7 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -1,7 +1,7 @@ """LM workload implemented in PyTorch.""" import contextlib -from typing import Any, Dict, Optional, Tuple +from typing import Dict, Iterator, Optional, Tuple from absl import logging import jax @@ -22,12 +22,6 @@ class LmWorkload(BaseLmWorkload): """LM PyTorch workload.""" - - def init_model_fn(): - pass - - def model_fn(): - pass def _build_input_queue(self, data_rng: jax.random.PRNGKey, @@ -35,8 +29,12 @@ def _build_input_queue(self, data_dir: str, global_batch_size: int, num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): + repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: + not_train = split != 'train' per_device_batch_size = int(global_batch_size / N_GPUS) + + seq_len = 2048 # TODO: define it somewehere else + DTYPE = torch.int32 # TODO: decide between int32 and int64. # Only create and iterate over tf input pipeline in one Python process to # avoid creating too many threads. @@ -48,36 +46,50 @@ def _build_input_queue(self, global_batch_size=global_batch_size, num_batches=num_batches, repeat_final_dataset=repeat_final_dataset) + weights = None + while True: + # Only iterate over tf input pipeline in one Python process to + # avoid creating too many threads. if RANK == 0: - batch = next(np_iter) - inputs = torch.as_tensor( - batch['inputs'], dtype=torch.float32, device=DEVICE) - targets = torch.as_tensor( - batch['targets'], dtype=torch.float32, device=DEVICE) + batch = next(np_iter) # pylint: disable=stop-iteration-return + inputs = torch.as_tensor(batch['inputs'], dtype=DTYPE, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + targets = torch.as_tensor(batch['targets'], dtype=DTYPE, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: - dist.broadcast(inputs, src=0) - inputs = inputs[0] # TODO: check - dist.broadcast(targets, src=0) - targets = targets[0] # TODO: check + if not_train: + # During eval, the batch size of the remainder might be different. + per_device_batch_size = torch.tensor(len(targets[0]), dtype=DTYPE, device=DEVICE) + dist.broadcast(per_device_batch_size, src=0) + # We don't broadcast the shard for RANK 0. + dist.broadcast(inputs[1:], src=0) + dist.broadcast(targets[1:], src=0) + + # RANK 0 extracts his shard. If not DDP, this just flattens. + inputs, targets = inputs[0], targets[0] + else: - batch = {} - inputs = torch.empty((N_GPUS, per_device_batch_size, 39), - dtype=torch.float32, - device=DEVICE) + # Receive batch from rank 0. + if not_train: + # During eval, the batch size of the remainder might be different. + per_device_batch_size = torch.empty((1,), dtype=DTYPE, device=DEVICE) + dist.broadcast(per_device_batch_size, src=0) + + # N_GPUS - 1 since we don't broadcast the shard for RANK 0. + inputs = torch.empty((N_GPUS-1, per_device_batch_size, seq_len), dtype=DTYPE, device=DEVICE) + targets = torch.empty((N_GPUS-1, per_device_batch_size, seq_len), dtype=DTYPE, device=DEVICE) dist.broadcast(inputs, src=0) - inputs = inputs[RANK] - targets = torch.empty((N_GPUS, per_device_batch_size, 1), - dtype=torch.float32, - device=DEVICE) dist.broadcast(targets, src=0) - targets = targets[RANK] - + # RANK - 1 since we don't broadcast the shard for RANK 0. + inputs, targets = inputs[RANK-1], targets[RANK-1] + + if weights is None: + weights = torch.ones(per_device_batch_size, device=DEVICE) batch = { 'inputs': inputs, 'targets': targets, - # 'weights': weights, + 'weights': weights, } yield batch diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 63d2c707e..7b1313dd7 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -31,12 +31,10 @@ def _build_input_queue(self, global_batch_size: int, num_batches: Optional[int] = None, repeat_final_dataset: bool = False): - is_training = split == 'train' ds = input_pipeline.get_lm_dataset( data_rng, split, data_dir, - is_training=is_training, vocab_size=self._vocab_size, global_batch_size=global_batch_size, num_batches=num_batches, @@ -103,4 +101,7 @@ def validation_target_value(): pass def target_metric_name(): - pass \ No newline at end of file + pass + + def eval_batch_size(): + pass diff --git a/submission_runner.py b/submission_runner.py index a2521e77b..6fac50d99 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -234,7 +234,7 @@ def train_once( dropout_rate = hyperparameters.dropout_rate if hasattr(hyperparameters, 'aux_dropout_rate'): aux_dropout_rate = hyperparameters.aux_dropout_rate - model_params, model_state = workload.init_model_fn( + model_params, model_state = workload.init_model_fn( model_init_rng, dropout_rate, aux_dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ From e6194950fc524793906127f09b330a8329ad079f Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Mon, 17 Mar 2025 11:34:10 +0100 Subject: [PATCH 05/82] LM workload - fix torch tests --- .../lm/dev/test_build_input_queue_torch.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/algoperf/workloads/lm/dev/test_build_input_queue_torch.py b/algoperf/workloads/lm/dev/test_build_input_queue_torch.py index 86b1ca6b7..66205d091 100644 --- a/algoperf/workloads/lm/dev/test_build_input_queue_torch.py +++ b/algoperf/workloads/lm/dev/test_build_input_queue_torch.py @@ -41,30 +41,33 @@ def test_dataloader_torch(): data_dir=data_dir, global_batch_size=global_batch_size) - # batch = next(input_queue) - print(f"RANK {RANK} of {N_GPUS}") sync_ddp() # Start test. for _ in range(100): - + batch = next(input_queue) - assert type(batch) == dict + assert type(batch) == dict assert 'inputs' in batch assert 'targets' in batch - assert type(batch['inputs']) == torch.Tensor - assert type(batch['targets']) == torch.Tensor + inputs, targets = batch['inputs'], batch['targets'] + + assert type(inputs) == torch.Tensor + assert type(targets) == torch.Tensor + + assert inputs.device == DEVICE + assert targets.device == DEVICE + + assert inputs.dtype == dtype + assert targets.dtype == dtype - assert batch['inputs'].dtype == dtype - assert batch['targets'].dtype == dtype + assert inputs.shape == (local_batch_size, seq_len) + assert targets.shape == (local_batch_size, seq_len) - assert batch['inputs'].shape == (local_batch_size, seq_len) - assert batch['targets'].shape == (local_batch_size, seq_len) - - sync_ddp() + assert torch.equal(inputs[:,1:], targets[:,:-1]) print(f"=== ALL TEST PASSED ===") From d8e9c56738de817e561e79cffee638ab7197eaed Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 09:44:36 +0100 Subject: [PATCH 06/82] add LM tests, remove dev files --- algoperf/workloads/lm/dev/data_pytorch.py | 42 --------- algoperf/workloads/lm/dev/test_01.py | 92 ------------------- .../lm/dev/test_build_input_queue_torch.py | 83 ----------------- .../workloads/lm/dev/test_input_pipeline.py | 68 -------------- algoperf/workloads/lm/dev/test_jax.py | 54 ----------- 5 files changed, 339 deletions(-) delete mode 100644 algoperf/workloads/lm/dev/data_pytorch.py delete mode 100644 algoperf/workloads/lm/dev/test_01.py delete mode 100644 algoperf/workloads/lm/dev/test_build_input_queue_torch.py delete mode 100644 algoperf/workloads/lm/dev/test_input_pipeline.py delete mode 100644 algoperf/workloads/lm/dev/test_jax.py diff --git a/algoperf/workloads/lm/dev/data_pytorch.py b/algoperf/workloads/lm/dev/data_pytorch.py deleted file mode 100644 index d0081a75d..000000000 --- a/algoperf/workloads/lm/dev/data_pytorch.py +++ /dev/null @@ -1,42 +0,0 @@ - -import torch - -from datasets import Dataset, load_from_disk -from torch.utils.data import DataLoader - -trainset_path = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens/train" -vocab_size = 50280 -seq_len = 2048 -sampler = 'sequential' -sampler_seed = None -num_workers = 4 - -train_set = load_from_disk(trainset_path) # - -""" ->>> type(train_set) - - ->>> len(train_set) -7501407 - ->>> train_set[0] -{'input_ids': tensor([ 5166, 20, 1639, ..., 275, 253, 19992])} - ->>> type(train_set[0]['input_ids']) - - -# In PyTorch we do: -trainloader = DataLoader( - train_set, - sampler = ..., - batch_size = ..., - num_workers = ..., - pin_memory = ..., - ) - -# PyTorch’s DataLoader expects an iterable dataset, -# which means it calls __getitem__() and __len__() on train_set. - -""" - diff --git a/algoperf/workloads/lm/dev/test_01.py b/algoperf/workloads/lm/dev/test_01.py deleted file mode 100644 index 977fae11a..000000000 --- a/algoperf/workloads/lm/dev/test_01.py +++ /dev/null @@ -1,92 +0,0 @@ - -import os -import numpy as np -import tensorflow as tf -import torch - -from datasets import load_from_disk - -from absl import app -from absl import flags -from absl import logging - -from algoperf.profiler import PassThroughProfiler -from algoperf import random_utils as prng -from algoperf.pytorch_utils import pytorch_init -from algoperf.pytorch_utils import pytorch_setup -from algoperf.workloads.lm.input_pipeline import get_lm_dataset - - -tf.config.set_visible_devices([], 'GPU') - -# Environment variables -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. -# disable only for deepspeech if it works fine for other workloads -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' -# (nico) -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' - -flags.DEFINE_enum( - 'framework', - None, - enum_values=['jax', 'pytorch'], - help='Whether to use Jax or Pytorch for the submission. Controls among ' - 'other things if the Jax or Numpy RNG library is used for RNG.') - -FLAGS = flags.FLAGS -USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() - - -DATASET_PATH = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens" -BATCH_SIZE = 2 -RNG_SEED = 1996 # Fixed random seed for reproducibility - - -def main(_): - profiler = PassThroughProfiler() - if FLAGS.framework == 'pytorch': - pytorch_init(USE_PYTORCH_DDP, RANK, profiler) - - rng = prng.PRNGKey(RNG_SEED) - data_rng, _, _, _ = prng.split(rng, 4) - - print(f"data_rng = {data_rng}") - - # Load the dataset - ds = get_lm_dataset( - data_rng=data_rng, - split="train", - data_dir=DATASET_PATH, - is_training=True, - vocab_size=0, # Not needed but kept for function signature - global_batch_size=BATCH_SIZE, - ) - # Check if `ds` acts as a generator - if hasattr(ds, '__iter__'): - print("Dataset is an iterable/generator.") - - # Fetch first batch - try: - first_batch = next(iter(ds)) - print(f"Successfully retrieved first batch.") - except Exception as e: - print(f"Error retrieving first batch: {e}") - return - - # Print structure of a batch - print(f"First batch keys: {first_batch.keys()}") - print(f"First batch shapes:") - for key, value in first_batch.items(): - print(f" - {key}: {value.shape} (dtype: {value.dtype})") - - # Validate batch dimensions - assert "inputs" in first_batch and "targets" in first_batch, "Missing expected keys!" - assert first_batch["inputs"].shape[0] == BATCH_SIZE, "Batch size mismatch!" - assert first_batch["inputs"].shape == first_batch["targets"].shape, "Inputs and targets should have the same shape!" - - print(f"Dataset is correctly batched and structured.") - print(f"Test completed successfully.") - -if __name__ == '__main__': - flags.mark_flag_as_required('framework') - app.run(main) diff --git a/algoperf/workloads/lm/dev/test_build_input_queue_torch.py b/algoperf/workloads/lm/dev/test_build_input_queue_torch.py deleted file mode 100644 index 66205d091..000000000 --- a/algoperf/workloads/lm/dev/test_build_input_queue_torch.py +++ /dev/null @@ -1,83 +0,0 @@ - -import jax -import torch -import pdb -import numpy as np - -from algoperf import random_utils as prng -from algoperf import spec -from algoperf.profiler import PassThroughProfiler -from algoperf.pytorch_utils import pytorch_init -from algoperf.pytorch_utils import pytorch_setup -from algoperf.workloads.lm.lm_pytorch.workload import LmWorkload - -USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() - -n_gpus = max(N_GPUS, jax.local_device_count()) - -def sync_ddp(): - if torch.cuda.is_available(): - torch.cuda.synchronize() - - -def test_dataloader_torch(): - # Test config. - rng_seed = 1996 - data_dir = '/fast/najroldi/data/finewebedu' - split = 'train' - global_batch_size = 8 - dtype = torch.int32 - seq_len = 2048 - - local_batch_size = global_batch_size // N_GPUS - - workload = LmWorkload() - - data_rng = jax.random.PRNGKey(rng_seed) - - input_queue = workload._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size) - - print(f"RANK {RANK} of {N_GPUS}") - sync_ddp() - - # Start test. - for _ in range(100): - - batch = next(input_queue) - - assert type(batch) == dict - assert 'inputs' in batch - assert 'targets' in batch - - inputs, targets = batch['inputs'], batch['targets'] - - assert type(inputs) == torch.Tensor - assert type(targets) == torch.Tensor - - assert inputs.device == DEVICE - assert targets.device == DEVICE - - assert inputs.dtype == dtype - assert targets.dtype == dtype - - assert inputs.shape == (local_batch_size, seq_len) - assert targets.shape == (local_batch_size, seq_len) - - assert torch.equal(inputs[:,1:], targets[:,:-1]) - - print(f"=== ALL TEST PASSED ===") - - -def main(): - profiler = PassThroughProfiler() - pytorch_init(USE_PYTORCH_DDP, RANK, profiler) - test_dataloader_torch() - - -if __name__ == '__main__': - main() - diff --git a/algoperf/workloads/lm/dev/test_input_pipeline.py b/algoperf/workloads/lm/dev/test_input_pipeline.py deleted file mode 100644 index 47c11969f..000000000 --- a/algoperf/workloads/lm/dev/test_input_pipeline.py +++ /dev/null @@ -1,68 +0,0 @@ -import os -import tensorflow as tf -import torch -from datasets import load_from_disk - -from algoperf.workloads.lm.input_pipeline import get_lm_dataset - -DATASET_PATH = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens" -BATCH_SIZE = 2 -SEED = 42 # Fixed random seed for reproducibility - - -def test_tf_dataset(): - """Tests if get_lm_dataset correctly loads the HF dataset as a TensorFlow dataset.""" - - print(f"Loading dataset from: {DATASET_PATH}") - - tf_seed = SEED - - # Load the dataset - ds = get_lm_dataset( - data_rng=[tf_seed], # Ensure correct seed type - split="train", - data_dir=DATASET_PATH, - is_training=True, - vocab_size=0, # Not needed but kept for function signature - global_batch_size=BATCH_SIZE, - ) - - print("Testing TensorFlow Dataset Output...") - for batch in ds.take(2): # Take two batches to test - print("Inputs:", batch["inputs"].numpy()) # Convert to NumPy for inspection - print("Targets:", batch["targets"].numpy()) - -def test_pytorch_dataloader(): - """Tests if the TensorFlow dataset can be converted to PyTorch format correctly.""" - - # Use the same TensorFlow-compatible seed - tf_seed = tf.constant(SEED, dtype=tf.int64) - - # Load the dataset - ds = get_lm_dataset( - data_rng=[tf_seed], # Ensure correct seed type - split="train", - data_dir=DATASET_PATH, - is_training=True, - vocab_size=0, - global_batch_size=BATCH_SIZE, - ) - - def _input_queue_generator(): - """Generator that converts TF dataset batches to PyTorch tensors.""" - for batch in iter(ds): - batch = {k: torch.tensor(v.numpy()) for k, v in batch.items()} # Convert to PyTorch tensors - yield batch - - dataloader = _input_queue_generator() - - print("\nTesting PyTorch DataLoader Output...") - for _ in range(2): # Take two batches - batch = next(dataloader) - print("Inputs:", batch["inputs"]) - print("Targets:", batch["targets"]) - -# Run tests -if __name__ == "__main__": - test_tf_dataset() - test_pytorch_dataloader() \ No newline at end of file diff --git a/algoperf/workloads/lm/dev/test_jax.py b/algoperf/workloads/lm/dev/test_jax.py deleted file mode 100644 index 4ba3de631..000000000 --- a/algoperf/workloads/lm/dev/test_jax.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -Test data pipaline in JAX and PyTorch. - -Instantiate a workload and loops over the input queue. -""" - -import jax -import numpy as np -import torch - -import algoperf.workloads.lm.lm_jax.workload as lm_jax -# import algoperf.workloads.lm.lm_pytorch.workload as lm_pytorch - - -data_rng = jax.random.PRNGKey(0) -split = 'train' -data_dir = "/fast/najroldi/data/finewebedu" -seq_len = 2048 -global_batch_size = 8 -num_batches = 10 -repeat_final_dataset = False - -# ------------------------------------------------------------------------------ -# JAX -# ------------------------------------------------------------------------------ - -# 1 GPU -workload = lm_jax.LmWorkload() - -input_queue = workload._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) - -batch = next(input_queue) -assert type(batch) == dict - -assert 'inputs' in batch -assert 'targets' in batch - -assert type(batch['inputs']) == np.ndarray -assert type(batch['targets']) == np.ndarray - -assert batch['inputs'].dtype == np.int64 -assert batch['targets'].dtype == np.int64 - -assert batch['inputs'].shape == (1, global_batch_size, seq_len) -assert batch['targets'].shape == (1, global_batch_size, seq_len) - -print(f"JAX devices = {jax.devices()}") -print("1") From 6b4ff12356c5f41b01ce703801b556a11079d354 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 09:44:58 +0100 Subject: [PATCH 07/82] add LM tests, remove dev files --- algoperf/workloads/lm/dev/data_pytorch.py | 42 ++++++ .../lm/dev/test_build_input_queue_jax.py | 127 ++++++++++++++++++ .../lm/tests/test_build_input_queue_torch.py | 87 ++++++++++++ 3 files changed, 256 insertions(+) create mode 100644 algoperf/workloads/lm/dev/data_pytorch.py create mode 100644 algoperf/workloads/lm/dev/test_build_input_queue_jax.py create mode 100644 algoperf/workloads/lm/tests/test_build_input_queue_torch.py diff --git a/algoperf/workloads/lm/dev/data_pytorch.py b/algoperf/workloads/lm/dev/data_pytorch.py new file mode 100644 index 000000000..d0081a75d --- /dev/null +++ b/algoperf/workloads/lm/dev/data_pytorch.py @@ -0,0 +1,42 @@ + +import torch + +from datasets import Dataset, load_from_disk +from torch.utils.data import DataLoader + +trainset_path = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens/train" +vocab_size = 50280 +seq_len = 2048 +sampler = 'sequential' +sampler_seed = None +num_workers = 4 + +train_set = load_from_disk(trainset_path) # + +""" +>>> type(train_set) + + +>>> len(train_set) +7501407 + +>>> train_set[0] +{'input_ids': tensor([ 5166, 20, 1639, ..., 275, 253, 19992])} + +>>> type(train_set[0]['input_ids']) + + +# In PyTorch we do: +trainloader = DataLoader( + train_set, + sampler = ..., + batch_size = ..., + num_workers = ..., + pin_memory = ..., + ) + +# PyTorch’s DataLoader expects an iterable dataset, +# which means it calls __getitem__() and __len__() on train_set. + +""" + diff --git a/algoperf/workloads/lm/dev/test_build_input_queue_jax.py b/algoperf/workloads/lm/dev/test_build_input_queue_jax.py new file mode 100644 index 000000000..08354be74 --- /dev/null +++ b/algoperf/workloads/lm/dev/test_build_input_queue_jax.py @@ -0,0 +1,127 @@ + +# TODO: redo with pmap!! + +import os +import jax +import tensorflow as tf +import torch +import pdb +import numpy as np + +from algoperf import random_utils as prng +from algoperf import spec +from algoperf.profiler import PassThroughProfiler +from algoperf.pytorch_utils import pytorch_init +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.lm.lm_jax.workload import LmWorkload + +# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make +# it unavailable to JAX. +tf.config.set_visible_devices([], 'GPU') + +# Environment variables +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +# disable only for deepspeech if it works fine for other workloads +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' + + +N_GPUS = jax.local_device_count() + +print(f"jax.local_devices() = {jax.local_devices()}") +print(f"jax.local_device_count() = {jax.local_device_count()}") + +print(f"N_GPUS = {N_GPUS}") + +def check_batch(batch): + assert type(batch) == dict + assert 'inputs' in batch + assert 'targets' in batch + + inputs, targets = batch['inputs'], batch['targets'] + + assert type(inputs) == torch.Tensor + assert type(targets) == torch.Tensor + + assert inputs.device == DEVICE + assert targets.device == DEVICE + + assert inputs.dtype == dtype + assert targets.dtype == dtype + + assert inputs.shape == (local_batch_size, seq_len) + assert targets.shape == (local_batch_size, seq_len) + + assert torch.equal(inputs[:,1:], targets[:,:-1]) + + +def process_shard(batch): + inputs, targets = batch['inputs'], batch['targets'] + jax.debug.print("Processing on GPU with inputs: {shape}", shape=inputs.shape) + jax.debug.print("inputs {inputs}", inputs=inputs) + jax.debug.callback(check_batch, batch) + return inputs, targets + +# Apply process_batch across devices, sharding batch across devices +pmap_process = jax.pmap(process_shard, axis_name='batch') + + +def test_dataloader_jax(): + # Test config. + rng_seed = 1996 + data_dir = '/fast/najroldi/data/finewebedu' + split = 'train' + global_batch_size = 8 + dtype = np.int32 + seq_len = 2048 + + local_batch_size = global_batch_size // N_GPUS + + workload = LmWorkload() + + data_rng = jax.random.PRNGKey(rng_seed) + + input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + + batch = next(input_queue) + + inputs, targets = batch['inputs'], batch['targets'] + print(f"Processing on GPU with inputs: {inputs.shape}") + + inputs, targets = pmap_process(batch) + print(f"Processing on GPU with inputs: {inputs.shape}") + print(f"Processing on GPU with inputs: {inputs}") + + # inputs, targets = batch['inputs'], batch['targets'] + # print(f"inputs.shape: {inputs.shape}") + # print(f"inputs[0]: {inputs[0]}") + # print(f"inputs[1]: {inputs[1]}") + + # for device_id in range(2): + # # Access the sharded data for each GPU + # print(inputs.shape) + # device_inputs = inputs[device_id] + # print(f" GPU {device_id} Inputs: {device_inputs.shape}") + + # @jax.pmap + # def process_batch(batch): + # inputs, targets = batch['inputs'], batch['targets'] + # print(f"inputs.shape: {inputs.shape}") + + # return inputs, targets + + # inputs, targets = batch['inputs'], batch['targets'] #process_batch(batch) + # print(f"inputs: {inputs[0]}") + + + +def main(): + test_dataloader_jax() + + +if __name__ == '__main__': + main() + diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py new file mode 100644 index 000000000..83a18ec15 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py @@ -0,0 +1,87 @@ + +import jax +import torch +import pdb +import numpy as np + +from algoperf import random_utils as prng +from algoperf import spec +from algoperf.profiler import PassThroughProfiler +from algoperf.pytorch_utils import pytorch_init +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.lm.lm_pytorch.workload import LmWorkload + +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() + + +def sync_ddp(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def test_dataloader_torch(): + # Test config. + rng_seed = 1996 + data_dir = '/fast/najroldi/data/finewebedu' + split = 'train' + global_batch_size = 8 + dtype = torch.int32 + seq_len = 2048 + + local_batch_size = global_batch_size // N_GPUS + + workload = LmWorkload() + + data_rng = jax.random.PRNGKey(rng_seed) + + input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + + print(f"RANK {RANK} of {N_GPUS}") + sync_ddp() + + # batch = next(input_queue) + # inputs, targets = batch['inputs'], batch['targets'] + # print(f"inputs.shape: {inputs.shape}") + # print(f"inputs: {inputs}") + + # Start test. + for _ in range(100): + + batch = next(input_queue) + + assert type(batch) == dict + assert 'inputs' in batch + assert 'targets' in batch + + inputs, targets = batch['inputs'], batch['targets'] + + assert type(inputs) == torch.Tensor + assert type(targets) == torch.Tensor + + assert inputs.device == DEVICE + assert targets.device == DEVICE + + assert inputs.dtype == dtype + assert targets.dtype == dtype + + assert inputs.shape == (local_batch_size, seq_len) + assert targets.shape == (local_batch_size, seq_len) + + assert torch.equal(inputs[:,1:], targets[:,:-1]) + + print(f"=== ALL TEST PASSED ===") + + +def main(): + profiler = PassThroughProfiler() + pytorch_init(USE_PYTORCH_DDP, RANK, profiler) + test_dataloader_torch() + + +if __name__ == '__main__': + main() + From 3c5c847eb1489fa11a65c98c0f3327bd3c23c088 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 09:45:41 +0100 Subject: [PATCH 08/82] Stop tracking .gitignore --- .gitignore | 28 ---------------------------- 1 file changed, 28 deletions(-) delete mode 100644 .gitignore diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 7d35f0ccc..000000000 --- a/.gitignore +++ /dev/null @@ -1,28 +0,0 @@ -__pycache__/* -__pycache__ -*egg-info -*eggs -.vscode/ -env/ -venv/ -workdir/ -makefile -*.out -*.sh -*.swp -*/data/ -*events.out.tfevents* -algoperf/workloads/librispeech_conformer/data_dir -algoperf/workloads/librispeech_conformer/work_dir -*.flac -*.npy -*.csv -*.vocab -wandb/ -*.txt -scoring/plots/ - -!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv -!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv - -algoperf/_version.py From 20d841b1932408bc905051dc2e188f3a43e0d749 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 09:47:55 +0100 Subject: [PATCH 09/82] Remove dev/ from repo, keep locally --- algoperf/workloads/lm/dev/data_pytorch.py | 42 ------ .../lm/dev/test_build_input_queue_jax.py | 127 ------------------ 2 files changed, 169 deletions(-) delete mode 100644 algoperf/workloads/lm/dev/data_pytorch.py delete mode 100644 algoperf/workloads/lm/dev/test_build_input_queue_jax.py diff --git a/algoperf/workloads/lm/dev/data_pytorch.py b/algoperf/workloads/lm/dev/data_pytorch.py deleted file mode 100644 index d0081a75d..000000000 --- a/algoperf/workloads/lm/dev/data_pytorch.py +++ /dev/null @@ -1,42 +0,0 @@ - -import torch - -from datasets import Dataset, load_from_disk -from torch.utils.data import DataLoader - -trainset_path = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens/train" -vocab_size = 50280 -seq_len = 2048 -sampler = 'sequential' -sampler_seed = None -num_workers = 4 - -train_set = load_from_disk(trainset_path) # - -""" ->>> type(train_set) - - ->>> len(train_set) -7501407 - ->>> train_set[0] -{'input_ids': tensor([ 5166, 20, 1639, ..., 275, 253, 19992])} - ->>> type(train_set[0]['input_ids']) - - -# In PyTorch we do: -trainloader = DataLoader( - train_set, - sampler = ..., - batch_size = ..., - num_workers = ..., - pin_memory = ..., - ) - -# PyTorch’s DataLoader expects an iterable dataset, -# which means it calls __getitem__() and __len__() on train_set. - -""" - diff --git a/algoperf/workloads/lm/dev/test_build_input_queue_jax.py b/algoperf/workloads/lm/dev/test_build_input_queue_jax.py deleted file mode 100644 index 08354be74..000000000 --- a/algoperf/workloads/lm/dev/test_build_input_queue_jax.py +++ /dev/null @@ -1,127 +0,0 @@ - -# TODO: redo with pmap!! - -import os -import jax -import tensorflow as tf -import torch -import pdb -import numpy as np - -from algoperf import random_utils as prng -from algoperf import spec -from algoperf.profiler import PassThroughProfiler -from algoperf.pytorch_utils import pytorch_init -from algoperf.pytorch_utils import pytorch_setup -from algoperf.workloads.lm.lm_jax.workload import LmWorkload - -# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make -# it unavailable to JAX. -tf.config.set_visible_devices([], 'GPU') - -# Environment variables -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. -# disable only for deepspeech if it works fine for other workloads -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' - - -N_GPUS = jax.local_device_count() - -print(f"jax.local_devices() = {jax.local_devices()}") -print(f"jax.local_device_count() = {jax.local_device_count()}") - -print(f"N_GPUS = {N_GPUS}") - -def check_batch(batch): - assert type(batch) == dict - assert 'inputs' in batch - assert 'targets' in batch - - inputs, targets = batch['inputs'], batch['targets'] - - assert type(inputs) == torch.Tensor - assert type(targets) == torch.Tensor - - assert inputs.device == DEVICE - assert targets.device == DEVICE - - assert inputs.dtype == dtype - assert targets.dtype == dtype - - assert inputs.shape == (local_batch_size, seq_len) - assert targets.shape == (local_batch_size, seq_len) - - assert torch.equal(inputs[:,1:], targets[:,:-1]) - - -def process_shard(batch): - inputs, targets = batch['inputs'], batch['targets'] - jax.debug.print("Processing on GPU with inputs: {shape}", shape=inputs.shape) - jax.debug.print("inputs {inputs}", inputs=inputs) - jax.debug.callback(check_batch, batch) - return inputs, targets - -# Apply process_batch across devices, sharding batch across devices -pmap_process = jax.pmap(process_shard, axis_name='batch') - - -def test_dataloader_jax(): - # Test config. - rng_seed = 1996 - data_dir = '/fast/najroldi/data/finewebedu' - split = 'train' - global_batch_size = 8 - dtype = np.int32 - seq_len = 2048 - - local_batch_size = global_batch_size // N_GPUS - - workload = LmWorkload() - - data_rng = jax.random.PRNGKey(rng_seed) - - input_queue = workload._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size) - - batch = next(input_queue) - - inputs, targets = batch['inputs'], batch['targets'] - print(f"Processing on GPU with inputs: {inputs.shape}") - - inputs, targets = pmap_process(batch) - print(f"Processing on GPU with inputs: {inputs.shape}") - print(f"Processing on GPU with inputs: {inputs}") - - # inputs, targets = batch['inputs'], batch['targets'] - # print(f"inputs.shape: {inputs.shape}") - # print(f"inputs[0]: {inputs[0]}") - # print(f"inputs[1]: {inputs[1]}") - - # for device_id in range(2): - # # Access the sharded data for each GPU - # print(inputs.shape) - # device_inputs = inputs[device_id] - # print(f" GPU {device_id} Inputs: {device_inputs.shape}") - - # @jax.pmap - # def process_batch(batch): - # inputs, targets = batch['inputs'], batch['targets'] - # print(f"inputs.shape: {inputs.shape}") - - # return inputs, targets - - # inputs, targets = batch['inputs'], batch['targets'] #process_batch(batch) - # print(f"inputs: {inputs[0]}") - - - -def main(): - test_dataloader_jax() - - -if __name__ == '__main__': - main() - From f3ba0593d955c657b6da8a07eede425509dbc6b9 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 10:00:44 +0100 Subject: [PATCH 10/82] fix comments --- algoperf/workloads/lm/input_pipeline.py | 2 +- datasets/dataset_setup.py | 27 +++++++------------------ 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index e74490a16..bae1f5e45 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -38,7 +38,7 @@ def get_lm_dataset(data_rng: jax.random.PRNGKey, is_training = split == "train" shuffle = split in ['train', 'eval_train'] - dataset.set_format("tensorflow") # tf.int64 # TODO: is this needed? + dataset.set_format("tensorflow") # tf.int64 # TODO (nico): is this needed? def tf_generator(): """Generates data in a TensorFlow-friendly format.""" diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index aab793832..8299133c1 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -711,8 +711,6 @@ def download_wmt(data_dir): def download_finewebedu(data_dir, tmp_dir): """Download FineWebEdu-10B.""" - # data_dir = "/fast/najroldi/data" - tmp_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser("~/.cache/huggingface/datasets") data_dir = os.path.join(data_dir, 'finewebedu') @@ -726,7 +724,7 @@ def download_finewebedu(data_dir, tmp_dir): 'HuggingFaceFW/fineweb-edu', name='sample-10BT', split='train', - # cache_dir=tmp_dir + cache_dir=tmp_dir ) ds = ds.shuffle(seed=1996) # shuffle so that multiproc has shards of similar size @@ -756,19 +754,11 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: tokenizer.model_max_length = seq_len tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) - from datasets import load_from_disk - tokenized_dataset = load_from_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) # Concat in chunks of max_seq_len - # TODO: this might take to much memory - # TODO: bug fix: Python's shutil.rmtree tried to delete a .nfs* file, but it was still in use (OSError: [Errno 16] Device or resource busy - # TODO: bug fix: I am losing tokens in the concat-chunk: num_tokens before split: 9_944_182_212 - # (1) loss happening because of batched=True: potentially losing the last tokens in the last batch of the 1024 batched examples - # NOTE: the current approach leads to data loss at batch boundaries, - # but concatenation *cannot* happen if batched=False, - # because concat_chunck relies on processing multiple examples at once. - # (2) loss happening because of nproc>1: potentially losing the last tokens in each process - # TODO: this does not allow to later change the seq_len... not a problem in AlgoPerf, but bad in plainLM + # TODO (nico): this might take to much memory + # TODO (nico): bug fix: Python's shutil.rmtree tried to delete .nfs file, but it was still in use (OSError: [Errno 16] Device or resource busy + # TODO (nico): make it sequential or increase batch_size in the map_setup def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Concatenate text and generate chunks of max_seq_length""" concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} @@ -780,15 +770,12 @@ def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: for k, t in concatenated_examples.items() } return result - lm_dataset = tokenized_dataset.map( - concat_chunck,\ - **map_setup - ) - n_tokens = len(lm_dataset) * max_seq_length # 9_944_182_212 + lm_dataset = tokenized_dataset.map(concat_chunck, **map_setup) + n_tokens = len(lm_dataset) * max_seq_length logging.info(f"Number of tokens in dataset: {n_tokens:_}") # Split dataset into training and validation sets - # TODO: avoid (single doc) contamination between train and val + # TODO (nico): avoid (single doc) contamination, by splitting before concatenation VAL_TOKENS = 10_000_000 val_samples = VAL_TOKENS // max_seq_length + 1 val_dataset = lm_dataset.select(range(val_samples)) From 381451f04a34e4a78a5256f92e1e7c092e0eadeb Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 10:46:45 +0100 Subject: [PATCH 11/82] add class specifications --- algoperf/workloads/lm/lm_jax/workload.py | 36 +++- algoperf/workloads/lm/lm_pytorch/workload.py | 26 ++- algoperf/workloads/lm/workload.py | 201 +++++++++++++------ datasets/dataset_setup.py | 6 +- 4 files changed, 199 insertions(+), 70 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 773f8c54c..84377b4bc 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -1,17 +1,47 @@ """LM workload implemented in Jax.""" import functools -from typing import Dict, Optional, Tuple +from typing import Any, Dict, Iterator, Optional, Tuple +from absl import logging from flax import jax_utils +from flax import linen as nn +from flax.training import common_utils import jax import jax.numpy as jnp import numpy as np +import optax from algoperf import param_utils +from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload - class LmWorkload(BaseLmWorkload): - pass + """LM JAX workload.""" + + def init_model_fn( + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + """aux_dropout_rate is used as attention_dropout_rate.""" + pass + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + pass + + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + pass diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 0ff7884c7..404dc2532 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -23,6 +23,24 @@ class LmWorkload(BaseLmWorkload): """LM PyTorch workload.""" + def init_model_fn( + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + """aux_dropout_rate is used as attention_dropout_rate.""" + pass + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + pass + def _build_input_queue(self, data_rng: jax.random.PRNGKey, split: str, @@ -93,6 +111,10 @@ def _build_input_queue(self, } yield batch - - def eval_step(): + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" pass diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 7b1313dd7..e36d54625 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -5,6 +5,9 @@ import os from typing import Any, Dict, Optional, Tuple +from absl import flags +import torch.distributed as dist + import jax import numpy as np import torch @@ -12,17 +15,98 @@ from algoperf import spec from algoperf.workloads.lm import input_pipeline +FLAGS = flags.FLAGS + USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ class BaseLmWorkload(spec.Workload): - """A LM workload.""" + """LM workload.""" _vocab_size: int = 32000 def __init__(self) -> None: super().__init__() - self._tokenizer = None + + @property + def target_metric_name(self) -> str: + """The name of the target metric (useful for scoring/processing code).""" + return 'ppl' + + def has_reached_validation_target(self, eval_result: float) -> bool: + return eval_result['validation/ppl'] > self.validation_target_value + + @property + def validation_target_value(self) -> float: + pass + + def has_reached_test_target(self, eval_result: float) -> bool: + return eval_result['test/ppl'] > self.test_target_value + + @property + def test_target_value(self) -> float: + pass + + @property + def loss_type(self) -> spec.LossType: + return spec.LossType.SOFTMAX_CROSS_ENTROPY + + @property + def num_train_examples(self) -> int: + pass + + @property + def num_eval_train_examples(self) -> int: + pass + + @property + def num_validation_examples(self) -> int: + pass + + @property + def num_test_examples(self) -> int: + pass + + @property + def eval_batch_size(self) -> int: + pass + + @property + def train_mean(self): + raise NotImplementedError + + @property + def train_stddev(self): + raise NotImplementedError + + @property + def max_allowed_runtime_sec(self) -> int: + pass + + @property + def eval_period_time_sec(self) -> int: + pass + + @property + def step_hint(self) -> int: + """Approx. steps the baseline can do in the allowed runtime budget.""" + pass + + @property + def pre_ln(self) -> bool: + return True + + @property + def attention_temp(self) -> float: + return 1.0 + + @property + def activation(self) -> str: + return 'silu' + + @property + def glu(self) -> bool: + return True def _build_input_queue(self, data_rng: jax.random.PRNGKey, @@ -43,65 +127,58 @@ def _build_input_queue(self, for batch in iter(ds): yield batch - def _eval_model_on_split(): - pass - - def eval_period_time_sec(): - pass - - def has_reached_test_target(): - pass - - def has_reached_validation_target(): - pass - - def init_model_fn(): - pass - - def is_output_params(): - pass - - def loss_fn(): - pass - - def loss_type(): - pass - - def max_allowed_runtime_sec(): - pass - - def model_fn(): - pass - - def num_eval_train_examples(): - pass - - def num_test_examples(): - pass - - def num_train_examples(): - pass - - def num_validation_examples(): - pass - - def step_hint(): - pass - - def test_target_value(): - pass - - def train_mean(): - pass - - def train_stddev(): - pass - - def validation_target_value(): - pass - - def target_metric_name(): - pass + @abc.abstractmethod + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + + def _eval_model_on_split(self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0) -> Dict[str, float]: + """Run a full evaluation of the model.""" + num_batches = int(math.ceil(num_examples / global_batch_size)) + if split not in self._eval_iters: + # These iterators will repeat indefinitely. + self._eval_iters[split] = self._build_input_queue( + rng, + split, + data_dir, + global_batch_size, + num_batches, + repeat_final_dataset=True) + + for _ in range(num_batches): + eval_batch = next(self._eval_iters[split]) + loss += self._eval_batch(params, eval_batch) + if USE_PYTORCH_DDP: + dist.all_reduce(loss) + mean_loss = loss.item() / num_examples + return {'loss': mean_loss} - def eval_batch_size(): + # Does NOT apply regularization, which is left to the submitter to do in + # `update_params`. + def loss_fn( + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + """Evaluate the (masked) loss function at (label_batch, logits_batch). + + Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 1-d array of per-example losses} + (not synced across devices). + """ pass + + + diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 8299133c1..fb8701f4d 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -711,11 +711,11 @@ def download_wmt(data_dir): def download_finewebedu(data_dir, tmp_dir): """Download FineWebEdu-10B.""" - tmp_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser("~/.cache/huggingface/datasets") data_dir = os.path.join(data_dir, 'finewebedu') - - _maybe_mkdir(tmp_dir) + tmp_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None \ + else os.path.expanduser("~/.cache/huggingface/datasets") _maybe_mkdir(data_dir) + _maybe_mkdir(tmp_dir) # Use local disk instead of NFS for temp storage os.environ["TMPDIR"] = tmp_dir From f111d2e8baada7af619504a87974fa78f3e34d55 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 11:29:37 +0100 Subject: [PATCH 12/82] add workload LM info --- algoperf/workloads/workloads.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/algoperf/workloads/workloads.py b/algoperf/workloads/workloads.py index 4712f4e25..6b99a25a6 100644 --- a/algoperf/workloads/workloads.py +++ b/algoperf/workloads/workloads.py @@ -114,6 +114,7 @@ 'workload_path': 'librispeech_deepspeech/librispeech', 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', }, + 'lm': {'workload_path': 'lm/lm', 'workload_class_name': 'LmWorkload'}, 'mnist': { 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload' }, @@ -150,6 +151,7 @@ 'imagenet_vit', 'librispeech_conformer', 'librispeech_deepspeech', + 'lm', 'ogbg', 'wmt' ] From 808d398ee2cf78e92cea29e2d0696eb6ce592929 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 11:32:48 +0100 Subject: [PATCH 13/82] restore data_utils.py tree map --- algoperf/data_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 068c21c03..37d1bd20f 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -65,7 +65,7 @@ def _prepare(x): # Assumes that `global_batch_size % local_device_count == 0`. return x.reshape((local_device_count, -1, *x.shape[1:])) - return jax.tree_util.tree_map(_prepare, batch) + return jax.tree.map(_prepare, batch) def pad(tensor: np.ndarray, From 35f8f8942cb993628f1b20c3d29346e4d7b40e95 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 14:38:41 +0100 Subject: [PATCH 14/82] fixed NFS bug --- datasets/dataset_setup.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index fb8701f4d..a68da3ff5 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -708,26 +708,28 @@ def download_wmt(data_dir): ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7) -def download_finewebedu(data_dir, tmp_dir): +def download_finewebedu(data_dir, tmp_dir=None): """Download FineWebEdu-10B.""" data_dir = os.path.join(data_dir, 'finewebedu') - tmp_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None \ - else os.path.expanduser("~/.cache/huggingface/datasets") + tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' + cache_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser('~/.cache/huggingface/datasets') + _maybe_mkdir(data_dir) _maybe_mkdir(tmp_dir) + _maybe_mkdir(cache_dir) - # Use local disk instead of NFS for temp storage os.environ["TMPDIR"] = tmp_dir ds = hf_datasets.load_dataset( 'HuggingFaceFW/fineweb-edu', name='sample-10BT', split='train', - cache_dir=tmp_dir + cache_dir=cache_dir ) - ds = ds.shuffle(seed=1996) # shuffle so that multiproc has shards of similar size + # Shuffle so that multiproc has shards of similar size. + ds = ds.shuffle(seed=1996) seq_len = 2048 max_seq_length = seq_len+1 @@ -754,11 +756,8 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: tokenizer.model_max_length = seq_len tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) - + # Concat in chunks of max_seq_len - # TODO (nico): this might take to much memory - # TODO (nico): bug fix: Python's shutil.rmtree tried to delete .nfs file, but it was still in use (OSError: [Errno 16] Device or resource busy - # TODO (nico): make it sequential or increase batch_size in the map_setup def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Concatenate text and generate chunks of max_seq_length""" concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} From cbb6ee67c6eb4828b574987d45fde508e5f1db67 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 15:02:27 +0100 Subject: [PATCH 15/82] train/val split before concat --- datasets/dataset_setup.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index a68da3ff5..5e27211e8 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -756,8 +756,21 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: tokenizer.model_max_length = seq_len tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) - - # Concat in chunks of max_seq_len + + # Find how many entries to take from dataset to have VAL_TOKENS in validation set. + VAL_TOKENS = 10_000_000 + tokens_accumulated, num_examples_for_val = 0, 0 + for example in tokenized_dataset: + tokens_accumulated += len(example['input_ids']) + num_examples_for_val += 1 + if tokens_accumulated >= VAL_TOKENS: + break + # Split in train and valid. + val_dataset = tokenized_dataset.select(range(num_examples_for_val)) + train_dataset = tokenized_dataset.select(range(num_examples_for_val, len(tokenized_dataset))) + + # Concat in chunks of max_seq_len. + # NOTE: expected token loss by batched concat_chunk. Truncates leftover tokens that don't fill a full max_seq_length chunk. def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Concatenate text and generate chunks of max_seq_length""" concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} @@ -769,18 +782,11 @@ def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: for k, t in concatenated_examples.items() } return result - lm_dataset = tokenized_dataset.map(concat_chunck, **map_setup) - n_tokens = len(lm_dataset) * max_seq_length - logging.info(f"Number of tokens in dataset: {n_tokens:_}") - - # Split dataset into training and validation sets - # TODO (nico): avoid (single doc) contamination, by splitting before concatenation - VAL_TOKENS = 10_000_000 - val_samples = VAL_TOKENS // max_seq_length + 1 - val_dataset = lm_dataset.select(range(val_samples)) - train_dataset = lm_dataset.select(range(val_samples, len(lm_dataset))) - logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length :_}") - logging.info(f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length :_}") + # Concat text in validation and train sets. + val_dataset = val_dataset.map(concat_chunck, **map_setup) + train_dataset = train_dataset.map(concat_chunck, **map_setup) + logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}") + logging.info(f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length:_}") # Save datasets train_dataset.save_to_disk(os.path.join(data_dir, f"train")) From 868987c2fd72ced8107048e20de44a7e303074e8 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 19 Mar 2025 09:41:05 +0100 Subject: [PATCH 16/82] renamed datasets to avoid conflict with HF --- {datasets => datasets_algoperf}/README.md | 0 .../dataset_setup.py | 17 ++++++++++------- .../librispeech_preprocess.py | 2 +- .../librispeech_tokenizer.py | 0 4 files changed, 11 insertions(+), 8 deletions(-) rename {datasets => datasets_algoperf}/README.md (100%) rename {datasets => datasets_algoperf}/dataset_setup.py (98%) rename {datasets => datasets_algoperf}/librispeech_preprocess.py (98%) rename {datasets => datasets_algoperf}/librispeech_tokenizer.py (100%) diff --git a/datasets/README.md b/datasets_algoperf/README.md similarity index 100% rename from datasets/README.md rename to datasets_algoperf/README.md diff --git a/datasets/dataset_setup.py b/datasets_algoperf/dataset_setup.py similarity index 98% rename from datasets/dataset_setup.py rename to datasets_algoperf/dataset_setup.py index 5e27211e8..21811e729 100644 --- a/datasets/dataset_setup.py +++ b/datasets_algoperf/dataset_setup.py @@ -56,7 +56,7 @@ Example command: -python3 datasets/dataset_setup.py \ +python3 datasets_algoperf/dataset_setup.py \ --data_dir=~/data \ --temp_dir=/tmp/mlcommons_data --imagenet \ @@ -126,15 +126,15 @@ flags.DEFINE_boolean('fastmri', False, 'If --all=false, whether or not to download FastMRI.') +flags.DEFINE_boolean('finewebedu', + False, + 'If --all=false, whether or not to download FineWebEdu.') flags.DEFINE_boolean('imagenet', False, 'If --all=false, whether or not to download Imagenet.') flags.DEFINE_boolean('librispeech', False, 'If --all=false, whether or not to download LibriSpeech.') -flags.DEFINE_boolean('finewebedu', - False, - 'If --all=false, whether or not to download FineWebEdu.') flags.DEFINE_boolean('mnist', False, 'If --all=false, whether or not to download MNIST.') @@ -727,6 +727,8 @@ def download_finewebedu(data_dir, tmp_dir=None): split='train', cache_dir=cache_dir ) + # TODO (nico): maybe save intermediate dataset to avoid re-downloading + # and allow re-chunking with different seq_len? # Shuffle so that multiproc has shards of similar size. ds = ds.shuffle(seed=1996) @@ -747,6 +749,7 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: return_attention_mask=False ) tokenizer.model_max_length = 1e30 # prevent truncation during tokenization + logging.info(f"Tokenizing...") tokenized_dataset = ds.map( tokenize, remove_columns=['text', 'id', 'dump', 'url', 'file_path', 'language', @@ -783,6 +786,7 @@ def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: } return result # Concat text in validation and train sets. + logging.info(f"Concatenating and chunking...") val_dataset = val_dataset.map(concat_chunck, **map_setup) train_dataset = train_dataset.map(concat_chunck, **map_setup) logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}") @@ -876,9 +880,8 @@ def main(_): download_wmt(data_dir) if FLAGS.all or FLAGS.finewebedu: - if not FLAGS.skip_download: - logging.info('Downloading FineWebEdu-10B...') - download_finewebedu(data_dir) + logging.info('Downloading FineWebEdu-10B...') + download_finewebedu(data_dir, tmp_dir) # pylint: enable=logging-format-interpolation diff --git a/datasets/librispeech_preprocess.py b/datasets_algoperf/librispeech_preprocess.py similarity index 98% rename from datasets/librispeech_preprocess.py rename to datasets_algoperf/librispeech_preprocess.py index a8c5cae1d..cd291e5b3 100644 --- a/datasets/librispeech_preprocess.py +++ b/datasets_algoperf/librispeech_preprocess.py @@ -15,7 +15,7 @@ from pydub import AudioSegment import tensorflow as tf -from datasets import librispeech_tokenizer +from datasets_algoperf import librispeech_tokenizer gfile = tf.io.gfile copy = tf.io.gfile.copy diff --git a/datasets/librispeech_tokenizer.py b/datasets_algoperf/librispeech_tokenizer.py similarity index 100% rename from datasets/librispeech_tokenizer.py rename to datasets_algoperf/librispeech_tokenizer.py From dd59dedc97f99e994221775b1e980d845bfb908c Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 19 Mar 2025 09:55:11 +0100 Subject: [PATCH 17/82] renamed datasets to dataset --- {datasets_algoperf => dataset}/README.md | 0 {datasets_algoperf => dataset}/dataset_setup.py | 6 +++--- {datasets_algoperf => dataset}/librispeech_preprocess.py | 2 +- {datasets_algoperf => dataset}/librispeech_tokenizer.py | 0 4 files changed, 4 insertions(+), 4 deletions(-) rename {datasets_algoperf => dataset}/README.md (100%) rename {datasets_algoperf => dataset}/dataset_setup.py (99%) rename {datasets_algoperf => dataset}/librispeech_preprocess.py (98%) rename {datasets_algoperf => dataset}/librispeech_tokenizer.py (100%) diff --git a/datasets_algoperf/README.md b/dataset/README.md similarity index 100% rename from datasets_algoperf/README.md rename to dataset/README.md diff --git a/datasets_algoperf/dataset_setup.py b/dataset/dataset_setup.py similarity index 99% rename from datasets_algoperf/dataset_setup.py rename to dataset/dataset_setup.py index 21811e729..0c7f33de6 100644 --- a/datasets_algoperf/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -56,7 +56,7 @@ Example command: -python3 datasets_algoperf/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir=~/data \ --temp_dir=/tmp/mlcommons_data --imagenet \ @@ -74,8 +74,8 @@ from algoperf.workloads.wmt import tokenizer from algoperf.workloads.wmt.input_pipeline import \ normalize_feature_names -from datasets import librispeech_preprocess -from datasets import librispeech_tokenizer +from dataset import librispeech_preprocess +from dataset import librispeech_tokenizer import datasets as hf_datasets from transformers import AutoTokenizer diff --git a/datasets_algoperf/librispeech_preprocess.py b/dataset/librispeech_preprocess.py similarity index 98% rename from datasets_algoperf/librispeech_preprocess.py rename to dataset/librispeech_preprocess.py index cd291e5b3..b96881332 100644 --- a/datasets_algoperf/librispeech_preprocess.py +++ b/dataset/librispeech_preprocess.py @@ -15,7 +15,7 @@ from pydub import AudioSegment import tensorflow as tf -from datasets_algoperf import librispeech_tokenizer +from dataset import librispeech_tokenizer gfile = tf.io.gfile copy = tf.io.gfile.copy diff --git a/datasets_algoperf/librispeech_tokenizer.py b/dataset/librispeech_tokenizer.py similarity index 100% rename from datasets_algoperf/librispeech_tokenizer.py rename to dataset/librispeech_tokenizer.py From 496b9c31f0bdd9a50e18a6907146969fd98e73cf Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 10:52:54 +0100 Subject: [PATCH 18/82] fix style --- .gitignore | 28 +++++++++++ algoperf/workloads/lm/input_pipeline.py | 50 ++++++++----------- algoperf/workloads/lm/lm_jax/workload.py | 15 +----- algoperf/workloads/lm/lm_pytorch/workload.py | 46 +++++++++-------- .../lm/tests/test_build_input_queue_torch.py | 18 +++---- algoperf/workloads/lm/workload.py | 12 ++--- 6 files changed, 86 insertions(+), 83 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..916a29ff4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,28 @@ +__pycache__/* +__pycache__ +*egg-info +*eggs +.vscode/ +env/ +venv/ +workdir/ +makefile +*.out +*.sh +*.swp +*/data/ +*events.out.tfevents* +algoperf/workloads/librispeech_conformer/data_dir +algoperf/workloads/librispeech_conformer/work_dir +*.flac +*.npy +*.csv +*.vocab +wandb/ +*.txt +scoring/plots/ + +!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv +!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv + +algoperf/_version.py \ No newline at end of file diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index bae1f5e45..53fe79276 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -1,24 +1,22 @@ """Input pipeline for a LM dataset.""" import functools import os +from typing import Optional -from datasets import Dataset, load_from_disk -from typing import Dict, List, Optional, Union - +from datasets import load_from_disk import jax -import numpy as np import tensorflow as tf -import tensorflow_datasets as tfds from algoperf import data_utils from algoperf.pytorch_utils import pytorch_setup RANK = pytorch_setup()[1] # Avoid multithreading in all processes but the first (rank 0). -# This ensures that only the primary process (RANK == 0) uses TensorFlow's +# This ensures that only the primary process (RANK == 0) uses TensorFlow's # automatic optimization (AUTOTUNE), while other processes disable it (None). -# tf.data.AUTOTUNE is a constant that lets TensorFlow automatically determine the optimal -# number of elements to prefetch or parallelize for dataset operations, improving performance. +# tf.data.AUTOTUNE is a constant that lets TensorFlow automatically determine +# the optimal number of elements to prefetch or parallelize for dataset +# operations, improving performance. AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None @@ -44,25 +42,24 @@ def tf_generator(): """Generates data in a TensorFlow-friendly format.""" for example in dataset: yield { - "inputs": example["input_ids"][:-1], - "targets": example["input_ids"][1:], + "inputs": example["input_ids"][:-1], + "targets": example["input_ids"][1:], } # Create a TensorFlow dataset ds = tf.data.Dataset.from_generator( - tf_generator, - output_signature={ - "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int64), - "targets": tf.TensorSpec(shape=(None,), dtype=tf.int64), - } - ) + tf_generator, + output_signature={ + "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int64), + "targets": tf.TensorSpec(shape=(None,), dtype=tf.int64), + }) # Avoid creating too many threads when using PyTorch DDP. # Limits TensorFlow's threading for non-primary processes (RANK != 0) - if RANK != 0: + if RANK != 0: options = tf.data.Options() - options.threading.private_threadpool_size = 1 # restrict dataset operations to a single thread - ds = ds.with_options(options) # apply threading restrictions + options.threading.private_threadpool_size = 1 + ds = ds.with_options(options) if shuffle: ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) @@ -70,10 +67,7 @@ def tf_generator(): if is_training: ds = ds.repeat() - # Batch the dataset, ensuring the last batch is dropped if not full during training - # i.e. it groups consecutive elements into fixed-size chunks. - # Instead of processing individual elements, the dataset yields batches (tensors with multiple elements), - # improving efficiency and parallelism in training + # Batch the dataset, grouping consecutive elements into fixed-size chunks. ds = ds.batch(global_batch_size, drop_remainder=is_training) ds = ds.prefetch(AUTOTUNE) @@ -83,9 +77,9 @@ def tf_generator(): # Shard the dataset across multiple GPUs/TPUs if necessary ds = map( - functools.partial( - data_utils.shard_and_maybe_pad_np, - global_batch_size=global_batch_size), - ds) + functools.partial( + data_utils.shard_and_maybe_pad_np, + global_batch_size=global_batch_size), + ds) - return ds \ No newline at end of file + return ds diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 84377b4bc..64d538dda 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -1,22 +1,11 @@ """LM workload implemented in Jax.""" -import functools -from typing import Any, Dict, Iterator, Optional, Tuple +from typing import Dict, Optional, Tuple -from absl import logging -from flax import jax_utils -from flax import linen as nn -from flax.training import common_utils -import jax -import jax.numpy as jnp -import numpy as np -import optax - -from algoperf import param_utils -from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload + class LmWorkload(BaseLmWorkload): """LM JAX workload.""" diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 404dc2532..e57d26390 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -3,16 +3,10 @@ import contextlib from typing import Dict, Iterator, Optional, Tuple -from absl import logging import jax -import tensorflow as tf import torch import torch.distributed as dist -from torch.nn import DataParallel as DP -import torch.nn.functional as F -from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import param_utils from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload @@ -41,16 +35,17 @@ def model_fn( update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: pass - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: not_train = split != 'train' per_device_batch_size = int(global_batch_size / N_GPUS) - + seq_len = 2048 # TODO: define it somewehere else DTYPE = torch.int32 # TODO: decide between int32 and int64. @@ -65,20 +60,25 @@ def _build_input_queue(self, num_batches=num_batches, repeat_final_dataset=repeat_final_dataset) weights = None - + while True: # Only iterate over tf input pipeline in one Python process to # avoid creating too many threads. if RANK == 0: batch = next(np_iter) # pylint: disable=stop-iteration-return - inputs = torch.as_tensor(batch['inputs'], dtype=DTYPE, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) - targets = torch.as_tensor(batch['targets'], dtype=DTYPE, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + inputs = torch.as_tensor( + batch['inputs'], dtype=DTYPE, + device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + targets = torch.as_tensor( + batch['targets'], dtype=DTYPE, + device=DEVICE) # (N_GPUS, global_batch_size, seq_len) # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: if not_train: # During eval, the batch size of the remainder might be different. - per_device_batch_size = torch.tensor(len(targets[0]), dtype=DTYPE, device=DEVICE) + per_device_batch_size = torch.tensor( + len(targets[0]), dtype=DTYPE, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) # We don't broadcast the shard for RANK 0. dist.broadcast(inputs[1:], src=0) @@ -95,12 +95,16 @@ def _build_input_queue(self, dist.broadcast(per_device_batch_size, src=0) # N_GPUS - 1 since we don't broadcast the shard for RANK 0. - inputs = torch.empty((N_GPUS-1, per_device_batch_size, seq_len), dtype=DTYPE, device=DEVICE) - targets = torch.empty((N_GPUS-1, per_device_batch_size, seq_len), dtype=DTYPE, device=DEVICE) + inputs = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), + dtype=DTYPE, + device=DEVICE) + targets = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), + dtype=DTYPE, + device=DEVICE) dist.broadcast(inputs, src=0) dist.broadcast(targets, src=0) # RANK - 1 since we don't broadcast the shard for RANK 0. - inputs, targets = inputs[RANK-1], targets[RANK-1] + inputs, targets = inputs[RANK - 1], targets[RANK - 1] if weights is None: weights = torch.ones(per_device_batch_size, device=DEVICE) diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py index 83a18ec15..639e71491 100644 --- a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py +++ b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py @@ -1,11 +1,6 @@ - import jax import torch -import pdb -import numpy as np - -from algoperf import random_utils as prng -from algoperf import spec + from algoperf.profiler import PassThroughProfiler from algoperf.pytorch_utils import pytorch_init from algoperf.pytorch_utils import pytorch_setup @@ -29,20 +24,20 @@ def test_dataloader_torch(): seq_len = 2048 local_batch_size = global_batch_size // N_GPUS - + workload = LmWorkload() data_rng = jax.random.PRNGKey(rng_seed) - + input_queue = workload._build_input_queue( data_rng=data_rng, split=split, data_dir=data_dir, global_batch_size=global_batch_size) - + print(f"RANK {RANK} of {N_GPUS}") sync_ddp() - + # batch = next(input_queue) # inputs, targets = batch['inputs'], batch['targets'] # print(f"inputs.shape: {inputs.shape}") @@ -71,7 +66,7 @@ def test_dataloader_torch(): assert inputs.shape == (local_batch_size, seq_len) assert targets.shape == (local_batch_size, seq_len) - assert torch.equal(inputs[:,1:], targets[:,:-1]) + assert torch.equal(inputs[:, 1:], targets[:, :-1]) print(f"=== ALL TEST PASSED ===") @@ -84,4 +79,3 @@ def main(): if __name__ == '__main__': main() - diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index e36d54625..3d04be3c5 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -3,14 +3,11 @@ import abc import math import os -from typing import Any, Dict, Optional, Tuple +from typing import Dict, Optional from absl import flags -import torch.distributed as dist - import jax -import numpy as np -import torch +import torch.distributed as dist from algoperf import spec from algoperf.workloads.lm import input_pipeline @@ -155,7 +152,7 @@ def _eval_model_on_split(self, global_batch_size, num_batches, repeat_final_dataset=True) - + for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) loss += self._eval_batch(params, eval_batch) @@ -179,6 +176,3 @@ def loss_fn( (not synced across devices). """ pass - - - From 50989eb6a8a54c43225a4243f770a4419d431a81 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 10:57:06 +0100 Subject: [PATCH 19/82] fix formatting --- algoperf/workloads/lm/lm_pytorch/workload.py | 1 - submission_runner.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index e57d26390..be6c94c46 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -1,6 +1,5 @@ """LM workload implemented in PyTorch.""" -import contextlib from typing import Dict, Iterator, Optional, Tuple import jax diff --git a/submission_runner.py b/submission_runner.py index d7df006bb..f8a66452d 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -234,7 +234,7 @@ def train_once( dropout_rate = hyperparameters.dropout_rate if hasattr(hyperparameters, 'aux_dropout_rate'): aux_dropout_rate = hyperparameters.aux_dropout_rate - model_params, model_state = workload.init_model_fn( + model_params, model_state = workload.init_model_fn( model_init_rng, dropout_rate, aux_dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ @@ -384,8 +384,8 @@ def train_once( train_step_end_time - train_state['last_step_end_time']) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) + >= workload.eval_period_time_sec or train_state['training_complete']): # Prepare for evaluation (timed). if prepare_for_eval is not None: From 5af0fdc1437d924e2e162de5100e66782d01a7e5 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 11:02:22 +0100 Subject: [PATCH 20/82] fix style --- algoperf/workloads/lm/lm_pytorch/workload.py | 16 ++++++++-------- algoperf/workloads/lm/workload.py | 1 + 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index be6c94c46..606f16ad7 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -45,8 +45,8 @@ def _build_input_queue( not_train = split != 'train' per_device_batch_size = int(global_batch_size / N_GPUS) - seq_len = 2048 # TODO: define it somewehere else - DTYPE = torch.int32 # TODO: decide between int32 and int64. + seq_len = self._seq_len # TODO: define it somewehere else? + dtype = torch.int32 # TODO: decide between int32 and int64. # Only create and iterate over tf input pipeline in one Python process to # avoid creating too many threads. @@ -66,10 +66,10 @@ def _build_input_queue( if RANK == 0: batch = next(np_iter) # pylint: disable=stop-iteration-return inputs = torch.as_tensor( - batch['inputs'], dtype=DTYPE, + batch['inputs'], dtype=dtype, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) targets = torch.as_tensor( - batch['targets'], dtype=DTYPE, + batch['targets'], dtype=dtype, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) # Send batch to other devices when using DDP. @@ -77,7 +77,7 @@ def _build_input_queue( if not_train: # During eval, the batch size of the remainder might be different. per_device_batch_size = torch.tensor( - len(targets[0]), dtype=DTYPE, device=DEVICE) + len(targets[0]), dtype=dtype, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) # We don't broadcast the shard for RANK 0. dist.broadcast(inputs[1:], src=0) @@ -90,15 +90,15 @@ def _build_input_queue( # Receive batch from rank 0. if not_train: # During eval, the batch size of the remainder might be different. - per_device_batch_size = torch.empty((1,), dtype=DTYPE, device=DEVICE) + per_device_batch_size = torch.empty((1,), dtype=dtype, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) # N_GPUS - 1 since we don't broadcast the shard for RANK 0. inputs = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), - dtype=DTYPE, + dtype=dtype, device=DEVICE) targets = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), - dtype=DTYPE, + dtype=dtype, device=DEVICE) dist.broadcast(inputs, src=0) dist.broadcast(targets, src=0) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 3d04be3c5..aa6d188b3 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -21,6 +21,7 @@ class BaseLmWorkload(spec.Workload): """LM workload.""" _vocab_size: int = 32000 + _seq_len: int = 2048 def __init__(self) -> None: super().__init__() From 26830999b92d26c729171cae141ee7abb3409463 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 11:32:47 +0100 Subject: [PATCH 21/82] fix style --- algoperf/workloads/lm/workload.py | 2 +- dataset/dataset_setup.py | 91 +++++++++++++++++++------------ 2 files changed, 56 insertions(+), 37 deletions(-) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index aa6d188b3..4eb6c74a5 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -24,7 +24,7 @@ class BaseLmWorkload(spec.Workload): _seq_len: int = 2048 def __init__(self) -> None: - super().__init__() + pass @property def target_metric_name(self) -> str: diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 0c7f33de6..8f0b09ab7 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -80,7 +80,6 @@ import datasets as hf_datasets from transformers import AutoTokenizer -import math import functools import itertools import os @@ -713,7 +712,9 @@ def download_finewebedu(data_dir, tmp_dir=None): data_dir = os.path.join(data_dir, 'finewebedu') tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' - cache_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser('~/.cache/huggingface/datasets') + cache_dir = os.path.join(tmp_dir, + 'lm') if tmp_dir is not None else os.path.expanduser( + '~/.cache/huggingface/datasets') _maybe_mkdir(data_dir) _maybe_mkdir(tmp_dir) @@ -722,75 +723,93 @@ def download_finewebedu(data_dir, tmp_dir=None): os.environ["TMPDIR"] = tmp_dir ds = hf_datasets.load_dataset( - 'HuggingFaceFW/fineweb-edu', - name='sample-10BT', - split='train', - cache_dir=cache_dir - ) - # TODO (nico): maybe save intermediate dataset to avoid re-downloading + 'HuggingFaceFW/fineweb-edu', + name='sample-10BT', + split='train', + cache_dir=cache_dir) + # TODO (nico): maybe save intermediate dataset to avoid re-downloading # and allow re-chunking with different seq_len? # Shuffle so that multiproc has shards of similar size. ds = ds.shuffle(seed=1996) seq_len = 2048 - max_seq_length = seq_len+1 + max_seq_length = seq_len + 1 map_setup = dict(batched=True, batch_size=1024, num_proc=8) # Tokenize - tokenizer = AutoTokenizer.from_pretrained('gpt2') - logging.info(f"Vocab size of tokenizer = {len(tokenizer)}") + lm_tokenizer = AutoTokenizer.from_pretrained('gpt2') + logging.info(f"Vocab size of lm_tokenizer = {len(lm_tokenizer)}") + def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq + add_eos = lambda seq: (seq + lm_tokenizer.eos_token) if seq else seq add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs] - return tokenizer( - add_eos_batched(examples["text"]), - return_special_tokens_mask=False, - return_attention_mask=False - ) - tokenizer.model_max_length = 1e30 # prevent truncation during tokenization + return lm_tokenizer( + add_eos_batched(examples["text"]), + return_special_tokens_mask=False, + return_attention_mask=False) + + lm_tokenizer.model_max_length = 1e30 # prevent truncation during tokenization logging.info(f"Tokenizing...") tokenized_dataset = ds.map( - tokenize, - remove_columns=['text', 'id', 'dump', 'url', 'file_path', 'language', - 'language_score', 'token_count', 'score', 'int_score'], - **map_setup - ) - tokenizer.model_max_length = seq_len - + tokenize, + remove_columns=[ + 'text', + 'id', + 'dump', + 'url', + 'file_path', + 'language', + 'language_score', + 'token_count', + 'score', + 'int_score' + ], + **map_setup) + lm_tokenizer.model_max_length = seq_len + tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) - # Find how many entries to take from dataset to have VAL_TOKENS in validation set. - VAL_TOKENS = 10_000_000 + # Find how many entries to take from dataset to have val_tokens in validation set. + val_tokens = 10_000_000 # TODO: decide this value. tokens_accumulated, num_examples_for_val = 0, 0 for example in tokenized_dataset: tokens_accumulated += len(example['input_ids']) num_examples_for_val += 1 - if tokens_accumulated >= VAL_TOKENS: - break + if tokens_accumulated >= val_tokens: + break # Split in train and valid. val_dataset = tokenized_dataset.select(range(num_examples_for_val)) - train_dataset = tokenized_dataset.select(range(num_examples_for_val, len(tokenized_dataset))) + train_dataset = tokenized_dataset.select( + range(num_examples_for_val, len(tokenized_dataset))) # Concat in chunks of max_seq_len. # NOTE: expected token loss by batched concat_chunk. Truncates leftover tokens that don't fill a full max_seq_length chunk. def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Concatenate text and generate chunks of max_seq_length""" - concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} + concatenated_examples = { + k: list(itertools.chain(*examples[k])) for k in examples.keys() + } total_length = len(concatenated_examples[list(examples.keys())[0]]) if total_length >= max_seq_length: - total_length = (total_length // max_seq_length) * max_seq_length + total_length = (total_length // max_seq_length) * max_seq_length result = { - k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] - for k, t in concatenated_examples.items() + k: [ + t[i:i + max_seq_length] + for i in range(0, total_length, max_seq_length) + ] for k, t in concatenated_examples.items() } return result + # Concat text in validation and train sets. logging.info(f"Concatenating and chunking...") val_dataset = val_dataset.map(concat_chunck, **map_setup) train_dataset = train_dataset.map(concat_chunck, **map_setup) - logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}") - logging.info(f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length:_}") + logging.info( + f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}") + logging.info( + f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length:_}" + ) # Save datasets train_dataset.save_to_disk(os.path.join(data_dir, f"train")) From 6b7ee29684ee9bf1f9564032f65c09373212c4a4 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 11:36:27 +0100 Subject: [PATCH 22/82] fix yapf --- submission_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index f8a66452d..468a04c7c 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -384,8 +384,8 @@ def train_once( train_step_end_time - train_state['last_step_end_time']) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) - >= workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) >= + workload.eval_period_time_sec or train_state['training_complete']): # Prepare for evaluation (timed). if prepare_for_eval is not None: From 46b645b2ac4a4f4b93fe4ee6324b07f412fb81b3 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 11:38:40 +0100 Subject: [PATCH 23/82] fix style --- dataset/dataset_setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 8f0b09ab7..6587f1439 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -797,7 +797,8 @@ def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: k: [ t[i:i + max_seq_length] for i in range(0, total_length, max_seq_length) - ] for k, t in concatenated_examples.items() + ] for k, + t in concatenated_examples.items() } return result From b3ae6474be93f07c578f885bae484773b8a65515 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 27 Mar 2025 15:56:25 +0000 Subject: [PATCH 24/82] HF datasets pipeline --- algoperf/workloads/lm/input_pipeline.py | 75 ++++++++++- .../lm/tests/test_hf_input_pipeline.py | 116 ++++++++++++++++++ 2 files changed, 190 insertions(+), 1 deletion(-) create mode 100644 algoperf/workloads/lm/tests/test_hf_input_pipeline.py diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 53fe79276..ea4cb9d63 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -3,12 +3,17 @@ import os from typing import Optional -from datasets import load_from_disk import jax +import jax.numpy as jnp import tensorflow as tf +import torch +import torch.nn.functional as F +from transformers import GPT2Tokenizer from algoperf import data_utils from algoperf.pytorch_utils import pytorch_setup +from datasets import load_dataset +from datasets import load_from_disk RANK = pytorch_setup()[1] # Avoid multithreading in all processes but the first (rank 0). @@ -20,6 +25,74 @@ AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None +def get_hf_dataloader(cache_dir: str, + data_rng: jax.random.PRNGKey, + batch_size: int = 8, + seq_len: int = 32, + framework: str = "torch", + split="train"): + """ + Create a data loader from HuggingFace's FineWeb dataset. + + Args: + cache_dir: Directory to cache the dataset + batch_size: Number of sequences per batch + seq_len: Length of each sequence + framework: Either "torch" or "jax" to specify output tensor type + split: Dataset split to load + """ + # Initialize tokenizer and get vocab size + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + vocab_size = tokenizer.vocab_size + # Load the FineWeb dataset in streaming mode + fw = load_dataset( + "HuggingFaceFW/fineweb-edu", + name="sample-10BT", + split=split, + streaming=True, + cache_dir=cache_dir) + fw = fw.batch(batch_size=batch_size, drop_last_batch=True) + if split in ['train', 'eval_train']: + fw = fw.shuffle(seed=int(data_rng[-1])) + + def _tokenize(x): + """Tokenize and pad text to seq_len+1 tokens.""" + if framework == "torch": + tokens = tokenizer(x, return_tensors="pt")["input_ids"].squeeze() + pad_length = seq_len - tokens.shape[0] + if pad_length > 0: + tokens = F.pad(tokens, pad_length, value=tokenizer.pad_token_id) + elif framework == "jax": + tokens = tokenizer(x, return_tensors="jax")["input_ids"].squeeze() + pad_length = seq_len - tokens.shape[0] + if pad_length > 0: + tokens = jnp.pad( + tokens, + pad_length, + mode="constant", + constant_values=tokenizer.pad_token_id) + return tokens[:seq_len + 1] + + def batch_iterator(): + for doc in fw: + if framework == "torch": + token_ids = torch.stack([_tokenize(x) for x in doc['text']]) + # Take first seq_len+1 tokens and convert to one-hot + tokens = F.one_hot(token_ids, num_classes=vocab_size).float() + # Split into input/target + inputs, targets = tokens[:, :-1, :], tokens[:, 1:, :] + inputs, targets = inputs.to("cuda"), targets.to("cuda") + elif framework == "jax": + token_ids = jnp.stack([_tokenize(x) for x in doc['text']]) + tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) + inputs, targets = tokens[:, :-1], tokens[:, 1:] + devices = jax.devices("gpu") + inputs, targets = jax.device_put(inputs), jax.device_put(targets) + yield inputs, targets + + return batch_iterator() + + def get_lm_dataset(data_rng: jax.random.PRNGKey, split: str, data_dir: str, diff --git a/algoperf/workloads/lm/tests/test_hf_input_pipeline.py b/algoperf/workloads/lm/tests/test_hf_input_pipeline.py new file mode 100644 index 000000000..36bab0d02 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_hf_input_pipeline.py @@ -0,0 +1,116 @@ +"""Tests for LM HuggingFace input pipeline.""" +import os + +import jax +import jax.numpy as jnp +import torch +from transformers import GPT2Tokenizer + +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader + + +def main(): + # Setup test environment + cache_dir = "/home/ak4605/data" + if not os.path.exists(cache_dir): + raise FileNotFoundError(f"Cache directory {cache_dir} not found") + + data_rng = jax.random.PRNGKey(42) + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + vocab_size = tokenizer.vocab_size + + print("Running JAX output shapes and types test...") + batch_size = 8 + seq_len = 32 + loader = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="train", + data_rng=data_rng) + inputs, targets = next(loader) + assert inputs.shape == (batch_size, seq_len, vocab_size), \ + f"Expected inputs shape {(batch_size, seq_len, vocab_size)}, got {inputs.shape}" + assert targets.shape == (batch_size, seq_len, vocab_size), \ + f"Expected targets shape {(batch_size, seq_len, vocab_size)}, got {targets.shape}" + assert inputs.dtype == jnp.float32, \ + f"Expected inputs dtype float32, got {inputs.dtype}" + assert targets.dtype == jnp.float32, \ + f"Expected targets dtype float32, got {targets.dtype}" + assert jnp.all(jnp.sum(inputs, axis=-1) == 1), "Inputs should be one-hot encoded" + assert jnp.all(jnp.sum(targets, axis=-1) == 1), "Targets should be one-hot encoded" + print("✓ JAX test passed") + + print("\nRunning Torch output shapes and types test...") + loader = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="torch", + split="train", + data_rng=data_rng) + inputs, targets = next(loader) + assert inputs.shape == (batch_size, seq_len, vocab_size), \ + f"Expected inputs shape {(batch_size, seq_len, vocab_size)}, got {inputs.shape}" + assert targets.shape == (batch_size, seq_len, vocab_size), \ + f"Expected targets shape {(batch_size, seq_len, vocab_size)}, got {targets.shape}" + assert inputs.dtype == torch.float32, \ + f"Expected inputs dtype float32, got {inputs.dtype}" + assert targets.dtype == torch.float32, \ + f"Expected targets dtype float32, got {targets.dtype}" + assert torch.all(torch.sum(inputs, dim=-1) == 1), "Inputs should be one-hot encoded" + assert torch.all(torch.sum(targets, dim=-1) == 1), "Targets should be one-hot encoded" + print("✓ Torch test passed") + + print("\nTesting consistent batching with same seed...") + loader1 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="train", + data_rng=jax.random.PRNGKey(42)) + batch1 = next(loader1) + + loader2 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="train", + data_rng=jax.random.PRNGKey(42)) + batch2 = next(loader2) + + assert jnp.array_equal(batch1[0], batch2[0]), "Input batches should be identical with same seed" + assert jnp.array_equal(batch1[1], batch2[1]), "Target batches should be identical with same seed" + print("✓ Consistent batching test passed") + + print("\nTesting eval split doesn't shuffle...") + loader1 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="eval", + data_rng=jax.random.PRNGKey(42)) + batch1 = next(loader1) + + loader2 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="eval", + data_rng=jax.random.PRNGKey(999)) + batch2 = next(loader2) + + assert jnp.array_equal(batch1[0], batch2[0]), "Eval inputs should be identical regardless of seed" + assert jnp.array_equal(batch1[1], batch2[1]), "Eval targets should be identical regardless of seed" + print("✓ Eval no shuffling test passed") + + print("\nAll tests passed successfully!") + + +if __name__ == "__main__": + main() From f095d4b167dabc0e1aeb925b871f32f427fc22c8 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 27 Mar 2025 17:03:05 +0000 Subject: [PATCH 25/82] Testing with linear model --- algoperf/workloads/lm/input_pipeline.py | 1 - algoperf/workloads/lm/lm_jax/models.py | 18 +++++++++ algoperf/workloads/lm/lm_jax/workload.py | 26 +++++++++++-- algoperf/workloads/lm/lm_pytorch/models.py | 18 +++++++++ algoperf/workloads/lm/lm_pytorch/workload.py | 32 +++++++++++++-- .../workloads/lm/tests/test_linear_model.py | 39 +++++++++++++++++++ algoperf/workloads/lm/workload.py | 17 ++------ 7 files changed, 129 insertions(+), 22 deletions(-) create mode 100644 algoperf/workloads/lm/lm_jax/models.py create mode 100644 algoperf/workloads/lm/lm_pytorch/models.py create mode 100644 algoperf/workloads/lm/tests/test_linear_model.py diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index ea4cb9d63..cc658501e 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -86,7 +86,6 @@ def batch_iterator(): token_ids = jnp.stack([_tokenize(x) for x in doc['text']]) tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) inputs, targets = tokens[:, :-1], tokens[:, 1:] - devices = jax.devices("gpu") inputs, targets = jax.device_put(inputs), jax.device_put(targets) yield inputs, targets diff --git a/algoperf/workloads/lm/lm_jax/models.py b/algoperf/workloads/lm/lm_jax/models.py new file mode 100644 index 000000000..edfc102fa --- /dev/null +++ b/algoperf/workloads/lm/lm_jax/models.py @@ -0,0 +1,18 @@ +from flax import linen as nn +import jax.numpy as jnp + +class LinearModel(nn.Module): + vocab_size: int + + @nn.compact + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + x = nn.Dense( + 512, + kernel_init=nn.initializers.normal(0.02), + bias_init=nn.initializers.zeros + )(inputs) + return nn.Dense( + self.vocab_size, + kernel_init=nn.initializers.normal(0.02), + bias_init=nn.initializers.zeros + )(x) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 64d538dda..30b0c7867 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -2,8 +2,12 @@ from typing import Dict, Optional, Tuple +import jax.numpy as jnp +from flax import jax_utils +from algoperf import param_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload +from algoperf.workloads.lm.lm_jax.models import LinearModel class LmWorkload(BaseLmWorkload): @@ -14,18 +18,32 @@ def init_model_fn( rng: spec.RandomState, dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is used as attention_dropout_rate.""" - pass + + model = LinearModel(vocab_size=self._vocab_size) + input_shape = (1, self._seq_len, self._vocab_size) + variables = model.init(rng, jnp.ones(input_shape, jnp.float32)) + model_state, params = variables.pop('params') + + self._param_shapes = param_utils.jax_param_shapes(params) + self._param_types = param_utils.jax_param_types(self._param_shapes) + model_state = jax_utils.replicate(model_state) + params = jax_utils.replicate(params) + + return params, model_state def model_fn( self, params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - pass + + del mode, rng, update_batch_norm # Not used for linear model + inputs = batch['inputs'] + logits = self._model.apply({'params': params, **model_state}, inputs) + return logits, model_state def _eval_batch(self, params: spec.ParameterContainer, diff --git a/algoperf/workloads/lm/lm_pytorch/models.py b/algoperf/workloads/lm/lm_pytorch/models.py new file mode 100644 index 000000000..545763924 --- /dev/null +++ b/algoperf/workloads/lm/lm_pytorch/models.py @@ -0,0 +1,18 @@ +import torch +import torch.nn as nn + +class LinearLayer(nn.Module): + def __init__(self, vocab_size: int): + super().__init__() + self.bottleneck = nn.Linear(vocab_size, 512) + self.output = nn.Linear(512, vocab_size) + self.reset_parameters() + + def reset_parameters(self): + nn.init.normal_(self.bottleneck.weight, std=0.02) + nn.init.zeros_(self.bottleneck.bias) + nn.init.normal_(self.output.weight, std=0.02) + nn.init.zeros_(self.output.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.output(self.bottleneck(x)) diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 606f16ad7..3395aa08f 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -5,10 +5,13 @@ import jax import torch import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from algoperf import param_utils from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload +from algoperf.workloads.lm.lm_pytorch.models import LinearLayer USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -21,18 +24,39 @@ def init_model_fn( rng: spec.RandomState, dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is used as attention_dropout_rate.""" - pass + + if hasattr(self, '_model'): + self._model.reset_parameters() + return self._model, None + + torch.manual_seed(rng[0]) + self._model = LinearLayer(vocab_size=self._vocab_size) + self._param_shapes = param_utils.pytorch_param_shapes(self._model) + self._param_types = param_utils.pytorch_param_types(self._param_shapes) + self._model.to(DEVICE) + + if N_GPUS > 1: + if USE_PYTORCH_DDP: + self._model = DDP(self._model, device_ids=[RANK], output_device=RANK) + else: + self._model = torch.nn.DataParallel(self._model) + + return self._model, None def model_fn( self, params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - pass + + del model_state, rng, update_batch_norm # Not used for linear model + model = params + inputs = batch['inputs'].float() # Convert one-hot to float + logits = model(inputs) + return logits, None def _build_input_queue( self, diff --git a/algoperf/workloads/lm/tests/test_linear_model.py b/algoperf/workloads/lm/tests/test_linear_model.py new file mode 100644 index 000000000..31cd1d577 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_linear_model.py @@ -0,0 +1,39 @@ +import jax +import jax.numpy as jnp +import torch + +TEST_SEQ_LEN = 512 + +def test_pytorch_linear(): + from algoperf.workloads.lm.lm_pytorch.models import LinearLayer + vocab_size = 32000 + model = LinearLayer(vocab_size) + + batch_size = 8 + seq_len = TEST_SEQ_LEN + inputs = torch.randn(batch_size, seq_len, vocab_size) + outputs = model(inputs) + + assert outputs.shape == (batch_size, seq_len, vocab_size) + assert not torch.isnan(outputs).any() + +def test_jax_linear(): + from algoperf.workloads.lm.lm_jax.models import LinearModel + + vocab_size = 32000 + seq_len = TEST_SEQ_LEN + batch_size = 8 + model = LinearModel(vocab_size) + rng = jax.random.PRNGKey(0) + params = model.init(rng, jnp.ones((1, seq_len, vocab_size))) + + inputs = jax.random.normal(rng, (batch_size, seq_len, vocab_size)) + outputs = model.apply(params, inputs) + + assert outputs.shape == (batch_size, seq_len, vocab_size) + assert not jnp.isnan(outputs).any() + +if __name__ == '__main__': + test_pytorch_linear() + test_jax_linear() + print("All tests passed!") diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 4eb6c74a5..a06b17fdc 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -20,8 +20,8 @@ class BaseLmWorkload(spec.Workload): """LM workload.""" - _vocab_size: int = 32000 - _seq_len: int = 2048 + _vocab_size: int = 50257 + _seq_len: int = 512 def __init__(self) -> None: pass @@ -106,6 +106,7 @@ def activation(self) -> str: def glu(self) -> bool: return True + @abc.abstractmethod def _build_input_queue(self, data_rng: jax.random.PRNGKey, split: str, @@ -113,17 +114,7 @@ def _build_input_queue(self, global_batch_size: int, num_batches: Optional[int] = None, repeat_final_dataset: bool = False): - ds = input_pipeline.get_lm_dataset( - data_rng, - split, - data_dir, - vocab_size=self._vocab_size, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) - - for batch in iter(ds): - yield batch + """Build an input queue for the given split.""" @abc.abstractmethod def _eval_batch(self, From 0c22f3df420968cf820cbcc826f84a61751f95f5 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 3 Apr 2025 12:28:05 -0400 Subject: [PATCH 26/82] lm workload with linear model --- .../workloads/cifar/cifar_jax/workload.py | 11 -- algoperf/workloads/lm/input_pipeline.py | 2 +- algoperf/workloads/lm/lm_jax/models.py | 5 +- algoperf/workloads/lm/lm_jax/workload.py | 82 +++++++++-- algoperf/workloads/lm/lm_pytorch/workload.py | 129 ++++++++++-------- algoperf/workloads/lm/workload.py | 59 ++++---- pyproject.toml | 3 +- .../nesterov/jax/submission.py | 8 +- submission_runner.py | 6 +- 9 files changed, 187 insertions(+), 118 deletions(-) diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index f827fac87..fd990eeaa 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -71,17 +71,6 @@ def _build_input_queue( cache, repeat_final_dataset) - def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - """Sync the batch statistics across replicas.""" - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics - # and we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() - new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) - return new_model_state - def init_model_fn( self, rng: spec.RandomState, diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index cc658501e..440de64c1 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -87,7 +87,7 @@ def batch_iterator(): tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) inputs, targets = tokens[:, :-1], tokens[:, 1:] inputs, targets = jax.device_put(inputs), jax.device_put(targets) - yield inputs, targets + yield {'inputs': inputs, 'targets': targets} return batch_iterator() diff --git a/algoperf/workloads/lm/lm_jax/models.py b/algoperf/workloads/lm/lm_jax/models.py index edfc102fa..72ee5bd83 100644 --- a/algoperf/workloads/lm/lm_jax/models.py +++ b/algoperf/workloads/lm/lm_jax/models.py @@ -7,12 +7,13 @@ class LinearModel(nn.Module): @nn.compact def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: x = nn.Dense( - 512, + 10, kernel_init=nn.initializers.normal(0.02), bias_init=nn.initializers.zeros )(inputs) return nn.Dense( self.vocab_size, kernel_init=nn.initializers.normal(0.02), - bias_init=nn.initializers.zeros + bias_init=nn.initializers.zeros, + name="output" )(x) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 30b0c7867..7cb50302f 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -2,33 +2,57 @@ from typing import Dict, Optional, Tuple +import jax import jax.numpy as jnp +import optax from flax import jax_utils from algoperf import param_utils +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload from algoperf.workloads.lm.lm_jax.models import LinearModel +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader class LmWorkload(BaseLmWorkload): """LM JAX workload.""" + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + """Build an input queue using HuggingFace FineWeb dataset.""" + del num_batches + del repeat_final_dataset + loader = get_hf_dataloader( + cache_dir=data_dir, + data_rng=data_rng, + batch_size=global_batch_size, + seq_len=self._seq_len, + framework="jax", + split=split) + return loader + def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - model = LinearModel(vocab_size=self._vocab_size) + self._model = LinearModel(vocab_size=self._vocab_size) input_shape = (1, self._seq_len, self._vocab_size) - variables = model.init(rng, jnp.ones(input_shape, jnp.float32)) - model_state, params = variables.pop('params') - + params_rng, init_rng = jax.random.split(rng) + print(params_rng) + # variables = model.init(init_rng, jnp.ones(input_shape, jnp.float32)) + variables = jax.jit(self._model.init)({'params': params_rng}, jnp.ones(input_shape, jnp.float32)) + params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) - + params = sharding_utils.shard_replicated(params) + model_state = None return params, model_state def model_fn( @@ -40,15 +64,51 @@ def model_fn( rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - del mode, rng, update_batch_norm # Not used for linear model + del mode, rng, update_batch_norm, model_state inputs = batch['inputs'] - logits = self._model.apply({'params': params, **model_state}, inputs) - return logits, model_state + logits = self._model.apply({'params': params}, inputs) + return logits, None + + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling in JAX.""" + vocab_size = logits_batch.shape[-1] + + if len(label_batch.shape) == len(logits_batch.shape): + # One-hot labels + loss = -jnp.sum(label_batch * jax.nn.log_softmax(logits_batch, axis=-1)) + else: + # Dense labels + loss = -jax.nn.log_softmax(logits_batch)[jnp.arange(label_batch.shape[0]), label_batch] + + if mask_batch is not None: + loss = loss * mask_batch + + n_valid = mask_batch.sum() if mask_batch is not None else label_batch.shape[0] + return { + 'summed': loss.sum(), + 'n_valid_examples': n_valid, + 'per_example': loss + } + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return param_name.contains('output') + def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" - pass + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + targets = batch['targets'] + + # Calculate cross-entropy loss + loss = -jnp.sum(targets * jax.nn.log_softmax(logits, axis=-1)) + return loss diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 3395aa08f..0d0281690 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -66,68 +66,38 @@ def _build_input_queue( global_batch_size: int, num_batches: Optional[int] = None, repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: - not_train = split != 'train' - per_device_batch_size = int(global_batch_size / N_GPUS) - - seq_len = self._seq_len # TODO: define it somewehere else? - dtype = torch.int32 # TODO: decide between int32 and int64. - - # Only create and iterate over tf input pipeline in one Python process to - # avoid creating too many threads. - if RANK == 0: - np_iter = super()._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) + """Build an input queue for the given split.""" + from algoperf.workloads.lm.input_pipeline import get_hf_dataloader + + loader = get_hf_dataloader( + cache_dir=data_dir, + data_rng=data_rng, + batch_size=global_batch_size, + seq_len=self._seq_len, + framework="torch", + split=split) + seq_len = self._seq_len weights = None - - while True: - # Only iterate over tf input pipeline in one Python process to - # avoid creating too many threads. - if RANK == 0: - batch = next(np_iter) # pylint: disable=stop-iteration-return - inputs = torch.as_tensor( - batch['inputs'], dtype=dtype, - device=DEVICE) # (N_GPUS, global_batch_size, seq_len) - targets = torch.as_tensor( - batch['targets'], dtype=dtype, - device=DEVICE) # (N_GPUS, global_batch_size, seq_len) - - # Send batch to other devices when using DDP. - if USE_PYTORCH_DDP: - if not_train: - # During eval, the batch size of the remainder might be different. - per_device_batch_size = torch.tensor( - len(targets[0]), dtype=dtype, device=DEVICE) - dist.broadcast(per_device_batch_size, src=0) - # We don't broadcast the shard for RANK 0. - dist.broadcast(inputs[1:], src=0) - dist.broadcast(targets[1:], src=0) - - # RANK 0 extracts his shard. If not DDP, this just flattens. - inputs, targets = inputs[0], targets[0] - - else: - # Receive batch from rank 0. - if not_train: - # During eval, the batch size of the remainder might be different. - per_device_batch_size = torch.empty((1,), dtype=dtype, device=DEVICE) + + dtype = torch.long + is_train = split == 'train' + + for batch in loader: + inputs, targets = batch + + if USE_PYTORCH_DDP: + if not is_train: + # During eval, the batch size of the remainder might be different + per_device_batch_size = torch.tensor( + len(targets[0]), dtype=dtype, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) - - # N_GPUS - 1 since we don't broadcast the shard for RANK 0. - inputs = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), - dtype=dtype, - device=DEVICE) - targets = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), - dtype=dtype, - device=DEVICE) + + # Broadcast to all devices dist.broadcast(inputs, src=0) dist.broadcast(targets, src=0) - # RANK - 1 since we don't broadcast the shard for RANK 0. - inputs, targets = inputs[RANK - 1], targets[RANK - 1] + + if weights is None: + weights = torch.ones(inputs.shape[0], device=DEVICE) if weights is None: weights = torch.ones(per_device_batch_size, device=DEVICE) @@ -138,10 +108,51 @@ def _build_input_queue( } yield batch + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return 'output.weight' in param_name or 'output.bias' in param_name + def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" - pass + model = params + logits, _ = self.model_fn( + model, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + targets = batch['targets'] + + # Calculate cross-entropy loss + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + loss = -torch.sum(targets * log_probs) + return loss + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling in PyTorch.""" + vocab_size = logits_batch.shape[-1] + + if len(label_batch.shape) == len(logits_batch.shape): + # One-hot labels + log_probs = torch.nn.functional.log_softmax(logits_batch, dim=-1) + loss = -torch.sum(label_batch * log_probs, dim=-1) + else: + # Dense labels + loss = torch.nn.functional.cross_entropy( + logits_batch, + label_batch, + reduction='none') + + if mask_batch is not None: + loss = loss * mask_batch + + n_valid = mask_batch.sum() if mask_batch is not None else label_batch.shape[0] + return { + 'summed': loss.sum(), + 'n_valid_examples': n_valid, + 'per_example': loss + } diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index a06b17fdc..c10bf13e8 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -11,6 +11,7 @@ from algoperf import spec from algoperf.workloads.lm import input_pipeline +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader FLAGS = flags.FLAGS @@ -21,10 +22,13 @@ class BaseLmWorkload(spec.Workload): """LM workload.""" _vocab_size: int = 50257 - _seq_len: int = 512 + _seq_len: int = 5 + warmup_factor: float = 0.1 def __init__(self) -> None: - pass + super().__init__() + self._param_shapes = None + self._param_types = None @property def target_metric_name(self) -> str: @@ -36,14 +40,14 @@ def has_reached_validation_target(self, eval_result: float) -> bool: @property def validation_target_value(self) -> float: - pass + return 20.0 # Target perplexity - def has_reached_test_target(self, eval_result: float) -> bool: - return eval_result['test/ppl'] > self.test_target_value + def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: + return eval_result['test/ppl'] <= self.test_target_value @property def test_target_value(self) -> float: - pass + return 20.0 # Target perplexity @property def loss_type(self) -> spec.LossType: @@ -51,23 +55,23 @@ def loss_type(self) -> spec.LossType: @property def num_train_examples(self) -> int: - pass + return 1000000 # Example size @property def num_eval_train_examples(self) -> int: - pass + return 10000 # Subset for evaluation @property def num_validation_examples(self) -> int: - pass + return 50000 @property def num_test_examples(self) -> int: - pass + return 50000 @property def eval_batch_size(self) -> int: - pass + return 8 @property def train_mean(self): @@ -79,16 +83,16 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - pass + return 3600 * 4 # 4 hours @property def eval_period_time_sec(self) -> int: - pass + return 600 # 10 minutes @property def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" - pass + return 100000 @property def pre_ln(self) -> bool: @@ -116,13 +120,22 @@ def _build_input_queue(self, repeat_final_dataset: bool = False): """Build an input queue for the given split.""" - @abc.abstractmethod def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False) + + loss_dict = self.loss_fn(batch['targets'], logits) + return loss_dict['summed'] def _eval_model_on_split(self, split: str, @@ -145,9 +158,10 @@ def _eval_model_on_split(self, num_batches, repeat_final_dataset=True) + loss = 0.0 for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - loss += self._eval_batch(params, eval_batch) + loss += self._eval_batch(params, eval_batch, model_state, rng) if USE_PYTORCH_DDP: dist.all_reduce(loss) mean_loss = loss.item() / num_examples @@ -155,16 +169,11 @@ def _eval_model_on_split(self, # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. + @abc.abstractmethod def loss_fn( self, - label_batch: spec.Tensor, # Dense or one-hot labels. + label_batch: spec.Tensor, logits_batch: spec.Tensor, mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable - """Evaluate the (masked) loss function at (label_batch, logits_batch). - - Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of - valid examples in batch, 'per_example': 1-d array of per-example losses} - (not synced across devices). - """ - pass + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling.""" diff --git a/pyproject.toml b/pyproject.toml index f4ebdaee3..745c6c680 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,7 @@ version_file = "algoperf/_version.py" [project.optional-dependencies] # All workloads full = [ - "algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]", + "algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt,lm]", ] # All workloads plus development dependencies full_dev = ["algoperf[full,dev]"] @@ -96,6 +96,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.18.0"] +lm = ["transformers", "datasets"] # Frameworks jax_core_deps = [ diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 49e46109b..c570e382b 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -90,12 +90,6 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) -# @functools.partial( -# jax.pmap, -# axis_name='batch', -# in_axes=(None, None, 0, 0, 0, 0, 0, None, None), -# static_broadcasted_argnums=(0, 1), -# donate_argnums=(2, 3, 4)) def train_step(workload, opt_update_fn, model_state, @@ -272,6 +266,8 @@ def get_batch_size(workload_name): return 16 elif workload_name == 'cifar': return 128 + elif workload_name == 'lm': + return 8 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/submission_runner.py b/submission_runner.py index fa300916e..fd1eb8259 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -250,7 +250,8 @@ def train_once( 'ogbg', 'criteo1tb', 'imagenet_vit', - 'librispeech_deepspeech' + 'librispeech_deepspeech', + 'lm' ] eager_backend_workloads = [] aot_eager_backend_workloads = [] @@ -712,7 +713,8 @@ def main(_): 'librispeech_conformer', 'librispeech_deepspeech', 'imagenet_vit', - 'criteo1tb' + 'criteo1tb', + 'lm' ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' From 99c7b9b70a374a25d6ac29c4f9a0f7c95e57c1aa Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 3 Apr 2025 12:46:53 -0400 Subject: [PATCH 27/82] add nanodo model --- algoperf/workloads/lm/lm_jax/nanodo_model.py | 345 ++++++++++++++++++ algoperf/workloads/lm/lm_jax/workload.py | 56 ++- .../paper_baselines/adamw/jax/submission.py | 4 +- 3 files changed, 386 insertions(+), 19 deletions(-) create mode 100644 algoperf/workloads/lm/lm_jax/nanodo_model.py diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py new file mode 100644 index 000000000..d21fd5090 --- /dev/null +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -0,0 +1,345 @@ +# Self-contained version of the DecoderOnly Transformer from NanoDO + +import dataclasses +from functools import partial + +from flax import linen as nn +import jax +import jax.numpy as jnp + +# =========== Transformer Decoder-only Model ========== + + + +@dataclasses.dataclass +class DoConfig: + """Hyper-parameters for Transformer decoder-only.""" + + D: int # model/embed dim = qkv dim + H: int # num attention heads + L: int # max context/sequence length + N: int # number of transformer block layers + V: int # vocab size + F: int # FF inner dimension + kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform() + embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling( + 1.0, "fan_in", "normal", out_axis=0 + ) + dtype: jnp.dtype = jnp.float32 + rmsnorm_epsilon: float = 1e-6 + multiple_of: int = 256 + tie_embeddings: bool = True # Whether to tie input and output embeddings + + +class Mlp(nn.Module): + """Multilayer perceptron with GLU activation.""" + + cfg: DoConfig + + @nn.compact + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg + # Use Xavier uniform initialization explicitly + xavier_init = nn.initializers.xavier_uniform() + linear = partial( + nn.Dense, kernel_init=xavier_init, use_bias=False, dtype=cfg.dtype + ) + hidden_dim = cfg.multiple_of * ( + (cfg.F + cfg.multiple_of - 1) // cfg.multiple_of + ) + # Double the hidden dimension for GLU + x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD) + # Apply GLU activation + x_BxLxF = nn.glu(x_BxLx2F, axis=-1) + x_BxLxD = linear(cfg.D)(x_BxLxF) + return x_BxLxD + +@partial(jax.jit, static_argnums=(0,1,2)) +def init_rope(dim=256, seq_len=128, n_heads=4): + """Initialize rotary embeddings.""" + def precompute_freqs_cis_jax(dim, end, theta=10000.0): + inv_freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2) / dim)) + t = jnp.arange(end) / 1.0 + freqs = jnp.outer(t, inv_freqs).astype(jnp.float32) + return jnp.stack([ + jnp.cos(freqs)[None, :, None, :], + jnp.sin(freqs)[None, :, None, :] + ], axis=3) + + freqs_cis = precompute_freqs_cis_jax(dim // n_heads, seq_len, theta=500000) + return freqs_cis.transpose(0, 1, 2, 4, 3) + +@jax.jit +def apply_rope(q, k, freqs_cis): + """Apply rotary embeddings to Q and K.""" + def rotate_tensor(x): + # Split into real and imaginary parts + x_r2 = x.reshape(*x.shape[:-1], -1, 2) + L = x.shape[1] + freqs = freqs_cis[:, :L, :, :, :] + + # Apply rotation + rotated_x_r2 = jnp.stack([ + x_r2[..., 0] * freqs[..., 0] - x_r2[..., 1] * freqs[..., 1], + x_r2[..., 1] * freqs[..., 0] + x_r2[..., 0] * freqs[..., 1] + ], axis=-1) + + return rotated_x_r2.reshape(*x.shape) + + # Apply rotation to Q and K separately + rotated_q = rotate_tensor(q) + rotated_k = rotate_tensor(k) + + return rotated_q, rotated_k + + +class CausalAttn(nn.Module): + """Causal attention layer with rotary embeddings.""" + + cfg: DoConfig + + def setup(self): + cfg = self.cfg + assert cfg.D % cfg.H == 0, f"D {cfg.D} not divisible by H {cfg.H}" + self.Dh = cfg.D // cfg.H + + # Initialize rotary embeddings + self.freqs_cis = init_rope(cfg.D, cfg.L, cfg.H) + + # Maps D -> (H, Dh) + self.multilinear = partial( + nn.DenseGeneral, + axis=-1, + features=(cfg.H, self.Dh), + kernel_init=cfg.kernel_init, + use_bias=False, + dtype=cfg.dtype, + ) + + self.multilinear_query = self.multilinear(name="query") + self.multilinear_key = self.multilinear(name="key") + self.multilinear_value = self.multilinear(name="value") + self.output_projection = nn.DenseGeneral( + features=cfg.D, + name="attn_out_proj", + # axis=(-2, -1), # + kernel_init=cfg.kernel_init, + use_bias=False, + dtype=cfg.dtype, + ) + + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg + + # Project inputs to Q, K, V + q_BxLxHxDh = self.multilinear_query(x_BxLxD) + k_BxLxHxDh = self.multilinear_key(x_BxLxD) + v_BxLxHxDh = self.multilinear_value(x_BxLxD) + + # Apply rotary embeddings to Q and K + q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis) + + # Scale queries + q_BxLxHxDh /= self.Dh**0.5 + + # Compute attention scores + att_BxHxLxL = jnp.einsum("...qhd,...khd->...hqk", q_BxLxHxDh, k_BxLxHxDh) + + # Causal attention mask + L = x_BxLxD.shape[1] + mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_)) + + # Apply mask and softmax + _NEG_INF = jnp.finfo(cfg.dtype).min + att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) + att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) + att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) + + # Compute attention output + out_BxLxHxDh = jnp.einsum("...hqk,...khd->...qhd", att_BxHxLxL, v_BxLxHxDh) + + # Reshape and project output + out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape) + + # Output projection + out_BxLxD = self.output_projection(out_BxLxD) + + return out_BxLxD + + +class TBlock(nn.Module): + """Transformer Block.""" + + docfg: DoConfig + + @nn.compact + def __call__(self, in_BxLxD: jax.Array): + cfg = self.docfg + + # x = x + attn( attn_norm(x) ) + x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + in_BxLxD + ) + x_BxLxD = CausalAttn(cfg)(x_BxLxD) + x_BxLxD += in_BxLxD + + # x = x + mlp( mlp_norm(x) ) + z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + x_BxLxD + ) + z_BxLxD = Mlp(cfg)(z_BxLxD) + + return x_BxLxD + z_BxLxD + + +class TransformerDo(nn.Module): + """Transformer decoder-only.""" + + docfg: DoConfig + + def setup(self): + cfg = self.docfg + self.embed = nn.Embed( + num_embeddings=cfg.V, + features=cfg.D, + embedding_init=cfg.embed_init, + ) + + self.blocks = [TBlock(cfg) for _ in range(cfg.N)] + self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) + + # Output projection - tied to input embeddings if configured + if cfg.tie_embeddings: + self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) + else: + self.output_proj = nn.Dense( + cfg.V, + kernel_init=cfg.embed_init, + dtype=cfg.dtype, + name="output_proj" + ) + + def __call__(self, y_BxL: jax.Array): + # For training on concatenated examples. + y_BxLxD = self.embed(y_BxL) + for block in self.blocks: + y_BxLxD = block(y_BxLxD) + y_BxLxD = self.out_ln(y_BxLxD) + logits_BxLxV = self.output_proj(y_BxLxD) + return logits_BxLxV + + def predict(self, y_BxL: jax.Array, k: int = 1): + """Generate k tokens autoregressively. + + Args: + y_BxL: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict + + Returns: + Tuple of (input_ids, predicted_ids) + """ + cfg = self.docfg + batch_size = y_BxL.shape[0] + seq_len = y_BxL.shape[1] + + # Store original input + original_input = y_BxL + + # Make sure we don't exceed the model's context length + if seq_len + k > cfg.L: + raise ValueError( + f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.L})" + ) + + # Generate k tokens autoregressively + for _ in range(k): + # Get logits for the entire sequence + logits = self(y_BxL) + + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] + + # Get the most likely token + next_token = jnp.argmax(next_token_logits, axis=-1) + + # Append the predicted token to the sequence + y_BxL = jnp.concatenate([y_BxL, next_token[:, None]], axis=1) + + # Return original input and the k predicted tokens + return original_input, y_BxL[:, -k:] + + +# =========== Demo Code ========== + + +def main(): + """Create and run the DecoderOnly Transformer model.""" + # Initialize model configuration with smaller parameters for demo + B, L = (2, 128) # Batch size, sequence length + cfg = DoConfig(D=128, H=4, L=L, N=2, V=256, F=4 * 128) + model = TransformerDo(cfg) + + # Print model info + print(f"\nModel Configuration:") + print(f" - Model dimension (D): {cfg.D}") + print(f" - Number of heads (H): {cfg.H}") + print(f" - Max sequence length (L): {cfg.L}") + print(f" - Number of layers (N): {cfg.N}") + print(f" - Vocabulary size (V): {cfg.V}") + print(f" - Feed forward dimension (F): {cfg.F}") + + # Create random input tokens (simulated token IDs) + rng_key = jax.random.PRNGKey(42) + input_rng, init_rng = jax.random.split(rng_key) + + # Generate random token IDs (integers between 0 and vocab_size-1) + x_BxL = jax.random.randint( + input_rng, shape=(B, L), minval=0, maxval=cfg.V, dtype=jnp.int32 + ) + + # Initialize model parameters + print("\nInitializing model parameters...") + params = model.init(init_rng, x_BxL) + + # Print parameter count + param_count = sum(x.size for x in jax.tree_util.tree_leaves(params)) + print(f"Total parameters: {param_count:,}") + + # Make a prediction (forward pass) + print("\nRunning forward pass...") + logits = model.apply(params, x_BxL) + + # Print output shape and sample values + print(f"\nOutput shape: {logits.shape} (batch_size, sequence_length, vocab_size)") + print(f"Output data type: {logits.dtype}") + + # Print sample logits (first 5 positions of the first sequence) + print("\nSample logits (first sequence, first 5 positions, first 5 values):") + for position in range(min(5, L)): + print(f" Position {position}: {logits[0, position, :5]}") + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print("\nPredicted token IDs (first sequence, first 10 positions):") + print(predictions[0, :10]) + + # Test the predict function + print("\nTesting predict function...") + # Use a shorter + short_seq = x_BxL[:, :10] + print(f"Input sequence shape: {short_seq.shape}") + + # Predict 5 tokens + k = 5 + original, predicted = model.apply(params, short_seq, k, method=model.predict) + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print("\nPredicted token IDs (first sequence, first 10 positions):") + print(predictions[0, :10]) + + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 7cb50302f..9fdfe6f60 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -10,7 +10,8 @@ from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload -from algoperf.workloads.lm.lm_jax.models import LinearModel +from algoperf.workloads.lm.lm_jax.nanodo_model import ( + TransformerDo, DoConfig, init_rope, apply_rope) from algoperf.workloads.lm.input_pipeline import get_hf_dataloader @@ -42,12 +43,22 @@ def init_model_fn( dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - self._model = LinearModel(vocab_size=self._vocab_size) - input_shape = (1, self._seq_len, self._vocab_size) + # Initialize NanoDO transformer model + cfg = DoConfig( + D=512, # model dim + H=8, # num heads + L=self._seq_len, + N=6, # num layers + V=self._vocab_size, + F=2048, # feedforward dim + dtype=jnp.float32 + ) + self._model = TransformerDo(cfg) + input_shape = (1, self._seq_len) # For token IDs + params_rng, init_rng = jax.random.split(rng) - print(params_rng) - # variables = model.init(init_rng, jnp.ones(input_shape, jnp.float32)) - variables = jax.jit(self._model.init)({'params': params_rng}, jnp.ones(input_shape, jnp.float32)) + variables = jax.jit(self._model.init)({'params': params_rng}, + jnp.ones(input_shape, jnp.int32)) params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -66,6 +77,11 @@ def model_fn( del mode, rng, update_batch_norm, model_state inputs = batch['inputs'] + + # Convert one-hot inputs to token IDs if needed + if inputs.ndim == 3: # one-hot encoded + inputs = jnp.argmax(inputs, axis=-1) + logits = self._model.apply({'params': params}, inputs) return logits, None @@ -76,23 +92,29 @@ def loss_fn( mask_batch: Optional[spec.Tensor] = None, label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: """Compute cross-entropy loss for language modeling in JAX.""" - vocab_size = logits_batch.shape[-1] + # Convert one-hot labels to token IDs if needed + if len(label_batch.shape) == len(logits_batch.shape): # one-hot + label_batch = jnp.argmax(label_batch, axis=-1) - if len(label_batch.shape) == len(logits_batch.shape): - # One-hot labels - loss = -jnp.sum(label_batch * jax.nn.log_softmax(logits_batch, axis=-1)) - else: - # Dense labels - loss = -jax.nn.log_softmax(logits_batch)[jnp.arange(label_batch.shape[0]), label_batch] + # Reshape for sequence modeling + logits = logits_batch.reshape(-1, logits_batch.shape[-1]) + labels = label_batch.reshape(-1) + + # Compute cross-entropy loss + loss = -jnp.sum( + jax.nn.log_softmax(logits)[jnp.arange(labels.shape[0]), labels]) if mask_batch is not None: - loss = loss * mask_batch + mask = mask_batch.reshape(-1) + loss = loss * mask + n_valid = mask.sum() + else: + n_valid = labels.shape[0] - n_valid = mask_batch.sum() if mask_batch is not None else label_batch.shape[0] return { - 'summed': loss.sum(), + 'summed': loss, 'n_valid_examples': n_valid, - 'per_example': loss + 'per_example': loss / n_valid # Return per-token loss } def is_output_params(self, param_name: str) -> bool: diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 6c6d19ef8..dca9a6b95 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -75,7 +75,6 @@ def _loss_fn(params): spec.ForwardPassMode.TRAIN, rng, update_batch_norm=True,) - jax.debug.print("logits: {logits}", logits=logits) loss_dict = workload.loss_fn( label_batch=batch['targets'], logits_batch=logits, @@ -163,7 +162,6 @@ def update_params( replicated, # loss replicated # grad_norm )) - # print(batch) new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, opt_update_fn, model_state, @@ -229,6 +227,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'lm': + return 4 else: raise ValueError(f'Unsupported workload name: {workload_name}.') From 706d9f74046a0f1c90256ae584b45e30a38e4349 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 3 Apr 2025 13:26:15 -0400 Subject: [PATCH 28/82] torch model --- algoperf/param_utils.py | 2 + .../workloads/lm/lm_pytorch/plainlm_model.py | 298 ++++++++++++++++++ algoperf/workloads/lm/lm_pytorch/workload.py | 57 ++-- .../adamw/pytorch/submission.py | 2 + 4 files changed, 341 insertions(+), 18 deletions(-) create mode 100644 algoperf/workloads/lm/lm_pytorch/plainlm_model.py diff --git a/algoperf/param_utils.py b/algoperf/param_utils.py index 05d882404..24f981546 100644 --- a/algoperf/param_utils.py +++ b/algoperf/param_utils.py @@ -43,6 +43,8 @@ def pytorch_param_types( param_types[name] = spec.ParameterType.ATTENTION_BIAS elif 'in_proj' in name: param_types[name] = spec.ParameterType.ATTENTION_QKV + elif 'qkv' in name: + param_types[name] = spec.ParameterType.ATTENTION_QKV elif 'kv_proj' in name: param_types[name] = spec.ParameterType.ATTENTION_KV elif 'k_proj' in name or 'key' in name: diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py new file mode 100644 index 000000000..627a0e16d --- /dev/null +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -0,0 +1,298 @@ +import math +import torch +import torch.nn.functional as F +from torch import nn +from dataclasses import dataclass +from typing import Tuple + + + +@dataclass +class ModelConfig: + vocab_size: int + seq_len: int + dim: int + expand: float + n_layers: int + n_heads: int + rmsnorm_eps: float = 1e-6 + tie_embeddings: bool = False + + +class MLP(nn.Module): + + def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): + super().__init__() + hidden_dim = multiple_of * ( + (hidden_dim + multiple_of - 1) // multiple_of) + self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) + self.fc2 = nn.Linear(hidden_dim, dim, bias=False) + self.glu = nn.GLU(dim=2) + + # Initialize with Xavier uniform + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + + def forward(self, x): + # x: (bsz, T, dim) + return self.fc2(self.glu(self.fc1(x))) + + +def precompute_freqs_cis(dim: int, + end: int, + theta: float = 10000.0, + condense_ratio: int = 1): + inv_freqs = 1.0 / (theta**(torch.arange( + 0, dim, 2, dtype=torch.float32, device=torch.device("cpu")) / dim)) + t = torch.arange(end, dtype=torch.float32, + device=inv_freqs.device) / condense_ratio + freqs = torch.outer(t, inv_freqs).float() + return torch.stack([ + torch.cos(freqs)[None, :, None, :], + torch.sin(freqs)[None, :, None, :] + ], + dim=4) + + +def apply_rotary_emb_complex_like( + q: torch.Tensor, k: torch.Tensor, + freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # Rotate query and key vectors using RoPE + qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() + rotated_qk_r2 = torch.stack( + [ + qk_r2[..., 0] * freqs_cis[..., 0] - + qk_r2[..., 1] * freqs_cis[..., 1], + qk_r2[..., 1] * freqs_cis[..., 0] + + qk_r2[..., 0] * freqs_cis[..., 1], + ], + -1, + ).flatten(3) + rotated_qk = rotated_qk_r2 + return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) + + +class Attention(nn.Module): + + def __init__(self, cfg: ModelConfig): + super().__init__() + assert cfg.dim % cfg.n_heads == 0 + self.dim = cfg.dim + self.n_heads = cfg.n_heads + self.head_dim = cfg.dim // cfg.n_heads + + self.w_qkv = nn.Linear(cfg.dim, 3 * cfg.dim, bias=False) + self.w_out = nn.Linear(cfg.dim, cfg.dim, bias=False) + + def forward(self, x, freqs_cis): + bsz, seqlen, d = x.shape # (bsz, seqlen, d) + + q, k, v = self.w_qkv(x).split(d, dim=2) # (bsz, seqlen, d) + q = q.view(bsz, seqlen, self.n_heads, + self.head_dim) # (bsz, seqlen, nh, h_dim) + k = k.view(bsz, seqlen, self.n_heads, + self.head_dim) # (bsz, seqlen, nh, h_dim) + v = v.view(bsz, seqlen, self.n_heads, + self.head_dim) # (bsz, seqlen, nh, h_dim) + + q, k = apply_rotary_emb_complex_like( + q, k, freqs_cis=freqs_cis) # (bsz, seqlen, nh, h_dim) + + q = q.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + k = k.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + v = v.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + + out = F.scaled_dot_product_attention( + q, k, v, is_causal=True) # (bsz, nh, seqlen, h_dim) + + out = out.transpose(1, 2).contiguous().view(bsz, seqlen, + d) # (bsz, seqlen, d) + + return self.w_out(out) + + +class Block(nn.Module): + + def __init__(self, layer_id: int, cfg: ModelConfig): + super().__init__() + self.attn = Attention(cfg) + self.attn_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.mlp = MLP(dim=cfg.dim, hidden_dim=int(cfg.expand * cfg.dim)) + self.mlp_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.layer_id = layer_id + + def forward(self, x, freqs_cis): + # x: (bsz, seqlen, dim) + x = x + self.attn(self.attn_norm(x), freqs_cis) + x = x + self.mlp(self.mlp_norm(x)) + return x + + +class Transformer(nn.Module): + + def __init__(self, cfg): + super().__init__() + self.n_layers = cfg.n_layers + self.cfg = cfg + head_dim = cfg.dim // cfg.n_heads + assert cfg.dim % cfg.n_heads == 0 + + self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.dim) + self.layers = nn.ModuleList( + [Block(idx, cfg) for idx in range(cfg.n_layers)]) + self.out_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) + + # Initialize freqs_cis on CPU first (more memory efficient) + self.register_buffer('freqs_cis', + precompute_freqs_cis(head_dim, cfg.seq_len, 500000)[0:cfg.seq_len], + persistent=False) + + # init all weights, scale residual branches + self.apply(self._init_weights) + self._scale_residual_branches() + + # Move model to device (which will also move freqs_cis) + if torch.cuda.is_available(): + self.cuda() + + if cfg.tie_embeddings: + self.tie_weights() + + def forward(self, x): + # x: (bsz, seqlen) + x = self.embed_tokens(x) # (bsz, seqlen, dim) + L = x.shape[1] + + # Make sure we have enough precomputed frequencies + if L > self.freqs_cis.shape[1]: + # Need to recompute for longer sequence + head_dim = self.cfg.dim // self.cfg.n_heads + new_freqs = precompute_freqs_cis(head_dim, max(L, self.cfg.seq_len), 500000) + self.register_buffer('freqs_cis', new_freqs[0:max(L, self.cfg.seq_len)], persistent=False) + if torch.cuda.is_available(): + self.freqs_cis = self.freqs_cis.cuda() + + # Select the frequencies for current sequence length and ensure correct device + freqs_cis = self.freqs_cis[:, :L, :].to(x.device) + + for layer in self.layers: + x = layer(x, freqs_cis) # (bsz, seqlen, dim) + return self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) + + def predict(self, x, k=1): + """Generate k tokens autoregressively. + + Args: + x: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict + + Returns: + Tuple of (input_ids, predicted_ids) + """ + # For debugging + predictions = [] + + batch_size = x.shape[0] + seq_len = x.shape[1] + + # Store original input + original_input = x.clone() + generated_input = x.clone() + + # Generate k tokens autoregressively + for i in range(k): + # Get logits for the entire sequence + logits = self(generated_input) + + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] + + # Zero out the last token ID to prevent repetition + # This is a common issue - the model gets stuck repeating the last token + last_token_id = generated_input[:, -1] + next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) + + # Print top 5 tokens for debugging + if i == 0: + print("\nPyTorch detailed prediction:") + top5_values, top5_indices = torch.topk(next_token_logits[0], 5) + for j, (idx, val) in enumerate(zip(top5_indices.tolist(), top5_values.tolist())): + prob = torch.softmax(next_token_logits[0], dim=-1)[idx].item() + print(f" Top {j+1}: Token {idx}, logit={val:.2f}, prob={prob:.6f}") + + # Get the most likely token + next_token = torch.argmax(next_token_logits, dim=-1) + predictions.append(next_token.item()) + + # Append the predicted token to the sequence + next_token = next_token.unsqueeze(1) # Add sequence dimension + generated_input = torch.cat([generated_input, next_token], dim=1) + + print(f" Full predictions step by step: {predictions}") + + # Return all tokens, not just the last k + return original_input, generated_input[:, -k:] + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def _scale_residual_branches(self): + for n, p in self.named_parameters(): + if n.endswith("fc2.weight"): # mlp/glu output layer + torch.nn.init.normal_(p, + mean=0.0, + std=0.02 / math.sqrt(2 * self.n_layers)) + if n.endswith("w_out.weight"): # attn output layer + torch.nn.init.normal_(p, + mean=0.0, + std=0.02 / math.sqrt(2 * self.n_layers)) + + def tie_weights(self): + self.lm_head.weight = self.embed_tokens.weight + + def count_params(self, non_embedding=True): + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.embed_tokens.weight.numel() + if (not self.lm_head.weight + is self.embed_tokens.weight): # if no weight tying + n_params -= self.lm_head.weight.numel() + return n_params + + +def main(): + print("Initializing transformer model and running forward pass...") + + seq_length = 512 + + # Define model configuration + config = ModelConfig( + vocab_size=32000, # Common vocab size for tokenizers like BPE or SentencePiece + seq_len=seq_length, # Maximum sequence length + dim=768, # Embedding dimension + expand=4.0, # MLP expansion factor + n_layers=12, # Number of transformer layers + n_heads=12, # Number of attention heads + rmsnorm_eps=1e-6, # RMSNorm epsilon + tie_embeddings=True # Tie embedding and output weights + ) + + def tie_weights(self): + self.lm_head.weight = self.embed_tokens.weight + + def count_params(self, non_embedding=True): + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.embed_tokens.weight.numel() + if (not self.lm_head.weight + is self.embed_tokens.weight): # if no weight tying + n_params -= self.lm_head.weight.numel() + return n_params + + diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 0d0281690..45ad0828f 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -11,7 +11,7 @@ from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload -from algoperf.workloads.lm.lm_pytorch.models import LinearLayer +from algoperf.workloads.lm.lm_pytorch.plainlm_model import Transformer, ModelConfig USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -26,11 +26,23 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: if hasattr(self, '_model'): - self._model.reset_parameters() + # Reinitialize weights but keep same config + self._model.apply(self._model._init_weights) + self._model._scale_residual_branches() return self._model, None torch.manual_seed(rng[0]) - self._model = LinearLayer(vocab_size=self._vocab_size) + cfg = ModelConfig( + vocab_size=self._vocab_size, + seq_len=self._seq_len, + dim=512, # Model dimension + expand=4, # MLP expansion factor + n_layers=6, # Number of transformer layers + n_heads=8, # Number of attention heads + rmsnorm_eps=1e-6, + tie_embeddings=True + ) + self._model = Transformer(cfg) self._param_shapes = param_utils.pytorch_param_shapes(self._model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) self._model.to(DEVICE) @@ -46,15 +58,20 @@ def init_model_fn( def model_fn( self, params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - del model_state, rng, update_batch_norm # Not used for linear model + del model_state, rng, update_batch_norm model = params - inputs = batch['inputs'].float() # Convert one-hot to float + + # Convert one-hot inputs to token IDs if needed + inputs = augmented_and_preprocessed_input_batch['inputs'] + if inputs.dim() == 3: # one-hot encoded + inputs = inputs.argmax(dim=-1) + logits = model(inputs) return logits, None @@ -83,13 +100,14 @@ def _build_input_queue( is_train = split == 'train' for batch in loader: - inputs, targets = batch + inputs = batch['inputs'] + targets = batch['targets'] if USE_PYTORCH_DDP: if not is_train: # During eval, the batch size of the remainder might be different per_device_batch_size = torch.tensor( - len(targets[0]), dtype=dtype, device=DEVICE) + targets.shape[0], dtype=dtype, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) # Broadcast to all devices @@ -97,10 +115,8 @@ def _build_input_queue( dist.broadcast(targets, src=0) if weights is None: - weights = torch.ones(inputs.shape[0], device=DEVICE) - - if weights is None: - weights = torch.ones(per_device_batch_size, device=DEVICE) + batch_size = targets.shape[0] if not USE_PYTORCH_DDP else per_device_batch_size.item() + weights = torch.ones((batch_size, seq_len), device=DEVICE) batch = { 'inputs': inputs, 'targets': targets, @@ -110,7 +126,7 @@ def _build_input_queue( def is_output_params(self, param_name: str) -> bool: """Return whether the given parameter is an output parameter.""" - return 'output.weight' in param_name or 'output.bias' in param_name + return 'lm_head.weight' in param_name or 'lm_head.bias' in param_name def _eval_batch(self, params: spec.ParameterContainer, @@ -121,11 +137,17 @@ def _eval_batch(self, model = params logits, _ = self.model_fn( model, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - targets = batch['targets'] - # Calculate cross-entropy loss - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) - loss = -torch.sum(targets * log_probs) + # Handle both one-hot and token ID targets + targets = batch['targets'] + if targets.dim() == 3: # one-hot + loss = -torch.sum(targets * torch.nn.functional.log_softmax(logits, dim=-1)) + else: # token IDs + loss = torch.nn.functional.cross_entropy( + logits.view(-1, logits.size(-1)), + targets.view(-1), + reduction='sum' + ) return loss def loss_fn( self, @@ -146,7 +168,6 @@ def loss_fn( logits_batch, label_batch, reduction='none') - if mask_batch is not None: loss = loss * mask_batch diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 21d9b6b57..bdeaaf95b 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -173,6 +173,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'lm': + return 4 else: raise ValueError(f'Unsupported workload name: {workload_name}.') From c335e341913dc6b1a747f2d3407e71a8d8e66ab6 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 29 May 2025 14:22:50 +0000 Subject: [PATCH 29/82] lm workload dataset integration in jax --- .../workloads/cifar/cifar_jax/workload.py | 11 - algoperf/workloads/lm/input_pipeline.py | 12 +- algoperf/workloads/lm/lm_jax/models.py | 3 +- algoperf/workloads/lm/lm_jax/workload.py | 68 +++- algoperf/workloads/lm/lm_pytorch/workload.py | 49 +-- algoperf/workloads/lm/workload.py | 313 +++++++++--------- .../nesterov/jax/submission.py | 8 +- submission_runner.py | 6 +- 8 files changed, 261 insertions(+), 209 deletions(-) diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index f827fac87..fd990eeaa 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -71,17 +71,6 @@ def _build_input_queue( cache, repeat_final_dataset) - def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - """Sync the batch statistics across replicas.""" - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics - # and we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() - new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) - return new_model_state - def init_model_fn( self, rng: spec.RandomState, diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index cc658501e..8f68fcb55 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -87,19 +87,19 @@ def batch_iterator(): tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) inputs, targets = tokens[:, :-1], tokens[:, 1:] inputs, targets = jax.device_put(inputs), jax.device_put(targets) - yield inputs, targets - + batch = { + "inputs": inputs, + "targets": targets, + } + yield batch return batch_iterator() def get_lm_dataset(data_rng: jax.random.PRNGKey, split: str, data_dir: str, - vocab_size: int, global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False, - vocab_path: Optional[str] = None): + num_batches: Optional[int] = None): """Load HF dataset and return a TF dataset.""" dataset_path = os.path.join(data_dir, split) diff --git a/algoperf/workloads/lm/lm_jax/models.py b/algoperf/workloads/lm/lm_jax/models.py index edfc102fa..7913f2c67 100644 --- a/algoperf/workloads/lm/lm_jax/models.py +++ b/algoperf/workloads/lm/lm_jax/models.py @@ -14,5 +14,6 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: return nn.Dense( self.vocab_size, kernel_init=nn.initializers.normal(0.02), - bias_init=nn.initializers.zeros + bias_init=nn.initializers.zeros, + name="output" )(x) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 30b0c7867..6ad0e7d3d 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -2,16 +2,36 @@ from typing import Dict, Optional, Tuple +import jax import jax.numpy as jnp +import optax from flax import jax_utils from algoperf import param_utils +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload from algoperf.workloads.lm.lm_jax.models import LinearModel +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader, get_lm_dataset class LmWorkload(BaseLmWorkload): """LM JAX workload.""" + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + """Build an input queue using pre-cached FineWeb dataset.""" + del num_batches + del repeat_final_dataset + loader = get_lm_dataset( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + return loader def init_model_fn( self, @@ -21,14 +41,15 @@ def init_model_fn( model = LinearModel(vocab_size=self._vocab_size) input_shape = (1, self._seq_len, self._vocab_size) - variables = model.init(rng, jnp.ones(input_shape, jnp.float32)) - model_state, params = variables.pop('params') - + params_rng, init_rng = jax.random.split(rng) + variables = jax.jit(model.init)({'params': params_rng}, + jnp.ones(input_shape, jnp.float32)) + params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) - + params = sharding_utils.shard_replicated(params) + model_state = None + self._model = model return params, model_state def model_fn( @@ -40,15 +61,40 @@ def model_fn( rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - del mode, rng, update_batch_norm # Not used for linear model - inputs = batch['inputs'] - logits = self._model.apply({'params': params, **model_state}, inputs) - return logits, model_state + del mode, rng, update_batch_norm, model_state + inputs = jax.nn.one_hot(batch['inputs'], self._vocab_size, axis=-1) + logits = self._model.apply({'params': params}, inputs) + return logits, None + + def loss_fn( + self, + label_batch: spec.Tensor, # One-hot labels. + logits_batch: spec.Tensor, # Dense logits. + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: Optional[float] = 0.0) -> Dict[str, spec.Tensor]: + del mask_batch, label_smoothing + logits_flat = logits_batch.reshape(-1, self._vocab_size) + targets = jax.nn.one_hot(label_batch, self._vocab_size, axis=-1) + targets_flat = targets.reshape(-1, self._vocab_size) + # Cross-entropy loss + loss = -jnp.sum(targets_flat * jax.nn.log_softmax(logits_flat, axis=-1)) + n_valid_examples = logits_flat.shape[0] + return {'summed': loss, 'n_valid_examples': n_valid_examples} + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return param_name.contains('output') + def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" - pass + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + targets = batch['targets'] + + # Calculate cross-entropy loss + loss = -jnp.sum(targets * jax.nn.log_softmax(logits, axis=-1)) + return loss diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 3395aa08f..2c6862160 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -66,35 +66,30 @@ def _build_input_queue( global_batch_size: int, num_batches: Optional[int] = None, repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: - not_train = split != 'train' - per_device_batch_size = int(global_batch_size / N_GPUS) - - seq_len = self._seq_len # TODO: define it somewehere else? - dtype = torch.int32 # TODO: decide between int32 and int64. - - # Only create and iterate over tf input pipeline in one Python process to - # avoid creating too many threads. - if RANK == 0: - np_iter = super()._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) + """Build an input queue for the given split.""" + from algoperf.workloads.lm.input_pipeline import get_hf_dataloader + + loader = get_hf_dataloader( + cache_dir=data_dir, + data_rng=data_rng, + batch_size=global_batch_size, + seq_len=self._seq_len, + framework="torch", + split=split) + seq_len = self._seq_len weights = None - + while True: # Only iterate over tf input pipeline in one Python process to # avoid creating too many threads. if RANK == 0: - batch = next(np_iter) # pylint: disable=stop-iteration-return + batch = next(dataset_iter) # pylint: disable=stop-iteration-return inputs = torch.as_tensor( batch['inputs'], dtype=dtype, - device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + device=DEVICE) # (N_GPUS, per_device_batch_size, seq_len) targets = torch.as_tensor( batch['targets'], dtype=dtype, - device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + device=DEVICE) # (N_GPUS, per_device_batch_size, seq_len) # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: @@ -138,10 +133,22 @@ def _build_input_queue( } yield batch + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return 'output.weight' in param_name or 'output.bias' in param_name + def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" - pass + model = params + logits, _ = self.model_fn( + model, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + targets = batch['targets'] + + # Calculate cross-entropy loss + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + loss = -torch.sum(targets * log_probs) + return loss diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index a06b17fdc..e6b33e3e4 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -11,160 +11,171 @@ from algoperf import spec from algoperf.workloads.lm import input_pipeline +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader FLAGS = flags.FLAGS -USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ +USE_PYTORCH_DDP = "LOCAL_RANK" in os.environ class BaseLmWorkload(spec.Workload): - """LM workload.""" - - _vocab_size: int = 50257 - _seq_len: int = 512 - - def __init__(self) -> None: - pass - - @property - def target_metric_name(self) -> str: - """The name of the target metric (useful for scoring/processing code).""" - return 'ppl' - - def has_reached_validation_target(self, eval_result: float) -> bool: - return eval_result['validation/ppl'] > self.validation_target_value - - @property - def validation_target_value(self) -> float: - pass - - def has_reached_test_target(self, eval_result: float) -> bool: - return eval_result['test/ppl'] > self.test_target_value - - @property - def test_target_value(self) -> float: - pass - - @property - def loss_type(self) -> spec.LossType: - return spec.LossType.SOFTMAX_CROSS_ENTROPY - - @property - def num_train_examples(self) -> int: - pass - - @property - def num_eval_train_examples(self) -> int: - pass - - @property - def num_validation_examples(self) -> int: - pass - - @property - def num_test_examples(self) -> int: - pass - - @property - def eval_batch_size(self) -> int: - pass - - @property - def train_mean(self): - raise NotImplementedError - - @property - def train_stddev(self): - raise NotImplementedError - - @property - def max_allowed_runtime_sec(self) -> int: - pass - - @property - def eval_period_time_sec(self) -> int: - pass - - @property - def step_hint(self) -> int: - """Approx. steps the baseline can do in the allowed runtime budget.""" - pass - - @property - def pre_ln(self) -> bool: - return True - - @property - def attention_temp(self) -> float: - return 1.0 - - @property - def activation(self) -> str: - return 'silu' - - @property - def glu(self) -> bool: - return True - - @abc.abstractmethod - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): - """Build an input queue for the given split.""" - - @abc.abstractmethod - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> spec.Tensor: - """Evaluate the model on a single batch.""" - - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: - """Run a full evaluation of the model.""" - num_batches = int(math.ceil(num_examples / global_batch_size)) - if split not in self._eval_iters: - # These iterators will repeat indefinitely. - self._eval_iters[split] = self._build_input_queue( - rng, - split, - data_dir, - global_batch_size, - num_batches, - repeat_final_dataset=True) - - for _ in range(num_batches): - eval_batch = next(self._eval_iters[split]) - loss += self._eval_batch(params, eval_batch) - if USE_PYTORCH_DDP: - dist.all_reduce(loss) - mean_loss = loss.item() / num_examples - return {'loss': mean_loss} - - # Does NOT apply regularization, which is left to the submitter to do in - # `update_params`. - def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable - """Evaluate the (masked) loss function at (label_batch, logits_batch). - - Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of - valid examples in batch, 'per_example': 1-d array of per-example losses} - (not synced across devices). - """ - pass + """LM workload.""" + + _vocab_size: int = 50257 + _seq_len: int = 512 + + def __init__(self) -> None: + pass + + @property + def target_metric_name(self) -> str: + """The name of the target metric (useful for scoring/processing code).""" + return "ppl" + + def has_reached_validation_target(self, eval_result: float) -> bool: + return eval_result["validation/ppl"] > self.validation_target_value + + @property + def validation_target_value(self) -> float: + pass + + def has_reached_test_target(self, eval_result: float) -> bool: + return eval_result["test/ppl"] > self.test_target_value + + @property + def test_target_value(self) -> float: + pass + + @property + def loss_type(self) -> spec.LossType: + return spec.LossType.SOFTMAX_CROSS_ENTROPY + + @property + def num_train_examples(self) -> int: + pass + + @property + def num_eval_train_examples(self) -> int: + pass + + @property + def num_validation_examples(self) -> int: + pass + + @property + def num_test_examples(self) -> int: + pass + + @property + def eval_batch_size(self) -> int: + return 8 + + @property + def train_mean(self): + raise NotImplementedError + + @property + def train_stddev(self): + raise NotImplementedError + + @property + def max_allowed_runtime_sec(self) -> int: + pass + + @property + def eval_period_time_sec(self) -> int: + pass + + @property + def step_hint(self) -> int: + """Approx. steps the baseline can do in the allowed runtime budget.""" + # FIXME: should replace this with a real value later. + return 10000 + + @property + def pre_ln(self) -> bool: + return True + + @property + def attention_temp(self) -> float: + return 1.0 + + @property + def activation(self) -> str: + return "silu" + + @property + def glu(self) -> bool: + return True + + @abc.abstractmethod + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False, + ): + """Build an input queue for the given split.""" + + @abc.abstractmethod + def _eval_batch( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> spec.Tensor: + """Evaluate the model on a single batch.""" + + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: + """Run a full evaluation of the model.""" + num_batches = int(math.ceil(num_examples / global_batch_size)) + if split not in self._eval_iters: + # These iterators will repeat indefinitely. + self._eval_iters[split] = self._build_input_queue( + rng, + split, + data_dir, + global_batch_size, + num_batches, + repeat_final_dataset=True, + ) + + loss = 0.0 + for _ in range(num_batches): + eval_batch = next(self._eval_iters[split]) + loss += self._eval_batch(params, eval_batch, model_state, rng) + if USE_PYTORCH_DDP: + dist.all_reduce(loss) + mean_loss = loss.item() / num_examples + return {"loss": mean_loss} + + # Does NOT apply regularization, which is left to the submitter to do in + # `update_params`. + def loss_fn( + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable + """Evaluate the (masked) loss function at (label_batch, logits_batch). + + Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 1-d array of per-example losses} + (not synced across devices). + """ + pass diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 49e46109b..c570e382b 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -90,12 +90,6 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) -# @functools.partial( -# jax.pmap, -# axis_name='batch', -# in_axes=(None, None, 0, 0, 0, 0, 0, None, None), -# static_broadcasted_argnums=(0, 1), -# donate_argnums=(2, 3, 4)) def train_step(workload, opt_update_fn, model_state, @@ -272,6 +266,8 @@ def get_batch_size(workload_name): return 16 elif workload_name == 'cifar': return 128 + elif workload_name == 'lm': + return 8 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/submission_runner.py b/submission_runner.py index fa300916e..fd1eb8259 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -250,7 +250,8 @@ def train_once( 'ogbg', 'criteo1tb', 'imagenet_vit', - 'librispeech_deepspeech' + 'librispeech_deepspeech', + 'lm' ] eager_backend_workloads = [] aot_eager_backend_workloads = [] @@ -712,7 +713,8 @@ def main(_): 'librispeech_conformer', 'librispeech_deepspeech', 'imagenet_vit', - 'criteo1tb' + 'criteo1tb', + 'lm' ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' From af8cce4d61e7f79916d7293127121ebaa4a4d7ce Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 5 Jun 2025 03:20:46 +0000 Subject: [PATCH 30/82] set package versions for transformers and datasets --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 745c6c680..5e9c21f47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,7 +96,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.18.0"] -lm = ["transformers", "datasets"] +lm = ["transformers==4.25.4", "datasets==3.6.0"] # Frameworks jax_core_deps = [ From d68c54e0aa023570abc94cea97f5757bfb0baca8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 5 Jun 2025 04:02:41 +0000 Subject: [PATCH 31/82] use train_test_split method to shuffle and split fineweb-edu dataset --- dataset/dataset_setup.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 6587f1439..7a83a03f6 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -770,18 +770,10 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) - # Find how many entries to take from dataset to have val_tokens in validation set. - val_tokens = 10_000_000 # TODO: decide this value. - tokens_accumulated, num_examples_for_val = 0, 0 - for example in tokenized_dataset: - tokens_accumulated += len(example['input_ids']) - num_examples_for_val += 1 - if tokens_accumulated >= val_tokens: - break # Split in train and valid. - val_dataset = tokenized_dataset.select(range(num_examples_for_val)) - train_dataset = tokenized_dataset.select( - range(num_examples_for_val, len(tokenized_dataset))) + dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42) + train_dataset = dataset_split_dict['train'] + val_dataset = dataset_split_dict['test'] # Concat in chunks of max_seq_len. # NOTE: expected token loss by batched concat_chunk. Truncates leftover tokens that don't fill a full max_seq_length chunk. From 9737367473f35b206333edc46f9c193ec8dda821 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 9 Jun 2025 19:45:32 +0000 Subject: [PATCH 32/82] modifications to fwedu datasetup --- dataset/dataset_setup.py | 164 +++++++++++++++++---------------------- 1 file changed, 73 insertions(+), 91 deletions(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 7a83a03f6..584189c4a 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -191,6 +191,7 @@ flags.DEFINE_string('framework', None, 'Can be either jax or pytorch.') flags.DEFINE_boolean('skip_download', False, 'Skips data download.') +flags.DEFINE_boolean('skip_tokenization', False, 'Skip Fineweb-edu tokenization.') FLAGS = flags.FLAGS @@ -707,106 +708,87 @@ def download_wmt(data_dir): ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7) -def download_finewebedu(data_dir, tmp_dir=None): +def download_finewebedu(data_dir, + tmp_dir=None, + skip_download=False, + skip_tokenization=False): """Download FineWebEdu-10B.""" - data_dir = os.path.join(data_dir, 'finewebedu') - tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' - cache_dir = os.path.join(tmp_dir, - 'lm') if tmp_dir is not None else os.path.expanduser( - '~/.cache/huggingface/datasets') - - _maybe_mkdir(data_dir) - _maybe_mkdir(tmp_dir) - _maybe_mkdir(cache_dir) - - os.environ["TMPDIR"] = tmp_dir - - ds = hf_datasets.load_dataset( - 'HuggingFaceFW/fineweb-edu', - name='sample-10BT', - split='train', - cache_dir=cache_dir) - # TODO (nico): maybe save intermediate dataset to avoid re-downloading - # and allow re-chunking with different seq_len? - - # Shuffle so that multiproc has shards of similar size. - ds = ds.shuffle(seed=1996) - - seq_len = 2048 - max_seq_length = seq_len + 1 - map_setup = dict(batched=True, batch_size=1024, num_proc=8) - - # Tokenize - lm_tokenizer = AutoTokenizer.from_pretrained('gpt2') - logging.info(f"Vocab size of lm_tokenizer = {len(lm_tokenizer)}") - - def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - add_eos = lambda seq: (seq + lm_tokenizer.eos_token) if seq else seq - add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs] - return lm_tokenizer( - add_eos_batched(examples["text"]), - return_special_tokens_mask=False, - return_attention_mask=False) - - lm_tokenizer.model_max_length = 1e30 # prevent truncation during tokenization - logging.info(f"Tokenizing...") - tokenized_dataset = ds.map( - tokenize, - remove_columns=[ - 'text', - 'id', - 'dump', - 'url', - 'file_path', - 'language', - 'language_score', - 'token_count', - 'score', - 'int_score' - ], - **map_setup) - lm_tokenizer.model_max_length = seq_len - - tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) + if not skip_download: + data_dir = os.path.join(data_dir, 'finewebedu') + tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' + cache_dir = os.path.join(tmp_dir, + 'lm') if tmp_dir is not None else os.path.expanduser( + '~/.cache/huggingface/datasets') + + _maybe_mkdir(data_dir) + _maybe_mkdir(tmp_dir) + _maybe_mkdir(cache_dir) + + os.environ["TMPDIR"] = tmp_dir + + ds = hf_datasets.load_dataset( + 'HuggingFaceFW/fineweb-edu', + name='sample-10BT', + split='train', + cache_dir=cache_dir) + ds.save_to_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) + else: + ds = hf_datasets.load_from_disk(tmp_dir, 'fwedu_10B_raw') + + if not skip_tokenization: + # Tokenize + lm_tokenizer = AutoTokenizer.from_pretrained('gpt2') + logging.info(f"Vocab size of lm_tokenizer = {len(lm_tokenizer)}") + + def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + + def add_eos(seq): + return seq + lm_tokenizer.eos_token if seq else seq + + def add_eos_batched(seqs): + return [add_eos(seq) for seq in seqs] + + return lm_tokenizer( + add_eos_batched(examples["text"]), + return_special_tokens_mask=False, + return_attention_mask=False) + + lm_tokenizer.model_max_length = 1e30 # prevent truncation during tokenization + logging.info("Tokenizing...") + tokenized_dataset = ds.map( + tokenize, + remove_columns=[ + 'text', + 'id', + 'dump', + 'url', + 'file_path', + 'language', + 'language_score', + 'token_count', + 'score', + 'int_score' + ],) + + tokenized_dataset.save_to_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) + else: + tokenized_dataset.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) # Split in train and valid. dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42) train_dataset = dataset_split_dict['train'] val_dataset = dataset_split_dict['test'] - # Concat in chunks of max_seq_len. - # NOTE: expected token loss by batched concat_chunk. Truncates leftover tokens that don't fill a full max_seq_length chunk. - def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - """Concatenate text and generate chunks of max_seq_length""" - concatenated_examples = { - k: list(itertools.chain(*examples[k])) for k in examples.keys() - } - total_length = len(concatenated_examples[list(examples.keys())[0]]) - if total_length >= max_seq_length: - total_length = (total_length // max_seq_length) * max_seq_length - result = { - k: [ - t[i:i + max_seq_length] - for i in range(0, total_length, max_seq_length) - ] for k, - t in concatenated_examples.items() - } - return result - - # Concat text in validation and train sets. - logging.info(f"Concatenating and chunking...") - val_dataset = val_dataset.map(concat_chunck, **map_setup) - train_dataset = train_dataset.map(concat_chunck, **map_setup) - logging.info( - f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}") - logging.info( - f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length:_}" - ) + # Convert to tensorflow_datasets.Dataset objects + train_dataset = train_dataset.to_tf_dataset() + val_dataset = train_dataset.to_tf_dataset() # Save datasets - train_dataset.save_to_disk(os.path.join(data_dir, f"train")) - val_dataset.save_to_disk(os.path.join(data_dir, f"val")) + train_dataset.Save(os.path.join(data_dir, "train")) + val_dataset.save(os.path.join(data_dir, "val")) + + return def main(_): @@ -893,7 +875,7 @@ def main(_): if FLAGS.all or FLAGS.finewebedu: logging.info('Downloading FineWebEdu-10B...') - download_finewebedu(data_dir, tmp_dir) + download_finewebedu(data_dir, tmp_dir, FLAGS.skip_download, FLAGS.skip_tokenization) # pylint: enable=logging-format-interpolation From 1bf0750e094a695176e8e3bc45ffd979abe9e237 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 9 Jun 2025 19:46:26 +0000 Subject: [PATCH 33/82] rename fwedu data dir --- dataset/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 584189c4a..ae27aab18 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -715,7 +715,7 @@ def download_finewebedu(data_dir, """Download FineWebEdu-10B.""" if not skip_download: - data_dir = os.path.join(data_dir, 'finewebedu') + data_dir = os.path.join(data_dir, 'fineweb_edu_10B') tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' cache_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser( From a33339117b4c79d5fa946f4f7ed029087ab5a630 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 9 Jun 2025 20:46:21 +0000 Subject: [PATCH 34/82] fix --- dataset/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index ae27aab18..289a1faa6 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -734,7 +734,7 @@ def download_finewebedu(data_dir, cache_dir=cache_dir) ds.save_to_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) else: - ds = hf_datasets.load_from_disk(tmp_dir, 'fwedu_10B_raw') + ds = hf_datasets.load_from_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) if not skip_tokenization: # Tokenize From 05dc4dd7102670cebb8ac3a8875b34387d57b9b6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 9 Jun 2025 21:22:57 +0000 Subject: [PATCH 35/82] add back batch mapping in tokenization for fwedu --- dataset/dataset_setup.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 289a1faa6..f50274615 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -769,7 +769,10 @@ def add_eos_batched(seqs): 'token_count', 'score', 'int_score' - ],) + ], + batched=True, + batch_size=1024, + num_proc=8) tokenized_dataset.save_to_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) else: From b374cf8db62e99e1594dea90b46a7f69a5bb04c6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 10 Jun 2025 00:12:24 +0000 Subject: [PATCH 36/82] debugging --- dataset/dataset_setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index f50274615..2c46f4ebc 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -779,9 +779,11 @@ def add_eos_batched(seqs): tokenized_dataset.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) # Split in train and valid. + print(type(tokenized_dataset)) dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42) train_dataset = dataset_split_dict['train'] val_dataset = dataset_split_dict['test'] + print(type(train_dataset)) # Convert to tensorflow_datasets.Dataset objects train_dataset = train_dataset.to_tf_dataset() From c0c1e3c32c46d65cb7511891b32429aeeb05f90c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 10 Jun 2025 00:13:48 +0000 Subject: [PATCH 37/82] debugging --- dataset/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 2c46f4ebc..c18e72ea4 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -776,7 +776,7 @@ def add_eos_batched(seqs): tokenized_dataset.save_to_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) else: - tokenized_dataset.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) + tokenized_dataset = hf_datasets.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) # Split in train and valid. print(type(tokenized_dataset)) From f76dc392fa83a1da25194d401aa03a9dd6dc9c6a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 10 Jun 2025 00:23:24 +0000 Subject: [PATCH 38/82] debugging --- dataset/dataset_setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index c18e72ea4..414b78609 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -778,6 +778,7 @@ def add_eos_batched(seqs): else: tokenized_dataset = hf_datasets.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) + tokenized_dataset.to_tf_dataset() # Split in train and valid. print(type(tokenized_dataset)) dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42) From e805fa7997daae83deea4e5336801af195270c1a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 10 Jun 2025 00:45:07 +0000 Subject: [PATCH 39/82] use tfds to shuffle and split dataset --- dataset/dataset_setup.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 414b78609..747d06d27 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -778,20 +778,18 @@ def add_eos_batched(seqs): else: tokenized_dataset = hf_datasets.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) - tokenized_dataset.to_tf_dataset() - # Split in train and valid. - print(type(tokenized_dataset)) - dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42) - train_dataset = dataset_split_dict['train'] - val_dataset = dataset_split_dict['test'] - print(type(train_dataset)) - # Convert to tensorflow_datasets.Dataset objects - train_dataset = train_dataset.to_tf_dataset() - val_dataset = train_dataset.to_tf_dataset() + tokenized_dataset = tokenized_dataset.to_tf_dataset() - # Save datasets - train_dataset.Save(os.path.join(data_dir, "train")) + # Shuffle dataset + dataset_size = tokenized_dataset.cardinality().numpy() + shuffled_dataset = tokenized_dataset.shuffle(dataset_size, seed=0) + train_size = int(0.9 * dataset_size) + train_dataset = shuffled_dataset.take(train_size) + val_dataset = shuffled_dataset.skip(train_size) + + # Split in train and valid. + train_dataset.save(os.path.join(data_dir, "train")) val_dataset.save(os.path.join(data_dir, "val")) return From c9e9abcdf0cc9c817c1683f7a40d94a9372752f3 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 2 Oct 2025 03:40:29 +0000 Subject: [PATCH 40/82] add command for fineweb-edu --- dataset/README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/dataset/README.md b/dataset/README.md index 1aeb83239..50ca11985 100644 --- a/dataset/README.md +++ b/dataset/README.md @@ -453,3 +453,13 @@ The preprocessing script will generate `.npy` files for audio data, `features.cs ```bash python3 librispeech_preprocess.py --data_dir=$DATA_DIR/librispeech --tokenizer_vocab_path=$DATA_DIR/librispeech/spm_model.vocab ``` + +### Fineweb-EDU 10B +From `algorithmic-efficiency` run: + +```bash +python3 python3 datasets/dataset_setup.py \ +--data_dir $DATA_DIR \ +--temp_dir $DATA_DIR/tmp \ +--fineweb_edu +``` \ No newline at end of file From e4323deca83a86ad1d703f056157dfcb0e0b1650 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 2 Oct 2025 03:42:16 +0000 Subject: [PATCH 41/82] fix --- dataset/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataset/README.md b/dataset/README.md index 50ca11985..1bfd9bf73 100644 --- a/dataset/README.md +++ b/dataset/README.md @@ -458,7 +458,7 @@ python3 librispeech_preprocess.py --data_dir=$DATA_DIR/librispeech --tokenizer_v From `algorithmic-efficiency` run: ```bash -python3 python3 datasets/dataset_setup.py \ +python3 datasets/dataset_setup.py \ --data_dir $DATA_DIR \ --temp_dir $DATA_DIR/tmp \ --fineweb_edu From f0c6e75ad70cb2c4242014c1522abb3b3bf9aa2e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 3 Oct 2025 06:23:26 +0000 Subject: [PATCH 42/82] update calls to sharing utils --- algoperf/workloads/lm/lm_jax/workload.py | 4 ++-- algoperf/workloads/lm/workload.py | 2 +- .../baselines/external_tuning/jax_nadamw_full_budget.py | 2 ++ 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index e73a5bfaf..81dde95fc 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -7,7 +7,7 @@ import optax from flax import jax_utils from algoperf import param_utils -from algoperf import sharding_utils +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload from algoperf.workloads.lm.lm_jax.models import LinearModel @@ -79,7 +79,7 @@ def init_model_fn( params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - params = sharding_utils.shard_replicated(params) + params = jax_sharding_utils.replicate(params) model_state = None return params, model_state diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 6b71c7952..2a9777354 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -92,7 +92,7 @@ def eval_period_time_sec(self) -> int: @property def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" - return 100000 + return 7000 @property def pre_ln(self) -> bool: diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index 0577cd4e0..6e40cdab1 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -394,6 +394,8 @@ def get_batch_size(workload_name): return 512 elif workload_name == 'wmt': return 128 + elif workload_name == 'lm': + return 128 elif workload_name == 'mnist': return 16 else: From f4ffbe709f6a867ea95ae55f4b47032caee98c4a Mon Sep 17 00:00:00 2001 From: rka97 Date: Mon, 6 Oct 2025 17:09:11 +0000 Subject: [PATCH 43/82] Fix torch sharding issue, update input pipeline and workload classes to use int32 for tensor types and add dropout rate parameter --- algoperf/workloads/lm/input_pipeline.py | 4 +- algoperf/workloads/lm/lm_jax/workload.py | 5 ++- algoperf/workloads/lm/lm_pytorch/workload.py | 37 ++++++++++--------- .../lm/tests/test_build_input_queue_torch.py | 15 +++++--- algoperf/workloads/lm/workload.py | 3 +- 5 files changed, 37 insertions(+), 27 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index db345700e..c010b32af 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -119,8 +119,8 @@ def tf_generator(): ds = tf.data.Dataset.from_generator( tf_generator, output_signature={ - "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int64), - "targets": tf.TensorSpec(shape=(None,), dtype=tf.int64), + "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int32), + "targets": tf.TensorSpec(shape=(None,), dtype=tf.int32), }) # Avoid creating too many threads when using PyTorch DDP. diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 81dde95fc..1f6b3c2b2 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -90,8 +90,9 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - del mode, rng, update_batch_norm, model_state + update_batch_norm: bool, + dropout_rate: float) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + del mode, rng, update_batch_norm, model_state, dropout_rate inputs = batch['inputs'] # Convert one-hot inputs to token IDs if needed if inputs.ndim == 3: # one-hot encoded diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 36e441e7e..e5dafdd3c 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -6,7 +6,8 @@ import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP - +from itertools import islice +from algoperf import data_utils from algoperf import param_utils from algoperf import pytorch_utils from algoperf import spec @@ -84,19 +85,22 @@ def _build_input_queue( num_batches: Optional[int] = None, repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: """Build an input queue for the given split.""" - from algoperf.workloads.lm.input_pipeline import get_hf_dataloader - - loader = get_hf_dataloader( - cache_dir=data_dir, + from algoperf.workloads.lm.input_pipeline import get_lm_dataset + local_batch_size = global_batch_size // N_GPUS + + loader = get_lm_dataset( data_rng=data_rng, - batch_size=global_batch_size, - seq_len=self._seq_len, - framework="torch", - split=split) + split=split, + data_dir=data_dir, + global_batch_size=local_batch_size, + num_batches=num_batches + ) + if USE_PYTORCH_DDP: + loader = islice(loader, RANK, None, N_GPUS) seq_len = self._seq_len weights = None - dtype = torch.long + dtype = torch.int32 is_train = split == 'train' for batch in loader: @@ -109,17 +113,16 @@ def _build_input_queue( per_device_batch_size = torch.tensor( targets.shape[0], dtype=dtype, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) - + local_batch_size = per_device_batch_size.item() # Broadcast to all devices - dist.broadcast(inputs, src=0) - dist.broadcast(targets, src=0) + #dist.broadcast(inputs, src=0) + #dist.broadcast(targets, src=0) if weights is None: - batch_size = targets.shape[0] if not USE_PYTORCH_DDP else per_device_batch_size.item() - weights = torch.ones((batch_size, seq_len), device=DEVICE) + weights = torch.ones((local_batch_size, seq_len), device=DEVICE) batch = { - 'inputs': inputs, - 'targets': targets, + 'inputs': torch.tensor(inputs, device=DEVICE, dtype=dtype), + 'targets': torch.tensor(targets, device=DEVICE, dtype=dtype), 'weights': weights, } yield batch diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py index 639e71491..827272037 100644 --- a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py +++ b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py @@ -17,9 +17,9 @@ def sync_ddp(): def test_dataloader_torch(): # Test config. rng_seed = 1996 - data_dir = '/fast/najroldi/data/finewebedu' + data_dir = '/home/ak4605/data/finewebedu/' split = 'train' - global_batch_size = 8 + global_batch_size = 64 dtype = torch.int32 seq_len = 2048 @@ -44,35 +44,40 @@ def test_dataloader_torch(): # print(f"inputs: {inputs}") # Start test. - for _ in range(100): + for _ in range(1): batch = next(input_queue) + print(f"RANK {RANK} got batch") assert type(batch) == dict assert 'inputs' in batch assert 'targets' in batch inputs, targets = batch['inputs'], batch['targets'] - + print(f"RANK {RANK} inputs.shape: {inputs.shape}") + print(f"RANK {RANK} targets.shape: {targets.shape}") + print(f"RANK {RANK} type(inputs): {type(inputs)}") assert type(inputs) == torch.Tensor assert type(targets) == torch.Tensor assert inputs.device == DEVICE assert targets.device == DEVICE - assert inputs.dtype == dtype assert targets.dtype == dtype + print(local_batch_size, seq_len) assert inputs.shape == (local_batch_size, seq_len) assert targets.shape == (local_batch_size, seq_len) assert torch.equal(inputs[:, 1:], targets[:, :-1]) + print(f"RANK {RANK} inputs[0, :10]: {inputs[0, :10]}") print(f"=== ALL TEST PASSED ===") def main(): profiler = PassThroughProfiler() + print(USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS) pytorch_init(USE_PYTORCH_DDP, RANK, profiler) test_dataloader_torch() diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 2a9777354..986a98297 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -132,7 +132,8 @@ def _eval_batch(self, model_state, spec.ForwardPassMode.EVAL, rng, - update_batch_norm=False) + update_batch_norm=False, + dropout_rate=None) loss_dict = self.loss_fn(batch['targets'], logits) return loss_dict['summed'] From 5c85c7e278ffa540d65b1d49f0bd1d0cad732052 Mon Sep 17 00:00:00 2001 From: rka97 Date: Mon, 6 Oct 2025 17:39:35 +0000 Subject: [PATCH 44/82] test working, lm workload training not working (debugging) --- algoperf/workloads/lm/lm_jax/workload.py | 3 +- .../lm/tests/test_build_input_queue_jax.py | 60 +++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 algoperf/workloads/lm/tests/test_build_input_queue_jax.py diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 1f6b3c2b2..5401ad240 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -33,9 +33,10 @@ def _build_input_queue(self, split=split, data_dir=data_dir, global_batch_size=global_batch_size) + loader = map(jax_sharding_utils.shard_along_batch_dim, loader) return loader - def _build_input_queue(self, + def _build_hf_input_queue(self, data_rng: jax.random.PRNGKey, split: str, data_dir: str, diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_jax.py b/algoperf/workloads/lm/tests/test_build_input_queue_jax.py new file mode 100644 index 000000000..b9adc70d2 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_build_input_queue_jax.py @@ -0,0 +1,60 @@ +import jax +import jax.numpy as jnp + +from algoperf.profiler import PassThroughProfiler +from algoperf.workloads.lm.lm_jax.workload import LmWorkload +import os + +RANK = os.environ.get('RANK', 0) + +def test_dataloader_jax(): + # Test config. + rng_seed = 1996 + data_dir = '/home/ak4605/data/finewebedu/' + split = 'train' + global_batch_size = 64 + dtype = jnp.int32 + seq_len = 2048 + + workload = LmWorkload() + data_rng = jax.random.PRNGKey(rng_seed) + input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + + for _ in range(1): + + batch = next(input_queue) + print(f"RANK {RANK} got batch") + + assert type(batch) == dict + assert 'inputs' in batch + assert 'targets' in batch + + inputs, targets = batch['inputs'], batch['targets'] + print(f"RANK {RANK} inputs.shape: {inputs.shape}") + print(f"RANK {RANK} targets.shape: {targets.shape}") + print(f"RANK {RANK} type(inputs): {type(inputs)}") + + jax.debug.inspect_array_sharding(inputs, callback=print) + assert inputs.dtype == dtype + assert targets.dtype == dtype + + assert inputs.shape == (global_batch_size, seq_len) + assert targets.shape == (global_batch_size, seq_len) + + assert jnp.equal(inputs[:, 1:], targets[:, :-1]).all() + print(f"RANK {RANK} inputs[0, :10]: {inputs[0, :10]}") + + print(f"=== ALL TEST PASSED ===") + + +def main(): + profiler = PassThroughProfiler() + test_dataloader_jax() + + +if __name__ == '__main__': + main() From a59dfda3a7ce87b5cad550f2332aaf049f59c8f6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 6 Oct 2025 18:33:29 +0000 Subject: [PATCH 45/82] updates to input_pipeline and model spec --- algoperf/workloads/lm/input_pipeline.py | 257 +++++++++---------- algoperf/workloads/lm/lm_jax/nanodo_model.py | 2 +- algoperf/workloads/lm/lm_jax/workload.py | 36 +-- algoperf/workloads/lm/lm_pytorch/workload.py | 5 +- algoperf/workloads/lm/workload.py | 98 +++---- 5 files changed, 187 insertions(+), 211 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index c010b32af..e674170e4 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -1,154 +1,129 @@ """Input pipeline for a LM dataset.""" + import functools import os from typing import Optional import jax -import jax.numpy as jnp import tensorflow as tf -import torch -import torch.nn.functional as F -from transformers import GPT2Tokenizer from algoperf import data_utils -from algoperf.pytorch_utils import pytorch_setup -from datasets import load_dataset -from datasets import load_from_disk - -RANK = pytorch_setup()[1] -# Avoid multithreading in all processes but the first (rank 0). -# This ensures that only the primary process (RANK == 0) uses TensorFlow's -# automatic optimization (AUTOTUNE), while other processes disable it (None). -# tf.data.AUTOTUNE is a constant that lets TensorFlow automatically determine -# the optimal number of elements to prefetch or parallelize for dataset -# operations, improving performance. -AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None - - -def get_hf_dataloader(cache_dir: str, - data_rng: jax.random.PRNGKey, - batch_size: int = 8, - seq_len: int = 32, - framework: str = "torch", - split="train"): + +AUTOTUNE = tf.data.experimental.AUTOTUNE +PAD_ID = -1 + +TFDS_SPLIT_NAME = {'train': 'train', 'eval_train': 'train', 'validation': 'val'} + +SEQUENCE_LENGTH = 2048 +MAX_CORPUS_CHARS = 1_000_000_000 +SHUFFLE_BUFFER_SIZE = 1_000_000 +VOCAB_SIZE = 50_257 + + +def batch_with_padding( + dataset: tf.data.Dataset, + batch_size, + padded_shapes=None, + padding_id=PAD_ID, +): + """Batches a tf.data.Dataset and adds padding if len(dataset) is not divisible by the batch size. + + Args: + dataset: tf.data.Dataset + batch_size: batch size of resulting batched dataset + padded_shapes: shapes of the padded batches + padding_id: value for padding, for elements in new batch + + Returns: """ - Create a data loader from HuggingFace's FineWeb dataset. - - Args: - cache_dir: Directory to cache the dataset - batch_size: Number of sequences per batch - seq_len: Length of each sequence - framework: Either "torch" or "jax" to specify output tensor type - split: Dataset split to load - """ - # Initialize tokenizer and get vocab size - tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") - vocab_size = tokenizer.vocab_size - # Load the FineWeb dataset in streaming mode - fw = load_dataset( - "HuggingFaceFW/fineweb-edu", - name="sample-10BT", - split=split, - streaming=True, - cache_dir=cache_dir) - fw = fw.batch(batch_size=batch_size, drop_last_batch=True) - if split in ['train', 'eval_train']: - fw = fw.shuffle(seed=int(data_rng[-1])) - - def _tokenize(x): - """Tokenize and pad text to seq_len+1 tokens.""" - if framework == "torch": - tokens = tokenizer(x, return_tensors="pt")["input_ids"].squeeze() - pad_length = seq_len - tokens.shape[0] - if pad_length > 0: - tokens = F.pad(tokens, pad_length, value=tokenizer.pad_token_id) - elif framework == "jax": - tokens = tokenizer(x, return_tensors="jax")["input_ids"].squeeze() - pad_length = seq_len - tokens.shape[0] - if pad_length > 0: - tokens = jnp.pad( - tokens, - pad_length, - mode="constant", - constant_values=tokenizer.pad_token_id) - return tokens[:seq_len + 1] - - def batch_iterator(): - for doc in fw: - if framework == "torch": - token_ids = torch.stack([_tokenize(x) for x in doc['text']]) - # Take first seq_len+1 tokens and convert to one-hot - tokens = F.one_hot(token_ids, num_classes=vocab_size).float() - # Split into input/target - inputs, targets = tokens[:, :-1, :], tokens[:, 1:, :] - inputs, targets = inputs.to("cuda"), targets.to("cuda") - elif framework == "jax": - token_ids = jnp.stack([_tokenize(x) for x in doc['text']]) - tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) - inputs, targets = tokens[:, :-1], tokens[:, 1:] - inputs, targets = jax.device_put(inputs), jax.device_put(targets) - yield {'inputs': inputs, 'targets': targets} - - return batch_iterator() - - -def get_lm_dataset(data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None): + batched_dataset = dataset.batch(batch_size, drop_remainder=False) + + # tf.data.Dataset.padded.batch pads elements in the batch so we call it + # again with batch_size=1 to pad each element in original batch. + padded_batched_dataset = batched_dataset.padded_batch( + 1, padded_shapes=padded_shapes, padding_values=padding_id + ) + + # Remove extra dimension resulting from the batch_size=1. + padded_batched_dataset = padded_batched_dataset.unbatch() + + return padded_batched_dataset + + +def get_data_iter(data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None,): + + ds = get_lm_dataset(data_rng, split, data_dir, global_batch_size, num_batches) + + it = map( + functools.partial( + data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size + ), + ds, + ) + + return iter(it) + +def get_lm_dataset( + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, +): """Load HF dataset and return a TF dataset.""" - - dataset_path = os.path.join(data_dir, split) - dataset = load_from_disk(dataset_path) - - is_training = split == "train" - shuffle = split in ['train', 'eval_train'] - - dataset.set_format("tensorflow") # tf.int64 # TODO (nico): is this needed? - - def tf_generator(): - """Generates data in a TensorFlow-friendly format.""" - for example in dataset: - yield { - "inputs": example["input_ids"][:-1], - "targets": example["input_ids"][1:], - } - - # Create a TensorFlow dataset - ds = tf.data.Dataset.from_generator( - tf_generator, - output_signature={ - "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int32), - "targets": tf.TensorSpec(shape=(None,), dtype=tf.int32), - }) - - # Avoid creating too many threads when using PyTorch DDP. - # Limits TensorFlow's threading for non-primary processes (RANK != 0) - if RANK != 0: - options = tf.data.Options() - options.threading.private_threadpool_size = 1 - ds = ds.with_options(options) - - if shuffle: - ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) - - if is_training: - ds = ds.repeat() - - # Batch the dataset, grouping consecutive elements into fixed-size chunks. - ds = ds.batch(global_batch_size, drop_remainder=is_training) - ds = ds.prefetch(AUTOTUNE) - - # Limit the dataset to a fixed number of batches if `num_batches` is specified - if num_batches: - ds = ds.take(num_batches) - - # Shard the dataset across multiple GPUs/TPUs if necessary - ds = map( - functools.partial( - data_utils.shard_and_maybe_pad_np, - global_batch_size=global_batch_size), - ds) + if split not in TFDS_SPLIT_NAME: + raise NotImplementedError + + shuffle_seed = jax.random.randint(data_rng, (), -2**31, 2**31-1) + + data_dir = os.path.join(data_dir, TFDS_SPLIT_NAME[split]) + tokens_ds = tf.data.Dataset.load(data_dir) + + # tokens + tokens_ds = tokens_ds.flat_map(tf.data.Dataset.from_tensor_slices) + + # sequences + sequences_ds = tokens_ds.batch(SEQUENCE_LENGTH + 1, drop_remainder=True) + + # get inputs and outputs + sequences_ds = sequences_ds.map( + lambda x: { + 'inputs': x['input_ids'][:SEQUENCE_LENGTH], + 'targets': x['input_ids'][1:], + }, + num_parallel_calls=AUTOTUNE, + ) + + # batch + if split == 'train': + shuffled_sequences_ds = sequences_ds.shuffle( + SHUFFLE_BUFFER_SIZE, seed=shuffle_seed + ) + repeated_sequences_dataset = shuffled_sequences_ds.repeat() + ds = repeated_sequences_dataset.batch( + global_batch_size, drop_remainder=False + ).take(100).prefetch(tf.data.experimental.AUTOTUNE) + elif split == 'eval_train': + ds = batch_with_padding( + sequences_ds, + global_batch_size, + padded_shapes={ + 'inputs': (global_batch_size, None), + 'targets': (global_batch_size, None), + }, + ).take(100).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size of validation + elif split == 'validation': + ds = batch_with_padding( + sequences_ds, + global_batch_size, + padded_shapes={ + 'inputs': (global_batch_size, None), + 'targets': (global_batch_size, None), + }, + ).take(100).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size return ds diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py index d21fd5090..ed469e1bd 100644 --- a/algoperf/workloads/lm/lm_jax/nanodo_model.py +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -3,9 +3,9 @@ import dataclasses from functools import partial -from flax import linen as nn import jax import jax.numpy as jnp +from flax import linen as nn # =========== Transformer Decoder-only Model ========== diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 5401ad240..49547fcef 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -4,16 +4,14 @@ import jax import jax.numpy as jnp -import optax -from flax import jax_utils -from algoperf import param_utils -from algoperf import jax_sharding_utils -from algoperf import spec -from algoperf.workloads.lm.workload import BaseLmWorkload -from algoperf.workloads.lm.lm_jax.models import LinearModel -from algoperf.workloads.lm.input_pipeline import get_hf_dataloader, get_lm_dataset + +from algoperf import jax_sharding_utils, param_utils, spec +from algoperf.workloads.lm.input_pipeline import get_data_iter from algoperf.workloads.lm.lm_jax.nanodo_model import ( - TransformerDo, DoConfig, init_rope, apply_rope) + DoConfig, + TransformerDo, +) +from algoperf.workloads.lm.workload import BaseLmWorkload class LmWorkload(BaseLmWorkload): @@ -28,7 +26,7 @@ def _build_input_queue(self, """Build an input queue using pre-cached FineWeb dataset.""" del num_batches del repeat_final_dataset - loader = get_lm_dataset( + loader = get_data_iter( data_rng=data_rng, split=split, data_dir=data_dir, @@ -46,14 +44,8 @@ def _build_hf_input_queue(self, """Build an input queue using HuggingFace FineWeb dataset.""" del num_batches del repeat_final_dataset - loader = get_hf_dataloader( - cache_dir=data_dir, - data_rng=data_rng, - batch_size=global_batch_size, - seq_len=self._seq_len, - framework="jax", - split=split) - return loader + iter = get_data_iter(data_rng, split, data_dir, global_batch_size) + return iter def init_model_fn( self, @@ -63,10 +55,10 @@ def init_model_fn( # Initialize NanoDO transformer model cfg = DoConfig( - D=512, # model dim - H=8, # num heads + D=2048, # model dim + H=16, # num heads L=self._seq_len, - N=6, # num layers + N=12, # num layers V=self._vocab_size, F=2048, # feedforward dim dtype=jnp.float32 @@ -92,7 +84,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode, rng, update_batch_norm, model_state, dropout_rate inputs = batch['inputs'] # Convert one-hot inputs to token IDs if needed diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index e5dafdd3c..5797de654 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -63,9 +63,10 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - del model_state, rng, update_batch_norm + del model_state, rng, update_batch_norm, dropout_rate model = params # Convert one-hot inputs to token IDs if needed diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 986a98297..8f17553ff 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -1,21 +1,20 @@ """LM workload parent class.""" import abc +from absl import logging import math import os from typing import Dict, Optional -from absl import flags import jax import torch.distributed as dist +from absl import flags from algoperf import spec -from algoperf.workloads.lm import input_pipeline -from algoperf.workloads.lm.input_pipeline import get_hf_dataloader FLAGS = flags.FLAGS -USE_PYTORCH_DDP = "LOCAL_RANK" in os.environ +USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ class BaseLmWorkload(spec.Workload): @@ -63,7 +62,7 @@ def num_eval_train_examples(self) -> int: @property def num_validation_examples(self) -> int: - return 50000 + return 50000 @property def num_test_examples(self) -> int: @@ -111,53 +110,60 @@ def glu(self) -> bool: return True @abc.abstractmethod - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False, + ): """Build an input queue for the given split.""" - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> spec.Tensor: + def _eval_batch( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> spec.Tensor: """Evaluate the model on a single batch.""" logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False, - dropout_rate=None) - + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) + loss_dict = self.loss_fn(batch['targets'], logits) return loss_dict['summed'] - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" num_batches = int(math.ceil(num_examples / global_batch_size)) if split not in self._eval_iters: # These iterators will repeat indefinitely. self._eval_iters[split] = self._build_input_queue( - rng, - split, - data_dir, - global_batch_size, - num_batches, - repeat_final_dataset=True) + rng, + split, + data_dir, + global_batch_size, + num_batches, + repeat_final_dataset=True, + ) loss = 0.0 for _ in range(num_batches): @@ -168,13 +174,15 @@ def _eval_model_on_split(self, mean_loss = loss.item() / num_examples return {'loss': mean_loss} + # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. @abc.abstractmethod def loss_fn( - self, - label_batch: spec.Tensor, - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: """Compute cross-entropy loss for language modeling.""" From 1c3cb6649b26c87e4bd7afd9c83fac84af9372ab Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 6 Oct 2025 22:15:38 +0000 Subject: [PATCH 46/82] add defaults for lm workload --- algoperf/workloads/lm/lm_jax/workload.py | 10 +++++----- algoperf/workloads/lm/workload.py | 6 +++++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 49547fcef..76739b590 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -54,13 +54,13 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: # Initialize NanoDO transformer model - cfg = DoConfig( - D=2048, # model dim - H=16, # num heads + cfg = DoConfig(u + D=self._emb_dim, # embedding dim + H=self._n_heads, # num heads L=self._seq_len, - N=12, # num layers + N=self._n_layers, # num layers V=self._vocab_size, - F=2048, # feedforward dim + F=self._mlp_dim, # feedforward dim dtype=jnp.float32 ) self._model = TransformerDo(cfg) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 8f17553ff..5cc783dba 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -21,7 +21,11 @@ class BaseLmWorkload(spec.Workload): """LM workload.""" _vocab_size: int = 50257 - _seq_len: int = 5 + _seq_len: int = 2048 + _emb_dim: int = 1024 + _n_heads: int = 8 + _n_layers: int = 12 + _mlp_dim: int = 4096 warmup_factor: float = 0.1 def __init__(self) -> None: From af91b120b2d5bd055f486aabdb3a881e28f3d231 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 7 Oct 2025 01:03:33 +0000 Subject: [PATCH 47/82] refactor eval pipeline and loss fn for lm --- algoperf/workloads/lm/input_pipeline.py | 8 +- algoperf/workloads/lm/lm_jax/workload.py | 92 +++++++++++-------- algoperf/workloads/lm/lm_pytorch/workload.py | 28 ++++-- algoperf/workloads/lm/workload.py | 52 +++++++---- .../external_tuning/jax_nadamw_full_budget.py | 4 +- submission_runner.py | 2 +- 6 files changed, 116 insertions(+), 70 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index e674170e4..91d6ae53c 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -10,13 +10,13 @@ from algoperf import data_utils AUTOTUNE = tf.data.experimental.AUTOTUNE -PAD_ID = -1 +PAD_ID = tf.constant(-1, dtype=tf.int64) TFDS_SPLIT_NAME = {'train': 'train', 'eval_train': 'train', 'validation': 'val'} -SEQUENCE_LENGTH = 2048 +SEQUENCE_LENGTH = 1024 MAX_CORPUS_CHARS = 1_000_000_000 -SHUFFLE_BUFFER_SIZE = 1_000_000 +SHUFFLE_BUFFER_SIZE = 1024 VOCAB_SIZE = 50_257 @@ -74,7 +74,7 @@ def get_lm_dataset( global_batch_size: int, num_batches: Optional[int] = None, ): - """Load HF dataset and return a TF dataset.""" + """Load preprocessed TF dataset.""" if split not in TFDS_SPLIT_NAME: raise NotImplementedError diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 76739b590..c3d84104b 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -1,9 +1,11 @@ """LM workload implemented in Jax.""" -from typing import Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import jax import jax.numpy as jnp +import optax +from flax.training import common_utils from algoperf import jax_sharding_utils, param_utils, spec from algoperf.workloads.lm.input_pipeline import get_data_iter @@ -54,7 +56,7 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: # Initialize NanoDO transformer model - cfg = DoConfig(u + cfg = DoConfig( D=self._emb_dim, # embedding dim H=self._n_heads, # num heads L=self._seq_len, @@ -84,7 +86,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode, rng, update_batch_norm, model_state, dropout_rate inputs = batch['inputs'] # Convert one-hot inputs to token IDs if needed @@ -93,41 +95,58 @@ def model_fn( logits = self._model.apply({'params': params}, inputs) return logits, None - def loss_fn( - self, - label_batch: spec.Tensor, - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: - """Compute cross-entropy loss for language modeling in JAX.""" - # Convert one-hot labels to token IDs if needed - if len(label_batch.shape) == len(logits_batch.shape): # one-hot - label_batch = jnp.argmax(label_batch, axis=-1) - - # Reshape for sequence modeling - logits = logits_batch.reshape(-1, logits_batch.shape[-1]) - labels = label_batch.reshape(-1) - - # Compute cross-entropy loss - loss = -jnp.sum( - jax.nn.log_softmax(logits)[jnp.arange(labels.shape[0]), labels]) - - if mask_batch is not None: - mask = mask_batch.reshape(-1) - loss = loss * mask - n_valid = mask.sum() - else: - n_valid = labels.shape[0] + + def compute_weighted_cross_entropy( + self, + logits: spec.Tensor, + targets: spec.Tensor, + weights: Optional[spec.Tensor] = None, + label_smoothing: float = 0.1, + ) -> Dict[str, spec.Tensor]: # differentiable + """Compute weighted cross entropy and entropy for log probs and targets. + + Args: + logits: [batch, length, num_classes] float array. + targets: categorical targets [batch, length] int array. + weights: array of shape [batch, length]. + label_smoothing: label smoothing constant, used to determine the on and off + values. + + Returns: + {'summed': scalar summed loss, 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 1-d array of per-example losses} + """ + if logits.ndim != targets.ndim + 1: + raise ValueError( + f'Incorrect shapes. Got shape {logits.shape} logits and ' + f'{targets.shape} targets.' + ) + smoothed_targets = optax.smooth_labels( + common_utils.onehot(targets, self._vocab_size), label_smoothing + ) + per_example_losses = -jnp.sum( + smoothed_targets * jax.nn.log_softmax(logits), axis=-1 + ) + if weights is None: + weights = jnp.ones_like(targets) + per_example_losses = jnp.where(weights, per_example_losses, 0.0) + summed_loss = per_example_losses.sum() + n_valid_examples = weights.sum() return { - 'summed': loss, - 'n_valid_examples': n_valid, - 'per_example': loss / n_valid # Return per-token loss + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } - def is_output_params(self, param_name: str) -> bool: - """Return whether the given parameter is an output parameter.""" - return param_name.contains('output') + def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + del num_examples + eval_denominator = total_metrics.pop('denominator') + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) + def _eval_batch(self, params: spec.ParameterContainer, @@ -140,5 +159,6 @@ def _eval_batch(self, targets = batch['targets'] # Calculate cross-entropy loss - loss = -jnp.sum(targets * jax.nn.log_softmax(logits, axis=-1)) - return loss + # TODO(kasimbeg): add weights? + loss_metrics = self.compute_weighted_cross_entropy(logits, targets) + return loss_metrics diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 5797de654..ddf99204d 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -1,18 +1,19 @@ """LM workload implemented in PyTorch.""" -from typing import Dict, Iterator, Optional, Tuple +from itertools import islice +from typing import Any, Dict, Iterator, Optional, Tuple import jax import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -from itertools import islice -from algoperf import data_utils -from algoperf import param_utils -from algoperf import pytorch_utils -from algoperf import spec + +from algoperf import data_utils, param_utils, pytorch_utils, spec +from algoperf.workloads.lm.lm_pytorch.plainlm_model import ( + ModelConfig, + Transformer, +) from algoperf.workloads.lm.workload import BaseLmWorkload -from algoperf.workloads.lm.lm_pytorch.plainlm_model import Transformer, ModelConfig USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -153,6 +154,7 @@ def _eval_batch(self, reduction='sum' ) return loss + def loss_fn( self, label_batch: spec.Tensor, @@ -181,3 +183,15 @@ def loss_fn( 'n_valid_examples': n_valid, 'per_example': loss } + +def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + del num_examples + if USE_PYTORCH_DDP: + for metric in total_metrics.values(): + dist.all_reduce(metric) + total_metrics = {k: v.item() for k, v in total_metrics.items()} + eval_denominator = total_metrics.pop('denominator') + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) \ No newline at end of file diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 5cc783dba..b1fa3d2a8 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -1,13 +1,11 @@ """LM workload parent class.""" import abc -from absl import logging import math import os -from typing import Dict, Optional +from typing import Any, Dict, Optional import jax -import torch.distributed as dist from absl import flags from algoperf import spec @@ -21,7 +19,7 @@ class BaseLmWorkload(spec.Workload): """LM workload.""" _vocab_size: int = 50257 - _seq_len: int = 2048 + _seq_len: int = 1024 _emb_dim: int = 1024 _n_heads: int = 8 _n_layers: int = 12 @@ -169,24 +167,38 @@ def _eval_model_on_split( repeat_final_dataset=True, ) - loss = 0.0 + eval_metrics = {} for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - loss += self._eval_batch(params, eval_batch, model_state, rng) - if USE_PYTORCH_DDP: - dist.all_reduce(loss) - mean_loss = loss.item() / num_examples - return {'loss': mean_loss} + metrics = self._eval_batch(params, eval_batch) + for metric_name, metric_value in metrics.items(): + if metric_name not in eval_metrics: + eval_metrics[metric_name] = 0.0 + eval_metrics[metric_name] += metric_value + eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) + + return eval_results - - # Does NOT apply regularization, which is left to the submitter to do in - # `update_params`. @abc.abstractmethod + def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + def loss_fn( - self, - label_batch: spec.Tensor, - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0, - ) -> Dict[str, spec.Tensor]: - """Compute cross-entropy loss for language modeling.""" + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling in JAX.""" + return self.compute_weighted_cross_entropy( + logits_batch, + label_batch, + weights=mask_batch, + label_smoothing=label_smoothing + ) + + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return param_name.contains('output') \ No newline at end of file diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index 6e40cdab1..9b4192de2 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -11,7 +11,7 @@ Tuple, Union, ) - +from absl import logging # isort: on import chex import jax @@ -395,7 +395,7 @@ def get_batch_size(workload_name): elif workload_name == 'wmt': return 128 elif workload_name == 'lm': - return 128 + return 64 elif workload_name == 'mnist': return 16 else: diff --git a/submission_runner.py b/submission_runner.py index 1c51ec58f..64a67e781 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -53,7 +53,7 @@ # Environment variables os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Disables tensorRT, cuda warnings. # disable only for deepspeech if it works fine for other workloads -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false --xla_dump_to=/logs/xla_dump_jax_lm_10_06_bsz64_seq1028 --xla_dump_hlo_as_proto' # TODO(znado): make a nicer registry of workloads that lookup in. BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR From 6b55adf5a65184d09d62a734db8fd3b6c33fdce2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 7 Oct 2025 03:41:09 +0000 Subject: [PATCH 48/82] refactor evaluation pipeline for lm --- algoperf/workloads/lm/input_pipeline.py | 15 ++++++++++--- algoperf/workloads/lm/lm_jax/workload.py | 28 +++++++----------------- algoperf/workloads/lm/workload.py | 26 ++++++++++++---------- 3 files changed, 35 insertions(+), 34 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 91d6ae53c..3a2e46923 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -5,6 +5,7 @@ from typing import Optional import jax +import numpy as np import tensorflow as tf from algoperf import data_utils @@ -106,7 +107,7 @@ def get_lm_dataset( repeated_sequences_dataset = shuffled_sequences_ds.repeat() ds = repeated_sequences_dataset.batch( global_batch_size, drop_remainder=False - ).take(100).prefetch(tf.data.experimental.AUTOTUNE) + ).prefetch(tf.data.experimental.AUTOTUNE) elif split == 'eval_train': ds = batch_with_padding( sequences_ds, @@ -115,7 +116,11 @@ def get_lm_dataset( 'inputs': (global_batch_size, None), 'targets': (global_batch_size, None), }, - ).take(100).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size of validation + ) + ds = ds.map(lambda x: {'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)}) + ds = ds.take(1000).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size of validation elif split == 'validation': ds = batch_with_padding( sequences_ds, @@ -124,6 +129,10 @@ def get_lm_dataset( 'inputs': (global_batch_size, None), 'targets': (global_batch_size, None), }, - ).take(100).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size + ) + ds = ds.map(lambda x: {'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)}) + ds = ds.take(1000).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size return ds diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index c3d84104b..bb19d6c30 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -28,26 +28,13 @@ def _build_input_queue(self, """Build an input queue using pre-cached FineWeb dataset.""" del num_batches del repeat_final_dataset - loader = get_data_iter( + ds = get_data_iter( data_rng=data_rng, split=split, data_dir=data_dir, global_batch_size=global_batch_size) - loader = map(jax_sharding_utils.shard_along_batch_dim, loader) - return loader - - def _build_hf_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): - """Build an input queue using HuggingFace FineWeb dataset.""" - del num_batches - del repeat_final_dataset - iter = get_data_iter(data_rng, split, data_dir, global_batch_size) - return iter + ds = map(jax_sharding_utils.shard_along_batch_dim, ds) + return ds def init_model_fn( self, @@ -156,9 +143,10 @@ def _eval_batch(self, """Evaluate the model on a single batch.""" logits, _ = self.model_fn( params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - targets = batch['targets'] - # Calculate cross-entropy loss # TODO(kasimbeg): add weights? - loss_metrics = self.compute_weighted_cross_entropy(logits, targets) - return loss_metrics + metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) + return { + 'loss': metrics['summed'], + 'denominator': metrics['n_valid_examples'], + } diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index b1fa3d2a8..b8e1ea144 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -2,6 +2,7 @@ import abc import math +import numpy as np import os from typing import Any, Dict, Optional @@ -44,11 +45,11 @@ def validation_target_value(self) -> float: return 20.0 # Target perplexity def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: - return eval_result['test/ppl'] <= self.test_target_value + return True # No test targets @property def test_target_value(self) -> float: - return 20.0 # Target perplexity + return None # No test targets @property def loss_type(self) -> spec.LossType: @@ -60,19 +61,19 @@ def num_train_examples(self) -> int: @property def num_eval_train_examples(self) -> int: - return 10000 # Subset for evaluation + return 500 # Subset for evaluation. # TODO(kasimbeg): update @property def num_validation_examples(self) -> int: - return 50000 + return 500 # TODO(kasimbeg update) @property def num_test_examples(self) -> int: - return 50000 + return 0 @property def eval_batch_size(self) -> int: - return 8 + return 32 @property def train_mean(self): @@ -84,7 +85,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 3600 * 4 # 4 hours + return 3600 * 5 # 4 hours @property def eval_period_time_sec(self) -> int: @@ -93,7 +94,7 @@ def eval_period_time_sec(self) -> int: @property def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" - return 7000 + return 54000 @property def pre_ln(self) -> bool: @@ -141,7 +142,7 @@ def _eval_batch( ) loss_dict = self.loss_fn(batch['targets'], logits) - return loss_dict['summed'] + return loss_dict def _eval_model_on_split( self, @@ -170,12 +171,15 @@ def _eval_model_on_split( eval_metrics = {} for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - metrics = self._eval_batch(params, eval_batch) + metrics = self._eval_batch(params, eval_batch, model_state, rng) for metric_name, metric_value in metrics.items(): if metric_name not in eval_metrics: eval_metrics[metric_name] = 0.0 eval_metrics[metric_name] += metric_value - eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) + + eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) + eval_results['ppl'] = np.exp(eval_results['loss']) + print(eval_results) return eval_results From 210d671fe7e78502cf321a52c0dfcafe6fa3580c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 7 Oct 2025 03:43:42 +0000 Subject: [PATCH 49/82] remove temporary flag for hlo dumps --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 64a67e781..1c51ec58f 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -53,7 +53,7 @@ # Environment variables os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Disables tensorRT, cuda warnings. # disable only for deepspeech if it works fine for other workloads -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false --xla_dump_to=/logs/xla_dump_jax_lm_10_06_bsz64_seq1028 --xla_dump_hlo_as_proto' +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' # TODO(znado): make a nicer registry of workloads that lookup in. BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR From 0ad7788302fdc8c5ea22379a0f15c047f75988af Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 7 Oct 2025 03:45:45 +0000 Subject: [PATCH 50/82] fix in workload target condition check --- algoperf/workloads/lm/workload.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index b8e1ea144..374b91ce6 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -38,7 +38,7 @@ def target_metric_name(self) -> str: return 'ppl' def has_reached_validation_target(self, eval_result: float) -> bool: - return eval_result['validation/ppl'] > self.validation_target_value + return eval_result['validation/ppl'] < self.validation_target_value @property def validation_target_value(self) -> float: @@ -178,9 +178,7 @@ def _eval_model_on_split( eval_metrics[metric_name] += metric_value eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) - eval_results['ppl'] = np.exp(eval_results['loss']) - print(eval_results) - + eval_results['ppl'] = np.exp(eval_results['loss']) return eval_results @abc.abstractmethod From 01921d5f6d0068e1d92808ad224b50ab19b60b15 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 8 Oct 2025 23:36:28 +0000 Subject: [PATCH 51/82] fix in mlp for glu --- algoperf/workloads/lm/lm_jax/nanodo_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py index ed469e1bd..bd7213620 100644 --- a/algoperf/workloads/lm/lm_jax/nanodo_model.py +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -44,6 +44,10 @@ def __call__(self, x_BxLxD: jax.Array): linear = partial( nn.Dense, kernel_init=xavier_init, use_bias=False, dtype=cfg.dtype ) + # Adjust hidden dimension to keep the number of parameters invariant to + # the activation function used since the GLU MLP has 3 * hidden_dim * D + # parameters instead of 2 * hidden_dim * D parameters + hidden_dim = cfg.F * 2 / 3 hidden_dim = cfg.multiple_of * ( (cfg.F + cfg.multiple_of - 1) // cfg.multiple_of ) From e42045083c1d28aba5fa5dd15f6993d4a8312880 Mon Sep 17 00:00:00 2001 From: rka97 Date: Fri, 10 Oct 2025 04:14:40 +0000 Subject: [PATCH 52/82] Fix OOM error in weighted cross entropy calculation --- algoperf/workloads/lm/lm_jax/workload.py | 44 +++++++++++-------- .../workloads/lm/lm_pytorch/plainlm_model.py | 2 +- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index bb19d6c30..c052794c8 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -84,21 +84,19 @@ def model_fn( def compute_weighted_cross_entropy( - self, - logits: spec.Tensor, - targets: spec.Tensor, - weights: Optional[spec.Tensor] = None, - label_smoothing: float = 0.1, - ) -> Dict[str, spec.Tensor]: # differentiable + self, + logits: spec.Tensor, + targets: spec.Tensor, + weights: Optional[spec.Tensor] = None, + label_smoothing: float = 0.1, + ) -> Dict[str, spec.Tensor]: # differentiable """Compute weighted cross entropy and entropy for log probs and targets. - Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. weights: array of shape [batch, length]. label_smoothing: label smoothing constant, used to determine the on and off values. - Returns: {'summed': scalar summed loss, 'n_valid_examples': scalar number of valid examples in batch, 'per_example': 1-d array of per-example losses} @@ -108,18 +106,26 @@ def compute_weighted_cross_entropy( f'Incorrect shapes. Got shape {logits.shape} logits and ' f'{targets.shape} targets.' ) - smoothed_targets = optax.smooth_labels( - common_utils.onehot(targets, self._vocab_size), label_smoothing - ) - - per_example_losses = -jnp.sum( - smoothed_targets * jax.nn.log_softmax(logits), axis=-1 - ) - if weights is None: - weights = jnp.ones_like(targets) - per_example_losses = jnp.where(weights, per_example_losses, 0.0) + # Compute log probabilities + log_probs = jax.nn.log_softmax(logits, axis=-1) + # Extract log probability of the target class + # Shape: [batch, length] + target_log_probs = jnp.take_along_axis( + log_probs, + targets[..., None], + axis=-1 + ).squeeze(-1) + # Cross-entropy with smoothing: -(1 - α) * log_p[target] - α * mean(log_p) + # The above formula is easy to derive from the definition of label smoothing and cross-entropy loss. + confidence = 1.0 - label_smoothing + smoothing_term = label_smoothing / self._vocab_size + per_example_losses = -1.0 * (confidence * target_log_probs + smoothing_term * log_probs.sum(axis=-1)) + if weights is not None: + per_example_losses = jnp.where(weights, per_example_losses, 0.0) + n_valid_examples = weights.sum() + else: + n_valid_examples = targets.shape[0] * targets.shape[1] summed_loss = per_example_losses.sum() - n_valid_examples = weights.sum() return { 'summed': summed_loss, 'n_valid_examples': n_valid_examples, diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py index 627a0e16d..225b98767 100644 --- a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -16,7 +16,7 @@ class ModelConfig: n_layers: int n_heads: int rmsnorm_eps: float = 1e-6 - tie_embeddings: bool = False + tie_embeddings: bool = True class MLP(nn.Module): From 3b31ad521d0037f80391de31582517cc291877be Mon Sep 17 00:00:00 2001 From: rka97 Date: Fri, 10 Oct 2025 04:15:27 +0000 Subject: [PATCH 53/82] fix issue with checkpointing bool --- algoperf/checkpoint_utils.py | 47 ++++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py index 2c8441d9c..00f05ba5d 100644 --- a/algoperf/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -5,7 +5,7 @@ """ import os -from typing import Sequence, Tuple +from typing import Sequence, Tuple, Optional import numpy as np import torch @@ -14,7 +14,8 @@ from flax.training import checkpoints as flax_checkpoints from flax.training.checkpoints import latest_checkpoint from tensorflow.io import gfile # pytype: disable=import-error - +import orbax.checkpoint as ocp +from orbax.checkpoint.type_handlers import NumpyHandler from algoperf import spec from algoperf.pytorch_utils import pytorch_setup @@ -29,6 +30,48 @@ int, ] +class BoolHandler(NumpyHandler): + """ + An implementation of TypeHandler for np.bool_ that inherits from NumpyHandler. + It works by treating the scalar as a 0-dimensional array. + """ + + def typestr(self) -> str: + """Unique string identifier for this handler.""" + return 'np.bool_' + + async def serialize( + self, + values: Sequence[np.bool_], + infos: Sequence, + args: Optional[Sequence[ocp.SaveArgs]] = None, + ): + """ + Serializes a sequence of np.bool_ scalars by first converting them + to 0-dim numpy arrays and then calling the parent NumpyHandler. + """ + # Convert each scalar np.bool_ to a 0-dimensional np.ndarray + array_values = [np.asarray(v, dtype=np.bool_) for v in values] + # Use the parent class's robust serialization logic + return await super().serialize(array_values, infos, args) + + async def deserialize( + self, + infos: Sequence, + args: Optional[Sequence[ocp.RestoreArgs]] = None, + ) -> Sequence[np.bool_]: + """ + Deserializes into a sequence of np.bool_ scalars by calling the + parent handler and then converting the resulting 0-dim arrays. + """ + # Parent deserialize will return a sequence of 0-dimensional np.ndarray + results = await super().deserialize(infos, args) + + # Convert each 0-d array back to an np.bool_ scalar using .item() + scalar_results = [np.bool_(r.item()) for r in results] + return scalar_results + +ocp.type_handlers.register_type_handler(np.bool_, BoolHandler(), override=True) def maybe_restore_checkpoint( framework: str, From bbc114fe730e351d3a721d78f6165f343e4c25cb Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 10 Oct 2025 04:33:15 +0000 Subject: [PATCH 54/82] increase buffer size --- algoperf/workloads/lm/input_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 3a2e46923..2fd27113a 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -17,7 +17,7 @@ SEQUENCE_LENGTH = 1024 MAX_CORPUS_CHARS = 1_000_000_000 -SHUFFLE_BUFFER_SIZE = 1024 +SHUFFLE_BUFFER_SIZE = 100_000 VOCAB_SIZE = 50_257 From 2b162e8d87603ad7ae2ac5020a26fd8c2bce974d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 10 Oct 2025 04:42:19 +0000 Subject: [PATCH 55/82] remove _eval_batch from jax workload --- algoperf/workloads/lm/lm_jax/workload.py | 17 ----------- algoperf/workloads/lm/workload.py | 36 +++++++++++------------- 2 files changed, 17 insertions(+), 36 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index c052794c8..801b1e0b4 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -139,20 +139,3 @@ def _normalize_eval_metrics( del num_examples eval_denominator = total_metrics.pop('denominator') return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) - - - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> spec.Tensor: - """Evaluate the model on a single batch.""" - logits, _ = self.model_fn( - params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - # Calculate cross-entropy loss - # TODO(kasimbeg): add weights? - metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) - return { - 'loss': metrics['summed'], - 'denominator': metrics['n_valid_examples'], - } diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 374b91ce6..f5d2cda38 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -124,25 +124,6 @@ def _build_input_queue( ): """Build an input queue for the given split.""" - def _eval_batch( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - ) -> spec.Tensor: - """Evaluate the model on a single batch.""" - logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False, - ) - - loss_dict = self.loss_fn(batch['targets'], logits) - return loss_dict def _eval_model_on_split( self, @@ -181,6 +162,23 @@ def _eval_model_on_split( eval_results['ppl'] = np.exp(eval_results['loss']) return eval_results + + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + # Calculate cross-entropy loss + metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) + return { + 'loss': metrics['summed'], + 'denominator': metrics['n_valid_examples'], + } + + @abc.abstractmethod def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any] From 617e1a3f3810bb73f15d998c25e54fa79ef04315 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 10 Oct 2025 04:45:44 +0000 Subject: [PATCH 56/82] add todo for pytorch _eval_batch cleanup --- algoperf/workloads/lm/lm_pytorch/workload.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index ddf99204d..71a8afd93 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -148,6 +148,7 @@ def _eval_batch(self, if targets.dim() == 3: # one-hot loss = -torch.sum(targets * torch.nn.functional.log_softmax(logits, dim=-1)) else: # token IDs + # TODO(kasimbeg): before deleting make sure we have defined self.weighted_cross_entropy so that we can call the shared workload _eval_batch. loss = torch.nn.functional.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), From 64ea658c04a2d13db75ab0b8fd1204cfe43f8746 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Oct 2025 01:34:05 +0000 Subject: [PATCH 57/82] add target setting algorithm for fineweb edu lm workload --- .../jax_nadamw_target_setting.py | 427 ++++++++++++++++++ .../fineweb_edu_lm/tuning_search_space.json | 11 + 2 files changed, 438 insertions(+) create mode 100644 algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py create mode 100644 algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py new file mode 100644 index 000000000..9fa6823d5 --- /dev/null +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py @@ -0,0 +1,427 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) + +# isort: on +import chex +import jax +import jax.numpy as jnp +import optax + +from algoperf import jax_sharding_utils, spec + +_GRAD_CLIP_EPS = 1e-6 + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate), + ) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree.map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat + ) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree.map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments + ) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + step_hint = 0.75 * step_hint + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) + optimizer_state = opt_init_fn(params_zeros_like) + + return optimizer_state, opt_update_fn + + +def train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, + dropout_rate, +): + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + dropout_rate=dropout_rate, + ) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container + ) + # Compute mean loss and grad + loss = summed_loss / n_valid_examples + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del train_state + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + dropout_rate = hyperparameters.dropout_rate + + # Create shardings for each argument + replicated = jax_sharding_utils.get_replicate_sharding() # No partitioning + sharded = ( + jax_sharding_utils.get_batch_dim_sharding() + ) # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated, # label_smoothing + replicated, # dropout_rate + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated, # grad_norm + ) + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings, + ) + + new_optimizer_state, new_params, new_model_state, loss, grad_norm = ( + jitted_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, + dropout_rate, + ) + ) + + # Log loss, grad_norm. + if global_step % 1 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + {'loss': loss.item(), 'grad_norm': grad_norm.item()}, global_step + ) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'lm': + return 64 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json b/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json new file mode 100644 index 000000000..e6945d69a --- /dev/null +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json @@ -0,0 +1,11 @@ +[ + { + "dropout_rate": 0.0, + "label_smoothing": 0.1, + "learning_rate": 0.0003955553491092581, + "one_minus_beta1": 0.06124602712, + "beta2": 0.9535169492059872, + "weight_decay": 0.03268700808664715, + "warmup_factor": 0.0375 + } +] \ No newline at end of file From b38ade083282348a5000220bf3ca11f79b5c9e9a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Oct 2025 01:34:49 +0000 Subject: [PATCH 58/82] update step hint for lm workload --- algoperf/workloads/lm/input_pipeline.py | 2 +- algoperf/workloads/lm/workload.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 2fd27113a..04bd90216 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -17,7 +17,7 @@ SEQUENCE_LENGTH = 1024 MAX_CORPUS_CHARS = 1_000_000_000 -SHUFFLE_BUFFER_SIZE = 100_000 +SHUFFLE_BUFFER_SIZE = 1000 VOCAB_SIZE = 50_257 diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index f5d2cda38..b9610f919 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -57,7 +57,7 @@ def loss_type(self) -> spec.LossType: @property def num_train_examples(self) -> int: - return 1000000 # Example size + return 8_749_870 # sequences of 1024 tokens each @property def num_eval_train_examples(self) -> int: @@ -94,7 +94,7 @@ def eval_period_time_sec(self) -> int: @property def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" - return 54000 + return 72000 @property def pre_ln(self) -> bool: @@ -159,7 +159,7 @@ def _eval_model_on_split( eval_metrics[metric_name] += metric_value eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) - eval_results['ppl'] = np.exp(eval_results['loss']) + eval_results['ppl'] = np.exp(eval_results['loss']).item() return eval_results From 65369f239a3110748890473cef415dcb087fe6c0 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Oct 2025 01:36:42 +0000 Subject: [PATCH 59/82] update target --- algoperf/workloads/lm/workload.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index b9610f919..0bed0b34d 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -38,11 +38,11 @@ def target_metric_name(self) -> str: return 'ppl' def has_reached_validation_target(self, eval_result: float) -> bool: - return eval_result['validation/ppl'] < self.validation_target_value + return eval_result['validation/ppl'] <= self.validation_target_value @property def validation_target_value(self) -> float: - return 20.0 # Target perplexity + return 25.5477 # Target perplexity def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: return True # No test targets @@ -73,7 +73,7 @@ def num_test_examples(self) -> int: @property def eval_batch_size(self) -> int: - return 32 + return 64 @property def train_mean(self): @@ -85,16 +85,16 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 3600 * 5 # 4 hours + return 3600 * 5 # 4 hours TODO(kasimbeg): update @property def eval_period_time_sec(self) -> int: - return 600 # 10 minutes + return 600 # 10 minutes TODO(kasimbeg): update @property def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" - return 72000 + return 72_000 @property def pre_ln(self) -> bool: From 6171b2d2fb6a0243993b10d03f0c284eb2c86801 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Oct 2025 23:04:56 +0000 Subject: [PATCH 60/82] update eval split sizes for lm workload and target setting point --- algoperf/workloads/lm/input_pipeline.py | 4 ++-- algoperf/workloads/lm/workload.py | 8 ++++---- .../fineweb_edu_lm/jax_nadamw_target_setting.py | 4 ++-- .../fineweb_edu_lm/tuning_search_space.json | 12 ++++++------ 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 04bd90216..79fdfbbcb 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -120,7 +120,7 @@ def get_lm_dataset( ds = ds.map(lambda x: {'inputs': x['inputs'], 'targets': x['targets'], 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)}) - ds = ds.take(1000).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size of validation + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) elif split == 'validation': ds = batch_with_padding( sequences_ds, @@ -133,6 +133,6 @@ def get_lm_dataset( ds = ds.map(lambda x: {'inputs': x['inputs'], 'targets': x['targets'], 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)}) - ds = ds.take(1000).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) return ds diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 0bed0b34d..466769d96 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -61,11 +61,11 @@ def num_train_examples(self) -> int: @property def num_eval_train_examples(self) -> int: - return 500 # Subset for evaluation. # TODO(kasimbeg): update + return 10_000 # Subset for evaluation. @property def num_validation_examples(self) -> int: - return 500 # TODO(kasimbeg update) + return 100_000 # sequences @property def num_test_examples(self) -> int: @@ -85,11 +85,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 3600 * 5 # 4 hours TODO(kasimbeg): update + return 3600 * 14 # 14 hours TODO(kasimbeg): update @property def eval_period_time_sec(self) -> int: - return 600 # 10 minutes TODO(kasimbeg): update + return 1200 # 20 minutes TODO(kasimbeg): update @property def step_hint(self) -> int: diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py index 9fa6823d5..1fef611ac 100644 --- a/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py @@ -170,8 +170,8 @@ def init_optimizer_state( del rng def jax_cosine_warmup(step_hint: int, hyperparameters): - # Create learning rate schedule. step_hint = 0.75 * step_hint + # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( init_value=0.0, @@ -343,7 +343,7 @@ def update_params( ) # Log loss, grad_norm. - if global_step % 1 == 0 and workload.metrics_logger is not None: + if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( {'loss': loss.item(), 'grad_norm': grad_norm.item()}, global_step ) diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json b/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json index e6945d69a..ce0f75623 100644 --- a/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json @@ -1,11 +1,11 @@ [ { "dropout_rate": 0.0, - "label_smoothing": 0.1, - "learning_rate": 0.0003955553491092581, - "one_minus_beta1": 0.06124602712, - "beta2": 0.9535169492059872, - "weight_decay": 0.03268700808664715, - "warmup_factor": 0.0375 + "label_smoothing": 0.0, + "learning_rate": 0.00038418421332238876, + "one_minus_beta1": 0.01564758865, + "beta2": 0.992362328914093, + "weight_decay": 0.25551270901641954, + "warmup_factor": 0.05 } ] \ No newline at end of file From d7a885cd7270dfbd8203f41276c3313ddbd63929 Mon Sep 17 00:00:00 2001 From: rka97 Date: Fri, 17 Oct 2025 04:01:11 +0000 Subject: [PATCH 61/82] Porting workload input pipeline to torch - Added `limit_tf_threads` parameter to `pytorch_init` to control TensorFlow threading based on workload type. Dataloader was going OOM otherwise. - Updated input pipeline to support "None" for weights (for memory). - Modified Transformer model's `forward` method to optionally return loss during training. Should be better to fuse the loss later. - Adjusted torch LM workload configuration for model dimensions and parameters to match jax. - Updated transformers version in `pyproject.toml`, older version seems unavailable. --- algoperf/pytorch_utils.py | 6 +- algoperf/workloads/lm/input_pipeline.py | 8 +- .../workloads/lm/lm_pytorch/plainlm_model.py | 74 ++++++----- algoperf/workloads/lm/lm_pytorch/workload.py | 118 ++++++------------ pyproject.toml | 2 +- submission_runner.py | 3 +- 6 files changed, 90 insertions(+), 121 deletions(-) diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index af09e67fc..c7537a884 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -27,7 +27,7 @@ def pytorch_setup() -> Tuple[bool, int, torch.device, int]: return use_pytorch_ddp, rank, device, n_gpus -def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: +def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler, limit_tf_threads = True) -> None: # Make sure no GPU memory is preallocated to Jax. os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # Only use CPU for Jax to avoid memory issues. @@ -39,7 +39,7 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: if use_pytorch_ddp: # Avoid tf input pipeline creating too many threads. - if rank != 0: + if rank != 0 and limit_tf_threads: tf.config.threading.set_intra_op_parallelism_threads(1) tf.config.threading.set_inter_op_parallelism_threads(1) @@ -47,10 +47,8 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: profiler.set_local_rank(rank) # Only log once (for local rank == 0). if rank != 0: - def logging_pass(*args): pass - logging.info = logging_pass # Initialize the process group. dist.init_process_group('nccl') diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 04bd90216..ee54427e1 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -107,7 +107,13 @@ def get_lm_dataset( repeated_sequences_dataset = shuffled_sequences_ds.repeat() ds = repeated_sequences_dataset.batch( global_batch_size, drop_remainder=False - ).prefetch(tf.data.experimental.AUTOTUNE) + ) + ds = ds.map(lambda x: { + 'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': None, + }) + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) elif split == 'eval_train': ds = batch_with_padding( sequences_ds, diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py index 225b98767..5de5bf310 100644 --- a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -159,7 +159,7 @@ def __init__(self, cfg): if cfg.tie_embeddings: self.tie_weights() - def forward(self, x): + def forward(self, x, targets=None): # x: (bsz, seqlen) x = self.embed_tokens(x) # (bsz, seqlen, dim) L = x.shape[1] @@ -178,7 +178,12 @@ def forward(self, x): for layer in self.layers: x = layer(x, freqs_cis) # (bsz, seqlen, dim) - return self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) + out = self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) + if targets is not None: + loss = F.cross_entropy( + out.view(-1, out.size(-1)), targets.view(-1), ignore_index=-100) + return out, loss + return out def predict(self, x, k=1): """Generate k tokens autoregressively. @@ -190,11 +195,6 @@ def predict(self, x, k=1): Returns: Tuple of (input_ids, predicted_ids) """ - # For debugging - predictions = [] - - batch_size = x.shape[0] - seq_len = x.shape[1] # Store original input original_input = x.clone() @@ -202,6 +202,7 @@ def predict(self, x, k=1): # Generate k tokens autoregressively for i in range(k): + # Get logits for the entire sequence logits = self(generated_input) @@ -212,24 +213,20 @@ def predict(self, x, k=1): # This is a common issue - the model gets stuck repeating the last token last_token_id = generated_input[:, -1] next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) - - # Print top 5 tokens for debugging - if i == 0: - print("\nPyTorch detailed prediction:") - top5_values, top5_indices = torch.topk(next_token_logits[0], 5) - for j, (idx, val) in enumerate(zip(top5_indices.tolist(), top5_values.tolist())): - prob = torch.softmax(next_token_logits[0], dim=-1)[idx].item() - print(f" Top {j+1}: Token {idx}, logit={val:.2f}, prob={prob:.6f}") - + # Get the most likely token next_token = torch.argmax(next_token_logits, dim=-1) - predictions.append(next_token.item()) # Append the predicted token to the sequence next_token = next_token.unsqueeze(1) # Add sequence dimension generated_input = torch.cat([generated_input, next_token], dim=1) - print(f" Full predictions step by step: {predictions}") + # For debugging, print predictions for the first item in the batch + print("\nPyTorch detailed prediction (first item in batch):") + predicted_sequence = generated_input[0, -k:].tolist() + print(f" Predicted token IDs: {predicted_sequence}") + for i, token_id in enumerate(predicted_sequence): + print(f" Step {i+1}: Predicted token {token_id}") # Return all tokens, not just the last k return original_input, generated_input[:, -k:] @@ -269,30 +266,43 @@ def count_params(self, non_embedding=True): def main(): print("Initializing transformer model and running forward pass...") - seq_length = 512 + seq_length = 1024 # Define model configuration config = ModelConfig( - vocab_size=32000, # Common vocab size for tokenizers like BPE or SentencePiece + vocab_size=50257, # Common vocab size for tokenizers like BPE or SentencePiece seq_len=seq_length, # Maximum sequence length - dim=768, # Embedding dimension + dim=1024, # Embedding dimension expand=4.0, # MLP expansion factor n_layers=12, # Number of transformer layers - n_heads=12, # Number of attention heads + n_heads=8, # Number of attention heads rmsnorm_eps=1e-6, # RMSNorm epsilon tie_embeddings=True # Tie embedding and output weights ) - def tie_weights(self): - self.lm_head.weight = self.embed_tokens.weight + # Instantiate the model + model = Transformer(config) + print(f"Model has {model.count_params():,} parameters.") - def count_params(self, non_embedding=True): - n_params = sum(p.numel() for p in self.parameters()) - if non_embedding: - n_params -= self.embed_tokens.weight.numel() - if (not self.lm_head.weight - is self.embed_tokens.weight): # if no weight tying - n_params -= self.lm_head.weight.numel() - return n_params + # Create some random input data + batch_size = 2 + input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_length)) + + # Move data to the same device as the model + if torch.cuda.is_available(): + input_ids = input_ids.cuda() + + # Run a forward pass + print(f"Running forward pass with input shape: {input_ids.shape}") + logits = model(input_ids) + print(f"Output logits shape: {logits.shape}") + # Run prediction + print("Running prediction...") + original_input, predicted_ids = model.predict(input_ids[:, :10], k=5) + print(f"Original input shape for prediction: {original_input.shape}") + print(f"Predicted IDs shape: {predicted_ids.shape}") + print(f"Predicted IDs: {predicted_ids}") +if __name__ == "__main__": + main() diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 71a8afd93..e4c03c4f5 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -14,6 +14,7 @@ Transformer, ) from algoperf.workloads.lm.workload import BaseLmWorkload +from algoperf.workloads.lm.input_pipeline import get_data_iter USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -37,10 +38,11 @@ def init_model_fn( cfg = ModelConfig( vocab_size=self._vocab_size, seq_len=self._seq_len, - dim=512, # Model dimension - expand=4, # MLP expansion factor - n_layers=6, # Number of transformer layers - n_heads=8, # Number of attention heads + dim=self._emb_dim, # Model dimension + expand=self._mlp_dim // self._emb_dim, # MLP expansion factor + # FIXME(rka97): fix expansion factor + n_layers=self._n_layers, # Number of transformer layers + n_heads=self._n_heads, # Number of attention heads rmsnorm_eps=1e-6, tie_embeddings=True ) @@ -65,7 +67,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = 0.0) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state, rng, update_batch_norm, dropout_rate model = params @@ -87,10 +89,8 @@ def _build_input_queue( num_batches: Optional[int] = None, repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: """Build an input queue for the given split.""" - from algoperf.workloads.lm.input_pipeline import get_lm_dataset local_batch_size = global_batch_size // N_GPUS - - loader = get_lm_dataset( + loader = get_data_iter( data_rng=data_rng, split=split, data_dir=data_dir, @@ -99,33 +99,12 @@ def _build_input_queue( ) if USE_PYTORCH_DDP: loader = islice(loader, RANK, None, N_GPUS) - seq_len = self._seq_len - weights = None - dtype = torch.int32 - is_train = split == 'train' - for batch in loader: - inputs = batch['inputs'] - targets = batch['targets'] - - if USE_PYTORCH_DDP: - if not is_train: - # During eval, the batch size of the remainder might be different - per_device_batch_size = torch.tensor( - targets.shape[0], dtype=dtype, device=DEVICE) - dist.broadcast(per_device_batch_size, src=0) - local_batch_size = per_device_batch_size.item() - # Broadcast to all devices - #dist.broadcast(inputs, src=0) - #dist.broadcast(targets, src=0) - - if weights is None: - weights = torch.ones((local_batch_size, seq_len), device=DEVICE) batch = { - 'inputs': torch.tensor(inputs, device=DEVICE, dtype=dtype), - 'targets': torch.tensor(targets, device=DEVICE, dtype=dtype), - 'weights': weights, + 'inputs': torch.tensor(batch['inputs'], device=DEVICE, dtype=dtype), + 'targets': torch.tensor(batch['targets'], device=DEVICE, dtype=torch.int64), + 'weights': None, } yield batch @@ -133,66 +112,41 @@ def is_output_params(self, param_name: str) -> bool: """Return whether the given parameter is an output parameter.""" return 'lm_head.weight' in param_name or 'lm_head.bias' in param_name - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> spec.Tensor: - """Evaluate the model on a single batch.""" - model = params - logits, _ = self.model_fn( - model, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - - # Handle both one-hot and token ID targets - targets = batch['targets'] - if targets.dim() == 3: # one-hot - loss = -torch.sum(targets * torch.nn.functional.log_softmax(logits, dim=-1)) - else: # token IDs - # TODO(kasimbeg): before deleting make sure we have defined self.weighted_cross_entropy so that we can call the shared workload _eval_batch. - loss = torch.nn.functional.cross_entropy( - logits.view(-1, logits.size(-1)), - targets.view(-1), - reduction='sum' - ) - return loss - - def loss_fn( - self, - label_batch: spec.Tensor, - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + # FIXME(rka97): Implement label smoothing + def compute_weighted_cross_entropy(self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor, label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: """Compute cross-entropy loss for language modeling in PyTorch.""" - vocab_size = logits_batch.shape[-1] + vocab_size = logits.size(-1) - if len(label_batch.shape) == len(logits_batch.shape): + if len(labels.shape) == len(logits.shape): # One-hot labels - log_probs = torch.nn.functional.log_softmax(logits_batch, dim=-1) - loss = -torch.sum(label_batch * log_probs, dim=-1) + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + loss = -torch.sum(labels * log_probs, dim=-1) else: # Dense labels loss = torch.nn.functional.cross_entropy( - logits_batch, - label_batch, + logits.view(-1, vocab_size), + labels.view(-1), reduction='none') - if mask_batch is not None: - loss = loss * mask_batch + loss = loss.view_as(labels) + + if weights is not None: + loss = loss * weights - n_valid = mask_batch.sum() if mask_batch is not None else label_batch.shape[0] + n_valid = weights.sum() if weights is not None else torch.tensor(labels.numel(), dtype=torch.float32, device=labels.device) return { 'summed': loss.sum(), 'n_valid_examples': n_valid, - 'per_example': loss + 'per_example': loss, } -def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, Any] - ) -> Dict[str, float]: - """Normalize eval metrics.""" - del num_examples - if USE_PYTORCH_DDP: - for metric in total_metrics.values(): - dist.all_reduce(metric) - total_metrics = {k: v.item() for k, v in total_metrics.items()} - eval_denominator = total_metrics.pop('denominator') - return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) \ No newline at end of file + def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + del num_examples + if USE_PYTORCH_DDP: + for metric in total_metrics.values(): + dist.all_reduce(metric) + total_metrics = {k: v.item() for k, v in total_metrics.items()} + eval_denominator = total_metrics.pop('denominator') + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 76bcfb7ca..b93c9794e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.19.0"] -lm = ["transformers==4.25.4", "datasets==3.6.0"] +lm = ["transformers==4.26", "datasets==3.6.0"] # Frameworks jax_core_deps = [ diff --git a/submission_runner.py b/submission_runner.py index 1c51ec58f..1c50cd6d9 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -784,7 +784,8 @@ def main(_): os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' if FLAGS.framework == 'pytorch': - pytorch_init(USE_PYTORCH_DDP, RANK, profiler) + limit_tf_threads = (base_workload != 'lm') + pytorch_init(USE_PYTORCH_DDP, RANK, profiler, limit_tf_threads=limit_tf_threads) # TODO: remove once issue resolved. if FLAGS.pytorch_eval_num_workers != 0: From 1f0439aaf6bbb7f0670a4dc0564a41c86e509270 Mon Sep 17 00:00:00 2001 From: rka97 Date: Sat, 18 Oct 2025 06:41:33 +0000 Subject: [PATCH 62/82] Fix OOM bug in lm eval --- algoperf/random_utils.py | 4 +-- algoperf/workloads/lm/lm_pytorch/workload.py | 28 ++++++++++++++----- algoperf/workloads/lm/workload.py | 15 +++++++--- .../pytorch_nadamw_full_budget.py | 2 ++ 4 files changed, 36 insertions(+), 13 deletions(-) diff --git a/algoperf/random_utils.py b/algoperf/random_utils.py index 1dc773e80..07efa2bdf 100644 --- a/algoperf/random_utils.py +++ b/algoperf/random_utils.py @@ -35,13 +35,13 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.uint32) return [new_seed, data] def _split(seed: SeedType, num: int = 2) -> SeedType: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) + return rng.randint(MIN_INT32, MAX_INT32, dtype=np.uint32, size=[num, 2]) def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index e4c03c4f5..b2ffac18e 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -1,5 +1,6 @@ """LM workload implemented in PyTorch.""" +import contextlib from itertools import islice from typing import Any, Dict, Iterator, Optional, Tuple @@ -8,7 +9,7 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import data_utils, param_utils, pytorch_utils, spec +from algoperf import param_utils, pytorch_utils, spec from algoperf.workloads.lm.lm_pytorch.plainlm_model import ( ModelConfig, Transformer, @@ -72,12 +73,23 @@ def model_fn( del model_state, rng, update_batch_norm, dropout_rate model = params - # Convert one-hot inputs to token IDs if needed - inputs = augmented_and_preprocessed_input_batch['inputs'] - if inputs.dim() == 3: # one-hot encoded + # Set model to eval or train mode based on the mode parameter + if mode == spec.ForwardPassMode.EVAL: + model.eval() + elif mode == spec.ForwardPassMode.TRAIN: + model.train() + contexts = { + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + } + with contexts[mode](): + # Convert one-hot inputs to token IDs if needed + inputs = augmented_and_preprocessed_input_batch['inputs'] + if inputs.dim() == 3: # one-hot encoded inputs = inputs.argmax(dim=-1) - logits = model(inputs) + logits = model(inputs) + return logits, None def _build_input_queue( @@ -90,12 +102,14 @@ def _build_input_queue( repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: """Build an input queue for the given split.""" local_batch_size = global_batch_size // N_GPUS + # In DDP mode, pass local_device_count=1 to prevent shard_and_maybe_pad_np + # from seeing all GPUs via torch.cuda.device_count() loader = get_data_iter( data_rng=data_rng, split=split, data_dir=data_dir, global_batch_size=local_batch_size, - num_batches=num_batches + num_batches=num_batches, ) if USE_PYTORCH_DDP: loader = islice(loader, RANK, None, N_GPUS) @@ -104,7 +118,7 @@ def _build_input_queue( batch = { 'inputs': torch.tensor(batch['inputs'], device=DEVICE, dtype=dtype), 'targets': torch.tensor(batch['targets'], device=DEVICE, dtype=torch.int64), - 'weights': None, + 'weights': torch.tensor(batch['weights'], device=DEVICE, dtype=torch.float32) if batch['weights'] is not None else None, } yield batch diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 466769d96..73e784f3a 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -73,7 +73,7 @@ def num_test_examples(self) -> int: @property def eval_batch_size(self) -> int: - return 64 + return 256 @property def train_mean(self): @@ -138,6 +138,11 @@ def _eval_model_on_split( ) -> Dict[str, float]: """Run a full evaluation of the model.""" num_batches = int(math.ceil(num_examples / global_batch_size)) + + # Handle edge case where num_batches is 0 (e.g., test split with 0 examples) + if num_batches == 0: + return {'loss': 0.0, 'ppl': 1.0} + if split not in self._eval_iters: # These iterators will repeat indefinitely. self._eval_iters[split] = self._build_input_queue( @@ -159,7 +164,7 @@ def _eval_model_on_split( eval_metrics[metric_name] += metric_value eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) - eval_results['ppl'] = np.exp(eval_results['loss']).item() + eval_results['ppl'] = np.exp(eval_results['loss']).item() return eval_results @@ -173,9 +178,11 @@ def _eval_batch(self, params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) # Calculate cross-entropy loss metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) + # CRITICAL: Detach tensors to free computation graph and activations + # Without this, all intermediate activations are kept in memory! return { - 'loss': metrics['summed'], - 'denominator': metrics['n_valid_examples'], + 'loss': metrics['summed'].detach(), + 'denominator': metrics['n_valid_examples'].detach(), } diff --git a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py index 0b32199ba..9b544e380 100644 --- a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -372,6 +372,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'lm': + return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') From b11c1938447c3cb68a9635ffa75648ec97c3e5d2 Mon Sep 17 00:00:00 2001 From: rka97 Date: Sat, 18 Oct 2025 20:42:14 +0000 Subject: [PATCH 63/82] repeat dataset --- algoperf/workloads/lm/input_pipeline.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 7a55e81fd..ab7c64479 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -98,14 +98,12 @@ def get_lm_dataset( }, num_parallel_calls=AUTOTUNE, ) - - # batch + sequences_ds = sequences_ds.repeat() if split == 'train': - shuffled_sequences_ds = sequences_ds.shuffle( + ds = sequences_ds.shuffle( SHUFFLE_BUFFER_SIZE, seed=shuffle_seed ) - repeated_sequences_dataset = shuffled_sequences_ds.repeat() - ds = repeated_sequences_dataset.batch( + ds = ds.batch( global_batch_size, drop_remainder=False ) ds = ds.map(lambda x: { From 42d1d1a5379257015ca93847d539ef710e307067 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 20 Oct 2025 17:26:07 +0000 Subject: [PATCH 64/82] label smoothing default fix --- algoperf/workloads/lm/input_pipeline.py | 7 ++++--- algoperf/workloads/lm/lm_jax/workload.py | 4 +--- algoperf/workloads/lm/workload.py | 8 ++++---- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 79fdfbbcb..1716399c0 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -119,7 +119,8 @@ def get_lm_dataset( ) ds = ds.map(lambda x: {'inputs': x['inputs'], 'targets': x['targets'], - 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)}) + 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0) + }) ds = ds.prefetch(tf.data.experimental.AUTOTUNE) elif split == 'validation': ds = batch_with_padding( @@ -132,7 +133,7 @@ def get_lm_dataset( ) ds = ds.map(lambda x: {'inputs': x['inputs'], 'targets': x['targets'], - 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)}) + 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0) + }) ds = ds.prefetch(tf.data.experimental.AUTOTUNE) - return ds diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 801b1e0b4..91a2592b4 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -4,8 +4,6 @@ import jax import jax.numpy as jnp -import optax -from flax.training import common_utils from algoperf import jax_sharding_utils, param_utils, spec from algoperf.workloads.lm.input_pipeline import get_data_iter @@ -88,7 +86,7 @@ def compute_weighted_cross_entropy( logits: spec.Tensor, targets: spec.Tensor, weights: Optional[spec.Tensor] = None, - label_smoothing: float = 0.1, + label_smoothing: float = 0.0, ) -> Dict[str, spec.Tensor]: # differentiable """Compute weighted cross entropy and entropy for log probs and targets. Args: diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 466769d96..21a7e8fbb 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -2,11 +2,11 @@ import abc import math -import numpy as np import os from typing import Any, Dict, Optional import jax +import numpy as np from absl import flags from algoperf import spec @@ -85,11 +85,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 3600 * 14 # 14 hours TODO(kasimbeg): update + return 3600 * 14 # 14 hours @property def eval_period_time_sec(self) -> int: - return 1200 # 20 minutes TODO(kasimbeg): update + return 1200 # 20 minutes @property def step_hint(self) -> int: @@ -172,7 +172,7 @@ def _eval_batch(self, logits, _ = self.model_fn( params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) # Calculate cross-entropy loss - metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) + metrics = self.loss_fn(batch['targets'], logits, batch['weights']) return { 'loss': metrics['summed'], 'denominator': metrics['n_valid_examples'], From d95f2bfb6290a47c0a81580f4c9e90e84c6bbd53 Mon Sep 17 00:00:00 2001 From: rka97 Date: Tue, 21 Oct 2025 00:05:02 +0000 Subject: [PATCH 65/82] Make sure to take the correct number of batches in lm --- algoperf/workloads/lm/input_pipeline.py | 26 +++++++++++--------- algoperf/workloads/lm/lm_jax/workload.py | 13 +++++----- algoperf/workloads/lm/lm_pytorch/workload.py | 10 ++++---- algoperf/workloads/lm/workload.py | 12 ++++----- 4 files changed, 32 insertions(+), 29 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index ab7c64479..68cb54d1e 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -54,14 +54,14 @@ def batch_with_padding( def get_data_iter(data_rng: jax.random.PRNGKey, split: str, data_dir: str, - global_batch_size: int, + batch_size: int, num_batches: Optional[int] = None,): - ds = get_lm_dataset(data_rng, split, data_dir, global_batch_size, num_batches) + ds = get_lm_dataset(data_rng, split, data_dir, batch_size, num_batches) it = map( functools.partial( - data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size + data_utils.shard_and_maybe_pad_np, global_batch_size=batch_size ), ds, ) @@ -72,7 +72,7 @@ def get_lm_dataset( data_rng: jax.random.PRNGKey, split: str, data_dir: str, - global_batch_size: int, + batch_size: int, num_batches: Optional[int] = None, ): """Load preprocessed TF dataset.""" @@ -104,8 +104,9 @@ def get_lm_dataset( SHUFFLE_BUFFER_SIZE, seed=shuffle_seed ) ds = ds.batch( - global_batch_size, drop_remainder=False + batch_size, drop_remainder=False ) + ds = ds.take(num_batches) if num_batches is not None else ds ds = ds.map(lambda x: { 'inputs': x['inputs'], 'targets': x['targets'], @@ -115,12 +116,13 @@ def get_lm_dataset( elif split == 'eval_train': ds = batch_with_padding( sequences_ds, - global_batch_size, + batch_size, padded_shapes={ - 'inputs': (global_batch_size, None), - 'targets': (global_batch_size, None), + 'inputs': (batch_size, None), + 'targets': (batch_size, None), }, ) + ds = ds.take(num_batches) if num_batches is not None else ds ds = ds.map(lambda x: {'inputs': x['inputs'], 'targets': x['targets'], 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)}) @@ -128,15 +130,15 @@ def get_lm_dataset( elif split == 'validation': ds = batch_with_padding( sequences_ds, - global_batch_size, + batch_size, padded_shapes={ - 'inputs': (global_batch_size, None), - 'targets': (global_batch_size, None), + 'inputs': (batch_size, None), + 'targets': (batch_size, None), }, ) + ds = ds.take(num_batches) if num_batches is not None else ds ds = ds.map(lambda x: {'inputs': x['inputs'], 'targets': x['targets'], 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)}) ds = ds.prefetch(tf.data.experimental.AUTOTUNE) - return ds diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 801b1e0b4..760b87306 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -23,16 +23,17 @@ def _build_input_queue(self, split: str, data_dir: str, global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None): """Build an input queue using pre-cached FineWeb dataset.""" - del num_batches - del repeat_final_dataset + del cache, repeat_final_dataset ds = get_data_iter( data_rng=data_rng, split=split, data_dir=data_dir, - global_batch_size=global_batch_size) + batch_size=global_batch_size, + num_batches=num_batches) ds = map(jax_sharding_utils.shard_along_batch_dim, ds) return ds @@ -73,7 +74,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = 0.0) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode, rng, update_batch_norm, model_state, dropout_rate inputs = batch['inputs'] # Convert one-hot inputs to token IDs if needed diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index b2ffac18e..b5f93ce2e 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -98,17 +98,17 @@ def _build_input_queue( split: str, data_dir: str, global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: """Build an input queue for the given split.""" + del cache, repeat_final_dataset local_batch_size = global_batch_size // N_GPUS - # In DDP mode, pass local_device_count=1 to prevent shard_and_maybe_pad_np - # from seeing all GPUs via torch.cuda.device_count() loader = get_data_iter( data_rng=data_rng, split=split, data_dir=data_dir, - global_batch_size=local_batch_size, + batch_size=local_batch_size, num_batches=num_batches, ) if USE_PYTORCH_DDP: diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 73e784f3a..8f17fd930 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -4,7 +4,7 @@ import math import numpy as np import os -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Iterator import jax from absl import flags @@ -119,9 +119,10 @@ def _build_input_queue( split: str, data_dir: str, global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, num_batches: Optional[int] = None, - repeat_final_dataset: bool = False, - ): + ) -> Iterator[Dict[str, Any]]: """Build an input queue for the given split.""" @@ -150,8 +151,7 @@ def _eval_model_on_split( split, data_dir, global_batch_size, - num_batches, - repeat_final_dataset=True, + num_batches=num_batches ) eval_metrics = {} @@ -175,7 +175,7 @@ def _eval_batch(self, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" logits, _ = self.model_fn( - params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False, 0.0) # Calculate cross-entropy loss metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) # CRITICAL: Detach tensors to free computation graph and activations From 0dc16db94c20973d2ae1f31231cfae91bef0801b Mon Sep 17 00:00:00 2001 From: rka97 Date: Tue, 21 Oct 2025 00:22:23 +0000 Subject: [PATCH 66/82] Properly handle repetition in LM training and evaluation splits --- algoperf/workloads/lm/input_pipeline.py | 4 +++- algoperf/workloads/lm/workload.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index aeeff80a9..e701d1bcb 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -98,7 +98,6 @@ def get_lm_dataset( }, num_parallel_calls=AUTOTUNE, ) - sequences_ds = sequences_ds.repeat() if split == 'train': ds = sequences_ds.shuffle( SHUFFLE_BUFFER_SIZE, seed=shuffle_seed @@ -107,6 +106,7 @@ def get_lm_dataset( batch_size, drop_remainder=False ) ds = ds.take(num_batches) if num_batches is not None else ds + ds = ds.repeat() ds = ds.map(lambda x: { 'inputs': x['inputs'], 'targets': x['targets'], @@ -123,6 +123,7 @@ def get_lm_dataset( }, ) ds = ds.take(num_batches) if num_batches is not None else ds + ds = ds.repeat() ds = ds.map(lambda x: {'inputs': x['inputs'], 'targets': x['targets'], 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0) @@ -138,6 +139,7 @@ def get_lm_dataset( }, ) ds = ds.take(num_batches) if num_batches is not None else ds + ds = ds.repeat() ds = ds.map(lambda x: {'inputs': x['inputs'], 'targets': x['targets'], 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 1f966ca03..8f17fd930 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -2,11 +2,11 @@ import abc import math +import numpy as np import os from typing import Any, Dict, Optional, Iterator import jax -import numpy as np from absl import flags from algoperf import spec @@ -85,11 +85,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 3600 * 14 # 14 hours + return 3600 * 14 # 14 hours TODO(kasimbeg): update @property def eval_period_time_sec(self) -> int: - return 1200 # 20 minutes + return 1200 # 20 minutes TODO(kasimbeg): update @property def step_hint(self) -> int: From 7edb702c2f4a4eb8a88bd35a40ea7a255e6f09d8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Oct 2025 01:41:23 +0000 Subject: [PATCH 67/82] move eval_batch from shared class to framework specific classes since pytorch calls detatch --- algoperf/workloads/lm/lm_jax/workload.py | 17 +++++++++++++++++ algoperf/workloads/lm/lm_pytorch/workload.py | 18 ++++++++++++++++++ algoperf/workloads/lm/workload.py | 18 ------------------ 3 files changed, 35 insertions(+), 18 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 91a2592b4..3809c8258 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -130,6 +130,23 @@ def compute_weighted_cross_entropy( 'per_example': per_example_losses, } + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + # Calculate cross-entropy loss + metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) + # CRITICAL: Detach tensors to free computation graph and activations + # Without this, all intermediate activations are kept in memory! + return { + 'loss': metrics['summed'], + 'denominator': metrics['n_valid_examples'], + } + def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any] ) -> Dict[str, float]: diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index b2ffac18e..4d87c5ba7 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -153,6 +153,24 @@ def compute_weighted_cross_entropy(self, logits: spec.Tensor, labels: spec.Tenso 'per_example': loss, } + + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + # Calculate cross-entropy loss + metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) + # CRITICAL: Detach tensors to free computation graph and activations + # Without this, all intermediate activations are kept in memory! + return { + 'loss': metrics['summed'].detatch(), + 'denominator': metrics['n_valid_examples'].detatch(), + } + def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any] ) -> Dict[str, float]: diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index b4e7d7bb6..1c4c53fc8 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -168,24 +168,6 @@ def _eval_model_on_split( return eval_results - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> spec.Tensor: - """Evaluate the model on a single batch.""" - logits, _ = self.model_fn( - params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - # Calculate cross-entropy loss - metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) - # CRITICAL: Detach tensors to free computation graph and activations - # Without this, all intermediate activations are kept in memory! - return { - 'loss': metrics['summed'].detach(), - 'denominator': metrics['n_valid_examples'].detach(), - } - - @abc.abstractmethod def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any] From 73e3ea6679d9b7fc5ccf6d75a2e4b1c9021e8d22 Mon Sep 17 00:00:00 2001 From: rka97 Date: Tue, 21 Oct 2025 02:13:04 +0000 Subject: [PATCH 68/82] Refactor imports and clean up unused code in LM workload and related modules --- algoperf/checkpoint_utils.py | 7 +- algoperf/workloads/lm/input_pipeline.py | 1 - algoperf/workloads/lm/lm_jax/models.py | 3 +- algoperf/workloads/lm/lm_jax/nanodo_model.py | 2 +- algoperf/workloads/lm/lm_pytorch/models.py | 1 + .../workloads/lm/lm_pytorch/plainlm_model.py | 9 +- algoperf/workloads/lm/lm_pytorch/workload.py | 9 +- .../lm/tests/test_build_input_queue_jax.py | 60 --------- .../lm/tests/test_build_input_queue_torch.py | 86 ------------- .../lm/tests/test_hf_input_pipeline.py | 116 ------------------ .../workloads/lm/tests/test_linear_model.py | 39 ------ algoperf/workloads/lm/workload.py | 4 +- .../external_tuning/jax_nadamw_full_budget.py | 2 +- dataset/dataset_setup.py | 1 - 14 files changed, 18 insertions(+), 322 deletions(-) delete mode 100644 algoperf/workloads/lm/tests/test_build_input_queue_jax.py delete mode 100644 algoperf/workloads/lm/tests/test_build_input_queue_torch.py delete mode 100644 algoperf/workloads/lm/tests/test_hf_input_pipeline.py delete mode 100644 algoperf/workloads/lm/tests/test_linear_model.py diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py index 00f05ba5d..6d61e9d7f 100644 --- a/algoperf/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -5,17 +5,18 @@ """ import os -from typing import Sequence, Tuple, Optional +from typing import Optional, Sequence, Tuple import numpy as np +import orbax.checkpoint as ocp import torch from absl import logging from flax import jax_utils from flax.training import checkpoints as flax_checkpoints from flax.training.checkpoints import latest_checkpoint -from tensorflow.io import gfile # pytype: disable=import-error -import orbax.checkpoint as ocp from orbax.checkpoint.type_handlers import NumpyHandler +from tensorflow.io import gfile # pytype: disable=import-error + from algoperf import spec from algoperf.pytorch_utils import pytorch_setup diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index e701d1bcb..cfa2f36cd 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -5,7 +5,6 @@ from typing import Optional import jax -import numpy as np import tensorflow as tf from algoperf import data_utils diff --git a/algoperf/workloads/lm/lm_jax/models.py b/algoperf/workloads/lm/lm_jax/models.py index 72ee5bd83..ae8d935bf 100644 --- a/algoperf/workloads/lm/lm_jax/models.py +++ b/algoperf/workloads/lm/lm_jax/models.py @@ -1,5 +1,6 @@ -from flax import linen as nn import jax.numpy as jnp +from flax import linen as nn + class LinearModel(nn.Module): vocab_size: int diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py index bd7213620..9126d31e8 100644 --- a/algoperf/workloads/lm/lm_jax/nanodo_model.py +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -284,7 +284,7 @@ def main(): model = TransformerDo(cfg) # Print model info - print(f"\nModel Configuration:") + print("\nModel Configuration:") print(f" - Model dimension (D): {cfg.D}") print(f" - Number of heads (H): {cfg.H}") print(f" - Max sequence length (L): {cfg.L}") diff --git a/algoperf/workloads/lm/lm_pytorch/models.py b/algoperf/workloads/lm/lm_pytorch/models.py index 545763924..b88e457d8 100644 --- a/algoperf/workloads/lm/lm_pytorch/models.py +++ b/algoperf/workloads/lm/lm_pytorch/models.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn + class LinearLayer(nn.Module): def __init__(self, vocab_size: int): super().__init__() diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py index 5de5bf310..9dc8be522 100644 --- a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -1,10 +1,10 @@ import math -import torch -import torch.nn.functional as F -from torch import nn from dataclasses import dataclass from typing import Tuple +import torch +import torch.nn.functional as F +from torch import nn @dataclass @@ -257,8 +257,7 @@ def count_params(self, non_embedding=True): n_params = sum(p.numel() for p in self.parameters()) if non_embedding: n_params -= self.embed_tokens.weight.numel() - if (not self.lm_head.weight - is self.embed_tokens.weight): # if no weight tying + if (self.lm_head.weight is not self.embed_tokens.weight): # if no weight tying n_params -= self.lm_head.weight.numel() return n_params diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 6a59770bb..9713e84b0 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -10,12 +10,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP from algoperf import param_utils, pytorch_utils, spec +from algoperf.workloads.lm.input_pipeline import get_data_iter from algoperf.workloads.lm.lm_pytorch.plainlm_model import ( ModelConfig, Transformer, ) from algoperf.workloads.lm.workload import BaseLmWorkload -from algoperf.workloads.lm.input_pipeline import get_data_iter USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -162,13 +162,10 @@ def _eval_batch(self, """Evaluate the model on a single batch.""" logits, _ = self.model_fn( params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - # Calculate cross-entropy loss metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) - # CRITICAL: Detach tensors to free computation graph and activations - # Without this, all intermediate activations are kept in memory! return { - 'loss': metrics['summed'].detatch(), - 'denominator': metrics['n_valid_examples'].detatch(), + 'loss': metrics['summed'].detach(), + 'denominator': metrics['n_valid_examples'].detach(), } def _normalize_eval_metrics( diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_jax.py b/algoperf/workloads/lm/tests/test_build_input_queue_jax.py deleted file mode 100644 index b9adc70d2..000000000 --- a/algoperf/workloads/lm/tests/test_build_input_queue_jax.py +++ /dev/null @@ -1,60 +0,0 @@ -import jax -import jax.numpy as jnp - -from algoperf.profiler import PassThroughProfiler -from algoperf.workloads.lm.lm_jax.workload import LmWorkload -import os - -RANK = os.environ.get('RANK', 0) - -def test_dataloader_jax(): - # Test config. - rng_seed = 1996 - data_dir = '/home/ak4605/data/finewebedu/' - split = 'train' - global_batch_size = 64 - dtype = jnp.int32 - seq_len = 2048 - - workload = LmWorkload() - data_rng = jax.random.PRNGKey(rng_seed) - input_queue = workload._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size) - - for _ in range(1): - - batch = next(input_queue) - print(f"RANK {RANK} got batch") - - assert type(batch) == dict - assert 'inputs' in batch - assert 'targets' in batch - - inputs, targets = batch['inputs'], batch['targets'] - print(f"RANK {RANK} inputs.shape: {inputs.shape}") - print(f"RANK {RANK} targets.shape: {targets.shape}") - print(f"RANK {RANK} type(inputs): {type(inputs)}") - - jax.debug.inspect_array_sharding(inputs, callback=print) - assert inputs.dtype == dtype - assert targets.dtype == dtype - - assert inputs.shape == (global_batch_size, seq_len) - assert targets.shape == (global_batch_size, seq_len) - - assert jnp.equal(inputs[:, 1:], targets[:, :-1]).all() - print(f"RANK {RANK} inputs[0, :10]: {inputs[0, :10]}") - - print(f"=== ALL TEST PASSED ===") - - -def main(): - profiler = PassThroughProfiler() - test_dataloader_jax() - - -if __name__ == '__main__': - main() diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py deleted file mode 100644 index 827272037..000000000 --- a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py +++ /dev/null @@ -1,86 +0,0 @@ -import jax -import torch - -from algoperf.profiler import PassThroughProfiler -from algoperf.pytorch_utils import pytorch_init -from algoperf.pytorch_utils import pytorch_setup -from algoperf.workloads.lm.lm_pytorch.workload import LmWorkload - -USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() - - -def sync_ddp(): - if torch.cuda.is_available(): - torch.cuda.synchronize() - - -def test_dataloader_torch(): - # Test config. - rng_seed = 1996 - data_dir = '/home/ak4605/data/finewebedu/' - split = 'train' - global_batch_size = 64 - dtype = torch.int32 - seq_len = 2048 - - local_batch_size = global_batch_size // N_GPUS - - workload = LmWorkload() - - data_rng = jax.random.PRNGKey(rng_seed) - - input_queue = workload._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size) - - print(f"RANK {RANK} of {N_GPUS}") - sync_ddp() - - # batch = next(input_queue) - # inputs, targets = batch['inputs'], batch['targets'] - # print(f"inputs.shape: {inputs.shape}") - # print(f"inputs: {inputs}") - - # Start test. - for _ in range(1): - - batch = next(input_queue) - print(f"RANK {RANK} got batch") - - assert type(batch) == dict - assert 'inputs' in batch - assert 'targets' in batch - - inputs, targets = batch['inputs'], batch['targets'] - print(f"RANK {RANK} inputs.shape: {inputs.shape}") - print(f"RANK {RANK} targets.shape: {targets.shape}") - print(f"RANK {RANK} type(inputs): {type(inputs)}") - assert type(inputs) == torch.Tensor - assert type(targets) == torch.Tensor - - assert inputs.device == DEVICE - assert targets.device == DEVICE - assert inputs.dtype == dtype - assert targets.dtype == dtype - - print(local_batch_size, seq_len) - assert inputs.shape == (local_batch_size, seq_len) - assert targets.shape == (local_batch_size, seq_len) - - assert torch.equal(inputs[:, 1:], targets[:, :-1]) - print(f"RANK {RANK} inputs[0, :10]: {inputs[0, :10]}") - - print(f"=== ALL TEST PASSED ===") - - -def main(): - profiler = PassThroughProfiler() - print(USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS) - pytorch_init(USE_PYTORCH_DDP, RANK, profiler) - test_dataloader_torch() - - -if __name__ == '__main__': - main() diff --git a/algoperf/workloads/lm/tests/test_hf_input_pipeline.py b/algoperf/workloads/lm/tests/test_hf_input_pipeline.py deleted file mode 100644 index 36bab0d02..000000000 --- a/algoperf/workloads/lm/tests/test_hf_input_pipeline.py +++ /dev/null @@ -1,116 +0,0 @@ -"""Tests for LM HuggingFace input pipeline.""" -import os - -import jax -import jax.numpy as jnp -import torch -from transformers import GPT2Tokenizer - -from algoperf.workloads.lm.input_pipeline import get_hf_dataloader - - -def main(): - # Setup test environment - cache_dir = "/home/ak4605/data" - if not os.path.exists(cache_dir): - raise FileNotFoundError(f"Cache directory {cache_dir} not found") - - data_rng = jax.random.PRNGKey(42) - tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") - vocab_size = tokenizer.vocab_size - - print("Running JAX output shapes and types test...") - batch_size = 8 - seq_len = 32 - loader = get_hf_dataloader( - cache_dir=cache_dir, - batch_size=batch_size, - seq_len=seq_len, - framework="jax", - split="train", - data_rng=data_rng) - inputs, targets = next(loader) - assert inputs.shape == (batch_size, seq_len, vocab_size), \ - f"Expected inputs shape {(batch_size, seq_len, vocab_size)}, got {inputs.shape}" - assert targets.shape == (batch_size, seq_len, vocab_size), \ - f"Expected targets shape {(batch_size, seq_len, vocab_size)}, got {targets.shape}" - assert inputs.dtype == jnp.float32, \ - f"Expected inputs dtype float32, got {inputs.dtype}" - assert targets.dtype == jnp.float32, \ - f"Expected targets dtype float32, got {targets.dtype}" - assert jnp.all(jnp.sum(inputs, axis=-1) == 1), "Inputs should be one-hot encoded" - assert jnp.all(jnp.sum(targets, axis=-1) == 1), "Targets should be one-hot encoded" - print("✓ JAX test passed") - - print("\nRunning Torch output shapes and types test...") - loader = get_hf_dataloader( - cache_dir=cache_dir, - batch_size=batch_size, - seq_len=seq_len, - framework="torch", - split="train", - data_rng=data_rng) - inputs, targets = next(loader) - assert inputs.shape == (batch_size, seq_len, vocab_size), \ - f"Expected inputs shape {(batch_size, seq_len, vocab_size)}, got {inputs.shape}" - assert targets.shape == (batch_size, seq_len, vocab_size), \ - f"Expected targets shape {(batch_size, seq_len, vocab_size)}, got {targets.shape}" - assert inputs.dtype == torch.float32, \ - f"Expected inputs dtype float32, got {inputs.dtype}" - assert targets.dtype == torch.float32, \ - f"Expected targets dtype float32, got {targets.dtype}" - assert torch.all(torch.sum(inputs, dim=-1) == 1), "Inputs should be one-hot encoded" - assert torch.all(torch.sum(targets, dim=-1) == 1), "Targets should be one-hot encoded" - print("✓ Torch test passed") - - print("\nTesting consistent batching with same seed...") - loader1 = get_hf_dataloader( - cache_dir=cache_dir, - batch_size=batch_size, - seq_len=seq_len, - framework="jax", - split="train", - data_rng=jax.random.PRNGKey(42)) - batch1 = next(loader1) - - loader2 = get_hf_dataloader( - cache_dir=cache_dir, - batch_size=batch_size, - seq_len=seq_len, - framework="jax", - split="train", - data_rng=jax.random.PRNGKey(42)) - batch2 = next(loader2) - - assert jnp.array_equal(batch1[0], batch2[0]), "Input batches should be identical with same seed" - assert jnp.array_equal(batch1[1], batch2[1]), "Target batches should be identical with same seed" - print("✓ Consistent batching test passed") - - print("\nTesting eval split doesn't shuffle...") - loader1 = get_hf_dataloader( - cache_dir=cache_dir, - batch_size=batch_size, - seq_len=seq_len, - framework="jax", - split="eval", - data_rng=jax.random.PRNGKey(42)) - batch1 = next(loader1) - - loader2 = get_hf_dataloader( - cache_dir=cache_dir, - batch_size=batch_size, - seq_len=seq_len, - framework="jax", - split="eval", - data_rng=jax.random.PRNGKey(999)) - batch2 = next(loader2) - - assert jnp.array_equal(batch1[0], batch2[0]), "Eval inputs should be identical regardless of seed" - assert jnp.array_equal(batch1[1], batch2[1]), "Eval targets should be identical regardless of seed" - print("✓ Eval no shuffling test passed") - - print("\nAll tests passed successfully!") - - -if __name__ == "__main__": - main() diff --git a/algoperf/workloads/lm/tests/test_linear_model.py b/algoperf/workloads/lm/tests/test_linear_model.py deleted file mode 100644 index 31cd1d577..000000000 --- a/algoperf/workloads/lm/tests/test_linear_model.py +++ /dev/null @@ -1,39 +0,0 @@ -import jax -import jax.numpy as jnp -import torch - -TEST_SEQ_LEN = 512 - -def test_pytorch_linear(): - from algoperf.workloads.lm.lm_pytorch.models import LinearLayer - vocab_size = 32000 - model = LinearLayer(vocab_size) - - batch_size = 8 - seq_len = TEST_SEQ_LEN - inputs = torch.randn(batch_size, seq_len, vocab_size) - outputs = model(inputs) - - assert outputs.shape == (batch_size, seq_len, vocab_size) - assert not torch.isnan(outputs).any() - -def test_jax_linear(): - from algoperf.workloads.lm.lm_jax.models import LinearModel - - vocab_size = 32000 - seq_len = TEST_SEQ_LEN - batch_size = 8 - model = LinearModel(vocab_size) - rng = jax.random.PRNGKey(0) - params = model.init(rng, jnp.ones((1, seq_len, vocab_size))) - - inputs = jax.random.normal(rng, (batch_size, seq_len, vocab_size)) - outputs = model.apply(params, inputs) - - assert outputs.shape == (batch_size, seq_len, vocab_size) - assert not jnp.isnan(outputs).any() - -if __name__ == '__main__': - test_pytorch_linear() - test_jax_linear() - print("All tests passed!") diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index e0af589e3..f15e4b8a7 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -2,11 +2,11 @@ import abc import math -import numpy as np import os -from typing import Any, Dict, Optional, Iterator +from typing import Any, Dict, Iterator, Optional import jax +import numpy as np from absl import flags from algoperf import spec diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index 9b4192de2..ccfa25360 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -11,7 +11,7 @@ Tuple, Union, ) -from absl import logging + # isort: on import chex import jax diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 872e2ef0b..8fecaf419 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -81,7 +81,6 @@ from transformers import AutoTokenizer import functools -import itertools import os import shutil import subprocess From 91988af436f452021e98a61e5144a89d14418e20 Mon Sep 17 00:00:00 2001 From: rka97 Date: Tue, 21 Oct 2025 02:20:31 +0000 Subject: [PATCH 69/82] pass linter checks --- algoperf/checkpoint_utils.py | 73 +-- algoperf/pytorch_utils.py | 6 +- algoperf/workloads/lm/input_pipeline.py | 54 +- algoperf/workloads/lm/lm_jax/models.py | 20 - algoperf/workloads/lm/lm_jax/nanodo_model.py | 575 +++++++++--------- algoperf/workloads/lm/lm_jax/workload.py | 134 ++-- algoperf/workloads/lm/lm_pytorch/models.py | 19 - .../workloads/lm/lm_pytorch/plainlm_model.py | 556 +++++++++-------- algoperf/workloads/lm/lm_pytorch/workload.py | 182 +++--- algoperf/workloads/lm/workload.py | 33 +- algoperf/workloads/workloads.py | 286 ++++----- dataset/dataset_setup.py | 171 +++--- submission_runner.py | 28 +- 13 files changed, 1091 insertions(+), 1046 deletions(-) delete mode 100644 algoperf/workloads/lm/lm_jax/models.py delete mode 100644 algoperf/workloads/lm/lm_pytorch/models.py diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py index 6d61e9d7f..af05111cd 100644 --- a/algoperf/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -31,49 +31,52 @@ int, ] + class BoolHandler(NumpyHandler): + """ + An implementation of TypeHandler for np.bool_ that inherits from NumpyHandler. + It works by treating the scalar as a 0-dimensional array. + """ + + def typestr(self) -> str: + """Unique string identifier for this handler.""" + return 'np.bool_' + + async def serialize( + self, + values: Sequence[np.bool_], + infos: Sequence, + args: Optional[Sequence[ocp.SaveArgs]] = None, + ): """ - An implementation of TypeHandler for np.bool_ that inherits from NumpyHandler. - It works by treating the scalar as a 0-dimensional array. + Serializes a sequence of np.bool_ scalars by first converting them + to 0-dim numpy arrays and then calling the parent NumpyHandler. """ + # Convert each scalar np.bool_ to a 0-dimensional np.ndarray + array_values = [np.asarray(v, dtype=np.bool_) for v in values] + # Use the parent class's robust serialization logic + return await super().serialize(array_values, infos, args) + + async def deserialize( + self, + infos: Sequence, + args: Optional[Sequence[ocp.RestoreArgs]] = None, + ) -> Sequence[np.bool_]: + """ + Deserializes into a sequence of np.bool_ scalars by calling the + parent handler and then converting the resulting 0-dim arrays. + """ + # Parent deserialize will return a sequence of 0-dimensional np.ndarray + results = await super().deserialize(infos, args) - def typestr(self) -> str: - """Unique string identifier for this handler.""" - return 'np.bool_' + # Convert each 0-d array back to an np.bool_ scalar using .item() + scalar_results = [np.bool_(r.item()) for r in results] + return scalar_results - async def serialize( - self, - values: Sequence[np.bool_], - infos: Sequence, - args: Optional[Sequence[ocp.SaveArgs]] = None, - ): - """ - Serializes a sequence of np.bool_ scalars by first converting them - to 0-dim numpy arrays and then calling the parent NumpyHandler. - """ - # Convert each scalar np.bool_ to a 0-dimensional np.ndarray - array_values = [np.asarray(v, dtype=np.bool_) for v in values] - # Use the parent class's robust serialization logic - return await super().serialize(array_values, infos, args) - - async def deserialize( - self, - infos: Sequence, - args: Optional[Sequence[ocp.RestoreArgs]] = None, - ) -> Sequence[np.bool_]: - """ - Deserializes into a sequence of np.bool_ scalars by calling the - parent handler and then converting the resulting 0-dim arrays. - """ - # Parent deserialize will return a sequence of 0-dimensional np.ndarray - results = await super().deserialize(infos, args) - - # Convert each 0-d array back to an np.bool_ scalar using .item() - scalar_results = [np.bool_(r.item()) for r in results] - return scalar_results ocp.type_handlers.register_type_handler(np.bool_, BoolHandler(), override=True) + def maybe_restore_checkpoint( framework: str, optimizer_state: spec.OptimizerState, diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index c7537a884..e24b0f141 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -27,7 +27,9 @@ def pytorch_setup() -> Tuple[bool, int, torch.device, int]: return use_pytorch_ddp, rank, device, n_gpus -def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler, limit_tf_threads = True) -> None: +def pytorch_init( + use_pytorch_ddp: bool, rank: int, profiler: Profiler, limit_tf_threads=True +) -> None: # Make sure no GPU memory is preallocated to Jax. os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # Only use CPU for Jax to avoid memory issues. @@ -47,8 +49,10 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler, limit_tf_ profiler.set_local_rank(rank) # Only log once (for local rank == 0). if rank != 0: + def logging_pass(*args): pass + logging.info = logging_pass # Initialize the process group. dist.init_process_group('nccl') diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index cfa2f36cd..3007371fc 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -50,14 +50,15 @@ def batch_with_padding( return padded_batched_dataset -def get_data_iter(data_rng: jax.random.PRNGKey, +def get_data_iter( + data_rng: jax.random.PRNGKey, split: str, data_dir: str, batch_size: int, - num_batches: Optional[int] = None,): - + num_batches: Optional[int] = None, +): ds = get_lm_dataset(data_rng, split, data_dir, batch_size, num_batches) - + it = map( functools.partial( data_utils.shard_and_maybe_pad_np, global_batch_size=batch_size @@ -67,6 +68,7 @@ def get_data_iter(data_rng: jax.random.PRNGKey, return iter(it) + def get_lm_dataset( data_rng: jax.random.PRNGKey, split: str, @@ -78,7 +80,7 @@ def get_lm_dataset( if split not in TFDS_SPLIT_NAME: raise NotImplementedError - shuffle_seed = jax.random.randint(data_rng, (), -2**31, 2**31-1) + shuffle_seed = jax.random.randint(data_rng, (), -(2**31), 2**31 - 1) data_dir = os.path.join(data_dir, TFDS_SPLIT_NAME[split]) tokens_ds = tf.data.Dataset.load(data_dir) @@ -98,19 +100,17 @@ def get_lm_dataset( num_parallel_calls=AUTOTUNE, ) if split == 'train': - ds = sequences_ds.shuffle( - SHUFFLE_BUFFER_SIZE, seed=shuffle_seed - ) - ds = ds.batch( - batch_size, drop_remainder=False - ) + ds = sequences_ds.shuffle(SHUFFLE_BUFFER_SIZE, seed=shuffle_seed) + ds = ds.batch(batch_size, drop_remainder=False) ds = ds.take(num_batches) if num_batches is not None else ds ds = ds.repeat() - ds = ds.map(lambda x: { - 'inputs': x['inputs'], - 'targets': x['targets'], - 'weights': None, - }) + ds = ds.map( + lambda x: { + 'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': None, + } + ) ds = ds.prefetch(tf.data.experimental.AUTOTUNE) elif split == 'eval_train': ds = batch_with_padding( @@ -123,10 +123,13 @@ def get_lm_dataset( ) ds = ds.take(num_batches) if num_batches is not None else ds ds = ds.repeat() - ds = ds.map(lambda x: {'inputs': x['inputs'], - 'targets': x['targets'], - 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0) - }) + ds = ds.map( + lambda x: { + 'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0), + } + ) ds = ds.prefetch(tf.data.experimental.AUTOTUNE) elif split == 'validation': ds = batch_with_padding( @@ -139,9 +142,12 @@ def get_lm_dataset( ) ds = ds.take(num_batches) if num_batches is not None else ds ds = ds.repeat() - ds = ds.map(lambda x: {'inputs': x['inputs'], - 'targets': x['targets'], - 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0) - }) + ds = ds.map( + lambda x: { + 'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0), + } + ) ds = ds.prefetch(tf.data.experimental.AUTOTUNE) return ds diff --git a/algoperf/workloads/lm/lm_jax/models.py b/algoperf/workloads/lm/lm_jax/models.py deleted file mode 100644 index ae8d935bf..000000000 --- a/algoperf/workloads/lm/lm_jax/models.py +++ /dev/null @@ -1,20 +0,0 @@ -import jax.numpy as jnp -from flax import linen as nn - - -class LinearModel(nn.Module): - vocab_size: int - - @nn.compact - def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: - x = nn.Dense( - 10, - kernel_init=nn.initializers.normal(0.02), - bias_init=nn.initializers.zeros - )(inputs) - return nn.Dense( - self.vocab_size, - kernel_init=nn.initializers.normal(0.02), - bias_init=nn.initializers.zeros, - name="output" - )(x) diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py index 9126d31e8..a1644f569 100644 --- a/algoperf/workloads/lm/lm_jax/nanodo_model.py +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -1,4 +1,7 @@ -# Self-contained version of the DecoderOnly Transformer from NanoDO +""" +Originally based on code from the NanoDO repository under the Apache 2.0 license: +https://github.com/google-deepmind/nanodo +""" import dataclasses from functools import partial @@ -7,343 +10,345 @@ import jax.numpy as jnp from flax import linen as nn -# =========== Transformer Decoder-only Model ========== - - @dataclasses.dataclass class DoConfig: - """Hyper-parameters for Transformer decoder-only.""" - - D: int # model/embed dim = qkv dim - H: int # num attention heads - L: int # max context/sequence length - N: int # number of transformer block layers - V: int # vocab size - F: int # FF inner dimension - kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform() - embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling( - 1.0, "fan_in", "normal", out_axis=0 - ) - dtype: jnp.dtype = jnp.float32 - rmsnorm_epsilon: float = 1e-6 - multiple_of: int = 256 - tie_embeddings: bool = True # Whether to tie input and output embeddings + """Hyper-parameters for Transformer decoder-only.""" + + D: int # model/embed dim = qkv dim + H: int # num attention heads + L: int # max context/sequence length + N: int # number of transformer block layers + V: int # vocab size + F: int # FF inner dimension + kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform() + embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'normal', out_axis=0 + ) + dtype: jnp.dtype = jnp.float32 + rmsnorm_epsilon: float = 1e-6 + multiple_of: int = 256 + tie_embeddings: bool = True # Whether to tie input and output embeddings class Mlp(nn.Module): - """Multilayer perceptron with GLU activation.""" - - cfg: DoConfig - - @nn.compact - def __call__(self, x_BxLxD: jax.Array): - cfg = self.cfg - # Use Xavier uniform initialization explicitly - xavier_init = nn.initializers.xavier_uniform() - linear = partial( - nn.Dense, kernel_init=xavier_init, use_bias=False, dtype=cfg.dtype - ) - # Adjust hidden dimension to keep the number of parameters invariant to - # the activation function used since the GLU MLP has 3 * hidden_dim * D - # parameters instead of 2 * hidden_dim * D parameters - hidden_dim = cfg.F * 2 / 3 - hidden_dim = cfg.multiple_of * ( - (cfg.F + cfg.multiple_of - 1) // cfg.multiple_of - ) - # Double the hidden dimension for GLU - x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD) - # Apply GLU activation - x_BxLxF = nn.glu(x_BxLx2F, axis=-1) - x_BxLxD = linear(cfg.D)(x_BxLxF) - return x_BxLxD - -@partial(jax.jit, static_argnums=(0,1,2)) -def init_rope(dim=256, seq_len=128, n_heads=4): - """Initialize rotary embeddings.""" - def precompute_freqs_cis_jax(dim, end, theta=10000.0): - inv_freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2) / dim)) - t = jnp.arange(end) / 1.0 - freqs = jnp.outer(t, inv_freqs).astype(jnp.float32) - return jnp.stack([ - jnp.cos(freqs)[None, :, None, :], - jnp.sin(freqs)[None, :, None, :] - ], axis=3) - - freqs_cis = precompute_freqs_cis_jax(dim // n_heads, seq_len, theta=500000) - return freqs_cis.transpose(0, 1, 2, 4, 3) + """Multilayer perceptron with GLU activation.""" -@jax.jit -def apply_rope(q, k, freqs_cis): - """Apply rotary embeddings to Q and K.""" - def rotate_tensor(x): - # Split into real and imaginary parts - x_r2 = x.reshape(*x.shape[:-1], -1, 2) - L = x.shape[1] - freqs = freqs_cis[:, :L, :, :, :] + cfg: DoConfig - # Apply rotation - rotated_x_r2 = jnp.stack([ - x_r2[..., 0] * freqs[..., 0] - x_r2[..., 1] * freqs[..., 1], - x_r2[..., 1] * freqs[..., 0] + x_r2[..., 0] * freqs[..., 1] - ], axis=-1) + @nn.compact + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg + # Use Xavier uniform initialization explicitly + xavier_init = nn.initializers.xavier_uniform() + linear = partial( + nn.Dense, kernel_init=xavier_init, use_bias=False, dtype=cfg.dtype + ) + # Adjust hidden dimension to keep the number of parameters invariant to + # the activation function used since the GLU MLP has 3 * hidden_dim * D + # parameters instead of 2 * hidden_dim * D parameters + hidden_dim = cfg.F * 2 / 3 + hidden_dim = cfg.multiple_of * ( + (cfg.F + cfg.multiple_of - 1) // cfg.multiple_of + ) + # Double the hidden dimension for GLU + x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD) + # Apply GLU activation + x_BxLxF = nn.glu(x_BxLx2F, axis=-1) + x_BxLxD = linear(cfg.D)(x_BxLxF) + return x_BxLxD - return rotated_x_r2.reshape(*x.shape) - # Apply rotation to Q and K separately - rotated_q = rotate_tensor(q) - rotated_k = rotate_tensor(k) +@partial(jax.jit, static_argnums=(0, 1, 2)) +def init_rope(dim=256, seq_len=128, n_heads=4): + """Initialize rotary embeddings.""" + + def precompute_freqs_cis_jax(dim, end, theta=10000.0): + inv_freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2) / dim)) + t = jnp.arange(end) / 1.0 + freqs = jnp.outer(t, inv_freqs).astype(jnp.float32) + return jnp.stack( + [jnp.cos(freqs)[None, :, None, :], jnp.sin(freqs)[None, :, None, :]], + axis=3, + ) - return rotated_q, rotated_k + freqs_cis = precompute_freqs_cis_jax(dim // n_heads, seq_len, theta=500000) + return freqs_cis.transpose(0, 1, 2, 4, 3) -class CausalAttn(nn.Module): - """Causal attention layer with rotary embeddings.""" +@jax.jit +def apply_rope(q, k, freqs_cis): + """Apply rotary embeddings to Q and K.""" + + def rotate_tensor(x): + # Split into real and imaginary parts + x_r2 = x.reshape(*x.shape[:-1], -1, 2) + L = x.shape[1] + freqs = freqs_cis[:, :L, :, :, :] + + # Apply rotation + rotated_x_r2 = jnp.stack( + [ + x_r2[..., 0] * freqs[..., 0] - x_r2[..., 1] * freqs[..., 1], + x_r2[..., 1] * freqs[..., 0] + x_r2[..., 0] * freqs[..., 1], + ], + axis=-1, + ) - cfg: DoConfig + return rotated_x_r2.reshape(*x.shape) - def setup(self): - cfg = self.cfg - assert cfg.D % cfg.H == 0, f"D {cfg.D} not divisible by H {cfg.H}" - self.Dh = cfg.D // cfg.H + # Apply rotation to Q and K separately + rotated_q = rotate_tensor(q) + rotated_k = rotate_tensor(k) - # Initialize rotary embeddings - self.freqs_cis = init_rope(cfg.D, cfg.L, cfg.H) + return rotated_q, rotated_k - # Maps D -> (H, Dh) - self.multilinear = partial( - nn.DenseGeneral, - axis=-1, - features=(cfg.H, self.Dh), - kernel_init=cfg.kernel_init, - use_bias=False, - dtype=cfg.dtype, - ) - self.multilinear_query = self.multilinear(name="query") - self.multilinear_key = self.multilinear(name="key") - self.multilinear_value = self.multilinear(name="value") - self.output_projection = nn.DenseGeneral( - features=cfg.D, - name="attn_out_proj", - # axis=(-2, -1), # - kernel_init=cfg.kernel_init, - use_bias=False, - dtype=cfg.dtype, - ) +class CausalAttn(nn.Module): + """Causal attention layer with rotary embeddings.""" + + cfg: DoConfig + + def setup(self): + cfg = self.cfg + assert cfg.D % cfg.H == 0, f'D {cfg.D} not divisible by H {cfg.H}' + self.Dh = cfg.D // cfg.H + + # Initialize rotary embeddings + self.freqs_cis = init_rope(cfg.D, cfg.L, cfg.H) + + # Maps D -> (H, Dh) + self.multilinear = partial( + nn.DenseGeneral, + axis=-1, + features=(cfg.H, self.Dh), + kernel_init=cfg.kernel_init, + use_bias=False, + dtype=cfg.dtype, + ) + + self.multilinear_query = self.multilinear(name='query') + self.multilinear_key = self.multilinear(name='key') + self.multilinear_value = self.multilinear(name='value') + self.output_projection = nn.DenseGeneral( + features=cfg.D, + name='attn_out_proj', + # axis=(-2, -1), # + kernel_init=cfg.kernel_init, + use_bias=False, + dtype=cfg.dtype, + ) - def __call__(self, x_BxLxD: jax.Array): - cfg = self.cfg + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg - # Project inputs to Q, K, V - q_BxLxHxDh = self.multilinear_query(x_BxLxD) - k_BxLxHxDh = self.multilinear_key(x_BxLxD) - v_BxLxHxDh = self.multilinear_value(x_BxLxD) + # Project inputs to Q, K, V + q_BxLxHxDh = self.multilinear_query(x_BxLxD) + k_BxLxHxDh = self.multilinear_key(x_BxLxD) + v_BxLxHxDh = self.multilinear_value(x_BxLxD) - # Apply rotary embeddings to Q and K - q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis) + # Apply rotary embeddings to Q and K + q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis) - # Scale queries - q_BxLxHxDh /= self.Dh**0.5 + # Scale queries + q_BxLxHxDh /= self.Dh**0.5 - # Compute attention scores - att_BxHxLxL = jnp.einsum("...qhd,...khd->...hqk", q_BxLxHxDh, k_BxLxHxDh) + # Compute attention scores + att_BxHxLxL = jnp.einsum('...qhd,...khd->...hqk', q_BxLxHxDh, k_BxLxHxDh) - # Causal attention mask - L = x_BxLxD.shape[1] - mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_)) + # Causal attention mask + L = x_BxLxD.shape[1] + mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_)) - # Apply mask and softmax - _NEG_INF = jnp.finfo(cfg.dtype).min - att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) - att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) - att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) + # Apply mask and softmax + _NEG_INF = jnp.finfo(cfg.dtype).min + att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) + att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) + att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) - # Compute attention output - out_BxLxHxDh = jnp.einsum("...hqk,...khd->...qhd", att_BxHxLxL, v_BxLxHxDh) + # Compute attention output + out_BxLxHxDh = jnp.einsum('...hqk,...khd->...qhd', att_BxHxLxL, v_BxLxHxDh) - # Reshape and project output - out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape) + # Reshape and project output + out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape) - # Output projection - out_BxLxD = self.output_projection(out_BxLxD) + # Output projection + out_BxLxD = self.output_projection(out_BxLxD) - return out_BxLxD + return out_BxLxD class TBlock(nn.Module): - """Transformer Block.""" + """Transformer Block.""" - docfg: DoConfig + docfg: DoConfig - @nn.compact - def __call__(self, in_BxLxD: jax.Array): - cfg = self.docfg + @nn.compact + def __call__(self, in_BxLxD: jax.Array): + cfg = self.docfg - # x = x + attn( attn_norm(x) ) - x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( - in_BxLxD - ) - x_BxLxD = CausalAttn(cfg)(x_BxLxD) - x_BxLxD += in_BxLxD + # x = x + attn( attn_norm(x) ) + x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + in_BxLxD + ) + x_BxLxD = CausalAttn(cfg)(x_BxLxD) + x_BxLxD += in_BxLxD - # x = x + mlp( mlp_norm(x) ) - z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( - x_BxLxD - ) - z_BxLxD = Mlp(cfg)(z_BxLxD) + # x = x + mlp( mlp_norm(x) ) + z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + x_BxLxD + ) + z_BxLxD = Mlp(cfg)(z_BxLxD) - return x_BxLxD + z_BxLxD + return x_BxLxD + z_BxLxD class TransformerDo(nn.Module): - """Transformer decoder-only.""" - - docfg: DoConfig - - def setup(self): - cfg = self.docfg - self.embed = nn.Embed( - num_embeddings=cfg.V, - features=cfg.D, - embedding_init=cfg.embed_init, - ) - - self.blocks = [TBlock(cfg) for _ in range(cfg.N)] - self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) - - # Output projection - tied to input embeddings if configured - if cfg.tie_embeddings: - self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) - else: - self.output_proj = nn.Dense( - cfg.V, - kernel_init=cfg.embed_init, - dtype=cfg.dtype, - name="output_proj" - ) - - def __call__(self, y_BxL: jax.Array): - # For training on concatenated examples. - y_BxLxD = self.embed(y_BxL) - for block in self.blocks: - y_BxLxD = block(y_BxLxD) - y_BxLxD = self.out_ln(y_BxLxD) - logits_BxLxV = self.output_proj(y_BxLxD) - return logits_BxLxV - - def predict(self, y_BxL: jax.Array, k: int = 1): - """Generate k tokens autoregressively. - - Args: - y_BxL: Input token sequence of shape (batch_size, seq_len) - k: Number of tokens to predict - - Returns: - Tuple of (input_ids, predicted_ids) - """ - cfg = self.docfg - batch_size = y_BxL.shape[0] - seq_len = y_BxL.shape[1] - - # Store original input - original_input = y_BxL - - # Make sure we don't exceed the model's context length - if seq_len + k > cfg.L: - raise ValueError( - f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.L})" - ) - - # Generate k tokens autoregressively - for _ in range(k): - # Get logits for the entire sequence - logits = self(y_BxL) - - # Get the logits for the last token in each sequence - next_token_logits = logits[:, -1, :] - - # Get the most likely token - next_token = jnp.argmax(next_token_logits, axis=-1) - - # Append the predicted token to the sequence - y_BxL = jnp.concatenate([y_BxL, next_token[:, None]], axis=1) - - # Return original input and the k predicted tokens - return original_input, y_BxL[:, -k:] + """Transformer decoder-only.""" + docfg: DoConfig -# =========== Demo Code ========== + def setup(self): + cfg = self.docfg + self.embed = nn.Embed( + num_embeddings=cfg.V, + features=cfg.D, + embedding_init=cfg.embed_init, + ) + self.blocks = [TBlock(cfg) for _ in range(cfg.N)] + self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) -def main(): - """Create and run the DecoderOnly Transformer model.""" - # Initialize model configuration with smaller parameters for demo - B, L = (2, 128) # Batch size, sequence length - cfg = DoConfig(D=128, H=4, L=L, N=2, V=256, F=4 * 128) - model = TransformerDo(cfg) - - # Print model info - print("\nModel Configuration:") - print(f" - Model dimension (D): {cfg.D}") - print(f" - Number of heads (H): {cfg.H}") - print(f" - Max sequence length (L): {cfg.L}") - print(f" - Number of layers (N): {cfg.N}") - print(f" - Vocabulary size (V): {cfg.V}") - print(f" - Feed forward dimension (F): {cfg.F}") - - # Create random input tokens (simulated token IDs) - rng_key = jax.random.PRNGKey(42) - input_rng, init_rng = jax.random.split(rng_key) - - # Generate random token IDs (integers between 0 and vocab_size-1) - x_BxL = jax.random.randint( - input_rng, shape=(B, L), minval=0, maxval=cfg.V, dtype=jnp.int32 - ) + # Output projection - tied to input embeddings if configured + if cfg.tie_embeddings: + self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) + else: + self.output_proj = nn.Dense( + cfg.V, kernel_init=cfg.embed_init, dtype=cfg.dtype, name='output_proj' + ) - # Initialize model parameters - print("\nInitializing model parameters...") - params = model.init(init_rng, x_BxL) + def __call__(self, y_BxL: jax.Array): + # For training on concatenated examples. + y_BxLxD = self.embed(y_BxL) + for block in self.blocks: + y_BxLxD = block(y_BxLxD) + y_BxLxD = self.out_ln(y_BxLxD) + logits_BxLxV = self.output_proj(y_BxLxD) + return logits_BxLxV - # Print parameter count - param_count = sum(x.size for x in jax.tree_util.tree_leaves(params)) - print(f"Total parameters: {param_count:,}") + def predict(self, y_BxL: jax.Array, k: int = 1): + """Generate k tokens autoregressively. - # Make a prediction (forward pass) - print("\nRunning forward pass...") - logits = model.apply(params, x_BxL) + Args: + y_BxL: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict - # Print output shape and sample values - print(f"\nOutput shape: {logits.shape} (batch_size, sequence_length, vocab_size)") - print(f"Output data type: {logits.dtype}") + Returns: + Tuple of (input_ids, predicted_ids) + """ + cfg = self.docfg + seq_len = y_BxL.shape[1] - # Print sample logits (first 5 positions of the first sequence) - print("\nSample logits (first sequence, first 5 positions, first 5 values):") - for position in range(min(5, L)): - print(f" Position {position}: {logits[0, position, :5]}") + # Store original input + original_input = y_BxL - # Get predictions (token with highest logit at each position) - predictions = jnp.argmax(logits, axis=-1) - print("\nPredicted token IDs (first sequence, first 10 positions):") - print(predictions[0, :10]) + # Make sure we don't exceed the model's context length + if seq_len + k > cfg.L: + raise ValueError( + f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.L})" + ) - # Test the predict function - print("\nTesting predict function...") - # Use a shorter - short_seq = x_BxL[:, :10] - print(f"Input sequence shape: {short_seq.shape}") + # Generate k tokens autoregressively + for _ in range(k): + # Get logits for the entire sequence + logits = self(y_BxL) - # Predict 5 tokens - k = 5 - original, predicted = model.apply(params, short_seq, k, method=model.predict) + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] - # Get predictions (token with highest logit at each position) - predictions = jnp.argmax(logits, axis=-1) - print("\nPredicted token IDs (first sequence, first 10 positions):") - print(predictions[0, :10]) + # Get the most likely token + next_token = jnp.argmax(next_token_logits, axis=-1) - print("\nDone!") + # Append the predicted token to the sequence + y_BxL = jnp.concatenate([y_BxL, next_token[:, None]], axis=1) + # Return original input and the k predicted tokens + return original_input, y_BxL[:, -k:] -if __name__ == "__main__": - main() + +# =========== Demo Code ========== + + +def main(): + """Create and run the DecoderOnly Transformer model.""" + # Initialize model configuration with smaller parameters for demo + B, L = (2, 128) # Batch size, sequence length + cfg = DoConfig(D=128, H=4, L=L, N=2, V=256, F=4 * 128) + model = TransformerDo(cfg) + + # Print model info + print('\nModel Configuration:') + print(f' - Model dimension (D): {cfg.D}') + print(f' - Number of heads (H): {cfg.H}') + print(f' - Max sequence length (L): {cfg.L}') + print(f' - Number of layers (N): {cfg.N}') + print(f' - Vocabulary size (V): {cfg.V}') + print(f' - Feed forward dimension (F): {cfg.F}') + + # Create random input tokens (simulated token IDs) + rng_key = jax.random.PRNGKey(42) + input_rng, init_rng = jax.random.split(rng_key) + + # Generate random token IDs (integers between 0 and vocab_size-1) + x_BxL = jax.random.randint( + input_rng, shape=(B, L), minval=0, maxval=cfg.V, dtype=jnp.int32 + ) + + # Initialize model parameters + print('\nInitializing model parameters...') + params = model.init(init_rng, x_BxL) + + # Print parameter count + param_count = sum(x.size for x in jax.tree_util.tree_leaves(params)) + print(f'Total parameters: {param_count:,}') + + # Make a prediction (forward pass) + print('\nRunning forward pass...') + logits = model.apply(params, x_BxL) + + # Print output shape and sample values + print( + f'\nOutput shape: {logits.shape} (batch_size, sequence_length, vocab_size)' + ) + print(f'Output data type: {logits.dtype}') + + # Print sample logits (first 5 positions of the first sequence) + print('\nSample logits (first sequence, first 5 positions, first 5 values):') + for position in range(min(5, L)): + print(f' Position {position}: {logits[0, position, :5]}') + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print('\nPredicted token IDs (first sequence, first 10 positions):') + print(predictions[0, :10]) + + # Test the predict function + print('\nTesting predict function...') + # Use a shorter + short_seq = x_BxL[:, :10] + print(f'Input sequence shape: {short_seq.shape}') + + # Predict 5 tokens + k = 5 + original, predicted = model.apply(params, short_seq, k, method=model.predict) + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print('\nPredicted token IDs (first sequence, first 10 positions):') + print(predictions[0, :10]) + + print('\nDone!') + + +if __name__ == '__main__': + main() diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index ad7eac8aa..5b736fad7 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -16,47 +16,52 @@ class LmWorkload(BaseLmWorkload): """LM JAX workload.""" - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None): + + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ): """Build an input queue using pre-cached FineWeb dataset.""" del cache, repeat_final_dataset ds = get_data_iter( - data_rng=data_rng, - split=split, - data_dir=data_dir, - batch_size=global_batch_size, - num_batches=num_batches) + data_rng=data_rng, + split=split, + data_dir=data_dir, + batch_size=global_batch_size, + num_batches=num_batches, + ) ds = map(jax_sharding_utils.shard_along_batch_dim, ds) return ds def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None, + ) -> spec.ModelInitState: # Initialize NanoDO transformer model cfg = DoConfig( - D=self._emb_dim, # embedding dim - H=self._n_heads, # num heads - L=self._seq_len, - N=self._n_layers, # num layers - V=self._vocab_size, - F=self._mlp_dim, # feedforward dim - dtype=jnp.float32 + D=self._emb_dim, # embedding dim + H=self._n_heads, # num heads + L=self._seq_len, + N=self._n_layers, # num layers + V=self._vocab_size, + F=self._mlp_dim, # feedforward dim + dtype=jnp.float32, ) self._model = TransformerDo(cfg) input_shape = (1, self._seq_len) # For token IDs params_rng, init_rng = jax.random.split(rng) - variables = jax.jit(self._model.init)({'params': params_rng}, - jnp.ones(input_shape, jnp.int32)) + variables = jax.jit(self._model.init)( + {'params': params_rng}, jnp.ones(input_shape, jnp.int32) + ) params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -65,14 +70,15 @@ def init_model_fn( return params, model_state def model_fn( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - dropout_rate: float = 0.0) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = 0.0, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode, rng, update_batch_norm, model_state, dropout_rate inputs = batch['inputs'] # Convert one-hot inputs to token IDs if needed @@ -81,14 +87,13 @@ def model_fn( logits = self._model.apply({'params': params}, inputs) return logits, None - def compute_weighted_cross_entropy( - self, - logits: spec.Tensor, - targets: spec.Tensor, - weights: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0, - ) -> Dict[str, spec.Tensor]: # differentiable + self, + logits: spec.Tensor, + targets: spec.Tensor, + weights: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Compute weighted cross entropy and entropy for log probs and targets. Args: logits: [batch, length, num_classes] float array. @@ -110,15 +115,15 @@ def compute_weighted_cross_entropy( # Extract log probability of the target class # Shape: [batch, length] target_log_probs = jnp.take_along_axis( - log_probs, - targets[..., None], - axis=-1 + log_probs, targets[..., None], axis=-1 ).squeeze(-1) # Cross-entropy with smoothing: -(1 - α) * log_p[target] - α * mean(log_p) # The above formula is easy to derive from the definition of label smoothing and cross-entropy loss. confidence = 1.0 - label_smoothing smoothing_term = label_smoothing / self._vocab_size - per_example_losses = -1.0 * (confidence * target_log_probs + smoothing_term * log_probs.sum(axis=-1)) + per_example_losses = -1.0 * ( + confidence * target_log_probs + smoothing_term * log_probs.sum(axis=-1) + ) if weights is not None: per_example_losses = jnp.where(weights, per_example_losses, 0.0) n_valid_examples = weights.sum() @@ -131,22 +136,27 @@ def compute_weighted_cross_entropy( 'per_example': per_example_losses, } - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> spec.Tensor: - """Evaluate the model on a single batch.""" - logits, _ = self.model_fn( - params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - # Calculate cross-entropy loss - metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) - # CRITICAL: Detach tensors to free computation graph and activations - # Without this, all intermediate activations are kept in memory! - return { - 'loss': metrics['summed'], - 'denominator': metrics['n_valid_examples'], - } + def _eval_batch( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False + ) + # Calculate cross-entropy loss + metrics = self.compute_weighted_cross_entropy( + logits, batch['targets'], batch['weights'] + ) + # CRITICAL: Detach tensors to free computation graph and activations + # Without this, all intermediate activations are kept in memory! + return { + 'loss': metrics['summed'], + 'denominator': metrics['n_valid_examples'], + } def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any] diff --git a/algoperf/workloads/lm/lm_pytorch/models.py b/algoperf/workloads/lm/lm_pytorch/models.py deleted file mode 100644 index b88e457d8..000000000 --- a/algoperf/workloads/lm/lm_pytorch/models.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch -import torch.nn as nn - - -class LinearLayer(nn.Module): - def __init__(self, vocab_size: int): - super().__init__() - self.bottleneck = nn.Linear(vocab_size, 512) - self.output = nn.Linear(512, vocab_size) - self.reset_parameters() - - def reset_parameters(self): - nn.init.normal_(self.bottleneck.weight, std=0.02) - nn.init.zeros_(self.bottleneck.bias) - nn.init.normal_(self.output.weight, std=0.02) - nn.init.zeros_(self.output.bias) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.output(self.bottleneck(x)) diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py index 9dc8be522..f7e7f9e62 100644 --- a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -1,3 +1,9 @@ +""" +Originally based on the plainLM codebase: +https://github.com/Niccolo-Ajroldi/plainLM +under the MIT license https://github.com/Niccolo-Ajroldi/plainLM/blob/main/LICENSE. +""" + import math from dataclasses import dataclass from typing import Tuple @@ -9,299 +15,313 @@ @dataclass class ModelConfig: - vocab_size: int - seq_len: int - dim: int - expand: float - n_layers: int - n_heads: int - rmsnorm_eps: float = 1e-6 - tie_embeddings: bool = True + vocab_size: int + seq_len: int + dim: int + expand: float + n_layers: int + n_heads: int + rmsnorm_eps: float = 1e-6 + tie_embeddings: bool = True class MLP(nn.Module): - - def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): - super().__init__() - hidden_dim = multiple_of * ( - (hidden_dim + multiple_of - 1) // multiple_of) - self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) - self.fc2 = nn.Linear(hidden_dim, dim, bias=False) - self.glu = nn.GLU(dim=2) - - # Initialize with Xavier uniform - nn.init.xavier_uniform_(self.fc1.weight) - nn.init.xavier_uniform_(self.fc2.weight) - - def forward(self, x): - # x: (bsz, T, dim) - return self.fc2(self.glu(self.fc1(x))) - - -def precompute_freqs_cis(dim: int, - end: int, - theta: float = 10000.0, - condense_ratio: int = 1): - inv_freqs = 1.0 / (theta**(torch.arange( - 0, dim, 2, dtype=torch.float32, device=torch.device("cpu")) / dim)) - t = torch.arange(end, dtype=torch.float32, - device=inv_freqs.device) / condense_ratio - freqs = torch.outer(t, inv_freqs).float() - return torch.stack([ - torch.cos(freqs)[None, :, None, :], - torch.sin(freqs)[None, :, None, :] - ], - dim=4) + def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): + super().__init__() + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) + self.fc2 = nn.Linear(hidden_dim, dim, bias=False) + self.glu = nn.GLU(dim=2) + + # Initialize with Xavier uniform + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + + def forward(self, x): + # x: (bsz, T, dim) + return self.fc2(self.glu(self.fc1(x))) + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, condense_ratio: int = 1 +): + inv_freqs = 1.0 / ( + theta + ** ( + torch.arange(0, dim, 2, dtype=torch.float32, device=torch.device('cpu')) + / dim + ) + ) + t = ( + torch.arange(end, dtype=torch.float32, device=inv_freqs.device) + / condense_ratio + ) + freqs = torch.outer(t, inv_freqs).float() + return torch.stack( + [torch.cos(freqs)[None, :, None, :], torch.sin(freqs)[None, :, None, :]], + dim=4, + ) def apply_rotary_emb_complex_like( - q: torch.Tensor, k: torch.Tensor, - freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - # Rotate query and key vectors using RoPE - qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() - rotated_qk_r2 = torch.stack( - [ - qk_r2[..., 0] * freqs_cis[..., 0] - - qk_r2[..., 1] * freqs_cis[..., 1], - qk_r2[..., 1] * freqs_cis[..., 0] + - qk_r2[..., 0] * freqs_cis[..., 1], - ], - -1, - ).flatten(3) - rotated_qk = rotated_qk_r2 - return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) + q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + # Rotate query and key vectors using RoPE + qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() + rotated_qk_r2 = torch.stack( + [ + qk_r2[..., 0] * freqs_cis[..., 0] - qk_r2[..., 1] * freqs_cis[..., 1], + qk_r2[..., 1] * freqs_cis[..., 0] + qk_r2[..., 0] * freqs_cis[..., 1], + ], + -1, + ).flatten(3) + rotated_qk = rotated_qk_r2 + return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) class Attention(nn.Module): + def __init__(self, cfg: ModelConfig): + super().__init__() + assert cfg.dim % cfg.n_heads == 0 + self.dim = cfg.dim + self.n_heads = cfg.n_heads + self.head_dim = cfg.dim // cfg.n_heads - def __init__(self, cfg: ModelConfig): - super().__init__() - assert cfg.dim % cfg.n_heads == 0 - self.dim = cfg.dim - self.n_heads = cfg.n_heads - self.head_dim = cfg.dim // cfg.n_heads - - self.w_qkv = nn.Linear(cfg.dim, 3 * cfg.dim, bias=False) - self.w_out = nn.Linear(cfg.dim, cfg.dim, bias=False) + self.w_qkv = nn.Linear(cfg.dim, 3 * cfg.dim, bias=False) + self.w_out = nn.Linear(cfg.dim, cfg.dim, bias=False) - def forward(self, x, freqs_cis): - bsz, seqlen, d = x.shape # (bsz, seqlen, d) + def forward(self, x, freqs_cis): + bsz, seqlen, d = x.shape # (bsz, seqlen, d) - q, k, v = self.w_qkv(x).split(d, dim=2) # (bsz, seqlen, d) - q = q.view(bsz, seqlen, self.n_heads, - self.head_dim) # (bsz, seqlen, nh, h_dim) - k = k.view(bsz, seqlen, self.n_heads, - self.head_dim) # (bsz, seqlen, nh, h_dim) - v = v.view(bsz, seqlen, self.n_heads, - self.head_dim) # (bsz, seqlen, nh, h_dim) + q, k, v = self.w_qkv(x).split(d, dim=2) # (bsz, seqlen, d) + q = q.view( + bsz, seqlen, self.n_heads, self.head_dim + ) # (bsz, seqlen, nh, h_dim) + k = k.view( + bsz, seqlen, self.n_heads, self.head_dim + ) # (bsz, seqlen, nh, h_dim) + v = v.view( + bsz, seqlen, self.n_heads, self.head_dim + ) # (bsz, seqlen, nh, h_dim) - q, k = apply_rotary_emb_complex_like( - q, k, freqs_cis=freqs_cis) # (bsz, seqlen, nh, h_dim) + q, k = apply_rotary_emb_complex_like( + q, k, freqs_cis=freqs_cis + ) # (bsz, seqlen, nh, h_dim) - q = q.transpose(1, 2) # (bsz, nh, seqlen, h_dim) - k = k.transpose(1, 2) # (bsz, nh, seqlen, h_dim) - v = v.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + q = q.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + k = k.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + v = v.transpose(1, 2) # (bsz, nh, seqlen, h_dim) - out = F.scaled_dot_product_attention( - q, k, v, is_causal=True) # (bsz, nh, seqlen, h_dim) + out = F.scaled_dot_product_attention( + q, k, v, is_causal=True + ) # (bsz, nh, seqlen, h_dim) - out = out.transpose(1, 2).contiguous().view(bsz, seqlen, - d) # (bsz, seqlen, d) + out = ( + out.transpose(1, 2).contiguous().view(bsz, seqlen, d) + ) # (bsz, seqlen, d) - return self.w_out(out) + return self.w_out(out) class Block(nn.Module): + def __init__(self, layer_id: int, cfg: ModelConfig): + super().__init__() + self.attn = Attention(cfg) + self.attn_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.mlp = MLP(dim=cfg.dim, hidden_dim=int(cfg.expand * cfg.dim)) + self.mlp_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.layer_id = layer_id - def __init__(self, layer_id: int, cfg: ModelConfig): - super().__init__() - self.attn = Attention(cfg) - self.attn_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) - self.mlp = MLP(dim=cfg.dim, hidden_dim=int(cfg.expand * cfg.dim)) - self.mlp_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) - self.layer_id = layer_id - - def forward(self, x, freqs_cis): - # x: (bsz, seqlen, dim) - x = x + self.attn(self.attn_norm(x), freqs_cis) - x = x + self.mlp(self.mlp_norm(x)) - return x + def forward(self, x, freqs_cis): + # x: (bsz, seqlen, dim) + x = x + self.attn(self.attn_norm(x), freqs_cis) + x = x + self.mlp(self.mlp_norm(x)) + return x class Transformer(nn.Module): - - def __init__(self, cfg): - super().__init__() - self.n_layers = cfg.n_layers - self.cfg = cfg - head_dim = cfg.dim // cfg.n_heads - assert cfg.dim % cfg.n_heads == 0 - - self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.dim) - self.layers = nn.ModuleList( - [Block(idx, cfg) for idx in range(cfg.n_layers)]) - self.out_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) - self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) - - # Initialize freqs_cis on CPU first (more memory efficient) - self.register_buffer('freqs_cis', - precompute_freqs_cis(head_dim, cfg.seq_len, 500000)[0:cfg.seq_len], - persistent=False) - - # init all weights, scale residual branches - self.apply(self._init_weights) - self._scale_residual_branches() - - # Move model to device (which will also move freqs_cis) - if torch.cuda.is_available(): - self.cuda() - - if cfg.tie_embeddings: - self.tie_weights() - - def forward(self, x, targets=None): - # x: (bsz, seqlen) - x = self.embed_tokens(x) # (bsz, seqlen, dim) - L = x.shape[1] - - # Make sure we have enough precomputed frequencies - if L > self.freqs_cis.shape[1]: - # Need to recompute for longer sequence - head_dim = self.cfg.dim // self.cfg.n_heads - new_freqs = precompute_freqs_cis(head_dim, max(L, self.cfg.seq_len), 500000) - self.register_buffer('freqs_cis', new_freqs[0:max(L, self.cfg.seq_len)], persistent=False) - if torch.cuda.is_available(): - self.freqs_cis = self.freqs_cis.cuda() - - # Select the frequencies for current sequence length and ensure correct device - freqs_cis = self.freqs_cis[:, :L, :].to(x.device) - - for layer in self.layers: - x = layer(x, freqs_cis) # (bsz, seqlen, dim) - out = self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) - if targets is not None: - loss = F.cross_entropy( - out.view(-1, out.size(-1)), targets.view(-1), ignore_index=-100) - return out, loss - return out - - def predict(self, x, k=1): - """Generate k tokens autoregressively. - - Args: - x: Input token sequence of shape (batch_size, seq_len) - k: Number of tokens to predict - - Returns: - Tuple of (input_ids, predicted_ids) - """ - - # Store original input - original_input = x.clone() - generated_input = x.clone() - - # Generate k tokens autoregressively - for i in range(k): - - # Get logits for the entire sequence - logits = self(generated_input) - - # Get the logits for the last token in each sequence - next_token_logits = logits[:, -1, :] - - # Zero out the last token ID to prevent repetition - # This is a common issue - the model gets stuck repeating the last token - last_token_id = generated_input[:, -1] - next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) - - # Get the most likely token - next_token = torch.argmax(next_token_logits, dim=-1) - - # Append the predicted token to the sequence - next_token = next_token.unsqueeze(1) # Add sequence dimension - generated_input = torch.cat([generated_input, next_token], dim=1) - - # For debugging, print predictions for the first item in the batch - print("\nPyTorch detailed prediction (first item in batch):") - predicted_sequence = generated_input[0, -k:].tolist() - print(f" Predicted token IDs: {predicted_sequence}") - for i, token_id in enumerate(predicted_sequence): - print(f" Step {i+1}: Predicted token {token_id}") - - # Return all tokens, not just the last k - return original_input, generated_input[:, -k:] - - def _init_weights(self, module): - if isinstance(module, nn.Linear): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - if module.bias is not None: - torch.nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - - def _scale_residual_branches(self): - for n, p in self.named_parameters(): - if n.endswith("fc2.weight"): # mlp/glu output layer - torch.nn.init.normal_(p, - mean=0.0, - std=0.02 / math.sqrt(2 * self.n_layers)) - if n.endswith("w_out.weight"): # attn output layer - torch.nn.init.normal_(p, - mean=0.0, - std=0.02 / math.sqrt(2 * self.n_layers)) - - def tie_weights(self): - self.lm_head.weight = self.embed_tokens.weight - - def count_params(self, non_embedding=True): - n_params = sum(p.numel() for p in self.parameters()) - if non_embedding: - n_params -= self.embed_tokens.weight.numel() - if (self.lm_head.weight is not self.embed_tokens.weight): # if no weight tying - n_params -= self.lm_head.weight.numel() - return n_params - - -def main(): - print("Initializing transformer model and running forward pass...") - - seq_length = 1024 - - # Define model configuration - config = ModelConfig( - vocab_size=50257, # Common vocab size for tokenizers like BPE or SentencePiece - seq_len=seq_length, # Maximum sequence length - dim=1024, # Embedding dimension - expand=4.0, # MLP expansion factor - n_layers=12, # Number of transformer layers - n_heads=8, # Number of attention heads - rmsnorm_eps=1e-6, # RMSNorm epsilon - tie_embeddings=True # Tie embedding and output weights + def __init__(self, cfg): + super().__init__() + self.n_layers = cfg.n_layers + self.cfg = cfg + head_dim = cfg.dim // cfg.n_heads + assert cfg.dim % cfg.n_heads == 0 + + self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.dim) + self.layers = nn.ModuleList( + [Block(idx, cfg) for idx in range(cfg.n_layers)] + ) + self.out_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) + + # Initialize freqs_cis on CPU first (more memory efficient) + self.register_buffer( + 'freqs_cis', + precompute_freqs_cis(head_dim, cfg.seq_len, 500000)[0 : cfg.seq_len], + persistent=False, ) - # Instantiate the model - model = Transformer(config) - print(f"Model has {model.count_params():,} parameters.") - - # Create some random input data - batch_size = 2 - input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_length)) + # init all weights, scale residual branches + self.apply(self._init_weights) + self._scale_residual_branches() - # Move data to the same device as the model + # Move model to device (which will also move freqs_cis) if torch.cuda.is_available(): - input_ids = input_ids.cuda() - - # Run a forward pass - print(f"Running forward pass with input shape: {input_ids.shape}") - logits = model(input_ids) - print(f"Output logits shape: {logits.shape}") - - # Run prediction - print("Running prediction...") - original_input, predicted_ids = model.predict(input_ids[:, :10], k=5) - print(f"Original input shape for prediction: {original_input.shape}") - print(f"Predicted IDs shape: {predicted_ids.shape}") - print(f"Predicted IDs: {predicted_ids}") - -if __name__ == "__main__": - main() + self.cuda() + + if cfg.tie_embeddings: + self.tie_weights() + + def forward(self, x, targets=None): + # x: (bsz, seqlen) + x = self.embed_tokens(x) # (bsz, seqlen, dim) + L = x.shape[1] + + # Make sure we have enough precomputed frequencies + if L > self.freqs_cis.shape[1]: + # Need to recompute for longer sequence + head_dim = self.cfg.dim // self.cfg.n_heads + new_freqs = precompute_freqs_cis( + head_dim, max(L, self.cfg.seq_len), 500000 + ) + self.register_buffer( + 'freqs_cis', new_freqs[0 : max(L, self.cfg.seq_len)], persistent=False + ) + if torch.cuda.is_available(): + self.freqs_cis = self.freqs_cis.cuda() + + # Select the frequencies for current sequence length and ensure correct device + freqs_cis = self.freqs_cis[:, :L, :].to(x.device) + + for layer in self.layers: + x = layer(x, freqs_cis) # (bsz, seqlen, dim) + out = self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) + if targets is not None: + loss = F.cross_entropy( + out.view(-1, out.size(-1)), targets.view(-1), ignore_index=-100 + ) + return out, loss + return out + + def predict(self, x, k=1): + """Generate k tokens autoregressively. + + Args: + x: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict + + Returns: + Tuple of (input_ids, predicted_ids) + """ + + # Store original input + original_input = x.clone() + generated_input = x.clone() + + # Generate k tokens autoregressively + for i in range(k): + # Get logits for the entire sequence + logits = self(generated_input) + + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] + + # Zero out the last token ID to prevent repetition + # This is a common issue - the model gets stuck repeating the last token + last_token_id = generated_input[:, -1] + next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) + + # Get the most likely token + next_token = torch.argmax(next_token_logits, dim=-1) + + # Append the predicted token to the sequence + next_token = next_token.unsqueeze(1) # Add sequence dimension + generated_input = torch.cat([generated_input, next_token], dim=1) + + # For debugging, print predictions for the first item in the batch + print('\nPyTorch detailed prediction (first item in batch):') + predicted_sequence = generated_input[0, -k:].tolist() + print(f' Predicted token IDs: {predicted_sequence}') + for i, token_id in enumerate(predicted_sequence): + print(f' Step {i + 1}: Predicted token {token_id}') + + # Return all tokens, not just the last k + return original_input, generated_input[:, -k:] + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def _scale_residual_branches(self): + for n, p in self.named_parameters(): + if n.endswith('fc2.weight'): # mlp/glu output layer + torch.nn.init.normal_( + p, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layers) + ) + if n.endswith('w_out.weight'): # attn output layer + torch.nn.init.normal_( + p, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layers) + ) + + def tie_weights(self): + self.lm_head.weight = self.embed_tokens.weight + + def count_params(self, non_embedding=True): + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.embed_tokens.weight.numel() + if ( + self.lm_head.weight is not self.embed_tokens.weight + ): # if no weight tying + n_params -= self.lm_head.weight.numel() + return n_params + + +def main(): + print('Initializing transformer model and running forward pass...') + + seq_length = 1024 + + # Define model configuration + config = ModelConfig( + vocab_size=50257, # Common vocab size for tokenizers like BPE or SentencePiece + seq_len=seq_length, # Maximum sequence length + dim=1024, # Embedding dimension + expand=4.0, # MLP expansion factor + n_layers=12, # Number of transformer layers + n_heads=8, # Number of attention heads + rmsnorm_eps=1e-6, # RMSNorm epsilon + tie_embeddings=True, # Tie embedding and output weights + ) + + # Instantiate the model + model = Transformer(config) + print(f'Model has {model.count_params():,} parameters.') + + # Create some random input data + batch_size = 2 + input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_length)) + + # Move data to the same device as the model + if torch.cuda.is_available(): + input_ids = input_ids.cuda() + + # Run a forward pass + print(f'Running forward pass with input shape: {input_ids.shape}') + logits = model(input_ids) + print(f'Output logits shape: {logits.shape}') + + # Run prediction + print('Running prediction...') + original_input, predicted_ids = model.predict(input_ids[:, :10], k=5) + print(f'Original input shape for prediction: {original_input.shape}') + print(f'Predicted IDs shape: {predicted_ids.shape}') + print(f'Predicted IDs: {predicted_ids}') + + +if __name__ == '__main__': + main() diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 9713e84b0..115fae4f6 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -24,28 +24,28 @@ class LmWorkload(BaseLmWorkload): """LM PyTorch workload.""" def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None, + ) -> spec.ModelInitState: if hasattr(self, '_model'): - # Reinitialize weights but keep same config - self._model.apply(self._model._init_weights) - self._model._scale_residual_branches() - return self._model, None + # Reinitialize weights but keep same config + self._model.apply(self._model._init_weights) + self._model._scale_residual_branches() + return self._model, None torch.manual_seed(rng[0]) cfg = ModelConfig( - vocab_size=self._vocab_size, - seq_len=self._seq_len, - dim=self._emb_dim, # Model dimension - expand=self._mlp_dim // self._emb_dim, # MLP expansion factor - # FIXME(rka97): fix expansion factor - n_layers=self._n_layers, # Number of transformer layers - n_heads=self._n_heads, # Number of attention heads - rmsnorm_eps=1e-6, - tie_embeddings=True + vocab_size=self._vocab_size, + seq_len=self._seq_len, + dim=self._emb_dim, # Model dimension + expand=self._mlp_dim // self._emb_dim, # MLP expansion factor + # FIXME(rka97): fix expansion factor + n_layers=self._n_layers, # Number of transformer layers + n_heads=self._n_heads, # Number of attention heads + rmsnorm_eps=1e-6, + tie_embeddings=True, ) self._model = Transformer(cfg) self._param_shapes = param_utils.pytorch_param_shapes(self._model) @@ -53,23 +53,23 @@ def init_model_fn( self._model.to(DEVICE) if N_GPUS > 1: - if USE_PYTORCH_DDP: - self._model = DDP(self._model, device_ids=[RANK], output_device=RANK) - else: - self._model = torch.nn.DataParallel(self._model) + if USE_PYTORCH_DDP: + self._model = DDP(self._model, device_ids=[RANK], output_device=RANK) + else: + self._model = torch.nn.DataParallel(self._model) return self._model, None def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - dropout_rate: float = 0.0) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = 0.0, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state, rng, update_batch_norm, dropout_rate model = params @@ -93,32 +93,39 @@ def model_fn( return logits, None def _build_input_queue( - self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: """Build an input queue for the given split.""" del cache, repeat_final_dataset local_batch_size = global_batch_size // N_GPUS loader = get_data_iter( - data_rng=data_rng, - split=split, - data_dir=data_dir, - batch_size=local_batch_size, - num_batches=num_batches, + data_rng=data_rng, + split=split, + data_dir=data_dir, + batch_size=local_batch_size, + num_batches=num_batches, ) if USE_PYTORCH_DDP: - loader = islice(loader, RANK, None, N_GPUS) + loader = islice(loader, RANK, None, N_GPUS) dtype = torch.int32 for batch in loader: batch = { - 'inputs': torch.tensor(batch['inputs'], device=DEVICE, dtype=dtype), - 'targets': torch.tensor(batch['targets'], device=DEVICE, dtype=torch.int64), - 'weights': torch.tensor(batch['weights'], device=DEVICE, dtype=torch.float32) if batch['weights'] is not None else None, + 'inputs': torch.tensor(batch['inputs'], device=DEVICE, dtype=dtype), + 'targets': torch.tensor( + batch['targets'], device=DEVICE, dtype=torch.int64 + ), + 'weights': torch.tensor( + batch['weights'], device=DEVICE, dtype=torch.float32 + ) + if batch['weights'] is not None + else None, } yield batch @@ -127,7 +134,13 @@ def is_output_params(self, param_name: str) -> bool: return 'lm_head.weight' in param_name or 'lm_head.bias' in param_name # FIXME(rka97): Implement label smoothing - def compute_weighted_cross_entropy(self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor, label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + def compute_weighted_cross_entropy( + self, + logits: spec.Tensor, + labels: spec.Tensor, + weights: spec.Tensor, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: """Compute cross-entropy loss for language modeling in PyTorch.""" vocab_size = logits.size(-1) @@ -138,44 +151,53 @@ def compute_weighted_cross_entropy(self, logits: spec.Tensor, labels: spec.Tenso else: # Dense labels loss = torch.nn.functional.cross_entropy( - logits.view(-1, vocab_size), - labels.view(-1), - reduction='none') + logits.view(-1, vocab_size), labels.view(-1), reduction='none' + ) loss = loss.view_as(labels) if weights is not None: loss = loss * weights - n_valid = weights.sum() if weights is not None else torch.tensor(labels.numel(), dtype=torch.float32, device=labels.device) + n_valid = ( + weights.sum() + if weights is not None + else torch.tensor( + labels.numel(), dtype=torch.float32, device=labels.device + ) + ) return { - 'summed': loss.sum(), - 'n_valid_examples': n_valid, - 'per_example': loss, + 'summed': loss.sum(), + 'n_valid_examples': n_valid, + 'per_example': loss, } - - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> spec.Tensor: - """Evaluate the model on a single batch.""" - logits, _ = self.model_fn( - params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) - return { - 'loss': metrics['summed'].detach(), - 'denominator': metrics['n_valid_examples'].detach(), - } + def _eval_batch( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False + ) + metrics = self.compute_weighted_cross_entropy( + logits, batch['targets'], batch['weights'] + ) + return { + 'loss': metrics['summed'].detach(), + 'denominator': metrics['n_valid_examples'].detach(), + } def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, Any] - ) -> Dict[str, float]: - """Normalize eval metrics.""" - del num_examples - if USE_PYTORCH_DDP: - for metric in total_metrics.values(): - dist.all_reduce(metric) - total_metrics = {k: v.item() for k, v in total_metrics.items()} - eval_denominator = total_metrics.pop('denominator') - return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) \ No newline at end of file + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + del num_examples + if USE_PYTORCH_DDP: + for metric in total_metrics.values(): + dist.all_reduce(metric) + total_metrics = {k: v.item() for k, v in total_metrics.items()} + eval_denominator = total_metrics.pop('denominator') + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index f15e4b8a7..43dd60ab5 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -45,11 +45,11 @@ def validation_target_value(self) -> float: return 25.5477 # Target perplexity def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: - return True # No test targets + return True # No test targets @property def test_target_value(self) -> float: - return None # No test targets + return None # No test targets @property def loss_type(self) -> spec.LossType: @@ -61,11 +61,11 @@ def num_train_examples(self) -> int: @property def num_eval_train_examples(self) -> int: - return 10_000 # Subset for evaluation. + return 10_000 # Subset for evaluation. @property def num_validation_examples(self) -> int: - return 100_000 # sequences + return 100_000 # sequences @property def num_test_examples(self) -> int: @@ -85,7 +85,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 3600 * 14 # 14 hours TODO(kasimbeg): update + return 3600 * 14 # 14 hours TODO(kasimbeg): update @property def eval_period_time_sec(self) -> int: @@ -125,7 +125,6 @@ def _build_input_queue( ) -> Iterator[Dict[str, Any]]: """Build an input queue for the given split.""" - def _eval_model_on_split( self, split: str, @@ -147,11 +146,7 @@ def _eval_model_on_split( if split not in self._eval_iters: # These iterators will repeat indefinitely. self._eval_iters[split] = self._build_input_queue( - rng, - split, - data_dir, - global_batch_size, - num_batches=num_batches + rng, split, data_dir, global_batch_size, num_batches=num_batches ) eval_metrics = {} @@ -167,7 +162,6 @@ def _eval_model_on_split( eval_results['ppl'] = np.exp(eval_results['loss']).item() return eval_results - @abc.abstractmethod def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any] @@ -175,19 +169,20 @@ def _normalize_eval_metrics( """Normalize eval metrics.""" def loss_fn( - self, - label_batch: spec.Tensor, - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: """Compute cross-entropy loss for language modeling in JAX.""" return self.compute_weighted_cross_entropy( logits_batch, label_batch, weights=mask_batch, - label_smoothing=label_smoothing + label_smoothing=label_smoothing, ) def is_output_params(self, param_name: str) -> bool: """Return whether the given parameter is an output parameter.""" - return param_name.contains('output') \ No newline at end of file + return param_name.contains('output') diff --git a/algoperf/workloads/workloads.py b/algoperf/workloads/workloads.py index 114b1adb4..391f16f51 100644 --- a/algoperf/workloads/workloads.py +++ b/algoperf/workloads/workloads.py @@ -9,151 +9,153 @@ BASE_WORKLOADS_DIR = 'algoperf/workloads/' WORKLOADS = { - 'cifar': { - 'workload_path': 'cifar/cifar', 'workload_class_name': 'CifarWorkload' - }, - 'criteo1tb': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallWorkload', - }, - 'criteo1tb_test': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallTestWorkload', - }, - 'criteo1tb_layernorm': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallLayerNormWorkload' - }, - 'criteo1tb_embed_init': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallEmbedInitWorkload' - }, - 'criteo1tb_resnet': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallResNetWorkload' - }, - 'fastmri': { - 'workload_path': 'fastmri/fastmri', - 'workload_class_name': 'FastMRIWorkload', - }, - 'fastmri_model_size': { - 'workload_path': 'fastmri/fastmri', - 'workload_class_name': 'FastMRIModelSizeWorkload', - }, - 'fastmri_tanh': { - 'workload_path': 'fastmri/fastmri', - 'workload_class_name': 'FastMRITanhWorkload', - }, - 'fastmri_layernorm': { - 'workload_path': 'fastmri/fastmri', - 'workload_class_name': 'FastMRILayerNormWorkload', - }, - 'imagenet_resnet': { - 'workload_path': 'imagenet_resnet/imagenet', - 'workload_class_name': 'ImagenetResNetWorkload', - }, - 'imagenet_resnet_silu': { - 'workload_path': 'imagenet_resnet/imagenet', - 'workload_class_name': 'ImagenetResNetSiLUWorkload', - }, - 'imagenet_resnet_gelu': { - 'workload_path': 'imagenet_resnet/imagenet', - 'workload_class_name': 'ImagenetResNetGELUWorkload', - }, - 'imagenet_resnet_large_bn_init': { - 'workload_path': 'imagenet_resnet/imagenet', - 'workload_class_name': 'ImagenetResNetLargeBNScaleWorkload', - }, - 'imagenet_vit': { - 'workload_path': 'imagenet_vit/imagenet', - 'workload_class_name': 'ImagenetVitWorkload', - }, - 'imagenet_vit_glu': { - 'workload_path': 'imagenet_vit/imagenet', - 'workload_class_name': 'ImagenetVitGluWorkload', - }, - 'imagenet_vit_post_ln': { - 'workload_path': 'imagenet_vit/imagenet', - 'workload_class_name': 'ImagenetVitPostLNWorkload', - }, - 'imagenet_vit_map': { - 'workload_path': 'imagenet_vit/imagenet', - 'workload_class_name': 'ImagenetVitMapWorkload', - }, - 'librispeech_conformer': { - 'workload_path': 'librispeech_conformer/librispeech', - 'workload_class_name': 'LibriSpeechConformerWorkload', - }, - 'librispeech_conformer_attention_temperature': { - 'workload_path': - 'librispeech_conformer/librispeech', - 'workload_class_name': - 'LibriSpeechConformerAttentionTemperatureWorkload', - }, - 'librispeech_conformer_layernorm': { - 'workload_path': 'librispeech_conformer/librispeech', - 'workload_class_name': 'LibriSpeechConformerLayerNormWorkload', - }, - 'librispeech_conformer_gelu': { - 'workload_path': 'librispeech_conformer/librispeech', - 'workload_class_name': 'LibriSpeechConformerGeluWorkload', - }, - 'librispeech_deepspeech': { - 'workload_path': 'librispeech_deepspeech/librispeech', - 'workload_class_name': 'LibriSpeechDeepSpeechWorkload', - }, - 'librispeech_deepspeech_tanh': { - 'workload_path': 'librispeech_deepspeech/librispeech', - 'workload_class_name': 'LibriSpeechDeepSpeechTanhWorkload', - }, - 'librispeech_deepspeech_no_resnet': { - 'workload_path': 'librispeech_deepspeech/librispeech', - 'workload_class_name': 'LibriSpeechDeepSpeechNoResNetWorkload', - }, - 'librispeech_deepspeech_norm_and_spec_aug': { - 'workload_path': 'librispeech_deepspeech/librispeech', - 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', - }, - 'lm': {'workload_path': 'lm/lm', 'workload_class_name': 'LmWorkload'}, - 'mnist': { - 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload' - }, - 'ogbg': { - 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgWorkload' - }, - 'ogbg_gelu': { - 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgGeluWorkload' - }, - 'ogbg_silu': { - 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgSiluWorkload' - }, - 'ogbg_model_size': { - 'workload_path': 'ogbg/ogbg', - 'workload_class_name': 'OgbgModelSizeWorkload' - }, - 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'}, - 'wmt_post_ln': { - 'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkloadPostLN' - }, - 'wmt_attention_temp': { - 'workload_path': 'wmt/wmt', - 'workload_class_name': 'WmtWorkloadAttentionTemp' - }, - 'wmt_glu_tanh': { - 'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkloadGLUTanH' - }, + 'cifar': { + 'workload_path': 'cifar/cifar', + 'workload_class_name': 'CifarWorkload', + }, + 'criteo1tb': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallWorkload', + }, + 'criteo1tb_test': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallTestWorkload', + }, + 'criteo1tb_layernorm': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallLayerNormWorkload', + }, + 'criteo1tb_embed_init': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallEmbedInitWorkload', + }, + 'criteo1tb_resnet': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallResNetWorkload', + }, + 'fastmri': { + 'workload_path': 'fastmri/fastmri', + 'workload_class_name': 'FastMRIWorkload', + }, + 'fastmri_model_size': { + 'workload_path': 'fastmri/fastmri', + 'workload_class_name': 'FastMRIModelSizeWorkload', + }, + 'fastmri_tanh': { + 'workload_path': 'fastmri/fastmri', + 'workload_class_name': 'FastMRITanhWorkload', + }, + 'fastmri_layernorm': { + 'workload_path': 'fastmri/fastmri', + 'workload_class_name': 'FastMRILayerNormWorkload', + }, + 'imagenet_resnet': { + 'workload_path': 'imagenet_resnet/imagenet', + 'workload_class_name': 'ImagenetResNetWorkload', + }, + 'imagenet_resnet_silu': { + 'workload_path': 'imagenet_resnet/imagenet', + 'workload_class_name': 'ImagenetResNetSiLUWorkload', + }, + 'imagenet_resnet_gelu': { + 'workload_path': 'imagenet_resnet/imagenet', + 'workload_class_name': 'ImagenetResNetGELUWorkload', + }, + 'imagenet_resnet_large_bn_init': { + 'workload_path': 'imagenet_resnet/imagenet', + 'workload_class_name': 'ImagenetResNetLargeBNScaleWorkload', + }, + 'imagenet_vit': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitWorkload', + }, + 'imagenet_vit_glu': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitGluWorkload', + }, + 'imagenet_vit_post_ln': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitPostLNWorkload', + }, + 'imagenet_vit_map': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitMapWorkload', + }, + 'librispeech_conformer': { + 'workload_path': 'librispeech_conformer/librispeech', + 'workload_class_name': 'LibriSpeechConformerWorkload', + }, + 'librispeech_conformer_attention_temperature': { + 'workload_path': 'librispeech_conformer/librispeech', + 'workload_class_name': 'LibriSpeechConformerAttentionTemperatureWorkload', + }, + 'librispeech_conformer_layernorm': { + 'workload_path': 'librispeech_conformer/librispeech', + 'workload_class_name': 'LibriSpeechConformerLayerNormWorkload', + }, + 'librispeech_conformer_gelu': { + 'workload_path': 'librispeech_conformer/librispeech', + 'workload_class_name': 'LibriSpeechConformerGeluWorkload', + }, + 'librispeech_deepspeech': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechWorkload', + }, + 'librispeech_deepspeech_tanh': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechTanhWorkload', + }, + 'librispeech_deepspeech_no_resnet': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechNoResNetWorkload', + }, + 'librispeech_deepspeech_norm_and_spec_aug': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', + }, + 'lm': {'workload_path': 'lm/lm', 'workload_class_name': 'LmWorkload'}, + 'mnist': { + 'workload_path': 'mnist/mnist', + 'workload_class_name': 'MnistWorkload', + }, + 'ogbg': {'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgWorkload'}, + 'ogbg_gelu': { + 'workload_path': 'ogbg/ogbg', + 'workload_class_name': 'OgbgGeluWorkload', + }, + 'ogbg_silu': { + 'workload_path': 'ogbg/ogbg', + 'workload_class_name': 'OgbgSiluWorkload', + }, + 'ogbg_model_size': { + 'workload_path': 'ogbg/ogbg', + 'workload_class_name': 'OgbgModelSizeWorkload', + }, + 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'}, + 'wmt_post_ln': { + 'workload_path': 'wmt/wmt', + 'workload_class_name': 'WmtWorkloadPostLN', + }, + 'wmt_attention_temp': { + 'workload_path': 'wmt/wmt', + 'workload_class_name': 'WmtWorkloadAttentionTemp', + }, + 'wmt_glu_tanh': { + 'workload_path': 'wmt/wmt', + 'workload_class_name': 'WmtWorkloadGLUTanH', + }, } BASE_WORKLOADS = [ - 'criteo1tb', - 'fastmri', - 'imagenet_resnet', - 'imagenet_vit', - 'librispeech_conformer', - 'librispeech_deepspeech', - 'lm', - 'ogbg', - 'wmt' + 'criteo1tb', + 'fastmri', + 'imagenet_resnet', + 'imagenet_vit', + 'librispeech_conformer', + 'librispeech_deepspeech', + 'lm', + 'ogbg', + 'wmt', ] diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 8fecaf419..de5e9d271 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -72,8 +72,7 @@ from torchvision.datasets import CIFAR10 from algoperf.workloads.wmt import tokenizer -from algoperf.workloads.wmt.input_pipeline import \ - normalize_feature_names +from algoperf.workloads.wmt.input_pipeline import normalize_feature_names from dataset import librispeech_preprocess from dataset import librispeech_tokenizer @@ -111,38 +110,41 @@ 'files will be deleted.', ) flags.DEFINE_boolean( - 'all', - False, - 'Whether or not to download all datasets. If false, can download some ' - 'combination of datasets by setting the individual dataset flags below.') - -flags.DEFINE_boolean('criteo1tb', - False, - 'If --all=false, whether or not to download Criteo 1TB.') -flags.DEFINE_boolean('cifar', - False, - 'If --all=false, whether or not to download CIFAR-10.') -flags.DEFINE_boolean('fastmri', - False, - 'If --all=false, whether or not to download FastMRI.') -flags.DEFINE_boolean('finewebedu', - False, - 'If --all=false, whether or not to download FineWebEdu.') -flags.DEFINE_boolean('imagenet', - False, - 'If --all=false, whether or not to download Imagenet.') -flags.DEFINE_boolean('librispeech', - False, - 'If --all=false, whether or not to download LibriSpeech.') -flags.DEFINE_boolean('mnist', - False, - 'If --all=false, whether or not to download MNIST.') -flags.DEFINE_boolean('ogbg', - False, - 'If --all=false, whether or not to download OGBG.') -flags.DEFINE_boolean('wmt', - False, - 'If --all=false, whether or not to download WMT.') + 'all', + False, + 'Whether or not to download all datasets. If false, can download some ' + 'combination of datasets by setting the individual dataset flags below.', +) + +flags.DEFINE_boolean( + 'criteo1tb', False, 'If --all=false, whether or not to download Criteo 1TB.' +) +flags.DEFINE_boolean( + 'cifar', False, 'If --all=false, whether or not to download CIFAR-10.' +) +flags.DEFINE_boolean( + 'fastmri', False, 'If --all=false, whether or not to download FastMRI.' +) +flags.DEFINE_boolean( + 'finewebedu', False, 'If --all=false, whether or not to download FineWebEdu.' +) +flags.DEFINE_boolean( + 'imagenet', False, 'If --all=false, whether or not to download Imagenet.' +) +flags.DEFINE_boolean( + 'librispeech', + False, + 'If --all=false, whether or not to download LibriSpeech.', +) +flags.DEFINE_boolean( + 'mnist', False, 'If --all=false, whether or not to download MNIST.' +) +flags.DEFINE_boolean( + 'ogbg', False, 'If --all=false, whether or not to download OGBG.' +) +flags.DEFINE_boolean( + 'wmt', False, 'If --all=false, whether or not to download WMT.' +) flags.DEFINE_string( 'data_dir', @@ -199,7 +201,9 @@ flags.DEFINE_string('framework', None, 'Can be either jax or pytorch.') flags.DEFINE_boolean('skip_download', False, 'Skips data download.') -flags.DEFINE_boolean('skip_tokenization', False, 'Skip Fineweb-edu tokenization.') +flags.DEFINE_boolean( + 'skip_tokenization', False, 'Skip Fineweb-edu tokenization.' +) FLAGS = flags.FLAGS @@ -773,30 +777,32 @@ def download_wmt(data_dir): ) -def download_finewebedu(data_dir, - tmp_dir=None, - skip_download=False, - skip_tokenization=False): +def download_finewebedu( + data_dir, tmp_dir=None, skip_download=False, skip_tokenization=False +): """Download FineWebEdu-10B.""" - if not skip_download: + if not skip_download: data_dir = os.path.join(data_dir, 'fineweb_edu_10B') tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' - cache_dir = os.path.join(tmp_dir, - 'lm') if tmp_dir is not None else os.path.expanduser( - '~/.cache/huggingface/datasets') + cache_dir = ( + os.path.join(tmp_dir, 'lm') + if tmp_dir is not None + else os.path.expanduser('~/.cache/huggingface/datasets') + ) _maybe_mkdir(data_dir) _maybe_mkdir(tmp_dir) _maybe_mkdir(cache_dir) - os.environ["TMPDIR"] = tmp_dir + os.environ['TMPDIR'] = tmp_dir ds = hf_datasets.load_dataset( - 'HuggingFaceFW/fineweb-edu', - name='sample-10BT', - split='train', - cache_dir=cache_dir) + 'HuggingFaceFW/fineweb-edu', + name='sample-10BT', + split='train', + cache_dir=cache_dir, + ) ds.save_to_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) else: ds = hf_datasets.load_from_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) @@ -804,10 +810,9 @@ def download_finewebedu(data_dir, if not skip_tokenization: # Tokenize lm_tokenizer = AutoTokenizer.from_pretrained('gpt2') - logging.info(f"Vocab size of lm_tokenizer = {len(lm_tokenizer)}") + logging.info(f'Vocab size of lm_tokenizer = {len(lm_tokenizer)}') def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - def add_eos(seq): return seq + lm_tokenizer.eos_token if seq else seq @@ -815,33 +820,41 @@ def add_eos_batched(seqs): return [add_eos(seq) for seq in seqs] return lm_tokenizer( - add_eos_batched(examples["text"]), - return_special_tokens_mask=False, - return_attention_mask=False) + add_eos_batched(examples['text']), + return_special_tokens_mask=False, + return_attention_mask=False, + ) - lm_tokenizer.model_max_length = 1e30 # prevent truncation during tokenization - logging.info("Tokenizing...") + lm_tokenizer.model_max_length = ( + 1e30 # prevent truncation during tokenization + ) + logging.info('Tokenizing...') tokenized_dataset = ds.map( - tokenize, - remove_columns=[ - 'text', - 'id', - 'dump', - 'url', - 'file_path', - 'language', - 'language_score', - 'token_count', - 'score', - 'int_score' - ], - batched=True, - batch_size=1024, - num_proc=8) - - tokenized_dataset.save_to_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) + tokenize, + remove_columns=[ + 'text', + 'id', + 'dump', + 'url', + 'file_path', + 'language', + 'language_score', + 'token_count', + 'score', + 'int_score', + ], + batched=True, + batch_size=1024, + num_proc=8, + ) + + tokenized_dataset.save_to_disk( + os.path.join(data_dir, 'fwedu_10B_tokenized') + ) else: - tokenized_dataset = hf_datasets.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) + tokenized_dataset = hf_datasets.load_from_disk( + os.path.join(data_dir, 'fwedu_10B_tokenized') + ) # Convert to tensorflow_datasets.Dataset objects tokenized_dataset = tokenized_dataset.to_tf_dataset() @@ -854,10 +867,10 @@ def add_eos_batched(seqs): val_dataset = shuffled_dataset.skip(train_size) # Split in train and valid. - train_dataset.save(os.path.join(data_dir, "train")) - val_dataset.save(os.path.join(data_dir, "val")) + train_dataset.save(os.path.join(data_dir, 'train')) + val_dataset.save(os.path.join(data_dir, 'val')) - return + return def main(_): @@ -949,7 +962,9 @@ def main(_): if FLAGS.all or FLAGS.finewebedu: logging.info('Downloading FineWebEdu-10B...') - download_finewebedu(data_dir, tmp_dir, FLAGS.skip_download, FLAGS.skip_tokenization) + download_finewebedu( + data_dir, tmp_dir, FLAGS.skip_download, FLAGS.skip_tokenization + ) # pylint: enable=logging-format-interpolation diff --git a/submission_runner.py b/submission_runner.py index 1c50cd6d9..857d4479f 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -253,12 +253,12 @@ def train_once( model_params, model_state = workload.init_model_fn(model_init_rng) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ - 'librispeech_conformer', - 'ogbg', - 'criteo1tb', - 'imagenet_vit', - 'librispeech_deepspeech', - 'lm' + 'librispeech_conformer', + 'ogbg', + 'criteo1tb', + 'imagenet_vit', + 'librispeech_deepspeech', + 'lm', ] eager_backend_workloads = [] aot_eager_backend_workloads = [] @@ -784,8 +784,10 @@ def main(_): os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' if FLAGS.framework == 'pytorch': - limit_tf_threads = (base_workload != 'lm') - pytorch_init(USE_PYTORCH_DDP, RANK, profiler, limit_tf_threads=limit_tf_threads) + limit_tf_threads = base_workload != 'lm' + pytorch_init( + USE_PYTORCH_DDP, RANK, profiler, limit_tf_threads=limit_tf_threads + ) # TODO: remove once issue resolved. if FLAGS.pytorch_eval_num_workers != 0: @@ -797,11 +799,11 @@ def main(_): workload_metadata = WORKLOADS[FLAGS.workload] if base_workload in [ - 'librispeech_conformer', - 'librispeech_deepspeech', - 'imagenet_vit', - 'criteo1tb', - 'lm' + 'librispeech_conformer', + 'librispeech_deepspeech', + 'imagenet_vit', + 'criteo1tb', + 'lm', ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' From bb4a3809ee6978ce0cd0f7e5c15ba1aef2ca6a6e Mon Sep 17 00:00:00 2001 From: rka97 Date: Tue, 21 Oct 2025 04:33:40 +0000 Subject: [PATCH 70/82] Refactor loss function in LM workloads to unify label handling and improve clarity --- algoperf/workloads/lm/lm_jax/workload.py | 52 ++++++------- algoperf/workloads/lm/lm_pytorch/workload.py | 77 ++++++++++++-------- algoperf/workloads/lm/workload.py | 19 +++-- 3 files changed, 87 insertions(+), 61 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 5b736fad7..13738086a 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -87,35 +87,38 @@ def model_fn( logits = self._model.apply({'params': params}, inputs) return logits, None - def compute_weighted_cross_entropy( + def loss_fn( self, - logits: spec.Tensor, - targets: spec.Tensor, - weights: Optional[spec.Tensor] = None, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, label_smoothing: float = 0.0, ) -> Dict[str, spec.Tensor]: # differentiable - """Compute weighted cross entropy and entropy for log probs and targets. + """Compute weighted cross entropy. + Args: - logits: [batch, length, num_classes] float array. - targets: categorical targets [batch, length] int array. - weights: array of shape [batch, length]. - label_smoothing: label smoothing constant, used to determine the on and off - values. + label_batch: categorical targets [batch, length] int array. + logits_batch: [batch, length, num_classes] float array. + mask_batch: weights array of shape [batch, length]. + label_smoothing: Label smoothing factor in [0, 1]. When > 0, the target + distribution becomes (1 - label_smoothing) for the correct class and + label_smoothing / vocab_size for all other classes. Default is 0.0 (no smoothing). + Returns: {'summed': scalar summed loss, 'n_valid_examples': scalar number of - valid examples in batch, 'per_example': 1-d array of per-example losses} + valid examples in batch, 'per_example': 2d array of per-example losses} """ - if logits.ndim != targets.ndim + 1: + if logits_batch.ndim != label_batch.ndim + 1: raise ValueError( - f'Incorrect shapes. Got shape {logits.shape} logits and ' - f'{targets.shape} targets.' + f'Incorrect shapes. Got shape {logits_batch.shape} logits and ' + f'{label_batch.shape} targets.' ) # Compute log probabilities - log_probs = jax.nn.log_softmax(logits, axis=-1) + log_probs = jax.nn.log_softmax(logits_batch, axis=-1) # Extract log probability of the target class # Shape: [batch, length] target_log_probs = jnp.take_along_axis( - log_probs, targets[..., None], axis=-1 + log_probs, label_batch[..., None], axis=-1 ).squeeze(-1) # Cross-entropy with smoothing: -(1 - α) * log_p[target] - α * mean(log_p) # The above formula is easy to derive from the definition of label smoothing and cross-entropy loss. @@ -124,11 +127,11 @@ def compute_weighted_cross_entropy( per_example_losses = -1.0 * ( confidence * target_log_probs + smoothing_term * log_probs.sum(axis=-1) ) - if weights is not None: - per_example_losses = jnp.where(weights, per_example_losses, 0.0) - n_valid_examples = weights.sum() + if mask_batch is not None: + per_example_losses = mask_batch * per_example_losses + n_valid_examples = mask_batch.sum() else: - n_valid_examples = targets.shape[0] * targets.shape[1] + n_valid_examples = label_batch.shape[0] * label_batch.shape[1] summed_loss = per_example_losses.sum() return { 'summed': summed_loss, @@ -147,12 +150,11 @@ def _eval_batch( logits, _ = self.model_fn( params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False ) - # Calculate cross-entropy loss - metrics = self.compute_weighted_cross_entropy( - logits, batch['targets'], batch['weights'] + metrics = self.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch['weights'], ) - # CRITICAL: Detach tensors to free computation graph and activations - # Without this, all intermediate activations are kept in memory! return { 'loss': metrics['summed'], 'denominator': metrics['n_valid_examples'], diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 115fae4f6..2f5c33ebf 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -133,42 +133,59 @@ def is_output_params(self, param_name: str) -> bool: """Return whether the given parameter is an output parameter.""" return 'lm_head.weight' in param_name or 'lm_head.bias' in param_name - # FIXME(rka97): Implement label smoothing - def compute_weighted_cross_entropy( + def loss_fn( self, - logits: spec.Tensor, - labels: spec.Tensor, - weights: spec.Tensor, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: spec.Tensor, label_smoothing: float = 0.0, ) -> Dict[str, spec.Tensor]: - """Compute cross-entropy loss for language modeling in PyTorch.""" - vocab_size = logits.size(-1) - - if len(labels.shape) == len(logits.shape): - # One-hot labels - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) - loss = -torch.sum(labels * log_probs, dim=-1) - else: - # Dense labels - loss = torch.nn.functional.cross_entropy( - logits.view(-1, vocab_size), labels.view(-1), reduction='none' - ) - loss = loss.view_as(labels) + """Compute weighted cross-entropy loss. + + Args: + label_batch: Target labels of shape [batch, length] (int). + logits_batch: Predicted logits of shape [batch, length, vocab_size] (float). + mask_batch: Optional weights of shape [batch, length] (float). Used to mask + out padding tokens or weight examples differently. If None, all examples + are weighted equally. + label_smoothing: Label smoothing factor in [0, 1]. When > 0, the target + distribution becomes (1 - label_smoothing) for the correct class and + label_smoothing / vocab_size for all other classes. Default is 0.0 (no smoothing). + + Returns: + Dictionary containing: + - 'summed': Scalar tensor with the sum of all weighted losses. + - 'n_valid_examples': Scalar tensor with the count of valid (non-masked) examples. + - 'per_example': Tensor of shape [batch, length] with individual losses per example. + """ + vocab_size = logits_batch.size(-1) + + # Compute cross-entropy loss with label smoothing + per_example_losses = torch.nn.functional.cross_entropy( + logits_batch.view(-1, vocab_size), + label_batch.view(-1), + reduction='none', + label_smoothing=label_smoothing, + ) + per_example_losses = per_example_losses.view_as(label_batch) - if weights is not None: - loss = loss * weights + # Apply weights if provided + if mask_batch is not None: + per_example_losses = per_example_losses * mask_batch - n_valid = ( - weights.sum() - if weights is not None + # Calculate number of valid examples + n_valid_examples = ( + mask_batch.sum() + if mask_batch is not None else torch.tensor( - labels.numel(), dtype=torch.float32, device=labels.device + label_batch.numel(), dtype=torch.float32, device=label_batch.device ) ) + return { - 'summed': loss.sum(), - 'n_valid_examples': n_valid, - 'per_example': loss, + 'summed': per_example_losses.sum(), + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } def _eval_batch( @@ -182,8 +199,10 @@ def _eval_batch( logits, _ = self.model_fn( params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False ) - metrics = self.compute_weighted_cross_entropy( - logits, batch['targets'], batch['weights'] + metrics = self.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch['weights'], ) return { 'loss': metrics['summed'].detach(), diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 43dd60ab5..56d9fabcc 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -125,6 +125,16 @@ def _build_input_queue( ) -> Iterator[Dict[str, Any]]: """Build an input queue for the given split.""" + @abc.abstractmethod + def _eval_batch( + self, + params: spec.ParameterContainer, + eval_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Dict[str, float]: + """Evaluate the model on a single batch.""" + def _eval_model_on_split( self, split: str, @@ -168,6 +178,7 @@ def _normalize_eval_metrics( ) -> Dict[str, float]: """Normalize eval metrics.""" + @abc.abstractmethod def loss_fn( self, label_batch: spec.Tensor, @@ -175,13 +186,7 @@ def loss_fn( mask_batch: Optional[spec.Tensor] = None, label_smoothing: float = 0.0, ) -> Dict[str, spec.Tensor]: - """Compute cross-entropy loss for language modeling in JAX.""" - return self.compute_weighted_cross_entropy( - logits_batch, - label_batch, - weights=mask_batch, - label_smoothing=label_smoothing, - ) + """Compute cross-entropy loss for language modeling.""" def is_output_params(self, param_name: str) -> bool: """Return whether the given parameter is an output parameter.""" From a58fbd57ebd9fde597d96b3eba34f89929ffcab4 Mon Sep 17 00:00:00 2001 From: rka97 Date: Tue, 21 Oct 2025 08:46:00 +0000 Subject: [PATCH 71/82] Fix init in both models to be the same, add lm model diff test --- algoperf/workloads/lm/lm_jax/nanodo_model.py | 25 +- .../workloads/lm/lm_pytorch/plainlm_model.py | 20 +- tests/modeldiffs/lm/compare.py | 868 ++++++++++++++++++ 3 files changed, 893 insertions(+), 20 deletions(-) create mode 100644 tests/modeldiffs/lm/compare.py diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py index a1644f569..1227e57b2 100644 --- a/algoperf/workloads/lm/lm_jax/nanodo_model.py +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -21,14 +21,17 @@ class DoConfig: N: int # number of transformer block layers V: int # vocab size F: int # FF inner dimension - kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform() - embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'normal', out_axis=0 - ) + attention_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) + linear_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) + embed_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) + use_residual_scaling: bool = True dtype: jnp.dtype = jnp.float32 rmsnorm_epsilon: float = 1e-6 multiple_of: int = 256 - tie_embeddings: bool = True # Whether to tie input and output embeddings + tie_embeddings: bool = True # Whether to tie input and output embed + + def __post_init__(self): + self.residual_init = nn.initializers.normal(stddev=0.02/jnp.sqrt(2 * self.N)) class Mlp(nn.Module): @@ -40,9 +43,8 @@ class Mlp(nn.Module): def __call__(self, x_BxLxD: jax.Array): cfg = self.cfg # Use Xavier uniform initialization explicitly - xavier_init = nn.initializers.xavier_uniform() linear = partial( - nn.Dense, kernel_init=xavier_init, use_bias=False, dtype=cfg.dtype + nn.Dense, kernel_init=cfg.linear_init, use_bias=False, dtype=cfg.dtype ) # Adjust hidden dimension to keep the number of parameters invariant to # the activation function used since the GLU MLP has 3 * hidden_dim * D @@ -55,7 +57,7 @@ def __call__(self, x_BxLxD: jax.Array): x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD) # Apply GLU activation x_BxLxF = nn.glu(x_BxLx2F, axis=-1) - x_BxLxD = linear(cfg.D)(x_BxLxF) + x_BxLxD = nn.Dense(cfg.D, use_bias=False, dtype=cfg.dtype, kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init)(x_BxLxF) return x_BxLxD @@ -122,7 +124,7 @@ def setup(self): nn.DenseGeneral, axis=-1, features=(cfg.H, self.Dh), - kernel_init=cfg.kernel_init, + kernel_init=cfg.attention_init, use_bias=False, dtype=cfg.dtype, ) @@ -134,7 +136,7 @@ def setup(self): features=cfg.D, name='attn_out_proj', # axis=(-2, -1), # - kernel_init=cfg.kernel_init, + kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init, use_bias=False, dtype=cfg.dtype, ) @@ -265,6 +267,9 @@ def predict(self, y_BxL: jax.Array, k: int = 1): # Get the logits for the last token in each sequence next_token_logits = logits[:, -1, :] + last_token_id = y_BxL[:, -1] + # Prevent predicting the same token consecutively + next_token_logits = next_token_logits.at[jnp.arange(len(last_token_id)), last_token_id].set(float('-inf')) # Get the most likely token next_token = jnp.argmax(next_token_logits, axis=-1) diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py index f7e7f9e62..af4232b7e 100644 --- a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -23,6 +23,7 @@ class ModelConfig: n_heads: int rmsnorm_eps: float = 1e-6 tie_embeddings: bool = True + use_residual_scaling: bool = True class MLP(nn.Module): @@ -32,10 +33,8 @@ def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) self.fc2 = nn.Linear(hidden_dim, dim, bias=False) self.glu = nn.GLU(dim=2) - - # Initialize with Xavier uniform - nn.init.xavier_uniform_(self.fc1.weight) - nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.weight, std=0.02) + nn.init.normal_(self.fc2.weight, std=0.02) def forward(self, x): # x: (bsz, T, dim) @@ -89,6 +88,11 @@ def __init__(self, cfg: ModelConfig): self.w_qkv = nn.Linear(cfg.dim, 3 * cfg.dim, bias=False) self.w_out = nn.Linear(cfg.dim, cfg.dim, bias=False) + # Split into Q, K, V sections + wq, wk, wv = torch.chunk(self.w_qkv.weight, 3, dim=0) + for w in [wq, wk, wv]: + nn.init.normal_(w, std=0.02) + nn.init.normal_(self.w_out.weight, std=0.02) def forward(self, x, freqs_cis): bsz, seqlen, d = x.shape # (bsz, seqlen, d) @@ -254,15 +258,11 @@ def _init_weights(self, module): if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + torch.nn.init.normal_(module.weight, std=0.02) def _scale_residual_branches(self): for n, p in self.named_parameters(): - if n.endswith('fc2.weight'): # mlp/glu output layer - torch.nn.init.normal_( - p, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layers) - ) - if n.endswith('w_out.weight'): # attn output layer + if n.endswith('fc2.weight') or n.endswith('w_out.weight'): # mlp/glu output layer torch.nn.init.normal_( p, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layers) ) diff --git a/tests/modeldiffs/lm/compare.py b/tests/modeldiffs/lm/compare.py new file mode 100644 index 000000000..5b95f934c --- /dev/null +++ b/tests/modeldiffs/lm/compare.py @@ -0,0 +1,868 @@ +""" +Test file to verify that JAX and PyTorch implementations produce identical outputs +when given the same weights and inputs. + +Tests are performed module-by-module: +1. RMSNorm +2. RoPE (Rotary Position Embeddings) +3. MLP +4. Attention +5. Transformer Block +6. Full Model +""" + +import os +import sys + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import jax.numpy as jnp +import numpy as np +import torch +import torch.nn.functional as F +from absl import flags, logging +from absl.testing import absltest, parameterized + +# Import JAX implementation +from algoperf.workloads.lm.lm_jax.nanodo_model import ( + CausalAttn, + DoConfig, + Mlp, + TBlock, + TransformerDo, + apply_rope, + init_rope, +) + +# Import PyTorch implementation +from algoperf.workloads.lm.lm_pytorch.plainlm_model import ( + MLP, + Attention, + Block, + ModelConfig, + Transformer, + apply_rotary_emb_complex_like, + precompute_freqs_cis, +) + +FLAGS = flags.FLAGS +# Needed to avoid UnparsedFlagAccessError +FLAGS(sys.argv) + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def assert_close(jax_output, torch_output, rtol=1e-5, atol=1e-6, name=''): + """Assert that JAX and PyTorch outputs are close.""" + jax_np = np.array(jax_output) + torch_np = torch_output.detach().cpu().numpy() + + mse = np.mean((jax_np - torch_np) ** 2) + max_diff = np.max(np.abs(jax_np - torch_np)) + + logging.info(f'\n{name} Comparison:') + logging.info(f' MSE: {mse:.8e}') + logging.info(f' Max Difference: {max_diff:.8e}') + + np.testing.assert_allclose( + jax_np, + torch_np, + rtol=rtol, + atol=atol, + err_msg=f'{name} outputs do not match', + ) + + +# ============================================================================ +# Test Functions (unchanged) +# ============================================================================ + + +def test_rmsnorm(): + """Test that RMSNorm produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing RMSNorm') + logging.info('=' * 70) + + batch_size, seq_len, dim = 2, 10, 256 + eps = 1e-6 + + # Create random input + np_input = np.random.randn(batch_size, seq_len, dim).astype(np.float32) + + # Initialize PyTorch RMSNorm + torch_norm = torch.nn.RMSNorm(dim, eps=eps) + torch_input = torch.tensor(np_input) + + # Initialize JAX RMSNorm (using Flax's RMSNorm from nanodo) + from flax import linen as nn + + flax_norm = nn.RMSNorm(epsilon=eps) + jax_input = jnp.array(np_input) + flax_params = flax_norm.init(jax.random.PRNGKey(0), jax_input) + + # Copy weights from PyTorch to JAX + with torch.no_grad(): + flax_params['params']['scale'] = jnp.array(torch_norm.weight.numpy()) + + # Forward pass + with torch.no_grad(): + torch_output = torch_norm(torch_input) + + jax_output = flax_norm.apply(flax_params, jax_input) + + # Compare + assert_close(jax_output, torch_output, name='RMSNorm') + logging.info('✓ RMSNorm test passed') + + +def test_rope(): + """Test that RoPE produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing RoPE (Rotary Position Embeddings)') + logging.info('=' * 70) + + batch_size, seq_len, n_heads, dim = 2, 16, 4, 128 + head_dim = dim // n_heads + + # Initialize RoPE + torch_freqs = precompute_freqs_cis(head_dim, seq_len, theta=500000) + jax_freqs = init_rope(dim, seq_len, n_heads) + + # Create random Q and K + np_q = np.random.randn(batch_size, seq_len, n_heads, head_dim).astype( + np.float32 + ) + np_k = np.random.randn(batch_size, seq_len, n_heads, head_dim).astype( + np.float32 + ) + + # PyTorch forward + torch_q = torch.tensor(np_q) + torch_k = torch.tensor(np_k) + with torch.no_grad(): + torch_q_rot, torch_k_rot = apply_rotary_emb_complex_like( + torch_q, torch_k, freqs_cis=torch_freqs + ) + + # JAX forward + jax_q = jnp.array(np_q) + jax_k = jnp.array(np_k) + jax_q_rot, jax_k_rot = apply_rope(jax_q, jax_k, jax_freqs) + + # Compare + assert_close(jax_q_rot, torch_q_rot, name='RoPE Q') + assert_close(jax_k_rot, torch_k_rot, name='RoPE K') + logging.info('✓ RoPE test passed') + + +def copy_mlp_params(pytorch_mlp, flax_params): + """Copy MLP parameters from PyTorch to JAX.""" + new_params = flax_params.copy() + + # Handle compiled models + if hasattr(pytorch_mlp, '_orig_mod'): + pytorch_mlp = pytorch_mlp._orig_mod + + # Copy fc1 and fc2 weights (transposed for JAX) + new_params['params']['Dense_0']['kernel'] = ( + pytorch_mlp.fc1.weight.detach().numpy().T + ) + new_params['params']['Dense_1']['kernel'] = ( + pytorch_mlp.fc2.weight.detach().numpy().T + ) + + return new_params + + +def test_mlp(): + """Test that MLP produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing MLP') + logging.info('=' * 70) + + batch_size, seq_len, dim = 2, 10, 256 + hidden_dim = 1024 + + # Initialize PyTorch MLP + pytorch_mlp = MLP(dim=dim, hidden_dim=hidden_dim) + + # Initialize JAX MLP + cfg = DoConfig( + D=dim, + H=4, + L=128, + N=2, + V=1000, + F=hidden_dim, + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + ) + flax_mlp = Mlp(cfg) + + # Initialize JAX params + dummy_input = jnp.ones((batch_size, seq_len, dim)) + flax_params = flax_mlp.init(jax.random.PRNGKey(0), dummy_input) + + # Copy weights + flax_params = copy_mlp_params(pytorch_mlp, flax_params) + + # Create input + np_input = np.random.randn(batch_size, seq_len, dim).astype(np.float32) + torch_input = torch.tensor(np_input) + jax_input = jnp.array(np_input) + + # Forward pass + with torch.no_grad(): + torch_output = pytorch_mlp(torch_input) + + jax_output = flax_mlp.apply(flax_params, jax_input) + + # Compare + assert_close(jax_output, torch_output, name='MLP') + logging.info('✓ MLP test passed') + + +def copy_attention_params(pytorch_attn, flax_params): + """Copy attention parameters from PyTorch to JAX.""" + # Handle compiled models + if hasattr(pytorch_attn, '_orig_mod'): + pytorch_attn = pytorch_attn._orig_mod + + n_heads = pytorch_attn.n_heads + head_dim = pytorch_attn.head_dim + dim = pytorch_attn.dim + + # Split PyTorch's combined qkv weights + w_qkv = pytorch_attn.w_qkv.weight + q_weight, k_weight, v_weight = [ + u.detach().numpy() for u in w_qkv.split(dim, dim=0) + ] + + # Reshape for Flax's DenseGeneral format [D, H, Dh] + def reshape_for_flax(w, n_heads, head_dim): + return w.reshape(n_heads, head_dim, -1).transpose(2, 0, 1) + + new_params = { + 'query': {'kernel': reshape_for_flax(q_weight, n_heads, head_dim)}, + 'key': {'kernel': reshape_for_flax(k_weight, n_heads, head_dim)}, + 'value': {'kernel': reshape_for_flax(v_weight, n_heads, head_dim)}, + 'attn_out_proj': {'kernel': pytorch_attn.w_out.weight.detach().numpy().T}, + } + + return {'params': new_params} + + +def test_attention(): + """Test that Attention produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Attention') + logging.info('=' * 70) + + batch_size, seq_len, dim, n_heads = 2, 16, 256, 4 + + # Initialize PyTorch Attention + config = ModelConfig( + vocab_size=1000, + seq_len=seq_len, + dim=dim, + expand=4.0, + n_layers=1, + n_heads=n_heads, + rmsnorm_eps=1e-6, + ) + pytorch_attn = Attention(config) + freqs_cis = precompute_freqs_cis(dim // n_heads, seq_len, theta=500000) + + # Initialize JAX Attention + cfg = DoConfig( + D=dim, + H=n_heads, + L=seq_len, + N=1, + V=1000, + F=1024, + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + ) + flax_attn = CausalAttn(cfg) + + # Initialize JAX params + dummy_input = jnp.ones((batch_size, seq_len, dim)) + flax_params = flax_attn.init(jax.random.PRNGKey(0), dummy_input) + + # Copy weights + flax_params = copy_attention_params(pytorch_attn, flax_params) + + # Create input + np_input = np.random.randn(batch_size, seq_len, dim).astype(np.float32) + torch_input = torch.tensor(np_input) + jax_input = jnp.array(np_input) + + # Forward pass + with torch.no_grad(): + torch_output = pytorch_attn(torch_input, freqs_cis) + + jax_output = flax_attn.apply(flax_params, jax_input) + + # Compare + assert_close(jax_output, torch_output, rtol=1e-4, atol=1e-5, name='Attention') + logging.info('✓ Attention test passed') + + +def copy_block_params(pytorch_block, flax_params): + """Copy block parameters from PyTorch to JAX.""" + # Copy attention parameters + attn_params = copy_attention_params(pytorch_block.attn, {'params': {}})[ + 'params' + ] + + # Copy MLP parameters + pytorch_mlp = pytorch_block.mlp + mlp_params = { + 'Dense_0': {'kernel': pytorch_mlp.fc1.weight.detach().numpy().T}, + 'Dense_1': {'kernel': pytorch_mlp.fc2.weight.detach().numpy().T}, + } + + # Copy RMSNorm parameters + norm_params = { + 'attn_norm': {'scale': pytorch_block.attn_norm.weight.detach().numpy()}, + 'mlp_norm': {'scale': pytorch_block.mlp_norm.weight.detach().numpy()}, + } + + return { + 'params': { + 'CausalAttn_0': attn_params, + 'Mlp_0': mlp_params, + 'RMSNorm_0': norm_params['attn_norm'], + 'RMSNorm_1': norm_params['mlp_norm'], + } + } + + +def test_block(): + """Test that Transformer Block produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Transformer Block') + logging.info('=' * 70) + + batch_size, seq_len, dim, n_heads = 2, 16, 256, 4 + expand = 4.0 + + # Initialize PyTorch Block + config = ModelConfig( + vocab_size=1000, + seq_len=seq_len, + dim=dim, + expand=expand, + n_layers=1, + n_heads=n_heads, + rmsnorm_eps=1e-6, + ) + pytorch_block = Block(layer_id=0, cfg=config) + freqs_cis = precompute_freqs_cis(dim // n_heads, seq_len, theta=500000) + + # Initialize JAX Block + cfg = DoConfig( + D=dim, + H=n_heads, + L=seq_len, + N=1, + V=1000, + F=int(dim * expand), + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + ) + flax_block = TBlock(cfg) + + # Initialize JAX params + dummy_input = jnp.ones((batch_size, seq_len, dim)) + flax_params = flax_block.init(jax.random.PRNGKey(0), dummy_input) + + # Copy weights + flax_params = copy_block_params(pytorch_block, flax_params) + + # Create input + np_input = np.random.randn(batch_size, seq_len, dim).astype(np.float32) + torch_input = torch.tensor(np_input) + jax_input = jnp.array(np_input) + + # Forward pass + with torch.no_grad(): + torch_output = pytorch_block(torch_input, freqs_cis) + + jax_output = flax_block.apply(flax_params, jax_input) + + # Compare + assert_close(jax_output, torch_output, rtol=1e-4, atol=1e-5, name='Block') + logging.info('✓ Block test passed') + + +def copy_full_model_params(pytorch_model, flax_params, config): + """Copy all parameters from PyTorch model to JAX model.""" + # Handle tied embeddings case + if hasattr(pytorch_model, '_orig_mod'): + pytorch_model = pytorch_model._orig_mod + + n_layers = config.n_layers + n_heads = config.n_heads + dim = config.dim + head_dim = dim // n_heads + + new_params = {'params': {}} + + # Copy embedding weights + new_params['params']['embed'] = { + 'embedding': pytorch_model.embed_tokens.weight.detach().numpy() + } + + # Copy each transformer block + for i in range(n_layers): + pytorch_block = pytorch_model.layers[i] + + # Attention params + w_qkv = pytorch_block.attn.w_qkv.weight + q_weight, k_weight, v_weight = [ + u.detach().numpy() for u in w_qkv.split(dim, dim=0) + ] + + def reshape_for_flax(w, n_heads, head_dim): + return w.reshape(n_heads, head_dim, -1).transpose(2, 0, 1) + + attn_params = { + 'query': {'kernel': reshape_for_flax(q_weight, n_heads, head_dim)}, + 'key': {'kernel': reshape_for_flax(k_weight, n_heads, head_dim)}, + 'value': {'kernel': reshape_for_flax(v_weight, n_heads, head_dim)}, + 'attn_out_proj': { + 'kernel': pytorch_block.attn.w_out.weight.detach().numpy().T + }, + } + + # MLP params + mlp_params = { + 'Dense_0': {'kernel': pytorch_block.mlp.fc1.weight.detach().numpy().T}, + 'Dense_1': {'kernel': pytorch_block.mlp.fc2.weight.detach().numpy().T}, + } + + # Norm params + attn_norm = {'scale': pytorch_block.attn_norm.weight.detach().numpy()} + mlp_norm = {'scale': pytorch_block.mlp_norm.weight.detach().numpy()} + + # Assemble block params + block_key = f'blocks_{i}' + new_params['params'][block_key] = { + 'CausalAttn_0': attn_params, + 'Mlp_0': mlp_params, + 'RMSNorm_0': attn_norm, + 'RMSNorm_1': mlp_norm, + } + + # Copy output norm + new_params['params']['out_ln'] = { + 'scale': pytorch_model.out_norm.weight.detach().numpy() + } + + # Handle output projection (tied or untied) + if not config.tie_embeddings: + new_params['params']['output_proj'] = { + 'kernel': pytorch_model.lm_head.weight.detach().numpy().T + } + + return new_params + + +def test_full_model(): + """Test that full Transformer model produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Full Transformer Model') + logging.info('=' * 70) + + batch_size, seq_len = 2, 32 + vocab_size = 256 + dim = 128 + n_heads = 4 + n_layers = 2 + expand = 4.0 + + # Initialize PyTorch model + pytorch_config = ModelConfig( + vocab_size=vocab_size, + seq_len=seq_len, + dim=dim, + expand=expand, + n_layers=n_layers, + n_heads=n_heads, + rmsnorm_eps=1e-6, + tie_embeddings=True, + ) + pytorch_model = Transformer(pytorch_config) + pytorch_model.eval() + + # Initialize JAX model + jax_config = DoConfig( + D=dim, + H=n_heads, + L=seq_len, + N=n_layers, + V=vocab_size, + F=int(dim * expand), + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + tie_embeddings=True, + ) + jax_model = TransformerDo(jax_config) + + # Create input tokens + np_tokens = np.random.randint( + 0, vocab_size, size=(batch_size, seq_len), dtype=np.int32 + ) + torch_tokens = torch.tensor(np_tokens, dtype=torch.long) + jax_tokens = jnp.array(np_tokens, dtype=jnp.int32) + + # Initialize JAX params + jax_params = jax_model.init(jax.random.PRNGKey(0), jax_tokens) + + # Copy weights from PyTorch to JAX + jax_params = copy_full_model_params(pytorch_model, jax_params, pytorch_config) + + # Forward pass + with torch.no_grad(): + torch_logits = pytorch_model(torch_tokens) + + jax_logits = jax_model.apply(jax_params, jax_tokens) + + # Compare + assert_close( + jax_logits, torch_logits, rtol=1e-4, atol=1e-5, name='Full Model' + ) + logging.info('✓ Full Model test passed') + + +def test_prediction(): + """Test that autoregressive generation produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Autoregressive Prediction') + logging.info('=' * 70) + + batch_size, seq_len = 1, 10 + vocab_size = 256 + dim = 128 + n_heads = 4 + n_layers = 2 + expand = 4.0 + k = 5 # Number of tokens to predict + + # Initialize PyTorch model + pytorch_config = ModelConfig( + vocab_size=vocab_size, + seq_len=seq_len + k, + dim=dim, + expand=expand, + n_layers=n_layers, + n_heads=n_heads, + rmsnorm_eps=1e-6, + tie_embeddings=True, + ) + pytorch_model = Transformer(pytorch_config) + pytorch_model.eval() + + # Initialize JAX model + jax_config = DoConfig( + D=dim, + H=n_heads, + L=seq_len + k, + N=n_layers, + V=vocab_size, + F=int(dim * expand), + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + tie_embeddings=True, + ) + jax_model = TransformerDo(jax_config) + + # Create input tokens + np_tokens = np.random.randint( + 0, vocab_size, size=(batch_size, seq_len), dtype=np.int32 + ) + torch_tokens = torch.tensor(np_tokens, dtype=torch.long) + jax_tokens = jnp.array(np_tokens, dtype=jnp.int32) + + # Initialize JAX params + jax_params = jax_model.init(jax.random.PRNGKey(0), jax_tokens) + + # Copy weights from PyTorch to JAX + jax_params = copy_full_model_params(pytorch_model, jax_params, pytorch_config) + + # Predict k tokens + with torch.no_grad(): + _, torch_predictions = pytorch_model.predict(torch_tokens, k=k) + + _, jax_predictions = jax_model.apply( + jax_params, jax_tokens, k, method=jax_model.predict + ) + + # Compare predictions + torch_pred_np = torch_predictions.cpu().numpy() + jax_pred_np = np.array(jax_predictions) + + logging.info(f'\nPyTorch predictions: {torch_pred_np[0]}') + logging.info(f'JAX predictions: {jax_pred_np[0]}') + + # Check if predictions match exactly + if np.array_equal(torch_pred_np, jax_pred_np): + logging.info('✓ Predictions match exactly!') + else: + matching = np.sum(torch_pred_np == jax_pred_np) + total = torch_pred_np.size + logging.info( + f'⚠ Predictions differ: {matching}/{total} tokens match ({matching / total * 100:.1f}%)' + ) + logging.info( + ' (Note: Small numerical differences can lead to different argmax results)' + ) + + +def test_initialization_statistics(): + """Verify initialization follows expected distributions.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Initialization Statistics') + logging.info('=' * 70) + + # Initialize models + jax_cfg = DoConfig(D=512, H=8, L=1024, N=12, V=50000, F=2048) + jax_model = TransformerDo(jax_cfg) + jax_params = jax_model.init( + jax.random.PRNGKey(42), jnp.ones((1, 10), dtype=jnp.int32) + ) + + pytorch_cfg = ModelConfig( + vocab_size=50000, seq_len=1024, dim=512, expand=4.0, n_layers=12, n_heads=8 + ) + pytorch_model = Transformer(pytorch_cfg) + + logging.info('Initialization Statistics Check:') + + # Check embedding + jax_embed = jax_params['params']['embed']['embedding'] + torch_embed = pytorch_model.embed_tokens.weight.detach().numpy() + + logging.info('\nToken Embedding (should be ~0.02 std):') + logging.info( + f' JAX: mean={jax_embed.mean():.6f}, std={jax_embed.std():.6f}' + ) + logging.info( + f' PyTorch: mean={torch_embed.mean():.6f}, std={torch_embed.std():.6f}' + ) + + # Assert embedding std is close to 0.02 + assert abs(jax_embed.std() - 0.02) < 0.005, ( + f'JAX embedding std {jax_embed.std():.6f} not close to 0.02' + ) + assert abs(torch_embed.std() - 0.02) < 0.005, ( + f'PyTorch embedding std {torch_embed.std():.6f} not close to 0.02' + ) + assert abs(jax_embed.mean()) < 0.01, ( + f'JAX embedding mean {jax_embed.mean():.6f} not close to 0' + ) + assert abs(torch_embed.mean()) < 0.01, ( + f'PyTorch embedding mean {torch_embed.mean():.6f} not close to 0' + ) + + # Check first layer attention Q + jax_q = jax_params['params']['blocks_0']['CausalAttn_0']['query']['kernel'] + torch_q_weight = ( + pytorch_model.layers[0].attn.w_qkv.weight[:512].detach().numpy() + ) + + logging.info('\nAttention Q:') + logging.info(f' JAX: mean={jax_q.mean():.6f}, std={jax_q.std():.6f}') + logging.info( + f' PyTorch: mean={torch_q_weight.mean():.6f}, std={torch_q_weight.std():.6f}' + ) + + # Check means are close to 0 + assert abs(jax_q.mean()) < 0.01, ( + f'JAX Q mean {jax_q.mean():.6f} not close to 0' + ) + assert abs(torch_q_weight.mean()) < 0.01, ( + f'PyTorch Q mean {torch_q_weight.mean():.6f} not close to 0' + ) + + # Check stds are similar + # Allow 20% difference due to random initialization + assert abs(jax_q.std() - torch_q_weight.std()) / torch_q_weight.std() < 0.2, ( + f'Q std differs too much: JAX {jax_q.std():.6f} vs PyTorch {torch_q_weight.std():.6f}' + ) + + # Check first layer attention output (should be scaled) + jax_attn_out = jax_params['params']['blocks_0']['CausalAttn_0'][ + 'attn_out_proj' + ]['kernel'] + torch_attn_out = pytorch_model.layers[0].attn.w_out.weight.detach().numpy() + + logging.info('\nAttention Output:') + logging.info( + f' JAX: mean={jax_attn_out.mean():.6f}, std={jax_attn_out.std():.6f}' + ) + logging.info( + f' PyTorch: mean={torch_attn_out.mean():.6f}, std={torch_attn_out.std():.6f}' + ) + + # Check means are close to 0 + assert abs(jax_attn_out.mean()) < 0.01, ( + f'JAX attn out mean {jax_attn_out.mean():.6f} not close to 0' + ) + assert abs(torch_attn_out.mean()) < 0.01, ( + f'PyTorch attn out mean {torch_attn_out.mean():.6f} not close to 0' + ) + + # Check stds are similar + assert ( + abs(jax_attn_out.std() - torch_attn_out.std()) / torch_attn_out.std() < 0.2 + ), ( + f'Attention output std differs too much: JAX {jax_attn_out.std():.6f} vs PyTorch {torch_attn_out.std():.6f}' + ) + + # Check MLP fc2 (should be scaled) + jax_mlp_out = jax_params['params']['blocks_0']['Mlp_0']['Dense_1']['kernel'] + torch_mlp_out = pytorch_model.layers[0].mlp.fc2.weight.detach().numpy() + + logging.info('\nMLP Output:') + logging.info( + f' JAX: mean={jax_mlp_out.mean():.6f}, std={jax_mlp_out.std():.6f}' + ) + logging.info( + f' PyTorch: mean={torch_mlp_out.mean():.6f}, std={torch_mlp_out.std():.6f}' + ) + + # Check means are close to 0 + assert abs(jax_mlp_out.mean()) < 0.01, ( + f'JAX MLP out mean {jax_mlp_out.mean():.6f} not close to 0' + ) + assert abs(torch_mlp_out.mean()) < 0.01, ( + f'PyTorch MLP out mean {torch_mlp_out.mean():.6f} not close to 0' + ) + + # Check stds are similar + assert ( + abs(jax_mlp_out.std() - torch_mlp_out.std()) / torch_mlp_out.std() < 0.2 + ), ( + f'MLP output std differs too much: JAX {jax_mlp_out.std():.6f} vs PyTorch {torch_mlp_out.std():.6f}' + ) + + logging.info('\n✓ Initialization statistics test passed') + + +def test_initialization_impact(): + """Test that initialization produces similar initial losses.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Initialization Impact') + logging.info('=' * 70) + + # Create identical inputs + batch_size, seq_len = 4, 128 + vocab_size = 50000 + + np.random.seed(42) + tokens = np.random.randint(0, vocab_size, size=(batch_size, seq_len)) + + # Initialize both models with same seed + jax_cfg = DoConfig(D=512, H=8, L=seq_len, N=12, V=vocab_size, F=2048) + jax_model = TransformerDo(jax_cfg) + jax_params = jax_model.init( + jax.random.PRNGKey(42), jnp.array(tokens, dtype=jnp.int32) + ) + + torch.manual_seed(42) + pytorch_cfg = ModelConfig( + vocab_size=vocab_size, + seq_len=seq_len, + dim=512, + expand=4.0, + n_layers=12, + n_heads=8, + ) + pytorch_model = Transformer(pytorch_cfg) + + # Forward pass + jax_logits = jax_model.apply(jax_params, jnp.array(tokens, dtype=jnp.int32)) + + with torch.no_grad(): + torch_logits = pytorch_model(torch.tensor(tokens, dtype=torch.long)) + + # Compute losses + targets = tokens[:, 1:] + jax_loss = -jax.nn.log_softmax(jax_logits[:, :-1]).mean() + torch_loss = F.cross_entropy( + torch_logits[:, :-1].reshape(-1, vocab_size), + torch.tensor(targets.reshape(-1), dtype=torch.long), + ) + + logging.info('\nInitial Loss Comparison:') + logging.info(f' JAX: {jax_loss:.4f}') + logging.info(f' PyTorch: {torch_loss.item():.4f}') + logging.info(f' Difference: {abs(jax_loss - torch_loss.item()):.6f}') + + # Check that losses are in reasonable range for random init + # With vocab_size=50000, random init should give loss around log(50000) ≈ 10.82 + expected_loss = np.log(vocab_size) + + assert 8.0 < jax_loss < 13.0, ( + f'JAX loss {jax_loss:.4f} outside expected range [8.0, 13.0]' + ) + assert 8.0 < torch_loss.item() < 13.0, ( + f'PyTorch loss {torch_loss.item():.4f} outside expected range [8.0, 13.0]' + ) + + # Both losses should be within 10% of log(vocab_size) + assert abs(jax_loss - expected_loss) / expected_loss < 0.1, ( + f'JAX loss {jax_loss:.4f} too far from expected {expected_loss:.4f}' + ) + assert abs(torch_loss.item() - expected_loss) / expected_loss < 0.1, ( + f'PyTorch loss {torch_loss.item():.4f} too far from expected {expected_loss:.4f}' + ) + + logging.info( + '\nNote: Losses are in expected range for random initialization.' + ) + logging.info(f' Expected ~log(vocab_size) = {expected_loss:.4f}') + logging.info('\n✓ Initialization impact test passed') + + +# ============================================================================ +# Test Class +# ============================================================================ + +named_parameters = [ + dict(testcase_name='rmsnorm', test_fn=test_rmsnorm), + dict(testcase_name='rope', test_fn=test_rope), + dict(testcase_name='mlp', test_fn=test_mlp), + dict(testcase_name='attention', test_fn=test_attention), + dict(testcase_name='block', test_fn=test_block), + dict(testcase_name='full_model', test_fn=test_full_model), + dict(testcase_name='prediction', test_fn=test_prediction), + dict( + testcase_name='initialization_statistics', + test_fn=test_initialization_statistics, + ), + dict( + testcase_name='initialization_impact', test_fn=test_initialization_impact + ), +] + + +class ModelMatchingTest(parameterized.TestCase): + """Tests for JAX vs PyTorch model matching.""" + + @parameterized.named_parameters(*named_parameters) + def test_model_matching(self, test_fn): + """Run individual model matching test.""" + test_fn() + + +if __name__ == '__main__': + absltest.main() From b59afa0120f98e7aecc04b3393addb2acbdafe23 Mon Sep 17 00:00:00 2001 From: rka97 Date: Tue, 21 Oct 2025 09:06:29 +0000 Subject: [PATCH 72/82] Refactor model configuration classes to make them consistent between JAX and PyTorch, also unify initialization to be the same in both --- algoperf/workloads/lm/lm_jax/nanodo_model.py | 79 ++++---- algoperf/workloads/lm/lm_jax/workload.py | 16 +- .../workloads/lm/lm_pytorch/plainlm_model.py | 61 +++---- algoperf/workloads/lm/lm_pytorch/workload.py | 11 +- tests/modeldiffs/lm/compare.py | 169 ++++++++++-------- 5 files changed, 180 insertions(+), 156 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py index 1227e57b2..2b47c1735 100644 --- a/algoperf/workloads/lm/lm_jax/nanodo_model.py +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -12,32 +12,33 @@ @dataclasses.dataclass -class DoConfig: +class ModelConfig: """Hyper-parameters for Transformer decoder-only.""" - D: int # model/embed dim = qkv dim - H: int # num attention heads - L: int # max context/sequence length - N: int # number of transformer block layers - V: int # vocab size - F: int # FF inner dimension + model_dim: int # model/embed dim = qkv dim + num_heads: int # num attention heads + seq_len: int # max context/sequence length + num_layers: int # number of transformer block layers + vocab_size: int # vocab size + expanded_model_dim: int # FF inner dimension + multiple_of: int = 256 + rmsnorm_epsilon: float = 1e-6 + use_residual_scaling: bool = True + tie_embeddings: bool = True # Whether to tie input and output embed + + dtype: jnp.dtype = jnp.float32 attention_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) linear_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) embed_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) - use_residual_scaling: bool = True - dtype: jnp.dtype = jnp.float32 - rmsnorm_epsilon: float = 1e-6 - multiple_of: int = 256 - tie_embeddings: bool = True # Whether to tie input and output embed def __post_init__(self): - self.residual_init = nn.initializers.normal(stddev=0.02/jnp.sqrt(2 * self.N)) + self.residual_init = nn.initializers.normal(stddev=0.02/jnp.sqrt(2 * self.num_layers)) class Mlp(nn.Module): """Multilayer perceptron with GLU activation.""" - cfg: DoConfig + cfg: ModelConfig @nn.compact def __call__(self, x_BxLxD: jax.Array): @@ -49,15 +50,15 @@ def __call__(self, x_BxLxD: jax.Array): # Adjust hidden dimension to keep the number of parameters invariant to # the activation function used since the GLU MLP has 3 * hidden_dim * D # parameters instead of 2 * hidden_dim * D parameters - hidden_dim = cfg.F * 2 / 3 + hidden_dim = cfg.expanded_model_dim * 2 / 3 hidden_dim = cfg.multiple_of * ( - (cfg.F + cfg.multiple_of - 1) // cfg.multiple_of + (cfg.expanded_model_dim + cfg.multiple_of - 1) // cfg.multiple_of ) # Double the hidden dimension for GLU x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD) # Apply GLU activation x_BxLxF = nn.glu(x_BxLx2F, axis=-1) - x_BxLxD = nn.Dense(cfg.D, use_bias=False, dtype=cfg.dtype, kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init)(x_BxLxF) + x_BxLxD = nn.Dense(cfg.model_dim, use_bias=False, dtype=cfg.dtype, kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init)(x_BxLxF) return x_BxLxD @@ -109,21 +110,21 @@ def rotate_tensor(x): class CausalAttn(nn.Module): """Causal attention layer with rotary embeddings.""" - cfg: DoConfig + cfg: ModelConfig def setup(self): cfg = self.cfg - assert cfg.D % cfg.H == 0, f'D {cfg.D} not divisible by H {cfg.H}' - self.Dh = cfg.D // cfg.H + assert cfg.model_dim % cfg.num_heads == 0, f'D {cfg.model_dim} not divisible by H {cfg.num_heads}' + self.Dh = cfg.model_dim // cfg.num_heads # Initialize rotary embeddings - self.freqs_cis = init_rope(cfg.D, cfg.L, cfg.H) + self.freqs_cis = init_rope(cfg.model_dim, cfg.seq_len, cfg.num_heads) # Maps D -> (H, Dh) self.multilinear = partial( nn.DenseGeneral, axis=-1, - features=(cfg.H, self.Dh), + features=(cfg.num_heads, self.Dh), kernel_init=cfg.attention_init, use_bias=False, dtype=cfg.dtype, @@ -133,7 +134,7 @@ def setup(self): self.multilinear_key = self.multilinear(name='key') self.multilinear_value = self.multilinear(name='value') self.output_projection = nn.DenseGeneral( - features=cfg.D, + features=cfg.model_dim, name='attn_out_proj', # axis=(-2, -1), # kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init, @@ -183,7 +184,7 @@ def __call__(self, x_BxLxD: jax.Array): class TBlock(nn.Module): """Transformer Block.""" - docfg: DoConfig + docfg: ModelConfig @nn.compact def __call__(self, in_BxLxD: jax.Array): @@ -208,17 +209,17 @@ def __call__(self, in_BxLxD: jax.Array): class TransformerDo(nn.Module): """Transformer decoder-only.""" - docfg: DoConfig + docfg: ModelConfig def setup(self): cfg = self.docfg self.embed = nn.Embed( - num_embeddings=cfg.V, - features=cfg.D, + num_embeddings=cfg.vocab_size, + features=cfg.model_dim, embedding_init=cfg.embed_init, ) - self.blocks = [TBlock(cfg) for _ in range(cfg.N)] + self.blocks = [TBlock(cfg) for _ in range(cfg.num_layers)] self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) # Output projection - tied to input embeddings if configured @@ -226,7 +227,7 @@ def setup(self): self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) else: self.output_proj = nn.Dense( - cfg.V, kernel_init=cfg.embed_init, dtype=cfg.dtype, name='output_proj' + cfg.vocab_size, kernel_init=cfg.embed_init, dtype=cfg.dtype, name='output_proj' ) def __call__(self, y_BxL: jax.Array): @@ -255,9 +256,9 @@ def predict(self, y_BxL: jax.Array, k: int = 1): original_input = y_BxL # Make sure we don't exceed the model's context length - if seq_len + k > cfg.L: + if seq_len + k > cfg.seq_len: raise ValueError( - f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.L})" + f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.seq_len})" ) # Generate k tokens autoregressively @@ -288,17 +289,17 @@ def main(): """Create and run the DecoderOnly Transformer model.""" # Initialize model configuration with smaller parameters for demo B, L = (2, 128) # Batch size, sequence length - cfg = DoConfig(D=128, H=4, L=L, N=2, V=256, F=4 * 128) + cfg = ModelConfig(model_dim=128, num_heads=4, seq_len=L, num_layers=2, vocab_size=256, expanded_model_dim=4 * 128) model = TransformerDo(cfg) # Print model info print('\nModel Configuration:') - print(f' - Model dimension (D): {cfg.D}') - print(f' - Number of heads (H): {cfg.H}') - print(f' - Max sequence length (L): {cfg.L}') - print(f' - Number of layers (N): {cfg.N}') - print(f' - Vocabulary size (V): {cfg.V}') - print(f' - Feed forward dimension (F): {cfg.F}') + print(f' - Model dimension (D): {cfg.model_dim}') + print(f' - Number of heads (H): {cfg.num_heads}') + print(f' - Max sequence length (L): {cfg.seq_len}') + print(f' - Number of layers (N): {cfg.num_layers}') + print(f' - Vocabulary size (V): {cfg.vocab_size}') + print(f' - Feed forward dimension (F): {cfg.expanded_model_dim}') # Create random input tokens (simulated token IDs) rng_key = jax.random.PRNGKey(42) @@ -306,7 +307,7 @@ def main(): # Generate random token IDs (integers between 0 and vocab_size-1) x_BxL = jax.random.randint( - input_rng, shape=(B, L), minval=0, maxval=cfg.V, dtype=jnp.int32 + input_rng, shape=(B, L), minval=0, maxval=cfg.vocab_size, dtype=jnp.int32 ) # Initialize model parameters diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 13738086a..effb12089 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -8,7 +8,7 @@ from algoperf import jax_sharding_utils, param_utils, spec from algoperf.workloads.lm.input_pipeline import get_data_iter from algoperf.workloads.lm.lm_jax.nanodo_model import ( - DoConfig, + ModelConfig, TransformerDo, ) from algoperf.workloads.lm.workload import BaseLmWorkload @@ -46,13 +46,13 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None, ) -> spec.ModelInitState: # Initialize NanoDO transformer model - cfg = DoConfig( - D=self._emb_dim, # embedding dim - H=self._n_heads, # num heads - L=self._seq_len, - N=self._n_layers, # num layers - V=self._vocab_size, - F=self._mlp_dim, # feedforward dim + cfg = ModelConfig( + model_dim=self._emb_dim, # embedding dim + num_heads=self._n_heads, # num heads + seq_len=self._seq_len, + num_layers=self._n_layers, # num layers + vocab_size=self._vocab_size, + expanded_model_dim=self._mlp_dim, # feedforward dim dtype=jnp.float32, ) self._model = TransformerDo(cfg) diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py index af4232b7e..8186638e7 100644 --- a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -15,15 +15,16 @@ @dataclass class ModelConfig: - vocab_size: int + model_dim: int + num_heads: int seq_len: int - dim: int - expand: float - n_layers: int - n_heads: int - rmsnorm_eps: float = 1e-6 - tie_embeddings: bool = True + num_layers: int + vocab_size: int + expanded_model_dim: int + multiple_of: int = 256 + rmsnorm_epsilon: float = 1e-6 use_residual_scaling: bool = True + tie_embeddings: bool = True class MLP(nn.Module): @@ -81,13 +82,13 @@ def apply_rotary_emb_complex_like( class Attention(nn.Module): def __init__(self, cfg: ModelConfig): super().__init__() - assert cfg.dim % cfg.n_heads == 0 - self.dim = cfg.dim - self.n_heads = cfg.n_heads - self.head_dim = cfg.dim // cfg.n_heads + assert cfg.model_dim % cfg.num_heads == 0 + self.dim = cfg.model_dim + self.n_heads = cfg.num_heads + self.head_dim = cfg.model_dim // cfg.num_heads - self.w_qkv = nn.Linear(cfg.dim, 3 * cfg.dim, bias=False) - self.w_out = nn.Linear(cfg.dim, cfg.dim, bias=False) + self.w_qkv = nn.Linear(cfg.model_dim, 3 * cfg.model_dim, bias=False) + self.w_out = nn.Linear(cfg.model_dim, cfg.model_dim, bias=False) # Split into Q, K, V sections wq, wk, wv = torch.chunk(self.w_qkv.weight, 3, dim=0) for w in [wq, wk, wv]: @@ -131,9 +132,9 @@ class Block(nn.Module): def __init__(self, layer_id: int, cfg: ModelConfig): super().__init__() self.attn = Attention(cfg) - self.attn_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) - self.mlp = MLP(dim=cfg.dim, hidden_dim=int(cfg.expand * cfg.dim)) - self.mlp_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.attn_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) + self.mlp = MLP(dim=cfg.model_dim, hidden_dim=cfg.expanded_model_dim, multiple_of=cfg.multiple_of) + self.mlp_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) self.layer_id = layer_id def forward(self, x, freqs_cis): @@ -144,19 +145,19 @@ def forward(self, x, freqs_cis): class Transformer(nn.Module): - def __init__(self, cfg): + def __init__(self, cfg: ModelConfig): super().__init__() - self.n_layers = cfg.n_layers + self.n_layers = cfg.num_layers self.cfg = cfg - head_dim = cfg.dim // cfg.n_heads - assert cfg.dim % cfg.n_heads == 0 + head_dim = cfg.model_dim // cfg.num_heads + assert cfg.model_dim % cfg.num_heads == 0 - self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.dim) + self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.model_dim) self.layers = nn.ModuleList( - [Block(idx, cfg) for idx in range(cfg.n_layers)] + [Block(idx, cfg) for idx in range(cfg.num_layers)] ) - self.out_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) - self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) + self.out_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) + self.lm_head = nn.Linear(cfg.model_dim, cfg.vocab_size, bias=False) # Initialize freqs_cis on CPU first (more memory efficient) self.register_buffer( @@ -184,7 +185,7 @@ def forward(self, x, targets=None): # Make sure we have enough precomputed frequencies if L > self.freqs_cis.shape[1]: # Need to recompute for longer sequence - head_dim = self.cfg.dim // self.cfg.n_heads + head_dim = self.cfg.model_dim // self.cfg.num_heads new_freqs = precompute_freqs_cis( head_dim, max(L, self.cfg.seq_len), 500000 ) @@ -290,11 +291,11 @@ def main(): config = ModelConfig( vocab_size=50257, # Common vocab size for tokenizers like BPE or SentencePiece seq_len=seq_length, # Maximum sequence length - dim=1024, # Embedding dimension - expand=4.0, # MLP expansion factor - n_layers=12, # Number of transformer layers - n_heads=8, # Number of attention heads - rmsnorm_eps=1e-6, # RMSNorm epsilon + model_dim=1024, # Embedding dimension + expanded_model_dim=4.0, # MLP expansion factor + num_layers=12, # Number of transformer layers + num_heads=8, # Number of attention heads + rmsnorm_epsilon=1e-6, # RMSNorm epsilon tie_embeddings=True, # Tie embedding and output weights ) diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 2f5c33ebf..3d185636b 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -39,12 +39,11 @@ def init_model_fn( cfg = ModelConfig( vocab_size=self._vocab_size, seq_len=self._seq_len, - dim=self._emb_dim, # Model dimension - expand=self._mlp_dim // self._emb_dim, # MLP expansion factor - # FIXME(rka97): fix expansion factor - n_layers=self._n_layers, # Number of transformer layers - n_heads=self._n_heads, # Number of attention heads - rmsnorm_eps=1e-6, + model_dim=self._emb_dim, # Model dimension + expanded_model_dim=self._mlp_dim, # MLP expansion factor + num_layers=self._n_layers, # Number of transformer layers + num_heads=self._n_heads, # Number of attention heads + rmsnorm_epsilon=1e-6, tie_embeddings=True, ) self._model = Transformer(cfg) diff --git a/tests/modeldiffs/lm/compare.py b/tests/modeldiffs/lm/compare.py index 5b95f934c..f681597d8 100644 --- a/tests/modeldiffs/lm/compare.py +++ b/tests/modeldiffs/lm/compare.py @@ -28,24 +28,28 @@ # Import JAX implementation from algoperf.workloads.lm.lm_jax.nanodo_model import ( CausalAttn, - DoConfig, Mlp, TBlock, TransformerDo, apply_rope, init_rope, ) +from algoperf.workloads.lm.lm_jax.nanodo_model import ( + ModelConfig as JaxModelConfig, +) # Import PyTorch implementation from algoperf.workloads.lm.lm_pytorch.plainlm_model import ( MLP, Attention, Block, - ModelConfig, Transformer, apply_rotary_emb_complex_like, precompute_freqs_cis, ) +from algoperf.workloads.lm.lm_pytorch.plainlm_model import ( + ModelConfig as PyTorchModelConfig, +) FLAGS = flags.FLAGS # Needed to avoid UnparsedFlagAccessError @@ -192,13 +196,13 @@ def test_mlp(): pytorch_mlp = MLP(dim=dim, hidden_dim=hidden_dim) # Initialize JAX MLP - cfg = DoConfig( - D=dim, - H=4, - L=128, - N=2, - V=1000, - F=hidden_dim, + cfg = JaxModelConfig( + model_dim=dim, + num_heads=4, + seq_len=128, + num_layers=2, + vocab_size=1000, + expanded_model_dim=hidden_dim, dtype=jnp.float32, rmsnorm_epsilon=1e-6, ) @@ -266,26 +270,26 @@ def test_attention(): batch_size, seq_len, dim, n_heads = 2, 16, 256, 4 # Initialize PyTorch Attention - config = ModelConfig( + config = PyTorchModelConfig( vocab_size=1000, seq_len=seq_len, - dim=dim, - expand=4.0, - n_layers=1, - n_heads=n_heads, - rmsnorm_eps=1e-6, + model_dim=dim, + expanded_model_dim=1024, + num_layers=1, + num_heads=n_heads, + rmsnorm_epsilon=1e-6, ) pytorch_attn = Attention(config) freqs_cis = precompute_freqs_cis(dim // n_heads, seq_len, theta=500000) # Initialize JAX Attention - cfg = DoConfig( - D=dim, - H=n_heads, - L=seq_len, - N=1, - V=1000, - F=1024, + cfg = JaxModelConfig( + model_dim=dim, + num_heads=n_heads, + seq_len=seq_len, + num_layers=1, + vocab_size=1000, + expanded_model_dim=1024, dtype=jnp.float32, rmsnorm_epsilon=1e-6, ) @@ -354,26 +358,26 @@ def test_block(): expand = 4.0 # Initialize PyTorch Block - config = ModelConfig( + config = PyTorchModelConfig( vocab_size=1000, seq_len=seq_len, - dim=dim, - expand=expand, - n_layers=1, - n_heads=n_heads, - rmsnorm_eps=1e-6, + model_dim=dim, + expanded_model_dim=int(dim * expand), + num_layers=1, + num_heads=n_heads, + rmsnorm_epsilon=1e-6, ) pytorch_block = Block(layer_id=0, cfg=config) freqs_cis = precompute_freqs_cis(dim // n_heads, seq_len, theta=500000) # Initialize JAX Block - cfg = DoConfig( - D=dim, - H=n_heads, - L=seq_len, - N=1, - V=1000, - F=int(dim * expand), + cfg = JaxModelConfig( + model_dim=dim, + num_heads=n_heads, + seq_len=seq_len, + num_layers=1, + vocab_size=1000, + expanded_model_dim=int(dim * expand), dtype=jnp.float32, rmsnorm_epsilon=1e-6, ) @@ -408,9 +412,9 @@ def copy_full_model_params(pytorch_model, flax_params, config): if hasattr(pytorch_model, '_orig_mod'): pytorch_model = pytorch_model._orig_mod - n_layers = config.n_layers - n_heads = config.n_heads - dim = config.dim + n_layers = config.num_layers + n_heads = config.num_heads + dim = config.model_dim head_dim = dim // n_heads new_params = {'params': {}} @@ -489,27 +493,27 @@ def test_full_model(): expand = 4.0 # Initialize PyTorch model - pytorch_config = ModelConfig( + pytorch_config = PyTorchModelConfig( vocab_size=vocab_size, seq_len=seq_len, - dim=dim, - expand=expand, - n_layers=n_layers, - n_heads=n_heads, - rmsnorm_eps=1e-6, + model_dim=dim, + expanded_model_dim=int(dim * expand), + num_layers=n_layers, + num_heads=n_heads, + rmsnorm_epsilon=1e-6, tie_embeddings=True, ) pytorch_model = Transformer(pytorch_config) pytorch_model.eval() # Initialize JAX model - jax_config = DoConfig( - D=dim, - H=n_heads, - L=seq_len, - N=n_layers, - V=vocab_size, - F=int(dim * expand), + jax_config = JaxModelConfig( + model_dim=dim, + num_heads=n_heads, + seq_len=seq_len, + num_layers=n_layers, + vocab_size=vocab_size, + expanded_model_dim=int(dim * expand), dtype=jnp.float32, rmsnorm_epsilon=1e-6, tie_embeddings=True, @@ -557,27 +561,27 @@ def test_prediction(): k = 5 # Number of tokens to predict # Initialize PyTorch model - pytorch_config = ModelConfig( + pytorch_config = PyTorchModelConfig( vocab_size=vocab_size, seq_len=seq_len + k, - dim=dim, - expand=expand, - n_layers=n_layers, - n_heads=n_heads, - rmsnorm_eps=1e-6, + model_dim=dim, + expanded_model_dim=int(dim * expand), + num_layers=n_layers, + num_heads=n_heads, + rmsnorm_epsilon=1e-6, tie_embeddings=True, ) pytorch_model = Transformer(pytorch_config) pytorch_model.eval() # Initialize JAX model - jax_config = DoConfig( - D=dim, - H=n_heads, - L=seq_len + k, - N=n_layers, - V=vocab_size, - F=int(dim * expand), + jax_config = JaxModelConfig( + model_dim=dim, + num_heads=n_heads, + seq_len=seq_len + k, + num_layers=n_layers, + vocab_size=vocab_size, + expanded_model_dim=int(dim * expand), dtype=jnp.float32, rmsnorm_epsilon=1e-6, tie_embeddings=True, @@ -633,14 +637,26 @@ def test_initialization_statistics(): logging.info('=' * 70) # Initialize models - jax_cfg = DoConfig(D=512, H=8, L=1024, N=12, V=50000, F=2048) + jax_cfg = JaxModelConfig( + model_dim=512, + num_heads=8, + seq_len=1024, + num_layers=12, + vocab_size=50000, + expanded_model_dim=2048, + dtype=jnp.float32) jax_model = TransformerDo(jax_cfg) jax_params = jax_model.init( jax.random.PRNGKey(42), jnp.ones((1, 10), dtype=jnp.int32) ) - pytorch_cfg = ModelConfig( - vocab_size=50000, seq_len=1024, dim=512, expand=4.0, n_layers=12, n_heads=8 + pytorch_cfg = PyTorchModelConfig( + vocab_size=50000, + seq_len=1024, + model_dim=512, + expanded_model_dim=2048, + num_layers=12, + num_heads=8, ) pytorch_model = Transformer(pytorch_cfg) @@ -771,20 +787,27 @@ def test_initialization_impact(): tokens = np.random.randint(0, vocab_size, size=(batch_size, seq_len)) # Initialize both models with same seed - jax_cfg = DoConfig(D=512, H=8, L=seq_len, N=12, V=vocab_size, F=2048) + jax_cfg = JaxModelConfig( + model_dim=512, + num_heads=8, + seq_len=seq_len, + num_layers=12, + vocab_size=vocab_size, + expanded_model_dim=2048, + ) jax_model = TransformerDo(jax_cfg) jax_params = jax_model.init( jax.random.PRNGKey(42), jnp.array(tokens, dtype=jnp.int32) ) torch.manual_seed(42) - pytorch_cfg = ModelConfig( + pytorch_cfg = PyTorchModelConfig( vocab_size=vocab_size, seq_len=seq_len, - dim=512, - expand=4.0, - n_layers=12, - n_heads=8, + model_dim=512, + expanded_model_dim=2048, + num_layers=12, + num_heads=8, ) pytorch_model = Transformer(pytorch_cfg) From d35cddebdb2f62f49665313a79188510684c12df Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 23 Oct 2025 17:00:58 +0000 Subject: [PATCH 73/82] Add query-key normalization to CausalAttn and Attention classes, including learned scaling factor --- algoperf/workloads/lm/lm_jax/nanodo_model.py | 61 +++++++++++++++---- .../workloads/lm/lm_pytorch/plainlm_model.py | 26 ++++++-- tests/modeldiffs/lm/compare.py | 3 +- 3 files changed, 72 insertions(+), 18 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py index 2b47c1735..d08e9b7bf 100644 --- a/algoperf/workloads/lm/lm_jax/nanodo_model.py +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -25,14 +25,19 @@ class ModelConfig: rmsnorm_epsilon: float = 1e-6 use_residual_scaling: bool = True tie_embeddings: bool = True # Whether to tie input and output embed + qknorm_epsilon: float = 1e-6 dtype: jnp.dtype = jnp.float32 - attention_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) + attention_init: nn.initializers.Initializer = nn.initializers.normal( + stddev=0.02 + ) linear_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) embed_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) def __post_init__(self): - self.residual_init = nn.initializers.normal(stddev=0.02/jnp.sqrt(2 * self.num_layers)) + self.residual_init = nn.initializers.normal( + stddev=0.02 / jnp.sqrt(2 * self.num_layers) + ) class Mlp(nn.Module): @@ -43,7 +48,6 @@ class Mlp(nn.Module): @nn.compact def __call__(self, x_BxLxD: jax.Array): cfg = self.cfg - # Use Xavier uniform initialization explicitly linear = partial( nn.Dense, kernel_init=cfg.linear_init, use_bias=False, dtype=cfg.dtype ) @@ -58,7 +62,14 @@ def __call__(self, x_BxLxD: jax.Array): x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD) # Apply GLU activation x_BxLxF = nn.glu(x_BxLx2F, axis=-1) - x_BxLxD = nn.Dense(cfg.model_dim, use_bias=False, dtype=cfg.dtype, kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init)(x_BxLxF) + x_BxLxD = nn.Dense( + cfg.model_dim, + use_bias=False, + dtype=cfg.dtype, + kernel_init=cfg.residual_init + if cfg.use_residual_scaling + else cfg.linear_init, + )(x_BxLxF) return x_BxLxD @@ -114,8 +125,11 @@ class CausalAttn(nn.Module): def setup(self): cfg = self.cfg - assert cfg.model_dim % cfg.num_heads == 0, f'D {cfg.model_dim} not divisible by H {cfg.num_heads}' + assert cfg.model_dim % cfg.num_heads == 0, ( + f'D {cfg.model_dim} not divisible by H {cfg.num_heads}' + ) self.Dh = cfg.model_dim // cfg.num_heads + self.eps = cfg.qknorm_epsilon # Initialize rotary embeddings self.freqs_cis = init_rope(cfg.model_dim, cfg.seq_len, cfg.num_heads) @@ -129,15 +143,22 @@ def setup(self): use_bias=False, dtype=cfg.dtype, ) - self.multilinear_query = self.multilinear(name='query') self.multilinear_key = self.multilinear(name='key') self.multilinear_value = self.multilinear(name='value') + # See Henry et al. (2020) "Query Key Normalization for Transformers" + seq_len = cfg.seq_len + attn_scale0 = jnp.log2(seq_len**2 - seq_len) + self.attn_scale = self.param( + 'attn_scale', nn.initializers.constant(attn_scale0), () + ) self.output_projection = nn.DenseGeneral( features=cfg.model_dim, name='attn_out_proj', # axis=(-2, -1), # - kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init, + kernel_init=cfg.residual_init + if cfg.use_residual_scaling + else cfg.linear_init, use_bias=False, dtype=cfg.dtype, ) @@ -153,8 +174,9 @@ def __call__(self, x_BxLxD: jax.Array): # Apply rotary embeddings to Q and K q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis) - # Scale queries - q_BxLxHxDh /= self.Dh**0.5 + # Apply QK normalization + q_BxLxHxDh /= jnp.linalg.norm(q_BxLxHxDh, axis=-1, keepdims=True) + self.eps + k_BxLxHxDh /= jnp.linalg.norm(k_BxLxHxDh, axis=-1, keepdims=True) + self.eps # Compute attention scores att_BxHxLxL = jnp.einsum('...qhd,...khd->...hqk', q_BxLxHxDh, k_BxLxHxDh) @@ -166,6 +188,9 @@ def __call__(self, x_BxLxD: jax.Array): # Apply mask and softmax _NEG_INF = jnp.finfo(cfg.dtype).min att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) + att_BxHxLxL = ( + self.attn_scale * att_BxHxLxL + ) # Learned scaling factor for QK norm att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) @@ -227,7 +252,10 @@ def setup(self): self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) else: self.output_proj = nn.Dense( - cfg.vocab_size, kernel_init=cfg.embed_init, dtype=cfg.dtype, name='output_proj' + cfg.vocab_size, + kernel_init=cfg.embed_init, + dtype=cfg.dtype, + name='output_proj', ) def __call__(self, y_BxL: jax.Array): @@ -270,7 +298,9 @@ def predict(self, y_BxL: jax.Array, k: int = 1): next_token_logits = logits[:, -1, :] last_token_id = y_BxL[:, -1] # Prevent predicting the same token consecutively - next_token_logits = next_token_logits.at[jnp.arange(len(last_token_id)), last_token_id].set(float('-inf')) + next_token_logits = next_token_logits.at[ + jnp.arange(len(last_token_id)), last_token_id + ].set(float('-inf')) # Get the most likely token next_token = jnp.argmax(next_token_logits, axis=-1) @@ -289,7 +319,14 @@ def main(): """Create and run the DecoderOnly Transformer model.""" # Initialize model configuration with smaller parameters for demo B, L = (2, 128) # Batch size, sequence length - cfg = ModelConfig(model_dim=128, num_heads=4, seq_len=L, num_layers=2, vocab_size=256, expanded_model_dim=4 * 128) + cfg = ModelConfig( + model_dim=128, + num_heads=4, + seq_len=L, + num_layers=2, + vocab_size=256, + expanded_model_dim=4 * 128, + ) model = TransformerDo(cfg) # Print model info diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py index 8186638e7..edee8318c 100644 --- a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -23,6 +23,7 @@ class ModelConfig: expanded_model_dim: int multiple_of: int = 256 rmsnorm_epsilon: float = 1e-6 + qknorm_epsilon: float = 1e-6 use_residual_scaling: bool = True tie_embeddings: bool = True @@ -92,9 +93,14 @@ def __init__(self, cfg: ModelConfig): # Split into Q, K, V sections wq, wk, wv = torch.chunk(self.w_qkv.weight, 3, dim=0) for w in [wq, wk, wv]: - nn.init.normal_(w, std=0.02) + nn.init.normal_(w, std=0.02) nn.init.normal_(self.w_out.weight, std=0.02) + self.eps = cfg.qknorm_epsilon # e.g., 1e-6 + seq_len = cfg.seq_len + attn_scale0 = math.log2(seq_len**2 - seq_len) + self.attn_scale = nn.Parameter(torch.tensor(attn_scale0)) + def forward(self, x, freqs_cis): bsz, seqlen, d = x.shape # (bsz, seqlen, d) @@ -117,10 +123,14 @@ def forward(self, x, freqs_cis): k = k.transpose(1, 2) # (bsz, nh, seqlen, h_dim) v = v.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + # Apply QK normalization + q = q / torch.norm(q, dim=-1, keepdim=True) + self.eps + k = k / torch.norm(k, dim=-1, keepdim=True) + self.eps + q *= self.attn_scale + out = F.scaled_dot_product_attention( - q, k, v, is_causal=True + q, k, v, is_causal=True, scale=1.0 ) # (bsz, nh, seqlen, h_dim) - out = ( out.transpose(1, 2).contiguous().view(bsz, seqlen, d) ) # (bsz, seqlen, d) @@ -133,7 +143,11 @@ def __init__(self, layer_id: int, cfg: ModelConfig): super().__init__() self.attn = Attention(cfg) self.attn_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) - self.mlp = MLP(dim=cfg.model_dim, hidden_dim=cfg.expanded_model_dim, multiple_of=cfg.multiple_of) + self.mlp = MLP( + dim=cfg.model_dim, + hidden_dim=cfg.expanded_model_dim, + multiple_of=cfg.multiple_of, + ) self.mlp_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) self.layer_id = layer_id @@ -263,7 +277,9 @@ def _init_weights(self, module): def _scale_residual_branches(self): for n, p in self.named_parameters(): - if n.endswith('fc2.weight') or n.endswith('w_out.weight'): # mlp/glu output layer + if n.endswith('fc2.weight') or n.endswith( + 'w_out.weight' + ): # mlp/glu output layer torch.nn.init.normal_( p, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layers) ) diff --git a/tests/modeldiffs/lm/compare.py b/tests/modeldiffs/lm/compare.py index f681597d8..e1d85eba7 100644 --- a/tests/modeldiffs/lm/compare.py +++ b/tests/modeldiffs/lm/compare.py @@ -644,7 +644,8 @@ def test_initialization_statistics(): num_layers=12, vocab_size=50000, expanded_model_dim=2048, - dtype=jnp.float32) + dtype=jnp.float32, + ) jax_model = TransformerDo(jax_cfg) jax_params = jax_model.init( jax.random.PRNGKey(42), jnp.ones((1, 10), dtype=jnp.int32) From ffb816329d1a9f5272956a8ad04ba2e307401ee2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 24 Oct 2025 19:46:11 +0000 Subject: [PATCH 74/82] update target --- algoperf/workloads/lm/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index e0af589e3..79f65040c 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -42,7 +42,7 @@ def has_reached_validation_target(self, eval_result: float) -> bool: @property def validation_target_value(self) -> float: - return 25.5477 # Target perplexity + return 22.432 # Target perplexity def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: return True # No test targets From 202e5cb79e237178d47fdb391fc29c4c4fbc3b8a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sun, 26 Oct 2025 23:58:00 +0000 Subject: [PATCH 75/82] add pytorch nadamw_target_setting --- .../pytorch_nadamw_target_setting.py | 403 ++++++++++++++++++ submission_runner.py | 2 +- 2 files changed, 404 insertions(+), 1 deletion(-) create mode 100644 algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py b/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py new file mode 100644 index 000000000..196c1f809 --- /dev/null +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py @@ -0,0 +1,403 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import torch +import torch.distributed.nn as dist_nn +from absl import logging +from torch import Tensor +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR + +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 + ): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, + 'betas': betas, + 'eps': eps, + 'weight_decay': weight_decay, + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step'] + ) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.0) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + ) + + return loss + + +def nadamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors' + ) + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + + optimizer_state = { + 'optimizer': NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + step_hint = step_hint * 0.75 + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state['optimizer'] + ) + + return optimizer_state + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del train_state + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + dropout_rate=hyperparameters.dropout_rate, + ) + + label_smoothing = ( + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip + ) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) + + return (optimizer_state, current_param_container, new_model_state) + + +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + elif workload_name == 'lm': + return 64 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/submission_runner.py b/submission_runner.py index 857d4479f..1bb763cf2 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -258,7 +258,6 @@ def train_once( 'criteo1tb', 'imagenet_vit', 'librispeech_deepspeech', - 'lm', ] eager_backend_workloads = [] aot_eager_backend_workloads = [] @@ -267,6 +266,7 @@ def train_once( 'librispeech_deepspeech', 'ogbg', 'wmt', + 'lm' ] base_workload = workloads.get_base_workload_name(workload_name) if base_workload in compile_error_workloads: From 98e491ad712ba2be1d17452f67be877d2abaf679 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 27 Oct 2025 00:42:12 +0000 Subject: [PATCH 76/82] docker updates for a100 --- docker/scripts/startup.sh | 6 +++--- scoring/utils/run_workloads.py | 1 + scoring/utils/workload_metadata_external_tuning.json | 4 ++++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 35ac30461..d92107e90 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -174,7 +174,7 @@ fi # Check if arguments are valid VALID_DATASETS=("criteo1tb" "imagenet" "fastmri" "ogbg" "librispeech" \ - "wmt" "mnist") + "wmt" "mnist" "fineweb_edu_10B") VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_resnet_gelu" \ "imagenet_resnet_large_bn_init" "imagenet_vit" "imagenet_vit_glu" \ "imagenet_vit_post_ln" "imagenet_vit_map" "fastmri" "ogbg" \ @@ -185,7 +185,7 @@ VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_ "librispeech_conformer_gelu" "fastmri_model_size" "fastmri_tanh" \ "librispeech_deepspeech_tanh" \ "librispeech_deepspeech_no_resnet" "librispeech_deepspeech_norm_and_spec_aug" - "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size") + "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size" "lm") VALID_RULESETS=("self" "external") # Set data and experiment paths @@ -221,7 +221,7 @@ TUNING_RULESET_FLAG="--tuning_ruleset=${TUNING_RULESET}" if [[ "${FRAMEWORK}" == "jax" ]]; then COMMAND_PREFIX="python" else - COMMAND_PREFIX="torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8" + COMMAND_PREFIX="torchrun --redirects 1:0,2:0,3:0 --standalone --nnodes=1 --nproc_per_node=4" fi # Set data directory and bucket (bucket is only relevant in internal mode) diff --git a/scoring/utils/run_workloads.py b/scoring/utils/run_workloads.py index 273881c5a..c76ef6e32 100644 --- a/scoring/utils/run_workloads.py +++ b/scoring/utils/run_workloads.py @@ -270,6 +270,7 @@ def main(_): 'docker run -t -d -v /home/kasimbeg/data/:/data/ ' '-v /home/kasimbeg/experiment_runs/:/experiment_runs ' '-v /home/kasimbeg/experiment_runs/logs:/logs ' + '-v /home/kasimbeg/algorithmic-efficiency:/algorithmic-efficiency' f'{mount_repo_flag}' '--gpus all --ipc=host ' f'{docker_image_url} ' diff --git a/scoring/utils/workload_metadata_external_tuning.json b/scoring/utils/workload_metadata_external_tuning.json index c7d4ae195..5138e9acf 100644 --- a/scoring/utils/workload_metadata_external_tuning.json +++ b/scoring/utils/workload_metadata_external_tuning.json @@ -30,5 +30,9 @@ "librispeech_conformer": { "max_steps": 80000, "dataset": "librispeech" + }, + "lm" : { + "max_steps": 55000, + "dataset":"fineweb_edu_10B" } } From f0f7774beb47ffa48e0d1621ec997e4455db006e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 6 Nov 2025 03:53:03 +0000 Subject: [PATCH 77/82] update budgets for a100 hardware weightclass --- algoperf/workloads/criteo1tb/workload.py | 4 ++-- algoperf/workloads/fastmri/workload.py | 4 ++-- .../workloads/imagenet_resnet/workload.py | 4 ++-- algoperf/workloads/imagenet_vit/workload.py | 4 ++-- .../librispeech_conformer/workload.py | 4 ++-- .../librispeech_jax/workload.py | 6 +++++- .../librispeech_pytorch/workload.py | 6 +++++- algoperf/workloads/ogbg/workload.py | 4 ++-- algoperf/workloads/wmt/workload.py | 4 ++-- docker/build_docker_images.sh | 14 ++++++------- scoring/performance_profile.py | 1 + scoring/score_submissions.py | 4 +++- scoring/scoring_utils.py | 20 +++++++++++++++++++ scoring/utils/run_workloads.py | 7 ++++++- .../workload_metadata_external_tuning.json | 2 +- 15 files changed, 62 insertions(+), 26 deletions(-) diff --git a/algoperf/workloads/criteo1tb/workload.py b/algoperf/workloads/criteo1tb/workload.py index 2cb7e5450..fb38eacc3 100644 --- a/algoperf/workloads/criteo1tb/workload.py +++ b/algoperf/workloads/criteo1tb/workload.py @@ -95,11 +95,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 7_703 # ~2.1 hours. + return 8915 # ~2.4 hours. @property def eval_period_time_sec(self) -> int: - return 2 * 60 # 2 mins. + return 356 # approx 25 evals def _build_input_queue( self, diff --git a/algoperf/workloads/fastmri/workload.py b/algoperf/workloads/fastmri/workload.py index 0b1ecfaa1..5a8afa2e9 100644 --- a/algoperf/workloads/fastmri/workload.py +++ b/algoperf/workloads/fastmri/workload.py @@ -95,11 +95,11 @@ def accelerations(self): @property def max_allowed_runtime_sec(self) -> int: - return 4_430 # ~1.2 hours + return 2745 # ~0.7 hours @property def eval_period_time_sec(self) -> int: - return 80 + return 110 # approx 25 evals @property def step_hint(self) -> int: diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py index ef696e328..b5263e0a6 100644 --- a/algoperf/workloads/imagenet_resnet/workload.py +++ b/algoperf/workloads/imagenet_resnet/workload.py @@ -103,11 +103,11 @@ def resize_size(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 66_159 # ~18.4 hours + return 49918 # ~13.8 hours @property def eval_period_time_sec(self) -> int: - return 510 # 8.5 minutes. + return 1996 # approx 25 evals def _build_dataset( self, diff --git a/algoperf/workloads/imagenet_vit/workload.py b/algoperf/workloads/imagenet_vit/workload.py index 2a0070ba4..f8f4f2659 100644 --- a/algoperf/workloads/imagenet_vit/workload.py +++ b/algoperf/workloads/imagenet_vit/workload.py @@ -88,11 +88,11 @@ def eval_batch_size(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 69_768 # ~19.4 hours + return 64_292 # ~17.8 hours @property def eval_period_time_sec(self) -> int: - return 7 * 60 # 7 mins. + return 2571 # 7 mins. def _build_dataset( self, diff --git a/algoperf/workloads/librispeech_conformer/workload.py b/algoperf/workloads/librispeech_conformer/workload.py index 791270719..327e8bc39 100644 --- a/algoperf/workloads/librispeech_conformer/workload.py +++ b/algoperf/workloads/librispeech_conformer/workload.py @@ -80,11 +80,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 58_015 # ~16.1 hours + return 43680 # ~16.1 hours @property def eval_period_time_sec(self) -> int: - return 24 * 60 + return 1747 # approx 25 evals @property def step_hint(self) -> int: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 3a320b0dd..2a8fd29d0 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -100,7 +100,11 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 44_405 # ~12.3 hours + return 36_949 # ~12.3 hours + + @property + def eval_period_time_sec(self) -> int: + return 1447 # approx 25 evals @property def use_tanh(self) -> bool: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 672f3440f..119049b34 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -96,7 +96,11 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 44_405 # ~12.3 hours + return 36949 # 10.3 hours + + @property + def eval_period_time_sec(self) -> int: + return 1447 # approx 25 evals @property def use_tanh(self) -> bool: diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 8717e46d6..53206200f 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -88,11 +88,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 12_011 # ~3.3 hours + return 11303 # ~3.1 hours @property def eval_period_time_sec(self) -> int: - return 4 * 60 + return 452. # approx 25 evals def _build_input_queue( self, diff --git a/algoperf/workloads/wmt/workload.py b/algoperf/workloads/wmt/workload.py index 40e4262dd..d972a5486 100644 --- a/algoperf/workloads/wmt/workload.py +++ b/algoperf/workloads/wmt/workload.py @@ -89,11 +89,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 43_336 # ~12.0 hours + return 16114 # ~12.0 hours @property def eval_period_time_sec(self) -> int: - return 14 * 60 + return 644 @property def step_hint(self) -> int: diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh index 6b5e67ceb..22590b9fd 100644 --- a/docker/build_docker_images.sh +++ b/docker/build_docker_images.sh @@ -27,7 +27,7 @@ then GIT_BRANCH='main' # Set default argument fi -FRAMEWORKS=( "jax" "pythorch" "both" ) +FRAMEWORKS=( "jax" "pytorch") if [[ -n "$FRAMEWORK" ]]; then @@ -45,10 +45,10 @@ do echo "On branch: ${GIT_BRANCH}" echo $DOCKER_BUILD_COMMAND eval $DOCKER_BUILD_COMMAND - echo $DOCKER_TAG_COMMAND - eval $DOCKER_TAG_COMMAND - echo $DOCKER_PUSH_COMMAND - eval $DOCKER_PUSH_COMMAND - echo "To pull container run: " - echo $DOCKER_PULL_COMMAND + # echo $DOCKER_TAG_COMMAND + # eval $DOCKER_TAG_COMMAND + # echo $DOCKER_PUSH_COMMAND + # eval $DOCKER_PUSH_COMMAND + # echo "To pull container run: " + # echo $DOCKER_PULL_COMMAND done diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 4f2ae9c57..b200c6865 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -71,6 +71,7 @@ 'wer', 'l1_loss', 'loss', + 'ppl' ] MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu'] diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 3423df2e1..4b7bed2b5 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -123,6 +123,8 @@ def get_summary_df(workload, workload_df, include_test_split=False): workload_df['accumulated_submission_time'] / workload_df['global_step'] ).iloc[-1][-1] + summary_df['step_hint'] = scoring_utils.get_workload_stephint(workload) + # test metrics if include_test_split: test_metric, test_target = scoring_utils.get_workload_metrics_and_targets( @@ -157,7 +159,7 @@ def get_summary_df(workload, workload_df, include_test_split=False): return summary_df -def get_submission_summary(df, include_test_split=True): +def get_submission_summary(df, include_test_split=False): """Summarizes the submission results into metric and time tables organized by workload. """ diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py index 5be6c790c..cb63eab4b 100644 --- a/scoring/scoring_utils.py +++ b/scoring/scoring_utils.py @@ -240,3 +240,23 @@ def get_workload_metrics_and_targets(workload, split='validation'): metric = f'test/{metric_name}' target = workload_obj.test_target_value return metric, target + + +def get_workload_stephint(workload): + workload_name = re.match(WORKLOAD_NAME_PATTERN, workload).group(1) + framework = re.match(WORKLOAD_NAME_PATTERN, workload).group(2) + workload_metadata = copy.copy(WORKLOADS[workload_name]) + + # Extend path according to framework. + workload_metadata['workload_path'] = os.path.join( + BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + f'{framework}', + 'workload.py', + ) + workload_init_kwargs = {} + workload_obj = workloads_registry.import_workload( + workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + workload_init_kwargs=workload_init_kwargs, + ) + return workload_obj.step_hint diff --git a/scoring/utils/run_workloads.py b/scoring/utils/run_workloads.py index 273881c5a..c6764e9de 100644 --- a/scoring/utils/run_workloads.py +++ b/scoring/utils/run_workloads.py @@ -241,7 +241,8 @@ def main(_): # For each runnable workload check if there are any containers running and if not launch next container command for workload in workloads: - run_key = prng.fold_in(rng_subkey, hash(workload)) + workload_foldin = hash(workload) % 9 + run_key = prng.fold_in(rng_subkey, workload_foldin) run_seed = run_key[0] # arbitrary base_workload_name = get_base_workload_name(workload) wait_until_container_not_running() @@ -270,6 +271,10 @@ def main(_): 'docker run -t -d -v /home/kasimbeg/data/:/data/ ' '-v /home/kasimbeg/experiment_runs/:/experiment_runs ' '-v /home/kasimbeg/experiment_runs/logs:/logs ' +<<<<<<< Updated upstream +======= + '-v /home/kasimbeg/algorithmic-efficiency:/algorithmic-efficiency ' +>>>>>>> Stashed changes f'{mount_repo_flag}' '--gpus all --ipc=host ' f'{docker_image_url} ' diff --git a/scoring/utils/workload_metadata_external_tuning.json b/scoring/utils/workload_metadata_external_tuning.json index c7d4ae195..3d9f78ca1 100644 --- a/scoring/utils/workload_metadata_external_tuning.json +++ b/scoring/utils/workload_metadata_external_tuning.json @@ -24,7 +24,7 @@ "dataset": "librispeech" }, "criteo1tb": { - "max_steps": 10666, + "max_steps": 15666, "dataset": "criteo1tb" }, "librispeech_conformer": { From b93eb3ca97871ed30ddd6b08806a3bbc1ca0bdae Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 6 Nov 2025 03:56:48 +0000 Subject: [PATCH 78/82] formatting --- algoperf/workloads/criteo1tb/workload.py | 2 +- algoperf/workloads/fastmri/workload.py | 2 +- algoperf/workloads/imagenet_resnet/workload.py | 4 ++-- algoperf/workloads/imagenet_vit/workload.py | 2 +- algoperf/workloads/librispeech_conformer/workload.py | 2 +- .../librispeech_deepspeech/librispeech_pytorch/workload.py | 2 +- algoperf/workloads/ogbg/workload.py | 4 ++-- algoperf/workloads/wmt/workload.py | 2 +- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/algoperf/workloads/criteo1tb/workload.py b/algoperf/workloads/criteo1tb/workload.py index fb38eacc3..4d2196cd5 100644 --- a/algoperf/workloads/criteo1tb/workload.py +++ b/algoperf/workloads/criteo1tb/workload.py @@ -95,7 +95,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 8915 # ~2.4 hours. + return 8_915 # ~2.4 hours. @property def eval_period_time_sec(self) -> int: diff --git a/algoperf/workloads/fastmri/workload.py b/algoperf/workloads/fastmri/workload.py index 5a8afa2e9..b87dfc755 100644 --- a/algoperf/workloads/fastmri/workload.py +++ b/algoperf/workloads/fastmri/workload.py @@ -95,7 +95,7 @@ def accelerations(self): @property def max_allowed_runtime_sec(self) -> int: - return 2745 # ~0.7 hours + return 2_745 # ~0.7 hours @property def eval_period_time_sec(self) -> int: diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py index b5263e0a6..de8458c92 100644 --- a/algoperf/workloads/imagenet_resnet/workload.py +++ b/algoperf/workloads/imagenet_resnet/workload.py @@ -103,11 +103,11 @@ def resize_size(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 49918 # ~13.8 hours + return 49_918 # ~13.8 hours @property def eval_period_time_sec(self) -> int: - return 1996 # approx 25 evals + return 1_996 # approx 25 evals def _build_dataset( self, diff --git a/algoperf/workloads/imagenet_vit/workload.py b/algoperf/workloads/imagenet_vit/workload.py index f8f4f2659..4da02614f 100644 --- a/algoperf/workloads/imagenet_vit/workload.py +++ b/algoperf/workloads/imagenet_vit/workload.py @@ -92,7 +92,7 @@ def max_allowed_runtime_sec(self) -> int: @property def eval_period_time_sec(self) -> int: - return 2571 # 7 mins. + return 2_571 # 7 mins. def _build_dataset( self, diff --git a/algoperf/workloads/librispeech_conformer/workload.py b/algoperf/workloads/librispeech_conformer/workload.py index 327e8bc39..5a0a546e4 100644 --- a/algoperf/workloads/librispeech_conformer/workload.py +++ b/algoperf/workloads/librispeech_conformer/workload.py @@ -80,7 +80,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 43680 # ~16.1 hours + return 43_680 # ~16.1 hours @property def eval_period_time_sec(self) -> int: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 119049b34..c6bb149f7 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -96,7 +96,7 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 36949 # 10.3 hours + return 36_949 # 10.3 hours @property def eval_period_time_sec(self) -> int: diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 53206200f..002576268 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -88,11 +88,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 11303 # ~3.1 hours + return 11_303 # ~3.1 hours @property def eval_period_time_sec(self) -> int: - return 452. # approx 25 evals + return 452 # approx 25 evals def _build_input_queue( self, diff --git a/algoperf/workloads/wmt/workload.py b/algoperf/workloads/wmt/workload.py index d972a5486..2e232214e 100644 --- a/algoperf/workloads/wmt/workload.py +++ b/algoperf/workloads/wmt/workload.py @@ -89,7 +89,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 16114 # ~12.0 hours + return 16_114 # ~12.0 hours @property def eval_period_time_sec(self) -> int: From 88b0e47fe9694d35d651a77c1acf8ea9491df5ab Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 6 Nov 2025 03:57:34 +0000 Subject: [PATCH 79/82] revert changes to docker build shell script --- docker/build_docker_images.sh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh index 22590b9fd..aa94222ea 100644 --- a/docker/build_docker_images.sh +++ b/docker/build_docker_images.sh @@ -45,10 +45,10 @@ do echo "On branch: ${GIT_BRANCH}" echo $DOCKER_BUILD_COMMAND eval $DOCKER_BUILD_COMMAND - # echo $DOCKER_TAG_COMMAND - # eval $DOCKER_TAG_COMMAND - # echo $DOCKER_PUSH_COMMAND - # eval $DOCKER_PUSH_COMMAND - # echo "To pull container run: " - # echo $DOCKER_PULL_COMMAND + echo $DOCKER_TAG_COMMAND + eval $DOCKER_TAG_COMMAND + echo $DOCKER_PUSH_COMMAND + eval $DOCKER_PUSH_COMMAND + echo "To pull container run: " + echo $DOCKER_PULL_COMMAND done From fa946d861aab6d803d88046edb05caa27a79c4ab Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 6 Nov 2025 04:00:09 +0000 Subject: [PATCH 80/82] fix merge conflict --- scoring/utils/run_workloads.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/scoring/utils/run_workloads.py b/scoring/utils/run_workloads.py index c6764e9de..d8e0172fa 100644 --- a/scoring/utils/run_workloads.py +++ b/scoring/utils/run_workloads.py @@ -271,10 +271,7 @@ def main(_): 'docker run -t -d -v /home/kasimbeg/data/:/data/ ' '-v /home/kasimbeg/experiment_runs/:/experiment_runs ' '-v /home/kasimbeg/experiment_runs/logs:/logs ' -<<<<<<< Updated upstream -======= '-v /home/kasimbeg/algorithmic-efficiency:/algorithmic-efficiency ' ->>>>>>> Stashed changes f'{mount_repo_flag}' '--gpus all --ipc=host ' f'{docker_image_url} ' From 02f835d0961227e56750360e73c7a8ed4213801b Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 6 Nov 2025 16:59:54 +0000 Subject: [PATCH 81/82] rename models.py --- .../workloads/lm/lm_jax/{nanodo_model.py => models.py} | 0 algoperf/workloads/lm/lm_jax/workload.py | 2 +- .../lm/lm_pytorch/{plainlm_model.py => models.py} | 0 algoperf/workloads/lm/lm_pytorch/workload.py | 2 +- tests/modeldiffs/lm/compare.py | 8 ++++---- 5 files changed, 6 insertions(+), 6 deletions(-) rename algoperf/workloads/lm/lm_jax/{nanodo_model.py => models.py} (100%) rename algoperf/workloads/lm/lm_pytorch/{plainlm_model.py => models.py} (100%) diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/models.py similarity index 100% rename from algoperf/workloads/lm/lm_jax/nanodo_model.py rename to algoperf/workloads/lm/lm_jax/models.py diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index effb12089..3862b73dc 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -7,7 +7,7 @@ from algoperf import jax_sharding_utils, param_utils, spec from algoperf.workloads.lm.input_pipeline import get_data_iter -from algoperf.workloads.lm.lm_jax.nanodo_model import ( +from algoperf.workloads.lm.lm_jax.models import ( ModelConfig, TransformerDo, ) diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/models.py similarity index 100% rename from algoperf/workloads/lm/lm_pytorch/plainlm_model.py rename to algoperf/workloads/lm/lm_pytorch/models.py diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 3d185636b..a052a8452 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -11,7 +11,7 @@ from algoperf import param_utils, pytorch_utils, spec from algoperf.workloads.lm.input_pipeline import get_data_iter -from algoperf.workloads.lm.lm_pytorch.plainlm_model import ( +from algoperf.workloads.lm.lm_pytorch.models import ( ModelConfig, Transformer, ) diff --git a/tests/modeldiffs/lm/compare.py b/tests/modeldiffs/lm/compare.py index e1d85eba7..e1ca8e06c 100644 --- a/tests/modeldiffs/lm/compare.py +++ b/tests/modeldiffs/lm/compare.py @@ -26,7 +26,7 @@ from absl.testing import absltest, parameterized # Import JAX implementation -from algoperf.workloads.lm.lm_jax.nanodo_model import ( +from algoperf.workloads.lm.lm_jax.models import ( CausalAttn, Mlp, TBlock, @@ -34,12 +34,12 @@ apply_rope, init_rope, ) -from algoperf.workloads.lm.lm_jax.nanodo_model import ( +from algoperf.workloads.lm.lm_jax.models import ( ModelConfig as JaxModelConfig, ) # Import PyTorch implementation -from algoperf.workloads.lm.lm_pytorch.plainlm_model import ( +from algoperf.workloads.lm.lm_pytorch.models import ( MLP, Attention, Block, @@ -47,7 +47,7 @@ apply_rotary_emb_complex_like, precompute_freqs_cis, ) -from algoperf.workloads.lm.lm_pytorch.plainlm_model import ( +from algoperf.workloads.lm.lm_pytorch.models import ( ModelConfig as PyTorchModelConfig, ) From 0abf39d3e8888103dbecf952ab75d08be85657a9 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 6 Nov 2025 17:05:57 +0000 Subject: [PATCH 82/82] rename workload --- algoperf/workloads/{lm => finewebedu_lm}/__init__.py | 0 .../finewebedu_lm_jax}/__init__.py | 0 .../lm_jax => finewebedu_lm/finewebedu_lm_jax}/models.py | 0 .../finewebedu_lm_jax}/workload.py | 6 +++--- .../finewebedu_lm_pytorch}/__init__.py | 0 .../finewebedu_lm_pytorch}/models.py | 0 .../finewebedu_lm_pytorch}/workload.py | 6 +++--- .../workloads/{lm => finewebedu_lm}/input_pipeline.py | 0 algoperf/workloads/{lm => finewebedu_lm}/workload.py | 0 algoperf/workloads/ogbg/workload.py | 2 +- algoperf/workloads/workloads.py | 7 +++++-- .../archived_paper_baselines/adamw/pytorch/submission.py | 4 ++-- .../archived_paper_baselines/nesterov/jax/submission.py | 4 ++-- .../baselines/external_tuning/jax_nadamw_full_budget.py | 2 +- .../external_tuning/pytorch_nadamw_full_budget.py | 2 +- .../fineweb_edu_lm/jax_nadamw_target_setting.py | 2 +- .../fineweb_edu_lm/pytorch_nadamw_target_setting.py | 2 +- scoring/performance_profile.py | 2 +- scoring/utils/workload_metadata_external_tuning.json | 2 +- submission_runner.py | 6 +++--- tests/modeldiffs/lm/compare.py | 8 ++++---- 21 files changed, 29 insertions(+), 26 deletions(-) rename algoperf/workloads/{lm => finewebedu_lm}/__init__.py (100%) rename algoperf/workloads/{lm/lm_jax => finewebedu_lm/finewebedu_lm_jax}/__init__.py (100%) rename algoperf/workloads/{lm/lm_jax => finewebedu_lm/finewebedu_lm_jax}/models.py (100%) rename algoperf/workloads/{lm/lm_jax => finewebedu_lm/finewebedu_lm_jax}/workload.py (96%) rename algoperf/workloads/{lm/lm_pytorch => finewebedu_lm/finewebedu_lm_pytorch}/__init__.py (100%) rename algoperf/workloads/{lm/lm_pytorch => finewebedu_lm/finewebedu_lm_pytorch}/models.py (100%) rename algoperf/workloads/{lm/lm_pytorch => finewebedu_lm/finewebedu_lm_pytorch}/workload.py (97%) rename algoperf/workloads/{lm => finewebedu_lm}/input_pipeline.py (100%) rename algoperf/workloads/{lm => finewebedu_lm}/workload.py (100%) diff --git a/algoperf/workloads/lm/__init__.py b/algoperf/workloads/finewebedu_lm/__init__.py similarity index 100% rename from algoperf/workloads/lm/__init__.py rename to algoperf/workloads/finewebedu_lm/__init__.py diff --git a/algoperf/workloads/lm/lm_jax/__init__.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/__init__.py similarity index 100% rename from algoperf/workloads/lm/lm_jax/__init__.py rename to algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/__init__.py diff --git a/algoperf/workloads/lm/lm_jax/models.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py similarity index 100% rename from algoperf/workloads/lm/lm_jax/models.py rename to algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py similarity index 96% rename from algoperf/workloads/lm/lm_jax/workload.py rename to algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py index 3862b73dc..ee4cffbbc 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py @@ -6,12 +6,12 @@ import jax.numpy as jnp from algoperf import jax_sharding_utils, param_utils, spec -from algoperf.workloads.lm.input_pipeline import get_data_iter -from algoperf.workloads.lm.lm_jax.models import ( +from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import ( ModelConfig, TransformerDo, ) -from algoperf.workloads.lm.workload import BaseLmWorkload +from algoperf.workloads.finewebedu_lm.input_pipeline import get_data_iter +from algoperf.workloads.finewebedu_lm.workload import BaseLmWorkload class LmWorkload(BaseLmWorkload): diff --git a/algoperf/workloads/lm/lm_pytorch/__init__.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/__init__.py similarity index 100% rename from algoperf/workloads/lm/lm_pytorch/__init__.py rename to algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/__init__.py diff --git a/algoperf/workloads/lm/lm_pytorch/models.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py similarity index 100% rename from algoperf/workloads/lm/lm_pytorch/models.py rename to algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py similarity index 97% rename from algoperf/workloads/lm/lm_pytorch/workload.py rename to algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py index a052a8452..a25ca334a 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py @@ -10,12 +10,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP from algoperf import param_utils, pytorch_utils, spec -from algoperf.workloads.lm.input_pipeline import get_data_iter -from algoperf.workloads.lm.lm_pytorch.models import ( +from algoperf.workloads.finewebedu_lm.finewebedu_lm_pytorch.models import ( ModelConfig, Transformer, ) -from algoperf.workloads.lm.workload import BaseLmWorkload +from algoperf.workloads.finewebedu_lm.input_pipeline import get_data_iter +from algoperf.workloads.finewebedu_lm.workload import BaseLmWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/finewebedu_lm/input_pipeline.py similarity index 100% rename from algoperf/workloads/lm/input_pipeline.py rename to algoperf/workloads/finewebedu_lm/input_pipeline.py diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/finewebedu_lm/workload.py similarity index 100% rename from algoperf/workloads/lm/workload.py rename to algoperf/workloads/finewebedu_lm/workload.py diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 002576268..771b103a0 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -92,7 +92,7 @@ def max_allowed_runtime_sec(self) -> int: @property def eval_period_time_sec(self) -> int: - return 452 # approx 25 evals + return 452 # approx 25 evals def _build_input_queue( self, diff --git a/algoperf/workloads/workloads.py b/algoperf/workloads/workloads.py index 391f16f51..e90300a36 100644 --- a/algoperf/workloads/workloads.py +++ b/algoperf/workloads/workloads.py @@ -113,7 +113,10 @@ 'workload_path': 'librispeech_deepspeech/librispeech', 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', }, - 'lm': {'workload_path': 'lm/lm', 'workload_class_name': 'LmWorkload'}, + 'finewebedu_lm': { + 'workload_path': 'finewebedu_lm/finewebedu_lm', + 'workload_class_name': 'LmWorkload', + }, 'mnist': { 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload', @@ -153,7 +156,7 @@ 'imagenet_vit', 'librispeech_conformer', 'librispeech_deepspeech', - 'lm', + 'finewebedu_lm', 'ogbg', 'wmt', ] diff --git a/algorithms/archived_paper_baselines/adamw/pytorch/submission.py b/algorithms/archived_paper_baselines/adamw/pytorch/submission.py index 8fa4e27f6..7c50ff4ff 100644 --- a/algorithms/archived_paper_baselines/adamw/pytorch/submission.py +++ b/algorithms/archived_paper_baselines/adamw/pytorch/submission.py @@ -189,8 +189,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 - elif workload_name == 'lm': - return 4 + elif workload_name == 'finewebedu_lm': + return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/archived_paper_baselines/nesterov/jax/submission.py b/algorithms/archived_paper_baselines/nesterov/jax/submission.py index cc8eba3c5..061acc3de 100644 --- a/algorithms/archived_paper_baselines/nesterov/jax/submission.py +++ b/algorithms/archived_paper_baselines/nesterov/jax/submission.py @@ -292,8 +292,8 @@ def get_batch_size(workload_name): return 16 elif workload_name == 'cifar': return 128 - elif workload_name == 'lm': - return 8 + elif workload_name == 'finewebedu_lm': + return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index ccfa25360..323022598 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -394,7 +394,7 @@ def get_batch_size(workload_name): return 512 elif workload_name == 'wmt': return 128 - elif workload_name == 'lm': + elif workload_name == 'finewebedu_lm': return 64 elif workload_name == 'mnist': return 16 diff --git a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py index 9b544e380..2abf74c73 100644 --- a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -372,7 +372,7 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 - elif workload_name == 'lm': + elif workload_name == 'finewebedu_lm': return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py index 1fef611ac..b7adf6cd6 100644 --- a/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py @@ -395,7 +395,7 @@ def get_batch_size(workload_name): return 512 elif workload_name == 'wmt': return 128 - elif workload_name == 'lm': + elif workload_name == 'finewebedu_lm': return 64 elif workload_name == 'mnist': return 16 diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py b/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py index 196c1f809..b881747d8 100644 --- a/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py @@ -373,7 +373,7 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 - elif workload_name == 'lm': + elif workload_name == 'finewebedu_lm': return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index b200c6865..043a65791 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -71,7 +71,7 @@ 'wer', 'l1_loss', 'loss', - 'ppl' + 'ppl', ] MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu'] diff --git a/scoring/utils/workload_metadata_external_tuning.json b/scoring/utils/workload_metadata_external_tuning.json index f133f2462..0ba0d99ee 100644 --- a/scoring/utils/workload_metadata_external_tuning.json +++ b/scoring/utils/workload_metadata_external_tuning.json @@ -31,7 +31,7 @@ "max_steps": 80000, "dataset": "librispeech" }, - "lm" : { + "finewebedu_lm" : { "max_steps": 55000, "dataset":"fineweb_edu_10B" } diff --git a/submission_runner.py b/submission_runner.py index 1bb763cf2..01d9894d8 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -266,7 +266,7 @@ def train_once( 'librispeech_deepspeech', 'ogbg', 'wmt', - 'lm' + 'finewebedu_lm', ] base_workload = workloads.get_base_workload_name(workload_name) if base_workload in compile_error_workloads: @@ -784,7 +784,7 @@ def main(_): os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' if FLAGS.framework == 'pytorch': - limit_tf_threads = base_workload != 'lm' + limit_tf_threads = base_workload != 'finewebedu_lm' pytorch_init( USE_PYTORCH_DDP, RANK, profiler, limit_tf_threads=limit_tf_threads ) @@ -803,7 +803,7 @@ def main(_): 'librispeech_deepspeech', 'imagenet_vit', 'criteo1tb', - 'lm', + 'finewebedu_lm', ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' diff --git a/tests/modeldiffs/lm/compare.py b/tests/modeldiffs/lm/compare.py index e1ca8e06c..709e3125f 100644 --- a/tests/modeldiffs/lm/compare.py +++ b/tests/modeldiffs/lm/compare.py @@ -26,7 +26,7 @@ from absl.testing import absltest, parameterized # Import JAX implementation -from algoperf.workloads.lm.lm_jax.models import ( +from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import ( CausalAttn, Mlp, TBlock, @@ -34,12 +34,12 @@ apply_rope, init_rope, ) -from algoperf.workloads.lm.lm_jax.models import ( +from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import ( ModelConfig as JaxModelConfig, ) # Import PyTorch implementation -from algoperf.workloads.lm.lm_pytorch.models import ( +from algoperf.workloads.finewebedu_lm.finewebedu_lm_pytorch.models import ( MLP, Attention, Block, @@ -47,7 +47,7 @@ apply_rotary_emb_complex_like, precompute_freqs_cis, ) -from algoperf.workloads.lm.lm_pytorch.models import ( +from algoperf.workloads.finewebedu_lm.finewebedu_lm_pytorch.models import ( ModelConfig as PyTorchModelConfig, )