From c7827a8f3abc68c18e4c4ad563d36e2cc8947576 Mon Sep 17 00:00:00 2001 From: Dmitrii Zarukin Date: Mon, 6 Jan 2025 09:39:44 -0800 Subject: [PATCH] cpu: x64: matmul: add f32:bf16 support on avx512_core and avx2 --- src/cpu/matmul/ref_matmul.hpp | 4 +- src/cpu/x64/matmul/brgemm_matmul.cpp | 9 ++-- .../x64/matmul/brgemm_matmul_copy_utils.cpp | 28 ++++++++++- src/cpu/x64/matmul/brgemm_matmul_utils.cpp | 49 ++++++++++++------- src/cpu/x64/matmul/brgemm_matmul_utils.hpp | 5 +- .../inputs/matmul/test_matmul_bfloat16 | 6 +-- 6 files changed, 73 insertions(+), 28 deletions(-) diff --git a/src/cpu/matmul/ref_matmul.hpp b/src/cpu/matmul/ref_matmul.hpp index 8cd1f18dc50..c659bfe7f9d 100644 --- a/src/cpu/matmul/ref_matmul.hpp +++ b/src/cpu/matmul/ref_matmul.hpp @@ -61,8 +61,8 @@ struct ref_matmul_t : public primitive_t { f8_e4m3, f4_e2m1, f4_e3m0), VERBOSE_UNSUPPORTED_DT); VDISPATCH_MATMUL((src_type == wei_type - || utils::one_of(wei_type, f16, u8, s8, u4, - s4, f4_e3m0)), + || utils::one_of(wei_type, bf16, f16, u8, + s8, u4, s4, f4_e3m0)), VERBOSE_UNSUPPORTED_DT); /* int8 weights decompression support */ VDISPATCH_MATMUL(IMPLICATION(utils::one_of(wei_type, u8, s8), diff --git a/src/cpu/x64/matmul/brgemm_matmul.cpp b/src/cpu/x64/matmul/brgemm_matmul.cpp index 2912319f320..44dd510013d 100644 --- a/src/cpu/x64/matmul/brgemm_matmul.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul.cpp @@ -62,6 +62,8 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { = everyone_is(f16, src_dt, wei_dt) && one_of(dst_dt, f16, f32); const bool is_f32_f16 = src_dt == f32 && wei_dt == f16 && one_of(dst_dt, f16, f32); + const bool is_f32_bf16 + = src_dt == f32 && wei_dt == bf16 && one_of(dst_dt, bf16, f32); const bool is_bf16_with_int_wei = src_dt == bf16 && one_of(wei_dt, s8, u8, s4, u4) && one_of(dst_dt, bf16, f32); const bool is_f16_with_int_wei = src_dt == f16 @@ -121,7 +123,7 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { = [&]() -> bool { return attr()->zero_points_.common(); }; const bool problem_dt_correct = one_of(true, is_int8, is_f8, is_bf16, is_f32, is_f16, is_f32_f16, - is_bf16_with_int_wei, is_f16_with_int_wei); + is_f32_bf16, is_bf16_with_int_wei, is_f16_with_int_wei); auto src_d = memory_desc_wrapper(src_md_); auto weights_d = memory_desc_wrapper(weights_md_); @@ -161,7 +163,8 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { // f32:f16 configuration on AVX2 doesn't support tails with proper // instruction sequence in copy routines. Anchor: F32_F16_AVX2_NO_TAIL. - VDISPATCH_MATMUL(IMPLICATION(is_f32_f16 && isa == avx2, bgmmc_.N % 8 == 0), + VDISPATCH_MATMUL(IMPLICATION((is_f32_f16 || is_f32_bf16) && isa == avx2, + bgmmc_.N % 8 == 0), "unsupported configuration"); const float alpha = 1.0; @@ -181,7 +184,7 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { const auto backup_isa = is_amx && bgmmc_.is_runtime_M && !is_s8s8 ? (is_f16 || is_f32_f16 || is_f16_with_int_wei ? avx512_core_fp16 - : (is_bf16 || is_bf16_with_int_wei + : (is_bf16 || is_f32_bf16 || is_bf16_with_int_wei ? avx512_core_bf16 : (is_int8 ? avx512_core_vnni : avx512_core))) diff --git a/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp index 7c616e68dfe..38e72f1e5ae 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp @@ -3390,6 +3390,11 @@ void jit_brgemm_matmul_copy_b_f32_t::load_data( switch (dt_in_) { case data_type::f32: uni_vmovups(vmm, op); break; + case data_type::bf16: + // Upconvert: load 16 bits and move them 16 bits left. + uni_vpmovzxwd(vmm, op); + uni_vpslld(vmm, vmm, 16); + break; case data_type::f16: if (is_superset(conf_->isa, avx512_core_fp16)) { vcvtph2psx(vmm, op); @@ -3602,6 +3607,11 @@ struct jit_brgemm_matmul_copy_b_transposed_t , use_fp16_instructions_(is_subset(conf_->isa, avx512_core_fp16) && conf_->orig_wei_dt == data_type::f16 && conf_->wei_dt == data_type::f32) + // This variable is responsible for enabling to upconversion from bf16 + // to f32 similarly to f16, mostly for proper tail handling. + , use_bf16_instructions_(is_subset(conf_->isa, avx512_core_bf16) + && conf_->orig_wei_dt == data_type::bf16 + && conf_->wei_dt == data_type::f32) , max_tmp_idx(16 - (avx512_core_dot_product_ ? 8 @@ -3648,6 +3658,7 @@ struct jit_brgemm_matmul_copy_b_transposed_t const bool req_apply_scales_; const bool avx512_core_dot_product_; const bool use_fp16_instructions_; + const bool use_bf16_instructions_; const int max_tmp_idx; const dim_t src_stride_, tr_src_stride_, scales_K_stride_, typesize_scale_; @@ -3793,8 +3804,10 @@ void jit_brgemm_matmul_copy_b_transposed_t::init_tail_mask( const int columns_tail, const bool use_int4_mask) { assert(IMPLICATION(use_int4_mask, is_src_int4_)); if (columns_tail > 0) { - const int dt_step - = req_cvtps2xf16_ || use_fp16_instructions_ ? 1 : typesize_; + const int dt_step = req_cvtps2xf16_ || use_fp16_instructions_ + || use_bf16_instructions_ + ? 1 + : typesize_; const auto tail_mask = use_int4_mask ? size_t(((size_t)1 << div_up(dt_step * columns_tail, 2)) - 1) : size_t(((size_t)1 << dt_step * columns_tail) - 1); @@ -3998,6 +4011,10 @@ void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( } else { vcvtph2ps(src_load, addr); } + } else if (use_bf16_instructions_) { + // Upconvert: load 16 bits and move them 16 bits left. + uni_vpmovzxwd(src_load, addr); + uni_vpslld(src_load, src_load, 16); } else { vmovdqu8(src_load, addr); } @@ -4168,11 +4185,18 @@ void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( // For f32:f16 case need to convert raw bytes after `load_bytes` // into f32 values. vcvtph2ps(vmm_src, Xmm(vmm_src.getIdx())); + } else if (use_bf16_instructions_) { + // Upconvert: move loaded 16 bits left. + uni_vpslld(vmm_src, vmm_src, 16); } } else { if (use_fp16_instructions_) { // For non-tailed case can use the convert instruction directly. vcvtph2ps(vmm_src, ptr[reg_src + i * src_stride_]); + } else if (use_bf16_instructions_) { + // Upconvert: load 16 bits and move them 16 bits left. + uni_vpmovzxwd(vmm_src, ptr[reg_src + i * src_stride_]); + uni_vpslld(vmm_src, vmm_src, 16); } else { uni_vmovups(vmm_src, ptr[reg_src + i * src_stride_]); } diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp index 1521080600c..19a882fc7bc 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp @@ -195,6 +195,11 @@ status_t check_isa_with_datatype( && IMPLICATION(bm_conf_utils.is_f32_f16(), one_of(isa, avx512_core_fp16, avx2_vnni_2, avx512_core, avx2)) + // `avx512_core_amx` is not supported for plain upconversion as HW + // supports native compute. + && IMPLICATION(bm_conf_utils.is_f32_bf16(), + one_of(isa, avx512_core_bf16, avx2_vnni_2, avx512_core, + avx2)) && IMPLICATION(bm_conf_utils.is_int8_with_bf16_dst(), is_superset(isa, avx512_core) || isa == avx2_vnni_2) && IMPLICATION(bm_conf_utils.is_bf16_with_int_wei(), @@ -207,12 +212,13 @@ status_t check_isa_with_datatype( } status_t check_datatype_cfg(const brgemm_matmul_conf_utils_t &bm_conf_utils) { - const bool ok = one_of(true, bm_conf_utils.is_f32(), - bm_conf_utils.is_bf16(), bm_conf_utils.is_f16(), - bm_conf_utils.is_f32_f16(), bm_conf_utils.is_bf32(), - bm_conf_utils.is_f8(), bm_conf_utils.is_int8(), - bm_conf_utils.is_bf16_with_int_wei(), - bm_conf_utils.is_f16_with_int_wei()) + const bool ok + = one_of(true, bm_conf_utils.is_f32(), bm_conf_utils.is_bf16(), + bm_conf_utils.is_f16(), bm_conf_utils.is_f32_f16(), + bm_conf_utils.is_f32_bf16(), bm_conf_utils.is_bf32(), + bm_conf_utils.is_f8(), bm_conf_utils.is_int8(), + bm_conf_utils.is_bf16_with_int_wei(), + bm_conf_utils.is_f16_with_int_wei()) && IMPLICATION(bm_conf_utils.is_bf16_with_int_wei() || bm_conf_utils.is_f16_with_int_wei(), bm_conf_utils.with_weights_decompression()); @@ -251,6 +257,10 @@ brgemm_matmul_conf_utils_t::brgemm_matmul_conf_utils_t( // avx2 as there's no kernel for such combination. , f32_f16_dt(bgmmc.src_dt == f32 && bgmmc.wei_dt == f16 && one_of(bgmmc.dst_dt, f16, f32)) + // Keep this var separate from bf16_dt to not slip bf16:bf16 on avx512_core + // and avx2 as there's no kernel for such combination. + , f32_bf16_dt(bgmmc.src_dt == f32 && bgmmc.wei_dt == bf16 + && one_of(bgmmc.dst_dt, bf16, f32)) , f16_with_int_wei_dt(weights_decompression_support && bgmmc.src_dt == f16 && one_of(bgmmc.dst_dt, f16, f32)) , A_any_layout(A_any_layout) @@ -372,7 +382,7 @@ status_t brgemm_matmul_conf_utils_t::set_or_check_tags(memory_desc_t &A_md, const bool is_adbc_allowed = (this->is_bf16() || this->is_f32() || this->is_bf32() || this->is_f16() || this->is_f32_f16() - || this->is_bf16_with_int_wei() + || this->is_f32_bf16() || this->is_bf16_with_int_wei() || this->is_f16_with_int_wei()) && !xf16_avx2_vnni_2; bgmmc.src_tag = is_adbc_allowed @@ -475,7 +485,7 @@ format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout( } if (this->is_bf16() || this->is_bf16_with_int_wei() - || ((this->is_f16() || this->is_f32_f16() + || ((this->is_f16() || this->is_f32_f16() || this->is_f32_bf16() || this->is_f16_with_int_wei()) && (is_superset(bgmmc.isa, avx512_core_amx) || is_superset(bgmmc.isa, avx2_vnni_2)))) @@ -488,7 +498,8 @@ format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout( } // Note: bf32 assumes f32 blocking if (this->is_f32() || this->is_bf32() || this->is_f16() - || this->is_f32_f16() || this->is_f16_with_int_wei()) + || this->is_f32_f16() || this->is_f32_bf16() + || this->is_f16_with_int_wei()) switch (n_blk) { case 64: return bgmmc.ndims == 3 ? aCB16b64c : BA16a64b; case 48: return bgmmc.ndims == 3 ? aCB16b48c : BA16a48b; @@ -742,7 +753,8 @@ void compute_blocking_heuristic_amx(const brgemm_matmul_conf_t &bgmmc, = div_up(static_cast(bgmmc.K), min_k_per_thread); const bool is_amx_xf16 = bgmmc.is_amx && (bm_conf_utils.is_bf16() || bm_conf_utils.is_f16() - || bm_conf_utils.is_f32_f16() || bm_conf_utils.is_bf32() + || bm_conf_utils.is_f32_f16() || bm_conf_utils.is_f32_bf16() + || bm_conf_utils.is_bf32() || bm_conf_utils.is_bf16_with_int_wei() || bm_conf_utils.is_f16_with_int_wei()); const bool is_amx_int8 = bgmmc.is_amx && bm_conf_utils.is_int8(); @@ -1297,6 +1309,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, bgmmc.is_bf16_with_int_wei = bm_conf_utils.is_bf16_with_int_wei(); bgmmc.is_f16_with_int_wei = bm_conf_utils.is_f16_with_int_wei(); bgmmc.is_f32_f16 = bm_conf_utils.is_f32_f16(); + bgmmc.is_f32_bf16 = bm_conf_utils.is_f32_bf16(); bgmmc.with_wei_decompression = bm_conf_utils.with_weights_decompression(); bgmmc.is_int4_weights = one_of(bgmmc.wei_dt, data_type::s4, data_type::u4); @@ -1314,15 +1327,17 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, bgmmc.wei_dt = f32; bgmmc.tr_a_dt_sz = types::data_type_size(f32); bgmmc.tr_b_dt_sz = types::data_type_size(f32); - } else if (bm_conf_utils.is_f32_f16() && is_superset(bgmmc.isa, avx2)) { + } else if ((bm_conf_utils.is_f32_f16() || bm_conf_utils.is_f32_bf16()) + && is_superset(bgmmc.isa, avx2)) { // Note 1: Keep this branch separately from f16 one to have different - // ISA conditions (f16 includes f16:f32 and f16:f16 combinations). + // ISA conditions (f16 includes f16:f32 and f16:f16 combinations). Same + // applies for bf16 (which includes bf16:bf16). // Note 2: If `use_buffer_b()` is false, let the kernel perform the // conversion. Otherwise, make the copy_b routine handle the conversion // and set kernel data types to f32. // Note 3: Since `use_buffer_b()` depends on `bgmmc.wei_tag`, which is // set later in the code due to its dependencies, the update of data - // types to f32 happens below in ANCHOR: `CONVERT_F32_F16_DATA_TYPES`. + // types to f32 happens below in ANCHOR: `CONVERT_F32_XF16_DATA_TYPES`. } else if (bgmmc.is_f16_with_int_wei && bgmmc.isa != avx512_core_fp16) { bgmmc.src_dt = f16; bgmmc.wei_dt = f16; @@ -1449,9 +1464,9 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, && bgmmc.is_oscale_per_k && bgmmc.is_oscale_per_n && bgmmc.transposed_B; - if (bm_conf_utils.is_f32_f16() && is_superset(bgmmc.isa, avx2) - && bm_conf_utils.use_buffer_b()) { - // ANCHOR: `CONVERT_F32_F16_DATA_TYPES` + if ((bm_conf_utils.is_f32_f16() || bm_conf_utils.is_f32_bf16()) + && is_superset(bgmmc.isa, avx2) && bm_conf_utils.use_buffer_b()) { + // ANCHOR: `CONVERT_F32_XF16_DATA_TYPES` bgmmc.src_dt = f32; bgmmc.wei_dt = f32; bgmmc.tr_a_dt_sz = types::data_type_size(f32); @@ -1660,7 +1675,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, is_small_shapes = is_small_shapes && (bgmmc.isa != avx512_core_amx_fp16); if (bm_conf_utils.is_bf16() || bm_conf_utils.is_f16() - || bm_conf_utils.is_f32_f16() + || bm_conf_utils.is_f32_f16() || bm_conf_utils.is_f32_bf16() || bm_conf_utils.is_bf16_with_int_wei() || bm_conf_utils.is_f16_with_int_wei()) { // empirical observation for performance breakpoint between amx and vnni diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.hpp b/src/cpu/x64/matmul/brgemm_matmul_utils.hpp index b81c87f331a..0978e9b6672 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.hpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.hpp @@ -191,6 +191,7 @@ struct brgemm_matmul_conf_t { bool is_bf16_with_int_wei = false; bool is_f16_with_int_wei = false; bool is_f32_f16 = false; + bool is_f32_bf16 = false; bool is_int4_weights = false; bool req_wei_vnni_downconvert = false; bool is_runtime_M = false; @@ -303,6 +304,8 @@ struct brgemm_matmul_conf_utils_t { inline bool is_f32_f16() const { return f32_f16_dt; } + inline bool is_f32_bf16() const { return f32_bf16_dt; } + inline bool is_f16_with_int_wei() const { return f16_with_int_wei_dt; } inline bool with_weights_decompression() const { @@ -341,7 +344,7 @@ struct brgemm_matmul_conf_utils_t { const bool f32_dt, bf16_dt, f16_dt, f8_dt, int8_dt, bf32_dt; const bool weights_decompression_support, bf16_with_int_wei_dt, f32_f16_dt, - f16_with_int_wei_dt; + f32_bf16_dt, f16_with_int_wei_dt; const bool A_any_layout; const bool B_any_layout; diff --git a/tests/benchdnn/inputs/matmul/test_matmul_bfloat16 b/tests/benchdnn/inputs/matmul/test_matmul_bfloat16 index 03550499b3c..1f3c5d782d9 100644 --- a/tests/benchdnn/inputs/matmul/test_matmul_bfloat16 +++ b/tests/benchdnn/inputs/matmul/test_matmul_bfloat16 @@ -1,7 +1,7 @@ # bf16 --reset ---dt=bf16:bf16:f32,bf16 +--dt=bf16:bf16:f32,bf16,f32:bf16:f32 --stag=ab,ba --wtag=ab,ba --dtag=ab --runtime_dims_masks=0,2:1,1:0,3:1 --bia_dt=undef,f32 --bia_mask=2 @@ -28,13 +28,13 @@ # test any --reset ---dt=bf16:bf16:f32,bf16 +--dt=bf16:bf16:f32,bf16,f32:bf16:f32 --stag=ab,ba,any --wtag=ab,ba,any --dtag=ab,any --batch=shapes_2d # 3d --reset ---dt=bf16:bf16:f32,bf16 +--dt=bf16:bf16:f32,bf16,f32:bf16:f32 --stag=abc,acb --wtag=abc,acb --dtag=abc --bia_dt=undef,f32 --bia_mask=4,6