Skip to content

Commit d98dd7d

Browse files
committed
cpu: x64: pooling: refactor to use io injector (step 1)
1 parent 6149300 commit d98dd7d

File tree

3 files changed

+43
-234
lines changed

3 files changed

+43
-234
lines changed

src/cpu/x64/jit_uni_pool_kernel.cpp

+40-228
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,7 @@ jit_uni_pool_kernel<isa>::~jit_uni_pool_kernel() = default;
4545
template <cpu_isa_t isa>
4646
jit_uni_pool_kernel<isa>::jit_uni_pool_kernel(
4747
const jit_pool_conf_t &ajpp, const memory_desc_t *dst_md)
48-
: jit_generator(jit_name(), isa), jpp(ajpp), bf16_emu_(nullptr) {
49-
if (use_bf16_emulation())
50-
bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
51-
bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3,
52-
bf16_emu_reserv_4, bf16_emu_reserv_5);
48+
: jit_generator(jit_name(), isa), jpp(ajpp) {
5349

5450
bool has_f8_e5m2_binary_postops = false;
5551
bool has_f8_e4m3_binary_postops = false;
@@ -109,6 +105,31 @@ jit_uni_pool_kernel<isa>::jit_uni_pool_kernel(
109105
= utils::make_unique<injector::jit_uni_postops_injector_t<isa>>(
110106
this, jpp.post_ops, bsp);
111107
}
108+
109+
io::io_tail_conf_t io_tail_conf(jpp.c_block, jpp.c_tail,
110+
k_c_tail_mask.getIdx(), vmm_c_tail_mask.getIdx(), tmp_gpr);
111+
112+
utils::optional_t<io::io_emu_bf16_conf_t> io_bf16_conf;
113+
if (use_bf16_emulation())
114+
io_bf16_conf = io::io_emu_bf16_conf_t(bf16_emu_reserv_1,
115+
bf16_emu_reserv_2, bf16_emu_reserv_3, bf16_emu_reserv_4,
116+
bf16_emu_reserv_5);
117+
118+
utils::optional_t<io::io_emu_fp8_conf_t> io_fp8_conf;
119+
if (use_fp8_emulation() || has_f8_e5m2_binary_postops
120+
|| has_f8_e4m3_binary_postops)
121+
io_fp8_conf = io::io_emu_fp8_conf_t(fp8_emu_reserv_1, fp8_emu_reserv_2,
122+
fp8_emu_reserv_3, fp8_emu_reserv_4, fp8_emu_reserv_5,
123+
fp8_tmp_mask, fp8_emu_reg64);
124+
125+
using io_mdt_helper = io::jit_io_multi_dt_helper_t<Vmm>;
126+
127+
typename io_mdt_helper::data_types_t dtypes = {jpp.src_dt, jpp.dst_dt};
128+
if (jpp.ind_dt != data_type::undef) dtypes.insert(jpp.ind_dt);
129+
if (jpp.needs_f32_accum_for_bf16) dtypes.insert(data_type::f32);
130+
131+
io_ = io_mdt_helper(this, jpp.isa, dtypes, {}, io_tail_conf, io_bf16_conf,
132+
{}, utils::nullopt, io_fp8_conf);
112133
}
113134

