Skip to content

Commit be08d72

Browse files
am17anstruct
authored andcommitted
CUDA: add a fused top-K MoE kernel (ggml-org#16130)
* CUDA: add a fused top-K MoE kernel This kernel does the following: 1. softmax over the logits per token [n_experts, n_tokens] 2. argmax reduce over the top-k (n_experts_used) logits 3. write weights + ids to global memory It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models * Refactor into ggml_cuda_should_use_topk_moe * Review: Use better coalescing pattern, use WARP_SIZE, store logits into registers before * Review: format + micro-optimizations * Fix bug: fix tie breakers * Add optional norm + clean-up code * Use smem for final write * Add bounds check * Use better memory pattern for writeback
1 parent 9fdf159 commit be08d72

File tree

5 files changed

+381
-0
lines changed

5 files changed

+381
-0
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "ggml-cuda/sumrows.cuh"
4646
#include "ggml-cuda/mean.cuh"
4747
#include "ggml-cuda/tsembd.cuh"
48+
#include "ggml-cuda/topk-moe.cuh"
4849
#include "ggml-cuda/unary.cuh"
4950
#include "ggml-cuda/upscale.cuh"
5051
#include "ggml-cuda/wkv.cuh"
@@ -2825,6 +2826,44 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28252826
GGML_ASSERT(unary_ops.size() == num_unary);
28262827
#endif
28272828

2829+
//TODO: remove special case once ggml_can_fuse can handle empty nodes
2830+
std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false);
2831+
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
2832+
2833+
if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) {
2834+
2835+
if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) {
2836+
return false;
2837+
}
2838+
2839+
for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) {
2840+
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false;
2841+
}
2842+
ggml_tensor * softmax = cgraph->nodes[node_idx];
2843+
ggml_tensor * weights = cgraph->nodes[node_idx+8];
2844+
2845+
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
2846+
return true;
2847+
}
2848+
}
2849+
2850+
if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) {
2851+
2852+
if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) {
2853+
return false;
2854+
}
2855+
2856+
for (size_t i = 0; i < topk_moe_ops.size(); i++) {
2857+
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false;
2858+
}
2859+
2860+
ggml_tensor * softmax = cgraph->nodes[node_idx];
2861+
ggml_tensor * weights = cgraph->nodes[node_idx+4];
2862+
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
2863+
return true;
2864+
}
2865+
}
2866+
28282867
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
28292868
return false;
28302869
}
@@ -2915,6 +2954,22 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
29152954
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
29162955
if (!disable_fusion) {
29172956

2957+
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
2958+
ggml_tensor * weights = cgraph->nodes[i+8];
2959+
ggml_tensor * selected_experts = cgraph->nodes[i+3];
2960+
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true);
2961+
i += 8;
2962+
continue;
2963+
}
2964+
2965+
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
2966+
ggml_tensor * weights = cgraph->nodes[i+4];
2967+
ggml_tensor * selected_experts = cgraph->nodes[i+3];
2968+
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false);
2969+
i += 4;
2970+
continue;
2971+
}
2972+
29182973
if (node->op == GGML_OP_ADD) {
29192974
int n_fuse = 0;
29202975
ggml_op ops[8];

ggml/src/ggml-cuda/topk-moe.cu

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
#include "ggml-cuda/common.cuh"
2+
#include "ggml.h"
3+
#include "topk-moe.cuh"
4+
5+
#include <initializer_list>
6+
7+
/*
8+
This kernel does the following:
9+
1. softmax over the logits per token [n_experts, n_tokens]
10+
2. argmax reduce over the top-k (n_experts_used) logits
11+
3. write weights + ids to global memory
12+
4. optionally normalize the weights
13+
14+
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
15+
*/
16+
template <size_t n_experts, bool with_norm>
17+
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
18+
float * weights,
19+
int32_t * ids,
20+
const int n_rows,
21+
const int n_expert_used) {
22+
const int row = blockIdx.x * blockDim.y + threadIdx.y;
23+
if (row >= n_rows) {
24+
return;
25+
}
26+
27+
logits += n_experts * row;
28+
weights += n_expert_used * row;
29+
ids += n_experts * row;
30+
31+
constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
32+
33+
float logits_r[experts_per_thread];
34+
35+
#pragma unroll
36+
for (int i = 0; i < n_experts; i += WARP_SIZE) {
37+
const int expert = i + threadIdx.x;
38+
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY;
39+
}
40+
41+
float max_val = logits_r[0];
42+
43+
#pragma unroll
44+
for (int i = 1; i < experts_per_thread; i++) {
45+
const float val = logits_r[i];
46+
max_val = max(val, max_val);
47+
}
48+
49+
max_val = warp_reduce_max(max_val);
50+
51+
float wt[experts_per_thread];
52+
float tmp = 0.f;
53+
54+
#pragma unroll
55+
for (int i = 0; i < experts_per_thread; i++) {
56+
const float val = logits_r[i];
57+
wt[i] = expf(val - max_val);
58+
tmp += wt[i];
59+
}
60+
61+
tmp = warp_reduce_sum(tmp);
62+
63+
const float inv_sum = 1.0f / tmp;
64+
65+
#pragma unroll
66+
for (int i = 0; i < experts_per_thread; i++) {
67+
wt[i] = wt[i] * inv_sum;
68+
}
69+
70+
//at this point, each thread holds a portion of softmax,
71+
//we do the argmax reduce over n_expert_used, each time marking
72+
//the expert weight as -inf to exclude from the next iteration
73+
74+
float wt_sum = 0.f;
75+
76+
extern __shared__ float data_topk_shared[];
77+
float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used;
78+
79+
for (int k = 0; k < n_expert_used; k++) {
80+
float max_val = wt[0];
81+
int max_expert = threadIdx.x;
82+
83+
#pragma unroll
84+
for (int i = 1; i < experts_per_thread; i++) {
85+
const int expert = threadIdx.x + i * WARP_SIZE;
86+
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
87+
max_val = wt[i];
88+
max_expert = expert;
89+
}
90+
}
91+
92+
#pragma unroll
93+
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
94+
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
95+
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
96+
if (val > max_val || (val == max_val && expert < max_expert)) {
97+
max_val = val;
98+
max_expert = expert;
99+
}
100+
}
101+
102+
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
103+
wt[max_expert / WARP_SIZE] = -INFINITY;
104+
105+
wt_shared_ptr[k] = max_val;
106+
ids[k] = max_expert;
107+
if constexpr (with_norm) {
108+
wt_sum += max_val;
109+
}
110+
}
111+
}
112+
113+
if constexpr (with_norm) {
114+
wt_sum = warp_reduce_sum(wt_sum);
115+
const float inv_sum = 1.0f / wt_sum;
116+
117+
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
118+
wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum;
119+
}
120+
}
121+
122+
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
123+
weights[i] = wt_shared_ptr[i];
124+
}
125+
}
126+
127+
template <bool with_norm>
128+
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
129+
const float * logits,
130+
float * weights,
131+
int32_t * ids,
132+
const int n_rows,
133+
const int n_expert,
134+
const int n_expert_used) {
135+
const int rows_per_block = 4;
136+
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
137+
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
138+
cudaStream_t stream = ctx.stream();
139+
140+
const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float);
141+
142+
switch (n_expert) {
143+
case 1:
144+
topk_moe_cuda<1, with_norm>
145+
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
146+
break;
147+
case 2:
148+
topk_moe_cuda<2, with_norm>
149+
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
150+
break;
151+
case 4:
152+
topk_moe_cuda<4, with_norm>
153+
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
154+
break;
155+
case 8:
156+
topk_moe_cuda<8, with_norm>
157+
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
158+
break;
159+
case 16:
160+
topk_moe_cuda<16, with_norm>
161+
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
162+
break;
163+
case 32:
164+
topk_moe_cuda<32, with_norm>
165+
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
166+
break;
167+
case 64:
168+
topk_moe_cuda<64, with_norm>
169+
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
170+
break;
171+
case 128:
172+
topk_moe_cuda<128, with_norm>
173+
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
174+
break;
175+
case 256:
176+
topk_moe_cuda<256, with_norm>
177+
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
178+
break;
179+
case 512:
180+
topk_moe_cuda<512, with_norm>
181+
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
182+
break;
183+
default:
184+
GGML_ASSERT(false && "fatal error");
185+
break;
186+
}
187+
}
188+
189+
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
190+
const ggml_tensor * logits,
191+
ggml_tensor * weights,
192+
ggml_tensor * ids,
193+
const bool with_norm) {
194+
GGML_ASSERT(logits->type == GGML_TYPE_F32);
195+
GGML_ASSERT(weights->type == GGML_TYPE_F32);
196+
GGML_ASSERT(ids->type == GGML_TYPE_I32);
197+
198+
const int n_experts = logits->ne[0];
199+
const int n_rows = logits->ne[1];
200+
201+
const float * logits_d = (const float *) logits->src[0]->data;
202+
float * weights_d = (float *) weights->data;
203+
int32_t * ids_d = (int32_t *) ids->data;
204+
205+
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
206+
207+
cudaStream_t stream = ctx.stream();
208+
209+
const int n_expert_used = weights->ne[1];
210+
211+
if (with_norm) {
212+
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
213+
} else {
214+
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
215+
}
216+
}
217+
218+
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) {
219+
float scale = 1.0f;
220+
float max_bias = 0.0f;
221+
222+
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
223+
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
224+
225+
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
226+
return false;
227+
}
228+
229+
if (scale != 1.0f || max_bias != 0.0f) {
230+
return false;
231+
}
232+
233+
// don't fuse when masks or sinks are present
234+
if (softmax->src[1] || softmax->src[2]) {
235+
return false;
236+
}
237+
238+
const int n_expert = softmax->ne[0];
239+
// n_expert must be a power of 2
240+
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
241+
return false;
242+
}
243+
244+
return true;
245+
}
246+
247+
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
248+
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
249+
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
250+
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
251+
252+
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
253+
GGML_OP_VIEW, GGML_OP_GET_ROWS };
254+
255+
if (norm) {
256+
return norm_ops;
257+
}
258+
return no_norm_ops;
259+
}

ggml/src/ggml-cuda/topk-moe.cuh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#include "common.cuh"
2+
#include "ggml.h"
3+
4+
#include <initializer_list>
5+
6+
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
7+
const ggml_tensor * logits,
8+
ggml_tensor * weights,
9+
ggml_tensor * top_k,
10+
const bool with_norm);
11+
12+
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
13+
14+
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm);

src/llama-graph.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
932932
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
933933
cb(weights, "ffn_moe_weights", il);
934934

935+
935936
if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
936937
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
937938
weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
@@ -955,6 +956,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
955956
cb(weights, "ffn_moe_weights_scaled", il);
956957
}
957958

959+
//call early so that topk-moe can be used
960+
ggml_build_forward_expand(gf, weights);
961+
958962
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
959963

960964
if (weight_before_ffn) {

0 commit comments

Comments
 (0)