Skip to content

Commit 778eb84

Browse files
dongfengysuyoggupta
authored andcommitted
[None][chore] Clean up unused and confusing code in moe test (NVIDIA#9019)
Signed-off-by: Dongfeng Yu <[email protected]>
1 parent 3eb78f7 commit 778eb84

File tree

1 file changed

+8
-19
lines changed

1 file changed

+8
-19
lines changed

tests/unittest/_torch/thop/parallel/test_moe.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def routing_reference_no_aux(expert_logits,
263263

264264

265265
# TopK -> Softmax
266-
def routing_reference_renormalize(expert_logits, top_k, num_experts, padding):
266+
def routing_reference_renormalize(expert_logits, top_k, padding):
267267
topk_values, topk_idx = torch.topk(expert_logits, k=top_k, dim=-1)
268268
topk_values = torch.nn.functional.softmax(topk_values.float(), dim=-1)
269269

@@ -279,8 +279,7 @@ def routing_reference_renormalize(expert_logits, top_k, num_experts, padding):
279279

280280

281281
# Softmax->TopK -> Normalize
282-
def routing_reference_renormalize_naive(expert_logits, top_k, num_experts,
283-
padding):
282+
def routing_reference_renormalize_naive(expert_logits, top_k, padding):
284283
norm_topk_prob = True
285284
scores = torch.nn.functional.softmax(expert_logits.float(), dim=-1)
286285
topk_values, topk_idx = torch.topk(scores, k=top_k, dim=-1)
@@ -1002,7 +1001,6 @@ class TestMoeFp4:
10021001
{
10031002
"num_experts": 256,
10041003
"top_k": 8,
1005-
"padding": 8,
10061004
"n_groups": 8,
10071005
"top_k_groups": 4,
10081006
"routed_scaling": 2.5,
@@ -1014,7 +1012,6 @@ class TestMoeFp4:
10141012
{
10151013
"num_experts": 72,
10161014
"top_k": 6,
1017-
"padding": 8,
10181015
"n_groups": 1,
10191016
"top_k_groups": 1,
10201017
"routed_scaling": 2.5,
@@ -1026,7 +1023,6 @@ class TestMoeFp4:
10261023
{
10271024
"num_experts": 128,
10281025
"top_k": 8,
1029-
"padding": 8,
10301026
"n_groups": None,
10311027
"top_k_groups": None,
10321028
"routed_scaling": None,
@@ -1038,7 +1034,6 @@ class TestMoeFp4:
10381034
{
10391035
"num_experts": 128,
10401036
"top_k": 4,
1041-
"padding": 8,
10421037
"n_groups": None,
10431038
"top_k_groups": None,
10441039
"routed_scaling": None,
@@ -1050,7 +1045,6 @@ class TestMoeFp4:
10501045
{
10511046
"num_experts": 512,
10521047
"top_k": 10,
1053-
"padding": 8,
10541048
"n_groups": None,
10551049
"top_k_groups": None,
10561050
"routed_scaling": None,
@@ -1080,7 +1074,6 @@ def test_autotune(self, num_tokens, hidden_size, intermediate_size,
10801074
{
10811075
"num_experts": 72,
10821076
"top_k": 6,
1083-
"padding": 8,
10841077
"n_groups": 1,
10851078
"top_k_groups": 1,
10861079
"routed_scaling": 2.5,
@@ -1110,7 +1103,6 @@ def test_autotune_fp8_fp4(self, num_tokens, hidden_size, intermediate_size,
11101103
{
11111104
"num_experts": 256,
11121105
"top_k": 8,
1113-
"padding": 8,
11141106
"n_groups": 8,
11151107
"top_k_groups": 4,
11161108
"routed_scaling": 2.5,
@@ -1122,7 +1114,6 @@ def test_autotune_fp8_fp4(self, num_tokens, hidden_size, intermediate_size,
11221114
{
11231115
"num_experts": 128,
11241116
"top_k": 4,
1125-
"padding": 8,
11261117
"n_groups": None,
11271118
"top_k_groups": None,
11281119
"routed_scaling": None,
@@ -1134,7 +1125,6 @@ def test_autotune_fp8_fp4(self, num_tokens, hidden_size, intermediate_size,
11341125
{
11351126
"num_experts": 512,
11361127
"top_k": 10,
1137-
"padding": 8,
11381128
"n_groups": None,
11391129
"top_k_groups": None,
11401130
"routed_scaling": None,
@@ -1166,7 +1156,6 @@ def test_no_autotune(self, num_tokens, hidden_size, intermediate_size,
11661156
{
11671157
"num_experts": 128,
11681158
"top_k": 4,
1169-
"padding": 8,
11701159
"n_groups": None,
11711160
"top_k_groups": None,
11721161
"routed_scaling": None,
@@ -1305,10 +1294,10 @@ def run_moe_fp4_test(self, num_tokens: int, hidden_size: int,
13051294
routed_scaling, padding)
13061295
elif routing_method_type == RoutingMethodType.Renormalize:
13071296
permute_info, scores = routing_reference_renormalize(
1308-
expert_logits, top_k, num_experts, padding)
1297+
expert_logits, top_k, padding)
13091298
elif routing_method_type == RoutingMethodType.RenormalizeNaive:
13101299
permute_info, scores = routing_reference_renormalize_naive(
1311-
expert_logits, top_k, num_experts, padding)
1300+
expert_logits, top_k, padding)
13121301

13131302
args = moe_args(num_tokens, num_experts, hidden_size, intermediate_size,
13141303
top_k, padding, hidden_states_fp4_bytes,
@@ -1552,10 +1541,10 @@ def run_moe_fp8_fp4_test(self, num_tokens: int, hidden_size: int,
15521541
routed_scaling, padding)
15531542
elif routing_method_type == RoutingMethodType.Renormalize:
15541543
permute_info, scores = routing_reference_renormalize(
1555-
expert_logits, top_k, num_experts, padding)
1544+
expert_logits, top_k, padding)
15561545
elif routing_method_type == RoutingMethodType.RenormalizeNaive:
15571546
permute_info, scores = routing_reference_renormalize_naive(
1558-
expert_logits, top_k, num_experts, padding)
1547+
expert_logits, top_k, padding)
15591548

15601549
args = moe_args(num_tokens, num_experts, hidden_size, intermediate_size,
15611550
top_k, padding, hidden_states_fp8, None,
@@ -2028,10 +2017,10 @@ def test_moe_mxe2m1_weights(num_tokens, hidden_size, intermediate_size,
20282017
sf_block_size) # ue8m0 scaling factors
20292018
if routing_method_type == RoutingMethodType.Renormalize:
20302019
permute_info, scores = routing_reference_renormalize(
2031-
expert_logits, top_k, num_experts, padding)
2020+
expert_logits, top_k, padding)
20322021
elif routing_method_type == RoutingMethodType.RenormalizeNaive:
20332022
permute_info, scores = routing_reference_renormalize_naive(
2034-
expert_logits, top_k, num_experts, padding)
2023+
expert_logits, top_k, padding)
20352024
else:
20362025
raise ValueError("Invalid routing method type")
20372026

0 commit comments

Comments
 (0)