@@ -531,6 +531,92 @@ void iqk_dequantize_iq3_kt_q80_r8(int n, const void * vx, size_t bx, void * vy,
531531 }
532532}
533533
534+ template <int nrc_y>
535+ void mul_mat_iq3_kt_q8_2_x4_T (int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
536+ assert (n%QK_K == 0 );
537+ const int nb = n/QK_K;
538+
539+ Trellis3<true , true > trellis;
540+
541+ auto shifts = _mm_set_epi32 (0 , 0 , 4 , 0 );
542+
543+ constexpr int k_acc = nrc_y;
544+
545+ __m256 accd[k_acc];
546+ const block_q8_2_x4 * y[nrc_y];
547+ for (int iy = 0 ; iy < nrc_y; ++iy) {
548+ y[iy] = (const block_q8_2_x4 *)info.src1_row (iy);
549+ }
550+
551+ __m256i xv[4 ], sv[4 ], dot[4 ];
552+ __m256 scales[2 ];
553+
554+ auto sum_4 = [&dot] () {
555+ // dot[k] has 8 values from block k
556+ // 0 1 0 1 0 1 0 1
557+ dot[0 ] = _mm256_add_epi32 (_mm256_unpacklo_epi32 (dot[0 ], dot[1 ]), _mm256_unpackhi_epi32 (dot[0 ], dot[1 ]));
558+ // 2 3 2 3 2 3 2 3
559+ dot[2 ] = _mm256_add_epi32 (_mm256_unpacklo_epi32 (dot[2 ], dot[3 ]), _mm256_unpackhi_epi32 (dot[2 ], dot[3 ]));
560+ // 0 1 2 3 0 1 2 3
561+ dot[0 ] = _mm256_add_epi32 (_mm256_unpacklo_epi64 (dot[0 ], dot[2 ]), _mm256_unpackhi_epi64 (dot[0 ], dot[2 ]));
562+ return _mm256_cvtepi32_ps (dot[0 ]);
563+ };
564+
565+ auto compute_dot = [&dot, &xv, &sv] (const int8_t * y) {
566+ for (int k = 0 ; k < 4 ; ++k) {
567+ auto yv = _mm256_loadu_si256 ((const __m256i *)y + k);
568+ #ifdef HAVE_FANCY_SIMD
569+ // dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv);
570+ dot[k] = _mm256_dpbusd_epi32 (_mm256_setzero_si256 (), xv[k], _mm256_sign_epi8 (yv, sv[k]));
571+ #else
572+ auto p = _mm256_maddubs_epi16 (xv[k], _mm256_sign_epi8 (yv, sv[k]));
573+ dot[k] = _mm256_madd_epi16 (p, _mm256_set1_epi16 (1 ));
574+ #endif
575+ }
576+ };
577+
578+ for (int ix = 0 ; ix < nrc_x; ++ix) {
579+ const float * dptr = (const float *)((const char *)vx + ix*bx);
580+ auto d = _mm256_set1_ps (dptr[0 ] * 1 .01f );
581+ const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1 );
582+
583+ for (int iy = 0 ; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps ();
584+
585+ for (int i = 0 ; i < nb; ++i) {
586+ auto ql = (const uint16_t *)x[i].ql ;
587+ auto sign_bits = _mm256_loadu_si256 ((const __m256i *)x[i].qh );
588+ auto s8 = _mm_set1_epi32 (*(const uint32_t *)x[i].scales );
589+ s8 = _mm_and_si128 (_mm_srlv_epi32 (s8, shifts), _mm_set1_epi8 (0xf ));
590+ auto s32 = _mm256_cvtepi8_epi32 (s8);
591+ auto all_scales = _mm256_mul_ps (d, _mm256_cvtepi32_ps (s32));
592+ auto scales_l = _mm256_castps256_ps128 (all_scales);
593+ auto scales_h = _mm256_extractf128_ps (all_scales, 1 );
594+ scales[0 ] = _mm256_set_m128 (scales_l, scales_l);
595+ scales[1 ] = _mm256_set_m128 (scales_h, scales_h);
596+ auto mask = _mm256_set1_epi8 (1 );
597+ for (int i128 = 0 ; i128 < 2 ; ++i128 ) {
598+ for (int k = 0 ; k < 4 ; ++k) {
599+ xv[k] = trellis.next32 (ql + 16 *i128 + 4 *k, 4096 );
600+ sv[k] = _mm256_or_si256 (_mm256_cmpeq_epi8 (_mm256_and_si256 (sign_bits, mask), mask), _mm256_set1_epi8 (1 ));
601+ mask = _mm256_slli_epi16 (mask, 1 );
602+ }
603+ for (int iy = 0 ; iy < nrc_y; ++iy) {
604+ const block_q8_2_x4& yb = y[iy][2 *i+i128 ];
605+ auto dy = _mm256_castsi256_ps (_mm256_slli_epi32 (_mm256_cvtepu16_epi32 (_mm_loadu_si128 ((const __m128i *)yb.d )), 16 ));
606+ dy = _mm256_mul_ps (scales[i128 ], dy);
607+ auto d8 = _mm256_set_m128 (_mm256_castps256_ps128 (dy), _mm256_castps256_ps128 (dy));
608+ compute_dot (yb.qs );
609+ accd[iy] = _mm256_fmadd_ps (d8, sum_4 (), accd[iy]);
610+ }
611+ }
612+ }
613+
614+ for (int iy = 0 ; iy < nrc_y; ++iy) {
615+ info.store (ix, iy, hsum_float_8 (accd[iy]));
616+ }
617+ }
618+ }
619+
534620inline __m256 abs_ps (__m256 vals) {
535621 // Clear sign-bit of all the 32-bit floats in vals
536622 __m256 sign_bit = _mm256_set1_ps (-0 .0f );
@@ -947,6 +1033,14 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
9471033 return false ;
9481034 }
9491035
1036+ if (typeA == GGML_TYPE_IQ3_KT) {
1037+ if (typeB == GGML_TYPE_Q8_2_X4) {
1038+ IQK_SET_MUL_MAT_FUNCTIONS (mul_mat_iq3_kt_q8_2_x4_T, kernels);
1039+ return true ;
1040+ }
1041+ return false ;
1042+ }
1043+
9501044 if (ggml_type (typeB) != GGML_TYPE_F32) {
9511045 return false ;
9521046 }
0 commit comments