Skip to content

Commit

Permalink
xe: jit: gemm: TLB warmup support
Browse files Browse the repository at this point in the history
  • Loading branch information
petercad committed Feb 8, 2025
1 parent c07fe05 commit 65fd254
Show file tree
Hide file tree
Showing 10 changed files with 217 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/gpu/intel/jit/gemm/gemm_walk_orders.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2024 Intel Corporation
* Copyright 2021-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -196,7 +196,7 @@ inline void gemm_linear_order_args(compute::kernel_arg_list_t &arg_list,
arg_list.set(argn++, group_count);
}

gws[0] = lws[0] * group_count;
gws[0] = lws[0] * (group_count + info.extraWGs());
gws[1] = lws[1];
}

Expand Down
3 changes: 2 additions & 1 deletion src/gpu/intel/jit/gemm/generator/generator.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2024 Intel Corporation
* Copyright 2019-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -48,6 +48,7 @@
#include "pieces/remask.cxx"
#include "pieces/row_column_sums.cxx"
#include "pieces/state_utils.cxx"
#include "pieces/tlb_warmup.cxx"
#include "pieces/walk_orders.cxx"

#include "pieces/quantization.cxx"
Expand Down
3 changes: 2 additions & 1 deletion src/gpu/intel/jit/gemm/generator/pieces/driver_info.cxx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2024 Intel Corporation
* Copyright 2019-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -77,6 +77,7 @@ CommonDriverInfo BLASKernelGenerator<hw>::driverInfo(GEMMProblem problem, const
if (problem.alpha.pointer()) info.flags |= FlagAlphaPtr;
if (problem.beta.pointer()) info.flags |= FlagBetaPtr;
if (strategy.nondeterministic(problem)) info.flags |= FlagNondeterministic;
if (strategy.tlbWarmup) info.flags |= FlagExtraWG;
info.flags |= (strategy.fillGoal << FlagShiftFillGoal) & FlagMaskFillGoal;
info.slm = int(gemmSLMSize(hw, problem, strategy));
info.perKSLM = int(gemmPerKSLMSize(hw, problem, strategy));
Expand Down
23 changes: 21 additions & 2 deletions src/gpu/intel/jit/gemm/generator/pieces/gemm.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,23 @@ void BLASKernelGenerator<hw>::gemm(GEMMProblem &problem, GEMMStrategy &strategy,
jmpi(1 | f1[0], lPadThread);
}

// Check if this is a TLB warmup thread, and perform warmup if so.
if (strategy.tlbWarmup) {
Label lNotTLBWarmup;
state.groupIDMN = state.ra.alloc_sub<uint32_t>(getHint(HintType::LongTerm, strategy));
add(1 | ge | f1[0], state.groupIDMN.d(), state.inputs.groupIDMN, -1);
jmpi(1 | f1[0], lNotTLBWarmup);
status << "TLB warmup" << status_stream::endl;
auto mstate = state;
moveR0(strategy, mstate);
gemmGetBatchIDs(problem, strategy, mstate);
gemmOffsetBatchABC(problem, strategy, mstate);
gemmSetupABC(problem, strategy, mstate);
gemmTLBWarmup(problem, strategy, mstate);
epilogue(strategy, mstate);
mark(lNotTLBWarmup);
}

// Scale LDs/offsets.
gemmScaleInputs(problem, strategy, state);

Expand Down Expand Up @@ -232,8 +249,10 @@ void BLASKernelGenerator<hw>::gemm(GEMMProblem &problem, GEMMStrategy &strategy,
if (!strategy.linearOrder()) stub();
if (problem.batch != BatchMode::None) stub(); // need to wrangle groupIDK also

state.groupIDMN = state.ra.alloc_sub<uint32_t>(getHint(HintType::LongTerm, strategy));
mov(1, state.groupIDMN, state.inputs.groupIDMN);
if (state.groupIDMN == state.inputs.groupIDMN) {
state.groupIDMN = state.ra.alloc_sub<uint32_t>(getHint(HintType::LongTerm, strategy));
mov(1, state.groupIDMN, state.inputs.groupIDMN);
}

if (state.effTempC == state.inputs.tempC)
state.effTempC = state.ra.alloc_sub<uint64_t>(getHint(HintType::LongTerm, strategy));
Expand Down
174 changes: 174 additions & 0 deletions src/gpu/intel/jit/gemm/generator/pieces/tlb_warmup.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*******************************************************************************
* Copyright 2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/


#include "generator.hpp"
#include "hw_utils.hpp"
#include "layout_utils.hpp"
#include "state_utils.hpp"
#include "ngen_object_helpers.hpp"

#include "internal/namespace_start.hxx"

using namespace ngen;
using namespace ngen::utils;
using std::vector;



template <HW hw>
void BLASKernelGenerator<hw>::gemmTLBWarmup(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state)
{
auto lid = state.ra.allocSub<uint32_t>();
int whose = 0;

emad(1, lid, state.inputs.localIDM, state.inputs.localIDN, strategy.wg[LoopM], strategy, state);
if (strategy.kParallelLocal)
emad(1, lid, lid, state.inputs.localIDK, strategy.wg[LoopM] * strategy.wg[LoopN], strategy, state);

if (problem.quantized2DA()) {
auto mq = state.ra.allocSub<uint32_t>();
auto kq = state.ra.allocSub<uint32_t>();
divDown(mq, state.inputs.m, problem.aqGroupM, strategy, state);
divDown(kq, state.inputs.k, problem.aqGroupK, strategy, state);
if (problem.aScale2D) {
tlbWarmup(problem.Ta_scale, problem.A_scale, strategy.A_scale, state.inputs.aScalePtr,
mq, kq, state.inputs.ldaq, lid, whose++, problem, strategy, state);
}
if (problem.aoPtrDims == 2) {
tlbWarmup(problem.Tao, problem.AO, strategy.AO, state.inputs.aoPtr,
mq, kq, state.inputs.ldaq, lid, whose++, problem, strategy, state);
}
state.ra.safeRelease(mq);
state.ra.safeRelease(kq);
}

if (problem.quantized2DB()) {
auto kq = state.ra.allocSub<uint32_t>();
auto nq = state.ra.allocSub<uint32_t>();
divDown(kq, state.inputs.k, problem.bqGroupK, strategy, state);
divDown(nq, state.inputs.n, problem.bqGroupN, strategy, state);
if (problem.bScale2D) {
tlbWarmup(problem.Tb_scale, problem.B_scale, strategy.B_scale, state.inputs.bScalePtr,
kq, nq, state.inputs.ldbq, lid, whose++, problem, strategy, state);
}
if (problem.boPtrDims == 2) {
tlbWarmup(problem.Tbo, problem.BO, strategy.BO, state.inputs.boPtr,
kq, nq, state.inputs.ldbq, lid, whose++, problem, strategy, state);
}
state.ra.safeRelease(kq);
state.ra.safeRelease(nq);
}

tlbWarmup(problem.Ta_ext, problem.A, strategy.A, state.effA,
state.inputs.m, state.inputs.k, state.inputs.lda, lid, whose++,
problem, strategy, state);
tlbWarmup(problem.Tb_ext, problem.B, strategy.B, state.effB,
state.inputs.k, state.inputs.n, state.inputs.ldb, lid, whose++,
problem, strategy, state);

state.ra.safeRelease(lid);
}

template <HW hw>
void BLASKernelGenerator<hw>::tlbWarmup(Type T, const MatrixAddressing &atype, const MatrixAddressingStrategy &astrategy,
const Subregister &ptr, const Subregister &r, const Subregister &c,
const Subregister &ld, const Subregister &lid, int whose,
const CommonProblem &problem, const CommonStrategy &strategy, CommonState &state)
{
auto flag = state.raVFlag.alloc();
const uint32_t byteLimit = 256 * 1024 * 1024;

auto bytes = state.ra.allocSub<uint64_t>();
emul(1, bytes, ld, isColMajor(atype.layout) ? c : r, strategy, state);
cmp(1 | nz | flag, bytes.ud(1), 0);
cmp(1 | ~flag | gt | flag, bytes.ud(), byteLimit / T);
emulConstant(1, bytes.ud(), bytes.ud(), T, strategy, state);
mov(1 | flag, bytes.ud(), byteLimit);

state.raVFlag.safeRelease(flag);

tlbWarmup(astrategy.base, ptr, bytes.ud(), lid, whose, problem, strategy, state);

state.ra.safeRelease(bytes);
}

template <HW hw>
void BLASKernelGenerator<hw>::tlbWarmup(AddressBase base, const Subregister &ptr, const Subregister &bytes,
const Subregister &lid, int whose,
const CommonProblem &problem, const CommonStrategy &strategy, CommonState &state)
{
bool a64 = base.isA64();
auto Taddr = a64 ? DataType::uq : DataType::ud;
const int simd = elementsPerGRF<uint32_t>(hw);
const int log2Stride = 16; // 64kb stride.
const int log2TwiddleStride = 6;

int udStride = a64 ? 2 : 1;
auto addr = state.ra.allocRange(udStride);
auto addr0 = addr[0].retype(Taddr);
auto addrLo = addr0.ud(0)(udStride);
auto off = state.ra.allocRange(udStride);
auto off0 = off[0].ud(0)(udStride);
auto twiddle = state.ra.alloc().ud();
auto data = state.ra.alloc().ud();
auto count = state.ra.alloc().d();
auto flag = state.raVFlag.alloc();

extendIndexVec(simd, state);

auto iv = accessIndexVec(0, state)(1);

cmp(1 | nz | flag, lid, whose); /* Check if we are responsible thread */

shl(simd, off0, iv, log2Stride);
shl(simd, twiddle, iv, log2TwiddleStride);
eadd(simd, addr0, ptr, off0, strategy, state);
xor_(simd, addrLo, addrLo, twiddle); /* Perturb low bits to avoid cache hotspotting */

add(1, count, bytes, ((simd + 1) << log2Stride) - 1);
shr(1, count, count, log2Stride);
add(simd, count, count[0], -iv);

Label lTop, lSkip;
jmpi(1 | flag, lSkip);

mark(lTop);
add(simd | gt | flag, count, count, -simd);
if (hw >= HW::XeHPC)
load(simd | flag, null, D8U32 | L1C_L3C, base, addr);
else if (hw >= HW::XeHPG)
load(simd | flag, data, D8U32 | L1C_L3C, base, addr);
else
load(simd | flag, data, scattered_byte(), base, addr);
xor_(simd, addrLo, addrLo, twiddle);
add(simd, twiddle, twiddle, simd << log2TwiddleStride);
and_(simd, twiddle, twiddle, 0xFFF); /* Don't cross 4K page boundaries */
eadd(simd, addr0, addr0, simd << log2Stride, strategy, state);
xor_(simd, addrLo, addrLo, twiddle);
jmpi(1 | flag, lTop);
mark(lSkip);

releaseIndexVec(state);
state.raVFlag.safeRelease(flag);
state.ra.safeRelease(off);
state.ra.safeRelease(twiddle);
state.ra.safeRelease(addr);
state.ra.safeRelease(data);
state.ra.safeRelease(count);
}

#include "internal/namespace_end.hxx"
3 changes: 3 additions & 0 deletions src/gpu/intel/jit/gemm/generator/strategy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ void GEMMStrategy::preflight(HW hw, const GEMMProblem &problem)

extendedAtomicFMA &= !problem.needsASums() && !problem.needsBSums();

if (tlbWarmup && !linearOrder())
cWalkOrder = WalkOrder::SimpleLinear;

// Default SIMD setting.
if (fmaSIMD == 0) {
fmaSIMD = std::min(32, 2 * GRF::bytes(hw) / std::max<int>({Ta.paddedSize(), Tb.paddedSize(), Tc.paddedSize()}));
Expand Down
3 changes: 3 additions & 0 deletions src/gpu/intel/jit/gemm/generator/strategy_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ void parseStrategy(const char *str, HW hw, const GEMMProblem &problem, GEMMStrat
strategy.reverse[LoopM] = true;
else if (mod == "rn")
strategy.reverse[LoopN] = true;
else if (mod == "wt")
strategy.tlbWarmup = true;
else if (mod == "kb" || mod == "kv") {
if (mod == "kb") strategy.kParallel = true;
if (mod == "kv") {
Expand Down Expand Up @@ -886,6 +888,7 @@ std::string unparseStrategy(HW hw, const GEMMProblem &problem, const GEMMStrateg
if (strategy.panelCheck) s << " up";
if (strategy.reverse[LoopM]) s << " rm";
if (strategy.reverse[LoopN]) s << " rn";
if (strategy.tlbWarmup) s << " wt";

if (strategy.checkAdd32 && !strategy.emulate.emulate64) s << " ch";
if (!strategy.checkAdd32 && strategy.emulate.emulate64) s << " nch";
Expand Down
4 changes: 3 additions & 1 deletion src/gpu/intel/jit/gemm/include/driver_info.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2024 Intel Corporation
* Copyright 2019-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -69,6 +69,7 @@ enum DriverInfoFlags : uint32_t {
FlagNondeterministic = 0x4000, // Kernel produces nondeterministic results.
FlagMaskFillGoal = 0xF0000, // Fraction of available thread slots to fill, in sixteenths
FlagShiftFillGoal = 16, // (starting bit)
FlagExtraWG = 0x400000, // Add an additional workgroup.
};

// Driver information, shared by all kernel types.
Expand Down Expand Up @@ -116,6 +117,7 @@ struct CommonDriverInfo {
bool betaPtr() const { return flags & FlagBetaPtr; }
bool fixedWGK() const { return flags & FlagFixedWGK; }
bool nondeterministic() const { return flags & FlagNondeterministic; }
int extraWGs() const { return (flags & FlagExtraWG) ? 1 : 0; }

int wgTile(LoopType l) const { return unroll[l] * wg[l]; }
int kPadding() const { return (kParallel() || kParallelVariable()) ? blockingAlt[LoopK] : 0; }
Expand Down
5 changes: 5 additions & 0 deletions src/gpu/intel/jit/gemm/include/generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,11 @@ class BLASKernelGenerator : public GENERATOR_BASE(hw) {
bool gemmFusedPostOpsFinalize(ngen::Label &labelLateExit, GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state);
void gemmRedirectToTempC(GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state);

// tlb_warmup.cxx
void tlbWarmup(ngen::AddressBase base, const ngen::Subregister &ptr, const ngen::Subregister &bytes, const ngen::Subregister &lid, int whose, const CommonProblem &problem, const CommonStrategy &strategy, CommonState &state);
void tlbWarmup(Type T, const MatrixAddressing &atype, const MatrixAddressingStrategy &astrategy, const ngen::Subregister &base, const ngen::Subregister &r, const ngen::Subregister &c, const ngen::Subregister &ld, const ngen::Subregister &lid, int whose, const CommonProblem &problem, const CommonStrategy &strategy, CommonState &state);
void gemmTLBWarmup(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state);

// gemm_setup.cpp
void gemmCheck32(const GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state);
void gemmGetBatchIDs(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state);
Expand Down
3 changes: 2 additions & 1 deletion src/gpu/intel/jit/gemm/include/strategy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ struct GEMMStrategyPOD : public CommonStrategy {
bool kDescRem = false; // Allow descriptor-based k remainder handling for A/B.
bool slmA = false, slmB = false; // Whether to copy A/B to SLM.
bool splitCopy = false; // Separate SLM copy and compute threads?
ZPAD(C, 2)
bool tlbWarmup = false; // Enable TLB warmup?
ZPAD(C, 1)
int slmBuffers = 0; // # of A/B SLM buffers, 0 for none.
int unrollKSLM = 0; // k unroll for SLM copies (0 = auto = unroll[LoopK]/slmCopies)
int unrollKSLMMasked = 0; // Alternate value to use with masking (0 = same as unrollKSLM)
Expand Down

0 comments on commit 65fd254

Please sign in to comment.