diff --git a/.github/workflows/hopper-build-and-test.yml b/.github/workflows/hopper-build-and-test.yml index 8ee10ede8c..20becfb0de 100644 --- a/.github/workflows/hopper-build-and-test.yml +++ b/.github/workflows/hopper-build-and-test.yml @@ -149,12 +149,12 @@ jobs: fi # flagtree tle # python tutorials - python3 python/tutorials/tle/01-sparse-mla.py + python3 python/tutorials/tle/01-fft.py python3 python/tutorials/tle/02-moe_align_block_size.py python3 python/tutorials/tle/03-topk.py - python3 python/tutorials/tle/04-fft.py - python3 python/tutorials/tle/05-deepseek_v32_topk_selector.py - python3 python/tutorials/tle/06-cluster-gemm.py + python3 python/tutorials/tle/04-cluster-gemm.py + python3 python/tutorials/tle/deepseek_v32/01-topk_selector.py + python3 python/tutorials/tle/deepseek_v32/02-sparse-mla.py # python unit test python3 -m pytest -s python/test/tle/integration python3 -m pytest -s python/test/tle/unit diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index 37a2f7fbc1..4f782b5036 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -2,6 +2,7 @@ #define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H #include "triton/Conversion/MLIRTypes.h" +#include namespace mlir::triton { enum class ProgramIDDim : uint32_t; @@ -66,6 +67,15 @@ class TargetInfoBase { unsigned numLaneToReduce, unsigned interleave) const = 0; +#ifdef __TLE__ + // Optional fastpath for CTA-wide boolean OR reduction. + // Returns std::nullopt when unsupported so callers can fall back. + virtual std::optional + ctaReduceOrPredicate(RewriterBase &rewriter, Location loc, Value pred) const { + return std::nullopt; + } +#endif + virtual std::string getMulhiFuncName(Type resultElementTy) const = 0; // Emits LLVM code with |rewriter| to print a message following the given // format from the device. |formatStrStart| is the pointer to the start of diff --git a/lib/Analysis/Alias.cpp b/lib/Analysis/Alias.cpp index e5ff03a937..045d559ac2 100644 --- a/lib/Analysis/Alias.cpp +++ b/lib/Analysis/Alias.cpp @@ -2,10 +2,23 @@ #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Support/LLVM.h" +#ifdef __TLE__ +#include "triton/Dialect/Triton/IR/Types.h" +#endif #include "triton/Dialect/TritonGPU/IR/Dialect.h" namespace mlir { +#ifdef __TLE__ +static bool isTritonPtrLikeType(Type type) { + if (isa(type)) + return true; + if (auto tensorTy = dyn_cast(type)) + return isa(tensorTy.getElementType()); + return false; +} +#endif + AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) { if (lhs == rhs) return lhs; @@ -45,6 +58,14 @@ LogicalResult SharedMemoryAliasAnalysis::visitOperation( // Treat local pointer views as aliases of their source memdesc. aliasInfo = AliasInfo(operands[0]->getValue()); pessimistic = false; + } else if (isTritonPtrLikeType(result.getType())) { + // Propagate aliases through pointer-producing/view-like ops such as + // tt.splat/tt.broadcast/tt.addptr chains so shared buffers stay live + // across pointer arithmetic users. + aliasInfo = AliasInfo(); + for (auto *operand : operands) + aliasInfo = AliasInfo::join(aliasInfo, operand->getValue()); + pessimistic = false; #endif } else if (isa(op)) { aliasInfo = AliasInfo(); diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 159392f174..3c261d9fd2 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1156,12 +1156,34 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber, } void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) { +#ifdef __TLE__ + if (!vec) + return; + if (auto intAttr = dyn_cast_or_null(attr)) { + // Scalar hints are only valid for rank-1 vectors. Ignore mismatched hints + // to avoid shrinking rank information (which can cause out-of-bounds axis + // queries later in vectorization analysis). + if (vec->size() == 1 || vec->empty()) + *vec = DimVectorT(1, intAttr.getValue().getZExtValue()); + return; + } + if (auto dense_attr = dyn_cast_or_null(attr)) { + SmallVector vals; + vals.reserve(dense_attr.getNumElements()); + for (APInt v : dense_attr.getValues()) + vals.push_back(v.getSExtValue()); + if (vec->empty() || vals.size() == vec->size()) + *vec = DimVectorT(vals.begin(), vals.end()); + return; + } +#else if (auto int_attr = dyn_cast_or_null(attr)) *vec = DimVectorT(1, int_attr.getValue().getZExtValue()); if (auto dense_attr = dyn_cast_or_null(attr)) { auto vals = dense_attr.getValues(); *vec = DimVectorT(vals.begin(), vals.end()); } +#endif } /*static*/ AxisInfo AxisInfo::getPessimisticValueState(Value value) { @@ -1221,7 +1243,17 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) { return rhs; if (rhs.getRank() == 0) return lhs; +#ifdef __TLE__ + if (lhs.getRank() != rhs.getRank()) { + // Be conservative when malformed/mismatched hints have polluted rank + // information. Prefer correctness and robustness over optimistic metadata. + const int rank = std::max(lhs.getRank(), rhs.getRank()); + return AxisInfo(DimVectorT(rank, 1), DimVectorT(rank, 1), + DimVectorT(rank, 1)); + } +#else assert(lhs.getRank() == rhs.getRank() && "Mismatched ranks"); +#endif DimVectorT contiguity; DimVectorT divisibility; DimVectorT constancy; @@ -1258,11 +1290,20 @@ unsigned ModuleAxisInfoAnalysis::getContiguity(Value offsetsValue, auto tensorTy = cast(offsetsValue.getType()); auto linAttr = gpu::toLinearEncoding(tensorTy); auto order = linAttr.getOrder(); +#ifdef __TLE__ + if (order.empty()) + return 1; +#endif unsigned align = getAlignment(offsetsValue, elementBitWidth); auto uniqueContigPerThread = linAttr.getContigPerThread(); +#ifdef __TLE__ + if (order[0] >= uniqueContigPerThread.size()) + return align; +#else assert(order[0] < uniqueContigPerThread.size() && "Unexpected uniqueContigPerThread size"); +#endif unsigned contiguity = uniqueContigPerThread[order[0]]; LDBG("getContiguity uniqueContigPerThread = " << contiguity); contiguity = std::min(align, contiguity); @@ -1291,6 +1332,13 @@ unsigned ModuleAxisInfoAnalysis::getAlignment(Value offsetsValue, return 1; auto linAttr = gpu::toLinearEncoding(tensorTy); auto order = linAttr.getOrder(); +#ifdef __TLE__ + if (order.empty()) + return 1; + if (order[0] >= axisInfo->getRank()) { + return 1; + } +#endif auto divisibility = axisInfo->getDivisibility(order[0]); auto elemNumBytes = std::max(elementBitWidth / 8, 1); @@ -1323,6 +1371,12 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) { return 1; auto linAttr = gpu::toLinearEncoding(tensorTy); auto maskOrder = linAttr.getOrder(); +#ifdef __TLE__ + if (maskOrder.empty()) + return 1; + if (maskOrder[0] >= axisInfo->getRank()) + return 1; +#endif auto alignment = std::max(axisInfo->getConstancy(maskOrder[0]), 1); LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment " << alignment); diff --git a/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp index f33cb37cbf..6e39bc067f 100644 --- a/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp @@ -84,8 +84,14 @@ struct CallOpConversion : public ConvertOpToLLVMPattern { auto promotedOperands = this->getTypeConverter()->promoteOperands( callOp.getLoc(), /*opOperands=*/callOp->getOperands(), adaptor.getOperands(), rewriter); +#ifdef __TLE__ + // The call-site offset is sufficient to derive the callee frame base. + // Kernel callers do not carry `allocation.offset` themselves. + if (!callOp->hasAttr("allocation.offset")) { +#else if (!caller->hasAttr("allocation.offset") || !callOp->hasAttr("allocation.offset")) { +#endif auto base = LLVM::getStackPointer(rewriter, caller); promotedOperands.push_back(base); } else { diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index a17526f102..e77292e74f 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -1,8 +1,10 @@ #include "ReduceScanCommon.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Support/LLVM.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include using namespace mlir; using namespace mlir::triton; @@ -31,6 +33,13 @@ struct ReduceOpConversion "Unexpected srcLayout in ReduceOpConversion"); Location loc = op->getLoc(); +#ifdef __TLE__ + if (succeeded( + tryLowerCtaOrReductionFastpath(op, adaptor, helper, rewriter))) { + return success(); + } +#endif + auto srcValues = unpackInputs(loc, op, adaptor, rewriter); std::map, SmallVector> accs; std::map, SmallVector> indices; @@ -79,6 +88,48 @@ struct ReduceOpConversion private: const TargetInfoBase &targetInfo; +#ifdef __TLE__ + bool isOrReductionToScalarI1(triton::ReduceOp op) const { + if (op.getNumOperands() != 1 || op.getNumResults() != 1) + return false; + if (!op.getResult()[0].getType().isInteger(1)) + return false; + auto *combine = op.getSingleCombiner(); + return combine && isa(combine); + } + + LogicalResult + tryLowerCtaOrReductionFastpath(triton::ReduceOp op, OpAdaptor adaptor, + ReduceOpHelper &helper, + ConversionPatternRewriter &rewriter) const { + if (helper.isWarpSynchronous() || !isOrReductionToScalarI1(op)) + return failure(); + + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto srcValues = unpackInputs(loc, op, adaptor, rewriter); + std::optional threadAny; + for (unsigned i = 0; i < srcValues.size(); ++i) { + Value pred = srcValues[i][0]; + if (!pred.getType().isInteger(1)) { + Type predTy = pred.getType(); + pred = b.icmp_ne(pred, b.int_val(predTy.getIntOrFloatBitWidth(), 0)); + } + threadAny = threadAny ? b.or_(*threadAny, pred) : pred; + } + if (!threadAny) + return failure(); + + std::optional ctaAny = + targetInfo.ctaReduceOrPredicate(rewriter, loc, *threadAny); + if (!ctaAny) + return failure(); + + rewriter.replaceOp(op, *ctaAny); + return success(); + } +#endif + void accumulate(Location loc, ConversionPatternRewriter &rewriter, Region &combineOp, SmallVector &acc, ValueRange cur, Value pred = {}) const { diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 02372c68bc..b49871486d 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -906,6 +906,7 @@ void populateTleRawPatterns(TritonGPUTypeConverter &typeConverter, .add, GenericOpPattern, + GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 9306a1c1c2..3c480d14eb 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -66,6 +66,32 @@ class SwizzleShmemConvert : public OpRewritePattern { if (newInnerCvtEnc == cvtEncoding) return failure(); rewriter.setInsertionPoint(trans); +#ifdef __TLE__ + // If the source is already loaded from a shared memdesc allocated in this + // function, directly retag that memdesc encoding and avoid inserting an + // additional local_alloc staging buffer. + if (auto srcLocalLoad = trans.getSrc().getDefiningOp()) { + Value srcMemDesc = srcLocalLoad.getSrc(); + auto srcMemDescTy = dyn_cast(srcMemDesc.getType()); + if (srcMemDescTy && srcMemDescTy.getShape() == srcTy.getShape() && + srcMemDescTy.getElementType() == srcTy.getElementType() && + srcMemDesc.getDefiningOp()) { + auto updatedMemDescTy = MemDescType::get( + srcMemDescTy.getShape(), srcMemDescTy.getElementType(), + newInnerCvtEnc, srcMemDescTy.getMemorySpace(), + srcMemDescTy.getMutableMemory(), srcMemDescTy.getAllocShape()); + srcMemDesc.setType(updatedMemDescTy); + auto newTrans = rewriter.create( + trans.getLoc(), srcMemDesc, ArrayRef({1, 0})); + auto localLoadOp = rewriter.create( + trans.getLoc(), sharedLoadTy, newTrans, srcLocalLoad.getToken()); + rewriter.modifyOpInPlace(cvtOp, [&]() { + cvtOp.getSrcMutable().assign(localLoadOp.getResult()); + }); + return success(); + } + } +#endif auto sharedMemorySpace = SharedMemorySpaceAttr::get(getContext()); auto alloc = LocalAllocOp::create( rewriter, trans.getLoc(), diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp index d0558b75e8..e641e2a078 100644 --- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp +++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp @@ -27,8 +27,6 @@ namespace gpu { #ifdef __TLE__ static bool isLikelyRemotePtr(Value ptr) { - constexpr StringLiteral kRemoteShardCarrierAttr = - "tle.remote_shard_id_carrier"; SmallVector worklist{ptr}; DenseSet visited; while (!worklist.empty()) { @@ -41,8 +39,6 @@ static bool isLikelyRemotePtr(Value ptr) { if (def->getName().getStringRef() == "tle.remote_pointers") return true; if (auto addPtr = dyn_cast(def)) { - if (addPtr->hasAttr(kRemoteShardCarrierAttr)) - return true; worklist.push_back(addPtr.getPtr()); worklist.push_back(addPtr.getOffset()); continue; diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index d288faed52..5ae73f69e1 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -38,28 +38,27 @@ namespace mlir::triton::gpu { namespace { #ifdef __TLE__ -static bool touchesTleDistributedPointerPath(Value value, - DenseSet &visited) { +static bool touchesTleRemotePointerPath(Value value, DenseSet &visited) { if (!visited.insert(value).second) return false; Operation *def = value.getDefiningOp(); if (!def) return false; StringRef opName = def->getName().getStringRef(); - if (opName == "tle.local_pointers" || opName == "tle.remote_pointers") + if (opName == "tle.remote_pointers") return true; if (auto ifOp = dyn_cast(def)) { auto result = dyn_cast(value); if (!result) return false; unsigned idx = result.getResultNumber(); - return touchesTleDistributedPointerPath(ifOp.thenYield().getOperand(idx), - visited) || - touchesTleDistributedPointerPath(ifOp.elseYield().getOperand(idx), - visited); + return touchesTleRemotePointerPath(ifOp.thenYield().getOperand(idx), + visited) || + touchesTleRemotePointerPath(ifOp.elseYield().getOperand(idx), + visited); } for (Value operand : def->getOperands()) { - if (touchesTleDistributedPointerPath(operand, visited)) + if (touchesTleRemotePointerPath(operand, visited)) return true; } return false; @@ -1313,7 +1312,7 @@ void LayoutRematerialization::hoistConvertDotOperand( #ifdef __TLE__ { DenseSet visited; - if (touchesTleDistributedPointerPath(convertOp.getSrc(), visited)) + if (touchesTleRemotePointerPath(convertOp.getSrc(), visited)) return; } #endif diff --git a/python/test/tle/integration/test_tle_distributed.py b/python/test/tle/integration/test_tle_distributed.py index 1470a5955d..78d5f51018 100644 --- a/python/test/tle/integration/test_tle_distributed.py +++ b/python/test/tle/integration/test_tle_distributed.py @@ -182,13 +182,29 @@ def _remote_const_shard_vectorized_load_kernel( @triton.jit -def _remote_pointer_input_disallowed_kernel(out_ptr, mesh: tl.constexpr, BLOCK: tl.constexpr): +def _remote_pointer_input_allowed_kernel(out_ptr, mesh: tl.constexpr, BLOCK: tl.constexpr): offs = tl.arange(0, BLOCK) + pid = tl.program_id(0) smem = tle.gpu.alloc([BLOCK], dtype=tl.float16, layout=None, scope=tle.gpu.smem, nv_mma_shared_layout=False) local_ptr = tle.gpu.local_ptr(smem, (offs, )) + vals = tl.cast(offs + pid * BLOCK, tl.float16) + tl.store(local_ptr, vals) + tle.distributed_barrier(mesh) remote_ptr = tle.remote(local_ptr, 0, scope=mesh) vals = tl.load(remote_ptr) - tl.store(out_ptr + offs, vals) + tl.store(out_ptr + pid * BLOCK + offs, vals) + + +@triton.jit +def _remote_pointer_scalar_input_allowed_kernel(out_ptr, mesh: tl.constexpr): + pid = tl.program_id(0) + smem = tle.gpu.alloc([1], dtype=tl.float16, layout=None, scope=tle.gpu.smem, nv_mma_shared_layout=False) + local_scalar_ptr = tle.gpu.local_ptr(smem, (0, )) + tl.store(local_scalar_ptr, tl.cast(pid, tl.float16)) + tle.distributed_barrier(mesh) + remote_scalar_ptr = tle.remote(local_scalar_ptr, 0, scope=mesh) + val = tl.load(remote_scalar_ptr) + tl.store(out_ptr + pid, val) @triton.jit @@ -277,6 +293,60 @@ def _remote_rank0_dsmem_atomic_add_kernel(out_ptr, mesh: tl.constexpr): tl.store(out_ptr + idx, counter) +@triton.jit +def _remote_rank0_dsmem_scalar_ptr_atomic_add_kernel(out_ptr, mesh: tl.constexpr): + rank = tle.shard_id(mesh, "cluster_x") + smem = tle.gpu.alloc([1], dtype=tl.int32, layout=None, scope=tle.gpu.smem, nv_mma_shared_layout=False) + local_scalar_ptr = tle.gpu.local_ptr(smem, (0, )) + + if rank == 0: + tl.store(local_scalar_ptr, 0) + tle.distributed_barrier(mesh) + + remote_rank0 = tle.remote(smem, 0, scope=mesh) + remote_scalar_ptr = tle.gpu.local_ptr(remote_rank0, (0, )) + tl.atomic_add(remote_scalar_ptr, 1, sem="relaxed", scope="cta") + tle.distributed_barrier(mesh) + + if rank == 0: + counter = tl.load(local_scalar_ptr) + tl.store(out_ptr, counter) + + +@triton.jit +def _remote_rank0_dsmem_buffer_vs_ptr_remote_atomic_add_kernel(out_ptr, mesh: tl.constexpr, BLOCK: tl.constexpr): + rank = tle.shard_id(mesh, "cluster_x") + zeros = tl.zeros((BLOCK, ), dtype=tl.int32) + ones = tl.full((BLOCK, ), 1, tl.int32) + + smem = tle.gpu.alloc([2], dtype=tl.int32, layout=None, scope=tle.gpu.smem, nv_mma_shared_layout=False) + local_counter0_ptr = tle.gpu.local_ptr(smem, (0, )) + local_counter1_ptr = tle.gpu.local_ptr(smem, (1, )) + + if rank == 0: + tl.store(local_counter0_ptr, 0) + tl.store(local_counter1_ptr, 0) + tle.distributed_barrier(mesh) + + # Buffer-level remote + local_ptr path. + remote_rank0_buffer = tle.remote(smem, 0, scope=mesh) + remote_counter0_ptrs = tle.gpu.local_ptr(remote_rank0_buffer, (zeros, )) + + # Pointer-level remote path with derived addptr tensor pointer. + remote_counter1_scalar_ptr = tle.remote(local_counter1_ptr, 0, scope=mesh) + remote_counter1_ptrs = remote_counter1_scalar_ptr + zeros + + tl.atomic_add(remote_counter0_ptrs, ones, sem="relaxed", scope="cta") + tl.atomic_add(remote_counter1_ptrs, ones, sem="relaxed", scope="cta") + tle.distributed_barrier(mesh) + + if rank == 0: + counter0 = tl.load(local_counter0_ptr) + counter1 = tl.load(local_counter1_ptr) + tl.store(out_ptr + 0, counter0) + tl.store(out_ptr + 1, counter1) + + @triton.jit def _remote_scan_shared_scratch_kernel(out_ptr, mesh: tl.constexpr, BLOCK: tl.constexpr): rank = tle.shard_id(mesh, "cluster_x") @@ -754,8 +824,7 @@ def test_remote_read_peer_smem_same_cluster(self): num_warps=4, ) assert compiled.metadata.cluster_dims == (2, 1, 1) - assert ("tle.remote_pointers" in compiled.asm["ttgir"]) or ("tle.remote_shard_id_carrier" - in compiled.asm["ttgir"]) + assert "tle.remote_pointers" in compiled.asm["ttgir"] assert "\"ttg.num-ctas\" = 1" in compiled.asm["ttgir"] assert "mapa.shared::cluster" in compiled.asm["ptx"] @@ -907,15 +976,15 @@ def test_remote_const_shard_load_high_block_encoding_no_regression(self): ptx = compiled.asm["ptx"] # local/remote pointer tensors should keep the load-friendly encoding. - assert "tensor<256x!tt.ptr, #blocked>" in ttgir + assert re.search(r"tensor<256x!tt\.ptr,\s*#blocked[0-9]*>", ttgir) is not None assert re.search( r"\"tle\.remote_pointers\"\(%[^,]+,\s*%[^)]+\)\s*:\s*" - r"\(tensor<256x!tt\.ptr,\s*#blocked>,\s*i32\)\s*->\s*" - r"tensor<256x!tt\.ptr,\s*#blocked>", + r"\(tensor<256x!tt\.ptr,\s*#blocked[0-9]*>,\s*i32\)\s*->\s*" + r"tensor<256x!tt\.ptr,\s*#blocked[0-9]*>", ttgir, ) is not None # Avoid re-introducing a degraded blocked1 -> blocked convert path. - assert "ttg.convert_layout %remote_ptr : tensor<256x!tt.ptr, #blocked1> -> tensor<256x!tt.ptr, #blocked>" not in ttgir + assert "ttg.convert_layout %remote_ptr : tensor<256x!tt.ptr, #blocked1> -> tensor<256x!tt.ptr, #blocked>" not in ttgir assert "ld.shared::cluster.b16" in ptx def test_remote_const_shard_vectorized_load_lowering_same_cluster(self): @@ -954,17 +1023,61 @@ def test_remote_const_shard_vectorized_load_lowering_same_cluster(self): expected = torch.cat([expected_chunk, expected_chunk], dim=0) torch.testing.assert_close(out, expected, atol=0.0, rtol=0.0) - def test_remote_pointer_input_disallowed(self): - out = torch.empty((32, ), device="cuda", dtype=torch.float16) - with pytest.raises(Exception, match="only accepts tle.buffered_tensor"): - _remote_pointer_input_disallowed_kernel.warmup( - out, - mesh=BLOCK_CLUSTER_MESH, - BLOCK=32, - grid=(1, ), - num_ctas=1, - num_warps=4, - ) + def test_remote_pointer_input_allowed(self): + block = 32 + grid = 1 + cluster_size = 2 + out = torch.empty((grid * cluster_size * block, ), device="cuda", dtype=torch.float16) + + compiled = _remote_pointer_input_allowed_kernel.warmup( + out, + mesh=BLOCK_CLUSTER_MESH, + BLOCK=block, + grid=(grid, ), + num_ctas=1, + num_warps=4, + ) + ttgir = compiled.asm["ttgir"] + assert "\"tle.remote_pointers\"" in ttgir + + _remote_pointer_input_allowed_kernel[(grid, )]( + out, + mesh=BLOCK_CLUSTER_MESH, + BLOCK=block, + num_ctas=1, + num_warps=4, + ) + torch.cuda.synchronize() + + expected_chunk = torch.arange(0, block, device="cuda", dtype=torch.float16) + expected = torch.cat([expected_chunk, expected_chunk], dim=0) + torch.testing.assert_close(out, expected, atol=0.0, rtol=0.0) + + def test_remote_pointer_scalar_input_allowed(self): + grid = 1 + cluster_size = 2 + out = torch.empty((grid * cluster_size, ), device="cuda", dtype=torch.float16) + + compiled = _remote_pointer_scalar_input_allowed_kernel.warmup( + out, + mesh=BLOCK_CLUSTER_MESH, + grid=(grid, ), + num_ctas=1, + num_warps=4, + ) + ttgir = compiled.asm["ttgir"] + assert "\"tle.remote_pointers\"" in ttgir + + _remote_pointer_scalar_input_allowed_kernel[(grid, )]( + out, + mesh=BLOCK_CLUSTER_MESH, + num_ctas=1, + num_warps=4, + ) + torch.cuda.synchronize() + + expected = torch.zeros_like(out) + torch.testing.assert_close(out, expected, atol=0.0, rtol=0.0) def test_remote_buffer_const_shard_vectorized_load_lowering_same_cluster(self): block_m = 32 @@ -1090,6 +1203,76 @@ def test_remote_rank0_dsmem_atomic_add_runtime_cluster8_stable(self): torch.cuda.synchronize() assert int(out.item()) == 8 + def test_remote_rank0_dsmem_scalar_ptr_atomic_add_lowering_cluster8(self): + out = torch.empty((1, ), device="cuda", dtype=torch.int32) + compiled = _remote_rank0_dsmem_scalar_ptr_atomic_add_kernel.warmup( + out, + mesh=BLOCK_CLUSTER_MESH_8, + grid=(1, ), + num_ctas=1, + num_warps=4, + ) + assert compiled.metadata.cluster_dims == (8, 1, 1) + ptx = compiled.asm["ptx"] + assert "atom.shared::cluster.cta.relaxed.add.u32" in ptx + assert "atom.shared.shared::cluster" not in ptx + + def test_remote_rank0_dsmem_scalar_ptr_atomic_add_runtime_cluster8_stable(self): + out = torch.empty((1, ), device="cuda", dtype=torch.int32) + for _ in range(512): + _remote_rank0_dsmem_scalar_ptr_atomic_add_kernel[(1, )]( + out, + mesh=BLOCK_CLUSTER_MESH_8, + num_ctas=1, + num_warps=4, + ) + torch.cuda.synchronize() + assert int(out.item()) == 8 + + def test_remote_rank0_dsmem_buffer_vs_ptr_remote_atomic_add_lowering_cluster8(self): + block = 128 + out = torch.empty((2, ), device="cuda", dtype=torch.int32) + compiled = _remote_rank0_dsmem_buffer_vs_ptr_remote_atomic_add_kernel.warmup( + out, + mesh=BLOCK_CLUSTER_MESH_8, + BLOCK=block, + grid=(1, ), + num_ctas=1, + num_warps=4, + ) + assert compiled.metadata.cluster_dims == (8, 1, 1) + ptx = compiled.asm["ptx"] + assert "atom.shared::cluster.cta.relaxed.add.u32" in ptx + + @pytest.mark.parametrize("num_warps", [16, 32]) + def test_remote_rank0_dsmem_buffer_vs_ptr_remote_atomic_add_runtime_cluster8_stable(self, num_warps): + block = num_warps * 32 + expected = 8 * block + out = torch.empty((2, ), device="cuda", dtype=torch.int32) + + compiled = _remote_rank0_dsmem_buffer_vs_ptr_remote_atomic_add_kernel.warmup( + out, + mesh=BLOCK_CLUSTER_MESH_8, + BLOCK=block, + grid=(1, ), + num_ctas=1, + num_warps=num_warps, + ) + assert compiled.metadata.cluster_dims == (8, 1, 1) + assert "atom.shared::cluster.cta.relaxed.add.u32" in compiled.asm["ptx"] + + for _ in range(128): + _remote_rank0_dsmem_buffer_vs_ptr_remote_atomic_add_kernel[(1, )]( + out, + mesh=BLOCK_CLUSTER_MESH_8, + BLOCK=block, + num_ctas=1, + num_warps=num_warps, + ) + torch.cuda.synchronize() + assert int(out[0].item()) == expected + assert int(out[1].item()) == expected + def test_remote_scan_shared_scratch_compile_regression_cluster8(self): block = 64 out = torch.empty((8 * block, ), device="cuda", dtype=torch.int32) diff --git a/python/test/tle/integration/test_tle_topk_smem_fallback.py b/python/test/tle/integration/test_tle_topk_smem_fallback.py new file mode 100644 index 0000000000..31aa9f29bf --- /dev/null +++ b/python/test/tle/integration/test_tle_topk_smem_fallback.py @@ -0,0 +1,125 @@ +import importlib.util +from pathlib import Path + +import pytest +import torch +import triton +import triton.language as tl +import triton.experimental.tle.language as tle + + +def _load_topk_module(): + repo_root = Path(__file__).resolve().parents[4] + mod_path = repo_root / "python" / "tutorials" / "tle" / "deepseek_v32" / "01-topk_selector.py" + spec = importlib.util.spec_from_file_location("tle_topk_selector_tutorial", mod_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +_TOPK_MOD = _load_topk_module() + + +def _recall(pred: torch.Tensor, ref: torch.Tensor) -> float: + pred_set = set(pred[0].cpu().tolist()) + ref_set = set(ref[0].cpu().tolist()) + return len(pred_set & ref_set) / ref.shape[1] + + +@triton.jit +def _fallback_only_kernel( + x_ptr, + out_ptr, + stride_xm, + stride_xn, + stride_outm, + stride_outn, + seq_len, + TOPK: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + row_ptr = x_ptr + pid * stride_xm + out_row = out_ptr + pid * stride_outm + hist = tle.gpu.alloc([4096], dtype=tl.int32, layout=None, scope=tle.gpu.smem, nv_mma_shared_layout=False) + write_cnt = tle.gpu.alloc([1], dtype=tl.int32, layout=None, scope=tle.gpu.smem, nv_mma_shared_layout=False) + eq_cnt = tle.gpu.alloc([1], dtype=tl.int32, layout=None, scope=tle.gpu.smem, nv_mma_shared_layout=False) + _TOPK_MOD._tle_topk_smem_overflow_fallback_fullscan( + row_ptr, + out_row, + stride_xn, + stride_outn, + tl.zeros((), dtype=tl.int32), + seq_len, + seq_len, + tle.gpu.local_ptr(hist, (0, )), + tle.gpu.local_ptr(write_cnt, (0, )), + tle.gpu.local_ptr(eq_cnt, (0, )), + TOPK=TOPK, + BLOCK_SIZE=BLOCK_SIZE, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_tle_topk_smem_recall_seq262144(): + torch.manual_seed(1) + seq_len = 262144 + topk = 2048 + x = torch.randn((1, seq_len), device=_TOPK_MOD.DEVICE, dtype=torch.float32) + starts = torch.zeros((1, ), device=_TOPK_MOD.DEVICE, dtype=torch.int32) + ends = torch.full((1, ), seq_len, device=_TOPK_MOD.DEVICE, dtype=torch.int32) + ref = torch.topk(x, topk, dim=-1)[1] + + smem_out = _TOPK_MOD.tle_topk_selector_smem( + x, + starts, + ends, + topk, + block_size=1024, + assume_aligned=True, + ) + assert _recall(smem_out, ref) == 1.0 + + if _TOPK_MOD._supports_tle_cluster_remote(): + cluster_out = _TOPK_MOD.tle_topk_selector_smem_cluster( + x, + starts, + ends, + topk, + block_size=1024, + assume_aligned=True, + ) + assert _recall(cluster_out, ref) == 1.0 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("num_warps", [16, 32]) +def test_tle_topk_fallback_fullscan_stable_high_warps(num_warps): + torch.manual_seed(1) + seq_len = 262144 + topk = 2048 + x = torch.randn((1, seq_len), device=_TOPK_MOD.DEVICE, dtype=torch.float32) + ref = torch.topk(x, topk, dim=-1)[1] + + outputs = [] + for _ in range(3): + out = torch.full((1, topk), -1, device=_TOPK_MOD.DEVICE, dtype=torch.int32) + _fallback_only_kernel[(1, )]( + x, + out, + x.stride(0), + x.stride(1), + out.stride(0), + out.stride(1), + seq_len=seq_len, + TOPK=topk, + BLOCK_SIZE=1024, + num_warps=num_warps, + num_stages=1, + ) + outputs.append(out.clone()) + assert _recall(out, ref) == 1.0 + + out_sets = [set(o[0].cpu().tolist()) for o in outputs] + assert out_sets[0] == out_sets[1] + assert out_sets[1] == out_sets[2] diff --git a/python/test/tle/unit/test_tle.py b/python/test/tle/unit/test_tle.py index acda9e52b9..8fbd7f56bf 100644 --- a/python/test/tle/unit/test_tle.py +++ b/python/test/tle/unit/test_tle.py @@ -153,6 +153,7 @@ def test_tle_module_import(self): """Test TLE module import""" # Check if main functions are importable assert hasattr(tle, 'gpu') + assert hasattr(tle, 'cumsum') assert hasattr(tle.gpu, 'alloc') assert hasattr(tle.gpu, 'copy') assert hasattr(tle.gpu, 'local_ptr') diff --git a/python/test/tle/unit/test_tle_cumsum.py b/python/test/tle/unit/test_tle_cumsum.py new file mode 100644 index 0000000000..77903c025e --- /dev/null +++ b/python/test/tle/unit/test_tle_cumsum.py @@ -0,0 +1,266 @@ +# flagtree tle +""" +Unit tests for TLE cumsum helper. +""" + +import re + +import pytest +import torch +import triton +import triton.language as tl +import triton.experimental.tle.language as tle + + +def _require_cuda(): + try: + torch.cuda.init() + except Exception as exc: + pytest.skip(f"CUDA init failed: {exc}") + + +@pytest.fixture(scope="module", autouse=True) +def _cuda_guard(): + _require_cuda() + + +@triton.jit +def _tle_cumsum_masked_kernel(x_ptr, exclusive_ptr, total_ptr, n, BLOCK: tl.constexpr, REVERSE: tl.constexpr): + offs = tl.arange(0, BLOCK) + mask = offs < n + x = tl.load(x_ptr + offs, mask=mask, other=0) + exclusive, total = tle.cumsum(x, axis=0, reverse=REVERSE) + tl.store(exclusive_ptr + offs, exclusive, mask=mask) + tl.store(total_ptr, total) + + +@triton.jit +def _tle_cumsum_ptx_kernel(x_ptr, exclusive_ptr, total_ptr, BLOCK: tl.constexpr): + offs = tl.arange(0, BLOCK) + x = tl.load(x_ptr + offs) + exclusive, total = tle.cumsum(x, axis=0, reverse=False) + tl.store(exclusive_ptr + offs, exclusive) + tl.store(total_ptr, total) + + +@triton.jit +def _tle_cumsum_callee_shared_kernel(hist_ptr, BLOCK: tl.constexpr): + offs = tl.arange(0, BLOCK) + x = tl.load(hist_ptr + offs) + exclusive, _ = tle.cumsum(x, axis=0, reverse=False) + tl.store(hist_ptr + offs, exclusive) + + +@triton.jit +def _tle_cumsum_call_shared_kernel(exclusive_ptr, sentinel_ptr, BLOCK: tl.constexpr): + sentinel_value = 123456789 + offs = tl.arange(0, BLOCK) + smem = tle.gpu.alloc([BLOCK * 2], dtype=tl.int32, scope=tle.gpu.smem) + base = tle.gpu.local_ptr(smem, (0, )) + + tl.store(base + offs, offs + 1) + tl.store(base + (BLOCK + offs), sentinel_value) + _tle_cumsum_callee_shared_kernel(base, BLOCK=BLOCK) + tl.debug_barrier() + + exclusive = tl.load(base + offs) + sentinel = tl.load(base + (BLOCK + offs)) + tl.store(exclusive_ptr + offs, exclusive) + tl.store(sentinel_ptr + offs, sentinel) + + +@triton.jit +def _tle_cumsum_scalar_base_addptr_kernel(exclusive_ptr, sentinel_ptr, BLOCK: tl.constexpr): + sentinel_value = 123456789 + offs = tl.arange(0, BLOCK) + smem = tle.gpu.alloc([BLOCK * 2], dtype=tl.int32, scope=tle.gpu.smem) + base = tle.gpu.local_ptr(smem, (0, )) + data_ptrs = base + offs + sentinel_ptrs = base + (BLOCK + offs) + + tl.store(data_ptrs, offs + 1) + tl.store(sentinel_ptrs, sentinel_value) + x = tl.load(data_ptrs) + exclusive, _ = tle.cumsum(x, axis=0, reverse=False) + tl.store(data_ptrs, exclusive) + tl.debug_barrier() + + tl.store(exclusive_ptr + offs, tl.load(data_ptrs)) + tl.store(sentinel_ptr + offs, tl.load(sentinel_ptrs)) + + +def _pick_expected_dtype(input_dtype: torch.dtype) -> torch.dtype: + if input_dtype in (torch.int8, torch.int16): + return torch.int32 + if input_dtype == torch.bfloat16: + return torch.float32 + return input_dtype + + +def _make_input(dtype: torch.dtype, block: int) -> torch.Tensor: + if dtype in (torch.float16, torch.float32, torch.bfloat16): + return torch.randn((block, ), device="cuda", dtype=dtype) + if dtype == torch.int8: + return torch.randint(-32, 32, (block, ), device="cuda", dtype=dtype) + if dtype == torch.int16: + return torch.randint(-512, 512, (block, ), device="cuda", dtype=dtype) + if dtype == torch.int32: + return torch.randint(-2048, 2048, (block, ), device="cuda", dtype=dtype) + raise AssertionError(f"unsupported dtype for test: {dtype}") + + +def _entry_block(ptx: str) -> str: + m = re.search(r"\.entry\s+([^\(]+)\(", ptx) + assert m is not None, "failed to locate PTX entry" + begin = ptx.find("{", m.end()) + assert begin >= 0, "failed to locate PTX entry body" + depth = 0 + for i in range(begin, len(ptx)): + if ptx[i] == "{": + depth += 1 + elif ptx[i] == "}": + depth -= 1 + if depth == 0: + return ptx[m.start():i + 1] + raise AssertionError("failed to parse PTX entry block") + + +@pytest.mark.parametrize( + "dtype, n, block, reverse, num_warps", + [ + (torch.int8, 511, 512, False, 16), + (torch.int16, 257, 512, False, 16), + (torch.int32, 512, 512, False, 16), + (torch.int32, 256, 256, True, 8), + (torch.int32, 512, 512, True, 16), + (torch.float16, 127, 128, False, 4), + (torch.float32, 128, 128, True, 4), + (torch.float32, 512, 512, True, 16), + (torch.bfloat16, 193, 256, False, 8), + ], +) +def test_tle_cumsum_exclusive_and_total(dtype, n, block, reverse, num_warps): + if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): + pytest.skip("bfloat16 is not supported on this GPU") + + x = _make_input(dtype, block) + out_dtype = _pick_expected_dtype(dtype) + exclusive = torch.empty((block, ), device="cuda", dtype=out_dtype) + total = torch.empty((1, ), device="cuda", dtype=out_dtype) + + _tle_cumsum_masked_kernel[(1, )]( + x, + exclusive, + total, + n, + BLOCK=block, + REVERSE=reverse, + num_warps=num_warps, + ) + + x_valid = x[:n].to(out_dtype) + if reverse: + expected_exclusive = torch.flip(torch.cumsum(torch.flip(x_valid, dims=[0]), dim=0, dtype=out_dtype), + dims=[0]) - x_valid + else: + expected_exclusive = torch.cumsum(x_valid, dim=0, dtype=out_dtype) - x_valid + expected_total = torch.sum(x_valid, dim=0, dtype=out_dtype) + + if out_dtype in (torch.float16, torch.bfloat16): + torch.testing.assert_close(exclusive[:n], expected_exclusive, atol=2e-2, rtol=2e-2) + torch.testing.assert_close(total[0], expected_total, atol=2e-2, rtol=2e-2) + elif out_dtype == torch.float32: + # GPU parallel scan accumulation order differs from torch's sequential + # cumsum reference, especially in reverse mode. + atol = 1e-5 if reverse else 2e-6 + rtol = 5e-4 if reverse else 1e-5 + torch.testing.assert_close(exclusive[:n], expected_exclusive, atol=atol, rtol=rtol) + torch.testing.assert_close(total[0], expected_total, atol=2e-6, rtol=1e-5) + else: + torch.testing.assert_close(exclusive[:n], expected_exclusive) + torch.testing.assert_close(total[0], expected_total) + + +def test_tle_cumsum_ptx_fastpath_regression_guard(): + block = 512 + x = torch.randint(-1024, 1024, (block, ), device="cuda", dtype=torch.int32) + exclusive = torch.empty_like(x) + total = torch.empty((1, ), device="cuda", dtype=torch.int32) + + compiled = _tle_cumsum_ptx_kernel.warmup( + x, + exclusive, + total, + BLOCK=block, + grid=(1, ), + num_warps=block // 32, + num_stages=1, + ) + + ttgir = compiled.asm["ttgir"] + assert "tle.exclusive_cumsum" in ttgir + assert "\"tt.scan\"" not in ttgir + assert ttgir.count("ttg.convert_layout") == 0 + + ptx = _entry_block(compiled.asm["ptx"]) + assert len(re.findall(r"\bbar\.sync\b", ptx)) == 2 + # TRT/CUB-aligned lowering keeps a single 32-lane scan in the round path. + assert len(re.findall(r"\bshfl\.sync\.up\b", ptx)) == 5 + assert len(re.findall(r"\bshfl\.sync\.idx\b", ptx)) == 1 + assert len(re.findall(r"\bshfl\.sync\b", ptx)) <= 6 + assert len(re.findall(r"@%p\d+\s+ld\.shared", ptx)) == 0 + assert len(re.findall(r"@%p\d+\s+st\.shared", ptx)) == 0 + assert len(re.findall(r"\bselp\b", ptx)) == 0 + + +def test_tle_cumsum_call_shared_frame_regression(): + block = 512 + num_warps = block // 32 + exclusive = torch.empty((block, ), device="cuda", dtype=torch.int32) + sentinel = torch.empty((block, ), device="cuda", dtype=torch.int32) + + compiled = _tle_cumsum_call_shared_kernel.warmup( + exclusive, + sentinel, + BLOCK=block, + grid=(1, ), + num_warps=num_warps, + num_stages=1, + ) + ttgir = compiled.asm["ttgir"] + assert "tt.call" in ttgir, "regression scenario requires cross-function call frame" + + _tle_cumsum_call_shared_kernel[(1, )]( + exclusive, + sentinel, + BLOCK=block, + num_warps=num_warps, + num_stages=1, + ) + + x = torch.arange(1, block + 1, device="cuda", dtype=torch.int32) + expected_exclusive = torch.cumsum(x, dim=0, dtype=torch.int32) - x + expected_sentinel = torch.full((block, ), 123456789, device="cuda", dtype=torch.int32) + torch.testing.assert_close(exclusive, expected_exclusive) + torch.testing.assert_close(sentinel, expected_sentinel) + + +def test_tle_cumsum_scalar_base_addptr_alias_regression(): + block = 512 + num_warps = block // 32 + exclusive = torch.empty((block, ), device="cuda", dtype=torch.int32) + sentinel = torch.empty((block, ), device="cuda", dtype=torch.int32) + + _tle_cumsum_scalar_base_addptr_kernel[(1, )]( + exclusive, + sentinel, + BLOCK=block, + num_warps=num_warps, + num_stages=1, + ) + + x = torch.arange(1, block + 1, device="cuda", dtype=torch.int32) + expected_exclusive = torch.cumsum(x, dim=0, dtype=torch.int32) - x + expected_sentinel = torch.full((block, ), 123456789, device="cuda", dtype=torch.int32) + torch.testing.assert_close(exclusive, expected_exclusive) + torch.testing.assert_close(sentinel, expected_sentinel) diff --git a/python/test/tle/unit/test_tle_gpu_local_ptr.py b/python/test/tle/unit/test_tle_gpu_local_ptr.py index 2f1bbb2499..37e1d9d56e 100644 --- a/python/test/tle/unit/test_tle_gpu_local_ptr.py +++ b/python/test/tle/unit/test_tle_gpu_local_ptr.py @@ -4,6 +4,8 @@ TLE local pointers before writing results back to global memory. """ +import re + import pytest import torch import triton @@ -66,6 +68,66 @@ def _local_pointer_store_kernel(out_ptr, numel, value, BLOCK: tl.constexpr): tl.store(out_tile, out_vals, mask=mask) +@triton.jit +def _local_pointer_full_view_store_kernel(out_ptr, numel, BLOCK: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK + tl.arange(0, BLOCK) + mask = offsets < numel + + smem_tile = tle.gpu.alloc([BLOCK], dtype=tl.int32, layout=None, scope=tle.gpu.smem, nv_mma_shared_layout=False) + smem_ptrs = tle.gpu.local_ptr(smem_tile) + + vals = tl.arange(0, BLOCK) + tl.store(smem_ptrs, vals) + + out_vals = tl.load(smem_ptrs, mask=mask, other=-1) + tl.store(out_ptr + offsets, out_vals, mask=mask) + + +@triton.jit +def _local_pointer_local_load_none_kernel(out_ptr, BLOCK: tl.constexpr): + idx = tl.arange(0, BLOCK) + smem_tile = tle.gpu.alloc([BLOCK], dtype=tl.int32, layout=None, scope=tle.gpu.smem, nv_mma_shared_layout=False) + smem_ptrs = tle.gpu.local_ptr(smem_tile) + tl.store(smem_ptrs, idx + 3) + vals = tl.load(smem_ptrs) + tl.store(out_ptr + idx, vals) + + +@triton.jit +def _local_pointer_local_load_full_indices_kernel(out_ptr, BLOCK: tl.constexpr): + idx = tl.arange(0, BLOCK) + smem_tile = tle.gpu.alloc([BLOCK], dtype=tl.int32, layout=None, scope=tle.gpu.smem, nv_mma_shared_layout=False) + smem_ptrs = tle.gpu.local_ptr(smem_tile, (idx, )) + tl.store(smem_ptrs, idx + 5) + vals = tl.load(smem_ptrs) + tl.store(out_ptr + idx, vals) + + +@triton.jit +def _local_pointer_full_view_2d_copy_kernel( + x_ptr, + out_ptr, + stride_xm, + stride_xn, + stride_om, + stride_on, + ROWS: tl.constexpr, + COLS: tl.constexpr, +): + smem = tle.gpu.alloc([ROWS, COLS], dtype=tl.float32, layout=None, scope=tle.gpu.smem, nv_mma_shared_layout=False) + rows = tl.arange(0, ROWS)[:, None] + cols = tl.arange(0, COLS)[None, :] + x_tile = x_ptr + rows * stride_xm + cols * stride_xn + tle.gpu.copy(x_tile, smem, [ROWS, COLS]) + + full_ptrs = tle.gpu.local_ptr(smem) + vals = tl.load(full_ptrs) + + out_tile = out_ptr + rows * stride_om + cols * stride_on + tl.store(out_tile, vals) + + @triton.jit def _local_pointer_conditional_mask_store_kernel(out_ptr, numel, BLOCK: tl.constexpr): pid = tl.program_id(0) @@ -76,7 +138,7 @@ def _local_pointer_conditional_mask_store_kernel(out_ptr, numel, BLOCK: tl.const ptrs = tle.gpu.local_ptr(smem, (idx, )) # Keep the masked store inside an scf.if region. This used to trigger a - # verifier failure in `triton-tle-assign-local-pointers-encoding` because the + # verifier failure in `triton-tle-select-encodings` because the # store mask did not match the pointer layout. if pid == 0: tl.store(ptrs, idx, mask=mask) @@ -235,6 +297,33 @@ def _local_pointer_dynamic_scalar_load_after_vector_store_kernel( tl.store(out_ptr + i, scalar_val) +@triton.jit +def _local_pointer_full_view_dot_kernel( + a_ptr, + out_ptr, + stride_ai, + stride_aj, + stride_oi, + stride_oj, + BLOCK: tl.constexpr, +): + offs_i = tl.arange(0, BLOCK)[:, None] + offs_j = tl.arange(0, BLOCK)[None, :] + + a_tile_ptr = a_ptr + offs_i * stride_ai + offs_j * stride_aj + a_tile = tl.load(a_tile_ptr) + + smem = tle.gpu.alloc([BLOCK, BLOCK], dtype=tl.float16, layout=None, scope=tle.gpu.smem, nv_mma_shared_layout=False) + smem_ptr = tle.gpu.local_ptr(smem) + tl.store(smem_ptr, a_tile) + + staged = tl.load(smem_ptr) + acc = tl.dot(staged, tl.trans(staged), out_dtype=tl.float32) + + out_ptrs = out_ptr + offs_i * stride_oi + offs_j * stride_oj + tl.store(out_ptrs, acc.to(tl.float16)) + + class TestTLELocalPointerKernel: """Ensure kernels can perform load/compute/store entirely via local pointers.""" @@ -264,6 +353,86 @@ def test_local_pointer_store_populates_constant(self): expected = torch.full_like(out, value) torch.testing.assert_close(out, expected, atol=1e-7, rtol=0) + def test_local_pointer_none_generates_full_view_1d(self): + block = 128 + numel = block - 9 + out = torch.full((block, ), -1, device="cuda", dtype=torch.int32) + + compiled = _local_pointer_full_view_store_kernel.warmup( + out, + numel, + BLOCK=block, + grid=(1, ), + num_warps=4, + ) + ttgir = compiled.asm["ttgir"] + assert "ttg.local_store" in ttgir + line = next((line for line in ttgir.splitlines() if "tle.gpu.local_pointers" in line), None) + if line is not None: + line_lhs = line.split(":", 1)[0] + assert "tle.gpu.local_pointers" in line_lhs + assert "," not in line_lhs + + _local_pointer_full_view_store_kernel[(1, )]( + out, + numel, + BLOCK=block, + num_warps=4, + ) + expected = torch.arange(block, device="cuda", dtype=torch.int32) + expected[numel:] = -1 + torch.testing.assert_close(out, expected, atol=0, rtol=0) + + def test_local_pointer_none_generates_full_view_2d(self): + rows = 16 + cols = 32 + x = torch.randn((rows, cols), device="cuda", dtype=torch.float32) + out = torch.empty_like(x) + + _local_pointer_full_view_2d_copy_kernel[(1, )]( + x, + out, + x.stride(0), + x.stride(1), + out.stride(0), + out.stride(1), + ROWS=rows, + COLS=cols, + ) + torch.testing.assert_close(out, x, atol=1e-6, rtol=1e-6) + + def test_local_pointer_none_load_rewrites_to_local_load(self): + block = 64 + out = torch.empty((block, ), device="cuda", dtype=torch.int32) + compiled = _local_pointer_local_load_none_kernel.warmup( + out, + BLOCK=block, + grid=(1, ), + num_warps=4, + ) + ttgir = compiled.asm["ttgir"] + assert "ttg.local_load" in ttgir + + _local_pointer_local_load_none_kernel[(1, )](out, BLOCK=block, num_warps=4) + expected = torch.arange(block, device="cuda", dtype=torch.int32) + 3 + torch.testing.assert_close(out, expected, atol=0, rtol=0) + + def test_local_pointer_full_indices_load_rewrites_to_local_load(self): + block = 64 + out = torch.empty((block, ), device="cuda", dtype=torch.int32) + compiled = _local_pointer_local_load_full_indices_kernel.warmup( + out, + BLOCK=block, + grid=(1, ), + num_warps=4, + ) + ttgir = compiled.asm["ttgir"] + assert "ttg.local_load" in ttgir + + _local_pointer_local_load_full_indices_kernel[(1, )](out, BLOCK=block, num_warps=4) + expected = torch.arange(block, device="cuda", dtype=torch.int32) + 5 + torch.testing.assert_close(out, expected, atol=0, rtol=0) + def test_local_pointer_conditional_mask_store_compiles(self): block = 512 numel = block - 7 @@ -390,6 +559,41 @@ def test_local_pointer_scalar_dynamic_index_inserts_barrier(self): rtol=0, ) + def test_local_pointer_full_view_dot_avoids_pointer_convert_layout(self): + block = 32 + a = torch.randn((block, block), device="cuda", dtype=torch.float16) + out = torch.empty_like(a) + + compiled = _local_pointer_full_view_dot_kernel.warmup( + a, + out, + a.stride(0), + a.stride(1), + out.stride(0), + out.stride(1), + BLOCK=block, + grid=(1, ), + num_warps=4, + num_stages=2, + ) + ttgir = compiled.asm["ttgir"] + assert "ttg.local_load" in ttgir + assert re.search(r"ttg\\.convert_layout .*-> tensor<.*!tt\\.ptr", ttgir) is None + + _local_pointer_full_view_dot_kernel[(1, )]( + a, + out, + a.stride(0), + a.stride(1), + out.stride(0), + out.stride(1), + BLOCK=block, + num_warps=4, + num_stages=2, + ) + expected = (a.float() @ a.float().T).to(torch.float16) + torch.testing.assert_close(out, expected, atol=2e-1, rtol=2e-1) + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) diff --git a/python/triton/experimental/tle/__init__.py b/python/triton/experimental/tle/__init__.py index 6d5b6024b9..186038eeac 100644 --- a/python/triton/experimental/tle/__init__.py +++ b/python/triton/experimental/tle/__init__.py @@ -1,24 +1,13 @@ # flagtree tle - from . import language + try: from . import raw except ModuleNotFoundError: raw = None -from .language.gpu import ( - extract_tile, - insert_tile, - alloc, - copy, -) - __all__ = [ "language", - "extract_tile", - "insert_tile", - "alloc", - "copy", ] if raw is not None: diff --git a/python/triton/experimental/tle/language/__init__.py b/python/triton/experimental/tle/language/__init__.py index 443dad4a49..9a5f92b6e7 100644 --- a/python/triton/experimental/tle/language/__init__.py +++ b/python/triton/experimental/tle/language/__init__.py @@ -1,6 +1,8 @@ # flagtree tle from .core import ( - load, ) + cumsum, + load, +) from .distributed import ( B, P, @@ -21,11 +23,9 @@ sharding, ) -from . import distributed, gpu, raw -from .gpu import extract_tile, insert_tile - __all__ = [ "load", + "cumsum", "device_mesh", "S", "P", @@ -42,6 +42,6 @@ "distributed", "gpu", "raw", - "extract_tile", - "insert_tile", ] + +from . import distributed, gpu, raw diff --git a/python/triton/experimental/tle/language/core.py b/python/triton/experimental/tle/language/core.py index 7da37bcc39..94c870ee3d 100644 --- a/python/triton/experimental/tle/language/core.py +++ b/python/triton/experimental/tle/language/core.py @@ -1,6 +1,17 @@ # flagtree tle import triton.language.core as tl + +def _tle_pick_sum_dtype(in_dtype, dtype): + if dtype is not None: + return dtype + if in_dtype.is_int_signed(): + return tl.int32 if in_dtype.int_bitwidth < 32 else None + if in_dtype.is_int_unsigned(): + return tl.uint32 if in_dtype.int_bitwidth < 32 else None + return None + + # ----------------------- # Non-Atomic Memory Operations # ----------------------- @@ -56,3 +67,45 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c cache_modifier=cache_modifier, eviction_policy=eviction_policy, volatile=volatile, _semantic=_semantic) x.handle.set_attr("tt.load.async", _semantic.builder.get_bool_attr(is_async)) return x + + +@tl.builtin +def cumsum(input, axis=0, reverse=False, dtype: tl.constexpr = None, _semantic=None, _generator=None): + """ + Compute exclusive cumulative sum and total sum along :code:`axis`. + + Returns a tuple :code:`(exclusive_sum, total_sum)` where: + - :code:`exclusive_sum[i] = sum(input[:i])` (or reverse-exclusive when ``reverse=True``) + - :code:`total_sum = sum(input)` + """ + axis = tl._unwrap_if_constexpr(axis) + reverse = tl._unwrap_if_constexpr(reverse) + dtype = tl._unwrap_if_constexpr(dtype) + input = tl._promote_bfloat16_to_float32(input, _semantic=_semantic) + out_dtype: tl.constexpr = _tle_pick_sum_dtype(input.dtype, dtype) + if out_dtype is not None: + input = input.to(out_dtype, _semantic=_semantic) + + if not isinstance(input, tl.tensor): + input = _semantic.to_tensor(input) + + input_ty = input.type + if not input_ty.is_block(): + zero = tl.full((), 0, input_ty, _semantic=_semantic) + return zero, input + + exclusive_ty = input_ty + total_ty = input_ty.scalar + exclusive_ir = exclusive_ty.to_ir(_semantic.builder) + total_ir = total_ty.to_ir(_semantic.builder) + + cumsum_op = _semantic.builder.create_exclusive_cumsum( + exclusive_ir, + total_ir, + input.handle, + int(axis), + bool(reverse), + ) + exclusive_sum = tl.tensor(cumsum_op.get_result(0), exclusive_ty) + total_sum = tl.tensor(cumsum_op.get_result(1), total_ty) + return exclusive_sum, total_sum diff --git a/python/triton/experimental/tle/language/distributed.py b/python/triton/experimental/tle/language/distributed.py index 5017028106..cbb5208190 100644 --- a/python/triton/experimental/tle/language/distributed.py +++ b/python/triton/experimental/tle/language/distributed.py @@ -680,18 +680,25 @@ def _create_remote_pointers_tensor( tensor: tl.tensor, shard_id_tensor: tl.tensor, _semantic, -) -> tl.tensor | None: +) -> tl.tensor: builder = _semantic.builder - remote_type = tensor.type.to_ir(builder) - try: - remote_op = builder.create_remote_pointers( - remote_type, - tensor.handle, - shard_id_tensor.handle, - ) - except AttributeError: - return None - return tl.tensor(remote_op.get_result(0), tensor.type) + if not tensor.dtype.is_ptr(): + raise TypeError("remote(pointer, ...) requires a pointer tensor input") + if not hasattr(builder, "create_remote_pointers"): + raise RuntimeError("remote pointer lowering requires TLE remote_pointers support in the active Triton build") + remote_ptr_dtype = tl.pointer_type(tensor.dtype.element_ty, 7) + if tensor.type.is_block(): + remote_type = tl.block_type(remote_ptr_dtype, list(tensor.shape)).to_ir(builder) + else: + remote_type = remote_ptr_dtype.to_ir(builder) + remote_op = builder.create_remote_pointers( + remote_type, + tensor.handle, + shard_id_tensor.handle, + ) + if tensor.type.is_block(): + return tl.tensor(remote_op.get_result(0), tl.block_type(remote_ptr_dtype, list(tensor.shape))) + return tl.tensor(remote_op.get_result(0), remote_ptr_dtype) def _remote_pointer( @@ -704,42 +711,33 @@ def _remote_pointer( raise TypeError(f"tensor must be tl.tensor, got {type(tensor).__name__}") if not tensor.dtype.is_ptr(): raise TypeError("remote(pointer, ...) internal path requires a pointer tensor") + if tensor.dtype.address_space == 7: + # Pointer is already in cluster-shared space. Preserve compatibility + # for existing callsites that re-annotate with shard_id=0. + if isinstance(shard_id, (int, tuple, list)): + linear_shard_id = _normalize_compile_time_remote_shard_id(shard_id, scope) + if linear_shard_id == 0: + return tensor + raise ValueError("remote(pointer, ...) on cluster-shared pointers only supports shard_id=0") + raise ValueError("remote(pointer, ...) on cluster-shared pointers requires compile-time shard_id=0") + if tensor.dtype.address_space != 3: - raise ValueError("remote(pointer, ...) internal path requires shared-memory pointers (addrspace=3)") + raise ValueError("remote(pointer, ...) internal path requires shared-memory pointers (addrspace=3) " + "or cluster-shared pointers (addrspace=7)") # Compile-time constant shard id path. if isinstance(shard_id, (int, tuple, list)): linear_shard_id = _normalize_compile_time_remote_shard_id(shard_id, scope) - # Prefer explicit remote_pointers op so remote metadata survives - # downstream layout/materialization rewrites. shard_id_tensor = _semantic.to_tensor(int(linear_shard_id)) shard_id_tensor = _normalize_runtime_remote_shard_id_tensor(shard_id_tensor) - remote_ptr = _create_remote_pointers_tensor(tensor, shard_id_tensor, _semantic) - if remote_ptr is not None: - return remote_ptr - - # Compatibility fallback for older TLE extensions. - tensor.handle.set_attr("tle.remote_cta_id", _semantic.builder.get_int32_attr(int(linear_shard_id))) - return tensor + return _create_remote_pointers_tensor(tensor, shard_id_tensor, _semantic) # Runtime shard id path. This materializes a TLE op that carries the # runtime i32 shard id through lowering. shard_id_tensor = shard_id if isinstance(shard_id, tl.tensor) else _semantic.to_tensor(shard_id) shard_id_tensor = _normalize_runtime_remote_shard_id_tensor(shard_id_tensor) - # Preferred path: keep remote semantics through a dedicated TLE op so the - # shard-id survives local_pointers lowering. - remote_ptr = _create_remote_pointers_tensor(tensor, shard_id_tensor, _semantic) - if remote_ptr is not None: - return remote_ptr - - # Compatibility fallback for older TLE extensions. - # Represent runtime shard_id with a marked addptr op. The lowering rewrites - # pointer arithmetic to use the original base pointer and consumes the - # runtime i32 from addptr's offset operand as cluster CTA id. - remote_ptr = _semantic.add(tensor, shard_id_tensor, sanitize_overflow=True) - remote_ptr.handle.set_attr("tle.remote_shard_id_carrier", _semantic.builder.get_unit_attr()) - return remote_ptr + return _create_remote_pointers_tensor(tensor, shard_id_tensor, _semantic) @tl.builtin @@ -755,6 +753,8 @@ def remote( Supported input: - tle buffered_tensor: returns a remote-marked buffered tensor; caller should then use `tle.gpu.local_ptr(...)` to materialize remote pointers. + - tl.tensor shared-memory pointer (scalar or tensor): returns remote + pointer directly. `shard_id` is the target block id inside the current thread block cluster. When `scope` is provided, launch cluster dimensions are inferred from that @@ -767,6 +767,11 @@ def remote( if scope is not None: _apply_mesh_cluster_launch(scope, _semantic) + # Direct pointer path: support local_ptr scalar/tensor values and return + # remote pointer with preserved shape. + if isinstance(tensor, tl.tensor): + return _remote_pointer(tensor, shard_id, scope=scope, _semantic=_semantic) + # Buffered tensor path: carry remote metadata and let `local_ptr` materialize # remote pointers later. if _is_buffered_tensor_like(tensor): @@ -801,9 +806,6 @@ def remote( setattr(remote_buffer, "_tle_remote_scope", scope) return remote_buffer - if isinstance(tensor, tl.tensor): - raise TypeError("remote(...) only accepts tle.buffered_tensor; " - "use remote(buffered_tensor, shard_id, scope) + local_ptr(...)") raise TypeError(f"tensor must be tle.buffered_tensor, got {type(tensor).__name__}") diff --git a/python/triton/experimental/tle/language/gpu/__init__.py b/python/triton/experimental/tle/language/gpu/__init__.py index 805119bb19..addd7392c0 100644 --- a/python/triton/experimental/tle/language/gpu/__init__.py +++ b/python/triton/experimental/tle/language/gpu/__init__.py @@ -5,8 +5,6 @@ copy, memory_space, local_ptr, - extract_tile, - insert_tile, ) from .types import (layout, shared_layout, swizzled_shared_layout, tensor_memory_layout, nv_mma_shared_layout, scope, buffered_tensor, buffered_tensor_type, smem, tmem) @@ -20,8 +18,6 @@ "copy", "local_ptr", "storage_kind", - "extract_tile", - "insert_tile", "layout", "memory_space", "shared_layout", diff --git a/python/triton/experimental/tle/language/gpu/core.py b/python/triton/experimental/tle/language/gpu/core.py index 397660a67b..254b558379 100644 --- a/python/triton/experimental/tle/language/gpu/core.py +++ b/python/triton/experimental/tle/language/gpu/core.py @@ -801,9 +801,10 @@ def local_ptr( Args: buffer: Local memory buffer tensor returned by ``tle.alloc``. - indices: Tuple of integer index tensors. The tuple length must equal - the rank of ``buffer`` and every tensor must have the same shape. - The output pointer tensor will have that same shape. + indices: Optional tuple of integer index tensors. If provided, tuple + length must equal ``rank(buffer)`` and every tensor must have the + same shape. If ``None``, emit a full-view pointer tensor over + ``buffer`` shape (or scalar pointer for rank-0 buffer). _semantic: Semantic analyzer (internal use). _generator: Triton code generator (internal use). @@ -822,46 +823,51 @@ def local_ptr( remote_scope = getattr(buffer, "_tle_remote_scope", None) remote_buffer_marker = remote_shard_id is not None + buffer_shape = tuple(int(tl._unwrap_if_constexpr(dim)) for dim in buffer.type.shape) indices = tl._unwrap_if_constexpr(indices) - if indices is None: - raise ValueError("local_ptr indices must be provided as a tuple of tensors") - if isinstance(indices, tl.tuple): + no_indices = indices is None + if no_indices: + indices_tuple = tuple() + elif isinstance(indices, tl.tuple): indices_tuple = tuple(indices.values) elif isinstance(indices, (tuple, list)): indices_tuple = tuple(indices) else: - raise ValueError("local_ptr indices must be a tuple or list of tensors") - - buffer_shape = tuple(int(tl._unwrap_if_constexpr(dim)) for dim in buffer.type.shape) - if len(indices_tuple) != len(buffer_shape): + raise ValueError("local_ptr indices must be a tuple/list of tensors or None") + if not no_indices and len(indices_tuple) != len(buffer_shape): raise ValueError(f"local_ptr indices must provide {len(buffer_shape)} tensors, got {len(indices_tuple)}") idx_tensors: list[tensor] = [] view_shape: Optional[tuple[int, ...]] = None scalar_index_flags: list[bool] = [] - for idx in indices_tuple: - idx_tensor = idx if isinstance(idx, tensor) else _semantic.to_tensor(idx) - if not idx_tensor.dtype.is_int(): - raise ValueError("local_ptr indices must use integer dtypes") - is_scalar_index = not idx_tensor.type.is_block() - scalar_index_flags.append(is_scalar_index) - if is_scalar_index: + if not no_indices: + for idx in indices_tuple: + idx_tensor = idx if isinstance(idx, tensor) else _semantic.to_tensor(idx) + if not idx_tensor.dtype.is_int(): + raise ValueError("local_ptr indices must use integer dtypes") + is_scalar_index = not idx_tensor.type.is_block() + scalar_index_flags.append(is_scalar_index) + if is_scalar_index: + idx_tensors.append(idx_tensor) + continue + if view_shape is None: + view_shape = tuple(idx_tensor.shape) + elif tuple(idx_tensor.shape) != view_shape: + raise ValueError("local_ptr indices must have identical shapes") idx_tensors.append(idx_tensor) - continue - if view_shape is None: - view_shape = tuple(idx_tensor.shape) - elif tuple(idx_tensor.shape) != view_shape: - raise ValueError("local_ptr indices must have identical shapes") - idx_tensors.append(idx_tensor) - - if not idx_tensors: - raise ValueError("local_ptr indices cannot be empty") - all_scalar_indices = all(scalar_index_flags) - any_scalar_indices = any(scalar_index_flags) - if any_scalar_indices and not all_scalar_indices: - raise ValueError("local_ptr indices must be either all scalar or all tensors with identical shapes") - if not all_scalar_indices and view_shape is None: - view_shape = tuple() + + if not idx_tensors: + raise ValueError("local_ptr indices cannot be empty") + all_scalar_indices = all(scalar_index_flags) + any_scalar_indices = any(scalar_index_flags) + if any_scalar_indices and not all_scalar_indices: + raise ValueError("local_ptr indices must be either all scalar or all tensors with identical shapes") + if not all_scalar_indices and view_shape is None: + view_shape = tuple() + else: + all_scalar_indices = len(buffer_shape) == 0 + if not all_scalar_indices: + view_shape = buffer_shape try: from .semantic import TLESemantic @@ -875,7 +881,14 @@ def local_ptr( insert_block = _semantic.builder.get_insertion_block() if insert_block is None: raise RuntimeError("TLE local_ptr called without an insertion block") - if all_scalar_indices: + if no_indices: + if len(buffer_shape) == 0: + result_ty = ptr_dtype + result_ir = ptr_dtype.to_ir(_semantic.builder) + else: + result_ty = tl.block_type(ptr_dtype, list(buffer_shape)) + result_ir = result_ty.to_ir(_semantic.builder) + elif all_scalar_indices: result_ty = ptr_dtype result_ir = ptr_dtype.to_ir(_semantic.builder) else: @@ -887,10 +900,10 @@ def local_ptr( result_tensor = tl.tensor(local_ptr_op.get_result(0), result_ty) if remote_buffer_marker: - if all_scalar_indices: - raise ValueError("local_ptr does not yet support scalar indices on remote buffers") # Keep remote semantics attached to the source buffered tensor and - # materialize them only when pointer view is requested. + # materialize them only when pointer view is requested. This applies + # to both block and scalar pointer views so remote/local_ptr semantics + # stay aligned with local shared-memory local_ptr. from triton.experimental.tle.language import distributed as _tle_distributed result_tensor = _tle_distributed._remote_pointer( result_tensor, diff --git a/python/tutorials/tle/04-fft.py b/python/tutorials/tle/01-fft.py similarity index 100% rename from python/tutorials/tle/04-fft.py rename to python/tutorials/tle/01-fft.py diff --git a/python/tutorials/tle/01-sparse-mla.py b/python/tutorials/tle/01-sparse-mla.py deleted file mode 100644 index 02b0c5c272..0000000000 --- a/python/tutorials/tle/01-sparse-mla.py +++ /dev/null @@ -1,224 +0,0 @@ -# flagtree -""" -Sparse MLA Forward -================== - -This module implements a Triton kernel for the forward pass of a sparse MLA (Multi-Headed Attention) mechanism. -It demonstrates the use of Triton's TLE (Triton Language Extensions) for efficient memory access and computation. -""" - -import torch -import triton -import triton.language as tl -import triton.experimental.tle.language as tle - -spar_mla_fwd_configs = [ - triton.Config({'num_stages': 4, 'num_warps': 8}), - # triton.Config({'num_stages': 2, 'num_warps': 4}), -] - - -@triton.autotune( # Decorate the kernel - configs=spar_mla_fwd_configs, - key=['K', 'is_causal'], -) -@triton.jit -def triton_sparse_mla_fwd(q, kv, indices, sm_scale: tl.constexpr, output, lse, stride_qb, stride_qh, stride_qm, - stride_qd, stride_kvb, stride_kvg, stride_kvn, stride_kvd, stride_tb, stride_tg, stride_tm, - stride_tt, # topk,for indices - stride_ob, stride_oh, stride_om, stride_od, stride_lb, stride_lh, stride_lm, B: tl.constexpr, - SQ: tl.constexpr, # seqlen - SKV: tl.constexpr, K: tl.constexpr, # topk - D: tl.constexpr, # QKV dim - TD: tl.constexpr, # tail dim - DP: tl.constexpr, TDP: tl.constexpr, H: tl.constexpr, # q_head_dim - G: tl.constexpr, # group_size - VG: tl.constexpr, # H/G KV groups - BK: tl.constexpr, BH: tl.constexpr, is_causal: tl.constexpr): - i_b, i_sq, i_gbh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_g, i_bh = i_gbh // G, i_gbh % G - q_base = q + i_b * stride_qb + i_sq * stride_qm + i_gbh * (BH * stride_qh) - tq_base = q_base + D * stride_qd - kv_base = kv + i_b * stride_kvb + i_g * stride_kvg - tkv_base = kv_base + D * stride_kvd - t_base = indices + i_b * stride_tb + i_sq * stride_tm + i_g * stride_tg - o_base = output + i_b * stride_ob + i_sq * stride_om + i_gbh * (BH * stride_oh) - l_base = lse + i_b * stride_lb + i_sq * stride_lm + i_gbh * (BH * stride_lh) - - offs_h = tl.arange(0, BH) - offs_d = tl.arange(0, DP) - offs_td = tl.arange(0, TDP) - offs_od = tl.arange(0, DP) - offs_t = tl.arange(0, BK) - mask_h = i_bh * BH + offs_h < G - mask_d = offs_d < D - mask_td = offs_td < TD - mask_od = mask_d - - q_ptr = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd - q_msk = mask_h[:, None] & mask_d[None, :] - q_blk = tl.load(q_ptr, q_msk, other=0.0) - - tq_ptr = tq_base + offs_h[:, None] * stride_qh + offs_td[None, :] * stride_qd - tq_msk = mask_h[:, None] & mask_td[None, :] - tq_blk = tl.load(tq_ptr, tq_msk, other=0.0) - - max_prev = tl.full([BH], float('-inf'), dtype=tl.float32) - sum_exp = tl.full([BH], 1.0, dtype=tl.float32) - acc = tl.zeros([BH, DP], dtype=tl.float32) - - log_scale: tl.constexpr = sm_scale * 1.44269504 - - max_col = i_sq if is_causal else SQ - 1 - - NK = tl.cdiv(K, BK) - for ck in tl.range(NK, num_stages=0): - if ck * BK <= max_col: - t_ptr = (BK * ck + offs_t) * stride_tt - t_msk = t_ptr < K - t_ptr += t_base - kv_ids = tl.load(t_ptr, t_msk, other=-1) - mask_ids = (kv_ids <= max_col) & (kv_ids >= 0) - - kv_ptr = kv_base + offs_d[:, None] * stride_kvd + kv_ids[None, :] * stride_kvn - kv_msk = mask_d[:, None] & mask_ids[None, :] - kv_blk = tle.load(kv_ptr, kv_msk, other=0.0, is_async=True) # [DP, BK] - - tkv_ptr = tkv_base + offs_td[:, None] * stride_kvd + kv_ids[None, :] * stride_kvn - tkv_msk = mask_td[:, None] & mask_ids[None, :] - tkv_blk = tle.load(tkv_ptr, tkv_msk, other=0.0, is_async=False) # [TDP, BK] - - qk = tl.dot(tq_blk, tkv_blk, out_dtype=tl.float32) - qk = tl.dot(q_blk, kv_blk, qk, out_dtype=tl.float32) - - qk = tl.where(mask_ids[None, :], qk, float('-inf')) # [BH, BK] - - new_max = tl.maximum(max_prev, tl.max(qk, axis=1)) - alpha = tl.math.exp2((max_prev - new_max) * log_scale) - exp_qk = tl.math.exp2(qk * log_scale - new_max[:, None] * log_scale) - sum_qk = tl.sum(exp_qk, axis=1) - sum_exp = sum_exp * alpha + sum_qk - acc = acc * alpha[:, None] - exp_qk = exp_qk.to(tl.bfloat16) - acc = tl.dot(exp_qk, tl.trans(kv_blk), acc, out_dtype=tl.float32) # [BH, BK] @ [BK, DP] = [BH, DP] - - max_prev = new_max - - out_vals = acc / sum_exp[:, None] - o_ptr = o_base + offs_h[:, None] * stride_oh + offs_od[None, :] * stride_od - o_msk = mask_h[:, None] & mask_od[None, :] - tl.store(o_ptr, out_vals.to(q_blk.dtype), o_msk) - - fin_log = max_prev * log_scale + tl.math.log2(sum_exp.to(tl.float32)) # lse / ln2 - l_ptr = l_base + offs_h * stride_lh - l_msk = mask_h - tl.store(l_ptr, fin_log.to(q_blk.dtype), l_msk) - - -def triton_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512): - is_causal = True - assert not return_p_sum, "This kernel file is for fwd only" - assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() - B, SQ, H, DT = q.shape - _, S, VG, _ = kv.shape - - # assert DT == 576, "you should assign dim otherwise" - D = d_v - - assert kv.shape[-1] == DT - TD = DT - D - DP = triton.next_power_of_2(D) - TDP = triton.next_power_of_2(TD) - assert kv.shape[0] == B - _, _, _, K = indices.shape - assert indices.shape == (B, SQ, VG, K) - G = H // VG - if sm_scale is None: - sm_scale = DT**-0.5 - BH = 32 - NH = triton.cdiv(G, BH) - BK = 32 - output = torch.zeros((B, SQ, H, D), device=q.device, dtype=q.dtype) - lse = torch.full((B, SQ, H), float('-inf'), device=q.device, dtype=q.dtype) - grid = (B, SQ, VG * NH) # (SQ//BQ, B*H) - triton_sparse_mla_fwd[grid]( - q, kv, indices, sm_scale, output, lse, q.stride(0), q.stride(2), q.stride(1), q.stride(3), # [B, H, SQ, DT] - kv.stride(0), kv.stride(2), kv.stride(1), kv.stride(3), # [B, VG, SKV, DT] - indices.stride(0), indices.stride(2), indices.stride(1), indices.stride(3), # [B, VG, SQ, K] - output.stride(0), output.stride(2), output.stride(1), output.stride(3), # [B, H, SQ, D] - lse.stride(0), lse.stride(2), lse.stride(1), # [B, H, SQ] - B, SQ, S, K, D, TD, DP, TDP, H, G, VG, BK, BH, - # BD, - is_causal) - # sparse_mla_fwd[grid](q, kv, indices, output) - return output, lse - - -def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True, d_v=512): - q = q.float() - kv = kv.float() - indices = indices.transpose(1, 2) - b, sq, h, dim_q = q.shape - b, sk, g, _ = kv.shape - - dim = d_v - # assert kv.shape[-1] == 576, "you should assign dim otherwise" - # dim = 512 - k = kv - v = kv[..., :dim] - - b, _, _, dim_v = v.shape - g_index = g - h_index = h // g - compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, - device="cuda").view(-1, - 1) >= torch.arange(1 - 1, sk * 1, 1, dtype=torch.int32, - device="cuda").view(1, -1) - - mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) - mask = mask[..., :-1] - mask = mask & compressed_casual_mask.view(1, 1, sq, sk) - mask[:, :, :1 - 1, 0] = True - mask = mask.view(b, g_index, 1, sq, sk) - - q = q.view(b, sq, g, -1, dim_q) - score = torch.einsum("bmghd,bngd->bghmn", q, k) - sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale - score = score.masked_fill(~mask, float("-inf")).mul(sm_scale) - p = score.softmax(dim=-1) - p = p.view(b, g_index, h_index, -1, sq, sk) - p = p.view(b, g, -1, sq, sk) - o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v) - o = o.reshape(b, sq, h, dim_v) - return o.to(torch.bfloat16) - - -def test_sparse_mla_fwd(B=1, S=4096, SKV=4096, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16): - torch.random.manual_seed(0) - q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) - kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) - - indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") - for b in range(B): - for t in range(S): - for h in range(HKV): - i_i = torch.randperm(max(1, t))[:topk] - indices[b, t, h, :len(i_i)] = i_i - - ref_bf16_out = ref_sparse_mla_fwd_interface(q, kv, indices, d_v=DV) - - triton_bf16_out, triton_bf16_lse = triton_sparse_mla_fwd_interface(q, kv, indices, d_v=DV) - print("triton bf16 done \n triton lse tensor: \n", triton_bf16_lse) - print() - - assert torch.allclose( - triton_bf16_out.float(), - ref_bf16_out.float(), - atol=1e-1, - rtol=1e-1, - ), "Triton sparse MLA fwd bf16 does not match reference" - print("Triton sparse MLA fwd bf16 matches reference!") - - -if __name__ == "__main__": - test_sparse_mla_fwd(B=1, S=128, SKV=1024, H=32, HKV=1, DQK=256 + 32, DV=256, topk=64, dtype=torch.bfloat16) diff --git a/python/tutorials/tle/03-topk.py b/python/tutorials/tle/03-topk.py index 15918854c4..4c774b1454 100644 --- a/python/tutorials/tle/03-topk.py +++ b/python/tutorials/tle/03-topk.py @@ -5,6 +5,7 @@ This tutorial implements Top-K over the last dimension of an (M, N) tensor and compares: - radix: Triton radix-select kernel +- triton: Triton streaming top-k kernel - torch: torch.topk """ @@ -30,11 +31,10 @@ def get_topmask_and_fullmask(x): @triton.jit -def fpval_to_key_with_nan(x, x_bits): +def fpval_to_key(x_bits): tm, fm = get_topmask_and_fullmask(x_bits) mask = tl.where((x_bits & tm) != 0, fm, tm) - key = x_bits ^ mask - return tl.where(x == x, key, fm) + return x_bits ^ mask @triton.jit @@ -44,6 +44,18 @@ def key_to_fpval(x): return x ^ mask +@triton.jit +def indx_to_key(indx): + max_u16 = tl.full(indx.shape, 0xFFFF, dtype=tl.uint32) + return max_u16 - indx.to(tl.uint32) + + +@triton.jit +def key_to_indx(indx_key): + max_u16 = tl.full(indx_key.shape, 0xFFFF, dtype=tl.uint32) + return (max_u16 - indx_key.to(tl.uint32)).to(tl.int32) + + @triton.jit def topk_kernel_radix_triton( X, @@ -53,20 +65,14 @@ def topk_kernel_radix_triton( stride_ym, n_cols, K: tl.constexpr, - K_PAD: tl.constexpr, BLOCK_N: tl.constexpr, RADIX_BITS: tl.constexpr, ): pid = tl.program_id(0) - # Stage 0: setup dtype metadata and key packing types. + # Stage 0: setup dtype metadata. x_dtype = X.dtype.element_ty x_nbits: tl.constexpr = x_dtype.primitive_bitwidth - if x_nbits < 16: - y_nbits: tl.constexpr = 32 - else: - y_nbits: tl.constexpr = x_nbits * 2 x_utype = tl.dtype(f"uint{x_nbits}") - x_ultype = tl.dtype(f"uint{y_nbits}") RADIX_SIZE: tl.constexpr = 1 << RADIX_BITS RADIX_MASK: tl.constexpr = RADIX_SIZE - 1 @@ -76,6 +82,7 @@ def topk_kernel_radix_triton( desired = tl.full((), 0, dtype=x_utype) desired_mask = tl.full((), 0, dtype=x_utype) k_to_find = tl.full((), K, dtype=tl.int32) + k_limit = tl.full((), K, dtype=tl.int32) n_tiles = tl.cdiv(n_cols, BLOCK_N) # Stage 1: shared-memory histogram storage for each radix digit. @@ -97,7 +104,7 @@ def topk_kernel_radix_triton( x_ptrs = X + pid * stride_xm + offs_n x = tl.load(x_ptrs, mask=mask_n, other=float("-inf")) x_bits = x.to(x_utype, bitcast=True) - x_key = fpval_to_key_with_nan(x, x_bits) + x_key = fpval_to_key(x_bits) matches = (x_key & desired_mask) == desired digit = ((x_key >> digit_pos) & RADIX_MASK).to(tl.int32) valid = mask_n & matches @@ -108,46 +115,20 @@ def topk_kernel_radix_triton( # Compute descending cumulative histogram in-place. cumsum_desc = tl.cumsum(counts, axis=0, reverse=True) - tl.store(smem_count_ptrs, cumsum_desc) - - selected_scalar = 0 - counts_gt_scalar = 0 - found = 0 - for rev in tl.static_range(RADIX_SIZE): - d = RADIX_SIZE - 1 - rev - cum_d = tl.load(tle.gpu.local_ptr(smem_counts, (d, ))) - if d + 1 < RADIX_SIZE: - cum_next = tl.load(tle.gpu.local_ptr(smem_counts, (d + 1, ))) - else: - cum_next = 0 - take = (found == 0) & (cum_d >= k_to_find) & (cum_next < k_to_find) - selected_scalar = tl.where(take, d, selected_scalar) - counts_gt_scalar = tl.where(take, cum_next, counts_gt_scalar) - found = tl.where(take, 1, found) - - selected_u = selected_scalar.to(x_utype) + + cond = cumsum_desc >= k_to_find + selected = tl.max(tl.where(cond, bins, 0), axis=0).to(tl.int32) + counts_gt = tl.max(tl.where(bins == (selected + 1), cumsum_desc, 0), axis=0) + + selected_u = selected.to(x_utype) desired = desired | (selected_u << digit_pos) desired_mask = desired_mask | (tl.full((), RADIX_MASK, dtype=x_utype) << digit_pos) - k_to_find = k_to_find - counts_gt_scalar + k_to_find = k_to_find - counts_gt # Stage 3: compact candidates with shared-memory atomic write count. thr_key = desired - - min_val = tl.full((), float("-inf"), tl.float32).to(x_dtype) - min_bits = min_val.to(x_utype, bitcast=True) - min_key = fpval_to_key_with_nan(min_val, min_bits) - min_packed = min_key.to(x_ultype) << 16 - offs_k = tl.arange(0, K_PAD) - - smem_selected = tle.gpu.alloc( - [K_PAD], - dtype=x_ultype, - layout=None, - scope=tle.gpu.smem, - nv_mma_shared_layout=False, - ) - smem_selected_ptrs = tle.gpu.local_ptr(smem_selected, (offs_k, )) - tl.store(smem_selected_ptrs, tl.full([K_PAD], min_packed, dtype=x_ultype)) + thr_bits = key_to_fpval(thr_key) + thr_val = thr_bits.to(x_dtype, bitcast=True) smem_write_count = tle.gpu.alloc( [1], @@ -165,47 +146,88 @@ def topk_kernel_radix_triton( mask_n = offs_n < n_cols x_ptrs = X + pid * stride_xm + offs_n x = tl.load(x_ptrs, mask=mask_n, other=float("-inf")) - x_bits = x.to(x_utype, bitcast=True) - x_key = fpval_to_key_with_nan(x, x_bits) - idx_key = (n_cols - offs_n).to(x_ultype) - packed = (x_key.to(x_ultype) << 16) | idx_key - take_gt = mask_n & (x_key > thr_key) + take_gt = mask_n & (x > thr_val) pos = tl.atomic_add(write_count_ptrs, one, mask=take_gt, sem="relaxed", scope="cta") - write_mask = take_gt & (pos < K_PAD) - dst_ptrs = tle.gpu.local_ptr(smem_selected, (pos.to(tl.int32), )) - tl.store(dst_ptrs, packed, mask=write_mask) + write_mask = take_gt & (pos < k_limit) + out_pos = pos.to(tl.int32) + yv_ptrs = Yv + pid * stride_ym + out_pos + yi_ptrs = Yi + pid * stride_ym + out_pos + tl.store(yv_ptrs, x, mask=write_mask) + tl.store(yi_ptrs, offs_n.to(tl.int32), mask=write_mask) # Pass 2: fill remaining slots with values equal to threshold (first-come-first-serve). - for t in tl.range(0, n_tiles): - offs_n = t * BLOCK_N + tl.arange(0, BLOCK_N) - mask_n = offs_n < n_cols + cur_count = tl.load(tle.gpu.local_ptr(smem_write_count, (0, ))) + if cur_count < k_limit: + for t in tl.range(0, n_tiles): + cur_count = tl.load(tle.gpu.local_ptr(smem_write_count, (0, ))) + if cur_count < k_limit: + offs_n = t * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_cols + x_ptrs = X + pid * stride_xm + offs_n + x = tl.load(x_ptrs, mask=mask_n, other=float("-inf")) + take_eq = mask_n & (x == thr_val) + pos = tl.atomic_add(write_count_ptrs, one, mask=take_eq, sem="relaxed", scope="cta") + write_mask = take_eq & (pos < k_limit) + out_pos = pos.to(tl.int32) + yv_ptrs = Yv + pid * stride_ym + out_pos + yi_ptrs = Yi + pid * stride_ym + out_pos + tl.store(yv_ptrs, x, mask=write_mask) + tl.store(yi_ptrs, offs_n.to(tl.int32), mask=write_mask) + + +@triton.jit +def topk_kernel_streaming_triton( + X, + Yv, + Yi, + stride_xm, + stride_ym, + n_cols, + K: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + + x_dtype: tl.constexpr = X.dtype.element_ty + x_nbits: tl.constexpr = x_dtype.primitive_bitwidth + x_utype = tl.dtype(f"uint{x_nbits}") + if x_nbits < 16: + packed_nbits: tl.constexpr = 32 + else: + packed_nbits: tl.constexpr = x_nbits * 2 + x_packtype = tl.dtype(f"uint{packed_nbits}") + + n_tiles = tl.cdiv(n_cols, BLOCK_N) + offs_n = (n_tiles - 1) * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_cols + + x_ptrs = X + pid * stride_xm + offs_n + x = tl.load(x_ptrs, mask=mask_n, other=float("-inf")) + x_key = fpval_to_key(x.to(x_utype, bitcast=True)) + x_pack = (x_key.to(x_packtype) << 16) | indx_to_key(offs_n).to(x_packtype) + acc = tl.topk(x_pack, K) + + for _ in tl.range(0, n_tiles - 1): + acc = tl.bitonic_merge(acc) + offs_n -= BLOCK_N x_ptrs = X + pid * stride_xm + offs_n - x = tl.load(x_ptrs, mask=mask_n, other=float("-inf")) - x_bits = x.to(x_utype, bitcast=True) - x_key = fpval_to_key_with_nan(x, x_bits) - idx_key = (n_cols - offs_n).to(x_ultype) - packed = (x_key.to(x_ultype) << 16) | idx_key - take_eq = mask_n & (x_key == thr_key) - pos = tl.atomic_add(write_count_ptrs, one, mask=take_eq, sem="relaxed", scope="cta") - write_mask = take_eq & (pos < K_PAD) - dst_ptrs = tle.gpu.local_ptr(smem_selected, (pos.to(tl.int32), )) - tl.store(dst_ptrs, packed, mask=write_mask) - - selected_packed = tl.load(smem_selected_ptrs) - - # Stage 4: unpack final packed keys and write outputs. - topk = tl.sort(selected_packed, dim=0, descending=True) - idx_mask = tl.full(topk.shape, (1 << 16) - 1, dtype=topk.dtype) - idx_raw = (topk & idx_mask).to(tl.uint32) - y_indices = (n_cols - idx_raw.to(tl.int32)).to(tl.int32) - y_values_raw = (topk >> 16).to(x_utype) - y_values = key_to_fpval(y_values_raw).to(x_dtype, bitcast=True) - - mask_k = offs_k < K - yv_ptrs = Yv + pid * stride_ym + offs_k - yi_ptrs = Yi + pid * stride_ym + offs_k - tl.store(yv_ptrs, y_values, mask=mask_k) - tl.store(yi_ptrs, y_indices, mask=mask_k) + x = tl.load(x_ptrs, mask=tl.full([BLOCK_N], True, tl.int1), other=float("-inf")) + x_key = fpval_to_key(x.to(x_utype, bitcast=True)) + x_pack = (x_key.to(x_packtype) << 16) | indx_to_key(offs_n).to(x_packtype) + acc = tl.maximum(acc, tl.topk(x_pack, K)) + + # Rotate index-key into high bits, then sort by descending key. + acc = (acc << (packed_nbits - 16)) | (acc >> 16) + acc = tl.sort(acc, descending=True) + + y_indx_key = (acc >> (packed_nbits - 16)).to(tl.uint32) + y_idx = key_to_indx(y_indx_key) + y_val_bits = acc.to(x_utype) + y_vals = key_to_fpval(y_val_bits).to(x_dtype, bitcast=True) + + offs_k = tl.arange(0, K) + tl.store(Yv + pid * stride_ym + offs_k, y_vals) + tl.store(Yi + pid * stride_ym + offs_k, y_idx) def triton_radix_topk( @@ -219,12 +241,6 @@ def triton_radix_topk( n_rows, n_cols = x.shape if k > n_cols: raise ValueError(f"k={k} must be <= N={n_cols}") - if n_cols > 65535: - raise ValueError(f"N={n_cols} too large for 16-bit packed index encoding (max 65535)") - - k_pad = triton.next_power_of_2(k) - if k_pad > 1024: - raise ValueError(f"k={k} too large for radix kernel (K_PAD={k_pad}, max 1024)") if out_vals is None: y_vals = torch.empty((n_rows, k), device=x.device, dtype=x.dtype) @@ -243,8 +259,17 @@ def triton_radix_topk( num_batch = n_rows num_blocks = num_batch - block_n_radix = max(k_pad, min(512, triton.next_power_of_2(n_cols))) + # Tuned heuristic from empirical sweeps: + # - medium/large N prefers BLOCK_N=1024 and higher warp count + # - very small N should avoid over-large BLOCK_N + block_n_radix = max(32, triton.next_power_of_2(n_cols)) block_n_radix = min(block_n_radix, 1024) + if block_n_radix <= 64: + num_warps = 2 + elif block_n_radix <= 128: + num_warps = 4 + else: + num_warps = 8 topk_kernel_radix_triton[(num_blocks, )]( x, y_vals, @@ -253,10 +278,61 @@ def triton_radix_topk( y_vals.stride(0), n_cols, K=k, - K_PAD=k_pad, BLOCK_N=block_n_radix, RADIX_BITS=4, - num_warps=4, + num_warps=num_warps, + num_stages=1, + ) + return y_vals, y_idx + + +def triton_topk( + x: torch.Tensor, + k: int, + out_vals: torch.Tensor | None = None, + out_idx: torch.Tensor | None = None, +): + assert x.is_cuda, "input must be on CUDA" + assert x.ndim == 2, "input must be 2D (M, N)" + n_rows, n_cols = x.shape + if k > n_cols: + raise ValueError(f"k={k} must be <= N={n_cols}") + if n_cols > 65535: + raise ValueError(f"triton_topk supports N <= 65535, got N={n_cols}") + + if out_vals is None: + y_vals = torch.empty((n_rows, k), device=x.device, dtype=x.dtype) + else: + y_vals = out_vals + assert y_vals.shape == (n_rows, k) + assert y_vals.dtype == x.dtype + assert y_vals.device == x.device + if out_idx is None: + y_idx = torch.empty((n_rows, k), device=x.device, dtype=torch.int32) + else: + y_idx = out_idx + assert y_idx.shape == (n_rows, k) + assert y_idx.dtype == torch.int32 + assert y_idx.device == x.device + + block_n = max(32, triton.next_power_of_2(min(n_cols, 1024))) + if block_n <= 64: + num_warps = 2 + elif block_n <= 128: + num_warps = 4 + else: + num_warps = 8 + + topk_kernel_streaming_triton[(n_rows, )]( + x, + y_vals, + y_idx, + x.stride(0), + y_vals.stride(0), + n_cols, + K=k, + BLOCK_N=block_n, + num_warps=num_warps, num_stages=1, ) return y_vals, y_idx @@ -276,12 +352,22 @@ def _get_dtype(name: str): def run_correctness(m: int, n: int, k: int, dtype: torch.dtype): torch.manual_seed(0) x = torch.rand((m, n), device=DEVICE, dtype=dtype) - t_vals, _ = torch.topk(x, k, dim=1) - y_vals, y_idx = triton_radix_topk(x, k) - torch.testing.assert_close(y_vals, t_vals, rtol=1e-3, atol=1e-3) - gathered = x.gather(1, y_idx.to(torch.int64)) - torch.testing.assert_close(gathered, y_vals, rtol=1e-3, atol=1e-3) - print("Correctness check passed (radix).") + t_vals, _ = torch.topk(x, k, dim=1, sorted=False) + t_vals_sorted = torch.sort(t_vals, dim=1, descending=True).values + + y_vals_radix, y_idx_radix = triton_radix_topk(x, k) + y_vals_radix_sorted = torch.sort(y_vals_radix, dim=1, descending=True).values + torch.testing.assert_close(y_vals_radix_sorted, t_vals_sorted, rtol=1e-3, atol=1e-3) + gathered_radix = x.gather(1, y_idx_radix.to(torch.int64)) + torch.testing.assert_close(gathered_radix, y_vals_radix, rtol=1e-3, atol=1e-3) + + y_vals_triton, y_idx_triton = triton_topk(x, k) + y_vals_triton_sorted = torch.sort(y_vals_triton, dim=1, descending=True).values + torch.testing.assert_close(y_vals_triton_sorted, t_vals_sorted, rtol=1e-3, atol=1e-3) + gathered_triton = x.gather(1, y_idx_triton.to(torch.int64)) + torch.testing.assert_close(gathered_triton, y_vals_triton, rtol=1e-3, atol=1e-3) + + print("Correctness check passed (radix + triton).") if "--only_unit_test" in sys.argv: @@ -290,9 +376,9 @@ def run_correctness(m: int, n: int, k: int, dtype: torch.dtype): run_correctness(_args.batch, _args.seq_len, _args.K, _dtype) sys.exit(0) -_BENCH_PROVIDERS = ["radix", "torch"] -_BENCH_NAMES = ["Triton-RadixSelect", "Torch-TopK"] -_BENCH_STYLES = [("red", "-"), ("orange", "-")] +_BENCH_PROVIDERS = ["radix", "triton", "torch"] +_BENCH_NAMES = ["Triton-RadixSelect", "Triton-TopK", "Torch-TopK"] +_BENCH_STYLES = [("red", "-"), ("blue", "-"), ("orange", "-")] @triton.testing.perf_report( @@ -310,7 +396,7 @@ def run_correctness(m: int, n: int, k: int, dtype: torch.dtype): line_names=_BENCH_NAMES, styles=_BENCH_STYLES, ylabel="ms", - plot_name="tle-topk-radix-vs-torch", + plot_name="tle-topk-radix-vs-triton-vs-torch", args={}, )) def benchmark(M, N, K, provider, dtype): @@ -323,15 +409,21 @@ def benchmark(M, N, K, provider, dtype): quantiles = [0.5, 0.2, 0.8] if provider == "radix": - if N > 65535: - return float("nan"), float("nan"), float("nan") - k_pad = triton.next_power_of_2(K) - if k_pad > 1024: - return float("nan"), float("nan"), float("nan") def run_kernel(): triton_radix_topk(x, K, out_vals=y_vals, out_idx=y_idx) + ms, min_ms, max_ms = triton.testing.do_bench( + run_kernel, + quantiles=quantiles, + warmup=bench_warmup, + rep=bench_rep, + ) + elif provider == "triton": + + def run_kernel(): + triton_topk(x, K, out_vals=y_vals, out_idx=y_idx) + ms, min_ms, max_ms = triton.testing.do_bench( run_kernel, quantiles=quantiles, @@ -341,7 +433,7 @@ def run_kernel(): else: def run_kernel(): - torch.topk(x, K, dim=1) + torch.topk(x, K, dim=1, sorted=False) ms, min_ms, max_ms = triton.testing.do_bench( run_kernel, @@ -358,6 +450,7 @@ def main(argv=None): parser.add_argument("--seq_len", type=int, default=1024, help="sequence length") parser.add_argument("--K", type=int, default=2, help="topk") parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "float32", "bfloat16"]) + parser.add_argument("--skip_correctness", action="store_true", help="skip correctness check before benchmark") parser.add_argument("--show_plots", action="store_true", help="show plots in benchmark") args = parser.parse_args(argv) @@ -365,7 +458,8 @@ def main(argv=None): check_m = args.batch check_n = min(args.seq_len, 256) check_k = min(args.K, check_n) - run_correctness(check_m, check_n, check_k, dtype) + if not args.skip_correctness: + run_correctness(check_m, check_n, check_k, dtype) benchmark.run(print_data=True, show_plots=args.show_plots, dtype=dtype) diff --git a/python/tutorials/tle/06-cluster-gemm.py b/python/tutorials/tle/04-cluster-gemm.py similarity index 99% rename from python/tutorials/tle/06-cluster-gemm.py rename to python/tutorials/tle/04-cluster-gemm.py index 761580f1a9..6d2cee7800 100644 --- a/python/tutorials/tle/06-cluster-gemm.py +++ b/python/tutorials/tle/04-cluster-gemm.py @@ -370,8 +370,7 @@ def _verify_remote_lowering( raise RuntimeError(f"unexpected cluster_dims={cluster_dims}, expect (2, 1, 1)") ptx = compiled.asm.get("ptx", "") ttgir = compiled.asm.get("ttgir", "") - has_remote = (("mapa.shared::cluster" in ptx) or ("tle.remote_pointers" in ttgir) or ("tle.remote_cta_id" in ttgir) - or ("tle.remote_shard_id_carrier" in ttgir)) + has_remote = ("mapa.shared::cluster" in ptx) or ("tle.remote_pointers" in ttgir) if not has_remote: raise RuntimeError("remote lowering evidence not found in PTX/TTGIR") diff --git a/python/tutorials/tle/05-deepseek_v32_topk_selector.py b/python/tutorials/tle/05-deepseek_v32_topk_selector.py deleted file mode 100644 index bcd5357141..0000000000 --- a/python/tutorials/tle/05-deepseek_v32_topk_selector.py +++ /dev/null @@ -1,894 +0,0 @@ -""" -DeepSeek V3-2 Top-K Selector with Triton and TLE (TLE Tutorial) -============================================================== - -This tutorial adapts the TileLang DeepSeek V3-2 top-k selector example and -implements two kernels: -- A Triton version rewritten with the radix-select flow used in `03-topk.py`. -- A TLE version that keeps the shared-memory DeepSeek-style selector - (`tle.gpu.alloc` + `tle.gpu.local_ptr`). - -If TileLang is installed, the script will also run the original TileLang kernel -and compare correctness and performance. - -Notes ------ -- Input dtype is assumed to be float32 for the 32-bit radix refinement. -- `SMEM_INPUT_SIZE` bounds the number of candidates carried into stage-2. - If the threshold bucket exceeds this size, results are approximate. -""" - -# %% -# Setup -# ----- - -import argparse -from typing import Optional - -import torch -import triton -import triton.language as tl -import triton.experimental.tle.language as tle - -try: - import tilelang - import tilelang.language as T - - _HAVE_TILELANG = True -except Exception: # pragma: no cover - optional dependency - tilelang = None - T = None - _HAVE_TILELANG = False - -DEVICE = triton.runtime.driver.active.get_active_torch_device() -RADIX_BITS = 8 -RADIX = 1 << RADIX_BITS - -# %% -# Key conversions -# --------------- - - -@triton.jit -def _convert_to_uint16(x): - hval = x.to(tl.float16) - bits = hval.to(tl.uint16, bitcast=True) - sign_mask = tl.full(hval.shape, 0x8000, tl.uint16) - bits = tl.where(x < 0, ~bits, bits | sign_mask) - return bits >> 8 - - -@triton.jit -def _convert_to_uint32(x): - bits = x.to(tl.uint32, bitcast=True) - sign_mask = tl.full(bits.shape, 0x80000000, tl.uint32) - bits = tl.where(x < 0, ~bits, bits | sign_mask) - return bits - - -# %% -# Triton kernel (radix-select, based on 03-topk) -# ------------------------------ - - -@triton.jit -def triton_topk_selector_kernel( - x_ptr, - out_ptr, - starts_ptr, - ends_ptr, - hist_ptr, - num_ptr, - stride_xm, - stride_xn, - stride_outm, - stride_outn, - stride_hist, - stride_num, - seq_len, - RADIX_BITS: tl.constexpr, - ASSUME_ALIGNED: tl.constexpr, - TOPK: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - row_start = tl.load(starts_ptr + pid).to(tl.int32) - row_end = tl.load(ends_ptr + pid).to(tl.int32) - - row_ptr = x_ptr + pid * stride_xm - out_row = out_ptr + pid * stride_outm - hist_row = hist_ptr + pid * stride_hist - num_row = num_ptr + pid * stride_num - - if ASSUME_ALIGNED: - tl.assume(row_start == 0) - tl.assume(row_end == seq_len) - tl.assume(stride_xn == 1) - tl.assume(stride_outn == 1) - seq_len = tl.multiple_of(seq_len, BLOCK_SIZE) - - lane = tl.arange(0, BLOCK_SIZE) - ones = tl.full([BLOCK_SIZE], 1, tl.int32) - RADIX_SIZE: tl.constexpr = 1 << RADIX_BITS - RADIX_MASK: tl.constexpr = RADIX_SIZE - 1 - hist_idx = tl.arange(0, RADIX_SIZE) - - desired = tl.full((), 0, dtype=tl.uint32) - desired_mask = tl.full((), 0, dtype=tl.uint32) - k_to_find = tl.full((), TOPK, dtype=tl.int32) - n_tiles = tl.cdiv(seq_len, BLOCK_SIZE) - - # MSD radix-select on 32-bit float keys. - for digit_pos in tl.static_range(32 - RADIX_BITS, -1, -RADIX_BITS): - tl.store(hist_row + hist_idx, 0) - for t in tl.range(0, n_tiles): - offs = t * BLOCK_SIZE + lane - in_range = (offs < seq_len) & (offs >= row_start) & (offs < row_end) - x = tl.load(row_ptr + offs * stride_xn, mask=in_range, other=float("-inf")) - x_key = _convert_to_uint32(x) - matches = (x_key & desired_mask) == desired - digit = ((x_key >> digit_pos) & RADIX_MASK).to(tl.int32) - tl.atomic_add(hist_row + digit, ones, mask=in_range & matches) - - counts = tl.load(hist_row + hist_idx) - cumsum_desc = tl.cumsum(counts, axis=0, reverse=True) - tl.store(hist_row + hist_idx, cumsum_desc) - - cond = cumsum_desc >= k_to_find - selected = tl.max(tl.where(cond, hist_idx, 0), axis=0).to(tl.int32) - counts_gt = tl.max(tl.where(hist_idx == (selected + 1), cumsum_desc, 0), axis=0) - - selected_u = selected.to(tl.uint32) - desired = desired | (selected_u << digit_pos) - desired_mask = desired_mask | (tl.full((), RADIX_MASK, dtype=tl.uint32) << digit_pos) - k_to_find = k_to_find - counts_gt - - thr_key = desired - - # Compact candidates: first all keys > threshold, then keys == threshold. - tl.store(num_row + tl.arange(0, 2), 0) - num_ptrs = num_row + tl.zeros([BLOCK_SIZE], tl.int32) - - for t in tl.range(0, n_tiles): - offs = t * BLOCK_SIZE + lane - in_range = (offs < seq_len) & (offs >= row_start) & (offs < row_end) - x = tl.load(row_ptr + offs * stride_xn, mask=in_range, other=float("-inf")) - x_key = _convert_to_uint32(x) - take_gt = in_range & (x_key > thr_key) - pos = tl.atomic_add(num_ptrs, ones, mask=take_gt) - tl.store(out_row + pos * stride_outn, offs.to(tl.int32), mask=take_gt & (pos < TOPK)) - - for t in tl.range(0, n_tiles): - offs = t * BLOCK_SIZE + lane - in_range = (offs < seq_len) & (offs >= row_start) & (offs < row_end) - x = tl.load(row_ptr + offs * stride_xn, mask=in_range, other=float("-inf")) - x_key = _convert_to_uint32(x) - take_eq = in_range & (x_key == thr_key) - pos = tl.atomic_add(num_ptrs, ones, mask=take_eq) - tl.store(out_row + pos * stride_outn, offs.to(tl.int32), mask=take_eq & (pos < TOPK)) - - -# %% -# TLE kernel (shared memory) -# -------------------------- - - -@triton.jit -def tle_topk_selector_kernel( - x_ptr, - out_ptr, - starts_ptr, - ends_ptr, - stride_xm, - stride_xn, - stride_outm, - stride_outn, - seq_len, - RADIX: tl.constexpr, - HIST_SIZE: tl.constexpr, - ASSUME_ALIGNED: tl.constexpr, - TOPK: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - N_TILES: tl.constexpr, - SMEM_INPUT: tl.constexpr, - NUM_INPUT_TILES: tl.constexpr, -): - pid = tl.program_id(0) - row_start = tl.load(starts_ptr + pid).to(tl.int32) - row_end = tl.load(ends_ptr + pid).to(tl.int32) - - row_ptr = x_ptr + pid * stride_xm - out_row = out_ptr + pid * stride_outm - - if ASSUME_ALIGNED: - tl.assume(row_start == 0) - tl.assume(row_end == seq_len) - tl.assume(stride_xn == 1) - tl.assume(stride_outn == 1) - seq_len = tl.multiple_of(seq_len, BLOCK_SIZE) - - lane = tl.arange(0, BLOCK_SIZE) - ones = tl.full([BLOCK_SIZE], 1, tl.int32) - - s_histogram = tle.gpu.alloc( - [HIST_SIZE], - dtype=tl.int32, - layout=None, - scope=tle.gpu.smem, - nv_mma_shared_layout=False, - ) - s_num_input = tle.gpu.alloc( - [2], - dtype=tl.int32, - layout=None, - scope=tle.gpu.smem, - nv_mma_shared_layout=False, - ) - s_input_idx = tle.gpu.alloc( - [2, SMEM_INPUT], - dtype=tl.int32, - layout=None, - scope=tle.gpu.smem, - nv_mma_shared_layout=False, - ) - - hist_idx = tl.arange(0, RADIX) - hist_last = tl.full([1], RADIX, tl.int32) - - hist_ptrs = tle.gpu.local_ptr(s_histogram, (hist_idx, )) - hist_last_ptrs = tle.gpu.local_ptr(s_histogram, (hist_last, )) - tl.store(hist_ptrs, 0) - tl.store(hist_last_ptrs, 0) - tl.store(tle.gpu.local_ptr(s_num_input, (tl.arange(0, 2), )), 0) - tl.debug_barrier() - - l_new_topk = tl.full((), TOPK, tl.int32) - - # stage 1 - for t in tl.static_range(N_TILES): - offs = t * BLOCK_SIZE + lane - in_range = (offs < seq_len) & (offs >= row_start) & (offs < row_end) - x = tl.load(row_ptr + offs * stride_xn, mask=in_range, other=0.0) - bin_u16 = _convert_to_uint16(x) - bin_i32 = bin_u16.to(tl.int32) - hist_bin_ptrs = tle.gpu.local_ptr(s_histogram, (bin_i32, )) - tl.atomic_add(hist_bin_ptrs, ones, mask=in_range) - - rev_idx = (RADIX - 1) - hist_idx - hist_rev = tl.load(tle.gpu.local_ptr(s_histogram, (rev_idx, ))) - hist_cum_rev = tl.cumsum(hist_rev, axis=0) - tl.store(tle.gpu.local_ptr(s_histogram, (rev_idx, )), hist_cum_rev) - tl.debug_barrier() - - hist_cum = tl.load(hist_ptrs) - hist_cum_next = tl.load(tle.gpu.local_ptr(s_histogram, (hist_idx + 1, )), mask=hist_idx + 1 < RADIX, other=0) - cond = (hist_cum > l_new_topk) & (hist_cum_next <= l_new_topk) - cand = tl.where(cond, hist_idx.to(tl.int32), -1) - threshold = tl.max(cand, axis=0) - hist_next = tl.max(tl.where(hist_idx == threshold + 1, hist_cum, 0), axis=0) - l_new_topk = tl.maximum(l_new_topk - hist_next, 0) - - num_ptrs = tle.gpu.local_ptr(s_num_input, (tl.zeros([BLOCK_SIZE], tl.int32), )) - for t in tl.static_range(N_TILES): - offs = t * BLOCK_SIZE + lane - in_range = (offs < seq_len) & (offs >= row_start) & (offs < row_end) - x = tl.load(row_ptr + offs * stride_xn, mask=in_range, other=0.0) - bin_u16 = _convert_to_uint16(x) - bin_i32 = bin_u16.to(tl.int32) - gt_thr = bin_i32 > threshold - eq_thr = bin_i32 == threshold - - pos = tl.atomic_add(tle.gpu.local_ptr(s_histogram, (bin_i32 + 1, )), ones, mask=in_range & gt_thr) - pos = tl.where(in_range & gt_thr, pos, 0) - tl.store(out_row + pos * stride_outn, offs.to(tl.int32), mask=in_range & gt_thr & (pos < TOPK)) - - pos_eq = tl.atomic_add(num_ptrs, ones, mask=in_range & eq_thr & (l_new_topk > 0)) - pos_eq = tl.where(in_range & eq_thr, pos_eq, 0) - tl.store( - tle.gpu.local_ptr(s_input_idx, (tl.zeros([BLOCK_SIZE], tl.int32), pos_eq)), - offs.to(tl.int32), - mask=in_range & eq_thr & (pos_eq < SMEM_INPUT) & (l_new_topk > 0), - ) - - # stage 2 - for round_id in tl.static_range(4): - r_idx = round_id & 1 - next_idx = r_idx ^ 1 - start_pos = TOPK - l_new_topk - - tl.store(hist_ptrs, 0) - tl.store(hist_last_ptrs, 0) - num_ptrs_next = tle.gpu.local_ptr(s_num_input, (tl.full([BLOCK_SIZE], next_idx, tl.int32), )) - tl.store(num_ptrs_next, 0, mask=lane == 0) - tl.debug_barrier() - - num_ptrs_r = tle.gpu.local_ptr(s_num_input, (tl.full([BLOCK_SIZE], r_idx, tl.int32), )) - l_num_input = tl.max(tl.load(num_ptrs_r), axis=0).to(tl.int32) - max_input = tl.full((), SMEM_INPUT, tl.int32) - l_num_input = tl.minimum(l_num_input, max_input) - active = l_new_topk > 0 - - shift = 24 - round_id * 8 - for t in tl.static_range(NUM_INPUT_TILES): - offs = t * BLOCK_SIZE + lane - valid = offs < l_num_input - cand_idx = tl.load( - tle.gpu.local_ptr(s_input_idx, (tl.full([BLOCK_SIZE], r_idx, tl.int32), offs)), - mask=valid, - other=0, - ) - x = tl.load(row_ptr + cand_idx * stride_xn, mask=valid, other=0.0) - bin_u32 = _convert_to_uint32(x) - bin_i32 = ((bin_u32 >> shift) & 0xFF).to(tl.int32) - tl.atomic_add(tle.gpu.local_ptr(s_histogram, (bin_i32, )), ones, mask=valid & active) - - rev_idx = (RADIX - 1) - hist_idx - hist_rev = tl.load(tle.gpu.local_ptr(s_histogram, (rev_idx, ))) - hist_cum_rev = tl.cumsum(hist_rev, axis=0) - tl.store(tle.gpu.local_ptr(s_histogram, (rev_idx, )), hist_cum_rev) - tl.debug_barrier() - - hist_cum = tl.load(hist_ptrs) - hist_cum_next = tl.load(tle.gpu.local_ptr(s_histogram, (hist_idx + 1, )), mask=hist_idx + 1 < RADIX, other=0) - cond = (hist_cum > l_new_topk) & (hist_cum_next <= l_new_topk) - cand = tl.where(cond, hist_idx.to(tl.int32), -1) - threshold = tl.max(cand, axis=0) - hist_next = tl.max(tl.where(hist_idx == threshold + 1, hist_cum, 0), axis=0) - l_new_topk = tl.maximum(l_new_topk - hist_next, 0) - - for t in tl.static_range(NUM_INPUT_TILES): - offs = t * BLOCK_SIZE + lane - valid = offs < l_num_input - cand_idx = tl.load( - tle.gpu.local_ptr(s_input_idx, (tl.full([BLOCK_SIZE], r_idx, tl.int32), offs)), - mask=valid, - other=0, - ) - x = tl.load(row_ptr + cand_idx * stride_xn, mask=valid, other=0.0) - bin_u32 = _convert_to_uint32(x) - bin_i32 = ((bin_u32 >> shift) & 0xFF).to(tl.int32) - - gt_thr = bin_i32 > threshold - eq_thr = bin_i32 == threshold - pos = tl.atomic_add(tle.gpu.local_ptr(s_histogram, (bin_i32 + 1, )), ones, mask=valid & gt_thr & active) - pos = tl.where(valid & gt_thr & active, pos, 0) - out_pos = pos + start_pos - tl.store( - out_row + out_pos * stride_outn, - cand_idx, - mask=valid & gt_thr & active & (out_pos < TOPK), - ) - - if round_id == 3: - pos_eq = tl.atomic_add( - tle.gpu.local_ptr(s_histogram, (bin_i32 + 1, )), - ones, - mask=valid & eq_thr & active & (l_new_topk > 0), - ) - pos_eq = tl.where(valid & eq_thr & active, pos_eq, 0) - out_pos = pos_eq + start_pos - tl.store( - out_row + out_pos * stride_outn, - cand_idx, - mask=valid & eq_thr & active & (out_pos < TOPK) & (l_new_topk > 0), - ) - else: - num_ptrs = tle.gpu.local_ptr(s_num_input, (tl.full([BLOCK_SIZE], next_idx, tl.int32), )) - pos_eq = tl.atomic_add(num_ptrs, ones, mask=valid & eq_thr & active & (l_new_topk > 0)) - pos_eq = tl.where(valid & eq_thr & active, pos_eq, 0) - tl.store( - tle.gpu.local_ptr(s_input_idx, (tl.full([BLOCK_SIZE], next_idx, tl.int32), pos_eq)), - cand_idx, - mask=valid & eq_thr & active & (pos_eq < SMEM_INPUT) & (l_new_topk > 0), - ) - - -# %% -# TileLang reference (optional) -# ----------------------------- - -if _HAVE_TILELANG: - _TL_PASS_CONFIGS = { - tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, - } - _TL_KERNEL_CACHE = {} - - def convert_to_uint16(x): - hval = T.Cast(T.float16, x) - bits_uint = T.reinterpret(T.uint16, hval) - bits_uint = T.if_then_else(x < 0, ~bits_uint & 0xFFFF, bits_uint | 0x8000) - return bits_uint >> 8 - - def convert_to_uint32(x): - bits_uint = T.reinterpret(T.uint32, x) - bits_uint = T.if_then_else( - x < 0, - ~bits_uint & T.Cast(T.uint32, 0xFFFFFFFF), - bits_uint | T.Cast(T.uint32, 0x80000000), - ) - return bits_uint - - @tilelang.jit(pass_configs=_TL_PASS_CONFIGS) - def _tilelang_topk_impl(topk, in_dtype=T.float32, out_dtype=T.int32): - batch = T.dynamic("batch") - seq_len = T.dynamic("seq_len") - RADIX_LOCAL = 1 << 8 - BLOCK_SIZE = 1024 - SMEM_INPUT_SIZE = 4096 - - @T.prim_func - def tl_topk_kernel( - input: T.Tensor[(batch, seq_len), in_dtype], - index: T.Tensor[(batch, topk), out_dtype], - starts: T.Tensor[(batch), out_dtype], - ends: T.Tensor[(batch), out_dtype], - ): - with T.Kernel(batch, threads=BLOCK_SIZE) as (bx): - tx = T.get_thread_binding() - - s_threshold_bin_id = T.alloc_shared([1], T.int32) - s_histogram = T.alloc_shared([RADIX_LOCAL + 1], T.int32) - s_num_input = T.alloc_shared([2], T.int32) - s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], T.int32) - - l_threshold_bin_id = T.alloc_var(T.int32) - l_new_topk = T.alloc_var(T.int32) - l_num_input = T.alloc_var(T.int32) - l_bin_id32 = T.alloc_var(T.int32) - l_val = T.alloc_var(T.int32) - l_start_pos = T.alloc_var(T.int32) - l_start_idx = T.alloc_var(T.int32) - l_end_idx = T.alloc_var(T.int32) - l_out_pos = T.alloc_var(T.int32) - - l_new_topk = topk - l_start_idx = starts[bx] - l_end_idx = ends[bx] - - T.fill(s_histogram, 0) - T.fill(s_num_input[0], 0) - T.sync_threads() - for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): - input_idx = s * BLOCK_SIZE + tx - if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: - inval_int16 = convert_to_uint16(input[bx, input_idx]) - T.atomic_add(s_histogram[inval_int16], 1) - T.sync_threads() - - if tx < RADIX_LOCAL: - for i in T.serial(8): - offset = 1 << i - T.sync_threads(3, RADIX_LOCAL) - if tx < RADIX_LOCAL - offset: - l_val = s_histogram[tx] + s_histogram[tx + offset] - T.sync_threads(3, RADIX_LOCAL) - if tx < RADIX_LOCAL - offset: - s_histogram[tx] = l_val - - T.sync_threads(3, RADIX_LOCAL) - if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk: - s_threshold_bin_id[0] = tx - T.sync_threads() - l_threshold_bin_id = s_threshold_bin_id[0] - l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1] - T.sync_threads() - - for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): - T.sync_threads() - input_idx = s * BLOCK_SIZE + tx - if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: - bin_id = convert_to_uint16(input[bx, input_idx]) - l_bin_id32 = T.Cast(T.int32, bin_id) - if l_bin_id32 > l_threshold_bin_id: - pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) - index[bx, pos] = input_idx - elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: - pos = T.atomic_add(s_num_input[0], 1, return_prev=True) - s_input_idx[0, pos] = input_idx - - for round in T.serial(4): - if l_new_topk <= 0: - T.loop_break() - - r_idx = round % 2 - l_start_pos = topk - l_new_topk - - T.sync_threads() - T.fill(s_histogram, 0) - if tx == 0: - s_num_input[r_idx ^ 1] = 0 - T.sync_threads() - - l_num_input = s_num_input[r_idx] - for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): - if s * BLOCK_SIZE + tx < l_num_input: - l_bin_id32 = T.Cast( - T.int32, - ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> - (24 - round * 8)) & 0xFF), - ) - T.atomic_add(s_histogram[l_bin_id32], 1) - T.sync_threads() - - if tx < RADIX_LOCAL: - for i in T.serial(8): - offset = 1 << i - T.sync_threads(3, RADIX_LOCAL) - if tx < RADIX_LOCAL - offset: - l_val = s_histogram[tx] + s_histogram[tx + offset] - T.sync_threads(3, RADIX_LOCAL) - if tx < RADIX_LOCAL - offset: - s_histogram[tx] = l_val - - T.sync_threads(3, RADIX_LOCAL) - if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk: - s_threshold_bin_id[0] = tx - T.sync_threads() - - l_threshold_bin_id = s_threshold_bin_id[0] - l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1] - T.sync_threads() - - for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): - T.sync_threads() - if s * BLOCK_SIZE + tx < l_num_input: - l_bin_id32 = T.Cast( - T.int32, - ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> - (24 - round * 8)) & 0xFF), - ) - if l_bin_id32 > l_threshold_bin_id: - pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos - index[bx, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] - elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: - if round == 3: - l_out_pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, - return_prev=True) + l_start_pos - if l_out_pos < topk: - index[bx, l_out_pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] - else: - pos = T.atomic_add(s_num_input[r_idx ^ 1], 1, return_prev=True) - s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] - - return tl_topk_kernel - - def tilelang_topk_selector(input, starts, ends, topk, out: Optional[torch.Tensor] = None): - batch, _ = input.shape - if out is None: - out = torch.zeros((batch, topk), dtype=torch.int32, device=input.device) - kernel = _TL_KERNEL_CACHE.get(topk) - if kernel is None: - kernel = _tilelang_topk_impl(topk) - _TL_KERNEL_CACHE[topk] = kernel - kernel(input, out, starts, ends) - return out - - -# %% -# Python wrappers -# --------------- - - -def _allocate_triton_scratch(batch, smem_input, device): - hist = torch.empty((batch, RADIX + 1), dtype=torch.int32, device=device) - num = torch.empty((batch, 2), dtype=torch.int32, device=device) - return hist, num - - -def triton_topk_selector( - x, - starts, - ends, - topk, - block_size=1024, - num_warps=32, - smem_input=4096, - out: Optional[torch.Tensor] = None, - scratch=None, - assume_aligned: Optional[bool] = None, -): - if x.dtype != torch.float32: - x = x.float() - batch, seq_len = x.shape - if out is None: - out = torch.full((batch, topk), -1, dtype=torch.int32, device=x.device) - if scratch is None: - scratch = _allocate_triton_scratch(batch, smem_input, x.device) - if len(scratch) == 3: - hist, num, _ = scratch - else: - hist, num = scratch - - if assume_aligned is None: - assume_aligned = (x.is_contiguous() and out.is_contiguous() and (seq_len % block_size == 0) - and torch.all(starts == 0).item() and torch.all(ends == seq_len).item()) - - grid = (batch, ) - triton_topk_selector_kernel[grid]( - x, - out, - starts, - ends, - hist, - num, - x.stride(0), - x.stride(1), - out.stride(0), - out.stride(1), - hist.stride(0), - num.stride(0), - seq_len, - RADIX_BITS=8, - ASSUME_ALIGNED=assume_aligned, - TOPK=topk, - BLOCK_SIZE=block_size, - num_warps=num_warps, - ) - return out - - -def tle_topk_selector( - x, - starts, - ends, - topk, - block_size=1024, - num_warps=32, - smem_input=4096, - out: Optional[torch.Tensor] = None, - assume_aligned: Optional[bool] = None, -): - if x.dtype != torch.float32: - x = x.float() - batch, seq_len = x.shape - if out is None: - out = torch.full((batch, topk), -1, dtype=torch.int32, device=x.device) - - n_tiles = triton.cdiv(seq_len, block_size) - num_input_tiles = triton.cdiv(smem_input, block_size) - hist_size = RADIX * 2 - if assume_aligned is None: - assume_aligned = (x.is_contiguous() and out.is_contiguous() and (seq_len % block_size == 0) - and torch.all(starts == 0).item() and torch.all(ends == seq_len).item()) - - grid = (batch, ) - tle_topk_selector_kernel[grid]( - x, - out, - starts, - ends, - x.stride(0), - x.stride(1), - out.stride(0), - out.stride(1), - seq_len, - RADIX=RADIX, - HIST_SIZE=hist_size, - ASSUME_ALIGNED=assume_aligned, - TOPK=topk, - BLOCK_SIZE=block_size, - N_TILES=n_tiles, - SMEM_INPUT=smem_input, - NUM_INPUT_TILES=num_input_tiles, - num_warps=num_warps, - ) - return out - - -# %% -# Correctness & benchmarking -# -------------------------- - - -def _torch_topk_indices(x, starts, ends, topk): - batch, _ = x.shape - out = torch.empty((batch, topk), dtype=torch.int32, device=x.device) - for i in range(batch): - start = int(starts[i].item()) - end = int(ends[i].item()) - vals, idx = torch.topk(x[i, start:end], topk, dim=0) - out[i] = idx.to(torch.int32) + start - return out - - -def _recall(pred, ref): - batch = pred.shape[0] - k = ref.shape[1] - hits = 0 - for i in range(batch): - pred_set = set(pred[i].tolist()) - ref_set = set(ref[i].tolist()) - hits += len(pred_set & ref_set) - return hits / (batch * k) - - -_BENCH_PROVIDERS = ["triton", "tle", "torch"] + (["tilelang"] if _HAVE_TILELANG else []) -_BENCH_NAMES = ["Triton-Radix", "TLE-DeepSeek", "Torch-TopK"] + (["TileLang"] if _HAVE_TILELANG else []) -_BENCH_STYLES = [("red", "-"), ("orange", "-"), ("green", "-")] + ([("blue", "-")] if _HAVE_TILELANG else []) - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["batch", "seq_len", "topk"], - x_vals=[ - (64, 4096, 128), - (64, 8192, 256), - (64, 32768, 1024), - (64, 32768, 2048), - ], - x_log=True, - line_arg="provider", - line_vals=_BENCH_PROVIDERS, - line_names=_BENCH_NAMES, - styles=_BENCH_STYLES, - ylabel="ms", - plot_name="tle-deepseek-v32-topk-selector", - args={}, - )) -def benchmark(batch, seq_len, topk, provider, block_size, smem_input, num_warps, warmup, rep): - if topk > smem_input: - return float("nan"), float("nan"), float("nan") - - torch.manual_seed(1) - x = torch.randn(batch, seq_len, device=DEVICE, dtype=torch.float32) - starts = torch.zeros(batch, dtype=torch.int32, device=DEVICE) - ends = torch.full((batch, ), seq_len, dtype=torch.int32, device=DEVICE) - assume_aligned = (seq_len % block_size == 0) - quantiles = [0.5, 0.2, 0.8] - - if provider == "triton": - triton_scratch = _allocate_triton_scratch(batch, smem_input, x.device) - triton_out = torch.full((batch, topk), -1, dtype=torch.int32, device=x.device) - - def run(): - triton_topk_selector( - x, - starts, - ends, - topk, - block_size=block_size, - num_warps=num_warps, - smem_input=smem_input, - out=triton_out, - scratch=triton_scratch, - assume_aligned=assume_aligned, - ) - - elif provider == "tle": - tle_out = torch.full((batch, topk), -1, dtype=torch.int32, device=x.device) - - def run(): - tle_topk_selector( - x, - starts, - ends, - topk, - block_size=block_size, - num_warps=num_warps, - smem_input=smem_input, - out=tle_out, - assume_aligned=assume_aligned, - ) - - elif provider == "torch": - - def run(): - torch.topk(x, topk, dim=-1)[1] - - else: - if not _HAVE_TILELANG: - return float("nan"), float("nan"), float("nan") - tilelang_out = torch.zeros((batch, topk), dtype=torch.int32, device=x.device) - - def run(): - tilelang_topk_selector(x, starts, ends, topk, out=tilelang_out) - - ms, min_ms, max_ms = triton.testing.do_bench( - run, - quantiles=quantiles, - warmup=warmup, - rep=rep, - ) - return ms, max_ms, min_ms - - -def run_correctness(batch, seq_len, topk, block_size, smem_input, num_warps): - torch.manual_seed(1) - x = torch.randn(batch, seq_len, device=DEVICE, dtype=torch.float32) - starts = torch.zeros(batch, dtype=torch.int32, device=DEVICE) - ends = torch.full((batch, ), seq_len, dtype=torch.int32, device=DEVICE) - assume_aligned = (seq_len % block_size == 0) - - ref = _torch_topk_indices(x, starts, ends, topk) - - triton_out = triton_topk_selector( - x, - starts, - ends, - topk, - block_size=block_size, - num_warps=num_warps, - smem_input=smem_input, - assume_aligned=assume_aligned, - ) - tle_out = tle_topk_selector( - x, - starts, - ends, - topk, - block_size=block_size, - num_warps=num_warps, - smem_input=smem_input, - assume_aligned=assume_aligned, - ) - - print(f"Triton recall vs torch.topk: {_recall(triton_out, ref):.4f}") - print(f"TLE recall vs torch.topk: {_recall(tle_out, ref):.4f}") - - if _HAVE_TILELANG: - tilelang_out = tilelang_topk_selector(x, starts, ends, topk) - print(f"TileLang recall vs torch.topk: {_recall(tilelang_out, ref):.4f}") - print(f"Triton recall vs TileLang: {_recall(triton_out, tilelang_out):.4f}") - print(f"TLE recall vs TileLang: {_recall(tle_out, tilelang_out):.4f}") - else: - print("TileLang not available; skipping TileLang correctness.") - - -def run_bench(block_size, smem_input, num_warps, warmup, rep, show_plots): - benchmark.run( - print_data=True, - show_plots=show_plots, - block_size=block_size, - smem_input=smem_input, - num_warps=num_warps, - warmup=warmup, - rep=rep, - ) - - -# %% -# Main -# ---- - - -def main(argv=None): - parser = argparse.ArgumentParser() - parser.add_argument("--batch", type=int, default=64, help="batch size") - parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") - parser.add_argument("--topk", type=int, default=128, help="top-k") - parser.add_argument("--block_size", type=int, default=1024, help="block size (threads)") - parser.add_argument("--smem_input", type=int, default=4096, help="candidate buffer size") - parser.add_argument("--num_warps", type=int, default=32, help="num warps") - parser.add_argument("--warmup", type=int, default=5, help="warmup iters") - parser.add_argument("--rep", type=int, default=20, help="benchmark iters") - parser.add_argument("--show_plots", action="store_true", help="show plots in benchmark") - parser.add_argument("--skip_correctness", action="store_true", help="skip correctness check") - parser.add_argument("--skip_bench", action="store_true", help="skip benchmark") - args = parser.parse_args(argv) - - if args.topk > args.smem_input: - raise ValueError("topk must be <= smem_input to avoid truncation") - - if not args.skip_correctness: - run_correctness( - batch=args.batch, - seq_len=args.seq_len, - topk=args.topk, - block_size=args.block_size, - smem_input=args.smem_input, - num_warps=args.num_warps, - ) - - if not args.skip_bench: - run_bench( - block_size=args.block_size, - smem_input=args.smem_input, - num_warps=args.num_warps, - warmup=args.warmup, - rep=args.rep, - show_plots=args.show_plots, - ) - - -if __name__ == "__main__": - main() diff --git a/python/tutorials/tle/deepseek_v32/01-topk_selector.py b/python/tutorials/tle/deepseek_v32/01-topk_selector.py new file mode 100644 index 0000000000..85302d1200 --- /dev/null +++ b/python/tutorials/tle/deepseek_v32/01-topk_selector.py @@ -0,0 +1,4121 @@ +""" +DeepSeek V3-2 Top-K Selector with Triton and TLE (TLE Tutorial) +============================================================== + +This tutorial adapts the TileLang DeepSeek V3-2 top-k selector example and +implements a TLE kernel: +- A TRT-style 4-step selector in TLE: + step0 (fp16-mapped 11-bit) + step1/2/3 (uint32-mapped 11/11/10-bit) + with shared-memory histogram/final-sort flow. + +If TileLang is installed, the script will also run the original TileLang kernel +and compare correctness and performance. + +Notes +----- +- Input dtype is assumed to be float32 for the 32-bit radix refinement. +""" + +# %% +# Setup +# ----- + +import argparse +import hashlib +import urllib.request +from functools import lru_cache +from typing import Optional + +import torch +import triton +import triton.language as tl +import triton.experimental.tle.language as tle + +try: + import tilelang + import tilelang.language as T + + _HAVE_TILELANG = True +except Exception: # pragma: no cover - optional dependency + tilelang = None + T = None + _HAVE_TILELANG = False + +DEVICE = triton.runtime.driver.active.get_active_torch_device() +RADIX_BITS = 8 +RADIX = 1 << RADIX_BITS +TLE_FIXED_BLOCK_SIZE = 512 +TLE_FIXED_NUM_WARPS = TLE_FIXED_BLOCK_SIZE // 32 +TLE_FIXED_NUM_STAGES = 1 +TLE_RADIX_FINAL_SEQ_LEN_THRESHOLD = 12288 +TLE_SMEM_BLOCK_SIZE = 1024 +TLE_SMEM_NUM_WARPS = TLE_SMEM_BLOCK_SIZE // 32 +TLE_SMEM_NUM_STAGES = 1 +TLE_SMEM_INPUT_SIZE = 4096 +TLE_SMEM_CLUSTER_SIZE = 8 +BLOCK_CLUSTER_MESH_8 = tle.device_mesh({"block_cluster": [("cluster_x", TLE_SMEM_CLUSTER_SIZE)]}) + +# %% +# Key conversions +# --------------- + + +@triton.jit +def _convert_to_uint32(x): + bits = x.to(tl.uint32, bitcast=True) + sign_mask = tl.full(bits.shape, 0x80000000, tl.uint32) + bits = tl.where(x < 0, ~bits, bits | sign_mask) + return bits + + +@triton.jit +def _convert_to_uint16_hi8(x): + h = x.to(tl.float16) + bits = h.to(tl.uint16, bitcast=True) + sign_mask = tl.full(bits.shape, 0x8000, tl.uint16) + bits = tl.where(x < 0, ~bits, bits | sign_mask) + return (bits >> 8).to(tl.int32) + + +@triton.jit +def _convert_to_uint16_hi11(x): + h = x.to(tl.float16) + bits = h.to(tl.uint16, bitcast=True) + sign_mask = tl.full(bits.shape, 0x8000, tl.uint16) + bits = tl.where(x < 0, ~bits, bits | sign_mask) + return (bits >> 5).to(tl.int32) + + +@triton.jit +def _convert_to_trt_uint32(x): + bits = x.to(tl.uint32, bitcast=True) + sign_mask = tl.full(bits.shape, 0x80000000, tl.uint32) + sign_set = (bits & sign_mask) != 0 + inv = (~bits) & tl.full(bits.shape, 0x7FFFFFFF, tl.uint32) + return tl.where(sign_set, bits, inv) + + +@triton.jit +def _convert_to_trt_uint16_hi11(x): + h = x.to(tl.float16) + bits = h.to(tl.uint16, bitcast=True) + sign_mask = tl.full(bits.shape, 0x8000, tl.uint16) + sign_set = (bits & sign_mask) != 0 + inv = (~bits) & tl.full(bits.shape, 0x7FFF, tl.uint16) + mapped = tl.where(sign_set, bits, inv) + return (mapped >> 5).to(tl.int32) + + +@triton.jit +def processHistogramStep( + row_ptr, + stride_xn, + row_start, + row_end, + seq_len, + step_idx: tl.constexpr, + logit_pattern, + s_step_thresholds_ptr, + found_topk_values, + hist_base_ptr, + s_out_indices_ptr, + s_final_cnt_ptr, + s_found_topk_values_ptr, + s_threshold_bin_idx_ptr, + s_final_bin_size_ptr, + ASSUME_ALIGNED: tl.constexpr, + TOPK: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + VEC: tl.constexpr = 4 + FINAL_SORT_ITEMS: tl.constexpr = 2048 + RADIX11_SIZE: tl.constexpr = 2048 + RADIX11_MASK: tl.constexpr = 0x7FF + RADIX10_SIZE: tl.constexpr = 1024 + RADIX10_MASK: tl.constexpr = 0x3FF + + lane = tl.arange(0, BLOCK_SIZE) + vec = tl.arange(0, VEC) + ones = tl.full([BLOCK_SIZE], 1, tl.int32) + ones_vec_2d = tl.full([BLOCK_SIZE, VEC], 1, tl.int32) + zeros = tl.zeros([BLOCK_SIZE], dtype=tl.int32) + zeros_vec_2d = tl.zeros([BLOCK_SIZE, VEC], dtype=tl.int32) + + for clear_round in tl.range(0, RADIX11_SIZE // BLOCK_SIZE): + clear_bins = clear_round * BLOCK_SIZE + lane + tl.store(hist_base_ptr + clear_bins, 0) + tl.debug_barrier() + + if step_idx == 2: + step1_threshold = tl.load(s_step_thresholds_ptr + 1) + logit_pattern = (step1_threshold.to(tl.uint32) & RADIX11_MASK) << 21 + elif step_idx == 3: + step1_threshold = tl.load(s_step_thresholds_ptr + 1) + step2_threshold = tl.load(s_step_thresholds_ptr + 2) + logit_pattern = ((step1_threshold.to(tl.uint32) & RADIX11_MASK) << 21) | ( + (step2_threshold.to(tl.uint32) & RADIX11_MASK) << 10) + + n_tiles = tl.cdiv(seq_len, BLOCK_SIZE) + n_vec_full = seq_len // (BLOCK_SIZE * VEC) + rem_tiles = (seq_len - n_vec_full * BLOCK_SIZE * VEC) // BLOCK_SIZE + + if ASSUME_ALIGNED: + for t in tl.range(0, n_vec_full): + base = t * BLOCK_SIZE * VEC + lane * VEC + offs = base[:, None] + vec[None, :] + x_vec = tl.load(row_ptr + offs) + key = _convert_to_trt_uint32(x_vec) + if step_idx == 0: + digit = _convert_to_trt_uint16_hi11(x_vec) + elif step_idx == 1: + digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) + elif step_idx == 2: + digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) + else: + digit = (key & RADIX10_MASK).to(tl.int32) + + if step_idx < 2: + partial = tl.full([BLOCK_SIZE, VEC], True, tl.int1) + elif step_idx == 2: + partial = ((key ^ logit_pattern) >> 21) == 0 + else: + partial = ((key ^ logit_pattern) >> 10) == 0 + + tl.atomic_add( + hist_base_ptr + digit, + ones_vec_2d, + mask=partial, + sem="relaxed", + scope="cta", + ) + for t in tl.range(0, rem_tiles): + offs = (n_vec_full * VEC + t) * BLOCK_SIZE + lane + x = tl.load(row_ptr + offs) + key = _convert_to_trt_uint32(x) + if step_idx == 0: + digit = _convert_to_trt_uint16_hi11(x) + elif step_idx == 1: + digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) + elif step_idx == 2: + digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) + else: + digit = (key & RADIX10_MASK).to(tl.int32) + + if step_idx < 2: + partial = tl.full([BLOCK_SIZE], True, tl.int1) + elif step_idx == 2: + partial = ((key ^ logit_pattern) >> 21) == 0 + else: + partial = ((key ^ logit_pattern) >> 10) == 0 + + tl.atomic_add( + hist_base_ptr + digit, + ones, + mask=partial, + sem="relaxed", + scope="cta", + ) + else: + for t in tl.range(0, n_tiles): + offs = t * BLOCK_SIZE + lane + in_range = (offs < seq_len) & (offs >= row_start) & (offs < row_end) + x = tl.load(row_ptr + offs * stride_xn, mask=in_range, other=float("-inf")) + key = _convert_to_trt_uint32(x) + if step_idx == 0: + digit = _convert_to_trt_uint16_hi11(x) + elif step_idx == 1: + digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) + elif step_idx == 2: + digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) + else: + digit = (key & RADIX10_MASK).to(tl.int32) + + if step_idx < 2: + partial = in_range + elif step_idx == 2: + partial = in_range & (((key ^ logit_pattern) >> 21) == 0) + else: + partial = in_range & (((key ^ logit_pattern) >> 10) == 0) + + tl.atomic_add( + hist_base_ptr + digit, + ones, + mask=partial, + sem="relaxed", + scope="cta", + ) + tl.debug_barrier() + + # TRT-style threshold search with per-round early-exit. + tl.store(s_threshold_bin_idx_ptr, -1) + tl.store(s_final_bin_size_ptr, 0) + tl.debug_barrier() + threshold_bin_ptrs = s_threshold_bin_idx_ptr + zeros + final_bin_size_ptrs = s_final_bin_size_ptr + zeros + last_value = found_topk_values + threshold_found = False + threshold_rounds = tl.where( + step_idx == 3, + RADIX10_SIZE // BLOCK_SIZE, + RADIX11_SIZE // BLOCK_SIZE, + ) + for round_idx in tl.range(0, threshold_rounds): + if not threshold_found: + bins = round_idx * BLOCK_SIZE + lane + counts = tl.load(hist_base_ptr + bins) + prefix_sum, counts_total = tle.cumsum(counts, axis=0, reverse=False) + prefix_sum = prefix_sum + last_value + total_sum = last_value + counts_total + next_prefix_sum = prefix_sum + counts + threshold_mask = (prefix_sum < TOPK) & (next_prefix_sum >= TOPK) + threshold_bin = bins + threshold_bin_size = next_prefix_sum - prefix_sum + tl.store(threshold_bin_ptrs, threshold_bin, mask=threshold_mask) + tl.store(final_bin_size_ptrs, threshold_bin_size, mask=threshold_mask) + found_round = tl.reduce_or(threshold_mask, axis=0) + threshold_found = found_round + last_value = total_sum + + threshold_bin_idx = tl.load(s_threshold_bin_idx_ptr) + final_bin_size = tl.load(s_final_bin_size_ptr) + tl.store(s_step_thresholds_ptr + step_idx, threshold_bin_idx) + + use_final = (step_idx < 3) & (threshold_bin_idx >= 0) & (final_bin_size <= FINAL_SORT_ITEMS) + if use_final: + tl.store(s_final_cnt_ptr, 0) + + found_ptrs = s_found_topk_values_ptr + zeros + final_cnt_ptrs = s_final_cnt_ptr + zeros + if ASSUME_ALIGNED: + found_ptrs_vec_2d = s_found_topk_values_ptr + zeros_vec_2d + final_cnt_ptrs_vec_2d = s_final_cnt_ptr + zeros_vec_2d + for t in tl.range(0, n_vec_full): + base = t * BLOCK_SIZE * VEC + lane * VEC + offs = base[:, None] + vec[None, :] + x_vec = tl.load(row_ptr + offs) + key = _convert_to_trt_uint32(x_vec) + if step_idx == 0: + digit = _convert_to_trt_uint16_hi11(x_vec) + elif step_idx == 1: + digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) + elif step_idx == 2: + digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) + else: + digit = (key & RADIX10_MASK).to(tl.int32) + + if step_idx < 2: + partial = tl.full([BLOCK_SIZE, VEC], True, tl.int1) + elif step_idx == 2: + partial = ((key ^ logit_pattern) >> 21) == 0 + else: + partial = ((key ^ logit_pattern) >> 10) == 0 + + take_lt = partial & (digit < threshold_bin_idx) + out_pos_lt = tl.atomic_add( + found_ptrs_vec_2d, + ones_vec_2d, + mask=take_lt, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_ptr + out_pos_lt, + offs.to(tl.int32), + mask=take_lt & (out_pos_lt < TOPK), + ) + + if step_idx == 3: + take_eq = partial & (digit == threshold_bin_idx) + out_pos_eq = tl.atomic_add( + hist_base_ptr + digit, + ones_vec_2d, + mask=take_eq, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_ptr + out_pos_eq, + offs.to(tl.int32), + mask=take_eq & (out_pos_eq < TOPK), + ) + elif use_final: + take_eq_final = partial & (digit == threshold_bin_idx) + final_pos = tl.atomic_add( + final_cnt_ptrs_vec_2d, + ones_vec_2d, + mask=take_eq_final, + sem="relaxed", + scope="cta", + ) + tl.store( + hist_base_ptr + final_pos, + offs.to(tl.int32), + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + tl.store( + hist_base_ptr + (FINAL_SORT_ITEMS + final_pos), + x_vec.to(tl.int32, bitcast=True), + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + + for t in tl.range(0, rem_tiles): + offs = (n_vec_full * VEC + t) * BLOCK_SIZE + lane + x = tl.load(row_ptr + offs) + key = _convert_to_trt_uint32(x) + if step_idx == 0: + digit = _convert_to_trt_uint16_hi11(x) + elif step_idx == 1: + digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) + elif step_idx == 2: + digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) + else: + digit = (key & RADIX10_MASK).to(tl.int32) + + if step_idx < 2: + partial = tl.full([BLOCK_SIZE], True, tl.int1) + elif step_idx == 2: + partial = ((key ^ logit_pattern) >> 21) == 0 + else: + partial = ((key ^ logit_pattern) >> 10) == 0 + + take_lt = partial & (digit < threshold_bin_idx) + out_pos_lt = tl.atomic_add( + found_ptrs, + ones, + mask=take_lt, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_ptr + out_pos_lt, + offs.to(tl.int32), + mask=take_lt & (out_pos_lt < TOPK), + ) + + if step_idx == 3: + take_eq = partial & (digit == threshold_bin_idx) + out_pos_eq = tl.atomic_add( + hist_base_ptr + digit, + ones, + mask=take_eq, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_ptr + out_pos_eq, + offs.to(tl.int32), + mask=take_eq & (out_pos_eq < TOPK), + ) + elif use_final: + take_eq_final = partial & (digit == threshold_bin_idx) + final_pos = tl.atomic_add( + final_cnt_ptrs, + ones, + mask=take_eq_final, + sem="relaxed", + scope="cta", + ) + tl.store( + hist_base_ptr + final_pos, + offs.to(tl.int32), + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + tl.store( + hist_base_ptr + (FINAL_SORT_ITEMS + final_pos), + x.to(tl.int32, bitcast=True), + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + else: + for t in tl.range(0, n_tiles): + offs = t * BLOCK_SIZE + lane + in_range = (offs < seq_len) & (offs >= row_start) & (offs < row_end) + x = tl.load(row_ptr + offs * stride_xn, mask=in_range, other=float("-inf")) + key = _convert_to_trt_uint32(x) + if step_idx == 0: + digit = _convert_to_trt_uint16_hi11(x) + elif step_idx == 1: + digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) + elif step_idx == 2: + digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) + else: + digit = (key & RADIX10_MASK).to(tl.int32) + + if step_idx < 2: + partial = in_range + elif step_idx == 2: + partial = in_range & (((key ^ logit_pattern) >> 21) == 0) + else: + partial = in_range & (((key ^ logit_pattern) >> 10) == 0) + + take_lt = partial & (digit < threshold_bin_idx) + out_pos_lt = tl.atomic_add( + found_ptrs, + ones, + mask=take_lt, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_ptr + out_pos_lt, + offs.to(tl.int32), + mask=take_lt & (out_pos_lt < TOPK), + ) + + if step_idx == 3: + take_eq = partial & (digit == threshold_bin_idx) + out_pos_eq = tl.atomic_add( + hist_base_ptr + digit, + ones, + mask=take_eq, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_ptr + out_pos_eq, + offs.to(tl.int32), + mask=take_eq & (out_pos_eq < TOPK), + ) + elif use_final: + take_eq_final = partial & (digit == threshold_bin_idx) + final_pos = tl.atomic_add( + final_cnt_ptrs, + ones, + mask=take_eq_final, + sem="relaxed", + scope="cta", + ) + tl.store( + hist_base_ptr + final_pos, + offs.to(tl.int32), + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + tl.store( + hist_base_ptr + (FINAL_SORT_ITEMS + final_pos), + x.to(tl.int32, bitcast=True), + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + + if step_idx < 3: + if use_final: + need_final_sort = True + continue_to_next_step = False + else: + need_final_sort = False + continue_to_next_step = True + else: + tl.store(s_found_topk_values_ptr, TOPK) + need_final_sort = False + continue_to_next_step = False + + tl.debug_barrier() + return continue_to_next_step, need_final_sort, logit_pattern + + +@triton.jit +def _tle_final_select_radix( + hist_base_ptr, + s_out_indices_ptr, + s_final_cnt_ptr, + s_found_topk_values_ptr, + TOPK: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + FINAL_SORT_ITEMS: tl.constexpr, +): + RADIX_BITS_FINAL: tl.constexpr = 8 + RADIX_SIZE_FINAL: tl.constexpr = 1 << RADIX_BITS_FINAL + RADIX_MASK_FINAL: tl.constexpr = RADIX_SIZE_FINAL - 1 + DIGIT_START: tl.constexpr = 32 - RADIX_BITS_FINAL + + lane = tl.arange(0, BLOCK_SIZE) + ones = tl.full([BLOCK_SIZE], 1, tl.int32) + zeros = tl.zeros([BLOCK_SIZE], dtype=tl.int32) + bins = tl.arange(0, RADIX_SIZE_FINAL) + + s_radix_counts = tle.gpu.alloc( + [RADIX_SIZE_FINAL], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + radix_count_ptr = tle.gpu.local_ptr(s_radix_counts, (0, )) + radix_count_vec_ptr = tle.gpu.local_ptr(s_radix_counts, (bins, )) + + base_idx = tl.load(s_found_topk_values_ptr) + final_cnt = tl.minimum(tl.load(s_final_cnt_ptr), FINAL_SORT_ITEMS) + remain = tl.minimum(TOPK - base_idx, final_cnt) + if remain > 0: + desired = tl.zeros((), dtype=tl.uint32) + desired_mask = tl.zeros((), dtype=tl.uint32) + k_to_find = remain + 1 + + for digit_pos in tl.static_range(DIGIT_START, -1, -RADIX_BITS_FINAL): + tl.store(radix_count_ptr + lane, 0, mask=lane < RADIX_SIZE_FINAL) + tl.debug_barrier() + + cnt_tiles = tl.cdiv(final_cnt, BLOCK_SIZE) + for t in tl.range(0, cnt_tiles): + pos = t * BLOCK_SIZE + lane + valid = pos < final_cnt + x_bits_i32 = tl.load( + hist_base_ptr + (FINAL_SORT_ITEMS + pos), + mask=valid, + other=0, + ) + x = x_bits_i32.to(tl.float32, bitcast=True) + key = _convert_to_trt_uint32(x) + matches = (key & desired_mask) == desired + digit = ((key >> digit_pos) & RADIX_MASK_FINAL).to(tl.int32) + take = valid & matches + tl.atomic_add( + radix_count_ptr + digit, + ones, + mask=take, + sem="relaxed", + scope="cta", + ) + + tl.debug_barrier() + counts = tl.load(radix_count_vec_ptr) + prefix_sum, _ = tle.cumsum(counts, axis=0, reverse=False) + next_prefix_sum = prefix_sum + counts + threshold_mask = (prefix_sum < k_to_find) & (next_prefix_sum >= k_to_find) + threshold_init = tl.full((), RADIX_SIZE_FINAL, dtype=tl.int32) + threshold_bin = tl.min(tl.where(threshold_mask, bins, threshold_init), axis=0).to(tl.int32) + threshold_bin = tl.where(threshold_bin == RADIX_SIZE_FINAL, RADIX_SIZE_FINAL - 1, threshold_bin) + counts_lt = tl.max(tl.where(bins == threshold_bin, prefix_sum, 0), axis=0).to(tl.int32) + + desired = desired | (threshold_bin.to(tl.uint32) << digit_pos) + desired_mask = desired_mask | (tl.full((), RADIX_MASK_FINAL, dtype=tl.uint32) << digit_pos) + k_to_find = k_to_find - counts_lt + + thr_key = desired + found_ptrs = s_found_topk_values_ptr + zeros + cnt_tiles = tl.cdiv(final_cnt, BLOCK_SIZE) + for t in tl.range(0, cnt_tiles): + pos = t * BLOCK_SIZE + lane + valid = pos < final_cnt + idx = tl.load(hist_base_ptr + pos, mask=valid, other=0) + x_bits_i32 = tl.load( + hist_base_ptr + (FINAL_SORT_ITEMS + pos), + mask=valid, + other=0, + ) + x = x_bits_i32.to(tl.float32, bitcast=True) + key = _convert_to_trt_uint32(x) + take_lt = valid & (key < thr_key) + out_pos_gt = tl.atomic_add( + found_ptrs, + ones, + mask=take_lt, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_ptr + out_pos_gt, + idx, + mask=take_lt & (out_pos_gt < TOPK), + ) + + cur = tl.load(s_found_topk_values_ptr) + if cur < TOPK: + for t in tl.range(0, cnt_tiles): + cur = tl.load(s_found_topk_values_ptr) + if cur < TOPK: + pos = t * BLOCK_SIZE + lane + valid = pos < final_cnt + idx = tl.load(hist_base_ptr + pos, mask=valid, other=0) + x_bits_i32 = tl.load( + hist_base_ptr + (FINAL_SORT_ITEMS + pos), + mask=valid, + other=0, + ) + x = x_bits_i32.to(tl.float32, bitcast=True) + key = _convert_to_trt_uint32(x) + take_eq = valid & (key == thr_key) + out_pos_eq = tl.atomic_add( + found_ptrs, + ones, + mask=take_eq, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_ptr + out_pos_eq, + idx, + mask=take_eq & (out_pos_eq < TOPK), + ) + + tl.store(s_found_topk_values_ptr, TOPK) + + +# %% +# TLE kernel (shared memory) +# -------------------------- + + +@triton.jit +def tle_topk_selector_kernel( + x_ptr, + out_ptr, + starts_ptr, + ends_ptr, + stride_xm, + stride_xn, + stride_outm, + stride_outn, + seq_len, + ASSUME_ALIGNED: tl.constexpr, + TOPK: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + USE_RADIX_FINAL: tl.constexpr, +): + pid = tl.program_id(0) + row_start = tl.load(starts_ptr + pid).to(tl.int32) + row_end = tl.load(ends_ptr + pid).to(tl.int32) + + row_ptr = x_ptr + pid * stride_xm + out_row = out_ptr + pid * stride_outm + row_len = row_end - row_start + + if ASSUME_ALIGNED: + tl.assume(row_start == 0) + tl.assume(row_end == seq_len) + tl.assume(stride_xn == 1) + tl.assume(stride_outn == 1) + seq_len = tl.multiple_of(seq_len, BLOCK_SIZE) + + lane = tl.arange(0, BLOCK_SIZE) + if row_len <= TOPK: + chunks: tl.constexpr = (TOPK + BLOCK_SIZE - 1) // BLOCK_SIZE + for chunk_idx in tl.range(0, chunks): + pos = chunk_idx * BLOCK_SIZE + lane + take_row = pos < row_len + tl.store(out_row + pos * stride_outn, (row_start + pos).to(tl.int32), mask=take_row) + take_pad = (pos >= row_len) & (pos < TOPK) + tl.store(out_row + pos * stride_outn, -1, mask=take_pad) + return + + FINAL_SORT_ITEMS: tl.constexpr = 2048 + HIST_SIZE: tl.constexpr = 4096 + + s_histogram = tle.gpu.alloc( + [HIST_SIZE], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + hist_base_ptr = tle.gpu.local_ptr(s_histogram, (0, )) + # TRT-style union reuse: + # - [0, FINAL_SORT_ITEMS): final indices (int32) + # - [FINAL_SORT_ITEMS, 2*FINAL_SORT_ITEMS): final logits bitcast(int32) + s_out_indices = tle.gpu.alloc( + [TOPK], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_final_cnt = tle.gpu.alloc( + [1], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_threshold_bin_idx = tle.gpu.alloc( + [1], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_final_bin_size = tle.gpu.alloc( + [1], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_found_topk_values = tle.gpu.alloc( + [1], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_step_thresholds = tle.gpu.alloc( + [4], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_final_cnt_ptr = tle.gpu.local_ptr(s_final_cnt, (0, )) + s_threshold_bin_idx_ptr = tle.gpu.local_ptr(s_threshold_bin_idx, (0, )) + s_final_bin_size_ptr = tle.gpu.local_ptr(s_final_bin_size, (0, )) + s_found_topk_values_ptr = tle.gpu.local_ptr(s_found_topk_values, (0, )) + s_step_thresholds_ptr = tle.gpu.local_ptr(s_step_thresholds, (0, )) + s_out_indices_ptr = tle.gpu.local_ptr(s_out_indices, (0, )) + tl.store(s_final_cnt_ptr, 0) + tl.store(s_threshold_bin_idx_ptr, -1) + tl.store(s_final_bin_size_ptr, 0) + tl.store(s_found_topk_values_ptr, 0) + + logit_pattern = tl.zeros((), dtype=tl.uint32) + continue_to_next_step = True + need_final_sort = False + init_chunks: tl.constexpr = (TOPK + BLOCK_SIZE - 1) // BLOCK_SIZE + for init_idx in tl.range(0, init_chunks): + pos = init_idx * BLOCK_SIZE + lane + tl.store(tle.gpu.local_ptr(s_out_indices, (pos, )), -1, mask=pos < TOPK) + + for step_idx in tl.static_range(0, 4): + if continue_to_next_step: + found_topk_values = tl.load(s_found_topk_values_ptr) + continue_to_next_step, step_need_final_sort, logit_pattern = processHistogramStep( + row_ptr, + stride_xn, + row_start, + row_end, + seq_len, + step_idx, + logit_pattern, + s_step_thresholds_ptr, + found_topk_values, + hist_base_ptr, + s_out_indices_ptr, + s_final_cnt_ptr, + s_found_topk_values_ptr, + s_threshold_bin_idx_ptr, + s_final_bin_size_ptr, + ASSUME_ALIGNED=ASSUME_ALIGNED, + TOPK=TOPK, + BLOCK_SIZE=BLOCK_SIZE, + ) + need_final_sort = need_final_sort | step_need_final_sort + + if need_final_sort: + if USE_RADIX_FINAL: + _tle_final_select_radix( + hist_base_ptr, + s_out_indices_ptr, + s_final_cnt_ptr, + s_found_topk_values_ptr, + TOPK=TOPK, + BLOCK_SIZE=BLOCK_SIZE, + FINAL_SORT_ITEMS=FINAL_SORT_ITEMS, + ) + else: + base_idx = tl.load(s_found_topk_values_ptr) + # Guard against stale/oversized counts to avoid out-of-bounds accesses + # in the shared-memory final buffers. + final_cnt = tl.minimum(tl.load(s_final_cnt_ptr), FINAL_SORT_ITEMS) + sort_chunks = tl.cdiv(final_cnt, BLOCK_SIZE) + for sort_chunk in tl.range(0, sort_chunks): + pos = sort_chunk * BLOCK_SIZE + lane + valid = pos < final_cnt + logit_i_bits = tl.load( + tle.gpu.local_ptr(s_histogram, (FINAL_SORT_ITEMS + pos, )), + mask=valid, + other=0, + ) + logit_i = logit_i_bits.to(tl.float32, bitcast=True) + out_rank = tl.zeros([BLOCK_SIZE], dtype=tl.int32) + for j in tl.range(0, final_cnt): + logit_j_bits = tl.load(tle.gpu.local_ptr(s_histogram, (FINAL_SORT_ITEMS + j, ))) + logit_j = logit_j_bits.to(tl.float32, bitcast=True) + better = (logit_i < logit_j) | ((logit_i == logit_j) & (pos < j)) + out_rank = out_rank + (valid & better).to(tl.int32) + dst_pos = base_idx + out_rank + take = valid & (dst_pos < TOPK) + idx_i = tl.load( + tle.gpu.local_ptr(s_histogram, (pos, )), + mask=take, + other=0, + ) + tl.store(tle.gpu.local_ptr(s_out_indices, (dst_pos, )), idx_i, mask=take) + tl.store(s_found_topk_values_ptr, TOPK) + + flush_chunks: tl.constexpr = (TOPK + BLOCK_SIZE - 1) // BLOCK_SIZE + for flush_chunk in tl.static_range(flush_chunks): + pos = flush_chunk * BLOCK_SIZE + lane + mask = pos < TOPK + out_vals = tl.load(tle.gpu.local_ptr(s_out_indices, (pos, )), mask=mask, other=-1) + tl.store(out_row + pos * stride_outn, out_vals, mask=mask) + + +@triton.jit +def _tle_topk_smem_overflow_fallback_fullscan( + row_ptr, + out_row, + stride_xn, + stride_outn, + row_start, + row_end, + seq_len, + hist_base_ptr, + s_write_count_ptr, + s_eq_count_ptr, + TOPK: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + RADIX_SIZE: tl.constexpr = 256 + CAND_ROUNDS: tl.constexpr = 4 + + lane = tl.arange(0, BLOCK_SIZE) + ones = tl.full([BLOCK_SIZE], 1, tl.int32) + zeros = tl.zeros([BLOCK_SIZE], dtype=tl.int32) + + out_init_chunks: tl.constexpr = (TOPK + BLOCK_SIZE - 1) // BLOCK_SIZE + hist_clear_chunks: tl.constexpr = (RADIX_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE + num_scan_tiles = tl.cdiv(seq_len, BLOCK_SIZE) + + tl.store(s_write_count_ptr, 0) + tl.store(s_eq_count_ptr, 0) + for t in tl.range(0, out_init_chunks): + pos = t * BLOCK_SIZE + lane + tl.store(out_row + pos * stride_outn, -1, mask=pos < TOPK) + + for t in tl.range(0, hist_clear_chunks): + bins = t * BLOCK_SIZE + lane + tl.store(hist_base_ptr + bins, 0, mask=bins < RADIX_SIZE) + tl.debug_barrier() + + for t in tl.range(0, num_scan_tiles): + offs = t * BLOCK_SIZE + lane + valid = (offs < seq_len) & (offs >= row_start) & (offs < row_end) + x = tl.load(row_ptr + offs * stride_xn, mask=valid, other=float("-inf")) + digit = _convert_to_uint16_hi8(x) + tl.atomic_add( + hist_base_ptr + digit, + ones, + mask=valid, + sem="relaxed", + scope="cta", + ) + tl.debug_barrier() + + radix_bins = tl.arange(0, RADIX_SIZE) + zeros_radix = tl.zeros([RADIX_SIZE], dtype=tl.int32) + counts = tl.load(hist_base_ptr + radix_bins) + gt_exclusive, _ = tle.cumsum(counts, axis=0, reverse=True) + cumsum_desc = gt_exclusive + counts + threshold_mask = (cumsum_desc >= TOPK) & (gt_exclusive < TOPK) + coarse_threshold_bin = tl.sum( + tl.where(threshold_mask, radix_bins, zeros_radix), + axis=0, + ) + coarse_counts_gt = tl.sum( + tl.where(threshold_mask, gt_exclusive, zeros_radix), + axis=0, + ) + gt_cursors = tl.where(radix_bins > coarse_threshold_bin, gt_exclusive, zeros_radix) + tl.store(hist_base_ptr + radix_bins, gt_cursors) + remaining = TOPK - coarse_counts_gt + tl.store(s_write_count_ptr + zeros, coarse_counts_gt) + tl.debug_barrier() + + for t in tl.range(0, num_scan_tiles): + offs = t * BLOCK_SIZE + lane + valid = (offs < seq_len) & (offs >= row_start) & (offs < row_end) + idx = offs.to(tl.int32) + x = tl.load(row_ptr + offs * stride_xn, mask=valid, other=float("-inf")) + digit = _convert_to_uint16_hi8(x) + + take_gt = valid & (digit > coarse_threshold_bin) + out_pos_gt = tl.atomic_add( + hist_base_ptr + digit, + ones, + mask=take_gt, + sem="relaxed", + scope="cta", + ) + tl.store( + out_row + out_pos_gt * stride_outn, + idx, + mask=take_gt & (out_pos_gt < TOPK), + ) + tl.debug_barrier() + + refine_prefix = tl.zeros((), dtype=tl.uint32) + refine_mask = tl.zeros((), dtype=tl.uint32) + for round_idx in tl.static_range(CAND_ROUNDS): + if remaining > 0: + for t in tl.range(0, hist_clear_chunks): + bins = t * BLOCK_SIZE + lane + tl.store(hist_base_ptr + bins, 0, mask=bins < RADIX_SIZE) + tl.debug_barrier() + + shift: tl.constexpr = 24 - round_idx * 8 + for t in tl.range(0, num_scan_tiles): + offs = t * BLOCK_SIZE + lane + valid = (offs < seq_len) & (offs >= row_start) & (offs < row_end) + x = tl.load(row_ptr + offs * stride_xn, mask=valid, other=float("-inf")) + coarse_digit = _convert_to_uint16_hi8(x) + ordered = _convert_to_uint32(x) + prefix_match = (ordered & refine_mask) == refine_prefix + active = valid & (coarse_digit == coarse_threshold_bin) & prefix_match + digit = ((ordered >> shift) & 0xFF).to(tl.int32) + tl.atomic_add( + hist_base_ptr + digit, + ones, + mask=active, + sem="relaxed", + scope="cta", + ) + tl.debug_barrier() + + radix_bins = tl.arange(0, RADIX_SIZE) + zeros_radix = tl.zeros([RADIX_SIZE], dtype=tl.int32) + counts = tl.load(hist_base_ptr + radix_bins) + gt_exclusive, _ = tle.cumsum(counts, axis=0, reverse=True) + cumsum_desc = gt_exclusive + counts + base_write = tl.load(s_write_count_ptr) + threshold_mask = (cumsum_desc >= remaining) & (gt_exclusive < remaining) + threshold_bin = tl.sum( + tl.where(threshold_mask, radix_bins, zeros_radix), + axis=0, + ) + counts_gt = tl.sum( + tl.where(threshold_mask, gt_exclusive, zeros_radix), + axis=0, + ) + gt_cursors = tl.where( + radix_bins > threshold_bin, + base_write + gt_exclusive, + zeros_radix, + ) + tl.store(hist_base_ptr + radix_bins, gt_cursors) + remaining = remaining - counts_gt + tl.store(s_write_count_ptr + zeros, base_write + counts_gt) + if round_idx == (CAND_ROUNDS - 1): + tl.store(s_eq_count_ptr, 0) + tl.debug_barrier() + + for t in tl.range(0, num_scan_tiles): + offs = t * BLOCK_SIZE + lane + valid = (offs < seq_len) & (offs >= row_start) & (offs < row_end) + idx = offs.to(tl.int32) + x = tl.load(row_ptr + offs * stride_xn, mask=valid, other=float("-inf")) + coarse_digit = _convert_to_uint16_hi8(x) + ordered = _convert_to_uint32(x) + prefix_match = (ordered & refine_mask) == refine_prefix + active = valid & (coarse_digit == coarse_threshold_bin) & prefix_match + digit = ((ordered >> shift) & 0xFF).to(tl.int32) + + take_gt = active & (digit > threshold_bin) + out_pos_gt = tl.atomic_add( + hist_base_ptr + digit, + ones, + mask=take_gt, + sem="relaxed", + scope="cta", + ) + tl.store( + out_row + out_pos_gt * stride_outn, + idx, + mask=take_gt & (out_pos_gt < TOPK), + ) + + if remaining > 0: + take_eq = active & (digit == threshold_bin) + if round_idx == (CAND_ROUNDS - 1): + eq_pos = tl.atomic_add( + s_eq_count_ptr + zeros, + ones, + mask=take_eq, + sem="relaxed", + scope="cta", + ) + take_eq_select = take_eq & (eq_pos < remaining) + out_pos_eq = tl.atomic_add( + s_write_count_ptr + zeros, + ones, + mask=take_eq_select, + sem="relaxed", + scope="cta", + ) + tl.store( + out_row + out_pos_eq * stride_outn, + idx, + mask=take_eq_select & (out_pos_eq < TOPK), + ) + tl.debug_barrier() + + threshold_u32 = threshold_bin.to(tl.uint32) + if round_idx == 0: + refine_prefix = threshold_u32 << 24 + refine_mask = tl.full((), 0xFF000000, tl.uint32) + elif round_idx == 1: + refine_prefix = refine_prefix | (threshold_u32 << 16) + refine_mask = tl.full((), 0xFFFF0000, tl.uint32) + elif round_idx == 2: + refine_prefix = refine_prefix | (threshold_u32 << 8) + refine_mask = tl.full((), 0xFFFFFF00, tl.uint32) + + +@triton.jit +def _tle_process_histogram_step_smem( + row_ptr, + stride_xn, + row_start, + row_end, + seq_len, + step_idx: tl.constexpr, + found_topk_values, + hist_base_ptr, + s_out_indices_ptr, + s_final_cnt_ptr, + s_found_topk_values_ptr, + s_threshold_bin_idx_ptr, + s_final_bin_size_ptr, + src_idx_ptr, + src_val_ptr, + src_count_ptr, + dst_idx_ptr, + dst_val_ptr, + dst_count_ptr, + s_need_fallback_ptr, + ASSUME_ALIGNED: tl.constexpr, + TOPK: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + SMEM_INPUT_SIZE: tl.constexpr, +): + VEC: tl.constexpr = 4 + FINAL_SORT_ITEMS: tl.constexpr = 2048 + RADIX11_SIZE: tl.constexpr = 2048 + RADIX11_MASK: tl.constexpr = 0x7FF + RADIX10_SIZE: tl.constexpr = 1024 + RADIX10_MASK: tl.constexpr = 0x3FF + + lane = tl.arange(0, BLOCK_SIZE) + vec = tl.arange(0, VEC) + ones = tl.full([BLOCK_SIZE], 1, tl.int32) + ones_vec_2d = tl.full([BLOCK_SIZE, VEC], 1, tl.int32) + zeros = tl.zeros([BLOCK_SIZE], dtype=tl.int32) + zeros_vec_2d = tl.zeros([BLOCK_SIZE, VEC], dtype=tl.int32) + + clear_rounds = tl.where(step_idx == 2, RADIX10_SIZE // BLOCK_SIZE, RADIX11_SIZE // BLOCK_SIZE) + threshold_rounds = clear_rounds + + for clear_round in tl.range(0, clear_rounds): + clear_bins = clear_round * BLOCK_SIZE + lane + tl.store(hist_base_ptr + clear_bins, 0) + + if step_idx == 0: + if ASSUME_ALIGNED: + n_vec_full = seq_len // (BLOCK_SIZE * VEC) + rem_tiles = (seq_len - n_vec_full * BLOCK_SIZE * VEC) // BLOCK_SIZE + + for t in tl.range(0, n_vec_full): + base = t * BLOCK_SIZE * VEC + lane * VEC + offs = base[:, None] + vec[None, :] + x_vec = tl.load(row_ptr + offs) + key = _convert_to_trt_uint32(x_vec) + digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) + tl.atomic_add( + hist_base_ptr + digit, + ones_vec_2d, + sem="relaxed", + scope="cta", + ) + + for t in tl.range(0, rem_tiles): + offs = (n_vec_full * VEC + t) * BLOCK_SIZE + lane + x = tl.load(row_ptr + offs) + key = _convert_to_trt_uint32(x) + digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) + tl.atomic_add( + hist_base_ptr + digit, + ones, + sem="relaxed", + scope="cta", + ) + else: + n_tiles = tl.cdiv(seq_len, BLOCK_SIZE) + for t in tl.range(0, n_tiles): + offs = t * BLOCK_SIZE + lane + valid = (offs < seq_len) & (offs >= row_start) & (offs < row_end) + x = tl.load(row_ptr + offs * stride_xn, mask=valid, other=float("-inf")) + key = _convert_to_trt_uint32(x) + digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) + tl.atomic_add( + hist_base_ptr + digit, + ones, + mask=valid, + sem="relaxed", + scope="cta", + ) + else: + src_count = tl.minimum(tl.load(src_count_ptr), SMEM_INPUT_SIZE) + src_vec_full = src_count // (BLOCK_SIZE * VEC) + vec_processed = src_vec_full * BLOCK_SIZE * VEC + src_tail_tiles = tl.cdiv(src_count - vec_processed, BLOCK_SIZE) + + for t in tl.range(0, src_vec_full): + base = t * BLOCK_SIZE * VEC + lane * VEC + pos = base[:, None] + vec[None, :] + val_bits_vec = tl.load(src_val_ptr + pos) + x_vec = val_bits_vec.to(tl.float32, bitcast=True) + key = _convert_to_trt_uint32(x_vec) + if step_idx == 1: + digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) + else: + digit = (key & RADIX10_MASK).to(tl.int32) + tl.atomic_add( + hist_base_ptr + digit, + ones_vec_2d, + sem="relaxed", + scope="cta", + ) + + for t in tl.range(0, src_tail_tiles): + pos = vec_processed + t * BLOCK_SIZE + lane + valid = pos < src_count + val_bits = tl.load(src_val_ptr + pos, mask=valid, other=0) + x = val_bits.to(tl.float32, bitcast=True) + key = _convert_to_trt_uint32(x) + if step_idx == 1: + digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) + else: + digit = (key & RADIX10_MASK).to(tl.int32) + tl.atomic_add( + hist_base_ptr + digit, + ones, + mask=valid, + sem="relaxed", + scope="cta", + ) + tl.debug_barrier() + + tl.store(s_threshold_bin_idx_ptr, -1) + tl.store(s_final_bin_size_ptr, 0) + threshold_bin_ptrs = s_threshold_bin_idx_ptr + zeros + final_bin_size_ptrs = s_final_bin_size_ptr + zeros + last_value = found_topk_values + threshold_found = False + for round_idx in tl.range(0, threshold_rounds): + if not threshold_found: + bins = round_idx * BLOCK_SIZE + lane + counts = tl.load(hist_base_ptr + bins) + prefix_sum, counts_total = tle.cumsum(counts, axis=0, reverse=False) + prefix_sum = prefix_sum + last_value + total_sum = last_value + counts_total + next_prefix_sum = prefix_sum + counts + threshold_mask = (prefix_sum < TOPK) & (next_prefix_sum >= TOPK) + threshold_bin = bins + threshold_bin_size = next_prefix_sum - prefix_sum + tl.store(threshold_bin_ptrs, threshold_bin, mask=threshold_mask) + tl.store(final_bin_size_ptrs, threshold_bin_size, mask=threshold_mask) + found_round = tl.reduce_or(threshold_mask, axis=0) + threshold_found = found_round + last_value = total_sum + + tl.debug_barrier() + threshold_bin_idx = tl.load(s_threshold_bin_idx_ptr) + final_bin_size = tl.load(s_final_bin_size_ptr) + use_final = (step_idx < 2) & (threshold_bin_idx >= 0) & (final_bin_size <= FINAL_SORT_ITEMS) + if use_final: + tl.store(s_final_cnt_ptr, 0) + elif step_idx < 2: + tl.store(dst_count_ptr, 0) + + found_ptrs = s_found_topk_values_ptr + zeros + found_ptrs_vec_2d = s_found_topk_values_ptr + zeros_vec_2d + final_cnt_ptrs = s_final_cnt_ptr + zeros + final_cnt_ptrs_vec_2d = s_final_cnt_ptr + zeros_vec_2d + dst_count_ptrs = dst_count_ptr + zeros + dst_count_ptrs_vec_2d = dst_count_ptr + zeros_vec_2d + fallback_ptrs = s_need_fallback_ptr + zeros + fallback_ptrs_vec_2d = s_need_fallback_ptr + zeros_vec_2d + + if step_idx == 0: + if ASSUME_ALIGNED: + n_vec_full = seq_len // (BLOCK_SIZE * VEC) + rem_tiles = (seq_len - n_vec_full * BLOCK_SIZE * VEC) // BLOCK_SIZE + + for t in tl.range(0, n_vec_full): + base = t * BLOCK_SIZE * VEC + lane * VEC + offs = base[:, None] + vec[None, :] + idx = offs.to(tl.int32) + x_vec = tl.load(row_ptr + offs) + key = _convert_to_trt_uint32(x_vec) + digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) + + take_lt = digit < threshold_bin_idx + out_pos_lt = tl.atomic_add( + found_ptrs_vec_2d, + ones_vec_2d, + mask=take_lt, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_ptr + out_pos_lt, + idx, + mask=take_lt & (out_pos_lt < TOPK), + ) + + if use_final: + take_eq_final = digit == threshold_bin_idx + final_pos = tl.atomic_add( + final_cnt_ptrs_vec_2d, + ones_vec_2d, + mask=take_eq_final, + sem="relaxed", + scope="cta", + ) + tl.store( + hist_base_ptr + final_pos, + idx, + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + tl.store( + hist_base_ptr + (FINAL_SORT_ITEMS + final_pos), + x_vec.to(tl.int32, bitcast=True), + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + elif step_idx < 2: + take_eq_next = digit == threshold_bin_idx + dst_pos = tl.atomic_add( + dst_count_ptrs_vec_2d, + ones_vec_2d, + mask=take_eq_next, + sem="relaxed", + scope="cta", + ) + keep_eq = take_eq_next & (dst_pos < SMEM_INPUT_SIZE) + tl.store(dst_idx_ptr + dst_pos, idx, mask=keep_eq) + tl.store(dst_val_ptr + dst_pos, x_vec.to(tl.int32, bitcast=True), mask=keep_eq) + overflow_mask = take_eq_next & (dst_pos >= SMEM_INPUT_SIZE) + tl.atomic_or( + fallback_ptrs_vec_2d, + ones_vec_2d, + mask=overflow_mask, + sem="relaxed", + scope="cta", + ) + + for t in tl.range(0, rem_tiles): + offs = (n_vec_full * VEC + t) * BLOCK_SIZE + lane + idx = offs.to(tl.int32) + x = tl.load(row_ptr + offs) + key = _convert_to_trt_uint32(x) + digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) + + take_lt = digit < threshold_bin_idx + out_pos_lt = tl.atomic_add( + found_ptrs, + ones, + mask=take_lt, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_ptr + out_pos_lt, + idx, + mask=take_lt & (out_pos_lt < TOPK), + ) + + if use_final: + take_eq_final = digit == threshold_bin_idx + final_pos = tl.atomic_add( + final_cnt_ptrs, + ones, + mask=take_eq_final, + sem="relaxed", + scope="cta", + ) + tl.store( + hist_base_ptr + final_pos, + idx, + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + tl.store( + hist_base_ptr + (FINAL_SORT_ITEMS + final_pos), + x.to(tl.int32, bitcast=True), + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + elif step_idx < 2: + take_eq_next = digit == threshold_bin_idx + dst_pos = tl.atomic_add( + dst_count_ptrs, + ones, + mask=take_eq_next, + sem="relaxed", + scope="cta", + ) + keep_eq = take_eq_next & (dst_pos < SMEM_INPUT_SIZE) + tl.store(dst_idx_ptr + dst_pos, idx, mask=keep_eq) + tl.store(dst_val_ptr + dst_pos, x.to(tl.int32, bitcast=True), mask=keep_eq) + overflow_mask = take_eq_next & (dst_pos >= SMEM_INPUT_SIZE) + tl.atomic_or( + fallback_ptrs, + ones, + mask=overflow_mask, + sem="relaxed", + scope="cta", + ) + else: + n_tiles = tl.cdiv(seq_len, BLOCK_SIZE) + for t in tl.range(0, n_tiles): + offs = t * BLOCK_SIZE + lane + valid = (offs < seq_len) & (offs >= row_start) & (offs < row_end) + idx = offs.to(tl.int32) + x = tl.load(row_ptr + offs * stride_xn, mask=valid, other=float("-inf")) + key = _convert_to_trt_uint32(x) + digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) + + take_lt = valid & (digit < threshold_bin_idx) + out_pos_lt = tl.atomic_add( + found_ptrs, + ones, + mask=take_lt, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_ptr + out_pos_lt, + idx, + mask=take_lt & (out_pos_lt < TOPK), + ) + + if use_final: + take_eq_final = valid & (digit == threshold_bin_idx) + final_pos = tl.atomic_add( + final_cnt_ptrs, + ones, + mask=take_eq_final, + sem="relaxed", + scope="cta", + ) + tl.store( + hist_base_ptr + final_pos, + idx, + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + tl.store( + hist_base_ptr + (FINAL_SORT_ITEMS + final_pos), + x.to(tl.int32, bitcast=True), + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + elif step_idx < 2: + take_eq_next = valid & (digit == threshold_bin_idx) + dst_pos = tl.atomic_add( + dst_count_ptrs, + ones, + mask=take_eq_next, + sem="relaxed", + scope="cta", + ) + keep_eq = take_eq_next & (dst_pos < SMEM_INPUT_SIZE) + tl.store(dst_idx_ptr + dst_pos, idx, mask=keep_eq) + tl.store(dst_val_ptr + dst_pos, x.to(tl.int32, bitcast=True), mask=keep_eq) + overflow_mask = take_eq_next & (dst_pos >= SMEM_INPUT_SIZE) + tl.atomic_or( + fallback_ptrs, + ones, + mask=overflow_mask, + sem="relaxed", + scope="cta", + ) + else: + src_count = tl.minimum(tl.load(src_count_ptr), SMEM_INPUT_SIZE) + src_vec_full = src_count // (BLOCK_SIZE * VEC) + vec_processed = src_vec_full * BLOCK_SIZE * VEC + src_tail_tiles = tl.cdiv(src_count - vec_processed, BLOCK_SIZE) + + for t in tl.range(0, src_vec_full): + base = t * BLOCK_SIZE * VEC + lane * VEC + pos = base[:, None] + vec[None, :] + idx = tl.load(src_idx_ptr + pos) + val_bits_vec = tl.load(src_val_ptr + pos) + x_vec = val_bits_vec.to(tl.float32, bitcast=True) + key = _convert_to_trt_uint32(x_vec) + if step_idx == 1: + digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) + else: + digit = (key & RADIX10_MASK).to(tl.int32) + + take_lt = digit < threshold_bin_idx + out_pos_lt = tl.atomic_add( + found_ptrs_vec_2d, + ones_vec_2d, + mask=take_lt, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_ptr + out_pos_lt, + idx.to(tl.int32), + mask=take_lt & (out_pos_lt < TOPK), + ) + + if step_idx == 2: + take_eq = digit == threshold_bin_idx + out_pos_eq = tl.atomic_add( + found_ptrs_vec_2d, + ones_vec_2d, + mask=take_eq, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_ptr + out_pos_eq, + idx.to(tl.int32), + mask=take_eq & (out_pos_eq < TOPK), + ) + elif use_final: + take_eq_final = digit == threshold_bin_idx + final_pos = tl.atomic_add( + final_cnt_ptrs_vec_2d, + ones_vec_2d, + mask=take_eq_final, + sem="relaxed", + scope="cta", + ) + tl.store( + hist_base_ptr + final_pos, + idx.to(tl.int32), + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + tl.store( + hist_base_ptr + (FINAL_SORT_ITEMS + final_pos), + x_vec.to(tl.int32, bitcast=True), + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + elif step_idx < 2: + take_eq_next = digit == threshold_bin_idx + dst_pos = tl.atomic_add( + dst_count_ptrs_vec_2d, + ones_vec_2d, + mask=take_eq_next, + sem="relaxed", + scope="cta", + ) + keep_eq = take_eq_next & (dst_pos < SMEM_INPUT_SIZE) + tl.store(dst_idx_ptr + dst_pos, idx.to(tl.int32), mask=keep_eq) + tl.store(dst_val_ptr + dst_pos, val_bits_vec, mask=keep_eq) + overflow_mask = take_eq_next & (dst_pos >= SMEM_INPUT_SIZE) + tl.atomic_or( + fallback_ptrs_vec_2d, + ones_vec_2d, + mask=overflow_mask, + sem="relaxed", + scope="cta", + ) + + for t in tl.range(0, src_tail_tiles): + pos = vec_processed + t * BLOCK_SIZE + lane + valid = pos < src_count + idx = tl.load(src_idx_ptr + pos, mask=valid, other=0) + val_bits = tl.load(src_val_ptr + pos, mask=valid, other=0) + x = val_bits.to(tl.float32, bitcast=True) + key = _convert_to_trt_uint32(x) + if step_idx == 1: + digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) + else: + digit = (key & RADIX10_MASK).to(tl.int32) + + take_lt = valid & (digit < threshold_bin_idx) + out_pos_lt = tl.atomic_add( + found_ptrs, + ones, + mask=take_lt, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_ptr + out_pos_lt, + idx.to(tl.int32), + mask=take_lt & (out_pos_lt < TOPK), + ) + + if step_idx == 2: + take_eq = valid & (digit == threshold_bin_idx) + out_pos_eq = tl.atomic_add( + found_ptrs, + ones, + mask=take_eq, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_ptr + out_pos_eq, + idx.to(tl.int32), + mask=take_eq & (out_pos_eq < TOPK), + ) + elif use_final: + take_eq_final = valid & (digit == threshold_bin_idx) + final_pos = tl.atomic_add( + final_cnt_ptrs, + ones, + mask=take_eq_final, + sem="relaxed", + scope="cta", + ) + tl.store( + hist_base_ptr + final_pos, + idx.to(tl.int32), + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + tl.store( + hist_base_ptr + (FINAL_SORT_ITEMS + final_pos), + x.to(tl.int32, bitcast=True), + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + elif step_idx < 2: + take_eq_next = valid & (digit == threshold_bin_idx) + dst_pos = tl.atomic_add( + dst_count_ptrs, + ones, + mask=take_eq_next, + sem="relaxed", + scope="cta", + ) + keep_eq = take_eq_next & (dst_pos < SMEM_INPUT_SIZE) + tl.store(dst_idx_ptr + dst_pos, idx.to(tl.int32), mask=keep_eq) + tl.store(dst_val_ptr + dst_pos, val_bits, mask=keep_eq) + overflow_mask = take_eq_next & (dst_pos >= SMEM_INPUT_SIZE) + tl.atomic_or( + fallback_ptrs, + ones, + mask=overflow_mask, + sem="relaxed", + scope="cta", + ) + + tl.debug_barrier() + if step_idx < 2: + if use_final: + continue_to_next_step = False + need_final_sort = True + else: + continue_to_next_step = True + need_final_sort = False + else: + tl.store(s_found_topk_values_ptr, TOPK) + continue_to_next_step = False + need_final_sort = False + return continue_to_next_step, need_final_sort + + +@triton.jit +def tle_topk_selector_kernel_smem( + x_ptr, + out_ptr, + starts_ptr, + ends_ptr, + stride_xm, + stride_xn, + stride_outm, + stride_outn, + seq_len, + ASSUME_ALIGNED: tl.constexpr, + TOPK: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + SMEM_INPUT_SIZE: tl.constexpr, + USE_RADIX_FINAL: tl.constexpr, +): + pid = tl.program_id(0) + row_start = tl.load(starts_ptr + pid).to(tl.int32) + row_end = tl.load(ends_ptr + pid).to(tl.int32) + + row_ptr = x_ptr + pid * stride_xm + out_row = out_ptr + pid * stride_outm + row_len = row_end - row_start + + if ASSUME_ALIGNED: + tl.assume(row_start == 0) + tl.assume(row_end == seq_len) + tl.assume(stride_xn == 1) + tl.assume(stride_outn == 1) + seq_len = tl.multiple_of(seq_len, BLOCK_SIZE) + + lane = tl.arange(0, BLOCK_SIZE) + if row_len <= TOPK: + chunks: tl.constexpr = (TOPK + BLOCK_SIZE - 1) // BLOCK_SIZE + for chunk_idx in tl.range(0, chunks): + pos = chunk_idx * BLOCK_SIZE + lane + take_row = pos < row_len + tl.store(out_row + pos * stride_outn, (row_start + pos).to(tl.int32), mask=take_row) + take_pad = (pos >= row_len) & (pos < TOPK) + tl.store(out_row + pos * stride_outn, -1, mask=take_pad) + return + + FINAL_SORT_ITEMS: tl.constexpr = 2048 + HIST_SIZE: tl.constexpr = 4096 + + s_histogram = tle.gpu.alloc( + [HIST_SIZE], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + hist_base_ptr = tle.gpu.local_ptr(s_histogram, (0, )) + s_out_indices = tle.gpu.alloc( + [TOPK], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_input_idx0 = tle.gpu.alloc( + [SMEM_INPUT_SIZE], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_input_idx1 = tle.gpu.alloc( + [SMEM_INPUT_SIZE], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_input_val0 = tle.gpu.alloc( + [SMEM_INPUT_SIZE], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_input_val1 = tle.gpu.alloc( + [SMEM_INPUT_SIZE], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_input_count0 = tle.gpu.alloc( + [1], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_input_count1 = tle.gpu.alloc( + [1], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_need_fallback = tle.gpu.alloc( + [1], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_fallback_eq_count = tle.gpu.alloc( + [1], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_final_cnt = tle.gpu.alloc( + [1], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_threshold_bin_idx = tle.gpu.alloc( + [1], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_final_bin_size = tle.gpu.alloc( + [1], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_found_topk_values = tle.gpu.alloc( + [1], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + + s_input_idx0_ptr = tle.gpu.local_ptr(s_input_idx0, (0, )) + s_input_idx1_ptr = tle.gpu.local_ptr(s_input_idx1, (0, )) + s_input_val0_ptr = tle.gpu.local_ptr(s_input_val0, (0, )) + s_input_val1_ptr = tle.gpu.local_ptr(s_input_val1, (0, )) + s_input_count0_ptr = tle.gpu.local_ptr(s_input_count0, (0, )) + s_input_count1_ptr = tle.gpu.local_ptr(s_input_count1, (0, )) + s_need_fallback_ptr = tle.gpu.local_ptr(s_need_fallback, (0, )) + s_fallback_eq_count_ptr = tle.gpu.local_ptr(s_fallback_eq_count, (0, )) + s_final_cnt_ptr = tle.gpu.local_ptr(s_final_cnt, (0, )) + s_threshold_bin_idx_ptr = tle.gpu.local_ptr(s_threshold_bin_idx, (0, )) + s_final_bin_size_ptr = tle.gpu.local_ptr(s_final_bin_size, (0, )) + s_found_topk_values_ptr = tle.gpu.local_ptr(s_found_topk_values, (0, )) + s_out_indices_ptr = tle.gpu.local_ptr(s_out_indices, (0, )) + + tl.store(s_input_count0_ptr, 0) + tl.store(s_input_count1_ptr, 0) + tl.store(s_need_fallback_ptr, 0) + tl.store(s_fallback_eq_count_ptr, 0) + tl.store(s_final_cnt_ptr, 0) + tl.store(s_threshold_bin_idx_ptr, -1) + tl.store(s_final_bin_size_ptr, 0) + tl.store(s_found_topk_values_ptr, 0) + + init_chunks: tl.constexpr = (TOPK + BLOCK_SIZE - 1) // BLOCK_SIZE + for init_idx in tl.range(0, init_chunks): + pos = init_idx * BLOCK_SIZE + lane + tl.store(s_out_indices_ptr + pos, -1, mask=pos < TOPK) + + continue_to_next_step = True + need_final_sort = False + for step_idx in tl.static_range(0, 3): + if continue_to_next_step: + found_topk_values = tl.load(s_found_topk_values_ptr) + if step_idx == 0: + src_idx_ptr = s_input_idx0_ptr + src_val_ptr = s_input_val0_ptr + src_count_ptr = s_input_count0_ptr + dst_idx_ptr = s_input_idx0_ptr + dst_val_ptr = s_input_val0_ptr + dst_count_ptr = s_input_count0_ptr + elif step_idx == 1: + src_idx_ptr = s_input_idx0_ptr + src_val_ptr = s_input_val0_ptr + src_count_ptr = s_input_count0_ptr + dst_idx_ptr = s_input_idx1_ptr + dst_val_ptr = s_input_val1_ptr + dst_count_ptr = s_input_count1_ptr + else: + src_idx_ptr = s_input_idx1_ptr + src_val_ptr = s_input_val1_ptr + src_count_ptr = s_input_count1_ptr + dst_idx_ptr = s_input_idx0_ptr + dst_val_ptr = s_input_val0_ptr + dst_count_ptr = s_input_count0_ptr + + continue_to_next_step, step_need_final_sort = _tle_process_histogram_step_smem( + row_ptr, + stride_xn, + row_start, + row_end, + seq_len, + step_idx, + found_topk_values, + hist_base_ptr, + s_out_indices_ptr, + s_final_cnt_ptr, + s_found_topk_values_ptr, + s_threshold_bin_idx_ptr, + s_final_bin_size_ptr, + src_idx_ptr, + src_val_ptr, + src_count_ptr, + dst_idx_ptr, + dst_val_ptr, + dst_count_ptr, + s_need_fallback_ptr, + ASSUME_ALIGNED=ASSUME_ALIGNED, + TOPK=TOPK, + BLOCK_SIZE=BLOCK_SIZE, + SMEM_INPUT_SIZE=SMEM_INPUT_SIZE, + ) + need_final_sort = need_final_sort | step_need_final_sort + + if tl.load(s_need_fallback_ptr) != 0: + _tle_topk_smem_overflow_fallback_fullscan( + row_ptr, + out_row, + stride_xn, + stride_outn, + row_start, + row_end, + seq_len, + hist_base_ptr, + s_found_topk_values_ptr, + s_fallback_eq_count_ptr, + TOPK=TOPK, + BLOCK_SIZE=BLOCK_SIZE, + ) + return + + if need_final_sort: + if USE_RADIX_FINAL: + _tle_final_select_radix( + hist_base_ptr, + s_out_indices_ptr, + s_final_cnt_ptr, + s_found_topk_values_ptr, + TOPK=TOPK, + BLOCK_SIZE=BLOCK_SIZE, + FINAL_SORT_ITEMS=FINAL_SORT_ITEMS, + ) + else: + base_idx = tl.load(s_found_topk_values_ptr) + final_cnt = tl.minimum(tl.load(s_final_cnt_ptr), FINAL_SORT_ITEMS) + sort_chunks = tl.cdiv(final_cnt, BLOCK_SIZE) + for sort_chunk in tl.range(0, sort_chunks): + pos = sort_chunk * BLOCK_SIZE + lane + valid = pos < final_cnt + logit_i_bits = tl.load( + tle.gpu.local_ptr(s_histogram, (FINAL_SORT_ITEMS + pos, )), + mask=valid, + other=0, + ) + logit_i = logit_i_bits.to(tl.float32, bitcast=True) + out_rank = tl.zeros([BLOCK_SIZE], dtype=tl.int32) + for j in tl.range(0, final_cnt): + logit_j_bits = tl.load(tle.gpu.local_ptr(s_histogram, (FINAL_SORT_ITEMS + j, ))) + logit_j = logit_j_bits.to(tl.float32, bitcast=True) + better = (logit_i < logit_j) | ((logit_i == logit_j) & (pos < j)) + out_rank = out_rank + (valid & better).to(tl.int32) + dst_pos = base_idx + out_rank + take = valid & (dst_pos < TOPK) + idx_i = tl.load( + tle.gpu.local_ptr(s_histogram, (pos, )), + mask=take, + other=0, + ) + tl.store(tle.gpu.local_ptr(s_out_indices, (dst_pos, )), idx_i, mask=take) + tl.store(s_found_topk_values_ptr, TOPK) + + flush_chunks: tl.constexpr = (TOPK + BLOCK_SIZE - 1) // BLOCK_SIZE + for flush_chunk in tl.static_range(flush_chunks): + pos = flush_chunk * BLOCK_SIZE + lane + mask = pos < TOPK + out_vals = tl.load(s_out_indices_ptr + pos, mask=mask, other=-1) + tl.store(out_row + pos * stride_outn, out_vals, mask=mask) + + +@triton.jit +def _tle_process_histogram_step_cluster( + row_ptr, + stride_xn, + row_start, + row_end, + seq_len, + step_idx: tl.constexpr, + logit_pattern, + cluster_rank, + is_rank0, + s_step_local_hist_ptr, + s_histogram_ptr, + s_histogram_rank0_ptr, + s_out_indices_ptr, + s_final_cnt_ptr, + s_found_topk_values_ptr, + s_step_thresholds_ptr, + s_step_thresholds_rank0_ptr, + s_threshold_bin_idx_ptr, + s_final_bin_size_ptr, + s_threshold_bin_idx_rank0_ptr, + s_final_bin_size_rank0_ptr, + mesh: tl.constexpr, + CLUSTER_SIZE: tl.constexpr, + ASSUME_ALIGNED: tl.constexpr, + TOPK: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + VEC: tl.constexpr = 4 + FINAL_SORT_ITEMS: tl.constexpr = 2048 + RADIX11_SIZE: tl.constexpr = 2048 + RADIX11_MASK: tl.constexpr = 0x7FF + RADIX10_SIZE: tl.constexpr = 1024 + RADIX10_MASK: tl.constexpr = 0x3FF + + lane = tl.arange(0, BLOCK_SIZE) + vec = tl.arange(0, VEC) + ones = tl.full([BLOCK_SIZE], 1, tl.int32) + ones_vec_2d = tl.full([BLOCK_SIZE, VEC], 1, tl.int32) + zeros = tl.zeros([BLOCK_SIZE], dtype=tl.int32) + zeros_vec_2d = tl.zeros([BLOCK_SIZE, VEC], dtype=tl.int32) + s_histogram_rank0_ptr = tle.remote(s_histogram_rank0_ptr, 0, scope=mesh) + s_step_thresholds_rank0_ptr = tle.remote(s_step_thresholds_rank0_ptr, 0, scope=mesh) + s_threshold_bin_idx_rank0_ptr = tle.remote(s_threshold_bin_idx_rank0_ptr, 0, scope=mesh) + s_final_bin_size_rank0_ptr = tle.remote(s_final_bin_size_rank0_ptr, 0, scope=mesh) + s_out_indices_rank0_ptr = tle.remote(s_out_indices_ptr, 0, scope=mesh) + s_final_cnt_rank0_ptr = tle.remote(s_final_cnt_ptr, 0, scope=mesh) + s_found_topk_values_rank0_ptr = tle.remote(s_found_topk_values_ptr, 0, scope=mesh) + + clear_rounds = tl.where( + step_idx == 3, + RADIX10_SIZE // BLOCK_SIZE, + RADIX11_SIZE // BLOCK_SIZE, + ) + + for clear_round in tl.range(0, clear_rounds): + clear_bins = clear_round * BLOCK_SIZE + lane + tl.store(s_step_local_hist_ptr + clear_bins, 0) + if is_rank0: + tl.store(s_histogram_ptr + clear_bins, 0) + tle.distributed_barrier(mesh) + + if step_idx == 2: + step1_threshold = tl.load(s_step_thresholds_rank0_ptr + 1) + logit_pattern = (step1_threshold.to(tl.uint32) & RADIX11_MASK) << 21 + elif step_idx == 3: + step1_threshold = tl.load(s_step_thresholds_rank0_ptr + 1) + step2_threshold = tl.load(s_step_thresholds_rank0_ptr + 2) + logit_pattern = ((step1_threshold.to(tl.uint32) & RADIX11_MASK) << 21) | ( + (step2_threshold.to(tl.uint32) & RADIX11_MASK) << 10) + + n_tiles = tl.cdiv(seq_len, BLOCK_SIZE) + n_vec_full = seq_len // (BLOCK_SIZE * VEC) + rem_tiles = (seq_len - n_vec_full * BLOCK_SIZE * VEC) // BLOCK_SIZE + + if ASSUME_ALIGNED: + for t in tl.range(0, n_vec_full): + if (t % CLUSTER_SIZE) == cluster_rank: + base = t * BLOCK_SIZE * VEC + lane * VEC + offs = base[:, None] + vec[None, :] + x_vec = tl.load(row_ptr + offs) + key = _convert_to_trt_uint32(x_vec) + if step_idx == 0: + digit = _convert_to_trt_uint16_hi11(x_vec) + elif step_idx == 1: + digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) + elif step_idx == 2: + digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) + else: + digit = (key & RADIX10_MASK).to(tl.int32) + + if step_idx < 2: + partial = tl.full([BLOCK_SIZE, VEC], True, tl.int1) + elif step_idx == 2: + partial = ((key ^ logit_pattern) >> 21) == 0 + else: + partial = ((key ^ logit_pattern) >> 10) == 0 + + tl.atomic_add( + s_step_local_hist_ptr + digit, + ones_vec_2d, + mask=partial, + sem="relaxed", + scope="cta", + ) + + for t in tl.range(0, rem_tiles): + tile_idx = n_vec_full + t + if (tile_idx % CLUSTER_SIZE) == cluster_rank: + offs = (n_vec_full * VEC + t) * BLOCK_SIZE + lane + x = tl.load(row_ptr + offs) + key = _convert_to_trt_uint32(x) + if step_idx == 0: + digit = _convert_to_trt_uint16_hi11(x) + elif step_idx == 1: + digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) + elif step_idx == 2: + digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) + else: + digit = (key & RADIX10_MASK).to(tl.int32) + + if step_idx < 2: + partial = tl.full([BLOCK_SIZE], True, tl.int1) + elif step_idx == 2: + partial = ((key ^ logit_pattern) >> 21) == 0 + else: + partial = ((key ^ logit_pattern) >> 10) == 0 + + tl.atomic_add( + s_step_local_hist_ptr + digit, + ones, + mask=partial, + sem="relaxed", + scope="cta", + ) + else: + for t in tl.range(0, n_tiles): + if (t % CLUSTER_SIZE) == cluster_rank: + offs = t * BLOCK_SIZE + lane + in_range = (offs < seq_len) & (offs >= row_start) & (offs < row_end) + x = tl.load(row_ptr + offs * stride_xn, mask=in_range, other=float("-inf")) + key = _convert_to_trt_uint32(x) + if step_idx == 0: + digit = _convert_to_trt_uint16_hi11(x) + elif step_idx == 1: + digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) + elif step_idx == 2: + digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) + else: + digit = (key & RADIX10_MASK).to(tl.int32) + + if step_idx < 2: + partial = in_range + elif step_idx == 2: + partial = in_range & (((key ^ logit_pattern) >> 21) == 0) + else: + partial = in_range & (((key ^ logit_pattern) >> 10) == 0) + + tl.atomic_add( + s_step_local_hist_ptr + digit, + ones, + mask=partial, + sem="relaxed", + scope="cta", + ) + for clear_round in tl.range(0, clear_rounds): + bins = clear_round * BLOCK_SIZE + lane + local_counts = tl.load(s_step_local_hist_ptr + bins) + tl.atomic_add( + s_histogram_rank0_ptr + bins, + local_counts, + sem="relaxed", + scope="gpu", + ) + tle.distributed_barrier(mesh) + + found_topk_values = tl.load(s_found_topk_values_rank0_ptr) + if is_rank0: + tl.store(s_threshold_bin_idx_ptr, -1) + tl.store(s_final_bin_size_ptr, 0) + threshold_bin_ptrs = s_threshold_bin_idx_ptr + zeros + final_bin_size_ptrs = s_final_bin_size_ptr + zeros + last_value = found_topk_values + threshold_found = False + for round_idx in tl.range(0, clear_rounds): + if not threshold_found: + bins = round_idx * BLOCK_SIZE + lane + counts = tl.load(s_histogram_ptr + bins) + prefix_sum, counts_total = tle.cumsum(counts, axis=0, reverse=False) + prefix_sum = prefix_sum + last_value + total_sum = last_value + counts_total + next_prefix_sum = prefix_sum + counts + threshold_mask = (prefix_sum < TOPK) & (next_prefix_sum >= TOPK) + threshold_bin = bins + threshold_bin_size = next_prefix_sum - prefix_sum + tl.store(threshold_bin_ptrs, threshold_bin, mask=threshold_mask) + tl.store(final_bin_size_ptrs, threshold_bin_size, mask=threshold_mask) + found_round = tl.reduce_or(threshold_mask, axis=0) + threshold_found = found_round + last_value = total_sum + + threshold_bin_idx_local = tl.load(s_threshold_bin_idx_ptr) + tl.store(s_step_thresholds_ptr + step_idx, threshold_bin_idx_local) + tle.distributed_barrier(mesh) + + threshold_bin_idx = tl.load(s_threshold_bin_idx_rank0_ptr) + final_bin_size = tl.load(s_final_bin_size_rank0_ptr) + use_final = (step_idx < 3) & (threshold_bin_idx >= 0) & (final_bin_size <= FINAL_SORT_ITEMS) + if is_rank0 and use_final: + tl.store(s_final_cnt_ptr, 0) + tle.distributed_barrier(mesh) + + found_ptrs = s_found_topk_values_rank0_ptr + zeros + final_cnt_ptrs = s_final_cnt_rank0_ptr + zeros + if ASSUME_ALIGNED: + found_ptrs_vec_2d = s_found_topk_values_rank0_ptr + zeros_vec_2d + final_cnt_ptrs_vec_2d = s_final_cnt_rank0_ptr + zeros_vec_2d + for t in tl.range(0, n_vec_full): + if (t % CLUSTER_SIZE) == cluster_rank: + base = t * BLOCK_SIZE * VEC + lane * VEC + offs = base[:, None] + vec[None, :] + x_vec = tl.load(row_ptr + offs) + key = _convert_to_trt_uint32(x_vec) + if step_idx == 0: + digit = _convert_to_trt_uint16_hi11(x_vec) + elif step_idx == 1: + digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) + elif step_idx == 2: + digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) + else: + digit = (key & RADIX10_MASK).to(tl.int32) + + if step_idx < 2: + partial = tl.full([BLOCK_SIZE, VEC], True, tl.int1) + elif step_idx == 2: + partial = ((key ^ logit_pattern) >> 21) == 0 + else: + partial = ((key ^ logit_pattern) >> 10) == 0 + + take_lt = partial & (digit < threshold_bin_idx) + out_pos_lt = tl.atomic_add( + found_ptrs_vec_2d, + ones_vec_2d, + mask=take_lt, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_rank0_ptr + out_pos_lt, + offs.to(tl.int32), + mask=take_lt & (out_pos_lt < TOPK), + ) + + if step_idx == 3: + take_eq = partial & (digit == threshold_bin_idx) + out_pos_eq = tl.atomic_add( + s_histogram_rank0_ptr + digit, + ones_vec_2d, + mask=take_eq, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_rank0_ptr + out_pos_eq, + offs.to(tl.int32), + mask=take_eq & (out_pos_eq < TOPK), + ) + elif use_final: + take_eq_final = partial & (digit == threshold_bin_idx) + final_pos = tl.atomic_add( + final_cnt_ptrs_vec_2d, + ones_vec_2d, + mask=take_eq_final, + sem="relaxed", + scope="cta", + ) + tl.store( + s_histogram_rank0_ptr + final_pos, + offs.to(tl.int32), + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + tl.store( + s_histogram_rank0_ptr + (FINAL_SORT_ITEMS + final_pos), + x_vec.to(tl.int32, bitcast=True), + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + + for t in tl.range(0, rem_tiles): + tile_idx = n_vec_full + t + if (tile_idx % CLUSTER_SIZE) == cluster_rank: + offs = (n_vec_full * VEC + t) * BLOCK_SIZE + lane + x = tl.load(row_ptr + offs) + key = _convert_to_trt_uint32(x) + if step_idx == 0: + digit = _convert_to_trt_uint16_hi11(x) + elif step_idx == 1: + digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) + elif step_idx == 2: + digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) + else: + digit = (key & RADIX10_MASK).to(tl.int32) + + if step_idx < 2: + partial = tl.full([BLOCK_SIZE], True, tl.int1) + elif step_idx == 2: + partial = ((key ^ logit_pattern) >> 21) == 0 + else: + partial = ((key ^ logit_pattern) >> 10) == 0 + + take_lt = partial & (digit < threshold_bin_idx) + out_pos_lt = tl.atomic_add( + found_ptrs, + ones, + mask=take_lt, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_rank0_ptr + out_pos_lt, + offs.to(tl.int32), + mask=take_lt & (out_pos_lt < TOPK), + ) + + if step_idx == 3: + take_eq = partial & (digit == threshold_bin_idx) + out_pos_eq = tl.atomic_add( + s_histogram_rank0_ptr + digit, + ones, + mask=take_eq, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_rank0_ptr + out_pos_eq, + offs.to(tl.int32), + mask=take_eq & (out_pos_eq < TOPK), + ) + elif use_final: + take_eq_final = partial & (digit == threshold_bin_idx) + final_pos = tl.atomic_add( + final_cnt_ptrs, + ones, + mask=take_eq_final, + sem="relaxed", + scope="cta", + ) + tl.store( + s_histogram_rank0_ptr + final_pos, + offs.to(tl.int32), + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + tl.store( + s_histogram_rank0_ptr + (FINAL_SORT_ITEMS + final_pos), + x.to(tl.int32, bitcast=True), + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + else: + for t in tl.range(0, n_tiles): + if (t % CLUSTER_SIZE) == cluster_rank: + offs = t * BLOCK_SIZE + lane + in_range = (offs < seq_len) & (offs >= row_start) & (offs < row_end) + x = tl.load(row_ptr + offs * stride_xn, mask=in_range, other=float("-inf")) + key = _convert_to_trt_uint32(x) + if step_idx == 0: + digit = _convert_to_trt_uint16_hi11(x) + elif step_idx == 1: + digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) + elif step_idx == 2: + digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) + else: + digit = (key & RADIX10_MASK).to(tl.int32) + + if step_idx < 2: + partial = in_range + elif step_idx == 2: + partial = in_range & (((key ^ logit_pattern) >> 21) == 0) + else: + partial = in_range & (((key ^ logit_pattern) >> 10) == 0) + + take_lt = partial & (digit < threshold_bin_idx) + out_pos_lt = tl.atomic_add( + found_ptrs, + ones, + mask=take_lt, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_rank0_ptr + out_pos_lt, + offs.to(tl.int32), + mask=take_lt & (out_pos_lt < TOPK), + ) + + if step_idx == 3: + take_eq = partial & (digit == threshold_bin_idx) + out_pos_eq = tl.atomic_add( + s_histogram_rank0_ptr + digit, + ones, + mask=take_eq, + sem="relaxed", + scope="cta", + ) + tl.store( + s_out_indices_rank0_ptr + out_pos_eq, + offs.to(tl.int32), + mask=take_eq & (out_pos_eq < TOPK), + ) + elif use_final: + take_eq_final = partial & (digit == threshold_bin_idx) + final_pos = tl.atomic_add( + final_cnt_ptrs, + ones, + mask=take_eq_final, + sem="relaxed", + scope="cta", + ) + tl.store( + s_histogram_rank0_ptr + final_pos, + offs.to(tl.int32), + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + tl.store( + s_histogram_rank0_ptr + (FINAL_SORT_ITEMS + final_pos), + x.to(tl.int32, bitcast=True), + mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), + ) + + tle.distributed_barrier(mesh) + + if step_idx < 3: + if use_final: + continue_to_next_step = False + need_final_sort = True + else: + continue_to_next_step = True + need_final_sort = False + else: + if is_rank0: + tl.store(s_found_topk_values_ptr, TOPK) + continue_to_next_step = False + need_final_sort = False + + tle.distributed_barrier(mesh) + return continue_to_next_step, need_final_sort, logit_pattern + + +@triton.jit +def tle_topk_selector_kernel_smem_cluster( + x_ptr, + out_ptr, + starts_ptr, + ends_ptr, + stride_xm, + stride_xn, + stride_outm, + stride_outn, + seq_len, + mesh: tl.constexpr, + CLUSTER_SIZE: tl.constexpr, + ASSUME_ALIGNED: tl.constexpr, + TOPK: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + USE_RADIX_FINAL: tl.constexpr, +): + cluster_pid = tl.program_id(0) + cluster_rank = tle.shard_id(mesh, "cluster_x") + pid = cluster_pid // CLUSTER_SIZE + is_rank0 = cluster_rank == 0 + + row_start = tl.load(starts_ptr + pid).to(tl.int32) + row_end = tl.load(ends_ptr + pid).to(tl.int32) + row_ptr = x_ptr + pid * stride_xm + out_row = out_ptr + pid * stride_outm + row_len = row_end - row_start + + if ASSUME_ALIGNED: + tl.assume(row_start == 0) + tl.assume(row_end == seq_len) + tl.assume(stride_xn == 1) + tl.assume(stride_outn == 1) + seq_len = tl.multiple_of(seq_len, BLOCK_SIZE) + + lane = tl.arange(0, BLOCK_SIZE) + if row_len <= TOPK: + if is_rank0: + chunks: tl.constexpr = (TOPK + BLOCK_SIZE - 1) // BLOCK_SIZE + for chunk_idx in tl.range(0, chunks): + pos = chunk_idx * BLOCK_SIZE + lane + take_row = pos < row_len + tl.store(out_row + pos * stride_outn, (row_start + pos).to(tl.int32), mask=take_row) + take_pad = (pos >= row_len) & (pos < TOPK) + tl.store(out_row + pos * stride_outn, -1, mask=take_pad) + return + + FINAL_SORT_ITEMS: tl.constexpr = 2048 + HIST_SIZE: tl.constexpr = 4096 + + s_histogram = tle.gpu.alloc( + [HIST_SIZE], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_step_local_hist = tle.gpu.alloc( + [HIST_SIZE], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_out_indices = tle.gpu.alloc( + [TOPK], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_final_cnt = tle.gpu.alloc( + [1], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_threshold_bin_idx = tle.gpu.alloc( + [1], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_final_bin_size = tle.gpu.alloc( + [1], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_found_topk_values = tle.gpu.alloc( + [1], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + s_step_thresholds = tle.gpu.alloc( + [4], + dtype=tl.int32, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=False, + ) + + s_histogram_ptr = tle.gpu.local_ptr(s_histogram, (0, )) + s_step_local_hist_ptr = tle.gpu.local_ptr(s_step_local_hist, (0, )) + s_out_indices_ptr = tle.gpu.local_ptr(s_out_indices, (0, )) + s_final_cnt_ptr = tle.gpu.local_ptr(s_final_cnt, (0, )) + s_threshold_bin_idx_ptr = tle.gpu.local_ptr(s_threshold_bin_idx, (0, )) + s_final_bin_size_ptr = tle.gpu.local_ptr(s_final_bin_size, (0, )) + s_found_topk_values_ptr = tle.gpu.local_ptr(s_found_topk_values, (0, )) + s_step_thresholds_ptr = tle.gpu.local_ptr(s_step_thresholds, (0, )) + + s_histogram_rank0 = tle.remote(s_histogram, 0, scope=mesh) + s_threshold_bin_idx_rank0 = tle.remote(s_threshold_bin_idx, 0, scope=mesh) + s_final_bin_size_rank0 = tle.remote(s_final_bin_size, 0, scope=mesh) + s_step_thresholds_rank0 = tle.remote(s_step_thresholds, 0, scope=mesh) + + s_histogram_rank0_ptr = tle.gpu.local_ptr(s_histogram_rank0, (0, )) + s_threshold_bin_idx_rank0_ptr = tle.gpu.local_ptr(s_threshold_bin_idx_rank0, (0, )) + s_final_bin_size_rank0_ptr = tle.gpu.local_ptr(s_final_bin_size_rank0, (0, )) + s_step_thresholds_rank0_ptr = tle.gpu.local_ptr(s_step_thresholds_rank0, (0, )) + + if is_rank0: + tl.store(s_final_cnt_ptr, 0) + tl.store(s_threshold_bin_idx_ptr, -1) + tl.store(s_final_bin_size_ptr, 0) + tl.store(s_found_topk_values_ptr, 0) + for i in tl.static_range(4): + tl.store(s_step_thresholds_ptr + i, 0) + init_chunks: tl.constexpr = (TOPK + BLOCK_SIZE - 1) // BLOCK_SIZE + for init_idx in tl.range(0, init_chunks): + pos = init_idx * BLOCK_SIZE + lane + tl.store(s_out_indices_ptr + pos, -1, mask=is_rank0 & (pos < TOPK)) + tle.distributed_barrier(mesh) + + logit_pattern = tl.zeros((), dtype=tl.uint32) + continue_to_next_step = True + need_final_sort = False + + for step_idx in tl.static_range(0, 4): + if continue_to_next_step: + continue_to_next_step, step_need_final_sort, logit_pattern = _tle_process_histogram_step_cluster( + row_ptr, + stride_xn, + row_start, + row_end, + seq_len, + step_idx, + logit_pattern, + cluster_rank, + is_rank0, + s_step_local_hist_ptr, + s_histogram_ptr, + s_histogram_rank0_ptr, + s_out_indices_ptr, + s_final_cnt_ptr, + s_found_topk_values_ptr, + s_step_thresholds_ptr, + s_step_thresholds_rank0_ptr, + s_threshold_bin_idx_ptr, + s_final_bin_size_ptr, + s_threshold_bin_idx_rank0_ptr, + s_final_bin_size_rank0_ptr, + mesh=mesh, + CLUSTER_SIZE=CLUSTER_SIZE, + ASSUME_ALIGNED=ASSUME_ALIGNED, + TOPK=TOPK, + BLOCK_SIZE=BLOCK_SIZE, + ) + need_final_sort = need_final_sort | step_need_final_sort + + if is_rank0 and need_final_sort: + if USE_RADIX_FINAL: + _tle_final_select_radix( + s_histogram_ptr, + s_out_indices_ptr, + s_final_cnt_ptr, + s_found_topk_values_ptr, + TOPK=TOPK, + BLOCK_SIZE=BLOCK_SIZE, + FINAL_SORT_ITEMS=FINAL_SORT_ITEMS, + ) + else: + base_idx = tl.load(s_found_topk_values_ptr) + final_cnt = tl.minimum(tl.load(s_final_cnt_ptr), FINAL_SORT_ITEMS) + sort_chunks = tl.cdiv(final_cnt, BLOCK_SIZE) + for sort_chunk in tl.range(0, sort_chunks): + pos = sort_chunk * BLOCK_SIZE + lane + valid = pos < final_cnt + logit_i_bits = tl.load( + tle.gpu.local_ptr(s_histogram, (FINAL_SORT_ITEMS + pos, )), + mask=valid, + other=0, + ) + logit_i = logit_i_bits.to(tl.float32, bitcast=True) + out_rank = tl.zeros([BLOCK_SIZE], dtype=tl.int32) + for j in tl.range(0, final_cnt): + logit_j_bits = tl.load(tle.gpu.local_ptr(s_histogram, (FINAL_SORT_ITEMS + j, ))) + logit_j = logit_j_bits.to(tl.float32, bitcast=True) + better = (logit_i < logit_j) | ((logit_i == logit_j) & (pos < j)) + out_rank = out_rank + (valid & better).to(tl.int32) + dst_pos = base_idx + out_rank + take = valid & (dst_pos < TOPK) + idx_i = tl.load( + tle.gpu.local_ptr(s_histogram, (pos, )), + mask=take, + other=0, + ) + tl.store(s_out_indices_ptr + dst_pos, idx_i, mask=take) + tl.store(s_found_topk_values_ptr, TOPK) + tle.distributed_barrier(mesh) + + if is_rank0: + total_out = tl.minimum(tl.load(s_found_topk_values_ptr), TOPK) + flush_chunks = tl.cdiv(total_out, BLOCK_SIZE) + for flush_chunk in tl.range(0, flush_chunks): + pos = flush_chunk * BLOCK_SIZE + lane + mask = pos < total_out + out_vals = tl.load(s_out_indices_ptr + pos, mask=mask, other=-1) + tl.store(out_row + pos * stride_outn, out_vals, mask=mask) + tle.distributed_barrier(mesh) + + +@triton.jit +def triton_topk_selector_kernel( + x_ptr, + out_ptr, + cand0_ptr, + cand1_ptr, + starts_ptr, + ends_ptr, + stride_xm, + stride_xn, + stride_outm, + stride_outn, + stride_c0m, + stride_c0n, + stride_c1m, + stride_c1n, + seq_len, + ASSUME_ALIGNED: tl.constexpr, + TOPK: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + RADIX_BITS: tl.constexpr, +): + tl.static_assert(RADIX_BITS == 8, "triton_topk_selector_kernel currently expects 8-bit radix") + pid = tl.program_id(0) + row_start = tl.load(starts_ptr + pid).to(tl.int32) + row_end = tl.load(ends_ptr + pid).to(tl.int32) + row_ptr = x_ptr + pid * stride_xm + out_row = out_ptr + pid * stride_outm + cand0_row = cand0_ptr + pid * stride_c0m + cand1_row = cand1_ptr + pid * stride_c1m + + if ASSUME_ALIGNED: + tl.assume(row_start == 0) + tl.assume(row_end == seq_len) + tl.assume(stride_xn == 1) + tl.assume(stride_outn == 1) + seq_len = tl.multiple_of(seq_len, BLOCK_SIZE) + + lane = tl.arange(0, BLOCK_SIZE) + n_tiles = tl.cdiv(seq_len, BLOCK_SIZE) + RADIX_SIZE: tl.constexpr = 1 << RADIX_BITS + RADIX_MASK: tl.constexpr = RADIX_SIZE - 1 + bins = tl.arange(0, RADIX_SIZE) + + # Stage 1: 8-bit coarse prescreen on fp16-mapped keys. + coarse_counts = tl.zeros([RADIX_SIZE], dtype=tl.int32) + for t in tl.range(0, n_tiles): + offs = t * BLOCK_SIZE + lane + in_range = (offs < seq_len) & (offs >= row_start) & (offs < row_end) + x = tl.load(row_ptr + offs * stride_xn, mask=in_range, other=float("-inf")) + digit8 = _convert_to_uint16_hi8(x) + coarse_counts = coarse_counts + tl.histogram(digit8, RADIX_SIZE, mask=in_range) + + coarse_cumsum_desc = tl.cumsum(coarse_counts, axis=0, reverse=True) + topk_target = TOPK + coarse_cond = coarse_cumsum_desc > topk_target + coarse_threshold_bin = tl.max(tl.where(coarse_cond, bins, 0), axis=0).to(tl.int32) + coarse_counts_gt = tl.max(tl.where(bins == (coarse_threshold_bin + 1), coarse_cumsum_desc, 0), axis=0) + new_topk = topk_target - coarse_counts_gt + write_count = 0 + cand_count0 = 0 + + # Stage 2: write coarse winners and compact coarse-threshold candidates into cand0. + for t in tl.range(0, n_tiles): + offs = t * BLOCK_SIZE + lane + in_range = (offs < seq_len) & (offs >= row_start) & (offs < row_end) + x = tl.load(row_ptr + offs * stride_xn, mask=in_range, other=float("-inf")) + digit8 = _convert_to_uint16_hi8(x) + take = in_range & (digit8 > coarse_threshold_bin) + take_i32 = take.to(tl.int32) + pos = write_count + tl.cumsum(take_i32, axis=0) - 1 + mask = take & (pos < TOPK) + tl.store(out_row + pos * stride_outn, offs.to(tl.int32), mask=mask) + write_count = write_count + tl.sum(take_i32, axis=0) + + take_eq = in_range & (digit8 == coarse_threshold_bin) + take_eq_i32 = take_eq.to(tl.int32) + pos_eq = cand_count0 + tl.cumsum(take_eq_i32, axis=0) - 1 + tl.store(cand0_row + pos_eq * stride_c0n, offs.to(tl.int32), mask=take_eq) + cand_count0 = cand_count0 + tl.sum(take_eq_i32, axis=0) + + # Stage 3: four 8-bit refinements over compact candidate lists. + num_in = cand_count0 + for round_idx in tl.static_range(4): + if (new_topk > 0) & (num_in > 0): + shift: tl.constexpr = 24 - round_idx * 8 + desired = tl.zeros((), dtype=tl.uint32) + desired_mask = tl.zeros((), dtype=tl.uint32) + radix_mask_u32 = tl.zeros((), dtype=tl.uint32) + RADIX_MASK + k_to_find = new_topk + num_in_tiles = tl.cdiv(num_in, BLOCK_SIZE) + counts = tl.zeros([RADIX_SIZE], dtype=tl.int32) + + # Histogram on current candidate table. + for t in tl.range(0, num_in_tiles): + pos = t * BLOCK_SIZE + lane + valid = pos < num_in + if round_idx & 1: + idx = tl.load(cand1_row + pos * stride_c1n, mask=valid, other=0) + else: + idx = tl.load(cand0_row + pos * stride_c0n, mask=valid, other=0) + x = tl.load(row_ptr + idx * stride_xn, mask=valid, other=float("-inf")) + x_key = _convert_to_uint32(x) + matches = (x_key & desired_mask) == desired + take = valid & matches + digit = ((x_key >> shift) & RADIX_MASK).to(tl.int32) + counts = counts + tl.histogram(digit, RADIX_SIZE, mask=take) + + cumsum_desc = tl.cumsum(counts, axis=0, reverse=True) + cond = cumsum_desc > k_to_find + threshold_bin = tl.max(tl.where(cond, bins, 0), axis=0).to(tl.int32) + counts_gt = tl.max(tl.where(bins == (threshold_bin + 1), cumsum_desc, 0), axis=0) + desired = desired | (threshold_bin.to(tl.uint32) << shift) + desired_mask = desired_mask | (radix_mask_u32 << shift) + new_topk = k_to_find - counts_gt + + out_count = write_count + next_count = 0 + for t in tl.range(0, num_in_tiles): + pos = t * BLOCK_SIZE + lane + valid = pos < num_in + if round_idx & 1: + idx = tl.load(cand1_row + pos * stride_c1n, mask=valid, other=0) + else: + idx = tl.load(cand0_row + pos * stride_c0n, mask=valid, other=0) + x = tl.load(row_ptr + idx * stride_xn, mask=valid, other=float("-inf")) + x_key = _convert_to_uint32(x) + digit = ((x_key >> shift) & RADIX_MASK).to(tl.int32) + + take_gt = valid & (digit > threshold_bin) + take_gt_i32 = take_gt.to(tl.int32) + out_pos_gt = out_count + tl.cumsum(take_gt_i32, axis=0) - 1 + out_mask_gt = take_gt & (out_pos_gt < TOPK) + tl.store(out_row + out_pos_gt * stride_outn, idx, mask=out_mask_gt) + out_count = out_count + tl.sum(take_gt_i32, axis=0) + + take_eq = valid & (digit == threshold_bin) + take_eq_i32 = take_eq.to(tl.int32) + if round_idx == 3: + out_pos_eq = out_count + tl.cumsum(take_eq_i32, axis=0) - 1 + out_mask_eq = take_eq & (out_pos_eq < TOPK) + tl.store(out_row + out_pos_eq * stride_outn, idx, mask=out_mask_eq) + out_count = out_count + tl.sum(take_eq_i32, axis=0) + else: + nxt_pos = next_count + tl.cumsum(take_eq_i32, axis=0) - 1 + if round_idx & 1: + tl.store(cand0_row + nxt_pos * stride_c0n, idx, mask=take_eq) + else: + tl.store(cand1_row + nxt_pos * stride_c1n, idx, mask=take_eq) + next_count = next_count + tl.sum(take_eq_i32, axis=0) + + write_count = out_count + num_in = next_count + + +# %% +# TileLang reference (optional) +# ----------------------------- + +if _HAVE_TILELANG: + _TL_PASS_CONFIGS = { + tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, + } + _TL_KERNEL_CACHE = {} + + def convert_to_uint16(x): + hval = T.Cast(T.float16, x) + bits_uint = T.reinterpret(T.uint16, hval) + bits_uint = T.if_then_else(x < 0, ~bits_uint & 0xFFFF, bits_uint | 0x8000) + return bits_uint >> 8 + + def convert_to_uint32(x): + bits_uint = T.reinterpret(T.uint32, x) + bits_uint = T.if_then_else( + x < 0, + ~bits_uint & T.Cast(T.uint32, 0xFFFFFFFF), + bits_uint | T.Cast(T.uint32, 0x80000000), + ) + return bits_uint + + @tilelang.jit(pass_configs=_TL_PASS_CONFIGS) + def _tilelang_topk_impl(topk, in_dtype=T.float32, out_dtype=T.int32): + batch = T.dynamic("batch") + seq_len = T.dynamic("seq_len") + RADIX_LOCAL = 1 << 8 + BLOCK_SIZE = 1024 + SMEM_INPUT_SIZE = 4096 + + @T.prim_func + def tl_topk_kernel( + input: T.Tensor[(batch, seq_len), in_dtype], + index: T.Tensor[(batch, topk), out_dtype], + starts: T.Tensor[(batch), out_dtype], + ends: T.Tensor[(batch), out_dtype], + ): + with T.Kernel(batch, threads=BLOCK_SIZE) as (bx): + tx = T.get_thread_binding() + + s_threshold_bin_id = T.alloc_shared([1], T.int32) + s_histogram = T.alloc_shared([RADIX_LOCAL + 1], T.int32) + s_num_input = T.alloc_shared([2], T.int32) + s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], T.int32) + + l_threshold_bin_id = T.alloc_var(T.int32) + l_new_topk = T.alloc_var(T.int32) + l_num_input = T.alloc_var(T.int32) + l_bin_id32 = T.alloc_var(T.int32) + l_val = T.alloc_var(T.int32) + l_start_pos = T.alloc_var(T.int32) + l_start_idx = T.alloc_var(T.int32) + l_end_idx = T.alloc_var(T.int32) + l_out_pos = T.alloc_var(T.int32) + + l_new_topk = topk + l_start_idx = starts[bx] + l_end_idx = ends[bx] + + T.fill(s_histogram, 0) + T.fill(s_num_input[0], 0) + T.sync_threads() + for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): + input_idx = s * BLOCK_SIZE + tx + if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: + inval_int16 = convert_to_uint16(input[bx, input_idx]) + T.atomic_add(s_histogram[inval_int16], 1) + T.sync_threads() + + if tx < RADIX_LOCAL: + for i in T.serial(8): + offset = 1 << i + T.sync_threads(3, RADIX_LOCAL) + if tx < RADIX_LOCAL - offset: + l_val = s_histogram[tx] + s_histogram[tx + offset] + T.sync_threads(3, RADIX_LOCAL) + if tx < RADIX_LOCAL - offset: + s_histogram[tx] = l_val + + T.sync_threads(3, RADIX_LOCAL) + if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk: + s_threshold_bin_id[0] = tx + T.sync_threads() + l_threshold_bin_id = s_threshold_bin_id[0] + l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1] + T.sync_threads() + + for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): + T.sync_threads() + input_idx = s * BLOCK_SIZE + tx + if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: + bin_id = convert_to_uint16(input[bx, input_idx]) + l_bin_id32 = T.Cast(T.int32, bin_id) + if l_bin_id32 > l_threshold_bin_id: + pos_gt0 = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + index[bx, pos_gt0] = input_idx + elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: + pos_eq0 = T.atomic_add(s_num_input[0], 1, return_prev=True) + s_input_idx[0, pos_eq0] = input_idx + + for round in T.serial(4): + if l_new_topk <= 0: + T.loop_break() + + r_idx = round % 2 + l_start_pos = topk - l_new_topk + + T.sync_threads() + T.fill(s_histogram, 0) + if tx == 0: + s_num_input[r_idx ^ 1] = 0 + T.sync_threads() + + l_num_input = s_num_input[r_idx] + for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): + if s * BLOCK_SIZE + tx < l_num_input: + l_bin_id32 = T.Cast( + T.int32, + ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> + (24 - round * 8)) & 0xFF), + ) + T.atomic_add(s_histogram[l_bin_id32], 1) + T.sync_threads() + + if tx < RADIX_LOCAL: + for i in T.serial(8): + offset = 1 << i + T.sync_threads(3, RADIX_LOCAL) + if tx < RADIX_LOCAL - offset: + l_val = s_histogram[tx] + s_histogram[tx + offset] + T.sync_threads(3, RADIX_LOCAL) + if tx < RADIX_LOCAL - offset: + s_histogram[tx] = l_val + + T.sync_threads(3, RADIX_LOCAL) + if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk: + s_threshold_bin_id[0] = tx + T.sync_threads() + + l_threshold_bin_id = s_threshold_bin_id[0] + l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1] + T.sync_threads() + + for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): + T.sync_threads() + if s * BLOCK_SIZE + tx < l_num_input: + l_bin_id32 = T.Cast( + T.int32, + ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> + (24 - round * 8)) & 0xFF), + ) + if l_bin_id32 > l_threshold_bin_id: + pos_gt_round = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, + return_prev=True) + l_start_pos + index[bx, pos_gt_round] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] + elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: + if round == 3: + l_out_pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, + return_prev=True) + l_start_pos + if l_out_pos < topk: + index[bx, l_out_pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] + else: + pos_eq_round = T.atomic_add(s_num_input[r_idx ^ 1], 1, return_prev=True) + s_input_idx[r_idx ^ 1, pos_eq_round] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] + + return tl_topk_kernel + + def tilelang_topk_selector(input, starts, ends, topk, out: Optional[torch.Tensor] = None): + batch, _ = input.shape + if out is None: + out = torch.zeros((batch, topk), dtype=torch.int32, device=input.device) + kernel = _TL_KERNEL_CACHE.get(topk) + if kernel is None: + kernel = _tilelang_topk_impl(topk) + _TL_KERNEL_CACHE[topk] = kernel + kernel(input, out, starts, ends) + return out + + +# %% +# Python wrappers +# --------------- + + +def _supports_tle_cluster_remote() -> bool: + if not torch.cuda.is_available(): + return False + major, _minor = torch.cuda.get_device_capability() + return major >= 9 + + +def tle_topk_selector( + x, + starts, + ends, + topk, + block_size=1024, + out: Optional[torch.Tensor] = None, + assume_aligned: Optional[bool] = None, + use_radix_final: Optional[bool] = None, +): + if x.dtype != torch.float32: + x = x.float() + batch, seq_len = x.shape + if out is None: + out = torch.full((batch, topk), -1, dtype=torch.int32, device=x.device) + tle_block_size = TLE_FIXED_BLOCK_SIZE + if use_radix_final is None: + use_radix_final = seq_len >= TLE_RADIX_FINAL_SEQ_LEN_THRESHOLD + + if assume_aligned is None: + assume_aligned = (x.is_contiguous() and out.is_contiguous() and (seq_len % tle_block_size == 0) + and torch.all(starts == 0).item() and torch.all(ends == seq_len).item()) + + batch, seq_len = x.shape + grid = (batch, ) + tle_topk_selector_kernel[grid]( + x, + out, + starts, + ends, + x.stride(0), + x.stride(1), + out.stride(0), + out.stride(1), + seq_len, + ASSUME_ALIGNED=assume_aligned, + TOPK=topk, + BLOCK_SIZE=tle_block_size, + USE_RADIX_FINAL=use_radix_final, + num_warps=TLE_FIXED_NUM_WARPS, + num_stages=TLE_FIXED_NUM_STAGES, + ) + return out + + +def tle_topk_selector_1024threads( + x, + starts, + ends, + topk, + out: Optional[torch.Tensor] = None, + assume_aligned: Optional[bool] = None, + use_radix_final: Optional[bool] = None, +): + if x.dtype != torch.float32: + x = x.float() + batch, seq_len = x.shape + if out is None: + out = torch.full((batch, topk), -1, dtype=torch.int32, device=x.device) + tle_block_size = 1024 + if use_radix_final is None: + use_radix_final = seq_len >= TLE_RADIX_FINAL_SEQ_LEN_THRESHOLD + + if assume_aligned is None: + assume_aligned = (x.is_contiguous() and out.is_contiguous() and (seq_len % tle_block_size == 0) + and torch.all(starts == 0).item() and torch.all(ends == seq_len).item()) + + grid = (batch, ) + tle_topk_selector_kernel[grid]( + x, + out, + starts, + ends, + x.stride(0), + x.stride(1), + out.stride(0), + out.stride(1), + seq_len, + ASSUME_ALIGNED=assume_aligned, + TOPK=topk, + BLOCK_SIZE=tle_block_size, + USE_RADIX_FINAL=use_radix_final, + num_warps=tle_block_size // 32, + num_stages=TLE_FIXED_NUM_STAGES, + ) + return out + + +def tle_topk_selector_smem( + x, + starts, + ends, + topk, + block_size=1024, + out: Optional[torch.Tensor] = None, + assume_aligned: Optional[bool] = None, + use_radix_final: Optional[bool] = None, +): + if x.dtype != torch.float32: + x = x.float() + batch, seq_len = x.shape + if out is None: + out = torch.full((batch, topk), -1, dtype=torch.int32, device=x.device) + tle_block_size = TLE_SMEM_BLOCK_SIZE + if use_radix_final is None: + use_radix_final = seq_len >= TLE_RADIX_FINAL_SEQ_LEN_THRESHOLD + if assume_aligned is None: + assume_aligned = (x.is_contiguous() and out.is_contiguous() and (seq_len % tle_block_size == 0) + and torch.all(starts == 0).item() and torch.all(ends == seq_len).item()) + + grid = (batch, ) + tle_topk_selector_kernel_smem[grid]( + x, + out, + starts, + ends, + x.stride(0), + x.stride(1), + out.stride(0), + out.stride(1), + seq_len, + ASSUME_ALIGNED=assume_aligned, + TOPK=topk, + BLOCK_SIZE=tle_block_size, + SMEM_INPUT_SIZE=TLE_SMEM_INPUT_SIZE, + USE_RADIX_FINAL=use_radix_final, + num_warps=TLE_SMEM_NUM_WARPS, + num_stages=TLE_SMEM_NUM_STAGES, + ) + return out + + +def tle_topk_selector_smem_cluster( + x, + starts, + ends, + topk, + block_size=1024, + out: Optional[torch.Tensor] = None, + assume_aligned: Optional[bool] = None, + use_radix_final: Optional[bool] = None, +): + if not _supports_tle_cluster_remote(): + raise RuntimeError("TLE-Cluster requires CUDA SM90+") + if x.dtype != torch.float32: + x = x.float() + batch, seq_len = x.shape + if out is None: + out = torch.full((batch, topk), -1, dtype=torch.int32, device=x.device) + tle_block_size = TLE_SMEM_BLOCK_SIZE + if use_radix_final is None: + use_radix_final = seq_len >= TLE_RADIX_FINAL_SEQ_LEN_THRESHOLD + if assume_aligned is None: + assume_aligned = (x.is_contiguous() and out.is_contiguous() and (seq_len % tle_block_size == 0) + and torch.all(starts == 0).item() and torch.all(ends == seq_len).item()) + + grid = (batch, ) + tle_topk_selector_kernel_smem_cluster[grid]( + x, + out, + starts, + ends, + x.stride(0), + x.stride(1), + out.stride(0), + out.stride(1), + seq_len, + mesh=BLOCK_CLUSTER_MESH_8, + CLUSTER_SIZE=TLE_SMEM_CLUSTER_SIZE, + ASSUME_ALIGNED=assume_aligned, + TOPK=topk, + BLOCK_SIZE=tle_block_size, + USE_RADIX_FINAL=use_radix_final, + num_ctas=1, + num_warps=TLE_SMEM_NUM_WARPS, + num_stages=TLE_SMEM_NUM_STAGES, + ) + return out + + +def triton_topk_selector( + x, + starts, + ends, + topk, + block_size=1024, + out: Optional[torch.Tensor] = None, + cand0: Optional[torch.Tensor] = None, + cand1: Optional[torch.Tensor] = None, + assume_aligned: Optional[bool] = None, +): + if x.dtype != torch.float32: + x = x.float() + batch, seq_len = x.shape + if out is None: + out = torch.full((batch, topk), -1, dtype=torch.int32, device=x.device) + if cand0 is None: + cand0 = torch.empty((batch, seq_len), dtype=torch.int32, device=x.device) + if cand1 is None: + cand1 = torch.empty((batch, seq_len), dtype=torch.int32, device=x.device) + + if assume_aligned is None: + assume_aligned = (x.is_contiguous() and out.is_contiguous() and (seq_len % block_size == 0) + and torch.all(starts == 0).item() and torch.all(ends == seq_len).item()) + + assert cand0.shape == (batch, seq_len) and cand0.dtype == torch.int32 and cand0.is_cuda + assert cand1.shape == (batch, seq_len) and cand1.dtype == torch.int32 and cand1.is_cuda + + # Triton kernel uses kernel-specific tuning to avoid slow/unstable configs. + kernel_num_warps = 4 if block_size >= 1024 else 8 + + grid = (batch, ) + triton_topk_selector_kernel[grid]( + x, + out, + cand0, + cand1, + starts, + ends, + x.stride(0), + x.stride(1), + out.stride(0), + out.stride(1), + cand0.stride(0), + cand0.stride(1), + cand1.stride(0), + cand1.stride(1), + seq_len, + ASSUME_ALIGNED=assume_aligned, + TOPK=topk, + BLOCK_SIZE=block_size, + RADIX_BITS=8, + num_warps=kernel_num_warps, + num_stages=1, + ) + return out + + +# %% +# TRT-LLM CUDA reference (optional) +# --------------------------------- + +TRTLLM_INDEXER_TOPK_KERNEL_URL = ("https://raw.githubusercontent.com/NVIDIA/TensorRT-LLM/main/" + "cpp/tensorrt_llm/kernels/indexerTopK.cu") +FLASHINFER_TOPK_CUH_URL = ("https://raw.githubusercontent.com/flashinfer-ai/flashinfer/refs/heads/main/" + "include/flashinfer/topk.cuh") +FLASHINFER_INCLUDE_BASE_URL = ("https://raw.githubusercontent.com/flashinfer-ai/flashinfer/refs/heads/main/" + "include/flashinfer/") +SGLANG_TOPK_KERNEL_URL = ("https://raw.githubusercontent.com/sgl-project/sglang/refs/heads/main/" + "sgl-kernel/csrc/elementwise/topk.cu") + +TRTLLM_INDEXER_TOPK_BINDING_CPP = r""" +#include +#include + +namespace kernels { +void invokeIndexerTopKDecode(float const* logits, int const* seqLens, int* indices, float* outLogitsAux, + int* outIndicesAux, int const splitWorkThreshold, int const numRows, int const numColumns, int const stride0, + int const stride1, int const next_n, int const topK, cudaStream_t const stream); +void invokeIndexerTopKPrefill(float const* logits, int const* rowStarts, int const* rowEnds, int* indices, + int const numRows, int const numColumns, int const stride0, int const stride1, int const topK, + cudaStream_t const stream); +} + +void indexer_topk_decode(torch::Tensor logits, torch::Tensor seq_lens, torch::Tensor out, + torch::Tensor out_logits_aux, torch::Tensor out_indices_aux, int64_t next_n, int64_t topk) { + TORCH_CHECK(logits.is_cuda() && seq_lens.is_cuda() && out.is_cuda() && out_logits_aux.is_cuda() && out_indices_aux.is_cuda()); + TORCH_CHECK(logits.dtype() == torch::kFloat32); + TORCH_CHECK(seq_lens.dtype() == torch::kInt32); + TORCH_CHECK(out.dtype() == torch::kInt32); + TORCH_CHECK(out_logits_aux.dtype() == torch::kFloat32); + TORCH_CHECK(out_indices_aux.dtype() == torch::kInt32); + TORCH_CHECK(logits.dim() == 2 && seq_lens.dim() == 1 && out.dim() == 2); + + int numRows = static_cast(logits.size(0)); + int numColumns = static_cast(logits.size(1)); + TORCH_CHECK(seq_lens.size(0) * static_cast(next_n) == numRows); + TORCH_CHECK(out.size(0) == numRows && out.size(1) == topk); + TORCH_CHECK(out_logits_aux.size(0) == numRows && out_indices_aux.size(0) == numRows); + TORCH_CHECK(out_logits_aux.size(1) >= topk && out_indices_aux.size(1) >= topk); + + kernels::invokeIndexerTopKDecode( + logits.data_ptr(), + seq_lens.data_ptr(), + out.data_ptr(), + out_logits_aux.data_ptr(), + out_indices_aux.data_ptr(), + 200 * 1000, + numRows, + numColumns, + static_cast(logits.stride(0)), + static_cast(logits.stride(1)), + static_cast(next_n), + static_cast(topk), + at::cuda::getDefaultCUDAStream(logits.get_device())); +} + +void indexer_topk_prefill(torch::Tensor logits, torch::Tensor row_starts, torch::Tensor row_ends, torch::Tensor out, + int64_t topk) { + TORCH_CHECK(logits.is_cuda() && row_starts.is_cuda() && row_ends.is_cuda() && out.is_cuda()); + TORCH_CHECK(logits.dtype() == torch::kFloat32); + TORCH_CHECK(row_starts.dtype() == torch::kInt32); + TORCH_CHECK(row_ends.dtype() == torch::kInt32); + TORCH_CHECK(out.dtype() == torch::kInt32); + TORCH_CHECK(logits.dim() == 2 && row_starts.dim() == 1 && row_ends.dim() == 1 && out.dim() == 2); + + int numRows = static_cast(logits.size(0)); + int numColumns = static_cast(logits.size(1)); + TORCH_CHECK(row_starts.size(0) == numRows && row_ends.size(0) == numRows); + TORCH_CHECK(out.size(0) == numRows && out.size(1) == topk); + + kernels::invokeIndexerTopKPrefill( + logits.data_ptr(), + row_starts.data_ptr(), + row_ends.data_ptr(), + out.data_ptr(), + numRows, + numColumns, + static_cast(logits.stride(0)), + static_cast(logits.stride(1)), + static_cast(topk), + at::cuda::getDefaultCUDAStream(logits.get_device())); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("indexer_topk_decode", &indexer_topk_decode, "TRT-LLM indexerTopK decode"); + m.def("indexer_topk_prefill", &indexer_topk_prefill, "TRT-LLM indexerTopK prefill"); +} +""" + + +def _patch_trtllm_indexer_topk_source(src: str, prefill_threads: int = 512) -> str: + for old in [ + '#include "moeTopKFuncs.cuh"\n', + '#include "tensorrt_llm/common/config.h"\n', + '#include "tensorrt_llm/common/cudaTypeUtils.cuh"\n', + '#include "tensorrt_llm/common/envUtils.h"\n', + '#include "tensorrt_llm/kernels/noAuxTcKernels.h"\n', + ]: + src = src.replace(old, "") + + if prefill_threads != 512: + fn_marker = "void invokeIndexerTopKPrefill(" + fn_pos = src.find(fn_marker) + if fn_pos < 0: + raise RuntimeError("TRT-LLM source format changed: invokeIndexerTopKPrefill not found") + tail = src[fn_pos:] + marker = "constexpr int kNumThreadsPerBlock = 512;" + rel = tail.find(marker) + if rel < 0: + raise RuntimeError("TRT-LLM source format changed: prefill thread marker not found") + abs_pos = fn_pos + rel + replacement = f"constexpr int kNumThreadsPerBlock = {prefill_threads};" + src = src[:abs_pos] + replacement + src[abs_pos + len(marker):] + + # Make the standalone source compile under torch cpp_extension. + shim = r""" +#include +#include +#include +#include +#include +#define TRTLLM_NAMESPACE_BEGIN +#define TRTLLM_NAMESPACE_END +#define TLLM_CHECK_WITH_INFO(cond, msg) TORCH_CHECK((cond), msg) +namespace tensorrt_llm { namespace common { +inline bool getEnvEnablePDL() { return false; } +inline void sync_check_cuda_error(cudaStream_t) { C10_CUDA_CHECK(cudaGetLastError()); } +}} // namespace tensorrt_llm::common +""" + + return shim + src + + +@lru_cache(maxsize=4) +def _load_embedded_trtllm_indexer_topk(prefill_threads: int = 512): + try: + from torch.utils.cpp_extension import load_inline + except Exception as ex: + print(f"warning: cannot import torch cpp_extension for trtllm topk: {ex}") + return None + + try: + with urllib.request.urlopen(TRTLLM_INDEXER_TOPK_KERNEL_URL, timeout=20) as resp: + cuda_src = resp.read().decode("utf-8") + except Exception as ex: + print(f"warning: failed to download trtllm indexerTopK.cu: {ex}") + return None + + cuda_src = _patch_trtllm_indexer_topk_source(cuda_src, prefill_threads=prefill_threads) + digest = hashlib.sha1((TRTLLM_INDEXER_TOPK_BINDING_CPP + cuda_src).encode("utf-8")).hexdigest()[:12] + ext_name = f"flagtree_trtllm_indexer_topk_{digest}" + + try: + module = load_inline( + name=ext_name, + cpp_sources=[TRTLLM_INDEXER_TOPK_BINDING_CPP], + cuda_sources=[cuda_src], + functions=None, + extra_cuda_cflags=["-O3"], + with_cuda=True, + verbose=False, + ) + except Exception as ex: + print(f"warning: failed to compile embedded trtllm indexerTopK kernel: {ex}") + return None + + decode_fn = getattr(module, "indexer_topk_decode", None) + prefill_fn = getattr(module, "indexer_topk_prefill", None) + if decode_fn is None: + print("warning: embedded trtllm topk module has no indexer_topk_decode symbol") + if prefill_fn is None: + print("warning: embedded trtllm topk module has no indexer_topk_prefill symbol") + if decode_fn is None and prefill_fn is None: + return None + return module + + +def trtllm_cuda_topk_selector_decode( + x: torch.Tensor, + starts: torch.Tensor, + ends: torch.Tensor, + topk: int, + out: Optional[torch.Tensor] = None, + out_logits_aux: Optional[torch.Tensor] = None, + out_indices_aux: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if x.dtype != torch.float32: + x = x.float() + if out is None: + out = torch.full((x.shape[0], topk), -1, dtype=torch.int32, device=x.device) + # TRT-LLM decode path uses auxiliary buffers in long-sequence branch. + aux_cols = 10 * topk + if out_logits_aux is None: + out_logits_aux = torch.empty((x.shape[0], aux_cols), dtype=torch.float32, device=x.device) + if out_indices_aux is None: + out_indices_aux = torch.empty((x.shape[0], aux_cols), dtype=torch.int32, device=x.device) + module = _load_embedded_trtllm_indexer_topk() + if module is None: + raise RuntimeError("TRT-LLM indexerTopK extension unavailable") + fn = getattr(module, "indexer_topk_decode", None) + if fn is None: + raise RuntimeError("TRT-LLM decode symbol unavailable") + fn(x, ends, out, out_logits_aux, out_indices_aux, 1, int(topk)) + return out + + +def trtllm_cuda_topk_selector_prefill( + x: torch.Tensor, + starts: torch.Tensor, + ends: torch.Tensor, + topk: int, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if x.dtype != torch.float32: + x = x.float() + if out is None: + out = torch.full((x.shape[0], topk), -1, dtype=torch.int32, device=x.device) + module = _load_embedded_trtllm_indexer_topk(prefill_threads=512) + if module is None: + raise RuntimeError("TRT-LLM indexerTopK extension unavailable") + fn = getattr(module, "indexer_topk_prefill", None) + if fn is None: + raise RuntimeError("TRT-LLM prefill symbol unavailable") + fn(x, starts, ends, out, int(topk)) + return out + + +def trtllm_cuda_topk_selector_prefill_1024threads( + x: torch.Tensor, + starts: torch.Tensor, + ends: torch.Tensor, + topk: int, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if x.dtype != torch.float32: + x = x.float() + if out is None: + out = torch.full((x.shape[0], topk), -1, dtype=torch.int32, device=x.device) + module = _load_embedded_trtllm_indexer_topk(prefill_threads=1024) + if module is None: + raise RuntimeError("TRT-LLM indexerTopK extension unavailable") + fn = getattr(module, "indexer_topk_prefill", None) + if fn is None: + raise RuntimeError("TRT-LLM prefill symbol unavailable") + fn(x, starts, ends, out, int(topk)) + return out + + +FLASHINFER_TOPK_BINDING_CPP = r""" +#include + +void flashinfer_topk_cuda(torch::Tensor logits, torch::Tensor out_indices, torch::Tensor out_values, + torch::Tensor row_states_buffer, int64_t topk); +int64_t flashinfer_row_state_nbytes(); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("flashinfer_topk_cuda", &flashinfer_topk_cuda, "FlashInfer TopKDispatch CUDA"); + m.def("flashinfer_row_state_nbytes", &flashinfer_row_state_nbytes, "FlashInfer RadixRowState bytes"); +} +""" + +FLASHINFER_TOPK_CUDA_SRC = r""" +#include +#include +#include +#include +#include "flashinfer/math.cuh" +#include "flashinfer/topk.cuh" + +void flashinfer_topk_cuda(torch::Tensor logits, torch::Tensor out_indices, torch::Tensor out_values, + torch::Tensor row_states_buffer, int64_t topk) { + TORCH_CHECK(logits.is_cuda() && out_indices.is_cuda() && out_values.is_cuda() && row_states_buffer.is_cuda()); + TORCH_CHECK(logits.dtype() == torch::kFloat32); + TORCH_CHECK(out_indices.dtype() == torch::kInt32); + TORCH_CHECK(out_values.dtype() == torch::kFloat32); + TORCH_CHECK(row_states_buffer.dtype() == torch::kUInt8); + TORCH_CHECK(logits.dim() == 2 && out_indices.dim() == 2 && out_values.dim() == 2); + + int num_rows = static_cast(logits.size(0)); + int max_len = static_cast(logits.size(1)); + TORCH_CHECK(out_indices.size(0) == num_rows && out_indices.size(1) == topk); + TORCH_CHECK(out_values.size(0) == num_rows && out_values.size(1) == topk); + TORCH_CHECK(topk > 0 && topk <= max_len); + + const int64_t need_nbytes = static_cast(num_rows) * + static_cast(sizeof(flashinfer::sampling::RadixRowState)); + TORCH_CHECK(row_states_buffer.numel() >= need_nbytes); + + auto* row_states = reinterpret_cast(row_states_buffer.data_ptr()); + auto err = flashinfer::sampling::TopKDispatch( + logits.data_ptr(), + out_indices.data_ptr(), + out_values.data_ptr(), + static_cast(num_rows), + static_cast(topk), + static_cast(max_len), + row_states, + at::cuda::getDefaultCUDAStream(logits.get_device())); + C10_CUDA_CHECK(err); +} + +int64_t flashinfer_row_state_nbytes() { + return static_cast(sizeof(flashinfer::sampling::RadixRowState)); +} +""" + + +@lru_cache(maxsize=1) +def _prepare_embedded_flashinfer_headers(): + import re + import tempfile + from pathlib import Path + + include_root = Path(tempfile.gettempdir()) / "flagtree_flashinfer_include_main" + header_root = include_root / "flashinfer" + header_root.mkdir(parents=True, exist_ok=True) + + queue = ["topk.cuh", "math.cuh"] + seen = set() + while queue: + header = queue.pop(0) + if header in seen: + continue + seen.add(header) + local_path = header_root / header + if local_path.exists(): + src = local_path.read_text(encoding="utf-8") + else: + header_url = FLASHINFER_TOPK_CUH_URL if header == "topk.cuh" else (FLASHINFER_INCLUDE_BASE_URL + header) + with urllib.request.urlopen(header_url, timeout=20) as resp: + src = resp.read().decode("utf-8") + local_path.write_text(src, encoding="utf-8") + for inc in re.findall(r'^#include\s+"([^"]+)"', src, flags=re.M): + if inc not in seen: + queue.append(inc) + return str(include_root) + + +@lru_cache(maxsize=1) +def _load_embedded_flashinfer_topk(): + try: + from pathlib import Path + from torch.utils.cpp_extension import load_inline + except Exception as ex: + print(f"warning: cannot import torch cpp_extension for flashinfer topk: {ex}") + return None + + try: + include_dir = _prepare_embedded_flashinfer_headers() + topk_src = (Path(include_dir) / "flashinfer" / "topk.cuh").read_text(encoding="utf-8") + except Exception as ex: + print(f"warning: failed to prepare flashinfer headers: {ex}") + return None + + digest = hashlib.sha1( + (FLASHINFER_TOPK_BINDING_CPP + FLASHINFER_TOPK_CUDA_SRC + topk_src).encode("utf-8")).hexdigest()[:12] + ext_name = f"flagtree_flashinfer_topk_{digest}" + + try: + module = load_inline( + name=ext_name, + cpp_sources=[FLASHINFER_TOPK_BINDING_CPP], + cuda_sources=[FLASHINFER_TOPK_CUDA_SRC], + functions=None, + extra_cuda_cflags=[ + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + ], + with_cuda=True, + extra_include_paths=[include_dir], + verbose=False, + ) + except Exception as ex: + print(f"warning: failed to compile embedded flashinfer topk: {ex}") + return None + + if getattr(module, "flashinfer_topk_cuda", None) is None: + print("warning: embedded flashinfer module has no flashinfer_topk_cuda symbol") + return None + if getattr(module, "flashinfer_row_state_nbytes", None) is None: + print("warning: embedded flashinfer module has no flashinfer_row_state_nbytes symbol") + return None + return module + + +def flashinfer_cuda_topk_selector( + x: torch.Tensor, + starts: torch.Tensor, + ends: torch.Tensor, + topk: int, + out: Optional[torch.Tensor] = None, + out_values: Optional[torch.Tensor] = None, + row_states: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if x.dtype != torch.float32: + x = x.float() + if out is None: + out = torch.full((x.shape[0], topk), -1, dtype=torch.int32, device=x.device) + if out_values is None: + out_values = torch.empty((x.shape[0], topk), dtype=torch.float32, device=x.device) + module = _load_embedded_flashinfer_topk() + if module is None: + raise RuntimeError("FlashInfer topk extension unavailable") + if row_states is None: + row_state_nbytes = int(module.flashinfer_row_state_nbytes()) + row_states = torch.zeros((x.shape[0] * row_state_nbytes, ), dtype=torch.uint8, device=x.device) + # Current benchmark path uses starts==0 and ends==seq_len; decode/ragged transforms are not used. + module.flashinfer_topk_cuda(x, out, out_values, row_states, int(topk)) + return out + + +SGLANG_TOPK_BINDING_CPP = r""" +#include +#include + +void fast_topk_interface( + const at::Tensor& score, + at::Tensor& indices, + const at::Tensor& lengths, + std::optional row_starts_opt); + +void sglang_fast_topk(torch::Tensor score, torch::Tensor lengths, torch::Tensor out) { + TORCH_CHECK(score.is_cuda() && lengths.is_cuda() && out.is_cuda()); + TORCH_CHECK(score.dtype() == torch::kFloat32); + TORCH_CHECK(lengths.dtype() == torch::kInt32); + TORCH_CHECK(out.dtype() == torch::kInt32); + TORCH_CHECK(score.dim() == 2 && lengths.dim() == 1 && out.dim() == 2); + TORCH_CHECK(score.size(0) == lengths.size(0)); + TORCH_CHECK(score.size(0) == out.size(0)); + TORCH_CHECK(out.size(1) == 2048, "sglang fast_topk kernel is fixed TopK=2048"); + fast_topk_interface(score, out, lengths, std::nullopt); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("sglang_fast_topk", &sglang_fast_topk, "SGLang fast topk"); +} +""" + + +@lru_cache(maxsize=1) +def _load_embedded_sglang_topk(): + try: + from torch.utils.cpp_extension import load_inline + except Exception as ex: + print(f"warning: cannot import torch cpp_extension for sglang topk: {ex}") + return None + + try: + with urllib.request.urlopen(SGLANG_TOPK_KERNEL_URL, timeout=20) as resp: + cuda_src = resp.read().decode("utf-8") + except Exception as ex: + print(f"warning: failed to download sglang topk.cu: {ex}") + return None + + digest = hashlib.sha1((SGLANG_TOPK_BINDING_CPP + cuda_src).encode("utf-8")).hexdigest()[:12] + ext_name = f"flagtree_sglang_topk_{digest}" + try: + module = load_inline( + name=ext_name, + cpp_sources=[SGLANG_TOPK_BINDING_CPP], + cuda_sources=[cuda_src], + functions=None, + extra_cuda_cflags=["-O3"], + with_cuda=True, + verbose=False, + ) + except Exception as ex: + print(f"warning: failed to compile embedded sglang topk kernel: {ex}") + return None + + fn = getattr(module, "sglang_fast_topk", None) + if fn is None: + print("warning: embedded sglang topk module has no sglang_fast_topk symbol") + return fn + + +def sglang_cuda_topk_selector( + x: torch.Tensor, + starts: torch.Tensor, + ends: torch.Tensor, + topk: int, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if topk != 2048: + raise RuntimeError("sglang fast_topk kernel only supports topk=2048") + if x.dtype != torch.float32: + x = x.float() + if out is None: + out = torch.full((x.shape[0], topk), -1, dtype=torch.int32, device=x.device) + lengths = (ends - starts).to(torch.int32) + fn = _load_embedded_sglang_topk() + if fn is None: + raise RuntimeError("SGLang topk extension unavailable") + fn(x, lengths, out) + return out + + +# %% +# Correctness & benchmarking +# -------------------------- + + +def _torch_topk_indices(x, starts, ends, topk): + batch, _ = x.shape + out = torch.empty((batch, topk), dtype=torch.int32, device=x.device) + for i in range(batch): + start = int(starts[i].item()) + end = int(ends[i].item()) + vals, idx = torch.topk(x[i, start:end], topk, dim=0) + out[i] = idx.to(torch.int32) + start + return out + + +def _recall(pred, ref): + batch = pred.shape[0] + k = ref.shape[1] + hits = 0 + for i in range(batch): + pred_set = set(pred[i].tolist()) + ref_set = set(ref[i].tolist()) + hits += len(pred_set & ref_set) + return hits / (batch * k) + + +_BENCH_PROVIDERS = (["triton"] + ["trtllm-prefill"] + ["trtllm-prefill-1024threads"] + ["flashinfer-cuda"] + + ["tle-trt"] + ["tle-trt-1024threads"] + ["tle-cluster"] + (["tilelang"] if _HAVE_TILELANG else [])) +_BENCH_NAMES = (["Triton"] + ["TRTLLM-Prefill"] + ["TRTLLM-Prefill-1024T"] + ["FlashInfer"] + ["TLE-TRT"] + + ["TLE-TRT-1024T"] + ["TLE-Cluster"] + (["TileLang"] if _HAVE_TILELANG else [])) +_BENCH_STYLES = ([("red", "-")] + [("black", "-")] + [("brown", "-")] + [("gray", "-")] + [("orange", "-")] + + [("olive", "-")] + [("teal", "-")] + ([("blue", "-")] if _HAVE_TILELANG else [])) +_BENCH_XVALS = [ + (1, 131072, 2048), + (1, 262144, 2048), + (1, 524288, 2048), + (64, 4096, 128), + (64, 8192, 256), + (64, 32768, 1024), + (64, 32768, 2048), + (64, 131072, 2048), + (64, 524288, 2048), +] +_TILELANG_SKIP_SEQ_LEN_MIN = 262144 + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch", "seq_len", "topk"], + x_vals=_BENCH_XVALS, + x_log=True, + line_arg="provider", + line_vals=_BENCH_PROVIDERS, + line_names=_BENCH_NAMES, + styles=_BENCH_STYLES, + ylabel="ms", + plot_name="topk-selector", + args={}, + )) +def benchmark(batch, seq_len, topk, provider, block_size, warmup, rep): + torch.manual_seed(1) + x = torch.randn(batch, seq_len, device=DEVICE, dtype=torch.float32) + starts = torch.zeros(batch, dtype=torch.int32, device=DEVICE) + ends = torch.full((batch, ), seq_len, dtype=torch.int32, device=DEVICE) + assume_aligned = (seq_len % block_size == 0) + quantiles = [0.5, 0.2, 0.8] + + if provider == "tle-trt": + tle_out = torch.full((batch, topk), -1, dtype=torch.int32, device=x.device) + + def run(): + tle_topk_selector( + x, + starts, + ends, + topk, + block_size=block_size, + out=tle_out, + assume_aligned=assume_aligned, + ) + + elif provider == "tle-trt-1024threads": + tle_out = torch.full((batch, topk), -1, dtype=torch.int32, device=x.device) + + def run(): + tle_topk_selector_1024threads( + x, + starts, + ends, + topk, + out=tle_out, + assume_aligned=assume_aligned, + ) + + elif provider == "tle-cluster": + if not _supports_tle_cluster_remote(): + return float("nan"), float("nan"), float("nan") + if batch >= 48 and seq_len >= 131072: + return float("nan"), float("nan"), float("nan") + tle_smem_cluster_out = torch.full((batch, topk), -1, dtype=torch.int32, device=x.device) + + def run(): + tle_topk_selector_smem_cluster( + x, + starts, + ends, + topk, + block_size=block_size, + out=tle_smem_cluster_out, + assume_aligned=assume_aligned, + ) + + elif provider == "triton": + triton_out = torch.full((batch, topk), -1, dtype=torch.int32, device=x.device) + triton_cand0 = torch.empty((batch, seq_len), dtype=torch.int32, device=x.device) + triton_cand1 = torch.empty((batch, seq_len), dtype=torch.int32, device=x.device) + + def run(): + triton_topk_selector( + x, + starts, + ends, + topk, + block_size=block_size, + out=triton_out, + cand0=triton_cand0, + cand1=triton_cand1, + assume_aligned=assume_aligned, + ) + + elif provider == "trtllm-decode": + if _load_embedded_trtllm_indexer_topk() is None: + raise RuntimeError("TRT-LLM indexerTopK extension unavailable") + trtllm_out = torch.full((batch, topk), -1, dtype=torch.int32, device=x.device) + trtllm_out_logits_aux = torch.empty((batch, 10 * topk), dtype=torch.float32, device=x.device) + trtllm_out_indices_aux = torch.empty((batch, 10 * topk), dtype=torch.int32, device=x.device) + + def run(): + trtllm_cuda_topk_selector_decode( + x, + starts, + ends, + topk, + out=trtllm_out, + out_logits_aux=trtllm_out_logits_aux, + out_indices_aux=trtllm_out_indices_aux, + ) + + elif provider == "trtllm-prefill": + if _load_embedded_trtllm_indexer_topk() is None: + raise RuntimeError("TRT-LLM indexerTopK extension unavailable") + trtllm_out = torch.full((batch, topk), -1, dtype=torch.int32, device=x.device) + + def run(): + trtllm_cuda_topk_selector_prefill( + x, + starts, + ends, + topk, + out=trtllm_out, + ) + + elif provider == "trtllm-prefill-1024threads": + if _load_embedded_trtllm_indexer_topk(prefill_threads=1024) is None: + raise RuntimeError("TRT-LLM indexerTopK extension unavailable") + trtllm_out = torch.full((batch, topk), -1, dtype=torch.int32, device=x.device) + + def run(): + trtllm_cuda_topk_selector_prefill_1024threads( + x, + starts, + ends, + topk, + out=trtllm_out, + ) + + elif provider == "flashinfer-cuda": + module = _load_embedded_flashinfer_topk() + if module is None: + return float("nan"), float("nan"), float("nan") + flashinfer_out = torch.full((batch, topk), -1, dtype=torch.int32, device=x.device) + flashinfer_out_values = torch.empty((batch, topk), dtype=torch.float32, device=x.device) + row_state_nbytes = int(module.flashinfer_row_state_nbytes()) + flashinfer_row_states = torch.zeros((batch * row_state_nbytes, ), dtype=torch.uint8, device=x.device) + + def run(): + flashinfer_cuda_topk_selector( + x, + starts, + ends, + topk, + out=flashinfer_out, + out_values=flashinfer_out_values, + row_states=flashinfer_row_states, + ) + + elif provider == "sglang-cuda": + if topk != 2048: + return float("nan"), float("nan"), float("nan") + if _load_embedded_sglang_topk() is None: + return float("nan"), float("nan"), float("nan") + sglang_out = torch.full((batch, topk), -1, dtype=torch.int32, device=x.device) + + def run(): + sglang_cuda_topk_selector(x, starts, ends, topk, out=sglang_out) + + elif provider == "torch": + + def run(): + torch.topk(x, topk, dim=-1)[1] + + else: + if not _HAVE_TILELANG: + return float("nan"), float("nan"), float("nan") + if seq_len >= _TILELANG_SKIP_SEQ_LEN_MIN: + return float("nan"), float("nan"), float("nan") + tilelang_out = torch.zeros((batch, topk), dtype=torch.int32, device=x.device) + + def run(): + tilelang_topk_selector(x, starts, ends, topk, out=tilelang_out) + + ms, min_ms, max_ms = triton.testing.do_bench( + run, + quantiles=quantiles, + warmup=warmup, + rep=rep, + ) + return ms, max_ms, min_ms + + +def run_correctness(batch, seq_len, topk, block_size): + torch.manual_seed(1) + x = torch.randn(batch, seq_len, device=DEVICE, dtype=torch.float32) + starts = torch.zeros(batch, dtype=torch.int32, device=DEVICE) + ends = torch.full((batch, ), seq_len, dtype=torch.int32, device=DEVICE) + assume_aligned = (seq_len % block_size == 0) + + ref = _torch_topk_indices(x, starts, ends, topk) + + tle_out = tle_topk_selector( + x, + starts, + ends, + topk, + block_size=block_size, + assume_aligned=assume_aligned, + ) + tle_smem_out = tle_topk_selector_smem( + x, + starts, + ends, + topk, + block_size=block_size, + assume_aligned=assume_aligned, + ) + tle_smem_cluster_out = None + if _supports_tle_cluster_remote(): + tle_smem_cluster_out = tle_topk_selector_smem_cluster( + x, + starts, + ends, + topk, + block_size=block_size, + assume_aligned=assume_aligned, + ) + + print(f"TLE recall vs torch.topk: {_recall(tle_out, ref):.4f}") + print(f"TLE-SMEM recall vs torch.topk: {_recall(tle_smem_out, ref):.4f}") + if tle_smem_cluster_out is not None: + print(f"TLE-Cluster recall vs torch.topk: {_recall(tle_smem_cluster_out, ref):.4f}") + else: + print("TLE-Cluster not available; skipping cluster correctness.") + triton_out = triton_topk_selector( + x, + starts, + ends, + topk, + block_size=block_size, + assume_aligned=assume_aligned, + ) + print(f"Triton recall vs torch.topk: {_recall(triton_out, ref):.4f}") + print(f"TLE recall vs Triton: {_recall(tle_out, triton_out):.4f}") + print(f"TLE-SMEM recall vs Triton: {_recall(tle_smem_out, triton_out):.4f}") + if tle_smem_cluster_out is not None: + print(f"TLE-Cluster recall vs Triton: {_recall(tle_smem_cluster_out, triton_out):.4f}") + + trtllm_fn = _load_embedded_trtllm_indexer_topk() + if trtllm_fn is not None: + trtllm_out = trtllm_cuda_topk_selector_decode(x, starts, ends, topk) + print(f"TRTLLM-CUDA-Decode recall vs torch.topk: {_recall(trtllm_out, ref):.4f}") + print(f"TRTLLM-CUDA-Decode recall vs Triton: {_recall(trtllm_out, triton_out):.4f}") + else: + print("TRTLLM-CUDA-Decode not available; skipping TRTLLM correctness.") + + flashinfer_mod = _load_embedded_flashinfer_topk() + if flashinfer_mod is not None: + flashinfer_out = flashinfer_cuda_topk_selector(x, starts, ends, topk) + print(f"FlashInfer-CUDA recall vs torch.topk: {_recall(flashinfer_out, ref):.4f}") + print(f"FlashInfer-CUDA recall vs Triton: {_recall(flashinfer_out, triton_out):.4f}") + else: + print("FlashInfer-CUDA not available; skipping FlashInfer correctness.") + + sglang_fn = _load_embedded_sglang_topk() + if topk == 2048 and sglang_fn is not None: + sglang_out = sglang_cuda_topk_selector(x, starts, ends, topk) + print(f"SGLang-CUDA recall vs torch.topk: {_recall(sglang_out, ref):.4f}") + print(f"SGLang-CUDA recall vs Triton: {_recall(sglang_out, triton_out):.4f}") + elif topk != 2048: + print("SGLang-CUDA only supports topk=2048; skipping SGLang correctness.") + else: + print("SGLang-CUDA not available; skipping SGLang correctness.") + + if _HAVE_TILELANG: + tilelang_out = tilelang_topk_selector(x, starts, ends, topk) + print(f"TileLang recall vs torch.topk: {_recall(tilelang_out, ref):.4f}") + print(f"TLE recall vs TileLang: {_recall(tle_out, tilelang_out):.4f}") + else: + print("TileLang not available; skipping TileLang correctness.") + + +def _parse_bench_x_vals(raw): + if not raw: + return None + vals = [] + for chunk in raw.split(","): + text = chunk.strip() + if not text: + continue + parts = text.split("x") + if len(parts) != 3: + raise ValueError(f"invalid --bench_x_vals item: {text!r}, expect BxSxK") + vals.append((int(parts[0]), int(parts[1]), int(parts[2]))) + if not vals: + raise ValueError("--bench_x_vals produced empty set") + return vals + + +def _parse_providers(raw): + if not raw: + return None + providers = [p.strip() for p in raw.split(",") if p.strip()] + if not providers: + raise ValueError("--providers produced empty set") + providers = ["trtllm-decode" if p == "trtllm-cuda" else p for p in providers] + unknown = [p for p in providers if p not in _BENCH_PROVIDERS] + if unknown: + raise ValueError(f"unknown providers: {unknown}, supported={list(_BENCH_PROVIDERS)}") + return providers + + +def run_bench(block_size, warmup, rep, show_plots, providers=None, bench_x_vals=None, quick_bench=False, + max_seq_len=None): + bench = benchmark.benchmarks + + x_vals = list(_BENCH_XVALS) + if quick_bench: + x_vals = [v for v in x_vals if v[1] <= 32768] + if max_seq_len is not None: + x_vals = [v for v in x_vals if v[1] <= max_seq_len] + if bench_x_vals is not None: + x_vals = list(bench_x_vals) + if not x_vals: + raise ValueError("no benchmark x_vals left after filtering") + + line_vals = list(_BENCH_PROVIDERS) + line_names = list(_BENCH_NAMES) + styles = list(_BENCH_STYLES) + if providers is not None: + index_map = {p: i for i, p in enumerate(_BENCH_PROVIDERS)} + selected_indices = [index_map[p] for p in providers] + line_vals = [line_vals[i] for i in selected_indices] + line_names = [line_names[i] for i in selected_indices] + styles = [styles[i] for i in selected_indices] + + bench.x_vals = x_vals + bench.line_vals = line_vals + bench.line_names = line_names + bench.styles = styles + print(f"[bench] providers={line_vals}, x_vals={x_vals}, warmup={warmup}, rep={rep}, block_size={block_size}") + + benchmark.run( + print_data=True, + show_plots=show_plots, + block_size=block_size, + warmup=warmup, + rep=rep, + ) + + +# %% +# Main +# ---- + + +def main(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=64, help="batch size") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--topk", type=int, default=128, help="top-k") + parser.add_argument("--block_size", type=int, default=1024, help="block size (threads)") + parser.add_argument("--warmup", type=int, default=5, help="warmup iters") + parser.add_argument("--rep", type=int, default=20, help="benchmark iters") + parser.add_argument("--show_plots", action="store_true", help="show plots in benchmark") + parser.add_argument("--skip_correctness", action="store_true", help="skip correctness check") + parser.add_argument("--skip_bench", action="store_true", help="skip benchmark") + parser.add_argument( + "--providers", + type=str, + default="", + help=( + "comma-separated providers for benchmark, e.g. " + "tle-trt,tle-trt-1024threads,tle-cluster,triton,trtllm-prefill,trtllm-prefill-1024threads,flashinfer-cuda"), + ) + parser.add_argument( + "--bench_x_vals", + type=str, + default="", + help="override x-values: comma-separated BxSxK triplets, e.g. 64x4096x128,64x8192x256", + ) + parser.add_argument( + "--quick_bench", + action="store_true", + help="benchmark only default cases with seq_len<=32768", + ) + parser.add_argument( + "--max_seq_len", + type=int, + default=None, + help="filter benchmark cases by seq_len <= max_seq_len", + ) + args = parser.parse_args(argv) + + if not args.skip_correctness: + run_correctness( + batch=args.batch, + seq_len=args.seq_len, + topk=args.topk, + block_size=args.block_size, + ) + + if not args.skip_bench: + providers = _parse_providers(args.providers) + bench_x_vals = _parse_bench_x_vals(args.bench_x_vals) + run_bench( + block_size=args.block_size, + warmup=args.warmup, + rep=args.rep, + show_plots=args.show_plots, + providers=providers, + bench_x_vals=bench_x_vals, + quick_bench=args.quick_bench, + max_seq_len=args.max_seq_len, + ) + + +if __name__ == "__main__": + main() diff --git a/python/tutorials/tle/deepseek_v32/02-sparse-mla.py b/python/tutorials/tle/deepseek_v32/02-sparse-mla.py new file mode 100644 index 0000000000..fe197b2c33 --- /dev/null +++ b/python/tutorials/tle/deepseek_v32/02-sparse-mla.py @@ -0,0 +1,1006 @@ +# flagtree +""" +Sparse MLA Forward +================== + +This tutorial provides: +- Triton sparse MLA forward kernel (no TLE API in kernel body) +- Triton+TLE sparse MLA forward kernel (shared-memory staging) +- optional TileLang sparse MLA forward kernel (inlined from TileLang example) +- correctness test and benchmark entry +""" + +import argparse +import math + +import torch +import triton +import triton.language as tl +import triton.experimental.tle.language as tle + +try: + import tilelang + from tilelang import language as T + + _HAVE_TILELANG = True +except Exception: # pragma: no cover - optional dependency + tilelang = None + T = None + _HAVE_TILELANG = False + +spar_mla_fwd_configs = [ + triton.Config({"num_stages": 1, "num_warps": 4}), + triton.Config({"num_stages": 1, "num_warps": 8}), + triton.Config({"num_stages": 1, "num_warps": 16}), + triton.Config({"num_stages": 1, "num_warps": 32}), + triton.Config({"num_stages": 2, "num_warps": 4}), + triton.Config({"num_stages": 2, "num_warps": 8}), + triton.Config({"num_stages": 2, "num_warps": 16}), + triton.Config({"num_stages": 2, "num_warps": 32}), + triton.Config({"num_stages": 4, "num_warps": 4}), + triton.Config({"num_stages": 4, "num_warps": 8}), + triton.Config({"num_stages": 4, "num_warps": 16}), + triton.Config({"num_stages": 4, "num_warps": 32}), +] +tle_spar_mla_fwd_configs = [ + triton.Config({"num_stages": 2, "num_warps": 4}), + triton.Config({"num_stages": 2, "num_warps": 8}), + triton.Config({"num_stages": 2, "num_warps": 16}), + triton.Config({"num_stages": 2, "num_warps": 32}), +] + + +@triton.autotune( + configs=spar_mla_fwd_configs, + key=["SQ", "K", "H", "D", "is_causal"], +) +@triton.jit +def triton_sparse_mla_fwd( + q, + kv, + indices, + sm_scale: tl.constexpr, + output, + lse, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kvb, + stride_kvg, + stride_kvn, + stride_kvd, + stride_tb, + stride_tg, + stride_tm, + stride_tt, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_lb, + stride_lh, + stride_lm, + B: tl.constexpr, + SQ: tl.constexpr, + SKV: tl.constexpr, + K: tl.constexpr, + D: tl.constexpr, + TD: tl.constexpr, + DP: tl.constexpr, + TDP: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + VG: tl.constexpr, + BK: tl.constexpr, + BH: tl.constexpr, + is_causal: tl.constexpr, +): + i_b, i_sq, i_gbh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_g, i_bh = i_gbh // G, i_gbh % G + q_base = q + i_b * stride_qb + i_sq * stride_qm + i_gbh * (BH * stride_qh) + tq_base = q_base + D * stride_qd + kv_base = kv + i_b * stride_kvb + i_g * stride_kvg + tkv_base = kv_base + D * stride_kvd + t_base = indices + i_b * stride_tb + i_sq * stride_tm + i_g * stride_tg + o_base = output + i_b * stride_ob + i_sq * stride_om + i_gbh * (BH * stride_oh) + l_base = lse + i_b * stride_lb + i_sq * stride_lm + i_gbh * (BH * stride_lh) + + offs_h = tl.arange(0, BH) + offs_d = tl.arange(0, DP) + offs_td = tl.arange(0, TDP) + offs_od = tl.arange(0, DP) + offs_t = tl.arange(0, BK) + mask_h = i_bh * BH + offs_h < G + mask_d = offs_d < D + mask_td = offs_td < TD + mask_od = mask_d + + q_ptr = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd + q_msk = mask_h[:, None] & mask_d[None, :] + q_blk = tl.load(q_ptr, q_msk, other=0.0) + + tq_ptr = tq_base + offs_h[:, None] * stride_qh + offs_td[None, :] * stride_qd + tq_msk = mask_h[:, None] & mask_td[None, :] + tq_blk = tl.load(tq_ptr, tq_msk, other=0.0) + + max_prev = tl.full([BH], float("-inf"), dtype=tl.float32) + sum_exp = tl.full([BH], 1.0, dtype=tl.float32) + acc = tl.zeros([BH, DP], dtype=tl.float32) + + log_scale: tl.constexpr = sm_scale * 1.44269504 + max_col = i_sq if is_causal else SQ - 1 + + NK = tl.cdiv(K, BK) + for ck in tl.range(NK, num_stages=0): + if ck * BK <= max_col: + t_ptr = (BK * ck + offs_t) * stride_tt + t_msk = t_ptr < K + t_ptr += t_base + kv_ids = tl.load(t_ptr, t_msk, other=-1) + mask_ids = (kv_ids <= max_col) & (kv_ids >= 0) + + kv_ptr = kv_base + offs_d[:, None] * stride_kvd + kv_ids[None, :] * stride_kvn + kv_msk = mask_d[:, None] & mask_ids[None, :] + kv_blk = tl.load(kv_ptr, kv_msk, other=0.0) + + tkv_ptr = tkv_base + offs_td[:, None] * stride_kvd + kv_ids[None, :] * stride_kvn + tkv_msk = mask_td[:, None] & mask_ids[None, :] + tkv_blk = tl.load(tkv_ptr, tkv_msk, other=0.0) + + qk = tl.dot(tq_blk, tkv_blk, out_dtype=tl.float32) + qk = tl.dot(q_blk, kv_blk, qk, out_dtype=tl.float32) + qk = tl.where(mask_ids[None, :], qk, float("-inf")) + + new_max = tl.maximum(max_prev, tl.max(qk, axis=1)) + alpha = tl.math.exp2((max_prev - new_max) * log_scale) + exp_qk = tl.math.exp2(qk * log_scale - new_max[:, None] * log_scale) + sum_qk = tl.sum(exp_qk, axis=1) + sum_exp = sum_exp * alpha + sum_qk + acc = acc * alpha[:, None] + exp_qk = exp_qk.to(tl.bfloat16) + acc = tl.dot(exp_qk, tl.trans(kv_blk), acc, out_dtype=tl.float32) + max_prev = new_max + + out_vals = acc / sum_exp[:, None] + o_ptr = o_base + offs_h[:, None] * stride_oh + offs_od[None, :] * stride_od + o_msk = mask_h[:, None] & mask_od[None, :] + tl.store(o_ptr, out_vals.to(q_blk.dtype), o_msk) + + fin_log = max_prev * log_scale + tl.math.log2(sum_exp.to(tl.float32)) + l_ptr = l_base + offs_h * stride_lh + l_msk = mask_h + tl.store(l_ptr, fin_log.to(q_blk.dtype), l_msk) + + +@triton.autotune( + configs=tle_spar_mla_fwd_configs, + key=["SQ", "K", "H", "D", "is_causal"], +) +@triton.jit +def tle_sparse_mla_fwd( + q, + kv, + indices, + sm_scale: tl.constexpr, + output, + lse, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kvb, + stride_kvg, + stride_kvn, + stride_kvd, + stride_tb, + stride_tg, + stride_tm, + stride_tt, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_lb, + stride_lh, + stride_lm, + B: tl.constexpr, + SQ: tl.constexpr, + SKV: tl.constexpr, + K: tl.constexpr, + D: tl.constexpr, + TD: tl.constexpr, + DP: tl.constexpr, + TDP: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + VG: tl.constexpr, + BK: tl.constexpr, + BH: tl.constexpr, + is_causal: tl.constexpr, +): + # TileLang-style forward path: + # - stage Q and Q_tail once in shared memory; + # - load sparse KV/K_tail blocks directly from global memory per K tile; + # - online softmax on logits; + # - use probabilities directly for the second GEMM. + i_b, i_sq, i_gbh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_g, i_bh = i_gbh // G, i_gbh % G + q_base = q + i_b * stride_qb + i_sq * stride_qm + i_gbh * (BH * stride_qh) + tq_base = q_base + D * stride_qd + kv_base = kv + i_b * stride_kvb + i_g * stride_kvg + tkv_base = kv_base + D * stride_kvd + t_base = indices + i_b * stride_tb + i_sq * stride_tm + i_g * stride_tg + o_base = output + i_b * stride_ob + i_sq * stride_om + i_gbh * (BH * stride_oh) + l_base = lse + i_b * stride_lb + i_sq * stride_lm + i_gbh * (BH * stride_lh) + + offs_h = tl.arange(0, BH) + offs_d = tl.arange(0, DP) + offs_td = tl.arange(0, TDP) + offs_od = tl.arange(0, DP) + offs_t = tl.arange(0, BK) + mask_h = i_bh * BH + offs_h < G + mask_d = offs_d < D + mask_td = offs_td < TD + mask_od = mask_d + + q_ptr = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd + q_msk = mask_h[:, None] & mask_d[None, :] + tq_ptr = tq_base + offs_h[:, None] * stride_qh + offs_td[None, :] * stride_qd + tq_msk = mask_h[:, None] & mask_td[None, :] + + q_smem = tle.gpu.alloc( + [BH, DP], + dtype=q.dtype.element_ty, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=True, + ) + tq_smem = tle.gpu.alloc( + [BH, TDP], + dtype=q.dtype.element_ty, + layout=None, + scope=tle.gpu.smem, + nv_mma_shared_layout=True, + ) + + q_smem_ptr = tle.gpu.local_ptr(q_smem) + q_blk = tl.load(q_ptr, q_msk, other=0.0) + tl.store(q_smem_ptr, q_blk) + + tq_smem_ptr = tle.gpu.local_ptr(tq_smem) + tq_blk = tl.load(tq_ptr, tq_msk, other=0.0) + tl.store(tq_smem_ptr, tq_blk) + + max_prev = tl.full([BH], float("-inf"), dtype=tl.float32) + sum_exp = tl.full([BH], 1.0, dtype=tl.float32) + acc = tl.zeros([BH, DP], dtype=tl.float32) + + log_scale: tl.constexpr = sm_scale * 1.44269504 + max_col = i_sq if is_causal else SQ - 1 + + NK = tl.cdiv(K, BK) + for ck in tl.range(NK, num_stages=2): + if ck * BK <= max_col: + t_ptr = (BK * ck + offs_t) * stride_tt + t_msk = t_ptr < K + t_ptr += t_base + kv_ids = tl.load(t_ptr, t_msk, other=-1) + mask_ids = (kv_ids <= max_col) & (kv_ids >= 0) + kv_ids_safe = tl.where(mask_ids, kv_ids, 0) + + kv_ptr = kv_base + offs_d[:, None] * stride_kvd + kv_ids_safe[None, :] * stride_kvn + kv_col = tl.load(kv_ptr, mask=mask_d[:, None], other=0.0) + + tkv_ptr = tkv_base + offs_td[:, None] * stride_kvd + kv_ids_safe[None, :] * stride_kvn + tkv_col = tl.load(tkv_ptr, mask=mask_td[:, None], other=0.0) + + tq_blk = tl.load(tq_smem_ptr) + qk = tl.dot(tq_blk, tkv_col, out_dtype=tl.float32) + q_blk = tl.load(q_smem_ptr) + qk = tl.dot(q_blk, kv_col, qk, out_dtype=tl.float32) + qk = tl.where(mask_ids[None, :], qk, float("-inf")) + + new_max = tl.maximum(max_prev, tl.max(qk, axis=1)) + alpha = tl.math.exp2((max_prev - new_max) * log_scale) + exp_qk = tl.math.exp2(qk * log_scale - new_max[:, None] * log_scale) + sum_qk = tl.sum(exp_qk, axis=1) + sum_exp = sum_exp * alpha + sum_qk + acc = acc * alpha[:, None] + exp_qk = exp_qk.to(q.dtype.element_ty) + acc = tl.dot(exp_qk, tl.trans(kv_col), acc, out_dtype=tl.float32) + max_prev = new_max + + out_vals = acc / sum_exp[:, None] + o_ptr = o_base + offs_h[:, None] * stride_oh + offs_od[None, :] * stride_od + o_msk = mask_h[:, None] & mask_od[None, :] + tl.store(o_ptr, out_vals.to(q.dtype.element_ty), o_msk) + + fin_log = max_prev * log_scale + tl.math.log2(sum_exp.to(tl.float32)) + l_ptr = l_base + offs_h * stride_lh + l_msk = mask_h + tl.store(l_ptr, fin_log.to(q.dtype.element_ty), l_msk) + + +def _sparse_mla_fwd_interface_impl(kernel, q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512, bk=32): + is_causal = True + assert not return_p_sum, "This kernel file is for fwd only" + assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() + B, SQ, H, DT = q.shape + _, S, VG, _ = kv.shape + + D = d_v + assert kv.shape[-1] == DT + TD = DT - D + DP = triton.next_power_of_2(D) + TDP = triton.next_power_of_2(TD) + assert kv.shape[0] == B + _, _, _, K = indices.shape + assert indices.shape == (B, SQ, VG, K) + G = H // VG + if sm_scale is None: + sm_scale = DT**-0.5 + BH = 32 + NH = triton.cdiv(G, BH) + BK = bk + output = torch.zeros((B, SQ, H, D), device=q.device, dtype=q.dtype) + lse = torch.full((B, SQ, H), float("-inf"), device=q.device, dtype=q.dtype) + grid = (B, SQ, VG * NH) + kernel[grid]( + q, + kv, + indices, + sm_scale, + output, + lse, + q.stride(0), + q.stride(2), + q.stride(1), + q.stride(3), + kv.stride(0), + kv.stride(2), + kv.stride(1), + kv.stride(3), + indices.stride(0), + indices.stride(2), + indices.stride(1), + indices.stride(3), + output.stride(0), + output.stride(2), + output.stride(1), + output.stride(3), + lse.stride(0), + lse.stride(2), + lse.stride(1), + B, + SQ, + S, + K, + D, + TD, + DP, + TDP, + H, + G, + VG, + BK, + BH, + is_causal, + ) + return output, lse + + +def triton_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512): + return _sparse_mla_fwd_interface_impl( + triton_sparse_mla_fwd, + q, + kv, + indices, + sm_scale=sm_scale, + return_p_sum=return_p_sum, + d_v=d_v, + bk=32, + ) + + +def tle_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512): + return _sparse_mla_fwd_interface_impl( + tle_sparse_mla_fwd, + q, + kv, + indices, + sm_scale=sm_scale, + return_p_sum=return_p_sum, + d_v=d_v, + bk=64, + ) + + +if _HAVE_TILELANG: + + @tilelang.jit( + out_idx=[-2, -1], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + def tilelang_sparse_mla_fwd( + heads, + dim, + tail_dim, + topk, + kv_group=1, + sm_scale=None, + is_causal=True, + block_I=64, + num_stages=2, + threads=256, + ): + assert dim == tilelang.math.next_power_of_2(dim), f"dim should be power-of-2, but got {dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"tail_dim should be power-of-2, but got {tail_dim}" + assert is_causal, "non-causal path is not implemented" + assert topk % block_I == 0, "topk should be divisible by block_I" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim))**0.5 * 1.44269504 + else: + sm_scale = sm_scale * 1.44269504 + + batch = T.dynamic("batch") + seq_len = T.dynamic("seq_len") + seq_len_kv = T.dynamic("seq_len_kv") + + head_kv = heads // kv_group + q_shape = [batch, seq_len, heads, dim + tail_dim] + kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim] + o_shape = [batch, seq_len, heads, dim] + indices_shape = [batch, seq_len, kv_group, topk] + lse_shape = [batch, seq_len, heads] + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 + + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != head_kv: + assert kv_group == 1, "automatic head padding only supports kv_group == 1" + + BI = block_I + NI = tilelang.cdiv(topk, block_I) + D = dim + D_tail = tail_dim + + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + replicate_h = head_kv // 64 + else: + replicate_h = 1 + + H_per_block = padded_H if replicate_h == 1 else 64 + + @T.prim_func + def main(Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + ): + with T.Kernel(seq_len * replicate_h, batch, kv_group, threads=threads) as ( + bx, + by, + bz, + ): + Q_shared = T.alloc_shared([H_per_block, D], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared = T.alloc_shared([BI, D], dtype) + K_tail_shared = T.alloc_shared([BI, D_tail], dtype) + mask = T.alloc_fragment([BI], "bool") + acc_o = T.alloc_fragment([H_per_block, D], accum_dtype) + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + S_shared = T.alloc_shared([H_per_block, BI], dtype) + sumexp = T.alloc_fragment([H_per_block], accum_dtype) + sumexp_i = T.alloc_fragment([H_per_block], accum_dtype) + alpha = T.alloc_fragment([H_per_block], accum_dtype) + m_i = T.alloc_fragment([H_per_block], accum_dtype) + m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) + + T.fill(acc_o, 0) + T.fill(sumexp, 0) + T.fill(m_i, -(2**30)) + + b_i, g_i = by, bz + s_i = bx if replicate_h == 1 else (bx // replicate_h) + q_i = s_i + max_kv_i = q_i + + H0 = g_i * padded_H + (0 if replicate_h == 1 else (bx % replicate_h) * 64) + H1 = H0 + H_per_block + + T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared) + T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) + + for i_i in T.Pipelined(NI, num_stages=num_stages): + for bi_i in T.Parallel(BI): + mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i + + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i] + for bi_i, d_i in T.Parallel(BI, D_tail): + K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_tail_shared, K_tail_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) + for h_i in T.Parallel(H_per_block): + alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i] + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] /= sumexp[h_i] + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale + + T.copy(acc_o, Output[b_i, s_i, H0:H1, :]) + T.copy(sumexp, Lse[b_i, s_i, H0:H1]) + + return main + +else: + tilelang_sparse_mla_fwd = None + + +def tilelang_sparse_mla_fwd_interface( + q, + kv, + indices, + sm_scale=None, + return_p_sum: bool = False, + d_v=512, + block_I=64, + num_stages=2, + threads=256, +): + if not _HAVE_TILELANG or tilelang_sparse_mla_fwd is None: + raise RuntimeError("TileLang is not installed, cannot run TileLang sparse MLA bench") + + is_causal = True + assert not return_p_sum, "This kernel file is for fwd only" + assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() + batch, seq_len, heads, dim_plus_tail_dim = q.shape + _, _seq_len_kv, kv_group, _ = kv.shape + + dim = d_v + assert kv.shape[-1] == dim_plus_tail_dim + tail_dim = dim_plus_tail_dim - dim + assert dim == triton.next_power_of_2(dim), f"d_v should be power-of-2 for TileLang path, but got {dim}" + assert tail_dim == triton.next_power_of_2( + tail_dim), f"tail dim should be power-of-2 for TileLang path, but got {tail_dim}" + _, _, _, topk = indices.shape + assert indices.shape == (batch, seq_len, kv_group, topk) + + kernel = tilelang_sparse_mla_fwd( + heads, + dim, + tail_dim, + topk, + kv_group, + sm_scale, + is_causal, + block_I=block_I, + num_stages=num_stages, + threads=threads, + ) + out, lse = kernel(q, kv, indices) + return out, lse + + +def _build_sparse_mla_inputs(B=1, S=4096, SKV=4096, H=128, HKV=1, DQK=576, topk=2048, dtype=torch.bfloat16, seed=0): + torch.random.manual_seed(seed) + q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[b, t, h, :len(i_i)] = i_i + return q, kv, indices + + +def _resolve_tilelang_block_i(topk: int, block_i: int) -> int: + if block_i <= 0: + raise ValueError(f"tilelang block_I should be > 0, but got {block_i}") + if topk % block_i == 0: + return block_i + fallback = math.gcd(topk, block_i) + if fallback <= 0: + raise ValueError(f"cannot find a valid tilelang block_I for topk={topk}, block_I={block_i}") + return fallback + + +def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True, d_v=512): + q = q.float() + kv = kv.float() + indices = indices.transpose(1, 2) + b, sq, h, dim_q = q.shape + b, sk, g, _ = kv.shape + + dim = d_v + k = kv + v = kv[..., :dim] + + b, _, _, dim_v = v.shape + g_index = g + h_index = h // g + compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, + device="cuda").view(-1, + 1) >= torch.arange(1 - 1, sk * 1, 1, dtype=torch.int32, + device="cuda").view(1, -1) + + mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) + mask = mask[..., :-1] + mask = mask & compressed_casual_mask.view(1, 1, sq, sk) + mask[:, :, :1 - 1, 0] = True + mask = mask.view(b, g_index, 1, sq, sk) + + q = q.view(b, sq, g, -1, dim_q) + score = torch.einsum("bmghd,bngd->bghmn", q, k) + sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale + score = score.masked_fill(~mask, float("-inf")).mul(sm_scale) + p = score.softmax(dim=-1) + p = p.view(b, g_index, h_index, -1, sq, sk) + p = p.view(b, g, -1, sq, sk) + o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v) + o = o.reshape(b, sq, h, dim_v) + return o.to(torch.bfloat16) + + +def _sparse_mla_tflops(B, S, H, DQK, DV, topk, ms): + return (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12 + + +def _bench_ms(fn, warmup=200, rep=100): + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return float(ms if not isinstance(ms, tuple) else ms[0]) + + +_BENCH_PROVIDERS = (["triton"] + ["tle"] + (["tilelang"] if _HAVE_TILELANG else [])) +_BENCH_NAMES = (["Triton"] + ["TLE"] + (["TileLang"] if _HAVE_TILELANG else [])) +_BENCH_STYLES = ([("red", "-")] + [("orange", "-")] + ([("blue", "-")] if _HAVE_TILELANG else [])) +_BENCH_X_VALS = [ + # (B, S, SKV, H, HKV, DQK, DV, topk) + (1, 512, 1024, 128, 1, 192, 128, 512), + (1, 1024, 2048, 128, 1, 192, 128, 1024), + (1, 2048, 4096, 128, 1, 192, 128, 2048), + (1, 1024, 2048, 128, 1, 160, 128, 1024), +] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["B", "S", "SKV", "H", "HKV", "DQK", "DV", "topk"], + x_vals=_BENCH_X_VALS, + x_log=False, + line_arg="provider", + line_vals=_BENCH_PROVIDERS, + line_names=_BENCH_NAMES, + styles=_BENCH_STYLES, + ylabel="ms", + plot_name="tle-sparse-mla-fwd", + args={}, + )) +def benchmark_sparse_mla_fwd( + B, + S, + SKV, + H, + HKV, + DQK, + DV, + topk, + provider, + warmup, + rep, + tilelang_block_I, + tilelang_num_stages, + tilelang_threads, +): + dtype = torch.bfloat16 + q, kv, indices = _build_sparse_mla_inputs(B=B, S=S, SKV=SKV, H=H, HKV=HKV, DQK=DQK, topk=topk, dtype=dtype, seed=1) + quantiles = [0.5, 0.2, 0.8] + + if provider == "triton": + + def run(): + triton_sparse_mla_fwd_interface(q, kv, indices, d_v=DV) + + elif provider == "tle": + + def run(): + tle_sparse_mla_fwd_interface(q, kv, indices, d_v=DV) + + else: + if not _HAVE_TILELANG: + return float("nan"), float("nan"), float("nan") + resolved_block_i = _resolve_tilelang_block_i(topk, tilelang_block_I) + + def run(): + tilelang_sparse_mla_fwd_interface( + q, + kv, + indices, + d_v=DV, + block_I=resolved_block_i, + num_stages=tilelang_num_stages, + threads=tilelang_threads, + ) + + try: + ms, min_ms, max_ms = triton.testing.do_bench( + run, + quantiles=quantiles, + warmup=warmup, + rep=rep, + ) + except Exception as exc: # pragma: no cover - depends on runtime/resource limits + print(f"[bench:{provider}] failed for " + f"(B={B}, S={S}, SKV={SKV}, H={H}, HKV={HKV}, DQK={DQK}, DV={DV}, topk={topk}): {exc}") + return float("nan"), float("nan"), float("nan") + return ms, max_ms, min_ms + + +def run_bench_table(warmup=100, rep=50, show_plots=False, tilelang_block_I=64, tilelang_num_stages=2, + tilelang_threads=256): + benchmark_sparse_mla_fwd.run( + print_data=True, + show_plots=show_plots, + warmup=warmup, + rep=rep, + tilelang_block_I=tilelang_block_I, + tilelang_num_stages=tilelang_num_stages, + tilelang_threads=tilelang_threads, + ) + + +def test_sparse_mla_fwd( + B=1, + S=4096, + SKV=4096, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + check_tle=True, + check_tilelang=False, + tilelang_block_I=64, + tilelang_num_stages=2, + tilelang_threads=256, +): + q, kv, indices = _build_sparse_mla_inputs(B=B, S=S, SKV=SKV, H=H, HKV=HKV, DQK=DQK, topk=topk, dtype=dtype, seed=0) + ref_bf16_out = ref_sparse_mla_fwd_interface(q, kv, indices, d_v=DV) + + triton_bf16_out, triton_bf16_lse = triton_sparse_mla_fwd_interface(q, kv, indices, d_v=DV) + print("triton (no TLE API) bf16 done \n triton lse tensor: \n", triton_bf16_lse) + print() + + assert torch.allclose( + triton_bf16_out.float(), + ref_bf16_out.float(), + atol=1e-1, + rtol=1e-1, + ), "Triton sparse MLA fwd bf16 does not match reference" + print("Triton sparse MLA fwd bf16 matches reference!") + + if check_tle: + tle_bf16_out, tle_bf16_lse = tle_sparse_mla_fwd_interface(q, kv, indices, d_v=DV) + print("tle bf16 done \n tle lse tensor: \n", tle_bf16_lse) + print() + assert torch.allclose( + tle_bf16_out.float(), + ref_bf16_out.float(), + atol=1e-1, + rtol=1e-1, + ), "TLE sparse MLA fwd bf16 does not match reference" + print("TLE sparse MLA fwd bf16 matches reference!") + + if check_tilelang: + if not _HAVE_TILELANG: + raise RuntimeError("TileLang is not installed, cannot run TileLang correctness check") + resolved_block_i = _resolve_tilelang_block_i(topk, tilelang_block_I) + tilelang_bf16_out, _tilelang_bf16_lse = tilelang_sparse_mla_fwd_interface( + q, + kv, + indices, + d_v=DV, + block_I=resolved_block_i, + num_stages=tilelang_num_stages, + threads=tilelang_threads, + ) + assert torch.allclose( + tilelang_bf16_out.float(), + ref_bf16_out.float(), + atol=1e-1, + rtol=1e-1, + ), "TileLang sparse MLA fwd bf16 does not match reference" + print("TileLang sparse MLA fwd bf16 matches reference!") + + +def bench_sparse_mla_fwd( + B=1, + S=4096, + SKV=4096, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + warmup=250, + rep=100, + check_outputs=True, + tilelang_block_I=64, + tilelang_num_stages=2, + tilelang_threads=256, +): + q, kv, indices = _build_sparse_mla_inputs(B=B, S=S, SKV=SKV, H=H, HKV=HKV, DQK=DQK, topk=topk, dtype=dtype, seed=0) + results = [] + + def run_triton(): + return triton_sparse_mla_fwd_interface(q, kv, indices, d_v=DV) + + triton_out, _ = run_triton() + triton_ms = _bench_ms(run_triton, warmup=warmup, rep=rep) + triton_tflops = _sparse_mla_tflops(B, S, H, DQK, DV, topk, triton_ms) + results.append(("triton", triton_ms, triton_tflops)) + + tle_out = None + tilelang_out = None + + def run_tle(): + return tle_sparse_mla_fwd_interface(q, kv, indices, d_v=DV) + + try: + tle_out, _ = run_tle() + tle_ms = _bench_ms(run_tle, warmup=warmup, rep=rep) + tle_tflops = _sparse_mla_tflops(B, S, H, DQK, DV, topk, tle_ms) + results.append(("tle", tle_ms, tle_tflops)) + except Exception as exc: # pragma: no cover - depends on tle/runtime constraints + print(f"TLE bench skipped due to compile/runtime error: {exc}") + + if _HAVE_TILELANG: + resolved_block_i = _resolve_tilelang_block_i(topk, tilelang_block_I) + if resolved_block_i != tilelang_block_I: + print(f"TileLang block_I auto-adjusted from {tilelang_block_I} to {resolved_block_i} " + f"for topk={topk}.") + + def run_tilelang(): + return tilelang_sparse_mla_fwd_interface( + q, + kv, + indices, + d_v=DV, + block_I=resolved_block_i, + num_stages=tilelang_num_stages, + threads=tilelang_threads, + ) + + try: + tilelang_out, _ = run_tilelang() + tilelang_ms = _bench_ms(run_tilelang, warmup=warmup, rep=rep) + tilelang_tflops = _sparse_mla_tflops(B, S, H, DQK, DV, topk, tilelang_ms) + results.append(("tilelang", tilelang_ms, tilelang_tflops)) + except Exception as exc: # pragma: no cover - depends on tilelang/runtime constraints + print(f"TileLang bench skipped due to compile/runtime error: {exc}") + else: + print("TileLang is not installed, skip TileLang bench.") + + print(f"{'provider':<18}{'ms':>10}{'tflops':>12}{'speedup':>12}") + for name, ms, tflops in results: + print(f"{name:<18}{ms:>10.3f}{tflops:>12.2f}{(triton_ms / ms):>12.2f}x") + + if check_outputs: + if tle_out is not None: + assert torch.allclose( + triton_out.float(), + tle_out.float(), + atol=1e-1, + rtol=1e-1, + ), "Triton output does not match TLE output" + print("Triton and TLE outputs match.") + if tilelang_out is not None: + assert torch.allclose( + triton_out.float(), + tilelang_out.float(), + atol=1e-1, + rtol=1e-1, + ), "Triton output does not match TileLang output" + print("Triton and TileLang outputs match.") + + +def _parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--mode", choices=["test", "bench", "bench-single"], default="bench") + parser.add_argument("--B", type=int, default=1) + parser.add_argument("--S", type=int, default=128) + parser.add_argument("--SKV", type=int, default=1024) + parser.add_argument("--H", type=int, default=32) + parser.add_argument("--HKV", type=int, default=1) + parser.add_argument("--DQK", type=int, default=288) + parser.add_argument("--DV", type=int, default=256) + parser.add_argument("--topk", type=int, default=64) + parser.add_argument("--dtype", choices=["bf16", "fp16"], default="bf16") + parser.add_argument("--warmup", type=int, default=250) + parser.add_argument("--rep", type=int, default=100) + parser.add_argument("--show-plots", action="store_true") + parser.add_argument("--skip-output-check", action="store_true") + parser.add_argument("--skip-tle-check", action="store_true") + parser.add_argument("--check-tilelang", action="store_true") + parser.add_argument("--tilelang-block-I", type=int, default=64) + parser.add_argument("--tilelang-num-stages", type=int, default=2) + parser.add_argument("--tilelang-threads", type=int, default=256) + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_args() + dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16 + if args.mode == "test": + test_sparse_mla_fwd( + B=args.B, + S=args.S, + SKV=args.SKV, + H=args.H, + HKV=args.HKV, + DQK=args.DQK, + DV=args.DV, + topk=args.topk, + dtype=dtype, + check_tle=not args.skip_tle_check, + check_tilelang=args.check_tilelang, + tilelang_block_I=args.tilelang_block_I, + tilelang_num_stages=args.tilelang_num_stages, + tilelang_threads=args.tilelang_threads, + ) + elif args.mode == "bench-single": + bench_sparse_mla_fwd( + B=args.B, + S=args.S, + SKV=args.SKV, + H=args.H, + HKV=args.HKV, + DQK=args.DQK, + DV=args.DV, + topk=args.topk, + dtype=dtype, + warmup=args.warmup, + rep=args.rep, + check_outputs=not args.skip_output_check, + tilelang_block_I=args.tilelang_block_I, + tilelang_num_stages=args.tilelang_num_stages, + tilelang_threads=args.tilelang_threads, + ) + else: + run_bench_table( + warmup=args.warmup, + rep=args.rep, + show_plots=args.show_plots, + tilelang_block_I=args.tilelang_block_I, + tilelang_num_stages=args.tilelang_num_stages, + tilelang_threads=args.tilelang_threads, + ) diff --git a/skills/tle-developer/SKILL.md b/skills/tle-developer/SKILL.md new file mode 100644 index 0000000000..1deee7ad6d --- /dev/null +++ b/skills/tle-developer/SKILL.md @@ -0,0 +1,72 @@ +--- +name: tle-developer +description: Self-contained orchestration skill for writing high-performance TLE kernels and shipping TLE feature changes with reproducible validation. +--- + +# TLE Developer + +## Mission +Use this skill to execute TLE work end-to-end: +intake -> implementation -> validation -> artifacts -> merge decision. + +## Scope +Use for: +1. Writing or optimizing TLE kernels. +2. Implementing TLE API/verifier/lowering/pipeline features. +3. Debugging correctness, performance, and regression issues. + +## Self-Contained Policy +1. Do not rely on documentation outside this skill folder. +2. Put all detailed guidance in `references/`. +3. Keep this file as orchestration-only (no duplicated deep details). + +## Required Input +Start every task with: +```text +Goal: +Non-goal: +Acceptance: +Impact scope (optional): +``` + +## Mandatory Read Order +1. `references/tle-sources.md` +2. `references/workflow-templates.md` + +## Operating Contract +1. Treat `references/tle-sources.md` as the technical source of truth for: + - quickstart, + - current TLE semantics contract, + - kernel patterns, + - feature-development file map, + - debug/perf procedures. +2. Treat `references/workflow-templates.md` as the source of truth for: + - intake, + - validation matrix, + - performance record, + - fix summary, + - lessons entry, + - merge package. + +## Non-Negotiable Guardrails +1. Never assume a specific python environment name. +2. Never assume a fixed build script name. +3. If native Triton files are modified for TLE-specific behavior, use compile-time guards like `#ifdef __TLE__` / `#endif`. +4. Do not use comment marker blocks (`// begin flagtree tle` / `// end flagtree tle`) as a policy mechanism. +5. Do not add `__TLE__` guards inside `third_party/tle` unless that subtree explicitly requires it. + +## Required Outputs Per Task +1. Validation commands and outcomes. +2. Fix Summary (when fixing bugs or regressions). +3. Lessons Entry (for fixes and optimization work). +4. Merge Decision Package (changed layers, risks, follow-ups). + +## Completion Checklist +1. Acceptance criteria mapped to tests. +2. Changes validated with reproducible commands. +3. Artifacts filled from templates. +4. Residual risks explicitly stated. + +## References +1. `references/tle-sources.md` +2. `references/workflow-templates.md` diff --git a/skills/tle-developer/references/tle-sources.md b/skills/tle-developer/references/tle-sources.md new file mode 100644 index 0000000000..0db3374582 --- /dev/null +++ b/skills/tle-developer/references/tle-sources.md @@ -0,0 +1,352 @@ +# TLE Practical Guide (Beginner to Advanced) + +This guide is self-contained and executable. +It targets three jobs: +1. write a working TLE kernel, +2. optimize it to high performance, +3. implement new TLE functionality in API/IR/lowering/pipeline and debug failures. + +## 1. First-Run Quickstart + +### 1.1 Environment Preflight +Run from repo root: + +```bash + -V + -c "import torch, triton; print('torch', torch.__version__, 'cuda', torch.version.cuda); print('triton', triton.__version__)" + -c "import torch; print('cuda_available', torch.cuda.is_available()); print('device_count', torch.cuda.device_count())" +``` + +`` can be any of: +1. `python` (active shell env), +2. `/path/to/venv/bin/python`, +3. `conda run -n python`. + +If C++ bindings need rebuild, use your repo's actual build entrypoint. +Do not assume a specific script exists. + +```bash +# Option A: project-provided build script (if present) +if [ -x ./build.sh ]; then + ./build.sh +elif [ -x ./scripts/build.sh ]; then + ./scripts/build.sh +fi + +# Option B: editable python rebuild path (if your project uses setuptools/pyproject) + -m pip install -e . + +# Option C: CMake/Ninja path (if your project is cmake-based) +ninja -C +``` + +If none of the above match your repo, define `` explicitly in your task context. + +### 1.2 Minimal End-to-End Script (Host + Kernel + Check) +Create and run this script directly. + +```python +import torch +import triton +import triton.language as tl + +@triton.jit +def tle_axpy_kernel(x_ptr, y_ptr, out_ptr, n, alpha, BLOCK: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < n + + smem = tle.gpu.alloc([BLOCK], dtype=tl.float32, layout=None, scope=tle.gpu.smem, nv_mma_shared_layout=False) + ptrs = tle.gpu.local_ptr(smem, (tl.arange(0, BLOCK),)) + + x = tl.load(x_ptr + offs, mask=mask, other=0.0) + y = tl.load(y_ptr + offs, mask=mask, other=0.0) + tl.store(ptrs, x, mask=mask) + z = tl.load(ptrs, mask=mask, other=0.0) * alpha + y + tl.store(ptrs, z, mask=mask) + tl.store(out_ptr + offs, tl.load(ptrs, mask=mask, other=0.0), mask=mask) + + +def main(): + torch.manual_seed(0) + n = 4096 + block = 256 + alpha = 1.25 + + x = torch.randn(n, device='cuda', dtype=torch.float32) + y = torch.randn(n, device='cuda', dtype=torch.float32) + out = torch.empty_like(x) + + grid = (triton.cdiv(n, block),) + tle_axpy_kernel[grid](x, y, out, n, alpha, BLOCK=block) + + ref = x * alpha + y + torch.testing.assert_close(out, ref, atol=1e-6, rtol=1e-6) + print('PASS: correctness check') + + # Compiler artifact inspection (critical for debug/perf work) + compiled = tle_axpy_kernel.warmup(x, y, out, n, alpha, BLOCK=block, grid=grid) + ttgir = compiled.asm.get('ttgir', '') + ptx = compiled.asm.get('ptx', '') + print('TTGIR length:', len(ttgir)) + print('PTX length:', len(ptx)) + print('Has local pointers op:', 'tle.local_pointers' in ttgir) + + +if __name__ == '__main__': + main() +``` + +Run: + +```bash + /tmp/tle_axpy_quickstart.py +``` + +## 2. Current TLE Semantics Baseline (No External File Needed) + +### 2.1 `local_ptr` Contract (Current Code) +API form: +```python +ptr = tle.gpu.local_ptr(buffer, indices) +``` + +Rules: +1. `buffer` must be a TLE buffered tensor from `tle.gpu.alloc`. +2. `indices` must be tuple/list (or Triton tuple) and cannot be empty. +3. Index count must equal buffer rank. +4. Index dtype must be integer. +5. Either all scalar indices or all tensor indices. +6. Tensor-index mode requires all index tensors to have identical shape. +7. Mixed scalar/tensor index usage is invalid. + +### 2.2 Shared-Memory Pointer Semantics +1. Local pointers are shared-memory pointers in lowering semantics. +2. Load/store lowering must branch by pointer address space (shared vs global). + +### 2.3 Local Pointer Pipeline Invariants +NVIDIA TTGIR pipeline local pointer segment: +1. `add_early_assign_memory_space` +2. `add_select_encodings` +3. `add_insert_local_pointer_barriers` +4. `add_optimize_local_pointer_loads` +5. `add_optimize_local_pointer_stores` + +Do not reorder without proof and tests. + +### 2.4 TLE->LLVM Legality Requirements +TLE conversion path includes: +1. legal `mlir::gpu::GPUDialect`, +2. legal `mlir::UnrealizedConversionCastOp`, +3. registered local pointer conversion patterns. + +## 3. Kernel Authoring Patterns + +### 3.1 1D Local Staging Pattern +Use for elementwise fusion and short reuse windows. + +```python +smem = tle.gpu.alloc([BLOCK], dtype=tl.float32, layout=None, scope=tle.gpu.smem, nv_mma_shared_layout=False) +ptrs = tle.gpu.local_ptr(smem, (tl.arange(0, BLOCK),)) +vals = tl.load(global_ptrs, mask=mask, other=0.0) +tl.store(ptrs, vals, mask=mask) +out = tl.load(ptrs, mask=mask, other=0.0) +``` + +### 3.2 2D Tile Pointer Pattern +Use when loading and slicing tiles. + +```python +rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, BK)) +cols = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK)) +ptr = tle.gpu.local_ptr(tile_buf, (rows, cols)) +sub = tl.load(ptr) +``` + +### 3.3 `copy` vs `load/store` +1. Use `tle.gpu.copy` for explicit transfer operations and descriptor/TMA flows. +2. Use `local_ptr + tl.load/store` for custom indexing and compute choreography. + +### 3.4 Distributed Entry Pattern +```python +import triton.experimental.tle.language as tle + +mesh = tle.device_mesh({"block_cluster": [("cluster_x", 2), ("cluster_y", 2)]}) +sid = tle.shard_id(mesh, "cluster_x") +tle.distributed_barrier(mesh) +``` + +## 4. High-Performance Optimization Playbook + +### 4.1 Parameter Priority (Most Impact First) +1. Tile sizes (`BLOCK_M`, `BLOCK_N`, `BLOCK_K` or 1D `BLOCK`). +2. `num_warps`. +3. `num_stages`. +4. Memory path choice (`copy` vs manual load/store). +5. Layout settings (`nv_mma_shared_layout`, swizzled layout choices). + +### 4.2 One-Change Benchmark Loop +For each candidate: +1. Keep shape/seed/grid fixed. +2. Change one parameter only. +3. Run correctness check. +4. Run timed benchmark. +5. Capture TTGIR/PTX evidence. + +Minimal timing skeleton: + +```python +import time + +def bench(fn, rep=50): + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(rep): + fn() + torch.cuda.synchronize() + return (time.perf_counter() - t0) / rep +``` + +### 4.3 Stop Conditions +Stop tuning when one is true: +1. no measurable improvement for 3 consecutive single-parameter trials, +2. regression risk rises (correctness instability, brittle masking), +3. achieved target acceptance performance. + +## 5. Debug Guide (Command-Level) + +### 5.1 Fast Triage Order +1. Reproduce with smallest shape that still fails. +2. Confirm correctness mismatch vs Torch reference. +3. Dump TTGIR/PTX via `warmup(...).asm`. +4. Identify layer: API, verifier, lowering, or runtime behavior. + +### 5.2 Useful Commands +Targeted tests first: + +```bash + -m pytest python/test/tle/unit/test_tle_gpu_local_ptr.py -vv -s + -m pytest python/test/tle/integration/test_tle_local_store.py -vv -s + -m pytest python/test/tle/integration/test_tle_distributed.py -vv -s +``` + +Search relevant code quickly: + +```bash +rg -n "def local_ptr\(|analyze_local_pointer_operation" python/triton/experimental/tle/language/gpu +rg -n "LocalPointersOp::verify|kSharedMemoryAddressSpace" third_party/tle/dialect/lib/IR/Ops.cpp +rg -n "TleSelectEncodings|TleInsertLocalPointerBarriers" third_party/tle/dialect/lib/Transforms +rg -n "add_early_assign_memory_space|add_select_encodings|add_insert_local_pointer_barriers" third_party/nvidia/backend/compiler.py +rg -n "populateLocalPointersOpToLLVMPatterns|UnrealizedConversionCastOp|GPUDialect" third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +``` + +### 5.3 Symptom -> Likely Layer -> Action +1. Verifier error on pointer/index shape: + - Layer: API/verifier. + - Action: validate index count/type/shape contract in local_ptr call and verifier. +2. Compiles but wrong output: + - Layer: kernel logic or lowering mismatch. + - Action: reduce shape, isolate one tile, compare intermediate loads/stores. +3. Intermittent mismatch after local store/load: + - Layer: ordering/barrier behavior. + - Action: inspect barrier insertion path and simplify control flow. +4. No perf gain after local staging: + - Layer: layout conversions / pipeline. + - Action: count key TTGIR/PTX patterns before/after and verify traffic reduction. + +## 6. Implementing New TLE Features (Concrete File Map) + +Use this section when changing language semantics or compiler behavior. + +### 6.1 Python API Layer +Typical files: +1. `python/triton/experimental/tle/__init__.py` +2. `python/triton/experimental/tle/language/__init__.py` +3. `python/triton/experimental/tle/language/gpu/core.py` +4. `python/triton/experimental/tle/language/gpu/semantic.py` + +What to do: +1. expose API, +2. enforce argument contract and error messages, +3. add semantic checks and tests. + +### 6.2 IR and Verifier Layer +Typical files: +1. `third_party/tle/dialect/include/IR/TleOps.td` +2. `third_party/tle/dialect/lib/IR/Ops.cpp` + +What to do: +1. update op defs/types/attrs, +2. add/adjust verifier invariants, +3. keep diagnostics specific and actionable. + +### 6.3 Lowering/Conversion Layer +Typical files: +1. `third_party/tle/dialect/lib/Conversion/TleToLLVM/LocalPointersOpToLLVM.cpp` +2. related conversion files under `third_party/tle/dialect/lib/Conversion/TleToLLVM/`. + +What to do: +1. map op semantics to LLVM-compatible forms, +2. preserve address-space correctness, +3. handle shape/encoding consistency. + +### 6.4 Transform and Pass Wiring +Typical files: +1. `third_party/tle/dialect/lib/Transforms/TleSelectEncodings.cpp` +2. `third_party/tle/dialect/lib/Transforms/TleInsertLocalPointerBarriers.cpp` +3. `third_party/nvidia/backend/compiler.py` +4. `third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp` + +What to do: +1. maintain pass ordering invariants, +2. ensure conversion target legality is correct, +3. ensure patterns are registered. + +### 6.5 Test Coverage Placement +1. Unit semantics: `python/test/tle/unit/` +2. Integration behavior: `python/test/tle/integration/` +3. Backend-specific cases: `third_party//python/test/` + +Minimum required test additions for semantic changes: +1. one positive case, +2. one negative contract case, +3. one regression case that would fail without your fix. + +## 7. Validation Matrix and Done Criteria + +### 7.1 Validation Matrix +1. targeted unit tests for changed API/verifier path, +2. targeted integration tests for changed lowering path, +3. backend-specific tests if pass/codegen changed, +4. `ninja check-*` if C++ compiler components changed. + +### 7.2 Done Criteria +A change is done only when: +1. behavior contract is explicit, +2. tests cover positive + negative + regression, +3. commands and outcomes are reproducible, +4. Fix Summary and Lessons Entry are completed, +5. residual risk and follow-up are listed. + +## 8. API Surface Snapshot + +### `triton.experimental.tle` +- `device_mesh`, `S`, `P`, `B` +- `sharding`, `ShardingSpec` +- `ShardedTensor`, `make_sharded_tensor` +- `reshard`, `remote`, `shard_id`, `distributed_barrier`, `distributed_dot` +- `language`, optional `raw` + +### `triton.experimental.tle.language` +- `load`, `gpu`, `raw` + +### `tle.gpu` +- `pipeline`, `alloc`, `copy`, `local_ptr`, `memory_space` +- `layout`, `shared_layout`, `swizzled_shared_layout`, `tensor_memory_layout`, `nv_mma_shared_layout` +- `scope`, `smem`, `tmem`, `buffered_tensor`, `buffered_tensor_type` + +### `triton.experimental.tle.language.raw` +- `call` + +### `triton.experimental.tle.raw` +- `dialect`, `Input`, `InOut` diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 920147680b..fcd80f93c2 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -293,8 +293,10 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttgpuir.add_optimize_thread_locality(pm) tle.passes.add_early_assign_memory_space(pm) # begin flagtree tle - tle.passes.add_assign_local_pointers_encoding(pm) + tle.passes.add_select_encodings(pm) tle.passes.add_insert_local_pointer_barriers(pm) + tle.passes.add_optimize_local_pointer_loads(pm) + tle.passes.add_optimize_local_pointer_stores(pm) # end flagtree tle passes.ttgpuir.add_accelerate_matmul(pm) passes.ttgpuir.add_remove_layout_conversions(pm) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Allocation.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Allocation.cpp index c62b97c546..9428b2a28a 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Allocation.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Allocation.cpp @@ -1,14 +1,22 @@ +#include +#include #include #include "Allocation.h" #include "TargetInfo.h" #include "TritonNVIDIAGPUToLLVM/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "triton/Analysis/Allocation.h" +#include "triton/Analysis/Utility.h" #include "triton/Conversion/TritonGPUToLLVM/AllocateSharedMemoryUtility.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Tools/GenericSwizzling.h" #include "triton/Tools/LayoutUtils.h" +#ifdef __TLE__ +#include "tle/dialect/include/IR/Dialect.h" +#endif using namespace mlir; using namespace mlir::triton; @@ -42,6 +50,27 @@ struct AllocateSharedMemoryNv namespace mlir::triton::nvidia_gpu { +namespace { +#ifdef __TLE__ +static bool isTleCtaOrReduceFastpathCandidate(ReduceOp reduceOp) { + if (reduceOp.getNumOperands() != 1 || reduceOp.getNumResults() != 1) + return false; + if (!reduceOp.getResult()[0].getType().isInteger(1)) + return false; + auto *combine = reduceOp.getSingleCombiner(); + if (!combine || !isa(combine)) + return false; + + ReduceOpHelper helper(reduceOp); + if (!helper.isReduceWithinCTA()) + return false; + if (helper.isWarpSynchronous()) + return false; + return true; +} +#endif +} // namespace + static unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy, RankedTensorType dstTy, TargetInfoBase &targetInfo) { @@ -61,6 +90,35 @@ static unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy, std::function getNvidiaAllocationAnalysisScratchSizeFn(TargetInfoBase &targetInfo) { auto allocation = [&targetInfo](Operation *op) -> unsigned { +#ifdef __TLE__ + if (auto cumsumOp = dyn_cast(op)) { + auto srcTy = dyn_cast(cumsumOp.getSrc().getType()); + if (!srcTy || srcTy.getRank() != 1) + return 0; + int64_t axisExtent = srcTy.getShape()[0]; + if (ShapedType::isDynamic(axisExtent) || axisExtent <= 0) + return 0; + unsigned elemBytes = + static_cast(std::max(1, getBitwidth(srcTy) / 8)); + // Scratch layout for cumsum lowering: + // [axisExtent data][numWarps warp-prefix slots][1 total slot] + int64_t numWarps = std::max(1, triton::gpu::lookupNumWarps(op)); + uint64_t totalBytes = (static_cast(axisExtent) + + static_cast(numWarps) + 1ull) * + elemBytes; + if (totalBytes > std::numeric_limits::max()) + return 0; + return static_cast(totalBytes); + } +#endif + if (auto reduceOp = dyn_cast(op)) { +#ifdef __TLE__ + // TLE fastpath lowers CTA-wide i1 OR reduce directly to bar.red.or.pred, + // so no shared scratch allocation is needed for this op. + if (isTleCtaOrReduceFastpathCandidate(reduceOp)) + return 0; +#endif + } if (auto cvtOp = dyn_cast(op)) { auto srcTy = cvtOp.getSrc().getType(); auto dstTy = cvtOp.getType(); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 895829ec1d..7b0fdf89b3 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -75,10 +75,19 @@ std::optional inferPtrAddrSpace(llvm::ArrayRef ptrElems) { #endif } +#ifdef __TLE__ +bool isSharedFamilyAddressSpace(unsigned addressSpace) { + return addressSpace == 3 || + addressSpace == + static_cast(NVVM::NVVMMemorySpace::SharedCluster); +} +#endif + bool isSharedPointerValue(llvm::ArrayRef ptrElems, unsigned defaultAddrSpace = 1) { #ifdef __TLE__ - return inferPtrAddrSpace(ptrElems).value_or(defaultAddrSpace) == 3; + return isSharedFamilyAddressSpace( + inferPtrAddrSpace(ptrElems).value_or(defaultAddrSpace)); #else return inferPtrAddrSpace(ptrElems).value_or(defaultAddrSpace) == 3; #endif @@ -120,21 +129,6 @@ unsigned getCanonicalIndex(unsigned index, unsigned freeVarMask) { return index & ~freeVarMask; } -#ifdef __TLE__ -bool canReuseCanonicalPointer(llvm::ArrayRef ptrElems, - size_t currentStart, size_t canonicalStart, - size_t width) { - if (currentStart + width > ptrElems.size() || - canonicalStart + width > ptrElems.size()) - return false; - for (size_t i = 0; i < width; ++i) { - if (ptrElems[currentStart + i] != ptrElems[canonicalStart + i]) - return false; - } - return true; -} -#endif - std::string getRegisterSizeCode(int size, bool is_float) { switch (size) { case 1: @@ -248,25 +242,23 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, Value llPtr = adaptor.getPtr(); Value llMask = adaptor.getMask(); Value llOther = adaptor.getOther(); -#ifdef __TLE__ - auto remoteCTAInfo = tte::getRemotePointerInfoFromValue(ptr, rewriter); -#endif // Determine the vectorization size Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(op.getType())); unsigned vec = getVectorSize(ptr); #ifdef __TLE__ - // remote metadata carriers can pessimize AxisInfo on the load pointer. - // Reuse the underlying pointer as a hint to recover vector width. - if (remoteCTAInfo.hasRemoteCTAId() && !llMask) { - vec = - std::max(vec, tte::inferTlePointerVectorSize(ptr, axisAnalysisPass)); - if (remoteCTAInfo.vectorHintPtr && remoteCTAInfo.vectorHintPtr != ptr) { - vec = std::max(vec, tte::inferTlePointerVectorSize( - remoteCTAInfo.vectorHintPtr, axisAnalysisPass)); - vec = std::max(vec, getVectorSize(remoteCTAInfo.vectorHintPtr)); - } + auto ptrTensorTy = dyn_cast(ptr.getType()); + auto ptrElemTy = ptrTensorTy + ? dyn_cast(ptrTensorTy.getElementType()) + : PointerType(); + bool isSharedTensorPtr = + ptrElemTy && isSharedFamilyAddressSpace(ptrElemTy.getAddressSpace()); + if (!llMask && isSharedTensorPtr) { + // For TLE local/shared pointer chains, AxisInfo divisibility can be + // conservative on packed contiguous lanes. Recover vector width from + // the pointer layout as a lower-bound hint. + vec = std::max(vec, tte::inferTlePointerLayoutVectorHint(ptr)); } #endif unsigned numElems = getTotalElemsPerThread(ptr.getType()); @@ -287,47 +279,12 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, } // Get the LLVM values for pointers #ifdef __TLE__ - Value llBasePtr = llPtr; - if (remoteCTAInfo.basePtr != ptr) { - llBasePtr = rewriter.getRemappedValue(remoteCTAInfo.basePtr); - if (!llBasePtr) - return op.emitError("failed to remap remote base pointer"); - } - - auto ptrElems = unpackLLElements(loc, llBasePtr, rewriter); + auto ptrElems = unpackLLElements(loc, llPtr, rewriter); assert(ptrElems.size() == numElems); const bool isSharedPtr = isSharedPointerValue(ptrElems); - if (remoteCTAInfo.hasRemoteCTAId() && !isSharedPtr) - return op.emitError("remote shard_id requires shared-memory pointers"); - - auto ensureI32 = [&](Value v) -> Value { - if (!v) - return Value(); - if (v.getType().isInteger(32)) - return v; - if (auto intTy = dyn_cast(v.getType())) { - if (intTy.getWidth() > 32) - return rewriter.create(loc, rewriter.getI32Type(), v); - if (intTy.isUnsigned()) - return rewriter.create(loc, rewriter.getI32Type(), v); - return rewriter.create(loc, rewriter.getI32Type(), v); - } - return Value(); - }; - auto materializeRemoteCTAId = [&](Value v) -> Value { - if (!v) - return Value(); - if (Value scalar = ensureI32(v)) - return scalar; - auto elems = unpackLLElements(loc, v, rewriter); - if (elems.empty()) - return Value(); - return ensureI32(elems.front()); - }; - Value remoteDynamicCTAId = - materializeRemoteCTAId(remoteCTAInfo.dynamicCTAId); - if (remoteCTAInfo.dynamicCTAId && !remoteDynamicCTAId) - return op.emitError("runtime shard_id must lower to scalar integer"); + const bool isClusterSharedPtr = + inferPtrAddrSpace(ptrElems).value_or(1) == + static_cast(NVVM::NVVMMemorySpace::SharedCluster); #else auto ptrElems = unpackLLElements(loc, llPtr, rewriter); assert(ptrElems.size() == numElems); @@ -370,51 +327,6 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, << " valueElemNBits = " << valueElemNBits << " " << op.getType()); SmallVector loadedVals; -#ifdef __TLE__ - if (remoteCTAInfo.hasRemoteCTAId()) { - Value ctaId = remoteCTAInfo.constCTAId - ? b.i32_val(*remoteCTAInfo.constCTAId) - : remoteDynamicCTAId; - auto maybeStripShardOffset = [&](Value ptrVal) -> Value { - if (!remoteCTAInfo.stripShardOffsetFromPtr) - return ptrVal; - Value negCtaId = rewriter.create( - loc, rewriter.getI32Type(), b.i32_val(0), ctaId); - return b.gep(ptrVal.getType(), valueElemTy, ptrVal, negCtaId); - }; - for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { - if (auto canonicalVecStart = getCanonicalIndex(vecStart, regMask); - vecStart != canonicalVecStart && - canReuseCanonicalPointer(ptrElems, vecStart, canonicalVecStart, - vec)) { - for (size_t iVec = 0; iVec < vec; ++iVec) - loadedVals.push_back(loadedVals[canonicalVecStart + iVec]); - continue; - } - Value pred = llMask ? maskElems[vecStart] : b.true_val(); - Value sharedPtr = maybeStripShardOffset(ptrElems[vecStart]); - Type remoteLoadTy = vec == 1 - ? valueElemTy - : Type(LLVM::getVectorType(valueElemTy, vec)); - Value loadedVec = targetInfo.loadDShared(rewriter, loc, sharedPtr, - ctaId, remoteLoadTy, pred, op); - auto loadedElems = unpackLLVector(loc, loadedVec, rewriter); - assert(loadedElems.size() == vec); - for (size_t iVec = 0; iVec < vec; ++iVec) { - size_t idx = vecStart + iVec; - Value loaded = loadedElems[iVec]; - if (llMask && other) - loaded = b.select(maskElems[idx], loaded, otherElems[idx]); - loadedVals.push_back(loaded); - } - } - Type llvmResultStructTy = typeConverter->convertType(op.getType()); - Value resultStruct = packLLElements(loc, typeConverter, loadedVals, - rewriter, llvmResultStructTy); - rewriter.replaceOp(op, {resultStruct}); - return success(); - } -#endif for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { if (auto canonicalVecStart = getCanonicalIndex(vecStart, regMask); vecStart != canonicalVecStart) { @@ -488,10 +400,17 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, } } +#ifdef __TLE__ + Value addrVal = ptrElems[vecStart]; + const char *addrConstraint = "l"; + if (isClusterSharedPtr) { + addrVal = rewriter.create(loc, rewriter.getI32Type(), + addrVal); + addrConstraint = "r"; + } auto *addrOpr = - ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); + ptxBuilder.newAddrOperand(addrVal, addrConstraint, in_off); -#ifdef __TLE__ // Create L2 cache policy register only for global-memory accesses. Value l2PolicyReg; if (!isSharedPtr) @@ -501,7 +420,8 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, auto *ld = ptxBuilder.create<>("ld"); ld->o("volatile", op.getIsVolatile()); if (isSharedPtr) { - ld->shared(); + ld->o("shared::cluster", isClusterSharedPtr) + .o("shared", !isClusterSharedPtr); } else { ld->global() .o("ca", op.getCache() == triton::CacheModifier::CA) @@ -623,48 +543,13 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, unsigned elemsPerThread = getTotalElemsPerThread(ptr.getType()); #ifdef __TLE__ - auto remoteCTAInfo = tte::getRemotePointerInfoFromValue(ptr, rewriter); - Value llBasePtr = llPtr; - if (remoteCTAInfo.basePtr != ptr) { - llBasePtr = rewriter.getRemappedValue(remoteCTAInfo.basePtr); - if (!llBasePtr) - return op.emitError("failed to remap remote base pointer"); - } - auto ptrElems = unpackLLElements(loc, llBasePtr, rewriter); + auto ptrElems = unpackLLElements(loc, llPtr, rewriter); auto valueElems = unpackLLElements(loc, llValue, rewriter); assert(ptrElems.size() == valueElems.size()); const bool isSharedPtr = isSharedPointerValue(ptrElems); - if (remoteCTAInfo.hasRemoteCTAId() && !isSharedPtr) - return op.emitError("remote shard_id requires shared-memory pointers"); - - auto ensureI32 = [&](Value v) -> Value { - if (!v) - return Value(); - if (v.getType().isInteger(32)) - return v; - if (auto intTy = dyn_cast(v.getType())) { - if (intTy.getWidth() > 32) - return rewriter.create(loc, rewriter.getI32Type(), v); - if (intTy.isUnsigned()) - return rewriter.create(loc, rewriter.getI32Type(), v); - return rewriter.create(loc, rewriter.getI32Type(), v); - } - return Value(); - }; - auto materializeRemoteCTAId = [&](Value v) -> Value { - if (!v) - return Value(); - if (Value scalar = ensureI32(v)) - return scalar; - auto elems = unpackLLElements(loc, v, rewriter); - if (elems.empty()) - return Value(); - return ensureI32(elems.front()); - }; - Value remoteDynamicCTAId = - materializeRemoteCTAId(remoteCTAInfo.dynamicCTAId); - if (remoteCTAInfo.dynamicCTAId && !remoteDynamicCTAId) - return op.emitError("runtime shard_id must lower to scalar integer"); + const bool isClusterSharedPtr = + inferPtrAddrSpace(ptrElems).value_or(1) == + static_cast(NVVM::NVVMMemorySpace::SharedCluster); #else auto ptrElems = unpackLLElements(loc, llPtr, rewriter); auto valueElems = unpackLLElements(loc, llValue, rewriter); @@ -700,37 +585,6 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo); uint32_t regMask = freeVarMasks[str_attr("reg")]; -#ifdef __TLE__ - if (remoteCTAInfo.hasRemoteCTAId()) { - Value ctaId = remoteCTAInfo.constCTAId - ? b.i32_val(*remoteCTAInfo.constCTAId) - : remoteDynamicCTAId; - auto maybeStripShardOffset = [&](Value ptrVal) -> Value { - if (!remoteCTAInfo.stripShardOffsetFromPtr) - return ptrVal; - Value negCtaId = rewriter.create( - loc, rewriter.getI32Type(), b.i32_val(0), ctaId); - return b.gep(ptrVal.getType(), valueElemTy, ptrVal, negCtaId); - }; - // Conservative path: preserve correctness for all remote pointer shapes. - for (size_t vecStart = 0; vecStart < elemsPerThread; vecStart += vec) { - if (!isCanonicalIndex(vecStart, regMask)) - continue; - for (size_t iVec = 0; iVec < vec; ++iVec) { - size_t idx = vecStart + iVec; - Value pred = threadPred ? threadPred : b.true_val(); - if (llMask) - pred = maybeAnd(rewriter, loc, pred, maskElems[idx]); - Value remotePtr = maybeStripShardOffset(ptrElems[idx]); - targetInfo.storeDShared(rewriter, loc, remotePtr, ctaId, - valueElems[idx], pred); - } - } - rewriter.eraseOp(op); - return success(); - } -#endif - const int numVecs = elemsPerThread / vec; for (size_t vecStart = 0; vecStart < elemsPerThread; vecStart += vec) { if (!isCanonicalIndex(vecStart, regMask)) { @@ -784,10 +638,17 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, pred = maybeAnd(rewriter, loc, pred, mask); } +#ifdef __TLE__ + Value addrVal = ptrElems[vecStart]; + const char *addrConstraint = "l"; + if (isClusterSharedPtr) { + addrVal = rewriter.create(loc, rewriter.getI32Type(), + addrVal); + addrConstraint = "r"; + } auto *asmAddr = - ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); + ptxBuilder.newAddrOperand(addrVal, addrConstraint, in_off); -#ifdef __TLE__ // Create L2 cache policy register only for global-memory accesses. Value l2PolicyReg; if (!isSharedPtr) @@ -796,7 +657,8 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, auto *ptxStoreInstr = ptxBuilder.create<>("st"); if (isSharedPtr) { - ptxStoreInstr->shared(); + ptxStoreInstr->o("shared::cluster", isClusterSharedPtr) + .o("shared", !isClusterSharedPtr); } else { ptxStoreInstr->global() .o("wb", op.getCache() == triton::CacheModifier::WB) @@ -1078,48 +940,11 @@ struct AtomicRMWOpConversion auto valElements = unpackLLElements(loc, llVal, rewriter); auto ptrElements = unpackLLElements(loc, llPtr, rewriter); -#ifdef __TLE__ - auto remoteCTAInfo = tte::getRemotePointerInfoFromValue(ptr, rewriter); - Value llBasePtr = llPtr; - if (remoteCTAInfo.basePtr != ptr) { - llBasePtr = rewriter.getRemappedValue(remoteCTAInfo.basePtr); - if (!llBasePtr) - return op.emitError("failed to remap remote base pointer"); - ptrElements = unpackLLElements(loc, llBasePtr, rewriter); - } -#endif const bool isSharedPtr = isSharedPointerValue(ptrElements); #ifdef __TLE__ - if (remoteCTAInfo.hasRemoteCTAId() && !isSharedPtr) - return op.emitError("remote shard_id requires shared-memory pointers"); - auto ensureI32 = [&](Value v) -> Value { - if (!v) - return Value(); - if (v.getType().isInteger(32)) - return v; - if (auto intTy = dyn_cast(v.getType())) { - if (intTy.getWidth() > 32) - return rewriter.create(loc, rewriter.getI32Type(), v); - if (intTy.isUnsigned()) - return rewriter.create(loc, rewriter.getI32Type(), v); - return rewriter.create(loc, rewriter.getI32Type(), v); - } - return Value(); - }; - auto materializeRemoteCTAId = [&](Value v) -> Value { - if (!v) - return Value(); - if (Value scalar = ensureI32(v)) - return scalar; - auto elems = unpackLLElements(loc, v, rewriter); - if (elems.empty()) - return Value(); - return ensureI32(elems.front()); - }; - Value remoteDynamicCTAId = - materializeRemoteCTAId(remoteCTAInfo.dynamicCTAId); - if (remoteCTAInfo.dynamicCTAId && !remoteDynamicCTAId) - return op.emitError("runtime shard_id must lower to scalar integer"); + const bool isClusterSharedPtr = + inferPtrAddrSpace(ptrElements).value_or(1) == + static_cast(NVVM::NVVMMemorySpace::SharedCluster); #endif SmallVector maskElements; if (llMask) @@ -1132,6 +957,18 @@ struct AtomicRMWOpConversion : valueTy; const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); auto elemsPerThread = getTotalElemsPerThread(val.getType()); + auto broadcastIfSplat = [&](SmallVector &elems) { + if (elems.size() == 1 && elemsPerThread > 1) + elems.assign(elemsPerThread, elems.front()); + }; + broadcastIfSplat(ptrElements); + broadcastIfSplat(valElements); + if (llMask) + broadcastIfSplat(maskElements); + if (ptrElements.size() != elemsPerThread || + valElements.size() != elemsPerThread || + (llMask && maskElements.size() != elemsPerThread)) + return op.emitError("unexpected element count in AtomicRMW lowering"); // packed: e.g. packed=2 for f16x2 // vec: e.g. .v2, .v4, .v8 version of atom instruction. unsigned vec, vecOrig; @@ -1180,9 +1017,8 @@ struct AtomicRMWOpConversion {triton::MemSyncScope::SYSTEM, triton::nvgpu::MemSyncScope::SYSTEM}}; #ifdef __TLE__ - const bool doPTXLDPromotion = - isPromotableToNVPTXLD(op) && vec == 1 && packed == 1 && - ScopeMap.count(op.getScope()) && !remoteCTAInfo.hasRemoteCTAId(); + const bool doPTXLDPromotion = isPromotableToNVPTXLD(op) && vec == 1 && + packed == 1 && ScopeMap.count(op.getScope()); #else const bool doPTXLDPromotion = isPromotableToNVPTXLD(op) && vec == 1 && packed == 1 && ScopeMap.count(op.getScope()); @@ -1200,24 +1036,7 @@ struct AtomicRMWOpConversion Value rmwPtr = ptrElements[i]; #ifdef __TLE__ - Value ctaId; - if (remoteCTAInfo.hasRemoteCTAId()) { - ctaId = remoteCTAInfo.constCTAId ? b.i32_val(*remoteCTAInfo.constCTAId) - : remoteDynamicCTAId; - if (remoteCTAInfo.stripShardOffsetFromPtr) { - Value negCtaId = rewriter.create( - loc, rewriter.getI32Type(), b.i32_val(0), ctaId); - rmwPtr = b.gep(rmwPtr.getType(), valueElemTy, rmwPtr, negCtaId); - } - rmwPtr = - targetInfo.mapSharedToClusterPointer(rewriter, loc, rmwPtr, ctaId); - } - auto rmwPtrTy = cast(rmwPtr.getType()); - const bool isClusterSharedPtr = - rmwPtrTy.getAddressSpace() == - static_cast(NVVM::NVVMMemorySpace::SharedCluster); - const bool useClusterSharedAtomic = - remoteCTAInfo.hasRemoteCTAId() || isClusterSharedPtr; + const bool useClusterSharedAtomic = isClusterSharedPtr; #else const bool isClusterSharedPtr = false; const bool useClusterSharedAtomic = false; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index e8f11e027a..3ac3f70407 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -493,29 +493,16 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, } else { #ifdef __TLE__ std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth); - const bool predIsConstTrue = isConstantTruePred(pred); - PTXBuilder::Operand *outOpr = nullptr; - if (vec == 1) { - outOpr = builder.newOperand(elemConstraint, !predIsConstTrue); - } else if (predIsConstTrue) { - outOpr = builder.newListOperand(vec, elemConstraint); - } else { - // Initialize predicated outputs to avoid ptxas mis-optimizing undefined - // destination registers. - outOpr = builder.newListOperand(); - for (unsigned i = 0; i < vec; ++i) - outOpr->listAppend(builder.newOperand(elemConstraint, /*init=*/true)); - } + auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint) + : builder.newListOperand(vec, elemConstraint); Value addrOperand = ptr; const char *addrConstraint = "l"; if (useCluster && isa(addrOperand.getType())) { addrOperand = rewriter.create(loc, i32_ty, addrOperand); addrConstraint = "r"; } - auto &ldExec = - ld(outOpr, builder.newAddrOperand(addrOperand, addrConstraint)); - if (!predIsConstTrue) - ldExec.predicate(pred, "b"); + ld(outOpr, builder.newAddrOperand(addrOperand, addrConstraint)) + .predicate(pred, "b"); #else std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth); auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint) @@ -609,6 +596,34 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, return false; } +#ifdef __TLE__ +std::optional TargetInfo::ctaReduceOrPredicate(RewriterBase &rewriter, + Location loc, + Value pred) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + if (!pred.getType().isInteger(1)) { + Type predTy = pred.getType(); + pred = b.icmp_ne(pred, b.int_val(predTy.getIntOrFloatBitWidth(), 0)); + } + + PTXBuilder ptxBuilder; + auto *out = ptxBuilder.newOperand("=r", /*init=*/false); + auto *inPred = ptxBuilder.newOperand(pred, "b"); + const char *ptx = R"( +{ + .reg .pred p_out; + bar.red.or.pred p_out, 0, $1; + selp.u32 $0, 1, 0, p_out; +} +)"; + auto &barRedOr = *ptxBuilder.create(ptx); + barRedOr({out, inPred}, /*onlyAttachMLIRArgs=*/true); + Value outI32 = ptxBuilder.launch(rewriter, loc, i32_ty, + /*hasSideEffect=*/true); + return b.icmp_ne(outI32, b.i32_val(0)); +} +#endif + std::string TargetInfo::getMulhiFuncName(Type resultElementTy) const { std::string funcName = resultElementTy.isInteger(32) ? "__nv_umulhi" : "__nv_umul64hi"; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h index 4426fb7a61..29607bb48d 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h @@ -55,6 +55,12 @@ class TargetInfo : public mlir::triton::TargetInfoBase { triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const override; +#ifdef __TLE__ + std::optional ctaReduceOrPredicate(RewriterBase &rewriter, + Location loc, + Value pred) const override; +#endif + std::string getMulhiFuncName(Type resultElementTy) const override; void printf(RewriterBase &rewriter, Value formatStrStart, diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index 4fa7375bdf..5ee7e3f6cc 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -23,6 +23,7 @@ #include "tle/dialect/include/Analysis/AxisInfoExt.h" #include "tle/dialect/include/Conversion/TleToLLVM/DSLRegionOpToLLVM.h" #include "tle/dialect/include/Conversion/TleToLLVM/DistributedBarrierOpToLLVM.h" +#include "tle/dialect/include/Conversion/TleToLLVM/ExclusiveCumsumOpToLLVM.h" #include "tle/dialect/include/Conversion/TleToLLVM/ExtractOpToLLVM.h" #include "tle/dialect/include/Conversion/TleToLLVM/LocalPointersOpToLLVM.h" #include "tle/dialect/include/Conversion/TleToLLVM/PackOpToLLVM.h" @@ -77,6 +78,9 @@ class TritonLLVMConversionTarget : public ConversionTarget { addIllegalDialect(); addIllegalDialect(); addIllegalDialect(); +#ifdef __TLE__ + addIllegalDialect(); +#endif addLegalOp(); // Warp specialization is lowered later. @@ -105,6 +109,7 @@ class TleLLVMConversionTarget : public ConversionTarget { } return hasLegalRegions && typeConverter.isLegal(op); }); + addLegalOp(); // Allow non-TLE ops to remain during this partial conversion. markUnknownOpDynamicallyLegal([](Operation *) -> bool { return true; }); } @@ -175,6 +180,8 @@ struct ConvertTritonGPUToLLVM patterns, benefit); mlir::triton::tle::populateInsertTileOpToLLVMPatterns(typeConverter, patterns, benefit); + mlir::triton::tle::populateExclusiveCumsumOpToLLVMPatterns( + typeConverter, targetInfo, patterns, benefit); if (failed(applyPartialConversion(mod, target, std::move(patterns)))) { return signalPassFailure(); } @@ -197,6 +204,10 @@ struct ConvertTritonGPUToLLVM populateLoadStoreOpToLLVMPatterns(typeConverter, targetInfo, computeCapability, patterns, axisInfoAnalysis, benefit); +#ifdef __TLE__ + mlir::triton::tle::populateRemotePointersOpToLLVMPatterns( + typeConverter, targetInfo, patterns, benefit + 1); +#endif mlir::triton::populateReduceOpToLLVMPatterns(typeConverter, patterns, targetInfo, benefit); mlir::triton::populateScanOpToLLVMPatterns(typeConverter, patterns, diff --git a/third_party/tle/README.md b/third_party/tle/README.md index be28702b4a..93dd91d3ab 100644 --- a/third_party/tle/README.md +++ b/third_party/tle/README.md @@ -8,7 +8,7 @@ TLE is a language extension for Triton that exposes shared memory and pipeline c - **Shared Memory Management**: `tle.alloc()` - Allocate shared/tensor memory with custom layouts - **Data Movement**: `tle.copy()` - Efficient bidirectional copying between memory spaces -- **Local Operations**: `tle.local_ptr(buffer, indices)` + `tl.load/tl.store` - Access shared/tensor memory through pointer tensors +- **Local Operations**: `tle.local_ptr(buffer, indices=None)` + `tl.load/tl.store` - Access shared/tensor memory through pointer tensors (`indices=None` means full buffer view) - **Pipeline Optimization**: `tle.pipeline()` - Hardware-aware pipeline iteration ## Memory Scopes & Layouts @@ -28,9 +28,9 @@ def kernel(a_ptr, b_ptr, c_ptr, n, BLOCK_SIZE: tl.constexpr): # Allocate shared memory a_smem = tle.alloc([BLOCK_SIZE], dtype=tl.float32, scope=tle.smem) b_smem = tle.alloc([BLOCK_SIZE], dtype=tl.float32, scope=tle.smem) - idx = tl.arange(0, BLOCK_SIZE) - a_ptrs = tle.local_ptr(a_smem, (idx,)) - b_ptrs = tle.local_ptr(b_smem, (idx,)) + # Full-view pointers (equivalent to passing explicit full indices) + a_ptrs = tle.local_ptr(a_smem) + b_ptrs = tle.local_ptr(b_smem) # Pipeline iteration for memory hiding for offset in tle.pipeline(0, n, BLOCK_SIZE, num_stages=2): diff --git a/third_party/tle/dialect/include/Conversion/TleToLLVM/ExclusiveCumsumOpToLLVM.h b/third_party/tle/dialect/include/Conversion/TleToLLVM/ExclusiveCumsumOpToLLVM.h new file mode 100644 index 0000000000..bda2044690 --- /dev/null +++ b/third_party/tle/dialect/include/Conversion/TleToLLVM/ExclusiveCumsumOpToLLVM.h @@ -0,0 +1,13 @@ +#ifndef TLE_RAW_CONVERSION_TLETOLLVM_EXCLUSIVECUMSUMOPTOLLVM_H +#define TLE_RAW_CONVERSION_TLETOLLVM_EXCLUSIVECUMSUMOPTOLLVM_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" + +namespace mlir::triton::tle { +void populateExclusiveCumsumOpToLLVMPatterns( + mlir::LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit); +} + +#endif diff --git a/third_party/tle/dialect/include/Conversion/TleToLLVM/LocalPointersOpToLLVM.h b/third_party/tle/dialect/include/Conversion/TleToLLVM/LocalPointersOpToLLVM.h index e553bcfb8a..caa9444481 100644 --- a/third_party/tle/dialect/include/Conversion/TleToLLVM/LocalPointersOpToLLVM.h +++ b/third_party/tle/dialect/include/Conversion/TleToLLVM/LocalPointersOpToLLVM.h @@ -8,6 +8,10 @@ namespace mlir::triton::tle { void populateLocalPointersOpToLLVMPatterns( mlir::LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, RewritePatternSet &patterns, PatternBenefit benefit); -} + +void populateRemotePointersOpToLLVMPatterns( + mlir::LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit); +} // namespace mlir::triton::tle #endif diff --git a/third_party/tle/dialect/include/Conversion/TleToLLVM/RemotePointerUtils.h b/third_party/tle/dialect/include/Conversion/TleToLLVM/RemotePointerUtils.h index 0a582bc19e..a7bd43b271 100644 --- a/third_party/tle/dialect/include/Conversion/TleToLLVM/RemotePointerUtils.h +++ b/third_party/tle/dialect/include/Conversion/TleToLLVM/RemotePointerUtils.h @@ -3,37 +3,8 @@ #include "mlir/IR/Value.h" -#include -#include - -namespace mlir { -class ConversionPatternRewriter; -} - -namespace mlir::triton { -class ModuleAxisInfoAnalysis; -} - namespace mlir::triton::tle { -struct RemotePointerInfo { - std::optional constCTAId; - Value dynamicCTAId; - Value basePtr; - Value vectorHintPtr; - bool stripShardOffsetFromPtr = false; - - bool hasRemoteCTAId() const { return constCTAId || dynamicCTAId; } -}; - -bool isTlePointerValue(Value ptr); - -RemotePointerInfo -getRemotePointerInfoFromValue(Value ptr, ConversionPatternRewriter &rewriter); - -unsigned inferTlePointerVectorSize(Value ptr, - ModuleAxisInfoAnalysis &axisAnalysisPass); - unsigned inferTlePointerLayoutVectorHint(Value ptr); } // namespace mlir::triton::tle diff --git a/third_party/tle/dialect/include/IR/TleOps.td b/third_party/tle/dialect/include/IR/TleOps.td index a35295f218..61ed930273 100644 --- a/third_party/tle/dialect/include/IR/TleOps.td +++ b/third_party/tle/dialect/include/IR/TleOps.td @@ -71,20 +71,29 @@ def Tle_LocalPointersOp : Tle_Op<"local_pointers", [Pure]> { let hasVerifier = 1; } -def Tle_DistributedBarrierOp - : Tle_Op<"distributed_barrier", [MemoryEffects<[MemRead, MemWrite]>]> { - let arguments = (ins OptionalAttr:$group_kind, - OptionalAttr:$group_rank, - OptionalAttr:$group_shape, - OptionalAttr:$group_axes, - OptionalAttr:$group_mask); +def Tle_ExclusiveCumsumOp : Tle_Op<"exclusive_cumsum", [Pure, SameOperandsAndResultEncoding]> { + let arguments = (ins TT_Tensor:$src, I32Attr:$axis, BoolAttr:$reverse); + let results = (outs TT_Tensor:$exclusive, TT_Type:$total); + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($exclusive) `,` type($total)"; + let hasVerifier = 1; +} + +def Tle_DistributedBarrierOp : Tle_Op<"distributed_barrier", + [MemoryEffects<[MemRead, MemWrite]>]> { + let arguments = (ins + OptionalAttr:$group_kind, + OptionalAttr:$group_rank, + OptionalAttr:$group_shape, + OptionalAttr:$group_axes, + OptionalAttr:$group_mask + ); let assemblyFormat = "attr-dict"; let hasVerifier = 1; } def Tle_RemotePointersOp : Tle_Op<"remote_pointers", [Pure]> { - let arguments = (ins TT_Tensor:$src, TT_Int:$shard_id); - let results = (outs TT_Tensor:$result); + let arguments = (ins Tle_LocalPointerResultType:$src, TT_Int:$shard_id); + let results = (outs Tle_LocalPointerResultType:$result); let hasVerifier = 1; } diff --git a/third_party/tle/dialect/include/Transforms/Passes.td b/third_party/tle/dialect/include/Transforms/Passes.td index 5c14ada2b5..9f52f432cc 100644 --- a/third_party/tle/dialect/include/Transforms/Passes.td +++ b/third_party/tle/dialect/include/Transforms/Passes.td @@ -18,14 +18,14 @@ def TritonTleEarlyAssignMemorySpace "mlir::triton::tle::TleDialect"]; } -def TritonTleAssignLocalPointersEncoding - : Pass<"triton-tle-assign-local-pointers-encoding", "mlir::ModuleOp"> { - let summary = "assign shared encodings to local_pointers results"; +def TritonTleSelectEncodings + : Pass<"triton-tle-select-encodings", "mlir::ModuleOp"> { + let summary = "select shared encodings for local pointer and related users"; let description = [{ - This pass ensures `tle.local_pointers` ops produced by the TLE frontend - carry shared-memory pointer element types and a distributed encoding so - downstream conversions do not infer layouts during the LLVM lowering. + This pass performs shared-memory pointer canonicalization and encoding + selection for `tle.local_pointers` and dependent users so downstream + conversion sees stable, cost-aware layouts. }]; let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", @@ -50,6 +50,70 @@ def TritonTleInsertLocalPointerBarriers "mlir::triton::TritonDialect"]; } +def TritonTleOptimizeLocalPointerLoads + : Pass<"triton-tle-optimize-local-pointer-loads", "mlir::ModuleOp"> { + let summary = "rewrite full-view tle.local_pointers loads into ttg.local_load"; + + let description = [{ + This pass matches ``tl.load`` operations whose pointer operand is produced + by ``tle.local_pointers`` with full-view indexing (including zero-indices + form) and rewrites them into ``ttg.local_load`` directly from the backing + memdesc. The transform is conservative and only fires when the load has no + mask/other operands. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect", + "mlir::triton::tle::TleDialect"]; +} + +def TritonTleOptimizeLocalPointerStores + : Pass<"triton-tle-optimize-local-pointer-stores", "mlir::ModuleOp"> { + let summary = "rewrite tle.local_pointers stores into ttg.local_store"; + + let description = [{ + This pass rewrites ``tl.store`` operations that target pointers produced by + ``tle.local_pointers`` into ``ttg.local_store`` to avoid pointer-based + shared-memory stores in the backend. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect", + "mlir::triton::tle::TleDialect"]; +} + +def TritonTleOptimizeExclusiveCumsumLayouts + : Pass<"triton-tle-optimize-exclusive-cumsum-layouts", "mlir::ModuleOp"> { + let summary = "fold convert_layout around tle.exclusive_cumsum"; + + let description = [{ + This pass folds patterns of the form: + ttg.convert_layout -> tle.exclusive_cumsum -> ttg.convert_layout + when the source and sink layouts can be made identical, so the cumsum + runs directly on the canonical layout and redundant converts are removed. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect", + "mlir::triton::tle::TleDialect"]; +} + +def TritonTleLowerExclusiveCumsum + : Pass<"triton-tle-lower-exclusive-cumsum", "mlir::ModuleOp"> { + let summary = "lower tle.exclusive_cumsum with a specialized Triton IR pattern"; + + let description = [{ + This pass lowers ``tle.exclusive_cumsum`` into Triton IR with a + cumsum-specialized pattern: one ``tt.scan`` for prefix accumulation and an + index-select ``tt.reduce`` to extract the tail value used as total sum. + This avoids a second add-reduction over the original input. + }]; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::triton::TritonDialect", + "mlir::triton::tle::TleDialect"]; +} + def TritonTleLowerAsyncLoad : Pass<"triton-tle-lower-async-load", "mlir::ModuleOp"> { let summary = "Lower TLE async load operations"; diff --git a/third_party/tle/dialect/lib/Analysis/AxisInfoExt.cpp b/third_party/tle/dialect/lib/Analysis/AxisInfoExt.cpp index 07ab310bfc..da61ab9d3e 100644 --- a/third_party/tle/dialect/lib/Analysis/AxisInfoExt.cpp +++ b/third_party/tle/dialect/lib/Analysis/AxisInfoExt.cpp @@ -25,6 +25,15 @@ int64_t multiplyDivisor(int64_t lhs, int64_t rhs) { return std::abs(lhs * rhs); } +int64_t saturatingMultiplyDivisor(int64_t lhs, int64_t rhs) { + constexpr int64_t kMax = (int64_t(1) << (sizeof(int64_t) * 8 - 2)); + if (lhs == 0 || rhs == 0) + return 0; + if (lhs > kMax / rhs) + return kMax; + return multiplyDivisor(lhs, rhs); +} + class TleLocalPointersOpAxisInfoVisitor final : public AxisInfoVisitor { public: AxisInfo @@ -34,16 +43,66 @@ class TleLocalPointersOpAxisInfoVisitor final : public AxisInfoVisitor { if (!local || operands.size() < 2) return AxisInfo(); - auto resultTy = dyn_cast(local.getResult().getType()); - if (!resultTy) - return AxisInfo(); - auto memDescTy = dyn_cast(local.getSrc().getType()); if (!memDescTy) return AxisInfo(); - const int rank = resultTy.getRank(); + auto resultTensorTy = + dyn_cast(local.getResult().getType()); + auto resultPtrTy = dyn_cast(local.getResult().getType()); + + // Scalar pointer result: preserve base shared-memory alignment so later + // `tt.splat + tt.addptr` can still infer vectorization width. + if (!resultTensorTy && resultPtrTy) { + const int rank = 1; + int64_t elemBytes = + std::max(1, getPointeeBitWidth(resultPtrTy) / 8); + int64_t baseAlignBytes = elemBytes; + if (auto sharedEnc = dyn_cast( + memDescTy.getEncoding())) + baseAlignBytes = + std::max(baseAlignBytes, sharedEnc.getAlignment()); + + int64_t offsetDivElems = highestPowOf2Divisor(0); + bool hasConstOffset = true; + int64_t constOffsetElems = 0; + const auto memShape = memDescTy.getShape(); + const size_t maxTerms = std::min(memShape.size(), operands.size() - 1); + for (size_t i = 0; i < maxTerms; ++i) { + const AxisInfo &idxInfo = operands[i + 1]->getValue(); + if (idxInfo.getRank() == 0) + continue; + int64_t stride = 1; + for (size_t j = i + 1; j < memShape.size(); ++j) + stride *= memShape[j]; + int64_t strideDiv = highestPowOf2Divisor(stride); + int64_t idxDiv = idxInfo.getDivisibility(0); + if (idxInfo.getContiguity(0) > 1 && strideDiv != 1) + idxDiv = 1; + int64_t termDiv = multiplyDivisor(idxDiv, strideDiv); + offsetDivElems = std::gcd(offsetDivElems, termDiv); + + if (hasConstOffset && idxInfo.getConstantValue().has_value()) + constOffsetElems += idxInfo.getConstantValue().value() * stride; + else + hasConstOffset = false; + } + + int64_t offsetDivBytes = + saturatingMultiplyDivisor(offsetDivElems, elemBytes); + int64_t ptrDivBytes = std::gcd(baseAlignBytes, offsetDivBytes); + std::optional constantValue = std::nullopt; + if (hasConstOffset) + constantValue = constOffsetElems * elemBytes; + return AxisInfo(/*contiguity=*/{1}, + /*divisibility=*/{ptrDivBytes}, + /*constancy=*/{1}, constantValue); + } + + if (!resultTensorTy) + return AxisInfo(); + const int rank = resultTensorTy.getRank(); if (rank == 0) return AxisInfo(); @@ -123,13 +182,14 @@ class TleLocalPointersOpAxisInfoVisitor final : public AxisInfoVisitor { return AxisInfo(); // Pointer divisibility is tracked in bytes for alignment queries. - auto ptrTy = dyn_cast(resultTy.getElementType()); + auto ptrTy = dyn_cast(resultTensorTy.getElementType()); int64_t elemBytes = 1; if (ptrTy) elemBytes = std::max(1, getPointeeBitWidth(ptrTy) / 8); AxisInfo::DimVectorT byteDivisibility = offsetInfo.getDivisibility(); for (int d = 0; d < rank; ++d) - byteDivisibility[d] = multiplyDivisor(byteDivisibility[d], elemBytes); + byteDivisibility[d] = + saturatingMultiplyDivisor(byteDivisibility[d], elemBytes); std::optional constantValue = std::nullopt; if (offsetInfo.getConstantValue().has_value()) diff --git a/third_party/tle/dialect/lib/Conversion/TleToLLVM/CMakeLists.txt b/third_party/tle/dialect/lib/Conversion/TleToLLVM/CMakeLists.txt index da81241b55..e2fc03fe1a 100644 --- a/third_party/tle/dialect/lib/Conversion/TleToLLVM/CMakeLists.txt +++ b/third_party/tle/dialect/lib/Conversion/TleToLLVM/CMakeLists.txt @@ -1,6 +1,7 @@ add_triton_library(TleToLLVM DistributedBarrierOpToLLVM.cpp DSLRegionOpToLLVM.cpp + ExclusiveCumsumOpToLLVM.cpp ExtractOpToLLVM.cpp LocalPointersOpToLLVM.cpp PackOpToLLVM.cpp diff --git a/third_party/tle/dialect/lib/Conversion/TleToLLVM/ExclusiveCumsumOpToLLVM.cpp b/third_party/tle/dialect/lib/Conversion/TleToLLVM/ExclusiveCumsumOpToLLVM.cpp new file mode 100644 index 0000000000..e1b166eb17 --- /dev/null +++ b/third_party/tle/dialect/lib/Conversion/TleToLLVM/ExclusiveCumsumOpToLLVM.cpp @@ -0,0 +1,461 @@ +#include "tle/dialect/include/Conversion/TleToLLVM/ExclusiveCumsumOpToLLVM.h" + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "third_party/nvidia/include/TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" +#include "tle/dialect/include/IR/Dialect.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/STLExtras.h" +#include +#include + +namespace { + +using namespace mlir; +namespace tle = mlir::triton::tle; + +static Value createZeroConstant(Location loc, + ConversionPatternRewriter &rewriter, Type ty) { + if (auto intTy = dyn_cast(ty)) { + return LLVM::ConstantOp::create(rewriter, loc, ty, + rewriter.getIntegerAttr(intTy, 0)); + } + if (auto floatTy = dyn_cast(ty)) { + return LLVM::ConstantOp::create(rewriter, loc, ty, + rewriter.getFloatAttr(floatTy, 0.0)); + } + return Value(); +} + +static Value createAdd(Location loc, ConversionPatternRewriter &rewriter, + Value lhs, Value rhs, Type elemTy) { + if (isa(elemTy)) + return LLVM::FAddOp::create(rewriter, loc, lhs, rhs); + if (isa(elemTy)) + return LLVM::AddOp::create(rewriter, loc, lhs, rhs); + return Value(); +} + +static Value createSub(Location loc, ConversionPatternRewriter &rewriter, + Value lhs, Value rhs, Type elemTy) { + if (isa(elemTy)) + return LLVM::FSubOp::create(rewriter, loc, lhs, rhs); + if (isa(elemTy)) + return LLVM::SubOp::create(rewriter, loc, lhs, rhs); + return Value(); +} + +static Value castToI32(Location loc, ConversionPatternRewriter &rewriter, + Value idx) { + auto i32Ty = rewriter.getI32Type(); + auto idxTy = dyn_cast(idx.getType()); + if (!idxTy) + return Value(); + if (idxTy.getWidth() == 32) + return idx; + if (idxTy.getWidth() > 32) + return LLVM::TruncOp::create(rewriter, loc, i32Ty, idx); + return LLVM::ZExtOp::create(rewriter, loc, i32Ty, idx); +} + +static Value getSharedElemPtr(Location loc, TritonLLVMOpBuilder &b, Value base, + Type elemTy, Value idx) { + return b.gep(base.getType(), elemTy, base, idx); +} + +static std::pair +createIfThenBlocks(ConversionPatternRewriter &rewriter, Location loc, + Value condition) { + Block *prevBlock = rewriter.getInsertionBlock(); + Block *ifBlock = rewriter.splitBlock(prevBlock, rewriter.getInsertionPoint()); + Block *thenBlock = rewriter.splitBlock(ifBlock, ifBlock->begin()); + rewriter.setInsertionPointToEnd(ifBlock); + cf::BranchOp::create(rewriter, loc, thenBlock); + rewriter.setInsertionPointToEnd(prevBlock); + cf::CondBranchOp::create(rewriter, loc, condition, ifBlock, thenBlock); + return {ifBlock, thenBlock}; +} + +static Value branchSelect(ConversionPatternRewriter &rewriter, Location loc, + Value condition, Value trueValue, Value falseValue) { + Block *prevBlock = rewriter.getInsertionBlock(); + Block *ifBlock = rewriter.splitBlock(prevBlock, rewriter.getInsertionPoint()); + Block *mergeBlock = rewriter.splitBlock(ifBlock, ifBlock->begin()); + mergeBlock->addArgument(trueValue.getType(), loc); + rewriter.setInsertionPointToEnd(ifBlock); + cf::BranchOp::create(rewriter, loc, mergeBlock, ValueRange{trueValue}); + rewriter.setInsertionPointToEnd(prevBlock); + cf::CondBranchOp::create(rewriter, loc, condition, ifBlock, mergeBlock, + ValueRange{falseValue}); + rewriter.setInsertionPointToStart(mergeBlock); + return mergeBlock->getArgument(0); +} + +static Value createWarpScanStepI32(Location loc, + ConversionPatternRewriter &rewriter, + Value val, int offset) { + auto intTy = dyn_cast(val.getType()); + if (!intTy || intTy.getWidth() != 32) + return Value(); + + mlir::triton::PTXBuilder ptxBuilder; + auto *out = ptxBuilder.newOperand("=r", /*init=*/false); + auto *in = ptxBuilder.newOperand(val, "r"); + auto i32Ty = rewriter.getI32Type(); + auto *offsetOpr = ptxBuilder.newOperand( + LLVM::ConstantOp::create(rewriter, loc, i32Ty, + rewriter.getI32IntegerAttr(offset)), + "r"); + auto *clampOpr = ptxBuilder.newOperand( + LLVM::ConstantOp::create(rewriter, loc, i32Ty, + rewriter.getI32IntegerAttr(0)), + "r"); + auto *maskOpr = ptxBuilder.newOperand( + LLVM::ConstantOp::create(rewriter, loc, i32Ty, + rewriter.getI32IntegerAttr(-1)), + "r"); + std::string ptx = "{\n" + " .reg .s32 r0;\n" + " .reg .pred p;\n" + " shfl.sync.up.b32 r0|p, $1, $2, $3, $4;\n" + " @p add.s32 r0, r0, $1;\n" + " mov.s32 $0, r0;\n" + "}\n"; + auto &shflScan = *ptxBuilder.create(ptx); + shflScan({out, in, offsetOpr, clampOpr, maskOpr}, + /*onlyAttachMLIRArgs=*/true); + return ptxBuilder.launch(rewriter, loc, val.getType(), + /*hasSideEffects=*/false); +} + +struct ExclusiveCumsumOpConversion + : public ConvertOpToLLVMPattern { + ExclusiveCumsumOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(tle::ExclusiveCumsumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcTy = dyn_cast(op.getSrc().getType()); + if (!srcTy || srcTy.getRank() != 1) + return rewriter.notifyMatchFailure(op, "expects rank-1 tensor source"); + int64_t axisExtent = srcTy.getShape()[0]; + if (ShapedType::isDynamic(axisExtent) || axisExtent <= 0) + return rewriter.notifyMatchFailure( + op, "expects static, positive axis extent"); + + Location loc = op.getLoc(); + auto *typeConverter = getTypeConverter(); + Type elemTy = srcTy.getElementType(); + Type llvmElemTy = typeConverter->convertType(elemTy); + if (!isa(llvmElemTy)) { + return rewriter.notifyMatchFailure(op, + "unsupported element type for cumsum"); + } + + auto mod = op->getParentOfType(); + const int threadsPerWarp = + triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + const int numWarps = triton::gpu::lookupNumWarps(op); + const int numThreadsPerCTA = threadsPerWarp * numWarps; + + Value zero = createZeroConstant(loc, rewriter, llvmElemTy); + if (!zero) + return rewriter.notifyMatchFailure(op, "failed to materialize zero"); + + auto inputVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + auto inputIndices = + emitIndices(loc, rewriter, targetInfo, srcTy.getEncoding(), srcTy, + /*withCTAOffset=*/false); + if (inputVals.size() != inputIndices.size()) + return rewriter.notifyMatchFailure(op, "value/index size mismatch"); + + Value baseSharedMem = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value threadId = getThreadId(rewriter, loc); + Value trueVal = b.true_val(); + + auto getElemPtr = [&](Value logicalIndex) -> Value { + return getSharedElemPtr(loc, b, baseSharedMem, llvmElemTy, logicalIndex); + }; + + // TRT/CUB-aligned fastpath for topk histogram rounds: + // - rank-1, one element per thread + // - no reverse remapping + // - full CTA participates (axisExtent == threads per CTA) + // This avoids the extra shared-memory store/load round-trip used by the + // generic logical-index path. + if (!op.getReverse() && inputVals.size() == 1 && + axisExtent == static_cast(numThreadsPerCTA) && + threadsPerWarp == 32 && numWarps > 0 && numWarps <= 32) { + Value laneId = b.and_(threadId, b.i32_val(threadsPerWarp - 1)); + Value warpId = b.lshr(threadId, b.i32_val(5)); + Value orderedVal = inputVals.front(); + Value scanVal = orderedVal; + for (int offset = 1; offset < threadsPerWarp; offset <<= 1) { + if (Value scanStep = + createWarpScanStepI32(loc, rewriter, scanVal, offset)) { + scanVal = scanStep; + } else { + Value shfl = targetInfo.shuffleUp(rewriter, loc, scanVal, offset); + Value hasPred = b.icmp_sge(laneId, b.i32_val(offset)); + Value combined = createAdd(loc, rewriter, scanVal, shfl, llvmElemTy); + if (!combined) + return rewriter.notifyMatchFailure(op, + "unsupported add in warp scan"); + scanVal = branchSelect(rewriter, loc, hasPred, combined, scanVal); + } + } + + Value warpTotal = + targetInfo.shuffleIdx(rewriter, loc, scanVal, threadsPerWarp - 1); + Value isWarpTail = b.icmp_eq(laneId, b.i32_val(threadsPerWarp - 1)); + Value warpSlotBase = b.i32_val(static_cast(axisExtent)); + Value warpSlot = b.add(warpSlotBase, warpId); + Value warpSlotPtr = getElemPtr(warpSlot); + auto [tailStoreBlock, tailContBlock] = + createIfThenBlocks(rewriter, loc, isWarpTail); + rewriter.setInsertionPointToStart(tailStoreBlock); + targetInfo.storeShared(rewriter, loc, warpSlotPtr, warpTotal, trueVal); + rewriter.setInsertionPointToStart(tailContBlock); + + targetInfo.barrier(loc, rewriter); + + Value totalSlot = b.add(warpSlotBase, b.i32_val(numWarps)); + Value totalPtrFast = getElemPtr(totalSlot); + Value isThread0 = b.icmp_eq(threadId, b.i32_val(0)); + auto [ifBlock, thenBlock] = createIfThenBlocks(rewriter, loc, isThread0); + rewriter.setInsertionPointToStart(ifBlock); + Value running = zero; + for (int w = 0; w < numWarps; ++w) { + Value slot = b.add(warpSlotBase, b.i32_val(w)); + Value slotPtr = getElemPtr(slot); + Value warpSum = + targetInfo.loadShared(rewriter, loc, slotPtr, llvmElemTy, trueVal); + targetInfo.storeShared(rewriter, loc, slotPtr, running, trueVal); + Value next = createAdd(loc, rewriter, running, warpSum, llvmElemTy); + if (!next) { + return rewriter.notifyMatchFailure( + op, "unsupported add in block-prefix scan"); + } + running = next; + } + targetInfo.storeShared(rewriter, loc, totalPtrFast, running, trueVal); + rewriter.setInsertionPointToStart(thenBlock); + targetInfo.barrier(loc, rewriter); + + Value blockPrefixPtr = getElemPtr(warpSlot); + Value blockPrefix = targetInfo.loadShared(rewriter, loc, blockPrefixPtr, + llvmElemTy, trueVal); + Value inclusiveOrdered = + createAdd(loc, rewriter, scanVal, blockPrefix, llvmElemTy); + Value exclusiveOrdered = + createSub(loc, rewriter, inclusiveOrdered, orderedVal, llvmElemTy); + if (!exclusiveOrdered) + return rewriter.notifyMatchFailure( + op, "unsupported sub in ordered exclusive"); + + Value exclusiveRes = + packLLElements(loc, typeConverter, + SmallVector{exclusiveOrdered}, rewriter, srcTy); + Value totalRes = targetInfo.loadShared(rewriter, loc, totalPtrFast, + llvmElemTy, trueVal); + rewriter.replaceOp(op, ValueRange{exclusiveRes, totalRes}); + return success(); + } + + Value nMinus1 = b.i32_val(static_cast(axisExtent - 1)); + auto getLogicalIndex = [&](Value idx) -> Value { + if (!op.getReverse()) + return idx; + return b.sub(nMinus1, idx); + }; + + for (auto [val, idxVec] : llvm::zip_equal(inputVals, inputIndices)) { + if (idxVec.size() != 1) + return rewriter.notifyMatchFailure(op, "expects rank-1 indices"); + Value idx = castToI32(loc, rewriter, idxVec[0]); + if (!idx) + return rewriter.notifyMatchFailure(op, "index must be integer"); + Value logicalIndex = getLogicalIndex(idx); + Value ptr = getElemPtr(logicalIndex); + targetInfo.storeShared(rewriter, loc, ptr, val, trueVal); + } + + // Ensure all logical-index stores are visible before scan reads. + targetInfo.barrier(loc, rewriter); + + // Fast path (TRT-style): one logical element per thread order, then + // warp-scan + shared-memory cross-warp prefix scan. + // This is the dominant configuration for topk histogram threshold search. + if (inputVals.size() == 1 && axisExtent <= numThreadsPerCTA && + threadsPerWarp == 32 && numWarps > 0 && numWarps <= 32) { + Value axisExtentVal = b.i32_val(static_cast(axisExtent)); + Value laneId = b.and_(threadId, b.i32_val(threadsPerWarp - 1)); + Value warpId = b.lshr(threadId, b.i32_val(5)); + Value activeOrdered = b.icmp_ult(threadId, axisExtentVal); + + Value orderedPtr = getElemPtr(threadId); + Value orderedVal = targetInfo.loadShared(rewriter, loc, orderedPtr, + llvmElemTy, activeOrdered); + orderedVal = b.select(activeOrdered, orderedVal, zero); + + Value scanVal = orderedVal; + for (int offset = 1; offset < threadsPerWarp; offset <<= 1) { + if (Value scanStep = + createWarpScanStepI32(loc, rewriter, scanVal, offset)) { + scanVal = scanStep; + } else { + Value shfl = targetInfo.shuffleUp(rewriter, loc, scanVal, offset); + Value hasPred = b.icmp_sge(laneId, b.i32_val(offset)); + Value combined = createAdd(loc, rewriter, scanVal, shfl, llvmElemTy); + if (!combined) + return rewriter.notifyMatchFailure(op, + "unsupported add in warp scan"); + scanVal = branchSelect(rewriter, loc, hasPred, combined, scanVal); + } + } + + Value warpTotal = + targetInfo.shuffleIdx(rewriter, loc, scanVal, threadsPerWarp - 1); + Value isWarpTail = b.icmp_eq(laneId, b.i32_val(threadsPerWarp - 1)); + Value warpSlotBase = b.i32_val(static_cast(axisExtent)); + Value warpSlot = b.add(warpSlotBase, warpId); + Value warpSlotPtr = getElemPtr(warpSlot); + auto [tailStoreBlock, tailContBlock] = + createIfThenBlocks(rewriter, loc, isWarpTail); + rewriter.setInsertionPointToStart(tailStoreBlock); + targetInfo.storeShared(rewriter, loc, warpSlotPtr, warpTotal, trueVal); + rewriter.setInsertionPointToStart(tailContBlock); + + targetInfo.barrier(loc, rewriter); + + Value totalSlot = b.add(warpSlotBase, b.i32_val(numWarps)); + Value totalPtrFast = getElemPtr(totalSlot); + Value isThread0 = b.icmp_eq(threadId, b.i32_val(0)); + auto [ifBlock, thenBlock] = createIfThenBlocks(rewriter, loc, isThread0); + rewriter.setInsertionPointToStart(ifBlock); + Value running = zero; + for (int w = 0; w < numWarps; ++w) { + Value slot = b.add(warpSlotBase, b.i32_val(w)); + Value slotPtr = getElemPtr(slot); + Value warpSum = + targetInfo.loadShared(rewriter, loc, slotPtr, llvmElemTy, trueVal); + targetInfo.storeShared(rewriter, loc, slotPtr, running, trueVal); + Value next = createAdd(loc, rewriter, running, warpSum, llvmElemTy); + if (!next) { + return rewriter.notifyMatchFailure( + op, "unsupported add in block-prefix scan"); + } + running = next; + } + targetInfo.storeShared(rewriter, loc, totalPtrFast, running, trueVal); + rewriter.setInsertionPointToStart(thenBlock); + + targetInfo.barrier(loc, rewriter); + + Value blockPrefixPtr = getElemPtr(warpSlot); + Value blockPrefix = targetInfo.loadShared(rewriter, loc, blockPrefixPtr, + llvmElemTy, trueVal); + Value inclusiveOrdered = + createAdd(loc, rewriter, scanVal, blockPrefix, llvmElemTy); + Value exclusiveOrdered = + createSub(loc, rewriter, inclusiveOrdered, orderedVal, llvmElemTy); + if (!exclusiveOrdered) + return rewriter.notifyMatchFailure( + op, "unsupported sub in ordered exclusive"); + targetInfo.storeShared(rewriter, loc, orderedPtr, exclusiveOrdered, + activeOrdered); + // The gather below reads by logical index (reverse remap may read values + // produced by different threads/warps). Ensure all ordered exclusive + // stores are visible before any thread starts loading gathered values. + targetInfo.barrier(loc, rewriter); + + SmallVector exclusiveVals; + exclusiveVals.reserve(inputVals.size()); + for (auto idxVec : inputIndices) { + Value idx = castToI32(loc, rewriter, idxVec[0]); + Value logicalIndex = getLogicalIndex(idx); + Value ptr = getElemPtr(logicalIndex); + exclusiveVals.push_back( + targetInfo.loadShared(rewriter, loc, ptr, llvmElemTy, trueVal)); + } + Value exclusiveRes = + packLLElements(loc, typeConverter, exclusiveVals, rewriter, srcTy); + Value totalRes = targetInfo.loadShared(rewriter, loc, totalPtrFast, + llvmElemTy, trueVal); + rewriter.replaceOp(op, ValueRange{exclusiveRes, totalRes}); + return success(); + } + + // Fallback path: generic serial scan in shared memory by thread-0. + Value isThread0 = b.icmp_eq(threadId, b.i32_val(0)); + Type i8PtrTy = LLVM::LLVMPointerType::get( + rewriter.getContext(), targetInfo.getSharedAddressSpace()); + unsigned elemBytes = static_cast( + std::max(1, srcTy.getElementTypeBitWidth() / 8)); + int64_t totalByteOffset = + static_cast(axisExtent) * static_cast(elemBytes); + if (totalByteOffset > std::numeric_limits::max()) { + return rewriter.notifyMatchFailure(op, + "shared scratch offset exceeds i32"); + } + Value baseI8 = b.bitcast(baseSharedMem, i8PtrTy); + Value totalOffsetBytes = b.i32_val(static_cast(totalByteOffset)); + Value totalPtrI8 = b.gep(i8PtrTy, i8_ty, baseI8, totalOffsetBytes); + Value totalPtr = b.bitcast(totalPtrI8, baseSharedMem.getType()); + + Value running = zero; + for (int64_t i = 0; i < axisExtent; ++i) { + Value idx = b.i32_val(static_cast(i)); + Value ptr = getElemPtr(idx); + Value inVal = + targetInfo.loadShared(rewriter, loc, ptr, llvmElemTy, isThread0); + targetInfo.storeShared(rewriter, loc, ptr, running, isThread0); + Value next = createAdd(loc, rewriter, running, inVal, llvmElemTy); + if (!next) + return rewriter.notifyMatchFailure(op, + "unsupported add for element type"); + running = next; + } + targetInfo.storeShared(rewriter, loc, totalPtr, running, isThread0); + + targetInfo.barrier(loc, rewriter); + + SmallVector exclusiveVals; + exclusiveVals.reserve(inputVals.size()); + for (auto idxVec : inputIndices) { + Value idx = castToI32(loc, rewriter, idxVec[0]); + Value logicalIndex = getLogicalIndex(idx); + Value ptr = getElemPtr(logicalIndex); + exclusiveVals.push_back( + targetInfo.loadShared(rewriter, loc, ptr, llvmElemTy, trueVal)); + } + Value exclusiveRes = + packLLElements(loc, typeConverter, exclusiveVals, rewriter, srcTy); + Value totalRes = + targetInfo.loadShared(rewriter, loc, totalPtr, llvmElemTy, trueVal); + + rewriter.replaceOp(op, ValueRange{exclusiveRes, totalRes}); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void tle::populateExclusiveCumsumOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/tle/dialect/lib/Conversion/TleToLLVM/LocalPointersOpToLLVM.cpp b/third_party/tle/dialect/lib/Conversion/TleToLLVM/LocalPointersOpToLLVM.cpp index 80f8d86bae..0f1106cbc2 100644 --- a/third_party/tle/dialect/lib/Conversion/TleToLLVM/LocalPointersOpToLLVM.cpp +++ b/third_party/tle/dialect/lib/Conversion/TleToLLVM/LocalPointersOpToLLVM.cpp @@ -3,6 +3,7 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/DialectConversion.h" #include "tle/dialect/include/IR/Dialect.h" @@ -18,24 +19,23 @@ namespace { using namespace mlir; namespace ttg = mlir::triton::gpu; namespace tle = mlir::triton::tle; -constexpr llvm::StringLiteral kRemoteShardCarrierAttr = - "tle.remote_shard_id_carrier"; -constexpr llvm::StringLiteral kTTContiguityAttr = "tt.contiguity"; -constexpr llvm::StringLiteral kTTDivisibilityAttr = "tt.divisibility"; -constexpr llvm::StringLiteral kTTConstancyAttr = "tt.constancy"; - -void copyAxisInfoAttrs(Operation *src, Operation *dst) { - if (!src || !dst) - return; - auto tryCopy = [&](StringRef name) { - if (dst->getDiscardableAttr(name)) - return; - if (auto attr = src->getDiscardableAttr(name)) - dst->setDiscardableAttr(name, attr); - }; - tryCopy(kTTContiguityAttr); - tryCopy(kTTDivisibilityAttr); - tryCopy(kTTConstancyAttr); + +Value mapSharedToClusterPointer(ConversionPatternRewriter &rewriter, + Location loc, Value ptr, Value ctaId) { + auto ptrTy = dyn_cast(ptr.getType()); + if (!ptrTy) + return Value(); + const unsigned sharedAddrSpace = + static_cast(NVVM::NVVMMemorySpace::Shared); + const unsigned clusterSharedAddrSpace = + static_cast(NVVM::NVVMMemorySpace::SharedCluster); + if (ptrTy.getAddressSpace() == clusterSharedAddrSpace) + return ptr; + if (ptrTy.getAddressSpace() != sharedAddrSpace) + return Value(); + auto clusterPtrTy = + LLVM::LLVMPointerType::get(rewriter.getContext(), clusterSharedAddrSpace); + return NVVM::MapaOp::create(rewriter, loc, clusterPtrTy, ptr, ctaId); } struct LocalPointersOpConversion @@ -130,29 +130,61 @@ struct LocalPointersOpConversion return reportFailure("shared memory offsets rank mismatch"); auto indexVals = adaptor.getIndices(); - if (indexVals.size() != bufferRank) - return reportFailure("indices must provide buffer-rank values"); + const bool hasExplicitIndices = !indexVals.empty(); + if (hasExplicitIndices) { + if (indexVals.size() != bufferRank) + return reportFailure("indices must provide buffer-rank values"); + } else { + if (!resultTensorTy && bufferRank != 0) + return reportFailure( + "zero-index scalar local_pointers requires rank-0 buffer"); + if (resultTensorTy && resultTensorTy.getShape() != memDescTy.getShape()) + return reportFailure( + "zero-index tensor local_pointers requires full buffer shape"); + } SmallVector> indexElems; - indexElems.reserve(indexVals.size()); - for (Value indexVal : indexVals) { - if (resultTensorTy) { - auto elems = unpackLLElements(loc, indexVal, rewriter); - if (elems.size() != outVals.size()) - return reportFailure( - "indices tensors must match local_pointers result shape"); - indexElems.push_back(std::move(elems)); - } else { - Value scalar = ensureI32(indexVal); - if (!scalar) - return reportFailure("scalar indices must lower to i32 values"); - indexElems.push_back(SmallVector{scalar}); + if (hasExplicitIndices) { + indexElems.reserve(indexVals.size()); + for (Value indexVal : indexVals) { + if (resultTensorTy) { + auto elems = unpackLLElements(loc, indexVal, rewriter); + if (elems.size() != outVals.size()) + return reportFailure( + "indices tensors must match local_pointers result shape"); + indexElems.push_back(std::move(elems)); + } else { + Value scalar = ensureI32(indexVal); + if (!scalar) + return reportFailure("scalar indices must lower to i32 values"); + indexElems.push_back(SmallVector{scalar}); + } + } + } else if (resultTensorTy) { + auto fullCoords = + emitIndices(loc, rewriter, targetInfo, resultTensorTy.getEncoding(), + resultTensorTy, + /*withCTAOffset=*/false); + if (fullCoords.size() != outVals.size()) + return reportFailure( + "failed to synthesize full indices for local_pointers"); + indexElems.assign(bufferRank, SmallVector{}); + for (size_t idx = 0; idx < fullCoords.size(); ++idx) { + if (fullCoords[idx].size() != bufferRank) + return reportFailure("synthesized full indices rank mismatch"); + for (size_t dim = 0; dim < bufferRank; ++dim) { + Value coord = ensureI32(fullCoords[idx][dim]); + if (!coord) + return reportFailure( + "synthesized full indices must lower to i32 values"); + indexElems[dim].push_back(coord); + } } } for (size_t idx = 0; idx < outVals.size(); ++idx) { SmallVector idxCoords; - idxCoords.reserve(indexVals.size()); + idxCoords.reserve(bufferRank); for (size_t dim = 0; dim < indexElems.size(); ++dim) { Value val = ensureI32(indexElems[dim][idx]); if (!val) @@ -165,7 +197,10 @@ struct LocalPointersOpConversion } Value elemOffset; - if (auto paddedEnc = dyn_cast(sharedEnc)) { + if (bufferRank == 0) { + elemOffset = b.i32_val(0); + } else if (auto paddedEnc = + dyn_cast(sharedEnc)) { auto order = ttg::getOrder(sharedEnc, memDescTy.getShape()); elemOffset = LLVM::linearize(rewriter, loc, idxCoords, bufferShape, order); @@ -177,9 +212,43 @@ struct LocalPointersOpConversion logicalOffsets.push_back({dim, offset}); LinearLayout sharedLayout = ttg::toLinearLayout(memDescTy); sharedLayout = sharedLayout.sublayout({kOffset}, dimNames); - elemOffset = applyLinearLayout(loc, rewriter, sharedLayout.invert(), - logicalOffsets)[0] - .second; + LinearLayout invSharedLayout = sharedLayout.invert(); + + // Be robust to non-canonical input ordering produced by upstream + // transformations: reorder offsets to match the inverted layout's + // expected in-dim order before applying the mapping. + SmallVector> orderedLogicalOffsets; + orderedLogicalOffsets.reserve(invSharedLayout.getNumInDims()); + for (StringAttr inDim : invSharedLayout.getInDimNames()) { + bool found = false; + for (auto &logical : logicalOffsets) { + if (logical.first == inDim) { + orderedLogicalOffsets.push_back(logical); + found = true; + break; + } + } + if (!found) + return reportFailure( + "missing logical offset for inverted shared-layout in-dim"); + } + + auto remappedOffsets = applyLinearLayout(loc, rewriter, invSharedLayout, + orderedLogicalOffsets); + if (remappedOffsets.empty()) + return reportFailure("failed to remap shared-memory linear offsets"); + + bool foundOffset = false; + for (auto &mapped : remappedOffsets) { + if (mapped.first == kOffset) { + elemOffset = mapped.second; + foundOffset = true; + break; + } + } + if (!foundOffset) + return reportFailure( + "remapped shared layout does not contain offset"); } Value byteOffset = elemOffset; @@ -212,29 +281,67 @@ struct LocalPointersOpConversion }; struct RemotePointersOpConversion - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + : public ConvertOpToLLVMPattern { + RemotePointersOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} LogicalResult matchAndRewrite(tle::RemotePointersOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Value offset = op.getShardId(); - if (auto srcTy = dyn_cast(op.getSrc().getType())) { - auto shardTy = dyn_cast(offset.getType()); - if (!shardTy || shardTy.getShape() != srcTy.getShape() || - shardTy.getEncoding() != srcTy.getEncoding()) { - auto offsetTy = RankedTensorType::get( - srcTy.getShape(), offset.getType(), srcTy.getEncoding()); - offset = - rewriter.create(op.getLoc(), offsetTy, offset); - } + auto loc = op.getLoc(); + auto *typeConverter = getTypeConverter(); + auto reportFailure = [&](StringRef msg) -> LogicalResult { + llvm::errs() << "[RemotePointersOpConversion] " << msg << "\n"; + return rewriter.notifyMatchFailure(op, msg); + }; + + auto srcElems = unpackLLElements(loc, adaptor.getSrc(), rewriter); + if (srcElems.empty()) + return reportFailure("expected non-empty source pointer elements"); + + auto shardElems = unpackLLElements(loc, adaptor.getShardId(), rewriter); + if (shardElems.empty()) + return reportFailure("expected non-empty shard_id elements"); + if (shardElems.size() != 1 && shardElems.size() != srcElems.size()) + return reportFailure( + "shard_id must be scalar or match source pointer element count"); + + auto ensureI32 = [&](Value v) -> Value { + if (!v) + return Value(); + if (v.getType().isInteger(32)) + return v; + auto intTy = dyn_cast(v.getType()); + if (!intTy) + return Value(); + if (intTy.getWidth() > 32) + return rewriter.create(loc, rewriter.getI32Type(), v); + if (intTy.isUnsigned()) + return rewriter.create(loc, rewriter.getI32Type(), v); + return rewriter.create(loc, rewriter.getI32Type(), v); + }; + + SmallVector mappedPtrs; + mappedPtrs.reserve(srcElems.size()); + for (auto [idx, srcPtr] : llvm::enumerate(srcElems)) { + if (!isa(srcPtr.getType())) + return reportFailure("source elements must lower to LLVM pointers"); + Value shardVal = + shardElems.size() == 1 ? shardElems.front() : shardElems[idx]; + Value ctaId = ensureI32(shardVal); + if (!ctaId) + return reportFailure("shard_id must lower to i32 scalar elements"); + Value mappedPtr = mapSharedToClusterPointer(rewriter, loc, srcPtr, ctaId); + if (!mappedPtr) + return reportFailure("source pointers must lower to " + "shared/cluster-shared address space"); + mappedPtrs.push_back(mappedPtr); } - auto addPtr = rewriter.create(op.getLoc(), op.getType(), - op.getSrc(), offset); - addPtr->setAttr(kRemoteShardCarrierAttr, rewriter.getUnitAttr()); - copyAxisInfoAttrs(op.getOperation(), addPtr.getOperation()); - copyAxisInfoAttrs(op.getSrc().getDefiningOp(), addPtr.getOperation()); - rewriter.replaceOp(op, addPtr.getResult()); + + Value packed = + packLLElements(loc, typeConverter, mappedPtrs, rewriter, op.getType()); + rewriter.replaceOp(op, packed); return success(); } }; @@ -245,6 +352,11 @@ void tle::populateLocalPointersOpToLLVMPatterns( LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, targetInfo, benefit); - patterns.add(typeConverter, patterns.getContext(), - benefit); +} + +void tle::populateRemotePointersOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + (void)targetInfo; + patterns.add(typeConverter, benefit); } diff --git a/third_party/tle/dialect/lib/Conversion/TleToLLVM/RemotePointerUtils.cpp b/third_party/tle/dialect/lib/Conversion/TleToLLVM/RemotePointerUtils.cpp index 07ef867132..14b7903057 100644 --- a/third_party/tle/dialect/lib/Conversion/TleToLLVM/RemotePointerUtils.cpp +++ b/third_party/tle/dialect/lib/Conversion/TleToLLVM/RemotePointerUtils.cpp @@ -1,188 +1,14 @@ #include "tle/dialect/include/Conversion/TleToLLVM/RemotePointerUtils.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/Matchers.h" -#include "mlir/Transforms/DialectConversion.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Dialect/Triton/IR/Dialect.h" +#include "mlir/IR/BuiltinTypes.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include "llvm/ADT/DenseSet.h" -namespace { - -using namespace mlir; -constexpr llvm::StringLiteral kRemoteShardCarrierAttr = - "tle.remote_shard_id_carrier"; - -Value peelRemoteMetadataCarrier(Value ptr) { - llvm::DenseSet visited; - Value current = ptr; - while (current && visited.insert(current).second) { - Operation *def = current.getDefiningOp(); - if (!def) - break; - if (auto convert = dyn_cast(def)) { - current = convert.getSrc(); - continue; - } - if (auto bcast = dyn_cast(def)) { - current = bcast.getSrc(); - continue; - } - if (auto expand = dyn_cast(def)) { - current = expand.getSrc(); - continue; - } - if (auto reshape = dyn_cast(def)) { - current = reshape.getSrc(); - continue; - } - break; - } - return current; -} - -bool isTlePointerProducer(Operation *op) { - if (!op) - return false; - StringRef name = op->getName().getStringRef(); - return name == "tle.local_pointers" || name == "tle.remote_pointers"; -} - -} // namespace +#include namespace mlir::triton::tle { -bool isTlePointerValue(Value ptr) { - Value carrier = peelRemoteMetadataCarrier(ptr); - return isTlePointerProducer(carrier.getDefiningOp()); -} - -RemotePointerInfo -getRemotePointerInfoFromValue(Value ptr, ConversionPatternRewriter &rewriter) { - RemotePointerInfo info; - info.basePtr = ptr; - info.vectorHintPtr = ptr; - - Value carrier = peelRemoteMetadataCarrier(ptr); - info.vectorHintPtr = carrier; - Operation *defOp = carrier.getDefiningOp(); - if (!defOp) - return info; - - // Dedicated remote op path: recover vectorization hint from the source - // pointer and derive shard id from the shard operand directly. - if (defOp->getName().getStringRef() == "tle.remote_pointers") { - if (defOp->getNumOperands() >= 1) - info.vectorHintPtr = defOp->getOperand(0); - if (defOp->getNumOperands() >= 2) { - Value shard = defOp->getOperand(1); - APInt shardConst; - if (matchPattern(shard, m_ConstantInt(&shardConst))) { - info.constCTAId = static_cast(shardConst.getSExtValue()); - } else { - Value remappedShard = rewriter.getRemappedValue(shard); - info.dynamicCTAId = remappedShard ? remappedShard : shard; - } - } - } - - auto ctaAttr = defOp->getAttrOfType("tle.remote_cta_id"); - if (ctaAttr) - info.constCTAId = static_cast(ctaAttr.getInt()); - - if (auto addPtrOp = dyn_cast(defOp); - addPtrOp && addPtrOp->hasAttr(kRemoteShardCarrierAttr)) { - // Keep the remote carrier pointer as the lowering source and strip the - // synthetic shard offset during remote memory op lowering. - info.basePtr = ptr; - info.vectorHintPtr = addPtrOp.getPtr(); - info.stripShardOffsetFromPtr = true; - Value shardOffset = addPtrOp.getOffset(); - APInt shardConst; - if (matchPattern(shardOffset, m_ConstantInt(&shardConst))) { - info.constCTAId = static_cast(shardConst.getSExtValue()); - } else { - if (auto splatOp = shardOffset.getDefiningOp()) { - APInt splatConst; - if (matchPattern(splatOp.getSrc(), m_ConstantInt(&splatConst))) { - info.constCTAId = static_cast(splatConst.getSExtValue()); - } - } - DenseElementsAttr denseConst; - if (!info.constCTAId && - matchPattern(shardOffset, m_Constant(&denseConst)) && - denseConst.isSplat() && denseConst.getElementType().isInteger(32)) { - info.constCTAId = static_cast( - denseConst.getSplatValue().getSExtValue()); - } else if (!info.constCTAId) { - info.dynamicCTAId = rewriter.getRemappedValue(shardOffset); - } - } - } - - return info; -} - -unsigned inferTlePointerVectorSize(Value ptr, - ModuleAxisInfoAnalysis &axisAnalysisPass) { - auto tensorTy = dyn_cast(ptr.getType()); - if (!tensorTy || tensorTy.getRank() == 0) - return 1; - if (!tensorTy.getEncoding()) - return 1; - - auto *axisInfo = axisAnalysisPass.getAxisInfo(ptr); - if (!axisInfo || axisInfo->getRank() == 0) - return 1; - - SmallVector order; - Attribute encoding = tensorTy.getEncoding(); - if (auto dist = dyn_cast(encoding)) { - order = triton::gpu::getOrder(dist, tensorTy.getShape()); - } else if (auto shared = - dyn_cast(encoding)) { - order = triton::gpu::getOrder(shared, tensorTy.getShape()); - } else { - order = triton::gpu::getOrder(tensorTy); - } - auto contigPerThread = triton::gpu::getContigPerThread(tensorTy); - if (contigPerThread.empty()) - return 1; - - auto pointeeBitWidth = triton::getPointeeBitWidth(tensorTy); - if (pointeeBitWidth == 0) - return 1; - unsigned elemBytes = std::max(1, pointeeBitWidth / 8); - unsigned maxByType = std::max(1, 128 / pointeeBitWidth); - unsigned elemsPerThread = std::max( - 1, static_cast( - triton::gpu::getTotalElemsPerThread(ptr.getType()))); - unsigned best = 1; - // Local/remote pointer tensors can carry an encoding order whose leading - // axis does not correspond to the flattened row-major contiguous dimension. - // Probe all axes and keep the best legal vector width. - for (unsigned axis = 0; axis < static_cast(axisInfo->getRank()) && - axis < contigPerThread.size(); - ++axis) { - unsigned contiguity = std::max( - 1, - std::min(std::max(1, axisInfo->getContiguity(axis)), - contigPerThread[axis])); - unsigned divisibility = - std::max(1, axisInfo->getDivisibility(axis)); - unsigned alignment = std::min( - contiguity, std::max(1, divisibility / elemBytes)); - unsigned candidate = std::max( - 1, std::min(std::min(maxByType, alignment), - elemsPerThread)); - best = std::max(best, candidate); - } - return best; -} - unsigned inferTlePointerLayoutVectorHint(Value ptr) { auto tensorTy = dyn_cast(ptr.getType()); if (!tensorTy || tensorTy.getRank() == 0) diff --git a/third_party/tle/dialect/lib/IR/Ops.cpp b/third_party/tle/dialect/lib/IR/Ops.cpp index ca609fe787..b7b11119a3 100644 --- a/third_party/tle/dialect/lib/IR/Ops.cpp +++ b/third_party/tle/dialect/lib/IR/Ops.cpp @@ -5,6 +5,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" +#include #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" @@ -14,6 +15,8 @@ namespace mlir::triton::tle { namespace { // Triton shared-memory pointers map to LLVM address space 3 (NVVM shared). constexpr int kSharedMemoryAddressSpace = 3; +// Cluster-shared pointers map to LLVM address space 7 (NVVM shared::cluster). +constexpr int kClusterSharedMemoryAddressSpace = 7; } // namespace // ============================================================================ @@ -354,6 +357,21 @@ LogicalResult LocalPointersOp::verify() { return emitOpError() << "expects pointers to live in shared memory"; auto indices = getIndices(); + if (indices.empty()) { + if (resultTensorTy) { + if (resultTensorTy.getShape() != memDescTy.getShape()) + return emitOpError() + << "zero-index local_pointers expects tensor result shape to " + "match buffer shape"; + return success(); + } + if (!memDescTy.getShape().empty()) + return emitOpError() + << "zero-index scalar local_pointers is only valid for rank-0 " + "buffers"; + return success(); + } + if (indices.size() != memDescTy.getShape().size()) return emitOpError() << "expects indices count to match buffer rank"; @@ -400,6 +418,42 @@ LogicalResult LocalPointersOp::verify() { return success(); } +LogicalResult ExclusiveCumsumOp::verify() { + auto srcTy = dyn_cast(getSrc().getType()); + if (!srcTy) + return emitOpError() << "expects src to be a ranked tensor"; + + auto exclusiveTy = dyn_cast(getExclusive().getType()); + if (!exclusiveTy) + return emitOpError() << "expects exclusive result to be a ranked tensor"; + if (exclusiveTy != srcTy) + return emitOpError() << "expects exclusive result type to match src type"; + + // Keep semantics aligned with current DeepSeek topk use: scan over a single + // per-block histogram vector. + if (srcTy.getRank() != 1) + return emitOpError() << "currently only rank-1 tensors are supported"; + int64_t axisExtent = srcTy.getShape()[0]; + if (ShapedType::isDynamic(axisExtent) || axisExtent <= 0) + return emitOpError() << "currently only static, positive axis extent is " + "supported"; + if (axisExtent > static_cast(std::numeric_limits::max())) + return emitOpError() << "axis extent is too large"; + + const int64_t rank = srcTy.getRank(); + int64_t axis = static_cast(getAxis()); + if (axis < 0) + axis += rank; + if (axis != 0) + return emitOpError() << "currently only axis=0 is supported"; + + if (getTotal().getType() != srcTy.getElementType()) + return emitOpError() << "expects total result type to match src element " + "type"; + + return success(); +} + LogicalResult DistributedBarrierOp::verify() { auto *op = getOperation(); auto kindAttr = op->getAttrOfType("group_kind"); @@ -482,20 +536,67 @@ LogicalResult DistributedBarrierOp::verify() { } LogicalResult RemotePointersOp::verify() { - auto srcTy = dyn_cast(getSrc().getType()); - if (!srcTy) - return emitOpError() << "expects src operand to be a ranked tensor"; - auto resultTy = dyn_cast(getResult().getType()); - if (!resultTy) - return emitOpError() << "expects result to be a ranked tensor"; - if (srcTy != resultTy) - return emitOpError() << "expects result type to match src type"; - - auto ptrTy = dyn_cast(srcTy.getElementType()); - if (!ptrTy) - return emitOpError() << "expects src/result element type to be tt.ptr"; - if (ptrTy.getAddressSpace() != kSharedMemoryAddressSpace) - return emitOpError() << "expects pointers to live in shared memory"; + Type srcTy = getSrc().getType(); + Type resultTy = getResult().getType(); + auto getPtrInfo = [&](Type ty, triton::PointerType &ptr, bool &isTensor, + ArrayRef &shape, + Attribute &encoding) -> LogicalResult { + if (auto tensorTy = dyn_cast(ty)) { + ptr = dyn_cast(tensorTy.getElementType()); + if (!ptr) + return emitOpError() + << "expects tensor src/result element type to be tt.ptr"; + isTensor = true; + shape = tensorTy.getShape(); + encoding = tensorTy.getEncoding(); + return success(); + } + if (auto ptrTy = dyn_cast(ty)) { + ptr = ptrTy; + isTensor = false; + shape = ArrayRef(); + encoding = Attribute(); + return success(); + } + return emitOpError() << "expects src/result to be tensor> or " + "tt.ptr"; + }; + + triton::PointerType srcPtrTy; + triton::PointerType resultPtrTy; + bool srcIsTensor = false; + bool resultIsTensor = false; + ArrayRef srcShape; + ArrayRef resultShape; + Attribute srcEncoding; + Attribute resultEncoding; + if (failed(getPtrInfo(srcTy, srcPtrTy, srcIsTensor, srcShape, srcEncoding)) || + failed(getPtrInfo(resultTy, resultPtrTy, resultIsTensor, resultShape, + resultEncoding))) + return failure(); + + if (srcIsTensor != resultIsTensor) + return emitOpError() << "expects src/result to both be scalar pointers or " + "both be pointer tensors"; + if (srcIsTensor) { + if (srcShape != resultShape) + return emitOpError() << "expects src/result pointer tensor shapes to " + "match"; + if (srcEncoding && resultEncoding && srcEncoding != resultEncoding) + return emitOpError() << "expects src/result pointer tensor encodings to " + "match"; + } + if (srcPtrTy.getPointeeType() != resultPtrTy.getPointeeType()) + return emitOpError() << "expects src/result pointer pointee types to " + "match"; + + if (srcPtrTy.getAddressSpace() != kSharedMemoryAddressSpace) + return emitOpError() + << "expects src pointers to live in shared memory (addrspace=3)"; + if (resultPtrTy.getAddressSpace() != kClusterSharedMemoryAddressSpace) + return emitOpError() + << "expects result pointers to live in cluster shared memory " + "(addrspace=7)"; if (!getShardId().getType().isInteger(32)) return emitOpError() << "expects shard_id to be i32"; diff --git a/third_party/tle/dialect/lib/Transforms/CMakeLists.txt b/third_party/tle/dialect/lib/Transforms/CMakeLists.txt index aad8f0e9ca..bd81d269af 100644 --- a/third_party/tle/dialect/lib/Transforms/CMakeLists.txt +++ b/third_party/tle/dialect/lib/Transforms/CMakeLists.txt @@ -1,8 +1,12 @@ # flagtree tle add_triton_library(TritonTLETransforms TleEarlyAssignMemorySpace.cpp - TleAssignLocalPointersEncoding.cpp + TleSelectEncodings.cpp TleInsertLocalPointerBarriers.cpp + TleOptimizeLocalPointerLoads.cpp + TleOptimizeLocalPointerStores.cpp + TleOptimizeExclusiveCumsumLayouts.cpp + TleLowerExclusiveCumsum.cpp TleLowerAsyncLoad.cpp TleTileToLLVMUtils.cpp ExtractTileToLLVM.cpp diff --git a/third_party/tle/dialect/lib/Transforms/TleAssignLocalPointersEncoding.cpp b/third_party/tle/dialect/lib/Transforms/TleAssignLocalPointersEncoding.cpp deleted file mode 100644 index 49216aecf0..0000000000 --- a/third_party/tle/dialect/lib/Transforms/TleAssignLocalPointersEncoding.cpp +++ /dev/null @@ -1,399 +0,0 @@ -// MIT License - -// Copyright (c) 2025 The FlagOS Contributors - -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: - -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. - -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -// flagtree tle - -#include "tle/dialect/include/IR/Dialect.h" -#include "tle/dialect/include/Transforms/Passes.h" - -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Types.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "llvm/ADT/DenseSet.h" - -namespace mlir::triton::tle { - -#define GEN_PASS_DEF_TRITONTLEASSIGNLOCALPOINTERSENCODING -#include "tle/dialect/include/Transforms/Passes.h.inc" - -namespace { - -// Triton shared-memory pointers use LLVM address space 3 (NVVM shared). -constexpr int kSharedMemoryAddressSpace = 3; -constexpr StringLiteral kBarrierGroupAttr = "tle.barrier_group"; -constexpr StringLiteral kTTContiguityAttr = "tt.contiguity"; -constexpr StringLiteral kTTDivisibilityAttr = "tt.divisibility"; -constexpr StringLiteral kTTConstancyAttr = "tt.constancy"; - -static Operation *peelAxisInfoCarrier(Value value) { - llvm::DenseSet visited; - Value current = value; - while (current && visited.insert(current).second) { - Operation *def = current.getDefiningOp(); - if (!def) - break; - if (auto convert = dyn_cast(def)) { - current = convert.getSrc(); - continue; - } - if (auto bcast = dyn_cast(def)) { - current = bcast.getSrc(); - continue; - } - if (auto expand = dyn_cast(def)) { - current = expand.getSrc(); - continue; - } - if (auto reshape = dyn_cast(def)) { - current = reshape.getSrc(); - continue; - } - return def; - } - return current ? current.getDefiningOp() : nullptr; -} - -static void copyAxisInfoAttrs(Operation *src, Operation *dst) { - if (!src || !dst) - return; - auto tryCopy = [&](StringRef name) { - if (dst->getDiscardableAttr(name)) - return; - if (auto attr = src->getDiscardableAttr(name)) - dst->setDiscardableAttr(name, attr); - }; - tryCopy(kTTContiguityAttr); - tryCopy(kTTDivisibilityAttr); - tryCopy(kTTConstancyAttr); -} - -static void -collectConsumerEncodings(Value root, - llvm::SmallVectorImpl &loadEncodings, - llvm::SmallVectorImpl &storeEncodings) { - llvm::SmallVector worklist; - llvm::DenseSet visited; - auto enqueue = [&](Value v) { - if (!v) - return; - if (!visited.insert(v).second) - return; - worklist.push_back(v); - }; - - enqueue(root); - while (!worklist.empty()) { - Value current = worklist.pop_back_val(); - for (OpOperand &use : current.getUses()) { - Operation *owner = use.getOwner(); - if (auto load = dyn_cast(owner)) { - auto loadTy = dyn_cast(load.getResult().getType()); - if (loadTy && loadTy.getEncoding()) - loadEncodings.push_back(loadTy.getEncoding()); - continue; - } - if (auto store = dyn_cast(owner)) { - auto valueTy = dyn_cast(store.getValue().getType()); - if (valueTy && valueTy.getEncoding()) - storeEncodings.push_back(valueTy.getEncoding()); - continue; - } - if (auto convert = dyn_cast(owner)) { - enqueue(convert.getResult()); - continue; - } - if (auto bcast = dyn_cast(owner)) { - enqueue(bcast.getResult()); - continue; - } - if (auto expand = dyn_cast(owner)) { - enqueue(expand.getResult()); - continue; - } - if (auto reshape = dyn_cast(owner)) { - enqueue(reshape.getResult()); - continue; - } - if (auto remote = dyn_cast(owner)) { - enqueue(remote.getResult()); - continue; - } - } - } -} - -class AssignLocalPointersEncodingPass - : public impl::TritonTleAssignLocalPointersEncodingBase< - AssignLocalPointersEncodingPass> { - void runOnOperation() override { - ModuleOp module = getOperation(); - OpBuilder builder(module.getContext()); - module.walk([&](triton::tle::LocalPointersOp op) { - // Always tag local pointer ops so barrier insertion can track hazards - // across different pointer views of the same alloc. - tagDependencyGroup(op, builder); - - auto tensorTy = dyn_cast(op.getResult().getType()); - auto scalarPtrTy = - dyn_cast(op.getResult().getType()); - if (!tensorTy && !scalarPtrTy) - return; - auto ptrTy = - tensorTy ? dyn_cast(tensorTy.getElementType()) - : scalarPtrTy; - if (!ptrTy) - return; - bool updated = false; - Type updatedResultTy = op.getResult().getType(); - const auto desiredAddrSpace = kSharedMemoryAddressSpace; - if (ptrTy.getAddressSpace() != desiredAddrSpace) { - ptrTy = - triton::PointerType::get(ptrTy.getPointeeType(), desiredAddrSpace); - updated = true; - } - - if (!tensorTy) { - if (updated) - op.getResult().setType(ptrTy); - return; - } - - auto encoding = tensorTy.getEncoding(); - Attribute userEncoding; - SmallVector loadConsumerEncodings; - SmallVector storeConsumerEncodings; - collectConsumerEncodings(op.getResult(), loadConsumerEncodings, - storeConsumerEncodings); - auto pickConsistentEncoding = - [](ArrayRef encodings) -> Attribute { - Attribute selected; - for (Attribute enc : encodings) { - if (!selected) - selected = enc; - else if (selected != enc) - return Attribute(); - } - return selected; - }; - // Pointer tensor encoding should follow load consumers first; stores can - // be bridged via convert_layout on the value path. - userEncoding = pickConsistentEncoding(loadConsumerEncodings); - if (!userEncoding) - userEncoding = pickConsistentEncoding(storeConsumerEncodings); - if (userEncoding && userEncoding != encoding) { - encoding = userEncoding; - updated = true; - } - if (!encoding) { - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPoint(op); - int numWarps = triton::gpu::maybeLookupNumWarps(op).value_or(1); - int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(builder); - int numCTAs = triton::gpu::lookupNumCTAs(builder); - encoding = triton::gpu::getDefaultBlockedEncoding( - module.getContext(), tensorTy.getShape(), numWarps, threadsPerWarp, - numCTAs); - updated = true; - } - - if (updated) - updatedResultTy = - RankedTensorType::get(tensorTy.getShape(), ptrTy, encoding); - - if (updated) - op.getResult().setType(updatedResultTy); - - if (updated) { - llvm::DenseSet visited; - auto updateUserResultTypes = [&](auto &&self, Value ptrVal) -> void { - if (!ptrVal || !visited.insert(ptrVal).second) - return; - auto ptrTensorTy = cast(ptrVal.getType()); - auto ptrEncoding = ptrTensorTy.getEncoding(); - auto ptrElemTy = - cast(ptrTensorTy.getElementType()) - .getPointeeType(); - auto loadTy = RankedTensorType::get(ptrTensorTy.getShape(), ptrElemTy, - ptrTensorTy.getEncoding()); - auto convertOperandEncoding = [&](Operation *insertBefore, Value v, - Attribute encoding) -> Value { - auto vTy = dyn_cast(v.getType()); - if (!vTy) - return v; - if (vTy.getEncoding() == encoding) - return v; - auto convertedTy = RankedTensorType::get( - vTy.getShape(), vTy.getElementType(), encoding); - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPoint(insertBefore); - auto converted = builder.create( - insertBefore->getLoc(), convertedTy, v); - return converted.getResult(); - }; - for (OpOperand &use : ptrVal.getUses()) { - Operation *owner = use.getOwner(); - if (auto load = dyn_cast(owner)) { - if (Value mask = load.getMask()) { - Value convertedMask = - convertOperandEncoding(owner, mask, ptrEncoding); - if (convertedMask != mask) - load.getMaskMutable().assign(convertedMask); - } - if (Value other = load.getOther()) { - Value convertedOther = - convertOperandEncoding(owner, other, ptrEncoding); - if (convertedOther != other) - load.getOtherMutable().assign(convertedOther); - } - auto oldLoadTy = - dyn_cast(load.getResult().getType()); - if (oldLoadTy != loadTy) { - load.getResult().setType(loadTy); - if (oldLoadTy) { - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointAfter(load); - auto bridge = builder.create( - load.getLoc(), oldLoadTy, load.getResult()); - load.getResult().replaceAllUsesExcept(bridge.getResult(), - bridge.getOperation()); - } - } - continue; - } - if (auto store = dyn_cast(owner)) { - auto valueTy = - dyn_cast(store.getValue().getType()); - if (valueTy) { - Value convertedValue = convertOperandEncoding( - owner, store.getValue(), ptrEncoding); - if (convertedValue != store.getValue()) - store.getValueMutable().assign(convertedValue); - } - if (Value mask = store.getMask()) { - Value convertedMask = - convertOperandEncoding(owner, mask, ptrEncoding); - if (convertedMask != mask) - store.getMaskMutable().assign(convertedMask); - } - continue; - } - if (auto atomic = dyn_cast(owner)) { - Value val = atomic.getVal(); - Value convertedVal = - convertOperandEncoding(owner, val, ptrEncoding); - if (convertedVal != val) - atomic.getValMutable().assign(convertedVal); - if (Value mask = atomic.getMask()) { - Value convertedMask = - convertOperandEncoding(owner, mask, ptrEncoding); - if (convertedMask != mask) - atomic.getMaskMutable().assign(convertedMask); - } - atomic.getResult().setType(loadTy); - continue; - } - if (auto cas = dyn_cast(owner)) { - Value cmp = cas.getCmp(); - Value convertedCmp = - convertOperandEncoding(owner, cmp, ptrEncoding); - if (convertedCmp != cmp) - cas.getCmpMutable().assign(convertedCmp); - Value val = cas.getVal(); - Value convertedVal = - convertOperandEncoding(owner, val, ptrEncoding); - if (convertedVal != val) - cas.getValMutable().assign(convertedVal); - cas.getResult().setType(loadTy); - continue; - } - if (auto remote = dyn_cast(owner)) { - if (remote.getResult().getType() != ptrTensorTy) - remote.getResult().setType(ptrTensorTy); - self(self, remote.getResult()); - continue; - } - } - }; - updateUserResultTypes(updateUserResultTypes, op.getResult()); - } - - auto desiredEncoding = - cast(updatedResultTy).getEncoding(); - if (desiredEncoding) { - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPoint(op); - SmallVector newOperands; - newOperands.reserve(op->getNumOperands()); - newOperands.push_back(op.getSrc()); - bool updatedOperands = false; - for (Value operand : op.getIndices()) { - auto operandTy = dyn_cast(operand.getType()); - if (!operandTy) { - newOperands.push_back(operand); - continue; - } - if (operandTy.getEncoding() == desiredEncoding) { - newOperands.push_back(operand); - continue; - } - auto convertedTy = RankedTensorType::get(operandTy.getShape(), - operandTy.getElementType(), - desiredEncoding); - auto converted = builder.create( - op.getLoc(), convertedTy, operand); - newOperands.push_back(converted); - updatedOperands = true; - } - if (updatedOperands) - op->setOperands(newOperands); - } - }); - - // remote_pointers should preserve source pointer axis properties so later - // passes can reason about remote operands without dialect-specific - // visitors. - module.walk([&](triton::tle::RemotePointersOp op) { - Operation *srcDef = peelAxisInfoCarrier(op.getSrc()); - copyAxisInfoAttrs(srcDef, op.getOperation()); - }); - } - - void tagDependencyGroup(triton::tle::LocalPointersOp op, OpBuilder &builder) { - auto alloc = op.getSrc().getDefiningOp(); - if (!alloc) - return; - auto groupAttr = alloc->getAttrOfType(kBarrierGroupAttr); - if (!groupAttr) { - groupAttr = builder.getI64IntegerAttr(nextBarrierGroupId++); - alloc->setAttr(kBarrierGroupAttr, groupAttr); - } - op->setAttr(kBarrierGroupAttr, groupAttr); - } - - int64_t nextBarrierGroupId = 0; -}; - -} // namespace -} // namespace mlir::triton::tle diff --git a/third_party/tle/dialect/lib/Transforms/TleInsertLocalPointerBarriers.cpp b/third_party/tle/dialect/lib/Transforms/TleInsertLocalPointerBarriers.cpp index bf662a3169..aed48041c3 100644 --- a/third_party/tle/dialect/lib/Transforms/TleInsertLocalPointerBarriers.cpp +++ b/third_party/tle/dialect/lib/Transforms/TleInsertLocalPointerBarriers.cpp @@ -24,7 +24,9 @@ #include "tle/dialect/include/Transforms/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" @@ -88,8 +90,29 @@ class InsertLocalPointerBarriersPass Operation *owner = use.getOwner(); if (auto convert = dyn_cast(owner)) { tryTrackDerived(owner, convert.getSrc(), convert.getResult()); + } else if (auto splat = dyn_cast(owner)) { + tryTrackDerived(owner, splat.getSrc(), splat.getResult()); } else if (auto bcast = dyn_cast(owner)) { tryTrackDerived(owner, bcast.getSrc(), bcast.getResult()); + } else if (auto expand = dyn_cast(owner)) { + tryTrackDerived(owner, expand.getSrc(), expand.getResult()); + } else if (auto reshape = dyn_cast(owner)) { + tryTrackDerived(owner, reshape.getSrc(), reshape.getResult()); + } else if (auto addptr = dyn_cast(owner)) { + // Only propagate along the pointer operand. + if (use.getOperandNumber() == 0) + tryTrackDerived(owner, addptr.getPtr(), addptr.getResult()); + } else if (auto call = dyn_cast(owner)) { + auto it = pointerGroups.find(current); + if (it == pointerGroups.end()) + continue; + unsigned operandIdx = use.getOperandNumber(); + auto callee = module.lookupSymbol(call.getCallee()); + if (!callee || operandIdx >= callee.getNumArguments()) + continue; + Value calleeArg = callee.getArgument(operandIdx); + if (pointerGroups.try_emplace(calleeArg, it->second).second) + worklist.push_back(calleeArg); } } } @@ -108,11 +131,17 @@ class InsertLocalPointerBarriersPass void processBlock(Block &block) { llvm::DenseMap dirtyGroups; for (Operation &op : block) { - if (!dirtyGroups.empty() && op.getNumRegions() > 0 && - opHasLoadNeedingBarrier(op, dirtyGroups)) { - OpBuilder builder(&op); - builder.create(op.getLoc()); - dirtyGroups.clear(); + if (!dirtyGroups.empty() && op.getNumRegions() > 0) { + bool handledByIfSpecialization = false; + if (auto ifOp = dyn_cast(&op)) + handledByIfSpecialization = tryHandleUniformIf(ifOp, dirtyGroups); + + if (!handledByIfSpecialization && + opHasLoadNeedingBarrier(op, dirtyGroups)) { + OpBuilder builder(&op); + builder.create(op.getLoc()); + dirtyGroups.clear(); + } } if (auto store = dyn_cast(&op)) { @@ -124,42 +153,107 @@ class InsertLocalPointerBarriersPass continue; OpBuilder builder(load); builder.create(load.getLoc()); - dirtyGroups[*group] = false; + // A CTA barrier synchronizes all shared-memory groups, not only the + // group used by this load. Clearing all dirty groups avoids emitting + // redundant back-to-back barriers for consecutive loads from different + // tracked groups. + dirtyGroups.clear(); } else if (isa(&op)) { dirtyGroups.clear(); } for (Region &nested : op.getRegions()) processRegion(nested); + + // Propagate write hazards from nested regions to the parent block. + // Without this, a store inside scf.if/scf.for may not mark parent state + // dirty, so a subsequent outer load can miss the required barrier. + markGroupsWrittenByNestedRegions(op, dirtyGroups); + } + } + + bool tryHandleUniformIf(scf::IfOp ifOp, + const llvm::DenseMap &dirtyGroups) { + if (!isUniformCondition(ifOp.getCondition())) + return false; + + for (Region ®ion : ifOp->getRegions()) { + if (!regionHasLoadNeedingBarrier(region, dirtyGroups)) + continue; + if (region.empty() || region.front().empty()) + continue; + + Block &entry = region.front(); + if (isa(entry.front())) + continue; + + OpBuilder builder(&entry, entry.begin()); + builder.create(ifOp.getLoc()); } + return true; + } + + bool isUniformCondition(Value cond) const { + if (isa_and_nonnull(cond.getDefiningOp())) + return true; + + auto reduce = cond.getDefiningOp(); + if (!reduce || !cond.getType().isInteger(1)) + return false; + + Operation *combiner = reduce.getSingleCombiner(); + return combiner && isa(combiner); + } + + bool regionHasLoadNeedingBarrier( + Region ®ion, const llvm::DenseMap &dirtyGroups) const { + for (Block &block : region) { + for (Operation &nestedOp : block) { + if (auto load = dyn_cast(&nestedOp)) { + if (auto group = lookupPointerGroup(load.getPtr()); + group && dirtyGroups.lookup(*group)) + return true; + } + if (nestedOp.getNumRegions() > 0 && + opHasLoadNeedingBarrier(nestedOp, dirtyGroups)) + return true; + } + } + return false; } bool opHasLoadNeedingBarrier( Operation &op, const llvm::DenseMap &dirtyGroups) const { - bool needsBarrier = false; for (Region ®ion : op.getRegions()) { - for (Block &block : region) { - for (Operation &nestedOp : block) { - if (auto load = dyn_cast(&nestedOp)) { - if (auto group = lookupPointerGroup(load.getPtr()); - group && dirtyGroups.lookup(*group)) { - needsBarrier = true; - break; - } - } - if (nestedOp.getNumRegions() > 0 && - opHasLoadNeedingBarrier(nestedOp, dirtyGroups)) { - needsBarrier = true; - break; - } + if (regionHasLoadNeedingBarrier(region, dirtyGroups)) + return true; + } + return false; + } + + void markGroupsWrittenByNestedRegions( + Operation &op, llvm::DenseMap &dirtyGroups) const { + if (op.getNumRegions() == 0) + return; + llvm::DenseSet writtenGroups; + for (Region ®ion : op.getRegions()) + collectWrittenGroups(region, writtenGroups); + for (int64_t group : writtenGroups) + dirtyGroups[group] = true; + } + + void collectWrittenGroups(Region ®ion, + llvm::DenseSet &writtenGroups) const { + for (Block &block : region) { + for (Operation &nestedOp : block) { + if (auto store = dyn_cast(&nestedOp)) { + if (auto group = lookupPointerGroup(store.getPtr())) + writtenGroups.insert(*group); } - if (needsBarrier) - break; + for (Region &deeperRegion : nestedOp.getRegions()) + collectWrittenGroups(deeperRegion, writtenGroups); } - if (needsBarrier) - break; } - return needsBarrier; } std::optional lookupPointerGroup(Value ptr) const { diff --git a/third_party/tle/dialect/lib/Transforms/TleLowerExclusiveCumsum.cpp b/third_party/tle/dialect/lib/Transforms/TleLowerExclusiveCumsum.cpp new file mode 100644 index 0000000000..bb54ffa637 --- /dev/null +++ b/third_party/tle/dialect/lib/Transforms/TleLowerExclusiveCumsum.cpp @@ -0,0 +1,192 @@ +// MIT License +// +// Copyright (c) 2025 The FlagOS Contributors +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "tle/dialect/include/IR/Dialect.h" +#include "tle/dialect/include/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "llvm/ADT/SmallVector.h" +#include + +namespace mlir::triton::tle { + +#define GEN_PASS_DEF_TRITONTLELOWEREXCLUSIVECUMSUM +#include "tle/dialect/include/Transforms/Passes.h.inc" + +namespace { + +static Value createAddOp(OpBuilder &builder, Location loc, Value lhs, Value rhs, + Type elemTy) { + if (isa(elemTy)) + return builder.create(loc, lhs, rhs).getResult(); + if (elemTy.isIntOrIndex()) + return builder.create(loc, lhs, rhs).getResult(); + return nullptr; +} + +static Value createSubOp(OpBuilder &builder, Location loc, Value lhs, Value rhs, + Type elemTy) { + if (isa(elemTy)) + return builder.create(loc, lhs, rhs).getResult(); + if (elemTy.isIntOrIndex()) + return builder.create(loc, lhs, rhs).getResult(); + return nullptr; +} + +static LogicalResult buildScanAddRegion(OpBuilder &builder, triton::ScanOp scan, + Type elemTy, Location loc) { + OpBuilder::InsertionGuard guard(builder); + Block *block = builder.createBlock(&scan.getCombineOp()); + block->addArgument(elemTy, loc); + block->addArgument(elemTy, loc); + builder.setInsertionPointToEnd(block); + Value sum = createAddOp(builder, loc, block->getArgument(0), + block->getArgument(1), elemTy); + if (!sum) + return failure(); + builder.create(loc, ValueRange{sum}); + return success(); +} + +static LogicalResult buildReduceSelectByIndexRegion(OpBuilder &builder, + triton::ReduceOp reduce, + Type elemTy, Location loc, + bool pickMaxIndex) { + OpBuilder::InsertionGuard guard(builder); + Block *block = builder.createBlock(&reduce.getCombineOp()); + Type idxTy = builder.getI32Type(); + // Reduce with 2 operands: (idx, value). Region argument order is: + // (lhs_idx, lhs_val, rhs_idx, rhs_val). + block->addArgument(idxTy, loc); + block->addArgument(elemTy, loc); + block->addArgument(idxTy, loc); + block->addArgument(elemTy, loc); + builder.setInsertionPointToEnd(block); + + Value lhsIdx = block->getArgument(0); + Value lhsVal = block->getArgument(1); + Value rhsIdx = block->getArgument(2); + Value rhsVal = block->getArgument(3); + + arith::CmpIPredicate pred = + pickMaxIndex ? arith::CmpIPredicate::sgt : arith::CmpIPredicate::slt; + Value chooseLhs = builder.create(loc, pred, lhsIdx, rhsIdx); + Value selectedIdx = + builder.create(loc, chooseLhs, lhsIdx, rhsIdx); + Value selectedVal = + builder.create(loc, chooseLhs, lhsVal, rhsVal); + builder.create(loc, + ValueRange{selectedIdx, selectedVal}); + return success(); +} + +class LowerExclusiveCumsumPass + : public impl::TritonTleLowerExclusiveCumsumBase { + void runOnOperation() override { + ModuleOp module = getOperation(); + SmallVector ops; + module.walk([&](tle::ExclusiveCumsumOp op) { ops.push_back(op); }); + + for (tle::ExclusiveCumsumOp op : ops) { + if (!op) + continue; + + auto srcTy = dyn_cast(op.getSrc().getType()); + if (!srcTy) { + op.emitOpError("expects ranked tensor input"); + signalPassFailure(); + return; + } + int64_t axisExtent = srcTy.getShape()[0]; + if (ShapedType::isDynamic(axisExtent) || axisExtent <= 0) { + op.emitOpError("expects static, positive axis extent"); + signalPassFailure(); + return; + } + if (axisExtent > + static_cast(std::numeric_limits::max())) { + op.emitOpError("axis extent is too large for tt.make_range"); + signalPassFailure(); + return; + } + + const Type elemTy = srcTy.getElementType(); + OpBuilder builder(op); + + auto scan = builder.create( + op.getLoc(), ValueRange{op.getSrc()}, static_cast(op.getAxis()), + op.getReverse()); + if (failed(buildScanAddRegion(builder, scan, elemTy, op.getLoc()))) { + op.emitOpError("failed to build add combiner for triton.scan"); + signalPassFailure(); + return; + } + Value inclusive = scan.getResult()[0]; + Value exclusive = + createSubOp(builder, op.getLoc(), inclusive, op.getSrc(), elemTy); + if (!exclusive) { + op.emitOpError("unsupported element type for exclusive subtraction"); + signalPassFailure(); + return; + } + + RankedTensorType idxTy = RankedTensorType::get( + srcTy.getShape(), builder.getI32Type(), srcTy.getEncoding()); + Value indices = builder + .create( + op.getLoc(), idxTy, /*start=*/0u, + /*end=*/static_cast(axisExtent)) + .getResult(); + auto reduce = builder.create( + op.getLoc(), ValueRange{indices, inclusive}, + static_cast(op.getAxis())); + bool pickMaxIndex = !op.getReverse(); + if (failed(buildReduceSelectByIndexRegion(builder, reduce, elemTy, + op.getLoc(), pickMaxIndex))) { + op.emitOpError( + "failed to build index-select combiner for triton.reduce"); + signalPassFailure(); + return; + } + Value total = reduce.getResult()[1]; + + if (exclusive.getType() != op.getExclusive().getType() || + total.getType() != op.getTotal().getType()) { + op.emitOpError("lowered value types do not match op result types"); + signalPassFailure(); + return; + } + + op.getExclusive().replaceAllUsesWith(exclusive); + op.getTotal().replaceAllUsesWith(total); + op.erase(); + } + } +}; + +} // namespace +} // namespace mlir::triton::tle diff --git a/third_party/tle/dialect/lib/Transforms/TleOptimizeExclusiveCumsumLayouts.cpp b/third_party/tle/dialect/lib/Transforms/TleOptimizeExclusiveCumsumLayouts.cpp new file mode 100644 index 0000000000..1f4a7deb91 --- /dev/null +++ b/third_party/tle/dialect/lib/Transforms/TleOptimizeExclusiveCumsumLayouts.cpp @@ -0,0 +1,100 @@ +// MIT License +// +// Copyright (c) 2025 The FlagOS Contributors +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "tle/dialect/include/IR/Dialect.h" +#include "tle/dialect/include/Transforms/Passes.h" + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir::triton::tle { + +#define GEN_PASS_DEF_TRITONTLEOPTIMIZEEXCLUSIVECUMSUMLAYOUTS +#include "tle/dialect/include/Transforms/Passes.h.inc" + +namespace { + +class OptimizeExclusiveCumsumLayoutsPass + : public impl::TritonTleOptimizeExclusiveCumsumLayoutsBase< + OptimizeExclusiveCumsumLayoutsPass> { +public: + void runOnOperation() override { + ModuleOp module = getOperation(); + SmallVector ops; + module.walk([&](tle::ExclusiveCumsumOp op) { ops.push_back(op); }); + + for (tle::ExclusiveCumsumOp op : ops) { + auto srcCvt = op.getSrc().getDefiningOp(); + if (!srcCvt) + continue; + + auto srcBaseTy = dyn_cast(srcCvt.getSrc().getType()); + auto cumsumTy = dyn_cast(op.getExclusive().getType()); + if (!srcBaseTy || !cumsumTy) + continue; + if (srcBaseTy.getShape() != cumsumTy.getShape() || + srcBaseTy.getElementType() != cumsumTy.getElementType()) + continue; + if (srcBaseTy.getEncoding() == cumsumTy.getEncoding()) + continue; + + SmallVector outCvts; + bool allUsersMatch = true; + for (OpOperand &use : op.getExclusive().getUses()) { + auto outCvt = dyn_cast(use.getOwner()); + if (!outCvt || outCvt.getSrc() != op.getExclusive()) { + allUsersMatch = false; + break; + } + auto outTy = dyn_cast(outCvt.getType()); + if (!outTy || outTy.getShape() != srcBaseTy.getShape() || + outTy.getElementType() != srcBaseTy.getElementType() || + outTy.getEncoding() != srcBaseTy.getEncoding()) { + allUsersMatch = false; + break; + } + outCvts.push_back(outCvt); + } + if (!allUsersMatch) + continue; + + OpBuilder builder(op); + auto newOp = builder.create( + op.getLoc(), TypeRange{srcBaseTy, op.getTotal().getType()}, + srcCvt.getSrc(), op.getAxisAttr(), op.getReverseAttr()); + + for (auto cvt : outCvts) + cvt.replaceAllUsesWith(newOp.getExclusive()); + op.getTotal().replaceAllUsesWith(newOp.getTotal()); + + for (auto cvt : outCvts) + cvt.erase(); + op.erase(); + if (srcCvt->use_empty()) + srcCvt.erase(); + } + } +}; + +} // namespace +} // namespace mlir::triton::tle diff --git a/third_party/tle/dialect/lib/Transforms/TleOptimizeLocalPointerLoads.cpp b/third_party/tle/dialect/lib/Transforms/TleOptimizeLocalPointerLoads.cpp new file mode 100644 index 0000000000..03472a2861 --- /dev/null +++ b/third_party/tle/dialect/lib/Transforms/TleOptimizeLocalPointerLoads.cpp @@ -0,0 +1,201 @@ +// MIT License +// +// Copyright (c) 2025 The FlagOS Contributors +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "tle/dialect/include/IR/Dialect.h" +#include "tle/dialect/include/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include + +namespace mlir::triton::tle { + +#define GEN_PASS_DEF_TRITONTLEOPTIMIZELOCALPOINTERLOADS +#include "tle/dialect/include/Transforms/Passes.h.inc" + +namespace { + +namespace ttg = mlir::triton::gpu; + +static Value stripConvertLayouts(Value value) { + Value current = value; + while (auto cvt = current.getDefiningOp()) + current = cvt.getSrc(); + return current; +} + +static Value stripIndexValueWrappers(Value value) { + Value current = value; + while (true) { + if (auto cvt = current.getDefiningOp()) { + current = cvt.getSrc(); + continue; + } + if (auto ext = current.getDefiningOp()) { + current = ext.getIn(); + continue; + } + if (auto ext = current.getDefiningOp()) { + current = ext.getIn(); + continue; + } + if (auto trunc = current.getDefiningOp()) { + current = trunc.getIn(); + continue; + } + if (auto cast = current.getDefiningOp()) { + current = cast.getIn(); + continue; + } + break; + } + return current; +} + +static bool matchZeroStartMakeRange(Value value, int64_t extent) { + Value current = stripIndexValueWrappers(value); + auto range = current.getDefiningOp(); + if (!range) + return false; + return range.getStart() == 0 && range.getEnd() == extent; +} + +static bool matchFullIndexTensorForAxis(Value index, size_t axis, + ArrayRef shape) { + auto indexTy = dyn_cast(index.getType()); + if (!indexTy || !indexTy.getElementType().isInteger()) + return false; + if (indexTy.getShape() != shape) + return false; + + Value current = stripIndexValueWrappers(index); + if (shape.size() == 1) + return matchZeroStartMakeRange(current, shape.front()); + + auto bcast = current.getDefiningOp(); + if (!bcast) + return false; + + auto bcastSrcTy = dyn_cast(bcast.getSrc().getType()); + if (!bcastSrcTy || bcastSrcTy.getRank() != static_cast(shape.size())) + return false; + for (auto [dim, dimSize] : llvm::enumerate(shape)) { + const int64_t expected = dim == axis ? dimSize : 1; + if (bcastSrcTy.getShape()[dim] != expected) + return false; + } + + current = stripIndexValueWrappers(bcast.getSrc()); + while (auto expand = current.getDefiningOp()) + current = stripIndexValueWrappers(expand.getSrc()); + + auto rangeTy = dyn_cast(current.getType()); + if (!rangeTy || rangeTy.getRank() != 1) + return false; + if (rangeTy.getShape()[0] != shape[axis]) + return false; + + return matchZeroStartMakeRange(current, shape[axis]); +} + +static std::optional matchFullViewMemDesc(triton::LoadOp load) { + if (load.getMask() || load.getOther()) + return std::nullopt; + if (load.getIsVolatile()) + return std::nullopt; + if (load.getCache() != triton::CacheModifier::NONE || + load.getEvict() != triton::EvictionPolicy::NORMAL) + return std::nullopt; + + auto loadTy = dyn_cast(load.getType()); + if (!loadTy) + return std::nullopt; + + Value ptr = stripConvertLayouts(load.getPtr()); + auto localPointers = ptr.getDefiningOp(); + if (!localPointers) + return std::nullopt; + + auto ptrTy = dyn_cast(localPointers.getResult().getType()); + if (!ptrTy) + return std::nullopt; + + auto memDescTy = dyn_cast(localPointers.getSrc().getType()); + if (!memDescTy) + return std::nullopt; + + auto memDescShape = memDescTy.getShape(); + if (loadTy.getShape() != memDescShape || ptrTy.getShape() != memDescShape) + return std::nullopt; + if (loadTy.getElementType() != memDescTy.getElementType()) + return std::nullopt; + + auto indices = localPointers.getIndices(); + if (indices.empty()) + return localPointers.getSrc(); + if (indices.size() != memDescShape.size()) + return std::nullopt; + + for (auto [axis, index] : llvm::enumerate(indices)) + if (!matchFullIndexTensorForAxis(index, axis, memDescShape)) + return std::nullopt; + + return localPointers.getSrc(); +} + +class OptimizeLocalPointerLoadsPass + : public impl::TritonTleOptimizeLocalPointerLoadsBase< + OptimizeLocalPointerLoadsPass> { + void runOnOperation() override { + ModuleOp module = getOperation(); + + struct RewriteItem { + triton::LoadOp load; + Value memDesc; + }; + SmallVector rewrites; + + module.walk([&](triton::LoadOp load) { + if (auto memDesc = matchFullViewMemDesc(load)) + rewrites.push_back({load, *memDesc}); + }); + + for (RewriteItem &item : rewrites) { + if (!item.load || !item.memDesc) + continue; + OpBuilder builder(item.load); + auto localLoad = builder.create( + item.load.getLoc(), item.load.getType(), item.memDesc); + item.load.replaceAllUsesWith(localLoad.getResult()); + item.load.erase(); + } + } +}; + +} // namespace +} // namespace mlir::triton::tle diff --git a/third_party/tle/dialect/lib/Transforms/TleOptimizeLocalPointerStores.cpp b/third_party/tle/dialect/lib/Transforms/TleOptimizeLocalPointerStores.cpp new file mode 100644 index 0000000000..f90ead863b --- /dev/null +++ b/third_party/tle/dialect/lib/Transforms/TleOptimizeLocalPointerStores.cpp @@ -0,0 +1,112 @@ +// MIT License +// +// Copyright (c) 2025 The FlagOS Contributors +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "tle/dialect/include/IR/Dialect.h" +#include "tle/dialect/include/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir::triton::tle { + +#define GEN_PASS_DEF_TRITONTLEOPTIMIZELOCALPOINTERSTORES +#include "tle/dialect/include/Transforms/Passes.h.inc" + +namespace { + +namespace ttg = mlir::triton::gpu; + +static Value stripConvertLayouts(Value value) { + Value current = value; + while (auto cvt = current.getDefiningOp()) + current = cvt.getSrc(); + return current; +} + +class OptimizeLocalPointerStoresPass + : public impl::TritonTleOptimizeLocalPointerStoresBase< + OptimizeLocalPointerStoresPass> { + void runOnOperation() override { + ModuleOp module = getOperation(); + + SmallVector stores; + module.walk([&](triton::StoreOp store) { stores.push_back(store); }); + + for (triton::StoreOp store : stores) { + if (!store) + continue; + + Value ptr = stripConvertLayouts(store.getPtr()); + auto localPointers = ptr.getDefiningOp(); + if (!localPointers) + continue; + + auto valueTy = dyn_cast(store.getValue().getType()); + auto memDescTy = + dyn_cast(localPointers.getSrc().getType()); + if (!valueTy || !memDescTy) + continue; + + if (!store.getBoundaryCheck().empty()) + continue; + if (valueTy.getShape() != memDescTy.getShape()) + continue; + if (valueTy.getElementType() != memDescTy.getElementType()) + continue; + + OpBuilder builder(store); + Value valueToStore = store.getValue(); + + if (Value mask = store.getMask()) { + auto maskTy = dyn_cast(mask.getType()); + if (!maskTy || maskTy.getShape() != valueTy.getShape()) + continue; + if (maskTy.getEncoding() != valueTy.getEncoding()) { + auto targetMaskTy = + RankedTensorType::get(maskTy.getShape(), maskTy.getElementType(), + valueTy.getEncoding()); + mask = builder + .create(store.getLoc(), targetMaskTy, + mask) + .getResult(); + } + Value oldValue = builder.create( + store.getLoc(), valueTy, localPointers.getSrc()); + valueToStore = builder + .create(store.getLoc(), mask, + valueToStore, oldValue) + .getResult(); + } + + builder.create(store.getLoc(), valueToStore, + localPointers.getSrc()); + store.erase(); + } + } +}; + +} // namespace +} // namespace mlir::triton::tle diff --git a/third_party/tle/dialect/lib/Transforms/TleSelectEncodings.cpp b/third_party/tle/dialect/lib/Transforms/TleSelectEncodings.cpp new file mode 100644 index 0000000000..76f05d9018 --- /dev/null +++ b/third_party/tle/dialect/lib/Transforms/TleSelectEncodings.cpp @@ -0,0 +1,946 @@ +// MIT License + +// Copyright (c) 2025 The FlagOS Contributors + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// flagtree tle + +#include "tle/dialect/include/IR/Dialect.h" +#include "tle/dialect/include/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" + +namespace mlir::triton::tle { + +#define GEN_PASS_DEF_TRITONTLESELECTENCODINGS +#include "tle/dialect/include/Transforms/Passes.h.inc" + +namespace { + +// Triton shared-memory pointers use LLVM address space 3 (NVVM shared). +constexpr int kSharedMemoryAddressSpace = 3; +constexpr StringLiteral kBarrierGroupAttr = "tle.barrier_group"; +constexpr StringLiteral kTTContiguityAttr = "tt.contiguity"; +constexpr StringLiteral kTTDivisibilityAttr = "tt.divisibility"; +constexpr StringLiteral kTTConstancyAttr = "tt.constancy"; + +static Value stripConvertLayouts(Value value) { + Value current = value; + while (auto convert = current.getDefiningOp()) + current = convert.getSrc(); + return current; +} + +static Attribute getStrippedTensorEncoding(Value value) { + Value stripped = stripConvertLayouts(value); + auto strippedTy = dyn_cast(stripped.getType()); + if (!strippedTy) + return Attribute(); + return strippedTy.getEncoding(); +} + +static bool isConstantLikeTensorValue(Value value) { + Value cur = stripConvertLayouts(value); + if (!isa(cur.getType())) + return false; + if (isa_and_nonnull(cur.getDefiningOp())) + return true; + if (auto splat = cur.getDefiningOp()) { + Value src = splat.getSrc(); + if (isa_and_nonnull(src.getDefiningOp())) + return true; + } + return false; +} + +static Value stripIndexValueWrappers(Value value) { + Value current = value; + while (true) { + if (auto convert = current.getDefiningOp()) { + current = convert.getSrc(); + continue; + } + if (auto ext = current.getDefiningOp()) { + current = ext.getIn(); + continue; + } + if (auto ext = current.getDefiningOp()) { + current = ext.getIn(); + continue; + } + if (auto trunc = current.getDefiningOp()) { + current = trunc.getIn(); + continue; + } + if (auto cast = current.getDefiningOp()) { + current = cast.getIn(); + continue; + } + break; + } + return current; +} + +static bool matchZeroStartMakeRange(Value value, int64_t extent) { + Value current = stripIndexValueWrappers(value); + auto range = current.getDefiningOp(); + if (!range) + return false; + return range.getStart() == 0 && range.getEnd() == extent; +} + +static bool matchFullIndexTensorForAxis(Value index, size_t axis, + ArrayRef shape) { + auto indexTy = dyn_cast(index.getType()); + if (!indexTy || !indexTy.getElementType().isInteger()) + return false; + if (indexTy.getShape() != shape) + return false; + + Value current = stripIndexValueWrappers(index); + if (shape.size() == 1) + return matchZeroStartMakeRange(current, shape.front()); + + auto bcast = current.getDefiningOp(); + if (!bcast) + return false; + + auto bcastSrcTy = dyn_cast(bcast.getSrc().getType()); + if (!bcastSrcTy || bcastSrcTy.getRank() != static_cast(shape.size())) + return false; + for (auto [dim, dimSize] : llvm::enumerate(shape)) { + const int64_t expected = dim == axis ? dimSize : 1; + if (bcastSrcTy.getShape()[dim] != expected) + return false; + } + + current = stripIndexValueWrappers(bcast.getSrc()); + while (auto expand = current.getDefiningOp()) + current = stripIndexValueWrappers(expand.getSrc()); + + auto rangeTy = dyn_cast(current.getType()); + if (!rangeTy || rangeTy.getRank() != 1) + return false; + if (rangeTy.getShape()[0] != shape[axis]) + return false; + + return matchZeroStartMakeRange(current, shape[axis]); +} + +// Loads of full-view local_pointers are later rewritten to ttg.local_load. +// They should not bias local_pointers encoding inference toward load layouts. +static bool isRewritableFullViewLocalPointerLoad(triton::LoadOp load) { + if (load.getMask() || load.getOther()) + return false; + if (load.getIsVolatile()) + return false; + if (load.getCache() != triton::CacheModifier::NONE || + load.getEvict() != triton::EvictionPolicy::NORMAL) + return false; + + auto loadTy = dyn_cast(load.getType()); + if (!loadTy) + return false; + + Value ptr = stripConvertLayouts(load.getPtr()); + auto localPointers = ptr.getDefiningOp(); + if (!localPointers) + return false; + + auto ptrTy = dyn_cast(localPointers.getResult().getType()); + if (!ptrTy) + return false; + + auto memDescTy = + dyn_cast(localPointers.getSrc().getType()); + if (!memDescTy) + return false; + + auto memDescShape = memDescTy.getShape(); + if (loadTy.getShape() != memDescShape || ptrTy.getShape() != memDescShape) + return false; + if (loadTy.getElementType() != memDescTy.getElementType()) + return false; + + auto indices = localPointers.getIndices(); + if (indices.empty()) + return true; + if (indices.size() != memDescShape.size()) + return false; + + for (auto [axis, index] : llvm::enumerate(indices)) + if (!matchFullIndexTensorForAxis(index, axis, memDescShape)) + return false; + return true; +} + +static int64_t getScfLoopDepth(Operation *op) { + int64_t depth = 0; + for (Operation *cur = op; cur; cur = cur->getParentOp()) + if (isa(cur)) + ++depth; + return depth; +} + +static bool valueFeedsDot(Value root) { + llvm::SmallVector worklist; + llvm::DenseSet visited; + auto enqueue = [&](Value v) { + if (v && visited.insert(v).second) + worklist.push_back(v); + }; + enqueue(root); + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + for (OpOperand &use : current.getUses()) { + Operation *owner = use.getOwner(); + if (isa(owner)) + return true; + if (auto convert = dyn_cast(owner)) { + enqueue(convert.getResult()); + continue; + } + if (auto trans = dyn_cast(owner)) { + enqueue(trans.getResult()); + continue; + } + if (auto bcast = dyn_cast(owner)) { + enqueue(bcast.getResult()); + continue; + } + if (auto expand = dyn_cast(owner)) { + enqueue(expand.getResult()); + continue; + } + if (auto reshape = dyn_cast(owner)) { + enqueue(reshape.getResult()); + continue; + } + } + } + return false; +} + +struct EncodingVote { + Attribute encoding; + int64_t score; +}; + +using CachedConversionKey = std::pair; +using CachedConversionMap = + llvm::DenseMap>; + +static Value getOrCreateCachedConvertLayout(OpBuilder &builder, + Operation *insertBefore, Value v, + Attribute encoding, + CachedConversionMap &cache) { + Value stripped = stripConvertLayouts(v); + auto strippedTy = dyn_cast(stripped.getType()); + if (strippedTy && strippedTy.getEncoding() == encoding) + return stripped; + + auto vTy = dyn_cast(v.getType()); + if (!vTy) + return v; + if (vTy.getEncoding() == encoding) + return v; + + CachedConversionKey key{v, encoding}; + auto it = cache.find(key); + if (it != cache.end()) { + for (Value candidate : it->second) { + Operation *def = candidate.getDefiningOp(); + if (!def) + continue; + if (def->getBlock() != insertBefore->getBlock()) + continue; + if (def->isBeforeInBlock(insertBefore)) + return candidate; + } + } + + auto convertedTy = + RankedTensorType::get(vTy.getShape(), vTy.getElementType(), encoding); + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(insertBefore); + auto converted = builder.create( + insertBefore->getLoc(), convertedTy, v); + Value convertedValue = converted.getResult(); + cache[key].push_back(convertedValue); + return convertedValue; +} + +static Operation *peelAxisInfoCarrier(Value value) { + llvm::DenseSet visited; + Value current = value; + while (current && visited.insert(current).second) { + Operation *def = current.getDefiningOp(); + if (!def) + break; + if (auto convert = dyn_cast(def)) { + current = convert.getSrc(); + continue; + } + if (auto bcast = dyn_cast(def)) { + current = bcast.getSrc(); + continue; + } + if (auto expand = dyn_cast(def)) { + current = expand.getSrc(); + continue; + } + if (auto reshape = dyn_cast(def)) { + current = reshape.getSrc(); + continue; + } + return def; + } + return current ? current.getDefiningOp() : nullptr; +} + +static void copyAxisInfoAttrs(Operation *src, Operation *dst) { + if (!src || !dst) + return; + auto tryCopy = [&](StringRef name) { + if (dst->getDiscardableAttr(name)) + return; + if (auto attr = src->getDiscardableAttr(name)) + dst->setDiscardableAttr(name, attr); + }; + tryCopy(kTTContiguityAttr); + tryCopy(kTTDivisibilityAttr); + tryCopy(kTTConstancyAttr); +} + +static void +collectConsumerEncodingVotes(Value root, + llvm::SmallVectorImpl &votes) { + auto rootLocal = + stripConvertLayouts(root).getDefiningOp(); + bool preferMaskForScalarLocalPointers = false; + if (rootLocal) { + if (auto memDescTy = + dyn_cast(rootLocal.getSrc().getType())) { + int64_t elemCount = 1; + for (int64_t dim : memDescTy.getShape()) { + if (dim <= 0) { + elemCount = 0; + break; + } + elemCount *= dim; + } + preferMaskForScalarLocalPointers = (elemCount == 1); + } + } + + llvm::SmallVector worklist; + llvm::DenseSet visited; + auto enqueue = [&](Value v) { + if (!v) + return; + if (!visited.insert(v).second) + return; + worklist.push_back(v); + }; + + enqueue(root); + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + for (OpOperand &use : current.getUses()) { + Operation *owner = use.getOwner(); + if (auto load = dyn_cast(owner)) { + if (isRewritableFullViewLocalPointerLoad(load)) + continue; + if (Attribute loadEncoding = + getStrippedTensorEncoding(load.getResult())) { + const int64_t depthFactor = 1 + getScfLoopDepth(owner); + int64_t score = 8 * depthFactor; + if (valueFeedsDot(load.getResult())) + score += 128 * depthFactor; + votes.push_back({loadEncoding, score}); + } + continue; + } + if (auto store = dyn_cast(owner)) { + if (Attribute valueEncoding = + getStrippedTensorEncoding(store.getValue())) { + const int64_t depthFactor = 1 + getScfLoopDepth(owner); + int64_t score = 2 * depthFactor; + if (Operation *def = store.getValue().getDefiningOp(); + def && isa(def)) + score += 8 * depthFactor; + votes.push_back({valueEncoding, score}); + } + if (Value mask = store.getMask()) + if (Attribute maskEncoding = getStrippedTensorEncoding(mask)) + votes.push_back({maskEncoding, 2 * (1 + getScfLoopDepth(owner))}); + continue; + } + if (auto atomic = dyn_cast(owner)) { + const int64_t depthFactor = 1 + getScfLoopDepth(owner); + const int64_t valScore = + (preferMaskForScalarLocalPointers ? 8 : 24) * depthFactor; + const int64_t maskScoreBase = + (preferMaskForScalarLocalPointers ? 48 : 12) * depthFactor; + const int64_t resultScore = + (preferMaskForScalarLocalPointers ? 0 : 12) * depthFactor; + if (Attribute valEncoding = getStrippedTensorEncoding(atomic.getVal())) + votes.push_back({valEncoding, valScore}); + if (Value mask = atomic.getMask()) { + if (Attribute maskEncoding = getStrippedTensorEncoding(mask)) { + int64_t maskScore = maskScoreBase; + if (preferMaskForScalarLocalPointers && + isConstantLikeTensorValue(mask)) + maskScore = depthFactor; + votes.push_back({maskEncoding, maskScore}); + } + } + if (resultScore > 0) + if (Attribute resultEncoding = + getStrippedTensorEncoding(atomic.getResult())) + votes.push_back({resultEncoding, resultScore}); + continue; + } + if (auto cas = dyn_cast(owner)) { + const int64_t depthFactor = 1 + getScfLoopDepth(owner); + const int64_t valScore = + (preferMaskForScalarLocalPointers ? 8 : 24) * depthFactor; + const int64_t cmpScore = + (preferMaskForScalarLocalPointers ? 48 : 12) * depthFactor; + const int64_t resultScore = + (preferMaskForScalarLocalPointers ? 0 : 12) * depthFactor; + if (Attribute cmpEncoding = getStrippedTensorEncoding(cas.getCmp())) + votes.push_back({cmpEncoding, cmpScore}); + if (Attribute valEncoding = getStrippedTensorEncoding(cas.getVal())) + votes.push_back({valEncoding, valScore}); + if (resultScore > 0) + if (Attribute resultEncoding = + getStrippedTensorEncoding(cas.getResult())) + votes.push_back({resultEncoding, resultScore}); + continue; + } + if (auto convert = dyn_cast(owner)) { + enqueue(convert.getResult()); + continue; + } + if (auto bcast = dyn_cast(owner)) { + enqueue(bcast.getResult()); + continue; + } + if (auto expand = dyn_cast(owner)) { + enqueue(expand.getResult()); + continue; + } + if (auto reshape = dyn_cast(owner)) { + enqueue(reshape.getResult()); + continue; + } + if (auto remote = dyn_cast(owner)) { + enqueue(remote.getResult()); + continue; + } + } + } +} + +static Attribute pickDominantEncoding(ArrayRef votes, + Attribute fallback) { + if (votes.empty()) + return fallback; + + llvm::DenseMap scoreByEncoding; + llvm::SmallVector order; + for (const EncodingVote &vote : votes) { + if (!vote.encoding) + continue; + auto [it, inserted] = scoreByEncoding.try_emplace(vote.encoding, 0); + if (inserted) + order.push_back(vote.encoding); + it->second += vote.score; + } + if (order.empty()) + return fallback; + + Attribute best = order.front(); + int64_t bestScore = scoreByEncoding.lookup(best); + for (Attribute encoding : order) { + int64_t score = scoreByEncoding.lookup(encoding); + if (score > bestScore) { + best = encoding; + bestScore = score; + continue; + } + if (score == bestScore && encoding == fallback) + best = encoding; + } + return best; +} + +static bool isPointerTensorType(Type type) { + auto tensorTy = dyn_cast(type); + if (!tensorTy) + return false; + return isa(tensorTy.getElementType()); +} + +static void bridgeResultTypeToOldEncoding(Value result, Type oldType, + OpBuilder &builder) { + if (result.getType() == oldType) + return; + auto oldTensorTy = dyn_cast(oldType); + if (!oldTensorTy) + return; + Operation *def = result.getDefiningOp(); + if (!def) + return; + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(def); + auto bridge = builder.create( + def->getLoc(), oldTensorTy, result); + result.replaceAllUsesExcept(bridge.getResult(), bridge.getOperation()); +} + +static bool tryFoldPointerConvertLayout(triton::gpu::ConvertLayoutOp convert, + OpBuilder &builder, + CachedConversionMap &cache) { + auto srcTy = dyn_cast(convert.getSrc().getType()); + auto dstTy = dyn_cast(convert.getType()); + if (!srcTy || !dstTy) + return false; + if (!isa(srcTy.getElementType()) || + !isa(dstTy.getElementType())) + return false; + + Value srcPtr = convert.getSrc(); + Value convertedPtr = convert.getResult(); + Attribute srcEncoding = srcTy.getEncoding(); + auto srcElemTy = + cast(srcTy.getElementType()).getPointeeType(); + auto srcLoadTy = + RankedTensorType::get(srcTy.getShape(), srcElemTy, srcEncoding); + + SmallVector uses; + uses.reserve(convertedPtr.getNumUses()); + for (OpOperand &use : convertedPtr.getUses()) { + Operation *owner = use.getOwner(); + if (!isa(owner)) + return false; + uses.push_back(&use); + } + + auto convertOperandEncoding = [&](Operation *insertBefore, Value v, + Attribute encoding) -> Value { + return getOrCreateCachedConvertLayout(builder, insertBefore, v, encoding, + cache); + }; + + for (OpOperand *use : uses) { + Operation *owner = use->getOwner(); + use->set(srcPtr); + + if (auto load = dyn_cast(owner)) { + if (Value mask = load.getMask()) { + Value convertedMask = convertOperandEncoding(owner, mask, srcEncoding); + if (convertedMask != mask) + load.getMaskMutable().assign(convertedMask); + } + if (Value other = load.getOther()) { + Value convertedOther = + convertOperandEncoding(owner, other, srcEncoding); + if (convertedOther != other) + load.getOtherMutable().assign(convertedOther); + } + Type oldType = load.getResult().getType(); + if (oldType != srcLoadTy) { + load.getResult().setType(srcLoadTy); + bridgeResultTypeToOldEncoding(load.getResult(), oldType, builder); + } + continue; + } + + if (auto store = dyn_cast(owner)) { + Value value = store.getValue(); + Value convertedValue = convertOperandEncoding(owner, value, srcEncoding); + if (convertedValue != value) + store.getValueMutable().assign(convertedValue); + if (Value mask = store.getMask()) { + Value convertedMask = convertOperandEncoding(owner, mask, srcEncoding); + if (convertedMask != mask) + store.getMaskMutable().assign(convertedMask); + } + continue; + } + + if (auto atomic = dyn_cast(owner)) { + Value val = atomic.getVal(); + Value convertedVal = convertOperandEncoding(owner, val, srcEncoding); + if (convertedVal != val) + atomic.getValMutable().assign(convertedVal); + if (Value mask = atomic.getMask()) { + Value convertedMask = convertOperandEncoding(owner, mask, srcEncoding); + if (convertedMask != mask) + atomic.getMaskMutable().assign(convertedMask); + } + Type oldType = atomic.getResult().getType(); + if (oldType != srcLoadTy) { + atomic.getResult().setType(srcLoadTy); + bridgeResultTypeToOldEncoding(atomic.getResult(), oldType, builder); + } + continue; + } + + auto cas = cast(owner); + Value cmp = cas.getCmp(); + Value convertedCmp = convertOperandEncoding(owner, cmp, srcEncoding); + if (convertedCmp != cmp) + cas.getCmpMutable().assign(convertedCmp); + Value val = cas.getVal(); + Value convertedVal = convertOperandEncoding(owner, val, srcEncoding); + if (convertedVal != val) + cas.getValMutable().assign(convertedVal); + Type oldType = cas.getResult().getType(); + if (oldType != srcLoadTy) { + cas.getResult().setType(srcLoadTy); + bridgeResultTypeToOldEncoding(cas.getResult(), oldType, builder); + } + } + + if (convertedPtr.use_empty()) + convert.erase(); + return true; +} + +class SelectEncodingsPass + : public impl::TritonTleSelectEncodingsBase { + void runOnOperation() override { + ModuleOp module = getOperation(); + OpBuilder builder(module.getContext()); + CachedConversionMap userOperandConversionCache; + CachedConversionMap indexOperandConversionCache; + module.walk([&](triton::tle::LocalPointersOp op) { + // Always tag local pointer ops so barrier insertion can track hazards + // across different pointer views of the same alloc. + tagDependencyGroup(op, builder); + + auto tensorTy = dyn_cast(op.getResult().getType()); + auto scalarPtrTy = + dyn_cast(op.getResult().getType()); + if (!tensorTy && !scalarPtrTy) + return; + auto ptrTy = + tensorTy ? dyn_cast(tensorTy.getElementType()) + : scalarPtrTy; + if (!ptrTy) + return; + bool updated = false; + Type updatedResultTy = op.getResult().getType(); + const auto desiredAddrSpace = kSharedMemoryAddressSpace; + if (ptrTy.getAddressSpace() != desiredAddrSpace) { + ptrTy = + triton::PointerType::get(ptrTy.getPointeeType(), desiredAddrSpace); + updated = true; + } + + if (!tensorTy) { + if (updated) + op.getResult().setType(ptrTy); + return; + } + + auto encoding = tensorTy.getEncoding(); + SmallVector votes; + collectConsumerEncodingVotes(op.getResult(), votes); + for (Value index : op.getIndices()) { + Attribute indexEncoding = getStrippedTensorEncoding(index); + if (!indexEncoding) + continue; + const bool constantLike = isConstantLikeTensorValue(index); + int64_t elemCount = 1; + if (auto indexTy = dyn_cast(index.getType())) { + for (int64_t dim : indexTy.getShape()) { + if (dim <= 0) { + elemCount = 0; + break; + } + elemCount *= dim; + } + } + const int64_t depthFactor = 1 + getScfLoopDepth(op.getOperation()); + int64_t baseScore = constantLike ? 1 : 12; + if (!constantLike) { + if (elemCount >= 1024) + baseScore = 192; + else if (elemCount >= 256) + baseScore = 64; + } + const int64_t score = baseScore * depthFactor; + votes.push_back({indexEncoding, score}); + } + Attribute userEncoding = pickDominantEncoding(votes, encoding); + if (userEncoding && userEncoding != encoding) { + encoding = userEncoding; + updated = true; + } + if (!encoding) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(op); + int numWarps = triton::gpu::maybeLookupNumWarps(op).value_or(1); + int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(builder); + int numCTAs = triton::gpu::lookupNumCTAs(builder); + encoding = triton::gpu::getDefaultBlockedEncoding( + module.getContext(), tensorTy.getShape(), numWarps, threadsPerWarp, + numCTAs); + updated = true; + } + + if (updated) + updatedResultTy = + RankedTensorType::get(tensorTy.getShape(), ptrTy, encoding); + + if (updated) + op.getResult().setType(updatedResultTy); + + if (updated) { + llvm::DenseSet visited; + auto updateUserResultTypes = [&](auto &&self, Value ptrVal) -> void { + if (!ptrVal || !visited.insert(ptrVal).second) + return; + auto ptrTensorTy = cast(ptrVal.getType()); + auto ptrEncoding = ptrTensorTy.getEncoding(); + auto ptrElemTy = + cast(ptrTensorTy.getElementType()) + .getPointeeType(); + auto loadTy = RankedTensorType::get(ptrTensorTy.getShape(), ptrElemTy, + ptrTensorTy.getEncoding()); + auto convertOperandEncoding = [&](Operation *insertBefore, Value v, + Attribute encoding) -> Value { + return getOrCreateCachedConvertLayout( + builder, insertBefore, v, encoding, userOperandConversionCache); + }; + for (OpOperand &use : ptrVal.getUses()) { + Operation *owner = use.getOwner(); + if (auto load = dyn_cast(owner)) { + if (isRewritableFullViewLocalPointerLoad(load)) + continue; + if (Value mask = load.getMask()) { + Value convertedMask = + convertOperandEncoding(owner, mask, ptrEncoding); + if (convertedMask != mask) + load.getMaskMutable().assign(convertedMask); + } + if (Value other = load.getOther()) { + Value convertedOther = + convertOperandEncoding(owner, other, ptrEncoding); + if (convertedOther != other) + load.getOtherMutable().assign(convertedOther); + } + auto oldLoadTy = + dyn_cast(load.getResult().getType()); + if (oldLoadTy != loadTy) { + load.getResult().setType(loadTy); + if (oldLoadTy) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(load); + auto bridge = builder.create( + load.getLoc(), oldLoadTy, load.getResult()); + load.getResult().replaceAllUsesExcept(bridge.getResult(), + bridge.getOperation()); + } + } + continue; + } + if (auto store = dyn_cast(owner)) { + auto valueTy = + dyn_cast(store.getValue().getType()); + if (valueTy) { + Value convertedValue = convertOperandEncoding( + owner, store.getValue(), ptrEncoding); + if (convertedValue != store.getValue()) + store.getValueMutable().assign(convertedValue); + } + if (Value mask = store.getMask()) { + Value convertedMask = + convertOperandEncoding(owner, mask, ptrEncoding); + if (convertedMask != mask) + store.getMaskMutable().assign(convertedMask); + } + continue; + } + if (auto atomic = dyn_cast(owner)) { + Value val = atomic.getVal(); + Value convertedVal = + convertOperandEncoding(owner, val, ptrEncoding); + if (convertedVal != val) + atomic.getValMutable().assign(convertedVal); + if (Value mask = atomic.getMask()) { + Value convertedMask = + convertOperandEncoding(owner, mask, ptrEncoding); + if (convertedMask != mask) + atomic.getMaskMutable().assign(convertedMask); + } + auto oldAtomicTy = + dyn_cast(atomic.getResult().getType()); + if (oldAtomicTy != loadTy) { + atomic.getResult().setType(loadTy); + if (oldAtomicTy) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(atomic); + auto bridge = builder.create( + atomic.getLoc(), oldAtomicTy, atomic.getResult()); + atomic.getResult().replaceAllUsesExcept( + bridge.getResult(), bridge.getOperation()); + } + } + continue; + } + if (auto cas = dyn_cast(owner)) { + Value cmp = cas.getCmp(); + Value convertedCmp = + convertOperandEncoding(owner, cmp, ptrEncoding); + if (convertedCmp != cmp) + cas.getCmpMutable().assign(convertedCmp); + Value val = cas.getVal(); + Value convertedVal = + convertOperandEncoding(owner, val, ptrEncoding); + if (convertedVal != val) + cas.getValMutable().assign(convertedVal); + auto oldCasTy = + dyn_cast(cas.getResult().getType()); + if (oldCasTy != loadTy) { + cas.getResult().setType(loadTy); + if (oldCasTy) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(cas); + auto bridge = builder.create( + cas.getLoc(), oldCasTy, cas.getResult()); + cas.getResult().replaceAllUsesExcept(bridge.getResult(), + bridge.getOperation()); + } + } + continue; + } + if (auto remote = dyn_cast(owner)) { + auto remoteResultTy = + dyn_cast(remote.getResult().getType()); + if (!remoteResultTy) + continue; + auto remoteElemPtrTy = dyn_cast( + remoteResultTy.getElementType()); + if (!remoteElemPtrTy) + continue; + auto desiredElemPtrTy = triton::PointerType::get( + ptrElemTy, remoteElemPtrTy.getAddressSpace()); + auto desiredRemoteTy = RankedTensorType::get( + ptrTensorTy.getShape(), desiredElemPtrTy, ptrEncoding); + if (remoteResultTy != desiredRemoteTy) + remote.getResult().setType(desiredRemoteTy); + self(self, remote.getResult()); + continue; + } + } + }; + updateUserResultTypes(updateUserResultTypes, op.getResult()); + } + + auto desiredEncoding = + cast(updatedResultTy).getEncoding(); + if (desiredEncoding) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(op); + SmallVector newOperands; + newOperands.reserve(op->getNumOperands()); + newOperands.push_back(op.getSrc()); + bool updatedOperands = false; + for (Value operand : op.getIndices()) { + auto operandTy = dyn_cast(operand.getType()); + if (!operandTy) { + newOperands.push_back(operand); + continue; + } + if (operandTy.getEncoding() == desiredEncoding) { + newOperands.push_back(operand); + continue; + } + auto converted = getOrCreateCachedConvertLayout( + builder, op.getOperation(), operand, desiredEncoding, + indexOperandConversionCache); + newOperands.push_back(converted); + updatedOperands = (converted != operand) || updatedOperands; + } + if (updatedOperands) + op->setOperands(newOperands); + } + }); + + // remote_pointers should preserve source pointer axis properties so later + // passes can reason about remote operands without dialect-specific + // visitors. + module.walk([&](triton::tle::RemotePointersOp op) { + Operation *srcDef = peelAxisInfoCarrier(op.getSrc()); + copyAxisInfoAttrs(srcDef, op.getOperation()); + }); + + // Fold pointer convert_layout around local/remote pointer users after + // encoding updates to avoid leaving convert chains on ptr tensors. + bool changed = true; + while (changed) { + changed = false; + SmallVector ptrConverts; + module.walk([&](triton::gpu::ConvertLayoutOp convert) { + if (isPointerTensorType(convert.getType()) && + isPointerTensorType(convert.getSrc().getType())) + ptrConverts.push_back(convert); + }); + for (triton::gpu::ConvertLayoutOp convert : ptrConverts) { + if (convert->getBlock() == nullptr) + continue; + changed |= tryFoldPointerConvertLayout(convert, builder, + userOperandConversionCache); + } + } + } + + void tagDependencyGroup(triton::tle::LocalPointersOp op, OpBuilder &builder) { + auto alloc = op.getSrc().getDefiningOp(); + if (!alloc) + return; + auto groupAttr = alloc->getAttrOfType(kBarrierGroupAttr); + if (!groupAttr) { + groupAttr = builder.getI64IntegerAttr(nextBarrierGroupId++); + alloc->setAttr(kBarrierGroupAttr, groupAttr); + } + op->setAttr(kBarrierGroupAttr, groupAttr); + } + + int64_t nextBarrierGroupId = 0; +}; + +} // namespace +} // namespace mlir::triton::tle diff --git a/third_party/tle/test/Analysis/test-alignment.mlir b/third_party/tle/test/Analysis/test-alignment.mlir index 20cd049ac6..5a3456b4cb 100644 --- a/third_party/tle/test/Analysis/test-alignment.mlir +++ b/third_party/tle/test/Analysis/test-alignment.mlir @@ -10,6 +10,6 @@ tt.func @tle_remote_pointers_axis_info(%arg0: !tt.ptr {tt.divisibility = %2 = tt.addptr %1, %0 : tensor<128x!tt.ptr>, tensor<128xi32> // expected-remark @below {{contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0}} %c0_i32 = arith.constant 0 : i32 - %3 = "tle.remote_pointers"(%2, %c0_i32) : (tensor<128x!tt.ptr>, i32) -> tensor<128x!tt.ptr> + %3 = "tle.remote_pointers"(%2, %c0_i32) : (tensor<128x!tt.ptr>, i32) -> tensor<128x!tt.ptr> tt.return } diff --git a/third_party/tle/test/GPU/test_tle_axisinfo_hints.mlir b/third_party/tle/test/GPU/test_tle_axisinfo_hints.mlir new file mode 100644 index 0000000000..2dd64dc036 --- /dev/null +++ b/third_party/tle/test/GPU/test_tle_axisinfo_hints.mlir @@ -0,0 +1,38 @@ +// MIT License +// +// Copyright (c) 2025 The FlagOS Contributors +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// +// flagtree tle + +// RUN: triton-opt %s -test-print-alignment -split-input-file 2>&1 | FileCheck %s + +tt.func @ignore_rank_mismatched_axis_hints() { + // CHECK: tt.make_range {{.*}} => contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK: tt.expand_dims {{.*}} => contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + // Scalar tt.* hints on rank-2 tensor results are malformed and must be + // ignored. Otherwise AxisInfo rank may be shrunk and later vectorization can + // query out-of-bounds dimensions. + // CHECK: tt.broadcast {{.*}} => contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [2, 1], constant_value = + %2 = tt.broadcast %1 {tt.contiguity = 8 : i32, tt.divisibility = 16 : i32, tt.constancy = 4 : i32} : tensor<1x128xi32> -> tensor<2x128xi32> + tt.return +} diff --git a/third_party/tle/test/GPU/test_tle_cumsum_scratch_alias.mlir b/third_party/tle/test/GPU/test_tle_cumsum_scratch_alias.mlir new file mode 100644 index 0000000000..ee4ac55380 --- /dev/null +++ b/third_party/tle/test/GPU/test_tle_cumsum_scratch_alias.mlir @@ -0,0 +1,26 @@ +// RUN: triton-opt %s -pass-pipeline='builtin.module(allocate-shared-memory-nv{compute-capability=120 ptx-version=88})' | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [16], order = [0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = true, elementBitWidth = 32, rank = 1}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func @cumsum_scratch_no_alias + // CHECK: ttg.local_alloc {allocation.offset = 0 : i32 + // CHECK: tle.exclusive_cumsum + // CHECK-SAME: allocation.offset = 4096 : i32 + tt.func @cumsum_scratch_no_alias(%out: !tt.ptr) { + %c0_i32 = arith.constant 0 : i32 + %cst_1 = arith.constant dense<1> : tensor<512xi32, #blocked> + %offs = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked> + %alloc = ttg.local_alloc : () -> !ttg.memdesc<1024xi32, #shared, #smem, mutable> + %base = "tle.local_pointers"(%alloc, %c0_i32) {tle.barrier_group = 0 : i64} : (!ttg.memdesc<1024xi32, #shared, #smem, mutable>, i32) -> !tt.ptr + %base_tensor = tt.splat %base : !tt.ptr -> tensor<512x!tt.ptr, #blocked> + %ptrs = tt.addptr %base_tensor, %offs : tensor<512x!tt.ptr, #blocked>, tensor<512xi32, #blocked> + tt.store %ptrs, %cst_1 : tensor<512x!tt.ptr, #blocked> + %x = tt.load %ptrs : tensor<512x!tt.ptr, #blocked> + %exclusive, %total = "tle.exclusive_cumsum"(%x) {axis = 0 : i32, reverse = false} : (tensor<512xi32, #blocked>) -> (tensor<512xi32, #blocked>, i32) + tt.store %ptrs, %exclusive : tensor<512x!tt.ptr, #blocked> + tt.return + } +} diff --git a/third_party/tle/test/GPU/test_tle_exclusive_cumsum_layout_propagation.mlir b/third_party/tle/test/GPU/test_tle_exclusive_cumsum_layout_propagation.mlir new file mode 100644 index 0000000000..7d669e8453 --- /dev/null +++ b/third_party/tle/test/GPU/test_tle_exclusive_cumsum_layout_propagation.mlir @@ -0,0 +1,58 @@ +// MIT License +// +// Copyright (c) 2025 The FlagOS Contributors +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// +// flagtree tle + +// RUN: triton-opt %s -split-input-file -tritongpu-remove-layout-conversions | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func @remove_cumsum_layout_sandwich + // CHECK: %[[EX:.*]], %[[TOT:.*]] = "tle.exclusive_cumsum"(%arg0) {{.*}} : (tensor<256xi32, #blocked>) -> (tensor<256xi32, #blocked>, i32) + // CHECK-NOT: ttg.convert_layout %[[EX]] + tt.func @remove_cumsum_layout_sandwich(%arg0: tensor<256xi32, #blocked>) -> (tensor<256xi32, #blocked>, i32) { + %0 = ttg.convert_layout %arg0 : tensor<256xi32, #blocked> -> tensor<256xi32, #blocked1> + %exclusive, %total = "tle.exclusive_cumsum"(%0) {axis = 0 : i32, reverse = false} : (tensor<256xi32, #blocked1>) -> (tensor<256xi32, #blocked1>, i32) + %1 = ttg.convert_layout %exclusive : tensor<256xi32, #blocked1> -> tensor<256xi32, #blocked> + tt.return %1, %total : tensor<256xi32, #blocked>, i32 + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func @keep_convert_for_non_convert_user + // CHECK: %[[EX:.*]], %[[TOT:.*]] = tle.exclusive_cumsum %arg0 {axis = 0 : i32, reverse = false} : tensor<256xi32, #blocked> -> tensor<256xi32, #blocked>, i32 + // CHECK: %[[ADD:.*]] = arith.addi %[[EX]], %[[EX]] : tensor<256xi32, #blocked> + // CHECK: %[[RET:.*]] = ttg.convert_layout %[[ADD]] : tensor<256xi32, #blocked> -> tensor<256xi32, #blocked1> + tt.func @keep_convert_for_non_convert_user(%arg0: tensor<256xi32, #blocked>) -> (tensor<256xi32, #blocked1>, i32) { + %0 = ttg.convert_layout %arg0 : tensor<256xi32, #blocked> -> tensor<256xi32, #blocked1> + %exclusive, %total = "tle.exclusive_cumsum"(%0) {axis = 0 : i32, reverse = false} : (tensor<256xi32, #blocked1>) -> (tensor<256xi32, #blocked1>, i32) + %1 = arith.addi %exclusive, %exclusive : tensor<256xi32, #blocked1> + tt.return %1, %total : tensor<256xi32, #blocked1>, i32 + } +} diff --git a/third_party/tle/test/GPU/test_tle_insert_local_pointer_barriers.mlir b/third_party/tle/test/GPU/test_tle_insert_local_pointer_barriers.mlir new file mode 100644 index 0000000000..ba181e62ae --- /dev/null +++ b/third_party/tle/test/GPU/test_tle_insert_local_pointer_barriers.mlir @@ -0,0 +1,208 @@ +// MIT License +// +// Copyright (c) 2025 The FlagOS Contributors +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// +// flagtree tle + +// RUN: triton-opt %s -split-input-file -triton-tle-insert-local-pointer-barriers | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func @uniform_reduce_or_if_barrier + tt.func @uniform_reduce_or_if_barrier() { + %idx = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #blocked> + %zeros = arith.constant dense<0> : tensor<32xi32, #blocked> + %ones = arith.constant dense<1> : tensor<32xi32, #blocked> + %true_mask = arith.constant dense : tensor<32xi1, #blocked> + %smem0 = ttg.local_alloc : () -> !ttg.memdesc<32xi32, #shared, #smem, mutable> + %smem1 = ttg.local_alloc : () -> !ttg.memdesc<32xi32, #shared, #smem, mutable> + %ptr0 = "tle.local_pointers"(%smem0, %idx) {tle.barrier_group = 0 : i64} : (!ttg.memdesc<32xi32, #shared, #smem, mutable>, tensor<32xi32, #blocked>) -> tensor<32x!tt.ptr, #blocked> + %ptr1 = "tle.local_pointers"(%smem1, %idx) {tle.barrier_group = 1 : i64} : (!ttg.memdesc<32xi32, #shared, #smem, mutable>, tensor<32xi32, #blocked>) -> tensor<32x!tt.ptr, #blocked> + tt.store %ptr0, %zeros : tensor<32x!tt.ptr, #blocked> + // CHECK: %[[FOUND:.*]] = "tt.reduce"(%{{.*}}) + // CHECK-NOT: gpu.barrier + // CHECK: scf.if %[[FOUND]] { + %found = "tt.reduce"(%true_mask) <{axis = 0 : i32}> ({ + ^bb0(%lhs: i1, %rhs: i1): + %or = arith.ori %lhs, %rhs : i1 + tt.reduce.return %or : i1 + }) : (tensor<32xi1, #blocked>) -> i1 + scf.if %found { + tt.store %ptr1, %ones : tensor<32x!tt.ptr, #blocked> + } else { + // CHECK: } else { + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: %[[LOAD:.*]] = tt.load %{{.*}} : tensor<32x!tt.ptr, #blocked> + %load = tt.load %ptr0 : tensor<32x!tt.ptr, #blocked> + tt.store %ptr1, %load : tensor<32x!tt.ptr, #blocked> + } + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func @splat_store_scalar_load_barrier + tt.func @splat_store_scalar_load_barrier() { + %c0 = arith.constant 0 : i32 + %idx = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #blocked> + %vals = arith.constant dense<7> : tensor<32xi32, #blocked> + %mask = arith.cmpi eq, %idx, %idx : tensor<32xi32, #blocked> + %smem = ttg.local_alloc : () -> !ttg.memdesc<1xi32, #shared, #smem, mutable> + %ptr = "tle.local_pointers"(%smem, %c0) {tle.barrier_group = 5 : i64} : (!ttg.memdesc<1xi32, #shared, #smem, mutable>, i32) -> !tt.ptr + %ptrs = tt.splat %ptr : !tt.ptr -> tensor<32x!tt.ptr, #blocked> + // CHECK: tt.store %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x!tt.ptr, #blocked> + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: %[[L:.*]] = tt.load %{{.*}} : !tt.ptr + tt.store %ptrs, %vals, %mask : tensor<32x!tt.ptr, #blocked> + %l = tt.load %ptr : !tt.ptr + tt.store %ptr, %l : !tt.ptr + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func @addptr_store_scalar_load_barrier + tt.func @addptr_store_scalar_load_barrier() { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %c7 = arith.constant 7 : i32 + %smem = ttg.local_alloc : () -> !ttg.memdesc<4xi32, #shared, #smem, mutable> + %ptr = "tle.local_pointers"(%smem, %c0) {tle.barrier_group = 6 : i64} : (!ttg.memdesc<4xi32, #shared, #smem, mutable>, i32) -> !tt.ptr + %ptr_next = tt.addptr %ptr, %c1 : !tt.ptr, i32 + // CHECK: tt.store %{{.*}}, %{{.*}} : !tt.ptr + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: %[[L:.*]] = tt.load %{{.*}} : !tt.ptr + tt.store %ptr_next, %c7 : !tt.ptr + %l = tt.load %ptr : !tt.ptr + tt.store %ptr, %l : !tt.ptr + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func @single_barrier_for_consecutive_group_loads + tt.func @single_barrier_for_consecutive_group_loads() { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %smem0 = ttg.local_alloc : () -> !ttg.memdesc<1xi32, #shared, #smem, mutable> + %smem1 = ttg.local_alloc : () -> !ttg.memdesc<1xi32, #shared, #smem, mutable> + %ptr0 = "tle.local_pointers"(%smem0, %c0) {tle.barrier_group = 10 : i64} : (!ttg.memdesc<1xi32, #shared, #smem, mutable>, i32) -> !tt.ptr + %ptr1 = "tle.local_pointers"(%smem1, %c0) {tle.barrier_group = 11 : i64} : (!ttg.memdesc<1xi32, #shared, #smem, mutable>, i32) -> !tt.ptr + tt.store %ptr0, %c0 : !tt.ptr + tt.store %ptr1, %c1 : !tt.ptr + // CHECK: tt.store %{{.*}}, %{{.*}} : !tt.ptr + // CHECK: tt.store %{{.*}}, %{{.*}} : !tt.ptr + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: %[[L0:.*]] = tt.load %{{.*}} : !tt.ptr + // CHECK-NEXT: %[[L1:.*]] = tt.load %{{.*}} : !tt.ptr + // CHECK-NOT: gpu.barrier + %l0 = tt.load %ptr0 : !tt.ptr + %l1 = tt.load %ptr1 : !tt.ptr + %sum = arith.addi %l0, %l1 : i32 + tt.store %ptr0, %sum : !tt.ptr + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func private @callee_pointer_arg_barrier + tt.func private @callee_pointer_arg_barrier(%ptr: !tt.ptr) { + %cond = arith.constant true + %v0 = arith.constant 0 : i32 + %v1 = arith.constant 1 : i32 + scf.if %cond { + tt.store %ptr, %v1 : !tt.ptr + } + // CHECK: scf.if %{{.*}} { + // CHECK: tt.store %{{.*}}, %{{.*}} : !tt.ptr + // CHECK: } + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: %[[L:.*]] = tt.load %{{.*}} : !tt.ptr + %l = tt.load %ptr : !tt.ptr + %sum = arith.addi %l, %v0 : i32 + tt.store %ptr, %sum : !tt.ptr + tt.return + } + + // CHECK-LABEL: tt.func @caller_passes_local_pointer + tt.func @caller_passes_local_pointer() { + %c0 = arith.constant 0 : i32 + %smem = ttg.local_alloc : () -> !ttg.memdesc<1xi32, #shared, #smem, mutable> + %ptr = "tle.local_pointers"(%smem, %c0) {tle.barrier_group = 9 : i64} : (!ttg.memdesc<1xi32, #shared, #smem, mutable>, i32) -> !tt.ptr + tt.call @callee_pointer_arg_barrier(%ptr) : (!tt.ptr) -> () + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func @nested_store_outer_load_barrier + tt.func @nested_store_outer_load_barrier() { + %idx = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #blocked> + %vals = arith.constant dense<7> : tensor<32xi32, #blocked> + %cond = arith.constant true + %smem = ttg.local_alloc : () -> !ttg.memdesc<32xi32, #shared, #smem, mutable> + %ptr = "tle.local_pointers"(%smem, %idx) {tle.barrier_group = 3 : i64} : (!ttg.memdesc<32xi32, #shared, #smem, mutable>, tensor<32xi32, #blocked>) -> tensor<32x!tt.ptr, #blocked> + scf.if %cond { + tt.store %ptr, %vals : tensor<32x!tt.ptr, #blocked> + } + // CHECK: scf.if %{{.*}} { + // CHECK: tt.store %{{.*}}, %{{.*}} : tensor<32x!tt.ptr, #blocked> + // CHECK: } + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: %[[LOAD:.*]] = tt.load %{{.*}} : tensor<32x!tt.ptr, #blocked> + %load = tt.load %ptr : tensor<32x!tt.ptr, #blocked> + tt.store %ptr, %load : tensor<32x!tt.ptr, #blocked> + tt.return + } +} diff --git a/third_party/tle/test/GPU/test_tle_local_ptr_vectorized_load.mlir b/third_party/tle/test/GPU/test_tle_local_ptr_vectorized_load.mlir new file mode 100644 index 0000000000..3c235d402c --- /dev/null +++ b/third_party/tle/test/GPU/test_tle_local_ptr_vectorized_load.mlir @@ -0,0 +1,28 @@ +// RUN: triton-opt %s -pass-pipeline='builtin.module(allocate-shared-memory-nv{compute-capability=120 ptx-version=88}, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=120 ptx-version=88}, canonicalize, cse, convert-nv-gpu-to-llvm, convert-warp-specialize-to-llvm, canonicalize, cse, symbol-dce, convert-nvvm-to-llvm)' | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [16, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @local_ptr_v4_load() { + %c0_i32 = arith.constant 0 : i32 + %smem = ttg.local_alloc {tle.barrier_group = 0 : i64} : () -> !ttg.memdesc<4096xi32, #shared, #smem, mutable> + %base = "tle.local_pointers"(%smem, %c0_i32) {tle.barrier_group = 0 : i64} : (!ttg.memdesc<4096xi32, #shared, #smem, mutable>, i32) -> !tt.ptr + %basev = tt.splat %base : !tt.ptr -> tensor<512x4x!tt.ptr, #blocked> + + %row = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %row2d = tt.expand_dims %row {axis = 1 : i32} : tensor<512xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<512x1xi32, #blocked> + %rowb = tt.broadcast %row2d : tensor<512x1xi32, #blocked> -> tensor<512x4xi32, #blocked> + %col = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %col2d = tt.expand_dims %col {axis = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x4xi32, #blocked> + %colb = tt.broadcast %col2d : tensor<1x4xi32, #blocked> -> tensor<512x4xi32, #blocked> + %offs = arith.addi %rowb, %colb : tensor<512x4xi32, #blocked> + + %ptrs = tt.addptr %basev, %offs : tensor<512x4x!tt.ptr, #blocked>, tensor<512x4xi32, #blocked> + %vals = tt.load %ptrs : tensor<512x4x!tt.ptr, #blocked> + tt.return + } +} + +// CHECK: ld.shared.v4.b32 diff --git a/third_party/tle/test/GPU/test_tle_lower_exclusive_cumsum.mlir b/third_party/tle/test/GPU/test_tle_lower_exclusive_cumsum.mlir new file mode 100644 index 0000000000..22c8e4c8b5 --- /dev/null +++ b/third_party/tle/test/GPU/test_tle_lower_exclusive_cumsum.mlir @@ -0,0 +1,61 @@ +// MIT License +// +// Copyright (c) 2025 The FlagOS Contributors +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// +// flagtree tle + +// RUN: triton-opt %s -split-input-file -triton-tle-lower-exclusive-cumsum | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func @lower_exclusive_cumsum_i32 + tt.func @lower_exclusive_cumsum_i32(%arg0: tensor<16xi32, #blocked>) -> (tensor<16xi32, #blocked>, i32) { + // CHECK: %[[SCAN:.*]] = "tt.scan"(%arg0) + // CHECK: %[[EXCLUSIVE:.*]] = arith.subi %[[SCAN]], %arg0 : tensor<16xi32, #blocked> + // CHECK: %[[RANGE:.*]] = tt.make_range + // CHECK: %[[REDUCE:.*]]:2 = "tt.reduce"(%[[RANGE]], %[[SCAN]]) + // CHECK: arith.cmpi sgt + // CHECK-NOT: "tle.exclusive_cumsum" + // CHECK: tt.return %[[EXCLUSIVE]], %[[REDUCE]]#1 + %exclusive, %total = "tle.exclusive_cumsum"(%arg0) {axis = 0 : i32, reverse = false} : (tensor<16xi32, #blocked>) -> (tensor<16xi32, #blocked>, i32) + tt.return %exclusive, %total : tensor<16xi32, #blocked>, i32 + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func @lower_exclusive_cumsum_f32_reverse + tt.func @lower_exclusive_cumsum_f32_reverse(%arg0: tensor<16xf32, #blocked>) -> (tensor<16xf32, #blocked>, f32) { + // CHECK: %[[SCAN:.*]] = "tt.scan"(%arg0) <{axis = 0 : i32, reverse = true}> + // CHECK: %[[EXCLUSIVE:.*]] = arith.subf %[[SCAN]], %arg0 : tensor<16xf32, #blocked> + // CHECK: %[[RANGE:.*]] = tt.make_range + // CHECK: %[[REDUCE:.*]]:2 = "tt.reduce"(%[[RANGE]], %[[SCAN]]) + // CHECK: arith.cmpi slt + // CHECK-NOT: "tle.exclusive_cumsum" + // CHECK: tt.return %[[EXCLUSIVE]], %[[REDUCE]]#1 + %exclusive, %total = "tle.exclusive_cumsum"(%arg0) {axis = 0 : i32, reverse = true} : (tensor<16xf32, #blocked>) -> (tensor<16xf32, #blocked>, f32) + tt.return %exclusive, %total : tensor<16xf32, #blocked>, f32 + } +} diff --git a/third_party/tle/test/GPU/test_tle_optimize_exclusive_cumsum_layouts.mlir b/third_party/tle/test/GPU/test_tle_optimize_exclusive_cumsum_layouts.mlir new file mode 100644 index 0000000000..f3d9c7019a --- /dev/null +++ b/third_party/tle/test/GPU/test_tle_optimize_exclusive_cumsum_layouts.mlir @@ -0,0 +1,58 @@ +// MIT License +// +// Copyright (c) 2025 The FlagOS Contributors +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// +// flagtree tle + +// RUN: triton-opt %s -split-input-file -triton-tle-optimize-exclusive-cumsum-layouts | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func @fold_cvt_cumsum_cvt + // CHECK: %[[EX:.*]], %[[TOT:.*]] = tle.exclusive_cumsum %arg0 {axis = 0 : i32, reverse = false} + // CHECK-NOT: ttg.convert_layout %[[EX]] + tt.func @fold_cvt_cumsum_cvt(%arg0: tensor<256xi32, #blocked>) -> (tensor<256xi32, #blocked>, i32) { + %0 = ttg.convert_layout %arg0 : tensor<256xi32, #blocked> -> tensor<256xi32, #blocked1> + %exclusive, %total = "tle.exclusive_cumsum"(%0) {axis = 0 : i32, reverse = false} : (tensor<256xi32, #blocked1>) -> (tensor<256xi32, #blocked1>, i32) + %1 = ttg.convert_layout %exclusive : tensor<256xi32, #blocked1> -> tensor<256xi32, #blocked> + tt.return %1, %total : tensor<256xi32, #blocked>, i32 + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func @keep_when_non_convert_user_exists + // CHECK: %[[INCVT:.*]] = ttg.convert_layout %arg0 + // CHECK: %[[EX2:.*]], %[[TOT2:.*]] = tle.exclusive_cumsum %[[INCVT]] + // CHECK: arith.addi %[[EX2]], %[[EX2]] + tt.func @keep_when_non_convert_user_exists(%arg0: tensor<256xi32, #blocked>) -> (tensor<256xi32, #blocked1>, i32) { + %0 = ttg.convert_layout %arg0 : tensor<256xi32, #blocked> -> tensor<256xi32, #blocked1> + %exclusive, %total = "tle.exclusive_cumsum"(%0) {axis = 0 : i32, reverse = false} : (tensor<256xi32, #blocked1>) -> (tensor<256xi32, #blocked1>, i32) + %1 = arith.addi %exclusive, %exclusive : tensor<256xi32, #blocked1> + tt.return %1, %total : tensor<256xi32, #blocked1>, i32 + } +} diff --git a/third_party/tle/test/GPU/test_tle_reduce_or_bar_red.mlir b/third_party/tle/test/GPU/test_tle_reduce_or_bar_red.mlir new file mode 100644 index 0000000000..48db7e416f --- /dev/null +++ b/third_party/tle/test/GPU/test_tle_reduce_or_bar_red.mlir @@ -0,0 +1,49 @@ +// MIT License +// +// Copyright (c) 2025 The FlagOS Contributors +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// +// flagtree tle + +// RUN: triton-opt %s -split-input-file --allocate-shared-memory-nv | FileCheck %s --check-prefix=ALLOC +// RUN: triton-opt %s -split-input-file --allocate-shared-memory-nv --convert-triton-gpu-to-llvm -reconcile-unrealized-casts | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // ALLOC-LABEL: tt.func @reduce_or_bar_red + // ALLOC-NOT: allocation.offset + // CHECK-LABEL: llvm.func @reduce_or_bar_red + // CHECK: llvm.inline_asm + // CHECK-SAME: bar.red.or.pred + tt.func @reduce_or_bar_red(%arg0: tensor<128xi1, #blocked>, %out_ptr: !tt.ptr) { + %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ + ^bb0(%lhs: i1, %rhs: i1): + %1 = arith.ori %lhs, %rhs : i1 + tt.reduce.return %1 : i1 + }) : (tensor<128xi1, #blocked>) -> i1 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked> + %3 = arith.constant dense<0> : tensor<128xi32, #blocked> + %4 = arith.cmpi eq, %2, %3 : tensor<128xi32, #blocked> + %5 = tt.splat %out_ptr : !tt.ptr -> tensor<128x!tt.ptr, #blocked> + %6 = tt.splat %0 : i1 -> tensor<128xi1, #blocked> + tt.store %5, %6, %4 : tensor<128x!tt.ptr, #blocked> + tt.return + } +} diff --git a/third_party/tle/test/GPU/test_tle_remote_atomic_splat_ptr_lowering.mlir b/third_party/tle/test/GPU/test_tle_remote_atomic_splat_ptr_lowering.mlir new file mode 100644 index 0000000000..e0e4a94ac7 --- /dev/null +++ b/third_party/tle/test/GPU/test_tle_remote_atomic_splat_ptr_lowering.mlir @@ -0,0 +1,21 @@ +// RUN: triton-opt %s -pass-pipeline='builtin.module(allocate-shared-memory-nv{compute-capability=90 ptx-version=80}, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=90 ptx-version=80})' | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @kernel() { + %c0_i32 = arith.constant 0 : i32 + %ones = arith.constant dense<1> : tensor<128xi32, #blocked> + %pred = arith.constant dense : tensor<128xi1, #blocked> + %smem = ttg.local_alloc : () -> !ttg.memdesc<16xi32, #shared, #smem, mutable> + %counter_local_ptr = "tle.local_pointers"(%smem, %c0_i32) : (!ttg.memdesc<16xi32, #shared, #smem, mutable>, i32) -> !tt.ptr + %counter_remote_ptr = "tle.remote_pointers"(%counter_local_ptr, %c0_i32) : (!tt.ptr, i32) -> !tt.ptr + %counter_ptrs = tt.splat %counter_remote_ptr : !tt.ptr -> tensor<128x!tt.ptr, #blocked> + %old = tt.atomic_rmw add, relaxed, cta, %counter_ptrs, %ones, %pred : (tensor<128x!tt.ptr, #blocked>, tensor<128xi32, #blocked>, tensor<128xi1, #blocked>) -> tensor<128xi32, #blocked> + tt.return + } +} + +// CHECK: atom.shared::cluster.cta.relaxed.add.u32 diff --git a/third_party/tle/test/GPU/test_tle_remote_call_arg_store_lowering.mlir b/third_party/tle/test/GPU/test_tle_remote_call_arg_store_lowering.mlir new file mode 100644 index 0000000000..65f8b78777 --- /dev/null +++ b/third_party/tle/test/GPU/test_tle_remote_call_arg_store_lowering.mlir @@ -0,0 +1,40 @@ +// RUN: triton-opt %s -pass-pipeline='builtin.module(allocate-shared-memory-nv{compute-capability=90 ptx-version=80}, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=90 ptx-version=80})' | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func private @callee(%counter_remote_ptr: !tt.ptr, %out_remote_ptr: !tt.ptr) { + %c0_i32 = arith.constant 0 : i32 + %lane = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked> + %zeros = arith.constant dense<0> : tensor<128xi32, #blocked> + %ones = arith.constant dense<1> : tensor<128xi32, #blocked> + %pred = arith.constant dense : tensor<128xi1, #blocked> + %counter_ptrs = tt.splat %counter_remote_ptr : !tt.ptr -> tensor<128x!tt.ptr, #blocked> + %counter_ptrs_2 = tt.addptr %counter_ptrs, %zeros : tensor<128x!tt.ptr, #blocked>, tensor<128xi32, #blocked> + %pos = tt.atomic_rmw add, relaxed, cta, %counter_ptrs_2, %ones, %pred : (tensor<128x!tt.ptr, #blocked>, tensor<128xi32, #blocked>, tensor<128xi1, #blocked>) -> tensor<128xi32, #blocked> + %out_ptrs = tt.splat %out_remote_ptr : !tt.ptr -> tensor<128x!tt.ptr, #blocked> + %dst = tt.addptr %out_ptrs, %pos : tensor<128x!tt.ptr, #blocked>, tensor<128xi32, #blocked> + tt.store %dst, %lane : tensor<128x!tt.ptr, #blocked> + tt.return + } + + tt.func public @caller() { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %smem = ttg.local_alloc : () -> !ttg.memdesc<2048xi32, #shared, #smem, mutable> + %counter_local_ptr = "tle.local_pointers"(%smem, %c0_i32) : (!ttg.memdesc<2048xi32, #shared, #smem, mutable>, i32) -> !tt.ptr + %out_local_ptr = "tle.local_pointers"(%smem, %c1_i32) : (!ttg.memdesc<2048xi32, #shared, #smem, mutable>, i32) -> !tt.ptr + %counter_remote_ptr = "tle.remote_pointers"(%counter_local_ptr, %c0_i32) : (!tt.ptr, i32) -> !tt.ptr + %out_remote_ptr = "tle.remote_pointers"(%out_local_ptr, %c0_i32) : (!tt.ptr, i32) -> !tt.ptr + tt.call @callee(%counter_remote_ptr, %out_remote_ptr) : (!tt.ptr, !tt.ptr) -> () + tt.return + } +} + +// CHECK: llvm.func internal @callee(%arg0: !llvm.ptr<7> +// CHECK-SAME: %arg1: !llvm.ptr<7> +// CHECK: atom.shared::cluster.cta.relaxed.add.u32 +// CHECK: st.shared::cluster.b32 +// CHECK: nvvm.mapa diff --git a/third_party/tle/test/GPU/test_tle_select_encodings.mlir b/third_party/tle/test/GPU/test_tle_select_encodings.mlir new file mode 100644 index 0000000000..3b735530ce --- /dev/null +++ b/third_party/tle/test/GPU/test_tle_select_encodings.mlir @@ -0,0 +1,137 @@ +// MIT License +// +// Copyright (c) 2025 The FlagOS Contributors +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// +// flagtree tle + +// RUN: triton-opt %s -split-input-file -triton-tle-select-encodings | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func @atomic_votes_drive_encoding + tt.func @atomic_votes_drive_encoding() { + %idx = arith.constant dense<0> : tensor<32x4xi32, #blocked4> + %ones = arith.constant dense<1> : tensor<32x4xi32, #blocked> + %mask = arith.constant dense : tensor<32x4xi1, #blocked> + %smem = ttg.local_alloc : () -> !ttg.memdesc<1xi32, #shared, #smem, mutable> + %ptrs = "tle.local_pointers"(%smem, %idx) : (!ttg.memdesc<1xi32, #shared, #smem, mutable>, tensor<32x4xi32, #blocked4>) -> tensor<32x4x!tt.ptr, #blocked4> + %ptrs_blocked = ttg.convert_layout %ptrs : tensor<32x4x!tt.ptr, #blocked4> -> tensor<32x4x!tt.ptr, #blocked> + %old = tt.atomic_rmw add, relaxed, cta, %ptrs_blocked, %ones, %mask : (tensor<32x4x!tt.ptr, #blocked>, tensor<32x4xi32, #blocked>, tensor<32x4xi1, #blocked>) -> tensor<32x4xi32, #blocked> + tt.return + } + // CHECK: %[[A_IDX:.*]] = arith.constant dense<0> : tensor<32x4xi32, #[[A_IDX_ENC:[A-Za-z0-9_]+]]> + // CHECK: %[[A_ONES:.*]] = arith.constant dense<1> : tensor<32x4xi32, #[[A_DATA_ENC:[A-Za-z0-9_]+]]> + // CHECK: %[[A_MASK:.*]] = arith.constant dense : tensor<32x4xi1, #[[A_DATA_ENC]]> + // CHECK: %[[A_IDX_CAST:.*]] = ttg.convert_layout %[[A_IDX]] : tensor<32x4xi32, #[[A_IDX_ENC]]> -> tensor<32x4xi32, #[[A_DATA_ENC]]> + // CHECK: %[[A_PTRS:.*]] = "tle.local_pointers"(%{{.*}}, %[[A_IDX_CAST]]) {{.*}} : (!ttg.memdesc<1xi32, #shared, #smem, mutable>, tensor<32x4xi32, #[[A_DATA_ENC]]>) -> tensor<32x4x!tt.ptr, #[[A_DATA_ENC]]> + // CHECK: tt.atomic_rmw add, relaxed, cta, %{{.*}}, %[[A_ONES]], %[[A_MASK]] : (tensor<32x4x!tt.ptr, #[[A_DATA_ENC]]>, tensor<32x4xi32, #[[A_DATA_ENC]]>, tensor<32x4xi1, #[[A_DATA_ENC]]>) -> tensor<32x4xi32, #[[A_DATA_ENC]]> +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func @index_convert_reused + tt.func @index_convert_reused() { + %idx = arith.constant dense<0> : tensor<32x4xi32, #blocked4> + %vals = arith.constant dense<1> : tensor<32x4xi32, #blocked> + %mask = arith.constant dense : tensor<32x4xi1, #blocked> + %smem0 = ttg.local_alloc : () -> !ttg.memdesc<1xi32, #shared, #smem, mutable> + %smem1 = ttg.local_alloc : () -> !ttg.memdesc<1xi32, #shared, #smem, mutable> + %ptr0 = "tle.local_pointers"(%smem0, %idx) : (!ttg.memdesc<1xi32, #shared, #smem, mutable>, tensor<32x4xi32, #blocked4>) -> tensor<32x4x!tt.ptr, #blocked4> + %ptr1 = "tle.local_pointers"(%smem1, %idx) : (!ttg.memdesc<1xi32, #shared, #smem, mutable>, tensor<32x4xi32, #blocked4>) -> tensor<32x4x!tt.ptr, #blocked4> + %ptr0_blocked = ttg.convert_layout %ptr0 : tensor<32x4x!tt.ptr, #blocked4> -> tensor<32x4x!tt.ptr, #blocked> + %ptr1_blocked = ttg.convert_layout %ptr1 : tensor<32x4x!tt.ptr, #blocked4> -> tensor<32x4x!tt.ptr, #blocked> + tt.store %ptr0_blocked, %vals, %mask : tensor<32x4x!tt.ptr, #blocked> + tt.store %ptr1_blocked, %vals, %mask : tensor<32x4x!tt.ptr, #blocked> + tt.return + } + // CHECK: %[[B_IDX:.*]] = arith.constant dense<0> : tensor<32x4xi32, #[[B_IDX_ENC:[A-Za-z0-9_]+]]> + // CHECK: %[[B_VALS:.*]] = arith.constant dense<1> : tensor<32x4xi32, #[[B_DATA_ENC:[A-Za-z0-9_]+]]> + // CHECK: %[[B_MASK:.*]] = arith.constant dense : tensor<32x4xi1, #[[B_DATA_ENC]]> + // CHECK: %[[B_IDX_CAST:.*]] = ttg.convert_layout %[[B_IDX]] : tensor<32x4xi32, #[[B_IDX_ENC]]> -> tensor<32x4xi32, #[[B_DATA_ENC]]> + // CHECK: %[[B_PTR0:.*]] = "tle.local_pointers"(%{{.*}}, %[[B_IDX_CAST]]) {{.*}} : (!ttg.memdesc<1xi32, #shared, #smem, mutable>, tensor<32x4xi32, #[[B_DATA_ENC]]>) -> tensor<32x4x!tt.ptr, #[[B_DATA_ENC]]> + // CHECK: %[[B_PTR1:.*]] = "tle.local_pointers"(%{{.*}}, %[[B_IDX_CAST]]) {{.*}} : (!ttg.memdesc<1xi32, #shared, #smem, mutable>, tensor<32x4xi32, #[[B_DATA_ENC]]>) -> tensor<32x4x!tt.ptr, #[[B_DATA_ENC]]> + // CHECK-NOT: ttg.convert_layout %[[B_IDX]] + // CHECK: tt.store %{{.*}}, %[[B_VALS]], %[[B_MASK]] : tensor<32x4x!tt.ptr, #[[B_DATA_ENC]]> + // CHECK: tt.store %{{.*}}, %[[B_VALS]], %[[B_MASK]] : tensor<32x4x!tt.ptr, #[[B_DATA_ENC]]> +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func @scalar_local_ptr_prefers_atomic_mask_encoding + tt.func @scalar_local_ptr_prefers_atomic_mask_encoding() { + %idx = arith.constant dense<0> : tensor<32x4xi32, #blocked> + %ones = arith.constant dense<1> : tensor<32x4xi32, #blocked> + %mask_seed = arith.constant dense<0> : tensor<32x4xi32, #blocked4> + %mask = arith.cmpi eq, %mask_seed, %mask_seed : tensor<32x4xi32, #blocked4> + %smem = ttg.local_alloc : () -> !ttg.memdesc<1xi32, #shared, #smem, mutable> + %ptr = "tle.local_pointers"(%smem, %idx) : (!ttg.memdesc<1xi32, #shared, #smem, mutable>, tensor<32x4xi32, #blocked>) -> tensor<32x4x!tt.ptr, #blocked> + %ptr_b4 = ttg.convert_layout %ptr : tensor<32x4x!tt.ptr, #blocked> -> tensor<32x4x!tt.ptr, #blocked4> + %ones_b4 = ttg.convert_layout %ones : tensor<32x4xi32, #blocked> -> tensor<32x4xi32, #blocked4> + %old = tt.atomic_rmw add, relaxed, cta, %ptr_b4, %ones_b4, %mask : (tensor<32x4x!tt.ptr, #blocked4>, tensor<32x4xi32, #blocked4>, tensor<32x4xi1, #blocked4>) -> tensor<32x4xi32, #blocked4> + tt.return + } + // CHECK: %[[IDX:.*]] = arith.constant dense<0> : tensor<32x4xi32, #[[IDX_ENC:[A-Za-z0-9_]+]]> + // CHECK: %[[MASK_SEED:.*]] = arith.constant dense<0> : tensor<32x4xi32, #[[MASK_ENC:[A-Za-z0-9_]+]]> + // CHECK: %[[MASK:.*]] = arith.cmpi eq, %[[MASK_SEED]], %[[MASK_SEED]] : tensor<32x4xi32, #[[MASK_ENC]]> + // CHECK: %[[IDX_CAST:.*]] = ttg.convert_layout %[[IDX]] : tensor<32x4xi32, #[[IDX_ENC]]> -> tensor<32x4xi32, #[[MASK_ENC]]> + // CHECK: %[[PTR:.*]] = "tle.local_pointers"(%{{.*}}, %[[IDX_CAST]]) {{.*}} : (!ttg.memdesc<1xi32, #shared, #smem, mutable>, tensor<32x4xi32, #[[MASK_ENC]]>) -> tensor<32x4x!tt.ptr, #[[MASK_ENC]]> + // CHECK: tt.atomic_rmw add, relaxed, cta, %{{.*}}, %{{.*}}, %[[MASK]] : (tensor<32x4x!tt.ptr, #[[MASK_ENC]]>, tensor<32x4xi32, #[[MASK_ENC]]>, tensor<32x4xi1, #[[MASK_ENC]]>) -> tensor<32x4xi32, #[[MASK_ENC]]> +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func @fold_pointer_convert_for_atomic_users + tt.func @fold_pointer_convert_for_atomic_users() { + %idx = arith.constant dense<0> : tensor<32x4xi32, #blocked> + %vals = arith.constant dense<1> : tensor<32x4xi32, #blocked4> + %mask = arith.constant dense : tensor<32x4xi1, #blocked4> + %smem = ttg.local_alloc : () -> !ttg.memdesc<1xi32, #shared, #smem, mutable> + %ptr = "tle.local_pointers"(%smem, %idx) : (!ttg.memdesc<1xi32, #shared, #smem, mutable>, tensor<32x4xi32, #blocked>) -> tensor<32x4x!tt.ptr, #blocked> + %ptr_b4 = ttg.convert_layout %ptr : tensor<32x4x!tt.ptr, #blocked> -> tensor<32x4x!tt.ptr, #blocked4> + %old = tt.atomic_rmw add, relaxed, cta, %ptr_b4, %vals, %mask : (tensor<32x4x!tt.ptr, #blocked4>, tensor<32x4xi32, #blocked4>, tensor<32x4xi1, #blocked4>) -> tensor<32x4xi32, #blocked4> + tt.return + } + // CHECK: %[[PTR:.*]] = "tle.local_pointers"(%{{.*}}, %{{.*}}) {{.*}} : (!ttg.memdesc<1xi32, #shared, #smem, mutable>, tensor<32x4xi32, #[[P_ENC:[A-Za-z0-9_]+]]>) -> tensor<32x4x!tt.ptr, #[[P_ENC]]> + // CHECK-NOT: ttg.convert_layout %[[PTR]] : tensor<32x4x!tt.ptr, #[[P_ENC]]> + // CHECK: tt.atomic_rmw add, relaxed, cta, %[[PTR]], %{{.*}}, %{{.*}} : (tensor<32x4x!tt.ptr, #[[P_ENC]]>, tensor<32x4xi32, #[[P_ENC]]>, tensor<32x4xi1, #[[P_ENC]]>) -> tensor<32x4xi32, #[[P_ENC]]> +} diff --git a/third_party/tle/test/GPU/test_tle_tritongpu_to_llvm_exclusive_cumsum.mlir b/third_party/tle/test/GPU/test_tle_tritongpu_to_llvm_exclusive_cumsum.mlir new file mode 100644 index 0000000000..1ddabaf3ba --- /dev/null +++ b/third_party/tle/test/GPU/test_tle_tritongpu_to_llvm_exclusive_cumsum.mlir @@ -0,0 +1,14 @@ +// RUN: triton-opt %s -pass-pipeline='builtin.module(allocate-shared-memory-nv{compute-capability=120 ptx-version=88}, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=120 ptx-version=88}, canonicalize, cse, convert-nv-gpu-to-llvm, convert-warp-specialize-to-llvm, canonicalize, cse, symbol-dce, convert-nvvm-to-llvm)' | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @cumsum_to_llvm(%arg0: tensor<128xi32, #blocked>, %out: !tt.ptr) { + %exclusive, %total = "tle.exclusive_cumsum"(%arg0) {axis = 0 : i32, reverse = false} : (tensor<128xi32, #blocked>) -> (tensor<128xi32, #blocked>, i32) + tt.store %out, %total : !tt.ptr + tt.return + } +} + +// CHECK: llvm.func @cumsum_to_llvm +// CHECK-NOT: tle.exclusive_cumsum diff --git a/third_party/tle/test/GPU/test_tle_tritontoTritonGPU_exclusive_cumsum.mlir b/third_party/tle/test/GPU/test_tle_tritontoTritonGPU_exclusive_cumsum.mlir new file mode 100644 index 0000000000..7393e6c413 --- /dev/null +++ b/third_party/tle/test/GPU/test_tle_tritontoTritonGPU_exclusive_cumsum.mlir @@ -0,0 +1,36 @@ +// MIT License +// +// Copyright (c) 2025 The FlagOS Contributors +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// flagtree tle + +// RUN: triton-opt %s -convert-triton-to-tritongpu='target=cuda:80 num-warps=4' | FileCheck %s + +// CHECK: tt.func @exclusive_cumsum_type_conversion() -> i32 +// CHECK: %[[IN:.*]] = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked> +// CHECK: %[[EX:.*]], %[[TOT:.*]] = tle.exclusive_cumsum %[[IN]] {axis = 0 : i32, reverse = false} : tensor<512xi32, #blocked> -> tensor<512xi32, #blocked>, i32 +module { + tt.func @exclusive_cumsum_type_conversion() -> i32 { + %in = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32> + %exclusive, %total = "tle.exclusive_cumsum"(%in) {axis = 0 : i32, reverse = false} : (tensor<512xi32>) -> (tensor<512xi32>, i32) + tt.return %total : i32 + } +} diff --git a/third_party/tle/triton_tle.cc b/third_party/tle/triton_tle.cc index 555ecee631..d79d0d5f58 100644 --- a/third_party/tle/triton_tle.cc +++ b/third_party/tle/triton_tle.cc @@ -181,6 +181,14 @@ void init_triton_tle_ir(py::module &&m) { return self.create(resultTy, memDesc, indices); }) + .def("create_exclusive_cumsum", + [](TritonOpBuilder &self, Type exclusiveTy, Type totalTy, Value src, + int axis, bool reverse) -> OpState { + auto &builder = self.getBuilder(); + return self.create( + TypeRange{exclusiveTy, totalTy}, src, + builder.getI32IntegerAttr(axis), builder.getBoolAttr(reverse)); + }) .def("create_distributed_barrier", [](TritonOpBuilder &self) -> void { self.create( @@ -252,10 +260,21 @@ void init_triton_tle_ir(py::module &&m) { void init_triton_tle_passes(py::module &&m) { ADD_PASS_WRAPPER_0("add_early_assign_memory_space", tle::createTritonTleEarlyAssignMemorySpace); + ADD_PASS_WRAPPER_0("add_select_encodings", + tle::createTritonTleSelectEncodings); + // Backward-compatible alias. ADD_PASS_WRAPPER_0("add_assign_local_pointers_encoding", - tle::createTritonTleAssignLocalPointersEncoding); + tle::createTritonTleSelectEncodings); ADD_PASS_WRAPPER_0("add_insert_local_pointer_barriers", tle::createTritonTleInsertLocalPointerBarriers); + ADD_PASS_WRAPPER_0("add_optimize_local_pointer_loads", + tle::createTritonTleOptimizeLocalPointerLoads); + ADD_PASS_WRAPPER_0("add_optimize_local_pointer_stores", + tle::createTritonTleOptimizeLocalPointerStores); + ADD_PASS_WRAPPER_0("add_optimize_exclusive_cumsum_layouts", + tle::createTritonTleOptimizeExclusiveCumsumLayouts); + ADD_PASS_WRAPPER_0("add_lower_exclusive_cumsum", + tle::createTritonTleLowerExclusiveCumsum); ADD_PASS_WRAPPER_0("add_lower_async_load", tle::createTritonTleLowerAsyncLoad); ADD_PASS_WRAPPER_0("add_lower_tma_copy", tle::createTritonTleLowerTmaCopy); diff --git a/third_party/triton_shared b/third_party/triton_shared new file mode 160000 index 0000000000..1c203ca99e --- /dev/null +++ b/third_party/triton_shared @@ -0,0 +1 @@ +Subproject commit 1c203ca99e7ea6c364e02c9b0a8a73df948193f9 diff --git a/tle.md b/tle.md index dcccbfa3ed..c8fb499f15 100644 --- a/tle.md +++ b/tle.md @@ -500,10 +500,10 @@ a_smem_ptrs = tle.gpu.local_ptr( ) ``` -- Signature: `tle.local_ptr(buffer, indices=None) -> tl.tensor | tl.ptr` -- Purpose: Build arbitrary-shaped pointer views over shared memory buffer for `tl.load/tl.store`. +- Signature: `tle.gpu.local_ptr(buffer, indices=None) -> tl.tensor | tl.ptr` +- Purpose: Build arbitrary-shaped pointer views over shared memory buffers for `tl.load/tl.store/tl.atomic*`. - Parameters: - - `buffer`: buffered tensor returned by `tle.alloc` (SMEM/TMEM). + - `buffer`: buffered tensor returned by `tle.gpu.alloc` (SMEM/TMEM). - `indices`: optional tuple of integer tensors. Tuple length must equal `rank(buffer)`, and all tensors must have identical shapes. If omitted/`None`, backend treats it as full indices. - Semantics: - If `indices` is provided: output pointer tensor shape equals common shape of index tensors. @@ -515,82 +515,287 @@ a_smem_ptrs = tle.gpu.local_ptr( Example 1: 1D slice ```python -smem = tle.alloc([BLOCK], dtype=tl.float32, scope=tle.smem) +smem = tle.gpu.alloc([BLOCK], dtype=tl.float32, scope=tle.gpu.smem) # Slice [offset, offset + SLICE) idx = offset + tl.arange(0, SLICE) -slice_ptr = tle.local_ptr(smem, (idx,)) +slice_ptr = tle.gpu.local_ptr(smem, (idx,)) vals = tl.load(slice_ptr) ``` Example 2: K-dimension tiling (matrix slice) ```python -smem_a = tle.alloc([BM, BK], dtype=tl.float16, scope=tle.smem) +smem_a = tle.gpu.alloc([BM, BK], dtype=tl.float16, scope=tle.gpu.smem) # Slice (BM, KW), where KW is K-dimension slice rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, KW)) cols = tl.broadcast_to(tl.arange(0, KW)[None, :] + k_start, (BM, KW)) -a_slice = tle.local_ptr(smem_a, (rows, cols)) +a_slice = tle.gpu.local_ptr(smem_a, (rows, cols)) a_vals = tl.load(a_slice) ``` Example 3: arbitrary gather view ```python -smem = tle.alloc([H, W], dtype=tl.float32, scope=tle.smem) +smem = tle.gpu.alloc([H, W], dtype=tl.float32, scope=tle.gpu.smem) # Take an offset column per row rows = tl.broadcast_to(tl.arange(0, H)[:, None], (H, SLICE)) cols = tl.broadcast_to(1 + tl.arange(0, SLICE)[None, :], (H, SLICE)) -gather_ptr = tle.local_ptr(smem, (rows, cols)) +gather_ptr = tle.gpu.local_ptr(smem, (rows, cols)) out = tl.load(gather_ptr) ``` -###### 3.3.1.1.4 `tle.gpu.copy` +Supported downstream ops: -Memory copy: +- `tl.load` +- `tl.store` +- `tl.atomic_add/and/cas/max/min/or/xchg/xor` + +Practical notes: + +- Atomic ops require element dtype/backend support; use integer/float types supported by target hardware. +- For local-pointer load-after-store hazards, TLE backend pass `TleInsertLocalPointerBarriers` inserts barriers automatically; add manual barriers only for custom synchronization patterns outside pass coverage. + +Example 4: load/store/atomic on the same `local_ptr` ```python -tle.gpu.copy(a_ptrs + ystride_a * yoffs[None, :], a_smem, [XBLOCK, YBLOCK]) +smem_i32 = tle.gpu.alloc([BLOCK], dtype=tl.int32, scope=tle.gpu.smem) +ptr = tle.gpu.local_ptr(smem_i32, (tl.arange(0, BLOCK),)) + +tl.store(ptr, tl.zeros([BLOCK], dtype=tl.int32)) +tl.atomic_add(ptr, 1) +vals = tl.load(ptr) ``` -###### 3.3.1.1.5 `tl.load/tl.store/tl.atomic*` for `tle.local_ptr` +###### 3.3.1.1.4 `tle.gpu.local_ptr` (for remote) -Shared-memory pointers from `tle.local_ptr` can be directly used by: +- Signature: `tle.gpu.local_ptr(remote_buffer, indices=None) -> tl.tensor | tl.ptr` +- Purpose: materialize pointer views for remote shared/local buffers returned by `tle.remote(...)`. +- Inputs: + - `remote_buffer`: result of `tle.remote(buffer, shard_id, scope)`, where `buffer` is typically allocated by `tle.gpu.alloc`. + - `indices`: same rules as local mode (`None` for full view, or tuple of integer tensors with identical shapes). +- Semantics: + - Pointer shape/linearization rules are identical to local `tle.gpu.local_ptr`. + - Address resolution targets the remote shard selected by `shard_id`. + - Use `tle.distributed_barrier(...)` when cross-shard producer/consumer ordering is required. -- `tl.load` -- `tl.store` -- `tl.atomic_add/and/cas/max/min/or/xchg/xor` +Example: read remote SMEM tile from neighbor shard + +```python +smem = tle.gpu.alloc([BM, BK], dtype=tl.float16, scope=tle.gpu.storage_kind.smem) +remote_smem = tle.remote(smem, shard_id=(node_rank, next_device), scope=mesh) + +rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, BK)) +cols = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK)) +remote_ptr = tle.gpu.local_ptr(remote_smem, (rows, cols)) + +vals = tl.load(remote_ptr) +``` + +###### 3.3.1.1.5 `tle.gpu.copy` + +Memory copy: + +```python +tle.gpu.copy(a_ptrs + ystride_a * yoffs[None, :], a_smem, [XBLOCK, YBLOCK]) +``` #### 3.3.2 DSA -##### 3.3.2.1 Memory Management +This section is rewritten from `triton_v3.2.x` (`python/triton/experimental/tle/language/dsa` and its README). +DSA APIs are split into: + +- Generic DSA APIs under `tle.dsa.*` +- Backend-specific address spaces under `tle.dsa.ascend.*` + +##### 3.3.2.1 Memory and Data Movement ###### 3.3.2.1.1 `tle.dsa.alloc` -Allocate memory (Ascend): +- Signature: `tle.dsa.alloc(shape, dtype, mem_addr_space)` +- Purpose: allocate DSA local buffers in a target memory space. + +Ascend memory spaces exposed in source: + +- `tle.dsa.ascend.UB` +- `tle.dsa.ascend.L1` +- `tle.dsa.ascend.L0A` +- `tle.dsa.ascend.L0B` +- `tle.dsa.ascend.L0C` ```python -a_ub = tle.dsa.alloc( - [XBLOCK, YBLOCK], - dtype=tl.float32, - layout=tle.dsa.ascend.NZ, - scope=tle.dsa.ascend.UB, -) +a_ub = tle.dsa.alloc([XBLOCK, YBLOCK], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) +b_l1 = tle.dsa.alloc([XBLOCK, YBLOCK], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.L1) ``` ###### 3.3.2.1.2 `tle.dsa.copy` -Memory copy: +- Signature: `tle.dsa.copy(src, dst, shape, inter_no_alias=False)` +- Purpose: explicit movement between GMEM pointers and DSA local buffers (both directions). + +```python +tle.dsa.copy(x_ptrs, a_ub, [tail_m, tail_n]) # GMEM -> local buffer +tle.dsa.copy(a_ub, out_ptrs, [tail_m, tail_n]) # local buffer -> GMEM +``` + +###### 3.3.2.1.3 `tle.dsa.local_ptr` + +- Signature: `tle.dsa.local_ptr(buffer, indices=None) -> tl.tensor | tl.ptr` +- Purpose: build pointer views over DSA local buffers (for example UB/L1) for explicit local-memory access patterns. +- Parameters: + - `buffer`: DSA buffered tensor, typically from `tle.dsa.alloc`. + - `indices`: optional tuple of integer tensors. If omitted/`None`, backend treats it as full indices. +- Semantics: + - Shape and indexing behavior follow `tle.gpu.local_ptr` (same pointer-view model). + - Intended for DSA-local data access paths that require explicit pointer materialization. + +Example: + +```python +a_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, BK)) +cols = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK)) +a_ptr = tle.dsa.local_ptr(a_ub, (rows, cols)) +a_val = tl.load(a_ptr) +``` + +###### 3.3.2.1.4 `tle.dsa.local_ptr` (for remote) + +- Signature: `tle.dsa.local_ptr(remote_buffer, indices=None) -> tl.tensor | tl.ptr` +- Purpose: materialize pointer views over remote DSA local buffers obtained from `tle.remote(...)`. +- Inputs: + - `remote_buffer`: result of `tle.remote(dsa_buffer, shard_id, scope)`. + - `indices`: same rules as local DSA mode. +- Semantics: + - Same pointer-view semantics as local DSA mode. + - Pointer dereference is routed to the remote shard selected by `shard_id`. + - Pair with `tle.distributed_barrier` when cross-shard ordering is required. + +Example: + +```python +a_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +remote_a_ub = tle.remote(a_ub, shard_id=peer_rank, scope=mesh) + +rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, BK)) +cols = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK)) +remote_ptr = tle.dsa.local_ptr(remote_a_ub, (rows, cols)) +remote_val = tl.load(remote_ptr) +``` + +###### 3.3.2.1.5 `tle.dsa.to_tensor` / `tle.dsa.to_buffer` + +- `tle.dsa.to_tensor(buffer, writable=True)`: convert a DSA buffer to a tensor view for tensor expressions. +- `tle.dsa.to_buffer(tensor, space)`: convert a tensor value back to a buffer in a target DSA address space. + +```python +c_val = tle.dsa.to_tensor(c_ub, writable=True) +result = c_val * 0.5 +d_ub = tle.dsa.to_buffer(result, tle.dsa.ascend.UB) +tle.dsa.copy(d_ub, out_ptrs, [tail_m, tail_n]) +``` + +##### 3.3.2.2 Elementwise Compute Ops (buffer-based) + +Builtins provided by source: + +- `tle.dsa.add` +- `tle.dsa.sub` +- `tle.dsa.mul` +- `tle.dsa.div` +- `tle.dsa.max` +- `tle.dsa.min` + +- Common signature: `tle.dsa.(lhs, rhs, out)` +- Compute model: elementwise binary op over DSA local buffers. +- Shape rules: + - `lhs`, `rhs`, `out` must have the same rank and shape. + - No implicit broadcast is assumed in this API layer. +- Dtype rules: + - Three operands should use the same dtype in practice. + - Integer dtypes are typical for index/count paths; float dtypes are typical for activation/math paths. +- Memory-space rules: + - Buffers should be allocated in compatible DSA local spaces (for example UB/L1 combinations allowed by backend). + - Keep hot operands/results in local space to avoid extra GMEM traffic. + +Per-op semantics: + +- `tle.dsa.add(lhs, rhs, out)`: `out = lhs + rhs` +- `tle.dsa.sub(lhs, rhs, out)`: `out = lhs - rhs` +- `tle.dsa.mul(lhs, rhs, out)`: `out = lhs * rhs` +- `tle.dsa.div(lhs, rhs, out)`: `out = lhs / rhs` (backend-dependent precision/rounding) +- `tle.dsa.max(lhs, rhs, out)`: `out = max(lhs, rhs)` +- `tle.dsa.min(lhs, rhs, out)`: `out = min(lhs, rhs)` + +In-place usage: + +- You can reuse the same output buffer across steps, for example `tle.dsa.mul(tmp, b, tmp)`. +- Avoid aliasing inputs/outputs unless backend semantics explicitly allow it. + +Example 1: arithmetic chain `((a - b) * b) / scale` ```python -tle.dsa.copy(a_ptrs + ystride_a * yoffs[None, :], a_smem, [XBLOCK, YBLOCK]) +a_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +b_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +scale_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +tmp_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +out_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) + +tle.dsa.copy(a_ptrs, a_ub, [BM, BK]) +tle.dsa.copy(b_ptrs, b_ub, [BM, BK]) +tle.dsa.copy(scale_ptrs, scale_ub, [BM, BK]) + +tle.dsa.sub(a_ub, b_ub, tmp_ub) # tmp = a - b +tle.dsa.mul(tmp_ub, b_ub, tmp_ub) # tmp = tmp * b +tle.dsa.div(tmp_ub, scale_ub, out_ub) # out = tmp / scale + +tle.dsa.copy(out_ub, out_ptrs, [BM, BK]) ``` -###### 3.3.2.1.3 `tle.dsa.local_load` +Example 2: clamp by `max` + `min` + +```python +x_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +floor_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +ceil_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +tmp_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +y_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) + +tle.dsa.copy(x_ptrs, x_ub, [BM, BK]) +tle.dsa.copy(floor_ptrs, floor_ub, [BM, BK]) +tle.dsa.copy(ceil_ptrs, ceil_ub, [BM, BK]) + +tle.dsa.max(x_ub, floor_ub, tmp_ub) # tmp = max(x, floor) +tle.dsa.min(tmp_ub, ceil_ub, y_ub) # y = min(tmp, ceil) -Load from local memory: +tle.dsa.copy(y_ub, y_ptrs, [BM, BK]) +``` + +##### 3.3.2.3 Loop and Hint APIs + +Source includes: + +- `tle.dsa.pipeline(...)` +- `tle.dsa.parallel(...)` +- `tle.dsa.hint(...)` (used as `with tle.dsa.hint(...)` compile-time hints) ```python -aval = tle.dsa.local_load(a_smem) +with tle.dsa.hint(inter_no_alias=True): + tle.dsa.copy(x_ptr + offs, a_ub, [tail_size], inter_no_alias=True) +``` + +##### 3.3.2.4 Slice/View Utilities + +Source includes: + +- `tle.dsa.extract_slice` +- `tle.dsa.insert_slice` +- `tle.dsa.extract_element` +- `tle.dsa.subview` + +```python +sub = tle.dsa.extract_slice(full, offsets=(0, k0), sizes=(BM, BK), strides=(1, 1)) +full = tle.dsa.insert_slice(full, sub, offsets=(0, k0), sizes=(BM, BK), strides=(1, 1)) +elem = tle.dsa.extract_element(sub, indice=(i, j)) ``` #### 3.3.3 Struct API Cookbook @@ -625,14 +830,22 @@ count_ptr = tle.gpu.local_ptr(counts, (idx,)) tl.atomic_add(count_ptr, 1) ``` -##### 3.3.3.3 DSA local-buffer flow (`dsa.alloc` + `dsa.copy` + `dsa.local_load`) +##### 3.3.3.3 DSA local-buffer flow (`dsa.alloc` + `dsa.copy` + `dsa.to_tensor/to_buffer`) Use this for DSA backends that expose dedicated local buffer spaces. ```python -a_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, layout=tle.dsa.ascend.NZ, scope=tle.dsa.ascend.UB) +a_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +b_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +c_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) + tle.dsa.copy(a_ptrs, a_ub, [BM, BK]) -a_val = tle.dsa.local_load(a_ub) +tle.dsa.copy(b_ptrs, b_ub, [BM, BK]) +tle.dsa.add(a_ub, b_ub, c_ub) + +c_val = tle.dsa.to_tensor(c_ub, writable=True) +out_ub = tle.dsa.to_buffer(c_val, tle.dsa.ascend.UB) +tle.dsa.copy(out_ub, out_ptrs, [BM, BK]) ``` ### 3.4 TLE-Raw diff --git a/tle_cn.md b/tle_cn.md index 065cc8b4b1..1379ea9ff2 100644 --- a/tle_cn.md +++ b/tle_cn.md @@ -494,10 +494,10 @@ a_smem_ptrs = tle.gpu.local_ptr( ) ``` -- Signature: `tle.local_ptr(buffer, indices=None) -> tl.tensor | tl.ptr` -- Purpose: 在 shared memory buffer 上构建任意形状 pointer view,用于 `tl.load/tl.store`。 +- Signature: `tle.gpu.local_ptr(buffer, indices=None) -> tl.tensor | tl.ptr` +- Purpose: 在 shared memory buffer 上构建任意形状 pointer view,可用于 `tl.load/tl.store/tl.atomic*`。 - Parameters: - - `buffer`: 由 `tle.alloc` 返回的 buffered_tensor(SMEM/TMEM)。 + - `buffer`: 由 `tle.gpu.alloc` 返回的 buffered_tensor(SMEM/TMEM)。 - `indices`: 可选整数 tensor 元组,长度必须等于 `rank(buffer)`,且每个 tensor 形状相同;若省略/传 `None`,由后端按 full indices 语义处理。 - Semantics: - 当显式传入 `indices` 时,输出 pointer tensor 形状等于 indices 的公共形状。 @@ -509,82 +509,287 @@ a_smem_ptrs = tle.gpu.local_ptr( Example 1:1D slice ```python -smem = tle.alloc([BLOCK], dtype=tl.float32, scope=tle.smem) +smem = tle.gpu.alloc([BLOCK], dtype=tl.float32, scope=tle.gpu.smem) # Slice [offset, offset + SLICE) idx = offset + tl.arange(0, SLICE) -slice_ptr = tle.local_ptr(smem, (idx,)) +slice_ptr = tle.gpu.local_ptr(smem, (idx,)) vals = tl.load(slice_ptr) ``` Example 2:K 维切片(矩阵) ```python -smem_a = tle.alloc([BM, BK], dtype=tl.float16, scope=tle.smem) +smem_a = tle.gpu.alloc([BM, BK], dtype=tl.float16, scope=tle.gpu.smem) # Slice (BM, KW), KW 是 K 维子切片 rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, KW)) cols = tl.broadcast_to(tl.arange(0, KW)[None, :] + k_start, (BM, KW)) -a_slice = tle.local_ptr(smem_a, (rows, cols)) +a_slice = tle.gpu.local_ptr(smem_a, (rows, cols)) a_vals = tl.load(a_slice) ``` Example 3:任意 gather view ```python -smem = tle.alloc([H, W], dtype=tl.float32, scope=tle.smem) +smem = tle.gpu.alloc([H, W], dtype=tl.float32, scope=tle.gpu.smem) # 每行取偏移列 rows = tl.broadcast_to(tl.arange(0, H)[:, None], (H, SLICE)) cols = tl.broadcast_to(1 + tl.arange(0, SLICE)[None, :], (H, SLICE)) -gather_ptr = tle.local_ptr(smem, (rows, cols)) +gather_ptr = tle.gpu.local_ptr(smem, (rows, cols)) out = tl.load(gather_ptr) ``` -###### 3.3.1.1.4 `tle.gpu.copy` +支持的下游操作: -内存拷贝: +- `tl.load` +- `tl.store` +- `tl.atomic_add/and/cas/max/min/or/xchg/xor` + +实践说明: + +- 原子操作是否可用取决于元素 dtype 和后端硬件能力,建议优先使用目标硬件已验证支持的整数/浮点类型。 +- 对于 local_ptr 的 load-after-store hazard,TLE 后端 pass `TleInsertLocalPointerBarriers` 会自动插入 barrier;只有在超出该 pass 覆盖范围的自定义同步模式下,才需要手动加 barrier。 + +Example 4:同一 `local_ptr` 上执行 load/store/atomic ```python -tle.gpu.copy(a_ptrs + ystride_a * yoffs[None, :], a_smem, [XBLOCK, YBLOCK]) +smem_i32 = tle.gpu.alloc([BLOCK], dtype=tl.int32, scope=tle.gpu.smem) +ptr = tle.gpu.local_ptr(smem_i32, (tl.arange(0, BLOCK),)) + +tl.store(ptr, tl.zeros([BLOCK], dtype=tl.int32)) +tl.atomic_add(ptr, 1) +vals = tl.load(ptr) ``` -###### 3.3.1.1.5 `tl.load/tl.store/tl.atomic*` for `tle.local_ptr` +###### 3.3.1.1.4 `tle.gpu.local_ptr`(for remote) -`tle.local_ptr` 返回的 Shared Memory 指针可直接用于: +- Signature: `tle.gpu.local_ptr(remote_buffer, indices=None) -> tl.tensor | tl.ptr` +- 用途:对 `tle.remote(...)` 返回的远端 shared/local buffer 构建指针视图。 +- 输入: + - `remote_buffer`:由 `tle.remote(buffer, shard_id, scope)` 返回,`buffer` 通常来自 `tle.gpu.alloc`。 + - `indices`:与本地模式一致(`None` 代表 full-view,或传入同形状整数 tensor 元组)。 +- 语义: + - 指针形状、索引和线性化规则与本地 `tle.gpu.local_ptr` 完全一致。 + - 地址解析会路由到 `shard_id` 指定的远端分片。 + - 跨分片读写若需要顺序保证,需配合 `tle.distributed_barrier(...)`。 -- `tl.load` -- `tl.store` -- `tl.atomic_add/and/cas/max/min/or/xchg/xor` +Example:读取邻居分片上的远端 SMEM tile + +```python +smem = tle.gpu.alloc([BM, BK], dtype=tl.float16, scope=tle.gpu.storage_kind.smem) +remote_smem = tle.remote(smem, shard_id=(node_rank, next_device), scope=mesh) + +rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, BK)) +cols = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK)) +remote_ptr = tle.gpu.local_ptr(remote_smem, (rows, cols)) + +vals = tl.load(remote_ptr) +``` + +###### 3.3.1.1.5 `tle.gpu.copy` + +内存拷贝: + +```python +tle.gpu.copy(a_ptrs + ystride_a * yoffs[None, :], a_smem, [XBLOCK, YBLOCK]) +``` #### 3.3.2 DSA -##### 3.3.2.1 内存管理 +本节基于 `triton_v3.2.x` 中 `python/triton/experimental/tle/language/dsa` 及其 README 重写。 +DSA API 分为两层: + +- 通用 DSA API:`tle.dsa.*` +- 后端特定地址空间:`tle.dsa.ascend.*` + +##### 3.3.2.1 内存与数据搬运 ###### 3.3.2.1.1 `tle.dsa.alloc` -分配内存(Ascend): +- Signature: `tle.dsa.alloc(shape, dtype, mem_addr_space)` +- 用途:在目标地址空间分配 DSA 本地 buffer。 + +源码中 Ascend 暴露的地址空间: + +- `tle.dsa.ascend.UB` +- `tle.dsa.ascend.L1` +- `tle.dsa.ascend.L0A` +- `tle.dsa.ascend.L0B` +- `tle.dsa.ascend.L0C` ```python -a_ub = tle.dsa.alloc( - [XBLOCK, YBLOCK], - dtype=tl.float32, - layout=tle.dsa.ascend.NZ, - scope=tle.dsa.ascend.UB, -) +a_ub = tle.dsa.alloc([XBLOCK, YBLOCK], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) +b_l1 = tle.dsa.alloc([XBLOCK, YBLOCK], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.L1) ``` ###### 3.3.2.1.2 `tle.dsa.copy` -内存拷贝: +- Signature: `tle.dsa.copy(src, dst, shape, inter_no_alias=False)` +- 用途:在 GMEM 指针与 DSA 本地 buffer 之间做显式搬运(双向)。 + +```python +tle.dsa.copy(x_ptrs, a_ub, [tail_m, tail_n]) # GMEM -> 本地 buffer +tle.dsa.copy(a_ub, out_ptrs, [tail_m, tail_n]) # 本地 buffer -> GMEM +``` + +###### 3.3.2.1.3 `tle.dsa.local_ptr` + +- Signature: `tle.dsa.local_ptr(buffer, indices=None) -> tl.tensor | tl.ptr` +- 用途:在 DSA 本地 buffer(如 UB/L1)上构建指针视图,用于显式本地访存路径。 +- 参数: + - `buffer`:DSA buffered tensor,通常由 `tle.dsa.alloc` 分配。 + - `indices`:可选整数 tensor 元组;省略/传 `None` 时按 full indices 语义处理。 +- 语义: + - 指针视图模型与 `tle.gpu.local_ptr` 一致(形状和索引规则相同)。 + - 适用于需要显式 materialize 指针的 DSA 本地访问流程。 + +Example: + +```python +a_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, BK)) +cols = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK)) +a_ptr = tle.dsa.local_ptr(a_ub, (rows, cols)) +a_val = tl.load(a_ptr) +``` + +###### 3.3.2.1.4 `tle.dsa.local_ptr`(for remote) + +- Signature: `tle.dsa.local_ptr(remote_buffer, indices=None) -> tl.tensor | tl.ptr` +- 用途:对 `tle.remote(...)` 返回的远端 DSA 本地 buffer 构建指针视图。 +- 输入: + - `remote_buffer`:由 `tle.remote(dsa_buffer, shard_id, scope)` 返回。 + - `indices`:与本地 DSA 模式一致。 +- 语义: + - 与本地 DSA 模式保持相同的指针视图规则。 + - 指针解引用会路由到 `shard_id` 指定的远端分片。 + - 需要跨分片时序保证时,配合 `tle.distributed_barrier` 使用。 + +Example: + +```python +a_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +remote_a_ub = tle.remote(a_ub, shard_id=peer_rank, scope=mesh) + +rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, BK)) +cols = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK)) +remote_ptr = tle.dsa.local_ptr(remote_a_ub, (rows, cols)) +remote_val = tl.load(remote_ptr) +``` + +###### 3.3.2.1.5 `tle.dsa.to_tensor` / `tle.dsa.to_buffer` + +- `tle.dsa.to_tensor(buffer, writable=True)`:把 DSA buffer 转成 tensor 视图以参与 tensor 表达式。 +- `tle.dsa.to_buffer(tensor, space)`:把 tensor 值转回指定地址空间的 DSA buffer。 + +```python +c_val = tle.dsa.to_tensor(c_ub, writable=True) +result = c_val * 0.5 +d_ub = tle.dsa.to_buffer(result, tle.dsa.ascend.UB) +tle.dsa.copy(d_ub, out_ptrs, [tail_m, tail_n]) +``` + +##### 3.3.2.2 向量算子(buffer 形态) + +源码内置: + +- `tle.dsa.add` +- `tle.dsa.sub` +- `tle.dsa.mul` +- `tle.dsa.div` +- `tle.dsa.max` +- `tle.dsa.min` + +- 通用签名:`tle.dsa.(lhs, rhs, out)` +- 计算模型:对 DSA 本地 buffer 做逐元素二元运算。 +- 形状规则: + - `lhs`、`rhs`、`out` 的 rank 和 shape 应一致。 + - 该 API 层不默认做隐式 broadcast。 +- 类型规则: + - 三个操作数在实践中建议使用相同 dtype。 + - 整数类型常用于索引/计数路径,浮点类型常用于激活/数值计算路径。 +- 地址空间规则: + - buffer 应分配在后端支持的兼容 DSA 本地地址空间(例如 UB/L1 组合)。 + - 热数据尽量留在本地空间,避免额外 GMEM 往返。 + +各算子语义: + +- `tle.dsa.add(lhs, rhs, out)`:`out = lhs + rhs` +- `tle.dsa.sub(lhs, rhs, out)`:`out = lhs - rhs` +- `tle.dsa.mul(lhs, rhs, out)`:`out = lhs * rhs` +- `tle.dsa.div(lhs, rhs, out)`:`out = lhs / rhs`(精度与舍入行为取决于后端实现) +- `tle.dsa.max(lhs, rhs, out)`:`out = max(lhs, rhs)` +- `tle.dsa.min(lhs, rhs, out)`:`out = min(lhs, rhs)` + +原地/复用建议: + +- 可以在多步计算中复用输出 buffer,例如 `tle.dsa.mul(tmp, b, tmp)`。 +- 除非后端明确保证别名安全,否则不要让输入输出随意别名。 + +Example 1:算术链路 `((a - b) * b) / scale` + +```python +a_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +b_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +scale_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +tmp_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +out_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) + +tle.dsa.copy(a_ptrs, a_ub, [BM, BK]) +tle.dsa.copy(b_ptrs, b_ub, [BM, BK]) +tle.dsa.copy(scale_ptrs, scale_ub, [BM, BK]) + +tle.dsa.sub(a_ub, b_ub, tmp_ub) # tmp = a - b +tle.dsa.mul(tmp_ub, b_ub, tmp_ub) # tmp = tmp * b +tle.dsa.div(tmp_ub, scale_ub, out_ub) # out = tmp / scale + +tle.dsa.copy(out_ub, out_ptrs, [BM, BK]) +``` + +Example 2:用 `max` + `min` 做 clamp + +```python +x_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +floor_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +ceil_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +tmp_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +y_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) + +tle.dsa.copy(x_ptrs, x_ub, [BM, BK]) +tle.dsa.copy(floor_ptrs, floor_ub, [BM, BK]) +tle.dsa.copy(ceil_ptrs, ceil_ub, [BM, BK]) + +tle.dsa.max(x_ub, floor_ub, tmp_ub) # tmp = max(x, floor) +tle.dsa.min(tmp_ub, ceil_ub, y_ub) # y = min(tmp, ceil) + +tle.dsa.copy(y_ub, y_ptrs, [BM, BK]) +``` + +##### 3.3.2.3 循环与 Hint API + +源码包含: + +- `tle.dsa.pipeline(...)` +- `tle.dsa.parallel(...)` +- `tle.dsa.hint(...)`(以 `with tle.dsa.hint(...)` 形式提供编译期 hint) ```python -tle.dsa.copy(a_ptrs + ystride_a * yoffs[None, :], a_smem, [XBLOCK, YBLOCK]) +with tle.dsa.hint(inter_no_alias=True): + tle.dsa.copy(x_ptr + offs, a_ub, [tail_size], inter_no_alias=True) ``` -###### 3.3.2.1.3 `tle.dsa.local_load` +##### 3.3.2.4 切片与视图 API + +源码包含: -内存加载: +- `tle.dsa.extract_slice` +- `tle.dsa.insert_slice` +- `tle.dsa.extract_element` +- `tle.dsa.subview` ```python -aval = tle.dsa.local_load(a_smem) +sub = tle.dsa.extract_slice(full, offsets=(0, k0), sizes=(BM, BK), strides=(1, 1)) +full = tle.dsa.insert_slice(full, sub, offsets=(0, k0), sizes=(BM, BK), strides=(1, 1)) +elem = tle.dsa.extract_element(sub, indice=(i, j)) ``` #### 3.3.3 Struct API 组合示例 @@ -619,14 +824,22 @@ count_ptr = tle.gpu.local_ptr(counts, (idx,)) tl.atomic_add(count_ptr, 1) ``` -##### 3.3.3.3 DSA 本地缓冲流程(`dsa.alloc` + `dsa.copy` + `dsa.local_load`) +##### 3.3.3.3 DSA 本地缓冲流程(`dsa.alloc` + `dsa.copy` + `dsa.to_tensor/to_buffer`) 适用于暴露专用本地缓冲空间的 DSA 后端。 ```python -a_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, layout=tle.dsa.ascend.NZ, scope=tle.dsa.ascend.UB) +a_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +b_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) +c_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB) + tle.dsa.copy(a_ptrs, a_ub, [BM, BK]) -a_val = tle.dsa.local_load(a_ub) +tle.dsa.copy(b_ptrs, b_ub, [BM, BK]) +tle.dsa.add(a_ub, b_ub, c_ub) + +c_val = tle.dsa.to_tensor(c_ub, writable=True) +out_ub = tle.dsa.to_buffer(c_val, tle.dsa.ascend.UB) +tle.dsa.copy(out_ub, out_ptrs, [BM, BK]) ``` ### 3.4 TLE-Raw