Skip to content

FSDPRefWorkerBase loads in pure bf16 while policy/critic use mixed precision, causing KL NaN on long sequences #1694

@jamesbraza

Description

@jamesbraza

In SkyRL v0.2.0's FSDP backend, there's a bit of an asymmetry:

This leads to a case where, by default:

  • Policy and Critic: mixed precision (weights fp32, forward pass autocasts to bf16)
  • Ref: pure bf16 (weights bf16, forward pass autocasts to bf16)

The asymmetry produces no observable failure on small models (e.g. ≤14B) and short sequences (e.g. ≤16k tokens), but on large dense models with long sequences, the ref's pure-bf16 attention overflows; one bad key/value position then poisons every later position in the sequence (each later position attends back to the bad one).

The cause:

  1. Ref's weights in bf16 vs policy weights in fp32, across many (e.g. Qwen3-32B has 64) layers, a rounding error compounds
    • Requires use_kl_loss=true and kl_loss_coef>0, so the reference model is wired into the policy loss
  2. Eventually an attention dot product saturates to +/- inf
  3. Turns log_probs_base into NaN
  4. Contaminates the final loss to be NaN

Workarounds

Match the hardcoded bf16=false for policy and critic.

  • Configure trainer.bf16=false on the launch CLI
  • Patch HFModelWrapper.__init__ globally to force bf16=False

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions