diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index e82397b854f..db5c8936022 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import asyncio import concurrent @@ -659,7 +659,7 @@ def generate_all_output_tokens_static_batch( # Check whether CUDA graphs are enabled enable_cuda_graph = ( model_config.cuda_graph_impl == "local" - and model_config.cuda_graph_scope != "full_iteration" + and "full_iteration" not in model_config.cuda_graph_scope ) # Pad batch tokens if necessary diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py index d855322c2df..8f90fb3ba47 100644 --- a/megatron/core/models/common/language_module/language_module.py +++ b/megatron/core/models/common/language_module/language_module.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import logging import os from typing import Optional, Tuple @@ -136,7 +136,7 @@ def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor: # Use is_cg_capturable=True for full iteration CUDA graphs to avoid torch.equal checks is_cg_capturable = ( hasattr(self.config, 'cuda_graph_scope') - and self.config.cuda_graph_scope == 'full_iteration' + and 'full_iteration' in self.config.cuda_graph_scope ) if is_cg_capturable and not is_te_min_version("2.7.0"): from megatron.core.utils import get_te_version diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index fd1cc3d33c6..6755dd8e8d5 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -358,7 +358,8 @@ def submodule_post_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor) else: pre_mlp_layernorm_output = layer.pre_mlp_layernorm(hidden_states) - local_tokens, probs, _ = layer.mlp.router_and_preprocess(pre_mlp_layernorm_output) + probs, routing_map = layer.mlp.route(pre_mlp_layernorm_output) + local_tokens, probs, _ = layer.mlp.preprocess(pre_mlp_layernorm_output, probs, routing_map) # Detach here for mlp_bda residual connection node.layer_state.residual = node.detach(hidden_states) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 654827dc6fb..e134920ad41 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -371,7 +371,7 @@ def _preprocess( and ( ( self.config.cuda_graph_impl == "local" - and self.config.cuda_graph_scope != "full_iteration" + and "full_iteration" not in self.config.cuda_graph_scope ) or self.config.flash_decode ) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index e83f8d90635..48f4a483e53 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -648,7 +648,7 @@ def forward_backward_no_pipelining( if ( hasattr(config, 'cuda_graph_impl') and config.cuda_graph_impl == "local" - and config.cuda_graph_scope != "full_iteration" + and "full_iteration" not in config.cuda_graph_scope ): create_cudagraphs() @@ -1912,7 +1912,7 @@ def pp_post_backward(input_tensor_grad, vp_stage=None): if ( hasattr(config, 'cuda_graph_impl') and config.cuda_graph_impl == "local" - and config.cuda_graph_scope != "full_iteration" + and "full_iteration" not in config.cuda_graph_scope ): create_cudagraphs() nvtx_range_pop(suffix="misc") @@ -2296,7 +2296,7 @@ def enable_grad_sync(): if ( hasattr(config, 'cuda_graph_impl') and config.cuda_graph_impl == "local" - and config.cuda_graph_scope != "full_iteration" + and "full_iteration" not in config.cuda_graph_scope ): create_cudagraphs() diff --git a/megatron/core/ssm/mamba_block.py b/megatron/core/ssm/mamba_block.py index 01b9f4eac66..8e23a3b2aae 100644 --- a/megatron/core/ssm/mamba_block.py +++ b/megatron/core/ssm/mamba_block.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2024, Tri Dao, Albert Gu. # Some of this code was adopted from https://github.com/state-spaces/mamba/ @@ -301,7 +301,7 @@ def forward( ( ( self.config.cuda_graph_impl == "local" - and self.config.cuda_graph_scope != "full_iteration" + and "full_iteration" not in self.config.cuda_graph_scope ) or self.config.flash_decode ) diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index d4e990041ca..0a63f8f728d 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -782,7 +782,7 @@ def forward( if ( in_decode_mode and self.config.cuda_graph_impl == "local" - and self.config.cuda_graph_scope != "full_iteration" + and "full_iteration" not in self.config.cuda_graph_scope and inference_context.is_static_batching() ): raise ValueError(f"CUDA graphs must use flash decode with static batching!") diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index f75eff7399a..4d6893908ed 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import gc import inspect @@ -22,7 +22,7 @@ get_cuda_rng_tracker, ) from megatron.core.transformer.identity_op import IdentityOp -from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.module import GraphableMegatronModule, MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import ( get_attr_wrapped_model, @@ -1059,9 +1059,12 @@ def __init__( ), "RNG tracker does not support cudagraphs!" assert config.cuda_graph_impl == "local", "Option cuda_graph_impl=local not enabled." - assert "expandable_segments:True" not in os.getenv("PYTORCH_CUDA_ALLOC_CONF", ""), ( - "expandable_segments:True may not be safe when using CUDA Graphs, and may result in" - "a crash due to illegal memory access or other undefined behaviour." + assert ( + "expandable_segments:True" not in os.getenv("PYTORCH_CUDA_ALLOC_CONF", "") + or os.getenv("NCCL_GRAPH_REGISTER", "") == "0" + ), ( + "Setting NCCL_GRAPH_REGISTER=0 to avoid illegal memory access when using " + "CUDA Graph with PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True." ) self.cudagraph_runners = [] @@ -1311,23 +1314,40 @@ def _layer_is_graphable(layer, config): Check if a layer is graphable. """ + # Only GraphableMegatronModule can be graphed. + if not isinstance(layer, GraphableMegatronModule): + return False + + # If cuda_graph_scope is not set, every layer is graphed. + if not config.cuda_graph_scope: + return True + # import modules here to avoid a circular import from megatron.core.ssm.mamba_layer import MambaLayer from megatron.core.transformer.identity_op import IdentityOp + from megatron.core.transformer.mlp import MLP + from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.transformer_layer import TransformerLayer - if isinstance(layer, MambaLayer) and config.cuda_graph_scope == "full": + if isinstance(layer, MambaLayer) and 'mamba' in config.cuda_graph_scope: # mamba layer. return True if isinstance(layer, TransformerLayer): - if config.cuda_graph_scope == 'attn': - if not ( - isinstance(layer.self_attention, IdentityOp) - and isinstance(layer.cross_attention, IdentityOp) - ): - # attn layer. - return True - else: + if 'attn' in config.cuda_graph_scope and not ( + isinstance(layer.self_attention, IdentityOp) + and isinstance(layer.cross_attention, IdentityOp) + ): + # attn layer. + return True + if ( + 'moe' in config.cuda_graph_scope + or 'moe_router' in config.cuda_graph_scope + or 'moe_preprocess' in config.cuda_graph_scope + ) and isinstance(layer.mlp, MoELayer): + # moe layer. + return True + if 'mlp' in config.cuda_graph_scope and isinstance(layer.mlp, MLP): + # mlp layer. return True return False @@ -1346,18 +1366,17 @@ def __init__(self, model, config, seq_length, micro_batch_size, optimizers=[]): assert ( config.cuda_graph_impl == "transformer_engine" ), "Option cuda_graph_impl=transformer_engine not enabled." - assert "expandable_segments:True" not in os.getenv("PYTORCH_CUDA_ALLOC_CONF", ""), ( - "expandable_segments:True may not be safe when using CUDA Graphs, and may result in" - "a crash due to illegal memory access or other undefined behaviour." + assert ( + "expandable_segments:True" not in os.getenv("PYTORCH_CUDA_ALLOC_CONF", "") + or os.getenv("NCCL_GRAPH_REGISTER", "") == "0" + ), ( + "Setting NCCL_GRAPH_REGISTER=0 to avoid illegal memory access when using " + "CUDA Graph with PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True." ) - assert config.cuda_graph_scope != "full_iteration", ( + assert "full_iteration" not in config.cuda_graph_scope, ( "full_iteration cuda graph is not supported for cuda_graph_impl=transformer_engine. " "Please use cuda_graph_impl=local instead." ) - assert config.cuda_graph_scope in [ - 'full', - 'attn', - ], f"--cuda-graph-scope should be full or attn, got {config.cuda_graph_scope}." self.model = model self.config = config @@ -1440,6 +1459,16 @@ def __init__(self, model, config, seq_length, micro_batch_size, optimizers=[]): f'{len(self.flattened_callables)} graphable layers.', ) + # One helper object can only capture CUDA Graphs once. Use this flag to check if the graphs + # have been created. + self._graphs_created = False + + def graphs_created(self): + """ + Returns whether the CUDA Graphs have been created. + """ + return self._graphs_created + def _get_cuda_graph_input_data(self): """ Create the CUDA Graph capturing input data. @@ -1480,8 +1509,13 @@ def get_rotary_pos_emb(transformer_module, transformer_input): from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.transformer_layer import TransformerLayer - contains_self_attn = isinstance(layer, TransformerLayer) and not isinstance( - layer.self_attention, IdentityOp + contains_self_attn = ( + isinstance(layer, TransformerLayer) + and not isinstance(layer.self_attention, IdentityOp) + and ( + not self.config.cuda_graph_scope + or 'attn' in self.config.cuda_graph_scope + ) ) if is_te_min_version("1.10.0"): # te.make_graphed_callables() accepts keyword arguments since 1.10.0. @@ -1590,6 +1624,8 @@ def _start_capturing(self): """ Start capturing CUDA Graphs. """ + assert not self._graphs_created, "CUDA Graphs have already been created." + torch.distributed.barrier() gc.collect() torch.cuda.empty_cache() @@ -1623,6 +1659,8 @@ def _finish_capturing(self, start_time): gc.collect() torch.cuda.empty_cache() + self._graphs_created = True + def create_cudagraphs(self): """ Capture CUDA Graphs per TransformerLayer per microbatch. diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 4ad83ce4a8f..893b2e7b99a 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. from abc import ABC, abstractmethod from dataclasses import dataclass @@ -9,7 +9,12 @@ from megatron.core import parallel_state, tensor_parallel, utils from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.moe.moe_utils import get_default_pg_collection +from megatron.core.transformer.moe.moe_utils import ( + MoECudaGraphPartialCaptureSignal, + MoECudaGraphTensorStore, + get_default_pg_collection, + maybe_skip_or_early_return_by_cudagraph, +) from megatron.core.transformer.moe.router import TopKRouter from megatron.core.transformer.moe.token_dispatcher import ( MoEAllGatherTokenDispatcher, @@ -169,16 +174,29 @@ def __init__( if self.shared_expert_overlap: self.token_dispatcher.set_shared_experts(self.shared_experts) - def router_and_preprocess(self, hidden_states: torch.Tensor): - """Compute and preprocess token routing for dispatch. + # Cudagraph tensor store for resuming the forward pass from the end of the cudagraph. + self.cudagraph_tensor_store = MoECudaGraphTensorStore() + + @maybe_skip_or_early_return_by_cudagraph("route") + def route(self, hidden_states: torch.Tensor): + """Compute token routing for preprocessing. This method uses the router to determine which experts to send each token to, - producing routing probabilities and a mapping. It then preprocesses the - hidden states and probabilities for the token dispatcher. The original - hidden states are returned as a residual connection. + producing routing probabilities and a mapping. """ - residual = hidden_states probs, routing_map = self.router(hidden_states) + return probs, routing_map + + @maybe_skip_or_early_return_by_cudagraph("preprocess") + def preprocess( + self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor + ): + """Preprocess token routing for dispatch. + + This method preprocesses the hidden states and routing probabilities for the token + dispatcher. The original hidden states are returned as a residual connection. + """ + residual = hidden_states hidden_states, probs = self.token_dispatcher.dispatch_preprocess( hidden_states, routing_map, probs ) @@ -186,12 +204,14 @@ def router_and_preprocess(self, hidden_states: torch.Tensor): def dispatch(self, hidden_states: torch.Tensor, probs: torch.Tensor): """Dispatches tokens to assigned expert ranks via communication. + This method performs the actual communication (e.g., All-to-All) to distribute tokens and their associated probabilities to the devices hosting their assigned experts. """ return self.token_dispatcher.token_dispatch(hidden_states, probs) + @maybe_skip_or_early_return_by_cudagraph("shared_experts_compute") def shared_experts_compute(self, hidden_states: torch.Tensor): """Computes the output of the shared experts. @@ -273,8 +293,18 @@ def forward(self, hidden_states: torch.Tensor): # MoE forward: route -> dispatch -> compute -> combine def custom_forward(hidden_states): - shared_expert_output = self.shared_experts_compute(hidden_states) - hidden_states, probs, residual = self.router_and_preprocess(hidden_states) + try: + shared_expert_output = self.shared_experts_compute(hidden_states) + probs, routing_map = self.route(hidden_states) + hidden_states, probs, residual = self.preprocess(hidden_states, probs, routing_map) + except MoECudaGraphPartialCaptureSignal as e: + # This signal is raised from the maybe_skip_or_early_return_by_cudagraph decorator. + # It means we should early-return from the MoE layer forward pass. + # This happens when we are partially capturing the CUDA graph of the MoE layer, + # like cuda_graph_scope=["moe_router", "moe_preprocess"]. + # We need to return the intermediate tensors as CUDA graph outputs. + return e.get_early_return_outputs(hidden_states, shared_expert_output) + dispatched_input, probs = self.dispatch(hidden_states, probs) output, mlp_bias = self.routed_experts_compute(dispatched_input, probs, residual) output = self.combine(output, shared_expert_output) @@ -282,7 +312,7 @@ def custom_forward(hidden_states): if self.moe_layer_recompute: if self.config.fp8: - output, mlp_bias = te_checkpoint( + outputs = te_checkpoint( custom_forward, False, tensor_parallel.random.get_cuda_rng_tracker, @@ -290,11 +320,11 @@ def custom_forward(hidden_states): hidden_states, ) else: - output, mlp_bias = tensor_parallel.checkpoint(custom_forward, False, hidden_states) + outputs = tensor_parallel.checkpoint(custom_forward, False, hidden_states) else: - output, mlp_bias = custom_forward(hidden_states) + outputs = custom_forward(hidden_states) - return output, mlp_bias + return outputs def backward_dw(self): """Compute weight gradients for experts and shared experts.""" diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index dc857129834..408dd565cd4 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -1,12 +1,14 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import math +from dataclasses import dataclass from typing import List, Optional, Union import torch from megatron.core import parallel_state from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.cuda_graphs import is_graph_capturing try: import transformer_engine as te # pylint: disable=unused-import @@ -905,12 +907,16 @@ class RandomSTE(torch.autograd.Function): """ generator = None + random_logits = None @staticmethod def forward(ctx, logits): """ Forward pass returns random logits with rank-specific seed. """ + if is_graph_capturing() and RandomSTE.random_logits is not None: + return RandomSTE.random_logits + if RandomSTE.generator is None: global_rank = torch.distributed.get_rank() base_seed = 42 @@ -918,8 +924,8 @@ def forward(ctx, logits): RandomSTE.generator = torch.Generator(device=logits.device) RandomSTE.generator.manual_seed(seed) - random_logits = logits.clone().normal_(generator=RandomSTE.generator) - return random_logits + RandomSTE.random_logits = logits.clone().normal_(generator=RandomSTE.generator) + return RandomSTE.random_logits @staticmethod def backward(ctx, grad_output): @@ -1028,3 +1034,242 @@ def get_default_pg_collection(): with_context_parallel=True ) return pg_collection + + +class MoECudaGraphPartialCaptureSignal(Exception): + """ + Used to early-return from a MoE layer forward pass in CUDA graph capture. + This signal is raised when we are partially capturing the CUDA graph of the MoE layer, + and the related intermediate tensors are recorded in self.kwargs. + Call self.get_early_return_outputs() to collect the CUDA graph outputs. + """ + + def __init__(self, moe_layer, return_step: str, **kwargs): + self.moe_layer = moe_layer + self.return_step = return_step + self.kwargs = kwargs + + def get_early_return_outputs( + self, hidden_states: torch.Tensor, shared_expert_output: torch.Tensor + ): + """ + Get the CUDA graph early return outputs for the MoE layer, including the intermediate + tensors and the intermediate attributes of the token dispatcher. + """ + if self.return_step == "route": + # Capturing the router step returns three intermediate tensors: + # hidden states, routing probabilities, and routing map. + outputs = [hidden_states, self.kwargs['probs'], self.kwargs['routing_map']] + elif self.return_step == "preprocess": + # Capturing the preprocess step returns three intermediate tensors: + # hidden states, routing probabilities, and residual connection. + # It also returns the intermediate attributes of the token dispatcher, recorded in + # "token_dispatcher.cudagraph_attrs". + outputs = [self.kwargs['hidden_states'], self.kwargs['probs'], self.kwargs['residual']] + valid_cudagraph_attrs = [] + for attr_name in self.moe_layer.token_dispatcher.cudagraph_attrs: + hier_attr_name = attr_name.split('.') + attr = self.moe_layer.token_dispatcher + for name in hier_attr_name: + attr = getattr(attr, name, None) + if attr is None: + break + if isinstance(attr, torch.Tensor): + outputs.append(attr) + valid_cudagraph_attrs.append(attr_name) + if self.moe_layer.token_dispatcher.valid_cudagraph_attrs is None: + self.moe_layer.token_dispatcher.valid_cudagraph_attrs = valid_cudagraph_attrs + else: + assert ( + self.moe_layer.token_dispatcher.valid_cudagraph_attrs == valid_cudagraph_attrs + ), ( + "valid_cudagraph_attrs mismatch: " + f"{self.moe_layer.token_dispatcher.valid_cudagraph_attrs} != " + f"{valid_cudagraph_attrs}" + ) + # Also return the shared expert output, if it is not None. + if shared_expert_output is not None: + outputs.append(shared_expert_output) + return outputs + + +@dataclass +class MoECudaGraphTensorStore: + """Storage for tensors used in CUDA graph replay for MoE layers. + + This dataclass stores intermediate tensors computed during CUDA graph replay + that need to be resumed from the end of the CUDA graph scope to skip redundant computations. + + Attributes: + hidden_states (Optional[torch.Tensor]): The hidden states output from the CUDA graph replay. + probs (Optional[torch.Tensor]): The routing probabilities for each token-expert pair. + routing_map (Optional[torch.Tensor]): The sparse mapping indicating which experts + were selected for each token. Used to skip the normal router step. + residual (Optional[torch.Tensor]): The residual connection tensor before routing. + Used to skip the normal preprocess step. + shared_expert_output (Optional[torch.Tensor]): The output from shared experts + computation. Used to skip the normal shared expert computation step. + """ + + hidden_states: Optional[torch.Tensor] = None + probs: Optional[torch.Tensor] = None + routing_map: Optional[torch.Tensor] = None + residual: Optional[torch.Tensor] = None + shared_expert_output: Optional[torch.Tensor] = None + + def is_empty(self) -> bool: + """Check if the store has any non-None tensors. + + Returns: + bool: True if all fields are None, False otherwise. + """ + return all( + getattr(self, field_name) is None + for field_name in [ + 'hidden_states', + 'probs', + 'routing_map', + 'residual', + 'shared_expert_output', + ] + ) + + def set(self, **kwargs): + """Set the tensors in the store from keyword arguments.""" + for field_name, value in kwargs.items(): + assert field_name in [ + 'hidden_states', + 'probs', + 'routing_map', + 'residual', + 'shared_expert_output', + ], f"Invalid field name: {field_name}" + if value is not None: + assert isinstance( + value, torch.Tensor + ), f"Value must be a torch.Tensor, got {type(value)} for field {field_name}" + setattr(self, field_name, value) + + def clear(self): + """Reset all stored tensors to None.""" + for field_name in [ + 'hidden_states', + 'probs', + 'routing_map', + 'residual', + 'shared_expert_output', + ]: + setattr(self, field_name, None) + + +def maybe_skip_or_early_return_by_cudagraph(step_condition): + """ + Decorator to skip certain codepaths in the MoE layer forward pass in CUDA graph replay, + or early return from the MoE layer forward pass in CUDA graph capture. + + Args: + step_condition: The step condition to check. Can be "shared_experts_compute", "route", + or "preprocess". If "shared_experts_compute", the shared experts computation will be + skipped in replay if it is in the CUDA graph scope. If "route" or "preprocess", the + router or preprocess will be skipped in replay if it is in the CUDA graph scope, or + early return from the MoE layer forward pass if it is in CUDA graph capturing mode. + + Returns: + A decorator function that wraps the MoE layer forward pass. + """ + + def maybe_raise_signal(moe_layer, **kwargs): + """ + Check if the MoE layer should early return for CUDA graph capture. + If so, raise a MoECudaGraphPartialCaptureSignal. + """ + if ( + moe_layer.config.cuda_graph_impl == "transformer_engine" + and moe_layer.training + and is_graph_capturing() + ): + if ( + step_condition == "route" + and 'moe_router' in moe_layer.config.cuda_graph_scope + and 'moe_preprocess' not in moe_layer.config.cuda_graph_scope + ): + raise MoECudaGraphPartialCaptureSignal(moe_layer, "route", **kwargs) + elif ( + step_condition == "preprocess" + and 'moe_preprocess' in moe_layer.config.cuda_graph_scope + ): + raise MoECudaGraphPartialCaptureSignal(moe_layer, "preprocess", **kwargs) + + def decorator(func): + def wrapped_func(moe_layer, *args, **kwargs): + """ + Check if we should skip executing the original function based on the current + step condition and the tensor store status. If the tensor can be found in the store, + it indicates that it is already computed by the CUDA graph replay, so we can skip it. + Otherwise, we execute the original function and check if we should raise a signal to + early return in CUDA graph capture. + """ + # The non-cudagraph codepath just calls the original function. + if not is_graph_capturing() and moe_layer.cudagraph_tensor_store.is_empty(): + return func(moe_layer, *args, **kwargs) + + assert ( + not is_graph_capturing() or moe_layer.cudagraph_tensor_store.is_empty() + ), "cudagraph_tensor_store cannot be used when it is capturing cuda graph." + if step_condition == "shared_experts_compute": + if moe_layer.cudagraph_tensor_store.shared_expert_output is None: + # Don't skip the shared expert computation. + shared_expert_output = func(moe_layer, *args, **kwargs) + else: + # Skip the shared expert computation and get value from store. + shared_expert_output = moe_layer.cudagraph_tensor_store.shared_expert_output + return shared_expert_output + elif step_condition == "route": + if moe_layer.cudagraph_tensor_store.probs is None: + # Don't skip the router. + assert ( + moe_layer.cudagraph_tensor_store.routing_map is None + and moe_layer.cudagraph_tensor_store.residual is None + ), "both routing_map and residual must be None if probs is None" + probs, routing_map = func(moe_layer, *args, **kwargs) + + # Maybe early return after the router. + maybe_raise_signal(moe_layer, probs=probs, routing_map=routing_map) + else: + # Skip the router and get value from store. + assert ( + moe_layer.cudagraph_tensor_store.routing_map is not None + or moe_layer.cudagraph_tensor_store.residual is not None + ), "either routing_map or residual must be given if probs is given" + probs, routing_map = ( + moe_layer.cudagraph_tensor_store.probs, + moe_layer.cudagraph_tensor_store.routing_map, + ) + return probs, routing_map + elif step_condition == "preprocess": + if moe_layer.cudagraph_tensor_store.residual is None: + # Don't skip the preprocess. + hidden_states, probs, residual = func(moe_layer, *args, **kwargs) + + # Maybe early return after the preprocess. + maybe_raise_signal( + moe_layer, hidden_states=hidden_states, probs=probs, residual=residual + ) + else: + # Skip the preprocess and get value from store. + assert ( + moe_layer.cudagraph_tensor_store.probs is not None + ), "probs must not be None if residual is not None" + assert ( + moe_layer.cudagraph_tensor_store.routing_map is None + ), "routing_map must be None if residual is not None" + hidden_states, probs, residual = ( + moe_layer.cudagraph_tensor_store.hidden_states, + moe_layer.cudagraph_tensor_store.probs, + moe_layer.cudagraph_tensor_store.residual, + ) + return hidden_states, probs, residual + + return wrapped_func + + return decorator diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 46f94ebe79a..1711c2b5d1b 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -76,6 +76,11 @@ def __init__( self.tp_rank = utils.get_pg_rank(self.tp_group) self.ep_size = utils.get_pg_size(self.ep_group) + # Attributes that need to be captured in cudagraph. These attributes are returned + # as cudagraph outputs when the cuda_graph_scope contains moe_preprocess. + self.cudagraph_attrs = [] + self.valid_cudagraph_attrs = None + @abstractmethod def dispatch_preprocess( self, tokens: torch.Tensor, routing_map: torch.Tensor, probs: torch.Tensor @@ -241,6 +246,10 @@ def __init__( # device token permutation is enabled and **AllGahter** is performed. self.global_local_map = None + # Attributes that need to be captured in cudagraph. These attributes are returned + # as cudagraph outputs when the cuda_graph_scope contains moe_preprocess. + self.cudagraph_attrs = ['routing_map'] + def dispatch_preprocess( self, hidden_states: torch.Tensor, routing_map: torch.Tensor, probs: torch.Tensor ): @@ -433,12 +442,38 @@ def __init__( "before_finish": 3, "no_sync": 4, } - self.cuda_dtoh_point = "before_permutation_1" + if ( + config.cuda_graph_impl == "transformer_engine" + and 'moe_preprocess' in config.cuda_graph_scope + ): + self.cuda_dtoh_point = "before_ep_alltoall" + else: + self.cuda_dtoh_point = "before_permutation_1" if MoEAlltoAllTokenDispatcher.cuda_dtoh_stream is None: MoEAlltoAllTokenDispatcher.cuda_dtoh_stream = torch.cuda.Stream() + # Attributes that need to be captured in cudagraph. These attributes are returned + # as cudagraph outputs when the cuda_graph_scope contains moe_preprocess. + self.cudagraph_attrs = [ + 'tokens_per_expert', + 'input_splits', + 'output_splits', + 'output_splits_tp', + 'num_out_tokens', + 'num_global_tokens_per_local_expert', + 'reversed_local_input_permutation_mapping', + 'routing_map', + ] + self.shared_experts = None + def set_shared_experts(self, shared_experts): + """Set shared expert to the dispatcher.""" + super().set_shared_experts(shared_experts) + if shared_experts.use_shared_expert_gate: + self.cudagraph_attrs.append('shared_experts.gate_score') + self.cudagraph_attrs.append('shared_experts.cached_fc1_input') + def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: """ Preprocesses the token routing map for All-to-All communication and token permutation. @@ -1340,6 +1375,7 @@ def __init__( num_experts=self.tp_size * self.config.num_moe_experts, config=self.config, ) + self.cudagraph_attrs = ['_comm_manager.token_probs', '_comm_manager.token_indices'] elif self.config.moe_flex_dispatcher_backend == "hybridep": self._comm_manager = _HybridEPManager( group=self.tp_ep_group, @@ -1347,6 +1383,7 @@ def __init__( num_experts=self.tp_size * self.config.num_moe_experts, config=self.config, ) + self.cudagraph_attrs = ['_comm_manager.token_probs', '_comm_manager.routing_map'] else: raise ValueError( f"Invalid backend: {self.config.moe_flex_dispatcher_backend}" diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index aead6133f22..59a3b3b086a 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -522,7 +522,7 @@ def _should_call_local_cudagraph(self, *args, **kwargs): kwargs.get('inference_context') is not None or kwargs.get('inference_params') is not None ) - and self.config.cuda_graph_scope == 'full_iteration' + and 'full_iteration' in self.config.cuda_graph_scope ): if kwargs['inference_context'].is_static_batching(): using_cuda_graph = kwargs['inference_context'].is_decode_only() diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index d14f991046e..c95777c9652 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1,6 +1,5 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -import logging import warnings from dataclasses import dataclass from typing import Callable, List, Literal, Optional, Tuple, Union @@ -30,8 +29,6 @@ except ImportError: HAVE_PACKAGING = False -logger = logging.getLogger(__name__) - @dataclass class TransformerConfig(ModelParallelConfig): @@ -693,11 +690,10 @@ class TransformerConfig(ModelParallelConfig): excluding optimizer) is enabled. "transformer_engine": capture the CUDA graph using TE make_graphed_callables().""" - cuda_graph_scope: str = "full" + cuda_graph_scope: Optional[List[str]] = None """Determines the CUDA graphs capturing scope. - When cuda_graph_impl is set to "transformer_engine", valid values are "full" and "attn". - "Full" scope captures a whole Transformer layer. "Attn" scope only captures operations in - TransformerLayer._forward_attention(). + When cuda_graph_impl is set to "transformer_engine", valid values are "attn", "mlp", "moe", + "moe_router", "moe_preprocess", "mamba". None means the full layer. When cuda_graph_impl is set to "local", "full_iteration" can be specified as cuda_graph_scope to enable whole iteration CUDA graph. All other values enable layerwise CUDA graph.""" @@ -921,7 +917,7 @@ def __post_init__(self): if self.moe_enable_deepep: if self.moe_token_dispatcher_type != "flex": raise ValueError("DeepEP backend is only supported with flex token dispatcher.") - logger.warning( + warnings.warn( "moe_enable_deepep is deprecated." "Please use --moe-flex-dispatcher-backend=deepep instead." ) @@ -1513,24 +1509,133 @@ def __post_init__(self): ], f"Invalid cuda graph implementation: {self.cuda_graph_impl}" if self.cpu_offloading: raise ValueError("CUDA graphs not supported with CPU offloading.") - if self.recompute_granularity: - if ( - self.recompute_granularity != "selective" - or self.cuda_graph_impl != "transformer_engine" - or self.cuda_graph_scope != "attn" - ): - raise ValueError("CUDA graphs not supported with activation recomputation.") + + if self.cuda_graph_scope is None: + self.cuda_graph_scope = [] + elif not isinstance(self.cuda_graph_scope, list): + assert isinstance(self.cuda_graph_scope, str), ( + "cuda_graph_scope must be a string or a list of strings, " + f"got {self.cuda_graph_scope}." + ) + self.cuda_graph_scope = [self.cuda_graph_scope] + + if self.cuda_graph_impl == "local": + assert not self.cuda_graph_scope or self.cuda_graph_scope == ["full_iteration"], ( + "For local cuda graph implementation, the only valid value " + "for cuda_graph_scope is full_iteration. " + "To use other scopes, use cuda_graph_impl=transformer_engine." + ) + + if self.cuda_graph_impl == "transformer_engine": + assert "full_iteration" not in self.cuda_graph_scope, ( + "To use full iteration cuda graph, please use " + "cuda_graph_impl=transformer_engine instead of cuda_graph_impl=local." + ) + for scope in self.cuda_graph_scope: + assert scope in [ + 'attn', + 'mlp', + 'moe', + 'moe_router', + 'moe_preprocess', + 'mamba', + ], ( + "--cuda-graph-scope should be attn, mlp, moe, moe_router, moe_preprocess, " + f"or mamba, got {self.cuda_graph_scope}." + ) + + assert ( + 'moe' not in self.cuda_graph_scope or 'moe_router' not in self.cuda_graph_scope + ), 'cuda_graph_scope must not contain both moe and moe_router.' + if 'moe_preprocess' in self.cuda_graph_scope: + assert ( + 'moe_router' in self.cuda_graph_scope + ), 'moe_preprocess cuda graph is only supported with moe_router cuda graph.' + if self.num_moe_experts is None or self.num_moe_experts <= 1: + assert ( + 'moe' not in self.cuda_graph_scope + and 'moe_router' not in self.cuda_graph_scope + ), 'moe cuda graph is only supported for MoE.' else: - for module in self.recompute_modules: - if module in ['core_attn', 'mla_up_proj']: - raise ValueError( - f'attn cuda graph is not supported with {module} recompute.' + if self.moe_layer_freq == 1 or ( + isinstance(self.moe_layer_freq, list) and 0 not in self.moe_layer_freq + ): + assert 'mlp' not in self.cuda_graph_scope, ( + 'mlp cuda graph is only supported for dense layers, ' + 'but not found in the model.' + ) + if ( + self.moe_expert_capacity_factor is None + or not self.moe_pad_expert_input_to_capacity + ): + assert ( + 'moe' not in self.cuda_graph_scope + ), 'moe cuda graph is only supported with drop-padding MoE.' + if self.moe_token_dispatcher_type == 'alltoall' and ( + self.moe_expert_capacity_factor is not None + or self.moe_router_padding_for_fp8 + ): + assert 'moe_preprocess' not in self.cuda_graph_scope, ( + 'moe_preprocess cuda graph is not supported when there are ' + 'DtoH copies and synchronizations in the preprocess step.' ) + + if self.recompute_granularity: + if self.recompute_granularity != "selective" or not self.cuda_graph_scope: + raise ValueError( + "Full-layer CUDA graphs not supported with activation recomputation." + ) + elif self.cuda_graph_scope != ['full_iteration']: + # For scoped CUDA graphs, only the non-graphed parts of the layer can be + # recomputed. So check if there are overlaps between the recomputed parts + # and the graphed parts. + if "attn" in self.cuda_graph_scope: + for module in self.recompute_modules: + if module in ['core_attn', 'mla_up_proj']: + raise ValueError( + f'attn cuda graph is not supported with {module} recompute.' + ) + if "mlp" in self.cuda_graph_scope and "mlp" in self.recompute_modules: + raise ValueError(f'mlp cuda graph is not supported with mlp recompute.') + if "moe" in self.cuda_graph_scope: + for module in self.recompute_modules: + if module in ['moe_act', 'moe', 'shared_experts']: + raise ValueError( + f'moe cuda graph is not supported with {module} recompute.' + ) + if "moe_router" in self.cuda_graph_scope: + for module in self.recompute_modules: + if module in ['moe', 'shared_experts']: + raise ValueError( + f'moe_router cuda graph is not supported with {module} ' + 'recompute.' + ) if "layernorm" in self.recompute_modules: - warnings.warn( - "input_layernorm recompute is not supported with attention " - "cudagraph. Will only recompute the pre_mlp_layernorm." - ) + if ( + "attn" in self.cuda_graph_scope + and "mlp" in self.cuda_graph_scope + and ( + "moe" in self.cuda_graph_scope + or "moe_router" in self.cuda_graph_scope + ) + ): + raise ValueError( + 'cuda graph is not supported with layernorm recompute.' + ) + if "attn" in self.cuda_graph_scope: + warnings.warn( + "input_layernorm recompute is not supported with attention " + "cudagraph. Will only recompute the pre_mlp_layernorm." + ) + if ( + "mlp" in self.cuda_graph_scope + or "moe" in self.cuda_graph_scope + or "moe_router" in self.cuda_graph_scope + ): + warnings.warn( + "pre_mlp_layernorm recompute is not supported with mlp/moe " + "cudagraph. Will only recompute the input_layernorm." + ) if self.moe_token_dispatcher_type in ["allgather"]: if self.variable_seq_lengths is True: diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index a5babece9d0..01bda804c0f 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -15,6 +15,7 @@ from megatron.core.dist_checkpointing.utils import apply_prefix_mapping from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.cuda_graphs import is_graph_capturing from megatron.core.transformer.enums import LayerType from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp from megatron.core.transformer.mlp import MLP @@ -371,19 +372,29 @@ def __init__( # [Module 9: BiasDropoutFusion] self.mlp_bda = build_module(submodules.mlp_bda) + self.is_moe_layer = isinstance(self.mlp, MoELayer) + self.recompute_input_layernorm = False self.recompute_pre_mlp_layernorm = False self.recompute_mlp = False if self.config.recompute_granularity == 'selective': if "layernorm" in self.config.recompute_modules: - if ( - not isinstance(self.input_layernorm, IdentityOp) - and self.config.cuda_graph_impl == "none" + if not isinstance(self.input_layernorm, IdentityOp) and ( + self.config.cuda_graph_impl == "none" + or 'attn' not in self.config.cuda_graph_scope ): self.recompute_input_layernorm = True if self.config.fp8: self.self_attention.set_for_recompute_input_layernorm() - if not isinstance(self.pre_mlp_layernorm, IdentityOp): + if not isinstance(self.pre_mlp_layernorm, IdentityOp) and ( + self.config.cuda_graph_impl == "none" + or (not self.is_moe_layer and 'mlp' not in self.config.cuda_graph_scope) + or ( + self.is_moe_layer + and 'moe' not in self.config.cuda_graph_scope + and 'moe_router' not in self.config.cuda_graph_scope + ) + ): self.recompute_pre_mlp_layernorm = True if self.config.fp8: if isinstance(self.mlp, MoELayer): @@ -395,7 +406,7 @@ def __init__( set_save_original_input(self.mlp.linear_fc1) if "mlp" in self.config.recompute_modules: - if not isinstance(self.mlp, MoELayer): + if not self.is_moe_layer: self.recompute_mlp = True # @jcasper how should we handle nvfuser? @@ -584,7 +595,19 @@ def _forward_mlp(self, hidden_states, inference_context=None): and not isinstance(self.mlp, IdentityOp) ) - if self.recompute_mlp: + if ( + self.is_moe_layer + and self.config.cuda_graph_impl == "transformer_engine" + and self.training + and is_graph_capturing() + and 'moe_router' in self.config.cuda_graph_scope + ): + assert ( + not self.recompute_pre_mlp_layernorm + ), "Recomputation is not supported for CUDA graph." + cudagraph_outputs = self.mlp(pre_mlp_layernorm_output) + return cudagraph_outputs + [residual] + elif self.recompute_mlp: if self.config.fp8: # import here to avoid circular import from megatron.core.extensions.transformer_engine import te_checkpoint @@ -613,7 +636,6 @@ def _forward_mlp(self, hidden_states, inference_context=None): bias_chunks = [bias for _, bias in outputs if bias is not None] bias_output = torch.stack(bias_chunks, dim=0).sum(dim=0) if bias_chunks else None mlp_output_with_bias = (mlp_output, bias_output) - else: mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) @@ -625,6 +647,19 @@ def _forward_mlp(self, hidden_states, inference_context=None): ) nvtx_range_pop(suffix="mlp") + return self._forward_post_mlp(mlp_output_with_bias, residual) + + def _forward_post_mlp(self, mlp_output_with_bias, residual): + """ + Perform operations after the MLP computation. + + Args: + mlp_output_with_bias (Tensor): Output tensor of the MLP layer with bias. + residual (Tensor): Residual tensor. + + Returns: + output (Tensor): Transformed hidden states of shape [s, b, h]. + """ # TODO: could we move `bias_dropout_add_exec_handler` itself # inside the module provided in the `bias_dropout_add_spec` module? nvtx_range_push(suffix="mlp_bda") @@ -679,7 +714,9 @@ def get_layer_static_inputs(self, seq_length, micro_batch_size): """ static_inputs = super().get_layer_static_inputs(seq_length, micro_batch_size) - if not isinstance(self.self_attention, IdentityOp): + if not isinstance(self.self_attention, IdentityOp) and ( + not self.config.cuda_graph_scope or 'attn' in self.config.cuda_graph_scope + ): slen_per_cp = seq_length // self.config.context_parallel_size static_inputs["attention_mask"] = ( ~(torch.tril(torch.ones((slen_per_cp, seq_length))).bool()) @@ -693,18 +730,28 @@ def _get_submodules_under_cudagraphs(self): """ Get the submodules that are covered by cudagraphs. """ - if self.config.cuda_graph_scope == 'full': - submodules = [self] - else: - assert ( - self.config.cuda_graph_scope == 'attn' - ), f"Invalid cuda_graph_scope {self.config.cuda_graph_scope}" - submodules = [ + if not self.config.cuda_graph_scope: + return super()._get_submodules_under_cudagraphs() + + submodules = [] + if 'attn' in self.config.cuda_graph_scope: + submodules += [ self.input_layernorm, self.self_attention, self.pre_cross_attn_layernorm, self.cross_attention, ] + if (not self.is_moe_layer and 'mlp' in self.config.cuda_graph_scope) or ( + self.is_moe_layer and 'moe' in self.config.cuda_graph_scope + ): + submodules += [self.pre_mlp_layernorm, self.mlp] + elif self.is_moe_layer and 'moe_router' in self.config.cuda_graph_scope: + submodules += [self.pre_mlp_layernorm, self.mlp.router] + if ( + self.config.moe_shared_expert_intermediate_size is not None + and not self.config.moe_shared_expert_overlap + ): + submodules += [self.mlp.shared_experts] return submodules def _te_cuda_graph_capture(self, *args, **kwargs): @@ -715,12 +762,31 @@ def _te_cuda_graph_capture(self, *args, **kwargs): attribute can be set to control the scope of the CUDA graph. 2. If context is None, it cannot be returned as output. """ - hidden_states, context = self._forward_attention(*args, **kwargs) - - if self.config.cuda_graph_scope == "full": + context = None + if not self.config.cuda_graph_scope or 'attn' in self.config.cuda_graph_scope: + hidden_states, context = self._forward_attention(*args, **kwargs) + else: + if len(args) > 0: + hidden_states = args[0] + else: + hidden_states = kwargs.pop("hidden_states") + + if ( + not self.config.cuda_graph_scope + or (not self.is_moe_layer and 'mlp' in self.config.cuda_graph_scope) + or ( + self.is_moe_layer + and ( + 'moe' in self.config.cuda_graph_scope + or 'moe_router' in self.config.cuda_graph_scope + ) + ) + ): hidden_states = self._forward_mlp(hidden_states) - cuda_graph_outputs = [hidden_states] - + if not isinstance(hidden_states, list) and not isinstance(hidden_states, tuple): + cuda_graph_outputs = [hidden_states] + else: + cuda_graph_outputs = list(hidden_states) if context is not None: cuda_graph_outputs.append(context) return tuple(cuda_graph_outputs) @@ -732,6 +798,11 @@ def _te_cuda_graph_replay(self, *args, **kwargs): However, CUDA graph accepts only Tensor inputs. Hence, `inference_context` and `packed_seq_params` are excluded from input list. """ + context = None + if self.config.cuda_graph_scope and 'attn' not in self.config.cuda_graph_scope: + hidden_states, context = self._forward_attention(*args, **kwargs) + args = (hidden_states,) + kwargs = {} assert (kwargs.get('inference_context') is None) and ( kwargs.get('packed_seq_params') is None @@ -741,19 +812,69 @@ def _te_cuda_graph_replay(self, *args, **kwargs): "For inference cuda graph, please use cuda_graph_impl=local instead." ) - cuda_graph_output = super()._te_cuda_graph_replay(*args, **kwargs) + cuda_graph_output = list(super()._te_cuda_graph_replay(*args, **kwargs)) if kwargs.get('context') is not None: - context = cuda_graph_output[-1] - cuda_graph_output = cuda_graph_output[:-1] + context = cuda_graph_output.pop() + + if ( + not self.config.cuda_graph_scope + or (not self.is_moe_layer and 'mlp' in self.config.cuda_graph_scope) + or (self.is_moe_layer and 'moe' in self.config.cuda_graph_scope) + ): + # CUDA Graph captures the whole MLP/MoE part. CUDA Graph output is the layer output. + assert len(cuda_graph_output) == 1, "CUDA Graph output should be the layer output." + output = cuda_graph_output.pop() + elif self.is_moe_layer and 'moe_router' in self.config.cuda_graph_scope: + # CUDA Graph partially captures the MoE. + # The rest of the layer should go to the normal pass. + shared_expert_output, routing_map, residual = None, None, None + mlp_residual = cuda_graph_output.pop() + if ( + self.config.moe_shared_expert_intermediate_size is not None + and not self.config.moe_shared_expert_overlap + ): + # The shared expert output is the fourth element in the CUDA graph output. + shared_expert_output = cuda_graph_output.pop() + + # Split cudagraph outputs into function outputs and attribute outputs, and + # process them separately. Function outputs should have three tensors. + func_output, attr_outputs = cuda_graph_output[:3], cuda_graph_output[3:] + if 'moe_preprocess' in self.config.cuda_graph_scope: + hidden_states, probs, residual = func_output + valid_cudagraph_attrs = self.mlp.token_dispatcher.valid_cudagraph_attrs + assert len(attr_outputs) == len( + valid_cudagraph_attrs + ), f"attr_outputs: {len(attr_outputs)} != {len(valid_cudagraph_attrs)}" + for i, attr_name in enumerate(valid_cudagraph_attrs): + hier_attr_name = attr_name.split('.') + attr = self.mlp.token_dispatcher + for name in hier_attr_name[:-1]: + attr = getattr(attr, name) + setattr(attr, hier_attr_name[-1], attr_outputs[i]) + else: + hidden_states, probs, routing_map = func_output + assert not attr_outputs, "cuda_graph_attr_outputs should be empty" + + # Resume the MoELayer forward pass from the end of the CUDA graph scope. + # The MoE layer will skip redundant computations when we pass in the calculated values + # through the keyword arguments. See MoELayer.forward docstring for more details. + nvtx_range_push(suffix="mlp") + self.mlp.cudagraph_tensor_store.set( + hidden_states=hidden_states, + probs=probs, + routing_map=routing_map, + residual=residual, + shared_expert_output=shared_expert_output, + ) + mlp_output_with_bias = self.mlp(hidden_states) + self.mlp.cudagraph_tensor_store.clear() + nvtx_range_pop(suffix="mlp") + + output = self._forward_post_mlp(mlp_output_with_bias, mlp_residual) else: - context = None - if self.config.cuda_graph_scope == "attn": - # CUDA Graph only covers the attention layer. Feed-forward - # layer still goes through the normal pass. + # CUDA Graph does not capture the MLP/MoE part at all. output = self._forward_mlp(*cuda_graph_output) - else: - output = cuda_graph_output[0] return output, context def _get_te_cuda_graph_replay_args(self, *args, **kwargs): @@ -826,7 +947,7 @@ def _should_call_local_cudagraph(self, *args, **kwargs): (kwargs.get('inference_context') is not None) or (kwargs.get('inference_params') is not None) ) - and self.config.cuda_graph_scope != 'full_iteration' + and 'full_iteration' not in self.config.cuda_graph_scope ): if kwargs['inference_context'].is_static_batching(): using_cuda_graph = kwargs['inference_context'].is_decode_only() diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index d1e062edd02..82288710947 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -772,7 +772,7 @@ def validate_args(args, defaults={}): if args.rank == 0: print('accumulate and all-reduce gradients in fp32 for ' 'bfloat16 data type.', flush=True) - if args.cuda_graph_impl == "local" and args.cuda_graph_scope=="full_iteration": + if args.cuda_graph_impl == "local" and "full_iteration" in args.cuda_graph_scope: if not args.inference_dynamic_batching: assert not args.check_for_nan_in_loss_and_grad, \ "--no-check-for-nan-in-loss-and-grad should be set with full_iteration CUDA graph" @@ -1228,9 +1228,12 @@ def validate_args(args, defaults={}): if args.transformer_impl == 'transformer_engine' and not args.te_rng_tracker: args.te_rng_tracker = True warn_rank_0("te_rng_tracker is not enabled, enabling it for CUDA graphs.", args.rank) - assert "expandable_segments:True" not in os.getenv("PYTORCH_CUDA_ALLOC_CONF", ""), ( - "expandable_segments:True may not be safe when using CUDA Graphs with some specific parallel settings. " - "The training may crash with illegal memory access." + assert ( + "expandable_segments:True" not in os.getenv("PYTORCH_CUDA_ALLOC_CONF", "") + or os.getenv("NCCL_GRAPH_REGISTER", "") == "0" + ), ( + "Setting NCCL_GRAPH_REGISTER=0 to avoid illegal memory access when using " + "CUDA Graph with PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True." ) assert ( args.recompute_granularity != 'full' @@ -1434,22 +1437,27 @@ def _add_inference_args(parser): help="Number of CUDA graph warmup steps") group.add_argument('--external-cuda-graph', action='store_true', help='Deprecated. Use --cuda-graph-impl=transformer_engine instead. ' - 'Use TE make_graphed_callables() to capture the CUDA graph.') + 'Use TE make_graphed_callables() to capture the CUDA graph. ' + 'Use --cuda-graph-scope=\"attn\", \"mlp\", \"moe\", \"moe_router\", \"moe_preprocess\", \"mamba\" for partial capture. ') group.add_argument('--cuda-graph-impl', type=str, default='none', choices=['none', 'local', 'transformer_engine'], help='Determines the CUDA graph capture implementation. ' '"none": no CUDA graph. ' '"local": capture the CUDA graph using MCore local implementation. --cuda-graph-scope=\"full_iteration\" enables whole iteration CUDA graph. ' '"transformer_engine": capture the CUDA graph using TE make_graphed_callables().') - group.add_argument('--cuda-graph-scope', type=str, default='full', - choices=['full', 'attn', 'full_iteration'], - help='Determines the CUDA graphs capturing scope. Valid values are ' - '\"full\", \"attn\" and \"full_iteration\". \"Full\" scope captures a whole ' - 'Transformer layer. \"Attn\" scope only captures operations in ' - 'TransformerLayer._forward_attention(). \"ful_iteration\" scope captures a ' - 'whole iteration. ' - 'full_iteration scope is only supported with --cuda-graph-impl=local, ' - 'attn scope is only supported with --cuda-graph-impl=transformer_engine.') + group.add_argument('--cuda-graph-scope', nargs='+', type=str, default=[], + help='Determines the CUDA graphs capturing scope. ' + 'choices: "attn", "mlp", "moe", "moe_router", "moe_preprocess", "mamba", "full_iteration". ' + '"attn": captures operations in TransformerLayer._forward_attention(). ' + '"mlp": captures operations in TransformerLayer._forward_mlp() for a dense layer. ' + '"moe": captures operations in TransformerLayer._forward_mlp() for a MoE layer. ' + '"moe_router": captures operations in TransformerLayer._forward_mlp() up to MoELayer.router(), ' + 'including the shared experts if they are not overlapped with EP comm. ' + '"moe_preprocess": captures operations in MoELayer.preprocess(). Must be used together with "moe_router". ' + '"mamba": captures the mamba layer. ' + '"full_iteration": captures a whole iteration. ' + 'full_iteration scope is only supported with --cuda-graph-impl=local, other scopes are only supported with --cuda-graph-impl=transformer_engine. ' + 'If not specified, the default scope is to capture the whole Transformer layer.') group.add_argument('--inference-max-requests', type=int, default=8, help='Maximum number of requests for inference.', dest='inference_max_batch_size') diff --git a/megatron/training/training.py b/megatron/training/training.py index b0d5b950a32..aa39ab05b35 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -2234,7 +2234,7 @@ def train( eval_iterations = 0 # Wrap forward_backward_func for Full iteration CUDA graph forward_backward_func = get_forward_backward_func() - if args.cuda_graph_impl == "local" and args.cuda_graph_scope=="full_iteration": + if args.cuda_graph_impl == "local" and "full_iteration" in args.cuda_graph_scope: forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps) def get_e2e_base_metrics(): @@ -2365,12 +2365,13 @@ def get_e2e_base_metrics(): # Capture CUDA Graphs. if ( args.cuda_graph_impl == "transformer_engine" - and iteration == args.cuda_graph_warmup_steps + and not cuda_graph_helper.graphs_created() + and iteration - start_iteration == args.cuda_graph_warmup_steps ): - if iteration > start_iteration and should_disable_forward_pre_hook(args): + if args.cuda_graph_warmup_steps > 0 and should_disable_forward_pre_hook(args): disable_forward_pre_hook(model, param_sync=False) cuda_graph_helper.create_cudagraphs() - if iteration > start_iteration and should_disable_forward_pre_hook(args): + if args.cuda_graph_warmup_steps > 0 and should_disable_forward_pre_hook(args): enable_forward_pre_hook(model) cuda_graph_helper.cuda_graph_set_manual_hooks() @@ -2445,8 +2446,11 @@ def get_e2e_base_metrics(): # Set the manual hooks here since it's not set right after the capturing. if ( args.cuda_graph_impl == "transformer_engine" - and iteration == args.cuda_graph_warmup_steps + and args.cuda_graph_warmup_steps == 0 ): + assert ( + cuda_graph_helper.graphs_created() + ), "CUDA Graphs should have been created." cuda_graph_helper.cuda_graph_set_manual_hooks() iteration += 1 @@ -2646,7 +2650,7 @@ def evaluate( eval_batch_size = args.global_batch_size eval_num_microbatches = eval_batch_size // (args.micro_batch_size * args.data_parallel_size) forward_backward_func = get_forward_backward_func() - if args.cuda_graph_impl == "local" and args.cuda_graph_scope=="full_iteration": + if args.cuda_graph_impl == "local" and "full_iteration" in args.cuda_graph_scope: forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps) if eval_iters is None: diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph/golden_values_dev_dgx_h100.json new file mode 100644 index 00000000000..309b2533461 --- /dev/null +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph/golden_values_dev_dgx_h100.json @@ -0,0 +1,644 @@ +{ + "lm loss": { + "start_step": 1, + "end_step": 100, + "step_interval": 1, + "values": { + "1": 10.93663, + "2": 10.9327, + "3": 10.94263, + "4": 10.94969, + "5": 10.95052, + "6": 10.94157, + "7": 10.94484, + "8": 10.93674, + "9": 10.94996, + "10": 10.93686, + "11": 10.94102, + "12": 10.93763, + "13": 10.9235, + "14": 10.93428, + "15": 10.88791, + "16": 10.87434, + "17": 10.86896, + "18": 10.86065, + "19": 10.86311, + "20": 10.78063, + "21": 10.73125, + "22": 10.60283, + "23": 10.73278, + "24": 10.61888, + "25": 10.55212, + "26": 10.62704, + "27": 10.6391, + "28": 10.5908, + "29": 10.59809, + "30": 10.37777, + "31": 10.1201, + "32": 10.46078, + "33": 10.45538, + "34": 10.20107, + "35": 10.25779, + "36": 10.20889, + "37": 10.33688, + "38": 10.16827, + "39": 10.40875, + "40": 10.05239, + "41": 10.09432, + "42": 10.17894, + "43": 9.74205, + "44": 9.8904, + "45": 9.74009, + "46": 9.72707, + "47": 10.09139, + "48": 9.75298, + "49": 9.40106, + "50": 9.83667, + "51": 9.77071, + "52": 9.65705, + "53": 10.03051, + "54": 9.87899, + "55": 9.79604, + "56": 9.52924, + "57": 9.36583, + "58": 9.75331, + "59": 9.48065, + "60": 9.40785, + "61": 9.60145, + "62": 9.90753, + "63": 9.2583, + "64": 9.68397, + "65": 8.80003, + "66": 9.60779, + "67": 9.25408, + "68": 9.71438, + "69": 9.71682, + "70": 9.6617, + "71": 9.52466, + "72": 9.47116, + "73": 9.38822, + "74": 8.80223, + "75": 9.33966, + "76": 8.93574, + "77": 9.99333, + "78": 9.64731, + "79": 9.28114, + "80": 9.29588, + "81": 9.39589, + "82": 9.60893, + "83": 9.21629, + "84": 9.33891, + "85": 9.52979, + "86": 8.95817, + "87": 9.51641, + "88": 9.68228, + "89": 9.50664, + "90": 9.75348, + "91": 9.23465, + "92": 9.25972, + "93": 8.94517, + "94": 8.69188, + "95": 9.44591, + "96": 9.4101, + "97": 9.20087, + "98": 9.58175, + "99": 8.75818, + "100": 9.29466 + } + }, + "num-zeros": { + "start_step": 1, + "end_step": 100, + "step_interval": 1, + "values": { + "1": 22750260.0, + "2": 22953110.0, + "3": 22604450.0, + "4": 23266322.0, + "5": 22735560.0, + "6": 23061920.0, + "7": 22793342.0, + "8": 22960820.0, + "9": 22865664.0, + "10": 22950364.0, + "11": 22499674.0, + "12": 22456088.0, + "13": 22948060.0, + "14": 22384512.0, + "15": 22846272.0, + "16": 22856858.0, + "17": 22836412.0, + "18": 22590058.0, + "19": 22627048.0, + "20": 22712308.0, + "21": 22762624.0, + "22": 22816888.0, + "23": 22545124.0, + "24": 22794440.0, + "25": 22841936.0, + "26": 22549680.0, + "27": 22464820.0, + "28": 22453684.0, + "29": 22534640.0, + "30": 22636152.0, + "31": 22989488.0, + "32": 22594070.0, + "33": 22566010.0, + "34": 22855504.0, + "35": 22813688.0, + "36": 22595396.0, + "37": 22499360.0, + "38": 22926126.0, + "39": 22825392.0, + "40": 22675666.0, + "41": 22671586.0, + "42": 22682140.0, + "43": 23013940.0, + "44": 22764458.0, + "45": 22678992.0, + "46": 22915276.0, + "47": 22642868.0, + "48": 22954190.0, + "49": 23786668.0, + "50": 22934008.0, + "51": 23866222.0, + "52": 23807290.0, + "53": 24007532.0, + "54": 22871610.0, + "55": 23571284.0, + "56": 23954310.0, + "57": 24211632.0, + "58": 23914404.0, + "59": 23771838.0, + "60": 23813560.0, + "61": 23797288.0, + "62": 23739984.0, + "63": 23916692.0, + "64": 23895952.0, + "65": 24150562.0, + "66": 23796504.0, + "67": 25032232.0, + "68": 23673188.0, + "69": 23648580.0, + "70": 23903504.0, + "71": 24864636.0, + "72": 24767108.0, + "73": 24850612.0, + "74": 24132990.0, + "75": 24146528.0, + "76": 25025540.0, + "77": 24358472.0, + "78": 24910064.0, + "79": 23810516.0, + "80": 24821440.0, + "81": 25020512.0, + "82": 23851244.0, + "83": 24961024.0, + "84": 25144020.0, + "85": 24823608.0, + "86": 23153096.0, + "87": 24850204.0, + "88": 24749150.0, + "89": 22505554.0, + "90": 24059620.0, + "91": 23839038.0, + "92": 23874568.0, + "93": 24769548.0, + "94": 23992452.0, + "95": 25189838.0, + "96": 23909262.0, + "97": 24713068.0, + "98": 23832506.0, + "99": 23983474.0, + "100": 24101108.0 + } + }, + "mem-allocated-bytes": { + "start_step": 1, + "end_step": 100, + "step_interval": 1, + "values": { + "1": 763142656.0, + "2": 778734592.0, + "3": 772525056.0, + "4": 803593216.0, + "5": 803593216.0, + "6": 803593216.0, + "7": 801299456.0, + "8": 803593216.0, + "9": 801840128.0, + "10": 803593216.0, + "11": 802987008.0, + "12": 803593216.0, + "13": 802987008.0, + "14": 801299456.0, + "15": 803593216.0, + "16": 801840128.0, + "17": 803593216.0, + "18": 802987008.0, + "19": 801299456.0, + "20": 803593216.0, + "21": 801299456.0, + "22": 803593216.0, + "23": 801299456.0, + "24": 803593216.0, + "25": 801299456.0, + "26": 803593216.0, + "27": 801299456.0, + "28": 803593216.0, + "29": 801299456.0, + "30": 803593216.0, + "31": 801299456.0, + "32": 803593216.0, + "33": 801840128.0, + "34": 803593216.0, + "35": 801840128.0, + "36": 803593216.0, + "37": 802987008.0, + "38": 801299456.0, + "39": 803593216.0, + "40": 801299456.0, + "41": 803593216.0, + "42": 801840128.0, + "43": 803593216.0, + "44": 801840128.0, + "45": 803593216.0, + "46": 801840128.0, + "47": 803593216.0, + "48": 801840128.0, + "49": 803593216.0, + "50": 801840128.0, + "51": 801299456.0, + "52": 803593216.0, + "53": 801299456.0, + "54": 803593216.0, + "55": 801840128.0, + "56": 803593216.0, + "57": 801840128.0, + "58": 803593216.0, + "59": 801840128.0, + "60": 803593216.0, + "61": 801299456.0, + "62": 803593216.0, + "63": 801299456.0, + "64": 802987008.0, + "65": 803593216.0, + "66": 801299456.0, + "67": 803593216.0, + "68": 801299456.0, + "69": 803593216.0, + "70": 801840128.0, + "71": 803593216.0, + "72": 801299456.0, + "73": 803593216.0, + "74": 803593216.0, + "75": 802987008.0, + "76": 803593216.0, + "77": 801840128.0, + "78": 803593216.0, + "79": 801299456.0, + "80": 802987008.0, + "81": 803593216.0, + "82": 801840128.0, + "83": 803593216.0, + "84": 801299456.0, + "85": 802987008.0, + "86": 803593216.0, + "87": 801840128.0, + "88": 803593216.0, + "89": 801299456.0, + "90": 802987008.0, + "91": 803593216.0, + "92": 801299456.0, + "93": 803593216.0, + "94": 801299456.0, + "95": 803593216.0, + "96": 801299456.0, + "97": 803593216.0, + "98": 801299456.0, + "99": 802987008.0, + "100": 803593216.0 + } + }, + "mem-max-allocated-bytes": { + "start_step": 1, + "end_step": 100, + "step_interval": 1, + "values": { + "1": 993582592.0, + "2": 1210942464.0, + "3": 1210942464.0, + "4": 1210942464.0, + "5": 1210942464.0, + "6": 1210942464.0, + "7": 1210942464.0, + "8": 1210942464.0, + "9": 1210942464.0, + "10": 1210942464.0, + "11": 1210942464.0, + "12": 1210942464.0, + "13": 1210942464.0, + "14": 1210942464.0, + "15": 1210942464.0, + "16": 1210942464.0, + "17": 1210942464.0, + "18": 1210942464.0, + "19": 1210942464.0, + "20": 1210942464.0, + "21": 1210942464.0, + "22": 1210942464.0, + "23": 1210942464.0, + "24": 1210942464.0, + "25": 1210942464.0, + "26": 1210942464.0, + "27": 1210942464.0, + "28": 1210942464.0, + "29": 1210942464.0, + "30": 1210942464.0, + "31": 1210942464.0, + "32": 1210942464.0, + "33": 1210942464.0, + "34": 1210942464.0, + "35": 1210942464.0, + "36": 1210942464.0, + "37": 1210942464.0, + "38": 1210942464.0, + "39": 1210942464.0, + "40": 1210942464.0, + "41": 1210942464.0, + "42": 1210942464.0, + "43": 1210942464.0, + "44": 1210942464.0, + "45": 1210942464.0, + "46": 1210942464.0, + "47": 1210942464.0, + "48": 1210942464.0, + "49": 1210942464.0, + "50": 1210942464.0, + "51": 1210942464.0, + "52": 1210942464.0, + "53": 1210942464.0, + "54": 1210942464.0, + "55": 1210942464.0, + "56": 1210942464.0, + "57": 1210942464.0, + "58": 1210942464.0, + "59": 1210942464.0, + "60": 1210942464.0, + "61": 1210942464.0, + "62": 1210942464.0, + "63": 1210942464.0, + "64": 1210942464.0, + "65": 1210942464.0, + "66": 1210942464.0, + "67": 1210942464.0, + "68": 1210942464.0, + "69": 1210942464.0, + "70": 1210942464.0, + "71": 1210942464.0, + "72": 1210942464.0, + "73": 1210942464.0, + "74": 1210942464.0, + "75": 1210942464.0, + "76": 1210942464.0, + "77": 1210942464.0, + "78": 1210942464.0, + "79": 1210942464.0, + "80": 1210942464.0, + "81": 1210942464.0, + "82": 1210942464.0, + "83": 1210942464.0, + "84": 1210942464.0, + "85": 1210942464.0, + "86": 1210942464.0, + "87": 1210942464.0, + "88": 1210942464.0, + "89": 1210942464.0, + "90": 1210942464.0, + "91": 1210942464.0, + "92": 1210942464.0, + "93": 1210942464.0, + "94": 1210942464.0, + "95": 1210942464.0, + "96": 1210942464.0, + "97": 1210942464.0, + "98": 1210942464.0, + "99": 1210942464.0, + "100": 1210942464.0 + } + }, + "mtp_1 loss": { + "start_step": 1, + "end_step": 100, + "step_interval": 1, + "values": { + "1": 10.88689, + "2": 10.90485, + "3": 10.90869, + "4": 10.86903, + "5": 10.91601, + "6": 10.906, + "7": 10.90268, + "8": 10.88984, + "9": 10.90425, + "10": 10.89144, + "11": 10.93384, + "12": 10.91647, + "13": 10.91108, + "14": 10.91974, + "15": 10.88488, + "16": 10.9077, + "17": 10.87571, + "18": 10.91379, + "19": 10.9092, + "20": 10.87837, + "21": 10.87896, + "22": 10.85583, + "23": 10.88007, + "24": 10.87245, + "25": 10.85859, + "26": 10.8696, + "27": 10.87702, + "28": 10.88641, + "29": 10.88866, + "30": 10.85422, + "31": 10.79713, + "32": 10.86631, + "33": 10.8781, + "34": 10.83982, + "35": 10.84165, + "36": 10.85012, + "37": 10.85556, + "38": 10.83674, + "39": 10.86355, + "40": 10.82887, + "41": 10.8341, + "42": 10.84469, + "43": 10.78828, + "44": 10.82123, + "45": 10.78831, + "46": 10.7823, + "47": 10.82898, + "48": 10.78985, + "49": 10.71269, + "50": 10.77382, + "51": 10.76639, + "52": 10.7397, + "53": 10.80285, + "54": 10.77365, + "55": 10.76066, + "56": 10.71068, + "57": 10.66686, + "58": 10.74378, + "59": 10.69209, + "60": 10.66474, + "61": 10.7073, + "62": 10.77206, + "63": 10.61812, + "64": 10.7178, + "65": 10.49439, + "66": 10.67106, + "67": 10.57534, + "68": 10.6873, + "69": 10.6816, + "70": 10.66836, + "71": 10.64586, + "72": 10.60925, + "73": 10.56508, + "74": 10.37144, + "75": 10.51183, + "76": 10.39914, + "77": 10.75182, + "78": 10.6268, + "79": 10.46827, + "80": 10.47524, + "81": 10.51083, + "82": 10.58769, + "83": 10.4381, + "84": 10.45057, + "85": 10.55084, + "86": 10.28076, + "87": 10.51088, + "88": 10.60323, + "89": 10.50794, + "90": 10.60274, + "91": 10.38238, + "92": 10.38703, + "93": 10.23076, + "94": 10.08438, + "95": 10.42616, + "96": 10.44905, + "97": 10.32215, + "98": 10.4966, + "99": 10.04765, + "100": 10.33491 + } + }, + "iteration-time": { + "start_step": 1, + "end_step": 100, + "step_interval": 1, + "values": { + "1": 51.30209, + "2": 1.41746, + "3": 1.28029, + "4": 10.57024, + "5": 0.66643, + "6": 0.67893, + "7": 0.65727, + "8": 0.66196, + "9": 0.66227, + "10": 0.65877, + "11": 0.65828, + "12": 0.65862, + "13": 0.65727, + "14": 0.65896, + "15": 0.65851, + "16": 0.66826, + "17": 0.65878, + "18": 0.65573, + "19": 0.65631, + "20": 0.65579, + "21": 0.65091, + "22": 0.65603, + "23": 0.65158, + "24": 0.65266, + "25": 0.65816, + "26": 0.65194, + "27": 0.6541, + "28": 0.65515, + "29": 0.65439, + "30": 0.65241, + "31": 0.65597, + "32": 0.65551, + "33": 0.65318, + "34": 0.6553, + "35": 0.65725, + "36": 0.65926, + "37": 0.65606, + "38": 0.65571, + "39": 0.65846, + "40": 0.65642, + "41": 0.65509, + "42": 0.66105, + "43": 0.65448, + "44": 0.65534, + "45": 0.65304, + "46": 0.65227, + "47": 0.64871, + "48": 0.65257, + "49": 0.65485, + "50": 0.65054, + "51": 0.67883, + "52": 0.6571, + "53": 0.65671, + "54": 0.65877, + "55": 0.65584, + "56": 0.65072, + "57": 0.64951, + "58": 0.65703, + "59": 0.65106, + "60": 0.64536, + "61": 0.64416, + "62": 0.64816, + "63": 0.64084, + "64": 0.6396, + "65": 0.64182, + "66": 0.64004, + "67": 0.64101, + "68": 0.63928, + "69": 0.65723, + "70": 0.6828, + "71": 0.64052, + "72": 0.64287, + "73": 0.64136, + "74": 0.64252, + "75": 0.64617, + "76": 0.64857, + "77": 0.64304, + "78": 0.64068, + "79": 0.64048, + "80": 0.64091, + "81": 0.64179, + "82": 0.64793, + "83": 0.641, + "84": 0.64077, + "85": 0.64011, + "86": 0.64018, + "87": 0.64132, + "88": 0.63901, + "89": 0.6407, + "90": 0.64277, + "91": 0.64132, + "92": 0.64123, + "93": 0.65051, + "94": 0.65036, + "95": 0.64542, + "96": 0.64561, + "97": 0.6504, + "98": 0.64563, + "99": 0.64524, + "100": 0.65049 + } + } +} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph/golden_values_dev_dgxh100_coreweave.json b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph/golden_values_dev_dgxh100_coreweave.json new file mode 100644 index 00000000000..309b2533461 --- /dev/null +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph/golden_values_dev_dgxh100_coreweave.json @@ -0,0 +1,644 @@ +{ + "lm loss": { + "start_step": 1, + "end_step": 100, + "step_interval": 1, + "values": { + "1": 10.93663, + "2": 10.9327, + "3": 10.94263, + "4": 10.94969, + "5": 10.95052, + "6": 10.94157, + "7": 10.94484, + "8": 10.93674, + "9": 10.94996, + "10": 10.93686, + "11": 10.94102, + "12": 10.93763, + "13": 10.9235, + "14": 10.93428, + "15": 10.88791, + "16": 10.87434, + "17": 10.86896, + "18": 10.86065, + "19": 10.86311, + "20": 10.78063, + "21": 10.73125, + "22": 10.60283, + "23": 10.73278, + "24": 10.61888, + "25": 10.55212, + "26": 10.62704, + "27": 10.6391, + "28": 10.5908, + "29": 10.59809, + "30": 10.37777, + "31": 10.1201, + "32": 10.46078, + "33": 10.45538, + "34": 10.20107, + "35": 10.25779, + "36": 10.20889, + "37": 10.33688, + "38": 10.16827, + "39": 10.40875, + "40": 10.05239, + "41": 10.09432, + "42": 10.17894, + "43": 9.74205, + "44": 9.8904, + "45": 9.74009, + "46": 9.72707, + "47": 10.09139, + "48": 9.75298, + "49": 9.40106, + "50": 9.83667, + "51": 9.77071, + "52": 9.65705, + "53": 10.03051, + "54": 9.87899, + "55": 9.79604, + "56": 9.52924, + "57": 9.36583, + "58": 9.75331, + "59": 9.48065, + "60": 9.40785, + "61": 9.60145, + "62": 9.90753, + "63": 9.2583, + "64": 9.68397, + "65": 8.80003, + "66": 9.60779, + "67": 9.25408, + "68": 9.71438, + "69": 9.71682, + "70": 9.6617, + "71": 9.52466, + "72": 9.47116, + "73": 9.38822, + "74": 8.80223, + "75": 9.33966, + "76": 8.93574, + "77": 9.99333, + "78": 9.64731, + "79": 9.28114, + "80": 9.29588, + "81": 9.39589, + "82": 9.60893, + "83": 9.21629, + "84": 9.33891, + "85": 9.52979, + "86": 8.95817, + "87": 9.51641, + "88": 9.68228, + "89": 9.50664, + "90": 9.75348, + "91": 9.23465, + "92": 9.25972, + "93": 8.94517, + "94": 8.69188, + "95": 9.44591, + "96": 9.4101, + "97": 9.20087, + "98": 9.58175, + "99": 8.75818, + "100": 9.29466 + } + }, + "num-zeros": { + "start_step": 1, + "end_step": 100, + "step_interval": 1, + "values": { + "1": 22750260.0, + "2": 22953110.0, + "3": 22604450.0, + "4": 23266322.0, + "5": 22735560.0, + "6": 23061920.0, + "7": 22793342.0, + "8": 22960820.0, + "9": 22865664.0, + "10": 22950364.0, + "11": 22499674.0, + "12": 22456088.0, + "13": 22948060.0, + "14": 22384512.0, + "15": 22846272.0, + "16": 22856858.0, + "17": 22836412.0, + "18": 22590058.0, + "19": 22627048.0, + "20": 22712308.0, + "21": 22762624.0, + "22": 22816888.0, + "23": 22545124.0, + "24": 22794440.0, + "25": 22841936.0, + "26": 22549680.0, + "27": 22464820.0, + "28": 22453684.0, + "29": 22534640.0, + "30": 22636152.0, + "31": 22989488.0, + "32": 22594070.0, + "33": 22566010.0, + "34": 22855504.0, + "35": 22813688.0, + "36": 22595396.0, + "37": 22499360.0, + "38": 22926126.0, + "39": 22825392.0, + "40": 22675666.0, + "41": 22671586.0, + "42": 22682140.0, + "43": 23013940.0, + "44": 22764458.0, + "45": 22678992.0, + "46": 22915276.0, + "47": 22642868.0, + "48": 22954190.0, + "49": 23786668.0, + "50": 22934008.0, + "51": 23866222.0, + "52": 23807290.0, + "53": 24007532.0, + "54": 22871610.0, + "55": 23571284.0, + "56": 23954310.0, + "57": 24211632.0, + "58": 23914404.0, + "59": 23771838.0, + "60": 23813560.0, + "61": 23797288.0, + "62": 23739984.0, + "63": 23916692.0, + "64": 23895952.0, + "65": 24150562.0, + "66": 23796504.0, + "67": 25032232.0, + "68": 23673188.0, + "69": 23648580.0, + "70": 23903504.0, + "71": 24864636.0, + "72": 24767108.0, + "73": 24850612.0, + "74": 24132990.0, + "75": 24146528.0, + "76": 25025540.0, + "77": 24358472.0, + "78": 24910064.0, + "79": 23810516.0, + "80": 24821440.0, + "81": 25020512.0, + "82": 23851244.0, + "83": 24961024.0, + "84": 25144020.0, + "85": 24823608.0, + "86": 23153096.0, + "87": 24850204.0, + "88": 24749150.0, + "89": 22505554.0, + "90": 24059620.0, + "91": 23839038.0, + "92": 23874568.0, + "93": 24769548.0, + "94": 23992452.0, + "95": 25189838.0, + "96": 23909262.0, + "97": 24713068.0, + "98": 23832506.0, + "99": 23983474.0, + "100": 24101108.0 + } + }, + "mem-allocated-bytes": { + "start_step": 1, + "end_step": 100, + "step_interval": 1, + "values": { + "1": 763142656.0, + "2": 778734592.0, + "3": 772525056.0, + "4": 803593216.0, + "5": 803593216.0, + "6": 803593216.0, + "7": 801299456.0, + "8": 803593216.0, + "9": 801840128.0, + "10": 803593216.0, + "11": 802987008.0, + "12": 803593216.0, + "13": 802987008.0, + "14": 801299456.0, + "15": 803593216.0, + "16": 801840128.0, + "17": 803593216.0, + "18": 802987008.0, + "19": 801299456.0, + "20": 803593216.0, + "21": 801299456.0, + "22": 803593216.0, + "23": 801299456.0, + "24": 803593216.0, + "25": 801299456.0, + "26": 803593216.0, + "27": 801299456.0, + "28": 803593216.0, + "29": 801299456.0, + "30": 803593216.0, + "31": 801299456.0, + "32": 803593216.0, + "33": 801840128.0, + "34": 803593216.0, + "35": 801840128.0, + "36": 803593216.0, + "37": 802987008.0, + "38": 801299456.0, + "39": 803593216.0, + "40": 801299456.0, + "41": 803593216.0, + "42": 801840128.0, + "43": 803593216.0, + "44": 801840128.0, + "45": 803593216.0, + "46": 801840128.0, + "47": 803593216.0, + "48": 801840128.0, + "49": 803593216.0, + "50": 801840128.0, + "51": 801299456.0, + "52": 803593216.0, + "53": 801299456.0, + "54": 803593216.0, + "55": 801840128.0, + "56": 803593216.0, + "57": 801840128.0, + "58": 803593216.0, + "59": 801840128.0, + "60": 803593216.0, + "61": 801299456.0, + "62": 803593216.0, + "63": 801299456.0, + "64": 802987008.0, + "65": 803593216.0, + "66": 801299456.0, + "67": 803593216.0, + "68": 801299456.0, + "69": 803593216.0, + "70": 801840128.0, + "71": 803593216.0, + "72": 801299456.0, + "73": 803593216.0, + "74": 803593216.0, + "75": 802987008.0, + "76": 803593216.0, + "77": 801840128.0, + "78": 803593216.0, + "79": 801299456.0, + "80": 802987008.0, + "81": 803593216.0, + "82": 801840128.0, + "83": 803593216.0, + "84": 801299456.0, + "85": 802987008.0, + "86": 803593216.0, + "87": 801840128.0, + "88": 803593216.0, + "89": 801299456.0, + "90": 802987008.0, + "91": 803593216.0, + "92": 801299456.0, + "93": 803593216.0, + "94": 801299456.0, + "95": 803593216.0, + "96": 801299456.0, + "97": 803593216.0, + "98": 801299456.0, + "99": 802987008.0, + "100": 803593216.0 + } + }, + "mem-max-allocated-bytes": { + "start_step": 1, + "end_step": 100, + "step_interval": 1, + "values": { + "1": 993582592.0, + "2": 1210942464.0, + "3": 1210942464.0, + "4": 1210942464.0, + "5": 1210942464.0, + "6": 1210942464.0, + "7": 1210942464.0, + "8": 1210942464.0, + "9": 1210942464.0, + "10": 1210942464.0, + "11": 1210942464.0, + "12": 1210942464.0, + "13": 1210942464.0, + "14": 1210942464.0, + "15": 1210942464.0, + "16": 1210942464.0, + "17": 1210942464.0, + "18": 1210942464.0, + "19": 1210942464.0, + "20": 1210942464.0, + "21": 1210942464.0, + "22": 1210942464.0, + "23": 1210942464.0, + "24": 1210942464.0, + "25": 1210942464.0, + "26": 1210942464.0, + "27": 1210942464.0, + "28": 1210942464.0, + "29": 1210942464.0, + "30": 1210942464.0, + "31": 1210942464.0, + "32": 1210942464.0, + "33": 1210942464.0, + "34": 1210942464.0, + "35": 1210942464.0, + "36": 1210942464.0, + "37": 1210942464.0, + "38": 1210942464.0, + "39": 1210942464.0, + "40": 1210942464.0, + "41": 1210942464.0, + "42": 1210942464.0, + "43": 1210942464.0, + "44": 1210942464.0, + "45": 1210942464.0, + "46": 1210942464.0, + "47": 1210942464.0, + "48": 1210942464.0, + "49": 1210942464.0, + "50": 1210942464.0, + "51": 1210942464.0, + "52": 1210942464.0, + "53": 1210942464.0, + "54": 1210942464.0, + "55": 1210942464.0, + "56": 1210942464.0, + "57": 1210942464.0, + "58": 1210942464.0, + "59": 1210942464.0, + "60": 1210942464.0, + "61": 1210942464.0, + "62": 1210942464.0, + "63": 1210942464.0, + "64": 1210942464.0, + "65": 1210942464.0, + "66": 1210942464.0, + "67": 1210942464.0, + "68": 1210942464.0, + "69": 1210942464.0, + "70": 1210942464.0, + "71": 1210942464.0, + "72": 1210942464.0, + "73": 1210942464.0, + "74": 1210942464.0, + "75": 1210942464.0, + "76": 1210942464.0, + "77": 1210942464.0, + "78": 1210942464.0, + "79": 1210942464.0, + "80": 1210942464.0, + "81": 1210942464.0, + "82": 1210942464.0, + "83": 1210942464.0, + "84": 1210942464.0, + "85": 1210942464.0, + "86": 1210942464.0, + "87": 1210942464.0, + "88": 1210942464.0, + "89": 1210942464.0, + "90": 1210942464.0, + "91": 1210942464.0, + "92": 1210942464.0, + "93": 1210942464.0, + "94": 1210942464.0, + "95": 1210942464.0, + "96": 1210942464.0, + "97": 1210942464.0, + "98": 1210942464.0, + "99": 1210942464.0, + "100": 1210942464.0 + } + }, + "mtp_1 loss": { + "start_step": 1, + "end_step": 100, + "step_interval": 1, + "values": { + "1": 10.88689, + "2": 10.90485, + "3": 10.90869, + "4": 10.86903, + "5": 10.91601, + "6": 10.906, + "7": 10.90268, + "8": 10.88984, + "9": 10.90425, + "10": 10.89144, + "11": 10.93384, + "12": 10.91647, + "13": 10.91108, + "14": 10.91974, + "15": 10.88488, + "16": 10.9077, + "17": 10.87571, + "18": 10.91379, + "19": 10.9092, + "20": 10.87837, + "21": 10.87896, + "22": 10.85583, + "23": 10.88007, + "24": 10.87245, + "25": 10.85859, + "26": 10.8696, + "27": 10.87702, + "28": 10.88641, + "29": 10.88866, + "30": 10.85422, + "31": 10.79713, + "32": 10.86631, + "33": 10.8781, + "34": 10.83982, + "35": 10.84165, + "36": 10.85012, + "37": 10.85556, + "38": 10.83674, + "39": 10.86355, + "40": 10.82887, + "41": 10.8341, + "42": 10.84469, + "43": 10.78828, + "44": 10.82123, + "45": 10.78831, + "46": 10.7823, + "47": 10.82898, + "48": 10.78985, + "49": 10.71269, + "50": 10.77382, + "51": 10.76639, + "52": 10.7397, + "53": 10.80285, + "54": 10.77365, + "55": 10.76066, + "56": 10.71068, + "57": 10.66686, + "58": 10.74378, + "59": 10.69209, + "60": 10.66474, + "61": 10.7073, + "62": 10.77206, + "63": 10.61812, + "64": 10.7178, + "65": 10.49439, + "66": 10.67106, + "67": 10.57534, + "68": 10.6873, + "69": 10.6816, + "70": 10.66836, + "71": 10.64586, + "72": 10.60925, + "73": 10.56508, + "74": 10.37144, + "75": 10.51183, + "76": 10.39914, + "77": 10.75182, + "78": 10.6268, + "79": 10.46827, + "80": 10.47524, + "81": 10.51083, + "82": 10.58769, + "83": 10.4381, + "84": 10.45057, + "85": 10.55084, + "86": 10.28076, + "87": 10.51088, + "88": 10.60323, + "89": 10.50794, + "90": 10.60274, + "91": 10.38238, + "92": 10.38703, + "93": 10.23076, + "94": 10.08438, + "95": 10.42616, + "96": 10.44905, + "97": 10.32215, + "98": 10.4966, + "99": 10.04765, + "100": 10.33491 + } + }, + "iteration-time": { + "start_step": 1, + "end_step": 100, + "step_interval": 1, + "values": { + "1": 51.30209, + "2": 1.41746, + "3": 1.28029, + "4": 10.57024, + "5": 0.66643, + "6": 0.67893, + "7": 0.65727, + "8": 0.66196, + "9": 0.66227, + "10": 0.65877, + "11": 0.65828, + "12": 0.65862, + "13": 0.65727, + "14": 0.65896, + "15": 0.65851, + "16": 0.66826, + "17": 0.65878, + "18": 0.65573, + "19": 0.65631, + "20": 0.65579, + "21": 0.65091, + "22": 0.65603, + "23": 0.65158, + "24": 0.65266, + "25": 0.65816, + "26": 0.65194, + "27": 0.6541, + "28": 0.65515, + "29": 0.65439, + "30": 0.65241, + "31": 0.65597, + "32": 0.65551, + "33": 0.65318, + "34": 0.6553, + "35": 0.65725, + "36": 0.65926, + "37": 0.65606, + "38": 0.65571, + "39": 0.65846, + "40": 0.65642, + "41": 0.65509, + "42": 0.66105, + "43": 0.65448, + "44": 0.65534, + "45": 0.65304, + "46": 0.65227, + "47": 0.64871, + "48": 0.65257, + "49": 0.65485, + "50": 0.65054, + "51": 0.67883, + "52": 0.6571, + "53": 0.65671, + "54": 0.65877, + "55": 0.65584, + "56": 0.65072, + "57": 0.64951, + "58": 0.65703, + "59": 0.65106, + "60": 0.64536, + "61": 0.64416, + "62": 0.64816, + "63": 0.64084, + "64": 0.6396, + "65": 0.64182, + "66": 0.64004, + "67": 0.64101, + "68": 0.63928, + "69": 0.65723, + "70": 0.6828, + "71": 0.64052, + "72": 0.64287, + "73": 0.64136, + "74": 0.64252, + "75": 0.64617, + "76": 0.64857, + "77": 0.64304, + "78": 0.64068, + "79": 0.64048, + "80": 0.64091, + "81": 0.64179, + "82": 0.64793, + "83": 0.641, + "84": 0.64077, + "85": 0.64011, + "86": 0.64018, + "87": 0.64132, + "88": 0.63901, + "89": 0.6407, + "90": 0.64277, + "91": 0.64132, + "92": 0.64123, + "93": 0.65051, + "94": 0.65036, + "95": 0.64542, + "96": 0.64561, + "97": 0.6504, + "98": 0.64563, + "99": 0.64524, + "100": 0.65049 + } + } +} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph/golden_values_dev_dgxh100_eos.json b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph/golden_values_dev_dgxh100_eos.json new file mode 100644 index 00000000000..e8c2bae571f --- /dev/null +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph/golden_values_dev_dgxh100_eos.json @@ -0,0 +1,644 @@ +{ + "lm loss": { + "start_step": 1, + "end_step": 100, + "step_interval": 1, + "values": { + "1": 10.93663, + "2": 10.9327, + "3": 10.94263, + "4": 10.94969, + "5": 10.95052, + "6": 10.94157, + "7": 10.94484, + "8": 10.93674, + "9": 10.94996, + "10": 10.93686, + "11": 10.94102, + "12": 10.93763, + "13": 10.9235, + "14": 10.93428, + "15": 10.88791, + "16": 10.87434, + "17": 10.86896, + "18": 10.86065, + "19": 10.86311, + "20": 10.78063, + "21": 10.73125, + "22": 10.60283, + "23": 10.73278, + "24": 10.61888, + "25": 10.55212, + "26": 10.62704, + "27": 10.6391, + "28": 10.5908, + "29": 10.59809, + "30": 10.37777, + "31": 10.1201, + "32": 10.46078, + "33": 10.45538, + "34": 10.20107, + "35": 10.25779, + "36": 10.20889, + "37": 10.33688, + "38": 10.16827, + "39": 10.40875, + "40": 10.05239, + "41": 10.09432, + "42": 10.17894, + "43": 9.74205, + "44": 9.8904, + "45": 9.74009, + "46": 9.72707, + "47": 10.09139, + "48": 9.75298, + "49": 9.40106, + "50": 9.83667, + "51": 9.77071, + "52": 9.65705, + "53": 10.03051, + "54": 9.87899, + "55": 9.79604, + "56": 9.52924, + "57": 9.36583, + "58": 9.75331, + "59": 9.48065, + "60": 9.40785, + "61": 9.60145, + "62": 9.90753, + "63": 9.2583, + "64": 9.68397, + "65": 8.80003, + "66": 9.60779, + "67": 9.25408, + "68": 9.71438, + "69": 9.71682, + "70": 9.6617, + "71": 9.52466, + "72": 9.47116, + "73": 9.38822, + "74": 8.80223, + "75": 9.33966, + "76": 8.93574, + "77": 9.99333, + "78": 9.64731, + "79": 9.28114, + "80": 9.29588, + "81": 9.39589, + "82": 9.60893, + "83": 9.21629, + "84": 9.33891, + "85": 9.52979, + "86": 8.95817, + "87": 9.51641, + "88": 9.68228, + "89": 9.50664, + "90": 9.75348, + "91": 9.23465, + "92": 9.25972, + "93": 8.94517, + "94": 8.69188, + "95": 9.44591, + "96": 9.4101, + "97": 9.20087, + "98": 9.58175, + "99": 8.75818, + "100": 9.29466 + } + }, + "num-zeros": { + "start_step": 1, + "end_step": 100, + "step_interval": 1, + "values": { + "1": 22750260.0, + "2": 22953110.0, + "3": 22604450.0, + "4": 23266322.0, + "5": 22735560.0, + "6": 23061920.0, + "7": 22793342.0, + "8": 22960820.0, + "9": 22865664.0, + "10": 22950364.0, + "11": 22499674.0, + "12": 22456088.0, + "13": 22948060.0, + "14": 22384512.0, + "15": 22846272.0, + "16": 22856858.0, + "17": 22836412.0, + "18": 22590058.0, + "19": 22627048.0, + "20": 22712308.0, + "21": 22762624.0, + "22": 22816888.0, + "23": 22545124.0, + "24": 22794440.0, + "25": 22841936.0, + "26": 22549680.0, + "27": 22464820.0, + "28": 22453684.0, + "29": 22534640.0, + "30": 22636152.0, + "31": 22989488.0, + "32": 22594070.0, + "33": 22566010.0, + "34": 22855504.0, + "35": 22813688.0, + "36": 22595396.0, + "37": 22499360.0, + "38": 22926126.0, + "39": 22825392.0, + "40": 22675666.0, + "41": 22671586.0, + "42": 22682140.0, + "43": 23013940.0, + "44": 22764458.0, + "45": 22678992.0, + "46": 22915276.0, + "47": 22642868.0, + "48": 22954190.0, + "49": 23786668.0, + "50": 22934008.0, + "51": 23866222.0, + "52": 23807290.0, + "53": 24007532.0, + "54": 22871610.0, + "55": 23571284.0, + "56": 23954310.0, + "57": 24211632.0, + "58": 23914404.0, + "59": 23771838.0, + "60": 23813560.0, + "61": 23797288.0, + "62": 23739984.0, + "63": 23916692.0, + "64": 23895952.0, + "65": 24150562.0, + "66": 23796504.0, + "67": 25032232.0, + "68": 23673188.0, + "69": 23648580.0, + "70": 23903504.0, + "71": 24864636.0, + "72": 24767108.0, + "73": 24850612.0, + "74": 24132990.0, + "75": 24146528.0, + "76": 25025540.0, + "77": 24358472.0, + "78": 24910064.0, + "79": 23810516.0, + "80": 24821440.0, + "81": 25020512.0, + "82": 23851244.0, + "83": 24961024.0, + "84": 25144020.0, + "85": 24823608.0, + "86": 23153096.0, + "87": 24850204.0, + "88": 24749150.0, + "89": 22505554.0, + "90": 24059620.0, + "91": 23839038.0, + "92": 23874568.0, + "93": 24769548.0, + "94": 23992452.0, + "95": 25189838.0, + "96": 23909262.0, + "97": 24713068.0, + "98": 23832506.0, + "99": 23983474.0, + "100": 24101108.0 + } + }, + "mem-allocated-bytes": { + "start_step": 1, + "end_step": 100, + "step_interval": 1, + "values": { + "1": 769688064.0, + "2": 775359488.0, + "3": 769690624.0, + "4": 801299456.0, + "5": 803593216.0, + "6": 801299456.0, + "7": 803593216.0, + "8": 803593216.0, + "9": 801299456.0, + "10": 803593216.0, + "11": 801299456.0, + "12": 803593216.0, + "13": 801299456.0, + "14": 803593216.0, + "15": 803593216.0, + "16": 801299456.0, + "17": 803593216.0, + "18": 801299456.0, + "19": 803593216.0, + "20": 801299456.0, + "21": 803593216.0, + "22": 803593216.0, + "23": 801840128.0, + "24": 803593216.0, + "25": 802987008.0, + "26": 801299456.0, + "27": 802987008.0, + "28": 801299456.0, + "29": 801299456.0, + "30": 803593216.0, + "31": 801299456.0, + "32": 803593216.0, + "33": 801299456.0, + "34": 803593216.0, + "35": 801299456.0, + "36": 801299456.0, + "37": 803593216.0, + "38": 801299456.0, + "39": 803593216.0, + "40": 801299456.0, + "41": 803593216.0, + "42": 801299456.0, + "43": 801299456.0, + "44": 803593216.0, + "45": 802987008.0, + "46": 801299456.0, + "47": 803593216.0, + "48": 801299456.0, + "49": 803593216.0, + "50": 801299456.0, + "51": 801299456.0, + "52": 803593216.0, + "53": 802446336.0, + "54": 801299456.0, + "55": 803593216.0, + "56": 802987008.0, + "57": 801299456.0, + "58": 801840128.0, + "59": 801299456.0, + "60": 803593216.0, + "61": 801840128.0, + "62": 801299456.0, + "63": 803593216.0, + "64": 802446336.0, + "65": 803593216.0, + "66": 801840128.0, + "67": 801299456.0, + "68": 803593216.0, + "69": 801840128.0, + "70": 801299456.0, + "71": 803593216.0, + "72": 803593216.0, + "73": 802987008.0, + "74": 801299456.0, + "75": 803593216.0, + "76": 803593216.0, + "77": 801299456.0, + "78": 801299456.0, + "79": 803593216.0, + "80": 801840128.0, + "81": 801299456.0, + "82": 803593216.0, + "83": 801299456.0, + "84": 801299456.0, + "85": 803593216.0, + "86": 801299456.0, + "87": 801299456.0, + "88": 803593216.0, + "89": 801840128.0, + "90": 803593216.0, + "91": 802987008.0, + "92": 801299456.0, + "93": 803593216.0, + "94": 801299456.0, + "95": 801299456.0, + "96": 803593216.0, + "97": 801840128.0, + "98": 803593216.0, + "99": 802987008.0, + "100": 801299456.0 + } + }, + "mem-max-allocated-bytes": { + "start_step": 1, + "end_step": 100, + "step_interval": 1, + "values": { + "1": 988765184.0, + "2": 1206831616.0, + "3": 1210116096.0, + "4": 1210116096.0, + "5": 1210116096.0, + "6": 1210116096.0, + "7": 1210116096.0, + "8": 1210116096.0, + "9": 1210116096.0, + "10": 1210116096.0, + "11": 1210116096.0, + "12": 1210116096.0, + "13": 1210116096.0, + "14": 1210116096.0, + "15": 1210116096.0, + "16": 1210116096.0, + "17": 1210116096.0, + "18": 1210116096.0, + "19": 1210116096.0, + "20": 1210116096.0, + "21": 1210116096.0, + "22": 1210116096.0, + "23": 1210116096.0, + "24": 1210116096.0, + "25": 1210116096.0, + "26": 1210116096.0, + "27": 1210116096.0, + "28": 1210116096.0, + "29": 1210116096.0, + "30": 1210116096.0, + "31": 1210116096.0, + "32": 1210116096.0, + "33": 1210116096.0, + "34": 1210116096.0, + "35": 1210116096.0, + "36": 1210116096.0, + "37": 1210116096.0, + "38": 1210116096.0, + "39": 1210116096.0, + "40": 1210116096.0, + "41": 1210116096.0, + "42": 1210116096.0, + "43": 1210116096.0, + "44": 1210116096.0, + "45": 1210116096.0, + "46": 1210116096.0, + "47": 1210116096.0, + "48": 1210116096.0, + "49": 1210116096.0, + "50": 1210116096.0, + "51": 1210116096.0, + "52": 1210116096.0, + "53": 1210116096.0, + "54": 1210116096.0, + "55": 1210116096.0, + "56": 1210116096.0, + "57": 1210116096.0, + "58": 1210116096.0, + "59": 1210116096.0, + "60": 1210116096.0, + "61": 1210116096.0, + "62": 1210116096.0, + "63": 1210116096.0, + "64": 1210116096.0, + "65": 1210116096.0, + "66": 1210116096.0, + "67": 1210116096.0, + "68": 1210116096.0, + "69": 1210116096.0, + "70": 1210116096.0, + "71": 1210116096.0, + "72": 1210116096.0, + "73": 1210116096.0, + "74": 1210116096.0, + "75": 1210116096.0, + "76": 1210116096.0, + "77": 1210116096.0, + "78": 1210116096.0, + "79": 1210116096.0, + "80": 1210116096.0, + "81": 1210116096.0, + "82": 1210116096.0, + "83": 1210116096.0, + "84": 1210116096.0, + "85": 1210116096.0, + "86": 1210116096.0, + "87": 1210116096.0, + "88": 1210116096.0, + "89": 1210116096.0, + "90": 1210116096.0, + "91": 1210116096.0, + "92": 1210116096.0, + "93": 1210116096.0, + "94": 1210116096.0, + "95": 1210116096.0, + "96": 1210116096.0, + "97": 1210116096.0, + "98": 1210116096.0, + "99": 1210116096.0, + "100": 1210116096.0 + } + }, + "mtp_1 loss": { + "start_step": 1, + "end_step": 100, + "step_interval": 1, + "values": { + "1": 10.88689, + "2": 10.90485, + "3": 10.90869, + "4": 10.86903, + "5": 10.91601, + "6": 10.906, + "7": 10.90268, + "8": 10.88984, + "9": 10.90425, + "10": 10.89144, + "11": 10.93384, + "12": 10.91647, + "13": 10.91108, + "14": 10.91974, + "15": 10.88488, + "16": 10.9077, + "17": 10.87571, + "18": 10.91379, + "19": 10.9092, + "20": 10.87837, + "21": 10.87896, + "22": 10.85583, + "23": 10.88007, + "24": 10.87245, + "25": 10.85859, + "26": 10.8696, + "27": 10.87702, + "28": 10.88641, + "29": 10.88866, + "30": 10.85422, + "31": 10.79713, + "32": 10.86631, + "33": 10.8781, + "34": 10.83982, + "35": 10.84165, + "36": 10.85012, + "37": 10.85556, + "38": 10.83674, + "39": 10.86355, + "40": 10.82887, + "41": 10.8341, + "42": 10.84469, + "43": 10.78828, + "44": 10.82123, + "45": 10.78831, + "46": 10.7823, + "47": 10.82898, + "48": 10.78985, + "49": 10.71269, + "50": 10.77382, + "51": 10.76639, + "52": 10.7397, + "53": 10.80285, + "54": 10.77365, + "55": 10.76066, + "56": 10.71068, + "57": 10.66686, + "58": 10.74378, + "59": 10.69209, + "60": 10.66474, + "61": 10.7073, + "62": 10.77206, + "63": 10.61812, + "64": 10.7178, + "65": 10.49439, + "66": 10.67106, + "67": 10.57534, + "68": 10.6873, + "69": 10.6816, + "70": 10.66836, + "71": 10.64586, + "72": 10.60925, + "73": 10.56508, + "74": 10.37144, + "75": 10.51183, + "76": 10.39914, + "77": 10.75182, + "78": 10.6268, + "79": 10.46827, + "80": 10.47524, + "81": 10.51083, + "82": 10.58769, + "83": 10.4381, + "84": 10.45057, + "85": 10.55084, + "86": 10.28076, + "87": 10.51088, + "88": 10.60323, + "89": 10.50794, + "90": 10.60274, + "91": 10.38238, + "92": 10.38703, + "93": 10.23076, + "94": 10.08438, + "95": 10.42616, + "96": 10.44905, + "97": 10.32215, + "98": 10.4966, + "99": 10.04765, + "100": 10.33491 + } + }, + "iteration-time": { + "start_step": 1, + "end_step": 100, + "step_interval": 1, + "values": { + "1": 58.67467, + "2": 1.49483, + "3": 1.38721, + "4": 11.78499, + "5": 0.75759, + "6": 0.75678, + "7": 0.76144, + "8": 0.80382, + "9": 0.74706, + "10": 0.74893, + "11": 0.75091, + "12": 0.75087, + "13": 0.74803, + "14": 0.75316, + "15": 0.80396, + "16": 0.75267, + "17": 0.75378, + "18": 0.75457, + "19": 0.75484, + "20": 0.75428, + "21": 0.75639, + "22": 0.81363, + "23": 0.75607, + "24": 0.75553, + "25": 0.75564, + "26": 0.75334, + "27": 0.75722, + "28": 0.76027, + "29": 0.8113, + "30": 0.75278, + "31": 0.75471, + "32": 0.75104, + "33": 0.75271, + "34": 0.74877, + "35": 0.74765, + "36": 0.80549, + "37": 0.75089, + "38": 0.75395, + "39": 0.75254, + "40": 0.76025, + "41": 0.75356, + "42": 0.75573, + "43": 0.79632, + "44": 0.77927, + "45": 0.75515, + "46": 0.75759, + "47": 0.75978, + "48": 0.75749, + "49": 0.75504, + "50": 0.75616, + "51": 0.77974, + "52": 0.76581, + "53": 0.76997, + "54": 0.76705, + "55": 0.76737, + "56": 0.77352, + "57": 0.77833, + "58": 0.81195, + "59": 0.77251, + "60": 0.7711, + "61": 0.77181, + "62": 0.77006, + "63": 0.76957, + "64": 0.77251, + "65": 0.82259, + "66": 0.77112, + "67": 0.7683, + "68": 0.77335, + "69": 0.77022, + "70": 0.77335, + "71": 0.77822, + "72": 0.77769, + "73": 0.79476, + "74": 0.7728, + "75": 0.7711, + "76": 0.76863, + "77": 0.77228, + "78": 0.77031, + "79": 0.76995, + "80": 0.77286, + "81": 0.76616, + "82": 0.76752, + "83": 0.76583, + "84": 0.77264, + "85": 0.76732, + "86": 0.76873, + "87": 0.77239, + "88": 0.77971, + "89": 0.76112, + "90": 0.76225, + "91": 0.75814, + "92": 0.76144, + "93": 0.75796, + "94": 0.76412, + "95": 0.777, + "96": 0.77207, + "97": 0.7628, + "98": 0.76325, + "99": 0.76204, + "100": 0.7668 + } + } +} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph/model_config.yaml new file mode 100644 index 00000000000..ef2b76069a1 --- /dev/null +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph/model_config.yaml @@ -0,0 +1,96 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + NCCL_ALGO: Ring + CUBLAS_WORKSPACE_CONFIG: :4096:8 +MODEL_ARGS: + --num-layers: 13 + --hidden-size: 512 + --num-attention-heads: 8 + --mtp-num-layers: 1 + --micro-batch-size: 2 + --global-batch-size: 32 + --seq-length: 1024 + --max-position-embeddings: 1024 + --position-embedding-type: rope + --rotary-base: 10000 + --untie-embeddings-and-output-weights: true + --disable-bias-linear: true + --attention-dropout: 0.0 + --hidden-dropout: 0.0 + --train-iters: 100 + --lr-decay-iters: 320000 + --split: 949,50,1 + --distributed-backend: nccl + --lr: 0.00015 + --lr-decay-style: cosine + --min-lr: 1.0e-5 + --weight-decay: 1e-2 + --clip-grad: 1.0 + --lr-warmup-fraction: .01 + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 4 + --pipeline-model-parallel-size: 2 + --expert-model-parallel-size: 2 + --expert-tensor-parallel-size: 2 + --pipeline-model-parallel-layout: Et\\|\\(tt\\|\\)*6mL # Et|(tt|)*6mL + --sequence-parallel: true + --num-experts: 8 + --use-distributed-optimizer: true + --overlap-grad-reduce: true + --overlap-param-gather: true + --moe-token-dispatcher-type: alltoall + --moe-router-load-balancing-type: global_aux_loss + --moe-router-topk: 2 + --moe-router-dtype: fp32 + --moe-router-fusion: true + --moe-router-enable-expert-bias: true + --moe-router-score-function: sigmoid + --moe-router-pre-softmax: true + --moe-ffn-hidden-size: 1024 + --moe-shared-expert-intermediate-size: 512 + --moe-grouped-gemm: true + --moe-layer-freq: ([0]*4+[1]*9) + --moe-permute-fusion: true + --deterministic-mode: true + --no-gradient-accumulation-fusion: true + --attention-softmax-in-fp32: true + --use-checkpoint-opt_param-scheduler: true + --use-mcore-models: true + --bf16: true + --fp8-format: hybrid + --fp8-recipe: blockwise + --first-last-layers-bf16: true + --no-bias-gelu-fusion: true + --recompute-granularity: selective + --recompute-modules: "[moe_act]" + --cuda-graph-impl: transformer_engine + --cuda-graph-scope: "[attn mlp moe_router moe_preprocess]" + --log-memory-to-tensorboard: true + --log-params-norm: true + --log-num-zeros-in-grad: true + --log-validation-ppl-to-tensorboard: true + --log-timers-to-tensorboard: true + --tensorboard-dir: ${TENSORBOARD_PATH} + --log-interval: 1 + --timing-log-level: 0 + --save-interval: 50 + --eval-interval: 1000 + --eval-iters: 10 + --data-path: ${DATA_PATH}/text/the_pile/shard00/my-gpt3_00_text_document + --data-cache-path: ${DATA_CACHE_PATH} + --vocab-file: ${DATA_PATH}/text/the_pile/shard00/bpe/vocab.json + --merge-file: ${DATA_PATH}/text/the_pile/shard00/bpe/merges.txt + --save: ${CHECKPOINT_SAVE_PATH} + --load: ${CHECKPOINT_LOAD_PATH} + --ckpt-fully-parallel-load: true + --ckpt-format: torch_dist + --ckpt-assume-constant-structure: true +TEST_TYPE: ckpt-resume +METRICS: + - "iteration-time" + - "lm loss" + - "num-zeros" + - "mem-allocated-bytes" + - "mem-max-allocated-bytes" + - "mtp_1 loss" diff --git a/tests/test_utils/recipes/moe.yaml b/tests/test_utils/recipes/moe.yaml index 8164ca37df8..8b6e1191aca 100644 --- a/tests/test_utils/recipes/moe.yaml +++ b/tests/test_utils/recipes/moe.yaml @@ -124,6 +124,11 @@ products: - environment: [dev] scope: [mr] platforms: [dgx_h100] + - test_case: [gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph] + products: + - environment: [dev] + scope: [mr] + platforms: [dgx_h100] ####################################################################### # Super important MR tests that run for both DEV and LTS per MR # ####################################################################### diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index 2d87b3c6adb..20b970f44df 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import asyncio import random @@ -80,7 +80,7 @@ class DynamicEngineTestConfig: return_log_probs: bool = False materialize_only_last_token_logits: bool = True skip_prompt_log_probs_for_dynamic_inference: bool = False - cuda_graph_scope: str = "full_iteration" + cuda_graph_scope: List[str] = None force_build_cuda_graphs: bool = False # If False, do not build cuda graphs in the tests, even if # num_cuda_graphs is set. @@ -111,6 +111,9 @@ def __post_init__(self): if self.context_max_tokens_override is None: self.context_max_tokens_override = self.num_requests * self.max_sequence_length + if self.cuda_graph_scope is None: + self.cuda_graph_scope = ["full_iteration"] + @dataclass class DynamicEngineTestEnv: @@ -403,7 +406,7 @@ def teardown_method(self, method): not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" ) @pytest.mark.parametrize("num_cuda_graphs", [None, 1, 4]) - @pytest.mark.parametrize("cuda_graph_scope", ["full", "full_iteration"]) + @pytest.mark.parametrize("cuda_graph_scope", [[], ["full_iteration"]]) def test_simple(self, num_cuda_graphs, cuda_graph_scope) -> None: """Simple test that runs without errors, and validates output.""" diff --git a/tests/unit_tests/transformer/test_cuda_graphs.py b/tests/unit_tests/transformer/test_cuda_graphs.py index 685e3674374..b4da65aa056 100644 --- a/tests/unit_tests/transformer/test_cuda_graphs.py +++ b/tests/unit_tests/transformer/test_cuda_graphs.py @@ -1,7 +1,8 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import os import random +import sys import time import types @@ -9,6 +10,7 @@ import torch from megatron.core import parallel_state +from megatron.core.enums import ModelType from megatron.core.inference.contexts import DynamicInferenceContext from megatron.core.inference.engines import DynamicInferenceEngine from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( @@ -27,6 +29,7 @@ ) from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec +from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator from megatron.core.pipeline_parallel.schedules import set_current_microbatch from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.ssm.mamba_block import MambaStack @@ -39,6 +42,14 @@ from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import is_fa_min_version, is_te_min_version +from megatron.training.arguments import core_transformer_config_from_args, parse_args, validate_args +from megatron.training.global_vars import ( + destroy_global_vars, + get_args, + set_args, + set_global_variables, +) +from megatron.training.training import setup_model_and_optimizer from tests.unit_tests.test_utilities import Utils @@ -715,6 +726,256 @@ def test_capture_freeze_gc(self): ) +def is_deep_ep_available(): + from megatron.core.transformer.moe.fused_a2a import HAVE_DEEP_EP + + return HAVE_DEEP_EP + + +def is_hybrid_ep_available(): + from megatron.core.transformer.moe.fused_a2a import HAVE_HYBRIDEP + + return HAVE_HYBRIDEP + + +class TestPartialCudaGraph: + """Test that CUDA graph outputs match non-CUDA graph outputs for various scopes.""" + + def setup_method(self, method): + self.seq_length = 512 + self.micro_batch_size = 2 + # Store original environment variable values + self.original_env = { + 'CUDA_DEVICE_MAX_CONNECTIONS': os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS'), + 'NVTE_ALLOW_NONDETERMINISTIC_ALGO': os.environ.get('NVTE_ALLOW_NONDETERMINISTIC_ALGO'), + } + os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' + os.environ['NVTE_ALLOW_NONDETERMINISTIC_ALGO'] = '0' + + def teardown_method(self, method): + # Restore original environment variable values + for key, value in self.original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + Utils.destroy_model_parallel() + destroy_global_vars() + destroy_num_microbatches_calculator() + + def model_provider( + self, + pre_process=True, + post_process=True, + layer_spec_fn=get_gpt_layer_with_transformer_engine_spec, + **config_kwargs, + ): + model_parallel_cuda_manual_seed(123) + args = get_args() + config = core_transformer_config_from_args(args) + transformer_layer_spec = layer_spec_fn() + return GPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=args.vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + ) + + def create_test_args( + self, cuda_graph_impl, cuda_graph_scope, cuda_graph_warmup_steps, ep_size, **kwargs + ): + destroy_global_vars() + destroy_num_microbatches_calculator() + + sys.argv = ['test_cuda_graphs.py'] + args = parse_args() + args.num_layers = 4 + args.mtp_num_layers = 1 + args.vocab_size = 1024 + args.hidden_size = 128 + args.num_attention_heads = 8 + args.max_position_embeddings = 512 + args.global_batch_size = self.micro_batch_size * 8 + args.micro_batch_size = self.micro_batch_size + args.create_attention_mask_in_dataloader = True + args.seq_length = self.seq_length + args.tensor_model_parallel_size = 2 + args.sequence_parallel = True + args.pipeline_model_parallel_size = 1 + args.context_parallel_size = 1 + args.expert_model_parallel_size = ep_size + args.train_iters = 10 + args.lr = 3e-5 + args.bf16 = True + args.add_bias_linear = False + args.swiglu = True + args.use_distributed_optimizer = True + args.position_embedding_type = "rope" + args.rotary_percent = 1.0 + args.hidden_dropout = 0.0 + args.attention_dropout = 0.0 + + # MoE settings + args.num_experts = 4 + args.expert_model_parallel_size = ep_size + args.moe_shared_expert_intermediate_size = 1024 + args.moe_layer_freq = "[0,0,1,1]" + args.moe_permute_fusion = True + args.moe_router_fusion = True + args.moe_router_topk = 2 + + # CUDA graph settings + args.cuda_graph_impl = cuda_graph_impl + args.cuda_graph_scope = cuda_graph_scope + args.cuda_graph_warmup_steps = cuda_graph_warmup_steps + args.use_te_rng_tracker = cuda_graph_impl != "none" + + for key, value in kwargs.items(): + assert hasattr(args, key) + setattr(args, key, value) + + validate_args(args) + set_global_variables(args, False) + return args + + def get_batch(self, seq_length, micro_batch_size): + data = list(range(seq_length)) + input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + labels = 1 + torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + attention_mask = torch.ones( + (micro_batch_size, 1, seq_length, seq_length), dtype=bool + ).cuda() + loss_mask = torch.ones(seq_length).repeat((micro_batch_size, 1)).cuda() + return input_ids, labels, position_ids, attention_mask, loss_mask + + def _run_test_helper( + self, ep_size, cuda_graph_impl, cuda_graph_scope, cuda_graph_warmup_steps, **kwargs + ): + """Test fp8_param with gpt_model.""" + args = self.create_test_args( + cuda_graph_impl, cuda_graph_scope, cuda_graph_warmup_steps, ep_size, **kwargs + ) + + set_args(args) + torch.manual_seed(123) + Utils.initialize_model_parallel( + tensor_model_parallel_size=2, expert_model_parallel_size=ep_size + ) + + input_ids, labels, position_ids, attention_mask, loss_mask = self.get_batch( + self.seq_length, self.micro_batch_size + ) + + gpt_model, optimizer, _ = setup_model_and_optimizer( + self.model_provider, ModelType.encoder_or_decoder + ) + assert len(gpt_model) == 1 # Assume only one model in the model provider. + + loss_list = [] + + cuda_graph_helper = None + if cuda_graph_impl == "transformer_engine": + from megatron.core.transformer.cuda_graphs import TECudaGraphHelper + + cuda_graph_helper = TECudaGraphHelper( + model=gpt_model, + config=gpt_model[0].config, + seq_length=self.seq_length, + micro_batch_size=self.micro_batch_size, + optimizers=[optimizer], + ) + + for i in range(100): + gpt_model[0].zero_grad_buffer() + optimizer.zero_grad() + + # Capture CUDA graphs after warmup if helper is provided + if cuda_graph_helper is not None and i == cuda_graph_warmup_steps: + cuda_graph_helper.create_cudagraphs() + + output = gpt_model[0].forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=labels, + loss_mask=loss_mask, + ) + + # Check output shapes + assert output.shape[0] == self.micro_batch_size + assert output.shape[1] == self.seq_length + + # Verify gradients + loss = output.mean() + loss.backward() + + for param in gpt_model[0].parameters(): + assert param.main_grad is not None + + update_successful, _, _ = optimizer.step() + assert update_successful + + loss_list.append(loss.item()) + + return torch.tensor(loss_list) + + @pytest.mark.skipif( + not (HAVE_TE and is_te_min_version("1.14.0")), + reason="Partial CUDA graph support requires TransformerEngine version >= 1.14.0", + ) + @pytest.mark.parametrize("ep_size", [1, 4]) + @pytest.mark.parametrize("moe_dropless_dispatcher", [False, True]) + @pytest.mark.parametrize("moe_dispatcher_type", ["alltoall", "deepep", "hybridep"]) + def test_moe_partial_cudagraph(self, ep_size, moe_dropless_dispatcher, moe_dispatcher_type): + extra_kwargs = {} + if moe_dispatcher_type == "deepep": + if not is_deep_ep_available(): + pytest.skip("Deep EP is not available") + extra_kwargs["moe_token_dispatcher_type"] = "flex" + extra_kwargs["moe_flex_dispatcher_backend"] = "deepep" + elif moe_dispatcher_type == "hybridep": + if not is_hybrid_ep_available(): + pytest.skip("Hybrid EP is not available") + extra_kwargs["moe_token_dispatcher_type"] = "flex" + extra_kwargs["moe_flex_dispatcher_backend"] = "hybridep" + else: + extra_kwargs["moe_token_dispatcher_type"] = moe_dispatcher_type + if not moe_dropless_dispatcher: + if moe_dispatcher_type == "deepep": + pytest.skip("Deep EP doesn't support drop&pad MoE") + extra_kwargs["moe_expert_capacity_factor"] = 1.0 + extra_kwargs["moe_pad_expert_input_to_capacity"] = True + + loss_list_ref = self._run_test_helper(ep_size, "none", None, 0, **extra_kwargs) + for cuda_graph_scope in [ + None, + ["attn"], + ["moe"], + ["mlp", "moe_router"], + ["attn", "mlp", "moe_router", "moe_preprocess"], + ]: + if moe_dropless_dispatcher and (cuda_graph_scope is None or "moe" in cuda_graph_scope): + # Dropless MoE doesn't work with "moe" scope cudagraph. Skip. + continue + cuda_graph_warmup_steps = 3 + loss_list = self._run_test_helper( + ep_size, + "transformer_engine", + cuda_graph_scope, + cuda_graph_warmup_steps, + **extra_kwargs, + ) + assert torch.equal(loss_list, loss_list_ref) + + if __name__ == "__main__": test = TestParallelTransformerBlockCudagraphs() @@ -729,3 +990,8 @@ def test_capture_freeze_gc(self): test = TestCaptureFreezeGC() test.test_capture_freeze_gc() + + test = TestPartialCudaGraph() + test.setup_method(method=None) + test.test_moe_partial_cudagraph(4, True, "alltoall") + test.teardown_method(method=None) diff --git a/tools/checkpoint/checkpoint_inspector.py b/tools/checkpoint/checkpoint_inspector.py index 34afa27755f..c45b56e9f8a 100644 --- a/tools/checkpoint/checkpoint_inspector.py +++ b/tools/checkpoint/checkpoint_inspector.py @@ -1,3 +1,5 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + # python checkpoint_inspector.py inspect /path/to/checkpoint # torchrun --nproc_per_node=8 --nnodes=1 checkpoint_inspector.py convert-torch-dist-to-fsdp-dtensor /path/to/input_checkpoint /path/to/output_checkpoint --swiglu import gc