@@ -62,6 +62,8 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
62
62
= everyone_is (f16, src_dt, wei_dt) && one_of (dst_dt, f16, f32);
63
63
const bool is_bf16_with_int_wei = src_dt == bf16
64
64
&& one_of (wei_dt, s8, u8, s4, u4) && one_of (dst_dt, bf16, f32);
65
+ const bool is_f16_with_int_wei = src_dt == f16
66
+ && one_of (wei_dt, s8, u8, s4, u4) && one_of (dst_dt, f16, f32);
65
67
66
68
auto check_bias = [&]() -> bool {
67
69
const auto bia_dt = weights_md (1 )->data_type ;
@@ -86,7 +88,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
86
88
if (N () == DNNL_RUNTIME_DIM_VAL) ok = false ;
87
89
}
88
90
// Impl suppports f32 scales only for non-weight decompression
89
- if (!is_bf16_with_int_wei) {
91
+ if (!( is_bf16_with_int_wei || is_f16_with_int_wei) ) {
90
92
ok = ok && one_of (asc.get_data_type (DNNL_ARG_SRC), undef, f32);
91
93
ok = ok && one_of (asc.get_data_type (DNNL_ARG_WEIGHTS), undef, f32);
92
94
ok = ok && one_of (asc.get_data_type (DNNL_ARG_DST), undef, f32);
@@ -97,7 +99,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
97
99
auto check_attr_zero_points
98
100
= [&]() -> bool { return attr ()->zero_points_ .common (); };
99
101
const bool problem_dt_correct = one_of (true , is_int8, is_f8, is_bf16,
100
- is_f32, is_f16, is_bf16_with_int_wei);
102
+ is_f32, is_f16, is_bf16_with_int_wei, is_f16_with_int_wei );
101
103
102
104
auto src_d = memory_desc_wrapper (src_md_);
103
105
auto weights_d = memory_desc_wrapper (weights_md_);
@@ -150,9 +152,12 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
150
152
// non-amx isa. s8s8 proplem type is exception to avoid compensations
151
153
// processing for tail kernel
152
154
const auto backup_isa = is_amx && bgmmc_.is_runtime_M && !is_s8s8
153
- ? (is_f16 ? avx512_core_fp16
154
- : (is_bf16 ? avx512_core_bf16
155
- : (is_int8 ? avx512_core_vnni : avx512_core)))
155
+ ? (is_f16 || is_f16_with_int_wei
156
+ ? avx512_core_fp16
157
+ : (is_bf16 || is_bf16_with_int_wei
158
+ ? avx512_core_bf16
159
+ : (is_int8 ? avx512_core_vnni
160
+ : avx512_core)))
156
161
: isa;
157
162
for_ (int i_bs = 0 ; i_bs < 2 ; i_bs++)
158
163
for_ (int i_init = 0 ; i_init < 2 ; i_init++)
@@ -788,7 +793,8 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
788
793
ctx.current_K_start = k;
789
794
ctx.current_K_iters = nstl::min (bgmmc.K_blk , bgmmc.K );
790
795
ctx.scales_ptr = (void *)brgmm_ctx.get_oscales_ptr (n, k);
791
- if (bgmmc.blocked_B && isa == avx512_core_fp16) {
796
+ if (bgmmc.blocked_B && !bgmmc.is_f16_with_int_wei
797
+ && isa == avx512_core_fp16) {
792
798
cvt_float16_to_float ((float *)ctx.tr_src , (float16_t *)ctx.src ,
793
799
bgmmc.wei_n_blk * ctx.current_K_iters );
794
800
} else {
@@ -805,7 +811,8 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
805
811
ctx.current_K_start = k;
806
812
ctx.current_K_iters = bgmmc.K % bgmmc.K_blk ;
807
813
ctx.scales_ptr = (void *)brgmm_ctx.get_oscales_ptr (n, k);
808
- if (bgmmc.blocked_B && isa == avx512_core_fp16) {
814
+ if (bgmmc.blocked_B && !bgmmc.is_f16_with_int_wei
815
+ && isa == avx512_core_fp16) {
809
816
cvt_float16_to_float ((float *)ctx.tr_src , (float16_t *)ctx.src ,
810
817
bgmmc.wei_n_blk * ctx.current_K_iters );
811
818
} else {
0 commit comments