diff --git a/src/cpu/x64/jit_brgemm_conv_utils.cpp b/src/cpu/x64/jit_brgemm_conv_utils.cpp index 162b5e60d6d..4f92ba53eba 100644 --- a/src/cpu/x64/jit_brgemm_conv_utils.cpp +++ b/src/cpu/x64/jit_brgemm_conv_utils.cpp @@ -1941,12 +1941,17 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, using namespace data_type; // ======================= blocking ================================= - auto bcast_amount - = static_cast(jcp.id) * jcp.ih * jcp.iw * jcp.src_dsz; + auto bcast_amount = static_cast(jcp.id) * jcp.ih * jcp.iw + * jcp.src_dsz * jcp.ic; auto wei_amount = static_cast(jcp.oc) * jcp.kd * jcp.kh * jcp.kw - * jcp.wei_dsz; - - jcp.loop_order = (bcast_amount < wei_amount) ? loop_ngcdhw : loop_ndhwgc; + * jcp.wei_dsz * jcp.ic; + + jcp.loop_order + = (one_of(isa, avx2, avx2_vnni, avx2_vnni_2) && jcp.mb > jcp.nthr + && bcast_amount > brg_blocking_t::L2 + && wei_amount > brg_blocking_t::L2) + ? loop_gcndhw + : ((bcast_amount < wei_amount) ? loop_ngcdhw : loop_ndhwgc); jcp.brgemm_kernel_loop_order = brgemm_kernel_loop_order_t::brgemm_lo_default; @@ -1964,6 +1969,13 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, const bool small_amx_job = est_amx_job < 64 || jcp.oc < 256; auto start_ocb = (is_amx(isa) && jcp.is_os_blocking && small_amx_job) ? 2 : 4; + if (one_of(isa, avx2, avx2_vnni, avx2_vnni_2) + && jcp.loop_order == loop_gcndhw) + start_ocb = 2; + if (one_of(isa, avx2, avx2_vnni, avx2_vnni_2) + && jcp.oh * jcp.ow >= 150 * 150) + start_ocb = 2; + start_ocb = nstl::min(div_up(jcp.oc, jcp.acc_simd_w), start_ocb); auto finish_ocb = 1; @@ -2216,7 +2228,6 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, #endif // ============ end blocking =========================================== - jcp.brg_type = (jcp.use_uker && one_of(jcp.exec_type, exec_base, exec_trans)) ? brgemm_static_offs