From 8373e71f1ed17883a46ccac0a623e7ccccb8f0f7 Mon Sep 17 00:00:00 2001 From: Dmitrii Zarukin Date: Mon, 6 Jan 2025 09:39:27 -0800 Subject: [PATCH] cpu: x64: matmul: add f32:f16 support on avx512_core and avx2 --- src/cpu/matmul/ref_matmul.hpp | 12 ++-- src/cpu/x64/matmul/brgemm_matmul.cpp | 16 ++++-- .../x64/matmul/brgemm_matmul_copy_utils.cpp | 37 +++++++++--- src/cpu/x64/matmul/brgemm_matmul_utils.cpp | 56 +++++++++++++++---- src/cpu/x64/matmul/brgemm_matmul_utils.hpp | 11 ++-- .../inputs/matmul/test_matmul_float16 | 7 ++- 6 files changed, 102 insertions(+), 37 deletions(-) diff --git a/src/cpu/matmul/ref_matmul.hpp b/src/cpu/matmul/ref_matmul.hpp index 6d61e25013c..8cd1f18dc50 100644 --- a/src/cpu/matmul/ref_matmul.hpp +++ b/src/cpu/matmul/ref_matmul.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -61,8 +61,8 @@ struct ref_matmul_t : public primitive_t { f8_e4m3, f4_e2m1, f4_e3m0), VERBOSE_UNSUPPORTED_DT); VDISPATCH_MATMUL((src_type == wei_type - || utils::one_of(wei_type, u8, s8, u4, s4, - f4_e3m0)), + || utils::one_of(wei_type, f16, u8, s8, u4, + s4, f4_e3m0)), VERBOSE_UNSUPPORTED_DT); /* int8 weights decompression support */ VDISPATCH_MATMUL(IMPLICATION(utils::one_of(wei_type, u8, s8), @@ -82,10 +82,10 @@ struct ref_matmul_t : public primitive_t { utils::one_of( bia_type, f32, bf16, f16, f8_e5m2, f8_e4m3) && IMPLICATION( - src_type == f32, bia_type == f32) - && IMPLICATION(src_type == f16, + wei_type == f32, bia_type == f32) + && IMPLICATION(wei_type == f16, utils::one_of(bia_type, f32, f16)) - && IMPLICATION(src_type == bf16, + && IMPLICATION(wei_type == bf16, utils::one_of(bia_type, f32, bf16)) // TODO: any implication on allowed bias // data type for fp8? diff --git a/src/cpu/x64/matmul/brgemm_matmul.cpp b/src/cpu/x64/matmul/brgemm_matmul.cpp index 56115a31cdb..2912319f320 100644 --- a/src/cpu/x64/matmul/brgemm_matmul.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -60,6 +60,8 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { = everyone_is(bf16, src_dt, wei_dt) && one_of(dst_dt, bf16, f32); const bool is_f16 = everyone_is(f16, src_dt, wei_dt) && one_of(dst_dt, f16, f32); + const bool is_f32_f16 + = src_dt == f32 && wei_dt == f16 && one_of(dst_dt, f16, f32); const bool is_bf16_with_int_wei = src_dt == bf16 && one_of(wei_dt, s8, u8, s4, u4) && one_of(dst_dt, bf16, f32); const bool is_f16_with_int_wei = src_dt == f16 @@ -117,8 +119,9 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { auto check_attr_zero_points = [&]() -> bool { return attr()->zero_points_.common(); }; - const bool problem_dt_correct = one_of(true, is_int8, is_f8, is_bf16, - is_f32, is_f16, is_bf16_with_int_wei, is_f16_with_int_wei); + const bool problem_dt_correct + = one_of(true, is_int8, is_f8, is_bf16, is_f32, is_f16, is_f32_f16, + is_bf16_with_int_wei, is_f16_with_int_wei); auto src_d = memory_desc_wrapper(src_md_); auto weights_d = memory_desc_wrapper(weights_md_); @@ -156,6 +159,11 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { CHECK(init_brgemm_matmul_conf(isa, bgmmc_, *desc(), src_md_, weights_md_, dst_md_, bias_md_, attr_)); + // f32:f16 configuration on AVX2 doesn't support tails with proper + // instruction sequence in copy routines. Anchor: F32_F16_AVX2_NO_TAIL. + VDISPATCH_MATMUL(IMPLICATION(is_f32_f16 && isa == avx2, bgmmc_.N % 8 == 0), + "unsupported configuration"); + const float alpha = 1.0; const float beta = 1.0; const float beta_init = 0.0; @@ -171,7 +179,7 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { // non-amx isa. s8s8 proplem type is exception to avoid compensations // processing for tail kernel const auto backup_isa = is_amx && bgmmc_.is_runtime_M && !is_s8s8 - ? (is_f16 || is_f16_with_int_wei + ? (is_f16 || is_f32_f16 || is_f16_with_int_wei ? avx512_core_fp16 : (is_bf16 || is_bf16_with_int_wei ? avx512_core_bf16 diff --git a/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp index 9a101dc111b..7c616e68dfe 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -544,7 +544,7 @@ struct jit_brgemm_matmul_copy_a_transposed_impl_t , m_loop_dst_shift(columns_step * dst_stride) , k_loop_src_shift(rows_step * src_stride) , k_loop_dst_shift(rows_step * tr_typesize) - , is_f32(everyone_is(data_type::f32, conf_->src_dt, conf_->wei_dt)) + , is_f32(conf_->src_dt == data_type::f32) , is_bf32(conf_->is_bf32) , is_dynamic_src_ld(conf_->is_runtime_M) // See the note in `create_brgemm_matmul_copy_b` why `orig_src_dt` used. @@ -3390,7 +3390,13 @@ void jit_brgemm_matmul_copy_b_f32_t::load_data( switch (dt_in_) { case data_type::f32: uni_vmovups(vmm, op); break; - case data_type::f16: vcvtph2psx(vmm, op); break; + case data_type::f16: + if (is_superset(conf_->isa, avx512_core_fp16)) { + vcvtph2psx(vmm, op); + } else { + vcvtph2ps(vmm, op); + } + break; case data_type::s8: uni_vpmovsxbd(vmm, op); break; case data_type::u8: uni_vpmovzxbd(vmm, op); break; // For int4, we see two int4 as one int8 and extend them int32 @@ -3593,7 +3599,7 @@ struct jit_brgemm_matmul_copy_b_transposed_t , avx512_core_dot_product_( do_compute_compensation_ && !isa_has_int8_vnni(conf->isa)) // See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt` used. - , use_fp16_instructions_(conf_->isa == avx512_core_fp16 + , use_fp16_instructions_(is_subset(conf_->isa, avx512_core_fp16) && conf_->orig_wei_dt == data_type::f16 && conf_->wei_dt == data_type::f32) , max_tmp_idx(16 @@ -3987,7 +3993,11 @@ void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( vcvtdq2ps(src_load, src_load); maybe_apply_scales(src_reg, i * scales_K_stride_, is_tail); } else if (use_fp16_instructions_) { - vcvtph2psx(src_load, addr); + if (conf_->isa == avx512_core_fp16) { + vcvtph2psx(src_load, addr); + } else { + vcvtph2ps(src_load, addr); + } } else { vmovdqu8(src_load, addr); } @@ -4131,6 +4141,7 @@ void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( if (!nrows) return; const int columns_tail = ncolumns % k_blk_step_; + auto load = [this, nrows, columns_tail](int i) { auto vmm_src = src_vmm(i); @@ -4149,11 +4160,23 @@ void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( uni_vpxor(vmm_src, vmm_src, vmm_src); return; } + if (columns_tail > 0) { load_bytes(vmm_src, reg_src, i * src_stride_, columns_tail * typesize_); - } else - uni_vmovups(vmm_src, ptr[reg_src + i * src_stride_]); + if (use_fp16_instructions_) { + // For f32:f16 case need to convert raw bytes after `load_bytes` + // into f32 values. + vcvtph2ps(vmm_src, Xmm(vmm_src.getIdx())); + } + } else { + if (use_fp16_instructions_) { + // For non-tailed case can use the convert instruction directly. + vcvtph2ps(vmm_src, ptr[reg_src + i * src_stride_]); + } else { + uni_vmovups(vmm_src, ptr[reg_src + i * src_stride_]); + } + } L(load_done); }; diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp index 5a70f5c4efc..1521080600c 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -190,6 +190,11 @@ status_t check_isa_with_datatype( && IMPLICATION(bm_conf_utils.is_f16(), one_of(isa, avx512_core_amx_fp16, avx512_core_fp16, avx2_vnni_2)) + // `avx512_core_amx_fp16` is not supported for plain upconversion + // as HW supports native compute. + && IMPLICATION(bm_conf_utils.is_f32_f16(), + one_of(isa, avx512_core_fp16, avx2_vnni_2, avx512_core, + avx2)) && IMPLICATION(bm_conf_utils.is_int8_with_bf16_dst(), is_superset(isa, avx512_core) || isa == avx2_vnni_2) && IMPLICATION(bm_conf_utils.is_bf16_with_int_wei(), @@ -202,12 +207,12 @@ status_t check_isa_with_datatype( } status_t check_datatype_cfg(const brgemm_matmul_conf_utils_t &bm_conf_utils) { - const bool ok - = one_of(true, bm_conf_utils.is_f32(), bm_conf_utils.is_bf16(), - bm_conf_utils.is_f16(), bm_conf_utils.is_bf32(), - bm_conf_utils.is_f8(), bm_conf_utils.is_int8(), - bm_conf_utils.is_bf16_with_int_wei(), - bm_conf_utils.is_f16_with_int_wei()) + const bool ok = one_of(true, bm_conf_utils.is_f32(), + bm_conf_utils.is_bf16(), bm_conf_utils.is_f16(), + bm_conf_utils.is_f32_f16(), bm_conf_utils.is_bf32(), + bm_conf_utils.is_f8(), bm_conf_utils.is_int8(), + bm_conf_utils.is_bf16_with_int_wei(), + bm_conf_utils.is_f16_with_int_wei()) && IMPLICATION(bm_conf_utils.is_bf16_with_int_wei() || bm_conf_utils.is_f16_with_int_wei(), bm_conf_utils.with_weights_decompression()); @@ -242,6 +247,10 @@ brgemm_matmul_conf_utils_t::brgemm_matmul_conf_utils_t( && attr.fpmath_.apply_to_int_) , bf16_with_int_wei_dt(weights_decompression_support && bgmmc.src_dt == bf16 && one_of(bgmmc.dst_dt, bf16, f32)) + // Keep this var separate from f16_dt to not slip f16:f16 on avx512_core and + // avx2 as there's no kernel for such combination. + , f32_f16_dt(bgmmc.src_dt == f32 && bgmmc.wei_dt == f16 + && one_of(bgmmc.dst_dt, f16, f32)) , f16_with_int_wei_dt(weights_decompression_support && bgmmc.src_dt == f16 && one_of(bgmmc.dst_dt, f16, f32)) , A_any_layout(A_any_layout) @@ -362,7 +371,8 @@ status_t brgemm_matmul_conf_utils_t::set_or_check_tags(memory_desc_t &A_md, = this->is_int8() && is_superset(bgmmc.isa, avx512_core); const bool is_adbc_allowed = (this->is_bf16() || this->is_f32() || this->is_bf32() - || this->is_f16() || this->is_bf16_with_int_wei() + || this->is_f16() || this->is_f32_f16() + || this->is_bf16_with_int_wei() || this->is_f16_with_int_wei()) && !xf16_avx2_vnni_2; bgmmc.src_tag = is_adbc_allowed @@ -465,8 +475,10 @@ format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout( } if (this->is_bf16() || this->is_bf16_with_int_wei() - || ((this->is_f16() || this->is_f16_with_int_wei()) - && bgmmc.isa != avx512_core_fp16)) + || ((this->is_f16() || this->is_f32_f16() + || this->is_f16_with_int_wei()) + && (is_superset(bgmmc.isa, avx512_core_amx) + || is_superset(bgmmc.isa, avx2_vnni_2)))) switch (n_blk) { case 64: return bgmmc.ndims == 3 ? aCB16b64c2b : BA16a64b2a; case 48: return bgmmc.ndims == 3 ? aCB16b48c2b : BA16a48b2a; @@ -476,7 +488,7 @@ format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout( } // Note: bf32 assumes f32 blocking if (this->is_f32() || this->is_bf32() || this->is_f16() - || this->is_f16_with_int_wei()) + || this->is_f32_f16() || this->is_f16_with_int_wei()) switch (n_blk) { case 64: return bgmmc.ndims == 3 ? aCB16b64c : BA16a64b; case 48: return bgmmc.ndims == 3 ? aCB16b48c : BA16a48b; @@ -730,7 +742,7 @@ void compute_blocking_heuristic_amx(const brgemm_matmul_conf_t &bgmmc, = div_up(static_cast(bgmmc.K), min_k_per_thread); const bool is_amx_xf16 = bgmmc.is_amx && (bm_conf_utils.is_bf16() || bm_conf_utils.is_f16() - || bm_conf_utils.is_bf32() + || bm_conf_utils.is_f32_f16() || bm_conf_utils.is_bf32() || bm_conf_utils.is_bf16_with_int_wei() || bm_conf_utils.is_f16_with_int_wei()); const bool is_amx_int8 = bgmmc.is_amx && bm_conf_utils.is_int8(); @@ -1284,6 +1296,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, bgmmc.is_bf32 = bm_conf_utils.is_bf32(); bgmmc.is_bf16_with_int_wei = bm_conf_utils.is_bf16_with_int_wei(); bgmmc.is_f16_with_int_wei = bm_conf_utils.is_f16_with_int_wei(); + bgmmc.is_f32_f16 = bm_conf_utils.is_f32_f16(); bgmmc.with_wei_decompression = bm_conf_utils.with_weights_decompression(); bgmmc.is_int4_weights = one_of(bgmmc.wei_dt, data_type::s4, data_type::u4); @@ -1301,6 +1314,15 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, bgmmc.wei_dt = f32; bgmmc.tr_a_dt_sz = types::data_type_size(f32); bgmmc.tr_b_dt_sz = types::data_type_size(f32); + } else if (bm_conf_utils.is_f32_f16() && is_superset(bgmmc.isa, avx2)) { + // Note 1: Keep this branch separately from f16 one to have different + // ISA conditions (f16 includes f16:f32 and f16:f16 combinations). + // Note 2: If `use_buffer_b()` is false, let the kernel perform the + // conversion. Otherwise, make the copy_b routine handle the conversion + // and set kernel data types to f32. + // Note 3: Since `use_buffer_b()` depends on `bgmmc.wei_tag`, which is + // set later in the code due to its dependencies, the update of data + // types to f32 happens below in ANCHOR: `CONVERT_F32_F16_DATA_TYPES`. } else if (bgmmc.is_f16_with_int_wei && bgmmc.isa != avx512_core_fp16) { bgmmc.src_dt = f16; bgmmc.wei_dt = f16; @@ -1427,6 +1449,15 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, && bgmmc.is_oscale_per_k && bgmmc.is_oscale_per_n && bgmmc.transposed_B; + if (bm_conf_utils.is_f32_f16() && is_superset(bgmmc.isa, avx2) + && bm_conf_utils.use_buffer_b()) { + // ANCHOR: `CONVERT_F32_F16_DATA_TYPES` + bgmmc.src_dt = f32; + bgmmc.wei_dt = f32; + bgmmc.tr_a_dt_sz = types::data_type_size(f32); + bgmmc.tr_b_dt_sz = types::data_type_size(f32); + } + // int4 weights decompression only supports plain and transpose layouts // TODO: enable int4 reorder and extend support to blocked weights // layout when needed @@ -1629,6 +1660,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, is_small_shapes = is_small_shapes && (bgmmc.isa != avx512_core_amx_fp16); if (bm_conf_utils.is_bf16() || bm_conf_utils.is_f16() + || bm_conf_utils.is_f32_f16() || bm_conf_utils.is_bf16_with_int_wei() || bm_conf_utils.is_f16_with_int_wei()) { // empirical observation for performance breakpoint between amx and vnni diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.hpp b/src/cpu/x64/matmul/brgemm_matmul_utils.hpp index f0d76b2aae5..b81c87f331a 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.hpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -190,6 +190,7 @@ struct brgemm_matmul_conf_t { bool is_bf32 = false; bool is_bf16_with_int_wei = false; bool is_f16_with_int_wei = false; + bool is_f32_f16 = false; bool is_int4_weights = false; bool req_wei_vnni_downconvert = false; bool is_runtime_M = false; @@ -230,10 +231,8 @@ struct brgemm_matmul_conf_utils_t { inline bool use_buffer_b(bool use_heuristic = true) const { if (bgmmc.is_runtime_N) return true; if (bgmmc.is_bf16_with_int_wei) return true; + if (bgmmc.is_f16_with_int_wei) return true; if (bgmmc.apply_scales_in_buffer_b) return true; - if (utils::one_of(true, bgmmc.is_runtime_N, bgmmc.is_bf16_with_int_wei, - bgmmc.is_f16_with_int_wei, bgmmc.apply_scales_in_buffer_b)) - return true; if (bgmmc.is_amx) // use b_buffer for AMX when: @@ -302,6 +301,8 @@ struct brgemm_matmul_conf_utils_t { inline bool is_bf16_with_int_wei() const { return bf16_with_int_wei_dt; } + inline bool is_f32_f16() const { return f32_f16_dt; } + inline bool is_f16_with_int_wei() const { return f16_with_int_wei_dt; } inline bool with_weights_decompression() const { @@ -339,7 +340,7 @@ struct brgemm_matmul_conf_utils_t { brgemm_matmul_conf_t &bgmmc; const bool f32_dt, bf16_dt, f16_dt, f8_dt, int8_dt, bf32_dt; - const bool weights_decompression_support, bf16_with_int_wei_dt, + const bool weights_decompression_support, bf16_with_int_wei_dt, f32_f16_dt, f16_with_int_wei_dt; const bool A_any_layout; diff --git a/tests/benchdnn/inputs/matmul/test_matmul_float16 b/tests/benchdnn/inputs/matmul/test_matmul_float16 index 727e4663310..ca5621ced19 100644 --- a/tests/benchdnn/inputs/matmul/test_matmul_float16 +++ b/tests/benchdnn/inputs/matmul/test_matmul_float16 @@ -1,7 +1,8 @@ # f16 --reset ---dt=f16:f16:f32,f16 +--skip-impl=ref +--dt=f16:f16:f32,f16,f32:f16:f32 --stag=ab,ba --wtag=ab,ba --dtag=ab --runtime_dims_masks=0,2:1,1:0,3:1 --bia_dt=undef,f32 --bia_mask=2 @@ -28,13 +29,13 @@ # test any --reset ---dt=f16:f16:f32,f16 +--dt=f16:f16:f32,f16,f32:f16:f32 --stag=ab,ba,any --wtag=ab,ba,any --dtag=ab,any --batch=shapes_2d # 3d --reset ---dt=f16:f16:f32,f16 +--dt=f16:f16:f32,f16,f32:f16:f32 --stag=abc,acb --wtag=abc,acb --dtag=abc --bia_dt=undef,f32 --bia_mask=4,6