Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.utils import get_attention_mask, set_decode_expert_padding
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.moe.moe_layer import BaseMoELayer
from megatron.core.transformer.utils import set_model_to_sequence_parallel
from megatron.core.utils import get_asyncio_loop, get_model_config, unwrap_model
Expand Down Expand Up @@ -851,7 +852,7 @@ def generate_all_output_tokens_static_batch(
# Check whether CUDA graphs are enabled
enable_cuda_graph = (
model_config.cuda_graph_impl == "local"
and "full_iteration" not in model_config.cuda_graph_scope
and CudaGraphScope.full_iteration not in model_config.cuda_graph_scope
)

# Pad batch tokens if necessary
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
is_vp_last_stage,
)
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.enums import AttnBackend
from megatron.core.transformer.enums import AttnBackend, CudaGraphScope
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import ensure_metadata_has_dp_cp_group
Expand Down Expand Up @@ -144,8 +144,7 @@ def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor:
# Use is_cg_capturable=True for full iteration CUDA graphs to avoid torch.equal checks
is_cg_capturable = (
hasattr(self.config, 'cuda_graph_scope')
and self.config.cuda_graph_scope
and 'full_iteration' in self.config.cuda_graph_scope
and CudaGraphScope.full_iteration in self.config.cuda_graph_scope
)
if is_cg_capturable and not is_te_min_version("2.7.0"):
from megatron.core.utils import get_te_version
Expand Down
4 changes: 2 additions & 2 deletions megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.quantization.utils import get_quant_config_or_none
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.enums import CudaGraphScope, ModelType
from megatron.core.transformer.multi_token_prediction import (
MTPLossAutoScaler,
MTPLossLoggingHelper,
Expand Down Expand Up @@ -374,7 +374,7 @@ def _preprocess(
and (
(
self.config.cuda_graph_impl == "local"
and "full_iteration" not in self.config.cuda_graph_scope
and CudaGraphScope.full_iteration not in self.config.cuda_graph_scope
)
or self.config.flash_decode
)
Expand Down
7 changes: 4 additions & 3 deletions megatron/core/pipeline_parallel/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.utils import (
drain_embedding_wgrad_compute,
Expand Down Expand Up @@ -656,7 +657,7 @@ def forward_backward_no_pipelining(
if (
hasattr(config, 'cuda_graph_impl')
and config.cuda_graph_impl == "local"
and "full_iteration" not in config.cuda_graph_scope
and CudaGraphScope.full_iteration not in config.cuda_graph_scope
):
create_cudagraphs()

Expand Down Expand Up @@ -1923,7 +1924,7 @@ def pp_post_backward(input_tensor_grad, vp_stage=None):
if (
hasattr(config, 'cuda_graph_impl')
and config.cuda_graph_impl == "local"
and "full_iteration" not in config.cuda_graph_scope
and CudaGraphScope.full_iteration not in config.cuda_graph_scope
):
create_cudagraphs()
nvtx_range_pop(suffix="misc")
Expand Down Expand Up @@ -2310,7 +2311,7 @@ def enable_grad_sync():
if (
hasattr(config, 'cuda_graph_impl')
and config.cuda_graph_impl == "local"
and "full_iteration" not in config.cuda_graph_scope
and CudaGraphScope.full_iteration not in config.cuda_graph_scope
):
create_cudagraphs()

Expand Down
3 changes: 2 additions & 1 deletion megatron/core/ssm/mamba_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from megatron.core.ssm.mamba_hybrid_layer_allocation import allocate_layers
from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
Expand Down Expand Up @@ -294,7 +295,7 @@ def forward(
(
(
self.config.cuda_graph_impl == "local"
and "full_iteration" not in self.config.cuda_graph_scope
and CudaGraphScope.full_iteration not in self.config.cuda_graph_scope
)
or self.config.flash_decode
)
Expand Down
4 changes: 2 additions & 2 deletions megatron/core/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from ..models.common.embeddings.yarn_rotary_pos_embedding import (
_yarn_get_concentration_factor_from_config,
)
from .enums import AttnMaskType
from .enums import AttnMaskType, CudaGraphScope
from .transformer_config import TransformerConfig

try:
Expand Down Expand Up @@ -828,7 +828,7 @@ def forward(
if (
in_decode_mode
and self.config.cuda_graph_impl == "local"
and "full_iteration" not in self.config.cuda_graph_scope
and CudaGraphScope.full_iteration not in self.config.cuda_graph_scope
and inference_context.is_static_batching()
):
raise ValueError(f"CUDA graphs must use flash decode with static batching!")
Expand Down
40 changes: 32 additions & 8 deletions megatron/core/transformer/cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
get_all_rng_states,
get_cuda_rng_tracker,
)
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.module import GraphableMegatronModule, MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
Expand Down Expand Up @@ -1344,24 +1345,24 @@ def _layer_is_graphable(layer, config):
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.transformer_layer import TransformerLayer

if isinstance(layer, MambaLayer) and 'mamba' in config.cuda_graph_scope:
if isinstance(layer, MambaLayer) and CudaGraphScope.mamba in config.cuda_graph_scope:
# mamba layer.
return True
if isinstance(layer, TransformerLayer):
if 'attn' in config.cuda_graph_scope and not (
if CudaGraphScope.attn in config.cuda_graph_scope and not (
isinstance(layer.self_attention, IdentityOp)
and isinstance(layer.cross_attention, IdentityOp)
):
# attn layer.
return True
if (
'moe' in config.cuda_graph_scope
or 'moe_router' in config.cuda_graph_scope
or 'moe_preprocess' in config.cuda_graph_scope
CudaGraphScope.moe in config.cuda_graph_scope
or CudaGraphScope.moe_router in config.cuda_graph_scope
or CudaGraphScope.moe_preprocess in config.cuda_graph_scope
) and isinstance(layer.mlp, MoELayer):
# moe layer.
return True
if 'mlp' in config.cuda_graph_scope and isinstance(layer.mlp, MLP):
if CudaGraphScope.mlp in config.cuda_graph_scope and isinstance(layer.mlp, MLP):
# mlp layer.
return True
return False
Expand All @@ -1388,7 +1389,7 @@ def __init__(self, model, config, seq_length, micro_batch_size, optimizers=[]):
"Setting NCCL_GRAPH_REGISTER=0 to avoid illegal memory access when using "
"CUDA Graph with PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True."
)
assert "full_iteration" not in config.cuda_graph_scope, (
assert CudaGraphScope.full_iteration not in config.cuda_graph_scope, (
"full_iteration cuda graph is not supported for cuda_graph_impl=transformer_engine. "
"Please use cuda_graph_impl=local instead."
)
Expand Down Expand Up @@ -1529,7 +1530,7 @@ def get_rotary_pos_emb(transformer_module, transformer_input):
and not isinstance(layer.self_attention, IdentityOp)
and (
not self.config.cuda_graph_scope
or 'attn' in self.config.cuda_graph_scope
or CudaGraphScope.attn in self.config.cuda_graph_scope
)
)
if is_te_min_version("1.10.0"):
Expand Down Expand Up @@ -1712,3 +1713,26 @@ def cuda_graph_set_manual_hooks(self):
model_chunk = self.model[chunk_number]
for layer in layers:
layer.setup_manual_hooks(model_chunk._make_forward_pre_hook)

def delete_cuda_graphs(self):
"""
Delete all CUDA graphs.
"""
assert self._graphs_created, "CUDA Graphs have not been created."
graphs_deleted, graphs_not_deleted = 0, 0
for _, layers in enumerate(self.callables_per_chunk):
for layer in layers:
for graph in layer.cuda_graphs:
if is_te_min_version("2.10.0"):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

te version check is repeated inside loop, maybe only one pre-check is enough.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

graph.reset()
graphs_deleted += 1
else:
graphs_not_deleted += 1
layer.cuda_graphs = []
layer.cuda_graph_manual_hooks = []
log_single_rank(
logger,
logging.INFO,
f'{graphs_deleted} graphs deleted, {graphs_not_deleted} graphs not deleted.',
)
self._graphs_created = False
12 changes: 12 additions & 0 deletions megatron/core/transformer/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,15 @@ class AttnBackend(enum.Enum):
unfused = 3
local = 4
auto = 5


class CudaGraphScope(enum.Enum):
"""Cuda Graph Scope"""

full_iteration = 1
attn = 2
mlp = 3
moe = 4 # only used for MoeLayer
moe_router = 5 # only used for MoeLayer
moe_preprocess = 6 # only used for MoeLayer
mamba = 7 # only used for MambaLayer
8 changes: 8 additions & 0 deletions megatron/core/transformer/moe/fused_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,14 @@ def init_hybrid_ep_buffer(
)


def reset_hybrid_ep_buffer():
'''
Reset the HybridEP buffer
'''
global _hybrid_ep_buffer
_hybrid_ep_buffer = None


class HybridEPDispatch(torch.autograd.Function):
'''
Fused dispatch operation for permute + dispatch a2a + permute using the HybridEP backend
Expand Down
7 changes: 4 additions & 3 deletions megatron/core/transformer/moe/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from megatron.core.fp8_utils import get_fp8_align_size
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.cuda_graphs import is_graph_capturing
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.transformer_config import TransformerConfig

try:
Expand Down Expand Up @@ -1205,13 +1206,13 @@ def maybe_raise_signal(moe_layer, **kwargs):
):
if (
step_condition == "route"
and 'moe_router' in moe_layer.config.cuda_graph_scope
and 'moe_preprocess' not in moe_layer.config.cuda_graph_scope
and CudaGraphScope.moe_router in moe_layer.config.cuda_graph_scope
and CudaGraphScope.moe_preprocess not in moe_layer.config.cuda_graph_scope
):
raise MoECudaGraphPartialCaptureSignal(moe_layer, "route", **kwargs)
elif (
step_condition == "preprocess"
and 'moe_preprocess' in moe_layer.config.cuda_graph_scope
and CudaGraphScope.moe_preprocess in moe_layer.config.cuda_graph_scope
):
raise MoECudaGraphPartialCaptureSignal(moe_layer, "preprocess", **kwargs)

Expand Down
8 changes: 5 additions & 3 deletions megatron/core/transformer/moe/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
gather_from_sequence_parallel_region,
reduce_scatter_to_sequence_parallel_region,
)
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.moe.fused_a2a import (
fused_combine,
fused_dispatch,
Expand Down Expand Up @@ -436,7 +437,7 @@ def __init__(
}
if (
config.cuda_graph_impl == "transformer_engine"
and 'moe_preprocess' in config.cuda_graph_scope
and CudaGraphScope.moe_preprocess in config.cuda_graph_scope
):
self.cuda_dtoh_point = "before_ep_alltoall"
else:
Expand Down Expand Up @@ -1077,8 +1078,9 @@ def combine(
)
# Release the used handle/num_permuted_tokens which could change in each iteration
self.handle = None
self.num_permuted_tokens = None
self.num_dispatched_tokens = None
if not self.drop_and_pad:
self.num_permuted_tokens = None
self.num_dispatched_tokens = None
return hidden_states

def get_permuted_hidden_states_by_experts(self, hidden_states: torch.Tensor) -> torch.Tensor:
Expand Down
4 changes: 2 additions & 2 deletions megatron/core/transformer/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
from megatron.core.pipeline_parallel.utils import is_vp_first_stage, is_vp_last_stage
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.enums import LayerType
from megatron.core.transformer.enums import CudaGraphScope, LayerType
from megatron.core.transformer.module import GraphableMegatronModule, MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
Expand Down Expand Up @@ -555,7 +555,7 @@ def _should_call_local_cudagraph(self, *args, **kwargs):
kwargs.get('inference_context') is not None
or kwargs.get('inference_params') is not None
)
and 'full_iteration' in self.config.cuda_graph_scope
and CudaGraphScope.full_iteration in self.config.cuda_graph_scope
):
if kwargs['inference_context'].is_static_batching():
using_cuda_graph = kwargs['inference_context'].is_decode_only()
Expand Down
Loading
Loading