@@ -1940,12 +1940,17 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
1940
1940
using namespace data_type ;
1941
1941
// ======================= blocking =================================
1942
1942
1943
- auto bcast_amount
1944
- = static_cast < size_t >(jcp. id ) * jcp.ih * jcp.iw * jcp. src_dsz ;
1943
+ auto bcast_amount = static_cast < size_t >(jcp. id ) * jcp. ih * jcp. iw
1944
+ * jcp.src_dsz * jcp.ic ;
1945
1945
auto wei_amount = static_cast <size_t >(jcp.oc ) * jcp.kd * jcp.kh * jcp.kw
1946
- * jcp.wei_dsz ;
1947
-
1948
- jcp.loop_order = (bcast_amount < wei_amount) ? loop_ngcdhw : loop_ndhwgc;
1946
+ * jcp.wei_dsz * jcp.ic ;
1947
+
1948
+ jcp.loop_order
1949
+ = (one_of (isa, avx2, avx2_vnni, avx2_vnni_2) && jcp.mb > jcp.nthr
1950
+ && bcast_amount > brg_blocking_t ::L2
1951
+ && wei_amount > brg_blocking_t ::L2)
1952
+ ? loop_gcndhw
1953
+ : ((bcast_amount < wei_amount) ? loop_ngcdhw : loop_ndhwgc);
1949
1954
jcp.brgemm_kernel_loop_order
1950
1955
= brgemm_kernel_loop_order_t ::brgemm_lo_default;
1951
1956
@@ -1963,6 +1968,13 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
1963
1968
const bool small_amx_job = est_amx_job < 64 || jcp.oc < 256 ;
1964
1969
auto start_ocb
1965
1970
= (is_amx (isa) && jcp.is_os_blocking && small_amx_job) ? 2 : 4 ;
1971
+ if (one_of (isa, avx2, avx2_vnni, avx2_vnni_2)
1972
+ && jcp.loop_order == loop_gcndhw)
1973
+ start_ocb = 2 ;
1974
+ if (one_of (isa, avx2, avx2_vnni, avx2_vnni_2)
1975
+ && jcp.oh * jcp.ow >= 150 * 150 )
1976
+ start_ocb = 2 ;
1977
+
1966
1978
start_ocb = nstl::min (div_up (jcp.oc , jcp.acc_simd_w ), start_ocb);
1967
1979
1968
1980
auto finish_ocb = 1 ;
@@ -2215,7 +2227,6 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
2215
2227
#endif
2216
2228
2217
2229
// ============ end blocking ===========================================
2218
-
2219
2230
jcp.brg_type
2220
2231
= (jcp.use_uker && one_of (jcp.exec_type , exec_base, exec_trans))
2221
2232
? brgemm_static_offs
0 commit comments