From 3f261d7e7ffe90dcabae33eeea31f65eabc7662a Mon Sep 17 00:00:00 2001 From: xiaoxi-wangfj <690912414@qq.com> Date: Fri, 12 Dec 2025 05:46:22 +0000 Subject: [PATCH 1/6] Fuse permute+pad and unpermute+unpad ops for FP8/FP4 precision This can remove explicit padding/unpadding around GroupedMLP, which improves throughput and reduces peak memory usage Signed-off-by: xiaoxi-wangfj <690912414@qq.com> --- .../core/extensions/transformer_engine.py | 8 ++ megatron/core/transformer/moe/moe_utils.py | 25 ++++- .../core/transformer/moe/token_dispatcher.py | 33 ++++-- .../core/transformer/transformer_config.py | 22 ++++ megatron/training/arguments.py | 4 + .../transformer/moe/test_token_dispatcher.py | 102 +++++++++++++++++- 6 files changed, 181 insertions(+), 13 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index bf5b228b6c9..42dbdec0790 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -2328,6 +2328,14 @@ def fused_apply_rotary_pos_emb_thd( fused_sort_chunks_by_index_with_probs = None fused_unpermute = None +try: + from transformer_engine.pytorch.permutation import moe_permute_and_pad_with_probs + + fused_permute_and_pad_with_probs = moe_permute_and_pad_with_probs + +except ImportError: + fused_permute_and_pad_with_probs = None + try: from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index f8b7d234fff..4f8cd314562 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -19,6 +19,7 @@ fused_compute_score_for_moe_aux_loss, fused_moe_aux_loss, fused_permute, + fused_permute_and_pad_with_probs, fused_permute_with_probs, fused_sort_chunks_by_index, fused_sort_chunks_by_index_with_probs, @@ -227,6 +228,8 @@ def permute( num_out_tokens: Optional[int] = None, fused: bool = False, drop_and_pad: bool = False, + tokens_per_expert: Optional[torch.Tensor] = None, + align_size: int = -1, ): """Permute the tokens and probs based on the mask. Tokens with the same designated expert will be grouped together. @@ -257,11 +260,18 @@ def permute( return permuted_input, None, sorted_indices if fused and probs is not None: - if not HAVE_TE or fused_permute_with_probs is None: - raise ValueError( - "fused_permute_with_probs is not available. Please install TE >= 2.1.0." + if tokens_per_expert is not None and align_size > 0: + return fused_permute_and_pad_with_probs( + tokens, probs, routing_map, tokens_per_expert, align_size + ) + else: + if not HAVE_TE or fused_permute_with_probs is None: + raise ValueError( + "fused_permute_with_probs is not available. Please install TE >= 2.1.0." + ) + return fused_permute_with_probs( + tokens, probs, routing_map, num_out_tokens=num_out_tokens ) - return fused_permute_with_probs(tokens, probs, routing_map, num_out_tokens=num_out_tokens) num_tokens, hidden = tokens.shape num_experts = routing_map.shape[1] @@ -315,6 +325,7 @@ def unpermute( routing_map: torch.Tensor = None, fused: bool = False, drop_and_pad: bool = False, + pad_offsets: Optional[torch.Tensor] = None, ): """ Restore the original order of tokens after permutation. If probs are provided, it @@ -344,7 +355,11 @@ def unpermute( if not HAVE_TE or fused_unpermute is None: raise ValueError("fused_unpermute is not available. Please install TE >= 2.1.0.") return fused_unpermute( - permuted_tokens, sorted_indices, merging_probs=probs, restore_shape=restore_shape + permuted_tokens, + sorted_indices, + merging_probs=probs, + restore_shape=restore_shape, + pad_offsets=pad_offsets, ) _, hidden = restore_shape diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 0beae556cf7..bfbd26f63fb 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -1258,13 +1258,29 @@ def get_permuted_hidden_states_by_experts(self, hidden_states: torch.Tensor) -> self.hidden_shape_before_permute = hidden_states.shape assert self.dispatched_probs.dtype == torch.float32, "DeepEP only supports float32 probs" - hidden_states, permuted_probs, self.reversed_mapping_for_combine = permute( - hidden_states, - self.dispatched_routing_map, - probs=self.dispatched_probs, - num_out_tokens=self.tokens_per_expert.sum().item(), - fused=self.permute_fusion, - ) + if self.config.moe_permute_padding_for_quantization: + ( + hidden_states, + permuted_probs, + self.reversed_mapping_for_combine, + self.pad_offsets, + self.tokens_per_expert, + ) = permute( + hidden_states, + self.dispatched_routing_map, + probs=self.dispatched_probs, + fused=self.permute_fusion, + tokens_per_expert=self.tokens_per_expert, + align_size=get_align_size_for_quantization(self.config), + ) + else: + hidden_states, permuted_probs, self.reversed_mapping_for_combine = permute( + hidden_states, + self.dispatched_routing_map, + probs=self.dispatched_probs, + num_out_tokens=self.tokens_per_expert.sum().item(), + fused=self.permute_fusion, + ) if self.router_dtype == "fp64": permuted_probs = permuted_probs.to(torch.float64) return hidden_states, permuted_probs @@ -1276,6 +1292,9 @@ def get_restored_hidden_states_by_experts(self, hidden_states: torch.Tensor) -> restore_shape=self.hidden_shape_before_permute, routing_map=self.dispatched_routing_map, fused=self.permute_fusion, + pad_offsets=( + self.pad_offsets if self.config.moe_permute_padding_for_quantization else None + ), ) return hidden_states diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 434dbcd1a99..17d30e75403 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -496,6 +496,12 @@ class TransformerConfig(ModelParallelConfig): """[Compatibility alias for moe_router_padding_for_quantization] Enabling this will also enable moe_router_padding_for_quantization.""" + moe_permute_padding_for_quantization: Optional[bool] = False + """Enable padding during MoE token permutation and corresponding unpadding during unpermutation + so that the number of tokens in each expert's permuted block is aligned to a multiple of 16 / 32 + for quantized precisions such as FP8 and FP4. This can remove explicit padding/ + unpadding around GroupedMLP kernels, which improves throughput and reduces peak memory usage.""" + moe_router_num_groups: Optional[int] = None """Number of groups to divide experts into for group-limited routing. When using group-limited routing: @@ -1419,6 +1425,22 @@ def __post_init__(self): "moe_router_padding_for_quantization." ) + if self.moe_permute_padding_for_quantization: + if self.fp8 is None and self.fp4 is None: + raise ValueError( + "moe_permute_padding_for_quantization requires a quantized precision recipe(e.g., fp8 or fp4) to be enabled." + ) + + if not self.moe_permute_fusion: + raise ValueError( + "moe_permute_padding_for_quantization currently requires fused permute." + ) + + from megatron.core.transformer.moe.moe_utils import fused_permute_and_pad_with_probs + + if fused_permute_and_pad_with_probs is None: + raise ValueError("fused_permute_and_pad_with_probs is not available. Please install TE >= 2.12.0.") + if ( self.moe_router_topk == 1 and self.moe_router_score_function == "softmax" diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 649fe442f45..346f7d4f339 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -3207,6 +3207,10 @@ def _add_moe_args(parser): group.add_argument('--moe-router-padding-for-fp8', action='store_true', help='[Compatibility alias for --moe-router-padding-for-quantization] ' 'Enabling this will also enable --moe-router-padding-for-quantization.') + group.add_argument('--moe-permute-padding-for-quantization', action='store_true', + help='Enable padding during MoE token permutation (and unpadding during unpermutation) ' + 'so that the number of tokens in each expert permuted block is aligned to a multiple of 16/32 for FP8/FP4 precision. ' + 'This can remove explicit padding/unpadding around GroupedMLP, which improves throughput and reduces peak memory usage.') group.add_argument('--moe-aux-loss-coeff', type=float, nargs='+', default=0.0, help='Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended.') group.add_argument('--moe-z-loss-coeff', type=float, default=None, diff --git a/tests/unit_tests/transformer/moe/test_token_dispatcher.py b/tests/unit_tests/transformer/moe/test_token_dispatcher.py index c2462ef73ad..73bfe5185e7 100644 --- a/tests/unit_tests/transformer/moe/test_token_dispatcher.py +++ b/tests/unit_tests/transformer/moe/test_token_dispatcher.py @@ -7,7 +7,10 @@ import torch from megatron.core import config, parallel_state -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.moe.moe_utils import get_capacity from megatron.core.transformer.transformer_config import TransformerConfig @@ -343,6 +346,75 @@ def dispatcher_router_padding_for_fp8_test(self): grad_1, hidden_states.grad ), "Gradients do not match between padded and non-padded versions" + @pytest.mark.internal + def dispatcher_permute_padding_for_quantization_test(self): + """Test permute padding behavior for FP8 training. + + Run the dispatch flow twice with identical routing: + 1) moe_permute_padding_for_quantization = False + -> naive permute + quantization padding/unpadding + 2) moe_permute_padding_for_quantization = True + -> fused permute+pad and fused unpermute+unpad + + """ + self.config.fp8 = "hybrid" + self.config.moe_grouped_gemm = True + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=self.config.num_moe_experts, moe_grouped_gemm=self.config.moe_grouped_gemm + ) + moe_layer = ( + MoELayer(self.config, transformer_layer_spec.submodules.mlp.submodules) + .cuda() + .to(dtype=self.test_dtype) + ) + moe_layer.set_layer_number(0) + num_tokens = 32 + hidden_states = torch.randn( + (num_tokens, moe_layer.config.hidden_size), dtype=self.test_dtype + ).cuda() + hidden_states.requires_grad = True + probs, indices = moe_layer.router(hidden_states) + + # First run with moe_permute_padding_for_quantization = False, navie permute + pad + moe_layer.config.moe_permute_padding_for_quantization = False + (permuted_input_1, tokens_per_expert_1, permuted_probs_1) = token_permutation( + moe_layer.token_dispatcher, hidden_states, probs, indices + ) + actual_tokens_per_expert = tokens_per_expert_1.tolist() + permuted_input_paded_1, tokens_per_expert_1 = moe_layer.experts.quantization_padding( + permuted_input_1, actual_tokens_per_expert + ) + permuted_probs_paded_1, _ = moe_layer.experts.quantization_padding( + permuted_probs_1.unsqueeze(-1), actual_tokens_per_expert + ) + + restored_hidden_states_1 = moe_layer.experts.quantization_unpadding( + permuted_input_paded_1, actual_tokens_per_expert + ) + restored_hidden_states_1, _ = token_unpermutation( + moe_layer.token_dispatcher, restored_hidden_states_1 + ) + + # Run with moe_permute_padding_for_quantization = True, Fuse permute + pad + moe_layer.config.moe_permute_padding_for_quantization = True + (permuted_input_paded_2, tokens_per_expert_paded_2, permuted_probs_paded_2) = ( + token_permutation(moe_layer.token_dispatcher, hidden_states, probs, indices) + ) + restored_hidden_states_2, _ = token_unpermutation( + moe_layer.token_dispatcher, permuted_input_paded_2 + ) + + # Check that the results are the same + torch.testing.assert_close( + permuted_input_paded_2, permuted_input_paded_1 + ), "permuted hidden states do not match between between permute+pad and fused_permute_pad versions" + torch.testing.assert_close( + permuted_probs_paded_2.unsqueeze(-1), permuted_probs_paded_1 + ), "permuted probs do not match between between permute+pad and fused_permute_pad versions" + torch.testing.assert_close( + restored_hidden_states_1, restored_hidden_states_2 + ), "Restored hidden states do not match between unpermute+unpad and fused_unpermute_unpad versions" + def set_params(self): # TODO: Set consistent parameters for various parallelisms. raise NotImplementedError @@ -525,3 +597,31 @@ def test_router_padding_for_fp8_forward_backward( ) container.dispatcher_router_padding_for_fp8_test() config.ENABLE_EXPERIMENTAL = False + + @pytest.mark.skipif( + not is_te_min_version("2.12.0"), + reason="TE 2.12.0 is required for MoE fused_permute_pad with FP8.", + ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal + @pytest.mark.timeout(120) + @pytest.mark.parametrize("tp_size,ep_size", [(1, 8), (8, 1), (4, 2)]) + @pytest.mark.parametrize("moe_flex_dispatcher_backend", ["deepep"]) + def test_permute_padding_for_quantization(self, tp_size, ep_size, moe_flex_dispatcher_backend): + if moe_flex_dispatcher_backend == "deepep" and not is_deep_ep_available(): + pytest.skip("Deep EP is not available") + container = MoEModelTestContainer( + tp_size=tp_size, + ep_size=ep_size, + pp_size=1, + num_moe_experts=32, + moe_router_topk=4, + moe_router_load_balancing_type="aux_loss", + moe_token_dispatcher_type="flex", + moe_pad_expert_input_to_capacity=False, + moe_permute_fusion=True, + hidden_size=1024, + moe_flex_dispatcher_backend=moe_flex_dispatcher_backend, + test_dtype=torch.bfloat16, + ) + container.dispatcher_permute_padding_for_quantization_test() From 76aff2e29cb4c7329db6da6d0a8602d8ae50dd99 Mon Sep 17 00:00:00 2001 From: xiaoxi-wangfj <690912414@qq.com> Date: Tue, 6 Jan 2026 09:52:26 +0000 Subject: [PATCH 2/6] set fused_permute_pad to default Signed-off-by: xiaoxi-wangfj <690912414@qq.com> --- megatron/core/transformer/moe/moe_utils.py | 25 +++-- .../core/transformer/moe/token_dispatcher.py | 42 +++----- .../core/transformer/transformer_config.py | 22 ---- megatron/training/arguments.py | 4 - .../transformer/moe/test_token_dispatcher.py | 102 +----------------- 5 files changed, 36 insertions(+), 159 deletions(-) diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index 4f8cd314562..17e5acda81d 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -250,6 +250,10 @@ def permute( and pads the number of tokens to the expert capacity. If set to true, routing_map has a fixed number of non-zeros in each column. + tokens_per_expert : torch.Tensor + Tensor of shape `[num_experts]` containing actual token counts per expert. + align_size : int + the alignment size for the input tensor for fp8 or fp4. """ if fused and probs is None: if not HAVE_TE or fused_permute is None: @@ -260,18 +264,23 @@ def permute( return permuted_input, None, sorted_indices if fused and probs is not None: - if tokens_per_expert is not None and align_size > 0: + if not HAVE_TE or ( + fused_permute_and_pad_with_probs is None and fused_permute_with_probs is None + ): + raise ValueError( + "Transformer Engine (TE) fused kernel is not available. " + "fused_permute_with_probs typically requires TE >= 2.1.0, and " + "fused_permute_and_pad_with_probs` typically requires TE >= 2.12.0. " + ) + if fused_permute_and_pad_with_probs is not None: return fused_permute_and_pad_with_probs( tokens, probs, routing_map, tokens_per_expert, align_size ) else: - if not HAVE_TE or fused_permute_with_probs is None: - raise ValueError( - "fused_permute_with_probs is not available. Please install TE >= 2.1.0." - ) - return fused_permute_with_probs( + output, permuted_probs, row_id_map = fused_permute_with_probs( tokens, probs, routing_map, num_out_tokens=num_out_tokens ) + return output, permuted_probs, row_id_map, None, tokens_per_expert num_tokens, hidden = tokens.shape num_experts = routing_map.shape[1] @@ -347,6 +356,10 @@ def unpermute( fused (bool, optional): Whether use the fused unpermute function. drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop and pads the number of tokens to the expert capacity. + pad_offsets : torch.Tensor, default = None + Tensor of per-expert cumulative padding offsets used to remove padding added + during permutation. This is the fourth output of `moe_permute_and_pad_with_probs` + and is required when unpermuting padded outputs. Returns: torch.Tensor: The tokens restored to their original order. diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index bfbd26f63fb..6952d4165cf 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -1258,29 +1258,21 @@ def get_permuted_hidden_states_by_experts(self, hidden_states: torch.Tensor) -> self.hidden_shape_before_permute = hidden_states.shape assert self.dispatched_probs.dtype == torch.float32, "DeepEP only supports float32 probs" - if self.config.moe_permute_padding_for_quantization: - ( - hidden_states, - permuted_probs, - self.reversed_mapping_for_combine, - self.pad_offsets, - self.tokens_per_expert, - ) = permute( - hidden_states, - self.dispatched_routing_map, - probs=self.dispatched_probs, - fused=self.permute_fusion, - tokens_per_expert=self.tokens_per_expert, - align_size=get_align_size_for_quantization(self.config), - ) - else: - hidden_states, permuted_probs, self.reversed_mapping_for_combine = permute( - hidden_states, - self.dispatched_routing_map, - probs=self.dispatched_probs, - num_out_tokens=self.tokens_per_expert.sum().item(), - fused=self.permute_fusion, - ) + ( + hidden_states, + permuted_probs, + self.reversed_mapping_for_combine, + self.pad_offsets, + self.tokens_per_expert, + ) = permute( + hidden_states, + self.dispatched_routing_map, + probs=self.dispatched_probs, + num_out_tokens=self.tokens_per_expert.sum().item(), + fused=self.permute_fusion, + tokens_per_expert=self.tokens_per_expert, + align_size=get_align_size_for_quantization(self.config), + ) if self.router_dtype == "fp64": permuted_probs = permuted_probs.to(torch.float64) return hidden_states, permuted_probs @@ -1292,9 +1284,7 @@ def get_restored_hidden_states_by_experts(self, hidden_states: torch.Tensor) -> restore_shape=self.hidden_shape_before_permute, routing_map=self.dispatched_routing_map, fused=self.permute_fusion, - pad_offsets=( - self.pad_offsets if self.config.moe_permute_padding_for_quantization else None - ), + pad_offsets=self.pad_offsets, ) return hidden_states diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 17d30e75403..434dbcd1a99 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -496,12 +496,6 @@ class TransformerConfig(ModelParallelConfig): """[Compatibility alias for moe_router_padding_for_quantization] Enabling this will also enable moe_router_padding_for_quantization.""" - moe_permute_padding_for_quantization: Optional[bool] = False - """Enable padding during MoE token permutation and corresponding unpadding during unpermutation - so that the number of tokens in each expert's permuted block is aligned to a multiple of 16 / 32 - for quantized precisions such as FP8 and FP4. This can remove explicit padding/ - unpadding around GroupedMLP kernels, which improves throughput and reduces peak memory usage.""" - moe_router_num_groups: Optional[int] = None """Number of groups to divide experts into for group-limited routing. When using group-limited routing: @@ -1425,22 +1419,6 @@ def __post_init__(self): "moe_router_padding_for_quantization." ) - if self.moe_permute_padding_for_quantization: - if self.fp8 is None and self.fp4 is None: - raise ValueError( - "moe_permute_padding_for_quantization requires a quantized precision recipe(e.g., fp8 or fp4) to be enabled." - ) - - if not self.moe_permute_fusion: - raise ValueError( - "moe_permute_padding_for_quantization currently requires fused permute." - ) - - from megatron.core.transformer.moe.moe_utils import fused_permute_and_pad_with_probs - - if fused_permute_and_pad_with_probs is None: - raise ValueError("fused_permute_and_pad_with_probs is not available. Please install TE >= 2.12.0.") - if ( self.moe_router_topk == 1 and self.moe_router_score_function == "softmax" diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 346f7d4f339..649fe442f45 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -3207,10 +3207,6 @@ def _add_moe_args(parser): group.add_argument('--moe-router-padding-for-fp8', action='store_true', help='[Compatibility alias for --moe-router-padding-for-quantization] ' 'Enabling this will also enable --moe-router-padding-for-quantization.') - group.add_argument('--moe-permute-padding-for-quantization', action='store_true', - help='Enable padding during MoE token permutation (and unpadding during unpermutation) ' - 'so that the number of tokens in each expert permuted block is aligned to a multiple of 16/32 for FP8/FP4 precision. ' - 'This can remove explicit padding/unpadding around GroupedMLP, which improves throughput and reduces peak memory usage.') group.add_argument('--moe-aux-loss-coeff', type=float, nargs='+', default=0.0, help='Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended.') group.add_argument('--moe-z-loss-coeff', type=float, default=None, diff --git a/tests/unit_tests/transformer/moe/test_token_dispatcher.py b/tests/unit_tests/transformer/moe/test_token_dispatcher.py index 73bfe5185e7..c2462ef73ad 100644 --- a/tests/unit_tests/transformer/moe/test_token_dispatcher.py +++ b/tests/unit_tests/transformer/moe/test_token_dispatcher.py @@ -7,10 +7,7 @@ import torch from megatron.core import config, parallel_state -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_local_spec, - get_gpt_layer_with_transformer_engine_spec, -) +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.moe.moe_utils import get_capacity from megatron.core.transformer.transformer_config import TransformerConfig @@ -346,75 +343,6 @@ def dispatcher_router_padding_for_fp8_test(self): grad_1, hidden_states.grad ), "Gradients do not match between padded and non-padded versions" - @pytest.mark.internal - def dispatcher_permute_padding_for_quantization_test(self): - """Test permute padding behavior for FP8 training. - - Run the dispatch flow twice with identical routing: - 1) moe_permute_padding_for_quantization = False - -> naive permute + quantization padding/unpadding - 2) moe_permute_padding_for_quantization = True - -> fused permute+pad and fused unpermute+unpad - - """ - self.config.fp8 = "hybrid" - self.config.moe_grouped_gemm = True - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=self.config.num_moe_experts, moe_grouped_gemm=self.config.moe_grouped_gemm - ) - moe_layer = ( - MoELayer(self.config, transformer_layer_spec.submodules.mlp.submodules) - .cuda() - .to(dtype=self.test_dtype) - ) - moe_layer.set_layer_number(0) - num_tokens = 32 - hidden_states = torch.randn( - (num_tokens, moe_layer.config.hidden_size), dtype=self.test_dtype - ).cuda() - hidden_states.requires_grad = True - probs, indices = moe_layer.router(hidden_states) - - # First run with moe_permute_padding_for_quantization = False, navie permute + pad - moe_layer.config.moe_permute_padding_for_quantization = False - (permuted_input_1, tokens_per_expert_1, permuted_probs_1) = token_permutation( - moe_layer.token_dispatcher, hidden_states, probs, indices - ) - actual_tokens_per_expert = tokens_per_expert_1.tolist() - permuted_input_paded_1, tokens_per_expert_1 = moe_layer.experts.quantization_padding( - permuted_input_1, actual_tokens_per_expert - ) - permuted_probs_paded_1, _ = moe_layer.experts.quantization_padding( - permuted_probs_1.unsqueeze(-1), actual_tokens_per_expert - ) - - restored_hidden_states_1 = moe_layer.experts.quantization_unpadding( - permuted_input_paded_1, actual_tokens_per_expert - ) - restored_hidden_states_1, _ = token_unpermutation( - moe_layer.token_dispatcher, restored_hidden_states_1 - ) - - # Run with moe_permute_padding_for_quantization = True, Fuse permute + pad - moe_layer.config.moe_permute_padding_for_quantization = True - (permuted_input_paded_2, tokens_per_expert_paded_2, permuted_probs_paded_2) = ( - token_permutation(moe_layer.token_dispatcher, hidden_states, probs, indices) - ) - restored_hidden_states_2, _ = token_unpermutation( - moe_layer.token_dispatcher, permuted_input_paded_2 - ) - - # Check that the results are the same - torch.testing.assert_close( - permuted_input_paded_2, permuted_input_paded_1 - ), "permuted hidden states do not match between between permute+pad and fused_permute_pad versions" - torch.testing.assert_close( - permuted_probs_paded_2.unsqueeze(-1), permuted_probs_paded_1 - ), "permuted probs do not match between between permute+pad and fused_permute_pad versions" - torch.testing.assert_close( - restored_hidden_states_1, restored_hidden_states_2 - ), "Restored hidden states do not match between unpermute+unpad and fused_unpermute_unpad versions" - def set_params(self): # TODO: Set consistent parameters for various parallelisms. raise NotImplementedError @@ -597,31 +525,3 @@ def test_router_padding_for_fp8_forward_backward( ) container.dispatcher_router_padding_for_fp8_test() config.ENABLE_EXPERIMENTAL = False - - @pytest.mark.skipif( - not is_te_min_version("2.12.0"), - reason="TE 2.12.0 is required for MoE fused_permute_pad with FP8.", - ) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.internal - @pytest.mark.timeout(120) - @pytest.mark.parametrize("tp_size,ep_size", [(1, 8), (8, 1), (4, 2)]) - @pytest.mark.parametrize("moe_flex_dispatcher_backend", ["deepep"]) - def test_permute_padding_for_quantization(self, tp_size, ep_size, moe_flex_dispatcher_backend): - if moe_flex_dispatcher_backend == "deepep" and not is_deep_ep_available(): - pytest.skip("Deep EP is not available") - container = MoEModelTestContainer( - tp_size=tp_size, - ep_size=ep_size, - pp_size=1, - num_moe_experts=32, - moe_router_topk=4, - moe_router_load_balancing_type="aux_loss", - moe_token_dispatcher_type="flex", - moe_pad_expert_input_to_capacity=False, - moe_permute_fusion=True, - hidden_size=1024, - moe_flex_dispatcher_backend=moe_flex_dispatcher_backend, - test_dtype=torch.bfloat16, - ) - container.dispatcher_permute_padding_for_quantization_test() From 6d46c57de222b612c76770bf4ff5eef5ff05b112 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Wed, 14 Jan 2026 16:01:42 +0800 Subject: [PATCH 3/6] Update moe_utils.py --- megatron/core/transformer/moe/moe_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index ab1d7cf66c0..31916ec2d78 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -14,7 +14,7 @@ from megatron.core.transformer.cuda_graphs import is_graph_capturing from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import internal_api +from megatron.core.utils import internal_api, is_te_min_version try: import transformer_engine as te # pylint: disable=unused-import @@ -371,12 +371,15 @@ def unpermute( if fused: if not HAVE_TE or fused_unpermute is None: raise ValueError("fused_unpermute is not available. Please install TE >= 2.1.0.") + extra_kwargs = {} + if is_te_min_version("2.12.0"): + extra_kwargs["pad_offsets"] = pad_offsets return fused_unpermute( permuted_tokens, sorted_indices, merging_probs=probs, restore_shape=restore_shape, - pad_offsets=pad_offsets, + **extra_kwargs, ) _, hidden = restore_shape From b96bd1eb254711eec9cbd059ddfee645feab36c1 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Mon, 26 Jan 2026 11:27:43 +0800 Subject: [PATCH 4/6] Update token_dispatcher.py --- megatron/core/transformer/moe/token_dispatcher.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index c8a59ecce88..68fdf1a247c 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -294,11 +294,13 @@ def dispatch_postprocess(self, hidden_states, probs): tokens_per_expert = self.local_map.sum(dim=0).long().cpu() - (permuted_local_hidden_states, _, self.reversed_local_input_permutation_mapping) = permute( - hidden_states, - self.local_map, - num_out_tokens=tokens_per_expert.sum(), - fused=self.config.moe_permute_fusion, + permuted_local_hidden_states, _, self.reversed_local_input_permutation_mapping, _, _ = ( + permute( + hidden_states, + self.local_map, + num_out_tokens=tokens_per_expert.sum(), + fused=self.config.moe_permute_fusion, + ) ) self.local_probs = self.local_probs.T.contiguous().masked_select( @@ -634,6 +636,8 @@ def dispatch_preprocess( permutated_local_input_tokens, permuted_probs, self.reversed_local_input_permutation_mapping, + _, + _, ) = permute( hidden_states, self.routing_map, From db81c3a9fb2e69ecc8644669eb13ca612517a258 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Mon, 26 Jan 2026 11:28:03 +0800 Subject: [PATCH 5/6] Update moe_utils.py --- megatron/core/transformer/moe/moe_utils.py | 39 ++++++++++++++-------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index 0700a2432f0..8aab966850c 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -248,8 +248,13 @@ def permute( drop_and_pad: bool = False, tokens_per_expert: Optional[torch.Tensor] = None, align_size: int = -1, -): -) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: +) -> Tuple[ + torch.Tensor, + Optional[torch.Tensor], + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], +]: """Permute the tokens and probs based on the mask. Tokens with the same designated expert will be grouped together. The shape of mask is [tokens, num_experts], it indicates which experts were selected @@ -258,6 +263,9 @@ def permute( When drop_and_pad=True, in routing_map, the number of non-zeros in each column equals to expert capacity. This function exploits this feature to use ops that support cuda graph. + If the fused permute and pad kernel is available, it will pad the tokens to the align_size + and return the padded permuted tokens, pad_offsets and padded tokens per expert. + Args: tokens (torch.Tensor): The input token tensor, [num_tokens, hidden]. routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts]. @@ -269,14 +277,20 @@ def permute( and pads the number of tokens to the expert capacity. If set to true, routing_map has a fixed number of non-zeros in each column. - tokens_per_expert : torch.Tensor - Tensor of shape `[num_experts]` containing actual token counts per expert. - align_size : int - the alignment size for the input tensor for fp8 or fp4. + tokens_per_expert (torch.Tensor, optional): Tensor of shape `[num_experts]` containing + actual token counts per expert. + align_size (int, optional): The alignment size for the input tensor for fp8 or fp4. Returns: - Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: - The permuted tokens, permuted probs, and sorted indices. + Tuple[ + torch.Tensor, + Optional[torch.Tensor], + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + ]: + The permuted tokens, (optional) permuted probs, sorted indices, + (optional) pad_offsets, (optional) padded_tokens_per_expert. """ if fused and probs is None: if not HAVE_TE or fused_permute is None: @@ -284,7 +298,7 @@ def permute( permuted_input, sorted_indices = fused_permute( tokens, routing_map, num_out_tokens=num_out_tokens ) - return permuted_input, None, sorted_indices + return permuted_input, None, sorted_indices, None, tokens_per_expert if fused and probs is not None: if not HAVE_TE or ( @@ -346,7 +360,7 @@ def permute( # use the mapping to permute the tokens permuted_input = tokens.index_select(0, sorted_indices) - return permuted_input, permuted_probs, sorted_indices + return permuted_input, permuted_probs, sorted_indices, None, tokens_per_expert def unpermute( @@ -358,7 +372,6 @@ def unpermute( fused: bool = False, drop_and_pad: bool = False, pad_offsets: Optional[torch.Tensor] = None, -): ) -> torch.Tensor: """ Restore the original order of tokens after permutation. If probs are provided, it @@ -380,10 +393,10 @@ def unpermute( fused (bool, optional): Whether use the fused unpermute function. drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop and pads the number of tokens to the expert capacity. - pad_offsets : torch.Tensor, default = None + pad_offsets (torch.Tensor, optional): Tensor of per-expert cumulative padding offsets used to remove padding added during permutation. This is the fourth output of `moe_permute_and_pad_with_probs` - and is required when unpermuting padded outputs. + and is required when unpermuting padded outputs. Defaults to None. Returns: torch.Tensor: The tokens restored to their original order. From 9057392e321d3659e887c0229a080f4650b27695 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Wed, 4 Feb 2026 17:15:22 +0800 Subject: [PATCH 6/6] Update moe_utils.py --- megatron/core/transformer/moe/moe_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index 66f3997b9ee..47debdd27df 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -359,7 +359,7 @@ def permute( "fused_permute_with_probs typically requires TE >= 2.1.0, and " "fused_permute_and_pad_with_probs` typically requires TE >= 2.12.0. " ) - if fused_permute_and_pad_with_probs is not None: + if fused_permute_and_pad_with_probs is not None and tokens_per_expert is not None: return fused_permute_and_pad_with_probs( tokens, probs, routing_map, tokens_per_expert, align_size )