Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xe: jit: gemm: TLB warmup #2631

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/gpu/intel/jit/gemm/gemm_walk_orders.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2024 Intel Corporation
* Copyright 2021-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -196,7 +196,7 @@ inline void gemm_linear_order_args(compute::kernel_arg_list_t &arg_list,
arg_list.set(argn++, group_count);
}

gws[0] = lws[0] * group_count;
gws[0] = lws[0] * (group_count + info.extraWGs());
gws[1] = lws[1];
}

Expand Down
3 changes: 2 additions & 1 deletion src/gpu/intel/jit/gemm/generator/generator.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2024 Intel Corporation
* Copyright 2019-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -48,6 +48,7 @@
#include "pieces/remask.cxx"
#include "pieces/row_column_sums.cxx"
#include "pieces/state_utils.cxx"
#include "pieces/tlb_warmup.cxx"
#include "pieces/walk_orders.cxx"

#include "pieces/quantization.cxx"
Expand Down
3 changes: 2 additions & 1 deletion src/gpu/intel/jit/gemm/generator/pieces/driver_info.cxx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2024 Intel Corporation
* Copyright 2019-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -77,6 +77,7 @@ CommonDriverInfo BLASKernelGenerator<hw>::driverInfo(GEMMProblem problem, const
if (problem.alpha.pointer()) info.flags |= FlagAlphaPtr;
if (problem.beta.pointer()) info.flags |= FlagBetaPtr;
if (strategy.nondeterministic(problem)) info.flags |= FlagNondeterministic;
if (strategy.tlbWarmup) info.flags |= FlagExtraWG;
info.flags |= (strategy.fillGoal << FlagShiftFillGoal) & FlagMaskFillGoal;
info.slm = int(gemmSLMSize(hw, problem, strategy));
info.perKSLM = int(gemmPerKSLMSize(hw, problem, strategy));
Expand Down
23 changes: 21 additions & 2 deletions src/gpu/intel/jit/gemm/generator/pieces/gemm.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,23 @@ void BLASKernelGenerator<hw>::gemm(GEMMProblem &problem, GEMMStrategy &strategy,
jmpi(1 | f1[0], lPadThread);
}

// Check if this is a TLB warmup thread, and perform warmup if so.
if (strategy.tlbWarmup) {
Label lNotTLBWarmup;
state.groupIDMN = state.ra.alloc_sub<uint32_t>(getHint(HintType::LongTerm, strategy));
add(1 | ge | f1[0], state.groupIDMN.d(), state.inputs.groupIDMN, -1);
jmpi(1 | f1[0], lNotTLBWarmup);
status << "TLB warmup" << status_stream::endl;
auto mstate = state;
moveR0(strategy, mstate);
gemmGetBatchIDs(problem, strategy, mstate);
gemmOffsetBatchABC(problem, strategy, mstate);
gemmSetupABC(problem, strategy, mstate);
gemmTLBWarmup(problem, strategy, mstate);
epilogue(strategy, mstate);
mark(lNotTLBWarmup);
}

// Scale LDs/offsets.
gemmScaleInputs(problem, strategy, state);

Expand Down Expand Up @@ -232,8 +249,10 @@ void BLASKernelGenerator<hw>::gemm(GEMMProblem &problem, GEMMStrategy &strategy,
if (!strategy.linearOrder()) stub();
if (problem.batch != BatchMode::None) stub(); // need to wrangle groupIDK also

state.groupIDMN = state.ra.alloc_sub<uint32_t>(getHint(HintType::LongTerm, strategy));
mov(1, state.groupIDMN, state.inputs.groupIDMN);
if (state.groupIDMN == state.inputs.groupIDMN) {
state.groupIDMN = state.ra.alloc_sub<uint32_t>(getHint(HintType::LongTerm, strategy));
mov(1, state.groupIDMN, state.inputs.groupIDMN);
}

if (state.effTempC == state.inputs.tempC)
state.effTempC = state.ra.alloc_sub<uint64_t>(getHint(HintType::LongTerm, strategy));
Expand Down
174 changes: 174 additions & 0 deletions src/gpu/intel/jit/gemm/generator/pieces/tlb_warmup.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*******************************************************************************
* Copyright 2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/


#include "generator.hpp"
#include "hw_utils.hpp"
#include "layout_utils.hpp"
#include "state_utils.hpp"
#include "ngen_object_helpers.hpp"

#include "internal/namespace_start.hxx"

using namespace ngen;
using namespace ngen::utils;
using std::vector;



template <HW hw>
void BLASKernelGenerator<hw>::gemmTLBWarmup(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state)
{
auto lid = state.ra.allocSub<uint32_t>();
int whose = 0;

emad(1, lid, state.inputs.localIDM, state.inputs.localIDN, strategy.wg[LoopM], strategy, state);
if (strategy.kParallelLocal)
emad(1, lid, lid, state.inputs.localIDK, strategy.wg[LoopM] * strategy.wg[LoopN], strategy, state);

if (problem.quantized2DA()) {
auto mq = state.ra.allocSub<uint32_t>();
auto kq = state.ra.allocSub<uint32_t>();
divDown(mq, state.inputs.m, problem.aqGroupM, strategy, state);
divDown(kq, state.inputs.k, problem.aqGroupK, strategy, state);
if (problem.aScale2D) {
tlbWarmup(problem.Ta_scale, problem.A_scale, strategy.A_scale, state.inputs.aScalePtr,
mq, kq, state.inputs.ldaq, lid, whose++, problem, strategy, state);
}
if (problem.aoPtrDims == 2) {
tlbWarmup(problem.Tao, problem.AO, strategy.AO, state.inputs.aoPtr,
mq, kq, state.inputs.ldaq, lid, whose++, problem, strategy, state);
}
state.ra.safeRelease(mq);
state.ra.safeRelease(kq);
}

if (problem.quantized2DB()) {
auto kq = state.ra.allocSub<uint32_t>();
auto nq = state.ra.allocSub<uint32_t>();
divDown(kq, state.inputs.k, problem.bqGroupK, strategy, state);
divDown(nq, state.inputs.n, problem.bqGroupN, strategy, state);
if (problem.bScale2D) {
tlbWarmup(problem.Tb_scale, problem.B_scale, strategy.B_scale, state.inputs.bScalePtr,
kq, nq, state.inputs.ldbq, lid, whose++, problem, strategy, state);
}
if (problem.boPtrDims == 2) {
tlbWarmup(problem.Tbo, problem.BO, strategy.BO, state.inputs.boPtr,
kq, nq, state.inputs.ldbq, lid, whose++, problem, strategy, state);
}
state.ra.safeRelease(kq);
state.ra.safeRelease(nq);
}

tlbWarmup(problem.Ta_ext, problem.A, strategy.A, state.effA,
state.inputs.m, state.inputs.k, state.inputs.lda, lid, whose++,
problem, strategy, state);
tlbWarmup(problem.Tb_ext, problem.B, strategy.B, state.effB,
state.inputs.k, state.inputs.n, state.inputs.ldb, lid, whose++,
problem, strategy, state);

state.ra.safeRelease(lid);
}

template <HW hw>
void BLASKernelGenerator<hw>::tlbWarmup(Type T, const MatrixAddressing &atype, const MatrixAddressingStrategy &astrategy,
const Subregister &ptr, const Subregister &r, const Subregister &c,
const Subregister &ld, const Subregister &lid, int whose,
const CommonProblem &problem, const CommonStrategy &strategy, CommonState &state)
{
auto flag = state.raVFlag.alloc();
const uint32_t byteLimit = 256 * 1024 * 1024;

auto bytes = state.ra.allocSub<uint64_t>();
emul(1, bytes, ld, isColMajor(atype.layout) ? c : r, strategy, state);
cmp(1 | nz | flag, bytes.ud(1), 0);
cmp(1 | ~flag | gt | flag, bytes.ud(), byteLimit / T);
emulConstant(1, bytes.ud(), bytes.ud(), T, strategy, state);
mov(1 | flag, bytes.ud(), byteLimit);

state.raVFlag.safeRelease(flag);

tlbWarmup(astrategy.base, ptr, bytes.ud(), lid, whose, problem, strategy, state);

state.ra.safeRelease(bytes);
}

template <HW hw>
void BLASKernelGenerator<hw>::tlbWarmup(AddressBase base, const Subregister &ptr, const Subregister &bytes,
const Subregister &lid, int whose,
const CommonProblem &problem, const CommonStrategy &strategy, CommonState &state)
{
bool a64 = base.isA64();
auto Taddr = a64 ? DataType::uq : DataType::ud;
const int simd = elementsPerGRF<uint32_t>(hw);
const int log2Stride = 16; // 64kb stride.
const int log2TwiddleStride = 6;

int udStride = a64 ? 2 : 1;
auto addr = state.ra.allocRange(udStride);
auto addr0 = addr[0].retype(Taddr);
auto addrLo = addr0.ud(0)(udStride);
auto off = state.ra.allocRange(udStride);
auto off0 = off[0].ud(0)(udStride);
auto twiddle = state.ra.alloc().ud();
auto data = state.ra.alloc().ud();
auto count = state.ra.alloc().d();
auto flag = state.raVFlag.alloc();

extendIndexVec(simd, state);

auto iv = accessIndexVec(0, state)(1);

cmp(1 | nz | flag, lid, whose); /* Check if we are responsible thread */

shl(simd, off0, iv, log2Stride);
shl(simd, twiddle, iv, log2TwiddleStride);
eadd(simd, addr0, ptr, off0, strategy, state);
xor_(simd, addrLo, addrLo, twiddle); /* Perturb low bits to avoid cache hotspotting */

add(1, count, bytes, ((simd + 1) << log2Stride) - 1);
shr(1, count, count, log2Stride);
add(simd, count, count[0], -iv);

Label lTop, lSkip;
jmpi(1 | flag, lSkip);

mark(lTop);
add(simd | gt | flag, count, count, -simd);
if (hw >= HW::XeHPC)
load(simd | flag, null, D8U32 | L1C_L3C, base, addr);
else if (hw >= HW::XeHPG)
load(simd | flag, data, D8U32 | L1C_L3C, base, addr);
else
load(simd | flag, data, scattered_byte(), base, addr);
xor_(simd, addrLo, addrLo, twiddle);
add(simd, twiddle, twiddle, simd << log2TwiddleStride);
and_(simd, twiddle, twiddle, 0xFFF); /* Don't cross 4K page boundaries */
eadd(simd, addr0, addr0, simd << log2Stride, strategy, state);
xor_(simd, addrLo, addrLo, twiddle);
jmpi(1 | flag, lTop);
mark(lSkip);

releaseIndexVec(state);
state.raVFlag.safeRelease(flag);
state.ra.safeRelease(off);
state.ra.safeRelease(twiddle);
state.ra.safeRelease(addr);
state.ra.safeRelease(data);
state.ra.safeRelease(count);
}

#include "internal/namespace_end.hxx"
3 changes: 3 additions & 0 deletions src/gpu/intel/jit/gemm/generator/strategy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ void GEMMStrategy::preflight(HW hw, const GEMMProblem &problem)

extendedAtomicFMA &= !problem.needsASums() && !problem.needsBSums();

if (tlbWarmup && !linearOrder())
cWalkOrder = WalkOrder::SimpleLinear;

// Default SIMD setting.
if (fmaSIMD == 0) {
fmaSIMD = std::min(32, 2 * GRF::bytes(hw) / std::max<int>({Ta.paddedSize(), Tb.paddedSize(), Tc.paddedSize()}));
Expand Down
3 changes: 3 additions & 0 deletions src/gpu/intel/jit/gemm/generator/strategy_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ void parseStrategy(const char *str, HW hw, const GEMMProblem &problem, GEMMStrat
strategy.reverse[LoopM] = true;
else if (mod == "rn")
strategy.reverse[LoopN] = true;
else if (mod == "wt")
strategy.tlbWarmup = true;
else if (mod == "kb" || mod == "kv") {
if (mod == "kb") strategy.kParallel = true;
if (mod == "kv") {
Expand Down Expand Up @@ -886,6 +888,7 @@ std::string unparseStrategy(HW hw, const GEMMProblem &problem, const GEMMStrateg
if (strategy.panelCheck) s << " up";
if (strategy.reverse[LoopM]) s << " rm";
if (strategy.reverse[LoopN]) s << " rn";
if (strategy.tlbWarmup) s << " wt";

if (strategy.checkAdd32 && !strategy.emulate.emulate64) s << " ch";
if (!strategy.checkAdd32 && strategy.emulate.emulate64) s << " nch";
Expand Down
4 changes: 3 additions & 1 deletion src/gpu/intel/jit/gemm/include/driver_info.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2024 Intel Corporation
* Copyright 2019-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -69,6 +69,7 @@ enum DriverInfoFlags : uint32_t {
FlagNondeterministic = 0x4000, // Kernel produces nondeterministic results.
FlagMaskFillGoal = 0xF0000, // Fraction of available thread slots to fill, in sixteenths
FlagShiftFillGoal = 16, // (starting bit)
FlagExtraWG = 0x400000, // Add an additional workgroup.
};

// Driver information, shared by all kernel types.
Expand Down Expand Up @@ -116,6 +117,7 @@ struct CommonDriverInfo {
bool betaPtr() const { return flags & FlagBetaPtr; }
bool fixedWGK() const { return flags & FlagFixedWGK; }
bool nondeterministic() const { return flags & FlagNondeterministic; }
int extraWGs() const { return (flags & FlagExtraWG) ? 1 : 0; }

int wgTile(LoopType l) const { return unroll[l] * wg[l]; }
int kPadding() const { return (kParallel() || kParallelVariable()) ? blockingAlt[LoopK] : 0; }
Expand Down
5 changes: 5 additions & 0 deletions src/gpu/intel/jit/gemm/include/generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,11 @@ class BLASKernelGenerator : public GENERATOR_BASE(hw) {
bool gemmFusedPostOpsFinalize(ngen::Label &labelLateExit, GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state);
void gemmRedirectToTempC(GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state);

// tlb_warmup.cxx
void tlbWarmup(ngen::AddressBase base, const ngen::Subregister &ptr, const ngen::Subregister &bytes, const ngen::Subregister &lid, int whose, const CommonProblem &problem, const CommonStrategy &strategy, CommonState &state);
void tlbWarmup(Type T, const MatrixAddressing &atype, const MatrixAddressingStrategy &astrategy, const ngen::Subregister &base, const ngen::Subregister &r, const ngen::Subregister &c, const ngen::Subregister &ld, const ngen::Subregister &lid, int whose, const CommonProblem &problem, const CommonStrategy &strategy, CommonState &state);
void gemmTLBWarmup(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state);

// gemm_setup.cpp
void gemmCheck32(const GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state);
void gemmGetBatchIDs(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state);
Expand Down
3 changes: 2 additions & 1 deletion src/gpu/intel/jit/gemm/include/strategy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ struct GEMMStrategyPOD : public CommonStrategy {
bool kDescRem = false; // Allow descriptor-based k remainder handling for A/B.
bool slmA = false, slmB = false; // Whether to copy A/B to SLM.
bool splitCopy = false; // Separate SLM copy and compute threads?
ZPAD(C, 2)
bool tlbWarmup = false; // Enable TLB warmup?
ZPAD(C, 1)
int slmBuffers = 0; // # of A/B SLM buffers, 0 for none.
int unrollKSLM = 0; // k unroll for SLM copies (0 = auto = unroll[LoopK]/slmCopies)
int unrollKSLMMasked = 0; // Alternate value to use with masking (0 = same as unrollKSLM)
Expand Down
Loading