Skip to content

Commit

Permalink
xe: jit: gemm: SLM remasking for m/n grouped scales
Browse files Browse the repository at this point in the history
  • Loading branch information
petercad authored and umar456 committed Feb 11, 2025
1 parent 3f14034 commit 6c5e592
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 17 deletions.
69 changes: 53 additions & 16 deletions src/gpu/intel/jit/gemm/generator/pieces/k_loop.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -871,35 +871,70 @@ void BLASKernelGenerator<hw>::kLoop(KLoop type, const GEMMProblem &problem, GEMM
if (remaskA && remaskB && loadBFirst)
ls.swapLast2();

// A/B quantization parameter repacking.
// A/B quantization parameter repacking and remasking.
auto reqRepackAq = every(kaq_load);
auto reqRepackBq = every(kbq_load);
auto reqRepackAqLate = every(kaq_loadLate);
auto reqRepackBqLate = every(kbq_loadLate);

bool remaskAs = as2D && (minOPCount > 1) && (problem.aqGroupK == 1);
bool remaskBs = bs2D && (minOPCount > 1) && (problem.bqGroupK == 1);
bool remaskAq = (ao2D || as2D) && (minOPCount > 1) && (problem.aqGroupK == 1);
bool remaskBq = (ao2D || bs2D) && (minOPCount > 1) && (problem.bqGroupK == 1);
int iremaskScale = 2;
if (dequantize2DA) ls.schedule(reqRepackAq, [&](Iteration h) {
if (remaskAs) {
int ms, ks;

auto doRemaskAq = [&](Iteration h, bool slm) {
if (!remaskAq) return;
int ms, ks;
Subregister offK;
if (slm && (state.effCoopA == CoopSplit::K || state.effCoopA == CoopSplit::FullK)) {
offK = state.ra.allocSub<uint32_t>();
mulConstant(1, offK, state.lidN, state.ka_slm);
}
if (as2D) {
getLayoutDims(state.A_scaleLayout, ms, ks);
setupTeardownRemask(Ta_scale, iremaskScale, true, ks, state.K, strategy, state, -h.counterOffset());
remaskLayout(Ta_scale, iremaskScale, true, state.A_scaleLayout, state.A_scaleRegs, strategy, state, h % ks);
setupTeardownRemask(Ta_scale, iremaskScale, false, ks, state.K, strategy, state);
remaskLayoutSingle(Ta_scale, iremaskScale, true, ks, state.K,
state.A_scaleLayout, state.A_scaleRegs, strategy, state,
-h.counterOffset(), offK);
}
if (ao2D) {
getLayoutDims(state.A_offsetLayout, ms, ks);
remaskLayoutSingle(Tao, iremaskScale, true, ks, state.K,
state.A_offsetLayout, state.A_offsetRegs, strategy, state,
-h.counterOffset(), offK);
}
state.ra.safeRelease(offK);
};

auto doRemaskBq = [&](Iteration h, bool slm) {
if (!remaskBq) return;
int ks, ns;
Subregister offK;
if (slm && (state.effCoopB == CoopSplit::K || state.effCoopB == CoopSplit::FullK)) {
offK = state.ra.allocSub<uint32_t>();
mulConstant(1, offK, state.lidM, state.ka_slm);
}
if (bs2D) {
getLayoutDims(state.B_scaleLayout, ks, ns);
remaskLayoutSingle(Tb_scale, iremaskScale, false, ks, state.K,
state.B_scaleLayout, state.B_scaleRegs, strategy, state,
-h.counterOffset(), offK);
}
if (bo2D) {
getLayoutDims(state.B_offsetLayout, ks, ns);
remaskLayoutSingle(Tbo, iremaskScale, false, ks, state.K,
state.B_offsetLayout, state.B_offsetRegs, strategy, state,
-h.counterOffset(), offK);
}
state.ra.safeRelease(offK);
};

if (dequantize2DA) ls.schedule(reqRepackAq, [&](Iteration h) {
if (A_remActive(h)) doRemaskAq(h, false);
if (ao2D) gemmRepack2DOffsetData(Ta_ext, Tao, state.Tao_int, state.A_offsetLayout, state.Ar_offsetLayout, state.A_offsetRegs, state.Ar_offsetRegs, problem, strategy, state);
if (as2D) gemmRepack2DQuantizationData(Ta_scale, state.Ta_scaleOp, state.A_scaleLayout, state.Ar_scaleLayout, state.A_scaleRegs, state.Ar_scaleRegs, problem, strategy, state);
});

if (dequantize2DB) ls.schedule(reqRepackBq, [&](Iteration h) {
if (remaskBs) {
int ks, ns;
getLayoutDims(state.B_scaleLayout, ks, ns);
setupTeardownRemask(Tb_scale, iremaskScale, true, ks, state.K, strategy, state, -h.counterOffset());
remaskLayout(Tb_scale, iremaskScale, false, state.B_scaleLayout, state.B_scaleRegs, strategy, state, h % ks);
setupTeardownRemask(Tb_scale, iremaskScale, false, ks, state.K, strategy, state);
}
if (B_remActive(h)) doRemaskBq(h, false);
if (bo2D) gemmRepack2DOffsetData(Tb_ext, Tbo, state.Tbo_int, state.B_offsetLayout, state.Br_offsetLayout, state.B_offsetRegs, state.Br_offsetRegs, problem, strategy, state);
if (bs2D) gemmRepack2DQuantizationData(Tb_scale, state.Tb_scaleOp, state.B_scaleLayout, state.Br_scaleLayout, state.B_scaleRegs, state.Br_scaleRegs, problem, strategy, state);
});
Expand Down Expand Up @@ -1058,10 +1093,12 @@ void BLASKernelGenerator<hw>::kLoop(KLoop type, const GEMMProblem &problem, GEMM

if (slmDequantize2D) ls.schedule(reqSLMRepackQ, [&](Iteration h) {
if (slmDequantize2DA) {
if (slmRemActive(h)) doRemaskAq(h, true);
if (ao2D) gemmRepack2DOffsetData(Ta_ext, problem.Tao, state.Tao_int, state.A_offsetLayout, state.Ar_offsetLayout, state.A_offsetRegs, state.Ar_offsetRegs, problem, strategy, state);
if (as2D) gemmRepack2DQuantizationData(problem.Ta_scale, state.Ta_scaleOp, state.A_scaleLayout, state.Ar_scaleLayout, state.A_scaleRegs, state.Ar_scaleRegs, problem, strategy, state);
}
if (slmDequantize2DB) {
if (slmRemActive(h)) doRemaskBq(h, true);
if (bo2D) gemmRepack2DOffsetData(Tb_ext, problem.Tbo, state.Tbo_int, state.B_offsetLayout, state.Br_offsetLayout, state.B_offsetRegs, state.Br_offsetRegs, problem, strategy, state);
if (bs2D) gemmRepack2DQuantizationData(problem.Tb_scale, state.Tb_scaleOp, state.B_scaleLayout, state.Br_scaleLayout, state.B_scaleRegs, state.Br_scaleRegs, problem, strategy, state);
}
Expand Down
15 changes: 14 additions & 1 deletion src/gpu/intel/jit/gemm/generator/pieces/remask.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ void BLASKernelGenerator<hw>::setupTeardownRemask(Type T, int index, bool setup,
}

template <HW hw>
void BLASKernelGenerator<hw>::remaskLayout(Type T, int index, bool column, const std::vector<RegisterBlock> &layout, const GRFMultirange &regs, const CommonStrategy &strategy, CommonState &state, int offset)
void BLASKernelGenerator<hw>::remaskLayout(Type T, int index, bool column,
const std::vector<RegisterBlock> &layout, const GRFMultirange &regs,
const CommonStrategy &strategy, CommonState &state, int offset)
{
for (auto &block: layout) {
auto crosspack = block.crosspack;
Expand Down Expand Up @@ -164,4 +166,15 @@ void BLASKernelGenerator<hw>::remaskLayout(Type T, int index, bool column, const
}
}

template <HW hw>
void BLASKernelGenerator<hw>::remaskLayoutSingle(Type T, int index, bool column, int nq, Subregister remQ,
const std::vector<RegisterBlock> &layout, const GRFMultirange &regs,
const CommonStrategy &strategy, CommonState &state,
int fixedOffQ, const Subregister &variableOffQ, int maskOff)
{
setupTeardownRemask(T, index, true, nq, remQ, strategy, state, fixedOffQ, variableOffQ);
remaskLayout(T, index, column, layout, regs, strategy, state, maskOff);
setupTeardownRemask(T, index, false, nq, remQ, strategy, state);
}

#include "internal/namespace_end.hxx"
1 change: 1 addition & 0 deletions src/gpu/intel/jit/gemm/include/generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ class BLASKernelGenerator : public GENERATOR_BASE(hw) {

void setupTeardownRemask(Type T, int index, bool setup, int nq, ngen::Subregister remQ, const CommonStrategy &strategy, CommonState &state, int fixedOffQ = 0, const ngen::Subregister &variableOffQ = ngen::Subregister());
void remaskLayout(Type T, int index, bool column, const std::vector<RegisterBlock> &layout, const GRFMultirange &regs, const CommonStrategy &strategy, CommonState &state, int offset = 0);
void remaskLayoutSingle(Type T, int index, bool column, int nq, ngen::Subregister remQ, const std::vector<RegisterBlock> &layout, const GRFMultirange &regs, const CommonStrategy &strategy, CommonState &state, int fixedOffQ = 0, const ngen::Subregister &variableOffQ = ngen::Subregister(), int maskOff = 0);

void setAddrRemainder(Type T, const ngen::GRFRange &addr, const RegisterBlock &block, const ngen::Subregister &remR, const ngen::Subregister &remC, const MatrixAddressing &atype, const MatrixAddressingStrategy &astrategy, const CommonStrategy &strategy, CommonState &state);
void setAddrRemainder(Type T, const std::vector<ngen::GRFRange> &addr, const std::vector<RegisterBlock> &layout, const ngen::Subregister &remR, const ngen::Subregister &remC, const MatrixAddressing &atype, const MatrixAddressingStrategy &astrategy, const CommonStrategy &strategy, CommonState &state);
Expand Down

0 comments on commit 6c5e592

Please sign in to comment.