Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/hopper-build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,12 @@ jobs:
fi
# flagtree tle
# python tutorials
python3 python/tutorials/tle/01-sparse-mla.py
python3 python/tutorials/tle/01-fft.py
python3 python/tutorials/tle/02-moe_align_block_size.py
python3 python/tutorials/tle/03-topk.py
python3 python/tutorials/tle/04-fft.py
python3 python/tutorials/tle/05-deepseek_v32_topk_selector.py
python3 python/tutorials/tle/06-cluster-gemm.py
python3 python/tutorials/tle/04-cluster-gemm.py
python3 python/tutorials/tle/deepseek_v32/01-topk_selector.py
python3 python/tutorials/tle/deepseek_v32/02-sparse-mla.py
# python unit test
python3 -m pytest -s python/test/tle/integration
python3 -m pytest -s python/test/tle/unit
Expand Down
10 changes: 10 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H

#include "triton/Conversion/MLIRTypes.h"
#include <optional>

namespace mlir::triton {
enum class ProgramIDDim : uint32_t;
Expand Down Expand Up @@ -66,6 +67,15 @@ class TargetInfoBase {
unsigned numLaneToReduce,
unsigned interleave) const = 0;

#ifdef __TLE__
// Optional fastpath for CTA-wide boolean OR reduction.
// Returns std::nullopt when unsupported so callers can fall back.
virtual std::optional<Value>
ctaReduceOrPredicate(RewriterBase &rewriter, Location loc, Value pred) const {
return std::nullopt;
}
#endif

virtual std::string getMulhiFuncName(Type resultElementTy) const = 0;
// Emits LLVM code with |rewriter| to print a message following the given
// format from the device. |formatStrStart| is the pointer to the start of
Expand Down
21 changes: 21 additions & 0 deletions lib/Analysis/Alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,23 @@

#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Support/LLVM.h"
#ifdef __TLE__
#include "triton/Dialect/Triton/IR/Types.h"
#endif
#include "triton/Dialect/TritonGPU/IR/Dialect.h"

namespace mlir {

#ifdef __TLE__
static bool isTritonPtrLikeType(Type type) {
if (isa<triton::PointerType>(type))
return true;
if (auto tensorTy = dyn_cast<RankedTensorType>(type))
return isa<triton::PointerType>(tensorTy.getElementType());
return false;
}
#endif

AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) {
if (lhs == rhs)
return lhs;
Expand Down Expand Up @@ -45,6 +58,14 @@ LogicalResult SharedMemoryAliasAnalysis::visitOperation(
// Treat local pointer views as aliases of their source memdesc.
aliasInfo = AliasInfo(operands[0]->getValue());
pessimistic = false;
} else if (isTritonPtrLikeType(result.getType())) {
// Propagate aliases through pointer-producing/view-like ops such as
// tt.splat/tt.broadcast/tt.addptr chains so shared buffers stay live
// across pointer arithmetic users.
aliasInfo = AliasInfo();
for (auto *operand : operands)
aliasInfo = AliasInfo::join(aliasInfo, operand->getValue());
pessimistic = false;
#endif
} else if (isa<ub::PoisonOp>(op)) {
aliasInfo = AliasInfo();
Expand Down
54 changes: 54 additions & 0 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1156,12 +1156,34 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber,
}

void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
#ifdef __TLE__
if (!vec)
return;
if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr)) {
// Scalar hints are only valid for rank-1 vectors. Ignore mismatched hints
// to avoid shrinking rank information (which can cause out-of-bounds axis
// queries later in vectorization analysis).
if (vec->size() == 1 || vec->empty())
*vec = DimVectorT(1, intAttr.getValue().getZExtValue());
return;
}
if (auto dense_attr = dyn_cast_or_null<DenseElementsAttr>(attr)) {
SmallVector<int64_t> vals;
vals.reserve(dense_attr.getNumElements());
for (APInt v : dense_attr.getValues<APInt>())
vals.push_back(v.getSExtValue());
if (vec->empty() || vals.size() == vec->size())
*vec = DimVectorT(vals.begin(), vals.end());
return;
}
#else
if (auto int_attr = dyn_cast_or_null<IntegerAttr>(attr))
*vec = DimVectorT(1, int_attr.getValue().getZExtValue());
if (auto dense_attr = dyn_cast_or_null<DenseElementsAttr>(attr)) {
auto vals = dense_attr.getValues<int>();
*vec = DimVectorT(vals.begin(), vals.end());
}
#endif
}

