|
| 1 | +#include "ggml-cuda/common.cuh" |
| 2 | +#include "ggml.h" |
| 3 | +#include "topk-moe.cuh" |
| 4 | + |
| 5 | +#include <initializer_list> |
| 6 | + |
| 7 | +/* |
| 8 | + This kernel does the following: |
| 9 | + 1. softmax over the logits per token [n_experts, n_tokens] |
| 10 | + 2. argmax reduce over the top-k (n_experts_used) logits |
| 11 | + 3. write weights + ids to global memory |
| 12 | + 4. optionally normalize the weights |
| 13 | +
|
| 14 | + It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models |
| 15 | +*/ |
| 16 | +template <size_t n_experts, bool with_norm> |
| 17 | +__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, |
| 18 | + float * weights, |
| 19 | + int32_t * ids, |
| 20 | + const int n_rows, |
| 21 | + const int n_expert_used) { |
| 22 | + const int row = blockIdx.x * blockDim.y + threadIdx.y; |
| 23 | + if (row >= n_rows) { |
| 24 | + return; |
| 25 | + } |
| 26 | + |
| 27 | + logits += n_experts * row; |
| 28 | + weights += n_expert_used * row; |
| 29 | + ids += n_experts * row; |
| 30 | + |
| 31 | + constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; |
| 32 | + |
| 33 | + float logits_r[experts_per_thread]; |
| 34 | + |
| 35 | +#pragma unroll |
| 36 | + for (int i = 0; i < n_experts; i += WARP_SIZE) { |
| 37 | + const int expert = i + threadIdx.x; |
| 38 | + logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY; |
| 39 | + } |
| 40 | + |
| 41 | + float max_val = logits_r[0]; |
| 42 | + |
| 43 | +#pragma unroll |
| 44 | + for (int i = 1; i < experts_per_thread; i++) { |
| 45 | + const float val = logits_r[i]; |
| 46 | + max_val = max(val, max_val); |
| 47 | + } |
| 48 | + |
| 49 | + max_val = warp_reduce_max(max_val); |
| 50 | + |
| 51 | + float wt[experts_per_thread]; |
| 52 | + float tmp = 0.f; |
| 53 | + |
| 54 | +#pragma unroll |
| 55 | + for (int i = 0; i < experts_per_thread; i++) { |
| 56 | + const float val = logits_r[i]; |
| 57 | + wt[i] = expf(val - max_val); |
| 58 | + tmp += wt[i]; |
| 59 | + } |
| 60 | + |
| 61 | + tmp = warp_reduce_sum(tmp); |
| 62 | + |
| 63 | + const float inv_sum = 1.0f / tmp; |
| 64 | + |
| 65 | +#pragma unroll |
| 66 | + for (int i = 0; i < experts_per_thread; i++) { |
| 67 | + wt[i] = wt[i] * inv_sum; |
| 68 | + } |
| 69 | + |
| 70 | + //at this point, each thread holds a portion of softmax, |
| 71 | + //we do the argmax reduce over n_expert_used, each time marking |
| 72 | + //the expert weight as -inf to exclude from the next iteration |
| 73 | + |
| 74 | + float wt_sum = 0.f; |
| 75 | + |
| 76 | + extern __shared__ float data_topk_shared[]; |
| 77 | + float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used; |
| 78 | + |
| 79 | + for (int k = 0; k < n_expert_used; k++) { |
| 80 | + float max_val = wt[0]; |
| 81 | + int max_expert = threadIdx.x; |
| 82 | + |
| 83 | +#pragma unroll |
| 84 | + for (int i = 1; i < experts_per_thread; i++) { |
| 85 | + const int expert = threadIdx.x + i * WARP_SIZE; |
| 86 | + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { |
| 87 | + max_val = wt[i]; |
| 88 | + max_expert = expert; |
| 89 | + } |
| 90 | + } |
| 91 | + |
| 92 | +#pragma unroll |
| 93 | + for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { |
| 94 | + const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); |
| 95 | + const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); |
| 96 | + if (val > max_val || (val == max_val && expert < max_expert)) { |
| 97 | + max_val = val; |
| 98 | + max_expert = expert; |
| 99 | + } |
| 100 | + } |
| 101 | + |
| 102 | + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { |
| 103 | + wt[max_expert / WARP_SIZE] = -INFINITY; |
| 104 | + |
| 105 | + wt_shared_ptr[k] = max_val; |
| 106 | + ids[k] = max_expert; |
| 107 | + if constexpr (with_norm) { |
| 108 | + wt_sum += max_val; |
| 109 | + } |
| 110 | + } |
| 111 | + } |
| 112 | + |
| 113 | + if constexpr (with_norm) { |
| 114 | + wt_sum = warp_reduce_sum(wt_sum); |
| 115 | + const float inv_sum = 1.0f / wt_sum; |
| 116 | + |
| 117 | + for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) { |
| 118 | + wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum; |
| 119 | + } |
| 120 | + } |
| 121 | + |
| 122 | + for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) { |
| 123 | + weights[i] = wt_shared_ptr[i]; |
| 124 | + } |
| 125 | +} |
| 126 | + |
| 127 | +template <bool with_norm> |
| 128 | +static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, |
| 129 | + const float * logits, |
| 130 | + float * weights, |
| 131 | + int32_t * ids, |
| 132 | + const int n_rows, |
| 133 | + const int n_expert, |
| 134 | + const int n_expert_used) { |
| 135 | + const int rows_per_block = 4; |
| 136 | + dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); |
| 137 | + dim3 block_dims(WARP_SIZE, rows_per_block, 1); |
| 138 | + cudaStream_t stream = ctx.stream(); |
| 139 | + |
| 140 | + const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float); |
| 141 | + |
| 142 | + switch (n_expert) { |
| 143 | + case 1: |
| 144 | + topk_moe_cuda<1, with_norm> |
| 145 | + <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); |
| 146 | + break; |
| 147 | + case 2: |
| 148 | + topk_moe_cuda<2, with_norm> |
| 149 | + <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); |
| 150 | + break; |
| 151 | + case 4: |
| 152 | + topk_moe_cuda<4, with_norm> |
| 153 | + <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); |
| 154 | + break; |
| 155 | + case 8: |
| 156 | + topk_moe_cuda<8, with_norm> |
| 157 | + <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); |
| 158 | + break; |
| 159 | + case 16: |
| 160 | + topk_moe_cuda<16, with_norm> |
| 161 | + <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); |
| 162 | + break; |
| 163 | + case 32: |
| 164 | + topk_moe_cuda<32, with_norm> |
| 165 | + <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); |
| 166 | + break; |
| 167 | + case 64: |
| 168 | + topk_moe_cuda<64, with_norm> |
| 169 | + <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); |
| 170 | + break; |
| 171 | + case 128: |
| 172 | + topk_moe_cuda<128, with_norm> |
| 173 | + <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); |
| 174 | + break; |
| 175 | + case 256: |
| 176 | + topk_moe_cuda<256, with_norm> |
| 177 | + <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); |
| 178 | + break; |
| 179 | + case 512: |
| 180 | + topk_moe_cuda<512, with_norm> |
| 181 | + <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); |
| 182 | + break; |
| 183 | + default: |
| 184 | + GGML_ASSERT(false && "fatal error"); |
| 185 | + break; |
| 186 | + } |
| 187 | +} |
| 188 | + |
| 189 | +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, |
| 190 | + const ggml_tensor * logits, |
| 191 | + ggml_tensor * weights, |
| 192 | + ggml_tensor * ids, |
| 193 | + const bool with_norm) { |
| 194 | + GGML_ASSERT(logits->type == GGML_TYPE_F32); |
| 195 | + GGML_ASSERT(weights->type == GGML_TYPE_F32); |
| 196 | + GGML_ASSERT(ids->type == GGML_TYPE_I32); |
| 197 | + |
| 198 | + const int n_experts = logits->ne[0]; |
| 199 | + const int n_rows = logits->ne[1]; |
| 200 | + |
| 201 | + const float * logits_d = (const float *) logits->src[0]->data; |
| 202 | + float * weights_d = (float *) weights->data; |
| 203 | + int32_t * ids_d = (int32_t *) ids->data; |
| 204 | + |
| 205 | + GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); |
| 206 | + |
| 207 | + cudaStream_t stream = ctx.stream(); |
| 208 | + |
| 209 | + const int n_expert_used = weights->ne[1]; |
| 210 | + |
| 211 | + if (with_norm) { |
| 212 | + launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); |
| 213 | + } else { |
| 214 | + launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); |
| 215 | + } |
| 216 | +} |
| 217 | + |
| 218 | +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) { |
| 219 | + float scale = 1.0f; |
| 220 | + float max_bias = 0.0f; |
| 221 | + |
| 222 | + memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float)); |
| 223 | + memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float)); |
| 224 | + |
| 225 | + if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) { |
| 226 | + return false; |
| 227 | + } |
| 228 | + |
| 229 | + if (scale != 1.0f || max_bias != 0.0f) { |
| 230 | + return false; |
| 231 | + } |
| 232 | + |
| 233 | + // don't fuse when masks or sinks are present |
| 234 | + if (softmax->src[1] || softmax->src[2]) { |
| 235 | + return false; |
| 236 | + } |
| 237 | + |
| 238 | + const int n_expert = softmax->ne[0]; |
| 239 | + // n_expert must be a power of 2 |
| 240 | + if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) { |
| 241 | + return false; |
| 242 | + } |
| 243 | + |
| 244 | + return true; |
| 245 | +} |
| 246 | + |
| 247 | +std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) { |
| 248 | + static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, |
| 249 | + GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, |
| 250 | + GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE }; |
| 251 | + |
| 252 | + static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, |
| 253 | + GGML_OP_VIEW, GGML_OP_GET_ROWS }; |
| 254 | + |
| 255 | + if (norm) { |
| 256 | + return norm_ops; |
| 257 | + } |
| 258 | + return no_norm_ops; |
| 259 | +} |
0 commit comments