Skip to content

Commit

Permalink
cpu: x64: brgemm: add f32:bf16 support on avx512_core and avx2
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Jan 6, 2025
1 parent a46eea2 commit de434eb
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 18 deletions.
22 changes: 18 additions & 4 deletions src/cpu/x64/brgemm/brgemm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ void init_kernel_datatype(
brg->is_int8 = utils::one_of(dt_a, data_type::u8, data_type::s8)
&& utils::one_of(dt_b, data_type::u8, data_type::s8);
brg->is_bf16 = (dt_a == data_type::bf16) && (dt_b == data_type::bf16);
brg->is_f32 = (dt_a == data_type::f32) && (dt_b == data_type::f32);
// Note: f32:bf16 is treated as f32 case while f32:f16 has already been
// treated as f16. Probably, need a common ground here.
brg->is_f32 = (dt_a == data_type::f32)
&& utils::one_of(dt_b, data_type::f32, data_type::bf16);
brg->is_f16 = utils::one_of(data_type::f16, dt_a, dt_b);
brg->is_fp8 = one_of(dt_a, data_type::f8_e5m2, data_type::f8_e4m3)
&& one_of(dt_b, data_type::f8_e5m2, data_type::f8_e4m3);
Expand Down Expand Up @@ -131,9 +134,20 @@ void set_isa_impl(brgemm_desc_t *brg) {
is_isa_ok(avx512_core_fp16), avx512_core_fp16, is_isa_ok(avx2),
avx2);
} else if (brg->is_bf16) {
brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(avx512_core_amx),
avx512_core_amx, is_isa_ok(avx512_core_bf16), avx512_core_bf16,
is_isa_ok(avx2_vnni_2), avx2_vnni_2);
if (brg->dt_a == data_type::f32 && brg->dt_b == data_type::bf16) {
// Distinguish f32:bf16 case upconversion for bf16 on AVX512_CORE
// and AVX2.
brg->isa_impl = utils::map(true, isa_undef,
is_isa_ok(avx512_core_amx), avx512_core_amx,
is_isa_ok(avx512_core_bf16), avx512_core_bf16,
is_isa_ok(avx512_core), avx512_core, is_isa_ok(avx2_vnni_2),
avx2_vnni_2, is_isa_ok(avx2), avx2);
} else {
brg->isa_impl = utils::map(true, isa_undef,
is_isa_ok(avx512_core_amx), avx512_core_amx,
is_isa_ok(avx512_core_bf16), avx512_core_bf16,
is_isa_ok(avx2_vnni_2), avx2_vnni_2);
}
} else if (brg->is_f16) {
if (everyone_is(data_type::f16, brg->dt_a, brg->dt_b)) {
brg->isa_impl = utils::map(true, isa_undef,
Expand Down
44 changes: 32 additions & 12 deletions src/cpu/x64/brgemm/jit_brgemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2216,12 +2216,22 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel(int bd_block2, bool is_bdb_tail,
} else {
uni_vcvtph2psx(vmm_load, addr);
}
} else if (brg.dt_b == data_type::bf16
&& brg.isa_impl == avx2_vnni_2) {
if (rd % 2 == 0)
vcvtneebf162ps(vmm_load, addr);
else
vcvtneobf162ps(vmm_load, addr);
} else if (brg.dt_b == data_type::bf16) {
if (brg.isa_impl == avx2_vnni_2) {
if (rd % 2 == 0)
vcvtneebf162ps(vmm_load, addr);
else
vcvtneobf162ps(vmm_load, addr);
} else if (utils::one_of(brg.isa_impl, avx512_core, avx2)) {
// Upconvert: load 16 bits and move them 16 bits left.
uni_vpmovzxwd(vmm_load, addr);
uni_vpslld(vmm_load, vmm_load, 16);
} else if (is_ld_tail
&& !is_superset(brg.isa_impl, avx512_core)) {
load_bytes(vmm_load, addr, ldb_B_offset(0, true));
} else {
uni_vmovups(vmm_load, addr);
}
} else if (is_ld_tail) {
if (is_superset(brg.isa_impl, avx512_core)) {
uni_vmovups(vmm_load, addr);
Expand Down Expand Up @@ -2274,12 +2284,22 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel(int bd_block2, bool is_bdb_tail,
} else {
uni_vcvtph2psx(vmm_load, addr);
}
} else if (brg.dt_b == data_type::bf16
&& brg.isa_impl == avx2_vnni_2) {
if (rd % 2 == 0)
vcvtneebf162ps(vmm_load, addr);
else
vcvtneobf162ps(vmm_load, addr);
} else if (brg.dt_b == data_type::bf16) {
if (brg.isa_impl == avx2_vnni_2) {
if (rd % 2 == 0)
vcvtneebf162ps(vmm_load, addr);
else
vcvtneobf162ps(vmm_load, addr);
} else if (utils::one_of(brg.isa_impl, avx512_core, avx2)) {
// Upconvert: load 16 bits and move them 16 bits left.
uni_vpmovzxwd(vmm_load, addr);
uni_vpslld(vmm_load, vmm_load, 16);
} else if (is_ld_tail
&& !is_superset(brg.isa_impl, avx512_core)) {
load_bytes(vmm_load, addr, ldb_B_offset(0, true));
} else {
uni_vmovups(vmm_load, addr);
}
} else if (is_ld_tail) {
if (is_superset(brg.isa_impl, avx512_core)) {
uni_vmovups(vmm_load, addr);
Expand Down
3 changes: 2 additions & 1 deletion src/cpu/x64/cpu_isa_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,8 @@ inline data_type_t get_mac_emu_data_type(const data_type_t data_type,
using namespace data_type;
if (req_emulation) switch (data_type) {
case bf16:
if (isa == avx2_vnni_2) return f32;
if (utils::one_of(isa, avx2, avx2_vnni_2, avx512_core))
return f32;
break;
case f16:
if (utils::one_of(isa, avx2, avx2_vnni_2, avx512_core,
Expand Down
2 changes: 1 addition & 1 deletion tests/benchdnn/inputs/brgemm/test_brgemm_bf16
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
--reset

--dt=bf16,bf16:bf16:f32
--dt=bf16,bf16:bf16:f32,f32:bf16:f32
--bia_dt=undef,f32,bf16
--beta=0,1
--attr-post-ops=,sum:2,relu
Expand Down

0 comments on commit de434eb

Please sign in to comment.