Skip to content

Commit

Permalink
[FORK][FEATURE] DQ IP: reduce aux vecs counts required for microkernel
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitry-gorokhov committed Jan 28, 2025
1 parent 77250c1 commit bc4e68a
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 49 deletions.
3 changes: 1 addition & 2 deletions src/cpu/x64/brgemm/brgemm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,8 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) {
if (one_of(brg->dt_b, data_type::nf4) && brg->isa_impl == avx2) max_bcast_block -= 5;
if (one_of(brg->dt_b, data_type::f4_e2m1) && brg->isa_impl == avx2) max_bcast_block -= 2;
if (one_of(brg->dt_b, data_type::nf4, data_type::f4_e2m1) && brg->isa_impl != avx2) max_bcast_block -= 1;
if (brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride == 0) max_bcast_block -= 1;
if (brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride == 0 && !brg->with_src_dyn_quant) max_bcast_block -= 1;
if (brg->with_src_dyn_quant) max_bcast_block -= 1;
if (brg->with_src_dyn_quant && brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride != 0) max_bcast_block -= adj_ld_block2;

max_bcast_block /= adj_ld_block2;

Expand Down
92 changes: 45 additions & 47 deletions src/cpu/x64/brgemm/jit_brgemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,17 +320,13 @@ struct jit_brgemm_kernel_t : public jit_generator {
used_vregs += 1;
}

if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride == 0) {
if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride == 0 && !brg.with_src_dyn_quant) {
used_vregs += 1;
}

if (brg.with_src_dyn_quant) {
used_vregs += 1;
}

if (brg.with_src_dyn_quant && brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0) {
used_vregs += brg.ld_block2;
}
return isa_num_vregs(brg.isa_impl) - used_vregs;
}

Expand Down Expand Up @@ -2306,11 +2302,6 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
if (brg.req_s8s8_compensation) uni_vpaddb(v1, v1, vmm_inp_shift());
};

auto vmm_zero_point = [&](int ld) {
int idx = isa_num_vregs(brg.isa_impl) - 2 - ld;
return Vmm(idx);
};

static const int8_t mask_low_half[64] = {
0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,
0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,
Expand All @@ -2321,30 +2312,11 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop);
mov(ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop);

auto reg_local_wei_scales = reg_bdb_loop;
auto reg_local_wei_zp = reg_ldb_loop;
auto reg_ptr = reg_local_wei_scales;

if (brg.with_wei_decomp_zero_points) {
mov(reg_local_wei_zp, ptr[rsp + reg_aux2_wei_zero_points_offs_]);
if (brg.wei_decomp_zero_points_stride == 0) {
auto reg_ptr_32 = Reg32(reg_ptr.getIdx());
movzx(reg_ptr_32, ptr[reg_local_wei_zp]);
uni_vmovq(Xmm(vmm_zero_point(0).getIdx()), reg_ptr);
uni_vbroadcastss(vmm_zero_point(0), Xmm(vmm_zero_point(0).getIdx()));
} else {
for (int ld = 0; ld < ld_block2; ld++) {
uni_vpmovzxbd(vmm_zero_point(ld), ptr[reg_local_wei_zp + ld * brg.ld_block * types::data_type_size(brg.wei_decomp_zero_points_dt)]);
}
}
}

auto reg_ptr = reg_bdb_loop;
auto vmm_mask_low_half = Vmm(isa_num_vregs(brg.isa_impl) - 1);
mov(reg_ptr, (size_t)mask_low_half);
uni_vmovups(vmm_mask_low_half, ptr[reg_ptr]);

mov(reg_local_wei_scales, ptr[rsp + reg_aux2_wei_scales_offs_]);

const int vec_size = vreg_traits<Vmm>::vlen;
auto accums_stack_space = bd_e * ld_block2 * vec_size;
sub(rsp, accums_stack_space);
Expand Down Expand Up @@ -2397,42 +2369,68 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
}
}

auto reg_local_src_scales = reg_local_wei_zp;
auto reg_local_src_grouped_sum = reg_local_wei_zp;
auto vmm_src_scales = bcst();
auto vmm_zero_point = [&](int ld) {
return load(ld);
};

auto reg_local_wei_zp = reg_ldb_loop;
auto reg_local_src_grouped_sum = reg_bdb_loop;
auto vmm_tmp = Vmm(isa_num_vregs(brg.isa_impl) - 1);
auto vmm_src_grouped_sum = bcst();

