@@ -21,6 +21,16 @@ torch::Tensor unpermute_with_index_map(torch::Tensor permuted_tokens,
21
21
int64_t n_tokens,
22
22
int64_t topk);
23
23
24
+ std::tuple<torch::Tensor, torch::Tensor> permute_with_mask_map (
25
+ torch::Tensor tokens,
26
+ torch::Tensor indices);
27
+
28
+ torch::Tensor unpermute_with_mask_map (torch::Tensor permuted_tokens,
29
+ torch::Tensor row_id_map,
30
+ torch::Tensor probs,
31
+ int64_t n_tokens,
32
+ int64_t topk);
33
+
24
34
} // namespace kernel::moe
25
35
26
36
namespace {
@@ -66,6 +76,47 @@ torch::Tensor unpermute_index_ref(
66
76
return tokens.sum (/* dim=*/ 1 );
67
77
}
68
78
79
+ std::tuple<torch::Tensor, torch::Tensor> permute_mask_ref (
80
+ const torch::Tensor& tokens, // [n_tokens, dim]
81
+ const torch::Tensor& topk_indices // [n_tokens, topk]
82
+ ) {
83
+ const auto n_tokens = tokens.size (0 );
84
+ const auto topk = topk_indices.size (1 );
85
+
86
+ auto flatten_indices = topk_indices.view ({-1 });
87
+ // idx, sorted by (experts, tokens)
88
+ auto sorted_incices = flatten_indices.argsort (/* stable=*/ true );
89
+
90
+ // idx => token_indices, [n_permuted_tokens]
91
+ auto token_indices = sorted_incices.div (topk, /* rounding_mode=*/ " floor" );
92
+ auto permuted_tokens = tokens.index_select (
93
+ /* dim=*/ 0 , token_indices);
94
+
95
+ return {permuted_tokens, sorted_incices};
96
+ }
97
+
98
+ torch::Tensor unpermute_mask_ref (
99
+ const torch::Tensor& permuted_tokens, // [n_permuted_tokens, dim]
100
+ const torch::Tensor& sorted_incices, // [n_permuted_tokens]
101
+ const torch::Tensor& probs, // [n_token, topk]
102
+ int64_t n_tokens,
103
+ int64_t topK) {
104
+ auto tokens = torch::zeros_like (permuted_tokens);
105
+
106
+ // [n_permuted_tokens, dim] restore back to original order, sorted by (tokens)
107
+ tokens.index_copy_ (
108
+ /* dim=*/ 0 , sorted_incices, permuted_tokens);
109
+ // [n_permuted_tokens, dim] => [n_tokens, topk, dim]
110
+ tokens = tokens.reshape ({n_tokens, topK, -1 });
111
+
112
+ // apply prob
113
+ // [n_tokens, topk, dim] * [n_tokens, topk]
114
+ tokens *= probs.unsqueeze (/* dim=*/ -1 );
115
+
116
+ // [n_tokens, dim], sum over topk
117
+ return tokens.sum (/* dim=*/ 1 );
118
+ }
119
+
69
120
} // namespace
70
121
71
122
class PermuteTest
@@ -111,6 +162,36 @@ TEST_P(PermuteTest, Index) {
111
162
torch::allclose (tokens, unpermute_out, /* rtol=*/ 1e-2 , /* atol=*/ 1e-2 ));
112
163
}
113
164
165
+ TEST_P (PermuteTest, Mask) {
166
+ const auto [dtype, n_tokens, dim, n_experts, topk] = GetParam ();
167
+
168
+ const auto options = torch::dtype (dtype).device (torch::kCUDA );
169
+
170
+ const auto tokens = torch::randn ({n_tokens, dim}, options);
171
+ const auto gating_logit = torch::randn ({n_tokens, n_experts}, options);
172
+
173
+ auto [weights, indices] = gating_logit.topk (topk, /* dim=*/ -1 );
174
+ auto probs = weights.softmax (/* dim=*/ -1 );
175
+
176
+ auto [permuted_tokens, sorted_indices] =
177
+ kernel::moe::permute_with_mask_map (tokens, indices.to (torch::kInt32 ));
178
+
179
+ auto [ref_permuted_tokens, ref_sorted_indices] =
180
+ permute_mask_ref (tokens, indices);
181
+
182
+ EXPECT_TRUE (torch::allclose (permuted_tokens, ref_permuted_tokens));
183
+
184
+ auto unpermute_out = kernel::moe::unpermute_with_mask_map (
185
+ permuted_tokens, sorted_indices, probs, n_tokens, topk);
186
+
187
+ auto ref_unpermute_out = unpermute_mask_ref (
188
+ ref_permuted_tokens, ref_sorted_indices, probs, n_tokens, topk);
189
+ EXPECT_TRUE (torch::allclose (
190
+ unpermute_out, ref_unpermute_out, /* rtol=*/ 1e-2 , /* atol=*/ 1e-2 ));
191
+ EXPECT_TRUE (
192
+ torch::allclose (tokens, unpermute_out, /* rtol=*/ 1e-2 , /* atol=*/ 1e-2 ));
193
+ }
194
+
114
195
INSTANTIATE_TEST_SUITE_P (
115
196
Moe,
116
197
PermuteTest,
0 commit comments