diff --git a/examples/gemma/input_pipeline.py b/examples/gemma/input_pipeline.py index da9ae4733..c7448847d 100644 --- a/examples/gemma/input_pipeline.py +++ b/examples/gemma/input_pipeline.py @@ -15,15 +15,11 @@ """Input pipeline for a LM1B dataset.""" import os -import typing +from typing import Any import tensorflow as tf import tensorflow_datasets as tfds import tokenizer -from clu import deterministic_data - -if typing.TYPE_CHECKING: - from train import TrainConfig AUTOTUNE = tf.data.experimental.AUTOTUNE Features = dict[str, tf.Tensor] @@ -324,7 +320,7 @@ def filter_fn(x): def get_datasets( - config: "TrainConfig", + config: Any, *, n_devices: int, vocab_path: str | None = None, diff --git a/examples/gemma/main.py b/examples/gemma/main.py index f4185e216..cd97f3f10 100644 --- a/examples/gemma/main.py +++ b/examples/gemma/main.py @@ -18,21 +18,24 @@ that can be easily tested and imported in Colab. """ -import jax -import tensorflow as tf -import train -from absl import app, flags, logging +from absl import app +from absl import flags +from absl import logging from clu import platform +import train +import jax from ml_collections import config_flags +import tensorflow as tf + FLAGS = flags.FLAGS flags.DEFINE_string('workdir', None, 'Directory to store model data.') config_flags.DEFINE_config_file( - 'config', - 'configs/default.py', - 'File path to the training hyperparameter configuration.', - lock_config=True, + 'config', + 'configs/default.py', + 'File path to the training hyperparameter configuration.', + lock_config=True, ) flags.mark_flags_as_required(['workdir']) @@ -51,11 +54,11 @@ def main(argv): # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( - f'process_index: {jax.process_index()}, ' - f'process_count: {jax.process_count()}' + f'process_index: {jax.process_index()}, ' + f'process_count: {jax.process_count()}' ) platform.work_unit().create_artifact( - platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' + platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' ) train.train_and_evaluate(FLAGS.config, FLAGS.workdir) diff --git a/examples/gemma/tokenizer.py b/examples/gemma/tokenizer.py index fe740f3be..3c694220d 100644 --- a/examples/gemma/tokenizer.py +++ b/examples/gemma/tokenizer.py @@ -25,7 +25,8 @@ import tensorflow as tf import tensorflow_text as tftxt from absl import logging -from sentencepiece import SentencePieceTrainer, SentencePieceProcessor +from sentencepiece import SentencePieceProcessor +from sentencepiece import SentencePieceTrainer Features = dict[str, tf.Tensor] @@ -190,5 +191,5 @@ def __call__(self, features: Features) -> Features: def load_sentencepiece_processor(vocab_path: str): spp = SentencePieceProcessor() - spp.load(vocab_path) + spp.Load(vocab_path) return spp diff --git a/examples/gemma/train.py b/examples/gemma/train.py index b5bc07745..aaafce7bc 100644 --- a/examples/gemma/train.py +++ b/examples/gemma/train.py @@ -22,26 +22,26 @@ import dataclasses import os +from typing import Any +from absl import logging +from clu import metric_writers +from clu import periodic_actions +from flax import nnx import input_pipeline -import jax -import jax.numpy as jnp +import sampler as sampler_lib import tokenizer import transformer as transformer_lib +import utils +from flax.training import checkpoints +from flax.training import common_utils +import jax +from jax import random +import jax.numpy as jnp +import ml_collections import numpy as np import optax -import sampler as sampler_lib import tensorflow as tf -import utils -from absl import logging -from clu import metric_writers, periodic_actions -from jax import random -from jax.sharding import Mesh, NamedSharding -from jax.sharding import PartitionSpec as P -from utils import TrainState - -from flax import nnx -from flax.training import checkpoints, common_utils @dataclasses.dataclass(unsafe_hash=True) @@ -53,13 +53,13 @@ class MeshRules: def __call__(self, *keys: str) -> tuple[str, ...]: return tuple( - getattr(self, key) if key is not None else None - for key in keys + getattr(self, key) if key is not None else None for key in keys ) @dataclasses.dataclass(unsafe_hash=True) class TrainConfig: + """Configuration for training a model.""" # Path to load or store sentencepiece vocab file. vocab_path: str | None # Vocabulary size if `vocab_path` is not given. @@ -107,10 +107,11 @@ class TrainConfig: # Gemma transformer name. # Possible values defined in transformer.TransformerConfig: - # (gemma_2b, gemma_7b, gemma2_2b, gemma2_9b, gemma2_27b, gemma3_1b, gemma3_4b, ...) + # (gemma_2b, gemma_7b, gemma2_2b, gemma2_9b, gemma2_27b, gemma3_1b, gemma3_4b, + # ...) transformer_name: str | None # or alternatively define the model using the dict of parameters - transformer_params: dict | None + transformer_params: dict[Any, Any] | None # Whether to save model checkpoints. save_checkpoints: bool @@ -157,8 +158,8 @@ def __post_init__(self): def rsqrt_schedule( - init_value: float, - shift: int = 0, + init_value: float, + shift: int = 0 ): """Applies a reverse square-root schedule. @@ -182,20 +183,20 @@ def schedule(count): def create_learning_rate_schedule(learning_rate: float, warmup_steps: int): """Creates a rsqrt schedule with linear warmup.""" return optax.join_schedules( - [ - optax.linear_schedule( - init_value=0, - end_value=learning_rate, - transition_steps=warmup_steps, - ), - rsqrt_schedule(init_value=learning_rate, shift=warmup_steps), - ], - boundaries=[warmup_steps], + [ + optax.linear_schedule( + init_value=0, + end_value=learning_rate, + transition_steps=warmup_steps, + ), + rsqrt_schedule(init_value=learning_rate, shift=warmup_steps), + ], + boundaries=[warmup_steps], ) def compute_weighted_cross_entropy( - logits, targets, weights=None, label_smoothing=0.0 + logits, targets, weights=None, label_smoothing=0.0 ): """Compute weighted cross entropy and entropy for log probs and targets. @@ -211,18 +212,18 @@ def compute_weighted_cross_entropy( """ if logits.ndim != targets.ndim + 1: raise ValueError( - 'Incorrect shapes. Got shape %s logits and %s targets' - % (str(logits.shape), str(targets.shape)) + 'Incorrect shapes. Got shape %s logits and %s targets' + % (str(logits.shape), str(targets.shape)) ) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( - confidence * jnp.log(confidence) - + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) + confidence * jnp.log(confidence) + + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) ) soft_targets = common_utils.onehot( - targets, vocab_size, on_value=confidence, off_value=low_confidence + targets, vocab_size, on_value=confidence, off_value=low_confidence ) loss = -jnp.sum(soft_targets * nnx.log_softmax(logits), axis=-1) @@ -249,8 +250,8 @@ def compute_weighted_accuracy(logits, targets, weights=None): """ if logits.ndim != targets.ndim + 1: raise ValueError( - 'Incorrect shapes. Got shape %s logits and %s targets' - % (str(logits.shape), str(targets.shape)) + 'Incorrect shapes. Got shape %s logits and %s targets' + % (str(logits.shape), str(targets.shape)) ) loss = jnp.equal(jnp.argmax(logits, axis=-1), targets) normalizing_factor = np.prod(logits.shape[:-1]) @@ -264,13 +265,13 @@ def compute_weighted_accuracy(logits, targets, weights=None): def compute_metrics(logits, labels, weights, label_smoothing=0.0): """Compute summary metrics.""" loss, weight_sum = compute_weighted_cross_entropy( - logits, labels, weights, label_smoothing + logits, labels, weights, label_smoothing ) acc, _ = compute_weighted_accuracy(logits, labels, weights) metrics = { - 'loss': loss, - 'accuracy': acc, - 'denominator': weight_sum, + 'loss': loss, + 'accuracy': acc, + 'denominator': weight_sum, } return metrics @@ -280,10 +281,10 @@ def compute_metrics(logits, labels, weights, label_smoothing=0.0): def train_step( - state: TrainState, - batch, - learning_rate_fn, - label_smoothing=0.0, + state: utils.TrainState, + batch, + learning_rate_fn, + label_smoothing=0.0, ): """Perform a single training step.""" # X_position and X_segmentation are needed only when using "packed examples" @@ -293,16 +294,18 @@ def train_step( # like a normal, unpacked sequence example. train_keys = ['inputs', 'inputs_position', 'inputs_segmentation', 'targets'] (inputs, inputs_positions, inputs_segmentation, targets) = ( - batch.get(k, None) for k in train_keys + batch.get(k, None) for k in train_keys ) # TODO: this should be defined globally pad_id = 0 weights = jnp.where(inputs > pad_id, 1, 0).astype(jnp.float32) input_mask = inputs > pad_id - attention_mask = transformer_lib.make_causal_attn_mask(input_mask) # (B, L, L) + # (B, L, L) + attention_mask = transformer_lib.make_causal_attn_mask(input_mask) # inputs_segmentation: (B, L) - mask = inputs_segmentation[:, :, None] == inputs_segmentation[:, None, :] # (B, L, L) + # mask: (B, L, L) + mask = inputs_segmentation[:, :, None] == inputs_segmentation[:, None, :] attention_mask = jnp.logical_and(mask, attention_mask) def loss_fn(params): @@ -310,14 +313,14 @@ def loss_fn(params): module = nnx.merge(state.graphdef, params) logits, _ = module( - inputs, - positions=inputs_positions, - attention_mask=attention_mask, - cache=None, + inputs, + positions=inputs_positions, + attention_mask=attention_mask, + cache=None, ) loss, weight_sum = compute_weighted_cross_entropy( - logits, targets, weights, label_smoothing + logits, targets, weights, label_smoothing ) mean_loss = loss / weight_sum return mean_loss, logits @@ -334,10 +337,10 @@ def loss_fn(params): def eval_step( - params: nnx.State, - batch, - graphdef: nnx.GraphDef[transformer_lib.Transformer], - label_smoothing=0.0, + params: nnx.State, + batch, + graphdef: nnx.GraphDef[transformer_lib.Transformer], + label_smoothing=0.0, ): """Calculate evaluation metrics on a batch.""" inputs, targets = batch['inputs'], batch['targets'] @@ -351,21 +354,21 @@ def eval_step( module = nnx.merge(graphdef, params) logits, _ = module( - inputs, - positions=inputs_positions, - attention_mask=attention_mask, - cache=None, + inputs, + positions=inputs_positions, + attention_mask=attention_mask, + cache=None, ) return compute_metrics(logits, targets, weights, label_smoothing) def evaluate( - *, - jit_eval_step, - state: TrainState, - eval_ds: tf.data.Dataset, - num_eval_steps: int, + *, + jit_eval_step, + state: utils.TrainState, + eval_ds: tf.data.Dataset, + num_eval_steps: int, ): """Evaluate the target an return a dictionary with the metrics.""" logging.info('Gathering evaluation metrics.') @@ -379,13 +382,107 @@ def evaluate( eval_metrics_sums = jax.tree.map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree.map( - lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop - eval_metrics_sums, + lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop + eval_metrics_sums, ) return eval_summary -def train_and_evaluate(config: TrainConfig, workdir: str): +def get_fake_batch(batch_size: int) -> Any: + """Returns fake data for the given batch size.""" + rng = jax.random.PRNGKey(0) + batch = {} + for k in ( + 'inputs', + 'inputs_position', + 'inputs_segmentation', + 'targets', + 'targets_position', + 'targets_segmentation', + ): + batch[k] = jax.random.randint(rng, (batch_size, 128), 0, 9999999, jnp.int32) + return batch + + +def get_apply_fn_and_args( + config: ml_collections.ConfigDict, + vocab_size: int | None = None, +): + """Returns the apply function and args for the given config.""" + if vocab_size is None: + vocab_size = config.vocab_size + + # Build Model and Optimizer + # --------------------------------------------------------------------------- + if config.transformer_name is not None: + model_config = transformer_lib.TransformerConfig.from_version_name( + config.transformer_name, + num_embed=vocab_size, + dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, + axis_rules=config.axis_rules, + ) + else: + assert config.transformer_params is not None + model_config = transformer_lib.TransformerConfig.from_dict( + **config.transformer_params, + num_embed=vocab_size, + dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, + axis_rules=config.axis_rules, + ) + + # Mesh definition + devices_array = utils.create_device_mesh(config) + mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) + + rng = jax.random.PRNGKey(config.seed) + rng, init_rng = jax.random.split(rng) + + def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): + return transformer_lib.Transformer(config, rngs=nnx.Rngs(params=key)) + + learning_rate_fn = create_learning_rate_schedule( + learning_rate=config.learning_rate, warmup_steps=config.warmup_steps + ) + + optimizer = optax.adamw( + learning_rate_fn, + b1=0.9, + b2=0.98, + eps=1e-9, + weight_decay=config.weight_decay, + ) + + state, state_sharding = utils.setup_initial_state( + constructor, optimizer, model_config, init_rng, mesh + ) + data_sharding = jax.NamedSharding(mesh, jax.P(config.data_sharding)) + jit_train_step = jax.jit( + train_step, + in_shardings=( + state_sharding, + data_sharding, + ), # type: ignore + out_shardings=(state_sharding, None), # type: ignore + static_argnames=('learning_rate_fn', 'label_smoothing'), + donate_argnums=0, + ) + + batch = get_fake_batch(config.per_device_batch_size) + batch = jax.tree.map(lambda x: jnp.asarray(x, device=data_sharding), batch) + + return ( + jit_train_step, + (state, batch, learning_rate_fn, 0.0), + dict(), + ( + state_sharding, + data_sharding, + rng, + ), + ) + + +def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: @@ -406,59 +503,28 @@ def train_and_evaluate(config: TrainConfig, workdir: str): # --------------------------------------------------------------------------- logging.info('Initializing dataset.') train_ds, eval_ds, encoder = input_pipeline.get_datasets( - n_devices=jax.local_device_count(), config=config, vocab_path=vocab_path + n_devices=jax.local_device_count(), config=config, vocab_path=vocab_path ) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) logging.info('Initializing model, optimizer, and step functions.') - # Build Model and Optimizer - # --------------------------------------------------------------------------- - if config.transformer_name is not None: - model_config = transformer_lib.TransformerConfig.from_version_name( - config.transformer_name, - num_embed=vocab_size, - dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, - axis_rules=config.axis_rules, - ) - else: - assert config.transformer_params is not None - model_config = transformer_lib.TransformerConfig.from_dict( - **config.transformer_params, - num_embed=vocab_size, - dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, - axis_rules=config.axis_rules, - ) - - # Mesh definition - devices_array = utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) start_step = 0 - rng = jax.random.PRNGKey(config.seed) - rng, init_rng = jax.random.split(rng) - rng, inference_rng = random.split(rng) - - def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): - return transformer_lib.Transformer(config, rngs=nnx.Rngs(params=key)) - learning_rate_fn = create_learning_rate_schedule( - learning_rate=config.learning_rate, warmup_steps=config.warmup_steps - ) + ( + jit_train_step, + (state, _, learning_rate_fn, _), + _, + ( + state_sharding, + data_sharding, + rng, + ), + ) = get_apply_fn_and_args(config, vocab_size) - optimizer = optax.adamw( - learning_rate_fn, - b1=0.9, - b2=0.98, - eps=1e-9, - weight_decay=config.weight_decay, - ) - - state, state_sharding = utils.setup_initial_state( - constructor, optimizer, model_config, init_rng, mesh - ) - data_sharding = NamedSharding(mesh, P(config.data_sharding)) + _, inference_rng = random.split(rng) if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. @@ -467,38 +533,28 @@ def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): start_step = int(state.step) writer = metric_writers.create_default_writer( - workdir, just_logging=jax.process_index() > 0 + workdir, just_logging=jax.process_index() > 0 ) if start_step == 0: writer.write_hparams(dataclasses.asdict(config)) # compile multidevice versions of train/eval/predict step fn. - jit_train_step = jax.jit( - train_step, - in_shardings=( - state_sharding, - data_sharding, - ), # type: ignore - out_shardings=(state_sharding, None), # type: ignore - static_argnames=("learning_rate_fn", "label_smoothing"), - donate_argnums=0, - ) jit_eval_step = jax.jit( - eval_step, - in_shardings=( - state_sharding.params, - data_sharding, - ), # type: ignore - out_shardings=None, # type: ignore - static_argnames=("graphdef", "label_smoothing"), + eval_step, + in_shardings=( + state_sharding.params, + data_sharding, + ), # type: ignore + out_shardings=None, # type: ignore + static_argnames=('graphdef', 'label_smoothing'), ) vocab = tokenizer.load_sentencepiece_processor(vocab_path) - sampler = sampler_lib.Sampler( - transformer=nnx.merge(state.graphdef, state.params), - vocab=vocab, - cache_size=1024, + sampler = sampler_lib.Sampler( + transformer=nnx.merge(state.graphdef, state.params), + vocab=vocab, + cache_size=1024, ) # Main Train Loop @@ -509,12 +565,12 @@ def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): logging.info('Starting training loop.') hooks = [] report_progress = periodic_actions.ReportProgress( - num_train_steps=config.num_train_steps, writer=writer + num_train_steps=config.num_train_steps, writer=writer ) if jax.process_index() == 0: hooks += [ - report_progress, - periodic_actions.Profile(logdir=workdir, num_profile_steps=5), + report_progress, + periodic_actions.Profile(logdir=workdir, num_profile_steps=5), ] train_metrics = [] with metric_writers.ensure_flushes(writer): @@ -525,12 +581,12 @@ def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): with jax.profiler.StepTraceAnnotation('train', step_num=step): with report_progress.timed('data'): batch = next(train_iter) - batch = jax.tree.map(lambda x: jnp.asarray(x, device=data_sharding), batch) + batch = jax.tree.map( + lambda x: jnp.asarray(x, device=data_sharding), batch + ) with report_progress.timed('train_step'): - state, metrics = jit_train_step( - state, batch, learning_rate_fn, 0.0 - ) + state, metrics = jit_train_step(state, batch, learning_rate_fn, 0.0) train_metrics.append(metrics) # Quick indication that training is happening. @@ -541,14 +597,17 @@ def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): # Write batch loss and lr every step to TB # without overwhelming the stdout: if jax.process_index() == 0: - tb_writer = writer._writers[-1] + tb_writer = writer._writers[-1] # pylint: disable=protected-access lr = train_metrics[-1]['learning_rate'] train_batch_loss = train_metrics[-1]['loss'] denominator = train_metrics[-1]['denominator'] - tb_writer.write_scalars(step, { - "train_learning_rate": lr, - "train_loss": train_batch_loss / denominator, - }) + tb_writer.write_scalars( + step, + { + 'train_learning_rate': lr, + 'train_loss': train_batch_loss / denominator, + }, + ) # Periodic metric handling. if (step > 0 and step % config.eval_every_steps == 0) or is_last_step: @@ -569,33 +628,33 @@ def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): # update sampler's transformer state: sampler.transformer_state = state.params exemplars = sampler( - config.prompts, - total_generation_steps=config.num_predict_steps, - temperature=config.sampling_temperature, - top_p=config.sampling_top_p, - seed=inference_rng, - echo=True, + config.prompts, + total_generation_steps=config.num_predict_steps, + temperature=config.sampling_temperature, + top_p=config.sampling_top_p, + seed=inference_rng, + echo=True, ) - writer.write_texts(step, {'samples': exemplars.text}) + writer.write_texts(step, {'samples': exemplars.text[0]}) with report_progress.timed('eval'): eval_results = evaluate( - jit_eval_step=jit_eval_step, - state=state, - eval_ds=eval_ds, - num_eval_steps=config.num_eval_steps, + jit_eval_step=jit_eval_step, + state=state, + eval_ds=eval_ds, + num_eval_steps=config.num_eval_steps, ) # (clipped) perplexity after averaging log-perplexity eval_results['perplexity'] = jnp.clip( - jnp.exp(eval_results['loss']), max=1.0e4 + jnp.exp(eval_results['loss']), max=1.0e4 ) writer.write_scalars( - step, {'eval_' + k: v for k, v in eval_results.items()} + step, {'eval_' + k: v for k, v in eval_results.items()} ) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ( - step % config.checkpoint_every_steps == 0 or is_last_step + step % config.checkpoint_every_steps == 0 or is_last_step ) if config.save_checkpoints and save_checkpoint: logging.info('Saving checkpoint step %d.', step) diff --git a/examples/gemma/utils.py b/examples/gemma/utils.py index 18f6909cc..076400e0a 100644 --- a/examples/gemma/utils.py +++ b/examples/gemma/utils.py @@ -15,7 +15,7 @@ # Copied over from MaxText (https://github.com/google/maxtext/blob/main/MaxText/max_utils.py). import logging -from typing import Any, TYPE_CHECKING +from typing import Any from collections.abc import Callable import jax @@ -28,9 +28,6 @@ from flax import nnx from flax.training import train_state -if TYPE_CHECKING: - from train import TrainConfig - Dtype = Any Shape = tuple[int, ...] @@ -43,7 +40,7 @@ class TrainState(train_state.TrainState): # ----------------------------------------------------------------------------- -def create_device_mesh(config: "TrainConfig"): +def create_device_mesh(config: Any): """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas.""" devices = jax.devices() num_devices = len(devices) diff --git a/examples/mnist/train.py b/examples/mnist/train.py index 0886a1963..f6e32ca95 100644 --- a/examples/mnist/train.py +++ b/examples/mnist/train.py @@ -20,6 +20,7 @@ # See issue #620. # pytype: disable=wrong-keyword-args +from typing import Any, Callable from absl import logging from flax import linen as nn @@ -51,6 +52,23 @@ def __call__(self, x): return x +def get_fake_batch(batch_size: int) -> Any: + """Returns fake data for the given batch size.""" + rng = jax.random.PRNGKey(0) + images = jax.random.randint(rng, (batch_size, 28, 28, 1), 0, 255, jnp.uint8) + labels = jax.random.randint(rng, (batch_size,), 0, 10, jnp.int32) + return images, labels + + +def get_apply_fn_and_args( + config: ml_collections.ConfigDict, +) -> tuple[Any, tuple[Any, ...], dict[str, Any], tuple[Any, ...]]: + """Returns the apply function and args for the given config.""" + state = create_train_state(jax.random.key(0), config) + batch = get_fake_batch(config.batch_size) + return apply_model, (state, *batch), dict(), () + + @jax.jit def apply_model(state, images, labels): """Computes gradients, loss and accuracy for a single batch.""" diff --git a/examples/wmt/models.py b/examples/wmt/models.py index 5da0f7065..e4d89f8a8 100644 --- a/examples/wmt/models.py +++ b/examples/wmt/models.py @@ -338,7 +338,7 @@ def __call__(self, inputs, inputs_positions=None, encoder_mask=None): output of a transformer encoder. """ config = self.config - assert inputs.ndim == 2 # (batch, len) + assert inputs.ndim == 2, inputs.shape # (batch, len) # Input Embedding if self.shared_embedding is None: diff --git a/examples/wmt/train.py b/examples/wmt/train.py index 0f9f2bec5..edeee5cd9 100644 --- a/examples/wmt/train.py +++ b/examples/wmt/train.py @@ -23,12 +23,17 @@ import collections import functools import os +from typing import Any from absl import logging from clu import metric_writers from clu import periodic_actions from flax import jax_utils from flax import linen as nn +import bleu +import decode +import input_pipeline +import models from flax.training import checkpoints from flax.training import common_utils from flax.training import dynamic_scale as dynamic_scale_lib @@ -41,11 +46,6 @@ import orbax.checkpoint as ocp import tensorflow as tf -import bleu -import decode -import input_pipeline -import models - class TrainState(train_state.TrainState): dynamic_scale: dynamic_scale_lib.DynamicScale @@ -250,7 +250,7 @@ def loss_fn(params): if state.dynamic_scale: # if is_fin == False the gradients contain Inf/NaNs and optimizer state and # params should be restored (= skip this step). - select_fn = functools.partial(jnp.where, is_fin) + select_fn = functools.partial(jnp.where, is_fin) # pylint: disable=undefined-variable new_state = new_state.replace( opt_state=jax.tree_util.tree_map( select_fn, new_state.opt_state, state.opt_state @@ -259,7 +259,7 @@ def loss_fn(params): select_fn, new_state.params, state.params ), ) - metrics["loss_scale"] = dynamic_scale.scale * metrics["denominator"] + metrics["loss_scale"] = dynamic_scale.scale * metrics["denominator"] # pylint: disable=undefined-variable return new_state, metrics @@ -394,6 +394,111 @@ def evaluate( return eval_summary +def get_fake_batch(batch_size: int) -> Any: + """Returns fake data for the given batch size.""" + rng = jax.random.PRNGKey(0) + batch = {} + for k in ( + "inputs", + "inputs_position", + "inputs_segmentation", + "targets", + "targets_position", + "targets_segmentation", + ): + batch[k] = jax.random.randint( + rng, + (batch_size, 256), + 0, + 9999999, + dtype=jnp.int32, + ) + batch = common_utils.shard(batch) + return batch + + +def get_apply_fn_and_args( + config: ml_collections.ConfigDict, vocab_size: int | None = None, +) -> tuple[Any, tuple[Any, ...], dict[str, Any], tuple[Any, ...]]: + """Returns the apply function and args for the given config.""" + if vocab_size is None: + vocab_size = config.vocab_size + dtype = preferred_dtype(config) + learning_rate_fn = create_learning_rate_schedule( + learning_rate=config.learning_rate, warmup_steps=config.warmup_steps + ) + train_config = models.TransformerConfig( + vocab_size=vocab_size, + output_vocab_size=vocab_size, + share_embeddings=config.share_embeddings, + logits_via_embedding=config.logits_via_embedding, + dtype=dtype, + emb_dim=config.emb_dim, + num_heads=config.num_heads, + num_layers=config.num_layers, + qkv_dim=config.qkv_dim, + mlp_dim=config.mlp_dim, + max_len=max(config.max_target_length, config.max_eval_target_length), + dropout_rate=config.dropout_rate, + attention_dropout_rate=config.attention_dropout_rate, + deterministic=False, + decode=False, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.normal(stddev=1e-6), + ) + p_train_step = jax.pmap( + functools.partial( + train_step, + config=train_config, + learning_rate_fn=learning_rate_fn, + label_smoothing=config.label_smoothing, + ), + axis_name="batch", + donate_argnums=(0,), + ) # pytype: disable=wrong-arg-types + + dynamic_scale = None + if dtype == jnp.float16: + dynamic_scale = dynamic_scale_lib.DynamicScale() + eval_config = train_config.replace(deterministic=True) + m = models.Transformer(eval_config) + rng = jax.random.key(config.seed) + rng, init_rng = jax.random.split(rng) + input_shape = (config.per_device_batch_size, config.max_target_length) + target_shape = (config.per_device_batch_size, config.max_target_length) + initial_variables = jax.jit(m.init)( + init_rng, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32), + ) + state = TrainState.create( + apply_fn=m.apply, + params=initial_variables["params"], + tx=optax.adamw( + learning_rate=learning_rate_fn, + b1=0.9, + b2=0.98, + eps=1e-9, + weight_decay=config.weight_decay, + ), + dynamic_scale=dynamic_scale, + ) + state = jax_utils.replicate(state) + batch = get_fake_batch( + jax.local_device_count() * config.per_device_batch_size + ) + + # We init the first set of dropout PRNG keys, but update it afterwards inside + # the main pmap"d training update for performance. + dropout_rngs = jax.random.split(rng, jax.local_device_count()) + return ( + jax.jit(p_train_step), + (state, batch), + dict(dropout_rng=dropout_rngs), + (train_config,), + ) + + def translate_and_calculate_bleu( *, p_pred_step, @@ -497,94 +602,17 @@ def decode_tokens(toks): logging.info("Initializing model, optimizer, and step functions.") - dtype = preferred_dtype(config) - # Build Model and Optimizer # --------------------------------------------------------------------------- - train_config = models.TransformerConfig( - vocab_size=vocab_size, - output_vocab_size=vocab_size, - share_embeddings=config.share_embeddings, - logits_via_embedding=config.logits_via_embedding, - dtype=dtype, - emb_dim=config.emb_dim, - num_heads=config.num_heads, - num_layers=config.num_layers, - qkv_dim=config.qkv_dim, - mlp_dim=config.mlp_dim, - max_len=max(config.max_target_length, config.max_eval_target_length), - dropout_rate=config.dropout_rate, - attention_dropout_rate=config.attention_dropout_rate, - deterministic=False, - decode=False, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.normal(stddev=1e-6), - ) - eval_config = train_config.replace(deterministic=True) - predict_config = train_config.replace(deterministic=True, decode=True) - start_step = 0 - rng = jax.random.key(config.seed) - rng, init_rng = jax.random.split(rng) - input_shape = (config.per_device_batch_size, config.max_target_length) - target_shape = (config.per_device_batch_size, config.max_target_length) - - m = models.Transformer(eval_config) - initial_variables = jax.jit(m.init)( - init_rng, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32), - ) - - # Create train state with Adam optimizer and weight decay. - learning_rate_fn = create_learning_rate_schedule( - learning_rate=config.learning_rate, warmup_steps=config.warmup_steps - ) - dynamic_scale = None - if dtype == jnp.float16: - dynamic_scale = dynamic_scale_lib.DynamicScale() - state = TrainState.create( - apply_fn=m.apply, - params=initial_variables["params"], - tx=optax.adamw( - learning_rate=learning_rate_fn, - b1=0.9, - b2=0.98, - eps=1e-9, - weight_decay=config.weight_decay, - ), - dynamic_scale=dynamic_scale, - ) - - # We access model params only via state.params - del initial_variables - - if config.restore_checkpoints: - # Restore unreplicated optimizer + model state from last checkpoint. - state = checkpoints.restore_checkpoint(workdir, state) - # Grab last step. - start_step = int(state.step) - - writer = metric_writers.create_default_writer( - workdir, just_logging=jax.process_index() > 0 - ) - if start_step == 0: - writer.write_hparams(dict(config)) - - # Replicate state. - state = jax_utils.replicate(state) # compile multidevice versions of train/eval/predict step and cache init fn. - p_train_step = jax.pmap( - functools.partial( - train_step, - config=train_config, - learning_rate_fn=learning_rate_fn, - label_smoothing=config.label_smoothing, - ), - axis_name="batch", - donate_argnums=(0,), - ) # pytype: disable=wrong-arg-types + p_train_step, (state, _,), kwargs, (train_config,) = ( + get_apply_fn_and_args(config, vocab_size) + ) + dropout_rngs = kwargs["dropout_rng"] + eval_config = train_config.replace(deterministic=True) + predict_config = train_config.replace(deterministic=True, decode=True) p_eval_step = jax.pmap( functools.partial(eval_step, config=eval_config), axis_name="batch" ) @@ -604,14 +632,23 @@ def decode_tokens(toks): static_broadcasted_argnums=(3, 4), ) # eos token, max_length are constant + if config.restore_checkpoints: + # Restore unreplicated optimizer + model state from last checkpoint. + unreplicated_state = jax_utils.unreplicate(state) + state = checkpoints.restore_checkpoint(workdir, unreplicated_state) + # Grab last step. + start_step = int(state.step) + state = jax_utils.replicate(state) + + writer = metric_writers.create_default_writer( + workdir, just_logging=jax.process_index() > 0 + ) + if start_step == 0: + writer.write_hparams(dict(config)) + # Main Train Loop # --------------------------------------------------------------------------- - # We init the first set of dropout PRNG keys, but update it afterwards inside - # the main pmap"d training update for performance. - dropout_rngs = jax.random.split(rng, jax.local_device_count()) - del rng - logging.info("Starting training loop.") hooks = [] report_progress = periodic_actions.ReportProgress( @@ -649,8 +686,8 @@ def decode_tokens(toks): metrics_sums = jax.tree_util.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop("denominator") summary = jax.tree_util.tree_map( - lambda x: x / denominator, metrics_sums - ) # pylint: disable=cell-var-from-loop + lambda x: x / denominator, metrics_sums # pylint: disable=cell-var-from-loop + ) summary["learning_rate"] = lr summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary)