Summary
When using the FSDP backend with actor.dtype: bfloat16 and the default actor.optimizer.type: adam, AdamW's optimizer state (exp_avg, exp_avg_sq) is silently created in bf16 instead of fp32, because torch.optim.AdamW(self.model.parameters(), ...) inherits the parameter dtype. This causes the late-stage loss to plateau noticeably higher than DeepSpeed ZeRO-3 (LLaMA-Factory) or AReaL Megatron, both of which maintain fp32 master weights and fp32 m/v by default.
Reproduction
8 × L20X (143 GB), single node, identical recipe across three frameworks:
- model:
Qwen/Qwen3-8B, bf16, gradient_checkpointing, no rope_scaling
- data: 3520 SWE-bench-style trajectories, 32K cutoff, sorted by
instance_id, no shuffling
- training: 4 epochs, gbs=64 (per_dev=1, gas=8, world=8), lr=1e-4 cosine, warmup_ratio=0.1, wd=0.01, clip=1.0, bf16
- 220 optimizer steps, identical lr schedule across frameworks (verified to <1e-9)
Final loss after 220 steps:
| Framework |
Optimizer state dtype |
Final loss |
| LLaMA-Factory + DeepSpeed ZeRO-3 |
fp32 (DS default) |
0.0464 |
AReaL Megatron d8t1p1c1 |
fp32 (DistributedOptimizer default) |
0.0480 |
AReaL Megatron d4t1p1c2 (CP=2) |
fp32 |
0.0467 |
AReaL Megatron d4t2p1c1 (TP=2) |
fp32 |
0.0466 |
AReaL Megatron d4t1p2c1 (PP=2) |
fp32 |
0.0470 |
AReaL FSDP d8p1t1 |
bf16 ❗ |
0.1309 ← ~3× higher |
LF and Megatron variants all sit at 0.046–0.048; AReaL FSDP plateaus at 0.131. The divergence starts around step 50, becomes visible by step 100 (LF: 0.152, FSDP: 0.177), and grows to 0.08+ absolute by step 200. grad_norm stays well-behaved (~0.2-0.3) so it's not a divergence; it's a plateau caused by noisy Adam updates as √v̂ loses precision in bf16 when v̂ is small.
Root cause
areal/engine/fsdp_engine.py around the AdamW construction:
self.optimizer = torch.optim.AdamW(
self.model.parameters(), # bf16 params after model.to(bfloat16)
lr=lr,
weight_decay=weight_decay,
...
fused=not (self.is_vision_model and self.parallel_helper.tp_enabled),
)
Verification:
import torch
m = torch.nn.Linear(8, 8).bfloat16()
opt = torch.optim.AdamW(m.parameters(), lr=1e-4, fused=False)
m(torch.randn(2, 8, dtype=torch.bfloat16)).sum().backward()
opt.step()
print(opt.state[m.weight]['exp_avg'].dtype) # torch.bfloat16
print(opt.state[m.weight]['exp_avg_sq'].dtype) # torch.bfloat16
The FSDP2 MixedPrecisionPolicy(param_dtype=bf16, reduce_dtype=fp32) set up in parallelize_model only governs forward/backward compute and gradient reduction; it does not maintain fp32 master weights, so the optimizer keeps everything in bf16.
DeepSpeed ZeRO-3 and Megatron Core's DistributedOptimizer both default to maintaining fp32 master weights + fp32 m/v even when the model is bf16; this is the standard mixed-precision recipe and the reason their final loss matches.
Suggested fixes
Two options, in increasing order of intrusiveness:
Option A: switch the default to adam_bf16 for bf16 models
AnyPrecisionAdamW (already in areal/engine/fsdp_utils/optimizer.py) keeps m/v in bf16 but uses Kahan summation for the parameter update, recovering fp32-equivalent training stability. Per the AnyPrecision paper this is essentially indistinguishable from fp32 Adam in convergence.
Currently AReaL warns "may be less stable" for adam_bf16, but for bf16 dense/MoE models it appears to be the more correct choice than the current bf16 m/v + non-Kahan AdamW.
Option B: maintain fp32 master weights via FSDP2 mixed precision
Load the model in fp32, let MixedPrecisionPolicy(param_dtype=bf16) cast forward/backward to bf16, while the underlying parameter storage (and therefore model.parameters() seen by AdamW) stays fp32. This is the FSDP2 standard mixed-precision recipe and matches DS-Z3 / Megatron behavior bit-for-bit (modulo all-reduce ordering).
Memory cost: +16 GB params storage on 8B models. On L20X 143 GB this is well within budget.
A docs/warning addition pointing users to adam_bf16 would be a low-cost first step.
Environment
- AReaL:
wht/swe-dev branch, latest commit 52aa0cebb fix(megatron): pass absolute total_train_steps as lr_decay_steps
- transformers: 4.55 / 4.57.1 (FSDP path)
- torch: 2.x with FSDP2
- 8 × NVIDIA L20X (143 GB), single node
Related
While debugging this I also found that areal/engine/megatron_engine.py was passing lr_decay_steps = total_train_steps - warmup_steps to OptimizerParamScheduler, which Megatron Core then internally subtracts warmup from again, causing cosine to reach min_lr ~10% earlier than intended. Fixed in the same branch.
Summary
When using the FSDP backend with
actor.dtype: bfloat16and the defaultactor.optimizer.type: adam, AdamW's optimizer state (exp_avg,exp_avg_sq) is silently created in bf16 instead of fp32, becausetorch.optim.AdamW(self.model.parameters(), ...)inherits the parameter dtype. This causes the late-stage loss to plateau noticeably higher than DeepSpeed ZeRO-3 (LLaMA-Factory) or AReaL Megatron, both of which maintain fp32 master weights and fp32 m/v by default.Reproduction
8 × L20X (143 GB), single node, identical recipe across three frameworks:
Qwen/Qwen3-8B, bf16, gradient_checkpointing, no rope_scalinginstance_id, no shufflingFinal loss after 220 steps:
d8t1p1c1d4t1p1c2(CP=2)d4t2p1c1(TP=2)d4t1p2c1(PP=2)d8p1t1LF and Megatron variants all sit at 0.046–0.048; AReaL FSDP plateaus at 0.131. The divergence starts around step 50, becomes visible by step 100 (LF: 0.152, FSDP: 0.177), and grows to 0.08+ absolute by step 200.
grad_normstays well-behaved (~0.2-0.3) so it's not a divergence; it's a plateau caused by noisy Adam updates as√v̂loses precision in bf16 whenv̂is small.Root cause
areal/engine/fsdp_engine.pyaround the AdamW construction:Verification:
The FSDP2
MixedPrecisionPolicy(param_dtype=bf16, reduce_dtype=fp32)set up inparallelize_modelonly governs forward/backward compute and gradient reduction; it does not maintain fp32 master weights, so the optimizer keeps everything in bf16.DeepSpeed ZeRO-3 and Megatron Core's DistributedOptimizer both default to maintaining fp32 master weights + fp32 m/v even when the model is bf16; this is the standard mixed-precision recipe and the reason their final loss matches.
Suggested fixes
Two options, in increasing order of intrusiveness:
Option A: switch the default to
adam_bf16for bf16 modelsAnyPrecisionAdamW(already inareal/engine/fsdp_utils/optimizer.py) keeps m/v in bf16 but uses Kahan summation for the parameter update, recovering fp32-equivalent training stability. Per the AnyPrecision paper this is essentially indistinguishable from fp32 Adam in convergence.Currently AReaL warns "may be less stable" for
adam_bf16, but for bf16 dense/MoE models it appears to be the more correct choice than the current bf16 m/v + non-Kahan AdamW.Option B: maintain fp32 master weights via FSDP2 mixed precision
Load the model in fp32, let
MixedPrecisionPolicy(param_dtype=bf16)cast forward/backward to bf16, while the underlying parameter storage (and thereforemodel.parameters()seen by AdamW) stays fp32. This is the FSDP2 standard mixed-precision recipe and matches DS-Z3 / Megatron behavior bit-for-bit (modulo all-reduce ordering).Memory cost: +16 GB params storage on 8B models. On L20X 143 GB this is well within budget.
A docs/warning addition pointing users to
adam_bf16would be a low-cost first step.Environment
wht/swe-devbranch, latest commit52aa0cebb fix(megatron): pass absolute total_train_steps as lr_decay_stepsRelated
While debugging this I also found that
areal/engine/megatron_engine.pywas passinglr_decay_steps = total_train_steps - warmup_stepstoOptimizerParamScheduler, which Megatron Core then internally subtracts warmup from again, causing cosine to reachmin_lr~10% earlier than intended. Fixed in the same branch.