diff --git a/src/gpu/intel/jit/gemm/generator/pieces/k_loop.cxx b/src/gpu/intel/jit/gemm/generator/pieces/k_loop.cxx index 79ab46fa935..a6d10135f7d 100644 --- a/src/gpu/intel/jit/gemm/generator/pieces/k_loop.cxx +++ b/src/gpu/intel/jit/gemm/generator/pieces/k_loop.cxx @@ -871,35 +871,70 @@ void BLASKernelGenerator::kLoop(KLoop type, const GEMMProblem &problem, GEMM if (remaskA && remaskB && loadBFirst) ls.swapLast2(); - // A/B quantization parameter repacking. + // A/B quantization parameter repacking and remasking. auto reqRepackAq = every(kaq_load); auto reqRepackBq = every(kbq_load); auto reqRepackAqLate = every(kaq_loadLate); auto reqRepackBqLate = every(kbq_loadLate); - bool remaskAs = as2D && (minOPCount > 1) && (problem.aqGroupK == 1); - bool remaskBs = bs2D && (minOPCount > 1) && (problem.bqGroupK == 1); + bool remaskAq = (ao2D || as2D) && (minOPCount > 1) && (problem.aqGroupK == 1); + bool remaskBq = (ao2D || bs2D) && (minOPCount > 1) && (problem.bqGroupK == 1); int iremaskScale = 2; - if (dequantize2DA) ls.schedule(reqRepackAq, [&](Iteration h) { - if (remaskAs) { - int ms, ks; + + auto doRemaskAq = [&](Iteration h, bool slm) { + if (!remaskAq) return; + int ms, ks; + Subregister offK; + if (slm && (state.effCoopA == CoopSplit::K || state.effCoopA == CoopSplit::FullK)) { + offK = state.ra.allocSub(); + mulConstant(1, offK, state.lidN, state.ka_slm); + } + if (as2D) { getLayoutDims(state.A_scaleLayout, ms, ks); - setupTeardownRemask(Ta_scale, iremaskScale, true, ks, state.K, strategy, state, -h.counterOffset()); - remaskLayout(Ta_scale, iremaskScale, true, state.A_scaleLayout, state.A_scaleRegs, strategy, state, h % ks); - setupTeardownRemask(Ta_scale, iremaskScale, false, ks, state.K, strategy, state); + remaskLayoutSingle(Ta_scale, iremaskScale, true, ks, state.K, + state.A_scaleLayout, state.A_scaleRegs, strategy, state, + -h.counterOffset(), offK); + } + if (ao2D) { + getLayoutDims(state.A_offsetLayout, ms, ks); + remaskLayoutSingle(Tao, iremaskScale, true, ks, state.K, + state.A_offsetLayout, state.A_offsetRegs, strategy, state, + -h.counterOffset(), offK); } + state.ra.safeRelease(offK); + }; + + auto doRemaskBq = [&](Iteration h, bool slm) { + if (!remaskBq) return; + int ks, ns; + Subregister offK; + if (slm && (state.effCoopB == CoopSplit::K || state.effCoopB == CoopSplit::FullK)) { + offK = state.ra.allocSub(); + mulConstant(1, offK, state.lidM, state.ka_slm); + } + if (bs2D) { + getLayoutDims(state.B_scaleLayout, ks, ns); + remaskLayoutSingle(Tb_scale, iremaskScale, false, ks, state.K, + state.B_scaleLayout, state.B_scaleRegs, strategy, state, + -h.counterOffset(), offK); + } + if (bo2D) { + getLayoutDims(state.B_offsetLayout, ks, ns); + remaskLayoutSingle(Tbo, iremaskScale, false, ks, state.K, + state.B_offsetLayout, state.B_offsetRegs, strategy, state, + -h.counterOffset(), offK); + } + state.ra.safeRelease(offK); + }; + + if (dequantize2DA) ls.schedule(reqRepackAq, [&](Iteration h) { + if (A_remActive(h)) doRemaskAq(h, false); if (ao2D) gemmRepack2DOffsetData(Ta_ext, Tao, state.Tao_int, state.A_offsetLayout, state.Ar_offsetLayout, state.A_offsetRegs, state.Ar_offsetRegs, problem, strategy, state); if (as2D) gemmRepack2DQuantizationData(Ta_scale, state.Ta_scaleOp, state.A_scaleLayout, state.Ar_scaleLayout, state.A_scaleRegs, state.Ar_scaleRegs, problem, strategy, state); }); if (dequantize2DB) ls.schedule(reqRepackBq, [&](Iteration h) { - if (remaskBs) { - int ks, ns; - getLayoutDims(state.B_scaleLayout, ks, ns); - setupTeardownRemask(Tb_scale, iremaskScale, true, ks, state.K, strategy, state, -h.counterOffset()); - remaskLayout(Tb_scale, iremaskScale, false, state.B_scaleLayout, state.B_scaleRegs, strategy, state, h % ks); - setupTeardownRemask(Tb_scale, iremaskScale, false, ks, state.K, strategy, state); - } + if (B_remActive(h)) doRemaskBq(h, false); if (bo2D) gemmRepack2DOffsetData(Tb_ext, Tbo, state.Tbo_int, state.B_offsetLayout, state.Br_offsetLayout, state.B_offsetRegs, state.Br_offsetRegs, problem, strategy, state); if (bs2D) gemmRepack2DQuantizationData(Tb_scale, state.Tb_scaleOp, state.B_scaleLayout, state.Br_scaleLayout, state.B_scaleRegs, state.Br_scaleRegs, problem, strategy, state); }); @@ -1058,10 +1093,12 @@ void BLASKernelGenerator::kLoop(KLoop type, const GEMMProblem &problem, GEMM if (slmDequantize2D) ls.schedule(reqSLMRepackQ, [&](Iteration h) { if (slmDequantize2DA) { + if (slmRemActive(h)) doRemaskAq(h, true); if (ao2D) gemmRepack2DOffsetData(Ta_ext, problem.Tao, state.Tao_int, state.A_offsetLayout, state.Ar_offsetLayout, state.A_offsetRegs, state.Ar_offsetRegs, problem, strategy, state); if (as2D) gemmRepack2DQuantizationData(problem.Ta_scale, state.Ta_scaleOp, state.A_scaleLayout, state.Ar_scaleLayout, state.A_scaleRegs, state.Ar_scaleRegs, problem, strategy, state); } if (slmDequantize2DB) { + if (slmRemActive(h)) doRemaskBq(h, true); if (bo2D) gemmRepack2DOffsetData(Tb_ext, problem.Tbo, state.Tbo_int, state.B_offsetLayout, state.Br_offsetLayout, state.B_offsetRegs, state.Br_offsetRegs, problem, strategy, state); if (bs2D) gemmRepack2DQuantizationData(problem.Tb_scale, state.Tb_scaleOp, state.B_scaleLayout, state.Br_scaleLayout, state.B_scaleRegs, state.Br_scaleRegs, problem, strategy, state); } diff --git a/src/gpu/intel/jit/gemm/generator/pieces/remask.cxx b/src/gpu/intel/jit/gemm/generator/pieces/remask.cxx index c5d28f50262..0cf27c10caf 100644 --- a/src/gpu/intel/jit/gemm/generator/pieces/remask.cxx +++ b/src/gpu/intel/jit/gemm/generator/pieces/remask.cxx @@ -110,7 +110,9 @@ void BLASKernelGenerator::setupTeardownRemask(Type T, int index, bool setup, } template -void BLASKernelGenerator::remaskLayout(Type T, int index, bool column, const std::vector &layout, const GRFMultirange ®s, const CommonStrategy &strategy, CommonState &state, int offset) +void BLASKernelGenerator::remaskLayout(Type T, int index, bool column, + const std::vector &layout, const GRFMultirange ®s, + const CommonStrategy &strategy, CommonState &state, int offset) { for (auto &block: layout) { auto crosspack = block.crosspack; @@ -164,4 +166,15 @@ void BLASKernelGenerator::remaskLayout(Type T, int index, bool column, const } } +template +void BLASKernelGenerator::remaskLayoutSingle(Type T, int index, bool column, int nq, Subregister remQ, + const std::vector &layout, const GRFMultirange ®s, + const CommonStrategy &strategy, CommonState &state, + int fixedOffQ, const Subregister &variableOffQ, int maskOff) +{ + setupTeardownRemask(T, index, true, nq, remQ, strategy, state, fixedOffQ, variableOffQ); + remaskLayout(T, index, column, layout, regs, strategy, state, maskOff); + setupTeardownRemask(T, index, false, nq, remQ, strategy, state); +} + #include "internal/namespace_end.hxx" diff --git a/src/gpu/intel/jit/gemm/include/generator.hpp b/src/gpu/intel/jit/gemm/include/generator.hpp index 7e7e40bcd69..a2c04813df1 100644 --- a/src/gpu/intel/jit/gemm/include/generator.hpp +++ b/src/gpu/intel/jit/gemm/include/generator.hpp @@ -276,6 +276,7 @@ class BLASKernelGenerator : public GENERATOR_BASE(hw) { void setupTeardownRemask(Type T, int index, bool setup, int nq, ngen::Subregister remQ, const CommonStrategy &strategy, CommonState &state, int fixedOffQ = 0, const ngen::Subregister &variableOffQ = ngen::Subregister()); void remaskLayout(Type T, int index, bool column, const std::vector &layout, const GRFMultirange ®s, const CommonStrategy &strategy, CommonState &state, int offset = 0); + void remaskLayoutSingle(Type T, int index, bool column, int nq, ngen::Subregister remQ, const std::vector &layout, const GRFMultirange ®s, const CommonStrategy &strategy, CommonState &state, int fixedOffQ = 0, const ngen::Subregister &variableOffQ = ngen::Subregister(), int maskOff = 0); void setAddrRemainder(Type T, const ngen::GRFRange &addr, const RegisterBlock &block, const ngen::Subregister &remR, const ngen::Subregister &remC, const MatrixAddressing &atype, const MatrixAddressingStrategy &astrategy, const CommonStrategy &strategy, CommonState &state); void setAddrRemainder(Type T, const std::vector &addr, const std::vector &layout, const ngen::Subregister &remR, const ngen::Subregister &remC, const MatrixAddressing &atype, const MatrixAddressingStrategy &astrategy, const CommonStrategy &strategy, CommonState &state);