From c051dfb3395c07392930d5dd90fcb7a951460f5a Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Wed, 17 Dec 2025 03:47:00 -0800 Subject: [PATCH 1/4] improve recompute checks Signed-off-by: Robin Zhang --- megatron/core/tensor_parallel/random.py | 5 + megatron/core/transformer/cuda_graphs.py | 6 +- megatron/core/transformer/moe/moe_layer.py | 13 +- megatron/core/transformer/moe/moe_utils.py | 67 +++++----- .../core/transformer/transformer_config.py | 104 +++++++--------- .../core/transformer/transformer_layer.py | 114 ++++++++++++------ megatron/training/arguments.py | 3 - 7 files changed, 160 insertions(+), 152 deletions(-) diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index 617d2803c12..5d5389a52d2 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -627,6 +627,11 @@ def checkpoint(self, run_function, *args): def _recompute(self, _): """Used as a hook to recompute the output.""" + + if self.ctx is None: + # The recomputation has been triggered already. Just return. + return + if not torch.autograd._is_checkpoint_valid(): raise RuntimeError( "Checkpointing is not compatible with .grad(), " diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 27e6c65c738..a1cb764b01d 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -1785,7 +1785,11 @@ def _get_cuda_graph_input_data(self): sample_args, sample_kwargs = self._get_sample_arguments(order) def get_make_graphed_callables_kwargs(): - kwargs = {'allow_unused_input': True, '_order': order} + kwargs = { + 'allow_unused_input': True, + '_order': order, + 'retain_graph_in_backward': self.config.cuda_graph_retain_backward_graph, + } # Calculate the number of warmup iterations per layer per microbatch inside TE # make_graphed_callables(). There are two rules: diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 10d10f667fe..9d4923b4aa6 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -194,13 +194,12 @@ def preprocess( """Preprocess token routing for dispatch. This method preprocesses the hidden states and routing probabilities for the token - dispatcher. The original hidden states are returned as a residual connection. + dispatcher. """ - residual = hidden_states hidden_states, probs = self.token_dispatcher.dispatch_preprocess( hidden_states, routing_map, probs ) - return hidden_states, probs, residual + return hidden_states, probs def dispatch(self, hidden_states: torch.Tensor, probs: torch.Tensor): """Dispatches tokens to assigned expert ranks via communication. @@ -239,9 +238,7 @@ def shared_experts_compute(self, hidden_states: torch.Tensor): return shared_expert_output - def routed_experts_compute( - self, hidden_states: torch.Tensor, probs: torch.Tensor, residual: torch.Tensor - ): + def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tensor): """Computes the output of the routed experts on the dispatched tokens. This method first post-processes the dispatched input to get permuted tokens @@ -296,7 +293,7 @@ def custom_forward(hidden_states): try: shared_expert_output = self.shared_experts_compute(hidden_states) probs, routing_map = self.route(hidden_states) - hidden_states, probs, residual = self.preprocess(hidden_states, probs, routing_map) + hidden_states, probs = self.preprocess(hidden_states, probs, routing_map) except MoECudaGraphPartialCaptureSignal as e: # This signal is raised from the maybe_skip_or_early_return_by_cudagraph decorator. # It means we should early-return from the MoE layer forward pass. @@ -306,7 +303,7 @@ def custom_forward(hidden_states): return e.get_early_return_outputs(hidden_states, shared_expert_output) dispatched_input, probs = self.dispatch(hidden_states, probs) - output, mlp_bias = self.routed_experts_compute(dispatched_input, probs, residual) + output, mlp_bias = self.routed_experts_compute(dispatched_input, probs) output = self.combine(output, shared_expert_output) return output, mlp_bias diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index 28cff06f5ec..537ef53bc90 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -1,5 +1,6 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +import functools import math from dataclasses import dataclass from typing import List, Optional, Union @@ -1070,17 +1071,24 @@ def get_early_return_outputs( """ Get the CUDA graph early return outputs for the MoE layer, including the intermediate tensors and the intermediate attributes of the token dispatcher. + + The returned output tensors are in the order of: + - routed experts path outputs + - hidden states, probs, and routing map for capturing router + - hidden states and probs for capturing router and preprocess + - intermediate attributes of the token dispatcher (if capturing the preprocess step) + - shared expert path output (if exists) """ if self.return_step == "route": # Capturing the router step returns three intermediate tensors: # hidden states, routing probabilities, and routing map. outputs = [hidden_states, self.kwargs['probs'], self.kwargs['routing_map']] elif self.return_step == "preprocess": - # Capturing the preprocess step returns three intermediate tensors: - # hidden states, routing probabilities, and residual connection. + # Capturing the preprocess step returns two intermediate tensors: + # hidden states and routing probabilities. # It also returns the intermediate attributes of the token dispatcher, recorded in # "token_dispatcher.cudagraph_attrs". - outputs = [self.kwargs['hidden_states'], self.kwargs['probs'], self.kwargs['residual']] + outputs = [self.kwargs['hidden_states'], self.kwargs['probs']] valid_cudagraph_attrs = [] for attr_name in self.moe_layer.token_dispatcher.cudagraph_attrs: hier_attr_name = attr_name.split('.') @@ -1120,8 +1128,6 @@ class MoECudaGraphTensorStore: probs (Optional[torch.Tensor]): The routing probabilities for each token-expert pair. routing_map (Optional[torch.Tensor]): The sparse mapping indicating which experts were selected for each token. Used to skip the normal router step. - residual (Optional[torch.Tensor]): The residual connection tensor before routing. - Used to skip the normal preprocess step. shared_expert_output (Optional[torch.Tensor]): The output from shared experts computation. Used to skip the normal shared expert computation step. """ @@ -1129,7 +1135,6 @@ class MoECudaGraphTensorStore: hidden_states: Optional[torch.Tensor] = None probs: Optional[torch.Tensor] = None routing_map: Optional[torch.Tensor] = None - residual: Optional[torch.Tensor] = None shared_expert_output: Optional[torch.Tensor] = None def is_empty(self) -> bool: @@ -1140,13 +1145,7 @@ def is_empty(self) -> bool: """ return all( getattr(self, field_name) is None - for field_name in [ - 'hidden_states', - 'probs', - 'routing_map', - 'residual', - 'shared_expert_output', - ] + for field_name in ['hidden_states', 'probs', 'routing_map', 'shared_expert_output'] ) def set(self, **kwargs): @@ -1156,7 +1155,6 @@ def set(self, **kwargs): 'hidden_states', 'probs', 'routing_map', - 'residual', 'shared_expert_output', ], f"Invalid field name: {field_name}" if value is not None: @@ -1167,13 +1165,7 @@ def set(self, **kwargs): def clear(self): """Reset all stored tensors to None.""" - for field_name in [ - 'hidden_states', - 'probs', - 'routing_map', - 'residual', - 'shared_expert_output', - ]: + for field_name in ['hidden_states', 'probs', 'routing_map', 'shared_expert_output']: setattr(self, field_name, None) @@ -1216,6 +1208,8 @@ def maybe_raise_signal(moe_layer, **kwargs): raise MoECudaGraphPartialCaptureSignal(moe_layer, "preprocess", **kwargs) def decorator(func): + + @functools.wraps(func) def wrapped_func(moe_layer, *args, **kwargs): """ Check if we should skip executing the original function based on the current @@ -1244,46 +1238,39 @@ def wrapped_func(moe_layer, *args, **kwargs): # Don't skip the router. assert ( moe_layer.cudagraph_tensor_store.routing_map is None - and moe_layer.cudagraph_tensor_store.residual is None - ), "both routing_map and residual must be None if probs is None" + ), "routing_map must be None if probs is None" probs, routing_map = func(moe_layer, *args, **kwargs) # Maybe early return after the router. maybe_raise_signal(moe_layer, probs=probs, routing_map=routing_map) else: # Skip the router and get value from store. - assert ( - moe_layer.cudagraph_tensor_store.routing_map is not None - or moe_layer.cudagraph_tensor_store.residual is not None - ), "either routing_map or residual must be given if probs is given" probs, routing_map = ( moe_layer.cudagraph_tensor_store.probs, moe_layer.cudagraph_tensor_store.routing_map, ) return probs, routing_map elif step_condition == "preprocess": - if moe_layer.cudagraph_tensor_store.residual is None: + if ( + moe_layer.cudagraph_tensor_store.is_empty() + or moe_layer.cudagraph_tensor_store.routing_map is not None + ): # Don't skip the preprocess. - hidden_states, probs, residual = func(moe_layer, *args, **kwargs) + hidden_states, probs = func(moe_layer, *args, **kwargs) # Maybe early return after the preprocess. - maybe_raise_signal( - moe_layer, hidden_states=hidden_states, probs=probs, residual=residual - ) + maybe_raise_signal(moe_layer, hidden_states=hidden_states, probs=probs) else: # Skip the preprocess and get value from store. assert ( - moe_layer.cudagraph_tensor_store.probs is not None - ), "probs must not be None if residual is not None" - assert ( - moe_layer.cudagraph_tensor_store.routing_map is None - ), "routing_map must be None if residual is not None" - hidden_states, probs, residual = ( + moe_layer.cudagraph_tensor_store.hidden_states is not None + and moe_layer.cudagraph_tensor_store.probs is not None + ), "hidden_states and probs must be given in moe_preprocess cudagraph replay" + hidden_states, probs = ( moe_layer.cudagraph_tensor_store.hidden_states, moe_layer.cudagraph_tensor_store.probs, - moe_layer.cudagraph_tensor_store.residual, ) - return hidden_states, probs, residual + return hidden_states, probs return wrapped_func diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index e2705bd9f51..c65b64cfabb 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -713,11 +713,11 @@ class TransformerConfig(ModelParallelConfig): determines the scope of graph capture.""" cuda_graph_use_single_mempool: bool = False - """When set to true, cudagraphs will be captured inside a single mempool, in which all - cudagraphs may only be used once per step. If false, cudagraphs may be reused across - microbatches. Enabling may reduce cudagraph memory overheads due to memory fragmentation, - however may greatly increase the number of cudagraphs created when the number of microbatches - is high.""" + """[For `local` implementation only] When set to true, cudagraphs will be captured inside a + single mempool, in which all cudagraphs may only be used once per step. If false, cudagraphs may + be reused across microbatches. Enabling may reduce cudagraph memory overheads due to memory + fragmentation, however may greatly increase the number of cudagraphs created when the number of + microbatches is high.""" cuda_graph_retain_backward_graph: bool = False """When set to true, cudagraph backward passes will be graph captured with 'retain_grad=True' @@ -1722,64 +1722,46 @@ def __post_init__(self): ) if self.recompute_granularity: - if self.recompute_granularity != "selective" or not self.cuda_graph_scope: - raise ValueError( - "Full-layer CUDA graphs not supported with activation recomputation." - ) - elif self.cuda_graph_scope != [CudaGraphScope.full_iteration]: - # For scoped CUDA graphs, only the non-graphed parts of the layer can be - # recomputed. So check if there are overlaps between the recomputed parts - # and the graphed parts. - if CudaGraphScope.attn in self.cuda_graph_scope: - for module in self.recompute_modules: - if module in ['core_attn', 'mla_up_proj']: - raise ValueError( - f'attn cuda graph is not supported with {module} recompute.' - ) + if self.recompute_granularity != "selective": + assert self.cuda_graph_scope == [ + CudaGraphScope.full_iteration + ], "full recompute is only supported with full iteration CUDA graph." + else: + # The recompute module should be inside or outside of the graph scope. + # Recompute module coverring graph scope is not allowed. + if "moe" in self.recompute_modules: + assert ( + CudaGraphScope.moe_router not in self.cuda_graph_scope + ), "moe recompute is not supported with moe_router CUDA graph." + # Graphed recompute module doesn't accept random number. if ( - CudaGraphScope.mlp in self.cuda_graph_scope - and "mlp" in self.recompute_modules + not self.cuda_graph_scope + or CudaGraphScope.full_iteration in self.cuda_graph_scope ): - raise ValueError(f'mlp cuda graph is not supported with mlp recompute.') - if CudaGraphScope.moe in self.cuda_graph_scope: - for module in self.recompute_modules: - if module in ['moe_act', 'moe', 'shared_experts']: - raise ValueError( - f'moe cuda graph is not supported with {module} recompute.' - ) - if CudaGraphScope.moe_router in self.cuda_graph_scope: - for module in self.recompute_modules: - if module in ['moe', 'shared_experts']: - raise ValueError( - f'moe_router cuda graph is not supported with {module} ' - 'recompute.' - ) - if "layernorm" in self.recompute_modules: - if ( - CudaGraphScope.attn in self.cuda_graph_scope - and CudaGraphScope.mlp in self.cuda_graph_scope - and ( - CudaGraphScope.moe in self.cuda_graph_scope - or CudaGraphScope.moe_router in self.cuda_graph_scope - ) - ): - raise ValueError( - 'cuda graph is not supported with layernorm recompute.' - ) - if CudaGraphScope.attn in self.cuda_graph_scope: - warnings.warn( - "input_layernorm recompute is not supported with attention " - "cudagraph. Will only recompute the pre_mlp_layernorm." - ) - if ( - CudaGraphScope.mlp in self.cuda_graph_scope - or CudaGraphScope.moe in self.cuda_graph_scope - or CudaGraphScope.moe_router in self.cuda_graph_scope - ): - warnings.warn( - "pre_mlp_layernorm recompute is not supported with mlp/moe " - "cudagraph. Will only recompute the input_layernorm." - ) + full_cudagraph = True + else: + full_cudagraph = False + if self.attention_dropout != 0.0: + assert ( + not full_cudagraph and CudaGraphScope.attn not in self.cuda_graph_scope + ) or "core_attn" not in self.recompute_modules, ( + "attention dropout is not supported with graphed attention " + "recomputation." + ) + if self.hidden_dropout != 0.0: + assert ( + (not full_cudagraph and CudaGraphScope.mlp not in self.cuda_graph_scope) + or "mlp" not in self.recompute_modules + ) and ( + (not full_cudagraph and CudaGraphScope.moe not in self.cuda_graph_scope) + or "moe" not in self.recompute_modules + ), "hidden dropout is not supported with graphed MLP/MoE recomputation." + if self.moe_input_jitter_eps is not None: + assert ( + not full_cudagraph and CudaGraphScope.moe not in self.cuda_graph_scope + ) or "moe" not in self.recompute_modules, ( + "moe_input_jitter_eps is not supported with graphed moe recomputation." + ) if self.moe_token_dispatcher_type in ["allgather"]: if self.variable_seq_lengths is True: diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 3ea40577009..7bfaaaf7f0e 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -380,24 +380,55 @@ def __init__( self.recompute_mlp = False if self.config.recompute_granularity == 'selective': if "layernorm" in self.config.recompute_modules: - if not isinstance(self.input_layernorm, IdentityOp) and ( - self.config.cuda_graph_impl == "none" - or CudaGraphScope.attn not in self.config.cuda_graph_scope - ): + if not isinstance(self.input_layernorm, IdentityOp): self.recompute_input_layernorm = True if self.config.fp8 or self.config.fp4: self.self_attention.set_for_recompute_input_layernorm() - if not isinstance(self.pre_mlp_layernorm, IdentityOp) and ( - self.config.cuda_graph_impl == "none" - or ( + + def can_recompute_pre_mlp_layernorm_for_cudagraph(): + if ( not self.is_moe_layer - and CudaGraphScope.mlp not in self.config.cuda_graph_scope - ) - or ( - self.is_moe_layer - and CudaGraphScope.moe not in self.config.cuda_graph_scope - and CudaGraphScope.moe_router not in self.config.cuda_graph_scope + or CudaGraphScope.moe_router not in self.config.cuda_graph_scope + ): + # Not a MoE layer, or not capturing the router part. + return True + if ( + self.config.moe_shared_expert_intermediate_size is not None + and self.config.moe_shared_expert_overlap + ): + # If shared expert overlap is used, we cannot make the pre-mlp layernorm + # recomputation, because the shared expert takes the layernorm output as + # input, and it is outside of the CUDA graph scope. + log_single_rank( + logger, + logging.WARNING, + "pre_mlp_layernorm recompute is not supported with moe router " + "cudagraph + shared expert overlap. Disabling pre_mlp_layernorm " + "recompute.", + ) + return False + if ( + CudaGraphScope.moe_preprocess in self.config.cuda_graph_scope + and self.config.moe_token_dispatcher_type == "alltoall" + ): + # Only when capturing the preprocess part and using alltoall token + # dispatcher can we make the pre-mlp layernorm recomputation. + # Because in other cases the layernorm output returns directly as one of the + # outputs of the cudagraph, which will be allocated a static buffer, thus + # not able to be released. + return True + log_single_rank( + logger, + logging.WARNING, + "pre_mlp_layernorm recompute is only supported with moe router + " + "preprocess cudagraph will alltoall token dispatcher. Disabling " + "pre_mlp_layernorm recompute.", ) + return False + + if ( + not isinstance(self.pre_mlp_layernorm, IdentityOp) + and can_recompute_pre_mlp_layernorm_for_cudagraph() ): self.recompute_pre_mlp_layernorm = True if self.config.fp8 or self.config.fp4: @@ -632,20 +663,7 @@ def _forward_mlp(self, hidden_states, inference_context=None): and not isinstance(self.mlp, IdentityOp) ) - if ( - self.is_moe_layer - and self.config.cuda_graph_impl == "transformer_engine" - and self.training - and is_graph_capturing() - and CudaGraphScope.moe_router in self.config.cuda_graph_scope - ): - assert ( - not self.recompute_pre_mlp_layernorm - ), "Recomputation is not supported for CUDA graph." - cudagraph_outputs = self.mlp(pre_mlp_layernorm_output) - nvtx_range_pop(suffix="mlp") - return cudagraph_outputs + [residual] - elif self.recompute_mlp: + if self.recompute_mlp: if self.config.fp8 or self.config.fp4: # import here to avoid circular import from megatron.core.extensions.transformer_engine import te_checkpoint @@ -685,7 +703,23 @@ def _forward_mlp(self, hidden_states, inference_context=None): ) nvtx_range_pop(suffix="mlp") - return self._forward_post_mlp(mlp_output_with_bias, residual) + if ( + self.is_moe_layer + and self.config.cuda_graph_impl == "transformer_engine" + and self.training + and is_graph_capturing() + and CudaGraphScope.moe_router in self.config.cuda_graph_scope + ): + if self.recompute_pre_mlp_layernorm: + # Register the recompute hooks to all the cudagraph output tensors, because some + # tensors are in parallel execution paths and they all need pre_mlp_layernorm to be + # recomputed in backward pass. For example, the router path and the shared expert + # path. So only register in one path is risky. + for tensor in mlp_output_with_bias[1:]: + self.pre_mlp_norm_checkpoint.discard_output_and_register_recompute(tensor) + return list(mlp_output_with_bias) + [residual] + else: + return self._forward_post_mlp(mlp_output_with_bias, residual) def _forward_post_mlp(self, mlp_output_with_bias, residual): """ @@ -875,20 +909,19 @@ def _te_cuda_graph_replay(self, *args, **kwargs): elif self.is_moe_layer and CudaGraphScope.moe_router in self.config.cuda_graph_scope: # CUDA Graph partially captures the MoE. # The rest of the layer should go to the normal pass. - shared_expert_output, routing_map, residual = None, None, None - mlp_residual = cuda_graph_output.pop() + shared_expert_output, routing_map = None, None + # residual is the last element in the CUDA graph output. + residual = cuda_graph_output.pop() if ( self.config.moe_shared_expert_intermediate_size is not None and not self.config.moe_shared_expert_overlap ): - # The shared expert output is the fourth element in the CUDA graph output. + # The shared expert output is the last second element in the CUDA graph output. shared_expert_output = cuda_graph_output.pop() - # Split cudagraph outputs into function outputs and attribute outputs, and - # process them separately. Function outputs should have three tensors. - func_output, attr_outputs = cuda_graph_output[:3], cuda_graph_output[3:] if CudaGraphScope.moe_preprocess in self.config.cuda_graph_scope: - hidden_states, probs, residual = func_output + # CUDA graph output is [hidden_states, probs] + attributes outputs. + (hidden_states, probs), attr_outputs = cuda_graph_output[:2], cuda_graph_output[2:] valid_cudagraph_attrs = self.mlp.token_dispatcher.valid_cudagraph_attrs assert len(attr_outputs) == len( valid_cudagraph_attrs @@ -900,8 +933,12 @@ def _te_cuda_graph_replay(self, *args, **kwargs): attr = getattr(attr, name) setattr(attr, hier_attr_name[-1], attr_outputs[i]) else: - hidden_states, probs, routing_map = func_output - assert not attr_outputs, "cuda_graph_attr_outputs should be empty" + # CUDA graph output is [hidden_states, probs, routing_map]. + assert len(cuda_graph_output) == 3, ( + "CUDA graph output should be [hidden_states, probs, routing_map], " + f"but got {len(cuda_graph_output)} elements" + ) + hidden_states, probs, routing_map = cuda_graph_output # Resume the MoELayer forward pass from the end of the CUDA graph scope. # The MoE layer will skip redundant computations when we pass in the calculated values @@ -911,14 +948,13 @@ def _te_cuda_graph_replay(self, *args, **kwargs): hidden_states=hidden_states, probs=probs, routing_map=routing_map, - residual=residual, shared_expert_output=shared_expert_output, ) mlp_output_with_bias = self.mlp(hidden_states) self.mlp.cudagraph_tensor_store.clear() nvtx_range_pop(suffix="mlp") - output = self._forward_post_mlp(mlp_output_with_bias, mlp_residual) + output = self._forward_post_mlp(mlp_output_with_bias, residual) else: # CUDA Graph does not capture the MLP/MoE part at all. output = self._forward_mlp(*cuda_graph_output) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index c157d062c53..6f521ed109f 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1290,9 +1290,6 @@ def validate_args(args, defaults={}): "Setting NCCL_GRAPH_REGISTER=0 to avoid illegal memory access when using " "CUDA Graph with PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True." ) - assert ( - args.recompute_granularity != 'full' - ), 'recompute_granularity must not be full when CUDA Graphs are enabled.' if args.cuda_graph_scope == "full" or ( isinstance(args.cuda_graph_scope, list) and "full" in args.cuda_graph_scope ): From a72fb61a9103f90098beaa3634f38e47245f35e4 Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Thu, 18 Dec 2025 00:44:24 -0800 Subject: [PATCH 2/4] fix backward compatibility Signed-off-by: Robin Zhang --- megatron/core/models/gpt/fine_grained_callables.py | 4 ++-- megatron/core/transformer/moe/moe_layer.py | 2 ++ megatron/core/transformer/moe/moe_utils.py | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index 60094976a9a..f9296a934bc 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -384,7 +384,7 @@ def submodule_post_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor) pre_mlp_layernorm_output = layer.pre_mlp_layernorm(hidden_states) probs, routing_map = layer.mlp.route(pre_mlp_layernorm_output) - local_tokens, probs, _ = layer.mlp.preprocess(pre_mlp_layernorm_output, probs, routing_map) + local_tokens, probs = layer.mlp.preprocess(pre_mlp_layernorm_output, probs, routing_map) # Detach here for mlp_bda residual connection node.layer_state.residual = node.detach(hidden_states) @@ -426,7 +426,7 @@ def submodule_moe_forward(node: ScheduleNode, dispatched_tokens: torch.Tensor): pre_mlp_layernorm_output = getattr(node.layer_state, 'pre_mlp_layernorm_output', None) shared_expert_output = layer.mlp.shared_experts_compute(pre_mlp_layernorm_output) expert_output, mlp_bias = layer.mlp.routed_experts_compute( - dispatched_tokens, dispatched_probs, pre_mlp_layernorm_output + dispatched_tokens, dispatched_probs ) if layer.recompute_pre_mlp_layernorm: diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 9d4923b4aa6..51d69ba31b6 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -24,6 +24,7 @@ ) from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import internal_api try: import transformer_engine as te # pylint: disable=unused-import @@ -238,6 +239,7 @@ def shared_experts_compute(self, hidden_states: torch.Tensor): return shared_expert_output + @internal_api def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tensor): """Computes the output of the routed experts on the dispatched tokens. diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index 537ef53bc90..b855fac92e5 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -1116,6 +1116,7 @@ def get_early_return_outputs( return outputs +@internal_api @dataclass class MoECudaGraphTensorStore: """Storage for tensors used in CUDA graph replay for MoE layers. From 48337a3fe79e11aba81c93d8b6f2a26e1802f735 Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Mon, 5 Jan 2026 18:23:42 -0800 Subject: [PATCH 3/4] latent moe cudagraph for layernorm recompute --- megatron/core/transformer/transformer_layer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 7bfaaaf7f0e..2ac60f00eff 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -407,12 +407,12 @@ def can_recompute_pre_mlp_layernorm_for_cudagraph(): "recompute.", ) return False - if ( - CudaGraphScope.moe_preprocess in self.config.cuda_graph_scope - and self.config.moe_token_dispatcher_type == "alltoall" + if CudaGraphScope.moe_preprocess in self.config.cuda_graph_scope and ( + self.config.moe_token_dispatcher_type == "alltoall" + or self.config.moe_latent_size ): # Only when capturing the preprocess part and using alltoall token - # dispatcher can we make the pre-mlp layernorm recomputation. + # dispatcher or latent MoE can we make the pre-mlp layernorm recomputation. # Because in other cases the layernorm output returns directly as one of the # outputs of the cudagraph, which will be allocated a static buffer, thus # not able to be released. @@ -421,8 +421,8 @@ def can_recompute_pre_mlp_layernorm_for_cudagraph(): logger, logging.WARNING, "pre_mlp_layernorm recompute is only supported with moe router + " - "preprocess cudagraph will alltoall token dispatcher. Disabling " - "pre_mlp_layernorm recompute.", + "preprocess cudagraph will alltoall token dispatcher or latent MoE. " + "Disabling pre_mlp_layernorm recompute.", ) return False From 2b474a5f728a7f929f2d188061c283a78af74e74 Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Tue, 6 Jan 2026 21:51:20 -0800 Subject: [PATCH 4/4] fix lint Signed-off-by: Robin Zhang --- megatron/core/models/gpt/fine_grained_callables.py | 4 +--- megatron/core/transformer/transformer_layer.py | 8 ++------ 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index 61a5700b4f5..71c5c19749c 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -519,9 +519,7 @@ def submodule_moe_forward(node: ScheduleNode, dispatched_tokens: torch.Tensor): # backward graph from connecting to dispatch submodule token_dispatcher._comm_manager.dispatched_probs = dispatched_probs - expert_output, _ = layer.mlp.routed_experts_compute( - dispatched_tokens, dispatched_probs - ) + expert_output, _ = layer.mlp.routed_experts_compute(dispatched_tokens, dispatched_probs) if layer.recompute_pre_mlp_layernorm: # discard the output of the pre-mlp layernorm and register the recompute diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index afc01a66d2d..ce90aaf357a 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -974,9 +974,7 @@ def _te_cuda_graph_replay(self, *args, **kwargs): # and should be skipped here. if self.config.overlap_moe_expert_parallel_comm: probs, routing_map = self.mlp.route(hidden_states) - hidden_states, probs = self.mlp.preprocess( - hidden_states, probs, routing_map - ) + hidden_states, probs = self.mlp.preprocess(hidden_states, probs, routing_map) nvtx_range_pop(suffix="mlp") return residual, hidden_states, probs, shared_expert_output mlp_output_with_bias = self.mlp(hidden_states) @@ -994,9 +992,7 @@ def _te_cuda_graph_replay(self, *args, **kwargs): hidden_states = self.pre_mlp_layernorm(residual) shared_expert_output = self.mlp.shared_experts_compute(hidden_states) probs, routing_map = self.mlp.route(hidden_states) - hidden_states, probs = self.mlp.preprocess( - hidden_states, probs, routing_map - ) + hidden_states, probs = self.mlp.preprocess(hidden_states, probs, routing_map) return residual, hidden_states, probs, shared_expert_output # CUDA Graph does not capture the MLP/MoE part at all.