diff --git a/src/common/memory_tracking.hpp b/src/common/memory_tracking.hpp index b3f18bdca01..43b33052201 100644 --- a/src/common/memory_tracking.hpp +++ b/src/common/memory_tracking.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * Copyright 2024-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -271,6 +271,7 @@ enum { key_pool_dst_plain2blocked_cvt, key_pool_ind_plain2blocked_cvt, key_pool_src_bf16cvt, + key_pool_src_f32_accum, key_pool_src_plain2blocked_cvt, key_pool_reduction, key_precomputed_scales, diff --git a/src/cpu/x64/jit_primitive_conf.hpp b/src/cpu/x64/jit_primitive_conf.hpp index f3bf75e4572..9c471b45bab 100644 --- a/src/cpu/x64/jit_primitive_conf.hpp +++ b/src/cpu/x64/jit_primitive_conf.hpp @@ -526,6 +526,8 @@ struct jit_pool_conf_t { bool with_binary; int nthr; memory_desc_t tmp_md; + bool needs_f32_accum_for_bf16; + dim_t f32_accum_block_size; }; struct jit_pool_call_s { diff --git a/src/cpu/x64/jit_uni_pool_kernel.cpp b/src/cpu/x64/jit_uni_pool_kernel.cpp index 8d35ff6dfd9..e7478829cd7 100644 --- a/src/cpu/x64/jit_uni_pool_kernel.cpp +++ b/src/cpu/x64/jit_uni_pool_kernel.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2024 Intel Corporation +* Copyright 2017-2025 Intel Corporation * Copyright 2018 YANDEX LLC * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,6 +18,7 @@ #include #include "common/dnnl_thread.hpp" +#include "common/memory_desc.hpp" #include "cpu/cpu_pooling_pd.hpp" #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" @@ -422,6 +423,29 @@ status_t jit_uni_pool_kernel::init_conf(jit_pool_conf_t &jpp, * nscr); } + jpp.needs_f32_accum_for_bf16 = jpp.is_bf16 + && jpp.alg == alg_kind::pooling_max && jpp.is_backward + && (jpp.stride_d < jpp.kd || jpp.stride_h < jpp.kh + || jpp.stride_w < jpp.kw); + jpp.f32_accum_block_size = jpp.ur_bc * jpp.c_block; + if (jpp.needs_f32_accum_for_bf16) { + auto tmp_d = memory_desc_wrapper(jpp.tmp_md); + assert(tmp_d.is_zero() + && (fmt_tag == nspc_fmt_tag || fmt_tag == blocked_fmt_tag)); + + dims_t dims {}; + utils::array_copy(dims, src_d.dims(), ndims); + + const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); + dims[0] = nstl::min(dnnl_get_max_threads(), jpp.mb * nb2_c); + dims[1] = jpp.f32_accum_block_size; + + memory_desc_init_by_tag( + jpp.tmp_md, ndims, dims, data_type::f32, fmt_tag); + + scratchpad.book(key_pool_src_f32_accum, tmp_d.size()); + } + jpp.post_ops = attr.post_ops_; return status::success; @@ -511,15 +535,22 @@ inline void jit_uni_pool_kernel::load(const int idx, else if (jpp.src_dt == data_type::f8_e4m3) f8_e4m3_emu_->vcvt_f8_to_f32(vmm_to_load, ptr[reg_ptr + offset]); } else { - if (is_c_tail_proccessing && !jpp.is_c_padded) { - if (isa == avx || isa == avx2) { - vmaskmovps(Vmm(idx), vmm_c_tail_mask, ptr[reg_ptr + offset]); - } else { - vmovups(Zmm(idx) | k_c_tail_mask | T_z, ptr[reg_ptr + offset]); - } + load32(idx, reg_ptr, offset, is_c_tail_proccessing); + } +} + +template +inline void jit_uni_pool_kernel::load32(const int idx, + const reg64_t ®_ptr, const int offset, + const bool is_c_tail_proccessing) { + if (is_c_tail_proccessing && !jpp.is_c_padded) { + if (isa == avx || isa == avx2) { + vmaskmovps(Vmm(idx), vmm_c_tail_mask, ptr[reg_ptr + offset]); } else { - uni_vmovups(Vmm(idx), ptr[reg_ptr + offset]); + vmovups(Zmm(idx) | k_c_tail_mask | T_z, ptr[reg_ptr + offset]); } + } else { + uni_vmovups(Vmm(idx), ptr[reg_ptr + offset]); } } @@ -551,6 +582,13 @@ inline void jit_uni_pool_kernel::load(const int idx, assert(!"invalid data type"); } +template <> +inline void jit_uni_pool_kernel::load32(const int idx, + const reg64_t ®_ptr, const int offset, + const bool is_c_tail_proccessing) { + assert(!"invalid data type"); +} + template <> inline void jit_uni_pool_kernel::load(const int idx, const reg64_t ®_ptr, const int offset, @@ -562,6 +600,13 @@ inline void jit_uni_pool_kernel::load(const int idx, uni_vmovups(Vmm(idx), ptr[reg_ptr + offset]); } +template <> +inline void jit_uni_pool_kernel::load32(const int idx, + const reg64_t ®_ptr, const int offset, + const bool is_c_tail_proccessing) { + assert(!"invalid data type"); +} + template inline void jit_uni_pool_kernel::store(const int idx, const reg64_t ®_ptr, const int offset, @@ -585,27 +630,33 @@ inline void jit_uni_pool_kernel::store(const int idx, } else vmovdqu8(yword[reg_ptr + offset], Xmm(idx)); } else { - if (is_c_tail_proccessing) { - if (!jpp.is_c_padded) { - if (isa == avx || isa == avx2) - vmaskmovps( - ptr[reg_ptr + offset], vmm_c_tail_mask, Vmm(idx)); - else - vmovups(ptr[reg_ptr + offset] | k_c_tail_mask, Zmm(idx)); - } else { - if (jpp.with_postops) { - if (isa == avx || isa == avx2) { - uni_vxorps(ymm_tmp_1, ymm_tmp_1, ymm_tmp_1); - uni_vblendvps( - Vmm(idx), ymm_tmp_1, Vmm(idx), vmm_c_tail_mask); - } else - uni_vmovups(Vmm(idx) | k_c_tail_mask | T_z, Vmm(idx)); - } - uni_vmovups(vmmword[reg_ptr + offset], Vmm(idx)); + store32(idx, reg_ptr, offset, is_c_tail_proccessing); + } +} + +template +inline void jit_uni_pool_kernel::store32(const int idx, + const reg64_t ®_ptr, const int offset, + const bool is_c_tail_proccessing) { + if (is_c_tail_proccessing) { + if (!jpp.is_c_padded) { + if (isa == avx || isa == avx2) + vmaskmovps(ptr[reg_ptr + offset], vmm_c_tail_mask, Vmm(idx)); + else + vmovups(ptr[reg_ptr + offset] | k_c_tail_mask, Zmm(idx)); + } else { + if (jpp.with_postops) { + if (isa == avx || isa == avx2) { + uni_vxorps(ymm_tmp_1, ymm_tmp_1, ymm_tmp_1); + uni_vblendvps( + Vmm(idx), ymm_tmp_1, Vmm(idx), vmm_c_tail_mask); + } else + uni_vmovups(Vmm(idx) | k_c_tail_mask | T_z, Vmm(idx)); } - } else uni_vmovups(vmmword[reg_ptr + offset], Vmm(idx)); - } + } + } else + uni_vmovups(vmmword[reg_ptr + offset], Vmm(idx)); } template <> @@ -626,6 +677,13 @@ inline void jit_uni_pool_kernel::store(const int idx, assert(!"datatype not supported"); } +template <> +inline void jit_uni_pool_kernel::store32(const int idx, + const reg64_t ®_ptr, const int offset, + const bool is_c_tail_proccessing) { + assert(!"datatype not supported"); +} + template <> inline void jit_uni_pool_kernel::store(const int idx, const reg64_t ®_ptr, const int offset, @@ -664,6 +722,13 @@ inline void jit_uni_pool_kernel::store(const int idx, uni_vmovups(vmmword[reg_ptr + offset], Vmm(idx)); } +template <> +inline void jit_uni_pool_kernel::store32(const int idx, + const reg64_t ®_ptr, const int offset, + const bool is_c_tail_proccessing) { + assert(!"datatype not supported"); +} + template bool jit_uni_pool_kernel::post_ops_ok(jit_pool_conf_t &jpp, const primitive_attr_t &attr, const memory_desc_wrapper &dst_d) { @@ -1237,8 +1302,20 @@ inline void jit_uni_pool_kernel::max_step_bwd(int ur_w, int ur_bc, int kw = jpp.kw; int stride_w = jpp.stride_w; int c_block = jpp.c_block; - const int c_off + const int output_c_off = (jpp.tag_kind == jit_memory_tag_kind_t::nspc) ? jpp.c : c_block; + const int input_c_off = jpp.needs_f32_accum_for_bf16 + ? jpp.f32_accum_block_size + : output_c_off; + const size_t input_dt_size + = jpp.needs_f32_accum_for_bf16 ? sizeof(float) : jpp.dt_size; + const auto store_input_fun = jpp.needs_f32_accum_for_bf16 + ? &jit_uni_pool_kernel::store32 + : &jit_uni_pool_kernel::store; + const auto load_input_fun = jpp.needs_f32_accum_for_bf16 + ? &jit_uni_pool_kernel::load32 + : &jit_uni_pool_kernel::load; + Label kd_label, kh_label; const auto is_tail_processing = [&](int bc) { @@ -1256,9 +1333,9 @@ inline void jit_uni_pool_kernel::max_step_bwd(int ur_w, int ur_bc, for_(int jj = 0; jj < ur_w; jj++) for (int bci = 0; bci < ur_bc; bci++) { const auto outr_i = reg_ind(0, bci, jj, ur_bc, ur_w); - auto out_offset = jpp.dt_size * (jj * c_off + bci * c_block); + auto out_offset = jpp.dt_size * (jj * output_c_off + bci * c_block); load(reg_idx(outr_i), reg_output, out_offset, is_tail_processing(bci)); - const size_t step_index = (jj * c_off + bci * c_block) + const size_t step_index = (jj * output_c_off + bci * c_block) * types::data_type_size(jpp.ind_dt); const auto indr_i = reg_ind(1, bci, jj, ur_bc, ur_w); @@ -1334,12 +1411,12 @@ inline void jit_uni_pool_kernel::max_step_bwd(int ur_w, int ur_bc, const auto inpr_i = reg_ind(2, bci, jj, ur_bc, ur_w); const auto inpvr = vreg(inpr_i); const auto cvtvr = vreg(reg_ind(3, bci, jj, ur_bc, ur_w)); - int aux_inp_offset - = (ki + jj * stride_w - pad_l) * c_off + bci * c_block; - if (aux_inp_offset >= iw * c_off) continue; - int inp_offset = jpp.dt_size * aux_inp_offset; - load(reg_idx(inpr_i), aux_reg_input, inp_offset, - is_tail_processing(bci)); + int aux_inp_offset = (ki + jj * stride_w - pad_l) * input_c_off + + bci * c_block; + if (aux_inp_offset >= iw * input_c_off) continue; + int inp_offset = input_dt_size * aux_inp_offset; + (this->*load_input_fun)(reg_idx(inpr_i), aux_reg_input, + inp_offset, is_tail_processing(bci)); if (isa == sse41) { mov(dst_ptr, aux_reg_input); add(dst_ptr, inp_offset); @@ -1377,7 +1454,7 @@ inline void jit_uni_pool_kernel::max_step_bwd(int ur_w, int ur_bc, vpcmpeqd(k_store_mask, indvr, vmm_k_offset); vblendmps(vmm_tmp | k_store_mask | T_z, outvr, outvr); vaddps(inpvr, inpvr, vmm_tmp); - if (jpp.is_bf16) { + if (jpp.is_bf16 && !jpp.needs_f32_accum_for_bf16) { if (!isa_has_bf16(jpp.isa)) bf16_emu_->vcvtneps2bf16(indyr, indzr); else @@ -1385,8 +1462,8 @@ inline void jit_uni_pool_kernel::max_step_bwd(int ur_w, int ur_bc, } else if (jpp.is_f16) { vcvtps2ph(indyr, inpvr, _op_mxcsr); } - store(inpvr.getIdx(), aux_reg_input, inp_offset, - is_tail_processing(bci)); + (this->*store_input_fun)(inpvr.getIdx(), aux_reg_input, + inp_offset, is_tail_processing(bci)); } } @@ -1404,13 +1481,13 @@ inline void jit_uni_pool_kernel::max_step_bwd(int ur_w, int ur_bc, if (with_c_tail_proccessing && (isa == avx || isa == avx2)) pop_vmm_val(vmm_c_tail_mask.getIdx()); } - add(aux_reg_input, jpp.dt_size * iw * c_off); + add(aux_reg_input, input_dt_size * iw * input_c_off); inc(kj); cmp(kj, reg_kh); jl(kh_label, T_NEAR); } if (jpp.simple_alg && jpp.ndims == 5) { - add(aux_reg_input_d, jpp.dt_size * jpp.ih * iw * c_off); + add(aux_reg_input_d, input_dt_size * jpp.ih * iw * input_c_off); mov(tmp_gpr, reg_kd_pad_shift); uni_vmovq(xmm_tmp, tmp_gpr); @@ -1438,9 +1515,10 @@ inline void jit_uni_pool_kernel::max_step_bwd(int ur_w, int ur_bc, template void jit_uni_pool_kernel::zero_diff_src( int ur_bc, bool with_c_tail_proccessing) { - const int c_off = (jpp.tag_kind == jit_memory_tag_kind_t::nspc) - ? jpp.c - : jpp.c_block; + const int c_off = jpp.needs_f32_accum_for_bf16 + ? jpp.f32_accum_block_size + : ((jpp.tag_kind == jit_memory_tag_kind_t::nspc) ? jpp.c + : jpp.c_block); Label l_skip, l_ih_loop, l_id_loop; @@ -1461,7 +1539,12 @@ void jit_uni_pool_kernel::zero_diff_src( Vmm vzero = vmm_tmp; uni_vpxor(vzero, vzero, vzero); - const int width_size = jpp.iw * c_off * jpp.dt_size; + const size_t dt_size + = jpp.needs_f32_accum_for_bf16 ? sizeof(float) : jpp.dt_size; + const int width_size = jpp.iw * c_off * dt_size; + const auto store_fun = jpp.needs_f32_accum_for_bf16 + ? &jit_uni_pool_kernel::store32 + : &jit_uni_pool_kernel::store; auto aux_reg_zero_ptr = tmp_gpr; @@ -1472,30 +1555,30 @@ void jit_uni_pool_kernel::zero_diff_src( L(l_ih_loop); { const auto vlen = cpu_isa_traits::vlen; - const int step = c_off * jpp.dt_size; + const int step = c_off * dt_size; // TODO: maybe a big code generated here for_(int i = 0; i < width_size; i += step) for (int bci = 0; bci < ur_bc; bci++) { - const int offs = i + bci * jpp.c_block * jpp.dt_size; + const int offs = i + bci * jpp.c_block * dt_size; if (isa == sse41) { bool is_needed_c_tail_processing = false; if (is_tail_processing(bci) && jpp.c_tail < (jpp.c_block / 2)) is_needed_c_tail_processing = true; - store(vzero.getIdx(), reg_zero_ptr, offs, + (this->*store_fun)(vzero.getIdx(), reg_zero_ptr, offs, is_needed_c_tail_processing); if (!is_tail_processing(bci) || (is_tail_processing(bci) && (jpp.is_c_padded || jpp.c_tail > (jpp.c_block / 2)))) { - store(vzero.getIdx(), reg_zero_ptr, offs + vlen, - is_tail_processing(bci)); + (this->*store_fun)(vzero.getIdx(), reg_zero_ptr, + offs + vlen, is_tail_processing(bci)); } } else { - store(vzero.getIdx(), reg_zero_ptr, offs, + (this->*store_fun)(vzero.getIdx(), reg_zero_ptr, offs, is_tail_processing(bci)); } } @@ -1526,11 +1609,17 @@ void jit_uni_pool_kernel::generate() { int c_block = jpp.c_block; int stride_w = jpp.stride_w; int l_pad = jpp.l_pad; - const int c_off + const int output_c_off = (jpp.tag_kind == jit_memory_tag_kind_t::nspc) ? jpp.c : c_block; + const int input_c_off = jpp.needs_f32_accum_for_bf16 + ? jpp.f32_accum_block_size + : output_c_off; int vlen = cpu_isa_traits::vlen; + const size_t input_dt_size + = jpp.needs_f32_accum_for_bf16 ? sizeof(float) : jpp.dt_size; + #if defined(_WIN32) // Always mimic the Unix ABI (see the note about maskmovdqu in the header // file). @@ -1587,15 +1676,17 @@ void jit_uni_pool_kernel::generate() { if (!inc_reg) return; - auto dt_size = jpp.dt_size; + auto output_dt_size = jpp.dt_size; auto shift = (isa == sse41) ? vlen : 0; add(reg_input, - dt_size * nstl::max(0, ur_w * stride_w - lpad) * c_off - shift); - add(reg_output, dt_size * ur_w * c_off - shift); + input_dt_size * nstl::max(0, ur_w * stride_w - lpad) + * input_c_off + - shift); + add(reg_output, output_dt_size * ur_w * output_c_off - shift); if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) { auto ishift = (isa == sse41) ? jpp.c_block / 2 : 0; auto ind_dt_size = types::data_type_size(jpp.ind_dt); - add(reg_index, (ur_w * c_off - ishift) * ind_dt_size); + add(reg_index, (ur_w * output_c_off - ishift) * ind_dt_size); } }; diff --git a/src/cpu/x64/jit_uni_pool_kernel.hpp b/src/cpu/x64/jit_uni_pool_kernel.hpp index 3f4f0eb96a6..70f470989d3 100644 --- a/src/cpu/x64/jit_uni_pool_kernel.hpp +++ b/src/cpu/x64/jit_uni_pool_kernel.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2024 Intel Corporation +* Copyright 2017-2025 Intel Corporation * Copyright 2018 YANDEX LLC * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -163,8 +163,12 @@ struct jit_uni_pool_kernel : public jit_generator { void pop_vmm_val(const int idx); void load(const int idx, const reg64_t ®_ptr, const int offset, const bool is_c_tail_proccessing); + void load32(const int idx, const reg64_t ®_ptr, const int offset, + const bool is_c_tail_proccessing); void store(const int idx, const reg64_t ®_ptr, const int offset, const bool is_c_tail_proccessing); + void store32(const int idx, const reg64_t ®_ptr, const int offset, + const bool is_c_tail_proccessing); void maybe_recalculate_divisor(int jj, int ur_w, int pad_l, int pad_r, bool with_c_tail_proccessing); diff --git a/src/cpu/x64/jit_uni_pooling.cpp b/src/cpu/x64/jit_uni_pooling.cpp index 01810218be9..d06f36a7390 100644 --- a/src/cpu/x64/jit_uni_pooling.cpp +++ b/src/cpu/x64/jit_uni_pooling.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017 - 2024 Intel Corporation +* Copyright 2017 - 2025 Intel Corporation * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -483,6 +483,152 @@ class bwd_pooling_transpose_facade_t const dim_t c_tail_; }; +class bwd_f32_accum_for_bf16_t { +public: + using value_type = typename prec_traits::type; + + bwd_f32_accum_for_bf16_t(const jit_pool_conf_t &jpp, const exec_ctx_t &ctx); + + value_type *get_addr_2d(int ithr, dim_t ih) const { + return blk_data(ithr, 0, ih, 0); + } + + value_type *get_addr_3d(int ithr, dim_t id, dim_t ih) const { + return blk_data(ithr, 0, id, ih, 0); + } + + void zero_data(int ithr); + + void cvt_to_bf16_slice_2d(int ithr, bfloat16_t *dst, + memory_desc_wrapper const &dst_d, dim_t n, dim_t b_c, + dim_t ur_bc) const; + + void cvt_to_bf16_slice_3d(int ithr, bfloat16_t *dst, + memory_desc_wrapper const &dst_d, dim_t n, dim_t b_c, + dim_t ur_bc) const; + +private: + template + value_type *blk_data(Args... args) const { + assert(wsp_); + return wsp_ + accum_d_.blk_off(std::forward(args)...); + } + + const jit_pool_conf_t &jpp_; + value_type *wsp_ {nullptr}; + memory_desc_wrapper accum_d_ {nullptr}; +}; + +bwd_f32_accum_for_bf16_t::bwd_f32_accum_for_bf16_t( + const jit_pool_conf_t &jpp, const exec_ctx_t &ctx) + : jpp_ {jpp} { + if (jpp_.needs_f32_accum_for_bf16) { + accum_d_ = memory_desc_wrapper(jpp_.tmp_md); + auto &scratchpad = ctx.get_scratchpad_grantor(); + wsp_ = scratchpad.template get( + memory_tracking::names::key_pool_src_f32_accum); + assert(wsp_); + } +} + +void bwd_f32_accum_for_bf16_t::zero_data(int ithr) { + auto *data = blk_data(ithr); + memset(data, 0, + jpp_.tmp_md.format_desc.blocking.strides[0] * sizeof(value_type)); +} + +void bwd_f32_accum_for_bf16_t::cvt_to_bf16_slice_2d(int ithr, bfloat16_t *dst, + memory_desc_wrapper const &dst_d, dim_t n, dim_t b_c, + dim_t ur_bc) const { + + assert(wsp_ && (jpp_.ndims == 3 || jpp_.ndims == 4) + && (jpp_.tag_kind == jit_memory_tag_kind_t::nspc + || jpp_.tag_kind == jit_memory_tag_kind_t::blocked)); + + if (jpp_.tag_kind == jit_memory_tag_kind_t::nspc) { + if (jpp_.tmp_md.dims[1] == jpp_.c && b_c == 0 + && jpp_.c == ur_bc * jpp_.c_block) { + // all channels + const size_t nelems = jpp_.ih * jpp_.iw * jpp_.c; + auto *cur_src = blk_data(ithr); + auto *cur_dst = dst + dst_d.blk_off(n); + cvt_float_to_bfloat16(cur_dst, cur_src, nelems); + } else { + const auto c_b = jpp_.c_block * b_c; + const auto c_e = nstl::min( + static_cast(jpp_.c), jpp_.c_block * (b_c + ur_bc)); + + if (c_b >= c_e) return; + + const size_t nelems = c_e - c_b; + if (jpp_.ndims == 4) { + for (dim_t h = 0; h < jpp_.ih; ++h) { + for (dim_t w = 0; w < jpp_.iw; ++w) { + auto *cur_src = blk_data(ithr, 0, h, w); + auto *cur_dst = dst + dst_d.blk_off(n, c_b, h, w); + cvt_float_to_bfloat16(cur_dst, cur_src, nelems); + } + } + } else { + for (dim_t w = 0; w < jpp_.iw; ++w) { + auto *cur_src = blk_data(ithr, 0, w); + auto *cur_dst = dst + dst_d.blk_off(n, c_b, w); + cvt_float_to_bfloat16(cur_dst, cur_src, nelems); + } + } + } + } else if (jpp_.tag_kind == jit_memory_tag_kind_t::blocked) { + assert(ur_bc == 1); + + const size_t nelems = jpp_.ih * jpp_.iw * jpp_.c_block; + auto *src_b = blk_data(ithr); + auto *dst_b = dst + dst_d.blk_off(n, b_c); + cvt_float_to_bfloat16(dst_b, src_b, nelems); + } +} + +void bwd_f32_accum_for_bf16_t::cvt_to_bf16_slice_3d(int ithr, bfloat16_t *dst, + memory_desc_wrapper const &dst_d, dim_t n, dim_t b_c, + dim_t ur_bc) const { + + assert(wsp_ && jpp_.ndims == 5 + && (jpp_.tag_kind == jit_memory_tag_kind_t::nspc + || jpp_.tag_kind == jit_memory_tag_kind_t::blocked)); + + if (jpp_.tag_kind == jit_memory_tag_kind_t::blocked) { + assert(ur_bc == 1); + const size_t nelems = jpp_.id * jpp_.ih * jpp_.iw * jpp_.c_block; + auto *src_b = blk_data(ithr); + auto *dst_b = dst + dst_d.blk_off(n, b_c); + cvt_float_to_bfloat16(dst_b, src_b, nelems); + } else if (jpp_.tag_kind == jit_memory_tag_kind_t::nspc) { + if (jpp_.tmp_md.dims[1] == jpp_.c && b_c == 0 + && jpp_.c == ur_bc * jpp_.c_block) { + // all channels + const size_t nelems = jpp_.id * jpp_.ih * jpp_.iw * jpp_.c; + cvt_float_to_bfloat16( + dst + dst_d.blk_off(n), blk_data(ithr), nelems); + } else { + const auto c_b = jpp_.c_block * b_c; + const auto c_e = nstl::min( + static_cast(jpp_.c), jpp_.c_block * (b_c + ur_bc)); + + if (c_b >= c_e) return; + + const size_t nelems = c_e - c_b; + for (dim_t id = 0; id < jpp_.id; ++id) { + for (dim_t h = 0; h < jpp_.ih; ++h) { + for (dim_t w = 0; w < jpp_.iw; ++w) { + auto *cur_src = blk_data(ithr, 0, id, h, w); + auto *cur_dst = dst + dst_d.blk_off(n, c_b, id, h, w); + cvt_float_to_bfloat16(cur_dst, cur_src, nelems); + } + } + } + } + } +} + } // namespace jit_uni_pooling_utils template @@ -906,6 +1052,8 @@ void jit_uni_pooling_bwd_t::execute_backward( diff_dst_d, indices_d, wsp_dt_, diff_src, diff_dst, indices, ctx); + bwd_f32_accum_for_bf16_t f32_accum(jpp, ctx); + auto get_first_ih = [&](int oh) { return nstl::min(nstl::max(oh * jpp.stride_h - jpp.t_pad, 0), jpp.ih); }; @@ -924,6 +1072,8 @@ void jit_uni_pooling_bwd_t::execute_backward( const auto c_off = jpp.is_plain() ? b_c * jpp.c_block : b_c; if (transpose_facade.should_transpose_src()) arg.src = transpose_facade.get_src_addr(ithr, ih, jpp); + else if (jpp.needs_f32_accum_for_bf16) + arg.src = f32_accum.get_addr_2d(ithr, ih); else arg.src = &diff_src[diff_src_d.blk_off(n, c_off, ih)]; @@ -950,6 +1100,8 @@ void jit_uni_pooling_bwd_t::execute_backward( if (transpose_facade.should_transpose_src()) arg.zero_ptr = transpose_facade.get_src_addr(ithr, zero_ih_start, jpp); + else if (jpp.needs_f32_accum_for_bf16) + arg.zero_ptr = f32_accum.get_addr_2d(ithr, zero_ih_start); else arg.zero_ptr = &diff_src[diff_src_d.blk_off(n, c_off, zero_ih_start, 0)]; @@ -978,6 +1130,10 @@ void jit_uni_pooling_bwd_t::execute_backward( if (transpose_facade.should_transpose_src()) transpose_facade.execute_transpose_output(ithr, n, b_c); + + if (jpp.needs_f32_accum_for_bf16) + f32_accum.cvt_to_bf16_slice_2d( + ithr, (bfloat16_t *)diff_src, diff_src_d, n, b_c, ur_bc); }; const int nthr = jpp.nthr; @@ -1029,6 +1185,12 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( const auto trans_src = transpose_facade.should_transpose_src(); const auto trans_dst = transpose_facade.should_transpose_dst(); + bwd_f32_accum_for_bf16_t f32_accum(jpp, ctx); + + const size_t input_dt_size = jpp.needs_f32_accum_for_bf16 + ? sizeof(bwd_f32_accum_for_bf16_t::value_type) + : jpp.dt_size; + auto get_last_ih = [&](int oh) { return nstl::min( nstl::max(oh * jpp.stride_h - jpp.t_pad + jpp.kh, 0), jpp.ih); @@ -1056,6 +1218,8 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( if (trans_src) arg.src = transpose_facade.get_src_addr_3d(ithr, id + kd, ih, jpp); + else if (jpp.needs_f32_accum_for_bf16) + arg.src = f32_accum.get_addr_3d(ithr, id + kd, ih); else arg.src = (const void *)&diff_src[diff_src_d.blk_off( n, c_off, id + kd, ih)]; @@ -1091,6 +1255,9 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( if (trans_src) arg.zero_ptr = transpose_facade.get_src_addr_3d( ithr, zero_id_start, zero_ih_start, jpp); + else if (jpp.needs_f32_accum_for_bf16) + arg.zero_ptr = f32_accum.get_addr_3d( + ithr, zero_id_start, zero_ih_start); else arg.zero_ptr = &diff_src[diff_src_d.blk_off( n, c_off, zero_id_start, zero_ih_start, 0)]; @@ -1135,18 +1302,34 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( const int nthr = jpp.nthr; if (jpp.simple_alg) { + const dim_t nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); + if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { - const dim_t nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); - parallel_nd( - jpp.mb, jpp.od, nb2_c, [&](dim_t n, dim_t od, dim_t b2_c) { - const dim_t b_c = b2_c * jpp.ur_bc; - const dim_t ur_bc - = nstl::min(dim_t(jpp.ur_bc), jpp.nb_c - b_c); - process_simple(n, b_c, od, ur_bc, first_ithr); - }); + if (!jpp.needs_f32_accum_for_bf16) { + parallel_nd(jpp.mb, jpp.od, nb2_c, + [&](dim_t n, dim_t od, dim_t b2_c) { + const dim_t b_c = b2_c * jpp.ur_bc; + const dim_t ur_bc = nstl::min( + dim_t(jpp.ur_bc), jpp.nb_c - b_c); + process_simple(n, b_c, od, ur_bc, first_ithr); + }); + } else { + parallel_nd_ext(nthr, jpp.mb, nb2_c, + [&](dim_t ithr, dim_t nthr, dim_t n, dim_t b2_c) { + const dim_t b_c = b2_c * jpp.ur_bc; + const dim_t ur_bc = nstl::min( + dim_t(jpp.ur_bc), jpp.nb_c - b_c); + for (int od = 0; od < jpp.od; ++od) { + process_simple(n, b_c, od, ur_bc, ithr); + } + f32_accum.cvt_to_bf16_slice_3d(ithr, + (bfloat16_t *)diff_src, diff_src_d, n, b_c, + ur_bc); + }); + } } else { assert(jpp.ur_bc == 1); - if (trans_src || trans_dst) { + if (trans_src || trans_dst || jpp.needs_f32_accum_for_bf16) { parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, [&](dim_t ithr, dim_t nthr, dim_t n, dim_t b_c) { if (trans_src) @@ -1158,6 +1341,10 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( if (trans_dst) transpose_facade.execute_transpose_output( ithr, n, b_c); + if (jpp.needs_f32_accum_for_bf16) + f32_accum.cvt_to_bf16_slice_3d(ithr, + (bfloat16_t *)diff_src, diff_src_d, n, + b_c, 1); }); } else { parallel_nd(jpp.mb, jpp.nb_c, jpp.od, @@ -1168,31 +1355,35 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( } } else { const data_t zero_val = 0; - if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { - const size_t chunk_size = (size_t)jpp.ih * jpp.iw * jpp.c; - parallel_nd(jpp.mb, jpp.id, [&](dim_t n, dim_t id) { - const size_t offset = ((size_t)n * jpp.id + id) * chunk_size; - PRAGMA_OMP_SIMD() - for (size_t idx = 0; idx < chunk_size; ++idx) - diff_src[offset + idx] = zero_val; - }); - } else { - if (!trans_src) { - const size_t chunk_size - = (size_t)jpp.id * jpp.ih * jpp.iw * jpp.c_block; - parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, - [&](dim_t ithr, dim_t nthr, dim_t n, dim_t b_c) { - const size_t offset - = ((size_t)n * jpp.nb_c + b_c) * chunk_size; - PRAGMA_OMP_SIMD() - for (size_t idx = 0; idx < chunk_size; ++idx) - diff_src[offset + idx] = zero_val; - }); + if (!jpp.needs_f32_accum_for_bf16) { + if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { + const size_t chunk_size = (size_t)jpp.ih * jpp.iw * jpp.c; + parallel_nd(jpp.mb, jpp.id, [&](dim_t n, dim_t id) { + const size_t offset + = ((size_t)n * jpp.id + id) * chunk_size; + PRAGMA_OMP_SIMD() + for (size_t idx = 0; idx < chunk_size; ++idx) + diff_src[offset + idx] = zero_val; + }); + } else { + if (!trans_src) { + const size_t chunk_size + = (size_t)jpp.id * jpp.ih * jpp.iw * jpp.c_block; + parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, + [&](dim_t ithr, dim_t nthr, dim_t n, dim_t b_c) { + const size_t offset + = ((size_t)n * jpp.nb_c + b_c) + * chunk_size; + PRAGMA_OMP_SIMD() + for (size_t idx = 0; idx < chunk_size; ++idx) + diff_src[offset + idx] = zero_val; + }); + } } } const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); - if (trans_src || trans_dst) { + if (trans_src || trans_dst || jpp.needs_f32_accum_for_bf16) { parallel_nd_ext(nthr, jpp.mb, nb2_c, [&](dim_t ithr, dim_t nthr, dim_t n, dim_t b2_c) { const dim_t b_c = b2_c * jpp.ur_bc; @@ -1202,16 +1393,19 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( ithr, n, b_c); size_t block_size = jpp.c_block * jpp.id * jpp.ih - * jpp.iw * jpp.dt_size; + * jpp.iw * input_dt_size; const void *src = transpose_facade.get_src_addr_3d( ithr, 0, 0, jpp); std::memset((void *)src, zero_val, block_size); } + if (jpp.needs_f32_accum_for_bf16) + f32_accum.zero_data(ithr); + + const dim_t ur_bc + = nstl::min(dim_t(jpp.ur_bc), jpp.nb_c - b_c); for (dim_t kd = 0; kd < jpp.kd; ++kd) { - const dim_t ur_bc = nstl::min( - dim_t(jpp.ur_bc), jpp.nb_c - b_c); for (int od = 0; od < jpp.od; ++od) { const dim_t ik = static_cast(od) * jpp.stride_d; @@ -1236,6 +1430,11 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( if (trans_src) transpose_facade.execute_transpose_output( ithr, n, b_c); + + if (jpp.needs_f32_accum_for_bf16) + f32_accum.cvt_to_bf16_slice_3d(ithr, + (bfloat16_t *)diff_src, diff_src_d, n, b_c, + ur_bc); }); } else { for (dim_t kd = 0; kd < jpp.kd; ++kd) {