Skip to content

Commit

Permalink
cpu: x64: pooling: support acc mode in max pooling backprop for bf16
Browse files Browse the repository at this point in the history
Do not use f32 accumulator in jit_uni_pooling for max pooling back
propagation with bf16 if 'relaxed' or 'any' accumulation mode
is specified.
Use zero error threshold in tests for max pooling if 'strict' or
'f32' accumulation mode is specified.
  • Loading branch information
asimonov1 committed Jan 26, 2025
1 parent f1bb9e5 commit f5d5a1d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
4 changes: 3 additions & 1 deletion src/cpu/x64/jit_uni_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,9 @@ status_t jit_uni_pool_kernel<isa>::init_conf(
}
assert(jpp.ur > 0);

jpp.needs_f32_accum_for_bf16 = jpp.is_bf16
const bool is_relaxed_acc = utils::one_of(
attr.acc_mode_, accumulation_mode::relaxed, accumulation_mode::any);
jpp.needs_f32_accum_for_bf16 = !is_relaxed_acc && jpp.is_bf16
&& jpp.alg == alg_kind::pooling_max && jpp.is_backward
&& (jpp.stride_d < jpp.kd || jpp.stride_h < jpp.kh
|| jpp.stride_w < jpp.kw);
Expand Down
11 changes: 11 additions & 0 deletions tests/benchdnn/inputs/pool/test_pool_bfloat16
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,14 @@

--attr-post-ops=add:bf16,linear:0.5:-1
--batch=set_all_small

# Backward propagation without f32 accumulator
--attr-post-ops=

--alg=max
--tag=axb,aBx8b,aBx16b

--dir=BWD_D
--attr-acc-mode=relaxed
--batch=set_all
--batch=set_topologies
8 changes: 6 additions & 2 deletions tests/benchdnn/pool/pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,13 @@ bool cuda_check_correctness(const prb_t *prb,

void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
const args_t &ref_args) {
const bool is_strict_acc
= prb->attr.acc_mode == dnnl_accumulation_mode_strict
|| prb->attr.acc_mode == dnnl_accumulation_mode_f32;
// Threshold to compensate division error. CPU could live with 6.f coeff.
const float trh
= prb->alg == alg_t::max ? 0.f : 10.f * epsilon_dt(prb->dt[1]);
const float trh = (prb->alg == alg_t::max && is_strict_acc)
? 0.f
: 10.f * epsilon_dt(prb->dt[1]);
cmp.set_threshold(trh);
// Backward may have most zeroes for ker_in_pad with huge kernels problems.
const float zero_percent = (prb->dir & FLAG_FWD) ? 99.f : 100.f;
Expand Down

0 comments on commit f5d5a1d

Please sign in to comment.