[megatron] fix: dynamic context parallel batch splitting and loss normalization#5869
[megatron] fix: dynamic context parallel batch splitting and loss normalization#5869Kite0011 wants to merge 2 commits intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a compatibility check for older Megatron versions and refactors the dynamic context parallel (DCP) batch splitting logic. It also modifies the forward-backward pass to accommodate DCP. However, the review identifies several critical correctness issues related to gradient scaling and loss normalization. Specifically, skipping the global token count synchronization and using inconsistent values for sequence and data parallel sizes (sp_size and dp_size) in DCP mode are likely to cause mathematical errors and gradient bias during distributed training.
| if self.engine_config.dynamic_context_parallel: | ||
| pass | ||
| else: | ||
| torch.distributed.all_reduce( | ||
| batch_num_tokens, op=torch.distributed.ReduceOp.SUM, group=self.get_data_parallel_group() | ||
| ) |
There was a problem hiding this comment.
在开启 dynamic_context_parallel 时跳过 batch_num_tokens 的 all_reduce 会导致梯度缩放错误。在动态 batch 模式下,不同 rank 处理的 token 数量通常不同。如果每个 rank 仅使用其本地 token 计数进行归一化,随后 DDP 对这些 loss 取平均,将导致梯度产生偏差(得到的是本地平均值的平均值,而非真正的全局平均 loss)。PR 描述中提到的“CP 子组内 token 重复”的理由与标准 Megatron Context Parallel 行为不符,CP 通常是在子组内切分序列而非复制。即使序列被复制,不一致的归一化分母也会导致分布式训练中的数学错误。
References
- 分布式训练中,loss 归一化必须使用全局 token 总数,否则在 token 数量不均衡时会导致梯度偏差。
There was a problem hiding this comment.
在dcp的场景下,每个卡初始就会拿到所有数据,所以可以跳过这里的allreduce
There was a problem hiding this comment.
Key design: In DCP mode, context_parallel_size is set to 1 in the Megatron config. This means Megatron treats all
ranks as a single DP group. All ranks receive the identical batch from the dataloader. The dynamic CP grouping happens
at runtime inside dynamic_cp_split_batch.
▎ Concern 1 (skip all-reduce for batch_num_tokens):
▎ Since all ranks start with the same batch, loss_mask.sum() already equals the global token count T_total on every
rank. An all-reduce would incorrectly produce W × T_total. Skipping it is correct.
| if self.engine_config.dynamic_context_parallel: | ||
| tu.assign_non_tensor(data, sp_size=mpu.get_data_parallel_world_size()) | ||
| else: | ||
| tu.assign_non_tensor(data, sp_size=self.engine_config.context_parallel_size) |
There was a problem hiding this comment.
在 DCP 模式下将 sp_size 设置为整个 world size (mpu.get_data_parallel_world_size()) 在逻辑上是不一致的。在 Megatron 中,sp_size (或 cp_size) 通常代表序列并行度。由于 DCP 在 CP 组内切分序列,梯度并不会在 CP rank 之间重复,因此为了归一化,sp_size 实际上应该为 1,或者设置为实际的动态 local_cp_size。将其设置为整个 world size 会导致梯度被过度缩小(如果 sp_size 在 loss 计算中被用作除数来处理 CP 重复的话)。
References
- 并行度参数(TP, PP, DP, CP)的乘积应等于总 GPU 数量。将多个并行维度同时设为 world size 会导致逻辑矛盾和错误的缩放因子。
There was a problem hiding this comment.
sp_size is not used in the loss formula (loss = -masked_sum / batch_num_tokens * dp_size). It's only consumed in
prepare_micro_batches() for computing max_token_len = max_token_len_per_gpu * sp_size, which determines micro-batch
splitting granularity. Setting it to W correctly represents the total token capacity across all ranks.
| # print(f"rank={torch.distributed.get_rank()}, local_cp_size={local_cp_size} max_seq_len={max_seq_len}") | ||
| # --- attach metadata used by downstream model forward and loss --- | ||
| tu.assign_non_tensor_data(batch, "local_cp_size", local_cp_size) | ||
| tu.assign_non_tensor(batch, dp_size=local_dp_size) |
There was a problem hiding this comment.
将 dp_size 覆盖为 local_dp_size 可能会导致梯度缩放错误。如果 local_cp_size 个 rank 处理同一个序列(导致梯度重复),DDP 的 all-reduce(在整个 world size 上取平均)已经会自动产生 local_dp_size 个独立 shard 的平均梯度。例如,若 world size 为 8,local_cp_size 为 4,则 local_dp_size 为 2。DDP 计算 (4*G1 + 4*G2) / 8 = (G1 + G2) / 2,这正是正确的平均值。如果按照 docstring 提到的公式再乘以 dp_size (2),最终梯度将变为 shard 梯度的总和而非平均值,导致学习率实际上翻倍。
References
- 在混合并行模式下,loss 归一化必须精确补偿 DDP 的平均行为和 CP 的重复/求和行为。
There was a problem hiding this comment.
This works together with the unchanged batch_num_tokens = T_total. Within a CP subgroup of size C, all ranks compute
the same loss (outputs are all-gathered via postprocess_thd_engine). After DDP averaging across W = C × D ranks:
▎ grad_avg = Σ_j(C × loss_j × D / T_total) / (C × D) = Σ_j(loss_j) / T_total ✓
▎ The three changes form a consistent system — batch_num_tokens = T_total (no all-reduce needed), dp_size = D
(compensates for C-fold gradient duplication in DDP), and the math produces the correct global average loss.
What does this PR do?
修复 Dynamic Context Parallel (DCP) 在 Megatron 引擎中的 batch 切分逻辑和 loss 归一化问题。
相关改动:
未生效的问题;改用正确的 power-of-2 CP size 计算逻辑,支持不均匀序列分配
batch_num_tokens 不应在 DP group 内 all-reduce(因为 CP 子组内的 token 是重复的)
saved_objects 支持)上 crash
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
verl/utils/megatron_utils.py — dynamic_cp_split_batch:
verl/workers/engine/megatron/transformer_impl.py — forward_backward_batch:
verl/models/mcore/patch.py — apply_patch_megatron_recomputation_backward:
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.