114135
static status_t set_binary_postops_formats(
@@ -463,30 +484,6 @@ static int reg_ind(int shift, int bc, int j, int ur_bc, int ur_w) noexcept {
463484
return shift * ur_bc * ur_w + bc * ur_w + j;
464485
};
465486

466-
template <cpu_isa_t isa>
467-
inline void jit_uni_pool_kernel<isa>::prepare_tail_mask() {
468-
if (is_superset(isa, avx512_core)) {
469-
size_t c_tail_mask = (1ULL << jpp.c_tail) - 1ULL;
470-
mov(tmp_gpr.cvt32(), c_tail_mask);
471-
kmovw(k_c_tail_mask, tmp_gpr.cvt32());
472-
} else if (utils::one_of(isa, avx, avx2, avx2_vnni_2)) {
473-
constexpr int max_words_in_ymm = 8;
474-
475-
// for 'avx2_vnni_2' mask works with 2 x xf16 elements,
476-
// in case of 'c_tail % 2 != 0' load/store an additional word
477-
// for the remaining element.
478-
auto dt_elem_div = isa == avx2_vnni_2 ? 2 : 1;
479-
auto mask_offset = max_words_in_ymm - (jpp.c_tail / dt_elem_div);
480-
auto mask_register
481-
= isa == avx2_vnni_2 ? xmm_c_tail_mask : vmm_c_tail_mask;
482-
static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff,
483-
0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0,
484-
0, 0, 0, 0, 0, 0, 0};
485-
mov(tmp_gpr, reinterpret_cast<size_t>(&mask[mask_offset]));
486-
vmovups(mask_register, ptr[tmp_gpr]);
487-
}
488-
}
489-
490487
template <cpu_isa_t isa>
491488
inline void jit_uni_pool_kernel<isa>::put_one_in_vmm() {
492489
mov(tmp_gpr, 1);
@@ -518,69 +515,8 @@ template <cpu_isa_t isa>
518515
inline void jit_uni_pool_kernel<isa>::load(const data_type_t dt, const int idx,
519516
const reg64_t &reg_ptr, const int offset,
520517
const bool is_c_tail_proccessing) {
521-
if (dt == data_type::bf16) {
522-
/*TODO: maybe use vpmovzxwd + vpslld,
523-
* in order to free up vmm_idx() register */
524-
if (is_c_tail_proccessing && !jpp.is_c_padded) {
525-
Vmm vmm_to_load = Vmm(idx) | k_c_tail_mask | T_z;
526-
vpmovzxwd(vmm_to_load, ptr[reg_ptr + offset]);
527-
vpslld(vmm_to_load, vmm_to_load, 16);
528-
} else {
529-
vmovups(Ymm(idx), ptr[reg_ptr + offset]);
530-
vpermw(Vmm(idx) | k_mask_cvt | T_z, vmm_idx(), Vmm(idx));
531-
}
532-
} else if (dt == data_type::f16) {
533-
Vmm vmm_to_load = is_c_tail_proccessing && !jpp.is_c_padded
534-
? Vmm(idx) | k_c_tail_mask | T_z
535-
: Vmm(idx);
536-
vcvtph2psx(vmm_to_load, ptr[reg_ptr + offset]);
537-
} else if (utils::one_of(dt, data_type::f8_e5m2, data_type::f8_e4m3)) {
538-
Vmm vmm_to_load = is_c_tail_proccessing && !jpp.is_c_padded
539-
? Vmm(idx) | k_c_tail_mask | T_z
540-
: Vmm(idx);
541-
if (dt == data_type::f8_e5m2)
542-
f8_e5m2_emu_->vcvt_f8_to_f32(vmm_to_load, ptr[reg_ptr + offset]);
543-
else if (dt == data_type::f8_e4m3)
544-
f8_e4m3_emu_->vcvt_f8_to_f32(vmm_to_load, ptr[reg_ptr + offset]);
545-
} else {
546-
if (is_c_tail_proccessing && !jpp.is_c_padded) {
547-
if (isa == avx || isa == avx2) {
548-
vmaskmovps(Vmm(idx), vmm_c_tail_mask, ptr[reg_ptr + offset]);
549-
} else {
550-
vmovups(Zmm(idx) | k_c_tail_mask | T_z, ptr[reg_ptr + offset]);
551-
}
552-
} else {
553-
uni_vmovups(Vmm(idx), ptr[reg_ptr + offset]);
554-
}
555-
}
556-
}
557-
558-
template <>
559-
inline void jit_uni_pool_kernel<avx2_vnni_2>::load(const data_type_t dt,
560-
const int idx, const reg64_t &reg_ptr, const int offset,
561-
const bool is_c_tail_proccessing) {
562-
if (is_c_tail_proccessing) {
563-
vmaskmovps(Xmm(idx), xmm_c_tail_mask, ptr[reg_ptr + offset]);
564-
if (jpp.c_tail % 2 != 0) {
565-
const int tail_pos = jpp.c_tail - 1;
566-
auto word_addr
567-
= ptr[reg_ptr + offset + tail_pos * sizeof(bfloat16_t)];
568-
vpinsrw(Xmm(idx), Xmm(idx), word_addr, tail_pos);
569-
}
570-
}
571-
if (dt == data_type::bf16) {
572-
if (is_c_tail_proccessing)
573-
vpmovzxwd(Ymm(idx), Xmm(idx));
574-
else
575-
vpmovzxwd(Ymm(idx), ptr[reg_ptr + offset]);
576-
vpslld(Ymm(idx), Ymm(idx), 16);
577-
} else if (dt == data_type::f16) {
578-
if (is_c_tail_proccessing)
579-
vcvtph2ps(Ymm(idx), Xmm(idx));
580-
else
581-
vcvtph2ps(Ymm(idx), ptr[reg_ptr + offset]);
582-
} else
583-
assert(!"invalid data type");
518+
io_[dt]->load(vmmword[reg_ptr + offset], Vmm(idx),
519+
is_c_tail_proccessing && !jpp.is_c_padded);
584520
}
585521

586522
template <>
@@ -599,64 +535,16 @@ template <cpu_isa_t isa>
599535
inline void jit_uni_pool_kernel<isa>::store(const data_type_t dt, const int idx,
600536
const reg64_t &reg_ptr, const int offset,
601537
const bool is_c_tail_proccessing) {
602-
if (utils::one_of(dt, data_type::bf16, data_type::f16)) {
603-
if (is_c_tail_proccessing) {
604-
if (jpp.is_c_padded) {
605-
vmovdqu16(Ymm(idx) | k_c_tail_mask | T_z, Ymm(idx));
606-
vmovups(yword[reg_ptr + offset], Ymm(idx));
607-
} else
608-
vmovdqu16(ptr[reg_ptr + offset] | k_c_tail_mask, Ymm(idx));
609-
} else
610-
vmovups(yword[reg_ptr + offset], Ymm(idx));
611-
} else if (utils::one_of(dt, data_type::f8_e5m2, data_type::f8_e4m3)) {
612-
if (is_c_tail_proccessing) {
613-
if (jpp.is_c_padded) {
614-
vmovdqu8(Xmm(idx) | k_c_tail_mask | T_z, Xmm(idx));
615-
vmovdqu8(yword[reg_ptr + offset], Xmm(idx));
616-
} else
617-
vmovdqu8(ptr[reg_ptr + offset] | k_c_tail_mask, Xmm(idx));
618-
} else
619-
vmovdqu8(yword[reg_ptr + offset], Xmm(idx));
620-
} else {
621-
if (is_c_tail_proccessing) {
622-
if (!jpp.is_c_padded) {
623-
if (isa == avx || isa == avx2)
624-
vmaskmovps(
625-
ptr[reg_ptr + offset], vmm_c_tail_mask, Vmm(idx));
626-
else
627-
vmovups(ptr[reg_ptr + offset] | k_c_tail_mask, Zmm(idx));
628-
} else {
629-
if (jpp.with_postops) {
630-
if (isa == avx || isa == avx2) {
631-
uni_vxorps(ymm_tmp_1, ymm_tmp_1, ymm_tmp_1);
632-
uni_vblendvps(
633-
Vmm(idx), ymm_tmp_1, Vmm(idx), vmm_c_tail_mask);
634-
} else
635-
uni_vmovups(Vmm(idx) | k_c_tail_mask | T_z, Vmm(idx));
636-
}
637-
uni_vmovups(vmmword[reg_ptr + offset], Vmm(idx));
638-
}
538+
if (is_c_tail_proccessing && jpp.is_c_padded) {
539+
if (isa == avx || isa == avx2) {
540+
uni_vxorps(ymm_tmp_1, ymm_tmp_1, ymm_tmp_1);
541+
uni_vblendvps(Vmm(idx), ymm_tmp_1, Vmm(idx), vmm_c_tail_mask);
639542
} else
640-
uni_vmovups(vmmword[reg_ptr + offset], Vmm(idx));
543+
uni_vmovups(Vmm(idx) | k_c_tail_mask | T_z, Vmm(idx));
641544
}
642-
}
643545

644-
template <>
645-
inline void jit_uni_pool_kernel<avx2_vnni_2>::store(const data_type_t dt,
646-
const int idx, const reg64_t &reg_ptr, const int offset,
647-
const bool is_c_tail_proccessing) {
648-
if (utils::one_of(dt, data_type::bf16, data_type::f16)) {
649-
if (is_c_tail_proccessing) {
650-
vmaskmovps(ptr[reg_ptr + offset], xmm_c_tail_mask, Xmm(idx));
651-
if (jpp.c_tail % 2 != 0) {
652-
const int tail_pos = jpp.c_tail - 1;
653-
auto word_addr = ptr[reg_ptr + offset + tail_pos * 2];
654-
vpextrw(word_addr, Xmm(idx), tail_pos);
655-
}
656-
} else
657-
vmovups(xword[reg_ptr + offset], Xmm(idx));
658-
} else
659-
assert(!"datatype not supported");
546+
io_[dt]->store(Vmm(idx), vmmword[reg_ptr + offset],
547+
is_c_tail_proccessing && !jpp.is_c_padded);
660548
}
661549

662550
template <>
@@ -887,24 +775,9 @@ inline void jit_uni_pool_kernel<isa>::avg_step(int ur_w, int ur_bc, int pad_l,
887775
if (aux_input_offset >= iw * c_off) continue;
888776
int input_offset = dt_size * aux_input_offset;
889777
if (jpp.is_backward) {
890-
auto inpyr = yreg(inpr_i);
891778
load(jpp.src_dt, reg_idx(inpr_i), aux_reg_input,
892779
input_offset, is_tail_processing(bci));
893780
uni_vaddps(inpvr, inpvr, accvr);
894-
if (jpp.is_bf16) {
895-
if (!isa_has_bf16(jpp.isa))
896-
bf16_emu_->vcvtneps2bf16(inpyr, zreg(inpr_i));
897-
else
898-
vcvtneps2bf16(inpyr, inpvr);
899-
} else if (jpp.is_f16) {
900-
vcvtps2ph(inpyr, inpvr, _op_mxcsr);
901-
} else if (jpp.is_fp8) {
902-
auto inpxr = xreg(inpr_i);
903-
if (jpp.src_dt == data_type::f8_e5m2)
904-
f8_e5m2_emu_->vcvt_f32_to_f8(inpxr, zreg(inpr_i));
905-
else if (jpp.src_dt == data_type::f8_e4m3)
906-
f8_e4m3_emu_->vcvt_f32_to_f8(inpxr, zreg(inpr_i));
907-
}
908781
store(jpp.src_dt, reg_idx(inpr_i), aux_reg_input,
909782
input_offset, is_tail_processing(bci));
910783
} else {
@@ -955,34 +828,8 @@ inline void jit_uni_pool_kernel<isa>::avg_step(int ur_w, int ur_bc, int pad_l,
955828
for (int jj = 0; jj < ur_w; jj++) {
956829
for (int bci = 0; bci < ur_bc; bci++) {
957830
const auto accr_i = reg_ind(0, bci, jj, ur_bc, ur_w);
958-
const auto accvr = vreg(accr_i);
959831
const auto output_offset
960832
= dt_size * (jj * c_off + bci * c_block);
961-
const auto accyr = yreg(accr_i);
962-
if (jpp.is_bf16) {
963-
if (isa == avx2_vnni_2) {
964-
auto accxr = xreg(accr_i);
965-
vcvtneps2bf16(accxr, accyr, Xbyak::VexEncoding);
966-
} else {
967-
const auto acczr = zreg(accr_i);
968-
if (!isa_has_bf16(jpp.isa))
969-
bf16_emu_->vcvtneps2bf16(accyr, acczr);
970-
else
971-
vcvtneps2bf16(accyr, accvr);
972-
}
973-
} else if (jpp.is_f16) {
974-
if (isa == avx2_vnni_2) {
975-
auto accxr = xreg(accr_i);
976-
vcvtps2ph(accxr, accyr, _op_mxcsr);
977-
} else
978-
vcvtps2ph(accyr, accvr, _op_mxcsr);
979-
} else if (jpp.is_fp8) {
980-
const auto accxr = xreg(accr_i);
981-
if (jpp.src_dt == data_type::f8_e5m2)
982-
f8_e5m2_emu_->vcvt_f32_to_f8(accxr, accvr);
983-
else if (jpp.src_dt == data_type::f8_e4m3)
984-
f8_e4m3_emu_->vcvt_f32_to_f8(accxr, accvr);
985-
}
986833
store(jpp.dst_dt, reg_idx(accr_i), reg_output, output_offset,
987834
is_tail_processing(bci));
988835
}
@@ -1129,34 +976,7 @@ inline void jit_uni_pool_kernel<isa>::max_step_fwd(int ur_w, int ur_bc,
1129976
for_(int jj = 0; jj < ur_w; jj++)
1130977
for (int bci = 0; bci < ur_bc; bci++) {
1131978
const auto accr_i = reg_ind(0, bci, jj, ur_bc, ur_w);
1132-
const auto accvr = vreg(accr_i);
1133979
const auto output_offset = jpp.dt_size * (jj * c_off + bci * c_block);
1134-
auto accyr = yreg(accr_i);
1135-
if (jpp.is_bf16) {
1136-
if (isa == avx2_vnni_2) {
1137-
auto accxr = xreg(accr_i);
1138-
vcvtneps2bf16(accxr, accyr, Xbyak::VexEncoding);
1139-
} else {
1140-
auto acczr = zreg(accr_i);
1141-
if (!isa_has_bf16(jpp.isa))
1142-
bf16_emu_->vcvtneps2bf16(accyr, acczr);
1143-
else
1144-
vcvtneps2bf16(accyr, accvr);
1145-
}
1146-
} else if (jpp.is_f16) {
1147-
if (isa == avx2_vnni_2) {
1148-
auto accxr = xreg(accr_i);
1149-
vcvtps2ph(accxr, accyr, _op_mxcsr);
1150-
} else
1151-
vcvtps2ph(accyr, accvr, _op_mxcsr);
1152-
} else if (jpp.is_fp8) {
1153-
auto accxr = xreg(accr_i);
1154-
auto acczr = zreg(accr_i);
1155-
if (jpp.src_dt == data_type::f8_e5m2)
1156-
f8_e5m2_emu_->vcvt_f32_to_f8(accxr, acczr);
1157-
else if (jpp.src_dt == data_type::f8_e4m3)
1158-
f8_e4m3_emu_->vcvt_f32_to_f8(accxr, acczr);
1159-
}
1160980
store(jpp.dst_dt, reg_idx(accr_i), reg_output, output_offset,
1161981
is_tail_processing(bci));
1162982

@@ -1416,19 +1236,9 @@ inline void jit_uni_pool_kernel<isa>::max_step_bwd(int ur_w, int ur_bc,
14161236
vmaskmovps(
14171237
vmmword[aux_reg_input + inp_offset], cvtvr, inpvr);
14181238
} else {
1419-
auto indzr = zreg(inpr_i);
1420-
auto indyr = yreg(inpr_i);
14211239
vpcmpeqd(k_store_mask, indvr, vmm_k_offset);
14221240
vblendmps(vmm_tmp | k_store_mask | T_z, outvr, outvr);
14231241
vaddps(inpvr, inpvr, vmm_tmp);
1424-
if (jpp.is_bf16 && !jpp.needs_f32_accum_for_bf16) {
1425-
if (!isa_has_bf16(jpp.isa))
1426-
bf16_emu_->vcvtneps2bf16(indyr, indzr);
1427-
else
1428-
vcvtneps2bf16(indyr, inpvr);
1429-
} else if (jpp.is_f16) {
1430-
vcvtps2ph(indyr, inpvr, _op_mxcsr);
1431-
}
14321242
store(input_dt, inpvr.getIdx(), aux_reg_input, inp_offset,
14331243
is_tail_processing(bci));
14341244
}
@@ -1592,7 +1402,8 @@ void jit_uni_pool_kernel<isa>::generate() {
15921402
xor_(rcx, rdi);
15931403
xor_(rdi, rcx);
15941404
#endif
1595-
if (use_bf16_emulation()) bf16_emu_->init_vcvtneps2bf16();
1405+
1406+
if (use_bf16_emulation()) io_.init_bf16();
15961407

15971408
mov(reg_input, ptr[reg_param + GET_OFF(src)]);
15981409
mov(reg_output, ptr[reg_param + GET_OFF(dst)]);
@@ -1763,15 +1574,15 @@ void jit_uni_pool_kernel<isa>::generate() {
17631574
// care of c tail processing if number of channels
17641575
// is not divided by number of channels in block
17651576
L(ur_bc_tail_label);
1766-
if (jpp.c_tail != 0) prepare_tail_mask();
1577+
if (jpp.c_tail != 0) io_.prepare_tail_mask();
17671578
perform_ker(jpp.ur_bc_tail, jpp.c_tail != 0);
17681579

17691580
L(finish_label);
17701581
} else if (jpp.c_tail != 0) {
17711582
jmp(finish_label, T_NEAR);
17721583

17731584
L(c_tail_processing_label);
1774-
prepare_tail_mask();
1585+
io_.prepare_tail_mask();
17751586
perform_ker(jpp.ur_bc, true);
17761587

17771588
L(finish_label);
@@ -1792,6 +1603,7 @@ void jit_uni_pool_kernel<isa>::generate() {
17921603
}
17931604
if (f8_e5m2_emu_) f8_e5m2_emu_->prepare_table();
17941605
if (f8_e4m3_emu_) f8_e4m3_emu_->prepare_table();
1606+
io_.prepare_table_fp8();
17951607
}
17961608

17971609
template struct jit_uni_pool_kernel<sse41>;

0 commit comments

Comments
 (0)