Skip to content

Conversation

bzantium
Copy link
Contributor

@bzantium bzantium commented Sep 16, 2025

This PR enables the z_loss (logit regularization) feature in the main training loop.

Previously, train.py hardcoded the z_loss factor to 0.0, disabling this feature. This change introduces a new config flag, z_loss_weight, allowing users to control the regularization strength. It also adds the z_loss component to the training and evaluation metrics.

Why is this change being made?

The z_loss term (weight * log(z)^2) is a useful regularization technique for penalizing large logits, which can improve model stability and prevent logit drift. The core function max_utils.cross_entropy_with_logits already supported this, but it was inaccessible from the training script.

This PR "re-enables" this dormant feature, providing a valuable hyperparameter for tuning.

This solution is:

  • Flexible: Users can now enable and tune z_loss strength via config.
  • Backward-compatible: The new z_loss_weight flag defaults to 0.0, ensuring existing configurations are unaffected.
  • Correct: It correctly integrates the z_loss value into the metrics logging pipeline, handling both standard and gradient-accumulated (GA) steps.

Changes in this PR

  • base.yml:

    • Adds a new hyperparameter, z_loss_weight: 0.0, near other regularization flags like dropout_rate.
  • loss_fn (train.py):

    • Calls max_utils.cross_entropy_with_logits with config.z_loss_weight instead of 0.0.
    • Captures the z_loss component, applies the targets_segmentation mask, and computes its jnp.sum().
    • Returns z_loss in the aux dict, handling GA and non-GA cases consistently:
      • If GA > 1: Returns the sum of z_loss (to be accumulated).
      • If GA == 1: Returns the average of z_loss (to be logged directly).
  • train_step (train.py):

    • GA Branch:
      • The z_loss sum is now correctly accumulated in accumulate_gradient.
      • The final learning/z_loss metric is calculated as an average per token (total z_loss sum / total_weights) to ensure consistent logging.
    • Non-GA Branch:
      • The z_loss average from loss_fn is retrieved and logged directly as learning/z_loss.
  • eval_step (train.py):

    • Captures the average z_loss from loss_fn and logs it as evaluation/z_loss.

Fixes: #2352

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[feature request] Make z_loss factor configurable
1 participant