Skip to content

Commit

Permalink
cpu: x64: pooling: refactor to use io injector (step 1)
Browse files Browse the repository at this point in the history
  • Loading branch information
asimonov1 committed Feb 5, 2025
1 parent 6149300 commit d98dd7d
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 234 deletions.
268 changes: 40 additions & 228 deletions src/cpu/x64/jit_uni_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,7 @@ jit_uni_pool_kernel<isa>::~jit_uni_pool_kernel() = default;
template <cpu_isa_t isa>
jit_uni_pool_kernel<isa>::jit_uni_pool_kernel(
const jit_pool_conf_t &ajpp, const memory_desc_t *dst_md)
: jit_generator(jit_name(), isa), jpp(ajpp), bf16_emu_(nullptr) {
if (use_bf16_emulation())
bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3,
bf16_emu_reserv_4, bf16_emu_reserv_5);
: jit_generator(jit_name(), isa), jpp(ajpp) {

bool has_f8_e5m2_binary_postops = false;
bool has_f8_e4m3_binary_postops = false;
Expand Down Expand Up @@ -109,6 +105,31 @@ jit_uni_pool_kernel<isa>::jit_uni_pool_kernel(
= utils::make_unique<injector::jit_uni_postops_injector_t<isa>>(
this, jpp.post_ops, bsp);
}

io::io_tail_conf_t io_tail_conf(jpp.c_block, jpp.c_tail,
k_c_tail_mask.getIdx(), vmm_c_tail_mask.getIdx(), tmp_gpr);

utils::optional_t<io::io_emu_bf16_conf_t> io_bf16_conf;
if (use_bf16_emulation())
io_bf16_conf = io::io_emu_bf16_conf_t(bf16_emu_reserv_1,
bf16_emu_reserv_2, bf16_emu_reserv_3, bf16_emu_reserv_4,
bf16_emu_reserv_5);

utils::optional_t<io::io_emu_fp8_conf_t> io_fp8_conf;
if (use_fp8_emulation() || has_f8_e5m2_binary_postops
|| has_f8_e4m3_binary_postops)
io_fp8_conf = io::io_emu_fp8_conf_t(fp8_emu_reserv_1, fp8_emu_reserv_2,
fp8_emu_reserv_3, fp8_emu_reserv_4, fp8_emu_reserv_5,
fp8_tmp_mask, fp8_emu_reg64);

using io_mdt_helper = io::jit_io_multi_dt_helper_t<Vmm>;

typename io_mdt_helper::data_types_t dtypes = {jpp.src_dt, jpp.dst_dt};
if (jpp.ind_dt != data_type::undef) dtypes.insert(jpp.ind_dt);
if (jpp.needs_f32_accum_for_bf16) dtypes.insert(data_type::f32);

io_ = io_mdt_helper(this, jpp.isa, dtypes, {}, io_tail_conf, io_bf16_conf,
{}, utils::nullopt, io_fp8_conf);
}

static status_t set_binary_postops_formats(
Expand Down Expand Up @@ -463,30 +484,6 @@ static int reg_ind(int shift, int bc, int j, int ur_bc, int ur_w) noexcept {
return shift * ur_bc * ur_w + bc * ur_w + j;
};

template <cpu_isa_t isa>
inline void jit_uni_pool_kernel<isa>::prepare_tail_mask() {
if (is_superset(isa, avx512_core)) {
size_t c_tail_mask = (1ULL << jpp.c_tail) - 1ULL;
mov(tmp_gpr.cvt32(), c_tail_mask);
kmovw(k_c_tail_mask, tmp_gpr.cvt32());
} else if (utils::one_of(isa, avx, avx2, avx2_vnni_2)) {
constexpr int max_words_in_ymm = 8;

// for 'avx2_vnni_2' mask works with 2 x xf16 elements,
// in case of 'c_tail % 2 != 0' load/store an additional word
// for the remaining element.
auto dt_elem_div = isa == avx2_vnni_2 ? 2 : 1;
auto mask_offset = max_words_in_ymm - (jpp.c_tail / dt_elem_div);
auto mask_register
= isa == avx2_vnni_2 ? xmm_c_tail_mask : vmm_c_tail_mask;
static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff,
0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0,
0, 0, 0, 0, 0, 0, 0};
mov(tmp_gpr, reinterpret_cast<size_t>(&mask[mask_offset]));
vmovups(mask_register, ptr[tmp_gpr]);
}
}

