Skip to content

Commit feb65b7

Browse files
committed
[PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization
1.Fused `moe_permute_with_probs` + `Fp8Padding` and fused `moe_unpermute` + `Fp8Unpadding`, that can remove the explicit padding/unpadding of moe expert, improved performance and reduced peak gpu memory usage. 2.Add tests of fused permute/pad and unpermute/unpad. Signed-off-by: xiaoxi-wangfj <690912414@qq.com>
1 parent 5afbb0e commit feb65b7

5 files changed

Lines changed: 480 additions & 23 deletions

File tree

tests/pytorch/test_permutation.py

Lines changed: 355 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from transformer_engine.pytorch import (
1414
moe_permute as te_permute,
1515
moe_permute_with_probs as te_permute_with_probs,
16+
moe_permute_and_pad_with_probs as te_permute_and_pad_with_probs,
1617
moe_unpermute as te_unpermute,
1718
moe_sort_chunks_by_index as te_sort_chunks_by_index,
1819
moe_sort_chunks_by_index_with_probs as te_sort_chunks_by_index_with_probs,
@@ -24,6 +25,7 @@
2425
MXFP8Quantizer,
2526
)
2627
import transformer_engine_torch as tex
28+
from transformer_engine.pytorch import Fp8Padding, Fp8Unpadding
2729
import copy
2830

2931
seed = 1234
@@ -653,6 +655,303 @@ def _test_permutation_mask_map(
653655
print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
654656

655657

658+
def _test_permutation_and_padding_mask_map(
659+
te_dtype,
660+
num_tokens,
661+
num_expert,
662+
hidden_size,
663+
topK,
664+
num_out_tokens,
665+
align_size=16,
666+
BENCHMARK=False,
667+
):
668+
if topK > num_expert:
669+
pytest.skip("topK should be smaller than the number of experts.")
670+
671+
if num_out_tokens == None:
672+
num_out_tokens = num_tokens * topK
673+
674+
print(
675+
"permutation and padding:"
676+
f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} align_size:{align_size} {te_dtype}"
677+
)
678+
679+
# Convert TE dtypes to PyTorch dtypes
680+
if te_dtype == tex.DType.kFloat32:
681+
dtype = torch.float32
682+
elif te_dtype == tex.DType.kFloat16:
683+
dtype = torch.float16
684+
elif te_dtype == tex.DType.kBFloat16:
685+
dtype = torch.bfloat16
686+
else:
687+
pytest.skip("Invalid dtype.")
688+
689+
_tmp_tensor = torch.zeros((num_tokens * num_expert,))
690+
_tmp_tensor[: int(num_out_tokens)] = 1.0
691+
_tmp_idx = torch.randperm(num_tokens * num_expert)
692+
routing_map = (
693+
torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda()
694+
)
695+
696+
probs = torch.rand(num_tokens, num_expert).cuda() * routing_map
697+
row_sums = probs.sum(dim=1, keepdim=True)
698+
probs = probs / row_sums
699+
probs = probs.to(dtype)
700+
probs.requires_grad_(True)
701+
702+
tokens_per_expert = routing_map.sum(dim=0).cpu()
703+
target_tokens_per_expert = (
704+
torch.ceil(tokens_per_expert / align_size) * align_size
705+
).long()
706+
num_permute_pad_out_tokens = target_tokens_per_expert.sum().item()
707+
708+
permute_pad_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
709+
permute_pad_bwd_input = torch.rand(
710+
(num_permute_pad_out_tokens, hidden_size), dtype=dtype
711+
).cuda()
712+
unpermute_unpad_bwd_input = torch.rand(
713+
(num_tokens, hidden_size), dtype=dtype
714+
).cuda()
715+
permute_pad_fwd_input.requires_grad_(True)
716+
717+
restore_shape = permute_pad_fwd_input.shape
718+
###################################################################################################################################
719+
#
720+
# moe_permute_with_probs and Fp8Padding, moe_unpermute and Fp8Unpadding
721+
#
722+
###################################################################################################################################
723+
# permute + padding
724+
permuted_output, permuted_probs, row_id_map = te_permute_with_probs(
725+
permute_pad_fwd_input,
726+
probs,
727+
routing_map,
728+
num_out_tokens=num_out_tokens,
729+
)
730+
tokens_per_expert_list = tokens_per_expert.tolist()
731+
fp8_padding = Fp8Padding(num_expert, align_size)
732+
permuted_paded_output, _ = fp8_padding(permuted_output, tokens_per_expert_list)
733+
permuted_paded_probs, _ = fp8_padding(
734+
permuted_probs.unsqueeze(-1), tokens_per_expert_list
735+
)
736+
737+
permuted_paded_output.backward(permute_pad_bwd_input, retain_graph=True)
738+
739+
# unpadding + unpermute
740+
741+
unpermute_unpad_fwd_input = permuted_paded_output.detach()
742+
unpermute_unpad_fwd_input.requires_grad_(True)
743+
744+
fp8_unpadding = Fp8Unpadding(num_expert, align_size)
745+
unpaded_output = fp8_unpadding(unpermute_unpad_fwd_input, tokens_per_expert_list)
746+
unpermuted_unpaded_output = te_unpermute(
747+
unpaded_output, row_id_map, restore_shape=restore_shape
748+
)
749+
750+
unpermuted_unpaded_output.backward(unpermute_unpad_bwd_input, retain_graph=True)
751+
752+
###################################################################################################################################
753+
#
754+
# fusion moe_permute_with_probs and Fp8Padding, fusion fusion moe_unpermute and Fp8Unpadding
755+
#
756+
###################################################################################################################################
757+
# fusion permute_and_pad
758+
fusion_permute_and_pad_fwd_input = permute_pad_fwd_input.detach()
759+
fusion_permute_and_pad_fwd_input.requires_grad_(True)
760+
probs = probs.detach()
761+
probs.requires_grad_(True)
762+
763+
(
764+
fusion_permuted_padded_output,
765+
fusion_permuted_padded_probs,
766+
row_id_map,
767+
pad_offsets,
768+
target_tokens_per_expert,
769+
) = te_permute_and_pad_with_probs(
770+
fusion_permute_and_pad_fwd_input,
771+
probs,
772+
routing_map,
773+
tokens_per_expert,
774+
align_size,
775+
)
776+
fusion_permuted_padded_probs = fusion_permuted_padded_probs.unsqueeze(-1)
777+
778+
fusion_permute_pad_bwd_input = permute_pad_bwd_input.detach()
779+
fusion_permuted_padded_output.backward(
780+
fusion_permute_pad_bwd_input, retain_graph=True
781+
)
782+
783+
# fusion unpad and unpermute
784+
fusion_unpermute_unpad_fwd_input = fusion_permuted_padded_output.detach()
785+
fusion_unpermute_unpad_fwd_input.requires_grad_(True)
786+
787+
fusion_unpermuted_unpaded_output = te_unpermute(
788+
fusion_unpermute_unpad_fwd_input,
789+
row_id_map,
790+
restore_shape=restore_shape,
791+
pad_offsets=pad_offsets,
792+
)
793+
794+
fusion_unpermute_bwd_input = unpermute_unpad_bwd_input.detach()
795+
fusion_unpermuted_unpaded_output.backward(
796+
fusion_unpermute_bwd_input, retain_graph=True
797+
)
798+
799+
###################################################################################################################################
800+
#
801+
# Results Check
802+
#
803+
###################################################################################################################################
804+
tols = dtype_tols(te_dtype)
805+
806+
permuted_paded_output_ = permuted_paded_output.float()
807+
fusion_permuted_padded_output_ = fusion_permuted_padded_output.float()
808+
permute_pad_fwd_input_grad = permute_pad_fwd_input.grad.float()
809+
fusion_permute_and_pad_fwd_input_grad = (
810+
fusion_permute_and_pad_fwd_input.grad.float()
811+
)
812+
813+
unpermuted_unpaded_output_ = unpermuted_unpaded_output.float()
814+
fusion_unpermuted_unpaded_output_ = fusion_unpermuted_unpaded_output.float()
815+
unpermute_unpad_fwd_input_grad = unpermute_unpad_fwd_input.grad.float()
816+
fusion_unpermute_unpad_fwd_input_grad = (
817+
fusion_unpermute_unpad_fwd_input.grad.float()
818+
)
819+
820+
if not BENCHMARK:
821+
torch.testing.assert_close(
822+
permuted_paded_output_,
823+
fusion_permuted_padded_output_,
824+
msg=f"Mismatch in te_permute_and_pad fwd",
825+
**tols,
826+
)
827+
torch.testing.assert_close(
828+
permute_pad_fwd_input_grad,
829+
fusion_permute_and_pad_fwd_input_grad,
830+
msg=f"Mismatch in te_permute_and_pad bwd",
831+
**tols,
832+
)
833+
torch.testing.assert_close(
834+
unpermuted_unpaded_output_,
835+
fusion_unpermuted_unpaded_output_,
836+
msg=f"Mismatch in te_unpermute fwd",
837+
**tols,
838+
)
839+
torch.testing.assert_close(
840+
unpermute_unpad_fwd_input_grad,
841+
fusion_unpermute_unpad_fwd_input_grad,
842+
msg=f"Mismatch in te_unpermute bwd",
843+
**tols,
844+
)
845+
torch.testing.assert_close(
846+
permuted_paded_probs.float(),
847+
fusion_permuted_padded_probs.float(),
848+
msg=f"Mismatch in te_permute_and_pad bwd",
849+
**tols,
850+
)
851+
852+
###################################################################################################################################
853+
#
854+
# Benchmark
855+
#
856+
###################################################################################################################################
857+
if BENCHMARK:
858+
859+
def permute_and_pad():
860+
permuted_output, permuted_probs, row_id_map = te_permute_with_probs(
861+
permute_pad_fwd_input,
862+
probs,
863+
routing_map,
864+
num_out_tokens=num_out_tokens,
865+
)
866+
fp8_padding(permuted_output, tokens_per_expert_list)
867+
fp8_padding(permuted_probs.unsqueeze(-1), tokens_per_expert_list)
868+
869+
def fusion_permute_and_pad():
870+
(
871+
fusion_permuted_padded_output,
872+
fusion_permuted_padded_probs,
873+
row_id_map,
874+
pad_offsets,
875+
target_tokens_per_expert,
876+
) = te_permute_and_pad_with_probs(
877+
fusion_permute_and_pad_fwd_input,
878+
probs,
879+
routing_map,
880+
tokens_per_expert,
881+
align_size,
882+
)
883+
fusion_permuted_padded_probs = fusion_permuted_padded_probs.unsqueeze(-1)
884+
885+
t1 = perf_test_cuda_kernel(lambda: permute_and_pad())
886+
887+
t2 = perf_test_cuda_kernel(lambda: fusion_permute_and_pad())
888+
889+
print(f"permute_and_pad\t\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")
890+
891+
t1 = perf_test_cuda_kernel(
892+
lambda: backward_wrapper(
893+
permuted_paded_output,
894+
permute_pad_bwd_input,
895+
forward_input=[permute_pad_fwd_input],
896+
retain_graph=True,
897+
accumulate_grad=False,
898+
)
899+
)
900+
t2 = perf_test_cuda_kernel(
901+
lambda: backward_wrapper(
902+
fusion_permuted_padded_output,
903+
fusion_permute_pad_bwd_input,
904+
forward_input=[fusion_permute_and_pad_fwd_input],
905+
retain_graph=True,
906+
accumulate_grad=False,
907+
)
908+
)
909+
print(f"permute_and_pad\t\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")
910+
911+
def unpad_unpermute():
912+
unpaded_output = fp8_unpadding(
913+
unpermute_unpad_fwd_input, tokens_per_expert_list
914+
)
915+
unpermuted_unpaded_output = te_unpermute(
916+
unpaded_output, row_id_map, restore_shape=restore_shape
917+
)
918+
919+
unpermuted_unpaded_output.backward(
920+
unpermute_unpad_bwd_input, retain_graph=True
921+
)
922+
923+
t1 = perf_test_cuda_kernel(lambda: unpad_unpermute())
924+
t2 = perf_test_cuda_kernel(
925+
lambda: te_unpermute(
926+
fusion_unpermute_unpad_fwd_input,
927+
row_id_map,
928+
restore_shape=restore_shape,
929+
pad_offsets=pad_offsets,
930+
)
931+
)
932+
print(f"unpermute_and_unpad\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")
933+
934+
t1 = perf_test_cuda_kernel(
935+
lambda: backward_wrapper(
936+
unpermuted_unpaded_output,
937+
unpermute_unpad_bwd_input,
938+
forward_input=([unpermute_unpad_fwd_input, probs]),
939+
retain_graph=True,
940+
accumulate_grad=False,
941+
)
942+
)
943+
t2 = perf_test_cuda_kernel(
944+
lambda: backward_wrapper(
945+
fusion_unpermuted_unpaded_output,
946+
fusion_unpermute_bwd_input,
947+
forward_input=([fusion_unpermute_unpad_fwd_input, probs]),
948+
retain_graph=True,
949+
accumulate_grad=False,
950+
)
951+
)
952+
print(f"unpermute_and_unpad\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")
953+
954+
656955
def _test_permutation_mask_map_fp8(
657956
te_dtype,
658957
num_tokens,
@@ -1180,6 +1479,40 @@ def test_permutation_mask_map(
11801479
)
11811480

11821481

1482+
@pytest.mark.parametrize("te_dtype", _te_dtypes)
1483+
@pytest.mark.parametrize("num_out_tokens", [None])
1484+
@pytest.mark.parametrize(
1485+
"num_tokens, num_expert, hidden_size, topK",
1486+
[
1487+
(4096, 64, 1280, 7),
1488+
(4096, 64, 2048, 6),
1489+
(4096, 160, 5120, 6),
1490+
(4096, 256, 7168, 8),
1491+
(4096, 384, 8192, 8),
1492+
(4096, 512, 9216, 8),
1493+
],
1494+
)
1495+
def test_permutation_and_padding_mask_map(
1496+
te_dtype,
1497+
num_tokens,
1498+
num_expert,
1499+
hidden_size,
1500+
topK,
1501+
num_out_tokens,
1502+
):
1503+
BENCHMARK = False
1504+
1505+
_test_permutation_and_padding_mask_map(
1506+
te_dtype=te_dtype,
1507+
num_tokens=num_tokens,
1508+
num_expert=num_expert,
1509+
hidden_size=hidden_size,
1510+
topK=topK,
1511+
num_out_tokens=num_out_tokens,
1512+
BENCHMARK=BENCHMARK,
1513+
)
1514+
1515+
11831516
@pytest.mark.parametrize("te_dtype", _te_dtypes)
11841517
def test_permutation_mask_map_empty_input(te_dtype):
11851518
with_probs = True
@@ -1413,6 +1746,16 @@ def test_permutation_single_case():
14131746
BENCHMARK=Benchmark,
14141747
)
14151748

1749+
_test_permutation_and_padding_mask_map(
1750+
te_dtype=te_dtype,
1751+
num_tokens=num_tokens,
1752+
num_expert=num_expert,
1753+
hidden_size=hidden_size,
1754+
topK=topK,
1755+
num_out_tokens=num_out_tokens,
1756+
BENCHMARK=Benchmark,
1757+
)
1758+
14161759
_test_moe_chunk_sort(
14171760
te_dtype=te_dtype,
14181761
num_tokens=num_tokens,
@@ -1479,6 +1822,18 @@ def benchmark_single_case(
14791822
)
14801823
torch.cuda.nvtx.range_pop()
14811824

1825+
torch.cuda.nvtx.range_push("permutation_and_padding_mask_map")
1826+
_test_permutation_and_padding_mask_map(
1827+
te_dtype=te_dtype,
1828+
num_tokens=num_tokens,
1829+
num_expert=num_expert,
1830+
hidden_size=hidden_size,
1831+
topK=topK,
1832+
num_out_tokens=num_out_tokens,
1833+
BENCHMARK=True,
1834+
)
1835+
torch.cuda.nvtx.range_pop()
1836+
14821837
torch.cuda.nvtx.range_push("permutation_mask_map_alongside_probs")
14831838
_test_permutation_mask_map_alongside_probs(
14841839
te_dtype=te_dtype,

0 commit comments

Comments
 (0)