Skip to content

Commit 47710eb

Browse files
committed
cpu: x64: matmul: enables f16 with int weights decompression for brgemm
1 parent 59fae60 commit 47710eb

File tree

4 files changed

+179
-64
lines changed

4 files changed

+179
-64
lines changed

src/cpu/x64/matmul/brgemm_matmul.cpp

+14-7
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
6262
= everyone_is(f16, src_dt, wei_dt) && one_of(dst_dt, f16, f32);
6363
const bool is_bf16_with_int_wei = src_dt == bf16
6464
&& one_of(wei_dt, s8, u8, s4, u4) && one_of(dst_dt, bf16, f32);
65+
const bool is_f16_with_int_wei = src_dt == f16
66+
&& one_of(wei_dt, s8, u8, s4, u4) && one_of(dst_dt, f16, f32);
6567

6668
auto check_bias = [&]() -> bool {
6769
const auto bia_dt = weights_md(1)->data_type;
@@ -86,7 +88,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
8688
if (N() == DNNL_RUNTIME_DIM_VAL) ok = false;
8789
}
8890
// Impl suppports f32 scales only for non-weight decompression
89-
if (!is_bf16_with_int_wei) {
91+
if (!(is_bf16_with_int_wei || is_f16_with_int_wei)) {
9092
ok = ok && one_of(asc.get_data_type(DNNL_ARG_SRC), undef, f32);
9193
ok = ok && one_of(asc.get_data_type(DNNL_ARG_WEIGHTS), undef, f32);
9294
ok = ok && one_of(asc.get_data_type(DNNL_ARG_DST), undef, f32);
@@ -97,7 +99,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
9799
auto check_attr_zero_points
98100
= [&]() -> bool { return attr()->zero_points_.common(); };
99101
const bool problem_dt_correct = one_of(true, is_int8, is_f8, is_bf16,
100-
is_f32, is_f16, is_bf16_with_int_wei);
102+
is_f32, is_f16, is_bf16_with_int_wei, is_f16_with_int_wei);
101103

102104
auto src_d = memory_desc_wrapper(src_md_);
103105
auto weights_d = memory_desc_wrapper(weights_md_);
@@ -150,9 +152,12 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
150152
// non-amx isa. s8s8 proplem type is exception to avoid compensations
151153
// processing for tail kernel
152154
const auto backup_isa = is_amx && bgmmc_.is_runtime_M && !is_s8s8
153-
? (is_f16 ? avx512_core_fp16
154-
: (is_bf16 ? avx512_core_bf16
155-
: (is_int8 ? avx512_core_vnni : avx512_core)))
155+
? (is_f16 || is_f16_with_int_wei
156+
? avx512_core_fp16
157+
: (is_bf16 || is_bf16_with_int_wei
158+
? avx512_core_bf16
159+
: (is_int8 ? avx512_core_vnni
160+
: avx512_core)))
156161
: isa;
157162
for_(int i_bs = 0; i_bs < 2; i_bs++)
158163
for_(int i_init = 0; i_init < 2; i_init++)
@@ -788,7 +793,8 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
788793
ctx.current_K_start = k;
789794
ctx.current_K_iters = nstl::min(bgmmc.K_blk, bgmmc.K);
790795
ctx.scales_ptr = (void *)brgmm_ctx.get_oscales_ptr(n, k);
791-
if (bgmmc.blocked_B && isa == avx512_core_fp16) {
796+
if (bgmmc.blocked_B && !bgmmc.is_f16_with_int_wei
797+
&& isa == avx512_core_fp16) {
792798
cvt_float16_to_float((float *)ctx.tr_src, (float16_t *)ctx.src,
793799
bgmmc.wei_n_blk * ctx.current_K_iters);
794800
} else {
@@ -805,7 +811,8 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
805811
ctx.current_K_start = k;
806812
ctx.current_K_iters = bgmmc.K % bgmmc.K_blk;
807813
ctx.scales_ptr = (void *)brgmm_ctx.get_oscales_ptr(n, k);
808-
if (bgmmc.blocked_B && isa == avx512_core_fp16) {
814+
if (bgmmc.blocked_B && !bgmmc.is_f16_with_int_wei
815+
&& isa == avx512_core_fp16) {
809816
cvt_float16_to_float((float *)ctx.tr_src, (float16_t *)ctx.src,
810817
bgmmc.wei_n_blk * ctx.current_K_iters);
811818
} else {

0 commit comments

Comments
 (0)