Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2328,6 +2328,14 @@ def fused_apply_rotary_pos_emb_thd(
fused_sort_chunks_by_index_with_probs = None
fused_unpermute = None

try:
from transformer_engine.pytorch.permutation import moe_permute_and_pad_with_probs

fused_permute_and_pad_with_probs = moe_permute_and_pad_with_probs

except ImportError:
fused_permute_and_pad_with_probs = None

try:
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy

Expand Down
25 changes: 20 additions & 5 deletions megatron/core/transformer/moe/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
fused_compute_score_for_moe_aux_loss,
fused_moe_aux_loss,
fused_permute,
fused_permute_and_pad_with_probs,
fused_permute_with_probs,
fused_sort_chunks_by_index,
fused_sort_chunks_by_index_with_probs,
Expand Down Expand Up @@ -227,6 +228,8 @@ def permute(
num_out_tokens: Optional[int] = None,
fused: bool = False,
drop_and_pad: bool = False,
tokens_per_expert: Optional[torch.Tensor] = None,
align_size: int = -1,
):
"""Permute the tokens and probs based on the mask.
Tokens with the same designated expert will be grouped together.
Expand Down Expand Up @@ -257,11 +260,18 @@ def permute(
return permuted_input, None, sorted_indices

if fused and probs is not None:
if not HAVE_TE or fused_permute_with_probs is None:
raise ValueError(
"fused_permute_with_probs is not available. Please install TE >= 2.1.0."
if tokens_per_expert is not None and align_size > 0:
Comment thread
xiaoxi-wangfj marked this conversation as resolved.
Outdated
return fused_permute_and_pad_with_probs(
tokens, probs, routing_map, tokens_per_expert, align_size
)
else:
if not HAVE_TE or fused_permute_with_probs is None:
raise ValueError(
"fused_permute_with_probs is not available. Please install TE >= 2.1.0."
)
return fused_permute_with_probs(
tokens, probs, routing_map, num_out_tokens=num_out_tokens
)
return fused_permute_with_probs(tokens, probs, routing_map, num_out_tokens=num_out_tokens)

num_tokens, hidden = tokens.shape
num_experts = routing_map.shape[1]
Expand Down Expand Up @@ -315,6 +325,7 @@ def unpermute(
routing_map: torch.Tensor = None,
fused: bool = False,
drop_and_pad: bool = False,
pad_offsets: Optional[torch.Tensor] = None,
Comment thread
xiaoxi-wangfj marked this conversation as resolved.
):
"""
Restore the original order of tokens after permutation. If probs are provided, it
Expand Down Expand Up @@ -344,7 +355,11 @@ def unpermute(
if not HAVE_TE or fused_unpermute is None:
raise ValueError("fused_unpermute is not available. Please install TE >= 2.1.0.")
return fused_unpermute(
permuted_tokens, sorted_indices, merging_probs=probs, restore_shape=restore_shape
permuted_tokens,
sorted_indices,
merging_probs=probs,
restore_shape=restore_shape,
pad_offsets=pad_offsets,
Comment thread
xiaoxi-wangfj marked this conversation as resolved.
Outdated
)

_, hidden = restore_shape
Expand Down
33 changes: 26 additions & 7 deletions megatron/core/transformer/moe/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,13 +1258,29 @@ def get_permuted_hidden_states_by_experts(self, hidden_states: torch.Tensor) ->

self.hidden_shape_before_permute = hidden_states.shape
assert self.dispatched_probs.dtype == torch.float32, "DeepEP only supports float32 probs"
hidden_states, permuted_probs, self.reversed_mapping_for_combine = permute(
hidden_states,
self.dispatched_routing_map,
probs=self.dispatched_probs,
num_out_tokens=self.tokens_per_expert.sum().item(),
fused=self.permute_fusion,
)
if self.config.moe_permute_padding_for_quantization:
(
hidden_states,
permuted_probs,
self.reversed_mapping_for_combine,
self.pad_offsets,
self.tokens_per_expert,
) = permute(
hidden_states,
self.dispatched_routing_map,
probs=self.dispatched_probs,
fused=self.permute_fusion,
tokens_per_expert=self.tokens_per_expert,
align_size=get_align_size_for_quantization(self.config),
)
else:
hidden_states, permuted_probs, self.reversed_mapping_for_combine = permute(
hidden_states,
self.dispatched_routing_map,
probs=self.dispatched_probs,
num_out_tokens=self.tokens_per_expert.sum().item(),
fused=self.permute_fusion,
)
if self.router_dtype == "fp64":
permuted_probs = permuted_probs.to(torch.float64)
return hidden_states, permuted_probs
Expand All @@ -1276,6 +1292,9 @@ def get_restored_hidden_states_by_experts(self, hidden_states: torch.Tensor) ->
restore_shape=self.hidden_shape_before_permute,
routing_map=self.dispatched_routing_map,
fused=self.permute_fusion,
pad_offsets=(
self.pad_offsets if self.config.moe_permute_padding_for_quantization else None
),
)
return hidden_states

Expand Down
22 changes: 22 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,12 @@ class TransformerConfig(ModelParallelConfig):
"""[Compatibility alias for moe_router_padding_for_quantization]
Enabling this will also enable moe_router_padding_for_quantization."""

moe_permute_padding_for_quantization: Optional[bool] = False
Comment thread
xiaoxi-wangfj marked this conversation as resolved.
Outdated
"""Enable padding during MoE token permutation and corresponding unpadding during unpermutation
so that the number of tokens in each expert's permuted block is aligned to a multiple of 16 / 32
for quantized precisions such as FP8 and FP4. This can remove explicit padding/
unpadding around GroupedMLP kernels, which improves throughput and reduces peak memory usage."""

moe_router_num_groups: Optional[int] = None
"""Number of groups to divide experts into for group-limited routing.
When using group-limited routing:
Expand Down Expand Up @@ -1419,6 +1425,22 @@ def __post_init__(self):
"moe_router_padding_for_quantization."
)

if self.moe_permute_padding_for_quantization:
if self.fp8 is None and self.fp4 is None:
raise ValueError(
"moe_permute_padding_for_quantization requires a quantized precision recipe(e.g., fp8 or fp4) to be enabled."
)

if not self.moe_permute_fusion:
raise ValueError(
"moe_permute_padding_for_quantization currently requires fused permute."
)

from megatron.core.transformer.moe.moe_utils import fused_permute_and_pad_with_probs

if fused_permute_and_pad_with_probs is None:
raise ValueError("fused_permute_and_pad_with_probs is not available. Please install TE >= 2.12.0.")

if (
self.moe_router_topk == 1
and self.moe_router_score_function == "softmax"
Expand Down
4 changes: 4 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3207,6 +3207,10 @@ def _add_moe_args(parser):
group.add_argument('--moe-router-padding-for-fp8', action='store_true',
help='[Compatibility alias for --moe-router-padding-for-quantization] '
'Enabling this will also enable --moe-router-padding-for-quantization.')
group.add_argument('--moe-permute-padding-for-quantization', action='store_true',
Comment thread
xiaoxi-wangfj marked this conversation as resolved.
Outdated
help='Enable padding during MoE token permutation (and unpadding during unpermutation) '
'so that the number of tokens in each expert permuted block is aligned to a multiple of 16/32 for FP8/FP4 precision. '
'This can remove explicit padding/unpadding around GroupedMLP, which improves throughput and reduces peak memory usage.')
group.add_argument('--moe-aux-loss-coeff', type=float, nargs='+', default=0.0,
help='Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended.')
group.add_argument('--moe-z-loss-coeff', type=float, default=None,
Expand Down
102 changes: 101 additions & 1 deletion tests/unit_tests/transformer/moe/test_token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import torch

from megatron.core import config, parallel_state
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
)
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.moe.moe_utils import get_capacity
from megatron.core.transformer.transformer_config import TransformerConfig
Expand Down Expand Up @@ -343,6 +346,75 @@ def dispatcher_router_padding_for_fp8_test(self):
grad_1, hidden_states.grad
), "Gradients do not match between padded and non-padded versions"

@pytest.mark.internal
def dispatcher_permute_padding_for_quantization_test(self):
"""Test permute padding behavior for FP8 training.

Run the dispatch flow twice with identical routing:
1) moe_permute_padding_for_quantization = False
-> naive permute + quantization padding/unpadding
2) moe_permute_padding_for_quantization = True
-> fused permute+pad and fused unpermute+unpad