/*static*/ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
Expand Down Expand Up @@ -1221,7 +1243,17 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
return rhs;
if (rhs.getRank() == 0)
return lhs;
#ifdef __TLE__
if (lhs.getRank() != rhs.getRank()) {
// Be conservative when malformed/mismatched hints have polluted rank
// information. Prefer correctness and robustness over optimistic metadata.
const int rank = std::max(lhs.getRank(), rhs.getRank());
return AxisInfo(DimVectorT(rank, 1), DimVectorT(rank, 1),
DimVectorT(rank, 1));
}
#else
assert(lhs.getRank() == rhs.getRank() && "Mismatched ranks");
#endif
DimVectorT contiguity;
DimVectorT divisibility;
DimVectorT constancy;
Expand Down Expand Up @@ -1258,11 +1290,20 @@ unsigned ModuleAxisInfoAnalysis::getContiguity(Value offsetsValue,
auto tensorTy = cast<RankedTensorType>(offsetsValue.getType());
auto linAttr = gpu::toLinearEncoding(tensorTy);
auto order = linAttr.getOrder();
#ifdef __TLE__
if (order.empty())
return 1;
#endif
unsigned align = getAlignment(offsetsValue, elementBitWidth);

auto uniqueContigPerThread = linAttr.getContigPerThread();
#ifdef __TLE__
if (order[0] >= uniqueContigPerThread.size())
return align;
#else
assert(order[0] < uniqueContigPerThread.size() &&
"Unexpected uniqueContigPerThread size");
#endif
unsigned contiguity = uniqueContigPerThread[order[0]];
LDBG("getContiguity uniqueContigPerThread = " << contiguity);
contiguity = std::min(align, contiguity);
Expand Down Expand Up @@ -1291,6 +1332,13 @@ unsigned ModuleAxisInfoAnalysis::getAlignment(Value offsetsValue,
return 1;
auto linAttr = gpu::toLinearEncoding(tensorTy);
auto order = linAttr.getOrder();
#ifdef __TLE__
if (order.empty())
return 1;
if (order[0] >= axisInfo->getRank()) {
return 1;
}
#endif

auto divisibility = axisInfo->getDivisibility(order[0]);
auto elemNumBytes = std::max<unsigned>(elementBitWidth / 8, 1);
Expand Down Expand Up @@ -1323,6 +1371,12 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
return 1;
auto linAttr = gpu::toLinearEncoding(tensorTy);
auto maskOrder = linAttr.getOrder();
#ifdef __TLE__
if (maskOrder.empty())
return 1;
if (maskOrder[0] >= axisInfo->getRank())
return 1;
#endif
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "
<< alignment);
Expand Down
6 changes: 6 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,14 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
auto promotedOperands = this->getTypeConverter()->promoteOperands(
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
adaptor.getOperands(), rewriter);
#ifdef __TLE__
// The call-site offset is sufficient to derive the callee frame base.
// Kernel callers do not carry `allocation.offset` themselves.
if (!callOp->hasAttr("allocation.offset")) {
#else
if (!caller->hasAttr("allocation.offset") ||
!callOp->hasAttr("allocation.offset")) {
#endif
auto base = LLVM::getStackPointer(rewriter, caller);
promotedOperands.push_back(base);
} else {
Expand Down
51 changes: 51 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#include "ReduceScanCommon.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Support/LLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include <optional>

using namespace mlir;
using namespace mlir::triton;
Expand Down Expand Up @@ -31,6 +33,13 @@ struct ReduceOpConversion
"Unexpected srcLayout in ReduceOpConversion");
Location loc = op->getLoc();

#ifdef __TLE__
if (succeeded(
tryLowerCtaOrReductionFastpath(op, adaptor, helper, rewriter))) {
return success();
}
#endif

auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
std::map<SmallVector<unsigned>, SmallVector<Value>> accs;
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
Expand Down Expand Up @@ -79,6 +88,48 @@ struct ReduceOpConversion
private:
const TargetInfoBase &targetInfo;

#ifdef __TLE__
bool isOrReductionToScalarI1(triton::ReduceOp op) const {
if (op.getNumOperands() != 1 || op.getNumResults() != 1)
return false;
if (!op.getResult()[0].getType().isInteger(1))
return false;
auto *combine = op.getSingleCombiner();
return combine && isa<arith::OrIOp>(combine);
}

LogicalResult
tryLowerCtaOrReductionFastpath(triton::ReduceOp op, OpAdaptor adaptor,
ReduceOpHelper &helper,
ConversionPatternRewriter &rewriter) const {
if (helper.isWarpSynchronous() || !isOrReductionToScalarI1(op))
return failure();

Location loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
std::optional<Value> threadAny;
for (unsigned i = 0; i < srcValues.size(); ++i) {
Value pred = srcValues[i][0];
if (!pred.getType().isInteger(1)) {
Type predTy = pred.getType();
pred = b.icmp_ne(pred, b.int_val(predTy.getIntOrFloatBitWidth(), 0));
}
threadAny = threadAny ? b.or_(*threadAny, pred) : pred;
}
if (!threadAny)
return failure();

std::optional<Value> ctaAny =
targetInfo.ctaReduceOrPredicate(rewriter, loc, *threadAny);
if (!ctaAny)
return failure();

rewriter.replaceOp(op, *ctaAny);
return success();
}
#endif

void accumulate(Location loc, ConversionPatternRewriter &rewriter,
Region &combineOp, SmallVector<Value> &acc, ValueRange cur,
Value pred = {}) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,7 @@ void populateTleRawPatterns(TritonGPUTypeConverter &typeConverter,
.add<TleDSLRegionOpPattern, TleExtractTileOpPattern,
TleInsertTileOpPattern, GenericOpPattern<tle::LocalPointersOp>,
GenericOpPattern<tle::RemotePointersOp>,
GenericOpPattern<tle::ExclusiveCumsumOp>,
GenericOpPattern<tle::DistributedBarrierOp>,
GenericOpPattern<tle::YieldOp>,
GenericOpPattern<tle::ExtractAllocatedPtrOp>,
Expand Down
26 changes: 26 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,32 @@ class SwizzleShmemConvert : public OpRewritePattern<ConvertLayoutOp> {
if (newInnerCvtEnc == cvtEncoding)
return failure();
rewriter.setInsertionPoint(trans);
#ifdef __TLE__
// If the source is already loaded from a shared memdesc allocated in this
// function, directly retag that memdesc encoding and avoid inserting an
// additional local_alloc staging buffer.
if (auto srcLocalLoad = trans.getSrc().getDefiningOp<LocalLoadOp>()) {
Value srcMemDesc = srcLocalLoad.getSrc();
auto srcMemDescTy = dyn_cast<MemDescType>(srcMemDesc.getType());
if (srcMemDescTy && srcMemDescTy.getShape() == srcTy.getShape() &&
srcMemDescTy.getElementType() == srcTy.getElementType() &&
srcMemDesc.getDefiningOp<LocalAllocOp>()) {
auto updatedMemDescTy = MemDescType::get(
srcMemDescTy.getShape(), srcMemDescTy.getElementType(),
newInnerCvtEnc, srcMemDescTy.getMemorySpace(),
srcMemDescTy.getMutableMemory(), srcMemDescTy.getAllocShape());
srcMemDesc.setType(updatedMemDescTy);
auto newTrans = rewriter.create<MemDescTransOp>(
trans.getLoc(), srcMemDesc, ArrayRef<int32_t>({1, 0}));
auto localLoadOp = rewriter.create<LocalLoadOp>(
trans.getLoc(), sharedLoadTy, newTrans, srcLocalLoad.getToken());
rewriter.modifyOpInPlace(cvtOp, [&]() {
cvtOp.getSrcMutable().assign(localLoadOp.getResult());
});
return success();
}
}
#endif
auto sharedMemorySpace = SharedMemorySpaceAttr::get(getContext());
auto alloc = LocalAllocOp::create(
rewriter, trans.getLoc(),
Expand Down
4 changes: 0 additions & 4 deletions lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ namespace gpu {

#ifdef __TLE__
static bool isLikelyRemotePtr(Value ptr) {
constexpr StringLiteral kRemoteShardCarrierAttr =
"tle.remote_shard_id_carrier";
SmallVector<Value> worklist{ptr};
DenseSet<Value> visited;
while (!worklist.empty()) {
Expand All @@ -41,8 +39,6 @@ static bool isLikelyRemotePtr(Value ptr) {
if (def->getName().getStringRef() == "tle.remote_pointers")
return true;
if (auto addPtr = dyn_cast<triton::AddPtrOp>(def)) {
if (addPtr->hasAttr(kRemoteShardCarrierAttr))
return true;
worklist.push_back(addPtr.getPtr());
worklist.push_back(addPtr.getOffset());
continue;
Expand Down
17 changes: 8 additions & 9 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,28 +38,27 @@ namespace mlir::triton::gpu {
namespace {

#ifdef __TLE__
static bool touchesTleDistributedPointerPath(Value value,
DenseSet<Value> &visited) {
static bool touchesTleRemotePointerPath(Value value, DenseSet<Value> &visited) {
if (!visited.insert(value).second)
return false;
Operation *def = value.getDefiningOp();
if (!def)
return false;
StringRef opName = def->getName().getStringRef();
if (opName == "tle.local_pointers" || opName == "tle.remote_pointers")
if (opName == "tle.remote_pointers")
return true;
if (auto ifOp = dyn_cast<scf::IfOp>(def)) {
auto result = dyn_cast<OpResult>(value);
if (!result)
return false;
unsigned idx = result.getResultNumber();
return touchesTleDistributedPointerPath(ifOp.thenYield().getOperand(idx),
visited) ||
touchesTleDistributedPointerPath(ifOp.elseYield().getOperand(idx),
visited);
return touchesTleRemotePointerPath(ifOp.thenYield().getOperand(idx),
visited) ||
touchesTleRemotePointerPath(ifOp.elseYield().getOperand(idx),
visited);
}
for (Value operand : def->getOperands()) {
if (touchesTleDistributedPointerPath(operand, visited))
if (touchesTleRemotePointerPath(operand, visited))
return true;
}
return false;
Expand Down Expand Up @@ -1313,7 +1312,7 @@ void LayoutRematerialization::hoistConvertDotOperand(
#ifdef __TLE__
{
DenseSet<Value> visited;
if (touchesTleDistributedPointerPath(convertOp.getSrc(), visited))
if (touchesTleRemotePointerPath(convertOp.getSrc(), visited))
return;
}
#endif
Expand Down
Loading
Loading