template <cpu_isa_t isa>
inline void jit_uni_pool_kernel<isa>::put_one_in_vmm() {
mov(tmp_gpr, 1);
Expand Down Expand Up @@ -518,69 +515,8 @@ template <cpu_isa_t isa>
inline void jit_uni_pool_kernel<isa>::load(const data_type_t dt, const int idx,
const reg64_t &reg_ptr, const int offset,
const bool is_c_tail_proccessing) {
if (dt == data_type::bf16) {
/*TODO: maybe use vpmovzxwd + vpslld,
* in order to free up vmm_idx() register */
if (is_c_tail_proccessing && !jpp.is_c_padded) {
Vmm vmm_to_load = Vmm(idx) | k_c_tail_mask | T_z;
vpmovzxwd(vmm_to_load, ptr[reg_ptr + offset]);
vpslld(vmm_to_load, vmm_to_load, 16);
} else {
vmovups(Ymm(idx), ptr[reg_ptr + offset]);
vpermw(Vmm(idx) | k_mask_cvt | T_z, vmm_idx(), Vmm(idx));
}
} else if (dt == data_type::f16) {
Vmm vmm_to_load = is_c_tail_proccessing && !jpp.is_c_padded
? Vmm(idx) | k_c_tail_mask | T_z
: Vmm(idx);
vcvtph2psx(vmm_to_load, ptr[reg_ptr + offset]);
} else if (utils::one_of(dt, data_type::f8_e5m2, data_type::f8_e4m3)) {
Vmm vmm_to_load = is_c_tail_proccessing && !jpp.is_c_padded
? Vmm(idx) | k_c_tail_mask | T_z
: Vmm(idx);
if (dt == data_type::f8_e5m2)
f8_e5m2_emu_->vcvt_f8_to_f32(vmm_to_load, ptr[reg_ptr + offset]);
else if (dt == data_type::f8_e4m3)
f8_e4m3_emu_->vcvt_f8_to_f32(vmm_to_load, ptr[reg_ptr + offset]);
} else {
if (is_c_tail_proccessing && !jpp.is_c_padded) {
if (isa == avx || isa == avx2) {
vmaskmovps(Vmm(idx), vmm_c_tail_mask, ptr[reg_ptr + offset]);
} else {
vmovups(Zmm(idx) | k_c_tail_mask | T_z, ptr[reg_ptr + offset]);
}
} else {
uni_vmovups(Vmm(idx), ptr[reg_ptr + offset]);
}
}
}

template <>
inline void jit_uni_pool_kernel<avx2_vnni_2>::load(const data_type_t dt,
const int idx, const reg64_t &reg_ptr, const int offset,
const bool is_c_tail_proccessing) {
if (is_c_tail_proccessing) {
vmaskmovps(Xmm(idx), xmm_c_tail_mask, ptr[reg_ptr + offset]);
if (jpp.c_tail % 2 != 0) {
const int tail_pos = jpp.c_tail - 1;
auto word_addr
= ptr[reg_ptr + offset + tail_pos * sizeof(bfloat16_t)];
vpinsrw(Xmm(idx), Xmm(idx), word_addr, tail_pos);
}
}
if (dt == data_type::bf16) {
if (is_c_tail_proccessing)
vpmovzxwd(Ymm(idx), Xmm(idx));
else
vpmovzxwd(Ymm(idx), ptr[reg_ptr + offset]);
vpslld(Ymm(idx), Ymm(idx), 16);
} else if (dt == data_type::f16) {
if (is_c_tail_proccessing)
vcvtph2ps(Ymm(idx), Xmm(idx));
else
vcvtph2ps(Ymm(idx), ptr[reg_ptr + offset]);
} else
assert(!"invalid data type");
io_[dt]->load(vmmword[reg_ptr + offset], Vmm(idx),
is_c_tail_proccessing && !jpp.is_c_padded);
}

template <>
Expand All @@ -599,64 +535,16 @@ template <cpu_isa_t isa>
inline void jit_uni_pool_kernel<isa>::store(const data_type_t dt, const int idx,
const reg64_t &reg_ptr, const int offset,
const bool is_c_tail_proccessing) {
if (utils::one_of(dt, data_type::bf16, data_type::f16)) {
if (is_c_tail_proccessing) {
if (jpp.is_c_padded) {
vmovdqu16(Ymm(idx) | k_c_tail_mask | T_z, Ymm(idx));
vmovups(yword[reg_ptr + offset], Ymm(idx));
} else
vmovdqu16(ptr[reg_ptr + offset] | k_c_tail_mask, Ymm(idx));
} else
vmovups(yword[reg_ptr + offset], Ymm(idx));
} else if (utils::one_of(dt, data_type::f8_e5m2, data_type::f8_e4m3)) {
if (is_c_tail_proccessing) {
if (jpp.is_c_padded) {
vmovdqu8(Xmm(idx) | k_c_tail_mask | T_z, Xmm(idx));
vmovdqu8(yword[reg_ptr + offset], Xmm(idx));
} else
vmovdqu8(ptr[reg_ptr + offset] | k_c_tail_mask, Xmm(idx));
} else
vmovdqu8(yword[reg_ptr + offset], Xmm(idx));
} else {
if (is_c_tail_proccessing) {
if (!jpp.is_c_padded) {
if (isa == avx || isa == avx2)
vmaskmovps(
ptr[reg_ptr + offset], vmm_c_tail_mask, Vmm(idx));
else
vmovups(ptr[reg_ptr + offset] | k_c_tail_mask, Zmm(idx));
} else {
if (jpp.with_postops) {
if (isa == avx || isa == avx2) {
uni_vxorps(ymm_tmp_1, ymm_tmp_1, ymm_tmp_1);
uni_vblendvps(
Vmm(idx), ymm_tmp_1, Vmm(idx), vmm_c_tail_mask);
} else
uni_vmovups(Vmm(idx) | k_c_tail_mask | T_z, Vmm(idx));
}
uni_vmovups(vmmword[reg_ptr + offset], Vmm(idx));
}
if (is_c_tail_proccessing && jpp.is_c_padded) {
if (isa == avx || isa == avx2) {
uni_vxorps(ymm_tmp_1, ymm_tmp_1, ymm_tmp_1);
uni_vblendvps(Vmm(idx), ymm_tmp_1, Vmm(idx), vmm_c_tail_mask);
} else
uni_vmovups(vmmword[reg_ptr + offset], Vmm(idx));
uni_vmovups(Vmm(idx) | k_c_tail_mask | T_z, Vmm(idx));
}
}

template <>
inline void jit_uni_pool_kernel<avx2_vnni_2>::store(const data_type_t dt,
const int idx, const reg64_t &reg_ptr, const int offset,
const bool is_c_tail_proccessing) {
if (utils::one_of(dt, data_type::bf16, data_type::f16)) {
if (is_c_tail_proccessing) {
vmaskmovps(ptr[reg_ptr + offset], xmm_c_tail_mask, Xmm(idx));
if (jpp.c_tail % 2 != 0) {
const int tail_pos = jpp.c_tail - 1;
auto word_addr = ptr[reg_ptr + offset + tail_pos * 2];
vpextrw(word_addr, Xmm(idx), tail_pos);
}
} else
vmovups(xword[reg_ptr + offset], Xmm(idx));
} else
assert(!"datatype not supported");
io_[dt]->store(Vmm(idx), vmmword[reg_ptr + offset],
is_c_tail_proccessing && !jpp.is_c_padded);
}

template <>
Expand Down Expand Up @@ -887,24 +775,9 @@ inline void jit_uni_pool_kernel<isa>::avg_step(int ur_w, int ur_bc, int pad_l,
if (aux_input_offset >= iw * c_off) continue;
int input_offset = dt_size * aux_input_offset;
if (jpp.is_backward) {
auto inpyr = yreg(inpr_i);
load(jpp.src_dt, reg_idx(inpr_i), aux_reg_input,
input_offset, is_tail_processing(bci));
uni_vaddps(inpvr, inpvr, accvr);
if (jpp.is_bf16) {
if (!isa_has_bf16(jpp.isa))
bf16_emu_->vcvtneps2bf16(inpyr, zreg(inpr_i));
else
vcvtneps2bf16(inpyr, inpvr);
} else if (jpp.is_f16) {
vcvtps2ph(inpyr, inpvr, _op_mxcsr);
} else if (jpp.is_fp8) {
auto inpxr = xreg(inpr_i);
if (jpp.src_dt == data_type::f8_e5m2)
f8_e5m2_emu_->vcvt_f32_to_f8(inpxr, zreg(inpr_i));
else if (jpp.src_dt == data_type::f8_e4m3)
f8_e4m3_emu_->vcvt_f32_to_f8(inpxr, zreg(inpr_i));
}
store(jpp.src_dt, reg_idx(inpr_i), aux_reg_input,
input_offset, is_tail_processing(bci));
} else {
Expand Down Expand Up @@ -955,34 +828,8 @@ inline void jit_uni_pool_kernel<isa>::avg_step(int ur_w, int ur_bc, int pad_l,
for (int jj = 0; jj < ur_w; jj++) {
for (int bci = 0; bci < ur_bc; bci++) {
const auto accr_i = reg_ind(0, bci, jj, ur_bc, ur_w);
const auto accvr = vreg(accr_i);
const auto output_offset
= dt_size * (jj * c_off + bci * c_block);
const auto accyr = yreg(accr_i);
if (jpp.is_bf16) {
if (isa == avx2_vnni_2) {
auto accxr = xreg(accr_i);
vcvtneps2bf16(accxr, accyr, Xbyak::VexEncoding);
} else {
const auto acczr = zreg(accr_i);
if (!isa_has_bf16(jpp.isa))
bf16_emu_->vcvtneps2bf16(accyr, acczr);
else
vcvtneps2bf16(accyr, accvr);
}
} else if (jpp.is_f16) {
if (isa == avx2_vnni_2) {
auto accxr = xreg(accr_i);
vcvtps2ph(accxr, accyr, _op_mxcsr);
} else
vcvtps2ph(accyr, accvr, _op_mxcsr);
} else if (jpp.is_fp8) {
const auto accxr = xreg(accr_i);
if (jpp.src_dt == data_type::f8_e5m2)
f8_e5m2_emu_->vcvt_f32_to_f8(accxr, accvr);
else if (jpp.src_dt == data_type::f8_e4m3)
f8_e4m3_emu_->vcvt_f32_to_f8(accxr, accvr);
}
store(jpp.dst_dt, reg_idx(accr_i), reg_output, output_offset,
is_tail_processing(bci));
}
Expand Down Expand Up @@ -1129,34 +976,7 @@ inline void jit_uni_pool_kernel<isa>::max_step_fwd(int ur_w, int ur_bc,
for_(int jj = 0; jj < ur_w; jj++)
for (int bci = 0; bci < ur_bc; bci++) {
const auto accr_i = reg_ind(0, bci, jj, ur_bc, ur_w);
const auto accvr = vreg(accr_i);
const auto output_offset = jpp.dt_size * (jj * c_off + bci * c_block);
auto accyr = yreg(accr_i);
if (jpp.is_bf16) {
if (isa == avx2_vnni_2) {
auto accxr = xreg(accr_i);
vcvtneps2bf16(accxr, accyr, Xbyak::VexEncoding);
} else {
auto acczr = zreg(accr_i);
if (!isa_has_bf16(jpp.isa))
bf16_emu_->vcvtneps2bf16(accyr, acczr);
else
vcvtneps2bf16(accyr, accvr);
}
} else if (jpp.is_f16) {
if (isa == avx2_vnni_2) {
auto accxr = xreg(accr_i);
vcvtps2ph(accxr, accyr, _op_mxcsr);
} else
vcvtps2ph(accyr, accvr, _op_mxcsr);
} else if (jpp.is_fp8) {
auto accxr = xreg(accr_i);
auto acczr = zreg(accr_i);
if (jpp.src_dt == data_type::f8_e5m2)
f8_e5m2_emu_->vcvt_f32_to_f8(accxr, acczr);
else if (jpp.src_dt == data_type::f8_e4m3)
f8_e4m3_emu_->vcvt_f32_to_f8(accxr, acczr);
}
store(jpp.dst_dt, reg_idx(accr_i), reg_output, output_offset,
is_tail_processing(bci));

Expand Down Expand Up @@ -1416,19 +1236,9 @@ inline void jit_uni_pool_kernel<isa>::max_step_bwd(int ur_w, int ur_bc,
vmaskmovps(
vmmword[aux_reg_input + inp_offset], cvtvr, inpvr);
} else {
auto indzr = zreg(inpr_i);
auto indyr = yreg(inpr_i);
vpcmpeqd(k_store_mask, indvr, vmm_k_offset);
vblendmps(vmm_tmp | k_store_mask | T_z, outvr, outvr);
vaddps(inpvr, inpvr, vmm_tmp);
if (jpp.is_bf16 && !jpp.needs_f32_accum_for_bf16) {
if (!isa_has_bf16(jpp.isa))
bf16_emu_->vcvtneps2bf16(indyr, indzr);
else
vcvtneps2bf16(indyr, inpvr);
} else if (jpp.is_f16) {
vcvtps2ph(indyr, inpvr, _op_mxcsr);
}
store(input_dt, inpvr.getIdx(), aux_reg_input, inp_offset,
is_tail_processing(bci));
}
Expand Down Expand Up @@ -1592,7 +1402,8 @@ void jit_uni_pool_kernel<isa>::generate() {
xor_(rcx, rdi);
xor_(rdi, rcx);
#endif
if (use_bf16_emulation()) bf16_emu_->init_vcvtneps2bf16();

if (use_bf16_emulation()) io_.init_bf16();

mov(reg_input, ptr[reg_param + GET_OFF(src)]);
mov(reg_output, ptr[reg_param + GET_OFF(dst)]);
Expand Down Expand Up @@ -1763,15 +1574,15 @@ void jit_uni_pool_kernel<isa>::generate() {
// care of c tail processing if number of channels
// is not divided by number of channels in block
L(ur_bc_tail_label);
if (jpp.c_tail != 0) prepare_tail_mask();
if (jpp.c_tail != 0) io_.prepare_tail_mask();
perform_ker(jpp.ur_bc_tail, jpp.c_tail != 0);

L(finish_label);
} else if (jpp.c_tail != 0) {
jmp(finish_label, T_NEAR);

L(c_tail_processing_label);
prepare_tail_mask();
io_.prepare_tail_mask();
perform_ker(jpp.ur_bc, true);

L(finish_label);
Expand All @@ -1792,6 +1603,7 @@ void jit_uni_pool_kernel<isa>::generate() {
}
if (f8_e5m2_emu_) f8_e5m2_emu_->prepare_table();
if (f8_e4m3_emu_) f8_e4m3_emu_->prepare_table();
io_.prepare_table_fp8();
}

template struct jit_uni_pool_kernel<sse41>;
Expand Down
Loading

0 comments on commit d98dd7d

Please sign in to comment.