Skip to content

[megatron] fix: dynamic context parallel batch splitting and loss normalization#5869

Open
Kite0011 wants to merge 2 commits intoverl-project:mainfrom
Kite0011:fix_dcp
Open

[megatron] fix: dynamic context parallel batch splitting and loss normalization#5869
Kite0011 wants to merge 2 commits intoverl-project:mainfrom
Kite0011:fix_dcp

Conversation

@Kite0011
Copy link
Copy Markdown
Contributor

@Kite0011 Kite0011 commented Apr 2, 2026

What does this PR do?

修复 Dynamic Context Parallel (DCP) 在 Megatron 引擎中的 batch 切分逻辑和 loss 归一化问题。

相关改动:

  1. 重写 dynamic_cp_split_batch:修复当序列数少于 dp_size 时直接 return 导致 CP
    未生效的问题;改用正确的 power-of-2 CP size 计算逻辑,支持不均匀序列分配
  2. 修复 forward_backward_batch 中 DCP 的 loss 归一化:DCP 模式下 sp_size 应为整个 DP world size,且
    batch_num_tokens 不应在 DP group 内 all-reduce(因为 CP 子组内的 token 是重复的)
  3. Megatron recomputation patch 兼容性:增加 _recover_function_args 检查,避免在旧版 Megatron(无
    saved_objects 支持)上 crash
  • 参数设置:
    • 当开启dynamic_context_parallel时,cp_size需要设置成1
    • 当开启dynamic_context_parallel时,建议max_seqlen_per_dp_cp_rank设置跟max_token_len_per_gpu一样

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, veomni, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward, fully_async, one_step_off
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

verl/utils/megatron_utils.py — dynamic_cp_split_batch:

  • local_cp_size 计算:ceil(max_seq_len / max_seqlen_per_dp_cp_rank) 后 round up to power-of-2
  • 新增:当序列数不足以覆盖所有 DP 子组时,自动增大 local_cp_size 保证每个子组至少分到一条序列
  • 数据切分:按 local_dp_rank 均匀分配序列(处理余数),替代原来的 ceil 分配(可能越界)
  • 写入 dp_size=local_dp_size 用于下游 loss 归一化

verl/workers/engine/megatron/transformer_impl.py — forward_backward_batch:

  • DCP 模式下 sp_size 设为 mpu.get_data_parallel_world_size()(而非 context_parallel_size)
  • DCP 模式下跳过 batch_num_tokens 的 DP all-reduce
  • dp_size 统一使用 mpu.get_data_parallel_world_size()

verl/models/mcore/patch.py — apply_patch_megatron_recomputation_backward:

  • 增加 hasattr(rd.CheckpointFunction, '_recover_function_args') 前置检查,兼容旧版 Megatron

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

@Kite0011
Copy link
Copy Markdown
Contributor Author

Kite0011 commented Apr 2, 2026

@ISEEKYAN

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a compatibility check for older Megatron versions and refactors the dynamic context parallel (DCP) batch splitting logic. It also modifies the forward-backward pass to accommodate DCP. However, the review identifies several critical correctness issues related to gradient scaling and loss normalization. Specifically, skipping the global token count synchronization and using inconsistent values for sequence and data parallel sizes (sp_size and dp_size) in DCP mode are likely to cause mathematical errors and gradient bias during distributed training.

