Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions verl/models/mcore/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,12 @@ def apply_patch_megatron_recomputation_backward():
import megatron.core.tensor_parallel.random as rd
import torch

# Only apply patch if megatron CheckpointFunction has saved_objects support
# (i.e., has _recover_function_args). Older megatron versions don't save
# non-tensor args and this patch would crash on ctx.saved_objects access.
if not hasattr(rd.CheckpointFunction, '_recover_function_args'):
return

_fork_rng = rd._fork_rng
_set_all_rng_states = rd._set_all_rng_states
detach_variable = rd.detach_variable
Expand Down
88 changes: 55 additions & 33 deletions verl/utils/megatron_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,47 +1368,69 @@ def get_megatron_module_device(models: list[Any]) -> str:
def dynamic_cp_split_batch(
batch: TensorDict, engine_config: McoreEngineConfig, dp_size: int, dp_rank: int
) -> TensorDict:
"""Split a micro-batch for Dynamic Context Parallel.

Preconditions (enforced by caller):
- engine.context_parallel_size is 1 when DCP is enabled, so dp_size
already spans the full DP×CP space.
- max_token_len_per_gpu == max_seqlen_per_dp_cp_rank.

The function decides a *local_cp_size* (power-of-2) based on the longest
sequence in the micro-batch. Consecutive dp_ranks are grouped into CP
sub-groups of that size, and the remaining dp dimension becomes the
effective data parallelism (*local_dp_size = dp_size / local_cp_size*).

After splitting, ``batch["dp_size"]`` is overwritten to *local_dp_size*
so that the downstream loss formula
``loss = -masked_sum / batch_num_tokens * dp_size``
automatically compensates for the duplicated CP gradients during the DP
all-reduce.
"""
Split the batch into sub-batches for dynamic context parallel.
import math

we can spilt a microbatch into several sub-batches with different local_cp_size, but for simplicity now,
we only split the batch into a fixed local_cp_size.

