Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ logits_dot_in_fp32: False # whether to use fp32 in logits_dense or shared_embed
cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher precision is generally beneficial, but it can vary slightly.
float32_qk_product: False # in dot_product attention, whether to cast to fp32 the inputs to qk product
float32_logits: False # in dot_product attention, whether to cast to fp32 the inputs to softmax
z_loss_weight: 0.0 # z-loss regularization weight, 0.0 to disable

# Multi-Token Prediction Configs
# The number of auxiliary prediction layers to use for MTP.
Expand Down
17 changes: 13 additions & 4 deletions src/MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,14 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):
intermediate_outputs = {}

one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size)
xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets, 0.0)
xent, z_loss = max_utils.cross_entropy_with_logits(logits, one_hot_targets, config.z_loss_weight)
xent = nn.with_logical_constraint(xent, ("activation_embed_and_logits_batch", "activation_length"))
# Mask out paddings at the end of each example.
xent = xent * (data["targets_segmentation"] != 0)
z_loss = z_loss * (data["targets_segmentation"] != 0)
total_loss = jnp.sum(xent)
total_weights = jnp.sum(data["targets_segmentation"] != 0)
z_loss = jnp.sum(z_loss)
# If gradient accumulation is enabled, we don't need to divide total_loss
# by total_weights and then multiply the computed gradient by total_weights,
# since it's equivalent to computing the gradient from total_loss.
Expand All @@ -161,6 +163,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):
loss = total_loss
else:
loss = total_loss / (total_weights + EPS)
z_loss = z_loss / (total_weights + EPS)

# Calculate and Add MTP Loss
mtp_loss = 0.0
Expand All @@ -186,6 +189,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):
"total_weights": total_weights,
"moe_lb_loss": moe_lb_loss,
"mtp_loss": mtp_loss,
"z_loss": z_loss,
}
return loss, aux

Expand Down Expand Up @@ -222,6 +226,7 @@ def accumulate_gradient(acc_grad_and_loss, data):
acc_grad_and_loss["loss"] += aux["total_loss"]
acc_grad_and_loss["moe_lb_loss"] += aux["moe_lb_loss"]
acc_grad_and_loss["mtp_loss"] += aux["mtp_loss"]
acc_grad_and_loss["z_loss"] += aux["z_loss"]
acc_grad_and_loss["grad"] = jax.tree_util.tree_map(
lambda x, y: x + y, cur_batch_gradient, acc_grad_and_loss["grad"]
)
Expand All @@ -236,7 +241,7 @@ def reshape_to_microbatch_accumulations(batch_arr):

data = jax.tree_util.tree_map(reshape_to_microbatch_accumulations, data)
init_grad = jax.tree_util.tree_map(jnp.zeros_like, state.params)
init_grad_and_loss = {"loss": 0.0, "grad": init_grad, "total_weights": 0, "moe_lb_loss": 0.0, "mtp_loss": 0.0}
init_grad_and_loss = {"loss": 0.0, "grad": init_grad, "total_weights": 0, "moe_lb_loss": 0.0, "mtp_loss": 0.0, "z_loss": 0.0}

grad_and_loss, aux = jax.lax.scan(
accumulate_gradient, init_grad_and_loss, data, length=config.gradient_accumulation_steps
Expand All @@ -259,8 +264,9 @@ def reshape_to_microbatch_accumulations(batch_arr):
(loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, state.params, *extra_dpo_args, is_train=True)
intermediate_outputs = aux["intermediate_outputs"]
total_weights = aux["total_weights"]
moe_lb_loss = aux["moe_lb_loss"]
mtp_loss = aux["mtp_loss"]
moe_lb_loss = aux["moe_lb_loss"] / config.gradient_accumulation_steps
mtp_loss = aux["mtp_loss"] / config.gradient_accumulation_steps
z_loss = aux["z_loss"] / total_weights

if config.gradient_clipping_threshold > 0:
grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold)
Expand Down Expand Up @@ -294,6 +300,7 @@ def move(path, value):
"learning/moe_lb_loss": moe_lb_loss,
"learning/mtp_loss": mtp_loss,
"learning/total_weights": total_weights,
"learning/z_loss": z_loss,
}
if not config.optimizer_memory_host_offload:
scalar_metrics["learning/grad_norm"] = max_utils.l2norm_pytree(grads)
Expand Down Expand Up @@ -335,6 +342,7 @@ def eval_step(model, config, state, data, dropout_rng):
total_weights = aux["total_weights"]
moe_lb_loss = aux["moe_lb_loss"]
mtp_loss = aux["mtp_loss"]
z_loss = aux["z_loss"]
metrics = {
"scalar": {
"evaluation/loss": loss,
Expand All @@ -343,6 +351,7 @@ def eval_step(model, config, state, data, dropout_rng):
"evaluation/moe_lb_loss": moe_lb_loss,
"evaluation/mtp_loss": mtp_loss,
"evaluation/mtp_acceptance_rate_percent": mtp_acceptance_rate,
"evaluation/z_loss": z_loss,
},
}
if config.use_dpo:
Expand Down
Loading