diff --git a/src/cpu/x64/jit_uni_pool_kernel.cpp b/src/cpu/x64/jit_uni_pool_kernel.cpp index b9cf8b83039..dca6c01a364 100644 --- a/src/cpu/x64/jit_uni_pool_kernel.cpp +++ b/src/cpu/x64/jit_uni_pool_kernel.cpp @@ -34,6 +34,9 @@ using namespace alg_kind; #define GET_OFF(field) offsetof(jit_pool_call_s, field) +constexpr int sse41_single_block_size + = cpu_isa_traits::vlen / sizeof(float); + static bcast_set_t get_supported_bcast_strategies() { return {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc, broadcasting_strategy_t::no_broadcast}; @@ -45,11 +48,7 @@ jit_uni_pool_kernel::~jit_uni_pool_kernel() = default; template jit_uni_pool_kernel::jit_uni_pool_kernel( const jit_pool_conf_t &ajpp, const memory_desc_t *dst_md) - : jit_generator(jit_name(), isa), jpp(ajpp), bf16_emu_(nullptr) { - if (use_bf16_emulation()) - bf16_emu_ = utils::make_unique(this, - bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3, - bf16_emu_reserv_4, bf16_emu_reserv_5); + : jit_generator(jit_name(), isa), jpp(ajpp) { bool has_f8_e5m2_binary_postops = false; bool has_f8_e4m3_binary_postops = false; @@ -81,16 +80,13 @@ jit_uni_pool_kernel::jit_uni_pool_kernel( fp8_emu_reserv_4, fp8_emu_reserv_5, fp8_emu_reg64); } + const auto tail_size + = isa == sse41 ? jpp.c_tail % sse41_single_block_size : jpp.c_tail; + if (jpp.with_postops) { static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = true; static constexpr bool use_exact_tail_scalar_bcast = false; - static constexpr int sse41_single_block_size - = cpu_isa_traits::vlen / sizeof(float); - size_t postop_tail = static_cast(jpp.c_tail); - const bool high_half_block_empty = isa == sse41 - && static_cast(jpp.c_tail) > sse41_single_block_size; - if (high_half_block_empty) postop_tail -= sse41_single_block_size; const binary_injector::rhs_arg_static_params_t rhs_sp { static_cast(this->xmm4.getIdx()), this->r14, @@ -99,7 +95,8 @@ jit_uni_pool_kernel::jit_uni_pool_kernel( memory_desc_wrapper(jpp.tag_kind == jit_memory_tag_kind_t::ncsp ? jpp.tmp_md : *dst_md), - postop_tail, k_c_tail_mask, use_exact_tail_scalar_bcast}; + static_cast(tail_size), k_c_tail_mask, + use_exact_tail_scalar_bcast}; const binary_injector::static_params_t bsp {reg_param, get_supported_bcast_strategies(), rhs_sp, f8_e5m2_emu_.get(), @@ -109,6 +106,34 @@ jit_uni_pool_kernel::jit_uni_pool_kernel( = utils::make_unique>( this, jpp.post_ops, bsp); } + + io::io_tail_conf_t io_tail_conf(jpp.c_block, tail_size, + k_c_tail_mask.getIdx(), vmm_c_tail_mask.getIdx(), tmp_gpr); + + utils::optional_t io_bf16_conf; + if (use_bf16_emulation()) + io_bf16_conf = io::io_emu_bf16_conf_t(bf16_emu_reserv_1, + bf16_emu_reserv_2, bf16_emu_reserv_3, bf16_emu_reserv_4, + bf16_emu_reserv_5); + + utils::optional_t io_fp8_conf; + if (use_fp8_emulation() || has_f8_e5m2_binary_postops + || has_f8_e4m3_binary_postops) + io_fp8_conf = io::io_emu_fp8_conf_t(fp8_emu_reserv_1, fp8_emu_reserv_2, + fp8_emu_reserv_3, fp8_emu_reserv_4, fp8_emu_reserv_5, + fp8_tmp_mask, fp8_emu_reg64); + + using io_mdt_helper = io::jit_io_multi_dt_helper_t; + + typename io_mdt_helper::data_types_t dtypes = {jpp.src_dt, jpp.dst_dt}; + // Indices of type s32 will be stored/loaded as f32 as jit_io_helper_t does not + // support integers but stores/loads f32 without additional conversions of those + // 4 bytes. jit_io_helper_t is not used for processing indices of type u8. + if (jpp.ind_dt == data_type::s32) dtypes.insert(data_type::f32); + if (jpp.needs_f32_accum_for_bf16) dtypes.insert(data_type::f32); + + io_ = io_mdt_helper(this, jpp.isa, dtypes, {}, io_tail_conf, io_bf16_conf, + {}, utils::nullopt, io_fp8_conf); } static status_t set_binary_postops_formats( @@ -463,30 +488,6 @@ static int reg_ind(int shift, int bc, int j, int ur_bc, int ur_w) noexcept { return shift * ur_bc * ur_w + bc * ur_w + j; }; -template -inline void jit_uni_pool_kernel::prepare_tail_mask() { - if (is_superset(isa, avx512_core)) { - size_t c_tail_mask = (1ULL << jpp.c_tail) - 1ULL; - mov(tmp_gpr.cvt32(), c_tail_mask); - kmovw(k_c_tail_mask, tmp_gpr.cvt32()); - } else if (utils::one_of(isa, avx, avx2, avx2_vnni_2)) { - constexpr int max_words_in_ymm = 8; - - // for 'avx2_vnni_2' mask works with 2 x xf16 elements, - // in case of 'c_tail % 2 != 0' load/store an additional word - // for the remaining element. - auto dt_elem_div = isa == avx2_vnni_2 ? 2 : 1; - auto mask_offset = max_words_in_ymm - (jpp.c_tail / dt_elem_div); - auto mask_register - = isa == avx2_vnni_2 ? xmm_c_tail_mask : vmm_c_tail_mask; - static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff, - 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0, - 0, 0, 0, 0, 0, 0, 0}; - mov(tmp_gpr, reinterpret_cast(&mask[mask_offset])); - vmovups(mask_register, ptr[tmp_gpr]); - } -} - template inline void jit_uni_pool_kernel::put_one_in_vmm() { mov(tmp_gpr, 1); @@ -518,184 +519,180 @@ template inline void jit_uni_pool_kernel::load(const data_type_t dt, const int idx, const reg64_t ®_ptr, const int offset, const bool is_c_tail_proccessing) { - if (dt == data_type::bf16) { - /*TODO: maybe use vpmovzxwd + vpslld, - * in order to free up vmm_idx() register */ - if (is_c_tail_proccessing && !jpp.is_c_padded) { - Vmm vmm_to_load = Vmm(idx) | k_c_tail_mask | T_z; - vpmovzxwd(vmm_to_load, ptr[reg_ptr + offset]); - vpslld(vmm_to_load, vmm_to_load, 16); - } else { - vmovups(Ymm(idx), ptr[reg_ptr + offset]); - vpermw(Vmm(idx) | k_mask_cvt | T_z, vmm_idx(), Vmm(idx)); - } - } else if (dt == data_type::f16) { - Vmm vmm_to_load = is_c_tail_proccessing && !jpp.is_c_padded - ? Vmm(idx) | k_c_tail_mask | T_z - : Vmm(idx); - vcvtph2psx(vmm_to_load, ptr[reg_ptr + offset]); - } else if (utils::one_of(dt, data_type::f8_e5m2, data_type::f8_e4m3)) { - Vmm vmm_to_load = is_c_tail_proccessing && !jpp.is_c_padded - ? Vmm(idx) | k_c_tail_mask | T_z - : Vmm(idx); - if (dt == data_type::f8_e5m2) - f8_e5m2_emu_->vcvt_f8_to_f32(vmm_to_load, ptr[reg_ptr + offset]); - else if (dt == data_type::f8_e4m3) - f8_e4m3_emu_->vcvt_f8_to_f32(vmm_to_load, ptr[reg_ptr + offset]); - } else { - if (is_c_tail_proccessing && !jpp.is_c_padded) { - if (isa == avx || isa == avx2) { - vmaskmovps(Vmm(idx), vmm_c_tail_mask, ptr[reg_ptr + offset]); - } else { - vmovups(Zmm(idx) | k_c_tail_mask | T_z, ptr[reg_ptr + offset]); - } - } else { - uni_vmovups(Vmm(idx), ptr[reg_ptr + offset]); - } - } + io_[dt]->load(vmmword[reg_ptr + offset], Vmm(idx), + is_c_tail_proccessing && !jpp.is_c_padded); } -template <> -inline void jit_uni_pool_kernel::load(const data_type_t dt, - const int idx, const reg64_t ®_ptr, const int offset, +template +inline void jit_uni_pool_kernel::store(const data_type_t dt, const int idx, + const reg64_t ®_ptr, const int offset, const bool is_c_tail_proccessing) { - if (is_c_tail_proccessing) { - vmaskmovps(Xmm(idx), xmm_c_tail_mask, ptr[reg_ptr + offset]); - if (jpp.c_tail % 2 != 0) { - const int tail_pos = jpp.c_tail - 1; - auto word_addr - = ptr[reg_ptr + offset + tail_pos * sizeof(bfloat16_t)]; - vpinsrw(Xmm(idx), Xmm(idx), word_addr, tail_pos); - } - } - if (dt == data_type::bf16) { - if (is_c_tail_proccessing) - vpmovzxwd(Ymm(idx), Xmm(idx)); - else - vpmovzxwd(Ymm(idx), ptr[reg_ptr + offset]); - vpslld(Ymm(idx), Ymm(idx), 16); - } else if (dt == data_type::f16) { - if (is_c_tail_proccessing) - vcvtph2ps(Ymm(idx), Xmm(idx)); - else - vcvtph2ps(Ymm(idx), ptr[reg_ptr + offset]); - } else - assert(!"invalid data type"); + if (is_c_tail_proccessing && jpp.is_c_padded && jpp.with_postops) + pad_with_zeros(idx); + io_[dt]->store(Vmm(idx), vmmword[reg_ptr + offset], + is_c_tail_proccessing && !jpp.is_c_padded); } -template <> -inline void jit_uni_pool_kernel::load(const data_type_t dt, - const int idx, const reg64_t ®_ptr, const int offset, - const bool is_c_tail_proccessing) { - if (is_c_tail_proccessing && !jpp.is_c_padded) { - const auto dt_size = types::data_type_size(dt); - for (int i = 0; i < jpp.c_tail % (jpp.c_block / 2); i++) - pinsrd(Xmm(idx), ptr[reg_ptr + offset + i * dt_size], i); +template +inline void jit_uni_pool_kernel::pad_with_zeros(const int idx) { + if (isa == sse41) { + uni_vxorps(xmm_tmp_1, xmm_tmp_1, xmm_tmp_1); + if (jpp.c_tail <= sse41_single_block_size && sse_high_half) { + uni_vmovups(Vmm(idx), xmm_tmp_1); + } else if ((jpp.c_tail < sse41_single_block_size && !sse_high_half) + || (jpp.c_tail > sse41_single_block_size && sse_high_half)) { + const auto c_tail = jpp.c_tail % sse41_single_block_size; + std::bitset<8> tail_mask((1 << c_tail) - 1); + tail_mask.flip(); + uni_vblendps(Vmm(idx), Vmm(idx), xmm_tmp_1, tail_mask.to_ulong()); + } + } else if (isa == avx || isa == avx2) { + uni_vxorps(ymm_tmp_1, ymm_tmp_1, ymm_tmp_1); + uni_vblendvps(Vmm(idx), ymm_tmp_1, Vmm(idx), vmm_c_tail_mask); } else - uni_vmovups(Vmm(idx), ptr[reg_ptr + offset]); + uni_vmovups(Vmm(idx) | k_c_tail_mask | T_z, Vmm(idx)); } template -inline void jit_uni_pool_kernel::store(const data_type_t dt, const int idx, - const reg64_t ®_ptr, const int offset, - const bool is_c_tail_proccessing) { - if (utils::one_of(dt, data_type::bf16, data_type::f16)) { - if (is_c_tail_proccessing) { - if (jpp.is_c_padded) { - vmovdqu16(Ymm(idx) | k_c_tail_mask | T_z, Ymm(idx)); - vmovups(yword[reg_ptr + offset], Ymm(idx)); - } else - vmovdqu16(ptr[reg_ptr + offset] | k_c_tail_mask, Ymm(idx)); - } else - vmovups(yword[reg_ptr + offset], Ymm(idx)); - } else if (utils::one_of(dt, data_type::f8_e5m2, data_type::f8_e4m3)) { - if (is_c_tail_proccessing) { - if (jpp.is_c_padded) { - vmovdqu8(Xmm(idx) | k_c_tail_mask | T_z, Xmm(idx)); - vmovdqu8(yword[reg_ptr + offset], Xmm(idx)); - } else - vmovdqu8(ptr[reg_ptr + offset] | k_c_tail_mask, Xmm(idx)); - } else - vmovdqu8(yword[reg_ptr + offset], Xmm(idx)); - } else { - if (is_c_tail_proccessing) { - if (!jpp.is_c_padded) { - if (isa == avx || isa == avx2) - vmaskmovps( - ptr[reg_ptr + offset], vmm_c_tail_mask, Vmm(idx)); - else - vmovups(ptr[reg_ptr + offset] | k_c_tail_mask, Zmm(idx)); +inline void jit_uni_pool_kernel::load_indices( + const int indr_i, const int step_index, bool is_c_tail_processing) { + if (jpp.ind_dt == data_type::u8) { + auto indvr = vreg(indr_i); + auto indxr = xreg(indr_i); + if (isa == sse41) { + if (is_c_tail_processing && !jpp.is_c_padded) { + for (int i = 0; i < jpp.c_tail % (jpp.c_block / 2); i++) + pinsrb(indxr, ptr[reg_index + step_index + i], i); } else { - if (jpp.with_postops) { - if (isa == avx || isa == avx2) { - uni_vxorps(ymm_tmp_1, ymm_tmp_1, ymm_tmp_1); - uni_vblendvps( - Vmm(idx), ymm_tmp_1, Vmm(idx), vmm_c_tail_mask); - } else - uni_vmovups(Vmm(idx) | k_c_tail_mask | T_z, Vmm(idx)); - } - uni_vmovups(vmmword[reg_ptr + offset], Vmm(idx)); + movd(indxr, ptr[reg_index + step_index]); } - } else - uni_vmovups(vmmword[reg_ptr + offset], Vmm(idx)); + pmovzxbd(indvr, indxr); + } else if (isa == avx || isa == avx2) { + if (is_c_tail_processing && !jpp.is_c_padded) { + for (int i = 0; i < jpp.c_tail; i++) + vpinsrb(indxr, indxr, ptr[reg_index + step_index + i], i); + } else { + vmovq(indxr, ptr[reg_index + step_index]); + } + if (!mayiuse(avx2)) { + avx_pmovzxbd(indvr, indxr, xmm_tmp); + } else { + vpmovzxbd(indvr, indxr); + } + } else { + if (is_c_tail_processing && !jpp.is_c_padded) { + vpmovzxbd(indvr | k_c_tail_mask | T_z, + ptr[reg_index + step_index]); + } else { + vpmovzxbd(indvr, ptr[reg_index + step_index]); + } + } + } else { + assert(jpp.ind_dt == data_type::s32); + + // Load 4-byte values without conversion. The values are actually integers. + auto indvr = vreg(indr_i); + io_[data_type::f32]->load(vmmword[reg_index + step_index], indvr, + is_c_tail_processing && !jpp.is_c_padded); } } -template <> -inline void jit_uni_pool_kernel::store(const data_type_t dt, - const int idx, const reg64_t ®_ptr, const int offset, - const bool is_c_tail_proccessing) { - if (utils::one_of(dt, data_type::bf16, data_type::f16)) { - if (is_c_tail_proccessing) { - vmaskmovps(ptr[reg_ptr + offset], xmm_c_tail_mask, Xmm(idx)); - if (jpp.c_tail % 2 != 0) { - const int tail_pos = jpp.c_tail - 1; - auto word_addr = ptr[reg_ptr + offset + tail_pos * 2]; - vpextrw(word_addr, Xmm(idx), tail_pos); +template +inline void jit_uni_pool_kernel::store_indices(const int indr_i, + const int step_index, const bool is_c_tail_processing, + const bool is_first_w_block) { + if (jpp.ind_dt == data_type::u8) { + auto xr = xreg(indr_i); + if (isa == sse41) { + for (int i = 0; i < (jpp.c_block / 2); ++i) { + if (is_c_tail_processing + && i + (sse_high_half ? (jpp.c_block / 2) : 0) + >= jpp.c_tail) { + if (jpp.is_c_padded) + mov(ptr[reg_index + step_index + i], + tmp_gpr.cvt8()); // fill padded tail with zeros + else + break; // tail end + } else { + // bytes which should be stored are located in + // least significant bits(8 to be precise) of 32 bits parts + // of xmm thus we need to store 0, 4, 8 and 12 byte of xmm + pextrb(ptr[reg_index + step_index + i], xr, 4 * i); + } } - } else - vmovups(xword[reg_ptr + offset], Xmm(idx)); - } else - assert(!"datatype not supported"); -} - -template <> -inline void jit_uni_pool_kernel::store(const data_type_t dt, - const int idx, const reg64_t ®_ptr, const int offset, - const bool is_c_tail_proccessing) { - if (is_c_tail_proccessing) { - if (!jpp.is_c_padded) { - const auto dt_size = types::data_type_size(dt); - for (int i = 0; i < jpp.c_tail % (jpp.c_block / 2); i++) - pextrd(ptr[reg_ptr + offset + i * dt_size], Xmm(idx), i); - } else { - if (jpp.with_postops) { - static constexpr auto xmm_half = 4; - const auto tail_size = (jpp.c_without_padding > jpp.c_block) - ? jpp.c_without_padding % (jpp.c - jpp.c_block) - : jpp.c_without_padding; - const auto tail_size_real = (tail_size >= xmm_half) - ? tail_size - xmm_half - : tail_size; - uni_vxorps(xmm_tmp_1, xmm_tmp_1, xmm_tmp_1); - if (tail_size <= xmm_half && sse_high_half) { - // just zero out upper half padding and don't write anything else - uni_vmovups(vmmword[reg_ptr + offset], xmm_tmp_1); - return; + } else if (utils::one_of(isa, avx, avx2, avx2_vnni_2)) { + auto yr = yreg(indr_i); + if (is_c_tail_processing && !jpp.is_c_padded) { + const int max_nr_of_vals = jpp.c_tail > (jpp.c_block / 2) + ? (jpp.c_block / 2) + : jpp.c_tail; + for (int i = 0; i < max_nr_of_vals; ++i) { + // bytes which should be stored are located in + // least significant bits(8 to be precise) of 32 bits parts + // of xmm thus we need to store 0, 4, 8 and 12 byte of xmm + vpextrb(ptr[reg_index + step_index + i], xr, 4 * i); } - if ((tail_size < xmm_half && !sse_high_half) - || (tail_size > xmm_half && sse_high_half)) { - std::bitset<8> tail_mask((1 << tail_size_real) - 1); - tail_mask.flip(); - uni_vblendps(Vmm(idx), Vmm(idx), xmm_tmp_1, - tail_mask.to_ulong()); + if (jpp.c_tail > (jpp.c_block / 2)) { + Xmm higher_128bits(vmm_mask.getIdx()); + vextractf128(higher_128bits, yr, 1); + for (int i = 0; i < jpp.c_tail - (jpp.c_block / 2); ++i) { + // bytes which should be stored are located in + // least significant bits(8 to be precise) of 32 bits parts + // of xmm thus we need to store 0, 4, 8 and 12 byte of xmm + vpextrb(ptr[reg_index + step_index + (jpp.c_block / 2) + + i], + higher_128bits, 4 * i); + } + } + } else { + if (is_c_tail_processing) { + assert(jpp.is_c_padded); + vandps(yr, yr, vmm_c_tail_mask); + } + if (is_first_w_block) { + vmovd(xmm_tmp, reg_shuf_mask); + uni_vpbroadcastd(vmm_tmp, xmm_tmp); + } + if (mayiuse(avx2)) { + vpshufb(yr, yr, vmm_tmp); + vmovd(ptr[reg_index + step_index], xr); + vperm2i128(yr, yr, yr, 0x1u); + vmovd(ptr[reg_index + step_index + (jpp.c_block / 2)], xr); + } else { + Xmm t(vmm_mask.getIdx()); + vextractf128(t, yr, 0); + vpshufb(t, t, xmm_tmp); + vmovd(ptr[reg_index + step_index], t); + vextractf128(t, yr, 1); + vpshufb(t, t, + xmm_tmp); // ymm_tmp[:128]==ymm_tmp[127:0] + vmovd(ptr[reg_index + step_index + (jpp.c_block / 2)], t); } } - uni_vmovups(vmmword[reg_ptr + offset], Vmm(idx)); + } else { + auto vr = vreg(indr_i); + if (is_c_tail_processing) { + if (jpp.is_c_padded) { + knotw(k_c_tail_mask, k_c_tail_mask); + vpxord(vr | k_c_tail_mask, vr, vr); + knotw(k_c_tail_mask, k_c_tail_mask); + vpmovusdb(ptr[reg_index + step_index], vr); + } else + vpmovusdb(ptr[reg_index + step_index], vr | k_c_tail_mask); + } else { + vpmovusdb(ptr[reg_index + step_index], vr); + } } - } else - uni_vmovups(vmmword[reg_ptr + offset], Vmm(idx)); + } else { + assert(jpp.ind_dt == data_type::s32); + + // Store 4-byte values without conversion. The values are actually integers. + auto idx = reg_idx(indr_i); + if (is_c_tail_processing && jpp.is_c_padded) pad_with_zeros(idx); + io_[data_type::f32]->store(Vmm(idx), vmmword[reg_index + step_index], + is_c_tail_processing && !jpp.is_c_padded); + } } template @@ -887,39 +884,15 @@ inline void jit_uni_pool_kernel::avg_step(int ur_w, int ur_bc, int pad_l, if (aux_input_offset >= iw * c_off) continue; int input_offset = dt_size * aux_input_offset; if (jpp.is_backward) { - auto inpyr = yreg(inpr_i); load(jpp.src_dt, reg_idx(inpr_i), aux_reg_input, input_offset, is_tail_processing(bci)); uni_vaddps(inpvr, inpvr, accvr); - if (jpp.is_bf16) { - if (!isa_has_bf16(jpp.isa)) - bf16_emu_->vcvtneps2bf16(inpyr, zreg(inpr_i)); - else - vcvtneps2bf16(inpyr, inpvr); - } else if (jpp.is_f16) { - vcvtps2ph(inpyr, inpvr, _op_mxcsr); - } else if (jpp.is_fp8) { - auto inpxr = xreg(inpr_i); - if (jpp.src_dt == data_type::f8_e5m2) - f8_e5m2_emu_->vcvt_f32_to_f8(inpxr, zreg(inpr_i)); - else if (jpp.src_dt == data_type::f8_e4m3) - f8_e4m3_emu_->vcvt_f32_to_f8(inpxr, zreg(inpr_i)); - } store(jpp.src_dt, reg_idx(inpr_i), aux_reg_input, input_offset, is_tail_processing(bci)); } else { - if (jpp.is_bf16 || jpp.is_f16 || jpp.is_fp8 - || is_tail_processing(bci) - || (isa == sse41 - && c_off % (jpp.c_block / 2) != 0)) { - load(jpp.src_dt, vmm_tmp_1.getIdx(), aux_reg_input, - input_offset, is_tail_processing(bci)); - - uni_vaddps(accvr, accvr, vmm_tmp_1); - } else { - uni_vaddps(accvr, accvr, - ptr[aux_reg_input + input_offset]); - } + load(jpp.src_dt, vmm_tmp_1.getIdx(), aux_reg_input, + input_offset, is_tail_processing(bci)); + uni_vaddps(accvr, accvr, vmm_tmp_1); } } } @@ -955,34 +928,8 @@ inline void jit_uni_pool_kernel::avg_step(int ur_w, int ur_bc, int pad_l, for (int jj = 0; jj < ur_w; jj++) { for (int bci = 0; bci < ur_bc; bci++) { const auto accr_i = reg_ind(0, bci, jj, ur_bc, ur_w); - const auto accvr = vreg(accr_i); const auto output_offset = dt_size * (jj * c_off + bci * c_block); - const auto accyr = yreg(accr_i); - if (jpp.is_bf16) { - if (isa == avx2_vnni_2) { - auto accxr = xreg(accr_i); - vcvtneps2bf16(accxr, accyr, Xbyak::VexEncoding); - } else { - const auto acczr = zreg(accr_i); - if (!isa_has_bf16(jpp.isa)) - bf16_emu_->vcvtneps2bf16(accyr, acczr); - else - vcvtneps2bf16(accyr, accvr); - } - } else if (jpp.is_f16) { - if (isa == avx2_vnni_2) { - auto accxr = xreg(accr_i); - vcvtps2ph(accxr, accyr, _op_mxcsr); - } else - vcvtps2ph(accyr, accvr, _op_mxcsr); - } else if (jpp.is_fp8) { - const auto accxr = xreg(accr_i); - if (jpp.src_dt == data_type::f8_e5m2) - f8_e5m2_emu_->vcvt_f32_to_f8(accxr, accvr); - else if (jpp.src_dt == data_type::f8_e4m3) - f8_e4m3_emu_->vcvt_f32_to_f8(accxr, accvr); - } store(jpp.dst_dt, reg_idx(accr_i), reg_output, output_offset, is_tail_processing(bci)); } @@ -1129,136 +1076,19 @@ inline void jit_uni_pool_kernel::max_step_fwd(int ur_w, int ur_bc, for_(int jj = 0; jj < ur_w; jj++) for (int bci = 0; bci < ur_bc; bci++) { const auto accr_i = reg_ind(0, bci, jj, ur_bc, ur_w); - const auto accvr = vreg(accr_i); const auto output_offset = jpp.dt_size * (jj * c_off + bci * c_block); - auto accyr = yreg(accr_i); - if (jpp.is_bf16) { - if (isa == avx2_vnni_2) { - auto accxr = xreg(accr_i); - vcvtneps2bf16(accxr, accyr, Xbyak::VexEncoding); - } else { - auto acczr = zreg(accr_i); - if (!isa_has_bf16(jpp.isa)) - bf16_emu_->vcvtneps2bf16(accyr, acczr); - else - vcvtneps2bf16(accyr, accvr); - } - } else if (jpp.is_f16) { - if (isa == avx2_vnni_2) { - auto accxr = xreg(accr_i); - vcvtps2ph(accxr, accyr, _op_mxcsr); - } else - vcvtps2ph(accyr, accvr, _op_mxcsr); - } else if (jpp.is_fp8) { - auto accxr = xreg(accr_i); - auto acczr = zreg(accr_i); - if (jpp.src_dt == data_type::f8_e5m2) - f8_e5m2_emu_->vcvt_f32_to_f8(accxr, acczr); - else if (jpp.src_dt == data_type::f8_e4m3) - f8_e4m3_emu_->vcvt_f32_to_f8(accxr, acczr); - } + const bool is_c_tail_processing = is_tail_processing(bci); store(jpp.dst_dt, reg_idx(accr_i), reg_output, output_offset, - is_tail_processing(bci)); + is_c_tail_processing); if (jpp.is_training) { const size_t step_index = (jj * c_off + bci * c_block) * types::data_type_size(jpp.ind_dt); const auto indr_i = reg_ind(2, bci, jj, ur_bc, ur_w); - auto vr = vreg(indr_i); - if (jpp.ind_dt == data_type::u8) { - auto xr = xreg(indr_i); - if (isa == sse41) { - for (int i = 0; i < (jpp.c_block / 2); ++i) { - if (is_tail_processing(bci) - && i + (sse_high_half ? (jpp.c_block / 2) : 0) - >= jpp.c_tail) { - if (jpp.is_c_padded) - mov(ptr[reg_index + step_index + i], - tmp_gpr.cvt8()); // fill padded tail with zeros - else - break; // tail end - } else { - // bytes which should be stored are located in - // least significant bits(8 to be precise) of 32 bits parts - // of xmm thus we need to store 0, 4, 8 and 12 byte of xmm - pextrb(ptr[reg_index + step_index + i], xr, 4 * i); - } - } - } else if (utils::one_of(isa, avx, avx2, avx2_vnni_2)) { - auto yr = yreg(indr_i); - if (is_tail_processing(bci) && !jpp.is_c_padded) { - const int max_nr_of_vals - = jpp.c_tail > (jpp.c_block / 2) - ? (jpp.c_block / 2) - : jpp.c_tail; - for (int i = 0; i < max_nr_of_vals; ++i) { - // bytes which should be stored are located in - // least significant bits(8 to be precise) of 32 bits parts - // of xmm thus we need to store 0, 4, 8 and 12 byte of xmm - vpextrb(ptr[reg_index + step_index + i], xr, 4 * i); - } - - if (jpp.c_tail > (jpp.c_block / 2)) { - Xmm higher_128bits(vmm_mask.getIdx()); - vextractf128(higher_128bits, yr, 1); - for (int i = 0; i < jpp.c_tail - (jpp.c_block / 2); - ++i) { - // bytes which should be stored are located in - // least significant bits(8 to be precise) of 32 bits parts - // of xmm thus we need to store 0, 4, 8 and 12 byte of xmm - vpextrb(ptr[reg_index + step_index - + (jpp.c_block / 2) + i], - higher_128bits, 4 * i); - } - } - } else { - if (is_tail_processing(bci)) { - assert(jpp.is_c_padded); - vandps(yr, yr, vmm_c_tail_mask); - } - if (jj == 0) { - vmovd(xmm_tmp, reg_shuf_mask); - uni_vpbroadcastd(vmm_tmp, xmm_tmp); - } - if (mayiuse(avx2)) { - vpshufb(yr, yr, vmm_tmp); - vmovd(ptr[reg_index + step_index], xr); - vperm2i128(yr, yr, yr, 0x1u); - vmovd(ptr[reg_index + step_index - + (jpp.c_block / 2)], - xr); - } else { - Xmm t(vmm_mask.getIdx()); - vextractf128(t, yr, 0); - vpshufb(t, t, xmm_tmp); - vmovd(ptr[reg_index + step_index], t); - vextractf128(t, yr, 1); - vpshufb(t, t, - xmm_tmp); // ymm_tmp[:128]==ymm_tmp[127:0] - vmovd(ptr[reg_index + step_index - + (jpp.c_block / 2)], - t); - } - } - } else { - if (is_tail_processing(bci)) { - if (jpp.is_c_padded) { - knotw(k_c_tail_mask, k_c_tail_mask); - vpxord(vr | k_c_tail_mask, vr, vr); - knotw(k_c_tail_mask, k_c_tail_mask); - vpmovusdb(ptr[reg_index + step_index], vr); - } else - vpmovusdb(ptr[reg_index + step_index], - vr | k_c_tail_mask); - } else { - vpmovusdb(ptr[reg_index + step_index], vr); - } - } - } else { - store(jpp.ind_dt, vr.getIdx(), reg_index, step_index, - is_tail_processing(bci)); - } + const bool is_first_w_block = jj == 0; + store_indices( + indr_i, step_index, is_c_tail_processing, is_first_w_block); } } } @@ -1300,48 +1130,14 @@ inline void jit_uni_pool_kernel::max_step_bwd(int ur_w, int ur_bc, for (int bci = 0; bci < ur_bc; bci++) { const auto outr_i = reg_ind(0, bci, jj, ur_bc, ur_w); auto out_offset = output_dt_size * (jj * output_c_off + bci * c_block); + const bool is_c_tail_processing = is_tail_processing(bci); load(jpp.dst_dt, reg_idx(outr_i), reg_output, out_offset, - is_tail_processing(bci)); + is_c_tail_processing); const size_t step_index = (jj * output_c_off + bci * c_block) * types::data_type_size(jpp.ind_dt); const auto indr_i = reg_ind(1, bci, jj, ur_bc, ur_w); - auto indvr = vreg(indr_i); - if (jpp.ind_dt == data_type::u8) { - auto indxr = xreg(indr_i); - if (isa == sse41) { - if (is_tail_processing(bci) && !jpp.is_c_padded) { - for (int i = 0; i < jpp.c_tail % (jpp.c_block / 2); i++) - pinsrb(indxr, ptr[reg_index + step_index + i], i); - } else { - movd(indxr, ptr[reg_index + step_index]); - } - pmovzxbd(indvr, indxr); - } else if (isa == avx || isa == avx2) { - if (is_tail_processing(bci) && !jpp.is_c_padded) { - for (int i = 0; i < jpp.c_tail; i++) - vpinsrb(indxr, indxr, ptr[reg_index + step_index + i], - i); - } else { - vmovq(indxr, ptr[reg_index + step_index]); - } - if (!mayiuse(avx2)) { - avx_pmovzxbd(indvr, indxr, xmm_tmp); - } else { - vpmovzxbd(indvr, indxr); - } - } else { - if (is_tail_processing(bci) && !jpp.is_c_padded) { - vpmovzxbd(indvr | k_c_tail_mask | T_z, - ptr[reg_index + step_index]); - } else { - vpmovzxbd(indvr, ptr[reg_index + step_index]); - } - } - } else { - load(jpp.ind_dt, indvr.getIdx(), reg_index, step_index, - is_tail_processing(bci)); - } + load_indices(indr_i, step_index, is_c_tail_processing); } uni_vmovq(xmm_tmp, reg_k_shift); uni_vpbroadcastd(vmm_k_offset, xmm_tmp); @@ -1349,11 +1145,6 @@ inline void jit_uni_pool_kernel::max_step_bwd(int ur_w, int ur_bc, if (jpp.simple_alg && jpp.ndims == 5) { push(reg_input); push(reg_output); - if (isa == sse41) { - // Save rdi since it is used in maskmovdqu - assert(dst_ptr == rdi); - push(dst_ptr); - } mov(aux_reg_input_d, reg_input); mov(ki, ptr[reg_param + GET_OFF(kd_padding)]); mov(reg_kd_pad_shift, ptr[reg_param + GET_OFF(kd_padding_shift)]); @@ -1385,50 +1176,27 @@ inline void jit_uni_pool_kernel::max_step_bwd(int ur_w, int ur_bc, load(input_dt, reg_idx(inpr_i), aux_reg_input, inp_offset, is_tail_processing(bci)); if (isa == sse41) { - mov(dst_ptr, aux_reg_input); - add(dst_ptr, inp_offset); - movups(cvtvr, indvr); pcmpeqd(cvtvr, vmm_k_offset); - addps(inpvr, outvr); - if (is_tail_processing(bci)) { - Label end_cond_move[4]; - for (int i = 0; i < jpp.c_tail % (jpp.c_block / 2); - i++) { - pextrd(tmp_gpr.cvt32(), cvtvr, i); - cmp(tmp_gpr, 0); - je(end_cond_move[i], T_NEAR); - pextrd(ptr[dst_ptr + i * jpp.dt_size], inpvr, i); - L(end_cond_move[i]); - } - } else - maskmovdqu(inpvr, cvtvr); + vandps(cvtvr, cvtvr, outvr); + addps(inpvr, cvtvr); + store(input_dt, inpvr.getIdx(), aux_reg_input, inp_offset, + is_tail_processing(bci)); } else if (isa == avx || isa == avx2) { if (mayiuse(avx2)) { vpcmpeqd(cvtvr, indvr, vmm_k_offset); } else { avx_pcmpeqd(cvtvr, indvr, vmm_k_offset, xmm_tmp); } - vaddps(inpvr, inpvr, outvr); - if (is_tail_processing(bci)) { - vandps(cvtvr, cvtvr, vmm_c_tail_mask); - } - vmaskmovps( - vmmword[aux_reg_input + inp_offset], cvtvr, inpvr); + uni_vpxor(vmm_tmp, vmm_tmp, vmm_tmp); + vblendvps(vmm_tmp, vmm_tmp, outvr, cvtvr); + vaddps(inpvr, inpvr, vmm_tmp); + store(input_dt, inpvr.getIdx(), aux_reg_input, inp_offset, + is_tail_processing(bci)); } else { - auto indzr = zreg(inpr_i); - auto indyr = yreg(inpr_i); vpcmpeqd(k_store_mask, indvr, vmm_k_offset); vblendmps(vmm_tmp | k_store_mask | T_z, outvr, outvr); vaddps(inpvr, inpvr, vmm_tmp); - if (jpp.is_bf16 && !jpp.needs_f32_accum_for_bf16) { - if (!isa_has_bf16(jpp.isa)) - bf16_emu_->vcvtneps2bf16(indyr, indzr); - else - vcvtneps2bf16(indyr, inpvr); - } else if (jpp.is_f16) { - vcvtps2ph(indyr, inpvr, _op_mxcsr); - } store(input_dt, inpvr.getIdx(), aux_reg_input, inp_offset, is_tail_processing(bci)); } @@ -1469,11 +1237,6 @@ inline void jit_uni_pool_kernel::max_step_bwd(int ur_w, int ur_bc, dec(ki); cmp(ki, 0); jg(kd_label, T_NEAR); - if (isa == sse41) { - // Save rdi since it is used in maskmovdqu - assert(dst_ptr == rdi); - pop(dst_ptr); - } pop(reg_output); pop(reg_input); } @@ -1565,8 +1328,6 @@ void jit_uni_pool_kernel::generate() { this->preamble(); - Label idx_table; - int ow = jpp.ow; int iw = jpp.iw; int kw = jpp.kw; @@ -1585,14 +1346,7 @@ void jit_uni_pool_kernel::generate() { const size_t input_dt_size = jpp.needs_f32_accum_for_bf16 ? sizeof(float) : jpp.dt_size; -#if defined(_WIN32) - // Always mimic the Unix ABI (see the note about maskmovdqu in the header - // file). - xor_(rdi, rcx); - xor_(rcx, rdi); - xor_(rdi, rcx); -#endif - if (use_bf16_emulation()) bf16_emu_->init_vcvtneps2bf16(); + if (use_bf16_emulation()) io_.init_bf16(); mov(reg_input, ptr[reg_param + GET_OFF(src)]); mov(reg_output, ptr[reg_param + GET_OFF(dst)]); @@ -1603,14 +1357,6 @@ void jit_uni_pool_kernel::generate() { mov(reg_ker_area_h, ptr[reg_param + GET_OFF(ker_area_h)]); mov(reg_nbc, ptr[reg_param + GET_OFF(ur_bc)]); - if ((jpp.is_bf16 || jpp.is_f16) && isa != avx2_vnni_2) { - mov(tmp_gpr.cvt32(), 0xAAAAAAAA); - kmovd(k_mask_cvt, tmp_gpr.cvt32()); - - mov(tmp_gpr, idx_table); - vmovups(vmm_idx(), ptr[tmp_gpr]); - } - auto process_oi = [&](int ur_w, int ur_bc, int lpad, int rpad, bool with_c_tail_proccessing, bool inc_reg = true) { @@ -1763,7 +1509,7 @@ void jit_uni_pool_kernel::generate() { // care of c tail processing if number of channels // is not divided by number of channels in block L(ur_bc_tail_label); - if (jpp.c_tail != 0) prepare_tail_mask(); + if (jpp.c_tail != 0) io_.prepare_tail_mask(); perform_ker(jpp.ur_bc_tail, jpp.c_tail != 0); L(finish_label); @@ -1771,7 +1517,7 @@ void jit_uni_pool_kernel::generate() { jmp(finish_label, T_NEAR); L(c_tail_processing_label); - prepare_tail_mask(); + io_.prepare_tail_mask(); perform_ker(jpp.ur_bc, true); L(finish_label); @@ -1781,17 +1527,9 @@ void jit_uni_pool_kernel::generate() { if (jpp.with_eltwise && postops_injector_) postops_injector_->prepare_table(/* generate = */ true); - - if ((jpp.is_bf16 || jpp.is_f16) && isa != avx2_vnni_2) { - align(64); - L(idx_table); - const uint16_t _idx[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, - 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15}; - for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i) - dw(_idx[i]); - } if (f8_e5m2_emu_) f8_e5m2_emu_->prepare_table(); if (f8_e4m3_emu_) f8_e4m3_emu_->prepare_table(); + io_.prepare_table_fp8(); } template struct jit_uni_pool_kernel; diff --git a/src/cpu/x64/jit_uni_pool_kernel.hpp b/src/cpu/x64/jit_uni_pool_kernel.hpp index 221cdc7acb3..6f239882d04 100644 --- a/src/cpu/x64/jit_uni_pool_kernel.hpp +++ b/src/cpu/x64/jit_uni_pool_kernel.hpp @@ -27,14 +27,13 @@ #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" #include "cpu/x64/jit_generator.hpp" #include "cpu/x64/jit_primitive_conf.hpp" +#include "cpu/x64/utils/jit_io_helper.hpp" namespace dnnl { namespace impl { namespace cpu { namespace x64 { -struct bf16_emulation_t; - template struct jit_uni_pool_kernel : public jit_generator { @@ -81,9 +80,8 @@ struct jit_uni_pool_kernel : public jit_generator { Ymm ymm_tmp_1 = Ymm(0); Vmm vmm_tmp_1 = Vmm(0); - // Used only for avx and if c tail is present + // Used only for avx and if c tail is present; is shared with jit_io_multi_dt_helper_t Vmm vmm_c_tail_mask = Vmm(2); - Xmm xmm_c_tail_mask = Xmm(2); Vmm vmm_ker_area_h = Vmm(2); Vmm vmm_one = Vmm(2); @@ -92,14 +90,6 @@ struct jit_uni_pool_kernel : public jit_generator { Vmm vmm_k_offset = Vmm(1); - // Used only for avx512 when bf16 is present - inline Vmm vmm_idx() { - if (!jpp.is_backward) { - return (jpp.is_training) ? Vmm(4) : Vmm(1); - } else - return Vmm(4); - } - Zmm bf16_emu_reserv_1 = Zmm(5); Zmm bf16_emu_reserv_2 = Zmm(6); Zmm bf16_emu_reserv_3 = Zmm(7); @@ -114,33 +104,23 @@ struct jit_uni_pool_kernel : public jit_generator { Reg64 fp8_emu_reg64 = bf16_emu_reserv_4; Xbyak::Opmask fp8_tmp_mask = Xbyak::Opmask(3); - Opmask k_c_tail_mask = Opmask(4); - Opmask k_mask_cvt = Opmask(5); - Opmask k_store_mask = Opmask(6); - - // Here be some (tame) dragons. This kernel does not follow the regular - // OS-agnostic ABI pattern because when isa is sse41 it uses maskmovdqu - // instruction which has its destination hardcoded in rdi. Therefore: - // - all registers are hardcoded - // - on Windows rdi and rcx are swapped to mimic the Unix x86_64 ABI - // - // While this is only required by the backward pass, the quirk above - // is applied to the forward pass as well to keep things simpler. + Opmask k_c_tail_mask = Opmask( + 4); // is shared with jit_io_multi_dt_helper_t and jit_uni_postops_injector_t + Opmask k_store_mask = Opmask(5); using reg64_t = const Reg64; - reg64_t reg_param = rdi; // Always mimic the Unix ABI + reg64_t reg_param = abi_param1; reg64_t reg_input = r8; reg64_t aux_reg_input = r9; reg64_t reg_index = r10; reg64_t reg_output = r12; reg64_t reg_kd_pad_shift = r13; - reg64_t dst_ptr = rdi; // Must be rdi due to maskmovdqu reg64_t kj = r14; reg64_t oi_iter = r15; reg64_t reg_kh = rax; reg64_t reg_k_shift = rbx; - reg64_t tmp_gpr = rcx; // Must be rcx because rdi is used above + reg64_t tmp_gpr = abi_not_param1; reg64_t reg_ker_area_h = rdx; reg64_t reg_nbc = rsi; @@ -158,7 +138,6 @@ struct jit_uni_pool_kernel : public jit_generator { int prev_kw; - void prepare_tail_mask(); void put_one_in_vmm(); void uni_broadcast_reg_val(const int reg_idx, const int vmm_idx); void push_vmm_val(const int idx); @@ -167,6 +146,10 @@ struct jit_uni_pool_kernel : public jit_generator { const int offset, const bool is_c_tail_proccessing); void store(const data_type_t dt, const int idx, const reg64_t ®_ptr, const int offset, const bool is_c_tail_proccessing); + void pad_with_zeros(int idx); + void load_indices(int indr_i, int step_index, bool is_c_tail_processing); + void store_indices(int indr_i, int step_index, bool is_c_tail_processing, + bool is_first_w_block); void maybe_recalculate_divisor(int jj, int ur_w, int pad_l, int pad_r, bool with_c_tail_proccessing); @@ -271,11 +254,11 @@ struct jit_uni_pool_kernel : public jit_generator { return jpp.is_fp8 && is_superset(isa, avx512_core_fp16); } - std::unique_ptr bf16_emu_; std::unique_ptr f8_e5m2_emu_; std::unique_ptr f8_e4m3_emu_; std::unique_ptr> postops_injector_; + io::jit_io_multi_dt_helper_t io_; }; } // namespace x64 diff --git a/src/cpu/x64/utils/jit_io_helper.cpp b/src/cpu/x64/utils/jit_io_helper.cpp index d8e06455273..63c1a4d86cd 100644 --- a/src/cpu/x64/utils/jit_io_helper.cpp +++ b/src/cpu/x64/utils/jit_io_helper.cpp @@ -206,7 +206,7 @@ bool jit_io_helper_t::is_data_type_supported(const data_type_t dt) { case data_type::f16: return is_superset(isa_, avx512_core_fp16) || isa_ == avx2_vnni_2; case data_type::f8_e4m3: - case data_type::f8_e5m2: return is_superset(isa_, avx512_core_amx); + case data_type::f8_e5m2: return is_superset(isa_, avx512_core_fp16); default: assert(!"Unsupported data type"); } return false;