Skip to content

Commit 3b1daf0

Browse files
committed
xe: jit: gemm: TLB warmup support
1 parent c07fe05 commit 3b1daf0

10 files changed

+212
-4
lines changed

src/gpu/intel/jit/gemm/gemm_walk_orders.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ inline void gemm_linear_order_args(compute::kernel_arg_list_t &arg_list,
196196
arg_list.set(argn++, group_count);
197197
}
198198

199-
gws[0] = lws[0] * group_count;
199+
gws[0] = lws[0] * (group_count + info.extraWGs());
200200
gws[1] = lws[1];
201201
}
202202

src/gpu/intel/jit/gemm/generator/generator.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
#include "pieces/remask.cxx"
4949
#include "pieces/row_column_sums.cxx"
5050
#include "pieces/state_utils.cxx"
51+
#include "pieces/tlb_warmup.cxx"
5152
#include "pieces/walk_orders.cxx"
5253

5354
#include "pieces/quantization.cxx"

src/gpu/intel/jit/gemm/generator/pieces/driver_info.cxx

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ CommonDriverInfo BLASKernelGenerator<hw>::driverInfo(GEMMProblem problem, const
7777
if (problem.alpha.pointer()) info.flags |= FlagAlphaPtr;
7878
if (problem.beta.pointer()) info.flags |= FlagBetaPtr;
7979
if (strategy.nondeterministic(problem)) info.flags |= FlagNondeterministic;
80+
if (strategy.tlbWarmup) info.flags |= FlagExtraWG;
8081
info.flags |= (strategy.fillGoal << FlagShiftFillGoal) & FlagMaskFillGoal;
8182
info.slm = int(gemmSLMSize(hw, problem, strategy));
8283
info.perKSLM = int(gemmPerKSLMSize(hw, problem, strategy));

src/gpu/intel/jit/gemm/generator/pieces/gemm.cxx

+21-2
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,23 @@ void BLASKernelGenerator<hw>::gemm(GEMMProblem &problem, GEMMStrategy &strategy,
116116
jmpi(1 | f1[0], lPadThread);
117117
}
118118

119+
// Check if this is a TLB warmup thread, and perform warmup if so.
120+
if (strategy.tlbWarmup) {
121+
Label lNotTLBWarmup;
122+
state.groupIDMN = state.ra.alloc_sub<uint32_t>(getHint(HintType::LongTerm, strategy));
123+
add(1 | ge | f1[0], state.groupIDMN.d(), state.inputs.groupIDMN, -1);
124+
jmpi(1 | f1[0], lNotTLBWarmup);
125+
status << "TLB warmup" << status_stream::endl;
126+
auto mstate = state;
127+
moveR0(strategy, mstate);
128+
gemmGetBatchIDs(problem, strategy, mstate);
129+
gemmOffsetBatchABC(problem, strategy, mstate);
130+
gemmSetupABC(problem, strategy, mstate);
131+
gemmTLBWarmup(problem, strategy, mstate);
132+
epilogue(strategy, mstate);
133+
mark(lNotTLBWarmup);
134+
}
135+
119136
// Scale LDs/offsets.
120137
gemmScaleInputs(problem, strategy, state);
121138

@@ -232,8 +249,10 @@ void BLASKernelGenerator<hw>::gemm(GEMMProblem &problem, GEMMStrategy &strategy,
232249
if (!strategy.linearOrder()) stub();
233250
if (problem.batch != BatchMode::None) stub(); // need to wrangle groupIDK also
234251

235-
state.groupIDMN = state.ra.alloc_sub<uint32_t>(getHint(HintType::LongTerm, strategy));
236-
mov(1, state.groupIDMN, state.inputs.groupIDMN);
252+
if (state.groupIDMN == state.inputs.groupIDMN) {
253+
state.groupIDMN = state.ra.alloc_sub<uint32_t>(getHint(HintType::LongTerm, strategy));
254+
mov(1, state.groupIDMN, state.inputs.groupIDMN);
255+
}
237256

238257
if (state.effTempC == state.inputs.tempC)
239258
state.effTempC = state.ra.alloc_sub<uint64_t>(getHint(HintType::LongTerm, strategy));
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
/*******************************************************************************
2+
* INTEL CONFIDENTIAL
3+
* Copyright 2025 Intel Corporation.
4+
*
5+
* This software and the related documents are Intel copyrighted materials, and
6+
* your use of them is governed by the express license under which they were
7+
* provided to you (License). Unless the License provides otherwise, you may not
8+
* use, modify, copy, publish, distribute, disclose or transmit this software or
9+
* the related documents without Intel's prior written permission.
10+
*
11+
* This software and the related documents are provided as is, with no express
12+
* or implied warranties, other than those that are expressly stated in the
13+
* License.
14+
*******************************************************************************/
15+
16+
17+
#include "generator.hpp"
18+
#include "hw_utils.hpp"
19+
#include "layout_utils.hpp"
20+
#include "state_utils.hpp"
21+
#include "ngen_object_helpers.hpp"
22+
23+
#include "internal/namespace_start.hxx"
24+
25+
using namespace ngen;
26+
using namespace ngen::utils;
27+
using std::vector;
28+
29+
30+
31+
template <HW hw>
32+
void BLASKernelGenerator<hw>::gemmTLBWarmup(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state)
33+
{
34+
auto lid = state.ra.allocSub<uint32_t>();
35+
int whose = 0;
36+
37+
emad(1, lid, state.inputs.localIDM, state.inputs.localIDN, strategy.wg[LoopM], strategy, state);
38+
if (strategy.kParallelLocal)
39+
emad(1, lid, lid, state.inputs.localIDK, strategy.wg[LoopM] * strategy.wg[LoopN], strategy, state);
40+
41+
if (problem.quantized2DA()) {
42+
auto mq = state.ra.allocSub<uint32_t>();
43+
auto kq = state.ra.allocSub<uint32_t>();
44+
divDown(mq, state.inputs.m, problem.aqGroupM, strategy, state);
45+
divDown(kq, state.inputs.k, problem.aqGroupK, strategy, state);
46+
if (problem.aScale2D) {
47+
tlbWarmup(problem.Ta_scale, problem.A_scale, strategy.A_scale, state.inputs.aScalePtr,
48+
mq, kq, state.inputs.ldaq, lid, whose++, problem, strategy, state);
49+
}
50+
if (problem.aoPtrDims == 2) {
51+
tlbWarmup(problem.Tao, problem.AO, strategy.AO, state.inputs.aoPtr,
52+
mq, kq, state.inputs.ldaq, lid, whose++, problem, strategy, state);
53+
}
54+
state.ra.safeRelease(mq);
55+
state.ra.safeRelease(kq);
56+
}
57+
58+
if (problem.quantized2DB()) {
59+
auto kq = state.ra.allocSub<uint32_t>();
60+
auto nq = state.ra.allocSub<uint32_t>();
61+
divDown(kq, state.inputs.k, problem.bqGroupK, strategy, state);
62+
divDown(nq, state.inputs.n, problem.bqGroupN, strategy, state);
63+
if (problem.bScale2D) {
64+
tlbWarmup(problem.Tb_scale, problem.B_scale, strategy.B_scale, state.inputs.bScalePtr,
65+
kq, nq, state.inputs.ldbq, lid, whose++, problem, strategy, state);
66+
}
67+
if (problem.boPtrDims == 2) {
68+
tlbWarmup(problem.Tbo, problem.BO, strategy.BO, state.inputs.boPtr,
69+
kq, nq, state.inputs.ldbq, lid, whose++, problem, strategy, state);
70+
}
71+
state.ra.safeRelease(kq);
72+
state.ra.safeRelease(nq);
73+
}
74+
75+
tlbWarmup(problem.Ta_ext, problem.A, strategy.A, state.effA,
76+
state.inputs.m, state.inputs.k, state.inputs.lda, lid, whose++,
77+
problem, strategy, state);
78+
tlbWarmup(problem.Tb_ext, problem.B, strategy.B, state.effB,
79+
state.inputs.k, state.inputs.n, state.inputs.ldb, lid, whose++,
80+
problem, strategy, state);
81+
82+
state.ra.safeRelease(lid);
83+
}
84+
85+
template <HW hw>
86+
void BLASKernelGenerator<hw>::tlbWarmup(Type T, const MatrixAddressing &atype, const MatrixAddressingStrategy &astrategy,
87+
const Subregister &ptr, const Subregister &r, const Subregister &c,
88+
const Subregister &ld, const Subregister &lid, int whose,
89+
const CommonProblem &problem, const CommonStrategy &strategy, CommonState &state)
90+
{
91+
auto flag = state.raVFlag.alloc();
92+
const uint32_t byteLimit = 256 * 1024 * 1024;
93+
94+
auto bytes = state.ra.allocSub<uint64_t>();
95+
emul(1, bytes, ld, isColMajor(atype.layout) ? c : r, strategy, state);
96+
cmp(1 | nz | flag, bytes.ud(1), 0);
97+
cmp(1 | ~flag | gt | flag, bytes.ud(), byteLimit / T);
98+
emulConstant(1, bytes.ud(), bytes.ud(), T, strategy, state);
99+
mov(1 | flag, bytes.ud(), byteLimit);
100+
101+
state.raVFlag.safeRelease(flag);
102+
103+
tlbWarmup(astrategy.base, ptr, bytes.ud(), lid, whose, problem, strategy, state);
104+
105+
state.ra.safeRelease(bytes);
106+
}
107+
108+
template <HW hw>
109+
void BLASKernelGenerator<hw>::tlbWarmup(AddressBase base, const Subregister &ptr, const Subregister &bytes,
110+
const Subregister &lid, int whose,
111+
const CommonProblem &problem, const CommonStrategy &strategy, CommonState &state)
112+
{
113+
bool a64 = base.isA64();
114+
auto Taddr = a64 ? DataType::uq : DataType::ud;
115+
const int simd = elementsPerGRF<uint32_t>(hw);
116+
const int log2Stride = 16; // 64kb stride.
117+
const int log2TwiddleStride = 6;
118+
119+
int udStride = a64 ? 2 : 1;
120+
auto addr = state.ra.allocRange(udStride);
121+
auto addr0 = addr[0].retype(Taddr);
122+
auto addrLo = addr0.ud(0)(udStride);
123+
auto off = state.ra.allocRange(udStride);
124+
auto off0 = off[0].ud(0)(udStride);
125+
auto twiddle = state.ra.alloc().ud();
126+
auto data = state.ra.alloc().ud();
127+
auto count = state.ra.alloc().d();
128+
auto flag = state.raVFlag.alloc();
129+
130+
extendIndexVec(simd, state);
131+
132+
auto iv = accessIndexVec(0, state)(1);
133+
134+
cmp(1 | nz | flag, lid, whose); /* Check if we are responsible thread */
135+
136+
shl(simd, off0, iv, log2Stride);
137+
shl(simd, twiddle, iv, log2TwiddleStride);
138+
eadd(simd, addr0, ptr, off0, strategy, state);
139+
xor_(simd, addrLo, addrLo, twiddle); /* Perturb low bits to avoid cache hotspotting */
140+
141+
add(1, count, bytes, ((simd + 1) << log2Stride) - 1);
142+
shr(1, count, count, log2Stride);
143+
add(simd, count, count[0], -iv);
144+
145+
Label lTop, lSkip;
146+
jmpi(1 | flag, lSkip);
147+
148+
mark(lTop);
149+
add(simd | gt | flag, count, count, -simd);
150+
if (hw >= HW::XeHPC)
151+
load(simd | flag, null, D8U32 | L1C_L3C, base, addr);
152+
else if (hw >= HW::XeHPG)
153+
load(simd | flag, data, D8U32 | L1C_L3C, base, addr);
154+
else
155+
load(simd | flag, data, scattered_byte(), base, addr);
156+
xor_(simd, addrLo, addrLo, twiddle);
157+
add(simd, twiddle, twiddle, simd << log2TwiddleStride);
158+
and_(simd, twiddle, twiddle, 0xFFF); /* Don't cross 4K page boundaries */
159+
eadd(simd, addr0, addr0, simd << log2Stride, strategy, state);
160+
xor_(simd, addrLo, addrLo, twiddle);
161+
jmpi(1 | flag, lTop);
162+
mark(lSkip);
163+
164+
releaseIndexVec(state);
165+
state.raVFlag.safeRelease(flag);
166+
state.ra.safeRelease(off);
167+
state.ra.safeRelease(twiddle);
168+
state.ra.safeRelease(addr);
169+
state.ra.safeRelease(data);
170+
state.ra.safeRelease(count);
171+
}
172+
173+
#include "internal/namespace_end.hxx"

src/gpu/intel/jit/gemm/generator/strategy.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,9 @@ void GEMMStrategy::preflight(HW hw, const GEMMProblem &problem)
182182

183183
extendedAtomicFMA &= !problem.needsASums() && !problem.needsBSums();
184184

185+
if (tlbWarmup && !linearOrder())
186+
cWalkOrder = WalkOrder::SimpleLinear;
187+
185188
// Default SIMD setting.
186189
if (fmaSIMD == 0) {
187190
fmaSIMD = std::min(32, 2 * GRF::bytes(hw) / std::max<int>({Ta.paddedSize(), Tb.paddedSize(), Tc.paddedSize()}));

src/gpu/intel/jit/gemm/generator/strategy_parser.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,8 @@ void parseStrategy(const char *str, HW hw, const GEMMProblem &problem, GEMMStrat
379379
strategy.reverse[LoopM] = true;
380380
else if (mod == "rn")
381381
strategy.reverse[LoopN] = true;
382+
else if (mod == "wt")
383+
strategy.tlbWarmup = true;
382384
else if (mod == "kb" || mod == "kv") {
383385
if (mod == "kb") strategy.kParallel = true;
384386
if (mod == "kv") {
@@ -886,6 +888,7 @@ std::string unparseStrategy(HW hw, const GEMMProblem &problem, const GEMMStrateg
886888
if (strategy.panelCheck) s << " up";
887889
if (strategy.reverse[LoopM]) s << " rm";
888890
if (strategy.reverse[LoopN]) s << " rn";
891+
if (strategy.tlbWarmup) s << " wt";
889892

890893
if (strategy.checkAdd32 && !strategy.emulate.emulate64) s << " ch";
891894
if (!strategy.checkAdd32 && strategy.emulate.emulate64) s << " nch";

src/gpu/intel/jit/gemm/include/driver_info.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ enum DriverInfoFlags : uint32_t {
6969
FlagNondeterministic = 0x4000, // Kernel produces nondeterministic results.
7070
FlagMaskFillGoal = 0xF0000, // Fraction of available thread slots to fill, in sixteenths
7171
FlagShiftFillGoal = 16, // (starting bit)
72+
FlagExtraWG = 0x400000, // Add an additional workgroup.
7273
};
7374

7475
// Driver information, shared by all kernel types.
@@ -116,6 +117,7 @@ struct CommonDriverInfo {
116117
bool betaPtr() const { return flags & FlagBetaPtr; }
117118
bool fixedWGK() const { return flags & FlagFixedWGK; }
118119
bool nondeterministic() const { return flags & FlagNondeterministic; }
120+
int extraWGs() const { return (flags & FlagExtraWG) ? 1 : 0; }
119121

120122
int wgTile(LoopType l) const { return unroll[l] * wg[l]; }
121123
int kPadding() const { return (kParallel() || kParallelVariable()) ? blockingAlt[LoopK] : 0; }

src/gpu/intel/jit/gemm/include/generator.hpp

+5
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,11 @@ class BLASKernelGenerator : public GENERATOR_BASE(hw) {
465465
bool gemmFusedPostOpsFinalize(ngen::Label &labelLateExit, GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state);
466466
void gemmRedirectToTempC(GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state);
467467

468+
// tlb_warmup.cxx
469+
void tlbWarmup(ngen::AddressBase base, const ngen::Subregister &ptr, const ngen::Subregister &bytes, const ngen::Subregister &lid, int whose, const CommonProblem &problem, const CommonStrategy &strategy, CommonState &state);
470+
void tlbWarmup(Type T, const MatrixAddressing &atype, const MatrixAddressingStrategy &astrategy, const ngen::Subregister &base, const ngen::Subregister &r, const ngen::Subregister &c, const ngen::Subregister &ld, const ngen::Subregister &lid, int whose, const CommonProblem &problem, const CommonStrategy &strategy, CommonState &state);
471+
void gemmTLBWarmup(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state);
472+
468473
// gemm_setup.cpp
469474
void gemmCheck32(const GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state);
470475
void gemmGetBatchIDs(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state);

src/gpu/intel/jit/gemm/include/strategy.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,8 @@ struct GEMMStrategyPOD : public CommonStrategy {
212212
bool kDescRem = false; // Allow descriptor-based k remainder handling for A/B.
213213
bool slmA = false, slmB = false; // Whether to copy A/B to SLM.
214214
bool splitCopy = false; // Separate SLM copy and compute threads?
215-
ZPAD(C, 2)
215+
bool tlbWarmup = false; // Enable TLB warmup?
216+
ZPAD(C, 1)
216217
int slmBuffers = 0; // # of A/B SLM buffers, 0 for none.
217218
int unrollKSLM = 0; // k unroll for SLM copies (0 = auto = unroll[LoopK]/slmCopies)
218219
int unrollKSLMMasked = 0; // Alternate value to use with masking (0 = same as unrollKSLM)

0 commit comments

Comments
 (0)