"""
input_ids = batch["input_ids"]
assert input_ids.is_nested, "input_ids must be a nested tensor"
seq_len_effective: torch.Tensor = input_ids.offsets().diff()
max_seq_len = max(seq_len_effective)
# if num of sequences is less than dp_size, we don't need to split the batch
local_cp_size = None
if len(seq_len_effective) < dp_size:
local_cp_size = dp_size
return batch
max_seq_len = int(max(seq_len_effective))
num_seqs = len(seq_len_effective)
max_seqlen_per_dp_cp_rank = engine_config.max_seqlen_per_dp_cp_rank

# --- determine local_cp_size ---
local_cp_size = math.ceil(max_seq_len / max_seqlen_per_dp_cp_rank)
local_cp_size = 1 << (local_cp_size - 1).bit_length() if local_cp_size > 1 else 1

# Every DP sub-group must get at least one sequence; increase CP if needed.
min_cp_for_coverage = math.ceil(dp_size / num_seqs) if num_seqs > 0 else dp_size
if min_cp_for_coverage > 1:
min_cp_for_coverage = 1 << (min_cp_for_coverage - 1).bit_length()
local_cp_size = max(local_cp_size, min_cp_for_coverage)
local_cp_size = min(local_cp_size, dp_size)

local_dp_size = dp_size // local_cp_size
local_dp_rank = dp_rank // local_cp_size

# --- split data across local_dp groups ---
if local_dp_size > 1:
base_count = num_seqs // local_dp_size
remainder = num_seqs % local_dp_size
if local_dp_rank < remainder:
start_idx = local_dp_rank * (base_count + 1)
count = base_count + 1
else:
start_idx = remainder * (base_count + 1) + (local_dp_rank - remainder) * base_count
count = base_count
end_idx = start_idx + count
selected_indices = list(range(start_idx, end_idx))
batch = tu.index_select_tensor_dict(batch, selected_indices)
decision = "split"
else:
# decide the local_cp_size based on the max_seq_len and dp_size
max_seqlen_per_dp_cp_rank = engine_config.max_seqlen_per_dp_cp_rank
import math
selected_indices = list(range(num_seqs))
decision = "no_split_full_cp"

local_cp_size = math.ceil(max_seq_len / max_seqlen_per_dp_cp_rank)
# round up to the nearest power of 2, for [1,2,3,4,5,6,7,8] -> [1,2,4,4,8,8,8,8]
local_cp_size = 1 << (local_cp_size - 1).bit_length()

assert local_cp_size <= dp_size, (
"local_cp_size must be less than or equal to dp_size, try to increase max_seqlen_per_dp_cp_rank"
)
if local_cp_size < dp_size:
# split the batch into local_cp_size sub-batches
local_dp_rank = dp_rank // local_cp_size
local_dp_size = dp_size // local_cp_size
indices = list(range(len(seq_len_effective)))
num_seq_per_local_cp = math.ceil(len(seq_len_effective) / local_dp_size)
start_idx = local_dp_rank * num_seq_per_local_cp
end_idx = min(start_idx + num_seq_per_local_cp, len(seq_len_effective))
selected_indices = indices[start_idx:end_idx]
batch = tu.index_select_tensor_dict(batch, selected_indices)

# print(f"rank={torch.distributed.get_rank()}, local_cp_size={local_cp_size} max_seq_len={max_seq_len}")
# --- attach metadata used by downstream model forward and loss ---
tu.assign_non_tensor_data(batch, "local_cp_size", local_cp_size)
tu.assign_non_tensor(batch, dp_size=local_dp_size)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

dp_size 覆盖为 local_dp_size 可能会导致梯度缩放错误。如果 local_cp_size 个 rank 处理同一个序列(导致梯度重复),DDP 的 all-reduce(在整个 world size 上取平均)已经会自动产生 local_dp_size 个独立 shard 的平均梯度。例如,若 world size 为 8,local_cp_size 为 4,则 local_dp_size 为 2。DDP 计算 (4*G1 + 4*G2) / 8 = (G1 + G2) / 2,这正是正确的平均值。如果按照 docstring 提到的公式再乘以 dp_size (2),最终梯度将变为 shard 梯度的总和而非平均值,导致学习率实际上翻倍。

References
  1. 在混合并行模式下,loss 归一化必须精确补偿 DDP 的平均行为和 CP 的重复/求和行为。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works together with the unchanged batch_num_tokens = T_total. Within a CP subgroup of size C, all ranks compute
the same loss (outputs are all-gathered via postprocess_thd_engine). After DDP averaging across W = C × D ranks:

▎ grad_avg = Σ_j(C × loss_j × D / T_total) / (C × D) = Σ_j(loss_j) / T_total ✓

▎ The three changes form a consistent system — batch_num_tokens = T_total (no all-reduce needed), dp_size = D
(compensates for C-fold gradient duplication in DDP), and the math produces the correct global average loss.


return batch


Expand Down
16 changes: 11 additions & 5 deletions verl/workers/engine/megatron/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,15 +581,21 @@ def load_checkpoint(
offload_megatron_optimizer(self.optimizer)

def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forward_only=False) -> Any:
tu.assign_non_tensor(data, sp_size=self.engine_config.context_parallel_size)
if self.engine_config.dynamic_context_parallel:
tu.assign_non_tensor(data, sp_size=mpu.get_data_parallel_world_size())
else:
tu.assign_non_tensor(data, sp_size=self.engine_config.context_parallel_size)
Comment on lines +584 to +587
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

在 DCP 模式下将 sp_size 设置为整个 world size (mpu.get_data_parallel_world_size()) 在逻辑上是不一致的。在 Megatron 中,sp_size (或 cp_size) 通常代表序列并行度。由于 DCP 在 CP 组内切分序列,梯度并不会在 CP rank 之间重复,因此为了归一化,sp_size 实际上应该为 1,或者设置为实际的动态 local_cp_size。将其设置为整个 world size 会导致梯度被过度缩小(如果 sp_size 在 loss 计算中被用作除数来处理 CP 重复的话)。

References
  1. 并行度参数(TP, PP, DP, CP)的乘积应等于总 GPU 数量。将多个并行维度同时设为 world size 会导致逻辑矛盾和错误的缩放因子。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sp_size is not used in the loss formula (loss = -masked_sum / batch_num_tokens * dp_size). It's only consumed in
prepare_micro_batches() for computing max_token_len = max_token_len_per_gpu * sp_size, which determines micro-batch
splitting granularity. Setting it to W correctly represents the total token capacity across all ranks.


# compute num_tokens in global batch for loss normalization
batch_num_tokens = data["loss_mask"].sum().to(get_device_id())
torch.distributed.all_reduce(
batch_num_tokens, op=torch.distributed.ReduceOp.SUM, group=self.get_data_parallel_group()
)
if self.engine_config.dynamic_context_parallel:
pass
else:
torch.distributed.all_reduce(
batch_num_tokens, op=torch.distributed.ReduceOp.SUM, group=self.get_data_parallel_group()
)
Comment on lines +591 to +596
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

在开启 dynamic_context_parallel 时跳过 batch_num_tokensall_reduce 会导致梯度缩放错误。在动态 batch 模式下,不同 rank 处理的 token 数量通常不同。如果每个 rank 仅使用其本地 token 计数进行归一化,随后 DDP 对这些 loss 取平均,将导致梯度产生偏差(得到的是本地平均值的平均值,而非真正的全局平均 loss)。PR 描述中提到的“CP 子组内 token 重复”的理由与标准 Megatron Context Parallel 行为不符,CP 通常是在子组内切分序列而非复制。即使序列被复制,不一致的归一化分母也会导致分布式训练中的数学错误。

References
  1. 分布式训练中,loss 归一化必须使用全局 token 总数,否则在 token 数量不均衡时会导致梯度偏差。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在dcp的场景下,每个卡初始就会拿到所有数据,所以可以跳过这里的allreduce

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Key design: In DCP mode, context_parallel_size is set to 1 in the Megatron config. This means Megatron treats all
ranks as a single DP group. All ranks receive the identical batch from the dataloader. The dynamic CP grouping happens
at runtime inside dynamic_cp_split_batch.

▎ Concern 1 (skip all-reduce for batch_num_tokens):
▎ Since all ranks start with the same batch, loss_mask.sum() already equals the global token count T_total on every
rank. An all-reduce would incorrectly produce W × T_total. Skipping it is correct.

tu.assign_non_tensor(data, batch_num_tokens=batch_num_tokens.item())
tu.assign_non_tensor(data, dp_size=self.get_data_parallel_size())
tu.assign_non_tensor(data, dp_size=mpu.get_data_parallel_world_size())

vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
if vpp_size is not None and vpp_size > 1:
Expand Down