Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

import asyncio
import concurrent
Expand Down Expand Up @@ -659,7 +659,7 @@ def generate_all_output_tokens_static_batch(
# Check whether CUDA graphs are enabled
enable_cuda_graph = (
model_config.cuda_graph_impl == "local"
and model_config.cuda_graph_scope != "full_iteration"
and "full_iteration" not in model_config.cuda_graph_scope
)

# Pad batch tokens if necessary
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
import logging
import os
from typing import Optional, Tuple
Expand Down Expand Up @@ -136,7 +136,7 @@ def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor:
# Use is_cg_capturable=True for full iteration CUDA graphs to avoid torch.equal checks
is_cg_capturable = (
hasattr(self.config, 'cuda_graph_scope')
and self.config.cuda_graph_scope == 'full_iteration'
and 'full_iteration' in self.config.cuda_graph_scope
)
if is_cg_capturable and not is_te_min_version("2.7.0"):
from megatron.core.utils import get_te_version
Expand Down
3 changes: 2 additions & 1 deletion megatron/core/models/gpt/fine_grained_callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,8 @@ def submodule_post_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor)
with get_fine_grained_offloading_context(layer.offload_mlp_norm):
pre_mlp_layernorm_output = layer.pre_mlp_layernorm(hidden_states)

local_tokens, probs, _ = layer.mlp.router_and_preprocess(pre_mlp_layernorm_output)
probs, routing_map = layer.mlp.route(pre_mlp_layernorm_output)
local_tokens, probs, _ = layer.mlp.preprocess(pre_mlp_layernorm_output, probs, routing_map)

