@@ -3546,6 +3546,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_at
3546
3546
template [[host_name(" kernel_flash_attn_ext_f16_h192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 192 , 192 >;
3547
3547
template [[host_name(" kernel_flash_attn_ext_f16_hk192_hv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 192 , 128 >;
3548
3548
template [[host_name(" kernel_flash_attn_ext_f16_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 256 , 256 >;
3549
+ template [[host_name(" kernel_flash_attn_ext_f16_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 576 , 512 >;
3549
3550
3550
3551
#if defined(GGML_METAL_USE_BF16)
3551
3552
template [[host_name(" kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 64 , 64 >;
@@ -3556,6 +3557,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_at
3556
3557
template [[host_name(" kernel_flash_attn_ext_bf16_h192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 192 , 192 >;
3557
3558
template [[host_name(" kernel_flash_attn_ext_bf16_hk192_hv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 192 , 128 >;
3558
3559
template [[host_name(" kernel_flash_attn_ext_bf16_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 256 , 256 >;
3560
+ template [[host_name(" kernel_flash_attn_ext_bf16_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 576 , 512 >;
3559
3561
#endif
3560
3562
3561
3563
template [[host_name(" kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 64 , 64 >;
@@ -3566,6 +3568,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_at
3566
3568
template [[host_name(" kernel_flash_attn_ext_q4_0_h192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 192 , 192 >;
3567
3569
template [[host_name(" kernel_flash_attn_ext_q4_0_hk192_hv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 192 , 128 >;
3568
3570
template [[host_name(" kernel_flash_attn_ext_q4_0_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 256 , 256 >;
3571
+ template [[host_name(" kernel_flash_attn_ext_q4_0_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 576 , 512 >;
3569
3572
3570
3573
template [[host_name(" kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 64 , 64 >;
3571
3574
template [[host_name(" kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 80 , 80 >;
@@ -3575,6 +3578,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_at
3575
3578
template [[host_name(" kernel_flash_attn_ext_q4_1_h192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 192 , 192 >;
3576
3579
template [[host_name(" kernel_flash_attn_ext_q4_1_hk192_hv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 192 , 128 >;
3577
3580
template [[host_name(" kernel_flash_attn_ext_q4_1_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 256 , 256 >;
3581
+ template [[host_name(" kernel_flash_attn_ext_q4_1_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 576 , 512 >;
3578
3582
3579
3583
template [[host_name(" kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 64 , 64 >;
3580
3584
template [[host_name(" kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 80 , 80 >;
@@ -3584,6 +3588,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_at
3584
3588
template [[host_name(" kernel_flash_attn_ext_q5_0_h192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 192 , 192 >;
3585
3589
template [[host_name(" kernel_flash_attn_ext_q5_0_hk192_hv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 192 , 128 >;
3586
3590
template [[host_name(" kernel_flash_attn_ext_q5_0_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 256 , 256 >;
3591
+ template [[host_name(" kernel_flash_attn_ext_q5_0_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 576 , 512 >;
3587
3592
3588
3593
template [[host_name(" kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 64 , 64 >;
3589
3594
template [[host_name(" kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 80 , 80 >;
@@ -3593,6 +3598,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_at
3593
3598
template [[host_name(" kernel_flash_attn_ext_q5_1_h192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 192 , 192 >;
3594
3599
template [[host_name(" kernel_flash_attn_ext_q5_1_hk192_hv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 192 , 128 >;
3595
3600
template [[host_name(" kernel_flash_attn_ext_q5_1_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 256 , 256 >;
3601
+ template [[host_name(" kernel_flash_attn_ext_q5_1_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 576 , 512 >;
3596
3602
3597
3603
template [[host_name(" kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 64 , 64 >;
3598
3604
template [[host_name(" kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 80 , 80 >;
@@ -3602,6 +3608,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_at
3602
3608
template [[host_name(" kernel_flash_attn_ext_q8_0_h192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 192 , 192 >;
3603
3609
template [[host_name(" kernel_flash_attn_ext_q8_0_hk192_hv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 192 , 128 >;
3604
3610
template [[host_name(" kernel_flash_attn_ext_q8_0_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 256 , 256 >;
3611
+ template [[host_name(" kernel_flash_attn_ext_q8_0_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 576 , 512 >;
3605
3612
3606
3613
#undef FA_TYPES
3607
3614
@@ -4009,6 +4016,16 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_
4009
4016
template [[host_name(" kernel_flash_attn_ext_vec_q5_1_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8 , dequantize_q5_1_t4, block_q5_1, 8 , dequantize_q5_1_t4, 256 , 256 , 4 >;
4010
4017
template [[host_name(" kernel_flash_attn_ext_vec_q8_0_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8 , dequantize_q8_0_t4, block_q8_0, 8 , dequantize_q8_0_t4, 256 , 256 , 4 >;
4011
4018
4019
+ template [[host_name(" kernel_flash_attn_ext_vec_f16_hk576_hv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 576 , 512 , 2 >;
4020
+ #if defined(GGML_METAL_USE_BF16)
4021
+ template [[host_name(" kernel_flash_attn_ext_vec_bf16_hk576_hv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1 , dequantize_bf16_t4, bfloat4, 1 , dequantize_bf16_t4, 576 , 512 , 2 >;
4022
+ #endif
4023
+ template [[host_name(" kernel_flash_attn_ext_vec_q4_0_hk576_hv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8 , dequantize_q4_0_t4, block_q4_0, 8 , dequantize_q4_0_t4, 576 , 512 , 2 >;
4024
+ template [[host_name(" kernel_flash_attn_ext_vec_q4_1_hk576_hv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8 , dequantize_q4_1_t4, block_q4_1, 8 , dequantize_q4_1_t4, 576 , 512 , 2 >;
4025
+ template [[host_name(" kernel_flash_attn_ext_vec_q5_0_hk576_hv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8 , dequantize_q5_0_t4, block_q5_0, 8 , dequantize_q5_0_t4, 576 , 512 , 2 >;
4026
+ template [[host_name(" kernel_flash_attn_ext_vec_q5_1_hk576_hv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8 , dequantize_q5_1_t4, block_q5_1, 8 , dequantize_q5_1_t4, 576 , 512 , 2 >;
4027
+ template [[host_name(" kernel_flash_attn_ext_vec_q8_0_hk576_hv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8 , dequantize_q8_0_t4, block_q8_0, 8 , dequantize_q8_0_t4, 576 , 512 , 2 >;
4028
+
4012
4029
#undef FA_TYPES
4013
4030
4014
4031
template <typename T>
0 commit comments