diff --git a/src/common/memory_tracking.hpp b/src/common/memory_tracking.hpp index 2b37f9fc05a..8752a8240af 100644 --- a/src/common/memory_tracking.hpp +++ b/src/common/memory_tracking.hpp @@ -305,6 +305,7 @@ enum { key_decompression_zero_points, key_src_quantized, key_src_dequantized_scales, + key_src_grouped_sum, // These two keys should always be the last ones, // even though they are not in alphabetical order key_nested, diff --git a/src/cpu/x64/brgemm/brgemm.cpp b/src/cpu/x64/brgemm/brgemm.cpp index 52cbf163ca2..e8b88515348 100644 --- a/src/cpu/x64/brgemm/brgemm.cpp +++ b/src/cpu/x64/brgemm/brgemm.cpp @@ -82,7 +82,8 @@ void brgemm_desc_t::cleanup_dst_md() { void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, const brgemm_batch_element_t *batch, void *ptr_C, void *scratch, const brgemm_dynamic_values_t *dynamic_values, - const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) { + const void *ptr_wei_scales, const void *ptr_wei_zero_points, + const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) { brgemm_kernel_params_t brgemm_p; brgemm_p.batch = batch; @@ -105,6 +106,7 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, brgemm_p.ptr_wei_scales = ptr_wei_scales; brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points; brgemm_p.ptr_src_scales = ptr_src_scales; + brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum; brgemm_p.ic = ic; assert(brg_kernel); @@ -116,7 +118,8 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, const void *addr_A, const void *addr_B, const brgemm_batch_element_t *batch, void *ptr_C, void *scratch, const brgemm_dynamic_values_t *dynamic_values, - const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) { + const void *ptr_wei_scales, const void *ptr_wei_zero_points, + const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) { brgemm_kernel_params_t brgemm_p; brgemm_p.batch = batch; @@ -133,6 +136,7 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, brgemm_p.ptr_wei_scales = ptr_wei_scales; brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points; brgemm_p.ptr_src_scales = ptr_src_scales; + brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum; brgemm_p.ic = ic; if (dynamic_values) { brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA; @@ -148,7 +152,8 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs, const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D, const brgemm_post_ops_data_t &post_ops_data, void *scratch, const brgemm_dynamic_values_t *dynamic_values, - const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) { + const void *ptr_wei_scales, const void *ptr_wei_zero_points, + const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) { brgemm_kernel_params_t brgemm_p; brgemm_p.batch = batch; @@ -178,6 +183,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs, brgemm_p.ptr_wei_scales = ptr_wei_scales; brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points; brgemm_p.ptr_src_scales = ptr_src_scales; + brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum; brgemm_p.ic = ic; if (dynamic_values) { brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA; @@ -194,7 +200,8 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs, const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D, const brgemm_post_ops_data_t &post_ops_data, void *scratch, const brgemm_dynamic_values_t *dynamic_values, - const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) { + const void *ptr_wei_scales, const void *ptr_wei_zero_points, + const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) { brgemm_kernel_params_t brgemm_p; brgemm_p.batch = batch; @@ -224,6 +231,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs, brgemm_p.ptr_wei_scales = ptr_wei_scales; brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points; brgemm_p.ptr_src_scales = ptr_src_scales; + brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum; brgemm_p.ic = ic; if (dynamic_values) { brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA; @@ -318,6 +326,12 @@ status_t brgemm_desc_init(brgemm_desc_t *brg, cpu_isa_t isa, CHECK(brgemm_blocking(brg)); + brg->src_sum_group_size = wei_d.dims()[1]; + if (brg->with_src_dyn_quant) { + brg->src_sum_group_size = brg->rd_block; + brg->src_grouped_sum_stride = div_up(wei_d.dims()[1], brg->src_sum_group_size); + } + // avx2_vnni_2 kernel with xf16 data type requires blocked weights. if (brg->isa_impl == avx2_vnni_2 && brg->is_xf16() && brg->LDB % brg->ld_block > 0) diff --git a/src/cpu/x64/brgemm/brgemm.hpp b/src/cpu/x64/brgemm/brgemm.hpp index e53fdf18999..bbd39ffe6f4 100644 --- a/src/cpu/x64/brgemm/brgemm.hpp +++ b/src/cpu/x64/brgemm/brgemm.hpp @@ -175,7 +175,7 @@ void DNNL_API brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, void *scratch = nullptr, const brgemm_dynamic_values_t *dynamic_values = nullptr, const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr, - const void *ptr_src_scales = nullptr, size_t ic = 0); + const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0); /// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version) /// @@ -205,7 +205,7 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, void *scratch = nullptr, const brgemm_dynamic_values_t *dynamic_values = nullptr, const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr, - const void *ptr_src_scales = nullptr, size_t ic = 0); + const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0); /// Execute BRGEMM kernel (brgemm_addr version) /// @@ -234,7 +234,7 @@ void DNNL_API brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr, const brgemm_dynamic_values_t *dynamic_values = nullptr, const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr, - const void *ptr_src_scales = nullptr, size_t ic = 0); + const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0); /// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version) /// @@ -267,7 +267,7 @@ void DNNL_API brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr, const brgemm_dynamic_values_t *dynamic_values = nullptr, const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr, - const void *ptr_src_scales = nullptr, size_t ic = 0); + const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0); /// AMX utilities: Creates a palette based on BRGEMM descriptor /// diff --git a/src/cpu/x64/brgemm/brgemm_types.hpp b/src/cpu/x64/brgemm/brgemm_types.hpp index f9a619e4912..e84bb8386b3 100644 --- a/src/cpu/x64/brgemm/brgemm_types.hpp +++ b/src/cpu/x64/brgemm/brgemm_types.hpp @@ -321,6 +321,8 @@ struct brgemm_desc_t { bool with_src_dyn_quant = false; int src_scales_group_size = 0; int src_scales_stride = 0; + int src_sum_group_size = 0; + int src_grouped_sum_stride = 0; bool is_row_major() const { assert(layout != brgemm_layout_undef); @@ -500,6 +502,7 @@ struct brgemm_kernel_params_t { const void *ptr_wei_scales = nullptr; const void *ptr_wei_zero_points = nullptr; const void *ptr_src_scales = nullptr; + const void *ptr_src_grouped_sum = nullptr; size_t ic; dim_t dynamic_LDA = 0; dim_t dynamic_LDB = 0; diff --git a/src/cpu/x64/brgemm/brgemm_utils.cpp b/src/cpu/x64/brgemm/brgemm_utils.cpp index f71487920ef..36df3e8c428 100644 --- a/src/cpu/x64/brgemm/brgemm_utils.cpp +++ b/src/cpu/x64/brgemm/brgemm_utils.cpp @@ -231,7 +231,7 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) { if (one_of(brg->dt_b, data_type::f4_e2m1) && brg->isa_impl == avx2) max_bcast_block -= 2; if (one_of(brg->dt_b, data_type::nf4, data_type::f4_e2m1) && brg->isa_impl != avx2) max_bcast_block -= 1; if (brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride == 0) max_bcast_block -= 1; - if (brg->with_src_dyn_quant) max_bcast_block -= 2; + if (brg->with_src_dyn_quant) max_bcast_block -= 1; if (brg->with_src_dyn_quant && brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride != 0) max_bcast_block -= adj_ld_block2; max_bcast_block /= adj_ld_block2; @@ -298,15 +298,23 @@ status_t brgemm_blocking(brgemm_desc_t *brg) { = (brg->is_f16 && brg->isa_impl == avx512_core_fp16) ? 1 : data_type_vnni_granularity(brg->dt_a); + int rd_unroll = one_of(brg->dt_b, data_type::nf4, data_type::u4, data_type::s4, data_type::f4_e2m1) ? 32 : 4; - if (brg->with_grouped_wei_decomp) { + if (brg->with_grouped_wei_decomp && !brg->with_src_dyn_quant) { auto min_group_size = nstl::min(brg->wei_decomp_scales_group_size, brg->wei_decomp_zero_points_group_size); min_group_size = nstl::min(min_group_size, brg->src_scales_group_size); rd_unroll = nstl::min(rd_unroll, min_group_size / vnni_granularity); rd_unroll = nstl::min(rd_unroll, min_group_size / vnni_granularity); + brg->rd_block = rd_unroll * vnni_granularity; + } else if (brg->with_src_dyn_quant) { + brg->rd_block = 32; + auto min_group_size = nstl::min(brg->wei_decomp_scales_group_size, brg->wei_decomp_zero_points_group_size); + min_group_size = nstl::min(min_group_size, brg->src_scales_group_size); + brg->rd_block = nstl::min(brg->rd_block, min_group_size); + } else { + brg->rd_block = rd_unroll * vnni_granularity; } - brg->rd_block = rd_unroll * vnni_granularity; brg->rdb = brg->reduce_dim / brg->rd_block; brg->rdb_tail = brg->reduce_dim % brg->rd_block; diff --git a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp index 0ed1c4c104b..8445c886e42 100644 --- a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp +++ b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp @@ -203,6 +203,7 @@ struct jit_brgemm_kernel_t : public jit_generator { const reg64_t reg_aux_wei_zp = reg_rdb_loop; const reg64_t reg_ic = reg_rdb_loop; const reg64_t reg_src_scales = reg_rdb_loop; + const reg64_t reg_src_grouped_sum = reg_rdb_loop; const reg64_t reg_tmp_read_values = reg_rdb_loop; const reg64_t reg_aux_scales = reg_aux_B; @@ -280,12 +281,13 @@ struct jit_brgemm_kernel_t : public jit_generator { constexpr static int reg_src_scales_offs_ = 336; constexpr static int reg_aux_src_scales_offs_ = 344; constexpr static int reg_aux2_src_scales_offs_ = 352; - // constexpr static int stack_space_needed_ = 360; + constexpr static int reg_src_grouped_sum_offs_ = 360; + constexpr static int reg_aux_src_grouped_sum_offs_ = 368; + constexpr static int reg_aux2_src_grouped_sum_offs_ = 376; // these are used for FP8 as temporary push/pop spaces - constexpr static int reg_val_tmp_1_ = 368; - constexpr static int reg_val_tmp_2_ = 376; - constexpr static int stack_space_needed_ = 384; - // regsiters for dynamic quant + constexpr static int reg_val_tmp_1_ = 384; + constexpr static int reg_val_tmp_2_ = 392; + constexpr static int stack_space_needed_ = 400; bool is_ldb_loop_ = false; @@ -323,7 +325,7 @@ struct jit_brgemm_kernel_t : public jit_generator { } if (brg.with_src_dyn_quant) { - used_vregs += 2; + used_vregs += 1; } if (brg.with_src_dyn_quant && brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0) { @@ -971,6 +973,12 @@ void jit_brgemm_kernel_t::copy_post_ops_stack_values_to_aux( mov(reg_src_scales, ptr[rsp + reg_src_scales_offs_]); mov(ptr[rsp + reg_aux_src_scales_offs_], reg_src_scales); mov(ptr[rsp + reg_aux2_src_scales_offs_], reg_src_scales); + + if (brg.with_wei_decomp_zero_points) { + mov(reg_src_grouped_sum, ptr[rsp + reg_src_grouped_sum_offs_]); + mov(ptr[rsp + reg_aux_src_grouped_sum_offs_], reg_src_grouped_sum); + mov(ptr[rsp + reg_aux2_src_grouped_sum_offs_], reg_src_grouped_sum); + } } if (brg.zp_type_b != brgemm_broadcast_t::none) { mov(reg_zp_comp_b, ptr[rsp + reg_zp_comp_b_offs_]); @@ -1048,6 +1056,9 @@ void jit_brgemm_kernel_t::read_params() { if (brg.with_src_dyn_quant) { mov(reg_src_scales, ptr[param1 + GET_OFF(ptr_src_scales)]); mov(ptr[rsp + reg_src_scales_offs_], reg_src_scales); + + mov(reg_src_grouped_sum, ptr[param1 + GET_OFF(ptr_src_grouped_sum)]); + mov(ptr[rsp + reg_src_grouped_sum_offs_], reg_src_grouped_sum); } if (brg.zp_type_c != brgemm_broadcast_t::none) { @@ -2296,21 +2307,10 @@ void jit_brgemm_kernel_t::gemm_microkernel_dyn_quant(int bd_block2, }; auto vmm_zero_point = [&](int ld) { - int idx = isa_num_vregs(brg.isa_impl) - 3 - ld; + int idx = isa_num_vregs(brg.isa_impl) - 2 - ld; return Vmm(idx); }; - static const int8_t negative_one[64] = { - -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1 - }; - static const int8_t mask_low_half[64] = { 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, @@ -2328,33 +2328,18 @@ void jit_brgemm_kernel_t::gemm_microkernel_dyn_quant(int bd_block2, if (brg.with_wei_decomp_zero_points) { mov(reg_local_wei_zp, ptr[rsp + reg_aux2_wei_zero_points_offs_]); if (brg.wei_decomp_zero_points_stride == 0) { - auto reg_ptr_8 = Reg8(reg_ptr.getIdx()); - mov(reg_ptr_8, ptr[reg_local_wei_zp]); - uni_vpbroadcastb(vmm_zero_point(0), reg_ptr_8); + auto reg_ptr_32 = Reg32(reg_ptr.getIdx()); + movzx(reg_ptr_32, ptr[reg_local_wei_zp]); + uni_vmovq(Xmm(vmm_zero_point(0).getIdx()), reg_ptr); + uni_vbroadcastss(vmm_zero_point(0), Xmm(vmm_zero_point(0).getIdx())); } else { - static const int8_t index_table[64] = { - 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x0C, 0x0C, 0x0C, 0x0C, - 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x0C, 0x0C, 0x0C, 0x0C, - 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x0C, 0x0C, 0x0C, 0x0C, - 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x0C, 0x0C, 0x0C, 0x0C - }; - - auto vmm_indexes = Vmm(isa_num_vregs(brg.isa_impl) - 1); - mov(reg_ptr, (size_t)index_table); - uni_vmovups(vmm_indexes, ptr[reg_ptr]); - for (int ld = 0; ld < ld_block2; ld++) { uni_vpmovzxbd(vmm_zero_point(ld), ptr[reg_local_wei_zp + ld * brg.ld_block * types::data_type_size(brg.wei_decomp_zero_points_dt)]); - vpshufb(vmm_zero_point(ld), vmm_zero_point(ld), vmm_indexes); } } } - auto vmm_neg_one = Vmm(isa_num_vregs(brg.isa_impl) - 1); - mov(reg_ptr, (size_t)negative_one); - uni_vmovups(vmm_neg_one, ptr[reg_ptr]); - - auto vmm_mask_low_half = Vmm(isa_num_vregs(brg.isa_impl) - 2); + auto vmm_mask_low_half = Vmm(isa_num_vregs(brg.isa_impl) - 1); mov(reg_ptr, (size_t)mask_low_half); uni_vmovups(vmm_mask_low_half, ptr[reg_ptr]); @@ -2409,22 +2394,28 @@ void jit_brgemm_kernel_t::gemm_microkernel_dyn_quant(int bd_block2, auto vmm = accm(ld_block2, bd, ld); vpdpbusd(vmm, load(ld), bcst(), is_superset(brg.isa_impl, avx512_core) ? EvexEncoding : VexEncoding); } - if (brg.with_wei_decomp_zero_points) { - uni_vpxor(bcst(), bcst(), vmm_neg_one); - uni_vpsubb(bcst(), bcst(), vmm_neg_one); - for (int ld = 0; ld < ld_block2; ld++) { - auto vmm = accm(ld_block2, bd, ld); - Vmm vmm_zp = brg.wei_decomp_zero_points_stride == 0 ? vmm_zero_point(0) : vmm_zero_point(ld); - vpdpbusd(vmm, vmm_zp, bcst(), is_superset(brg.isa_impl, avx512_core) ? EvexEncoding : VexEncoding); - } - } } } auto reg_local_src_scales = reg_local_wei_zp; + auto reg_local_src_grouped_sum = reg_local_wei_zp; auto vmm_src_scales = bcst(); - mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_ + accums_stack_space]); + auto vmm_src_grouped_sum = bcst(); + if (brg.with_wei_decomp_zero_points) { + mov(reg_local_src_grouped_sum, ptr[rsp + reg_aux2_src_grouped_sum_offs_ + accums_stack_space]); + for (int bd = bd_b; bd < bd_e; bd++) { + for (int ld = 0; ld < ld_block2; ld++) { + auto vmm_accm = accm(ld_block2, bd, ld); + Vmm vmm_zp = brg.wei_decomp_zero_points_stride == 0 ? vmm_zero_point(0) : vmm_zero_point(ld); + uni_vbroadcastss(vmm_src_grouped_sum, ptr[reg_local_src_grouped_sum + bd * brg.src_grouped_sum_stride * sizeof(int32_t)]); + uni_vpmulld(vmm_src_grouped_sum, vmm_src_grouped_sum, vmm_zp); + uni_vpsubd(vmm_accm, vmm_accm, vmm_src_grouped_sum); + } + } + } + + mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_ + accums_stack_space]); for (int bd = bd_b; bd < bd_e; bd++) { uni_vbroadcastss(vmm_src_scales, ptr[reg_local_src_scales + bd * brg.src_scales_stride * sizeof(float)]); for (int ld = 0; ld < ld_block2; ld++) { @@ -3073,6 +3064,11 @@ void jit_brgemm_kernel_t::ldb_loop(int bd_block2, bool is_bdb_tail, if (brg.with_src_dyn_quant) { ic_group_shift(reg_aux_src_scales_offs_, reg_aux2_src_scales_offs_, brg.src_scales_group_size, sizeof(float)); + + if (brg.with_wei_decomp_zero_points) { + ic_group_shift(reg_aux_src_grouped_sum_offs_, reg_aux2_src_grouped_sum_offs_, + brg.src_sum_group_size, sizeof(int32_t)); + } } mov(reg_local_ic, ptr[rsp + reg_aux_ic_offs_]); @@ -3306,6 +3302,10 @@ void jit_brgemm_kernel_t::bdb_loop() { mov(reg_src_scales, ptr[rsp + reg_src_scales_offs_]); add(reg_src_scales, bd_block2 * brg.bd_block * brg.src_scales_stride * sizeof(float)); mov(ptr[rsp + reg_src_scales_offs_], reg_src_scales); + + mov(reg_src_grouped_sum, ptr[rsp + reg_src_grouped_sum_offs_]); + add(reg_src_grouped_sum, bd_block2 * brg.bd_block * brg.src_grouped_sum_stride * sizeof(int32_t)); + mov(ptr[rsp + reg_src_grouped_sum_offs_], reg_src_grouped_sum); } advance_bd_block2_post_op_regs(bd_block2); diff --git a/src/cpu/x64/jit_brgemm_inner_product.cpp b/src/cpu/x64/jit_brgemm_inner_product.cpp index 5879b5b5a89..f6f2092cb8f 100644 --- a/src/cpu/x64/jit_brgemm_inner_product.cpp +++ b/src/cpu/x64/jit_brgemm_inner_product.cpp @@ -143,21 +143,26 @@ status_t brgemm_inner_product_fwd_t::execute_forward( int8_t* qsrc = nullptr; float* src_dscales = nullptr; + int32_t* src_grouped_sum = nullptr; if (jbgp.with_src_dynamic_quant) { qsrc = scratchpad.template get(key_src_quantized); src_dscales = scratchpad.template get(key_src_dequantized_scales); + src_grouped_sum = scratchpad.template get(key_src_grouped_sum); int ic_groups = div_up(jbgp.ic, jbgp.src_quant_group_size); + int ic_sum_groups = div_up(jbgp.ic, jbgp.src_sum_group_size); auto src_ptr = reinterpret_cast(src); auto qsrc_ptr = qsrc; auto src_dscales_ptr = src_dscales; - int vec_loop_end = (ic_groups - 1) * jbgp.src_quant_group_size; + auto src_grouped_sum_ptr = src_grouped_sum; + int vec_loop_end = rnd_dn(jbgp.ic, jbgp.src_quant_group_size); parallel_nd(jbgp.mb, [&](int mb) { src_quantization_runtime_params_t rt_params = {}; rt_params.src_ptr = src_ptr + mb * jbgp.ic; rt_params.qsrc_ptr = qsrc_ptr + mb * jbgp.ic; rt_params.src_scales_ptr = src_dscales_ptr + mb * ic_groups; + rt_params.src_grouped_sum_ptr = src_grouped_sum_ptr + mb * ic_sum_groups; rt_params.ic_size = vec_loop_end; (*brg_src_quant_kernel_)(&rt_params); @@ -175,6 +180,18 @@ status_t brgemm_inner_product_fwd_t::execute_forward( qsrc_ptr[mb * jbgp.ic + ic] = std::round(src_ptr[mb * jbgp.ic + ic] * qscale); } } + + if (jbgp.wei_decomp_zero_points_dt) { + for (int icb = vec_loop_end / jbgp.src_quant_group_size; icb < ic_sum_groups; icb++) { + int ic_begin = icb * jbgp.src_sum_group_size; + int ic_end = nstl::min(static_cast((icb + 1) * jbgp.src_sum_group_size), jbgp.ic); + int sum = 0; + for (int ic = ic_begin; ic < ic_end; ic++) { + sum += qsrc_ptr[mb * jbgp.ic + ic]; + } + src_grouped_sum_ptr[mb * ic_sum_groups + icb] = sum; + } + } }); src = reinterpret_cast(qsrc); @@ -429,10 +446,12 @@ status_t brgemm_inner_product_fwd_t::execute_forward( int wei_scales_offset = 0; int wei_zero_points_offset = 0; int src_scales_offset = 0; + int src_grouped_sum_offset = 0; if (jbgp.weights_decompression) { wei_scales_offset = wei_scales_oc_stride * oc * wei_scales_dt_size; wei_zero_points_offset = wei_zero_points_oc_stride * oc * wei_zero_points_dt_size; src_scales_offset = n * div_up(jbgp.ic, jbgp.src_quant_group_size); + src_grouped_sum_offset = n * div_up(jbgp.ic, jbgp.src_sum_group_size); } auto ptr_D = dst + dst_off; @@ -456,10 +475,12 @@ status_t brgemm_inner_product_fwd_t::execute_forward( brgemm_kernel_execute_postops(brg_kernel, gemm_batch, addr_batch, (void *)ptr_C, (void *)ptr_D, post_ops_data, - scratch, nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, ic); + scratch, nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, + src_dscales + src_scales_offset, src_grouped_sum + src_grouped_sum_offset, ic); } else { brgemm_kernel_execute(brg_kernel, gemm_batch, addr_batch, - (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, ic); + (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, + src_dscales + src_scales_offset, src_grouped_sum + src_grouped_sum_offset, ic); } } @@ -534,10 +555,12 @@ status_t brgemm_inner_product_fwd_t::execute_forward( int wei_scales_offset = 0; int wei_zero_points_offset = 0; int src_scales_offset = 0; + int src_grouped_sum_offset = 0; if (jbgp.weights_decompression) { wei_scales_offset = wei_scales_oc_stride * oc * wei_scales_dt_size; wei_zero_points_offset = wei_zero_points_oc_stride * oc * wei_zero_points_dt_size; src_scales_offset = n * div_up(jbgp.ic, jbgp.src_quant_group_size); + src_grouped_sum_offset = n * div_up(jbgp.ic, jbgp.src_sum_group_size); } auto brg_kernel_ic_tail = brg_kernels_[brg_ker_ic_tail_idx].get(); @@ -560,10 +583,12 @@ status_t brgemm_inner_product_fwd_t::execute_forward( nullptr, false, 1, false, false, dst_scales}; brgemm_kernel_execute_postops(brg_kernel_ic_tail, 1, addr_batch, - (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch, nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, ic); + (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch, nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, + src_dscales + src_scales_offset, src_grouped_sum + src_grouped_sum_offset, ic); } else { brgemm_kernel_execute(brg_kernel_ic_tail, 1, addr_batch, - (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, ic); + (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, + src_dscales + src_scales_offset, src_grouped_sum + src_grouped_sum_offset, ic); } } }; diff --git a/src/cpu/x64/jit_brgemm_inner_product.hpp b/src/cpu/x64/jit_brgemm_inner_product.hpp index 118b6b79fc4..041c27abd37 100644 --- a/src/cpu/x64/jit_brgemm_inner_product.hpp +++ b/src/cpu/x64/jit_brgemm_inner_product.hpp @@ -265,6 +265,8 @@ struct brgemm_inner_product_fwd_t : public primitive_t { if (pd()->jbgp_.with_src_dynamic_quant) { src_quantization_compile_params_t jcp = {}; jcp.ic_quant_block = pd()->jbgp_.src_quant_group_size; + jcp.with_src_grouped_sum = !pd()->attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS); + jcp.src_sum_group_size = pd()->jbgp_.src_sum_group_size; jcp.src_dt = pd()->jbgp_.orig_src_dt; jcp.qsrc_dt = data_type::s8; diff --git a/src/cpu/x64/jit_brgemm_inner_product_utils.cpp b/src/cpu/x64/jit_brgemm_inner_product_utils.cpp index 929bf649867..5db750157ce 100644 --- a/src/cpu/x64/jit_brgemm_inner_product_utils.cpp +++ b/src/cpu/x64/jit_brgemm_inner_product_utils.cpp @@ -1421,6 +1421,7 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa, jbgp.with_src_dynamic_quant = false; if (jbgp.weights_decompression) { jbgp.src_quant_group_size = jbgp.ic; + jbgp.src_sum_group_size = jbgp.ic; if (!attr.src_dyn_quant_params_.has_default_values()) { jbgp.with_src_dynamic_quant = true; } @@ -1477,6 +1478,13 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa, if (jbgp.with_src_dynamic_quant) { jbgp.orig_src_dt = jbgp.src_dt; jbgp.src_dt = s8; + + size_t rd_unroll = 32; + auto min_group_size = nstl::min(jbgp.wei_scales_ic_group_size, jbgp.wei_zero_points_ic_group_size); + min_group_size = nstl::min(min_group_size, jbgp.src_quant_group_size); + jbgp.src_sum_group_size = nstl::min(rd_unroll, min_group_size); + + assert(jbgp.src_quant_group_size % jbgp.src_sum_group_size == 0); } jbgp.bia_dt = jbgp.with_bias @@ -1691,6 +1699,8 @@ void jit_brgemm_ip_conf_t::init_scratchpad_base( if (jbgp.with_src_dynamic_quant) { scratchpad.book(key_src_quantized, jbgp.mb * jbgp.ic, sizeof(int8_t)); scratchpad.book(key_src_dequantized_scales, jbgp.mb * div_up(jbgp.ic, jbgp.src_quant_group_size), sizeof(float)); + if (jbgp.wei_decomp_zero_points_dt) + scratchpad.book(key_src_grouped_sum, jbgp.mb * div_up(jbgp.ic, jbgp.src_sum_group_size), sizeof(int32_t)); } } diff --git a/src/cpu/x64/jit_brgemm_primitive_conf.hpp b/src/cpu/x64/jit_brgemm_primitive_conf.hpp index 5f7ebd2cf0e..60cbe03f718 100644 --- a/src/cpu/x64/jit_brgemm_primitive_conf.hpp +++ b/src/cpu/x64/jit_brgemm_primitive_conf.hpp @@ -113,6 +113,7 @@ struct jit_brgemm_primitive_conf_t { bool with_src_dynamic_quant; size_t src_quant_group_size; + size_t src_sum_group_size; data_type_t orig_src_dt; }; diff --git a/src/cpu/x64/jit_brgemm_src_quantization_kernel.cpp b/src/cpu/x64/jit_brgemm_src_quantization_kernel.cpp index 90256ade02d..b6aff7896b0 100644 --- a/src/cpu/x64/jit_brgemm_src_quantization_kernel.cpp +++ b/src/cpu/x64/jit_brgemm_src_quantization_kernel.cpp @@ -43,6 +43,39 @@ void jit_brgemm_src_quantization_kernel_t::load_src(Vmm vmm_load, const Xby } } +template +void jit_brgemm_src_quantization_kernel_t::horiz_op(Vmm vmm_src, Vmm vmm_aux, op_type type) { + auto uni_op = [&](const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) { + if (type == op_type::max) { + uni_vmaxps(x1, x2, op); + } else if (type == op_type::sum) { + uni_vpaddd(x1, x2, op); + } else { + assert(!"unsupported op type"); + } + }; + + if (isa == avx512_core) { + Xbyak::Zmm zmm_src = Xbyak::Zmm(vmm_src.getIdx()); + Xbyak::Zmm zmm_aux = Xbyak::Zmm(vmm_aux.getIdx()); + vshuff32x4(zmm_aux, zmm_src, zmm_src, 0x4E); + uni_op(zmm_src, zmm_src, zmm_aux); + vshuff32x4(zmm_aux, zmm_src, zmm_src, 0xB1); + uni_op(zmm_src, zmm_src, zmm_aux); + } else if (isa == avx2) { + Xbyak::Ymm ymm_src = Xbyak::Ymm(vmm_src.getIdx()); + Xbyak::Ymm ymm_aux = Xbyak::Ymm(vmm_aux.getIdx()); + vperm2i128(ymm_aux, ymm_src, ymm_src, 0x01); + uni_op(ymm_src, ymm_src, ymm_aux); + } else { + assert(!"unsupported isa"); + } + uni_vshufps(vmm_aux, vmm_src, vmm_src, 0x4E); + uni_op(vmm_src, vmm_src, vmm_aux); + uni_vshufps(vmm_aux, vmm_src, vmm_src, 0xB1); + uni_op(vmm_src, vmm_src, vmm_aux); +} + template void jit_brgemm_src_quantization_kernel_t::generate() { preamble(); @@ -50,6 +83,7 @@ void jit_brgemm_src_quantization_kernel_t::generate() { mov(reg_src, ptr[param1 + GET_OFF(src_ptr)]); mov(reg_qsrc, ptr[param1 + GET_OFF(qsrc_ptr)]); mov(reg_src_scales, ptr[param1 + GET_OFF(src_scales_ptr)]); + mov(reg_src_grouped_sum, ptr[param1 + GET_OFF(src_grouped_sum_ptr)]); mov(reg_ic_size, ptr[param1 + GET_OFF(ic_size)]); Xbyak::Label ic_loop_label; @@ -58,6 +92,7 @@ void jit_brgemm_src_quantization_kernel_t::generate() { size_t src_dt_size = types::data_type_size(jcp_.src_dt); size_t qsrc_dt_size = types::data_type_size(jcp_.qsrc_dt); size_t src_scales_dt_size = types::data_type_size(data_type::f32); + size_t src_grouped_sum_dt_size = types::data_type_size(data_type::s32); static const float negative_zero[16] = { -0.f, -0.f, -0.f, -0.f, -0.f, -0.f, -0.f, -0.f, @@ -89,6 +124,7 @@ void jit_brgemm_src_quantization_kernel_t::generate() { jl(ic_end_label, T_NEAR); assert(!(jcp_.ic_quant_block % vec_size)); + assert(!(jcp_.src_sum_group_size % vec_size)); int ic_blocks = jcp_.ic_quant_block / vec_size; uni_vpxor(vmm_max(), vmm_max(), vmm_max()); @@ -98,25 +134,7 @@ void jit_brgemm_src_quantization_kernel_t::generate() { uni_vmaxps(vmm_max(), vmm_max(), vmm_src()); } - if (isa == avx512_core) { - Xbyak::Zmm max_zmm = Xbyak::Zmm(vmm_max().getIdx()); - Xbyak::Zmm aux_zmm = Xbyak::Zmm(vmm_aux().getIdx()); - vshuff32x4(aux_zmm, max_zmm, max_zmm, 0x4E); - uni_vmaxps(max_zmm, max_zmm, aux_zmm); - vshuff32x4(aux_zmm, max_zmm, max_zmm, 0xB1); - uni_vmaxps(max_zmm, max_zmm, aux_zmm); - } else if (isa == avx2) { - Xbyak::Ymm max_ymm = Xbyak::Ymm(vmm_max().getIdx()); - Xbyak::Ymm aux_ymm = Xbyak::Ymm(vmm_aux().getIdx()); - vperm2i128(aux_ymm, max_ymm, max_ymm, 0x01); - uni_vmaxps(max_ymm, max_ymm, aux_ymm); - } else { - assert(!"unsupported isa"); - } - uni_vshufps(vmm_aux(), vmm_max(), vmm_max(), 0x4E); - uni_vmaxps(vmm_max(), vmm_max(), vmm_aux()); - uni_vshufps(vmm_aux(), vmm_max(), vmm_max(), 0xB1); - uni_vmaxps(vmm_max(), vmm_max(), vmm_aux()); + horiz_op(vmm_max(), vmm_aux(), op_type::max); auto vmm_dscale = vmm_max(); uni_vbroadcastss(vmm_dscale, Xmm(vmm_dscale.getIdx())); @@ -126,11 +144,25 @@ void jit_brgemm_src_quantization_kernel_t::generate() { uni_vdivps(vmm_qscale(), vmm_one(), vmm_dscale); uni_vmovss(ptr[reg_src_scales], Xmm(vmm_dscale.getIdx())); + if (jcp_.with_src_grouped_sum) { + uni_vxorps(vmm_src_sum_accum(), vmm_src_sum_accum(), vmm_src_sum_accum()); + } for (int icb = 0; icb < ic_blocks; icb++) { load_src(vmm_src(), ptr[reg_src + icb * vec_size * src_dt_size]); uni_vmulps(vmm_src(), vmm_src(), vmm_qscale()); uni_vcvtps2dq(vmm_src(), vmm_src()); + if (jcp_.with_src_grouped_sum) { + uni_vpaddd(vmm_src_sum_accum(), vmm_src_sum_accum(), vmm_src()); + + if (((icb + 1) * vec_size) % jcp_.src_sum_group_size == 0) { + horiz_op(vmm_src_sum_accum(), vmm_aux(), op_type::sum); + uni_vmovss(ptr[reg_src_grouped_sum], Xmm(vmm_src_sum_accum().getIdx())); + uni_vxorps(vmm_src_sum_accum(), vmm_src_sum_accum(), vmm_src_sum_accum()); + add(reg_src_grouped_sum, src_grouped_sum_dt_size); + } + } + if (isa == avx512_core) { vpmovsdb(ptr[reg_qsrc + icb * vec_size * qsrc_dt_size], vmm_src()); } else { diff --git a/src/cpu/x64/jit_brgemm_src_quantization_kernel.hpp b/src/cpu/x64/jit_brgemm_src_quantization_kernel.hpp index 15c2621940e..b93b19ee25d 100644 --- a/src/cpu/x64/jit_brgemm_src_quantization_kernel.hpp +++ b/src/cpu/x64/jit_brgemm_src_quantization_kernel.hpp @@ -32,6 +32,8 @@ namespace x64 { struct src_quantization_compile_params_t { size_t ic_quant_block; + bool with_src_grouped_sum; + size_t src_sum_group_size; data_type_t src_dt; data_type_t qsrc_dt; }; @@ -40,6 +42,7 @@ struct src_quantization_runtime_params_t { const void *src_ptr; const void *qsrc_ptr; const void *src_scales_ptr; + const void *src_grouped_sum_ptr; size_t ic_size; }; @@ -76,6 +79,9 @@ struct jit_brgemm_src_quantization_kernel_t : public jit_src_quantization_kernel void generate() override; void load_src(Vmm vmm_load, const Xbyak::Address& addr); + enum class op_type {max, sum}; + void horiz_op(Vmm vmm_src, Vmm vmm_aux, op_type op); + Vmm vmm_src() { return Vmm(0); } @@ -104,11 +110,16 @@ struct jit_brgemm_src_quantization_kernel_t : public jit_src_quantization_kernel return Vmm(6); } + Vmm vmm_src_sum_accum() { + return Vmm(7); + } + Xbyak::Reg64 reg_src = r8; Xbyak::Reg64 reg_qsrc = r9; Xbyak::Reg64 reg_src_scales = r10; Xbyak::Reg64 reg_ic_size = r11; Xbyak::Reg64 reg_tmp = r12; + Xbyak::Reg64 reg_src_grouped_sum = r13; size_t vec_size; };