Skip to content

Commit

Permalink
cpu: x64: matmul: add f32:f16 support on avx512_core and avx2
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Dec 14, 2024
1 parent 61af1ce commit b9440a2
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 30 deletions.
8 changes: 4 additions & 4 deletions src/cpu/matmul/ref_matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ struct ref_matmul_t : public primitive_t {
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_MATMUL(
(src_type == wei_type
|| utils::one_of(wei_type, u8, s8, u4, s4)),
|| utils::one_of(wei_type, f16, u8, s8, u4, s4)),
VERBOSE_UNSUPPORTED_DT);
/* int8 weights decompression support */
VDISPATCH_MATMUL(IMPLICATION(utils::one_of(wei_type, u8, s8),
Expand All @@ -82,10 +82,10 @@ struct ref_matmul_t : public primitive_t {
utils::one_of(
bia_type, f32, bf16, f16, f8_e5m2, f8_e4m3)
&& IMPLICATION(
src_type == f32, bia_type == f32)
&& IMPLICATION(src_type == f16,
wei_type == f32, bia_type == f32)
&& IMPLICATION(wei_type == f16,
utils::one_of(bia_type, f32, f16))
&& IMPLICATION(src_type == bf16,
&& IMPLICATION(wei_type == bf16,
utils::one_of(bia_type, f32, bf16))
// TODO: any implication on allowed bias
// data type for fp8?
Expand Down
14 changes: 11 additions & 3 deletions src/cpu/x64/matmul/brgemm_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
= everyone_is(bf16, src_dt, wei_dt) && one_of(dst_dt, bf16, f32);
const bool is_f16
= everyone_is(f16, src_dt, wei_dt) && one_of(dst_dt, f16, f32);
const bool is_f32_f16
= src_dt == f32 && wei_dt == f16 && one_of(dst_dt, f16, f32);
const bool is_bf16_with_int_wei = src_dt == bf16
&& one_of(wei_dt, s8, u8, s4, u4) && one_of(dst_dt, bf16, f32);
const bool is_f16_with_int_wei = src_dt == f16
Expand Down Expand Up @@ -117,8 +119,9 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {

auto check_attr_zero_points
= [&]() -> bool { return attr()->zero_points_.common(); };
const bool problem_dt_correct = one_of(true, is_int8, is_f8, is_bf16,
is_f32, is_f16, is_bf16_with_int_wei, is_f16_with_int_wei);
const bool problem_dt_correct
= one_of(true, is_int8, is_f8, is_bf16, is_f32, is_f16, is_f32_f16,
is_bf16_with_int_wei, is_f16_with_int_wei);

auto src_d = memory_desc_wrapper(src_md_);
auto weights_d = memory_desc_wrapper(weights_md_);
Expand Down Expand Up @@ -156,6 +159,11 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
CHECK(init_brgemm_matmul_conf(isa, bgmmc_, *desc(), src_md_, weights_md_,
dst_md_, bias_md_, attr_));

// f32:f16 configuration on AVX2 doesn't support tails with proper
// instruction sequence in copy routines. Anchor: F32_F16_AVX2_NO_TAIL.
VDISPATCH_MATMUL(IMPLICATION(is_f32_f16 && isa == avx2, bgmmc_.N % 8 == 0),
"unsupported configuration");

const float alpha = 1.0;
const float beta = 1.0;
const float beta_init = 0.0;
Expand All @@ -171,7 +179,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
// non-amx isa. s8s8 proplem type is exception to avoid compensations
// processing for tail kernel
const auto backup_isa = is_amx && bgmmc_.is_runtime_M && !is_s8s8
? (is_f16 || is_f16_with_int_wei
? (is_f16 || is_f32_f16 || is_f16_with_int_wei
? avx512_core_fp16
: (is_bf16 || is_bf16_with_int_wei
? avx512_core_bf16
Expand Down
35 changes: 30 additions & 5 deletions src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3390,7 +3390,13 @@ void jit_brgemm_matmul_copy_b_f32_t<Vmm>::load_data(

switch (dt_in_) {
case data_type::f32: uni_vmovups(vmm, op); break;
case data_type::f16: vcvtph2psx(vmm, op); break;
case data_type::f16:
if (is_superset(conf_->isa, avx512_core_fp16)) {
vcvtph2psx(vmm, op);
} else {
vcvtph2ps(vmm, op);
}
break;
case data_type::s8: uni_vpmovsxbd(vmm, op); break;
case data_type::u8: uni_vpmovzxbd(vmm, op); break;
// For int4, we see two int4 as one int8 and extend them int32
Expand Down Expand Up @@ -3593,7 +3599,7 @@ struct jit_brgemm_matmul_copy_b_transposed_t
, avx512_core_dot_product_(
do_compute_compensation_ && !isa_has_int8_vnni(conf->isa))
// See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt` used.
, use_fp16_instructions_(conf_->isa == avx512_core_fp16
, use_fp16_instructions_(is_subset(conf_->isa, avx512_core_fp16)
&& conf_->orig_wei_dt == data_type::f16
&& conf_->wei_dt == data_type::f32)
, max_tmp_idx(16
Expand Down Expand Up @@ -3987,7 +3993,11 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(
vcvtdq2ps(src_load, src_load);
maybe_apply_scales(src_reg, i * scales_K_stride_, is_tail);
} else if (use_fp16_instructions_) {
vcvtph2psx(src_load, addr);
if (conf_->isa == avx512_core_fp16) {
vcvtph2psx(src_load, addr);
} else {
vcvtph2ps(src_load, addr);
}
} else {
vmovdqu8(src_load, addr);
}
Expand Down Expand Up @@ -4131,6 +4141,7 @@ void jit_brgemm_matmul_copy_b_transposed_t<Ymm>::copy_row_x_col(
if (!nrows) return;

const int columns_tail = ncolumns % k_blk_step_;

auto load = [this, nrows, columns_tail](int i) {
auto vmm_src = src_vmm(i);

Expand All @@ -4149,11 +4160,25 @@ void jit_brgemm_matmul_copy_b_transposed_t<Ymm>::copy_row_x_col(
uni_vpxor(vmm_src, vmm_src, vmm_src);
return;
}

if (columns_tail > 0) {
load_bytes(vmm_src, reg_src, i * src_stride_,
columns_tail * typesize_);
} else
uni_vmovups(vmm_src, ptr[reg_src + i * src_stride_]);
if (use_fp16_instructions_) {
// For f32:f16 case need to convert raw bytes after `load_bytes`
// into f32 values.
vcvtph2ps(vmm_src, Xmm(vmm_src.getIdx()));
}
} else {
const auto src_offset = (i * src_stride_);
const auto addr = EVEX_compress_addr(reg_src, src_offset);
if (use_fp16_instructions_) {
// For non-tailed case can use the convert instruction directly.
vcvtph2ps(vmm_src, addr);
} else {
uni_vmovups(vmm_src, addr);
}
}

L(load_done);
};
Expand Down
41 changes: 30 additions & 11 deletions src/cpu/x64/matmul/brgemm_matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ status_t check_isa_with_datatype(
&& IMPLICATION(bm_conf_utils.is_f16(),
one_of(isa, avx512_core_amx_fp16, avx512_core_fp16,
avx2_vnni_2))
&& IMPLICATION(bm_conf_utils.is_f32_f16(),
one_of(isa, avx512_core_amx_fp16, avx512_core_fp16,
avx2_vnni_2, avx512_core, avx2))
&& IMPLICATION(bm_conf_utils.is_int8_with_bf16_dst(),
is_superset(isa, avx512_core) || isa == avx2_vnni_2)
&& IMPLICATION(bm_conf_utils.is_bf16_with_int_wei(),
Expand All @@ -202,12 +205,12 @@ status_t check_isa_with_datatype(
}

status_t check_datatype_cfg(const brgemm_matmul_conf_utils_t &bm_conf_utils) {
const bool ok
= one_of(true, bm_conf_utils.is_f32(), bm_conf_utils.is_bf16(),
bm_conf_utils.is_f16(), bm_conf_utils.is_bf32(),
bm_conf_utils.is_f8(), bm_conf_utils.is_int8(),
bm_conf_utils.is_bf16_with_int_wei(),
bm_conf_utils.is_f16_with_int_wei())
const bool ok = one_of(true, bm_conf_utils.is_f32(),
bm_conf_utils.is_bf16(), bm_conf_utils.is_f16(),
bm_conf_utils.is_f32_f16(), bm_conf_utils.is_bf32(),
bm_conf_utils.is_f8(), bm_conf_utils.is_int8(),
bm_conf_utils.is_bf16_with_int_wei(),
bm_conf_utils.is_f16_with_int_wei())
&& IMPLICATION(bm_conf_utils.is_bf16_with_int_wei()
|| bm_conf_utils.is_f16_with_int_wei(),
bm_conf_utils.with_weights_decompression());
Expand Down Expand Up @@ -242,6 +245,10 @@ brgemm_matmul_conf_utils_t::brgemm_matmul_conf_utils_t(
&& attr.fpmath_.apply_to_int_)
, bf16_with_int_wei_dt(weights_decompression_support && bgmmc.src_dt == bf16
&& one_of(bgmmc.dst_dt, bf16, f32))
// Keep this var separate from f16_dt to not slip f16:f16 on avx512_core and
// avx2 as there's no kernel for such combination.
, f32_f16_dt(bgmmc.src_dt == f32 && bgmmc.wei_dt == f16
&& one_of(bgmmc.dst_dt, f16, f32))
, f16_with_int_wei_dt(weights_decompression_support && bgmmc.src_dt == f16
&& one_of(bgmmc.dst_dt, f16, f32))
, A_any_layout(A_any_layout)
Expand Down Expand Up @@ -362,7 +369,8 @@ status_t brgemm_matmul_conf_utils_t::set_or_check_tags(memory_desc_t &A_md,
= this->is_int8() && is_superset(bgmmc.isa, avx512_core);
const bool is_adbc_allowed
= (this->is_bf16() || this->is_f32() || this->is_bf32()
|| this->is_f16() || this->is_bf16_with_int_wei()
|| this->is_f16() || this->is_f32_f16()
|| this->is_bf16_with_int_wei()
|| this->is_f16_with_int_wei())
&& !xf16_avx2_vnni_2;
bgmmc.src_tag = is_adbc_allowed
Expand Down Expand Up @@ -465,8 +473,9 @@ format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout(
}

if (this->is_bf16() || this->is_bf16_with_int_wei()
|| ((this->is_f16() || this->is_f16_with_int_wei())
&& bgmmc.isa != avx512_core_fp16))
|| ((this->is_f16() || this->is_f32_f16()
|| this->is_f16_with_int_wei())
&& is_superset(bgmmc.isa, avx512_core_amx)))
switch (n_blk) {
case 64: return bgmmc.ndims == 3 ? aCB16b64c2b : BA16a64b2a;
case 48: return bgmmc.ndims == 3 ? aCB16b48c2b : BA16a48b2a;
Expand All @@ -476,7 +485,7 @@ format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout(
}
// Note: bf32 assumes f32 blocking
if (this->is_f32() || this->is_bf32() || this->is_f16()
|| this->is_f16_with_int_wei())
|| this->is_f32_f16() || this->is_f16_with_int_wei())
switch (n_blk) {
case 64: return bgmmc.ndims == 3 ? aCB16b64c : BA16a64b;
case 48: return bgmmc.ndims == 3 ? aCB16b48c : BA16a48b;
Expand Down Expand Up @@ -730,7 +739,7 @@ void compute_blocking_heuristic_amx(const brgemm_matmul_conf_t &bgmmc,
= div_up(static_cast<int>(bgmmc.K), min_k_per_thread);
const bool is_amx_xf16 = bgmmc.is_amx
&& (bm_conf_utils.is_bf16() || bm_conf_utils.is_f16()
|| bm_conf_utils.is_bf32()
|| bm_conf_utils.is_f32_f16() || bm_conf_utils.is_bf32()
|| bm_conf_utils.is_bf16_with_int_wei()
|| bm_conf_utils.is_f16_with_int_wei());
const bool is_amx_int8 = bgmmc.is_amx && bm_conf_utils.is_int8();
Expand Down Expand Up @@ -1284,6 +1293,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
bgmmc.is_bf32 = bm_conf_utils.is_bf32();
bgmmc.is_bf16_with_int_wei = bm_conf_utils.is_bf16_with_int_wei();
bgmmc.is_f16_with_int_wei = bm_conf_utils.is_f16_with_int_wei();
bgmmc.is_f32_f16 = bm_conf_utils.is_f32_f16();
bgmmc.with_wei_decompression = bm_conf_utils.with_weights_decompression();
bgmmc.is_int4_weights = one_of(bgmmc.wei_dt, data_type::s4, data_type::u4);

Expand All @@ -1301,6 +1311,13 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
bgmmc.wei_dt = f32;
bgmmc.tr_a_dt_sz = types::data_type_size(f32);
bgmmc.tr_b_dt_sz = types::data_type_size(f32);
} else if (bm_conf_utils.is_f32_f16() && is_superset(bgmmc.isa, avx2)) {
// Keep this branch separately from f16 one to have less restrictive
// ISA condition. For the rest, same upconvert mechanism.
bgmmc.src_dt = f32;
bgmmc.wei_dt = f32;
bgmmc.tr_a_dt_sz = types::data_type_size(f32);
bgmmc.tr_b_dt_sz = types::data_type_size(f32);
} else if (bgmmc.is_f16_with_int_wei && bgmmc.isa != avx512_core_fp16) {
bgmmc.src_dt = f16;
bgmmc.wei_dt = f16;
Expand Down Expand Up @@ -1518,6 +1535,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
bm_conf_utils.is_bf16_with_int_wei(),
(bgmmc.is_amx
&& (bm_conf_utils.is_f16()
|| bm_conf_utils.is_f32_f16()
|| bm_conf_utils.is_f16_with_int_wei())))
&& (bgmmc.isa != avx2_vnni_2) // no perf study yet.
&& bgmmc.lda_big_pow2() && bgmmc.M >= 1024;
Expand Down Expand Up @@ -1629,6 +1647,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
is_small_shapes = is_small_shapes && (bgmmc.isa != avx512_core_amx_fp16);

if (bm_conf_utils.is_bf16() || bm_conf_utils.is_f16()
|| bm_conf_utils.is_f32_f16()
|| bm_conf_utils.is_bf16_with_int_wei()
|| bm_conf_utils.is_f16_with_int_wei()) {
// empirical observation for performance breakpoint between amx and vnni
Expand Down
10 changes: 6 additions & 4 deletions src/cpu/x64/matmul/brgemm_matmul_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ struct brgemm_matmul_conf_t {
bool is_bf32 = false;
bool is_bf16_with_int_wei = false;
bool is_f16_with_int_wei = false;
bool is_f32_f16 = false;
bool is_int4_weights = false;
bool req_wei_vnni_downconvert = false;
bool is_runtime_M = false;
Expand Down Expand Up @@ -230,10 +231,9 @@ struct brgemm_matmul_conf_utils_t {
inline bool use_buffer_b(bool use_heuristic = true) const {
if (bgmmc.is_runtime_N) return true;
if (bgmmc.is_bf16_with_int_wei) return true;
if (bgmmc.is_f32_f16) return true;
if (bgmmc.is_f16_with_int_wei) return true;
if (bgmmc.apply_scales_in_buffer_b) return true;
if (utils::one_of(true, bgmmc.is_runtime_N, bgmmc.is_bf16_with_int_wei,
bgmmc.is_f16_with_int_wei, bgmmc.apply_scales_in_buffer_b))
return true;

if (bgmmc.is_amx)
// use b_buffer for AMX when:
Expand Down Expand Up @@ -302,6 +302,8 @@ struct brgemm_matmul_conf_utils_t {

inline bool is_bf16_with_int_wei() const { return bf16_with_int_wei_dt; }

inline bool is_f32_f16() const { return f32_f16_dt; }

inline bool is_f16_with_int_wei() const { return f16_with_int_wei_dt; }

inline bool with_weights_decompression() const {
Expand Down Expand Up @@ -339,7 +341,7 @@ struct brgemm_matmul_conf_utils_t {
brgemm_matmul_conf_t &bgmmc;

const bool f32_dt, bf16_dt, f16_dt, f8_dt, int8_dt, bf32_dt;
const bool weights_decompression_support, bf16_with_int_wei_dt,
const bool weights_decompression_support, bf16_with_int_wei_dt, f32_f16_dt,
f16_with_int_wei_dt;

const bool A_any_layout;
Expand Down
7 changes: 4 additions & 3 deletions tests/benchdnn/inputs/matmul/test_matmul_float16
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# f16
--reset

--dt=f16:f16:f32,f16
--skip-impl=ref
--dt=f16:f16:f32,f16,f32:f16:f32
--stag=ab,ba --wtag=ab,ba --dtag=ab
--runtime_dims_masks=0,2:1,1:0,3:1
--bia_dt=undef,f32 --bia_mask=2
Expand All @@ -28,13 +29,13 @@

# test any
--reset
--dt=f16:f16:f32,f16
--dt=f16:f16:f32,f16,f32:f16:f32
--stag=ab,ba,any --wtag=ab,ba,any --dtag=ab,any
--batch=shapes_2d

# 3d
--reset
--dt=f16:f16:f32,f16
--dt=f16:f16:f32,f16,f32:f16:f32
--stag=abc,acb --wtag=abc,acb --dtag=abc
--bia_dt=undef,f32 --bia_mask=4,6

Expand Down

0 comments on commit b9440a2

Please sign in to comment.