Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix accuracy of max-pooling backpropagation for bfloat16 data #2386

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/common/memory_tracking.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ enum {
key_pool_dst_plain2blocked_cvt,
key_pool_ind_plain2blocked_cvt,
key_pool_src_bf16cvt,
key_pool_src_f32_accum,
key_pool_src_plain2blocked_cvt,
key_pool_reduction,
key_precomputed_scales,
Expand Down
2 changes: 2 additions & 0 deletions src/cpu/x64/jit_primitive_conf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,8 @@ struct jit_pool_conf_t {
bool with_binary;
int nthr;
memory_desc_t tmp_md;
bool needs_f32_accum_for_bf16;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add support for acc_mode (separate commit is totally fine) which would allow to preserve former behavior (accumulation in bf16) to avoid issues like with softmax reported recently.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is in progress.

Copy link
Author

@asimonov1 asimonov1 Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added partial support of acc_mode. jit_uni_pooling implementation of backpropagation for max-pooling for axb, aBx16b/aBx8b for bf16 numbers switches to an old implementation (without f32 accumulator) if 'relaxed' or 'any' acc mode is specified. We have 'strict', 'relaxed', 'any', 'f32', 's32', 'f16' acc modes. Modes 's32' and 'f16' are just ignored, and 'f32' and 'strict' modes allow new implementation (f32 accumulator is not required if strides are larger than corresponding kernel sizes).

benchdnn is updated to use zero error threshold in case of max-pooling with strict or f32 accumulation modes.

If this approach is ok then docs are to be updated. It looks that docs are not correct/complete (https://oneapi-src.github.io/oneDNN/dev_guide_pooling.html) E.g. in case of abx format, our jit_uni_pooling implementation converts inputs/outputs to/from f32 arrays, so its accumulation mode is actually always 'strict'. 'relaxed' mode is not always faster than 'strict' but uses less memory. f64 data type can be used on GPUs only (?).

I also noticed that GPU version is out of scope (?) of MFDNN-11050 and MFDNN-11396, I did not test it.

dim_t f32_accum_block_size;
};

struct jit_pool_call_s {
Expand Down
832 changes: 313 additions & 519 deletions src/cpu/x64/jit_uni_pool_kernel.cpp

Large diffs are not rendered by default.

55 changes: 20 additions & 35 deletions src/cpu/x64/jit_uni_pool_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,13 @@
#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
#include "cpu/x64/jit_generator.hpp"
#include "cpu/x64/jit_primitive_conf.hpp"
#include "cpu/x64/utils/jit_io_helper.hpp"

namespace dnnl {
namespace impl {
namespace cpu {
namespace x64 {

struct bf16_emulation_t;

template <cpu_isa_t isa>
struct jit_uni_pool_kernel : public jit_generator {

Expand All @@ -45,10 +44,12 @@ struct jit_uni_pool_kernel : public jit_generator {

DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel)

static status_t init_conf(jit_pool_conf_t &jbp,
memory_tracking::registrar_t &scratchpad, primitive_attr_t &attr,
static status_t init_conf(jit_pool_conf_t &jpp, primitive_attr_t &attr,
const pooling_pd_t *ppd);

static void init_scratchpad(const jit_pool_conf_t &jpp,
memory_tracking::registrar_t &scratchpad);

private:
using Xmm = Xbyak::Xmm;
using Ymm = Xbyak::Ymm;
Expand Down Expand Up @@ -79,9 +80,8 @@ struct jit_uni_pool_kernel : public jit_generator {
Ymm ymm_tmp_1 = Ymm(0);
Vmm vmm_tmp_1 = Vmm(0);

// Used only for avx and if c tail is present
// Used only for avx and if c tail is present; is shared with jit_io_multi_dt_helper_t
Vmm vmm_c_tail_mask = Vmm(2);
Xmm xmm_c_tail_mask = Xmm(2);

Vmm vmm_ker_area_h = Vmm(2);
Vmm vmm_one = Vmm(2);
Expand All @@ -90,14 +90,6 @@ struct jit_uni_pool_kernel : public jit_generator {

Vmm vmm_k_offset = Vmm(1);

// Used only for avx512 when bf16 is present
inline Vmm vmm_idx() {
if (!jpp.is_backward) {
return (jpp.is_training) ? Vmm(4) : Vmm(1);
} else
return Vmm(4);
}

Zmm bf16_emu_reserv_1 = Zmm(5);
Zmm bf16_emu_reserv_2 = Zmm(6);
Zmm bf16_emu_reserv_3 = Zmm(7);
Expand All @@ -112,33 +104,23 @@ struct jit_uni_pool_kernel : public jit_generator {
Reg64 fp8_emu_reg64 = bf16_emu_reserv_4;
Xbyak::Opmask fp8_tmp_mask = Xbyak::Opmask(3);

Opmask k_c_tail_mask = Opmask(4);
Opmask k_mask_cvt = Opmask(5);
Opmask k_store_mask = Opmask(6);

// Here be some (tame) dragons. This kernel does not follow the regular
// OS-agnostic ABI pattern because when isa is sse41 it uses maskmovdqu
// instruction which has its destination hardcoded in rdi. Therefore:
// - all registers are hardcoded
// - on Windows rdi and rcx are swapped to mimic the Unix x86_64 ABI
//
// While this is only required by the backward pass, the quirk above
// is applied to the forward pass as well to keep things simpler.
Opmask k_c_tail_mask = Opmask(
4); // is shared with jit_io_multi_dt_helper_t and jit_uni_postops_injector_t
Opmask k_store_mask = Opmask(5);

using reg64_t = const Reg64;
reg64_t reg_param = rdi; // Always mimic the Unix ABI
reg64_t reg_param = abi_param1;
reg64_t reg_input = r8;
reg64_t aux_reg_input = r9;
reg64_t reg_index = r10;
reg64_t reg_output = r12;
reg64_t reg_kd_pad_shift = r13;
reg64_t dst_ptr = rdi; // Must be rdi due to maskmovdqu

reg64_t kj = r14;
reg64_t oi_iter = r15;
reg64_t reg_kh = rax;
reg64_t reg_k_shift = rbx;
reg64_t tmp_gpr = rcx; // Must be rcx because rdi is used above
reg64_t tmp_gpr = abi_not_param1;
reg64_t reg_ker_area_h = rdx;
reg64_t reg_nbc = rsi;

Expand All @@ -156,15 +138,18 @@ struct jit_uni_pool_kernel : public jit_generator {

int prev_kw;

void prepare_tail_mask();
void put_one_in_vmm();
void uni_broadcast_reg_val(const int reg_idx, const int vmm_idx);
void push_vmm_val(const int idx);
void pop_vmm_val(const int idx);
void load(const int idx, const reg64_t &reg_ptr, const int offset,
const bool is_c_tail_proccessing);
void store(const int idx, const reg64_t &reg_ptr, const int offset,
const bool is_c_tail_proccessing);
void load(const data_type_t dt, const int idx, const reg64_t &reg_ptr,
const int offset, const bool is_c_tail_proccessing);
void store(const data_type_t dt, const int idx, const reg64_t &reg_ptr,
const int offset, const bool is_c_tail_proccessing);
void pad_with_zeros(int idx);
void load_indices(int indr_i, int step_index, bool is_c_tail_processing);
void store_indices(int indr_i, int step_index, bool is_c_tail_processing,
bool is_first_w_block);

void maybe_recalculate_divisor(int jj, int ur_w, int pad_l, int pad_r,
bool with_c_tail_proccessing);
Expand Down Expand Up @@ -269,11 +254,11 @@ struct jit_uni_pool_kernel : public jit_generator {
return jpp.is_fp8 && is_superset(isa, avx512_core_fp16);
}

std::unique_ptr<bf16_emulation_t> bf16_emu_;
std::unique_ptr<fp8_emulation_e5m2_t> f8_e5m2_emu_;
std::unique_ptr<fp8_emulation_e4m3_t> f8_e4m3_emu_;
std::unique_ptr<injector::jit_uni_postops_injector_t<isa>>
postops_injector_;
io::jit_io_multi_dt_helper_t<Vmm> io_;
};

} // namespace x64
Expand Down
Loading