Skip to content
6 changes: 2 additions & 4 deletions megatron/core/models/gpt/fine_grained_callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def forward_func(

shared_expert_output = layer.mlp.shared_experts_compute(pre_mlp_layernorm_output)
probs, routing_map = layer.mlp.route(pre_mlp_layernorm_output)
local_tokens, probs, _ = layer.mlp.preprocess(
local_tokens, probs = layer.mlp.preprocess(
pre_mlp_layernorm_output, probs, routing_map
)
return hidden_states, local_tokens, probs, shared_expert_output
Expand Down Expand Up @@ -519,9 +519,7 @@ def submodule_moe_forward(node: ScheduleNode, dispatched_tokens: torch.Tensor):
# backward graph from connecting to dispatch submodule
token_dispatcher._comm_manager.dispatched_probs = dispatched_probs

expert_output, _ = layer.mlp.routed_experts_compute(
dispatched_tokens, dispatched_probs, None
)
expert_output, _ = layer.mlp.routed_experts_compute(dispatched_tokens, dispatched_probs)

if layer.recompute_pre_mlp_layernorm:
# discard the output of the pre-mlp layernorm and register the recompute
Expand Down
5 changes: 5 additions & 0 deletions megatron/core/tensor_parallel/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,11 @@ def checkpoint(self, run_function, *args):

def _recompute(self, _):
"""Used as a hook to recompute the output."""

if self.ctx is None:
# The recomputation has been triggered already. Just return.
return

if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad(), "
Expand Down
6 changes: 5 additions & 1 deletion megatron/core/transformer/cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1835,7 +1835,11 @@ def _get_cuda_graph_input_data(self):
sample_args, sample_kwargs = self._get_sample_arguments(order, chunk_id_list)

def get_make_graphed_callables_kwargs():
kwargs = {'allow_unused_input': True, '_order': order}
kwargs = {
'allow_unused_input': True,
'_order': order,
'retain_graph_in_backward': self.config.cuda_graph_retain_backward_graph,
}

# Calculate the number of warmup iterations per layer per microbatch inside TE
# make_graphed_callables(). There are two rules:
Expand Down
15 changes: 7 additions & 8 deletions megatron/core/transformer/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import internal_api

try:
import transformer_engine as te # pylint: disable=unused-import
Expand Down Expand Up @@ -222,9 +223,8 @@ def preprocess(
"""Preprocess token routing for dispatch.

This method preprocesses the hidden states and routing probabilities for the token
dispatcher. The original hidden states are returned as a residual connection.
dispatcher.
"""
residual = hidden_states
# Project the hidden_states from hidden dimension down to latent dimenion.
if self.config.moe_latent_size:
assert (
Expand All @@ -234,7 +234,7 @@ def preprocess(
hidden_states, probs = self.token_dispatcher.dispatch_preprocess(
hidden_states, routing_map, probs
)
return hidden_states, probs, residual
return hidden_states, probs

def dispatch(self, hidden_states: torch.Tensor, probs: torch.Tensor):
"""Dispatches tokens to assigned expert ranks via communication.
Expand Down Expand Up @@ -273,9 +273,8 @@ def shared_experts_compute(self, hidden_states: torch.Tensor):

return shared_expert_output

def routed_experts_compute(
self, hidden_states: torch.Tensor, probs: torch.Tensor, residual: torch.Tensor
):
@internal_api
def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tensor):
"""Computes the output of the routed experts on the dispatched tokens.

This method first post-processes the dispatched input to get permuted tokens
Expand Down Expand Up @@ -342,7 +341,7 @@ def custom_forward(hidden_states, padding_mask=None):
try:
shared_expert_output = self.shared_experts_compute(hidden_states)
probs, routing_map = self.route(hidden_states, padding_mask=padding_mask)
hidden_states, probs, residual = self.preprocess(hidden_states, probs, routing_map)
hidden_states, probs = self.preprocess(hidden_states, probs, routing_map)
except MoECudaGraphPartialCaptureSignal as e:
# This signal is raised from the maybe_skip_or_early_return_by_cudagraph decorator.
# It means we should early-return from the MoE layer forward pass.
Expand All @@ -352,7 +351,7 @@ def custom_forward(hidden_states, padding_mask=None):
return e.get_early_return_outputs(hidden_states, shared_expert_output)

dispatched_input, probs = self.dispatch(hidden_states, probs)
output, mlp_bias = self.routed_experts_compute(dispatched_input, probs, residual)
output, mlp_bias = self.routed_experts_compute(dispatched_input, probs)
assert mlp_bias is None, f"mlp_bias is not supported for {type(self.token_dispatcher)}"
output = self.combine(output, shared_expert_output)

Expand Down
68 changes: 28 additions & 40 deletions megatron/core/transformer/moe/moe_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
import functools
import math
from dataclasses import dataclass
from typing import List, Optional, Union
Expand Down Expand Up @@ -1142,17 +1143,24 @@ def get_early_return_outputs(
"""
Get the CUDA graph early return outputs for the MoE layer, including the intermediate
tensors and the intermediate attributes of the token dispatcher.

The returned output tensors are in the order of:
- routed experts path outputs
- hidden states, probs, and routing map for capturing router
- hidden states and probs for capturing router and preprocess
- intermediate attributes of the token dispatcher (if capturing the preprocess step)
- shared expert path output (if exists)
"""
if self.return_step == "route":
# Capturing the router step returns three intermediate tensors:
# hidden states, routing probabilities, and routing map.
outputs = [hidden_states, self.kwargs['probs'], self.kwargs['routing_map']]
elif self.return_step == "preprocess":
# Capturing the preprocess step returns three intermediate tensors:
# hidden states, routing probabilities, and residual connection.
# Capturing the preprocess step returns two intermediate tensors:
# hidden states and routing probabilities.
# It also returns the intermediate attributes of the token dispatcher, recorded in
# "token_dispatcher.cudagraph_attrs".
outputs = [self.kwargs['hidden_states'], self.kwargs['probs'], self.kwargs['residual']]
outputs = [self.kwargs['hidden_states'], self.kwargs['probs']]
valid_cudagraph_attrs = []
for attr_name in self.moe_layer.token_dispatcher.cudagraph_attrs:
hier_attr_name = attr_name.split('.')
Expand Down Expand Up @@ -1180,6 +1188,7 @@ def get_early_return_outputs(
return outputs


@internal_api
@dataclass
class MoECudaGraphTensorStore:
"""Storage for tensors used in CUDA graph replay for MoE layers.
Expand All @@ -1192,16 +1201,13 @@ class MoECudaGraphTensorStore:
probs (Optional[torch.Tensor]): The routing probabilities for each token-expert pair.
routing_map (Optional[torch.Tensor]): The sparse mapping indicating which experts
were selected for each token. Used to skip the normal router step.
residual (Optional[torch.Tensor]): The residual connection tensor before routing.
Used to skip the normal preprocess step.
shared_expert_output (Optional[torch.Tensor]): The output from shared experts
computation. Used to skip the normal shared expert computation step.
"""

hidden_states: Optional[torch.Tensor] = None
probs: Optional[torch.Tensor] = None
routing_map: Optional[torch.Tensor] = None
residual: Optional[torch.Tensor] = None
shared_expert_output: Optional[torch.Tensor] = None

def is_empty(self) -> bool:
Expand All @@ -1212,13 +1218,7 @@ def is_empty(self) -> bool:
"""
return all(
getattr(self, field_name) is None
for field_name in [
'hidden_states',
'probs',
'routing_map',
'residual',
'shared_expert_output',
]
for field_name in ['hidden_states', 'probs', 'routing_map', 'shared_expert_output']
)

def set(self, **kwargs):
Expand All @@ -1228,7 +1228,6 @@ def set(self, **kwargs):
'hidden_states',
'probs',
'routing_map',
'residual',
'shared_expert_output',
], f"Invalid field name: {field_name}"
if value is not None:
Expand All @@ -1239,13 +1238,7 @@ def set(self, **kwargs):

def clear(self):
"""Reset all stored tensors to None."""
for field_name in [
'hidden_states',
'probs',
'routing_map',
'residual',
'shared_expert_output',
]:
for field_name in ['hidden_states', 'probs', 'routing_map', 'shared_expert_output']:
setattr(self, field_name, None)


Expand Down Expand Up @@ -1288,6 +1281,8 @@ def maybe_raise_signal(moe_layer, **kwargs):
raise MoECudaGraphPartialCaptureSignal(moe_layer, "preprocess", **kwargs)

def decorator(func):

@functools.wraps(func)
def wrapped_func(moe_layer, *args, **kwargs):
"""
Check if we should skip executing the original function based on the current
Expand Down Expand Up @@ -1316,46 +1311,39 @@ def wrapped_func(moe_layer, *args, **kwargs):
# Don't skip the router.
assert (
moe_layer.cudagraph_tensor_store.routing_map is None
and moe_layer.cudagraph_tensor_store.residual is None
), "both routing_map and residual must be None if probs is None"
), "routing_map must be None if probs is None"
probs, routing_map = func(moe_layer, *args, **kwargs)

# Maybe early return after the router.
maybe_raise_signal(moe_layer, probs=probs, routing_map=routing_map)
else:
# Skip the router and get value from store.
assert (
moe_layer.cudagraph_tensor_store.routing_map is not None
or moe_layer.cudagraph_tensor_store.residual is not None
), "either routing_map or residual must be given if probs is given"
probs, routing_map = (
moe_layer.cudagraph_tensor_store.probs,
moe_layer.cudagraph_tensor_store.routing_map,
)
return probs, routing_map
elif step_condition == "preprocess":
if moe_layer.cudagraph_tensor_store.residual is None:
if (
moe_layer.cudagraph_tensor_store.is_empty()
or moe_layer.cudagraph_tensor_store.routing_map is not None
):
# Don't skip the preprocess.
hidden_states, probs, residual = func(moe_layer, *args, **kwargs)
hidden_states, probs = func(moe_layer, *args, **kwargs)

# Maybe early return after the preprocess.
maybe_raise_signal(
moe_layer, hidden_states=hidden_states, probs=probs, residual=residual
)
maybe_raise_signal(moe_layer, hidden_states=hidden_states, probs=probs)
else:
# Skip the preprocess and get value from store.
assert (
moe_layer.cudagraph_tensor_store.probs is not None
), "probs must not be None if residual is not None"
assert (
moe_layer.cudagraph_tensor_store.routing_map is None
), "routing_map must be None if residual is not None"
hidden_states, probs, residual = (
moe_layer.cudagraph_tensor_store.hidden_states is not None
and moe_layer.cudagraph_tensor_store.probs is not None
), "hidden_states and probs must be given in moe_preprocess cudagraph replay"
hidden_states, probs = (
moe_layer.cudagraph_tensor_store.hidden_states,
moe_layer.cudagraph_tensor_store.probs,
moe_layer.cudagraph_tensor_store.residual,
)
return hidden_states, probs, residual
return hidden_states, probs

return wrapped_func

Expand Down
Loading
Loading