Skip to content

Commit c201aa7

Browse files
author
Iwan Kawrakow
committed
iq3_kt: AVX2 GEMV
1 parent 6153d0e commit c201aa7

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed

ggml/src/iqk/iqk_gemm_ktquants.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,92 @@ void iqk_dequantize_iq3_kt_q80_r8(int n, const void * vx, size_t bx, void * vy,
531531
}
532532
}
533533

534+
template <int nrc_y>
535+
void mul_mat_iq3_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
536+
assert(n%QK_K == 0);
537+
const int nb = n/QK_K;
538+
539+
Trellis3<true, true> trellis;
540+
541+
auto shifts = _mm_set_epi32(0, 0, 4, 0);
542+
543+
constexpr int k_acc = nrc_y;
544+
545+
__m256 accd[k_acc];
546+
const block_q8_2_x4 * y[nrc_y];
547+
for (int iy = 0; iy < nrc_y; ++iy) {
548+
y[iy] = (const block_q8_2_x4 *)info.src1_row(iy);
549+
}
550+
551+
__m256i xv[4], sv[4], dot[4];
552+
__m256 scales[2];
553+
554+
auto sum_4 = [&dot] () {
555+
// dot[k] has 8 values from block k
556+
// 0 1 0 1 0 1 0 1
557+
dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[0], dot[1]), _mm256_unpackhi_epi32(dot[0], dot[1]));
558+
// 2 3 2 3 2 3 2 3
559+
dot[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[2], dot[3]), _mm256_unpackhi_epi32(dot[2], dot[3]));
560+
// 0 1 2 3 0 1 2 3
561+
dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(dot[0], dot[2]), _mm256_unpackhi_epi64(dot[0], dot[2]));
562+
return _mm256_cvtepi32_ps(dot[0]);
563+
};
564+
565+
auto compute_dot = [&dot, &xv, &sv] (const int8_t * y) {
566+
for (int k = 0; k < 4; ++k) {
567+
auto yv = _mm256_loadu_si256((const __m256i *)y + k);
568+
#ifdef HAVE_FANCY_SIMD
569+
//dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv);
570+
dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], _mm256_sign_epi8(yv, sv[k]));
571+
#else
572+
auto p = _mm256_maddubs_epi16(xv[k], _mm256_sign_epi8(yv, sv[k]));
573+
dot[k] = _mm256_madd_epi16(p, _mm256_set1_epi16(1));
574+
#endif
575+
}
576+
};
577+
578+
for (int ix = 0; ix < nrc_x; ++ix) {
579+
const float * dptr = (const float *)((const char*)vx + ix*bx);
580+
auto d = _mm256_set1_ps(dptr[0] * 1.01f);
581+
const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1);
582+
583+
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps();
584+
585+
for (int i = 0; i < nb; ++i) {
586+
auto ql = (const uint16_t *)x[i].ql;
587+
auto sign_bits = _mm256_loadu_si256((const __m256i *)x[i].qh);
588+
auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales);
589+
s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf));
590+
auto s32 = _mm256_cvtepi8_epi32(s8);
591+
auto all_scales = _mm256_mul_ps(d, _mm256_cvtepi32_ps(s32));
592+
auto scales_l = _mm256_castps256_ps128(all_scales);
593+
auto scales_h = _mm256_extractf128_ps(all_scales, 1);
594+
scales[0] = _mm256_set_m128(scales_l, scales_l);
595+
scales[1] = _mm256_set_m128(scales_h, scales_h);
596+
auto mask = _mm256_set1_epi8(1);
597+
for (int i128 = 0; i128 < 2; ++i128) {
598+
for (int k = 0; k < 4; ++k) {
599+
xv[k] = trellis.next32(ql + 16*i128 + 4*k, 4096);
600+
sv[k] = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(sign_bits, mask), mask), _mm256_set1_epi8(1));
601+
mask = _mm256_slli_epi16(mask, 1);
602+
}
603+
for (int iy = 0; iy < nrc_y; ++iy) {
604+
const block_q8_2_x4& yb = y[iy][2*i+i128];
605+
auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)yb.d)), 16));
606+
dy = _mm256_mul_ps(scales[i128], dy);
607+
auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy));
608+
compute_dot(yb.qs);
609+
accd[iy] = _mm256_fmadd_ps(d8, sum_4(), accd[iy]);
610+
}
611+
}
612+
}
613+
614+
for (int iy = 0; iy < nrc_y; ++iy) {
615+
info.store(ix, iy, hsum_float_8(accd[iy]));
616+
}
617+
}
618+
}
619+
534620
inline __m256 abs_ps(__m256 vals) {
535621
// Clear sign-bit of all the 32-bit floats in vals
536622
__m256 sign_bit = _mm256_set1_ps(-0.0f);
@@ -947,6 +1033,14 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
9471033
return false;
9481034
}
9491035

1036+
if (typeA == GGML_TYPE_IQ3_KT) {
1037+
if (typeB == GGML_TYPE_Q8_2_X4) {
1038+
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_kt_q8_2_x4_T, kernels);
1039+
return true;
1040+
}
1041+
return false;
1042+
}
1043+
9501044
if (ggml_type(typeB) != GGML_TYPE_F32) {
9511045
return false;
9521046
}

0 commit comments

Comments
 (0)