From bdc428d0960ce8455e89e58f0a7a17bf9a55e9ba Mon Sep 17 00:00:00 2001 From: xiaoxi-wangfj <690912414@qq.com> Date: Mon, 22 Dec 2025 07:01:08 +0000 Subject: [PATCH 1/3] [FP8-Flow-MoE] Fuse permute+pad and unpermute+unpad ops for FP8 optimization Signed-off-by: xiaoxi-wangfj <690912414@qq.com> --- tests/pytorch/test_permutation.py | 658 +++++++++++++++++- .../common/triton/permutation.py | 38 +- transformer_engine/pytorch/__init__.py | 1 + transformer_engine/pytorch/permutation.py | 104 ++- .../pytorch/triton/permutation.py | 50 +- 5 files changed, 804 insertions(+), 47 deletions(-) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index e8a7bedc87..9a0cf6fb7c 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. +import os import random import torch @@ -13,6 +14,7 @@ from transformer_engine.pytorch import ( moe_permute as te_permute, moe_permute_with_probs as te_permute_with_probs, + moe_permute_and_pad_with_probs as te_permute_and_pad_with_probs, moe_unpermute as te_unpermute, moe_sort_chunks_by_index as te_sort_chunks_by_index, moe_sort_chunks_by_index_with_probs as te_sort_chunks_by_index_with_probs, @@ -24,6 +26,7 @@ MXFP8Quantizer, ) import transformer_engine_torch as tex +from transformer_engine.pytorch import Fp8Padding, Fp8Unpadding import copy seed = 1234 @@ -653,6 +656,522 @@ def _test_permutation_mask_map( print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") +def _test_permutation_and_padding_mask_map( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + with_merging_probs=False, + align_size=16, + BENCHMARK=False, +): + if topK > num_expert: + pytest.skip("topK should be smaller than the number of experts.") + + if num_out_tokens is None: + num_out_tokens = num_tokens * topK + + print( + "permutation and padding:" + f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK}" + f" with_merging_probs:{with_merging_probs} align_size:{align_size} {te_dtype}" + ) + + # Convert TE dtypes to PyTorch dtypes + if te_dtype == tex.DType.kFloat32: + dtype = torch.float32 + elif te_dtype == tex.DType.kFloat16: + dtype = torch.float16 + elif te_dtype == tex.DType.kBFloat16: + dtype = torch.bfloat16 + else: + pytest.skip("Invalid dtype.") + + _tmp_tensor = torch.zeros((num_tokens * num_expert,)) + _tmp_tensor[: int(num_out_tokens)] = 1.0 + _tmp_idx = torch.randperm(num_tokens * num_expert) + routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda() + + probs = torch.rand(num_tokens, num_expert).cuda() * routing_map + row_sums = probs.sum(dim=1, keepdim=True) + probs = probs / row_sums + probs = probs.to(dtype) + probs.requires_grad_(True) + + tokens_per_expert = routing_map.sum(dim=0).cpu() + target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long() + num_permute_pad_out_tokens = target_tokens_per_expert.sum().item() + + permute_pad_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + permute_pad_bwd_input = torch.rand( + (num_permute_pad_out_tokens, hidden_size), dtype=dtype + ).cuda() + unpermute_unpad_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + permute_pad_fwd_input.requires_grad_(True) + + restore_shape = permute_pad_fwd_input.shape + ################################################################################################################################### + # + # moe_permute_with_probs and Fp8Padding, moe_unpermute and Fp8Unpadding + # + ################################################################################################################################### + # permute + padding + permuted_output, permuted_probs, row_id_map = te_permute_with_probs( + permute_pad_fwd_input, + probs, + routing_map, + num_out_tokens=num_out_tokens, + ) + tokens_per_expert_list = tokens_per_expert.tolist() + fp8_padding = Fp8Padding(num_expert, align_size) + permuted_paded_output, _ = fp8_padding(permuted_output, tokens_per_expert_list) + permuted_paded_probs, _ = fp8_padding(permuted_probs.unsqueeze(-1), tokens_per_expert_list) + + permuted_paded_output.backward(permute_pad_bwd_input, retain_graph=True) + + # unpadding + unpermute + + unpermute_unpad_fwd_input = permuted_paded_output.detach() + unpermute_unpad_fwd_input.requires_grad_(True) + + fp8_unpadding = Fp8Unpadding(num_expert, align_size) + unpaded_output = fp8_unpadding(unpermute_unpad_fwd_input, tokens_per_expert_list) + + probs_naive = probs + unpermuted_unpaded_output = te_unpermute( + unpaded_output, + row_id_map, + merging_probs=probs_naive if with_merging_probs else None, + restore_shape=restore_shape, + ) + + unpermuted_unpaded_output.backward(unpermute_unpad_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # fusion moe_permute_with_probs and Fp8Padding, fusion fusion moe_unpermute and Fp8Unpadding + # + ################################################################################################################################### + # fusion permute_and_pad + fusion_permute_and_pad_fwd_input = permute_pad_fwd_input.detach() + fusion_permute_and_pad_fwd_input.requires_grad_(True) + probs_fusion = probs_naive.detach().clone() + probs_fusion.requires_grad_(True) + + ( + fusion_permuted_padded_output, + fusion_permuted_padded_probs, + row_id_map, + pad_offsets, + target_tokens_per_expert, + ) = te_permute_and_pad_with_probs( + fusion_permute_and_pad_fwd_input, + probs_fusion, + routing_map, + tokens_per_expert, + align_size, + ) + fusion_permuted_padded_probs = fusion_permuted_padded_probs.unsqueeze(-1) + + fusion_permute_pad_bwd_input = permute_pad_bwd_input.detach() + fusion_permuted_padded_output.backward(fusion_permute_pad_bwd_input, retain_graph=True) + + # fusion unpad and unpermute + fusion_unpermute_unpad_fwd_input = fusion_permuted_padded_output.detach() + fusion_unpermute_unpad_fwd_input.requires_grad_(True) + + fusion_unpermuted_unpaded_output = te_unpermute( + fusion_unpermute_unpad_fwd_input, + row_id_map, + merging_probs=probs_fusion if with_merging_probs else None, + restore_shape=restore_shape, + pad_offsets=pad_offsets, + ) + + fusion_unpermute_bwd_input = unpermute_unpad_bwd_input.detach() + fusion_unpermuted_unpaded_output.backward(fusion_unpermute_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # Results Check + # + ################################################################################################################################### + tols = dtype_tols(te_dtype) + + permuted_paded_output_ = permuted_paded_output.float() + fusion_permuted_padded_output_ = fusion_permuted_padded_output.float() + permute_pad_fwd_input_grad = permute_pad_fwd_input.grad.float() + fusion_permute_and_pad_fwd_input_grad = fusion_permute_and_pad_fwd_input.grad.float() + + unpermuted_unpaded_output_ = unpermuted_unpaded_output.float() + fusion_unpermuted_unpaded_output_ = fusion_unpermuted_unpaded_output.float() + unpermute_unpad_fwd_input_grad = unpermute_unpad_fwd_input.grad.float() + fusion_unpermute_unpad_fwd_input_grad = fusion_unpermute_unpad_fwd_input.grad.float() + + if not BENCHMARK: + torch.testing.assert_close( + permuted_paded_output_, + fusion_permuted_padded_output_, + msg=f"Mismatch in te_permute_and_pad fwd", + **tols, + ) + torch.testing.assert_close( + permute_pad_fwd_input_grad, + fusion_permute_and_pad_fwd_input_grad, + msg=f"Mismatch in te_permute_and_pad bwd", + **tols, + ) + torch.testing.assert_close( + unpermuted_unpaded_output_, + fusion_unpermuted_unpaded_output_, + msg=f"Mismatch in te_unpermute fwd", + **tols, + ) + torch.testing.assert_close( + unpermute_unpad_fwd_input_grad, + fusion_unpermute_unpad_fwd_input_grad, + msg=f"Mismatch in te_unpermute bwd", + **tols, + ) + torch.testing.assert_close( + permuted_paded_probs.float(), + fusion_permuted_padded_probs.float(), + msg=f"Mismatch in te_permute_and_pad bwd", + **tols, + ) + if with_merging_probs: + torch.testing.assert_close( + probs_naive.grad.float(), + probs_fusion.grad.float(), + msg=f"Mismatch in te_unpermute bwd", + **tols, + ) + + ################################################################################################################################### + # + # Benchmark + # + ################################################################################################################################### + if BENCHMARK: + + def permute_and_pad(): + permuted_output, permuted_probs, row_id_map = te_permute_with_probs( + permute_pad_fwd_input, + probs, + routing_map, + num_out_tokens=num_out_tokens, + ) + fp8_padding(permuted_output, tokens_per_expert_list) + fp8_padding(permuted_probs.unsqueeze(-1), tokens_per_expert_list) + + def fusion_permute_and_pad(): + ( + fusion_permuted_padded_output, + fusion_permuted_padded_probs, + row_id_map, + pad_offsets, + target_tokens_per_expert, + ) = te_permute_and_pad_with_probs( + fusion_permute_and_pad_fwd_input, + probs, + routing_map, + tokens_per_expert, + align_size, + ) + fusion_permuted_padded_probs = fusion_permuted_padded_probs.unsqueeze(-1) + + t1 = perf_test_cuda_kernel(lambda: permute_and_pad()) + + t2 = perf_test_cuda_kernel(lambda: fusion_permute_and_pad()) + + print(f"permute_and_pad\t\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + permuted_paded_output, + permute_pad_bwd_input, + forward_input=[permute_pad_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + fusion_permuted_padded_output, + fusion_permute_pad_bwd_input, + forward_input=[fusion_permute_and_pad_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"permute_and_pad\t\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms") + + def unpad_unpermute(): + unpaded_output = fp8_unpadding(unpermute_unpad_fwd_input, tokens_per_expert_list) + unpermuted_unpaded_output = te_unpermute( + unpaded_output, row_id_map, restore_shape=restore_shape + ) + + unpermuted_unpaded_output.backward(unpermute_unpad_bwd_input, retain_graph=True) + + t1 = perf_test_cuda_kernel(lambda: unpad_unpermute()) + t2 = perf_test_cuda_kernel( + lambda: te_unpermute( + fusion_unpermute_unpad_fwd_input, + row_id_map, + restore_shape=restore_shape, + pad_offsets=pad_offsets, + ) + ) + print(f"unpermute_and_unpad\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + unpermuted_unpaded_output, + unpermute_unpad_bwd_input, + forward_input=([unpermute_unpad_fwd_input, probs]), + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + fusion_unpermuted_unpaded_output, + fusion_unpermute_bwd_input, + forward_input=([fusion_unpermute_unpad_fwd_input, probs]), + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"unpermute_and_unpad\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms") + + +def _test_permutation_and_padding_with_merging_probs( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + align_size=16, + BENCHMARK=False, +): + """ + Test the combination of merging_probs AND pad_offsets together in moe_unpermute. + This specifically tests the backward pass fix where pad_offsets must be used + when computing gradients with merging_probs. + """ + if topK > num_expert: + pytest.skip("topK should be smaller than the number of experts.") + + if num_out_tokens == None: + num_out_tokens = num_tokens * topK + + print( + "permutation and padding with merging probs:" + f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} align_size:{align_size} {te_dtype}" + ) + + # Convert TE dtypes to PyTorch dtypes + if te_dtype == tex.DType.kFloat32: + dtype = torch.float32 + elif te_dtype == tex.DType.kFloat16: + dtype = torch.float16 + elif te_dtype == tex.DType.kBFloat16: + dtype = torch.bfloat16 + else: + pytest.skip("Invalid dtype.") + + _tmp_tensor = torch.zeros((num_tokens * num_expert,)) + _tmp_tensor[: int(num_out_tokens)] = 1.0 + _tmp_idx = torch.randperm(num_tokens * num_expert) + routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda() + + probs = torch.rand(num_tokens, num_expert).cuda() * routing_map + row_sums = probs.sum(dim=1, keepdim=True) + probs = probs / row_sums + probs = probs.to(dtype) + probs.requires_grad_(True) + + tokens_per_expert = routing_map.sum(dim=0).cpu() + target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long() + num_permute_pad_out_tokens = target_tokens_per_expert.sum().item() + + permute_pad_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + permute_pad_bwd_input = torch.rand( + (num_permute_pad_out_tokens, hidden_size), dtype=dtype + ).cuda() + unpermute_unpad_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + permute_pad_fwd_input.requires_grad_(True) + + restore_shape = permute_pad_fwd_input.shape + ################################################################################################################################### + # + # Reference: moe_permute_with_probs + Fp8Padding, then Fp8Unpadding + moe_unpermute with merging_probs + # + ################################################################################################################################### + # permute + padding + permuted_output, permuted_probs, row_id_map = te_permute_with_probs( + permute_pad_fwd_input, + probs, + routing_map, + num_out_tokens=num_out_tokens, + ) + tokens_per_expert_list = tokens_per_expert.tolist() + fp8_padding = Fp8Padding(num_expert, align_size) + permuted_paded_output, _ = fp8_padding(permuted_output, tokens_per_expert_list) + + permuted_paded_output.backward(permute_pad_bwd_input, retain_graph=True) + + # Reference: unpadding + unpermute WITH merging_probs + ref_unpermute_fwd_input = permuted_paded_output.detach() + ref_unpermute_fwd_input.requires_grad_(True) + + ref_probs = probs.detach() + ref_probs.requires_grad_(True) + + fp8_unpadding = Fp8Unpadding(num_expert, align_size) + unpaded_output = fp8_unpadding(ref_unpermute_fwd_input, tokens_per_expert_list) + ref_unpermuted_output = te_unpermute( + unpaded_output, row_id_map, ref_probs, restore_shape=restore_shape + ) + + ref_unpermuted_output.backward(unpermute_unpad_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # Fused: moe_permute_and_pad_with_probs, then moe_unpermute with BOTH merging_probs AND pad_offsets + # + ################################################################################################################################### + # fusion permute_and_pad + fusion_permute_fwd_input = permute_pad_fwd_input.detach() + fusion_permute_fwd_input.requires_grad_(True) + fusion_probs = probs.detach() + fusion_probs.requires_grad_(True) + + ( + fusion_permuted_padded_output, + fusion_permuted_padded_probs, + fused_row_id_map, + pad_offsets, + _, + ) = te_permute_and_pad_with_probs( + fusion_permute_fwd_input, + fusion_probs, + routing_map, + tokens_per_expert, + align_size, + ) + + fusion_permute_pad_bwd_input = permute_pad_bwd_input.detach() + fusion_permuted_padded_output.backward(fusion_permute_pad_bwd_input, retain_graph=True) + + # Fused: unpermute with BOTH merging_probs AND pad_offsets + fusion_unpermute_fwd_input = fusion_permuted_padded_output.detach() + fusion_unpermute_fwd_input.requires_grad_(True) + + fusion_merging_probs = probs.detach() + fusion_merging_probs.requires_grad_(True) + + fusion_unpermuted_output = te_unpermute( + fusion_unpermute_fwd_input, + fused_row_id_map, + fusion_merging_probs, + restore_shape=restore_shape, + pad_offsets=pad_offsets, + ) + + fusion_unpermute_bwd_input = unpermute_unpad_bwd_input.detach() + fusion_unpermuted_output.backward(fusion_unpermute_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # Results Check + # + ################################################################################################################################### + tols = dtype_tols(te_dtype) + + # Check forward pass + ref_unpermuted_output_ = ref_unpermuted_output.float() + fusion_unpermuted_output_ = fusion_unpermuted_output.float() + + if not BENCHMARK: + torch.testing.assert_close( + ref_unpermuted_output_, + fusion_unpermuted_output_, + msg=f"Mismatch in te_unpermute with merging_probs and pad_offsets fwd", + **tols, + ) + + # Check backward pass - activation gradients + ref_unpermute_fwd_input_grad = ref_unpermute_fwd_input.grad.float() + fusion_unpermute_fwd_input_grad = fusion_unpermute_fwd_input.grad.float() + + torch.testing.assert_close( + ref_unpermute_fwd_input_grad, + fusion_unpermute_fwd_input_grad, + msg=f"Mismatch in te_unpermute with merging_probs and pad_offsets bwd (act_grad)", + **tols, + ) + + # Check backward pass - probs gradients + ref_probs_grad = ref_probs.grad.float() + fusion_probs_grad = fusion_merging_probs.grad.float() + + torch.testing.assert_close( + ref_probs_grad, + fusion_probs_grad, + msg=f"Mismatch in te_unpermute with merging_probs and pad_offsets bwd (probs_grad)", + **tols, + ) + + ################################################################################################################################### + # + # Benchmark + # + ################################################################################################################################### + if BENCHMARK: + + def ref_unpad_unpermute(): + unpaded = fp8_unpadding(ref_unpermute_fwd_input, tokens_per_expert_list) + return te_unpermute(unpaded, row_id_map, ref_probs, restore_shape=restore_shape) + + def fused_unpermute(): + return te_unpermute( + fusion_unpermute_fwd_input, + fused_row_id_map, + fusion_merging_probs, + restore_shape=restore_shape, + pad_offsets=pad_offsets, + ) + + t1 = perf_test_cuda_kernel(lambda: ref_unpad_unpermute()) + t2 = perf_test_cuda_kernel(lambda: fused_unpermute()) + print(f"unpermute_unpad_with_probs\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + ref_unpermuted_output, + unpermute_unpad_bwd_input, + forward_input=[ref_unpermute_fwd_input, ref_probs], + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + fusion_unpermuted_output, + fusion_unpermute_bwd_input, + forward_input=[fusion_unpermute_fwd_input, fusion_merging_probs], + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"unpermute_unpad_with_probs\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms") + + def _test_permutation_mask_map_fp8( te_dtype, num_tokens, @@ -1126,7 +1645,7 @@ def perf_test_cuda_kernel(cuda_kernel_fn): @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) -@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("topK", [2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) def test_permutation_index_map( te_dtype, @@ -1155,7 +1674,7 @@ def test_permutation_index_map( @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) -@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("topK", [2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) def test_permutation_mask_map( te_dtype, @@ -1180,6 +1699,74 @@ def test_permutation_mask_map( ) +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_out_tokens", [None]) +@pytest.mark.parametrize( + "num_tokens, num_expert, hidden_size, topK", + [ + (4096, 8, 1280, 2), + (4096, 64, 4096, 6), + (4096, 256, 7168, 6), + (4096, 512, 9216, 8), + ], +) +@pytest.mark.parametrize("with_merging_probs", [True, False]) +def test_permutation_and_padding_mask_map( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + with_merging_probs, +): + BENCHMARK = False + + _test_permutation_and_padding_mask_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_merging_probs=with_merging_probs, + BENCHMARK=BENCHMARK, + ) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_out_tokens", [None]) +@pytest.mark.parametrize( + "num_tokens, num_expert, hidden_size, topK", + [ + (4096, 8, 1280, 2), + (4096, 64, 4096, 6), + (4096, 256, 7168, 6), + (4096, 512, 9216, 8), + ], +) +def test_permutation_and_padding_with_merging_probs( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, +): + """Test moe_unpermute backward pass with BOTH merging_probs AND pad_offsets.""" + BENCHMARK = False + + _test_permutation_and_padding_with_merging_probs( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + BENCHMARK=BENCHMARK, + ) + + @pytest.mark.parametrize("te_dtype", _te_dtypes) def test_permutation_mask_map_empty_input(te_dtype): with_probs = True @@ -1201,9 +1788,9 @@ def test_permutation_mask_map_empty_input(te_dtype): @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) -@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("topK", [2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) -@pytest.mark.parametrize("tp_size", [1, 2, 8]) +@pytest.mark.parametrize("tp_size", [1, 2]) def test_permutation_mask_map_alongside_probs( te_dtype, num_tokens, @@ -1253,10 +1840,10 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype): @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) -@pytest.mark.parametrize("num_tokens", [2048]) +@pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) -@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("topK", [2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) @pytest.mark.parametrize("recipe", fp8_recipes) def test_permutation_mask_map_fp8( @@ -1341,7 +1928,7 @@ def test_permutation_mask_map_topk1_no_probs( @pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [7, 16]) -@pytest.mark.parametrize("tp_size", [1, 2, 8]) +@pytest.mark.parametrize("tp_size", [2, 8]) @pytest.mark.parametrize("hidden_size", [4096]) def test_chunk_permutation( te_dtype, @@ -1376,6 +1963,10 @@ def test_chunk_permutation_empty_input(te_dtype): ) +@pytest.mark.skipif( + os.getenv("RUN_BENCHMARK_TESTS", "0") != "1", + reason="Benchmark test - run with: RUN_BENCHMARK_TESTS=1 pytest -k single_case", +) def test_permutation_single_case(): print("GPU:", torch.cuda.get_device_name(0)) @@ -1413,6 +2004,26 @@ def test_permutation_single_case(): BENCHMARK=Benchmark, ) + _test_permutation_and_padding_mask_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + BENCHMARK=Benchmark, + ) + + _test_permutation_and_padding_with_merging_probs( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + BENCHMARK=Benchmark, + ) + _test_moe_chunk_sort( te_dtype=te_dtype, num_tokens=num_tokens, @@ -1479,6 +2090,30 @@ def benchmark_single_case( ) torch.cuda.nvtx.range_pop() + torch.cuda.nvtx.range_push("permutation_and_padding_mask_map") + _test_permutation_and_padding_mask_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + BENCHMARK=True, + ) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push("permutation_and_padding_with_merging_probs") + _test_permutation_and_padding_with_merging_probs( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + BENCHMARK=True, + ) + torch.cuda.nvtx.range_pop() + torch.cuda.nvtx.range_push("permutation_mask_map_alongside_probs") _test_permutation_mask_map_alongside_probs( te_dtype=te_dtype, @@ -1495,7 +2130,12 @@ def benchmark_single_case( torch.cuda.nvtx.range_pop() -def benchmark_multiple_cases(): +@pytest.mark.skipif( + os.getenv("RUN_BENCHMARK_TESTS", "0") != "1", + reason="Benchmark test - run with: RUN_BENCHMARK_TESTS=1 pytest -k benchmark", +) +def test_benchmark_multiple_cases(): + """Benchmark test - skipped by default. Run with: RUN_BENCHMARK_TESTS=1 pytest -k benchmark""" print("GPU:", torch.cuda.get_device_name(0)) # te_dtype = tex.DType.kFloat32 @@ -1537,4 +2177,4 @@ def benchmark_multiple_cases(): if __name__ == "__main__": - benchmark_multiple_cases() + test_benchmark_multiple_cases() diff --git a/transformer_engine/common/triton/permutation.py b/transformer_engine/common/triton/permutation.py index 87a9c24533..de30c7c532 100644 --- a/transformer_engine/common/triton/permutation.py +++ b/transformer_engine/common/triton/permutation.py @@ -200,6 +200,7 @@ def _permute_kernel( probs_ptr, scale_ptr, permuted_scale_ptr, + pad_offsets_ptr, # sizes scale_hidden_dim, # strides @@ -224,8 +225,11 @@ def _permute_kernel( hidden_size: tl.constexpr, PERMUTE_PROBS: tl.constexpr, PERMUTE_SCALE: tl.constexpr, + FUSION_PAD: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): + expert_idx = 0 + pid_t = tl.program_id(0) pid_h = tl.program_id(1) cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) @@ -246,6 +250,15 @@ def _permute_kernel( dst_row = tl.load( row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert ).to(tl.int64) + if FUSION_PAD or PERMUTE_PROBS: + expert_idx = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + (num_experts + idx) * stride_row_id_map_expert + ) + if FUSION_PAD: + pad_off = tl.load(pad_offsets_ptr + expert_idx) + dst_row = dst_row + pad_off output_off = dst_row * stride_output_token + cur_off * stride_output_hidden if PERMUTE_SCALE: permuted_scale_off = ( @@ -253,11 +266,6 @@ def _permute_kernel( ) tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale) if PERMUTE_PROBS: - expert_idx = tl.load( - row_id_map_ptr - + pid_t * stride_row_id_map_token - + (num_experts + idx) * stride_row_id_map_expert - ) prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert prob = tl.load(probs_ptr + prob_off) if pid_h == 0: @@ -297,6 +305,7 @@ def _unpermute_kernel( row_id_map_ptr, merging_probs_ptr, permuted_probs_ptr, + pad_offsets_ptr, # strides stride_row_id_map_token, stride_row_id_map_expert, @@ -318,10 +327,12 @@ def _unpermute_kernel( PROBS_LOAD_WIDTH: tl.constexpr, WITH_MERGING_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr, + FUSION_UNPAD: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): data_type = input_ptr.dtype.element_ty compute_type = tl.float32 + expert_idx = 0 pid_t = tl.program_id(0) pid_h = tl.program_id(1) @@ -348,15 +359,19 @@ def _unpermute_kernel( src_row = tl.load( row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert ).to(tl.int64) - input_off = src_row * stride_input_token + current_offset * stride_input_hidden - inp = tl.load(input_ptr + input_off, mask=mask) - inp = inp.to(compute_type) - if WITH_MERGING_PROBS: + if FUSION_UNPAD or WITH_MERGING_PROBS: expert_idx = tl.load( row_id_map_ptr + pid_t * stride_row_id_map_token + (num_experts + idx) * stride_row_id_map_expert ) + if FUSION_UNPAD: + pad_off = tl.load(pad_offsets_ptr + expert_idx) + src_row = src_row + pad_off + input_off = src_row * stride_input_token + current_offset * stride_input_hidden + inp = tl.load(input_ptr + input_off, mask=mask) + inp = inp.to(compute_type) + if WITH_MERGING_PROBS: merging_prob_off = ( pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert ) @@ -407,6 +422,7 @@ def _unpermute_bwd_with_merging_probs_kernel( fwd_input_ptr, merging_probs_ptr, row_id_map_ptr, + pad_offsets_ptr, # strides stride_row_id_map_token, stride_row_id_map_expert, @@ -427,6 +443,7 @@ def _unpermute_bwd_with_merging_probs_kernel( num_experts: tl.constexpr, hidden_size: tl.constexpr, PROBS_LOAD_WIDTH: tl.constexpr, + FUSION_UNPAD: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): data_type = fwd_output_grad_ptr.dtype.element_ty @@ -450,6 +467,9 @@ def _unpermute_bwd_with_merging_probs_kernel( + pid * stride_row_id_map_token + (num_experts + idx) * stride_row_id_map_expert ) + if FUSION_UNPAD: + pad_off = tl.load(pad_offsets_ptr + expert_idx) + dst_row = dst_row + pad_off prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type) current_start = 0 while current_start < hidden_size: diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 5341af3d74..9f4a9678eb 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -34,6 +34,7 @@ from transformer_engine.pytorch.permutation import ( moe_permute, moe_permute_with_probs, + moe_permute_and_pad_with_probs, moe_unpermute, moe_sort_chunks_by_index, moe_sort_chunks_by_index_with_probs, diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 9fce9cefcf..d15814585e 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -"""MoE Permutaion API""" +"""MoE Permutation API""" import warnings from typing import Optional, Tuple import torch @@ -191,6 +191,7 @@ def forward( routing_map: torch.Tensor, num_out_tokens: int, probs: torch.Tensor, + pad_offsets: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # pylint: disable=missing-function-docstring if not inp.numel(): @@ -201,6 +202,8 @@ def forward( assert routing_map.is_cuda, "TransformerEngine needs CUDA." if probs is not None: assert probs.is_cuda, "TransformerEngine needs CUDA." + if pad_offsets is not None: + assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." assert inp.size(0) == routing_map.size(0), "Permute not possible" num_tokens, hidden_size = inp.size() @@ -250,6 +253,7 @@ def forward( row_id_map, probs, fp8_scale, + pad_offsets, num_tokens, num_experts, num_out_tokens, @@ -292,7 +296,7 @@ def forward( requires_grad=output.requires_grad, ) - ctx.save_for_backward(row_id_map) + ctx.save_for_backward(row_id_map, pad_offsets) ctx.num_experts = num_experts ctx.num_tokens = num_tokens ctx.hidden_size = hidden_size @@ -307,12 +311,12 @@ def backward( ) -> Tuple[torch.Tensor, ...]: # pylint: disable=missing-function-docstring if not permuted_act_grad.numel(): - return permuted_act_grad, None, None, ctx.probs + return permuted_act_grad, None, None, ctx.probs, None act_grad = None probs_grad = None if ctx.needs_input_grad[0]: - (row_id_map,) = ctx.saved_tensors + row_id_map, pad_offsets = ctx.saved_tensors assert not isinstance( permuted_act_grad, QuantizedTensor ), "The backward of moe_permute does not support FP8." @@ -321,13 +325,14 @@ def backward( row_id_map, None, permuted_probs_grad, + pad_offsets, ctx.num_tokens, ctx.num_experts, ctx.hidden_size, ) if not ctx.needs_input_grad[3]: probs_grad = None - return act_grad, None, None, probs_grad + return act_grad, None, None, probs_grad, None class _moe_unpermute_mask_map(torch.autograd.Function): @@ -340,6 +345,7 @@ def forward( row_id_map: torch.Tensor, merging_probs: Optional[torch.Tensor], restore_shape: Optional[torch.Size], + pad_offsets: Optional[torch.Tensor], ) -> torch.Tensor: # pylint: disable=missing-function-docstring if not inp.numel(): @@ -358,6 +364,8 @@ def forward( # Device check assert inp.is_cuda, "TransformerEngine needs CUDA." assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + if pad_offsets is not None: + assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." assert not isinstance( inp, QuantizedTensor @@ -367,15 +375,16 @@ def forward( row_id_map, merging_probs, None, + pad_offsets, num_tokens, num_experts, hidden_size, ) if with_probs: - ctx.save_for_backward(inp, row_id_map, merging_probs) + ctx.save_for_backward(inp, row_id_map, merging_probs, pad_offsets) else: - ctx.save_for_backward(row_id_map) + ctx.save_for_backward(row_id_map, pad_offsets) ctx.num_experts = num_experts ctx.num_tokens = num_tokens ctx.num_permuted_tokens = inp.size(0) @@ -387,15 +396,15 @@ def forward( def backward(ctx, unpermuted_act_grad): # pylint: disable=missing-function-docstring if not unpermuted_act_grad.numel(): - return unpermuted_act_grad, None, ctx.merging_probs, None + return unpermuted_act_grad, None, ctx.merging_probs, None, None act_grad = None probs_grad = None if ctx.needs_input_grad[0]: if ctx.with_probs: - fwd_input, row_id_map, merging_probs = ctx.saved_tensors + fwd_input, row_id_map, merging_probs, pad_offsets = ctx.saved_tensors else: - (row_id_map,) = ctx.saved_tensors + row_id_map, pad_offsets = ctx.saved_tensors fp8 = isinstance(unpermuted_act_grad, QuantizedTensor) per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor) @@ -441,6 +450,7 @@ def backward(ctx, unpermuted_act_grad): row_id_map, fwd_input, merging_probs, + pad_offsets, ctx.num_tokens, ctx.num_experts, ctx.num_permuted_tokens, @@ -453,6 +463,7 @@ def backward(ctx, unpermuted_act_grad): row_id_map, None, fp8_scale, + pad_offsets, ctx.num_tokens, ctx.num_experts, ctx.num_permuted_tokens, @@ -497,7 +508,7 @@ def backward(ctx, unpermuted_act_grad): if not ctx.needs_input_grad[2]: probs_grad = None - return act_grad, None, probs_grad, None + return act_grad, None, probs_grad, None, None def moe_permute( @@ -537,7 +548,9 @@ def moe_permute( if map_type == "index": return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num) if map_type == "mask": - output, row_id_map, _ = _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens, None) + output, row_id_map, _ = _moe_permute_mask_map.apply( + inp, routing_map, num_out_tokens, None, None + ) return output, row_id_map raise ValueError("map_type should be one of 'mask' or 'index'") @@ -570,11 +583,67 @@ def moe_permute_with_probs( By default, set to '-1', meaning no tokens are dropped. """ output, row_id_map, permuted_probs = _moe_permute_mask_map.apply( - inp, routing_map, num_out_tokens, probs + inp, routing_map, num_out_tokens, probs, None ) return output, permuted_probs, row_id_map +def moe_permute_and_pad_with_probs( + inp: torch.Tensor, + probs: torch.Tensor, + routing_map: torch.Tensor, + tokens_per_expert: torch.Tensor, + align_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + """ + Permute the tokens and probs based on the routing_map. + Token with the same index will be grouped together. + Tokens with the same designated expert will be grouped together. + The routing_map indicates which experts were selected by each token. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + probs: torch.Tensor + The tensor of probabilities corresponding to the permuted tokens and is + of shape [num_tokens, num_experts]. It will be permuted with the tokens + according to the routing_map. + routing_map: torch.Tensor + The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'. + The values in it: 1 means the token is routed to this expert and 0 means not. + tokens_per_expert : torch.Tensor + Tensor of shape `[num_experts]` containing actual token counts per expert. + align_size : int + the alignment size for the input tensor. + """ + assert ( + tokens_per_expert is not None + ), "tokens_per_expert must be provided to the fused permute padding function." + assert align_size > 0, f"align_size must be positive, got {align_size}" + + # Ensure tokens_per_expert is on the same device as input to avoid device transfers + if tokens_per_expert.device != inp.device: + tokens_per_expert = tokens_per_expert.to(inp.device) + + # Calculate aligned token counts per expert + target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long() + + if torch.equal(tokens_per_expert, target_tokens_per_expert): + pad_offsets = None + else: + pad_lengths = target_tokens_per_expert - tokens_per_expert + cum_pad = torch.cumsum(pad_lengths, dim=0) + pad_offsets = torch.cat( + [torch.zeros(1, dtype=cum_pad.dtype, device=inp.device), cum_pad[:-1]] + ) + + output, row_id_map, permuted_probs = _moe_permute_mask_map.apply( + inp, routing_map, target_tokens_per_expert.sum().item(), probs, pad_offsets + ) + return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert + + def moe_unpermute( inp: torch.Tensor, row_id_map: torch.Tensor, @@ -582,6 +651,7 @@ def moe_unpermute( restore_shape: Optional[torch.Size] = None, map_type: str = "mask", probs: Optional[torch.Tensor] = None, + pad_offsets: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Unpermute a tensor with permuted tokens, and optionally merge the tokens with their @@ -605,6 +675,10 @@ def moe_unpermute( Options are: 'mask', 'index'. probs : torch.Tensor, default = None Renamed to merging_probs. Keep for backward compatibility. + pad_offsets : torch.Tensor, default = None + Tensor of per-expert cumulative padding offsets used to remove padding added + during permutation. This is the fourth output of `moe_permute_and_pad_with_probs` + and is required when unpermuting padded outputs. """ if probs is not None: if merging_probs is not None: @@ -616,7 +690,9 @@ def moe_unpermute( if map_type == "index": return _moe_unpermute_index_map.apply(inp, row_id_map, merging_probs) if map_type == "mask": - return _moe_unpermute_mask_map.apply(inp, row_id_map, merging_probs, restore_shape) + return _moe_unpermute_mask_map.apply( + inp, row_id_map, merging_probs, restore_shape, pad_offsets + ) raise ValueError("map_type should be one of 'mask' or 'index'") diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 8f953e9c31..27662e1b28 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -123,6 +123,7 @@ def permute_with_mask_map( row_id_map: torch.Tensor, probs: torch.Tensor, scale: torch.Tensor, + pad_offsets: torch.Tensor, num_tokens: int, num_experts: int, num_out_tokens: int, @@ -142,6 +143,9 @@ def permute_with_mask_map( The probabilities of the input tensor. If it is not None, it will be permuted. scale : torch.Tensor The scale of the input tensor. If it is not None, it will be permuted. + pad_offsets : torch.Tensor + Per-expert padding offsets of shape `[num_experts]` for FP8 fused padding. + If it is not None, it will be allocated output buffers with aligned sizes. num_tokens : int Number of tokens in the input tensor. num_experts : int @@ -153,18 +157,18 @@ def permute_with_mask_map( scale_hidden_dim : int Hidden size of the scale tensor. """ - output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") - if probs is not None: - permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda") - else: - permuted_probs = None - - if scale is not None: - permuted_scale = torch.empty( - (num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda" - ) - else: - permuted_scale = None + # Use torch.zeros when pad_offsets is provided to ensure padding regions are zeroed, + # since the kernel doesn't write to padding positions. + alloc = torch.zeros if pad_offsets is not None else torch.empty + output = alloc((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") + permuted_probs = ( + alloc((num_out_tokens,), dtype=probs.dtype, device="cuda") if probs is not None else None + ) + permuted_scale = ( + torch.empty((num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda") + if scale is not None + else None + ) # pylint: disable=unnecessary-lambda-assignment grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"])) _permute_kernel[grid]( @@ -173,6 +177,7 @@ def permute_with_mask_map( probs, scale, permuted_scale, + pad_offsets, scale_hidden_dim, row_id_map.stride(0), row_id_map.stride(1), @@ -193,6 +198,7 @@ def permute_with_mask_map( hidden_size, PERMUTE_PROBS=probs is not None, PERMUTE_SCALE=scale is not None, + FUSION_PAD=pad_offsets is not None, ) return output, permuted_scale, permuted_probs @@ -202,6 +208,7 @@ def unpermute_with_mask_map( row_id_map: torch.Tensor, merging_probs: Union[torch.Tensor, None], permuted_probs: Union[torch.Tensor, None], + pad_offsets: Union[torch.Tensor, None], num_tokens: int, num_experts: int, hidden_size: int, @@ -220,6 +227,9 @@ def unpermute_with_mask_map( to reduce the unpermuted tokens. permuted_probs : torch.Tensor The permuted probabilities of the input tensor. If it is not None, it will be unpermuted. + pad_offsets : torch.Tensor + Per-expert padding offsets of shape `[num_experts]` for FP8 fused unpadding. + If it is not None, it will remove the previously fused padding. num_tokens : int Number of tokens in the permuted tensor. num_experts : int @@ -241,6 +251,7 @@ def unpermute_with_mask_map( row_id_map, merging_probs, permuted_probs, + pad_offsets, row_id_map.stride(0), row_id_map.stride(1), inp.stride(0), @@ -259,6 +270,7 @@ def unpermute_with_mask_map( PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), WITH_MERGING_PROBS=merging_probs is not None, PERMUTE_PROBS=permuted_probs is not None, + FUSION_UNPAD=pad_offsets is not None, ) return output, unpermuted_probs @@ -268,6 +280,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs( row_id_map: torch.Tensor, fwd_input: torch.Tensor, merging_probs: torch.Tensor, + pad_offsets: Union[torch.Tensor, None], num_tokens: int, num_experts: int, num_out_tokens: int, @@ -286,6 +299,9 @@ def unpermute_with_mask_map_bwd_with_merging_probs( The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`. merging_probs : torch.Tensor The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`. + pad_offsets : torch.Tensor + Per-expert padding offsets of shape `[num_experts]` for FP8 fused padding. + If it is not None, it will be allocated output buffers with aligned sizes. num_tokens : int Number of tokens in the permuted tensor. num_experts : int @@ -295,9 +311,11 @@ def unpermute_with_mask_map_bwd_with_merging_probs( hidden_size : int Hidden size of the output tensor. """ - act_grad = torch.empty( - (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" - ) + # Use zeros when pad_offsets is used because padding slots won't be written to + # by the kernel. This matches the behavior of Fp8Unpadding.backward which zeros + # out the padding slots. + alloc = torch.zeros if pad_offsets is not None else torch.empty + act_grad = alloc((num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda") merging_probs_grad = torch.empty( (num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda" ) @@ -307,6 +325,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs( fwd_input, merging_probs, row_id_map, + pad_offsets, row_id_map.stride(0), row_id_map.stride(1), fwd_output_grad.stride(0), @@ -324,6 +343,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs( num_experts, hidden_size, PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), + FUSION_UNPAD=pad_offsets is not None, ) return act_grad, merging_probs_grad From 620b4413088536828cd1b4b682e3f2f62f45e20f Mon Sep 17 00:00:00 2001 From: xiaoxi-wangfj <690912414@qq.com> Date: Thu, 25 Dec 2025 10:09:18 +0000 Subject: [PATCH 2/3] [FP8-Flow-MoE] fix dirty pad scale Signed-off-by: xiaoxi-wangfj <690912414@qq.com> --- transformer_engine/pytorch/triton/permutation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 27662e1b28..985d11c644 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -165,7 +165,7 @@ def permute_with_mask_map( alloc((num_out_tokens,), dtype=probs.dtype, device="cuda") if probs is not None else None ) permuted_scale = ( - torch.empty((num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda") + alloc((num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda") if scale is not None else None ) From f4115c3a93958255b24f87c698a58192a243d14f Mon Sep 17 00:00:00 2001 From: xiaoxi-wangfj <690912414@qq.com> Date: Fri, 26 Dec 2025 02:51:28 +0000 Subject: [PATCH 3/3] [PyTorch]: add moe fp8 flow under blockwise recipe MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. add fp8 rowwise scaling-aware transpose op for wgrad columwise. 2. support Float8BlockwiseQTensor input in grouped_linear. 3. _rowwise_scale_inv is propagated with a COMPACT layout along the `dispatch → permute → GroupedLinear` path. Signed-off-by: xiaoxi-wangfj <690912414@qq.com> Co-authored-by: dantesuu@gmail.com Co-authored-by: xzhu@zhejianglab.org Co-authored-by: 123sssmmm@gmail.com --- .../pytorch/module/grouped_linear.py | 39 ++- transformer_engine/pytorch/permutation.py | 14 +- .../pytorch/tensor/float8_blockwise_tensor.py | 91 +++++++ .../blockwise_scaling_aware_fp8_transpose.py | 230 ++++++++++++++++++ 4 files changed, 359 insertions(+), 15 deletions(-) create mode 100644 transformer_engine/pytorch/triton/blockwise_scaling_aware_fp8_transpose.py diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index c4d35a9c2c..6589a1cf8e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -44,6 +44,7 @@ from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer +from ..tensor.float8_blockwise_tensor import Float8BlockwiseQTensor from ..quantized_tensor import ( QuantizedTensorStorage, Quantizer, @@ -143,7 +144,12 @@ def forward( inp_view = inp.reshape(-1, in_features) inputmats: list if fp8 and not debug: - inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers) + if isinstance(inp_view, Float8BlockwiseQTensor): + inputmats = inp_view.split_scaling_aware_fp8_transpose( + m_splits, input_quantizers + ) + else: + inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers) elif debug: inputmats = DebugQuantizer.multi_tensor_quantize( inp_view, input_quantizers, m_splits, activation_dtype @@ -343,18 +349,28 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Unfused bias grad and multi-tensor quantize for i in range(ctx.num_gemms): grad_biases[i] = grad_output_mats[i].sum(dim=0) + if isinstance(grad_output_view, Float8BlockwiseQTensor): + grad_output = grad_output_view.split_scaling_aware_fp8_transpose( + ctx.m_splits, ctx.grad_output_quantizers + ) + else: + grad_output = tex.split_quantize( + grad_output_view, + ctx.m_splits, + ctx.grad_output_quantizers, + ) + else: + # Multi-tensor quantize + if isinstance(grad_output_view, Float8BlockwiseQTensor): + grad_output = grad_output_view.split_scaling_aware_fp8_transpose( + ctx.m_splits, ctx.grad_output_quantizers + ) + else: grad_output = tex.split_quantize( grad_output_view, ctx.m_splits, ctx.grad_output_quantizers, ) - else: - # Multi-tensor quantize - grad_output = tex.split_quantize( - grad_output_view, - ctx.m_splits, - ctx.grad_output_quantizers, - ) elif ctx.debug: grad_output_mats = torch.split(grad_output_view, ctx.m_splits) for i in range(ctx.num_gemms): @@ -781,9 +797,10 @@ def forward( """ debug = self.is_debug_iter() - assert not isinstance( - inp, QuantizedTensorStorage - ), "GroupedLinear doesn't support input tensor in FP8." + if not isinstance(inp, Float8BlockwiseQTensor): + assert not isinstance( + inp, QuantizedTensorStorage + ), "GroupedLinear doesn't support input tensor in FP8." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." is_grad_enabled = torch.is_grad_enabled() diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index d15814585e..78766ad67d 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -224,7 +224,10 @@ def forward( fake_dtype = inp.dtype # blockwise scaling if blockwise_recipe: - fp8_scale = inp._rowwise_scale_inv.T.contiguous() + if inp._rowwise_data.shape[0] == inp._rowwise_scale_inv.shape[0]: + fp8_scale = inp._rowwise_scale_inv + else: + fp8_scale = inp._rowwise_scale_inv.T.contiguous() scale_hidden_dim = fp8_scale.shape[1] assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" inp = inp._rowwise_data @@ -275,7 +278,7 @@ def forward( shape=output.shape, dtype=fake_dtype, rowwise_data=output, - rowwise_scale_inv=permuted_scale.T.contiguous(), + rowwise_scale_inv=permuted_scale, columnwise_data=None, columnwise_scale_inv=None, fp8_dtype=fp8_dtype, @@ -423,7 +426,10 @@ def backward(ctx, unpermuted_act_grad): unpermuted_act_grad = unpermuted_act_grad._data # blockwise scaling elif blockwise_recipe: - fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous() + if unpermuted_act_grad._rowwise_data.shape[0] == unpermuted_act_grad._rowwise_scale_inv.shape[0]: + fp8_scale = unpermuted_act_grad._rowwise_scale_inv + else: + fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous() unpermuted_act_grad = unpermuted_act_grad._rowwise_data scale_hidden_dim = fp8_scale.shape[1] assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" @@ -485,7 +491,7 @@ def backward(ctx, unpermuted_act_grad): shape=act_grad.shape, dtype=fake_dtype, rowwise_data=act_grad, - rowwise_scale_inv=permuted_scale.T.contiguous(), + rowwise_scale_inv=permuted_scale, columnwise_data=None, columnwise_scale_inv=None, fp8_dtype=fp8_dtype, diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 01e03e5355..6d3e4cdbd4 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -17,6 +17,9 @@ from ..quantized_tensor import QuantizedTensor, Quantizer from ._quantization_helpers import _IdentityFunc from ..utils import devices_match, round_up_to_nearest_multiple +from ..triton.blockwise_scaling_aware_fp8_transpose import ( + blockwise_scaling_aware_fp8_transpose, +) aten = torch.ops.aten @@ -437,6 +440,94 @@ def untyped_storage(self) -> torch.UntypedStorage: return data.untyped_storage() return torch.UntypedStorage(0, device=self.device) + def split_scaling_aware_fp8_transpose(self, m_splits, quantizers): + assert ( + self._rowwise_data is not None and self._rowwise_scale_inv is not None + ), "split_transpose_quantize only supports rowwise_data inputs." + + # Temporary solution: perf fp8flow + assert ( + self._rowwise_scale_inv.shape[0] == self._rowwise_data.shape[0] + ), "rowwise_data and rowwise_scale_inv must have same M (rows)." + if ( + self._is_gemm_ready_format() + and self._rowwise_data.shape[0] == self._rowwise_scale_inv.shape[0] + ): + self._data_format = tex.Float8BlockScaleTensorFormat.COMPACT + assert ( + not self._is_gemm_ready_format() + ), "Only COMPACT input format is supported." + + rowwise_usage = quantizers[0].rowwise_usage + device = self._rowwise_data.device + kept = [i for i, m in enumerate(m_splits) if m > 0] + m_splits_kept = [m_splits[i] for i in kept] + + if len(m_splits_kept) > 0: + ( + rowwise_data_list, + rowwise_scale_inv_t_list, + columnwise_data_list, + columnwise_scale_inv_list, + ) = blockwise_scaling_aware_fp8_transpose( + self._rowwise_data, self._rowwise_scale_inv, m_splits_kept + ) + + if len(m_splits_kept) != len(m_splits): + K = self._rowwise_data.shape[1] + empty_rw_data = ( + torch.empty((0, K), dtype=self._rowwise_data.dtype, device=device) + if rowwise_usage + else None + ) + empty_rw_si_t = ( + torch.empty( + (self._rowwise_scale_inv.shape[1], 0), + dtype=self._rowwise_scale_inv.dtype, + device=device, + ) + if rowwise_usage + else None + ) + empty_cw_data = torch.empty( + (K, 0), dtype=self._rowwise_data.dtype, device=device + ) + empty_cw_si = torch.empty( + (0, K), dtype=self._rowwise_scale_inv.dtype, device=device + ) + + results = [] + kept_idx = 0 + for i, m in enumerate(m_splits): + if m == 0: + rowwise_data = empty_rw_data + rowwise_scale_inv_t = empty_rw_si_t + columnwise_data = empty_cw_data + columnwise_scale_inv = empty_cw_si + else: + rowwise_data = rowwise_data_list[kept_idx] if rowwise_usage else None + rowwise_scale_inv_t = ( + rowwise_scale_inv_t_list[kept_idx] if rowwise_usage else None + ) + columnwise_data = columnwise_data_list[kept_idx] + columnwise_scale_inv = columnwise_scale_inv_list[kept_idx] + kept_idx += 1 + + results.append( + Float8BlockwiseQTensorStorage( + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv_t, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + fp8_dtype=self._fp8_dtype, + quantizer=quantizers[i], + is_2D_scaled=self._is_2D_scaled, + data_format=tex.Float8BlockScaleTensorFormat.GEMM_READY, + ) + ) + + return results + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): diff --git a/transformer_engine/pytorch/triton/blockwise_scaling_aware_fp8_transpose.py b/transformer_engine/pytorch/triton/blockwise_scaling_aware_fp8_transpose.py new file mode 100644 index 0000000000..52f2481138 --- /dev/null +++ b/transformer_engine/pytorch/triton/blockwise_scaling_aware_fp8_transpose.py @@ -0,0 +1,230 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""PyTorch wrapper functions and scaling_aware_fp8_transpose Triton kernels.""" +import torch +import triton +import triton.language as tl + + +@triton.jit +def _scaling_aware_fp8_transpose_kernel( + # input pointers + rowwise_data_ptrs, + rowwise_scale_inv_ptrs, + columnwise_data_ptrs, + columnwise_scale_inv_ptrs, + rowwise_scale_inv_t_ptrs, + rows_ptr, + # sizes + cols, + rsi_cols, + # strides + stride_rowwise_data_r, + stride_rsi_r, + # metas + BLOCK_SIZE: tl.constexpr, +): + pid_group_index = tl.program_id(0) + pid_row = tl.program_id(1) + pid_col = tl.program_id(2) + + rows = tl.load(rows_ptr + pid_group_index) + nbrows = (rows + BLOCK_SIZE - 1) // BLOCK_SIZE + if pid_row >= nbrows: + return + + row_base = tl.load(rowwise_data_ptrs + pid_group_index).to( + tl.pointer_type(tl.uint8) + ) + rsi_base = tl.load(rowwise_scale_inv_ptrs + pid_group_index).to( + tl.pointer_type(tl.float32) + ) + col_base = tl.load(columnwise_data_ptrs + pid_group_index).to( + tl.pointer_type(tl.uint8) + ) + csi_base = tl.load(columnwise_scale_inv_ptrs + pid_group_index).to( + tl.pointer_type(tl.float32) + ) + + r_start = pid_row * BLOCK_SIZE + c_start = pid_col * BLOCK_SIZE + r_offsets = r_start + tl.arange(0, BLOCK_SIZE) + c_offsets = c_start + tl.arange(0, BLOCK_SIZE) + valid_r = r_offsets < rows + valid_c = c_offsets < cols + data = tl.load( + row_base + (r_offsets[:, None] * stride_rowwise_data_r + c_offsets[None, :]), + mask=valid_r[:, None] & valid_c[None, :], + other=0, + ) + + rsi_c_offsets = pid_col + tl.arange(0, 1) + valid_rsi_c = rsi_c_offsets < rsi_cols + si = tl.load( + rsi_base + r_offsets[:, None] * stride_rsi_r + rsi_c_offsets[None, :], + mask=valid_r[:, None] & valid_rsi_c[None, :], + other=0.0, + ) + + # Write rowwise_scale_inv.T + rst_base = tl.load(rowwise_scale_inv_t_ptrs + pid_group_index).to( + tl.pointer_type(tl.float32) + ) + tl.store( + rst_base + (rsi_c_offsets[:, None] * rows + r_offsets[None, :]), + si.T, + mask=valid_rsi_c[:, None] & valid_r[None, :], + ) + + # For the current block-row (128 rows), take the per-channel max of rowwise_scale_inv + # This max value becomes the columnwise scaling factor for this block + target_si = tl.max(si, axis=0) + tl.store(csi_base + (pid_row * cols + c_offsets), target_si, mask=valid_c) + + # FP8 decode/encode + sign = (data >> 7) & 1 + exp = (data >> 3) & 0xF + mant = data & 0x7 + # log2_t = tl.log2(target_si) + # log2_si = tl.log2(si + 1e-30) + # kf = log2_t - log2_si + # k = tl.cast(tl.floor(kf + 0.5), tl.int32) + bits_target = tl.cast(target_si, tl.uint32, bitcast=True) + bits_si = tl.cast(si, tl.uint32, bitcast=True) + exp_t = ((bits_target & 0x7F800000) >> 23) - 127 + exp_s = ((bits_si & 0x7F800000) >> 23) - 127 + k_approx = exp_t[None, :] - exp_s + k = tl.cast(k_approx, tl.int32) + exp_new = exp - k + exp_new = tl.where(exp_new < 1, 0, exp_new) + new_data = (sign << 7) | (exp_new << 3) | mant + new_data = tl.where(exp == 0, 0, new_data) + + # write columnwise_data (uint8) to [K,M] (c, r) + tl.store( + col_base + (c_offsets[:, None] * rows + r_offsets[None, :]), + new_data.T, + mask=valid_c[:, None] & valid_r[None, :], + ) + + +def blockwise_scaling_aware_fp8_transpose( + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + m_splits: list[int], + block_size: int = 128, +): + """ + Scaling-aware FP8 transpose that converts row-wise quantized FP8 tensors to a + column-wise layout in the FP8 domain. + + The input is split along the M dimension according to ``m_splits``. For each split, + the kernel transposes FP8 data from shape ``[m_i, cols]`` to ``[cols, m_i]`` while + producing column-wise scaling factors at block-row granularity. The operation is + performed without dequantizing to higher precision types. + + Parameters + ---------- + rowwise_data : torch.Tensor + Row-wise FP8-encoded data stored as ``uint8`` with shape + ``[sum(m_splits), cols]``. + + rowwise_scale_inv : torch.Tensor + Row-wise scaling factors associated with ``rowwise_data`` with shape + ``[sum(m_splits), rsi_cols]``. + + m_splits : list[int] + Sizes of splits along the M dimension. Each entry ``m_i`` defines the number of + rows in one group. + + block_size : int, optional + Tile size for the blockwise transpose and scaling-aware conversion. + + Returns + ------- + rowwise_data_list : list[torch.Tensor] + List of input views split by ``m_splits``, each with shape ``[m_i, cols]`` and + dtype matching ``rowwise_data``. + + rowwise_scale_inv_t_list : list[torch.Tensor] + List of transposed row-wise inverse scaling tensors, each with shape + ``[nbcols, m_i]``, where ``nbcols = ceil(cols / block_size)`` and dtype matching + ``rowwise_scale_inv``. + + columnwise_data_list : list[torch.Tensor] + List of column-wise FP8-encoded output tensors, each with shape ``[cols, m_i]`` + and dtype matching ``rowwise_data`` (raw FP8 bits in ``uint8``). + + columnwise_scale_inv_list : list[torch.Tensor] + List of column-wise inverse scaling tensors at block-row granularity, each with + shape ``[nbrows_i, cols]``, where ``nbrows_i = ceil(m_i / block_size)`` and dtype + matching ``rowwise_scale_inv``. + + """ + assert len(m_splits) > 0, "m_splits can not be zero" + device = rowwise_data.device + data_dtype = rowwise_data.dtype + scale_dtype = rowwise_scale_inv.dtype + + cols = rowwise_data.shape[1] + rsi_cols = rowwise_scale_inv.shape[1] + # Number of block-rows (along the M dimension) for each tensor, + # since each Mi differs, we must take the maximum among them + nbrows_list = [(m + block_size - 1) // block_size for m in m_splits] + nbcols = (cols + block_size - 1) // block_size + + rowwise_data_list = list(torch.split(rowwise_data, m_splits, dim=0)) + rowwise_scale_inv_list = list(torch.split(rowwise_scale_inv, m_splits, dim=0)) + rowwise_scale_inv_t_list = [ + torch.empty((nbcols, m), dtype=scale_dtype, device=device) for m in m_splits + ] + columnwise_data_list = [ + torch.empty((cols, m), dtype=data_dtype, device=device) for m in m_splits + ] + columnwise_scale_inv_list = [ + torch.empty((nb, cols), dtype=scale_dtype, device=device) for nb in nbrows_list + ] + + rowwise_data_ptrs = torch.as_tensor([t.data_ptr() for t in rowwise_data_list]).to( + device=device, non_blocking=True + ) + rowwise_scale_inv_ptrs = torch.as_tensor( + [t.data_ptr() for t in rowwise_scale_inv_list] + ).to(device=device, non_blocking=True) + rowwise_scale_inv_t_ptrs = torch.as_tensor( + [t.data_ptr() for t in rowwise_scale_inv_t_list] + ).to(device=device, non_blocking=True) + columnwise_data_ptrs = torch.as_tensor( + [t.data_ptr() for t in columnwise_data_list] + ).to(device=device, non_blocking=True) + columnwise_scale_inv_ptrs = torch.as_tensor( + [t.data_ptr() for t in columnwise_scale_inv_list] + ).to(device=device, non_blocking=True) + + rows_t = torch.as_tensor(m_splits, dtype=torch.int32).to( + device=device, non_blocking=True + ) + + grid = (len(m_splits), max(nbrows_list), nbcols) + _scaling_aware_fp8_transpose_kernel[grid]( + rowwise_data_ptrs, + rowwise_scale_inv_ptrs, + columnwise_data_ptrs, + columnwise_scale_inv_ptrs, + rowwise_scale_inv_t_ptrs, + rows_t, + cols, + rsi_cols, + rowwise_data.stride(0), + rowwise_scale_inv.stride(0), + BLOCK_SIZE=block_size, + ) + + return ( + rowwise_data_list, + rowwise_scale_inv_t_list, + columnwise_data_list, + columnwise_scale_inv_list, + )