Skip to content

[BUG] FSDP engine: torch.optim.AdamW inherits bf16 dtype from model.parameters(), causing late-stage convergence to plateau ~3× higher than DS-Z3 / Megatron #1292

@yulangz

Description

@yulangz

Summary

When using the FSDP backend with actor.dtype: bfloat16 and the default actor.optimizer.type: adam, AdamW's optimizer state (exp_avg, exp_avg_sq) is silently created in bf16 instead of fp32, because torch.optim.AdamW(self.model.parameters(), ...) inherits the parameter dtype. This causes the late-stage loss to plateau noticeably higher than DeepSpeed ZeRO-3 (LLaMA-Factory) or AReaL Megatron, both of which maintain fp32 master weights and fp32 m/v by default.

Reproduction

8 × L20X (143 GB), single node, identical recipe across three frameworks:

  • model: Qwen/Qwen3-8B, bf16, gradient_checkpointing, no rope_scaling
  • data: 3520 SWE-bench-style trajectories, 32K cutoff, sorted by instance_id, no shuffling
  • training: 4 epochs, gbs=64 (per_dev=1, gas=8, world=8), lr=1e-4 cosine, warmup_ratio=0.1, wd=0.01, clip=1.0, bf16
  • 220 optimizer steps, identical lr schedule across frameworks (verified to <1e-9)

Final loss after 220 steps:

Framework Optimizer state dtype Final loss
LLaMA-Factory + DeepSpeed ZeRO-3 fp32 (DS default) 0.0464
AReaL Megatron d8t1p1c1 fp32 (DistributedOptimizer default) 0.0480
AReaL Megatron d4t1p1c2 (CP=2) fp32 0.0467
AReaL Megatron d4t2p1c1 (TP=2) fp32 0.0466
AReaL Megatron d4t1p2c1 (PP=2) fp32 0.0470
AReaL FSDP d8p1t1 bf16 0.1309 ← ~3× higher

LF and Megatron variants all sit at 0.046–0.048; AReaL FSDP plateaus at 0.131. The divergence starts around step 50, becomes visible by step 100 (LF: 0.152, FSDP: 0.177), and grows to 0.08+ absolute by step 200. grad_norm stays well-behaved (~0.2-0.3) so it's not a divergence; it's a plateau caused by noisy Adam updates as √v̂ loses precision in bf16 when is small.

Root cause

areal/engine/fsdp_engine.py around the AdamW construction:

self.optimizer = torch.optim.AdamW(
    self.model.parameters(),  # bf16 params after model.to(bfloat16)
    lr=lr,
    weight_decay=weight_decay,
    ...
    fused=not (self.is_vision_model and self.parallel_helper.tp_enabled),
)

Verification:

import torch
m = torch.nn.Linear(8, 8).bfloat16()
opt = torch.optim.AdamW(m.parameters(), lr=1e-4, fused=False)
m(torch.randn(2, 8, dtype=torch.bfloat16)).sum().backward()
opt.step()
print(opt.state[m.weight]['exp_avg'].dtype)     # torch.bfloat16
print(opt.state[m.weight]['exp_avg_sq'].dtype)  # torch.bfloat16

The FSDP2 MixedPrecisionPolicy(param_dtype=bf16, reduce_dtype=fp32) set up in parallelize_model only governs forward/backward compute and gradient reduction; it does not maintain fp32 master weights, so the optimizer keeps everything in bf16.

DeepSpeed ZeRO-3 and Megatron Core's DistributedOptimizer both default to maintaining fp32 master weights + fp32 m/v even when the model is bf16; this is the standard mixed-precision recipe and the reason their final loss matches.

Suggested fixes

Two options, in increasing order of intrusiveness:

Option A: switch the default to adam_bf16 for bf16 models

AnyPrecisionAdamW (already in areal/engine/fsdp_utils/optimizer.py) keeps m/v in bf16 but uses Kahan summation for the parameter update, recovering fp32-equivalent training stability. Per the AnyPrecision paper this is essentially indistinguishable from fp32 Adam in convergence.

Currently AReaL warns "may be less stable" for adam_bf16, but for bf16 dense/MoE models it appears to be the more correct choice than the current bf16 m/v + non-Kahan AdamW.

Option B: maintain fp32 master weights via FSDP2 mixed precision

Load the model in fp32, let MixedPrecisionPolicy(param_dtype=bf16) cast forward/backward to bf16, while the underlying parameter storage (and therefore model.parameters() seen by AdamW) stays fp32. This is the FSDP2 standard mixed-precision recipe and matches DS-Z3 / Megatron behavior bit-for-bit (modulo all-reduce ordering).

Memory cost: +16 GB params storage on 8B models. On L20X 143 GB this is well within budget.

A docs/warning addition pointing users to adam_bf16 would be a low-cost first step.

Environment

  • AReaL: wht/swe-dev branch, latest commit 52aa0cebb fix(megatron): pass absolute total_train_steps as lr_decay_steps
  • transformers: 4.55 / 4.57.1 (FSDP path)
  • torch: 2.x with FSDP2
  • 8 × NVIDIA L20X (143 GB), single node

Related

While debugging this I also found that areal/engine/megatron_engine.py was passing lr_decay_steps = total_train_steps - warmup_steps to OptimizerParamScheduler, which Megatron Core then internally subtracts warmup from again, causing cosine to reach min_lr ~10% earlier than intended. Fixed in the same branch.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingcall-for-contributionSomething planned but not in our current developmentgood first issueGood for newcomers

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions