Skip to content

Commit c2755ae

Browse files
committed
cpu: x64: pooling: support acc mode in max pooling backprop for bf16
Do not use f32 accumulator in jit_uni_pooling for max pooling back propagation with bf16 if 'relaxed' or 'any' accumulation mode is specified. Use zero error threshold in tests for max pooling if 'strict' or 'f32' accumulation mode is specified.
1 parent 5382c81 commit c2755ae

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

src/cpu/x64/jit_uni_pool_kernel.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,9 @@ status_t jit_uni_pool_kernel<isa>::init_conf(
367367
}
368368
assert(jpp.ur > 0);
369369

370-
jpp.needs_f32_accum_for_bf16 = jpp.is_bf16
370+
const bool is_relaxed_acc = utils::one_of(
371+
attr.acc_mode_, accumulation_mode::relaxed, accumulation_mode::any);
372+
jpp.needs_f32_accum_for_bf16 = !is_relaxed_acc && jpp.is_bf16
371373
&& jpp.alg == alg_kind::pooling_max && jpp.is_backward
372374
&& (jpp.stride_d < jpp.kd || jpp.stride_h < jpp.kh
373375
|| jpp.stride_w < jpp.kw);

tests/benchdnn/inputs/pool/test_pool_bfloat16

+11
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,14 @@
2222

2323
--attr-post-ops=add:bf16,linear:0.5:-1
2424
--batch=set_all_small
25+
26+
# Backward propagation without f32 accumulator
27+
--attr-post-ops=
28+
29+
--alg=max
30+
--tag=axb,aBx8b,aBx16b
31+
32+
--dir=BWD_D
33+
--attr-acc-mode=relaxed
34+
--batch=set_all
35+
--batch=set_topologies

tests/benchdnn/pool/pool.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,13 @@ bool cuda_check_correctness(const prb_t *prb,
192192

193193
void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
194194
const args_t &ref_args) {
195+
const bool is_strict_acc
196+
= prb->attr.acc_mode == dnnl_accumulation_mode_strict
197+
|| prb->attr.acc_mode == dnnl_accumulation_mode_f32;
195198
// Threshold to compensate division error. CPU could live with 6.f coeff.
196-
const float trh
197-
= prb->alg == alg_t::max ? 0.f : 10.f * epsilon_dt(prb->dt[1]);
199+
const float trh = (prb->alg == alg_t::max && is_strict_acc)
200+
? 0.f
201+
: 10.f * epsilon_dt(prb->dt[1]);
198202
cmp.set_threshold(trh);
199203
// Backward may have most zeroes for ker_in_pad with huge kernels problems.
200204
const float zero_percent = (prb->dir & FLAG_FWD) ? 99.f : 100.f;

0 commit comments

Comments
 (0)