Skip to content

Commit 613b6c3

Browse files
committed
Refactor into ggml_cuda_should_use_topk_moe
1 parent 2141b8b commit 613b6c3

File tree

4 files changed

+45
-40
lines changed

4 files changed

+45
-40
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2835,29 +2835,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28352835
}
28362836

28372837
ggml_tensor * softmax = cgraph->nodes[node_idx];
2838-
2839-
float scale = 1.0f;
2840-
float max_bias = 0.0f;
2841-
2842-
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
2843-
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
2844-
2845-
if (scale != 1.0f || max_bias != 0.0f) {
2846-
return false;
2847-
}
2848-
2849-
// don't fuse when masks or sinks are present
2850-
if (softmax->src[1] || softmax->src[2]) {
2851-
return false;
2852-
}
2853-
2854-
const int n_expert = softmax->ne[0];
2855-
// n_expert must be a power of 2
2856-
if (n_expert & (n_expert - 1) != 0 || n_expert > 512) {
2857-
return false;
2838+
if (ggml_cuda_should_use_topk_moe(softmax)) {
2839+
return true;
28582840
}
2859-
2860-
return true;
28612841
}
28622842

28632843
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
@@ -2927,8 +2907,6 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
29272907
return true;
29282908
}
29292909

2930-
2931-
29322910
return false;
29332911
}
29342912

@@ -3010,7 +2988,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
30102988
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
30112989
continue;
30122990
}
3013-
30142991
}
30152992
#ifndef NDEBUG
30162993
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));

ggml/src/ggml-cuda/topk-moe.cu

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "ggml.h"
12
#include "topk-moe.cuh"
23

34
/*
@@ -10,10 +11,10 @@
1011
*/
1112
template <size_t n_experts>
1213
__global__ void topk_moe_cuda(const float * logits,
13-
float * weights,
14-
int32_t * ids,
15-
const int n_rows,
16-
const int n_expert_used) {
14+
float * weights,
15+
int32_t * ids,
16+
const int n_rows,
17+
const int n_expert_used) {
1718
const int row = blockIdx.x * blockDim.y + threadIdx.y;
1819
if (row >= n_rows) {
1920
return;
@@ -94,12 +95,12 @@ __global__ void topk_moe_cuda(const float * logits,
9495
}
9596

9697
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
97-
const float * logits,
98-
float * weights,
99-
int32_t * ids,
100-
const int n_rows,
101-
const int n_expert,
102-
const int n_expert_used) {
98+
const float * logits,
99+
float * weights,
100+
int32_t * ids,
101+
const int n_rows,
102+
const int n_expert,
103+
const int n_expert_used) {
103104
const int rows_per_block = 4;
104105
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
105106
dim3 block_dims(32, rows_per_block, 1);
@@ -143,9 +144,9 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
143144
}
144145

145146
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
146-
ggml_tensor * logits,
147-
ggml_tensor * weights,
148-
ggml_tensor * ids) {
147+
const ggml_tensor * logits,
148+
ggml_tensor * weights,
149+
ggml_tensor * ids) {
149150
GGML_ASSERT(logits->type == GGML_TYPE_F32);
150151
GGML_ASSERT(weights->type == GGML_TYPE_F32);
151152
GGML_ASSERT(ids->type == GGML_TYPE_I32);
@@ -163,3 +164,28 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
163164

164165
launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
165166
}
167+
168+
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax) {
169+
float scale = 1.0f;
170+
float max_bias = 0.0f;
171+
172+
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
173+
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
174+
175+
if (scale != 1.0f || max_bias != 0.0f) {
176+
return false;
177+
}
178+
179+
// don't fuse when masks or sinks are present
180+
if (softmax->src[1] || softmax->src[2]) {
181+
return false;
182+
}
183+
184+
const int n_expert = softmax->ne[0];
185+
// n_expert must be a power of 2
186+
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
187+
return false;
188+
}
189+
190+
return true;
191+
}

ggml/src/ggml-cuda/topk-moe.cuh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
#include "common.cuh"
22

3-
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, ggml_tensor * logits, ggml_tensor * weights, ggml_tensor * top_k);
3+
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const ggml_tensor * logits, ggml_tensor * weights, ggml_tensor * top_k);
4+
5+
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax);

tests/test-backend-ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4419,7 +4419,7 @@ struct test_topk_moe: public test_case {
44194419

44204420
std::string op_desc(ggml_tensor * t) override {
44214421
GGML_UNUSED(t);
4422-
return "TOPK_GATED_MOE";
4422+
return "TOPK_MOE";
44234423
}
44244424

44254425
bool run_whole_graph() override { return true; }

0 commit comments

Comments
 (0)