Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
feb65b7
[PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization
xiaoxi-wangfj Jul 3, 2025
a7de66c
[PyTorch/Common] Fuse permute+pad and unpermute+unpad support with_me…
xiaoxi-wangfj Dec 11, 2025
f550684
[PyTorch]format code
xiaoxi-wangfj Dec 11, 2025
6069277
[Common]perf expert_idx loaded once
xiaoxi-wangfj Dec 11, 2025
053abee
Merge branch 'main' into fused_perm_pad
xiaoxi-wangfj Dec 12, 2025
1ea08f7
fix: pad_offsets can be None
xiaoxi-wangfj Dec 17, 2025
ac12a91
Merge branch 'main' into fused_perm_pad
xiaoxi-wangfj Dec 17, 2025
230939c
add padding + merging probs bwd support. Not tested
tdophung Dec 11, 2025
f301462
Fix garbage initialized act grad
tdophung Dec 11, 2025
7ed584c
all test passing for jax permutation + pad
tdophung Dec 17, 2025
7998ce8
change tokens_per_experts APIs to num_out_tokens with conservative a…
tdophung Dec 17, 2025
dd5c72a
change test permutation to reduce test time
tdophung Dec 19, 2025
ce187b6
triggering PR refresh
tdophung Dec 19, 2025
7dc9ccb
format code
tdophung Dec 20, 2025
1fbe99c
Remove some tests cases from pytorch side. Add a separate toekn_dispa…
tdophung Dec 20, 2025
592f675
format code
tdophung Dec 20, 2025
1d43279
remove chance for inefficiency in moving between CPU and GPU, remove …
tdophung Dec 20, 2025
4169a4e
fix lint in jax
tdophung Dec 22, 2025
c619adf
account for both jax newer and older than version 0.8.2. Adjusted gpu…
tdophung Dec 22, 2025
405b341
format code
tdophung Dec 22, 2025
7cad5c5
fix typo
tdophung Dec 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
352 changes: 352 additions & 0 deletions tests/pytorch/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from transformer_engine.pytorch import (
moe_permute as te_permute,
moe_permute_with_probs as te_permute_with_probs,
moe_permute_and_pad_with_probs as te_permute_and_pad_with_probs,
moe_unpermute as te_unpermute,
moe_sort_chunks_by_index as te_sort_chunks_by_index,
moe_sort_chunks_by_index_with_probs as te_sort_chunks_by_index_with_probs,
Expand All @@ -24,6 +25,7 @@
MXFP8Quantizer,
)
import transformer_engine_torch as tex
from transformer_engine.pytorch import Fp8Padding, Fp8Unpadding
import copy

seed = 1234
Expand Down Expand Up @@ -653,6 +655,297 @@ def _test_permutation_mask_map(
print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")


def _test_permutation_and_padding_mask_map(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
with_merging_probs=False,
align_size=16,
BENCHMARK=False,
):
if topK > num_expert:
pytest.skip("topK should be smaller than the number of experts.")

if num_out_tokens is None:
num_out_tokens = num_tokens * topK

print(
"permutation and padding:"
f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} with_probs:{with_merging_probs} align_size:{align_size} {te_dtype}"
)

# Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32:
dtype = torch.float32
elif te_dtype == tex.DType.kFloat16:
dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16
else:
pytest.skip("Invalid dtype.")

_tmp_tensor = torch.zeros((num_tokens * num_expert,))
_tmp_tensor[: int(num_out_tokens)] = 1.0
_tmp_idx = torch.randperm(num_tokens * num_expert)
routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda()

probs = torch.rand(num_tokens, num_expert).cuda() * routing_map
row_sums = probs.sum(dim=1, keepdim=True)
probs = probs / row_sums
probs = probs.to(dtype)
probs.requires_grad_(True)

tokens_per_expert = routing_map.sum(dim=0).cpu()
target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long()
num_permute_pad_out_tokens = target_tokens_per_expert.sum().item()

permute_pad_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
permute_pad_bwd_input = torch.rand(
(num_permute_pad_out_tokens, hidden_size), dtype=dtype
).cuda()
unpermute_unpad_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
permute_pad_fwd_input.requires_grad_(True)

restore_shape = permute_pad_fwd_input.shape
###################################################################################################################################
#
# moe_permute_with_probs and Fp8Padding, moe_unpermute and Fp8Unpadding
#
###################################################################################################################################
# permute + padding
permuted_output, permuted_probs, row_id_map = te_permute_with_probs(
permute_pad_fwd_input,
probs,
routing_map,
num_out_tokens=num_out_tokens,
)
tokens_per_expert_list = tokens_per_expert.tolist()
fp8_padding = Fp8Padding(num_expert, align_size)
permuted_paded_output, _ = fp8_padding(permuted_output, tokens_per_expert_list)
permuted_paded_probs, _ = fp8_padding(permuted_probs.unsqueeze(-1), tokens_per_expert_list)

permuted_paded_output.backward(permute_pad_bwd_input, retain_graph=True)

# unpadding + unpermute

unpermute_unpad_fwd_input = permuted_paded_output.detach()
unpermute_unpad_fwd_input.requires_grad_(True)

fp8_unpadding = Fp8Unpadding(num_expert, align_size)
unpaded_output = fp8_unpadding(unpermute_unpad_fwd_input, tokens_per_expert_list)

probs_naive = probs
unpermuted_unpaded_output = te_unpermute(
unpaded_output,
row_id_map,
merging_probs=probs_naive if with_merging_probs else None,
restore_shape=restore_shape,
)

unpermuted_unpaded_output.backward(unpermute_unpad_bwd_input, retain_graph=True)

###################################################################################################################################
#
# fusion moe_permute_with_probs and Fp8Padding, fusion fusion moe_unpermute and Fp8Unpadding
#
###################################################################################################################################
# fusion permute_and_pad
fusion_permute_and_pad_fwd_input = permute_pad_fwd_input.detach()
fusion_permute_and_pad_fwd_input.requires_grad_(True)
probs_fusion = probs_naive.detach().clone()
probs_fusion.requires_grad_(True)

(
fusion_permuted_padded_output,
fusion_permuted_padded_probs,
row_id_map,
pad_offsets,
target_tokens_per_expert,
) = te_permute_and_pad_with_probs(
fusion_permute_and_pad_fwd_input,
probs_fusion,
routing_map,
tokens_per_expert,
align_size,
)
fusion_permuted_padded_probs = fusion_permuted_padded_probs.unsqueeze(-1)

fusion_permute_pad_bwd_input = permute_pad_bwd_input.detach()
fusion_permuted_padded_output.backward(fusion_permute_pad_bwd_input, retain_graph=True)

# fusion unpad and unpermute
fusion_unpermute_unpad_fwd_input = fusion_permuted_padded_output.detach()
fusion_unpermute_unpad_fwd_input.requires_grad_(True)

fusion_unpermuted_unpaded_output = te_unpermute(
fusion_unpermute_unpad_fwd_input,
row_id_map,
merging_probs=probs_fusion if with_merging_probs else None,
restore_shape=restore_shape,
pad_offsets=pad_offsets,
)

fusion_unpermute_bwd_input = unpermute_unpad_bwd_input.detach()
fusion_unpermuted_unpaded_output.backward(fusion_unpermute_bwd_input, retain_graph=True)

###################################################################################################################################
#
# Results Check
#
###################################################################################################################################
tols = dtype_tols(te_dtype)

permuted_paded_output_ = permuted_paded_output.float()
fusion_permuted_padded_output_ = fusion_permuted_padded_output.float()
permute_pad_fwd_input_grad = permute_pad_fwd_input.grad.float()
fusion_permute_and_pad_fwd_input_grad = fusion_permute_and_pad_fwd_input.grad.float()

unpermuted_unpaded_output_ = unpermuted_unpaded_output.float()
fusion_unpermuted_unpaded_output_ = fusion_unpermuted_unpaded_output.float()
unpermute_unpad_fwd_input_grad = unpermute_unpad_fwd_input.grad.float()
fusion_unpermute_unpad_fwd_input_grad = fusion_unpermute_unpad_fwd_input.grad.float()

if not BENCHMARK:
torch.testing.assert_close(
permuted_paded_output_,
fusion_permuted_padded_output_,
msg=f"Mismatch in te_permute_and_pad fwd",
**tols,
)
torch.testing.assert_close(
permute_pad_fwd_input_grad,
fusion_permute_and_pad_fwd_input_grad,
msg=f"Mismatch in te_permute_and_pad bwd",
**tols,
)
torch.testing.assert_close(
unpermuted_unpaded_output_,
fusion_unpermuted_unpaded_output_,
msg=f"Mismatch in te_unpermute fwd",
**tols,
)
torch.testing.assert_close(
unpermute_unpad_fwd_input_grad,
fusion_unpermute_unpad_fwd_input_grad,
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
torch.testing.assert_close(
permuted_paded_probs.float(),
fusion_permuted_padded_probs.float(),
msg=f"Mismatch in te_permute_and_pad bwd",
**tols,
)
if with_merging_probs:
torch.testing.assert_close(
probs_naive.grad.float(),
probs_fusion.grad.float(),
msg=f"Mismatch in te_unpermute bwd",
**tols,
)

###################################################################################################################################
#
# Benchmark
#
###################################################################################################################################
if BENCHMARK:
Comment thread
tdophung marked this conversation as resolved.

def permute_and_pad():
permuted_output, permuted_probs, row_id_map = te_permute_with_probs(
permute_pad_fwd_input,
probs,
routing_map,
num_out_tokens=num_out_tokens,
)
fp8_padding(permuted_output, tokens_per_expert_list)
fp8_padding(permuted_probs.unsqueeze(-1), tokens_per_expert_list)

def fusion_permute_and_pad():
(
fusion_permuted_padded_output,
fusion_permuted_padded_probs,
row_id_map,
pad_offsets,
target_tokens_per_expert,
) = te_permute_and_pad_with_probs(
fusion_permute_and_pad_fwd_input,
probs,
routing_map,
tokens_per_expert,
align_size,
)
fusion_permuted_padded_probs = fusion_permuted_padded_probs.unsqueeze(-1)

t1 = perf_test_cuda_kernel(lambda: permute_and_pad())

t2 = perf_test_cuda_kernel(lambda: fusion_permute_and_pad())

print(f"permute_and_pad\t\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")

t1 = perf_test_cuda_kernel(
lambda: backward_wrapper(
permuted_paded_output,
permute_pad_bwd_input,
forward_input=[permute_pad_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
fusion_permuted_padded_output,
fusion_permute_pad_bwd_input,
forward_input=[fusion_permute_and_pad_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
print(f"permute_and_pad\t\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")

def unpad_unpermute():
unpaded_output = fp8_unpadding(unpermute_unpad_fwd_input, tokens_per_expert_list)
unpermuted_unpaded_output = te_unpermute(
unpaded_output, row_id_map, restore_shape=restore_shape
)

unpermuted_unpaded_output.backward(unpermute_unpad_bwd_input, retain_graph=True)

t1 = perf_test_cuda_kernel(lambda: unpad_unpermute())
t2 = perf_test_cuda_kernel(
lambda: te_unpermute(
fusion_unpermute_unpad_fwd_input,
row_id_map,
restore_shape=restore_shape,
pad_offsets=pad_offsets,
)
)
print(f"unpermute_and_unpad\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")

t1 = perf_test_cuda_kernel(
lambda: backward_wrapper(
unpermuted_unpaded_output,
unpermute_unpad_bwd_input,
forward_input=([unpermute_unpad_fwd_input, probs]),
retain_graph=True,
accumulate_grad=False,
)
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
fusion_unpermuted_unpaded_output,
fusion_unpermute_bwd_input,
forward_input=([fusion_unpermute_unpad_fwd_input, probs]),
retain_graph=True,
accumulate_grad=False,
)
)
print(f"unpermute_and_unpad\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")


def _test_permutation_mask_map_fp8(
te_dtype,
num_tokens,
Expand Down Expand Up @@ -1180,6 +1473,43 @@ def test_permutation_mask_map(
)


@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_out_tokens", [None])
@pytest.mark.parametrize(
"num_tokens, num_expert, hidden_size, topK",
[
(4096, 64, 1280, 7),
(4096, 64, 2048, 6),
(4096, 160, 5120, 6),
(4096, 256, 7168, 8),
(4096, 384, 8192, 8),
(4096, 512, 9216, 8),
],
)
@pytest.mark.parametrize("with_merging_probs", [True, False])
def test_permutation_and_padding_mask_map(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
with_merging_probs,
):
BENCHMARK = False

_test_permutation_and_padding_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_merging_probs=with_merging_probs,
BENCHMARK=BENCHMARK,
)


@pytest.mark.parametrize("te_dtype", _te_dtypes)
def test_permutation_mask_map_empty_input(te_dtype):
with_probs = True
Expand Down Expand Up @@ -1413,6 +1743,16 @@ def test_permutation_single_case():
BENCHMARK=Benchmark,
)

_test_permutation_and_padding_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
BENCHMARK=Benchmark,
)

_test_moe_chunk_sort(
te_dtype=te_dtype,
num_tokens=num_tokens,
Expand Down Expand Up @@ -1479,6 +1819,18 @@ def benchmark_single_case(
)
torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_push("permutation_and_padding_mask_map")
_test_permutation_and_padding_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_push("permutation_mask_map_alongside_probs")
_test_permutation_mask_map_alongside_probs(
te_dtype=te_dtype,
Expand Down
Loading