if (brg.with_wei_decomp_zero_points) {
mov(reg_local_wei_zp, ptr[rsp + reg_aux2_wei_zero_points_offs_ + accums_stack_space]);
if (brg.wei_decomp_zero_points_stride == 0) {
Vmm vmm_zp = vmm_zero_point(0);
auto reg_ptr_32 = Reg32(reg_ptr.getIdx());
movzx(reg_ptr_32, ptr[reg_local_wei_zp]);
uni_vmovq(Xmm(vmm_zp.getIdx()), reg_ptr);
uni_vbroadcastss(vmm_zp, Xmm(vmm_zp.getIdx()));
}

mov(reg_local_src_grouped_sum, ptr[rsp + reg_aux2_src_grouped_sum_offs_ + accums_stack_space]);
for (int bd = bd_b; bd < bd_e; bd++) {
uni_vbroadcastss(vmm_src_grouped_sum, ptr[reg_local_src_grouped_sum + bd * brg.src_grouped_sum_stride * sizeof(int32_t)]);
for (int ld = 0; ld < ld_block2; ld++) {
auto vmm_accm = accm(ld_block2, bd, ld);
Vmm vmm_zp = brg.wei_decomp_zero_points_stride == 0 ? vmm_zero_point(0) : vmm_zero_point(ld);
uni_vbroadcastss(vmm_src_grouped_sum, ptr[reg_local_src_grouped_sum + bd * brg.src_grouped_sum_stride * sizeof(int32_t)]);
uni_vpmulld(vmm_src_grouped_sum, vmm_src_grouped_sum, vmm_zp);
uni_vpsubd(vmm_accm, vmm_accm, vmm_src_grouped_sum);
if (bd == bd_b && brg.wei_decomp_zero_points_stride != 0) {
uni_vpmovzxbd(vmm_zp, ptr[reg_local_wei_zp + ld * brg.ld_block * types::data_type_size(brg.wei_decomp_zero_points_dt)]);
}

auto vmm_accm = accm(ld_block2, bd, ld);
uni_vpmulld(vmm_tmp, vmm_src_grouped_sum, vmm_zp);
uni_vpsubd(vmm_accm, vmm_accm, vmm_tmp);
}
}
}

auto wei_scale = [&](int ld) {
return load(ld);
};

auto reg_local_src_scales = reg_ldb_loop;
auto reg_local_wei_scales = reg_bdb_loop;
auto vmm_src_scales = bcst();

mov(reg_local_wei_scales, ptr[rsp + reg_aux2_wei_scales_offs_ + accums_stack_space]);
mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_ + accums_stack_space]);
if (brg.wei_decomp_scales_stride == 0) {
uni_vbroadcastss(wei_scale(0), ptr[reg_local_wei_scales]);
}

for (int bd = bd_b; bd < bd_e; bd++) {
uni_vbroadcastss(vmm_src_scales, ptr[reg_local_src_scales + bd * brg.src_scales_stride * sizeof(float)]);
for (int ld = 0; ld < ld_block2; ld++) {
if (brg.wei_decomp_scales_stride == 0) {
uni_vbroadcastss(load(ld), ptr[reg_local_wei_scales]);
} else {
uni_vmovups(load(ld), ptr[reg_local_wei_scales + ld * brg.ld_block * sizeof(float)]);
auto vmm_wei_scale = brg.wei_decomp_scales_stride == 0 ? wei_scale(0) : wei_scale(ld);
if (bd == bd_b && brg.wei_decomp_scales_stride != 0) {
uni_vmovups(vmm_wei_scale, ptr[reg_local_wei_scales + ld * brg.ld_block * sizeof(float)]);
}
}
for (int ld = 0; ld < ld_block2; ld++) {
auto vmm_accm = accm(ld_block2, bd, ld);

auto vmm_accm = accm(ld_block2, bd, ld);
uni_vcvtdq2ps(vmm_accm, vmm_accm);
uni_vmulps(vmm_accm, vmm_accm, vmm_src_scales);
uni_vmulps(load(ld), vmm_accm, load(ld));
uni_vmulps(vmm_tmp, vmm_accm, vmm_src_scales);
uni_vmovups(vmm_accm, ptr[rsp + (bd * ld_block2 + ld) * vec_size]);
uni_vaddps(vmm_accm, vmm_accm, load(ld));
uni_vfmadd231ps(vmm_accm, vmm_tmp, vmm_wei_scale);
}
}

Expand Down

0 comments on commit bc4e68a

Please sign in to comment.