Skip to content

Commit 3f81b4e

Browse files
authored
vulkan: support GET_ROWS for k-quants (#16235)
The dequantize functions are copy/pasted from mul_mm_funcs.comp with very few changes - add a_offset and divide iqs by 2. It's probably possible to call these functions from mul_mm_funcs and avoid the duplication, but I didn't go that far in this change.
1 parent ace6a54 commit 3f81b4e

File tree

4 files changed

+162
-8
lines changed

4 files changed

+162
-8
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3256,6 +3256,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
32563256
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
32573257
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
32583258
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3259+
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q2_K], "get_rows_q2_k", get_rows_q2_k_len, get_rows_q2_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3260+
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q3_K], "get_rows_q3_k", get_rows_q3_k_len, get_rows_q3_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3261+
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_K], "get_rows_q4_k", get_rows_q4_k_len, get_rows_q4_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3262+
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_K], "get_rows_q5_k", get_rows_q5_k_len, get_rows_q5_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3263+
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q6_K], "get_rows_q6_k", get_rows_q6_k_len, get_rows_q6_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
32593264
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_S], "get_rows_iq1_s", get_rows_iq1_s_len, get_rows_iq1_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
32603265
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_M], "get_rows_iq1_m", get_rows_iq1_m_len, get_rows_iq1_m_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
32613266
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
@@ -3275,6 +3280,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
32753280
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
32763281
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
32773282
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3283+
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q2_K], "get_rows_q2_k_f32", get_rows_q2_k_f32_len, get_rows_q2_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3284+
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q3_K], "get_rows_q3_k_f32", get_rows_q3_k_f32_len, get_rows_q3_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3285+
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_K], "get_rows_q4_k_f32", get_rows_q4_k_f32_len, get_rows_q4_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3286+
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_K], "get_rows_q5_k_f32", get_rows_q5_k_f32_len, get_rows_q5_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3287+
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q6_K], "get_rows_q6_k_f32", get_rows_q6_k_f32_len, get_rows_q6_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
32783288
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_S], "get_rows_iq1_s_f32", get_rows_iq1_s_f32_len, get_rows_iq1_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
32793289
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_M], "get_rows_iq1_m_f32", get_rows_iq1_m_f32_len, get_rows_iq1_m_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
32803290
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
@@ -12613,6 +12623,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1261312623
case GGML_TYPE_Q5_0:
1261412624
case GGML_TYPE_Q5_1:
1261512625
case GGML_TYPE_Q8_0:
12626+
case GGML_TYPE_Q2_K:
12627+
case GGML_TYPE_Q3_K:
12628+
case GGML_TYPE_Q4_K:
12629+
case GGML_TYPE_Q5_K:
12630+
case GGML_TYPE_Q6_K:
1261612631
case GGML_TYPE_IQ1_S:
1261712632
case GGML_TYPE_IQ1_M:
1261812633
case GGML_TYPE_IQ2_XXS:

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,3 +478,139 @@ vec2 get_dm(uint ib, uint a_offset) {
478478
return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
479479
}
480480
#endif
481+
482+
#if defined(DATA_A_Q2_K)
483+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
484+
iqs /= 2;
485+
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
486+
const uint scalesi = iqs / 8; // 0..15
487+
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
488+
489+
const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
490+
const uint scales = data_a[a_offset + ib].scales[scalesi];
491+
const vec2 d = vec2(data_a[a_offset + ib].d);
492+
493+
return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
494+
}
495+
vec2 get_dm(uint ib, uint a_offset) {
496+
return vec2(1, 0);
497+
}
498+
#endif
499+
500+
#if defined(DATA_A_Q3_K)
501+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
502+
iqs /= 2;
503+
const uint n = iqs / 64; // 0,1
504+
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
505+
const uint hmi = (iqs % 16) * 2; // 0,2,4..30
506+
const uint j = (iqs % 64) / 4; // 0..3
507+
const uint is = iqs / 8; // 0..15
508+
const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3
509+
const uint qsshift = halfsplit * 2; // 0,2,4,6
510+
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
511+
512+
const int8_t us = int8_t(((data_a[a_offset + ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)
513+
| (((data_a[a_offset + ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
514+
const float dl = float(data_a[a_offset + ib].d) * float(us - 32);
515+
516+
return vec2(dl * float(int8_t((data_a[a_offset + ib].qs[qsi ] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi ] & m) != 0) ? 0 : 4)),
517+
dl * float(int8_t((data_a[a_offset + ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
518+
}
519+
vec2 get_dm(uint ib, uint a_offset) {
520+
return vec2(1, 0);
521+
}
522+
#endif
523+
524+
#if defined(DATA_A_Q4_K)
525+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
526+
iqs /= 2;
527+
const uint n = iqs / 32; // 0,1,2,3
528+
const uint b = (iqs % 32) / 16; // 0,1
529+
const uint is = 2 * n + b; // 0..7
530+
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
531+
532+
const vec2 loadd = vec2(data_a[a_offset + ib].d);
533+
534+
const uint scidx0 = (is < 4) ? is : (is + 4);
535+
const uint scidx1 = (is < 4) ? is : (is - 4);
536+
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
537+
const uint scidxshift1 = (is < 4) ? 0 : 2;
538+
const uint mbidx0 = is + 4;
539+
const uint mbidx1 = (is < 4) ? is + 4 : is;
540+
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
541+
const uint mbidxshift0 = (is < 4) ? 0 : 4;
542+
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
543+
const uint mbidxshift1 = (is < 4) ? 0 : 2;
544+
545+
const uint8_t sc = uint8_t((data_a[a_offset + ib].scales[scidx0] & 0xF) | ((data_a[a_offset + ib].scales[scidx1] & scidxmask1) >> scidxshift1));
546+
const uint8_t mbyte = uint8_t((data_a[a_offset + ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[a_offset + ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
547+
548+
const float d = loadd.x * sc;
549+
const float m = -loadd.y * mbyte;
550+
551+
return vec2(fma(d, float((data_a[a_offset + ib].qs[qsi ] >> (b * 4)) & 0xF), m),
552+
fma(d, float((data_a[a_offset + ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
553+
}
554+
vec2 get_dm(uint ib, uint a_offset) {
555+
return vec2(1, 0);
556+
}
557+
#endif
558+
559+
#if defined(DATA_A_Q5_K)
560+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
561+
iqs /= 2;
562+
const uint n = iqs / 32; // 0,1,2,3
563+
const uint b = (iqs % 32) / 16; // 0,1
564+
const uint is = 2 * n + b; // 0..7
565+
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
566+
const uint qhi = (iqs % 16) * 2; // 0,2,4..30
567+
568+
const uint8_t hm = uint8_t(1 << (iqs / 16));
569+
570+
const vec2 loadd = vec2(data_a[a_offset + ib].d);
571+
572+
const uint scidx0 = (is < 4) ? is : (is + 4);
573+
const uint scidx1 = (is < 4) ? is : (is - 4);
574+
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
575+
const uint scidxshift1 = (is < 4) ? 0 : 2;
576+
const uint mbidx0 = is + 4;
577+
const uint mbidx1 = (is < 4) ? is + 4 : is;
578+
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
579+
const uint mbidxshift0 = (is < 4) ? 0 : 4;
580+
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
581+
const uint mbidxshift1 = (is < 4) ? 0 : 2;
582+
583+
const uint8_t sc = uint8_t((data_a[a_offset + ib].scales[scidx0] & 0xF) | ((data_a[a_offset + ib].scales[scidx1] & scidxmask1) >> scidxshift1));
584+
const uint8_t mbyte = uint8_t(((data_a[a_offset + ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[a_offset + ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
585+
586+
const float d = loadd.x * sc;
587+
const float m = -loadd.y * mbyte;
588+
589+
return vec2(fma(d, float((data_a[a_offset + ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[a_offset + ib].qh[qhi ] & hm) != 0 ? 16 : 0), m),
590+
fma(d, float((data_a[a_offset + ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[a_offset + ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
591+
}
592+
vec2 get_dm(uint ib, uint a_offset) {
593+
return vec2(1, 0);
594+
}
595+
#endif
596+
597+
#if defined(DATA_A_Q6_K)
598+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
599+
iqs /= 2;
600+
const uint n = iqs / 64; // 0,1
601+
const uint b = (iqs % 64) / 32; // 0,1
602+
const uint is_b = (iqs % 16) / 8; // 0,1
603+
const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
604+
const uint is = 8 * n + qhshift + is_b; // 0..15
605+
const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
606+
const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
607+
608+
const float dscale = float(data_a[a_offset + ib].d) * float(data_a[a_offset + ib].scales[is]);
609+
610+
return vec2(dscale * float(int8_t(((data_a[a_offset + ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[a_offset + ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32),
611+
dscale * float(int8_t(((data_a[a_offset + ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[a_offset + ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
612+
}
613+
vec2 get_dm(uint ib, uint a_offset) {
614+
return vec2(1, 0);
615+
}
616+
#endif

ggml/src/ggml-vulkan/vulkan-shaders/types.comp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ struct block_q2_K_packed32
245245

246246
#if defined(DATA_A_Q2_K)
247247
#define QUANT_K QUANT_K_Q2_K
248+
#define QUANT_R 1
248249
#define A_TYPE block_q2_K
249250
#define A_TYPE_PACKED16 block_q2_K_packed16
250251
#define A_TYPE_PACKED32 block_q2_K_packed32
@@ -270,6 +271,7 @@ struct block_q3_K_packed16
270271

271272
#if defined(DATA_A_Q3_K)
272273
#define QUANT_K QUANT_K_Q3_K
274+
#define QUANT_R 1
273275
#define A_TYPE block_q3_K
274276
#define A_TYPE_PACKED16 block_q3_K_packed16
275277
#endif
@@ -304,6 +306,7 @@ struct block_q4_K_packed128
304306

305307
#if defined(DATA_A_Q4_K)
306308
#define QUANT_K QUANT_K_Q4_K
309+
#define QUANT_R 1
307310
#define A_TYPE block_q4_K
308311
#define A_TYPE_PACKED16 block_q4_K_packed16
309312
#define A_TYPE_PACKED32 block_q4_K_packed32
@@ -334,6 +337,7 @@ struct block_q5_K_packed128
334337

335338
#if defined(DATA_A_Q5_K)
336339
#define QUANT_K QUANT_K_Q5_K
340+
#define QUANT_R 1
337341
#define A_TYPE block_q5_K
338342
#define A_TYPE_PACKED16 block_q5_K_packed16
339343
#endif
@@ -358,6 +362,7 @@ struct block_q6_K_packed16
358362

359363
#if defined(DATA_A_Q6_K)
360364
#define QUANT_K QUANT_K_Q6_K
365+
#define QUANT_R 1
361366
#define A_TYPE block_q6_K
362367
#define A_TYPE_PACKED16 block_q6_K_packed16
363368
#endif

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -589,16 +589,14 @@ void process_shaders() {
589589
string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
590590
}
591591

592-
if (!string_ends_with(tname, "_k")) {
593-
shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
592+
shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
594593

595-
if (tname == "f16") {
596-
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
597-
} else {
598-
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
599-
}
600-
string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
594+
if (tname == "f16") {
595+
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
596+
} else {
597+
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
601598
}
599+
string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
602600
}
603601

604602
string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});

0 commit comments

Comments
 (0)