feat(train): Enable z_loss regularization and logging via z_loss_weight
flag
#2353
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR enables the
z_loss
(logit regularization) feature in the main training loop.Previously,
train.py
hardcoded thez_loss
factor to0.0
, disabling this feature. This change introduces a new config flag,z_loss_weight
, allowing users to control the regularization strength. It also adds thez_loss
component to the training and evaluation metrics.Why is this change being made?
The
z_loss
term (weight * log(z)^2
) is a useful regularization technique for penalizing large logits, which can improve model stability and prevent logit drift. The core functionmax_utils.cross_entropy_with_logits
already supported this, but it was inaccessible from the training script.This PR "re-enables" this dormant feature, providing a valuable hyperparameter for tuning.
This solution is:
z_loss
strength via config.z_loss_weight
flag defaults to0.0
, ensuring existing configurations are unaffected.z_loss
value into the metrics logging pipeline, handling both standard and gradient-accumulated (GA) steps.Changes in this PR
base.yml
:z_loss_weight: 0.0
, near other regularization flags likedropout_rate
.loss_fn
(train.py
):max_utils.cross_entropy_with_logits
withconfig.z_loss_weight
instead of0.0
.z_loss
component, applies thetargets_segmentation
mask, and computes itsjnp.sum()
.z_loss
in theaux
dict, handling GA and non-GA cases consistently:z_loss
(to be accumulated).z_loss
(to be logged directly).train_step
(train.py
):z_loss
sum is now correctly accumulated inaccumulate_gradient
.learning/z_loss
metric is calculated as an average per token (totalz_loss
sum /total_weights
) to ensure consistent logging.z_loss
average fromloss_fn
is retrieved and logged directly aslearning/z_loss
.eval_step
(train.py
):z_loss
fromloss_fn
and logs it asevaluation/z_loss
.Fixes: #2352
Before submitting this PR, please make sure (put X in square brackets):