diff --git a/src/cpu/x64/brgemm/brgemm_utils.cpp b/src/cpu/x64/brgemm/brgemm_utils.cpp index f97e100d539..d34407c605c 100644 --- a/src/cpu/x64/brgemm/brgemm_utils.cpp +++ b/src/cpu/x64/brgemm/brgemm_utils.cpp @@ -53,7 +53,10 @@ void init_kernel_datatype( brg->is_int8 = utils::one_of(dt_a, data_type::u8, data_type::s8) && utils::one_of(dt_b, data_type::u8, data_type::s8); brg->is_bf16 = (dt_a == data_type::bf16) && (dt_b == data_type::bf16); - brg->is_f32 = (dt_a == data_type::f32) && (dt_b == data_type::f32); + // Note: f32:bf16 is treated as f32 case while f32:f16 has already been + // treated as f16. Probably, need a common ground here. + brg->is_f32 = (dt_a == data_type::f32) + && utils::one_of(dt_b, data_type::f32, data_type::bf16); brg->is_f16 = utils::one_of(data_type::f16, dt_a, dt_b); brg->is_fp8 = one_of(dt_a, data_type::f8_e5m2, data_type::f8_e4m3) && one_of(dt_b, data_type::f8_e5m2, data_type::f8_e4m3); @@ -131,9 +134,20 @@ void set_isa_impl(brgemm_desc_t *brg) { is_isa_ok(avx512_core_fp16), avx512_core_fp16, is_isa_ok(avx2), avx2); } else if (brg->is_bf16) { - brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(avx512_core_amx), - avx512_core_amx, is_isa_ok(avx512_core_bf16), avx512_core_bf16, - is_isa_ok(avx2_vnni_2), avx2_vnni_2); + if (brg->dt_a == data_type::f32 && brg->dt_b == data_type::bf16) { + // Distinguish f32:bf16 case upconversion for bf16 on AVX512_CORE + // and AVX2. + brg->isa_impl = utils::map(true, isa_undef, + is_isa_ok(avx512_core_amx), avx512_core_amx, + is_isa_ok(avx512_core_bf16), avx512_core_bf16, + is_isa_ok(avx512_core), avx512_core, is_isa_ok(avx2_vnni_2), + avx2_vnni_2, is_isa_ok(avx2), avx2); + } else { + brg->isa_impl = utils::map(true, isa_undef, + is_isa_ok(avx512_core_amx), avx512_core_amx, + is_isa_ok(avx512_core_bf16), avx512_core_bf16, + is_isa_ok(avx2_vnni_2), avx2_vnni_2); + } } else if (brg->is_f16) { if (everyone_is(data_type::f16, brg->dt_a, brg->dt_b)) { brg->isa_impl = utils::map(true, isa_undef, diff --git a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp index 3ffbf3eb28c..c28ccc668ea 100644 --- a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp +++ b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp @@ -2216,12 +2216,22 @@ void jit_brgemm_kernel_t::gemm_microkernel(int bd_block2, bool is_bdb_tail, } else { uni_vcvtph2psx(vmm_load, addr); } - } else if (brg.dt_b == data_type::bf16 - && brg.isa_impl == avx2_vnni_2) { - if (rd % 2 == 0) - vcvtneebf162ps(vmm_load, addr); - else - vcvtneobf162ps(vmm_load, addr); + } else if (brg.dt_b == data_type::bf16) { + if (brg.isa_impl == avx2_vnni_2) { + if (rd % 2 == 0) + vcvtneebf162ps(vmm_load, addr); + else + vcvtneobf162ps(vmm_load, addr); + } else if (utils::one_of(brg.isa_impl, avx512_core, avx2)) { + // Upconvert: load 16 bits and move them 16 bits left. + uni_vpmovzxwd(vmm_load, addr); + uni_vpslld(vmm_load, vmm_load, 16); + } else if (is_ld_tail + && !is_superset(brg.isa_impl, avx512_core)) { + load_bytes(vmm_load, addr, ldb_B_offset(0, true)); + } else { + uni_vmovups(vmm_load, addr); + } } else if (is_ld_tail) { if (is_superset(brg.isa_impl, avx512_core)) { uni_vmovups(vmm_load, addr); @@ -2274,12 +2284,22 @@ void jit_brgemm_kernel_t::gemm_microkernel(int bd_block2, bool is_bdb_tail, } else { uni_vcvtph2psx(vmm_load, addr); } - } else if (brg.dt_b == data_type::bf16 - && brg.isa_impl == avx2_vnni_2) { - if (rd % 2 == 0) - vcvtneebf162ps(vmm_load, addr); - else - vcvtneobf162ps(vmm_load, addr); + } else if (brg.dt_b == data_type::bf16) { + if (brg.isa_impl == avx2_vnni_2) { + if (rd % 2 == 0) + vcvtneebf162ps(vmm_load, addr); + else + vcvtneobf162ps(vmm_load, addr); + } else if (utils::one_of(brg.isa_impl, avx512_core, avx2)) { + // Upconvert: load 16 bits and move them 16 bits left. + uni_vpmovzxwd(vmm_load, addr); + uni_vpslld(vmm_load, vmm_load, 16); + } else if (is_ld_tail + && !is_superset(brg.isa_impl, avx512_core)) { + load_bytes(vmm_load, addr, ldb_B_offset(0, true)); + } else { + uni_vmovups(vmm_load, addr); + } } else if (is_ld_tail) { if (is_superset(brg.isa_impl, avx512_core)) { uni_vmovups(vmm_load, addr); diff --git a/src/cpu/x64/cpu_isa_traits.hpp b/src/cpu/x64/cpu_isa_traits.hpp index ffed0fed916..5862217057d 100644 --- a/src/cpu/x64/cpu_isa_traits.hpp +++ b/src/cpu/x64/cpu_isa_traits.hpp @@ -494,7 +494,8 @@ inline data_type_t get_mac_emu_data_type(const data_type_t data_type, using namespace data_type; if (req_emulation) switch (data_type) { case bf16: - if (isa == avx2_vnni_2) return f32; + if (utils::one_of(isa, avx2, avx2_vnni_2, avx512_core)) + return f32; break; case f16: if (utils::one_of(isa, avx2, avx2_vnni_2, avx512_core, diff --git a/tests/benchdnn/inputs/brgemm/test_brgemm_bf16 b/tests/benchdnn/inputs/brgemm/test_brgemm_bf16 index 5a792528a81..311e1b50d81 100644 --- a/tests/benchdnn/inputs/brgemm/test_brgemm_bf16 +++ b/tests/benchdnn/inputs/brgemm/test_brgemm_bf16 @@ -1,6 +1,6 @@ --reset ---dt=bf16,bf16:bf16:f32 +--dt=bf16,bf16:bf16:f32,f32:bf16:f32 --bia_dt=undef,f32,bf16 --beta=0,1 --attr-post-ops=,sum:2,relu