# Detach here for mlp_bda residual connection
node.layer_state.residual = node.detach(hidden_states)
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def _preprocess(
and (
(
self.config.cuda_graph_impl == "local"
and self.config.cuda_graph_scope != "full_iteration"
and "full_iteration" not in self.config.cuda_graph_scope
)
or self.config.flash_decode
)
Expand Down
6 changes: 3 additions & 3 deletions megatron/core/pipeline_parallel/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ def forward_backward_no_pipelining(
if (
hasattr(config, 'cuda_graph_impl')
and config.cuda_graph_impl == "local"
and config.cuda_graph_scope != "full_iteration"
and "full_iteration" not in config.cuda_graph_scope
):
create_cudagraphs()

Expand Down Expand Up @@ -1921,7 +1921,7 @@ def pp_post_backward(input_tensor_grad, vp_stage=None):
if (
hasattr(config, 'cuda_graph_impl')
and config.cuda_graph_impl == "local"
and config.cuda_graph_scope != "full_iteration"
and "full_iteration" not in config.cuda_graph_scope
):
create_cudagraphs()
nvtx_range_pop(suffix="misc")
Expand Down Expand Up @@ -2308,7 +2308,7 @@ def enable_grad_sync():
if (
hasattr(config, 'cuda_graph_impl')
and config.cuda_graph_impl == "local"
and config.cuda_graph_scope != "full_iteration"
and "full_iteration" not in config.cuda_graph_scope
):
create_cudagraphs()

Expand Down
4 changes: 2 additions & 2 deletions megatron/core/ssm/mamba_block.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024, Tri Dao, Albert Gu.

# Some of this code was adopted from https://github.com/state-spaces/mamba/
Expand Down Expand Up @@ -301,7 +301,7 @@ def forward(
(
(
self.config.cuda_graph_impl == "local"
and self.config.cuda_graph_scope != "full_iteration"
and "full_iteration" not in self.config.cuda_graph_scope
)
or self.config.flash_decode
)
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ def forward(
if (
in_decode_mode
and self.config.cuda_graph_impl == "local"
and self.config.cuda_graph_scope != "full_iteration"
and "full_iteration" not in self.config.cuda_graph_scope
and inference_context.is_static_batching()
):
raise ValueError(f"CUDA graphs must use flash decode with static batching!")
Expand Down
86 changes: 62 additions & 24 deletions megatron/core/transformer/cuda_graphs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

import gc
import inspect
Expand All @@ -22,7 +22,7 @@
get_cuda_rng_tracker,
)
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.module import GraphableMegatronModule, MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import (
get_attr_wrapped_model,
Expand Down Expand Up @@ -1059,9 +1059,12 @@ def __init__(
), "RNG tracker does not support cudagraphs!"

assert config.cuda_graph_impl == "local", "Option cuda_graph_impl=local not enabled."
assert "expandable_segments:True" not in os.getenv("PYTORCH_CUDA_ALLOC_CONF", ""), (
"expandable_segments:True may not be safe when using CUDA Graphs, and may result in"
"a crash due to illegal memory access or other undefined behaviour."
assert (
"expandable_segments:True" not in os.getenv("PYTORCH_CUDA_ALLOC_CONF", "")
or os.getenv("NCCL_GRAPH_REGISTER", "") == "0"
), (
"Setting NCCL_GRAPH_REGISTER=0 to avoid illegal memory access when using "
"CUDA Graph with PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True."
)

self.cudagraph_runners = []
Expand Down Expand Up @@ -1311,23 +1314,40 @@ def _layer_is_graphable(layer, config):
Check if a layer is graphable.
"""

# Only GraphableMegatronModule can be graphed.
if not isinstance(layer, GraphableMegatronModule):
return False

# If cuda_graph_scope is not set, every layer is graphed.
if not config.cuda_graph_scope:
return True

# import modules here to avoid a circular import
from megatron.core.ssm.mamba_layer import MambaLayer
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.transformer_layer import TransformerLayer

if isinstance(layer, MambaLayer) and config.cuda_graph_scope == "full":
if isinstance(layer, MambaLayer) and 'mamba' in config.cuda_graph_scope:
# mamba layer.
return True
if isinstance(layer, TransformerLayer):
if config.cuda_graph_scope == 'attn':
if not (
isinstance(layer.self_attention, IdentityOp)
and isinstance(layer.cross_attention, IdentityOp)
):
# attn layer.
return True
else:
if 'attn' in config.cuda_graph_scope and not (
isinstance(layer.self_attention, IdentityOp)
and isinstance(layer.cross_attention, IdentityOp)
):
# attn layer.
return True
if (
'moe' in config.cuda_graph_scope
or 'moe_router' in config.cuda_graph_scope
or 'moe_preprocess' in config.cuda_graph_scope
) and isinstance(layer.mlp, MoELayer):
# moe layer.
return True
if 'mlp' in config.cuda_graph_scope and isinstance(layer.mlp, MLP):
# mlp layer.
return True
return False

Expand All @@ -1346,18 +1366,17 @@ def __init__(self, model, config, seq_length, micro_batch_size, optimizers=[]):
assert (
config.cuda_graph_impl == "transformer_engine"
), "Option cuda_graph_impl=transformer_engine not enabled."
assert "expandable_segments:True" not in os.getenv("PYTORCH_CUDA_ALLOC_CONF", ""), (
"expandable_segments:True may not be safe when using CUDA Graphs, and may result in"
"a crash due to illegal memory access or other undefined behaviour."
assert (
"expandable_segments:True" not in os.getenv("PYTORCH_CUDA_ALLOC_CONF", "")
or os.getenv("NCCL_GRAPH_REGISTER", "") == "0"
), (
"Setting NCCL_GRAPH_REGISTER=0 to avoid illegal memory access when using "
"CUDA Graph with PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True."
)
assert config.cuda_graph_scope != "full_iteration", (
assert "full_iteration" not in config.cuda_graph_scope, (
"full_iteration cuda graph is not supported for cuda_graph_impl=transformer_engine. "
"Please use cuda_graph_impl=local instead."
)
assert config.cuda_graph_scope in [
'full',
'attn',
], f"--cuda-graph-scope should be full or attn, got {config.cuda_graph_scope}."

self.model = model
self.config = config
Expand Down Expand Up @@ -1440,6 +1459,16 @@ def __init__(self, model, config, seq_length, micro_batch_size, optimizers=[]):
f'{len(self.flattened_callables)} graphable layers.',
)

# One helper object can only capture CUDA Graphs once. Use this flag to check if the graphs
# have been created.
self._graphs_created = False

def graphs_created(self):
"""
Returns whether the CUDA Graphs have been created.
"""
return self._graphs_created

def _get_cuda_graph_input_data(self):
"""
Create the CUDA Graph capturing input data.
Expand Down Expand Up @@ -1480,8 +1509,13 @@ def get_rotary_pos_emb(transformer_module, transformer_input):
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.transformer_layer import TransformerLayer

contains_self_attn = isinstance(layer, TransformerLayer) and not isinstance(
layer.self_attention, IdentityOp
contains_self_attn = (
isinstance(layer, TransformerLayer)
and not isinstance(layer.self_attention, IdentityOp)
and (
not self.config.cuda_graph_scope
or 'attn' in self.config.cuda_graph_scope
)
)
if is_te_min_version("1.10.0"):
# te.make_graphed_callables() accepts keyword arguments since 1.10.0.
Expand Down Expand Up @@ -1590,6 +1624,8 @@ def _start_capturing(self):
"""
Start capturing CUDA Graphs.
"""
assert not self._graphs_created, "CUDA Graphs have already been created."

torch.distributed.barrier()
gc.collect()
torch.cuda.empty_cache()
Expand Down Expand Up @@ -1623,6 +1659,8 @@ def _finish_capturing(self, start_time):
gc.collect()
torch.cuda.empty_cache()

self._graphs_created = True

def create_cudagraphs(self):
"""
Capture CUDA Graphs per TransformerLayer per microbatch.
Expand Down
58 changes: 44 additions & 14 deletions megatron/core/transformer/moe/moe_layer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

from abc import ABC, abstractmethod
from dataclasses import dataclass
Expand All @@ -9,7 +9,12 @@
from megatron.core import parallel_state, tensor_parallel, utils
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.moe_utils import get_default_pg_collection
from megatron.core.transformer.moe.moe_utils import (
MoECudaGraphPartialCaptureSignal,
MoECudaGraphTensorStore,
get_default_pg_collection,
maybe_skip_or_early_return_by_cudagraph,
)
from megatron.core.transformer.moe.router import TopKRouter
from megatron.core.transformer.moe.token_dispatcher import (
MoEAllGatherTokenDispatcher,
Expand Down Expand Up @@ -169,29 +174,44 @@ def __init__(
if self.shared_expert_overlap:
self.token_dispatcher.set_shared_experts(self.shared_experts)

def router_and_preprocess(self, hidden_states: torch.Tensor):
"""Compute and preprocess token routing for dispatch.
# Cudagraph tensor store for resuming the forward pass from the end of the cudagraph.
self.cudagraph_tensor_store = MoECudaGraphTensorStore()

@maybe_skip_or_early_return_by_cudagraph("route")
def route(self, hidden_states: torch.Tensor):
"""Compute token routing for preprocessing.

This method uses the router to determine which experts to send each token to,
producing routing probabilities and a mapping. It then preprocesses the
hidden states and probabilities for the token dispatcher. The original
hidden states are returned as a residual connection.
producing routing probabilities and a mapping.
"""
residual = hidden_states
probs, routing_map = self.router(hidden_states)
return probs, routing_map

@maybe_skip_or_early_return_by_cudagraph("preprocess")
def preprocess(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
):
"""Preprocess token routing for dispatch.

This method preprocesses the hidden states and routing probabilities for the token
dispatcher. The original hidden states are returned as a residual connection.
"""
residual = hidden_states
hidden_states, probs = self.token_dispatcher.dispatch_preprocess(
hidden_states, routing_map, probs
)
return hidden_states, probs, residual

def dispatch(self, hidden_states: torch.Tensor, probs: torch.Tensor):
"""Dispatches tokens to assigned expert ranks via communication.

This method performs the actual communication (e.g., All-to-All) to distribute
tokens and their associated probabilities to the devices hosting their assigned
experts.
"""
return self.token_dispatcher.token_dispatch(hidden_states, probs)

@maybe_skip_or_early_return_by_cudagraph("shared_experts_compute")
def shared_experts_compute(self, hidden_states: torch.Tensor):
"""Computes the output of the shared experts.

Expand Down Expand Up @@ -273,28 +293,38 @@ def forward(self, hidden_states: torch.Tensor):

# MoE forward: route -> dispatch -> compute -> combine
def custom_forward(hidden_states):
shared_expert_output = self.shared_experts_compute(hidden_states)
hidden_states, probs, residual = self.router_and_preprocess(hidden_states)
try:
shared_expert_output = self.shared_experts_compute(hidden_states)
probs, routing_map = self.route(hidden_states)
hidden_states, probs, residual = self.preprocess(hidden_states, probs, routing_map)
except MoECudaGraphPartialCaptureSignal as e:
# This signal is raised from the maybe_skip_or_early_return_by_cudagraph decorator.
# It means we should early-return from the MoE layer forward pass.
# This happens when we are partially capturing the CUDA graph of the MoE layer,
# like cuda_graph_scope=["moe_router", "moe_preprocess"].
# We need to return the intermediate tensors as CUDA graph outputs.
return e.get_early_return_outputs(hidden_states, shared_expert_output)

dispatched_input, probs = self.dispatch(hidden_states, probs)
output, mlp_bias = self.routed_experts_compute(dispatched_input, probs, residual)
output = self.combine(output, shared_expert_output)
return output, mlp_bias

if self.moe_layer_recompute:
if self.config.fp8:
output, mlp_bias = te_checkpoint(
outputs = te_checkpoint(
custom_forward,
False,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
hidden_states,
)
else:
output, mlp_bias = tensor_parallel.checkpoint(custom_forward, False, hidden_states)
outputs = tensor_parallel.checkpoint(custom_forward, False, hidden_states)
else:
output, mlp_bias = custom_forward(hidden_states)
outputs = custom_forward(hidden_states)

return output, mlp_bias
return outputs

def backward_dw(self):
"""Compute weight gradients for experts and shared experts."""
Expand Down
Loading
Loading