diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 4d85c5dc083d1..8c8647b147369 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -45,6 +45,7 @@ #include "ggml-cuda/sumrows.cuh" #include "ggml-cuda/mean.cuh" #include "ggml-cuda/tsembd.cuh" +#include "ggml-cuda/topk-moe.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" @@ -2825,6 +2826,44 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, GGML_ASSERT(unary_ops.size() == num_unary); #endif + //TODO: remove special case once ggml_can_fuse can handle empty nodes + std::initializer_list topk_moe_ops = ggml_cuda_topk_moe_ops(false); + std::initializer_list topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true); + + if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) { + + if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) { + return false; + } + + for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) { + if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false; + } + ggml_tensor * softmax = cgraph->nodes[node_idx]; + ggml_tensor * weights = cgraph->nodes[node_idx+8]; + + if (ggml_cuda_should_use_topk_moe(softmax, weights)) { + return true; + } + } + + if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) { + + if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) { + return false; + } + + for (size_t i = 0; i < topk_moe_ops.size(); i++) { + if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false; + } + + ggml_tensor * softmax = cgraph->nodes[node_idx]; + ggml_tensor * weights = cgraph->nodes[node_idx+4]; + if (ggml_cuda_should_use_topk_moe(softmax, weights)) { + return true; + } + } + if (!ggml_can_fuse(cgraph, node_idx, ops)) { return false; } @@ -2915,6 +2954,22 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); if (!disable_fusion) { + if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) { + ggml_tensor * weights = cgraph->nodes[i+8]; + ggml_tensor * selected_experts = cgraph->nodes[i+3]; + ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true); + i += 8; + continue; + } + + if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) { + ggml_tensor * weights = cgraph->nodes[i+4]; + ggml_tensor * selected_experts = cgraph->nodes[i+3]; + ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false); + i += 4; + continue; + } + if (node->op == GGML_OP_ADD) { int n_fuse = 0; ggml_op ops[8]; diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu new file mode 100644 index 0000000000000..039f284719648 --- /dev/null +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -0,0 +1,259 @@ +#include "ggml-cuda/common.cuh" +#include "ggml.h" +#include "topk-moe.cuh" + +#include + +/* + This kernel does the following: + 1. softmax over the logits per token [n_experts, n_tokens] + 2. argmax reduce over the top-k (n_experts_used) logits + 3. write weights + ids to global memory + 4. optionally normalize the weights + + It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models +*/ +template +__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, + float * weights, + int32_t * ids, + const int n_rows, + const int n_expert_used) { + const int row = blockIdx.x * blockDim.y + threadIdx.y; + if (row >= n_rows) { + return; + } + + logits += n_experts * row; + weights += n_expert_used * row; + ids += n_experts * row; + + constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; + + float logits_r[experts_per_thread]; + +#pragma unroll + for (int i = 0; i < n_experts; i += WARP_SIZE) { + const int expert = i + threadIdx.x; + logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY; + } + + float max_val = logits_r[0]; + +#pragma unroll + for (int i = 1; i < experts_per_thread; i++) { + const float val = logits_r[i]; + max_val = max(val, max_val); + } + + max_val = warp_reduce_max(max_val); + + float wt[experts_per_thread]; + float tmp = 0.f; + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const float val = logits_r[i]; + wt[i] = expf(val - max_val); + tmp += wt[i]; + } + + tmp = warp_reduce_sum(tmp); + + const float inv_sum = 1.0f / tmp; + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + wt[i] = wt[i] * inv_sum; + } + + //at this point, each thread holds a portion of softmax, + //we do the argmax reduce over n_expert_used, each time marking + //the expert weight as -inf to exclude from the next iteration + + float wt_sum = 0.f; + + extern __shared__ float data_topk_shared[]; + float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used; + + for (int k = 0; k < n_expert_used; k++) { + float max_val = wt[0]; + int max_expert = threadIdx.x; + +#pragma unroll + for (int i = 1; i < experts_per_thread; i++) { + const int expert = threadIdx.x + i * WARP_SIZE; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { + max_val = wt[i]; + max_expert = expert; + } + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { + const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); + const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); + if (val > max_val || (val == max_val && expert < max_expert)) { + max_val = val; + max_expert = expert; + } + } + + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { + wt[max_expert / WARP_SIZE] = -INFINITY; + + wt_shared_ptr[k] = max_val; + ids[k] = max_expert; + if constexpr (with_norm) { + wt_sum += max_val; + } + } + } + + if constexpr (with_norm) { + wt_sum = warp_reduce_sum(wt_sum); + const float inv_sum = 1.0f / wt_sum; + + for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) { + wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum; + } + } + + for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) { + weights[i] = wt_shared_ptr[i]; + } +} + +template +static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, + const float * logits, + float * weights, + int32_t * ids, + const int n_rows, + const int n_expert, + const int n_expert_used) { + const int rows_per_block = 4; + dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); + dim3 block_dims(WARP_SIZE, rows_per_block, 1); + cudaStream_t stream = ctx.stream(); + + const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float); + + switch (n_expert) { + case 1: + topk_moe_cuda<1, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 2: + topk_moe_cuda<2, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 4: + topk_moe_cuda<4, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 8: + topk_moe_cuda<8, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 16: + topk_moe_cuda<16, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 32: + topk_moe_cuda<32, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 64: + topk_moe_cuda<64, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 128: + topk_moe_cuda<128, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 256: + topk_moe_cuda<256, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 512: + topk_moe_cuda<512, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); + break; + default: + GGML_ASSERT(false && "fatal error"); + break; + } +} + +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, + const ggml_tensor * logits, + ggml_tensor * weights, + ggml_tensor * ids, + const bool with_norm) { + GGML_ASSERT(logits->type == GGML_TYPE_F32); + GGML_ASSERT(weights->type == GGML_TYPE_F32); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + const int n_experts = logits->ne[0]; + const int n_rows = logits->ne[1]; + + const float * logits_d = (const float *) logits->src[0]->data; + float * weights_d = (float *) weights->data; + int32_t * ids_d = (int32_t *) ids->data; + + GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); + + cudaStream_t stream = ctx.stream(); + + const int n_expert_used = weights->ne[1]; + + if (with_norm) { + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + } else { + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + } +} + +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) { + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float)); + + if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) { + return false; + } + + if (scale != 1.0f || max_bias != 0.0f) { + return false; + } + + // don't fuse when masks or sinks are present + if (softmax->src[1] || softmax->src[2]) { + return false; + } + + const int n_expert = softmax->ne[0]; + // n_expert must be a power of 2 + if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) { + return false; + } + + return true; +} + +std::initializer_list ggml_cuda_topk_moe_ops(bool norm) { + static std::initializer_list norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, + GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE }; + + static std::initializer_list no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS }; + + if (norm) { + return norm_ops; + } + return no_norm_ops; +} diff --git a/ggml/src/ggml-cuda/topk-moe.cuh b/ggml/src/ggml-cuda/topk-moe.cuh new file mode 100644 index 0000000000000..6613fb56507ea --- /dev/null +++ b/ggml/src/ggml-cuda/topk-moe.cuh @@ -0,0 +1,14 @@ +#include "common.cuh" +#include "ggml.h" + +#include + +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, + const ggml_tensor * logits, + ggml_tensor * weights, + ggml_tensor * top_k, + const bool with_norm); + +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights); + +std::initializer_list ggml_cuda_topk_moe_ops(bool with_norm); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 9f2e417f1ff4b..b49a964b6343c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -929,6 +929,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] cb(weights, "ffn_moe_weights", il); + if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) { weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens] @@ -952,6 +953,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(weights, "ffn_moe_weights_scaled", il); } + //call early so that topk-moe can be used + ggml_build_forward_expand(gf, weights); + cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); if (weight_before_ffn) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 592631f3ed21a..f87eaa8a3b0c8 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4418,6 +4418,49 @@ struct test_argsort : public test_case { } }; +struct test_topk_moe: public test_case { + const std::array ne; + const int n_expert_used; + const bool with_norm; + test_topk_moe(std::array ne = {10, 5, 1, 1}, int n_expert_used = 1, bool with_norm = false) + : ne(ne), n_expert_used(n_expert_used), with_norm(with_norm) { + GGML_ASSERT(n_expert_used <= ne[0]); + } + + std::string vars() override { + return VARS_TO_STR3(ne, n_expert_used, with_norm); + } + + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "TOPK_MOE"; + } + + bool run_whole_graph() override { return true; } + + ggml_tensor * build_graph(ggml_context * ctx) override { + const int n_expert = ne[0]; + const int n_tokens = ne[1]; + + ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); + ggml_tensor * probs = ggml_soft_max(ctx, logits); + ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] + + ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + + if (with_norm) { + out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens); + ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens] + + out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens] + out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens); + } + + ggml_set_name(out, "out"); + return out; + } +}; + // GGML_OP_SUM struct test_sum : public test_case { const ggml_type type; @@ -6588,6 +6631,12 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3})); test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, {10, 5, 4, 3})); + for (bool with_norm : {false, true}) { + test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm)); + test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm)); + test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm)); + } + #if 0 // these tests are disabled to save execution time, sbut they can be handy for debugging test_cases.emplace_back(new test_llama(2, true));