@@ -45,11 +45,7 @@ jit_uni_pool_kernel<isa>::~jit_uni_pool_kernel() = default;
45
45
template <cpu_isa_t isa>
46
46
jit_uni_pool_kernel<isa>::jit_uni_pool_kernel(
47
47
const jit_pool_conf_t &ajpp, const memory_desc_t *dst_md)
48
- : jit_generator(jit_name(), isa), jpp(ajpp), bf16_emu_(nullptr ) {
49
- if (use_bf16_emulation ())
50
- bf16_emu_ = utils::make_unique<bf16_emulation_t >(this ,
51
- bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3,
52
- bf16_emu_reserv_4, bf16_emu_reserv_5);
48
+ : jit_generator(jit_name(), isa), jpp(ajpp) {
53
49
54
50
bool has_f8_e5m2_binary_postops = false ;
55
51
bool has_f8_e4m3_binary_postops = false ;
@@ -109,6 +105,31 @@ jit_uni_pool_kernel<isa>::jit_uni_pool_kernel(
109
105
= utils::make_unique<injector::jit_uni_postops_injector_t <isa>>(
110
106
this , jpp.post_ops , bsp);
111
107
}
108
+
109
+ io::io_tail_conf_t io_tail_conf (jpp.c_block , jpp.c_tail ,
110
+ k_c_tail_mask.getIdx (), vmm_c_tail_mask.getIdx (), tmp_gpr);
111
+
112
+ utils::optional_t <io::io_emu_bf16_conf_t > io_bf16_conf;
113
+ if (use_bf16_emulation ())
114
+ io_bf16_conf = io::io_emu_bf16_conf_t (bf16_emu_reserv_1,
115
+ bf16_emu_reserv_2, bf16_emu_reserv_3, bf16_emu_reserv_4,
116
+ bf16_emu_reserv_5);
117
+
118
+ utils::optional_t <io::io_emu_fp8_conf_t > io_fp8_conf;
119
+ if (use_fp8_emulation () || has_f8_e5m2_binary_postops
120
+ || has_f8_e4m3_binary_postops)
121
+ io_fp8_conf = io::io_emu_fp8_conf_t (fp8_emu_reserv_1, fp8_emu_reserv_2,
122
+ fp8_emu_reserv_3, fp8_emu_reserv_4, fp8_emu_reserv_5,
123
+ fp8_tmp_mask, fp8_emu_reg64);
124
+
125
+ using io_mdt_helper = io::jit_io_multi_dt_helper_t <Vmm>;
126
+
127
+ typename io_mdt_helper::data_types_t dtypes = {jpp.src_dt , jpp.dst_dt };
128
+ if (jpp.ind_dt != data_type::undef) dtypes.insert (jpp.ind_dt );
129
+ if (jpp.needs_f32_accum_for_bf16 ) dtypes.insert (data_type::f32);
130
+
131
+ io_ = io_mdt_helper (this , jpp.isa , dtypes, {}, io_tail_conf, io_bf16_conf,
132
+ {}, utils::nullopt, io_fp8_conf);
112
133
}
113
134
114
135
static status_t set_binary_postops_formats (
@@ -463,30 +484,6 @@ static int reg_ind(int shift, int bc, int j, int ur_bc, int ur_w) noexcept {
463
484
return shift * ur_bc * ur_w + bc * ur_w + j;
464
485
};
465
486
466
- template <cpu_isa_t isa>
467
- inline void jit_uni_pool_kernel<isa>::prepare_tail_mask() {
468
- if (is_superset (isa, avx512_core)) {
469
- size_t c_tail_mask = (1ULL << jpp.c_tail ) - 1ULL ;
470
- mov (tmp_gpr.cvt32 (), c_tail_mask);
471
- kmovw (k_c_tail_mask, tmp_gpr.cvt32 ());
472
- } else if (utils::one_of (isa, avx, avx2, avx2_vnni_2)) {
473
- constexpr int max_words_in_ymm = 8 ;
474
-
475
- // for 'avx2_vnni_2' mask works with 2 x xf16 elements,
476
- // in case of 'c_tail % 2 != 0' load/store an additional word
477
- // for the remaining element.
478
- auto dt_elem_div = isa == avx2_vnni_2 ? 2 : 1 ;
479
- auto mask_offset = max_words_in_ymm - (jpp.c_tail / dt_elem_div);
480
- auto mask_register
481
- = isa == avx2_vnni_2 ? xmm_c_tail_mask : vmm_c_tail_mask;
482
- static const uint32_t mask[16 ] = {0xffffffff , 0xffffffff , 0xffffffff ,
483
- 0xffffffff , 0xffffffff , 0xffffffff , 0xffffffff , 0xffffffff , 0 ,
484
- 0 , 0 , 0 , 0 , 0 , 0 , 0 };
485
- mov (tmp_gpr, reinterpret_cast <size_t >(&mask[mask_offset]));
486
- vmovups (mask_register, ptr[tmp_gpr]);
487
- }
488
- }
489
-
490
487
template <cpu_isa_t isa>
491
488
inline void jit_uni_pool_kernel<isa>::put_one_in_vmm() {
492
489
mov (tmp_gpr, 1 );
@@ -518,69 +515,8 @@ template <cpu_isa_t isa>
518
515
inline void jit_uni_pool_kernel<isa>::load(const data_type_t dt, const int idx,
519
516
const reg64_t ®_ptr, const int offset,
520
517
const bool is_c_tail_proccessing) {
521
- if (dt == data_type::bf16) {
522
- /* TODO: maybe use vpmovzxwd + vpslld,
523
- * in order to free up vmm_idx() register */
524
- if (is_c_tail_proccessing && !jpp.is_c_padded ) {
525
- Vmm vmm_to_load = Vmm (idx) | k_c_tail_mask | T_z;
526
- vpmovzxwd (vmm_to_load, ptr[reg_ptr + offset]);
527
- vpslld (vmm_to_load, vmm_to_load, 16 );
528
- } else {
529
- vmovups (Ymm (idx), ptr[reg_ptr + offset]);
530
- vpermw (Vmm (idx) | k_mask_cvt | T_z, vmm_idx (), Vmm (idx));
531
- }
532
- } else if (dt == data_type::f16) {
533
- Vmm vmm_to_load = is_c_tail_proccessing && !jpp.is_c_padded
534
- ? Vmm (idx) | k_c_tail_mask | T_z
535
- : Vmm (idx);
536
- vcvtph2psx (vmm_to_load, ptr[reg_ptr + offset]);
537
- } else if (utils::one_of (dt, data_type::f8_e5m2, data_type::f8_e4m3)) {
538
- Vmm vmm_to_load = is_c_tail_proccessing && !jpp.is_c_padded
539
- ? Vmm (idx) | k_c_tail_mask | T_z
540
- : Vmm (idx);
541
- if (dt == data_type::f8_e5m2)
542
- f8_e5m2_emu_->vcvt_f8_to_f32 (vmm_to_load, ptr[reg_ptr + offset]);
543
- else if (dt == data_type::f8_e4m3)
544
- f8_e4m3_emu_->vcvt_f8_to_f32 (vmm_to_load, ptr[reg_ptr + offset]);
545
- } else {
546
- if (is_c_tail_proccessing && !jpp.is_c_padded ) {
547
- if (isa == avx || isa == avx2) {
548
- vmaskmovps (Vmm (idx), vmm_c_tail_mask, ptr[reg_ptr + offset]);
549
- } else {
550
- vmovups (Zmm (idx) | k_c_tail_mask | T_z, ptr[reg_ptr + offset]);
551
- }
552
- } else {
553
- uni_vmovups (Vmm (idx), ptr[reg_ptr + offset]);
554
- }
555
- }
556
- }
557
-
558
- template <>
559
- inline void jit_uni_pool_kernel<avx2_vnni_2>::load(const data_type_t dt,
560
- const int idx, const reg64_t ®_ptr, const int offset,
561
- const bool is_c_tail_proccessing) {
562
- if (is_c_tail_proccessing) {
563
- vmaskmovps (Xmm (idx), xmm_c_tail_mask, ptr[reg_ptr + offset]);
564
- if (jpp.c_tail % 2 != 0 ) {
565
- const int tail_pos = jpp.c_tail - 1 ;
566
- auto word_addr
567
- = ptr[reg_ptr + offset + tail_pos * sizeof (bfloat16_t )];
568
- vpinsrw (Xmm (idx), Xmm (idx), word_addr, tail_pos);
569
- }
570
- }
571
- if (dt == data_type::bf16) {
572
- if (is_c_tail_proccessing)
573
- vpmovzxwd (Ymm (idx), Xmm (idx));
574
- else
575
- vpmovzxwd (Ymm (idx), ptr[reg_ptr + offset]);
576
- vpslld (Ymm (idx), Ymm (idx), 16 );
577
- } else if (dt == data_type::f16) {
578
- if (is_c_tail_proccessing)
579
- vcvtph2ps (Ymm (idx), Xmm (idx));
580
- else
581
- vcvtph2ps (Ymm (idx), ptr[reg_ptr + offset]);
582
- } else
583
- assert (!" invalid data type" );
518
+ io_[dt]->load (vmmword[reg_ptr + offset], Vmm (idx),
519
+ is_c_tail_proccessing && !jpp.is_c_padded );
584
520
}
585
521
586
522
template <>
@@ -599,64 +535,16 @@ template <cpu_isa_t isa>
599
535
inline void jit_uni_pool_kernel<isa>::store(const data_type_t dt, const int idx,
600
536
const reg64_t ®_ptr, const int offset,
601
537
const bool is_c_tail_proccessing) {
602
- if (utils::one_of (dt, data_type::bf16, data_type::f16)) {
603
- if (is_c_tail_proccessing) {
604
- if (jpp.is_c_padded ) {
605
- vmovdqu16 (Ymm (idx) | k_c_tail_mask | T_z, Ymm (idx));
606
- vmovups (yword[reg_ptr + offset], Ymm (idx));
607
- } else
608
- vmovdqu16 (ptr[reg_ptr + offset] | k_c_tail_mask, Ymm (idx));
609
- } else
610
- vmovups (yword[reg_ptr + offset], Ymm (idx));
611
- } else if (utils::one_of (dt, data_type::f8_e5m2, data_type::f8_e4m3)) {
612
- if (is_c_tail_proccessing) {
613
- if (jpp.is_c_padded ) {
614
- vmovdqu8 (Xmm (idx) | k_c_tail_mask | T_z, Xmm (idx));
615
- vmovdqu8 (yword[reg_ptr + offset], Xmm (idx));
616
- } else
617
- vmovdqu8 (ptr[reg_ptr + offset] | k_c_tail_mask, Xmm (idx));
618
- } else
619
- vmovdqu8 (yword[reg_ptr + offset], Xmm (idx));
620
- } else {
621
- if (is_c_tail_proccessing) {
622
- if (!jpp.is_c_padded ) {
623
- if (isa == avx || isa == avx2)
624
- vmaskmovps (
625
- ptr[reg_ptr + offset], vmm_c_tail_mask, Vmm (idx));
626
- else
627
- vmovups (ptr[reg_ptr + offset] | k_c_tail_mask, Zmm (idx));
628
- } else {
629
- if (jpp.with_postops ) {
630
- if (isa == avx || isa == avx2) {
631
- uni_vxorps (ymm_tmp_1, ymm_tmp_1, ymm_tmp_1);
632
- uni_vblendvps (
633
- Vmm (idx), ymm_tmp_1, Vmm (idx), vmm_c_tail_mask);
634
- } else
635
- uni_vmovups (Vmm (idx) | k_c_tail_mask | T_z, Vmm (idx));
636
- }
637
- uni_vmovups (vmmword[reg_ptr + offset], Vmm (idx));
638
- }
538
+ if (is_c_tail_proccessing && jpp.is_c_padded ) {
539
+ if (isa == avx || isa == avx2) {
540
+ uni_vxorps (ymm_tmp_1, ymm_tmp_1, ymm_tmp_1);
541
+ uni_vblendvps (Vmm (idx), ymm_tmp_1, Vmm (idx), vmm_c_tail_mask);
639
542
} else
640
- uni_vmovups (vmmword[reg_ptr + offset] , Vmm (idx));
543
+ uni_vmovups (Vmm (idx) | k_c_tail_mask | T_z , Vmm (idx));
641
544
}
642
- }
643
545
644
- template <>
645
- inline void jit_uni_pool_kernel<avx2_vnni_2>::store(const data_type_t dt,
646
- const int idx, const reg64_t ®_ptr, const int offset,
647
- const bool is_c_tail_proccessing) {
648
- if (utils::one_of (dt, data_type::bf16, data_type::f16)) {
649
- if (is_c_tail_proccessing) {
650
- vmaskmovps (ptr[reg_ptr + offset], xmm_c_tail_mask, Xmm (idx));
651
- if (jpp.c_tail % 2 != 0 ) {
652
- const int tail_pos = jpp.c_tail - 1 ;
653
- auto word_addr = ptr[reg_ptr + offset + tail_pos * 2 ];
654
- vpextrw (word_addr, Xmm (idx), tail_pos);
655
- }
656
- } else
657
- vmovups (xword[reg_ptr + offset], Xmm (idx));
658
- } else
659
- assert (!" datatype not supported" );
546
+ io_[dt]->store (Vmm (idx), vmmword[reg_ptr + offset],
547
+ is_c_tail_proccessing && !jpp.is_c_padded );
660
548
}
661
549
662
550
template <>
@@ -887,24 +775,9 @@ inline void jit_uni_pool_kernel<isa>::avg_step(int ur_w, int ur_bc, int pad_l,
887
775
if (aux_input_offset >= iw * c_off) continue ;
888
776
int input_offset = dt_size * aux_input_offset;
889
777
if (jpp.is_backward ) {
890
- auto inpyr = yreg (inpr_i);
891
778
load (jpp.src_dt , reg_idx (inpr_i), aux_reg_input,
892
779
input_offset, is_tail_processing (bci));
893
780
uni_vaddps (inpvr, inpvr, accvr);
894
- if (jpp.is_bf16 ) {
895
- if (!isa_has_bf16 (jpp.isa ))
896
- bf16_emu_->vcvtneps2bf16 (inpyr, zreg (inpr_i));
897
- else
898
- vcvtneps2bf16 (inpyr, inpvr);
899
- } else if (jpp.is_f16 ) {
900
- vcvtps2ph (inpyr, inpvr, _op_mxcsr);
901
- } else if (jpp.is_fp8 ) {
902
- auto inpxr = xreg (inpr_i);
903
- if (jpp.src_dt == data_type::f8_e5m2)
904
- f8_e5m2_emu_->vcvt_f32_to_f8 (inpxr, zreg (inpr_i));
905
- else if (jpp.src_dt == data_type::f8_e4m3)
906
- f8_e4m3_emu_->vcvt_f32_to_f8 (inpxr, zreg (inpr_i));
907
- }
908
781
store (jpp.src_dt , reg_idx (inpr_i), aux_reg_input,
909
782
input_offset, is_tail_processing (bci));
910
783
} else {
@@ -955,34 +828,8 @@ inline void jit_uni_pool_kernel<isa>::avg_step(int ur_w, int ur_bc, int pad_l,
955
828
for (int jj = 0 ; jj < ur_w; jj++) {
956
829
for (int bci = 0 ; bci < ur_bc; bci++) {
957
830
const auto accr_i = reg_ind (0 , bci, jj, ur_bc, ur_w);
958
- const auto accvr = vreg (accr_i);
959
831
const auto output_offset
960
832
= dt_size * (jj * c_off + bci * c_block);
961
- const auto accyr = yreg (accr_i);
962
- if (jpp.is_bf16 ) {
963
- if (isa == avx2_vnni_2) {
964
- auto accxr = xreg (accr_i);
965
- vcvtneps2bf16 (accxr, accyr, Xbyak::VexEncoding);
966
- } else {
967
- const auto acczr = zreg (accr_i);
968
- if (!isa_has_bf16 (jpp.isa ))
969
- bf16_emu_->vcvtneps2bf16 (accyr, acczr);
970
- else
971
- vcvtneps2bf16 (accyr, accvr);
972
- }
973
- } else if (jpp.is_f16 ) {
974
- if (isa == avx2_vnni_2) {
975
- auto accxr = xreg (accr_i);
976
- vcvtps2ph (accxr, accyr, _op_mxcsr);
977
- } else
978
- vcvtps2ph (accyr, accvr, _op_mxcsr);
979
- } else if (jpp.is_fp8 ) {
980
- const auto accxr = xreg (accr_i);
981
- if (jpp.src_dt == data_type::f8_e5m2)
982
- f8_e5m2_emu_->vcvt_f32_to_f8 (accxr, accvr);
983
- else if (jpp.src_dt == data_type::f8_e4m3)
984
- f8_e4m3_emu_->vcvt_f32_to_f8 (accxr, accvr);
985
- }
986
833
store (jpp.dst_dt , reg_idx (accr_i), reg_output, output_offset,
987
834
is_tail_processing (bci));
988
835
}
@@ -1129,34 +976,7 @@ inline void jit_uni_pool_kernel<isa>::max_step_fwd(int ur_w, int ur_bc,
1129
976
for_ (int jj = 0 ; jj < ur_w; jj++)
1130
977
for (int bci = 0 ; bci < ur_bc; bci++) {
1131
978
const auto accr_i = reg_ind (0 , bci, jj, ur_bc, ur_w);
1132
- const auto accvr = vreg (accr_i);
1133
979
const auto output_offset = jpp.dt_size * (jj * c_off + bci * c_block);
1134
- auto accyr = yreg (accr_i);
1135
- if (jpp.is_bf16 ) {
1136
- if (isa == avx2_vnni_2) {
1137
- auto accxr = xreg (accr_i);
1138
- vcvtneps2bf16 (accxr, accyr, Xbyak::VexEncoding);
1139
- } else {
1140
- auto acczr = zreg (accr_i);
1141
- if (!isa_has_bf16 (jpp.isa ))
1142
- bf16_emu_->vcvtneps2bf16 (accyr, acczr);
1143
- else
1144
- vcvtneps2bf16 (accyr, accvr);
1145
- }
1146
- } else if (jpp.is_f16 ) {
1147
- if (isa == avx2_vnni_2) {
1148
- auto accxr = xreg (accr_i);
1149
- vcvtps2ph (accxr, accyr, _op_mxcsr);
1150
- } else
1151
- vcvtps2ph (accyr, accvr, _op_mxcsr);
1152
- } else if (jpp.is_fp8 ) {
1153
- auto accxr = xreg (accr_i);
1154
- auto acczr = zreg (accr_i);
1155
- if (jpp.src_dt == data_type::f8_e5m2)
1156
- f8_e5m2_emu_->vcvt_f32_to_f8 (accxr, acczr);
1157
- else if (jpp.src_dt == data_type::f8_e4m3)
1158
- f8_e4m3_emu_->vcvt_f32_to_f8 (accxr, acczr);
1159
- }
1160
980
store (jpp.dst_dt , reg_idx (accr_i), reg_output, output_offset,
1161
981
is_tail_processing (bci));
1162
982
@@ -1416,19 +1236,9 @@ inline void jit_uni_pool_kernel<isa>::max_step_bwd(int ur_w, int ur_bc,
1416
1236
vmaskmovps (
1417
1237
vmmword[aux_reg_input + inp_offset], cvtvr, inpvr);
1418
1238
} else {
1419
- auto indzr = zreg (inpr_i);
1420
- auto indyr = yreg (inpr_i);
1421
1239
vpcmpeqd (k_store_mask, indvr, vmm_k_offset);
1422
1240
vblendmps (vmm_tmp | k_store_mask | T_z, outvr, outvr);
1423
1241
vaddps (inpvr, inpvr, vmm_tmp);
1424
- if (jpp.is_bf16 && !jpp.needs_f32_accum_for_bf16 ) {
1425
- if (!isa_has_bf16 (jpp.isa ))
1426
- bf16_emu_->vcvtneps2bf16 (indyr, indzr);
1427
- else
1428
- vcvtneps2bf16 (indyr, inpvr);
1429
- } else if (jpp.is_f16 ) {
1430
- vcvtps2ph (indyr, inpvr, _op_mxcsr);
1431
- }
1432
1242
store (input_dt, inpvr.getIdx (), aux_reg_input, inp_offset,
1433
1243
is_tail_processing (bci));
1434
1244
}
@@ -1592,7 +1402,8 @@ void jit_uni_pool_kernel<isa>::generate() {
1592
1402
xor_ (rcx, rdi);
1593
1403
xor_ (rdi, rcx);
1594
1404
#endif
1595
- if (use_bf16_emulation ()) bf16_emu_->init_vcvtneps2bf16 ();
1405
+
1406
+ if (use_bf16_emulation ()) io_.init_bf16 ();
1596
1407
1597
1408
mov (reg_input, ptr[reg_param + GET_OFF (src)]);
1598
1409
mov (reg_output, ptr[reg_param + GET_OFF (dst)]);
@@ -1763,15 +1574,15 @@ void jit_uni_pool_kernel<isa>::generate() {
1763
1574
// care of c tail processing if number of channels
1764
1575
// is not divided by number of channels in block
1765
1576
L (ur_bc_tail_label);
1766
- if (jpp.c_tail != 0 ) prepare_tail_mask ();
1577
+ if (jpp.c_tail != 0 ) io_. prepare_tail_mask ();
1767
1578
perform_ker (jpp.ur_bc_tail , jpp.c_tail != 0 );
1768
1579
1769
1580
L (finish_label);
1770
1581
} else if (jpp.c_tail != 0 ) {
1771
1582
jmp (finish_label, T_NEAR);
1772
1583
1773
1584
L (c_tail_processing_label);
1774
- prepare_tail_mask ();
1585
+ io_. prepare_tail_mask ();
1775
1586
perform_ker (jpp.ur_bc , true );
1776
1587
1777
1588
L (finish_label);
@@ -1792,6 +1603,7 @@ void jit_uni_pool_kernel<isa>::generate() {
1792
1603
}
1793
1604
if (f8_e5m2_emu_) f8_e5m2_emu_->prepare_table ();
1794
1605
if (f8_e4m3_emu_) f8_e4m3_emu_->prepare_table ();
1606
+ io_.prepare_table_fp8 ();
1795
1607
}
1796
1608
1797
1609
template struct jit_uni_pool_kernel <sse41>;
0 commit comments