Skip to content

Commit

Permalink
cpu: x64: matmul: add f32:f16 support on avx512_core and avx2
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Jan 6, 2025
1 parent de434eb commit 8373e71
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 37 deletions.
12 changes: 6 additions & 6 deletions src/cpu/matmul/ref_matmul.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2024 Intel Corporation
* Copyright 2019-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -61,8 +61,8 @@ struct ref_matmul_t : public primitive_t {
f8_e4m3, f4_e2m1, f4_e3m0),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_MATMUL((src_type == wei_type
|| utils::one_of(wei_type, u8, s8, u4, s4,
f4_e3m0)),
|| utils::one_of(wei_type, f16, u8, s8, u4,
s4, f4_e3m0)),
VERBOSE_UNSUPPORTED_DT);
/* int8 weights decompression support */
VDISPATCH_MATMUL(IMPLICATION(utils::one_of(wei_type, u8, s8),
Expand All @@ -82,10 +82,10 @@ struct ref_matmul_t : public primitive_t {
utils::one_of(
bia_type, f32, bf16, f16, f8_e5m2, f8_e4m3)
&& IMPLICATION(
src_type == f32, bia_type == f32)
&& IMPLICATION(src_type == f16,
wei_type == f32, bia_type == f32)
&& IMPLICATION(wei_type == f16,
utils::one_of(bia_type, f32, f16))
&& IMPLICATION(src_type == bf16,
&& IMPLICATION(wei_type == bf16,
utils::one_of(bia_type, f32, bf16))
// TODO: any implication on allowed bias
// data type for fp8?
Expand Down
16 changes: 12 additions & 4 deletions src/cpu/x64/matmul/brgemm_matmul.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2024 Intel Corporation
* Copyright 2021-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -60,6 +60,8 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
= everyone_is(bf16, src_dt, wei_dt) && one_of(dst_dt, bf16, f32);
const bool is_f16
= everyone_is(f16, src_dt, wei_dt) && one_of(dst_dt, f16, f32);
const bool is_f32_f16
= src_dt == f32 && wei_dt == f16 && one_of(dst_dt, f16, f32);
const bool is_bf16_with_int_wei = src_dt == bf16
&& one_of(wei_dt, s8, u8, s4, u4) && one_of(dst_dt, bf16, f32);
const bool is_f16_with_int_wei = src_dt == f16
Expand Down Expand Up @@ -117,8 +119,9 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {

auto check_attr_zero_points
= [&]() -> bool { return attr()->zero_points_.common(); };
const bool problem_dt_correct = one_of(true, is_int8, is_f8, is_bf16,
is_f32, is_f16, is_bf16_with_int_wei, is_f16_with_int_wei);
const bool problem_dt_correct
= one_of(true, is_int8, is_f8, is_bf16, is_f32, is_f16, is_f32_f16,
is_bf16_with_int_wei, is_f16_with_int_wei);

auto src_d = memory_desc_wrapper(src_md_);
auto weights_d = memory_desc_wrapper(weights_md_);
Expand Down Expand Up @@ -156,6 +159,11 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
CHECK(init_brgemm_matmul_conf(isa, bgmmc_, *desc(), src_md_, weights_md_,
dst_md_, bias_md_, attr_));

// f32:f16 configuration on AVX2 doesn't support tails with proper
// instruction sequence in copy routines. Anchor: F32_F16_AVX2_NO_TAIL.
VDISPATCH_MATMUL(IMPLICATION(is_f32_f16 && isa == avx2, bgmmc_.N % 8 == 0),
"unsupported configuration");

const float alpha = 1.0;
const float beta = 1.0;
const float beta_init = 0.0;
Expand All @@ -171,7 +179,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
// non-amx isa. s8s8 proplem type is exception to avoid compensations
// processing for tail kernel
const auto backup_isa = is_amx && bgmmc_.is_runtime_M && !is_s8s8
? (is_f16 || is_f16_with_int_wei
? (is_f16 || is_f32_f16 || is_f16_with_int_wei
? avx512_core_fp16
: (is_bf16 || is_bf16_with_int_wei
? avx512_core_bf16
Expand Down
37 changes: 30 additions & 7 deletions src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2024 Intel Corporation
* Copyright 2021-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -544,7 +544,7 @@ struct jit_brgemm_matmul_copy_a_transposed_impl_t
, m_loop_dst_shift(columns_step * dst_stride)
, k_loop_src_shift(rows_step * src_stride)
, k_loop_dst_shift(rows_step * tr_typesize)
, is_f32(everyone_is(data_type::f32, conf_->src_dt, conf_->wei_dt))
, is_f32(conf_->src_dt == data_type::f32)
, is_bf32(conf_->is_bf32)
, is_dynamic_src_ld(conf_->is_runtime_M)
// See the note in `create_brgemm_matmul_copy_b` why `orig_src_dt` used.
Expand Down Expand Up @@ -3390,7 +3390,13 @@ void jit_brgemm_matmul_copy_b_f32_t<Vmm>::load_data(

switch (dt_in_) {
case data_type::f32: uni_vmovups(vmm, op); break;
case data_type::f16: vcvtph2psx(vmm, op); break;
case data_type::f16:
if (is_superset(conf_->isa, avx512_core_fp16)) {
vcvtph2psx(vmm, op);
} else {
vcvtph2ps(vmm, op);
}
break;
case data_type::s8: uni_vpmovsxbd(vmm, op); break;
case data_type::u8: uni_vpmovzxbd(vmm, op); break;
// For int4, we see two int4 as one int8 and extend them int32
Expand Down Expand Up @@ -3593,7 +3599,7 @@ struct jit_brgemm_matmul_copy_b_transposed_t
, avx512_core_dot_product_(
do_compute_compensation_ && !isa_has_int8_vnni(conf->isa))
// See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt` used.
, use_fp16_instructions_(conf_->isa == avx512_core_fp16
, use_fp16_instructions_(is_subset(conf_->isa, avx512_core_fp16)
&& conf_->orig_wei_dt == data_type::f16
&& conf_->wei_dt == data_type::f32)
, max_tmp_idx(16
Expand Down Expand Up @@ -3987,7 +3993,11 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(
vcvtdq2ps(src_load, src_load);
maybe_apply_scales(src_reg, i * scales_K_stride_, is_tail);
} else if (use_fp16_instructions_) {
vcvtph2psx(src_load, addr);
if (conf_->isa == avx512_core_fp16) {
vcvtph2psx(src_load, addr);
} else {
vcvtph2ps(src_load, addr);
}
} else {
vmovdqu8(src_load, addr);
}
Expand Down Expand Up @@ -4131,6 +4141,7 @@ void jit_brgemm_matmul_copy_b_transposed_t<Ymm>::copy_row_x_col(
if (!nrows) return;

const int columns_tail = ncolumns % k_blk_step_;

auto load = [this, nrows, columns_tail](int i) {
auto vmm_src = src_vmm(i);

Expand All @@ -4149,11 +4160,23 @@ void jit_brgemm_matmul_copy_b_transposed_t<Ymm>::copy_row_x_col(
uni_vpxor(vmm_src, vmm_src, vmm_src);
return;
}

if (columns_tail > 0) {
load_bytes(vmm_src, reg_src, i * src_stride_,
columns_tail * typesize_);
} else
uni_vmovups(vmm_src, ptr[reg_src + i * src_stride_]);
if (use_fp16_instructions_) {
// For f32:f16 case need to convert raw bytes after `load_bytes`
// into f32 values.
vcvtph2ps(vmm_src, Xmm(vmm_src.getIdx()));
}
} else {
if (use_fp16_instructions_) {
// For non-tailed case can use the convert instruction directly.
vcvtph2ps(vmm_src, ptr[reg_src + i * src_stride_]);
} else {
uni_vmovups(vmm_src, ptr[reg_src + i * src_stride_]);
}
}

L(load_done);
};
Expand Down
56 changes: 44 additions & 12 deletions src/cpu/x64/matmul/brgemm_matmul_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2024 Intel Corporation
* Copyright 2021-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -190,6 +190,11 @@ status_t check_isa_with_datatype(
&& IMPLICATION(bm_conf_utils.is_f16(),
one_of(isa, avx512_core_amx_fp16, avx512_core_fp16,
avx2_vnni_2))
// `avx512_core_amx_fp16` is not supported for plain upconversion
// as HW supports native compute.
&& IMPLICATION(bm_conf_utils.is_f32_f16(),
one_of(isa, avx512_core_fp16, avx2_vnni_2, avx512_core,
avx2))
&& IMPLICATION(bm_conf_utils.is_int8_with_bf16_dst(),
is_superset(isa, avx512_core) || isa == avx2_vnni_2)
&& IMPLICATION(bm_conf_utils.is_bf16_with_int_wei(),
Expand All @@ -202,12 +207,12 @@ status_t check_isa_with_datatype(
}

status_t check_datatype_cfg(const brgemm_matmul_conf_utils_t &bm_conf_utils) {
const bool ok
= one_of(true, bm_conf_utils.is_f32(), bm_conf_utils.is_bf16(),
bm_conf_utils.is_f16(), bm_conf_utils.is_bf32(),
bm_conf_utils.is_f8(), bm_conf_utils.is_int8(),
bm_conf_utils.is_bf16_with_int_wei(),
bm_conf_utils.is_f16_with_int_wei())
const bool ok = one_of(true, bm_conf_utils.is_f32(),
bm_conf_utils.is_bf16(), bm_conf_utils.is_f16(),
bm_conf_utils.is_f32_f16(), bm_conf_utils.is_bf32(),
bm_conf_utils.is_f8(), bm_conf_utils.is_int8(),
bm_conf_utils.is_bf16_with_int_wei(),
bm_conf_utils.is_f16_with_int_wei())
&& IMPLICATION(bm_conf_utils.is_bf16_with_int_wei()
|| bm_conf_utils.is_f16_with_int_wei(),
bm_conf_utils.with_weights_decompression());
Expand Down Expand Up @@ -242,6 +247,10 @@ brgemm_matmul_conf_utils_t::brgemm_matmul_conf_utils_t(
&& attr.fpmath_.apply_to_int_)
, bf16_with_int_wei_dt(weights_decompression_support && bgmmc.src_dt == bf16
&& one_of(bgmmc.dst_dt, bf16, f32))
// Keep this var separate from f16_dt to not slip f16:f16 on avx512_core and
// avx2 as there's no kernel for such combination.
, f32_f16_dt(bgmmc.src_dt == f32 && bgmmc.wei_dt == f16
&& one_of(bgmmc.dst_dt, f16, f32))
, f16_with_int_wei_dt(weights_decompression_support && bgmmc.src_dt == f16
&& one_of(bgmmc.dst_dt, f16, f32))
, A_any_layout(A_any_layout)
Expand Down Expand Up @@ -362,7 +371,8 @@ status_t brgemm_matmul_conf_utils_t::set_or_check_tags(memory_desc_t &A_md,
= this->is_int8() && is_superset(bgmmc.isa, avx512_core);
const bool is_adbc_allowed
= (this->is_bf16() || this->is_f32() || this->is_bf32()
|| this->is_f16() || this->is_bf16_with_int_wei()
|| this->is_f16() || this->is_f32_f16()
|| this->is_bf16_with_int_wei()
|| this->is_f16_with_int_wei())
&& !xf16_avx2_vnni_2;
bgmmc.src_tag = is_adbc_allowed
Expand Down Expand Up @@ -465,8 +475,10 @@ format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout(
}

if (this->is_bf16() || this->is_bf16_with_int_wei()
|| ((this->is_f16() || this->is_f16_with_int_wei())
&& bgmmc.isa != avx512_core_fp16))
|| ((this->is_f16() || this->is_f32_f16()
|| this->is_f16_with_int_wei())
&& (is_superset(bgmmc.isa, avx512_core_amx)
|| is_superset(bgmmc.isa, avx2_vnni_2))))
switch (n_blk) {
case 64: return bgmmc.ndims == 3 ? aCB16b64c2b : BA16a64b2a;
case 48: return bgmmc.ndims == 3 ? aCB16b48c2b : BA16a48b2a;
Expand All @@ -476,7 +488,7 @@ format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout(
}
// Note: bf32 assumes f32 blocking
if (this->is_f32() || this->is_bf32() || this->is_f16()
|| this->is_f16_with_int_wei())
|| this->is_f32_f16() || this->is_f16_with_int_wei())
switch (n_blk) {
case 64: return bgmmc.ndims == 3 ? aCB16b64c : BA16a64b;
case 48: return bgmmc.ndims == 3 ? aCB16b48c : BA16a48b;
Expand Down Expand Up @@ -730,7 +742,7 @@ void compute_blocking_heuristic_amx(const brgemm_matmul_conf_t &bgmmc,
= div_up(static_cast<int>(bgmmc.K), min_k_per_thread);
const bool is_amx_xf16 = bgmmc.is_amx
&& (bm_conf_utils.is_bf16() || bm_conf_utils.is_f16()
|| bm_conf_utils.is_bf32()
|| bm_conf_utils.is_f32_f16() || bm_conf_utils.is_bf32()
|| bm_conf_utils.is_bf16_with_int_wei()
|| bm_conf_utils.is_f16_with_int_wei());
const bool is_amx_int8 = bgmmc.is_amx && bm_conf_utils.is_int8();
Expand Down Expand Up @@ -1284,6 +1296,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
bgmmc.is_bf32 = bm_conf_utils.is_bf32();
bgmmc.is_bf16_with_int_wei = bm_conf_utils.is_bf16_with_int_wei();
bgmmc.is_f16_with_int_wei = bm_conf_utils.is_f16_with_int_wei();
bgmmc.is_f32_f16 = bm_conf_utils.is_f32_f16();
bgmmc.with_wei_decompression = bm_conf_utils.with_weights_decompression();
bgmmc.is_int4_weights = one_of(bgmmc.wei_dt, data_type::s4, data_type::u4);

Expand All @@ -1301,6 +1314,15 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
bgmmc.wei_dt = f32;
bgmmc.tr_a_dt_sz = types::data_type_size(f32);
bgmmc.tr_b_dt_sz = types::data_type_size(f32);
} else if (bm_conf_utils.is_f32_f16() && is_superset(bgmmc.isa, avx2)) {
// Note 1: Keep this branch separately from f16 one to have different
// ISA conditions (f16 includes f16:f32 and f16:f16 combinations).
// Note 2: If `use_buffer_b()` is false, let the kernel perform the
// conversion. Otherwise, make the copy_b routine handle the conversion
// and set kernel data types to f32.
// Note 3: Since `use_buffer_b()` depends on `bgmmc.wei_tag`, which is
// set later in the code due to its dependencies, the update of data
// types to f32 happens below in ANCHOR: `CONVERT_F32_F16_DATA_TYPES`.
} else if (bgmmc.is_f16_with_int_wei && bgmmc.isa != avx512_core_fp16) {
bgmmc.src_dt = f16;
bgmmc.wei_dt = f16;
Expand Down Expand Up @@ -1427,6 +1449,15 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
&& bgmmc.is_oscale_per_k && bgmmc.is_oscale_per_n
&& bgmmc.transposed_B;

if (bm_conf_utils.is_f32_f16() && is_superset(bgmmc.isa, avx2)
&& bm_conf_utils.use_buffer_b()) {
// ANCHOR: `CONVERT_F32_F16_DATA_TYPES`
bgmmc.src_dt = f32;
bgmmc.wei_dt = f32;
bgmmc.tr_a_dt_sz = types::data_type_size(f32);
bgmmc.tr_b_dt_sz = types::data_type_size(f32);
}

// int4 weights decompression only supports plain and transpose layouts
// TODO: enable int4 reorder and extend support to blocked weights
// layout when needed
Expand Down Expand Up @@ -1629,6 +1660,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
is_small_shapes = is_small_shapes && (bgmmc.isa != avx512_core_amx_fp16);

if (bm_conf_utils.is_bf16() || bm_conf_utils.is_f16()
|| bm_conf_utils.is_f32_f16()
|| bm_conf_utils.is_bf16_with_int_wei()
|| bm_conf_utils.is_f16_with_int_wei()) {
// empirical observation for performance breakpoint between amx and vnni
Expand Down
11 changes: 6 additions & 5 deletions src/cpu/x64/matmul/brgemm_matmul_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2024 Intel Corporation
* Copyright 2021-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -190,6 +190,7 @@ struct brgemm_matmul_conf_t {
bool is_bf32 = false;
bool is_bf16_with_int_wei = false;
bool is_f16_with_int_wei = false;
bool is_f32_f16 = false;
bool is_int4_weights = false;
bool req_wei_vnni_downconvert = false;
bool is_runtime_M = false;
Expand Down Expand Up @@ -230,10 +231,8 @@ struct brgemm_matmul_conf_utils_t {
inline bool use_buffer_b(bool use_heuristic = true) const {
if (bgmmc.is_runtime_N) return true;
if (bgmmc.is_bf16_with_int_wei) return true;
if (bgmmc.is_f16_with_int_wei) return true;
if (bgmmc.apply_scales_in_buffer_b) return true;
if (utils::one_of(true, bgmmc.is_runtime_N, bgmmc.is_bf16_with_int_wei,
bgmmc.is_f16_with_int_wei, bgmmc.apply_scales_in_buffer_b))
return true;

if (bgmmc.is_amx)
// use b_buffer for AMX when:
Expand Down Expand Up @@ -302,6 +301,8 @@ struct brgemm_matmul_conf_utils_t {

inline bool is_bf16_with_int_wei() const { return bf16_with_int_wei_dt; }

inline bool is_f32_f16() const { return f32_f16_dt; }

inline bool is_f16_with_int_wei() const { return f16_with_int_wei_dt; }

inline bool with_weights_decompression() const {
Expand Down Expand Up @@ -339,7 +340,7 @@ struct brgemm_matmul_conf_utils_t {
brgemm_matmul_conf_t &bgmmc;

const bool f32_dt, bf16_dt, f16_dt, f8_dt, int8_dt, bf32_dt;
const bool weights_decompression_support, bf16_with_int_wei_dt,
const bool weights_decompression_support, bf16_with_int_wei_dt, f32_f16_dt,
f16_with_int_wei_dt;

const bool A_any_layout;
Expand Down
7 changes: 4 additions & 3 deletions tests/benchdnn/inputs/matmul/test_matmul_float16
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# f16
--reset

--dt=f16:f16:f32,f16
--skip-impl=ref
--dt=f16:f16:f32,f16,f32:f16:f32
--stag=ab,ba --wtag=ab,ba --dtag=ab
--runtime_dims_masks=0,2:1,1:0,3:1
--bia_dt=undef,f32 --bia_mask=2
Expand All @@ -28,13 +29,13 @@

# test any
--reset
--dt=f16:f16:f32,f16
--dt=f16:f16:f32,f16,f32:f16:f32
--stag=ab,ba,any --wtag=ab,ba,any --dtag=ab,any
--batch=shapes_2d

# 3d
--reset
--dt=f16:f16:f32,f16
--dt=f16:f16:f32,f16,f32:f16:f32
--stag=abc,acb --wtag=abc,acb --dtag=abc
--bia_dt=undef,f32 --bia_mask=4,6

Expand Down

0 comments on commit 8373e71

Please sign in to comment.