Skip to content

Commit 2f74c35

Browse files
authored
graph : make FA compatible with MLA + add initial Metal kernels (#12953)
* graph : make mla compatible with FA * metal : add exp FA kernels for DeepSeek models ggml-ci * llama : minor naming updates ggml-ci * ggml : disable FA for DS head sizes * tests : add FA tests for MLA shapes ggml-ci
1 parent 207c22e commit 2f74c35

File tree

8 files changed

+117
-26
lines changed

8 files changed

+117
-26
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

+4
Original file line numberDiff line numberDiff line change
@@ -3237,6 +3237,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32373237
if (op->src[0]->ne[0] == 192) {
32383238
return false;
32393239
}
3240+
if (op->src[0]->ne[0] == 576) {
3241+
// DeepSeek MLA
3242+
return false;
3243+
}
32403244
if (op->src[0]->ne[3] != 1) {
32413245
return false;
32423246
}

ggml/src/ggml-metal/ggml-metal.m

+77-6
Large diffs are not rendered by default.

ggml/src/ggml-metal/ggml-metal.metal

+17
Original file line numberDiff line numberDiff line change
@@ -3546,6 +3546,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_at
35463546
template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
35473547
template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
35483548
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
3549+
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
35493550

35503551
#if defined(GGML_METAL_USE_BF16)
35513552
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
@@ -3556,6 +3557,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_at
35563557
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
35573558
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
35583559
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
3560+
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
35593561
#endif
35603562

35613563
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
@@ -3566,6 +3568,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_at
35663568
template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
35673569
template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
35683570
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
3571+
template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
35693572

35703573
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
35713574
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
@@ -3575,6 +3578,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_at
35753578
template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
35763579
template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
35773580
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
3581+
template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
35783582

35793583
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
35803584
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
@@ -3584,6 +3588,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_at
35843588
template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
35853589
template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
35863590
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
3591+
template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
35873592

35883593
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
35893594
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
@@ -3593,6 +3598,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_at
35933598
template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
35943599
template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
35953600
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
3601+
template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
35963602

35973603
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
35983604
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
@@ -3602,6 +3608,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_at
36023608
template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
36033609
template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
36043610
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
3611+
template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
36053612

36063613
#undef FA_TYPES
36073614

@@ -4009,6 +4016,16 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_
40094016
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 4>;
40104017
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 4>;
40114018

4019+
template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
4020+
#if defined(GGML_METAL_USE_BF16)
4021+
template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
4022+
#endif
4023+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
4024+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
4025+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
4026+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
4027+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
4028+
40124029
#undef FA_TYPES
40134030

40144031
template<typename T>

ggml/src/ggml-vulkan/ggml-vulkan.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -9261,6 +9261,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
92619261
case 112:
92629262
case 128:
92639263
case 256:
9264+
case 575: // DeepSeek MLA
92649265
break;
92659266
default:
92669267
return false;

src/llama-context.cpp

+3-8
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ ggml_tensor * llama_context::build_rope_shift(
484484

485485
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
486486
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
487-
const float yarn_attn_factor_scaled = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
487+
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
488488

489489
ggml_tensor * tmp;
490490

@@ -504,14 +504,14 @@ ggml_tensor * llama_context::build_rope_shift(
504504

505505
tmp = ggml_rope_ext_inplace(ctx0, tmp,
506506
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
507-
yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow);
507+
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
508508

509509
tmp = ggml_cpy(ctx0, tmp, cur);
510510
} else {
511511
// we rotate only the first n_rot dimensions
512512
tmp = ggml_rope_ext_inplace(ctx0, cur,
513513
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
514-
yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow);
514+
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
515515
}
516516

517517
return tmp;
@@ -2278,11 +2278,6 @@ llama_context * llama_init_from_model(
22782278
params.flash_attn = false;
22792279
}
22802280

2281-
if (params.flash_attn && model->arch == LLM_ARCH_DEEPSEEK2) {
2282-
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Deepseek2 - forcing off\n", __func__);
2283-
params.flash_attn = false;
2284-
}
2285-
22862281
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
22872282
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
22882283
return nullptr;

src/llama-graph.cpp

+8-6
Original file line numberDiff line numberDiff line change
@@ -1200,9 +1200,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12001200
//const auto & n_embd_head_k = hparams.n_embd_head_k;
12011201
//const auto & n_embd_head_v = hparams.n_embd_head_v;
12021202

1203-
// note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla
1204-
const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mla->ne[1];
1205-
12061203
const auto n_tokens = q->ne[1];
12071204
const auto n_head = q->ne[2];
12081205
const auto n_kv = k->ne[1];
@@ -1231,7 +1228,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12311228

12321229
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
12331230

1234-
cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
1231+
if (v_mla) {
1232+
cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
1233+
cur = ggml_mul_mat(ctx0, v_mla, cur);
1234+
}
1235+
1236+
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
12351237
} else {
12361238
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
12371239

@@ -1274,9 +1276,9 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12741276
kqv = ggml_mul_mat(ctx0, v_mla, kqv);
12751277
}
12761278

1277-
ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1279+
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
12781280

1279-
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
1281+
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
12801282

12811283
if (!cparams.offload_kqv) {
12821284
// all nodes between the KV store and the attention output are run on the CPU

src/llama-model.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -10050,7 +10050,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
1005010050
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
1005110051
const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
1005210052
const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(n_embd_head_k));
10053-
const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
10053+
const float attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
1005410054

1005510055
ggml_tensor * cur;
1005610056
ggml_tensor * inpL;
@@ -10127,13 +10127,13 @@ struct llm_build_deepseek2 : public llm_graph_context {
1012710127

1012810128
q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr,
1012910129
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
10130-
ext_factor, attn_factor_scaled, beta_fast, beta_slow
10130+
ext_factor, attn_factor, beta_fast, beta_slow
1013110131
);
1013210132
cb(q_pe, "q_pe", il);
1013310133

1013410134
k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr,
1013510135
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
10136-
ext_factor, attn_factor_scaled, beta_fast, beta_slow
10136+
ext_factor, attn_factor, beta_fast, beta_slow
1013710137
);
1013810138
cb(k_pe, "k_pe", il);
1013910139

tests/test-backend-ops.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -4428,10 +4428,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
44284428
test_cases.emplace_back(new test_timestep_embedding());
44294429
test_cases.emplace_back(new test_leaky_relu());
44304430

4431-
for (int hsk : { 64, 80, 128, 192, 256, }) {
4432-
for (int hsv : { 64, 80, 128, 192, 256, }) {
4433-
if (hsk != 192 && hsk != hsv) continue;
4431+
for (int hsk : { 64, 80, 128, 192, 256, 576 }) {
4432+
for (int hsv : { 64, 80, 128, 192, 256, 512 }) {
4433+
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
44344434
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
4435+
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
44354436

44364437
for (bool mask : { true, false } ) {
44374438
for (float max_bias : { 0.0f, 8.0f }) {

0 commit comments

Comments
 (0)