Comment on lines +591 to +596
if self.engine_config.dynamic_context_parallel:
pass
else:
torch.distributed.all_reduce(
batch_num_tokens, op=torch.distributed.ReduceOp.SUM, group=self.get_data_parallel_group()
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

在开启 dynamic_context_parallel 时跳过 batch_num_tokensall_reduce 会导致梯度缩放错误。在动态 batch 模式下,不同 rank 处理的 token 数量通常不同。如果每个 rank 仅使用其本地 token 计数进行归一化,随后 DDP 对这些 loss 取平均,将导致梯度产生偏差(得到的是本地平均值的平均值,而非真正的全局平均 loss)。PR 描述中提到的“CP 子组内 token 重复”的理由与标准 Megatron Context Parallel 行为不符,CP 通常是在子组内切分序列而非复制。即使序列被复制,不一致的归一化分母也会导致分布式训练中的数学错误。

References
  1. 分布式训练中,loss 归一化必须使用全局 token 总数,否则在 token 数量不均衡时会导致梯度偏差。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在dcp的场景下,每个卡初始就会拿到所有数据,所以可以跳过这里的allreduce

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Key design: In DCP mode, context_parallel_size is set to 1 in the Megatron config. This means Megatron treats all
ranks as a single DP group. All ranks receive the identical batch from the dataloader. The dynamic CP grouping happens
at runtime inside dynamic_cp_split_batch.

▎ Concern 1 (skip all-reduce for batch_num_tokens):
▎ Since all ranks start with the same batch, loss_mask.sum() already equals the global token count T_total on every
rank. An all-reduce would incorrectly produce W × T_total. Skipping it is correct.

Comment on lines +584 to +587
if self.engine_config.dynamic_context_parallel:
tu.assign_non_tensor(data, sp_size=mpu.get_data_parallel_world_size())
else:
tu.assign_non_tensor(data, sp_size=self.engine_config.context_parallel_size)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

在 DCP 模式下将 sp_size 设置为整个 world size (mpu.get_data_parallel_world_size()) 在逻辑上是不一致的。在 Megatron 中,sp_size (或 cp_size) 通常代表序列并行度。由于 DCP 在 CP 组内切分序列,梯度并不会在 CP rank 之间重复,因此为了归一化,sp_size 实际上应该为 1,或者设置为实际的动态 local_cp_size。将其设置为整个 world size 会导致梯度被过度缩小(如果 sp_size 在 loss 计算中被用作除数来处理 CP 重复的话)。

References
  1. 并行度参数(TP, PP, DP, CP)的乘积应等于总 GPU 数量。将多个并行维度同时设为 world size 会导致逻辑矛盾和错误的缩放因子。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sp_size is not used in the loss formula (loss = -masked_sum / batch_num_tokens * dp_size). It's only consumed in
prepare_micro_batches() for computing max_token_len = max_token_len_per_gpu * sp_size, which determines micro-batch
splitting granularity. Setting it to W correctly represents the total token capacity across all ranks.

# print(f"rank={torch.distributed.get_rank()}, local_cp_size={local_cp_size} max_seq_len={max_seq_len}")
# --- attach metadata used by downstream model forward and loss ---
tu.assign_non_tensor_data(batch, "local_cp_size", local_cp_size)
tu.assign_non_tensor(batch, dp_size=local_dp_size)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

dp_size 覆盖为 local_dp_size 可能会导致梯度缩放错误。如果 local_cp_size 个 rank 处理同一个序列(导致梯度重复),DDP 的 all-reduce(在整个 world size 上取平均)已经会自动产生 local_dp_size 个独立 shard 的平均梯度。例如,若 world size 为 8,local_cp_size 为 4,则 local_dp_size 为 2。DDP 计算 (4*G1 + 4*G2) / 8 = (G1 + G2) / 2,这正是正确的平均值。如果按照 docstring 提到的公式再乘以 dp_size (2),最终梯度将变为 shard 梯度的总和而非平均值,导致学习率实际上翻倍。

References
  1. 在混合并行模式下,loss 归一化必须精确补偿 DDP 的平均行为和 CP 的重复/求和行为。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works together with the unchanged batch_num_tokens = T_total. Within a CP subgroup of size C, all ranks compute
the same loss (outputs are all-gathered via postprocess_thd_engine). After DDP averaging across W = C × D ranks:

▎ grad_avg = Σ_j(C × loss_j × D / T_total) / (C × D) = Σ_j(loss_j) / T_total ✓

▎ The three changes form a consistent system — batch_num_tokens = T_total (no all-reduce needed), dp_size = D
(compensates for C-fold gradient duplication in DDP), and the math produces the correct global average loss.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant