@@ -263,7 +263,7 @@ def routing_reference_no_aux(expert_logits,
263263
264264
265265# TopK -> Softmax
266- def routing_reference_renormalize (expert_logits , top_k , num_experts , padding ):
266+ def routing_reference_renormalize (expert_logits , top_k , padding ):
267267 topk_values , topk_idx = torch .topk (expert_logits , k = top_k , dim = - 1 )
268268 topk_values = torch .nn .functional .softmax (topk_values .float (), dim = - 1 )
269269
@@ -279,8 +279,7 @@ def routing_reference_renormalize(expert_logits, top_k, num_experts, padding):
279279
280280
281281# Softmax->TopK -> Normalize
282- def routing_reference_renormalize_naive (expert_logits , top_k , num_experts ,
283- padding ):
282+ def routing_reference_renormalize_naive (expert_logits , top_k , padding ):
284283 norm_topk_prob = True
285284 scores = torch .nn .functional .softmax (expert_logits .float (), dim = - 1 )
286285 topk_values , topk_idx = torch .topk (scores , k = top_k , dim = - 1 )
@@ -1002,7 +1001,6 @@ class TestMoeFp4:
10021001 {
10031002 "num_experts" : 256 ,
10041003 "top_k" : 8 ,
1005- "padding" : 8 ,
10061004 "n_groups" : 8 ,
10071005 "top_k_groups" : 4 ,
10081006 "routed_scaling" : 2.5 ,
@@ -1014,7 +1012,6 @@ class TestMoeFp4:
10141012 {
10151013 "num_experts" : 72 ,
10161014 "top_k" : 6 ,
1017- "padding" : 8 ,
10181015 "n_groups" : 1 ,
10191016 "top_k_groups" : 1 ,
10201017 "routed_scaling" : 2.5 ,
@@ -1026,7 +1023,6 @@ class TestMoeFp4:
10261023 {
10271024 "num_experts" : 128 ,
10281025 "top_k" : 8 ,
1029- "padding" : 8 ,
10301026 "n_groups" : None ,
10311027 "top_k_groups" : None ,
10321028 "routed_scaling" : None ,
@@ -1038,7 +1034,6 @@ class TestMoeFp4:
10381034 {
10391035 "num_experts" : 128 ,
10401036 "top_k" : 4 ,
1041- "padding" : 8 ,
10421037 "n_groups" : None ,
10431038 "top_k_groups" : None ,
10441039 "routed_scaling" : None ,
@@ -1050,7 +1045,6 @@ class TestMoeFp4:
10501045 {
10511046 "num_experts" : 512 ,
10521047 "top_k" : 10 ,
1053- "padding" : 8 ,
10541048 "n_groups" : None ,
10551049 "top_k_groups" : None ,
10561050 "routed_scaling" : None ,
@@ -1080,7 +1074,6 @@ def test_autotune(self, num_tokens, hidden_size, intermediate_size,
10801074 {
10811075 "num_experts" : 72 ,
10821076 "top_k" : 6 ,
1083- "padding" : 8 ,
10841077 "n_groups" : 1 ,
10851078 "top_k_groups" : 1 ,
10861079 "routed_scaling" : 2.5 ,
@@ -1110,7 +1103,6 @@ def test_autotune_fp8_fp4(self, num_tokens, hidden_size, intermediate_size,
11101103 {
11111104 "num_experts" : 256 ,
11121105 "top_k" : 8 ,
1113- "padding" : 8 ,
11141106 "n_groups" : 8 ,
11151107 "top_k_groups" : 4 ,
11161108 "routed_scaling" : 2.5 ,
@@ -1122,7 +1114,6 @@ def test_autotune_fp8_fp4(self, num_tokens, hidden_size, intermediate_size,
11221114 {
11231115 "num_experts" : 128 ,
11241116 "top_k" : 4 ,
1125- "padding" : 8 ,
11261117 "n_groups" : None ,
11271118 "top_k_groups" : None ,
11281119 "routed_scaling" : None ,
@@ -1134,7 +1125,6 @@ def test_autotune_fp8_fp4(self, num_tokens, hidden_size, intermediate_size,
11341125 {
11351126 "num_experts" : 512 ,
11361127 "top_k" : 10 ,
1137- "padding" : 8 ,
11381128 "n_groups" : None ,
11391129 "top_k_groups" : None ,
11401130 "routed_scaling" : None ,
@@ -1166,7 +1156,6 @@ def test_no_autotune(self, num_tokens, hidden_size, intermediate_size,
11661156 {
11671157 "num_experts" : 128 ,
11681158 "top_k" : 4 ,
1169- "padding" : 8 ,
11701159 "n_groups" : None ,
11711160 "top_k_groups" : None ,
11721161 "routed_scaling" : None ,
@@ -1305,10 +1294,10 @@ def run_moe_fp4_test(self, num_tokens: int, hidden_size: int,
13051294 routed_scaling , padding )
13061295 elif routing_method_type == RoutingMethodType .Renormalize :
13071296 permute_info , scores = routing_reference_renormalize (
1308- expert_logits , top_k , num_experts , padding )
1297+ expert_logits , top_k , padding )
13091298 elif routing_method_type == RoutingMethodType .RenormalizeNaive :
13101299 permute_info , scores = routing_reference_renormalize_naive (
1311- expert_logits , top_k , num_experts , padding )
1300+ expert_logits , top_k , padding )
13121301
13131302 args = moe_args (num_tokens , num_experts , hidden_size , intermediate_size ,
13141303 top_k , padding , hidden_states_fp4_bytes ,
@@ -1552,10 +1541,10 @@ def run_moe_fp8_fp4_test(self, num_tokens: int, hidden_size: int,
15521541 routed_scaling , padding )
15531542 elif routing_method_type == RoutingMethodType .Renormalize :
15541543 permute_info , scores = routing_reference_renormalize (
1555- expert_logits , top_k , num_experts , padding )
1544+ expert_logits , top_k , padding )
15561545 elif routing_method_type == RoutingMethodType .RenormalizeNaive :
15571546 permute_info , scores = routing_reference_renormalize_naive (
1558- expert_logits , top_k , num_experts , padding )
1547+ expert_logits , top_k , padding )
15591548
15601549 args = moe_args (num_tokens , num_experts , hidden_size , intermediate_size ,
15611550 top_k , padding , hidden_states_fp8 , None ,
@@ -2028,10 +2017,10 @@ def test_moe_mxe2m1_weights(num_tokens, hidden_size, intermediate_size,
20282017 sf_block_size ) # ue8m0 scaling factors
20292018 if routing_method_type == RoutingMethodType .Renormalize :
20302019 permute_info , scores = routing_reference_renormalize (
2031- expert_logits , top_k , num_experts , padding )
2020+ expert_logits , top_k , padding )
20322021 elif routing_method_type == RoutingMethodType .RenormalizeNaive :
20332022 permute_info , scores = routing_reference_renormalize_naive (
2034- expert_logits , top_k , num_experts , padding )
2023+ expert_logits , top_k , padding )
20352024 else :
20362025 raise ValueError ("Invalid routing method type" )
20372026
0 commit comments