Skip to content

Commit 78aa2bb

Browse files
authoredMar 19, 2025··
[CPU]Fix by_channel quant for avx2 (#29532)
### Details: - *Fix dot_product kernel with avx2 support* ### Tickets: - *ticket-id*
1 parent df40180 commit 78aa2bb

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed
 

‎src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,8 @@ static void quant_u8_by_channel_kernel(const T* src,
275275
}
276276
}
277277
#endif
278-
for (size_t i = 0; i < seq_dim; ++i) {
279-
for (; j < hidden_dims; j++) {
278+
for (; j < hidden_dims; j++) {
279+
for (size_t i = 0; i < seq_dim; ++i) {
280280
float tmp = src[i * src_stride + j];
281281
dst[i * dst_stride + j] = static_cast<uint8_t>(std::round(tmp / scale[j] + zp[j]));
282282
}

‎src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -681,10 +681,10 @@ static void dot_product_block_by_channel(TA* a, uint8_t* b, float* c, const size
681681
auto va2 = mm256_uni_loadu_ps(a + i + vec_len_f32_avx2 * 2);
682682
auto va3 = mm256_uni_loadu_ps(a + i + vec_len_f32_avx2 * 3);
683683

684-
auto vb0_128 = _mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i));
685-
auto vb1_128 = _mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i + vec_len_f32_avx2));
686-
auto vb2_128 = _mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i + vec_len_f32_avx2 * 2));
687-
auto vb3_128 = _mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i + vec_len_f32_avx2 * 3));
684+
auto vb0_128 = _mm_loadl_epi64(reinterpret_cast<__m128i*>(b + params_offset + i));
685+
auto vb1_128 = _mm_loadl_epi64(reinterpret_cast<__m128i*>(b + params_offset + i + vec_len_f32_avx2));
686+
auto vb2_128 = _mm_loadl_epi64(reinterpret_cast<__m128i*>(b + params_offset + i + vec_len_f32_avx2 * 2));
687+
auto vb3_128 = _mm_loadl_epi64(reinterpret_cast<__m128i*>(b + params_offset + i + vec_len_f32_avx2 * 3));
688688

689689
auto vb0_256 = _mm256_cvtepu8_epi32(vb0_128);
690690
auto vb1_256 = _mm256_cvtepu8_epi32(vb1_128);

0 commit comments

Comments
 (0)
Please sign in to comment.