Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
11e1aff
[DEV] feat(MoE): Refactor cuda_graph_scope (#1917)
buptzyb Nov 5, 2025
b16a78c
main golden
buptzyb Nov 5, 2025
3681d7a
fix cudagraph ut
buptzyb Nov 13, 2025
6f139f1
enum CudaGraphScope
buptzyb Nov 13, 2025
a35489f
minor fixes
buptzyb Nov 14, 2025
ecf173d
Merge branch 'main' into robinz/refactor_cuda_graph_scope
buptzyb Nov 20, 2025
0337f20
minor updates
buptzyb Nov 20, 2025
20cb013
update
buptzyb Nov 20, 2025
63a958f
remove None check in language_module
buptzyb Nov 21, 2025
a3607af
update hybridep cudagraph ut
buptzyb Nov 21, 2025
be5c462
add nvtx_range_pop
buptzyb Nov 26, 2025
97f73e9
revert breaking API change
buptzyb Nov 26, 2025
e825232
Merge branch 'main' into robinz/refactor_cuda_graph_scope
buptzyb Nov 26, 2025
40a89ce
revert breaking API change
buptzyb Nov 26, 2025
aa0db95
Merge branch 'main' into robinz/refactor_cuda_graph_scope
buptzyb Dec 1, 2025
9639ab9
cherry-pick dev PR 2353
buptzyb Dec 1, 2025
8ce611c
Replay "[Dev] feat(MoE): Refactor cuda_graph_scope - part2 (#2353)" (…
buptzyb Dec 2, 2025
0823434
Merge branch 'main' into robinz/refactor_cuda_graph_scope
buptzyb Dec 2, 2025
f9857d2
Add functools.wraps
buptzyb Dec 4, 2025
f975fd3
Merge branch 'main' into robinz/refactor_cuda_graph_scope
buptzyb Dec 5, 2025
625e5b7
update test_fp8_param cudagraph ut
buptzyb Dec 5, 2025
14a9668
Merge branch 'main' into robinz/refactor_cuda_graph_scope
buptzyb Dec 8, 2025
96953f4
Merge branch 'main' into robinz/refactor_cuda_graph_scope
buptzyb Dec 10, 2025
4361b7f
Merge branch 'main' into robinz/refactor_cuda_graph_scope
buptzyb Dec 11, 2025
0ccb425
improve recompute checks
buptzyb Dec 17, 2025
de638ad
Merge branch 'main' into robinz/refactor_cuda_graph_scope
buptzyb Dec 17, 2025
b7abdd7
Merge branch 'main' into robinz/refactor_cuda_graph_scope
buptzyb Dec 18, 2025
85583d0
fix backward compatibility
buptzyb Dec 18, 2025
0698a59
disable hybridep ut
buptzyb Dec 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.utils import get_attention_mask, set_decode_expert_padding
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.moe.moe_layer import BaseMoELayer
from megatron.core.transformer.utils import set_model_to_sequence_parallel
from megatron.core.utils import get_asyncio_loop, get_model_config, unwrap_model
Expand Down Expand Up @@ -987,7 +988,7 @@ def generate_all_output_tokens_static_batch(
# Check whether CUDA graphs are enabled
enable_cuda_graph = (
model_config.cuda_graph_impl == "local"
and model_config.cuda_graph_scope != "full_iteration"
and CudaGraphScope.full_iteration not in model_config.cuda_graph_scope
)

# Pad batch tokens if necessary
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
import logging
import os
from typing import Optional, Tuple
Expand All @@ -21,7 +21,7 @@
is_vp_last_stage,
)
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.enums import AttnBackend
from megatron.core.transformer.enums import AttnBackend, CudaGraphScope
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import ensure_metadata_has_dp_cp_group
Expand Down Expand Up @@ -142,7 +142,7 @@ def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor:
# Use is_cg_capturable=True for full iteration CUDA graphs to avoid torch.equal checks
is_cg_capturable = (
hasattr(self.config, 'cuda_graph_scope')
and self.config.cuda_graph_scope == 'full_iteration'
and CudaGraphScope.full_iteration in self.config.cuda_graph_scope
)
if is_cg_capturable and not is_te_min_version("2.7.0"):
from megatron.core.utils import get_te_version
Expand Down
7 changes: 4 additions & 3 deletions megatron/core/models/gpt/fine_grained_callables.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

import weakref
from contextlib import nullcontext
Expand Down Expand Up @@ -358,7 +358,8 @@ def submodule_post_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor)
else:
pre_mlp_layernorm_output = layer.pre_mlp_layernorm(hidden_states)

local_tokens, probs, _ = layer.mlp.router_and_preprocess(pre_mlp_layernorm_output)
probs, routing_map = layer.mlp.route(pre_mlp_layernorm_output)
local_tokens, probs = layer.mlp.preprocess(pre_mlp_layernorm_output, probs, routing_map)

# Detach here for mlp_bda residual connection
node.layer_state.residual = node.detach(hidden_states)
Expand Down Expand Up @@ -400,7 +401,7 @@ def submodule_moe_forward(node: ScheduleNode, dispatched_tokens: torch.Tensor):
pre_mlp_layernorm_output = getattr(node.layer_state, 'pre_mlp_layernorm_output', None)
shared_expert_output = layer.mlp.shared_experts_compute(pre_mlp_layernorm_output)
expert_output, mlp_bias = layer.mlp.routed_experts_compute(
dispatched_tokens, dispatched_probs, pre_mlp_layernorm_output
dispatched_tokens, dispatched_probs
)

if layer.recompute_pre_mlp_layernorm:
Expand Down
6 changes: 3 additions & 3 deletions megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

from collections import OrderedDict
from typing import Dict, Literal, Optional
Expand All @@ -21,7 +21,7 @@
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.quantization.utils import get_quant_config_or_none
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.enums import CudaGraphScope, ModelType
from megatron.core.transformer.multi_token_prediction import (
MTPLossAutoScaler,
MTPLossLoggingHelper,
Expand Down Expand Up @@ -371,7 +371,7 @@ def _preprocess(
and (
(
self.config.cuda_graph_impl == "local"
and self.config.cuda_graph_scope != "full_iteration"
and CudaGraphScope.full_iteration not in self.config.cuda_graph_scope
)
or self.config.flash_decode
)
Expand Down
9 changes: 5 additions & 4 deletions megatron/core/pipeline_parallel/schedules.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

import contextlib
from functools import partial
Expand All @@ -18,6 +18,7 @@
)
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.utils import (
drain_embedding_wgrad_compute,
Expand Down Expand Up @@ -650,7 +651,7 @@ def forward_backward_no_pipelining(
if (
hasattr(config, 'cuda_graph_impl')
and config.cuda_graph_impl == "local"
and config.cuda_graph_scope != "full_iteration"
and CudaGraphScope.full_iteration not in config.cuda_graph_scope
):
create_cudagraphs()

Expand Down Expand Up @@ -1914,7 +1915,7 @@ def pp_post_backward(input_tensor_grad, vp_stage=None):
if (
hasattr(config, 'cuda_graph_impl')
and config.cuda_graph_impl == "local"
and config.cuda_graph_scope != "full_iteration"
and CudaGraphScope.full_iteration not in config.cuda_graph_scope
):
create_cudagraphs()
nvtx_range_pop(suffix="misc")
Expand Down Expand Up @@ -2298,7 +2299,7 @@ def enable_grad_sync():
if (
hasattr(config, 'cuda_graph_impl')
and config.cuda_graph_impl == "local"
and config.cuda_graph_scope != "full_iteration"
and CudaGraphScope.full_iteration not in config.cuda_graph_scope
):
create_cudagraphs()

Expand Down
3 changes: 2 additions & 1 deletion megatron/core/safe_globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from megatron.core.enums import ModelType
from megatron.core.optimizer import OptimizerConfig
from megatron.core.rerun_state_machine import RerunDiagnostic, RerunMode, RerunState
from megatron.core.transformer.enums import AttnBackend
from megatron.core.transformer.enums import AttnBackend, CudaGraphScope

SAFE_GLOBALS = [
SimpleNamespace,
Expand All @@ -25,6 +25,7 @@
UInt32DType,
Namespace,
AttnBackend,
CudaGraphScope,
ModelType,
OptimizerConfig,
RerunDiagnostic,
Expand Down
5 changes: 3 additions & 2 deletions megatron/core/ssm/mamba_block.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024, Tri Dao, Albert Gu.

# Some of this code was adopted from https://github.com/state-spaces/mamba/
Expand All @@ -22,6 +22,7 @@
from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols as LayerSymbols
from megatron.core.ssm.mamba_hybrid_layer_allocation import allocate_layers
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
Expand Down Expand Up @@ -245,7 +246,7 @@ def forward(
(
(
self.config.cuda_graph_impl == "local"
and self.config.cuda_graph_scope != "full_iteration"
and CudaGraphScope.full_iteration not in self.config.cuda_graph_scope
)
or self.config.flash_decode
)
Expand Down
5 changes: 5 additions & 0 deletions megatron/core/tensor_parallel/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,11 @@ def checkpoint(self, run_function, *args):

def _recompute(self, _):
"""Used as a hook to recompute the output."""

if self.ctx is None:
# The recomputation has been triggered already. Just return.
return

if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad(), "
Expand Down
6 changes: 3 additions & 3 deletions megatron/core/transformer/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
import copy
from abc import ABC, abstractmethod
from dataclasses import dataclass
Expand Down Expand Up @@ -41,7 +41,7 @@
from ..models.common.embeddings.yarn_rotary_pos_embedding import (
_yarn_get_concentration_factor_from_config,
)
from .enums import AttnMaskType
from .enums import AttnMaskType, CudaGraphScope
from .transformer_config import TransformerConfig

try:
Expand Down Expand Up @@ -851,7 +851,7 @@ def forward(
if (
in_decode_mode
and self.config.cuda_graph_impl == "local"
and self.config.cuda_graph_scope != "full_iteration"
and CudaGraphScope.full_iteration not in self.config.cuda_graph_scope
and inference_context.is_static_batching()
):
raise ValueError(f"CUDA graphs must use flash decode with static batching!")
Expand Down
Loading
Loading