-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[megatron] fix: dynamic context parallel batch splitting and loss normalization #5869
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -581,15 +581,21 @@ def load_checkpoint( | |
| offload_megatron_optimizer(self.optimizer) | ||
|
|
||
| def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forward_only=False) -> Any: | ||
| tu.assign_non_tensor(data, sp_size=self.engine_config.context_parallel_size) | ||
| if self.engine_config.dynamic_context_parallel: | ||
| tu.assign_non_tensor(data, sp_size=mpu.get_data_parallel_world_size()) | ||
| else: | ||
| tu.assign_non_tensor(data, sp_size=self.engine_config.context_parallel_size) | ||
|
Comment on lines
+584
to
+587
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在 DCP 模式下将 References
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sp_size is not used in the loss formula (loss = -masked_sum / batch_num_tokens * dp_size). It's only consumed in |
||
|
|
||
| # compute num_tokens in global batch for loss normalization | ||
| batch_num_tokens = data["loss_mask"].sum().to(get_device_id()) | ||
| torch.distributed.all_reduce( | ||
| batch_num_tokens, op=torch.distributed.ReduceOp.SUM, group=self.get_data_parallel_group() | ||
| ) | ||
| if self.engine_config.dynamic_context_parallel: | ||
| pass | ||
| else: | ||
| torch.distributed.all_reduce( | ||
| batch_num_tokens, op=torch.distributed.ReduceOp.SUM, group=self.get_data_parallel_group() | ||
| ) | ||
|
Comment on lines
+591
to
+596
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在开启 References
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在dcp的场景下,每个卡初始就会拿到所有数据,所以可以跳过这里的allreduce
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Key design: In DCP mode, context_parallel_size is set to 1 in the Megatron config. This means Megatron treats all ▎ Concern 1 (skip all-reduce for batch_num_tokens): |
||
| tu.assign_non_tensor(data, batch_num_tokens=batch_num_tokens.item()) | ||
| tu.assign_non_tensor(data, dp_size=self.get_data_parallel_size()) | ||
| tu.assign_non_tensor(data, dp_size=mpu.get_data_parallel_world_size()) | ||
|
|
||
| vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() | ||
| if vpp_size is not None and vpp_size > 1: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
将
dp_size覆盖为local_dp_size可能会导致梯度缩放错误。如果local_cp_size个 rank 处理同一个序列(导致梯度重复),DDP 的 all-reduce(在整个 world size 上取平均)已经会自动产生local_dp_size个独立 shard 的平均梯度。例如,若 world size 为 8,local_cp_size为 4,则local_dp_size为 2。DDP 计算(4*G1 + 4*G2) / 8 = (G1 + G2) / 2,这正是正确的平均值。如果按照 docstring 提到的公式再乘以dp_size(2),最终梯度将变为 shard 梯度的总和而非平均值,导致学习率实际上翻倍。References
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works together with the unchanged batch_num_tokens = T_total. Within a CP subgroup of size C, all ranks compute
the same loss (outputs are all-gathered via postprocess_thd_engine). After DDP averaging across W = C × D ranks:
▎ grad_avg = Σ_j(C × loss_j × D / T_total) / (C × D) = Σ_j(loss_j) / T_total ✓
▎ The three changes form a consistent system — batch_num_tokens = T_total (no all-reduce needed), dp_size = D
(compensates for C-fold gradient duplication in DDP), and the math produces the correct global average loss.