Skip to content

Commit 0a8001c

Browse files
committed
kenerl: add kernel for moe permutation with mask map
1 parent 4357525 commit 0a8001c

File tree

2 files changed

+448
-0
lines changed

2 files changed

+448
-0
lines changed

src/kernels/moe/permutation_kernel_test.cu

+81
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,16 @@ torch::Tensor unpermute_with_index_map(torch::Tensor permuted_tokens,
2121
int64_t n_tokens,
2222
int64_t topk);
2323

24+
std::tuple<torch::Tensor, torch::Tensor> permute_with_mask_map(
25+
torch::Tensor tokens,
26+
torch::Tensor indices);
27+
28+
torch::Tensor unpermute_with_mask_map(torch::Tensor permuted_tokens,
29+
torch::Tensor row_id_map,
30+
torch::Tensor probs,
31+
int64_t n_tokens,
32+
int64_t topk);
33+
2434
} // namespace kernel::moe
2535

2636
namespace {
@@ -66,6 +76,47 @@ torch::Tensor unpermute_index_ref(
6676
return tokens.sum(/*dim=*/1);
6777
}
6878

79+
std::tuple<torch::Tensor, torch::Tensor> permute_mask_ref(
80+
const torch::Tensor& tokens, // [n_tokens, dim]
81+
const torch::Tensor& topk_indices // [n_tokens, topk]
82+
) {
83+
const auto n_tokens = tokens.size(0);
84+
const auto topk = topk_indices.size(1);
85+
86+
auto flatten_indices = topk_indices.view({-1});
87+
// idx, sorted by (experts, tokens)
88+
auto sorted_incices = flatten_indices.argsort(/*stable=*/true);
89+
90+
// idx => token_indices, [n_permuted_tokens]
91+
auto token_indices = sorted_incices.div(topk, /*rounding_mode=*/"floor");
92+
auto permuted_tokens = tokens.index_select(
93+
/*dim=*/0, token_indices);
94+
95+
return {permuted_tokens, sorted_incices};
96+
}
97+
98+
torch::Tensor unpermute_mask_ref(
99+
const torch::Tensor& permuted_tokens, // [n_permuted_tokens, dim]
100+
const torch::Tensor& sorted_incices, // [n_permuted_tokens]
101+
const torch::Tensor& probs, // [n_token, topk]
102+
int64_t n_tokens,
103+
int64_t topK) {
104+
auto tokens = torch::zeros_like(permuted_tokens);
105+
106+
// [n_permuted_tokens, dim] restore back to original order, sorted by (tokens)
107+
tokens.index_copy_(
108+
/*dim=*/0, sorted_incices, permuted_tokens);
109+
// [n_permuted_tokens, dim] => [n_tokens, topk, dim]
110+
tokens = tokens.reshape({n_tokens, topK, -1});
111+
112+
// apply prob
113+
// [n_tokens, topk, dim] * [n_tokens, topk]
114+
tokens *= probs.unsqueeze(/*dim=*/-1);
115+
116+
// [n_tokens, dim], sum over topk
117+
return tokens.sum(/*dim=*/1);
118+
}
119+
69120
} // namespace
70121

71122
class PermuteTest
@@ -111,6 +162,36 @@ TEST_P(PermuteTest, Index) {
111162
torch::allclose(tokens, unpermute_out, /*rtol=*/1e-2, /*atol=*/1e-2));
112163
}
113164

165+
TEST_P(PermuteTest, Mask) {
166+
const auto [dtype, n_tokens, dim, n_experts, topk] = GetParam();
167+
168+
const auto options = torch::dtype(dtype).device(torch::kCUDA);
169+
170+
const auto tokens = torch::randn({n_tokens, dim}, options);
171+
const auto gating_logit = torch::randn({n_tokens, n_experts}, options);
172+
173+
auto [weights, indices] = gating_logit.topk(topk, /*dim=*/-1);
174+
auto probs = weights.softmax(/*dim=*/-1);
175+
176+
auto [permuted_tokens, sorted_indices] =
177+
kernel::moe::permute_with_mask_map(tokens, indices.to(torch::kInt32));
178+
179+
auto [ref_permuted_tokens, ref_sorted_indices] =
180+
permute_mask_ref(tokens, indices);
181+
182+
EXPECT_TRUE(torch::allclose(permuted_tokens, ref_permuted_tokens));
183+
184+
auto unpermute_out = kernel::moe::unpermute_with_mask_map(
185+
permuted_tokens, sorted_indices, probs, n_tokens, topk);
186+
187+
auto ref_unpermute_out = unpermute_mask_ref(
188+
ref_permuted_tokens, ref_sorted_indices, probs, n_tokens, topk);
189+
EXPECT_TRUE(torch::allclose(
190+
unpermute_out, ref_unpermute_out, /*rtol=*/1e-2, /*atol=*/1e-2));
191+
EXPECT_TRUE(
192+
torch::allclose(tokens, unpermute_out, /*rtol=*/1e-2, /*atol=*/1e-2));
193+
}
194+
114195
INSTANTIATE_TEST_SUITE_P(
115196
Moe,
116197
PermuteTest,

0 commit comments

Comments
 (0)