diff --git a/src/cpu/x64/jit_uni_pool_kernel.cpp b/src/cpu/x64/jit_uni_pool_kernel.cpp index 938e1629290..2ccb67a4b37 100644 --- a/src/cpu/x64/jit_uni_pool_kernel.cpp +++ b/src/cpu/x64/jit_uni_pool_kernel.cpp @@ -17,9 +17,6 @@ #include -#include "common/dnnl_thread.hpp" - -#include "cpu/cpu_pooling_pd.hpp" #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" #include "cpu/x64/jit_avx512_core_fp8cvt.hpp" #include "cpu/x64/jit_uni_pool_kernel.hpp" @@ -34,11 +31,6 @@ using namespace alg_kind; #define GET_OFF(field) offsetof(jit_pool_call_s, field) -static bcast_set_t get_supported_bcast_strategies() { - return {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc, - broadcasting_strategy_t::no_broadcast}; -} - template jit_uni_pool_kernel::~jit_uni_pool_kernel() = default; @@ -102,8 +94,8 @@ jit_uni_pool_kernel::jit_uni_pool_kernel( postop_tail, k_c_tail_mask, use_exact_tail_scalar_bcast}; const binary_injector::static_params_t bsp {reg_param, - get_supported_bcast_strategies(), rhs_sp, f8_e5m2_emu_.get(), - f8_e4m3_emu_.get()}; + jit_uni_pooling_utils::get_supported_bcast_strategies(), rhs_sp, + f8_e5m2_emu_.get(), f8_e4m3_emu_.get()}; postops_injector_ = utils::make_unique>( @@ -111,345 +103,6 @@ jit_uni_pool_kernel::jit_uni_pool_kernel( } } -static status_t set_binary_postops_formats( - post_ops_t &post_ops, const memory_desc_t *dst_md) { - for (int idx = 0; idx < post_ops.len(); ++idx) { - if (!post_ops.contain(primitive_kind::binary, idx)) continue; - - auto &src1_md = post_ops.entry_[idx].binary.src1_desc; - const memory_desc_wrapper src1_mdw(src1_md); - if (!src1_mdw.format_any()) { - if (src1_mdw.is_blocking_desc()) - continue; - else - return status::unimplemented; - } - - const memory_desc_wrapper dst_mdw(dst_md); - assert(!dst_mdw.format_any()); - - CHECK(memory_desc_init_by_blocking_desc( - src1_md, dst_mdw.blocking_desc())); - } - - return status::success; -} - -template -status_t jit_uni_pool_kernel::init_conf(jit_pool_conf_t &jpp, - memory_tracking::registrar_t &scratchpad, primitive_attr_t &attr, - const pooling_pd_t *ppd) { - - const auto &pd = *ppd->desc(); - const memory_desc_wrapper src_d( - ppd->is_fwd() ? ppd->src_md() : ppd->diff_src_md()); - const memory_desc_wrapper dst_d( - ppd->is_fwd() ? ppd->dst_md() : ppd->diff_dst_md()); - - const int ndims = src_d.ndims(); - - jpp.nthr = dnnl_get_max_threads(); - jpp.is_training = pd.prop_kind == prop_kind::forward_training; - jpp.is_backward = pd.prop_kind == prop_kind::backward_data; - - jpp.id = (ndims == 5) ? src_d.dims()[2] : 1; - jpp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2]; - jpp.iw = src_d.dims()[ndims - 1]; - jpp.od = (ndims == 5) ? dst_d.dims()[2] : 1; - jpp.ow = dst_d.dims()[ndims - 1]; - jpp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2]; - - const bool is_avx512 = is_superset(isa, avx512_core); - jpp.ndims = ndims; - jpp.mb = src_d.dims()[0]; - jpp.c_without_padding = src_d.dims()[1]; - jpp.c_block = is_avx512 ? 16 : 8; - - jpp.alg = pd.alg_kind; - - jpp.src_dt = jpp.is_backward ? pd.diff_src_desc.data_type - : pd.src_desc.data_type; - jpp.dst_dt = jpp.is_backward ? pd.diff_dst_desc.data_type - : pd.dst_desc.data_type; - - jpp.tmp_md = memory_desc_t(); - - jpp.is_bf16 = (src_d.data_type() == data_type::bf16 - && dst_d.data_type() == data_type::bf16); - jpp.is_f16 = (src_d.data_type() == data_type::f16 - && dst_d.data_type() == data_type::f16); - jpp.is_fp8 = utils::one_of(src_d.data_type(), data_type::f8_e5m2, - data_type::f8_e4m3) - && utils::one_of( - dst_d.data_type(), data_type::f8_e5m2, data_type::f8_e4m3); - - using namespace format_tag; - - const auto blocked_fmt_tag = is_avx512 - ? utils::pick(ndims - 3, nCw16c, nChw16c, nCdhw16c) - : utils::pick(ndims - 3, nCw8c, nChw8c, nCdhw8c); - - // src_d.data_type() is equal to dst_d.data_type(). This is checked in init - auto ncsp_fmt_tag = format_tag::undef; - - const unsigned int L3_cache_size_per_core - = platform::get_per_core_cache_size(3); - const size_t block_size - = ((size_t)jpp.id * jpp.ih * jpp.iw + jpp.od * jpp.oh * jpp.ow) - * jpp.c_block * types::data_type_size(src_d.data_type()); - - const bool forward_ncsp_allowed = !jpp.is_backward - && jpp.c_without_padding > 3 - && ((jpp.ih > 1 && jpp.iw > 1 - && block_size <= L3_cache_size_per_core) - || utils::one_of(src_d.data_type(), data_type::bf16, - data_type::f16, data_type::f8_e5m2, - data_type::f8_e4m3)); - - const bool backward_ncsp_allowed = jpp.is_backward - && ((jpp.ih > 1 && jpp.iw > 1 && jpp.c_without_padding > 1 - && block_size <= L3_cache_size_per_core) - || (utils::one_of(src_d.data_type(), data_type::bf16, - data_type::f16) - && !(jpp.alg == pooling_max - && block_size > L3_cache_size_per_core))); - - ncsp_fmt_tag = ((forward_ncsp_allowed || backward_ncsp_allowed) && is_avx512 - && ndims <= 5) - ? utils::pick(ndims - 3, ncw, nchw, ncdhw) - : format_tag::undef; - - const auto nspc_fmt_tag = (ndims <= 5) - ? utils::pick(ndims - 3, nwc, nhwc, ndhwc) - : format_tag::undef; - - const auto fmt_tag = src_d.matches_one_of_tag( - blocked_fmt_tag, ncsp_fmt_tag, nspc_fmt_tag); - - VDISPATCH_POOLING_IC( - dst_d.matches_tag(fmt_tag), VERBOSE_UNSUPPORTED_TAG_S, "dst"); - - VDISPATCH_POOLING_IC( - post_ops_ok(jpp, attr, dst_d), VERBOSE_UNSUPPORTED_POSTOP); - - if (fmt_tag == ncsp_fmt_tag) { - // transform input to blocked f32, call f32 jit, transform result to - // plain output - jpp.is_bf16 = false; - jpp.is_f16 = false; - jpp.is_fp8 = false; - jpp.dt_size = types::data_type_size(data_type::f32); - jpp.tag_kind = jit_memory_tag_kind_t::ncsp; - - // used to initialize binary post-ops - if (ppd->is_fwd() && jpp.with_binary) { - CHECK(memory_desc_init_by_tag(jpp.tmp_md, ndims, dst_d.md_->dims, - data_type::f32, blocked_fmt_tag)); - } - } else { - jpp.dt_size = types::data_type_size(src_d.data_type()); - jpp.tag_kind = (fmt_tag == nspc_fmt_tag) - ? jit_memory_tag_kind_t::nspc - : jit_memory_tag_kind_t::blocked; - } - - if (ppd->is_fwd() && jpp.with_binary) { - CHECK(set_binary_postops_formats(attr.post_ops_, - jpp.tag_kind == jit_memory_tag_kind_t::ncsp ? &jpp.tmp_md - : dst_d.md_)); - } - - jpp.isa = (jpp.is_bf16 && mayiuse(avx512_core_bf16)) - ? avx512_core_bf16 - : ((jpp.is_fp8 && mayiuse(avx512_core_fp16)) ? avx512_core_fp16 - : isa); - - // disabling verbose dispatch messages for unsupported isa for - // better readability - if (!mayiuse(isa)) return status::unimplemented; - - VDISPATCH_POOLING_IC( - (fmt_tag != format_tag::undef), VERBOSE_UNSUPPORTED_TAG); - VDISPATCH_POOLING_IC(IMPLICATION(jpp.is_bf16, - utils::one_of(jpp.isa, avx512_core_bf16, - avx512_core, avx2_vnni_2)), - VERBOSE_ISA_DT_MISMATCH); - VDISPATCH_POOLING_IC( - IMPLICATION(jpp.is_f16, - utils::one_of(jpp.isa, avx512_core_fp16, avx2_vnni_2)), - VERBOSE_ISA_DT_MISMATCH); - VDISPATCH_POOLING_IC( - IMPLICATION(jpp.is_fp8, utils::one_of(jpp.isa, avx512_core_fp16)), - VERBOSE_ISA_DT_MISMATCH); - VDISPATCH_POOLING_IC( - utils::one_of(pd.alg_kind, pooling_max, pooling_avg_include_padding, - pooling_avg_exclude_padding), - VERBOSE_BAD_ALGORITHM); - - const bool is_xf16_avx2_vnni_2 - = (jpp.is_bf16 || jpp.is_f16) && isa == avx2_vnni_2; - // note: avx2_vnni_2 only supports nxc format - VDISPATCH_POOLING_IC(IMPLICATION(is_xf16_avx2_vnni_2, - jpp.tag_kind == jit_memory_tag_kind_t::nspc), - "isa, format tag mismatch"); - - // note: avx2_vnni_2 only supports FWD direction - VDISPATCH_POOLING_IC(IMPLICATION(is_xf16_avx2_vnni_2, !jpp.is_backward), - "isa, propagation kind mismatch"); - - jpp.c = jpp.tag_kind == jit_memory_tag_kind_t::blocked - ? utils::rnd_up(jpp.c_without_padding, jpp.c_block) - : jpp.c_without_padding; - if (jpp.tag_kind == jit_memory_tag_kind_t::blocked) - assert(src_d.padded_dims()[1] == jpp.c); - jpp.nb_c = utils::div_up(jpp.c, jpp.c_block); - jpp.c_tail = jpp.c_without_padding % jpp.c_block; - jpp.is_c_padded = jpp.tag_kind == jit_memory_tag_kind_t::blocked - && src_d.padded_dims()[1] != jpp.c_without_padding; - - jpp.stride_d = (ndims == 5) ? pd.strides[0] : 1; - jpp.stride_h = (ndims == 3) ? 1 : pd.strides[ndims - 4]; - jpp.stride_w = pd.strides[ndims - 3]; - jpp.kd = (ndims == 5) ? pd.kernel[0] : 1; - jpp.kh = (ndims == 3) ? 1 : pd.kernel[ndims - 4]; - jpp.kw = pd.kernel[ndims - 3]; - - jpp.f_pad = (ndims == 5) ? pd.padding[0][0] : 0; - jpp.t_pad = (ndims == 3) ? 0 : pd.padding[0][ndims - 4]; - jpp.l_pad = pd.padding[0][ndims - 3]; - - const int back_pad = calculate_end_padding( - jpp.f_pad, jpp.od, jpp.id, jpp.stride_d, jpp.kd); - const int bottom_pad = calculate_end_padding( - jpp.t_pad, jpp.oh, jpp.ih, jpp.stride_h, jpp.kh); - const int right_pad = calculate_end_padding( - jpp.l_pad, jpp.ow, jpp.iw, jpp.stride_w, jpp.kw); - - VDISPATCH_POOLING_IC( - !(jpp.f_pad >= jpp.kd || jpp.t_pad >= jpp.kh || jpp.l_pad >= jpp.kw - || back_pad >= jpp.kd || bottom_pad >= jpp.kh - || right_pad >= jpp.kw), - VERBOSE_UNSUPPORTED_PAD_FEATURE, ""); - - jpp.ind_dt = ppd->workspace_md() ? ppd->workspace_md()->data_type - : data_type::undef; - - jpp.simple_alg = jpp.is_training - || IMPLICATION(jpp.is_backward, jpp.kd <= jpp.stride_d); - - jpp.ur = 0; - if (jpp.alg == pooling_max) { - jpp.ur = is_avx512 ? 16 : 4; - - if (utils::one_of(isa, avx, avx2, avx2_vnni_2) && jpp.c_tail > 0) - // Additional register needed for tail mask - jpp.ur -= 1; - - if (jpp.is_training) - jpp.ur = is_avx512 ? 9 : 3; - else if (jpp.is_backward) - jpp.ur = is_avx512 ? 6 : 3; - } else { - if (jpp.is_backward) - jpp.ur = is_avx512 ? 12 : 6; - else - jpp.ur = is_avx512 ? 24 : 12; - } - if ((jpp.is_bf16 || jpp.is_f16) && isa != avx2_vnni_2) { - jpp.ur = (!isa_has_bf16(jpp.isa)) - ? jpp.ur - 4 // Free registers for AVX512 emulation - : jpp.ur - 1; // Free register for cvt from bf16/f16 to f32 - } - - if (jpp.is_fp8) { - // TODO: Optimize the ur if native FP8 support is available - jpp.ur = jpp.ur - 4; - } - assert(jpp.ur > 0); - - jpp.needs_f32_accum_for_bf16 = jpp.is_bf16 - && jpp.alg == alg_kind::pooling_max && jpp.is_backward - && (jpp.stride_d < jpp.kd || jpp.stride_h < jpp.kh - || jpp.stride_w < jpp.kw); - - // select jpp.ur_bc - if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { - auto min_ur_w = nstl::max(1, utils::div_up(jpp.l_pad, jpp.stride_w)); - int min_ur_w1 = utils::div_up(right_pad, jpp.stride_w); - if (min_ur_w < min_ur_w1) { min_ur_w = min_ur_w1; } - jpp.ur_bc = nstl::min(jpp.nb_c, nstl::max(1, jpp.ur / min_ur_w)); - //take into account threading - to have enough work for parallelization - float best_eff = 0; - for (int ur_bc = jpp.ur_bc; ur_bc > 0; ur_bc--) { - - const auto nb2_c = utils::div_up(jpp.nb_c, ur_bc); - auto work = jpp.is_backward - ? (ndims == 5 && jpp.simple_alg ? jpp.od : 1) - : (ndims == 5 ? jpp.od : jpp.oh); - work *= jpp.mb * nb2_c; - auto eff = (float)work / utils::rnd_up(work, jpp.nthr); - if (eff > best_eff) { - - best_eff = eff; - jpp.ur_bc = ur_bc; - } - if (eff > 0.9f) break; // Heuristic threshold - } - - //take into account cache re-usage after zeroing on backward - if (jpp.is_backward && ndims < 5 && !jpp.needs_f32_accum_for_bf16) { - const int L2 = platform::get_per_core_cache_size(2) / jpp.dt_size; - int ur_bc = nstl::max(1, L2 / (jpp.kh * jpp.iw * jpp.c_block)); - jpp.ur_bc = nstl::min(jpp.ur_bc, ur_bc); - } - - jpp.ur_bc_tail = jpp.nb_c % jpp.ur_bc; - } else { - jpp.ur_bc = 1; - jpp.ur_bc_tail = 0; - } - - // scratchpad for c_block slice of input and/or output - using namespace memory_tracking::names; - const int nscr = nstl::min(dnnl_get_max_threads(), jpp.mb * jpp.nb_c); - if (jpp.tag_kind == jit_memory_tag_kind_t::ncsp) { - scratchpad.book(key_pool_src_plain2blocked_cvt, - static_cast(jpp.c_block) * jpp.id * jpp.ih * jpp.iw - * nscr, - jpp.dt_size); - scratchpad.book(key_pool_dst_plain2blocked_cvt, - static_cast(jpp.c_block) * jpp.od * jpp.oh * jpp.ow - * nscr, - jpp.dt_size); - scratchpad.book(key_pool_ind_plain2blocked_cvt, - static_cast(jpp.c_block) * jpp.od * jpp.oh * jpp.ow - * nscr); - } - - jpp.f32_accum_block_size = jpp.ur_bc * jpp.c_block; - if (jpp.needs_f32_accum_for_bf16) { - auto tmp_d = memory_desc_wrapper(jpp.tmp_md); - assert(tmp_d.is_zero() - && (fmt_tag == nspc_fmt_tag || fmt_tag == blocked_fmt_tag)); - - dims_t dims {}; - utils::array_copy(dims, src_d.dims(), ndims); - - const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); - dims[0] = nstl::min(dnnl_get_max_threads(), jpp.mb * nb2_c); - dims[1] = jpp.f32_accum_block_size; - - memory_desc_init_by_tag( - jpp.tmp_md, ndims, dims, data_type::f32, fmt_tag); - - scratchpad.book(key_pool_src_f32_accum, tmp_d.size()); - } - - jpp.post_ops = attr.post_ops_; - - return status::success; -} - static int reg_ind(int shift, int bc, int j, int ur_bc, int ur_w) noexcept { return shift * ur_bc * ur_w + bc * ur_w + j; }; @@ -728,46 +381,6 @@ inline void jit_uni_pool_kernel::store32(const int idx, assert(!"datatype not supported"); } -template -bool jit_uni_pool_kernel::post_ops_ok(jit_pool_conf_t &jpp, - const primitive_attr_t &attr, const memory_desc_wrapper &dst_d) { - const auto &post_ops = attr.post_ops_; - const auto &entries = post_ops.entry_; - jpp.with_postops = false; - jpp.with_eltwise = false; - jpp.with_binary = false; - - if (!jpp.is_backward) { - for (const auto &entry : entries) { - if (entry.is_eltwise()) { - const auto alg = entry.eltwise.alg; - jpp.with_eltwise = eltwise_injector::is_supported( - isa, alg, data_type::f32); - } else if (entry.is_binary()) { - const bool is_bf16_ok = IMPLICATION( - entry.binary.src1_desc.data_type == data_type::bf16, - utils::one_of(isa, avx512_core, avx2_vnni_2)); - const bool is_f16_ok = IMPLICATION( - entry.binary.src1_desc.data_type == data_type::f16, - utils::one_of(isa, avx512_core_fp16, avx2_vnni_2)); - const bool is_fp8_ok = IMPLICATION( - utils::one_of(entry.binary.src1_desc.data_type, - data_type::f8_e5m2, data_type::f8_e4m3), - utils::one_of(isa, avx512_core_fp16)); - if (!(is_bf16_ok && is_f16_ok && is_fp8_ok)) return false; - - jpp.with_binary = true; - } else - return false; - } - - jpp.with_postops = jpp.with_eltwise || jpp.with_binary; - } - - return binary_injector::binary_args_broadcast_supported( - post_ops, dst_d, get_supported_bcast_strategies()); -} - template void jit_uni_pool_kernel::apply_postops(int ur_bc, int ur_w, int c_block, const std::function &is_tail_predicate) { diff --git a/src/cpu/x64/jit_uni_pool_kernel.hpp b/src/cpu/x64/jit_uni_pool_kernel.hpp index 70f470989d3..6589745cdc8 100644 --- a/src/cpu/x64/jit_uni_pool_kernel.hpp +++ b/src/cpu/x64/jit_uni_pool_kernel.hpp @@ -18,12 +18,8 @@ #ifndef CPU_X64_JIT_UNI_POOL_KERNEL_HPP #define CPU_X64_JIT_UNI_POOL_KERNEL_HPP -#include -#include #include -#include "common/memory_tracking.hpp" - #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" #include "cpu/x64/jit_generator.hpp" #include "cpu/x64/jit_primitive_conf.hpp" @@ -33,6 +29,13 @@ namespace impl { namespace cpu { namespace x64 { +namespace jit_uni_pooling_utils { +inline bcast_set_t get_supported_bcast_strategies() { + return {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc, + broadcasting_strategy_t::no_broadcast}; +} +} // namespace jit_uni_pooling_utils + struct bf16_emulation_t; template @@ -45,10 +48,6 @@ struct jit_uni_pool_kernel : public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel) - static status_t init_conf(jit_pool_conf_t &jbp, - memory_tracking::registrar_t &scratchpad, primitive_attr_t &attr, - const pooling_pd_t *ppd); - private: using Xmm = Xbyak::Xmm; using Ymm = Xbyak::Ymm; @@ -257,9 +256,6 @@ struct jit_uni_pool_kernel : public jit_generator { void apply_postops(int ur_bc, int ur_w, int c_block, const std::function &is_tail_predicate); - static bool post_ops_ok(jit_pool_conf_t &jpp, const primitive_attr_t &attr, - const memory_desc_wrapper &dst_d); - inline bool use_bf16_emulation() const { return jpp.is_bf16 && !isa_has_bf16(jpp.isa) && isa != avx2_vnni_2; } diff --git a/src/cpu/x64/jit_uni_pooling.cpp b/src/cpu/x64/jit_uni_pooling.cpp index db46492c4ca..7cb978c9f9c 100644 --- a/src/cpu/x64/jit_uni_pooling.cpp +++ b/src/cpu/x64/jit_uni_pooling.cpp @@ -13,17 +13,18 @@ * limitations under the License. *******************************************************************************/ +#include #include #include #include "oneapi/dnnl/dnnl_types.h" -#include "common/c_types_map.hpp" #include "common/dnnl_thread.hpp" #include "common/nstl.hpp" #include "common/type_helpers.hpp" #include "cpu/x64/jit_uni_pooling.hpp" +#include "cpu/x64/jit_uni_reorder.hpp" namespace dnnl { namespace impl { @@ -32,6 +33,390 @@ namespace x64 { namespace jit_uni_pooling_utils { +static status_t set_binary_postops_formats( + post_ops_t &post_ops, const memory_desc_t *dst_md) { + for (int idx = 0; idx < post_ops.len(); ++idx) { + if (!post_ops.contain(primitive_kind::binary, idx)) continue; + + auto &src1_md = post_ops.entry_[idx].binary.src1_desc; + const memory_desc_wrapper src1_mdw(src1_md); + if (!src1_mdw.format_any()) { + if (src1_mdw.is_blocking_desc()) + continue; + else + return status::unimplemented; + } + + const memory_desc_wrapper dst_mdw(dst_md); + assert(!dst_mdw.format_any()); + + CHECK(memory_desc_init_by_blocking_desc( + src1_md, dst_mdw.blocking_desc())); + } + + return status::success; +} + +static bool post_ops_ok(cpu_isa_t isa, jit_pool_conf_t &jpp, + const primitive_attr_t &attr, const memory_desc_wrapper &dst_d) { + const auto &post_ops = attr.post_ops_; + const auto &entries = post_ops.entry_; + jpp.with_postops = false; + jpp.with_eltwise = false; + jpp.with_binary = false; + + if (!jpp.is_backward) { + for (const auto &entry : entries) { + if (entry.is_eltwise()) { + const auto alg = entry.eltwise.alg; + jpp.with_eltwise = eltwise_injector::is_supported( + isa, alg, data_type::f32); + } else if (entry.is_binary()) { + const bool is_bf16_ok = IMPLICATION( + entry.binary.src1_desc.data_type == data_type::bf16, + utils::one_of(isa, avx512_core, avx2_vnni_2)); + const bool is_f16_ok = IMPLICATION( + entry.binary.src1_desc.data_type == data_type::f16, + utils::one_of(isa, avx512_core_fp16, avx2_vnni_2)); + const bool is_fp8_ok = IMPLICATION( + utils::one_of(entry.binary.src1_desc.data_type, + data_type::f8_e5m2, data_type::f8_e4m3), + utils::one_of(isa, avx512_core_fp16)); + if (!(is_bf16_ok && is_f16_ok && is_fp8_ok)) return false; + + jpp.with_binary = true; + } else + return false; + } + + jpp.with_postops = jpp.with_eltwise || jpp.with_binary; + } + + return binary_injector::binary_args_broadcast_supported( + post_ops, dst_d, get_supported_bcast_strategies()); +} + +static status_t init_conf(cpu_isa_t isa, jit_pool_conf_t &jpp, + primitive_attr_t &attr, const pooling_pd_t *ppd) { + + using namespace alg_kind; + + const auto &pd = *ppd->desc(); + const memory_desc_wrapper src_d( + ppd->is_fwd() ? ppd->src_md() : ppd->diff_src_md()); + const memory_desc_wrapper dst_d( + ppd->is_fwd() ? ppd->dst_md() : ppd->diff_dst_md()); + + const int ndims = src_d.ndims(); + + jpp.nthr = dnnl_get_max_threads(); + jpp.is_training = pd.prop_kind == prop_kind::forward_training; + jpp.is_backward = pd.prop_kind == prop_kind::backward_data; + + jpp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jpp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2]; + jpp.iw = src_d.dims()[ndims - 1]; + jpp.od = (ndims == 5) ? dst_d.dims()[2] : 1; + jpp.ow = dst_d.dims()[ndims - 1]; + jpp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2]; + + const bool is_avx512 = is_superset(isa, avx512_core); + jpp.ndims = ndims; + jpp.mb = src_d.dims()[0]; + jpp.c_without_padding = src_d.dims()[1]; + jpp.c_block = is_avx512 ? 16 : 8; + + jpp.alg = pd.alg_kind; + + jpp.src_dt = jpp.is_backward ? pd.diff_src_desc.data_type + : pd.src_desc.data_type; + jpp.dst_dt = jpp.is_backward ? pd.diff_dst_desc.data_type + : pd.dst_desc.data_type; + + jpp.tmp_md = memory_desc_t(); + + jpp.is_bf16 = (src_d.data_type() == data_type::bf16 + && dst_d.data_type() == data_type::bf16); + jpp.is_f16 = (src_d.data_type() == data_type::f16 + && dst_d.data_type() == data_type::f16); + jpp.is_fp8 = utils::one_of(src_d.data_type(), data_type::f8_e5m2, + data_type::f8_e4m3) + && utils::one_of( + dst_d.data_type(), data_type::f8_e5m2, data_type::f8_e4m3); + + using namespace format_tag; + + const auto blocked_fmt_tag = is_avx512 + ? utils::pick(ndims - 3, nCw16c, nChw16c, nCdhw16c) + : utils::pick(ndims - 3, nCw8c, nChw8c, nCdhw8c); + + // src_d.data_type() is equal to dst_d.data_type(). This is checked in init + auto ncsp_fmt_tag = format_tag::undef; + + const unsigned int L3_cache_size_per_core + = platform::get_per_core_cache_size(3); + const size_t block_size + = ((size_t)jpp.id * jpp.ih * jpp.iw + jpp.od * jpp.oh * jpp.ow) + * jpp.c_block * types::data_type_size(src_d.data_type()); + + const bool forward_ncsp_allowed = !jpp.is_backward + && jpp.c_without_padding > 3 + && ((jpp.ih > 1 && jpp.iw > 1 + && block_size <= L3_cache_size_per_core) + || utils::one_of(src_d.data_type(), data_type::bf16, + data_type::f16, data_type::f8_e5m2, + data_type::f8_e4m3)); + + const bool backward_ncsp_allowed = jpp.is_backward + && ((jpp.ih > 1 && jpp.iw > 1 && jpp.c_without_padding > 1 + && block_size <= L3_cache_size_per_core) + || (utils::one_of(src_d.data_type(), data_type::bf16, + data_type::f16) + && !(jpp.alg == pooling_max + && block_size > L3_cache_size_per_core))); + + ncsp_fmt_tag = ((forward_ncsp_allowed || backward_ncsp_allowed) && is_avx512 + && ndims <= 5) + ? utils::pick(ndims - 3, ncw, nchw, ncdhw) + : format_tag::undef; + + const auto nspc_fmt_tag = (ndims <= 5) + ? utils::pick(ndims - 3, nwc, nhwc, ndhwc) + : format_tag::undef; + + const auto fmt_tag = src_d.matches_one_of_tag( + blocked_fmt_tag, ncsp_fmt_tag, nspc_fmt_tag); + + VDISPATCH_POOLING_IC( + dst_d.matches_tag(fmt_tag), VERBOSE_UNSUPPORTED_TAG_S, "dst"); + + VDISPATCH_POOLING_IC( + post_ops_ok(isa, jpp, attr, dst_d), VERBOSE_UNSUPPORTED_POSTOP); + + if (fmt_tag == ncsp_fmt_tag) { + // transform input to blocked f32, call f32 jit, transform result to + // plain output + jpp.is_bf16 = false; + jpp.is_f16 = false; + jpp.is_fp8 = false; + jpp.dt_size = types::data_type_size(data_type::f32); + jpp.tag_kind = jit_memory_tag_kind_t::ncsp; + + // used to initialize binary post-ops + if (ppd->is_fwd() && jpp.with_binary) { + CHECK(memory_desc_init_by_tag(jpp.tmp_md, ndims, dst_d.md_->dims, + data_type::f32, blocked_fmt_tag)); + } + } else { + jpp.dt_size = types::data_type_size(src_d.data_type()); + jpp.tag_kind = (fmt_tag == nspc_fmt_tag) + ? jit_memory_tag_kind_t::nspc + : jit_memory_tag_kind_t::blocked; + } + + if (ppd->is_fwd() && jpp.with_binary) { + CHECK(set_binary_postops_formats(attr.post_ops_, + jpp.tag_kind == jit_memory_tag_kind_t::ncsp ? &jpp.tmp_md + : dst_d.md_)); + } + + jpp.isa = (jpp.is_bf16 && mayiuse(avx512_core_bf16)) + ? avx512_core_bf16 + : ((jpp.is_fp8 && mayiuse(avx512_core_fp16)) ? avx512_core_fp16 + : isa); + + // disabling verbose dispatch messages for unsupported isa for + // better readability + if (!mayiuse(isa)) return status::unimplemented; + + VDISPATCH_POOLING_IC( + (fmt_tag != format_tag::undef), VERBOSE_UNSUPPORTED_TAG); + VDISPATCH_POOLING_IC(IMPLICATION(jpp.is_bf16, + utils::one_of(jpp.isa, avx512_core_bf16, + avx512_core, avx2_vnni_2)), + VERBOSE_ISA_DT_MISMATCH); + VDISPATCH_POOLING_IC( + IMPLICATION(jpp.is_f16, + utils::one_of(jpp.isa, avx512_core_fp16, avx2_vnni_2)), + VERBOSE_ISA_DT_MISMATCH); + VDISPATCH_POOLING_IC( + IMPLICATION(jpp.is_fp8, utils::one_of(jpp.isa, avx512_core_fp16)), + VERBOSE_ISA_DT_MISMATCH); + VDISPATCH_POOLING_IC( + utils::one_of(pd.alg_kind, pooling_max, pooling_avg_include_padding, + pooling_avg_exclude_padding), + VERBOSE_BAD_ALGORITHM); + + const bool is_xf16_avx2_vnni_2 + = (jpp.is_bf16 || jpp.is_f16) && isa == avx2_vnni_2; + // note: avx2_vnni_2 only supports nxc format + VDISPATCH_POOLING_IC(IMPLICATION(is_xf16_avx2_vnni_2, + jpp.tag_kind == jit_memory_tag_kind_t::nspc), + "isa, format tag mismatch"); + + // note: avx2_vnni_2 only supports FWD direction + VDISPATCH_POOLING_IC(IMPLICATION(is_xf16_avx2_vnni_2, !jpp.is_backward), + "isa, propagation kind mismatch"); + + jpp.c = jpp.tag_kind == jit_memory_tag_kind_t::blocked + ? utils::rnd_up(jpp.c_without_padding, jpp.c_block) + : jpp.c_without_padding; + if (jpp.tag_kind == jit_memory_tag_kind_t::blocked) + assert(src_d.padded_dims()[1] == jpp.c); + jpp.nb_c = utils::div_up(jpp.c, jpp.c_block); + jpp.c_tail = jpp.c_without_padding % jpp.c_block; + jpp.is_c_padded = jpp.tag_kind == jit_memory_tag_kind_t::blocked + && src_d.padded_dims()[1] != jpp.c_without_padding; + + jpp.stride_d = (ndims == 5) ? pd.strides[0] : 1; + jpp.stride_h = (ndims == 3) ? 1 : pd.strides[ndims - 4]; + jpp.stride_w = pd.strides[ndims - 3]; + jpp.kd = (ndims == 5) ? pd.kernel[0] : 1; + jpp.kh = (ndims == 3) ? 1 : pd.kernel[ndims - 4]; + jpp.kw = pd.kernel[ndims - 3]; + + jpp.f_pad = (ndims == 5) ? pd.padding[0][0] : 0; + jpp.t_pad = (ndims == 3) ? 0 : pd.padding[0][ndims - 4]; + jpp.l_pad = pd.padding[0][ndims - 3]; + + const int back_pad = calculate_end_padding( + jpp.f_pad, jpp.od, jpp.id, jpp.stride_d, jpp.kd); + const int bottom_pad = calculate_end_padding( + jpp.t_pad, jpp.oh, jpp.ih, jpp.stride_h, jpp.kh); + const int right_pad = calculate_end_padding( + jpp.l_pad, jpp.ow, jpp.iw, jpp.stride_w, jpp.kw); + + VDISPATCH_POOLING_IC( + !(jpp.f_pad >= jpp.kd || jpp.t_pad >= jpp.kh || jpp.l_pad >= jpp.kw + || back_pad >= jpp.kd || bottom_pad >= jpp.kh + || right_pad >= jpp.kw), + VERBOSE_UNSUPPORTED_PAD_FEATURE, ""); + + jpp.ind_dt = ppd->workspace_md() ? ppd->workspace_md()->data_type + : data_type::undef; + + jpp.simple_alg = jpp.is_training + || IMPLICATION(jpp.is_backward, jpp.kd <= jpp.stride_d); + + jpp.ur = 0; + if (jpp.alg == pooling_max) { + jpp.ur = is_avx512 ? 16 : 4; + + if (utils::one_of(isa, avx, avx2, avx2_vnni_2) && jpp.c_tail > 0) + // Additional register needed for tail mask + jpp.ur -= 1; + + if (jpp.is_training) + jpp.ur = is_avx512 ? 9 : 3; + else if (jpp.is_backward) + jpp.ur = is_avx512 ? 6 : 3; + } else { + if (jpp.is_backward) + jpp.ur = is_avx512 ? 12 : 6; + else + jpp.ur = is_avx512 ? 24 : 12; + } + if ((jpp.is_bf16 || jpp.is_f16) && isa != avx2_vnni_2) { + jpp.ur = (!isa_has_bf16(jpp.isa)) + ? jpp.ur - 4 // Free registers for AVX512 emulation + : jpp.ur - 1; // Free register for cvt from bf16/f16 to f32 + } + + if (jpp.is_fp8) { + // TODO: Optimize the ur if native FP8 support is available + jpp.ur = jpp.ur - 4; + } + assert(jpp.ur > 0); + + jpp.needs_f32_accum_for_bf16 = jpp.is_bf16 + && jpp.alg == alg_kind::pooling_max && jpp.is_backward + && (jpp.stride_d < jpp.kd || jpp.stride_h < jpp.kh + || jpp.stride_w < jpp.kw); + + // select jpp.ur_bc + if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { + auto min_ur_w = nstl::max(1, utils::div_up(jpp.l_pad, jpp.stride_w)); + int min_ur_w1 = utils::div_up(right_pad, jpp.stride_w); + if (min_ur_w < min_ur_w1) { min_ur_w = min_ur_w1; } + jpp.ur_bc = nstl::min(jpp.nb_c, nstl::max(1, jpp.ur / min_ur_w)); + //take into account threading - to have enough work for parallelization + float best_eff = 0; + for (int ur_bc = jpp.ur_bc; ur_bc > 0; ur_bc--) { + + const auto nb2_c = utils::div_up(jpp.nb_c, ur_bc); + auto work = jpp.is_backward + ? (ndims == 5 && jpp.simple_alg ? jpp.od : 1) + : (ndims == 5 ? jpp.od : jpp.oh); + work *= jpp.mb * nb2_c; + auto eff = (float)work / utils::rnd_up(work, jpp.nthr); + if (eff > best_eff) { + + best_eff = eff; + jpp.ur_bc = ur_bc; + } + if (eff > 0.9f) break; // Heuristic threshold + } + + //take into account cache re-usage after zeroing on backward + if (jpp.is_backward && ndims < 5 && !jpp.needs_f32_accum_for_bf16) { + const int L2 = platform::get_per_core_cache_size(2) / jpp.dt_size; + int ur_bc = nstl::max(1, L2 / (jpp.kh * jpp.iw * jpp.c_block)); + jpp.ur_bc = nstl::min(jpp.ur_bc, ur_bc); + } + + jpp.ur_bc_tail = jpp.nb_c % jpp.ur_bc; + } else { + jpp.ur_bc = 1; + jpp.ur_bc_tail = 0; + } + + jpp.f32_accum_block_size = jpp.ur_bc * jpp.c_block; + if (jpp.needs_f32_accum_for_bf16) { + assert(memory_desc_wrapper(jpp.tmp_md).is_zero() + && (fmt_tag == nspc_fmt_tag || fmt_tag == blocked_fmt_tag)); + + dims_t dims {}; + utils::array_copy(dims, src_d.dims(), ndims); + + const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); + dims[0] = nstl::min(dnnl_get_max_threads(), jpp.mb * nb2_c); + dims[1] = jpp.f32_accum_block_size; + + memory_desc_init_by_tag( + jpp.tmp_md, ndims, dims, data_type::f32, fmt_tag); + } + + jpp.post_ops = attr.post_ops_; + + return status::success; +} + +static void init_scratchpad( + jit_pool_conf_t const &jpp, memory_tracking::registrar_t &scratchpad) { + + // scratchpad for c_block slice of input and/or output + using namespace memory_tracking::names; + const int nscr = nstl::min(dnnl_get_max_threads(), jpp.mb * jpp.nb_c); + if (jpp.tag_kind == jit_memory_tag_kind_t::ncsp) { + scratchpad.book(key_pool_src_plain2blocked_cvt, + static_cast(jpp.c_block) * jpp.id * jpp.ih * jpp.iw + * nscr, + jpp.dt_size); + scratchpad.book(key_pool_dst_plain2blocked_cvt, + static_cast(jpp.c_block) * jpp.od * jpp.oh * jpp.ow + * nscr, + jpp.dt_size); + scratchpad.book(key_pool_ind_plain2blocked_cvt, + static_cast(jpp.c_block) * jpp.od * jpp.oh * jpp.ow + * nscr); + } + + if (jpp.needs_f32_accum_for_bf16) { + auto tmp_d = memory_desc_wrapper(jpp.tmp_md); + scratchpad.book(key_pool_src_f32_accum, tmp_d.size()); + } +} + struct trans_wrapper_t { trans_wrapper_t(data_type_t inp_dt, dim_t inp_str, data_type_t out_dt, dim_t out_str, dim_t ysize, dim_t xsize) @@ -630,6 +1015,35 @@ void bwd_f32_accum_for_bf16_t::cvt_to_bf16_slice_3d(int ithr, bfloat16_t *dst, } // namespace jit_uni_pooling_utils +template +status_t jit_uni_pooling_fwd_t::pd_t::init(engine_t *engine) { + using namespace utils; + + VDISPATCH_POOLING(is_fwd(), VERBOSE_BAD_PROPKIND); + VDISPATCH_POOLING(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, ""); + VDISPATCH_POOLING( + everyone_is(d_type, src_md()->data_type, dst_md()->data_type), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_POOLING(attr()->has_default_values( + primitive_attr_t::skip_mask_t::post_ops, d_type), + VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_POOLING(!is_dilated(), VERBOSE_UNSUPPORTED_FEATURE, + "does not support dilations"); + VDISPATCH_POOLING( + set_default_params() == status::success, VERBOSE_UNSUPPORTED_TAG); + + const bool is_training = desc_.prop_kind == prop_kind::forward_training; + if (desc()->alg_kind == alg_kind::pooling_max && is_training) + init_default_ws(); + + CHECK(jit_uni_pooling_utils::init_conf(isa, jpp_, attr_, this)); + + auto scratchpad = scratchpad_registry().registrar(); + jit_uni_pooling_utils::init_scratchpad(jpp_, scratchpad); + + return status::success; +} + template jit_uni_pooling_fwd_t::jit_uni_pooling_fwd_t(const pd_t *apd) : primitive_t(apd), kernel_(nullptr), trans_ctx_(nullptr) {} @@ -974,6 +1388,35 @@ void jit_uni_pooling_fwd_t::execute_forward_3d(const data_t *src, } } +template +status_t jit_uni_pooling_bwd_t::pd_t::init(engine_t *engine) { + using namespace utils; + + VDISPATCH_POOLING( + set_default_params() == status::success, VERBOSE_UNSUPPORTED_TAG); + VDISPATCH_POOLING(!is_fwd(), VERBOSE_BAD_PROPKIND); + VDISPATCH_POOLING(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, ""); + VDISPATCH_POOLING(everyone_is(d_type, diff_src_md()->data_type, + diff_dst_md()->data_type), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_POOLING(attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_POOLING(!is_dilated(), VERBOSE_UNSUPPORTED_FEATURE, + "does not support dilations"); + + if (desc()->alg_kind == alg_kind::pooling_max) { + const auto ws_dt = hint_fwd_pd_->workspace_md()->data_type; + init_default_ws(ws_dt); + VDISPATCH_POOLING(compare_ws(hint_fwd_pd_), VERBOSE_WS_MISMATCH); + } + + CHECK(jit_uni_pooling_utils::init_conf(isa, jpp_, attr_, this)); + + auto scratchpad = scratchpad_registry().registrar(); + jit_uni_pooling_utils::init_scratchpad(jpp_, scratchpad); + + return status::success; +} + template jit_uni_pooling_bwd_t::jit_uni_pooling_bwd_t(const pd_t *apd) : primitive_t(apd), kernel_(nullptr), trans_ctx_(nullptr) {} diff --git a/src/cpu/x64/jit_uni_pooling.hpp b/src/cpu/x64/jit_uni_pooling.hpp index 5b58e9ec26f..c5b3a8f2037 100644 --- a/src/cpu/x64/jit_uni_pooling.hpp +++ b/src/cpu/x64/jit_uni_pooling.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2024 Intel Corporation +* Copyright 2017-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,18 +17,13 @@ #ifndef CPU_X64_JIT_UNI_POOLING_HPP #define CPU_X64_JIT_UNI_POOLING_HPP -#include #include #include "common/c_types_map.hpp" -#include "common/dnnl_thread.hpp" #include "common/primitive.hpp" -#include "common/type_helpers.hpp" -#include "common/utils.hpp" #include "cpu/cpu_pooling_pd.hpp" #include "cpu/x64/jit_uni_pool_kernel.hpp" -#include "cpu/x64/jit_uni_reorder.hpp" namespace dnnl { namespace impl { @@ -48,35 +43,7 @@ struct jit_uni_pooling_fwd_t : public primitive_t { DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:", jpp_.isa, ""), jit_uni_pooling_fwd_t); - status_t init(engine_t *engine) { - using namespace utils; - - VDISPATCH_POOLING(is_fwd(), VERBOSE_BAD_PROPKIND); - VDISPATCH_POOLING(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, ""); - VDISPATCH_POOLING(everyone_is(d_type, src_md()->data_type, - dst_md()->data_type), - VERBOSE_UNSUPPORTED_DT); - VDISPATCH_POOLING( - attr()->has_default_values( - primitive_attr_t::skip_mask_t::post_ops, d_type), - VERBOSE_UNSUPPORTED_ATTR); - VDISPATCH_POOLING(!is_dilated(), VERBOSE_UNSUPPORTED_FEATURE, - "does not support dilations"); - VDISPATCH_POOLING(set_default_params() == status::success, - VERBOSE_UNSUPPORTED_TAG); - - const bool is_training - = desc_.prop_kind == prop_kind::forward_training; - if (desc()->alg_kind == alg_kind::pooling_max && is_training) - init_default_ws(); - - auto scratchpad = scratchpad_registry().registrar(); - - CHECK(jit_uni_pool_kernel::init_conf( - jpp_, scratchpad, attr_, this)); - - return status::success; - } + status_t init(engine_t *engine); jit_pool_conf_t jpp_; }; @@ -124,35 +91,7 @@ struct jit_uni_pooling_bwd_t : public primitive_t { DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:", jpp_.isa, ""), jit_uni_pooling_bwd_t); - status_t init(engine_t *engine) { - using namespace utils; - - VDISPATCH_POOLING(set_default_params() == status::success, - VERBOSE_UNSUPPORTED_TAG); - VDISPATCH_POOLING(!is_fwd(), VERBOSE_BAD_PROPKIND); - VDISPATCH_POOLING(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, ""); - VDISPATCH_POOLING(everyone_is(d_type, diff_src_md()->data_type, - diff_dst_md()->data_type), - VERBOSE_UNSUPPORTED_DT); - VDISPATCH_POOLING( - attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR); - VDISPATCH_POOLING(!is_dilated(), VERBOSE_UNSUPPORTED_FEATURE, - "does not support dilations"); - - if (desc()->alg_kind == alg_kind::pooling_max) { - const auto ws_dt = hint_fwd_pd_->workspace_md()->data_type; - init_default_ws(ws_dt); - VDISPATCH_POOLING( - compare_ws(hint_fwd_pd_), VERBOSE_WS_MISMATCH); - } - - auto scratchpad = scratchpad_registry().registrar(); - - CHECK(jit_uni_pool_kernel::init_conf( - jpp_, scratchpad, attr_, this)); - - return status::success; - } + status_t init(engine_t *engine); jit_pool_conf_t jpp_; };