-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix accuracy of max-pooling backpropagation for bfloat16 data #2386
base: main
Are you sure you want to change the base?
Conversation
17ab064
to
2820609
Compare
make test |
make test |
2820609
to
73b85e2
Compare
73b85e2
to
dce529d
Compare
@@ -367,6 +368,11 @@ status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp, | |||
} | |||
assert(jpp.ur > 0); | |||
|
|||
jpp.needs_f32_accum_for_bf16 = jpp.is_bf16 | |||
&& jpp.alg == alg_kind::pooling_max && jpp.is_backward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the issue is limited to max algorithm, accumulation on backward happens to all algorithms... We will need to adjust threshold in benchdnn for lower data types, but it's not tied to this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, average pooling algorithms have similar loss of accuracy. The scope of the task is defined in MFDNN-11050 and MFDNN-11396 (actually MFDNN-11050 should be reopened, as its fix was rolled back).
src/cpu/x64/jit_uni_pool_kernel.cpp
Outdated
@@ -18,6 +18,7 @@ | |||
#include <bitset> | |||
|
|||
#include "common/dnnl_thread.hpp" | |||
#include "common/memory_desc.hpp" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one should come with "cpu/cpu_pooling_pd.hpp", thus, not needed as standalone.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
memory_desc_init_by_tag( | ||
jpp.tmp_md, ndims, dims, data_type::f32, fmt_tag); | ||
|
||
scratchpad.book<char>(key_pool_src_f32_accum, tmp_d.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please consider improving the code by detaching init_scratchpad
function out of init_conf
as it's done in most places elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
src/cpu/x64/jit_uni_pool_kernel.cpp
Outdated
} | ||
|
||
template <cpu_isa_t isa> | ||
inline void jit_uni_pool_kernel<isa>::load32(const int idx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's already a lot of existing routines to support loading. I highly recommend to change the loading/storing implementation to rely on io_injector, which is more flexible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, loading/storing functions need a refactoring. I did not know how to do that properly, and I did not know about io_injector. I have to investigate that. Could it be done as a separate task?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If doing it as a separate task, I definitely recommend refactor existing implementation first, and then apply f32 accumulation altogether with acc_mode support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The io injector is now used to load/store tensor data, but indices are processed the same way as before because the io injector does not support these data types: it converts integers to floats during loading (however, it does not convert data if data are stored as s32; s32 and f32 are stored by one common function store_f32; it looks like a bug).
@@ -483,6 +483,151 @@ class bwd_pooling_transpose_facade_t | |||
const dim_t c_tail_; | |||
}; | |||
|
|||
struct bwd_f32_accum_for_bf16_t { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you, please, help me understanding why a new class is needed when the only change happening is inside the kernel (use a different buffer and instructions to accumulate the inputs) and shouldn't be on a parallel/balancing level?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not understand the idea. This class contains implementation of copying data with conversion to bf16 into corresponding place of src_diff (functions bwd_f32_accum_for_bf16_t::cvt_*
).
Merge with bwd_pooling_transpose_facade_t? The class bwd_pooling_transpose_facade_t supports ncsp format, reorder both tensors (both src and dst) to block format and then back to the original format; it does not support the case ur_bc>1 (only one block is processed on each iteration). The class bwd_f32_accum_for_bf16_t is simpler, it must convert only dst (dst_diff) from f32 to bf16 without data reordering, but it supports the case ur_bc>1 for nspc.
@@ -526,6 +526,8 @@ struct jit_pool_conf_t { | |||
bool with_binary; | |||
int nthr; | |||
memory_desc_t tmp_md; | |||
bool needs_f32_accum_for_bf16; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add support for acc_mode (separate commit is totally fine) which would allow to preserve former behavior (accumulation in bf16) to avoid issues like with softmax reported recently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is in progress.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added partial support of acc_mode. jit_uni_pooling implementation of backpropagation for max-pooling for axb, aBx16b/aBx8b for bf16 numbers switches to an old implementation (without f32 accumulator) if 'relaxed' or 'any' acc mode is specified. We have strict, relaxed, any, f32, s32, f16, so s32 and f16 are just ignored, and f32 accumulator is used for strict and f32 modes.
benchdnn is updated to use zero error threshold in case of max-pooling with strict and f32 accumulation mode.
If this approach is ok then docs are to be updated. It looks that docs are not correct/complete (https://oneapi-src.github.io/oneDNN/dev_guide_pooling.html) In case of abx format our jit_uni_pooling implementation converts inputs/outputs to to/from f32 arrays, so its accumulation mode is actually always strict. 'relaxed' mode is not necessary faster than strict but uses less memory. f64 data type can be used on GPUs only (?).
I also noticed that GPU version is out of scope (?) of MFDNN-11050 and MFDNN-11396, I did not test it.
8f5d49e
to
8d3b54c
Compare
25cbb9a
to
f5d5a1d
Compare
make test |
d98dd7d
to
5734498
Compare
5734498
to
3ee89a1
Compare
make test |
32be96c
to
fbf38e7
Compare
Do not use f32 accumulator in jit_uni_pooling for max pooling back propagation with bf16 if 'relaxed' or 'any' accumulation mode is specified. Use zero error threshold in tests for max pooling if 'strict' or 'f32' accumulation mode is specified.
fbf38e7
to
3abcafb
Compare
[MFDNN-11050] (bf16 backward max pooling returns incorrect results)
[MFDNN-11396] (BF16 pooling_backward performance regression on SPR)
(Also [MFDNN-12863] (JIT max pool implementation works incorrectly for small data types and large kernels))
It was found earlier (MFDNN-11050) that bf16 backward max pooling returns incorrect results (MFDNN-11050). An initial fix of an accuracy led to significant performance regression (MFDNN-11396). That initial fix was rolled back.
The reason of an accuracy issue is that even a sum of relatively small numbers is not accurate, e.g. bf16(256.0)+bf16(1.0) is bf16(256.0). Summation can take place if some pooling strides are less than corresponding kernel sizes.
The current fix uses additional accumulation arrays of f16's, with one array per thread. The size of those arrays for src_diff is the same as for existing ncsp implementation (ncsp implementation creates arrays of f32's for dst_diff, src_diff and indices, reorders data and uses those arrays during calculations). The ncsp case is not affected by this PR.
I have done some manual measurements on a machine with SPR processor. In some cases this implementation works faster than the original version, sometimes slower, but significantly better than not-optimized implementation (that was used after the first fix of MFDNN-11050).
The following tables contain performance data for axb and aBx16b layouts for the original implementation (main branch), the fixed version (this PR), and another implementation (that is used if the optimized implementation is skipped).
Scratch of a script used to run tests:
axb
aBx16b
The current implementation also fixes the bug [MFDNN-12863] (JIT max pool implementation works incorrectly for small data types and large kernels).