"""
self.config.fp8 = "hybrid"
self.config.moe_grouped_gemm = True
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
num_experts=self.config.num_moe_experts, moe_grouped_gemm=self.config.moe_grouped_gemm
)
moe_layer = (
MoELayer(self.config, transformer_layer_spec.submodules.mlp.submodules)
.cuda()
.to(dtype=self.test_dtype)
)
moe_layer.set_layer_number(0)
num_tokens = 32
hidden_states = torch.randn(
(num_tokens, moe_layer.config.hidden_size), dtype=self.test_dtype
).cuda()
hidden_states.requires_grad = True
probs, indices = moe_layer.router(hidden_states)

# First run with moe_permute_padding_for_quantization = False, navie permute + pad
moe_layer.config.moe_permute_padding_for_quantization = False
(permuted_input_1, tokens_per_expert_1, permuted_probs_1) = token_permutation(
moe_layer.token_dispatcher, hidden_states, probs, indices
)
actual_tokens_per_expert = tokens_per_expert_1.tolist()
permuted_input_paded_1, tokens_per_expert_1 = moe_layer.experts.quantization_padding(
permuted_input_1, actual_tokens_per_expert
)
permuted_probs_paded_1, _ = moe_layer.experts.quantization_padding(
permuted_probs_1.unsqueeze(-1), actual_tokens_per_expert
)

restored_hidden_states_1 = moe_layer.experts.quantization_unpadding(
permuted_input_paded_1, actual_tokens_per_expert
)
restored_hidden_states_1, _ = token_unpermutation(
moe_layer.token_dispatcher, restored_hidden_states_1
)

# Run with moe_permute_padding_for_quantization = True, Fuse permute + pad
moe_layer.config.moe_permute_padding_for_quantization = True
(permuted_input_paded_2, tokens_per_expert_paded_2, permuted_probs_paded_2) = (
token_permutation(moe_layer.token_dispatcher, hidden_states, probs, indices)
)
restored_hidden_states_2, _ = token_unpermutation(
moe_layer.token_dispatcher, permuted_input_paded_2
)

# Check that the results are the same
torch.testing.assert_close(
permuted_input_paded_2, permuted_input_paded_1
), "permuted hidden states do not match between between permute+pad and fused_permute_pad versions"
torch.testing.assert_close(
permuted_probs_paded_2.unsqueeze(-1), permuted_probs_paded_1
), "permuted probs do not match between between permute+pad and fused_permute_pad versions"
torch.testing.assert_close(
restored_hidden_states_1, restored_hidden_states_2
), "Restored hidden states do not match between unpermute+unpad and fused_unpermute_unpad versions"

def set_params(self):
# TODO: Set consistent parameters for various parallelisms.
raise NotImplementedError
Expand Down Expand Up @@ -525,3 +597,31 @@ def test_router_padding_for_fp8_forward_backward(
)
container.dispatcher_router_padding_for_fp8_test()
config.ENABLE_EXPERIMENTAL = False

@pytest.mark.skipif(
not is_te_min_version("2.12.0"),
reason="TE 2.12.0 is required for MoE fused_permute_pad with FP8.",
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.internal
@pytest.mark.timeout(120)
@pytest.mark.parametrize("tp_size,ep_size", [(1, 8), (8, 1), (4, 2)])
@pytest.mark.parametrize("moe_flex_dispatcher_backend", ["deepep"])
def test_permute_padding_for_quantization(self, tp_size, ep_size, moe_flex_dispatcher_backend):
if moe_flex_dispatcher_backend == "deepep" and not is_deep_ep_available():
pytest.skip("Deep EP is not available")
container = MoEModelTestContainer(
tp_size=tp_size,
ep_size=ep_size,
pp_size=1,
num_moe_experts=32,
moe_router_topk=4,
moe_router_load_balancing_type="aux_loss",
moe_token_dispatcher_type="flex",
moe_pad_expert_input_to_capacity=False,
moe_permute_fusion=True,
hidden_size=1024,
moe_flex_dispatcher_backend=moe_flex_dispatcher_backend,
test_dtype=torch.bfloat16,
)
container.dispatcher_permute_padding_for_quantization_test()