From 1ede0bd748d3467464d16691b3d300fd6b88f76d Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Mon, 30 Mar 2026 15:56:21 +0800 Subject: [PATCH 001/156] Introduce T.deallocate_tmem and T.transpose (#1971) * Fix tcgen05 barrier allocation planning regression * Add explicit TMEM deallocation and shared transpose copy * Add transpose operation to documentation and update implementation in copy_op.py * LINT FIX * Refactor: extract transpose from CopyNode into standalone TransposeNode op Remove the transpose annotation logic from CopyNode (GetTranspose, MakeIndices transpose branch, MakePredicate transpose branch, and GetCopyInst early return). Transpose is now handled by the independent TransposeNode registered as tl.tileop.transpose. Co-Authored-By: Claude Opus 4.6 * Refactor CopyNode to remove transpose handling from index generation and predicate creation. Simplify MakeIndices and MakePredicate methods by eliminating unnecessary transpose checks and related logic. Update associated checks to ensure index consistency. Clean up unused GetTranspose method in copy.h. * lint fix --------- Co-authored-by: Claude Opus 4.6 --- docs/programming_guides/instructions.md | 2 + examples/gemm/example_gemm_schedule.py | 82 ------- src/op/builtin.cc | 5 + src/op/builtin.h | 10 + src/op/transpose.cc | 216 ++++++++++++++++++ src/op/transpose.h | 70 ++++++ src/transform/lower_shared_tmem.cc | 177 +++++++++++--- .../multi_version_buffer_rewriter.cc | 3 +- .../plan_update_buffer_allocation_location.cc | 25 +- .../test_tilelang_language_transpose.py | 117 ++++++++++ ...tilelang_transform_lower_shared_barrier.py | 71 ++++++ ...st_tilelang_transform_lower_shared_tmem.py | 94 ++++++++ tilelang/language/__init__.py | 2 +- tilelang/language/allocate.py | 7 +- tilelang/language/builtin.py | 26 +++ tilelang/language/copy_op.py | 35 +++ 16 files changed, 826 insertions(+), 116 deletions(-) delete mode 100644 examples/gemm/example_gemm_schedule.py create mode 100644 src/op/transpose.cc create mode 100644 src/op/transpose.h create mode 100644 testing/python/language/test_tilelang_language_transpose.py create mode 100644 testing/python/transform/test_tilelang_transform_lower_shared_tmem.py diff --git a/docs/programming_guides/instructions.md b/docs/programming_guides/instructions.md index 4bcbb147d2..a260b915ae 100644 --- a/docs/programming_guides/instructions.md +++ b/docs/programming_guides/instructions.md @@ -145,6 +145,7 @@ signatures, behaviors, constraints, and examples, refer to API Reference Data movement - `T.copy(src, dst, ...)`: Move tiles between Global/Shared/Fragment. - `T.async_copy(src, dst, ...)`: Explicit async global→shared copy via `cp.async`. +- `T.transpose(src, dst)`: Transpose a 2D shared buffer: `dst[j, i] = src[i, j]`. - `T.c2d_im2col(img, col, ...)`: 2D im2col transform for conv. Memory allocation and descriptors @@ -153,6 +154,7 @@ Memory allocation and descriptors - `T.alloc_var(dtype, [init], scope='local.var')`: Scalar var buffer (1 elem). - `T.alloc_barrier(arrive_count)`: Allocate and initialize one or more mbarriers. - `T.alloc_tmem(shape, dtype)`: Tensor memory (TMEM) buffer (Hopper+). +- `T.deallocate_tmem(buffer)`: Explicitly release a TMEM buffer at the current site. - `T.alloc_reducer(shape, dtype, op='sum', replication=None)`: Reducer buf. - `T.alloc_descriptor(kind, dtype)`: Generic descriptor allocator. - `T.alloc_wgmma_desc(dtype='uint64')` diff --git a/examples/gemm/example_gemm_schedule.py b/examples/gemm/example_gemm_schedule.py deleted file mode 100644 index c921932bbd..0000000000 --- a/examples/gemm/example_gemm_schedule.py +++ /dev/null @@ -1,82 +0,0 @@ -import tilelang -import tilelang.language as T - - -@tilelang.jit(out_idx=[-1]) -def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): - @T.prim_func - def gemm_schedule( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_K, block_N), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - - # Enable rasterization for better L2 Cache Locality - T.use_swizzle(panel_size=10) - - # Clear the local buffer - T.clear(C_local) - - # Auto pipeline the computation - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): - T.copy(A[by * block_M, ko * block_K], A_shared) - - # Instead of using - # T.copy(B[k * block_K, bx * block_N], B_shared) - # we can also use Parallel to auto map the thread - # bindings and vectorize the copy operation. - for k, j in T.Parallel(block_K, block_N): - B_shared[k, j] = B[ko * block_K + k, bx * block_N + j] - - T.gemm(A_shared, B_shared, C_local) - - T.copy(C_local, C[by * block_M, bx * block_N]) - - return gemm_schedule - - -def main(): - kernel = matmul(1024, 1024, 1024, 128, 128, 32) - - import torch - - a = torch.randn(1024, 1024).cuda().half() - b = torch.randn(1024, 1024).cuda().half() - - c = kernel(a, b) - - ref_c = a @ b - - print("c:") - print(c) - print("ref_c:") - print(ref_c) - - # Get CUDA Source - print("CUDA Source:") - print(kernel.get_kernel_source()) - - torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) - print("All check passed.") - - -def run_regression_perf(): - kernel = matmul(1024, 1024, 1024, 128, 128, 32) - import torch - - a = torch.randn(1024, 1024).cuda().half() - b = torch.randn(1024, 1024).cuda().half() - from tilelang.profiler import do_bench - - def run_kernel_only(): - kernel(a, b) - - return do_bench(run_kernel_only, backend="cupti") - - -if __name__ == "__main__": - main() diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 3943e32070..15c1459edb 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -211,6 +211,11 @@ TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ts) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(deallocate_tmem) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(ptx_init_tensor_memory) .set_num_inputs(2) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index a3d8535b2a..db13db5d1f 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -342,6 +342,16 @@ TVM_DLL const Op &ptx_tcgen05_mma_ss(); */ TVM_DLL const Op &ptx_tcgen05_mma_ts(); +/*! + * \brief Frontend TMEM deallocation marker. + * + * deallocate_tmem(tmem_buffer_data) + * + * This op is produced by the TileLang Python frontend and must be lowered by + * LowerSharedTmem into ptx_deallocate_tensor_memory(access_ptr, num_cols). + */ +TVM_DLL const Op &deallocate_tmem(); + /*! * \brief tvm intrinsics for initializing tensor memory * diff --git a/src/op/transpose.cc b/src/op/transpose.cc new file mode 100644 index 0000000000..8d045b5322 --- /dev/null +++ b/src/op/transpose.cc @@ -0,0 +1,216 @@ +/*! + * \file tl/op/transpose.cc + * \brief Transpose operator: dst[j, i] = src[i, j] using SIMT loops. + */ + +#include "transpose.h" + +#include +#include +#include + +#include "../target/utils.h" +#include "../transform/common/loop_fusion_utils.h" +#include "../transform/loop_partition.h" +#include "../transform/loop_vectorize.h" +#include "utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +Transpose::Transpose(Array args, Map annotations) { + ObjectPtr node = tvm::ffi::make_object(); + Array rgs[2]; + Buffer bf[2]; + for (int i = 0; i < 2; i++) { + auto region = NormalizeToBufferRegion(args[i]); + rgs[i] = region->region; + bf[i] = region->buffer; + } + std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); + std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); + data_ = std::move(node); +} + +TileOperator TransposeNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return Transpose(op); +} + +Array TransposeNode::MakeIterVars() const { + // Use src_range as the iteration domain (src is the "inner" side). + Array loop_vars; + size_t idx = 0; + for (size_t i = 0; i < src_range.size(); i++) { + if (is_one(src_range[i]->extent)) + continue; + Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype); + idx++; + loop_vars.push_back( + {Range(0, src_range[i]->extent), var, IterVarType::kDataPar}); + } + return loop_vars; +} + +Array TransposeNode::MakeIndices(const Array &ivs, + int src_dst) const { + Array indices; + Array ranges = src_dst == 0 ? src_range : dst_range; + + if (src_dst == 1) { + // Transpose: reverse the loop variable assignment for non-trivial dims. + std::vector nontrivial; + for (size_t i = 0; i < ranges.size(); i++) { + if (!is_one(ranges[i]->extent)) + nontrivial.push_back(i); + } + ICHECK(nontrivial.size() == ivs.size()) + << "Transpose: nontrivial dims (" << nontrivial.size() + << ") != ivs size (" << ivs.size() << ") for dst=" << dst->name; + size_t N = nontrivial.size(); + size_t nt_idx = 0; + for (size_t i = 0; i < ranges.size(); i++) { + if (is_one(ranges[i]->extent)) { + indices.push_back(ranges[i]->min); + } else { + size_t rev = N - 1 - nt_idx; + indices.push_back(ranges[i]->min + ivs[rev]->var); + nt_idx++; + } + } + } else { + // Source: direct mapping. + size_t idx = 0; + for (size_t i = 0; i < ranges.size(); i++) { + if (is_one(ranges[i]->extent)) + indices.push_back(ranges[i]->min); + else { + indices.push_back(ranges[i]->min + ivs[idx]->var); + idx++; + } + } + ICHECK(idx == ivs.size()) + << "idx = " << idx << ", ivs.size() = " << ivs.size() + << " src name = " << src->name << ", dst name = " << dst->name; + } + return indices; +} + +PrimExpr TransposeNode::MakePredicate(arith::Analyzer *analyzer, + const Array &ivs, + Array extents, + int src_dst) const { + bool do_transpose = (src_dst == 1); + Array ranges = src_dst == 0 ? src_range : dst_range; + + size_t num_nontrivial = 0; + for (size_t i = 0; i < ranges.size(); i++) { + if (!is_one(ranges[i]->extent)) + num_nontrivial++; + } + + Array cond_list; + ICHECK(extents.size() == ranges.size()) << extents << " " << ranges; + size_t idx = 0; + for (size_t i = 0; i < ranges.size(); i++) { + if (is_one(ranges[i]->extent)) + continue; + size_t iv_idx = do_transpose ? (num_nontrivial - 1 - idx) : idx; + PrimExpr cond = ranges[i]->min + ivs[iv_idx]->var < extents[i]; + if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) { + cond_list.push_back(cond); + } + cond = ranges[i]->min + ivs[iv_idx]->var >= 0; + if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) { + cond_list.push_back(cond); + } + idx++; + } + if (cond_list.empty()) + return {}; + PrimExpr cond = cond_list[0]; + for (size_t i = 1; i < cond_list.size(); i++) + cond = And(cond, cond_list[i]); + return cond; +} + +For TransposeNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { + Array loop_vars = MakeIterVars(); + bool is_scalar = loop_vars.empty(); + + for (const auto &iv : loop_vars) + analyzer->Bind(iv->var, iv->dom); + + Array src_indices = MakeIndices(loop_vars, 0); + Array dst_indices = MakeIndices(loop_vars, 1); + + PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0); + PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1); + + PrimExpr value = BufferLoad(src, src_indices); + if (src->dtype != dst->dtype) + value = Cast(dst->dtype, value); + if (src_predicate.defined()) + value = if_then_else(src_predicate, value, make_zero(dst->dtype)); + + Stmt body = BufferStore(dst, value, dst_indices); + if (dst_predicate.defined()) + body = IfThenElse(dst_predicate, body); + if (is_scalar) { + return For(Var("i"), 0, 1, ForKind::kSerial, body); + } + + for (int i = loop_vars.size() - 1; i >= 0; i--) { + body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent, + ForKind::kParallel, body); + } + return Downcast(body); +} + +Stmt TransposeNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + // Transpose always uses normal SIMT lowering (no TMA/LDSM/etc.). + bool is_cpu_target = T.target->GetTargetDeviceType() == kDLCPU; + auto simt_loop = MakeSIMTLoop(analyzer); + auto fused_loop = Downcast(ParallelLoopFuser::Fuse(simt_loop)); + + if (is_cpu_target || IsLocalBuffer(src) || IsLocalBuffer(dst)) { + auto vectorized_loop = VectorizeLoop(fused_loop, T.layout_map); + return vectorized_loop; + } else { + auto par_op = ParallelOp(fused_loop); + std::vector levels = {InferLevel::kCommon, InferLevel::kStrict, + InferLevel::kFree}; + for (auto level : levels) { + par_op->InferLayout({T.target, + T.thread_bounds, + T.layout_map, + analyzer, + false, + T.buffer_remap, + {}}, + level); + } + auto loop_layout = par_op->GetLoopLayout(); + return LowerParallelLoop(par_op->GetRoot(), loop_layout, T.thread_var, + analyzer, T.layout_map, + par_op->GetPredicate(T.thread_var)); + } +} + +LayoutMap TransposeNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + // Transpose always uses SIMT loops; no special layout inference needed. + return {}; +} + +TIR_REGISTER_TL_TILE_OP(Transpose, transpose) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TVM_FFI_STATIC_INIT_BLOCK() { TransposeNode::RegisterReflection(); } + +} // namespace tl +} // namespace tvm diff --git a/src/op/transpose.h b/src/op/transpose.h new file mode 100644 index 0000000000..2d7ca30d1e --- /dev/null +++ b/src/op/transpose.h @@ -0,0 +1,70 @@ +/*! + * \file tl/op/transpose.h + * \brief Transpose operation for 2D shared memory buffers. + */ + +#ifndef TVM_TL_OP_TRANSPOSE_H_ +#define TVM_TL_OP_TRANSPOSE_H_ + +#include "operator.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +/// Node class for transpose operations: dst[j, i] = src[i, j] +class TransposeNode : public TileOperatorNode { +public: + Buffer src, dst; + Array src_range, dst_range; + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Transpose", TransposeNode, + TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("src", &TransposeNode::src) + .def_ro("dst", &TransposeNode::dst) + .def_ro("src_range", &TransposeNode::src_range) + .def_ro("dst_range", &TransposeNode::dst_range); + } + + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + TileOperator Clone() const override; + +private: + /// Create iterator variables for dimensions with extent > 1. + Array MakeIterVars() const; + + /// Generate source (src_dst=0) or destination (src_dst=1) index expressions. + /// For the destination side, non-trivial dimension indices are reversed to + /// implement the transpose. + Array MakeIndices(const Array &ivs, int src_dst) const; + + /// Build boundary predicate with transposed index mapping for dst. + PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array &ivs, + Array extents, int src_dst) const; + + /// Build a SIMT-style nested parallel loop implementing the transpose. + For MakeSIMTLoop(arith::Analyzer *analyzer) const; +}; + +/// Wrapper class for transpose operations +class Transpose : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Transpose, TileOperator, + TransposeNode); + TVM_DLL + Transpose(Array args, + Map annotations = Map()); + static const Op &Get(); +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_TRANSPOSE_H_ diff --git a/src/transform/lower_shared_tmem.cc b/src/transform/lower_shared_tmem.cc index bae3ab1e2d..679749d945 100644 --- a/src/transform/lower_shared_tmem.cc +++ b/src/transform/lower_shared_tmem.cc @@ -15,20 +15,113 @@ #include #include #include +#include namespace tvm { namespace tl { using namespace tir; +using VarSet = std::unordered_set; + +/*! + * \brief Collect TMEM buffers explicitly deallocated on fallthrough paths. + * + * A "fallthrough path" is one that reaches the end of the statement without + * hitting thread_return(). Buffers deallocated on every such path already + * have an explicit dealloc, so we can skip the auto-dealloc at block end. + * + * \return {buffers deallocated on fallthrough, whether the stmt can + * fallthrough} + */ +static std::pair CollectFallthroughDeallocs(const Stmt &stmt) { + if (!stmt.defined()) + return {{}, true}; + + // Unwrap transparent wrapper nodes + if (auto *n = stmt.as()) + return CollectFallthroughDeallocs(n->body); + if (auto *n = stmt.as()) + return CollectFallthroughDeallocs(n->body); + if (auto *n = stmt.as()) + return CollectFallthroughDeallocs(n->body); + if (auto *n = stmt.as()) + return CollectFallthroughDeallocs(n->block->body); + if (auto *n = stmt.as()) + return CollectFallthroughDeallocs(n->body); + + // Sequential: accumulate deallocs; stop if any child doesn't fallthrough + if (auto *seq = stmt.as()) { + VarSet deallocs; + for (const auto &child : seq->seq) { + auto [d, ft] = CollectFallthroughDeallocs(child); + if (!ft) + return {{}, false}; + deallocs.insert(d.begin(), d.end()); + } + return {std::move(deallocs), true}; + } + + // Branch: collect deallocs only from branches that can fallthrough + if (auto *iff = stmt.as()) { + auto [then_d, then_ft] = CollectFallthroughDeallocs(iff->then_case); + auto [else_d, else_ft] = + iff->else_case.defined() + ? CollectFallthroughDeallocs(iff->else_case.value()) + : std::pair{{}, true}; + VarSet deallocs; + if (then_ft) + deallocs.insert(then_d.begin(), then_d.end()); + if (else_ft) + deallocs.insert(else_d.begin(), else_d.end()); + return {std::move(deallocs), then_ft || else_ft}; + } + + // Leaf: detect deallocate_tmem and thread_return + if (auto *eval = stmt.as()) { + if (auto *call = eval->value.as()) { + if (call->op.same_as(tl::deallocate_tmem())) { + ICHECK_EQ(call->args.size(), 1U); + auto *buf = call->args[0].as(); + ICHECK(buf) << "tl.deallocate_tmem expects a buffer data Var"; + return {{tvm::ffi::GetRef(buf)}, true}; + } + if (call->op.same_as(builtin::thread_return())) { + return {{}, false}; + } + } + } + + return {{}, true}; +} + class SharedTmemRewriter : public StmtExprMutator { public: - static Stmt Rewrite(Stmt body) { + static Stmt Rewrite(Stmt body, Target target) { SharedTmemRewriter rewriter; + rewriter.target_ = std::move(target); return rewriter(body); } private: + int GetNumColsAllocated(const Buffer &buffer) const { + ICHECK_EQ(buffer->shape.size(), 2U); + + auto analyzer = std::make_shared(); + arith::ConstIntBound phy_col_bounds = + analyzer->const_int_bound(buffer->shape[1]); + int num_cols_required = phy_col_bounds->max_value; + ICHECK(num_cols_required <= 512) + << "The number of columns required for tmem buffer " << buffer->name + << " is " << num_cols_required + << ", which exceeds the maximum of 512 columns"; + + int num_cols_allocated = 32; // Align num_cols_allocated to power of 2 + for (; num_cols_allocated < num_cols_required; num_cols_allocated *= 2) { + } + return num_cols_allocated; + } + Stmt VisitStmt_(const BlockNode *op) final { Block block = tvm::ffi::GetRef(op); Array alloc_buffers = op->alloc_buffers; @@ -64,6 +157,8 @@ class SharedTmemRewriter : public StmtExprMutator { ICHECK(thread_var_.defined()) << "thread_var_ is not defined"; + auto [fallthrough_deallocs, _] = CollectFallthroughDeallocs(op->body); + for (auto buffer : tmem_buffers) { buffer_data_to_buffer_.Set(buffer->data, buffer); } @@ -145,22 +240,10 @@ class SharedTmemRewriter : public StmtExprMutator { auto data = buffer->data; auto old_buffer = buffer_data_to_buffer_.at(data); auto new_buffer = buffer_remap_.at(old_buffer); + int num_cols_allocated = GetNumColsAllocated(old_buffer); - // Tmem physical coord range analysis - ICHECK(old_buffer->shape.size() == 2); - - auto analyzer = std::make_shared(); - arith::ConstIntBound phy_col_bounds = - analyzer->const_int_bound(old_buffer->shape[1]); - int num_cols_required = phy_col_bounds->max_value; - ICHECK(num_cols_required <= 512) - << "The number of columns required for tmem buffer " - << old_buffer->name << " is " << num_cols_required - << ", which exceeds the maximum of 512 columns"; - - int num_cols_allocated = 32; // Align num_cols_allocated to power of 2 - for (; num_cols_allocated < num_cols_required; num_cols_allocated *= 2) - ; + tmem_num_cols_allocated_.insert({data, num_cols_allocated}); + tmem_call_annotations_.insert({data, tmem_call_ann}); auto new_buffer_access = new_buffer.access_ptr(1, DataType::Handle(), 1, PrimExpr(0), PrimExpr(1)); @@ -168,10 +251,12 @@ class SharedTmemRewriter : public StmtExprMutator { {new_buffer_access, PrimExpr(num_cols_allocated)}, tmem_call_ann); init_mtmem_calls_.push_back(Evaluate(alloc_call)); - auto dealloc_call = Call( - DataType::Handle(), tl::ptx_deallocate_tensor_memory(), - {new_buffer_access, PrimExpr(num_cols_allocated)}, tmem_call_ann); - dealloc_tmem_calls_.push_back(Evaluate(dealloc_call)); + if (!fallthrough_deallocs.count(data)) { + auto dealloc_call = Call( + DataType::Handle(), tl::ptx_deallocate_tensor_memory(), + {new_buffer_access, PrimExpr(num_cols_allocated)}, tmem_call_ann); + dealloc_tmem_calls_.push_back(Evaluate(dealloc_call)); + } } auto compare_by_buffer_name = [&](const Stmt &a, const Stmt &b) { auto call_a = a.as()->value.as(); @@ -184,8 +269,8 @@ class SharedTmemRewriter : public StmtExprMutator { compare_by_buffer_name); Array new_body; - auto target = Target::Current(); - auto warp_size = TargetGetWarpSize(target); + ICHECK(target_.defined()) << "LowerSharedTmem requires a bound target"; + auto warp_size = TargetGetWarpSize(target_); auto thread_var_div_warp_size = FloorDiv(thread_var_->var, IntImm(thread_var_->var->dtype, warp_size)); new_body.push_back(IfThenElse(EQ(thread_var_div_warp_size, 0), @@ -197,11 +282,17 @@ class SharedTmemRewriter : public StmtExprMutator { Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(), {StringImm("shared")}))); new_body.push_back(block->body); - new_body.push_back(IfThenElse(EQ(thread_var_div_warp_size, 0), - dealloc_tmem_calls_.size() > 1 - ? SeqStmt(dealloc_tmem_calls_) - : dealloc_tmem_calls_.back(), - Stmt())); + if (!dealloc_tmem_calls_.empty()) { + if (tmem_call_ann.find("use_2cta") != tmem_call_ann.end()) { + new_body.push_back( + Evaluate(Call(DataType::Handle(), tl::cluster_sync(), {}))); + } + new_body.push_back(IfThenElse(EQ(thread_var_div_warp_size, 0), + dealloc_tmem_calls_.size() > 1 + ? SeqStmt(dealloc_tmem_calls_) + : dealloc_tmem_calls_.back(), + Stmt())); + } auto block_ptr = block.CopyOnWrite(); block_ptr->annotations.erase(attr::kLayoutMap); @@ -261,6 +352,30 @@ class SharedTmemRewriter : public StmtExprMutator { } PrimExpr VisitExpr_(const CallNode *op) final { + if (op->op.same_as(tl::deallocate_tmem())) { + ICHECK_EQ(op->args.size(), 1U); + Var buffer_data = Downcast(op->args[0]); + auto num_cols_it = tmem_num_cols_allocated_.find(buffer_data); + ICHECK(num_cols_it != tmem_num_cols_allocated_.end()) + << "tl.deallocate_tmem expects a TMEM buffer allocated in the same " + "or an enclosing block"; + ICHECK(buffer_data_to_buffer_.count(buffer_data)) + << "TMEM buffer for tl.deallocate_tmem is not tracked"; + Buffer old_buffer = buffer_data_to_buffer_.at(buffer_data); + ICHECK(buffer_remap_.count(old_buffer)) + << "TMEM buffer for tl.deallocate_tmem has not been remapped"; + Buffer new_buffer = buffer_remap_[old_buffer]; + auto new_buffer_access = new_buffer.access_ptr(1, DataType::Handle(), 1, + PrimExpr(0), PrimExpr(1)); + + Map ann; + auto ann_it = tmem_call_annotations_.find(buffer_data); + if (ann_it != tmem_call_annotations_.end()) { + ann = ann_it->second; + } + return Call(DataType::Handle(), tl::ptx_deallocate_tensor_memory(), + {new_buffer_access, PrimExpr(num_cols_it->second)}, ann); + } if (op->op.same_as(builtin::tvm_access_ptr())) { ICHECK_EQ(op->args.size(), 5U); Var buffer_data = Downcast(op->args[1]); @@ -299,19 +414,23 @@ class SharedTmemRewriter : public StmtExprMutator { // This is a workaround for cpu backend, // we need to define a thread_var for the serial loop. IterVar thread_var_; + Target target_; Map var_remap_; Map buffer_data_to_buffer_; Map buffer_remap_; // Mapping from data Var of a Buffer to Buffer, for lookup std::unordered_map buffer_map_; + std::unordered_map + tmem_num_cols_allocated_; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + tmem_call_annotations_; Map layout_map_; }; PrimFunc LowerSharedTmem(PrimFunc f) { auto target = f->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerSharedTmem: Require the target attribute"; - SharedTmemRewriter rewriter; - f.CopyOnWrite()->body = rewriter.Rewrite(f->body); + f.CopyOnWrite()->body = SharedTmemRewriter::Rewrite(f->body, target.value()); return f; } diff --git a/src/transform/multi_version_buffer_rewriter.cc b/src/transform/multi_version_buffer_rewriter.cc index a3a5cab4f4..bfe3c70b93 100644 --- a/src/transform/multi_version_buffer_rewriter.cc +++ b/src/transform/multi_version_buffer_rewriter.cc @@ -457,7 +457,8 @@ class MultiVersionBufferRewriter : public StmtExprMutator { if (barrier_only_) { Array filtered; for (auto buffer : versioned_buffers) { - if (buffer.scope() == "shared.barrier") { + if (buffer.scope() == "shared.barrier" || + buffer.scope() == "shared.cluster_barrier") { filtered.push_back(buffer); } } diff --git a/src/transform/plan_update_buffer_allocation_location.cc b/src/transform/plan_update_buffer_allocation_location.cc index f6302a502a..1d12aac6eb 100644 --- a/src/transform/plan_update_buffer_allocation_location.cc +++ b/src/transform/plan_update_buffer_allocation_location.cc @@ -142,6 +142,13 @@ class BufferAllocationLocator : public StmtExprMutator { } // create buffers to be allocated at each stmts for (const auto &buffer : buffer_alloc_recorder) { + // Shared barriers must stay in their original block so the accompanying + // barrier_init annotation remains attached to the block that owns the + // initialization. Moving them into injected opaque blocks causes + // LowerSharedBarrier to see barrier buffers without local annotations. + if (IsBarrierBuffer(buffer)) { + continue; + } // Prefer the LCA derived from the underlying data var. If missing, fall // back to Buffer LCA. const StmtNode *stmt = nullptr; @@ -250,10 +257,18 @@ class BufferAllocationLocator : public StmtExprMutator { Stmt VisitStmt_(const BlockNode *op) final { ICHECK(!op->init.defined()); ffi::Array alloc_buffers; + ffi::Array preserved_barrier_buffers; + for (const Buffer &buf : op->alloc_buffers) { + if (IsBarrierBuffer(buf)) { + alloc_buffers.push_back(buf); + preserved_barrier_buffers.push_back(buf); + PushBinding(buf->data, buf); + } + } auto it = alloc_buffers_.find(op); if (it != alloc_buffers_.end()) { - alloc_buffers = it->second; for (const Buffer &buf : it->second) { + alloc_buffers.push_back(buf); PushBinding(buf->data, buf); } } @@ -280,6 +295,9 @@ class BufferAllocationLocator : public StmtExprMutator { PopBinding(buf->data); } } + for (const Buffer &buf : preserved_barrier_buffers) { + PopBinding(buf->data); + } ObjectPtr n = CopyOnWrite(op); n->alloc_buffers = std::move(alloc_buffers); @@ -332,6 +350,11 @@ class BufferAllocationLocator : public StmtExprMutator { ffi::Map> buffer_data_to_buffers_; /*! \brief Buffers that are allocated within a BlockNode, and may be moved. */ std::unordered_set managed_allocations_; + + static bool IsBarrierBuffer(const Buffer &buffer) { + String scope = buffer.scope(); + return scope == "shared.barrier" || scope == "shared.cluster_barrier"; + } }; PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) { diff --git a/testing/python/language/test_tilelang_language_transpose.py b/testing/python/language/test_tilelang_language_transpose.py new file mode 100644 index 0000000000..8d9b8bd34c --- /dev/null +++ b/testing/python/language/test_tilelang_language_transpose.py @@ -0,0 +1,117 @@ +"""Tests for T.transpose shared memory transpose primitive.""" + +import tilelang +import tilelang.language as T +import torch + + +def tilelang_transpose(M, N, block_M, block_N, dtype=T.float16): + """Kernel: read tile from A into shared, transpose in shared, write to B. + + A is (M, N), B is (M, N). + B = A.T.T = A when block_M == M and block_N == N (single tile). + Actually: we read A tile (block_M, block_N) into shared, + transpose to (block_N, block_M) in shared, then write to B + so B[bx*block_N + j, by*block_M + i] = A[by*block_M + i, bx*block_N + j] + i.e., B = A.T + """ + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((N, M), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + tile = T.alloc_shared((block_M, block_N), dtype) + tile_T = T.alloc_shared((block_N, block_M), dtype) + + # Load from global to shared + T.copy( + A[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N], + tile, + ) + # Transpose in shared memory + T.transpose(tile, tile_T) + # Store transposed tile back to global + T.copy( + tile_T, + B[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M], + ) + + return main + + +def run_tilelang_transpose(M=128, N=128, block_M=128, block_N=128, dtype=T.float16): + program = tilelang_transpose(M, N, block_M, block_N, dtype) + kernel = tilelang.compile( + program, + out_idx=[1], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + }, + ) + a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) + b = kernel(a) + expected = a.T + torch.testing.assert_close(b, expected, rtol=1e-2, atol=1e-2) + print(f"PASS: transpose M={M}, N={N}, block_M={block_M}, block_N={block_N}") + + +def tilelang_transpose_square(M, block_M, dtype=T.float16): + """Simpler test: square transpose with single tile.""" + + @T.prim_func + def main( + A: T.Tensor((M, M), dtype), + B: T.Tensor((M, M), dtype), + ): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(M, block_M), threads=128) as (bx, by): + tile = T.alloc_shared((block_M, block_M), dtype) + tile_T = T.alloc_shared((block_M, block_M), dtype) + + T.copy( + A[by * block_M : (by + 1) * block_M, bx * block_M : (bx + 1) * block_M], + tile, + ) + T.transpose(tile, tile_T) + T.copy( + tile_T, + B[bx * block_M : (bx + 1) * block_M, by * block_M : (by + 1) * block_M], + ) + + return main + + +def run_tilelang_transpose_square(M=256, block_M=128, dtype=T.float16): + program = tilelang_transpose_square(M, block_M, dtype) + kernel = tilelang.compile( + program, + out_idx=[1], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + }, + ) + a = torch.randn(M, M, device="cuda", dtype=getattr(torch, dtype)) + b = kernel(a) + expected = a.T + torch.testing.assert_close(b, expected, rtol=1e-2, atol=1e-2) + print(f"PASS: square transpose M={M}, block_M={block_M}") + + +def test_tilelang_transpose(): + run_tilelang_transpose(M=128, N=128, block_M=128, block_N=128) + run_tilelang_transpose(M=256, N=256, block_M=128, block_N=128) + run_tilelang_transpose(M=128, N=256, block_M=128, block_N=256) + + +def test_tilelang_transpose_square(): + run_tilelang_transpose_square(M=128, block_M=128) + run_tilelang_transpose_square(M=256, block_M=128) + run_tilelang_transpose_square(M=512, block_M=128) + + +if __name__ == "__main__": + test_tilelang_transpose() + test_tilelang_transpose_square() diff --git a/testing/python/transform/test_tilelang_transform_lower_shared_barrier.py b/testing/python/transform/test_tilelang_transform_lower_shared_barrier.py index 4d734fdbd8..7b1b2648fc 100644 --- a/testing/python/transform/test_tilelang_transform_lower_shared_barrier.py +++ b/testing/python/transform/test_tilelang_transform_lower_shared_barrier.py @@ -4,6 +4,7 @@ from tilelang.utils.target import determine_target import tilelang.language as T import tilelang.testing +from tilelang.engine.phase import LowerAndLegalize from tvm import tir auto_target = tvm.target.Target(determine_target("auto")) @@ -43,6 +44,19 @@ def _collect_shuffle_elect(stmt): return _collect_calls(stmt, "tl.tl_shuffle_elect") +def _collect_barrier_blocks(stmt): + blocks = [] + + def visitor(node): + if isinstance(node, tvm.tir.Block): + barrier_bufs = [buf for buf in node.alloc_buffers if buf.scope() in ("shared.barrier", "shared.cluster_barrier")] + if barrier_bufs: + blocks.append(node) + + tvm.tir.stmt_functor.post_order_visit(stmt, visitor) + return blocks + + def test_single_barrier(): """Single barrier with one arrive count.""" @@ -105,5 +119,62 @@ def func(): assert len(_collect_fence_barrier_init(body)) == 0 +def test_plan_update_keeps_barrier_init_with_tcgen05_no_tma(): + """Regression for tcgen05 no-TMA kernels after pass reordering.""" + + pass_configs = { + tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + } + + @T.prim_func + def func( + X: T.Tensor((256, 256), T.float16), + Y: T.Tensor((256, 256), T.float16), + ): + with T.Kernel(2, 2, threads=256) as (bx, by): + A_shared = T.alloc_shared((128, 128), T.float16) + B_shared = T.alloc_shared((128, 128), T.float16) + C_tmem = T.alloc_tmem([128, 128], T.float32) + mbar = T.alloc_barrier(1) + C_local = T.alloc_fragment((128, 128), T.float32) + Y_shared = T.alloc_shared((128, 128), T.float16) + + for ko in T.Pipelined(2, num_stages=2): + T.copy(X[by * 128, ko * 128], A_shared) + T.copy(X[bx * 128, ko * 128], B_shared) + T.tcgen05_gemm( + A_shared, + B_shared, + C_tmem, + transpose_B=True, + mbar=mbar, + clear_accum=ko == 0, + ) + T.mbarrier_wait_parity(mbar, ko % 2) + + T.copy(C_tmem, C_local) + T.copy(C_local, Y_shared) + T.copy(Y_shared, Y[by * 128, bx * 128]) + + target = tvm.target.Target("cuda -arch=sm_100") + with tvm.transform.PassContext(config=pass_configs), target: + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = LowerAndLegalize(mod, target) + mod = tl.transform.LowerSharedTmem()(mod) + mod = tl.transform.IfStmtBinding()(mod) + mod = tl.transform.PlanAndUpdateBufferAllocationLocation()(mod) + + barrier_blocks = _collect_barrier_blocks(mod["main"].body) + assert len(barrier_blocks) == 1 + assert "barrier_init" in barrier_blocks[0].annotations + + mod = tl.transform.LowerSharedBarrier()(mod) + + body = mod["main"].body + assert len(_collect_init_barrier_calls(body)) == 1 + assert len(_collect_fence_barrier_init(body)) == 1 + + if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/transform/test_tilelang_transform_lower_shared_tmem.py b/testing/python/transform/test_tilelang_transform_lower_shared_tmem.py new file mode 100644 index 0000000000..9c5d0a70b0 --- /dev/null +++ b/testing/python/transform/test_tilelang_transform_lower_shared_tmem.py @@ -0,0 +1,94 @@ +# ruff: noqa +from tilelang import tvm as tvm +import tilelang as tl +import tilelang.language as T +import tilelang.testing + + +TARGET = tvm.target.Target("cuda -arch=sm_100") + + +def _apply(func): + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tvm.tir.transform.BindTarget(TARGET)(mod) + mod = tl.transform.LowerSharedTmem()(mod) + return mod + + +def _collect_calls(stmt, op_name: str): + calls = [] + + def visitor(node): + if isinstance(node, tvm.tir.Call) and hasattr(node, "op") and hasattr(node.op, "name") and node.op.name == op_name: + calls.append(node) + + tvm.tir.stmt_functor.post_order_visit(stmt, visitor) + return calls + + +def test_explicit_deallocate_tmem_suppresses_auto_dealloc(): + """Explicit T.deallocate_tmem on fallthrough suppresses auto-dealloc.""" + + @T.prim_func + def func(): + with T.Kernel(1, threads=128): + C_tmem = T.alloc_tmem([128, 128], T.float32) + T.deallocate_tmem(C_tmem) + + mod = _apply(func) + body = mod["main"].body + assert len(_collect_calls(body, "tl.ptx_init_tensor_memory")) == 1 + assert len(_collect_calls(body, "tl.ptx_deallocate_tensor_memory")) == 1 + assert len(_collect_calls(body, "tl.deallocate_tmem")) == 0 + + dealloc_call = _collect_calls(body, "tl.ptx_deallocate_tensor_memory")[0] + assert dealloc_call.args[1].value == 128 + + +def test_explicit_deallocate_only_suppresses_matching_buffer(): + """Only the explicitly-deallocated buffer skips auto-dealloc; others keep it.""" + + @T.prim_func + def func(): + with T.Kernel(1, threads=128): + A_tmem = T.alloc_tmem([128, 128], T.float32) + B_tmem = T.alloc_tmem([128, 64], T.float32) + T.deallocate_tmem(A_tmem) + + mod = _apply(func) + body = mod["main"].body + + dealloc_calls = _collect_calls(body, "tl.ptx_deallocate_tensor_memory") + # A_tmem: 1 explicit (auto suppressed); B_tmem: 1 auto = 2 total + assert len(dealloc_calls) == 2 + + dealloc_num_cols = sorted(call.args[1].value for call in dealloc_calls) + assert dealloc_num_cols == [64, 128] + + +def test_dealloc_before_thread_return_keeps_auto_dealloc(): + """Dealloc on non-fallthrough path (before thread_return) does NOT suppress auto-dealloc.""" + + @T.prim_func + def func(): + with T.Kernel(1, threads=128): + C_tmem = T.alloc_tmem([128, 128], T.float32) + tx = T.get_thread_binding() + + if tx < 32: + T.deallocate_tmem(C_tmem) + T.thread_return() + + mod = _apply(func) + body = mod["main"].body + + dealloc_calls = _collect_calls(body, "tl.ptx_deallocate_tensor_memory") + # 1 explicit (non-fallthrough) + 1 auto (block end) = 2 + assert len(dealloc_calls) == 2 + assert [call.args[1].value for call in dealloc_calls] == [128, 128] + + +if __name__ == "__main__": + test_explicit_deallocate_tmem_suppresses_auto_dealloc() + test_explicit_deallocate_only_suppresses_matching_buffer() + test_dealloc_before_thread_return_keeps_auto_dealloc() diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 2e40136370..8ce71e5f31 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -54,7 +54,7 @@ alloc_global, # noqa: F401 ) from tvm.script.parser.tir import allocate as allocate # noqa: F401 -from .copy_op import copy, async_copy, tma_copy, c2d_im2col # noqa: F401 +from .copy_op import copy, async_copy, tma_copy, transpose, c2d_im2col # noqa: F401 from tilelang.tileop.base import GemmWarpPolicy # noqa: F401 from .gemm_op import gemm, gemm_v1, gemm_v2, wgmma_gemm, tcgen05_gemm # noqa: F401 from .experimental.gemm_sp import gemm_sp, gemm_sp_v2 # noqa: F401 diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index 39f313b312..a2d6b96c1d 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -200,7 +200,9 @@ def alloc_tmem(shape: ShapeType, dtype: DType) -> Buffer: Key properties and requirements: - The number of columns allocated must be a power of 2 and at least 32. - - TMEM allocations are dynamic and must be explicitly deallocated. + - TMEM allocations are dynamic. TileLang deallocates them automatically at + the end of the allocation block unless you call ``T.deallocate_tmem`` to + take manual control of the lifetime. - Both allocation and deallocation must be performed by the same warp. - The base address of the TMEM allocation is stored in shared memory and used as the offset for TCGEN5.MMA accumulator tensors. - Only TCGEN5.MMA and specific TMEM load/store instructions can access TMEM; all pre-processing must occur before data is loaded into TMEM, and all post-processing after data is retrieved. @@ -214,7 +216,8 @@ def alloc_tmem(shape: ShapeType, dtype: DType) -> Buffer: Note: - TMEM is only available on supported architectures (e.g., Hopper and later). - - The buffer returned should be used according to TMEM access restrictions and deallocated appropriately. + - The buffer returned should be used according to TMEM access restrictions. + Use ``T.deallocate_tmem`` only when you need an earlier, explicit release. """ assert len(shape) == 2, "shape must be a 2D tensor for TMEM allocation" diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 3a27e60699..c76bd42fb5 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -252,6 +252,32 @@ def _index_dtype(buf: tir.Buffer) -> str: ) +def deallocate_tmem(tmem: tir.Buffer) -> None: + """Explicitly deallocate a TMEM buffer allocated by ``T.alloc_tmem``. + + By default, TileLang inserts a TMEM deallocation automatically at the end + of the allocation block. Calling ``T.deallocate_tmem(buf)`` suppresses that + automatic tail deallocation for ``buf`` and lowers an explicit deallocation + at the call site instead. + + Notes: + - The deallocation must obey the hardware TMEM rules: it should be issued by + the same warp that performed the allocation. + - Once this API is used, the buffer lifetime is user-managed for the current + block; deallocating too early or conditionally is the user's responsibility. + + Args: + tmem: A TMEM buffer previously returned by ``T.alloc_tmem``. + """ + + if not isinstance(tmem, tir.Buffer): + raise TypeError(f"T.deallocate_tmem expects a tvm.tir.Buffer, but got {type(tmem)}.") + if tmem.scope() != "shared.tmem": + raise ValueError(f"T.deallocate_tmem expects a shared.tmem buffer, but got scope={tmem.scope()}.") + + return evaluate(tir.call_intrin("handle", tir.op.Op.get("tl.deallocate_tmem"), tmem.data)) + + def create_tma_descriptor(*args): """Create a Tensor Memory Access (TMA) descriptor. diff --git a/tilelang/language/copy_op.py b/tilelang/language/copy_op.py index 84d6b8dac4..7e9cdddb08 100644 --- a/tilelang/language/copy_op.py +++ b/tilelang/language/copy_op.py @@ -229,6 +229,41 @@ def tma_copy( return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.tma_copy"), src, dst, annotations=ann if ann else None) +def transpose( + src: BufferLikeType, + dst: BufferLikeType, +) -> tir.PrimExpr: + """Transpose a 2D buffer in shared memory: dst[j, i] = src[i, j]. + + Both src and dst should be shared memory buffers. + If src has shape (M, N), dst should have shape (N, M). + + Args: + src: Source buffer or region of shape (..., M, N). + dst: Destination buffer or region of shape (..., N, M). + + Returns: + tir.Call: A handle to the transpose operation. + """ + src_extent = get_extent(src) + dst_extent = get_extent(dst) + + assert src_extent is not None, "Cannot deduce extent for transpose src." + assert dst_extent is not None, "Cannot deduce extent for transpose dst." + assert len(src_extent) >= 2, "Transpose requires at least 2D buffers." + assert len(dst_extent) >= 2, "Transpose requires at least 2D buffers." + + src = to_buffer_region(src, access_type="r") + dst = to_buffer_region(dst, access_type="w") + + return tir.call_intrin( + "handle", + tir.op.Op.get("tl.tileop.transpose"), + src, + dst, + ) + + def c2d_im2col( img: BufferLikeType, col: BufferLikeType, From a9b8d53b6d3735f20d80739c57ff64b357486544 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Mon, 30 Mar 2026 16:08:39 +0800 Subject: [PATCH 002/156] Add `annotations` parameter to `alloc_buffer` in `tilelang/language/ast/ir.py` (#1996) * Initial plan * Add annotations parameter support to alloc_buffer in tilelang/language/ast/ir.py Agent-Logs-Url: https://github.com/tile-ai/tilelang/sessions/17577985-06fa-4b35-b714-185004d91524 Co-authored-by: LeiWang1999 <34334180+LeiWang1999@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: LeiWang1999 <34334180+LeiWang1999@users.noreply.github.com> --- tilelang/language/ast/ir.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/tilelang/language/ast/ir.py b/tilelang/language/ast/ir.py index d2f56598b6..49db2ca641 100644 --- a/tilelang/language/ast/ir.py +++ b/tilelang/language/ast/ir.py @@ -451,6 +451,7 @@ def alloc_buffer( offset_factor: int = 0, buffer_type: str = "default", axis_separators: Optional[List[int]] = None, + annotations: Optional[Dict[str, Any]] = None, ) -> Buffer: """The buffer allocation function. @@ -486,6 +487,10 @@ def alloc_buffer( axis_separators : List[int], optional The separators between input axes when generating flattened output axes. + annotations : Dict[str, Any], optional + Additional annotation hints for the buffer, e.g. to guide code generation + for specific backends. + Returns ------- res : Buffer @@ -496,18 +501,10 @@ def alloc_buffer( strides = [Var(s, T.int32) if isinstance(s, str) else s for s in strides] else: strides = [] - return _ffi_api.AllocBuffer( # type: ignore[attr-defined] # pylint: disable=no-member - shape, - dtype, - data, - strides, - elem_offset, - scope, - align, - offset_factor, - buffer_type, - axis_separators, - ) + args = [shape, dtype, data, strides, elem_offset, scope, align, offset_factor, buffer_type, axis_separators] + if annotations is not None: + args.append(annotations) + return _ffi_api.AllocBuffer(*args) # type: ignore[attr-defined] # pylint: disable=no-member def _as_range(dom: Union[ir.Range, List[PrimExpr]]) -> ir.Range: From 1c561f6d46074c2948b9230f3de0a7aa27c60880 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Mon, 30 Mar 2026 16:14:35 +0800 Subject: [PATCH 003/156] [Bugfix] Raise error on zero grid dimension instead of silent clamp (#1994) * [Bugfix] Raise error on zero grid dimension instead of silent clamp (#1993) Fix ThreadWorkLoad::Extract() silently clamping zero grid dims to 1, which caused either CUDA_ERROR_ILLEGAL_ADDRESS crashes (dynamic case) or silent wrong results (static case). Closes #1993 * lint fix --- 3rdparty/tvm | 2 +- .../python/issue/test_tilelang_issue_1993.py | 76 +++++++++++++++++++ 2 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 testing/python/issue/test_tilelang_issue_1993.py diff --git a/3rdparty/tvm b/3rdparty/tvm index fab43e41c0..12b47d3162 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit fab43e41c004e888ded30d45df25ccc8e2612617 +Subproject commit 12b47d316230fc777d13d4199200530e8c9529e1 diff --git a/testing/python/issue/test_tilelang_issue_1993.py b/testing/python/issue/test_tilelang_issue_1993.py new file mode 100644 index 0000000000..63581a23e5 --- /dev/null +++ b/testing/python/issue/test_tilelang_issue_1993.py @@ -0,0 +1,76 @@ +import torch +import pytest + +import tilelang +import tilelang.testing +import tilelang.language as T + + +@tilelang.jit +def _issue1993_dynamic_grid(): + num_tokens = T.dynamic("num_tokens") + + @T.prim_func + def kernel(out: T.Tensor[(num_tokens,), T.float32]): + with T.Kernel(num_tokens, threads=1) as pid: + out[pid] = T.float32(1.0) + + return kernel + + +@tilelang.jit +def _issue1993_static_grid(): + + @T.prim_func + def kernel(out: T.Tensor[(4,), T.float32]): + with T.Kernel(0, threads=1) as pid: + out[pid] = T.float32(1.0) + + return kernel + + +@tilelang.testing.requires_cuda +def test_issue_1993_dynamic_zero_grid_dim(): + """Regression test for issue #1993. + + When a dynamic grid dimension resolves to 0 at runtime, the runtime + should raise an error instead of silently clamping to 1 and launching + the kernel (which would write through a NULL pointer and crash with + CUDA_ERROR_ILLEGAL_ADDRESS). + """ + kernel = _issue1993_dynamic_grid() + + # Positive case: should work correctly + out = torch.zeros(4, dtype=torch.float32, device="cuda") + kernel(out) + torch.cuda.synchronize() + assert out.eq(1.0).all() + + # Zero case: should raise an error, not crash with illegal memory access + out_empty = torch.zeros(0, dtype=torch.float32, device="cuda") + with pytest.raises(Exception): # noqa: B017 + kernel(out_empty) + torch.cuda.synchronize() + + +@tilelang.testing.requires_cuda +def test_issue_1993_static_zero_grid_dim(): + """Regression test for issue #1993. + + When T.Kernel(0) is used with a static constant, the runtime should + raise an error instead of silently clamping to 1 and executing a + spurious CTA. + """ + kernel = _issue1993_static_grid() + + out = torch.zeros(4, dtype=torch.float32, device="cuda") + with pytest.raises(Exception): # noqa: B017 + kernel(out) + torch.cuda.synchronize() + + # Buffer should remain untouched + assert out.eq(0.0).all() + + +if __name__ == "__main__": + tilelang.testing.main() From 6a859f19411602ca8931774fce77b7eea3b5dc89 Mon Sep 17 00:00:00 2001 From: Tong WU <109033598+Rachmanino@users.noreply.github.com> Date: Tue, 31 Mar 2026 13:21:13 +0800 Subject: [PATCH 004/156] [BugFix] Fix missing barrier init attrs when TMA is disabled (#1995) * add check for kRows when allocating tmem * Add FirstOwningBlockCollector to map allocated buffers to their declaring blocks This change introduces the FirstOwningBlockCollector class, which collects the first Block that declares each allocated buffer variable. The BufferAllocationLocator is updated to utilize this new collector, ensuring that shared barriers remain associated with their declaring blocks during buffer allocation. This enhancement improves the management of buffer allocations and addresses potential issues with opaque child blocks in pipelined For loops. * replace `T.gemm` with `T.tcgen05_gemm` in fa sm100 examples * lint --------- Co-authored-by: LeiWang1999 --- .../flash_attention_sm100/gqa_bwd_bshd.py | 14 +++++----- .../flash_attention_sm100/gqa_fwd_bshd.py | 18 +++++-------- .../flash_attention_sm100/mha_bwd_bshd.py | 14 +++++----- .../flash_attention_sm100/mha_fwd_bshd.py | 26 +++++++------------ src/transform/lower_shared_tmem.cc | 12 +++++++++ 5 files changed, 42 insertions(+), 42 deletions(-) diff --git a/examples/flash_attention_sm100/gqa_bwd_bshd.py b/examples/flash_attention_sm100/gqa_bwd_bshd.py index a80da2f4b9..33661a4947 100644 --- a/examples/flash_attention_sm100/gqa_bwd_bshd.py +++ b/examples/flash_attention_sm100/gqa_bwd_bshd.py @@ -67,7 +67,7 @@ def main( else: for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.tcgen05_gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -80,7 +80,7 @@ def main( for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.copy(acc_s, acc_s_cast) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.tcgen05_gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.reduce_sum(acc_s, scores_sum, dim=1) for i in T.Parallel(block_M): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] @@ -207,7 +207,7 @@ def main( for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) - T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.tcgen05_gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) @@ -216,18 +216,18 @@ def main( qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) - T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.tcgen05_gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.copy(qkT_cast, qkT_shared) - T.gemm(qkT_shared, do, dv, policy=T.GemmWarpPolicy.FullRow) + T.tcgen05_gemm(qkT_shared, do, dv, policy=T.GemmWarpPolicy.FullRow) T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale T.copy(dsT_cast, dsT_shared) - T.gemm(dsT_shared, q, dk, policy=T.GemmWarpPolicy.FullRow) + T.tcgen05_gemm(dsT_shared, q, dk, policy=T.GemmWarpPolicy.FullRow) T.clear(dq) - T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + T.tcgen05_gemm(dsT_shared, K_shared, dq, transpose_A=True) for i, j in T.Parallel(block_N, dim): T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) diff --git a/examples/flash_attention_sm100/gqa_fwd_bshd.py b/examples/flash_attention_sm100/gqa_fwd_bshd.py index da945f046a..353f171146 100644 --- a/examples/flash_attention_sm100/gqa_fwd_bshd.py +++ b/examples/flash_attention_sm100/gqa_fwd_bshd.py @@ -96,13 +96,12 @@ def main( for k in T.Pipelined(loop_range, num_stages=1): T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) - T.gemm( + T.tcgen05_gemm( Q_shared, K_shared, S_tmem, transpose_B=True, mbar=mbar_s, - wg_wait=-1, clear_accum=True, ) T.mbarrier_wait_parity(mbar_s, k % 2) @@ -150,12 +149,11 @@ def main( T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) - T.gemm( + T.tcgen05_gemm( P_operand, V_shared, D_tmem, mbar=mbar_d, - wg_wait=-1, clear_accum=True, ) T.mbarrier_wait_parity(mbar_d, k % 2) @@ -291,23 +289,21 @@ def main( T.mbarrier_wait_parity(mbar_bmm1_empty[stage_id], parity_inv) if stage_id == 0: - T.gemm( + T.tcgen05_gemm( Q_shared, K_shared_0, S_tmem, transpose_B=True, mbar=mbar_bmm1_full[stage_id], - wg_wait=-1, clear_accum=True, ) else: - T.gemm( + T.tcgen05_gemm( Q_shared, K_shared_1, S_tmem, transpose_B=True, mbar=mbar_bmm1_full[stage_id], - wg_wait=-1, clear_accum=True, ) T.mbarrier_arrive(mbar_dma1_empty[stage_id]) @@ -316,21 +312,19 @@ def main( T.mbarrier_wait_parity(mbar_dma2_full[stage_id], parity) if stage_id == 0: - T.gemm( + T.tcgen05_gemm( P_tmem, V_shared_0, O_tmem, mbar=mbar_bmm2_full[stage_id], - wg_wait=-1, clear_accum=is_clear_accum, ) else: - T.gemm( + T.tcgen05_gemm( P_tmem, V_shared_1, O_tmem, mbar=mbar_bmm2_full[stage_id], - wg_wait=-1, clear_accum=is_clear_accum, ) diff --git a/examples/flash_attention_sm100/mha_bwd_bshd.py b/examples/flash_attention_sm100/mha_bwd_bshd.py index de64cb2736..0a04edd0aa 100644 --- a/examples/flash_attention_sm100/mha_bwd_bshd.py +++ b/examples/flash_attention_sm100/mha_bwd_bshd.py @@ -62,7 +62,7 @@ def main( else: for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.tcgen05_gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -75,7 +75,7 @@ def main( for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.copy(acc_s, acc_s_cast) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.tcgen05_gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.reduce_sum(acc_s, scores_sum, dim=1) for i in T.Parallel(block_M): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] @@ -197,7 +197,7 @@ def main( for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) - T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.tcgen05_gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) @@ -206,18 +206,18 @@ def main( qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) - T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.tcgen05_gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) - T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + T.tcgen05_gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale - T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + T.tcgen05_gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) T.copy(dsT_cast, dsT_shared) T.clear(dq) - T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + T.tcgen05_gemm(dsT_shared, K_shared, dq, transpose_A=True) for i, j in T.Parallel(block_N, dim): T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) diff --git a/examples/flash_attention_sm100/mha_fwd_bshd.py b/examples/flash_attention_sm100/mha_fwd_bshd.py index 97c231ad6c..acf5a7c3df 100644 --- a/examples/flash_attention_sm100/mha_fwd_bshd.py +++ b/examples/flash_attention_sm100/mha_fwd_bshd.py @@ -94,13 +94,12 @@ def main( T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) # GEMM 1: S = Q @ K^T -> S_tmem (tcgen05mma_ss) - T.gemm( + T.tcgen05_gemm( Q_shared, K_shared, S_tmem, transpose_B=True, mbar=mbar_s, - wg_wait=-1, clear_accum=True, ) T.mbarrier_wait_parity(mbar_s, k % 2) @@ -150,12 +149,11 @@ def main( T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) # GEMM 2: D = P @ V -> D_tmem (ss: mma_ss; ts: mma_ts) - T.gemm( + T.tcgen05_gemm( P_operand, V_shared, D_tmem, mbar=mbar_d, - wg_wait=-1, clear_accum=True, ) T.mbarrier_wait_parity(mbar_d, k % 2) @@ -286,23 +284,21 @@ def main( T.mbarrier_wait_parity(mbar_bmm1_empty[stage_id], parity_inv) if stage_id == 0: - T.gemm( + T.tcgen05_gemm( Q_shared, K_shared_0, S_tmem, transpose_B=True, mbar=mbar_bmm1_full[stage_id], - wg_wait=-1, clear_accum=True, ) else: - T.gemm( + T.tcgen05_gemm( Q_shared, K_shared_1, S_tmem, transpose_B=True, mbar=mbar_bmm1_full[stage_id], - wg_wait=-1, clear_accum=True, ) T.mbarrier_arrive(mbar_dma1_empty[stage_id]) @@ -311,21 +307,19 @@ def main( T.mbarrier_wait_parity(mbar_dma2_full[stage_id], parity) if stage_id == 0: - T.gemm( + T.tcgen05_gemm( P_tmem, V_shared_0, O_tmem, mbar=mbar_bmm2_full[stage_id], - wg_wait=-1, clear_accum=is_clear_accum, ) else: - T.gemm( + T.tcgen05_gemm( P_tmem, V_shared_1, O_tmem, mbar=mbar_bmm2_full[stage_id], - wg_wait=-1, clear_accum=is_clear_accum, ) @@ -478,15 +472,15 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--batch", type=int, default=2) - parser.add_argument("--heads", type=int, default=4) - parser.add_argument("--seq_len", type=int, default=256) + parser.add_argument("--batch", type=int, default=1) + parser.add_argument("--heads", type=int, default=16) + parser.add_argument("--seq_len", type=int, default=16384) parser.add_argument("--dim", type=int, default=128) parser.add_argument("--is_causal", action="store_true") parser.add_argument( "--variant", choices=["ss", "ts", "wasp"], - default="ss", + default="wasp", help="ss: pipeline 128t; ts: single-path 256t mma_ts; wasp: warp-specialized (fallback to ts if fail)", ) args = parser.parse_args() diff --git a/src/transform/lower_shared_tmem.cc b/src/transform/lower_shared_tmem.cc index 679749d945..6e5f27d39b 100644 --- a/src/transform/lower_shared_tmem.cc +++ b/src/transform/lower_shared_tmem.cc @@ -242,6 +242,18 @@ class SharedTmemRewriter : public StmtExprMutator { auto new_buffer = buffer_remap_.at(old_buffer); int num_cols_allocated = GetNumColsAllocated(old_buffer); + // Check that the number of rows doesn't exceed the tmem limit + { + auto analyzer = std::make_shared(); + arith::ConstIntBound phy_row_bounds = + analyzer->const_int_bound(old_buffer->shape[0]); + int num_rows_required = phy_row_bounds->max_value; + ICHECK(num_rows_required <= 128) + << "The number of rows required for tmem buffer " + << old_buffer->name << " is " << num_rows_required + << ", which exceeds the maximum of 128 rows"; + } + tmem_num_cols_allocated_.insert({data, num_cols_allocated}); tmem_call_annotations_.insert({data, tmem_call_ann}); From 0f7c21474d4a84f9d9a516e456e31811d58994d3 Mon Sep 17 00:00:00 2001 From: Tong WU <109033598+Rachmanino@users.noreply.github.com> Date: Tue, 31 Mar 2026 15:52:25 +0800 Subject: [PATCH 005/156] [BugFix] Add missing fences in GEMM SM100 examples and canonicalize the order of blockIdx (#1980) * Add TCGEN05 thread synchronization fences in GEMM examples and builtins - Introduced `tcgen05_before_thread_sync` and `tcgen05_after_thread_sync` builtins to manage thread synchronization in TCGEN05 operations. - Updated `gemm` and `gemm_2cta` examples to include synchronization calls before and after thread barriers, ensuring correct execution order. - Adjusted kernel launch parameters in GEMM examples for consistency in block dimensions. - Enhanced memory copy operations to align with the new synchronization logic, improving performance and correctness. * lint --- examples/gemm_sm100/gemm_tcgen5mma_ws.py | 24 +++++++++++-------- .../gemm_tcgen5mma_ws_persistent.py | 8 +++++++ src/op/builtin.cc | 10 ++++++++ src/op/builtin.h | 13 ++++++++++ src/target/codegen_cuda.cc | 10 ++++++++ src/target/codegen_cutedsl.cc | 10 ++++++++ tilelang/language/builtin.py | 8 +++++++ 7 files changed, 73 insertions(+), 10 deletions(-) diff --git a/examples/gemm_sm100/gemm_tcgen5mma_ws.py b/examples/gemm_sm100/gemm_tcgen5mma_ws.py index 188cd57864..ff147a2827 100644 --- a/examples/gemm_sm100/gemm_tcgen5mma_ws.py +++ b/examples/gemm_sm100/gemm_tcgen5mma_ws.py @@ -16,7 +16,7 @@ def gemm(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_ B: T.Tensor[[K, N], in_dtype] C = T.empty((M, N), out_dtype) - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by): A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) B_shared = T.alloc_shared((num_stages, block_K, block_N), in_dtype) C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) @@ -35,12 +35,12 @@ def gemm(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_ for k in T.serial(k_iters): T.mbarrier_wait_parity(consumed[k % num_stages], ((k // num_stages) & 1) ^ 1) T.tma_copy( - A[by * block_M : (by + 1) * block_M, k * block_K : (k + 1) * block_K], + A[bx * block_M : (bx + 1) * block_M, k * block_K : (k + 1) * block_K], A_shared[k % num_stages, :, :], barrier=loaded[k % num_stages], ) T.tma_copy( - B[k * block_K : (k + 1) * block_K, bx * block_N : (bx + 1) * block_N], + B[k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared[k % num_stages, :, :], barrier=loaded[k % num_stages], ) @@ -48,6 +48,7 @@ def gemm(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_ elif tx < 64: # warp 1: issue tcgen5 for k in T.serial(k_iters): T.mbarrier_wait_parity(loaded[k % num_stages], (k // num_stages) & 1) + T.tcgen05_after_thread_sync() T.tcgen05_gemm( A_shared[k % num_stages, :, :], B_shared[k % num_stages, :, :], @@ -59,13 +60,14 @@ def gemm(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_ # Wait for all tcgen5 to finish T.mbarrier_wait_parity(tmem_full, 0) + T.tcgen05_after_thread_sync() T.copy(C_tmem, C_local) if use_tma_store: T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) + T.copy(C_shared, C[bx * block_M, by * block_N]) else: T.copy(C_local, C_local_cast) - T.copy(C_local_cast, C[by * block_M, bx * block_N]) # STG256 + T.copy(C_local_cast, C[bx * block_M, by * block_N]) # STG256 return C @@ -79,7 +81,7 @@ def gemm_2cta(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, B: T.Tensor[[K, N], in_dtype] C = T.empty((M, N), out_dtype) - with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128, cluster_dims=2) as (by, bx): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128, cluster_dims=2) as (bx, by): A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) B_shared = T.alloc_shared((num_stages, block_K, block_N // 2), in_dtype) # Each cta hold half of B C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) @@ -100,12 +102,12 @@ def gemm_2cta(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, for k in T.serial(k_iters): T.mbarrier_wait_parity(consumed[k % num_stages], ((k // num_stages) & 1) ^ 1) T.tma_copy( - A[by * block_M : (by + 1) * block_M, k * block_K : (k + 1) * block_K], + A[bx * block_M : (bx + 1) * block_M, k * block_K : (k + 1) * block_K], A_shared[k % num_stages, :, :], barrier=loaded[k % num_stages], ) T.tma_copy( - B[k * block_K : (k + 1) * block_K, (bx * 2 + cta_id) * (block_N // 2) : (bx * 2 + cta_id + 1) * (block_N // 2)], + B[k * block_K : (k + 1) * block_K, (by * 2 + cta_id) * (block_N // 2) : (by * 2 + cta_id + 1) * (block_N // 2)], B_shared[k % num_stages, :, :], barrier=loaded[k % num_stages], ) @@ -113,6 +115,7 @@ def gemm_2cta(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, elif cta_id == 0 and tx < 64: # Only warp 1 on leader cta issues tcgen5 for k in T.serial(k_iters): T.mbarrier_wait_parity(loaded[k % num_stages], (k // num_stages) & 1) + T.tcgen05_after_thread_sync() T.tcgen05_gemm( A_shared[k % num_stages, :, :], B_shared[k % num_stages, :, :], @@ -125,13 +128,14 @@ def gemm_2cta(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, # Wait for all tcgen5 to finish T.mbarrier_wait_parity(tmem_full, 0) + T.tcgen05_after_thread_sync() T.copy(C_tmem, C_local) if use_tma_store: T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) + T.copy(C_shared, C[bx * block_M, by * block_N]) else: T.copy(C_local, C_local_cast) - T.copy(C_local_cast, C[by * block_M, bx * block_N]) + T.copy(C_local_cast, C[bx * block_M, by * block_N]) return C diff --git a/examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py b/examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py index 2db2f1f1be..82a82aaa1a 100644 --- a/examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py +++ b/examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py @@ -81,9 +81,11 @@ def gemm_persistent( if bx * block_M < M and by * block_N < N: T.mbarrier_wait_parity(tmem_empty[w & 1], ((w // 2) & 1) ^ 1) + T.tcgen05_after_thread_sync() for k in T.serial(k_blocks): phase = w * k_blocks + k T.mbarrier_wait_parity(loaded[phase % num_stages], (phase // num_stages) & 1) + T.tcgen05_after_thread_sync() if w & 1 == 0: T.tcgen05_gemm( A_shared[k % num_stages, :, :], @@ -114,11 +116,13 @@ def gemm_persistent( if bx * block_M < M and by * block_N < N: T.mbarrier_wait_parity(tmem_full[w & 1], (w // 2) & 1) + T.tcgen05_after_thread_sync() T.sync_threads(1, 128) if (w & 1) == 0: T.copy(C_tmem_0, C_local) else: T.copy(C_tmem_1, C_local) + T.tcgen05_before_thread_sync() T.mbarrier_arrive(tmem_empty[w & 1]) if use_tma_store: @@ -216,9 +220,11 @@ def gemm_persistent_2cta( if bx * block_M < M and by * block_N < N: T.mbarrier_wait_parity(tmem_empty[w & 1], ((w // 2) & 1) ^ 1) + T.tcgen05_after_thread_sync() for k in T.serial(k_blocks): phase = w * k_blocks + k T.mbarrier_wait_parity(loaded[phase % num_stages], (phase // num_stages) & 1) + T.tcgen05_after_thread_sync() if w & 1 == 0: T.tcgen05_gemm( A_shared[phase % num_stages, :, :], @@ -250,11 +256,13 @@ def gemm_persistent_2cta( if bx * block_M < M and by * block_N < N: T.mbarrier_wait_parity(tmem_full[w & 1], (w // 2) & 1) + T.tcgen05_after_thread_sync() T.sync_threads(1, 128) if (w & 1) == 0: T.copy(C_tmem_0, C_local) else: T.copy(C_tmem_1, C_local) + T.tcgen05_before_thread_sync() T.mbarrier_arrive(tmem_empty[w & 1], 0) if use_tma_store: diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 15c1459edb..59b726c51a 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -478,6 +478,16 @@ TIR_DEFINE_TL_BUILTIN(tcgen05_mma_arrive) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(tcgen05_before_thread_sync) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(tcgen05_after_thread_sync) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(warp_reduce_sum) .set_num_inputs(1) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index db13db5d1f..0d8a570253 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -705,6 +705,19 @@ TVM_DLL const Op &initialize_tcgen05_descriptor(); */ TVM_DLL const Op &tcgen05_mma_arrive(); +/*! + * \brief TCGEN05 fence before a thread-block-wide sync (__syncthreads / + * bar.sync). Matches PTX \c tcgen05.fence::before_thread_sync (DeepGEMM / + * Blackwell UMMA sequencing). + */ +TVM_DLL const Op &tcgen05_before_thread_sync(); + +/*! + * \brief TCGEN05 fence after a thread-block-wide sync. Matches PTX \c + * tcgen05.fence::after_thread_sync. + */ +TVM_DLL const Op &tcgen05_after_thread_sync(); + /*! * \brief tilelang intrinsic for setting the start address of a descriptor * buffer for wgmma/utcmma. diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index b7ef2e0163..10ae3e9796 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -2622,6 +2622,16 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ss << ""; } print_extern_call_stmt(ss.str()); + } else if (op->op.same_as(tl::tcgen05_before_thread_sync())) { + ICHECK_EQ(op->args.size(), 0U) + << "tcgen05_before_thread_sync expects no arguments"; + need_tcgen05_common_h_ = true; + print_extern_call_stmt("tl::tcgen05_before_thread_sync"); + } else if (op->op.same_as(tl::tcgen05_after_thread_sync())) { + ICHECK_EQ(op->args.size(), 0U) + << "tcgen05_after_thread_sync expects no arguments"; + need_tcgen05_common_h_ = true; + print_extern_call_stmt("tl::tcgen05_after_thread_sync"); } else if (op->op.same_as(builtin::ptx_ldmatrix())) { // arg 0: whether the matrix is loaded in column major format or not. // arg 1: number of matrices to load. diff --git a/src/target/codegen_cutedsl.cc b/src/target/codegen_cutedsl.cc index b5fd373bc9..d7cf6b2cf0 100644 --- a/src/target/codegen_cutedsl.cc +++ b/src/target/codegen_cutedsl.cc @@ -903,6 +903,16 @@ void CodeGenTileLangCuTeDSL::VisitExpr_(const CallNode *op, ICHECK_EQ(op->args.size(), 1U) << "tcgen05_mma_arrive expects 1 argument"; PrintIndent(); stream << "tl.tcgen05_mma_arrive(" << PrintExpr_(op->args[0]) << ")\n"; + } else if (op->op.same_as(tl::tcgen05_before_thread_sync())) { + ICHECK_EQ(op->args.size(), 0U) + << "tcgen05_before_thread_sync expects no arguments"; + PrintIndent(); + stream << "tl.tcgen05_before_thread_sync()\n"; + } else if (op->op.same_as(tl::tcgen05_after_thread_sync())) { + ICHECK_EQ(op->args.size(), 0U) + << "tcgen05_after_thread_sync expects no arguments"; + PrintIndent(); + stream << "tl.tcgen05_after_thread_sync()\n"; } else if (op->op.same_as(builtin::ptx_ldmatrix())) { // arg 0: whether the matrix is loaded in column major format or not. // arg 1: number of matrices to load. diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index c76bd42fb5..7ec535c5c9 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -1061,6 +1061,14 @@ def tcgen05_mma_arrive(mbar: tir.Buffer | BufferLoad | PrimExpr, arrive_2cta: bo return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar, annotations=ann) +def tcgen05_before_thread_sync(): + return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_before_thread_sync")) + + +def tcgen05_after_thread_sync(): + return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_after_thread_sync")) + + def ptx_mma_sm70( shape, A_layout, From eb6f05c753876de79989a0dd9290576a2442840e Mon Sep 17 00:00:00 2001 From: Chaofan Lin Date: Tue, 31 Mar 2026 22:07:13 +0800 Subject: [PATCH 006/156] [Refactor] Refactor CUDA atomic helpers (#2001) * [Refactor] Refactor CUDA atomic helpers * update --- src/tl_templates/cuda/atomic.h | 630 +++++++++++++++------------------ 1 file changed, 287 insertions(+), 343 deletions(-) diff --git a/src/tl_templates/cuda/atomic.h b/src/tl_templates/cuda/atomic.h index ecc550c041..9bd1ff6eab 100644 --- a/src/tl_templates/cuda/atomic.h +++ b/src/tl_templates/cuda/atomic.h @@ -49,6 +49,205 @@ template <> TL_DEVICE __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) { } #endif +// Helpers for atomic operations + +namespace tl_atomic_detail { + +TL_DEVICE bool IsRelaxedMemoryOrder(int memory_order) { + return memory_order == int(cuda::memory_order_relaxed); +} + +TL_DEVICE bool IsReleaseLikeMemoryOrder(int memory_order) { + return memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume); +} + +TL_DEVICE bool IsAcquireMemoryOrder(int memory_order) { + return memory_order == int(cuda::memory_order_acquire); +} + +TL_DEVICE bool IsAcqRelLikeMemoryOrder(int memory_order) { + return memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst); +} + +template TL_DEVICE unsigned short PackBits16(const T &val) { + return *reinterpret_cast(&val); +} + +template TL_DEVICE T UnpackBits16(unsigned short val) { + return *reinterpret_cast(&val); +} + +TL_DEVICE void tl_atomic_add_f16(unsigned short &ret, unsigned long long addr, + unsigned short val, int memory_order) { + if (IsReleaseLikeMemoryOrder(memory_order)) { + asm volatile("atom.release.gpu.global.add.noftz.f16 %0, [%1], %2;" + : "=h"(ret) + : "l"(addr), "h"(val) + : "memory"); + } else if (IsAcquireMemoryOrder(memory_order)) { + asm volatile("atom.acquire.gpu.global.add.noftz.f16 %0, [%1], %2;" + : "=h"(ret) + : "l"(addr), "h"(val) + : "memory"); + } else if (IsAcqRelLikeMemoryOrder(memory_order)) { + asm volatile("atom.acq_rel.gpu.global.add.noftz.f16 %0, [%1], %2;" + : "=h"(ret) + : "l"(addr), "h"(val) + : "memory"); + } +} + +TL_DEVICE void tl_atomic_add_bf16(unsigned short &ret, unsigned long long addr, + unsigned short val, int memory_order) { + if (IsReleaseLikeMemoryOrder(memory_order)) { + asm volatile("atom.release.gpu.global.add.noftz.bf16 %0, [%1], %2;" + : "=h"(ret) + : "l"(addr), "h"(val) + : "memory"); + } else if (IsAcquireMemoryOrder(memory_order)) { + asm volatile("atom.acquire.gpu.global.add.noftz.bf16 %0, [%1], %2;" + : "=h"(ret) + : "l"(addr), "h"(val) + : "memory"); + } else if (IsAcqRelLikeMemoryOrder(memory_order)) { + asm volatile("atom.acq_rel.gpu.global.add.noftz.bf16 %0, [%1], %2;" + : "=h"(ret) + : "l"(addr), "h"(val) + : "memory"); + } +} + +TL_DEVICE void tl_atomic_add_v2_f16(unsigned short &ret_x, + unsigned short &ret_y, + unsigned long long addr, + unsigned short val_x, unsigned short val_y, + int memory_order) { + if (IsReleaseLikeMemoryOrder(memory_order)) { + asm volatile( + "atom.release.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_x), "=h"(ret_y) + : "l"(addr), "h"(val_x), "h"(val_y) + : "memory"); + } else if (IsAcquireMemoryOrder(memory_order)) { + asm volatile( + "atom.acquire.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_x), "=h"(ret_y) + : "l"(addr), "h"(val_x), "h"(val_y) + : "memory"); + } else if (IsAcqRelLikeMemoryOrder(memory_order)) { + asm volatile( + "atom.acq_rel.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_x), "=h"(ret_y) + : "l"(addr), "h"(val_x), "h"(val_y) + : "memory"); + } +} + +TL_DEVICE void tl_atomic_add_v2_bf16(unsigned short &ret_x, + unsigned short &ret_y, + unsigned long long addr, + unsigned short val_x, unsigned short val_y, + int memory_order) { + if (IsReleaseLikeMemoryOrder(memory_order)) { + asm volatile("atom.release.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_x), "=h"(ret_y) + : "l"(addr), "h"(val_x), "h"(val_y) + : "memory"); + } else if (IsAcquireMemoryOrder(memory_order)) { + asm volatile("atom.acquire.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_x), "=h"(ret_y) + : "l"(addr), "h"(val_x), "h"(val_y) + : "memory"); + } else if (IsAcqRelLikeMemoryOrder(memory_order)) { + asm volatile("atom.acq_rel.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_x), "=h"(ret_y) + : "l"(addr), "h"(val_x), "h"(val_y) + : "memory"); + } +} + +TL_DEVICE void tl_atomic_add_v2_f32(float &ret_x, float &ret_y, + unsigned long long addr, float val_x, + float val_y, int memory_order) { + if (IsReleaseLikeMemoryOrder(memory_order)) { + asm volatile("atom.release.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_x), "=f"(ret_y) + : "l"(addr), "f"(val_x), "f"(val_y) + : "memory"); + } else if (IsAcquireMemoryOrder(memory_order)) { + asm volatile("atom.acquire.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_x), "=f"(ret_y) + : "l"(addr), "f"(val_x), "f"(val_y) + : "memory"); + } else if (IsAcqRelLikeMemoryOrder(memory_order)) { + asm volatile("atom.acq_rel.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_x), "=f"(ret_y) + : "l"(addr), "f"(val_x), "f"(val_y) + : "memory"); + } +} + +TL_DEVICE void tl_atomic_add_v4_f32(float &ret_x, float &ret_y, float &ret_z, + float &ret_w, unsigned long long addr, + float val_x, float val_y, float val_z, + float val_w, int memory_order) { + if (IsReleaseLikeMemoryOrder(memory_order)) { + asm volatile( + "atom.release.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], {%5,%6,%7,%8};" + : "=f"(ret_x), "=f"(ret_y), "=f"(ret_z), "=f"(ret_w) + : "l"(addr), "f"(val_x), "f"(val_y), "f"(val_z), "f"(val_w) + : "memory"); + } else if (IsAcquireMemoryOrder(memory_order)) { + asm volatile( + "atom.acquire.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], {%5,%6,%7,%8};" + : "=f"(ret_x), "=f"(ret_y), "=f"(ret_z), "=f"(ret_w) + : "l"(addr), "f"(val_x), "f"(val_y), "f"(val_z), "f"(val_w) + : "memory"); + } else if (IsAcqRelLikeMemoryOrder(memory_order)) { + asm volatile( + "atom.acq_rel.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], {%5,%6,%7,%8};" + : "=f"(ret_x), "=f"(ret_y), "=f"(ret_z), "=f"(ret_w) + : "l"(addr), "f"(val_x), "f"(val_y), "f"(val_z), "f"(val_w) + : "memory"); + } +} + +// Fallback implementations: do atomicAdd sequentially. + +template TL_DEVICE void AtomicAddx2Scalar(T *ref, T x, T y) { + atomicAdd(ref + 0, x); + atomicAdd(ref + 1, y); +} + +template +TL_DEVICE void AtomicAddx4Scalar(T *ref, T x, T y, T z, T w) { + atomicAdd(ref + 0, x); + atomicAdd(ref + 1, y); + atomicAdd(ref + 2, z); + atomicAdd(ref + 3, w); +} + +TL_DEVICE float2 AtomicAddx2ScalarRet(float *ref, float2 add_val) { + float2 ret; + ret.x = atomicAdd(ref + 0, add_val.x); + ret.y = atomicAdd(ref + 1, add_val.y); + return ret; +} + +template +TL_DEVICE float4 AtomicAddx4ScalarRet(dst_dtype *ref, float4 add_val) { + float4 ret; + ret.x = atomicAdd(ref + 0, add_val.x); + ret.y = atomicAdd(ref + 1, add_val.y); + ret.z = atomicAdd(ref + 2, add_val.z); + ret.w = atomicAdd(ref + 3, add_val.w); + return ret; +} + +} // namespace tl_atomic_detail + template TL_DEVICE void AtomicMax(T1 *ref, T2 val, int memory_order = int(cuda::memory_order_relaxed)) { @@ -180,63 +379,29 @@ TL_DEVICE void AtomicAdd(T1 *address, T2 val, using NT1 = typename normalize_atomic_type::type; if constexpr (std::is_same_v || std::is_same_v) { - if (memory_order == int(cuda::memory_order_relaxed)) { + if (tl_atomic_detail::IsRelaxedMemoryOrder(memory_order)) { atomicAdd(reinterpret_cast(address), static_cast(val)); } else { // Since atomic ref do not support memory order, we need to inline ptx // code here for each situation if constexpr (std::is_same_v) { // fp16 - __half ret_val; - unsigned short ret_val_cast = - *reinterpret_cast(&ret_val); + unsigned short ret_val_cast; unsigned long long ref_address = reinterpret_cast(address); - unsigned short val_cast = *reinterpret_cast(&val); - if (memory_order == int(cuda::memory_order_release) || - memory_order == int(cuda::memory_order_consume)) { - asm volatile("atom.release.gpu.global.add.noftz.f16 %0, [%1], %2;" - : "=h"(ret_val_cast) - : "l"(ref_address), "h"(val_cast) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acquire)) { - asm volatile("atom.acquire.gpu.global.add.noftz.f16 %0, [%1], %2;" - : "=h"(ret_val_cast) - : "l"(ref_address), "h"(val_cast) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acq_rel) || - memory_order == int(cuda::memory_order_seq_cst)) { - asm volatile("atom.acq_rel.gpu.global.add.noftz.f16 %0, [%1], %2;" - : "=h"(ret_val_cast) - : "l"(ref_address), "h"(val_cast) - : "memory"); - } + unsigned short val_cast = + tl_atomic_detail::PackBits16(cuda_cast(val)); + tl_atomic_detail::tl_atomic_add_f16(ret_val_cast, ref_address, val_cast, + memory_order); } else if constexpr (std::is_same_v) { // bf16 - __nv_bfloat16 ret_val; - unsigned short ret_val_cast = - *reinterpret_cast(&ret_val); + unsigned short ret_val_cast; unsigned long long ref_address = reinterpret_cast(address); - unsigned short val_cast = *reinterpret_cast(&val); - if (memory_order == int(cuda::memory_order_release) || - memory_order == int(cuda::memory_order_consume)) { - asm volatile("atom.release.gpu.global.add.noftz.bf16 %0, [%1], %2;" - : "=h"(ret_val_cast) - : "l"(ref_address), "h"(val_cast) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acquire)) { - asm volatile("atom.acquire.gpu.global.add.noftz.bf16 %0, [%1], %2;" - : "=h"(ret_val_cast) - : "l"(ref_address), "h"(val_cast) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acq_rel) || - memory_order == int(cuda::memory_order_seq_cst)) { - asm volatile("atom.acq_rel.gpu.global.add.noftz.bf16 %0, [%1], %2;" - : "=h"(ret_val_cast) - : "l"(ref_address), "h"(val_cast) - : "memory"); - } + unsigned short val_cast = + tl_atomic_detail::PackBits16(cuda_cast(val)); + tl_atomic_detail::tl_atomic_add_bf16(ret_val_cast, ref_address, + val_cast, memory_order); } } } else { @@ -259,65 +424,32 @@ TL_DEVICE T1 AtomicAddRet(T1 *address, T2 val, using NT1 = typename normalize_atomic_type::type; if constexpr (std::is_same_v || std::is_same_v) { - if (memory_order == int(cuda::memory_order_relaxed)) { + if (tl_atomic_detail::IsRelaxedMemoryOrder(memory_order)) { return static_cast( atomicAdd(reinterpret_cast(address), static_cast(val))); } else { if constexpr (std::is_same_v) { // fp16 - __half ret_val; - unsigned short ret_val_cast = - *reinterpret_cast(&ret_val); + unsigned short ret_val_cast; unsigned long long ref_address = reinterpret_cast(address); - unsigned short val_cast = *reinterpret_cast(&val); - if (memory_order == int(cuda::memory_order_release) || - memory_order == int(cuda::memory_order_consume)) { - asm volatile("atom.release.gpu.global.add.noftz.f16 %0, [%1], %2;" - : "=h"(ret_val_cast) - : "l"(ref_address), "h"(val_cast) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acquire)) { - asm volatile("atom.acquire.gpu.global.add.noftz.f16 %0, [%1], %2;" - : "=h"(ret_val_cast) - : "l"(ref_address), "h"(val_cast) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acq_rel) || - memory_order == int(cuda::memory_order_seq_cst)) { - asm volatile("atom.acq_rel.gpu.global.add.noftz.f16 %0, [%1], %2;" - : "=h"(ret_val_cast) - : "l"(ref_address), "h"(val_cast) - : "memory"); - } - return static_cast(*reinterpret_cast<__half *>(&ret_val_cast)); + unsigned short val_cast = + tl_atomic_detail::PackBits16(cuda_cast(val)); + tl_atomic_detail::tl_atomic_add_f16(ret_val_cast, ref_address, val_cast, + memory_order); + return static_cast( + tl_atomic_detail::UnpackBits16<__half>(ret_val_cast)); } else if constexpr (std::is_same_v) { // bf16 - __nv_bfloat16 ret_val; - unsigned short ret_val_cast = - *reinterpret_cast(&ret_val); + unsigned short ret_val_cast; unsigned long long ref_address = reinterpret_cast(address); - unsigned short val_cast = *reinterpret_cast(&val); - if (memory_order == int(cuda::memory_order_release) || - memory_order == int(cuda::memory_order_consume)) { - asm volatile("atom.release.gpu.global.add.noftz.bf16 %0, [%1], %2;" - : "=h"(ret_val_cast) - : "l"(ref_address), "h"(val_cast) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acquire)) { - asm volatile("atom.acquire.gpu.global.add.noftz.bf16 %0, [%1], %2;" - : "=h"(ret_val_cast) - : "l"(ref_address), "h"(val_cast) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acq_rel) || - memory_order == int(cuda::memory_order_seq_cst)) { - asm volatile("atom.acq_rel.gpu.global.add.noftz.bf16 %0, [%1], %2;" - : "=h"(ret_val_cast) - : "l"(ref_address), "h"(val_cast) - : "memory"); - } + unsigned short val_cast = + tl_atomic_detail::PackBits16(cuda_cast(val)); + tl_atomic_detail::tl_atomic_add_bf16(ret_val_cast, ref_address, + val_cast, memory_order); return static_cast( - *reinterpret_cast<__nv_bfloat16 *>(&ret_val_cast)); + tl_atomic_detail::UnpackBits16<__nv_bfloat16>(ret_val_cast)); } } } else { @@ -350,7 +482,7 @@ template TL_DEVICE void AtomicAddx2(half_t *ref, ValType val, int memory_order = int(cuda::memory_order_relaxed)) { half2 add_val = ToHalf2(val); - if (memory_order == int(cuda::memory_order_relaxed)) { + if (tl_atomic_detail::IsRelaxedMemoryOrder(memory_order)) { atomicAdd(reinterpret_cast(ref), add_val); } else { // Since atomicAdd does not support memory order, atomic_ref does not @@ -358,37 +490,14 @@ TL_DEVICE void AtomicAddx2(half_t *ref, ValType val, // Note: Vectorized atomic operations only support global space // Note: for 16-bit value, we need to reinterpret_cast the value to unsigned // short and use "h" register in assembly - unsigned short add_val_x_cast = - *reinterpret_cast(&add_val.x); - unsigned short add_val_y_cast = - *reinterpret_cast(&add_val.y); + unsigned short add_val_x_cast = tl_atomic_detail::PackBits16(add_val.x); + unsigned short add_val_y_cast = tl_atomic_detail::PackBits16(add_val.y); unsigned long long ref_addr = reinterpret_cast(ref); - __half ret_val_x, ret_val_y; - unsigned short ret_val_x_cast = - *reinterpret_cast(&ret_val_x); - unsigned short ret_val_y_cast = - *reinterpret_cast(&ret_val_y); - if (memory_order == int(cuda::memory_order_release) || - memory_order == int(cuda::memory_order_consume)) { - asm volatile( - "atom.release.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" - : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) - : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acquire)) { - asm volatile( - "atom.acquire.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" - : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) - : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acq_rel) || - memory_order == int(cuda::memory_order_seq_cst)) { - asm volatile( - "atom.acq_rel.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" - : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) - : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) - : "memory"); - } + unsigned short ret_val_x_cast; + unsigned short ret_val_y_cast; + tl_atomic_detail::tl_atomic_add_v2_f16(ret_val_x_cast, ret_val_y_cast, + ref_addr, add_val_x_cast, + add_val_y_cast, memory_order); } } @@ -397,42 +506,19 @@ TL_DEVICE half2 AtomicAddx2Ret(half_t *ref, ValType val, int memory_order = int(cuda::memory_order_relaxed)) { half2 add_val = ToHalf2(val); - if (memory_order == int(cuda::memory_order_relaxed)) { + if (tl_atomic_detail::IsRelaxedMemoryOrder(memory_order)) { return atomicAdd(reinterpret_cast(ref), add_val); } else { - unsigned short add_val_x_cast = - *reinterpret_cast(&add_val.x); - unsigned short add_val_y_cast = - *reinterpret_cast(&add_val.y); + unsigned short add_val_x_cast = tl_atomic_detail::PackBits16(add_val.x); + unsigned short add_val_y_cast = tl_atomic_detail::PackBits16(add_val.y); unsigned long long ref_addr = reinterpret_cast(ref); - __half ret_val_x, ret_val_y; - unsigned short ret_val_x_cast = - *reinterpret_cast(&ret_val_x); - unsigned short ret_val_y_cast = - *reinterpret_cast(&ret_val_y); - if (memory_order == int(cuda::memory_order_release) || - memory_order == int(cuda::memory_order_consume)) { - asm volatile( - "atom.release.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" - : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) - : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acquire)) { - asm volatile( - "atom.acquire.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" - : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) - : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acq_rel) || - memory_order == int(cuda::memory_order_seq_cst)) { - asm volatile( - "atom.acq_rel.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" - : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) - : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) - : "memory"); - } - return half2(*reinterpret_cast<__half *>(&ret_val_x_cast), - *reinterpret_cast<__half *>(&ret_val_y_cast)); + unsigned short ret_val_x_cast; + unsigned short ret_val_y_cast; + tl_atomic_detail::tl_atomic_add_v2_f16(ret_val_x_cast, ret_val_y_cast, + ref_addr, add_val_x_cast, + add_val_y_cast, memory_order); + return half2(tl_atomic_detail::UnpackBits16<__half>(ret_val_x_cast), + tl_atomic_detail::UnpackBits16<__half>(ret_val_y_cast)); } } @@ -452,37 +538,17 @@ template TL_DEVICE void AtomicAddx2(bfloat16_t *ref, ValType val, int memory_order = int(cuda::memory_order_relaxed)) { __nv_bfloat162 add_val = ToBfloat162(val); - if (memory_order == int(cuda::memory_order_relaxed)) { + if (tl_atomic_detail::IsRelaxedMemoryOrder(memory_order)) { atomicAdd(reinterpret_cast<__nv_bfloat162 *>(ref), add_val); } else { - unsigned short add_val_x_cast = - *reinterpret_cast(&add_val.x); - unsigned short add_val_y_cast = - *reinterpret_cast(&add_val.y); + unsigned short add_val_x_cast = tl_atomic_detail::PackBits16(add_val.x); + unsigned short add_val_y_cast = tl_atomic_detail::PackBits16(add_val.y); unsigned long long ref_addr = reinterpret_cast(ref); - __nv_bfloat162 ret_val; - unsigned short ret_val_x_cast = - *reinterpret_cast(&ret_val.x); - unsigned short ret_val_y_cast = - *reinterpret_cast(&ret_val.y); - if (memory_order == int(cuda::memory_order_release) || - memory_order == int(cuda::memory_order_consume)) { - asm volatile("atom.release.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" - : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) - : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acquire)) { - asm volatile("atom.acquire.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" - : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) - : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acq_rel) || - memory_order == int(cuda::memory_order_seq_cst)) { - asm volatile("atom.acq_rel.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" - : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) - : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) - : "memory"); - } + unsigned short ret_val_x_cast; + unsigned short ret_val_y_cast; + tl_atomic_detail::tl_atomic_add_v2_bf16(ret_val_x_cast, ret_val_y_cast, + ref_addr, add_val_x_cast, + add_val_y_cast, memory_order); } } @@ -490,80 +556,51 @@ template TL_DEVICE __nv_bfloat162 AtomicAddx2Ret(bfloat16_t *ref, src_type *val, int memory_order = int(cuda::memory_order_relaxed)) { - if (memory_order == int(cuda::memory_order_relaxed)) { + if (tl_atomic_detail::IsRelaxedMemoryOrder(memory_order)) { return atomicAdd(reinterpret_cast<__nv_bfloat162 *>(ref), static_cast<__nv_bfloat162>( *reinterpret_cast(val))); } else { __nv_bfloat162 add_val = *reinterpret_cast(val); - unsigned short add_val_x_cast = - *reinterpret_cast(&add_val.x); - unsigned short add_val_y_cast = - *reinterpret_cast(&add_val.y); + unsigned short add_val_x_cast = tl_atomic_detail::PackBits16(add_val.x); + unsigned short add_val_y_cast = tl_atomic_detail::PackBits16(add_val.y); unsigned long long ref_addr = reinterpret_cast(ref); - __nv_bfloat162 ret_val; - unsigned short ret_val_x_cast = - *reinterpret_cast(&ret_val.x); - unsigned short ret_val_y_cast = - *reinterpret_cast(&ret_val.y); - if (memory_order == int(cuda::memory_order_release) || - memory_order == int(cuda::memory_order_consume)) { - asm volatile("atom.release.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" - : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) - : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acquire)) { - asm volatile("atom.acquire.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" - : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) - : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acq_rel) || - memory_order == int(cuda::memory_order_seq_cst)) { - asm volatile("atom.acq_rel.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" - : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) - : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) - : "memory"); - } - return __nv_bfloat162(*reinterpret_cast<__nv_bfloat16 *>(&ret_val_x_cast), - *reinterpret_cast<__nv_bfloat16 *>(&ret_val_y_cast)); + unsigned short ret_val_x_cast; + unsigned short ret_val_y_cast; + tl_atomic_detail::tl_atomic_add_v2_bf16(ret_val_x_cast, ret_val_y_cast, + ref_addr, add_val_x_cast, + add_val_y_cast, memory_order); + return __nv_bfloat162( + tl_atomic_detail::UnpackBits16<__nv_bfloat16>(ret_val_x_cast), + tl_atomic_detail::UnpackBits16<__nv_bfloat16>(ret_val_y_cast)); } } #endif -#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) template TL_DEVICE float2 ToFloat2(T *val) { return *reinterpret_cast(val); } TL_DEVICE float2 ToFloat2(float2 val) { return val; } +template TL_DEVICE float4 ToFloat4(T *val) { + return *reinterpret_cast(val); +} + +TL_DEVICE float4 ToFloat4(float4 val) { return val; } + +#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) template TL_DEVICE void AtomicAddx2(float *ref, ValType val, int memory_order = int(cuda::memory_order_relaxed)) { float2 add_val = ToFloat2(val); - if (memory_order == int(cuda::memory_order_relaxed)) { + if (tl_atomic_detail::IsRelaxedMemoryOrder(memory_order)) { atomicAdd(reinterpret_cast(ref), add_val); } else { unsigned long long ref_addr = reinterpret_cast(ref); float2 ret_val; - if (memory_order == int(cuda::memory_order_release) || - memory_order == int(cuda::memory_order_consume)) { - asm volatile("atom.release.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" - : "=f"(ret_val.x), "=f"(ret_val.y) - : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acquire)) { - asm volatile("atom.acquire.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" - : "=f"(ret_val.x), "=f"(ret_val.y) - : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acq_rel) || - memory_order == int(cuda::memory_order_seq_cst)) { - asm volatile("atom.acq_rel.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" - : "=f"(ret_val.x), "=f"(ret_val.y) - : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) - : "memory"); - } + tl_atomic_detail::tl_atomic_add_v2_f32(ret_val.x, ret_val.y, ref_addr, + add_val.x, add_val.y, memory_order); } } @@ -572,44 +609,22 @@ TL_DEVICE float2 AtomicAddx2Ret(float *ref, ValType val, int memory_order = int(cuda::memory_order_relaxed)) { float2 add_val = ToFloat2(val); - if (memory_order == int(cuda::memory_order_relaxed)) { + if (tl_atomic_detail::IsRelaxedMemoryOrder(memory_order)) { return atomicAdd(reinterpret_cast(ref), add_val); } else { unsigned long long ref_addr = reinterpret_cast(ref); float2 ret_val; - if (memory_order == int(cuda::memory_order_release) || - memory_order == int(cuda::memory_order_consume)) { - asm volatile("atom.release.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" - : "=f"(ret_val.x), "=f"(ret_val.y) - : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acquire)) { - asm volatile("atom.acquire.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" - : "=f"(ret_val.x), "=f"(ret_val.y) - : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acq_rel) || - memory_order == int(cuda::memory_order_seq_cst)) { - asm volatile("atom.acq_rel.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" - : "=f"(ret_val.x), "=f"(ret_val.y) - : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) - : "memory"); - } + tl_atomic_detail::tl_atomic_add_v2_f32(ret_val.x, ret_val.y, ref_addr, + add_val.x, add_val.y, memory_order); return ret_val; } } -template TL_DEVICE float4 ToFloat4(T *val) { - return *reinterpret_cast(val); -} - -TL_DEVICE float4 ToFloat4(float4 val) { return val; } - template TL_DEVICE void AtomicAddx4(dst_dtype *ref, ValType val, int memory_order = int(cuda::memory_order_relaxed)) { float4 add_val = ToFloat4(val); - if (memory_order == int(cuda::memory_order_relaxed)) { + if (tl_atomic_detail::IsRelaxedMemoryOrder(memory_order)) { atomicAdd(reinterpret_cast(ref), add_val); } else { // Since atomicAdd does not support memory order, atomic_ref does not @@ -617,33 +632,9 @@ TL_DEVICE void AtomicAddx4(dst_dtype *ref, ValType val, // Note: Vectorized atomic operations only support global space unsigned long long ref_addr = reinterpret_cast(ref); float4 ret_val; - if (memory_order == int(cuda::memory_order_release) || - memory_order == int(cuda::memory_order_consume)) { - asm volatile("atom.release.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], " - "{%5,%6,%7,%8};" - : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), - "=f"(ret_val.w) - : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), - "f"(add_val.z), "f"(add_val.w) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acquire)) { - asm volatile("atom.acquire.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], " - "{%5,%6,%7,%8};" - : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), - "=f"(ret_val.w) - : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), - "f"(add_val.z), "f"(add_val.w) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acq_rel) || - memory_order == int(cuda::memory_order_seq_cst)) { - asm volatile("atom.acq_rel.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], " - "{%5,%6,%7,%8};" - : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), - "=f"(ret_val.w) - : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), - "f"(add_val.z), "f"(add_val.w) - : "memory"); - } + tl_atomic_detail::tl_atomic_add_v4_f32( + ret_val.x, ret_val.y, ret_val.z, ret_val.w, ref_addr, add_val.x, + add_val.y, add_val.z, add_val.w, memory_order); } } @@ -652,61 +643,24 @@ TL_DEVICE float4 AtomicAddx4Ret(dst_dtype *ref, ValType val, int memory_order = int(cuda::memory_order_relaxed)) { float4 add_val = ToFloat4(val); - if (memory_order == int(cuda::memory_order_relaxed)) { + if (tl_atomic_detail::IsRelaxedMemoryOrder(memory_order)) { return atomicAdd(reinterpret_cast(ref), add_val); } else { unsigned long long ref_addr = reinterpret_cast(ref); float4 ret_val; - if (memory_order == int(cuda::memory_order_release) || - memory_order == int(cuda::memory_order_consume)) { - asm volatile("atom.global.gpu.release.add.v4.f32 {%0,%1,%2,%3}, [%4], " - "{%5,%6,%7,%8};" - : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), - "=f"(ret_val.w) - : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), - "f"(add_val.z), "f"(add_val.w) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acquire)) { - asm volatile("atom.global.gpu.acquire.add.v4.f32 {%0,%1,%2,%3}, [%4], " - "{%5,%6,%7,%8};" - : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), - "=f"(ret_val.w) - : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), - "f"(add_val.z), "f"(add_val.w) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acq_rel) || - memory_order == int(cuda::memory_order_seq_cst)) { - asm volatile("atom.global.gpu.acq_rel.add.v4.f32 {%0,%1,%2,%3}, [%4], " - "{%5,%6,%7,%8};" - : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), - "=f"(ret_val.w) - : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), - "f"(add_val.z), "f"(add_val.w) - : "memory"); - } + tl_atomic_detail::tl_atomic_add_v4_f32( + ret_val.x, ret_val.y, ret_val.z, ret_val.w, ref_addr, add_val.x, + add_val.y, add_val.z, add_val.w, memory_order); return ret_val; } } #else -template TL_DEVICE float2 ToFloat2(T *val) { - return *reinterpret_cast(val); -} - -TL_DEVICE float2 ToFloat2(float2 val) { return val; } - -template TL_DEVICE float4 ToFloat4(T *val) { - return *reinterpret_cast(val); -} - -TL_DEVICE float4 ToFloat4(float4 val) { return val; } - template TL_DEVICE void AtomicAddx2(float *ref, ValType val, int memory_order = int(cuda::memory_order_relaxed)) { (void)memory_order; float2 add_val = ToFloat2(val); - atomicAdd(ref + 0, add_val.x); - atomicAdd(ref + 1, add_val.y); + tl_atomic_detail::AtomicAddx2Scalar(ref, add_val.x, add_val.y); } template @@ -715,10 +669,7 @@ AtomicAddx2Ret(float *ref, ValType val, int memory_order = int(cuda::memory_order_relaxed)) { (void)memory_order; float2 add_val = ToFloat2(val); - float2 ret; - ret.x = atomicAdd(ref + 0, add_val.x); - ret.y = atomicAdd(ref + 1, add_val.y); - return ret; + return tl_atomic_detail::AtomicAddx2ScalarRet(ref, add_val); } template @@ -726,10 +677,8 @@ TL_DEVICE void AtomicAddx4(dst_dtype *ref, ValType val, int memory_order = int(cuda::memory_order_relaxed)) { (void)memory_order; float4 add_val = ToFloat4(val); - atomicAdd(ref + 0, add_val.x); - atomicAdd(ref + 1, add_val.y); - atomicAdd(ref + 2, add_val.z); - atomicAdd(ref + 3, add_val.w); + tl_atomic_detail::AtomicAddx4Scalar(ref, add_val.x, add_val.y, add_val.z, + add_val.w); } template @@ -738,12 +687,7 @@ AtomicAddx4Ret(dst_dtype *ref, ValType val, int memory_order = int(cuda::memory_order_relaxed)) { (void)memory_order; float4 add_val = ToFloat4(val); - float4 ret; - ret.x = atomicAdd(ref + 0, add_val.x); - ret.y = atomicAdd(ref + 1, add_val.y); - ret.z = atomicAdd(ref + 2, add_val.z); - ret.w = atomicAdd(ref + 3, add_val.w); - return ret; + return tl_atomic_detail::AtomicAddx4ScalarRet(ref, add_val); } #endif From 8c3b043076b00c4ae5ffc27196b2457e095a56ab Mon Sep 17 00:00:00 2001 From: Kuris <227995639+kurisu6912@users.noreply.github.com> Date: Tue, 31 Mar 2026 22:36:01 +0800 Subject: [PATCH 007/156] [Bugfix] Fix CuTeDSL autotune cache invalid ELF header (#1967) (#1972) * [Bugfix] Fix CuTeDSL autotune cache saving .py as .so (#1967) The autotune cache had no CuTeDSL-specific branch, causing it to save the Python source file (kernel.py) as kernel_lib.so. On reload, importlib treated the .so extension as a native extension module and failed with "invalid ELF header". Fix: add cutedsl branches in _save_kernel_to_disk and _load_kernel_from_disk to use KERNEL_PY_PATH ("kernel.py") instead of KERNEL_LIB_PATH ("kernel_lib.so"). Also saves launcher .so and cubin artifacts when present. Closes #1967 Co-Authored-By: Claude Opus 4.6 * [BugFix] Change _all_dtypes from set to list for deterministic order set has non-deterministic iteration order across processes, causing pytest-xdist workers to collect test parameters in different orders and fail with "Different tests were collected between gw3 and gw2". Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Claude Opus 4.6 --- .../test_tilelang_autotune_cutedsl_cache.py | 120 ++++++++++++++++++ tilelang/autotuner/param.py | 32 ++++- tilelang/jit/kernel.py | 2 +- tilelang/language/dtypes.py | 4 +- 4 files changed, 154 insertions(+), 4 deletions(-) create mode 100644 testing/python/autotune/test_tilelang_autotune_cutedsl_cache.py diff --git a/testing/python/autotune/test_tilelang_autotune_cutedsl_cache.py b/testing/python/autotune/test_tilelang_autotune_cutedsl_cache.py new file mode 100644 index 0000000000..bfdf58fc8e --- /dev/null +++ b/testing/python/autotune/test_tilelang_autotune_cutedsl_cache.py @@ -0,0 +1,120 @@ +"""Regression test for #1967: CuTeDSL autotune cache saved .py as .so → "invalid ELF header".""" + +import os +import pytest +import tilelang +import tilelang.testing +import tilelang.language as T +from tilelang.autotuner.param import AutotuneResult +from tilelang.env import env + + +def test_cutedsl_save_creates_kernel_py(tmp_path): + """_save_kernel_to_disk should write kernel.py (not kernel_lib.so) for CuTeDSL.""" + original_tmp_dir = env.TILELANG_TMP_DIR + env.TILELANG_TMP_DIR = str(tmp_path / "tmp") + os.makedirs(env.TILELANG_TMP_DIR, exist_ok=True) + + try: + src_dir = tmp_path / "src" + src_dir.mkdir() + (src_dir / "kernel.py").write_text("# cutedsl kernel\n") + (src_dir / "kernel.cubin").write_bytes(b"fake_cubin") + + class FakeLibGen: + launcher_libpath = None + + class FakeAdapter: + libpath = str(src_dir / "kernel.py") + lib_generator = FakeLibGen() + + def get_kernel_source(self, kernel_only=True): + return "# src" + + class FakeKernel: + execution_backend = "cutedsl" + adapter = FakeAdapter() + kernel_source = "# src" + params = [] + + cache = tmp_path / "cache" + cache.mkdir() + AutotuneResult()._save_kernel_to_disk(cache, FakeKernel()) + + assert (cache / "kernel.py").exists() + assert not (cache / "kernel_lib.so").exists() + assert (cache / "kernel.cubin").exists() + finally: + env.TILELANG_TMP_DIR = original_tmp_dir + + +def _is_cutedsl_available(): + try: + from tilelang.jit.adapter.cutedsl.checks import check_cutedsl_available + + check_cutedsl_available() + return True + except (ImportError, AssertionError): + return False + + +# Define autotune kernel at module level so closures don't capture module objects +def _make_vec_add_autotuned(): + from tilelang.autotuner import autotune + + @autotune(configs=[{"threads": t} for t in (128, 256)], warmup=3, rep=5) + @tilelang.jit(out_idx=[-1], target="cutedsl") + def vec_add(n: int, dtype: str = "float32", threads: int = 128): + num_blocks = n // threads + + @T.prim_func + def kernel(a: T.Tensor((n,), dtype), b: T.Tensor((n,), dtype), c: T.Tensor((n,), dtype)): + with T.Kernel(num_blocks, threads=threads) as bx: + for i in T.Parallel(threads): + c[bx * threads + i] = a[bx * threads + i] + b[bx * threads + i] + + return kernel + + return vec_add + + +@tilelang.testing.requires_cuda +@pytest.mark.skipif(not _is_cutedsl_available(), reason="CuTeDSL not installed") +def test_cutedsl_autotune_cache_roundtrip(tmp_path): + """Autotune + CuTeDSL: save → reload from disk → verify correctness.""" + import torch + from tilelang.autotuner import AutoTuner + + original_cache_dir, original_tmp_dir = env.TILELANG_CACHE_DIR, env.TILELANG_TMP_DIR + env.TILELANG_CACHE_DIR = str(tmp_path / "cache") + env.TILELANG_TMP_DIR = str(tmp_path / "tmp") + os.makedirs(env.TILELANG_CACHE_DIR, exist_ok=True) + os.makedirs(env.TILELANG_TMP_DIR, exist_ok=True) + original_cache_enabled = env.is_cache_enabled() + tilelang.enable_cache() + AutoTuner._memory_cache.clear() + + try: + vec_add = _make_vec_add_autotuned() + N = 256 + a = torch.randn(N, device="cuda", dtype=torch.float32) + b = torch.randn(N, device="cuda", dtype=torch.float32) + ref = a + b + + # Pass 1: cache miss + torch.testing.assert_close(vec_add(N)(a, b), ref, atol=1e-5, rtol=1e-5) + + # Pass 2: clear memory cache → force disk reload (was "invalid ELF header" before fix) + AutoTuner._memory_cache.clear() + vec_add._tuner_cache.clear() + torch.testing.assert_close(vec_add(N)(a, b), ref, atol=1e-5, rtol=1e-5) + finally: + env.TILELANG_CACHE_DIR = original_cache_dir + env.TILELANG_TMP_DIR = original_tmp_dir + if not original_cache_enabled: + tilelang.disable_cache() + AutoTuner._memory_cache.clear() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tilelang/autotuner/param.py b/tilelang/autotuner/param.py index 4542099868..b3a8e6ef99 100644 --- a/tilelang/autotuner/param.py +++ b/tilelang/autotuner/param.py @@ -224,6 +224,9 @@ def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: boo kernel_lib_file = KERNEL_CUBIN_PATH elif kernel.execution_backend == "tvm_ffi": kernel_lib_file = EXECUTABLE_PATH + elif kernel.execution_backend == "cutedsl": + # cutedsl only generates a Python source file as the "library", so save that instead of a .so + kernel_lib_file = KERNEL_PY_PATH else: kernel_lib_file = KERNEL_LIB_PATH @@ -251,6 +254,31 @@ def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: boo if verbose: logger.debug(f"Saving kernel executable to file: {kernel_lib_path}") self._safe_write_executable(executable, kernel_lib_path) + elif kernel.execution_backend == "cutedsl": + # Save the Python source file (CuTeDSL "library" is a .py, not a .so) + src_lib_path = kernel.adapter.libpath + if verbose: + logger.debug(f"Saving CuTeDSL kernel Python source to file: {kernel_lib_path}") + self._safe_write_file(kernel_lib_path, "wb", lambda f: f.write(self._load_binary(src_lib_path))) + + # Save launcher .so if present (compiled C++ launcher for TMA etc.) + lib_gen = kernel.adapter.lib_generator + launcher_src = getattr(lib_gen, "launcher_libpath", None) + if launcher_src and os.path.exists(launcher_src): + launcher_name = getattr(lib_gen, "launcher_libname", os.path.basename(launcher_src)) + dst_launcher = os.path.join(cache_path, launcher_name) + if verbose: + logger.debug(f"Saving CuTeDSL launcher library to file: {dst_launcher}") + self._safe_write_file(dst_launcher, "wb", lambda f: f.write(self._load_binary(launcher_src))) + + # Save cubin if already generated (generated during autotuning benchmark) + src_dir = os.path.dirname(src_lib_path) + src_cubin = os.path.join(src_dir, "kernel.cubin") + if os.path.exists(src_cubin): + dst_cubin = os.path.join(cache_path, KERNEL_CUBIN_PATH) + if verbose: + logger.debug(f"Saving CuTeDSL cubin to file: {dst_cubin}") + self._safe_write_file(dst_cubin, "wb", lambda f: f.write(self._load_binary(src_cubin))) else: src_lib_path = kernel.adapter.libpath if verbose: @@ -275,7 +303,7 @@ def _load_kernel_from_disk( target: str | Target = "auto", target_host: str | Target = None, out_idx: list[int] | int | None = None, - execution_backend: Literal["tvm_ffi", "cython", "nvrtc", "torch"] = "tvm_ffi", + execution_backend: Literal["tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi", pass_configs: dict = None, compile_flags: list[str] | str | None = None, func: Callable = None, @@ -306,6 +334,8 @@ def _load_kernel_from_disk( kernel_lib_file = KERNEL_CUBIN_PATH elif execution_backend == "tvm_ffi": kernel_lib_file = EXECUTABLE_PATH + elif execution_backend == "cutedsl": + kernel_lib_file = KERNEL_PY_PATH else: kernel_lib_file = KERNEL_LIB_PATH diff --git a/tilelang/jit/kernel.py b/tilelang/jit/kernel.py index bd7533787b..d2a594d28f 100644 --- a/tilelang/jit/kernel.py +++ b/tilelang/jit/kernel.py @@ -156,7 +156,7 @@ def from_database( target: str | Target, target_host: str | Target, out_idx: list[int] | int, - execution_backend: Literal["tvm_ffi", "cython", "nvrtc", "torch"], + execution_backend: Literal["tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"], pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | None = None, ): diff --git a/tilelang/language/dtypes.py b/tilelang/language/dtypes.py index 88fe8ad73f..28a9d7cd4a 100644 --- a/tilelang/language/dtypes.py +++ b/tilelang/language/dtypes.py @@ -591,7 +591,7 @@ class bfloat16x2(dtype): ... bfloat16 = dtype("bfloat16") bfloat16x2 = dtype("bfloat16x2") -_all_dtypes = { +_all_dtypes = [ "bool", "short", "int", @@ -757,7 +757,7 @@ class bfloat16x2(dtype): ... "float4_e2m1fnx64", "bfloat16", "bfloat16x2", -} +] __all__ = list(_all_dtypes) + [ "dtype", From a82fa719b8000fe629e6d59dca97139f8a880524 Mon Sep 17 00:00:00 2001 From: William Date: Wed, 1 Apr 2026 01:29:17 +0800 Subject: [PATCH 008/156] fix: fix copy+cast vectorize loop to use wider vector load/store instrcution (#2004) * fix copy+cast vectorize loop to use wider vector load/store instrcution * clean test * fix test * fix format * test fix --------- Co-authored-by: LeiWang1999 --- examples/gemm/test_example_gemm.py | 5 - src/transform/loop_vectorize.cc | 26 ++- ...t_tilelang_transform_decouple_type_cast.py | 168 ++++++++++++++++++ 3 files changed, 186 insertions(+), 13 deletions(-) diff --git a/examples/gemm/test_example_gemm.py b/examples/gemm/test_example_gemm.py index 5f69364be6..fb0ae3ab4b 100644 --- a/examples/gemm/test_example_gemm.py +++ b/examples/gemm/test_example_gemm.py @@ -1,7 +1,6 @@ import tilelang.testing import example_gemm_autotune import example_gemm_intrinsics -import example_gemm_schedule import example_gemm @@ -14,10 +13,6 @@ def test_example_gemm_intrinsics(): example_gemm_intrinsics.main(M=1024, N=1024, K=1024) -def test_example_gemm_schedule(): - example_gemm_schedule.main() - - def test_example_gemm(): example_gemm.main() diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index c192cd3356..d612208438 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -84,6 +84,7 @@ struct BufferVectorInfo { int vector_size; bool is_store; Array indices; + bool is_cast = false; // true for CastNode constraints (vs CallNode) }; Array GetBufferStrides(const Buffer &buffer) { @@ -215,6 +216,7 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { int local_fragment_min = initial_vector_size_; int memory_min = initial_vector_size_; int call_node_min = initial_vector_size_; + int non_cast_call_node_min = initial_vector_size_; bool has_global_or_shared_buffer = false; auto is_local_or_fragment = [](const Buffer &buf) { @@ -232,13 +234,16 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { << " -> vector_size=" << info.vector_size << (info.is_store ? " [store]" : " [load]") << "\n"; } else { - std::cerr << " [cast/call] -> vector_size=" << info.vector_size - << "\n"; + std::cerr << " [" << (info.is_cast ? "cast" : "call") + << "] -> vector_size=" << info.vector_size << "\n"; } } if (!buffer.defined()) { - // CastNode, CallNode do not have buffer defined. call_node_min = arith::ZeroAwareGCD(call_node_min, info.vector_size); + if (!info.is_cast) { + non_cast_call_node_min = + arith::ZeroAwareGCD(non_cast_call_node_min, info.vector_size); + } } else if (is_local_or_fragment(buffer)) { local_fragment_min = arith::ZeroAwareGCD(local_fragment_min, info.vector_size); @@ -269,12 +274,15 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { } } else if (has_global_or_shared_buffer) { // Has memory buffers and simple case (no SeqStmt): - // ignore local/fragment constraints - vector_size_ = arith::ZeroAwareGCD(memory_min, call_node_min); + // ignore local/fragment constraints AND cast constraints. + // Cast constraints are ignored because DecoupleTypeCast will later + // split mixed-type operations into separate loops, allowing memory + // copies to use wider vectors independently of cast width limits. + vector_size_ = arith::ZeroAwareGCD(memory_min, non_cast_call_node_min); if (verbose) { std::cerr << " [Strategy] Has memory buffers (simple case), using " - "memory_min=" - << memory_min + << "memory_min=" << memory_min + << ", non_cast_call_node_min=" << non_cast_call_node_min << " (ignoring local/fragment_min=" << local_fragment_min << ")" << "\n"; } @@ -655,7 +663,9 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { } int cast_vector_size = arith::ZeroAwareGCD(max_lanes, initial_vector_size_); // Record cast constraint (use empty buffer to indicate cast) - buffer_vector_infos_.push_back({Buffer(), cast_vector_size, false, {}}); + // Mark is_cast=true so Plan() can distinguish cast from other call nodes + buffer_vector_infos_.push_back( + {Buffer(), cast_vector_size, false, {}, /*is_cast=*/true}); return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } diff --git a/testing/python/transform/test_tilelang_transform_decouple_type_cast.py b/testing/python/transform/test_tilelang_transform_decouple_type_cast.py index fd9746e000..7a4fdaf081 100644 --- a/testing/python/transform/test_tilelang_transform_decouple_type_cast.py +++ b/testing/python/transform/test_tilelang_transform_decouple_type_cast.py @@ -211,5 +211,173 @@ def kernel(b: T.Tensor[(16,), T.float32]): assert "local_cast" not in source, "Should not have cast buffer when dtypes match" +# ============================================================================= +# End-to-end correctness + vectorization tests for DecoupleTypeCast +# ============================================================================= + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(10) +def test_e2e_bf16_global_to_frag(): + """bf16 global -> float32 frag -> bf16 global: roundtrip should be lossless. + + With 1024 bf16 elements and 64 threads, each thread handles 16 bf16 = 256 bits, + so the kernel should use 256-bit load/store (load_global_256 / store_global_256). + """ + import torch + + @tilelang.jit(out_idx=[1]) + def kernel_fn(): + @T.prim_func + def main( + A: T.Tensor((1024,), dtype=T.bfloat16), + B: T.Tensor((1024,), dtype=T.bfloat16), + ): + with T.Kernel(1, threads=64): + a_frag = T.alloc_fragment((1024,), dtype=T.float32) + T.copy(A, a_frag) + T.copy(a_frag, B) + + return main + + kernel = kernel_fn() + + # Check vectorization: 256-bit load/store + source = kernel.get_kernel_source() + assert "load_global_256" in source, "Expected 256-bit global load" + assert "store_global_256" in source, "Expected 256-bit global store" + + # Correctness + a = torch.randn(1024, device="cuda", dtype=torch.bfloat16) + b = kernel(a) + torch.testing.assert_close(b, a, rtol=0, atol=0) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(8) +def test_e2e_bf16_global_shared_frag(): + """bf16 global -> shared -> float32 frag -> bf16 global: roundtrip should be lossless. + + Shared memory path uses TMA for global->shared, then 128-bit for shared->local. + """ + import torch + + @tilelang.jit(out_idx=[1]) + def kernel_fn(): + @T.prim_func + def main( + A: T.Tensor((1024,), dtype=T.bfloat16), + B: T.Tensor((1024,), dtype=T.bfloat16), + ): + with T.Kernel(1, threads=64): + a_shared = T.alloc_shared((1024,), dtype=T.bfloat16) + a_frag = T.alloc_fragment((1024,), dtype=T.float32) + T.copy(A, a_shared) + T.copy(a_shared, a_frag) + T.copy(a_frag, B) + + return main + + kernel = kernel_fn() + + # Check: shared path should NOT use 256-bit (shared doesn't support it) + source = kernel.get_kernel_source() + assert "uint4" in source, f"Expected uint4 store in {source}" + + # Correctness + a = torch.randn(1024, device="cuda", dtype=torch.bfloat16) + b = kernel(a) + torch.testing.assert_close(b, a, rtol=0, atol=0) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9) +def test_e2e_fp8_global_to_frag(): + """fp8 global -> float32 frag -> fp8 global: roundtrip should be lossless. + + Verifies that cast constraints do not pollute the memory access layout. + With 1024 fp8 elements and 64 threads, each thread handles 16 fp8 = 128 bits, + so the kernel should use fp8_e4_16_t (128-bit) loads/stores. + """ + import torch + + @tilelang.jit(out_idx=[1]) + def kernel_fn(): + @T.prim_func + def main( + A: T.Tensor((1024,), dtype=T.float8_e4m3fn), + B: T.Tensor((1024,), dtype=T.float8_e4m3fn), + ): + with T.Kernel(1, threads=64): + a_frag = T.alloc_fragment((1024,), dtype=T.float32) + T.copy(A, a_frag) + T.copy(a_frag, B) + + return main + + kernel = kernel_fn() + source = kernel.get_kernel_source() + assert "fp8_e4_16_t" in source, ( + "Expected fp8_e4_16_t (128-bit) loads/stores for N=1024. Cast constraints may be polluting layout decisions." + ) + + a = (torch.randn(1024, device="cuda", dtype=torch.float32) * 0.5).to(torch.float8_e4m3fn) + b = kernel(a) + torch.testing.assert_close( + b.to(torch.float32), + a.to(torch.float32), + rtol=0, + atol=0, + ) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9) +def test_e2e_fp8_manual_decouple(): + """fp8 with manually decoupled copy stages: same result as auto-decoupled. + + Tests: fp8 global -> fp8 frag -> float32 frag -> fp8 frag -> fp8 global + """ + import torch + + @tilelang.jit(out_idx=[1]) + def kernel_fn(): + @T.prim_func + def main( + A: T.Tensor((1024,), dtype=T.float8_e4m3fn), + B: T.Tensor((1024,), dtype=T.float8_e4m3fn), + ): + with T.Kernel(1, threads=64): + a_frag = T.alloc_fragment((1024,), dtype=T.float8_e4m3fn) + b_frag = T.alloc_fragment((1024,), dtype=T.float32) + c_frag = T.alloc_fragment((1024,), dtype=T.float8_e4m3fn) + T.copy(A, a_frag) + T.copy(a_frag, b_frag) + T.copy(b_frag, c_frag) + T.copy(c_frag, B) + + return main + + kernel = kernel_fn() + + # Check vectorization + source = kernel.get_kernel_source() + assert "fp8_e4_16_t" in source, "Expected fp8_e4_16_t in kernel source" + + # Correctness + a = (torch.randn(1024, device="cuda", dtype=torch.float32) * 0.5).to(torch.float8_e4m3fn) + b = kernel(a) + torch.testing.assert_close( + b.to(torch.float32), + a.to(torch.float32), + rtol=0, + atol=0, + ) + + if __name__ == "__main__": test_no_transform_if_then_else_condition() + test_e2e_bf16_global_to_frag() + test_e2e_bf16_global_shared_frag() + test_e2e_fp8_global_to_frag() + test_e2e_fp8_manual_decouple() From e45ecf7dc4bcb62de96e863991a01e648e5c138c Mon Sep 17 00:00:00 2001 From: Kuris <227995639+kurisu6912@users.noreply.github.com> Date: Thu, 2 Apr 2026 12:42:47 +0800 Subject: [PATCH 009/156] [Feature] Support T.annotate_compile_flags, T.annotate_pass_configs, and out_idx as PrimFunc attrs (#2006) Allow configuring pass configs, compile flags, and out_idx directly inside function bodies using T.annotate_compile_flags(), T.annotate_pass_configs(), and T.empty()+return. These are stored as proper PrimFunc attrs (tilelang_compile_flags, tilelang_pass_configs, tilelang_out_idx) instead of monkey-patching, and merged at compile time. Annotations can be placed before or after tensor type annotations. Co-authored-by: Claude Opus 4.6 --- .../test_tilelang_language_func_attrs.py | 214 ++++++++++++++++++ tilelang/autotuner/param.py | 4 +- tilelang/jit/__init__.py | 23 +- tilelang/language/eager/__init__.py | 2 +- tilelang/language/eager/builder.py | 72 +++++- 5 files changed, 303 insertions(+), 12 deletions(-) create mode 100644 testing/python/language/test_tilelang_language_func_attrs.py diff --git a/testing/python/language/test_tilelang_language_func_attrs.py b/testing/python/language/test_tilelang_language_func_attrs.py new file mode 100644 index 0000000000..a64ab73928 --- /dev/null +++ b/testing/python/language/test_tilelang_language_func_attrs.py @@ -0,0 +1,214 @@ +"""Test T.annotate_compile_flags, T.annotate_pass_configs, and out_idx via PrimFunc attrs.""" + +import pytest +import torch +import tilelang +from tilelang import language as T +from tilelang.transform import PassConfigKey + + +def test_out_idx_via_attr_lazy(): + """out_idx should be stored as PrimFunc attr when using T.empty + return.""" + + @T.prim_func + def kernel(A): + A: T.Tensor[[128, 128], T.float32] + B = T.empty([128, 128], T.float32) + with T.Kernel(1): + for i in T.serial(128): + for j in T.serial(128): + B[i, j] = A[i, j] + 1.0 + return B + + assert "tilelang_out_idx" in kernel.attrs + assert list(kernel.attrs["tilelang_out_idx"]) == [-1] + + compiled = tilelang.compile(kernel) + a = torch.randn(128, 128, device="cuda") + b = compiled(a) + torch.testing.assert_close(b, a + 1.0) + + +def test_all_attrs_together_lazy(): + """annotate_pass_configs, annotate_compile_flags, and out_idx should all work together.""" + + @T.prim_func + def kernel(A): + A: T.Tensor[[64, 64], T.float32] + T.annotate_pass_configs({PassConfigKey.TL_ENABLE_FAST_MATH: True}) + T.annotate_compile_flags(["--use_fast_math"]) + B = T.empty([64, 64], T.float32) + with T.Kernel(1): + for i in T.serial(64): + for j in T.serial(64): + B[i, j] = A[i, j] * 2.0 + return B + + attrs = kernel.attrs + assert "tilelang_out_idx" in attrs + assert "tilelang_pass_configs" in attrs + assert "tilelang_compile_flags" in attrs + + compiled = tilelang.compile(kernel) + a = torch.randn(64, 64, device="cuda") + b = compiled(a) + torch.testing.assert_close(b, a * 2.0) + + +def test_eager_mode_attrs(): + """Eager mode should support annotate_pass_configs and out_idx via T.empty.""" + + @tilelang.jit + def kernel(A): + M, N = T.const("M N") + A: T.Tensor[[M, N], T.float32] + B = T.empty([M, N], T.float32) + T.annotate_pass_configs({PassConfigKey.TL_ENABLE_FAST_MATH: True}) + with T.Kernel(1): + for i in T.serial(M): + for j in T.serial(N): + B[i, j] = A[i, j] + 1.0 + return B + + a = torch.randn(32, 32, device="cuda") + result = kernel(a) + torch.testing.assert_close(result, a + 1.0) + + +def test_out_idx_conflict_detection(): + """Specifying both T.empty return and external out_idx should raise ValueError.""" + + @T.prim_func + def kernel(A): + A: T.Tensor[[32, 32], T.float32] + B = T.empty([32, 32], T.float32) + with T.Kernel(1): + for i in T.serial(32): + for j in T.serial(32): + B[i, j] = A[i, j] + return B + + with pytest.raises(ValueError, match="Out index conflict"): + tilelang.compile(kernel, out_idx=[-1]) + + +def test_no_out_idx_when_not_using_empty(): + """When T.empty is not used, tilelang_out_idx attr should not be present.""" + + @T.prim_func + def kernel(A, B): + A: T.Tensor[[32, 32], T.float32] + B: T.Tensor[[32, 32], T.float32] + with T.Kernel(1): + for i in T.serial(32): + for j in T.serial(32): + B[i, j] = A[i, j] + + assert kernel.attrs is None or "tilelang_out_idx" not in kernel.attrs + + compiled = tilelang.compile(kernel, out_idx=[-1]) + a = torch.randn(32, 32, device="cuda") + b = compiled(a) + torch.testing.assert_close(b, a) + + +def test_pass_configs_only_lazy(): + """annotate_pass_configs should work without T.empty or annotate_compile_flags.""" + + @T.prim_func + def kernel(A, B): + A: T.Tensor[[32, 32], T.float32] + B: T.Tensor[[32, 32], T.float32] + T.annotate_pass_configs({PassConfigKey.TL_ENABLE_FAST_MATH: True}) + with T.Kernel(1): + for i in T.serial(32): + for j in T.serial(32): + B[i, j] = A[i, j] + 1.0 + + assert "tilelang_pass_configs" in kernel.attrs + assert kernel.attrs is None or "tilelang_out_idx" not in kernel.attrs + + compiled = tilelang.compile(kernel, out_idx=[-1]) + a = torch.randn(32, 32, device="cuda") + b = compiled(a) + torch.testing.assert_close(b, a + 1.0) + + +def test_compile_flags_only_lazy(): + """annotate_compile_flags should work standalone.""" + + @T.prim_func + def kernel(A, B): + A: T.Tensor[[32, 32], T.float32] + B: T.Tensor[[32, 32], T.float32] + T.annotate_compile_flags(["--use_fast_math"]) + with T.Kernel(1): + for i in T.serial(32): + for j in T.serial(32): + B[i, j] = A[i, j] + 1.0 + + assert "tilelang_compile_flags" in kernel.attrs + + compiled = tilelang.compile(kernel, out_idx=[-1]) + a = torch.randn(32, 32, device="cuda") + b = compiled(a) + torch.testing.assert_close(b, a + 1.0) + + +def test_annotations_before_tensor_type(): + """Annotations placed before tensor type annotations should work.""" + + @T.prim_func + def kernel(A, B): + T.annotate_pass_configs({PassConfigKey.TL_ENABLE_FAST_MATH: True}) + T.annotate_compile_flags(["--use_fast_math"]) + A: T.Tensor[[32, 32], T.float32] + B: T.Tensor[[32, 32], T.float32] + with T.Kernel(1): + for i in T.serial(32): + for j in T.serial(32): + B[i, j] = A[i, j] + 1.0 + + assert "tilelang_pass_configs" in kernel.attrs + assert "tilelang_compile_flags" in kernel.attrs + + compiled = tilelang.compile(kernel, out_idx=[-1]) + a = torch.randn(32, 32, device="cuda") + b = compiled(a) + torch.testing.assert_close(b, a + 1.0) + + +def test_annotations_after_tensor_type(): + """Annotations placed after tensor type annotations should work.""" + + @T.prim_func + def kernel(A, B): + A: T.Tensor[[32, 32], T.float32] + B: T.Tensor[[32, 32], T.float32] + T.annotate_pass_configs({PassConfigKey.TL_ENABLE_FAST_MATH: True}) + T.annotate_compile_flags(["--use_fast_math"]) + with T.Kernel(1): + for i in T.serial(32): + for j in T.serial(32): + B[i, j] = A[i, j] + 1.0 + + assert "tilelang_pass_configs" in kernel.attrs + assert "tilelang_compile_flags" in kernel.attrs + + compiled = tilelang.compile(kernel, out_idx=[-1]) + a = torch.randn(32, 32, device="cuda") + b = compiled(a) + torch.testing.assert_close(b, a + 1.0) + + +if __name__ == "__main__": + test_out_idx_via_attr_lazy() + test_all_attrs_together_lazy() + test_eager_mode_attrs() + test_out_idx_conflict_detection() + test_no_out_idx_when_not_using_empty() + test_pass_configs_only_lazy() + test_compile_flags_only_lazy() + test_annotations_before_tensor_type() + test_annotations_after_tensor_type() + print("All tests passed!") diff --git a/tilelang/autotuner/param.py b/tilelang/autotuner/param.py index b3a8e6ef99..aa5254b996 100644 --- a/tilelang/autotuner/param.py +++ b/tilelang/autotuner/param.py @@ -417,7 +417,9 @@ def save_to_disk(self, path: Path, verbose: bool = False): "w", lambda f: json.dump( { - "out_idx": getattr(self.func, "out_idx_override", None), + "out_idx": list(self.func.attrs["tilelang_out_idx"]) + if (self.func.attrs and "tilelang_out_idx" in self.func.attrs) + else None, }, f, ), diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index a8bbe08fad..2297874b65 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -90,10 +90,27 @@ def compile( assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}" - if hasattr(func, "out_idx_override"): - if func.out_idx_override is not None and out_idx is not None: + # Merge function-level attrs from PrimFunc + func_attrs = func.attrs + if func_attrs and "tilelang_out_idx" in func_attrs: + func_out_idx = list(func_attrs["tilelang_out_idx"]) + if out_idx is not None: raise ValueError("Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors") - out_idx = func.out_idx_override or out_idx + out_idx = func_out_idx + if func_attrs and "tilelang_pass_configs" in func_attrs: + func_pc = dict(func_attrs["tilelang_pass_configs"]) + if pass_configs is not None: + # External pass_configs override function-level ones + func_pc.update(pass_configs) + pass_configs = func_pc + if func_attrs and "tilelang_compile_flags" in func_attrs: + func_cf = list(func_attrs["tilelang_compile_flags"]) + if compile_flags is not None: + if isinstance(compile_flags, str): + func_cf.append(compile_flags) + else: + func_cf.extend(compile_flags) + compile_flags = func_cf return cached( func=func, diff --git a/tilelang/language/eager/__init__.py b/tilelang/language/eager/__init__.py index 1710681263..e97f11ba51 100644 --- a/tilelang/language/eager/__init__.py +++ b/tilelang/language/eager/__init__.py @@ -1,2 +1,2 @@ -from .builder import prim_func, macro, PrimFunc, JITFunc, Ref, const # noqa: F401 +from .builder import prim_func, macro, PrimFunc, JITFunc, Ref, const, annotate_compile_flags, annotate_pass_configs # noqa: F401 from ..dtypes import * diff --git a/tilelang/language/eager/builder.py b/tilelang/language/eager/builder.py index f5ff966561..812d54638b 100644 --- a/tilelang/language/eager/builder.py +++ b/tilelang/language/eager/builder.py @@ -181,6 +181,8 @@ def __init__(self): self.constexpr_var = set() self.eager_jit: EagerJITStage = "none" self.eager_jit_subs: dict[str, PrimExpr] = {} + self.func_pass_configs: dict[str, Any] | None = None + self.func_compile_flags: list[str] | str | None = None self.current_file = "" self.current_line = 0 self.current_macro_name = "" @@ -754,7 +756,6 @@ class PrimFunc(Generic[_P, _T], tvm.tir.PrimFunc): span: Span | None ir_gen: IRGenerator[_P, _T] | None orig_func: Callable[_P, _T] | None - out_idx_override: list[int] | None else: PrimFunc = tvm.tir.PrimFunc @@ -935,6 +936,67 @@ def kernel(A, B): return builder.eager_jit_subs[name] +def annotate_compile_flags(flags: list[str] | str) -> None: + """ + Annotate additional device compile flags inside a function body. + + The flags will be merged with any externally provided compile_flags + at compilation time. Can be placed before or after tensor type annotations. + + Example:: + + @tilelang.jit + def kernel(A, B): + T.annotate_compile_flags(["--use_fast_math"]) + ... + """ + builder = Builder.current() + if builder is None: + raise JITNoBuilderError("T.annotate_compile_flags() can only be used inside @tilelang.jit or @T.prim_func") + if builder.eager_jit == "phase1": + return + builder.func_compile_flags = flags + + +def annotate_pass_configs(configs: dict[str, Any]) -> None: + """ + Annotate pass configuration inside a function body. + + The configs will be merged with any externally provided pass_configs + at compilation time (function-level configs take lower priority, i.e. + external configs override). Can be placed before or after tensor type annotations. + + Example:: + + @tilelang.jit + def kernel(A, B): + T.annotate_pass_configs({ + PassConfigKey.TL_ENABLE_FAST_MATH: True, + }) + ... + """ + builder = Builder.current() + if builder is None: + raise JITNoBuilderError("T.annotate_pass_configs() can only be used inside @tilelang.jit or @T.prim_func") + if builder.eager_jit == "phase1": + return + builder.func_pass_configs = configs + + +def _patch_prim_func_attrs(pf: PrimFunc, builder: Builder) -> PrimFunc: + """Attach function-level out_idx, pass_configs and compile_flags as PrimFunc attrs.""" + if builder.out_idx: + pf = pf.with_attr("tilelang_out_idx", builder.out_idx) + if builder.func_pass_configs is not None: + pf = pf.with_attr("tilelang_pass_configs", builder.func_pass_configs) + if builder.func_compile_flags is not None: + flags = builder.func_compile_flags + if isinstance(flags, str): + flags = [flags] + pf = pf.with_attr("tilelang_compile_flags", flags) + return pf + + @dataclass class TirTemplate(Generic[_P, _T]): """ @@ -1019,8 +1081,7 @@ def get_tir(self, tensor_args, given_tensor_args, kwargs): with builder.prim_func(self.name): self.ir_gen.gen(builder)(**tensor_args, **kwargs) pf = builder.get() - if builder.out_idx: - pf.out_idx_override = builder.out_idx + pf = _patch_prim_func_attrs(pf, builder) return pf @@ -1117,8 +1178,6 @@ def _build_tir_template(self, *args, **kwargs) -> TirTemplate[_P, _T]: self.ir_gen.gen(builder)(**self.tensor_args, **kwargs) pf = builder.get() pf.orig_func = self.orig_func - if builder.out_idx: - pf.out_idx_override = builder.out_idx return TirTemplate.create(self.orig_func.__name__, pf, builder.constexpr_var, self.ir_gen) else: raise ValueError(f"Invalid jit mode: {self.mode}, expected 'lazy' or 'eager'") @@ -1218,9 +1277,8 @@ def impl(func: Callable[_P, _T]) -> PrimFunc[_P, _T] | Callable[_P, PrimFunc[_P, with builder.prim_func(func.__name__): ir_gen.gen(builder)(**annot) prim_func = builder.get() + prim_func = _patch_prim_func_attrs(prim_func, builder) prim_func.orig_func = func - if builder.out_idx: - prim_func.out_idx_override = builder.out_idx return prim_func except Exception as e: logger.fatal(f"Failed to build prim_func from {func.__name__}\nargs={annot}\nsource={ir_gen.source}") From 6e6295fb3528ba4a66dd6561553478894eb60011 Mon Sep 17 00:00:00 2001 From: Kuris <227995639+kurisu6912@users.noreply.github.com> Date: Thu, 2 Apr 2026 18:46:18 +0800 Subject: [PATCH 010/156] [BugFix] Fix CI failures: clean /tmp on self-hosted runners and skip CuTeDSL alloc_global tests (#2009) 1. Add CI step to clean stale JIT temp files (/tmp/*.so, *.cu, *.cubin, tvm-debug-mode-tempdirs) before tests on self-hosted runners. These files accumulate across CI runs and can fill the disk, causing g++ to be killed (SIGTERM) during JIT compilation. 2. Skip CuTeDSL-incompatible example tests that use alloc_global (flash_decoding, deepseek_mla), since the CuTeDSL wrapper does not yet support alloc_global buffers. Co-authored-by: Claude Opus 4.6 --- .github/workflows/ci.yml | 6 ++++++ examples/deepseek_mla/test_example_mla_decode.py | 5 +++++ examples/flash_decoding/test_example_flash_decoding.py | 6 ++++++ 3 files changed, 17 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fe67b12008..24821b8cd2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -342,6 +342,12 @@ jobs: exit "${rc}" fi + - name: Clean up stale /tmp files (self-hosted runners) + if: startsWith(matrix.runner.name, 'self-hosted') + run: | + rm -f /tmp/tmp*.so /tmp/tmp*.cu /tmp/tmp*.cubin /tmp/tmp*.cpp + rm -rf /tmp/tvm-debug-mode-tempdirs /tmp/tilelang_cutedsl_* + - name: Run examples with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }}) if: contains(matrix.runner.toolkit, 'CUDA') run: | diff --git a/examples/deepseek_mla/test_example_mla_decode.py b/examples/deepseek_mla/test_example_mla_decode.py index a269ea57ae..00e30023a4 100644 --- a/examples/deepseek_mla/test_example_mla_decode.py +++ b/examples/deepseek_mla/test_example_mla_decode.py @@ -1,9 +1,14 @@ +import os +import pytest import tilelang.testing import example_mla_decode +_is_cutedsl = os.environ.get("TILELANG_TARGET", "").lower() == "cutedsl" + @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) +@pytest.mark.skipif(_is_cutedsl, reason="CuTeDSL backend does not support alloc_global yet") def test_example_mla_decode(): example_mla_decode.main() diff --git a/examples/flash_decoding/test_example_flash_decoding.py b/examples/flash_decoding/test_example_flash_decoding.py index 2cbcd84043..3181df2d56 100644 --- a/examples/flash_decoding/test_example_flash_decoding.py +++ b/examples/flash_decoding/test_example_flash_decoding.py @@ -1,16 +1,22 @@ +import os +import pytest import tilelang.testing import example_gqa_decode import example_mha_inference import example_gqa_decode_varlen_logits +_is_cutedsl = os.environ.get("TILELANG_TARGET", "").lower() == "cutedsl" + @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_le(8, 9) +@pytest.mark.skipif(_is_cutedsl, reason="CuTeDSL backend does not support alloc_global yet") def test_example_example_gqa_decode(): example_gqa_decode.main() +@pytest.mark.skipif(_is_cutedsl, reason="CuTeDSL backend does not support alloc_global yet") def test_example_example_mha_inference(): example_mha_inference.main(BATCH=1, H=32, Q_CTX=128, KV_CTX=2048, D_HEAD=128, causal=False) From 5f70374c48bb1d52e9260e995bf7adbcd64341e5 Mon Sep 17 00:00:00 2001 From: Kuris <227995639+kurisu6912@users.noreply.github.com> Date: Fri, 3 Apr 2026 14:30:05 +0800 Subject: [PATCH 011/156] [Test] Add 1D TMA regression test for issue #1842 (#2005) Add a regression test covering 1D single-dimension tensor TMA copy (global -> shared -> global) with warp specialization disabled. The underlying bug was fixed in #1840, but the test suite only covered 2D descriptor-based TMA paths. This test ensures the 1D bulk copy path (cp.async.bulk) also works correctly with proper mbarrier allocation. Closes #1842 Co-authored-by: Claude Opus 4.6 --- .../issue/test_tilelang_issue_tma_no_ws.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/testing/python/issue/test_tilelang_issue_tma_no_ws.py b/testing/python/issue/test_tilelang_issue_tma_no_ws.py index 928f46bfa0..0a6ae9f463 100644 --- a/testing/python/issue/test_tilelang_issue_tma_no_ws.py +++ b/testing/python/issue/test_tilelang_issue_tma_no_ws.py @@ -69,6 +69,45 @@ def tma_copy(x: T.Tensor((M, K), T.float16)): torch.cuda.synchronize() +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_tma_lower_1d_no_warp_specialized(): + """Regression for issue #1842: 1D TMA load fails when warp specialization is disabled. + + A single-dimension tensor copy (global -> shared -> global) using 1D bulk + TMA must compile and produce correct results when + ``tl.disable_warp_specialized=True``. + """ + + length = 7168 + + @T.prim_func + def tma_copy_1d( + a: T.Tensor((length,), T.float32), + out: T.Tensor((length,), T.float32), + ): + with T.Kernel(1, threads=256): + a_shared = T.alloc_shared((length,), T.float32) + T.copy(a, a_shared) + T.copy(a_shared, out) + + pass_configs = { + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + } + kernel = _compile_tvm_ffi(tma_copy_1d, pass_configs, out_idx=[1]) + + src = kernel.get_kernel_source() + assert "tl::tma_load" in src + assert "mbarrier_mem" in src + assert "tl::tma_store" in src + + t = torch.randn((length,), device="cuda", dtype=torch.float32) + out = kernel(t) + torch.testing.assert_close(out, t) + torch.cuda.synchronize() + + @tilelang.testing.requires_cuda_compute_version(9, 0) def test_tma_lower_no_warp_specialized_2d_descriptor_uses_args1_barrier(): """Cover the 2D-descriptor TMA barrier rewrite path (barrier at args[1]).""" From 6fc3afa2a35755d56854067d8f82a8326a6fe8b6 Mon Sep 17 00:00:00 2001 From: William Date: Sat, 4 Apr 2026 13:30:21 +0800 Subject: [PATCH 012/156] [BugFix] Fix auto vectorization for binary operations after wider copy instructions (#1986) * fix f32x2 vectorize for wider shape * clean code * fix: use string::operator+= to satisfy clang-tidy performance check Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Freebase6912 Co-authored-by: Claude Opus 4.6 --- src/target/codegen_cuda.cc | 98 ++++++++++++++----- .../python/cuda/test_cuda_f32x2_intrinsics.py | 8 +- 2 files changed, 79 insertions(+), 27 deletions(-) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 10ae3e9796..4e001b5601 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -870,7 +870,11 @@ void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string &op, DataType t, // (__hadd2, __hsub2, etc.) are available on SM80+ (bf16) / SM53+ (fp16). // The tl::*2 C++ helpers have compile-time arch guards with scalar // fallbacks, so we can emit them unconditionally for 16-bit types. - if (t.lanes() == 2) { + // + // When lanes > 2 and is even, we decompose the vector operation into + // lanes/2 independent x2 packed operations on consecutive pairs. + int lanes = t.lanes(); + if (lanes >= 2 && lanes % 2 == 0) { bool is_f32x2 = t.is_float() && t.bits() == 32; bool is_bf16x2 = t.is_bfloat16(); bool is_fp16x2 = t.is_float16(); @@ -902,29 +906,77 @@ void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string &op, DataType t, tl_func = "max2"; if (!tl_func.empty()) { - // For bfloat16x2 / float16x2 both map to uint1 in generated CUDA, - // so we must cast arguments to the correct native type and cast the - // result back to uint1. - bool need_cast = is_bf16x2 || is_fp16x2; - std::string native_type; - if (is_bf16x2) - native_type = "__nv_bfloat162"; - else if (is_fp16x2) - native_type = "__half2"; - - std::string lhs_str = PrintExpr(lhs); - std::string rhs_str = PrintExpr(rhs); - - if (need_cast) { - std::string cast_lhs = - "tl::from_uint1<" + native_type + ">(" + lhs_str + ")"; - std::string cast_rhs = - "tl::from_uint1<" + native_type + ">(" + rhs_str + ")"; - os << "tl::to_uint1(tl::" << tl_func << "(" << cast_lhs << ", " - << cast_rhs << "))"; - } else { - os << "tl::" << tl_func << "(" << lhs_str << ", " << rhs_str << ")"; + // Decompose into lanes/2 independent x2 packed operations. + // + // Vector type → CUDA struct mapping: + // bf16x2 -> uint1 {.x} bf16x4 -> uint2 {.x, .y} + // bf16x8 -> uint4 {.x,.y,.z,.w} (each field = one nv_bfloat162) + // fp16x2 -> uint1 {.x} fp16x4 -> uint2 {.x, .y} ... + // f32x2 -> float2 {.x, .y} f32x4 -> float4 {.x,.y,.z,.w} + // + // For bf16/fp16: each 32-bit field already packs a pair of elements, + // so we apply tl::*2 on each field directly. + // For f32: consecutive pairs of 32-bit fields form a float2, + // so we apply tl::*2 on each float2 pair. + int num_pairs = lanes / 2; + static const char access[] = {'x', 'y', 'z', 'w'}; + + std::string sret = name_supply_->FreshName("_"); + this->PrintIndent(); + this->PrintType(t, stream); + stream << ' ' << sret << ";\n"; + int ssa_scope = BeginScope(); + { + std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype()); + std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype()); + + if (is_bf16x2 || is_fp16x2) { + std::string native_type = is_bf16x2 ? "__nv_bfloat162" : "__half2"; + for (int p = 0; p < num_pairs; ++p) { + std::string field(1, access[p]); + std::string pair_lhs = "tl::from_uint1<"; + pair_lhs += native_type; + pair_lhs += ">(*(uint1*)(&("; + pair_lhs += vlhs; + pair_lhs += "."; + pair_lhs += field; + pair_lhs += ")))"; + std::string pair_rhs = "tl::from_uint1<"; + pair_rhs += native_type; + pair_rhs += ">(*(uint1*)(&("; + pair_rhs += vrhs; + pair_rhs += "."; + pair_rhs += field; + pair_rhs += ")))"; + this->PrintIndent(); + stream << "*(uint1*)(&(" << sret << "." << field + << ")) = tl::to_uint1(tl::" << tl_func << "(" << pair_lhs + << ", " << pair_rhs << "));\n"; + } + } else { + // f32: apply tl::*2 on each consecutive pair of float fields, + // reinterpreted as float2. + for (int p = 0; p < num_pairs; ++p) { + std::string field(1, access[p * 2]); + std::string pair_lhs = "*(float2*)(&("; + pair_lhs += vlhs; + pair_lhs += "."; + pair_lhs += field; + pair_lhs += "))"; + std::string pair_rhs = "*(float2*)(&("; + pair_rhs += vrhs; + pair_rhs += "."; + pair_rhs += field; + pair_rhs += "))"; + this->PrintIndent(); + stream << "*(float2*)(&(" << sret << "." << field + << ")) = tl::" << tl_func << "(" << pair_lhs << ", " + << pair_rhs << ");\n"; + } + } } + EndScope(ssa_scope); + os << sret; return; } } diff --git a/testing/python/cuda/test_cuda_f32x2_intrinsics.py b/testing/python/cuda/test_cuda_f32x2_intrinsics.py index da958c3684..98e1181124 100644 --- a/testing/python/cuda/test_cuda_f32x2_intrinsics.py +++ b/testing/python/cuda/test_cuda_f32x2_intrinsics.py @@ -120,12 +120,12 @@ def _make_auto_vec_binary_kernel(py_op, dtype_tl): @T.prim_func def main( - A: T.Tensor((M, 2), dtype=dtype_tl), - B: T.Tensor((M, 2), dtype=dtype_tl), - C: T.Tensor((M, 2), dtype=dtype_tl), + A: T.Tensor((M, 4), dtype=dtype_tl), + B: T.Tensor((M, 4), dtype=dtype_tl), + C: T.Tensor((M, 4), dtype=dtype_tl), ): with T.Kernel(1, 1, threads=M) as (bx, by): - for i, v in T.Parallel(M, 2): + for i, v in T.Parallel(M, 4): C[i, v] = py_op(A[i, v], B[i, v]) return main From 01c714d31b9d012e8e135b772d4ef1816ece6225 Mon Sep 17 00:00:00 2001 From: Kuris <227995639+kurisu6912@users.noreply.github.com> Date: Sun, 5 Apr 2026 02:57:29 +0800 Subject: [PATCH 013/156] fix: add cudaGetLastError check after cuLaunchKernel in TVM FFI backend (#2000) add cuda get last error in tvm ffi to align Cython backend --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 12b47d3162..882a774844 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 12b47d316230fc777d13d4199200530e8c9529e1 +Subproject commit 882a774844993d103ae6e317ba3c7bbb5952b662 From bb79425348c13aab3dc774fd95c78e8c3f86fc89 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Sun, 5 Apr 2026 17:17:41 +0800 Subject: [PATCH 014/156] [CI] Remove legacy dequantize gemm test (#2013) Remove example_dequant_groupedgemm_bf16_mxfp4_hopper references from regression and test files after its deletion. This cleanup ensures that the codebase remains consistent and free of unused imports. --- ...e_dequant_groupedgemm_bf16_mxfp4_hopper.py | 571 ------------------ .../regression_example_dequantize_gemm.py | 5 - .../test_example_dequantize_gemm.py | 7 - examples/gemm/regression_example_gemm.py | 5 - 4 files changed, 588 deletions(-) delete mode 100644 examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py diff --git a/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py deleted file mode 100644 index 501fe11cfc..0000000000 --- a/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py +++ /dev/null @@ -1,571 +0,0 @@ -import tilelang -import tilelang.language as T -from tilelang.quantize import _tir_u8_to_f4_to_bf16 -from tilelang import tvm as tvm -from tvm import DataType -import torch -from dequantize_utils import torch_convert_bit_twiddling, assert_similar -from tilelang.autotuner import set_autotune_inputs -import argparse - - -def get_configs(): - """ - Generate a list of hyperparameter configuration dictionaries for tuning. - - Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K', - 'num_stages', 'threads', and 'split'. The function returns the Cartesian - product of the parameter value lists: - - block_M, block_N, block_K: tiling sizes - - num_stages: pipeline stages - - threads: thread counts - - split: K-splitting factor - - Returns: - List[dict]: A list of configuration dictionaries covering all combinations. - """ - import itertools - - iter_params = dict( - block_M=[128], - block_N=[64, 128, 256], - block_K=[128], - num_stages=[0, 1, 2], - threads=[128, 256, 512], - split=[1], - ) - return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] - - -@tilelang.autotune(configs=get_configs()) -@tilelang.jit(out_idx=[-1]) -def matmul( - M, - N, - K, - topk, - E, - padding_M, - in_dtype, - out_dtype, - accum_dtype, - source_format=T.uint32, - num_bits=4, - scale_size=32, - fast_dequant=True, - with_bias=False, - block_M=128, - block_N=256, - block_K=128, - num_stages=2, - threads=256, - split=1, -): - """ - Construct and return a grouped (Mixture-of-Experts) matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized, expert-grouped B (shape ExNxQK) and writes an output of shape (M, topk, N) in out_dtype. - - The generated kernel accepts: - - A: dense matrix with element type `in_dtype` and shape (M, K). - - B: packed quantized matrix for all experts, stored as uint8 with `num_bits` bits per element, shape (E, N, QK), where QK = K / (8/num_bits). - - Scale: per-expert, per-block scale/exponent information for dequantizing B, shape (E, N, K // scale_size). - - Bias: per-expert, per-output bias, shape (E, N). - - topk_weights: router weights for the top-k experts for each token, shape (M, topk). - - sorted_token_ids: flattened and padded tensor of token indices, shape (padding_M,). - - expert_ids: expert id for each token in the padded batch, shape (padding_M // block_M,). - - C: output tensor, shape (M, topk, N). - - The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: - - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. - - fast_dequant (False): uses a simple elementwise dequantization helper. - - Parameters: - M, N, K (int): matrix dimensions (A is MxK, result is (M, topk, N)). K must be divisible by (block_K * split). - topk (int): number of experts selected per token. - E (int): number of experts. - padding_M (int): padded number of tokens after grouping and block alignment. - in_dtype (str): element type of A (e.g., T.bfloat16). - out_dtype (str): output tensor element type (e.g., T.bfloat16). - accum_dtype (str): accumulation type used for the inner GEMM. - source_format (str, optional): format string passed to intrinsic selector (default "uint"). - num_bits (int, optional): number of bits per quantized element in B (default 4). - scale_size (int, optional): number of elements grouped per scale entry (default 32). - fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). - block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). - num_stages (int, optional): pipelining stages for K loop (default 2). - threads (int, optional): threads per block used by the kernel (default 256). - split (int, optional): split factor along K used by the scheduler (default 1). - with_bias (bool, optional): whether to add Bias to the output (default False). - - Returns: - A T.prim_func implementing the grouped, pipelined GEMM that: - - loads tiled blocks of A and packed B for each expert to shared memory, - - dequantizes B via the chosen path into a shared dequantized tile, - - performs a tiled GEMM accumulating into local fragments, - - applies per-token topk weights and bias, - - writes the final (M, topk, N) block to the global output tensor. - - Notes: - - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. - - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. - - An assertion enforces that K % (block_K * split) == 0. - """ - - num_elems_per_byte = 8 // num_bits - storage_dtype = T.uint8 - QK = K // num_elems_per_byte - Block_QK = block_K // num_elems_per_byte - A_shared_shape = (block_M, block_K) - B_shared_shape = (block_N, Block_QK) - Bias_shared_shape = block_N - B_dequantize_shared_shape = (block_N, block_K) - assert K % (block_K * split) == 0 - - from tilelang.quantize import get_mxfp_intrin_group - - # fast_dequant_bf16_fp4_twiddling - mxfp_intrin_info = get_mxfp_intrin_group( - out_dtype=in_dtype, - source_format=source_format, - source_bit=num_bits, - storage_dtype=storage_dtype, - use_twiddling=True, - ) - import_source = mxfp_intrin_info["c_source"] - func_name = mxfp_intrin_info["func_name"] - assert import_source is not None, "mxfp_intrin_info is not found" - assert func_name is not None, "mxfp_intrin_info is not found" - import_source = import_source - - # the dequant part is the same as in dequant_gemm - def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16): - """ - Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16. - The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and: - - Loads packed FP4 elements from B_shared into per-thread local registers. - - Calls an external fast dequantization intrinsic (provided via `import_source` / `func_name` in the outer scope) to expand packed FP4 -> BF16 values. - - Applies a per-block scale factor derived from the Scale tensor (using exponentiation by powers of two). - - Writes the scaled BF16 results into B_dequantize_shared. - - Notes: - - This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16. - - The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro. - - The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime. - """ - assert in_dtype in ["fp4"] - assert out_dtype in [T.bfloat16] - - # Some variables for dequantization in each thread - MAX_TRANSACTION_SIZE_BITS = 128 - local_size = MAX_TRANSACTION_SIZE_BITS // DataType(out_dtype).bits - local_compress_size = local_size // num_elems_per_byte - - @T.macro - def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale_shared, k): - # import fast_dequantize plugin - """ - Fast dequantization kernel: convert packed 4-bit quantized values in B_shared to bfloat16 - in B_dequantize_shared using an external intrinsic optimized for twiddled (bit-packed) FP4, - applying per-block scale factors from Scale. - - This routine is a tiled, thread-parallel helper that: - - Imports and calls an external dequantization function (via `import_source`/`func_name`) - to expand compressed uint8-packed FP4 values into BF16 fragments in-thread. - - Loads the corresponding per-block scale entry, interprets it as an exponent bias - (applies 2^(Scale - 127)), and multiplies the dequantized BF16 fragment by that factor. - - Writes the scaled BF16 results back into the shared B_dequantize_shared buffer in-place. - - Parameters: - - B_shared: read-only shared buffer containing compressed FP4 data (packed uint8 layout). - - B_dequantize_shared: shared output buffer that is overwritten with BF16 dequantized values. - - Scale_shared: per-block scale tensor; entries are interpreted such that the multiplicative scale - = 2^(Scale - 127). - - k: block index along the K dimension used to select the appropriate Scale entries. - - Side effects: - - Mutates B_dequantize_shared in shared memory. - - Calls an external intrinsic function (must be provided by the environment via `import_source` - and `func_name`) to perform the low-level unpacking/dequantization. - """ - T.import_source(import_source) - - tx = T.get_thread_binding() - - B_local_thread = T.alloc_local((local_compress_size,), storage_dtype) - B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype) - Scale_local_thread = T.alloc_local((1,), storage_dtype) - Scale_local_thread_exponent = T.alloc_local((1,), out_dtype) - - for i in T.serial(0, block_N * block_K // threads // local_size): - # First, load data from share memory to register. - # Prepare for dequant. - index_base = i * threads * local_compress_size + tx * local_compress_size - for v in T.vectorized(0, local_compress_size): - index = index_base + v - B_local_thread[v] = B_shared[index // Block_QK, index % Block_QK] - index_scale = index_base // (scale_size // num_elems_per_byte) - si = index_scale // (block_K // scale_size) - sj = index_scale % (block_K // scale_size) - Scale_local_thread[0] = Scale_shared[si, k * block_K // scale_size + sj] - Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0])) - - # Then, dequant. - T.call_extern( - func_name, - T.access_ptr(B_local_thread, "r"), - T.access_ptr(B_dequantize_local_thread, "w"), - 1, - dtype=out_dtype, - ) - - # Finally, store the dequantized data to shared memory. - for v in T.Parallel(local_size): - B_dequantize_local_thread[v] *= Scale_local_thread_exponent[0] - - for v in T.vectorized(0, local_size): - index = i * threads * local_size + tx * local_size + v - B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] - - return fast_dequant_bf16_fp4_twiddling - - def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16): - assert in_dtype in ["fp4"] - assert out_dtype in [T.bfloat16] - - @T.macro - def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): - B_local = T.alloc_fragment(B_shared_shape, storage_dtype) - B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype) - - T.copy(B_shared, B_local) - for i, j in T.Parallel(block_N, block_K): - B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16( - num_bits, - B_local[i, j // num_elems_per_byte], - j % num_elems_per_byte, - Scale_shared[ - i, k * block_K // scale_size + j // scale_size - ], # Scale is the exponential part, within the representation of uint8 - dtype=out_dtype, - ) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size])) - T.copy(B_dequantize_local, B_dequantize_shared) - - return simple_dequant_bf16_fp4 - - @T.prim_func - def main( - A: T.Tensor((M, K), in_dtype), - B: T.Tensor((E, N, QK), storage_dtype), - Scale: T.Tensor((E, N, K // scale_size), storage_dtype), - Bias: T.Tensor((E, N), out_dtype), - # Add fusedmoe tensors - topk_weights: T.Tensor((M * topk), out_dtype), - sorted_token_ids: T.Tensor((padding_M), T.int32), - expert_ids: T.Tensor((padding_M // block_M), T.int32), - C: T.Tensor((M, topk, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, storage_dtype) - B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) - Bias_shared = T.alloc_shared(Bias_shared_shape, out_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - C_shared = T.alloc_shared((block_M, block_N), out_dtype) - topk_weights_shared = T.alloc_shared((block_M), out_dtype) - sorted_token_ids_shared = T.alloc_shared((block_M), T.int32) - expert_id = T.alloc_local((1), T.int32) # the expert id for the current block - # To use 1D TMA, the last dim of Scale_shared must have stride=1 - # May use much more shared memory than necessary - Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype) - - T.annotate_layout( - { - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - } - ) - T.use_swizzle(10) - - if threads == 512: - T.disable_warp_group_reg_alloc() - - T.copy(sorted_token_ids[by * block_M : (by + 1) * block_M], sorted_token_ids_shared) - expert_id[0] = expert_ids[by] - - # Get the topk weights of each token in the current block - for i in T.Parallel(block_M): - if sorted_token_ids_shared[i] != -1: - topk_weights_shared[i] = topk_weights[sorted_token_ids_shared[i]] - - # Get bias and scale based on the expert id - if with_bias: - T.copy(Bias[expert_id[0], bx * block_N : (bx + 1) * block_N], Bias_shared) - else: - T.clear(Bias_shared) - - T.copy(Scale[expert_id[0], bx * block_N : (bx + 1) * block_N, :], Scale_shared) - - for i, j in T.Parallel(block_M, block_N): - C_local[i, j] = Bias_shared[j] - - tx = T.get_thread_binding() - - for k in T.Pipelined(K // block_K, num_stages=num_stages): - # Each thread copies 4 bytes, local size is 16 - for copy_i in T.serial(block_M * block_K // threads // 16): - base = copy_i * threads * 16 + tx * 16 - if sorted_token_ids_shared[base // block_K] != -1: - for copy_j in T.vectorized(16): - A_shared[base // block_K, base % block_K + copy_j] = A[ - sorted_token_ids_shared[base // block_K] // topk, k * block_K + base % block_K + copy_j - ] - - T.copy(B[expert_id[0], bx * block_N, k * block_K // num_elems_per_byte], B_shared) - if fast_dequant: - get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k) - else: - get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k) - - T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) - - for i, j in T.Parallel(block_M, block_N): - C_local[i, j] = C_local[i, j] * topk_weights_shared[i] - - T.copy(C_local, C_shared) - for copy_i in T.serial(block_M * block_N // threads // 16): - base = copy_i * threads * 16 + tx * 16 - if sorted_token_ids_shared[base // block_N] != -1: - for copy_j in T.vectorized(16): - C[ - sorted_token_ids_shared[base // block_N] // topk, - sorted_token_ids_shared[base // block_N] % topk, - bx * block_N + base % block_N + copy_j, - ] = C_shared[base // block_N, base % block_N + copy_j] - - return main - - -def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=256): - dtypeC = T.bfloat16 - M, K = A.shape - E, N, QK = qB.shape - topk = topk_weights.shape[0] // M - scale_size = K // Scale.shape[2] - assert scale_size == 32 # MXFP4 - - # Initialize output tensor - C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device="cuda") - - # Iterate over sorted_token_ids - for idx in range(len(sorted_token_ids)): # padding_M - token_id = sorted_token_ids[idx] - if token_id == -1: - continue - expert_id = expert_ids[idx // block_M] - topk_idx = token_id % topk - - # Get the token embedding - token_embedding = A[token_id // topk] - - # Dequantize the expert weights - B = torch_convert_bit_twiddling(qB[expert_id]) # shape: (N, K) - B *= 2 ** (Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to(torch.bfloat16)) - - # Compute the output for this token-expert pair - # token_embedding @ B.T + bias - output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to(torch.bfloat16)) + Bias[expert_id] - output = output.to(torch.__getattribute__(dtypeC)) - - # Apply the topk weight - weight = topk_weights[token_id] - output = output * weight - - # Store the result - C[token_id // topk, topk_idx] = output - - return C - - -def get_data(m, n, k, qk, scale_size, topk, E, block_M): - A = torch.empty(m, k, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) - qB = torch.randint(0, 256, (E, n, qk), dtype=torch.uint8, device="cuda") # Quantized weight tensor for E experts. - Scale = torch.randint(0, 8, (E, n, k // scale_size), dtype=torch.uint8, device="cuda") - Bias = torch.empty(E, n, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) - - weights = torch.empty(m, E, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) - # topk_weights: Router weights for the top-k experts for each token. - # Shape: (m, topk) - # tokens_experts: A flattened tensor of expert assignments for each token. - # For each of m tokens, topk unique experts are chosen. Shape: (m * topk,) - topk_weights, tokens_experts = torch.topk(weights, topk, dim=-1) - tokens_experts = tokens_experts.reshape(m * topk) - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - topk_weights = topk_weights.reshape(m * topk) - - sorted_expert_vals, sorted_indices = torch.sort(tokens_experts, stable=True) - sorted_token_ids = sorted_indices - unique_expert_ids, counts = torch.unique_consecutive(sorted_expert_vals, return_counts=True) - expert_ids = [] - padded_token_ids = [] - start = 0 - for eid, cnt in zip(unique_expert_ids.tolist(), counts.tolist()): - end = start + cnt - group_token_ids = sorted_token_ids[start:end] - pad_len = ((cnt + block_M - 1) // block_M) * block_M - cnt - if pad_len > 0: - # -1 for padding (`M` instead in vLLM moe_align_block_size()) - group_token_ids = torch.cat([group_token_ids, torch.full((pad_len,), -1, dtype=group_token_ids.dtype, device="cuda")]) - padded_token_ids.append(group_token_ids) - expert_ids.extend([eid] * ((cnt + block_M - 1) // block_M)) - start = end - - # sorted_token_ids: The final flattened and padded tensor of token indices. - sorted_token_ids = torch.cat(padded_token_ids, dim=0).to(torch.int32) # (padding_M,) - # expert_ids: The final tensor of expert IDs corresponding to `sorted_token_ids`. - expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda") # (padding_M,) - padding_M = sorted_token_ids.shape[0] # padding_M: token number after padding - - return A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M - - -def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, with_bias=False, tune=False): - # Tunable parameters - block_M, block_N, block_K = 128, 256, 128 # noqa: F841 - num_stages = 1 # noqa: F841 - threads = 512 # noqa: F841 - split = 1 # noqa: F841 - - total_flops = 2 * m * n * k * topk - num_bits = 4 - num_elems_per_byte = 8 // num_bits - qk = k // num_elems_per_byte - A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(m, n, k, qk, scale_size, topk, E, block_M) - - if tune: - with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]): - # Autotune with inputs manually composed - kernel = matmul( - m, - n, - k, - topk, - E, - padding_M, - T.bfloat16, - T.bfloat16, - T.float32, - num_bits=num_bits, - scale_size=scale_size, - fast_dequant=fast_dequant, - with_bias=with_bias, - ) - else: - kernel = matmul( - m, - n, - k, - topk, - E, - padding_M, - T.bfloat16, - T.bfloat16, - T.float32, - num_bits=num_bits, - scale_size=scale_size, - fast_dequant=fast_dequant, - with_bias=with_bias, - block_M=block_M, - block_N=block_N, - block_K=block_K, - num_stages=num_stages, - threads=threads, - split=split, - ) - - output = kernel( - A, - qB, - Scale, - Bias, - topk_weights, - sorted_token_ids, - expert_ids, - ) - print("Tilelang kernel run finished.") - - ref_output = ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=block_M) # Maybe a little bit slow... - - latency = tilelang.profiler.do_bench(lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100) - print("Tilelang: {:.2f} ms".format(latency)) - print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) - - diff = (output - ref_output).abs() - max_val = diff.max() - max_idx = diff.argmax() - print(f"max abs diff: {max_val} at index: {max_idx}") - assert_similar(output, ref_output, name="output", eps=2e-5) # We care about the similarity rather than abs. difference - print("All checks pass. ✅") - - -def run_regression_perf(m=4096, n=4096, k=4096, scale_size=32, topk=4, E=32, fast_dequant=True, with_bias=False, tune=False): - block_M, block_N, block_K = 128, 256, 128 - num_stages = 1 - threads = 512 - split = 1 - num_bits = 4 - num_elems_per_byte = 8 // num_bits - qk = k // num_elems_per_byte - A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(m, n, k, qk, scale_size, topk, E, block_M) - - if tune: - with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]): - kernel = matmul( - m, - n, - k, - topk, - E, - padding_M, - "bfloat16", - "bfloat16", - "float32", - num_bits=num_bits, - scale_size=scale_size, - fast_dequant=fast_dequant, - with_bias=with_bias, - ) - else: - kernel = matmul( - m, - n, - k, - topk, - E, - padding_M, - "bfloat16", - "bfloat16", - "float32", - num_bits=num_bits, - scale_size=scale_size, - fast_dequant=fast_dequant, - with_bias=with_bias, - block_M=block_M, - block_N=block_N, - block_K=block_K, - num_stages=num_stages, - threads=threads, - split=split, - ) - - return tilelang.profiler.do_bench(lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), backend="cupti") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--M", type=int, default=256, help="M") # From gpt-oss-20b MoE's first gemm - parser.add_argument("--N", type=int, default=256, help="N") - parser.add_argument("--K", type=int, default=256, help="K") - parser.add_argument("--scale_size", type=int, default=32, help="scale size") - parser.add_argument("--topk", type=int, default=4, help="topk") # experts activated for each token - parser.add_argument("--E", type=int, default=32, help="E") # number of experts - parser.add_argument("--tune", action="store_true", help="tune configs") - args = parser.parse_args() - main(args.M, args.N, args.K, args.scale_size, topk=args.topk, E=args.E, fast_dequant=True, with_bias=True, tune=args.tune) diff --git a/examples/dequantize_gemm/regression_example_dequantize_gemm.py b/examples/dequantize_gemm/regression_example_dequantize_gemm.py index 4ab03784ff..51b7c53e00 100644 --- a/examples/dequantize_gemm/regression_example_dequantize_gemm.py +++ b/examples/dequantize_gemm/regression_example_dequantize_gemm.py @@ -4,7 +4,6 @@ import example_dequant_gemm_fp4_hopper import example_dequant_gemm_w4a8 import example_dequant_gemv_fp16xint4 -import example_dequant_groupedgemm_bf16_mxfp4_hopper def regression_example_dequant_gemv_fp16xint4(): @@ -23,10 +22,6 @@ def regression_example_dequant_gemm_bf16_mxfp4_hopper(): tilelang.testing.process_func(example_dequant_gemm_bf16_mxfp4_hopper.run_regression_perf) -def regression_example_dequant_groupedgemm_bf16_mxfp4_hopper(): - tilelang.testing.process_func(example_dequant_groupedgemm_bf16_mxfp4_hopper.run_regression_perf) - - def regression_example_dequant_gemm_w4a8(): tilelang.testing.process_func(example_dequant_gemm_w4a8.run_regression_perf) diff --git a/examples/dequantize_gemm/test_example_dequantize_gemm.py b/examples/dequantize_gemm/test_example_dequantize_gemm.py index 3e4183e79c..021402a363 100644 --- a/examples/dequantize_gemm/test_example_dequantize_gemm.py +++ b/examples/dequantize_gemm/test_example_dequantize_gemm.py @@ -3,7 +3,6 @@ import example_dequant_gemv_fp16xint4 import example_dequant_gemm_fp4_hopper import example_dequant_gemm_bf16_mxfp4_hopper -import example_dequant_groupedgemm_bf16_mxfp4_hopper import example_dequant_gemm_w4a8 @@ -24,12 +23,6 @@ def test_example_dequant_gemm_bf16_mxfp4_hopper(): example_dequant_gemm_bf16_mxfp4_hopper.main() -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_eq(9, 0) -def test_example_dequant_groupedgemm_bf16_mxfp4_hopper(): - example_dequant_groupedgemm_bf16_mxfp4_hopper.main() - - @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_eq(9, 0) def test_example_dequant_gemm_w4a8(): diff --git a/examples/gemm/regression_example_gemm.py b/examples/gemm/regression_example_gemm.py index 3583cf16ac..4976020598 100644 --- a/examples/gemm/regression_example_gemm.py +++ b/examples/gemm/regression_example_gemm.py @@ -2,7 +2,6 @@ import example_gemm import example_gemm_autotune import example_gemm_intrinsics -import example_gemm_schedule def regression_example_gemm_autotune(): @@ -13,10 +12,6 @@ def regression_example_gemm_intrinsics(): tilelang.testing.process_func(example_gemm_intrinsics.run_regression_perf, M=1024, N=1024, K=1024) -def regression_example_gemm_schedule(): - tilelang.testing.process_func(example_gemm_schedule.run_regression_perf) - - def regression_example_gemm(): tilelang.testing.process_func(example_gemm.run_regression_perf) From 1ff58bce447291345c6e9fcf5f0ce65ba826c165 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Apr 2026 02:31:26 +0800 Subject: [PATCH 015/156] [CI] [pre-commit.ci] autoupdate (#2014) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/pre-commit/mirrors-clang-format: v22.1.0 → v22.1.2](https://github.com/pre-commit/mirrors-clang-format/compare/v22.1.0...v22.1.2) - [github.com/astral-sh/ruff-pre-commit: v0.15.4 → v0.15.9](https://github.com/astral-sh/ruff-pre-commit/compare/v0.15.4...v0.15.9) - [github.com/codespell-project/codespell: v2.4.1 → v2.4.2](https://github.com/codespell-project/codespell/compare/v2.4.1...v2.4.2) - [github.com/jackdewinter/pymarkdown: v0.9.35 → v0.9.36](https://github.com/jackdewinter/pymarkdown/compare/v0.9.35...v0.9.36) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 01d574555c..9edadd6592 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,19 +30,19 @@ repos: args: [--ignore-case] files: ^docs/spelling_wordlist\.txt$ - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v22.1.0 # sync with requirements-lint.txt + rev: v22.1.2 # sync with requirements-lint.txt hooks: - id: clang-format types_or: [c++, c] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.15.4 # sync with requirements-lint.txt + rev: v0.15.9 # sync with requirements-lint.txt hooks: - id: ruff-check args: [--fix, --exit-non-zero-on-fix] - id: ruff-format args: [--exit-non-zero-on-format] - repo: https://github.com/codespell-project/codespell - rev: v2.4.1 # sync with requirements-lint.txt + rev: v2.4.2 # sync with requirements-lint.txt hooks: - id: codespell additional_dependencies: [".[toml]"] @@ -53,7 +53,7 @@ repos: ^.*\brequirements\b.*\.txt$ ) - repo: https://github.com/jackdewinter/pymarkdown - rev: v0.9.35 + rev: v0.9.36 hooks: - id: pymarkdown args: ["--config", ".pymarkdown", "fix"] From 4f75940612f1398dad2cc5211e4c44c32b18ac09 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Tue, 7 Apr 2026 13:28:48 +0800 Subject: [PATCH 016/156] [BugFix] Enhance CUDA vectorization for binary operations (#2015) Enhance CUDA vectorization for binary operations by supporting wider tensor dimensions. Update kernel generation to accommodate variable width, improving auto-vectorization for float32 types. Add tests for width-8 scenarios to ensure correct emission of packed operations. --- src/target/codegen_cuda.cc | 81 ++++++++++++++----- .../python/cuda/test_cuda_f32x2_intrinsics.py | 26 ++++-- 2 files changed, 79 insertions(+), 28 deletions(-) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 4e001b5601..a1ee2e9a5c 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -909,15 +909,17 @@ void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string &op, DataType t, // Decompose into lanes/2 independent x2 packed operations. // // Vector type → CUDA struct mapping: - // bf16x2 -> uint1 {.x} bf16x4 -> uint2 {.x, .y} - // bf16x8 -> uint4 {.x,.y,.z,.w} (each field = one nv_bfloat162) - // fp16x2 -> uint1 {.x} fp16x4 -> uint2 {.x, .y} ... - // f32x2 -> float2 {.x, .y} f32x4 -> float4 {.x,.y,.z,.w} + // bf16/fp16 x2..x8 -> uint1..uint4 (one packed x2 pair per field) + // bf16/fp16 x12/x16 -> ulonglong3/4 (two packed x2 pairs per field) + // f32x2 -> float2 {.x, .y} + // f32x4 -> float4 {.x,.y,.z,.w} + // f32x6/x8 -> ulonglong3/4 (one float2 pair per field) // // For bf16/fp16: each 32-bit field already packs a pair of elements, - // so we apply tl::*2 on each field directly. - // For f32: consecutive pairs of 32-bit fields form a float2, - // so we apply tl::*2 on each float2 pair. + // so we apply tl::*2 on each field directly for <= 8 lanes. For + // 12/16 lanes, each 64-bit field stores two x2 pairs. + // For f32: float4 stores pairs at {x,z}; ulonglong3/4 stores one + // float2 pair per field at {x,y,z,w}. int num_pairs = lanes / 2; static const char access[] = {'x', 'y', 'z', 'w'}; @@ -933,31 +935,66 @@ void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string &op, DataType t, if (is_bf16x2 || is_fp16x2) { std::string native_type = is_bf16x2 ? "__nv_bfloat162" : "__half2"; for (int p = 0; p < num_pairs; ++p) { - std::string field(1, access[p]); + int field_idx = lanes <= 8 ? p : (p / 2); + ICHECK_LT(field_idx, 4); + int pair_offset = lanes <= 8 ? 0 : (p % 2); + std::string field(1, access[field_idx]); std::string pair_lhs = "tl::from_uint1<"; pair_lhs += native_type; - pair_lhs += ">(*(uint1*)(&("; - pair_lhs += vlhs; - pair_lhs += "."; - pair_lhs += field; - pair_lhs += ")))"; + pair_lhs += ">("; + if (lanes <= 8) { + pair_lhs += "*(uint1*)(&("; + pair_lhs += vlhs; + pair_lhs += "."; + pair_lhs += field; + pair_lhs += "))"; + } else { + pair_lhs += "*(((uint1*)(&("; + pair_lhs += vlhs; + pair_lhs += "."; + pair_lhs += field; + pair_lhs += "))) + "; + pair_lhs += std::to_string(pair_offset); + pair_lhs += ")"; + } + pair_lhs += ")"; std::string pair_rhs = "tl::from_uint1<"; pair_rhs += native_type; - pair_rhs += ">(*(uint1*)(&("; - pair_rhs += vrhs; - pair_rhs += "."; - pair_rhs += field; - pair_rhs += ")))"; + pair_rhs += ">("; + if (lanes <= 8) { + pair_rhs += "*(uint1*)(&("; + pair_rhs += vrhs; + pair_rhs += "."; + pair_rhs += field; + pair_rhs += "))"; + } else { + pair_rhs += "*(((uint1*)(&("; + pair_rhs += vrhs; + pair_rhs += "."; + pair_rhs += field; + pair_rhs += "))) + "; + pair_rhs += std::to_string(pair_offset); + pair_rhs += ")"; + } + pair_rhs += ")"; this->PrintIndent(); - stream << "*(uint1*)(&(" << sret << "." << field - << ")) = tl::to_uint1(tl::" << tl_func << "(" << pair_lhs - << ", " << pair_rhs << "));\n"; + if (lanes <= 8) { + stream << "*(uint1*)(&(" << sret << "." << field + << ")) = tl::to_uint1(tl::" << tl_func << "(" << pair_lhs + << ", " << pair_rhs << "));\n"; + } else { + stream << "*(((uint1*)(&(" << sret << "." << field << "))) + " + << pair_offset << ") = tl::to_uint1(tl::" << tl_func + << "(" << pair_lhs << ", " << pair_rhs << "));\n"; + } } } else { // f32: apply tl::*2 on each consecutive pair of float fields, // reinterpreted as float2. for (int p = 0; p < num_pairs; ++p) { - std::string field(1, access[p * 2]); + int field_idx = lanes <= 4 ? (p * 2) : p; + ICHECK_LT(field_idx, 4); + std::string field(1, access[field_idx]); std::string pair_lhs = "*(float2*)(&("; pair_lhs += vlhs; pair_lhs += "."; diff --git a/testing/python/cuda/test_cuda_f32x2_intrinsics.py b/testing/python/cuda/test_cuda_f32x2_intrinsics.py index 98e1181124..63043a84e7 100644 --- a/testing/python/cuda/test_cuda_f32x2_intrinsics.py +++ b/testing/python/cuda/test_cuda_f32x2_intrinsics.py @@ -115,17 +115,17 @@ def _lower_to_cuda_source(func, target: str = SM80_TARGET) -> str: } -def _make_auto_vec_binary_kernel(py_op, dtype_tl): +def _make_auto_vec_binary_kernel(py_op, dtype_tl, width: int = 4): """Build a kernel that uses T.Parallel to let the vectoriser emit tl::2.""" @T.prim_func def main( - A: T.Tensor((M, 4), dtype=dtype_tl), - B: T.Tensor((M, 4), dtype=dtype_tl), - C: T.Tensor((M, 4), dtype=dtype_tl), + A: T.Tensor((M, width), dtype=dtype_tl), + B: T.Tensor((M, width), dtype=dtype_tl), + C: T.Tensor((M, width), dtype=dtype_tl), ): with T.Kernel(1, 1, threads=M) as (bx, by): - for i, v in T.Parallel(M, 4): + for i, v in T.Parallel(M, width): C[i, v] = py_op(A[i, v], B[i, v]) return main @@ -219,13 +219,27 @@ def test_codegen_abs2(dtype_name): # float32: auto-vectorization should emit tl::2 on SM100+ @tilelang.testing.requires_cuda @pytest.mark.parametrize("op_key", _AUTO_VEC_OP_NAMES) -def test_codegen_auto_vec_f32_sm100(op_key): +def test_codegen_auto_vec_f32(op_key): py_op, tl_func = _AUTO_VEC_OPS[op_key] func = _make_auto_vec_binary_kernel(py_op, T.float32) src = _lower_to_cuda_source(func, target=SM100_TARGET) assert f"tl::{tl_func}" in src, f"Expected tl::{tl_func} in SM100 auto-vectorised CUDA source for float32 {op_key}" +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(10) +@pytest.mark.parametrize("op_key", _AUTO_VEC_OP_NAMES) +def test_codegen_auto_vec_f32_width8(op_key): + py_op, tl_func = _AUTO_VEC_OPS[op_key] + func = _make_auto_vec_binary_kernel(py_op, T.float32, width=8) + src = _lower_to_cuda_source(func, target=SM100_TARGET) + assert "\x00" not in src, "Generated CUDA source should not contain embedded NUL bytes" + for field in "xyzw": + assert f".{field})) = tl::{tl_func}(" in src, ( + f"Expected {field}-field packed tl::{tl_func} emission in width-8 float32 auto-vectorised source" + ) + + # float32: auto-vectorization should NOT emit tl::2 before SM100 @tilelang.testing.requires_cuda @pytest.mark.parametrize("op_key", _AUTO_VEC_OP_NAMES) From 868c740063ca32310f72f4453ccc90e768d261bf Mon Sep 17 00:00:00 2001 From: _Kerman Date: Tue, 7 Apr 2026 14:20:26 +0800 Subject: [PATCH 017/156] [Docs] fix arrow direction in ir_transform_diagram.png (#2016) --- docs/_static/img/ir_transform_diagram.png | Bin 85502 -> 74048 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/docs/_static/img/ir_transform_diagram.png b/docs/_static/img/ir_transform_diagram.png index 3bd86891394c90db12f98c5dcc43f02330aa3f93..f6cbc9da4a59daba03db4234921a6e3f10c4f260 100644 GIT binary patch literal 74048 zcmdSBbz9p_^F9m(O7P&pNs*SeKybG}v0{bd4#llNaVrq4P&8Pv;-$E|YjG&6S;AD5Wbm*lu~AS^@Z@BrR8deKLs3vr^8k;L zBlMC<;mBX8&Z;s{)Y%^|{g7|4-pgt`qo9y6{rN$aQ)N6tL8af6lahGrp1PlgL-@GTTLI`>_DkIjYkdANpnalq6R87r1Gsv~%2b2+bC9jgx28q3mZ(olC! zhokQ6*Se|aa}96pib*?Ei*!_g=&=G1rwh)(Ldk=Rv9F0@w{NbflZ2eN29uMy;=Go( zn-GOCAQS-p?@LZvmQEwuf8YQ0*E|mP|9s@nkX5b9|9%TOB--FnYT8+0_)o_$xgvQVG z#hL%Vx8euY-;ZaP*pfF9ENW=e4<{3JC0Fyc>`%_ycjzCszx$l$T^e*P{%<{S|Mx3X zM9?FeV+^&(&$H71*|3Dj|4v1M{vYYc zsSdd}Rcj->=`p`lWMM?U{i_)^EoAd$ZPNtYRA(jW--_#V_5YXtW#@CNVbR+P z0!yq9{qQl?_X_{Ei+7FU|5{+qL2NkCsfvYRk<-cjZT%*D(h2Lu!%|)m*Cw7CSxUvT zEiFU7&G)t5G4q&h3>#r8&H0KR&ng_sS+>7$&cIdGJd-Q>;jN9P_wDdAtC5OTfg;{b zUNvDg=_JlBxNivoU-!OlpDEw=^{e)aeKiGv5x9@yj$QwQg2yC$TH&ga#bI+4-||p^ z=-Yqgr-sYQ$aAc5ajUs&Rx8OHmq zy>K7S{mKRS>zvH5MCMgF!>fwrnRT>@a82xQyk@rVTA_4}0@Lh+?RmGucNvkMYWEo| zd|$G-CSG*?lVn3MdXEbLkiR|3D;2amx~&%!%~C>VRHx+tQ_4Jm?D|>C6Q~izoNeX} zhFLiy#Tn;nMTM10?@9$lA2{xB>9u|&R?xQWR}!c~weIp<_ol2YR@s*QTZJ@~|Iwg# zP&8yD#VJIKAY4R|!o0t3^xQ8T@8&H+lf&CT!1x4GyO2eqHA2AiPwWiwr&fH_V z6jkH?hvbTO_^=5^&^*)T(8p1g47Z1ECwxHMS(y+%y8p$(bppR@%QV3UVybKF<6K4J|8PCL)p@Bq!8Ze_Mg1c@8jU!x&6m$h#I+W>XO;o-n{cqKHYmUk# z)Jm5xp4=>2_p6Wd|5kTXyMNAmf%UPi(lAE}JQDNwc@yOR+f&W4DH1~a8GYPyvb&hY z2>2)APfd3>pTOwl+@jIemDJdjofSUnfy7(}268fc(LtAm*DXMweY z#`afewbTG+yEN}cznzmC12#msOp&9I)n@cZyhr0eNI{bVOX?q!&G>qJ0ovaLU zf*saZR&&<>}qk>Z>+{4+;5W{ zl1~apw?~ps5Px=f`H!Nw^jp0YuFQXNB_Lk^94^k&TTQ5D5lX4xsPZWxIs6B$jVA;} zBb|OX{K{HM=-A`7PMbKJLT!P(RYPp604iqye2DosBK+-Of2VCC;tPUX1>7o z=$GDM*E@U8&F_7JutA-dQ|s&eYWUc8HxqD)ye6P?5j5?eCT##dLOcfDKQ$L%)^g|% zGKmY2pm1Ji})%2z3k{`see69s1w74cc;FQ2)RWCEGE*pC5q}sflu%#n)fi=+tU@R5of! zkQV`OZueF8)i@sDFbyG=H}@w6>|)PVf0Q{F<3Jx)liPgUi`(oNG>M&F1}#?ti>9|{ z@vlQJc$ZF(NA}zLvwR{&PArwZ@~!o2-8ZW$A7Wp3HV@b|kR7b~DK>S11+Qq@{kjv2 z2}-^T@=^0T`js0K(-GeAh5Yep|Hexsk@^ZOq%T$`l23wJ_5|C?x%_?4xXd)mmlz%D z9#ZBsbcJP9#jAL_hGXF%ipNjp%Afmn+S;I^BFJiSM=DZu%h?sP3||f1WKE=Nl`K~{ zUc}c%TpPZ=Fy4KIbMcaeg>u!#(4+t7*umU%!|HZJF7cu0-DJX^arjB3mZaE6Oap!Y zIguXtwqA-ceZm8Hq~?%(?*J3gzaALxWx(NAI-E8U)PJg^rjudYqs~{vSBt11RqynR zDIOdD^yO4);$?1DKy!(Hi|g6Fs)%dhOc?P)`>H~HI~=h+>xO*L|AmHgD07|!;J0%U z6yGHP9tA-S6MD~JBp-K7$q6o`Xa}oXrtU9#GznT9z0M0rXAPi$4vBjKbDtJypNbl- z%fP9kfY6Ll&-khNOB(V;FgZA+cgiZB^}OonFH#04o$&*kYF`nwy+)28OT&X8pr@Fa%? zZIScd@A=GVY1WqUn_XPixd>+MT)2Gj=K*FIT&aYGIPEuj`8gGK=Nx#{B-z*6C2NjxXN2rRU%#C8nIHSXUNth%MFR&5~EP zMDc9WfhF>f=i9!L_EfcJKfN+d!J%Q4$7rWj$wz5t2)$<7kX>ar28RRQKWlN7w&wssD zt}u$wG144OV` zHg5+TJW19yA=k8O*7)9~*O@iyiY$tq@ezc$`4DnmWy zXSR4kTAro%sjy3wVFlA9BV12hjE%$#@!m&l$WiRiz7PEk7-7KiEO#{aqTgMHso7x0fvp+=a99x4vQ9zqZ4KTFLlo9P9prC zq~L1_nN9T0?(uHaR0Q=tZF&;Qx;A|!qf}>eWS9+=HST`Ti31jZ>6XCIYt{Dhn@7KFLoy(^u4H z-w}xuXLkJB%B=|S7=WX7V@n_P$nM?2D}kzSPM$HV5;&RkZ2K3}_Vp&Q7^i z+@w>Wsb)5x*6;Q_YTF95aWluiLs;}zj9LiG5NwpuA%RDQvA0$f=( zLb%mhn&cEzNdB?LB24j{uyoP!CHmz19>0=(wyF0EJQcJ8*pL*zFm7dlL>)7XcXwTM zdApx%nyo~x{3hbPVJY{{v(3s+Hy;ops!TyoCk?e7(zL)*IpZ(VKK!olkMn4!bS_@z z-W1d?CggwH^Y6d|U1N-<2DXf#LvEAvZ1gx&-4`V3hS+ zOE+3S;+xd31V?faoZfx$xp{q$j5BZAy&gSbf;B{Rm=8nyHYlAHCkS`uI4=C+ zy|eI(o1L}vW-4^bjmz&@rwW;=>NH0_oelfD5k}6#9i{#R7`gv2B06w!fE=3|?rbm_ zh53D_S(5?`=k$?4TUI zw^2Cm4MtA~_~*+kMBHcQ3GXI20QZVVUwRtyJcO>TNvW!YkwbAAwwbOErg{1XP7;;h zt&g^yGBmX!ZZwON1lJO_&at|%z!)fx0M~dD8^q-X)TQ;02(k;iDQ_JQRz1)NpC8G0 zoJ9r~K))59E}m4IU^IQ~?7}8N<*OH*jU6*xco<-{AcCSaaX}7OL?$GQ+(^3GIu}n2 ztj_n?k>IZH^GUF7@JDVowM}Hci-ZBw@v8IC$mE;5m@5naXB2r-2kCnGXYAI6!G>g& zH>P>wus}`xL#q-G>+5jRH%>p=k&Xej5BU#6RrSmm%y^Y-A{(m*11^pNZyC&*C9+)D z2>0GRV&ib}rI;fFfRm2zm?k;p;5baMg_l!x_^5z%RL5W)(7rLhuXd4)c%@|k&}ra= zrL#o-;6Z$OZS4stzI)nf=+k#&KAyQ&EZe7G^xFyeL!vnq(MLZSTfj4*JM~6~L${{~ zIxGn2zA!jgo`#g~o+cjA<{T~iQz^2Law8mpSPjY$? zbS9DUV*Bjq(`$`6#$8VryU8OT6W0ueb6@{E==LA0_#pc~R`CrOt0%+3s0n7Y6;IEx z?KrihVSME`=hO6dc2u*yMjYTFnG3bFbHHYTiNnNym6?r0W?exW=XK&v1eov`{AVzf zg%|Fg^dBZr$05^?o3J4*z&yOw3DFGG6i}2QmJfc^`upJ++l?nMqo%|VQHHm3*lM0( zXAvktWYB2wd-d9gcl*Sg>cgE38=kon7iC3zd+guST&HSopBu+)5YN`>D%?{5xPX@Q zpeM(7dDO#!NSyjVG-A~=N2sim4<@;!6AAQ)%D!4Fo2^MOlH-L&P>m_CFz#20%R=_= z6S(&@`T&2&d(qn4p;0FkCvo~8UVP=1s{>ygix_DX0R9|wKg!8v`=g`0>40=fteIRJ~;Na zowe8F{&*|k27cFG8!vyZT~+#YN{PmIzQ!(&{3XxG(Hot;>$wNty2`<3e#&d>l^M_R#_iZPTfW2rN?|n20t&uio)C>kc;MP+1{PJrwYsyAw4v!X5 z4ZF)c+AT0S<);@OmYT-D%%9yxQEP3~ZqX`wLlSHg`m7I?V|udlh7Md|&vnLUAENdW z4h%r>_Srnj=DLn}FW*Xj<7W4DCR%Axf6@4`&O>$QtdRf2qg zeMzYA^CJfu_au6U?t2sBFM*AU9`mf7UhBqcOIT-nB-8j(!6IFs*ie=T`3>EL#b?AX zi<=d-4{Ib*uEJ(GJ~z=P%0G9*?@%B$|GJf^=lzbN``X>8%BCC-F=BkFB{<`njKULIkQpb*C&#)*>aQ~uR^p5%7(P3wS z0^!eQ-ul@bIP^xxU^<;Lt4KERZuwg=B`qn>u|D<)e-9)*(Aesl8e_GKp1)&#J7nCo zh3XIQIP@3iSj8UQ=2XNYo8+$Yj>@-Bl!bR}Q-*b+lqUXX7iAXx6Y%QR)8bJwb2QT} zwg_U1bwvx~su|F666tOC=ruIm6#UVj2?D`|Y<6qxOx^5ry#(~bU=Lk-rYP;5 z>q?qOf?Z#C6s?*Nlc&n(pNNz2LwQ3^}Fa1U`eLhk%@{hE_&PN_MiTI}WaUn-w!um6e>R`rz;vO7`ZN<<`r2J16k(H*px_WAQsnsIh$k2k*Q@_F8zIx!iEF^yjQdhSSNw!Z+Cp~S9 z4|Pj_>Qq*m6TUs5n^HR^QnfYZoY4I?#73ltz&y3v&f&#MFOGS2+KOmr67B$zdAsqs zzC|r<xt68A702aXlv6ue zCIIu9P6ozB1ebf?(Ay~$nte3oR;uo;PQZw02qw70P(N=U*L6Zm`L0QpIK-<3ChVe(ErS9U+2p@0FK=8wb)%k!tqR_^ZjA`Rl?%M`xM!Bjf#AJ9& zne2e*oSJj|x3svuS2tT3+t$;?qQwaNW_-0Z5tF&y9oHKvoBEa->lt|7@1a}XhPH_4 zMyY}l4w~y3>YJY9$v}sQ`P!l)+;kSc`W7)xuk-!Vp0|P17)=~Ui-aaYq<6hPHKOwnnXH zCU)E?$hfxA{1*g#H5UL3&>nEBi{g5lkYh489MF+?XP8HL*M8 zv`;+m!*g}RZ3mz8?~O_inl-L&B<64tUZ_XduR*e7p9djgbAlG3k*!0%P3m&4#K@@G zNOrHJ<45F21mjhNk^VR<@*+GVzEv(q>!{llw-A5SlL1zjBazaiR?TgtUFfDqeli{E z(-b<7%hfmm`$IFFh5@!~iPP+ru3x;T>tD^AZ5y6_;V&EE;X!MLEfm`F`@g^Yyv34I zm1W&UgQpvd2*txgS;Z-jbLRM3%Ux$73#AC92Pl3_HbgXzTpBx<=n^mhaEC5qE{z9V zx#WP$RON{m1FE&7=yD>*Xt&b>6Cv&}l}B|KufKTgf3P!aqy@6l!-W}#hz);YfrEs5 zp~#dd8ejY^V8>Se?#m{P4z3mcTrh0_Ayq4efu;tDuY#g|NqI)XReN7e=aWKvlo5J!F0GnxJ0W6tkdbc0AIuK)KGfxCtyNvJ*_xdb zX=(B3@}^J~veV)hU_mcit~ZN_ci489XT>t1UEWEE**Hn%KBIqR#7k=)^CV`Qb-vZ+ z^GpmZI^k<>iHsmPr&!`Ib)^CyJqxZ9S}FE#t#7}mlmEdgMF19h&d_zgg$`)ssUW>Z z(M4r;kN^f9l0jgi5Cpq$4IQK4Qb9AtGhmHhY|n^==GlRV$hYTtE96#Of9-2zXLIm5KdCC_I&(>JYYIO76NZB+7# zv&V`U284Z~_?2+%r#HB4w57z1#@S(m7yb6vm-AnQk05He9Py4KIywN4nT2gGJz0Ns z`GR+$B_E)&J2@B(@lQ|7imH9PZ{0eXn|Ef3Ce6x-O8c8uJ(+6lVn!yHIQN$iL2-i| z-xov)tdC27&fRFSB(JdWu+=X$30ddk;^3GYUpO)1!pc3nT^oUlXq7eZ5kPF5)yG^+ z(gac>P0Fn#vo#s>yej*Hs-ev8=Jd?>JgVf6NQkUW&Y56V?xZV=I0$M`#j$89d}3dr zsnO3~zCa5hI%tH`r6e4aUt#W*B9Jrz}Q3TIq)4<>a zw*L`5bETtRNeA0~i@rmGpuBabq)#sm2?Tcdxpe_!Q_+Ns-*4W&>TtE(Bby)afl=*R zZ|a#2iG{{fMmk(4SCmxlLlx$oCywZpEC+fAZRQEi2J9gu`rA&h8C+ zM<^)wmIUu2LWK5Oy5~(AO7t6D6*mzwb`p{J2S3WUbee_ir@qpI(+6*-vKIFmf2F}& zCm)tu4fHB+wj=52U+v;#P0ls2>OQz`E^sB}3YDg?8O+o^%fvtHnxEG^bFjNoFi86_ zGOEm_AFkW`@;Y|*=~Xt1G8jF0HP>E1U+bkXCPc7<7nXI}%d;Yt9BjAQDKuega9wYkSG{VAfF!!M3M3DU7p5jof&)Z}VDon#U4Vkt)Q zU}^>M0K!eK&lPX*geDRVVO6{^EF{TCy;qAp!@i=wT(H*GgEybenZh?-RTKV*{& zqPcw`Pbu_DPnSnnkV`D}JuqE7?u%Co2D)QAnxtsNc3kB|@8!p1nqWj3$0Jib({hVJ zQFp_Rx?@YrMe*_udfdySq_ETRPG?l($LoZTPhN4jpg5(VkI<&;HtoG$uxO$SNL2y@ny4JEs3xkBWdSgbdwI~7<}&G~?yxk~RV zWiHY}=10b+R)n97Z*ZEe&r3%I&dwx-dCMlSfeGjR^FBXF!J_$%0F;^ggNsdAIN+92 zOywjI%Q2)kjS5xGw;4l3WbpjLQO~B^ove0Fe5YEG2LPi41j4@&*il`}xU}3$I(A4t zYP`2!+!Z{?@Yvi3`{A4yz~xJ2uENaZe~e_W`l%RN)q5`NOU#=eFY&sL?WBLZQSbj4 zbZ=lbyvbPkAW-0PQlolJG`7ph<-~@~)A;gNRB@+>hCzKdwb(5KnWrfP-vi3WKvqBR3c=Mc6X{>B^3CIuWYnGmKeqbWFbkt zyo!nnF1V^@cv={okc&CyZp#hX!)3LiU+LC;hiOI(js5Xb#Haq`_rO4Y>Y0kWuGOod zv^;`D<_mHGWFg#G?hTO~_R!^>gi~`VyB3FwJLiV_&Ln~N8=_YZN!vZbPfk7@z>N*V zsPRX7Io<4z?EOJiautC#7zh+SI!2$Sr_N{0sC#c#b|xmGm1V0UZ$7(0Oeoe3OkS>p zkeJ6BZ;3|C%U)l`&RcxtRyh28fsK3V-nvFZ2utqddhCD~F?6?J;ZDDD$KN;U|IyzO zW`S*59VJ8nLj<(p2MU+a^zf%=#x3|-68q872Z$%VBo@LsV|7zM4U7nJgzx6`fsG=qBB!1!suT;3AWkw zP=If;H&;c2eEn=+gR$P-N57;4OlzA^0s0EoG++Ql2?WLMvwb+~5)S(&Jq%^g8&UWL z5_k7^)c}OmU3wao`hI$)RXQlT|Kv-FNba7XbA0#^xF<4&=bPa2IAOU?9)&N(N`b8@ypr3xGPq3IBSER$L_h_wdpXPbl0I{ z{cFleyf87PzP7VDflk^&UQ3^&>d8Q=#aeR!7!uqej_sdY%iS~$xi{5w8OfxnI*$69 zt-af!YbE}R3;NuLO%>BlDnqO&@Djz83)e4~A`&&5L0D^d!2(@p6C$?fbHBdi@k4}dNsm~K~T(L;~Q2_xK|wuHdoi2r>n2|lk+9z8Kbw1 zv#wLPv*ank4Zm>Cv3BmQl68Onr!fJo)rh+i&0h_yH1giufZJLnyOgj znDX*~uq`ToUD0$3cM*vXm)e%Pd8wefmb&4+?#Q2IwOTWQk4ChB*H{;89&QT1v3?vJ z<9E3S9gC!ra4l1V|1v&^bcn2^Lvgs?{E@W(M>2P>T$>lVk z40>Qs89rixKFquYS<87pXj8>A&ARP1a&#Ct8rb^NxHJ|8#*K3l^afeRN<((P{Q
I{n7pFrSkkZ|qH%0&xUH^>VAx~3-B ztcxgQ6}vJR;>@i9aO{Q3GJxK!1O=?K5ex0{yk$x@-N3p4KZmW%h&i7@8U+9`AK(?0 zuBTZ|-WtdMTs29t`HEk7VYb#K5V$mQ!H*=+Kb7(!#4@i75y8`n?7G5pr?J7Bsg`Nu z_X@1LEYA5tg&SuY*^4!b_gmMUB}qaXY5ilZ`{R{r%cqW?<5f3T&*^KD(1ti(V?G}d zTAn^(*qjz&G+1+0yTH}NqNV@oeN67rFVy8-XArjgQm2)-#BjNEXRQbtn_rfgWDK=g;$7_1mns5l$=YPY zwn5T^T-}sBXaq53K&Lbz=DJyZVt^c^uLFRO_AH_ltv7r|?^-$Bzpo1_i^V0w0H;5+^rr&Nx!o zVS}Rx3ZaRXC|e z1N9hzI=w+t7M!=39bXr@3C27Ql09k?$OQr39Io$ zO`SihCrOaTOm!dGPRvPZj*bYWXJKx7X}##=3VhSvfv@kk)t1gkhx^r-7F_)Nm=#cYJ#jo`~THv8Jhsz z0Qj$7U|kRzs~FMa;4(sLBFL!BF>c>>gHR=nE5@EsO7sb`0Kr6H1bL(dF_ z4Ou9RBOk)*@f(*k?k~fzHbvsBr&_yiKU^_RK4nqwZ=AK8$e6O5aXDGEC?=pQ4TA#E zL+P1)CV*5-iSKp_)XKya#Bn$1rATo{{o4;8S?%KmMN3FLvfnrqk4*Unz(hL`TReHt zyP=Ytzi>KOkztUcjW}Hip~p=EkbXQpWY@5iNHa3$=@?xJVGu4~!ark4CB{InoPCAt z-Gp#1uOu8Rtk-)kl6JKRft*6PapNc6eR%?dYfKnuilv*mWSFS3qd2CUc0GJ4WH4zI^vZ}K`TfW=f_E39Lu}?j%Byvy{h=A zeNTpCe)S9`3IYN{4k2uq^5$rQOJa}{(*=E6os1tx1gNWh)YsvyOk#LG!J{Pjp<;eH^Mk{PpLUJ7%;%aJ;aprQPVy|sl$K{ z8ID1%MZ>G1v|jykYU2#Iw^QyJ=H2Qf*3h$yP@SFy@({v5UTWk}HmT z^($k9TBkmTlY-1_Yo;3r5m-u1G|O+(;D%a z_G&O-S2aNfIn)7T)p4x};}IC$K)8+y(V~Zy2c}ZKc`uUgGSJ)=tV{JQh#VT=+cy6^ zu9ns*66Sp0cKYNc_bSbCTGHKII55Byn0M}$0k#((3`it{!Cov}BMTaQVWfhqxv1?I zpZm_e5!y%i9UCxvjx%K4a=g6{X%Zxo$T}u6#DNuaXQ6;m*Sue0$bOBsiWF=9aPFB5 zsB8moCp2{GXRkUzLzdGX0lyHRo!!rB4pIn0N}#6Fl|^*r=S>gB3Nn44WuBZN6?NiOPHQ^Q-BNi^`8{gX+y8P*56qZKj-Gk7 z2AlInQ1omsFK`}CUiY1ok^ZS0z4&knixxgU-%X7BgHZk`3#qX$&`czCX5Ildfi!}+ zMBK27z25LB(9 z)5?1>x*vIWFA75(aId3_5CgcC2NyeY8tDdnEuKev?b`>)O1w_Xd-0g(^1(7cf5n%! zey^TO%Cm}5)HT4X#r*ONM!$g#;{%`qf_NDv7m+1ytu?&MBX4YzDFXD*;60#IK*i{VtdV9hP0b5-&^G8+5dBaC_8oOyQbnLA=AU?w$j zNi%2Bu4*RYBs$xmfoPr+=4-Fx?Uhwt@v>D4 z9jdV-o8!EZZGq|(N;)5(ep8u)qvP5Wf*!V@TU={8e?ZRJqRvE*r{RP7TB<~8_kpg4 zhHJ#04MFb2+2rIT-rDstuled5EcRp8H`W0nZq0A4JXBxbEv^@aw1{MTALGf-ZnWtz z$Q2uWudeMs{jFbeZ+<%BFeLDx{o=vUpn2NR1N-p`!w|1?^I627`cYG{_ZSOS441+~ ztI@Szf(?R$NwANhyk2RNJ&E7Oa@<|iXVBE}qN$5t?ZI91g6zPdg`!?QFHz*2RN15U zizuDFI)P7e(F~_%ysH5v>@?u%r`g}pR}Zh4(xYHVrXoA$?b7eF^=nU&3BAe(C3de{ zlb_4%$<7*iBA^AC`gIYr>+|oWxx=Xs4a%>STWdVu>m{)OK{%b2GwW?hLH=})aLqeS z_1fE_95K)syHH_Y{2lp0*+-t@=>Smtsm4fH03m~@W1*X_M1`Ed{{A9_!tNj^dh4FtQ1Yu4gH zaeF}N1-1HVtcReBaGt5PdsB0_6?NENaMJy$o>*hugTseAsSLu$I9cK=q^?alP0S}I z8HLvPFP^J13|;=pd#@M|y<`skcqDjJIAtBK{a|u1VJzogy1ao=hhS+}-^>5dj{{(@ zyQ@|eDjZ5-*Ye$~|HXF_`ZjsnF1DC$gk@9gD+hzVSN@k{?*8i2?M?O0sHmMMxZCX! zZo;jvo%vZoA73c?_u}+)aZ%Xt(KL-5FNsY01g}z>v^5@bna_3tBfT)k@TD*AV^07b z?!7bqF_fm*2;~SjP~j>T5b!3ODL(MsS7e7>`z4+*h`z+=79)5^QcsLtbj@>>hMcja zw;T1OZh__Z)@soXfUL`blHzZubkJe#l4X$L>`6_IP%o1zKbZ07nXFq; zw+Vvh{yd%bJ5D+2Hns2tGLzaLgW0I8M-or7-#%2$V`i2fErnqRq9XB6Gzm?2OSU7e zJG)5{ID~wOuiaS@>28T|{&Vz1kw>o`S7We_)v0ip+$P7Lu)0#$t;T7Ur;4W}6)ySn zL2b$%BuZx9de4olT7vBGLyH32$$P}yi6TO6JNIYOpPSYFNZ#2DH}{hNET=26_Y@gi z1$pe%JUj%5A~F!7Y=tEy_A7W?&if{y!t8A%C*7l&;W{-DPIfq(YeU#%DIZ|hGNqX` z()%oTd}slW@ArPl-sLOTb*t;e5(4>0+wq9 zWGgV(QyU`qNf7^y99gdYeEwOrG~}{>dCen?NqcW?v?ZP2_x$VpxGg>u1r^{Lm0y4v zvp)4QBp?$p(&e7=EjWY%yktA2&Ue~o038*pE#|_l z=8qgHq&M0#T|n8uA$s)32+IgXz-Z}fNqTf+SJ55}B=@q%E-POZ$s9`eOpzE$|4oAa zzUih(x@d)ucY>gaE`Cp}OB%qFYcC|#x0-1dAt0RLuSxV#1tVK5VsHBsDY^KJ^W%UU z`EozvN6JOdc7`3pNR$9STt(I*dB)(n$Zm4eRRf7%hnoomj`JN*z|v{m_TAz(J&C4X z>-(X5WPWhL)%ThB@w(+V-Qn%@Frn#0g<16?H$am#J!l<)dEc|YzKdiaU^zi1e`7I5 zP3$Kyl(bLE1QPEYejb_I;aQW<3{`|^U&IuK*Ifg)i7ohX+7vdHzfUM!tnfoN6bWtg z_DE+jJ5YV`jJNnY9MGlR0MxI$hQH8kpCdME+S^V)%3~3Ea8%lkVQf8LBDGnWzBXmO zmUh?g`pb>StL`%Es{gA>&uiJ}O=v!)#BKK!S~zWO`0)OAU&3vrGoqsVXZ+Lk#Tv7` zTpY2~vwcI)ND;?~DW^cIR+)8sD^MEkq>Ib6V%^PeO$#{tbs*6h&!v9XnOdis%#B;} zDaeo|-eBY;4w8LW_q9b}7_+6M_4?DSq*5$?21T&_sN_XDrH zn8PB`?q1Dsxbwn9WFW2_^WG&|rVQr1m|27-GYsg2<{F*EXa=+r&z?mXf|eZIT+aL6 zT;{@G%fdG^2C5BiOkQ9|kE_j(gGn{BUE zcM_Eh%6Rk_AuUshMc*iC2jvqDh#w|e)$-SF6vENo{<`8^Kcg|cA=Xe5*&9nTU(r`P zNI3STpi_GqIiOJA$=V6>vyZ&tAx#~5byx>P3)rSyC+#tJpOhHXN;LOt!if<>qZK_W z>%8kZ#X$I{JaGvGO1nq(guQd8lZ1CWL)a?tmfnC`K^aUXcV)78sHczIBE(5x5_D_W zFQ-Avkd}vG8oG2jzbI;kqZR+@*M4$XnwM0tZ8BO7cXKd(&g+x3L+T9Rjc9vy6S#r^ z_xizU!2LYXn>D#a^DmSWZ_;cO;5N=#4zBTKEb!Jly49hW+lJKA_4L@=FW33DbJ?c9 zdmM+4kLxqMnI1&3KPB1Lae?Zd`jYTQ)N7#HA3@4)wj;BIOfYlqknF`l_Q8dYcYg0O z->srv?*7F94VJ?2-d+NSz}FwkZa+W%aHtc#)TN+O{96C-%Rb4=CV$=b`GgpsU*~@p z^0<{L@JBL$BW1|Bfu{}OD$L&Qrm^8}3lHIalwQo$r8Hsn`*25_qg6M|eLFdr;5`1r zvA2iK=~P3=)Pm=5MN-WteO-5@^%`O%JKag50I;&Exji%O7rpHwjlcyEpJxR#uALcN5|oIx__qihTtwZiL< z?Ae!DA=-0{;Abf8Irto`DT9h zAdvm5(8B3-V^f~Z8#e*fU3KzS-}M*BdnsqF-0EPq70(V`%#P>{*(0OoM!q7Xqvo6mt99~N_iSZUFqFtKZl=SCld@@*L#|BXa$R(K z0{1tOQ)vZwxHj{Z!-S5q&KsUJd@BI6_%PwkJaucB=C%*AB>}4QnMMNuW+Z?Vdba3F ze@l!ho*}O=4PYOZuw-(l*Bkxh1BP)E|0(K;Avt5#D$;J_GT~>EJOD7p#r3?B3V;RW zXD=J=;p|a{KxXwr^M^4~*Z01xQ-M@VQ`7GX=41B(Pyq8VO!AWjF*G>1rjz--Ygwj4 z_wACrM=8sj#@>T()gJLzoH?#@Wv?=}w!B)&)h^9_Z}RFlC}?Q$ zjc8$=+r^Ssh-UH9<Ek@FupxrE9V9h$pIX zsqZ+rJajbs$hM!U>1C5$%dZ4ukRPAq@=*$jDJd#}vvC{=AJWi*4YB$k2Mig2Pn}?b;(+2CjOKYs%xgB=gy0;?G`}PD`LY(Q5 zM9|}}f2EbtfBk!cIX)6e9fDJ*=Am0Y1I~xxuACg?4XUn9-&B+S+AhjCG$WYr0-WPH z+_`_kZ|>Gj^&4#ymr+F5;GxKZAZX9e?7XomEI$A!arF2w+5G0O8&q`BlAFfm(yNdp zAJJCFv)kG2FJIP)usjsP<^-AyY$)_~dG<9|U3^Nl--jm3Quety#!uPg%EQx8>Gm(| zn=hC3)Ont$4R==|+xVBK1D_7kR#}=~Wywet{H*@v)-yu+JLvlu!P)Ly{ea*8@d+6e zl0Zu=bSg%u6QxHD+-|Yxubv+sJ2BLxi%Wa2?6^#TWfuavAXh7aU#i%L3PL1{EW|-D zp*@rS@WUAt<1d6Y_57UzZKHlUEUUN4n|TqAd&HTZJbB)Rl-S`%@XG`tUU!Xs4~FUP zOnVw9lqVCqz%{%P{D>QU$%ES=IK%kw@^Ltw1wSTAY1l}JC}Qor(@`*J%5UX5LPhUQ z4Bj+$faLP1;6`qj*Q(xyKpc=$Q#G?NI?hbO-L&-+h8UN_cRHilh+Z5G1)6kR4|UHo zG())57*)S8zx(op62*DEY-18*t-}Au*n3Ac)qQKDs7Mi!2m&e{1w@*N^d<<>q$o-+ z(nWgj0Rce-q)BfU5Ge`0H|ag0h7t(UOMp-VB!t|}`<`>|@7(cyW8C$BviF*6u34Tr zpSid2n^xtEuS{;hADA+^(zUi}xBIq%s;a~77SD;}X>d}ZYu|HNB^ z9UJf`a2_Oqh3u3{WV~k(qtJi1#4YAH|NGm-sHN8rmGtHQoZ#lJBU97d)#+xauUy&b z5C)L6-2&U1Y|@1?nS^TeG*Ir(xDIwixGz}EfRkH@=Y!ybS>QGnRx z7z@KNvM-S|2!-DvYu7q6tT>T3M|aAe^qfVZ^|{YkHp!c+*twH!f78IJefvnw-cG}$ z*#RicjcU?8)_~?K8SBIUEO(-ZinvxeA(QG6X>ZhC_R8>0X4u)Nzd(W#pq>M}mIGY# zBSM;Ywo^n*%>jFaqy(G*aD6s(E`k*+AXr%$5p$Nabz>9o)|2{keO-L+aZDDRj}iLZ zSUdh&E%kHD+sWpl`K?p;N87GrQ@;)z6m0W*4Gy=oPf$HYrWu}VmshSWu(*a+h}0Zh z?Mq6i{pny$Feqnb*2p^r7?b4Mm%<4j%0#9iLTR8-kbF?XF>7IBd(viUNv&21i1XD( zvX&IiWTmg^ZEQWlCBr=%I+-b}ZTjo$^==%_pHQ7V7O^(h<35)AY~s9@Pr!>!PhJH2 z>LKAbfFO(Xv$g6evHUkzD|gL8i-FJkg=cQU*wwYzK*nAEKWX>|BpK%Gw59U3@TFJI zehmijd%ka;PVxC5hZ)0LLnCRsCE5iP2E&eNW%U0BeF*UTVc?SYH2%X$Ic(b_5zVA@ zNrutqBksls>382;x7-enG`;8`Iri&VM(Oh_dz`zzmnb$q;$L$Y>XNV>hoGH;%Y~RsaWLp>#o4;`46;*;RiTZ3p9Gk-w_J1=way z@;;2pM2};O(yQYemELG8O%_=%N zBepMv4aLEf?KiJhPF|3}Y@T6BksvS`*}~7+qLzg1N6tB$a+ZYC{=0^?JiT6_I;t1P zse0QaEVzE9osxs&zu8}pPraHzmbWE^MqFB_g&*RYj6E9 z2Z%bCLb9(!1cujzl3e<*Lsu)H`ufHtl7$jb1X@nbk86%S_F0m}D)H3$4o58euRpq1 zEw98ZgKnMEPLp$_zTJx8V54m5ICXZ{O!=~RU5N4%JBUY)!YS-K8>;~bT+?9XNa56T z`&wdPhxCOiUi;ZEYrwE#Z-!n6_!H=j9beFwT^1@%oxgooN|S;cY$#Mgj_nq$ zSa-bQ?Qqu*Afq}q#=3)cFQdCf^IbtVI^#VRDQa35Qe0BMo>a!7^d}Gb)qgX7eZQ=b zq~y@t$9Si%;{8dqaOOzq+U(=*!DFk<&JT~gKnBE2t03T_UA^%Bz)h371)RW!1VGYe z0Z6N!&z6XBuU~xPm~%xOarH-}P}W>kq_{q9{un98|CcxK=RHb}TK?R}TSepNG=~C$ zf<;{wr7IZ#!QJ_xQfm!h1=vl_5CY^mWVeN)!=6PS#0&pz!l^rd)Aaqyev4K{=C7K& z{|njMO$*zkR%cDUH5dG`gr_GV`lK?d&e1@j+aTC=V1)Vh2T6;UImEXswS+`l{*63JG(S(r%1PeaM7+I0caFTz{&o$Tr zOD!`s{1yer?!zdN<$`3(G{2%=+D{#P=8*>{Y&K zD-pJsnJX7T$v&GS*3;WxO|mLBO_AU+@MQ()Vs zlqsOZu5smkZwglJPeTPaK@=RSo?zQSFa=eu+N;Y$#nAI^Rasvnlf{&U?`03GoC2sg z-Z6z<&C&hD&KP|Z6JkfvB*gxZee&9VFsWoH%H!zl`nDqT$ETh}#WV_5&+@cDDcAGp z6hVqq-)q=EafDCBzxz1aN0ssrEpX6>sP@N9Omx(m&I1*Trnl!zBnj4I5t@xbu06ls zu|3yNKkO4#0S-Q1@4u!Z{c{iO=K&eWR#D#!&+Akj4U903b(L)!^j@~9Gnaf;S;NZu zLE$dY$Yx1hhW|uOe(Qd~`2|tUspD!}{O2s1M~css6m}x(p{y|{6bW9L?;W2Al776U zSVl~+q^=J5B>Lyd{lsq3kn-Dpq@d6C4(CJq{)C5Lnhf~7e%BA5%*`9~uF$?ma~Tad zJoIZGZ#c?C%%!Z47CJ;dtREkDqs%%gZb0SgNpF_~1jMxZq-JY)H$R-R-RT$u2N6z^ z$6)>tULNpR7Zsj?kwa7-&@^-&uHW8M{)azdQQW`PTiW2oAa4uuIS6V~Fb_?xBZXM+xGW z>!raDsp+ibIkARqQ(Y&Ke=mi}lT!I+ZeeC3Ke@&t$vOXOcYc1k&%dxBaK?PQckr*W7$4Gx8`BO)U_kh z8BO#Us6+JL3`NK68fMN`%OIS3yq>QYvt!6QQ-00~yUFgF1oMoq`OKCxG&qoc->n+1 z&bo(9_fx-1=ezxseUom^!04r}O}EZ|y$HO+L;aNI-K_j}cF|Y5?yB>H{i21Qo~l?p zi%ex5R>NxZiB#Uy{e#BnqmTgZ$RI#fZh`(YG$AP7%sYA{gvjpvd?&l&bU zj=t$Gi6y6|E*X(V4Wz!&*T-b8xJtmbhV5daM}&{!b@2QAPM0VGZkLwj|2YX(xPdpZ z?H5waEq%0t;K#Gu-AIHG9}WncG4d7tX?*Ap-3T$*FEc*>Y2e1e1p0cza{$D3*UzAn zMB!b&fyJ6ldDL*_r8d$Zm%~%XD#odw{uZH-LUW=y;sZMbXxyBftuUYQG@E&!bDz|@ zm)1S6_Is>UtYO(yu=^#fj!NNLD4I1fbN92|1B#K4XgALr_zTa^{-2&7AI_JT=w!G~(g5S)@QKs~J$(#K0!LU)Qz7VdRQvm-?d+5y%Z!M_y z-484EGhW5o=tkqxj<3FPX*K>0b8;g_P@suGjB6hJA>Afg@X@_x!Iwo2FCyi{dRp89 zkNz?(VFmd3!tDArp_sVA99ev^;SKWAt1*;=-I0u>0Me8GYe(6#IdSp&(yrbgLgdq% z;qSu8$z){XdpKdeasJ}bM@AG&_1>UYyfUkM1k%gXCgrR*AMGv~;tJ8mtvaU5vLDzl z?QyGy-w?VL9yN1s==4Kv*ENn0sR)U36lJ!^)|Jv5;S1Li2gq3krbBefW(1aG7BoG~ zG8%l%I2K>{WH-W@p!Qyn9!}ShkE=_aE%I3?;5mrGaA`CHXeJ0RS0w|LEK)0^`E_w{I9)*a#Gk| ziuXN5CSG%S%`0Ypg(J^Nr2Xru^C``*7)RTH-Z=MG3v;{EE!*@~_? zc>FBI-l1|x^?LG(sA;1{Ur4xC_cyJ@=vHw17Ki;38=T)b+)&D#q7+%QY1UGKJ0#vJ zGQTQAhT?TagUz;HRCHf2;h$NXx)+D7@+m@iMeA!xb@SDb;7(aQ|G;L`c=A zZWPp=`-kbK#^d$lfmai+dLUtS6)rYpE@CxKzJZ-CSZovG>3kO2PZ%`s5p6&P+jLdk zq-ebRtX8~v^Z{9 z$;}*1DN&&7>IPlakB)^HJ;J>iZxT}9j#|M9F!{j$;u_}C{72elHU0eT`F z(%_j5@<++nCH8D2N=z`$2}r1c{(;^&>KD$+t zC3jEczQj&ZtZ=QWOyXOJ*_4&!^QNEV=ud8^nT8ua4^H_?@?4-;GYQp z(a%LUeN?H&eD*`X|2TS-KVI0+&7d5IGMBUQE%n5wt>KwgOEc@SMG6&qBAY}zt<}+) z!Vd8K+sJ*xuKdQM9H%dP+wY8RbJwcJ<4j(9=}lc__5UN(N~48rdP3mnS+s5r(3*q9azJKNWBiq|-<%5>|L_fcO8Pp%B3wr!FV ziiJssfgXeX*G8fE=0w^XcYju&n*wtqoS?Hoe5=QL{Q*NU$#YBSOpC>V_!;{Sh83{- z&`(TW21ic8T~j15?`R`0!zJB8RKM^^@O+gcw0P`_-ki+Y#IT(1*Mf(m-bxq#wK?#8 zrw9M1n_u!`$Te_znR(Q}u(5a+whPIGVRWpxlJo9d?j8Fj!%^5dPGWX%O=DGQ%&mzJ zk(MDSbon2+;v$dD2x0Jhx)pPo#EKC8Yl~CtD_zVdkHL>Y`N<8bH28G)4CC&Yh?;S` zvlEFnfqO58=c%TB=liGxbpQsw`Qd)5Sq`7ZL-e(2uwYdh=vB)}oO<76MXeAzaF3?( zhrg>+OV2yrhbl0n-?6EWtezStM@|E*sg2BVY-37Do0*O z&(pf13k=zX>UiQ}yw=#x9iyr{n*-XN2lDndB|{TvMtbYzWqhGkNIc+_Q5v>q7Bn=GL4HjwjRzzGg=ghO#|>F4pJ5n5MZn4FoSg+7}* z7diQSVl`}FRitu0Kl8Brqp4@=B{M}$@nG&Gt76!^e`itQ+L?go@5j1Dz??9)MpMDd z(}DmnWZe@1nCWM4AHO-~-k~a;Te~*J_)j{Uc-tSd(-t3jB-)eBj12qQst*hKIZgSfCfBA9e*8zTImHvSr8x)$P?wMXqKN_nW-KLb^m9?WX3SH$+4Ve0f6c_&T8~zK6o?EAIz!t*|nK zZ}8*;^`;*%mHvp%OyfqUjR`d`e-XbJfYzz3nO|QZFBYA9dX9Le_d<4UZoKAd(ol8H zQtyA#(@erwhB9k+8U=Vh`Pt$CcG@M2U;hH;2^!PKy(l^jY}ZwDxlws8 zS9zY5;x{MZY*6d9{_qx)GQd2DAV=bmf2) zpB4yZ^G7k~Dt~6ri~Z2&)+-yoBmRQ9R{-0<>vL^%4)8sI8DRLJsQTqy0z8u#|5aiv zhV8aqcSf|7bk(A2xw*T=`it<*7N`EBUpunj8w#R_ZtBn$`6To5L=7>4P{)b36r5>C z1XDGxvny#Efl72pK2nOYk9k(IV~&wbID7k=l}@09#+<6!Yo$I8TR1WrSf44fXuVd7 z4j2l3)}-Vffe?g_`a@nv)vqdlC!ipV#LVKN=S{<3&;Nfx_rPSb`*BCHkv@CgJ7?ui zD&{~V;&5E3lX!}J3tCHkA_H9!eJ8B+ORm?<0$_>d#{yfP@zV1=DiL-vv7PZ{H*iBv zDHjC%*wO!9?tlwOmP}15F}PWhRGU1S`V`c6eEVfhD(G?Z*#ew*Vnc5XTej=J7AyS! z<(vH?J>`watYA0oYUbOKe~<|p{Us3 zxaLH7$Z$4zrhM{YyRil6D_!W@4bIn>-&Iz-)%-z6*EUebRN7`r-(994R2P6ieC~V% zJ!|bdfR86iV3UT7T?~EDlIYMQ{OhxF?F7&BTFD+u5s%RY_X4~d6Yfi6 z=klhN809F%p5ZCQ=*(yA8`#dL-q3;fI!~b6xqCMG8}!N$2zU!TO;ewsKC@nMVmBSP zyA#wX>iw3R618Ogo9>AzUvDta)lbA0R@uum56Oir@MfMLVd74>(3R7XFj z-+n8@!}uL+`!F!>rcD3^2lU2HA1wB+>XG9F8HA>#21l&);n{up^_Z2vAX>9VxmtJ3fT=BY(lj1tK5%C=L1&86t- zN4LD1Q(AD83VSK5JY({RNcyW8bfW9DkJoHw&vR-yT6g` z>=c*OHRKd1I4FQg_#rTM3VU+4>+kG%${lq~(U;$)KcQZC+Pc3=!2JOK9snzZI_t_Y z!YX&{oDoI>3O+`5>bg`3H*snBpdt)@{B3}D1(e(BK*(W2!?j{OMRNY?f`W>pKpUmB z!`xfbg)Po58bwFPWUD+sK6)#{j&beV%Dg0nPhNXt!C~N@b#zTSow#c7m)ge4s~nt3 zbZ_Xq1E^Vj?6A2;TAYY4lM!wo|Dc5`}4WNLL5T zdsV#9vL)lFnSCQllUv+cBrVRp5+kxn<6C@BTlQ}F{Uoq@2RX)c*rEAI>>#R>16H4e zxc+M=Nk7Xd;BWevR*vM85o7RH(vECUju3ZJUuw`tV|k96l0az#?XjryzMn2e0ZXVW zUnVB)giOsV;m}&R^C;SG(q)Y{$|}hWCJ*r@79$&DU$~KT;BY)*ayD#K;iIPVvKTJE zXa`S^B+Az4q)y(3T;l@P+>wt4Ey&)s27?N^EiC>--ySt*M#Vy{NT!u02a~QBqz_fy z3@7Gs=|0zr71R~A==Q>;!=(@(kHKw;lX{LqK40Ds+68v%chA zi^R#Kh=CjYCnrlK-8ID{2=J^H^{>e<>1vZ6-QmqdJmCRq#E&14a?+tUe_ywzJ@OET|`tP(i>h3xbaKDu_M^( zDFszQ1SIaLQCdTaDO~AL8jUgyI4V-5zL6_Sk?S+ve;K@b&k*d2+dN%gO6prS z)DV)}+GmKvgg7dwGXGFebv{3O!#B%nN3bPM>f+r;2V>7>T4vPR$)IfO+2)08FtWAT z?+aI8vPU>VX7!CUG3Ql%I(q0vC3_k~69dNqV?OoBC0NP(YnqBEHexA@2+qbHB~4Fl=A)W~dM%Fq^=BUkahTzTc($ zJd56xdm>x!V}KM3)_NaRs`~l5xtP~-rIS;s&J0udq5J{LiME-RX5C|6tgfeZqRjtj z(yVo>_jF{1>_r`BP>D0prE+-fyoXCy#fZ^x+UTZCp3V2)UIjL7J}paR)QRj$-(6fu zWM5gH>?Iw+i(>62-UH_QGH@8|H%K{H;_X0*H*Rm`P5F!c_3I8~(t$(SYZ$mSOm;`lIn{VuSgt<0WvhsCROlB7MSaeF*m1>WDQop(t zAXwt16l9!cJ9z4^`6S0OjrPAcrdCWu5tplGuDFA}r5j?a>cZmfMMdCR5j=fiiVL^A$jA+Hl5lcjF5~+ zL7z41q81)|n2efdzNqubP8HrjXyY39^Uk4L**zW)KVeBZ-X#4(&aVa>idaI1wj`S~ zOeTDHD+XxHZ<`9Pi}CHXRFjR^+&xr7ZJF$|?$Mj?!Gc|emV%ZPG=8z3yQXLwYBZNG zR8@2L&m=REe5(%(1u236v5~lOP;7+Agn`fx`^eg=`MM{4457HAw!R=H-xJO3R=?%* zcfbb_s2V07-U^LpWe8ty$fd`e+p?Umdzf!!pRFCvT$zjsPI`Zb!#U*%n!A^%ojRdB z$Bom!NAUSM!QRP}q!L#u{RnMWD}q#wmZqrR{MrhJUxk$Jk4~kXAvTW4#n6I5w5`;r>8u()@TP$NhnG?2KDA;IH_GQnl2SJKGok{C6F)kI%Q>8_ z5}`6p%8aky&|zeC;&1Mf%X)kuMi_HdtaF5%2OOQuTrsae>LUiv{6AKLxbiy` zRPAaK7<@qP%0poRuN(Rcbn$rsr7m0IO+Eu2P}i;^MlJldxZG#dL>qf&H(7g)>W>GMz=r3UzWFa4WMIN(>O9gLL@J13OC;f}LhkRjIgD2&T2rctxu`>w{ zIru+nB3oY59g2zcxg`d~Z3mU0L>qds?!KJ_B-W>bW5AsS;SG8$`BK#F*nIP`1DWDv zP)q5;l~eD-(1 zHvB_qUKsfF%&7l=UsY#P4!?cQ%=I|T`=iviA$!zSC}4&5!XMY}ZdwPESiB_CAG&8ArhHq3KtJqj)LOvu;bc)h0E2NYAWi z`K!pU&5pFm$PWbgjpS?1G|hQ7S?Q(0Yr4yS}=5tFG?AB-!8}o3nepeJ_fJ+v0KECWDiV)=+{Bvh>zdMge zxN}5qZR+FYb(DJ8^8`Ze6$~~sHb!y|r;A}1Ot}_Xl2q#|LkUMjceC@IY%Q^VjU_Zq z9Y@328&MJBZ?x3I>XOvxtYhty@G*EZ@z=pl<_6*}D)gyPa9{xzGbZ+X+HU?DeoG3! z)zHz=BbS$Ha4_xU8Y=>;JOA4w)ew8sy)B#fqn#ynw*oAIF>d(sFnf5OmrD20d`(a7 zv7|@f+y|OXO>qcrA<49TdwEON2d0%3cdhpH_YRV(Z`}9(ep|s}K@j#F&nGtH`g6^z zLz>j3o><$WcZ`PHW#3N}d+V<_yCde>?2PZ|ejZ%bB1Ck6i@_kWkMxhNv9;}5KaeN< z1Q9!DH^gB!AK#pCgFIRaZ@fI~)+W{zYrijqnS0nwoaSPTs~g*4ldz9)DkvNf#;_e} zLbSHJ=AN4ud(HMI_XxBsK42g7W>RZN8A>O7C6<`Xxe_z=nah6zlkgM}L+B7!|3XBg zaB1)-J)lX!#b@33Hc1Fn5tPUy1&~lNo5@ zHG__I;}|#gjaMt-l2i0JmDTcD8OV0UtPB1)hxAfhcGH~MwE$Y76n5)B)5l%Y+>?Q1rL4%}sK##2wKv0!4k~4{TCO$Awm(Ol zh^+`Ny;d&xW9Dot$DeR+>}|~yi^Ax0?C{_i9f)F#*z;#W&4^q*!f;EJQddbJ0+Vze zv*$%_;`e5#LwXY|D*!thvMRWt-mqMRb-vf*NTsGtQ2Qh=YUsf76g_Wh6N1U>NRp<(4`2_s&my4v*$d-owqe|C*k}BrVRRl|TwEw~ngU zrrm}lJJdH(r%5c8g(n3djkJQrh0Gp%Eq4}0jq!+oOoa%xK20?rTU2NB$(IwfOSNt| z8|a(hYiZirf0;gKM$kcTy82fjt~xmx2PvSb?XR_~ydUaaaw8drSW_udDCo3rDP|pW z-EQ3DQAVZ0)*0YdvvqFrFYeP44ySk!Ys18>jSf0`_e?6m7sk(8$%Q5{VRR{4vRg9} zwUf%Q;Rcy3Ib(E1+N6u=x-*>?Qu4Ck?j@41{GstOAHR2R(>C`vfdQk`hTc9!ws3BSSIhq9TlZ(Fr91*vawDbQ43LGJ6(`%gLMCh>C13Cb zS&l;0Tbt?diE##@@vcqbc9Wj;U&Fv{L6Dbcd&y=syOchaU`nm*SsJa&19Td4zE;Fb zXqTGGq|RI~Q>I87d~d+%N@C%?3h#ySl;UDL)&k54wV`%P0P(vs{>NT5x!Kq4a)2p^ zc)T*|I!`pw*FTOOwE6W0-t=H$Ics&vt;A0Y-qiSh1Dxa{cUozS_v|06JhXbp<24y2 z_owmoK)R^u?+r2Y_wf;7eEFkfvht4~*Qgv6Smg5NHYw!xWWZmO*SP(Fp06z;sw(;` zSXxK@e0mD2@wuKyY9iu^a{DFDR>dX*kqn|I5lgV&vL|;M-6rbDYrlT}Y~KJW(LY89 z$WChPcel0@&u#G7X;;D*$9vJTRXH~b?9{sLg8p{ik)NEjNpz6%Lj~p@hIIH5cNYuc zXgO^yJOs6wei1 z5W{@Cx#0a`jG8}DGGX2UFB9~GINQ9KL@(bv%97UHf~`d%tOU%->by~ULVdHSr=O| zR9k-i2)V7GA$Uhv1nRU|KD&*ZtDOH#&-bO_2w_ZQob(p~g_DqgI4L$bA zn=ddKF8aZ2`t#@tBv|cH&Qmk}9G}qa*VTGFVl@iUNoT62H*GK-5m8E(DXiB^zBn|u zN`bq{f+h_EQo~P%xY7f46PiE9-DD%D5F{YD7}6mN+I%p4ZQ zb?#4gv+y;CS`mJ^233cNYpVujU#&)x@_2h&4AINpj@=DL zC2`4u4z2r2k$1kCC>t9aqRXs4*9mRgEO8T_hwiTO0rmxIX$3W4Lo$3dph}bMnqwh( zAT+|B*Chia0yjlPMK=&*r2!aB;&WP@RWgya^j0>#YFT+%Q}qybZhm04{zY0c{N#;@K?hy2FR70#L1uf1>KPsNjW zuI;cHnG5c>R7FzjnF}bSnpJ*o!Qt_6kJkT zn#+|fanr?g2ELc1HD0XO|j1xn0sri zzI&bZbyjLmU9|eHU#)XRf8oG=^yonM-h;z5h5XKL>}~G0dFWI_-j$FXPUC-~1dUcVmI<*b1m8;O?z$Ugbk^v z(1IHQw7PuZ85tR4hds$++s&;dQ0UH<1mZujww9Nej1{GSp*9>DGRZ?)O3?b(J5b^a z?w*Os$$3zZf3g=7wnJ|fo~RYme5V4%!P*OK#K+%0 z8SXZEJ~jr20w75OwGLhYB{qxNv@wv8LXz)5Xn;oH!wwAfQUkuXw45JjPn$86%ft=K z(Po*Hjx-2c<$UCzc~4M97r~t04W+9*s{VeU%q&b zG>SjAj?^P1{@07-|MMd02A9Ljt$KEV(`Qdmt@4xEY0fnU++`%pQ7F1D6od?RXm&Ae zluNX_L-iCjHlyWWW+(Qm6sa0CcmBrQyp_MH`8!HZ$Hk{108vCaR!|qN;ia&uyjqtB zz+tW!DIu$@q$|3|0ekvxr4`7ePPf)eAG`2m?-ysFz$_r$XF^G|}u);!%5eOypL(gJ}G+otS= zRt&nk(@nS(#0|B5+yk!ain?{v!TQX^~5qFH4Vd5x5E+!zGfM+6jEIFuq!O+Syox+4h&^TofA;n~&E1>|>Skg2hdP z8ObDKI~d>J-(O$o6+p~xXlR)4E=TEg(uZH4$jDx+m``E94(#aug=YrV#g4WRlMSy1 zx@F7tIK(0%u3n3?&n`K7{{qST`+tHHzQto#*qc1eF-k}kw`@NkD0e0#>N$8lq7^TlKlc~B?&=L4;$rql_3Tfchn0!Q`bJ68BE&DCx6R)EMkZb{jp6=LZYk$ z3dL4%w99zdr7XO=!+$|I+t?>dTCY0_d1CG2(1)lGI=nEg%k|hTq@B#y> zHBHXi`BB_&Rs86kZ9i7*92nwVf>ccAqBVQ>F;Vz6`j)-FkN`=!0-DI(CTx;XeOFR% z<~S!)zv|1iNh*@3h5!PoU`unyJ(o_rV(4nKi{p!iE8Zp&K6)wjBT0*`iGMXF?-!ix0%RSH4rf%T7H7tz$C)#{8 zYqCNe$IQTMy^B=uBV*NYf^JOhMi zSOZ_Rb2$Sf=M0I|yTh~5xnho}rZHo`2N+3(iy|4YbZP*lN^?}`bxc6@0kRTbpyiQT z(|RGWzOwWEV2w51hW<3?R_*bonM+7MBo?pwej24j)qXMnID`IfDzL9FYM3+?(H*fMEnsr;F6GT~z^gz9h>QAg=S+LDY2FD>7X*U@ zif++Yt-|`I4>Gt+eP0Cv2{&GH4K~qNePD2Ag&~*%5)k*aHWqShKDyETFrfJsD)8NU zQ$0DH_3`qWRu*+4j~p)#jq#)7Rc7D_dH%I@BqaR|3aT_ocfyw!liX*@u6mUAKQ5MC zQqG$zI%45HSKBk$N&OZ)ZaAu`{Ta?G)+4hyd6~SQZCTwotx>6UEbfa?r`6CvMEn zv2;YEko<`_)p~5!YT?pKP`fZl|L8<~t7i2NNVz=1|0iIMCbrrkmlh6RK>&k4sq(6a z-68N-bzKPl&HIN{HBqxzS{Y(;*L1d^_>1Dy6qZbwvrYXu?MAg?Gc?(+=^DvXvkRNa zSc}?;rcXY<&$CnYmz9_WM*6!CT{a_=T5BoM%`9;CulJuit6vh)OM9Wwn%lq13Of-!!e+?B`qoe5>y4R$JKd zUfXdCLr*_tLT&zc3+%AIx#@7XKe*s%I|jccy+Uo6GNf_X0LqcK= z*qyNT)M|~aI+4!$&yEMpy$^ZLMfY|#$kfgq3r-Krhbj!@c)GMF8k8I$dfs*)N%fU% zY9xds?dO=Z*b2M?_|Q?$F4Q2PHy*QIdUAx*K-F%T5hHxfWj)W&$j?RJH7Yn9VdKzi z7SW;}%5oFIr7XZ@?CdQ9eDCr!~I!iqh zXwD)y8yxp-AYNAiw$Sr7kl*v+K5oCGY$+4(Zi;oN0<*WZwKHY?Z z^i&pp{Z{--c=SN)^FhvW(}f@Wrs8iuARw-JxD!-OXEz^kq*1CJ>yxU=jh9f{+Q`mY z&8p-(wDIcLWbpy7B&SxEaI=I@`5@Af~aRMUFx;{uaIQ0(?n*nkGLMHwOJ;`?|kN3mWX^>an77+-}n#a2gU$xdlM> zyG?){8NS;-7^5FW!hco6A*OY;{8_)0DM-F`P$&+YrEB_bQ;mzzniSrgOyzaJI*>C) zs*f?YD33VHk88nJ57&)Mqb4^+uYFrjz{^mI53InxP1i4XcyJsxPA% z-@J03cnKYTAr;_nD4GZ{{T}?0)KtvI3f~;e7Z)c=0gT$XaO^oM6q??l?SLwp1(B74#S0dIdtr84;tPxUOMkwNo z(3D%<4Y?J8zVGGPn;-2wWvAWbwza`hr>vQz*9@P0&U zp{QMTTZMd#1p&dci8hJ@!jNgV@++~q@0@M{A8?)8aQ-PbHlu6A`beqKQJt^Y7AgE! z7UPGJ1Y)$fWC4ow%9WHWL8WH*jzIoDJe)&-F{aAiS>l*3{^ECh6y;tsC!9>n<y*%&xLz_^C9=NV8>g(be-N~@XWR?yc) zH1Gb7@e|t_>=@lo>uG31NdZr3YpAJU|5+`#7NOxPlST~$M(hlv9Pfl0N7t8ootTDk zZc&)e_wer4`nywaL|wY7U{%sw>odGbL6X%%xio82Yz{wALhiZLn{#i-8S+Dpo7Vmb zRAmtv>`JdvVMrfyX;)~B;8K_TV<7w@W+=6b-$ewj0L{4MZ(QJpqloGj!@Gvv-M7m3Xdqm}+(z<=e?^jQA+PvLDwlE& z?WT-ckq%$)K_-{dZ?)OiYZW$1>b;o- zm~{Mo;XooeEI7h`pAkctrRS4Jx*T^!6EkhRTdV;7=k=8HkFNF26@$k73hapvuKQ41 zo;8R~H~&UEdmi7=zj4-hRABDcvAcmCbM*A|SLha5wddAW(k%|wcG8y?w(s1_YYxjt zDm#iHENYx5OPI>kUS$a@J`PetPrLfpU1ZMJ0L1koObXGdWbec~=*nl6K*-?%Xb#Uq zYl3_vEE5mt<%8EC%aZ&suO)w-WjS>Bgi&heTtkD%XWH=b@$o6NEG(e)WhNCl#CZo= zT-^Ax8R=IMOq4?(5oVcS^+(3Wf*VPVTaSXlHGhA^D}0*P4t10Kp*7acVU&2wdTIW z0NTJ`=ieq|{!q$!RJLz)Er6l#muxm3{HMws`d*}SChc7Na($a!^Adn1DL1sO=An=@ zMSm7cDKFmlf!AT42{7eyvPdWqR6^l2i{S$6FC6PSKGv`3Gu5qeI1&j7!`wufBiNw zQYpa7;a)=8W;kA21*0xdahMzJ7n6<#F4k~9#lPvCEJ_3ESm~v#+-2pTHY>EIh-!tX zj8FysoP89q&}95P>WQskk*va}?79}60WIipoU@uk4PwZMIM$$5>>}ZkPs&raEo`Ue zy5Ha<;o?`fHGKytVTJf+H&*%hrPEBifj4)MM?1B5G1$-2pS0P8XyT02sT)- z;Ox8zn_a=@|@Ve|4(x(%wDr7CGE+Jn@T`0rV(kVxcg!oXzNd{(@x{3 zmawhoPnti&=r58Aa>W0XMT0h071>{2TpBs<#%CaKKR&Hu9tnxWMQ6}ucmQ!( zrGA5>aH1Zg!0I%!q<06e23Rk%J6Tw1lTFFat5vSOv56I6g__)o`r538x$Yxmg8MP2F8e9#8J||9-&+@0lU+>|48hqlPTXTEJ6&fyea)jQ+#fx48DJJ zpdm;-lP6C6T<&wa<9z~nA`~$I(lKeiX!*blL4To|*jWzPv7f-4^5b=dy!CNMp%7J2 zklAgn53JNG-C`$pyA>AqoTy!>w{AC;d`|(SfGHjx@N7s1I}o4$TW^Q(SMIa$2dc5qkDC(8`YGGhu|I-2gme})H66ifW?A;*l( z*8IV_ZPndr8(dykc=jXuV?TEN<9l#2)H4qOh&=~UwM-H|s9}Kh`3b@>R=mvnuT^`6 zOC0T1svNVS8NjKchU1qA)OsXh=55yw#aBrYh$Oo_f8FjVjDGvP#%q=M0#2 zet(n%oP)oky<;eX+9e5^DEZh`+S3vv_%}nFqk^PbKTe||ZK^@B<+4b;!*z&Il1it%NJuIBRJ`S3HNoe5z*|5K~#KUj11pkLk zobW!xxLS#je7e?ced1n&pWe|JCFyi5QAoxCu(^>;HV@ zZ$TwPtoIi>1m~a&ynhBl6B|X5`0K@NcWM-|>O>ZVP4|P4{H#8ThRUfj|33-_=;AG~ zn*{5{_x)hO|B`F=0;1pg7d`>Eu?E-(ES31a-*5u|s!R0zHMXx6{X8~p#LK@)u4D7X zc3|4Y_#aJ|Bz<)uPSm2C89b1biW&@5)Lk&-xoVRLO+1g^v`$=0~@Jua14$J zR19bxn16RTo9}^R6t_^s6_4IL|EDu_4=j>Bj^m&GEdlfo#gG0eTlR+Q@3L^eA6@|! z#J&;w`{D4|1M~SPNn`uQDzpCGIiJkmhS6`z;_d#K=aUcrv}PD@{&#EH^uJm2VSJu{ znfa%4LJvmspKjCv+!NsCTHOB%q0|VCEV;;Sv9uw{(#5=#Gf=m8?7|L!yWz6y|EWeQ zfd#>o4Nd%t;3DkI?lvdmnN z4pkieH;?-4eJ4Sr(llWSo5}p)X!~v(>CX;Ajx4%}mz^}lcKiOeR3OX&icGcVJ6OcT zx9uD!&@BDHeu~6(@3A>V&1F~4YC+2SAReF#)WH1YZV>JkV~U$hJP-jV_Ya?F{8^Sh zqYow(+#516lwKt)EOA^AiKG=96!iX^XX&ksKTKEWo9FN9CI1;g=F+#32IOg-3B^N6-~#2p=Q|NaJoIRF=1xZte7eef{9B~>kOO&9Do&)1(GUhO zOJ&tCD8KT9c`=|g*;ZY46e_O*$D+)HhI^fMTr6&FrXk*5OnrM_xj53j)XdE~Z*)M? zrer@f9G^}9xsBF^1X-Wz*MFozB~P$TVNN zm$)w+O&SpC$@KoLhR~kbyB?t>y!70%xx1mne3nt{V_aC{g+aw#qQRXBIIVS)52QQl zrfy!+%~1);t0)8ZfApoWeaj@VmEtK^prCoQJWE0BYmTf-BNCupQ7j4U*oiq@1k&wK zGuF#Ys&9t3sJB$WC9=au)B_;@cbIu3mE19jdLgBvi*@1GcM99<8i|#b%`9?Y=}vQP zGI>oV95gF8>BMu;zKP`zl^9#^2n50oBVragv2G63$0^o^%o4!#K3L2QZxN^fOLuRC zUohQE-nY>0q30(yz&4+{v#QgP_jb2ulDRgW*vpTzk{V$VXy5X~7!6nsup(yW-ltY? z{oMQRm1N{TR{7d+=DY|(4$$#)b;M@6HN3zk3o@3OWL#wDl&0~q@6-(Ww5p(YFx zP9c()Zi_{$97M|lpN2$Wq_F(>iQV4GtwG;l2?pY#m|uZXk*TWK2(Gl4#r6TFcK{?) z{Bz%E?&fqD;if_B=Ej1vU#0vLZ>!?d&j^0AQxm+e$}P5GpnCr?`oJi52sQ0A#f|~! z+*q)x9DYYCOQQJEnFR&wz|kac^=t~1E$()pQrI~v@ng26z{}FYJe}x;H*5kZR|@3Y z>=~YsK<`AEQN%$mI_A|j(@kGt>QoeK>Q3&ckAo`PV9g018=^S{2vV@+VWobRlP2N) z_un}Y>B}pGDW=YN!6Y12#_;9)*z(d@^L>k6#M=ay$b4r|^=eHKv^+vt8w)(HhyIQ1 zRRb&k$v-`dO8m~2dDh_`4dF!Xyy!l4_fCLJfeLBPr;-kqk0|mn z`x~2<57TOuR@Rha+x7a%Ij(A0#j%5L!*hhu13%uuXpniLPO9&+uskhtUf<|pRd6=y)B9&97aXu#*)+GDeI+Ep$>@6ojXD?n$Ld{ud zr*Ktg-FAc%z98s99s{xu!6<|hIG`$6TkoWgdArZzR1f~JG1iFBJhIW4*E0oSp!gNO zy-%aK70lt}^!k10a`7<#Tac5dk(%f2Ti>;M4Te~o-yO&u`u(x{#DCIi(K zdyzADIh^;I&CQ%x#7KeHmj+;6|MRKIkahntsH}3|<@WOp{=QC10H(PPJl|)OqY91C zb1^o&t(xiOKI7}{WDT!kc*X+gA0J$a-%?|AkMbMH&N<3Zw=64criyMb%20zf3s6e% zJNS064A=y0rK^pbr-`XgrYtuO3ryZ%~dd|cIXUg4$>)H{i)&sBzglF^u4%ZGDTuC zZYes8C+c95a^^8r?9n+uY&C@WQCADR`lZ8OouNYq+fJGSt;m?ptuzWPdHu88OP3#; zHwlPQ@e?YNs6shA=U#7VQNYu6X)A@% z)KnurlI6C#X}Ku)i+lv_O=2`LudLD=UT_^S@o~b76^P{A-#?ZCB)Ae-_4k$z5fsee z);3l`3jb>&U5Ga#7{ zl4}}F;fT+oE&~w=1#*+h_b(*cR?ABb!VEa}%J+U2J?}z(cCN|O2&uGMbIm$g*i5ld zbSBAeL&VFAD$n(RtBBX;|~r+NV+93bx{QdzP@K z_P)~Txw>ZH+lD{9EX9Ka`*(tHW*~*_COVK z%mtRFss5^M+Hp$?SiC;!4qqt^?ZMDogY`ESMNvH(dm^Id4G7D~s)hhWQc)Dp#cAp_ zqHGRqg}((?w2YSZfh#g9@5pK`36)llfeWuR2m&X0{ZdLgs8Q>NuS8Hmp@#OpoZ1zq zA%J{+1gFUP$FzF>%s)9T*98zj_Aqdxmy-(AnY!BYs+tdC;EA0@F7J+tB(6oC;$7Z; z{tsQ>qKi(e$@(lQ#VGX+tiRRwFseWv=f3eh6txi1O-wY10jzq0KPu)8_rD}b7Dnhr zGUsh$B~wn*vpqko{c;HTGOfvKjLzWIt`M(*fUWS~BIPd+3(%(V7TEfse7|%`t7akF z#EogjB)YZrx;u zcPFrpCQE>5ndq5i*Y^MM!gw2yytlW(f+`d}{^^0urpHjD3))o3;&t4&BXsFpi_UW9xqW;1@{`P4v zB%@0CD>GTK#Q7r)1mNJ@Fn^oFzo#x(jZ{yj7J-Ts`Uq~g{EuvTm>bcXvH&UqD_U-{ z3@3htI@_!@Zr2Q`@NfKf}&|@Mp2YDZac6{L`M|ZwLYy zJ+aEass>`%9;Op;NRyb**&lLgv!nkmlmtw%A@xJQY=5I*Kq>9-m4Ev4^zokw0etJJ zKf|o}0IolLviqd|jFF}CWromo&-$~*Xmh3;nJ*g3~z0xp6bk=b-i1&@8{qfk0Tk`4?94p z95;dN-K!Uc`|h4w@B5xK_DMnFgzL|oZxq3Vl?09hR=Jb+w`cs$=R>$Yn1AAp70(*F zAy><8;M?;4%)(P6p^M8J^Yel+C#ai}FWwtubD?qlkmJ?s*l6=MO|>3FJ{E`&?`uXw zcRq>gQXsk65{9)=fmH#CTWGxUAhcj(m8f_vV`Fz<4%#+h879{1&)5nsu>T(AfPW;7 z0~K)dR=VMOn-?xvseX-b8wgolJedL|;oIs?8(5;sqr>JIuG1|j=Dd9&`=OP}yOlkg z=T>#9*3J!YY0sSJyyen0p?eB$_qV{o13FO1av`0(2uMT|DnL)dowc0A>~$Yu`Ry!z zB4Fw`b{OcoLFLwS6nCdlx39mhP3TN+wfgRyGM>woaKU~1vxXVWb-syvH_*z_8<&+* zu25)KytQ?GV?Jav4f4FH@4yYf6sOCpy8qx}Zr*rnR459_hXzi*4ec?2VQri9CZ5%B*Z2Tx=@$c2t{7xF;C>YU z?2mvk8VK9@B9e(|4dNlK1nrwzQi<`YlpZxnq7inESVXWqp)=#>*3- z93(5q3ix&><*5x#+uF(pQ+_iWJuy4XPo!*dWnhDbnqTZLt-j6NpEvHyFsf{5+Bu3T zaW>52HwRK=G*whpn_{O~_M5>a0~FJ4;Z5J0N7k*pOV%bJ5UZs1?{1^U&^i>YA6$w> z%bnSMy}5`pT}&17N!gkmag}@HAeCRcj~FDRu*5xVY|6}C>_5j1o||s_2|wEH%^eDG z9vC?#xaxiW^w-41tKBCBpVZ!ioiK*7G~BzrVYQ6c@6?s$uggYKgmZR9_D>Q_S7-8! z&2w%OrxRRCo*J=DkIfgc3M_=x?-m55U-KpnvwiV;|CKh&%X4cnYQ?R)6K`L;I*K+D z;#a39OK(medz6(>NV4aB97;oAq|u;$oAwyahkL{M9^)h^qXvwzfAtx1D_6cMN0?63 zbEZxb)!J$I(;(Vd+op^TW|y`VRoR-xo_UHa%Vgdvt$ilSECm zS5scAHO$+iCq713Zm1B*oqMycRZ#ifpzLt?cd{q-j{c{xG)}g|!1ObXie?PR>X?1( z^1d{(y0cKd=>RG;gS;*KqnM1G+<@>Ea4Sgb($oE0Mo%_xkfx^lduhnjeZRmJEhBVV zPB=fg+-+7R?bNAp_p*Z6--7!tlI>(~%(&F5Itq%FwAY^UmX+MqNV-ZFWw|o~_98~6 z0eh|p0{z|u?gBEH=5vSEQ8GNVllJF3mxNBsuO+}I_Y`M*>DXg;hW36CBGfyib&_|F zHuT<5pwp} z{rsj-=kKY|KZATrFnr_GW}i3f(v6nVSGg8y`-(?S#;#d1XUTrk(TzTfi+kQCq9;0{jNZO$yHa`Qu{e@jOR-??|b|2SPCDyEw7SUI1+V>n~ffFu9?g4 zxAjk~pTNOQXY1DNN>00&-qx&ffvGlMwjIc4%LEAfvfo?BdRkCdYS)E`WOWLq?Jb(1`m$!nlHE7B_4LACS0^Q9 z-j8e?m)?I%11mH~vp&;x|Mlz4;%mHIq3cGrf1%6E<2&GFp)*%m-BQDr-s&>uIAPuj zHmvm(b%m1xK?7uuqzVI(-p{8nSNj#GBqXZcJ!_|%ipqj3t5XH&IA=AY0QNY1y$9pgjCrPPeo328TrHFjL`SI@J1^5cXC6w210hx zZ>Cr97I!a6*Ug2so+`)-m9gkF)96wLt(Wdu6mdmlD4MhDh_>6HM(*=3dw|hEDfd2* z_qYfk+hE{yo_W4=ARPm`-%K$+F=078xBDT(e&1P3ssqfO@vaw|s3WL9&@<-6HFkYA z{mYZMNW`@xHa{pM#emD{^c<_r6b-}KWZ$czbkO3HY;p1O5s!?oqE#X3@$L#ZkHz{_@(;!;*xdYDdP-Ef!>1mjw{B2(ez&4asK^OtRe zX53_=I(9q;rriuGo$uI~Vwue5np{4ZgZrKw)`DmrPBZJVeaPCVI8AoI&MDdDxb-B7 zHeeU7Z{cOwY$iOrSNzhs1-Kr=-<43ri%<^_4%X1>B*15Yo`c7`R@k>@m9{JHE20s@ z&9stJOcd*$;Ci(`#)d}oYF~lw)+ofEYuPD@8y0x6-l^;TD`~nsLdm&j@*4R}zNpf& zb@P$GeLh_R8Bn$Rw?)AC=y#$#1p>dM_F@*^igP=0wUIAgy>qs}(Jw0WF~w~l*Cz9? zC~Tf8Kz$Wyi(1?E{S@s7hT>%*zE{Vx=gc%RzwEn-W}CF!HE=9ANlSQX!iEITfke5j z3jp+7znq6)WyFo)4rZj*>E1grr*KdcCp$Xce+N))Ea8jhDUg(USjdT#k`sraJg?=w zwaBI~tGI=0y||k`?-1w_48Pau$UGW>ux*EE@Hz-hesa>Fh(wp9WckFVmQ(Zm6gPdR zDsB#vqi0R1bvmu#GL5?GUNfH$4;VA`rAa0~MGY z#iFz?^tx1?Iu1Z?`(alX8Om?p#XIJfkQSZ;c8#aU1%o*iu4RJt>}R3jp1snYnJ&7+2PSH^rS&Vs*z%iB9)THb+!Q5&vQ=C8$*`1sJ(Wl9DR&GVX9tiRbt-IL4(7n zQ-MA0%k=I6&X290U-3dD3-Dzm>`seIbUoUeza9f-wDnK1V#ix3OA7zdHJ_@%m*h(R zpK%9H?`+-N&uP%=eKv_q=P=Uwqz#!)8m7#!6K2Q2ES<8u~RPl`=L44U$XBr zy>DEN_f5EehS`WpDq_2HmP(3O`aFM9Wg=gr$!QlHAc5DXq2Kx?$0wlitGfPE1Nfcf zkNQ!mKe9RTF)h5?eNjlO1@`YRcl+IWs(E0}QSg~;S&N^H=Te-V)-s&u&+~&Wcj3E* zB(^_frMd@s5a!9q!|8+&4w#vE-`GBSv{`o3DN@;G#|hNz1=_3pa%z^H(E7s1PmfR zn&_f7r_huVo@)r1qc~q!$ikfs6L`Bu(ceXc7UA%e3VW-0xws20EZ`h26dN5qTkyCJ z>$+oyGiwp&agv4$O|lcd^}}u*Bb|OK&$KhfNO4XR?@p|}#3OK&zo$e^B*hpew`J0|^(j1ltjtq+)A6mLP##uti zD$JIiWpVDL@ni=57!)vu8H3YZwwCwU>GH_{``@u}2o*P7buw3?>+~iWbXb4(Okq!- z#VqwzQ$8h%FR*R<%P*R0#h>>vd;-nFITP|CLFdcla?%1~e?Gb)$~9u?nl%2douCd& zk#m}=EHLqUU-xtDi)UX1Thl$_epaBy;Va-Q#n-0vX>EJ)Coyi|WGzDv59>Mki!bAH z;ZbNKK5Oar%H#MN``Ze~Q=*jb0Cx`7GB$*M%!sU|@=x6^|Irt%EG1(f^8ts$%ZfHt zgtn9#LG%LcnOKuO1~nnJpb>$hQ|pS@_(zs|X}hO>It@5T*tk1b8skRY`PAPkDo!vm zmpN0b5AEIDPq!wT-S%qNl~9EeaPk?&T`RD#lxlUwXiJ2j)sMjM!F{y;! zJtetAUQPJUB&~+=BkN*@%j_@iIpQ~U zBofZ=slT^RwQ3PHf-5hblJY7FT_RC@tKH$5Ry(`Aks_JghRNpqvBS=60-n5kXMGeh z?EHvJa@heUNnf1$l9pd|M@e{}z^d6h+^YJNuaFKiBCvt6?kHd9DT8|VeuOylK`N?* zd~iPN*v0o6wVysMx1XQuHZgxh5W_~$tACC!73}nsva|QxZPf+1mgW#t9J#>pO6WSA z8e#d`PlaPu(&?8EvH?msmp)5KT5J=wAtAYG@$pXD;I&J-!uI7w1~FN+Ba`^cxsHox z2D7CrEAb(Sog$r}IqRlqR6Z4d1V6EUsW&eDJ0#OwNJJ7s8Y(JW0<&)XQCZHBt53KL z@N)LuH}UH6cP4vHk3sWIWp*YFHmw7)#r{2X4p>jW5p7>T`;pT9^p4Mcls+#M+lof< zqkg2H;_~USvem2m#mi%ru{?!)HSYB&>8&rNYeIvb;pR3!nu}+DmJX>Y2bOF)T<&Ot zTdN8)Xtw~zsQ1hD6=U1MYGRg z$pZ-@uzxO3h9JqL;Q98f{*5j# z-A8=M%4w%bVUPZk&-W23{vQz5t6&+^oT;J)D#A6*qetw+8AV#KW_N$C#$1BjNV*<68rix2Y?cI_B${|kL zt0~d3LX^*lWOS<`T1R@e&K9goM0xx_&c2%5q%VG(XtSA>o|$qqYa^%Ab4xl@Ap7dr zQusa~2y&pL3EZhCIzl-&K3Gd$SNGj6I15pM?of_O^o0|2k;_4*U|s?nU|aujnfExp z^B)ZaIpr@`+{WaI7-?Sme}-}n34cY=_(2@5anyk~j$Mr3x^Imak1r+TVZ$cu2o3A$ z!NNlmtcbAHDsHV>0pUDe?SGM~!^AoLFt6>G2;5GgVO_a^&pMEXG4SSp@%BzBV;&86`3 z1Rj^z+>$2&;vs8$UKR?LghvQ?+6YHq% z6DEWAmJ~wO3dMM{N|M*lzm@koor=h>JBcxulHN z$Z&QA3-Frfpoq;@*Q{bIjQY#)iyL9nvl9WzQ`)gv)AWfq-qrcWlY*ZH_*4u$&o6N2 z;T>h#zn2;mB70<53F%15Fjp1t3e!2*(NqujT(uU2&u%m1!Osx<2TM1 zOwLx16i1NSO%;jtF@_pfO*N{lm)k~N1_^EmQN*)oz2_<+%VyrT&>TDM#T1cHU9e(x zkJUsWr>veb9IB3Y#P+UHi7^|X3C4OIghcunD`8fetSb)9NF1ZF&u{%&MUth8#q$TX z`BSWsGyA5vVsboUMdEAiS0DU;)xp0(ba|?`w3Q@-M}U)j&}s(fM~J@E073Zq!)Mjs zxpiH|UzF-Ig+0o+o~Hp4d$*+?5^}HtUo@cyo3)a44MQ6f0s-;YSxcKC&yhwtxQ@q2 zA&R8b-x=vX&>a%em3iAiZSNN%9w|JXp{I@ZQ^B4a2%F&uE;KuuPX0{jlMT0j7`1{< zM7hC3POT%=Jz$lwfO(!N8iNqfkH5m8_EL5?EsVD_g;BgNJWVqPf0j_iSFyr1vLxkY z#i-_YlbMF=c>G=G)L+oDqP6k}dn6gzDlkEern z2E1@x8`4$j9RggSiD&nW7|mCgh(s?uoNpm8`IfGQ)kl;?gJJqZw`1AKagQlpek6F& z($n%Nhy9@8XJ2N$caQz_+0Vi1ZEh!BAyb!e_`0{3f|=W|*Y}(B$^{p6_7;n21RaHs zw^rXIqwDSI&YsS@K`@fr8fE1f^!#62ap(^CG%$#x(?~}OGGe1p=55$!4vaLs1`Y_~ z8BFK1mrIuzqjr?)ZkRtz_a=Z;g1aE^NI}-6x}08f0yBB!k7z1W@2@{er#g<%Wht*z zYFl=MMuWfX5-9sQ8wuO{|3brDJf_m+<5hUDXA2(5 zHJ%kr_jL(6+r=g4D8JQhN}-n=T$qNeyGdBep*G!cV%f>BUO^8%Cs&IX-gY6zpY8{a zP1CsGPP{x+`tU8@F{obo{;5uL-^H`M+YK%<=IfdGLJ8uZE$`h6v7N`hdZxKkOJ^D%i6ZBAix$$aP`{hqx zjumn<8M-X|6bA9o-t`<$H5YC<5T9N*{v5}M4d~J7_xDX5t?=^^9un8Z#iUnNLwoO* z(^~PK%+cA}Q=$TV>l-aR&0f?xN#05xhe8TD?nHxI=JfaR>|?=Vax%E^DCr{_6{$mxkF&Pr=aBbx z$qUAw-_(Yl^BMGSn^Jmnhxo7$rnR^~afOv?U8xIWB3X&p4P=;#%j0(VX)%f)gb$=6 zhqBP-bdk0Q+K*YyeF<#eOz}xRq1zP5#6T||?)=CZ;2zR`l>F*>e?(303t|z%FX6M2 z<8JIjeEX|>D|!WUZ5nLqe0nB=EiA1|tru8+kY`YECl35KX z=7T`hubXEAqhP;0r7K5IzEq7z*-qbR-lp;BqU`5YDBr* z)7_tCd!P9J%!oykLRnffP(ir;=&s0!LUFjY% zn3T#D!Wl*_xorx{H%LfhL+p*Ne0x8er!OYODLFd+o=xI&H08|)c#jOWLq`aG%n_rw z1^Awqj+2smM77P)^(oca)wFi!&FP9;2hx1drU%6Ve1+FZ#hrj0=3w)vQi1HoXBWwx zgVH1w3@gLa#4ptCX+-FY(~NlW1Cltt6!!b)tqZ@@qc1i1iKI{(T%yO?taFG+<~g3x z8A4I_Y(uoPBd8H{sIlr{!;u7+@mz`az#0$~jxIE2%Q{LPZh<|*8<2T0YI0GYao4N7 zudbS)L*-uxa!?-dr;>YGBDK6Gub9Z_+~@Jl!ZXS;zfusJ;>(9|d*uCzZLS1~o7nYL zfxn#67daVx$Or@A*rx%|i!XcZ$@QW(uQ2lRupcRKZ+h9H2g0q1t_k^^e~?H2h8xYK$8mwrXBTp2duD)%(lNEd zyd~P585bMhLW+a)vQ@S>7ky~px@L7nV`Zc_cBn~wYAqd#waJ&B9e5|tU4uQJ~ zQHbFT;<)L6Svul&1W4yT#=rrqO7|db-PO7rg10xGtS{0aB8;?|xas%!_e9sI1!*~_ z_ZJ>PDN%q2TeVvb7}dh$=XIls{_mchax5{;r)@b$9>8ssZGDbF5`D%f{xRmXsXp?h znE`IRrM#5PSdRrfu+9BpV;zE8_{5qB{@v07S6SXQA;56BKNTFS^_sIZD`bbCr7@Gi zr?Kyye%)$bq>5XU6 z4`T_vnXiuHBx^WLuV`Fv=WhsG-R!d<(Q;~|a)LVGU}9&=lh|jkykyc$U(OO^P^{UI zkqtRxoMtS%2;!$4c8eLhq8C6Zoi; zS`FcHsn#N1?s4jMf2=?(VP}w}2(#BJ_{EF07`WV6ePAaV++1E()RnZfv~(5_5TJJ4 zu<0yF=oY%Jn`CI0$+~)MH4@s%)Fts12q(0<9qsLP4#HM(iE}Q@IjZ$uB5+LU8Q6%2 zn#&e4nKK5pE7S2v*DDDINoAEQ0EZ2+84sqXaGt1|nr_-OC5>6anwb0)9bj)^QVgui zt_aU0H4>pQDDV9NWf!mYLcD?9;t{{rJAS7#3wKQ}dCFZOVLYksG$2g&aL=#mSh~HeK5{|^N!mV<-|Ln z?198FpKfThfFB#FcJ`B;#SzZpUI_hhHh&ZTug-ye_aS7C5B?wF;5ta?%uxck8N@Wc zqJ$)=qd_=GW=}l;W#)GqQ-Al2?+eimuhAp@V|^tvT$g#-LD9B~{o*}WpZESQaX!>q z>K511tuiy_=<yCIp!G=d0E$mm&b#Zx_;!o za6$sQCwhOiGl3`X$*N;Shs07+TzDrKe+n34W05@@ihHKf#Oo$xg2I zo8+1)ILjOs>K>v_6zLk2ULF)!X8DZ0?dQY7>er zJnSX$6J3m%-=pS6jK)&GrImk^TNwIvx+e-Q$uJHZ3nG=!i{q0wuuIrDT>JB-&F26c zu_u?$o@(7e;vla$^D%HV?=3A(eXwpKSy)(ZYu&;1jJZ}A0pbzNlzWV>i@`qlHPP|J z%O1|Rk)tOHaxu@0zO;npA?V(fD~&W@*eGYV-8X-&Jgqg_+_bs`-RWZfzpwI>~T%to_4P>&y9PluC+3kJ$Ryx(?y$IIrDb5Lx2u@4Jn zCP5uzQIe+47u4OxiU}s~Ki|2hOKNh!(cN#k$-vLTiYs2U@+>*1Y_v9ia(sS&M?V@@ zf{77+4j(1U(kdk;Hu6IL*+=HRY#%YwPe8KI#YpkT%Rmb2_TB)a4AZX%Zlake=r0#E z*kg&s=_ya1#?ullq%Y~f#}gCn{Hiky)x?Bpcu@4kNAZt4nRbZG4pL>p0EpY^nF!WUqu64>DN|hw8VZ?~Gk83* zpz16=!8!9;r5Q70mS)dBj5m+B@q7Gj{A(?9%hTmE&KaR05+2@P15?GRe!Ka*;)Cb@^_Z^GIS2l=zevO}j>IccQVzq=e18hXGD)G> zDS1d^6vr6HEvsTunVCDoHhT|mL@&qnr*SlUVLm1CDbf=6F{B0*^SIT=LX zlFxpf%|55iC%##cOq#B1^g;!<+6Qt^N=7_9WUcm+NV2S` zXYP^-rsa_*SA#pWx%gL_h+v=r|byF5s%BGzK_p`Y7Bm!VnF`{gGJaaJ+C=w2U; zUfi$YexrzSa~aEoCCE@e-n30xA-^STs^Wu4kZdn;Hry{Evy5+Af>EjQGCXq@fUr|4 zTL_C5_CEFhGSa{hnJ2p_QhR-jLAAcIKaN5H-2q&r9LwaN7B0o8%~`)UEPRoPK8d0P zyCym})PQ|TpChOcQE2QLmmK1*yjA{PX#Gx^}+65%hMz1t&;bKw7jtzZE0t>%%BL6{@|BepN|~pgSxV!G%2t~#DIxi zfRs2NEMnsIF?&H>G_vE~Z9GvawoBK_zQ^~)GPi;8-tNN(<0EGdcjQL+ zAM|#~1<46^n>kl&Gr6iFwOlnOvXS$UbiG8GJcakmX6l5|m0b>2zu+~eD;BH31d%WFGGCH`OouSJeYZucF|OtjmPcpVUQd>){v}UQB~t>amTR#cbIZhdj28 zC;~`KkM~#5EC~~x%eyZZJ0%AgBDT6^J2Bawxaj}tYxDONzYbWqwW{tXM{Rrl~&!<;;Ff-`TtXYKnNYxFcC*a)8vVko%91tP|5m{=4B}Igq+c<@cK@d&99ig+H zVlW8*h*FjUEVxft4K=E&>MF`~nLZZI>)~vAF=M}P3)RE1{}kHw+a;%jb7^e5CemqO zdEubO>SNOnbsMx@XY_-1JoJiX?Q%e})Up1fhEw#D#m}|fk*d1SkYx1E^+p%`g zHatB?TvWl&)WYEF^jC&YmT5}KF0$)u-J%FZWq6@~EXhf)Xz_~)MV9Q#M^Sa~#CUWG zZ56Zlr^O7PBmL5Ao@W=uC{u`7@%Sq%i?ZVchkH~OMDbkUSp#eX*FyJuex&s(rVY^` zWy#7?8j?9pA=8U$+t5!W^H1OCyDU-~F05G;Ydj#(2zn&&s%>-aRVbEWGYmv=seEW> zZDlp)sq39JXs+AjvR^Fh$Egfov)Q3HhdcAHi3f8WVwo{-_^;kh#m|HZ5x5*>$==x) zsi%I$)Pspl7Xw%xntq{*KZ-cBxF|;KeYK$uE0H5x_`vhKYxD?P={ha7<2#euS(;= zNBgg`t=wS@;{R0;L2U99j#ThJ)&2bi#+-wK^WSCl(xjeJ{a2-c1y&wRE%_g<1Mm4; z68zs0yW%I@urtu-p>smel<1+@-6z}N@vASC#r?&9H7HIkiL+z}{c%apsOel4zsght ztty@`P6!Rp-x4XDVzt^{lF`wjm)3ZhoJ3f9cq*SUy*FNEDU$m^+Ae`6Q3)jN;&~Yx zcqCH9yfw0nlw0(q8Qwy`ZceN( z@fPgLI<(k!drpD-^Ray#!T3o#oq)&TVqa3pp#|As-`Bk>_=X!+tR#m2nz@IGuL-86 zrBDCW_t=@t+uPEKc^Xi;y-Md-WZ$@M3(rS8=vQvH1ete$ND_j49+`a6vN8|PwT3a z+ojbMiM1N@4-}qhk)um(_|Ly0)#sni9b~YY81H_BNn<0{9iCrsl<^(h;bWcL&$y#+ zjR`!h>T;vS)$r~(!E5HaW_)nIH*A?(HDb!Pci1QBmnU6ZMb-{D#TQ#7y%x(Fcy z-w#$T=z6q*u*L@8jgrpyy6x&ij$*zN?f)7U84O@pT%J;uB*<-h&t$X^MzSg*zW>|^ zIlDCa@)=73W3V|Yo;U$aZD_W6Lk`N1RvK?d^BroGH=U-+IduhqvGfSb2NlWvzi2z_ zzc&13%@->Hio3fzl;Um$in~McQrz83p~Wfg?(PJ4*WzwL3c-W3>Gzzw_w3#Mb^n08 z68PB6Gw*rk_zo%ZI_zoq_8P}{YTvjFC0ro~m?H4y<%B2*8%;SR@_RfDyg1m z*~y3Gs8p}Zw2^$HwuRZUybbPPWrCTwO+zi7!yd+-1-Zw^A z*6QH+1m*%H6rbysX|;vtvoPO55cnX=!x>Dgz^1T$yPqg6jF!?nW%{1Xxw=pRHVSgn zqQ`hLrPFLDaFS#3gB0rdcqWL%;ZID-y$cGmKE0pk za(ePT&SvZs%7F)cTmtyAoK zpx9oi3?*6e8VR(Yh+bZfXz9tajd-+24lpq?@-va3y*eyZ=%!z)_K@vbg4e&TJgul{ zL4N&`v{%w_pBl^BD7RyFgMtn3*dxC_tDY7`DF^Ua+@*OA{m~Zz;ltEXh&~ZAKFG`d zfoQF}Grfy@e#kt7auQ=M@IIS%F&dBS8+#3n0!DZ|eQx`3Qu+Lm$D5!@ia>CKFvfBZP+<-`wMa&%m znI3AhX0oBpsr6kiq*4DC8#8mU^roQV7V>9{FWPz9Gcp1{t9RJ!--3J~O9c)@+;-il z9>0K|&J-A)bO+tlGRh}iF;hrW2m6>F&sauhSLHYPS;EG<579l+;c^gUhLdbx_S=J) z&DWak;^?6YqxFhWcv0zL+4VNJnXoX)dx5cge;GuabG)N^^ASYO@XemF{WAf9;@9_^ zzfk_QbP{LVNEd}k7H;ptBm%~_&b5FVSI*S?S_I+u1HyX%F+4iP)wQ2{`v%UP;1Vf> zA3uJq;^5(Rj=E&em6w-m(=&Lpu=jZqrtdcnh@dx?$j0oxS5>}@L(q}3Wgsz{Kbccb zw$KG9y_NA&`5*>ZpivVi$=cd~=j?1mRzGmu5@a%N{gd1yANgUS*`P>cI``?7AEzkQ z6@^t!jmVDp6cITx#Ot1FK<}28UJd>o(k_xmoQq2f$c3LBTV3dN0hUU6%SZ1F)LcTD z&9pKxRVt1N4-6Iuz{EM+q_Z4bmzu9saZwSx+vpFd+VzKkH8nN+oZU7icR{FhmEjD9 z4^CDqW6Nb-M*OCS8yC?>Y9bI{|1K*>$m(djp{@TEkD1CU(ZzaFa}*HJ}3$x2NrJSff{qt*4Ea!WewKH9fw$+|KWpMC z_>)d;-sK=&Q8lWjebFhh_va}_F@360qkQhJcEiZHbb+`fE4@(*tRqj_f0BI8DZKWO zUGbYZq_UKi2y~?6^nCa2Tp2$q+i4}{Txz-F`i}U%)8HiI2j22_+ttuG--Taxb$HLs z&Dp~R5J;rI_56G*SW=U7E(zKF(UHVL1}w3f(#I)1q)}LPRRvr8uqyO1uFo~Xl$cSK zeLtS!+)#CM?)P8?y>cGZ)uZ+SMn&x>_?MR^S?t>XzF9Akl_N1<8X0hqGt+vu7uyu* z<=MIw!l4S8<7BxPzo@P_Iftq7G23?nN5{NgUk49*Sq@+L%}odv&OpX_*B*lonVbq6 zVN}*o_#;@9#suMjPFQ{pJN?i3KsdWG*`fr$p}c|u`98nZP!ut5);)BN%8PN;Brzfb z(NB&$4rln$h(NJZFOO#%Bc;vHZhOX5okysS?J*3iJ$@S!P?w*E8*HuftyhQ-9W*+)ii#(gSWU5hp-*`Sbn#HPy3X zx$56>6XDUNBCQ+Fhta0>t4%9z%>fDJIWPW5r$wvhdfthQ}5(p>c$ccEXDsrH(G0^NJRm7SUr zoi{7GiQ#qpcJRRB;fH>idIyb!%(wzx>kc8!HlK9Dv){qP8nUQWw?XzQHx%%BPhy(1 zF|p9kFz_5I*5|>)znHszE+a5!Ew0_zsjjm9D;WCvr~|EIKP_J$&24k$K&c68nV)yM z;^I9_F9ykX?;Du;O_fX#{8gi?F>|p@zJ^Wri0<93(I@1aAEouwYVPTjNCXOI`QlLG z8#wl4#eFO|+YTpxVbj9o`nps@y4qydJ_d&_bnf-XX}oK=wj$hwD&hR2V%AVEnXvz) zOEvFt9~FHd$nwg|{=;`Ql>0=CW>mkO4DD=vP0e3X&h&(Up4x81A>*!^$=auc?P|;$eufW#r?@_Cv6(9ppkiM2g;Hwgg+1|~Um+P!Tgi7%O$fKt9aE5J-&(V)3nRjz^ zB4;~M-hC)`uIc%efjgkeNEMc^i=?;r{oFw0AT=Wi=wD2Ctug)$Z=+d~5Csll_yhUy zs^&hPCtlAIT%C7=B7+o%eV{~+6&sUvG$g;U*-$mKpOQQ>wzTKIE95P~NmKMYC|@$) zA7v|l^rM)`g49qSzqPuCV|G-nGFfmjr-s^xmm86DuQtl=n70ZxM5*i%4B~YcnC_s#s6n@Yo=|^Xn(;?mb3oV<9bEcoJHK zBahJ{)H|=>YODA40K4C8?x2%QfyVX1bu?qgj#KP)YW>x4vIf`E>>)WL>jJ>cL)Ia= zw4KI0siBu(8pNAb*H>zMyQYFmbjhpP{NEQ`r~&r7x0FBofgqPGQ3_Y|!6%G7(UGxK zA%EHNgY8_E!&>-a7!J>1jx$GRAYl_8mu6@;(G96{C?_EOU498&Bt~b~cQ@fc@F{F( zvWmS4nL9oE;{~jAAHA=G&tnF;jAaW5Lq#b4pb=&hbYa?T@1t9(r>I$wC-W5m^Rp|- zu36LkIT$2ib|*lYQqbmTm@O%CPh*Ro5g#8P;x+l@z_aklue@H^3Kp34wluV&mjq1e zA^f{Wr?r~;W7t|GQFxA#g7I8p2}U(1uBvz)c9{>H-%!1Sj<6(UIK-W^C$A+qZP5D_ zF1ri2C)QasYpY8U`XTZApoQrIIxlug>XhRu$!%V?9teCar#NPC?E1O)U6fUJQBfq% z6(LBvJ&*#!)FB7{j!m3YCTU-o{-Mlo?cq5woR7Vt%n5Ab^`|)SCo4MbAEUE-k-|NG zRtity@A!Z>$!%^=naSLxweb?IAQ@7?knq$xS{Q~9>X-H1XPmgw1(=@}UVolBAk*N~ z0tGs^g4HUuZ;$EqtzB}`6amx_Y24I>xr@
    !z$v58Q1Q6{Er$al4K*NA)oW7cxf zSeB3Plt(_#E1t%p1{&+l8S{#54Evpo+m;`Sm+|{!g<>Zc2)Yi_7tE>E*HP$TEHp-@ zY#%tt+>GsF(~IFmf3A_ZXs^AWjFOMYTICYGOHh8pZ*n9GTl z`ZZmhN1k>o#azWu5e5e~+5FtvWHB$dR|>1ijB0c8;&4Y2XOrofnFTniAOxvgI|I>W z6z53IfTt4wo2ZSNP5ZKSgZ=)(5X7Mo3(emjNm?2HNwW1Z7|nGJy$jyXrM887jq;{Ywdlr}Pjb z>~#sR_gq8S_SF8mJGFTUlC@iX@qsqz!pgHrbpsJ78KPsB(ZoAzg|fv_b1P%tkW!W+ zuTLeE)M`5Vcy#e52}qpYA(2qLAp>Wj1QP%*O$*?Iuq5Qg;v5~uIlp&Aol9Y4eV)Y5 zE203K0;k`lg!Gt^v|>%=@G^3ZV>bO0KAknt8jmDUP*%p|#LSxbJyEFMDVZ?lVmU{$3A3C<**Tv^aehttyw9ZyJ>q&fQh$#St*5Qed-Nzcb+_?#v-#;l31R+0-Cq=?RoUOg zVt?9tV^b-4%AT0vD+I3T3iYYPvRBaW6G85`2MYb4*N?!1(RB^YwF#aFq5UH2wBG_k4o* z(zDgjD^m7&C8Mi~U+eQ+^VDduF|9LvUu$Ew2iYYfc?w_%s5$mWczySrijk}`$D7`z zukSiw)|W5JryjztNM(^4>YKtMg?Kz{q4vQ}td8anfl$h{%EvQEn+Hf%7`+>0ykmKJ zUqx%jb4J3#vhRS zEmN2KMwL*+M5dLR|KCzwXl0rS2j92XwJ^6V*)>&7O*f0k@LoY!Sf5WO^aFoFOdM(+ z@hQ`?9GlIsqWR(BTQIOBO-Dqr)Kb6>qTqoFxgzkeDKU{v>(+7Dmz7 z83*ak$f%FjMG-Isi}KtxkTK^Tqr0b0o=Bd8*quiVyVeqk0k^5I6Vp3NN=XSFkN2DO zz01M4`L@X!o$O&ShVD^l$;WFUUyxWdrp~qsm1b39Fs|jie!)~cEqQ@zHM@~%s@*Cr z^mB6wqzROsY=g77gKm&Q$MK;(A=zYDv@B{=-nhH^J?!0hPVDyvivHEIcV36u^t3)Y7DT|N-8h5&O+KhFX%(>TIe6? z_`Wd+N`Gu~>!-^{1cYB<@jvI`GZ>hGZJ5mUWtmQysFehc3KpOprzHK)XqVe(7EQL{ zC|-|iDh!GB96w4KB%YQ!o{lnCRJXiDJ{;2)*%|6`J00GAyPDkhQ!4RRd~gYEP%nty zTky|$^y@5b>^gb4vBQ7CaryGJAr|x<0{c9`sOv9ntf#UW?SqD3omUg%$LO`xXHLO> z8HIGaYt+YGUMg#3xo3b0o3{IGJ@Qwq;QT~(GwU#{HFODgGMX8;J6M{lJO8@7+UV_C z@@j|!8ft^Kv7m|A^NCJ-1UwuSf9?;uIjbi4(Ei)L{p**Bu+I| ze`$E{Q;xPvs_A1VT`f`d&z;5(MkEHR<^6njShC(9L*qyM8NSG=m8O*?w0TkP} z^WudSqA9l9;9Go6eIE}nE_0_|XvD+|qCtO1t0;qqWA7t*j1*NfInei>nA6R3w@f19 zJ?jY1!2a^ty1oZ0yo`0F%$3X8{(|80A8ozH{a3WP0yI^e-{U8!1rQzK1ShE)^@kUn3iBgwPNCxB=}Tm3$L0YMe{AGr1T>*%%VmiC|ClVoU-j5n+x** zUZLRA5uquBp_|o5~EJz4UQkZeKl_Kx9<%VrZ2s;DSaS|0>pe5RhW~CxAF}IF(w1Jg! zq)}LgWs1-qk8qxgKvuPIY&;ZLN5EBEyqOL86m|d3snz9c{YkUU!*}wn31CN3v*bs< zY2g=M-=~B#r=fFDe(M@k+RKd5?FlHJ&+a>5#p{^bf7!kwl{xG1SVGlZlc5+Yt^weLp;le1(3{xH%d=U;zA<5yG z_KhO~=9y&NWAH!2Cx&5gkR6`2Cy@-W1a^~65nPf0ykQH~#o{~p8cIoQa}+ANnG(V)-Rk;qq_eUYT{*)hJypJbZc3ZQz= z=e^;$XSkp@Spn4OFlRAr`nD1+_oZhwQ{a~gJ@ulw zeHTb0vQ1%R!s1BuQgL3s371`Mt6M&>76Lq+RtKSBTMbK*b0Ts&y!vMs9t7vEe$ zdIZ$)eGeCc+#Sx~fDBWS!hDjKc>tJue^EWX)f;HZ1wBC_>(m@UkG^3PN&POgcJA#f zv%4VrSJU6;wN~n!UDuGIBuNHl`?u_jYL#XXub^48xxtUBCUfZSOubBWGZ*dfEcdna(1 zPlmWQrO$SKAB^qIYUq3e)k0p=197c)>?bx~pEneeNsgD{c};Z;x+fNNyeu;K>`i0H zKsuvezPMo{b`EiSQgO`IBf;qb6KjS8^}%HGTN}*(E;#7pLQBUOEDf)Jck&B#C;z?# zCC4!bgWeme@sVs008FxhTL2U5b`#J zULr9#ekxp7WutUS6Wch>U*#WNXrEt12?x(AB-sZxKG>cK0BK)hb68ZrjLf3SV8V)v zo|UU=>=^Ro_Fwo56S@cV4W9sF z!=xt5QJ}?NB+@9OKb;`2Ss9K#hlqGSRU_%x2o$FE5kc9dR5a-V(;lCwUmYzzDxrfz zzPf~(>&dpAGO-5tP=wCYk&yUf=Xf%1iGYef($r1;gi?LunubSA5(0)Vt{&yn-bs(p z$&IWJCys{e)goQHyipRIP%#PQJXUvhE+~7=Q2A=4&W~fn?+?M3uxCJ>DKf>#Y)tOe znr|k~$?#k?1S>2vN@OtbMq|O(jq+sq#iC2xho9@m4)x~$W*=2-+f&4h2Sd~EMnm>I zzy-(Gr9a(Br0#wpMDNq0W**?$?XLnkxs{Eu+sj*8S`hR$;qJSYMd zOYk4{&7WzH@#YJVJ5uvR?c*Zrrnd$T+WZ?&>;twBoX|X~GKoe%_vRQcxTUjmF79cD zTH8@0K|2b{{3in1ZDtK**eh1eHN7OidbP6hc|bmt89qbi;v8S=j(hyIDE%qIha=9s zjg5?S=j%ZZt*MIYQCKA>>~}lzKCFfhe*)UcrW@UupP4qVyz!MuNlgxYyxDct&iHvh zM-kcOZi+Fmiu~E|>k^|lXYXu^c0#d1g=||0NX+508aux_4si{bR8+ia=(SK({;y)E zJ|a|FA{li_eZ3^HY<%?6ei4clE#^9tNr_b5V<7q5M?#RMfC-zMp8bom3MHwaws#iV zrCXyC`WYq?qgqu{Li(ee91?uh&}B!(>lyS2g;2*0*Aa4C`ljWG&~V9Cp^yKdD{}Pv z7CMz52W7HL_KUUVWu&8)6SbEvjjCyX#C*)(VCXTypd3BEYCk9OK(tq|BtwBos`cD} zH~pxo`J44<0TWRQClG@mH#;Fb$hkK^mEh2t_m@5}`Wlxqp@fPEc49kbO$H@@p@Hv( zN#x~LX2`M#&{NO-HXozql}%hXHi8O14iES(Ev1+^H7vUs3LR3ABd?Sg6W~`5i8B!l zZ2zUQS>pbovgv*bghu-3UQ_gs)5(5pa-&Fm7~@>cnNF2xrA`o-wbQogdB#|I>#VRw zjvN+w&;2ligfZJV*Z0lU`b4D1hM4`_Hds#C(1 zR*V*+(gNoSmDd9|Lnp&IX`^#Sm!-(}0zWy~&-$!HEM`A5AFraes)a&5r}hfY?E?|# z4Mi++_;|qfO!q}5phmyPcmMU&wgV(#-eiaYYw-RCJaMbsL)L zZbI^glcA-%rx6}>^n11qPIY;sLzz(l16iuEuFQ*ra9sOdLP_cvj7kS`X!f8fuxToP zG)B+Jdr-1)Ew5V=u=yS>d~h;!IvC|S0%^1erNWA1DG&kwt-!J#t2vy2=%>0^9n8@< zvOS&Le_x=zx8Fzbzejij_jXo;so?QSYd%52C_!>@2@9Q02^H8Lh1ph<$gQmIrFX%x zuJCP0*|K>;5lYs(TJ^uwYKDtHwp5i&(Pz_)S;S*Fka&6aKtq*@;R9bJu~n{M4}e~$ zq6(gvlZ7*(n28{r`0_z%z6Fy6Mc{kWa?G~lG&uIN zWQMo{ki$eyY<%i3ez$;}YSQqFUNuB|A_s99)hh!sTnyUH+1~6o-tVV{+(trYJHvJ`F{bO;)H!_y@XwW((`7KikZgwO0I##BS z4VEEIPg`3AXKP8FuYB;V>QDp3=lRz#+(_zH6?A+-ftbGg!w-Cnks|EG- zcv19c`D(*Pop@#mri#2vYKMKUn~3RF_6+QotjTAp_w&5{-ZL%yl8F`*QT{bIZp`#( zm5{1KA^VMsr2Ph@N*f8q{n&CprP^Tp(57*PEQIOcM|Vj!Ni)J~W?$oLhuSIBL1gMR zH_wqfnQgfQAfv>k+`ks%TgqRtRdib1O;QZCKkUR z?pTOw`D?UkTHoyau9uEs6dUG{#KHq75h7}_6hmiw-UT~ubn^2VH>DJWrTo0vw%AYV zAo(|Kv6Ls4gWMwpjKzF|3>%sdIiU{NrlETT;MdVYFJFOBC+tw|)WKY(G!|jW9QdGa z+^gJIgSQ)W z9N=1QB|K59U2L$Dbt$ z2=^ky+QSy%dh^zf&wNrc{Nd!I)Yp6>{NW&OHyjL?T0!ZGEIn zkqmaBCJ3t`+a5)CtH!-J1{Zw`J)~Z*14M-T!y?OICm}`>FqG!3>HSGs{<8pM!M>qQ zqY?huj2%KC^qTG|p68>(FDV{aLKVA70K?zOjKihkN?H0OXrFns5Y*M&vKh5@wVfHB z2Rqdl3JA9leu$7A96ecgSx72L=>_1! zglXAr(%?m$6_j_f=WGKgl5GPvM}M>JRSflztJg8%!(9Kz7Gjasin3Ri+EpmXoi&kp zR^tiU3%cO9n0s509s>u*@XqKQE^~N?O4Fp!m6jha$~BSe9u_)A@-=*9ZRl?g*FO>0$bj)_vFft@eU`&w@q&8@ZFcx}($s}5#+$d4^9;pkFr&sqe zw-Fz;j-(_YI6}bLSq!62NNL%%cFkqGO?hJLj&92AUunUE{*xq8pGyz)7G@z#1)6`B z`AINSpOZt8_Nr$4yb2BAU3+m0;VRtXc43^OrLtg+$aeO9JxF00x;*#q%k!zy40w6A zGbT=_Bb%WuSV6&i6rSXEc# z^Wh+SbtL3?a#Hqjl~O0T|NeYvAOhA_PrQKdh+>X_<%pHf4&qsc_($eA3%2m zlhjbhQPLg>Ak&qs`N8l0etgG5)PaUyECWd)w;;ys97elEuIq z@kJkI!hj_vL-Qox-ZwSJIVTA!?q9c06~>xJ&1J!%+#tzn`6ee-RjjCf?w|9@v7xH+ zWryr^hcV>w9TL@T-+aXwQbfkJ4@d0a( z`k~TFU%*Xld39a=ErfWKZ!JKl!l!S7;_GjR;W-CVGR>sz9$rWWJT*YH8t&8!6_DcSihV{<9 z(`*~8()(9~SGVOYRqH1TaYj?G!mUovgPYZ@R(B6630uoz5z$vUq;9muQF|dtYt4UK zW#4RYGsLq~fL@4V5fQ%;cs6q@aVbT6ECn!JfML!!B^)QF^0F**3!9(?*Z&3!mxCkN z;7x1%$$#Y-D0>89aAr#+eBzrnq-Y4?Ds)FNLaFF7Mff2Uu}x#~{PD)?t1X3g5K<#d z&pb2|Ci>OglJa0|q?2_JD*SaHI%483Xfh zEr!GA*fiP+RYahKBYIAQ4?7O>@4fT+<)-q}csA{cqaxg!g%!Qp<0cQ0on$_vYg`73 zD99ws7r;$ZV}e}Qk7+f!UR0b}|19}bSXu&{AOd!=`{eR9lT{Md%G1dI09J<)YCNT* zVRcM~1{3hr5Epi8GQRx=jni0Q*LH}Yhe`;n8H#=IjN8Ue-yqubzT40sw-7IfyrWF(br0KhRXm z=)^Cr;p&$~vqp{F1ObIH-CLQ?S*=`vp8K;MteT9^;-8@zab+^ zdR!(>LLV8Jvg;hf;mTQFsro+v;H4z7yHa7VRE20a8?M+pvU^61-s5^ajiDs?$oDiR z;kdVdA1#S8=y$M($L?x=C*kHtPhEgLKk>*_d8APCVo-?9mH~?EiA>Bwi!uKbIQ|;d z9kL}hWks_tqlCFP&6)LU3EiQf$|j3AqWx7W+DeWKN@i1{{NJSV*YcnH7HZ#yCh6%- zeH2P0hhhJ7=l<&vU1b^;4a(4MFZ5b#1`2&r;sPhBgvTQMf2Wp1#}6mRV1PnH3N&Gt zVZe<3F+^a+@TLV1vzDCPi0ZcET|m9%#7bh%#lKid;30{jy0X#A<>v&KYJIFz&!az8 z`4cS48Yy5q`qd_n6i&0jc=!S5<*LgIp{BRfMvmylp#5^37&(a1erL4kZ0FN8`j4t` z8^U#oh!r`PSo`c0G{F{7I0^IkST9L-7p+O$TtK6{DEsbyab-=kLBf>#?@*k__3pHQGW1L?L0=o{BuZc7*3v!(b15~#e{ zEH5kjR_JvXw+Eo7_)a)J_JhBT1#sKee!}O$Ckx{FGd8ATL*&^;^!^Y<50Eo3w-6-pYP)>GKc8z(oLDrm3!I?}K=wlE=oYko^qt1TJ zqABD)!bWXj_ary^UjIsAw6S0K>ydxEz5W;M{aQHi`u1zc8;F-j2R1*XC&z4|YjE_n1+jBFx|V?p7xG8O_xVK81L^f%e+NTZ(oddp->{xBbkHCZajh7smIVI z*N95JW%VayO-tRlrtVR<jNjE8FYY^iZCD7<)?-YCk0>&VaYNyD9GQ)e37-?bewVXuFZI3`Z*-1g`<+?5l z9b*G1lZVJqNbwz+Q{Bbyg&@bABY3OWi(fd*Y}T&*zLa`1GB(os*`X!-c6Tb9j-P5> z1oP=FWhC?Xn>85s^T)VmprJU%R*}g#PA$LJHkseB&ow86lBFjv~gDs$q8uFW* zL9P8hnO^Y&YFVq@u5JMe-GWJDDbc< zk*;eEHglO!W6iBsTE5>%Ko-Zv3&>(C@RQL>f*GxFrZf4@5sx#A0u%z9eXnV>zppa| zf|fT-H~?dtbN6HtcTcFBn+Qvj*`|$;{QQt3Q7HEsKB2(Pn|W55xhSOmquP+C5;o!mXj$6uy4tUt{=v7Ra?`567! zC`R@c46~KXJURL#0^ZC&!F<47|MaTlEY;>mMeq@IgDqUmH%PRX-%vI>&_Vwi0MmZy zjK@B=HU(V_+AAt&oc0Sb2C=ntS+0LAhv+>!BlIsH9?VQWNe<( z)i=f;mhFo{9I^o$CA^(0T8)Gy!(}wMsHw8v+^5>yHp9jCq`NKA@{kn%v_I*MP5vs8 zgF5n2(f0J-ZI-48S+X#aTaL-AxbVoPBnSA-Vg2G|f!eFC{4Lvb2F~d$*i%Xoi zg%bx^lV*whhQ?{}x%Y*j8cA@{z00rmarf6nIovj_iPJkKFbf^Vyf#mgwm;~UqZVjtD9@XTbr1-8SLwun*RX`5R^Wo~7FG^RkBzB9T% z9EQ?SfbSY|_Eq!lEWQ)rT-7A4teExExW#9)Gpkv8@5&}qyFZ$t{heL7z%{XGS!h4h z4TWqVO3x`)_T}3ye>e!-)LV%y&u!==ks&x!VSdWOFT07{Vtl9$KhX9+!>(T>2=C7tG2c}k?60z%sGRLACJ&q>%+V$YaM@@xCP;Us zb7it+uP-oJ`|(wx`)V}L9dziIYMnd(rDL=8d;BO?9_Uu91g(GfbBow4@j3w2u9FWD zS4C|1`_;)kVmRgBmQ_}sXk}~v)e#iJqu0psrs?DyF%G*Z5N@i(7wpOg8IS1Q#7fuAzIn|uUQg%&nK zP&*bIOj6qsx8!Wrgs-Zn=Z&RIg*j=uXl~)fWkX#6wNAR}cReBG{46yU91P0<>)VIl z^7XqEm6cMqh;Y^~PP#e_kQ^Dlv^r#_WEgBp#GbIdQqHP-e7K z>Q5ufhe?xtF7iYIJ^r`*Dd$9(dJm!FA|-xg;I z=pjaRdv!P`@)4FeNfUYBR5($!c0uP4$4r*kT4xfw(=yZ z&o|D*r4>ylyW(~!MmwH)cNz_3YxpUl1-xG$bQ?rgR6oVWob$)ymw51`czq|aKoC=a z?EC=#W_v9ttIo4o-ZS-d`7Jv39^iZf{~8G+(nx8t{hh*>Jvwn6PU^3Njq7qzF3#P0 zFRYN3B?FQg^KY8&s^jk0aO%S&wwTOoLsuc^t@T!`0!BtV`=1~#P=-l&uc1_$bCE!y z%uTH49^LsYISyN7JE`YNIvZHuPKx{;CQ+}1js3K~f_i;(C2=HgbK<+-Xf_{5DpkQQ zw5rQqa}EAgxM|40-CzZLgt{qhj4k>&U3=zXYbz39O;pYm`fjlLTD=mi+n=~-DUuMO zdTF}ut9LM{)Tl73r5zgcV1MiPzSy@g*pWp!UNy}K^zeo74dqwfPW8A^Y$B|%u1u1h zeR#U@{yT@Wo-X4;D zLNB&vXsH1I?@PTCFf<+#FjrfYqtBdgGaHHyVHqpi5rAA~U>gjtf-Xbv!r_YpyQe*W z8AvF>I>DW7%VkaR##29qB7R#W624tZSTi#Vf~GQ@s}%5ltUWh`6y~ARXt2|7_#iodHKRUH5&<79gwlcAJE^lF=Ci8BLrvxLi_m4eGS@ z*5`^4`ogNUX?s=Y2jeARGnEBFuYBS|*0aoL`E7yf?W>mo;Q$0429lk88?;3)z3~N& zN2zKPLCo_4uE-3HsuP?qhbT`%wL2FS?D+Bd4R`9n&MMXv)xk(`aMdIXraQVxAv!(J zH+$E1Bf~P7@k*r3Bee{WU~s5a1elN*76WqBgEM)$-_3F|yx07}qVCXRjt**4C=@w3 zMX01FJidst1*x{rT&W3{OWCk^QXjH5JpZC$UBQFni*jfAT!M?bqj;JDmehvgps8Cf z#~CpSPuAjof;gecbbS6W7UXsH%=|L=7eIpZ^a%G?iKH8XfAax#U)yac|2xYI{Zm3Y zw1el=22scr!dg-CSmefJ;K?m0C?HpssivdvSLgEIi#Wjkhm*p;38x?U3pfb9CqjC? z114)}YOZ60IN(BE7TZsT5)uUa0v`cC>}Xg~RYIV~wxuwpE3D9N^;m+%pcY6CN0{b` z=M-mcB*vTVQ?vQ4^&mN+&~A)R-THl@w81yqU28W`I3g@efTfFa_;N!M@X@pf+;}ea z^)+lSv^vna0G(fA28w<8CYx}456!8SNKx+FKF{U09KV|Z5yaZG%U=uS0lPN+5pnTw z(o_W4Z}sXt&-1c`9aIbQ@G@44et5I%3o#Eo5Wx)S6)R;@(40-!@iFCZxx3O-6QO55 z)TjXcST{rd+C+~_zw;5o7J};^4zKwlhCD>mp*^7{E3MJ(L6y61ON5K+0<{8mW_4iB zX)hJ;0R&O>ZstUZe~b{6kd1flxRi^9U&xE z7v1o>*7T&uytM`D^8#t7(J*%`xcI@M51K;lIqhI5I5zr;p=@l;RxJAsdOgpXzh|VF z?vt^z(7+q4>D|=4fXA5*`_MD-O@wMd^`idYvdyuWkAxtnggy<8>=GB?XVfgr290*> zjTP1k{i?w1bERVVCuo2)qJcSUv9^X)dDvNk6`EINuxm^AuJSqw+BX6gvA74RY`?hZ zPFNA67Yu!q^>qjN8qD@Sx_dTKW$U!3oVIwJvlxm&i~iGA6z{Vxyzk4~zq(JM<j1O>v()NI-(^uzy$jkNw2wa8HQ?yiT8tHA>l06Dsj zD5Xx^fg zR_%H^2!=W%8+l<|N>2aqC82NGL^>JtLF*lH8OFTa&GZ{`iT-eAm0vQfff;W@YVDvL z7a)S*8I>WiOdmH{6zv*dt4{aBpVtoRO%BCrJA$f9d+TjxN0B6}&XXd(%^{cY{<_r0 zx_~SV^Hzb6M{^w>|4iPum@}WOqN)IQgpK~L@>|uA;&5XIUI+UEgDAMGBESgbG0tLD z5#>K08d)B8>UB3jkBEXA8D8+$kBtY@*Wy3HW!BcbG z3F?TQbLR=IVs+a`wxclz_yzzN}d+5~H)y}-OHUa#R<&IzmPdE)a2HTH?rym?hkC-LgxAD3r>m3BZ#63gR zpk0qIt_BB@v91sOpKvJs<)whKg?nJuEvP4HF7K}=PBO&OHx^hzawfk~O_>vmg7_?; zmP-v*IvKm2XjhCVFb7aG%}=|Q-bN?LNR2FEeDW?GK z9&#_wv$rc(q_WA0urEvHD2Ui z9>BlpxQH2jl=hXKg5aIEFnWA1_M(H=PFmzFjgjBuIe+w4#05&O3IkCm{B~>a_Ia6Z zy(O!HGauF-5dv)0wthC(UqC@Crdt3MJULqpjN?6nHAbr-r_evmotfrs2(y+DqqIQ zNRlAGYPTVvS42gVDN@^tONr4rRIesmQuK^H4@RJ@-@4$8%JAV%W2sILhfR#GN_V37D^dixkTuF^1%` zrtu}jrDoQUeL!9-(eX7zt2kB=>{g{#4Nc?F0liUBp$IWhh zo1Z-SZsu>ny9$RAl7RjBD+2q=vvUYYt4g9{pXa@n$Yu=7s=eNSrjCMI=d)pI+2eTX zk}ZW9{X1-|tkP;GGp-*qE;>~!{SO*ey+vV?sM2@}Qh`h!(H1QnS%IQb# z-s)=<AdeU1~QhZA7~;~l5537j7nw|g3;6qE5%hqFtHekXFu9le5C#8uKe01JT;Q#*U=8< z2d&pHXFU5`zSLW73WNM6JHh-`5Wg4Dy=W!dGY@Afp!9=Wy^MjiQ9@(Yoz$=}TN=Qy zNKa_J&;P{3e>MEJ_O~=?RZQC?mye5KV$DZ7iT74LJ#0i1IcTe+SF|s}G!9VygIECc zUd+K^r!R5ogrZRCQ+v?_uKcga*`CWtP`!U%iT z;^`XncZt6zpG36P3SDKfO|4;iS^jQwm0ic~Q=ad2^Qz_Wa;AvQ921dX$o^`qrYPD8 zwIE=|j3;q(llaVE#2^!;@xo_+y;!g1Tk2L{g1Azvc1-8#IwEsPQM#b1u0hp9Po{T{;K2@ZEGu zgi-`VG(trD{Q~SS@_w9A_zyYcpK7wnWCk5~YseZ19XSiNIs*faKi=g!P!~lr{OWk< z26x@6W=0)1ERgQx8_t(H`3!&yW~q-&RAo;n+0sB_V-Sxk0gGRdma+$b76P6xQhesa zdlc4(hl%K!`owMLqg%T%4qw088NSpR4UrCIeXhP6P^Bs`(-`)%v$NBV(Ffj*l(81h z%Vt{Sg{BHe<85+)y=sXU4(sadmSuAb6M z{;Bn|JN2^@ocK{YtGBT+;p)T>tXS;tFr24k5HNOYZl6>=1^crxB{z(B;R)dJ@P41n z-s~J(9wP)`^_~j9^W#XJzA>Xw=pWAfo!O!pj+0go$<<@Scs2rr7A8P zMT3bE@E%JnU+Del@jb^ok=mixmjlgR9@$ud05=go|8?xp?OoH)4T+V40`GzqPrZHr z#4pP-5*B+&SuWd(^60G#W5Sa)h)_n$gPO$|?lJ*26opgGEyC>V%m-{(5X}YYbmRmCFnNW>fr;Fs<8SyFkqP zr;q=PauxGG{`uo%g15bj8l_AXG5gH&cc#yA&B|;5=|>H3nS6n3FAehP^54E%;b#Rn zQ-#se>qlAcu*Pl_$I&~?%w*cFtR3{nv3C{&Q%om+%DElfqul2i1v_>)-7LveEDYUl zMO61MSdRIt$~wtW+=+a`g`Ibb7ts1r+gbEHw-p4>?lQ$uL(U$rWBTxk@LAV{7ag{F zczB+fz6&55+uGPHmu7*=#52c-9{00kko#vz+6S6u)>PwdUehYSoeBGVGc=sX)H%ZQ z$>L_waJSV-&TeQAS>Gp>L)To~{9nagdpy&7AMc`5r;?LuE4fteGUp_bTT}=&cVTWP z5}HfO%&nn9TB9&Khg=#xWBs-jn$kPM%HwcHx033g&6vB{@=< zbohNRjs#QF2`R69`^9D)T9gIM%_%W@KE-XA_%t1W05y>wvB_|H}_^fCFMfyvA`re3#h1!mOiUp2s80Woz@9XiK zNKI8ujl08xUu(vEo{?P$iFgP(pv?ak9N2rbU=(3m;~YOHo8PfJMwgn#yYCL8ec3SJ z5<=jO^#jcDbltgsJ9yuTjNiN&LYJPnCx`RI5H^@Oi9e3qD$duWTd$SsbCPDmSoP#Z zmt)2WBAK=naWkyA5_`(_+Vw@(nT!?v)Zq@F3k{k&ycBtCs-v}ck%`wEx)-tD=85rS zYwU@n^9k730EW)EhRVopB9UF(y<~*<^y(ZL+FXVIGo*-@mlz3o9SG0;qDhFG!tpM@ zE)XGQZ*X(+94*bHz~6CDo53<=(8-{W>oT!18GvB@#%^nIxdWeXq_u?ye0W5i z{hUvA|HdSk&E5%RJ69+aXMFbMJc`nNIMsfZSn55{KbQogfFdm2pkn<|rQZu>i{$BR zGrVL3!XGnu>)G?<4uhqi$1q>$2h>n@3{!dpClUev1rN1`M~MYD`lR;5Iez&Mw3C=^ z+bhczc2fe^A_E#^j#^s!Xd&-p_b=m2OTy`eC2R;T`zGyBPBLg{=AJ&RyeP#EuEEA} z$4Oks{WZN3bcw2I-ynx^2;z`TerKPbwg{GYNj0=Ot)h!&(^mE&aHF}vv*C^K0ZF7#qZvEfvKXMMa#7k z52`cuxEG(s*Xg@Bmj*@j7e=Wzgs>u*#gGZ!&2l%mG0%G{?tpym`IN9FD|H)rvWDXG z8R<`QT_%Nn?$)*m_zO*jRlpg+Qwf0 zKgYXY9ohbbv49_fVj&Eg!{flPAw~g~L?sGqgOb+zn>O*gZ|Jwne~q z?K&)y{%tZCNM)Wik6tISth-^GM^F;l3T18(Ua<8OU)T%SPpoLDkdc7ZR9*rhetaOr z0EjZ07V~eCJ-Vgd&VOR4h^BN(%8H(<_N@q|uPI3p7aTP%Me0T&+!I-A3{tFDo2x z8=e-JsR*5kjs!T)_?2~GJ3OpWONx+zQ4F{8^84s5S$4ZKLYAuDLt)LT9PL(J*+v9y$H_VFR+PlfK&C&55;aX6FJs@!G3U}+K~ zu_V`hl}97`j45eis?AQSlu{7i-s+$*NjX=I*DFe?_m7_ zM0NGh8CmO*GDyXZ1vCJ7qWu^CIJ>sp8vCPU@zutaLGuF3x12MVcjwDNomc0mx`8_BGmJ$=vh z@0T?Ug}{WIY|aTe^!d?6G$$tQpl)+&D literal 85502 zcmeFZbyQp1);Em11a~il;#S;3&;o%9P>L0TwzxaRLvbitB$T#L+zJ#g1h)Xi-QC?k z&OPVebMJk=F`n`M`;Iqb?Xg$ZUOUUC{MMXvB|;mbLWoC;hk}AasHUp;90dhS0R;v1 z8xGb(2{Sk~^5KT+@?1p$rDTA9^Wlr9#Va*SO-&T8hdd4n1}ZHI<{woaE?HE%f8~`? zIZ@F6QI3X!5^jTn@mC$l!~Ks>+{5+9JAdBM;i!Mrekg;Z{j)Zf0v!FHdDL%zyvF)c zEaKsY>!_;lf`UTH{>O!?_WaQS3W_|6nj+|>C+bcnb}G5jRY$#*Xc;pnJi&vcPhU|W z5*36=j?Tr1hqryHC(r(UeRD1oR{&MHr6@1YgX)$3>+X$ofttZj%?{)6^xBu$X-x+xwk~p?61ziP3BmmUE zQhq%KL~+~>*hTwgKF~kRQeNoqXMa}`##q7mE}Zw2f#RR<0RrE$Ii~*r$`Y@jSf?fB z<^Mei|7!3z4Vv%5*Z&;5EgEkT0u8Z>QU0Fu-~0Zjz|gM2=^}YvGmzrV{r9*=L2rQX zfZ6&^N=pB865#_R$l~&?na3}}|9%R`35-$Aq*ytiVT1-rW&%%{xtT7m_kP4?To|cIt~ZH zp4i|7{Q&aRI@he<;cKtGc5wKchSH2EWos?2j+X3@O6n={jG_` z|JDsCM$b0R;>wk}_UDLmM`Qbq{GKhGpY2m6z$H5HbuRiEHq5X5y!$3ln{_)@<3-9K z+u{eZL{ZZTlIIlHB?NM=lQnDyS5oFzMz6&w+~_V0C@by7K?3vou;U$Ho(0%nIw>#6(#it*L)V6*o$Yv9p zP33|)Y;WUDduW_byG{WP@kw0;&hs+?N|%SadS^~Efv%qoOr9v;nHBP70{;)%7FS4^ z_5XnFDqybl`}mS9qKGF3a7{z)%b})rrbnKB%lA~fyjf?!JB2+TYNfN<#NDnL=D#De*8BJ{ka;`lz~-t+m3Eke=km%4%tLV?kB@m5S29I5zg!i3UH|Ome@}5WA)0fDE#11l%D?sKQHzzUW2;OiZYu8b_loA> zG}CFvM_1D{IpWzKraVEaQr0=$oK~+AulXXF98fK%vE5lg8Qjnkt+>IWbm@*qM$fDN zTjKmSdT8hOtV>>2wCV^;!bK=tgnBI%0DG490Dl2cN~o%D*U&eF{3?!EU2F> zs1#DEygF1I675>qEz6&W_lQH@I?r!-KVNPe9Dn^av)J?V220nC7W^^Jc{$hU7__P` z|9_~1-wb~&PNvJ(tPg)LjvY%2c388u8wb@PpROeRYTIu*jJGgHqgn-t;XejY*KHMq z%rT8BlpYX{z+lk}rd|ebHw)s1{5Yj9=HvRG`r|)gXA-a_i9F+x7;O@~`%c$OpAmT* zt`i_1be4LWth4^W)U^fBN*xExrDC$*tGu6Uy%F?nda!W*j!lQ8lvV17L9!o8k#>TB zVG_C}00~-W7O1qlWXz7YutpBWw}Uhzql1ka6f3 zI2_~y>Zp}m^G+;leIP1bs{hgAzZCLc9q1v+@8?$VG&h`Wi3q4f`}`$t5Ni{az2dG* zWIEYG31>S0yI)EbYsl@>@V&6U!;j?t(cXra@qN$TqPnN_n*Q$epK9Ee45M$2x&X^q5ES@|X1h2Q-u1K1CjcU$>FF9?!)nS~a&sG*oJo{%+A!{{U} zXWG{->ck(VH?K8N7)9pZmN0z8yu~PP`SgyWwa#2~NMF*`rWc-l`-;Ka=81Dp`LjX0 z1xb6`?g?UJ_YCPKev2n6Z-{r&_i?krV_(b=mMpy^o}7oyGdEUeT$|oQf-WSoFJz|K z`HOY(h#1(6SHuT<1Kci)^BP|y`&IR2>dF;A+i?Yw&QfKQNTIdeq*i)ueGs(bsi~f) zaGFa#Vu#I3{0)@=Wz7qiivPi)Twpkf=s=i&#PSL_TcKHg)bt-L8GG9kSRG!Q9-=|5 zi+H;&sJrs2N|O{Vb!J6z=N5U0TY!uh!wp#KkrMm_|Ngf7bMVK)4(m4f^dIBwrjgk7 z>2&k3Ou-i|cBs3Fku5B&zlL7-Pe^cc7wh#!{v;6#VtM|%Ayx0uVB~7E+{&*#CW_)8 zqhCDnl_4i1wP>;PPrjFdkzIx`yD4`{()()R9zX!4oA(2F3IB8@MVvXzBiy#Tt<2h4 zpXDLap#x4{JTa=U)fl|iXf!uH>$0&j!&~&&g7)rv=(E6%Ta1DvqJU8-Ze3nai1t zwVU*XvGnNMoci;Uo#Hro%iF2}@gg56|IeN&AcBZI2&7?&Mnddz60x*grGBNzoXFh?g-gsShInddyWpo{^6sMNPz` z6rUnu*03^H?Si&4+6_JOw>=4GQgdXQM)MF9f#B4 znjS9}RhMB83QLgJ_9wHjR=>DO+$%J1s#DTAKAh9^9BH!8cJhr(@#82?#$V^?(_B?Yy7KY7+vNw-$=am-)+geERxv%q2kqJ^6Zlk&XvmxhNY*q z(b_`7(2&)&pWO73*kBVdcTSnpr~6lFz9t_n>BsBkeFOZk{V>Y|B5W~*zE?-!^~QBJ zdEjm}hB`t3=fOe!bOW$27VeJzv3wIVjxUZ0fY3mrEY3OKgAu7k24Ny@kfRtcLVEwh zSEx^nuv^OVt^Jbj1}8d?(`Wre0Q8(XV3#~`@6sxE1duGkuJ zyfM8Au86$OS&H&A!wCaIJiq$0_7_;U{sv7P3 z9v)UxtV+AS-rEbzem|;UJo9)_ay#`7I$~ME{4~@eC!UTaOC_c-g$1(Z@?=;Dwru6o zRlOX);PobR9DdF$NjDcl6lg02h;U%>@y+;fVh!sQ!CvyrBP)|pgsQ4UH4-k*@TEnA zP_Lh^V|QXte9{&V-0p}rAD$o|F&5X{i4Jo zy0OFSt2u0snkKWJl7WT8GpVkK@G6pbZ-g?qubG0wz`m64?WS3V z=+dLT-6t8{*%9Y~j6t3h&4Q40iHp&WM+-_`@w4KbGHrKu01KBT|HPB zZm%86OX}w0h$<_Q{MJ}R*VFi>#}h>q@@P;iXYuXbnN7*H@A5M(2a{C&7#v1_2vC2< zY%aOB8(b^Pc1DNQVQFc)lu`4dF#nc;&KNy3HoFrs$}*`9?4sTt{`I4@8%JSG4$YfR zlC2Bc;#3ovWO?Fc{m7i&?-@cXtR&pJ9{Y%FOugdpC6ikiXy@I;D^a{UF2lGW*L8Tb z4$h|M_lBQ$cr?ungQrA+BX)KoY*#uE-TEK37Gu9Nm8EB;yCID7vU5F6hteP0owYa3 z@>yKbep4hw{0Y}ubivG%#NowjC~nMugV%rHtsk-SeNHBc!e0&=2o!B`POMF1+D0+% z%=v(>;R@lVL9-e`#~*H+0J|!%X-0G-l*1n1DSxD^9A!XG;sTzvleGRtdhEn3;QW`J zT(E37m)STpSDpE6ax8_yhUb^)Y1}#1Z$3^RG!|f=gpE3fE5^ar<#%oDe1&S;{tx#H z$*lps zuvtwVA_@I@{OO9yVdo(OL3`myyMR^9PCYW%naFSAy2OaMewFn7P$wiYGJu6P)4OnP3-LWL(|QNl_TbNgP)buUZ_74;^(|2S~O65SMX0`$QE;8mE?q# zoBn^`55+g!%m8t4On>(6@15cWM=oC+M@ODt9b^J7Q|ODAAg~row%!!5LO4i6Qdtt7 zm3H9EU?&|cEs0MQMg*#R(wjpZe+RIEysQbep4qegnASab=x^U{GQl+W8fcp#FfWs# zA9yk}@y)Gl)&wwT{^O8qNqR3OhX6)4oAHXjS+l`YxaB61uDZdw2ei34^sAe$JCfJv z=A=?$ck2_^ev#Xlhk9fESb5y4qZJVj{_uBDDfK zgl+N3F%(LnVc{V+1XR7Q+WV*~sxR`B#*y{76nu}1s@Plos{*GKT;T^Z$>tZH2YQE9 zN~(D{KYRbO?!75w0ZhJ)QN(8b&z+z4I*zr4>LrB6l1qX_W2nP4=Hoz=0@xd!^qk!C zQF9|9Foam))=PJGigB0GHI67nG4ak&;=o3G9l>0*_a@6J2;s4F`*H%6N{jIOtnT~0 zp&ZaV48r$&?>EdplLf})s{uNxck?|LnvJ8!YrvAzTyn?{;km9NgyB(Sj|`3HTqQ+Q zah!ONznmYt@`;dheP`kdXi? zR&%$!`sFmvs+OrZ85#c=sLCqTL-r53nm>Z^_L{g{l29*`>OTc7XXg71maoDUZzl%r z1;}!A%b9L_dKM^dB6h6ALDd?=z=CJjN9kBbuOkH}6^orhL`wW-OUY#Fh#FV$J^nghIktlj_#o#bemK&_sR9^~N;n1PfU2;J$ z3cPZ(FLBX&0J^5P3fWAP0LF$UT1Y|)b*HfDYYwlws7>s_MXs@=e+si|4f&WjyUU&J ziN7~B59|a_2E++}=>rDUJQhk!*E0)MjW@Vq?k`+CE*dkgS#pz+WD|d1jWek3DB!~u zn|M)m-+|tk9Nco3f6e=w;fhQ#ixBRYvj_nG-Ugww&^pvu+G4P(q`=iR+qsA~d6u72 zPMN_9**y)4j94T|m;wu0Yn_9At$jLzUpCZ*#sE5_sa4?fWmM8~7~EjIG2do;Xi#IC zKaaA$P@YXt|3~(}$jSg1SXF#n=KAtaNzB|Lr*6l57r2VfKQ3Q!M7EG^X?&P!a;F(- z_iU)zU0%JJ{9UmxO#@ejEMghwo~FdQxo14MV%##oF{k2K3URIT2LV@;piOTHgAYKR z5$j~k5NSHP4>wt$pn34%Qt%-IPAuU86INDa10}LkHP!-W=&`ee`P3!psho<0o(jQx z=b{SRm17A!5r458j6+=mU!<$u6G?@xpu|IBuwmK(0b>6LAJQZz46u9YsXaH|pgPb+ zgf2ZkOFS>%vPfQH-Xbh}WfhyzVO8-y<+^u@usJGntT#V<)$-P>*SKN1B3$coWY|E< zuc3RPlkV$RC4M0b&oOZ6SX*90#m(3YuaX|S8p6~etU65)FvxpuC-Z0M8rB?$7M#H@FCy8I%Q~H;|Z3MAR_#C z=Cs%x8H29`(CBby@f$Jy2!DhUg-OcKWYtVukx!v-;SC9z?$9d|5s(@-9)qilbx7u2g^mZssIAUzP2tIoxCoXHz!Dckg#*jwp3)Up8K86TGRW+UpNUZ;ZU z$%Ac1I~Rl1=xWVW_TJl_wb&~rvH7i9{GB4b!zZuzbyV54MuaHc_NEpSz9DiBcjkiC z-3n{X@T_kFpst7&`w2^uEtBXAP28=J%B$Qvw#7d3FDCMIKyYOaG^b}`Xu=b(3Y?(q zSlc9Q?VE5!S~*Sd$=UT0uVfzLDJX|3lJFN)&K-1wS1((NjJ0!jX5Ad1ag>az>|N+c zuQb;fGCm(-mL{U+N*^F#gv8Uy*?#aiWD zG)=#1KhgzA*-aKVUSyT+K7X`_Cv*h3Tk*WjX)>6UQg+ z0vDehLNl4jT^G2=otpA(t&kk|)bw~yaHd|rcW*AXt)FVT!ku|yJ+?Q&tz-RZ_J+75 zb>go zK4k6mGQKP9b&!b;uqLJ3s+Op21J(L(tByy%shBb_TtLhj3$FbhCv{j5mx{?47U+%0 zpmECZGb?$)PwFB&Un_ARs^f8HobiUvSf}IgSg?w^ptwVhzoB3k1`qy3!WTFO_K|Pk z+CkgE*znzG*4BK5zjFeBEpu;56nu_x1J(NyA9IP_4RQBR@;32Z>ihbZ@e@uyA>svE#Du@v+eyC1)7Rc7$%z0PX@G`$n+H|KK=c&p(I8-D4?8 zp3-4&D*hVpx?z3>;_dZexRszyM;S65*cW(Bu;lRL6Wj`pryoeBzu-N`DlXvu_&ch& z6eUINfR9c2#8bmeIUH_bPfS_Lc+Z{AhShD|JEUd!+A%%<5!uv|1J0DgtDnz28JG1f z0-`;oON~uB-L|Qwdyf6vo2)2%iG`5!9z(?QBtaMJYFqvCciOOK1{da=mwaX`*Hjz4 z+$m&x4q#}GrEXKiQOyWGQZ9(#?ydm^uZDsAdhqZK1GE>gNRJxOn`~~9;|S6Y+>cht zSq`qkZBj93_|4a~;A19Qh4C6^7dOe=H`e#L0%lD2agehlpKWe=+FYhJ~XH0Q3m~8_Tku4k{WcYO45(c)+wRP|9bN7qM{i?rWUx-j^S#pswT5UjQ#Pr^A4Qf%cR&&;C2Bn zs?>2r-?|%pswXK~Js*`>SuyP0-jGV?beJe&ZXCruR+CEA8xYVHJS#Vq@WOj%Ah0rB zDPzJZUbPDKU8Uv>_DSr`t2=cYZs>NyP84)Db-d&BV5<)Evb^8H_)^~EPJEinwiTY% zIHT(N^x}Ni%!GJ2JEI41l&f$lYdwDqyPOGtIxUBA^D&jIY z;9Z5VlYlhjxO0U8L0nvdKI66V(K3z3-o#MfL;4PbK3dL;@TUX8t$K3wrkBZQe{&O| z#`2M_PaotkRl<4?*WXONwgSstf8ZLkk@0<%=YVl|49=N1e`_;n9^2-QMH1c2A8IZ9oAe5|zOR&h|@W3=&E*}m%p;Wf> zs7JuJhx0()w6lUY)0ozY^D_J(U|yM@fB2zPLNm=es9B-n5e60!y{Ev3bVb(q$yX$1 zP45zapE=ATA_d*UbiwlqH@jhr$;bdPcis%R$0odtoAj8Cc&yN;JGQnuNfdkn|J?k6)hO1(8X_{xv+1AS(6Uj$f;rfz(g3+=udNO}5?}Pi%~$56xXJO^Np_M~(@1<}PrnoENl%w1 zV=a|+_tiT)*fim5{L#si80%;USQn+|roz~qB47C|+<|)}y>w5tWhgs*AG8L}1{(GSL}j`tjguFN-FHEm7*9YVGJDXexM%3BLE) zbe3{JUOoTOzm)^NnN}Eb}TJ`+H8`CPrs*l1rw6bnM5=8O$pYkhcGaX!$wSTm4 zyw5EdQHEN*2x1@td+HXr-}PoT2K;j7m%Mo@yqD79@_Y{eyq0T)#-8V2FeF+&TGR72 zv&3W!P?{== zj$k3v%D7ayW3qz^mU8cX_0ea)@=W?+SD`k424Wz4{fS(8PD=X*^|0ylkXBJ58lL4ldxE86lHrn& zSISAQvi*W#_kcULiL)0qr@6R^D@%o230Y@x>!{MvdBfJg#Na;u z!q97r>Sdo+fv4n@lGMS3Jm za)M}cM|+iN=nQMa<70nyZ`>UICOX`80}2}?e7vV25Aq_6iBQLvUhl7hNs+wB&lLPw zSHZlNK^}ofSDDF|s6G`tR$P785By)h!-WqY6)bQJTFKV3}oIqx5asBwo)YiFGreD(O5qz=r4ty}`97 zI|IKKc^3Wrm`FLA55#*mZ$)#Y@X0p; z@~P{x=^2Vs@sf3g1~Z_ot%nDUtmB5H9*wK}Re%<3Jz>`mpEvBOkn#6U+{rLx3ZRty znux1=V*Ly-AKRM~3WzOyu3P5SMiXDvM>McMg!zfi{%zKx`Nx~K+0!#z$Sjj*jB2(8 z)W@YYLB@HIgJITqx+oJ46EDb&t!Mrw1SAnc;tt-5|k{K+~)q zophR!cC`tU#%FL40vsJ5h~(oxdXQ-oVYt**FuzYc4Kd##hX)c? zT03&m5xr^xJ=LN4y@lOg6mP#0?D~3Lu8_$#aeIu>20R~g8xxfIc8@Rjdr-u>r#aFp zYx9L;-({sNcLp9J#N`p~B%d(bk>u-ly_B1--REX`hJu%fdqK@!mm^k9=UDm>?rDdW zr2Udj1(794zUvweLL-oh1I;G*o_QwQ=iNaqvUXEJM*Epr6x}C7^7- zPZdpy)C-ye23v5)&!N^4)!8B8AfPFGhXCS}JUeZNK(_xeKhG#r{e#2yCUNK9)LQUp za*f>{)_Q}sUQ(q>K)iJ*Tb9B6%hK53{ucI92?WZxof=XU-5{P{vE^ zeb;(5b2MDRowT}uYU;@SOFd{i2N+yoiPP&cP&=84Jk*vIL0g*W4)n~z`orHDXqWL| zTXb_rIq*qo#upA|TX6U~o^!9zjO9gehjv&K`n|KEVnQt66${~R!K!V5Tcp%^ab5In zhixbIpN@?8IIl3ok27ZYa0-69*S#erzJ)zL8sTJ0d5sFn~gA zDH$1?`ANl0V`#2jB>3zX@hE1-wj>ZW`02MTE@rW{n0u|S5K>4tR;uNRsa%KRt&WxB;dY5=3y2v*07OO9BkxGu_C~u2}+fY0%&aMIt3BSP9~8<0&|V#l4lB{wN9TQLJaR!yWWgR?Hwd~>%Mxc2Iw_zQL3%U0QzW3)NurLw`k#G=GWNgrqAIDH9Cm*v0Rjl+t9 zY2{b6RB)mfUp>S(xiC+8Z=k{U`wOZ9K07rz*ezejb_~KfkXjmFCMa_og>&Yak7}a# z)nC&h&hMUcaM?&q+%3YRRf@$cdwd2ablL22gqw+%zF%x}Vt{v|UUrVuWsIUEtPV&YHi%1zB3u67c9Or~QwX z5~(w6zZQo<&Y_1XcQ~BTK{GnYhVyWI4wDL9Cv%h3aU#u@i-ITPmNCpxWy4@)Hmk^= z^Zv;fFS?@nv{=!JK|Shlr(y3VChksk$KJe|VNn|C@@?L)qu;G1eCb9yiG9nrg2cK|q#W#HE&t2ckQ71P0Dl6!UusI8I5f1uzwlLu^(a1SISfVKB zoHepsN^5MFDZs{(UQpy6J7{C($;K!#Qg;F8@alEVRC03>pMj zlQ~HViK}<=J8?*b+heE<*Th6tk(mSrePyAFZgT8N%4L*JN<2#XF^hapvwZhT9m&i< zLk=Y0^t~YE+OG9xnTWZu=FaiBNnNH7JTL)eB!2!a80Qq?8+oB0N8>}x0)b}79zhSC zym^;0kYD>l#Yt}H5pvObU@>rp`UMs(&X9gNu`4~|*wV!|RK$)L79ee7+v`6#pnL`0 zc*beFT||)aqS^0`#*fBqz9d9ke{Vrl?_Nn>r_j_5TbJjj7xjF{-Ydzvb?U@76V4M= z41%&+WycDKW5piblAf2zYVP$Dj4D2pnCyr=xUE&b>MZN#_|10NV!KMWOl}W=8j& zNBPfx^2ok}JM&te}z-NWqHF9!ChHdQ!WS; z7J?_{48#c)+EKbRxS@kMzM22J_tC*!J0qt3R` z!3ik`FHG>Oi9Aoqqc_9fK`%?FiaQMs`k=EDy&MrYl3-=Y`unB0Tym5vm3GsNpdfCmbWk{! z!UA39BQg0>;iONWdi~y5gpGgPd>$2(vCg5He&YK0G=VDn)K&S6Pd{;SW!s9&)O&Sb zUFE!X?u%wBrIRv3kNFZZglXSb52SPo2z#-ou;QUar%y6yGAnPTJF(+rPOrWw88PhI zdKELF82Y&Dy5nsy+0xg^@_a|4mncS9hxE3@LkSY%QEXA!X@u(eouXXOX-vB2SVwXvM!?h;2>WzSjWFCO*qEPS*pM&*i;(c>1qm`|aPW{&o2e@bQ_+umUm}++ zx?`TmA{+NqY$y^3)>8dcejv{7v~RoiKU}=bPcX#wpFeI3xxwl)Ylu{i35iihPJ9P3 z0=MR$YL^sMkRCJ~rf5o=S~GI^dDzaZXc}Jn<7BdEo{PVQ3s3)y?SHf6#lYd*|K()_ zFLcQx&z!<=aCyc1DZwqSR(jG*36|nh3~~?Y zM{p}Jwg~-jXx;^@7SOXaBFB=k3NuhyU+BesHy zilck^UU>PM2^;3bbQ5z!kdi&79%#a;w>-VcIH+n3u)PKf2jzi_1@^KOYx!RibM;?D#RVkkK7J;_ z_@Rk-Mte;3-+)4Xu*f_&ctEJ3hl4ZLPp#;2sTXyn_IQ`T&=%1_()1tC zGk6uzF#&@>xfd5+rqhe=p9_}tfzGvE4^akNcz+8SY%n<#y6WwOwp2ETvj9`q7~TytW=h23Ox21K}3)-m2XGWi7&s=Sj;v_ z4PUp18_}Em3JwOk8^ElbG~_3BCaRTQl-?zhIa@U{9kUK98pe{1rrw=9mv7~g#O(79 zM*A^5hE=(-@s&>Oql7Ech0NR~I{mDSX}A5YwsDc%Ei9H0>BkW5Jk{Y_^X6*x+#|QG z_W@9>{mk#}em*~gmFDHC+r8_n+kF8Njs;m_H@CRZMh#}Vl~>x1v>Ch8Uhde&d}9gkr;=|MU12%tw7b-@Sj8Hzc}Ub)K? z49ffp3N{&&n7b(CaK9a$w;G~?Ts^ptz&=#GT*+6wT@s{=lYmYN;Dtnqr-myLLFmLc z!L>pBD&baq4OE4!V67ZX{#^iafy;nvmYve$a)0_1n1cmc>V+ zuOzS)b}DaW6@~-eKZ4}l-%vW!Bl=pz%Nt0SeBRO^{AY+5@MZ9CQ8MwINb6)3-x01V zXo0B|GMED>>M*_W&9YsX8LlP)-c!5EoDeE5uiyRz3^EMvd&Aa8CkPbpP=AC^3@!pXeaLBXGI58pVzHCF zZ$&54Z^apwF4}AkqL*)9!=Sg88(dvcrT(yetRc(m(wGsnMI%tg^r5Fi(7^i*j5L&6 zT%iuXOF<@SlUac8iOfd48*nD$r-27jNsRzJ^MKP-rIjb1O4Lc^#Sv z#Rl4Hlu<|K~rb5h`s zcLT36>Vi}dOzpOVJRwkK2nvl9773(fAJ#pHmZ(1lfW&RHmPiS_43N^OA_DmkWvR48 zFhVpag~bKJSdNi!X5vLQ&PtK0>EDzcQb`fB$fVNZj~-#w7|Olm?o>V(HAW zb6sh`X@?)fiJ8y7S}|kbMw2z(hhbrSb+W`QyedTT&V?nmtb01Cz zy6~@vb@K)<0tXO9g>||VjnC#rtX3gr)Ov25G#)g%vidSwMMlP@e?!(>S=pe^Efz!( z70cTJp)waGyC(VPe*sw1N47B4e9G)6wh-SG_phOul0svP8=QG8Eg3MHuzZx-rmQPn zDk4epU`!o0xsYW7ykyhzXVeKrjSk?Gi0CUHa8fJ|yI`~-OdH&Aco^aZiO`uKD?OE# zLW;u9aUOCH@Sudmz_AzQJ&inwH4{)hh`zay*`LUWz$yEdrHMHFfy0mCw>h>jQ3f-a zsYm&OCw#6z#c2#DsTd_8Pf-9`fSl_E13qtPGPn_5k)`!g^Jfx*ZF1mS&ii@ZS$I3Z zi=N;K@zRYg4&63S*p$KxxRSOMXMB8j4(Jyj&V34;w3Fm~(ES>0A!C)l6_C{w?x+WY zD1FElmQLY@E64oydVy3*p2yB88yI52dr3Xh@G#Zl5O`*mzK zl6HF>e5{+ZCS5N}&o4*mS^o3RGD}!SHlE%A!!Cf-<%o(XYvPoQU$>0E=f-{MtCGpk zKx--$oy7Y1=zX+p)El$@>uT%hknio7_P?fa)tA!9iYnd}d5ZUVi-9u8e?%LKcZ<=} z;@=X{EK{#JIfRJkbPB$UI+Os+Q^}?GJ!0wPwb*&5corxQ%IKTO%K@BUvz}PNmZz(! zNe_=n2Up9>ki0$Z6ZY?+jUXcz8&;Tw8)3{L>e&|n8}j<9jdT4*-R`H z$B|c7YXW_s;cs_(Fyya0NyM3y!X{u)qI1>((7*RWTBY=lh%f*^U;sd;`VfwU5 zlF~X%XMfK>J%$|87Tb6Oo4etn%+XF6hB82=)spHc9K$^gs; zC9F?mMX7)wm31{yJZzDwiKid&^e%dC=mPgq2a3pO^ zH22$V=y})Ie*Scec6rZogF!5aJc-|PnQzy zd+~9-Y@HCNT zo~)=Bci-W`zX>FM*K+~=7!c*mEJ>9DavgI#61QCo=bk4kWIoF@wdsJa#&#-7 zKTqH9bO*@g^&I(^5{W)oiV24v``i1WAKvP$Qf@l)H3cn=9o)Rre3|pGWny60E_)kM z>Awb5@@8^sfF4`PJ*ijyMmqWwg{f+{yzM$?nQSMuxlnR3{Co<-;)me2v+Gt29NGq$ z%UXgQ)rbV_BgKavT6}&Hq&Bx-E;OG#!#7dYWP{u6^P+bAvu5qqYVr0OwKDFA)kj+g_g!GrQ_ZgRebTBd-;pC(!kI#eDU#R)}MdGHr3?`2s8e~ z`SoD<>pu_IRrvEnTcD%@b8*Z`l>b>+{@UdRZCAlQTKSkcrK`&7Vb@ZF%Bwy!|n$_C&&M>+XLrt1}Z|MzvLq|%fZ z>&x+i-ZzQzb})`1CT_5H6ZY8pV~ynOziQCBG!YQe*X~%pX8@ypL%igH&qq2~q_3}P zw}<}aAT#53mA^RhOnZswA8rd6gsi639jduWM13RNFz#j)E1@xPS$eV-V3Ms{OWUFW zQ8;7b-dSj6+F9kaa6$_qF+KM8k3+(VV;k9uEb=ZOijbc2wuWQD3z-qF{+SQIRhOKI zK`i1Co{DNLE@p=$Z@*tPn+hQ%BZEeo+VR81Y{JZI5|t<~4^cbzqKe#?_k;!5Pngr_ z#DBe*AHr)km~MI+50pmNOl0}OQ!!4`{7mP`pP5D`jt?I4itYW7{Xn8B_VY!pBa>RH zhF9*!k7=1Yyq_6vY<^o*DEGhqRrr5KpMHzt7&IM8=$|~;yMtjzYJ4ChVSrAM)~^@1 zf6tvN#xFr*=h3Ab=E%9zQkZC$n98%qDK79%#_q2Ra+|bcn`@?BjA-s{$;o8SsL85{sgqKm0wrUHXn=i+) zEJxB{nuUy7!Y+P0w2$USm86Z_Le^~_ zs!cj%&Ns}u=NZT2s8M=(v(SAA`>vjti31XF@^7Td{0cu7`J?U7cQHTe@uj`IXLcfM z81BIv{Bf!c8+-fg^qnP(GfDS#EzLOkZ_OYw5GkXf&l@(oWVY;OqpKx7sMzp(p`*6? ztai0t;q#cu>eaQkZcwK2KhBXa5Wo)}@dMwzhkINR#Sz9v4lgc#giE{QhgNyvX%kqy z?qpYvip^@L84Z^F7T7M`TQL3=kI`z5`bfJFV=XY~Y+vSAlUDx7T=X2(uH-oLxbe7K ze+%v>No>t9_?Yd_y>AjJ>Jvd_FbLhK?2$zdSco2+do>0N}Yb) zb-Sbg@kjy|}apOT(xu)lka-6=F1@3HLZ@SK|q|e8CP@hkRBx2_@&^ND-R;&__$~2MXseK2r z_FruQJTPAiK@-R(d{ZVw5A2_`<+Ux>KmT09JR8tD&g)`_9Wz^Kr~V!`&Hm@jnVL>KRnpB3otH|E|>_gOwsa z(n*uw4~H+2S8UPln^0JOMFE7J#GJlIjS}t30?|P0#r-cYGXC{hdTic{<5jG2_sY(M zSLsVK!gy8Max7Uylfn=aEjjv3HboXtyLU)?<(mch$h;Ji2;y`SaP>Y{cj(6T>-URH zg+3|!LM!gmQ+mnDwq0Fgyx*xBk{TFnPw>jeLhgdTh)O?6haHyJhYb<{L*%4ICzLl2 zWt{XTlUJhzOVW(3g?Gq80C7Gv(Egw3JHVZ%Ef0>Zvq&;sv%FureeI7kZWQVkqCo3m zoX3|nda<%NI`0xP-qg_!?d$RMUJ<}u00gXyTk?q-1iOFh)uD4_o)+rg@sIYYEz5P# zf_SCD=AqouDLDDenNRz<+U*9Ap^ef`;P7?y2#PG?HgN3GDWUc@wG**edO`Nq7I%UNX*1kpO%H8c_4;{-4E8l9Pgc_hE`=&#h``D{Yco z5w3<7GYr9rAM0oFmtmgb36ph}-&Gyws&f$^r;n5E1X26ZiIvxWyw}D9;VmXOF}^TV zq}Gw%@f3b$xcOt?mJN?IGMG#By=~tB0%b=cgqb5jtoCs0Y z*-V}JNk#X4Zr!oqxOh*yAnL!pj>uf6wzXc)+2F;@p7`PcqH;>BLzfnl7Cm5GnKR{h zyTV!651cQctv5NdJdNi6o^0PNnYE73i!Bs=Gu-w3)@nEV{KDZ4biUxFzr7N(!G}0N z2{zs}$3HHfLNgQxb~nx=wU+-Bt-hqt;_W)xu4b$TZ96TC{g!ao5xml>>;7$cpbaMZ zFW~r^;WO?6rG>rDxP|8JZd$wrTHz(?{%afCmUe!ASSQ%tRY zHyS|}k=_^R-`!FpKr3nIOS9h&#kc*9?Wn?g%JsPFNf>9Cfo_|lo*1!XOAepzHNPGS zf4Y?8es1?byaP;3R56WsJov@;tA$fmSC(+5>$uxEp|W9!>$~?oQ9zF@<1YYah&mpH zW|{~U>O@BeG2M|!{oUhGa%#W^_m`)d5?Jsb1HYDsWsF{o1&cCgJ)+^miZG=X+n$~n zHw%GZQM~QtaMp>wdbX+^?RWpR)#Rb@8w9G)Q;O(Zg1&o&kfY-%8a~D*uqWPf^JoUm zNfzvh&u-rfbI#6pS?8iDqVGbWQkgEwbRpa8FGBiUc+P}OzQI3wW>Bxw2e6|#@PBM_ z&{&MPGIga{XAM?_a-nZ!lGrR zhj2qN;~D>k(mK@1;(Cudwtvijf3e=*S4oVI^>MLLW+;v?%*wGq)TNkDw^x{@=l?*_PILLQPmT|%F5=*6&s zn1mPy_s%@&Etd+AGjo<1#|cAgk3}}v_%y-uig3-O6^1HWtuXt%j_)yNA9XE=dUzQ% zB6tT~Hn>@2Dy!1x>r*~LKWq{;+B>rPrPmgxCIMq&0ve7~sL7Ws*^={<+NXwQFRPyP zuVC-t#VLti`QmXuq|o9PRZ99s|H+!!uZWoaTfSr(elLc9B)h%wvS3vRyk?)5gci)K z5n{6XP|EG&S2|2?2tgIprjDUydj6L}<9K*kd za8xoHIZ=~_9$l5CQti%3CqtPeq_@@d zH-$xmQZA7h@tL6P*DtK&>X#wjpQD*mBERxDj6sD0fnFvra2$xcd+li^2hzN>U+JPw z;P?CLKb)vCUipAIK+CMz7`0bBRr!w3<4TBO|KWHJ9_WcNb-4s7?)NedO2Yjr-?wlo z%&P^Faz;_-kAa%XXhHoh zmO4V+%x4}YWxsMchOA>L!bhDy^Qa6#pCVnI-}Pw#YxhW7CopKx(&KyDcZw(OC2kXZK9I5=l1BBp{qr(XYqi1sGzF6*s@HfQ zodW$kRaV7M zt*^h!9Khf|ZMxE|DR%3nMF7z${M$()HlN#0i?2*Hn^H%80)(^D&Nd!BuEw*EnNs-h zj^D&R4Zj%-`8>+K*n;*R;k!W+^2=cEZ zm4)1+Y93GE!kA>bt+RYg_el>IKks+v?Wa5x-)deP^qg+7kWSTv>w`;=kDF>27y8;=S#9WYL_Z=~0%O%V&Y`s# zkAihuByWMC&^|-m^3|G~&aT9SANA>zkULgOs*KE5{#pV|D&}eeG7tm%69$AR@-SFR z1^1?lGVn0~Sb&e>rZ8$=K!|T1SfR$7A#bfJ@xGzsJawEwZW@~>cLc=!rZLEj=nsnz4MU{_2rR0d>o3vqw44N$)k&Qpuc#MfbH%HXhb zku(uD+w@E33&9$9Q313&_qjNl`=-!DJR$$;JW*av(Z#sj!@PcBsv99$UQK2JjEJX$ zyrXcWmWoB*oBE)NYL#P*OE@q{oLeOA=Mi~#B6Tgyocbu+!8VAeIhx93;O?ct)rlH(nV_rs>BEEie zh#%TlUY+NBj6Se1&-v1*^FKtmC0QRg^!58(a9EE~0lmF{<>+HPWJ@fxtZw54cD55E zC)|uNTH^SBT0@bc*v`B!(fCgdKGT(8JS-aGVKB{%U-%Tc8cgC97dGLfqx$2iICzWz z;f@UTci(k%`D6>a4UiQsJyP0|hbWKGa0H&W00}}R%#0=bw}{-nkc8I?2@(l`3e`)**-mkz4C_~-wQ8y1WF`vyPd%Ygfd)ETfRc>a z=NW>JZQHauZmx%pcDiPOp)6qv>5Lb_ql6a`+vTOR7{qP?W}>Ns&ENmg6y^p!0K5Nd z2Sl}A7?&{Y4$ln60WoWM6Y7F$=RZPlSSoweGCGtNneKET}!XC;Mq47rX-E@b@Lpqz9+Tr0%55x@os5*bHjUDMpKQ+i6A3&o%ZTYmOE-G|=FALUpA z;Xlm>Wyb|VG-HSx{?)o+ z?O6}hFF0((1!)5|6!b6~l!jgQB#lP5$+@?|m3yPh;jutf{cHXR_ht{lg5@fF?&|#* z4#9hsNnbZgzC3M>ATN5#&+a@Md>nQiy~11y+MHQRlhDVtzp0Z#Lf$8^`Y!zp&DHYx z)5Il+HPhr(BfHt{0h%fzY@AkpH+)7d=+EsS5|1BEm^(ObyN~U3a>kB_h#l7vY2C}X zKs#%tnnEhPM1YJxR!-Q(7{z&C@7TDG_v!@Xd>Btp`DceqA^T0L;c=PuxFZGyEP|B@ z{W5Ln(`X3AuIo4!9)UascIG3{E6OB62f^)&_-E$NJE5uz4M;t39HE1oS!!9ue9P_$ z|1BDSz9UIqJ?DQ}kHbTkj0npNp1LvD4gcF^Pg~nG>;_&$B8a2pX>D<#I`7(j%`3d+ zbG*ONWVkW&(0n`+Fl(b=MYSfJ@`mRogfJ^-_?DWr)8x1F*gk!PSR~MNLhFo{o|0v_ zP=dY(?vA3rsxE~j&@mdiWbcNbHn39~d9?0(NPiU-e;eX4F;kd0jvwPQaMlZ;vtvs0 zky@z!p|iAn8nWdj59ApyCV=XgJhvD@vg;ceAi_}IT_M{qM4 z;(-d#zwKDn*pcKf(+~2RDWA#(kSy0`_3ojvzI^~aFRG~~+HZOtRFxrX51?NJ!4U^b zXrGDBmk`IvOD#DP5|uQY4^&j}BYFv}l~s&8*h0x7BIvLquty9}f5dLc($I zaEa#)uz8l(6h~huvat*e)0VH}m90JZw^`8=hO|~01wxgBP+KT5^~R4 z?DsluD%LAJfHf;QdsqeO#ffgqYwVTr*iV`s1VX zdd7$yUX7K+a{8i z%2eH^C=Di-T*kmZ+*^pnNpCN9o%yUb$S>nYd`mA-p|vil*VmIJv;byaHDZ4<{Sw+H zkiPVA%B1M9b@TfmN^UH*6{31AfVhYlX$js_V>L?tH8vLS>U}N*F`i+NHwH4(C~BUbATZQJkY4Kguq|5l1&vKd zp~CywUTHP7ev?w5{rmt0oyhPrz09$q_L?V3jha*B8k%of$!b?v(8^1)1|#spAg21) zZ*bJ%`b0#ie^wCr@`V6n9^-{vI!ZH{U{QV>@rC%@pn&AnXmo}1$pm1wO!I63TfHg6 z=uyv<)#b8hgj$xXE9b;kmdDVvYs4N9aQHj7+zT58#{ z7HZE8^h9V6pQp|1Y&m=MI9;B7ewM4&d@^}=`U$hu^SwLFw#%y^m>ul%^a%g}ybMsy z3-|(k_q`C*wsW`1&fk!c4)3<8=6p-$O=u%n{b4)UgSYuQ#6QKbrj zyJO$@_)rQ?yLThSL{s!MZs40AWw zj|M--^(@H< zurOm?p`4Qcd|zk9Jr%%5Tn|H3Ie@iVjxpifYifb|_&4AE3Eq8)pE~-c9kt7b0#nMg z>bpxbskrw5K9`MdO|nqq`!ZEm+(p#Ep&(>ku9`DPynn|wy4_VPio6N_ZcK~x2cgjd zOU^O&)|_?B<1l(1s#cNcj2?L3S5SHSL=wf@n3^<7}wEi{GH?>1$fC1EH?F66-M zoY30rMfooIdwp1gyjzJ_Dp7^k#2|sGFI`zaO9?3Tu!-K+xC20)Br4fJsBV02=N}-4aLsaU zpTl`5uHmy`=1F@63tlizM1gdofuLiaRC-+~Yp>8Wv;IPN283_BIZQ&1J>&SELeBpx zwGm9AS9@L6b!t*ZG$f?LH4d_V|S zo0a^y{`+g5G1lDb($cQuZ8=Lyom*oW#we{kNmC7XmY%mXCI#*X*I(@Ob^5gi1 z3@z%1&ZR{aA9s8c%)0g`^W9EqR*dARCmi6SvWZ(PAo~fC>R3zrwjP`g?}xd2bbiP`nsw#n*2uB2Zsiw>|r>+wU#OI0hB+_De0&k9q-wsja$-l7sB) z&2OL>-Yv0or#2V$#Vz7M-`r2i_0uw4p(fxHChRB-vMX`;LF2b(v6sO`4m=n}8NmG` ze{EBFf1*6U&|Ih?qZc`Ys7;4Tk9_`3kq`(~c5IG#9nus*5X{FRP}u20wd*i(x0I9c zy#&B6Z0L5W%yc)OaV}Jy&2-+xHe+c0E-voMr0t1hVWCdDcSJob{EY7mC-Q9E>0onu zKZZJDZb}~1v<<1Dq{Z zUjTGhc$HqJQfP+3-}G>0>EVO%VX7l&zu91oty* zC}Q64pjI1Vds@veDO9ZJbHPqN_Nu``$8gtq0ZO$+cBc7t0<#6nv~i94@#EiD1aHvLbGb+C=OCQSKAy^8wqT5tYO9wlqcJqDO9%{+R`W*I^aRkmhfyXg1{U2;aMQ+NT;h+YX$#-7hvs+qI!}?B<;|KYITn7T6@u z?sa{&AF;m5Z~4`e zHLAh=nnoEuz&h4WcAj+{D^=Q9Z`O<0IacT2Rnd;kfe`5w#Qsp%1lju~4jwLfd4zU1 zxHuxzfYKRfj;+%xz@3@qxn^|(ZuU*zY?wD|yjpYj$?HE%&^kB8e6} zYL2bRCAQetaKIc{rI~Qkbg|pfXSq4_gf@qEyE(gw-n8iM@Yx8KY4`2}9tiU7e=4Js zNw|YLeQ;}ZIXqD7gg5WgOAQ%CYwyZ?20~^SiFB5Wm{^UN{W$EbURAFp-je_X<>B|Z z5aHe2(rwQ~Hn7cxc}L1Keb+Om84^H^bd4qd6cDG@12VUpZSteF2TKFu+E;u==7h*s z2b^%gTJ+1-A(0F&%(DP=I|-{YC$sEHT@z@TANh?%4s4|-C|i0Cbbe1<00ECA_uL-% zp|LnvA-J|xpOs8FVZz<+1$$00J6B->ggX^WIU{)1d z1vSy&vi6+ZspB*Di1Mkw^oVsUFI6y844BVxc*fp}$hz0vi*xhwSj95$q3H-U_VLXU z%4m<#jq)VJd(ULTBYUR;|K%HbdzB=McE{x?^^Q5dyQ#2vyO(64?;M0ld<5bwA=o8c zGGF2Ho^>e@a)_Mb&5OMYy(0+ie%);4TQ^EO2Z5nzs5Myr`h)<{xG4ziha&cfZkmB7 zX)9jx8W~;}hB&%Mzt`U1rEa!TdSsc80|wF8-p!`(KL-q6>tTGwjxR-K5>^n|T!fdV z*v|0KXp;A`WCLYv<5!Ao=0GK^Vf8j zn7E)*N7F<q;*q^Y1XKR@xeh|=T9OTX;a{I^- zRk?$-l75mm0iNp%*s6&;8=Dp6`Ohtxzij){%lM$RMv?JlLtl{f8uMk-_pL0&In8?b z3H+jf285cA@XW~b&F!&;V1HGV6Q+jEc|PH~lhC;a%oFXlhSEC|)>-W131{XQ&TTu|smXSLr1{m%pPG_TORCil}Z-=VGzoIC) zD(cX^QQb{Cpgj6gNLiQDg?mc>kw>1m-q_I166-9VzYvLWVeFDhl+a+ih20Vb#{&5# z=XRTNiTww9HT|o?rPVBHAHiH0WKO6|T5N{5{Bj)$bv=wXCiUhB&~{y$kGK1K5~)sI zo0lq&aW7Vk=RC-$GFviSU3wJ`HuStM_Rd=0-l+<@4!$nmUe2|%V%*QZFT1-eR(A>Z zvma48GI!O!^!8DeI$v=aYw9qQ;mM6A_C1@A-;Al?iyi_G+(tT?aA|W*g?gt;zMPkPT zm)gV?ZT{niaVm2>w1;CJ-ov?%at=SpT?*|s3!}$S^LYr#7JRRmT5mFFNS>P(sPdbv zqL&+zdp)aSfxF)d(PUb@@gs;om3^)5k+f`HO$&8gjJGx%;i~v?(`04GI01|F2V&jY zm^if&`Gkz#E6rfsf7yMKeknP4-d=N7O@|=3epqrA`y>Ic_^#GS-_dOAO2<6PGC|~d zriqs0RI@=mRk@dfp}Rv+qAzZ+;dM+<-RhGvVId(ELDxCMeVl}|oE00N8yuygyJ_Cz zud>pTNAhE^T%Z!k(%oCIZ|S|=p!~2Rs-tF{vu5n)B$%^Lu4EQ1`nGzF(>wKSEY2@ZRKhS9%&@2Tj z;k0F(VItO`RA0=ils2U(93R*UxGL5l3r&bhOe^_8xdV7{)vRz!RzsM%6oa*44 zs@c{&O&ngqvsa#Odcu6P~U>zn#D z-dOnFRGyY5oc@%Ky%(yI9uZoZWH=ADifnJOy=f^K1I#gzwYSvSF4bl%yUX0PK+{~D zg{_Y|26Ey0g6WoYes|I(88dq!_o>wZ5%x!qdCuddd}SSalI zE%u((ap_103$%e~c~bG}Q;pO)LdW|WA6Lr<5uU%`Nk(a%AEZ<=h8;>?w?&5ba3LYb#<%Ty5-UCMp(z2u&UmJYTD=z%Q(tW{U$aD+% zHV0l`x5+7Sg^cayR*DhOkNn6uA*@7{t*BHfeCktlV1=3T>x`^Wrb2pTUsM>^(7c}8 zO_SlZD>KlA?^Lf))Kt!BK$@}Y`?@bqHLE1LqQYC7J9G?40`Hb11xM=Z0l-E^GynQPWzU3cS|DotuYR*dWw zo!!J^NNAlsuNL7mQC?BHM_)tDp4yF@M@C$wG01eXE2Zb`{*>FcciLZB_6-EJ0-Xs2 z7c0|x;UMs-5Z@=yn!6LFjB1-{UbiGbG!9lV*G}sya+3vx_W?^!Vc@Z0FevoBfLArOMzZXePr-O2t&$ z)*Gq!t;b}`t5Muf7jwOWA=ubVarpDvpyh=G*efrGnPfkon^xU+AiTQ0F%qPp#3;7m zSOpRUG}M95>2_VomUO4Q+=Q)H=3>gmHZ2*TR)^-T}% zx7IC43Mn5i>|PbSnW?>;72#Pa1o~KksJw?H-5)Qi>}#p^^PN>$lRcYs-^fU*l?Y+z z!1*w&{65YkkaBE;ujPpsZBo`9$h8+ma+xICqC+^f?ga>W(k(?6>O(k#GRn>F>cTM zKS|-vFC;kOyVcB0d)Qp!giyKF`2?0l#@HU$G?!Uz;aR(1%gyMeh4Ad7$Ev}*S3lYA zFb|CR=W)hiO6mn+7IJNFOYKHE+9DK?Z%lnmNO zH9UWS&>BWAqM^+_+D1vFmCVk+KSYx`<1t;)7z|dg_jSV*wrnwBG%wax9{KChWYI*3AMM){8!L zA@nO%1nsvL4uLafLRZ+o=AScu0c|Jvo*$y6$9e)_LYP~YS50gh2ziUgIto|*`%H&Zwps=*;%?TCy@uQ6fi)e zHiBSEV{>R{n>LLlKlaY9=Rn<%hWFljD~qKo`XgGq9z{V%wpsf!ymd8GI0iSPCC+Zy z_k|Jd&_klsL@kiJxZtgwTFbZQKo?sO1MLeh)TG2>sOXU8jUNb6?ja2_qiirk={01f27d-?JOe0Fc9 zZe{vn;j(4oQgZ6_L(%kVIO7Wk>2V<=_YgyF`7VCCEfPI-P2G|f{Z_7-i4n#sW=jTWiLSd#^lZ9TZB_bOQtBnrAYSU>IVa@zqa3Pf4 z?x4mXFv#C)-orWGFJ@2T63|cg3z8kNT==xOFoS(#|rzx@E zyHoVYKe?I73RuOp&dsQ}KjHFq)bmALLm+|pP+AL^zgxpgsjGh3#+ z4LpBbF^eTGbTT{3dRLq6I&dXRmmRAyqCe3*rrqhx?&qk4JvET*QX|hj+hS%l7e+bC z1^_{M)f?7K86G;`dDS&}UZ!{6b8kNHp$pSWxF=B{E$r*XK5p&qWYBEQ_4Di1jo~Za zIZO`Xue6=J+Et0Do<3J@qmTjTYDMQ^&mdv2NXs*HVX~5JO}_%yg={Hk$qE&GZnV|S zEJyAoo(2EI7QBP^JwsuB%xgMp%@m>U#sr*(bZwh9fgN6QlzIi9>LPt2YbhLhNzHAv zRaB~S8T!QT`T%Xp4Ud#6H+RL7UN>LC5^n^rl9vKsEPg6o_a8v=Piq#|Q_gklB)!@iz&36Cu*~OnJ>17i0ct*e&k8 zyG+*-ksvWH^FdXQt|iZhP&}`0yPB}k+`KHn+663q&V`CkCMwwiJzt8cjHnZZ=% z_2rPh-gxyHiT-eX4x35MO00%d(GmAt)H9eJ?NAKMr%x*1VwxgtNbrb_^;sVM!O&lx zVa?|VzCzz%9@oW&1+Pqn88M0;(Enlxiy+phS9Vs6ngM?lsz>5m=#K|q^iVEjD*A|A zw^b2(5kofj?Y@4#8fJG^Wry>khK7Etu9CE}(X%mT$Ygqa6;;OHxl!MsoZJT7S1tHt zfHLdGq{&qv$1raZr$uCg5lNXfzdE2tEMd0PfF&hm{Ym1UPSEMKt;d)07+{fq#F{>P zrMOZc-5A+^y_IUNqwf?>=~u^dWn*d@$YAgc;_T8e(r_Wr@b>PIXAXGsy==Dq>EOU# z*AY>w$!@rh8}9I1_LV(RD~mo>sEPMa3q>>%SJPo5g1^sQPl8NOFW}SwL3cw?n~Z}r zvooZ;eL>KuiCNt(jxNX!xm}-+g=q}YLKbwxI1p<4}gC|*2 zA#i}MY(s}!Oyv+H^C(tp0Ubnf6`c*j-{;YCp?a+ohd$9l1DL|NRi#UyQa_I&cxaqs zwxKdt48_+$JdZHUau++6?qLcPjI=CXd)9+X#M+G25ksScWae2&9U?jhVa38=-`YUG zB}kt@#Tw%EYL@=`Vi~`LR8!<>AZ4;=WzdZuvDc*$GZBNMj!EnULkmB@E^~gHH~{r6RfHr#F6dzX0!->R`x3%0=DQT5I5Tj2=fd4ZbIl?SGniSz@f$1vJ_LqU zQ{3NNOxWjVOF5L#>amcXGqT8QMO_mO%40MTp%aTL9RN9mnenv)DFS1fptJc)cY&-X}n&qkeoIL*6!19x#;OWA^PhE{Ye z9N&{stAI5vX)xP_l~?7OxkMvu112Y)40v09H@gMfg1>$|DseF%`6BmHVG;$-?Ui+p z=9MArLb!DK9p+#$Tl|O6s+xAVpGk8m%2+2#crHg0ftJIyI2q*(s&`0|N=Ze(*2=(M zu1PjXNAyt!n}$Xk5kaFa;aJsuCXbPCSW->@@>1hgDD+}F5ebj~uH|3CN<Z$FTQ?sf9Es0&Lap~)7Q#l>X9KAtc2QDRlFPmRtNkCoD;b7Fw z@!mT}FB;mo2#cX+E^Et&N$#R;o-XJCnFvM$2NmaKeY}F^SP0*l=lt6XfO~T(TCS0* zCEY1#4`v(5hB4NpQ$rp@+LS<=UNN?kY3a476F*}O_EmIol@28KT1;VhQ#xyd>n_HO1^7Oy6_!E3;jF>0_ASCBA;4`kui`>{%UT!7W0lxHR3i@toJ2117 z@OoizJLQub-dX=F^~w!07?@f$mgNgiocXLan?<#h3$6_j5V@pldTp!?MS1YIm3~^M zIso^TpS~CR;nJoL;vnSbol9@y6Rh|w+>EtD$=M)7BFJPd4TGL}y{Q`wk2?r!&UtAu ziudn8sgVOk4`D(1UjrqP8EG~54p09cizv1T6Forjhw7lzD27Wqfb{8D# zeX8EJ#CxKVaJ{Jln+2DA7znMf5-xME$WndRH^=9t;AJ@VD6T=yR%fo{M7-}c5B_AO zfYje2r%2`qk_)OO6=eSJI>zrYP|u?HX&k{Ov~XVwjm#5x6w(}8P`aSSfzE{?J*?^-PTI}H-r z<2z3lsa_ob^6N1FM5PBPrZmXMW^sK4`pvHVGo(MiQM;lbN+_e01qygnk6Sy#+U&Xv zAZ@W!WN~)d`=xDVgGB=ByOAU3$gqS)U)8YG8`GYM6yzQ|JFc`>LWw7YlWy9Mm)zV+QNYZ#r` zHvG`4s;diEkWWxw8)KWe+1jWAcvTFR3P2`M|9sMy7>rzY$5#swE&P8k!+)vaW1^%Y zTjBg9ys<;>Gtc~99b`!TfQ`8Z62p~$dHQ1 z-l-vb?LnUlcgXw!rzZdc9X;7Zke;jt8%ho17C*5xaecFjOJdO&k;adQrl z5tMkdzWIhn64VGe>c1v4Af6~{Q@i92)R!eB@>(te+OZ5e`M_OwXY+<0mhWF^1j9a# z$~Kk@Ue7vhjWt<^td|Z{(I6iwbVg{bN;o; zNo@=+HKpB{_cgXiJ6q)KMShs48+!FO`Z>!yB%q+PkC5@*o|2?jAozvfn5t;ttus8!4Pu^5|6FqgLd<+^>XImK^=nySKE zdqPpuv$_?|TO)AQtyY2jg>4V&fA-q%oj{#{vRym}#yG$AiSVOLi5#+^fqo)UbYc3~ z#ll;wg2Sc*W&0=6 zX<YeKObY^ZmiC9rtQs_&lw}V8se*D!x6IzF@L`^Ysy2o2MKF6#9hSA9xtVOoL0tWjpr2aMRZRJQ0lk1t%iiyAAHOI z2Pz;>M-^cuVee{eZx!?bxa-`(sM{ zxOY{}UvDuzk!H(MU^b)>aQ`UDRF|A8bN{X;+@+`)7Cf4IsbS`0k>W$yjZ670w1kl>6w8q%O>r#;h-U@&jYTg=rf@l1NO(LbJjdb)H zEhVNPid`NT_EN25vj*(uJd%y3<;g6kt=28}b4~`id~DKIm|J|)Rmy%NtKfvN=+IqN zb_v*<{8qZk$)b5WhtWC^tZEZu+M`r_kCdN;834R z-E4wLjjULXraiW^sTpa<)_y&vo9Wi`w6;D}qVo`QV7oZ0P%?Obe)|)4?|lu2elVndhO`rugEoqsT;(+eNo3@{- z#@Z8f3#s(ayqs;H?=i4VV#P4pXL3RKX1i>~yQ_3|-*C%JluCs7XtB#s0mYXvH)i;j zKgss_$jQUMYA>^{APoJJO%HZ}TK&JMZOxOcFbviiLelI&Ya zulpG1w*C3_DqCF0d#8ROyO%Gh#Q#rp`1|(D?C-*wXT2%dq>bf{cpkZ&zVlyloD)Yu?R=pISf* zLn*yxB#|W&ecoLoc>QD%dW_n+OLG=Xe?AkB5&2OL5Z$uHIhx3tt*YkcmM8SFn%3$| zdVNdf@tO%7$Gnd)`>J_dYpvjP8!A3-Ro)=R5c-`mH} zT7dHzAIc$zu1DJy_Mt%+?rVkFEwiA#@o(y6hD0g0M>(>28$n?(XhxzJ+m)=e%+G{=SRn zioJog)|eyzW6UwfLYwk9oy~G?K~vb3^urm^-P;Vsx3#Zg?MRdF&i4YpagB-sp#O-% zVof0mk))zyU*1jtty(p$JM`fAOj4u8ngv_7;lPaM&iq|B-o}6w52N{9eS+@pNrC|L z-)NB1_UT2RdF7OFQ!Mn?!G8O07>(`_?d@`rW5P|Wy|iDv>XiS=d=_3csv`^@*;?9< zM!b-0w&yC6Tshg8)yKqIpZvfV$HS4{2PP5EXD-wrfuF9F>NHC}&F8EzGpMdD(hOf_ zHh?xZLIT1n(up-Nb z3VhZ0FSQy_q_a2O)Q01%Gu-9W@;5%}>jU(p)+PI3<%h?(8*Fjqt6q0dO({qD4jlR& zu1lnW=0Q3MThJAsRutAu+!SC`6wGXt_Xn#$gO{Xc-tKPs_heU2gd5 ze)~y&v{0fk|Dc5eQqkgF1^LGqo{g{>A4nUxB3S1;9i1 zqcgzRNb~XvM-qoSX0kR#ggK^Og^Su-C$_S2O|YxNMAF%BoE<7c!(*3@%vd&QpOk%G z%S(>aPm-SznOkfvgm3Awf}IzZLZ^me`Ps^zN>G>=BOsQEg1#KUt(nT6mxJI5i1w;O$~G-?a2)$aF261|9K z1wp|9+3DwwuXt=VWEYaj6*SA0zh9esIR-?V-1Y5#?W^R_#FsrSfoW>pc15v^BYG4d zuQLD-s6y|MCV;`PAv1{Km}st7E7;}LZZn$A4ic=QG2XI~eP-^Bgoj5q zyDdh2NgwMC6o-H#DxH;@V~RUv_`6$PAKN%J-u>8hhdN@nqv;+j0@d`bWNMG1;yPN7 z)X8_+YqEX~x-ChkuVfkIwA(^TEaemQQOPJ1%R`!|&gCFP;(EvSvq4RU3XDIq5ncw! zcfM${Wx6kWJTCU~iwWkF(Y;ctqP0Z=ve6~f(MJ;G@p2K>(v0F9Uh- z@y#-WdP(24Sk76sATalHAMkR-O9=jj4hCaaf4cTJxx4q2hleQh^;-4cQs?A7KM(`8 z<(q0Um20Zq)e<~gzaGB0;)OK_8}@hJlWt<~c5^7@iCf2$fz4KBWhu9r`f*C0A^sE5 zdP;4b9fAxa1>#oje9Vl{%G3x#je-@tXuH|$@=*;@mV`XaJk9!v;YH|V~-eHVqN$CyO>n*S1trf48CiixYq;RyB_WLM%sJ#b(F4u-$mImPmSNhn%M*9wa zODtJQU@u$18opCS_4#q`py`4f;jmN z1@6R4v{4D-ez>>m{goUIJxeQ=(T-MmG2v=o>SLxoN(NHFDt7XlwI~Y2NpkI_H)3T{mJ3u){&=d zaN7D-iQ;+T7g=v-i{@CREjjV$S>`!<)A%j{ost~AiSabn$oEPN?Up0_7*zMsCB>~n2h9CQgfAk&#XxkZXQ-6A9An|0y>gL+fPj0SApmRTOVeO31=NjJq zX74Z^)ZC}4>t?8G+3vlEtQ`HKhAF*~MDm^=$G!UDJ~Uyb$Tz44J>gkNhX&P9$q#@H zG#L@Bn<%bcIZ$-wBKP3t*1Dpdzoz?@f8j|2DNVr4n%S$rqV*5H@!CREY8GpS;|FDQ zC}pA&IZQn`s7NxP5XmcFr*k0Ux@Htb>7qH(ZT!E5*AHzUY#kz|H44oSx{G5(cNofTI zM!7bX^XI4bG>&!bT(0x?Di`~T`>#FkPj&WT*>qxRUu0(Ucu4}!Tc`cUxG z(G{t8xrh~Ly%@x`Pj+txRkar=#QY~6w>1P;ASCZDEV8x>CRZvdMSj=ay&N8^*?7AG ztWSS?$)AWmMIUS|$alZm%5sn-6dSZt?XXRJHfUa9{ZQt^vYB-hjJm zNoupCxG=hFugoR2yQwbyWTlPWUCL4s*S%4W&W(TEbvQh8@hT|l3Zy!$drKvrQaVC^ zdS~Rc`?(Pg{aQ$ETh2aj)cd8RdBYiUm+wiyAwSDQwH~oCOdug2a#G^J?k^Bmbb{H-DX*d2kdG``OOi8d|SE5Vi^SrdfXVjjUV|+TBqZ*yng_cCt%i? z&mKbcW~=<}|J)t4olJe)*kX7)Z>y2$xE^o2KB{N=|bT7lAAhoU(967`&a_pr;v^a9gU!^!Fc`L13=lo+~0)ssx#rvqmO zMjgHve$iSBSO-|Da(mOOJ+-o%iHAC4_cZzgSphq+xSE;g(U+zAdiz!M@`hcX+|lT1 zn)SoUPZK41pmz2SpEp@32qjA3FHcNy&aEs9j19Y72UFE@;K7J~~2EdOJ`U z7z?xf(uRPrn4F?6Eb05ji;Qp=kGM_T%*YxtN_KSmTIbGak5P%{GP`-a>iv*Nxbm($1uOk! z$YhuRFt+q(anKoV68**B=W~85uo0eDp~|iWC_u&Vh&^Unm9A2bv!K{90<8=;*Y6QV zX>MX*PoJxl2*2Tqz&E`$q>)hZY|f-ZeX`_HK2~l1&<}z~Pp*)mI6s@XDI+dgh`_C) z+LffJYiY^4q4w5U2p{_{VDwsiC0YqCCgh$AM)hY>DV_4( z3lBM!gcWm2z&ES+Ki?1aq}*-WG?zag2HjBY)7U`!n>ZKgD`(v~YA^WZ0L@yxy9il_ zx-I#4oIg*#Jn`~=gqGs%(0^{&{P23b06h2#oRjHZ0R4mqMqfE_%1uOLVG`2seItEN zSj_1?J=L{C>(l-jhJM3_zXkrEPygZ2|Ktq*f7Fdu#c#ou6I;&0m^m-b=d`Ih-1F#g zQ9`Ob_j+9oE8A+7DgLnJSPx|Dwfz5mumKZ^orE;mCN@gJ-EXBqdi8t^~= zxzGP-&;PglNlBBV(c${9!sLyGJC%D-YYI&^t*-375xe-MWQ2%M_n&ov2a!K3FZpoi zSZhsLTdwWta(TBo?eDJj*k1RyIr5r5J00TZQE*%BkW$qp1Kz3rQ`1v^pZ%r5-g-Tlcm{i$06xM=Et_)Q>hJPa&#&+7sahzAY8 zRDH)e#a2l9_zC~q-}{qIoR6D#_%}fs#Df@M@%{?e21W@o|1iuaUbhgtb05zqm%>3_ zAa2kYY~7$)<+4ND4k(9FGD`gp>-@gGBMtk?o@d00#B(Z-Sb$XbfXUV+tT=z<(V59rf8HVc|-BHhO9Ny_}(Q6Iuqz3e4N)7q?EGLbl zXyWAI_X}?0eE&EFNtyHkPgr?aYbH<-L4Hmfbb{t7<0hru0lQ@LiF_BiP2pTYPt5mp ztX@+!UAp9Knn73&D_vBqSIJovS<-j%s3#FtY|^&}A?|f+Q87dx7Ayz%N9wv9n|vPK zXn!^;1IqK`9kXI)=6=eb9;@IbpqgVr%1^<_fQHM=?}!q0?>DI11b_xP2S#qSkiphA z*t+X=w>A|)X|s6|0HYa+zNs3< zP?V~!q?lekRn)^eDb0j04QTq*^!uNUzr0^zCcDIe1iYO7X`Q)q-f{p4DDpXp?NO(U zw4g(Nnt{1ZnGDshE2xo@PEEa2x|c?_z`RxBrWt)XW7NnKt*SfMc39?tnUY@qdClCe5PZG zh?RRrLA5)%G{d*@DOG(iT6y<1ah?)Cv6F8!UKQ6e+je1U0na;cx2&TfSPs@1;A{wu zFpalBFrKMKRxyN#Y-W+cbRxaNy!_;wUE)TZ-yTDh91x@4X#O*xqsG^hkSDvsLrR+m zn2LE~53;Ixhx%QfOW}d0?cSqaAKjxwHOa>XXBzLhg6N$tyOO18VI7&rI6cQ?jV!XH znNr#cuOMO@IV{JAXE)P)-7f?UjOiA7&tCoA<^kAp8*1^wxex!xI*FAHs_1KbDtJv9 z+^4O?)3@f zaYuO}cl&lbqDb>A-5QpOIh3XY-zPkymjh@Hka7mfg|K1~5c(g4yurVu6tYw?C_%sd z(}_|heOp)gJQq%YR)@1%rbVn{gqiL~5Y9rJF7aAf+fi#bmVqZBuz1Zk#$Hco=00wo zNDhbb3W=sk$SU9-~z?e7S+{wIP6?WrozW%9376Srz&p^bzRT26y3gW zskpC(SFy!;`&IiNF$F}ab?_7>=4vU3cE$GX8IFk?iQk-r#TsN=9cZ34Rmai!mIP2_ z3b%;>vRAcK?N=uOrjyFZxRJgGP|mVvC$|oudV*CV&iGCsj?qa3^nxQw+NazcJZ zm+r^n+teP)8(+Y`Po~H2`nqN|WoapGg5d5-X1U6}qwke0dhFTOCMe2-S+5U1ZS{6R z<~>7r&$v;s#%gfxXZ$CFJEgwuKZA*JLm}ehC%#Nj*Fq=d=M0`%De>4HC~xf;xV5f( zM)R&-iLtF2KjcnjSBr5C! z89{SqZif!!9^;JOOOk|VSb2z3g0dgcm+JVug>|Nm66dfORkZben^A}y3_G8KQUGca zy+u8ZFg*+Ub4*Q!_bsx2q7CVB0C9J_Tnw^9bgQJ$dD2#`)y7X##X{5i;d1J|tWv9D zC`%Uqvj+6-jThoXa8?J@_5lu-eQe~pQLfb$A!q7QHW2uB@rV>?pB}pkuHW#PC83Lf zJ^_b6t4itwc3*}rI_Z10KdcN?eHTu(#;K@#r_qTOm2@YtnI@?b_c~+jwtNK$y zeo4;XN-MCN=0~C9z9<&XObW#OI5F}Ij`0RyDn%0&%eUK&X5g_wwC!&?@q9T9sKAz# zNm!RsC0^xg!@obZflof`Fh9)D@oPpQI1i+Cb-vLCrA*?HFGWll*6sKTyi*Faeq^rK zB;2I!H5CB%fHG~G2Q8Z+d*79~YGGaOy#ao)*^HcS-%i@=@lX<$=MV8}g)=!8@BfT; zgde}nT z#1BRH_@r=IqBPz+PlwVNsc-q!EN#+NWY=1W4+18ZoV%I!r`V?z2rmixD)%5wg`ar6 zhAm(1INq4ns_V|F$)uZ@lS|&iIHBiDk5O!6DVCAuG-y{+9$^D>{lU*&jx1;qWq`acP0n`CBjC>|RD7tv)SRhwg4@lkW z%+1AYj=&XqS$Bs#5jz+WPQox`TDP(h17w9lRz~-`&Ap==v_u`pS)0uj<>(m@Xd`K;?DE?-t=lF`T~~6vHtJt?Z^? zeXgzc^{966^neB5<5vDv55A(k2mr%6Kr?SgmWJfxF2!=1QILV=j4&(r+V`GWRJ$)) zKxC5MIvq5R(TKJLWmaBi1GKMj+GmzwT(`rO;2T@C<@Z--E*%h1^BoZr-=^$BP9JyJ zJPI4nHr)UU8h3mj5p;)k?%kX++0@_=H8#cP=g)+8j=Ny>NbyDZ50&`^m8ZXcld8}@ zIsWvK|JlR%`fQ$9Pl0JfZA;9Y6VW-MjZa;lZALz}0~(~5Oq(CgC4s7rKC9+eY_BcFpm zaD9w;Y?5PEB)lv!4ZK3V2$YJqr9aq48ablB{?K4Hg3?J}M~mXYq}Y|J(j)@k;8VQdhgMKzSK;-Av2N z&o1A{<0wN-=CK9kbNWUHoz`>6G#udMRhn0G2kn3_9MtI1PfCZCt>c0PKGVxNMl|np zrvCCW=V0H1x=Ify@dTcDm7*K<5cHVW>}NRyp7kJLd!U(==e9RJHi+z5D|JxPZXwHe z+vB8Pi{kx}a$92*3%{M3n_H@@IhjkUz=GccD=q%!hkmq+7h=8wuA>T^TY`=6An6*W1I?|OKTi{;J&6!jq zzN7kxkSCDw2h=p3Y755A?@YS~JT4!Uj=n5YJmM>Y(@|rW%4K%HIsPWK@B%bboHDv( zmp|Qk4yKco&tmmgCGzO%;^GUtQ`uAL$&&By+v;G&8qwa%2nPoC8n)bWU!o#&3ceL0m?UL9H?}rdzCg6rl#g?x3~#W` zeCl?Hvddk((whVhUa+$$drp)O!P-`-_USJofwAM}C%?xFMfIT6gTS7W{z_0KCG z`_mo^qM7J{x8L!6^qz@?4SPr_;rwZ@<%PVfCCbkuh)DTVE?mD4|1|i(bif_G-mgFy z+~_2j)lYc5@zci7X2gG7iYHBYJemI~61eisV$5fWUyUb-fHeb0IWo*4h&`=hv>AO# z-_5x3DAD|J<+IO1kC#4wyj1kWU5oCg7+YlkjlsQ8f!?-VMcii;^$&yC;?)#lMgQ9o z0w)rYoXr$|Faht82|xW@*rWI43^N+81PkY>?jgkDN69d6!p#4@)w_|swLZ@qMTj6>-H z@uf%rZ|8p8RsYC71@%vj925cz?=UbAd*q&w<`adVp_upiqyM$awqJcIe82ee2Zx03 z&rA0HycFq~0DgSbD_|j{lc>*sd0|2VfVWqaLs|Jhc>80S`HfI^P=7P%Da5qWBX1!d zmuX(VqRjP^?+lMh)RHI^leoDtT&V*{<=MiIIaPmMs^#wWo&s2XeiLHOui>IzvyVX% zXDCe^5|CvT|D!BpJ)Ul2{6m(0j|xOH`^a-z2ew}aSr*@Dc(io=cRGAGM%#hUxc7NE;xB>@pYl$cPCrR^9heX&D!e?0&nIpT0%(agEO& zZ>nDrzxDhlJtdE8?3okFX8J+yBYQ|?zx<;Rz)M*5_)NG*b+?Hh`F)7-_cOmd6S!Zr z5PDpvxWntYU$XweBCk@kKu%WEAAeKc*Ld?d^Z*F|w2!j1sT$fcd(=F*cYw}zk&FFe z`Y(ZoH9bD_=_9pQsv|#BWne_puXTQb0iM}U@<^-!N8J7&2W>v^D#HO-ei$DhG$tsNq$3MVa%GKkQ zsgI9bpE(x!^|wwSJtM6`rqg)|=z(e%5)BQ0Q`yis8Ls8SbnHdHnOxUz;(iwdK0f-x z&(BLrb>4Is&_NkI{mn*sz&$8n8uaJfj$wL`O7!*dH`B@ zY?~&}fn>A|%_b7&Q0y}W@h8=Reos0KN+uwlQ@uk_*ngRelau-GRr|EF>;6D_h! zYM$AN31YShgB8o)+`6Ib?cF^D)3xi zxhtnFb?6Tl@z<63aa(;@)^Ahw^QfksM#Rw^GQ3cCkZ=kD=PWw66f1gt?dNtq_G5|* zyllstpjCuxVRO_Dk%k1DaY<*xW^1vNE4NQGb{PkScN`u$%PejDg1MkO%O>`!i-Y&M zoIn@Bp6EA^6kt{=zzL_po6ooov?7lbn$h2CF*oita?2bV2r718g|FO}$9HjH&pl4+3s&e!gx$Q4jW93;;d!~6g6L6CI1*5=r1Y+i(IW5 z97?we-e%^ynHu1*JyUfMsy=a-wBaoKc7&ei$S`^2s(XXcc zjuR163akpZLPzfLMr}65(Og93w|#F{e(bRu1F_o;`L~ulA`8NOFoz9+Fo=>4EAjV} zV_oiO=g;ssa?c`&CD&B9ciHX`21XBoJ@XYl>buVTI6M7j=Zd3VV)0D97YXx4Vy=}A z#Co4x%oY2KUh*G^RXnTn=6X$Hw!nfqBan)#-YYs>r$kXRA`VpeXC_@R@hp*cImwgxy<3h!`k9;P@DzftGy1`pK7KzJ(^iwPqa zYki+bhr{?5VqU;&rT87zAy{Aj8*E`67})l7k0&JUvqHSOLUoWRCY(;DZ#~mPpv5%8^rZFP(=~|8UJgRsTUT7k!JW zw5F8LWLB#C&|y^xc4e$cD0U6xa624&w*S@_<@Uhm5uGD4SwQ~1z)=ag2p zagr`BEMxkdM?%wPtblBstu-6GT zrmoEz1+1#;z0YUIJWVLx`=TgsVKjMu&w1NK7u4CG&TakT3lxTa3|ad4dTSwS#-^}Z zsB@ul8V750SN{I;Wu68?G2=nJU{#;woP%ChP1R0&^~^|oqE&gYJrXdiw-;bDUkl#OEB&V0cq-^ArPIp1`8~1R;Np57!^N z{oALbyuG~}a`~j}_AtgoVmc7-kADSZ3xM)!wAUJPotc4w#6B1|?ij5vi z4FUZixtW{>lAET%nV=e}YytsJtOK#x$ClGJqNIR@wz3I0~J7>M0+u45d^|*1W(lc?iS-B(rTY^gDEw36>RdtKHFsi3* zbYiWZQL<(;T1^3cQnhEV{e6L7Li7FiV~Qe?{$kd*W&!hPix>i$87#P_{`eJO2D z>!RfaisvgHLbG($E#HLrZf%zLNfxH~s%!h#(FU0)SKZa-1l3dut(m39RZOCs?hS+j zoXPS+=U)~ajnRT9q*Ldw(ggTo3l zj7i>2_>V7sy}%--p~_Avf`{PJj}*BVTlI=Xp~jkky;J3LMADe?TvU9hW3aN$qZM3r zqv-^;uf)2Dbr%}lIoemOgKh8I6-awlPBVL)qDBnGG(_tzafq@O`bIVNaw8YVH>41#F8bAmmh1S`fAuF5XdgGm&@@5av& zZ?}0WsBUrNe@H1T*2}*bIif~izhhqfJ>norf`Y#aD^G^LUf;`FkuqHK@=pdRUl@D< z1?HkIxzvZ?0SU>khIZh6sykAD%N3RT-d48;vCjOi)WRVBo7DNrTtW@RwBL2`+*pWhRiO+{mj)N&QjW)$)bLFfs%Uf24`(AwgxWnUB?BbU70&%XGIkQgk zhM6UcZx)#IR)Ewo=6k{$2hh{$_;e0Z1JX~5J>|brlM7c8y zyDi+|y_Q1MUBgUG^Oi@*qQVnSdMg--BHZkt_}N*Ttzjl7-(gGw)_s{`ZtEtLj-15( zQ{D|Z8%U?g3;BZqd_7>Ut^nDD0D-X-JH2-wpIkZj8?;?$W^QpPc!kL)%E{R{g;@l* z!9wf)@cOo9D#YQ{PP>yOmj2fE+qL(yN*+;3JA1vXPM3+67xJFWiK zcV^whi@ueZ;|Hxt^yhG6v9jE#|`p@B8Oj44O*20jG&R^6`W`i%s}O z@Wc3O&+R5kUcvaXk@^V@o&=gA^21OsL^sr}|-?gHe3T;)49MEy{l1H|uB{LpLUkxfOP3X$O;) zcT$1XnMvIbv4>@i9%J{Nhg0BY(61%PhVd7^MM$434aWO&-uKBTtGe^Sr7swC&AE)%$7Mt|694_kEpD)`QpHp0F)1_;cRt z^Cu-{zAa|Cj$DgDYO1cTzKpP5Y40O3a%Z#OyPiGFhRZ1k4qFFa*lPGQ4C(2mOydys zrc9)+9tNBvdLUy6n-X({9%k{WfjJVr5xDVWZB>5SOb~2y@QrB1VAJ1mjg4Y9+DgtN z#jK%BF|?uySAwt}eQD2&T#o?Uoz1t;cN-1|b%T_1<1FFJSHJHTw}*7%q=_v$C97Es zZ20)9nKUl?T~fja75BBSYp>ahD4Etms8wRCiIjKq7c10IMWbPpPE?6wInU}DpiOlY zh^z1(lsxvYcg$DFs#Ur7Z-Gi50RryGN<4~Rpkb@_0PE58ss)M=opPz3%+iWG=rkCC zetj2KQ8UqXfRD^Ebk=@)6dpyQp<;HmdsV7zUb*}`hU@OjIQImNz18)1X;u!NgT`RtX^Avo7*=aGrcQG;+Vkqtr-%nf${l_RS}uYtCihb5Y(i{=Tnn16ha5b6P2dJ&r#2P= zc>eLTN*KvGn*o}ACdb|aqGn&PPfCB&1dWE4bC6eEskzBbS1i7jVeijal@#2ZZ z#St*+8;0lV(-@a1K=x=gg4YxEd3Qd4=A7({Ad(NBYMiIu^hnQ(W3{%gtEWpOk zJfNZuOR^wR-QJ|CZrk2#<8uPPUVRaJgpWa)u!|P$_WTlZ_?)DG&3I1?_Qg66GKYUm z94!%D1#*O5*`iBwI!XohVW(1(#jA<+)^FI9s)AOipXlMP>cB6F&`&oO-fS=y@59;M zGG^Q%V1L$Xy{Csq*;Xu8G8}p`_bI2sdt78`mh+?8THdD1og!2)co0q9f?=n$(C2>7 zY8V^HdmIv%?Rad=BxvWw*N0Dyq`kxM^|8xcwbCzU{O=YMuTC}&Zf0!<3UUtGk9!3y zOl}CrP%PY3dz;Nrxbi6GOuz@bRcBBSJn7O4Sf+~>x3G&9l@vd-EwdsnQnakxA9A23 z!o78`F?L3P89_^qj66YJO8s_i(XD_qTqeJn%coe07>qbQ7%ivABw{?J33$*;Sw=!1 zjVQmiZ)zY&6Eytolfn6xtITvl22+h)nV5S|D)z-$l88DYpdvwQI-+%GpZv3i6Eo3EXm&k*L z@6gV)z-Ef_UiK=2z_VUTr)eb1L!tDn#hV5rl*KXv(=}>%`m2C>Et+c=$;%hEoeyZy zXm9*V&fw1X2M!Jv?nDG9eVux-lwa8sNyVwz?Qi?{5*VxKlArbE^GJQ5#nud<~9Luvc?Nx#Q>?RoXn^gFZ zo?xNZsjNt(!rs0Is%M68<}I|g@+cRwOD&&B5FbSzB|~h&k=YtpiSdU1@JBW~$^hOwRh1&jXG1Zf(AOmDd&p?V%rkOOxvKT5|03!>9Y@GP;Gtu9 zhy@!qxHQgmRETRJtAaJ(hwM8xS=_-sHtx_XwZ6>s_X4~pOYVFxCkc#z=;vG4SAntB zF>UEXiybYn@gNk6f2V{|n<=kT5&^(sMN4c6Gy8s7MunR@pT^`P^2)$_om+n0@V7qx zd3|h#MyGnq+>kS1SoQkQ#?szL>K!Q(34Gr)lToT@pl&Cw&-Yx98ZPEI zw-GG6E0)_SPY9CS+zwuQcZ`Rwtq)`uT#=qa>6`mf>IR?DY>y}vu~apUuTQIdthSmR zb9!;Esg5c|tJ!$lmn9L;9UuwatOnLdmxZ^KbxblM8}e|0(;bflGm%xra5`6m45lJ@ z%d*Dh?HvL_I~Q-Y)cAo-L>a<_O7)n8jr z*AtTspS{f@pr@0hj(AB~->Q=hAKadsaCmgmCP|w@-!=CfOV3#2c)|c*Dsj_VYT}N{ zN0cmm7TkGzm7lRIB~Qq+8X&d|W~ysPQbur2~DCvXwP&~-LFW4nWvaj)VSOv~n!(KH87 zi^SQg8cSy0;=6Q48vnWq&L`^(He}#Uo+g7bW@*n|&X?oYHbW?Tg@f!#_8M*XDadQ} zb#c*w>Rl+SNO{UgqMtU;?DK~@!Vv*I`emjr1OdsXL%J|$Do{zfR+f+(V0eyx{K=pD zF|?fgx{YDsEcebxWF|5NPr^@O+^c!7Z;sS_n(D+BsE4dSCZE+O^yDhY^K_Ej-x+_m zbc1}V-N{C9|nm0N(Iw|Qd1_tM|FeGF_Oa%NfJ zBL2drIQjBZG2^(R3km5u+JMBIPbc>N<6&O@Ud{svWVOdj zuk4cC&+k1EL;V*|shWj-oIg*H#VG{bw+}k%?>y7j2BO3TWH$-|7=zWbYSb@3lFwpn zaBR{%63noxrFT!+mOa6BYm%NXHH|#Y_!Le>CV?L!>g=v~MGv-i|DtSA3Dp0cH~z^| znl7uy8hmOw_z+#stK3>C2#OEJyu+}&n>W-~8%&Ppb-0U0A;-x)b2sh6v89;D`0|ZB zRxCUjp;jB$Be8bWr~yL`JSkoE^hs6AIj0g~$ak8YJ}~XZ?^E2{+)S1bwl`pKLv=6O zZ?MiNFzWXEnKUNf&$rQR1h0Kc^qQyiOXtT*KHe@mY0x{?Y^7fcxe}2iwkEEP%g^v_ z?&9IxBAbLbd`Vo)4^zHETI0g+H=%MzFmH8q#a;j024$W|gr42E|2XQjb0{x%Wj^9~ zqH}>$bnC_EQ>?;kJvkJJQFQrZS%|r2g4*z~qL6m3 zdx|&nns5qYckkxCXGYOIeJret5BP^mO#0YWCE^x|u$jZ1FY@bJNH50~m$NyXs=TQd zD`>y9`#P2!kNKA4aF!5A(Q&m3jBB2jO4l*IiA`cU%`rZOXette#_f#mlsTO+$V;y5 zO}ObBF1=1|50@@Y+`SKY-ouO1X{qzOjQIlmtmRvWTpeZ3H9xukl3Hgdal*&!SMOp@ zOjWFI>ZaMSTPL$PPo~qNP4X>D7No!RrXclo++As?q{av~KLlUm ze2Tf+YSBWtdh3CruTR81dhAV0rBU+tvKhVE zgl01Rk$hgk`m?#YD^)f?I+=pif<&kEnd#;BI*?lJ8CH$;1Z@egj6;?G2>378xf+om|+NViR%pJ zV>)yvhy2L&w=d6Q_;!0PZL{s@zVi-d19^ewnQ2lMOXmzPo2dk0{?)$0qHc4`8dzUs zKeRSG)8K=*9oFRl#2&G+c~DbiamI_=XQz1uR{dF*+oEn0-OoD*JO@6O`Z15&MuP;BS*6g3W&bEZ&1=T>Ywpb)%PIDgtu&c$P}nra$?4;dtrvie9(m92HR}w z@xmi$7~`@r?~dKqZ02dL_X8bDR4LJm=ggWDP%R{cZ?egiO0*_Dkmyn!{OnU5yIOBv ze>hQMs{FLE=K1aRI=Q#VF+)e)H1OisdrA0#eob@l;JS4d+~&i2U+vachE6@~+>$h^ z7l;l}p+5Rro>u5eIPdNwZSqd5$ZA5K)LQsShqR7h8*vU*;1^OIDM&b>k8`Zf22kLI zEUhfCtt_By)+e-}l4PUsd|{cPY>tXnz?7Wjrs&{G!Xcc zTt_&dLXb5Uq|va7V-~T-wY3niv^9I0?8KlbwRV7(zy35CItg8Ht5g^n--fV+KDa~ajCa-poPos z<(ZmNeJcfDOD`KuZksA_6IB>cDiDtmx2vy0=^szn{S{`xbB*fl zII#xm^p+Nn75yZb9O)b}Y}{ye54ZrOhQ9C`H`s=c#JByYR&GPy`%s5@{PySgy7K>z zy}yczq|4TDVXSbca4)ow!rk579SSJiCGo_naCdiicP-qV!rk31fy39`d+%@e?(tv$ zXN+?dBl99Ma>a@jbH4L==c43#t7&f}ydJDJrXMi~18T?3JpQKtLqmcht-g_y z0iN5-%(SuAzP5^>7PXB15X7fQwK^({bxx7$%MsdLYCdw>b_vG+oF6nH%$rI(A3uaT zLdxtSzoyCFs@mglkQMmw6*CEbH_ljXMYpI*T!`Vxfwq|||0_pq3%>yjDnM{pSHB|8%`1-N96yF-+)jo2?%@HxbfI-qG#g)w z2>d8F&C?{jvh-AFU>KLVED9;CkCvAg-PV5HK;tuZSoT-&rr{aiTirl>DtH za)QE$GoZ|hfhzbNzw#6+LI`VCG6F*Z(nird9W^)iM9E=Nl#KR>IKZd*3Txa00$eE9 z%3LH`dxN$RulL8GO8GdYa-TuZydh*YvK3LrB8^R8pVKN+I1G6#LJJ|%G%hXaBP`e%Dp{}EBRhGHcddfY^$4H7-KJ|`J9LI+P z{-&U<@<2`@s?0JOPHDmB`sDJvX=7Qe%_Dp62KRQA-35SE5MxzVc}YrD2L1?&wdF%9 zH)raEJv^^-o`#@{zYtMKErU~Q4W6{=xf8@^l@X+FJ-36{>b>R$0ryrPh9R`$-!N3A zZUbS`NK{G&H=w4;=6PM5T?^$~qa%$6Y!y zGW^I=hq5TAeaIu_{#6e6>jj*NV@R0GwNEPuJyI*m@r4(g651*$4{ifuKj_Nuh2=a> zkb}@0qp#x^e#Q;dbV0gS;5q}Q%BwHCF(pj@X-Jzv$(Oh{rV2XDB%(_YzAGha2kfuH zEphi}Kdo(ip+u?XX|5frWWz4R8G`q(QTgxMJ4rLQs@uw!#N{cTHG276wj%Lg@F+x9 zkqqUvA=wVRVTXKmGFsLUuY&nTn8u3ysv| z=mV{~Kbn}B7@-y<5{2GL-_?=F(l1%5t2Dp@V&+nbGC(+S+=7wO}!5L=1> zw-8i)Spy#9b^4C1Zy|JNC3ik-uBV!Mzc2hg_4oH5BP4tG5wOnLQ zI5fdWdehmKI2nYcYfk@T2{17sY1bh${!a0Q?ed&(deKBE40(+na|8xts6IhLv_E#r zE)%s6!x6>l4FZ+j*0Yh0UA-3XP_-Ux(uSs6`tv{oB#zdlpK4OR}{ z<;>i~PDoPXMU7@jt!h5LEy|X#VJBUN3XD0Z=tW4$&q#R$-85UL2>71YCJ@gno~4cz zGYkkO(T%DSkB~o!TFzEAQnYeD{~DQq;{H`D8h35gv+UPP*qqmmE!50cITA857$Zmw zdJhiVJA-YM2~$z@;PZQ4oENP98f{DZ2a*6%d-{TCU=54)6Xvkpp@4Rba-lPSXeZKW zU5G9WA;@|1h5pEcYK(Zft#7kNn}7Xl5$82X3FI`E-Q8sNg=mAi#^M~M zWD%H9ycu$zH<<^w&+$KjMXQ-_3j_4Z2%J0mE)E@I*8)KQ*i_1A>)q~F||Zb9*l&MVHF z#`tU&sa>Pe98dGTMiT1IdFcx>eBJ2wOvMOrFH&;1Xble-)dJ=auHf3na*(ELed5Xm zCe+Kq9k3fHPbr}cNQ!&xW4H@9(#!}6n0DvY`bkxQw|69+rjRz5nbD;wr?9S+F6VPZ zv)yb=)6eP}uOr5Fk_Wt2TL~`9U?sZ&16yX2g2g@nI}}jEUZA&RmkNDP?{ewT z2pe>`-@vG^4u+iBe(|TFT}drwVfjx@;@h?b%s}D5H}<|!ur~eZ%C=fo(s`&K&ZtE~ z*6{9yP(~Z1Rp+|P7@r|uA4mSIVVXO3!AzN$LQz9(LCNyGG$}6R?Ouc%1YBteUc=*v zFW31L#{zBnoE+0uS076>mpP0`?dkUC)jwbHq+uBI4eF&zdZ1(p>}N-@V6Jolfv;1g z6(`2=9Wlh~FS#Co`?PiyUMvBmurreWx{ymE1&8)su8VU9tyQmtxTLn|%HI zPL0>- zcNjL#BR8PSK4t`9h20$`(iBJI*L%z@;yDkmXO`?hJK-by(8OE^pUo>p%%w;7(gWGE z0$1FcDT$kTqHsg2T_k(0B@a6_y;uPpA7}6E>SY9j+1LrdyI^I&f3x zl6vY(Y~`*1R|9RO$(PO5L|&H$KA+}mgNy9O&^RDi0=*LcI(TcQhwmH7Dl13;bZo5R zJ5k~6oKYNE)cqiM@a|R(7aiPm;!EjfS;Wr1yN;|7 zp$3ajQMF~Sp?&l)-q>amJMqFjbJc&NpM&E48T!bH&^Ch%{>#FtKn54@xE261)zfMm2e0`ul~ob?zCf=EtfHXzPT7V{ z%_?@br^_E6-k}N!v|p};IA_{bim?AC>#gC+zl2L(3wX5h6j0>>n#4`n1?A(I)8Q5U zOJ*g7|v#qG$MJT)mduxeLlC}>C#zugKXwn^3Vc4M>znVn{f`>;xaywGDv;6VCaNKbs zFVhRLnILW@O#>n$$l+ll{uTQ+EgQ&vRK(L>ukkKIiey~nB$}hbNq64LT+e|ryZf=X zE9leLVF$;^!9}B8ErxpMovWZ;x0XZw&4KIGExt4)blepkYW)>#YkLswG^O3I6AA^& zZTfh3;^H<)DV}0DbZ2Fvw~@h+wMe9dFYD9pRawv^5HMiFeS?RTpsdbS_@@FBjj0!dB6rQnrJf;cCOQPBRAFpJKM1N z6sx+ZxwwJv6Amt8M_($P^(ka}4mzv01zIO25$F1f2FU(a`C^2wH8LcT{+LMK%YJM& zWchG|tjOj*`S|{yZIOhZP>5Eb;Fiq)oUgD(wKM?ZH07G&8!DYX!P-D-N14BV)k2X!Y{t^n{)B-G|` z=d*qE0|`pH+kT7Q2FH_5oU~+Hvifr7CfN!}X~E4%bkYhWAc`p{3+m_`=6I;gMc?po z2EQ+;p&;JZ&5ab|0n@W;1o3%nb+Pz+hol!CBg8ibVPMFKdmOcn zP1FebHYM%%j|yhzpfx{bEN9xd9{@YbnZd&qsl%|{t_c5HG87Y2KT7uITy<*;bu)|J zhhTG+^$job#gkmmTXV3Yf2mE-oBi`sgZaG6l?aa3uQDInKxTwVc0qgs(~xCnBE^nL z|2n^CfNCt1@&N^AnX!^5mT!_D`_Ve&(KS8V^ofj*1BgVCaKq0T4Q4wU=~pSjqEW#* z$0#_anF_uFriIekXw&b=r|?2PC<9=40BPJvW=Vy_BiVz7gnN%|GhQU338`qsE$gwWUA;db!2bs)@ z3|&v82&_QjN)QHoIwlithO0#c+FhbCYO&xJkNc=$hrg!0oKwRtT#cS9VxJf;3A{nv zL)XgaLjIE=q3#EB9USJyA3BaS}xES+14|jq~Nd=2JOnnO{ig}%Q z+?#XEIkMnlAUngHYYS#bqzi2|axxM>_}-nEp63cUy=p@bBmNjv#QC$87GsT;Jc2$) zjC~p%O?rL#Aa?8pd5~{;R~q(9mid8mxkkTpOVS{tY3QBS25+ z5eoHqYh1ch_(U_>A|gRwP;V|qzvw4E1LOh#|DvTX5&xfv=m*>y3%*8?FBX64md6WW z5qR5VL^lV!-Z}&eU>_B&FU-uJ%EGKqZIS+-*m*d<>y0q_>Y9@jlY-)|b zz-LcWEwhiGF;$VZLHeHk?O9E4t_C^-o#*pk!UzWLkQN5YNaQe%q_JR?=4{?wP=Kq2 z1s3bas*`P#eT#oiv1GzX_>WT6e~p44p7ExS^zvtG#bI>QM^EZeUKb^aL^-Ag8fJUS zr~3U_jMJsi%=?8|3A-ODIp2qs@5zx~O-xY>W!2tlE*F+(dHJW_lp%3w1e)PUR z;1HUx5wz{IP%^vDn$P0WdUQ8HZ~#-({8(4f_2%)z18-}#8>bq_8U_pEf4!9dYe`5 z&es!zlP~tiKWoVU`C9%pgnrO{U>uzvfGFEQw@-Y%r{eNx^E(2F7joN2V_V^-*%owK zU#yLCZ88Fg%#hzLmcH zFP6}8ZPkXl4)c*zD;{*(eibxE;4o5mg6MI1X~2>I)OkN>8jk1($Wg}4CXIRBWb`H> zgynfO-Ff@Gy78{-I*f)TeY0BJM7OKgQ5&uh*RAd{`|-$8hm&(IHC3yS+eWgU#Xluu11hpX-CJ8R~(l zQ-1}W{*uc}Rp+zno8UxG2kqsCKM&_++CQoD|DWsGlrK(oMD(VKD7)}Fj@zW4=Y|ra zdlLc2Ajcn1{p)zYJY=84O8=5~M_quR3y1Le2Lz08kw0*O%eZE_R-vfg7;$g zMw}RwjE<1e`v)5G9`~4=Xrs`cBhuuN@bxQgSlElj&2}ZY%2Kn#BEh$PkUGhxA<6_>EluwFW zHl)|u46xpa5usS21Ld|NdVY6SK0|(ohwkACF7p9t4-lQ@C}GXw{p?<5TD%uL$?bwW z?9Pd;9IG>#oViU>9A_7-J91KYdrV7};vBi+x}T&Vo*EYgWS@F^{A3NC_;--=UsQ0h z7PPjaF3%-b@rfP6!fA}Ag-gGn2F>zk*-xfxL)bQ93xFVS?+ZsI)*J%h_@YCDlP`wT zLhb~)8N9%BwRXA};*Wq-tR+-ooqk?Wm{qGkv|^@y5n-p0aMho;QmkcFY?nEc zI;%E=f+n_wDuXV)u&}Uq4j<6t-cbk}+Pe9raT>czDv@tR*-p9X!zqi%^~t0lwD~+eY@=#3V`jFkU3Cx1w>){#l55 z{I(lzs*bkwq;)Ec&OiK4wU|10Zn)a}_Kw$^l0*2=Wlz>vP&g;glI4o2bh< zIXM}aeRZBoW-{JS$|*eT=)T(Y`cz7KbntucM^#5nx=Z1h{h7KluhkTB`u6r@8nCc z9i$Q!cYpiqHnf?Chog2fKGW-o7XIfIzn>OIrk`|Tx;>WN(c?{5z2ocoN#aS}eg*ll z_3=gSOgT7OJYBN`e`|!4*^T@{Z)g&YSY|f-*rq_h|2n^&dt_82%R=o;L(S4jNJxkj zx9a<;H-yHHWrSp!4CUSCujNdfsaBR`+YFlcv>3vov506_5z5ToA)dr74Pz(i@0D>$ z8`kD4J#iD6JnaWNM)en1t8pu?9f^4*JGbWTqbHWj$SJJjPUwS^_^u|mt*Qo|kmiUj zaJt9~{?vsaM^VB>} zAZAk4AmxO>Z660sFJ;fJ1I?TQ)|zy*lbseZvM}&3XBg`JiFbe&Fht z_gJRc%#G@IevZpJ?}A>|$2ya%5f!iZW1zp^hQg!soyp*nc^G|T!TMbOj7yOAXjwrw4idf2&jx+GlIbt%A|xCMQdZlb-f zH1rG%3eXW5nlID_?)ap@Xdfasgds+KwXZJx{7a6bWSaM~;gv4xw45B1Pwgs6MTt~l zljrPK>RRHb^R;T9r@Mgf5?tqrqlxr9_kJ)$kDXs$EeoT_ySg$5#tXmMB${K0#zu7B z9_J#@tIRsGe7d)HOiyfPmQ14R`V@k{IZ~`Wf}z?yDws)<3-B<}c0=Am&gQipDI=?BhQ=tj>?%D5hhk?3VpNx!#*ZxixaDRPhG$s{N<%}u@=e)(t!Nj0HquTvG(>$qIGat%E&9==HEe(&7z zN2FZ@I>6nCd58?aG-Wqt$VhbfP_mHk7j@R)y4s+6c_04*PHnYX9ARk%dR8REcid=* zh|b2(;Mvdq>h_a)4^$I$6d7VXqFPL1w1GumuT>p4_h-({^s0EV4Z9Mr`i^(rFLz87?CGbnKHw*HLYzMpi|+IBg$5=v~=!Ti~*CPOh>uSS=~pIQ^-d_oYlWyZ>^2k35Bvg?f2t_(xfpl;ZobY5BGQh{^Z{KBa!gK68a%wcTo^m#jy_8ac};B51llwqF91;kGh{#<=K8oc&cw;Vw@LbZ2P# z?hc>8l}*Trdk8&jUD-oF755KYN|mkuDpLDr(V}qV0v9~bLw;mI4uwcC&qBg6%N-%@ zR{t)ryMVthpjr(n4s2awVOYz)Iz17clbg@trUN9e3k^GItsTL3EnZL8Jf*E)M;S+2 zq*9`q9bB2%S!^WP=$rg1Spvrquysfn(j07cQ!d{!^1Qg#Npcl-TCLG19lqs;&p}gg zDCfN|W{>Ud@_42+IXkLU9og12%VasOFehFtT3+6-(BtS^rqANA!x#I}*nt&eT5KA> z2qJ4An21#nc%6MNbyw%avQ*Lc28bw5bNWiVDq}AO&33=o{`hU@3QAF@z%{f#!TPa< z)`;IQjrwa3`7fwA=GB06-odyHvqE;AC3ka)9_LXx4ra22r50}6tnKoOe^mT7%aH?M zbEW!8Lx%UkFdYUY`o`GYg}VfGmB)bNa@gdBVoKn5Vy-|Z(71pcFhFp-9PD3`*Rnj3 z)48CZ#)}kBCVua2#Ng~fO=f)O8F0>Jo(nRSJR$MHJ-0+4*7xiuB(6M%zv_?RB6G|j zkb{*DYz6G#e_65PF&5WKqo-(X*uH9SQ~TB>Ew|<0AHAYwlt5Re4k50qoh0aNBv2 zZl4c={LS*6n^pdyQ)+P9_oUOVEduWiX;wN(?~=$l4qEIw4PSp2TpWgx z&#Cij@JdqQ-#|s$PRpBZjGtg=_uIZ#-sn(W``x1_*qA3I2uMkyqAQR=biW?_GQzz7 z6Pj-cNnc(@C#u@Edh8Q%KrK5$u8ZjCoMr>A8)STK18QRNs=VN;_dl&0dfk+Qd)`Pt z0nn}ToNJ|ZCk!- z#`)Spd(4W=M>5>&3@NNIk^CmTEqTbT+qn*W`&|8Hl;PoCsnq4)GI?0SckeXZa|Ga-+4NG9_Qf2-2QfrV1m>~>nniOn8I~)R$(T8kla3PWO}7piO3uw#a>jy{eX^} zZE)$m_l}Dr)ycHYFsgIoqT6pZ&wkVOi8)AGBj?k`wPccT_-V;SYT<3Fk?#$tDE19y z-m?dPQ>CbSw*)6e)G6I6T_K(K9iLd|QLOb1#!X|sz+-WX>}r|rE4?w?@q~#KF#6(! z`F;A11oDkY=sKpkmLo8lRC&>aCHfmT%5jqA`{*XsCGQj`T7oAm-6x}z*0+1;WzHBI zQ6s{|pCq64tXd0!Fj)|HVNZgCE_~+jxnFknX*V^p9gBRidpI`vs#zql8*ad}xz|6v ze=6N^sV0=fwh{CKSA8>!5uMLzJ`&_sDv>4Jc7lU-k2T<}KvC_S5KCr<2CC(AyQmf# zsFm7;H!XkG)D?`6(%PngL>Thu1ydZ=0qdM$r8`^b@U6JgoIu4L^P!ApPL*m= z0`j6Q2B8EVQ9DtZN%55YHt$F#5<`IHRzlB4H~ds!Pjf0@QSk^yIu6`M%fM~Lf}!#H zwI>ULpMPgyt(Wmz%FOyZ;Ht$rmGtU;t>J+ zQ&NN{dQcMa)jDVMXt7ggMzhvlB3qI_&JP)A~q>4 z@dymbIHvk|p|DHKT#fHNTbv6&sNxQ*-0gw%-@fr<`dmP4A(Y@%RiPco$c|~-Jl2)2WjjAlRDM`on}mA1 zPEV~On;>@+Y@3t@AFUa^9XgkQNW#qkwIV2+VVOJysy+wxqYcjd6*dV`ZY{PS zeIA6iHuvK1e-Bi9H*9@pn~_Z#h+>FG42gNxix zsP{Ml@)VU6RROXWI!g0+50khayDUyY_wD-!Jdtl)bY6e8)gT0<@{X{b8gR6~Nqa9P zJvBXO{i2ue>LTaSY3eX3YT<=*1h$dJ@uOt%Z`u3M$-jNc^J-m{5#?McMOkQ_yZN*3 z4;f#s4S&{(F$`2An~ZM(rv>T@=3N>C22L1)YBzK>bMYxdHLm_e%yjd z1})3=E6V6FLGkQ&;brip764QBY~YgReN_o&UCN6yJCLSxbZ4T8spsss;q4nS5ab~K z>}d$pza9<3FYoGHCvFDE$A=V0PF`M@aYJ!;u=9(vPrD%(_q={Ei(mN#&M^GItF5vC zwGH3>qeZSGN&-t3cVhL*xVNJF2@h3`%Y)^^*LI10q~M$FWIg+XOQIH>nV9f(s;WRq zHP$)Z>dvD=HMn9n-_uEcj5gb5JeY?!{Gez%b;=;}MU31AA2j=UA#izgoW8p)s5kS( za(`uvUi&S|=EtP|Z3-osjw_dn#n-LruDeCz_p8>E_13j=ExTVeemuy4_iiA2C3H5; za2+8Jy9IGnRzHdypbKmrdpKyt#a2aVa{S;dg6Ep4?CceGNCz#<8Zt}w`{LSOQ@E!{M9|1P3vWs)!lwxXn= z%j9nYXje^=yKrd0)(43%sta$~7@Rom$mDt5T!z%mSadMxgh|;hQHTkzg6z2waa-W~ zKjKKIT3~E-4AWQQ>?(QTh<5eXV0hmicS0OTl7&(iPPGAE29f$6%I*io36)*Y$F9KGAoFGQ3~E?AR6*{AF3LLnZ`UQ2#*)(1fj0Zb@JCzk}sjL zI~{5QDha}efpi1uFDJuwN8N+$BTn0s zGhglmBI0h)Wi*h(*r{4(j_^nC_NPucT}8b&4(80cn{Xdq&J$I)fEeG6mf7`UrX4>bRGA#I3 z8TD%ckVi#E0Lpa$YQG2jOg2Vx)lj8?#bapz2%Ebd&T$x@0{+6OEG<6q*(hcAhzU5J zk~LSPa@eioUwFd@DL=Yp<)Xb75}oyRd`0{&Cy4TkMX$WzCu=#?n(5~$_XSEBnh*`cnihxo>}i+{%y~ZGn9#=6 zE9yG-NmtU0!uQ$j2%$XW>D``S8j7BL-iql79unmB<8=L# zD4SMm&3N}gw6aMu&ox(??PTkOU#`ZlE-(Ylw#2Da(Hc(MZH#O3N{-H>(f4xoGa5Xu#=FVhL261Od5&Hm=t`A1O6Dh=%hcotyn8+(2_$YSW@?DCh3 zhSU-W-!@z}br8&%7^8Ex(%BQHapS9qeSzHoH@O8d(pS_9;mS71;-uP$H}k@-VRT!p zMm$j(RYd7;>a95r`qhe_5J`Ey+uZ|e`H2aF7edRz{;p({Q6&9hZ zkvs}#tq@NNsN`6_BO&e3NDiEbn#g?a-I)|O*qiqv%mXcM@a}Tl(;(hnIU)jnTAL9} z#%ZiPEv-}VMtuE!u4ZlD_Zk6F(6C5s0RJlIYtjDYgsvCgzbZrcJ)}g{;|Z9z zdWc6OQy)V*1Y0~FguBTe9l;wAdrglLZ)zSNA(Y#BTYMp7BklDID1e$7R&_;;!VeaE znL6{dHW($^*-9H{4R)Ykb{J{Vbp1iCc|__9P7GwM4xk0NVUli7vlB+xF}M6X5yus}P>{U!Ft zRcSkSMqS+F5Y5uRBU*r@tl%cxLEBA8<|wKDm&?66Y=4OjVFU@Nm^SirbK*l^g?jMP zOv*r{uvT>L1hU{S5KI21T_8eM&62Z|yqvqW!w^`QpXRGs3-8MeUBFTH(>7D|O?TqK zP;>(mX$Bga)N%5OE>!!w=CxZ{k7q{r87ZQk{hySuWpG+yU;@`69nb<(J1BfLS0k}m zkv~-1eimWxOR%86;jU42qJ75DR1P}Os^=o+RF}T}II~MJ9e>^6! zl~tGUW=ogKS0USFGCxO=$H*-lquZ&@eXcz#Ies<<#YK=T8E;B<43xQ9ct@;T&j_b3%c75rQ^{dcgZ*p?&tu2Y_)Gc$!Lke~O$cEK1>$8RdP zlTdENPb&gJ@8In+;a1(9uqr+iAby=?ZT%gGVvqts$Z4b%K_8uH)1ui*EiWWYW@;@I- z`4=j@$GVrl8USoxsb#b&JBJes_;4xwG$-Nl?b(8)Y}(4M)8;x@a~cX0J$HP;oot9Y z*mvE1!A-k{FEr4lx;ne)!uEDq65&e{W%+c5C%d{C^GXS@U)AJbh)KSi{nNz%)fqlc zX{|Wx09XBT>+U5Oqc)aOzfP}r4;-Jqs=6n4Sh&#-Wz404DQDWL;Of zoG@p~STTm^J*@6p&z+`w=u*B*2Ad2?RK>$Rt8-agY(9Bu;OZ+5ah!JVB9!AkdNa9B4kw=lgux*x{UsXuz2*SN{$#(ZMp z<`0bAM5t0(ioC8r>d6cPzzt3m%2W(B36*eTGka(ym!i3=h`z{!L+h#>w2z}e z&uGx_=87L#Ak90qFxn4Zf$6H8wl5DRS6kMZ=eBG!I2(7X!!VvSAw=7mVI0VPc;uhu zJq4KFrAQNdnKIWCGH<{;H$EgWC=gt|W=ND0C!v?mwavONWwP4ya{@H0 z-u!x5&VI>0{*pKwcr@bZgK4)2nwfxmiWg4?KadSAzu69CEt3O!- zXPr91zRm}ayj(2c0!9!(cMG=SJ1Hn9+1%sE_0z&pIcp|uga;bQ4dUVC**bjw-yx+r z7Rb7QXEq7_I1XWFkrn-(8u1Jf*mUGbn1p54V6504Vmt*iGf7$onz^dmr$5hWAKSA} zJ1^9~iNi*=VPekjG9@gY9~T~dd&>(DZrzkG%ND#-!-X^5P4r<`#&$88jXXX!ER&8C z?%dECD|5EJG(f_X3wOM6XU%~C>f;_Qi)yQjYs-nBP>v{?h;U&XdL>c3lWl3Een{wk zz|T`$N%b&Cjmc37sKM;m^-_44dI2k9WTl%`mW14qR6RXH83+z8R4?~}12;|GTZWXy zrK&}drT1m!%c0)#X^91ymc}a2U0SA<4h6lME2KOdLuJ+9;KF{y{SdyIoRHdx;&!ya z(HOZ~rQDJmsXXW!s37M$>i^NVN^)4ht|pDH;V$fBhM|VNxKf&GwO(XZ&1~!`*Ce>u zQHk+NX#mz7CI8}-9t97hOzsBukw2~Z&iLrr6(G( zi9pijswu7dDUS&VElhd_Vlw!RArmdkox`o zjr+Gpj+AKQvt%cX%(lSK?j3uLb~k+b zgJ(j8`GxP7cvH#O{KeLU2t!Vg zQ#Dem>u-WK3r~syv+iyP=_o|tv_c`5^m!|8Uvf}>cfSBXN>;cie794vv3>Cp2dBu2 z^W2S*@kw%Bp~Z`JO`DmkjVBG$*VgpA^t36nt|cH6=v->j+kU+|vq}2mVXv!mxqZo_e|Oa3`}yWa z?cHXe#cd_RJ;Ol(-p+6!YxuD84Rg5y54%YzR{be&Fnat?Cym+oQ)XdtnBM@BUeY2& zgiM8RC20f(dco0iSt)_ActleAbF7t>2ng?hwel6wx#JBSFQ4_cGi{bct4>@k1iUiV zxLG|{Aa`qEAn~b14v;@}FoW=;dv9>jrhwL(%^}0pE7=>k%@#AHUYOI&Ftxk)WX+i<#PGQv3@ExXRtr~Qc)Y)sNAr;42B_CGNT%6{!Ev@ zp8r77{IUw$bo>Y8_T67~(q-U`*Lxs}VnprslK%6(iZf?tIV@kpW~W@rF*bJ~OSidV zO#4s&=b(wa@D&w7cjZ}^Q@IP`9St0G;)QeWmmN%HiTGO%5pL8O7mUzw$DOC=9+~Dr zf2&?7&8b!|s6Sog6UGte+@Agf_D)h>N3DyrAqsjrR4_(*2w4&{BiAFy_SBj%yvo03 zO8PGFq*M$xFb;WjQQ!rsOMx8O_k~$1HiC$dO4DrRCzNpV;Mz@P^c@hJCh;~Q-Vo`y z`Ri#)w3x#UyBj&{m=CO%U#2+ook&d330L?8U*!fn1thTvt%e>H9yuUngg&}yR6*80 z=Q7EqN!s})`p6znK65@m>Lo=Lww$9F5Iz&ctiLe(>yXq zTG_iZcjZm=wW!B)rF`+2aX@K;0$aYsYH*y)M-eCGV*dnS1gqd&uA+^yt^2Tg3+X~w z#pTdM^HvlCd}?@#4gu832)p~dbvb-{E3Ey^lr9&o(6Mb=B_$vyPesspx>jG0hroVW z(=iaqLVXDTAv9;fHbSHfG%5rOpsT9glkDswrnqE?4|4qsGMlnORr-g9Rsk)P%$9qM zNZ#JZB{lg112|vB_<2AMLm&w@^M zn+qhT7R7!AuqkOG#n*&0zae!N`_-1i{#(#)C};R#`ur$^^vNnq?pAHml0atd&v^QE z_?@%*!0N)1GegV4%)m{5?m|Q{()`7OTS>u;?gO{%hb9Q!s(9a@S8L)Hmt#i~JR%GM z*$uZT!H7=>$>29GL}?5&-MQLr&>#UF5-Et!Who*jsQ`ZjHo`KHs4;1I@LPtGDcGHR zy|(|xN6;)a>GGc`#!~1~QTT_#3BCN`s%+1CpFed{*LI@h5+<9;e0&(q4U9p{_%IaI zwzKI4JzsH{s$qq^L&x2;l~->VmQl6F6tO+9>bu=`TzJGp6?p2}4cDiU8c)dWb|KM) z#%El6ys=j;6fc~z2 z$g#J|7X|FKPIsD<f0cGqNFsc`Io7UOm@x{Sr4Gf6s?K9o=gq zIk-O~ujjwvosaoiavs(^sD|+VBON|fjQb~g;T`E4qi8WZySW z%4PQ2{C9)^)f7{NFDas}FMb$3Z-y$-5%t1HB-sx6_(PxA=v|b8GU2tb2|g_^T!O1M z8O+-|;Ym20eaoE6IopKeThZmU&>8$1o-lhnD~T{z=dgFk)|g)!-iGHX{LT4nG2a|H zkLoukF@?g$oXL}w?8k<-i^(b$F#j&~{DYN8HQMw5EhAc>g|OJ|;Z6~rGOwp=J(U+Y zg0vowuOc{0c)RRQdhp)|)+Y!;SO2afzf2K)&!D?I`OC(jY-Q=87taa8tS@WO zF5eWuhg_(#ya+#$is=_zJo`vJknG{dUuj|`!aW-5-YIy{!L#3!8O|5|BISwiH zuKzRw|J%jTofQYpXYqNhem;(r!B@eTU4mS?3$h+2A?8c`J94eNn4#XBQc7dmzw7G% zajqL8f~0Ubvvy>G-^o`ljnn#@6n&bz3sz6XL|!hR2n z9WY)pB_l0Bh1{Ji+AVHxZ@));$PDh7$i=kow^uVI-DSdWvRd32DJZdmx*X&d;>OER zNXl6$YWh2yjV{b7x>g_Ed1ptb%Z7c_@DK9R3ptX>@38gKX1jCgbkwveEjuG=)79lG z>+&|%znc|01*qr?3}0857N{>V)i5Vi{Hz^GA}qu4f?r?_aS+7oN`-=9$9Rb9Hc*lO zxb-SFNTVg=K9~zi+dn=yzeA1XI-rq!nO@r&!OizIzN*LmrzV0s*A&z2l^;d6JlsrX zc8d{A8<#(*(sy+bei1$?Yq_%%T;DD~$)9q{-?B?BTHwiW8QwJtEkz6_=6qE9WBH_yc;8}wZ{ruiTICvKl+ z{ePI;dyFK;GfjrZ_883ss^s2EIFR+->&SFa5!&$Z@Rd7YmDju%A=7yh#0J9wOSs;+ znWn)<;QC-}q|8Jk_LNFX%&80c^|{sh?;GnW-NQ-=R$Z-TwX)>3)M2EQMpA!-#28>{ z(x|PVVm8HoB2fAgL(rhM`2bW|wO z#?zFy%%Z}H{K8ooHNd3FOIO#%?7>_~-F%mPMy9)hOTw+dj0_iTr|eUd!pNA0`GKt% z#tP~7vviIkPP_F|En#Kraj9kGhI&L<=z;O`diFB?7N!U!Y zPbbX{K+F>^m#9mSIejl4J@w{WMMEFyJ!^Q)W@VJBazBet(9>f-%kos6p_n(;Pk%A3 zWz)^!f|6k%;CU;(75t+`QNAdi@ST>eCB-43ZHH@rP1nWFd1M?P1y7T+R*~*JZ|R0i zDTi?m(>X!fAem>+5Xew&NcR7<_nkpaZ)?B%*u|}&Qnn&ZK%_STAqWam0@8c$B^2qs ziAoWqg(f8+LMWj_LMIV5R3Q*bfB=zRLxfO-0J-e*p7*_5_srch_v4+LPir!>{%bvD zJ>~bTf)sH-ULiuwXmn;37Xj)zt7B3h&gX~tZ8zbtl<0i_^*nULN(|pTsf<zpNpA$M1Jm7- zZ+yNRyI;ceFTO=9i{YdATD6A48_Q}ZyEe&8sxRAK_@aE@>SJICo$9IHNy18#3mkRI zCiC5Z-}z3Sk$*@bZunUnpDQ!z05Z|W&9EMOK$Yk+^$t4})mbu2)&uvyK~gO{9IHkKbUWue%}#htsbkos=DgdGf+9&c}@Q4d*Uw6iu%AW zFlAE+F)xN}*1SfUTb8FDq`14gU*r!P{}gov<#)RCYt);MR8ZG??_zk;rYP6jJPjckdxDkB8X?&AUwiUVqObtNk!N_bi+JJKqrCZ75XrZZOFhQrsz4>xFF=VaDp87uD7RaNF@9z&6TdWDr-v-Po$|~d4lPp;T}@C z(_l$X>LJOoO1|fhB-dlHLUqD34Cpy+k`2rl6WfuKG|EmG1MFISPku=?*Eao*Vu|XY zxID3ot9qfK-*uf6DNFn+&BQ`)m5DAbbuIj!SoBp8TDrat)!mW^OUSLT+6vRLDHvhI zl%c%yzL;(uUtjCJQ3dm%K-NO{%_oo5M+Lg-_!hR>tjPo!acJPz`#Xr48|z8yg&no{ z!L3?4ES)6HlmA*e|GkQ0s~BEsg5E7eY+}Hc8L$amV$!bl6)7D-(O1Lo!Bz`Mroq0kFCn|uckzuLND_COy&#zVg$d$)fM8t2iiZ)dMTg{~Wvo|I+UKrT89bt+Qi!pp9*dAXsJ z1l*g_vP_p;p@Je6oev?6hHWX5RR&QI-8?UYa@s}8?W};l_0r~cx1N`ujpUK`sa4s$M;;S_tmc|PVuq9r9wcsBs}Zdmc%nwomysU;Gr%GX~v z6e$!yUHF7lR=_Pc+n4U}fZmMI0Myb$)wct^r}HkFx`%~_J$UHPEHkhdt)_G7SEJ2% zr6cOihH7RZhrc#&sf&>-MjEROW=mZNY;O%rhc>L@&QSVH!eQR-+JKUidg1gP<85qG zb=$TYK;Apca_YE-m9!N{(_X0DaS&XhBpefDGSiO;(+2Oo{{+yZ@bHBt?`)xQc6Dy7o^qaY%##`5Qv!~l9PozK%^cy8{ zfgJaz(vL5r!m2Qww3XQdm=k{9Q`h&+)LzM|%bNcH`iMgk2?YKm?8r>Y6?8h`Xhye`y zN)+jd-_1)U2u)1wsX$fQQa)#cI5ADVL(t6%jjyd)`0Rk8hY=X+%=sCKt$Ps{N8!-E z4ISSx#$^4|DlEj+7_db6aHP4M_wMPIGo{cUQrt zYPfg9#XX(T@$>|BFO0silW>rD{*UPwq^Do^4BcxBPbJZft6VkVe`k#D**0;T}*D zlYj#EE<0k*=gkv+cg{Si;fTxSOSp8oY>+aLeCnlrMl_dmpxe!#;;*i@ASP-SLGughZ-4%{>~{e^jp>MwWG2?5?c6J!d6bO> zt)1*l-NdQaT+ zF^lu~<334QoWQo|=P+q`e#t*+-6ft~E?O-DU@z04U#t^HHIz}9)TUXAnmR556eY!c zF_75ETrb?J&0#kCQ79Dfgyuid=zqJ^7ms(c|M;4^F*qs1nbf6wU|xK*IXANDrX$S| z+Eu_3@=81{Q1QT{H@VuDk$U?H>1R=ux{8V+J4!*Mo5V@%`y61R3TKD2d_!Z@HMQn5FNaW9S*)1(0p?& zw9&N1ceNres^#Xe0QqgX=`v^4L5g65b_C7zwN-`LDi)`K(Cw|t25c?UfG%pxtkI68 zK7VpLAv8buEAIX5=a~EwT|Bp4H_9{^?L1VQ#w&NDg+^7Haot&o+wH+Ft z^~zU?^nFmbFmyOxEtixH-%pKO+uic^O5}A7I9RG%88#iDRNdIqx)GL39FUPNGg3s~PuXn{h?HaX`XP!);>UEQd9cvS5E?n$={k7coKwHm9@d@5gPSP_QE zKZz8aFk%%jcu5JbUfA09>uHncjS8PGZCZvxxUCvNWQm9}_^Smwl-Z;T4Y}_wS~w&0 zm^MR>kA-LAiAx?T<%FX@6M46pR=g+%n#q-v&Wr8DmT>KlHk2P&2 z{Ms00gofM3mx|S*uoKmqtT=MmsS+y|pWc9%T|GRXD$9!5I#0JOobK2k&4!l*vZ`j^xzQ0!>92Ko7@^{GDT)Zv}$=t zh9>Ll>xDhFIq_y8Ra5rSKV?wtzeReV(c+t(#@_Ft9townWRWk{+o3?ic2yf+>JQO7 za&!OzA|2db)8+cYXBlMxIOG}rk(ZJjqh#*P%*;t_ z15XV6-lrix=Fqz!e*~U&y_6i4+3#A;)q~~Ts*o}+nSa!vepNj1JJzBj4 zvID)O@0@L9AehQLmeHPkeAv2xKDRurPqP&CUc;NPmmynnj?#R0=DELUQHo@EnSVg~ zRLnG2QkhecHHL4_E<~2=0i7hiGu5Sjrpe{KnM_s!ru8Z!voSe*S`^6t8DirJJ5utw z$0o~dfvM5+2rE%0sSJZ0wq`#|t&f{9O&vR|&OE?d2O;GZvKT%Ygg%c@VnMn^X$A;K z8|G{A)0cvGu0vy#lI!W^h0{s6xDcV;T@meV0zECTykd8EcW7a}A6@y#=((8U&w0j+ z$I_9~Kn0!I)*?8$RJOWNF}!Q|nwW=_eT(Z!7+-0tk~*ndJCf;YcySlY)H@Cjgn%NL z1-Hh;-c^2#a=~kMNJ_){3o_0xe>~$5+-%qxJ{eFwc{!f_x!81xdI%$ZL5c*9nQ`*~^8LTB9gT)E*6b zW$s@(6toS#V2sCq2=LT@@R%&n3Yr?G&lx5h;o(Lh&>p`>ydnKTglZI=5gD#aV#_b6;6W^sqA+*L)`XVUF0 zd2$M)V)5z5tP`zZx1p&ry%MBhv#J)11a9zCWmctr@4W2{C`s(mJRa7=Cr(BlK3WB? zvhqp_YFh|rU{1jfX;doujj8jb0qNSr>M0a~Y6TdfG0I&A{p8s$OOWw_=8x>+TF;ka zp#_e`G=S>XO&MVYmgE^|tCk0LuK;VFP+2)lG5v$Ki^w;;;UaPUFJk0(wfdIjr@4U% z@5Je+UudUG8I9(?)m|#bVc=Xl=L6L!YJ_4HM||>#r0EWteE(cY*M9hxr35RZu7bQW z-`r`HK|vwarqF$7qRGW}#{GuU?i^Ab*)G@u^l@!_Ttb?1&?)ep%k8lft~Q2g|41H& zXnAazQZeL(m< z#^-85ilafS74|k9_Qz@fBPd)BC6~?I+N@nm2fo6T#qqx(vp=O&@tFOgi<4^-iN4y2 z14TZ_+SbUq!zQoGjjZ?djP5;3OlEWHHj;?X9rpjOsqUo8ib>amF-9l=HWnJP&^r$e z4?bU8hM!-ylW})IUo-@bhwVKK)Z3oMEK2HMo|Qeo-A~GJSmATZ`mPBW?Th=8^Y8Vh zi26-hsH$Cm4I2`ONmN5~x8D|4<^dhJzuw8}lA8LyL$nO;>1vl5w;DOUq%Tf4+oBSQ zvo|QPj)3bFOvtJ7ReTpXzQ2}s3Sh%k_W>x=G23s_W~)fXI#Z}?@0IA5j`n5IQ)lPG zMWE^x3;~}Ao1fGwbIUmb07<~@uZDqj3urar&52?)6#%$So*&r=N5K>I!`tr^!yPM! zBxQxIgrkD?HPO>X`I&60Cx|B{BX1>!l?&(j)FYT70revgK$2w~5~6%uX?ku#K**f0 zeaPTGi>Sx;IKJG_aIiJpo{v0+jc5?2BAZHlK=3{|6T&rZMikpHTc>ehCB7__gR91m z;dG6CVW{MCwV!eKGa_k6sQ^#!9Z*G9VH+wXuRXgV>r}Be)`(N_v6Er~Fe-S@-WC5D zlZ&)Gx6yq&rDz=d{V4K#FQ#Q9RUe>R(kZbD81* zIP_LPi9Yq&GEtgMToUhQ5Bm6GUVr12kSX8F6aHWib(^hEq~}s0dsmWj=2t3{Of21P zoZ%f8^r&r{Elis~M)dXtwX7WMkQC~zo`LBd36t{A3?`%Z2M;BY#rrDs^P(fi{1Anm znWp2HGfJYotUlo%$L}WmX#L;=^_gRe#AmTo=v8fasF6QZqvuC)34LuEK}kl_1e~0e zaD&KLm;=PJVK#awG2aD5Aj$k>?~%nMlW2&_hPw?HUCOJ5e@;VghA?+I~OkE<-(DqaCo@Rp~>BRTHIFhj$MzY z9)=@7O;+k0`9t@aRm+~Prgr`0TwFtnZ_5ICF)E9>=m3oG8IDexJ&QE~FKgpziwe0h z4o0dCKyVVt6};TeI$YgkeqswcD?Nu9o+Bb20a`gp+MaPE&ddp+xC}p=Mfc~2>1F{g z9>RC3&B52sS}OaeW-V<-sO3yD3;V4a1ZbAd<{I1v8jpeeTB*378u%7T|Lr2?4ob1 zkI-9X?->}hrov(&uVZ4)MWNC*-qxkAT_^)5z>)*ov2tKF551#^B3d9y3o3QP-O=1Y zj_vFPNV$qF3xiDIlNrlc8Yh*r7qhyIT9uRJN-M`H4)v1}#%8ts15SXKIkiea8vyta zd-A@wF*$T|KYbJIV&Y(kP^Ol62Phi%9>);h=qk2cEERLI*puMvG$i1srik|v* zM;MsB-lBgfkB!Dk>q?MiH_;gs?##ye379>LTHDh}#oIhAcZr+{83UR{*5}E0R$M(R z(&Rdmhr)7iNYYQ%9=fl#>gGu%>YRKu92(zHMKtHZo2jfu_~M(0A6~GA01d1RgJsSa zMbE0cOW6&w?h7nNtNw<;ZI$hw(5k|H*TjucCNDDL$fh50qHErAa!YFFU~Rz*nIYSd zX;x4ZQx|x&^Ul>&()pVu&75l}5WbKdIrLjX=BG}WM?fWrxG7rz6#V21n)l{OH#yDd z9SlaN%?+0KT39eyvkZ&WCf|B^-VXFBP4Kf4SxrDy04i&$X6nWZ!A2|of~o6UrPw@P%=pK>A->_2CRm`{ zY1U{V^S*A&B-nCOMm#A)2o$19TNoiUZL}DnuGt6}23_qWzTfYg^Kh#P9kZ3o>=^hx z<+hEIp~QBGkt=>Bt~#Du0xy6LxJ&pTE8N|>R$HkzY$pt=o!{>ndl2QFxHn#SyFhEt z?ZAlDqL%YWN=<#)CxHyjUk@c!mCvG!o^krjDWoG5r{;@~4!=vRihh66g9|@=D7$we z@QCxEMLlQuV=Mr#d>T$_VHUN*v`ux+G5nJ=z9)v0avi)l~@ z4etl00$Lh;M9y{(52K)IOF3(VgvIhlyOJ*WBX#m>q)+i3Uw5s!?e3vC$;0E?whXsg z1Hb2(F-mB2C^7tLi_x|$Q-%6LQq3b%(#1X~mn`oL06L zM$up5Vmjv+FCY#bWg-S!c55TEfagT}Dc!nM+qY@{X&|4<+AzN&|Io?J8j)(GC;BrZlfGH|n7cAnz zCz$+X>CC+0#unxA#>hI>)<`sbXL6&|CrT;0&Brzfb95!G{2jiLAbB88{8n1Ex2>D33Y&pIwDld${9u_7vq2{}i z2qr04`znOpnyqZ|*Vb2~U$=CB2)!E#tWSIW3ovpTC!3rL zZT!v;rh-oVI~4)kJtUR#w_lANH3pVizlY&!ZZ&KV^hsoyKsUG6;(+fX;Y z(SQ`1Ikwx`a!rY%{ z%+oJ8jIy~etc2&=_XBwLfIOwX37j9B)>`!@dr3*ib%%Kgqnv=_byUswlxxSV%cce3d>1e6k!6_Cl6!bjy| zckVUM1dp+~kF1}i{E}<@f{hwqxqPJ9m3x9#&HJ%=D`;1R`BrP)l9Q_ynKp&#pr?*oJj+(eoedakFsQLwWPCJDCtA8LCarxQcb72?bOx(b2 zk~?nMjaMMA4WL*ZRiadxKGOpd$a#ns(mz=P40L9|d$JV)bbU+xCg5lnE9OD=fE>E1M%P&6pjxXa69``4n4G5!)guHo!b;4jJt zPyXl%bG6hZjNMuf6G#_U&r+WBycv3<;m(;3wG(^+%a6CYykaQ{T* zV8FS&XE0XiX=?0vdhTTUo?^JuWd~*6Jm}TTS|T6uzPrF9ew&}Bvwo?Du5JwCeeI8* z*zGKO3W0^4P}eMOtzA2S_znMl0O3ySnN%(vOZpP5Z~DiLOMfMxrs4W`wxPeR%?Fy+ zfx084`9KauiBq*WR9Z3-`12zfGCgPiqF9lLmyao?UTJr@exl~LNF*oS3|Z7nL!f)` zSvRRdTx=mAL))=o&ScUvOB7p>*nMPN!94c|+fNkr{Yg;<#PizE3gpZi>8Fp6E$ZYE z8*6jtv&s-lfY093Oh)Fy{d{lS_=m+8`SqNpz9uT>b9P$6$m7MV%!LT`+2r3Hwi~X) zLS;b9QXG+VKQZv<-9LW{&z}36)i@@j?NXg;@AikC%fRr==BF^H(^<#_q$!{p%5azd z!6N+$OcthnLtj|W-i9zV;L{}4y-%RnCEZ7Tc;!;k9i^!O>eb7y|H_TOPvbE=ZKO)t zMEuphqMZc5=?|!+`Ls&G`sM!u D0EteT From cabdb86f25c78ff554217ea3ef6f55e5b389c77d Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Tue, 7 Apr 2026 17:14:54 +0800 Subject: [PATCH 018/156] [codex] Fuse packed x2 mul-add into fma2 in CUDA codegen (#2017) * fuse packed x2 mul-add into fma2 * document packed x2 fma2 fusion rationale --------- Co-authored-by: Zhiwen Mo --- src/target/codegen_cuda.cc | 217 ++++++++++++------ .../python/cuda/test_cuda_f32x2_intrinsics.py | 34 +++ 2 files changed, 175 insertions(+), 76 deletions(-) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index a1ee2e9a5c..00ccbf7600 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -24,6 +24,28 @@ namespace codegen { using namespace tvm::tl::codegen; using namespace ffi; +namespace { + +bool CanEmitPackedX2Math(DataType t) { + int lanes = t.lanes(); + if (lanes < 2 || lanes % 2 != 0) { + return false; + } + + if (t.is_bfloat16() || t.is_float16()) { + return true; + } + + if (t.is_float() && t.bits() == 32) { + Target cur_target = Target::Current(/*allow_not_defined=*/true); + return cur_target.defined() && tl::TargetHasSMVersionGE(cur_target, 100); + } + + return false; +} + +} // namespace + struct CUDAMath { std::string operator()(DataType t, std::string name) const { if (t.is_float()) { @@ -875,26 +897,38 @@ void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string &op, DataType t, // lanes/2 independent x2 packed operations on consecutive pairs. int lanes = t.lanes(); if (lanes >= 2 && lanes % 2 == 0) { - bool is_f32x2 = t.is_float() && t.bits() == 32; bool is_bf16x2 = t.is_bfloat16(); bool is_fp16x2 = t.is_float16(); - - // For f32x2, only emit packed ops on SM100+ (no native instructions - // before that). For bf16x2/fp16x2, the C++ helpers always have fallbacks. - bool should_emit = false; - if (is_bf16x2 || is_fp16x2) { - should_emit = true; - } else if (is_f32x2) { - Target cur_target = Target::Current(/*allow_not_defined=*/true); - should_emit = - cur_target.defined() && tl::TargetHasSMVersionGE(cur_target, 100); - } - - if (should_emit) { - // Map TIR binary-op strings to tl:: packed helpers. - // Note: fma (ternary) and abs (unary) cannot appear here. + if (CanEmitPackedX2Math(t)) { std::string tl_func; - if (op == "+") + bool use_fma = false; + PrimExpr fma_a, fma_b, fma_c; + + if (op == "+") { + // Fuse packed mul+add here instead of relying on NVCC to recover + // packed FMA from tl::mul2/tl::add2 (or the underlying __fmul2 / + // __fadd2-style helpers). Once the pairwise ops are emitted as + // separate calls, NVCC does not reliably contract them back to fma2. + auto try_fuse_mul_add = [&](const PrimExpr &maybe_mul, + const PrimExpr &addend) -> bool { + const MulNode *mul = maybe_mul.as(); + if (mul == nullptr || mul->dtype != t || mul->a.dtype() != t || + mul->b.dtype() != t || addend.dtype() != t) { + return false; + } + tl_func = "fma2"; + use_fma = true; + fma_a = mul->a; + fma_b = mul->b; + fma_c = addend; + return true; + }; + if (!try_fuse_mul_add(lhs, rhs)) { + try_fuse_mul_add(rhs, lhs); + } + } + + if (tl_func.empty() && op == "+") tl_func = "add2"; else if (op == "-") tl_func = "sub2"; @@ -929,86 +963,95 @@ void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string &op, DataType t, stream << ' ' << sret << ";\n"; int ssa_scope = BeginScope(); { - std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype()); - std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype()); + std::vector packed_vecs; + if (use_fma) { + packed_vecs = { + SSAGetID(PrintExpr(fma_a), fma_a.dtype()), + SSAGetID(PrintExpr(fma_b), fma_b.dtype()), + SSAGetID(PrintExpr(fma_c), fma_c.dtype()), + }; + } else { + packed_vecs = { + SSAGetID(PrintExpr(lhs), lhs.dtype()), + SSAGetID(PrintExpr(rhs), rhs.dtype()), + }; + } if (is_bf16x2 || is_fp16x2) { std::string native_type = is_bf16x2 ? "__nv_bfloat162" : "__half2"; + auto make_half_pair = [&](const std::string &vec_name, + const std::string &field, + int pair_offset) { + std::string pair = "tl::from_uint1<"; + pair += native_type; + pair += ">("; + if (lanes <= 8) { + pair += "*(uint1*)(&("; + pair += vec_name; + pair += "."; + pair += field; + pair += "))"; + } else { + pair += "*(((uint1*)(&("; + pair += vec_name; + pair += "."; + pair += field; + pair += "))) + "; + pair += std::to_string(pair_offset); + pair += ")"; + } + pair += ")"; + return pair; + }; for (int p = 0; p < num_pairs; ++p) { int field_idx = lanes <= 8 ? p : (p / 2); ICHECK_LT(field_idx, 4); int pair_offset = lanes <= 8 ? 0 : (p % 2); std::string field(1, access[field_idx]); - std::string pair_lhs = "tl::from_uint1<"; - pair_lhs += native_type; - pair_lhs += ">("; - if (lanes <= 8) { - pair_lhs += "*(uint1*)(&("; - pair_lhs += vlhs; - pair_lhs += "."; - pair_lhs += field; - pair_lhs += "))"; - } else { - pair_lhs += "*(((uint1*)(&("; - pair_lhs += vlhs; - pair_lhs += "."; - pair_lhs += field; - pair_lhs += "))) + "; - pair_lhs += std::to_string(pair_offset); - pair_lhs += ")"; - } - pair_lhs += ")"; - std::string pair_rhs = "tl::from_uint1<"; - pair_rhs += native_type; - pair_rhs += ">("; - if (lanes <= 8) { - pair_rhs += "*(uint1*)(&("; - pair_rhs += vrhs; - pair_rhs += "."; - pair_rhs += field; - pair_rhs += "))"; - } else { - pair_rhs += "*(((uint1*)(&("; - pair_rhs += vrhs; - pair_rhs += "."; - pair_rhs += field; - pair_rhs += "))) + "; - pair_rhs += std::to_string(pair_offset); - pair_rhs += ")"; + std::vector pair_args; + pair_args.reserve(packed_vecs.size()); + for (const auto &vec_name : packed_vecs) { + pair_args.push_back( + make_half_pair(vec_name, field, pair_offset)); } - pair_rhs += ")"; this->PrintIndent(); if (lanes <= 8) { stream << "*(uint1*)(&(" << sret << "." << field - << ")) = tl::to_uint1(tl::" << tl_func << "(" << pair_lhs - << ", " << pair_rhs << "));\n"; + << ")) = tl::to_uint1(tl::" << tl_func << "("; } else { stream << "*(((uint1*)(&(" << sret << "." << field << "))) + " << pair_offset << ") = tl::to_uint1(tl::" << tl_func - << "(" << pair_lhs << ", " << pair_rhs << "));\n"; + << "("; + } + stream << pair_args[0]; + for (size_t i = 1; i < pair_args.size(); ++i) { + stream << ", " << pair_args[i]; } + stream << "));\n"; } } else { // f32: apply tl::*2 on each consecutive pair of float fields, // reinterpreted as float2. + auto make_float_pair = [&](const std::string &vec_name, + const std::string &field) { + return "*(float2*)(&(" + vec_name + "." + field + "))"; + }; for (int p = 0; p < num_pairs; ++p) { int field_idx = lanes <= 4 ? (p * 2) : p; ICHECK_LT(field_idx, 4); std::string field(1, access[field_idx]); - std::string pair_lhs = "*(float2*)(&("; - pair_lhs += vlhs; - pair_lhs += "."; - pair_lhs += field; - pair_lhs += "))"; - std::string pair_rhs = "*(float2*)(&("; - pair_rhs += vrhs; - pair_rhs += "."; - pair_rhs += field; - pair_rhs += "))"; + std::vector pair_args; + pair_args.reserve(packed_vecs.size()); + for (const auto &vec_name : packed_vecs) { + pair_args.push_back(make_float_pair(vec_name, field)); + } this->PrintIndent(); stream << "*(float2*)(&(" << sret << "." << field - << ")) = tl::" << tl_func << "(" << pair_lhs << ", " - << pair_rhs << ");\n"; + << ")) = tl::" << tl_func << "(" << pair_args[0]; + for (size_t i = 1; i < pair_args.size(); ++i) { + stream << ", " << pair_args[i]; + } + stream << ");\n"; } } } @@ -3366,6 +3409,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { // the correct native type (__nv_bfloat162 or __half2) and cast the // result back to uint1 to avoid the ambiguous uint1 bridge overload. std::string op_name; + std::vector packed_args(op->args.begin(), op->args.end()); if (op->op.same_as(tl::add2())) op_name = "add2"; else if (op->op.same_as(tl::sub2())) @@ -3381,6 +3425,27 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { else op_name = "abs2"; + if (op->op.same_as(tl::add2()) && op->args.size() == 2) { + // Keep explicit packed helper trees on the same fused path for the + // same reason as PrintVecBinaryOp: NVCC will not reliably rewrite + // tl::mul2(...) + tl::add2(...) back into packed fma2 on its own. + auto try_fuse_mul_add = [&](const PrimExpr &mul_expr, + const PrimExpr &addend) -> bool { + const CallNode *mul_call = mul_expr.as(); + if (mul_call == nullptr || !mul_call->op.same_as(tl::mul2()) || + mul_call->args.size() != 2 || mul_call->dtype != op->dtype || + addend.dtype() != op->dtype) { + return false; + } + op_name = "fma2"; + packed_args = {mul_call->args[0], mul_call->args[1], addend}; + return true; + }; + if (!try_fuse_mul_add(op->args[0], op->args[1])) { + try_fuse_mul_add(op->args[1], op->args[0]); + } + } + DataType dtype = op->dtype; bool need_cast = dtype.is_bfloat16() || dtype.is_float16(); std::string native_type; @@ -3391,8 +3456,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { } // Helper lambda to print a casted argument expression. - auto print_arg = [&](int idx) -> std::string { - std::string arg_str = PrintExpr(op->args[idx]); + auto print_arg = [&](const PrimExpr &arg) -> std::string { + std::string arg_str = PrintExpr(arg); if (need_cast) { return "tl::from_uint1<" + native_type + ">(" + arg_str + ")"; } @@ -3405,9 +3470,9 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { os << "tl::" << op_name << "("; } - os << print_arg(0); - for (size_t i = 1; i < op->args.size(); ++i) { - os << ", " << print_arg(i); + os << print_arg(packed_args[0]); + for (size_t i = 1; i < packed_args.size(); ++i) { + os << ", " << print_arg(packed_args[i]); } os << ")"; diff --git a/testing/python/cuda/test_cuda_f32x2_intrinsics.py b/testing/python/cuda/test_cuda_f32x2_intrinsics.py index 63043a84e7..6dfb2d8099 100644 --- a/testing/python/cuda/test_cuda_f32x2_intrinsics.py +++ b/testing/python/cuda/test_cuda_f32x2_intrinsics.py @@ -131,6 +131,23 @@ def main( return main +def _make_auto_vec_fma_kernel(dtype_tl, width: int = 4): + """Build a kernel that lets CUDA codegen fuse mul + add into tl::fma2.""" + + @T.prim_func + def main( + A: T.Tensor((M, width), dtype=dtype_tl), + B: T.Tensor((M, width), dtype=dtype_tl), + C: T.Tensor((M, width), dtype=dtype_tl), + D: T.Tensor((M, width), dtype=dtype_tl), + ): + with T.Kernel(1, 1, threads=M) as (bx, by): + for i, v in T.Parallel(M, width): + D[i, v] = A[i, v] * B[i, v] + C[i, v] + + return main + + # =================================================================== # Parametrised op / dtype lists # =================================================================== @@ -250,6 +267,23 @@ def test_codegen_auto_vec_f32_no_sm80(op_key): assert f"tl::{tl_func}" not in src, f"tl::{tl_func} should NOT appear in SM80 auto-vectorised CUDA source for float32 {op_key}" +@tilelang.testing.requires_cuda +def test_codegen_auto_vec_fma_f32(): + func = _make_auto_vec_fma_kernel(T.float32) + src = _lower_to_cuda_source(func, target=SM100_TARGET) + assert "tl::fma2" in src, "Expected tl::fma2 in SM100 auto-vectorised CUDA source for float32 mul+add" + + +@tilelang.testing.requires_cuda +@pytest.mark.parametrize("dtype_name", ["bfloat16", "float16"]) +def test_codegen_auto_vec_fma_half_types(dtype_name): + dtype_tl, _ = _DTYPE_MAP[dtype_name] + func = _make_auto_vec_fma_kernel(dtype_tl, width=8) + src = _lower_to_cuda_source(func, target=SM80_TARGET) + assert "tl::fma2" in src, f"Expected tl::fma2 in CUDA source for {dtype_name} mul+add" + assert _NATIVE_CAST_TYPE[dtype_name] in src, f"Expected {_NATIVE_CAST_TYPE[dtype_name]} cast in CUDA source for {dtype_name}" + + # bfloat16 / float16: auto-vectorization should emit tl::2 on any target # (the C++ helpers have compile-time arch fallbacks). @tilelang.testing.requires_cuda From 37c1c0c68866bddfbbab935b712577678e6ec2e7 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Tue, 7 Apr 2026 21:05:45 +0800 Subject: [PATCH 019/156] [codex] Reduce slow pytest runtime in testing/python (#2018) Optimize bitwise reduce test runtime --- .../python/math/test_math_bitwise_reduce.py | 109 ++++++++++-------- 1 file changed, 64 insertions(+), 45 deletions(-) diff --git a/testing/python/math/test_math_bitwise_reduce.py b/testing/python/math/test_math_bitwise_reduce.py index 044e0ea376..4e70169955 100644 --- a/testing/python/math/test_math_bitwise_reduce.py +++ b/testing/python/math/test_math_bitwise_reduce.py @@ -1,3 +1,4 @@ +import pytest import tilelang import tilelang.language as T import torch @@ -43,40 +44,13 @@ def reduce_func( def run_single_bitwise_reduce( name, func, + a, clear=True, ): M, N = 32, 32 block_M, block_N = 32, 32 kernel = bitwise_reduce(M, N, block_M, block_N, name, func, clear) - # Generate test data that exercises all bit patterns for robust bitwise reduce testing - a = torch.zeros((M, N), device="cuda", dtype=torch.int32) - - # Fill with patterns that will produce meaningful results for bitwise operations: - # - Different bit patterns across rows/columns - # - Mix of 0s and 1s in various positions - # - Some all-1s and all-0s patterns for edge cases - for i in range(M): - for j in range(N): - # Create varied bit patterns: - # Row-based pattern: alternating bits based on row index - row_pattern = (i & 0xF) << (i % 4) # 4-bit patterns shifted by row - - # Column-based pattern: different bit positions set based on column - col_pattern = 1 << (j % 31) # Single bit set at different positions - - # Combine patterns with XOR to create diverse bit distributions - # Add some deterministic "noise" based on position - position_factor = (i * N + j) % 256 - - # Final value combines all patterns - a[i, j] = (row_pattern ^ col_pattern ^ position_factor) & 0xFFFFFFFF - - if i % 4 == 0: - a[i, j] &= ~(0x1 << (i // 4)) - elif i % 2 == 0: - a[i, j] |= 0x1 << (i // 2) - if name == "reduce_bitand": expected = torch.full((M,), -1, device="cuda", dtype=torch.int32) elif name == "reduce_bitor" or name == "reduce_bitxor": @@ -86,28 +60,73 @@ def run_single_bitwise_reduce( output = kernel(a, expected) - for i in range(M): - for j in range(N): - if name == "reduce_bitand": - expected[i] = expected[i] & a[i, j] - elif name == "reduce_bitor": - expected[i] = expected[i] | a[i, j] - elif name == "reduce_bitxor": - expected[i] = expected[i] ^ a[i, j] - else: - raise ValueError("Invalid name: {}".format(name)) + expected = reference_bitwise_reduce(name, a) assert torch.all(output == expected) print("✓ {} with clear={} test passed".format(name, clear)) +def reference_bitwise_reduce(name, a): + if name == "reduce_bitand": + op = torch.bitwise_and + identity = -1 + elif name == "reduce_bitor": + op = torch.bitwise_or + identity = 0 + elif name == "reduce_bitxor": + op = torch.bitwise_xor + identity = 0 + else: + raise ValueError("Invalid name: {}".format(name)) + + reduced = a + while reduced.shape[1] > 1: + if reduced.shape[1] % 2: + padding = torch.full( + (reduced.shape[0], 1), + identity, + device=reduced.device, + dtype=reduced.dtype, + ) + reduced = torch.cat([reduced, padding], dim=1) + reduced = op(reduced[:, 0::2], reduced[:, 1::2]) + return reduced[:, 0] + + +@pytest.fixture(scope="module") +def bitwise_reduce_input(): + M, N = 32, 32 + rows = torch.arange(M, dtype=torch.int32)[:, None] + cols = torch.arange(N, dtype=torch.int32)[None, :] + + row_pattern = (rows & 0xF) << (rows % 4) + col_pattern = torch.bitwise_left_shift(torch.ones_like(cols), cols % 31) + position_factor = (rows * N + cols) % 256 + + a = row_pattern ^ col_pattern ^ position_factor + + clear_rows = (rows % 4) == 0 + clear_bits = torch.bitwise_left_shift(torch.ones_like(rows), rows // 4) + a = torch.where(clear_rows, a & torch.bitwise_not(clear_bits), a) + + set_rows = ((rows % 4) != 0) & ((rows % 2) == 0) + set_bits = torch.bitwise_left_shift(torch.ones_like(rows), rows // 2) + a = torch.where(set_rows, a | set_bits, a) + + return a.to(device="cuda") + + +BITWISE_REDUCE_OPS = [ + ("reduce_bitand", T.reduce_bitand), + ("reduce_bitor", T.reduce_bitor), + ("reduce_bitxor", T.reduce_bitxor), +] + + @tilelang.testing.requires_cuda -def test_bitwise_reduce_ops(): - run_single_bitwise_reduce("reduce_bitand", T.reduce_bitand, clear=True) - run_single_bitwise_reduce("reduce_bitor", T.reduce_bitor, clear=True) - run_single_bitwise_reduce("reduce_bitxor", T.reduce_bitxor, clear=True) - run_single_bitwise_reduce("reduce_bitand", T.reduce_bitand, clear=False) - run_single_bitwise_reduce("reduce_bitor", T.reduce_bitor, clear=False) - run_single_bitwise_reduce("reduce_bitxor", T.reduce_bitxor, clear=False) +@pytest.mark.parametrize(("name", "func"), BITWISE_REDUCE_OPS, ids=[name for name, _ in BITWISE_REDUCE_OPS]) +@pytest.mark.parametrize("clear", [True, False], ids=["clear", "no-clear"]) +def test_bitwise_reduce_ops(bitwise_reduce_input, name, func, clear): + run_single_bitwise_reduce(name, func, bitwise_reduce_input, clear=clear) if __name__ == "__main__": From 3ee0988c86536b5867a84543fc72f1271a75f013 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Tue, 7 Apr 2026 23:54:55 +0800 Subject: [PATCH 020/156] [Refactor][Pipeline] Run pipeline rewriting before layout inference and stabilize tiled WS (#2002) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Refactor optimization pipeline in phase.py to streamline barrier handling - Consolidated the handling of shared barriers and pipeline planning by removing redundant conditional checks. - Ensured that `LowerSharedBarrier`, `PipelinePlanning`, and `InjectSoftwarePipeline` are consistently applied, enhancing the clarity and efficiency of the optimization process. This change improves the maintainability of the code while preserving existing functionality. * Enhance tile operation handling with instruction annotation and refactor gemm.h - Added a new `InstructionAnnotation` pass to annotate tile operations with their instruction kind before layout inference, improving the optimization pipeline's ability to reason about instruction mixes. - Refactored `gemm.h` to move the `allowTcgen5Mma` and `allowWgmma` methods under the private section, enhancing code organization and encapsulation. These changes improve the clarity and maintainability of the code while preserving existing functionality. * Enhance multi-version buffer handling and introduce tile-level warp specialization - Updated `multi_version_buffer_rewriter.cc` to improve read/write access detection for tile operations by analyzing `tl.tileop.region` calls, ensuring accurate buffer access tracking. - Modified `phase.py` to integrate `ProducerConsumerWarpSpecializedTiled` before layout inference, allowing for high-level tile-op IR transformations that enhance producer/consumer splits. - Added a new `ProducerConsumerWarpSpecializedTiled` function in `__init__.py` to facilitate tile-level warp specialization, improving the optimization pipeline's efficiency. These changes enhance the handling of multi-version buffers and optimize the transformation process for tiled operations. * Enable tiled WS for stage-1 pipelines * Move Hopper pipeline planning before layout inference * Run pipeline rewriting before layout inference * Add pipeline refactor WIP: subtree_modified_ guard, debug prints, and plan - inject_pipeline.cc: guard reads/writes recalculation with subtree_modified_ flag to prevent local.var buffer promotion to kernel parameters - phase.py: add temporary debug prints after PipelinePlanning and InjectSoftwarePipeline (to be removed during CI fix work) - example_group_per_split_token_cast_to_fp8.py: add disable_cache for debugging - docs/plan.md: implementation plan for fixing CI test failures - draft.md: original design draft Co-Authored-By: Claude Opus 4.6 (1M context) * Remove debug artifacts: prints in phase.py, disable_cache in example - Remove print("After PipelinePlanning"), print(mod), print("After InjectSoftwarePipeline"), print(mod) from LowerAndLegalize in tilelang/engine/phase.py - Remove tilelang.disable_cache() from examples/cast/example_group_per_split_token_cast_to_fp8.py Co-Authored-By: Claude Opus 4.6 (1M context) * Fix LayoutInference crash on multi-versioned pipeline buffers After InjectSoftwarePipeline, multi-versioned buffers share the same data Var as the original but have an extra leading dimension (num_stages). LayoutInference's alias propagation and annotation handling tried to Reshape layouts between these buffers, which failed because the total element counts differ. Guard three Reshape call sites in layout_inference.cc to skip sibling buffers whose total storage size is incompatible with the source layout. This lets multi-versioned buffers get their own layout inference instead of inheriting an incompatible layout from the original buffer. Fixes compilation failures in dequantize_gemm, GDN, and other kernels that use software pipelining with shared memory buffers. Co-Authored-By: Claude Opus 4.6 (1M context) * Strip pipeline annotations on WS fallback for TMA kernels When ProducerConsumerWarpSpecializedTiled identifies a TMA kernel as a warp-specialization candidate but the tiled rewriter cannot handle it (e.g., conditional loop bodies like sparse block masks), the fallback previously returned the original function with num_stages annotations intact. PipelinePlanning and InjectSoftwarePipeline would then generate non-WS TMA pipeline code with broken barrier phase tracking for conditional pipeline bodies (barrier waits outside conditionals cause deadlocks when the condition is false). Fix: on WS fallback, strip num_stages annotations from pipeline loops so that the pipeline passes skip the function. The kernel runs unpipelined but correctly. Fixes CUDA_ERROR_LAUNCH_FAILED in blocksparse_gemm and related TMA kernels with conditional loop bodies. Co-Authored-By: Claude Opus 4.6 (1M context) * Add guarded phase-counter support to tiled WS pass Port PhaseCounter and StageExprReplacer from the legacy ProducerConsumerWarpSpecialized pass into the tiled WS pass to handle conditional loop bodies (e.g., sparse block masks). When the pipeline loop body is wrapped in an IfThenElse without else: 1. Unwrap the condition before classifying statements 2. Create separate producer/consumer PhaseCounters (local int32 buffers) 3. Use counter-based stage/parity expressions instead of loop-variable 4. Wrap producer and consumer bodies in the original condition 5. Increment counters at end of each guarded iteration 6. Rewrite shared-buffer stage indices via StageExprReplacer This ensures barrier parity stays correct when iterations are conditionally skipped, fixing CUDA_ERROR_LAUNCH_FAILED in blocksparse_gemm and related TMA kernels with conditional execution. Co-Authored-By: Claude Opus 4.6 (1M context) * Add GemmSPNode support to pipeline planning and injection Add GemmSPNode handling to: - inject_pipeline.cc AddReadsWritesForTileOp: model A, E, B as reads and C as write (E is the sparse metadata buffer) - pipeline_planning.cc: same access model for dependency analysis This makes sparse GEMM visible to the pipeline machinery for correct stage assignment and buffer multi-versioning. However, the tile-op's consumer-side buffer accesses still don't get stage-indexed because the pipeline body rewriter can't rewrite high-level tile-op Call arguments (an architectural limitation of running InjectSoftwarePipeline before LowerTileOp). Co-Authored-By: Claude Opus 4.6 (1M context) * Fix GemmSPNode::Lower to use MakeAccessPtrFromRegion Replace whole-buffer access_ptr(1) calls with MakeAccessPtrFromRegion for A, B, C, and E buffers in sparse GEMM lowering. This preserves stage-specific region offsets from pipeline multi-versioning, matching the dense GemmNode::Lower pattern. CUDA output now shows correct stage-indexed consumer accesses: gemm_sp_ss(..., (k%3)*8192, (k%3)*8192+27648, C_local, (k%3)*2048+49152) instead of always using stage-0 offsets. Note: gemm_sp still produces incorrect results because the kernel needs warp specialization but TiledWSCandidate::Check doesn't recognize it as a TMA candidate. The non-WS pipeline path generates structurally different code from the reference WS path. Co-Authored-By: Claude Opus 4.6 (1M context) * Fix gemm_sp and seer_attention regressions Three changes in producer_consumer_ws_tiled.cc: 1. Require num_stages >= 2 for WS candidacy (was >= 1). Single-stage kernels like seer_attention don't need WS and the transformation produces incorrect results for them. 2. Add HasTmaPipeline() check to detect TMA kernels with pipeline annotations that are rejected by the full WS candidate check (e.g., kernels with manual layout annotations like gemm_sp). 3. Strip num_stages annotations for rejected TMA pipeline kernels to prevent InjectSoftwarePipeline from generating broken non-WS TMA pipeline code. One change in gemm_sp.cc: - Use MakeAccessPtrFromRegion for A, B, C, E buffer access pointers instead of whole-buffer access_ptr. This preserves stage-specific region offsets from pipeline multi-versioning. Co-Authored-By: Claude Opus 4.6 (1M context) * Fix WS conditional body unwrapping for complex loop bodies Only unwrap IfThenElse wrapper when the then-branch is a simple flat sequence of tile-op Evaluate calls. Skip unwrapping for complex bodies with LetStmt, For, or other control flow that could break variable scoping when split into producer/consumer for WS. Fixes variable-used-before-definition error in blocksparse_attention sparse_gqa_decode_varlen_indice, which has a conditional loop body containing LetStmt bindings. Co-Authored-By: Claude Opus 4.6 (1M context) * Revert num_stages >= 2 to >= 1 for WS candidacy The >= 2 threshold broke test_num_stages_one_pure_tma_keeps_auto_warp_specialize. Pure TMA kernels with num_stages=1 should still be WS candidates. The seer_attention issue (num_stages=1 with manual layout) is handled by the has_manual_layout_ check, not the num_stages threshold. Co-Authored-By: Claude Opus 4.6 (1M context) * Propagate manual layouts onto MVB-expanded buffers via Layout::Expand When LayoutInference encounters an MVB-expanded buffer (with leading stage dimensions) whose trailing dimensions match the original layout, use Layout::Expand to propagate the manual layout instead of rejecting or skipping the buffer. Applied to all 3 layout propagation paths: annotated layout map, alias propagation, and finalization. Also remove the blanket !has_manual_layout_ WS candidate rejection since manual layouts now survive onto versioned shared buffers via Layout::Expand. Fixes test_sparse_ws_regular_metadata_copy_stays_in_producer. Co-Authored-By: Claude Opus 4.6 (1M context) * Extend conditional body unwrapping for LetStmt chains and restore has_manual_layout_ guard 1. LetStmt chain peeling: when IfThenElse then_case starts with LetStmt bindings, peel them and append to let_bindings before checking the simple-body guard. This allows WS for conditional bodies with variable definitions (e.g., sparse attention patterns). 2. Restore !has_manual_layout_ in WS candidacy check: removing it caused dequant_groupedgemm_bf16_mxfp4_hopper to fail because MXFP4 layouts don't survive MVB expansion. The Layout::Expand fix handles sparse metadata layouts but not all manual layout types. 3. Layout::Expand propagation (from Round 6) remains in place for future use when MVB learns to handle all manual layout types. Co-Authored-By: Claude Opus 4.6 (1M context) * Remove is_simple_body guard from conditional WS unwrapping The simple-body guard only accepted flat Evaluate sequences inside IfThenElse, blocking legitimate WS for complex conditional bodies like sparse flash attention (T.clear, T.reduce_max, T.gemm inside the guard). The LetStmt peeling already handles variable scoping. Fixes test_pure_tma_consumer_local_init_does_not_leak_into_producer and test_sparse_ws_regular_metadata_copy_stays_in_producer. The remaining test_mixed_tma_cp_async_shared_stage_barriers failure is a pre-existing issue on the original branch: the tiled WS pass produces SIMT copies instead of cp.async because LowerPTXAsyncCopy is not in the pass pipeline (the comment at phase.py:308 says it runs earlier but no actual call exists on either branch or reference). Co-Authored-By: Claude Opus 4.6 (1M context) * Preserve pipeline context for WS producer cp.async injection Wrap the WS producer loop body in a kPipelineContextNumStages AttrStmt so that LowerTileOp's pipelined_depth_ is > 0 when processing SIMT producer copies. This enables InjectPTXAsyncCopy to generate cp.async for global-to-shared copies in the WS producer branch. Without this, the WS rewriter strips all pipeline annotations from the rewritten loops, causing LowerTileOp to skip cp.async injection for SIMT producers. The consumer loop stays annotation-free since it doesn't need cp.async. Co-Authored-By: Claude Opus 4.6 (1M context) * Fix mixed barrier protocol, consumer-init sinking, restore manual-layout guard 1. Mixed TMA+cp.async barrier: use ptx_cp_async_barrier_noinc for forward barrier arrival in mixed producer groups, matching the reference producer_consumer_ws.cc protocol. 2. Consumer-only pre-loop init sinking: in ReplacePipelineLoopInStmt, guard pre-loop siblings as consumer-only when they're not classified as producer (TMA/SIMT/cp.async). Fragment init (T.fill, T.clear) and local buffer init are placed in the consumer branch instead of the shared prelude. 3. Restored blanket has_manual_layout_ guard: the dtype-based heuristic to distinguish sparse metadata from MXFP4 layouts doesn't work because both use uint8. Dequant_groupedgemm_bf16_mxfp4_hopper requires the guard to prevent broken WS. Co-Authored-By: Claude Opus 4.6 (1M context) * Narrow consumer-init sinking, targeted manual-layout check, per-group barrier 1. Consumer-init sinking: only sink pre-loop stmts that are FillNode writing to fragment/local buffers. Keep block_mask setup and shared state in the shared prelude. 2. Manual-layout: attempt targeted check that only rejects when TMA copy destinations match layout_map entries. Collect layout_map vars and compare against TMA copy destinations. 3. Per-group cp.async barrier: use group-level cp.async flag (single group for now) instead of function-wide boolean. The 3 WS issue tests still fail because the layout_map annotation parsing falls through to the conservative rejection path. Co-Authored-By: Claude Opus 4.6 (1M context) * Fix layout_map parsing with Var→Buffer mapping for manual-layout check The layout_map annotation uses Map before LayoutInference (not Map). Parse as Map and handle both key types: Buffer (post-inference) and Var (pre-inference). For Var keys, look up the corresponding alloc_buffer by data Var match. Compare collected manual-layout buffers against pipeline copy destinations: reject only when a manually-laid-out buffer is also a producer copy target (TMA/SIMT/cp.async) inside the pipeline. This allows sparse metadata (E_shared, SIMT-copied) onto the WS path while rejecting dequant MXFP4 (B_shared, SIMT-copied with swizzle). Co-Authored-By: Claude Opus 4.6 (1M context) * Fix SEGFAULT in manual-layout check, use DetectSwizzleMode 1. Fixed SEGFAULT: removed dangling layout_map_layouts_ vector that was never populated, causing OOB access. Now stores Buffer+Layout pairs in layout_map_entries_. 2. Use DetectSwizzleMode to distinguish swizzled layouts (MXFP4, incompatible with MVB) from non-swizzled (sparse metadata, safe). Swizzled layouts reject WS candidacy; non-swizzled layouts allow it. 3. Removed debug LOG(WARNING) from hot path. 4. Parse layout_map annotation keys as both Buffer and Var (via Map), resolving Var keys to alloc_buffers. Results: sparse metadata WS test PASSES, dequant PASSES (protected). Down to 2 failures: mixed barrier pattern + consumer init sinking. Co-Authored-By: Claude Opus 4.6 (1M context) * Simplify consumer-init sinking to Evaluate-only classification Sink pre-loop Evaluate nodes classified as kConsumer into consumer branch. Keep For loops (block_mask setup), producer copies, and other control flow in the shared prelude. This is simpler and safer than the FillNode scope check approach. The test_pure_tma_consumer_local_init test still fails because the T.fill statements are at a different structural level than the SeqStmt where the pipeline loop lives. Fixing this requires deeper IR structure analysis. Co-Authored-By: Claude Opus 4.6 (1M context) * Sink consumer-only pre-loop init into WS consumer branch Extract consumer-only pre-loop Evaluate statements (T.fill on fragments) from the shared prelude and prepend them to the consumer branch inside the WS if/else structure. This ensures fragment init like acc_o, logsum, scores_max appears only in the consumer branch, not in the shared prelude or producer branch. Uses a two-pass approach: first ReplacePipelineLoopInStmt extracts consumer-only stmts into extracted_consumer_init_, then the WS body is rebuilt with the extracted stmts prepended to the consumer branch. Fixes test_pure_tma_consumer_local_init_does_not_leak_into_producer. Co-Authored-By: Claude Opus 4.6 (1M context) * Finalize pipeline refactor fixes * Refactor CopyNode stride checks for TMA bulk load/store - Extracted the global stride validation logic into a new static method `CheckGlobalStrides` for better readability and reusability. - Updated `CheckBulkLoad` and `CheckBulkStore` methods to utilize the new stride checking function, improving code clarity and maintainability. - Enhanced documentation for the new method to clarify its purpose and requirements. * Remove redundant copy direction notes from documentation in `copy.h` for stride checks. This simplifies the comments while maintaining clarity on TMA requirements. * refactor * layout related fix * refactor mbarrier with software pipeline. * remove legacy pass * Enhance TMA barrier handling and merging logic - Updated `CopyToTmaCopyRewriter` to conditionally emit arrive barriers based on the last TMA copy. - Refactored barrier creation to allow merging of TMA barriers when conditions are met, reducing the number of barriers needed. - Adjusted consumer and producer logic to accommodate merged barriers, ensuring correct synchronization across TMA copies. - Improved documentation and comments for clarity on the new barrier merging behavior. * Update example_mhc_post.py to disable main execution and add test function call - Commented out the main function call in `example_mhc_post.py`. - Added a call to `tilelang.disable_cache()` and a new `test(n=4096, h=2560)` function for testing purposes. * Enhance pipeline barrier handling and TMA copy detection - Updated `RewritePipelineTmaBarriers` to accept additional parameters for loop variable and stage count, improving barrier synchronization. - Modified `HandleTileOp` to ensure only user-defined T.copy patterns are marked for pipeline barriers, preventing interference with T.tma_copy. - Adjusted test cases to enforce specific CUDA compute version requirements for better compatibility and consistency. * lint fix * Enhance convolution example and TMA handling - Added kernel source printing in `example_convolution.py` for debugging. - Disabled the main function call and invoked `run_regression_perf()` to streamline execution. - Introduced `MakeTmaLeaderCondition` function in `copy.cc` to improve TMA leader-thread condition handling. - Updated `LowerBulkCopy` and `LowerBulkCopy1D` methods to utilize the new TMA leader condition for better thread management. - Enhanced `Conv2DIm2ColOpNode` to support barrier annotations, improving synchronization in TMA operations. * enhance * Remove example_dequant_groupedgemm_bf16_mxfp4_hopper.py and clean up phase.py by removing unnecessary whitespace. Enhance TMA handling in producer_consumer_ws_tiled.cc to improve barrier synchronization and streamline TMA copy operations. * fix * fix * Eliminate MVB(barrier_only=true) late fixup from OptimizeForTarget Move pipeline barrier ownership into InjectSoftwarePipeline: create pipeline_mbar[num_stages] at final expanded size instead of pipeline_mbar[1] that required late MVB expansion. Key changes: - RewritePipelineTmaBarriers creates barriers at num_stages size - Barrier indices use FloorMod(loop_var - loop_min, num_stages) - barrier_init has num_stages entries (one per slot, arrive_count=1) - CopyToTmaCopyRewriter accepts PrimExpr barrier_id (was int) - phase.py: remove MVB(barrier_only=True), unify both paths under PlanAndUpdateBufferAllocationLocation - Fix pre-existing test expectation for tma_copies annotation Co-Authored-By: Claude Opus 4.6 (1M context) * Fix Codex review issues: num_stages=1 regression, im2col pipeline barriers, test fixes 1. Fix num_stages=1 regression: use tl_pipelined_num_stages annotation for barrier sizing instead of max_stage+1. Gate barrier creation on pipeline_depth > 1 so num_stages=1 kernels don't get multi-versioned pipeline barriers. 2. Extend RewritePipelineTmaBarriers to handle c2d_im2col: annotate im2col calls with pipeline barrier in CopyToTmaCopyRewriter. Add im2col TMA recognition to PipelinePlanning. Fix im2col Lower() to respect emit_arrive annotation from pipeline barriers. 3. Fix test_simple_pipeline: use annotation-checking approach instead of structural equality with hardcoded tma_copies annotation that fails on non-Hopper targets. 4. Remove dead kPipelineMVBStageExpr/kPipelineMVBParityExpr/ kPipelineMVBContextNumStages emission from inject_pipeline.cc since pipeline barriers are now created at final size. Co-Authored-By: Claude Opus 4.6 (1M context) * Fix depth-1 barrier ownership, add regression tests, remove dead kPipelineMVB* 1. Fix depth-1 ownership gap: always create shared pipeline barrier for TMA copies even when pipeline_depth=1. This prevents LowerTileOp from allocating separate per-copy internal barriers, keeping num_stages=1 kernels at pipeline_mbar[1] (single slot). 2. Add checked-in lowering regression tests: - non-WS num_stages=3 TMA GEMM → asserts pipeline_mbar[3], no fallback - non-WS num_stages=1 TMA GEMM → asserts pipeline_mbar[1], no multi-slot - non-WS num_stages=3 im2col → asserts pipeline_mbar[3] feeds tma_load_im2col 3. Remove dead kPipelineMVB* definitions from pipeline_utils.h and consumption code from multi_version_buffer_rewriter.cc (stacks, explicit parity/version index logic, attr stripping). Co-Authored-By: Claude Opus 4.6 (1M context) * Fix GemmWMMA.lower() signature and autotuner cache backwards compatibility [P1] Add missing mbar_phase_expr parameter to GemmWMMA.lower() to match the interface expected by GemmPy.lower() dispatch. Without this, any RDNA kernel using the WMMA path would fail with TypeError. [P2] Make out_idx.json loading optional in autotuner load_from_disk(). Older cache directories don't have this file; fallback to compile_args.out_idx when the file is absent. Co-Authored-By: Claude Opus 4.6 (1M context) * Add .humanize to gitignore * Fix HIP swizzle codegen and CPU scalar GEMM region handling [P2] Update codegen_hip.cc to parse tvm_tuple(device_func, panel_size) format for threadblock_swizzle_pattern, matching the annotation format now emitted by T.use_swizzle(). Previously expected StringImmNode which would ICHECK-fail. [P2] Fix gemm_scalar.py: clear only the output tile region (not the whole buffer), and use the last two dimensions from regions to handle rank>2 buffers with leading singleton dimensions. Co-Authored-By: Claude Opus 4.6 (1M context) * Fix HIP codegen to parse tvm_tuple swizzle annotation format Update codegen_hip.cc threadblock_swizzle_pattern handler to parse tvm_tuple(device_func, panel_size) format, matching CUDA and cutedsl codegens. Previously expected StringImmNode which ICHECK-fails with the current annotation format. Co-Authored-By: Claude Opus 4.6 (1M context) * Restore kPipelineMVB* annotations — still needed by WS tiled full MVB The kPipelineMVBStageExpr/ParityExpr/ContextNumStages annotations and their MVB consumption code cannot be removed: they are emitted by EmitImpl for ALL pipelines and consumed by the full MVB call inside ProducerConsumerWarpSpecializedTiled. Removing them broke non-TMA pipeline kernels like mhc_post that go through the WS path. Restores: annotation emission in inject_pipeline.cc, constant definitions in pipeline_utils.h, and consumption code in multi_version_buffer_rewriter.cc. Co-Authored-By: Claude Opus 4.6 (1M context) * Unify barrier multi-versioning in InjectSoftwarePipeline (Direction A) Move all pipeline barrier multi-versioning into InjectSoftwarePipeline via ExpandPipelineBarriers, eliminating the need for late MVB(barrier_only=true) in OptimizeForTarget. ExpandPipelineBarriers runs before BuildPipeline and handles: - ISP-created pipeline_mbar (for non-WS TMA pipelines) - User-written T.alloc_barrier (for manual WS pipelines like softpipe) Key design: only barriers with explicit ptx_arrive_barrier calls OR ISP-created local barriers are expanded. Barriers whose arrival is managed internally by tile-ops (e.g., tcgen05 MMA arrive) are left unchanged. This distinguishes pipeline sync barriers from hardware- managed barriers. Buffer expansion: barrier[N] -> barrier[N * num_stages] Index rewriting: barrier[idx] -> barrier[stage_expr * N + idx] Parity rewriting: user_parity -> (iteration_block + offset) % 2 barrier_init replication: [c0,c1] -> [c0,c1,c0,c1] for num_stages=2 Expanded buffers propagate to outer blocks via pending_buffer_remap_. barrier_init annotations are replicated in VisitStmt_(BlockNode). Co-Authored-By: Claude Opus 4.6 (1M context) * Fix Python 3.9 compat: replace star expression in index with tuple concat * fix * fix * revert changes * fix * enhance pipeline * enhance pipeline * implement descriptor reuse pass * Refactor pipeline management and enhance async copy handling - Updated `example_mhc_pre.py` to replace the main function call with a test function for better testing flexibility. - Modified `example_gqa_decode.py` to disable argument parsing and added latency measurement for regression performance. - Enhanced `copy.cc` and `copy.h` with new checks for pipeline-managed cp.async synchronization and improved async copy handling. - Updated `inject_pipeline.cc` and `pipeline_planning.cc` to refine the handling of tile operations and global/shared buffer checks. - Added tests to ensure correct behavior of async pipeline and descriptor allocation reuse. This commit improves the overall pipeline management and async copy capabilities, ensuring better performance and flexibility in the codebase. * enhance pipeline * enhance pipeline * enhance pipeline * enhance pipeline * enhance pipeline * Revert "enhance pipeline" This reverts commit c056ded3c9112f1a9bc81a3f0ab012c176f40c11. * preloop tma handling * lint fix * Clean up debug changes in examples * Refactor PTX async copy injection result * refactor * Refactor tile op access region collection * Clean up pipeline access and async analysis * Refactor head async wait analysis * refactor layout expand * auto tma fix * Allow auto TMA store for plain copy * slash sparse test fix * Unify warp specialization and internalize buffer versioning * Revert formatting-only changes * Keep deprecated TL_DISABLE_TMA_LOWER compat --------- Co-authored-by: Claude Opus 4.6 (1M context) --- .gitignore | 3 + examples/bitnet-1.58b/tokenization_bitnet.py | 4 +- examples/deepseek_mhc/example_mhc_bwd.py | 5 +- examples/deepseek_mhc/example_mhc_post.py | 26 +- examples/deepseek_mhc/example_mhc_pre.py | 36 +- examples/deepseek_mla/example_mla_decode.py | 4 +- .../benchmark/benchmark_nsa_fwd.py | 8 +- .../deepseek_nsa/example_tilelang_nsa_bwd.py | 8 +- .../example_tilelang_nsa_decode.py | 6 +- .../deepseek_nsa/example_tilelang_nsa_fwd.py | 6 +- .../example_tilelang_nsa_fwd_varlen.py | 8 +- examples/deepseek_v32/inference/convert.py | 3 +- examples/deepseek_v32/inference/kernel.py | 4 +- examples/deepseek_v32/sparse_mla_bwd.py | 1 - examples/deepseek_v32/sparse_mla_fwd.py | 5 +- .../example_dequant_gemm_bf16_fp4_hopper.py | 3 +- examples/dsa_sparse_finetune/indexer_bwd.py | 5 +- .../indexer_topk_reducesum.py | 6 +- .../dsa_sparse_finetune/sparse_mla_bwd.py | 5 +- .../dsa_sparse_finetune/sparse_mla_fwd.py | 5 +- .../sparse_mla_topk_reducesum.py | 5 +- .../flash_attention_sm100/gqa_bwd_bshd.py | 6 +- .../flash_attention_sm100/gqa_fwd_bshd.py | 6 +- .../flash_attention_sm100/mha_bwd_bshd.py | 6 +- .../flash_attention_sm100/mha_fwd_bshd.py | 6 +- examples/flash_decoding/example_gqa_decode.py | 4 +- .../fusedmoe/example_fusedmoe_tilelang.py | 9 +- examples/gdn/example_chunk_o_bwd.py | 2 +- examples/gdn/example_cumsum.py | 4 +- examples/gdn/example_wy_fast_bwd_split.py | 4 +- examples/gdn/test_example_gdn_compilation.py | 4 +- .../example_tilelang_gemm_fp8_intrinsic.py | 9 +- .../example_tilelang_gemm_fp8_sm100.py | 1 - examples/gemm_sm100/README.md | 1 - examples/gemm_sm100/gemm_mma.py | 5 +- examples/gemm_sm100/gemm_tcgen5mma.py | 5 +- .../example_tilelang_gemm_streamk.py | 2 +- examples/gemv/example_gemv.py | 5 +- .../grouped_gemm/example_grouped_gemm_bwd.py | 4 +- .../example_grouped_gemm_fwd_ptr.py | 2 +- examples/kda/chunk_bwd_intra.py | 1 - examples/kda/wy_fast_bwd.py | 2 +- .../example_linear_attn_bwd.py | 7 +- .../example_linear_attn_fwd.py | 5 +- .../linear_attention/example_retention_fwd.py | 2 +- .../example_vertical_slash_sparse_attn.py | 6 +- examples/norm/rms_norm.py | 2 +- examples/norm/test_rms_norm.py | 2 +- ...mple_warp_specialize_gemm_copy_1_gemm_0.py | 3 +- ...mple_warp_specialize_gemm_copy_gemm_0_1.py | 4 +- ...st_tma_copy_pipeline_2_stages\343\200\202" | 0 maint/scripts/regression_all.py | 1 - src/layout/gemm_layouts.cc | 75 +- src/layout/layout.cc | 28 +- src/layout/layout.h | 8 + src/op/atomic_add.cc | 17 +- src/op/atomic_reduce.cc | 34 +- src/op/builtin.cc | 2 +- src/op/builtin.h | 27 +- src/op/copy.cc | 394 +- src/op/copy.h | 56 +- src/op/fill.cc | 7 +- src/op/finalize_reducer.cc | 11 +- src/op/gemm.cc | 31 +- src/op/gemm.h | 8 +- src/op/gemm_py.cc | 28 +- src/op/gemm_py.h | 1 + src/op/gemm_sp.cc | 43 +- src/op/gemm_sp.h | 1 + src/op/gemm_sp_py.cc | 26 +- src/op/gemm_sp_py.h | 1 + src/op/operator.h | 66 +- src/op/reduce.cc | 24 +- src/op/reduce.h | 1 + src/op/transpose.cc | 16 +- src/op/utils.cc | 12 + src/op/utils.h | 6 + src/target/codegen_cuda.cc | 21 - src/target/codegen_cutedsl.cc | 19 - src/target/codegen_hip.cc | 44 +- src/transform/common/pipeline_utils.h | 112 + src/transform/common/tma_copy_utils.h | 28 - src/transform/inject_pipeline.cc | 2757 ++++++++- src/transform/instruction_annotation.cc | 239 + src/transform/layout_inference.cc | 78 +- src/transform/legalize_safe_memory_access.cc | 54 +- src/transform/lower_ptx_async_copy.cc | 43 +- src/transform/lower_tile_op.cc | 127 +- .../multi_version_buffer_rewriter.cc | 558 +- src/transform/multi_version_buffer_rewriter.h | 19 + src/transform/optimize_cp_async_sync.cc | 1235 ---- src/transform/pipeline_planning.cc | 675 ++- src/transform/producer_consumer_ws.cc | 5327 ++++++----------- src/transform/ptx_async_copy_injector.h | 11 +- .../reuse_local_descriptor_allocations.cc | 254 + testing/conftest.py | 11 +- .../test_tilelang_nested_loop_checker.py | 25 +- .../cache/test_tilelang_kernel_cache.py | 12 + .../test_tilelang_carver_recommend_hints.py | 1 + ..._tilelang_pass_config_disable_tma_lower.py | 30 + ...ng_pass_config_disable_warp_specialized.py | 5 +- .../python/cuda/test_cuda_f32x2_intrinsics.py | 17 +- .../python/issue/test_tilelang_issue_1001.py | 5 +- .../python/issue/test_tilelang_issue_1008.py | 10 +- .../python/issue/test_tilelang_issue_1106.py | 5 +- .../python/issue/test_tilelang_issue_1210.py | 5 +- .../python/issue/test_tilelang_issue_1263.py | 10 +- .../python/issue/test_tilelang_issue_1744.py | 5 +- .../issue/test_tilelang_issue_tma_no_ws.py | 228 +- ...issue_ws_simt_copy_full_producer_extent.py | 4 +- ...est_tilelang_kernel_bf16_gemm_tcgen5_ts.py | 5 +- .../test_tilelang_kernel_gemm_batched.py | 5 +- .../test_tilelang_kernel_gemm_with_stride.py | 5 +- .../test_tilelang_kernel_int8_gemm_tcgen5.py | 5 +- .../language/test_tilelang_language_all_of.py | 10 +- ...t_tilelang_language_annotate_safe_value.py | 5 +- .../language/test_tilelang_language_any_of.py | 10 +- .../test_tilelang_language_chain_equal.py | 5 +- .../language/test_tilelang_language_clear.py | 2 +- ...test_tilelang_language_composable_index.py | 5 +- .../language/test_tilelang_language_copy.py | 13 +- .../test_tilelang_language_let_layout.py | 5 +- .../test_tilelang_language_mask_op.py | 8 +- .../test_tilelang_language_pipeline.py | 10 +- .../language/test_tilelang_language_ptr.py | 2 +- .../test_tilelang_language_reshape.py | 30 +- .../test_tilelang_language_tma_copy.py | 11 +- .../test_tilelang_language_tma_store.py | 64 +- .../test_tilelang_language_transpose.py | 10 +- .../test_tilelang_language_wgmma_gemm.py | 6 +- .../language/test_tilelang_memory_leak.py | 5 +- .../test_tilelang_annotate_loop_layout.py | 1 + .../test_tilelang_runtime_tma_validation.py | 13 +- .../test_tilelang_tilelibrary_gemm.py | 20 +- .../test_tilelang_tilelibrary_gemm_sp.py | 2 +- .../test_tilelang_tilelibrary_gemm_sp_v2.py | 20 +- ...lang_transform_Inject_software_pipeline.py | 394 ++ ...tilelang_transform_lower_ptx_async_copy.py | 26 + ...tilelang_transform_lower_shared_barrier.py | 5 +- .../test_tilelang_transform_lower_tile_op.py | 106 + ...tilelang_transform_multi_version_buffer.py | 144 - ...lelang_transform_optimize_cp_async_sync.py | 554 -- ...ng_transform_pipeline_barrier_ownership.py | 145 + ...st_tilelang_transform_pipeline_planning.py | 158 +- ...tilelang_transform_producer_consumer_ws.py | 763 +-- ...form_reuse_local_descriptor_allocations.py | 105 + tilelang/autotuner/param.py | 15 +- tilelang/engine/phase.py | 42 +- tilelang/intrinsics/mma_sp_macro_generator.py | 151 +- tilelang/jit/adapter/cutedsl/wrapper.py | 3 +- tilelang/jit/kernel.py | 5 +- tilelang/language/dtypes.py | 6 +- tilelang/language/eager/ast.py | 2 +- tilelang/language/eager/builder.py | 3 +- tilelang/layout/swizzle.py | 61 +- tilelang/tileop/gemm/gemm_wmma.py | 6 +- tilelang/tileop/gemm_sp/gemm_sp_base.py | 8 +- tilelang/tileop/gemm_sp/gemm_sp_mma.py | 6 +- tilelang/transform/__init__.py | 62 +- tilelang/transform/pass_config.py | 46 +- 160 files changed, 8611 insertions(+), 7655 deletions(-) create mode 100644 "language\\test_tilelang_language_tma_copy.py::test_tma_copy_pipeline_2_stages\343\200\202" create mode 100644 src/transform/common/pipeline_utils.h delete mode 100644 src/transform/common/tma_copy_utils.h create mode 100644 src/transform/instruction_annotation.cc create mode 100644 src/transform/multi_version_buffer_rewriter.h delete mode 100644 src/transform/optimize_cp_async_sync.cc create mode 100644 src/transform/reuse_local_descriptor_allocations.cc create mode 100644 testing/python/components/test_tilelang_pass_config_disable_tma_lower.py create mode 100644 testing/python/transform/test_tilelang_transform_lower_tile_op.py delete mode 100644 testing/python/transform/test_tilelang_transform_multi_version_buffer.py delete mode 100644 testing/python/transform/test_tilelang_transform_optimize_cp_async_sync.py create mode 100644 testing/python/transform/test_tilelang_transform_pipeline_barrier_ownership.py create mode 100644 testing/python/transform/test_tilelang_transform_reuse_local_descriptor_allocations.py diff --git a/.gitignore b/.gitignore index f3d27fe55b..a12077bf62 100644 --- a/.gitignore +++ b/.gitignore @@ -124,3 +124,6 @@ maint/host_checks/logs/* # perf regression test .perf_regression/ + +# agent skills folder +.humanize* diff --git a/examples/bitnet-1.58b/tokenization_bitnet.py b/examples/bitnet-1.58b/tokenization_bitnet.py index 2adfd6dee1..8db57a9c09 100644 --- a/examples/bitnet-1.58b/tokenization_bitnet.py +++ b/examples/bitnet-1.58b/tokenization_bitnet.py @@ -38,10 +38,10 @@ PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model" }, "tokenizer_file": { - "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json" }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { diff --git a/examples/deepseek_mhc/example_mhc_bwd.py b/examples/deepseek_mhc/example_mhc_bwd.py index 7b08187a21..2961d87952 100644 --- a/examples/deepseek_mhc/example_mhc_bwd.py +++ b/examples/deepseek_mhc/example_mhc_bwd.py @@ -55,10 +55,7 @@ def sinkhorn_bwd_configs(n_stream, seqlen): ) @tilelang.jit( out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def sinkhorn_bwd_implicit_cg(n_stream: int, tilesize: int = 32, threads: int = 128): seqlen = T.dynamic("seqlen") diff --git a/examples/deepseek_mhc/example_mhc_post.py b/examples/deepseek_mhc/example_mhc_post.py index f643aa4fcc..9c9dc2f720 100644 --- a/examples/deepseek_mhc/example_mhc_post.py +++ b/examples/deepseek_mhc/example_mhc_post.py @@ -7,11 +7,7 @@ @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10}, ) def mhc_post_tilelang(a, b, c, d, x, hc: int, hidden: int, n_thr: int = 128, h_blk: int = 1024) -> tilelang.JITKernel: # rename for shorter code @@ -60,6 +56,11 @@ def mhc_post( comb_res_mix: torch.Tensor, ) -> torch.Tensor: out = torch.empty_like(residual) + print( + mhc_post_tilelang.get_kernel_source( + comb_res_mix, residual, post_layer_mix.squeeze(-1), x, out, residual.shape[-2], residual.shape[-1] + ) + ) mhc_post_tilelang(comb_res_mix, residual, post_layer_mix.squeeze(-1), x, out, residual.shape[-2], residual.shape[-1]) return out @@ -108,17 +109,6 @@ def run_regression_perf(n: int = 4096, h: int = 2560, hc_mult: int = 4) -> float test_data = generate_test_data(n=n, h=h, hc_mult=hc_mult) out = torch.empty_like(test_data["residual"]) post_layer_mix = test_data["post_layer_mix"].squeeze(-1) - print( - mhc_post_tilelang.get_kernel_source( - test_data["comb_res_mix"], - test_data["residual"], - post_layer_mix, - test_data["x"], - out, - hc_mult, - h, - ) - ) def run_kernel_only(): mhc_post_tilelang( @@ -145,4 +135,6 @@ def main(): if __name__ == "__main__": - main() + # main() + tilelang.disable_cache() + test(n=4096, h=2560) diff --git a/examples/deepseek_mhc/example_mhc_pre.py b/examples/deepseek_mhc/example_mhc_pre.py index 81d9b8d101..28b6c32bf6 100644 --- a/examples/deepseek_mhc/example_mhc_pre.py +++ b/examples/deepseek_mhc/example_mhc_pre.py @@ -6,11 +6,7 @@ @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10}, ) def mhc_pre_big_fuse_tilelang( gemm_out_mul, @@ -446,36 +442,6 @@ def run_regression_perf( layer_input = torch.empty(num_tokens, hidden_size, dtype=torch.bfloat16, device=residual.device) gemm_out_mul = torch.empty(n_splits, num_tokens, hc_mult3, dtype=torch.float32, device=residual.device) gemm_out_sqrsum = torch.empty(n_splits, num_tokens, dtype=torch.float32, device=residual.device) - print( - mhc_pre_gemm_sqrsum_tilelang.get_kernel_source( - residual_flat.view(num_tokens, hc_mult * hidden_size), - fn, - gemm_out_mul.squeeze(0), - gemm_out_sqrsum.squeeze(0), - hc_mult3, - hc_mult * hidden_size, - ) - ) - print( - mhc_pre_big_fuse_tilelang.get_kernel_source( - gemm_out_mul, - gemm_out_sqrsum, - hc_scale, - hc_base, - residual_flat, - post_mix, - comb_mix, - layer_input, - hidden_size, - rms_eps, - hc_pre_eps, - hc_sinkhorn_eps, - hc_post_mult_value, - sinkhorn_repeat, - n_splits, - hc_mult, - ) - ) def run_kernel_only(): mhc_pre_gemm_sqrsum_tilelang( diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py index e777145b07..4daa39f494 100644 --- a/examples/deepseek_mla/example_mla_decode.py +++ b/examples/deepseek_mla/example_mla_decode.py @@ -9,9 +9,7 @@ @tilelang.jit( out_idx=[4], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, + pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}, ) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale): scale = float(softmax_scale * 1.44269504) # log2(e) diff --git a/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py b/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py index ca98d01be9..697f3de38c 100644 --- a/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py +++ b/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py @@ -460,13 +460,7 @@ def get_configs(): @tilelang.autotune( configs=get_configs(), ) -@tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - } -) +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) def tilelang_sparse_attention( batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16, block_T=128, num_stages=2, threads=32 ): diff --git a/examples/deepseek_nsa/example_tilelang_nsa_bwd.py b/examples/deepseek_nsa/example_tilelang_nsa_bwd.py index 3da285a9ba..2aa30a5bc5 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_bwd.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_bwd.py @@ -18,13 +18,7 @@ import tilelang -@tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - } -) +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) def tilelang_kernel_fwd( batch, heads, diff --git a/examples/deepseek_nsa/example_tilelang_nsa_decode.py b/examples/deepseek_nsa/example_tilelang_nsa_decode.py index 381d92493e..79414a762b 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_decode.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_decode.py @@ -12,11 +12,7 @@ # auto warp specialization may have some bugs. @tilelang.jit( out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}, ) def native_sparse_attention( batch, diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py index 7b36d6e26f..abed2e41dd 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py @@ -10,11 +10,7 @@ @tilelang.jit( out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def native_sparse_attention(batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16): if scale is None: diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py index b52ebe42e2..1d5d942b40 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py @@ -17,13 +17,7 @@ from einops import rearrange -@tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - } -) +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) def native_sparse_attention_varlen(batch, heads, c_seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16): if scale is None: scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) diff --git a/examples/deepseek_v32/inference/convert.py b/examples/deepseek_v32/inference/convert.py index 090be71455..cb912e1d8b 100644 --- a/examples/deepseek_v32/inference/convert.py +++ b/examples/deepseek_v32/inference/convert.py @@ -29,8 +29,7 @@ "wq_b": ("wq_b", None), "wk": ("wk", None), "k_norm": ("k_norm", None), - "weights_proj": ("weights_proj", None), -} + "weights_proj": ("weights_proj", None)} def main(hf_ckpt_path, save_path, n_experts, mp): diff --git a/examples/deepseek_v32/inference/kernel.py b/examples/deepseek_v32/inference/kernel.py index 4090d4beb8..9d9402d1a8 100644 --- a/examples/deepseek_v32/inference/kernel.py +++ b/examples/deepseek_v32/inference/kernel.py @@ -7,9 +7,7 @@ pass_configs = { tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, -} + tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True} FP8 = T.float8_e4m3fn BF16 = T.bfloat16 diff --git a/examples/deepseek_v32/sparse_mla_bwd.py b/examples/deepseek_v32/sparse_mla_bwd.py index 527de22b39..50192fa2bf 100644 --- a/examples/deepseek_v32/sparse_mla_bwd.py +++ b/examples/deepseek_v32/sparse_mla_bwd.py @@ -76,7 +76,6 @@ def postprocess_kernel( @tilelang.jit( out_idx=[-2], pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True, }, diff --git a/examples/deepseek_v32/sparse_mla_fwd.py b/examples/deepseek_v32/sparse_mla_fwd.py index ddde11f5bc..5426d9072b 100644 --- a/examples/deepseek_v32/sparse_mla_fwd.py +++ b/examples/deepseek_v32/sparse_mla_fwd.py @@ -7,10 +7,7 @@ @tilelang.jit( out_idx=[-2, -1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def sparse_mla_fwd( heads, diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py index 1e013dfbbc..2ae9bdf3eb 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py @@ -41,7 +41,7 @@ def get_configs(): ) @tilelang.jit( out_idx=[-1], - pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def matmul( M, @@ -449,7 +449,6 @@ def run_regression_perf(m=4096, n=4096, k=4096, fast_dequant=True): threads=256, split=1, ) - print(kernel.get_kernel_source()) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) return profiler.do_bench(backend="cupti") diff --git a/examples/dsa_sparse_finetune/indexer_bwd.py b/examples/dsa_sparse_finetune/indexer_bwd.py index 68508ad4e4..54e02e4f18 100644 --- a/examples/dsa_sparse_finetune/indexer_bwd.py +++ b/examples/dsa_sparse_finetune/indexer_bwd.py @@ -13,10 +13,7 @@ FP32 = T.float32 INT32 = T.int32 -pass_configs = { - tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, -} +pass_configs = {tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} @tl.jit(pass_configs=pass_configs) diff --git a/examples/dsa_sparse_finetune/indexer_topk_reducesum.py b/examples/dsa_sparse_finetune/indexer_topk_reducesum.py index d76eb02724..1066199cd0 100644 --- a/examples/dsa_sparse_finetune/indexer_topk_reducesum.py +++ b/examples/dsa_sparse_finetune/indexer_topk_reducesum.py @@ -14,11 +14,7 @@ FP32 = T.float32 INT32 = T.int32 -pass_configs = { - tl.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, - tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, -} +pass_configs = {tl.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} @tl.jit(pass_configs=pass_configs) diff --git a/examples/dsa_sparse_finetune/sparse_mla_bwd.py b/examples/dsa_sparse_finetune/sparse_mla_bwd.py index 06eaa8eb30..ab0b4fc493 100644 --- a/examples/dsa_sparse_finetune/sparse_mla_bwd.py +++ b/examples/dsa_sparse_finetune/sparse_mla_bwd.py @@ -78,10 +78,7 @@ def postprocess_kernel( @tilelang.jit( out_idx=[-2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def bwd( H, diff --git a/examples/dsa_sparse_finetune/sparse_mla_fwd.py b/examples/dsa_sparse_finetune/sparse_mla_fwd.py index d875236952..fcde71928b 100644 --- a/examples/dsa_sparse_finetune/sparse_mla_fwd.py +++ b/examples/dsa_sparse_finetune/sparse_mla_fwd.py @@ -9,10 +9,7 @@ @tilelang.jit( out_idx=[-2, -1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def sparse_mla_fwd( heads, diff --git a/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py b/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py index a03bc74f51..2fff8dd20f 100644 --- a/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py +++ b/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py @@ -12,10 +12,7 @@ FP32 = T.float32 INT32 = T.int32 -pass_configs = { - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, -} +pass_configs = {tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} @tilelang.jit(pass_configs=pass_configs) diff --git a/examples/flash_attention_sm100/gqa_bwd_bshd.py b/examples/flash_attention_sm100/gqa_bwd_bshd.py index 33661a4947..95e1c35d60 100644 --- a/examples/flash_attention_sm100/gqa_bwd_bshd.py +++ b/examples/flash_attention_sm100/gqa_bwd_bshd.py @@ -12,11 +12,7 @@ import argparse -PASS_CFG = { - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False, -} +PASS_CFG = {tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False} @tilelang.jit(out_idx=[3, 4], pass_configs=PASS_CFG) diff --git a/examples/flash_attention_sm100/gqa_fwd_bshd.py b/examples/flash_attention_sm100/gqa_fwd_bshd.py index 353f171146..775cb45dd1 100644 --- a/examples/flash_attention_sm100/gqa_fwd_bshd.py +++ b/examples/flash_attention_sm100/gqa_fwd_bshd.py @@ -14,11 +14,7 @@ import argparse -PASS_CFG = { - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False, -} +PASS_CFG = {tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False} @tilelang.jit(out_idx=[3], pass_configs=PASS_CFG) diff --git a/examples/flash_attention_sm100/mha_bwd_bshd.py b/examples/flash_attention_sm100/mha_bwd_bshd.py index 0a04edd0aa..45406a2eda 100644 --- a/examples/flash_attention_sm100/mha_bwd_bshd.py +++ b/examples/flash_attention_sm100/mha_bwd_bshd.py @@ -11,11 +11,7 @@ import argparse -PASS_CFG = { - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False, -} +PASS_CFG = {tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False} @tilelang.jit(out_idx=[3, 4], pass_configs=PASS_CFG) diff --git a/examples/flash_attention_sm100/mha_fwd_bshd.py b/examples/flash_attention_sm100/mha_fwd_bshd.py index acf5a7c3df..db9f2472fd 100644 --- a/examples/flash_attention_sm100/mha_fwd_bshd.py +++ b/examples/flash_attention_sm100/mha_fwd_bshd.py @@ -16,11 +16,7 @@ import argparse -PASS_CFG = { - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False, -} +PASS_CFG = {tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False} @tilelang.jit(out_idx=[3], pass_configs=PASS_CFG) diff --git a/examples/flash_decoding/example_gqa_decode.py b/examples/flash_decoding/example_gqa_decode.py index a23623efb0..26b4801115 100644 --- a/examples/flash_decoding/example_gqa_decode.py +++ b/examples/flash_decoding/example_gqa_decode.py @@ -40,9 +40,9 @@ def get_heuristic_config() -> Tuple[Dict, int]: return cfg, sm_version -# TODO(lei): fix warp specialized and tma lower pass +# TODO(lei): fix warp specialized pass def get_pass_configs(): - return {tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} + return {tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} @autotune(configs=get_configs(), warmup=10, rep=10) diff --git a/examples/fusedmoe/example_fusedmoe_tilelang.py b/examples/fusedmoe/example_fusedmoe_tilelang.py index 5dc3069167..d4a2ced46d 100644 --- a/examples/fusedmoe/example_fusedmoe_tilelang.py +++ b/examples/fusedmoe/example_fusedmoe_tilelang.py @@ -8,7 +8,7 @@ from example_fusedmoe_torch import * -@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +@tilelang.jit(pass_configs={"tl.disable_warp_specialized": True}) def moe_forward_tilelang_shared( d_hidden, d_expert, @@ -93,12 +93,7 @@ def kernel_shared( return kernel_shared -@tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - } -) +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) def moe_forward_tilelang_routed( d_hidden, d_expert, diff --git a/examples/gdn/example_chunk_o_bwd.py b/examples/gdn/example_chunk_o_bwd.py index 19233de62d..b369e03a8f 100644 --- a/examples/gdn/example_chunk_o_bwd.py +++ b/examples/gdn/example_chunk_o_bwd.py @@ -109,7 +109,7 @@ def prepare_output( @tilelang.jit( out_idx=[-4, -3, -2, -1], - pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def tilelang_chunk_o_bwd_dqkwg( # task config diff --git a/examples/gdn/example_cumsum.py b/examples/gdn/example_cumsum.py index 0760b49645..9d4ca0222e 100644 --- a/examples/gdn/example_cumsum.py +++ b/examples/gdn/example_cumsum.py @@ -20,9 +20,7 @@ import torch -@tilelang.jit( - out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} -) +@tilelang.jit(out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) def tilelang_chunk_local_cumsum_scalar( # task config B, diff --git a/examples/gdn/example_wy_fast_bwd_split.py b/examples/gdn/example_wy_fast_bwd_split.py index 822f745f23..5711010025 100644 --- a/examples/gdn/example_wy_fast_bwd_split.py +++ b/examples/gdn/example_wy_fast_bwd_split.py @@ -94,7 +94,7 @@ def prepare_output( @tilelang.jit( out_idx=[-5, -4, -3, -2, -1], - pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def tilelang_wy_fast_bwd( # task config @@ -247,7 +247,7 @@ def kernel( return kernel -@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) def tilelang_wy_fast_bwd_split( # task config B, diff --git a/examples/gdn/test_example_gdn_compilation.py b/examples/gdn/test_example_gdn_compilation.py index 3870393b80..b3d255a70a 100644 --- a/examples/gdn/test_example_gdn_compilation.py +++ b/examples/gdn/test_example_gdn_compilation.py @@ -1,6 +1,5 @@ import torch from tilelang import language as T -import tilelang.testing B = 1 S = 1024 # small but for test only. @@ -318,4 +317,5 @@ def test_example_chunk_delta_bwd_compilation(): if __name__ == "__main__": - tilelang.testing.main() + # tilelang.testing.main() + test_example_wy_fast_compilation() diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py index 9cfd978227..17a606fc9e 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py @@ -202,12 +202,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): if in_dtype in {torch.int8, torch.int32}: A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda() - elif in_dtype in { - torch.float8_e4m3fn, - torch.float8_e4m3fnuz, - torch.float8_e5m2, - torch.float8_e5m2fnuz, - }: + elif in_dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}: A = torch.randn(M, K).to(in_dtype).cuda() B = torch.randn(N, K).to(in_dtype).cuda() else: @@ -227,8 +222,6 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): # Get Reference Result ref_c = torch.matmul(A.to(accum_dtype), B.T.to(accum_dtype)).to(out_dtype) - print(C) - print(ref_c) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py index 3c93b4a305..cb42d921ef 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py @@ -99,7 +99,6 @@ def calc_diff(x, y): out_idx=[2], target="cuda", pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT: True, }, diff --git a/examples/gemm_sm100/README.md b/examples/gemm_sm100/README.md index 3d03184184..e9490b8654 100644 --- a/examples/gemm_sm100/README.md +++ b/examples/gemm_sm100/README.md @@ -87,7 +87,6 @@ block_M, block_N, block_K = 128, 256, 128 # Compile kernel jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, # Required tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # Required }) diff --git a/examples/gemm_sm100/gemm_mma.py b/examples/gemm_sm100/gemm_mma.py index 226e33c01e..e3a70df973 100644 --- a/examples/gemm_sm100/gemm_mma.py +++ b/examples/gemm_sm100/gemm_mma.py @@ -58,10 +58,7 @@ def main( func, out_idx=[2], target="cuda", - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) print(jit_kernel.get_kernel_source()) # 3. Test the kernel in Python with PyTorch data diff --git a/examples/gemm_sm100/gemm_tcgen5mma.py b/examples/gemm_sm100/gemm_tcgen5mma.py index 315c52b0ba..229908992a 100644 --- a/examples/gemm_sm100/gemm_tcgen5mma.py +++ b/examples/gemm_sm100/gemm_tcgen5mma.py @@ -63,10 +63,7 @@ def main( func, out_idx=[2], target="cuda", - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) print(jit_kernel.get_kernel_source()) diff --git a/examples/gemm_streamk/example_tilelang_gemm_streamk.py b/examples/gemm_streamk/example_tilelang_gemm_streamk.py index 0aa4ed1f4a..48dc175a96 100644 --- a/examples/gemm_streamk/example_tilelang_gemm_streamk.py +++ b/examples/gemm_streamk/example_tilelang_gemm_streamk.py @@ -54,7 +54,7 @@ def cdiv(a, b): sm_patition_factor = max(blocking_tiles // total_sm, 1) -@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False}) +@tilelang.jit def tl_matmul_streamk( M, N, diff --git a/examples/gemv/example_gemv.py b/examples/gemv/example_gemv.py index a5ecffbd0e..ddbe4fd7a6 100644 --- a/examples/gemv/example_gemv.py +++ b/examples/gemv/example_gemv.py @@ -227,10 +227,7 @@ def get_block_template_configs(): rep=20, ) @tl.jit( - pass_configs={ - tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, out_idx=[2], ) def gemv_alloc_reducer( diff --git a/examples/grouped_gemm/example_grouped_gemm_bwd.py b/examples/grouped_gemm/example_grouped_gemm_bwd.py index 49cce0d1dd..339f8bc1ae 100644 --- a/examples/grouped_gemm/example_grouped_gemm_bwd.py +++ b/examples/grouped_gemm/example_grouped_gemm_bwd.py @@ -5,7 +5,7 @@ import tilelang.language as T -@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_warp_specialized": True}) def grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16): """ args: @@ -157,7 +157,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets -@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_warp_specialized": True}) def grouped_gemm_bwd(batch_sum, batch_count, M, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16): """ args: diff --git a/examples/grouped_gemm/example_grouped_gemm_fwd_ptr.py b/examples/grouped_gemm/example_grouped_gemm_fwd_ptr.py index ba4d2107fb..4ce9e7320c 100644 --- a/examples/grouped_gemm/example_grouped_gemm_fwd_ptr.py +++ b/examples/grouped_gemm/example_grouped_gemm_fwd_ptr.py @@ -134,7 +134,7 @@ def run_tilelang_grouped_gemm_ptr( kernel = tl.compile( program, execution_backend=backend, - pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}, + pass_configs={"tl.disable_warp_specialized": True}, ) a_list, b_list, c_list, a_ptrs, b_ptrs, c_ptrs, batch_tile_offsets = construct_inputs(batch_sizes_list, K, N, block_M, device, dtype) refs = torch_grouped_gemm_ptr([a[:size] for a, size in zip(a_list, batch_sizes_list)], b_list) diff --git a/examples/kda/chunk_bwd_intra.py b/examples/kda/chunk_bwd_intra.py index 6c66732b4f..a4aa4f9d43 100644 --- a/examples/kda/chunk_bwd_intra.py +++ b/examples/kda/chunk_bwd_intra.py @@ -77,7 +77,6 @@ def get_configs(): @autotune(configs=get_configs(), warmup=5, rep=5) @tilelang.jit( out_idx=[-4, -3, -2, -1], - pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}, ) def tilelang_chunk_bwd_intra( # task config diff --git a/examples/kda/wy_fast_bwd.py b/examples/kda/wy_fast_bwd.py index 3a69b31621..8fd4b4d707 100644 --- a/examples/kda/wy_fast_bwd.py +++ b/examples/kda/wy_fast_bwd.py @@ -76,7 +76,7 @@ def get_configs(): @autotune(configs=get_configs(), warmup=3, rep=5) @tilelang.jit( out_idx=[-5, -4, -3, -2, -1], - pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def tilelang_wy_fast_bwd( # task config diff --git a/examples/linear_attention/example_linear_attn_bwd.py b/examples/linear_attention/example_linear_attn_bwd.py index 82ae1d982a..e15c263060 100644 --- a/examples/linear_attention/example_linear_attn_bwd.py +++ b/examples/linear_attention/example_linear_attn_bwd.py @@ -9,12 +9,7 @@ from typing import Optional, Tuple -@tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - } -) +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) def tl_fused_chunk_bwd_kernel( B, S, diff --git a/examples/linear_attention/example_linear_attn_fwd.py b/examples/linear_attention/example_linear_attn_fwd.py index cdfd5cb721..fd64882d48 100644 --- a/examples/linear_attention/example_linear_attn_fwd.py +++ b/examples/linear_attention/example_linear_attn_fwd.py @@ -11,10 +11,7 @@ @tilelang.jit( out_idx=[4], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def tl_fused_chunk_fwd_kernel( B, diff --git a/examples/linear_attention/example_retention_fwd.py b/examples/linear_attention/example_retention_fwd.py index f45e383889..9b1e1cd8b1 100644 --- a/examples/linear_attention/example_retention_fwd.py +++ b/examples/linear_attention/example_retention_fwd.py @@ -6,7 +6,7 @@ import argparse -@tl.jit(out_idx=3, pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +@tl.jit(out_idx=3, pass_configs={"tl.disable_warp_specialized": True}) def chunk_retention_fwd_kernel( B, S, diff --git a/examples/minference/example_vertical_slash_sparse_attn.py b/examples/minference/example_vertical_slash_sparse_attn.py index f1c822b5d4..a03e5318b1 100644 --- a/examples/minference/example_vertical_slash_sparse_attn.py +++ b/examples/minference/example_vertical_slash_sparse_attn.py @@ -150,15 +150,15 @@ def vs_sparse_flashattn_ws( if tid >= 128: T.annotate_producer_reg_dealloc() - T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.tma_copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared, barrier=mbars[8]) T.mbarrier_arrive(mbarrier=mbars[8]) for bi in T.serial(block_count): k = block_offset[bi] T.mbarrier_wait_parity(mbarrier=mbars[bi % 2 + 4], parity=(((bi & 3) >> 1) ^ 1)) - T.copy(K[bz, by, k : k + block_N, :], K_shared[bi % 2, :, :]) + T.tma_copy(K[bz, by, k : k + block_N, :], K_shared[bi % 2, :, :], barrier=mbars[bi % 2]) T.mbarrier_arrive(mbarrier=mbars[bi % 2]) T.mbarrier_wait_parity(mbarrier=mbars[bi % 2 + 6], parity=(((bi & 3) >> 1) ^ 1)) - T.copy(V[bz, by, k : k + block_N, :], V_shared[bi % 2, :, :]) + T.tma_copy(V[bz, by, k : k + block_N, :], V_shared[bi % 2, :, :], barrier=mbars[bi % 2 + 2]) T.mbarrier_arrive(mbarrier=mbars[bi % 2 + 2]) else: T.annotate_consumer_reg_alloc() diff --git a/examples/norm/rms_norm.py b/examples/norm/rms_norm.py index 57bccc1a0f..f05782add0 100644 --- a/examples/norm/rms_norm.py +++ b/examples/norm/rms_norm.py @@ -33,7 +33,7 @@ def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): return main -@tilelang.jit(out_idx=[-1], pass_configs={"tl.disable_tma_lower": True}) +@tilelang.jit(out_idx=[-1]) def rms_norm(M, N, blk_m): dtype = T.float diff --git a/examples/norm/test_rms_norm.py b/examples/norm/test_rms_norm.py index 53db03d98c..0d14e93a8a 100644 --- a/examples/norm/test_rms_norm.py +++ b/examples/norm/test_rms_norm.py @@ -65,7 +65,7 @@ def ref_program(x): def test_rms_norm(M=1024, N=1024, blk_m=1): program = rms_norm(M, N, blk_m) - kernel = tilelang.compile(program, out_idx=-1, pass_configs={"tl.disable_tma_lower": True}) + kernel = tilelang.compile(program, out_idx=-1) profiler = kernel.get_profiler() profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py index 8fe28c209a..aef3c0e90d 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py @@ -80,7 +80,6 @@ def run_regression_perf(M=16384, N=16384, K=16384): block_K = 64 jit_kernel = matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K) - print(jit_kernel.get_kernel_source()) import torch @@ -99,4 +98,4 @@ def run_regression_perf(M=16384, N=16384, K=16384): if __name__ == "__main__": - tilelang.testing.main() + main() diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py index 5468aa6eac..ad6dd2909d 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py @@ -6,9 +6,7 @@ # @tilelang.jit @tilelang.jit( out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - }, + pass_configs={}, ) def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): warp_group_num = 2 diff --git "a/language\\test_tilelang_language_tma_copy.py::test_tma_copy_pipeline_2_stages\343\200\202" "b/language\\test_tilelang_language_tma_copy.py::test_tma_copy_pipeline_2_stages\343\200\202" new file mode 100644 index 0000000000..e69de29bb2 diff --git a/maint/scripts/regression_all.py b/maint/scripts/regression_all.py index d6919cfeaf..0b39c8aa08 100644 --- a/maint/scripts/regression_all.py +++ b/maint/scripts/regression_all.py @@ -106,7 +106,6 @@ def regression_all(examples_root: str | os.PathLike[str] | None = None) -> None: text=True, env={ **os.environ, - "PYTHONNOUSERSITE": "1", "TL_PERF_REGRESSION_FORMAT": "json", }, ) diff --git a/src/layout/gemm_layouts.cc b/src/layout/gemm_layouts.cc index 530b46f927..f46c47715f 100644 --- a/src/layout/gemm_layouts.cc +++ b/src/layout/gemm_layouts.cc @@ -415,6 +415,15 @@ bool TryGetSwizzleShapeInfo(const Buffer &buffer, SwizzleShapeInfo *info) { } // namespace +static Layout ExpandLayout2D(const Layout &base, const Buffer &buffer) { + Array leading_shape; + leading_shape.reserve(buffer->shape.size() - 2); + for (size_t i = 0; i + 2 < buffer->shape.size(); ++i) { + leading_shape.push_back(buffer->shape[i]); + } + return base->Expand(leading_shape); +} + // Layout swizzling for 32 bytes static Layout MakeQuarterBankSwizzleLayout2D(int stride, int continuous, int element_size) { @@ -440,12 +449,7 @@ Layout makeQuarterBankSwizzleLayout(const Buffer &buffer) { auto base = MakeQuarterBankSwizzleLayout2D(static_cast(info.stride), static_cast(info.continuous), info.element_size); - Array leading_shape; - leading_shape.reserve(buffer->shape.size() - 2); - for (size_t i = 0; i + 2 < buffer->shape.size(); ++i) { - leading_shape.push_back(buffer->shape[i]); - } - return base->Expand(leading_shape); + return ExpandLayout2D(base, buffer); } // Layout swizzling for 64 bytes @@ -473,12 +477,7 @@ Layout makeHalfBankSwizzleLayout(const Buffer &buffer) { auto base = MakeHalfBankSwizzleLayout2D(static_cast(info.stride), static_cast(info.continuous), info.element_size); - Array leading_shape; - leading_shape.reserve(buffer->shape.size() - 2); - for (size_t i = 0; i + 2 < buffer->shape.size(); ++i) { - leading_shape.push_back(buffer->shape[i]); - } - return base->Expand(leading_shape); + return ExpandLayout2D(base, buffer); } // Layout swizzling for 128 bytes @@ -506,12 +505,7 @@ Layout makeFullBankSwizzleLayout(const Buffer &buffer) { auto base = MakeFullBankSwizzleLayout2D(static_cast(info.stride), static_cast(info.continuous), info.element_size); - Array leading_shape; - leading_shape.reserve(buffer->shape.size() - 2); - for (size_t i = 0; i + 2 < buffer->shape.size(); ++i) { - leading_shape.push_back(buffer->shape[i]); - } - return base->Expand(leading_shape); + return ExpandLayout2D(base, buffer); } // Detail implementation please ref to @@ -902,6 +896,51 @@ Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, return makeMatrixCoreSwizzleLayout(stride, continuous, element_size, kPack); } +Layout makeSwizzledLayout(const Buffer &buffer, bool k_inner, bool allow_pad) { + auto info = GetSwizzleShapeInfoChecked(buffer); + Layout base; + if (allow_pad) { + base = makeGemmABLayout( + static_cast(info.stride), static_cast(info.continuous), + static_cast(info.continuous), info.element_size, k_inner); + } else { + base = makeGemmABLayoutHopper( + static_cast(info.stride), static_cast(info.continuous), + static_cast(info.continuous), info.element_size, k_inner); + } + return ExpandLayout2D(base, buffer); +} + +Layout makeVoltaSwizzledLayout(const Buffer &buffer, bool is_a, bool k_inner) { + auto info = GetSwizzleShapeInfoChecked(buffer); + auto base = + makeGemmVoltaABLayout(static_cast(info.stride), + static_cast(info.continuous), is_a, k_inner); + return ExpandLayout2D(base, buffer); +} + +Layout makeWgmmaSwizzledLayout(const Buffer &buffer, int continuity, + bool k_inner) { + auto info = GetSwizzleShapeInfoChecked(buffer); + if (continuity < 0) + continuity = static_cast(info.continuous); + auto base = makeGemmABLayoutHopper(static_cast(info.stride), + static_cast(info.continuous), + continuity, info.element_size, k_inner); + return ExpandLayout2D(base, buffer); +} + +Layout makeTcgen05mmaSwizzledLayout(const Buffer &buffer, int continuity, + bool k_inner) { + auto info = GetSwizzleShapeInfoChecked(buffer); + if (continuity < 0) + continuity = static_cast(info.continuous); + auto base = makeGemmABLayoutSm100(static_cast(info.stride), + static_cast(info.continuous), + continuity, info.element_size, k_inner); + return ExpandLayout2D(base, buffer); +} + SwizzleMode DetectSwizzleMode(const Layout &layout, const Buffer &buffer) { SwizzleShapeInfo info; if (!TryGetSwizzleShapeInfo(buffer, &info)) { diff --git a/src/layout/layout.cc b/src/layout/layout.cc index 717c98bd88..c5661dc91d 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -1136,32 +1136,20 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("tl.Fragment_condense_rep_var", [](Fragment fragment) { return fragment->CondenseReplicateVar(); }) .def("tl.make_swizzled_layout", - [](int stride, int continuous, int element_size, bool k_inner, - bool allow_pad = true) { - if (allow_pad) { - return makeGemmABLayout(stride, continuous, continuous, - element_size, k_inner); - } else { - return makeGemmABLayoutHopper(stride, continuous, continuous, - element_size, k_inner); - } + [](const Buffer &buffer, bool k_inner, bool allow_pad) { + return makeSwizzledLayout(buffer, k_inner, allow_pad); }) .def("tl.make_volta_swizzled_layout", - [](int stride, int mat_continuous, bool is_a, bool k_inner) { - return makeGemmVoltaABLayout(stride, mat_continuous, is_a, - k_inner); + [](const Buffer &buffer, bool is_a, bool k_inner) { + return makeVoltaSwizzledLayout(buffer, is_a, k_inner); }) .def("tl.make_wgmma_swizzled_layout", - [](int stride, int mat_continuous, int continuity, int element_size, - bool k_inner) { - return makeGemmABLayoutHopper(stride, mat_continuous, continuity, - element_size, k_inner); + [](const Buffer &buffer, int continuity, bool k_inner) { + return makeWgmmaSwizzledLayout(buffer, continuity, k_inner); }) .def("tl.make_tcgen05mma_swizzled_layout", - [](int stride, int mat_continuous, int continuity, int element_size, - bool k_inner) { - return makeGemmABLayoutSm100(stride, mat_continuous, continuity, - element_size, k_inner); + [](const Buffer &buffer, int continuity, bool k_inner) { + return makeTcgen05mmaSwizzledLayout(buffer, continuity, k_inner); }) .def("tl.make_full_bank_swizzled_layout", [](const Buffer &buffer) { diff --git a/src/layout/layout.h b/src/layout/layout.h index 927d2e1442..8043c5765d 100644 --- a/src/layout/layout.h +++ b/src/layout/layout.h @@ -273,6 +273,14 @@ Layout makeTensorOpMultiplicand(int mat_stride, int mat_continuous, Layout makeGemmSparseAmpereABLayout(int mat_stride, int mat_continuous, int elementsize); +Layout makeSwizzledLayout(const Buffer &buffer, bool k_inner = true, + bool allow_pad = true); +Layout makeVoltaSwizzledLayout(const Buffer &buffer, bool is_a = true, + bool k_inner = true); +Layout makeWgmmaSwizzledLayout(const Buffer &buffer, int continuity = -1, + bool k_inner = true); +Layout makeTcgen05mmaSwizzledLayout(const Buffer &buffer, int continuity = -1, + bool k_inner = true); Layout makeFullBankSwizzleLayout(const Buffer &buffer); Layout makeHalfBankSwizzleLayout(const Buffer &buffer); Layout makeQuarterBankSwizzleLayout(const Buffer &buffer); diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index 8e465814bd..816ce379b4 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -45,18 +45,23 @@ AtomicAdd::AtomicAdd(Array args, Map annotations) { << "AtomicAdd expects at least 2 arguments (src, dst), got " << args.size(); ObjectPtr node = tvm::ffi::make_object(); + std::vector access_regions; if (IsBufferLikeExpr(args[0])) { - auto region = NormalizeToBufferRegion(args[0]); - node->src = region->buffer; - node->src_range = region->region; + auto src_access = NormalizeToAccessRegion(args[0], kAccessRead); + node->src = src_access.region->buffer; + node->src_range = src_access.region->region; + access_regions.push_back(std::move(src_access)); } else { node->src_value = args[0]; } - auto region = NormalizeToBufferRegion(args[1]); - node->dst = region->buffer; - node->dst_range = region->region; + auto dst_access = NormalizeToAccessRegion(args[1], kAccessReadWrite); + dst_access.access_mask = kAccessReadWrite; + node->dst = dst_access.region->buffer; + node->dst_range = dst_access.region->region; + access_regions.push_back(std::move(dst_access)); + node->SetAccessRegions(std::move(access_regions)); // Copy annotations from the Call node node->annotations = annotations; diff --git a/src/op/atomic_reduce.cc b/src/op/atomic_reduce.cc index 0bc088469c..534189c5bf 100644 --- a/src/op/atomic_reduce.cc +++ b/src/op/atomic_reduce.cc @@ -31,18 +31,23 @@ AtomicMax::AtomicMax(Array args, Map annotations) { << "AtomicMax expects at least 2 arguments (src, dst), got " << args.size(); ObjectPtr node = tvm::ffi::make_object(); + std::vector access_regions; if (IsBufferLikeExpr(args[0])) { - auto region = NormalizeToBufferRegion(args[0]); - node->src = region->buffer; - node->src_range = region->region; + auto src_access = NormalizeToAccessRegion(args[0], kAccessRead); + node->src = src_access.region->buffer; + node->src_range = src_access.region->region; + access_regions.push_back(std::move(src_access)); } else { node->src_value = args[0]; } - auto region = NormalizeToBufferRegion(args[1]); - node->dst = region->buffer; - node->dst_range = region->region; + auto dst_access = NormalizeToAccessRegion(args[1], kAccessReadWrite); + dst_access.access_mask = kAccessReadWrite; + node->dst = dst_access.region->buffer; + node->dst_range = dst_access.region->region; + access_regions.push_back(std::move(dst_access)); + node->SetAccessRegions(std::move(access_regions)); node->annotations = annotations; data_ = std::move(node); @@ -67,18 +72,23 @@ AtomicMin::AtomicMin(Array args, Map annotations) { << "AtomicMin expects at least 2 arguments (src, dst), got " << args.size(); ObjectPtr node = tvm::ffi::make_object(); + std::vector access_regions; if (IsBufferLikeExpr(args[0])) { - auto region = NormalizeToBufferRegion(args[0]); - node->src = region->buffer; - node->src_range = region->region; + auto src_access = NormalizeToAccessRegion(args[0], kAccessRead); + node->src = src_access.region->buffer; + node->src_range = src_access.region->region; + access_regions.push_back(std::move(src_access)); } else { node->src_value = args[0]; } - auto region = NormalizeToBufferRegion(args[1]); - node->dst = region->buffer; - node->dst_range = region->region; + auto dst_access = NormalizeToAccessRegion(args[1], kAccessReadWrite); + dst_access.access_mask = kAccessReadWrite; + node->dst = dst_access.region->buffer; + node->dst_range = dst_access.region->region; + access_regions.push_back(std::move(dst_access)); + node->SetAccessRegions(std::move(access_regions)); node->annotations = annotations; data_ = std::move(node); diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 59b726c51a..662f945878 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -17,11 +17,11 @@ namespace tvm { namespace tl { TVM_REGISTER_PASS_CONFIG_OPTION(kDebugMergeSharedMemoryAllocations, Bool); -TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableSafeMemoryLegalize, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableThreadStorageSync, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer); +TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kForceLetInline, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool); diff --git a/src/op/builtin.h b/src/op/builtin.h index 0d8a570253..6268528fac 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -34,6 +34,16 @@ static constexpr const char *kLoopPreferAsync = "parallel_prefer_async"; // for injected cp.async in this parallel loop subtree. Value should be Bool. static constexpr const char *kParallelAsyncWithoutAsyncCommitWait = "parallel_async_without_async_commit_wait"; +// Copy-op annotation key controlling whether cp.async commit/wait are managed +// by an enclosing transform (e.g. software pipeline / warp specialization). +// Value should be IntImm/Bool-like truthy scalar. +static constexpr const char *kAsyncCopyNoImplicitCommitWait = + "no_implicit_async_commit_wait"; +// Tile-op annotation key carrying an explicit mbarrier parity expression. +// Pipeline transforms set this on ops whose lowering would otherwise infer +// parity from surrounding loop context. +static constexpr const char *kPipelineMbarPhaseExpr = + "tl.pipeline_mbar_phase_expr"; static constexpr const char *kLocalVarInit = "tl.local_var_init"; // A PrimFunc-level attribute carrying a list of handle Vars // that must NOT be marked with the restrict qualifier in codegen. @@ -46,9 +56,21 @@ static constexpr const char *kNonRestrictParams = "tl.non_restrict_params"; static constexpr const char *kMinBlocksPerSM = "tl.min_blocks_per_sm"; } // namespace attr +inline Optional +GetAnnotatedMbarPhaseExpr(const Map &annotations) { + if (auto val = annotations.Get(attr::kPipelineMbarPhaseExpr)) { + if (val.value()->IsInstance()) { + return Downcast(val.value()); + } + LOG(FATAL) << "Annotation `" << attr::kPipelineMbarPhaseExpr + << "` expects a PrimExpr value, but got " + << val.value().GetTypeKey(); + } + return Optional(); +} + static constexpr const char *kDebugMergeSharedMemoryAllocations = "tl.debug_merge_shared_memory_allocations"; -static constexpr const char *kDisableTMALower = "tl.disable_tma_lower"; // PrimFunc attribute: set by LowerTileOp to indicate TMA operations were // actually generated. Read by OptimizeForTarget to pick the right pipeline. static constexpr const char *kHasTMA = "tl.has_tma"; @@ -57,6 +79,9 @@ static constexpr const char *kDisableSafeMemoryLegalize = static constexpr const char *kDisableWarpSpecialized = "tl.disable_warp_specialized"; static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth"; +// Deprecated compatibility-only pass config. It is no longer consumed by the +// lowering pipeline, but remains registered so legacy kernels keep working. +static constexpr const char *kDisableTMALower = "tl.disable_tma_lower"; static constexpr const char *kEnableAggressiveSharedMemoryMerge = "tl.enable_aggressive_shared_memory_merge"; static constexpr const char *kDisableFastMath = "tl.disable_fast_math"; diff --git a/src/op/copy.cc b/src/op/copy.cc index 8c4bf99a8f..0669e5a37e 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -14,6 +14,7 @@ #include "../transform/common/loop_fusion_utils.h" #include "../transform/loop_partition.h" #include "../transform/loop_vectorize.h" +#include "../transform/ptx_async_copy_injector.h" #include "utils.h" #include "builtin.h" @@ -29,6 +30,24 @@ using namespace tir; namespace { +/// Build a TMA leader-thread condition using tl_shuffle_elect. +/// \param thread_extent The number of threads in the current group +/// (e.g., full block extent for non-WS, producer_extent for WS). +/// The elected thread will be the first lane of the first warp in +/// the group. +static PrimExpr MakeTmaLeaderCondition(PrimExpr thread_extent) { + return Call(DataType::Bool(), tl_shuffle_elect(), {std::move(thread_extent)}); +} + +PrimExpr GetCopyMbarPhaseExpr(const Map &annotations, + const LowerArgs &T) { + PrimExpr phase = T.mbar_phase_expr; + if (auto explicit_phase = GetAnnotatedMbarPhaseExpr(annotations)) { + phase = explicit_phase.value(); + } + return phase; +} + // Rewrite scalar global->shared stores into ptx_cp_async calls. // This rewriter is applied before the global vectorize pass, so each generated // cp.async call starts with element-wise bytes and can be widened later. @@ -36,7 +55,9 @@ class CPAsyncStoreRewriter : public StmtMutator { public: Stmt Rewrite(const Stmt &stmt) { return VisitStmt(stmt); } - bool RewriteSuccess() const { return successfully_rewritten_; } + bool RewriteSuccess() const { + return rewritten_any_store_ && !failed_on_shared_store_; + } private: static bool IsZeroValue(const PrimExpr &e) { @@ -82,7 +103,6 @@ class CPAsyncStoreRewriter : public StmtMutator { Stmt VisitStmt_(const BufferStoreNode *op) final { if (!IsSharedBuffer(op->buffer)) { - successfully_rewritten_ = false; return StmtMutator::VisitStmt_(op); } @@ -92,19 +112,19 @@ class CPAsyncStoreRewriter : public StmtMutator { // combined so the generated cp.async is only issued when all guards hold. const BufferLoadNode *load = MatchZeroFillBufferLoad(op->value, &predicate); if (load == nullptr) { - successfully_rewritten_ = false; + failed_on_shared_store_ = true; return StmtMutator::VisitStmt_(op); } if (!IsGlobalBuffer(load->buffer)) { - successfully_rewritten_ = false; + failed_on_shared_store_ = true; return StmtMutator::VisitStmt_(op); } int bytes = op->value.dtype().bytes(); int vectorized_lanes = current_vectorized_lanes_; if (!IsValidCPAsyncTransferBytes(bytes * vectorized_lanes)) { - successfully_rewritten_ = false; + failed_on_shared_store_ = true; return StmtMutator::VisitStmt_(op); } @@ -129,7 +149,7 @@ class CPAsyncStoreRewriter : public StmtMutator { if (predicate.defined()) { args.push_back(predicate.value()); } - successfully_rewritten_ = true; + rewritten_any_store_ = true; return Evaluate(Call(DataType::Handle(), builtin::ptx_cp_async(), args)); } @@ -156,7 +176,8 @@ class CPAsyncStoreRewriter : public StmtMutator { return stmt; } - bool successfully_rewritten_ = true; + bool rewritten_any_store_ = false; + bool failed_on_shared_store_ = false; int current_vectorized_lanes_ = 1; }; @@ -168,15 +189,13 @@ class CPAsyncStoreRewriter : public StmtMutator { // etc. Copy::Copy(Array args, Map annotations) { ObjectPtr node = tvm::ffi::make_object(); - Array rgs[2]; - Buffer bf[2]; - for (int i = 0; i < 2; i++) { - auto region = NormalizeToBufferRegion(args[i]); - rgs[i] = region->region; - bf[i] = region->buffer; - } - std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); - std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); + auto src_access = NormalizeToAccessRegion(args[0], kAccessRead); + auto dst_access = NormalizeToAccessRegion(args[1], kAccessWrite); + node->src = src_access.region->buffer; + node->dst = dst_access.region->buffer; + node->src_range = src_access.region->region; + node->dst_range = dst_access.region->region; + node->SetAccessRegions({src_access, dst_access}); // Copy annotations from the Call node node->annotations = annotations; data_ = std::move(node); @@ -423,10 +442,6 @@ Layout CopyNode::ComputeLinearLayout(const Buffer &shared_tensor) const { LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, InferLevel level) const { auto target = T.target; - using namespace tvm::transform; - PassContext pass_ctx = PassContext::Current(); - bool disable_tma_lower = - pass_ctx->GetConfig(kDisableTMALower, Bool(false)).value(); CopyInst copy_inst; if (GetIsAsyncCopy()) { // Layout inference does not require a full cp.async legality proof (which @@ -450,9 +465,7 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, } copy_inst = CopyInst::kCPAsync; } else { - copy_inst = - GetCopyInst(target, disable_tma_lower || GetDisableTMA(), T.layout_map, - T.analyzer, T.buffer_oob, T.in_pipeline); + copy_inst = GetCopyInst(target, T.layout_map, T.analyzer, T.buffer_oob); } // If user annotated a loop layout on T.copy, enforce SIMT (normal) copy. @@ -620,6 +633,53 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, auto layout_map = par_op_->InferLayout(T, level); return layout_map; } +// Shared stride validation for TMA bulk load/store. +bool CopyNode::CheckGlobalStrides(const Buffer &buffer, + arith::Analyzer *analyzer) { + Array strides = buffer->strides; + if (strides.empty()) { + PrimExpr stride = 1; + strides.resize(buffer->shape.size()); + for (int i = static_cast(buffer->shape.size()) - 1; i >= 0; --i) { + strides.Set(i, stride); + stride *= buffer->shape[i]; + } + } + + if (!strides.empty() && + analyzer->CanProve(strides[strides.size() - 1] != 1, + arith::ProofStrength::kSymbolicBound)) { + LOG(WARNING) << "TMA bulk copy requires contiguous innermost global stride" + << ", but got " << strides[strides.size() - 1] + << " for buffer " << buffer->name + << ", fallback to normal copy."; + return false; + } + + for (size_t i = 0; i + 1 < strides.size(); ++i) { + PrimExpr stride_bytes = + cast(DataType::Int(64), strides[i]) * buffer->dtype.bytes(); + if (analyzer->CanProve( + FloorMod(stride_bytes, IntImm(DataType::Int(64), 16)) != 0, + arith::ProofStrength::kSymbolicBound)) { + LOG(WARNING) << "TMA bulk copy cannot support a global stride of " + << stride_bytes << " for buffer " << buffer->name + << ", fallback to normal copy."; + return false; + } + if (const int64_t *stride = + as_const_int(analyzer->Simplify(stride_bytes))) { + if (*stride >= (int64_t{1} << 40)) { + LOG(WARNING) << "TMA bulk copy cannot support a global stride of " + << stride_bytes << " for buffer " << buffer->name + << ", fallback to normal copy."; + return false; + } + } + } + return true; +} + // Checks if this copy can be lowered to a Bulk Load (TMA) instruction. // Requires: TMA support, global->shared scope, matching dtypes. bool CopyNode::CheckBulkLoad(Target target, arith::Analyzer *analyzer, @@ -654,6 +714,8 @@ bool CopyNode::CheckBulkLoad(Target target, arith::Analyzer *analyzer, << " vs. " << dst->dtype << " will be fallback to normal copy"; return false; } + if (!CheckGlobalStrides(src, analyzer)) + return false; return true; } @@ -765,6 +827,8 @@ bool CopyNode::CheckBulkStore(Target target, arith::Analyzer *analyzer, << " vs. " << dst->dtype << " will be fallback to normal copy"; return false; } + if (!CheckGlobalStrides(dst, analyzer)) + return false; return true; } @@ -804,15 +868,33 @@ bool CopyNode::CheckTMemStore(Target target) const { // - vectorized copy width (bytes) is one of {4, 8, 16} // - if OOB guards are required, only a *uniform* (scalar) source predicate // is supported (dst must be in-bounds) +bool CopyNode::CheckCPAsyncCopyPreconditions() const { + if (!IsGlobalBuffer(src) || !IsSharedBuffer(dst)) { + return false; + } + if (src->dtype != dst->dtype) { + return false; + } + return true; +} + +bool CopyNode::CheckPipelineManagedCPAsyncCopy() const { + return !GetIsTmaCopy() && !GetIsAsyncCopy() && + CheckCPAsyncCopyPreconditions(); +} + +bool CopyNode::CheckPipelineManagedCPAsyncCopy( + Target target, arith::Analyzer *analyzer) const { + return CheckPipelineManagedCPAsyncCopy() && + CheckCPAsyncCopy(target, LayoutMap(), analyzer); +} + bool CopyNode::CheckCPAsyncCopy(Target target, const LayoutMap &layout_map, arith::Analyzer *analyzer) const { if (!TargetHasAsyncCopy(target)) { return false; } - if (!IsGlobalBuffer(src) || !IsSharedBuffer(dst)) { - return false; - } - if (src->dtype != dst->dtype) { + if (!CheckCPAsyncCopyPreconditions()) { return false; } // Skip vectorize size check here because, during the Infer Layout stage, @@ -823,14 +905,9 @@ bool CopyNode::CheckCPAsyncCopy(Target target, const LayoutMap &layout_map, // Selects the most specific copy instruction for the given target and buffers. // Priority: BulkLoad1D, BulkStore1D, BulkLoad, BulkStore, LDSM, STSM, // TMemLoad, TMemStore, CPAsync, Normal. -CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, - const LayoutMap &layout_map, - arith::Analyzer *analyzer, bool buffer_oob, - bool in_pipeline) const { - // disable_tma_lower is from pass_configs - // when tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER is True, - // we will not use tma for bulk load/store - +CopyInst CopyNode::GetCopyInst(Target target, const LayoutMap &layout_map, + arith::Analyzer *analyzer, + bool buffer_oob) const { // When is_tma_copy is set (from T.tma_copy()), force TMA path. if (GetIsTmaCopy()) { // Check if target is CuTeDSL backend @@ -853,38 +930,35 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, } } - // Check if target is CuTeDSL backend - bool is_cutedsl = TargetIsCuTeDSL(target); bool is_async_copy = GetIsAsyncCopy(); + bool no_implicit_commit_wait = GetNoImplicitAsyncCommitWait(); - if (is_async_copy) { + if (is_async_copy || no_implicit_commit_wait) { bool cp_async_supported = CheckCPAsyncCopy(target, layout_map, analyzer); ICHECK(cp_async_supported) - << "T.async_copy must lower to cp.async, but constraints were not " - "satisfied. Got src=" + << "Explicit async copy semantics require cp.async lowering, but " + "constraints were not satisfied. Got src=" << src->name << " (scope=" << src.scope() << ", dtype=" << src->dtype << "), dst=" << dst->name << " (scope=" << dst.scope() << ", dtype=" << dst->dtype << ")."; return CopyInst::kCPAsync; } + // Plain T.copy does not auto-upgrade to TMA loads anymore. Store-side TMA + // remains allowed because it is self-synchronized locally and does not + // participate in pipeline producer scheduling. + if (!GetDisableTMA()) { + bool is_cutedsl = TargetIsCuTeDSL(target); + if (!is_cutedsl && !buffer_oob && + CheckBulkStore1D(target, layout_map, analyzer)) { + return CopyInst::kBulkStore1D; + } else if (CheckBulkStore(target, analyzer)) { + return CopyInst::kBulkStore; + } + } + // Check tensor memory operations first (highest priority for SM100/Blackwell) - // 1d tma access can not support out of bound access - // NOTE: Skip BulkLoad1D/BulkStore1D for CuTeDSL backend because - // cp_async_bulk_shared_cluster_global (raw 1D TMA) combined with WGMMA - // in the same kernel triggers a ptxas ICE in the NVPTX backend. - // Falling through to descriptor-based BulkLoad/BulkStore avoids this. - if (!is_cutedsl && !disable_tma_lower && !buffer_oob && - CheckBulkLoad1D(target, layout_map, analyzer)) { - return CopyInst::kBulkLoad1D; - } else if (!is_cutedsl && !disable_tma_lower && !buffer_oob && - CheckBulkStore1D(target, layout_map, analyzer)) { - return CopyInst::kBulkStore1D; - } else if (!disable_tma_lower && CheckBulkLoad(target, analyzer)) { - return CopyInst::kBulkLoad; - } else if (!disable_tma_lower && CheckBulkStore(target, analyzer)) { - return CopyInst::kBulkStore; - } else if (CheckLDSMCopy(target)) { + if (CheckLDSMCopy(target)) { return CopyInst::kLDSM; } else if (CheckSTSMCopy(target)) { return CopyInst::kSTSM; @@ -892,15 +966,6 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, return CopyInst::kTMemLoad; } else if (CheckTMemStore(target)) { return CopyInst::kTMemStore; - } else if (in_pipeline) { - using namespace tvm::transform; - PassContext pass_ctx = PassContext::Current(); - bool enable_async_copy = - pass_ctx->GetConfig(kEnableAsyncCopy, Bool(true)).value(); - if (enable_async_copy && CheckCPAsyncCopy(target, layout_map, analyzer)) { - return CopyInst::kCPAsync; - } - return CopyInst::kNormal; } else { return CopyInst::kNormal; } @@ -910,14 +975,8 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, // functions. Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Target target = T.target; - - using namespace tvm::transform; - PassContext pass_ctx = PassContext::Current(); - bool disable_tma_lower = - pass_ctx->GetConfig(kDisableTMALower, Bool(false)).value(); - auto copy_inst = GetCopyInst(target, disable_tma_lower || GetDisableTMA(), - T.layout_map, analyzer, /*buffer_oob=*/false, - /*in_pipeline=*/T.in_pipeline); + auto copy_inst = + GetCopyInst(target, T.layout_map, analyzer, /*buffer_oob=*/false); if (copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) { auto tmem_copy = LowerTmemCopy(T, analyzer); ICHECK(tmem_copy.defined()) << "Failed to lower tensor memory copy"; @@ -948,16 +1007,20 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } // Lowers copy to cp.async global->shared transfers. -// - T.copy (auto cp.async) keeps synchronous semantics by committing and -// waiting after the loop. +// - T.copy annotated for cp.async keeps synchronous semantics by committing +// and waiting after the loop. // - T.async_copy commits but does not wait (explicit async semantics). +// - Copies annotated with kAsyncCopyNoImplicitCommitWait emit only cp.async; +// an enclosing pass is responsible for commit/wait placement. Stmt CopyNode::LowerCPAsyncCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { using namespace tvm::transform; PassContext pass_ctx = PassContext::Current(); bool enable_async_copy = pass_ctx->GetConfig(kEnableAsyncCopy, Bool(true)).value(); - if ((!enable_async_copy || !T.in_pipeline) && !GetIsAsyncCopy()) { + bool no_implicit_commit_wait = GetNoImplicitAsyncCommitWait(); + bool explicit_async_semantics = no_implicit_commit_wait || GetIsAsyncCopy(); + if (!enable_async_copy && !explicit_async_semantics) { return LowerNormalCopy(T, analyzer); } @@ -982,25 +1045,43 @@ Stmt CopyNode::LowerCPAsyncCopy(const LowerArgs &T, LowerParallelLoop(par_op->GetRoot(), loop_layout, T.thread_var, analyzer, T.layout_map, par_op->GetPredicate(T.thread_var)); - CPAsyncStoreRewriter cp_async_rewriter; - Stmt cp_async_loop = cp_async_rewriter.Rewrite(lowered_loop); - if (!cp_async_rewriter.RewriteSuccess()) { - if (GetIsAsyncCopy()) { - LOG(FATAL) << "T.async_copy cannot be lowered to cp.async: no eligible " - "global->shared store was rewritten."; + bool async_without_implicit_commit_wait = + no_implicit_commit_wait || GetIsAsyncCopy(); + auto inject_result = + InjectPTXAsyncCopy(lowered_loop, /*enable_auto_async_copy=*/true, + async_without_implicit_commit_wait); + Stmt cp_async_loop = inject_result.stmt; + if (!inject_result.injected_ptx_async_copy) { + LOG(WARNING) << "cp.async rewrite miss for copy src=" << src->name + << " (scope=" << src.scope() << ", dtype=" << src->dtype + << "), dst=" << dst->name << " (scope=" << dst.scope() + << ", dtype=" << dst->dtype + << "), no_implicit_async_commit_wait=" + << no_implicit_commit_wait + << ", is_async_copy=" << GetIsAsyncCopy(); + if (no_implicit_commit_wait) { + LOG(WARNING) + << "Pipeline-managed async copy fallback to normal copy because " + "cp.async rewrite found no eligible global->shared store."; + return lowered_loop; + } + if (explicit_async_semantics) { + LOG(FATAL) << "Explicit async copy semantics require cp.async lowering, " + "but no eligible global->shared store was rewritten."; } LOG(WARNING) << "Fallback to normal copy because cp.async rewrite found " "no eligible global->shared store."; return LowerNormalCopy(T, analyzer); } - Stmt commit_group = - Evaluate(Call(DataType::Handle(), builtin::ptx_commit_group(), {})); + if (no_implicit_commit_wait) { + return cp_async_loop; + } if (GetIsAsyncCopy()) { + Stmt commit_group = + Evaluate(Call(DataType::Handle(), builtin::ptx_commit_group(), {})); return SeqStmt({cp_async_loop, commit_group}); } - Stmt wait_group = Evaluate(Call(DataType::Handle(), builtin::ptx_wait_group(), - {IntImm(DataType::Int(32), 0)})); - return SeqStmt({cp_async_loop, commit_group, wait_group}); + return cp_async_loop; } // Lowers the copy using standard load/store with loop transformations. @@ -1613,14 +1694,13 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, ICHECK(stride != nullptr && continuous != nullptr); // We also need to check if the shape satisfies the following doc: // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 - if (StructuralEqual()(shared_layout, makeQuarterBankSwizzleLayout( - shared_tensor_unmapped))) { + SwizzleMode swizzle_mode = + DetectSwizzleMode(shared_layout, shared_tensor_unmapped); + if (swizzle_mode == SwizzleMode::kQuarter) { desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_32B); - } else if (StructuralEqual()(shared_layout, makeHalfBankSwizzleLayout( - shared_tensor_unmapped))) { + } else if (swizzle_mode == SwizzleMode::kHalf) { desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_64B); - } else if (StructuralEqual()(shared_layout, makeFullBankSwizzleLayout( - shared_tensor_unmapped))) { + } else if (swizzle_mode == SwizzleMode::kFull) { desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_128B); } else if (StructuralEqual()( shared_layout, @@ -1708,7 +1788,7 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, << "Use T.tma_copy(src, dst, barrier=mbar[idx])."; } else if (T.AllocMBarrier) { // Internal mbarrier (T.copy()): allocate a single barrier slot. - // MultiVersionBuffer will expand it for pipelining stages. + // Pipeline buffer versioning expands it per stage when needed. barrier_base_id = T.AllocMBarrier(1); PrimExpr mbar_idx = IntImm(DataType::Int(32), barrier_base_id); mbar_handle = BufferLoad(T.mbarrier_buffer->value(), {mbar_idx}); @@ -1822,6 +1902,15 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, Evaluate(Call(DataType::Handle(), mbarrier_expect_tx(), {mbar_handle, total_bytes})); } + // When emit_arrive is set (by InjectSoftwarePipeline for pipeline-level + // barrier management), also emit arrive inside the thread-0 guard. + if (auto emit_arrive_val = annotations.Get("emit_arrive")) { + if (Downcast(emit_arrive_val.value())->value != 0) { + barrier_after_tma_stmt = + Evaluate(Call(DataType::Handle(), builtin::ptx_arrive_barrier(), + {mbar_handle})); + } + } } else { // T.copy() with TMA: keep expect_tx and arrive as separate control ops. // This lets downstream WS/barrier passes reason about the arrival @@ -1839,14 +1928,9 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, } // Thread-gated block: expect_tx + tma_load (+ optional arrive) - Stmt producer = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), + Stmt producer = IfThenElse(MakeTmaLeaderCondition(T.thread_bounds->extent), SeqStmt(producer_seq)); - // Annotate the producer with the shared buffer it writes to. - // PipelinePlanning uses this to identify TMA copy stages. - producer = AttrStmt(shared_tensor->data, "tl.tma_copy_write_buffer", - IntImm(DataType::Int(32), 1), producer); - // tma_copy (from T.tma_copy()) is fire-and-forget: only emit the // producer (expect_tx + tma_load). The user manages synchronization // (arrive + wait) explicitly. @@ -1856,13 +1940,15 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, // For T.copy() with TMA: emit producer + wait pair so the pipeline/WS // passes can split them into different stages. - Stmt wait_stmt = Evaluate(Call(DataType::Handle(), mbarrier_wait_parity(), - {mbar_handle, T.mbar_phase_expr})); + Stmt wait_stmt = + Evaluate(Call(DataType::Handle(), mbarrier_wait_parity(), + {mbar_handle, GetCopyMbarPhaseExpr(annotations, T)})); return SeqStmt({producer, wait_stmt}); } - tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); + tma_copy = + IfThenElse(MakeTmaLeaderCondition(T.thread_bounds->extent), tma_copy); return tma_copy; } @@ -1940,7 +2026,7 @@ Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer, << "Use T.tma_copy(src, dst, barrier=mbar[idx])."; } else if (T.AllocMBarrier) { // Internal mbarrier (T.copy()): allocate a single barrier slot. - // MultiVersionBuffer will expand it for pipelining stages. + // Pipeline buffer versioning expands it per stage when needed. barrier_base_id = T.AllocMBarrier(1); PrimExpr mbar_idx = IntImm(DataType::Int(32), barrier_base_id); mbar_handle = BufferLoad(T.mbarrier_buffer->value(), {mbar_idx}); @@ -1979,7 +2065,6 @@ Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer, // For 1D TMA loads with inline mbarrier: emit expect_tx + tma_load // (inside thread-gated block), and wait_parity after (all threads). - // The producer is annotated with the shared buffer for pipeline detection. if (is_load && barrier_base_id >= 0) { Stmt barrier_before_tma_stmt; Optional barrier_after_tma_stmt = std::nullopt; @@ -2004,13 +2089,9 @@ Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer, producer_seq.push_back(barrier_after_tma_stmt.value()); } - Stmt producer = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), + Stmt producer = IfThenElse(MakeTmaLeaderCondition(T.thread_bounds->extent), SeqStmt(producer_seq)); - // Annotate the producer with the shared buffer it writes to. - producer = AttrStmt(shared_tensor->data, "tl.tma_copy_write_buffer", - IntImm(DataType::Int(32), 1), producer); - // tma_copy (from T.tma_copy()) is fire-and-forget: only emit the // producer (expect_tx + tma_load). The user manages synchronization // (arrive + wait) explicitly. @@ -2020,13 +2101,15 @@ Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer, // For T.copy() with TMA: emit producer + wait pair so the pipeline/WS // passes can split them into different stages. - Stmt wait_stmt = Evaluate(Call(DataType::Handle(), mbarrier_wait_parity(), - {mbar_handle, T.mbar_phase_expr})); + Stmt wait_stmt = + Evaluate(Call(DataType::Handle(), mbarrier_wait_parity(), + {mbar_handle, GetCopyMbarPhaseExpr(annotations, T)})); return SeqStmt({producer, wait_stmt}); } - tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); + tma_copy = + IfThenElse(MakeTmaLeaderCondition(T.thread_bounds->extent), tma_copy); return tma_copy; } // Encodes the TMA descriptor into an array of PrimExpr for @@ -2061,8 +2144,11 @@ Conv2DIm2ColOp::Conv2DIm2ColOp(Array args, Map annotations) { ObjectPtr node = tvm::ffi::make_object(); - node->srcRegion_ = NormalizeToBufferRegion(args[0]); - node->dstRegion_ = NormalizeToBufferRegion(args[1]); + auto src_access = NormalizeToAccessRegion(args[0], kAccessRead); + auto dst_access = NormalizeToAccessRegion(args[1], kAccessWrite); + node->srcRegion_ = src_access.region; + node->dstRegion_ = dst_access.region; + node->SetAccessRegions({src_access, dst_access}); node->src_ = node->srcRegion_->buffer; node->dst_ = node->dstRegion_->buffer; node->nhw_step_ = args[2]; @@ -2072,6 +2158,7 @@ Conv2DIm2ColOp::Conv2DIm2ColOp(Array args, node->dilation_ = args[6].as().value()->value; node->padding_ = args[7].as().value()->value; node->eviction_policy_ = args[8].as().value()->value; + node->annotations_ = annotations; data_ = std::move(node); } @@ -2087,8 +2174,14 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, ICHECK(TargetIsHopper(T.target)); ICHECK(IsGlobalBuffer(src_) && IsSharedBuffer(dst_)); ICHECK(src_->shape.size() == 4); - ICHECK(dst_->shape.size() == 2); ICHECK(src_->dtype == dst_->dtype); + + // Use dstRegion_ to derive tile dimensions and shared memory offset. + // dstRegion_ always has the correct ranges regardless of whether MVB + // added a leading stage dimension to the buffer — the last two ranges + // give the tile (pixel, channel) extents and mins. + size_t ndim = dstRegion_->region.size(); + ICHECK(ndim >= 2) << "im2col dstRegion must have at least 2 dims"; Layout shared_layout; if (T.layout_map.count(dst_)) { shared_layout = T.layout_map[dst_]; @@ -2120,8 +2213,10 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, desc.elem_stride = {1, stride_, stride_, 1}; desc.lower_corner = {-padding_, -padding_}; desc.upper_corner = {-padding_, -padding_}; - desc.smem_box_pixel = Downcast(dst_->shape[0])->value; - desc.smem_box_channel = Downcast(dst_->shape[1])->value; + desc.smem_box_pixel = + Downcast(dstRegion_->region[ndim - 2]->extent)->value; + desc.smem_box_channel = + Downcast(dstRegion_->region[ndim - 1]->extent)->value; desc.l2_promotion = static_cast(CU_TENSOR_MAP_L2_PROMOTION_L2_128B); desc.oob_fill = static_cast(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); desc.interleave = static_cast(CU_TENSOR_MAP_INTERLEAVE_NONE); @@ -2181,11 +2276,16 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, // Allocate mbarrier(s) for TMA im2col load synchronization, // matching the protocol used by regular TMA loads. + // If a barrier was provided by the WS pass (via annotation), use it directly. int barrier_base_id = -1; PrimExpr mbar_handle; - if (T.AllocMBarrier) { - // Allocate a single barrier slot; MultiVersionBuffer will - // expand it for pipelining stages. + if (auto user_barrier = annotations_.Get("barrier")) { + // WS pass provided a barrier: use it without allocating a new one. + mbar_handle = Downcast(user_barrier.value()); + barrier_base_id = 0; + } else if (T.AllocMBarrier) { + // Allocate a single barrier slot; pipeline buffer versioning expands it + // per stage when needed. barrier_base_id = T.AllocMBarrier(1); PrimExpr mbar_idx = IntImm(DataType::Int(32), barrier_base_id); mbar_handle = BufferLoad(T.mbarrier_buffer->value(), {mbar_idx}); @@ -2196,7 +2296,22 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, args.push_back(create_desc); args.push_back(barrier_base_id >= 0 ? mbar_handle : PrimExpr(0)); auto dst_buffer = T.buffer_remap.count(dst_) ? T.buffer_remap[dst_] : dst_; - auto shared_addr = dst_buffer.access_ptr(2); + // Compute flat element offset from dstRegion_ mins and buffer strides. + // For a plain 2D buffer this is 0; for a versioned 3D buffer this + // resolves to stage_idx * pixel * channel — no special-casing needed. + PrimExpr flat_offset = IntImm(DataType::Int(32), 0); + { + PrimExpr stride = IntImm(DataType::Int(32), 1); + for (int i = static_cast(ndim) - 1; i >= 0; --i) { + flat_offset = flat_offset + dstRegion_->region[i]->min * stride; + stride = stride * dst_->shape[i]; + } + } + PrimExpr tile_elems = + IntImm(DataType::Int(32), desc.smem_box_pixel * desc.smem_box_channel); + PrimExpr shared_addr = dst_buffer.access_ptr( + /*access_mask=*/2, /*dtype=*/DataType::Handle(), /*content_lanes=*/1, + /*offset=*/flat_offset, /*extent=*/tile_elems); args.push_back(shared_addr); for (auto coord : global_coords) args.push_back(coord); @@ -2207,6 +2322,7 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, Evaluate(Call(DataType::Handle(), tma_load_im2col(), args)); if (barrier_base_id >= 0) { + bool ws_barrier = annotations_.Get("barrier").has_value(); // Total bytes transferred by im2col TMA copy PrimExpr total_bytes = IntImm(DataType::Int(32), desc.smem_box_pixel * desc.smem_box_channel * @@ -2214,28 +2330,42 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, Stmt barrier_before_tma_stmt = Evaluate(Call( DataType::Handle(), mbarrier_expect_tx(), {mbar_handle, total_bytes})); + + if (ws_barrier) { + // External barrier (WS pass or InjectSoftwarePipeline). + // Build: expect_tx + tma_load [+ arrive if emit_arrive is set]. + Array producer_seq{barrier_before_tma_stmt, tma_copy_stmt}; + if (auto emit_arrive_val = annotations_.Get("emit_arrive")) { + if (Downcast(emit_arrive_val.value())->value != 0) { + producer_seq.push_back( + Evaluate(Call(DataType::Handle(), builtin::ptx_arrive_barrier(), + {mbar_handle}))); + } + } + Stmt producer = + IfThenElse(MakeTmaLeaderCondition(T.thread_bounds->extent), + SeqStmt(producer_seq)); + return producer; + } + Stmt barrier_after_tma_stmt = Evaluate( Call(DataType::Handle(), builtin::ptx_arrive_barrier(), {mbar_handle})); // Thread-gated block: expect_tx + tma_load_im2col + arrive - Stmt producer = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), + Stmt producer = IfThenElse(MakeTmaLeaderCondition(T.thread_bounds->extent), SeqStmt({barrier_before_tma_stmt, tma_copy_stmt, barrier_after_tma_stmt})); - // Annotate the producer with the shared buffer it writes to. - // PipelinePlanning uses this to identify TMA copy stages. - producer = AttrStmt(dst_buffer->data, "tl.tma_copy_write_buffer", - IntImm(DataType::Int(32), 1), producer); - // Emit producer + wait pair for pipeline/WS passes. - Stmt wait_stmt = Evaluate(Call(DataType::Handle(), mbarrier_wait_parity(), - {mbar_handle, T.mbar_phase_expr})); + Stmt wait_stmt = + Evaluate(Call(DataType::Handle(), mbarrier_wait_parity(), + {mbar_handle, GetCopyMbarPhaseExpr(annotations_, T)})); return SeqStmt({producer, wait_stmt}); } - Stmt tma_copy = - IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy_stmt); + Stmt tma_copy = IfThenElse(MakeTmaLeaderCondition(T.thread_bounds->extent), + tma_copy_stmt); return tma_copy; } diff --git a/src/op/copy.h b/src/op/copy.h index 2f3a630702..d20f519815 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -6,6 +6,7 @@ #ifndef TVM_TL_OP_COPY_H_ #define TVM_TL_OP_COPY_H_ +#include "builtin.h" #include "operator.h" #include "parallel.h" @@ -124,6 +125,8 @@ class CopyNode : public TileOperatorNode { // - "disable_tma": Bool, whether to disable TMA acceleration // - "eviction_policy": IntImm, cache eviction policy (0=normal, 1=first, // 2=last) + // - attr::kAsyncCopyNoImplicitCommitWait: IntImm/Bool, suppress implicit + // cp.async commit/wait because an enclosing transform manages them // - attr::kParallelLoopLayout ("parallel_loop_layout"): Fragment, loop // layout hint applied to the outermost generated parallel loop of this // copy's SIMT loop nest. @@ -185,6 +188,15 @@ class CopyNode : public TileOperatorNode { return false; } + bool GetNoImplicitAsyncCommitWait() const { + if (auto val = annotations.Get(attr::kAsyncCopyNoImplicitCommitWait)) { + if (auto int_val = val->as()) { + return int_val->value != 0; + } + } + return false; + } + /*! * \brief Lower the copy operator to a TIR statement. * \param T Arguments for lowering. @@ -253,20 +265,38 @@ class CopyNode : public TileOperatorNode { */ bool CheckTMemStore(Target target) const; + /*! + * \brief Check target-independent cp.async prerequisites. + */ + bool CheckCPAsyncCopyPreconditions() const; + + /*! + * \brief Check whether this copy can participate in pipeline-managed + * cp.async synchronization using only target-independent prerequisites. + */ + bool CheckPipelineManagedCPAsyncCopy() const; + + /*! + * \brief Check whether this copy can participate in pipeline-managed + * cp.async synchronization for a concrete target. + */ + bool CheckPipelineManagedCPAsyncCopy(Target target, + arith::Analyzer *analyzer) const; + /*! * \brief Check if cp.async copy is supported. */ bool CheckCPAsyncCopy(Target target, const LayoutMap &layout_map, arith::Analyzer *analyzer) const; +protected: /*! * \brief Get the copy instruction type. */ - CopyInst GetCopyInst(Target target, bool disable_tma_lower, - const LayoutMap &layout_map, arith::Analyzer *analyzer, - bool buffer_oob = false, bool in_pipeline = false) const; + CopyInst GetCopyInst(Target target, const LayoutMap &layout_map, + arith::Analyzer *analyzer, + bool buffer_oob = false) const; -protected: /*! * \brief Generate lowering for bulk/global-to-shared copy. */ @@ -353,6 +383,17 @@ class CopyNode : public TileOperatorNode { */ TileOperator Clone() const; + /*! + * \brief Check that a global buffer's strides satisfy TMA requirements. + * + * Validates: contiguous innermost stride, 16-byte alignment for outer + * strides, and stride < 2^40. + * + * \return true if all stride checks pass. + */ + static bool CheckGlobalStrides(const Buffer &buffer, + arith::Analyzer *analyzer); + private: /*! * \brief Collect fragment buffers from expression and create fully replicated @@ -409,9 +450,10 @@ class Conv2DIm2ColOpNode : public TileOperatorNode { int padding_; // Padding amount int dilation_; // Dilation factor int kernel_; // Kernel size - int eviction_policy_; // Cache eviction policy - PrimExpr nhw_step_; // Step size in NHW dimensions - PrimExpr c_step_; // Step size in channel dimension + int eviction_policy_; // Cache eviction policy + PrimExpr nhw_step_; // Step size in NHW dimensions + PrimExpr c_step_; // Step size in channel dimension + Map annotations_; // Annotations from Call node TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Conv2DIm2Col", Conv2DIm2ColOpNode, TileOperatorNode); diff --git a/src/op/fill.cc b/src/op/fill.cc index 7eb22daa86..be0dd8dc10 100644 --- a/src/op/fill.cc +++ b/src/op/fill.cc @@ -60,9 +60,10 @@ using namespace tir; Fill::Fill(Array args, Map annotations) { ObjectPtr node = tvm::ffi::make_object(); - BufferRegion region = NormalizeToBufferRegion(args[0]); - node->dst = region->buffer; - node->region = region->region; + AccessRegion dst_access = NormalizeToAccessRegion(args[0], kAccessWrite); + node->dst = dst_access.region->buffer; + node->region = dst_access.region->region; + node->SetAccessRegions({dst_access}); if (args[1]->dtype != node->dst->dtype) { node->value = Cast(node->dst->dtype, args[1]); diff --git a/src/op/finalize_reducer.cc b/src/op/finalize_reducer.cc index f65be34176..c6e01e923f 100644 --- a/src/op/finalize_reducer.cc +++ b/src/op/finalize_reducer.cc @@ -34,11 +34,12 @@ using namespace tir; FinalizeReducerOp::FinalizeReducerOp(Array args, Map annotations) { auto node = tvm::ffi::make_object(); - // Normalize any supported region expression - // (BufferRegion/BufferLoad/tl.region) to a BufferRegion, then take the - // underlying Buffer as reducer. - auto region = NormalizeToBufferRegion(args[0]); - node->reducer = region->buffer; + auto reducer_access = NormalizeToAccessRegion(args[0], kAccessReadWrite); + reducer_access.region = + BufferRegion::FullRegion(reducer_access.region->buffer); + reducer_access.access_mask = kAccessReadWrite; + node->reducer = reducer_access.region->buffer; + node->SetAccessRegions({reducer_access}); node->op = (ReducerOpType)*as_const_int(args[1]); data_ = std::move(node); } diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 0619975758..ebe717bd44 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -53,9 +53,14 @@ using namespace tir; Gemm::Gemm(Array args, Map annotations) { ObjectPtr node = tvm::ffi::make_object(); - node->aRegion_ = NormalizeToBufferRegion(args[0]); - node->bRegion_ = NormalizeToBufferRegion(args[1]); - node->cRegion_ = NormalizeToBufferRegion(args[2]); + auto a_access = NormalizeToAccessRegion(args[0], kAccessRead); + auto b_access = NormalizeToAccessRegion(args[1], kAccessRead); + auto c_access = NormalizeToAccessRegion(args[2], kAccessReadWrite); + + node->aRegion_ = a_access.region; + node->bRegion_ = b_access.region; + node->cRegion_ = c_access.region; + node->SetAccessRegions({a_access, b_access, c_access}); node->a_ = node->aRegion_->buffer; node->b_ = node->bRegion_->buffer; @@ -95,9 +100,21 @@ Gemm::Gemm(Array args, Map annotations) { } node->cCoords_ = Array( {args[17].as().value(), args[18].as().value()}); + node->annotations_ = annotations; data_ = std::move(node); } +AccessRegions GemmNode::GetAccessRegions() const { + AccessRegions result; + result.reads.push_back(aRegion_); + result.reads.push_back(bRegion_); + if (!is_one(clearAccum_)) { + result.reads.push_back(cRegion_); + } + result.writes.push_back(cRegion_); + return result; +} + /** * @brief Create a copy of this GemmNode as a TileOperator. * @@ -556,8 +573,12 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { if (isTcgen05_) { return tcgen5mma_call; } - Stmt wait_stmt = Evaluate(Call(DataType::Handle(), mbarrier_wait_parity(), - {mbar_, T.mbar_phase_expr})); + PrimExpr mbar_phase = T.mbar_phase_expr; + if (auto explicit_phase = GetAnnotatedMbarPhaseExpr(annotations_)) { + mbar_phase = explicit_phase.value(); + } + Stmt wait_stmt = Evaluate( + Call(DataType::Handle(), mbarrier_wait_parity(), {mbar_, mbar_phase})); return SeqStmt({tcgen5mma_call, wait_stmt}); } diff --git a/src/op/gemm.h b/src/op/gemm.h index f8d1ae5099..523e8bafb3 100644 --- a/src/op/gemm.h +++ b/src/op/gemm.h @@ -146,6 +146,7 @@ class GemmNode : public TileOperatorNode { tir::BufferLoad mbar_; // mbar is optional, only used for TCGEN5MMA Array cCoords_; mutable GemmWarpPolicy policy_; + Map annotations_; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Gemm", GemmNode, TileOperatorNode); static void RegisterReflection() { @@ -173,17 +174,20 @@ class GemmNode : public TileOperatorNode { .def_ro("isTcgen05", &GemmNode::isTcgen05_) .def_ro("mbar", &GemmNode::mbar_) .def_ro("cCoords", &GemmNode::cCoords_) - .def_ro("policy", &GemmNode::policy_); + .def_ro("policy", &GemmNode::policy_) + .def_ro("annotations", &GemmNode::annotations_); } Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; + AccessRegions GetAccessRegions() const override; TileOperator Clone() const; -private: GemmInst getGemmInst(int block_size, Target target) const; + +private: bool allowTcgen5Mma(Target target) const; bool allowWgmma(int block_size, Target target) const; diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index b7a22a470a..86858d7cc3 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -53,9 +53,14 @@ using namespace tir; GemmPy::GemmPy(Array args, Map annotations) { ObjectPtr node = tvm::ffi::make_object(); - node->aRegion_ = NormalizeToBufferRegion(args[0]); - node->bRegion_ = NormalizeToBufferRegion(args[1]); - node->cRegion_ = NormalizeToBufferRegion(args[2]); + auto a_access = NormalizeToAccessRegion(args[0], kAccessRead); + auto b_access = NormalizeToAccessRegion(args[1], kAccessRead); + auto c_access = NormalizeToAccessRegion(args[2], kAccessReadWrite); + + node->aRegion_ = a_access.region; + node->bRegion_ = b_access.region; + node->cRegion_ = c_access.region; + node->SetAccessRegions({a_access, b_access, c_access}); node->a_ = node->aRegion_->buffer; node->b_ = node->bRegion_->buffer; @@ -99,6 +104,17 @@ GemmPy::GemmPy(Array args, Map annotations) { data_ = std::move(node); } +AccessRegions GemmPyNode::GetAccessRegions() const { + AccessRegions result; + result.reads.push_back(aRegion_); + result.reads.push_back(bRegion_); + if (!is_one(clearAccum_)) { + result.reads.push_back(cRegion_); + } + result.writes.push_back(cRegion_); + return result; +} + /** * @brief Create a copy of this GemmPyNode as a TileOperator. * @@ -278,10 +294,14 @@ static int GetArchInt(Target target) { Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) { + PrimExpr mbar_phase = T.mbar_phase_expr; + if (auto explicit_phase = GetAnnotatedMbarPhaseExpr(annotations_)) { + mbar_phase = explicit_phase.value(); + } // NOTE(wt): Decide GemmInst and compute warp partition on Python side auto prim_func = Downcast( (*f)(tvm::ffi::GetRef(this), T.layout_map, T.target, - T.thread_bounds, T.thread_var, T.mbar_phase_expr)); + T.thread_bounds, T.thread_var, mbar_phase)); ICHECK(prim_func->attrs.defined()); auto global_symbol = prim_func->attrs.GetAttr("global_symbol"); diff --git a/src/op/gemm_py.h b/src/op/gemm_py.h index 7aa7a4d10a..d8dcfa7554 100644 --- a/src/op/gemm_py.h +++ b/src/op/gemm_py.h @@ -74,6 +74,7 @@ class GemmPyNode : public TileOperatorNode { Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; + AccessRegions GetAccessRegions() const override; TileOperator Clone() const; diff --git a/src/op/gemm_sp.cc b/src/op/gemm_sp.cc index fd6f271011..6df31aa7f1 100644 --- a/src/op/gemm_sp.cc +++ b/src/op/gemm_sp.cc @@ -86,10 +86,15 @@ GemmSPWarpPolicyNode::computeWarpPartition(int M, int N, int block_size, */ GemmSP::GemmSP(Array args, Map annotations) { ObjectPtr node = tvm::ffi::make_object(); - node->aRegion_ = NormalizeToBufferRegion(args[0]); - node->eRegion_ = NormalizeToBufferRegion(args[1]); - node->bRegion_ = NormalizeToBufferRegion(args[2]); - node->cRegion_ = NormalizeToBufferRegion(args[3]); + auto a_access = NormalizeToAccessRegion(args[0], kAccessRead); + auto e_access = NormalizeToAccessRegion(args[1], kAccessRead); + auto b_access = NormalizeToAccessRegion(args[2], kAccessRead); + auto c_access = NormalizeToAccessRegion(args[3], kAccessReadWrite); + node->aRegion_ = a_access.region; + node->eRegion_ = e_access.region; + node->bRegion_ = b_access.region; + node->cRegion_ = c_access.region; + node->SetAccessRegions({a_access, e_access, b_access, c_access}); node->a_ = node->aRegion_->buffer; node->e_ = node->eRegion_->buffer; node->b_ = node->bRegion_->buffer; @@ -113,6 +118,18 @@ GemmSP::GemmSP(Array args, Map annotations) { data_ = std::move(node); } +AccessRegions GemmSPNode::GetAccessRegions() const { + AccessRegions result; + result.reads.push_back(aRegion_); + result.reads.push_back(eRegion_); + result.reads.push_back(bRegion_); + if (!clearAccum_) { + result.reads.push_back(cRegion_); + } + result.writes.push_back(cRegion_); + return result; +} + /** * @brief Create a deep copy of this GemmSPNode wrapped as a TileOperator. * @@ -175,16 +192,20 @@ Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ss << ", " << wgWait_; } ss << ">"; - auto A_buffer = T.buffer_remap.count(a_) ? T.buffer_remap[a_] : a_; - auto B_buffer = T.buffer_remap.count(b_) ? T.buffer_remap[b_] : b_; - auto C_buffer = T.buffer_remap[c_]; - auto E_buffer = T.buffer_remap.count(e_) ? T.buffer_remap[e_] : e_; + // Build access pointers from regions to preserve stage-specific offsets + // from pipeline multi-versioning (matching dense GemmNode::Lower pattern). + PrimExpr Aptr = + MakeAccessPtrFromRegion(aRegion_, /*r*/ 1, /*require_2d*/ true); + PrimExpr Bptr = + MakeAccessPtrFromRegion(bRegion_, /*r*/ 1, /*require_2d*/ true); + PrimExpr Cptr = + MakeAccessPtrFromRegion(cRegion_, /*rw*/ 3, /*require_2d*/ true); + PrimExpr Eptr = + MakeAccessPtrFromRegion(eRegion_, /*r*/ 1, /*require_2d*/ false); auto new_call = Call(DataType::Handle(), tl::tl_gemm_sp(), - Array{StringImm(ss.str()), A_buffer.access_ptr(1), - B_buffer.access_ptr(1), C_buffer.access_ptr(3), - E_buffer.access_ptr(1)}); + Array{StringImm(ss.str()), Aptr, Bptr, Cptr, Eptr}); return Evaluate(new_call); } diff --git a/src/op/gemm_sp.h b/src/op/gemm_sp.h index 8fb2db770b..c060f4efbc 100644 --- a/src/op/gemm_sp.h +++ b/src/op/gemm_sp.h @@ -77,6 +77,7 @@ class GemmSPNode : public TileOperatorNode { Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; + AccessRegions GetAccessRegions() const override; TileOperator Clone() const; diff --git a/src/op/gemm_sp_py.cc b/src/op/gemm_sp_py.cc index 177e389782..8546228fcb 100644 --- a/src/op/gemm_sp_py.cc +++ b/src/op/gemm_sp_py.cc @@ -52,10 +52,16 @@ using namespace tir; GemmSPPy::GemmSPPy(Array args, Map annotations) { ObjectPtr node = tvm::ffi::make_object(); - node->aRegion_ = NormalizeToBufferRegion(args[0]); - node->eRegion_ = NormalizeToBufferRegion(args[1]); - node->bRegion_ = NormalizeToBufferRegion(args[2]); - node->cRegion_ = NormalizeToBufferRegion(args[3]); + auto a_access = NormalizeToAccessRegion(args[0], kAccessRead); + auto e_access = NormalizeToAccessRegion(args[1], kAccessRead); + auto b_access = NormalizeToAccessRegion(args[2], kAccessRead); + auto c_access = NormalizeToAccessRegion(args[3], kAccessReadWrite); + + node->aRegion_ = a_access.region; + node->eRegion_ = e_access.region; + node->bRegion_ = b_access.region; + node->cRegion_ = c_access.region; + node->SetAccessRegions({a_access, e_access, b_access, c_access}); node->A = node->aRegion_->buffer; node->E = node->eRegion_->buffer; @@ -86,6 +92,18 @@ GemmSPPy::GemmSPPy(Array args, Map annotations) { data_ = std::move(node); } +AccessRegions GemmSPPyNode::GetAccessRegions() const { + AccessRegions result; + result.reads.push_back(aRegion_); + result.reads.push_back(eRegion_); + result.reads.push_back(bRegion_); + if (!is_one(clear_accum)) { + result.reads.push_back(cRegion_); + } + result.writes.push_back(cRegion_); + return result; +} + /** * @brief Create a copy of this GemmSPPyNode as a TileOperator. * diff --git a/src/op/gemm_sp_py.h b/src/op/gemm_sp_py.h index 59c276f168..7ced4d5663 100644 --- a/src/op/gemm_sp_py.h +++ b/src/op/gemm_sp_py.h @@ -70,6 +70,7 @@ class GemmSPPyNode : public TileOperatorNode { Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; + AccessRegions GetAccessRegions() const override; TileOperator Clone() const; diff --git a/src/op/operator.h b/src/op/operator.h index 4377b10858..6743660bbf 100644 --- a/src/op/operator.h +++ b/src/op/operator.h @@ -14,6 +14,8 @@ #include #include #include +#include +#include #include "../layout/layout.h" @@ -27,6 +29,36 @@ using AllocMBarrierCallback = std::function; using LayoutMap = Map; using BufferMap = Map; +enum AccessMask : int { + kAccessRead = 1, + kAccessWrite = 2, + kAccessReadWrite = kAccessRead | kAccessWrite, +}; + +struct AccessRegion { + BufferRegion region; + int access_mask{kAccessReadWrite}; +}; + +struct AccessRegions { + Array reads; + Array writes; +}; + +inline void AppendAccessRegionByMask(const AccessRegion &access, + Array *reads, + Array *writes) { + if (!access.region.defined()) { + return; + } + if (access.access_mask & kAccessRead) { + reads->push_back(access.region); + } + if (access.access_mask & kAccessWrite) { + writes->push_back(access.region); + } +} + enum class InferLevel : uint8_t { kFree = 0, kCommon = 1, @@ -58,20 +90,11 @@ struct LowerArgs { // Map from LetStmt variable to its bound expression, for resolving // fragment buffer accesses through let bindings Map let_var_to_expr; - // Whether the current TileOp is nested inside a pipelined loop - // (i.e. a surrounding loop annotated with num_stages > 0). - bool in_pipeline = false; - // Expression for mbarrier wait parity. - // For pipeline_num_stages=1: ko % 2 - // For pipeline_num_stages=N: (ko / N) % 2 - // For non-loop contexts: 0 - PrimExpr mbar_phase_expr; - // Number of pipeline stages (from T.Pipelined num_stages annotation). - // Determines how many mbarriers to allocate per TMA copy operation. - int pipeline_num_stages = 1; - // Expression for mbarrier stage index: ko % pipeline_num_stages. - // Used to cycle through multiple mbarriers in pipelined loops. - PrimExpr mbar_stage_expr; + // Fallback mbarrier parity for ops that do not carry an explicit + // tl.pipeline_mbar_phase_expr annotation. LowerTileOp derives this from the + // nearest enclosing serial loop so non-pipelined TMA loops still alternate + // barrier phase correctly. + PrimExpr mbar_phase_expr = IntImm(DataType::Int(32), 0); // Pointer to the shared.barrier buffer for compiler-generated mbarriers. // Points to the LowerTileOpPass member so copy.cc sees the buffer // even when created lazily by the AllocMBarrier callback. @@ -108,7 +131,22 @@ class TileOperatorNode : public Object { virtual TileOperator Clone() const = 0; + virtual AccessRegions GetAccessRegions() const { + AccessRegions result; + for (const auto &access : access_regions_) { + AppendAccessRegionByMask(access, &result.reads, &result.writes); + } + return result; + } + + void SetAccessRegions(std::vector access_regions) { + access_regions_ = std::move(access_regions); + } + TVM_FFI_DECLARE_OBJECT_INFO("tl.TileOperator", TileOperatorNode, Object); + +protected: + std::vector access_regions_; }; class TileOperator : public ObjectRef { diff --git a/src/op/reduce.cc b/src/op/reduce.cc index 460a68c9e6..32fc93f74e 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -33,8 +33,11 @@ using namespace tir; ReduceOp::ReduceOp(Array args, Map annotations) { ObjectPtr node = tvm::ffi::make_object(); // Accept BufferRegion/BufferLoad for src/dst - node->srcRegion_ = NormalizeToBufferRegion(args[0]); - node->dstRegion_ = NormalizeToBufferRegion(args[1]); + auto src_access = NormalizeToAccessRegion(args[0], kAccessRead); + auto dst_access = NormalizeToAccessRegion(args[1], kAccessReadWrite); + node->srcRegion_ = src_access.region; + node->dstRegion_ = dst_access.region; + node->SetAccessRegions({src_access, dst_access}); node->src = node->srcRegion_->buffer; node->dst = node->dstRegion_->buffer; std::string reduce_type = args[2].as().value()->value; @@ -44,6 +47,16 @@ ReduceOp::ReduceOp(Array args, Map annotations) { data_ = std::move(node); } +AccessRegions ReduceOpNode::GetAccessRegions() const { + AccessRegions result; + result.reads.push_back(srcRegion_); + if (!clear) { + result.reads.push_back(dstRegion_); + } + result.writes.push_back(dstRegion_); + return result; +} + TileOperator ReduceOpNode::Clone() const { auto op = tvm::ffi::make_object(*this); return ReduceOp(op); @@ -557,8 +570,11 @@ CumSumOp::CumSumOp(Array args, Map annotations) { ObjectPtr node = tvm::ffi::make_object(); // node->src = vmap[GetVarFromAccessPtr(args[0])]; // node->dst = vmap[GetVarFromAccessPtr(args[1])]; - node->srcRegion_ = NormalizeToBufferRegion(args[0]); - node->dstRegion_ = NormalizeToBufferRegion(args[1]); + auto src_access = NormalizeToAccessRegion(args[0], kAccessRead); + auto dst_access = NormalizeToAccessRegion(args[1], kAccessWrite); + node->srcRegion_ = src_access.region; + node->dstRegion_ = dst_access.region; + node->SetAccessRegions({src_access, dst_access}); node->src = node->srcRegion_->buffer; node->dst = node->dstRegion_->buffer; node->dim = args[2].as().value()->value; diff --git a/src/op/reduce.h b/src/op/reduce.h index 636abdd948..7c9db0c431 100644 --- a/src/op/reduce.h +++ b/src/op/reduce.h @@ -108,6 +108,7 @@ class ReduceOpNode : public TileOperatorNode { /// Infer memory layout for buffers LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; + AccessRegions GetAccessRegions() const override; static const Op &Get(); TileOperator Clone() const; diff --git a/src/op/transpose.cc b/src/op/transpose.cc index 8d045b5322..3238a633c9 100644 --- a/src/op/transpose.cc +++ b/src/op/transpose.cc @@ -22,15 +22,13 @@ using namespace tir; Transpose::Transpose(Array args, Map annotations) { ObjectPtr node = tvm::ffi::make_object(); - Array rgs[2]; - Buffer bf[2]; - for (int i = 0; i < 2; i++) { - auto region = NormalizeToBufferRegion(args[i]); - rgs[i] = region->region; - bf[i] = region->buffer; - } - std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); - std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); + auto src_access = NormalizeToAccessRegion(args[0], kAccessRead); + auto dst_access = NormalizeToAccessRegion(args[1], kAccessWrite); + node->src = src_access.region->buffer; + node->dst = dst_access.region->buffer; + node->src_range = src_access.region->region; + node->dst_range = dst_access.region->region; + node->SetAccessRegions({src_access, dst_access}); data_ = std::move(node); } diff --git a/src/op/utils.cc b/src/op/utils.cc index 309d34662f..5839efa8f3 100644 --- a/src/op/utils.cc +++ b/src/op/utils.cc @@ -63,6 +63,18 @@ BufferRegion NormalizeToBufferRegion(const PrimExpr &arg) { throw; // Unreachable } +AccessRegion NormalizeToAccessRegion(const PrimExpr &arg, + int default_access_mask) { + if (const auto *call = arg.as()) { + if (call->op.same_as(RegionOp::Get())) { + RegionOp region(call->args); + return {BufferRegion(region->GetBuffer(), region->GetRanges()), + region->GetAccessMask()}; + } + } + return {NormalizeToBufferRegion(arg), default_access_mask}; +} + PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, int rw_mask, bool require_2d) { Buffer buf = region->buffer; diff --git a/src/op/utils.h b/src/op/utils.h index 7ac3c93c98..77e21feda6 100644 --- a/src/op/utils.h +++ b/src/op/utils.h @@ -35,6 +35,12 @@ TVM_DLL bool IsBufferLikeExpr(const PrimExpr &expr); // Note: tvm_access_ptr is no longer supported here. TVM_DLL BufferRegion NormalizeToBufferRegion(const PrimExpr &arg); +// Normalize an argument to BufferRegion together with an access mask. +// If the argument is a tl.region(...) bridge, preserve its encoded mask; +// otherwise fall back to the provided default mask. +TVM_DLL AccessRegion NormalizeToAccessRegion( + const PrimExpr &arg, int default_access_mask = kAccessReadWrite); + // Build a tvm_access_ptr(handle) from a BufferRegion. // - If `require_2d` is true, checks buffer ndim >= 2. // - For 1D regions (when allowed), offset=min, extent=extent. diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 00ccbf7600..1e4a11f722 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -3616,27 +3616,6 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode *op) { const VarNode *buffer = op->node.as(); const StringImmNode *layout_str = op->value.as(); fragment_layouts[buffer] = layout_str->value; - } else if (op->attr_key == tir::attr::async_commit_queue_scope) { - const IntImmNode *queue_id = op->value.as(); - ICHECK(queue_id && queue_id->value == 0) - << "For CUDA, the index of an async queue must be 0."; - this->VisitStmt(op->body); - auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {}); - this->VisitExpr(commit_group, this->stream); - return; - } else if (op->attr_key == tir::attr::async_wait_queue_scope) { - auto wait_attrs = GetAsyncWaitAttributes(op); - auto queue_id = wait_attrs.first.as(); - ICHECK(queue_id && queue_id->value == 0) - << "For CUDA, the index of an async queue must be 0."; - auto wait_cnt = wait_attrs.second; - auto wait_group = - Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt}); - this->VisitExpr(wait_group, this->stream); - auto inner = op->body.as(); - ICHECK(inner); - this->VisitStmt(inner->body); - return; } else if (op->attr_key == "threadblock_swizzle_pattern") { this->PrintIndent(); std::string func_name; diff --git a/src/target/codegen_cutedsl.cc b/src/target/codegen_cutedsl.cc index d7cf6b2cf0..d0e44e089b 100644 --- a/src/target/codegen_cutedsl.cc +++ b/src/target/codegen_cutedsl.cc @@ -1689,25 +1689,6 @@ void CodeGenTileLangCuTeDSL::VisitStmt_(const AttrStmtNode *op) { } } VisitStmt(op->body); - } else if (op->attr_key == tir::attr::async_commit_queue_scope) { - const IntImmNode *queue_id = op->value.as(); - ICHECK(queue_id && queue_id->value == 0) - << "For CUDA, the index of an async queue must be 0."; - VisitStmt(op->body); - auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {}); - VisitExpr(commit_group, stream); - } else if (op->attr_key == tir::attr::async_wait_queue_scope) { - auto wait_attrs = GetAsyncWaitAttributes(op); - auto queue_id = wait_attrs.first.as(); - ICHECK(queue_id && queue_id->value == 0) - << "For CUDA, the index of an async queue must be 0."; - auto wait_cnt = wait_attrs.second; - auto wait_group = - Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt}); - VisitExpr(wait_group, stream); - auto inner = op->body.as(); - ICHECK(inner); - VisitStmt(inner->body); } else if (op->attr_key == "threadblock_swizzle_pattern") { this->PrintIndent(); std::string func_name; diff --git a/src/target/codegen_hip.cc b/src/target/codegen_hip.cc index 0bed839735..0e9afd3ee2 100644 --- a/src/target/codegen_hip.cc +++ b/src/target/codegen_hip.cc @@ -1177,32 +1177,26 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { } void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode *op) { - if (op->attr_key == tir::attr::async_commit_queue_scope) { - const IntImmNode *queue_id = op->value.as(); - ICHECK(queue_id && queue_id->value == 0) - << "For CUDA, the index of an async queue must be 0."; - this->VisitStmt(op->body); - auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {}); - this->VisitExpr(commit_group, this->stream); - return; - } else if (op->attr_key == tir::attr::async_wait_queue_scope) { - auto wait_attrs = GetAsyncWaitAttributes(op); - auto queue_id = wait_attrs.first.as(); - ICHECK(queue_id && queue_id->value == 0) - << "For CUDA, the index of an async queue must be 0."; - auto wait_cnt = wait_attrs.second; - auto wait_group = - Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt}); - this->VisitExpr(wait_group, this->stream); - auto inner = op->body.as(); - ICHECK(inner); - this->VisitStmt(inner->body); - return; - } else if (op->attr_key == "threadblock_swizzle_pattern") { + if (op->attr_key == "threadblock_swizzle_pattern") { this->PrintIndent(); - const StringImmNode *pattern = op->value.as(); - ICHECK(pattern); - this->stream << "const dim3 blockIdx = " << pattern->value << "();\n"; + std::string func_name; + int panel_size = 0; + if (const auto *call = op->value.as()) { + if (call->op.same_as(tir::builtin::tvm_tuple()) && + call->args.size() >= 2) { + const auto *name_node = call->args[0].as(); + const auto *size_node = call->args[1].as(); + ICHECK(name_node && size_node) << "threadblock_swizzle_pattern expects " + "tvm_tuple(device_func, panel_size)"; + func_name = name_node->value; + panel_size = static_cast(size_node->value); + } + } + ICHECK(!func_name.empty() && panel_size > 0) + << "threadblock_swizzle_pattern: failed to extract func_name and " + "panel_size"; + this->stream << "const dim3 blockIdx = tl::" << func_name << "(" + << panel_size << ");\n"; this->VisitStmt(op->body); return; } diff --git a/src/transform/common/pipeline_utils.h b/src/transform/common/pipeline_utils.h new file mode 100644 index 0000000000..1722c0262e --- /dev/null +++ b/src/transform/common/pipeline_utils.h @@ -0,0 +1,112 @@ +/*! + * \file pipeline_utils.h + * \brief Shared utilities for software-pipeline and warp-specialization passes. + * + * Provides: + * - Pipeline annotation attribute keys + * - GetPipelineNumStages() — extract num_stages from loop annotations + * - ComputeThreadBounds() — derive thread bounds from an analyzer + IterVar + */ +#ifndef TVM_TL_TRANSFORM_COMMON_PIPELINE_UTILS_H_ +#define TVM_TL_TRANSFORM_COMMON_PIPELINE_UTILS_H_ + +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +// --------------------------------------------------------------------------- +// Pipeline annotation attribute keys +// --------------------------------------------------------------------------- + +/*! Marks the enclosing scope with the pipeline stage count. */ +static constexpr const char *kPipelineContextNumStages = + "tl.pipeline_context_num_stages"; +/*! Multi-version buffer: stage count for buffer expansion. */ +static constexpr const char *kPipelineMVBContextNumStages = + "tl.pipeline_mvb_num_stages"; +/*! Multi-version buffer: per-statement stage index expression. */ +static constexpr const char *kPipelineMVBStageExpr = + "tl.pipeline_mvb_stage_expr"; +/*! Multi-version buffer: per-statement parity expression. */ +static constexpr const char *kPipelineMVBParityExpr = + "tl.pipeline_mvb_parity_expr"; +/*! Per-statement TMA copy flag (1 = TMA eligible, 0 = not). */ +static constexpr const char *kPipelineTmaCopies = + "software_pipeline_tma_copies"; +/*! Per-statement async producer flag (1 = async copy producer, 0 = not). */ +static constexpr const char *kPipelineAsyncProducers = + "software_pipeline_async_producers"; +/*! Per-statement async producer group id (-1 = not an async producer). */ +static constexpr const char *kPipelineAsyncProducerGroups = + "software_pipeline_async_producer_groups"; + +// --------------------------------------------------------------------------- +// GetPipelineNumStages +// --------------------------------------------------------------------------- + +/*! + * \brief Extract the pipeline stage count from a For loop's annotations. + * + * Checks (in order): + * 1. "num_stages" — user-provided stage count + * 2. "tl_pipelined_num_stages" — set by InjectSoftwarePipeline + * 3. tir::attr::software_pipeline_stage — max(stage) + 1 + * + * \return The stage count, or nullopt if the loop is not pipelined. + */ +inline Optional GetPipelineNumStages(const ForNode *loop) { + if (auto num_stages = loop->annotations.Get("num_stages")) { + if (const auto *imm = num_stages->as()) { + return Integer(static_cast(imm->value)); + } + } + if (auto num_stages = loop->annotations.Get("tl_pipelined_num_stages")) { + if (const auto *imm = num_stages->as()) { + return Integer(static_cast(imm->value)); + } + } + if (auto stages_anno = + loop->annotations.Get(tir::attr::software_pipeline_stage)) { + auto stages = Downcast>(stages_anno.value()); + int max_stage = -1; + for (const auto &stage : stages) { + max_stage = std::max(max_stage, static_cast(stage->value)); + } + if (max_stage >= 0) { + return Integer(max_stage + 1); + } + } + return Optional(); +} + +// --------------------------------------------------------------------------- +// ComputeThreadBounds +// --------------------------------------------------------------------------- + +/*! + * \brief Compute the thread index bounds from an IterVar and an analyzer. + * + * \return Range covering the thread index, or [0, 1) if no bound is known. + */ +inline Range ComputeThreadBounds(const IterVar &thread_var, + const arith::Analyzer &analyzer) { + if (thread_var.defined() && + analyzer.const_int_bound.IsBound(thread_var->var)) { + auto const_int_bound = analyzer.const_int_bound(thread_var); + auto min_value = const_int_bound->min_value; + auto max_value = const_int_bound->max_value; + auto extent = max_value - min_value + 1; + auto dtype = thread_var->var.dtype(); + return Range::FromMinExtent(IntImm(dtype, min_value), + IntImm(dtype, extent)); + } + return Range::FromMinExtent(0, 1); +} + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_TRANSFORM_COMMON_PIPELINE_UTILS_H_ diff --git a/src/transform/common/tma_copy_utils.h b/src/transform/common/tma_copy_utils.h deleted file mode 100644 index 8cd4e8dadb..0000000000 --- a/src/transform/common/tma_copy_utils.h +++ /dev/null @@ -1,28 +0,0 @@ -#ifndef TVM_TL_TRANSFORM_COMMON_TMA_COPY_UTILS_H_ -#define TVM_TL_TRANSFORM_COMMON_TMA_COPY_UTILS_H_ - -#include - -namespace tvm { -namespace tl { - -using namespace tir; - -inline Stmt StripTmaCopyWriteBufferAttr(Stmt stmt) { - class TmaCopyWriteBufferAttrStripper : public StmtExprMutator { - public: - Stmt VisitStmt_(const AttrStmtNode *op) final { - if (op->attr_key == "tl.tma_copy_write_buffer") { - return VisitStmt(op->body); - } - return StmtExprMutator::VisitStmt_(op); - } - }; - - return TmaCopyWriteBufferAttrStripper()(std::move(stmt)); -} - -} // namespace tl -} // namespace tvm - -#endif // TVM_TL_TRANSFORM_COMMON_TMA_COPY_UTILS_H_ diff --git a/src/transform/inject_pipeline.cc b/src/transform/inject_pipeline.cc index bc82ee5718..d2731e85e5 100644 --- a/src/transform/inject_pipeline.cc +++ b/src/transform/inject_pipeline.cc @@ -3,15 +3,28 @@ * \brief Transform annotated loops into pipelined one that parallelize * producers and consumers */ +#include #include #include #include #include +#include +#include #include #include - -#include "common/tma_copy_utils.h" +#include + +#include "../layout/layout.h" +#include "../op/builtin.h" +#include "../op/copy.h" +#include "../op/gemm.h" +#include "../op/gemm_py.h" +#include "../op/operator.h" +#include "../op/region.h" +#include "../op/utils.h" +#include "common/mbarrier.h" +#include "common/pipeline_utils.h" #include "support/utils.h" #include "tir/schedule/utils.h" #include "tir/transforms/ir_utils.h" @@ -22,6 +35,92 @@ using namespace tir; using namespace ffi; namespace software_pipeline { +namespace { + +bool ShapesEqual(const Array &lhs, const Array &rhs, + arith::Analyzer *analyzer) { + if (lhs.size() != rhs.size()) { + return false; + } + for (size_t i = 0; i < lhs.size(); ++i) { + if (!analyzer->CanProveEqual(lhs[i], rhs[i])) { + return false; + } + } + return true; +} + +Layout ExpandAnnotatedLayoutForMultiVersionedBuffer(const Layout &layout, + const Buffer &old_buffer, + const Buffer &new_buffer) { + if (!layout.defined() || + new_buffer->shape.size() <= old_buffer->shape.size()) { + return Layout(); + } + + arith::Analyzer analyzer; + if (!ShapesEqual(layout->InputShape(), old_buffer->shape, &analyzer)) { + return Layout(); + } + + size_t leading_ndim = new_buffer->shape.size() - old_buffer->shape.size(); + Array trailing_shape; + Array leading_shape; + for (size_t i = 0; i < leading_ndim; ++i) { + leading_shape.push_back(new_buffer->shape[i]); + } + for (size_t i = 0; i < old_buffer->shape.size(); ++i) { + trailing_shape.push_back(new_buffer->shape[leading_ndim + i]); + } + if (!ShapesEqual(trailing_shape, old_buffer->shape, &analyzer)) { + return Layout(); + } + + return layout->Expand(leading_shape); +} + +bool UpdateExpandedLayoutMapForRemappedAllocs( + const std::vector> &remapped_allocs, + Map *annotations) { + if (remapped_allocs.empty() || !annotations->count(attr::kLayoutMap)) { + return false; + } + + auto layout_map_ref = annotations->Get(attr::kLayoutMap); + if (!layout_map_ref.has_value()) { + return false; + } + auto layout_map = layout_map_ref.value().as>(); + if (!layout_map.has_value()) { + return false; + } + + Map updated_layout_map = layout_map.value(); + std::unordered_set visited; + bool changed = false; + for (const auto &[old_buffer, new_buffer] : remapped_allocs) { + if (!visited.insert(old_buffer->data.get()).second || + !updated_layout_map.count(old_buffer->data)) { + continue; + } + Layout layout = updated_layout_map[old_buffer->data]; + Layout expanded = ExpandAnnotatedLayoutForMultiVersionedBuffer( + layout, old_buffer, new_buffer); + if (!expanded.defined()) { + continue; + } + updated_layout_map.Set(old_buffer->data, expanded); + changed = true; + } + + if (changed) { + annotations->Set(attr::kLayoutMap, updated_layout_map); + } + return changed; +} + +} // namespace + struct LetWrapper { Var var; PrimExpr value; @@ -68,6 +167,18 @@ class BufferUsageCollector : public StmtExprVisitor { } void VisitExpr_(const CallNode *op) final { + if (auto tile_op = ParseOperator(tvm::ffi::GetRef(op)); + tile_op.defined()) { + AccessRegions access = tile_op->GetAccessRegions(); + for (const auto ®ion : access.reads) { + AddBuffer(region->buffer); + } + for (const auto ®ion : access.writes) { + AddBuffer(region->buffer); + } + StmtExprVisitor::VisitExpr_(op); + return; + } // Handle tvm_access_ptr which also accesses buffers if (op->op.same_as(builtin::tvm_access_ptr())) { if (op->args.size() > 1) { @@ -103,6 +214,29 @@ class BufferUsageCollector : public StmtExprVisitor { std::unordered_set used_buffers_; }; +class TileOpAccessCollector : public StmtExprVisitor { +public: + Array GetReads() const { return reads_; } + + Array GetWrites() const { return writes_; } + +private: + void VisitExpr_(const CallNode *op) final { + if (auto tile_op = ParseOperator(tvm::ffi::GetRef(op)); + tile_op.defined()) { + AccessRegions access = tile_op->GetAccessRegions(); + reads_.insert(reads_.end(), access.reads.begin(), access.reads.end()); + writes_.insert(writes_.end(), access.writes.begin(), access.writes.end()); + StmtExprVisitor::VisitExpr_(op); + return; + } + StmtExprVisitor::VisitExpr_(op); + } + + Array reads_; + Array writes_; +}; + /*! * \brief Create a block and infer the access region with the given body. * @@ -116,19 +250,27 @@ class BufferUsageCollector : public StmtExprVisitor { */ Block MakeBlock(const Stmt &body, const Map &buffer_data_to_buffer) { + Block block; if (const BlockRealizeNode *block_realize = body.as()) { if (is_one(block_realize->predicate)) { - // no need to create a new block - return block_realize->block; + block = block_realize->block; } } - Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", - /*body*/ body); + if (!block.defined()) { + block = Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/"", /*body*/ body); + } Array> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer); + TileOpAccessCollector collector; + collector(block->body); + Array tile_reads = collector.GetReads(); + Array tile_writes = collector.GetWrites(); BlockNode *n = block.CopyOnWrite(); n->reads = access[0]; + n->reads.insert(n->reads.end(), tile_reads.begin(), tile_reads.end()); n->writes = access[1]; + n->writes.insert(n->writes.end(), tile_writes.begin(), tile_writes.end()); return block; } @@ -136,6 +278,8 @@ Block MakeBlock(const Stmt &body, struct PipelineAnnotation { int stage; int order; + bool async{false}; + int async_group_id{-1}; }; using PipelineInfo = std::unordered_map()) { + if (attr->attr_key == tir::attr::async_scope || + attr->attr_key == tir::attr::async_commit_queue_scope || + attr->attr_key == tir::attr::async_wait_queue_scope || + attr->attr_key == tir::attr::async_wait_inflight_count) { + found = true; + return; + } + } + const auto *call = obj.as(); + if (!call) { + return; + } + if (call->op.same_as(builtin::ptx_cp_async()) || + call->op.same_as(tl::ptx_cp_async()) || + call->op.same_as(builtin::ptx_commit_group()) || + call->op.same_as(builtin::ptx_wait_group())) { + found = true; + } + }); + return found; +} + +class SimtProducerAnnotator : public StmtExprMutator { +public: + static Stmt Annotate(const Stmt &stmt, + Optional target = Optional()) { + SimtProducerAnnotator annotator(std::move(target)); + return annotator.VisitStmt(stmt); + } + +private: + explicit SimtProducerAnnotator(Optional target) + : target_(std::move(target)) {} + + Stmt VisitStmt_(const ForNode *op) final { + Stmt body = VisitStmt(op->body); + auto annotations = op->annotations; + // Keep the raw buffer-store cp.async path under outer pipeline-managed + // commit/wait semantics as well. + annotations.Set(attr::kParallelAsyncWithoutAsyncCommitWait, Bool(true)); + return For(op->loop_var, op->min, op->extent, op->kind, body, + op->thread_binding, annotations, op->step, op->span); + } + + PrimExpr VisitExpr_(const CallNode *op) final { + static const Op ©_op = Op::Get("tl.tileop.copy"); + Call call = Downcast(StmtExprMutator::VisitExpr_(op)); + if (!call->op.same_as(copy_op) || !CanUsePipelineManagedCPAsyncCopy(call)) { + return call; + } + // Tile-op copies lower through copy.cc, so they need an explicit + // per-copy marker to suppress their own implicit commit/wait. + auto annotations = call->annotations; + annotations.Set(attr::kAsyncCopyNoImplicitCommitWait, + IntImm(DataType::Int(32), 1)); + return Call(call->dtype, call->op, call->args, annotations, call->span); + } + + bool CanUsePipelineManagedCPAsyncCopy(const Call &call) const { + auto tile_op = ParseOperator(call); + const auto *copy = tile_op.as(); + if (copy == nullptr) { + return false; + } + if (!target_.defined()) { + return copy->CheckPipelineManagedCPAsyncCopy(); + } + return copy->CheckPipelineManagedCPAsyncCopy(target_.value(), &analyzer_); + } + + Optional target_; + mutable arith::Analyzer analyzer_; +}; + +class TileOpMbarPhaseAnnotator : public StmtExprMutator { +public: + static Stmt Annotate(const Stmt &stmt, PrimExpr phase_expr) { + TileOpMbarPhaseAnnotator annotator(std::move(phase_expr)); + return annotator.VisitStmt(stmt); + } + +private: + explicit TileOpMbarPhaseAnnotator(PrimExpr phase_expr) + : phase_expr_(std::move(phase_expr)) {} + + PrimExpr VisitExpr_(const CallNode *op) final { + Call call = Downcast(StmtExprMutator::VisitExpr_(op)); + if (!IsMbarPhaseConsumer(call)) { + return call; + } + if (call->annotations.count(attr::kPipelineMbarPhaseExpr)) { + return call; + } + auto annotations = call->annotations; + annotations.Set(attr::kPipelineMbarPhaseExpr, phase_expr_); + return Call(call->dtype, call->op, call->args, annotations, call->span); + } + + bool IsMbarPhaseConsumer(const Call &call) const { + auto tile_op = ParseOperator(call); + return tile_op.defined() && (tile_op.as() != nullptr || + tile_op.as() != nullptr || + tile_op.as() != nullptr || + tile_op.as() != nullptr); + } + + PrimExpr phase_expr_; +}; + +class AsyncCommitWaitAttrLowerer : public StmtExprMutator { +public: + static Stmt Lower(const Stmt &stmt) { + AsyncCommitWaitAttrLowerer lowerer; + return lowerer.VisitStmt(stmt); + } + +private: + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::async_commit_queue_scope) { + Stmt body = VisitStmt(op->body); + Stmt commit = + Evaluate(Call(DataType::Handle(), builtin::ptx_commit_group(), {})); + if (is_no_op(body)) { + return commit; + } + return SeqStmt({body, commit}); + } + if (op->attr_key == tir::attr::async_wait_queue_scope) { + auto wait_attrs = GetAsyncWaitAttributes(op); + Stmt body = op->body; + if (const auto *inner = op->body.as()) { + if (inner->attr_key == tir::attr::async_wait_inflight_count) { + body = inner->body; + } + } + body = VisitStmt(body); + Stmt wait = Evaluate(Call(DataType::Handle(), builtin::ptx_wait_group(), + {wait_attrs.second})); + if (is_no_op(body)) { + return wait; + } + return SeqStmt({wait, body}); + } + if (op->attr_key == tir::attr::async_wait_inflight_count) { + return VisitStmt(op->body); + } + return StmtExprMutator::VisitStmt_(op); + } +}; + /*! * \brief Rewriter for the body of the software pipeline. This pass inserts * `floormod` to indices of the remapped buffer to select the version @@ -282,6 +588,22 @@ class PipelineBodyRewriter : public StmtExprMutator { if (call->op.same_as(builtin::tvm_access_ptr())) { return RewriteBufferAccess(call, {1}); } + if (call->op.same_as(RegionOp::Get()) && call->args.size() >= 2) { + if (auto load = call->args[0].as()) { + size_t num_extents = call->args.size() - 2; + if (load->indices.size() == num_extents + 1) { + Array new_args; + new_args.push_back(call->args[0]); + new_args.push_back(call->args[1]); + new_args.push_back(IntImm(DataType::Int(32), 1)); + for (size_t i = 2; i < call->args.size(); ++i) { + new_args.push_back(call->args[i]); + } + return Call(call->dtype, call->op, new_args, call->annotations, + call->span); + } + } + } return call; } @@ -316,12 +638,13 @@ class PipelineRewriter : public StmtExprMutator { PipelineRewriter(Map buffer_data_to_buffer, const Array &pipeline_allocs, const Array &local_allocs, const For &pipeline_loop, - const PipelineInfo &pipeline_info, + const PipelineInfo &pipeline_info, Optional target, const std::vector &loop_var_let_wrappers, const std::vector &loop_var_if_wrappers) : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), pipeline_allocs_(pipeline_allocs), local_allocs_(local_allocs), pipeline_loop_(pipeline_loop), pipeline_info_(pipeline_info), + target_(std::move(target)), loop_var_let_wrappers_(loop_var_let_wrappers), loop_var_if_wrappers_(loop_var_if_wrappers) {} @@ -347,16 +670,33 @@ class PipelineRewriter : public StmtExprMutator { } // Step 2: Emit the pipeline prologue, body and epilogue. - Stmt prologue = EmitImpl(pipeline_loop_->min, - pipeline_loop_->min + max_stage_, true, true); - Stmt body = + Optional pipeline_num_stages = + GetPipelineNumStages(pipeline_loop_.get()); + Stmt prologue = StripPipelineContextAttrs(EmitImpl( + pipeline_loop_->min, pipeline_loop_->min + max_stage_, true, true)); + Stmt body = StripPipelineContextAttrs( EmitImpl(pipeline_loop_->min + max_stage_, - pipeline_loop_->min + pipeline_loop_->extent, false, false); - - Stmt epilogue = EmitImpl( + pipeline_loop_->min + pipeline_loop_->extent, false, false)); + Stmt epilogue = StripPipelineContextAttrs(EmitImpl( pipeline_loop_->min + pipeline_loop_->extent, - pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true); - SeqStmt stmt = SeqStmt({prologue, body, epilogue}); + pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true)); + + Array pipeline_parts; + for (const Stmt &part : {prologue, body, epilogue}) { + for (const Stmt &stmt : FlattenTopLevelSeq(part)) { + pipeline_parts.push_back(stmt); + } + } + + Stmt stmt = pipeline_parts.size() == 1 ? pipeline_parts[0] + : SeqStmt(pipeline_parts); + stmt = AsyncPipelineLoopWaitRelaxer(this)(stmt); + Array relaxed_pipeline_parts = FlattenTopLevelSeq(stmt); + relaxed_pipeline_parts = + RelaxTrailingConsumerWaits(std::move(relaxed_pipeline_parts), + PipelinedRetainGroups(pipeline_num_stages)); + stmt = relaxed_pipeline_parts.size() == 1 ? relaxed_pipeline_parts[0] + : SeqStmt(relaxed_pipeline_parts); // Step 3: Make a new block that contains new buffer allocations after // pipeline rewriting. @@ -367,6 +707,14 @@ class PipelineRewriter : public StmtExprMutator { alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc)); buffer_data_to_buffer_.erase(alloc->data); } + if (pipeline_num_stages) { + if (pipeline_num_stages.value()->value > 1) { + stmt = AttrStmt(Integer(0), kPipelineMVBContextNumStages, + Downcast(pipeline_num_stages.value()), stmt); + } + stmt = AttrStmt(Integer(0), kPipelineContextNumStages, + Downcast(pipeline_num_stages.value()), stmt); + } Block block = MakeBlock(stmt, buffer_data_to_buffer_); block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers); return BlockRealize({}, Bool(true), block); @@ -528,169 +876,1390 @@ class PipelineRewriter : public StmtExprMutator { return Buffer(new_buffer); } - /*! Structure holding intermediate information for pipeline loop rewriting. */ - struct RewrittenBlockInfo { - PrimExpr predicate; - Block block; + struct AsyncStateGlobal { + std::unordered_set dst_buffers; + Optional producer_head{PrimExpr(-1)}; + + bool writes(const Buffer &buffer) const { + return dst_buffers.count(buffer.get()) > 0; + } }; - /*! - * \brief Emit the pipeline loop in the given range. - * \param start The start of the range - * \param end The end of the range - * \param unroll_loop Whether the loop should be unrolled. - * \return The result loop. - */ - Stmt EmitImpl(const PrimExpr &start, const PrimExpr &end, bool unroll_loop, - bool need_bound_check) { - PrimExpr new_loop_var; - PrimExpr extent = end - start; - auto make_nop = []() { - return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {})); + struct AsyncStateLocal { + struct PendingWait { + int insert_before{-1}; + PrimExpr wait_count{nullptr}; + + bool valid() const { return wait_count.defined(); } }; - bool is_unit_loop = analyzer_.CanProveEqual(extent, 1); - if (is_unit_loop) { - new_loop_var = start; // use constants as the loop var for unit loops - } else { - new_loop_var = pipeline_loop_->loop_var.copy_with_suffix(""); - // Bind the iteration domain [start, end) to strengthen analyzer facts. - analyzer_.Bind(Downcast(new_loop_var), - Range::FromMinExtent(start, end - start)); - } - // Keep the bound constraints active for all analysis below. - // Only meaningful when the loop var is symbolic (non-unit loop). - std::unique_ptr> ctx_lb_guard; - std::unique_ptr> ctx_ub_guard; - if (!is_unit_loop) { - Var loop_iter = Downcast(new_loop_var); - ctx_lb_guard.reset( - new With(&analyzer_, loop_iter >= start)); - ctx_ub_guard.reset( - new With(&analyzer_, loop_iter < end)); + std::unordered_set seen; + Optional producer_head; + Optional predicate; + std::vector> commit_groups; + std::map pending_waits; + std::unordered_map annotated_group_to_commit_group; + bool consumed{false}; + }; + + struct RewrittenStmtInfo { + int stage; + PrimExpr predicate; + Array reads; + Array writes; + PrimExpr access_index; + bool is_async; + Stmt stmt; + }; + + struct FinalStmtInfo { + int stage; + PrimExpr access_index; + PrimExpr predicate; + Stmt stmt; + }; + + enum class AsyncSyncStmtKind { kOther, kCommit, kWaitStatic, kWaitDynamic }; + + struct ClassifiedAsyncSyncStmt { + AsyncSyncStmtKind kind{AsyncSyncStmtKind::kOther}; + int wait_n{0}; + }; + + struct AsyncSyncSummary { + int commit{0}; + int wait{0}; + }; + + enum class HeadAsyncSyncKind { + kNone, + kCommit, + kWaitStatic, + kWaitDynamic, + kBlocked, + }; + + struct HeadAsyncSyncInfo { + HeadAsyncSyncKind kind{HeadAsyncSyncKind::kNone}; + int wait_n{0}; + + bool IsBoundary() const { + return kind == HeadAsyncSyncKind::kCommit || + kind == HeadAsyncSyncKind::kWaitDynamic || + kind == HeadAsyncSyncKind::kBlocked; } + }; - std::vector new_blocks; + enum class HeadSeqMode { + kSingletonOnly, + kTakeFirstElement, + }; - for (const Block &block : ordered_stmts_) { - int stage = pipeline_info_.at(block).stage; - PrimExpr inbound = Bool(true); - PrimExpr skewed_loop_var = new_loop_var - stage; - if (need_bound_check) - inbound = And( - pipeline_loop_->min <= skewed_loop_var, - (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent)); + struct DeterministicNoWaitCommitEffect { + bool deterministic{true}; + bool has_wait{false}; + int commit_groups{0}; - Block new_block = Downcast( - PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_, - pipeline_loop_, max_stage_ != 1)(block)); + static DeterministicNoWaitCommitEffect Unknown() { + DeterministicNoWaitCommitEffect effect; + effect.deterministic = false; + return effect; + } - PrimExpr delta = start - pipeline_loop_->min; - PrimExpr normalized_access_index = - is_unit_loop ? skewed_loop_var : skewed_loop_var + delta; + static DeterministicNoWaitCommitEffect Wait() { + DeterministicNoWaitCommitEffect effect; + effect.has_wait = true; + return effect; + } + }; - normalized_access_index = analyzer_.Simplify(normalized_access_index); + // Analyze a stmt for one specific question used by wait relaxation: + // can we prove that it contributes a deterministic number of commit groups + // without crossing a wait boundary? The analyzer exposes the effect as + // structured state instead of overloading std::optional with both + // "unknown" and "has wait" meanings. + class DeterministicNoWaitCommitAnalyzer { + public: + explicit DeterministicNoWaitCommitAnalyzer(const PipelineRewriter *rewriter) + : rewriter_(rewriter) {} + + DeterministicNoWaitCommitEffect Analyze(const Stmt &stmt) const { + if (const auto *let = stmt.as()) { + return Analyze(let->body); + } + if (const auto *attr = stmt.as()) { + return AnalyzeAttr(attr); + } + if (const auto *seq = stmt.as()) { + DeterministicNoWaitCommitEffect effect; + for (const Stmt &s : seq->seq) { + effect = Combine(effect, Analyze(s)); + if (!effect.deterministic) { + return effect; + } + } + return effect; + } + if (const auto *block = stmt.as()) { + return Analyze(block->body); + } + if (const auto *realize = stmt.as()) { + if (!is_one(realize->predicate)) { + return DeterministicNoWaitCommitEffect::Unknown(); + } + return Analyze(realize->block->body); + } + if (const auto *for_node = stmt.as()) { + return AnalyzeFor(for_node); + } + if (stmt.as()) { + return DeterministicNoWaitCommitEffect::Unknown(); + } + if (rewriter_->ContainsAsyncSyncScopes(stmt)) { + return DeterministicNoWaitCommitEffect::Unknown(); + } + return {}; + } - // Adjust the block predicate and the body according to the final loop - // bound - // [pipeline_loop_->min, extent). - if (!is_unit_loop) { - Var loop_iter = Downcast(new_loop_var); - inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}}); + private: + DeterministicNoWaitCommitEffect + AnalyzeAttr(const AttrStmtNode *attr) const { + if (PipelineRewriter::IsAsyncWaitQueueScope(attr) || + PipelineRewriter::IsAsyncWaitInflightCount(attr)) { + return DeterministicNoWaitCommitEffect::Wait(); } - new_block = Downcast(Substitute( - new_block, {{pipeline_loop_->loop_var, normalized_access_index}})); + if (PipelineRewriter::IsAsyncCommitQueueScope(attr)) { + auto effect = Analyze(attr->body); + if (!effect.deterministic) { + return effect; + } + ++effect.commit_groups; + return effect; + } + return Analyze(attr->body); + } - // If there were Let-wrappers outside the original pipeline body that - // depended on the pipeline loop var, push them into each rewritten - // block with the correct per-block substitution. - // We iterate in reverse order so that earlier definitions scope over - // later ones. For example, if we have: - // id = ids[i] # depends on loop var - // id2 = ids2[id] # depends on id - // We want to produce: - // LetStmt(id, ids[...], - // LetStmt(id2, ids2[id], - // body)) - // So that id2's definition can reference id. - if (!loop_var_let_wrappers_.empty()) { - BlockNode *n = new_block.CopyOnWrite(); - Stmt inner = n->body; - for (auto it = loop_var_let_wrappers_.rbegin(); - it != loop_var_let_wrappers_.rend(); ++it) { - const auto &lw = *it; - PrimExpr substituted = Substitute( - lw.value, {{pipeline_loop_->loop_var, normalized_access_index}}); - inner = LetStmt(lw.var, substituted, inner); - } - n->body = inner; - } - - // Similarly, handle If-wrappers whose conditions depend on the - // pipeline loop var. - if (!loop_var_if_wrappers_.empty()) { - BlockNode *n = new_block.CopyOnWrite(); - Stmt inner = n->body; - for (auto it = loop_var_if_wrappers_.rbegin(); - it != loop_var_if_wrappers_.rend(); ++it) { - const auto &iw = *it; - PrimExpr substituted_condition = - Substitute(iw.condition, - {{pipeline_loop_->loop_var, normalized_access_index}}); - inner = IfThenElse(substituted_condition, inner, Stmt(), iw.span); - } - n->body = inner; - } - - new_blocks.push_back({inbound, new_block}); + DeterministicNoWaitCommitEffect AnalyzeFor(const ForNode *for_node) const { + if (for_node->thread_binding.defined()) { + return DeterministicNoWaitCommitEffect::Unknown(); + } + const int64_t *extent_imm = as_const_int(for_node->extent); + if (extent_imm == nullptr || *extent_imm < 0) { + return DeterministicNoWaitCommitEffect::Unknown(); + } + auto effect = Analyze(for_node->body); + if (!effect.deterministic) { + return effect; + } + effect.commit_groups *= static_cast(*extent_imm); + return effect; } - Array stmts; - for (const auto &block_info : new_blocks) { - stmts.push_back(BlockRealize({}, block_info.predicate, block_info.block)); + static DeterministicNoWaitCommitEffect + Combine(const DeterministicNoWaitCommitEffect &lhs, + const DeterministicNoWaitCommitEffect &rhs) { + if (!lhs.deterministic || !rhs.deterministic) { + return DeterministicNoWaitCommitEffect::Unknown(); + } + DeterministicNoWaitCommitEffect effect; + effect.has_wait = lhs.has_wait || rhs.has_wait; + effect.commit_groups = lhs.commit_groups + rhs.commit_groups; + return effect; } - Stmt new_loop{nullptr}; + const PipelineRewriter *rewriter_; + }; - if (stmts.empty()) { - return make_nop(); + Stmt + WrapLoopDependentWrappers(Stmt stmt, + const PrimExpr &normalized_access_index) const { + for (auto it = loop_var_if_wrappers_.rbegin(); + it != loop_var_if_wrappers_.rend(); ++it) { + const auto &iw = *it; + PrimExpr substituted_condition = Substitute( + iw.condition, {{pipeline_loop_->loop_var, normalized_access_index}}); + stmt = IfThenElse(substituted_condition, stmt, Stmt(), iw.span); } + for (auto it = loop_var_let_wrappers_.rbegin(); + it != loop_var_let_wrappers_.rend(); ++it) { + const auto &lw = *it; + PrimExpr substituted = Substitute( + lw.value, {{pipeline_loop_->loop_var, normalized_access_index}}); + stmt = LetStmt(lw.var, substituted, stmt); + } + return stmt; + } - if (stmts.size() == 1) { - new_loop = stmts[0]; + Stmt WrapPipelineStageContext(Stmt stmt, + const PrimExpr &normalized_access_index, + const Optional &pipeline_num_stages) { + if (!(pipeline_num_stages && pipeline_num_stages.value()->value > 1)) { + return stmt; + } + PrimExpr ns = IntImm(DataType::Int(32), pipeline_num_stages.value()->value); + PrimExpr stage_expr = + analyzer_.Simplify(FloorMod(normalized_access_index, ns)); + PrimExpr parity_expr = analyzer_.Simplify(FloorMod( + FloorDiv(normalized_access_index, ns), IntImm(DataType::Int(32), 2))); + stmt = AttrStmt(Integer(0), kPipelineMVBParityExpr, parity_expr, stmt); + stmt = AttrStmt(Integer(0), kPipelineMVBStageExpr, stage_expr, stmt); + return stmt; + } + + Optional + ComputePipelineMbarPhaseExpr(const PrimExpr &normalized_access_index, + const Optional &pipeline_num_stages) { + if (!pipeline_num_stages) { + return Optional(); + } + PrimExpr parity_expr; + if (pipeline_num_stages.value()->value <= 1) { + parity_expr = + FloorMod(normalized_access_index, IntImm(DataType::Int(32), 2)); } else { - new_loop = SeqStmt(stmts); + PrimExpr ns = + IntImm(DataType::Int(32), pipeline_num_stages.value()->value); + parity_expr = FloorMod(FloorDiv(normalized_access_index, ns), + IntImm(DataType::Int(32), 2)); } + return analyzer_.Simplify(parity_expr); + } - if (!is_unit_loop) { - Map preserved_annotations; - for (const auto &kv : pipeline_loop_->annotations) { - const String &key = kv.first; - if (kv.first != tir::attr::software_pipeline_stage && - kv.first != tir::attr::software_pipeline_order && - kv.first != tir::attr::software_pipeline_async_stages) { - preserved_annotations.Set(key, kv.second); - } + static bool IsAsyncCommitQueueScope(const AttrStmtNode *attr) { + return attr && attr->attr_key == tir::attr::async_commit_queue_scope; + } + + static bool IsAsyncWaitQueueScope(const AttrStmtNode *attr) { + return attr && attr->attr_key == tir::attr::async_wait_queue_scope; + } + + static bool IsAsyncWaitInflightCount(const AttrStmtNode *attr) { + return attr && attr->attr_key == tir::attr::async_wait_inflight_count; + } + + static int + PipelinedRetainGroups(const Optional &pipeline_num_stages) { + int retain = 1; + if (pipeline_num_stages) { + retain = + std::max(0, static_cast(pipeline_num_stages.value()->value) - 1); + } + return retain; + } + + Stmt StripPipelineContextAttrs(Stmt stmt) const { + while (const auto *attr = stmt.as()) { + if (attr->attr_key != kPipelineContextNumStages && + attr->attr_key != kPipelineMVBContextNumStages) { + break; } - new_loop = For(Downcast(new_loop_var), pipeline_loop_->min, extent, - unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, - std::move(new_loop), std::nullopt, preserved_annotations); + stmt = attr->body; } - return BlockRealize({}, Bool(true), - MakeBlock(new_loop, buffer_data_to_buffer_)); + return stmt; } - arith::Analyzer analyzer_; - Map buffer_data_to_buffer_; - Array pipeline_allocs_; - Array local_allocs_; - For pipeline_loop_; - PipelineInfo pipeline_info_; - int max_stage_ = -1; + Array FlattenTopLevelSeq(const Stmt &stmt) const { + if (const auto *seq = stmt.as()) { + return seq->seq; + } + return {stmt}; + } + + std::optional + TryGetStaticAsyncWaitCount(const AttrStmtNode *attr) const { + if (!IsAsyncWaitQueueScope(attr)) { + return std::nullopt; + } + const auto *inner = attr->body.as(); + if (!IsAsyncWaitInflightCount(inner)) { + return std::nullopt; + } + const int64_t *imm = as_const_int(inner->value); + if (!imm) { + return std::nullopt; + } + return static_cast(*imm); + } + + Stmt MakeStaticAsyncWaitStmtLike(const AttrStmtNode *attr, + int new_wait_n) const { + const auto *inner = attr->body.as(); + if (!IsAsyncWaitInflightCount(inner)) { + return AttrStmt(attr->node, attr->attr_key, attr->value, attr->body, + attr->span); + } + PrimExpr new_wait = make_const(inner->value.dtype(), new_wait_n); + Stmt new_inner = AttrStmt(inner->node, inner->attr_key, new_wait, + inner->body, inner->span); + return AttrStmt(attr->node, attr->attr_key, attr->value, new_inner, + attr->span); + } + + HeadAsyncSyncInfo AnalyzeHeadAsyncSync(const Stmt &stmt, + HeadSeqMode seq_mode) const { + if (const auto *let = stmt.as()) { + return AnalyzeHeadAsyncSync(let->body, seq_mode); + } + if (const auto *attr = stmt.as()) { + if (IsAsyncWaitQueueScope(attr)) { + if (auto wait_n = TryGetStaticAsyncWaitCount(attr)) { + return {HeadAsyncSyncKind::kWaitStatic, *wait_n}; + } + return {HeadAsyncSyncKind::kWaitDynamic, 0}; + } + if (IsAsyncCommitQueueScope(attr)) { + return {HeadAsyncSyncKind::kCommit, 0}; + } + if (IsAsyncWaitInflightCount(attr)) { + return {HeadAsyncSyncKind::kBlocked, 0}; + } + return AnalyzeHeadAsyncSync(attr->body, seq_mode); + } + if (const auto *seq = stmt.as()) { + if (seq->seq.empty()) { + return {}; + } + if (seq_mode == HeadSeqMode::kSingletonOnly && seq->seq.size() != 1) { + return {HeadAsyncSyncKind::kBlocked, 0}; + } + return AnalyzeHeadAsyncSync(seq->seq[0], seq_mode); + } + if (const auto *block = stmt.as()) { + return AnalyzeHeadAsyncSync(block->body, seq_mode); + } + if (const auto *realize = stmt.as()) { + if (is_one(realize->predicate)) { + return AnalyzeHeadAsyncSync(realize->block->body, seq_mode); + } + return {HeadAsyncSyncKind::kBlocked, 0}; + } + return {}; + } + + ClassifiedAsyncSyncStmt ClassifySimpleAsyncSyncStmt(const Stmt &stmt) const { + HeadAsyncSyncInfo info = + AnalyzeHeadAsyncSync(stmt, HeadSeqMode::kSingletonOnly); + switch (info.kind) { + case HeadAsyncSyncKind::kCommit: + return {AsyncSyncStmtKind::kCommit, 0}; + case HeadAsyncSyncKind::kWaitStatic: + return {AsyncSyncStmtKind::kWaitStatic, info.wait_n}; + case HeadAsyncSyncKind::kWaitDynamic: + return {AsyncSyncStmtKind::kWaitDynamic, 0}; + default: + return {}; + } + } + + bool ContainsAsyncSyncScopes(const Stmt &stmt) const { + bool found = false; + PostOrderVisit(stmt, [&](const ObjectRef &obj) { + if (found) { + return; + } + if (const auto *attr = obj.as()) { + if (IsAsyncCommitQueueScope(attr) || IsAsyncWaitQueueScope(attr)) { + found = true; + } + } + }); + return found; + } + + bool ContainsAsyncCommitScopes(const Stmt &stmt) const { + bool found = false; + PostOrderVisit(stmt, [&](const ObjectRef &obj) { + if (found) { + return; + } + if (const auto *attr = obj.as()) { + if (IsAsyncCommitQueueScope(attr)) { + found = true; + } + } + }); + return found; + } + + AsyncSyncSummary SummarizeAsyncSyncScopes(const Stmt &stmt) const { + AsyncSyncSummary summary; + PostOrderVisit(stmt, [&](const ObjectRef &obj) { + if (const auto *attr = obj.as()) { + if (IsAsyncCommitQueueScope(attr)) { + ++summary.commit; + } else if (IsAsyncWaitQueueScope(attr)) { + ++summary.wait; + } + } + }); + return summary; + } + + std::optional + TryGetDeterministicNoWaitCommitGroups(const Stmt &stmt) const { + auto effect = DeterministicNoWaitCommitAnalyzer(this).Analyze(stmt); + if (!effect.deterministic || effect.has_wait) { + return std::nullopt; + } + return effect.commit_groups; + } + + int GuaranteedNewGroupsBeforeNextWait(const Array &body, + int start_idx) const { + int guaranteed_groups = 0; + for (int i = start_idx, n = static_cast(body.size()); i < n; ++i) { + AsyncSyncSummary summary = SummarizeAsyncSyncScopes(body[i]); + if (summary.wait > 0) { + break; + } + if (summary.commit == 0) { + continue; + } + if (auto commits = TryGetDeterministicNoWaitCommitGroups(body[i])) { + guaranteed_groups += *commits; + continue; + } + break; + } + return guaranteed_groups; + } + + Stmt RewriteWaitStaticInSimpleWrapper(const Stmt &stmt, int new_wait_n, + bool *changed) const { + ClassifiedAsyncSyncStmt cls = ClassifySimpleAsyncSyncStmt(stmt); + if (cls.kind != AsyncSyncStmtKind::kWaitStatic) { + return stmt; + } + if (const auto *attr = stmt.as()) { + if (IsAsyncWaitQueueScope(attr)) { + *changed = true; + return MakeStaticAsyncWaitStmtLike(attr, new_wait_n); + } + } + if (const auto *let = stmt.as()) { + Stmt new_body = + RewriteWaitStaticInSimpleWrapper(let->body, new_wait_n, changed); + if (*changed) { + return LetStmt(let->var, let->value, new_body, let->span); + } + return stmt; + } + if (const auto *attr = stmt.as()) { + Stmt new_body = + RewriteWaitStaticInSimpleWrapper(attr->body, new_wait_n, changed); + if (*changed) { + return AttrStmt(attr->node, attr->attr_key, attr->value, new_body, + attr->span); + } + return stmt; + } + if (const auto *seq = stmt.as()) { + if (seq->seq.size() == 1) { + Stmt inner = + RewriteWaitStaticInSimpleWrapper(seq->seq[0], new_wait_n, changed); + if (*changed) { + return SeqStmt({inner}); + } + } + return stmt; + } + if (const auto *block = stmt.as()) { + Stmt inner = + RewriteWaitStaticInSimpleWrapper(block->body, new_wait_n, changed); + if (*changed) { + Block new_block = Downcast(stmt); + new_block.CopyOnWrite()->body = inner; + return new_block; + } + return stmt; + } + if (const auto *realize = stmt.as()) { + if (is_one(realize->predicate)) { + Stmt inner = RewriteWaitStaticInSimpleWrapper(realize->block->body, + new_wait_n, changed); + if (*changed) { + Block new_block = realize->block; + new_block.CopyOnWrite()->body = inner; + return BlockRealize(realize->iter_values, realize->predicate, + new_block, realize->span); + } + } + return stmt; + } + return stmt; + } + + std::optional TryGetHeadStaticWaitCount(const Stmt &stmt) const { + HeadAsyncSyncInfo info = + AnalyzeHeadAsyncSync(stmt, HeadSeqMode::kTakeFirstElement); + if (info.kind == HeadAsyncSyncKind::kWaitStatic) { + return info.wait_n; + } + return std::nullopt; + } + + std::optional TryGetFirstStaticWaitCount(const Stmt &stmt) const { + if (const auto *let = stmt.as()) { + return TryGetFirstStaticWaitCount(let->body); + } + if (const auto *attr = stmt.as()) { + HeadAsyncSyncInfo info = + AnalyzeHeadAsyncSync(stmt, HeadSeqMode::kTakeFirstElement); + if (info.kind == HeadAsyncSyncKind::kWaitStatic) { + return info.wait_n; + } + if (info.IsBoundary()) { + return std::nullopt; + } + return TryGetFirstStaticWaitCount(attr->body); + } + if (const auto *seq = stmt.as()) { + for (const Stmt &elem : seq->seq) { + HeadAsyncSyncInfo info = + AnalyzeHeadAsyncSync(elem, HeadSeqMode::kTakeFirstElement); + if (info.kind == HeadAsyncSyncKind::kWaitStatic) { + return info.wait_n; + } + if (info.IsBoundary() || ContainsAsyncSyncScopes(elem)) { + return std::nullopt; + } + } + return std::nullopt; + } + if (const auto *block = stmt.as()) { + return TryGetFirstStaticWaitCount(block->body); + } + if (const auto *realize = stmt.as()) { + if (is_one(realize->predicate)) { + return TryGetFirstStaticWaitCount(realize->block->body); + } + } + return std::nullopt; + } + + Stmt RewriteHeadStaticWaitInWrapper(const Stmt &stmt, int new_wait_n, + bool *changed) const { + if (const auto *let = stmt.as()) { + Stmt new_body = + RewriteHeadStaticWaitInWrapper(let->body, new_wait_n, changed); + if (*changed) { + return LetStmt(let->var, let->value, new_body, let->span); + } + return stmt; + } + if (const auto *attr = stmt.as()) { + if (IsAsyncWaitQueueScope(attr)) { + *changed = true; + return MakeStaticAsyncWaitStmtLike(attr, new_wait_n); + } + Stmt new_body = + RewriteHeadStaticWaitInWrapper(attr->body, new_wait_n, changed); + if (*changed) { + return AttrStmt(attr->node, attr->attr_key, attr->value, new_body, + attr->span); + } + return stmt; + } + if (const auto *seq = stmt.as()) { + if (seq->seq.empty()) { + return stmt; + } + Array new_seq = seq->seq; + new_seq.Set( + 0, RewriteHeadStaticWaitInWrapper(seq->seq[0], new_wait_n, changed)); + if (*changed) { + return SeqStmt(new_seq); + } + return stmt; + } + if (const auto *block = stmt.as()) { + Stmt new_body = + RewriteHeadStaticWaitInWrapper(block->body, new_wait_n, changed); + if (*changed) { + Block new_block = Downcast(stmt); + new_block.CopyOnWrite()->body = new_body; + return new_block; + } + return stmt; + } + if (const auto *realize = stmt.as()) { + if (is_one(realize->predicate)) { + Stmt new_body = RewriteHeadStaticWaitInWrapper(realize->block->body, + new_wait_n, changed); + if (*changed) { + Block new_block = realize->block; + new_block.CopyOnWrite()->body = new_body; + return BlockRealize(realize->iter_values, realize->predicate, + new_block, realize->span); + } + } + return stmt; + } + return stmt; + } + + Stmt RewriteFirstStaticWaitInWrapper(const Stmt &stmt, int new_wait_n, + bool *changed) const { + if (const auto *let = stmt.as()) { + Stmt new_body = + RewriteFirstStaticWaitInWrapper(let->body, new_wait_n, changed); + if (*changed) { + return LetStmt(let->var, let->value, new_body, let->span); + } + return stmt; + } + if (const auto *attr = stmt.as()) { + if (IsAsyncWaitQueueScope(attr)) { + *changed = true; + return MakeStaticAsyncWaitStmtLike(attr, new_wait_n); + } + if (IsAsyncCommitQueueScope(attr) || IsAsyncWaitInflightCount(attr)) { + return stmt; + } + Stmt new_body = + RewriteFirstStaticWaitInWrapper(attr->body, new_wait_n, changed); + if (*changed) { + return AttrStmt(attr->node, attr->attr_key, attr->value, new_body, + attr->span); + } + return stmt; + } + if (const auto *seq = stmt.as()) { + Array new_seq = seq->seq; + for (int i = 0, n = static_cast(new_seq.size()); i < n; ++i) { + Stmt updated = + RewriteFirstStaticWaitInWrapper(new_seq[i], new_wait_n, changed); + if (*changed) { + new_seq.Set(i, updated); + return SeqStmt(new_seq); + } + if (ContainsAsyncSyncScopes(new_seq[i])) { + return stmt; + } + } + return stmt; + } + if (const auto *block = stmt.as()) { + Stmt new_body = + RewriteFirstStaticWaitInWrapper(block->body, new_wait_n, changed); + if (*changed) { + Block new_block = Downcast(stmt); + new_block.CopyOnWrite()->body = new_body; + return new_block; + } + return stmt; + } + if (const auto *realize = stmt.as()) { + if (is_one(realize->predicate)) { + Stmt new_body = RewriteFirstStaticWaitInWrapper(realize->block->body, + new_wait_n, changed); + if (*changed) { + Block new_block = realize->block; + new_block.CopyOnWrite()->body = new_body; + return BlockRealize(realize->iter_values, realize->predicate, + new_block, realize->span); + } + } + return stmt; + } + return stmt; + } + + Stmt MaybeRelaxLoopWaits(const For &loop, int pre_outstanding_lb) const { + int retain = PipelinedRetainGroups(GetPipelineNumStages(loop.get())); + if (retain <= 0 || !loop.defined()) { + return loop; + } + const auto *seq = loop->body.as(); + if (!seq || seq->seq.empty()) { + return loop; + } + + Array body = seq->seq; + bool changed = false; + int outstanding_lb = std::max(0, pre_outstanding_lb); + int groups_since_wait_lb = 0; + bool seen_wait_boundary = false; + + for (int i = 0, n = static_cast(body.size()); i < n; ++i) { + ClassifiedAsyncSyncStmt cls = ClassifySimpleAsyncSyncStmt(body[i]); + if (cls.kind == AsyncSyncStmtKind::kCommit) { + ++outstanding_lb; + ++groups_since_wait_lb; + continue; + } + if (cls.kind == AsyncSyncStmtKind::kWaitDynamic) { + seen_wait_boundary = true; + outstanding_lb = 0; + groups_since_wait_lb = 0; + continue; + } + if (cls.kind == AsyncSyncStmtKind::kWaitStatic) { + int effective_wait_n = cls.wait_n; + if (cls.wait_n == 0) { + int groups_after_wait_lb = + GuaranteedNewGroupsBeforeNextWait(body, i + 1); + int per_sync_groups = groups_since_wait_lb; + bool uses_head_fallback = + (per_sync_groups == 0 && !seen_wait_boundary); + if (uses_head_fallback) { + per_sync_groups = 1; + } + int candidate_wait_n = + std::max(0, std::min(retain * per_sync_groups, 7)); + bool enough_pre_outstanding = + !uses_head_fallback || outstanding_lb >= (candidate_wait_n + 1); + if (candidate_wait_n > 0 && enough_pre_outstanding && + (!uses_head_fallback || groups_after_wait_lb > 0)) { + bool changed_wait = false; + body.Set(i, RewriteWaitStaticInSimpleWrapper( + body[i], candidate_wait_n, &changed_wait)); + if (changed_wait) { + changed = true; + effective_wait_n = candidate_wait_n; + } + } + } + seen_wait_boundary = true; + outstanding_lb = std::min(outstanding_lb, effective_wait_n); + groups_since_wait_lb = 0; + continue; + } + + AsyncSyncSummary summary = SummarizeAsyncSyncScopes(body[i]); + if (summary.wait == 0) { + if (auto commits = TryGetDeterministicNoWaitCommitGroups(body[i])) { + outstanding_lb += *commits; + groups_since_wait_lb += *commits; + continue; + } + } + if (summary.wait > 0) { + seen_wait_boundary = true; + } + outstanding_lb = 0; + groups_since_wait_lb = 0; + } + + if (!changed) { + return loop; + } + For new_loop = loop; + new_loop.CopyOnWrite()->body = body.size() == 1 ? body[0] : SeqStmt(body); + return new_loop; + } + + Stmt RelaxLoopWaitsInSimpleWrapper(const Stmt &stmt, int pre_outstanding_lb, + bool *changed) const { + if (const auto *loop = stmt.as()) { + Stmt relaxed = + MaybeRelaxLoopWaits(Downcast(stmt), pre_outstanding_lb); + *changed = !relaxed.same_as(stmt); + return relaxed; + } + if (const auto *let = stmt.as()) { + Stmt new_body = + RelaxLoopWaitsInSimpleWrapper(let->body, pre_outstanding_lb, changed); + if (*changed) { + return LetStmt(let->var, let->value, new_body, let->span); + } + return stmt; + } + if (const auto *attr = stmt.as()) { + Stmt new_body = RelaxLoopWaitsInSimpleWrapper( + attr->body, pre_outstanding_lb, changed); + if (*changed) { + return AttrStmt(attr->node, attr->attr_key, attr->value, new_body, + attr->span); + } + return stmt; + } + if (const auto *seq = stmt.as()) { + if (seq->seq.size() == 1) { + Stmt inner = RelaxLoopWaitsInSimpleWrapper(seq->seq[0], + pre_outstanding_lb, changed); + if (*changed) { + return SeqStmt({inner}); + } + } + return stmt; + } + if (const auto *block = stmt.as()) { + Stmt new_body = RelaxLoopWaitsInSimpleWrapper( + block->body, pre_outstanding_lb, changed); + if (*changed) { + Block new_block = Downcast(stmt); + new_block.CopyOnWrite()->body = new_body; + return new_block; + } + return stmt; + } + if (const auto *realize = stmt.as()) { + if (is_one(realize->predicate)) { + Stmt new_body = RelaxLoopWaitsInSimpleWrapper( + realize->block->body, pre_outstanding_lb, changed); + if (*changed) { + Block new_block = realize->block; + new_block.CopyOnWrite()->body = new_body; + return BlockRealize(realize->iter_values, realize->predicate, + new_block, realize->span); + } + } + return stmt; + } + return stmt; + } + + class AsyncPipelineLoopWaitRelaxer : public StmtExprMutator { + public: + explicit AsyncPipelineLoopWaitRelaxer(const PipelineRewriter *rewriter) + : rewriter_(rewriter) {} + + Stmt VisitStmt_(const SeqStmtNode *op) final { + Array visited; + visited.reserve(op->seq.size()); + for (const Stmt &stmt : op->seq) { + visited.push_back(this->VisitStmt(stmt)); + } + + int outstanding_lb = 0; + for (int i = 0, n = static_cast(visited.size()); i < n; ++i) { + Stmt current = visited[i]; + bool changed_loop = false; + current = rewriter_->RelaxLoopWaitsInSimpleWrapper( + current, outstanding_lb, &changed_loop); + if (changed_loop) { + visited.Set(i, current); + } + ClassifiedAsyncSyncStmt cls = + rewriter_->ClassifySimpleAsyncSyncStmt(current); + if (cls.kind == AsyncSyncStmtKind::kCommit) { + ++outstanding_lb; + continue; + } + if (cls.kind == AsyncSyncStmtKind::kWaitStatic) { + outstanding_lb = std::min(outstanding_lb, cls.wait_n); + continue; + } + if (cls.kind == AsyncSyncStmtKind::kWaitDynamic) { + outstanding_lb = 0; + continue; + } + AsyncSyncSummary summary = rewriter_->SummarizeAsyncSyncScopes(current); + if (summary.wait == 0) { + if (auto commits = + rewriter_->TryGetDeterministicNoWaitCommitGroups(current)) { + outstanding_lb += *commits; + continue; + } + } + if (summary.wait > 0) { + outstanding_lb = 0; + } + } + + if (visited.empty()) { + return Evaluate(0); + } + if (visited.size() == 1) { + return visited[0]; + } + return SeqStmt(visited); + } + + private: + const PipelineRewriter *rewriter_; + }; + + Array RelaxTrailingConsumerWaits(Array seq, int retain) const { + if (retain <= 0 || seq.size() <= 1) { + return seq; + } + std::vector suffix_wait_indices; + for (int i = static_cast(seq.size()) - 1; i >= 0; --i) { + if (ContainsAsyncCommitScopes(seq[i])) { + break; + } + auto first_wait = TryGetFirstStaticWaitCount(seq[i]); + if (!first_wait.has_value() || *first_wait != 0) { + break; + } + suffix_wait_indices.push_back(i); + } + if (suffix_wait_indices.size() <= 1) { + return seq; + } + for (size_t pos = 1; pos < suffix_wait_indices.size(); ++pos) { + bool changed = false; + int idx = suffix_wait_indices[pos]; + seq.Set(idx, RewriteFirstStaticWaitInWrapper(seq[idx], retain, &changed)); + } + return seq; + } + + void PopulateWaitCounts( + const std::vector &new_stmts, + arith::Analyzer *ana_normalized, + const std::unordered_map &buffer_to_commit_group, + std::map *async_states_local) { + for (size_t i = 0; i < new_stmts.size(); ++i) { + if (new_stmts[i].is_async) { + for (const BufferRegion &write_region : new_stmts[i].writes) { + (*async_states_local)[new_stmts[i].stage].seen.insert( + write_region->buffer.get()); + } + continue; + } + + int producer_stage_idx = -1; + for (const BufferRegion &read_region : new_stmts[i].reads) { + for (const auto &kv : async_states_) { + if (kv.first <= new_stmts[i].stage && + kv.second.writes(read_region->buffer)) { + ICHECK(producer_stage_idx == -1 || producer_stage_idx == kv.first) + << "A dependency on multiple async stages is not supported"; + producer_stage_idx = kv.first; + } + } + } + + if (producer_stage_idx == -1) { + continue; + } + + auto &dep_local_state = (*async_states_local)[producer_stage_idx]; + int num_commit_group = dep_local_state.commit_groups.size(); + std::vector> producer_head_per_commit; + std::vector dependent_commit_groups; + + if (num_commit_group == 0) { + ICHECK(!dep_local_state.producer_head); + dependent_commit_groups.push_back(-1); + producer_head_per_commit.push_back( + async_states_[producer_stage_idx].producer_head); + } else { + ICHECK(dep_local_state.producer_head); + std::vector need_wait_count(num_commit_group, true); + for (const BufferRegion &read_region : new_stmts[i].reads) { + if (!async_states_[producer_stage_idx].writes(read_region->buffer)) { + continue; + } + auto commit_group_id = + buffer_to_commit_group.at(read_region->buffer.get()); + if (!need_wait_count[commit_group_id]) { + continue; + } + dependent_commit_groups.push_back(commit_group_id); + if (!dep_local_state.seen.count(read_region->buffer.get())) { + producer_head_per_commit.push_back( + dep_local_state.producer_head.value() - 1); + } else { + producer_head_per_commit.push_back( + dep_local_state.producer_head.value()); + } + need_wait_count[commit_group_id] = false; + } + } + + PrimExpr wait_count = [&]() { + PrimExpr sum = PrimExpr(0); + for (const Optional &producer_head : + producer_head_per_commit) { + if (producer_head && + ana_normalized->CanProve(producer_head.value() >= 0)) { + sum += analyzer_.Simplify(producer_head.value() - + new_stmts[i].access_index); + } else { + return PrimExpr(0); + } + } + return sum; + }(); + + for (int commit_group_id : dependent_commit_groups) { + auto &pending_wait = dep_local_state.pending_waits[commit_group_id]; + if (!pending_wait.valid()) { + pending_wait = {static_cast(i), wait_count}; + } else if (analyzer_.CanProve(wait_count < pending_wait.wait_count)) { + pending_wait = {pending_wait.insert_before, wait_count}; + } + } + } + } + + std::vector CompletePipelineLoopStatements( + const std::vector &stmts, + const std::map &async_states_local, + arith::Analyzer *ana_normalized) const { + std::vector new_stmts; + new_stmts.reserve(stmts.size()); + for (const auto &stmt : stmts) { + new_stmts.push_back( + {stmt.stage, stmt.access_index, stmt.predicate, stmt.stmt}); + } + + std::vector commit_group_tags(new_stmts.size(), -1); + std::unordered_map commit_group_tag_to_stage; + int next_commit_group_tag = 0; + std::map> waits_before_stmt; + auto make_wait_stmt = [](int stage_id, PrimExpr wait_count, Stmt body) { + auto zero = make_zero(DataType::Int(32)); + return AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id, + AttrStmt(zero, tir::attr::async_wait_inflight_count, + wait_count, body)); + }; + auto merge_wait_before_stmt = [&](int insert_before, int stage_id, + PrimExpr wait_count) { + auto &waits_at_stmt = waits_before_stmt[insert_before]; + auto it = waits_at_stmt.find(stage_id); + if (it == waits_at_stmt.end()) { + waits_at_stmt.emplace(stage_id, ana_normalized->Simplify(wait_count)); + } else if (ana_normalized->CanProve(wait_count < it->second)) { + it->second = ana_normalized->Simplify(wait_count); + } + }; + + for (const auto &[stage_id, state] : async_states_local) { + if (!state.commit_groups.empty()) { + for (const auto &group_stmt_indices : state.commit_groups) { + int commit_group_tag = next_commit_group_tag++; + commit_group_tag_to_stage.emplace(commit_group_tag, stage_id); + for (size_t stmt_idx : group_stmt_indices) { + ICHECK(stmt_idx < new_stmts.size()); + commit_group_tags[stmt_idx] = commit_group_tag; + } + } + } + + for (const auto &[commit_group_id, pending_wait] : state.pending_waits) { + if (!pending_wait.valid()) { + continue; + } + PrimExpr wait_count = ana_normalized->Simplify(pending_wait.wait_count); + if (state.predicate && + !ana_normalized->CanProve(state.predicate.value())) { + PrimExpr predicate = + ana_normalized->Simplify(state.predicate.value()); + if (is_zero(predicate)) { + continue; + } + merge_wait_before_stmt(pending_wait.insert_before, stage_id, + wait_count); + continue; + } + + merge_wait_before_stmt(pending_wait.insert_before, stage_id, + wait_count); + } + } + + std::vector result; + for (size_t i = 0; i < new_stmts.size();) { + if (auto it = waits_before_stmt.find(i); it != waits_before_stmt.end()) { + for (const auto &[stage_id, wait_count] : it->second) { + Stmt wait_stmt = make_wait_stmt(stage_id, wait_count, Evaluate(0)); + if (auto state_it = async_states_local.find(stage_id); + state_it != async_states_local.end() && + state_it->second.predicate && + !ana_normalized->CanProve(state_it->second.predicate.value())) { + PrimExpr predicate = + ana_normalized->Simplify(state_it->second.predicate.value()); + if (is_zero(predicate)) { + continue; + } + wait_stmt = IfThenElse(predicate, wait_stmt, Evaluate(0)); + } + result.push_back({new_stmts[i].stage, new_stmts[i].access_index, + new_stmts[i].predicate, wait_stmt}); + } + } + + if (commit_group_tags[i] == -1) { + result.push_back(new_stmts[i]); + ++i; + continue; + } + + int commit_group_tag = commit_group_tags[i]; + int stage_id = commit_group_tag_to_stage.at(commit_group_tag); + Array group_stmts; + PrimExpr access_index = new_stmts[i].access_index; + PrimExpr predicate = new_stmts[i].predicate; + for (; i < new_stmts.size() && commit_group_tags[i] == commit_group_tag; + ++i) { + group_stmts.push_back(new_stmts[i].stmt); + } + Stmt group_body = + group_stmts.size() == 1 ? group_stmts[0] : SeqStmt(group_stmts); + Stmt commit_queue_scope = + AttrStmt(make_zero(DataType::Int(32)), + tir::attr::async_commit_queue_scope, stage_id, group_body); + if (!is_one(predicate) && !ana_normalized->CanProve(predicate)) { + PrimExpr simplified_predicate = ana_normalized->Simplify(predicate); + if (!is_zero(simplified_predicate)) { + commit_queue_scope = + IfThenElse(simplified_predicate, commit_queue_scope, Evaluate(0)); + } + } + result.push_back({stage_id, access_index, predicate, commit_queue_scope}); + } + return result; + } + + /*! + * \brief Emit the pipeline loop in the given range. + * \param start The start of the range + * \param end The end of the range + * \param unroll_loop Whether the loop should be unrolled. + * \return The result loop. + */ + Stmt EmitImpl(const PrimExpr &start, const PrimExpr &end, bool unroll_loop, + bool need_bound_check) { + PrimExpr new_loop_var; + PrimExpr extent = end - start; + Optional pipeline_num_stages = + GetPipelineNumStages(pipeline_loop_.get()); + auto make_nop = []() { + return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {})); + }; + + if (unroll_loop) { + if (const int64_t *extent_imm = as_const_int(extent)) { + if (*extent_imm > 1) { + Array expanded; + expanded.reserve(static_cast(*extent_imm)); + for (int64_t iter = 0; iter < *extent_imm; ++iter) { + PrimExpr unit_start = + analyzer_.Simplify(start + IntImm(extent.dtype(), iter)); + PrimExpr unit_end = + analyzer_.Simplify(start + IntImm(extent.dtype(), iter + 1)); + Stmt unit_stmt = + EmitImpl(unit_start, unit_end, false, need_bound_check); + expanded.push_back(StripPipelineContextAttrs(unit_stmt)); + } + Stmt result = expanded.size() == 1 ? expanded[0] : SeqStmt(expanded); + if (pipeline_num_stages) { + if (pipeline_num_stages.value()->value > 1) { + result = AttrStmt(Integer(0), kPipelineMVBContextNumStages, + Downcast(pipeline_num_stages.value()), + result); + } + result = AttrStmt(Integer(0), kPipelineContextNumStages, + Downcast(pipeline_num_stages.value()), + result); + } + return result; + } + } + } + + bool is_unit_loop = analyzer_.CanProveEqual(extent, 1); + if (is_unit_loop) { + new_loop_var = start; // use constants as the loop var for unit loops + } else { + new_loop_var = pipeline_loop_->loop_var.copy_with_suffix(""); + // Bind the iteration domain [start, end) to strengthen analyzer facts. + analyzer_.Bind(Downcast(new_loop_var), + Range::FromMinExtent(start, end - start)); + } + // Keep the bound constraints active for all analysis below. + // Only meaningful when the loop var is symbolic (non-unit loop). + std::unique_ptr> ctx_lb_guard; + std::unique_ptr> ctx_ub_guard; + if (!is_unit_loop) { + Var loop_iter = Downcast(new_loop_var); + ctx_lb_guard.reset( + new With(&analyzer_, loop_iter >= start)); + ctx_ub_guard.reset( + new With(&analyzer_, loop_iter < end)); + } + + arith::Analyzer ana_normalized; + if (!is_unit_loop) { + ana_normalized.Bind(Downcast(new_loop_var), + Range(pipeline_loop_->min, extent)); + } + + std::vector new_stmts; + std::map async_states_local; + std::unordered_map buffer_to_commit_group; + + for (const Block &block : ordered_stmts_) { + const auto &pipeline_anno = pipeline_info_.at(block); + int stage = pipeline_anno.stage; + PrimExpr inbound = Bool(true); + PrimExpr skewed_loop_var = new_loop_var - stage; + if (need_bound_check) + inbound = And( + pipeline_loop_->min <= skewed_loop_var, + (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent)); + + Block new_block = Downcast( + PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_, + pipeline_loop_, max_stage_ != 1)(block)); + + PrimExpr delta = start - pipeline_loop_->min; + PrimExpr normalized_access_index = + is_unit_loop ? skewed_loop_var : skewed_loop_var + delta; + + normalized_access_index = analyzer_.Simplify(normalized_access_index); + + // Adjust the block predicate and the body according to the final loop + // bound + // [pipeline_loop_->min, extent). + if (!is_unit_loop) { + Var loop_iter = Downcast(new_loop_var); + inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}}); + } + inbound = ana_normalized.Simplify(inbound); + if (is_zero(inbound)) { + continue; + } + new_block = Downcast(Substitute( + new_block, {{pipeline_loop_->loop_var, normalized_access_index}})); + + Stmt rewritten_stmt = BlockRealize({}, inbound, new_block); + rewritten_stmt = WrapLoopDependentWrappers(std::move(rewritten_stmt), + normalized_access_index); + rewritten_stmt = WrapPipelineStageContext(std::move(rewritten_stmt), + normalized_access_index, + pipeline_num_stages); + Optional pipeline_mbar_phase = ComputePipelineMbarPhaseExpr( + normalized_access_index, pipeline_num_stages); + + bool is_async = pipeline_anno.async; + if (is_async) { + auto &local_state = async_states_local[stage]; + int commit_group_id = -1; + if (pipeline_anno.async_group_id >= 0) { + auto it = local_state.annotated_group_to_commit_group.find( + pipeline_anno.async_group_id); + if (it == local_state.annotated_group_to_commit_group.end()) { + commit_group_id = local_state.commit_groups.size(); + local_state.commit_groups.push_back({new_stmts.size()}); + local_state.annotated_group_to_commit_group.emplace( + pipeline_anno.async_group_id, commit_group_id); + } else { + commit_group_id = it->second; + local_state.commit_groups[commit_group_id].push_back( + new_stmts.size()); + } + } else if (local_state.commit_groups.empty() || local_state.consumed) { + commit_group_id = local_state.commit_groups.size(); + local_state.commit_groups.push_back({new_stmts.size()}); + } else { + commit_group_id = local_state.commit_groups.size() - 1; + local_state.commit_groups.back().push_back(new_stmts.size()); + } + + for (const BufferRegion &write_region : new_block->writes) { + async_states_[stage].dst_buffers.insert(write_region->buffer.get()); + buffer_to_commit_group[write_region->buffer.get()] = commit_group_id; + } + async_states_[stage].producer_head = normalized_access_index; + local_state.producer_head = normalized_access_index; + if (!local_state.predicate || + ana_normalized.CanProve(local_state.predicate.value())) { + local_state.predicate = inbound; + } else { + local_state.predicate = + ana_normalized.Simplify(local_state.predicate.value() & inbound); + } + rewritten_stmt = + SimtProducerAnnotator::Annotate(rewritten_stmt, target_); + rewritten_stmt = AttrStmt(make_zero(DataType::Int(32)), + tir::attr::async_scope, 1, rewritten_stmt); + } + if (pipeline_mbar_phase) { + rewritten_stmt = TileOpMbarPhaseAnnotator::Annotate( + rewritten_stmt, pipeline_mbar_phase.value()); + } + + new_stmts.push_back({stage, inbound, new_block->reads, new_block->writes, + normalized_access_index, is_async, rewritten_stmt}); + + for (const BufferRegion &read_region : new_block->reads) { + for (const auto &kv : async_states_) { + if (kv.first <= stage && kv.second.writes(read_region->buffer)) { + async_states_local[kv.first].consumed = true; + } + } + } + } + + PopulateWaitCounts(new_stmts, &ana_normalized, buffer_to_commit_group, + &async_states_local); + std::vector final_stmts = CompletePipelineLoopStatements( + new_stmts, async_states_local, &ana_normalized); + + Array stmts; + for (const auto &stmt_info : final_stmts) { + stmts.push_back(stmt_info.stmt); + } + + Stmt new_loop{nullptr}; + + if (stmts.empty()) { + return make_nop(); + } + + if (stmts.size() == 1) { + new_loop = stmts[0]; + } else { + new_loop = SeqStmt(stmts); + } + + if (!is_unit_loop) { + Map preserved_annotations; + for (const auto &kv : pipeline_loop_->annotations) { + const String &key = kv.first; + if (kv.first != tir::attr::software_pipeline_stage && + kv.first != tir::attr::software_pipeline_order && + kv.first != tir::attr::software_pipeline_async_stages && + kv.first != kPipelineAsyncProducers && + kv.first != kPipelineAsyncProducerGroups && + kv.first != kPipelineTmaCopies && kv.first != "num_stages") { + preserved_annotations.Set(key, kv.second); + } + } + if (pipeline_num_stages && + preserved_annotations.find("tl_pipelined_num_stages") == + preserved_annotations.end()) { + preserved_annotations.Set("tl_pipelined_num_stages", + pipeline_num_stages.value()); + } + new_loop = For(Downcast(new_loop_var), pipeline_loop_->min, extent, + unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, + std::move(new_loop), std::nullopt, preserved_annotations); + } + Stmt result = BlockRealize({}, Bool(true), + MakeBlock(new_loop, buffer_data_to_buffer_)); + if (pipeline_num_stages) { + if (pipeline_num_stages.value()->value > 1) { + result = + AttrStmt(Integer(0), kPipelineMVBContextNumStages, + Downcast(pipeline_num_stages.value()), result); + } + result = + AttrStmt(Integer(0), kPipelineContextNumStages, + Downcast(pipeline_num_stages.value()), result); + } + return result; + } + + arith::Analyzer analyzer_; + Map buffer_data_to_buffer_; + Array pipeline_allocs_; + Array local_allocs_; + For pipeline_loop_; + PipelineInfo pipeline_info_; + int max_stage_ = -1; Map buffer_remap_; + Optional target_; Array ordered_stmts_; std::vector loop_var_let_wrappers_; std::vector loop_var_if_wrappers_; + std::map async_states_; }; /*! @@ -728,11 +2297,478 @@ void BuildDependencyGraph(const Array &blocks, } } +// --------------------------------------------------------------------------- +// Helpers for pipeline-level TMA barrier management +// --------------------------------------------------------------------------- + +/*! + * \brief Rewrite a block's body, converting tl.tileop.copy calls to + * tl.tileop.tma_copy with barrier and emit_arrive annotations. + */ +class CopyToTmaCopyRewriter : public StmtExprMutator { +public: + CopyToTmaCopyRewriter(const Buffer &barrier_buf, PrimExpr barrier_id, + bool emit_arrive = true) + : barrier_buf_(barrier_buf), barrier_id_(std::move(barrier_id)), + emit_arrive_(emit_arrive) {} + + PrimExpr VisitExpr_(const CallNode *op) final { + static const Op ©_op = Op::Get("tl.tileop.copy"); + static const Op &tma_copy_op = Op::Get("tl.tileop.tma_copy"); + static const Op &im2col_op = Op::Get("tl.tileop.c2d_im2col"); + Call call = Downcast(StmtExprMutator::VisitExpr_(op)); + if (call->op.same_as(copy_op)) { + auto new_annotations = call->annotations; + new_annotations.Set("barrier", MakeBarrierRef(barrier_buf_, barrier_id_)); + new_annotations.Set("is_tma_copy", IntImm(DataType::Int(32), 1)); + new_annotations.Set("emit_arrive", + IntImm(DataType::Int(32), emit_arrive_ ? 1 : 0)); + return Call(call->dtype, tma_copy_op, call->args, new_annotations, + call->span); + } + // Annotate c2d_im2col with pipeline barrier so its Lower() uses it + // instead of allocating a separate internal barrier. + if (call->op.same_as(im2col_op)) { + auto new_annotations = call->annotations; + new_annotations.Set("barrier", MakeBarrierRef(barrier_buf_, barrier_id_)); + new_annotations.Set("emit_arrive", + IntImm(DataType::Int(32), emit_arrive_ ? 1 : 0)); + return Call(call->dtype, call->op, call->args, new_annotations, + call->span); + } + return call; + } + +private: + Buffer barrier_buf_; + PrimExpr barrier_id_; + bool emit_arrive_; +}; + +// --------------------------------------------------------------------------- +// ExpandPipelineBarriers — multi-version all barrier buffers for pipelining +// --------------------------------------------------------------------------- + +/// Collect all shared.barrier Buffer objects referenced in a statement. +class BarrierBufferCollector : public StmtExprVisitor { +public: + static std::vector + Collect(const Array &blocks, + const Map &buffer_data_to_buffer) { + BarrierBufferCollector c(buffer_data_to_buffer); + for (const auto &block : blocks) { + c(block->body); + } + return {c.barriers_.begin(), c.barriers_.end()}; + } + +private: + explicit BarrierBufferCollector(const Map &buf_map) + : buf_map_(buf_map) {} + + void VisitExpr_(const BufferLoadNode *op) final { + if (op->buffer.scope() == "shared.barrier" || + op->buffer.scope() == "shared.cluster_barrier") { + if (!seen_.count(op->buffer.get())) { + seen_.insert(op->buffer.get()); + barriers_.push_back(op->buffer); + } + } + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BufferStoreNode *op) final { + if (op->buffer.scope() == "shared.barrier" || + op->buffer.scope() == "shared.cluster_barrier") { + if (!seen_.count(op->buffer.get())) { + seen_.insert(op->buffer.get()); + barriers_.push_back(op->buffer); + } + } + StmtExprVisitor::VisitStmt_(op); + } + + // Also check barrier refs inside Call annotations (e.g., tma_copy barrier). + void VisitExpr_(const CallNode *op) final { + for (const auto &[key, val] : op->annotations) { + if (auto load = val.as()) { + if (load->buffer.scope() == "shared.barrier" || + load->buffer.scope() == "shared.cluster_barrier") { + if (!seen_.count(load->buffer.get())) { + seen_.insert(load->buffer.get()); + barriers_.push_back(load->buffer); + } + } + } + } + StmtExprVisitor::VisitExpr_(op); + } + + const Map &buf_map_; + std::unordered_set seen_; + std::vector barriers_; +}; + +/// Rewrite barrier references: expand indices and rewrite parity. +class BarrierIndexRewriter : public StmtExprMutator { +public: + BarrierIndexRewriter( + const std::unordered_map &old_to_new, + const std::unordered_map &old_shapes, + PrimExpr stage_expr, PrimExpr parity_cycle, Var loop_var, + PrimExpr loop_min) + : old_to_new_(old_to_new), old_shapes_(old_shapes), + stage_expr_(std::move(stage_expr)), + parity_cycle_(std::move(parity_cycle)), loop_var_(std::move(loop_var)), + loop_min_(std::move(loop_min)) {} + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + auto it = old_to_new_.find(load->buffer.get()); + if (it != old_to_new_.end()) { + auto *n = load.CopyOnWrite(); + PrimExpr old_size = old_shapes_.at(load->buffer.get()); + n->buffer = it->second; + n->indices.Set(0, stage_expr_ * old_size + n->indices[0]); + } + return load; + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + auto it = old_to_new_.find(store->buffer.get()); + if (it != old_to_new_.end()) { + auto *n = store.CopyOnWrite(); + PrimExpr old_size = old_shapes_.at(store->buffer.get()); + n->buffer = it->second; + n->indices.Set(0, stage_expr_ * old_size + n->indices[0]); + } + return store; + } + + PrimExpr VisitExpr_(const CallNode *op) final { + Call call = Downcast(StmtExprMutator::VisitExpr_(op)); + + // Rewrite barrier refs inside annotations (e.g., tma_copy "barrier"). + bool anno_changed = false; + Map new_annos = call->annotations; + for (const auto &[key, val] : call->annotations) { + if (auto load = val.as()) { + auto it = old_to_new_.find(load->buffer.get()); + if (it != old_to_new_.end()) { + PrimExpr old_size = old_shapes_.at(load->buffer.get()); + auto new_load = BufferLoad( + it->second, {stage_expr_ * old_size + load->indices[0]}); + new_annos.Set(key, new_load); + anno_changed = true; + } + } + } + if (anno_changed) { + call = Call(call->dtype, call->op, call->args, new_annos, call->span); + } + + // Rewrite mbarrier_wait_parity parity argument. + if (call->op.same_as(mbarrier_wait_parity()) && call->args.size() >= 2) { + if (auto load = call->args[0].as()) { + // Check if the barrier ref (possibly already rewritten above) + // targets one of our expanded barriers. + const BufferNode *target = load->buffer.get(); + bool is_expanded = false; + for (const auto &[old_buf, new_buf] : old_to_new_) { + if (new_buf.get() == target) { + is_expanded = true; + break; + } + } + if (is_expanded) { + // Compute initial-phase offset from the user's original parity. + arith::Analyzer analyzer; + PrimExpr user_parity = call->args[1]; + PrimExpr user_parity_at_min = analyzer.Simplify( + tir::Substitute(user_parity, {{loop_var_, loop_min_}})); + // New parity = (iteration_block + offset) % 2 + PrimExpr offset = IntImm(DataType::Int(32), 0); + if (const int64_t *imm = as_const_int(user_parity_at_min)) { + offset = IntImm(DataType::Int(32), *imm % 2); + } + PrimExpr new_parity = FloorMod(parity_cycle_ + offset, 2); + Array new_args = call->args; + new_args.Set(1, new_parity); + return Call(call->dtype, call->op, new_args, call->annotations, + call->span); + } + } + } + return call; + } + +private: + const std::unordered_map &old_to_new_; + const std::unordered_map &old_shapes_; + PrimExpr stage_expr_; + PrimExpr parity_cycle_; + Var loop_var_; + PrimExpr loop_min_; +}; + +/// Expand all shared.barrier buffers in the pipeline body from [N] to +/// [N * num_stages], rewrite barrier indices to include stage offset, and +/// rewrite mbarrier_wait_parity parity expressions. +/// +/// This is the unified barrier multi-versioning path that replaces the old +/// late barrier-only fixup in OptimizeForTarget. +/// Returns a map of old→new barrier buffers for outer block alloc_buffers +/// update. +Map ExpandPipelineBarriers( + Array &original_order, PipelineInfo &pipeline_info, + Map &buffer_data_to_buffer, + std::unordered_set + &allocated_buffers, + Array &block_local_allocs, Array &pipeline_allocs, + Var loop_var, PrimExpr loop_min, int num_stages) { + if (num_stages <= 1) + return {}; + + // Only expand barriers that have explicit ptx_arrive_barrier calls in the + // loop body. This distinguishes pipeline synchronization barriers (where + // arrive/wait are user-managed and need per-stage slots) from barriers + // whose arrival is managed internally by tile-ops (e.g., tcgen05 MMA + // arrive barriers) — those should NOT be pipeline-expanded. + // ISP-created pipeline_mbar is handled specially: it's always in + // block_local_allocs and was just created, so include it too. + std::unordered_set local_barrier_set; + for (const Buffer &buf : block_local_allocs) { + if (buf.scope() == "shared.barrier" || + buf.scope() == "shared.cluster_barrier") + local_barrier_set.insert(buf.get()); + } + + // Find barriers that have explicit ptx_arrive_barrier calls. + class ArriveBarrierDetector : public StmtExprVisitor { + public: + std::unordered_set arrived_; + void VisitExpr_(const CallNode *op) final { + if (op->op.same_as(builtin::ptx_arrive_barrier()) && !op->args.empty()) { + if (auto load = op->args[0].as()) { + arrived_.insert(load->buffer.get()); + } + } + StmtExprVisitor::VisitExpr_(op); + } + }; + ArriveBarrierDetector arrive_det; + for (const auto &block : original_order) { + arrive_det(block->body); + } + + std::vector all_referenced = + BarrierBufferCollector::Collect(original_order, buffer_data_to_buffer); + std::vector barriers; + for (const Buffer &buf : all_referenced) { + // Include if: (a) it's an ISP-created local barrier, OR + // (b) it has explicit ptx_arrive_barrier calls. + if (local_barrier_set.count(buf.get()) || + arrive_det.arrived_.count(buf.get())) { + barriers.push_back(buf); + } + } + if (barriers.empty()) + return {}; + + PrimExpr ns = IntImm(DataType::Int(32), num_stages); + PrimExpr stage_expr = FloorMod(loop_var - loop_min, ns); + PrimExpr parity_cycle = FloorMod(FloorDiv(loop_var - loop_min, ns), 2); + + auto replace_in_array = [](Array &arr, const Buffer &old_buf, + const Buffer &new_buf) { + for (size_t i = 0; i < arr.size(); ++i) { + if (arr[i].same_as(old_buf)) { + arr.Set(i, new_buf); + } + } + }; + + // Create expanded buffer for each barrier. + std::unordered_map old_to_new; + std::unordered_map old_shapes; + for (const Buffer &buf : barriers) { + old_shapes[buf.get()] = buf->shape[0]; + ObjectPtr new_node = + tvm::ffi::make_object(*(buf.get())); + new_node->shape = {PrimExpr(num_stages) * buf->shape[0]}; + Buffer new_buf(new_node); + old_to_new[buf.get()] = new_buf; + + // Update all maps and alloc arrays. + buffer_data_to_buffer.Set(buf->data, new_buf); + allocated_buffers.erase(buf); + allocated_buffers.insert(new_buf); + replace_in_array(block_local_allocs, buf, new_buf); + replace_in_array(pipeline_allocs, buf, new_buf); + } + + // Rewrite all blocks. + BarrierIndexRewriter rewriter(old_to_new, old_shapes, stage_expr, + parity_cycle, loop_var, loop_min); + for (size_t i = 0; i < original_order.size(); ++i) { + Block old_block = original_order[i]; + Stmt new_body = rewriter(old_block->body); + if (!new_body.same_as(old_block->body)) { + // Also rewrite alloc_buffers in the block (barriers may be allocated + // here). + Array new_allocs; + for (const Buffer &ab : old_block->alloc_buffers) { + auto it = old_to_new.find(ab.get()); + new_allocs.push_back(it != old_to_new.end() ? it->second : ab); + } + Block new_block(old_block->iter_vars, old_block->reads, old_block->writes, + old_block->name_hint, new_body, old_block->init, + new_allocs, old_block->match_buffers, + old_block->annotations); + PipelineAnnotation anno = pipeline_info.at(old_block); + pipeline_info.erase(old_block); + pipeline_info.emplace(new_block, anno); + original_order.Set(i, new_block); + } + } + + // Return the old→new mapping for outer block alloc_buffers update. + Map result; + for (const auto &[old_ptr, new_buf] : old_to_new) { + for (const Buffer &old_buf : barriers) { + if (old_buf.get() == old_ptr) { + result.Set(old_buf, new_buf); + break; + } + } + } + return result; +} + +/*! + * \brief Rewrite TMA-eligible copy blocks in the pipeline body for + * pipeline-level barrier management. + * + * For each TMA copy: convert tl.tileop.copy → tl.tileop.tma_copy with a + * per-stage barrier slot and emit_arrive=1 so LowerTileOp emits arrive inside + * the thread-0 guard. + * + * For the first consumer stage block: prepend mbarrier_wait_parity with + * stage-indexed barrier reference and parity expression. + * + * \param original_order In/out: blocks in original pipeline order. + * \param pipeline_info In/out: block → PipelineAnnotation mapping. + * \param tma_copies Per-statement TMA flag array from PipelinePlanning. + * \param buffer_data_to_buffer In/out: buffer var → Buffer mapping. + * \param allocated_buffers In/out: set of allocated buffers. + * \param block_local_allocs In/out: buffers allocated in the pipeline + * block. + * \return The newly created barrier buffer (undefined if no TMA copies). + */ +Buffer RewritePipelineTmaBarriers( + Array &original_order, PipelineInfo &pipeline_info, + const Array &tma_copies, Map &buffer_data_to_buffer, + std::unordered_set + &allocated_buffers, + Array &block_local_allocs, Var loop_var, PrimExpr loop_min, + int num_stages) { + // Count TMA copies + int num_tma = 0; + for (const auto &tc : tma_copies) { + if (tc->value != 0) + num_tma++; + } + if (num_tma == 0) + return Buffer(); + + // Create pipeline barrier buffer with a single slot. The generic + // ExpandPipelineBarriers pass (called later) will expand it to + // num_stages slots along with all other barrier buffers. + Buffer barrier_buf = CreateMBarrierBuffer("pipeline_mbar", 1); + buffer_data_to_buffer.Set(barrier_buf->data, barrier_buf); + allocated_buffers.insert(barrier_buf); + block_local_allocs.push_back(barrier_buf); + + // Find the index of the last TMA copy for arrive emission. + int last_tma_idx = -1; + for (size_t i = 0; i < original_order.size(); i++) { + if (static_cast(tma_copies[i]->value) != 0) + last_tma_idx = static_cast(i); + } + + // Phase 1: Rewrite TMA copy blocks — all share barrier slot 0. + // ExpandPipelineBarriers (called later) will rewrite indices to be + // stage-dependent. Only the last TMA copy emits arrive. + for (size_t i = 0; i < original_order.size(); i++) { + if (static_cast(tma_copies[i]->value) == 0) + continue; + + bool is_last = (static_cast(i) == last_tma_idx); + Block old_block = original_order[i]; + CopyToTmaCopyRewriter rewriter(barrier_buf, + /*barrier_id=*/IntImm(DataType::Int(32), 0), + /*emit_arrive=*/is_last); + Stmt new_body = rewriter(old_block->body); + + Block new_block(old_block->iter_vars, old_block->reads, old_block->writes, + old_block->name_hint, new_body, old_block->init, + old_block->alloc_buffers, old_block->match_buffers, + old_block->annotations); + + PipelineAnnotation anno = pipeline_info.at(old_block); + pipeline_info.erase(old_block); + pipeline_info.emplace(new_block, anno); + original_order.Set(i, new_block); + } + + // Phase 2: Insert waits in consumer blocks (blocks that depend on TMA data). + // For simplicity, we insert waits before the first block whose stage > 0. + // This covers the common case where stage 0 = producers, stage 1 = consumer. + bool waits_inserted = false; + for (size_t i = 0; i < original_order.size(); i++) { + if (waits_inserted) + break; + Block old_block = original_order[i]; + int stage = pipeline_info.at(old_block).stage; + if (stage == 0) + continue; // still in producer stage + + // Wait on barrier slot 0 with single-slot parity. + // ExpandPipelineBarriers will rewrite index and parity for versioning. + Array wait_stmts; + { + PrimExpr barrier_ref = + MakeBarrierRef(barrier_buf, IntImm(DataType::Int(32), 0)); + PrimExpr ns = IntImm(DataType::Int(32), num_stages); + PrimExpr parity = FloorMod(FloorDiv(loop_var - loop_min, ns), 2); + wait_stmts.push_back(Evaluate(Call( + DataType::Handle(), mbarrier_wait_parity(), {barrier_ref, parity}))); + } + wait_stmts.push_back(old_block->body); + Stmt new_body = SeqStmt(wait_stmts); + + Block new_block(old_block->iter_vars, old_block->reads, old_block->writes, + old_block->name_hint, new_body, old_block->init, + old_block->alloc_buffers, old_block->match_buffers, + old_block->annotations); + + PipelineAnnotation anno = pipeline_info.at(old_block); + pipeline_info.erase(old_block); + pipeline_info.emplace(new_block, anno); + original_order.Set(i, new_block); + waits_inserted = true; + } + + return barrier_buf; +} + class PipelineInjector : private StmtExprMutator { public: static Stmt Inject(const PrimFunc &func) { auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - PipelineInjector injector(global_symbol); + auto target = func->GetAttr(tvm::attr::kTarget); + PipelineInjector injector(global_symbol, target); for (const auto &kv : func->buffer_map) { const Buffer &buffer = kv.second; injector.buffer_data_to_buffer_.Set(buffer->data, buffer); @@ -741,8 +2777,9 @@ class PipelineInjector : private StmtExprMutator { } private: - explicit PipelineInjector(Optional global_symbol) - : global_symbol_(std::move(global_symbol)) {} + explicit PipelineInjector(Optional global_symbol, + Optional target) + : global_symbol_(std::move(global_symbol)), target_(std::move(target)) {} /*! * \brief Check the pipeline satisfies the following conditions: @@ -791,6 +2828,36 @@ class PipelineInjector : private StmtExprMutator { } } + bool HasOverlappableStages(const PipelineInfo &pipeline_info) const { + std::optional first_stage; + for (const auto &pair : pipeline_info) { + int stage = pair.second.stage; + if (!first_stage.has_value()) { + first_stage = stage; + } else if (stage != first_stage.value()) { + return true; + } + } + return false; + } + + Map + StripPipelineAnnotations(const Map &annotations) const { + Map preserved_annotations; + for (const auto &kv : annotations) { + const String &key = kv.first; + if (key != tir::attr::software_pipeline_stage && + key != tir::attr::software_pipeline_order && + key != tir::attr::software_pipeline_async_stages && + key != kPipelineAsyncProducers && + key != kPipelineAsyncProducerGroups && key != kPipelineTmaCopies && + key != "num_stages" && key != "tl_pipelined_num_stages") { + preserved_annotations.Set(key, kv.second); + } + } + return preserved_annotations; + } + Stmt VisitStmt_(const ForNode *op) final { // Step 1: Recursively rewrite the children first. For for_node = Downcast(StmtExprMutator::VisitStmt_(op)); @@ -966,19 +3033,151 @@ class PipelineInjector : private StmtExprMutator { << ", but pipeline annotation is " << pipeline_orders << " with different size"; + std::unordered_set pipeline_async_stages; + if (auto async_annot = + op->annotations.Get(tir::attr::software_pipeline_async_stages)) { + for (const Integer &stage : + Downcast>(async_annot.value())) { + pipeline_async_stages.insert(static_cast(stage->value)); + } + } + Optional> pipeline_async_producers; + if (auto async_producers_anno = + op->annotations.Get(kPipelineAsyncProducers)) { + auto async_flags = Downcast>(async_producers_anno.value()); + CHECK_EQ(async_flags.size(), original_order.size()) + << "PrimFunc " << global_symbol_ << " has original order " + << original_order.Map( + [](const auto &block) { return block->name_hint; }) + << ", but async producer annotation is " << async_flags + << " with different size"; + pipeline_async_producers = async_flags; + } + Optional> pipeline_async_producer_groups; + if (auto async_groups_anno = + op->annotations.Get(kPipelineAsyncProducerGroups)) { + auto async_group_ids = + Downcast>(async_groups_anno.value()); + CHECK_EQ(async_group_ids.size(), original_order.size()) + << "PrimFunc " << global_symbol_ << " has original order " + << original_order.Map( + [](const auto &block) { return block->name_hint; }) + << ", but async producer group annotation is " << async_group_ids + << " with different size"; + pipeline_async_producer_groups = async_group_ids; + } + for (size_t i = 0; i < pipeline_stages.size(); i++) { int stage = static_cast(pipeline_stages[i]->value); + bool is_async_candidate = + pipeline_async_producers + ? (pipeline_async_producers.value()[i]->value != 0) + : (pipeline_async_stages.count(stage) > 0); + // Stages that already spell out async behavior themselves keep that + // ownership. The pipeline pass only injects async producer semantics for + // "plain" producer stages that do not already contain cp.async / async + // queue operations. + bool is_async = is_async_candidate && + !ContainsExplicitAsyncIntrinsics(original_order[i]->body); PipelineAnnotation stage_order{ - stage, /*order=*/static_cast(pipeline_orders[i]->value)}; + stage, + /*order=*/static_cast(pipeline_orders[i]->value), + /*async=*/is_async, + /*async_group_id=*/ + pipeline_async_producer_groups + ? static_cast( + pipeline_async_producer_groups.value()[i]->value) + : -1}; pipeline_info.emplace(original_order[i], stage_order); } ValidatePipelineBody(pipeline_info, original_order); + if (!HasOverlappableStages(pipeline_info)) { + if (const auto *realize = op->body.as()) { + const auto &block = realize->block; + for (const auto &buffer : block->alloc_buffers) { + buffer_data_to_buffer_.erase(buffer->data); + allocated_buffers_.erase(buffer); + } + } + return For(for_node->loop_var, for_node->min, for_node->extent, + for_node->kind, for_node->body, for_node->thread_binding, + StripPipelineAnnotations(for_node->annotations), + for_node->step, for_node->span); + } + + // Step 3.5: Pipeline-level TMA barrier management. + // When TMA copies are present (without warp specialization), rewrite + // them to use tl.tileop.tma_copy with shared pipeline barriers and insert + // mbarrier_wait_parity before the first consumer stage. + // Creates pipeline_mbar[pipeline_depth] at final size so LowerTileOp + // uses the provided barrier instead of allocating separate per-copy ones. + Buffer pipeline_barrier_buf; + int num_pipeline_tma_copies = 0; + { + int max_stage = 0; + for (const auto &pair : pipeline_info) { + max_stage = std::max(max_stage, pair.second.stage); + } + // Use the actual pipeline depth (number of buffer copies) for barrier + // sizing, not the SW pipeline stage count (max_stage + 1). + // Even for pipeline_depth=1 we create a shared barrier so that + // LowerTileOp uses it instead of allocating separate per-copy barriers. + Optional pipelined_num_stages = GetPipelineNumStages(op); + int pipeline_depth = + pipelined_num_stages.defined() + ? static_cast(pipelined_num_stages.value()->value) + : max_stage + 1; + // Clamp to at least 1 so we always allocate at least one barrier slot. + pipeline_depth = std::max(pipeline_depth, 1); + if (max_stage > 0) { + if (auto tma_copies_anno = op->annotations.Get(kPipelineTmaCopies)) { + auto tma_copies = Downcast>(tma_copies_anno.value()); + if (tma_copies.size() == original_order.size()) { + for (const auto &tc : tma_copies) { + if (tc->value != 0) + num_pipeline_tma_copies++; + } + if (num_pipeline_tma_copies > 0) { + pipeline_barrier_buf = RewritePipelineTmaBarriers( + original_order, pipeline_info, tma_copies, + buffer_data_to_buffer_, allocated_buffers_, + block_local_allocs, op->loop_var, op->min, pipeline_depth); + } + } + } + } + } + // Step 4: Rewrite the pipeline body. // local_allocs contains buffers allocated in the pipeline block itself. // pipeline_allocs contains all buffers that need multi-versioning, // including buffers from outer blocks. + // Step 4.5: Expand all barrier buffers for pipelining. + // This handles both ISP-created pipeline_mbar AND user-written + // T.alloc_barrier, so that no late standalone barrier-only fixup is needed. + // Must run BEFORE local_allocs is copied from block_local_allocs. + { + Optional pipelined_ns = GetPipelineNumStages(op); + int barrier_depth = 1; + if (pipelined_ns.defined()) { + barrier_depth = static_cast(pipelined_ns.value()->value); + } else if (op->annotations.count("num_stages")) { + barrier_depth = static_cast( + Downcast(op->annotations.Get("num_stages").value()) + ->value); + } + Map barrier_remap = ExpandPipelineBarriers( + original_order, pipeline_info, buffer_data_to_buffer_, + allocated_buffers_, block_local_allocs, pipeline_allocs, op->loop_var, + op->min, barrier_depth); + // Register expanded barriers for outer block alloc_buffers update. + for (const auto &[old_buf, new_buf] : barrier_remap) { + pending_buffer_remap_.Set(old_buf, new_buf); + } + } + Array local_allocs = block_local_allocs; // Add nested block allocs to local_allocs for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) { @@ -995,9 +3194,90 @@ class PipelineInjector : private StmtExprMutator { PipelineRewriter rewriter(buffer_data_to_buffer_, pipeline_allocs, local_allocs, tvm::ffi::GetRef(op), - pipeline_info, loop_var_let_wrappers, + pipeline_info, target_, loop_var_let_wrappers, loop_var_if_wrappers); Stmt pipeline = rewriter.BuildPipeline(); + subtree_modified_ = true; + + auto unwrap_outer_attrs = [](Stmt stmt) { + std::vector attrs; + while (const auto *attr = stmt.as()) { + attrs.push_back(Downcast(stmt)); + stmt = attr->body; + } + return std::make_pair(attrs, stmt); + }; + auto rewrap_outer_attrs = [](Stmt stmt, + const std::vector &attrs) { + for (auto it = attrs.rbegin(); it != attrs.rend(); ++it) { + stmt = AttrStmt((*it)->node, (*it)->attr_key, (*it)->value, stmt, + (*it)->span); + } + return stmt; + }; + + // Update barrier_init annotations for expanded barrier buffers. + // For pipeline_mbar (ISP-created): add new entry with arrive_count=1 per + // slot. For user barriers (T.alloc_barrier): replicate existing arrive + // counts across the expanded slots. + { + auto [outer_attrs, inner_stmt] = unwrap_outer_attrs(pipeline); + BlockRealize br = Downcast(inner_stmt); + Block block = br->block; + BlockNode *bn = block.CopyOnWrite(); + + Map> barrier_init_map; + if (bn->annotations.count("barrier_init")) { + barrier_init_map = Downcast>>( + bn->annotations.Get("barrier_init").value()); + } + bool changed = false; + + // Handle ISP-created pipeline barrier (needs new entry). + if (pipeline_barrier_buf.defined()) { + int num_slots = Downcast(pipeline_barrier_buf->shape[0])->value; + // After ExpandPipelineBarriers, pipeline_mbar has been expanded. + // Look up the expanded buffer via buffer_data_to_buffer_. + Buffer expanded_buf = + buffer_data_to_buffer_[pipeline_barrier_buf->data]; + int expanded_slots = Downcast(expanded_buf->shape[0])->value; + Array counts; + for (int s = 0; s < expanded_slots; ++s) { + counts.push_back(IntImm(DataType::Int(32), 1)); + } + barrier_init_map.Set(expanded_buf->data, counts); + changed = true; + } + + // Replicate existing barrier_init entries for expanded barriers. + Map> updated_init; + for (const auto &[var, counts] : barrier_init_map) { + Buffer buf = buffer_data_to_buffer_[var]; + int buf_size = Downcast(buf->shape[0])->value; + int orig_size = static_cast(counts.size()); + if (buf_size > orig_size && orig_size > 0 && + buf_size % orig_size == 0) { + // Replicate pattern to match expanded size. + Array new_counts; + for (int v = 0; v < buf_size; v += orig_size) { + for (const auto &c : counts) { + new_counts.push_back(c); + } + } + updated_init.Set(var, new_counts); + changed = true; + } else { + updated_init.Set(var, counts); + } + } + + if (changed) { + bn->annotations.Set("barrier_init", updated_init); + pipeline = rewrap_outer_attrs( + BlockRealize(br->iter_values, br->predicate, block, br->span), + outer_attrs); + } + } // Store the buffer remapping for updating outer block alloc_buffers for (const auto &kv : rewriter.GetBufferRemap()) { @@ -1011,20 +3291,25 @@ class PipelineInjector : private StmtExprMutator { }; if (!rewrap_fns.empty()) { if (pipeline_body_from_block) { - BlockRealize pipeline_realize = Downcast(pipeline); + auto [outer_attrs, inner_stmt] = unwrap_outer_attrs(pipeline); + BlockRealize pipeline_realize = Downcast(inner_stmt); Block pipeline_block = pipeline_realize->block; { BlockNode *block_node = pipeline_block.CopyOnWrite(); block_node->body = apply_wrappers(block_node->body); } - pipeline = BlockRealize(pipeline_realize->iter_values, - pipeline_realize->predicate, pipeline_block, - pipeline_realize->span); + pipeline = rewrap_outer_attrs( + BlockRealize(pipeline_realize->iter_values, + pipeline_realize->predicate, pipeline_block, + pipeline_realize->span), + outer_attrs); } else { pipeline = apply_wrappers(pipeline); } } + pipeline = AsyncCommitWaitAttrLowerer::Lower(pipeline); + if (const auto *realize = op->body.as()) { const auto &block = realize->block; for (const auto &buffer : block->alloc_buffers) { @@ -1041,28 +3326,168 @@ class PipelineInjector : private StmtExprMutator { allocated_buffers_.insert(buffer); } + bool outer_flag = subtree_modified_; + subtree_modified_ = false; Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + bool children_modified = subtree_modified_; + // Propagate to parent: if this subtree was modified, parent should know. + subtree_modified_ = outer_flag || children_modified; // Update alloc_buffers with any pending buffer remaps from pipeline // rewriting. This handles buffers allocated in this block but // multi-versioned during pipeline rewriting of inner loops. + bool allocs_changed = false; + bool layout_changed = false; Array new_alloc_buffers; + std::vector> remapped_allocs; for (const auto &buffer : block->alloc_buffers) { if (auto remapped = pending_buffer_remap_.Get(buffer)) { new_alloc_buffers.push_back(remapped.value()); - // Remove from pending after applying + remapped_allocs.emplace_back(buffer, remapped.value()); pending_buffer_remap_.erase(buffer); + allocs_changed = true; } else { new_alloc_buffers.push_back(buffer); } } - Array> access = - GetBlockReadWriteRegion(block, buffer_data_to_buffer_); - BlockNode *n = block.CopyOnWrite(); - n->reads = access[0]; - n->writes = access[1]; - n->alloc_buffers = std::move(new_alloc_buffers); + if (!remapped_allocs.empty()) { + auto ann = block->annotations; + if (UpdateExpandedLayoutMapForRemappedAllocs(remapped_allocs, &ann)) { + block.CopyOnWrite()->annotations = std::move(ann); + layout_changed = true; + } + } + + // Replicate barrier_init counts for any expanded barrier buffers. + if (allocs_changed && block->annotations.count("barrier_init")) { + Map> init_map = Downcast>>( + block->annotations.Get("barrier_init").value()); + Map> new_init; + bool init_changed = false; + for (const auto &[var, counts] : init_map) { + // Find the buffer for this var — it may have been remapped. + Buffer buf; + for (const auto &ab : new_alloc_buffers) { + if (ab->data.same_as(var)) { + buf = ab; + break; + } + } + if (buf.defined()) { + int buf_size = Downcast(buf->shape[0])->value; + int orig_size = static_cast(counts.size()); + if (buf_size > orig_size && orig_size > 0 && + buf_size % orig_size == 0) { + Array new_counts; + for (int v = 0; v < buf_size; v += orig_size) { + for (const auto &c : counts) + new_counts.push_back(c); + } + new_init.Set(var, new_counts); + init_changed = true; + continue; + } + } + new_init.Set(var, counts); + } + if (init_changed) { + BlockNode *bn = block.CopyOnWrite(); + bn->annotations.Set("barrier_init", new_init); + bn->alloc_buffers = new_alloc_buffers; + allocs_changed = false; // already applied + } + } + + bool modified = children_modified || allocs_changed || layout_changed; + if (modified) { + // Recalculate reads/writes only when the block was actually + // modified by pipeline rewriting. Unconditional recalculation + // can embed references to block-local buffers (e.g. local.var) + // into the block's own read/write annotations, which misleads + // downstream LCA analysis and causes those buffers to be + // promoted to kernel parameters. + // + // After recalculation: + // 1. Drop BufferRegions whose buffer is allocated in this block. + // 2. Widen to full-region any BufferRegion whose index + // expressions reference a data var of any buffer allocated + // in this block or any nested block. This prevents + // downstream LCA analysis from seeing those vars at the + // outer scope and promoting them to kernel parameters. + std::unordered_set local_bufs; + std::unordered_set local_data_vars; + for (const auto &buf : block->alloc_buffers) { + local_bufs.insert(buf.get()); + local_data_vars.insert(buf->data.get()); + } + // Also collect data vars from all nested blocks. + PostOrderVisit(block->body, [&](const ObjectRef &obj) { + if (auto *inner = obj.as()) { + for (const auto &buf : inner->alloc_buffers) { + local_data_vars.insert(buf->data.get()); + } + } + }); + auto region_uses_local_var = [&](const BufferRegion &br) -> bool { + for (const auto &range : br->region) { + bool found = false; + PostOrderVisit(range->min, [&](const ObjectRef &obj) { + if (found) + return; + if (auto *load = obj.as()) { + if (local_data_vars.count(load->buffer->data.get())) { + found = true; + } + } + if (auto *var = obj.as()) { + if (local_data_vars.count(var)) { + found = true; + } + } + }); + if (found) + return true; + PostOrderVisit(range->extent, [&](const ObjectRef &obj) { + if (found) + return; + if (auto *load = obj.as()) { + if (local_data_vars.count(load->buffer->data.get())) { + found = true; + } + } + if (auto *var = obj.as()) { + if (local_data_vars.count(var)) { + found = true; + } + } + }); + if (found) + return true; + } + return false; + }; + Array> access = + GetBlockReadWriteRegion(block, buffer_data_to_buffer_); + auto sanitize = [&](const Array ®ions) { + Array out; + for (const auto &br : regions) { + if (local_bufs.count(br->buffer.get())) { + continue; // drop block-local buffer + } + if (region_uses_local_var(br)) { + out.push_back(BufferRegion::FullRegion(br->buffer)); + } else { + out.push_back(br); + } + } + return out; + }; + BlockNode *n = block.CopyOnWrite(); + n->reads = sanitize(access[0]); + n->writes = sanitize(access[1]); + n->alloc_buffers = std::move(new_alloc_buffers); + } for (const auto &buffer : op->alloc_buffers) { buffer_data_to_buffer_.erase(buffer->data); @@ -1093,11 +3518,16 @@ class PipelineInjector : private StmtExprMutator { Map buffer_data_to_buffer_; std::unordered_set allocated_buffers_; Map pending_buffer_remap_; + Optional target_; // Buffers from outer blocks that have been used in a pipeline loop. // Used to detect if the same buffer is used in multiple pipeline loops. std::unordered_set buffers_used_in_pipeline_; Optional global_symbol_; + // Track whether any pipeline was actually injected in the current + // subtree. Used to avoid unnecessary reads/writes recalculation + // on blocks whose descendants were not modified. + bool subtree_modified_ = false; }; } // namespace software_pipeline @@ -1111,7 +3541,6 @@ tir::transform::Pass InjectSoftwarePipeline() { auto *fptr = f.CopyOnWrite(); fptr->body = software_pipeline::PipelineInjector::Inject(f); fptr->body = ConvertSSA(std::move(fptr->body)); - fptr->body = StripTmaCopyWriteBufferAttr(std::move(fptr->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {}); diff --git a/src/transform/instruction_annotation.cc b/src/transform/instruction_annotation.cc new file mode 100644 index 0000000000..b2ab121fed --- /dev/null +++ b/src/transform/instruction_annotation.cc @@ -0,0 +1,239 @@ +/*! + * \file instruction_annotation.cc + * \brief Annotate tile operations with coarse-grained instruction kind. + * + * This pass runs **before** LayoutInference and LowerTileOp. It inspects + * every `tl.tileop.*` Call node and determines the instruction category that + * will eventually be selected during lowering. The result is stored as a + * string annotation (`tl_instruction_kind`) on the Call node so that later + * passes (e.g. warp specialization) can make structural decisions without + * needing the full lowered IR. + * + * For copy operations the classification is: + * - "tma" : will use TMA bulk load/store (descriptor or 1-D) + * - "cp_async" : will use cp.async + * - "sync" : synchronous copy (SIMT / LDSM / STSM / TMem / normal) + * + * For gemm operations the classification is: + * - "wgmma" : Hopper warp-group MMA + * - "tcgen5mma" : Blackwell TCGEN5 MMA + * - "mma" : Volta/Ampere tensor-core MMA + * - "mfma" : AMD CDNA matrix fused multiply-add + * - "scalar" : scalar fallback + * + * Because this pass runs before layout inference it intentionally uses only + * coarse checks (target arch, buffer scopes, shape alignment) that do not + * depend on the inferred memory layout. + */ + +#include +#include +#include +#include +#include +#include + +#include "../op/builtin.h" +#include "../op/copy.h" +#include "../op/gemm.h" +#include "../op/operator.h" +#include "../op/utils.h" +#include "../target/utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +namespace { + +/// Annotation key written by this pass. +static constexpr const char *kInstructionKind = "tl_instruction_kind"; + +static bool IsAutoAsyncCopyEnabled(Target target, bool default_enabled = true) { + using namespace tvm::transform; + PassContext pass_ctx = PassContext::Current(); + return TargetHasAsyncCopy(target) && + pass_ctx->GetConfig(kEnableAsyncCopy, Bool(default_enabled)) + .value(); +} + +static bool CanUseAutoCPAsyncCopy(const CopyNode *copy, Target target, + arith::Analyzer *analyzer, + bool default_enabled = true) { + return copy != nullptr && !copy->GetIsTmaCopy() && !copy->GetIsAsyncCopy() && + IsAutoAsyncCopyEnabled(target, default_enabled) && + copy->CheckCPAsyncCopy(target, LayoutMap(), analyzer); +} + +// --------------------------------------------------------------------------- +// Classify copy ops +// --------------------------------------------------------------------------- + +/*! + * \brief Determine the coarse instruction kind for a CopyNode. + * + * The classification does **not** depend on layout_map (which is unavailable + * at this point). It mirrors the priority order in CopyNode::GetCopyInst but + * collapses BulkLoad/BulkLoad1D/BulkStore/BulkStore1D into "tma" and skips + * checks that require layout information. + */ +std::string ClassifyCopy(const CopyNode *copy, Target target, bool in_pipeline, + arith::Analyzer *analyzer) { + // Explicit T.tma_copy() — always TMA. + if (copy->GetIsTmaCopy()) { + // Verify target can do TMA at all. + if (copy->CheckBulkLoad(target, analyzer, /*check_last_dim=*/false) || + copy->CheckBulkStore(target, analyzer, /*check_last_dim=*/false)) { + return "tma"; + } + // User asked for TMA but target doesn't support it — leave unannotated + // so that LowerTileOp can produce a proper error later. + return "sync"; + } + + // Explicit T.async_copy() — always cp.async. + if (copy->GetIsAsyncCopy()) { + return "cp_async"; + } + + // Generic T.copy() stays synchronous here. Auto-TMA is only introduced by + // warp-specialized rewriting, which rewrites the op to explicit T.tma_copy. + + // LDSM / STSM / TMem — these are synchronous from the WS perspective. + if (copy->CheckLDSMCopy(target) || copy->CheckSTSMCopy(target) || + copy->CheckTMemLoad(target) || copy->CheckTMemStore(target)) { + return "sync"; + } + + // Inside a pipelined loop, eligible copies may be lowered to cp.async. + if (in_pipeline && CanUseAutoCPAsyncCopy(copy, target, analyzer, + /*default_enabled=*/false)) { + return "cp_async"; + } + + return "sync"; +} + +// --------------------------------------------------------------------------- +// Classify gemm ops +// --------------------------------------------------------------------------- + +std::string ClassifyGemm(const GemmNode *gemm, int block_size, Target target) { + GemmInst inst = gemm->getGemmInst(block_size, target); + switch (inst) { + case GemmInst::kWGMMA: + return "wgmma"; + case GemmInst::kTCGEN5MMA: + return "tcgen5mma"; + case GemmInst::kMMA: + return "mma"; + case GemmInst::kMFMA: + return "mfma"; + case GemmInst::kScalar: + return "scalar"; + default: + return "unknown"; + } +} + +// --------------------------------------------------------------------------- +// IR rewriter +// --------------------------------------------------------------------------- + +class InstructionAnnotator : public StmtExprMutator { +public: + static PrimFunc Annotate(PrimFunc f) { + auto target = f->GetAttr(tvm::attr::kTarget); + ICHECK(target.defined()) + << "InstructionAnnotation: target attribute is required"; + + InstructionAnnotator annotator; + annotator.target_ = target.value(); + PrimFuncNode *fptr = f.CopyOnWrite(); + fptr->body = annotator.VisitStmt(f->body); + return f; + } + +private: + // Track threadIdx.x extent for gemm instruction selection. + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag == "threadIdx.x") { + if (auto *int_imm = op->value.as()) { + block_size_ = static_cast(int_imm->value); + } + } + } + return StmtExprMutator::VisitStmt_(op); + } + + // Track whether we are inside a pipelined loop. + Stmt VisitStmt_(const ForNode *op) final { + bool old_in_pipeline = in_pipeline_; + if (op->annotations.Get("num_stages")) { + in_pipeline_ = true; + } + Stmt result = StmtExprMutator::VisitStmt_(op); + in_pipeline_ = old_in_pipeline; + return result; + } + + PrimExpr VisitExpr_(const CallNode *op) final { + Call call = Downcast(StmtExprMutator::VisitExpr_(op)); + + // Only process tile operators. + auto tile_op = ParseOperator(call); + if (!tile_op.defined()) + return call; + + // Skip if already annotated. + if (call->annotations.count(kInstructionKind)) + return call; + + std::string kind; + + if (auto *copy_node = tile_op.as()) { + kind = ClassifyCopy(copy_node, target_, in_pipeline_, &analyzer_); + } else if (auto *gemm_node = tile_op.as()) { + kind = ClassifyGemm(gemm_node, block_size_, target_); + } else { + // Other tile ops (reduce, fill, etc.) are synchronous. + kind = "sync"; + } + + // Create a new Call with the annotation added. + auto new_annotations = call->annotations; + new_annotations.Set(kInstructionKind, StringImm(kind)); + return Call(call->dtype, call->op, call->args, new_annotations, call->span); + } + + Target target_; + bool in_pipeline_{false}; + int block_size_{0}; + arith::Analyzer analyzer_; +}; + +} // namespace + +// --------------------------------------------------------------------------- +// Pass registration +// --------------------------------------------------------------------------- + +tvm::transform::Pass InstructionAnnotation() { + using namespace tir::transform; + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + return InstructionAnnotator::Annotate(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.InstructionAnnotation", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.InstructionAnnotation", + InstructionAnnotation); +} + +} // namespace tl +} // namespace tvm diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index 6e9e8b5ba8..7935391d35 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -18,15 +18,16 @@ #include "../layout/layout.h" #include "../layout/utils.h" +#include "../op/builtin.h" #include "../op/copy.h" #include "../op/parallel.h" #include "../op/region.h" #include "../op/utils.h" #include "../target/utils.h" - #include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_visitor_with_analyzer.h" #include "common/loop_fusion_utils.h" +#include "common/pipeline_utils.h" #include "common/union_find.h" #include "layout_reducer.h" #include "parallel_loop_layout_validator.h" @@ -101,7 +102,6 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { auto thread_bounds = thread_bounds_vec_[cur_infer_id]; arith::Analyzer *cur_analyzer = analyzer_vec_[cur_infer_id].get(); auto buffer_oob = buffer_oob_vec_[cur_infer_id]; - bool in_pipeline = in_pipeline_vec_[cur_infer_id]; // Double-check that 'next' is valid ICHECK(next.defined()) << "infer_list_[" << cur_infer_id << "] is null inside run_infer_step."; @@ -129,7 +129,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { buffer_oob, {}, let_var_to_expr_, - in_pipeline}, + false}, level); // Process the returned updates @@ -312,10 +312,6 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { ICHECK_EQ(buffer_oob_vec_.size(), infer_list_.size()) << "Size mismatch: buffer_oob_vec_ and infer_list_ must match in " "length."; - ICHECK_EQ(in_pipeline_vec_.size(), infer_list_.size()) - << "Size mismatch: in_pipeline_vec_ and infer_list_ must match in " - "length."; - DLOG(INFO) << "[InferLayout] all participating operators:" << '\n'; for (int i = 0; i < infer_list_stmt_.size(); ++i) { DLOG(INFO) << " op " << i << ":" << infer_list_stmt_[i] << '\n'; @@ -401,7 +397,6 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { } } } - Layout reshaped = shapes_equal ? rep_layout.value() @@ -539,17 +534,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { } // Compute thread_var_ and thread_bounds_ thread_var_vec_.push_back(thread_var_); - if (analyzer_.const_int_bound.IsBound(thread_var_->var)) { - auto const_int_bound = analyzer_.const_int_bound(thread_var_); - auto min_value = const_int_bound->min_value; - auto max_value = const_int_bound->max_value; - auto extent = max_value - min_value + 1; - auto dtype = thread_var_->var.dtype(); - thread_bounds_vec_.push_back(Range::FromMinExtent( - IntImm(dtype, min_value), IntImm(dtype, extent))); - } else { - thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1)); - } + thread_bounds_vec_.push_back(CurrentThreadBounds()); analyzer_vec_.push_back(analyzer_.Clone()); // Compute buffer oob for each buffer in the op @@ -584,7 +569,6 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { // Add the tile operator to infer_list_ infer_list_stmt_.push_back(tvm::ffi::GetRef(op)); infer_list_.push_back(std::move(p)); - in_pipeline_vec_.push_back(pipelined_depth_ > 0); } } @@ -648,17 +632,6 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { } void VisitStmt_(const ForNode *op) final { - bool enter_pipelined = false; - if (auto num_stages_anno = op->annotations.Get("num_stages")) { - const auto *imm = num_stages_anno->as(); - ICHECK(imm) << "For annotation num_stages must be IntImm, but got " - << num_stages_anno.value(); - enter_pipelined = imm->value > 0; - } - if (enter_pipelined) { - ++pipelined_depth_; - } - if (op->kind == ForKind::kParallel) { auto infer = ParallelOp(tvm::ffi::GetRef(op)); for (const auto &[buffer, _] : infer->GetIndiceMap()) { @@ -731,29 +704,13 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { }); infer_list_stmt_.push_back(tvm::ffi::GetRef(op)); infer_list_.push_back(std::move(infer)); - in_pipeline_vec_.push_back(pipelined_depth_ > 0); thread_var_vec_.push_back(thread_var_); - if (thread_var_.defined() && - analyzer_.const_int_bound.IsBound(thread_var_->var)) { - auto const_int_bound = analyzer_.const_int_bound(thread_var_); - auto dtype = thread_var_->var.dtype(); - auto extent = - const_int_bound->max_value - const_int_bound->min_value + 1; - thread_bounds_vec_.push_back(Range::FromMinExtent( - IntImm(dtype, const_int_bound->min_value), IntImm(dtype, extent))); - } else { - thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1)); - } + thread_bounds_vec_.push_back(CurrentThreadBounds()); analyzer_vec_.push_back(analyzer_.Clone()); buffer_oob_vec_.push_back(false); } else { IRVisitorWithAnalyzer::VisitStmt(op->body); } - - if (enter_pipelined) { - ICHECK_GT(pipelined_depth_, 0); - --pipelined_depth_; - } } void VisitStmt_(const BlockNode *op) final { @@ -803,6 +760,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { } else { // Use the first buffer sharing this var as the base for dtype ratio int64_t base_bits = GetElementStorageBits(buffers[0]->dtype); + auto reshaped_layout = layout->Reshape(buffer->shape, &analyzer_, Integer(base_bits), Integer(GetElementStorageBits(buffer->dtype))); @@ -848,6 +806,10 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { }); } + Range CurrentThreadBounds() const { + return ComputeThreadBounds(thread_var_, analyzer_); + } + void VisitExpr_(const BufferLoadNode *op) final { // Collect buffer from BufferLoad if (op->buffer.defined() && op->buffer->data.defined()) { @@ -978,17 +940,11 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { // This is a floating access - record buffer with current thread_bounds if (floating_buffers_.find(buffer) != floating_buffers_.end()) return; // Already recorded - Range thread_bounds = Range::FromMinExtent(0, 1); - if (thread_var_.defined() && - analyzer_.const_int_bound.IsBound(thread_var_->var)) { - auto const_int_bound = analyzer_.const_int_bound(thread_var_); - auto dtype = thread_var_->var.dtype(); - auto extent = - const_int_bound->max_value - const_int_bound->min_value + 1; - thread_bounds = Range::FromMinExtent( - IntImm(dtype, const_int_bound->min_value), IntImm(dtype, extent)); - } - floating_buffers_[buffer] = thread_bounds; + floating_buffers_[buffer] = CurrentThreadBounds(); + } + + Range CurrentThreadBounds() const { + return ComputeThreadBounds(thread_var_, analyzer_); } const std::unordered_set &nodes_in_tileops_; @@ -1017,10 +973,6 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { Map let_var_to_expr_; std::vector infer_list_stmt_; std::vector infer_list_; - // Whether the corresponding op was observed inside a pipelined loop - // (i.e., a surrounding For annotated with num_stages > 0). - std::vector in_pipeline_vec_; - int pipelined_depth_{0}; // Fragment buffers that have accesses outside of TileOps. // These "floating" buffers need fully replicated layouts since their // access patterns cannot be inferred from TileOp semantics. diff --git a/src/transform/legalize_safe_memory_access.cc b/src/transform/legalize_safe_memory_access.cc index 628079f4ea..b60cb099d1 100644 --- a/src/transform/legalize_safe_memory_access.cc +++ b/src/transform/legalize_safe_memory_access.cc @@ -104,12 +104,30 @@ struct SafeMemChecker : public StmtExprVisitor { continue; } + // Shared/local warning checks are best-effort only. Some swizzled + // indices generated by tile-op lowering contain floor-div/mod and + // bitwise arithmetic that the interval analyzer cannot robustly handle. + // Do not let warning-only analysis turn into a hard compilation error. + if (throw_warning && HasAnalyzerFragilePattern(index)) { + continue; + } + // We want to check if index < shape_dim can be proven. // If analyzer->CanProve(index < shape_dim) returns false, // it means we cannot prove the access is within bounds. PrimExpr upper_bound_cond = index < shape_dim; - if (!analyzer_->CanProve(upper_bound_cond, - arith::ProofStrength::kSymbolicBound)) { + bool can_prove_upper = false; + try { + can_prove_upper = analyzer_->CanProve( + upper_bound_cond, arith::ProofStrength::kSymbolicBound); + } catch (const Error &e) { + // Some layout-lowered sparse/global indices contain arithmetic that + // defeats interval reasoning. Safe-memory legalization should remain + // conservative in that case and emit an explicit runtime guard instead + // of hard-failing compilation. + can_prove_upper = false; + } + if (!can_prove_upper) { if (throw_warning) { LOG(WARNING) << "Index access may exceed buffer bounds: " << index << " >= " << shape_dim @@ -120,8 +138,14 @@ struct SafeMemChecker : public StmtExprVisitor { } // Check if index >= 0 can be proven. PrimExpr lower_bound_cond = index >= 0; - if (!analyzer_->CanProve(lower_bound_cond, - arith::ProofStrength::kSymbolicBound)) { + bool can_prove_lower = false; + try { + can_prove_lower = analyzer_->CanProve( + lower_bound_cond, arith::ProofStrength::kSymbolicBound); + } catch (const Error &e) { + can_prove_lower = false; + } + if (!can_prove_lower) { if (throw_warning) { LOG(WARNING) << "Index access may be negative: " << index << " < 0" << "; Buffer name: " << buffer->name; @@ -132,6 +156,28 @@ struct SafeMemChecker : public StmtExprVisitor { } } + static bool HasAnalyzerFragilePattern(const PrimExpr &expr) { + bool fragile = false; + PostOrderVisit(expr, [&](const ObjectRef &obj) { + if (obj->IsInstance() || obj->IsInstance() || + obj->IsInstance() || obj->IsInstance()) { + fragile = true; + return; + } + if (const auto *call = obj.as()) { + if (const auto *op_node = call->op.as()) { + String name = op_node->name; + if (name == "tir.bitwise_and" || name == "tir.bitwise_or" || + name == "tir.bitwise_xor" || name == "tir.shift_left" || + name == "tir.shift_right") { + fragile = true; + } + } + } + }); + return fragile; + } + Array GetConditions() { return _conditions; } private: diff --git a/src/transform/lower_ptx_async_copy.cc b/src/transform/lower_ptx_async_copy.cc index 2b7cbd1d30..1324456af3 100644 --- a/src/transform/lower_ptx_async_copy.cc +++ b/src/transform/lower_ptx_async_copy.cc @@ -36,8 +36,10 @@ class PTXAsyncCopyInjector : public StmtMutator { : enable_auto_async_copy_(enable_auto_async_copy), async_without_async_commit_wait_(async_without_async_commit_wait) {} + bool InjectedPTXAsyncCopy() const { return injected_ptx_async_copy_; } + Stmt Finalize(Stmt body) { - if (!pending_sync_copies_ || async_without_async_commit_wait_) { + if (!pending_sync_copies_ || UseExplicitAsyncSemantics()) { pending_sync_copies_ = false; uncommitted_sync_copies_ = false; return body; @@ -52,6 +54,17 @@ class PTXAsyncCopyInjector : public StmtMutator { return SeqStmt(seq); } + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::async_scope) { + ++explicit_async_scope_depth_; + Stmt body = this->VisitStmt(op->body); + --explicit_async_scope_depth_; + // `async_scope` is a lowering-only marker for cp.async semantics. + return body; + } + return StmtMutator::VisitStmt_(op); + } + Stmt VisitStmt_(const ForNode *op) final { // Track nested vectorized loop extents so we can rewrite element-wise // copies (e.g. float16 stores) into `tir.ptx_cp_async` with element bytes, @@ -138,7 +151,7 @@ class PTXAsyncCopyInjector : public StmtMutator { } Stmt VisitStmt_(const SeqStmtNode *op) final { - if (async_without_async_commit_wait_) { + if (UseExplicitAsyncSemantics()) { return StmtMutator::VisitStmt_(op); } @@ -212,7 +225,7 @@ class PTXAsyncCopyInjector : public StmtMutator { } Stmt VisitStmt_(const IfThenElseNode *op) final { - if (async_without_async_commit_wait_) { + if (UseExplicitAsyncSemantics()) { return StmtMutator::VisitStmt_(op); } @@ -266,7 +279,8 @@ class PTXAsyncCopyInjector : public StmtMutator { TryInjectPTX(load, store, predicate.defined(), predicate.defined() ? predicate.value() : PrimExpr()); if (injected.defined()) { - if (!async_without_async_commit_wait_) { + injected_ptx_async_copy_ = true; + if (!UseExplicitAsyncSemantics()) { pending_sync_copies_ = true; uncommitted_sync_copies_ = true; } @@ -278,6 +292,10 @@ class PTXAsyncCopyInjector : public StmtMutator { } private: + bool UseExplicitAsyncSemantics() const { + return async_without_async_commit_wait_ || explicit_async_scope_depth_ > 0; + } + // A copy candidate represented after flattening source/destination indexing. struct CopyIndexInfo { PrimExpr src_index; @@ -688,20 +706,24 @@ class PTXAsyncCopyInjector : public StmtMutator { bool enable_auto_async_copy_{true}; bool async_without_async_commit_wait_{false}; + int explicit_async_scope_depth_{0}; int current_vectorized_lanes_{1}; std::vector active_vectorized_loops_; arith::Analyzer analyzer_; + bool injected_ptx_async_copy_{false}; bool pending_sync_copies_{false}; bool uncommitted_sync_copies_{false}; }; using namespace tir::transform; -Stmt InjectPTXAsyncCopy(const Stmt &body, bool enable_auto_async_copy, - bool async_without_async_commit_wait) { +PTXAsyncCopyInjectResult +InjectPTXAsyncCopy(const Stmt &body, bool enable_auto_async_copy, + bool async_without_async_commit_wait) { PTXAsyncCopyInjector injector(enable_auto_async_copy, async_without_async_commit_wait); - return injector.Finalize(injector(body)); + Stmt injected = injector(body); + return {injector.Finalize(injected), injector.InjectedPTXAsyncCopy()}; } tvm::transform::Pass LowerPTXAsyncCopy() { @@ -724,9 +746,10 @@ tvm::transform::Pass LowerPTXAsyncCopy() { ctx->GetConfig(kEnableAsyncCopy, Bool(true)).value(); auto *n = f.CopyOnWrite(); - PTXAsyncCopyInjector injector(enable_auto_async_copy, - /*async_without_async_commit_wait=*/false); - n->body = injector.Finalize(injector(n->body)); + auto inject_result = + InjectPTXAsyncCopy(n->body, enable_auto_async_copy, + /*async_without_async_commit_wait=*/false); + n->body = inject_result.stmt; return f; }; return CreatePrimFuncPass(pass_func, 0, "tl.LowerPTXAsyncCopy", {}); diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 06bec7d106..f955babfde 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -24,6 +24,7 @@ #include "arith/ir_mutator_with_analyzer.h" #include "common/mbarrier.h" +#include "common/pipeline_utils.h" #include "layout_reducer.h" #include "loop_partition.h" @@ -230,17 +231,17 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { // If any TMA copies allocated mbarriers, inject the barrier buffer // into the tilelang_root block with a barrier_init annotation. - // MultiVersionBuffer will expand it for pipelining, and + // Pipeline buffer versioning expands it for pipelining, and // LowerSharedBarrier will process it into ptx_init_barrier_thread_count. if (substituter.mbarrier_count_ > 0) { ICHECK(substituter.mbarrier_buffer_.defined()) << "mbarrier_buffer_ must have been created by AllocMBarrier " "callback"; Buffer mbar_buf = substituter.mbarrier_buffer_.value(); - // Update buffer shape in-place to final count. We use const_cast + // Update buffer shape in-place to final count. We use const_cast // because CopyOnWrite would create a new BufferNode, breaking identity - // with BufferLoad references already in the body. MultiVersionBuffer - // relies on buffer identity to remap accesses correctly. + // with BufferLoad references already in the body. Pipeline buffer + // versioning relies on buffer identity to remap accesses correctly. const_cast(mbar_buf.get())->shape = { IntImm(DataType::Int(32), substituter.mbarrier_count_)}; @@ -1050,19 +1051,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { return workspace.access_ptr(2); // write }; - Range thread_bounds; - - if (analyzer_->const_int_bound.IsBound(thread_var_->var)) { - auto const_int_bound = analyzer_->const_int_bound(thread_var_); - auto min_value = const_int_bound->min_value; - auto max_value = const_int_bound->max_value; - auto extent = max_value + 1 - min_value; - thread_bounds = - Range::FromMinExtent(IntImm(thread_var_->var.dtype(), min_value), - IntImm(thread_var_->var.dtype(), extent)); - } else { - thread_bounds = Range::FromMinExtent(0, 1); - } + Range thread_bounds = CurrentThreadBounds(); // Convert let_bindings_ to Map for LowerArgs Map let_var_to_expr; @@ -1079,38 +1068,23 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { return id; }; - // Compute mbarrier expressions from the enclosing loop and pipeline info. - // pipeline_num_stages: number of pipeline stages (from T.Pipelined - // annotation) mbar_stage_expr: ko % num_stages (cycles through multiple - // mbarriers) mbar_phase_expr: (ko / num_stages) % 2 (mbarrier parity for - // wait) - int pipeline_num_stages = 1; - PrimExpr mbar_phase_expr; - PrimExpr mbar_stage_expr = IntImm(DataType::Int(32), 0); - if (!loop_var_stack_.empty()) { - pipeline_num_stages = pipeline_num_stages_stack_.back(); - Var loop_var = loop_var_stack_.back(); - PrimExpr ns = IntImm(DataType::Int(32), pipeline_num_stages); - mbar_stage_expr = FloorMod(loop_var, ns); - mbar_phase_expr = - FloorMod(FloorDiv(loop_var, ns), IntImm(DataType::Int(32), 2)); - } else { - mbar_phase_expr = IntImm(DataType::Int(32), 0); - } - - auto lowered = tile_op->Lower( - LowerArgs{target_, thread_bounds, thread_var_->var, callback, - mbarrier_callback, layout_map_, buffer_remap_, - let_var_to_expr, - /*in_pipeline=*/pipelined_depth_ > 0, mbar_phase_expr, - pipeline_num_stages, mbar_stage_expr, &mbarrier_buffer_, - cluster_size_}, - analyzer_); + auto lowered = + tile_op->Lower(LowerArgs{target_, thread_bounds, thread_var_->var, + callback, mbarrier_callback, layout_map_, + buffer_remap_, let_var_to_expr, + loop_mbar_phase_stack_.empty() + ? PrimExpr(IntImm(DataType::Int(32), 0)) + : loop_mbar_phase_stack_.back(), + &mbarrier_buffer_, cluster_size_}, + analyzer_); return IRMutatorWithAnalyzer::VisitStmt(lowered); } Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == kPipelineContextNumStages) { + return VisitStmt(op->body); + } if (op->attr_key == tir::attr::thread_extent) { IterVar iv = Downcast(op->node); ICHECK_NE(iv->thread_tag.length(), 0U); @@ -1145,16 +1119,26 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { * @return Stmt The lowered statement. */ Stmt VisitStmt_(const ForNode *op) final { - // Track enclosing loop variables for mbarrier parity computation. - loop_var_stack_.push_back(op->loop_var); - // Track pipeline num_stages from the loop's annotation. - int num_stages = 1; - if (auto ns_anno = op->annotations.Get("num_stages")) { - if (auto *ns_int = ns_anno.value().as()) { - num_stages = static_cast(ns_int->value); + bool pushed_loop_mbar_phase = false; + if (op->kind == ForKind::kSerial) { + int num_stages = 1; + if (auto ns_anno = op->annotations.Get("num_stages")) { + if (const auto *ns_int = ns_anno.value().as()) { + if (ns_int->value > 1) { + num_stages = static_cast(ns_int->value); + } + } + } + PrimExpr phase_expr; + if (num_stages > 1) { + phase_expr = FloorMod(FloorDiv(op->loop_var, num_stages), + IntImm(DataType::Int(32), 2)); + } else { + phase_expr = FloorMod(op->loop_var, IntImm(DataType::Int(32), 2)); } + loop_mbar_phase_stack_.push_back(analyzer_->Simplify(phase_expr)); + pushed_loop_mbar_phase = true; } - pipeline_num_stages_stack_.push_back(num_stages); // Extract reducer info from annotations Map reducer_info; @@ -1164,25 +1148,11 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { .value(); } - bool enter_pipelined = false; - if (auto num_stages_anno = op->annotations.Get("num_stages")) { - const auto *imm = num_stages_anno->as(); - ICHECK(imm) << "For annotation num_stages must be IntImm, but got " - << num_stages_anno.value(); - enter_pipelined = imm->value > 0; - } - if (enter_pipelined) { - ++pipelined_depth_; - } - // First visit the body. For for_node = Downcast(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); - if (enter_pipelined) { - ICHECK_GT(pipelined_depth_, 0); - --pipelined_depth_; + if (pushed_loop_mbar_phase) { + loop_mbar_phase_stack_.pop_back(); } - loop_var_stack_.pop_back(); - pipeline_num_stages_stack_.pop_back(); // Only process parallel loops if (op->kind != ForKind::kParallel) { @@ -1357,14 +1327,20 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { bool enable_auto_async_copy = ctx->GetConfig(kEnableAsyncCopy, Bool(true)).value(); bool should_enable_async_copy = - (enable_auto_async_copy && (pipelined_depth_ > 0)) || - parallel_prefer_async; - lowered = InjectPTXAsyncCopy(lowered, should_enable_async_copy, - parallel_async_without_async_commit_wait); + parallel_prefer_async || + (enable_auto_async_copy && parallel_async_without_async_commit_wait); + auto inject_result = + InjectPTXAsyncCopy(lowered, should_enable_async_copy, + parallel_async_without_async_commit_wait); + lowered = inject_result.stmt; } return lowered; } + Range CurrentThreadBounds() const { + return ComputeThreadBounds(thread_var_, *analyzer_); + } + Target target_; Map buffer_data_to_buffer_; Map layout_map_; @@ -1386,10 +1362,8 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { std::vector mbarrier_arrive_counts_; // The shared.barrier scope buffer created lazily by AllocMBarrier callback. Optional mbarrier_buffer_; - // Stack of enclosing loop variables for mbarrier parity computation. - std::vector loop_var_stack_; - // Stack of pipeline num_stages values from enclosing loop annotations. - std::vector pipeline_num_stages_stack_; + // Fallback mbarrier parity derived from the nearest enclosing serial loop. + std::vector loop_mbar_phase_stack_; // For ptx Node, we need to remap the buffer and indices // By access CallNode instead of BufferLoad Node. bool is_ptx_{false}; @@ -1404,7 +1378,6 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { // without recomputing indices, since swizzle is encoded in TMA descriptor // parameters rather than in memory indices. bool in_tma_context_{false}; - int pipelined_depth_{0}; }; namespace transform { diff --git a/src/transform/multi_version_buffer_rewriter.cc b/src/transform/multi_version_buffer_rewriter.cc index bfe3c70b93..e73e2e1b93 100644 --- a/src/transform/multi_version_buffer_rewriter.cc +++ b/src/transform/multi_version_buffer_rewriter.cc @@ -4,25 +4,115 @@ */ #include -#include #include #include #include #include -#include #include #include #include +#include +#include "../layout/layout.h" #include "../op/builtin.h" +#include "../op/operator.h" +#include "../op/region.h" #include "../op/utils.h" +#include "common/pipeline_utils.h" +#include "multi_version_buffer_rewriter.h" namespace tvm { namespace tl { using namespace tir; +namespace { + +bool ShapesEqual(const Array &lhs, const Array &rhs, + arith::Analyzer *analyzer) { + if (lhs.size() != rhs.size()) { + return false; + } + for (size_t i = 0; i < lhs.size(); ++i) { + if (!analyzer->CanProveEqual(lhs[i], rhs[i])) { + return false; + } + } + return true; +} + +Layout ExpandAnnotatedLayoutForMultiVersionedBuffer(const Layout &layout, + const Buffer &old_buffer, + const Buffer &new_buffer) { + if (!layout.defined() || + new_buffer->shape.size() <= old_buffer->shape.size()) { + return Layout(); + } + + arith::Analyzer analyzer; + if (!ShapesEqual(layout->InputShape(), old_buffer->shape, &analyzer)) { + return Layout(); + } + + size_t leading_ndim = new_buffer->shape.size() - old_buffer->shape.size(); + Array trailing_shape; + Array leading_shape; + for (size_t i = 0; i < leading_ndim; ++i) { + leading_shape.push_back(new_buffer->shape[i]); + } + for (size_t i = 0; i < old_buffer->shape.size(); ++i) { + trailing_shape.push_back(new_buffer->shape[leading_ndim + i]); + } + if (!ShapesEqual(trailing_shape, old_buffer->shape, &analyzer)) { + return Layout(); + } + + return layout->Expand(leading_shape); +} + +bool UpdateExpandedLayoutMapForRemappedAllocs( + const std::vector> &remapped_allocs, + Map *annotations) { + if (remapped_allocs.empty() || !annotations->count(attr::kLayoutMap)) { + return false; + } + + auto layout_map_ref = annotations->Get(attr::kLayoutMap); + if (!layout_map_ref.has_value()) { + return false; + } + auto layout_map = layout_map_ref.value().as>(); + if (!layout_map.has_value()) { + return false; + } + + Map updated_layout_map = layout_map.value(); + std::unordered_set visited; + bool changed = false; + for (const auto &[old_buffer, new_buffer] : remapped_allocs) { + if (!visited.insert(old_buffer->data.get()).second || + !updated_layout_map.count(old_buffer->data)) { + continue; + } + Layout layout = updated_layout_map[old_buffer->data]; + Layout expanded = ExpandAnnotatedLayoutForMultiVersionedBuffer( + layout, old_buffer, new_buffer); + if (!expanded.defined()) { + continue; + } + updated_layout_map.Set(old_buffer->data, expanded); + changed = true; + } + + if (changed) { + annotations->Set(attr::kLayoutMap, updated_layout_map); + } + return changed; +} + +} // namespace + enum class Role : uint8_t { kConsumer, kProducer, kBoth }; class WarpSpecializedRoleMarker_ : public StmtVisitor { @@ -128,8 +218,8 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor { class MultiVersionBufferRewriter : public StmtExprMutator { public: - static PrimFunc Substitute(PrimFunc &f, bool barrier_only = false) { - auto rewriter = MultiVersionBufferRewriter(barrier_only); + static PrimFunc Substitute(PrimFunc f) { + auto rewriter = MultiVersionBufferRewriter(); rewriter.buffer_lca_ = DetectBufferAccessLCA(f); for (auto [buffer, _] : rewriter.buffer_lca_) { Var buffer_var = buffer->data; @@ -140,8 +230,7 @@ class MultiVersionBufferRewriter : public StmtExprMutator { } private: - explicit MultiVersionBufferRewriter(bool barrier_only = false) - : barrier_only_(barrier_only) {} + explicit MultiVersionBufferRewriter() = default; Array GetVersionedBuffers(const Array &seq_stmt, const Array &scoped_buffers) { @@ -161,6 +250,13 @@ class MultiVersionBufferRewriter : public StmtExprMutator { collect_stmts(attr->body); return; } + if (const auto *if_then_else = stmt.as()) { + collect_stmts(if_then_else->then_case); + if (if_then_else->else_case.defined()) { + collect_stmts(if_then_else->else_case.value()); + } + return; + } if (const auto *block_realize = stmt.as()) { collect_stmts(block_realize->block->body); return; @@ -183,8 +279,52 @@ class MultiVersionBufferRewriter : public StmtExprMutator { Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ stmt); auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_); - reads.push_back(access[0]); - writes.push_back(access[1]); + Array stmt_reads = access[0]; + Array stmt_writes = access[1]; + + // Supplement with tile-op analysis. + // GetBlockAccessRegion misses buffer references that are encoded as + // tl.tileop.region Call args or as plain BufferLoad args whose + // semantic role (read vs write) is only known to the tile-op. + // Let the tile-op report its own access regions, and fall back to + // RegionOp scanning for any ops that still do not expose them. + if (auto *eval = stmt.as()) { + if (auto *call = eval->value.as()) { + auto tile_op = ParseOperator(ffi::GetRef(call)); + if (tile_op.defined()) { + AccessRegions access = tile_op->GetAccessRegions(); + if (!access.reads.empty() || !access.writes.empty()) { + stmt_reads.insert(stmt_reads.end(), access.reads.begin(), + access.reads.end()); + stmt_writes.insert(stmt_writes.end(), access.writes.begin(), + access.writes.end()); + } else { + // Fallback: scan RegionOp-encoded args. + for (const auto &arg : call->args) { + if (auto *region_call = arg.as()) { + if (region_call->op.same_as(RegionOp::Get())) { + auto region_op = + ParseOperator(ffi::GetRef(region_call)); + if (auto *rn = region_op.as()) { + int mask = rn->GetAccessMask(); + auto br = BufferRegion(rn->GetBuffer(), rn->GetRanges()); + if (mask & 1) { // read + stmt_reads.push_back(br); + } + if (mask & 2) { // write + stmt_writes.push_back(br); + } + } + } + } + } + } + } + } + } + + reads.push_back(stmt_reads); + writes.push_back(stmt_writes); roles.push_back(marker.GetRole(stmt)); } @@ -274,21 +414,152 @@ class MultiVersionBufferRewriter : public StmtExprMutator { return Buffer(new_buffer); } + Array GetPipelineTopLevelStmts(const Stmt &pipeline_body) const { + Stmt current = pipeline_body; + while (true) { + if (const auto *realize = current.as()) { + current = realize->block->body; + continue; + } + if (const auto *block = current.as()) { + current = block->body; + continue; + } + break; + } + if (const auto *seq = current.as()) { + return seq->seq; + } + return {current}; + } + + Array CollectScopedBuffers() const { + Array scoped_buffers; + std::unordered_set seen; + for (auto [buffer, stmt] : buffer_lca_) { + if (!stmt.defined()) { + continue; + } + const StmtNode *lca = stmt.value().get(); + bool in_scope = false; + for (const StmtNode *ancestor : stmt_stack_) { + if (ancestor == lca) { + in_scope = true; + break; + } + } + if (!in_scope) { + continue; + } + if (!IsSharedBuffer(buffer) && buffer.scope() != "shared.barrier") { + continue; + } + if (seen.insert(buffer.get()).second) { + scoped_buffers.push_back(buffer); + } + } + for (auto it = stmt_stack_.rbegin(); it != stmt_stack_.rend(); ++it) { + if (!(*it)->IsInstance()) { + continue; + } + const auto *block = static_cast(*it); + auto map_it = block_alloc_buffers_.find(block); + const Array &buffers = map_it != block_alloc_buffers_.end() + ? map_it->second + : block->alloc_buffers; + for (const Buffer &buffer : buffers) { + if (!IsSharedBuffer(buffer) && buffer.scope() != "shared.barrier") { + continue; + } + if (seen.insert(buffer.get()).second) { + scoped_buffers.push_back(buffer); + } + } + } + return scoped_buffers; + } + + Array SelectVersionedBuffers(const Stmt &pipeline_body, + int num_stages) { + Array scoped_buffers = CollectScopedBuffers(); + Array versioned_buffers = GetVersionedBuffers( + GetPipelineTopLevelStmts(pipeline_body), scoped_buffers); + + std::unordered_set already; + for (const Buffer &buffer : versioned_buffers) { + already.insert(buffer.get()); + } + for (const Buffer &buffer : scoped_buffers) { + if (buffer.scope() == "shared.barrier" && !already.count(buffer.get())) { + versioned_buffers.push_back(buffer); + } + } + + if (num_stages <= 1) { + Array filtered; + for (const Buffer &buffer : versioned_buffers) { + if (buffer.scope() == "shared.barrier") { + filtered.push_back(buffer); + } + } + versioned_buffers = filtered; + } + return versioned_buffers; + } + + void EnsureVersionedBuffers(const Array &versioned_buffers, + int num_stages) { + for (const Buffer &buffer : versioned_buffers) { + if (buffer_remap_.count(buffer)) { + continue; + } + Var buffer_var = buffer->data; + Buffer new_buffer = RewriteAllocBuffer(buffer, num_stages); + buffer_remap_.Set(buffer, new_buffer); + if (!buffer_data_to_buffer_.count(buffer_var)) { + buffer_data_to_buffer_.Set(buffer_var, buffer); + } + } + } + + PrimExpr CurrentVersionIndex() const { + if (!explicit_version_index_stack_.empty()) { + return explicit_version_index_stack_.back(); + } + return version_index_; + } + + PrimExpr CurrentParityCycle() const { + if (!explicit_parity_cycle_stack_.empty()) { + return explicit_parity_cycle_stack_.back(); + } + return parity_cycle_; + } + Stmt VisitStmt_(const BlockRealizeNode *op) final { BlockRealize block_realize = Downcast(StmtExprMutator::VisitStmt_(op)); Block block = block_realize->block; Array alloc_buffers; + std::vector> remapped_allocs; for (auto buffer : block->alloc_buffers) { if (buffer_remap_.count(buffer)) { Buffer new_buffer = buffer_remap_[buffer]; alloc_buffers.push_back(new_buffer); + remapped_allocs.emplace_back(buffer, new_buffer); } else { alloc_buffers.push_back(buffer); } } block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers); + if (!remapped_allocs.empty()) { + auto ann = block->annotations; + if (UpdateExpandedLayoutMapForRemappedAllocs(remapped_allocs, &ann)) { + block.CopyOnWrite()->annotations = std::move(ann); + } + } + // Update barrier_init annotation: replicate arrive counts for versioned // barrier buffers so lower_shared_barrier sees the correct count. if (block->annotations.count("barrier_init")) { @@ -339,147 +610,67 @@ class MultiVersionBufferRewriter : public StmtExprMutator { return stmt; } - Stmt VisitStmt_(const ForNode *op) final { + Stmt VisitStmt_(const AttrStmtNode *op) final { stmt_stack_.push_back(op); - loop_stack_.emplace_back(op->loop_var, op->extent); - auto num_stages_anno = op->annotations.Get("num_stages"); - if (!num_stages_anno) { - auto for_node = StmtExprMutator::VisitStmt_(op); - loop_stack_.pop_back(); - stmt_stack_.pop_back(); - return for_node; - } - ICHECK(num_stages_anno->as()); - int num_stages = static_cast(num_stages_anno->as()->value); - - Stmt pipeline_body_root{nullptr}; - if (const auto *realize = op->body.as()) { - const auto &block = realize->block; - for (const auto &buffer : block->alloc_buffers) { - ICHECK(buffer->IsInstance()); - buffer_data_to_buffer_.Set(buffer->data, buffer); + bool pushed_explicit_version = false; + bool pushed_explicit_parity = false; + if (op->attr_key == kPipelineMVBContextNumStages) { + if (const int64_t *imm = as_const_int(op->value)) { + int num_stages = static_cast(*imm); + EnsureVersionedBuffers(SelectVersionedBuffers(op->body, num_stages), + num_stages); } - pipeline_body_root = block->body; - } else { - pipeline_body_root = op->body; + } else if (op->attr_key == kPipelineMVBStageExpr) { + explicit_version_index_stack_.push_back(op->value); + pushed_explicit_version = true; + } else if (op->attr_key == kPipelineMVBParityExpr) { + explicit_parity_cycle_stack_.push_back(op->value); + pushed_explicit_parity = true; } - const SeqStmtNode *pipeline_body_seq = nullptr; - { - // Traverse trivial wrappers (let/if) to find the actual SeqStmt body. - Stmt current = pipeline_body_root; - while (true) { - if (const auto *seq_stmt = current.as()) { - pipeline_body_seq = seq_stmt; - break; - } - if (const auto *if_then_else = current.as()) { - ICHECK(!if_then_else->else_case.defined()) - << "MultiVersionBuffer: Can't handle the body of the loop " - "because the IfThenElse node has an else branch"; - current = if_then_else->then_case; - continue; - } - if (const auto *let_stmt = current.as()) { - current = let_stmt->body; - continue; - } - LOG(FATAL) - << "MultiVersionBuffer: Can't handle the body of the loop because " - << "it is not a SeqStmt, IfThenElse without else, " - << "or LetStmt wrapping them, but got " << current->GetTypeKey(); - } - } - ICHECK(pipeline_body_seq != nullptr); + Stmt body = this->VisitStmt(op->body); - Array scoped_buffers; - std::unordered_set seen; - for (auto [buffer, stmt] : buffer_lca_) { - if (!stmt.defined()) - continue; - const StmtNode *lca = stmt.value().get(); - bool in_scope = false; - for (const StmtNode *ancestor : stmt_stack_) { - if (ancestor == lca) { - in_scope = true; - break; - } - } - if (!in_scope) - continue; - // Only double-buffer shared/barrier allocations; locals do not need - // versioning. - if (!IsSharedBuffer(buffer) && buffer.scope() != "shared.barrier") - continue; - if (seen.insert(buffer.get()).second) { - scoped_buffers.push_back(buffer); - } + if (pushed_explicit_version) { + explicit_version_index_stack_.pop_back(); } - for (auto it = stmt_stack_.rbegin(); it != stmt_stack_.rend(); ++it) { - if (!(*it)->IsInstance()) - continue; - const auto *block = static_cast(*it); - // Try cached alloc list first; fall back to the original IR node - // (the cache may not be populated yet during the recursive visit). - auto map_it = block_alloc_buffers_.find(block); - const Array &buffers = map_it != block_alloc_buffers_.end() - ? map_it->second - : block->alloc_buffers; - for (const Buffer &buffer : buffers) { - if (!IsSharedBuffer(buffer) && buffer.scope() != "shared.barrier") - continue; - if (seen.insert(buffer.get()).second) { - scoped_buffers.push_back(buffer); - } - } + if (pushed_explicit_parity) { + explicit_parity_cycle_stack_.pop_back(); } + stmt_stack_.pop_back(); - Array versioned_buffers = - GetVersionedBuffers(pipeline_body_seq->seq, scoped_buffers); - - // Barrier buffers always get versioned in pipelined loops — - // they don't fit the producer/consumer analysis above. - { - std::unordered_set already; - for (auto b : versioned_buffers) - already.insert(b.get()); - for (auto buffer : scoped_buffers) { - if (buffer.scope() == "shared.barrier" && - !already.count(buffer.get())) { - versioned_buffers.push_back(buffer); - } - } + if (op->attr_key == kPipelineMVBStageExpr || + op->attr_key == kPipelineMVBParityExpr || + op->attr_key == kPipelineMVBContextNumStages) { + return body; } + return AttrStmt(op->node, op->attr_key, op->value, body, op->span); + } - // In barrier_only mode, only version barrier buffers. - // Data buffer versioning is left to InjectSoftwarePipeline. - if (barrier_only_) { - Array filtered; - for (auto buffer : versioned_buffers) { - if (buffer.scope() == "shared.barrier" || - buffer.scope() == "shared.cluster_barrier") { - filtered.push_back(buffer); - } - } - versioned_buffers = filtered; + Stmt VisitStmt_(const ForNode *op) final { + stmt_stack_.push_back(op); + loop_stack_.emplace_back(op->loop_var, op->extent); + Optional num_stages_anno = GetPipelineNumStages(op); + if (!num_stages_anno) { + auto for_node = StmtExprMutator::VisitStmt_(op); + loop_stack_.pop_back(); + stmt_stack_.pop_back(); + return for_node; } - for (auto buffer : versioned_buffers) { - Var buffer_var = buffer->data; - Buffer new_buffer = RewriteAllocBuffer(buffer, num_stages); - buffer_remap_.Set(buffer, new_buffer); - // Ensure the data var is discoverable so the barrier_init annotation - // update in VisitStmt_(BlockRealizeNode*) can find the remapped buffer. - if (!buffer_data_to_buffer_.count(buffer_var)) { - buffer_data_to_buffer_.Set(buffer_var, buffer); - } - } + int num_stages = num_stages_anno.value()->value; + EnsureVersionedBuffers(SelectVersionedBuffers(op->body, num_stages), + num_stages); + PrimExpr linear_index = loop_stack_[0].first; for (size_t i = 1; i < loop_stack_.size(); ++i) { linear_index = linear_index * loop_stack_[i].second + loop_stack_[i].first; } + PrimExpr old_version_index = version_index_; + PrimExpr old_parity_cycle = parity_cycle_; + Var old_pipeline_loop_var = pipeline_loop_var_; + PrimExpr old_pipeline_loop_min = pipeline_loop_min_; version_index_ = FloorMod(linear_index, num_stages); // Parity cycles every num_stages iterations for mbarrier phase tracking. parity_cycle_ = FloorMod(FloorDiv(linear_index, num_stages), 2); @@ -488,9 +679,10 @@ class MultiVersionBufferRewriter : public StmtExprMutator { pipeline_loop_var_ = op->loop_var; pipeline_loop_min_ = op->min; auto for_node = StmtExprMutator::VisitStmt_(op); - parity_cycle_ = PrimExpr(); // reset - pipeline_loop_var_ = Var(); - pipeline_loop_min_ = PrimExpr(); + version_index_ = old_version_index; + parity_cycle_ = old_parity_cycle; + pipeline_loop_var_ = old_pipeline_loop_var; + pipeline_loop_min_ = old_pipeline_loop_min; loop_stack_.pop_back(); stmt_stack_.pop_back(); @@ -505,13 +697,16 @@ class MultiVersionBufferRewriter : public StmtExprMutator { } Buffer old_buffer = load->buffer; const Buffer &new_buffer = (*it).second; + PrimExpr version_index = CurrentVersionIndex(); + ICHECK(version_index.defined()) + << "Versioned buffer load escaped pipeline stage context"; auto *n = load.CopyOnWrite(); n->buffer = new_buffer; if (old_buffer.scope() == "shared.barrier") { // Barrier: offset into expanded 1D array - n->indices.Set(0, version_index_ * old_buffer->shape[0] + n->indices[0]); + n->indices.Set(0, version_index * old_buffer->shape[0] + n->indices[0]); } else { - n->indices.insert(n->indices.begin(), version_index_); + n->indices.insert(n->indices.begin(), version_index); } return std::move(load); } @@ -524,12 +719,15 @@ class MultiVersionBufferRewriter : public StmtExprMutator { } Buffer old_buffer = store->buffer; const Buffer &new_buffer = (*it).second; + PrimExpr version_index = CurrentVersionIndex(); + ICHECK(version_index.defined()) + << "Versioned buffer store escaped pipeline stage context"; auto *n = store.CopyOnWrite(); n->buffer = new_buffer; if (old_buffer.scope() == "shared.barrier") { - n->indices.Set(0, version_index_ * old_buffer->shape[0] + n->indices[0]); + n->indices.Set(0, version_index * old_buffer->shape[0] + n->indices[0]); } else { - n->indices.insert(n->indices.begin(), version_index_); + n->indices.insert(n->indices.begin(), version_index); } return std::move(store); } @@ -539,6 +737,35 @@ class MultiVersionBufferRewriter : public StmtExprMutator { if (call->op.same_as(builtin::tvm_access_ptr())) { return RewriteBufferAccess(call, {1}); } + // Rewrite tl.tileop.region Calls for versioned buffers. + // The region encoding is: + // region(BufferLoad(buf, [min_0, ..., min_N]), access_mask, ext_0, ..., + // ext_N) + // After the recursive visit, VisitExpr_(BufferLoadNode*) prepends a + // version_index to the BufferLoad indices, yielding [version_index, + // min_0, ..., min_N]. We must also insert a matching extent (1) for the + // new leading dimension so that RegionOp's ndim == indices.size() + // invariant is preserved. + // + // Detection: if the BufferLoad has more indices than the number of extent + // args (args.size() - 2), a version index was prepended. + if (call->op.same_as(RegionOp::Get()) && call->args.size() >= 2) { + if (auto load = call->args[0].as()) { + size_t num_extents = + call->args.size() - 2; // args = [load, mask, ext...] + if (load->indices.size() == num_extents + 1) { + // Version index was prepended. Insert a unit extent to match. + Array new_args; + new_args.push_back(call->args[0]); // rewritten BufferLoad + new_args.push_back(call->args[1]); // access_mask + new_args.push_back(IntImm(DataType::Int(32), 1)); // stage extent + for (size_t i = 2; i < call->args.size(); ++i) { + new_args.push_back(call->args[i]); + } + return Call(call->dtype, call->op, new_args, call->annotations); + } + } + } // Rewrite parity for mbarrier_wait_parity on versioned barrier buffers. // The user writes single-barrier parity (e.g. k % 2 or (k+1) % 2). // After multi-versioning, each barrier is reused every num_stages @@ -547,27 +774,34 @@ class MultiVersionBufferRewriter : public StmtExprMutator { // (e.g. back-pressure barriers use (k+1)%2 so the first iteration // passes immediately). We detect this offset by evaluating the original // parity at the loop's initial value and preserving it. - if (call->op.same_as(mbarrier_wait_parity()) && parity_cycle_.defined()) { + PrimExpr parity_cycle = CurrentParityCycle(); + if (call->op.same_as(mbarrier_wait_parity()) && parity_cycle.defined()) { if (auto load = call->args[0].as()) { if (load->buffer.scope() == "shared.barrier") { - PrimExpr new_parity = parity_cycle_; + PrimExpr new_parity = parity_cycle; + arith::Analyzer analyzer; + PrimExpr init_orig = call->args[1]; + PrimExpr init_cycle = parity_cycle; + if (!explicit_parity_cycle_stack_.empty()) { + PrimExpr version_index = CurrentVersionIndex(); + ICHECK(version_index.defined()) + << "Explicit parity rewrite requires a version index"; + init_cycle = version_index; + } if (pipeline_loop_var_.defined()) { - arith::Analyzer analyzer; auto subst = [&](const Var &v) -> Optional { if (v.same_as(pipeline_loop_var_)) return pipeline_loop_min_; return Optional(); }; - PrimExpr init_orig = - analyzer.Simplify(tir::Substitute(call->args[1], subst)); - PrimExpr init_cycle = - analyzer.Simplify(tir::Substitute(parity_cycle_, subst)); - PrimExpr offset = - analyzer.Simplify(FloorMod(init_orig - init_cycle, 2)); - if (auto *imm = offset.as()) { - if (imm->value % 2 != 0) { - new_parity = FloorMod(parity_cycle_ + 1, 2); - } + init_orig = analyzer.Simplify(tir::Substitute(init_orig, subst)); + init_cycle = analyzer.Simplify(tir::Substitute(init_cycle, subst)); + } + PrimExpr offset = + analyzer.Simplify(FloorMod(init_orig - init_cycle, 2)); + if (const int64_t *imm = as_const_int(offset)) { + if (*imm % 2 != 0) { + new_parity = FloorMod(parity_cycle + 1, 2); } } Array new_args = call->args; @@ -596,6 +830,9 @@ class MultiVersionBufferRewriter : public StmtExprMutator { const Buffer &buffer = buffer_data_to_buffer_[buffer_var]; auto it = buffer_remap_.find(buffer); if (it != buffer_remap_.end()) { + PrimExpr version_index = CurrentVersionIndex(); + ICHECK(version_index.defined()) + << "Versioned access_ptr escaped pipeline stage context"; const Buffer &new_buffer = (*it).second; const PrimExpr &old_index = call->args[i + 1]; PrimExpr offset; @@ -604,14 +841,13 @@ class MultiVersionBufferRewriter : public StmtExprMutator { } else { offset = new_buffer->strides[0]; } - PrimExpr new_index = old_index + version_index_ * offset; + PrimExpr new_index = old_index + version_index * offset; new_args.Set(i + 1, new_index); } } return Call(call->dtype, call->op, new_args, call->annotations, call->span); } - bool barrier_only_; PrimExpr version_index_; PrimExpr parity_cycle_; // (k / num_stages) % 2 for mbarrier parity rewriting Var pipeline_loop_var_; // loop variable of the pipelined loop @@ -620,6 +856,8 @@ class MultiVersionBufferRewriter : public StmtExprMutator { // Track ancestor statements to query whether an LCA is inside the current // loop. std::vector stmt_stack_; + std::vector explicit_version_index_stack_; + std::vector explicit_parity_cycle_stack_; Map buffer_data_to_buffer_; Map> buffer_lca_; Map buffer_remap_; @@ -628,18 +866,8 @@ class MultiVersionBufferRewriter : public StmtExprMutator { std::unordered_map> block_alloc_buffers_; }; -using namespace tir::transform; - -tvm::transform::Pass MultiVersionBuffer(bool barrier_only) { - auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { - return MultiVersionBufferRewriter::Substitute(f, barrier_only); - }; - return CreatePrimFuncPass(pass_func, 0, "tl.MultiVersionBuffer", {}); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tl.transform.MultiVersionBuffer", MultiVersionBuffer); +PrimFunc ApplyMultiVersionBufferRewriter(PrimFunc f) { + return MultiVersionBufferRewriter::Substitute(std::move(f)); } } // namespace tl diff --git a/src/transform/multi_version_buffer_rewriter.h b/src/transform/multi_version_buffer_rewriter.h new file mode 100644 index 0000000000..01e248b68c --- /dev/null +++ b/src/transform/multi_version_buffer_rewriter.h @@ -0,0 +1,19 @@ +/*! + * \brief Internal helper for pipeline buffer multi-versioning. + * \file multi_version_buffer_rewriter.h + */ + +#ifndef TVM_TL_TRANSFORM_MULTI_VERSION_BUFFER_REWRITER_H_ +#define TVM_TL_TRANSFORM_MULTI_VERSION_BUFFER_REWRITER_H_ + +#include + +namespace tvm { +namespace tl { + +tir::PrimFunc ApplyMultiVersionBufferRewriter(tir::PrimFunc f); + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_TRANSFORM_MULTI_VERSION_BUFFER_REWRITER_H_ diff --git a/src/transform/optimize_cp_async_sync.cc b/src/transform/optimize_cp_async_sync.cc deleted file mode 100644 index 2d793e687a..0000000000 --- a/src/transform/optimize_cp_async_sync.cc +++ /dev/null @@ -1,1235 +0,0 @@ -/*! - * \file optimize_cp_async_sync.cc - * \brief Optimize explicit cp.async synchronization intrinsics. - */ - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -#include "../op/builtin.h" - -namespace tvm { -namespace tl { - -using namespace tir; - -namespace transform { - -class CPAsyncSyncOptimizer : public StmtExprMutator { -public: - Stmt VisitStmt_(const SeqStmtNode *op) final { - Array visited; - visited.reserve(op->seq.size()); - for (const Stmt &stmt : op->seq) { - visited.push_back(this->VisitStmt(stmt)); - } - - visited = MaybeMergeCommitsBeforeWait0(std::move(visited)); - - visited = MaybeSplitEpilogueWait(std::move(visited)); - - enum class UncommittedState { kUnknown, kZero, kNonZero }; - - UncommittedState uncommitted_state = UncommittedState::kUnknown; - std::optional last_wait_n; - bool last_wait_dynamic = false; - std::optional outstanding_committed_groups_exact = 0; - int outstanding_committed_groups_lb = 0; - - Array simplified; - simplified.reserve(visited.size()); - for (size_t stmt_idx = 0; stmt_idx < visited.size(); ++stmt_idx) { - const Stmt &stmt = visited[stmt_idx]; - Stmt current = stmt; - if (const auto *loop = current.as()) { - current = MaybeRelaxLoopWaits(Downcast(current), - outstanding_committed_groups_exact, - outstanding_committed_groups_lb); - } - - ClassifiedStmt cls = ClassifySimpleAsyncStmt(current); - switch (cls.kind) { - case AsyncStmtKind::kCPAsync: - uncommitted_state = UncommittedState::kNonZero; - simplified.push_back(current); - break; - case AsyncStmtKind::kCommit: { - if (uncommitted_state == UncommittedState::kZero) { - // Proven redundant commit: no cp.async issued since the last commit. - break; - } - bool commit_has_new_cpasync = - (uncommitted_state == UncommittedState::kNonZero); - simplified.push_back(current); - uncommitted_state = UncommittedState::kZero; - if (outstanding_committed_groups_exact.has_value() && - commit_has_new_cpasync) { - outstanding_committed_groups_exact = - AddWithCap(*outstanding_committed_groups_exact, /*inc=*/1); - } else if (outstanding_committed_groups_exact.has_value() && - !commit_has_new_cpasync) { - // Keep exact outstanding unchanged when this commit has no proven new - // cp.async group. - } else { - outstanding_committed_groups_exact = std::nullopt; - } - if (commit_has_new_cpasync) { - outstanding_committed_groups_lb = - AddWithCap(outstanding_committed_groups_lb, /*inc=*/1); - } - last_wait_n.reset(); - last_wait_dynamic = false; - break; - } - case AsyncStmtKind::kWaitStatic: - if (!last_wait_dynamic && last_wait_n.has_value() && - cls.wait_n >= *last_wait_n) { - // A weaker (or equal) wait is redundant when no commit happened in - // between. - break; - } - simplified.push_back(current); - last_wait_n = cls.wait_n; - last_wait_dynamic = false; - if (outstanding_committed_groups_exact.has_value()) { - outstanding_committed_groups_exact = - std::min(*outstanding_committed_groups_exact, cls.wait_n); - } - outstanding_committed_groups_lb = - std::min(outstanding_committed_groups_lb, cls.wait_n); - break; - case AsyncStmtKind::kWaitDynamic: - simplified.push_back(current); - last_wait_n.reset(); - last_wait_dynamic = true; - outstanding_committed_groups_exact = std::nullopt; - outstanding_committed_groups_lb = 0; - break; - case AsyncStmtKind::kOther: - simplified.push_back(current); - if (ContainsAsyncIntrinsics(current)) { - AsyncIntrinSummary summary = SummarizeAsyncIntrinsics(current); - if (summary.cp_async > 0 && summary.commit == 0 && - summary.wait == 0) { - // Preserve pending cp.async state across cp.async-only wrappers - // (e.g. prologue loops before a standalone commit). - uncommitted_state = UncommittedState::kNonZero; - break; - } - - if (summary.wait == 0) { - if (auto transfer = TryGetDeterministicNoWaitTransfer(current)) { - int guaranteed_new_groups = - std::min(transfer->groups_if_start_clear, - transfer->groups_if_start_pending); - outstanding_committed_groups_lb = AddWithCap( - outstanding_committed_groups_lb, guaranteed_new_groups); - - if (outstanding_committed_groups_exact.has_value()) { - if (uncommitted_state == UncommittedState::kZero) { - outstanding_committed_groups_exact = - AddWithCap(*outstanding_committed_groups_exact, - transfer->groups_if_start_clear); - } else if (uncommitted_state == UncommittedState::kNonZero) { - outstanding_committed_groups_exact = - AddWithCap(*outstanding_committed_groups_exact, - transfer->groups_if_start_pending); - } else { - outstanding_committed_groups_exact = std::nullopt; - } - } - - auto pending_to_state = [](bool pending) { - return pending ? UncommittedState::kNonZero - : UncommittedState::kZero; - }; - if (uncommitted_state == UncommittedState::kZero) { - uncommitted_state = - pending_to_state(transfer->pending_if_start_clear); - } else if (uncommitted_state == UncommittedState::kNonZero) { - uncommitted_state = - pending_to_state(transfer->pending_if_start_pending); - } else { - if (transfer->pending_if_start_clear == - transfer->pending_if_start_pending) { - uncommitted_state = - pending_to_state(transfer->pending_if_start_clear); - } else { - uncommitted_state = UncommittedState::kUnknown; - } - } - - break; - } - } - // Cross this unknown boundary conservatively. - uncommitted_state = UncommittedState::kUnknown; - last_wait_n.reset(); - last_wait_dynamic = false; - outstanding_committed_groups_exact = std::nullopt; - outstanding_committed_groups_lb = 0; - } - break; - } - } - - if (simplified.empty()) { - return Evaluate(0); - } - if (simplified.size() == 1) { - return simplified[0]; - } - return SeqStmt(simplified); - } - -private: - enum class AsyncStmtKind { - kOther, - kCPAsync, - kCommit, - kWaitStatic, - kWaitDynamic - }; - - enum class PendingAsyncState { kUnknown, kZero, kNonZero }; - - struct ClassifiedStmt { - AsyncStmtKind kind{AsyncStmtKind::kOther}; - int wait_n{0}; - }; - - struct AsyncIntrinSummary { - int cp_async = 0; - int commit = 0; - int wait = 0; - }; - - // Conservative deterministic transfer for statements that contain async - // intrinsics but no wait_group. - struct DeterministicNoWaitTransfer { - int groups_if_start_clear{0}; - bool pending_if_start_clear{false}; - int groups_if_start_pending{0}; - bool pending_if_start_pending{true}; - }; - - static constexpr int kOutstandingCap = 1024; - - static int AddWithCap(int base, int inc) { - int64_t sum = static_cast(base) + static_cast(inc); - return static_cast(std::min(sum, kOutstandingCap)); - } - - DeterministicNoWaitTransfer IdentityTransfer() const { - return {/*groups_if_start_clear=*/0, /*pending_if_start_clear=*/false, - /*groups_if_start_pending=*/0, /*pending_if_start_pending=*/true}; - } - - DeterministicNoWaitTransfer CPAsyncTransfer() const { - return {/*groups_if_start_clear=*/0, /*pending_if_start_clear=*/true, - /*groups_if_start_pending=*/0, /*pending_if_start_pending=*/true}; - } - - DeterministicNoWaitTransfer CommitTransfer() const { - return {/*groups_if_start_clear=*/0, /*pending_if_start_clear=*/false, - /*groups_if_start_pending=*/1, /*pending_if_start_pending=*/false}; - } - - DeterministicNoWaitTransfer - ComposeTransfer(const DeterministicNoWaitTransfer &first, - const DeterministicNoWaitTransfer &second) const { - auto compose_one = [&](int first_groups, bool first_pending) { - if (first_pending) { - return std::make_pair( - AddWithCap(first_groups, second.groups_if_start_pending), - second.pending_if_start_pending); - } - return std::make_pair( - AddWithCap(first_groups, second.groups_if_start_clear), - second.pending_if_start_clear); - }; - - auto [g0, p0] = - compose_one(first.groups_if_start_clear, first.pending_if_start_clear); - auto [g1, p1] = compose_one(first.groups_if_start_pending, - first.pending_if_start_pending); - return {/*groups_if_start_clear=*/g0, /*pending_if_start_clear=*/p0, - /*groups_if_start_pending=*/g1, /*pending_if_start_pending=*/p1}; - } - - DeterministicNoWaitTransfer RepeatTransfer(DeterministicNoWaitTransfer base, - int64_t times) const { - DeterministicNoWaitTransfer result = IdentityTransfer(); - while (times > 0) { - if (times & 1) { - result = ComposeTransfer(result, base); - } - times >>= 1; - if (times > 0) { - base = ComposeTransfer(base, base); - } - } - return result; - } - - std::optional - TryGetDeterministicNoWaitTransfer(const Stmt &stmt) const { - if (const auto *let = stmt.as()) { - return TryGetDeterministicNoWaitTransfer(let->body); - } - if (const auto *attr = stmt.as()) { - return TryGetDeterministicNoWaitTransfer(attr->body); - } - if (const auto *seq = stmt.as()) { - DeterministicNoWaitTransfer result = IdentityTransfer(); - for (const Stmt &s : seq->seq) { - auto part = TryGetDeterministicNoWaitTransfer(s); - if (!part.has_value()) { - return std::nullopt; - } - result = ComposeTransfer(result, *part); - } - return result; - } - if (const auto *block = stmt.as()) { - return TryGetDeterministicNoWaitTransfer(block->body); - } - if (const auto *realize = stmt.as()) { - if (!is_one(realize->predicate)) { - return std::nullopt; - } - return TryGetDeterministicNoWaitTransfer(realize->block->body); - } - if (const auto *for_node = stmt.as()) { - if (for_node->thread_binding.defined()) { - return std::nullopt; - } - const auto *extent_imm = for_node->extent.as(); - if (extent_imm == nullptr || extent_imm->value < 0) { - return std::nullopt; - } - auto body_transfer = TryGetDeterministicNoWaitTransfer(for_node->body); - if (!body_transfer.has_value()) { - return std::nullopt; - } - return RepeatTransfer(*body_transfer, extent_imm->value); - } - if (stmt.as()) { - return std::nullopt; - } - if (const auto *eval = stmt.as()) { - if (const auto *call = eval->value.as()) { - if (IsWaitCall(call)) { - return std::nullopt; - } - if (IsCPAsyncCall(call)) { - return CPAsyncTransfer(); - } - if (IsCommitCall(call)) { - return CommitTransfer(); - } - } - if (ContainsAsyncIntrinsics(stmt)) { - return std::nullopt; - } - return IdentityTransfer(); - } - if (ContainsAsyncIntrinsics(stmt)) { - return std::nullopt; - } - return IdentityTransfer(); - } - - Array MaybeMergeCommitsBeforeWait0(Array seq) const { - // Merge adjacent cp.async commit groups when the program is still using a - // full drain (wait_group(0)) as the synchronization point. - // - // Pattern (within a SeqStmt segment that ends at wait_group(0)): - // cp_async*; commit; cp_async*; commit; wait_group(0) - // => - // cp_async*; cp_async*; commit; wait_group(0) - // - // This reduces the number of committed groups without weakening the drain, - // and lets later wait relaxation derive the right retain count from the - // new (smaller) group topology. - const int n = static_cast(seq.size()); - if (n < 4) { - return seq; - } - - auto is_direct_commit = [&](const Stmt &s) -> bool { - const auto *eval = s.as(); - if (!eval) { - return false; - } - const auto *call = eval->value.as(); - return call && IsCommitCall(call); - }; - - Array out; - out.reserve(n); - int seg_start = 0; - - auto flush_segment = [&](int seg_end, bool merge_commits) { - if (!merge_commits) { - for (int j = seg_start; j < seg_end; ++j) { - out.push_back(seq[j]); - } - return; - } - - // We only want to merge commits in the immediately preceding cp.async - // region, so we analyze a maximal suffix that contains only: - // - cp.async-only statements (possibly wrapped in loops/attrs), and - // - standalone commit_group statements. - // This avoids blocking on unrelated statements (e.g. barriers) that may - // appear earlier in the segment. - auto is_cp_async_only_stmt = [&](const Stmt &s) -> bool { - ClassifiedStmt cls = ClassifySimpleAsyncStmt(s); - if (cls.kind == AsyncStmtKind::kCPAsync) { - return true; - } - if (!ContainsAsyncIntrinsics(s)) { - return false; - } - AsyncIntrinSummary summary = SummarizeAsyncIntrinsics(s); - return (summary.cp_async > 0 && summary.commit == 0 && - summary.wait == 0); - }; - auto is_direct_commit_stmt = [&](const Stmt &s) -> bool { - ClassifiedStmt cls = ClassifySimpleAsyncStmt(s); - return cls.kind == AsyncStmtKind::kCommit && is_direct_commit(s); - }; - - int merge_start = seg_end; - bool saw_cp_async = false; - for (int j = seg_end - 1; j >= seg_start; --j) { - if (is_direct_commit_stmt(seq[j])) { - merge_start = j; - continue; - } - if (is_cp_async_only_stmt(seq[j])) { - saw_cp_async = true; - merge_start = j; - continue; - } - // Stop at the first non-async statement. - break; - } - - std::vector commit_indices; - commit_indices.reserve(4); - bool has_complex_async = false; - for (int j = merge_start; j < seg_end; ++j) { - if (is_direct_commit_stmt(seq[j])) { - commit_indices.push_back(j); - continue; - } - if (is_cp_async_only_stmt(seq[j])) { - continue; - } - // Shouldn't happen given how merge_start is determined, but be safe. - has_complex_async = true; - break; - } - - if (has_complex_async || !saw_cp_async || commit_indices.size() != 2) { - for (int j = seg_start; j < seg_end; ++j) { - out.push_back(seq[j]); - } - return; - } - - // Emit prefix unchanged. - for (int j = seg_start; j < merge_start; ++j) { - out.push_back(seq[j]); - } - // Drop the first commit in the mergeable suffix so both copy regions are - // committed as a single group by the second commit. - int dropped_commit_idx = commit_indices[0]; - for (int j = merge_start; j < seg_end; ++j) { - if (j == dropped_commit_idx) { - continue; - } - out.push_back(seq[j]); - } - }; - - for (int i = 0; i < n; ++i) { - ClassifiedStmt cls = ClassifySimpleAsyncStmt(seq[i]); - if (cls.kind == AsyncStmtKind::kWaitStatic && cls.wait_n == 0) { - flush_segment(/*seg_end=*/i, /*merge_commits=*/true); - out.push_back(seq[i]); - seg_start = i + 1; - continue; - } - if (cls.kind == AsyncStmtKind::kWaitStatic || - cls.kind == AsyncStmtKind::kWaitDynamic) { - // For non-(wait0) waits, we don't attempt to merge commits because it - // can weaken synchronization unless we also adjust wait counts. - flush_segment(/*seg_end=*/i, /*merge_commits=*/false); - out.push_back(seq[i]); - seg_start = i + 1; - continue; - } - } - flush_segment(/*seg_end=*/n, /*merge_commits=*/false); - return out; - } - - Array MaybeSplitEpilogueWait(Array seq) const { - // Schedule cp.async drains in a software-pipeline epilogue more formally. - // - // In TileLang software pipelining, async global->shared copies are - // committed in the steady-state loop and consumed in one or more epilogue - // "consumer phases". A conservative lowering often emits: - // - // for ...: (contains cp.async + commit) - // ptx_wait_group(0) # full drain - // tvm_storage_sync("shared") - // ... consumer phase 0 ... - // tvm_storage_sync("shared") - // ... consumer phase 1 ... - // - // Draining all groups immediately after the loop can destroy overlap - // between the work in phase 0 and the last in-flight committed group(s) - // that are only needed in phase 1. We improve overlap by: - // - relaxing the post-loop wait_group(0) to keep some groups in flight, - // - inserting a final wait_group(0) right before the shared barrier that - // starts the next consumer phase. - // - // Unlike the earlier heuristic that looked for global stores, we identify - // consumer phases by detecting reads from buffers written by ptx_cp_async. - const int n = static_cast(seq.size()); - if (n < 6) { - return seq; - } - - auto is_shared_storage_sync = [&](const Stmt &s) -> bool { - const auto *eval = s.as(); - if (!eval) { - return false; - } - const auto *call = eval->value.as(); - if (!call || !call->op.same_as(builtin::tvm_storage_sync())) { - return false; - } - if (call->args.size() != 1) { - return false; - } - const auto *scope = call->args[0].as(); - return scope && scope->value == "shared"; - }; - - auto make_wait_stmt = [&](int wait_n) -> Stmt { - return Evaluate(Call(DataType::Handle(), builtin::ptx_wait_group(), - {IntImm(DataType::Int(32), wait_n)})); - }; - - auto access_ptr_buffer_var = [&](const PrimExpr &ptr) -> Optional { - // Support both `tl.access_ptr(BufferLoad, extent, rw_mask)` (frontend) - // and `tvm_access_ptr(ptype, data, offset, extent, rw_mask)` (lowered). - const auto *call = ptr.as(); - if (!call) { - return Optional(); - } - if (call->op.same_as(tl::access_ptr())) { - if (call->args.size() != 3) { - return Optional(); - } - const auto *base_load = call->args[0].as(); - if (!base_load) { - return Optional(); - } - return base_load->buffer->data; - } - if (call->op.same_as(builtin::tvm_access_ptr())) { - if (call->args.size() != 5) { - return Optional(); - } - if (call->args[1].as()) { - return Downcast(call->args[1]); - } - } - return Optional(); - }; - - auto collect_cp_async_dst_buffers = [&](const Stmt &s) { - std::unordered_set vars; - PostOrderVisit(s, [&](const ObjectRef &node) { - const auto *call = node.as(); - if (!call || !IsCPAsyncCall(call)) { - return; - } - if (call->args.empty()) { - return; - } - if (Optional buf_var = access_ptr_buffer_var(call->args[0])) { - vars.insert(buf_var.value().get()); - } - }); - return vars; - }; - - auto contains_async_smem_read = - [&](const Stmt &s, - const std::unordered_set &async_smem_vars) - -> bool { - if (async_smem_vars.empty()) { - return false; - } - bool found = false; - PostOrderVisit(s, [&](const ObjectRef &node) { - if (found) { - return; - } - const auto *load = node.as(); - if (!load) { - return; - } - if (async_smem_vars.count(load->buffer->data.get()) == 0) { - return; - } - // Only treat shared memory reads as cp.async consumers. - const String &scope = load->buffer.scope(); - if (scope == "shared" || scope == "shared.dyn") { - found = true; - } - }); - return found; - }; - - for (int i = 1; i + 1 < n; ++i) { - ClassifiedStmt cls = ClassifySimpleAsyncStmt(seq[i]); - if (cls.kind != AsyncStmtKind::kWaitStatic || cls.wait_n != 0) { - continue; - } - - const auto *loop = seq[i - 1].as(); - if (!loop) { - continue; - } - For loop_ref = Downcast(seq[i - 1]); - AsyncIntrinSummary loop_summary = SummarizeAsyncIntrinsics(loop_ref); - if (loop_summary.cp_async <= 0 || loop_summary.commit <= 0) { - continue; - } - - if (!is_shared_storage_sync(seq[i + 1])) { - continue; - } - - int retain = PipelinedRetainGroups(loop_ref); - if (retain <= 0) { - continue; - } - - // Avoid relaxing wait_group(0) into a no-op when we cannot prove there is - // at least (retain + 1) committed groups that can be drained here. - // - // When loop extent is a compile-time constant, we can conservatively - // lower-bound the total number of commit_group calls executed. Otherwise, - // fall back to the per-iteration count (syntactic). - int64_t min_commits = static_cast(loop_summary.commit); - if (const auto *ext = loop_ref->extent.as()) { - min_commits *= static_cast(ext->value); - } - if (min_commits < static_cast(retain + 1)) { - continue; - } - - std::unordered_set async_smem_vars = - collect_cp_async_dst_buffers(loop_ref); - if (async_smem_vars.empty()) { - continue; - } - - // Identify at least two "consumer phases" after the post-loop wait by - // scanning barrier-separated regions for reads from async-written shared - // buffers. - int insert_before_sync = -1; - int segment_start = i + 2; // after the first sync - int prev_sync = i + 1; - int found_phases = 0; - for (int j = segment_start; j <= n; ++j) { - bool end_segment = (j == n) || is_shared_storage_sync(seq[j]); - if (!end_segment) { - continue; - } - bool consumes = false; - for (int k = segment_start; k < j; ++k) { - if (contains_async_smem_read(seq[k], async_smem_vars)) { - consumes = true; - break; - } - } - if (consumes) { - ++found_phases; - if (found_phases == 2) { - // The barrier immediately before this segment starts the next - // consumer phase. - if (prev_sync > i + 1) { - insert_before_sync = prev_sync; - } - break; - } - } - // Start next segment after this sync (if any). - if (j < n) { - prev_sync = j; - segment_start = j + 1; - } - } - if (insert_before_sync == -1) { - continue; - } - - Array out; - out.reserve(n + 1); - for (int j = 0; j < n; ++j) { - if (j == i) { - bool changed = false; - out.push_back( - RewriteWaitStaticInSimpleWrapper(seq[j], retain, &changed)); - // If rewrite failed (non-simple wrapper), keep original. - if (!changed) { - out.Set(out.size() - 1, seq[j]); - } - continue; - } - if (j == insert_before_sync) { - // Drain all groups before the next consumer phase. - bool rewrote_existing = false; - if (!out.empty()) { - bool changed_prev = false; - Stmt prev = - RewriteWaitStaticInSimpleWrapper(out.back(), 0, &changed_prev); - if (changed_prev) { - out.Set(out.size() - 1, prev); - rewrote_existing = true; - } - } - if (!rewrote_existing) { - out.push_back(make_wait_stmt(0)); - } - } - out.push_back(seq[j]); - } - return out; - } - - return seq; - } - - Stmt MakeStaticWaitStmtLike(const Stmt &stmt, int new_wait_n) const { - const auto *eval = stmt.as(); - if (!eval) { - return stmt; - } - const auto *call = eval->value.as(); - if (!call || !IsWaitCall(call)) { - return stmt; - } - - DataType wait_dtype = - call->args.empty() ? DataType::Int(32) : call->args[0].dtype(); - Array args{make_const(wait_dtype, new_wait_n)}; - return Evaluate( - Call(call->dtype, call->op, args, call->annotations, call->span)); - } - - Stmt RewriteWaitStaticInSimpleWrapper(const Stmt &stmt, int new_wait_n, - bool *changed) const { - ClassifiedStmt cls = ClassifySimpleAsyncStmt(stmt); - if (cls.kind != AsyncStmtKind::kWaitStatic) { - return stmt; - } - - if (const auto *eval = stmt.as()) { - const auto *call = eval->value.as(); - if (call && IsWaitCall(call)) { - *changed = true; - return MakeStaticWaitStmtLike(stmt, new_wait_n); - } - } - if (const auto *let = stmt.as()) { - Stmt new_body = - RewriteWaitStaticInSimpleWrapper(let->body, new_wait_n, changed); - if (*changed) { - return LetStmt(let->var, let->value, new_body, let->span); - } - return stmt; - } - if (const auto *attr = stmt.as()) { - Stmt new_body = - RewriteWaitStaticInSimpleWrapper(attr->body, new_wait_n, changed); - if (*changed) { - return AttrStmt(attr->node, attr->attr_key, attr->value, new_body, - attr->span); - } - return stmt; - } - if (const auto *iff = stmt.as()) { - if (!iff->else_case.defined()) { - Stmt then_case = RewriteWaitStaticInSimpleWrapper(iff->then_case, - new_wait_n, changed); - if (*changed) { - return IfThenElse(iff->condition, then_case, Stmt(), iff->span); - } - } - return stmt; - } - if (const auto *seq = stmt.as()) { - if (seq->seq.size() == 1) { - Stmt inner = - RewriteWaitStaticInSimpleWrapper(seq->seq[0], new_wait_n, changed); - if (*changed) { - return SeqStmt({inner}); - } - } - return stmt; - } - if (const auto *block = stmt.as()) { - Stmt inner = - RewriteWaitStaticInSimpleWrapper(block->body, new_wait_n, changed); - if (*changed) { - Block new_block = Downcast(stmt); - BlockNode *n = new_block.CopyOnWrite(); - n->body = inner; - return new_block; - } - return stmt; - } - if (const auto *realize = stmt.as()) { - if (is_one(realize->predicate)) { - Stmt inner = RewriteWaitStaticInSimpleWrapper(realize->block->body, - new_wait_n, changed); - if (*changed) { - Block block = realize->block; - BlockNode *n = block.CopyOnWrite(); - n->body = inner; - return BlockRealize(realize->iter_values, realize->predicate, block, - realize->span); - } - } - return stmt; - } - - return stmt; - } - - void UpdatePendingStateWithTransfer( - PendingAsyncState *pending, - const DeterministicNoWaitTransfer &transfer) const { - auto pending_to_state = [](bool has_pending) { - return has_pending ? PendingAsyncState::kNonZero - : PendingAsyncState::kZero; - }; - - if (*pending == PendingAsyncState::kZero) { - *pending = pending_to_state(transfer.pending_if_start_clear); - return; - } - if (*pending == PendingAsyncState::kNonZero) { - *pending = pending_to_state(transfer.pending_if_start_pending); - return; - } - if (transfer.pending_if_start_clear == transfer.pending_if_start_pending) { - *pending = pending_to_state(transfer.pending_if_start_clear); - } else { - *pending = PendingAsyncState::kUnknown; - } - } - - int GuaranteedNewGroupsBeforeNextWait(const Array &body, - int start_idx) const { - PendingAsyncState pending = PendingAsyncState::kUnknown; - int guaranteed_groups = 0; - - for (int i = start_idx, n = static_cast(body.size()); i < n; ++i) { - AsyncIntrinSummary summary = SummarizeAsyncIntrinsics(body[i]); - if (summary.wait > 0) { - break; - } - if (summary.cp_async == 0 && summary.commit == 0) { - continue; - } - - ClassifiedStmt cls = ClassifySimpleAsyncStmt(body[i]); - if (cls.kind == AsyncStmtKind::kCPAsync) { - pending = PendingAsyncState::kNonZero; - continue; - } - if (cls.kind == AsyncStmtKind::kCommit) { - if (pending == PendingAsyncState::kNonZero) { - guaranteed_groups = AddWithCap(guaranteed_groups, 1); - } - pending = PendingAsyncState::kZero; - continue; - } - if (summary.cp_async > 0 && summary.commit == 0) { - pending = PendingAsyncState::kNonZero; - continue; - } - if (auto transfer = TryGetDeterministicNoWaitTransfer(body[i])) { - int guaranteed_new_groups = std::min(transfer->groups_if_start_clear, - transfer->groups_if_start_pending); - guaranteed_groups = - AddWithCap(guaranteed_groups, guaranteed_new_groups); - UpdatePendingStateWithTransfer(&pending, *transfer); - continue; - } - - // Unknown no-wait async shape: keep already guaranteed groups but drop - // pending precision for subsequent commit accounting. - pending = PendingAsyncState::kUnknown; - } - - return guaranteed_groups; - } - - Stmt MaybeRelaxUnrolledEpilogueLoopWaits(const For &loop, int retain) const { - if (!loop.defined() || loop->kind != ForKind::kUnrolled) { - return loop; - } - if (!loop->annotations.Get("tl_pipelined_num_stages")) { - return loop; - } - const auto *extent_imm = loop->extent.as(); - if (extent_imm == nullptr || extent_imm->value <= 1) { - return loop; - } - - const auto *seq = loop->body.as(); - if (!seq || seq->seq.empty()) { - return loop; - } - - int wait_stmt_idx = -1; - for (int i = 0, n = static_cast(seq->seq.size()); i < n; ++i) { - AsyncIntrinSummary summary = SummarizeAsyncIntrinsics(seq->seq[i]); - if (summary.cp_async > 0 || summary.commit > 0) { - return loop; - } - if (summary.wait == 0) { - continue; - } - ClassifiedStmt cls = ClassifySimpleAsyncStmt(seq->seq[i]); - if (summary.wait != 1 || cls.kind != AsyncStmtKind::kWaitStatic || - cls.wait_n != 0) { - return loop; - } - if (wait_stmt_idx >= 0) { - return loop; - } - wait_stmt_idx = i; - } - if (wait_stmt_idx < 0) { - return loop; - } - - Array relaxed_body = seq->seq; - bool changed = false; - relaxed_body.Set(wait_stmt_idx, - RewriteWaitStaticInSimpleWrapper( - relaxed_body[wait_stmt_idx], retain, &changed)); - if (!changed) { - return loop; - } - - For prefix_loop = loop; - ForNode *prefix = prefix_loop.CopyOnWrite(); - prefix->extent = IntImm(loop->extent.dtype(), - static_cast(extent_imm->value) - 1); - prefix->body = - relaxed_body.size() == 1 ? relaxed_body[0] : SeqStmt(relaxed_body); - - PrimExpr last_iter = - loop->min + IntImm(loop->extent.dtype(), extent_imm->value - 1); - Map vmap; - vmap.Set(loop->loop_var, last_iter); - Stmt tail_body = Substitute(loop->body, vmap); - return SeqStmt({prefix_loop, tail_body}); - } - - Stmt MaybeRelaxLoopWaits(const For &loop, - const std::optional & /*pre_outstanding_exact*/, - int pre_outstanding_lb) const { - if (!loop.defined()) { - return loop; - } - int retain = PipelinedRetainGroups(loop); - if (retain <= 0) { - return loop; - } - if (loop->kind == ForKind::kUnrolled) { - return MaybeRelaxUnrolledEpilogueLoopWaits(loop, retain); - } - if (loop->kind != ForKind::kSerial) { - return loop; - } - - const auto *seq = loop->body.as(); - if (!seq || seq->seq.empty()) { - return loop; - } - - Array body = seq->seq; - bool changed = false; - - PendingAsyncState pending = PendingAsyncState::kUnknown; - int outstanding_lb = std::max(0, pre_outstanding_lb); - int groups_since_wait_lb = 0; - bool seen_wait_boundary = false; - - for (int i = 0, n = static_cast(body.size()); i < n; ++i) { - ClassifiedStmt cls = ClassifySimpleAsyncStmt(body[i]); - if (cls.kind == AsyncStmtKind::kCPAsync) { - pending = PendingAsyncState::kNonZero; - continue; - } - if (cls.kind == AsyncStmtKind::kCommit) { - if (pending == PendingAsyncState::kNonZero) { - outstanding_lb = AddWithCap(outstanding_lb, 1); - groups_since_wait_lb = AddWithCap(groups_since_wait_lb, 1); - } - pending = PendingAsyncState::kZero; - continue; - } - if (cls.kind == AsyncStmtKind::kWaitDynamic) { - seen_wait_boundary = true; - pending = PendingAsyncState::kUnknown; - outstanding_lb = 0; - groups_since_wait_lb = 0; - continue; - } - if (cls.kind == AsyncStmtKind::kWaitStatic) { - int effective_wait_n = cls.wait_n; - if (cls.wait_n == 0) { - int groups_after_wait_lb = - GuaranteedNewGroupsBeforeNextWait(body, i + 1); - - int per_sync_groups = groups_since_wait_lb; - bool uses_head_fallback = - (per_sync_groups == 0 && !seen_wait_boundary); - if (uses_head_fallback) { - // Head wait: even with no in-iteration prefetch before it, keep - // one iteration's worth in flight when there is enough prologue - // outstanding and deterministic producer work after the wait. - per_sync_groups = 1; - } - - int candidate_wait_n = - std::max(0, std::min(retain * per_sync_groups, 7)); - bool enough_pre_outstanding = true; - if (uses_head_fallback) { - // Head wait has no in-iteration prefetch before it. Require - // pre-loop committed groups so wait_group(N) is not a no-op. - enough_pre_outstanding = outstanding_lb >= (candidate_wait_n + 1); - } - if (candidate_wait_n > 0 && enough_pre_outstanding && - (!uses_head_fallback || groups_after_wait_lb > 0)) { - bool changed_wait = false; - body.Set(i, RewriteWaitStaticInSimpleWrapper( - body[i], candidate_wait_n, &changed_wait)); - if (changed_wait) { - changed = true; - effective_wait_n = candidate_wait_n; - } - } - } - - seen_wait_boundary = true; - outstanding_lb = std::min(outstanding_lb, effective_wait_n); - groups_since_wait_lb = 0; - continue; - } - - if (!ContainsAsyncIntrinsics(body[i])) { - continue; - } - - AsyncIntrinSummary summary = SummarizeAsyncIntrinsics(body[i]); - if (summary.cp_async > 0 && summary.commit == 0 && summary.wait == 0) { - pending = PendingAsyncState::kNonZero; - continue; - } - if (summary.wait == 0) { - if (auto transfer = TryGetDeterministicNoWaitTransfer(body[i])) { - int guaranteed_new_groups = - std::min(transfer->groups_if_start_clear, - transfer->groups_if_start_pending); - outstanding_lb = AddWithCap(outstanding_lb, guaranteed_new_groups); - groups_since_wait_lb = - AddWithCap(groups_since_wait_lb, guaranteed_new_groups); - UpdatePendingStateWithTransfer(&pending, *transfer); - continue; - } - } - - if (summary.wait > 0) { - seen_wait_boundary = true; - } - pending = PendingAsyncState::kUnknown; - outstanding_lb = 0; - groups_since_wait_lb = 0; - } - - if (!changed) { - return loop; - } - For new_loop = loop; - ForNode *n = new_loop.CopyOnWrite(); - n->body = body.size() == 1 ? body[0] : SeqStmt(body); - return new_loop; - } - - int PipelinedRetainGroups(const For &loop) const { - // Keep (num_stages - 1) committed groups in flight when possible. - // This metadata is preserved by PipelinePlanning under the dedicated - // annotation key "tl_pipelined_num_stages". - int retain = 1; - if (!loop.defined()) { - return retain; - } - if (auto anno = loop->annotations.Get("tl_pipelined_num_stages")) { - int num_stages = -1; - if (const auto *imm = anno.value().as()) { - num_stages = static_cast(imm->value); - } - if (num_stages >= 1) { - retain = std::max(0, num_stages - 1); - } - } - return retain; - } - - bool IsCPAsyncCall(const CallNode *call) const { - return call && (call->op.same_as(builtin::ptx_cp_async()) || - call->op.same_as(tl::ptx_cp_async())); - } - - bool IsCommitCall(const CallNode *call) const { - return call && call->op.same_as(builtin::ptx_commit_group()); - } - - bool IsWaitCall(const CallNode *call) const { - return call && call->op.same_as(builtin::ptx_wait_group()); - } - - bool ContainsAsyncIntrinsics(const Stmt &stmt) const { - bool found = false; - PostOrderVisit(stmt, [&](const ObjectRef &node) { - if (found) { - return; - } - const auto *call = node.as(); - if (!call) { - return; - } - if (IsCPAsyncCall(call) || IsCommitCall(call) || IsWaitCall(call)) { - found = true; - } - }); - return found; - } - - AsyncIntrinSummary SummarizeAsyncIntrinsics(const Stmt &stmt) const { - AsyncIntrinSummary summary; - PostOrderVisit(stmt, [&](const ObjectRef &node) { - const auto *call = node.as(); - if (!call) { - return; - } - if (IsCPAsyncCall(call)) { - ++summary.cp_async; - } else if (IsCommitCall(call)) { - ++summary.commit; - } else if (IsWaitCall(call)) { - ++summary.wait; - } - }); - return summary; - } - - ClassifiedStmt ClassifySimpleAsyncStmt(const Stmt &stmt) const { - if (const auto *let = stmt.as()) { - return ClassifySimpleAsyncStmt(let->body); - } - if (const auto *attr = stmt.as()) { - return ClassifySimpleAsyncStmt(attr->body); - } - // Do not treat IfThenElse as a "simple wrapper": conditional execution can - // invalidate cp.async bookkeeping and make wait relaxation unsafe when the - // prefetch path is skipped at runtime (e.g. blocksparse kernels). - if (const auto *seq = stmt.as()) { - if (seq->seq.size() == 1) { - return ClassifySimpleAsyncStmt(seq->seq[0]); - } - return {}; - } - if (const auto *block = stmt.as()) { - return ClassifySimpleAsyncStmt(block->body); - } - if (const auto *realize = stmt.as()) { - if (is_one(realize->predicate)) { - return ClassifySimpleAsyncStmt(realize->block->body); - } - return {}; - } - - const auto *eval = stmt.as(); - if (!eval) { - return {}; - } - const auto *call = eval->value.as(); - if (!call) { - return {}; - } - if (IsCPAsyncCall(call)) { - return {AsyncStmtKind::kCPAsync, 0}; - } - if (IsCommitCall(call)) { - return {AsyncStmtKind::kCommit, 0}; - } - if (IsWaitCall(call)) { - if (!call->args.empty()) { - if (const auto *imm = call->args[0].as()) { - return {AsyncStmtKind::kWaitStatic, static_cast(imm->value)}; - } - } - return {AsyncStmtKind::kWaitDynamic, 0}; - } - return {}; - } -}; - -tvm::transform::Pass OptimizeCPAsyncSync() { - auto pass_func = [](PrimFunc f, const IRModule &m, - const tvm::transform::PassContext &ctx) { - PrimFuncNode *fptr = f.CopyOnWrite(); - fptr->body = CPAsyncSyncOptimizer()(std::move(fptr->body)); - return f; - }; - return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0, - "tl.OptimizeCPAsyncSync", {}); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tl.transform.OptimizeCPAsyncSync", - OptimizeCPAsyncSync); -} - -} // namespace transform -} // namespace tl -} // namespace tvm diff --git a/src/transform/pipeline_planning.cc b/src/transform/pipeline_planning.cc index 55a8b118df..191f3a93ca 100644 --- a/src/transform/pipeline_planning.cc +++ b/src/transform/pipeline_planning.cc @@ -7,10 +7,15 @@ #include #include "../op/builtin.h" +#include "../op/copy.h" +#include "../op/parallel.h" +#include "../op/region.h" #include "../op/utils.h" +#include "common/pipeline_utils.h" #include #include #include +#include #include #include #include @@ -232,9 +237,10 @@ class AsyncDependencyChainBuilder : public StmtExprVisitor { class BufferRegionCollector : public StmtExprVisitor { public: BufferRegionCollector(Map buffer_data_to_buffer, - const AsyncDependencyChainBuilder &chain_builder) + const AsyncDependencyChainBuilder &chain_builder, + Target target) : buffer_data_to_buffer_(buffer_data_to_buffer), - chain_builder_(chain_builder) {} + chain_builder_(chain_builder), target_(target) {} Array GetReads() const { return reads_; } @@ -244,7 +250,60 @@ class BufferRegionCollector : public StmtExprVisitor { bool GetTmaCopyPattern() const { return is_tma_copy_; } + bool HasNonCopyTileOp() const { return has_non_copy_tile_op_; } + private: + static bool IsGlobalLikeBuffer(const Buffer &buffer) { + return IsGlobalBuffer(buffer) || + (buffer.defined() && buffer.scope().empty()); + } + + void HandleTileOp(const TileOperator &tile_op) { + if (tile_op.as()) { + return; + } + if (const auto *parallel = tile_op.as()) { + BufferRegionCollector nested(buffer_data_to_buffer_, chain_builder_, + target_); + nested(parallel->GetRoot()); + reads_.insert(reads_.end(), nested.GetReads().begin(), + nested.GetReads().end()); + writes_.insert(writes_.end(), nested.GetWrites().begin(), + nested.GetWrites().end()); + is_global_copy_pattern_ = + is_global_copy_pattern_ || nested.GetGlobalCopyPattern(); + is_tma_copy_ = is_tma_copy_ || nested.GetTmaCopyPattern(); + has_non_copy_tile_op_ = + has_non_copy_tile_op_ || nested.HasNonCopyTileOp(); + return; + } + AccessRegions access = tile_op->GetAccessRegions(); + reads_.insert(reads_.end(), access.reads.begin(), access.reads.end()); + writes_.insert(writes_.end(), access.writes.begin(), access.writes.end()); + // Detect explicit TMA-like producer ops for pipeline planning. + // Plain T.copy no longer auto-upgrades to TMA in the generic pipeline + // path; only warp-specialized rewriting may turn it into + // tl.tileop.tma_copy. + if (const auto *copy = tile_op.as()) { + if (IsGlobalLikeBuffer(copy->src) && IsSharedBuffer(copy->dst)) { + is_global_copy_pattern_ = true; + } + } + // Conv2D im2col always uses TMA on Hopper. + if (const auto *im2col = tile_op.as()) { + if (IsGlobalLikeBuffer(im2col->src_) && IsSharedBuffer(im2col->dst_)) { + is_global_copy_pattern_ = true; + if (TargetIsHopper(target_)) { + is_tma_copy_ = true; + } + } + return; + } + if (!tile_op.as()) { + has_non_copy_tile_op_ = true; + } + } + Optional TryGetBufFromAccessPtr(const PrimExpr &expr) const { auto call = expr.as(); if (!call) @@ -271,22 +330,6 @@ class BufferRegionCollector : public StmtExprVisitor { return Optional(); } - void VisitStmt_(const AttrStmtNode *op) final { - if (op->attr_key == "tl.tma_copy_write_buffer") { - // TMA copy lowering annotates the producer with the shared buffer - // it writes to. Use this to detect TMA copy stages and track the - // written buffer for pipeline dependency analysis. - auto var = Downcast(op->node); - auto it = buffer_data_to_buffer_.find(var); - if (it != buffer_data_to_buffer_.end()) { - writes_.push_back(BufferRegion::FullRegion((*it).second)); - is_global_copy_pattern_ = true; - is_tma_copy_ = true; - } - } - StmtExprVisitor::VisitStmt_(op); - } - void VisitStmt_(const BufferStoreNode *op) final { Buffer store_buffer = op->buffer; Array indices = op->indices; @@ -317,7 +360,7 @@ class BufferRegionCollector : public StmtExprVisitor { auto load_region = BufferRegion(load_buffer, region); reads_.push_back(load_region); - if (IsGlobalBuffer(op->buffer) && !within_condition_expr_) { + if (IsGlobalLikeBuffer(op->buffer) && !within_condition_expr_) { // skip condition expr of if_then_else node // shared[i] = T.if_then_else(global[i] < n, register_a[i], register_b[i]) // is not a global read shared[i] = T.if_then_else(global[i] < n, @@ -328,6 +371,12 @@ class BufferRegionCollector : public StmtExprVisitor { void VisitExpr_(const CallNode *op) final { auto args = op->args; + if (auto tile_op = ParseOperator(tvm::ffi::GetRef(op)); + tile_op.defined()) { + HandleTileOp(tile_op); + StmtExprVisitor::VisitExpr_(op); + return; + } if (op->op.same_as(builtin::address_of())) { BufferRegion buffer_region; if (const auto *load = op->args[0].as()) { @@ -367,8 +416,7 @@ class BufferRegionCollector : public StmtExprVisitor { writes_.push_back(BufferRegion::FullRegion(dst_buf.value())); } if (src_buf.defined() && dst_buf.defined() && - (IsGlobalBuffer(src_buf.value()) || - src_buf.value().scope().empty()) && + IsGlobalLikeBuffer(src_buf.value()) && IsSharedBuffer(dst_buf.value())) { is_global_copy_pattern_ = true; } @@ -434,12 +482,14 @@ class BufferRegionCollector : public StmtExprVisitor { private: AsyncDependencyChainBuilder chain_builder_; Map buffer_data_to_buffer_; + Target target_; Array reads_; Array writes_; bool is_global_read_ = false; bool under_buffer_store_ = false; bool is_global_copy_pattern_ = false; bool is_tma_copy_ = false; + bool has_non_copy_tile_op_ = false; bool within_condition_expr_ = false; }; @@ -487,6 +537,7 @@ class PipelinePlanner : public StmtExprMutator { int order = -1, stage = -1; bool copy_stage = false; bool tma_copy = false; // true if this copy stage uses TMA (not cp.async) + bool conditional_execution = false; bool producer_for_copy = false; // Commit statements have no buffer writes, but they must be scheduled as a // part of their cp.async producer group (after the cp.async calls). @@ -541,9 +592,9 @@ class PipelinePlanner : public StmtExprMutator { } else if (call->op.same_as(builtin::ptx_wait_group())) { ++info.cp_async_wait_count; if (!call->args.empty()) { - if (const auto *imm = call->args[0].as()) { + if (const int64_t *imm = as_const_int(call->args[0])) { info.cp_async_wait_min_inflight = std::min( - info.cp_async_wait_min_inflight, static_cast(imm->value)); + info.cp_async_wait_min_inflight, static_cast(*imm)); } else { info.cp_async_wait_has_dynamic = true; } @@ -555,6 +606,346 @@ class PipelinePlanner : public StmtExprMutator { return info; } + bool MayBeConditionallyExecuted(const Stmt &stmt) const { + bool conditional = false; + PostOrderVisit(stmt, [&](const ObjectRef &node) { + if (conditional) { + return; + } + if (const auto *if_then_else = node.as()) { + conditional = true; + return; + } + if (const auto *realize = node.as()) { + if (!is_one(realize->predicate)) { + conditional = true; + } + } + }); + return conditional; + } + + bool IsAsyncProducerCandidate(const PipelineStageInfo &pinfo) const { + if (pinfo.conditional_execution) { + return false; + } + if (pinfo.is_tma_copy()) { + return false; + } + if (pinfo.has_cp_async_wait()) { + return false; + } + if (pinfo.has_cp_async_commit() && !pinfo.has_cp_async_call()) { + return false; + } + return pinfo.is_copy_stage() || pinfo.has_cp_async_call(); + } + + bool IsPureCopyStmt(const Stmt &stmt) const { + auto is_global_like_buffer = [](const Buffer &buffer) { + return IsGlobalBuffer(buffer) || + (buffer.defined() && buffer.scope().empty()); + }; + auto is_pure_raw_copy_value = [&](const PrimExpr &expr, + const auto &self) -> bool { + if (const auto *load = expr.as()) { + return is_global_like_buffer(load->buffer); + } + if (const auto *cast = expr.as()) { + return self(cast->value, self); + } + return false; + }; + + bool saw_copy = false; + bool saw_non_copy_tile_op = false; + bool saw_non_copy_buffer_store = false; + PostOrderVisit(stmt, [&](const ObjectRef &node) { + if (saw_non_copy_tile_op || saw_non_copy_buffer_store) { + return; + } + if (const auto *store = node.as()) { + saw_copy = true; + if ((!IsSharedBuffer(store->buffer) && + !IsLocalBuffer(store->buffer, /*allow_var=*/true)) || + !is_pure_raw_copy_value(store->value, is_pure_raw_copy_value)) { + saw_non_copy_buffer_store = true; + } + return; + } + const auto *call = node.as(); + if (call == nullptr) { + return; + } + auto tile_op = ParseOperator(tvm::ffi::GetRef(call)); + if (!tile_op.defined()) { + return; + } + if (tile_op.as()) { + return; + } + if (const auto *parallel = tile_op.as()) { + if (IsPureCopyStmt(parallel->GetRoot())) { + saw_copy = true; + } else { + saw_non_copy_tile_op = true; + } + return; + } + if (tile_op.as() || tile_op.as()) { + saw_copy = true; + } else { + saw_non_copy_tile_op = true; + } + }); + return saw_copy && !saw_non_copy_tile_op && !saw_non_copy_buffer_store; + } + + Optional GetSinglePureCopyTileOp(const Stmt &stmt) const { + Optional copy_tile_op; + bool saw_non_copy_tile_op = false; + bool saw_multiple_copy_ops = false; + PostOrderVisit(stmt, [&](const ObjectRef &node) { + if (saw_non_copy_tile_op || saw_multiple_copy_ops) { + return; + } + const auto *call = node.as(); + if (call == nullptr) { + return; + } + auto tile_op = ParseOperator(tvm::ffi::GetRef(call)); + if (!tile_op.defined()) { + return; + } + if (tile_op.as()) { + return; + } + if (tile_op.as() || tile_op.as()) { + if (copy_tile_op.defined()) { + saw_multiple_copy_ops = true; + copy_tile_op = Optional(); + } else { + copy_tile_op = tile_op; + } + } else { + saw_non_copy_tile_op = true; + copy_tile_op = Optional(); + } + }); + if (saw_non_copy_tile_op || saw_multiple_copy_ops) { + return Optional(); + } + return copy_tile_op; + } + + static bool IsGlobalLikeBuffer(const Buffer &buffer) { + return IsGlobalBuffer(buffer) || + (buffer.defined() && buffer.scope().empty()); + } + + void ClassifyCopyLikeStage(const Stmt &stmt, PipelineStageInfo *pinfo) const { + ICHECK(pinfo != nullptr); + if (pinfo->conditional_execution) { + return; + } + + // Explicit cp.async producer statements participate in the synthetic + // stage-0 producer schedule just like ordinary global->shared copies. + if (pinfo->has_cp_async_call()) { + pinfo->copy_stage = true; + return; + } + + if (pinfo->copy_stage) { + return; + } + + auto copy_tile_op = GetSinglePureCopyTileOp(stmt); + if (!copy_tile_op.defined()) { + return; + } + + if (const auto *copy = copy_tile_op.value().as()) { + if (!IsGlobalLikeBuffer(copy->src) || !IsSharedBuffer(copy->dst)) { + return; + } + pinfo->copy_stage = true; + return; + } + + if (const auto *im2col = copy_tile_op.value().as()) { + if (!IsGlobalLikeBuffer(im2col->src_) || !IsSharedBuffer(im2col->dst_)) { + return; + } + pinfo->copy_stage = true; + pinfo->tma_copy = TargetIsHopper(target_); + } + } + + void AnalyzeCopyLastUse( + std::vector *pipeline_stage_infos) const { + for (auto &pinfo : *pipeline_stage_infos) { + if (!pinfo.is_first_stage()) { + continue; + } + + for (int i = pinfo.original_stmt_index + 1; + i < static_cast(pipeline_stage_infos->size()); ++i) { + for (const BufferRegion &read : (*pipeline_stage_infos)[i].reads) { + if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(), + [&](const BufferRegion &r) { + return r->buffer == read->buffer && + MayConflict(r->region, read->region); + }) != pinfo.writes.end()) { + pinfo.last_use_stmt_index = std::max(pinfo.last_use_stmt_index, i); + } + } + + if (!pinfo.is_copy_stage() || + (pinfo.cp_async_group >= 0 && + pinfo.cp_async_group == + (*pipeline_stage_infos)[i].cp_async_group)) { + continue; + } + + for (const BufferRegion &write : (*pipeline_stage_infos)[i].writes) { + if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(), + [&](const BufferRegion &r) { + return r->buffer == write->buffer && + MayConflict(r->region, write->region); + }) != pinfo.writes.end()) { + LOG(FATAL) << "Pipeline planning error: Multiple writes to " + "overlapping buffer regions detected. " + << "Stage " << pinfo.original_stmt_index << " and stage " + << i << " are both writing to buffer '" + << write->buffer->name + << "' with overlapping regions. This is not supported " + "in pipeline planning."; + } + } + } + } + } + + bool EmitImplicitAsyncAnnotations( + const std::vector &pipeline_stage_infos, + Map *annotations) const { + if (!TargetHasAsyncCopy(target_) || !use_async_copy_) { + return false; + } + + std::vector async_group_ids(pipeline_stage_infos.size(), -1); + std::vector stmt_indices_by_order(pipeline_stage_infos.size()); + std::iota(stmt_indices_by_order.begin(), stmt_indices_by_order.end(), 0); + std::stable_sort(stmt_indices_by_order.begin(), stmt_indices_by_order.end(), + [&](int lhs, int rhs) { + if (pipeline_stage_infos[lhs].order != + pipeline_stage_infos[rhs].order) { + return pipeline_stage_infos[lhs].order < + pipeline_stage_infos[rhs].order; + } + return lhs < rhs; + }); + + int next_async_group_id = 0; + std::map, int> implicit_group_ids; + for (int stmt_idx : stmt_indices_by_order) { + const auto &pinfo = pipeline_stage_infos[stmt_idx]; + if (!IsAsyncProducerCandidate(pinfo)) { + continue; + } + auto key = std::make_pair(pinfo.stage, pinfo.last_use_stmt_index); + auto [it, inserted] = + implicit_group_ids.emplace(key, next_async_group_id); + if (inserted) { + ++next_async_group_id; + } + async_group_ids[stmt_idx] = it->second; + } + + if (next_async_group_id == 0) { + return false; + } + + std::vector async_producers; + std::vector async_producer_groups; + async_producers.reserve(pipeline_stage_infos.size()); + async_producer_groups.reserve(pipeline_stage_infos.size()); + std::unordered_set async_stage_ids; + for (size_t i = 0; i < pipeline_stage_infos.size(); ++i) { + bool is_async_producer = async_group_ids[i] != -1; + async_producers.push_back(Integer(is_async_producer ? 1 : 0)); + async_producer_groups.push_back(Integer(async_group_ids[i])); + if (is_async_producer) { + async_stage_ids.insert(pipeline_stage_infos[i].stage); + } + } + + annotations->Set(kPipelineAsyncProducers, Array(async_producers)); + annotations->Set(kPipelineAsyncProducerGroups, + Array(async_producer_groups)); + + std::vector sorted_async_stage_ids(async_stage_ids.begin(), + async_stage_ids.end()); + std::sort(sorted_async_stage_ids.begin(), sorted_async_stage_ids.end()); + std::vector async_stages; + async_stages.reserve(sorted_async_stage_ids.size()); + for (int stage_id : sorted_async_stage_ids) { + async_stages.push_back(Integer(stage_id)); + } + annotations->Set(tir::attr::software_pipeline_async_stages, + Array(async_stages)); + return true; + } + + void MaybeAnnotateLegacyAsyncPipelineLoop(const Stmt &pipeline_body_root, + const Array &pipeline_stmts, + const Array &order_array, + const Array &stage_array, + Map *annotations) { + if (!TargetHasAsyncCopy(target_) || !use_async_copy_) { + return; + } + ICHECK_EQ(pipeline_stmts.size(), order_array.size()); + ICHECK_EQ(pipeline_stmts.size(), stage_array.size()); + + AsyncDependencyChainBuilder chain_builder(buffer_data_to_buffer_); + chain_builder(pipeline_body_root); + + std::vector pipeline_stage_infos; + pipeline_stage_infos.reserve(pipeline_stmts.size()); + for (size_t i = 0; i < pipeline_stmts.size(); ++i) { + auto pinfo = MakePipelineStageInfo(pipeline_stmts[i], i, chain_builder); + ClassifyCopyLikeStage(pipeline_stmts[i], &pinfo); + pinfo.order = static_cast(order_array[i]->value); + pinfo.stage = static_cast(stage_array[i]->value); + if (!pinfo.is_copy_stage() && !pinfo.conditional_execution && + pinfo.stage == 0) { + bool reads_global = false; + bool writes_shared = false; + for (const BufferRegion &read : pinfo.reads) { + if (IsGlobalLikeBuffer(read->buffer)) { + reads_global = true; + break; + } + } + for (const BufferRegion &write : pinfo.writes) { + if (IsSharedBuffer(write->buffer)) { + writes_shared = true; + break; + } + } + if (reads_global && writes_shared) { + pinfo.copy_stage = true; + } + } + pipeline_stage_infos.push_back(std::move(pinfo)); + } + + AnalyzeCopyLastUse(&pipeline_stage_infos); + EmitImplicitAsyncAnnotations(pipeline_stage_infos, annotations); + } + PipelineStageInfo MakePipelineStageInfo(Stmt stmt, int idx, AsyncDependencyChainBuilder &chain_builder) { @@ -563,20 +954,25 @@ class PipelinePlanner : public StmtExprMutator { Array> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); auto collector = - BufferRegionCollector(buffer_data_to_buffer_, chain_builder); + BufferRegionCollector(buffer_data_to_buffer_, chain_builder, target_); collector(block); PipelineStageInfo pinfo; pinfo.reads = std::move(collector.GetReads()); pinfo.writes = std::move(collector.GetWrites()); pinfo.original_stmt_index = idx; - pinfo.copy_stage = collector.GetGlobalCopyPattern(); - pinfo.tma_copy = collector.GetTmaCopyPattern(); + pinfo.conditional_execution = MayBeConditionallyExecuted(block->body); + bool pure_copy_stage = + collector.GetGlobalCopyPattern() && IsPureCopyStmt(block->body); + pinfo.copy_stage = pure_copy_stage; + pinfo.tma_copy = pure_copy_stage && !pinfo.conditional_execution && + collector.GetTmaCopyPattern(); auto async_info = AnalyzeAsyncIntrinsics(block->body); pinfo.cp_async_call_count = async_info.cp_async_call_count; pinfo.cp_async_commit_count = async_info.cp_async_commit_count; pinfo.cp_async_wait_count = async_info.cp_async_wait_count; pinfo.cp_async_wait_min_inflight = async_info.cp_async_wait_min_inflight; pinfo.cp_async_wait_has_dynamic = async_info.cp_async_wait_has_dynamic; + ClassifyCopyLikeStage(block->body, &pinfo); return std::move(pinfo); } @@ -623,9 +1019,53 @@ class PipelinePlanner : public StmtExprMutator { } } annotations.Set(tir::attr::software_pipeline_stage, stage_anno.value()); - if (TargetHasAsyncCopy(target_) && use_async_copy_) + if (TargetHasAsyncCopy(target_) && use_async_copy_) { + // Legacy explicit stage/order annotations do not carry per-statement + // async producer metadata yet, so keep the previous stage-level + // behavior as a fallback for these loops. annotations.Set(tir::attr::software_pipeline_async_stages, Array{0}); + } + Stmt pipeline_body_root{nullptr}; + const SeqStmtNode *pipeline_body_seq = nullptr; + if (const auto *realize = loop->body.as()) { + const auto &block = realize->block; + for (const auto &buffer : block->alloc_buffers) { + ICHECK(buffer->IsInstance()); + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + pipeline_body_root = block->body; + } else { + pipeline_body_root = loop->body; + } + { + Stmt current = pipeline_body_root; + while (true) { + if (const auto *seq_stmt = current.as()) { + pipeline_body_seq = seq_stmt; + break; + } + if (const auto *if_then_else = current.as()) { + ICHECK(!if_then_else->else_case.defined()) + << "Pipeline_Planning: Can't handle the body of the loop " + "because the IfThenElse node has an else branch"; + current = if_then_else->then_case; + continue; + } + if (const auto *let_stmt = current.as()) { + current = let_stmt->body; + continue; + } + LOG(FATAL) << "Pipeline_Planning: Can't handle the body of the loop " + << "because it is not a SeqStmt, IfThenElse without else, " + << "or LetStmt wrapping them, but got " + << current->GetTypeKey(); + } + } + ICHECK(pipeline_body_seq != nullptr); + MaybeAnnotateLegacyAsyncPipelineLoop(pipeline_body_root, + pipeline_body_seq->seq, order_array, + stage_array, &annotations); auto for_node = tvm::ffi::GetRef(loop); for_node.CopyOnWrite()->annotations = annotations; return for_node; @@ -871,24 +1311,11 @@ class PipelinePlanner : public StmtExprMutator { last_bound_consumer_stmt = consumer_stmt_idx; } - // Prioritize cp.async groups with earlier consumers when enforcing group - // boundary ordering. This helps place prefetches for near-term consumers - // earlier in the stage-0 schedule. std::vector cp_async_group_schedule_order; cp_async_group_schedule_order.reserve(cp_async_groups.size()); for (size_t group_id = 0; group_id < cp_async_groups.size(); ++group_id) { cp_async_group_schedule_order.push_back(static_cast(group_id)); } - std::stable_sort( - cp_async_group_schedule_order.begin(), - cp_async_group_schedule_order.end(), [&](int lhs_group, int rhs_group) { - int lhs_first_consumer = cp_async_group_first_consumer[lhs_group]; - int rhs_first_consumer = cp_async_group_first_consumer[rhs_group]; - if (lhs_first_consumer != rhs_first_consumer) { - return lhs_first_consumer < rhs_first_consumer; - } - return lhs_group < rhs_group; - }); // For every copy stage, mark all its dependency stages as producer_for_copy // Helper struct to manage copy stage dependency reads @@ -990,54 +1417,7 @@ class PipelinePlanner : public StmtExprMutator { // identifies the index of the last statement that consumes data produced by // copy stages, enabling optimal placement of copy operations in the // pipeline schedule. - for (auto &pinfo : pipeline_stage_infos) { - // Only analyze copy stages (memory copy operations) - if (!pinfo.is_first_stage()) - continue; - - // Check all subsequent statements to find the latest consumer - for (int i = pinfo.original_stmt_index + 1; - i < static_cast(flat_stmts.size()); i++) { - - // Check if any read operation in statement 'i' uses data written by - // this copy stage - for (const BufferRegion &read : pipeline_stage_infos[i].reads) { - // Look for overlapping buffer regions between this stage's writes and - // stage 'i's reads - if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(), - [&](const BufferRegion &r) { - return r->buffer == read->buffer && - MayConflict(r->region, read->region); - }) != pinfo.writes.end()) { - // Update last_use_stmt_index to the maximum (latest) statement - // index that uses this data This ensures we capture the final - // consumer of the copied data - pinfo.last_use_stmt_index = std::max(pinfo.last_use_stmt_index, i); - } - } - // Check for write-after-write conflicts (multiple stages writing to - // same buffer region) This is important for pipeline correctness and - // affects last_use_stmt_index analysis - if (pinfo.is_copy_stage()) { - for (const BufferRegion &write : pipeline_stage_infos[i].writes) { - if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(), - [&](const BufferRegion &r) { - return r->buffer == write->buffer && - MayConflict(r->region, write->region); - }) != pinfo.writes.end()) { - LOG(FATAL) << "Pipeline planning error: Multiple writes to " - "overlapping buffer regions detected. " - << "Stage " << pinfo.original_stmt_index - << " and stage " << i - << " are both writing to buffer '" - << write->buffer->name - << "' with overlapping regions. This is not supported " - "in pipeline planning."; - } - } - } - } - } + AnalyzeCopyLastUse(&pipeline_stage_infos); // Treat each explicit `cp_async* ; commit` producer group as a synthetic // copy stage for scheduling. All statements in the group share the same @@ -1080,6 +1460,27 @@ class PipelinePlanner : public StmtExprMutator { } } + // Order explicit cp.async producer groups by the lifetime of the data they + // introduce. Groups whose data dies earlier should be scheduled earlier in + // the synthetic stage-0 producer schedule, which also matches the desired + // wait rebinding behavior for wait_group(0) consumers. + std::stable_sort( + cp_async_group_schedule_order.begin(), + cp_async_group_schedule_order.end(), [&](int lhs_group, int rhs_group) { + int lhs_last_use = cp_async_groups[lhs_group].last_use_stmt_index; + int rhs_last_use = cp_async_groups[rhs_group].last_use_stmt_index; + if (lhs_last_use != rhs_last_use) { + return lhs_last_use < rhs_last_use; + } + int lhs_first_consumer = cp_async_group_first_consumer[lhs_group]; + int rhs_first_consumer = cp_async_group_first_consumer[rhs_group]; + if (lhs_first_consumer != rhs_first_consumer) { + return lhs_first_consumer < rhs_first_consumer; + } + return cp_async_groups[lhs_group].anchor_cp_async_stmt < + cp_async_groups[rhs_group].anchor_cp_async_stmt; + }); + // Making stages and orders int order_idx = 0; // Stage 1. Create pipeline stages and assign order @@ -1530,8 +1931,8 @@ class PipelinePlanner : public StmtExprMutator { // Preserve the original TileLang pipelining depth for downstream scheduling // (e.g. cp.async wait_group relaxation/splitting). We intentionally do NOT // keep the legacy key "num_stages" here because multiple downstream passes - // (e.g. MultiVersionBuffer/WarpSpecialized) treat it as an active pipeline - // marker and do not support nested pipelines. + // (e.g. internal buffer versioning / warp specialization) treat it as an + // active pipeline marker and do not support nested pipelines. annotations.Set("tl_pipelined_num_stages", Integer(num_stages)); std::vector orders, stages; @@ -1545,18 +1946,92 @@ class PipelinePlanner : public StmtExprMutator { annotations.Set(tir::attr::software_pipeline_stage, Array(stages)); annotations.Set(tir::attr::software_pipeline_order, Array(orders)); - // Only mark stage 0 as async for cp.async copies. TMA copies use - // mbarrier synchronization and don't need async_commit/wait_queue. - bool has_tma_copy = false; - for (const auto &pinfo : pipeline_stage_infos) { - if (pinfo.is_tma_copy()) { - has_tma_copy = true; - break; + // Propagate per-statement TMA eligibility so InjectSoftwarePipeline can + // rewrite TMA copies to use pipeline-level barrier management. + { + std::vector tma_copies; + tma_copies.reserve(pipeline_stage_infos.size()); + for (auto &pinfo : pipeline_stage_infos) { + tma_copies.push_back(Integer(pinfo.is_tma_copy() ? 1 : 0)); + } + annotations.Set(kPipelineTmaCopies, Array(tma_copies)); + } + + if (TargetHasAsyncCopy(target_) && use_async_copy_) { + std::vector async_group_ids(pipeline_stage_infos.size(), -1); + int next_async_group_id = 0; + + for (int scheduled_group_id : cp_async_group_schedule_order) { + const auto &group = cp_async_groups[scheduled_group_id]; + bool emitted_group = false; + for (int stmt_idx : group.cp_async_stmt_indices) { + if (!IsAsyncProducerCandidate(pipeline_stage_infos[stmt_idx])) { + continue; + } + async_group_ids[stmt_idx] = next_async_group_id; + emitted_group = true; + } + if (emitted_group) { + ++next_async_group_id; + } + } + + std::vector stmt_indices_by_order(pipeline_stage_infos.size()); + std::iota(stmt_indices_by_order.begin(), stmt_indices_by_order.end(), 0); + std::stable_sort(stmt_indices_by_order.begin(), + stmt_indices_by_order.end(), [&](int lhs, int rhs) { + if (pipeline_stage_infos[lhs].order != + pipeline_stage_infos[rhs].order) { + return pipeline_stage_infos[lhs].order < + pipeline_stage_infos[rhs].order; + } + return lhs < rhs; + }); + std::map, int> implicit_group_ids; + for (int stmt_idx : stmt_indices_by_order) { + const auto &pinfo = pipeline_stage_infos[stmt_idx]; + if (!IsAsyncProducerCandidate(pinfo) || + async_group_ids[stmt_idx] != -1) { + continue; + } + auto key = std::make_pair(pinfo.stage, pinfo.last_use_stmt_index); + auto [it, inserted] = + implicit_group_ids.emplace(key, next_async_group_id); + if (inserted) { + ++next_async_group_id; + } + async_group_ids[stmt_idx] = it->second; + } + + std::vector async_producers; + std::vector async_producer_groups; + async_producers.reserve(pipeline_stage_infos.size()); + async_producer_groups.reserve(pipeline_stage_infos.size()); + std::unordered_set async_stage_ids; + for (size_t i = 0; i < pipeline_stage_infos.size(); ++i) { + bool is_async_producer = async_group_ids[i] != -1; + async_producers.push_back(Integer(is_async_producer ? 1 : 0)); + async_producer_groups.push_back(Integer(async_group_ids[i])); + if (is_async_producer) { + async_stage_ids.insert(pipeline_stage_infos[i].stage); + } + } + annotations.Set(kPipelineAsyncProducers, Array(async_producers)); + annotations.Set(kPipelineAsyncProducerGroups, + Array(async_producer_groups)); + if (!async_stage_ids.empty()) { + std::vector sorted_async_stage_ids(async_stage_ids.begin(), + async_stage_ids.end()); + std::sort(sorted_async_stage_ids.begin(), sorted_async_stage_ids.end()); + std::vector async_stages; + async_stages.reserve(sorted_async_stage_ids.size()); + for (int stage_id : sorted_async_stage_ids) { + async_stages.push_back(Integer(stage_id)); + } + annotations.Set(tir::attr::software_pipeline_async_stages, + Array(async_stages)); } } - if (TargetHasAsyncCopy(target_) && use_async_copy_ && !has_tma_copy) - annotations.Set(tir::attr::software_pipeline_async_stages, - Array{0}); // Reconstruct the loop body with the flattened SeqStmt so that // InjectSoftwarePipeline sees the correct number of pipeline stages. diff --git a/src/transform/producer_consumer_ws.cc b/src/transform/producer_consumer_ws.cc index 3c550b29f3..e9080f0b0a 100644 --- a/src/transform/producer_consumer_ws.cc +++ b/src/transform/producer_consumer_ws.cc @@ -1,121 +1,72 @@ /*! * \file producer_consumer_ws.cc - * \brief Producer-consumer warp specialization for sm90+ async-copy pipelines. + * \brief Warp-specialized producer/consumer rewriting at the tile-op level. * - * Works on the inline barrier IR emitted by lowering passes such as - * LowerBulkCopy / LowerPTXAsyncCopy: - * SeqStmt({ - * AttrStmt("tl.tma_copy_write_buffer", buf, 1, - * IfThenElse(threadIdx.x == 0, - * SeqStmt({arrive_expect_tx(mbar, bytes), tma_load(...)}))), - * mbarrier_wait_parity(mbar, parity) - * }) + * This pass runs **before** LayoutInference and LowerTileOp, operating on + * high-level tile ops (`tl.tileop.copy`, `tl.tileop.gemm`, etc.). + * It recognizes pipelined producer/consumer structure directly from tile-op + * semantics and splits eligible loops into warp-specialized branches with + * explicit barrier synchronization. * - * The pass splits the pipelined loop into: - * producer: issues TMA / cp.async - * consumer: waits, computes, and releases buffers + * The output IR is equivalent to a hand-written warp-specialized kernel: + * - TMA-annotated copies become `tl.tileop.tma_copy` with barrier refs + * - Barriers (`mbarrier_wait_parity`, `ptx_arrive_barrier`) are inserted + * - The loop body is wrapped in `if (threadIdx.x >= consumer_extent)` * - * For pure-TMA loops we rewrite the forward-barrier protocol so the producer - * releases the barrier after issuing the TMA copy: - * expect_transaction -> tma_load -> arrive + * Limitations (v1): + * - Pure TMA pipelines only (no mixed TMA + cp.async) + * - No conditionally guarded loop bodies (phase counters) + * - Single pipelined loop per block + * - No pre-loop TMA prefetch / prologue optimizations */ +#include +#include +#include +#include +#include +#include + +#include "../op/builtin.h" +#include "../op/copy.h" +#include "../op/fill.h" +#include "../op/gemm.h" +#include "../op/gemm_py.h" +#include "../op/operator.h" +#include "../op/region.h" #include "../op/utils.h" +#include "../target/utils.h" #include "common/mbarrier.h" -#include "common/tma_copy_utils.h" +#include "multi_version_buffer_rewriter.h" #include "warp_specialized_rewriter.h" -#include -#include -#include -#include -#include - namespace tvm { namespace tl { using namespace tir; -using namespace runtime; + +namespace { // --------------------------------------------------------------------------- -// Data structures +// Utility: flatten SeqStmt recursively // --------------------------------------------------------------------------- - -enum class AsyncProducerKind : uint8_t { kTma, kCpAsync }; - -struct AsyncCopyBlockInfo { - AsyncProducerKind kind; - Stmt producer_stmt; // TMA issue or cp.async enqueue+commit - Optional wait_stmt; // Existing forward wait for TMA blocks - Optional write_buffer_data; // shared buffer written by producer -}; - -using BufferDataToBufferMap = - std::unordered_map; -using BufferSet = std::unordered_set; -using VarSet = std::unordered_set; -using VarBindingMap = - std::unordered_map; - -struct LocalAccessSummary { - BufferSet all_read_buffers; - BufferSet all_write_buffers; - BufferSet branch_private_read_buffers; - BufferSet branch_private_write_buffers; - VarSet read_vars; - VarSet def_vars; - - bool HasTrackedDefs() const { - return !branch_private_write_buffers.empty() || !def_vars.empty(); - } -}; - -struct LocalLiveSet { - BufferSet buffers; - VarSet vars; - - bool NeedsAnyDef(const LocalAccessSummary &summary) const { - for (const auto &buf : summary.branch_private_write_buffers) { - if (buffers.count(buf)) { - return true; - } - } - for (const auto &var : summary.def_vars) { - if (vars.count(var)) { - return true; - } - } - return false; - } - - void KillDefs(const LocalAccessSummary &summary) { - for (const auto &buf : summary.branch_private_write_buffers) { - buffers.erase(buf); - } - for (const auto &var : summary.def_vars) { - vars.erase(var); +void FlattenSeqStmt(const Stmt &s, Array *out) { + if (auto *seq = s.as()) { + for (const auto &sub : seq->seq) { + FlattenSeqStmt(sub, out); } + } else { + out->push_back(s); } +} - void AddUses(const LocalAccessSummary &summary) { - buffers.insert(summary.branch_private_read_buffers.begin(), - summary.branch_private_read_buffers.end()); - vars.insert(summary.read_vars.begin(), summary.read_vars.end()); - } -}; +/// Annotation key marking that this function was transformed by the tiled WS +/// pass, so downstream passes can skip redundant transformations. +static constexpr const char *kTiledWSApplied = "tl_tiled_ws_applied"; // --------------------------------------------------------------------------- -// PhaseCounter: mutable int32 counter for guarded-loop phase tracking +// PhaseCounter: local counter for correct barrier parity in guarded loops // --------------------------------------------------------------------------- - -/*! - * \brief When a pipeline loop body is conditionally guarded (e.g. - * `if block_mask[k]: ...`), the loop-variable-based parity - * `(k / num_stages) % 2` can desynchronise because skipped iterations - * don't touch barriers. A PhaseCounter is a local int32[1] buffer - * that tracks the *actual* number of guarded-body entries so that - * parity/stage are always correct. - */ struct PhaseCounter { Buffer buf; @@ -137,7 +88,6 @@ struct PhaseCounter { return BufferStore(buf, Load() + 1, {IntImm(DataType::Int(32), 0)}); } - /*! Wrap a For-loop with Allocate + DeclBuffer + Init(0). */ Stmt WrapLoopWithAlloc(Stmt loop) const { Stmt body = SeqStmt({Init(), std::move(loop)}); body = DeclBuffer(buf, body); @@ -157,17 +107,9 @@ struct PhaseCounter { } }; -/*! - * \brief Replace the loop-variable-based stage expression with a - * phase-counter-based one inside producer / consumer statements. - * - * When `needs_phase_counter` is true, the barrier IDs already use - * `phase_counter->StageExpr(N)` but the shared-memory buffer offsets - * still embed `FloorMod(loop_var - loop_min, N)`. This mutator - * rewrites every matching FloorMod to the replacement expression so - * that stage indexing stays in sync with barrier indexing when loop - * iterations are conditionally skipped. - */ +// --------------------------------------------------------------------------- +// StageExprReplacer: rewrite loop-var-based stage indexing to counter-based +// --------------------------------------------------------------------------- class StageExprReplacer : public StmtExprMutator { public: static Stmt Replace(const Stmt &stmt, Var loop_var, PrimExpr loop_min, @@ -190,7 +132,6 @@ class StageExprReplacer : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } - /*! Match `loop_var`, `loop_var - loop_min`, or `loop_var - 0`. */ bool MatchLinearIdx(const PrimExpr &expr) const { if (expr.same_as(loop_var_)) return true; @@ -211,6 +152,202 @@ class StageExprReplacer : public StmtExprMutator { PrimExpr replacement_; }; +// --------------------------------------------------------------------------- +// Statement classification +// --------------------------------------------------------------------------- + +using BufferDataToBufferMap = + std::unordered_map; +using BufferSet = std::unordered_set; +using VarSet = std::unordered_set; +using BufferNodeMap = std::unordered_map; +using VarExprMap = std::unordered_map; + +struct LocalAccessSummary { + BufferSet read_buffers; + BufferSet write_buffers; + VarSet read_vars; + VarSet def_vars; + + bool HasTrackedDefs() const { + return !write_buffers.empty() || !def_vars.empty(); + } +}; + +struct LocalLiveSet { + BufferSet buffers; + VarSet vars; + + bool NeedsAnyDef(const LocalAccessSummary &summary) const { + for (const auto &buf : summary.write_buffers) { + if (buffers.count(buf)) { + return true; + } + } + for (const auto &var : summary.def_vars) { + if (vars.count(var)) { + return true; + } + } + return false; + } + + void AddUses(const LocalAccessSummary &summary) { + buffers.insert(summary.read_buffers.begin(), summary.read_buffers.end()); + vars.insert(summary.read_vars.begin(), summary.read_vars.end()); + } +}; + +static void MergeLocalAccessSummary(LocalAccessSummary *dst, + const LocalAccessSummary &src) { + dst->read_buffers.insert(src.read_buffers.begin(), src.read_buffers.end()); + dst->write_buffers.insert(src.write_buffers.begin(), src.write_buffers.end()); + dst->read_vars.insert(src.read_vars.begin(), src.read_vars.end()); + dst->def_vars.insert(src.def_vars.begin(), src.def_vars.end()); +} + +static Buffer CloneBranchPrivateBuffer(const Buffer &buffer, + const std::string &suffix) { + Type new_type = buffer->data->type_annotation; + if (IsFragmentBuffer(buffer)) { + const auto *ptr_type = buffer->data->type_annotation.as(); + ICHECK(ptr_type); + new_type = PointerType(ptr_type->element_type, "local"); + } + Var new_var(buffer->data->name_hint + suffix, new_type); + return Buffer(new_var, buffer->dtype, buffer->shape, buffer->strides, + buffer->elem_offset, buffer->name + suffix, + buffer->data_alignment, buffer->offset_factor, + buffer->buffer_type); +} + +class BufferRemapper : public StmtExprMutator { +public: + static Stmt Rewrite(const Stmt &stmt, const BufferNodeMap &buffer_remap) { + if (buffer_remap.empty()) { + return stmt; + } + BufferRemapper remapper(buffer_remap); + return remapper.VisitStmt(stmt); + } + +private: + explicit BufferRemapper(const BufferNodeMap &buffer_remap) + : buffer_remap_(buffer_remap) { + for (const auto &[old_buf, new_buf] : buffer_remap_) { + var_remap_.emplace(old_buf->data.get(), new_buf->data); + } + } + + Buffer RemapBuffer(const Buffer &buffer) const { + auto it = buffer_remap_.find(buffer.get()); + if (it != buffer_remap_.end()) { + return it->second; + } + return buffer; + } + + PrimExpr VisitExpr_(const VarNode *op) final { + auto it = var_remap_.find(op); + if (it != var_remap_.end()) { + return it->second; + } + return StmtExprMutator::VisitExpr_(op); + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + Buffer new_buffer = RemapBuffer(load->buffer); + if (!new_buffer.same_as(load->buffer)) { + return BufferLoad(new_buffer, load->indices, load->predicate, load->span); + } + return load; + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + Buffer new_buffer = RemapBuffer(store->buffer); + if (!new_buffer.same_as(store->buffer)) { + return BufferStore(new_buffer, store->value, store->indices, + store->predicate, store->span); + } + return store; + } + + const BufferNodeMap &buffer_remap_; + VarExprMap var_remap_; +}; + +enum class TileStmtKind { + kTmaProducer, // TMA load producer (global->shared) + kCpAsyncProducer, // Explicit cp.async / commit / wait_group producer stmt + kSimtProducer, // Non-tile-op SIMT copy: For loop writing shared from global + kConsumer, // Compute (gemm, reduce, element-wise, etc.) + kOther // Unclassified +}; + +/// Detect if a statement is a SIMT global-to-shared memory copy. +/// Matches any statement that writes to shared memory and reads from global +/// memory, without reading shared or local buffers (which would indicate +/// consumer-side compute). This is intentionally broader than "pure direct +/// copy" so that T.Parallel with complex indexing / if_then_else (later +/// lowered to cp.async) is also captured. +class SimtProducerDetector : public StmtExprVisitor { +public: + static bool Detect(const Stmt &stmt) { + SimtProducerDetector d; + d(stmt); + return d.writes_shared_ && d.reads_global_ && !d.reads_shared_local_; + } + +private: + void VisitStmt_(const BufferStoreNode *op) final { + if (IsSharedBuffer(op->buffer)) { + writes_shared_ = true; + } + StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const BufferLoadNode *op) final { + if (IsGlobalBuffer(op->buffer)) { + reads_global_ = true; + } + if (IsSharedBuffer(op->buffer) || IsLocalBuffer(op->buffer, true)) { + reads_shared_local_ = true; + } + StmtExprVisitor::VisitExpr_(op); + } + + bool writes_shared_{false}; + bool reads_global_{false}; + bool reads_shared_local_{false}; +}; + +static const CallNode *GetEvaluateCallInSimpleWrapper(const Stmt &stmt) { + if (const auto *eval = stmt.as()) { + return eval->value.as(); + } + if (const auto *if_stmt = stmt.as()) { + if (!if_stmt->else_case.defined()) { + return GetEvaluateCallInSimpleWrapper(if_stmt->then_case); + } + return nullptr; + } + if (const auto *attr = stmt.as()) { + return GetEvaluateCallInSimpleWrapper(attr->body); + } + if (const auto *let = stmt.as()) { + return GetEvaluateCallInSimpleWrapper(let->body); + } + if (const auto *block = stmt.as()) { + return GetEvaluateCallInSimpleWrapper(block->body); + } + if (const auto *realize = stmt.as()) { + return GetEvaluateCallInSimpleWrapper(realize->block->body); + } + return nullptr; +} + class BufferDataToBufferCollector : public StmtExprVisitor { public: static BufferDataToBufferMap Collect(const Stmt &stmt) { @@ -220,13 +357,6 @@ class BufferDataToBufferCollector : public StmtExprVisitor { } private: - void VisitStmt_(const DeclBufferNode *op) final { - if (op->buffer.defined()) { - result_.emplace(op->buffer->data, op->buffer); - } - StmtExprVisitor::VisitStmt_(op); - } - void VisitStmt_(const BlockRealizeNode *op) final { CollectBuffers(op->block); StmtExprVisitor::VisitStmt_(op); @@ -255,52 +385,14 @@ class LocalAccessCollector : public StmtExprVisitor { return std::move(collector.summary_); } - static LocalAccessSummary - CollectExpr(const PrimExpr &expr, const BufferDataToBufferMap &buffer_map) { - LocalAccessCollector collector(buffer_map); - collector.VisitExpr(expr); - return std::move(collector.summary_); - } - private: - static bool IsTrackedBranchPrivateBuffer(const Buffer &buffer) { - return IsLocalBuffer(buffer, /*allow_var=*/true) || - IsFragmentBuffer(buffer); - } - - void MarkBufferAccess(const Buffer &buffer, int rw_mask) { - if (!buffer.defined()) { - return; - } - if (rw_mask & 1) { - summary_.all_read_buffers.insert(buffer); - if (IsTrackedBranchPrivateBuffer(buffer)) { - summary_.branch_private_read_buffers.insert(buffer); - } - } - if (rw_mask & 2) { - summary_.all_write_buffers.insert(buffer); - if (IsTrackedBranchPrivateBuffer(buffer)) { - summary_.branch_private_write_buffers.insert(buffer); - } - } - } - - void MarkRawBufferVarArg(const PrimExpr &expr, int rw_mask) { - const auto *var = expr.as(); - if (!var) { - return; - } - auto it = buffer_data_to_buffer_.find(GetRef(var)); - if (it == buffer_data_to_buffer_.end()) { - return; - } - MarkBufferAccess(it->second, rw_mask); - } - explicit LocalAccessCollector(const BufferDataToBufferMap &buffer_map) : buffer_data_to_buffer_(buffer_map) {} + static bool IsBranchPrivateBuffer(const Buffer &buffer) { + return IsFragmentBuffer(buffer) || IsLocalBuffer(buffer, true); + } + void VisitStmt_(const LetStmtNode *op) final { VisitExpr(op->value); summary_.def_vars.insert(op->var); @@ -318,12 +410,16 @@ class LocalAccessCollector : public StmtExprVisitor { } void VisitExpr_(const BufferLoadNode *op) final { - MarkBufferAccess(op->buffer, /*rw_mask=*/1); + if (IsBranchPrivateBuffer(op->buffer)) { + summary_.read_buffers.insert(op->buffer); + } StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const BufferStoreNode *op) final { - MarkBufferAccess(op->buffer, /*rw_mask=*/2); + if (IsBranchPrivateBuffer(op->buffer)) { + summary_.write_buffers.insert(op->buffer); + } StmtExprVisitor::VisitStmt_(op); } @@ -336,12 +432,51 @@ class LocalAccessCollector : public StmtExprVisitor { } void VisitExpr_(const CallNode *op) final { + if (auto tile_op = ParseOperator(ffi::GetRef(op)); + tile_op.defined()) { + if (const auto *copy = tile_op.as()) { + if (IsBranchPrivateBuffer(copy->src)) { + summary_.read_buffers.insert(copy->src); + } + if (IsBranchPrivateBuffer(copy->dst)) { + summary_.write_buffers.insert(copy->dst); + } + for (const auto &range : copy->src_range) { + VisitExpr(range->min); + VisitExpr(range->extent); + } + for (const auto &range : copy->dst_range) { + VisitExpr(range->min); + VisitExpr(range->extent); + } + return; + } + if (const auto *fill = tile_op.as()) { + if (IsBranchPrivateBuffer(fill->dst)) { + summary_.write_buffers.insert(fill->dst); + } + VisitExpr(fill->value); + for (const auto &range : fill->region) { + VisitExpr(range->min); + VisitExpr(range->extent); + } + return; + } + } + if (op->op.same_as(tl::access_ptr())) { ICHECK_EQ(op->args.size(), 3); const auto *base_load = op->args[0].as(); ICHECK(base_load); - int rw_mask = GetConstAccessMask(op->args[2]); - MarkBufferAccess(base_load->buffer, rw_mask); + if (IsBranchPrivateBuffer(base_load->buffer)) { + int rw_mask = GetConstAccessMask(op->args[2]); + if (rw_mask & 1) { + summary_.read_buffers.insert(base_load->buffer); + } + if (rw_mask & 2) { + summary_.write_buffers.insert(base_load->buffer); + } + } for (const auto &index : base_load->indices) { VisitExpr(index); } @@ -354,34 +489,27 @@ class LocalAccessCollector : public StmtExprVisitor { const auto *var = op->args[1].as(); ICHECK(var); auto it = buffer_data_to_buffer_.find(GetRef(var)); - if (it != buffer_data_to_buffer_.end()) { + if (it != buffer_data_to_buffer_.end() && + IsBranchPrivateBuffer(it->second)) { int rw_mask = GetConstAccessMask(op->args[4]); - MarkBufferAccess(it->second, rw_mask); + if (rw_mask & 1) { + summary_.read_buffers.insert(it->second); + } + if (rw_mask & 2) { + summary_.write_buffers.insert(it->second); + } } VisitExpr(op->args[2]); VisitExpr(op->args[3]); return; } - if (op->op.same_as(tl::warpgroup_fence_operand())) { - ICHECK_EQ(op->args.size(), 4); - MarkRawBufferVarArg(op->args[1], /*rw_mask=*/1); - } else if (op->op.same_as(tl::ptx_wgmma_ss())) { - ICHECK_EQ(op->args.size(), 15); - // WGMMA accumulates into C registers in place. - MarkRawBufferVarArg(op->args[10], /*rw_mask=*/3); - } else if (op->op.same_as(tl::ptx_wgmma_rs())) { - ICHECK_EQ(op->args.size(), 14); - MarkRawBufferVarArg(op->args[5], /*rw_mask=*/1); - MarkRawBufferVarArg(op->args[9], /*rw_mask=*/3); - } - StmtExprVisitor::VisitExpr_(op); } int GetConstAccessMask(const PrimExpr &expr) const { - if (const auto *imm = expr.as()) { - return static_cast(imm->value); + if (const int64_t *imm = as_const_int(expr)) { + return static_cast(*imm); } return 3; } @@ -391,962 +519,865 @@ class LocalAccessCollector : public StmtExprVisitor { VarSet bound_vars_; }; -class ProducerSimtCopyDetector : public StmtExprVisitor { -public: - static bool HasSimtCopy(const Stmt &stmt, - const BufferDataToBufferMap &buffer_map) { - ProducerSimtCopyDetector detector(buffer_map); - detector.VisitStmt(stmt); - return detector.has_global_read_ && detector.has_shared_write_; - } - -private: - explicit ProducerSimtCopyDetector(const BufferDataToBufferMap &buffer_map) - : buffer_data_to_buffer_(buffer_map) {} +enum class PreludeStmtPlacement : uint8_t { + kKeepSharedPrelude, + kProducerOnly, + kConsumerOnly, + kDuplicateToBoth, +}; - void VisitStmt_(const IfThenElseNode *op) final { - bool old_in_if_cond = in_if_cond_; - in_if_cond_ = true; - VisitExpr(op->condition); - in_if_cond_ = old_in_if_cond; - VisitStmt(op->then_case); - if (op->else_case.defined()) { - VisitStmt(op->else_case.value()); - } +static PreludeStmtPlacement +ClassifyPreludeStmt(const Stmt &stmt, const BufferDataToBufferMap &buffer_map, + const LocalLiveSet &producer_live_seed, + const LocalLiveSet &consumer_live_seed) { + LocalAccessSummary summary = LocalAccessCollector::Collect(stmt, buffer_map); + if (!summary.HasTrackedDefs()) { + return PreludeStmtPlacement::kKeepSharedPrelude; } - void VisitExpr_(const BufferLoadNode *op) final { - if (!in_if_cond_ && !in_async_copy_ && IsGlobalBuffer(op->buffer)) { - has_global_read_ = true; - } - StmtExprVisitor::VisitExpr_(op); + bool producer_needs = producer_live_seed.NeedsAnyDef(summary); + bool consumer_needs = consumer_live_seed.NeedsAnyDef(summary); + if (producer_needs && consumer_needs) { + return PreludeStmtPlacement::kDuplicateToBoth; } - - void VisitStmt_(const BufferStoreNode *op) final { - if (!in_if_cond_ && !in_async_copy_ && IsSharedBuffer(op->buffer)) { - has_shared_write_ = true; - } - StmtExprVisitor::VisitStmt_(op); + if (producer_needs) { + return PreludeStmtPlacement::kProducerOnly; } - - void VisitExpr_(const CallNode *op) final { - bool old_in_async_copy = in_async_copy_; - if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) || - op->op.same_as(tma_store()) || op->op.same_as(tma_store_arrive()) || - op->op.same_as(tma_store_wait()) || - op->op.same_as(tl::ptx_cp_async()) || - op->op.same_as(builtin::ptx_cp_async())) { - in_async_copy_ = true; - } - - if (op->op.same_as(tl::access_ptr())) { - ICHECK_EQ(op->args.size(), 3); - const auto *base_load = op->args[0].as(); - ICHECK(base_load); - MarkAccess(base_load->buffer, GetConstAccessMask(op->args[2])); - for (const auto &index : base_load->indices) { - VisitExpr(index); - } - VisitExpr(op->args[1]); - in_async_copy_ = old_in_async_copy; - return; - } - - if (op->op.same_as(builtin::tvm_access_ptr())) { - ICHECK_EQ(op->args.size(), 5); - const auto *var = op->args[1].as(); - ICHECK(var); - auto it = buffer_data_to_buffer_.find(GetRef(var)); - if (it != buffer_data_to_buffer_.end()) { - MarkAccess(it->second, GetConstAccessMask(op->args[4])); - } - VisitExpr(op->args[2]); - VisitExpr(op->args[3]); - in_async_copy_ = old_in_async_copy; - return; - } - - StmtExprVisitor::VisitExpr_(op); - in_async_copy_ = old_in_async_copy; + if (consumer_needs) { + return PreludeStmtPlacement::kConsumerOnly; } + return PreludeStmtPlacement::kKeepSharedPrelude; +} - void MarkAccess(const Buffer &buffer, int rw_mask) { - if (in_if_cond_ || in_async_copy_ || !buffer.defined()) { +static bool ContainsPtxCpAsync(const Stmt &stmt) { + bool found = false; + PostOrderVisit(stmt, [&](const ObjectRef &node) { + if (found) { return; } - if ((rw_mask & 1) && IsGlobalBuffer(buffer)) { - has_global_read_ = true; - } - if ((rw_mask & 2) && IsSharedBuffer(buffer)) { - has_shared_write_ = true; - } - } - - int GetConstAccessMask(const PrimExpr &expr) const { - if (const auto *imm = expr.as()) { - return static_cast(imm->value); + if (const auto *call = node.as()) { + if (call->op.same_as(builtin::ptx_cp_async()) || + call->op.same_as(tl::ptx_cp_async())) { + found = true; + } } - return 3; - } - - const BufferDataToBufferMap &buffer_data_to_buffer_; - bool has_global_read_{false}; - bool has_shared_write_{false}; - bool in_if_cond_{false}; - bool in_async_copy_{false}; -}; - -// --------------------------------------------------------------------------- -// Helpers (reused from warp_specialized_rewriter.cc patterns) -// --------------------------------------------------------------------------- + }); + return found; +} -static PrimExpr makeGetBarrier(const Buffer &barrier_buf, PrimExpr barrier_id) { - return MakeBarrierRef(barrier_buf, std::move(barrier_id)); +static bool IsPtxCommitGroup(const Stmt &stmt) { + const auto *call = GetEvaluateCallInSimpleWrapper(stmt); + return call && call->op.same_as(builtin::ptx_commit_group()); } -static Stmt makeArriveBarrier(const Buffer &barrier_buf, PrimExpr barrier_id) { - Array args = {makeGetBarrier(barrier_buf, std::move(barrier_id))}; - return Evaluate( - Call(DataType::Handle(), builtin::ptx_arrive_barrier(), args)); +static bool IsPtxWaitGroup(const Stmt &stmt) { + const auto *call = GetEvaluateCallInSimpleWrapper(stmt); + return call && call->op.same_as(builtin::ptx_wait_group()); } -static Stmt makeCpAsyncBarrierNoInc(const Buffer &barrier_buf, - PrimExpr barrier_id) { - auto call = Call(DataType::Handle(), tl::ptx_cp_async_barrier_noinc(), - {makeGetBarrier(barrier_buf, std::move(barrier_id))}); - return Evaluate(call); +static bool IsBarrierOrTmaControlCall(const CallNode *call) { + return call->op.same_as(mbarrier_wait_parity()) || + call->op.same_as(mbarrier_expect_tx()) || + call->op.same_as(builtin::ptx_arrive_barrier()) || + call->op.same_as(tl::ptx_arrive_cluster_barrier()) || + call->op.same_as(builtin::ptx_arrive_barrier_expect_tx()) || + call->op.same_as(builtin::ptx_cp_async_barrier()) || + call->op.same_as(tl::ptx_cp_async_barrier_noinc()) || + call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col()) || + call->op.same_as(tma_store()) || + call->op.same_as(tma_store_arrive()) || + call->op.same_as(tma_store_wait()) || + call->op.same_as(builtin::tvm_storage_sync()); } -static Stmt makeParityWait(const Buffer &barrier_buf, PrimExpr barrier_id, - PrimExpr parity) { - auto call = Call( - DataType::Handle(), mbarrier_wait_parity(), - {makeGetBarrier(barrier_buf, std::move(barrier_id)), std::move(parity)}); - return Evaluate(call); -} - -static bool IsTrivialNoOpStmt(const Stmt &stmt) { - if (const auto *eval = stmt.as()) { - if (const auto *imm = eval->value.as()) { - return imm->value == 0; - } +static bool IsSyncGlobalToSharedCopyLikeStmt(const Stmt &stmt, Target target) { + const auto *call = GetEvaluateCallInSimpleWrapper(stmt); + if (!call) { + return false; } - if (const auto *seq = stmt.as()) { - for (const auto &s : seq->seq) { - if (!IsTrivialNoOpStmt(s)) { - return false; - } - } - return true; + auto tile_op = ParseOperator(ffi::GetRef(call)); + if (!tile_op.defined()) { + return false; + } + const auto *copy = tile_op.as(); + if (copy == nullptr || !copy->CheckCPAsyncCopyPreconditions() || + copy->GetIsTmaCopy() || copy->GetIsAsyncCopy()) { + return false; } - return false; -} -// --------------------------------------------------------------------------- -// AsyncCopyBlockExtractor -// --------------------------------------------------------------------------- + arith::Analyzer analyzer; + return !copy->CheckBulkLoad(target, &analyzer, /*check_last_dim=*/true); +} -/*! - * \brief Extract async producer blocks from a flattened loop body. - * - * Recognized patterns: - * - * Pattern 1: AttrStmt("tl.tma_copy_write_buffer", ...) + mbarrier_wait_parity - * Pattern 2: IfThenElse containing tma_load + mbarrier_wait_parity - * Pattern 3: one or more cp_async-only stmts + commit_group + wait_group(0) - * - * Everything else is classified as a compute statement. - */ -class AsyncCopyBlockExtractor { -public: - std::vector blocks; - std::vector compute_stmts; - - void Extract(const Array &flat_stmts) { - size_t i = 0; - while (i < flat_stmts.size()) { - if (i + 1 < flat_stmts.size() && - IsMbarrierWaitParity(flat_stmts[i + 1])) { - Optional write_buffer_data = - ExtractTmaCopyWriteBufferData(flat_stmts[i]); - // Check Pattern 1/2: TMA producer + wait pair, optionally wrapped in a - // simple guard/Block/Let/Attr shell. Recover the written shared buffer - // when the tl.tma_copy_write_buffer annotation survives under wrappers. - if (write_buffer_data.defined() || ContainsTmaLoad(flat_stmts[i])) { - blocks.push_back({AsyncProducerKind::kTma, - StripTmaCopyWriteBufferAttr(flat_stmts[i]), - Optional(flat_stmts[i + 1]), - write_buffer_data}); - i += 2; - continue; - } - } - if (ContainsPtxCpAsync(flat_stmts[i])) { - size_t cp_async_end = i; - while (cp_async_end + 1 < flat_stmts.size() && - ContainsPtxCpAsync(flat_stmts[cp_async_end + 1])) { - ++cp_async_end; - } - if (cp_async_end + 2 < flat_stmts.size() && - IsPtxCommitGroup(flat_stmts[cp_async_end + 1]) && - IsPtxWaitGroupZero(flat_stmts[cp_async_end + 2])) { - Array producer_seq; - producer_seq.reserve(cp_async_end - i + 2); - for (size_t j = i; j <= cp_async_end; ++j) { - producer_seq.push_back(flat_stmts[j]); - } - producer_seq.push_back(flat_stmts[cp_async_end + 1]); - Stmt producer_stmt = producer_seq.size() == 1 ? producer_seq[0] - : SeqStmt(producer_seq); - blocks.push_back({AsyncProducerKind::kCpAsync, producer_stmt, - Optional(), - GetCpAsyncDstBufferData(producer_stmt)}); - i = cp_async_end + 3; - continue; - } - } - compute_stmts.push_back(flat_stmts[i]); - i++; - } +static bool IsProducerMovableLoopPrefixStmt(const Stmt &stmt, Target target) { + if (IsSyncGlobalToSharedCopyLikeStmt(stmt, target)) { + return true; } -private: - static const CallNode *GetEvaluateCallInSimpleWrapper(const Stmt &stmt) { - if (const auto *eval = stmt.as()) { - return eval->value.as(); - } - if (const auto *if_stmt = stmt.as()) { - if (!if_stmt->else_case.defined() || - IsTrivialNoOpStmt(if_stmt->else_case.value())) { - return GetEvaluateCallInSimpleWrapper(if_stmt->then_case); - } - return nullptr; - } - if (const auto *attr = stmt.as()) { - return GetEvaluateCallInSimpleWrapper(attr->body); - } - if (const auto *let = stmt.as()) { - return GetEvaluateCallInSimpleWrapper(let->body); - } - if (const auto *seq = stmt.as()) { - if (seq->seq.size() == 1) { - return GetEvaluateCallInSimpleWrapper(seq->seq[0]); - } - return nullptr; - } - if (const auto *block = stmt.as()) { - return GetEvaluateCallInSimpleWrapper(block->body); + bool has_allowed_work = false; + bool has_disallowed = false; + PostOrderVisit(stmt, [&](const ObjectRef &node) { + if (has_disallowed) { + return; } - if (const auto *realize = stmt.as()) { - if (is_one(realize->predicate)) { - return GetEvaluateCallInSimpleWrapper(realize->block->body); + if (const auto *call = node.as()) { + if (call->op.same_as(builtin::tvm_storage_sync())) { + const auto *scope = call->args[0].as(); + if (!scope || + (scope->value != "shared" && scope->value != "shared.dyn")) { + has_disallowed = true; + return; + } + has_allowed_work = true; + return; } - return nullptr; - } - return nullptr; - } - - static Optional ExtractTmaCopyWriteBufferData(const Stmt &stmt) { - if (const auto *attr = stmt.as()) { - if (attr->attr_key == "tl.tma_copy_write_buffer") { - const auto *v = attr->node.as(); - ICHECK(v); - return GetRef(v); + if (IsBarrierOrTmaControlCall(call)) { + has_disallowed = true; + return; } - return ExtractTmaCopyWriteBufferData(attr->body); } - if (const auto *if_stmt = stmt.as()) { - if (!if_stmt->else_case.defined() || - IsTrivialNoOpStmt(if_stmt->else_case.value())) { - return ExtractTmaCopyWriteBufferData(if_stmt->then_case); + if (const auto *ld = node.as()) { + if (IsSharedBuffer(ld->buffer) || IsLocalBuffer(ld->buffer, true)) { + has_disallowed = true; + return; } - return Optional(); - } - if (const auto *let = stmt.as()) { - return ExtractTmaCopyWriteBufferData(let->body); - } - if (const auto *seq = stmt.as()) { - if (seq->seq.size() == 1) { - return ExtractTmaCopyWriteBufferData(seq->seq[0]); + if (IsGlobalBuffer(ld->buffer)) { + has_allowed_work = true; } - return Optional(); } - if (const auto *block = stmt.as()) { - return ExtractTmaCopyWriteBufferData(block->body); - } - if (const auto *realize = stmt.as()) { - if (is_one(realize->predicate)) { - return ExtractTmaCopyWriteBufferData(realize->block->body); + if (const auto *st = node.as()) { + if (IsSharedBuffer(st->buffer)) { + has_allowed_work = true; + return; } - return Optional(); + has_disallowed = true; } - return Optional(); - } + }); + return has_allowed_work && !has_disallowed; +} - static bool IsMbarrierWaitParity(const Stmt &stmt) { - const auto *call = GetEvaluateCallInSimpleWrapper(stmt); - return call && call->op.same_as(mbarrier_wait_parity()); - } +/// Classify a tile-op copy as TMA load producer, cp.async producer, or +/// consumer. Replicates the coarse checks from InstructionAnnotation inline so +/// that the tiled WS pass does not depend on a prior annotation pass. +static TileStmtKind ClassifyCopy(const CopyNode *copy, Target target) { + // Explicit T.tma_copy() is a load-side primitive: only treat valid + // global->shared TMA loads as producers. TMA stores consume previously + // produced shared data and must stay on the consumer side to preserve + // per-iteration ordering. + if (copy->GetIsTmaCopy()) { + arith::Analyzer analyzer; + if (copy->CheckBulkLoad(target, &analyzer, /*check_last_dim=*/false)) { + return TileStmtKind::kTmaProducer; + } + return TileStmtKind::kConsumer; // target doesn't support TMA + } + // Explicit T.async_copy() + if (copy->GetIsAsyncCopy()) { + return TileStmtKind::kCpAsyncProducer; + } + // Generic T.copy(): check if TMA is possible + { + arith::Analyzer analyzer; + if (!copy->GetDisableTMA() && + copy->CheckBulkLoad(target, &analyzer, /*check_last_dim=*/true)) { + return TileStmtKind::kTmaProducer; + } + } + return TileStmtKind::kConsumer; +} - static bool ContainsTmaLoad(const Stmt &stmt) { - bool found = false; - PostOrderVisit(stmt, [&](const ObjectRef &node) { - if (auto *call = node.as()) { - if (call->op.same_as(tma_load()) || - call->op.same_as(tma_load_im2col())) { - found = true; +/// Classify a single statement in the pipeline loop body. +TileStmtKind ClassifyStmt(const Stmt &stmt, Target target) { + // Tile-op Calls: classify directly via CopyNode checks. + if (auto *eval = stmt.as()) { + if (auto *call = eval->value.as()) { + auto tile_op = ParseOperator(ffi::GetRef(call)); + if (tile_op.defined()) { + if (auto *copy = tile_op.as()) { + return ClassifyCopy(copy, target); + } + // Conv2D im2col lowers to tma_load_im2col on Hopper — treat as TMA + // producer so it goes to the producer warp group. + if (tile_op.as()) { + if (TargetIsHopper(target)) { + return TileStmtKind::kTmaProducer; + } } + return TileStmtKind::kConsumer; // non-copy tile-op } - }); - return found; + } } - - static bool ContainsPtxCpAsync(const Stmt &stmt) { - bool found = false; - PostOrderVisit(stmt, [&](const ObjectRef &node) { - if (found) { - return; - } - if (const auto *call = node.as()) { - if (call->op.same_as(builtin::ptx_cp_async()) || - call->op.same_as(tl::ptx_cp_async())) { - found = true; - } - } - }); - return found; + // Explicit cp.async producer-side statements are already low-level builtins. + if (ContainsPtxCpAsync(stmt) || IsPtxCommitGroup(stmt) || + IsPtxWaitGroup(stmt)) { + return TileStmtKind::kCpAsyncProducer; } - - static bool IsPtxCommitGroup(const Stmt &stmt) { - const auto *call = GetEvaluateCallInSimpleWrapper(stmt); - return call && call->op.same_as(builtin::ptx_commit_group()); + // Non-tile-op: check for SIMT global-to-shared copy. + if (SimtProducerDetector::Detect(stmt)) { + return TileStmtKind::kSimtProducer; } + return TileStmtKind::kConsumer; +} - static bool IsPtxWaitGroupZero(const Stmt &stmt) { - const auto *call = GetEvaluateCallInSimpleWrapper(stmt); - if (!call || !call->op.same_as(builtin::ptx_wait_group())) { - return false; - } - ICHECK_EQ(call->args.size(), 1); - const auto *imm = call->args[0].as(); - ICHECK(imm); - return imm->value == 0; - } +bool IsProducer(TileStmtKind kind) { + return kind == TileStmtKind::kTmaProducer || + kind == TileStmtKind::kCpAsyncProducer || + kind == TileStmtKind::kSimtProducer; +} - static Optional AccessPtrBufferVar(const PrimExpr &ptr) { - const auto *call = ptr.as(); - if (!call) { - return Optional(); - } - if (call->op.same_as(tl::access_ptr())) { - ICHECK_EQ(call->args.size(), 3); - const auto *base_load = call->args[0].as(); - ICHECK(base_load); - return base_load->buffer->data; - } - if (call->op.same_as(builtin::tvm_access_ptr())) { - ICHECK_EQ(call->args.size(), 5); - const auto *var = call->args[1].as(); - ICHECK(var); - return GetRef(var); - } - ICHECK(false) << "Expected tl.access_ptr or tvm_access_ptr"; - throw; - } +// --------------------------------------------------------------------------- +// Helpers: create barrier IR nodes +// --------------------------------------------------------------------------- - static Optional GetCpAsyncDstBufferData(const Stmt &stmt) { - Optional found = std::nullopt; - bool multiple = false; - PostOrderVisit(stmt, [&](const ObjectRef &node) { - if (multiple) { - return; - } - const auto *call = node.as(); - if (!call) { - return; - } - if (!(call->op.same_as(builtin::ptx_cp_async()) || - call->op.same_as(tl::ptx_cp_async()))) { - return; - } - ICHECK(!call->args.empty()); - Optional current = AccessPtrBufferVar(call->args[0]); - if (!current.defined()) { - return; - } - if (!found.defined()) { - found = current; - } else if (found.value().get() != current.value().get()) { - multiple = true; - } - }); - if (multiple) { - return Optional(); - } - return found; - } -}; +static Stmt MakeParityWait(const Buffer &barrier_buf, PrimExpr barrier_id, + PrimExpr parity) { + auto ref = MakeBarrierRef(barrier_buf, std::move(barrier_id)); + return Evaluate(Call(DataType::Handle(), mbarrier_wait_parity(), + {ref, std::move(parity)})); +} + +static Stmt MakeArriveBarrier(const Buffer &barrier_buf, PrimExpr barrier_id) { + auto ref = MakeBarrierRef(barrier_buf, std::move(barrier_id)); + return Evaluate( + Call(DataType::Handle(), builtin::ptx_arrive_barrier(), {ref})); +} // --------------------------------------------------------------------------- -// ThreadIdxRewriter (from warp_specialized_rewriter.cc) +// Convert tl.tileop.copy → tl.tileop.tma_copy with barrier annotation // --------------------------------------------------------------------------- -class PCThreadIdxRewriter : public StmtExprMutator { +/// Rewrite a `tl.tileop.copy` Call into a `tl.tileop.tma_copy` Call with +/// barrier reference. The args (src/dst regions) are preserved; only the op +/// and annotations change. +static PrimExpr RewriteCopyToTmaCopy(const Call ©_call, + const Buffer &barrier_buf, + PrimExpr barrier_id) { + static const Op &tma_copy_op = Op::Get("tl.tileop.tma_copy"); + auto new_annotations = copy_call->annotations; + new_annotations.Set("barrier", MakeBarrierRef(barrier_buf, barrier_id)); + new_annotations.Set("is_tma_copy", IntImm(DataType::Int(32), 1)); + return Call(copy_call->dtype, tma_copy_op, copy_call->args, new_annotations, + copy_call->span); +} + +/// Annotate SIMT producer statements so the enclosing transform owns cp.async +/// synchronization. +/// - ForNodes get `kParallelAsyncWithoutAsyncCommitWait = true` so +/// InjectPTXAsyncCopy does not emit commit_group + wait_group(0). +/// - Tile-op copy calls get `kAsyncCopyNoImplicitCommitWait` so copy.cc does +/// not emit its own implicit commit/wait either. +/// This allows the WS pass to emit its own commit_group + +/// cp_async_barrier_noinc, tying cp.async completion to the forward mbarrier. +class SimtProducerAnnotator : public StmtExprMutator { public: - static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced, - PrimExpr thread_extent, bool do_shuffle = false) { - auto rewriter = - PCThreadIdxRewriter(std::move(thread_var), std::move(replaced), - std::move(thread_extent), do_shuffle); - return rewriter(std::move(stmt)); + static Stmt Annotate(const Stmt &stmt, Target target) { + SimtProducerAnnotator a(std::move(target)); + return a.VisitStmt(stmt); } private: - PCThreadIdxRewriter(Var thread_var, PrimExpr replaced, PrimExpr thread_extent, - bool do_shuffle) - : thread_var_(std::move(thread_var)), replaced_(std::move(replaced)), - thread_extent_(std::move(thread_extent)), do_shuffle_(do_shuffle) {} + explicit SimtProducerAnnotator(Target target) : target_(std::move(target)) {} - PrimExpr VisitExpr_(const VarNode *var) final { - if (var == thread_var_.get()) { - return replaced_; - } - return StmtExprMutator::VisitExpr_(var); + Stmt VisitStmt_(const ForNode *op) final { + Stmt body = VisitStmt(op->body); + auto annotations = op->annotations; + annotations.Set(attr::kParallelAsyncWithoutAsyncCommitWait, Bool(true)); + return For(op->loop_var, op->min, op->extent, op->kind, body, + op->thread_binding, annotations, op->step, op->span); } - Stmt VisitStmt_(const IfThenElseNode *op) final { - auto f_uses_thread = [=](const tvm::tir::VarNode *v) { - return v == thread_var_.get(); - }; - maybe_thread_opt_ = false; - if (!op->else_case.defined() && op->condition.as() && - UsesVar(op->condition, f_uses_thread) && - !(UsesVar(op->then_case, f_uses_thread))) { - auto eq_op = Downcast(op->condition); - if (eq_op->a.as() == thread_var_.get() || - eq_op->b.as() == thread_var_.get()) { - maybe_thread_opt_ = true; - } - auto then_case = StmtExprMutator::VisitStmt(op->then_case); - maybe_thread_opt_ = do_shuffle_ && maybe_thread_opt_ && has_tma_op_; - has_tma_op_ = false; - if (maybe_thread_opt_) { - return IfThenElse( - Call(DataType::Bool(), tl_shuffle_elect(), {thread_extent_}), - StmtExprMutator::VisitStmt(op->then_case), std::nullopt); - } + PrimExpr VisitExpr_(const CallNode *op) final { + static const Op ©_op = Op::Get("tl.tileop.copy"); + Call call = Downcast(StmtExprMutator::VisitExpr_(op)); + if (!call->op.same_as(copy_op) || !CanUsePipelineManagedCPAsyncCopy(call)) { + return call; } - return StmtExprMutator::VisitStmt_(op); + auto annotations = call->annotations; + annotations.Set(attr::kAsyncCopyNoImplicitCommitWait, + IntImm(DataType::Int(32), 1)); + return Call(call->dtype, call->op, call->args, annotations, call->span); } - PrimExpr VisitExpr_(const CallNode *op) final { - if (op->op.same_as(tl::tma_load()) || - op->op.same_as(tl::tma_load_im2col()) || - op->op.same_as(tl::tma_store()) || - op->op.same_as(builtin::ptx_arrive_barrier_expect_tx()) || - op->op.same_as(mbarrier_expect_tx())) { - has_tma_op_ = true; + bool CanUsePipelineManagedCPAsyncCopy(const Call &call) const { + auto tile_op = ParseOperator(call); + const auto *copy = tile_op.as(); + if (copy == nullptr) { + return false; } - return StmtExprMutator::VisitExpr_(op); + return copy->CheckPipelineManagedCPAsyncCopy(target_, &analyzer_); } - Var thread_var_; - PrimExpr replaced_; - PrimExpr thread_extent_; - bool maybe_thread_opt_ = false; - bool do_shuffle_; - bool has_tma_op_ = false; + Target target_; + mutable arith::Analyzer analyzer_; }; -// --------------------------------------------------------------------------- -// MbarrierInitRemover: removes barrier_init annotations and shared.barrier -// buffers from blocks outside the transformed block. -// --------------------------------------------------------------------------- - -/*! - * \brief Post-transform cleanup: remove barrier_init annotations and - * shared.barrier alloc_buffers that remain outside the transformed - * block. - * The new init is already emitted inside the block by the rewriter. - */ -class MbarrierInitRemover : public StmtExprMutator { +class TileOpMbarPhaseAnnotator : public StmtExprMutator { public: - static Stmt Remove(Stmt stmt) { - MbarrierInitRemover remover; - return remover(std::move(stmt)); + static Stmt Annotate(const Stmt &stmt, PrimExpr phase_expr) { + TileOpMbarPhaseAnnotator annotator(std::move(phase_expr)); + return annotator.VisitStmt(stmt); } private: - Stmt VisitStmt_(const BlockNode *op) final { - // Remove barrier_init annotation and shared.barrier buffers from - // blocks outside the transformed region. - bool has_barrier_init = op->annotations.count("barrier_init"); - bool has_barrier_bufs = false; - for (const auto &buf : op->alloc_buffers) { - if (buf.scope() == "shared.barrier") { - has_barrier_bufs = true; - break; + explicit TileOpMbarPhaseAnnotator(PrimExpr phase_expr) + : phase_expr_(std::move(phase_expr)) {} + + PrimExpr VisitExpr_(const CallNode *op) final { + Call call = Downcast(StmtExprMutator::VisitExpr_(op)); + if (!IsMbarPhaseConsumer(call)) { + return call; + } + if (call->annotations.count(attr::kPipelineMbarPhaseExpr)) { + return call; + } + auto annotations = call->annotations; + annotations.Set(attr::kPipelineMbarPhaseExpr, phase_expr_); + return Call(call->dtype, call->op, call->args, annotations, call->span); + } + + bool IsMbarPhaseConsumer(const Call &call) const { + auto tile_op = ParseOperator(call); + return tile_op.defined() && (tile_op.as() != nullptr || + tile_op.as() != nullptr || + tile_op.as() != nullptr || + tile_op.as() != nullptr); + } + + PrimExpr phase_expr_; +}; + +/// Annotate a tile-op Call (e.g., c2d_im2col) with a barrier reference. +/// The tile-op's Lower() is expected to check for the "barrier" annotation +/// and use it instead of allocating its own mbarrier. +static PrimExpr AnnotateTileOpBarrier(const Call &tile_call, + const Buffer &barrier_buf, + PrimExpr barrier_id) { + auto new_annotations = tile_call->annotations; + new_annotations.Set("barrier", MakeBarrierRef(barrier_buf, barrier_id)); + return Call(tile_call->dtype, tile_call->op, tile_call->args, new_annotations, + tile_call->span); +} + +struct BufferDataAccessInfo { + bool read{false}; + bool write{false}; + + bool HasAnyAccess() const { return read || write; } +}; + +struct PreludeTmaLoadPlan { + Stmt stmt; + const StmtNode *stmt_node{nullptr}; + int wait_pos{-1}; +}; + +static BufferDataAccessInfo +AnalyzeBufferDataAccess(const Stmt &stmt, const Var &buffer_data, + const BufferDataToBufferMap &buffer_map) { + class BufferDataAccessDetector : public StmtExprVisitor { + public: + BufferDataAccessDetector(const Var &buffer_data, + const BufferDataToBufferMap &buffer_map) + : buffer_data_(buffer_data), buffer_map_(buffer_map) {} + + BufferDataAccessInfo Result() const { return result_; } + + private: + void VisitExpr_(const BufferLoadNode *op) final { + if (op->buffer->data.same_as(buffer_data_)) { + result_.read = true; } + StmtExprVisitor::VisitExpr_(op); } - if (!has_barrier_init && !has_barrier_bufs) { - return StmtExprMutator::VisitStmt_(op); + void VisitStmt_(const BufferStoreNode *op) final { + if (op->buffer->data.same_as(buffer_data_)) { + result_.write = true; + } + StmtExprVisitor::VisitStmt_(op); } - Block block = GetRef(op); - auto block_ptr = block.CopyOnWrite(); + void VisitExpr_(const CallNode *op) final { + if (op->op.same_as(tl::access_ptr())) { + ICHECK_EQ(op->args.size(), 3); + const auto *base_load = op->args[0].as(); + ICHECK(base_load); + if (base_load->buffer->data.same_as(buffer_data_)) { + MarkAccess(op->args[2]); + } + for (const auto &index : base_load->indices) { + VisitExpr(index); + } + VisitExpr(op->args[1]); + return; + } - if (has_barrier_init) { - Map new_annos; - for (const auto &[key, value] : block_ptr->annotations) { - if (key != "barrier_init") { - new_annos.Set(key, value); + if (op->op.same_as(builtin::tvm_access_ptr())) { + ICHECK_EQ(op->args.size(), 5); + const auto *var = op->args[1].as(); + ICHECK(var); + auto it = buffer_map_.find(GetRef(var)); + if (it != buffer_map_.end() && it->second->data.same_as(buffer_data_)) { + MarkAccess(op->args[4]); } + VisitExpr(op->args[2]); + VisitExpr(op->args[3]); + return; } - block_ptr->annotations = new_annos; + + StmtExprVisitor::VisitExpr_(op); } - if (has_barrier_bufs) { - Array new_alloc_buffers; - for (const auto &buf : block_ptr->alloc_buffers) { - if (buf.scope() != "shared.barrier") { - new_alloc_buffers.push_back(buf); - } + void MarkAccess(const PrimExpr &rw_expr) { + int rw_mask = 3; + if (const int64_t *imm = as_const_int(rw_expr)) { + rw_mask = static_cast(*imm); + } + if (rw_mask & 1) { + result_.read = true; + } + if (rw_mask & 2) { + result_.write = true; + } + } + + Var buffer_data_; + const BufferDataToBufferMap &buffer_map_; + BufferDataAccessInfo result_; + }; + + BufferDataAccessDetector detector(buffer_data, buffer_map); + detector(stmt); + return detector.Result(); +} + +static bool CollectPreludeStmtsToPipelineLoop(const Stmt &stmt, + const ForNode *pipeline_loop, + Array *prelude_stmts) { + if (stmt.get() == pipeline_loop) { + return true; + } + if (const auto *seq = stmt.as()) { + for (int i = 0; i < static_cast(seq->seq.size()); ++i) { + Array nested_prelude; + if (CollectPreludeStmtsToPipelineLoop(seq->seq[i], pipeline_loop, + &nested_prelude)) { + for (int j = 0; j < i; ++j) { + prelude_stmts->push_back(seq->seq[j]); + } + prelude_stmts->insert(prelude_stmts->end(), nested_prelude.begin(), + nested_prelude.end()); + return true; } - block_ptr->alloc_buffers = new_alloc_buffers; } + return false; + } + if (const auto *let = stmt.as()) { + return CollectPreludeStmtsToPipelineLoop(let->body, pipeline_loop, + prelude_stmts); + } + if (const auto *realize = stmt.as()) { + return CollectPreludeStmtsToPipelineLoop(realize->block->body, + pipeline_loop, prelude_stmts); + } + if (const auto *block = stmt.as()) { + return CollectPreludeStmtsToPipelineLoop(block->body, pipeline_loop, + prelude_stmts); + } + if (const auto *attr = stmt.as()) { + return CollectPreludeStmtsToPipelineLoop(attr->body, pipeline_loop, + prelude_stmts); + } + return false; +} - block_ptr->body = VisitStmt(block_ptr->body); - return block; +static Optional ExtractProducerWriteBufferData(const Stmt &stmt) { + const auto *call = GetEvaluateCallInSimpleWrapper(stmt); + if (!call) { + return Optional(); + } + auto tile_op = ParseOperator(ffi::GetRef(call)); + if (!tile_op.defined()) { + return Optional(); + } + if (const auto *copy = tile_op.as()) { + if (IsSharedBuffer(copy->dst)) { + return copy->dst->data; + } + } + if (const auto *im2col = tile_op.as()) { + if (IsSharedBuffer(im2col->dst_)) { + return im2col->dst_->data; + } } + return Optional(); +} - // Stop recursion at BlockRealize — the new init is inside the block - // and we don't want to remove it. - Stmt VisitStmt_(const BlockRealizeNode *op) final { return GetRef(op); } -}; +static Stmt RewritePreludeTmaProducerStmt(const Stmt &stmt, + const Buffer &barrier_buf, + PrimExpr barrier_id) { + class PreludeTmaProducerRewriter : public StmtExprMutator { + public: + PreludeTmaProducerRewriter(Buffer barrier_buf, PrimExpr barrier_id) + : barrier_buf_(std::move(barrier_buf)), + barrier_id_(std::move(barrier_id)) {} + + Stmt Rewrite(const Stmt &stmt) { return VisitStmt(stmt); } + + private: + PrimExpr VisitExpr_(const CallNode *op) final { + Call call = Downcast(StmtExprMutator::VisitExpr_(op)); + if (rewritten_) { + return call; + } + auto tile_op = ParseOperator(call); + if (!tile_op.defined()) { + return call; + } + PrimExpr rewritten_call; + if (tile_op.as()) { + rewritten_call = RewriteCopyToTmaCopy(call, barrier_buf_, barrier_id_); + } else if (tile_op.as()) { + rewritten_call = AnnotateTileOpBarrier(call, barrier_buf_, barrier_id_); + } else { + return call; + } + Call new_call = Downcast(rewritten_call); + auto annotations = new_call->annotations; + annotations.Set("emit_arrive", IntImm(DataType::Int(32), 1)); + rewritten_ = true; + return Call(new_call->dtype, new_call->op, new_call->args, annotations, + new_call->span); + } + + Buffer barrier_buf_; + PrimExpr barrier_id_; + bool rewritten_{false}; + }; + + PreludeTmaProducerRewriter rewriter(barrier_buf, std::move(barrier_id)); + return rewriter.Rewrite(stmt); +} // --------------------------------------------------------------------------- -// ProducerConsumerWSRewriter — main pass +// Main rewriter // --------------------------------------------------------------------------- class ProducerConsumerWSRewriter : public StmtExprMutator { public: static PrimFunc Substitute(PrimFunc f) { - // Check thread tags - if (!ThreadTagChecker::HasOnlyThreadIdxX(f)) { - LOG(WARNING) << "ProducerConsumerWS: disabled because program uses " - "thread tags other than threadIdx.x"; - return f; - } + auto target = f->GetAttr(tvm::attr::kTarget); + ICHECK(target.defined()) + << "ProducerConsumerWS: target attribute is required"; ProducerConsumerWSRewriter T; + T.target_ = target.value(); f.CopyOnWrite()->body = T(f->body); - // TODO(lei): This should be refactored - // If WS was applied, remove any barrier_init annotations and - // shared.barrier buffers that remain OUTSIDE the block (e.g. at - // function body level from lower_tile_op). The new init is already - // inside the block. if (T.ws_transformed_) { - f.CopyOnWrite()->body = MbarrierInitRemover::Remove(f->body); + f = WithAttr(std::move(f), kTiledWSApplied, IntImm(DataType::Int(32), 1)); } - return f; } private: - // Locate the threadIdx.x binding + // --- Track threadIdx.x binding --- Stmt VisitStmt_(const AttrStmtNode *op) final { - if (op->attr_key == tir::attr::thread_extent && - Downcast(op->node)->thread_tag == "threadIdx.x") { - thread_iv_ = Downcast(op->node); - Optional old_num_threads = num_threads_; - num_threads_ = std::nullopt; - AttrStmt attr_stmt = Downcast(StmtExprMutator::VisitStmt_(op)); - if (num_threads_.defined()) { - PrimExpr num_threads = num_threads_.value(); - thread_iv_.CopyOnWrite()->dom = {0, num_threads}; - attr_stmt.CopyOnWrite()->node = thread_iv_; - attr_stmt.CopyOnWrite()->value = num_threads; - } - // clean up if we may have multiple threadIdx.x that - // need to be transformed - num_threads_ = old_num_threads; - thread_iv_ = {}; - return attr_stmt; + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag == "threadIdx.x") { + thread_iv_ = iv; + Optional old_num_threads = num_threads_; + num_threads_ = std::nullopt; + AttrStmt attr = Downcast(StmtExprMutator::VisitStmt_(op)); + if (num_threads_.defined()) { + PrimExpr nt = num_threads_.value(); + thread_iv_.CopyOnWrite()->dom = {0, nt}; + attr.CopyOnWrite()->node = thread_iv_; + attr.CopyOnWrite()->value = nt; + } + num_threads_ = old_num_threads; + thread_iv_ = {}; + return attr; + } } return StmtExprMutator::VisitStmt_(op); } + // --- Find the block containing the pipeline loop --- Stmt VisitStmt_(const BlockRealizeNode *op) final { if (!thread_iv_.defined()) return StmtExprMutator::VisitStmt_(op); const Block &orig_block = op->block; - // Find the explicitly pipelined loop for producer/consumer WS. - const ForNode *pipeline_loop = FindAnnotatedPipelineLoop(orig_block->body); + // Find the pipelined loop. + const ForNode *pipeline_loop = FindPipelineLoop(orig_block->body); if (!pipeline_loop) return StmtExprMutator::VisitStmt_(op); auto num_stages_anno = pipeline_loop->annotations.Get("num_stages"); - ICHECK(num_stages_anno); + if (!num_stages_anno) + return StmtExprMutator::VisitStmt_(op); int num_stages = static_cast(Downcast(num_stages_anno.value())->value); - ICHECK_GE(num_stages, 1); - - // Detect cluster barriers and compute cluster size from block annotations. - is_cluster_barrier_ = false; - cluster_size_ = 1; - for (const auto &buf : orig_block->alloc_buffers) { - if (buf.scope() == "shared.cluster_barrier") { - is_cluster_barrier_ = true; - break; - } - } - if (is_cluster_barrier_ && orig_block->annotations.count("cluster_dims")) { - if (auto arr = orig_block->annotations.Get("cluster_dims") - ->try_cast>()) { - int sz = 1; - for (auto d : arr.value()) - sz *= static_cast(d->value); - cluster_size_ = sz; - } - } + if (num_stages < 1) + return StmtExprMutator::VisitStmt_(op); - // Flatten the loop body + // Flatten the loop body. Array flat_stmts; - Stmt loop_body_root = pipeline_loop->body; - if (auto *realize = pipeline_loop->body.as()) { - loop_body_root = realize->block->body; - } - std::vector> loop_body_lets; - while (const auto *let_stmt = loop_body_root.as()) { - loop_body_lets.emplace_back(let_stmt->var, let_stmt->value); - loop_body_root = let_stmt->body; - } - FlattenSeqStmt(loop_body_root, &flat_stmts); - auto rewrap_loop_body_lets = [&](Stmt body) { - for (auto it = loop_body_lets.rbegin(); it != loop_body_lets.rend(); - ++it) { - body = LetStmt((*it).first, (*it).second, body); - } - return body; - }; - // Extract async producer blocks (TMA and cp.async) - AsyncCopyBlockExtractor extractor; - extractor.Extract(flat_stmts); - - if (extractor.blocks.empty()) { - // No TMA loads found — fall through to standard pipeline + Stmt loop_body = pipeline_loop->body; + if (auto *realize = loop_body.as()) { + loop_body = realize->block->body; + } + // Unwrap LetStmt chain that dominates the whole loop body. + std::vector> outer_let_bindings; + while (const auto *let = loop_body.as()) { + outer_let_bindings.emplace_back(let->var, let->value); + loop_body = let->body; + } + // Unwrap a single IfThenElse wrapper (no else branch) so that + // TMA producers inside conditional loop bodies can be classified. + // Keep LetStmt chains inside the conditional separate so they stay + // dominated by the original guard after rebuilding WS branches. + Optional loop_body_condition; + std::vector> inner_let_bindings; + if (const auto *if_stmt = loop_body.as()) { + if (!if_stmt->else_case.defined()) { + // Peel LetStmt chain from inside the conditional body. These + // bindings must remain inside the guarded region. + Stmt inner = if_stmt->then_case; + while (const auto *let = inner.as()) { + inner_let_bindings.emplace_back(let->var, let->value); + inner = let->body; + } + loop_body_condition = if_stmt->condition; + loop_body = inner; + } + } + FlattenSeqStmt(loop_body, &flat_stmts); + + // Classify statements into producer (TMA/SIMT copy) and consumer. + std::vector kinds; + int num_tma = 0; + int num_simt = 0; + for (const Stmt &s : flat_stmts) { + auto k = ClassifyStmt(s, target_); + kinds.push_back(k); + if (k == TileStmtKind::kTmaProducer) + ++num_tma; + if (k == TileStmtKind::kSimtProducer) + ++num_simt; + } + + // Require at least one TMA producer. + if (num_tma == 0) return StmtExprMutator::VisitStmt_(op); - } - - // Check if there are existing tl_pipeline_order/tl_pipeline_stage - // with -1 values (WS+TMA enabled markers) — if so, use those - auto order_anno = pipeline_loop->annotations.Get("tl_pipeline_order"); - auto stage_anno = pipeline_loop->annotations.Get("tl_pipeline_stage"); - if (order_anno && stage_anno) { - auto order_array = Downcast>(order_anno.value()); - for (const auto &val : order_array) { - if (val->value == -1) { - // Already has WS pipeline annotations — skip - return StmtExprMutator::VisitStmt_(op); - } - } - } - VarBindingMap saved_loop_guard_bindings = current_loop_guard_bindings_; - for (const auto &[var, value] : loop_body_lets) { - current_loop_guard_bindings_[var] = value; - } + // --- Build the WS transformation --- + return BuildWSBlock(op, orig_block, pipeline_loop, num_stages, flat_stmts, + kinds, outer_let_bindings, inner_let_bindings, + loop_body_condition); + } - BufferDataToBufferMap buffer_data_to_buffer = - BufferDataToBufferCollector::Collect(GetRef(op)); + Stmt + BuildWSBlock(const BlockRealizeNode *orig_realize, const Block &orig_block, + const ForNode *pipeline_loop, int num_stages, + const Array &flat_stmts, + const std::vector &kinds, + const std::vector> &outer_let_bindings, + const std::vector> &inner_let_bindings, + Optional loop_body_condition = Optional()) { + Var loop_var = pipeline_loop->loop_var; + PrimExpr loop_min = pipeline_loop->min; + PrimExpr loop_extent = pipeline_loop->extent; + PrimExpr linear_idx = loop_var - loop_min; - // --------------------------------------------------------------- - // Build producer and consumer loop bodies - // --------------------------------------------------------------- - PrimExpr consumer_thread_extent = thread_iv_->dom->extent; - consumer_thread_extent_ = - consumer_thread_extent; // Store for RebuildBlockBody - PrimExpr producer_thread_extent = IntImm(DataType::Int(32), 128); - producer_thread_extent_ = producer_thread_extent; + PrimExpr base_stage_expr = FloorMod(linear_idx, num_stages); + PrimExpr base_parity_expr = FloorMod(FloorDiv(linear_idx, num_stages), 2); - // Barrier layout has two modes: - // 1) Mixed TMA + cp.async: - // keep existing TMA forward ids, append cp.async forward ids, then - // append back-pressure ids. - // 2) Pure TMA: - // remap to [loop forward][back-pressure][preloop forward] so producer - // and consumer follow the same protocol as the hand-written WS kernels. - int num_existing_tma_fwd_barriers = 0; - int num_cp_async_groups = 0; - for (const auto &block : extractor.blocks) { - if (block.kind == AsyncProducerKind::kTma) { - ++num_existing_tma_fwd_barriers; - } else if (block.kind == AsyncProducerKind::kCpAsync) { - ++num_cp_async_groups; + // When the loop body is conditionally guarded, use PhaseCounters + // instead of the loop variable for barrier stage/parity. This + // ensures parity stays correct when iterations are skipped. + bool needs_phase_counter = loop_body_condition.defined(); + Optional producer_phase_counter; + Optional consumer_phase_counter; + PrimExpr p_stage_expr = base_stage_expr; + PrimExpr p_parity_expr = base_parity_expr; + PrimExpr c_stage_expr = base_stage_expr; + PrimExpr c_parity_expr = base_parity_expr; + if (needs_phase_counter) { + producer_phase_counter = PhaseCounter::Create("producer_phase_cnt"); + consumer_phase_counter = PhaseCounter::Create("consumer_phase_cnt"); + p_stage_expr = producer_phase_counter.value().StageExpr(num_stages); + p_parity_expr = producer_phase_counter.value().ParityExpr(num_stages); + c_stage_expr = consumer_phase_counter.value().StageExpr(num_stages); + c_parity_expr = consumer_phase_counter.value().ParityExpr(num_stages); + } + + PrimExpr consumer_extent = thread_iv_->dom->extent; + PrimExpr producer_extent = IntImm(DataType::Int(32), 128); + common_prelude_rewrites_.clear(); + + bool has_simt_producer = false; + bool has_cp_async_producer = false; + int num_producer_groups = 0; + for (auto k : kinds) { + if (k == TileStmtKind::kTmaProducer) + ++num_producer_groups; + if (k == TileStmtKind::kSimtProducer) + has_simt_producer = true; + if (k == TileStmtKind::kCpAsyncProducer) + has_cp_async_producer = true; + } + + // --- Barrier allocation --- + // Layout: [fwd_0..fwd_{G*S-1}] [bp_0..bp_{G*S-1}] + // where G = num_producer_groups (one per TMA copy), S = num_stages. + // When SIMT producers are present, all producer types share the same + // barrier group — the last forward arrive covers everything. + int num_fwd = num_producer_groups * num_stages; + int num_bp = num_producer_groups * num_stages; + + buffer_data_to_buffer_ = + BufferDataToBufferCollector::Collect(orig_block->body); + Array consumer_compute_stmts; + for (size_t i = 0; i < flat_stmts.size(); ++i) { + if (!IsProducer(kinds[i])) { + consumer_compute_stmts.push_back(flat_stmts[i]); + } + } + + Array prelude_stmts; + CollectPreludeStmtsToPipelineLoop(orig_block->body, pipeline_loop, + &prelude_stmts); + std::vector prelude_tma_plans; + for (const Stmt &stmt : prelude_stmts) { + if (ClassifyStmt(stmt, target_) != TileStmtKind::kTmaProducer) { + continue; } - } - std::vector wait_insert_pos(extractor.blocks.size(), 0); - std::vector arrive_insert_pos( - extractor.blocks.size(), - static_cast(extractor.compute_stmts.size())); - for (size_t ti = 0; ti < extractor.blocks.size(); ++ti) { - if (!extractor.blocks[ti].write_buffer_data.defined()) { + Optional write_buffer_data = ExtractProducerWriteBufferData(stmt); + if (!write_buffer_data.defined()) { continue; } - const Var &target = extractor.blocks[ti].write_buffer_data.value(); int first_read = -1; - int last_access = -1; - for (size_t ci = 0; ci < extractor.compute_stmts.size(); ++ci) { + for (size_t ci = 0; ci < consumer_compute_stmts.size(); ++ci) { BufferDataAccessInfo access = AnalyzeBufferDataAccess( - extractor.compute_stmts[ci], target, buffer_data_to_buffer); - if (access.read && first_read < 0) { + consumer_compute_stmts[ci], write_buffer_data.value(), + buffer_data_to_buffer_); + if (access.read) { first_read = static_cast(ci); - } - if (access.HasAnyAccess()) { - last_access = static_cast(ci); - } - } - if (first_read >= 0) { - wait_insert_pos[ti] = first_read; - arrive_insert_pos[ti] = last_access + 1; - } else if (last_access >= 0) { - // Write-only statements that touch the producer-written shared buffer - // do not need the producer result, so keep the forward wait at the - // loop head while still delaying back-pressure release until the last - // consumer-side access. - wait_insert_pos[ti] = 0; - arrive_insert_pos[ti] = last_access + 1; - } - } - int num_existing_loop_fwd_barriers = - num_existing_tma_fwd_barriers * num_stages; - int original_num_existing_loop_fwd_barriers = - num_existing_loop_fwd_barriers; - int inferred_existing_required = - InferMinRequiredBarrierCount(orig_block->body); - int required_preloop_tma_pairs = CountRewrittenPureTmaPreloopForwardPairs( - orig_block->body, pipeline_loop); - bool old_use_full_tma_forward_barrier_protocol = - use_full_tma_forward_barrier_protocol_; - bool old_remap_pure_tma_barriers = remap_pure_tma_barriers_; - int old_pure_tma_preloop_fwd_base = pure_tma_preloop_fwd_base_; - int old_pure_tma_preloop_fwd_count = pure_tma_preloop_fwd_count_; - int old_pure_tma_preloop_fwd_cursor = pure_tma_preloop_fwd_cursor_; - use_full_tma_forward_barrier_protocol_ = (num_cp_async_groups == 0); - remap_pure_tma_barriers_ = use_full_tma_forward_barrier_protocol_; - std::vector ws_producer_stmts(extractor.blocks.size()); - std::vector> ws_wait_stmts(extractor.blocks.size(), - std::nullopt); - std::vector> producer_issue_guards( - extractor.blocks.size(), std::nullopt); - std::vector> producer_issue_guard_sources( - extractor.blocks.size(), std::nullopt); - std::vector> protocol_guards(extractor.blocks.size(), - std::nullopt); - std::vector> protocol_guard_sources(extractor.blocks.size(), - std::nullopt); - for (size_t i = 0; i < extractor.blocks.size(); ++i) { - ws_producer_stmts[i] = extractor.blocks[i].producer_stmt; - ws_wait_stmts[i] = extractor.blocks[i].wait_stmt; - producer_issue_guards[i] = - ExtractNonThreadProducerGuard(extractor.blocks[i].producer_stmt); - if (producer_issue_guards[i].defined()) { - producer_issue_guard_sources[i] = extractor.blocks[i].producer_stmt; - protocol_guards[i] = producer_issue_guards[i]; - protocol_guard_sources[i] = extractor.blocks[i].producer_stmt; - if (arrive_insert_pos[i] > 0 && - arrive_insert_pos[i] <= - static_cast(extractor.compute_stmts.size())) { - const Stmt &arrive_source = - extractor.compute_stmts[arrive_insert_pos[i] - 1]; - Optional arrive_guard = - ExtractNonThreadProducerGuard(arrive_source); - if (arrive_guard.defined()) { - protocol_guards[i] = arrive_guard; - protocol_guard_sources[i] = arrive_source; - } - } - } - // NOTE: Previously, when the guard was a mask-like boolean expression - // (e.g. BlockMask[by, bx, k]), the producer would strip the guard and - // issue TMA loads unconditionally. This causes unnecessary memory - // traffic for sparse workloads, so we now keep the guard on the - // producer side and rely on phase-counter-based parity tracking to - // maintain barrier synchronisation. - } - - // --------------------------------------------------------------- - // Detect whether the pipeline loop needs counter-based phase - // tracking. This is necessary when the loop body is conditionally - // guarded (e.g. `if block_mask[k]`) so that skipped iterations do - // not desynchronise the mbarrier parity. - // --------------------------------------------------------------- - bool needs_phase_counter = false; - Optional uniform_phase_guard; - Optional uniform_phase_guard_source; - { - StructuralEqual eq; - for (size_t i = 0; i < extractor.blocks.size(); ++i) { - if (protocol_guards[i].defined()) { - if (!needs_phase_counter) { - needs_phase_counter = true; - uniform_phase_guard = protocol_guards[i]; - uniform_phase_guard_source = protocol_guard_sources[i]; - } else if (!eq(uniform_phase_guard.value(), - protocol_guards[i].value())) { - // Different guards on different blocks – fall back to - // original loop-variable parity (no counter). - needs_phase_counter = false; - break; - } + break; } } - // Only use counter when ALL blocks share the same guard. - if (needs_phase_counter) { - for (size_t i = 0; i < extractor.blocks.size(); ++i) { - if (!protocol_guards[i].defined()) { - needs_phase_counter = false; - break; - } - } + if (first_read < 0) { + continue; } + prelude_tma_plans.push_back({stmt, stmt.get(), first_read}); } - std::optional producer_phase_counter; - std::optional consumer_phase_counter; - if (needs_phase_counter) { - producer_phase_counter = PhaseCounter::Create("producer_phase_cnt"); - consumer_phase_counter = PhaseCounter::Create("consumer_phase_cnt"); - } + int total_barriers = num_fwd + num_bp + prelude_tma_plans.size(); + Buffer barrier_buf = + CreateMBarrierBuffer(injected_mbarrier_name_, total_barriers); + // arrive_counts are computed later (after producer_extent is finalized). - StructuralEqual equal; - auto same_optional_expr = [&](const Optional &guard_a, - const Optional &guard_b) { - if (guard_a.defined() != guard_b.defined()) { - return false; + std::vector wait_insert_pos(num_producer_groups, 0); + std::vector arrive_insert_pos( + num_producer_groups, static_cast(consumer_compute_stmts.size())); + int access_group_idx = 0; + for (size_t i = 0; i < flat_stmts.size(); ++i) { + if (kinds[i] != TileStmtKind::kTmaProducer) { + continue; } - return !guard_a.defined() || equal(guard_a.value(), guard_b.value()); - }; - auto same_guard = [&](size_t lhs, size_t rhs) { - return same_optional_expr(producer_issue_guards[lhs], - producer_issue_guards[rhs]) && - same_optional_expr(protocol_guards[lhs], protocol_guards[rhs]); - }; - std::vector block_group(extractor.blocks.size(), 0); - int num_block_groups = 0; - if (!extractor.blocks.empty()) { - int next_group = 0; - block_group[0] = next_group++; - bool current_group_has_tma = - extractor.blocks[0].kind == AsyncProducerKind::kTma; - for (size_t i = 1; i < extractor.blocks.size(); ++i) { - bool merge_with_prev = - wait_insert_pos[i] == wait_insert_pos[i - 1] && - arrive_insert_pos[i] == arrive_insert_pos[i - 1] && - same_guard(i - 1, i); - if (merge_with_prev && !remap_pure_tma_barriers_ && - current_group_has_tma && - extractor.blocks[i].kind == AsyncProducerKind::kTma) { - // Mixed groups can safely share one TMA barrier with cp.async - // arrive-on notifications, but keeping multiple TMA producers on the - // same preserved protocol would over-arrive the barrier. - merge_with_prev = false; + Optional write_buffer_data = + ExtractProducerWriteBufferData(flat_stmts[i]); + if (write_buffer_data.defined()) { + int first_read = -1; + int last_access = -1; + for (size_t ci = 0; ci < consumer_compute_stmts.size(); ++ci) { + BufferDataAccessInfo access = AnalyzeBufferDataAccess( + consumer_compute_stmts[ci], write_buffer_data.value(), + buffer_data_to_buffer_); + if (access.read && first_read < 0) { + first_read = static_cast(ci); + } + if (access.HasAnyAccess()) { + last_access = static_cast(ci); + } } - block_group[i] = merge_with_prev ? block_group[i - 1] : next_group++; - if (!merge_with_prev) { - current_group_has_tma = - extractor.blocks[i].kind == AsyncProducerKind::kTma; - } else if (extractor.blocks[i].kind == AsyncProducerKind::kTma) { - current_group_has_tma = true; + if (first_read >= 0) { + wait_insert_pos[access_group_idx] = first_read; + arrive_insert_pos[access_group_idx] = last_access + 1; + } else if (last_access >= 0) { + wait_insert_pos[access_group_idx] = 0; + arrive_insert_pos[access_group_idx] = last_access + 1; } } - num_block_groups = next_group; + ++access_group_idx; } - std::vector> producer_loop_prefix_stmts( - extractor.blocks.size()); - std::vector moved_compute_stmts(extractor.compute_stmts.size(), - false); - std::vector compute_stmt_summaries; - compute_stmt_summaries.reserve(extractor.compute_stmts.size()); - for (const auto &stmt : extractor.compute_stmts) { - compute_stmt_summaries.push_back( - LocalAccessCollector::Collect(stmt, buffer_data_to_buffer)); + // --- Determine if TMA barriers can be merged --- + // When all pure-TMA producers wait at the same consumer position and + // release at the same position, forward and back-pressure barriers can + // be shared across all TMA copies, reducing from 2*G*S to 2*S barriers. + bool can_merge_tma_barriers = (num_producer_groups > 1) && + !has_simt_producer && !has_cp_async_producer; + if (can_merge_tma_barriers) { + for (int g = 1; g < num_producer_groups; ++g) { + if (wait_insert_pos[g] != wait_insert_pos[0] || + arrive_insert_pos[g] != arrive_insert_pos[0]) { + can_merge_tma_barriers = false; + break; + } + } + } + if (can_merge_tma_barriers) { + // Re-compute barrier layout with a single merged group. + num_fwd = num_stages; + num_bp = num_stages; + total_barriers = num_fwd + num_bp + prelude_tma_plans.size(); + barrier_buf = + CreateMBarrierBuffer(injected_mbarrier_name_, total_barriers); } - std::vector prefix_begin(extractor.blocks.size(), -1); - std::vector prefix_end(extractor.blocks.size(), -1); - std::vector first_group_indices; + std::vector> producer_loop_prefix_stmts(num_producer_groups); + std::vector moved_compute_stmts(consumer_compute_stmts.size(), false); int compute_cursor = 0; - for (size_t ti = 0; ti < extractor.blocks.size(); ++ti) { - bool is_first_in_group = - ti == 0 || block_group[ti] != block_group[ti - 1]; - if (!is_first_in_group) { - continue; - } + for (int ti = 0; ti < num_producer_groups; ++ti) { int wait_pos = wait_insert_pos[ti]; - prefix_begin[ti] = compute_cursor; - prefix_end[ti] = wait_pos; - first_group_indices.push_back(static_cast(ti)); - compute_cursor = std::max(compute_cursor, wait_pos); - } - - // Move only the longest leading prefix that is producer-safe. We walk - // groups backwards so later slices can first mark which compute stmts stay - // in the consumer branch; the current slice may only hoist definitions - // that do not escape into that future consumer live set. - for (auto it = first_group_indices.rbegin(); - it != first_group_indices.rend(); ++it) { - int ti = *it; - int begin = prefix_begin[ti]; - int end = prefix_end[ti]; - if (begin < 0 || end <= begin) { + if (wait_pos <= compute_cursor) { + compute_cursor = std::max(compute_cursor, wait_pos); continue; } - - LocalLiveSet future_consumer_live; - for (int ci = end; ci < static_cast(extractor.compute_stmts.size()); - ++ci) { - if (!moved_compute_stmts[ci]) { - future_consumer_live.AddUses(compute_stmt_summaries[ci]); - } - } - - BufferSet prefix_defined_buffers; - int movable_end = begin; - for (int ci = begin; ci < end; ++ci) { - if (!IsProducerMovableLoopPrefixStmt( - extractor.compute_stmts[ci], compute_stmt_summaries[ci], - future_consumer_live, prefix_defined_buffers)) { + bool all_movable = true; + for (int ci = compute_cursor; ci < wait_pos; ++ci) { + if (!IsProducerMovableLoopPrefixStmt(consumer_compute_stmts[ci], + target_)) { + all_movable = false; break; } - for (const auto &buf : - compute_stmt_summaries[ci].branch_private_write_buffers) { - prefix_defined_buffers.insert(buf); + } + if (all_movable) { + for (int ci = compute_cursor; ci < wait_pos; ++ci) { + producer_loop_prefix_stmts[ti].push_back(consumer_compute_stmts[ci]); + moved_compute_stmts[ci] = true; } - movable_end = ci + 1; } + compute_cursor = wait_pos; + } - for (int ci = begin; ci < movable_end; ++ci) { - producer_loop_prefix_stmts[ti].push_back(extractor.compute_stmts[ci]); - moved_compute_stmts[ci] = true; + bool producer_needs_full_thread_extent = false; + for (size_t i = 0; + i < flat_stmts.size() && !producer_needs_full_thread_extent; ++i) { + if (kinds[i] == TileStmtKind::kSimtProducer || + IsSyncGlobalToSharedCopyLikeStmt(flat_stmts[i], target_)) { + producer_needs_full_thread_extent = true; } } - - auto stmt_has_lowered_simt_copy = [&](const Stmt &stmt) { - return ProducerSimtCopyDetector::HasSimtCopy(stmt, buffer_data_to_buffer); - }; - bool producer_needs_full_thread_extent = - std::any_of(ws_producer_stmts.begin(), ws_producer_stmts.end(), - stmt_has_lowered_simt_copy); if (!producer_needs_full_thread_extent) { for (const auto &prefix_stmts : producer_loop_prefix_stmts) { for (const auto &stmt : prefix_stmts) { - if (stmt_has_lowered_simt_copy(stmt)) { + if (IsSyncGlobalToSharedCopyLikeStmt(stmt, target_)) { producer_needs_full_thread_extent = true; break; } @@ -1357,344 +1388,275 @@ class ProducerConsumerWSRewriter : public StmtExprMutator { } } if (producer_needs_full_thread_extent) { - // LowerTileOp may already have materialized SIMT global->shared copies. - // Those copies cannot be safely remapped onto a smaller producer warp - // partition, so keep the producer extent at the original thread extent. - producer_thread_extent = consumer_thread_extent; - } - producer_thread_extent_ = producer_thread_extent; - - std::vector group_has_tma(num_block_groups, false); - std::vector group_has_cp_async(num_block_groups, false); - for (size_t i = 0; i < extractor.blocks.size(); ++i) { - int group = block_group[i]; - if (extractor.blocks[i].kind == AsyncProducerKind::kTma) { - group_has_tma[group] = true; - } else if (extractor.blocks[i].kind == AsyncProducerKind::kCpAsync) { - group_has_cp_async[group] = true; - } - } - int num_tma_groups = 0; - int num_cp_async_only_groups = 0; - for (int group = 0; group < num_block_groups; ++group) { - if (group_has_tma[group]) { - ++num_tma_groups; - } else if (group_has_cp_async[group]) { - ++num_cp_async_only_groups; - } - } - num_existing_tma_fwd_barriers = num_tma_groups; - num_existing_loop_fwd_barriers = num_existing_tma_fwd_barriers * num_stages; - int num_new_cp_async_fwd_barriers = num_cp_async_only_groups * num_stages; - - int num_existing_barriers = 0; - int num_preloop_fwd_barriers = 0; - if (remap_pure_tma_barriers_) { - // Pure-TMA WS remaps pre-loop TMA prefixes to a dedicated barrier range. - // Some kernels reuse loop barrier ids for those prefixes in the original - // IR, so `inferred_existing_required` alone can undercount how many - // distinct pre-loop barriers the rewritten form needs. - num_preloop_fwd_barriers = - std::max(required_preloop_tma_pairs, - std::max(0, inferred_existing_required - - original_num_existing_loop_fwd_barriers)); - num_existing_barriers = - num_existing_loop_fwd_barriers + num_preloop_fwd_barriers; - } else { - // Mixed TMA/cp.async keeps any existing non-pipelined forward barriers - // at their original ids. `inferred_existing_required` already accounts - // for those explicit references, so avoid reserving an extra unused slot. - num_existing_barriers = - std::max(num_existing_loop_fwd_barriers, inferred_existing_required); - num_preloop_fwd_barriers = - num_existing_barriers - num_existing_loop_fwd_barriers; - } - int num_total_fwd_barriers = 0; - int num_bp_barriers = num_block_groups * num_stages; - int total_barriers = 0; - - std::vector fwd_bases(extractor.blocks.size(), -1); - std::vector bp_bases(extractor.blocks.size(), -1); - std::vector mixed_fwd_arrive_counts; - - if (remap_pure_tma_barriers_) { - // Pure-TMA layout: - // [0, loop_fwd) : loop forward barriers - // [loop_fwd, loop_fwd + bp) : back-pressure barriers - // [loop_fwd + bp, total_barriers) : preloop/prologue forward barriers - int next_loop_fwd_base = 0; - for (size_t i = 0; i < extractor.blocks.size(); ++i) { - if (i == 0 || block_group[i] != block_group[i - 1]) { - fwd_bases[i] = next_loop_fwd_base; - next_loop_fwd_base += num_stages; + // LowerTileOp will materialize these producer-side sync copies into + // explicit SIMT global->shared loops. Keep the producer partition at the + // original thread extent so the lowered thread mapping stays valid. + producer_extent = consumer_extent; + } + + // --- Compute arrive_counts (after producer_extent is finalized) --- + // Forward arrive_count: + // - Pure TMA (possibly merged): 1 (leader thread only) + // - Mixed TMA with SIMT/cp.async: producer_extent (all producer threads) + PrimExpr fwd_arrive_count = (can_merge_tma_barriers || + (!has_simt_producer && !has_cp_async_producer)) + ? IntImm(DataType::Int(32), 1) + : producer_extent; + Array arrive_counts; + for (int i = 0; i < num_fwd; ++i) { + arrive_counts.push_back(fwd_arrive_count); + } + for (int i = 0; i < num_bp; ++i) { + arrive_counts.push_back(consumer_extent); + } + for (size_t i = 0; i < prelude_tma_plans.size(); ++i) { + arrive_counts.push_back(IntImm(DataType::Int(32), 1)); + } + + std::vector> prelude_waits_before_consumer( + consumer_compute_stmts.size()); + PrimExpr prelude_wait_guard = + needs_phase_counter ? EQ(consumer_phase_counter.value().Load(), + IntImm(DataType::Int(32), 0)) + : EQ(loop_var, loop_min); + int prelude_barrier_base = num_fwd + num_bp; + for (size_t i = 0; i < prelude_tma_plans.size(); ++i) { + PrimExpr barrier_id = IntImm(DataType::Int(32), prelude_barrier_base + i); + common_prelude_rewrites_.emplace( + prelude_tma_plans[i].stmt_node, + RewritePreludeTmaProducerStmt(prelude_tma_plans[i].stmt, barrier_buf, + barrier_id)); + int wait_pos = prelude_tma_plans[i].wait_pos; + ICHECK_GE(wait_pos, 0); + ICHECK_LT(wait_pos, static_cast(consumer_compute_stmts.size())); + prelude_waits_before_consumer[wait_pos].push_back(IfThenElse( + prelude_wait_guard, MakeParityWait(barrier_buf, barrier_id, + IntImm(DataType::Int(32), 0)))); + } + + // --- Build producer body --- + // Producer structure (mixed TMA + SIMT/cp.async): + // bp_wait → SIMT copies (all threads, async) → TMA copies (leader) → + // commit + cp_async_barrier_noinc. + // SIMT copies are placed after bp_wait but before TMA so cp.async + // and TMA can overlap. + + // First pass: collect SIMT/cp.async producer stmts separately. + Array simt_producer_stmts; + for (size_t i = 0; i < flat_stmts.size(); ++i) { + if (kinds[i] == TileStmtKind::kSimtProducer) { + // Annotate ForNodes with kParallelAsyncWithoutAsyncCommitWait so + // InjectPTXAsyncCopy (called from LowerTileOp) does not insert + // commit+wait — the WS pass will emit its own commit+barrier_noinc. + simt_producer_stmts.push_back( + SimtProducerAnnotator::Annotate(flat_stmts[i], target_)); + } else if (kinds[i] == TileStmtKind::kCpAsyncProducer) { + simt_producer_stmts.push_back(flat_stmts[i]); + } + } + + // Second pass: build the producer body with correct ordering. + Array producer_stmts; + int tma_idx = 0; + int last_tma_idx = num_producer_groups - 1; + bool simt_stmts_emitted = false; + for (size_t i = 0; i < flat_stmts.size(); ++i) { + if (kinds[i] == TileStmtKind::kTmaProducer) { + int barrier_group = can_merge_tma_barriers ? 0 : tma_idx; + int fwd_base = barrier_group * num_stages; + int bp_base = num_fwd + barrier_group * num_stages; + PrimExpr fwd_id = IntImm(DataType::Int(32), fwd_base) + p_stage_expr; + PrimExpr bp_id = IntImm(DataType::Int(32), bp_base) + p_stage_expr; + + // Back-pressure wait (only once when barriers are merged) + if (!can_merge_tma_barriers || tma_idx == 0) { + producer_stmts.push_back(MakeParityWait( + barrier_buf, bp_id, + bitwise_xor(p_parity_expr, IntImm(DataType::Int(32), 1)))); + } + + // After the first bp_wait, emit all SIMT/cp.async producers + // followed immediately by commit_group so the hardware can start + // the async transfers as early as possible, overlapping with TMA. + if (!simt_stmts_emitted && !simt_producer_stmts.empty()) { + for (const auto &s : simt_producer_stmts) { + producer_stmts.push_back(s); + } + // Commit cp.async group right after issuing — the earlier the + // commit, the more overlap with subsequent TMA loads. + if (has_simt_producer || has_cp_async_producer) { + producer_stmts.push_back(Evaluate( + Call(DataType::Handle(), builtin::ptx_commit_group(), {}))); + } + simt_stmts_emitted = true; + } + + for (const auto &stmt : producer_loop_prefix_stmts[tma_idx]) { + producer_stmts.push_back(stmt); + } + // Convert copy → tma_copy with barrier, or annotate non-copy + // TMA tile-ops (e.g. c2d_im2col) with barrier reference. + const auto *eval = flat_stmts[i].as(); + ICHECK(eval); + Call tile_call = Downcast(eval->value); + auto tile_op = ParseOperator(tile_call); + PrimExpr tma_call; + // For pure TMA, tell LowerTileOp to emit arrive inside the same + // tl_shuffle_elect block (via emit_arrive annotation), producing + // arrive_and_expect_tx instead of separate expect_tx + arrive. + // When merged barriers, only the last TMA copy should arrive. + bool emit_arrive_on_this = + !has_simt_producer && !has_cp_async_producer && + (!can_merge_tma_barriers || tma_idx == last_tma_idx); + + if (tile_op.defined() && tile_op.as()) { + tma_call = RewriteCopyToTmaCopy(tile_call, barrier_buf, fwd_id); } else { - fwd_bases[i] = fwd_bases[i - 1]; - } - } - num_total_fwd_barriers = - num_existing_loop_fwd_barriers + num_preloop_fwd_barriers; - for (size_t i = 0; i < extractor.blocks.size(); ++i) { - bp_bases[i] = - num_existing_loop_fwd_barriers + block_group[i] * num_stages; - } - pure_tma_preloop_fwd_base_ = - num_existing_loop_fwd_barriers + num_bp_barriers; - pure_tma_preloop_fwd_count_ = num_preloop_fwd_barriers; - pure_tma_preloop_fwd_cursor_ = 0; - total_barriers = num_total_fwd_barriers + num_bp_barriers; - } else { - // Mixed path: - // [0, num_existing_barriers) : pre-existing forward barriers - // [existing, total_fwd) : new cp.async forward barriers - // [total_fwd, total) : back-pressure barriers - num_total_fwd_barriers = - num_existing_barriers + num_new_cp_async_fwd_barriers; - int next_existing_tma_fwd_base = 0; - int next_cp_async_fwd_base = num_existing_barriers; - std::vector group_fwd_bases(num_block_groups, -1); - mixed_fwd_arrive_counts.assign(num_total_fwd_barriers, - IntImm(DataType::Int(32), 1)); - for (int group = 0; group < num_block_groups; ++group) { - if (group_has_tma[group]) { - group_fwd_bases[group] = next_existing_tma_fwd_base; - next_existing_tma_fwd_base += num_stages; + // Non-copy TMA producer (e.g. Conv2DIm2ColOp): annotate with + // barrier so Lower() uses the WS barrier instead of its own. + tma_call = AnnotateTileOpBarrier(tile_call, barrier_buf, fwd_id); + } + if (emit_arrive_on_this) { + auto call = Downcast(tma_call); + auto annos = call->annotations; + annos.Set("emit_arrive", IntImm(DataType::Int(32), 1)); + tma_call = Call(call->dtype, call->op, call->args, annos, call->span); + } + producer_stmts.push_back(Evaluate(tma_call)); + ++tma_idx; + } + // SIMT/cp.async producers are handled above (after first bp_wait). + // Consumer/Other statements are skipped in producer. + } + // Fallback: if there were no TMA producers to anchor the bp_wait, + // emit SIMT stmts now (shouldn't happen in the mixed path). + if (!simt_stmts_emitted && !simt_producer_stmts.empty()) { + for (const auto &s : simt_producer_stmts) { + producer_stmts.push_back(s); + } + } + // When any producer-side work is not single-threaded pure-TMA, all + // producer threads arrive on all forward barriers after finishing it. + // SIMT copies (later lowered to cp.async by InjectPTXAsyncCopy) and + // explicit cp.async groups use commit_group + cp_async_barrier_noinc + // so the async copy completion drives the mbarrier arrival, allowing + // TMA and cp.async to overlap. Other groups use MakeArriveBarrier. + if (has_simt_producer || has_cp_async_producer) { + // Any SIMT producer will become cp.async after LowerTileOp. + bool group_has_async_copy = has_simt_producer || has_cp_async_producer; + for (int g = 0; g < num_producer_groups; ++g) { + int fwd_base = g * num_stages; + PrimExpr fwd_id = IntImm(DataType::Int(32), fwd_base) + p_stage_expr; + if (group_has_async_copy) { + // Tie cp.async completion to the forward mbarrier. + // commit_group was already emitted right after the cp.async + // instructions (before TMA) to maximize overlap. + producer_stmts.push_back(Evaluate( + Call(DataType::Handle(), tl::ptx_cp_async_barrier_noinc(), + {MakeBarrierRef(barrier_buf, fwd_id)}))); } else { - ICHECK(group_has_cp_async[group]); - group_fwd_bases[group] = next_cp_async_fwd_base; - next_cp_async_fwd_base += num_stages; - } - PrimExpr group_arrive_count = IntImm(DataType::Int(32), 1); - if (group_has_cp_async[group]) { - group_arrive_count = producer_thread_extent; + producer_stmts.push_back(MakeArriveBarrier(barrier_buf, fwd_id)); } - for (int stage = 0; stage < num_stages; ++stage) { - mixed_fwd_arrive_counts[group_fwd_bases[group] + stage] = - group_arrive_count; - } - } - for (size_t i = 0; i < extractor.blocks.size(); ++i) { - fwd_bases[i] = group_fwd_bases[block_group[i]]; - bp_bases[i] = num_total_fwd_barriers + block_group[i] * num_stages; } - total_barriers = num_total_fwd_barriers + num_bp_barriers; - pure_tma_preloop_fwd_base_ = -1; - pure_tma_preloop_fwd_count_ = 0; - pure_tma_preloop_fwd_cursor_ = 0; } - - // Defensive check: ensure back-pressure barriers do not overlap - // any existing (forward/prologue) barrier ids in the original IR. - if (num_bp_barriers > 0 && !remap_pure_tma_barriers_) { - int existing_last = inferred_existing_required - 1; - int bp_begin = bp_bases.front(); - int bp_last = bp_begin + num_bp_barriers - 1; - ICHECK(bp_begin > existing_last) - << "ProducerConsumerWS: barrier id overlap detected. " - << "existing_last=" << existing_last << ", bp_begin=" << bp_begin - << ", bp_last=" << bp_last; + // Phase counter increment at end of producer guarded iteration + if (needs_phase_counter) { + producer_stmts.push_back(producer_phase_counter.value().Increment()); } - // Create barrier buffer early so loop body builders can use it. - barrier_buf_ = - CreateMBarrierBuffer(injected_mbarrier_name_, total_barriers); - - Var loop_var = pipeline_loop->loop_var; - PrimExpr loop_extent = pipeline_loop->extent; - PrimExpr loop_min = pipeline_loop->min; - - // Compute stage and parity expressions. - // When needs_phase_counter is true, the loop body is conditionally - // guarded and we use a mutable counter instead of the loop variable - // to derive stage/parity. Producer and consumer have separate - // counters because they run on different warp partitions. - PrimExpr linear_idx = loop_var - loop_min; - PrimExpr base_stage_expr = FloorMod(linear_idx, num_stages); - PrimExpr base_parity_expr = FloorMod(FloorDiv(linear_idx, num_stages), 2); - - PrimExpr producer_stage_expr = - needs_phase_counter ? producer_phase_counter->StageExpr(num_stages) - : base_stage_expr; - PrimExpr producer_parity_expr = - needs_phase_counter ? producer_phase_counter->ParityExpr(num_stages) - : base_parity_expr; - PrimExpr consumer_stage_expr = - needs_phase_counter ? consumer_phase_counter->StageExpr(num_stages) - : base_stage_expr; - PrimExpr consumer_parity_expr = - needs_phase_counter ? consumer_phase_counter->ParityExpr(num_stages) - : base_parity_expr; - - // --- Build Producer Body --- - Array producer_body_stmts; - for (size_t ti = 0; ti < extractor.blocks.size(); ti++) { - const auto &tma = extractor.blocks[ti]; - int group = block_group[ti]; - bool is_first_in_group = - ti == 0 || block_group[ti] != block_group[ti - 1]; - bool is_last_in_group = ti + 1 == extractor.blocks.size() || - block_group[ti] != block_group[ti + 1]; - PrimExpr bp_id = - IntImm(DataType::Int(32), bp_bases[ti]) + producer_stage_expr; - - // Back-pressure wait: producer cannot reuse the stage buffer until the - // consumer releases it. xor(parity, 1) bootstraps the first iteration. - if (is_first_in_group) { - producer_body_stmts.push_back(WrapStmtWithGuardSource( - protocol_guard_sources[ti], protocol_guards[ti], - makeParityWait(barrier_buf_, bp_id, - bitwise_xor(producer_parity_expr, 1)))); - for (const auto &stmt : producer_loop_prefix_stmts[ti]) { - producer_body_stmts.push_back(stmt); - } + // --- Build consumer body --- + // When barriers are merged, iterate over a single effective group. + int consumer_barrier_groups = + can_merge_tma_barriers ? 1 : num_producer_groups; + Array consumer_stmts; + std::vector arrive_emitted(consumer_barrier_groups, false); + for (size_t ci = 0; ci < consumer_compute_stmts.size(); ++ci) { + for (const auto &stmt : prelude_waits_before_consumer[ci]) { + consumer_stmts.push_back(stmt); } - - Stmt producer_stmt = ws_producer_stmts[ti]; - if (tma.kind == AsyncProducerKind::kTma) { - ICHECK_GE(fwd_bases[ti], 0); - PrimExpr barrier_id = - IntImm(DataType::Int(32), fwd_bases[ti]) + producer_stage_expr; - if (use_full_tma_forward_barrier_protocol_) { - // Pure-TMA WS uses a full producer-side release protocol so the - // consumer waits on a barrier owned by the producer branch. - producer_stmt = RewriteTmaForwardProducerStmt( - producer_stmt, barrier_id, is_last_in_group); - } else { - // Mixed groups keep the original producer-side TMA protocol, but - // rebind grouped loads onto a shared forward barrier. If the group - // also contains cp.async, let cp.async.mbarrier.arrive.noinc own the - // arrival count so the shared forward barrier stays on the producer - // thread extent instead of adding an extra leader-only arrive. - producer_stmt = RewriteTmaStmtBarrierIdPreserveProtocol( - producer_stmt, barrier_id, group_has_cp_async[group]); - } - // Keep expect/load under the same elected lane when lowering has - // emitted them as adjacent identical IfThenElse wrappers. - producer_stmt = MergeAdjacentEquivalentIfs(producer_stmt); - } - - // Execute the producer statement. - producer_body_stmts.push_back(producer_stmt); - if (is_last_in_group && group_has_cp_async[group]) { - ICHECK_GE(fwd_bases[ti], 0); - PrimExpr fwd_id = - IntImm(DataType::Int(32), fwd_bases[ti]) + producer_stage_expr; - producer_body_stmts.push_back(WrapStmtWithGuardSource( - producer_issue_guard_sources[ti], producer_issue_guards[ti], - makeCpAsyncBarrierNoInc(barrier_buf_, fwd_id))); - } - // Phase counter increment – exactly once per guarded iteration, - // after ALL groups have issued their barrier ops. - // MergeAdjacentEquivalentIfs will fold this into the same guard. - if (needs_phase_counter && ti + 1 == extractor.blocks.size()) { - producer_body_stmts.push_back(WrapStmtWithGuardSource( - uniform_phase_guard_source, uniform_phase_guard, - producer_phase_counter->Increment())); - } - } - Stmt producer_loop_body = - MergeAdjacentEquivalentIfs(SeqStmt(producer_body_stmts)); - producer_loop_body = rewrap_loop_body_lets(producer_loop_body); - - // --- Build Consumer Body --- - Array consumer_body_stmts; - - // Place forward waits at first use and back-pressure arrives at last use. - // If we cannot prove the dependency, fall back to wait-at-head / - // arrive-at-tail. - std::vector arrive_emitted(extractor.blocks.size(), false); - std::vector normalized_waits; - normalized_waits.reserve(extractor.blocks.size()); - for (size_t ti = 0; ti < extractor.blocks.size(); ++ti) { - ICHECK_GE(fwd_bases[ti], 0); - PrimExpr fwd_id = - IntImm(DataType::Int(32), fwd_bases[ti]) + consumer_stage_expr; - if (ws_wait_stmts[ti].defined()) { - normalized_waits.push_back(RewriteWaitBarrier( - ws_wait_stmts[ti].value(), fwd_id, consumer_parity_expr)); - } else { - normalized_waits.push_back(WrapStmtWithGuardSource( - producer_issue_guard_sources[ti], producer_issue_guards[ti], - makeParityWait(barrier_buf_, fwd_id, consumer_parity_expr))); - } - } - // Emit waits / compute / arrives according to insertion points. - for (size_t ci = 0; ci < extractor.compute_stmts.size(); ++ci) { - for (size_t ti = 0; ti < extractor.blocks.size(); ++ti) { - bool is_first_in_group = - ti == 0 || block_group[ti] != block_group[ti - 1]; - if (is_first_in_group && wait_insert_pos[ti] == static_cast(ci)) { - consumer_body_stmts.push_back(normalized_waits[ti]); + for (int g = 0; g < consumer_barrier_groups; ++g) { + if (wait_insert_pos[g] == static_cast(ci)) { + int fwd_base = g * num_stages; + PrimExpr fwd_id = IntImm(DataType::Int(32), fwd_base) + c_stage_expr; + consumer_stmts.push_back( + MakeParityWait(barrier_buf, fwd_id, c_parity_expr)); } } if (!moved_compute_stmts[ci]) { - consumer_body_stmts.push_back(extractor.compute_stmts[ci]); - } - for (size_t ti = 0; ti < extractor.blocks.size(); ++ti) { - bool is_last_in_group = ti + 1 == extractor.blocks.size() || - block_group[ti] != block_group[ti + 1]; - if (is_last_in_group && - arrive_insert_pos[ti] == static_cast(ci + 1)) { - PrimExpr bp_id = - IntImm(DataType::Int(32), bp_bases[ti]) + consumer_stage_expr; - consumer_body_stmts.push_back(WrapStmtWithGuardSource( - protocol_guard_sources[ti], protocol_guards[ti], - makeArriveBarrier(barrier_buf_, bp_id))); - arrive_emitted[ti] = true; + consumer_stmts.push_back(consumer_compute_stmts[ci]); + } + for (int g = 0; g < consumer_barrier_groups; ++g) { + if (arrive_insert_pos[g] == static_cast(ci + 1)) { + int bp_base = num_fwd + g * num_stages; + PrimExpr bp_id = IntImm(DataType::Int(32), bp_base) + c_stage_expr; + consumer_stmts.push_back(MakeArriveBarrier(barrier_buf, bp_id)); + arrive_emitted[g] = true; } } } - - // Handle degenerate loops with no compute statements. - if (extractor.compute_stmts.empty()) { - for (size_t ti = 0; ti < extractor.blocks.size(); ++ti) { - bool is_first_in_group = - ti == 0 || block_group[ti] != block_group[ti - 1]; - if (is_first_in_group) { - consumer_body_stmts.push_back(normalized_waits[ti]); - } + if (consumer_compute_stmts.empty()) { + for (int g = 0; g < consumer_barrier_groups; ++g) { + int fwd_base = g * num_stages; + PrimExpr fwd_id = IntImm(DataType::Int(32), fwd_base) + c_stage_expr; + consumer_stmts.push_back( + MakeParityWait(barrier_buf, fwd_id, c_parity_expr)); } } - - // Emit loop-tail arrives (blocks with unknown deps or tail use). - for (size_t ti = 0; ti < extractor.blocks.size(); ti++) { - bool is_last_in_group = ti + 1 == extractor.blocks.size() || - block_group[ti] != block_group[ti + 1]; - if (is_last_in_group && !arrive_emitted[ti] && - arrive_insert_pos[ti] == - static_cast(extractor.compute_stmts.size())) { - PrimExpr bp_id = - IntImm(DataType::Int(32), bp_bases[ti]) + consumer_stage_expr; - consumer_body_stmts.push_back(WrapStmtWithGuardSource( - protocol_guard_sources[ti], protocol_guards[ti], - makeArriveBarrier(barrier_buf_, bp_id))); + for (int g = 0; g < consumer_barrier_groups; ++g) { + if (!arrive_emitted[g] && + arrive_insert_pos[g] == + static_cast(consumer_compute_stmts.size())) { + int bp_base = num_fwd + g * num_stages; + PrimExpr bp_id = IntImm(DataType::Int(32), bp_base) + c_stage_expr; + consumer_stmts.push_back(MakeArriveBarrier(barrier_buf, bp_id)); } } - // Phase counter increment for the consumer side. - if (needs_phase_counter) { - consumer_body_stmts.push_back(WrapStmtWithGuardSource( - uniform_phase_guard_source, uniform_phase_guard, - consumer_phase_counter->Increment())); - } - Stmt consumer_loop_body = - MergeAdjacentEquivalentIfs(SeqStmt(consumer_body_stmts)); - consumer_loop_body = rewrap_loop_body_lets(consumer_loop_body); - - // --- Replace shared-memory stage expressions with phase counters --- - // When the loop body is conditionally guarded, the barrier IDs already - // use phase-counter-based stage/parity, but the shared-memory buffer - // offsets still embed FloorMod(loop_var - loop_min, num_stages). - // Rewrite them so that buffer staging stays in sync with barriers. + // Phase counter increment at end of consumer guarded iteration if (needs_phase_counter) { - producer_loop_body = StageExprReplacer::Replace( - producer_loop_body, loop_var, loop_min, num_stages, - producer_phase_counter->StageExpr(num_stages)); - consumer_loop_body = StageExprReplacer::Replace( - consumer_loop_body, loop_var, loop_min, num_stages, - consumer_phase_counter->StageExpr(num_stages)); + consumer_stmts.push_back(consumer_phase_counter.value().Increment()); } - // --- Build the loops --- - // Remove pipeline annotations since WS handles overlap directly + // --- Wrap with let bindings and optional condition --- + auto wrap_lets = + [&](Stmt body, + const std::vector> &bindings) -> Stmt { + for (auto it = bindings.rbegin(); it != bindings.rend(); ++it) { + body = LetStmt(it->first, it->second, body); + } + return body; + }; + + Stmt producer_body = wrap_lets(SeqStmt(producer_stmts), inner_let_bindings); + Stmt consumer_body = wrap_lets(SeqStmt(consumer_stmts), inner_let_bindings); + + // Wrap in original condition if the loop body was guarded. + if (loop_body_condition.defined()) { + producer_body = IfThenElse(loop_body_condition.value(), producer_body); + consumer_body = IfThenElse(loop_body_condition.value(), consumer_body); + } + + producer_body = wrap_lets(producer_body, outer_let_bindings); + consumer_body = wrap_lets(consumer_body, outer_let_bindings); + + // Rewrite shared-buffer stage indices from loop-var-based to + // counter-based so they stay in sync with barrier parity. + if (needs_phase_counter) { + producer_body = StageExprReplacer::Replace( + producer_body, loop_var, loop_min, num_stages, + producer_phase_counter.value().StageExpr(num_stages)); + consumer_body = StageExprReplacer::Replace( + consumer_body, loop_var, loop_min, num_stages, + consumer_phase_counter.value().StageExpr(num_stages)); + } + producer_body = + TileOpMbarPhaseAnnotator::Annotate(producer_body, p_parity_expr); + consumer_body = + TileOpMbarPhaseAnnotator::Annotate(consumer_body, c_parity_expr); + + // --- Build loops (strip pipeline annotations) --- + // WS handles pipeline overlap via barriers, so strip all pipeline- + // related annotations to prevent PipelinePlanning / InjectSoftware- + // Pipeline from re-pipelining the already WS-transformed loops. Map loop_annos; for (const auto &[key, value] : pipeline_loop->annotations) { if (key != "num_stages" && key != "tl_pipeline_order" && @@ -1704,2331 +1666,734 @@ class ProducerConsumerWSRewriter : public StmtExprMutator { } } - Stmt producer_loop = - For(loop_var, loop_min, loop_extent, ForKind::kSerial, - producer_loop_body, Optional(), loop_annos); - Stmt consumer_loop = - For(loop_var, loop_min, loop_extent, ForKind::kSerial, - consumer_loop_body, Optional(), loop_annos); + For producer_loop(loop_var, loop_min, loop_extent, ForKind::kSerial, + producer_body, Optional(), loop_annos); + For consumer_loop(loop_var, loop_min, loop_extent, ForKind::kSerial, + consumer_body, Optional(), loop_annos); // Wrap loops with phase counter allocation when needed. + Stmt final_producer_loop = producer_loop; + Stmt final_consumer_loop = consumer_loop; if (needs_phase_counter) { - producer_loop = producer_phase_counter->WrapLoopWithAlloc(producer_loop); - consumer_loop = consumer_phase_counter->WrapLoopWithAlloc(consumer_loop); - } - - // Rewrite threadIdx.x in producer: threadIdx.x -> threadIdx.x - - // consumer_threads Also converts `if (threadIdx.x == 0)` to `if - // (tl_shuffle_elect(extent))` - producer_loop = PCThreadIdxRewriter::Rewrite( - producer_loop, thread_iv_->var, - thread_iv_->var - consumer_thread_extent, producer_thread_extent, - /*do_shuffle=*/true); - consumer_loop = PCThreadIdxRewriter::Rewrite( - consumer_loop, thread_iv_->var, thread_iv_->var, consumer_thread_extent, - /*do_shuffle=*/true); - - // Wrap in IfThenElse: producer if threadIdx.x >= consumer_threads - Stmt ws_body = IfThenElse(GE(thread_iv_->var, consumer_thread_extent), - producer_loop, consumer_loop); - - // Add warp specialization scope attribute - Array ws_partition = {Downcast(producer_thread_extent), - Downcast(consumer_thread_extent)}; - ws_body = - AttrStmt(ws_partition, attr::kWarpSpecializationScope, 0, ws_body); - - // Forward barriers are producer-owned; back-pressure barriers are released - // by the full consumer partition. - Array barrier_arrive_counts; - barrier_arrive_counts.reserve(total_barriers); - if (remap_pure_tma_barriers_) { - for (int i = 0; i < num_existing_loop_fwd_barriers; ++i) { - barrier_arrive_counts.push_back(IntImm(DataType::Int(32), 1)); - } - for (int i = 0; i < num_bp_barriers; ++i) { - barrier_arrive_counts.push_back(consumer_thread_extent); - } - for (int i = 0; i < num_preloop_fwd_barriers; ++i) { - barrier_arrive_counts.push_back(IntImm(DataType::Int(32), 1)); - } - } else { - ICHECK_EQ(mixed_fwd_arrive_counts.size(), - static_cast(num_total_fwd_barriers)); - for (const auto &count : mixed_fwd_arrive_counts) { - barrier_arrive_counts.push_back(count); - } - for (int i = 0; i < num_bp_barriers; i++) { - barrier_arrive_counts.push_back(consumer_thread_extent); - } - } - // barrier_arrive_counts will be used for the barrier_init annotation. - - LocalLiveSet producer_live_seed = - SeedLiveSetFromStmt(producer_loop_body, buffer_data_to_buffer); - LocalLiveSet consumer_live_seed = - SeedLiveSetFromStmt(consumer_loop_body, buffer_data_to_buffer); - // Pre-loop liveness assignment must also account for variables used only in - // the pipeline loop bounds. Otherwise scalar setup that feeds the loop - // extent/min can be misclassified as common code and hoisted outside the - // warp-specialized split. - producer_live_seed.AddUses( - LocalAccessCollector::CollectExpr(loop_min, buffer_data_to_buffer)); - producer_live_seed.AddUses( - LocalAccessCollector::CollectExpr(loop_extent, buffer_data_to_buffer)); - consumer_live_seed.AddUses( - LocalAccessCollector::CollectExpr(loop_min, buffer_data_to_buffer)); - consumer_live_seed.AddUses( - LocalAccessCollector::CollectExpr(loop_extent, buffer_data_to_buffer)); - - // Reconstruct block body: replace the pipeline loop with ws_body - // and remove old barrier_init annotations / shared.barrier buffers. - Stmt new_block_body = RebuildBlockBody( - orig_block->body, pipeline_loop, ws_body, buffer_data_to_buffer, - producer_live_seed, consumer_live_seed); - - // Update thread extent - num_threads_ = consumer_thread_extent + producer_thread_extent; - ws_transformed_ = true; - use_full_tma_forward_barrier_protocol_ = - old_use_full_tma_forward_barrier_protocol; - remap_pure_tma_barriers_ = old_remap_pure_tma_barriers; - pure_tma_preloop_fwd_base_ = old_pure_tma_preloop_fwd_base; - pure_tma_preloop_fwd_count_ = old_pure_tma_preloop_fwd_count; - pure_tma_preloop_fwd_cursor_ = old_pure_tma_preloop_fwd_cursor; - current_loop_guard_bindings_ = std::move(saved_loop_guard_bindings); - - // Build the new Block and BlockRealize. - // Add barrier buffer to alloc_buffers and barrier_init annotation. - Array new_alloc_buffers = orig_block->alloc_buffers; - // Remove any existing shared.barrier buffers from old approach + final_producer_loop = + producer_phase_counter.value().WrapLoopWithAlloc(producer_loop); + final_consumer_loop = + consumer_phase_counter.value().WrapLoopWithAlloc(consumer_loop); + } + + // --- Rewrite threadIdx.x for producer partition --- + // Producer: threadIdx.x - consumer_extent (maps to [0, producer_extent)) + Stmt rewritten_producer = PCThreadIdxRewriter::Rewrite( + final_producer_loop, thread_iv_->var, thread_iv_->var - consumer_extent, + producer_extent, false); + // Consumer: threadIdx.x stays, but extent is consumer_extent + Stmt rewritten_consumer = final_consumer_loop; + + producer_prelude_live_seed_ = {}; + consumer_prelude_live_seed_ = {}; + producer_prelude_live_seed_.AddUses(LocalAccessCollector::Collect( + rewritten_producer, buffer_data_to_buffer_)); + consumer_prelude_live_seed_.AddUses(LocalAccessCollector::Collect( + rewritten_consumer, buffer_data_to_buffer_)); + + // Move pre-loop branch-private initialization next to the branch that + // consumes it. Classification is based on downstream producer/consumer + // uses of the values defined by each prelude statement. + extracted_producer_init_ = {}; + extracted_consumer_init_ = {}; + + Array ws_partition = {Downcast(producer_extent), + Downcast(consumer_extent)}; + + // First pass: find and extract consumer-only pre-loop statements + // by doing a dry replacement that populates extracted_consumer_init_. + Stmt dummy_producer = rewritten_producer; + const Stmt &dummy_consumer = rewritten_consumer; + Stmt dummy_ws = IfThenElse(GE(thread_iv_->var, consumer_extent), + dummy_producer, dummy_consumer); + dummy_ws = + AttrStmt(ws_partition, attr::kWarpSpecializationScope, 0, dummy_ws); + ReplaceResult replaced = ReplacePipelineLoopInStmt( + orig_block->body, pipeline_loop, dummy_ws, consumer_extent); + + // Producer and consumer partitions cannot safely share the same block-level + // local/fragment buffers after tiled WS is introduced before + // LayoutInference: a single fragment layout cannot represent both thread + // ranges. Clone every branch-private buffer touched by the producer so + // LayoutInference can infer an independent producer-side thread range. + BufferNodeMap producer_buffer_remap; + Array producer_private_buffers; { - Array filtered; - for (const auto &buf : new_alloc_buffers) { - if (buf.scope() != "shared.barrier") { - filtered.push_back(buf); + std::unordered_set block_alloc_buffers; + for (const auto &buffer : orig_block->alloc_buffers) { + block_alloc_buffers.insert(buffer.get()); + } + LocalAccessSummary producer_access = LocalAccessCollector::Collect( + rewritten_producer, buffer_data_to_buffer_); + for (const auto &stmt : extracted_producer_init_) { + MergeLocalAccessSummary( + &producer_access, + LocalAccessCollector::Collect(stmt, buffer_data_to_buffer_)); + } + auto maybe_clone = [&](const Buffer &buffer) { + if (!buffer.defined() || + !(IsFragmentBuffer(buffer) || IsLocalBuffer(buffer, true)) || + !block_alloc_buffers.count(buffer.get()) || + producer_buffer_remap.count(buffer.get())) { + return; } - } - new_alloc_buffers = filtered; - } - new_alloc_buffers.push_back(barrier_buf_); - - Map new_annotations = orig_block->annotations; - // Remove any old barrier_init and build fresh - Map> barrier_init_map; - barrier_init_map.Set(barrier_buf_->data, barrier_arrive_counts); - new_annotations.Set("barrier_init", barrier_init_map); - - Block new_block(orig_block->iter_vars, orig_block->reads, - orig_block->writes, orig_block->name_hint, new_block_body, - orig_block->init, new_alloc_buffers, - orig_block->match_buffers, new_annotations); - return BlockRealize(op->iter_values, op->predicate, new_block); - } - - // Handle ForNode with thread bindings - Stmt VisitStmt_(const ForNode *op) final { - if (op->kind == ForKind::kThreadBinding && op->thread_binding.defined() && - op->thread_binding.value()->thread_tag == "threadIdx.x" && - !thread_iv_.defined()) { - thread_iv_ = op->thread_binding.value(); - Optional old_num_threads = num_threads_; - num_threads_ = std::nullopt; - For for_node = Downcast(StmtExprMutator::VisitStmt_(op)); - if (num_threads_.defined()) { - PrimExpr num_threads = num_threads_.value(); - auto n = for_node.CopyOnWrite(); - n->extent = num_threads; - IterVar new_thread_iv = n->thread_binding.value(); - new_thread_iv.CopyOnWrite()->dom = - Range::FromMinExtent(Integer(0), num_threads); - n->thread_binding = new_thread_iv; - } - num_threads_ = old_num_threads; - thread_iv_ = {}; - return for_node; + Buffer cloned = CloneBranchPrivateBuffer(buffer, "_producer_ws"); + producer_buffer_remap.emplace(buffer.get(), cloned); + producer_private_buffers.push_back(cloned); + }; + for (const auto &buffer : producer_access.read_buffers) { + maybe_clone(buffer); + } + for (const auto &buffer : producer_access.write_buffers) { + maybe_clone(buffer); + } + } + if (!producer_buffer_remap.empty()) { + rewritten_producer = + BufferRemapper::Rewrite(rewritten_producer, producer_buffer_remap); + Array remapped_producer_init; + for (const auto &stmt : extracted_producer_init_) { + remapped_producer_init.push_back( + BufferRemapper::Rewrite(stmt, producer_buffer_remap)); + } + extracted_producer_init_ = remapped_producer_init; + } + + // If branch-local prelude init/copy was extracted, rebuild with it inside + // the corresponding WS branch so each branch initializes its own local + // state before entering the pipelined loop. + if (!extracted_producer_init_.empty() || + !extracted_consumer_init_.empty()) { + Stmt enriched_producer = rewritten_producer; + if (!extracted_producer_init_.empty()) { + Array producer_parts; + for (const auto &s : extracted_producer_init_) { + producer_parts.push_back(PCThreadIdxRewriter::Rewrite( + s, thread_iv_->var, thread_iv_->var - consumer_extent, + producer_extent, false)); + } + producer_parts.push_back(rewritten_producer); + enriched_producer = producer_parts.size() == 1 + ? producer_parts[0] + : SeqStmt(producer_parts); + } + Array consumer_parts; + for (const auto &s : extracted_consumer_init_) { + consumer_parts.push_back(s); + } + consumer_parts.push_back(rewritten_consumer); + Stmt enriched_consumer = consumer_parts.size() == 1 + ? consumer_parts[0] + : SeqStmt(consumer_parts); + Stmt scoped_producer = enriched_producer; + const Stmt &scoped_consumer = enriched_consumer; + Stmt ws_body = IfThenElse(GE(thread_iv_->var, consumer_extent), + scoped_producer, scoped_consumer); + ws_body = + AttrStmt(ws_partition, attr::kWarpSpecializationScope, 0, ws_body); + // Second pass: replace again with the enriched WS body. + // extracted_consumer_init_ is already empty (stmts were removed + // from the prelude in the first pass result). + // We need to replace in the ALREADY-modified body from pass 1. + // But ReplacePipelineLoopInStmt finds the pipeline_loop by + // pointer comparison, which won't match in the modified tree. + // Instead, just substitute the dummy_ws in the replaced result. + // Since dummy_ws appears exactly once in replaced.stmt, do a + // simple statement replacement on the full placeholder stmt. + class SubstWsBody : public StmtExprMutator { + public: + SubstWsBody(const Stmt &old_ws, const Stmt &new_ws) + : old_(old_ws), new_(new_ws) {} + Stmt VisitStmt(const Stmt &stmt) final { + if (stmt.same_as(old_)) { + return new_; + } + return StmtExprMutator::VisitStmt(stmt); + } + Stmt old_, new_; + }; + SubstWsBody subst(dummy_ws, ws_body); + replaced.stmt = subst(replaced.stmt); } + ICHECK(replaced.found) + << "ProducerConsumerWS: failed to replace pipeline loop"; + Stmt new_block_body = SinkGuardedConsumerPostlude::Rewrite( + replaced.stmt, thread_iv_->var, consumer_extent); - For for_node = Downcast(StmtExprMutator::VisitStmt_(op)); - if (for_node->kind == ForKind::kThreadBinding && thread_iv_.defined()) { - ICHECK(for_node->thread_binding.defined()); - String thread_tag = for_node->thread_binding.value()->thread_tag; - if (thread_tag == "threadIdx.x") { - Var thread_v = Downcast(for_node->loop_var); - Stmt new_body = PCThreadIdxRewriter::Rewrite(for_node->body, thread_v, - thread_iv_->var, 0); - return new_body; - } + // --- Update block --- + Block new_block = orig_block; + auto *block_ptr = new_block.CopyOnWrite(); + block_ptr->body = new_block_body; + for (const auto &buffer : producer_private_buffers) { + block_ptr->alloc_buffers.push_back(buffer); } - return for_node; - } - // --------------------------------------------------------------------------- - // Utility methods - // --------------------------------------------------------------------------- + // Add barrier buffer to alloc_buffers. + block_ptr->alloc_buffers.push_back(barrier_buf); - void FlattenSeqStmt(const Stmt &s, Array *out) { - if (auto *seq = s.as()) { - for (const auto &sub : seq->seq) { - FlattenSeqStmt(sub, out); + // Add barrier_init annotation. + Map> barrier_init_map; + barrier_init_map.Set(barrier_buf->data, arrive_counts); + auto ann = block_ptr->annotations; + if (ann.count("barrier_init")) { + auto existing = + Downcast>>(ann.Get("barrier_init").value()); + for (auto [k, v] : existing) { + barrier_init_map.Set(k, v); } - } else { - out->push_back(s); } - } - - struct BufferDataAccessInfo { - bool read{false}; - bool write{false}; - - bool HasAnyAccess() const { return read || write; } - }; - - BufferDataAccessInfo - AnalyzeBufferDataAccess(const Stmt &stmt, const Var &buffer_data, - const BufferDataToBufferMap &buffer_map) const { - class BufferDataAccessDetector : public StmtExprVisitor { - public: - BufferDataAccessDetector(const Var &buffer_data, - const BufferDataToBufferMap &buffer_map) - : buffer_data_(buffer_data), buffer_map_(buffer_map) {} - - BufferDataAccessInfo Result() const { return result_; } - - private: - void VisitExpr_(const BufferLoadNode *op) final { - if (op->buffer->data.same_as(buffer_data_)) { - result_.read = true; - } - StmtExprVisitor::VisitExpr_(op); - } - - void VisitStmt_(const BufferStoreNode *op) final { - if (op->buffer->data.same_as(buffer_data_)) { - result_.write = true; - } - StmtExprVisitor::VisitStmt_(op); - } - - void VisitExpr_(const CallNode *op) final { - if (op->op.same_as(tl::access_ptr())) { - ICHECK_EQ(op->args.size(), 3); - const auto *base_load = op->args[0].as(); - ICHECK(base_load); - if (base_load->buffer->data.same_as(buffer_data_)) { - MarkAccess(op->args[2]); - } - for (const auto &index : base_load->indices) { - VisitExpr(index); - } - VisitExpr(op->args[1]); - return; - } - - if (op->op.same_as(builtin::tvm_access_ptr())) { - ICHECK_EQ(op->args.size(), 5); - const auto *var = op->args[1].as(); - ICHECK(var); - auto it = buffer_map_.find(GetRef(var)); - if (it != buffer_map_.end() && - it->second->data.same_as(buffer_data_)) { - MarkAccess(op->args[4]); - } - VisitExpr(op->args[2]); - VisitExpr(op->args[3]); - return; - } - - StmtExprVisitor::VisitExpr_(op); - } + ann.Set("barrier_init", barrier_init_map); + block_ptr->annotations = std::move(ann); - void MarkAccess(const PrimExpr &rw_expr) { - int rw_mask = 3; - if (const auto *imm = rw_expr.as()) { - rw_mask = static_cast(imm->value); - } - if (rw_mask & 1) { - result_.read = true; - } - if (rw_mask & 2) { - result_.write = true; - } - } - - Var buffer_data_; - const BufferDataToBufferMap &buffer_map_; - BufferDataAccessInfo result_; - }; + // Update thread extent at the tiled WS level so LayoutInference sees + // the producer branch as live and can analyze explicit TMA copies. + num_threads_ = consumer_extent + producer_extent; + ws_transformed_ = true; - BufferDataAccessDetector detector(buffer_data, buffer_map); - detector(stmt); - return detector.Result(); + // Rebuild BlockRealize. + BlockRealize new_realize = GetRef(orig_realize); + new_realize.CopyOnWrite()->block = new_block; + return new_realize; } - const ForNode *FindAnnotatedPipelineLoop(const Stmt &stmt) { + // --- Find the first For loop with num_stages annotation --- + const ForNode *FindPipelineLoop(const Stmt &stmt) { if (auto *for_node = stmt.as()) { if (for_node->annotations.Get("num_stages")) { return for_node; } } + // Walk through SeqStmt, LetStmt, etc. if (auto *seq = stmt.as()) { - for (const auto &s : seq->seq) { - if (auto *result = FindAnnotatedPipelineLoop(s)) { + for (const Stmt &s : seq->seq) { + if (auto *result = FindPipelineLoop(s)) { return result; } } - return nullptr; } - if (auto *if_stmt = stmt.as()) { - if (auto *result = FindAnnotatedPipelineLoop(if_stmt->then_case)) { - return result; - } - if (if_stmt->else_case.defined()) { - return FindAnnotatedPipelineLoop(if_stmt->else_case.value()); - } - return nullptr; + if (auto *let = stmt.as()) { + return FindPipelineLoop(let->body); } if (auto *realize = stmt.as()) { - return FindAnnotatedPipelineLoop(realize->block->body); + return FindPipelineLoop(realize->block->body); } if (auto *block = stmt.as()) { - return FindAnnotatedPipelineLoop(block->body); + return FindPipelineLoop(block->body); } if (auto *attr = stmt.as()) { - return FindAnnotatedPipelineLoop(attr->body); - } - if (auto *let_s = stmt.as()) { - return FindAnnotatedPipelineLoop(let_s->body); + return FindPipelineLoop(attr->body); } return nullptr; } - // Infer how many mbarriers are already referenced by this block body. - // This prevents assigning back-pressure barriers that alias existing - // forward barriers (e.g. prologue TMA copy barriers outside the pipeline). - int InferMinRequiredBarrierCount(const Stmt &stmt) { - class GetMbarrierMaxIdxCollector : public StmtExprVisitor { - public: - int max_idx{-1}; - bool has_unbounded{false}; - - private: - void VisitStmt_(const ForNode *op) final { - // Bind loop variable range so expressions like (k + c) can be bounded. - analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); - StmtExprVisitor::VisitStmt_(op); - } - - void VisitExpr_(const BufferLoadNode *op) final { - if (op->buffer.scope() == "shared.barrier" && op->indices.size() == 1) { - auto bound = analyzer_.const_int_bound(op->indices[0]); - if (bound->max_value != arith::ConstIntBound::kPosInf && - bound->max_value != arith::ConstIntBound::kNegInf) { - max_idx = std::max(max_idx, static_cast(bound->max_value)); - } else { - has_unbounded = true; - } - } - StmtExprVisitor::VisitExpr_(op); - } - arith::Analyzer analyzer_; - }; - - GetMbarrierMaxIdxCollector collector; - collector(stmt); - ICHECK(!collector.has_unbounded) - << "ProducerConsumerWS: cannot infer finite upper bound for existing " - << "mbarrier id expressions. Refusing to allocate back-pressure " - << "barriers to avoid id overlap."; - return collector.max_idx + 1; - } + struct ReplaceResult { + Stmt stmt; + bool found{false}; + }; - int CountRewrittenPureTmaPreloopForwardPairs(const Stmt &stmt, - const ForNode *target_loop) { - if (stmt.as() == target_loop) { - return 0; + class SinkGuardedConsumerPostlude : public StmtExprMutator { + public: + static Stmt Rewrite(const Stmt &stmt, Var thread_var, + PrimExpr consumer_extent) { + SinkGuardedConsumerPostlude sinker(std::move(thread_var), + std::move(consumer_extent)); + return sinker.VisitStmt(stmt); } - if (auto *seq = stmt.as()) { - Array pre_loop_stmts; - bool found_loop = false; - int nested_count = 0; - for (const auto &s : seq->seq) { - if (!found_loop && ContainsLoop(s, target_loop)) { - nested_count = - CountRewrittenPureTmaPreloopForwardPairs(s, target_loop); - found_loop = true; - } else if (!found_loop) { - pre_loop_stmts.push_back(s); - } - } - if (!found_loop) { - return 0; - } - size_t movable_begin = pre_loop_stmts.size(); - while (movable_begin > 0 && - IsMovableConsumerPrefixStmt(pre_loop_stmts[movable_begin - 1])) { - --movable_begin; - } + private: + SinkGuardedConsumerPostlude(Var thread_var, PrimExpr consumer_extent) + : thread_var_(std::move(thread_var)), + consumer_extent_(std::move(consumer_extent)) {} - int local_count = 0; - for (size_t i = 0; i < movable_begin; ++i) { - if (ExtractTmaProducerWaitPair(pre_loop_stmts[i]).has_value()) { - ++local_count; - continue; - } - if (ExtractFlatTmaProducerClusterBeforeWait(pre_loop_stmts, - static_cast(i)) - .has_value()) { - ++local_count; - continue; - } - if (i + 1 < movable_begin && ContainsTmaLoadStmt(pre_loop_stmts[i]) && - IsMbarrierWaitParityStmt(pre_loop_stmts[i + 1])) { - ++local_count; - ++i; - } - } - return nested_count + local_count; - } - if (auto *if_stmt = stmt.as()) { - int then_count = CountRewrittenPureTmaPreloopForwardPairs( - if_stmt->then_case, target_loop); - if (if_stmt->else_case.defined()) { - return then_count + CountRewrittenPureTmaPreloopForwardPairs( - if_stmt->else_case.value(), target_loop); - } - return then_count; - } - if (auto *attr = stmt.as()) { - return CountRewrittenPureTmaPreloopForwardPairs(attr->body, target_loop); - } - if (auto *let_s = stmt.as()) { - return CountRewrittenPureTmaPreloopForwardPairs(let_s->body, target_loop); + static bool SameExpr(const PrimExpr &lhs, const PrimExpr &rhs) { + return ExprDeepEqual()(lhs, rhs); } - if (auto *realize = stmt.as()) { - return CountRewrittenPureTmaPreloopForwardPairs(realize->block->body, - target_loop); - } - if (auto *block = stmt.as()) { - return CountRewrittenPureTmaPreloopForwardPairs(block->body, target_loop); - } - return 0; - } - - // Single source of truth for barrier/TMA control-like calls that should not - // be moved across producer/consumer partition boundaries. - bool IsBarrierOrTmaControlCall(const CallNode *call) { - return call->op.same_as(mbarrier_wait_parity()) || - call->op.same_as(mbarrier_expect_tx()) || - call->op.same_as(builtin::ptx_arrive_barrier()) || - call->op.same_as(tl::ptx_arrive_cluster_barrier()) || - call->op.same_as(builtin::ptx_arrive_barrier_expect_tx()) || - call->op.same_as(builtin::ptx_cp_async_barrier()) || - call->op.same_as(tl::ptx_cp_async_barrier_noinc()) || - call->op.same_as(tma_load()) || - call->op.same_as(tma_load_im2col()) || - call->op.same_as(tma_store()) || - call->op.same_as(tma_store_arrive()) || - call->op.same_as(tma_store_wait()) || - call->op.same_as(builtin::tvm_storage_sync()); - } - - bool IsMovableConsumerPrefixStmt(const Stmt &stmt) { - bool has_disallowed = false; - PostOrderVisit(stmt, [&](const ObjectRef &node) { - if (has_disallowed) { - return; - } - if (auto *call = node.as()) { - if (IsBarrierOrTmaControlCall(call)) { - has_disallowed = true; - return; - } - } - if (auto *ld = node.as()) { - // Only move pure local init into the consumer prefix. If a stmt reads - // global or shared memory, the producer may also depend on its result - // (for example a mask controlling which async copies to issue). - if (IsSharedBuffer(ld->buffer) || IsGlobalBuffer(ld->buffer)) { - has_disallowed = true; - return; - } - } - if (auto *st = node.as()) { - if (IsSharedBuffer(st->buffer) || IsGlobalBuffer(st->buffer)) { - has_disallowed = true; - return; - } - } - }); - return !has_disallowed; - } - bool - IsProducerMovableLoopPrefixStmt(const Stmt &stmt, - const LocalAccessSummary &summary, - const LocalLiveSet &future_consumer_live, - const BufferSet &prefix_defined_buffers) { - auto is_branch_private_buffer = [](const Buffer &buffer) { - return IsLocalBuffer(buffer, /*allow_var=*/true) || - IsFragmentBuffer(buffer); - }; - bool has_disallowed = false; - PostOrderVisit(stmt, [&](const ObjectRef &node) { - if (has_disallowed) { - return; + bool IsWSBranchStmt(const Stmt &stmt, IfThenElse *branch) const { + const auto *if_node = stmt.as(); + if (!if_node || !if_node->else_case.defined()) { + return false; } - if (const auto *call = node.as()) { - if (call->op.same_as(builtin::tvm_storage_sync())) { - const auto *scope = call->args[0].as(); - if (scope && - (scope->value == "shared" || scope->value == "shared.dyn")) { - return; - } - } - if (IsBarrierOrTmaControlCall(call)) { - has_disallowed = true; - } + const auto *ge = if_node->condition.as(); + if (!ge) { + return false; } - }); - if (has_disallowed) { - return false; - } - - for (const auto &buf : summary.all_read_buffers) { - if (IsGlobalBuffer(buf)) { - continue; + const auto *lhs = ge->a.as(); + if (!lhs || lhs != thread_var_.get()) { + return false; } - if (is_branch_private_buffer(buf) && prefix_defined_buffers.count(buf)) { - continue; + if (!SameExpr(ge->b, consumer_extent_)) { + return false; } - return false; + *branch = GetRef(if_node); + return true; } - for (const auto &buf : summary.all_write_buffers) { - if (IsSharedBuffer(buf)) { - continue; + bool IsWSBranch(const Stmt &stmt, Stmt *container, + IfThenElse *branch) const { + if (IsWSBranchStmt(stmt, branch)) { + *container = stmt; + return true; } - if (is_branch_private_buffer(buf) && - !future_consumer_live.buffers.count(buf)) { - continue; + const auto *attr_node = stmt.as(); + if (!attr_node || attr_node->attr_key != attr::kWarpSpecializationScope) { + return false; } - return false; - } - - for (const auto &var : summary.def_vars) { - if (future_consumer_live.vars.count(var)) { + if (!IsWSBranchStmt(attr_node->body, branch)) { return false; } + *container = stmt; + return true; } - return true; - } - - Optional TryPrependToConsumerBranch(const Stmt &stmt, - const Stmt &prepend_stmt) { - if (auto *seq = stmt.as()) { - if (seq->seq.empty()) { - return std::nullopt; - } - Array new_seq = seq->seq; - auto nested = TryPrependToConsumerBranch(new_seq.back(), prepend_stmt); - if (nested.defined()) { - new_seq.Set(new_seq.size() - 1, nested.value()); - return SeqStmt(new_seq); + bool IsGuardedConsumerStmt(const Stmt &stmt, Stmt *body) const { + const auto *if_node = stmt.as(); + if (!if_node || if_node->else_case.defined()) { + return false; } - return std::nullopt; - } - if (auto *attr = stmt.as()) { - auto nested = TryPrependToConsumerBranch(attr->body, prepend_stmt); - if (nested.defined()) { - return AttrStmt(attr->node, attr->attr_key, attr->value, - nested.value()); + const auto *lt = if_node->condition.as(); + if (!lt) { + return false; } - return std::nullopt; - } - if (auto *let_s = stmt.as()) { - auto nested = TryPrependToConsumerBranch(let_s->body, prepend_stmt); - if (nested.defined()) { - return LetStmt(let_s->var, let_s->value, nested.value()); + const auto *lhs = lt->a.as(); + if (!lhs || lhs != thread_var_.get()) { + return false; } - return std::nullopt; - } - if (auto *realize = stmt.as()) { - auto nested = - TryPrependToConsumerBranch(realize->block->body, prepend_stmt); - if (nested.defined()) { - const Block &orig = realize->block; - Block new_block(orig->iter_vars, orig->reads, orig->writes, - orig->name_hint, nested.value(), orig->init, - orig->alloc_buffers, orig->match_buffers, - orig->annotations); - return BlockRealize(realize->iter_values, realize->predicate, - new_block); - } - return std::nullopt; - } - if (auto *block = stmt.as()) { - auto nested = TryPrependToConsumerBranch(block->body, prepend_stmt); - if (nested.defined()) { - return Block(block->iter_vars, block->reads, block->writes, - block->name_hint, nested.value(), block->init, - block->alloc_buffers, block->match_buffers, - block->annotations); - } - return std::nullopt; - } - if (auto *if_stmt = stmt.as()) { - if (!if_stmt->else_case.defined() || - !IsThreadOnlyPredicate(if_stmt->condition)) { - auto nested_then = - TryPrependToConsumerBranch(if_stmt->then_case, prepend_stmt); - if (nested_then.defined()) { - return IfThenElse(if_stmt->condition, nested_then.value(), - if_stmt->else_case, if_stmt->span); - } - if (if_stmt->else_case.defined()) { - auto nested_else = TryPrependToConsumerBranch( - if_stmt->else_case.value(), prepend_stmt); - if (nested_else.defined()) { - return IfThenElse(if_stmt->condition, if_stmt->then_case, - nested_else.value(), if_stmt->span); - } - } - return std::nullopt; + if (!SameExpr(lt->b, consumer_extent_)) { + return false; } - Stmt new_else = SeqStmt({prepend_stmt, if_stmt->else_case.value()}); - return IfThenElse(if_stmt->condition, if_stmt->then_case, new_else, - if_stmt->span); + *body = if_node->then_case; + return true; } - return std::nullopt; - } - Optional TryPrependToProducerBranch(const Stmt &stmt, - const Stmt &prepend_stmt) { - if (auto *seq = stmt.as()) { - if (seq->seq.empty()) { - return std::nullopt; + static Stmt AppendToStmt(const Stmt &stmt, const Array &suffix) { + if (suffix.empty()) { + return stmt; + } + Array seq; + if (const auto *seq_stmt = stmt.as()) { + for (const auto &s : seq_stmt->seq) { + seq.push_back(s); + } + } else { + seq.push_back(stmt); } - Array new_seq = seq->seq; - auto nested = TryPrependToProducerBranch(new_seq.back(), prepend_stmt); - if (nested.defined()) { - new_seq.Set(new_seq.size() - 1, nested.value()); - return SeqStmt(new_seq); + for (const auto &s : suffix) { + seq.push_back(s); } - return std::nullopt; + return seq.size() == 1 ? seq[0] : SeqStmt(seq); } - if (auto *attr = stmt.as()) { - auto nested = TryPrependToProducerBranch(attr->body, prepend_stmt); - if (nested.defined()) { - return AttrStmt(attr->node, attr->attr_key, attr->value, - nested.value()); + + Stmt UpdateWSBranchContainer(const Stmt &container, + const IfThenElse &branch, + const Array &consumer_postlude) const { + auto *branch_ptr = const_cast(branch).CopyOnWrite(); + ICHECK(branch_ptr->else_case.defined()); + branch_ptr->else_case = + AppendToStmt(branch_ptr->else_case.value(), consumer_postlude); + if (container.same_as(branch)) { + return branch; } - return std::nullopt; + AttrStmt attr = Downcast(container); + attr.CopyOnWrite()->body = branch; + return attr; } - if (auto *let_s = stmt.as()) { - auto nested = TryPrependToProducerBranch(let_s->body, prepend_stmt); - if (nested.defined()) { - return LetStmt(let_s->var, let_s->value, nested.value()); + + Stmt VisitStmt_(const SeqStmtNode *op) final { + Array visited; + for (const auto &stmt : op->seq) { + visited.push_back(VisitStmt(stmt)); } - return std::nullopt; - } - if (auto *realize = stmt.as()) { - auto nested = - TryPrependToProducerBranch(realize->block->body, prepend_stmt); - if (nested.defined()) { - const Block &orig = realize->block; - Block new_block(orig->iter_vars, orig->reads, orig->writes, - orig->name_hint, nested.value(), orig->init, - orig->alloc_buffers, orig->match_buffers, - orig->annotations); - return BlockRealize(realize->iter_values, realize->predicate, - new_block); - } - return std::nullopt; - } - if (auto *block = stmt.as()) { - auto nested = TryPrependToProducerBranch(block->body, prepend_stmt); - if (nested.defined()) { - return Block(block->iter_vars, block->reads, block->writes, - block->name_hint, nested.value(), block->init, - block->alloc_buffers, block->match_buffers, - block->annotations); - } - return std::nullopt; - } - if (auto *if_stmt = stmt.as()) { - if (!if_stmt->else_case.defined() || - !IsThreadOnlyPredicate(if_stmt->condition)) { - auto nested_then = - TryPrependToProducerBranch(if_stmt->then_case, prepend_stmt); - if (nested_then.defined()) { - return IfThenElse(if_stmt->condition, nested_then.value(), - if_stmt->else_case, if_stmt->span); + + Array rebuilt; + for (int i = 0; i < static_cast(visited.size()); ++i) { + Stmt ws_container; + IfThenElse ws_branch; + if (!IsWSBranch(visited[i], &ws_container, &ws_branch)) { + rebuilt.push_back(visited[i]); + continue; } - if (if_stmt->else_case.defined()) { - auto nested_else = TryPrependToProducerBranch( - if_stmt->else_case.value(), prepend_stmt); - if (nested_else.defined()) { - return IfThenElse(if_stmt->condition, if_stmt->then_case, - nested_else.value(), if_stmt->span); + + Array consumer_postlude; + int j = i + 1; + for (; j < static_cast(visited.size()); ++j) { + Stmt body; + if (!IsGuardedConsumerStmt(visited[j], &body)) { + break; } + consumer_postlude.push_back(body); + } + if (consumer_postlude.empty()) { + rebuilt.push_back(visited[i]); + continue; } - return std::nullopt; + + rebuilt.push_back(UpdateWSBranchContainer(ws_container, ws_branch, + consumer_postlude)); + i = j - 1; } - Stmt new_then = SeqStmt({prepend_stmt, if_stmt->then_case}); - return IfThenElse(if_stmt->condition, new_then, if_stmt->else_case, - if_stmt->span); + + return rebuilt.size() == 1 ? rebuilt[0] : SeqStmt(rebuilt); } - return std::nullopt; + + Var thread_var_; + PrimExpr consumer_extent_; + }; + + Stmt GuardConsumerOnly(const Stmt &stmt, PrimExpr consumer_extent) { + return IfThenElse(LT(thread_iv_->var, consumer_extent), stmt); } - Optional TryAppendToProducerBranch(const Stmt &stmt, - const Stmt &append_stmt) { - if (auto *seq = stmt.as()) { - if (seq->seq.empty()) { - return std::nullopt; - } - Array new_seq = seq->seq; - auto nested = TryAppendToProducerBranch(new_seq.back(), append_stmt); - if (nested.defined()) { - new_seq.Set(new_seq.size() - 1, nested.value()); - return SeqStmt(new_seq); - } - return std::nullopt; + ReplaceResult ReplacePipelineLoopInStmt(const Stmt &stmt, + const ForNode *pipeline_loop, + const Stmt &ws_body, + PrimExpr consumer_extent) { + if (stmt.get() == pipeline_loop) { + return {ws_body, true}; } - if (auto *attr = stmt.as()) { - auto nested = TryAppendToProducerBranch(attr->body, append_stmt); - if (nested.defined()) { - return AttrStmt(attr->node, attr->attr_key, attr->value, - nested.value()); + if (auto *seq = stmt.as()) { + Array new_seq; + bool found = false; + // First pass: find which child contains the pipeline loop. + int loop_idx = -1; + for (int i = 0; i < static_cast(seq->seq.size()); ++i) { + ReplaceResult probe = ReplacePipelineLoopInStmt( + seq->seq[i], pipeline_loop, ws_body, consumer_extent); + if (probe.found) { + loop_idx = i; + break; + } } - return std::nullopt; - } - if (auto *let_s = stmt.as()) { - auto nested = TryAppendToProducerBranch(let_s->body, append_stmt); - if (nested.defined()) { - return LetStmt(let_s->var, let_s->value, nested.value()); + if (loop_idx < 0) { + return {stmt, false}; } - return std::nullopt; - } - if (auto *realize = stmt.as()) { - auto nested = - TryAppendToProducerBranch(realize->block->body, append_stmt); - if (nested.defined()) { - const Block &orig = realize->block; - Block new_block(orig->iter_vars, orig->reads, orig->writes, - orig->name_hint, nested.value(), orig->init, - orig->alloc_buffers, orig->match_buffers, - orig->annotations); - return BlockRealize(realize->iter_values, realize->predicate, - new_block); - } - return std::nullopt; - } - if (auto *block = stmt.as()) { - auto nested = TryAppendToProducerBranch(block->body, append_stmt); - if (nested.defined()) { - return Block(block->iter_vars, block->reads, block->writes, - block->name_hint, nested.value(), block->init, - block->alloc_buffers, block->match_buffers, - block->annotations); - } - return std::nullopt; - } - if (auto *if_stmt = stmt.as()) { - if (!if_stmt->else_case.defined() || - !IsThreadOnlyPredicate(if_stmt->condition)) { - auto nested_then = - TryAppendToProducerBranch(if_stmt->then_case, append_stmt); - if (nested_then.defined()) { - return IfThenElse(if_stmt->condition, nested_then.value(), - if_stmt->else_case, if_stmt->span); - } - if (if_stmt->else_case.defined()) { - auto nested_else = TryAppendToProducerBranch( - if_stmt->else_case.value(), append_stmt); - if (nested_else.defined()) { - return IfThenElse(if_stmt->condition, if_stmt->then_case, - nested_else.value(), if_stmt->span); + // Classify pre-loop statements using branch-private def/use sets. + // Shared-prelude statements stay in place; branch-private definitions + // move next to the branch that consumes them, or are duplicated when + // both producer and consumer need the same definition. + for (int i = 0; i < loop_idx; ++i) { + switch (ClassifyPreludeStmt(seq->seq[i], buffer_data_to_buffer_, + producer_prelude_live_seed_, + consumer_prelude_live_seed_)) { + case PreludeStmtPlacement::kProducerOnly: + extracted_producer_init_.push_back(seq->seq[i]); + break; + case PreludeStmtPlacement::kConsumerOnly: + extracted_consumer_init_.push_back(seq->seq[i]); + break; + case PreludeStmtPlacement::kDuplicateToBoth: + extracted_producer_init_.push_back(seq->seq[i]); + extracted_consumer_init_.push_back(seq->seq[i]); + break; + case PreludeStmtPlacement::kKeepSharedPrelude: + if (auto it = common_prelude_rewrites_.find(seq->seq[i].get()); + it != common_prelude_rewrites_.end()) { + new_seq.push_back(it->second); + } else { + new_seq.push_back(seq->seq[i]); } + break; } - return std::nullopt; - } - Stmt new_then = SeqStmt({if_stmt->then_case, append_stmt}); - return IfThenElse(if_stmt->condition, new_then, if_stmt->else_case, - if_stmt->span); - } - return std::nullopt; - } - - Optional TryAppendToConsumerBranch(const Stmt &stmt, - const Stmt &append_stmt) { - if (auto *seq = stmt.as()) { - if (seq->seq.empty()) { - return std::nullopt; } - Array new_seq = seq->seq; - auto nested = TryAppendToConsumerBranch(new_seq.back(), append_stmt); - if (nested.defined()) { - new_seq.Set(new_seq.size() - 1, nested.value()); - return SeqStmt(new_seq); + // Replace the pipeline loop itself. + ReplaceResult result = ReplacePipelineLoopInStmt( + seq->seq[loop_idx], pipeline_loop, ws_body, consumer_extent); + new_seq.push_back(result.stmt); + // Guard post-loop siblings. + for (int i = loop_idx + 1; i < static_cast(seq->seq.size()); ++i) { + new_seq.push_back(GuardConsumerOnly(seq->seq[i], consumer_extent)); } - return std::nullopt; + return {new_seq.size() == 1 ? new_seq[0] : SeqStmt(new_seq), true}; } - if (auto *attr = stmt.as()) { - auto nested = TryAppendToConsumerBranch(attr->body, append_stmt); - if (nested.defined()) { - return AttrStmt(attr->node, attr->attr_key, attr->value, - nested.value()); - } - return std::nullopt; - } - if (auto *let_s = stmt.as()) { - auto nested = TryAppendToConsumerBranch(let_s->body, append_stmt); - if (nested.defined()) { - return LetStmt(let_s->var, let_s->value, nested.value()); + if (auto *let = stmt.as()) { + ReplaceResult result = ReplacePipelineLoopInStmt( + let->body, pipeline_loop, ws_body, consumer_extent); + if (!result.found) { + return {stmt, false}; } - return std::nullopt; + return {LetStmt(let->var, let->value, result.stmt), true}; } if (auto *realize = stmt.as()) { - auto nested = - TryAppendToConsumerBranch(realize->block->body, append_stmt); - if (nested.defined()) { - const Block &orig = realize->block; - Block new_block(orig->iter_vars, orig->reads, orig->writes, - orig->name_hint, nested.value(), orig->init, - orig->alloc_buffers, orig->match_buffers, - orig->annotations); - return BlockRealize(realize->iter_values, realize->predicate, - new_block); - } - return std::nullopt; + ReplaceResult result = ReplacePipelineLoopInStmt( + realize->block->body, pipeline_loop, ws_body, consumer_extent); + if (!result.found) { + return {stmt, false}; + } + Block block = realize->block; + block.CopyOnWrite()->body = result.stmt; + BlockRealize new_realize = GetRef(realize); + new_realize.CopyOnWrite()->block = block; + return {new_realize, true}; } if (auto *block = stmt.as()) { - auto nested = TryAppendToConsumerBranch(block->body, append_stmt); - if (nested.defined()) { - return Block(block->iter_vars, block->reads, block->writes, - block->name_hint, nested.value(), block->init, - block->alloc_buffers, block->match_buffers, - block->annotations); - } - return std::nullopt; - } - if (auto *if_stmt = stmt.as()) { - if (!if_stmt->else_case.defined() || - !IsThreadOnlyPredicate(if_stmt->condition)) { - auto nested_then = - TryAppendToConsumerBranch(if_stmt->then_case, append_stmt); - if (nested_then.defined()) { - return IfThenElse(if_stmt->condition, nested_then.value(), - if_stmt->else_case, if_stmt->span); - } - if (if_stmt->else_case.defined()) { - auto nested_else = TryAppendToConsumerBranch( - if_stmt->else_case.value(), append_stmt); - if (nested_else.defined()) { - return IfThenElse(if_stmt->condition, if_stmt->then_case, - nested_else.value(), if_stmt->span); - } - } - return std::nullopt; + ReplaceResult result = ReplacePipelineLoopInStmt( + block->body, pipeline_loop, ws_body, consumer_extent); + if (!result.found) { + return {stmt, false}; + } + Block new_block = GetRef(block); + new_block.CopyOnWrite()->body = result.stmt; + return {new_block, true}; + } + if (auto *attr = stmt.as()) { + ReplaceResult result = ReplacePipelineLoopInStmt( + attr->body, pipeline_loop, ws_body, consumer_extent); + if (!result.found) { + return {stmt, false}; } - Stmt new_else = SeqStmt({if_stmt->else_case.value(), append_stmt}); - return IfThenElse(if_stmt->condition, if_stmt->then_case, new_else, - if_stmt->span); + AttrStmt new_attr = GetRef(attr); + new_attr.CopyOnWrite()->body = result.stmt; + return {new_attr, true}; } - return std::nullopt; + return {stmt, false}; } - bool IsMbarrierWaitParityStmt(const Stmt &stmt) { - return ExtractWaitBarrierId(stmt).defined(); - } + // --- PCThreadIdxRewriter (simplified for tile-op level) --- + class PCThreadIdxRewriter : public StmtExprMutator { + public: + static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced, + PrimExpr thread_extent, bool do_shuffle) { + PCThreadIdxRewriter r(std::move(thread_var), std::move(replaced), + std::move(thread_extent)); + return r(std::move(stmt)); + } - Optional ExtractWaitBarrierId(const Stmt &stmt) { - auto extract_from_call = [](const CallNode *call) -> Optional { - if (!call || !call->op.same_as(mbarrier_wait_parity()) || - call->args.size() != 2) { - return std::nullopt; - } - // Check for BufferLoad on shared.barrier scope buffer - if (auto *bl = call->args[0].as()) { - if (bl->buffer.scope() == "shared.barrier" && bl->indices.size() == 1) { - return bl->indices[0]; - } - } - return std::nullopt; - }; + private: + PCThreadIdxRewriter(Var thread_var, PrimExpr replaced, + PrimExpr thread_extent) + : thread_var_(std::move(thread_var)), replaced_(std::move(replaced)), + thread_extent_(std::move(thread_extent)) {} - if (auto *eval = stmt.as()) { - return extract_from_call(eval->value.as()); - } - if (auto *if_stmt = stmt.as()) { - if (!if_stmt->else_case.defined() || - IsTrivialNoOpStmt(if_stmt->else_case.value())) { - return ExtractWaitBarrierId(if_stmt->then_case); + PrimExpr VisitExpr_(const VarNode *var) final { + if (var == thread_var_.get()) { + return replaced_; } - return std::nullopt; - } - if (auto *attr = stmt.as()) { - return ExtractWaitBarrierId(attr->body); - } - if (auto *let_stmt = stmt.as()) { - return ExtractWaitBarrierId(let_stmt->body); + return StmtExprMutator::VisitExpr_(var); } - if (auto *seq = stmt.as()) { - if (seq->seq.size() == 1) { - return ExtractWaitBarrierId(seq->seq[0]); - } - return std::nullopt; - } - if (auto *block = stmt.as()) { - return ExtractWaitBarrierId(block->body); - } - if (auto *realize = stmt.as()) { - if (is_one(realize->predicate)) { - return ExtractWaitBarrierId(realize->block->body); - } - } - return std::nullopt; - } - struct TmaProducerWaitPair { - Stmt producer_stmt; - Stmt wait_stmt; + Var thread_var_; + PrimExpr replaced_; + PrimExpr thread_extent_; }; - std::optional - ExtractTmaProducerWaitPair(const Stmt &stmt) { - if (auto *seq = stmt.as()) { - if (seq->seq.size() == 1) { - return ExtractTmaProducerWaitPair(seq->seq[0]); - } - if (seq->seq.size() == 2 && ContainsTmaLoadStmt(seq->seq[0]) && - IsMbarrierWaitParityStmt(seq->seq[1])) { - return TmaProducerWaitPair{seq->seq[0], seq->seq[1]}; - } - return std::nullopt; - } - if (auto *if_stmt = stmt.as()) { - if (!if_stmt->else_case.defined() || - IsTrivialNoOpStmt(if_stmt->else_case.value())) { - auto inner = ExtractTmaProducerWaitPair(if_stmt->then_case); - if (!inner.has_value()) { - return std::nullopt; - } - return TmaProducerWaitPair{ - IfThenElse(if_stmt->condition, inner->producer_stmt, std::nullopt, - if_stmt->span), - IfThenElse(if_stmt->condition, inner->wait_stmt, std::nullopt, - if_stmt->span)}; - } - return std::nullopt; - } - if (auto *attr = stmt.as()) { - auto inner = ExtractTmaProducerWaitPair(attr->body); - if (!inner.has_value()) { - return std::nullopt; - } - if (attr->attr_key == "tl.tma_copy_write_buffer") { - return TmaProducerWaitPair{AttrStmt(attr->node, attr->attr_key, - attr->value, inner->producer_stmt, - attr->span), - inner->wait_stmt}; - } - return TmaProducerWaitPair{ - AttrStmt(attr->node, attr->attr_key, attr->value, - inner->producer_stmt, attr->span), - AttrStmt(attr->node, attr->attr_key, attr->value, inner->wait_stmt, - attr->span)}; - } - if (auto *let_stmt = stmt.as()) { - auto inner = ExtractTmaProducerWaitPair(let_stmt->body); - if (!inner.has_value()) { - return std::nullopt; - } - return TmaProducerWaitPair{ - LetStmt(let_stmt->var, let_stmt->value, inner->producer_stmt), - LetStmt(let_stmt->var, let_stmt->value, inner->wait_stmt)}; - } - if (auto *block = stmt.as()) { - auto inner = ExtractTmaProducerWaitPair(block->body); - if (!inner.has_value()) { - return std::nullopt; - } - return TmaProducerWaitPair{ - Block(block->iter_vars, block->reads, block->writes, block->name_hint, - inner->producer_stmt, block->init, block->alloc_buffers, - block->match_buffers, block->annotations), - Block(block->iter_vars, block->reads, block->writes, block->name_hint, - inner->wait_stmt, block->init, block->alloc_buffers, - block->match_buffers, block->annotations)}; - } - if (auto *realize = stmt.as()) { - if (!is_one(realize->predicate)) { - return std::nullopt; - } - auto inner = ExtractTmaProducerWaitPair(realize->block->body); - if (!inner.has_value()) { - return std::nullopt; - } - const Block &orig = realize->block; - Block producer_block(orig->iter_vars, orig->reads, orig->writes, - orig->name_hint, inner->producer_stmt, orig->init, - orig->alloc_buffers, orig->match_buffers, - orig->annotations); - Block wait_block(orig->iter_vars, orig->reads, orig->writes, - orig->name_hint, inner->wait_stmt, orig->init, - orig->alloc_buffers, orig->match_buffers, - orig->annotations); - return TmaProducerWaitPair{ - BlockRealize(realize->iter_values, realize->predicate, - producer_block), - BlockRealize(realize->iter_values, realize->predicate, wait_block)}; - } - return std::nullopt; - } - - bool IsTmaProducerPrefixStmt(const Stmt &stmt) { - bool has_prefix_ops = false; - bool has_wait = false; - bool has_disallowed = false; - PostOrderVisit(stmt, [&](const ObjectRef &node) { - if (const auto *attr = node.as()) { - if (attr->attr_key == "tl.tma_copy_write_buffer") { - has_prefix_ops = true; - } - return; - } - const auto *call = node.as(); - if (call == nullptr) { - return; - } - if (call->op.same_as(mbarrier_wait_parity())) { - has_wait = true; - return; - } - if (call->op.same_as(mbarrier_expect_tx()) || - call->op.same_as(builtin::ptx_arrive_barrier_expect_tx()) || - call->op.same_as(builtin::ptx_arrive_barrier()) || - call->op.same_as(tl::ptx_arrive_cluster_barrier()) || - call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { - has_prefix_ops = true; - return; - } - if (IsBarrierOrTmaControlCall(call)) { - has_disallowed = true; - } - }); - return has_prefix_ops && !has_wait && !has_disallowed; - } - - std::optional> - ExtractFlatTmaProducerClusterBeforeWait(const Array &stmts, - int wait_idx) { - if (wait_idx < 0 || wait_idx >= static_cast(stmts.size()) || - !IsMbarrierWaitParityStmt(stmts[wait_idx])) { - return std::nullopt; - } - int start = wait_idx; - while (start > 0 && IsTmaProducerPrefixStmt(stmts[start - 1])) { - --start; - } - if (start == wait_idx) { - return std::nullopt; - } - Array producer_parts; - producer_parts.reserve(wait_idx - start); - for (int i = start; i < wait_idx; ++i) { - producer_parts.push_back(stmts[i]); - } - Stmt producer_stmt = producer_parts.size() == 1 ? producer_parts[0] - : SeqStmt(producer_parts); - return std::make_pair(start, producer_stmt); - } - - Stmt NormalizeForwardWaitParity(const Stmt &wait_stmt, - const PrimExpr &normalized_parity) { - auto barrier_id = ExtractWaitBarrierId(wait_stmt); - if (!barrier_id.defined()) { - return wait_stmt; - } - return makeParityWait(barrier_buf_, barrier_id.value(), normalized_parity); - } - - bool ContainsTmaLoadStmt(const Stmt &stmt) { - bool found = false; - PostOrderVisit(stmt, [&](const ObjectRef &node) { - if (auto *call = node.as()) { - if (call->op.same_as(tma_load()) || - call->op.same_as(tma_load_im2col())) { - found = true; - } - } - }); - return found; - } + // State + Target target_; + IterVar thread_iv_; + Optional num_threads_; // total (consumer + producer) + bool ws_transformed_{false}; + BufferDataToBufferMap buffer_data_to_buffer_; + std::unordered_map common_prelude_rewrites_; + LocalLiveSet producer_prelude_live_seed_; + LocalLiveSet consumer_prelude_live_seed_; + Array extracted_producer_init_; + Array extracted_consumer_init_; +}; - bool IsThreadOnlyPredicate(const PrimExpr &expr) const { - bool uses_thread = false; - PostOrderVisit(expr, [&](const ObjectRef &node) { - if (const auto *var = node.as()) { - if (thread_iv_.defined() && var == thread_iv_->var.get()) { - uses_thread = true; - } - } - }); - return uses_thread; - } +// --------------------------------------------------------------------------- +// Detect if manual WS is already present (skip if so) +// --------------------------------------------------------------------------- - Optional ExtractNonThreadProducerGuard(const Stmt &stmt) const { - if (const auto *attr = stmt.as()) { - return ExtractNonThreadProducerGuard(attr->body); - } - if (const auto *let_s = stmt.as()) { - return ExtractNonThreadProducerGuard(let_s->body); - } - if (const auto *realize = stmt.as()) { - return ExtractNonThreadProducerGuard(realize->block->body); - } - if (const auto *block = stmt.as()) { - return ExtractNonThreadProducerGuard(block->body); - } - if (const auto *seq = stmt.as()) { - for (const auto &s : seq->seq) { - auto guard = ExtractNonThreadProducerGuard(s); - if (guard.defined()) { - return guard; - } - } - return std::nullopt; - } - if (const auto *if_stmt = stmt.as()) { - if (!if_stmt->else_case.defined() || - IsTrivialNoOpStmt(if_stmt->else_case.value())) { - if (!IsThreadOnlyPredicate(if_stmt->condition)) { - return if_stmt->condition; - } - return ExtractNonThreadProducerGuard(if_stmt->then_case); - } - } - return std::nullopt; +class ManualWSDetector : public StmtExprVisitor { +public: + static bool HasManualWS(const Stmt &stmt) { + ManualWSDetector d; + d(stmt); + return d.found_; } - PrimExpr ResolveGuardBinding(const PrimExpr &expr, - const VarBindingMap &bindings) const { - if (const auto *var = expr.as()) { - auto it = bindings.find(GetRef(var)); - if (it != bindings.end()) { - return ResolveGuardBinding(it->second, bindings); - } - } - if (const auto *cast = expr.as()) { - return ResolveGuardBinding(cast->value, bindings); +private: + void VisitStmt_(const AttrStmtNode *op) final { + // Detect both the T.ws() language-level attr ("warp_specialize") and + // the compiler-level attr (kWarpSpecializationScope). + if (op->attr_key == "warp_specialize" || + op->attr_key == attr::kWarpSpecializationScope) { + found_ = true; + return; } - return expr; + StmtExprVisitor::VisitStmt_(op); } - bool IsMaskLikeBooleanExpr(const PrimExpr &expr) const { - PrimExpr resolved = expr; - while (const auto *cast = resolved.as()) { - resolved = cast->value; - } - auto is_const_bool = [](const PrimExpr &value, bool expected) { - if (const auto *imm = value.as()) { - return static_cast(imm->value) == expected; - } - return false; - }; - if (const auto *load = resolved.as()) { - return load->buffer->dtype.is_bool(); - } - if (const auto *select = resolved.as()) { - if (is_const_bool(select->false_value, false)) { - return IsMaskLikeBooleanExpr(select->true_value); - } - if (is_const_bool(select->true_value, false)) { - return IsMaskLikeBooleanExpr(select->false_value); - } - } - if (const auto *call = resolved.as()) { - if (const auto *op = call->op.as()) { - if (op->name == "tl.any_of" || op->name == "tl.all_of") { - return true; - } - } - if (call->op.same_as(builtin::call_extern()) && !call->args.empty()) { - if (const auto *name = call->args[0].as()) { - if (name->value == "tl::Any" || name->value == "tl::All") { - return true; - } - } - } - if (call->op.same_as(builtin::if_then_else()) && call->args.size() == 3) { - if (is_const_bool(call->args[2], false)) { - return IsMaskLikeBooleanExpr(call->args[1]); - } - if (is_const_bool(call->args[1], false)) { - return IsMaskLikeBooleanExpr(call->args[2]); - } - } - } - return false; - } + bool found_{false}; +}; - bool CanIssueProducerWithoutGuardImpl(const Stmt &stmt, - VarBindingMap *bindings) const { - if (const auto *attr = stmt.as()) { - return CanIssueProducerWithoutGuardImpl(attr->body, bindings); - } - if (const auto *let_s = stmt.as()) { - bindings->emplace(let_s->var, let_s->value); - bool result = CanIssueProducerWithoutGuardImpl(let_s->body, bindings); - bindings->erase(let_s->var); - return result; - } - if (const auto *realize = stmt.as()) { - return CanIssueProducerWithoutGuardImpl(realize->block->body, bindings); - } - if (const auto *block = stmt.as()) { - return CanIssueProducerWithoutGuardImpl(block->body, bindings); - } - if (const auto *seq = stmt.as()) { - for (const auto &s : seq->seq) { - if (CanIssueProducerWithoutGuardImpl(s, bindings)) { - return true; - } - } - return false; - } - if (const auto *if_stmt = stmt.as()) { - if (!if_stmt->else_case.defined() || - IsTrivialNoOpStmt(if_stmt->else_case.value())) { - if (!IsThreadOnlyPredicate(if_stmt->condition)) { - if (const auto *var = if_stmt->condition.as()) { - Var cond_var = GetRef(var); - if (!UsesVar(if_stmt->then_case, [cond_var](const VarNode *vn) { - return vn == cond_var.get(); - })) { - return true; - } - } - PrimExpr resolved = - ResolveGuardBinding(if_stmt->condition, *bindings); - return IsMaskLikeBooleanExpr(resolved); - } - return CanIssueProducerWithoutGuardImpl(if_stmt->then_case, bindings); - } - } - return false; +/// Quick pre-scan: check if the function contains a pipelined loop (num_stages +/// >= 1) with at least one TMA load producer tile op and no manual layout +/// annotations (which are incompatible with early MVB expansion). +/// Check whether a layout annotation on a shared buffer is compatible with +/// TMA. TMA supports identity (linear) layouts and the three standard +/// swizzle modes (32B / 64B / 128B). Any other layout (e.g. padded, +/// Volta-style) cannot be used with TMA. +static bool IsTmaCompatibleLayout(const Layout &layout, const Buffer &buffer) { + // Recognised swizzle → TMA with swizzle. + if (DetectSwizzleMode(layout, buffer) != SwizzleMode::kNone) { + return true; } - - bool CanIssueProducerWithoutGuard(const Stmt &stmt) const { - VarBindingMap bindings = current_loop_guard_bindings_; - return CanIssueProducerWithoutGuardImpl(stmt, &bindings); + // Identity / row-major linear → TMA without swizzle. + if (StructuralEqual()(layout, makeLinearLayout(buffer->shape))) { + return true; } + return false; +} - Stmt StripNonThreadProducerGuard(const Stmt &stmt) const { - if (const auto *attr = stmt.as()) { - return AttrStmt(attr->node, attr->attr_key, attr->value, - StripNonThreadProducerGuard(attr->body), attr->span); - } - if (const auto *let_s = stmt.as()) { - return LetStmt(let_s->var, let_s->value, - StripNonThreadProducerGuard(let_s->body)); - } - if (const auto *realize = stmt.as()) { - const Block &orig = realize->block; - Block new_block(orig->iter_vars, orig->reads, orig->writes, - orig->name_hint, StripNonThreadProducerGuard(orig->body), - orig->init, orig->alloc_buffers, orig->match_buffers, - orig->annotations); - return BlockRealize(realize->iter_values, realize->predicate, new_block); - } - if (const auto *block = stmt.as()) { - return Block(block->iter_vars, block->reads, block->writes, - block->name_hint, StripNonThreadProducerGuard(block->body), - block->init, block->alloc_buffers, block->match_buffers, - block->annotations); - } - if (const auto *seq = stmt.as()) { - Array stripped; - stripped.reserve(seq->seq.size()); - for (const auto &s : seq->seq) { - stripped.push_back(StripNonThreadProducerGuard(s)); - } - return stripped.size() == 1 ? stripped[0] : SeqStmt(stripped, seq->span); - } - if (const auto *if_stmt = stmt.as()) { - if (!if_stmt->else_case.defined() || - IsTrivialNoOpStmt(if_stmt->else_case.value())) { - if (!IsThreadOnlyPredicate(if_stmt->condition)) { - return StripNonThreadProducerGuard(if_stmt->then_case); - } - return IfThenElse(if_stmt->condition, - StripNonThreadProducerGuard(if_stmt->then_case), - std::nullopt, if_stmt->span); - } - } - return stmt; +class TiledWSCandidate : public StmtExprVisitor { +public: + static bool Check(const Stmt &stmt, Target target) { + TiledWSCandidate c; + c.target_ = target; + c(stmt); + return c.has_pipeline_loop_ && c.has_tma_tile_op_; } - Stmt WrapStmtWithOptionalGuard(const Optional &guard, - const Stmt &stmt) const { - if (!guard.defined()) { - return stmt; - } - return IfThenElse(guard.value(), stmt, std::nullopt); - } - - Optional WrapStmtWithNonThreadGuardLike(const Stmt &source, - const Stmt &stmt) const { - if (const auto *attr = source.as()) { - Optional wrapped = WrapStmtWithNonThreadGuardLike(attr->body, stmt); - if (!wrapped.defined()) { - return std::nullopt; - } - return AttrStmt(attr->node, attr->attr_key, attr->value, wrapped.value(), - attr->span); - } - if (const auto *let_s = source.as()) { - Optional wrapped = - WrapStmtWithNonThreadGuardLike(let_s->body, stmt); - if (!wrapped.defined()) { - return std::nullopt; - } - return LetStmt(let_s->var, let_s->value, wrapped.value()); - } - if (const auto *realize = source.as()) { - Optional wrapped = - WrapStmtWithNonThreadGuardLike(realize->block->body, stmt); - if (!wrapped.defined()) { - return std::nullopt; - } - const Block &orig = realize->block; - Block new_block(orig->iter_vars, orig->reads, orig->writes, - orig->name_hint, wrapped.value(), orig->init, - orig->alloc_buffers, orig->match_buffers, - orig->annotations); - return BlockRealize(realize->iter_values, realize->predicate, new_block); - } - if (const auto *block = source.as()) { - Optional wrapped = - WrapStmtWithNonThreadGuardLike(block->body, stmt); - if (!wrapped.defined()) { - return std::nullopt; - } - return Block(block->iter_vars, block->reads, block->writes, - block->name_hint, wrapped.value(), block->init, - block->alloc_buffers, block->match_buffers, - block->annotations); - } - if (const auto *seq = source.as()) { - if (seq->seq.size() == 1) { - return WrapStmtWithNonThreadGuardLike(seq->seq[0], stmt); - } - return std::nullopt; - } - if (const auto *if_stmt = source.as()) { - if (!if_stmt->else_case.defined() || - IsTrivialNoOpStmt(if_stmt->else_case.value())) { - if (!IsThreadOnlyPredicate(if_stmt->condition)) { - return IfThenElse(if_stmt->condition, stmt, std::nullopt, - if_stmt->span); - } - Optional wrapped = - WrapStmtWithNonThreadGuardLike(if_stmt->then_case, stmt); - if (!wrapped.defined()) { - return std::nullopt; +private: + void VisitStmt_(const ForNode *op) final { + bool old = in_pipeline_; + if (auto anno = op->annotations.Get("num_stages")) { + if (auto *imm = anno->as()) { + if (imm->value >= 1) { + has_pipeline_loop_ = true; + in_pipeline_ = true; } - return IfThenElse(if_stmt->condition, wrapped.value(), std::nullopt, - if_stmt->span); - } - } - return std::nullopt; - } - - Stmt WrapStmtWithGuardSource(const Optional &guard_source, - const Optional &guard, - const Stmt &stmt) const { - if (guard_source.defined()) { - Optional wrapped = - WrapStmtWithNonThreadGuardLike(guard_source.value(), stmt); - if (wrapped.defined()) { - return wrapped.value(); } } - return WrapStmtWithOptionalGuard(guard, stmt); + StmtExprVisitor::VisitStmt_(op); + in_pipeline_ = old; } - Stmt RewriteWaitBarrier(const Stmt &wait_stmt, const PrimExpr &new_barrier_id, - Optional new_parity = std::nullopt) { - class WaitBarrierRewriter : public StmtExprMutator { - public: - WaitBarrierRewriter(const Buffer &barrier_buf, PrimExpr barrier_id, - Optional parity) - : barrier_buf_(barrier_buf), barrier_id_(std::move(barrier_id)), - parity_(std::move(parity)) {} - - PrimExpr VisitExpr_(const CallNode *op) final { - auto call = Downcast(StmtExprMutator::VisitExpr_(op)); - if (call->op.same_as(mbarrier_wait_parity()) && - call->args.size() == 2) { - PrimExpr parity = parity_.defined() ? parity_.value() : call->args[1]; - return Call(call->dtype, call->op, - {makeGetBarrier(barrier_buf_, barrier_id_), parity}, - call->annotations, call->span); - } - return call; - } - - private: - Buffer barrier_buf_; - PrimExpr barrier_id_; - Optional parity_; - }; - - return MergeAdjacentEquivalentIfs(WaitBarrierRewriter( - barrier_buf_, new_barrier_id, std::move(new_parity))(wait_stmt)); - } - - Stmt RewriteTmaStmtBarrierIdPreserveProtocol(const Stmt &stmt, - const PrimExpr &barrier_id, - bool drop_arrive = false) { - class TmaBarrierIdRewriter : public StmtExprMutator { - public: - TmaBarrierIdRewriter(const Buffer &barrier_buf, PrimExpr barrier_id, - bool drop_arrive, bool is_cluster_barrier, - int cluster_size) - : barrier_buf_(barrier_buf), barrier_id_(std::move(barrier_id)), - drop_arrive_(drop_arrive), is_cluster_barrier_(is_cluster_barrier), - cluster_size_(cluster_size) {} - - Stmt VisitStmt_(const EvaluateNode *op) final { - if (!is_cluster_barrier_) { - return StmtExprMutator::VisitStmt_(op); - } - // For cluster barriers, intercept mbarrier_expect_tx: multiply bytes - // by cluster_size and wrap in if (block_rank_in_cluster() == 0). - if (const auto *call = op->value.as()) { - if ((call->op.same_as(builtin::ptx_arrive_barrier_expect_tx()) || - call->op.same_as(mbarrier_expect_tx())) && - call->args.size() == 2) { - PrimExpr new_bytes = - call->args[1] * IntImm(DataType::Int(32), cluster_size_); - auto new_call = - Call(call->dtype, call->op, - {makeGetBarrier(barrier_buf_, barrier_id_), new_bytes}, - call->annotations, call->span); - PrimExpr rank = - Call(DataType::Int(32), tl::block_rank_in_cluster(), {}); - return IfThenElse(EQ(rank, IntImm(DataType::Int(32), 0)), - Evaluate(new_call), Stmt()); - } - } - return StmtExprMutator::VisitStmt_(op); - } - - PrimExpr VisitExpr_(const CallNode *op) final { - auto call = Downcast(StmtExprMutator::VisitExpr_(op)); - if ((call->op.same_as(builtin::ptx_arrive_barrier_expect_tx()) || - call->op.same_as(mbarrier_expect_tx())) && - call->args.size() == 2) { - // For non-cluster barriers, just rewrite the barrier arg. - // Cluster barriers are handled in VisitStmt_ above. - if (!is_cluster_barrier_) { - return Call( - call->dtype, call->op, - {makeGetBarrier(barrier_buf_, barrier_id_), call->args[1]}, - call->annotations, call->span); - } - return call; - } - if (call->op.same_as(tma_load()) || - call->op.same_as(tma_load_im2col())) { - bool is_1d_tma_load = false; - if (const auto *arg0 = call->args[0].as()) { - is_1d_tma_load = !arg0->op.same_as(create_tma_descriptor()) && - call->op.same_as(tma_load()); - } - auto new_call = call.CopyOnWrite(); - new_call->args.Set(is_1d_tma_load ? 2 : 1, - makeGetBarrier(barrier_buf_, barrier_id_)); - // For cluster barriers, add use_2cta annotation - if (is_cluster_barrier_) { - Map new_annotations = call->annotations; - new_annotations.Set("use_2cta", Bool(true)); - new_call->annotations = new_annotations; - } - return call; - } - if ((call->op.same_as(builtin::ptx_arrive_barrier()) || - call->op.same_as(tl::ptx_arrive_cluster_barrier())) && - !call->args.empty()) { - if (drop_arrive_) { - return IntImm(DataType::Int(32), 0); - } - auto new_call = call.CopyOnWrite(); - new_call->args.Set(0, makeGetBarrier(barrier_buf_, barrier_id_)); - return call; - } - return call; - } - - private: - Buffer barrier_buf_; - PrimExpr barrier_id_; - bool drop_arrive_; - bool is_cluster_barrier_; - int cluster_size_; - }; - - return MergeAdjacentEquivalentIfs( - TmaBarrierIdRewriter(barrier_buf_, barrier_id, drop_arrive, - is_cluster_barrier_, cluster_size_)(stmt)); - } - - Stmt MergeAdjacentEquivalentIfs(const Stmt &stmt) { - if (const auto *attr = stmt.as()) { - return AttrStmt(attr->node, attr->attr_key, attr->value, - MergeAdjacentEquivalentIfs(attr->body), attr->span); - } - if (const auto *let_stmt = stmt.as()) { - return LetStmt(let_stmt->var, let_stmt->value, - MergeAdjacentEquivalentIfs(let_stmt->body)); - } - if (const auto *block = stmt.as()) { - return Block(block->iter_vars, block->reads, block->writes, - block->name_hint, MergeAdjacentEquivalentIfs(block->body), - block->init, block->alloc_buffers, block->match_buffers, - block->annotations); - } - if (const auto *realize = stmt.as()) { - const Block &orig = realize->block; - Block new_block(orig->iter_vars, orig->reads, orig->writes, - orig->name_hint, MergeAdjacentEquivalentIfs(orig->body), - orig->init, orig->alloc_buffers, orig->match_buffers, - orig->annotations); - return BlockRealize(realize->iter_values, realize->predicate, new_block); - } - if (const auto *if_stmt = stmt.as()) { - Optional else_case = std::nullopt; - if (if_stmt->else_case.defined()) { - else_case = MergeAdjacentEquivalentIfs(if_stmt->else_case.value()); - } - return IfThenElse(if_stmt->condition, - MergeAdjacentEquivalentIfs(if_stmt->then_case), - else_case, if_stmt->span); - } - if (const auto *seq = stmt.as()) { - Array merged; - StructuralEqual equal; - for (size_t i = 0; i < seq->seq.size();) { - const auto *if0 = seq->seq[i].as(); - if (if0 && !if0->else_case.defined()) { - Array then_stmts; - then_stmts.push_back(if0->then_case); - size_t j = i + 1; - while (j < seq->seq.size()) { - const auto *ifj = seq->seq[j].as(); - if (!ifj || ifj->else_case.defined() || - !equal(if0->condition, ifj->condition)) { - break; - } - then_stmts.push_back(ifj->then_case); - ++j; - } - if (then_stmts.size() == 1) { - merged.push_back(seq->seq[i]); - } else { - Stmt merged_then = MergeAdjacentEquivalentIfs( - then_stmts.size() == 1 ? then_stmts[0] : SeqStmt(then_stmts)); - merged.push_back(IfThenElse(if0->condition, merged_then, - std::nullopt, if0->span)); + void VisitExpr_(const CallNode *op) final { + if (in_pipeline_ && !has_tma_tile_op_) { + auto tile_op = ParseOperator(ffi::GetRef(op)); + if (auto *copy = tile_op.as()) { + if (ClassifyCopy(copy, target_) == TileStmtKind::kTmaProducer) { + // If the destination buffer has a layout annotation, verify + // that the layout is TMA-compatible (swizzle or linear). + // Copies whose layout is incompatible with TMA cannot become + // TMA producers. + if (HasTmaCompatibleLayout(copy->dst)) { + has_tma_tile_op_ = true; } - i = j; - continue; } - merged.push_back(seq->seq[i]); - ++i; } - return merged.size() == 1 ? merged[0] : SeqStmt(merged, seq->span); } - return stmt; - } - - Stmt RewriteTmaForwardProducerStmt(const Stmt &stmt, - const PrimExpr &barrier_id, - bool append_arrive) { - class TmaForwardBarrierStmtRewriter : public StmtExprMutator { - public: - TmaForwardBarrierStmtRewriter(const Buffer &barrier_buf, - PrimExpr barrier_id, - bool is_cluster_barrier, int cluster_size) - : barrier_buf_(barrier_buf), barrier_id_(std::move(barrier_id)), - is_cluster_barrier_(is_cluster_barrier), - cluster_size_(cluster_size) {} - - Stmt VisitStmt_(const EvaluateNode *op) final { - if (!is_cluster_barrier_) { - return StmtExprMutator::VisitStmt_(op); - } - if (const auto *call = op->value.as()) { - if ((call->op.same_as(builtin::ptx_arrive_barrier_expect_tx()) || - call->op.same_as(mbarrier_expect_tx())) && - call->args.size() == 2) { - PrimExpr new_bytes = - call->args[1] * IntImm(DataType::Int(32), cluster_size_); - auto new_call = - Call(call->dtype, mbarrier_expect_tx(), - {makeGetBarrier(barrier_buf_, barrier_id_), new_bytes}, - call->annotations, call->span); - PrimExpr rank = - Call(DataType::Int(32), tl::block_rank_in_cluster(), {}); - return IfThenElse(EQ(rank, IntImm(DataType::Int(32), 0)), - Evaluate(new_call), Stmt()); - } - } - return StmtExprMutator::VisitStmt_(op); - } - - PrimExpr VisitExpr_(const CallNode *op) final { - auto call = Downcast(StmtExprMutator::VisitExpr_(op)); - if ((call->op.same_as(builtin::ptx_arrive_barrier_expect_tx()) || - call->op.same_as(mbarrier_expect_tx())) && - call->args.size() == 2) { - if (!is_cluster_barrier_) { - return Call( - call->dtype, mbarrier_expect_tx(), - {makeGetBarrier(barrier_buf_, barrier_id_), call->args[1]}, - call->annotations, call->span); - } - return call; - } - if (call->op.same_as(tma_load()) || - call->op.same_as(tma_load_im2col())) { - bool is_1d_tma_load = false; - if (const auto *arg0 = call->args[0].as()) { - is_1d_tma_load = !arg0->op.same_as(create_tma_descriptor()) && - call->op.same_as(tma_load()); - } - auto new_call = call.CopyOnWrite(); - new_call->args.Set(is_1d_tma_load ? 2 : 1, - makeGetBarrier(barrier_buf_, barrier_id_)); - if (is_cluster_barrier_) { - Map new_annotations = call->annotations; - new_annotations.Set("use_2cta", Bool(true)); - new_call->annotations = new_annotations; - } - return call; - } - if ((call->op.same_as(builtin::ptx_arrive_barrier()) || - call->op.same_as(tl::ptx_arrive_cluster_barrier())) && - !call->args.empty()) { - return IntImm(DataType::Int(32), 0); - } - return call; - } - - private: - Buffer barrier_buf_; - PrimExpr barrier_id_; - bool is_cluster_barrier_; - int cluster_size_; - }; - - // Rebind the producer-side barrier id and finish the stage with a normal - // barrier arrival. Pure-TMA pipelines do not need cp.async.mbarrier.arrive. - Stmt rewritten = MergeAdjacentEquivalentIfs(TmaForwardBarrierStmtRewriter( - barrier_buf_, barrier_id, is_cluster_barrier_, cluster_size_)(stmt)); - if (!append_arrive) { - return rewritten; - } - Optional guard = ExtractNonThreadProducerGuard(stmt); - Stmt elect_arrive = IfThenElse( - Call(DataType::Bool(), tl_shuffle_elect(), {producer_thread_extent_}), - makeArriveBarrier(barrier_buf_, barrier_id), std::nullopt); - elect_arrive = WrapStmtWithOptionalGuard(guard, elect_arrive); - return MergeAdjacentEquivalentIfs(SeqStmt({rewritten, elect_arrive})); - } - - Stmt RewritePureTmaForwardPairsWithFreshBarriers(const Stmt &stmt) { - class OutsideLoopPureTmaRewriter : public StmtExprMutator { - public: - explicit OutsideLoopPureTmaRewriter(ProducerConsumerWSRewriter *parent) - : parent_(parent) {} - - Stmt VisitStmt_(const SeqStmtNode *op) final { - Array new_seq; - bool changed = false; - for (size_t i = 0; i < op->seq.size(); ++i) { - if (auto pair = parent_->ExtractTmaProducerWaitPair(op->seq[i]); - pair.has_value()) { - ICHECK_GE(parent_->pure_tma_preloop_fwd_base_, 0); - ICHECK_LT(parent_->pure_tma_preloop_fwd_cursor_, - parent_->pure_tma_preloop_fwd_count_); - PrimExpr barrier_id = IntImm( - DataType::Int(32), parent_->pure_tma_preloop_fwd_base_ + - parent_->pure_tma_preloop_fwd_cursor_++); - Stmt producer_stmt = parent_->MergeAdjacentEquivalentIfs( - parent_->RewriteTmaStmtBarrierIdPreserveProtocol( - StripTmaCopyWriteBufferAttr(pair->producer_stmt), - barrier_id)); - Stmt wait_stmt = - parent_->RewriteWaitBarrier(pair->wait_stmt, barrier_id); - new_seq.push_back(producer_stmt); - new_seq.push_back(wait_stmt); - changed = true; - continue; - } - if (i + 1 < op->seq.size() && - parent_->ContainsTmaLoadStmt(op->seq[i]) && - parent_->IsMbarrierWaitParityStmt(op->seq[i + 1])) { - ICHECK_GE(parent_->pure_tma_preloop_fwd_base_, 0); - ICHECK_LT(parent_->pure_tma_preloop_fwd_cursor_, - parent_->pure_tma_preloop_fwd_count_); - PrimExpr barrier_id = IntImm( - DataType::Int(32), parent_->pure_tma_preloop_fwd_base_ + - parent_->pure_tma_preloop_fwd_cursor_++); - Stmt producer_stmt = parent_->MergeAdjacentEquivalentIfs( - parent_->RewriteTmaStmtBarrierIdPreserveProtocol( - StripTmaCopyWriteBufferAttr(op->seq[i]), barrier_id)); - Stmt wait_stmt = - parent_->RewriteWaitBarrier(op->seq[i + 1], barrier_id); - new_seq.push_back(producer_stmt); - new_seq.push_back(wait_stmt); - ++i; - changed = true; - continue; - } - Stmt visited = StmtExprMutator::VisitStmt(op->seq[i]); - new_seq.push_back(visited); - changed = changed || !visited.same_as(op->seq[i]); - } - if (!changed) { - return GetRef(op); - } - return new_seq.size() == 1 ? new_seq[0] : SeqStmt(new_seq); - } - - private: - ProducerConsumerWSRewriter *parent_; - }; - - OutsideLoopPureTmaRewriter rewriter(this); - return rewriter(stmt); - } - - bool IsSharedDependentConsumerPreStmt(const Stmt &stmt) { - bool has_shared_access = false; - bool has_control_ops = false; - PostOrderVisit(stmt, [&](const ObjectRef &node) { - if (has_control_ops) { - return; - } - if (auto *call = node.as()) { - if (IsBarrierOrTmaControlCall(call)) { - has_control_ops = true; - return; - } - } - if (auto *ld = node.as()) { - if (IsSharedBuffer(ld->buffer)) { - has_shared_access = true; - } - } - if (auto *st = node.as()) { - if (IsSharedBuffer(st->buffer)) { - has_shared_access = true; - } - } - }); - return has_shared_access && !has_control_ops; + StmtExprVisitor::VisitExpr_(op); } - bool IsBranchLocalPreStmtCandidate(const Stmt &stmt, - const LocalAccessSummary &summary) { - if (!summary.HasTrackedDefs()) { - return false; - } - bool has_disallowed = false; - PostOrderVisit(stmt, [&](const ObjectRef &node) { - if (has_disallowed) { - return; - } - if (const auto *call = node.as()) { - if (IsBarrierOrTmaControlCall(call)) { - has_disallowed = true; - return; - } - } - if (const auto *ld = node.as()) { - if (IsSharedBuffer(ld->buffer)) { - has_disallowed = true; - return; - } - } - if (const auto *st = node.as()) { - if (IsSharedBuffer(st->buffer) || IsGlobalBuffer(st->buffer)) { - has_disallowed = true; - return; - } - } - }); - return !has_disallowed; - } - - LocalLiveSet SeedLiveSetFromStmt(const Stmt &stmt, - const BufferDataToBufferMap &buffer_map) { - LocalLiveSet live; - live.AddUses(LocalAccessCollector::Collect(stmt, buffer_map)); - return live; - } - - /*! - * \brief Rebuild the block body, replacing the pipeline loop with - * ws_body and removing old barrier_init annotations / - * shared.barrier buffers. - * - * Statements after the pipeline loop (e.g. epilogue, store) should execute - * only on consumer threads. Prefer appending them into the consumer branch - * of the warp-specialized if/else to keep a single top-level partition. - * If that is not possible, fall back to an explicit consumer-thread guard. - */ - Stmt RebuildBlockBody(const Stmt &body, const ForNode *target_loop, - const Stmt &ws_body, - const BufferDataToBufferMap &buffer_data_to_buffer, - const LocalLiveSet &producer_live_seed, - const LocalLiveSet &consumer_live_seed) { - // If this IS the target loop, replace it - if (body.as() == target_loop) { - return ws_body; - } - - if (auto *seq = body.as()) { - Array new_seq; - Array pre_loop_stmts; - Array post_loop_stmts; - bool found_loop = false; - Optional rebuilt_loop = std::nullopt; - - for (const auto &s : seq->seq) { - if (!found_loop && ContainsLoop(s, target_loop)) { - // Replace the pipeline loop - rebuilt_loop = - RebuildBlockBody(s, target_loop, ws_body, buffer_data_to_buffer, - producer_live_seed, consumer_live_seed); - found_loop = true; - } else if (found_loop) { - // Collect statements after the pipeline loop - post_loop_stmts.push_back(s); - } else { - // Statements before the pipeline loop. - pre_loop_stmts.push_back(s); - } - } - - // Move a movable suffix of pre-loop statements into consumer branch - // (e.g. fragment initialization), keeping barriers/syncs outside. - size_t movable_begin = pre_loop_stmts.size(); - while (movable_begin > 0 && - IsMovableConsumerPrefixStmt(pre_loop_stmts[movable_begin - 1])) { - --movable_begin; - } - - // Split non-movable pre-loop statements into: - // common statements kept outside the WS split - // producer-side async issues - // consumer-side waits / shared-dependent setup - // branch-local prefix code that is assigned by actual downstream use - // (producer only / consumer only / duplicated). - // - // We drive the branch-local assignment with a backward liveness walk over - // local buffers / Let vars. This avoids duplicating consumer-only local - // initialization into the producer branch. - enum class PrefixRole : uint8_t { - kUnknown, - kSkip, - kCommon, - kProducer, - kConsumer, - kBoth, - kConsumerShared, - kSpecialTmaStart, - }; - - Array common_pre_stmts; - Array producer_prefix_ordered_stmts; - Array consumer_prefix_early_stmts; - Array consumer_wait_prefix_stmts; - Array consumer_shared_prefix_stmts; - std::vector prefix_roles(movable_begin, PrefixRole::kUnknown); - std::vector> rewritten_producer_prefix(movable_begin, - std::nullopt); - std::vector> rewritten_consumer_wait(movable_begin, - std::nullopt); - std::vector flat_tma_wait_for_start(movable_begin, -1); - std::vector flat_tma_cluster_member(movable_begin, false); - - for (size_t wait_idx = 0; wait_idx < movable_begin; ++wait_idx) { - if (ExtractTmaProducerWaitPair(pre_loop_stmts[wait_idx]).has_value()) { - continue; - } - auto cluster = ExtractFlatTmaProducerClusterBeforeWait( - pre_loop_stmts, static_cast(wait_idx)); - if (!cluster.has_value()) { - continue; - } - int start = cluster->first; - flat_tma_wait_for_start[start] = static_cast(wait_idx); - rewritten_producer_prefix[start] = cluster->second; - rewritten_consumer_wait[start] = pre_loop_stmts[wait_idx]; - for (int j = start + 1; j <= static_cast(wait_idx); ++j) { - flat_tma_cluster_member[j] = true; - } - } - - auto apply_to_live = [](LocalLiveSet *live, - const LocalAccessSummary &summary) { - live->KillDefs(summary); - live->AddUses(summary); - }; - - LocalLiveSet producer_live = producer_live_seed; - LocalLiveSet consumer_live = consumer_live_seed; - for (size_t j = movable_begin; j < pre_loop_stmts.size(); ++j) { - consumer_live.AddUses(LocalAccessCollector::Collect( - pre_loop_stmts[j], buffer_data_to_buffer)); - } - for (const auto &stmt : post_loop_stmts) { - consumer_live.AddUses( - LocalAccessCollector::Collect(stmt, buffer_data_to_buffer)); - } - - for (int i = static_cast(movable_begin) - 1; i >= 0; --i) { - if (flat_tma_cluster_member[i]) { - prefix_roles[i] = PrefixRole::kSkip; - continue; - } - - if (flat_tma_wait_for_start[i] >= 0) { - Stmt producer_prefix_stmt = - StripTmaCopyWriteBufferAttr(rewritten_producer_prefix[i].value()); - Stmt consumer_wait_stmt = rewritten_consumer_wait[i].value(); - if (remap_pure_tma_barriers_) { - ICHECK_GE(pure_tma_preloop_fwd_base_, 0); - ICHECK_LT(pure_tma_preloop_fwd_cursor_, - pure_tma_preloop_fwd_count_); - PrimExpr barrier_id = - IntImm(DataType::Int(32), pure_tma_preloop_fwd_base_ + - pure_tma_preloop_fwd_cursor_++); - producer_prefix_stmt = - RewriteTmaForwardProducerStmt(producer_prefix_stmt, barrier_id, - /*append_arrive=*/true); - consumer_wait_stmt = - RewriteWaitBarrier(consumer_wait_stmt, barrier_id); - } else if (use_full_tma_forward_barrier_protocol_) { - auto barrier_id = ExtractWaitBarrierId(consumer_wait_stmt); - ICHECK(barrier_id.defined()) - << "ProducerConsumerWS: failed to extract pre-loop TMA " - "forward barrier id"; - producer_prefix_stmt = RewriteTmaForwardProducerStmt( - producer_prefix_stmt, barrier_id.value(), - /*append_arrive=*/true); - } - producer_prefix_stmt = - MergeAdjacentEquivalentIfs(producer_prefix_stmt); - rewritten_producer_prefix[i] = producer_prefix_stmt; - rewritten_consumer_wait[i] = consumer_wait_stmt; - prefix_roles[i] = PrefixRole::kSpecialTmaStart; - apply_to_live(&producer_live, - LocalAccessCollector::Collect(producer_prefix_stmt, - buffer_data_to_buffer)); - apply_to_live(&consumer_live, - LocalAccessCollector::Collect(consumer_wait_stmt, - buffer_data_to_buffer)); - continue; - } - - if (i > 0 && ContainsTmaLoadStmt(pre_loop_stmts[i - 1]) && - IsMbarrierWaitParityStmt(pre_loop_stmts[i])) { - prefix_roles[i] = PrefixRole::kSkip; - continue; - } - - auto standalone_pair = ExtractTmaProducerWaitPair(pre_loop_stmts[i]); - if (standalone_pair.has_value() || - (static_cast(i + 1) < movable_begin && - ContainsTmaLoadStmt(pre_loop_stmts[i]) && - IsMbarrierWaitParityStmt(pre_loop_stmts[i + 1]))) { - Stmt producer_prefix_stmt = StripTmaCopyWriteBufferAttr( - standalone_pair.has_value() ? standalone_pair->producer_stmt - : pre_loop_stmts[i]); - Stmt consumer_wait_stmt = standalone_pair.has_value() - ? standalone_pair->wait_stmt - : pre_loop_stmts[i + 1]; - if (remap_pure_tma_barriers_) { - ICHECK_GE(pure_tma_preloop_fwd_base_, 0); - ICHECK_LT(pure_tma_preloop_fwd_cursor_, - pure_tma_preloop_fwd_count_); - PrimExpr barrier_id = - IntImm(DataType::Int(32), pure_tma_preloop_fwd_base_ + - pure_tma_preloop_fwd_cursor_++); - producer_prefix_stmt = - RewriteTmaForwardProducerStmt(producer_prefix_stmt, barrier_id, - /*append_arrive=*/true); - consumer_wait_stmt = - RewriteWaitBarrier(consumer_wait_stmt, barrier_id); - } else if (use_full_tma_forward_barrier_protocol_) { - auto barrier_id = ExtractWaitBarrierId(pre_loop_stmts[i + 1]); - ICHECK(barrier_id.defined()) - << "ProducerConsumerWS: failed to extract pre-loop TMA " - "forward barrier id"; - producer_prefix_stmt = RewriteTmaForwardProducerStmt( - producer_prefix_stmt, barrier_id.value(), - /*append_arrive=*/true); - } - producer_prefix_stmt = - MergeAdjacentEquivalentIfs(producer_prefix_stmt); - rewritten_producer_prefix[i] = producer_prefix_stmt; - rewritten_consumer_wait[i] = consumer_wait_stmt; - prefix_roles[i] = PrefixRole::kSpecialTmaStart; - if (!standalone_pair.has_value()) { - prefix_roles[i + 1] = PrefixRole::kSkip; - } - apply_to_live(&producer_live, - LocalAccessCollector::Collect(producer_prefix_stmt, - buffer_data_to_buffer)); - apply_to_live(&consumer_live, - LocalAccessCollector::Collect(consumer_wait_stmt, - buffer_data_to_buffer)); - continue; - } - - const Stmt &stmt = pre_loop_stmts[i]; - LocalAccessSummary summary = - LocalAccessCollector::Collect(stmt, buffer_data_to_buffer); - if (remap_pure_tma_barriers_ && - IsBranchLocalPreStmtCandidate(stmt, summary)) { - bool producer_needed = producer_live.NeedsAnyDef(summary); - bool consumer_needed = consumer_live.NeedsAnyDef(summary); - if (producer_needed && consumer_needed) { - prefix_roles[i] = PrefixRole::kBoth; - apply_to_live(&producer_live, summary); - apply_to_live(&consumer_live, summary); - } else if (producer_needed) { - prefix_roles[i] = PrefixRole::kProducer; - apply_to_live(&producer_live, summary); - } else if (consumer_needed) { - prefix_roles[i] = PrefixRole::kConsumer; - apply_to_live(&consumer_live, summary); - } else { - prefix_roles[i] = PrefixRole::kCommon; - apply_to_live(&producer_live, summary); - apply_to_live(&consumer_live, summary); - } - continue; - } - - if (IsSharedDependentConsumerPreStmt(stmt)) { - prefix_roles[i] = PrefixRole::kConsumerShared; - apply_to_live(&consumer_live, summary); - } else { - prefix_roles[i] = PrefixRole::kCommon; - apply_to_live(&producer_live, summary); - apply_to_live(&consumer_live, summary); - } - } - - for (size_t i = 0; i < movable_begin; ++i) { - switch (prefix_roles[i]) { - case PrefixRole::kSkip: - break; - case PrefixRole::kCommon: - common_pre_stmts.push_back(pre_loop_stmts[i]); - break; - case PrefixRole::kProducer: - producer_prefix_ordered_stmts.push_back(pre_loop_stmts[i]); - break; - case PrefixRole::kConsumer: - consumer_prefix_early_stmts.push_back(pre_loop_stmts[i]); - break; - case PrefixRole::kBoth: - producer_prefix_ordered_stmts.push_back(pre_loop_stmts[i]); - consumer_prefix_early_stmts.push_back(pre_loop_stmts[i]); - break; - case PrefixRole::kConsumerShared: - consumer_shared_prefix_stmts.push_back(pre_loop_stmts[i]); - break; - case PrefixRole::kSpecialTmaStart: - ICHECK(rewritten_producer_prefix[i].defined()); - ICHECK(rewritten_consumer_wait[i].defined()); - producer_prefix_ordered_stmts.push_back( - rewritten_producer_prefix[i].value()); - consumer_wait_prefix_stmts.push_back( - rewritten_consumer_wait[i].value()); - break; - case PrefixRole::kUnknown: - common_pre_stmts.push_back(pre_loop_stmts[i]); - break; - } - } - - for (const auto &s : common_pre_stmts) { - new_seq.push_back(s); - } - - auto MakeOptionalStmt = [](const Array &stmts) -> Optional { - if (stmts.empty()) { - return std::nullopt; - } - return stmts.size() == 1 ? Optional(stmts[0]) - : Optional(SeqStmt(stmts)); - }; - - Array consumer_prefix_stmts; - for (const auto &s : consumer_prefix_early_stmts) { - consumer_prefix_stmts.push_back(s); - } - // Keep pure local init before waits to delay blocking until needed. - for (size_t j = movable_begin; j < pre_loop_stmts.size(); ++j) { - consumer_prefix_stmts.push_back(pre_loop_stmts[j]); - } - for (const auto &s : consumer_wait_prefix_stmts) { - consumer_prefix_stmts.push_back(s); - } - for (const auto &s : consumer_shared_prefix_stmts) { - consumer_prefix_stmts.push_back(s); - } - Optional consumer_prefix = MakeOptionalStmt(consumer_prefix_stmts); - Optional producer_prefix = - MakeOptionalStmt(producer_prefix_ordered_stmts); - - Optional ws_stmt = rebuilt_loop; - Optional producer_guard = std::nullopt; - Optional pre_guard = std::nullopt; - Optional post_guard = std::nullopt; - - // Merge TMA-issue producer prefix into producer branch. - if (producer_prefix.defined()) { - ICHECK(thread_iv_.defined()); - Stmt rewritten = PCThreadIdxRewriter::Rewrite( - producer_prefix.value(), thread_iv_->var, - thread_iv_->var - consumer_thread_extent_, producer_thread_extent_, - /*do_shuffle=*/true); - if (ws_stmt.defined()) { - auto merged = TryPrependToProducerBranch(ws_stmt.value(), rewritten); - if (merged.defined()) { - ws_stmt = merged.value(); - } else { - producer_guard = IfThenElse( - GE(thread_iv_->var, consumer_thread_extent_), rewritten); - } - } else { - producer_guard = IfThenElse( - GE(thread_iv_->var, consumer_thread_extent_), rewritten); - } - } - - // Merge movable pre-loop suffix into consumer branch when possible. - if (consumer_prefix.defined()) { - if (ws_stmt.defined()) { - auto merged = TryPrependToConsumerBranch(ws_stmt.value(), - consumer_prefix.value()); - if (merged.defined()) { - ws_stmt = merged.value(); - } else { - ICHECK(thread_iv_.defined()); - pre_guard = IfThenElse(LT(thread_iv_->var, consumer_thread_extent_), - consumer_prefix.value()); - } - } else { - ICHECK(thread_iv_.defined()); - pre_guard = IfThenElse(LT(thread_iv_->var, consumer_thread_extent_), - consumer_prefix.value()); - } - } - - // Keep post-loop statements on consumer threads. - if (!post_loop_stmts.empty()) { - Stmt post_body = post_loop_stmts.size() == 1 ? post_loop_stmts[0] - : SeqStmt(post_loop_stmts); - if (remap_pure_tma_barriers_) { - // When the target loop remaps pure-TMA forward barriers to the WS - // layout, any remaining TMA forward pairs outside that loop need - // fresh ids as well. Otherwise a rewritten pre-loop pair can alias a - // later consumer-only TMA loop that still uses its original id. - post_body = RewritePureTmaForwardPairsWithFreshBarriers(post_body); - } - bool merged = false; - if (ws_stmt.defined()) { - auto merged_stmt = - TryAppendToConsumerBranch(ws_stmt.value(), post_body); - if (merged_stmt.defined()) { - ws_stmt = merged_stmt.value(); - merged = true; + void VisitStmt_(const BlockNode *op) final { + // Collect layout_map entries so we can cross-check TMA copy targets. + if (op->annotations.count("layout_map")) { + auto anno = op->annotations.Get("layout_map"); + if (auto gmap = anno->as>(); gmap.has_value()) { + for (const auto &[key, val] : gmap.value()) { + Layout layout; + if (auto l = val.as(); l.has_value()) + layout = l.value(); + if (auto buf = key.as(); buf.has_value()) { + layout_map_[buf.value()->data.get()] = {buf.value(), layout}; + } else if (auto var = key.as(); var.has_value()) { + for (const auto &buf : op->alloc_buffers) { + if (buf->data.same_as(var.value())) { + layout_map_[buf->data.get()] = {buf, layout}; + break; + } + } } } - if (!merged) { - ICHECK(thread_iv_.defined()); - post_guard = IfThenElse(LT(thread_iv_->var, consumer_thread_extent_), - post_body); - } - } - - if (producer_guard.defined()) { - new_seq.push_back(producer_guard.value()); - } - if (pre_guard.defined()) { - new_seq.push_back(pre_guard.value()); - } - if (ws_stmt.defined()) { - new_seq.push_back(ws_stmt.value()); - } - if (post_guard.defined()) { - new_seq.push_back(post_guard.value()); - } - - if (new_seq.size() == 1) - return new_seq[0]; - return SeqStmt(new_seq); - } - - // Walk through wrapper nodes - if (auto *if_stmt = body.as()) { - bool then_has_loop = ContainsLoop(if_stmt->then_case, target_loop); - bool else_has_loop = - if_stmt->else_case.defined() && - ContainsLoop(if_stmt->else_case.value(), target_loop); - if (then_has_loop || else_has_loop) { - Stmt new_then = if_stmt->then_case; - Optional new_else = if_stmt->else_case; - if (then_has_loop) { - new_then = RebuildBlockBody(if_stmt->then_case, target_loop, ws_body, - buffer_data_to_buffer, producer_live_seed, - consumer_live_seed); - } - if (else_has_loop) { - new_else = RebuildBlockBody(if_stmt->else_case.value(), target_loop, - ws_body, buffer_data_to_buffer, - producer_live_seed, consumer_live_seed); - } - return IfThenElse(if_stmt->condition, new_then, new_else, - if_stmt->span); - } - } - if (auto *attr = body.as()) { - if (ContainsLoop(attr->body, target_loop)) { - Stmt new_body = RebuildBlockBody( - attr->body, target_loop, ws_body, buffer_data_to_buffer, - producer_live_seed, consumer_live_seed); - return AttrStmt(attr->node, attr->attr_key, attr->value, new_body); } } - if (auto *let_s = body.as()) { - if (ContainsLoop(let_s->body, target_loop)) { - Stmt new_body = RebuildBlockBody( - let_s->body, target_loop, ws_body, buffer_data_to_buffer, - producer_live_seed, consumer_live_seed); - return LetStmt(let_s->var, let_s->value, new_body); - } - } - - // Fallback: return unchanged - return body; + StmtExprVisitor::VisitStmt_(op); } - bool ContainsLoop(const Stmt &stmt, const ForNode *target) { - if (stmt.as() == target) - return true; - if (auto *seq = stmt.as()) { - for (const auto &s : seq->seq) { - if (ContainsLoop(s, target)) - return true; - } - } - if (auto *if_stmt = stmt.as()) { - if (ContainsLoop(if_stmt->then_case, target)) { - return true; - } - if (if_stmt->else_case.defined()) { - return ContainsLoop(if_stmt->else_case.value(), target); - } - return false; - } - if (auto *attr = stmt.as()) { - return ContainsLoop(attr->body, target); - } - if (auto *let_s = stmt.as()) { - return ContainsLoop(let_s->body, target); + /// A copy destination is TMA-compatible if it has no layout annotation, + /// or its annotated layout is a recognised swizzle / linear layout. + bool HasTmaCompatibleLayout(const Buffer &dst) const { + auto it = layout_map_.find(dst->data.get()); + if (it == layout_map_.end()) { + return true; // no annotation → identity layout → TMA OK } - if (auto *realize = stmt.as()) { - return ContainsLoop(realize->block->body, target); - } - if (auto *block = stmt.as()) { - return ContainsLoop(block->body, target); + const auto &[buf, layout] = it->second; + if (!layout.defined()) { + return false; // annotation present but layout not parseable } - return false; + return IsTmaCompatibleLayout(layout, buf); } - IterVar thread_iv_; - PrimExpr - consumer_thread_extent_; // Original thread extent (consumer warp count) - PrimExpr producer_thread_extent_ = IntImm(DataType::Int(32), 128); - Buffer barrier_buf_; // shared.barrier scope buffer for mbarriers - Optional num_threads_; - bool ws_transformed_ = false; - bool use_full_tma_forward_barrier_protocol_ = false; - bool remap_pure_tma_barriers_ = false; - int pure_tma_preloop_fwd_base_ = -1; - int pure_tma_preloop_fwd_count_ = 0; - int pure_tma_preloop_fwd_cursor_ = 0; - VarBindingMap current_loop_guard_bindings_; - bool is_cluster_barrier_ = false; - int cluster_size_ = 1; + Target target_; + bool in_pipeline_{false}; + bool has_pipeline_loop_{false}; + bool has_tma_tile_op_{false}; + // Map from buffer data Var pointer → (Buffer, Layout) for layout_map entries. + std::unordered_map> layout_map_; }; +} // namespace + // --------------------------------------------------------------------------- // Pass registration // --------------------------------------------------------------------------- -using namespace tir::transform; - -// Check only for manual warp specialization ("warp_specialize" attr). -// Unlike WarpSpecializedDetector, we do NOT skip when TMA+mbarrier are -// both present, since that is the expected input pattern for this pass. -class ManualWSDetector : public StmtExprVisitor { -public: - static bool HasManualWS(const Stmt &stmt) { - ManualWSDetector d; - d.VisitStmt(stmt); - return d.has_manual_ws_; - } - -private: - void VisitStmt_(const AttrStmtNode *op) final { - if (op->attr_key == "warp_specialize" && - op->value.as()->value == 1) { - has_manual_ws_ = true; - } - StmtExprVisitor::VisitStmt_(op); - } - bool has_manual_ws_ = false; -}; - tvm::transform::Pass ProducerConsumerWarpSpecialized() { - auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) { - bool disable_warp_specialized = - ctx->GetConfig(kDisableWarpSpecialized, Bool(false)).value(); - if (disable_warp_specialized) + using namespace tir::transform; + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + // Skip if disabled. + if (ctx->GetConfig(kDisableWarpSpecialized, Optional()) + .value_or(false)) { return f; - - // Skip if user has manual warp specialization - if (ManualWSDetector::HasManualWS(f->body)) + } + // Skip if the function already has manual WS. + if (ManualWSDetector::HasManualWS(f->body)) { return f; - - return ProducerConsumerWSRewriter::Substitute(f); + } + // Skip if TMA is not available. + auto target = f->GetAttr(tvm::attr::kTarget); + if (!target.defined() || !TargetHasBulkCopy(target.value())) { + return f; + } + // Only apply MVB + WS if the function is a tiled WS candidate. + if (!TiledWSCandidate::Check(f->body, target.value())) { + LOG(WARNING) << "[WS] skipped: no TMA copies in pipeline loop"; + return f; + } + LOG(WARNING) << "[WS] candidate found, applying MVB + WS"; + // Expand shared buffers for pipelining before the WS split. + // Keep the original so we can fall back if the WS rewriter doesn't fire + // (e.g. non-tile-op consumers in the loop body). + PrimFunc original_f = f; + f = ApplyMultiVersionBufferRewriter(std::move(f)); + PrimFunc result = ProducerConsumerWSRewriter::Substitute(std::move(f)); + if (!result->HasNonzeroAttr(kTiledWSApplied)) { + LOG(WARNING) << "[WS] rewriter did not fire, falling back"; + // The TMA kernel needs warp specialization for correct pipelined + // execution. Since the tiled rewriter could not apply WS (e.g. + // conditional loop body), strip pipeline annotations so that + // PipelinePlanning / InjectSoftwarePipeline do not generate + // broken non-WS TMA pipeline code. + class StripPipelineAnnotation : public tir::StmtExprMutator { + public: + tir::Stmt VisitStmt_(const tir::ForNode *op) final { + auto stmt = tir::StmtExprMutator::VisitStmt_(op); + const auto *for_node = stmt.as(); + ICHECK(for_node); + if (for_node->annotations.count("num_stages")) { + tir::For new_for = Downcast(stmt); + auto *n = new_for.CopyOnWrite(); + n->annotations.erase("num_stages"); + return std::move(new_for); + } + return stmt; + } + }; + StripPipelineAnnotation stripper; + auto stripped = stripper(original_f->body); + auto *fn = original_f.CopyOnWrite(); + fn->body = stripped; + return original_f; + } + LOG(WARNING) << "[WS] transformation applied successfully"; + return result; }; return CreatePrimFuncPass(pass_func, 0, "tl.ProducerConsumerWarpSpecialized", {}); @@ -4038,6 +2403,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.ProducerConsumerWarpSpecialized", ProducerConsumerWarpSpecialized); + refl::GlobalDef().def("tl.transform.ProducerConsumerWarpSpecializedTiled", + ProducerConsumerWarpSpecialized); } } // namespace tl diff --git a/src/transform/ptx_async_copy_injector.h b/src/transform/ptx_async_copy_injector.h index 0a5c0107be..80c642562c 100644 --- a/src/transform/ptx_async_copy_injector.h +++ b/src/transform/ptx_async_copy_injector.h @@ -5,15 +5,20 @@ namespace tvm { namespace tl { +struct PTXAsyncCopyInjectResult { + tvm::tir::Stmt stmt; + bool injected_ptx_async_copy{false}; +}; + /*! \brief Inject PTX cp.async lowering patterns into a statement. * * This is the statement-level entrypoint used by other transforms to apply the * same rewrite as the `tl.LowerPTXAsyncCopy` pass, but scoped to a region * (e.g., a lowered parallel loop) rather than the whole PrimFunc. */ -tvm::tir::Stmt InjectPTXAsyncCopy(const tvm::tir::Stmt &body, - bool enable_auto_async_copy, - bool async_without_async_commit_wait = false); +PTXAsyncCopyInjectResult +InjectPTXAsyncCopy(const tvm::tir::Stmt &body, bool enable_auto_async_copy, + bool async_without_async_commit_wait = false); } // namespace tl } // namespace tvm diff --git a/src/transform/reuse_local_descriptor_allocations.cc b/src/transform/reuse_local_descriptor_allocations.cc new file mode 100644 index 0000000000..a4e2dae3d6 --- /dev/null +++ b/src/transform/reuse_local_descriptor_allocations.cc @@ -0,0 +1,254 @@ +/*! + * \file reuse_local_descriptor_allocations.cc + * \brief Pool lexically-disjoint local descriptor allocations. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "runtime/thread_storage_scope.h" +#include "tir/transforms/ir_utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; +namespace refl = tvm::ffi::reflection; + +namespace { + +bool IsLocalDescriptorScope(const Var &buffer_var) { + std::string scope = GetPtrStorageScope(buffer_var); + return scope.rfind("local.descriptor.", 0) == 0; +} + +bool IsDescriptorHoistBoundary(const AttrStmtNode *op) { + return op->attr_key == tir::attr::thread_extent || + op->attr_key == tir::attr::virtual_thread || op->attr_key == "target"; +} + +bool IsReusableDescriptorAllocate(const AllocateNode *op) { + return IsLocalDescriptorScope(op->buffer_var) && is_one(op->condition) && + op->annotations.empty() && op->ConstantAllocationSize() > 0; +} + +std::string MakeDescriptorSignature(const AllocateNode *op) { + const DataType &dtype = op->dtype; + return GetPtrStorageScope(op->buffer_var) + "|" + + std::to_string(dtype.code()) + ":" + std::to_string(dtype.bits()) + + ":" + std::to_string(dtype.lanes()) + "|" + + std::to_string(op->ConstantAllocationSize()); +} + +struct AllocSite { + Var var; + DataType dtype; + ffi::Array extents; + ffi::Map annotations; + std::string signature; +}; + +class DescriptorAllocCollector : public StmtExprVisitor { +public: + static std::vector Collect(const Stmt &stmt) { + DescriptorAllocCollector collector; + collector(stmt); + return std::move(collector.allocs_); + } + +private: + void VisitStmt_(const AllocateNode *op) final { + if (IsReusableDescriptorAllocate(op)) { + allocs_.push_back(AllocSite{op->buffer_var, op->dtype, op->extents, + op->annotations, + MakeDescriptorSignature(op)}); + } + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const AttrStmtNode *op) final { + if (IsDescriptorHoistBoundary(op)) { + return; + } + StmtExprVisitor::VisitStmt_(op); + } + + std::vector allocs_; +}; + +class DescriptorVarRemapper : public StmtExprMutator { +public: + DescriptorVarRemapper(std::unordered_map var_remap, + std::unordered_set removed_allocs) + : var_remap_(std::move(var_remap)), + removed_allocs_(std::move(removed_allocs)) {} + +private: + PrimExpr VisitExpr_(const VarNode *op) final { + if (auto it = var_remap_.find(op); it != var_remap_.end()) { + return it->second; + } + return tvm::ffi::GetRef(op); + } + + Stmt VisitStmt_(const AllocateNode *op) final { + if (removed_allocs_.count(op->buffer_var.get())) { + return VisitStmt(op->body); + } + return StmtExprMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const DeclBufferNode *op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + Buffer new_buffer = RemapBuffer(node->buffer); + if (!new_buffer.same_as(node->buffer)) { + node.CopyOnWrite()->buffer = new_buffer; + } + return std::move(node); + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + Buffer new_buffer = RemapBuffer(node->buffer); + if (!new_buffer.same_as(node->buffer)) { + node.CopyOnWrite()->buffer = new_buffer; + } + return std::move(node); + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + Buffer new_buffer = RemapBuffer(node->buffer); + if (!new_buffer.same_as(node->buffer)) { + node.CopyOnWrite()->buffer = new_buffer; + } + return std::move(node); + } + + Buffer RemapBuffer(Buffer buffer) const { + if (auto it = var_remap_.find(buffer->data.get()); it != var_remap_.end()) { + Buffer new_buffer = buffer; + new_buffer.CopyOnWrite()->data = it->second; + return new_buffer; + } + return buffer; + } + + std::unordered_map var_remap_; + std::unordered_set removed_allocs_; +}; + +class ReuseLocalDescriptorAllocationsMutator : public StmtExprMutator { +public: + static PrimFunc Rewrite(PrimFunc func) { + auto fptr = func.CopyOnWrite(); + ReuseLocalDescriptorAllocationsMutator rewriter; + fptr->body = rewriter(std::move(fptr->body)); + return func; + } + +private: + struct PoolSlot { + AllocSite canonical; + int use_count{0}; + }; + + Stmt VisitStmt_(const SeqStmtNode *op) final { + ffi::Array visited_children; + visited_children.reserve(op->seq.size()); + for (const Stmt &stmt : op->seq) { + visited_children.push_back(VisitStmt(stmt)); + } + + std::unordered_map> signature_slots; + std::unordered_map alloc_to_slot; + std::vector slots; + + for (const Stmt &stmt : visited_children) { + std::unordered_map local_slot_index; + for (const AllocSite &alloc : DescriptorAllocCollector::Collect(stmt)) { + int ordinal = local_slot_index[alloc.signature]++; + std::vector &sig_slots = signature_slots[alloc.signature]; + if (static_cast(sig_slots.size()) <= ordinal) { + sig_slots.push_back(static_cast(slots.size())); + slots.push_back(PoolSlot{alloc, 0}); + } + int slot_idx = sig_slots[ordinal]; + alloc_to_slot[alloc.var.get()] = slot_idx; + ++slots[slot_idx].use_count; + } + } + + std::unordered_map var_remap; + std::unordered_set removed_allocs; + std::vector hoisted_allocs; + hoisted_allocs.reserve(slots.size()); + + for (const PoolSlot &slot : slots) { + if (slot.use_count <= 1) { + continue; + } + removed_allocs.insert(slot.canonical.var.get()); + hoisted_allocs.push_back(slot.canonical); + } + + if (hoisted_allocs.empty()) { + return visited_children.size() == 1 ? visited_children[0] + : SeqStmt(visited_children); + } + + for (const auto &[var, slot_idx] : alloc_to_slot) { + if (slots[slot_idx].use_count <= 1) { + continue; + } + removed_allocs.insert(var); + const Var &canonical_var = slots[slot_idx].canonical.var; + if (var != canonical_var.get()) { + var_remap[var] = canonical_var; + } + } + + DescriptorVarRemapper rewriter(std::move(var_remap), + std::move(removed_allocs)); + ffi::Array rewritten_children; + rewritten_children.reserve(visited_children.size()); + for (const Stmt &stmt : visited_children) { + rewritten_children.push_back(rewriter(stmt)); + } + + Stmt body = rewritten_children.size() == 1 ? rewritten_children[0] + : SeqStmt(rewritten_children); + for (auto it = hoisted_allocs.rbegin(); it != hoisted_allocs.rend(); ++it) { + body = Allocate(it->var, it->dtype, it->extents, const_true(), + std::move(body), it->annotations); + } + return body; + } +}; + +} // namespace + +tir::transform::Pass ReuseLocalDescriptorAllocations() { + auto pass_func = [](PrimFunc func, IRModule mod, + tvm::transform::PassContext ctx) { + return ReuseLocalDescriptorAllocationsMutator::Rewrite(std::move(func)); + }; + return tir::transform::CreatePrimFuncPass( + pass_func, 0, "tl.ReuseLocalDescriptorAllocations", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + refl::GlobalDef().def("tl.transform.ReuseLocalDescriptorAllocations", + ReuseLocalDescriptorAllocations); +} + +} // namespace tl +} // namespace tvm diff --git a/testing/conftest.py b/testing/conftest.py index d7952ff4e6..72094aa135 100644 --- a/testing/conftest.py +++ b/testing/conftest.py @@ -56,16 +56,7 @@ def pytest_collection_modifyitems(config, items): def pytest_terminal_summary(terminalreporter, exitstatus, config): """Ensure that at least one test is collected. Error out if all tests are skipped.""" - known_types = { - "failed", - "passed", - "skipped", - "deselected", - "xfailed", - "xpassed", - "warnings", - "error", - } + known_types = {"failed", "passed", "skipped", "deselected", "xfailed", "xpassed", "warnings", "error"} executed_count = sum(len(terminalreporter.stats.get(k, [])) for k in known_types.difference({"skipped", "deselected"})) if executed_count == 0 and getattr(config, "_perf_items_filtered", 0) > 0: terminalreporter.write_sep( diff --git a/testing/python/analysis/test_tilelang_nested_loop_checker.py b/testing/python/analysis/test_tilelang_nested_loop_checker.py index ef80cf3d64..0038f79f92 100644 --- a/testing/python/analysis/test_tilelang_nested_loop_checker.py +++ b/testing/python/analysis/test_tilelang_nested_loop_checker.py @@ -170,10 +170,7 @@ def run_gemm_nested_pipelines( kernel = tilelang.compile( program, out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) profiler = kernel.get_profiler() @@ -479,10 +476,7 @@ def run_gemm_mixed_pp(): kernel = tilelang.compile( program, out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) profiler = kernel.get_profiler() @@ -516,10 +510,7 @@ def ref_program(A, B): tilelang.compile( program1, out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) @@ -604,10 +595,7 @@ def run_gemm_tiled_op_with_parallel(): kernel = tilelang.compile( program, out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) profiler = kernel.get_profiler() @@ -641,10 +629,7 @@ def ref_program(A, B): tilelang.compile( program1, out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) diff --git a/testing/python/cache/test_tilelang_kernel_cache.py b/testing/python/cache/test_tilelang_kernel_cache.py index f7a0be19de..a3ca0eee33 100644 --- a/testing/python/cache/test_tilelang_kernel_cache.py +++ b/testing/python/cache/test_tilelang_kernel_cache.py @@ -43,6 +43,17 @@ def _get_target_from_backend(backend: str): return "cutedsl" if backend == "cutedsl" else "auto" +def _require_backend_available(backend: str) -> None: + if backend != "cutedsl": + return + try: + from tilelang.jit.adapter.cutedsl.checks import check_cutedsl_available + + check_cutedsl_available() + except ImportError as e: + pytest.skip(f"CuTeDSL backend unavailable: {e}") + + class PostProcCounter: """Track postproc callback invocations with a simple counter.""" @@ -102,6 +113,7 @@ def clean_cache_env(tmp_path, request): """ # This fixture should ONLY be used with @pytest.mark.parametrize("backend", ...) backend = request.node.callspec.params["backend"] # Will raise KeyError if missing + _require_backend_available(backend) cache_dir = tmp_path / "tilelang_cache" cache_dir.mkdir() diff --git a/testing/python/carver/test_tilelang_carver_recommend_hints.py b/testing/python/carver/test_tilelang_carver_recommend_hints.py index a096ec3b26..d62e397d3f 100644 --- a/testing/python/carver/test_tilelang_carver_recommend_hints.py +++ b/testing/python/carver/test_tilelang_carver_recommend_hints.py @@ -134,6 +134,7 @@ def run_fmha_recommend_hints( @tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(8, 0) def test_fmha_recommend_hints(): run_fmha_recommend_hints(4, 32, 512, 512, 128, T.float16, T.float16, T.float16) run_fmha_recommend_hints(4, 32, 512, 512, 128, T.int8, T.int32, T.int32) diff --git a/testing/python/components/test_tilelang_pass_config_disable_tma_lower.py b/testing/python/components/test_tilelang_pass_config_disable_tma_lower.py new file mode 100644 index 0000000000..8a53d99969 --- /dev/null +++ b/testing/python/components/test_tilelang_pass_config_disable_tma_lower.py @@ -0,0 +1,30 @@ +import warnings + +import tilelang +import tilelang.testing +from tilelang import tvm +from tilelang.jit.kernel import JITKernel + + +def test_disable_tma_lower_pass_context_compat(): + with tvm.transform.PassContext(config={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}): + assert bool(tvm.transform.PassContext.current().config["tl.disable_tma_lower"]) + + with tvm.transform.PassContext(config={"tl.disable_tma_lower": True}): + assert bool(tvm.transform.PassContext.current().config["tl.disable_tma_lower"]) + + +def test_disable_tma_lower_warns_in_jit_entry(): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always", DeprecationWarning) + JITKernel( + from_database=True, + target="c", + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}, + ) + + assert any("tl.disable_tma_lower" in str(item.message) and "v0.1.10" in str(item.message) for item in caught) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py b/testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py index d599e581ac..44bf783650 100644 --- a/testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py +++ b/testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py @@ -84,10 +84,7 @@ def run_gemm( kernel = tilelang.compile( program, out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: disable_warp_specialized, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: disable_warp_specialized}, ) profiler = kernel.get_profiler() diff --git a/testing/python/cuda/test_cuda_f32x2_intrinsics.py b/testing/python/cuda/test_cuda_f32x2_intrinsics.py index 6dfb2d8099..aace622dbe 100644 --- a/testing/python/cuda/test_cuda_f32x2_intrinsics.py +++ b/testing/python/cuda/test_cuda_f32x2_intrinsics.py @@ -29,11 +29,7 @@ # Dtype helpers # --------------------------------------------------------------------------- -_DTYPE_MAP = { - "float32": (T.float32, torch.float32), - "bfloat16": (T.bfloat16, torch.bfloat16), - "float16": (T.float16, torch.float16), -} +_DTYPE_MAP = {"float32": (T.float32, torch.float32), "bfloat16": (T.bfloat16, torch.bfloat16), "float16": (T.float16, torch.float16)} # --------------------------------------------------------------------------- # Generic kernel builders using T.Ramp for packed x2 access @@ -108,11 +104,7 @@ def _lower_to_cuda_source(func, target: str = SM80_TARGET) -> str: # --------------------------------------------------------------------------- # Map from Python operator string to (lambda, tl_func_name) -_AUTO_VEC_OPS = { - "add": (lambda a, b: a + b, "add2"), - "sub": (lambda a, b: a - b, "sub2"), - "mul": (lambda a, b: a * b, "mul2"), -} +_AUTO_VEC_OPS = {"add": (lambda a, b: a + b, "add2"), "sub": (lambda a, b: a - b, "sub2"), "mul": (lambda a, b: a * b, "mul2")} def _make_auto_vec_binary_kernel(py_op, dtype_tl, width: int = 4): @@ -165,10 +157,7 @@ def main( _DTYPES = ["float32", "bfloat16", "float16"] # Native cast types expected in codegen for 16-bit packed types -_NATIVE_CAST_TYPE = { - "bfloat16": "__nv_bfloat162", - "float16": "__half2", -} +_NATIVE_CAST_TYPE = {"bfloat16": "__nv_bfloat162", "float16": "__half2"} # Torch reference functions _TORCH_REFS = { diff --git a/testing/python/issue/test_tilelang_issue_1001.py b/testing/python/issue/test_tilelang_issue_1001.py index d6a9ffe264..4e34ffc972 100644 --- a/testing/python/issue/test_tilelang_issue_1001.py +++ b/testing/python/issue/test_tilelang_issue_1001.py @@ -5,10 +5,7 @@ @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def _cumsum_view_infer_layout(hidden): num_tokens = T.dynamic("num_tokens") diff --git a/testing/python/issue/test_tilelang_issue_1008.py b/testing/python/issue/test_tilelang_issue_1008.py index 1b25e203cd..a09e2b8cd6 100644 --- a/testing/python/issue/test_tilelang_issue_1008.py +++ b/testing/python/issue/test_tilelang_issue_1008.py @@ -5,10 +5,7 @@ @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def _fill_with_static_region_kernel(): num_tokens = T.symbolic("num_tokens") @@ -22,10 +19,7 @@ def buggy_kernel(x: T.Tensor[(num_tokens,), "int64"]): # noqa: F821 @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def _fill_with_dynamic_region_kernel(): num_tokens = T.symbolic("num_tokens") diff --git a/testing/python/issue/test_tilelang_issue_1106.py b/testing/python/issue/test_tilelang_issue_1106.py index f41450c0c3..c5ae33b1aa 100644 --- a/testing/python/issue/test_tilelang_issue_1106.py +++ b/testing/python/issue/test_tilelang_issue_1106.py @@ -4,10 +4,7 @@ @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def get_kernel(m: int): dtype = "int32" diff --git a/testing/python/issue/test_tilelang_issue_1210.py b/testing/python/issue/test_tilelang_issue_1210.py index aa0ce2f0da..d9f2a3a63e 100644 --- a/testing/python/issue/test_tilelang_issue_1210.py +++ b/testing/python/issue/test_tilelang_issue_1210.py @@ -45,10 +45,7 @@ def fwd_main(KV: T.Tensor((M, N), dtype), ids: T.Tensor((4,), T.int32)): def test_make_packed_api_no_free_loop_var(): func, func_if_cond = _make_kernel(4, 4), _make_kernel_if_cond(4, 4) # Keep warp-specialization/TMA disabled to match the original repro - cfg = { - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - } + cfg = {tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} tilelang.compile(func, pass_configs=cfg) tilelang.compile(func_if_cond, pass_configs=cfg) diff --git a/testing/python/issue/test_tilelang_issue_1263.py b/testing/python/issue/test_tilelang_issue_1263.py index 418500a8fa..b0c1e80bde 100644 --- a/testing/python/issue/test_tilelang_issue_1263.py +++ b/testing/python/issue/test_tilelang_issue_1263.py @@ -51,18 +51,12 @@ def test_issue_1263_pipeline_no_consumer(): tilelang.compile(_test_kernel(1024, 1024)) tilelang.compile( _test_kernel(1024, 1024), - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) tilelang.compile(_test_kernel_if_cond(1024, 1024)) tilelang.compile( _test_kernel_if_cond(1024, 1024), - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) diff --git a/testing/python/issue/test_tilelang_issue_1744.py b/testing/python/issue/test_tilelang_issue_1744.py index 5610f5b0f4..2e3d7b8385 100644 --- a/testing/python/issue/test_tilelang_issue_1744.py +++ b/testing/python/issue/test_tilelang_issue_1744.py @@ -5,10 +5,7 @@ @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def _buggy_kernel(S: T.Tensor((8), T.bfloat16), D: T.Tensor((4, 64), T.bfloat16)): with T.Kernel(1, threads=128): diff --git a/testing/python/issue/test_tilelang_issue_tma_no_ws.py b/testing/python/issue/test_tilelang_issue_tma_no_ws.py index 0a6ae9f463..a0ab91da27 100644 --- a/testing/python/issue/test_tilelang_issue_tma_no_ws.py +++ b/testing/python/issue/test_tilelang_issue_tma_no_ws.py @@ -21,143 +21,9 @@ def _compile_tvm_ffi(func, pass_configs, **kwargs): tilelang.enable_cache() -@tilelang.testing.requires_cuda_compute_version(9, 0) -def test_tma_lower_no_warp_specialized_injects_mbarrier(): - """Regression for Hopper TMA lowering when warp specialization is disabled. - - When `tl.disable_tma_lower=False` but `tl.disable_warp_specialized=True`, the - optimization pipeline must still run the TMA barrier allocation/injection - passes so generated CUDA source defines and uses `mbarrier[...]` correctly. - """ - - M, K = 16, 128 - block_m, block_k = 4, 128 - threads = 32 - - @T.prim_func - def tma_copy(x: T.Tensor((M, K), T.float16)): - with T.Kernel(T.ceildiv(M, block_m), T.ceildiv(K, block_k), threads=threads) as ( - pid_m, - pid_k, - ): - x_shared = T.alloc_shared((block_m, block_k), dtype=T.float16) - T.fill(x_shared, 0) - T.copy( - x[ - pid_m * block_m : (pid_m + 1) * block_m, - pid_k * block_k : (pid_k + 1) * block_k, - ], - x_shared, - ) - - pass_configs = { - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - } - kernel = _compile_tvm_ffi(tma_copy, pass_configs) - - src = kernel.get_kernel_source() - assert "tl::tma_load" in src - assert "mbarrier_mem" in src - assert "arrive_and_expect_tx" in src - assert "expect_transaction" not in src - assert ".arrive();" not in src - - x = torch.randn((M, K), device="cuda", dtype=torch.float16) - kernel(x) - torch.cuda.synchronize() - - -@tilelang.testing.requires_cuda_compute_version(9, 0) -def test_tma_lower_1d_no_warp_specialized(): - """Regression for issue #1842: 1D TMA load fails when warp specialization is disabled. - - A single-dimension tensor copy (global -> shared -> global) using 1D bulk - TMA must compile and produce correct results when - ``tl.disable_warp_specialized=True``. - """ - - length = 7168 - - @T.prim_func - def tma_copy_1d( - a: T.Tensor((length,), T.float32), - out: T.Tensor((length,), T.float32), - ): - with T.Kernel(1, threads=256): - a_shared = T.alloc_shared((length,), T.float32) - T.copy(a, a_shared) - T.copy(a_shared, out) - - pass_configs = { - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - } - kernel = _compile_tvm_ffi(tma_copy_1d, pass_configs, out_idx=[1]) - - src = kernel.get_kernel_source() - assert "tl::tma_load" in src - assert "mbarrier_mem" in src - assert "tl::tma_store" in src - - t = torch.randn((length,), device="cuda", dtype=torch.float32) - out = kernel(t) - torch.testing.assert_close(out, t) - torch.cuda.synchronize() - - -@tilelang.testing.requires_cuda_compute_version(9, 0) -def test_tma_lower_no_warp_specialized_2d_descriptor_uses_args1_barrier(): - """Cover the 2D-descriptor TMA barrier rewrite path (barrier at args[1]).""" - - M, K = 16, 256 - block_m, block_k = 4, 128 - threads = 32 - - @T.prim_func - def tma_copy_2d_desc(x: T.Tensor((M, K), T.float16)): - with T.Kernel(T.ceildiv(M, block_m), T.ceildiv(K, block_k), threads=threads) as ( - pid_m, - pid_k, - ): - x_shared = T.alloc_shared((block_m, block_k), dtype=T.float16) - T.fill(x_shared, 0) - T.copy( - x[ - pid_m * block_m : (pid_m + 1) * block_m, - pid_k * block_k : (pid_k + 1) * block_k, - ], - x_shared, - ) - - pass_configs = { - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - } - - kernel = _compile_tvm_ffi(tma_copy_2d_desc, pass_configs) - - src = kernel.get_kernel_source() - assert "CUtensorMap" in src - assert "tl::tma_load" in src - - flat_src = " ".join(src.split()) - pattern = r"tl::tma_load\([^,]+,\s*mbarrier\[[0-9]+\]" - assert re.search(pattern, flat_src), ( - f"Expected regex {pattern!r} to match flattened CUDA source. Generated source (truncated):\n{src[:1000]}" - ) - - x = torch.randn((M, K), device="cuda", dtype=torch.float16) - kernel(x) - torch.cuda.synchronize() - - @tilelang.testing.requires_cuda_compute_version(9, 0) def test_num_stages_zero_pure_tma_does_not_auto_warp_specialize(): - """num_stages=0 should keep pure TMA loops out of auto-WS.""" + """num_stages=0 should keep ordinary T.copy on the synchronous path.""" M, K = 8, 256 block_m, block_k = 4, 128 @@ -186,15 +52,11 @@ def copy_loop_num_stages_zero( ], ) - pass_configs = { - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False, - } + pass_configs = {tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False} kernel = _compile_tvm_ffi(copy_loop_num_stages_zero, pass_configs, out_idx=[1]) src = kernel.get_kernel_source() - assert "tl::tma_load" in src + assert "tl::tma_load" not in src assert "__launch_bounds__(160, 1)" not in src assert "if (32 <= ((int)threadIdx.x))" not in src @@ -235,11 +97,7 @@ def copy_loop_num_stages_one( ], ) - pass_configs = { - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False, - } + pass_configs = {tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False} kernel = _compile_tvm_ffi(copy_loop_num_stages_one, pass_configs, out_idx=[1]) src = kernel.get_kernel_source() @@ -278,11 +136,7 @@ def cp_async_only_num_stages_zero( for i in T.serial(bytes_per_copy): y[ko * bytes_per_copy + i] = x_shared[i] - pass_configs = { - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False, - } + pass_configs = {tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False} kernel = _compile_tvm_ffi(cp_async_only_num_stages_zero, pass_configs, out_idx=[1]) src = kernel.get_kernel_source() @@ -297,6 +151,46 @@ def cp_async_only_num_stages_zero( torch.cuda.synchronize() +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_num_stages_one_cp_async_only_keeps_non_ws_launch_shape(): + """Stage-1 cp.async-only loops should stay non-WS on Hopper.""" + + bytes_per_copy = 16 + threads = 32 + + @T.prim_func + def cp_async_only_num_stages_one( + x: T.Tensor((4 * bytes_per_copy,), T.uint8), + y: T.Tensor((4 * bytes_per_copy,), T.uint8), + ): + with T.Kernel(1, threads=threads): + x_shared = T.alloc_shared((bytes_per_copy,), dtype=T.uint8) + for ko in T.Pipelined(4, num_stages=1): + T.ptx_cp_async( + T.access_ptr(x_shared[0], "w", bytes_per_copy), + T.access_ptr(x[ko * bytes_per_copy], "r", bytes_per_copy), + bytes_per_copy, + ) + T.ptx_commit_group() + T.ptx_wait_group(0) + for i in T.serial(bytes_per_copy): + y[ko * bytes_per_copy + i] = x_shared[i] + + pass_configs = {tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False} + kernel = _compile_tvm_ffi(cp_async_only_num_stages_one, pass_configs, out_idx=[1]) + + src = kernel.get_kernel_source() + assert "cp_async_gs<16>" in src + assert "__launch_bounds__(32, 1)" in src + assert "__launch_bounds__(160, 1)" not in src + assert "if (32 <= ((int)threadIdx.x))" not in src + + x = torch.randint(0, 256, (4 * bytes_per_copy,), device="cuda", dtype=torch.uint8) + y = kernel(x) + torch.testing.assert_close(y, x) + torch.cuda.synchronize() + + @tilelang.testing.requires_cuda_compute_version(9, 0) def test_num_stages_one_mixed_tma_cp_async_keeps_auto_ws(): """Mixed TMA+cp.async loops should auto-WS when num_stages is enabled.""" @@ -342,11 +236,7 @@ def mixed_async_num_stages_one( for i in T.serial(cp_async_bytes): meta_out[ko * cp_async_bytes + i] = meta_shared[i] - pass_configs = { - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False, - } + pass_configs = {tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False} kernel = _compile_tvm_ffi(mixed_async_num_stages_one, pass_configs, out_idx=[2, 3]) src = kernel.get_kernel_source() @@ -389,11 +279,7 @@ def mixed_gemm_shared_barrier( T.copy(C_local, C[by * block_m, bx * block_n]) - pass_configs = { - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False, - } + pass_configs = {tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False} kernel = _compile_tvm_ffi(mixed_gemm_shared_barrier, pass_configs, out_idx=[2]) src = kernel.get_kernel_source() @@ -406,7 +292,11 @@ def mixed_gemm_shared_barrier( assert ".expect_transaction(8192);" in src assert src.count(".init(128);") == 6 assert ".init(1);" not in src - assert "tl::mbarrier_cp_async_arrive_noinc(mbarrier[(ko % 3)])" in flat_src + # Mixed TMA+cp.async should reuse the same forward barrier set. Depending on + # when cp.async is lowered, this may appear either as an explicit + # noinc-arrive or as a regular arrive on the same forward barrier after the + # cp.async visibility sync. + assert "tl::mbarrier_cp_async_arrive_noinc(mbarrier[(ko % 3)])" in flat_src or "mbarrier[(ko % 3)].arrive();" in flat_src assert "tl::mbarrier_cp_async_arrive_noinc(mbarrier[((ko % 3) + 4)])" not in flat_src assert "mbarrier[((ko % 3) + 7)]" not in flat_src assert "mbarrier[((ko % 3) + 10)]" not in flat_src @@ -419,7 +309,7 @@ def mixed_gemm_shared_barrier( torch.testing.assert_close(c, ref, rtol=1e-2, atol=1e-2) -@tilelang.testing.requires_cuda_compute_version(9, 0) +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) def test_sparse_ws_regular_metadata_copy_stays_in_producer(): """Ordinary global->shared metadata copies should stay in the producer.""" @@ -458,11 +348,7 @@ def sparse_tensorcore_metadata_copy( T.copy(C_local, C[by * block_m, bx * block_n]) - pass_configs = { - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False, - } + pass_configs = {tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False} kernel = _compile_tvm_ffi(sparse_tensorcore_metadata_copy, pass_configs, out_idx=[3]) src = kernel.get_kernel_source() @@ -557,11 +443,7 @@ def sparse_flash_attn( T.copy(acc_o, O_shared) T.copy(O_shared, Output[bz, by, bx * block_m : (bx + 1) * block_m, :]) - pass_configs = { - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False, - } + pass_configs = {tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False} kernel = _compile_tvm_ffi(sparse_flash_attn, pass_configs, out_idx=[4]) src = kernel.get_kernel_source() diff --git a/testing/python/issue/test_tilelang_issue_ws_simt_copy_full_producer_extent.py b/testing/python/issue/test_tilelang_issue_ws_simt_copy_full_producer_extent.py index 299fbbe8b7..34b1b1dba1 100644 --- a/testing/python/issue/test_tilelang_issue_ws_simt_copy_full_producer_extent.py +++ b/testing/python/issue/test_tilelang_issue_ws_simt_copy_full_producer_extent.py @@ -20,7 +20,7 @@ def _compile_tvm_ffi(func, pass_configs=None): @tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version(9, 0) +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) def test_ws_keeps_full_producer_extent_for_lowered_simt_copy(): M, N, K = 128, 64, 64 block_M, block_N, block_K = 64, 64, 32 @@ -66,7 +66,7 @@ def main( assert "__launch_bounds__(512, 1)" in src assert "if (256 <= ((int)threadIdx.x)) {" in flat_src - assert "tl::tl_shuffle_elect<256>()" in src + assert "tl::tl_shuffle_elect<256>()" in src or "if (((int)threadIdx.x) == 256) {" in src assert re.search(r"tl::__sync_thread_partial<\d+, 256>\(\);", src), src diff --git a/testing/python/kernel/test_tilelang_kernel_bf16_gemm_tcgen5_ts.py b/testing/python/kernel/test_tilelang_kernel_bf16_gemm_tcgen5_ts.py index d5f44d0f84..0231d1d67b 100644 --- a/testing/python/kernel/test_tilelang_kernel_bf16_gemm_tcgen5_ts.py +++ b/testing/python/kernel/test_tilelang_kernel_bf16_gemm_tcgen5_ts.py @@ -11,10 +11,7 @@ tilelang.testing.set_random_seed(0) -PASS_CFG = { - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, -} +PASS_CFG = {tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} def matmul_ss(M, N, K, bM, bN, bK, in_dtype, out_dtype, accum_dtype, threads): diff --git a/testing/python/kernel/test_tilelang_kernel_gemm_batched.py b/testing/python/kernel/test_tilelang_kernel_gemm_batched.py index 9478f604c8..c73923e0e9 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm_batched.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm_batched.py @@ -88,10 +88,7 @@ def run_gemm_batched( kernel = tilelang.compile( program, out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) profiler = kernel.get_profiler() diff --git a/testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py b/testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py index 1f76600325..8f5509f280 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py @@ -55,10 +55,7 @@ def run_gemm_with_stride_ss(M: int, N: int, K: int, block_M: int, block_N: int, func, out_idx=[2], target="cuda", - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) # Create random input tensors on the GPU a = torch.randn(M, K, device="cuda", dtype=torch.float16) diff --git a/testing/python/kernel/test_tilelang_kernel_int8_gemm_tcgen5.py b/testing/python/kernel/test_tilelang_kernel_int8_gemm_tcgen5.py index 976c0c5b16..a8dbd5258d 100644 --- a/testing/python/kernel/test_tilelang_kernel_int8_gemm_tcgen5.py +++ b/testing/python/kernel/test_tilelang_kernel_int8_gemm_tcgen5.py @@ -44,10 +44,7 @@ def assert_matmul_correctness(M, N, K, block_M, block_N, block_K, in_dtype, out_ func, out_idx=-1, target="cuda", - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) assert out_dtype in [T.int32], "Currently only int32 is supported" diff --git a/testing/python/language/test_tilelang_language_all_of.py b/testing/python/language/test_tilelang_language_all_of.py index db694d3376..83586d506a 100644 --- a/testing/python/language/test_tilelang_language_all_of.py +++ b/testing/python/language/test_tilelang_language_all_of.py @@ -230,10 +230,7 @@ def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, conditi kernel = tilelang.compile( func, out_idx=-1, - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) # Create block mask with desired sparsity mask_shape = (M // block_M, N // block_N, K // block_K) @@ -278,10 +275,7 @@ def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, conditio kernel = tilelang.compile( func, out_idx=-1, - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) # Create block mask with desired sparsity mask_shape = (M // block_M, N // block_N, K // block_K) diff --git a/testing/python/language/test_tilelang_language_annotate_safe_value.py b/testing/python/language/test_tilelang_language_annotate_safe_value.py index be5ef5fdf1..1207277e72 100644 --- a/testing/python/language/test_tilelang_language_annotate_safe_value.py +++ b/testing/python/language/test_tilelang_language_annotate_safe_value.py @@ -31,10 +31,7 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16, kernel = tilelang.compile( program, out_idx=[1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) diff --git a/testing/python/language/test_tilelang_language_any_of.py b/testing/python/language/test_tilelang_language_any_of.py index 74db94f7c2..46a834c698 100644 --- a/testing/python/language/test_tilelang_language_any_of.py +++ b/testing/python/language/test_tilelang_language_any_of.py @@ -230,10 +230,7 @@ def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, conditi kernel = tilelang.compile( func, out_idx=-1, - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) # Create block mask with desired sparsity mask_shape = (M // block_M, N // block_N, K // block_K) @@ -278,10 +275,7 @@ def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, conditio kernel = tilelang.compile( func, out_idx=-1, - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) # Create block mask with desired sparsity mask_shape = (M // block_M, N // block_N, K // block_K) diff --git a/testing/python/language/test_tilelang_language_chain_equal.py b/testing/python/language/test_tilelang_language_chain_equal.py index 083eefdcb4..65393cf8a1 100644 --- a/testing/python/language/test_tilelang_language_chain_equal.py +++ b/testing/python/language/test_tilelang_language_chain_equal.py @@ -5,10 +5,7 @@ @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def chain_equal(N, block_size, dtype=T.float32): @T.prim_func diff --git a/testing/python/language/test_tilelang_language_clear.py b/testing/python/language/test_tilelang_language_clear.py index 2e4c732fcf..c3e9df24e3 100644 --- a/testing/python/language/test_tilelang_language_clear.py +++ b/testing/python/language/test_tilelang_language_clear.py @@ -41,7 +41,7 @@ def main( def run_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): program = matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) - kernel = tilelang.compile(program, out_idx=[2], pass_configs={"tl.disable_tma_lower": True}) + kernel = tilelang.compile(program, out_idx=[2]) import torch from tilelang.utils import map_torch_type diff --git a/testing/python/language/test_tilelang_language_composable_index.py b/testing/python/language/test_tilelang_language_composable_index.py index 09f9ad9c45..51fc21b873 100644 --- a/testing/python/language/test_tilelang_language_composable_index.py +++ b/testing/python/language/test_tilelang_language_composable_index.py @@ -30,10 +30,7 @@ def run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128, dtype kernel = tilelang.compile( program, out_idx=[1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) diff --git a/testing/python/language/test_tilelang_language_copy.py b/testing/python/language/test_tilelang_language_copy.py index 1943999322..2efa2784af 100644 --- a/testing/python/language/test_tilelang_language_copy.py +++ b/testing/python/language/test_tilelang_language_copy.py @@ -29,7 +29,7 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16) kernel = tilelang.compile( program, out_idx=[1], - pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) source = kernel.get_kernel_source() print(source) @@ -65,10 +65,7 @@ def run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N= kernel = tilelang.compile( program, out_idx=[1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) if isinstance(NN, T.Var): NN = N * 2 @@ -102,7 +99,7 @@ def run_tilelang_copy_bufferload(num_tokens=128, dtype=T.float16): tilelang.compile( program, out_idx=[1], - pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) @@ -129,7 +126,7 @@ def run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, blo kernel = tilelang.compile( program, out_idx=[1], - pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) @@ -158,7 +155,7 @@ def run_tilelang_copy_shape_mismatched(M=1024, N=1024, dtype=T.float16): kernel = tilelang.compile( program, out_idx=[1], - pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) diff --git a/testing/python/language/test_tilelang_language_let_layout.py b/testing/python/language/test_tilelang_language_let_layout.py index fec30b914b..d413ea519f 100644 --- a/testing/python/language/test_tilelang_language_let_layout.py +++ b/testing/python/language/test_tilelang_language_let_layout.py @@ -112,10 +112,7 @@ def test_blocksparse_copy_cp_async(): N=1024, block_M=128, block_N=128, - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) diff --git a/testing/python/language/test_tilelang_language_mask_op.py b/testing/python/language/test_tilelang_language_mask_op.py index cd899a606a..abc54cef14 100644 --- a/testing/python/language/test_tilelang_language_mask_op.py +++ b/testing/python/language/test_tilelang_language_mask_op.py @@ -28,7 +28,7 @@ def main( def run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16): program = tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype) - kernel = tilelang.compile(program, out_idx=[1], pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}) + kernel = tilelang.compile(program, out_idx=[1], pass_configs={"tl.disable_warp_specialized": True}) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) @@ -62,7 +62,7 @@ def main( def run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16): program = tilelang_copy_mask_copy(M, N, block_M, block_N, dtype) - kernel = tilelang.compile(program, out_idx=[1], pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}) + kernel = tilelang.compile(program, out_idx=[1], pass_configs={"tl.disable_warp_specialized": True}) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) @@ -97,7 +97,7 @@ def main( def run_tilelang_copy_mask_parallel_range(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16): program = tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype) - kernel = tilelang.compile(program, out_idx=[1], pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}) + kernel = tilelang.compile(program, out_idx=[1], pass_configs={"tl.disable_warp_specialized": True}) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) @@ -131,7 +131,7 @@ def main( def run_tilelang_copy_mask_copy_range(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16): program = tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype) - kernel = tilelang.compile(program, out_idx=[1], pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}) + kernel = tilelang.compile(program, out_idx=[1], pass_configs={"tl.disable_warp_specialized": True}) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) diff --git a/testing/python/language/test_tilelang_language_pipeline.py b/testing/python/language/test_tilelang_language_pipeline.py index 8136e246f0..e0cb4612d4 100644 --- a/testing/python/language/test_tilelang_language_pipeline.py +++ b/testing/python/language/test_tilelang_language_pipeline.py @@ -86,10 +86,7 @@ def run_gemm( kernel = tilelang.compile( program, out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) profiler = kernel.get_profiler() @@ -121,10 +118,7 @@ def test_pipeline_order_stage(): @tilelang.jit( out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def blocksparse_matmul(M, N, K, block_M, block_N, block_K, num_stages, dtype=T.float16, accum_dtype=T.float32): block_mask_shape = (M // block_M, N // block_N, K // block_K) diff --git a/testing/python/language/test_tilelang_language_ptr.py b/testing/python/language/test_tilelang_language_ptr.py index 314f4fdcc2..41c6fa9f4d 100644 --- a/testing/python/language/test_tilelang_language_ptr.py +++ b/testing/python/language/test_tilelang_language_ptr.py @@ -166,7 +166,7 @@ def run_pointer_table_multi_copy(G, N, dtype=T.float16): def run_pointer_table_grouped_matmul(batch_sizes_list, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): program = pointer_table_grouped_matmul_test(batch_sizes_list, N, K, block_M, block_N, block_K, dtype, accum_dtype) - compile_kwargs = {"pass_configs": {"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}} + compile_kwargs = {"pass_configs": {"tl.disable_warp_specialized": True}} cython_jit_kernel = tl.compile(program, execution_backend="cython", **compile_kwargs) ffi_jit_kernel = tl.compile(program, execution_backend="tvm_ffi", **compile_kwargs) diff --git a/testing/python/language/test_tilelang_language_reshape.py b/testing/python/language/test_tilelang_language_reshape.py index 10c3d0ce87..c7ff50c145 100644 --- a/testing/python/language/test_tilelang_language_reshape.py +++ b/testing/python/language/test_tilelang_language_reshape.py @@ -25,10 +25,7 @@ def run_reshape(N, M, dtype): jit_kernel = tl.compile( program, out_idx=-1, - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) profiler = jit_kernel.get_profiler() @@ -68,10 +65,7 @@ def run_reshape_smem_1d_2_2d(N, M, dtype): jit_kernel = tl.compile( program, out_idx=-1, - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) profiler = jit_kernel.get_profiler() @@ -110,10 +104,7 @@ def run_reshape_smem_2d_2_1d(N, M, dtype): jit_kernel = tl.compile( program, out_idx=-1, - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) profiler = jit_kernel.get_profiler() @@ -153,10 +144,7 @@ def run_reshape_fragment(N, M, dtype): jit_kernel = tl.compile( program, out_idx=-1, - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) profiler = jit_kernel.get_profiler() @@ -199,10 +187,7 @@ def run_reshape_layout_transform_shared(N, M, dtype): jit_kernel = tl.compile( program, out_idx=-1, - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) profiler = jit_kernel.get_profiler() @@ -242,10 +227,7 @@ def run_reduce_after_reshape(N, M, dtype): jit_kernel = tl.compile( program, out_idx=-1, - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) profiler = jit_kernel.get_profiler() diff --git a/testing/python/language/test_tilelang_language_tma_copy.py b/testing/python/language/test_tilelang_language_tma_copy.py index 5811b10daa..c81721fa2d 100644 --- a/testing/python/language/test_tilelang_language_tma_copy.py +++ b/testing/python/language/test_tilelang_language_tma_copy.py @@ -4,7 +4,7 @@ T.tma_copy() emits only expect_tx + tma_load (no arrive, no wait). The user must explicitly call T.barrier_arrive() and T.mbarrier_wait_parity(). This allows multiple tma_copy operations to share a single barrier arrive. - MultiVersionBuffer expands the barrier to num_stages versions automatically. + Pipeline buffer versioning expands the barrier to num_stages versions automatically. For TMA stores (shared -> global): T.tma_copy() emits tma_store + tma_store_arrive (no wait). @@ -84,9 +84,7 @@ def run_gemm_tma_copy(num_stages): kernel = tilelang.compile( program, out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) print(kernel.get_kernel_source()) profiler = kernel.get_profiler() @@ -183,9 +181,7 @@ def run_gemm_tma_copy_store(num_stages): kernel = tilelang.compile( program, out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) print(kernel.get_kernel_source()) profiler = kernel.get_profiler() @@ -213,3 +209,4 @@ def test_tma_copy_store_pipeline_3_stages(): if __name__ == "__main__": tilelang.testing.main() + # test_tma_copy_pipeline_2_stages() diff --git a/testing/python/language/test_tilelang_language_tma_store.py b/testing/python/language/test_tilelang_language_tma_store.py index de4d0b97af..f8b79143db 100644 --- a/testing/python/language/test_tilelang_language_tma_store.py +++ b/testing/python/language/test_tilelang_language_tma_store.py @@ -1,8 +1,12 @@ -"""Test T.tma_copy() for TMA store (shared -> global) with user-managed synchronization. +"""Tests for TMA store (shared -> global). -T.tma_copy(shared_buf, global_buf) emits tma_store + tma_store_arrive (no wait). -The user must explicitly call T.tma_store_wait() for synchronization. -No barrier argument is needed for stores. +Explicit T.tma_copy(shared_buf, global_buf) emits tma_store + tma_store_arrive +(no wait). The user must explicitly call T.tma_store_wait() for +synchronization. + +Plain T.copy(shared_buf, global_buf) may also auto-lower to tma_store when the +store-side TMA constraints are satisfied. In that case lowering emits both +tma_store_arrive and tma_store_wait automatically. """ from tilelang import tvm as tvm @@ -76,12 +80,11 @@ def run_gemm_tma_store(num_stages): kernel = tilelang.compile( program, out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) kernel_source = kernel.get_kernel_source() print(kernel_source) + # exit() # Verify that the generated kernel contains tma_store_arrive but NOT tma_store_wait # (the wait is issued separately by the user via T.tma_store_wait) assert "tma_store_arrive" in kernel_source, "Expected tma_store_arrive in kernel source" @@ -97,6 +100,44 @@ def ref_program(A, B): profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) +def auto_tma_store_copy(M, N, block_M, block_N, dtype, threads): + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype) + T.copy(A[by * block_M, bx * block_N], A_shared) + T.copy(A_shared, C[by * block_M, bx * block_N]) + + return main + + +def run_auto_tma_store_copy(): + M = N = 256 + block_M = block_N = 128 + dtype = T.float16 + threads = 128 + + program = auto_tma_store_copy(M, N, block_M, block_N, dtype, threads) + kernel = tilelang.compile( + program, + out_idx=[1], + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, + ) + kernel_source = kernel.get_kernel_source() + assert "tma_store_arrive" in kernel_source, "Expected auto tma_store_arrive in kernel source" + assert "tma_store_wait" in kernel_source, "Expected auto tma_store_wait in kernel source" + + profiler = kernel.get_profiler() + + def ref_program(A): + return A + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_tma_store_2_stages(): @@ -109,5 +150,12 @@ def test_tma_store_3_stages(): run_gemm_tma_store(num_stages=3) +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_plain_copy_auto_tma_store(): + run_auto_tma_store_copy() + + if __name__ == "__main__": - tilelang.testing.main() + # tilelang.testing.main() + test_tma_store_2_stages() diff --git a/testing/python/language/test_tilelang_language_transpose.py b/testing/python/language/test_tilelang_language_transpose.py index 8d9b8bd34c..2bf1633525 100644 --- a/testing/python/language/test_tilelang_language_transpose.py +++ b/testing/python/language/test_tilelang_language_transpose.py @@ -46,10 +46,7 @@ def run_tilelang_transpose(M=128, N=128, block_M=128, block_N=128, dtype=T.float kernel = tilelang.compile( program, out_idx=[1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) @@ -88,10 +85,7 @@ def run_tilelang_transpose_square(M=256, block_M=128, dtype=T.float16): kernel = tilelang.compile( program, out_idx=[1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) a = torch.randn(M, M, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) diff --git a/testing/python/language/test_tilelang_language_wgmma_gemm.py b/testing/python/language/test_tilelang_language_wgmma_gemm.py index 2d575ef383..4f5c573837 100644 --- a/testing/python/language/test_tilelang_language_wgmma_gemm.py +++ b/testing/python/language/test_tilelang_language_wgmma_gemm.py @@ -28,7 +28,7 @@ def main( @tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) @pytest.mark.parametrize( "gemm_api", [T.wgmma_gemm], @@ -41,7 +41,7 @@ def test_wgmma_gemm_has_no_implicit_wait(gemm_api): @tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) def test_wgmma_gemm_dispatch_has_no_implicit_wait(): kernel = tilelang.compile( _make_wgmma_kernel(lambda A, B, C: T.wgmma_gemm(A, B, C, clear_accum=True)), @@ -52,7 +52,7 @@ def test_wgmma_gemm_dispatch_has_no_implicit_wait(): @tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) def test_wgmma_gemm_rejects_mma_fallback(): @T.prim_func def main( diff --git a/testing/python/language/test_tilelang_memory_leak.py b/testing/python/language/test_tilelang_memory_leak.py index 7da187fa37..f58b884d76 100644 --- a/testing/python/language/test_tilelang_memory_leak.py +++ b/testing/python/language/test_tilelang_memory_leak.py @@ -9,10 +9,7 @@ def test_tilelang_globals_leak(): @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def get_dummy_kernel(): @T.prim_func diff --git a/testing/python/layout/test_tilelang_annotate_loop_layout.py b/testing/python/layout/test_tilelang_annotate_loop_layout.py index 52653a9d12..4a50b98ffd 100644 --- a/testing/python/layout/test_tilelang_annotate_loop_layout.py +++ b/testing/python/layout/test_tilelang_annotate_loop_layout.py @@ -32,6 +32,7 @@ def loop_layout_fn(i, j): @tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_lt(9, 0) def test_loop_layout_identity(): def loop_layout_fn(i, j): forward_thread = i diff --git a/testing/python/runtime/test_tilelang_runtime_tma_validation.py b/testing/python/runtime/test_tilelang_runtime_tma_validation.py index 140f13a2ff..a3c9da5138 100644 --- a/testing/python/runtime/test_tilelang_runtime_tma_validation.py +++ b/testing/python/runtime/test_tilelang_runtime_tma_validation.py @@ -79,14 +79,17 @@ def tma_copy_2d_desc( ): with T.Kernel(T.ceildiv(m, block_m), T.ceildiv(k, block_k), threads=threads) as (pid_m, pid_k): x_shared = T.alloc_shared((block_m, block_k), dtype=T.float16) - T.fill(x_shared, 0) - T.copy( + mbar = T.alloc_barrier(1) + T.tma_copy( x[ pid_m * block_m : (pid_m + 1) * block_m, pid_k * block_k : (pid_k + 1) * block_k, ], x_shared, + barrier=mbar, ) + T.barrier_arrive(mbar) + T.mbarrier_wait_parity(mbar, 0) T.copy( x_shared, y[ @@ -97,11 +100,7 @@ def tma_copy_2d_desc( kernel = _compile_tvm_ffi( tma_copy_2d_desc, - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) source = kernel.get_host_source() diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py index e2a2175632..e2d7a5ee3e 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py @@ -86,10 +86,7 @@ def run_gemm_ss( kernel = tilelang.compile( program, out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) @@ -244,10 +241,7 @@ def run_gemm_rs( kernel = tilelang.compile( program, out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) @@ -401,10 +395,7 @@ def run_gemm_sr( kernel = tilelang.compile( program, out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) @@ -562,10 +553,7 @@ def run_gemm_rr( kernel = tilelang.compile( program, out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py index d3ef1d5879..cfd0f75e0a 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py @@ -304,7 +304,7 @@ def run_gemm_sp_sm80( @tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version(9, 0) +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) @pytest.mark.parametrize( "M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B", [ diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py index 3f3273b9aa..6ec5718e8a 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py @@ -104,10 +104,7 @@ def run_gemm_ss( kernel = tilelang.compile( program, out_idx=[3], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) @@ -273,10 +270,7 @@ def run_gemm_rs( kernel = tilelang.compile( program, out_idx=[3], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") @@ -424,10 +418,7 @@ def run_gemm_sr( kernel = tilelang.compile( program, out_idx=[3], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") @@ -579,10 +570,7 @@ def run_gemm_rr( kernel = tilelang.compile( program, out_idx=[3], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") diff --git a/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py b/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py index 4f64fc18f5..eb71215ccb 100644 --- a/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py +++ b/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py @@ -1,9 +1,31 @@ from tilelang import tvm as tvm import tilelang as tl import tilelang.language as T +from tilelang.layout import Layout import tilelang.testing from tvm.tir.stmt_functor import post_order_visit +_MVB_ATTR_KEYS = frozenset( + [ + "tl.pipeline_mvb_num_stages", + "tl.pipeline_mvb_stage_expr", + "tl.pipeline_mvb_parity_expr", + "tl.pipeline_context_num_stages", + ] +) + + +@tvm.tir.transform.prim_func_pass(opt_level=0) +def _strip_mvb_attrs(func, mod, ctx): + """Remove intermediate MVB attributes that are consumed by later passes.""" + + def _visit(stmt): + if isinstance(stmt, tvm.tir.AttrStmt) and str(stmt.attr_key) in _MVB_ATTR_KEYS: + return stmt.body + return None + + return func.with_body(tvm.tir.stmt_functor.ir_transform(func.body, None, _visit, ["tir.AttrStmt"])) + def _check(original, transformed): func = original @@ -12,6 +34,7 @@ def _check(original, transformed): mod = tl.transform.Simplify()(mod) mod = tl.transform.LowerOpaqueBlock()(mod) mod = tl.transform.Simplify()(mod) + mod = _strip_mvb_attrs(mod) tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), True) @@ -31,6 +54,93 @@ def _visit(node): return attr_count, call_count +def _collect_attr_values(func, attr_key): + values = [] + stmt = func.body if hasattr(func, "body") else func + + def _visit(node): + if isinstance(node, tvm.tir.AttrStmt) and str(node.attr_key) == attr_key: + value = node.value + if isinstance(value, tvm.tir.IntImm): + values.append(int(value.value)) + + post_order_visit(stmt, _visit) + return values + + +def _collect_attr_value_nodes(func, attr_key): + values = [] + + def _visit(node): + if isinstance(node, tvm.tir.AttrStmt) and str(node.attr_key) == attr_key: + values.append(node.value) + + post_order_visit(func.body, _visit) + return values + + +def _collect_wait_args(func): + wait_args = [] + stmt = func.body if hasattr(func, "body") else func + + def _visit(node): + if ( + isinstance(node, tvm.tir.Call) + and isinstance(node.op, tvm.ir.Op) + and str(node.op.name) == "tir.ptx_wait_group" + and len(node.args) == 1 + ): + arg = node.args[0] + if isinstance(arg, tvm.tir.IntImm): + wait_args.append(int(arg.value)) + + post_order_visit(stmt, _visit) + return wait_args + + +def _find_pipelined_loop(func): + loops = [] + + def _visit(node): + if isinstance(node, tvm.tir.For) and "tl_pipelined_num_stages" in node.annotations: + loops.append(node) + + post_order_visit(func.body, _visit) + assert loops, "Expected at least one loop annotated with tl_pipelined_num_stages" + return loops[0] + + +def _count_copy_calls_with_annotation(func, annotation_key): + annotated = 0 + total = 0 + + def _visit(node): + nonlocal annotated, total + if not isinstance(node, tvm.tir.Call) or not isinstance(node.op, tvm.ir.Op): + return + if str(node.op.name) not in {"tl.tileop.copy", "tl.tileop.async_copy"}: + return + total += 1 + value = node.annotations.get(annotation_key) if node.annotations else None + if isinstance(value, tvm.tir.IntImm) and int(value.value) != 0: + annotated += 1 + + post_order_visit(func.body, _visit) + return annotated, total + + +def _find_block_with_layout_map(func): + blocks = [] + + def _visit(node): + if isinstance(node, tvm.tir.Block) and "layout_map" in node.annotations: + blocks.append(node) + + post_order_visit(func.body, _visit) + assert blocks, "Expected at least one block with layout_map" + return blocks[0] + + def test_trival_pipeline(): @T.prim_func def before(A: T.Tensor((16, 1), T.float32), C: T.Tensor((16, 1), T.float32)): @@ -104,5 +214,289 @@ def before(A: T.Tensor((16,), T.uint8), B: T.Tensor((16,), T.uint8)): assert calls.get("tir.ptx_wait_group", 0) > 0 +def test_async_pipeline_groups_multiple_copy_producers(): + @T.prim_func + def before( + A: T.Tensor((16, 16), T.float32), + B: T.Tensor((16, 16), T.float32), + C: T.Tensor((16, 16), T.float32), + ): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 0, + 4, + annotations={ + "software_pipeline_stage": [0, 0, 1], + "software_pipeline_order": [0, 1, 2], + "software_pipeline_async_stages": [0], + "software_pipeline_async_producers": [1, 1, 0], + "software_pipeline_async_producer_groups": [0, 0, -1], + }, + ): + with T.block("compute"): + T.reads(A[tx, i], B[tx, i]) + T.writes(C[tx, i]) + A_shared = T.alloc_buffer((16, 1), dtype=T.float32, scope="shared") + B_shared = T.alloc_buffer((16, 1), dtype=T.float32, scope="shared") + with T.block("copy_a"): + T.reads(A[tx, i]) + T.writes(A_shared[tx, 0]) + A_shared[tx, 0] = A[tx, i] + with T.block("copy_b"): + T.reads(B[tx, i]) + T.writes(B_shared[tx, 0]) + B_shared[tx, 0] = B[tx, i] + with T.block("consume"): + T.reads(A_shared[tx, 0], B_shared[tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = A_shared[tx, 0] + B_shared[tx, 0] + + mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) + mod = tl.transform.InjectSoftwarePipeline()(mod) + mod = tl.transform.Simplify()(mod) + mod = tl.transform.LowerOpaqueBlock()(mod) + mod = tl.transform.Simplify()(mod) + + attrs, calls = _count_attrs_and_calls(mod["main"]) + assert attrs.get("async_scope", 0) > 0 + assert attrs.get("async_commit_queue_scope", 0) == 0 + assert attrs.get("async_wait_queue_scope", 0) == 0 + assert attrs.get("async_wait_inflight_count", 0) == 0 + assert calls.get("tir.ptx_commit_group", 0) > 0 + assert 1 in _collect_wait_args(mod["main"]) + + +def test_async_pipeline_only_wraps_producer_statements_from_explicit_group_annotations(): + @T.prim_func + def before( + A: T.Tensor((16, 16), T.float32), + B: T.Tensor((16, 16), T.float32), + C: T.Tensor((16, 16), T.float32), + ): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 0, + 4, + annotations={ + "software_pipeline_stage": [0, 0, 0, 1], + "software_pipeline_order": [0, 1, 2, 3], + "software_pipeline_async_stages": [0], + "software_pipeline_async_producers": [0, 1, 1, 0], + "software_pipeline_async_producer_groups": [-1, 0, 0, -1], + }, + ): + with T.block("compute"): + T.reads(A[tx, i], B[tx, i]) + T.writes(C[tx, i]) + A_shared = T.alloc_buffer((16, 1), dtype=T.float32, scope="shared") + B_shared = T.alloc_buffer((16, 1), dtype=T.float32, scope="shared") + with T.block("fill"): + T.reads() + T.writes(A_shared[tx, 0]) + A_shared[tx, 0] = T.float32(0) + with T.block("copy_a"): + T.reads(A[tx, i]) + T.writes(A_shared[tx, 0]) + A_shared[tx, 0] = A[tx, i] + with T.block("copy_b"): + T.reads(B[tx, i]) + T.writes(B_shared[tx, 0]) + B_shared[tx, 0] = B[tx, i] + with T.block("consume"): + T.reads(A_shared[tx, 0], B_shared[tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = A_shared[tx, 0] + B_shared[tx, 0] + + mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) + mod = tl.transform.InjectSoftwarePipeline()(mod) + mod = tl.transform.Simplify()(mod) + mod = tl.transform.LowerOpaqueBlock()(mod) + mod = tl.transform.Simplify()(mod) + + attrs, calls = _count_attrs_and_calls(mod["main"]) + # Dead prologue/epilogue producer clones are now dropped during injection, + # so only the live producer copies remain wrapped. + assert attrs.get("async_scope", 0) == 4 + assert attrs.get("async_commit_queue_scope", 0) == 0 + assert calls.get("tir.ptx_commit_group", 0) == 2 + + +def test_async_pipeline_marks_copy_ops_for_pipeline_managed_cp_async_sync(): + @T.prim_func + def before( + A: T.Tensor((16, 16), T.float32), + B: T.Tensor((16, 16), T.float32), + C: T.Tensor((16, 16), T.float32), + ): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 0, + 4, + annotations={ + "software_pipeline_stage": [0, 0, 1], + "software_pipeline_order": [0, 1, 2], + "software_pipeline_async_stages": [0], + "software_pipeline_async_producers": [1, 1, 0], + "software_pipeline_async_producer_groups": [0, 0, -1], + }, + ): + with T.block("compute"): + T.reads(A[tx, i], B[tx, i]) + T.writes(C[tx, i]) + A_shared = T.alloc_buffer((16, 1), dtype=T.float32, scope="shared") + B_shared = T.alloc_buffer((16, 1), dtype=T.float32, scope="shared") + T.copy(A[tx, i : i + 1], A_shared[tx, 0:1]) + T.copy(B[tx, i : i + 1], B_shared[tx, 0:1]) + C[tx, i] = A_shared[tx, 0] + B_shared[tx, 0] + + mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) + mod = tl.transform.InjectSoftwarePipeline()(mod) + + annotated, total = _count_copy_calls_with_annotation(mod["main"], "no_implicit_async_commit_wait") + assert total > 0 + assert annotated == total + + +def test_async_pipeline_does_not_mark_non_cp_async_compatible_copy(): + @T.prim_func + def before( + A: T.Tensor((16, 16), T.bfloat16), + C: T.Tensor((16, 16), T.float32), + ): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 0, + 4, + annotations={ + "software_pipeline_stage": [0, 1], + "software_pipeline_order": [0, 1], + "software_pipeline_async_stages": [0], + "software_pipeline_async_producers": [1, 0], + "software_pipeline_async_producer_groups": [0, -1], + }, + ): + with T.block("compute"): + T.reads(A[tx, i]) + T.writes(C[tx, i]) + S = T.alloc_buffer((16, 1), dtype=T.float32, scope="shared") + T.copy(A[tx, i : i + 1], S[tx, 0:1]) + C[tx, i] = S[tx, 0] + + mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) + mod = tl.transform.InjectSoftwarePipeline()(mod) + + annotated, total = _count_copy_calls_with_annotation(mod["main"], "no_implicit_async_commit_wait") + assert total > 0 + assert annotated == 0 + + +def test_async_pipeline_relaxes_loop_wait_and_splits_trailing_drain(): + @T.prim_func + def before(A: T.Tensor((32,), T.uint8), B: T.Tensor((32,), T.uint8)): + S = T.alloc_buffer((4,), dtype=T.uint8, scope="shared") + for i in T.serial( + 0, + 4, + annotations={ + "software_pipeline_stage": [0, 2], + "software_pipeline_order": [0, 1], + "software_pipeline_async_stages": [0], + "software_pipeline_async_producers": [1, 0], + "software_pipeline_async_producer_groups": [0, -1], + }, + ): + with T.block("copy"): + T.reads(A[i * 4 : i * 4 + 4]) + T.writes(S[0:4]) + T.copy(A[i * 4 : i * 4 + 4], S[0:4]) + with T.block("consume"): + T.reads(S[0:4]) + T.writes(B[i * 4 : i * 4 + 4]) + for j in range(4): + B[i * 4 + j] = S[j] + + mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) + mod = tl.transform.InjectSoftwarePipeline()(mod) + mod = tl.transform.Simplify()(mod) + + func = mod["main"] + loop = _find_pipelined_loop(func) + loop_waits = _collect_wait_args(loop.body) + all_waits = _collect_wait_args(func) + + assert loop_waits == [2], f"Expected relaxed loop wait to keep two groups in flight, got {loop_waits}" + assert all_waits == [2, 2, 0], f"Expected trailing waits to split into retain+drain, got {all_waits}" + + +def test_degenerate_pipeline_with_single_stage_is_not_expanded(): + @T.prim_func + def before(B: T.Tensor((128,), T.float32)): + with T.Kernel(1, threads=128) as _: + frag = T.alloc_fragment((4, 128), T.float16) + split = T.alloc_fragment((128,), T.float32) + scale = T.alloc_fragment((128,), T.float32) + for k in T.serial( + 4, + annotations={"software_pipeline_stage": [2, 2], "software_pipeline_order": [0, 1], "tl_pipelined_num_stages": 2}, + ): + for i in T.Parallel(128): + split[i] = T.Cast("float32", frag[k, i]) + for i in T.Parallel(128): + scale[i] = split[i] + B[i] = scale[i] + + mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) + mod = tl.transform.InjectSoftwarePipeline()(mod) + mod = tl.transform.Simplify()(mod) + + func = mod["main"] + attrs, calls = _count_attrs_and_calls(func) + assert attrs.get("tl.pipeline_context_num_stages", 0) == 0 + assert attrs.get("tl.pipeline_mvb_num_stages", 0) == 0 + assert attrs.get("tl.pipeline_mvb_stage_expr", 0) == 0 + assert attrs.get("tl.pipeline_mvb_parity_expr", 0) == 0 + assert calls.get("tir.ptx_wait_group", 0) == 0 + assert "tl_pipelined_num_stages" not in func.script() + assert "frag[k, i]" in func.script() + assert "frag[2, i]" not in func.script() + + +def test_inject_software_pipeline_expands_annotated_layout(): + layout = Layout([8, 16], lambda i, j: i * 16 + j) + + @T.prim_func + def before(A: T.Tensor((4, 8, 16), T.float16), B: T.Tensor((4, 8, 16), T.float16)): + with T.block("root"): + shared = T.alloc_buffer((8, 16), T.float16, scope="shared.dyn") + T.annotate_layout({shared: layout}) + for k in T.serial( + 4, + annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]}, + ): + with T.block("load"): + T.reads(A[k, 0:8, 0:16]) + T.writes(shared[0:8, 0:16]) + for i in T.serial(8): + for j in T.serial(16): + shared[i, j] = A[k, i, j] + with T.block("store"): + T.reads(shared[0:8, 0:16]) + T.writes(B[k, 0:8, 0:16]) + for i in T.serial(8): + for j in T.serial(16): + B[k, i, j] = shared[i, j] + + mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) + mod = tl.transform.InjectSoftwarePipeline()(mod) + + block = _find_block_with_layout_map(mod["main"]) + shared = next(buf for buf in block.alloc_buffers if buf.scope() == "shared.dyn") + layout_map = block.annotations["layout_map"] + + assert [int(dim) for dim in shared.shape] == [2, 8, 16] + assert list(layout_map[shared.data].get_input_shape()) == [2, 8, 16] + assert layout_map[shared.data].is_equal(layout.expand([2])) + + if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/transform/test_tilelang_transform_lower_ptx_async_copy.py b/testing/python/transform/test_tilelang_transform_lower_ptx_async_copy.py index a461d3c204..71b4c1e74a 100644 --- a/testing/python/transform/test_tilelang_transform_lower_ptx_async_copy.py +++ b/testing/python/transform/test_tilelang_transform_lower_ptx_async_copy.py @@ -56,6 +56,32 @@ def before( assert calls.get("tir.ptx_wait_group", 0) > 0 +def test_lower_ptx_async_copy_respects_explicit_async_scope(): + """`async_scope` marks explicit async semantics, so implicit sync should not be added.""" + + @T.prim_func + def before( + A: T.Tensor((16,), T.float32), + B: T.Tensor((16,), T.float32), + ): + S = T.alloc_buffer((16,), dtype=T.float32, scope="shared") + with T.attr(0, "async_scope", 1): + for i in T.Parallel(16): + S[i] = A[i] + B[0] = S[0] + + target = tvm.target.Target("cuda -arch=sm_80") + func = before.with_attr("global_symbol", "main").with_attr("target", target) + mod = tvm.IRModule.from_expr(func) + + mod = tl.transform.LowerPTXAsyncCopy()(mod) + calls = _count_calls(mod["main"]) + + assert calls.get("tir.ptx_cp_async", 0) > 0 + assert calls.get("tir.ptx_commit_group", 0) == 0 + assert calls.get("tir.ptx_wait_group", 0) == 0 + + def test_lower_ptx_async_copy_supports_multi_dim_indices(): """LowerPTXAsyncCopy should handle N-D buffer indices (pre-FlattenBuffer).""" diff --git a/testing/python/transform/test_tilelang_transform_lower_shared_barrier.py b/testing/python/transform/test_tilelang_transform_lower_shared_barrier.py index 7b1b2648fc..ae0040aa8d 100644 --- a/testing/python/transform/test_tilelang_transform_lower_shared_barrier.py +++ b/testing/python/transform/test_tilelang_transform_lower_shared_barrier.py @@ -122,10 +122,7 @@ def func(): def test_plan_update_keeps_barrier_init_with_tcgen05_no_tma(): """Regression for tcgen05 no-TMA kernels after pass reordering.""" - pass_configs = { - tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - } + pass_configs = {tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} @T.prim_func def func( diff --git a/testing/python/transform/test_tilelang_transform_lower_tile_op.py b/testing/python/transform/test_tilelang_transform_lower_tile_op.py new file mode 100644 index 0000000000..ae170a2663 --- /dev/null +++ b/testing/python/transform/test_tilelang_transform_lower_tile_op.py @@ -0,0 +1,106 @@ +"""Tests for TileLang `LowerTileOp` copy annotations affecting cp.async sync.""" + +import tilelang as tl +import tilelang.language as T +from tilelang import tvm +from tvm.tir.stmt_functor import post_order_visit + + +def _count_calls(func: tvm.tir.PrimFunc): + counts = {} + + def _visit(node): + if isinstance(node, tvm.tir.Call) and isinstance(node.op, tvm.ir.Op): + name = str(node.op.name) + counts[name] = counts.get(name, 0) + 1 + + post_order_visit(func.body, _visit) + return counts + + +def test_lower_tile_op_respects_copy_annotation_for_pipeline_managed_cp_async(): + target = tvm.target.Target("cuda -arch=sm_80") + + @T.prim_func + def before( + A: T.Tensor((16,), T.float32), + B: T.Tensor((16,), T.float32), + ): + T.func_attr({"global_symbol": "main", "target": target}) + T.launch_thread("blockIdx.x", 1) + tx = T.launch_thread("threadIdx.x", 16) + S = T.alloc_buffer((16,), dtype=T.float32, scope="shared") + T.copy( + A[0:16], + S, + annotations={"no_implicit_async_commit_wait": T.int32(1)}, + ) + B[tx] = S[tx] + + mod = tvm.IRModule.from_expr(before) + with target: + mod = tl.transform.LowerTileOp()(mod) + calls = _count_calls(mod["main"]) + + assert calls.get("tir.ptx_cp_async", 0) > 0 + assert calls.get("tir.ptx_commit_group", 0) == 0 + assert calls.get("tir.ptx_wait_group", 0) == 0 + + +def test_lower_tile_op_respects_copy_annotation_for_explicit_async_copy(): + target = tvm.target.Target("cuda -arch=sm_80") + + @T.prim_func + def before( + A: T.Tensor((16,), T.float32), + B: T.Tensor((16,), T.float32), + ): + T.func_attr({"global_symbol": "main", "target": target}) + T.launch_thread("blockIdx.x", 1) + tx = T.launch_thread("threadIdx.x", 16) + S = T.alloc_buffer((16,), dtype=T.float32, scope="shared") + T.async_copy( + A[0:16], + S, + annotations={"no_implicit_async_commit_wait": T.int32(1)}, + ) + B[tx] = S[tx] + + mod = tvm.IRModule.from_expr(before) + with target: + mod = tl.transform.LowerTileOp()(mod) + calls = _count_calls(mod["main"]) + + assert calls.get("tir.ptx_cp_async", 0) > 0 + assert calls.get("tir.ptx_commit_group", 0) == 0 + assert calls.get("tir.ptx_wait_group", 0) == 0 + + +def test_lower_tile_op_respects_parallel_loop_async_annotation_without_pipeline_context(): + target = tvm.target.Target("cuda -arch=sm_80") + + @T.prim_func + def before( + A: T.Tensor((16,), T.float32), + B: T.Tensor((16,), T.float32), + ): + T.func_attr({"global_symbol": "main", "target": target}) + T.launch_thread("blockIdx.x", 1) + tx = T.launch_thread("threadIdx.x", 16) + S = T.alloc_buffer((16,), dtype=T.float32, scope="shared") + for i in T.parallel( + 16, + annotations={"parallel_async_without_async_commit_wait": T.bool(True)}, + ): + S[i] = A[i] + B[tx] = S[tx] + + mod = tvm.IRModule.from_expr(before) + with target: + mod = tl.transform.LayoutInference()(mod) + mod = tl.transform.LowerTileOp()(mod) + calls = _count_calls(mod["main"]) + + assert calls.get("tir.ptx_cp_async", 0) > 0 + assert calls.get("tir.ptx_commit_group", 0) == 0 + assert calls.get("tir.ptx_wait_group", 0) == 0 diff --git a/testing/python/transform/test_tilelang_transform_multi_version_buffer.py b/testing/python/transform/test_tilelang_transform_multi_version_buffer.py deleted file mode 100644 index e85fd8db8d..0000000000 --- a/testing/python/transform/test_tilelang_transform_multi_version_buffer.py +++ /dev/null @@ -1,144 +0,0 @@ -from tilelang import tvm as tvm -import tilelang as tl -from tilelang.utils.target import determine_target -import tilelang.language as T -import tilelang.testing -from tvm import tir - -auto_target = tvm.target.Target(determine_target("auto")) - - -def _check(original, transformed): - func = original - mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.BindTarget(auto_target)(mod) - mod = tl.transform.MultiVersionBuffer()(mod) - mod = tir.transform.LowerOpaqueBlock()(mod) - transformed = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main")) - transformed = tvm.tir.transform.BindTarget(auto_target)(transformed) - transformed = tir.transform.LowerOpaqueBlock()(transformed) - - tvm.ir.assert_structural_equal(mod["main"], transformed["main"], True) - - -M = 512 -N = 512 -K = 512 -dtype = T.float16 -block_M = 64 -block_N = 64 -block_K = 32 - - -def test_multi_version_buffer(): - @T.prim_func - def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): - bx = T.launch_thread("blockIdx.x", 8) - by = T.launch_thread("blockIdx.y", 8) - v = T.launch_thread("threadIdx.x", 128) - with T.block(""): - T.reads(A[by * 64, 0:481], B[0:481, bx * 64]) - T.writes() - A_shared = T.alloc_buffer((1, 8, 256), T.float16, scope="shared.dyn") - B_shared = T.alloc_buffer((1, 4, 512), T.float16, scope="shared.dyn") - C_local = T.alloc_buffer((32,), scope="local") - for i in T.unroll(16, annotations={"pragma_unroll_explicit": T.bool(False)}): - for vec in T.vectorized(2): - C_local[i * 2 + vec] = T.float32(0) - for k in T.serial(16, annotations={"num_stages": T.int32(3)}): - if v == 0: - T.tma_load( - T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), - 0, - T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, 0, 2048, 2), - k * 32, - by * 64, - ) - if v == 0: - T.tma_load( - T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), - 0, - T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, 0, 2048, 2), - bx * 64, - k * 32, - ) - T.call_extern( - "handle", - "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", - T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, 0, 2048, 1), - T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, 0, 2048, 1), - T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3), - ) - - @T.prim_func - def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): - bx = T.launch_thread("blockIdx.x", 8) - by = T.launch_thread("blockIdx.y", 8) - v = T.launch_thread("threadIdx.x", 128) - with T.block(""): - T.reads(A[by * 64, 0:481], B[0:481, bx * 64]) - T.writes() - A_shared = T.alloc_buffer((3, 1, 8, 256), T.float16, scope="shared.dyn") - B_shared = T.alloc_buffer((3, 1, 4, 512), T.float16, scope="shared.dyn") - C_local = T.alloc_buffer((32,), scope="local") - for i in T.unroll(16, annotations={"pragma_unroll_explicit": T.bool(False)}): - for vec in T.vectorized(2): - C_local[i * 2 + vec] = T.float32(0) - for k in T.serial(16, annotations={"num_stages": T.int32(3)}): - if v == 0: - T.tma_load( - T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), - 0, - T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 2), - k * 32, - by * 64, - ) - if v == 0: - T.tma_load( - T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), - 0, - T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 2), - bx * 64, - k * 32, - ) - T.call_extern( - "handle", - "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", - T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3), - ) - - _check(before, after) - - -def test_multi_version_buffer_with_let(): - @T.prim_func - def before(scales: T.Tensor((4,), T.float32)): - with T.block("root"): - shared = T.alloc_buffer((8,), T.float32, scope="shared.dyn") - accum = T.alloc_buffer((8,), T.float32, scope="local") - for k in T.serial(4, annotations={"num_stages": T.int32(2)}): - value = scales[k] - for i in T.serial(8): - shared[i] = value - for i in T.serial(8): - accum[i] = accum[i] + shared[i] - - @T.prim_func - def after(scales: T.Tensor((4,), T.float32)): - with T.block("root"): - shared = T.alloc_buffer((2, 8), T.float32, scope="shared.dyn") - accum = T.alloc_buffer((8,), T.float32, scope="local") - for k in T.serial(4, annotations={"num_stages": T.int32(2)}): - value = scales[k] - for i in T.serial(8): - shared[k % 2, i] = value - for i in T.serial(8): - accum[i] = accum[i] + shared[k % 2, i] - - _check(before, after) - - -if __name__ == "__main__": - tilelang.testing.main() diff --git a/testing/python/transform/test_tilelang_transform_optimize_cp_async_sync.py b/testing/python/transform/test_tilelang_transform_optimize_cp_async_sync.py deleted file mode 100644 index d6ff40ca00..0000000000 --- a/testing/python/transform/test_tilelang_transform_optimize_cp_async_sync.py +++ /dev/null @@ -1,554 +0,0 @@ -from tilelang import tvm as tvm -import tilelang as tl -import tilelang.language as T -import tilelang.testing -from tvm.tir.stmt_functor import post_order_visit - - -def _count_calls(func): - call_count = {} - - def _visit(node): - if isinstance(node, tvm.tir.Call) and isinstance(node.op, tvm.ir.Op): - key = str(node.op.name) - call_count[key] = call_count.get(key, 0) + 1 - - post_order_visit(func.body, _visit) - return call_count - - -def _collect_wait_args(func): - wait_args = [] - - def _visit(node): - if ( - isinstance(node, tvm.tir.Call) - and isinstance(node.op, tvm.ir.Op) - and str(node.op.name) == "tir.ptx_wait_group" - and len(node.args) == 1 - ): - arg = node.args[0] - if isinstance(arg, tvm.tir.IntImm): - wait_args.append(int(arg.value)) - - post_order_visit(func.body, _visit) - return wait_args - - -def _run(mod): - mod = tl.transform.LowerOpaqueBlock()(mod) - mod = tl.transform.OptimizeCPAsyncSync()(mod) - mod = tl.transform.Simplify()(mod) - mod = tl.transform.OptimizeCPAsyncSync()(mod) - mod = tl.transform.Simplify()(mod) - return mod - - -def _find_pipelined_loop(func): - loops = [] - - def _visit(node): - if isinstance(node, tvm.tir.For) and "tl_pipelined_num_stages" in node.annotations: - loops.append(node) - - post_order_visit(func.body, _visit) - assert loops, "Expected at least one loop annotated with tl_pipelined_num_stages" - return loops[0] - - -def _count_commit_and_wait(stmt): - commit = 0 - waits = [] - - def _visit(node): - nonlocal commit, waits - if isinstance(node, tvm.tir.Call) and isinstance(node.op, tvm.ir.Op): - if node.op.name == "tir.ptx_commit_group": - commit += 1 - elif node.op.name == "tir.ptx_wait_group": - if node.args and isinstance(node.args[0], tvm.tir.IntImm): - waits.append(int(node.args[0])) - else: - waits.append(None) - - post_order_visit(stmt, _visit) - return commit, waits - - -def test_optimize_cp_async_sync_removes_redundant_commit(): - @T.prim_func - def before(A: T.Tensor((16,), T.uint8), B: T.Tensor((16,), T.uint8)): - S = T.alloc_buffer((16,), dtype=T.uint8, scope="shared") - for i in T.serial(0, 4): - T.ptx_cp_async( - T.access_ptr(S[i * 4], "w", 4), - T.access_ptr(A[i * 4], "r", 4), - 4, - ) - T.ptx_commit_group() - T.ptx_commit_group() - T.ptx_wait_group(0) - B[i * 4] = S[i * 4] - - mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - mod = _run(mod) - calls = _count_calls(mod["main"]) - assert calls.get("tir.ptx_commit_group", 0) == 1 - - -def test_optimize_cp_async_sync_removes_weaker_wait(): - @T.prim_func - def before(A: T.Tensor((16,), T.uint8), B: T.Tensor((16,), T.uint8)): - S = T.alloc_buffer((16,), dtype=T.uint8, scope="shared") - for i in T.serial(0, 4): - T.ptx_cp_async( - T.access_ptr(S[i * 4], "w", 4), - T.access_ptr(A[i * 4], "r", 4), - 4, - ) - T.ptx_commit_group() - T.ptx_wait_group(0) - T.ptx_wait_group(1) - B[i * 4] = S[i * 4] - - mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - mod = _run(mod) - calls = _count_calls(mod["main"]) - assert calls.get("tir.ptx_wait_group", 0) == 1 - - -def test_optimize_cp_async_sync_keeps_stricter_wait(): - @T.prim_func - def before(A: T.Tensor((16,), T.uint8), B: T.Tensor((16,), T.uint8)): - S = T.alloc_buffer((16,), dtype=T.uint8, scope="shared") - for i in T.serial(0, 4): - T.ptx_cp_async( - T.access_ptr(S[i * 4], "w", 4), - T.access_ptr(A[i * 4], "r", 4), - 4, - ) - T.ptx_commit_group() - T.ptx_wait_group(1) - T.ptx_wait_group(0) - B[i * 4] = S[i * 4] - - mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - mod = _run(mod) - calls = _count_calls(mod["main"]) - assert calls.get("tir.ptx_wait_group", 0) == 2 - - -def test_optimize_cp_async_sync_relaxes_loop_wait_with_prefetch(): - @T.prim_func - def before(A: T.Tensor((32,), T.uint8), B: T.Tensor((32,), T.uint8)): - S = T.alloc_buffer((32,), dtype=T.uint8, scope="shared") - # Prologue prefetch. - T.ptx_cp_async(T.access_ptr(S[0], "w", 4), T.access_ptr(A[0], "r", 4), 4) - T.ptx_commit_group() - for i in T.serial(0, 4): - T.ptx_cp_async( - T.access_ptr(S[(i + 1) * 4], "w", 4), - T.access_ptr(A[(i + 1) * 4], "r", 4), - 4, - ) - T.ptx_commit_group() - T.ptx_wait_group(0) - B[i * 4] = S[i * 4] - T.ptx_wait_group(0) - - mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - mod = _run(mod) - wait_args = _collect_wait_args(mod["main"]) - assert 1 in wait_args, f"Expected a relaxed wait_group(1), got wait args {wait_args}" - - -def test_optimize_cp_async_sync_merge_commit_groups_and_relax_wait(): - # Pattern inside a pipelined loop: - # cp_async(A); commit; cp_async(B); commit; wait_group(0) - # After OptimizeCPAsyncSync: - # cp_async(A); cp_async(B); commit; wait_group(1) (for num_stages=2) - @T.prim_func - def before(A: T.Tensor((16,), T.uint8), B: T.Tensor((16,), T.uint8)): - SA = T.alloc_buffer((16,), dtype=T.uint8, scope="shared") - SB = T.alloc_buffer((16,), dtype=T.uint8, scope="shared") - - for ko in T.serial(4, annotations={"tl_pipelined_num_stages": T.int32(2)}): - with T.block("copyA"): - T.reads(A[ko * 4 : ko * 4 + 4]) - T.writes(SA[ko * 4 : ko * 4 + 4]) - T.ptx_cp_async( - T.access_ptr(SA[ko * 4], "w", 4), - T.access_ptr(A[ko * 4], "r", 4), - 4, - ) - T.ptx_commit_group() - with T.block("copyB"): - T.reads(A[ko * 4 : ko * 4 + 4]) - T.writes(SB[ko * 4 : ko * 4 + 4]) - T.ptx_cp_async( - T.access_ptr(SB[ko * 4], "w", 4), - T.access_ptr(A[ko * 4], "r", 4), - 4, - ) - T.ptx_commit_group() - T.ptx_wait_group(0) - - # Consumer placeholder. - with T.block("consume"): - T.reads(SA[ko * 4], SB[ko * 4]) - T.writes(B[ko * 4]) - B[ko * 4] = SA[ko * 4] + SB[ko * 4] - - # Epilogue drain (typical pipeline lowering shape). - T.ptx_wait_group(0) - - mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - mod = _run(mod) - - func = mod["main"] - loop = _find_pipelined_loop(func) - - loop_commit, loop_waits = _count_commit_and_wait(loop.body) - assert loop_commit == 1, f"Expected 1 commit_group in loop after merge, got {loop_commit}" - assert 1 in loop_waits, f"Expected wait_group(1) in loop after relaxation, got waits={loop_waits}" - assert 0 not in loop_waits, f"Expected wait_group(0) inside loop to be relaxed, got waits={loop_waits}" - - func_commit, func_waits = _count_commit_and_wait(func.body) - assert 0 in func_waits, "Expected at least one epilogue wait_group(0) to remain in the function" - - -def test_optimize_cp_async_sync_does_not_relax_wait_when_prefetch_is_conditional(): - # If cp.async prefetch + commit is guarded by a runtime predicate, the - # number of committed groups before a wait_group(0) is not guaranteed at - # runtime. Relaxing to wait_group(N>0) can become a no-op and break - # correctness (e.g. blocksparse kernels). - @T.prim_func - def before(A: T.Tensor((64,), T.uint8), B: T.Tensor((64,), T.uint8)): - SA = T.alloc_buffer((64,), dtype=T.uint8, scope="shared") - SB = T.alloc_buffer((64,), dtype=T.uint8, scope="shared") - - for k in T.serial(0, 4, annotations={"tl_pipelined_num_stages": T.int32(2)}): - if k < 3: - T.ptx_cp_async( - T.access_ptr(SA[(k + 1) * 4], "w", 4), - T.access_ptr(A[(k + 1) * 4], "r", 4), - 4, - ) - T.ptx_commit_group() - T.ptx_cp_async( - T.access_ptr(SB[(k + 1) * 4], "w", 4), - T.access_ptr(A[(k + 1) * 4], "r", 4), - 4, - ) - T.ptx_commit_group() - - # Consumer wait for current stage. - T.ptx_wait_group(0) - B[k * 4] = SA[k * 4] + SB[k * 4] - - # Epilogue drain. - T.ptx_wait_group(0) - - mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - mod = _run(mod) - - func = mod["main"] - loop = _find_pipelined_loop(func) - _, loop_waits = _count_commit_and_wait(loop.body) - assert 0 in loop_waits, f"Expected wait_group(0) to remain in loop, got waits={loop_waits}" - assert 2 not in loop_waits, f"Did not expect wait_group(2) under conditional prefetch, got waits={loop_waits}" - - -def test_optimize_cp_async_sync_relaxes_loop_head_wait_with_non_async_prefix(): - # Regression case: - # The first statement in loop body is non-async, and the first async sync - # point is wait_group(0). This wait should still be relaxable. - @T.prim_func - def before(A: T.Tensor((64,), T.uint8), B: T.Tensor((64,), T.uint8)): - S = T.alloc_buffer((64,), dtype=T.uint8, scope="shared") - tmp = T.alloc_buffer((1,), dtype=T.uint8, scope="local") - - # Prologue: ensure there are committed groups before the loop. - T.ptx_cp_async(T.access_ptr(S[0], "w", 4), T.access_ptr(A[0], "r", 4), 4) - T.ptx_commit_group() - T.ptx_cp_async(T.access_ptr(S[4], "w", 4), T.access_ptr(A[4], "r", 4), 4) - T.ptx_commit_group() - - for k in T.serial(0, 4, annotations={"tl_pipelined_num_stages": T.int32(2)}): - # Non-async prefix before the first wait in this loop iteration. - tmp[0] = A[k * 4] - - T.ptx_wait_group(0) - - # Prefetch for a later tile. - T.ptx_cp_async( - T.access_ptr(S[(k + 2) * 4], "w", 4), - T.access_ptr(A[(k + 2) * 4], "r", 4), - 4, - ) - T.ptx_commit_group() - - B[k * 4] = S[k * 4] + tmp[0] - - # Epilogue drain. - T.ptx_wait_group(0) - - mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - mod = _run(mod) - - func = mod["main"] - loop = _find_pipelined_loop(func) - _, loop_waits = _count_commit_and_wait(loop.body) - assert 1 in loop_waits, f"Expected wait_group(1) in loop, got waits={loop_waits}" - assert 0 not in loop_waits, f"Expected loop wait_group(0) to be relaxed, got waits={loop_waits}" - - -def test_optimize_cp_async_sync_relaxes_multiple_waits_in_loop(): - # Two consumer waits in one pipelined loop should both be analyzed. - @T.prim_func - def before(A: T.Tensor((64,), T.uint8), B: T.Tensor((64,), T.uint8)): - SA = T.alloc_buffer((64,), dtype=T.uint8, scope="shared") - SB = T.alloc_buffer((64,), dtype=T.uint8, scope="shared") - - # Prologue: seed two committed groups before the loop. - T.ptx_cp_async(T.access_ptr(SA[0], "w", 4), T.access_ptr(A[0], "r", 4), 4) - T.ptx_commit_group() - T.ptx_cp_async(T.access_ptr(SB[0], "w", 4), T.access_ptr(A[32], "r", 4), 4) - T.ptx_commit_group() - - for k in T.serial(0, 4, annotations={"tl_pipelined_num_stages": T.int32(2)}): - # Wait for SA consumer. - T.ptx_wait_group(0) - B[k * 8] = SA[k * 4] - T.ptx_cp_async( - T.access_ptr(SA[(k + 2) * 4], "w", 4), - T.access_ptr(A[(k + 2) * 4], "r", 4), - 4, - ) - T.ptx_commit_group() - - # Wait for SB consumer. - T.ptx_wait_group(0) - B[k * 8 + 1] = SB[k * 4] - T.ptx_cp_async( - T.access_ptr(SB[(k + 2) * 4], "w", 4), - T.access_ptr(A[32 + (k + 2) * 4], "r", 4), - 4, - ) - T.ptx_commit_group() - - # Epilogue drain. - T.ptx_wait_group(0) - - mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - mod = _run(mod) - - func = mod["main"] - loop = _find_pipelined_loop(func) - _, loop_waits = _count_commit_and_wait(loop.body) - assert loop_waits.count(1) >= 2, f"Expected two relaxed waits in loop, got waits={loop_waits}" - assert 0 not in loop_waits, f"Expected all loop wait_group(0) to be relaxed, got waits={loop_waits}" - - -def test_optimize_cp_async_sync_relaxes_unrolled_epilogue_wait_but_keeps_last_drain(): - @T.prim_func - def before(A: T.Tensor((64,), T.uint8), B: T.Tensor((64,), T.uint8)): - S = T.alloc_buffer((64,), dtype=T.uint8, scope="shared") - - # Steady-state pipelined loop. - for k in T.serial(0, 4, annotations={"tl_pipelined_num_stages": T.int32(2)}): - T.ptx_cp_async( - T.access_ptr(S[(k + 1) * 4], "w", 4), - T.access_ptr(A[(k + 1) * 4], "r", 4), - 4, - ) - T.ptx_commit_group() - T.ptx_wait_group(1) - B[k * 4] = S[k * 4] - - # Epilogue consumer loop after software-pipeline expansion. - for k in T.unroll(2, annotations={"tl_pipelined_num_stages": T.int32(2)}): - T.ptx_wait_group(0) - B[16 + k] = S[16 + k] - - mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - mod = _run(mod) - - wait_args = _collect_wait_args(mod["main"]) - assert wait_args.count(1) >= 1, f"Expected a relaxed epilogue wait_group(1), got {wait_args}" - assert 0 in wait_args, f"Expected a final drain wait_group(0), got {wait_args}" - - -def test_optimize_cp_async_sync_relaxes_loop_head_wait_with_prefetch(): - @T.prim_func - def before(A: T.Tensor((32,), T.uint8), B: T.Tensor((32,), T.uint8)): - S = T.alloc_buffer((32,), dtype=T.uint8, scope="shared") - # Prologue prefetch: keep two committed groups in flight. - T.ptx_cp_async(T.access_ptr(S[0], "w", 4), T.access_ptr(A[0], "r", 4), 4) - T.ptx_commit_group() - T.ptx_cp_async(T.access_ptr(S[4], "w", 4), T.access_ptr(A[4], "r", 4), 4) - T.ptx_commit_group() - for i in T.serial(0, 4): - # Leading wait inserted by pipelining. - T.ptx_wait_group(0) - B[i * 4] = S[i * 4] - # Prefetch for i+2. - T.ptx_cp_async( - T.access_ptr(S[(i + 2) * 4], "w", 4), - T.access_ptr(A[(i + 2) * 4], "r", 4), - 4, - ) - T.ptx_commit_group() - T.ptx_wait_group(0) - - mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - mod = _run(mod) - wait_args = _collect_wait_args(mod["main"]) - assert 1 in wait_args, f"Expected a relaxed wait_group(1), got wait args {wait_args}" - - -def test_optimize_cp_async_sync_does_not_relax_loop_head_wait_without_prefetch(): - @T.prim_func - def before(A: T.Tensor((32,), T.uint8), B: T.Tensor((32,), T.uint8)): - S = T.alloc_buffer((32,), dtype=T.uint8, scope="shared") - # Only one committed group before the loop; relaxing would be unsafe. - T.ptx_cp_async(T.access_ptr(S[0], "w", 4), T.access_ptr(A[0], "r", 4), 4) - T.ptx_commit_group() - for i in T.serial(0, 4): - T.ptx_wait_group(0) - B[i * 4] = S[i * 4] - T.ptx_cp_async( - T.access_ptr(S[(i + 1) * 4], "w", 4), - T.access_ptr(A[(i + 1) * 4], "r", 4), - 4, - ) - T.ptx_commit_group() - T.ptx_wait_group(0) - - mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - mod = _run(mod) - wait_args = _collect_wait_args(mod["main"]) - assert 1 not in wait_args, f"Did not expect relaxed wait_group(1), got wait args {wait_args}" - - -def test_optimize_cp_async_sync_splits_epilogue_wait_between_two_consumer_phases(): - def _is_wait_stmt(stmt, wait_n: int) -> bool: - if not isinstance(stmt, tvm.tir.Evaluate): - return False - call = stmt.value - if not (isinstance(call, tvm.tir.Call) and isinstance(call.op, tvm.ir.Op)): - return False - if str(call.op.name) != "tir.ptx_wait_group" or len(call.args) != 1: - return False - arg = call.args[0] - return isinstance(arg, tvm.tir.IntImm) and int(arg.value) == wait_n - - def _is_shared_storage_sync(stmt) -> bool: - if not isinstance(stmt, tvm.tir.Evaluate): - return False - call = stmt.value - if not (isinstance(call, tvm.tir.Call) and isinstance(call.op, tvm.ir.Op)): - return False - if str(call.op.name) != "tir.tvm_storage_sync" or len(call.args) != 1: - return False - arg = call.args[0] - return isinstance(arg, tvm.tir.StringImm) and arg.value == "shared" - - @T.prim_func - def before(A: T.Tensor((32,), T.uint8), B: T.Tensor((32,), T.uint8)): - S = T.alloc_buffer((32,), dtype=T.uint8, scope="shared") - for i in T.serial(0, 2): - T.ptx_cp_async(T.access_ptr(S[i * 4], "w", 4), T.access_ptr(A[i * 4], "r", 4), 4) - T.ptx_commit_group() - # Epilogue drain inserted by pipelining. - T.ptx_wait_group(0) - T.tvm_storage_sync("shared") - # First epilogue consumer phase. - B[0] = S[0] - # Barrier between consumer phases. - T.tvm_storage_sync("shared") - # Second epilogue consumer phase. - B[1] = S[4] - - mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - mod = _run(mod) - - wait_args = _collect_wait_args(mod["main"]) - assert 1 in wait_args, f"Expected wait_group(1) after splitting epilogue, got {wait_args}" - assert 0 in wait_args, f"Expected an inserted wait_group(0) for the final drain, got {wait_args}" - - def _unwrap_to_seq(stmt): - while True: - if isinstance(stmt, tvm.tir.SeqStmt): - return stmt - if isinstance(stmt, tvm.tir.Allocate): - stmt = stmt.body - continue - if isinstance(stmt, tvm.tir.AllocateConst): - stmt = stmt.body - continue - if isinstance(stmt, tvm.tir.DeclBuffer): - stmt = stmt.body - continue - if isinstance(stmt, tvm.tir.LetStmt): - stmt = stmt.body - continue - if isinstance(stmt, tvm.tir.AttrStmt): - stmt = stmt.body - continue - if isinstance(stmt, tvm.tir.BlockRealize): - stmt = stmt.block.body - continue - if isinstance(stmt, tvm.tir.Block): - stmt = stmt.body - continue - return None - - top_seq = _unwrap_to_seq(mod["main"].body) - assert top_seq is not None, f"Expected a SeqStmt after unwrapping, got:\n{mod['main']}" - seq = list(top_seq.seq) - - # The original post-loop wait_group(0) should be relaxed to wait_group(1). - wait1_idx = next((i for i, s in enumerate(seq) if _is_wait_stmt(s, 1)), None) - assert wait1_idx is not None, f"Expected a top-level wait_group(1), got:\n{mod['main']}" - assert wait1_idx + 1 < len(seq) and _is_shared_storage_sync(seq[wait1_idx + 1]), ( - "Expected tvm_storage_sync('shared') immediately after relaxed wait_group(1)" - ) - - store_indices = [i for i, s in enumerate(seq) if isinstance(s, tvm.tir.BufferStore)] - store_indices = [i for i in store_indices if i > wait1_idx] - assert len(store_indices) >= 2, f"Expected two global BufferStore statements, got indices {store_indices}" - first_store, second_store = store_indices[0], store_indices[1] - - split_sync_idx = next( - (i for i in range(first_store + 1, second_store) if _is_shared_storage_sync(seq[i])), - None, - ) - assert split_sync_idx is not None, "Expected a shared barrier between the two global stores" - assert split_sync_idx - 1 >= 0 and _is_wait_stmt(seq[split_sync_idx - 1], 0), ( - "Expected an inserted wait_group(0) immediately before the barrier between epilogue blocks" - ) - - -def test_optimize_cp_async_sync_does_not_relax_wait_without_prefetch(): - @T.prim_func - def before(A: T.Tensor((16,), T.uint8), B: T.Tensor((16,), T.uint8)): - S = T.alloc_buffer((16,), dtype=T.uint8, scope="shared") - for i in T.serial(0, 4): - T.ptx_cp_async( - T.access_ptr(S[i * 4], "w", 4), - T.access_ptr(A[i * 4], "r", 4), - 4, - ) - T.ptx_commit_group() - T.ptx_wait_group(0) - B[i * 4] = S[i * 4] - - mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - mod = _run(mod) - wait_args = _collect_wait_args(mod["main"]) - assert 1 not in wait_args, f"Did not expect wait_group(1) without prefetch, got wait args {wait_args}" - - -if __name__ == "__main__": - tilelang.testing.main() diff --git a/testing/python/transform/test_tilelang_transform_pipeline_barrier_ownership.py b/testing/python/transform/test_tilelang_transform_pipeline_barrier_ownership.py new file mode 100644 index 0000000000..60d04b6a3e --- /dev/null +++ b/testing/python/transform/test_tilelang_transform_pipeline_barrier_ownership.py @@ -0,0 +1,145 @@ +"""Regression tests for pipeline barrier ownership. + +Plain pipelined T.copy should stay on the synchronous path in non-WS kernels. +Explicit TMA-style producers such as im2col still own pipeline barriers when +their lowering requires them. +""" + +import pytest +import tilelang +import tilelang.language as T +import tilelang.testing + + +def _check_hopper(): + """Return True if running on Hopper (sm_90).""" + try: + import torch + + if not torch.cuda.is_available(): + return False + props = torch.cuda.get_device_properties(0) + return (props.major, props.minor) == (9, 0) + except Exception: + return False + + +@pytest.mark.skipif(not _check_hopper(), reason="Requires Hopper GPU (sm_90)") +def test_nonws_plain_copy_gemm_num_stages_3_stays_sync(): + """Non-WS pipelined GEMM should not auto-upgrade plain T.copy to TMA.""" + M, N, K = 512, 512, 512 + block_M, block_N, block_K = 128, 128, 32 + + @T.prim_func + def gemm( + A: T.Tensor((M, K), T.float16), + B: T.Tensor((K, N), T.float16), + C: T.Tensor((M, N), T.float16), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_s = T.alloc_shared((block_M, block_K), T.float16) + B_s = T.alloc_shared((block_K, block_N), T.float16) + C_l = T.alloc_fragment((block_M, block_N), T.float32) + T.clear(C_l) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, ko * block_K], A_s) + T.copy(B[ko * block_K, bx * block_N], B_s) + T.gemm(A_s, B_s, C_l) + T.copy(C_l, C[by * block_M, bx * block_N]) + + kernel = tilelang.compile( + gemm, + out_idx=-1, + execution_backend="tvm_ffi", + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, + ) + src = kernel.get_kernel_source() + assert "tl::tma_load" not in src, "Non-WS plain T.copy should stay synchronous" + assert "pipeline_mbar_mem" not in src, "Non-WS plain T.copy should not allocate pipeline TMA barriers" + + +@pytest.mark.skipif(not _check_hopper(), reason="Requires Hopper GPU (sm_90)") +def test_nonws_plain_copy_gemm_num_stages_1_stays_sync(): + """num_stages=1 should also keep non-WS plain T.copy on the sync path.""" + M, N, K = 512, 512, 512 + block_M, block_N, block_K = 128, 128, 32 + + @T.prim_func + def gemm( + A: T.Tensor((M, K), T.float16), + B: T.Tensor((K, N), T.float16), + C: T.Tensor((M, N), T.float16), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_s = T.alloc_shared((block_M, block_K), T.float16) + B_s = T.alloc_shared((block_K, block_N), T.float16) + C_l = T.alloc_fragment((block_M, block_N), T.float32) + T.clear(C_l) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): + T.copy(A[by * block_M, ko * block_K], A_s) + T.copy(B[ko * block_K, bx * block_N], B_s) + T.gemm(A_s, B_s, C_l) + T.copy(C_l, C[by * block_M, bx * block_N]) + + kernel = tilelang.compile( + gemm, + out_idx=-1, + execution_backend="tvm_ffi", + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, + ) + src = kernel.get_kernel_source() + assert "tl::tma_load" not in src, "Non-WS plain T.copy should stay synchronous" + assert "pipeline_mbar_mem" not in src, "Non-WS plain T.copy should not allocate pipeline TMA barriers" + + +@pytest.mark.skipif(not _check_hopper(), reason="Requires Hopper GPU (sm_90)") +def test_nonws_im2col_tma_num_stages_3_uses_pipeline_barrier(): + """Non-WS pipelined im2col TMA with num_stages=3 must use pipeline_mbar[3].""" + N, C, H, W, F, K_size = 4, 64, 32, 32, 64, 3 + S, D, P = 1, 1, 1 + block_M, block_N, block_K = 64, 128, 32 + KH, KW = K_size, K_size + OH = (H + 2 * P - D * (K_size - 1) - 1) // S + 1 + OW = (W + 2 * P - D * (K_size - 1) - 1) // S + 1 + num_stages = 3 + + @T.prim_func + def conv( + data: T.Tensor((N, H, W, C), T.float16), + weight: T.Tensor((KH, KW, C, F), T.float16), + out: T.Tensor((N, OH, OW, F), T.float16), + ): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=256) as (bx, by): + data_shared = T.alloc_shared((block_M, block_K), T.float16) + weight_shared = T.alloc_shared((block_K, block_N), T.float16) + out_local = T.alloc_fragment((block_M, block_N), T.float32) + out_shared = T.alloc_shared((block_M, block_N), T.float16) + kernel_flat = T.Tensor((KH * KW * C, F), T.float16, weight.data) + out_flat = T.Tensor((N * OH * OW, F), T.float16, out.data) + T.clear(out_local) + for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): + T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P) + T.copy(kernel_flat[k_iter * block_K, bx * block_N], weight_shared) + T.gemm(data_shared, weight_shared, out_local) + T.copy(out_local, out_shared) + T.copy(out_shared, out_flat[by * block_M, bx * block_N]) + + kernel = tilelang.compile( + conv, + out_idx=-1, + execution_backend="tvm_ffi", + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, + ) + src = kernel.get_kernel_source() + assert f"pipeline_mbar_mem[{num_stages}]" in src, f"Expected pipeline_mbar_mem[{num_stages}] for non-WS im2col TMA pipeline" + # tma_load_im2col must appear (im2col was lowered through TMA path) + assert "tma_load_im2col" in src, "Expected tma_load_im2col in generated code" + # No fallback internal barriers for im2col + assert "mbarrier_1" not in src, "Should not have fallback mbarrier_1 when im2col uses pipeline barrier" + + +if __name__ == "__main__": + test_nonws_plain_copy_gemm_num_stages_3_stays_sync() + test_nonws_plain_copy_gemm_num_stages_1_stays_sync() + test_nonws_im2col_tma_num_stages_3_uses_pipeline_barrier() + print("All pipeline barrier ownership tests passed!") diff --git a/testing/python/transform/test_tilelang_transform_pipeline_planning.py b/testing/python/transform/test_tilelang_transform_pipeline_planning.py index 3220058a59..9ada565fba 100644 --- a/testing/python/transform/test_tilelang_transform_pipeline_planning.py +++ b/testing/python/transform/test_tilelang_transform_pipeline_planning.py @@ -7,6 +7,7 @@ from tvm.tir.stmt_functor import post_order_visit auto_target = tvm.target.Target(determine_target("auto")) +sm80_target = tvm.target.Target("cuda -arch=sm_80") def _check(original, transformed): @@ -31,6 +32,13 @@ def _visit(node): return annos +def _run_pipeline_planning(func, target=auto_target): + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tl.transform.PipelinePlanning()(mod) + return mod + + def test_simple_pipeline(): @T.prim_func def before(A: T.Tensor((1024, 32), T.float32), B: T.Tensor((32, 1024), T.float32), C: T.Tensor((1024, 1024), T.float32)): @@ -49,8 +57,37 @@ def before(A: T.Tensor((1024, 32), T.float32), B: T.Tensor((32, 1024), T.float32 T.copy(C_local, C[by * 128, bx * 128]) + func = before + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tvm.tir.transform.BindTarget(auto_target)(mod) + mod = tl.transform.PipelinePlanning()(mod) + mod = tl.transform.Simplify()(mod) + + annos = _collect_pipeline_loop_annotations(mod["main"]) + assert len(annos) == 1 + anno = annos[0] + assert "software_pipeline_stage" in anno + assert "software_pipeline_order" in anno + assert "tl_pipelined_num_stages" in anno + stages = [int(s) for s in anno["software_pipeline_stage"]] + orders = [int(o) for o in anno["software_pipeline_order"]] + assert stages == [0, 0, 2] + assert orders == [0, 1, 2] + assert int(anno["tl_pipelined_num_stages"]) == 3 + # tma_copies annotation depends on target TMA capability + if "software_pipeline_tma_copies" in anno: + tma_copies = [int(t) for t in anno["software_pipeline_tma_copies"]] + # On TMA-capable targets, copies are marked as TMA-eligible + assert tma_copies[2] == 0 # gemm is never TMA + + +def test_pipeline_planning_recognizes_parallel_bufferstore_copy_stages(): @T.prim_func - def after(A: T.Tensor((1024, 32), T.float32), B: T.Tensor((32, 1024), T.float32), C: T.Tensor((1024, 1024), T.float32)): + def before( + A: T.Tensor((1024, 32), T.float32), + B: T.Tensor((32, 1024), T.float32), + C: T.Tensor((1024, 1024), T.float32), + ): with T.Kernel(8, 8, threads=128) as (bx, by): A_shared = T.alloc_shared((128, 32), T.float32) B_shared = T.alloc_shared((32, 128), T.float32) @@ -58,22 +95,59 @@ def after(A: T.Tensor((1024, 32), T.float32), B: T.Tensor((32, 1024), T.float32) T.clear(C_local) - for ko in T.serial( - 32, - annotations={ - "software_pipeline_async_stages": [T.int32(0)], - "software_pipeline_order": [T.int32(0), T.int32(1), T.int32(2)], - "software_pipeline_stage": [T.int32(3), T.int32(3), T.int32(3)], - "tl_pipelined_num_stages": T.int32(3), - }, - ): + for ko in T.Pipelined(32, num_stages=3): + for i, k in T.Parallel(128, 32): + A_shared[i, k] = A[by * 128 + i, ko * 32 + k] + for k, j in T.Parallel(32, 128): + B_shared[k, j] = B[ko * 32 + k, bx * 128 + j] + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * 128, bx * 128]) + + mod = _run_pipeline_planning(before, sm80_target) + annos = _collect_pipeline_loop_annotations(mod["main"]) + assert annos, "Expected at least one loop annotated by PipelinePlanning" + anno = annos[0] + stages = [int(v) for v in anno["software_pipeline_stage"]] + orders = [int(v) for v in anno["software_pipeline_order"]] + async_producers = [int(v) for v in anno["software_pipeline_async_producers"]] + async_groups = [int(v) for v in anno["software_pipeline_async_producer_groups"]] + assert stages == [0, 0, 2] + assert orders == [0, 1, 2] + assert async_producers == [1, 1, 0] + assert async_groups == [0, 0, -1] + + +def test_pipeline_planning_marks_async_producers_per_statement(): + @T.prim_func + def before(A: T.Tensor((1024, 32), T.float32), B: T.Tensor((32, 1024), T.float32), C: T.Tensor((1024, 1024), T.float32)): + with T.Kernel(8, 8, threads=128) as (bx, by): + A_shared = T.alloc_shared((128, 32), T.float32) + B_shared = T.alloc_shared((32, 128), T.float32) + C_local = T.alloc_fragment((128, 128), T.float32) + + T.clear(C_local) + + for ko in T.Pipelined(32, num_stages=3): T.copy(A[by * 128, ko * 32], A_shared) T.copy(B[ko * 32, bx * 128], B_shared) T.gemm(A_shared, B_shared, C_local) T.copy(C_local, C[by * 128, bx * 128]) - _check(before, after) + mod = _run_pipeline_planning(before, sm80_target) + annos = _collect_pipeline_loop_annotations(mod["main"]) + assert annos, "Expected at least one loop annotated by PipelinePlanning" + anno = annos[0] + assert "software_pipeline_async_producers" in anno + assert "software_pipeline_async_producer_groups" in anno + assert "software_pipeline_async_stages" in anno + async_producers = [int(v) for v in anno["software_pipeline_async_producers"]] + async_groups = [int(v) for v in anno["software_pipeline_async_producer_groups"]] + async_stages = [int(v) for v in anno["software_pipeline_async_stages"]] + assert async_producers == [1, 1, 0] + assert async_groups == [0, 0, -1] + assert async_stages == [0] def test_pipeline_planning_recognizes_explicit_cp_async_copy_stage(): @@ -101,6 +175,68 @@ def before(A: T.Tensor((16,), T.uint8), B: T.Tensor((16,), T.uint8)): assert 0 in stages, "Expected explicit cp.async producer to be recognized as stage-0 copy stage" +def test_pipeline_planning_does_not_mark_fill_as_async_producer_for_predicated_cp_async(): + @T.prim_func + def before(A: T.Tensor((16,), T.uint8), B: T.Tensor((16,), T.uint8)): + S = T.alloc_buffer((16,), dtype=T.uint8, scope="shared") + for i in T.Pipelined(4, num_stages=2): + with T.block(): + for j in T.serial(16): + S[j] = T.uint8(0) + with T.block(): + T.ptx_cp_async( + T.access_ptr(S[i * 4], "w", 4), + T.access_ptr(A[i * 4], "r", 4), + 4, + True, + ) + with T.block(): + T.ptx_commit_group() + with T.block(): + T.ptx_wait_group(0) + with T.block(): + B[i * 4] = S[i * 4] + + mod = _run_pipeline_planning(before, sm80_target) + annos = _collect_pipeline_loop_annotations(mod["main"]) + assert annos, "Expected at least one loop annotated by PipelinePlanning" + anno = annos[0] + assert "software_pipeline_async_producers" in anno + assert "software_pipeline_async_producer_groups" in anno + async_producers = [int(v) for v in anno["software_pipeline_async_producers"]] + async_groups = [int(v) for v in anno["software_pipeline_async_producer_groups"]] + assert async_producers == [0, 1, 0, 0, 0] + assert async_groups == [-1, 0, -1, -1, -1] + + +def test_pipeline_planning_keeps_plain_hopper_pipeline_copies_sync(): + hopper_target = tvm.target.Target("cuda -arch=sm_90a") + + @T.prim_func + def before( + A: T.Tensor((1024, 32), T.float32), + B: T.Tensor((32, 1024), T.float32), + C: T.Tensor((1024, 1024), T.float32), + ): + with T.Kernel(8, 8, threads=128) as (bx, by): + A_shared = T.alloc_shared((128, 32), T.float32) + B_shared = T.alloc_shared((32, 128), T.float32) + C_local = T.alloc_fragment((128, 128), T.float32) + T.clear(C_local) + for k in T.Pipelined(32, num_stages=2): + T.copy(A[by * 128, k * 32], A_shared) + T.copy(B[k * 32, bx * 128], B_shared) + T.gemm(A_shared, B_shared, C_local) + + mod = _run_pipeline_planning(before, hopper_target) + + annos = _collect_pipeline_loop_annotations(mod["main"]) + assert annos, "Expected at least one loop annotated by PipelinePlanning" + anno = annos[0] + tma_copies = [int(v) for v in anno["software_pipeline_tma_copies"]] + assert tma_copies[:2] == [0, 0] + + def test_pipeline_planning_binds_commit_to_cp_async_stage(): @T.prim_func def before(A: T.Tensor((16,), T.uint8), B: T.Tensor((16,), T.uint8)): diff --git a/testing/python/transform/test_tilelang_transform_producer_consumer_ws.py b/testing/python/transform/test_tilelang_transform_producer_consumer_ws.py index 83e71c9622..4bbfd52e36 100644 --- a/testing/python/transform/test_tilelang_transform_producer_consumer_ws.py +++ b/testing/python/transform/test_tilelang_transform_producer_consumer_ws.py @@ -1,492 +1,327 @@ -# ruff: noqa -from tilelang import tvm as tvm -import tilelang as tl +"""Tests for the warp-specialized producer/consumer pass.""" + +import tilelang import tilelang.language as T import tilelang.testing +from tilelang import tvm as tvm +from tilelang.layout import make_swizzled_layout from tilelang.utils.target import determine_target -from tvm import tir - -auto_target = tvm.target.Target(determine_target("auto")) +def matmul_pipelined(M, N, K, block_M, block_K, block_N, num_stages, dtype="float16", threads=128): + """A simple pipelined GEMM using T.copy + T.gemm tile ops.""" -def _collect_calls(stmt, op_name: str): - calls = [] - - def visitor(node): - if isinstance(node, tvm.tir.Call) and hasattr(node, "op") and hasattr(node.op, "name") and node.op.name == op_name: - calls.append(node) + @T.prim_func + def main( + A: T.Buffer((M, K), dtype), + B: T.Buffer((K, N), dtype), + C: T.Buffer((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as ( + bx, + by, + ): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), "float32") + + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def matmul_windowed_pipelined( + M, + N, + K, + block_M, + block_K, + block_N, + num_stages, + window_tiles=2, + dtype="float16", + threads=128, +): + """A pipelined GEMM whose K-loop has a dynamic lower bound.""" - tvm.tir.stmt_functor.post_order_visit(stmt, visitor) - return calls + @T.prim_func + def main( + A: T.Buffer((M, K), dtype), + B: T.Buffer((K, N), dtype), + C: T.Buffer((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as ( + bx, + by, + ): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), "float32") + T.clear(C_local) -def _collect_ifs(stmt): - ifs = [] + start = T.max(0, bx - (window_tiles - 1)) + end = T.min(T.ceildiv(K, block_K), bx + 1) + for ko in T.Pipelined(start, end, num_stages=num_stages): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) - def visitor(node): - if isinstance(node, tvm.tir.IfThenElse): - ifs.append(node) + T.copy(C_local, C[by * block_M, bx * block_N]) - tvm.tir.stmt_functor.post_order_visit(stmt, visitor) - return ifs + return main -def _stmt_contains_call(stmt, op_name: str) -> bool: - found = False +def prelude_tma_wait_sink(block=64, iters=2, dtype="float16", threads=128): + """A tiled-WS kernel with pre-loop TMA loads consumed at different points.""" - def visitor(node): - nonlocal found - if isinstance(node, tvm.tir.Call) and hasattr(node, "op") and hasattr(node.op, "name") and node.op.name == op_name: - found = True + @T.prim_func + def main( + Q: T.Buffer((iters * block, block), dtype), + K_in: T.Buffer((block, block), dtype), + V_in: T.Buffer((block, block), dtype), + O: T.Buffer((block, block), dtype), + ): + with T.Kernel(1, threads=threads) as _: + K_shared = T.alloc_shared((block, block), dtype) + V_shared = T.alloc_shared((block, block), dtype) + q = T.alloc_shared((block, block), dtype) + acc0 = T.alloc_fragment((block, block), "float32") + acc1 = T.alloc_fragment((block, block), "float32") + out = T.alloc_fragment((block, block), "float32") + + T.copy(K_in[0, 0], K_shared) + T.copy(V_in[0, 0], V_shared) + T.clear(out) + for ko in T.Pipelined(iters, num_stages=2): + T.copy(Q[ko * block, 0], q) + T.clear(acc0) + T.gemm(K_shared, q, acc0, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.clear(acc1) + T.gemm(V_shared, q, acc1, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block, block): + out[i, j] = acc0[i, j] + acc1[i, j] + + T.copy(out, O[0, 0]) + + return main + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_tiled_ws_stage1_dynamic_loop_start(): + """Stage-1 tiled WS should handle dynamic pipeline loop bounds.""" + import torch + + M, N, K = 64, 128, 64 + block_M, block_K, block_N = 64, 32, 64 + func = matmul_windowed_pipelined( + M, + N, + K, + block_M, + block_K, + block_N, + num_stages=1, + window_tiles=2, + ) + target = determine_target() + kernel = tilelang.compile(func, target=target, out_idx=[2]) + source = kernel.get_kernel_source() + + assert "__launch_bounds__(256, 1)" in source + + A = torch.randn(M, K, dtype=torch.float16, device="cuda") + B = torch.randn(K, N, dtype=torch.float16, device="cuda") + C = kernel(A, B) + + ref = torch.zeros(M, N, dtype=torch.float32, device="cuda") + num_k_tiles = (K + block_K - 1) // block_K + num_n_tiles = (N + block_N - 1) // block_N + for bx in range(num_n_tiles): + start = max(0, bx - 1) + end = min(num_k_tiles, bx + 1) + n_slice = slice(bx * block_N, min((bx + 1) * block_N, N)) + acc = torch.zeros(M, n_slice.stop - n_slice.start, dtype=torch.float32, device="cuda") + for ko in range(start, end): + k_slice = slice(ko * block_K, min((ko + 1) * block_K, K)) + acc += A[:, k_slice].float() @ B[k_slice, n_slice].float() + ref[:, n_slice] = acc + + torch.testing.assert_close(C.float(), ref, rtol=1e-2, atol=1e-2) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_tiled_ws_correctness(): + """End-to-end correctness test: pipelined GEMM via tiled WS.""" + import torch + + M, N, K = 256, 256, 256 + func = matmul_pipelined(M, N, K, 64, 32, 64, num_stages=2) + target = determine_target() + kernel = tilelang.compile(func, target=target, out_idx=[2]) + + A = torch.randn(M, K, dtype=torch.float16, device="cuda") + B = torch.randn(K, N, dtype=torch.float16, device="cuda") + C = kernel(A, B) + + ref = A.float() @ B.float() + torch.testing.assert_close(C.float(), ref, rtol=1e-2, atol=1e-2) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_tiled_ws_stage3(): + """Pipelined GEMM with 3 stages.""" + import torch + + M, N, K = 512, 512, 512 + func = matmul_pipelined(M, N, K, 128, 64, 128, num_stages=3) + target = determine_target() + kernel = tilelang.compile(func, target=target, out_idx=[2]) + + A = torch.randn(M, K, dtype=torch.float16, device="cuda") + B = torch.randn(K, N, dtype=torch.float16, device="cuda") + C = kernel(A, B) + + ref = A.float() @ B.float() + torch.testing.assert_close(C.float(), ref, rtol=1e-2, atol=1e-2) + + +def _compile_tvm_ffi(func, pass_configs=None, **kwargs): + tilelang.disable_cache() + try: + return tilelang.compile( + func, + target="cuda", + execution_backend="tvm_ffi", + pass_configs=pass_configs or {}, + **kwargs, + ) + finally: + tilelang.enable_cache() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_tiled_ws_swizzled_layout_allows_ws(): + """Swizzled layout on a TMA copy target should NOT block warp specialization. + + Swizzled layouts are valid TMA layouts (TMA supports 32B/64B/128B swizzle). + Layout::Expand correctly handles MVB expansion for swizzled layouts. + """ + import torch + + M, N, K = 256, 256, 256 + block_M, block_K, block_N = 64, 64, 64 - tvm.tir.stmt_functor.post_order_visit(stmt, visitor) - return found + @T.prim_func + def gemm_swizzled( + A: T.Buffer((M, K), "float16"), + B: T.Buffer((K, N), "float16"), + C: T.Buffer((M, N), "float16"), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), "float16") + B_shared = T.alloc_shared((block_K, block_N), "float16") + C_local = T.alloc_fragment((block_M, block_N), "float32") + + T.annotate_layout({A_shared: make_swizzled_layout(A_shared), B_shared: make_swizzled_layout(B_shared)}) + + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=2): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + T.copy(C_local, C[by * block_M, bx * block_N]) + + pass_configs = {tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False} + kernel = _compile_tvm_ffi(gemm_swizzled, pass_configs, out_idx=[2]) + src = kernel.get_kernel_source() + + # WS should be applied: launch bounds should include producer warp group + assert "__launch_bounds__(256, 1)" in src + # TMA loads should be present + assert "tl::tma_load" in src + + # Correctness check + A = torch.randn(M, K, dtype=torch.float16, device="cuda") + B = torch.randn(K, N, dtype=torch.float16, device="cuda") + C = kernel(A, B) + ref = A.float() @ B.float() + torch.testing.assert_close(C.float(), ref, rtol=1e-2, atol=1e-2) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_tiled_ws_incompatible_layout_blocks_ws(): + """A non-swizzle, non-linear layout on ALL TMA copy targets should block WS. + + If every copy that could be a TMA producer has an incompatible layout, + there are no real TMA candidates and WS should not apply. + """ + from tilelang.layout import Layout + + M, K = 16, 128 + block_m, block_k = 16, 128 + + # A padded layout: (i, j) -> i * (block_k + 8) + j + # This is neither a swizzle layout nor a linear layout (output shape != input shape). + padded_continuous = block_k + 8 + padded_layout = Layout([block_m, block_k], lambda i, j: i * padded_continuous + j) + @T.prim_func + def copy_with_padded_layout( + x: T.Tensor((M, K), "float16"), + y: T.Tensor((M, K), "float16"), + ): + with T.Kernel(T.ceildiv(M, block_m), threads=128) as pid_m: + x_shared = T.alloc_shared((block_m, block_k), "float16") -def _count_calls_in_stmt(stmt, op_name: str) -> int: - count = 0 + T.annotate_layout({x_shared: padded_layout}) - def visitor(node): - nonlocal count - if isinstance(node, tvm.tir.Call) and hasattr(node, "op") and hasattr(node.op, "name") and node.op.name == op_name: - count += 1 + for _ in T.Pipelined(1, num_stages=1): + T.copy(x[pid_m * block_m, 0], x_shared) + T.copy(x_shared, y[pid_m * block_m, 0]) - tvm.tir.stmt_functor.post_order_visit(stmt, visitor) - return count + pass_configs = {tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False} + kernel = _compile_tvm_ffi(copy_with_padded_layout, pass_configs, out_idx=[1]) + src = kernel.get_kernel_source() + # WS should NOT be applied: no producer/consumer split + assert "__launch_bounds__(256, 1)" not in src -def _collect_buffer_loads(stmt, scope: str): - """Collect BufferLoad nodes from buffers with the given scope.""" - loads = [] - def visitor(node): - if isinstance(node, tvm.tir.BufferLoad) and node.buffer.scope() == scope: - loads.append(node) +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_tiled_ws_sinks_preloop_tma_waits_into_consumer(): + """Pre-loop TMA loads should not emit immediate waits in the common prelude.""" - tvm.tir.stmt_functor.post_order_visit(stmt, visitor) - return loads + pass_configs = {tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False} + kernel = _compile_tvm_ffi(prelude_tma_wait_sink(), pass_configs, out_idx=[3]) + src = kernel.get_kernel_source() + k_load = src.find("tl::tma_load(K_in_desc") + v_load = src.find("tl::tma_load(V_in_desc") + branch = src.find("if (128 <= ((int)threadIdx.x))") + first_wait = src.find(".wait(0)") -def test_producer_consumer_ws_pure_tma_does_not_reserve_unused_preloop_barrier(): - @T.prim_func - def before(A: T.Tensor((512, 512), T.float16), B: T.Tensor((512, 512), T.float16)): - bx = T.launch_thread("blockIdx.x", 8) - by = T.launch_thread("blockIdx.y", 8) - v = T.launch_thread("threadIdx.x", 128) - - with T.block(""): - T.reads(A[by * 64, 0:481], B[0:481, bx * 64]) - T.writes() - - A_shared = T.alloc_buffer((3, 1, 8, 256), T.float16, scope="shared.dyn") - B_shared = T.alloc_buffer((3, 1, 4, 512), T.float16, scope="shared.dyn") - C_local = T.alloc_buffer((32,), scope="local") - - mbarrier = T.alloc_barrier([128, 128, 128, 128, 128, 128]) - - for k in T.serial(16, annotations={"num_stages": T.int32(3)}): - if v == 0: - T.call_intrin( - "handle", - tir.op.Op.get("tl.mbarrier_expect_tx"), - mbarrier[k % 3], - 4096, - ) - if v == 0: - T.tma_load( - T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), - mbarrier[k % 3], - T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 2), - k * 32, - by * 64, - ) - T.call_intrin( - "handle", - tir.op.Op.get("tl.mbarrier_wait_parity"), - mbarrier[k % 3], - k // 3 % 2, - ) - - if v == 0: - T.call_intrin( - "handle", - tir.op.Op.get("tl.mbarrier_expect_tx"), - mbarrier[k % 3 + 3], - 4096, - ) - if v == 0: - T.tma_load( - T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), - mbarrier[k % 3 + 3], - T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 2), - k * 32, - bx * 64, - ) - T.call_intrin( - "handle", - tir.op.Op.get("tl.mbarrier_wait_parity"), - mbarrier[k % 3 + 3], - k // 3 % 2, - ) - - T.call_extern( - "handle", - "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", - T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3), - ) - - mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.BindTarget(auto_target)(mod) - mod = tl.transform.ProducerConsumerWarpSpecialized()(mod) - mod = tir.transform.LowerOpaqueBlock()(mod) - - main_func = mod["main"] - # After the WS pass, the barrier buffer should still be present as shared.barrier - # scope BufferLoads in the output (the pass creates its own barrier buffer). - barrier_loads = _collect_buffer_loads(main_func.body, "shared.barrier") - assert len(barrier_loads) > 0, "Expected shared.barrier BufferLoad nodes in WS output" - - -def test_producer_consumer_ws_preserves_guarded_forward_wait(): - @T.prim_func - def before(A: T.Tensor((512, 512), T.float16)): - bx = T.launch_thread("blockIdx.x", 1) - by = T.launch_thread("blockIdx.y", 1) - tx = T.launch_thread("threadIdx.x", 128) - - with T.block(""): - T.reads(A[0:128, 0:64]) - T.writes() - - A_shared = T.alloc_buffer((2, 1, 8, 256), T.float16, scope="shared.dyn") - C_local = T.alloc_buffer((1,), "float32", scope="local") - - mbarrier = T.alloc_barrier([1, 1]) - - for k in T.serial(4, annotations={"num_stages": T.int32(2)}): - i_s: T.int32 = T.if_then_else(k < 2, 0, -1) - - if i_s >= 0: - T.attr(A_shared.data, "tl.tma_copy_write_buffer", 1) - if tx == 0: - T.call_intrin( - "handle", - tir.op.Op.get("tl.mbarrier_expect_tx"), - mbarrier[k % 2], - 4096, - ) - if tx == 0: - T.tma_load( - T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), - mbarrier[k % 2], - T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 2 * 2048, 2048, 2), - k * 32, - by * 64, - ) - if i_s >= 0: - T.call_intrin( - "handle", - tir.op.Op.get("tl.mbarrier_wait_parity"), - mbarrier[k % 2], - k // 2 % 2, - ) - if i_s >= 0: - C_local[0] = C_local[0] + T.float32(1) - - mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.BindTarget(auto_target)(mod) - mod = tl.transform.ProducerConsumerWarpSpecialized()(mod) - mod = tir.transform.LowerOpaqueBlock()(mod) - - main_func = mod["main"] - body_text = main_func.script() - assert 'threadIdx.x", 256' in body_text - - guarded_waits = [] - for if_stmt in _collect_ifs(main_func.body): - if _stmt_contains_call(if_stmt.then_case, "tl.mbarrier_wait_parity"): - guarded_waits.append(str(if_stmt.condition)) - - assert guarded_waits - assert any("i_s" in cond for cond in guarded_waits) - - -def test_producer_consumer_ws_preserves_guarded_producer_backpressure_wait(): - @T.prim_func - def before(A: T.Tensor((512, 512), T.float16)): - bx = T.launch_thread("blockIdx.x", 1) - by = T.launch_thread("blockIdx.y", 1) - tx = T.launch_thread("threadIdx.x", 128) - - with T.block(""): - T.reads(A[0:128, 0:64]) - T.writes() - - A_shared = T.alloc_buffer((2, 1, 8, 256), T.float16, scope="shared.dyn") - C_local = T.alloc_buffer((1,), "float32", scope="local") - - mbarrier = T.alloc_barrier([1, 1]) - - for k in T.serial(4, annotations={"num_stages": T.int32(2)}): - i_s: T.int32 = T.if_then_else(k < 2, 0, -1) - - if i_s >= 0: - T.attr(A_shared.data, "tl.tma_copy_write_buffer", 1) - if tx == 0: - T.call_intrin( - "handle", - tir.op.Op.get("tl.mbarrier_expect_tx"), - mbarrier[k % 2], - 4096, - ) - if tx == 0: - T.tma_load( - T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), - mbarrier[k % 2], - T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 2 * 2048, 2048, 2), - k * 32, - by * 64, - ) - if i_s >= 0: - T.call_intrin( - "handle", - tir.op.Op.get("tl.mbarrier_wait_parity"), - mbarrier[k % 2], - k // 2 % 2, - ) - if i_s >= 0: - C_local[0] = C_local[0] + T.float32(1) - - mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.BindTarget(auto_target)(mod) - mod = tl.transform.ProducerConsumerWarpSpecialized()(mod) - mod = tir.transform.LowerOpaqueBlock()(mod) - - main_func = mod["main"] - body_text = main_func.script() - assert 'threadIdx.x", 256' in body_text - - guarded_wait_count = 0 - guarded_arrive_count = 0 - for if_stmt in _collect_ifs(main_func.body): - if "i_s" not in str(if_stmt.condition): - continue - guarded_wait_count += _count_calls_in_stmt(if_stmt.then_case, "tl.mbarrier_wait_parity") - guarded_arrive_count += _count_calls_in_stmt(if_stmt.then_case, "tir.ptx_arrive_barrier") - - assert guarded_wait_count >= 2 - assert guarded_arrive_count >= 2 - - -def test_producer_consumer_ws_uses_consumer_guard_for_backpressure_protocol(): - @T.prim_func - def before(A: T.Tensor((512, 512), T.float16)): - bx = T.launch_thread("blockIdx.x", 1) - by = T.launch_thread("blockIdx.y", 1) - tx = T.launch_thread("threadIdx.x", 128) - - with T.block(""): - T.reads(A[0:128, 0:64]) - T.writes() - - A_shared = T.alloc_buffer((2, 1, 8, 256), T.float16, scope="shared.dyn") - C_local = T.alloc_buffer((1,), "float32", scope="local") - - mbarrier = T.alloc_barrier([1, 1]) - - for k in T.serial(4, annotations={"num_stages": T.int32(2)}): - i_s: T.int32 = T.if_then_else(k < 2, 0, -1) - - if i_s >= 0: - T.attr(A_shared.data, "tl.tma_copy_write_buffer", 1) - if tx == 0: - T.call_intrin( - "handle", - tir.op.Op.get("tl.mbarrier_expect_tx"), - mbarrier[k % 2], - 4096, - ) - if tx == 0: - T.tma_load( - T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), - mbarrier[k % 2], - T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 2 * 2048, 2048, 2), - k * 32, - by * 64, - ) - if i_s >= 0: - T.call_intrin( - "handle", - tir.op.Op.get("tl.mbarrier_wait_parity"), - mbarrier[k % 2], - k // 2 % 2, - ) - - use_block: T.int32 = T.if_then_else(i_s >= 0, 1, 0) - if use_block != 0: - C_local[0] = C_local[0] + T.Cast("float32", A_shared[k % 2, 0, 0, 0]) - - mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.BindTarget(auto_target)(mod) - mod = tl.transform.ProducerConsumerWarpSpecialized()(mod) - mod = tir.transform.LowerOpaqueBlock()(mod) - - main_func = mod["main"] - body_text = main_func.script() - assert 'threadIdx.x", 256' in body_text - - guarded_wait_count = 0 - guarded_arrive_count = 0 - for if_stmt in _collect_ifs(main_func.body): - if "use_block" not in str(if_stmt.condition): - continue - guarded_wait_count += _count_calls_in_stmt(if_stmt.then_case, "tl.mbarrier_wait_parity") - guarded_arrive_count += _count_calls_in_stmt(if_stmt.then_case, "tir.ptx_arrive_barrier") - - assert guarded_wait_count >= 1 - assert guarded_arrive_count >= 1 - - -def test_producer_consumer_ws_finds_pipeline_loop_under_if_wrapper(): - @T.prim_func - def before(A: T.Tensor((512, 512), T.float16)): - bx = T.launch_thread("blockIdx.x", 1) - by = T.launch_thread("blockIdx.y", 1) - tx = T.launch_thread("threadIdx.x", 128) - - with T.block(""): - T.reads(A[0:128, 0:64]) - T.writes() - - A_shared = T.alloc_buffer((2, 1, 8, 256), T.float16, scope="shared.dyn") - C_local = T.alloc_buffer((1,), "float32", scope="local") - - mbarrier = T.alloc_barrier([1, 1]) - - if bx < 1: - for k in T.serial(4, annotations={"num_stages": T.int32(2)}): - with T.attr(A_shared.data, "tl.tma_copy_write_buffer", 1): - if tx == 0: - T.call_intrin( - "handle", - tir.op.Op.get("tl.mbarrier_expect_tx"), - mbarrier[k % 2], - 4096, - ) - if tx == 0: - T.tma_load( - T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), - mbarrier[k % 2], - T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 2 * 2048, 2048, 2), - k * 32, - by * 64, - ) - T.call_intrin( - "handle", - tir.op.Op.get("tl.mbarrier_wait_parity"), - mbarrier[k % 2], - k // 2 % 2, - ) - C_local[0] = C_local[0] + T.float32(1) - - mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.BindTarget(auto_target)(mod) - mod = tl.transform.ProducerConsumerWarpSpecialized()(mod) - mod = tir.transform.LowerOpaqueBlock()(mod) - - main_func = mod["main"] - body_text = main_func.script() - assert 'threadIdx.x", 256' in body_text - assert "num_stages" not in body_text - assert "software_pipeline_stage" not in body_text - assert "software_pipeline_order" not in body_text - - -def test_producer_consumer_ws_moves_preloop_tma_prefix_inside_wrapped_ws_split(): - @T.prim_func - def before(A: T.Tensor((512, 512), T.float16)): - bx = T.launch_thread("blockIdx.x", 1) - by = T.launch_thread("blockIdx.y", 1) - tx = T.launch_thread("threadIdx.x", 128) - - with T.block(""): - T.reads(A[0:128, 0:64]) - T.writes() - - A_shared = T.alloc_buffer((2, 1, 8, 256), T.float16, scope="shared.dyn") - C_local = T.alloc_buffer((1,), "float32", scope="local") - - mbarrier = T.alloc_barrier([1, 1, 1]) - - if bx < 1: - with T.attr(A_shared.data, "tl.tma_copy_write_buffer", 1): - if tx == 0: - T.call_intrin( - "handle", - tir.op.Op.get("tl.mbarrier_expect_tx"), - mbarrier[0], - 4096, - ) - if tx == 0: - T.tma_load( - T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), - mbarrier[0], - T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, 0, 2048, 2), - 0, - by * 64, - ) - T.call_intrin( - "handle", - tir.op.Op.get("tl.mbarrier_wait_parity"), - mbarrier[0], - 0, - ) - - for k in T.serial(4, annotations={"num_stages": T.int32(2)}): - with T.attr(A_shared.data, "tl.tma_copy_write_buffer", 1): - if tx == 0: - T.call_intrin( - "handle", - tir.op.Op.get("tl.mbarrier_expect_tx"), - mbarrier[k % 2 + 1], - 4096, - ) - if tx == 0: - T.tma_load( - T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), - mbarrier[k % 2 + 1], - T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 2 * 2048, 2048, 2), - k * 32, - by * 64, - ) - T.call_intrin( - "handle", - tir.op.Op.get("tl.mbarrier_wait_parity"), - mbarrier[k % 2 + 1], - k // 2 % 2, - ) - C_local[0] = C_local[0] + T.Cast("float32", A_shared[k % 2, 0, 0, 0]) - - mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.BindTarget(auto_target)(mod) - mod = tl.transform.MultiVersionBuffer(barrier_only=False)(mod) - mod = tl.transform.ProducerConsumerWarpSpecialized()(mod) - mod = tir.transform.LowerOpaqueBlock()(mod) - - main_func = mod["main"] - body_text = main_func.script() - - scope_idx = body_text.find("kWarpSpecializationScope") - first_tma_idx = body_text.find("T.tma_load(") - first_producer_elect_idx = body_text.find("T.tl_shuffle_elect(128)") - - assert scope_idx != -1 - assert first_tma_idx > scope_idx - assert first_producer_elect_idx > scope_idx + assert min(k_load, v_load, branch, first_wait) >= 0 + assert k_load < v_load < branch < first_wait if __name__ == "__main__": - tilelang.testing.main() + test_tiled_ws_stage1_dynamic_loop_start() + test_tiled_ws_correctness() + test_tiled_ws_stage3() + test_tiled_ws_swizzled_layout_allows_ws() + test_tiled_ws_incompatible_layout_blocks_ws() + test_tiled_ws_sinks_preloop_tma_waits_into_consumer() diff --git a/testing/python/transform/test_tilelang_transform_reuse_local_descriptor_allocations.py b/testing/python/transform/test_tilelang_transform_reuse_local_descriptor_allocations.py new file mode 100644 index 0000000000..940baa88a1 --- /dev/null +++ b/testing/python/transform/test_tilelang_transform_reuse_local_descriptor_allocations.py @@ -0,0 +1,105 @@ +# ruff: noqa +from tilelang import tvm as tvm +import tilelang as tl +from tilelang.utils.target import determine_target +import tilelang.language as T + + +auto_target = tvm.target.Target(determine_target("auto")) + + +def _check(original, transformed): + mod = tvm.IRModule.from_expr(original.with_attr("global_symbol", "main")) + mod = tvm.tir.transform.BindTarget(auto_target)(mod) + mod = tl.transform.ReuseLocalDescriptorAllocations()(mod) + + expected = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main")) + expected = tvm.tir.transform.BindTarget(auto_target)(expected) + + tvm.ir.assert_structural_equal(mod["main"], expected["main"], True) + + +def test_reuse_local_descriptor_allocations(): + @T.prim_func + def before(): + T.func_attr({"tir.noalias": True}) + with T.attr(0, "test.region", 0): + desc_a = T.allocate([1], "uint64", "local.descriptor.wgmma") + desc_b = T.allocate([1], "uint64", "local.descriptor.wgmma") + desc_a_buf = T.Buffer((1,), "uint64", data=desc_a, scope="local.descriptor.wgmma") + desc_b_buf = T.Buffer((1,), "uint64", data=desc_b, scope="local.descriptor.wgmma") + T.initialize_wgmma_descriptor(desc_a_buf[0], T.uint64(0), 1, 1, 64) + T.initialize_wgmma_descriptor(desc_b_buf[0], T.uint64(0), 1, 1, 64) + T.evaluate(T.call_extern("handle", "use_desc_pair", desc_a, desc_b)) + with T.attr(0, "test.region", 1): + desc_a_1 = T.allocate([1], "uint64", "local.descriptor.wgmma") + desc_b_1 = T.allocate([1], "uint64", "local.descriptor.wgmma") + desc_a_buf_1 = T.Buffer((1,), "uint64", data=desc_a_1, scope="local.descriptor.wgmma") + desc_b_buf_1 = T.Buffer((1,), "uint64", data=desc_b_1, scope="local.descriptor.wgmma") + T.initialize_wgmma_descriptor(desc_a_buf_1[0], T.uint64(1), 1, 1, 64) + T.initialize_wgmma_descriptor(desc_b_buf_1[0], T.uint64(1), 1, 1, 64) + T.evaluate(T.call_extern("handle", "use_desc_pair", desc_a_1, desc_b_1)) + + @T.prim_func + def after(): + T.func_attr({"tir.noalias": True}) + desc_a = T.allocate([1], "uint64", "local.descriptor.wgmma") + desc_b = T.allocate([1], "uint64", "local.descriptor.wgmma") + with T.attr(0, "test.region", 0): + desc_a_buf = T.Buffer((1,), "uint64", data=desc_a, scope="local.descriptor.wgmma") + desc_b_buf = T.Buffer((1,), "uint64", data=desc_b, scope="local.descriptor.wgmma") + T.initialize_wgmma_descriptor(desc_a_buf[0], T.uint64(0), 1, 1, 64) + T.initialize_wgmma_descriptor(desc_b_buf[0], T.uint64(0), 1, 1, 64) + T.evaluate(T.call_extern("handle", "use_desc_pair", desc_a, desc_b)) + with T.attr(0, "test.region", 1): + desc_a_buf_1 = T.Buffer((1,), "uint64", data=desc_a, scope="local.descriptor.wgmma") + desc_b_buf_1 = T.Buffer((1,), "uint64", data=desc_b, scope="local.descriptor.wgmma") + T.initialize_wgmma_descriptor(desc_a_buf_1[0], T.uint64(1), 1, 1, 64) + T.initialize_wgmma_descriptor(desc_b_buf_1[0], T.uint64(1), 1, 1, 64) + T.evaluate(T.call_extern("handle", "use_desc_pair", desc_a, desc_b)) + + _check(before, after) + + +def test_reuse_local_descriptor_allocations_stays_inside_launch_thread(): + @T.prim_func + def before(): + T.func_attr({"tir.noalias": True}) + with T.launch_thread("blockIdx.x", 1): + with T.attr(0, "test.region", 0): + desc_a = T.allocate([1], "uint64", "local.descriptor.wgmma") + desc_b = T.allocate([1], "uint64", "local.descriptor.wgmma") + desc_a_buf = T.Buffer((1,), "uint64", data=desc_a, scope="local.descriptor.wgmma") + desc_b_buf = T.Buffer((1,), "uint64", data=desc_b, scope="local.descriptor.wgmma") + T.initialize_wgmma_descriptor(desc_a_buf[0], T.uint64(0), 1, 1, 64) + T.initialize_wgmma_descriptor(desc_b_buf[0], T.uint64(0), 1, 1, 64) + T.evaluate(T.call_extern("handle", "use_desc_pair", desc_a, desc_b)) + with T.attr(0, "test.region", 1): + desc_a_1 = T.allocate([1], "uint64", "local.descriptor.wgmma") + desc_b_1 = T.allocate([1], "uint64", "local.descriptor.wgmma") + desc_a_buf_1 = T.Buffer((1,), "uint64", data=desc_a_1, scope="local.descriptor.wgmma") + desc_b_buf_1 = T.Buffer((1,), "uint64", data=desc_b_1, scope="local.descriptor.wgmma") + T.initialize_wgmma_descriptor(desc_a_buf_1[0], T.uint64(1), 1, 1, 64) + T.initialize_wgmma_descriptor(desc_b_buf_1[0], T.uint64(1), 1, 1, 64) + T.evaluate(T.call_extern("handle", "use_desc_pair", desc_a_1, desc_b_1)) + + @T.prim_func + def after(): + T.func_attr({"tir.noalias": True}) + with T.launch_thread("blockIdx.x", 1): + desc_a = T.allocate([1], "uint64", "local.descriptor.wgmma") + desc_b = T.allocate([1], "uint64", "local.descriptor.wgmma") + with T.attr(0, "test.region", 0): + desc_a_buf = T.Buffer((1,), "uint64", data=desc_a, scope="local.descriptor.wgmma") + desc_b_buf = T.Buffer((1,), "uint64", data=desc_b, scope="local.descriptor.wgmma") + T.initialize_wgmma_descriptor(desc_a_buf[0], T.uint64(0), 1, 1, 64) + T.initialize_wgmma_descriptor(desc_b_buf[0], T.uint64(0), 1, 1, 64) + T.evaluate(T.call_extern("handle", "use_desc_pair", desc_a, desc_b)) + with T.attr(0, "test.region", 1): + desc_a_buf_1 = T.Buffer((1,), "uint64", data=desc_a, scope="local.descriptor.wgmma") + desc_b_buf_1 = T.Buffer((1,), "uint64", data=desc_b, scope="local.descriptor.wgmma") + T.initialize_wgmma_descriptor(desc_a_buf_1[0], T.uint64(1), 1, 1, 64) + T.initialize_wgmma_descriptor(desc_b_buf_1[0], T.uint64(1), 1, 1, 64) + T.evaluate(T.call_extern("handle", "use_desc_pair", desc_a, desc_b)) + + _check(before, after) diff --git a/tilelang/autotuner/param.py b/tilelang/autotuner/param.py index aa5254b996..6c79b78d47 100644 --- a/tilelang/autotuner/param.py +++ b/tilelang/autotuner/param.py @@ -419,7 +419,7 @@ def save_to_disk(self, path: Path, verbose: bool = False): { "out_idx": list(self.func.attrs["tilelang_out_idx"]) if (self.func.attrs and "tilelang_out_idx" in self.func.attrs) - else None, + else None }, f, ), @@ -468,11 +468,14 @@ def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> AutotuneResult with open(path / FUNCTION_PATH, "rb") as f: func = cloudpickle.load(f) - # load out idx - if verbose: - logger.debug(f"Loading out idx from file: {path / OUT_IDX_PATH}") - with open(path / OUT_IDX_PATH) as f: - out_idx_override = json.load(f)["out_idx"] + # load out idx (optional — older caches may not have this file) + out_idx_override = None + out_idx_file = path / OUT_IDX_PATH + if out_idx_file.exists(): + if verbose: + logger.debug(f"Loading out idx from file: {out_idx_file}") + with open(out_idx_file) as f: + out_idx_override = json.load(f)["out_idx"] # load latency if verbose: diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 6980f0faf7..3153db1d7c 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -175,9 +175,26 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.Simplify()(mod) # Set layouts for reducers mod = tilelang.transform.LayoutReducer()(mod) + # Tile-level warp specialization: runs before layout inference so that + # producer/consumer split happens at the high-level tile-op IR. + # The pass classifies copy ops as TMA/cp.async/sync inline (no prior + # InstructionAnnotation pass needed). Shared buffers are multi-versioned + # internally only for functions where the WS transformation actually + # applies. + if allow_warp_specialized(target=target): + mod = tilelang.transform.ProducerConsumerWarpSpecialized()(mod) # Lower 2SM TCGEN5MMA and related on Blackwell target (must run before # LayoutInference so that the use_2cta annotation is visible to infer_layout) mod = tilelang.transform.LowerBlackwell2SM()(mod) + # Run pipeline planning and software-pipeline rewriting before layout + # inference so inferred layouts see the final pipelined structure directly. + mod = tilelang.transform.PipelinePlanning()(mod) + # print("After pipeline planing") + # print(mod) + mod = tilelang.transform.InjectSoftwarePipeline()(mod) + # print("After InjectSoftwarePipeline") + # print(mod) + mod = tilelang.transform.Simplify()(mod) # Infer memory layouts for fragments and shared memory mod = tilelang.transform.LayoutInference()(mod) # Visualize the layout @@ -212,34 +229,19 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: # which may be introduced by the LegalizeSafeMemoryAccess mod = tilelang.transform.IfStmtBinding()(mod) has_tma = module_has_tma(mod) - use_ws = has_tma and allow_warp_specialized(pass_ctx=pass_ctx, target=target) - if has_tma: - # In WS mode, version all buffers (barriers + data) because - # ProducerConsumerWarpSpecialized handles pipeline overlap and - # InjectSoftwarePipeline won't re-version data buffers. - # Without WS, only version barrier buffers for mbarrier parity - # rewriting; data buffer versioning is left to InjectSoftwarePipeline. - mod = tilelang.transform.MultiVersionBuffer(barrier_only=not use_ws)(mod) - if use_ws: - mod = tilelang.transform.ProducerConsumerWarpSpecialized()(mod) - else: - # Non-TMA: MultiVersionBuffer is not used, so buffer allocation - # locations must be planned explicitly. In TMA paths this is - # handled implicitly by MultiVersionBuffer (which runs LCA - # analysis to place versioned buffers). - mod = tilelang.transform.PlanAndUpdateBufferAllocationLocation()(mod) + # Pipeline barriers are now created at final expanded size by + # InjectSoftwarePipeline, so no late MVB barrier fixup is needed. + # Buffer allocation placement is handled uniformly for both paths. + mod = tilelang.transform.PlanAndUpdateBufferAllocationLocation()(mod) mod = tilelang.transform.LowerSharedBarrier()(mod) - mod = tilelang.transform.PipelinePlanning()(mod) - mod = tilelang.transform.InjectSoftwarePipeline()(mod) if has_tma: mod = tilelang.transform.FuseMBarrierArriveExpectTx()(mod) mod = tilelang.transform.HoistGlobalBufferAllocations()(mod) mod = tilelang.transform.LowerOpaqueBlock()(mod) + mod = tilelang.transform.ReuseLocalDescriptorAllocations()(mod) if is_hopper(target): mod = tilelang.transform.RewriteWgmmaSync()(mod) mod = tilelang.transform.Simplify()(mod) - mod = tilelang.transform.OptimizeCPAsyncSync()(mod) - mod = tilelang.transform.Simplify()(mod) mod = tir.transform.NarrowDataType(32)(mod) mod = tilelang.transform.FlattenBuffer()(mod) # ConfigIndexBitwidth must be applied after FlattenBuffer diff --git a/tilelang/intrinsics/mma_sp_macro_generator.py b/tilelang/intrinsics/mma_sp_macro_generator.py index 14852c5e3d..480a85601b 100644 --- a/tilelang/intrinsics/mma_sp_macro_generator.py +++ b/tilelang/intrinsics/mma_sp_macro_generator.py @@ -2,14 +2,15 @@ import tilelang.language as T from typing import Literal, Callable -from tvm import DataType -from tvm.tir import PrimExpr, IndexMap, Buffer, Var +from tvm import DataType, tir +from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion, BufferLoad +from tvm.ir import Range from tvm.runtime import convert from .utils import ( mma_store_index_map, get_ldmatrix_offset, ) -from tilelang.utils import is_fragment +from tilelang.utils import is_fragment, get_buffer_region_from_load from tilelang.intrinsics.mma_sp_layout import ( shared_16x16_to_mma_sp_layout_sr_a, @@ -60,62 +61,14 @@ class SparseTensorCoreIntrinEmitter: } E_FACTOR_MAP = { # e_kdim = mma_kdim // e_factor - "float": { - "int16": 8, - "uint16": 8, - }, - "float32": { - "int16": 8, - "uint16": 8, - }, - "float16": { - "int8": 8, - "uint8": 8, - "int16": 16, - "uint16": 16, - "int32": 32, - "uint32": 32, - }, - "bfloat16": { - "int8": 8, - "uint8": 8, - "int16": 16, - "uint16": 16, - "int32": 32, - "uint32": 32, - }, - "int8": { - "int8": 8, - "uint8": 8, - "int16": 16, - "uint16": 16, - "int32": 32, - "uint32": 32, - }, - "uint8": { - "int8": 8, - "uint8": 8, - "int16": 16, - "uint16": 16, - "int32": 32, - "uint32": 32, - }, - "float8_e4m3": { - "int8": 8, - "uint8": 8, - "int16": 16, - "uint16": 16, - "int32": 32, - "uint32": 32, - }, - "float8_e5m2": { - "int8": 8, - "uint8": 8, - "int16": 16, - "uint16": 16, - "int32": 32, - "uint32": 32, - }, + "float": {"int16": 8, "uint16": 8}, + "float32": {"int16": 8, "uint16": 8}, + "float16": {"int8": 8, "uint8": 8, "int16": 16, "uint16": 16, "int32": 32, "uint32": 32}, + "bfloat16": {"int8": 8, "uint8": 8, "int16": 16, "uint16": 16, "int32": 32, "uint32": 32}, + "int8": {"int8": 8, "uint8": 8, "int16": 16, "uint16": 16, "int32": 32, "uint32": 32}, + "uint8": {"int8": 8, "uint8": 8, "int16": 16, "uint16": 16, "int32": 32, "uint32": 32}, + "float8_e4m3": {"int8": 8, "uint8": 8, "int16": 16, "uint16": 16, "int32": 32, "uint32": 32}, + "float8_e5m2": {"int8": 8, "uint8": 8, "int16": 16, "uint16": 16, "int32": 32, "uint32": 32}, } E_REPLICATE_FACTOR = { # metadata replicate every 4 consecutive threads @@ -285,7 +238,7 @@ def extract_thread_binding(self, thread_id: PrimExpr, is_m_first: bool | None = ) return lane_id, warp_n, warp_m - def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer, ki: PrimExpr, rk: PrimExpr = 0): + def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer | BufferRegion | BufferLoad, ki: PrimExpr, rk: PrimExpr = 0): warp_row_tiles = self.warp_row_tiles warp_rows = self.warp_rows warp_k = self.warp_k @@ -312,6 +265,13 @@ def mma_load_layout(i, j): thread_binding = self.get_thread_binding() + A_region = self._legalize_to_buffer_region(A_shared_buf) + A_buf = A_region.buffer + A_base0 = A_region.region[-2].min + A_base1 = A_region.region[-1].min + A_other = [r.min for r in A_region.region[:-2]] + A_stride_last = A_buf.shape[-1] + @T.macro def _warp_ldmatrix_a( A_local_buf, @@ -320,14 +280,18 @@ def _warp_ldmatrix_a( thread_binding, rk=0, ): - stride = A_shared_buf.shape[-1] + stride = A_stride_last tx, _, warp_m = self.extract_thread_binding(thread_binding) trans = self.a_transposed for i in T.serial(warp_rows): # Assign A_shared_buf_elem wi, wk = warp_m * warp_row_tiles + i * micro_size_x, (rk * warp_k + ki * micro_size_k) // self.SPARSE_FACTOR - A_shared_buf_elem = A_shared_buf[wk, wi] if a_transposed else A_shared_buf[wi, wk] + A_shared_buf_elem = ( + A_buf[tuple(A_other) + (A_base0 + wk, A_base1 + wi)] + if a_transposed + else A_buf[tuple(A_other) + (A_base0 + wi, A_base1 + wk)] + ) if ldmatrix_available: T.ptx_ldmatrix( @@ -344,12 +308,14 @@ def _warp_ldmatrix_a( for j in T.serial(local_size_a): mi, mk = mma_load_layout(tx, j) A_local_buf[i * local_size_a + j] = ( - A_shared_buf[wk + mk, wi + mi] if a_transposed else A_shared_buf[wi + mi, wk + mk] + A_buf[tuple(A_other) + (A_base0 + wk + mk, A_base1 + wi + mi)] + if a_transposed + else A_buf[tuple(A_other) + (A_base0 + wi + mi, A_base1 + wk + mk)] ) - return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) + return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk) - def ldmatrix_e(self, E_local_buf: Buffer, E_shared_buf: Buffer, ki: PrimExpr, rk: PrimExpr = 0): + def ldmatrix_e(self, E_local_buf: Buffer, E_shared_buf: Buffer | BufferRegion | BufferLoad, ki: PrimExpr, rk: PrimExpr = 0): warp_row_tiles = self.warp_row_tiles warp_rows = self.warp_rows warp_k = self.warp_k @@ -395,6 +361,12 @@ def mma_load_layout(i, j): thread_binding = self.get_thread_binding() + E_region = self._legalize_to_buffer_region(E_shared_buf) + E_buf = E_region.buffer + E_base0 = E_region.region[-2].min + E_base1 = E_region.region[-1].min + E_other = [r.min for r in E_region.region[:-2]] + @T.macro def _warp_ldmatrix_e( E_local_buf, @@ -409,11 +381,15 @@ def _warp_ldmatrix_e( wi, wk = warp_m * warp_row_tiles + i * micro_size_x, (rk * warp_k + ki * micro_size_k) // self.e_factor for j in T.serial(local_size_e): mi, mk = mma_load_layout(tx, j) - E_local_buf[i * local_size_e + j] = E_shared_buf[wk + mk, wi + mi] if trans else E_shared_buf[wi + mi, wk + mk] + E_local_buf[i * local_size_e + j] = ( + E_buf[tuple(E_other) + (E_base0 + wk + mk, E_base1 + wi + mi)] + if trans + else E_buf[tuple(E_other) + (E_base0 + wi + mi, E_base1 + wk + mk)] + ) - return _warp_ldmatrix_e(E_local_buf, E_shared_buf, ki, thread_binding, rk) + return _warp_ldmatrix_e(E_local_buf, E_region, ki, thread_binding, rk) - def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer, ki: PrimExpr, rk: PrimExpr = 0): + def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer | BufferRegion | BufferLoad, ki: PrimExpr, rk: PrimExpr = 0): warp_col_tiles = self.warp_col_tiles warp_cols = self.warp_cols warp_k = self.warp_k @@ -440,6 +416,13 @@ def mma_load_layout(i, j): else: raise ValueError(f"Unsupported dtype: {b_dtype}") + B_region = self._legalize_to_buffer_region(B_shared_buf) + B_buf = B_region.buffer + B_base0 = B_region.region[-2].min + B_base1 = B_region.region[-1].min + B_other = [r.min for r in B_region.region[:-2]] + B_stride_last = B_buf.shape[-1] + @T.macro def _warp_ldmatrix_b( B_local_buf, @@ -448,7 +431,7 @@ def _warp_ldmatrix_b( thread_binding, rk=0, ): - stride = B_shared_buf.shape[-1] + stride = B_stride_last tx, warp_n, _ = self.extract_thread_binding(thread_binding) trans = not b_transposed @@ -460,7 +443,11 @@ def _warp_ldmatrix_b( ) if ldmatrix_available: - B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk, wi] + B_shared_buf_elem = ( + B_buf[tuple(B_other) + (B_base0 + wi, B_base1 + wk)] + if b_transposed + else B_buf[tuple(B_other) + (B_base0 + wk, B_base1 + wi)] + ) if replicate_b: T.ptx_ldmatrix( @@ -502,10 +489,30 @@ def _warp_ldmatrix_b( for j in T.serial(local_size_b): mi, mk = mma_load_layout(tx, j) B_local_buf[i * local_size_b + j] = ( - B_shared_buf[wi + mi, wk + mk] if b_transposed else B_shared_buf[wk + mk, wi + mi] + B_buf[tuple(B_other) + (B_base0 + wi + mi, B_base1 + wk + mk)] + if b_transposed + else B_buf[tuple(B_other) + (B_base0 + wk + mk, B_base1 + wi + mi)] ) - return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) + return _warp_ldmatrix_b(B_local_buf, B_region, ki, thread_binding, rk) + + @staticmethod + def _legalize_to_buffer_region(obj: Buffer | BufferLoad | BufferRegion) -> BufferRegion: + if isinstance(obj, BufferRegion): + return obj + if isinstance(obj, Buffer): + mins = [tir.IntImm("int32", 0) for _ in obj.shape] + ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, obj.shape)] + return BufferRegion(obj, ranges) + if isinstance(obj, BufferLoad): + region = get_buffer_region_from_load(obj) + if region is not None: + return region + mins = [idx for idx in obj.indices] + ones = [tir.IntImm("int32", 1) for _ in obj.indices] + ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, ones)] + return BufferRegion(obj.buffer, ranges) + raise ValueError(f"Unsupported argument type for BufferRegion: {type(obj)}") def mma_sp(self, A_local_buf: Buffer, E_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr = 0): warp_rows = self.warp_rows diff --git a/tilelang/jit/adapter/cutedsl/wrapper.py b/tilelang/jit/adapter/cutedsl/wrapper.py index 5521a7fbed..95c8cf0733 100644 --- a/tilelang/jit/adapter/cutedsl/wrapper.py +++ b/tilelang/jit/adapter/cutedsl/wrapper.py @@ -621,8 +621,7 @@ def _generate_cubin_if_needed({cubin_gen_params}): "torch.uint8": cutlass.Uint8, "torch.int16": cutlass.Int16, "torch.uint16": cutlass.Uint16, - "torch.uchar": cutlass.Uint8, - }} + "torch.uchar": cutlass.Uint8}} {cubin_gen_code} diff --git a/tilelang/jit/kernel.py b/tilelang/jit/kernel.py index d2a594d28f..b63884e307 100644 --- a/tilelang/jit/kernel.py +++ b/tilelang/jit/kernel.py @@ -26,6 +26,7 @@ from tilelang.utils.target import determine_target from tilelang.contrib import nvcc as tl_nvcc from tilelang.transform import PassConfigKey +from tilelang.transform.pass_config import normalize_pass_configs import logging import os @@ -99,9 +100,7 @@ def __init__( self.target_host = target_host self.verbose = verbose - if pass_configs is None: - pass_configs = {} - self.pass_configs = pass_configs + self.pass_configs = normalize_pass_configs(pass_configs) self.compile_flags = [compile_flags] if isinstance(compile_flags, str) else compile_flags diff --git a/tilelang/language/dtypes.py b/tilelang/language/dtypes.py index 28a9d7cd4a..74d5aab3f5 100644 --- a/tilelang/language/dtypes.py +++ b/tilelang/language/dtypes.py @@ -30,11 +30,7 @@ def _is_any_dtype(obj: object) -> bool: return isinstance(obj, (ir.Type, str, type, torch.dtype, dtype)) -_PYTHON_DTYPE_TO_STR = { - bool: "bool", - int: "int32", - float: "float32", -} +_PYTHON_DTYPE_TO_STR = {bool: "bool", int: "int32", float: "float32"} _NUMPY_DTYPE_TO_STR = { np.bool_: "bool", diff --git a/tilelang/language/eager/ast.py b/tilelang/language/eager/ast.py index 96930922ca..f3a39eef6a 100644 --- a/tilelang/language/eager/ast.py +++ b/tilelang/language/eager/ast.py @@ -689,7 +689,7 @@ def mutate(func: Callable[_P, _T]) -> IRGenerator[_P, _T]: # a = 123 # def foo(): # x = foo.__globals__ # OK, globals are maintained by python - # x = {**foo.__globals__, } # Not OK: globals are copied, and the original globals cannot be freed + # x = {**foo.__globals__} # Not OK: globals are copied, and the original globals cannot be freed # def bar(): x # return bar # ``` diff --git a/tilelang/language/eager/builder.py b/tilelang/language/eager/builder.py index 812d54638b..c579fd9511 100644 --- a/tilelang/language/eager/builder.py +++ b/tilelang/language/eager/builder.py @@ -971,8 +971,7 @@ def annotate_pass_configs(configs: dict[str, Any]) -> None: @tilelang.jit def kernel(A, B): T.annotate_pass_configs({ - PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + PassConfigKey.TL_ENABLE_FAST_MATH: True}) ... """ builder = Builder.current() diff --git a/tilelang/layout/swizzle.py b/tilelang/layout/swizzle.py index c6952afbc0..01a2f741cd 100644 --- a/tilelang/layout/swizzle.py +++ b/tilelang/layout/swizzle.py @@ -64,71 +64,30 @@ def _get_element_size(buffer_or_load_or_region: BufferLikeType) -> int: # Use a stable swizzled layout to ensure consistent memory access patterns. # Swizzling should be enabled or disabled based on whether TMA (Tensor Memory Access) is applied. def make_swizzled_layout(buffer: BufferLikeType, k_major: bool = True, allow_pad: bool = True): - _, shape, _ = _get_buffer_info(buffer) - stride, continuous = _get_stride_continuous(buffer) - element_size = _get_element_size(buffer) - base = _ffi_api.make_swizzled_layout( - stride, - continuous, - element_size, - k_major, - allow_pad, - ) - return base.reshape(shape) + buf, _, _ = _get_buffer_info(buffer) + return _ffi_api.make_swizzled_layout(buf, k_major, allow_pad) # for Volta Intrinsics def make_volta_swizzled_layout(buffer: BufferLikeType, is_a: bool = True, k_inner: bool = True): - _, shape, _ = _get_buffer_info(buffer) - stride, continuous = _get_stride_continuous(buffer) - base = _ffi_api.make_volta_swizzled_layout( - stride, - continuous, - is_a, - k_inner, - ) - return base.reshape(shape) + buf, _, _ = _get_buffer_info(buffer) + return _ffi_api.make_volta_swizzled_layout(buf, is_a, k_inner) # for WGMMA Intrinsics def make_wgmma_swizzled_layout(buffer: BufferLikeType, continuity: int = None, k_major: bool = True): - _, shape, _ = _get_buffer_info(buffer) - stride, continuous = _get_stride_continuous(buffer) - element_size = _get_element_size(buffer) + buf, _, _ = _get_buffer_info(buffer) if continuity is None: - continuity = continuous - base = _ffi_api.make_wgmma_swizzled_layout( - stride, - continuous, - continuity, - element_size, - k_major, - ) - return base.reshape(shape) + continuity = -1 + return _ffi_api.make_wgmma_swizzled_layout(buf, continuity, k_major) # for TCGEN05MMA Intrinsics def make_tcgen05mma_swizzled_layout(buffer: BufferLikeType, continuity: int = None, k_major: bool = True): - buf, shape, _ = _get_buffer_info(buffer) - stride, continuous = _get_stride_continuous(buffer) - element_size = _get_element_size(buffer) + buf, _, _ = _get_buffer_info(buffer) if continuity is None: - continuity = continuous - try: - base = _ffi_api.make_tcgen05mma_swizzled_layout( - stride, - continuous, - continuity, - element_size, - k_major, - ) - return base.reshape(shape) - except TypeError as err: - # Keep Python sources compatible with older built libs that still expose - # the legacy FFI signature: (buffer, continuity, k_major). - if "Mismatched number of arguments" not in str(err): - raise - return _ffi_api.make_tcgen05mma_swizzled_layout(buf, continuity, k_major) + continuity = -1 + return _ffi_api.make_tcgen05mma_swizzled_layout(buf, continuity, k_major) # swizzle 128B diff --git a/tilelang/tileop/gemm/gemm_wmma.py b/tilelang/tileop/gemm/gemm_wmma.py index 29cd426c66..6c59a41331 100644 --- a/tilelang/tileop/gemm/gemm_wmma.py +++ b/tilelang/tileop/gemm/gemm_wmma.py @@ -1,5 +1,7 @@ """GEMM implementation using AMD RDNA WMMA instructions (gfx11/gfx12).""" +from __future__ import annotations + from .gemm_base import GemmBase from .inst import GemmInst from tilelang.layout import make_swizzled_layout @@ -66,7 +68,9 @@ def infer_layout(self, target: Target, thread_nums: int): else: raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") - def lower(self, layout_map: dict, target: Target, thread_bounds: Range, thread_var: tir.Var): + def lower( + self, layout_map: dict, target: Target, thread_bounds: Range, thread_var: tir.Var, mbar_phase_expr: tir.PrimExpr | None = None + ): thread_nums = thread_bounds.extent wmma_emitter = self._make_emitter(target, thread_nums, thread_var=thread_var) diff --git a/tilelang/tileop/gemm_sp/gemm_sp_base.py b/tilelang/tileop/gemm_sp/gemm_sp_base.py index 8226a06641..3e6ae7c8fc 100644 --- a/tilelang/tileop/gemm_sp/gemm_sp_base.py +++ b/tilelang/tileop/gemm_sp/gemm_sp_base.py @@ -84,19 +84,19 @@ def C(self) -> tir.Buffer: @property def ARegion(self) -> tir.PrimExpr: - return self.gemm_sp_node.ARegion + return self.gemm_sp_node.aRegion @property def ERegion(self) -> tir.PrimExpr: - return self.gemm_sp_node.ERegion + return self.gemm_sp_node.eRegion @property def BRegion(self) -> tir.PrimExpr: - return self.gemm_sp_node.BRegion + return self.gemm_sp_node.bRegion @property def CRegion(self) -> tir.PrimExpr: - return self.gemm_sp_node.CRegion + return self.gemm_sp_node.cRegion @property def stride_A(self) -> int: diff --git a/tilelang/tileop/gemm_sp/gemm_sp_mma.py b/tilelang/tileop/gemm_sp/gemm_sp_mma.py index 9f1a013baa..eabebddc42 100644 --- a/tilelang/tileop/gemm_sp/gemm_sp_mma.py +++ b/tilelang/tileop/gemm_sp/gemm_sp_mma.py @@ -85,9 +85,9 @@ def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): local_size_e = mma_emitter.local_size_e local_size_b = mma_emitter.local_size_b micro_size_k = mma_emitter.micro_size_k - A_shared = self.A - E_shared = self.E - B_shared = self.B + A_shared = self.ARegion + E_shared = self.ERegion + B_shared = self.BRegion C_local = self.C clear_accum = self.clear_accum assert micro_size_k <= self.K, f"K dimension {self.K} should be >= micro size k {micro_size_k}" diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index ed0a89218e..f82592548f 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -38,6 +38,22 @@ def PipelinePlanning(): return _ffi_api.PipelinePlanning() # type: ignore +def InstructionAnnotation(): + """Annotate tile operations with coarse-grained instruction kind. + + This pass runs before LayoutInference and LowerTileOp. It adds a + ``tl_instruction_kind`` annotation to each tile-op Call node indicating + the instruction category ("tma", "cp_async", "sync", "wgmma", etc.) + that will be selected during lowering. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InstructionAnnotation() # type: ignore + + def LayoutInference(): """LayoutInference @@ -71,17 +87,6 @@ def InjectSoftwarePipeline(): return _ffi_api.InjectSoftwarePipeline() # type: ignore -def OptimizeCPAsyncSync(): - """Optimize explicit cp.async commit/wait synchronization intrinsics. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.OptimizeCPAsyncSync() # type: ignore - - def FrontendLegalize(): """FrontendLegalize @@ -160,6 +165,17 @@ def RewriteWgmmaSync(): return _ffi_api.RewriteWgmmaSync() # type: ignore +def ReuseLocalDescriptorAllocations(): + """Pool lexically-disjoint local descriptor allocations. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.ReuseLocalDescriptorAllocations() # type: ignore + + def ThreadSync(storage_scope: str): """Insert sync between parallel read/write of shared buffers. @@ -225,35 +241,33 @@ def LoopUnswitching(): return _ffi_api.LoopUnswitching() # type: ignore -def MultiVersionBuffer(barrier_only: bool = False): - """MultiVersionBuffer +def ProducerConsumerWarpSpecialized(): + """Producer-consumer warp specialization at the tile-op level. - Parameters - ---------- - barrier_only : bool - If True, only version barrier buffers (shared.barrier scope). - Data buffer versioning is left to InjectSoftwarePipeline. + This pass runs before LayoutInference and LowerTileOp. It rewrites + eligible pipelined tile-op loops into warp-specialized producer and + consumer branches with explicit barrier synchronization. Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.MultiVersionBuffer(barrier_only) # type: ignore + return _ffi_api.ProducerConsumerWarpSpecialized() # type: ignore -def ProducerConsumerWarpSpecialized(): - """Producer-Consumer Warp Specialization for TMA pipelines. +def ProducerConsumerWarpSpecializedTiled(): + """Compatibility alias for ``ProducerConsumerWarpSpecialized``. - Splits pipelined loops with TMA loads into producer (TMA copy) and - consumer (compute) warp groups with mbarrier-based synchronization. + The tiled tile-op implementation is now the canonical + ``ProducerConsumerWarpSpecialized`` pass. Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.ProducerConsumerWarpSpecialized() # type: ignore + return ProducerConsumerWarpSpecialized() def AnnotateWarpGroupRegAlloc(): diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py index e720130159..8973c8bb27 100644 --- a/tilelang/transform/pass_config.py +++ b/tilelang/transform/pass_config.py @@ -1,6 +1,10 @@ +from __future__ import annotations + # TODO: Add more documentation for each pass config +import warnings from enum import Enum +from typing import Any class PassConfigKey(str, Enum): @@ -82,7 +86,12 @@ class PassConfigKey(str, Enum): """Bitwidth for configuration indices. Default: 32""" TL_DISABLE_TMA_LOWER = "tl.disable_tma_lower" - """Disable TMA (Tensor Memory Access) lowering. Default: False""" + """Deprecated compatibility-only flag for legacy kernels. + + This flag no longer has any effect in the current lowering pipeline and is + kept only so older kernels do not fail pass-config validation. It will be + removed in v0.1.10. + """ TL_DISABLE_SAFE_MEMORY_ACCESS = "tl.disable_safe_memory_legalize" """Disable safe memory access optimization. Default: False""" @@ -260,3 +269,38 @@ class PassConfigKey(str, Enum): TL_DUMP_IR_DIR = "tl.dump_ir_path" """Path to the directory where IR will be dumped. Default: ./dump_ir/""" + + +_DEPRECATED_PASS_CONFIG_MESSAGES = { + PassConfigKey.TL_DISABLE_TMA_LOWER.value: ( + "`tl.disable_tma_lower` is deprecated, kept only for backward " + "compatibility, has no effect in the current lowering pipeline, and " + "will be removed in v0.1.10." + ), +} + + +def normalize_pass_configs(pass_configs: dict[str, Any] | None) -> dict[str, Any]: + """Canonicalize known pass-config keys and emit compatibility warnings.""" + if pass_configs is None: + return {} + + normalized: dict[str, Any] = {} + warned_keys: set[str] = set() + + for key, value in pass_configs.items(): + normalized_key = key + if isinstance(key, str): + try: + normalized_key = PassConfigKey(key) + except ValueError: + normalized_key = key + + normalized[normalized_key] = value + + warning_key = normalized_key.value if isinstance(normalized_key, PassConfigKey) else normalized_key + if warning_key in _DEPRECATED_PASS_CONFIG_MESSAGES and warning_key not in warned_keys: + warnings.warn(_DEPRECATED_PASS_CONFIG_MESSAGES[warning_key], DeprecationWarning, stacklevel=3) + warned_keys.add(warning_key) + + return normalized From 39adf6afe0ed3a1fce86465000072542fcb3c3dc Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Wed, 8 Apr 2026 16:53:22 +0800 Subject: [PATCH 021/156] remove wg_wait in gemm_auto_tcgen5mma.py --- examples/gemm_sm100/gemm_auto_tcgen5mma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/gemm_sm100/gemm_auto_tcgen5mma.py b/examples/gemm_sm100/gemm_auto_tcgen5mma.py index aae88ed387..130d8c5b74 100644 --- a/examples/gemm_sm100/gemm_auto_tcgen5mma.py +++ b/examples/gemm_sm100/gemm_auto_tcgen5mma.py @@ -40,7 +40,7 @@ def main( for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) # not trans_A T.copy(B[bx * block_N, k * block_K], B_shared) # trans_B - T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, wg_wait=-1, clear_accum=k == 0) + T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, clear_accum=k == 0) T.copy(C_tmem, C_local) T.copy(C_local, C_shared) From 04021209e4b711c0c7c942a60eb58d6f66820548 Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Wed, 8 Apr 2026 16:54:10 +0800 Subject: [PATCH 022/156] support more than 2 warp groups --- src/transform/auto_schedule.cc | 55 +- src/transform/auto_schedule.h | 61 +- src/transform/auto_schedule/barrier.h | 16 +- .../auto_schedule/schedule_builder.cc | 21 +- .../auto_schedule/schedule_builder.h | 36 +- .../auto_schedule/warpgroup_partition.cc | 570 ++++-------------- .../auto_schedule/warpgroup_partition.h | 2 +- 7 files changed, 215 insertions(+), 546 deletions(-) diff --git a/src/transform/auto_schedule.cc b/src/transform/auto_schedule.cc index b8a21712d0..5842f4c9e0 100644 --- a/src/transform/auto_schedule.cc +++ b/src/transform/auto_schedule.cc @@ -39,12 +39,11 @@ #include #include -#include - #include #include #include #include +#include #include #include #include @@ -567,8 +566,6 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { extractor(func->body); Stmt body_to_schedule; bool has_tilelang_root = false; - PrimExpr updated_thread_extent; // Will be set if warpgroup partition - // doubles thread extent IterVar thread_var; // Thread index variable for warpgroup partition if (extractor.body.defined()) { @@ -614,22 +611,15 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { // Build ScheduleUnits from IRStructure ScheduleUnitBuilder unit_builder; - // Get thread index variable for warpgroup partition - // First try to get from body_to_schedule, if not found, try from the entire - // function body - thread_var = ThreadTagChecker::GetThreadVar(body_to_schedule); - if (!thread_var.defined()) { - thread_var = ThreadTagChecker::GetThreadVar(func->body); - } if (thread_var.defined()) { unit_builder.SetThreadVar(thread_var); } else { LOG(FATAL) << "Could not find thread index variable, warpgroup " "partition will use default"; } - unit_builder.SetEnableWarpPartition(config.enable_warp_partition); - unit_builder.SetSharedMemoryLimit(config.shared_memory_limit); - bool double_thread = unit_builder.Build(ir_structure); + unit_builder.SetWarpSpeicializeConfig(config); + unit_builder.SetSharedMemoryLimit(GetSharedMemoryLimit(target)); + std::vector thread_count = unit_builder.Build(ir_structure); if (!config.enable_warpgroup_partition) { Stmt new_body = ConvertIRStructureToStmt(ir_structure.get(), enable_epi); @@ -655,28 +645,13 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { int next_barrier_id = 1; std::vector barrier_buffers; Map barrier_map; - // Determine thread count for barrier arrive_count calculations - PrimExpr thread_count[2]; - if (!config.enable_thread_extend) { - ICHECK(config.enable_warp_partition); - // sm_100: use fixed warp size (32) for both partitions - thread_count[0] = IntImm(DataType::Int(32), 32); - thread_count[1] = IntImm(DataType::Int(32), 32); - } else { - // sm_90: original behavior - thread_count[0] = thread_var->dom->extent; - thread_count[1] = double_thread ? thread_var->dom->extent - : IntImm(DataType::Int(32), - config.producer_thread_count); - } LoopNestingInfo loop_info; std::vector buffer_infos; - PrimExpr barrier_count = config.enable_thread_extend - ? thread_count[0] + thread_count[1] - : thread_var->dom->extent; + PrimExpr updated_thread_extent = std::accumulate( + thread_count.begin() + 1, thread_count.end(), thread_count[0]); Buffer neutral_sync_shared_barrier = - makeBarrierBuffer(barrier_count, "neutral_sync_shared_barrier", 1, - barrier_buffers, barrier_map); + makeBarrierBuffer(updated_thread_extent, "neutral_sync_shared_barrier", + 1, barrier_buffers, barrier_map); AnalyzeAndInsertBarriers( ir_structure.get(), next_barrier_id, barrier_buffers, barrier_map, thread_count, loop_info, buffer_infos, neutral_sync_shared_barrier); @@ -687,19 +662,7 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { // Apply warpgroup partition to entire IRStructure Stmt new_body = ApplyWarpgroupPartitionToIRStructure( ir_structure.get(), thread_var, barrier_buffers, barrier_map, - enable_epi, thread_count, double_thread, config, - neutral_sync_shared_barrier); - - if (config.enable_thread_extend) { - // sm_90: may need to update thread extent - if (double_thread) { - updated_thread_extent = thread_var->dom->extent * 2; - } else { - updated_thread_extent = - thread_var->dom->extent + - IntImm(DataType::Int(32), config.producer_thread_count); - } - } + enable_epi, thread_count, config, neutral_sync_shared_barrier); // If we extracted from tilelang_root block, replace the body Stmt final_body; diff --git a/src/transform/auto_schedule.h b/src/transform/auto_schedule.h index 2e01b50640..e3beab40c6 100644 --- a/src/transform/auto_schedule.h +++ b/src/transform/auto_schedule.h @@ -85,58 +85,25 @@ struct ComponentInfo { bool uses_tensor_core_{false}; }; -// Warp specialization architecture enum -enum class WarpSpecializeArch : uint8_t { - kHopper = 0, - kBlackwell = 1, - kUnsupported = 2, -}; - -// Configuration for warp specialization -struct WarpSpecializeConfig { - WarpSpecializeArch arch = WarpSpecializeArch::kUnsupported; - int consumer_max_nreg = 0; - int producer_max_nreg = 0; - int producer_thread_count = 0; - bool enable_set_max_nreg = false; - bool enable_warpgroup_partition = false; - bool enable_thread_extend = false; - bool enable_warp_partition = false; - int shared_memory_limit = 0; -}; - // Factory function to get warp specialization configuration for a target inline WarpSpecializeConfig GetWarpSpecializeConfig(Target target) { if (TargetIsHopper(target)) { - return {WarpSpecializeArch::kHopper, - 240, - 24, - 128, - true, - true, - true, - false, - 228 * 1024}; + return {WarpSpecializeArch::kHopper, 240, 24, 128, true, true, true, false}; + } else if (TargetIsSm100(target)) { + return {WarpSpecializeArch::kBlackwell, 0, 0, 32, false, true, false, true}; + } else { + return { + WarpSpecializeArch::kUnsupported, 0, 0, 0, false, false, false, false}; + } +} + +inline int64_t GetSharedMemoryLimit(Target target) { + if (TargetIsHopper(target)) { + return 228 * 1024; } else if (TargetIsSm100(target)) { - return {WarpSpecializeArch::kBlackwell, - 0, - 0, - 32, - false, - true, - false, - true, - 228 * 1024}; + return 228 * 1024; } else { - return {WarpSpecializeArch::kUnsupported, - 0, - 0, - 0, - false, - false, - false, - false, - 0}; + return 48 * 1024; } } diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index 08b2215f86..c6e7eafd7c 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -167,14 +167,15 @@ static void AnalyzeAndInsertBarriers(IRStructure *node, int &next_barrier_id, std::vector &barrier_buffers, Map &barrier_map, - PrimExpr thread_count[2], LoopNestingInfo &loop_info, + const std::vector &thread_count, + LoopNestingInfo &loop_info, std::vector &buffer_infos, Buffer neutral_sync_shared_barrier); static void AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, std::vector &barrier_buffers, Map &barrier_map, - PrimExpr thread_count[2], + const std::vector &thread_count, LoopNestingInfo &loop_info, std::vector &buffer_infos, Buffer neutral_sync_shared_barrier); @@ -182,7 +183,8 @@ static void AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, std::vector &barrier_buffers, Map &barrier_map, - PrimExpr thread_count[2], LoopNestingInfo &loop_info, + const std::vector &thread_count, + LoopNestingInfo &loop_info, std::vector &buffer_infos, Buffer neutral_sync_shared_barrier); @@ -543,7 +545,8 @@ static void AnalyzeAndInsertBarriers(IRStructure *node, int &next_barrier_id, std::vector &barrier_buffers, Map &barrier_map, - PrimExpr thread_count[2], LoopNestingInfo &loop_info, + const std::vector &thread_count, + LoopNestingInfo &loop_info, std::vector &buffer_infos, Buffer neutral_sync_shared_barrier) { if (!node) @@ -575,7 +578,7 @@ static void AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, std::vector &barrier_buffers, Map &barrier_map, - PrimExpr thread_count[2], + const std::vector &thread_count, LoopNestingInfo &loop_info, std::vector &buffer_infos, Buffer neutral_sync_shared_barrier) { @@ -786,7 +789,8 @@ static void AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, std::vector &barrier_buffers, Map &barrier_map, - PrimExpr thread_count[2], LoopNestingInfo &loop_info, + const std::vector &thread_count, + LoopNestingInfo &loop_info, std::vector &buffer_infos, Buffer neutral_sync_shared_barrier) { if (!ctrl || !ctrl->child) diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index 225a5fe65d..9b4178d0b0 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -237,7 +237,9 @@ void CollectSuffixTasks(IRStructure *root, } } -bool AssignWarpgroupIdsGlobal(IRStructure *root, bool enable_warp_partition) { +std::vector +AssignWarpgroupIdsGlobal(IRStructure *root, const WarpSpecializeConfig &config, + PrimExpr thread_count) { if (!root) { LOG(FATAL) << "Empty root"; } @@ -269,7 +271,7 @@ bool AssignWarpgroupIdsGlobal(IRStructure *root, bool enable_warp_partition) { CollectPrefixTasks(root, prefix_tasks, prefix_valid); std::unordered_set suffix_tasks; - if (enable_warp_partition) { + if (config.enable_warp_partition) { CollectSuffixTasks(root, all_tasks, uf, suffix_tasks); } @@ -360,7 +362,12 @@ bool AssignWarpgroupIdsGlobal(IRStructure *root, bool enable_warp_partition) { } } } - return true; + if (config.enable_thread_extend) { + return {thread_count, thread_count}; + } else { + return {IntImm(DataType::Int(32), 32), IntImm(DataType::Int(32), 32), + thread_count - IntImm(DataType::Int(32), 64)}; + } } else { int64_t warpgroup0_latency = 0; int64_t warpgroup1_latency = 0; @@ -387,7 +394,13 @@ bool AssignWarpgroupIdsGlobal(IRStructure *root, bool enable_warp_partition) { } } } - return false; + if (config.enable_thread_extend) { + return {thread_count, + IntImm(DataType::Int(32), config.producer_thread_count)}; + } else { + return {IntImm(DataType::Int(32), 32), IntImm(DataType::Int(32), 32), + thread_count - IntImm(DataType::Int(32), 64)}; + } } } diff --git a/src/transform/auto_schedule/schedule_builder.h b/src/transform/auto_schedule/schedule_builder.h index 60f128311e..28fc4d9763 100644 --- a/src/transform/auto_schedule/schedule_builder.h +++ b/src/transform/auto_schedule/schedule_builder.h @@ -29,7 +29,28 @@ using namespace tir; class TaskUnionFind; struct ComponentInfo; -bool AssignWarpgroupIdsGlobal(IRStructure *root, bool enable_warp_partition); +// Warp specialization architecture enum +enum class WarpSpecializeArch : uint8_t { + kHopper = 0, + kBlackwell = 1, + kUnsupported = 2, +}; + +// Configuration for warp specialization +struct WarpSpecializeConfig { + WarpSpecializeArch arch = WarpSpecializeArch::kUnsupported; + int consumer_max_nreg = 0; + int producer_max_nreg = 0; + int producer_thread_count = 0; + bool enable_set_max_nreg = false; + bool enable_warpgroup_partition = false; + bool enable_thread_extend = false; + bool enable_warp_partition = false; +}; + +std::vector +AssignWarpgroupIdsGlobal(IRStructure *root, const WarpSpecializeConfig &config, + PrimExpr thread_count); // Extract all sequential task nodes from the IR structure tree void GatherTaskNodes(const std::vector> &nodes, @@ -54,11 +75,12 @@ void CollectSuffixTasks(IRStructure *root, // Builder that collects ScheduleUnits from IRStructure class ScheduleUnitBuilder { public: - bool Build(std::shared_ptr &root) { + std::vector Build(std::shared_ptr &root) { ScheduleRecursive(root, {}); // Global warpgroup id assignment from the top level - return AssignWarpgroupIdsGlobal(root.get(), enable_warp_partition_); + return AssignWarpgroupIdsGlobal(root.get(), config_, + thread_var_->dom->extent); } // New recursive scheduling function that replaces Collect method @@ -589,14 +611,16 @@ class ScheduleUnitBuilder { void SetThreadVar(IterVar thread_var) { thread_var_ = thread_var; } // Set enable_warp_partition flag - void SetEnableWarpPartition(bool enable) { enable_warp_partition_ = enable; } + void SetWarpSpeicializeConfig(const WarpSpecializeConfig &config) { + config_ = config; + } // Set shared memory limit for pipeline (in bytes) void SetSharedMemoryLimit(int64_t bytes) { shared_memory_limit_ = bytes; } private: - IterVar thread_var_; // Thread index variable for warpgroup partition - bool enable_warp_partition_ = false; + IterVar thread_var_; // Thread index variable for warpgroup partition + WarpSpecializeConfig config_; // Configuration for warp specialization int64_t shared_memory_limit_ = 48 * 1024; // Check if two regions refer to the same buffer diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index 0d73723e75..649ac132b7 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -48,6 +48,7 @@ #include #include #include +#include #include #include #include @@ -250,6 +251,7 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, return new_unit; } LOG(FATAL); + return nullptr; } // Entry point overload — creates a fresh var_remap per call @@ -716,7 +718,7 @@ Stmt ConvertIRStructureToStmt(IRStructure *root, const bool outer_enable_epi) { Stmt ApplyWarpgroupPartitionToIRStructure( IRStructure *root, IterVar thread_var, std::vector &barrier_buffers, Map &barrier_map, const bool outer_enable_epi, - PrimExpr thread_count[2], bool producer_consumer, + const std::vector &thread_count, const WarpSpecializeConfig &config, Buffer neutral_sync_shared_barrier) { if (!root) return Evaluate(0); @@ -727,8 +729,7 @@ Stmt ApplyWarpgroupPartitionToIRStructure( if (wrapper->child) { body = ApplyWarpgroupPartitionToIRStructure( wrapper->child.get(), thread_var, barrier_buffers, barrier_map, - outer_enable_epi, thread_count, producer_consumer, config, - neutral_sync_shared_barrier); + outer_enable_epi, thread_count, config, neutral_sync_shared_barrier); } if (const auto *let = wrapper->wrapper.as()) { return LetStmt(let->var, let->value, body); @@ -736,302 +737,12 @@ Stmt ApplyWarpgroupPartitionToIRStructure( return AttrStmt(attr->node, attr->attr_key, attr->value, body); } else { LOG(FATAL); - } - } - - // Check if there are tasks with mixed warpgroup ids - std::vector all_tasks; - CollectAllTaskNodesWithContext(root, all_tasks); - - bool has_warpgroup0 = false; - bool has_warpgroup1 = false; - bool has_warpgroup_neutral = false; - for (auto &task : all_tasks) { - int wg_id = task.task->GetWarpgroupId(); - if (wg_id == 0) - has_warpgroup0 = true; - else if (wg_id == 1) - has_warpgroup1 = true; - else if (wg_id == -1) - has_warpgroup_neutral = true; - } - - // Convert IRStructure to Stmt for IfThenElse - std::function irstructure_to_stmt; - irstructure_to_stmt = [&irstructure_to_stmt, - outer_enable_epi](IRStructure *structure) -> Stmt { - if (!structure) { return Evaluate(0); } - - if (structure->IsTask()) { - auto task = static_cast(structure); - if (task->stmts.empty()) { - return Evaluate(0); - } else if (task->stmts.size() == 1) { - return task->stmts[0]; - } else { - return SeqStmt(task->stmts); - } - } else if (structure->IsSequence()) { - auto seq = static_cast(structure); - std::vector stmts; - for (const auto &child : seq->children) { - auto unit = static_cast(child.get()); - for (auto &before : unit->before) { - for (auto &stmt : before) { - stmts.push_back(stmt); - } - } - Stmt child_stmt = irstructure_to_stmt(unit->child.get()); - stmts.push_back(child_stmt); - for (auto &after : unit->after) { - for (auto &stmt : after) { - stmts.push_back(stmt); - } - } - } - auto flattened = SeqStmt::Flatten(stmts); - return flattened; - } else if (structure->IsControl()) { - auto ctrl = static_cast(structure); - Var loop_var = ctrl->control->loop_var; - PrimExpr loop_start = ctrl->control->min; - PrimExpr loop_extent = ctrl->control->extent; - PrimExpr loop_step = ctrl->control->step.has_value() - ? ctrl->control->step.value() - : IntImm(DataType::Int(32), 1); - int min_stages = 100, max_stages = -1; - if (ctrl->child->IsSequence()) { - auto seq = static_cast(ctrl->child.get()); - for (auto &child : seq->children) { - auto unit = static_cast(child.get()); - min_stages = std::min(min_stages, unit->stage); - max_stages = std::max(max_stages, unit->stage); - } - } - if (!ctrl->hasPromote() || !ctrl->child->IsSequence() || - min_stages == max_stages) { - std::vector stmts; - if (ctrl->child->IsScheduleUnit()) { - auto unit = static_cast(ctrl->child.get()); - for (auto &before : unit->before) { - for (auto &stmt : before) { - stmts.push_back(stmt); - } - } - stmts.push_back(irstructure_to_stmt(unit->child.get())); - for (auto &after : unit->after) { - for (auto &stmt : after) { - stmts.push_back(stmt); - } - } - } else if (ctrl->child->IsSequence()) { - auto seq = static_cast(ctrl->child.get()); - for (auto &child : seq->children) { - ICHECK(child->IsScheduleUnit()); - auto unit = static_cast(child.get()); - for (auto &before : unit->before) { - for (auto &stmt : before) { - stmts.push_back(stmt); - } - } - stmts.push_back(irstructure_to_stmt(unit->child.get())); - for (auto &after : unit->after) { - for (auto &stmt : after) { - stmts.push_back(stmt); - } - } - } - } else { - LOG(FATAL); - } - Stmt body = SeqStmt::Flatten(stmts); - // Filter out "num_stages" annotation - Map filtered_annotations = ctrl->control->annotations; - filtered_annotations.erase("num_stages"); - return For(loop_var, loop_start, loop_extent, ctrl->control->kind, body, - ctrl->control->thread_binding, filtered_annotations); - } - auto seq = static_cast(ctrl->child.get()); - Stmt body = Evaluate(0); - std::vector> unit_stages; - unit_stages.resize(max_stages - min_stages + 1); - for (auto &child : seq->children) { - auto unit = static_cast(child.get()); - std::vector stmts; - for (auto &before : unit->before) { - for (auto &stmt : before) { - stmts.push_back(stmt); - } - } - stmts.push_back(irstructure_to_stmt(unit->child.get())); - for (auto &after : unit->after) { - for (auto &stmt : after) { - stmts.push_back(stmt); - } - } - unit_stages[unit->stage - min_stages].push_back( - SeqStmt::Flatten(stmts)); - } - // Check if any task in this control node contains loop_break - // If any task contains loop_break, disable prologue - std::function check_contains_loop_break; - check_contains_loop_break = - [&check_contains_loop_break](IRStructure *structure) -> bool { - if (!structure) - return false; - - if (structure->IsTask()) { - auto task = static_cast(structure); - return task->ContainsLoopBreak(); - } else if (structure->IsSequence()) { - auto seq = static_cast(structure); - for (const auto &child : seq->children) { - auto unit = static_cast(child.get()); - if (check_contains_loop_break(unit->child.get())) { - return true; - } - } - return false; - } else if (structure->IsScheduleUnit()) { - auto unit = static_cast(structure); - return check_contains_loop_break(unit->child.get()); - } else if (structure->IsControl()) { - auto ctrl = static_cast(structure); - return check_contains_loop_break(ctrl->child.get()); - } else if (structure->IsWrapper()) { - auto wrapper = static_cast(structure); - return check_contains_loop_break(wrapper->child.get()); - } - return false; - }; - - // Set enable_pro to true only if: - // 1. No task contains loop_break - // 2. Loop boundaries (min and extent) are constants - bool enable_pro = !check_contains_loop_break(ctrl->child.get()); - - // Check if loop boundaries are constants - bool loop_min_is_const = tir::is_const_int(loop_start); - bool loop_extent_is_const = tir::is_const_int(loop_extent); - - if (!loop_min_is_const || !loop_extent_is_const) { - enable_pro = false; - } - - bool enable_epi = outer_enable_epi && enable_pro; - std::vector steady; - - for (auto &child : seq->children) { - auto unit = static_cast(child.get()); - std::vector stmts; - for (auto &before : unit->before) { - for (auto &stmt : before) { - stmts.push_back(stmt); - } - } - stmts.push_back(irstructure_to_stmt(unit->child.get())); - for (auto &after : unit->after) { - for (auto &stmt : after) { - stmts.push_back(stmt); - } - } - Map substitution, substitution_cond; - substitution.Set(loop_var, - loop_var - loop_step * (max_stages - unit->stage)); - substitution_cond.Set( - loop_var, - Max(loop_start, - Min(loop_start + loop_extent - loop_step, - loop_var - loop_step * (max_stages - unit->stage)))); - if (IsLetDeclNode(unit->child.get())) { - Stmt stmt = SeqStmt::Flatten(stmts); - steady.push_back(Substitute(stmt, substitution_cond)); - } else { - PrimExpr condition = - And(loop_var < loop_start + loop_extent, loop_var >= loop_start); - if (unit->stage == min_stages) { - condition = loop_var >= loop_start; - } - if (unit->stage == max_stages) { - condition = loop_var < loop_start + loop_extent; - } - Stmt stmt = IfThenElse(condition, SeqStmt::Flatten(stmts)); - steady.push_back(Substitute(stmt, substitution)); - } - } - Stmt new_body = SeqStmt::Flatten(steady); - auto new_var = loop_var.copy_with_suffix(""); - // Filter out "num_stages" annotation - Map filtered_annotations = ctrl->control->annotations; - filtered_annotations.erase("num_stages"); - Map substitution; - substitution.Set(loop_var, new_var); - For for_op = - For(new_var, loop_start, - ctrl->control->extent + loop_step * (max_stages - min_stages), - ctrl->control->kind, Substitute(new_body, substitution), - ctrl->control->thread_binding, filtered_annotations); - - Stmt prologue = Evaluate(0); - if (enable_pro) { - Map sub; - For new_for = for_op; - auto pro = loop_var.copy_with_suffix("_prologue"); - sub.Set(new_var, pro); - new_for.CopyOnWrite()->loop_var = pro; - new_for.CopyOnWrite()->kind = ForKind::kUnrolled; - new_for.CopyOnWrite()->extent = - min(max_stages - min_stages, for_op.get()->extent); - for_op.CopyOnWrite()->min += loop_step * (max_stages - min_stages); - for_op.CopyOnWrite()->extent = - max(0, for_op.get()->extent - (max_stages - min_stages)); - prologue = Substitute(new_for, sub); - } - Stmt epilogue = Evaluate(0); - if (enable_epi) { - Map sub; - For new_for = for_op; - auto epi = loop_var.copy_with_suffix("_epilogue"); - sub.Set(new_var, epi); - new_for.CopyOnWrite()->loop_var = epi; - new_for.CopyOnWrite()->kind = ForKind::kUnrolled; - new_for.CopyOnWrite()->min = - for_op.get()->min + - loop_step * (for_op.get()->extent - (max_stages - min_stages)); - new_for.CopyOnWrite()->extent = - min(max_stages - min_stages, for_op.get()->extent); - for_op.CopyOnWrite()->extent = - max(0, for_op.get()->extent - (max_stages - min_stages)); - epilogue = Substitute(new_for, sub); - } - return SeqStmt({prologue, for_op, epilogue}); - } else if (structure->IsWrapper()) { - auto wrapper = static_cast(structure); - Stmt body = Evaluate(0); - if (wrapper->child) { - body = irstructure_to_stmt(wrapper->child.get()); - } - if (const auto *let = wrapper->wrapper.as()) { - return LetStmt(let->var, let->value, body); - } else if (const auto *attr = wrapper->wrapper.as()) { - return AttrStmt(attr->node, attr->attr_key, attr->value, body); - } else { - LOG(FATAL); - } - } - - LOG(FATAL) - << "Failed to convert IRStructure to Stmt, returning empty statement"; - return Evaluate(0); - }; - - // If all tasks belong to the same warpgroup, no partition needed - if (!(has_warpgroup0 && has_warpgroup1)) { - return irstructure_to_stmt(root); } + size_t num_wgs = thread_count.size(); + // Helper function to clone IRStructure filtering tasks with warpgroup_id == // -1 (neutral tasks) std::function(IRStructure *)> @@ -1078,9 +789,12 @@ Stmt ApplyWarpgroupPartitionToIRStructure( return nullptr; } LOG(FATAL); + return nullptr; }; auto has_actual_statements = [](IRStructure *node) -> bool { + if (!node) + return false; std::vector tasks; CollectAllTaskNodesWithContext(node, tasks); for (auto &task : tasks) { @@ -1143,6 +857,7 @@ Stmt ApplyWarpgroupPartitionToIRStructure( return nullptr; } LOG(FATAL); + return nullptr; }; int last_warpgroup_task_top_level_index = -1; @@ -1173,45 +888,41 @@ Stmt ApplyWarpgroupPartitionToIRStructure( return !is_epi_top_level_index(top_level_index); }; - auto wg_pro_neutral_structure = has_warpgroup_neutral - ? clone_neutral_filter_with_top_level( - root, is_pro_top_level_index, -1) - : nullptr; - auto wg_epi_neutral_structure = has_warpgroup_neutral - ? clone_neutral_filter_with_top_level( - root, is_epi_top_level_index, -1) - : nullptr; - - auto wg0_structure = - RemoveUnusedLetDecls(CloneIRStructureWithWarpgroupFilter(root, 0)); - auto wg1_structure = - RemoveUnusedLetDecls(CloneIRStructureWithWarpgroupFilter(root, 1)); + auto wg_pro_neutral_structure = + clone_neutral_filter_with_top_level(root, is_pro_top_level_index, -1); + auto wg_epi_neutral_structure = + clone_neutral_filter_with_top_level(root, is_epi_top_level_index, -1); + std::vector> wg_structures(num_wgs); + for (size_t i = 0; i < num_wgs; ++i) { + wg_structures[i] = + RemoveUnusedLetDecls(CloneIRStructureWithWarpgroupFilter(root, i)); + } + std::vector wg_conditions(num_wgs); + wg_conditions[0] = thread_count[0]; + for (size_t i = 1; i < num_wgs; ++i) { + wg_conditions[i] = wg_conditions[i - 1] + thread_count[i]; + } + for (auto &cond : wg_conditions) { + cond = thread_var->var < cond; + } bool wg_pro_neutral_has_stmts = - wg_pro_neutral_structure - ? has_actual_statements(wg_pro_neutral_structure.get()) - : false; + has_actual_statements(wg_pro_neutral_structure.get()); bool wg_epi_neutral_has_stmts = - wg_epi_neutral_structure - ? has_actual_statements(wg_epi_neutral_structure.get()) - : false; - bool wg0_has_stmts = has_actual_statements(wg0_structure.get()); - bool wg1_has_stmts = has_actual_statements(wg1_structure.get()); - - PrimExpr condition = thread_var->var < thread_count[0]; - PrimExpr wg1_condition = - thread_var->var < (thread_count[0] + thread_count[1]); + has_actual_statements(wg_epi_neutral_structure.get()); Stmt pro_neutral_body = wg_pro_neutral_has_stmts - ? irstructure_to_stmt(wg_pro_neutral_structure.get()) + ? ConvertIRStructureToStmt(wg_pro_neutral_structure.get(), + outer_enable_epi) : Evaluate(0); Stmt epi_neutral_body = wg_epi_neutral_has_stmts - ? irstructure_to_stmt(wg_epi_neutral_structure.get()) + ? ConvertIRStructureToStmt(wg_epi_neutral_structure.get(), + outer_enable_epi) : Evaluate(0); - // --- Segment the wg0/wg1 structures by ControlNode (for-loop) boundaries --- + // --- Segment the wg structures by ControlNode (for-loop) boundaries --- // This produces multiple IfThenElse blocks separated by liveness boundary // markers, so that the merge-shared-memory pass can reuse buffers across // segments whose lifetimes do not overlap. @@ -1250,7 +961,7 @@ Stmt ApplyWarpgroupPartitionToIRStructure( // Helper: wrap a list of ScheduleUnit children back into a temporary // SequenceNode and convert to Stmt. auto SegmentToStmt = - [&irstructure_to_stmt]( + [&outer_enable_epi]( const std::vector> &children) -> Stmt { if (children.empty()) return Evaluate(0); @@ -1258,23 +969,17 @@ Stmt ApplyWarpgroupPartitionToIRStructure( // ScheduleUnit before/after stmts are emitted correctly. auto tmp_seq = std::make_shared(); tmp_seq->children = children; - return irstructure_to_stmt(tmp_seq.get()); + return ConvertIRStructureToStmt(tmp_seq.get(), outer_enable_epi); }; // Helper: build a single IfThenElse (with wg1 nesting) from a pair of Stmts. - auto MakeWarpgroupIf = [&condition, &wg1_condition](Stmt wg0_stmt, - Stmt wg1_stmt) -> Stmt { - bool wg0_valid = !IsEvaluateZero(wg0_stmt); - bool wg1_valid = !IsEvaluateZero(wg1_stmt); - if (wg0_valid && wg1_valid) { - return IfThenElse(condition, wg0_stmt, - IfThenElse(wg1_condition, wg1_stmt, Evaluate(0))); - } else if (wg0_valid) { - return IfThenElse(condition, wg0_stmt); - } else if (wg1_valid) { - return IfThenElse(wg1_condition, wg1_stmt); + auto MakeWarpgroupIf = + [&wg_conditions](const std::vector &wg_stmts) -> Stmt { + Stmt if_then_else = Evaluate(0); + for (size_t i = wg_stmts.size(); i-- > 0;) { + if_then_else = IfThenElse(wg_conditions[i], wg_stmts[i], if_then_else); } - return Evaluate(0); + return if_then_else; }; // Helper: collect LetDecl {Var, PrimExpr} pairs from a segment's children. @@ -1369,110 +1074,102 @@ Stmt ApplyWarpgroupPartitionToIRStructure( }; Stmt if_then_else; - if (wg0_has_stmts && wg1_has_stmts) { - auto wg0_segments = SegmentSequenceChildren(wg0_structure.get()); - auto wg1_segments = SegmentSequenceChildren(wg1_structure.get()); - - // Only apply segmented splitting when both sides have matching segment - // counts (they originate from the same root, split at ControlNode - // boundaries, so they should match). Otherwise fall back to the - // single-IfThenElse path. - if (!wg0_segments.empty() && !wg1_segments.empty() && - wg0_segments.size() == wg1_segments.size() && wg0_segments.size() > 1) { - std::vector segmented_stmts; - bool has_simt_copy = false; - // Check for SIMT copy in any wg1 segment (needed for set_max_nreg - // decision). - { - Stmt full_wg1 = irstructure_to_stmt(wg1_structure.get()); - has_simt_copy = SimtCopyDetector::Detect(full_wg1); + std::vector>>> + wg_segments(num_wgs); + bool equal_segment_counts = true; + for (size_t i = 0; i < num_wgs; ++i) { + wg_segments[i] = SegmentSequenceChildren(wg_structures[i].get()); + equal_segment_counts &= (wg_segments[i].size() == wg_segments[0].size()); + } + + // Only apply segmented splitting when both sides have matching segment + // counts (they originate from the same root, split at ControlNode + // boundaries, so they should match). Otherwise fall back to the + // single-IfThenElse path. + if (equal_segment_counts && wg_segments[0].size() > 1) { + std::vector segmented_stmts; + bool has_simt_copy = false; + // Check for SIMT copy in any wg1 segment (needed for set_max_nreg + // decision). + if (num_wgs == 2) { + Stmt full_wg1 = + ConvertIRStructureToStmt(wg_structures[1].get(), outer_enable_epi); + has_simt_copy = SimtCopyDetector::Detect(full_wg1); + } + + // Accumulate LetDecl info from previous segments for variable renaming. + std::vector>> wg_accumulated_lets( + num_wgs); + + for (size_t si = 0; si < wg_segments[0].size(); ++si) { + // Insert liveness boundary between segments. + segmented_stmts.push_back(AttrStmt( + Integer(0), attr::kAutoScheduleSharedMemoryBoundary, 0, Evaluate(0))); + + // Collect LetDecl info from current segment before converting to Stmt. + std::vector>> wg_lets(num_wgs); + for (size_t i = 0; i < num_wgs; ++i) { + wg_lets[i] = CollectLetDeclInfo(wg_segments[i][si]); + } + std::vector wg_seg_stmts(num_wgs); + for (size_t i = 0; i < num_wgs; ++i) { + wg_seg_stmts[i] = SegmentToStmt(wg_segments[i][si]); } - // Accumulate LetDecl info from previous segments for variable renaming. - std::vector> wg0_accumulated_lets; - std::vector> wg1_accumulated_lets; - - for (size_t si = 0; si < wg0_segments.size(); ++si) { - // Insert liveness boundary between segments. - segmented_stmts.push_back( - AttrStmt(Integer(0), attr::kAutoScheduleSharedMemoryBoundary, 0, - Evaluate(0))); - - // Collect LetDecl info from current segment before converting to Stmt. - auto wg0_lets = CollectLetDeclInfo(wg0_segments[si]); - auto wg1_lets = CollectLetDeclInfo(wg1_segments[si]); - - Stmt wg0_seg_stmt = SegmentToStmt(wg0_segments[si]); - Stmt wg1_seg_stmt = SegmentToStmt(wg1_segments[si]); - - // For segments after the first, wrap with renamed LetDecl bindings - // from all previous segments so that variables remain in scope. - if (si > 0) { - wg0_seg_stmt = - WrapWithRenamedLetDecls(wg0_seg_stmt, wg0_accumulated_lets); - wg1_seg_stmt = - WrapWithRenamedLetDecls(wg1_seg_stmt, wg1_accumulated_lets); + // For segments after the first, wrap with renamed LetDecl bindings + // from all previous segments so that variables remain in scope. + if (si > 0) { + for (size_t i = 0; i < num_wgs; ++i) { + wg_seg_stmts[i] = + WrapWithRenamedLetDecls(wg_seg_stmts[i], wg_accumulated_lets[i]); } + } - // Accumulate this segment's LetDecls for future segments. - wg0_accumulated_lets.insert(wg0_accumulated_lets.end(), - wg0_lets.begin(), wg0_lets.end()); - wg1_accumulated_lets.insert(wg1_accumulated_lets.end(), - wg1_lets.begin(), wg1_lets.end()); + // Accumulate this segment's LetDecls for future segments. + for (size_t i = 0; i < num_wgs; ++i) { + wg_accumulated_lets[i].insert(wg_accumulated_lets[i].end(), + wg_lets[i].begin(), wg_lets[i].end()); + } - // Prepend set_max_nreg only to the first segment. - if (si == 0 && !has_simt_copy && config.enable_set_max_nreg) { - wg0_seg_stmt = + // Prepend set_max_nreg only to the first segment. + if (si == 0 && !has_simt_copy && num_wgs == 2 && + config.enable_set_max_nreg) { + for (size_t i = 0; i < num_wgs; ++i) { + wg_seg_stmts[i] = SeqStmt({Evaluate(Call(DataType::Handle(), tl::set_max_nreg(), - {config.consumer_max_nreg, 1})), - wg0_seg_stmt}); - wg1_seg_stmt = - SeqStmt({Evaluate(Call(DataType::Handle(), tl::set_max_nreg(), - {config.producer_max_nreg, 0})), - wg1_seg_stmt}); + {i == 0 ? config.consumer_max_nreg + : config.producer_max_nreg, + static_cast(!i)})), + wg_seg_stmts[i]}); } - - segmented_stmts.push_back(MakeWarpgroupIf(wg0_seg_stmt, wg1_seg_stmt)); - } - if_then_else = SeqStmt::Flatten(segmented_stmts); - } else { - // Fallback: single IfThenElse (original logic). - Stmt then_body = irstructure_to_stmt(wg0_structure.get()); - Stmt else_body = irstructure_to_stmt(wg1_structure.get()); - bool has_simt_copy = SimtCopyDetector::Detect(else_body); - if (has_simt_copy || !config.enable_set_max_nreg) { - if_then_else = - IfThenElse(condition, then_body, - IfThenElse(wg1_condition, else_body, Evaluate(0))); - } else { - std::vector then_body_with_nreg{ - Evaluate(Call(DataType::Handle(), tl::set_max_nreg(), - {config.consumer_max_nreg, 1})), - then_body}; - std::vector else_body_with_nreg{ - Evaluate(Call(DataType::Handle(), tl::set_max_nreg(), - {config.producer_max_nreg, 0})), - else_body}; - if_then_else = - IfThenElse(condition, SeqStmt(then_body_with_nreg), - IfThenElse(wg1_condition, SeqStmt(else_body_with_nreg), - Evaluate(0))); } + + segmented_stmts.push_back(MakeWarpgroupIf(wg_seg_stmts)); } - } else if (wg0_has_stmts) { - // Only warpgroup 0 has statements, execute unconditionally - if_then_else = irstructure_to_stmt(wg0_structure.get()); - } else if (wg1_has_stmts) { - // Only warpgroup 1 has statements, execute unconditionally - if_then_else = irstructure_to_stmt(wg1_structure.get()); + if_then_else = SeqStmt::Flatten(segmented_stmts); } else { - // Neither warpgroup 0 nor 1 has statements - if_then_else = Evaluate(0); + // Fallback: single IfThenElse (original logic). + std::vector wg_stmts(num_wgs); + for (size_t i = 0; i < num_wgs; ++i) { + wg_stmts[i] = + ConvertIRStructureToStmt(wg_structures[i].get(), outer_enable_epi); + } + bool has_simt_copy = num_wgs == 2 && SimtCopyDetector::Detect(wg_stmts[1]); + if (!has_simt_copy && num_wgs == 2 && config.enable_set_max_nreg) { + for (size_t i = 0; i < num_wgs; ++i) { + wg_stmts[i] = + SeqStmt({Evaluate(Call(DataType::Handle(), tl::set_max_nreg(), + {i == 0 ? config.consumer_max_nreg + : config.producer_max_nreg, + static_cast(!i)})), + wg_stmts[i]}); + } + } + if_then_else = MakeWarpgroupIf(wg_stmts); } - PrimExpr barrier_count = config.enable_thread_extend - ? thread_count[0] + thread_count[1] - : thread_var->dom->extent; + PrimExpr updated_thread_extent = std::accumulate( + thread_count.begin() + 1, thread_count.end(), thread_count[0]); Stmt pro_and_warpgroup_stmt; if (wg_pro_neutral_has_stmts) { @@ -1481,7 +1178,7 @@ Stmt ApplyWarpgroupPartitionToIRStructure( // synchronization pro_and_warpgroup_stmt = InsertBarriersForNeutralSync( pro_neutral_body, if_then_else, barrier_buffers, barrier_map, - barrier_count, neutral_sync_shared_barrier); + updated_thread_extent, neutral_sync_shared_barrier); } else if (!IsEvaluateZero(if_then_else) || !IsEvaluateZero(pro_neutral_body)) { // Only one has actual statements @@ -1508,15 +1205,14 @@ Stmt ApplyWarpgroupPartitionToIRStructure( bool need_shared_barrier_for_epi = false; bool need_tmem_barrier_for_epi = false; if (wg_epi_neutral_structure) { - for (const auto *warpgroup_structure : - {wg0_structure.get(), wg1_structure.get()}) { + for (const auto &warpgroup_structure : wg_structures) { need_shared_barrier_for_epi = need_shared_barrier_for_epi || - HasSharedWriteReadDependency(warpgroup_structure, + HasSharedWriteReadDependency(warpgroup_structure.get(), wg_epi_neutral_structure.get()); need_tmem_barrier_for_epi = need_tmem_barrier_for_epi || - HasTmemWriteReadDependency(warpgroup_structure, + HasTmemWriteReadDependency(warpgroup_structure.get(), wg_epi_neutral_structure.get()); } } @@ -1526,10 +1222,12 @@ Stmt ApplyWarpgroupPartitionToIRStructure( !IsEvaluateZero(epi_neutral_body)) { // Both have statements: insert barriers for warpgroup-to-epi_neutral // synchronization + // TODO: tensor core may not only in wg0? combined_stmt = InsertBarriersForNeutralSyncWithDependency( pro_and_warpgroup_stmt, epi_neutral_body, barrier_buffers, barrier_map, - barrier_count, need_shared_barrier_for_epi, need_tmem_barrier_for_epi, - Buffer(), thread_var->var, 0, thread_count[0]); + updated_thread_extent, need_shared_barrier_for_epi, + need_tmem_barrier_for_epi, Buffer(), thread_var->var, 0, + thread_count[0]); } else if (!IsEvaluateZero(epi_neutral_body)) { combined_stmt = epi_neutral_body; } else { diff --git a/src/transform/auto_schedule/warpgroup_partition.h b/src/transform/auto_schedule/warpgroup_partition.h index 8d3f5ae83a..7b47d89cff 100644 --- a/src/transform/auto_schedule/warpgroup_partition.h +++ b/src/transform/auto_schedule/warpgroup_partition.h @@ -49,7 +49,7 @@ Stmt ConvertIRStructureToStmt(IRStructure *root, const bool outer_enable_epi); Stmt ApplyWarpgroupPartitionToIRStructure( IRStructure *root, IterVar thread_var, std::vector &barrier_buffers, Map &barrier_map, const bool enable_epi, - PrimExpr thread_count[2], bool producer_consumer, + const std::vector &thread_count, const WarpSpecializeConfig &config, Buffer neutral_sync_shared_barrier); Stmt ReNestLetStmts(const Stmt &stmt); From b021c2afff1e4b084b1132261afef83b40ed1a34 Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Wed, 8 Apr 2026 17:18:24 +0800 Subject: [PATCH 023/156] change `before` and `after` to map --- src/transform/auto_schedule/ir_structure.h | 29 +++++++---------- .../auto_schedule/warpgroup_partition.cc | 32 +++++++++---------- 2 files changed, 28 insertions(+), 33 deletions(-) diff --git a/src/transform/auto_schedule/ir_structure.h b/src/transform/auto_schedule/ir_structure.h index 9684a04be3..484bc0085c 100644 --- a/src/transform/auto_schedule/ir_structure.h +++ b/src/transform/auto_schedule/ir_structure.h @@ -569,15 +569,10 @@ class WrapperNode : public IRStructure { class ScheduleUnit : public IRStructure { public: int stage; - std::vector> before, after; + std::map> before, after; std::shared_ptr child; - ScheduleUnit() { - for (unsigned idx = 0; idx != 2; ++idx) { - before.emplace_back(); - after.emplace_back(); - } - } + ScheduleUnit() {} Kind GetKind() const override { return Kind::kSchedule; } @@ -611,12 +606,12 @@ class ScheduleUnit : public IRStructure { if (child) { child->SubstituteVar(old_var, new_var); } - for (auto &stmts : before) { + for (auto &[_, stmts] : before) { for (auto &stmt : stmts) { stmt = Substitute(stmt, {{old_var, new_var}}); } } - for (auto &stmts : after) { + for (auto &[_, stmts] : after) { for (auto &stmt : stmts) { stmt = Substitute(stmt, {{old_var, new_var}}); } @@ -925,13 +920,13 @@ inline void PrintAllStmts(const IRStructure *node, int indent = 0) { const ScheduleUnit *promote = static_cast(node); LOG(INFO) << indent_str << "ScheduleUnit:"; LOG(INFO) << indent_str << " Promote: " << promote->stage; - for (unsigned idx = 0; idx != promote->before.size(); ++idx) { - for (auto &stmt : promote->before[idx]) { + for (const auto &[idx, stmts] : promote->before) { + for (const auto &stmt : stmts) { LOG(INFO) << indent_str << " Before " << idx << " : " << stmt; } } - for (unsigned idx = 0; idx != promote->after.size(); ++idx) { - for (auto &stmt : promote->after[idx]) { + for (const auto &[idx, stmts] : promote->after) { + for (const auto &stmt : stmts) { LOG(INFO) << indent_str << " After " << idx << " : " << stmt; } } @@ -1021,13 +1016,13 @@ inline void PrintIRStructure(const IRStructure *node, int indent = 0) { const ScheduleUnit *promote = static_cast(node); LOG(INFO) << indent_str << "ScheduleUnit:"; LOG(INFO) << indent_str << " Promote: " << promote->stage; - for (unsigned idx = 0; idx != promote->before.size(); ++idx) { - for (auto &stmt : promote->before[idx]) { + for (const auto &[idx, stmts] : promote->before) { + for (const auto &stmt : stmts) { LOG(INFO) << indent_str << " Before " << idx << " : " << stmt; } } - for (unsigned idx = 0; idx != promote->after.size(); ++idx) { - for (auto &stmt : promote->after[idx]) { + for (const auto &[idx, stmts] : promote->after) { + for (const auto &stmt : stmts) { LOG(INFO) << indent_str << " After " << idx << " : " << stmt; } } diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index 649ac132b7..ea76e34982 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -325,11 +325,11 @@ RemoveUnusedLetDecls(std::shared_ptr root) { auto unit = static_cast(node); collect(unit->child.get()); VarRefCollector collector; - for (const auto &stmts : unit->before) { + for (const auto &[_, stmts] : unit->before) { for (const auto &s : stmts) collector(s); } - for (const auto &stmts : unit->after) { + for (const auto &[_, stmts] : unit->after) { for (const auto &s : stmts) collector(s); } @@ -462,14 +462,14 @@ Stmt ConvertIRStructureToStmt(IRStructure *root, const bool outer_enable_epi) { std::vector stmts; for (const auto &child : seq->children) { auto unit = static_cast(child.get()); - for (auto &before : unit->before) { + for (auto &[_, before] : unit->before) { for (auto &stmt : before) { stmts.push_back(stmt); } } Stmt child_stmt = irstructure_to_stmt(unit->child.get()); stmts.push_back(child_stmt); - for (auto &after : unit->after) { + for (auto &[_, after] : unit->after) { for (auto &stmt : after) { stmts.push_back(stmt); } @@ -499,13 +499,13 @@ Stmt ConvertIRStructureToStmt(IRStructure *root, const bool outer_enable_epi) { std::vector stmts; if (ctrl->child->IsScheduleUnit()) { auto unit = static_cast(ctrl->child.get()); - for (auto &before : unit->before) { + for (auto &[_, before] : unit->before) { for (auto &stmt : before) { stmts.push_back(stmt); } } stmts.push_back(irstructure_to_stmt(unit->child.get())); - for (auto &after : unit->after) { + for (auto &[_, after] : unit->after) { for (auto &stmt : after) { stmts.push_back(stmt); } @@ -515,13 +515,13 @@ Stmt ConvertIRStructureToStmt(IRStructure *root, const bool outer_enable_epi) { for (auto &child : seq->children) { ICHECK(child->IsScheduleUnit()); auto unit = static_cast(child.get()); - for (auto &before : unit->before) { + for (auto &[_, before] : unit->before) { for (auto &stmt : before) { stmts.push_back(stmt); } } stmts.push_back(irstructure_to_stmt(unit->child.get())); - for (auto &after : unit->after) { + for (auto &[_, after] : unit->after) { for (auto &stmt : after) { stmts.push_back(stmt); } @@ -544,14 +544,14 @@ Stmt ConvertIRStructureToStmt(IRStructure *root, const bool outer_enable_epi) { for (auto &child : seq->children) { auto unit = static_cast(child.get()); std::vector stmts; - for (auto &before : unit->before) { - for (auto &stmt : before) { + for (const auto &[_, before] : unit->before) { + for (const auto &stmt : before) { stmts.push_back(stmt); } } stmts.push_back(irstructure_to_stmt(unit->child.get())); - for (auto &after : unit->after) { - for (auto &stmt : after) { + for (const auto &[_, after] : unit->after) { + for (const auto &stmt : after) { stmts.push_back(stmt); } } @@ -610,14 +610,14 @@ Stmt ConvertIRStructureToStmt(IRStructure *root, const bool outer_enable_epi) { for (auto &child : seq->children) { auto unit = static_cast(child.get()); std::vector stmts; - for (auto &before : unit->before) { - for (auto &stmt : before) { + for (const auto &[_, before] : unit->before) { + for (const auto &stmt : before) { stmts.push_back(stmt); } } stmts.push_back(irstructure_to_stmt(unit->child.get())); - for (auto &after : unit->after) { - for (auto &stmt : after) { + for (const auto &[_, after] : unit->after) { + for (const auto &stmt : after) { stmts.push_back(stmt); } } From 1cb8fdeaa4b92cb80db19c258dc03422ed59d4a3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 8 Apr 2026 17:19:43 +0800 Subject: [PATCH 024/156] Bump transformers from 4.53.0 to 5.0.0rc3 in /examples/bitnet-1.58b (#2021) Bumps [transformers](https://github.com/huggingface/transformers) from 4.53.0 to 5.0.0rc3. - [Release notes](https://github.com/huggingface/transformers/releases) - [Commits](https://github.com/huggingface/transformers/compare/v4.53.0...v5.0.0rc3) --- updated-dependencies: - dependency-name: transformers dependency-version: 5.0.0rc3 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- examples/bitnet-1.58b/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/bitnet-1.58b/requirements.txt b/examples/bitnet-1.58b/requirements.txt index 67357781e0..7660c28c6d 100644 --- a/examples/bitnet-1.58b/requirements.txt +++ b/examples/bitnet-1.58b/requirements.txt @@ -1,3 +1,3 @@ lm_eval==0.3.0 flash_attn -transformers==4.53.0 +transformers==5.0.0rc3 From 469a8479f6eb304f664ffdbe01e90796a279c1b7 Mon Sep 17 00:00:00 2001 From: Yichen Yan Date: Wed, 8 Apr 2026 17:21:32 +0800 Subject: [PATCH 025/156] pin apache-tvm-ffi<0.1.10 (derived_object regression) (#2020) --- pyproject.toml | 2 +- requirements-dev.txt | 2 +- requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 80aba41e32..601f2e35fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ # >=0.1.6 fixes a memory issue: tilelang#1502, but keep # requirement as wide as possible to be compatible with other libraries # pip will try to use latest version whenever possible. - "apache-tvm-ffi~=0.1.0,>=0.1.2", + "apache-tvm-ffi~=0.1.0,>=0.1.2,<0.1.10", # torch-c-dlpack-ext provides prebuilt torch extensions. # Without it, TVM FFI may require JIT compilation on first import. "torch-c-dlpack-ext; python_version < '3.14'", diff --git a/requirements-dev.txt b/requirements-dev.txt index f8dccdc871..a74959409a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,6 @@ # Requirements to run local build with `--no-build-isolation` or other developments -apache-tvm-ffi~=0.1.0,>=0.1.2 +apache-tvm-ffi~=0.1.0,>=0.1.2,<0.1.10 build cmake>=3.26 cython>=3.1.0 diff --git a/requirements.txt b/requirements.txt index 2dbe070d9a..37023a758f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # Runtime requirements -apache-tvm-ffi~=0.1.0,>=0.1.2 +apache-tvm-ffi~=0.1.0,>=0.1.2,<0.1.10 torch-c-dlpack-ext; python_version < '3.14' cloudpickle ml-dtypes From 86e37b7cd96d73604b005e55942c4158ce55c0d0 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Wed, 8 Apr 2026 17:27:19 +0800 Subject: [PATCH 026/156] Fix serial loop phase dtype mismatch in LowerTileOp (#2022) Fix int64 loop phase dtype handling in LowerTileOp --- src/transform/lower_tile_op.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index f955babfde..7267f17ae7 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -1130,11 +1130,13 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { } } PrimExpr phase_expr; + DataType loop_dtype = op->loop_var.dtype(); + PrimExpr two = make_const(loop_dtype, 2); if (num_stages > 1) { - phase_expr = FloorMod(FloorDiv(op->loop_var, num_stages), - IntImm(DataType::Int(32), 2)); + PrimExpr num_stages_expr = make_const(loop_dtype, num_stages); + phase_expr = FloorMod(FloorDiv(op->loop_var, num_stages_expr), two); } else { - phase_expr = FloorMod(op->loop_var, IntImm(DataType::Int(32), 2)); + phase_expr = FloorMod(op->loop_var, two); } loop_mbar_phase_stack_.push_back(analyzer_->Simplify(phase_expr)); pushed_loop_mbar_phase = true; From 9a15696c6a28cc0bf7cca1b6deacafa7c990628b Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Wed, 8 Apr 2026 17:39:48 +0800 Subject: [PATCH 027/156] Fix typo Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- src/transform/auto_schedule/schedule_builder.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transform/auto_schedule/schedule_builder.h b/src/transform/auto_schedule/schedule_builder.h index 28fc4d9763..433be3faad 100644 --- a/src/transform/auto_schedule/schedule_builder.h +++ b/src/transform/auto_schedule/schedule_builder.h @@ -611,7 +611,7 @@ class ScheduleUnitBuilder { void SetThreadVar(IterVar thread_var) { thread_var_ = thread_var; } // Set enable_warp_partition flag - void SetWarpSpeicializeConfig(const WarpSpecializeConfig &config) { + void SetWarpSpecializeConfig(const WarpSpecializeConfig &config) { config_ = config; } From d37fd4292a74b6a0dde56c54c636c05787e16818 Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Thu, 9 Apr 2026 11:56:24 +0800 Subject: [PATCH 028/156] implement a naive ws on F3 --- .../auto_schedule/schedule_builder.cc | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index 9b4178d0b0..8b50699b6d 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -324,6 +324,28 @@ AssignWarpgroupIdsGlobal(IRStructure *root, const WarpSpecializeConfig &config, [](const ComponentInfo &a, const ComponentInfo &b) { return a.weighted_latency > b.weighted_latency; }); + + if (config.enable_warp_partition) { + for (const auto &comp : component_infos) { + int assigned_warpgroup = 0; + if (comp.uses_tensor_core_ && !comp.uses_tma_core_) { + assigned_warpgroup = 0; + } else if (!comp.uses_tensor_core_ && comp.uses_tma_core_) { + assigned_warpgroup = 1; + } else { + assigned_warpgroup = 3; + } + for (int idx : comp.task_indices) { + TaskNode *task = all_tasks[idx].task; + if (!task->ContainsLoopBreak()) { + task->SetWarpgroupId(assigned_warpgroup); + } + } + } + return {IntImm(DataType::Int(32), 32), IntImm(DataType::Int(32), 32), + IntImm(DataType::Int(32), 64), + thread_count - IntImm(DataType::Int(32), 128)}; + } int64_t warpgroup0_latency = 0; int64_t warpgroup1_latency = 0; @@ -362,12 +384,7 @@ AssignWarpgroupIdsGlobal(IRStructure *root, const WarpSpecializeConfig &config, } } } - if (config.enable_thread_extend) { - return {thread_count, thread_count}; - } else { - return {IntImm(DataType::Int(32), 32), IntImm(DataType::Int(32), 32), - thread_count - IntImm(DataType::Int(32), 64)}; - } + return {thread_count, thread_count}; } else { int64_t warpgroup0_latency = 0; int64_t warpgroup1_latency = 0; @@ -394,13 +411,8 @@ AssignWarpgroupIdsGlobal(IRStructure *root, const WarpSpecializeConfig &config, } } } - if (config.enable_thread_extend) { - return {thread_count, - IntImm(DataType::Int(32), config.producer_thread_count)}; - } else { - return {IntImm(DataType::Int(32), 32), IntImm(DataType::Int(32), 32), - thread_count - IntImm(DataType::Int(32), 64)}; - } + return {thread_count, + IntImm(DataType::Int(32), config.producer_thread_count)}; } } From 5c7352e02ca4c00d9cf1a932af7c72b840c46a0f Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Thu, 9 Apr 2026 13:11:26 +0800 Subject: [PATCH 029/156] fix typo and run format --- src/transform/auto_schedule.cc | 2 +- src/transform/auto_schedule/schedule_builder.cc | 2 +- src/transform/auto_schedule/schedule_builder.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transform/auto_schedule.cc b/src/transform/auto_schedule.cc index d16a0154e9..c5539dc256 100644 --- a/src/transform/auto_schedule.cc +++ b/src/transform/auto_schedule.cc @@ -617,7 +617,7 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { LOG(FATAL) << "Could not find thread index variable, warpgroup " "partition will use default"; } - unit_builder.SetWarpSpeicializeConfig(config); + unit_builder.SetWarpSpecializeConfig(config); unit_builder.SetSharedMemoryLimit(GetSharedMemoryLimit(target)); std::vector thread_count = unit_builder.Build(ir_structure); diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index 64b1eda5e1..b08b2ca2fe 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -437,7 +437,7 @@ AssignWarpgroupIdsGlobal(IRStructure *root, const WarpSpecializeConfig &config, [](const ComponentInfo &a, const ComponentInfo &b) { return a.weighted_latency > b.weighted_latency; }); - + if (config.enable_warp_partition) { for (const auto &comp : component_infos) { int assigned_warpgroup = 0; diff --git a/src/transform/auto_schedule/schedule_builder.h b/src/transform/auto_schedule/schedule_builder.h index 924d91a0bd..4d9bebeb0b 100644 --- a/src/transform/auto_schedule/schedule_builder.h +++ b/src/transform/auto_schedule/schedule_builder.h @@ -626,7 +626,7 @@ class ScheduleUnitBuilder { // Set thread index variable for warpgroup partition void SetThreadVar(IterVar thread_var) { thread_var_ = thread_var; } - // Set enable_warp_partition flag + // Set warp specialization configuration void SetWarpSpecializeConfig(const WarpSpecializeConfig &config) { config_ = config; } From 3b47b8714551f2d31d3650c44e03ec7d4630a6fe Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Thu, 9 Apr 2026 18:16:25 +0800 Subject: [PATCH 030/156] modify return value of NaiveBuild --- src/transform/auto_schedule/schedule_builder.cc | 16 ++++++++++++---- src/transform/auto_schedule/schedule_builder.h | 5 +++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index 00c187ad1d..1031dc5825 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -684,7 +684,8 @@ void ScheduleUnitBuilder::ScheduleRecursive( // --- Naive scheduling implementation --- -bool NaiveAssignWarpgroupIds(IRStructure *root) { +std::vector NaiveAssignWarpgroupIds(IRStructure *root, const WarpSpecializeConfig &config, + PrimExpr thread_count) { if (!root) LOG(FATAL) << "Empty root"; @@ -730,7 +731,14 @@ bool NaiveAssignWarpgroupIds(IRStructure *root) { task->SetWarpgroupId(-1); } - return false; // no double_thread in naive mode + // no double_thread in naive mode + if (config.enable_thread_extend) { + return {thread_count, + IntImm(DataType::Int(32), config.producer_thread_count)}; + } else { + return {IntImm(DataType::Int(32), 32), IntImm(DataType::Int(32), 32), + thread_count - IntImm(DataType::Int(32), 64)}; + } } void ScheduleUnitBuilder::NaiveScheduleLoop(ControlNode *ctrl) { @@ -867,9 +875,9 @@ void ScheduleUnitBuilder::NaiveScheduleRecursive( } } -bool ScheduleUnitBuilder::NaiveBuild(std::shared_ptr &root) { +std::vector ScheduleUnitBuilder::NaiveBuild(std::shared_ptr &root) { NaiveScheduleRecursive(root); - return NaiveAssignWarpgroupIds(root.get()); + return NaiveAssignWarpgroupIds(root.get(), config_, thread_var_->dom->extent); } } // namespace tl diff --git a/src/transform/auto_schedule/schedule_builder.h b/src/transform/auto_schedule/schedule_builder.h index 0b6a1dd277..9ca566b9d0 100644 --- a/src/transform/auto_schedule/schedule_builder.h +++ b/src/transform/auto_schedule/schedule_builder.h @@ -53,7 +53,8 @@ AssignWarpgroupIdsGlobal(IRStructure *root, const WarpSpecializeConfig &config, PrimExpr thread_count); // Naive warpgroup assignment: TMA→wg1, compute→wg0, neutral→-1 -bool NaiveAssignWarpgroupIds(IRStructure *root); +std::vector NaiveAssignWarpgroupIds(IRStructure *root, const WarpSpecializeConfig &config, + PrimExpr thread_count); // Extract all sequential task nodes from the IR structure tree void GatherTaskNodes(const std::vector> &nodes, @@ -105,7 +106,7 @@ class ScheduleUnitBuilder { // Naive build: preserve original order, assign pipeline stages based on // num_stages annotation, assign warpgroup IDs by resource type // (TMA→wg1, compute→wg0). No Z3 scheduling. - bool NaiveBuild(std::shared_ptr &root); + std::vector NaiveBuild(std::shared_ptr &root); // New recursive scheduling function that replaces Collect method // Directly schedules the entire IRStructure tree recursively in place From 101a672a963d6648e5efbf1e642e7cae5f11d5e2 Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Fri, 10 Apr 2026 10:42:30 +0800 Subject: [PATCH 031/156] fix barrier logic --- src/transform/auto_schedule/barrier.h | 732 ++++++++---------- src/transform/auto_schedule/ir_structure.h | 3 + .../auto_schedule/schedule_builder.h | 22 +- 3 files changed, 366 insertions(+), 391 deletions(-) diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index c6e7eafd7c..2236ac94e6 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -142,14 +142,6 @@ struct LoopNestingInfo { } return total_iter; } - - // Calculate parity expression considering all nested loops - PrimExpr CalculateParityExpr(PrimExpr iter_offset, int num_versions) const { - PrimExpr total_iter = indexdiv(CalculateIterationCount(), num_versions); - - // Add iteration offset and calculate parity - return indexmod(total_iter + iter_offset, 2); - } }; // Structure to store multi-version buffer information @@ -585,20 +577,6 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, if (!seq) return; - // Map from (buffer, warpgroup_id) to task of last access - std::unordered_map, ObjectPtrHash, - ObjectPtrEqual> - last_access_map[2]; - std::unordered_map>, - ObjectPtrHash, ObjectPtrEqual> - last_write_map; - std::unordered_map, ObjectPtrHash, - ObjectPtrEqual> - last_wgmma_map[2]; - std::unordered_map barrier_unit_map; - int wait_wgmma_id[2] = {}, total_wgmma[2] = {}; - - // Process tasks in sequence order for (auto &promote_child : seq->children) { auto task = static_cast(promote_child.get()); if (task->child->IsSequence() || task->child->IsControl()) { @@ -607,36 +585,99 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, task->child.get(), next_barrier_id, barrier_buffers, barrier_map, thread_count, loop_info, buffer_infos, neutral_sync_shared_barrier); } + } - // Allocate barrier for TCGEN05MMA and rewrite gemm mbar argument + size_t num_wgs = thread_count.size(); + + // Insert wait_wgmma + std::vector> last_wgmma_map(num_wgs); + std::vector wait_wgmma_id(num_wgs, 0); + std::vector total_wgmma(num_wgs, 0); + + for (auto &promote_child : seq->children) { + auto task = static_cast(promote_child.get()); + if (task->isInnerTask() && task->UsesTensorCore()) { + auto child = static_cast(task->child.get()); + if (child->is_WGMMA()) { + bool found_wgmma = false; + for (const auto ®ion_access : task->GetReadWriteRegions()) { + int wg_id = region_access.warpgroup_id; + if (wg_id == -1) + continue; + auto ®ion = region_access.region; + if (IsRegisterRegion(region)) { + Buffer buffer = region->buffer; + if (!found_wgmma) { + found_wgmma = true; + ++total_wgmma[wg_id]; + } + last_wgmma_map[wg_id][buffer] = total_wgmma[wg_id]; + } + } + } + } else { + for (const auto ®ion_access : task->GetReadWriteRegions()) { + int wg_id = region_access.warpgroup_id; + if (wg_id == -1) + continue; + auto ®ion = region_access.region; + if (IsRegisterRegion(region)) { + Buffer buffer = region->buffer; + auto it = last_wgmma_map[wg_id].find(buffer); + if (it == last_wgmma_map[wg_id].end()) + continue; + if (it->second <= wait_wgmma_id[wg_id]) + continue; + wait_wgmma_id[wg_id] = it->second; + Stmt wait_stmt = + Evaluate(Call(DataType::Handle(), wait_wgmma(), + {total_wgmma[wg_id] - it->second})); + InsertStatementIntoScheduleUnit(task, wait_stmt, true, wg_id); + } + } + } + } + + std::map barrier_unit_map; + + // Allocate barriers for TCGEN05MMA + for (auto &promote_child : seq->children) { + auto task = static_cast(promote_child.get()); if (task->isInnerTask() && task->UsesTensorCore()) { auto child = static_cast(task->child.get()); if (child->is_TCGEN05()) { - int wg_id = child->GetWarpgroupId(); int barrier_id = next_barrier_id++; + // Create a single barrier buffer with shape (1,) Buffer barrier_buffer = makeBarrierBuffer( - 1, "tcgen05_barrier_" + std::to_string(barrier_id), 1, - barrier_buffers, barrier_map); + 1, "tcgen05_barrier_" + std::to_string(barrier_id), + 1, barrier_buffers, barrier_map); barrier_unit_map[task] = barrier_buffer; - PrimExpr barrier_load = BufferLoad(barrier_buffer, {0}); - RewriteGemmMbar(child, barrier_load); + // Rewrite the gemm call's mbar argument (arg[16]) to use + // BufferLoad(barrier_buffer, {0}) + PrimExpr mbar_expr = BufferLoad(barrier_buffer, {0}); + RewriteGemmMbar(child, mbar_expr); } } + } - // Allocate barrier for TMA + // Allocate barriers for TMA + for (auto &promote_child : seq->children) { + auto task = static_cast(promote_child.get()); if (task->isInnerTask() && task->UsesTMACore()) { auto child = static_cast(task->child.get()); if (child->HasTMALoad()) { int wg_id = child->GetWarpgroupId(); if (wg_id != -1) { int barrier_id = next_barrier_id++; - Buffer barrier_buffer = makeBarrierBuffer( - thread_count[wg_id], "tma_barrier_" + std::to_string(barrier_id), - 1, barrier_buffers, barrier_map); + Buffer barrier_buffer = + makeBarrierBuffer(thread_count[wg_id], + "tma_barrier_" + std::to_string(barrier_id), + 1, barrier_buffers, barrier_map); barrier_unit_map[task] = barrier_buffer; - PrimExpr barrier_load = BufferLoad(barrier_buffer, {0}); + PrimExpr barrier_load = + BufferLoad(barrier_buffer, {0}); RewriteCopyMbar(child, barrier_load); Stmt arrive_stmt = makeBarrierArrive(barrier_load); InsertStatementIntoScheduleUnit(task, arrive_stmt, false, wg_id); @@ -646,139 +687,89 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, } } } + } - // Check regions for dependencies - for (const auto ®ion_access : task->GetReadWriteRegions()) { - int wg_id = region_access.warpgroup_id; - if (wg_id == -1) - continue; - auto ®ion = region_access.region; - if (IsRegisterRegion(region)) { - Buffer buffer = region->buffer; - auto it = last_wgmma_map[wg_id].find(buffer); - if (it == last_wgmma_map[wg_id].end()) + // Insert barriers for other dependencies + // First collect shared buffers + std::set shared_buffers; + for (const auto ®ion_access : seq->GetReadWriteRegions()) { + auto &buffer = region_access.region->buffer; + shared_buffers.emplace(buffer); + } + auto is_async_task = [](ScheduleUnit *task) { + return task->UsesTensorCore() || task->UsesTMACore(); + }; + for (const auto &buffer : shared_buffers) { + std::vector last_access_task(num_wgs, nullptr); + std::vector last_access(num_wgs, false); + ScheduleUnit *last_write_task = nullptr; + uint64_t waited_write_wgs = 0; + int last_write_wg_id = -1; + bool last_write = false; + // Process tasks in sequence order + for (auto &promote_child : seq->children) { + auto task = static_cast(promote_child.get()); + bool is_async = is_async_task(task); + for (const auto ®ion_access : task->GetReadWriteRegions()) { + int wg_id = region_access.warpgroup_id; + if (wg_id == -1) continue; - if (it->second.second <= wait_wgmma_id[wg_id]) + if (region_access.region->buffer != buffer) continue; - wait_wgmma_id[wg_id] = it->second.second; - Stmt wait_stmt = - Evaluate(Call(DataType::Handle(), wait_wgmma(), - {total_wgmma[wg_id] - wait_wgmma_id[wg_id]})); - InsertStatementIntoScheduleUnit(task, wait_stmt, true, wg_id); - } else { - Buffer buffer = region->buffer; - bool need_barrier = false; - ScheduleUnit *last_access_task = nullptr; - int last_wg_id = -1; - bool is_async = task->UsesTensorCore() || task->UsesTMACore(); - if (!region_access.is_write) { - auto it = last_write_map.find(buffer); - if (it != last_write_map.end()) { - last_access_task = it->second.first; - last_wg_id = it->second.second.second; - if (last_wg_id == -1) - continue; - if (it->second.second.first & (1 << wg_id)) - continue; - bool last_async = last_access_task->UsesTensorCore() || - last_access_task->UsesTMACore(); - if (last_wg_id != wg_id || is_async || last_async) { - need_barrier = true; - } - } - } else { - auto it = last_access_map[!wg_id].find(buffer); - if (it != last_access_map[!wg_id].end()) { - last_access_task = it->second.first; - last_wg_id = it->second.second; - if (last_wg_id == -1) - continue; - if (last_wg_id != wg_id) { - need_barrier = true; - } - } - } - if (last_access_task == task) - continue; - // If warpgroup ids differ, insert barrier - if (need_barrier) { - if (barrier_unit_map.find(last_access_task) == - barrier_unit_map.end()) { - // Allocate a new barrier buffer (single stage for sequence) + auto insert_barrier = [&](ScheduleUnit *last_task, int last_wg_id) { + if (last_wg_id == -1) + return; + bool last_async = is_async_task(last_task); + if (last_wg_id == wg_id && !is_async && !last_async) + return; + if (barrier_unit_map.find(last_task) == barrier_unit_map.end()) { + // Allocate a new barrier buffer int barrier_id = next_barrier_id++; - Buffer barrier_buffer = - makeBarrierBuffer(thread_count[last_wg_id], - "barrier_" + std::to_string(barrier_id), 1, - barrier_buffers, barrier_map); - barrier_unit_map[last_access_task] = barrier_buffer; - // Create BufferLoad expression for barrier[0] - PrimExpr barrier_load = BufferLoad(barrier_buffer, {0}); - // Insert barrier_arrive at the end of last_access_task's - // statements + Buffer barrier_buffer = makeBarrierBuffer( + thread_count[last_wg_id], + "barrier_" + std::to_string(barrier_id), 1, + barrier_buffers, barrier_map); + barrier_unit_map[last_task] = barrier_buffer; + PrimExpr barrier_load = + BufferLoad(barrier_buffer, {0}); + // Insert barrier_arrive at the end of last_task's statements Stmt arrive_stmt = makeBarrierArrive(barrier_load); - InsertStatementIntoScheduleUnit(last_access_task, arrive_stmt, - false, last_wg_id); + InsertStatementIntoScheduleUnit(last_task, arrive_stmt, false, last_wg_id); } - PrimExpr barrier_load = - BufferLoad(barrier_unit_map[last_access_task], {0}); - - // Insert barrier_wait at the beginning of task's statements - Stmt wait_stmt = - makeBarrierWait(barrier_load, - 0); // parity = 0 for non-loop barriers + auto barrier_buffer = barrier_unit_map[last_task]; + PrimExpr barrier_load = BufferLoad(barrier_buffer, {0}); + Stmt wait_stmt = makeBarrierWait(barrier_load, 0); InsertStatementIntoScheduleUnit(task, wait_stmt, true, wg_id); - // Remove from map (as per user instruction) - if (!region_access.is_write) { - auto it = last_write_map.find(buffer); - it->second.second.first |= (1 << wg_id); - if (it->second.second.first == 3) { - last_write_map.erase(last_write_map.find(buffer)); - } - } else { - for (unsigned idx = 0; idx < 2; ++idx) { - auto it = last_access_map[idx].find(buffer); - if (it != last_access_map[idx].end()) { - last_access_map[idx].erase(it); - } - } - auto it = last_write_map.find(buffer); - if (it != last_write_map.end()) { - last_write_map.erase(it); - } + }; + + if (!region_access.is_write) { + if (last_write_task == nullptr) + continue; + if (waited_write_wgs >> wg_id & 1) + continue; + insert_barrier(last_write_task, last_write_wg_id); + waited_write_wgs |= (1 << wg_id); + } else { + for (int last_wg_id = 0; last_wg_id < num_wgs; ++last_wg_id) { + if (last_access_task[last_wg_id] == nullptr) + continue; + insert_barrier(last_access_task[last_wg_id], last_wg_id); + last_access_task[last_wg_id] = nullptr; } } } - } - - // Update regions - bool found_wgmma = false; - for (const auto ®ion_access : task->GetReadWriteRegions()) { - int wg_id = region_access.warpgroup_id; - if (wg_id == -1) - continue; - auto ®ion = region_access.region; - if (IsRegisterRegion(region)) { - if (!task->UsesTensorCore() || !region_access.is_write) + for (const auto ®ion_access : task->GetReadWriteRegions()) { + int wg_id = region_access.warpgroup_id; + if (wg_id == -1) continue; - if (!task->isInnerTask()) + if (region_access.region->buffer != buffer) continue; - auto child = static_cast(task->child.get()); - if (child->is_WGMMA()) { - Buffer buffer = region->buffer; - if (!found_wgmma) { - found_wgmma = true; - ++total_wgmma[wg_id]; - } - last_wgmma_map[wg_id][buffer] = - std::make_pair(task, total_wgmma[wg_id]); - } - } else { - Buffer buffer = region->buffer; - last_access_map[wg_id][buffer] = std::make_pair(task, wg_id); + last_access_task[wg_id] = task; if (region_access.is_write) { - last_write_map[buffer] = - std::make_pair(task, std::make_pair(0, wg_id)); + last_write_task = task; + last_write_wg_id = wg_id; + waited_write_wgs = 0; } } } @@ -839,27 +830,12 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, ordered_tasks.begin(), ordered_tasks.end(), [](ScheduleUnit *a, ScheduleUnit *b) { return a->stage > b->stage; }); - // Map from (buffer, warpgroup_id) to task - std::unordered_map, ObjectPtrHash, - ObjectPtrEqual> - last_access_map[2]; - std::unordered_set - last_access_set[2]; - std::unordered_map>, - ObjectPtrHash, ObjectPtrEqual> - last_write_map; - std::unordered_set last_write_set; - std::unordered_map, ObjectPtrHash, - ObjectPtrEqual> - last_wgmma_map[2]; - std::map, Buffer> barrier_unit_map; - int wait_wgmma_id[2] = {}, total_wgmma[2] = {}; + // Rewrite multi-buffer auto num_stages = 1; auto num_stages_val = ctrl->control.get()->annotations.Get("num_stages"); if (num_stages_val.has_value()) { num_stages = num_stages_val.value().cast()->value; } - std::unordered_map multi_buffer; std::unordered_map @@ -868,7 +844,7 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, for (const auto &task : ordered_tasks) { for (const auto ®ion_access : task->GetReadWriteRegions()) { auto &buffer = region_access.region->buffer; - if (!IsSharedBuffer(buffer)) + if (!ctrl->multi_buffering_buffers.count(buffer)) continue; for (const auto &other_task : ordered_tasks) { if (task == other_task) @@ -894,7 +870,7 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, } for (auto ®ion : ctrl->GetWriteRegions()) { auto &buffer = region.get()->buffer; - if (!IsSharedBuffer(buffer)) + if (!ctrl->multi_buffering_buffers.count(buffer)) continue; if (multi_buffer.find(buffer) != multi_buffer.end()) continue; @@ -917,261 +893,237 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, RewriteTaskNodeBuffers(ctrl, multi_buffer, iteration); } - // Process tasks in the specified order + size_t num_wgs = thread_count.size(); + + // Insert wait_wgmma + std::vector> last_wgmma_map(num_wgs); + std::vector wait_wgmma_id(num_wgs, 0); + std::vector total_wgmma(num_wgs, 0); for (unsigned iter = 0; iter != 2; ++iter) { for (ScheduleUnit *task : ordered_tasks) { - int stage = task->GetStage(); - bool is_async = task->UsesTensorCore() || task->UsesTMACore(); - - if (iter == 0 && task->isInnerTask() && task->UsesTensorCore()) { + if (task->isInnerTask() && task->UsesTensorCore()) { + if (iter == 1) { + continue; + } auto child = static_cast(task->child.get()); - if (child->is_TCGEN05()) { - auto &&write_regions = child->GetWriteRegions(); - if (write_regions.size() == 1) { - Buffer buffer = write_regions[0]->buffer; - auto it = buffer_num_versions.find(buffer); - int num_versions = - it != buffer_num_versions.end() ? it->second : 1; - int wg_id = child->GetWarpgroupId(); - int barrier_id = next_barrier_id++; - // Create a single barrier buffer with shape (num_versions,) - Buffer barrier_buffer = makeBarrierBuffer( - 1, "tcgen05_barrier_" + std::to_string(barrier_id), - num_versions, barrier_buffers, barrier_map); - barrier_unit_map[std::make_pair(task, buffer)] = barrier_buffer; - - // Rewrite the gemm call's mbar argument (arg[16]) to use - // BufferLoad(barrier_buffer, {version_index}) - PrimExpr version_index = - indexmod(loop_info.CalculateIterationCount(), num_versions); - PrimExpr mbar_expr = BufferLoad(barrier_buffer, {version_index}); - RewriteGemmMbar(child, mbar_expr); - } else { - LOG(FATAL) - << "TCGEN05MMA tasks must have exactly one write region"; + if (child->is_WGMMA()) { + bool found_wgmma = false; + for (const auto ®ion_access : task->GetReadWriteRegions()) { + int wg_id = region_access.warpgroup_id; + if (wg_id == -1) + continue; + auto ®ion = region_access.region; + if (IsRegisterRegion(region)) { + Buffer buffer = region->buffer; + if (!found_wgmma) { + found_wgmma = true; + ++total_wgmma[wg_id]; + } + last_wgmma_map[wg_id][buffer] = total_wgmma[wg_id]; + } + } + } + } else { + for (const auto ®ion_access : task->GetReadWriteRegions()) { + int wg_id = region_access.warpgroup_id; + if (wg_id == -1) + continue; + auto ®ion = region_access.region; + if (IsRegisterRegion(region)) { + Buffer buffer = region->buffer; + auto it = last_wgmma_map[wg_id].find(buffer); + if (it == last_wgmma_map[wg_id].end()) + continue; + if (it->second <= wait_wgmma_id[wg_id]) + continue; + wait_wgmma_id[wg_id] = it->second; + Stmt wait_stmt = + Evaluate(Call(DataType::Handle(), wait_wgmma(), + {total_wgmma[wg_id] - it->second})); + InsertStatementIntoScheduleUnit(task, wait_stmt, true, wg_id); } } } + } + } - // Allocate barrier for TMA - if (iter == 0 && task->isInnerTask() && task->UsesTMACore()) { - auto child = static_cast(task->child.get()); - if (child->HasTMALoad()) { - auto &&write_regions = child->GetWriteRegions(); - if (write_regions.size() == 1) { - Buffer buffer = write_regions[0]->buffer; - auto it = buffer_num_versions.find(buffer); - int num_versions = - it != buffer_num_versions.end() ? it->second : 1; - int wg_id = child->GetWarpgroupId(); - ICHECK(wg_id != -1) << "TMA loads must have valid warpgroup id"; - int barrier_id = next_barrier_id++; - Buffer barrier_buffer = - makeBarrierBuffer(thread_count[wg_id], - "tma_barrier_" + std::to_string(barrier_id), - num_versions, barrier_buffers, barrier_map); - barrier_unit_map[std::make_pair(task, buffer)] = barrier_buffer; - - PrimExpr version_index = - indexmod(loop_info.CalculateIterationCount(), num_versions); - PrimExpr barrier_load = - BufferLoad(barrier_buffer, {version_index}); - RewriteCopyMbar(child, barrier_load); - Stmt arrive_stmt = makeBarrierArrive(barrier_load); - InsertStatementIntoScheduleUnit(task, arrive_stmt, false, wg_id); - } else { - LOG(FATAL) << "TMA loads must have exactly one write region"; + std::map> barrier_unit_map; + + // Allocate barriers for TCGEN05MMA + for (ScheduleUnit *task : ordered_tasks) { + if (task->isInnerTask() && task->UsesTensorCore()) { + auto child = static_cast(task->child.get()); + if (child->is_TCGEN05()) { + int num_versions = 1; + for (const auto ®ion_access : child->GetReadWriteRegions()) { + auto &buffer = region_access.region->buffer; + auto it = buffer_num_versions.find(buffer); + if (it != buffer_num_versions.end()) { + num_versions = std::max(num_versions, it->second); } } + + int barrier_id = next_barrier_id++; + // Create a single barrier buffer with shape (num_versions,) + Buffer barrier_buffer = makeBarrierBuffer( + 1, "tcgen05_barrier_" + std::to_string(barrier_id), + num_versions, barrier_buffers, barrier_map); + barrier_unit_map[task] = std::make_pair(barrier_buffer, num_versions); + + // Rewrite the gemm call's mbar argument (arg[16]) to use + // BufferLoad(barrier_buffer, {version_index}) + PrimExpr version_index = + indexmod(loop_info.CalculateIterationCount(), num_versions); + PrimExpr mbar_expr = BufferLoad(barrier_buffer, {version_index}); + RewriteGemmMbar(child, mbar_expr); } + } + } - // Check regions for dependencies - for (const auto ®ion_access : task->GetReadWriteRegions()) { - int wg_id = region_access.warpgroup_id; - if (wg_id == -1) - continue; - auto ®ion = region_access.region; - if (IsRegisterRegion(region)) { - if (task->UsesTensorCore()) - continue; - Buffer buffer = region->buffer; - auto it = last_wgmma_map[wg_id].find(buffer); - if (it == last_wgmma_map[wg_id].end()) - continue; - if (it->second.second <= wait_wgmma_id[wg_id]) - continue; - wait_wgmma_id[wg_id] = it->second.second; - Stmt wait_stmt = - Evaluate(Call(DataType::Handle(), wait_wgmma(), - {total_wgmma[wg_id] - wait_wgmma_id[wg_id]})); - InsertStatementIntoScheduleUnit(task, wait_stmt, true, wg_id); - } else { - Buffer buffer = region->buffer; + // Allocate barriers for TMA + for (ScheduleUnit *task : ordered_tasks) { + if (task->isInnerTask() && task->UsesTMACore()) { + auto child = static_cast(task->child.get()); + if (child->HasTMALoad()) { + int num_versions = 1; + for (const auto ®ion_access : child->GetReadWriteRegions()) { + auto &buffer = region_access.region->buffer; auto it = buffer_num_versions.find(buffer); - int num_versions = it != buffer_num_versions.end() ? it->second : 1; - bool need_barrier = false; - ScheduleUnit *last_access_task; - int last_wg_id; - int last_stage; - if (!region_access.is_write) { - if (iter == 1) { - if (last_write_set.find(buffer) != last_write_set.end()) { - continue; - } - last_write_set.insert(buffer); - } - auto it = last_write_map.find(buffer); - if (it != last_write_map.end()) { - last_access_task = it->second.first; - last_wg_id = it->second.second.second; - last_stage = last_access_task->GetStage(); - if (last_wg_id == -1) - continue; // Allow barriers involving neutral tasks - if (it->second.second.first & (1 << wg_id)) - continue; - - bool last_async = last_access_task->UsesTensorCore() || - last_access_task->UsesTMACore(); - // If warpgroup ids differ or promotion status differs, - // insert barrier - if (last_wg_id != wg_id || last_stage != stage || is_async || - last_async) { - need_barrier = true; - } - } - } else { - if (iter == 1) { - if (last_access_set[!wg_id].find(buffer) != - last_access_set[!wg_id].end()) { - continue; - } - last_access_set[!wg_id].insert(buffer); - } - auto it = last_access_map[!wg_id].find(buffer); - if (it != last_access_map[!wg_id].end()) { - last_access_task = it->second.first; - last_wg_id = it->second.second; - last_stage = last_access_task->GetStage(); - if (last_wg_id == -1) - continue; // Allow barriers involving neutral tasks - - // If warpgroup ids differ or promotion status differs, - // insert barrier - if (last_wg_id != wg_id || last_stage != stage) { - need_barrier = true; - } - } + if (it != buffer_num_versions.end()) { + num_versions = std::max(num_versions, it->second); } - if (last_access_task == task) + } + int wg_id = child->GetWarpgroupId(); + ICHECK(wg_id != -1) << "TMA loads must have valid warpgroup id"; + + int barrier_id = next_barrier_id++; + Buffer barrier_buffer = + makeBarrierBuffer(thread_count[wg_id], + "tma_barrier_" + std::to_string(barrier_id), + num_versions, barrier_buffers, barrier_map); + barrier_unit_map[task] = std::make_pair(barrier_buffer, num_versions); + + PrimExpr version_index = + indexmod(loop_info.CalculateIterationCount(), num_versions); + PrimExpr barrier_load = + BufferLoad(barrier_buffer, {version_index}); + RewriteCopyMbar(child, barrier_load); + Stmt arrive_stmt = makeBarrierArrive(barrier_load); + InsertStatementIntoScheduleUnit(task, arrive_stmt, false, wg_id); + } + } + } + + // Insert barriers for other dependencies + // First collect shared buffers + std::set, std::greater>> shared_buffers; + for (const auto ®ion_access : ctrl->GetReadWriteRegions()) { + auto &buffer = region_access.region->buffer; + if (IsSharedBuffer(buffer)) { + auto it = buffer_num_versions.find(buffer); + int num_versions = it != buffer_num_versions.end() ? it->second : 1; + shared_buffers.emplace(num_versions, buffer); + } + } + // Process buffers in order of decreasing number of versions to ensure correct barrier size + auto is_async_task = [](ScheduleUnit *task) { + return task->UsesTensorCore() || task->UsesTMACore(); + }; + for (const auto &[num_versions, buffer] : shared_buffers) { + std::vector last_access_task(num_wgs, nullptr); + std::vector last_access(num_wgs, false); + ScheduleUnit *last_write_task = nullptr; + uint64_t waited_write_wgs = 0; + int last_write_wg_id = -1; + bool last_write = false; + // Process tasks in the specified order + for (unsigned iter = 0; iter != 2; ++iter) { + for (ScheduleUnit *task : ordered_tasks) { + int stage = task->GetStage(); + bool is_async = is_async_task(task); + for (const auto ®ion_access : task->GetReadWriteRegions()) { + int wg_id = region_access.warpgroup_id; + if (wg_id == -1) continue; - // If warpgroup ids differ or promotion status differs, insert - // barrier - if (need_barrier) { - // Calculate parity for barrier wait considering all nested - // loops Use loop_info to calculate parity expression: - // outer_var - // * inner_constant + inner_var - PrimExpr iter_offset = IntImm(DataType::Int(32), iter); - PrimExpr parity_expr = - loop_info.CalculateParityExpr(iter_offset, num_versions); - - if (barrier_unit_map.find(std::make_pair( - last_access_task, buffer)) == barrier_unit_map.end()) { - // Allocate a single barrier buffer with shape (num_versions,) + if (region_access.region->buffer != buffer) + continue; + + auto insert_barrier = [&](ScheduleUnit *last_task, int last_wg_id) { + if (last_wg_id == -1) + return; + int last_stage = last_task->GetStage(); + bool last_async = is_async_task(last_task); + if (last_wg_id == wg_id && last_stage == stage && !is_async && !last_async) + return; + if (barrier_unit_map.find(last_task) == barrier_unit_map.end()) { + // Allocate a new barrier buffer int barrier_id = next_barrier_id++; Buffer barrier_buffer = makeBarrierBuffer( thread_count[last_wg_id], "barrier_" + std::to_string(barrier_id), num_versions, barrier_buffers, barrier_map); - barrier_unit_map[std::make_pair(last_access_task, buffer)] = - barrier_buffer; - + barrier_unit_map[last_task] = std::make_pair(barrier_buffer, num_versions); // Create BufferLoad with version-indexed offset PrimExpr version_index = indexmod(loop_info.CalculateIterationCount(), num_versions); PrimExpr barrier_load = BufferLoad(barrier_buffer, {version_index}); - // Insert barrier_arrive at the end of last_access_task's - // statements + // Insert barrier_arrive at the end of last_task's statements Stmt arrive_stmt = makeBarrierArrive(barrier_load); - InsertStatementIntoScheduleUnit(last_access_task, arrive_stmt, - false, last_wg_id); + InsertStatementIntoScheduleUnit(last_task, arrive_stmt, false, last_wg_id); } - Buffer &barrier_buffer = - barrier_unit_map[std::make_pair(last_access_task, buffer)]; - PrimExpr version_index = - indexmod(loop_info.CalculateIterationCount(), num_versions); - PrimExpr barrier_load = - BufferLoad(barrier_buffer, {version_index}); - - // Insert barrier_wait at the beginning of task's statements + auto [barrier_buffer, barrier_versions] = barrier_unit_map[last_task]; + PrimExpr iteration = loop_info.CalculateIterationCount(); + if (iter == 1) { + // Calculate the real iteration to wait. + // "+ barrier_versions * 2" ensures positive iteration for division and modulo, and keeps the parity the same. + iteration += barrier_versions * 2 - num_versions; + } + PrimExpr version_index = indexmod(iteration, barrier_versions); + PrimExpr barrier_load = BufferLoad(barrier_buffer, {version_index}); + PrimExpr parity_expr = indexmod(indexdiv(iteration, barrier_versions), 2); Stmt wait_stmt = makeBarrierWait(barrier_load, parity_expr); - // if (iter == 1) { - // // Check if at least one loop is not at its start - // iteration - // // (not the first iteration of all nested loops) - // wait_stmt = - // IfThenElse(indexdiv(loop_info.CalculateIterationCount(), - // num_stages) != 0, - // wait_stmt); - // } InsertStatementIntoScheduleUnit(task, wait_stmt, true, wg_id); - // Remove from map (as per user instruction) - if (!region_access.is_write) { - auto it = last_write_map.find(buffer); - it->second.second.first |= (1 << wg_id); - if (it->second.second.first == 3) { - last_write_map.erase(last_write_map.find(buffer)); - } - } else { - for (unsigned idx = 0; idx < 2; ++idx) { - auto it = last_access_map[idx].find(buffer); - if (it != last_access_map[idx].end()) { - last_access_map[idx].erase(it); - } - } - auto it = last_write_map.find(buffer); - if (it != last_write_map.end()) { - last_write_map.erase(it); - } - } - bool last_async = last_access_task->UsesTensorCore() || - last_access_task->UsesTMACore(); - } - } - } + }; - if (iter == 0) { - // Update regions - bool found_wgmma = false; - for (const auto ®ion_access : task->GetReadWriteRegions()) { - int wg_id = region_access.warpgroup_id; - if (wg_id == -1) - continue; - auto ®ion = region_access.region; - if (IsRegisterRegion(region)) { - if (!task->UsesTensorCore() || !region_access.is_write) + if (!region_access.is_write) { + if (iter == 1) { + if (last_write) + continue; + last_write = true; + } + if (last_write_task == nullptr) continue; - if (!task->isInnerTask()) + if (waited_write_wgs >> wg_id & 1) continue; - auto child = static_cast(task->child.get()); - if (child->is_WGMMA()) { - Buffer buffer = region->buffer; - if (!found_wgmma) { - found_wgmma = true; - ++total_wgmma[wg_id]; + insert_barrier(last_write_task, last_write_wg_id); + waited_write_wgs |= (1 << wg_id); + } else { + for (int last_wg_id = 0; last_wg_id < num_wgs; ++last_wg_id) { + if (iter == 1) { + if (last_access[last_wg_id]) + continue; + last_access[last_wg_id] = true; } - last_wgmma_map[wg_id][buffer] = - std::make_pair(task, total_wgmma[wg_id]); + if (last_access_task[last_wg_id] == nullptr) + continue; + insert_barrier(last_access_task[last_wg_id], last_wg_id); + last_access_task[last_wg_id] = nullptr; } - } else { - if (iter == 1) + } + } + if (iter == 0) { + for (const auto ®ion_access : task->GetReadWriteRegions()) { + int wg_id = region_access.warpgroup_id; + if (wg_id == -1) continue; - Buffer buffer = region->buffer; - last_access_map[wg_id][buffer] = std::make_pair(task, wg_id); + if (region_access.region->buffer != buffer) + continue; + last_access_task[wg_id] = task; if (region_access.is_write) { - last_write_map[buffer] = - std::make_pair(task, std::make_pair(0, wg_id)); + last_write_task = task; + last_write_wg_id = wg_id; + waited_write_wgs = 0; } } } diff --git a/src/transform/auto_schedule/ir_structure.h b/src/transform/auto_schedule/ir_structure.h index c067a9adeb..6b26872e7a 100644 --- a/src/transform/auto_schedule/ir_structure.h +++ b/src/transform/auto_schedule/ir_structure.h @@ -470,6 +470,9 @@ class ControlNode : public IRStructure { int64_t ii_{0}; // Initiation interval in cycles bool has_promote_{false}; int64_t ii_per_iter_{0}; + +public: + std::set multi_buffering_buffers; }; // Wrapper node: contains a Wrapper statement with variable, value, and child diff --git a/src/transform/auto_schedule/schedule_builder.h b/src/transform/auto_schedule/schedule_builder.h index 9ca566b9d0..d116b808ee 100644 --- a/src/transform/auto_schedule/schedule_builder.h +++ b/src/transform/auto_schedule/schedule_builder.h @@ -349,12 +349,30 @@ class ScheduleUnitBuilder { resource_flags.push_back(flags); } + // Helper function to check if a buffer is written before being read + auto check_buffer_write_first = [&nodes](const Buffer &buffer) { + for (const auto &node : nodes) { + for (const auto ®ion : node->GetReadRegions()) { + if (region->buffer.same_as(buffer)) { + return false; // read access found before any write + } + } + for (const auto ®ion : node->GetWriteRegions()) { + if (region->buffer.same_as(buffer)) { + return true; // write access found before any read + } + } + } + return false; + }; + // Collect all shared buffers // The negative number means we can use multi-buffering for this buffer, so // we need to create a variable for the number of versions for this buffer // in z3 scheduler. std::vector buffer_sizes; std::map buffer_to_num_versions; + std::set multi_buffering_buffers; int64_t memory_limit = shared_memory_limit_; for (const auto ®ion_access : ctrl->GetReadWriteRegions()) { const auto &buffer = region_access.region->buffer; @@ -364,12 +382,13 @@ class ScheduleUnitBuilder { if (buffer_to_num_versions.count(buffer)) { continue; } - if (used_buffers.count(buffer)) { + if (used_buffers.count(buffer) || !check_buffer_write_first(buffer)) { buffer_to_num_versions[buffer] = 1; memory_limit -= GetBufferSize(buffer); } else { buffer_sizes.push_back(GetBufferSize(buffer)); buffer_to_num_versions[buffer] = -(int64_t)buffer_sizes.size(); + multi_buffering_buffers.insert(buffer); } } @@ -635,6 +654,7 @@ class ScheduleUnitBuilder { ctrl->SetII(overall_latency); ctrl->SetLatency(overall_latency); ctrl->SetIIperIter(ii); + ctrl->multi_buffering_buffers = std::move(multi_buffering_buffers); } // Set thread index variable for warpgroup partition From 35f13d2d474252e037e76330ce453026e39e0ad3 Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Fri, 10 Apr 2026 10:44:23 +0800 Subject: [PATCH 032/156] run format --- src/transform/auto_schedule/barrier.h | 74 ++++++++++--------- .../auto_schedule/schedule_builder.cc | 8 +- .../auto_schedule/schedule_builder.h | 5 +- .../auto_schedule/warpgroup_partition.cc | 11 ++- 4 files changed, 52 insertions(+), 46 deletions(-) diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index 2236ac94e6..6f237a5faf 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -629,9 +629,8 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, if (it->second <= wait_wgmma_id[wg_id]) continue; wait_wgmma_id[wg_id] = it->second; - Stmt wait_stmt = - Evaluate(Call(DataType::Handle(), wait_wgmma(), - {total_wgmma[wg_id] - it->second})); + Stmt wait_stmt = Evaluate(Call(DataType::Handle(), wait_wgmma(), + {total_wgmma[wg_id] - it->second})); InsertStatementIntoScheduleUnit(task, wait_stmt, true, wg_id); } } @@ -649,8 +648,8 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, int barrier_id = next_barrier_id++; // Create a single barrier buffer with shape (1,) Buffer barrier_buffer = makeBarrierBuffer( - 1, "tcgen05_barrier_" + std::to_string(barrier_id), - 1, barrier_buffers, barrier_map); + 1, "tcgen05_barrier_" + std::to_string(barrier_id), 1, + barrier_buffers, barrier_map); barrier_unit_map[task] = barrier_buffer; // Rewrite the gemm call's mbar argument (arg[16]) to use @@ -670,14 +669,12 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, int wg_id = child->GetWarpgroupId(); if (wg_id != -1) { int barrier_id = next_barrier_id++; - Buffer barrier_buffer = - makeBarrierBuffer(thread_count[wg_id], - "tma_barrier_" + std::to_string(barrier_id), - 1, barrier_buffers, barrier_map); + Buffer barrier_buffer = makeBarrierBuffer( + thread_count[wg_id], "tma_barrier_" + std::to_string(barrier_id), + 1, barrier_buffers, barrier_map); barrier_unit_map[task] = barrier_buffer; - PrimExpr barrier_load = - BufferLoad(barrier_buffer, {0}); + PrimExpr barrier_load = BufferLoad(barrier_buffer, {0}); RewriteCopyMbar(child, barrier_load); Stmt arrive_stmt = makeBarrierArrive(barrier_load); InsertStatementIntoScheduleUnit(task, arrive_stmt, false, wg_id); @@ -726,16 +723,16 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, if (barrier_unit_map.find(last_task) == barrier_unit_map.end()) { // Allocate a new barrier buffer int barrier_id = next_barrier_id++; - Buffer barrier_buffer = makeBarrierBuffer( - thread_count[last_wg_id], - "barrier_" + std::to_string(barrier_id), 1, - barrier_buffers, barrier_map); + Buffer barrier_buffer = + makeBarrierBuffer(thread_count[last_wg_id], + "barrier_" + std::to_string(barrier_id), 1, + barrier_buffers, barrier_map); barrier_unit_map[last_task] = barrier_buffer; - PrimExpr barrier_load = - BufferLoad(barrier_buffer, {0}); + PrimExpr barrier_load = BufferLoad(barrier_buffer, {0}); // Insert barrier_arrive at the end of last_task's statements Stmt arrive_stmt = makeBarrierArrive(barrier_load); - InsertStatementIntoScheduleUnit(last_task, arrive_stmt, false, last_wg_id); + InsertStatementIntoScheduleUnit(last_task, arrive_stmt, false, + last_wg_id); } auto barrier_buffer = barrier_unit_map[last_task]; PrimExpr barrier_load = BufferLoad(barrier_buffer, {0}); @@ -966,8 +963,8 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, int barrier_id = next_barrier_id++; // Create a single barrier buffer with shape (num_versions,) Buffer barrier_buffer = makeBarrierBuffer( - 1, "tcgen05_barrier_" + std::to_string(barrier_id), - num_versions, barrier_buffers, barrier_map); + 1, "tcgen05_barrier_" + std::to_string(barrier_id), num_versions, + barrier_buffers, barrier_map); barrier_unit_map[task] = std::make_pair(barrier_buffer, num_versions); // Rewrite the gemm call's mbar argument (arg[16]) to use @@ -997,16 +994,14 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, ICHECK(wg_id != -1) << "TMA loads must have valid warpgroup id"; int barrier_id = next_barrier_id++; - Buffer barrier_buffer = - makeBarrierBuffer(thread_count[wg_id], - "tma_barrier_" + std::to_string(barrier_id), - num_versions, barrier_buffers, barrier_map); + Buffer barrier_buffer = makeBarrierBuffer( + thread_count[wg_id], "tma_barrier_" + std::to_string(barrier_id), + num_versions, barrier_buffers, barrier_map); barrier_unit_map[task] = std::make_pair(barrier_buffer, num_versions); PrimExpr version_index = indexmod(loop_info.CalculateIterationCount(), num_versions); - PrimExpr barrier_load = - BufferLoad(barrier_buffer, {version_index}); + PrimExpr barrier_load = BufferLoad(barrier_buffer, {version_index}); RewriteCopyMbar(child, barrier_load); Stmt arrive_stmt = makeBarrierArrive(barrier_load); InsertStatementIntoScheduleUnit(task, arrive_stmt, false, wg_id); @@ -1016,7 +1011,8 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, // Insert barriers for other dependencies // First collect shared buffers - std::set, std::greater>> shared_buffers; + std::set, std::greater>> + shared_buffers; for (const auto ®ion_access : ctrl->GetReadWriteRegions()) { auto &buffer = region_access.region->buffer; if (IsSharedBuffer(buffer)) { @@ -1025,7 +1021,8 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, shared_buffers.emplace(num_versions, buffer); } } - // Process buffers in order of decreasing number of versions to ensure correct barrier size + // Process buffers in order of decreasing number of versions to ensure + // correct barrier size auto is_async_task = [](ScheduleUnit *task) { return task->UsesTensorCore() || task->UsesTMACore(); }; @@ -1053,7 +1050,8 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, return; int last_stage = last_task->GetStage(); bool last_async = is_async_task(last_task); - if (last_wg_id == wg_id && last_stage == stage && !is_async && !last_async) + if (last_wg_id == wg_id && last_stage == stage && !is_async && + !last_async) return; if (barrier_unit_map.find(last_task) == barrier_unit_map.end()) { // Allocate a new barrier buffer @@ -1062,7 +1060,8 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, thread_count[last_wg_id], "barrier_" + std::to_string(barrier_id), num_versions, barrier_buffers, barrier_map); - barrier_unit_map[last_task] = std::make_pair(barrier_buffer, num_versions); + barrier_unit_map[last_task] = + std::make_pair(barrier_buffer, num_versions); // Create BufferLoad with version-indexed offset PrimExpr version_index = indexmod(loop_info.CalculateIterationCount(), num_versions); @@ -1070,18 +1069,23 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, BufferLoad(barrier_buffer, {version_index}); // Insert barrier_arrive at the end of last_task's statements Stmt arrive_stmt = makeBarrierArrive(barrier_load); - InsertStatementIntoScheduleUnit(last_task, arrive_stmt, false, last_wg_id); + InsertStatementIntoScheduleUnit(last_task, arrive_stmt, false, + last_wg_id); } - auto [barrier_buffer, barrier_versions] = barrier_unit_map[last_task]; + auto [barrier_buffer, barrier_versions] = + barrier_unit_map[last_task]; PrimExpr iteration = loop_info.CalculateIterationCount(); if (iter == 1) { // Calculate the real iteration to wait. - // "+ barrier_versions * 2" ensures positive iteration for division and modulo, and keeps the parity the same. + // "+ barrier_versions * 2" ensures positive iteration for + // division and modulo, and keeps the parity the same. iteration += barrier_versions * 2 - num_versions; } PrimExpr version_index = indexmod(iteration, barrier_versions); - PrimExpr barrier_load = BufferLoad(barrier_buffer, {version_index}); - PrimExpr parity_expr = indexmod(indexdiv(iteration, barrier_versions), 2); + PrimExpr barrier_load = + BufferLoad(barrier_buffer, {version_index}); + PrimExpr parity_expr = + indexmod(indexdiv(iteration, barrier_versions), 2); Stmt wait_stmt = makeBarrierWait(barrier_load, parity_expr); InsertStatementIntoScheduleUnit(task, wait_stmt, true, wg_id); }; diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index 1031dc5825..c0aaf64a2a 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -684,8 +684,9 @@ void ScheduleUnitBuilder::ScheduleRecursive( // --- Naive scheduling implementation --- -std::vector NaiveAssignWarpgroupIds(IRStructure *root, const WarpSpecializeConfig &config, - PrimExpr thread_count) { +std::vector +NaiveAssignWarpgroupIds(IRStructure *root, const WarpSpecializeConfig &config, + PrimExpr thread_count) { if (!root) LOG(FATAL) << "Empty root"; @@ -875,7 +876,8 @@ void ScheduleUnitBuilder::NaiveScheduleRecursive( } } -std::vector ScheduleUnitBuilder::NaiveBuild(std::shared_ptr &root) { +std::vector +ScheduleUnitBuilder::NaiveBuild(std::shared_ptr &root) { NaiveScheduleRecursive(root); return NaiveAssignWarpgroupIds(root.get(), config_, thread_var_->dom->extent); } diff --git a/src/transform/auto_schedule/schedule_builder.h b/src/transform/auto_schedule/schedule_builder.h index d116b808ee..7ca3af114e 100644 --- a/src/transform/auto_schedule/schedule_builder.h +++ b/src/transform/auto_schedule/schedule_builder.h @@ -53,8 +53,9 @@ AssignWarpgroupIdsGlobal(IRStructure *root, const WarpSpecializeConfig &config, PrimExpr thread_count); // Naive warpgroup assignment: TMA→wg1, compute→wg0, neutral→-1 -std::vector NaiveAssignWarpgroupIds(IRStructure *root, const WarpSpecializeConfig &config, - PrimExpr thread_count); +std::vector +NaiveAssignWarpgroupIds(IRStructure *root, const WarpSpecializeConfig &config, + PrimExpr thread_count); // Extract all sequential task nodes from the IR structure tree void GatherTaskNodes(const std::vector> &nodes, diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index ca6a30f4fa..4995d19624 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -1146,9 +1146,8 @@ Stmt ApplyWarpgroupPartitionToIRStructure( pending.clear(); } else { // Non-Control segment: accumulate for merging - pending.insert(pending.end(), - std::make_move_iterator(seg.begin()), - std::make_move_iterator(seg.end())); + pending.insert(pending.end(), std::make_move_iterator(seg.begin()), + std::make_move_iterator(seg.end())); } } // Trailing non-Control segments: append to last merged segment @@ -1211,10 +1210,10 @@ Stmt ApplyWarpgroupPartitionToIRStructure( for (size_t i = 0; i < num_wgs; ++i) { wg_seg_stmts[i] = SeqStmt({Evaluate(Call(DataType::Handle(), tl::set_max_nreg(), - {i == 0 ? config.consumer_max_nreg - : config.producer_max_nreg, + {i == 0 ? config.consumer_max_nreg + : config.producer_max_nreg, static_cast(!i)})), - wg_seg_stmts[i]}); + wg_seg_stmts[i]}); } } From a5e3f19c1e2a05b2f958599b70191badb171c95c Mon Sep 17 00:00:00 2001 From: LJC00118 <77378439+LJC00118@users.noreply.github.com> Date: Fri, 10 Apr 2026 12:39:21 +0800 Subject: [PATCH 033/156] Re-enable deprecated `TL_DISABLE_TMA_LOWER` pass config for TMA store (#2024) * Re-enable deprecated TL_DISABLE_TMA_LOWER pass config for TMA store * fix lint --------- Co-authored-by: wangxiangwen --- src/op/builtin.h | 4 ++-- src/op/copy.cc | 5 ++++- tilelang/transform/pass_config.py | 11 ++++------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/op/builtin.h b/src/op/builtin.h index 6268528fac..a4d0ba4125 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -79,8 +79,8 @@ static constexpr const char *kDisableSafeMemoryLegalize = static constexpr const char *kDisableWarpSpecialized = "tl.disable_warp_specialized"; static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth"; -// Deprecated compatibility-only pass config. It is no longer consumed by the -// lowering pipeline, but remains registered so legacy kernels keep working. +// Deprecated pass config, temporarily re-enabled. Prevents plain T.copy() +// from auto-lowering to TMA store. Will be removed in v0.1.10. static constexpr const char *kDisableTMALower = "tl.disable_tma_lower"; static constexpr const char *kEnableAggressiveSharedMemoryMerge = "tl.enable_aggressive_shared_memory_merge"; diff --git a/src/op/copy.cc b/src/op/copy.cc index 0669e5a37e..a887d6d3ca 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -947,7 +947,10 @@ CopyInst CopyNode::GetCopyInst(Target target, const LayoutMap &layout_map, // Plain T.copy does not auto-upgrade to TMA loads anymore. Store-side TMA // remains allowed because it is self-synchronized locally and does not // participate in pipeline producer scheduling. - if (!GetDisableTMA()) { + // Also honour the (deprecated) global pass config for backward compat. + if (!GetDisableTMA() && !tvm::transform::PassContext::Current() + ->GetConfig(kDisableTMALower, Bool(false)) + .value()) { bool is_cutedsl = TargetIsCuTeDSL(target); if (!is_cutedsl && !buffer_oob && CheckBulkStore1D(target, layout_map, analyzer)) { diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py index 8973c8bb27..872d0d3484 100644 --- a/tilelang/transform/pass_config.py +++ b/tilelang/transform/pass_config.py @@ -86,11 +86,10 @@ class PassConfigKey(str, Enum): """Bitwidth for configuration indices. Default: 32""" TL_DISABLE_TMA_LOWER = "tl.disable_tma_lower" - """Deprecated compatibility-only flag for legacy kernels. + """Deprecated flag — prevents plain T.copy() from auto-lowering to TMA store. - This flag no longer has any effect in the current lowering pipeline and is - kept only so older kernels do not fail pass-config validation. It will be - removed in v0.1.10. + Temporarily re-enabled for backward compatibility. Will be removed in + v0.1.10. """ TL_DISABLE_SAFE_MEMORY_ACCESS = "tl.disable_safe_memory_legalize" @@ -273,9 +272,7 @@ class PassConfigKey(str, Enum): _DEPRECATED_PASS_CONFIG_MESSAGES = { PassConfigKey.TL_DISABLE_TMA_LOWER.value: ( - "`tl.disable_tma_lower` is deprecated, kept only for backward " - "compatibility, has no effect in the current lowering pipeline, and " - "will be removed in v0.1.10." + "`tl.disable_tma_lower` is deprecated and will be removed in v0.1.10. Use `T.copy(..., disable_tma=True)` per-copy instead." ), } From b1a88bf794e1cab4bee584b4c1930160eb39774e Mon Sep 17 00:00:00 2001 From: Chaofan Lin Date: Fri, 10 Apr 2026 12:41:06 +0800 Subject: [PATCH 034/156] [Misc] Remove mistakenly introduced temp file (#2027) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Delete language\test_tilelang_language_tma_copy.py::test_tma_copy_pipeline_2_stages。 --- ...uage_tma_copy.py::test_tma_copy_pipeline_2_stages\343\200\202" | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 "language\\test_tilelang_language_tma_copy.py::test_tma_copy_pipeline_2_stages\343\200\202" diff --git "a/language\\test_tilelang_language_tma_copy.py::test_tma_copy_pipeline_2_stages\343\200\202" "b/language\\test_tilelang_language_tma_copy.py::test_tma_copy_pipeline_2_stages\343\200\202" deleted file mode 100644 index e69de29bb2..0000000000 From 9b4e0a2e511d00cdb37833c98a02181a0214f735 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Fri, 10 Apr 2026 14:09:06 +0800 Subject: [PATCH 035/156] Fix warpgroup partition --- examples/auto_schedule/flashmla_benchmark.py | 2 +- .../auto_schedule/warpgroup_partition.cc | 375 +++++++----------- .../auto_schedule/warpgroup_partition.h | 5 + 3 files changed, 148 insertions(+), 234 deletions(-) diff --git a/examples/auto_schedule/flashmla_benchmark.py b/examples/auto_schedule/flashmla_benchmark.py index d9273b585a..ac8973233e 100644 --- a/examples/auto_schedule/flashmla_benchmark.py +++ b/examples/auto_schedule/flashmla_benchmark.py @@ -587,7 +587,7 @@ def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64): configs = [ (flashattn_auto, "auto_schedule"), - (flashattn_manual, "manual"), + # (flashattn_manual, "manual"), # manual schedule is not needed (flashattn_warp_specialize, "warp_specialize"), ] diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index 4995d19624..ccb6658ddf 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -261,6 +261,21 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id) { return CloneIRStructureWithWarpgroupFilter(node, warpgroup_id, var_remap); } +// For each child of a root SequenceNode, apply +// CloneIRStructureWithWarpgroupFilter individually. +std::vector> +CloneIRStructureChildrenWithWarpgroupFilter(SequenceNode *root_seq, + int warpgroup_id, + Map &var_remap) { + std::vector> result; + result.reserve(root_seq->children.size()); + for (const auto &child : root_seq->children) { + result.push_back(CloneIRStructureWithWarpgroupFilter( + child.get(), warpgroup_id, var_remap)); + } + return result; +} + std::shared_ptr RemoveUnusedLetDecls(std::shared_ptr root) { if (!root) @@ -928,11 +943,32 @@ Stmt ApplyWarpgroupPartitionToIRStructure( clone_neutral_filter_with_top_level(root, is_pro_top_level_index, -1); auto wg_epi_neutral_structure = clone_neutral_filter_with_top_level(root, is_epi_top_level_index, -1); + // wg_children[wg_id][child_index] = filtered IRStructure (nullptr if absent) + std::vector>> wg_children(num_wgs); std::vector> wg_structures(num_wgs); - for (size_t i = 0; i < num_wgs; ++i) { - wg_structures[i] = - RemoveUnusedLetDecls(CloneIRStructureWithWarpgroupFilter(root, i)); + if (root->IsSequence()) { + auto root_seq = static_cast(root); + for (size_t i = 0; i < num_wgs; ++i) { + Map var_remap; + wg_children[i] = + CloneIRStructureChildrenWithWarpgroupFilter(root_seq, i, var_remap); + } + for (size_t i = 0; i < num_wgs; ++i) { + // Rebuild from wg_children: wrap non-null children into a SequenceNode + auto rebuilt_seq = std::make_shared(); + for (const auto &child : wg_children[i]) { + if (child) + rebuilt_seq->children.push_back(child); + } + wg_structures[i] = rebuilt_seq->children.empty() ? nullptr : rebuilt_seq; + } + } else { + // Fallback for non-SequenceNode root: clone entire root per warpgroup + for (size_t i = 0; i < num_wgs; ++i) { + wg_structures[i] = CloneIRStructureWithWarpgroupFilter(root, i); + } } + std::vector wg_conditions(num_wgs); wg_conditions[0] = thread_count[0]; for (size_t i = 1; i < num_wgs; ++i) { @@ -958,57 +994,7 @@ Stmt ApplyWarpgroupPartitionToIRStructure( outer_enable_epi) : Evaluate(0); - // --- Segment the wg structures by ControlNode (for-loop) boundaries --- - // This produces multiple IfThenElse blocks separated by liveness boundary - // markers, so that the merge-shared-memory pass can reuse buffers across - // segments whose lifetimes do not overlap. - - // Helper: segment a top-level SequenceNode's children into groups separated - // by ControlNode boundaries. Each ControlNode becomes its own segment; - // consecutive non-ControlNode children are grouped together. - auto SegmentSequenceChildren = [](IRStructure *structure) - -> std::vector>> { - std::vector>> segments; - if (!structure || !structure->IsSequence()) { - return segments; - } - auto seq = static_cast(structure); - - std::vector> current; - for (auto &child : seq->children) { - auto unit = static_cast(child.get()); - if (unit->child && unit->child->IsControl()) { - if (!current.empty()) { - segments.push_back(std::move(current)); - current = {}; - } - segments.push_back({child}); - } else { - current.push_back(child); - } - } - if (!current.empty()) { - segments.push_back(std::move(current)); - } - - return segments; - }; - - // Helper: wrap a list of ScheduleUnit children back into a temporary - // SequenceNode and convert to Stmt. - auto SegmentToStmt = - [outer_enable_epi]( - const std::vector> &children) -> Stmt { - if (children.empty()) - return Evaluate(0); - // Even for a single child we go through the SequenceNode path so that - // ScheduleUnit before/after stmts are emitted correctly. - auto tmp_seq = std::make_shared(); - tmp_seq->children = children; - return ConvertIRStructureToStmt(tmp_seq.get(), outer_enable_epi); - }; - - // Helper: build a single IfThenElse (with wg1 nesting) from a pair of Stmts. + // Helper: build a single IfThenElse (with wg nesting) from per-wg Stmts. auto MakeWarpgroupIf = [&wg_conditions](const std::vector &wg_stmts) -> Stmt { Stmt if_then_else = Evaluate(0); @@ -1018,208 +1004,131 @@ Stmt ApplyWarpgroupPartitionToIRStructure( return if_then_else; }; - // Helper: collect LetDecl {Var, PrimExpr} pairs from a segment's children. - // Returns them in order of appearance, which is the order they must be - // nested. - auto CollectLetDeclInfo = - [](const std::vector> &children) - -> std::vector> { - std::vector> result; - for (auto &child : children) { - auto unit = static_cast(child.get()); - IRStructure *inner = unit->child.get(); - // Handle ScheduleUnit wrapping a TaskNode - if (inner && inner->IsTask()) { - auto task = static_cast(inner); - if (IsLetDeclTask(task)) { - const auto *let = task->stmts[0].as(); - result.push_back({let->var, let->value}); - } - } - } - return result; - }; + // Check for SIMT copy in wg1 (needed for set_max_nreg decision). + bool has_simt_copy = false; + if (num_wgs == 2 && wg_structures[1]) { + Stmt full_wg1 = + ConvertIRStructureToStmt(wg_structures[1].get(), outer_enable_epi); + has_simt_copy = SimtCopyDetector::Detect(full_wg1); + } - // Helper: given a Stmt and accumulated LetDecl pairs from previous segments, - // create fresh variables with copy_with_suffix, substitute all references - // in the Stmt, and wrap with LetStmt bindings. Variables that are not - // referenced in the body (or in kept variables' value expressions) are - // pruned to avoid dead declarations. - auto WrapWithRenamedLetDecls = - [](Stmt body, - const std::vector> &accumulated_lets) - -> Stmt { - if (accumulated_lets.empty()) - return body; - - // Build substitution map: old_var -> new_var - Map subst_map; - // Create fresh vars and accumulate them (in order) - std::vector> new_lets; - for (auto &[old_var, old_value] : accumulated_lets) { - auto new_var = old_var.copy_with_suffix(""); - subst_map.Set(old_var, new_var); - PrimExpr new_value = Substitute(old_value, subst_map); - new_lets.push_back({new_var, new_value}); - } + // --- Per-child construction --- + // Walk root SequenceNode's children. LetDecl children accumulate bindings; + // non-LetDecl children produce IfThenElse blocks wrapped with accumulated + // LetDecl scopes per warp group. - // Substitute all references in the body - body = Substitute(body, subst_map); - - // Determine which variables are actually used. Walk from innermost to - // outermost: a variable is "needed" if it appears in the body or in any - // already-needed variable's value expression. - std::vector needed(new_lets.size(), false); - // Start with variables used directly in the body. - for (size_t i = 0; i < new_lets.size(); ++i) { - const Var &v = new_lets[i].first; - if (UsesVar(body, - [&v](const VarNode *node) { return node == v.get(); })) { - needed[i] = true; - } - } - // Propagate: if variable j is needed and its value uses variable i, - // then i is also needed. Iterate until fixpoint. - bool changed = true; - while (changed) { - changed = false; - for (size_t j = 0; j < new_lets.size(); ++j) { - if (!needed[j]) - continue; - for (size_t i = 0; i < j; ++i) { - if (needed[i]) - continue; - const Var &vi = new_lets[i].first; - if (UsesVar(new_lets[j].second, [&vi](const VarNode *node) { - return node == vi.get(); - })) { - needed[i] = true; - changed = true; + Stmt if_then_else; + if (root->IsSequence()) { + auto root_seq = static_cast(root); + size_t num_children = root_seq->children.size(); + + // per-wg accumulated LetDecl {var, value} from earlier children + std::vector>> wg_accumulated_lets( + num_wgs); + + std::vector segmented_stmts; + bool first_non_let = true; + + for (size_t ci = 0; ci < num_children; ++ci) { + auto unit = static_cast(root_seq->children[ci].get()); + bool is_let_decl = IsLetDeclNode(unit->child.get()); + + if (is_let_decl) { + // Extract LetDecl {var, value} from each wg's filtered result + for (size_t i = 0; i < num_wgs; ++i) { + if (wg_children[i][ci]) { + // wg_children[i][ci] is a ScheduleUnit wrapping a TaskNode + IRStructure *inner = wg_children[i][ci].get(); + TaskNode *task = nullptr; + if (inner->IsScheduleUnit()) { + task = static_cast( + static_cast(inner)->child.get()); + } else if (inner->IsTask()) { + task = static_cast(inner); + } + if (task && !task->stmts.empty()) { + const auto *let = task->stmts[0].as(); + if (let) { + wg_accumulated_lets[i].push_back({let->var, let->value}); + } + } } } + continue; // LetDecl children don't produce IfThenElse } - } - // Wrap only needed LetStmt bindings (innermost first) - for (int i = static_cast(new_lets.size()) - 1; i >= 0; --i) { - if (needed[i]) { - body = LetStmt(new_lets[i].first, new_lets[i].second, body); + // Build per-wg Stmt for this child, wrapped with accumulated LetDecls + std::vector wg_stmts(num_wgs); + bool all_empty = true; + for (size_t i = 0; i < num_wgs; ++i) { + if (wg_children[i][ci]) { + auto tmp_seq = std::make_shared(); + tmp_seq->children.push_back(wg_children[i][ci]); + wg_stmts[i] = + ConvertIRStructureToStmt(tmp_seq.get(), outer_enable_epi); + } else { + wg_stmts[i] = Evaluate(0); + } + if (!IsEvaluateZero(wg_stmts[i])) { + all_empty = false; + } + // Wrap with accumulated LetDecl bindings (innermost first) + for (int j = static_cast(wg_accumulated_lets[i].size()) - 1; + j >= 0; --j) { + wg_stmts[i] = LetStmt(wg_accumulated_lets[i][j].first, + wg_accumulated_lets[i][j].second, wg_stmts[i]); + } } - } - return body; - }; - - Stmt if_then_else; - std::vector>>> - wg_segments(num_wgs); - bool equal_segment_counts = true; - for (size_t i = 0; i < num_wgs; ++i) { - wg_segments[i] = SegmentSequenceChildren(wg_structures[i].get()); - equal_segment_counts &= (wg_segments[i].size() == wg_segments[0].size()); - } - // Helper: extract the For loop_var pointer from a Control segment - // (a segment with exactly one ScheduleUnit whose child is a ControlNode). - auto GetControlLoopVar = - [](const std::vector> &seg) - -> const VarNode * { - if (seg.size() != 1) - return nullptr; - auto unit = static_cast(seg[0].get()); - if (!unit->child || !unit->child->IsControl()) - return nullptr; - auto ctrl = static_cast(unit->child.get()); - return ctrl->control->loop_var.get(); - }; + // Skip segments where all warpgroups produce empty statements + if (all_empty) + continue; - auto MergeNonControlSegments = - [&GetControlLoopVar]( - std::vector>> &segments) { - std::vector>> merged; - std::vector> pending; - for (auto &seg : segments) { - auto lv = GetControlLoopVar(seg); - if (lv) { - // Control segment: prepend any pending non-Control children - merged.push_back(std::move(pending)); - merged.push_back(std::move(seg)); - pending.clear(); - } else { - // Non-Control segment: accumulate for merging - pending.insert(pending.end(), std::make_move_iterator(seg.begin()), - std::make_move_iterator(seg.end())); - } + // Insert liveness boundary before each non-empty non-LetDecl child + segmented_stmts.push_back(AttrStmt( + Integer(0), attr::kAutoScheduleSharedMemoryBoundary, 0, Evaluate(0))); + + // Prepend set_max_nreg only to the first non-LetDecl child + if (first_non_let && !has_simt_copy && num_wgs == 2 && + config.enable_set_max_nreg) { + for (size_t i = 0; i < num_wgs; ++i) { + wg_stmts[i] = + SeqStmt({Evaluate(Call(DataType::Handle(), tl::set_max_nreg(), + {i == 0 ? config.consumer_max_nreg + : config.producer_max_nreg, + static_cast(!i)})), + wg_stmts[i]}); } - // Trailing non-Control segments: append to last merged segment - merged.push_back(std::move(pending)); - segments = std::move(merged); - }; - - for (auto &segments : wg_segments) { - MergeNonControlSegments(segments); - } - - // Apply segmented splitting - std::vector segmented_stmts; - bool has_simt_copy = false; - // Check for SIMT copy in any wg1 segment (needed for set_max_nreg - // decision). - if (num_wgs == 2) { - Stmt full_wg1 = - ConvertIRStructureToStmt(wg_structures[1].get(), outer_enable_epi); - has_simt_copy = SimtCopyDetector::Detect(full_wg1); - } - - // Accumulate LetDecl info from previous segments for variable renaming. - std::vector>> wg_accumulated_lets( - num_wgs); - - for (size_t si = 0; si < wg_segments[0].size(); ++si) { - // Insert liveness boundary between segments. - segmented_stmts.push_back(AttrStmt( - Integer(0), attr::kAutoScheduleSharedMemoryBoundary, 0, Evaluate(0))); + } + first_non_let = false; - // Collect LetDecl info from current segment before converting to Stmt. - std::vector>> wg_lets(num_wgs); - for (size_t i = 0; i < num_wgs; ++i) { - wg_lets[i] = CollectLetDeclInfo(wg_segments[i][si]); + segmented_stmts.push_back(MakeWarpgroupIf(wg_stmts)); } - std::vector wg_seg_stmts(num_wgs); + if_then_else = SeqStmt::Flatten(segmented_stmts); + } else { + // Fallback for non-SequenceNode root: no boundary insertion, simple + // partition + std::vector wg_stmts(num_wgs); for (size_t i = 0; i < num_wgs; ++i) { - wg_seg_stmts[i] = SegmentToStmt(wg_segments[i][si]); - } - - // For segments after the first, wrap with renamed LetDecl bindings - // from all previous segments so that variables remain in scope. - if (si > 0) { - for (size_t i = 0; i < num_wgs; ++i) { - wg_seg_stmts[i] = - WrapWithRenamedLetDecls(wg_seg_stmts[i], wg_accumulated_lets[i]); + if (wg_structures[i]) { + wg_stmts[i] = + ConvertIRStructureToStmt(wg_structures[i].get(), outer_enable_epi); + } else { + wg_stmts[i] = Evaluate(0); } } - - // Accumulate this segment's LetDecls for future segments. - for (size_t i = 0; i < num_wgs; ++i) { - wg_accumulated_lets[i].insert(wg_accumulated_lets[i].end(), - wg_lets[i].begin(), wg_lets[i].end()); - } - - // Prepend set_max_nreg only to the first segment. - if (si == 0 && !has_simt_copy && num_wgs == 2 && - config.enable_set_max_nreg) { + if (!has_simt_copy && num_wgs == 2 && config.enable_set_max_nreg) { for (size_t i = 0; i < num_wgs; ++i) { - wg_seg_stmts[i] = + wg_stmts[i] = SeqStmt({Evaluate(Call(DataType::Handle(), tl::set_max_nreg(), {i == 0 ? config.consumer_max_nreg : config.producer_max_nreg, static_cast(!i)})), - wg_seg_stmts[i]}); + wg_stmts[i]}); } } - - segmented_stmts.push_back(MakeWarpgroupIf(wg_seg_stmts)); + if_then_else = MakeWarpgroupIf(wg_stmts); } - if_then_else = SeqStmt::Flatten(segmented_stmts); PrimExpr updated_thread_extent = std::accumulate( thread_count.begin() + 1, thread_count.end(), thread_count[0]); diff --git a/src/transform/auto_schedule/warpgroup_partition.h b/src/transform/auto_schedule/warpgroup_partition.h index 7b47d89cff..a8ee663057 100644 --- a/src/transform/auto_schedule/warpgroup_partition.h +++ b/src/transform/auto_schedule/warpgroup_partition.h @@ -42,6 +42,11 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id); std::shared_ptr RemoveUnusedLetDecls(std::shared_ptr root); +std::vector> +CloneIRStructureChildrenWithWarpgroupFilter(SequenceNode *root_seq, + int warpgroup_id, + Map &var_remap); + class SimtCopyDetector; Stmt ConvertIRStructureToStmt(IRStructure *root, const bool outer_enable_epi); From b31aaa9f7ee6d370c2e07a4ac31878d18b8888ac Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Fri, 10 Apr 2026 16:55:26 +0800 Subject: [PATCH 036/156] fix barrier logic --- src/transform/auto_schedule/barrier.h | 90 ++++++++++++++++----------- 1 file changed, 54 insertions(+), 36 deletions(-) diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index 6f237a5faf..c2ee5b61db 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -637,7 +637,8 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, } } - std::map barrier_unit_map; + // Map (ScheduleUnit, warpgroup_id) to barrier buffer + std::map, Buffer> barrier_unit_map; // Allocate barriers for TCGEN05MMA for (auto &promote_child : seq->children) { @@ -645,17 +646,23 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, if (task->isInnerTask() && task->UsesTensorCore()) { auto child = static_cast(task->child.get()); if (child->is_TCGEN05()) { - int barrier_id = next_barrier_id++; - // Create a single barrier buffer with shape (1,) - Buffer barrier_buffer = makeBarrierBuffer( - 1, "tcgen05_barrier_" + std::to_string(barrier_id), 1, - barrier_buffers, barrier_map); - barrier_unit_map[task] = barrier_buffer; - - // Rewrite the gemm call's mbar argument (arg[16]) to use - // BufferLoad(barrier_buffer, {0}) - PrimExpr mbar_expr = BufferLoad(barrier_buffer, {0}); - RewriteGemmMbar(child, mbar_expr); + int wg_id = child->GetWarpgroupId(); + if (wg_id != -1) { + int barrier_id = next_barrier_id++; + // Create a single barrier buffer with shape (1,) + Buffer barrier_buffer = makeBarrierBuffer( + 1, "tcgen05_barrier_" + std::to_string(barrier_id), 1, + barrier_buffers, barrier_map); + barrier_unit_map[std::make_pair(task, wg_id)] = barrier_buffer; + + // Rewrite the gemm call's mbar argument (arg[16]) to use + // BufferLoad(barrier_buffer, {0}) + PrimExpr mbar_expr = BufferLoad(barrier_buffer, {0}); + RewriteGemmMbar(child, mbar_expr); + } else { + PrimExpr mbar_expr = BufferLoad(neutral_sync_shared_barrier, {0}); + RewriteGemmMbar(child, mbar_expr); + } } } } @@ -672,7 +679,7 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, Buffer barrier_buffer = makeBarrierBuffer( thread_count[wg_id], "tma_barrier_" + std::to_string(barrier_id), 1, barrier_buffers, barrier_map); - barrier_unit_map[task] = barrier_buffer; + barrier_unit_map[std::make_pair(task, wg_id)] = barrier_buffer; PrimExpr barrier_load = BufferLoad(barrier_buffer, {0}); RewriteCopyMbar(child, barrier_load); @@ -687,16 +694,18 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, } // Insert barriers for other dependencies - // First collect shared buffers - std::set shared_buffers; + // First collect all buffers except register buffers + std::set buffers; for (const auto ®ion_access : seq->GetReadWriteRegions()) { auto &buffer = region_access.region->buffer; - shared_buffers.emplace(buffer); + if (!IsRegisterRegion(region_access.region)) { + buffers.emplace(buffer); + } } auto is_async_task = [](ScheduleUnit *task) { return task->UsesTensorCore() || task->UsesTMACore(); }; - for (const auto &buffer : shared_buffers) { + for (const auto &buffer : buffers) { std::vector last_access_task(num_wgs, nullptr); std::vector last_access(num_wgs, false); ScheduleUnit *last_write_task = nullptr; @@ -718,23 +727,26 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, if (last_wg_id == -1) return; bool last_async = is_async_task(last_task); - if (last_wg_id == wg_id && !is_async && !last_async) + if (last_wg_id == wg_id && !last_async) return; - if (barrier_unit_map.find(last_task) == barrier_unit_map.end()) { + if (barrier_unit_map.find(std::make_pair(last_task, last_wg_id)) == + barrier_unit_map.end()) { // Allocate a new barrier buffer int barrier_id = next_barrier_id++; Buffer barrier_buffer = makeBarrierBuffer(thread_count[last_wg_id], "barrier_" + std::to_string(barrier_id), 1, barrier_buffers, barrier_map); - barrier_unit_map[last_task] = barrier_buffer; + barrier_unit_map[std::make_pair(last_task, last_wg_id)] = + barrier_buffer; PrimExpr barrier_load = BufferLoad(barrier_buffer, {0}); // Insert barrier_arrive at the end of last_task's statements Stmt arrive_stmt = makeBarrierArrive(barrier_load); InsertStatementIntoScheduleUnit(last_task, arrive_stmt, false, last_wg_id); } - auto barrier_buffer = barrier_unit_map[last_task]; + auto barrier_buffer = + barrier_unit_map[std::make_pair(last_task, last_wg_id)]; PrimExpr barrier_load = BufferLoad(barrier_buffer, {0}); Stmt wait_stmt = makeBarrierWait(barrier_load, 0); InsertStatementIntoScheduleUnit(task, wait_stmt, true, wg_id); @@ -944,7 +956,9 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, } } - std::map> barrier_unit_map; + // Map (ScheduleUnit, warpgroup_id) to (barrier buffer, num_versions) + std::map, std::pair> + barrier_unit_map; // Allocate barriers for TCGEN05MMA for (ScheduleUnit *task : ordered_tasks) { @@ -959,13 +973,16 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, num_versions = std::max(num_versions, it->second); } } + int wg_id = child->GetWarpgroupId(); + ICHECK(wg_id != -1) << "TCGEN05MMA must have valid warpgroup id"; int barrier_id = next_barrier_id++; // Create a single barrier buffer with shape (num_versions,) Buffer barrier_buffer = makeBarrierBuffer( 1, "tcgen05_barrier_" + std::to_string(barrier_id), num_versions, barrier_buffers, barrier_map); - barrier_unit_map[task] = std::make_pair(barrier_buffer, num_versions); + barrier_unit_map[std::make_pair(task, wg_id)] = + std::make_pair(barrier_buffer, num_versions); // Rewrite the gemm call's mbar argument (arg[16]) to use // BufferLoad(barrier_buffer, {version_index}) @@ -997,7 +1014,8 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, Buffer barrier_buffer = makeBarrierBuffer( thread_count[wg_id], "tma_barrier_" + std::to_string(barrier_id), num_versions, barrier_buffers, barrier_map); - barrier_unit_map[task] = std::make_pair(barrier_buffer, num_versions); + barrier_unit_map[std::make_pair(task, wg_id)] = + std::make_pair(barrier_buffer, num_versions); PrimExpr version_index = indexmod(loop_info.CalculateIterationCount(), num_versions); @@ -1010,15 +1028,15 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, } // Insert barriers for other dependencies - // First collect shared buffers + // First collect all buffers except register buffers std::set, std::greater>> - shared_buffers; + buffers; for (const auto ®ion_access : ctrl->GetReadWriteRegions()) { auto &buffer = region_access.region->buffer; - if (IsSharedBuffer(buffer)) { + if (!IsRegisterRegion(region_access.region)) { auto it = buffer_num_versions.find(buffer); int num_versions = it != buffer_num_versions.end() ? it->second : 1; - shared_buffers.emplace(num_versions, buffer); + buffers.emplace(num_versions, buffer); } } // Process buffers in order of decreasing number of versions to ensure @@ -1026,7 +1044,7 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, auto is_async_task = [](ScheduleUnit *task) { return task->UsesTensorCore() || task->UsesTMACore(); }; - for (const auto &[num_versions, buffer] : shared_buffers) { + for (const auto &[num_versions, buffer] : buffers) { std::vector last_access_task(num_wgs, nullptr); std::vector last_access(num_wgs, false); ScheduleUnit *last_write_task = nullptr; @@ -1036,7 +1054,6 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, // Process tasks in the specified order for (unsigned iter = 0; iter != 2; ++iter) { for (ScheduleUnit *task : ordered_tasks) { - int stage = task->GetStage(); bool is_async = is_async_task(task); for (const auto ®ion_access : task->GetReadWriteRegions()) { int wg_id = region_access.warpgroup_id; @@ -1048,19 +1065,20 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, auto insert_barrier = [&](ScheduleUnit *last_task, int last_wg_id) { if (last_wg_id == -1) return; - int last_stage = last_task->GetStage(); + if (last_task == task) // ??? + return; bool last_async = is_async_task(last_task); - if (last_wg_id == wg_id && last_stage == stage && !is_async && - !last_async) + if (last_wg_id == wg_id && !last_async) return; - if (barrier_unit_map.find(last_task) == barrier_unit_map.end()) { + if (barrier_unit_map.find(std::make_pair( + last_task, last_wg_id)) == barrier_unit_map.end()) { // Allocate a new barrier buffer int barrier_id = next_barrier_id++; Buffer barrier_buffer = makeBarrierBuffer( thread_count[last_wg_id], "barrier_" + std::to_string(barrier_id), num_versions, barrier_buffers, barrier_map); - barrier_unit_map[last_task] = + barrier_unit_map[std::make_pair(last_task, last_wg_id)] = std::make_pair(barrier_buffer, num_versions); // Create BufferLoad with version-indexed offset PrimExpr version_index = @@ -1073,7 +1091,7 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, last_wg_id); } auto [barrier_buffer, barrier_versions] = - barrier_unit_map[last_task]; + barrier_unit_map[std::make_pair(last_task, last_wg_id)]; PrimExpr iteration = loop_info.CalculateIterationCount(); if (iter == 1) { // Calculate the real iteration to wait. From 853e805b2ae8d1d06cee215367a1bdf775a83058 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Sat, 11 Apr 2026 16:10:54 +0800 Subject: [PATCH 037/156] [Codegen] Add lexical_alloc_scope for scoped local variable lifetime (#2023) * [Codegen] Add lexical_alloc_scope for scoped local variable lifetime Introduce a `lexical_alloc_scope` AttrStmt that generates `{ ... }` in C/CUDA codegen, giving the underlying compiler accurate variable lifetime information for better register allocation. - Define `tl::attr::kLexicalAllocScope` constant - LowerOpaqueBlock wraps block-local allocations in the new AttrStmt - StorageRewrite treats it as a scope boundary with proper thread_scope_ save/restore so allocations are not hoisted past the boundary - CUDA and HIP codegen emit scoped `{ }` blocks - Add tests for IR insertion, StorageRewrite preservation, and codegen output * lint fix * [Codegen] Refine lexical_alloc_scope: skip top-level blocks and decouple from thread_scope Two improvements to the lexical_alloc_scope mechanism: 1. LowerOpaqueBlock: only insert lexical_alloc_scope for blocks inside loops (inside_loop_ > 0). Top-level blocks already have function-body lifetime, so the extra `{ }` in codegen is pointless. 2. StorageRewrite: introduce a separate `lexical_scope_` / `effective_scope()` instead of overriding `thread_scope_`. This avoids breaking PlanNewScope's toggle protocol when lexical_alloc_scope is nested inside thread_extent, fixing a `ICHECK(thread_scope_ == op)` crash. Co-Authored-By: Claude Opus 4.6 (1M context) * [Transform] Fix alloc-scope placement regressions * [Infra] Remove clang-tidy integration * Remove unused local descriptor allocation pass and related tests; update example to disable main execution and print kernel source for debugging. * Preserve lexical alloc scopes for nested register buffers * Limit lexical alloc scopes to local storage * refactor * Unify lexical alloc scope annotations * Clean up dead code and fix unsafe cast - Remove unused block_nesting_ member from OpaqueBlockLower - Use static_cast instead of reinterpret_cast in ResolveAllocationSite Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: Claude Opus 4.6 (1M context) --- .clang-tidy | 60 --- .github/workflows/ci.yml | 47 -- .github/workflows/pr-regression-test-bot.yml | 2 - .gitignore | 3 - format.sh | 59 --- requirements-lint.txt | 1 - src/op/builtin.h | 8 + src/op/copy.cc | 7 +- src/op/gemm_py.cc | 9 +- src/op/gemm_sp_py.cc | 9 +- src/target/codegen_cuda.cc | 11 +- src/target/codegen_hip.cc | 11 +- src/transform/lower_opaque_block.cc | 16 +- .../plan_update_buffer_allocation_location.cc | 94 ++++ .../reuse_local_descriptor_allocations.cc | 254 ----------- src/transform/storage_rewrite.cc | 60 ++- ..._tilelang_transform_lexical_alloc_scope.py | 429 ++++++++++++++++++ ..._plan_update_buffer_allocation_location.py | 71 +++ ...form_reuse_local_descriptor_allocations.py | 105 ----- tilelang/engine/phase.py | 1 - tilelang/transform/__init__.py | 11 - 21 files changed, 715 insertions(+), 553 deletions(-) delete mode 100644 .clang-tidy delete mode 100644 src/transform/reuse_local_descriptor_allocations.cc create mode 100644 testing/python/transform/test_tilelang_transform_lexical_alloc_scope.py create mode 100644 testing/python/transform/test_tilelang_transform_plan_update_buffer_allocation_location.py delete mode 100644 testing/python/transform/test_tilelang_transform_reuse_local_descriptor_allocations.py diff --git a/.clang-tidy b/.clang-tidy deleted file mode 100644 index f9b77bce8a..0000000000 --- a/.clang-tidy +++ /dev/null @@ -1,60 +0,0 @@ ---- -InheritParentConfig: true -ExtraArgs: [] -FormatStyle: file -UseColor: true -WarningsAsErrors: '*' -# FIXME: Use `ExcludeHeaderFilterRegex` instead when all maintainers upgraded their `clang-tidy` -HeaderFilterRegex: '^(?!.*(?:/|^)(3rdparty|tvm)/).*' -# ExcludeHeaderFilterRegex: '^(3rdparty|tvm)/.*$' - -# NOTE: there must be no spaces before the '-', so put the comma last. -Checks: >- - # 1. Retained categories: easier to find bugs/performance issues - clang-analyzer-*, - cppcoreguidelines-pro-type-static-cast-downcast, - cppcoreguidelines-pro-type-member-init, - cppcoreguidelines-pro-bounds-array-to-pointer-decay, - cppcoreguidelines-pro-bounds-pointer-arithmetic, - cppcoreguidelines-slicing, - cppcoreguidelines-narrowing-conversions, - performance-*, - - # 2. Readability: only keep useful rules - readability-braces-around-statements, - readability-container-size-empty, - readability-delete-null-pointer, - readability-redundant-member-init, - readability-redundant-smartptr-get, - readability-redundant-string-cstr, - - # 3. Disable all intrusive/style-breaking rules - -readability-identifier-length, - -readability-avoid-const-params-in-decls, - -readability-else-after-return, - -cppcoreguidelines-avoid-magic-numbers, - -modernize-use-trailing-return-type, - -modernize-use-nodiscard, - -modernize-use-auto, - -modernize-pass-by-value, - -modernize-return-braced-init-list, - -modernize-use-default-member-init, - -modernize-loop-convert, - -modernize-concat-nested-namespaces, - -llvm-include-order, - -bugprone-unused-return-value, - -clang-diagnostic-unused-result, - -cppcoreguidelines-special-member-functions, - -performance-noexcept-move-constructor, - -cppcoreguidelines-narrowing-conversions, - -clang-diagnostic-error, - -cppcoreguidelines-pro-type-member-init, - -clang-analyzer-optin.cplusplus.UninitializedObject, - -cppcoreguidelines-pro-type-static-cast-downcast, - -performance-unnecessary-value-param, - -performance-enum-size, - -cppcoreguidelines-pro-bounds-pointer-arithmetic, - -cppcoreguidelines-pro-bounds-array-to-pointer-decay, - -clang-analyzer-deadcode.DeadStores, - -clang-analyzer-optin.cplusplus.VirtualCall, - -clang-diagnostic-tautological-constant-compare, diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 24821b8cd2..f899ed473a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,7 +18,6 @@ concurrency: cancel-in-progress: ${{ github.event_name == 'pull_request' }} env: - CLANG_TIDY_CMAKE_OPTIONS: "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON" # to be updated PYTHONDEVMODE: "1" PYTHONUNBUFFERED: "1" PYTHONPATH: "" # explicit cleanup @@ -153,7 +152,6 @@ jobs: export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu${CUDA_VERSION_MAJMIN_NODOT}" fi export UV_INDEX="${PIP_EXTRA_INDEX_URL}" - export CLANG_TIDY_CMAKE_OPTIONS="${CLANG_TIDY_CMAKE_OPTIONS} -DUSE_CUDA=ON" echo "USE_CUDA=ON" | tee -a "${GITHUB_ENV}" echo "CUDA_VERSION=${CUDA_VERSION}" | tee -a "${GITHUB_ENV}" @@ -161,7 +159,6 @@ jobs: echo "CUDA_VERSION_MAJMIN_NODOT=${CUDA_VERSION_MAJMIN_NODOT}" | tee -a "${GITHUB_ENV}" echo "PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}" | tee -a "${GITHUB_ENV}" echo "UV_INDEX=${UV_INDEX}" | tee -a "${GITHUB_ENV}" - echo "CLANG_TIDY_CMAKE_OPTIONS=${CLANG_TIDY_CMAKE_OPTIONS}" | tee -a "${GITHUB_ENV}" if [[ ! -x "$(command -v nvcc)" ]]; then export PATH="/usr/local/cuda/bin:${PATH}" @@ -189,7 +186,6 @@ jobs: export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/rocm${ROCM_VERSION_MAJMIN}" fi export UV_INDEX="${PIP_EXTRA_INDEX_URL}" - export CLANG_TIDY_CMAKE_OPTIONS="${CLANG_TIDY_CMAKE_OPTIONS} -DUSE_ROCM=ON" echo "USE_ROCM=ON" | tee -a "${GITHUB_ENV}" echo "ROCM_VERSION=${ROCM_VERSION}" | tee -a "${GITHUB_ENV}" @@ -197,7 +193,6 @@ jobs: echo "ROCM_VERSION_MAJMIN_NODOT=${ROCM_VERSION_MAJMIN_NODOT}" | tee -a "${GITHUB_ENV}" echo "PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}" | tee -a "${GITHUB_ENV}" echo "UV_INDEX=${UV_INDEX}" | tee -a "${GITHUB_ENV}" - echo "CLANG_TIDY_CMAKE_OPTIONS=${CLANG_TIDY_CMAKE_OPTIONS}" | tee -a "${GITHUB_ENV}" if [[ ! -x "$(command -v hipcc)" ]]; then export PATH="/opt/rocm/bin:${PATH}" @@ -221,10 +216,8 @@ jobs: echo "PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}" | tee -a "${GITHUB_ENV}" echo "UV_INDEX=${UV_INDEX}" | tee -a "${GITHUB_ENV}" fi - export CLANG_TIDY_CMAKE_OPTIONS="${CLANG_TIDY_CMAKE_OPTIONS} -DUSE_METAL=ON" echo "USE_METAL=ON" | tee -a "${GITHUB_ENV}" - echo "CLANG_TIDY_CMAKE_OPTIONS=${CLANG_TIDY_CMAKE_OPTIONS}" | tee -a "${GITHUB_ENV}" - name: Setup Python and uv with caching id: setup-uv @@ -302,46 +295,6 @@ jobs: run: | uv pip install -v . - - name: Run clang-tidy - id: clang-tidy - if: runner.os == 'Linux' - run: | - echo "\$ $(command -v clang-tidy) --version" && clang-tidy --version - - # Download run-clang-tidy script - RCT_URL=https://raw.githubusercontent.com/llvm/llvm-project/refs/heads/release/21.x/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py - echo "Downloading run-clang-tidy script from ${RCT_URL}" - echo "import urllib.request; url = '${RCT_URL}'.rstrip('/'); urllib.request.urlretrieve(url, url.split('/')[-1])" | uv run --no-project --script - - RUN_CLANG_TIDY=(uv run --no-project --script -- run-clang-tidy.py) - - if [[ -x "$(command -v clang-apply-replacements)" ]]; then - echo "Using clang-apply-replacements from $(command -v clang-apply-replacements)" - RUN_CLANG_TIDY+=(-fix -clang-apply-replacements-binary="$(command -v clang-apply-replacements)") - else - echo "::warning::clang-apply-replacements not found in PATH, automatic fixing disabled." - fi - - # Run cmake to create the build directory with compile_commands.json - cmake -S . -B cmake-build --fresh ${CLANG_TIDY_CMAKE_OPTIONS} # no quotes here - echo "::group::compile_commands.json" - ls -alh cmake-build/compile_commands.json - uv run --no-project -m -- json.tool --no-ensure-ascii cmake-build/compile_commands.json - echo "::endgroup::" - - CXX_FILES=$(find src -type f -iname "*.[ch]pp" -o -iname "*.cc" -o -iname "*.c" -o -iname "*.h") - rc=0 - echo "::group::run-clang-tidy" - "${RUN_CLANG_TIDY[@]}" -clang-tidy-binary="$(command -v clang-tidy)" \ - -exclude-header-filter='^(3rdparty|tvm)/.*$' \ - -p="cmake-build" ${CXX_FILES} || rc="$?" - echo "::endgroup::" - rm -rf cmake-build run-clang-tidy.py - if (( rc != 0 )); then - echo "::error::clang-tidy found issues (exit code: ${rc}). Please run 'clang-tidy --fix' locally to fix them." - git diff --color=always || true - exit "${rc}" - fi - - name: Clean up stale /tmp files (self-hosted runners) if: startsWith(matrix.runner.name, 'self-hosted') run: | diff --git a/.github/workflows/pr-regression-test-bot.yml b/.github/workflows/pr-regression-test-bot.yml index f8af310458..4e060c3766 100644 --- a/.github/workflows/pr-regression-test-bot.yml +++ b/.github/workflows/pr-regression-test-bot.yml @@ -127,7 +127,6 @@ jobs: export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu${CUDA_VERSION_MAJMIN_NODOT}" fi export UV_INDEX="${PIP_EXTRA_INDEX_URL}" - export CLANG_TIDY_CMAKE_OPTIONS="${CLANG_TIDY_CMAKE_OPTIONS} -DUSE_CUDA=ON" echo "USE_CUDA=ON" | tee -a "${GITHUB_ENV}" echo "CUDA_VERSION=${CUDA_VERSION}" | tee -a "${GITHUB_ENV}" @@ -135,7 +134,6 @@ jobs: echo "CUDA_VERSION_MAJMIN_NODOT=${CUDA_VERSION_MAJMIN_NODOT}" | tee -a "${GITHUB_ENV}" echo "PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}" | tee -a "${GITHUB_ENV}" echo "UV_INDEX=${UV_INDEX}" | tee -a "${GITHUB_ENV}" - echo "CLANG_TIDY_CMAKE_OPTIONS=${CLANG_TIDY_CMAKE_OPTIONS}" | tee -a "${GITHUB_ENV}" if [[ ! -x "$(command -v nvcc)" ]]; then export PATH="/usr/local/cuda/bin:${PATH}" diff --git a/.gitignore b/.gitignore index a12077bf62..21ac39bc87 100644 --- a/.gitignore +++ b/.gitignore @@ -119,9 +119,6 @@ maint/host_checks/logs/* # csv *.csv -# clang-tidy -/run-clang-tidy.py - # perf regression test .perf_regression/ diff --git a/format.sh b/format.sh index 3cc4390dbe..ad8450900f 100755 --- a/format.sh +++ b/format.sh @@ -110,65 +110,6 @@ fi echo 'tile-lang pre-commit: Done' -echo 'tile-lang clang-tidy: Check Start' -# If clang-tidy is available, run it; otherwise, skip -if [[ -x "$(command -v run-clang-tidy)" ]]; then - # Check if clang-tidy is available - if [[ ! -x "$(command -v clang-tidy)" ]]; then - python3 -m pip install --upgrade --requirements "${ROOT}/requirements-lint.txt" --user - fi - # Get clang-tidy version - CLANG_TIDY_VERSION="$(clang-tidy --version | head -n1 | awk '{print $4}')" - echo "Using clang-tidy version: ${CLANG_TIDY_VERSION}" - - # Check if build directory exists - if [[ ! -d "${ROOT}/build" ]]; then - echo "Build directory not found. Skipping clang-tidy checks." - else - # Run clang-tidy on specified files - clang_tidy_files() { - run-clang-tidy -j 64 "$@" -p build - } - - # Run clang-tidy on all C/C++ source files - clang_tidy_all() { - run-clang-tidy -j 64 src/*.cc -p build - } - - # Run clang-tidy on changed C/C++ files relative to main - clang_tidy_changed() { - # Get changed C/C++ files - CHANGED_FILES="$(git diff --name-only --diff-filter=ACM "${MERGE_BASE}" -- '*.c' '*.cc' '*.cpp' '*.h' '*.hpp' 2>/dev/null || true)" - - if [[ -n "${CHANGED_FILES}" ]]; then - echo "Running clang-tidy on changed files:" - echo "${CHANGED_FILES}" - # Convert newline-separated files to space-separated and run clang-tidy once - CHANGED_FILES_SPACE="$(echo "${CHANGED_FILES}" | tr '\n' ' ')" - run-clang-tidy -j 64 ${CHANGED_FILES_SPACE} -p build -fix - else - echo "No C/C++ files changed. Skipping clang-tidy." - fi - } - - if [[ -n "${ALL_FILES}" ]]; then - # If --all is given, run clang-tidy on all source files - clang_tidy_all - elif [[ -n "${ONLY_CHANGED}" ]]; then - # Otherwise, run clang-tidy only on changed C/C++ files - clang_tidy_changed - elif [[ "${#FILES[@]}" -gt 0 ]]; then - # If --files is given, run clang-tidy only on the provided files - clang_tidy_files "${FILES[@]}" - fi - fi - -else - echo "run-clang-tidy not found. Skipping clang-tidy checks." - echo "To install clang-tidy tools, you may need to install clang-tidy and run-clang-tidy." -fi -echo 'tile-lang clang-tidy: Done' - # Check if there are any uncommitted changes after all formatting steps. # If there are, ask the user to review and stage them. if ! git diff --quiet &>/dev/null; then diff --git a/requirements-lint.txt b/requirements-lint.txt index bbb167aa5f..873af12e1f 100644 --- a/requirements-lint.txt +++ b/requirements-lint.txt @@ -1,6 +1,5 @@ # Format and lint requirements pre-commit clang-format==21.1.8 -clang-tidy==21.1.6 codespell[toml]==2.4.1 ruff==0.14.14 diff --git a/src/op/builtin.h b/src/op/builtin.h index a4d0ba4125..f4da8bb7f9 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -54,6 +54,14 @@ static constexpr const char *kNonRestrictParams = "tl.non_restrict_params"; // argument of __launch_bounds__(maxThreads, minBlocksPerMultiprocessor). // Type: Integer static constexpr const char *kMinBlocksPerSM = "tl.min_blocks_per_sm"; +// lexical_alloc_scope may first appear as a Block annotation, requesting that +// LowerOpaqueBlock materialize a lexical scope boundary for that block subtree. +// After LowerOpaqueBlock, the same key appears as an AttrStmt marker that +// generates a C/CUDA lexical scope `{ ... }` in codegen. Allocations nested +// inside this scope cannot be hoisted past the boundary by StorageRewrite, +// giving the underlying compiler accurate variable lifetime information for +// register allocation. +static constexpr const char *kLexicalAllocScope = "lexical_alloc_scope"; } // namespace attr inline Optional diff --git a/src/op/copy.cc b/src/op/copy.cc index a887d6d3ca..c58f1296bd 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -1690,7 +1690,12 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, } else if (StructuralEqual()(shared_layout, linear_layout)) { desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_NONE); } else { - ICHECK(shared_layout->InputDim() >= 2) << "Cannot detect TMA layout."; + if (shared_layout->InputDim() < 2) { + LOG(WARNING) << "TMA bulk copy cannot support shared layout with input " + << "dimension " << shared_layout->InputDim() + << ", fallback to normal copy."; + return LowerNormalCopy(T, analyzer); + } const int ndim = static_cast(shared_layout->InputDim()); auto stride = as_const_int(shared_layout->InputShape()[ndim - 2]); auto continuous = as_const_int(shared_layout->InputShape()[ndim - 1]); diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index 86858d7cc3..b5346c79b8 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -312,17 +312,24 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { { BlockNode *n = block.CopyOnWrite(); n->name_hint = global_symbol.value(); + n->annotations.Set(tl::attr::kLexicalAllocScope, + IntImm(DataType::Int(32), 1)); } return BlockRealize(block_realize->iter_values, block_realize->predicate, block); } // warp with block realize node + Map block_annotations; + block_annotations.Set(tl::attr::kLexicalAllocScope, + IntImm(DataType::Int(32), 1)); return BlockRealize( /*iter_values=*/Array(), /*predicate=*/const_true(), /*block=*/ Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, - /*name_hint=*/global_symbol.value(), prim_func->body)); + /*name_hint=*/global_symbol.value(), prim_func->body, + /*init=*/Optional(), /*alloc_buffers=*/{}, + /*match_buffers=*/{}, /*annotations=*/block_annotations)); } else { LOG(FATAL) << "No lower function found for gemm_py"; return Stmt(); // This line will never be reached due to LOG(FATAL), but diff --git a/src/op/gemm_sp_py.cc b/src/op/gemm_sp_py.cc index 8546228fcb..571e81ec87 100644 --- a/src/op/gemm_sp_py.cc +++ b/src/op/gemm_sp_py.cc @@ -260,17 +260,24 @@ Stmt GemmSPPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { { BlockNode *n = block.CopyOnWrite(); n->name_hint = global_symbol.value(); + n->annotations.Set(tl::attr::kLexicalAllocScope, + IntImm(DataType::Int(32), 1)); } return BlockRealize(block_realize->iter_values, block_realize->predicate, block); } // warp with block realize node + Map block_annotations; + block_annotations.Set(tl::attr::kLexicalAllocScope, + IntImm(DataType::Int(32), 1)); return BlockRealize( /*iter_values=*/Array(), /*predicate=*/const_true(), /*block=*/ Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, - /*name_hint=*/global_symbol.value(), prim_func->body)); + /*name_hint=*/global_symbol.value(), prim_func->body, + /*init=*/Optional(), /*alloc_buffers=*/{}, + /*match_buffers=*/{}, /*annotations=*/block_annotations)); } else { LOG(FATAL) << "No lower function found for gemm_sp_py"; } diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 1e4a11f722..4a11cc1b3c 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -3608,7 +3608,16 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { } void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode *op) { - if (op->attr_key == tir::attr::fragment_shape) { + if (op->attr_key == tl::attr::kLexicalAllocScope) { + PrintIndent(); + stream << "{\n"; + int scope = BeginScope(); + PrintStmt(op->body); + EndScope(scope); + PrintIndent(); + stream << "}\n"; + return; + } else if (op->attr_key == tir::attr::fragment_shape) { const VarNode *buffer = op->node.as(); const StringImmNode *shape_str = op->value.as(); fragment_shapes[buffer] = shape_str->value; diff --git a/src/target/codegen_hip.cc b/src/target/codegen_hip.cc index 0e9afd3ee2..40c4cbd3e1 100644 --- a/src/target/codegen_hip.cc +++ b/src/target/codegen_hip.cc @@ -1177,7 +1177,16 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { } void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode *op) { - if (op->attr_key == "threadblock_swizzle_pattern") { + if (op->attr_key == tl::attr::kLexicalAllocScope) { + PrintIndent(); + stream << "{\n"; + int scope = BeginScope(); + PrintStmt(op->body); + EndScope(scope); + PrintIndent(); + stream << "}\n"; + return; + } else if (op->attr_key == "threadblock_swizzle_pattern") { this->PrintIndent(); std::string func_name; int panel_size = 0; diff --git a/src/transform/lower_opaque_block.cc b/src/transform/lower_opaque_block.cc index 05a84f77fc..c019819bff 100644 --- a/src/transform/lower_opaque_block.cc +++ b/src/transform/lower_opaque_block.cc @@ -104,7 +104,16 @@ class OpaqueBlockLower : public StmtExprMutator { body = Allocate(buffer->data, buffer->dtype, allocation_shape, const_true(), std::move(body), allocate_annotations); } - // Step 5. Insert attribute statements converted from pragmas + // Step 5. Materialize a lexical scope boundary only for blocks that were + // explicitly marked by an earlier semantic lowering pass (for example + // gemm_py/gemm_sp_py). We intentionally avoid re-inferring this from the + // lowered alloc_buffers here because provenance has already been blurred by + // earlier allocation planning/hoisting passes. + if (HasLexicalAllocScopeAnnotation(new_block->annotations)) { + body = AttrStmt(Integer(0), tl::attr::kLexicalAllocScope, Integer(1), + std::move(body)); + } + // Step 6. Insert attribute statements converted from pragmas for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) { body = AttrStmt(Integer(0), it->first, it->second, std::move(body)); } @@ -276,6 +285,11 @@ class OpaqueBlockLower : public StmtExprMutator { return preserved_annotations; } + static bool + HasLexicalAllocScopeAnnotation(const Map &annotations) { + return annotations.find(tl::attr::kLexicalAllocScope) != annotations.end(); + } + Buffer ResolveLocalVarBuffer(const Array &alloc_buffers) const { for (const Buffer &buffer : alloc_buffers) { std::string scope = buffer.scope(); diff --git a/src/transform/plan_update_buffer_allocation_location.cc b/src/transform/plan_update_buffer_allocation_location.cc index 1d12aac6eb..5575f6785f 100644 --- a/src/transform/plan_update_buffer_allocation_location.cc +++ b/src/transform/plan_update_buffer_allocation_location.cc @@ -117,6 +117,72 @@ class BufferAllocateOrderCollector : public StmtExprVisitor { ffi::Array buffer_alloc_recorder_; }; +/*! \brief Collect scope parent links and buffer vars referenced in for headers. + * + * Allocations attached to a ForNode are injected into the loop body. If a + * buffer var is referenced by the loop header itself (e.g. in the extent), + * the allocation must therefore be placed at an outer scope instead. + */ +class ScopePlacementInfoCollector : public StmtExprVisitor { +public: + static ScopePlacementInfoCollector Collect(const PrimFunc &func) { + ScopePlacementInfoCollector collector; + collector.scope_stack_.push_back(nullptr); + collector(func->body); + return collector; + } + + std::unordered_map parent_scope_; + std::unordered_map> + for_header_vars_; + +private: + void VisitStmt_(const BlockRealizeNode *op) final { + parent_scope_[op->block.get()] = scope_stack_.back(); + scope_stack_.push_back(op->block.get()); + if (!is_one(op->predicate)) { + this->VisitExpr(op->predicate); + } + this->VisitStmt(op->block->body); + scope_stack_.pop_back(); + } + + void VisitStmt_(const ForNode *op) final { + parent_scope_[op] = scope_stack_.back(); + const ForNode *prev_for_header = current_for_header_; + current_for_header_ = op; + this->VisitExpr(op->min); + this->VisitExpr(op->extent); + for (const auto &kv : op->annotations) { + if (auto expr = kv.second.try_cast()) { + this->VisitExpr(expr.value()); + } + } + current_for_header_ = prev_for_header; + + scope_stack_.push_back(op); + this->VisitStmt(op->body); + scope_stack_.pop_back(); + } + + void VisitExpr_(const BufferLoadNode *op) final { + if (current_for_header_ != nullptr) { + for_header_vars_[current_for_header_].insert(op->buffer->data.get()); + } + StmtExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const VarNode *op) final { + if (current_for_header_ != nullptr) { + for_header_vars_[current_for_header_].insert(op); + } + StmtExprVisitor::VisitExpr_(op); + } + + std::vector scope_stack_; + const ForNode *current_for_header_{nullptr}; +}; + class BufferAllocationLocator : public StmtExprMutator { public: explicit BufferAllocationLocator(const PrimFunc &func) { @@ -132,8 +198,12 @@ class BufferAllocationLocator : public StmtExprMutator { BufferAllocateOrderCollector::Collect(func); std::unordered_set arg_buffer_vars; CollectManagedAllocations collector; + ScopePlacementInfoCollector scope_info = + ScopePlacementInfoCollector::Collect(func); collector(func->body); managed_allocations_ = collector.managed_allocations; + parent_scope_ = std::move(scope_info.parent_scope_); + for_header_vars_ = std::move(scope_info.for_header_vars_); for (const auto &kv : func->buffer_map) { const Buffer &buffer = kv.second; @@ -161,6 +231,7 @@ class BufferAllocationLocator : public StmtExprMutator { stmt = (*bit).second.get(); } } + stmt = ResolveAllocationSite(buffer->data.get(), stmt); if (stmt != nullptr || vit != var_lca.end()) { // Skip moving allocations that are already bound to func arguments. if (arg_buffer_vars.count(buffer->data.get())) { @@ -229,6 +300,24 @@ class BufferAllocationLocator : public StmtExprMutator { return out; } + const StmtNode *ResolveAllocationSite(const VarNode *buffer_var, + const StmtNode *stmt) const { + while (stmt != nullptr && stmt->IsInstance()) { + const auto *for_node = static_cast(stmt); + auto header_it = for_header_vars_.find(for_node); + if (header_it == for_header_vars_.end() || + !header_it->second.count(buffer_var)) { + break; + } + auto parent_it = parent_scope_.find(stmt); + if (parent_it == parent_scope_.end() || parent_it->second == nullptr) { + break; + } + stmt = parent_it->second; + } + return stmt; + } + Stmt VisitStmt_(const ForNode *op) final { auto it = alloc_buffers_.find(op); if (it == alloc_buffers_.end()) { @@ -346,6 +435,11 @@ class BufferAllocationLocator : public StmtExprMutator { /*! \brief The map from stmt to the buffers to be allocated under it. */ std::unordered_map> alloc_buffers_; + /*! \brief Parent scope for each For/Block stmt. */ + std::unordered_map parent_scope_; + /*! \brief Buffer vars referenced in the header of each For stmt. */ + std::unordered_map> + for_header_vars_; /*! \brief Stack of buffers per data var for scoping correctness. */ ffi::Map> buffer_data_to_buffers_; /*! \brief Buffers that are allocated within a BlockNode, and may be moved. */ diff --git a/src/transform/reuse_local_descriptor_allocations.cc b/src/transform/reuse_local_descriptor_allocations.cc deleted file mode 100644 index a4e2dae3d6..0000000000 --- a/src/transform/reuse_local_descriptor_allocations.cc +++ /dev/null @@ -1,254 +0,0 @@ -/*! - * \file reuse_local_descriptor_allocations.cc - * \brief Pool lexically-disjoint local descriptor allocations. - */ - -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -#include "runtime/thread_storage_scope.h" -#include "tir/transforms/ir_utils.h" - -namespace tvm { -namespace tl { - -using namespace tir; -namespace refl = tvm::ffi::reflection; - -namespace { - -bool IsLocalDescriptorScope(const Var &buffer_var) { - std::string scope = GetPtrStorageScope(buffer_var); - return scope.rfind("local.descriptor.", 0) == 0; -} - -bool IsDescriptorHoistBoundary(const AttrStmtNode *op) { - return op->attr_key == tir::attr::thread_extent || - op->attr_key == tir::attr::virtual_thread || op->attr_key == "target"; -} - -bool IsReusableDescriptorAllocate(const AllocateNode *op) { - return IsLocalDescriptorScope(op->buffer_var) && is_one(op->condition) && - op->annotations.empty() && op->ConstantAllocationSize() > 0; -} - -std::string MakeDescriptorSignature(const AllocateNode *op) { - const DataType &dtype = op->dtype; - return GetPtrStorageScope(op->buffer_var) + "|" + - std::to_string(dtype.code()) + ":" + std::to_string(dtype.bits()) + - ":" + std::to_string(dtype.lanes()) + "|" + - std::to_string(op->ConstantAllocationSize()); -} - -struct AllocSite { - Var var; - DataType dtype; - ffi::Array extents; - ffi::Map annotations; - std::string signature; -}; - -class DescriptorAllocCollector : public StmtExprVisitor { -public: - static std::vector Collect(const Stmt &stmt) { - DescriptorAllocCollector collector; - collector(stmt); - return std::move(collector.allocs_); - } - -private: - void VisitStmt_(const AllocateNode *op) final { - if (IsReusableDescriptorAllocate(op)) { - allocs_.push_back(AllocSite{op->buffer_var, op->dtype, op->extents, - op->annotations, - MakeDescriptorSignature(op)}); - } - StmtExprVisitor::VisitStmt_(op); - } - - void VisitStmt_(const AttrStmtNode *op) final { - if (IsDescriptorHoistBoundary(op)) { - return; - } - StmtExprVisitor::VisitStmt_(op); - } - - std::vector allocs_; -}; - -class DescriptorVarRemapper : public StmtExprMutator { -public: - DescriptorVarRemapper(std::unordered_map var_remap, - std::unordered_set removed_allocs) - : var_remap_(std::move(var_remap)), - removed_allocs_(std::move(removed_allocs)) {} - -private: - PrimExpr VisitExpr_(const VarNode *op) final { - if (auto it = var_remap_.find(op); it != var_remap_.end()) { - return it->second; - } - return tvm::ffi::GetRef(op); - } - - Stmt VisitStmt_(const AllocateNode *op) final { - if (removed_allocs_.count(op->buffer_var.get())) { - return VisitStmt(op->body); - } - return StmtExprMutator::VisitStmt_(op); - } - - Stmt VisitStmt_(const DeclBufferNode *op) final { - auto node = Downcast(StmtExprMutator::VisitStmt_(op)); - Buffer new_buffer = RemapBuffer(node->buffer); - if (!new_buffer.same_as(node->buffer)) { - node.CopyOnWrite()->buffer = new_buffer; - } - return std::move(node); - } - - PrimExpr VisitExpr_(const BufferLoadNode *op) final { - auto node = Downcast(StmtExprMutator::VisitExpr_(op)); - Buffer new_buffer = RemapBuffer(node->buffer); - if (!new_buffer.same_as(node->buffer)) { - node.CopyOnWrite()->buffer = new_buffer; - } - return std::move(node); - } - - Stmt VisitStmt_(const BufferStoreNode *op) final { - auto node = Downcast(StmtExprMutator::VisitStmt_(op)); - Buffer new_buffer = RemapBuffer(node->buffer); - if (!new_buffer.same_as(node->buffer)) { - node.CopyOnWrite()->buffer = new_buffer; - } - return std::move(node); - } - - Buffer RemapBuffer(Buffer buffer) const { - if (auto it = var_remap_.find(buffer->data.get()); it != var_remap_.end()) { - Buffer new_buffer = buffer; - new_buffer.CopyOnWrite()->data = it->second; - return new_buffer; - } - return buffer; - } - - std::unordered_map var_remap_; - std::unordered_set removed_allocs_; -}; - -class ReuseLocalDescriptorAllocationsMutator : public StmtExprMutator { -public: - static PrimFunc Rewrite(PrimFunc func) { - auto fptr = func.CopyOnWrite(); - ReuseLocalDescriptorAllocationsMutator rewriter; - fptr->body = rewriter(std::move(fptr->body)); - return func; - } - -private: - struct PoolSlot { - AllocSite canonical; - int use_count{0}; - }; - - Stmt VisitStmt_(const SeqStmtNode *op) final { - ffi::Array visited_children; - visited_children.reserve(op->seq.size()); - for (const Stmt &stmt : op->seq) { - visited_children.push_back(VisitStmt(stmt)); - } - - std::unordered_map> signature_slots; - std::unordered_map alloc_to_slot; - std::vector slots; - - for (const Stmt &stmt : visited_children) { - std::unordered_map local_slot_index; - for (const AllocSite &alloc : DescriptorAllocCollector::Collect(stmt)) { - int ordinal = local_slot_index[alloc.signature]++; - std::vector &sig_slots = signature_slots[alloc.signature]; - if (static_cast(sig_slots.size()) <= ordinal) { - sig_slots.push_back(static_cast(slots.size())); - slots.push_back(PoolSlot{alloc, 0}); - } - int slot_idx = sig_slots[ordinal]; - alloc_to_slot[alloc.var.get()] = slot_idx; - ++slots[slot_idx].use_count; - } - } - - std::unordered_map var_remap; - std::unordered_set removed_allocs; - std::vector hoisted_allocs; - hoisted_allocs.reserve(slots.size()); - - for (const PoolSlot &slot : slots) { - if (slot.use_count <= 1) { - continue; - } - removed_allocs.insert(slot.canonical.var.get()); - hoisted_allocs.push_back(slot.canonical); - } - - if (hoisted_allocs.empty()) { - return visited_children.size() == 1 ? visited_children[0] - : SeqStmt(visited_children); - } - - for (const auto &[var, slot_idx] : alloc_to_slot) { - if (slots[slot_idx].use_count <= 1) { - continue; - } - removed_allocs.insert(var); - const Var &canonical_var = slots[slot_idx].canonical.var; - if (var != canonical_var.get()) { - var_remap[var] = canonical_var; - } - } - - DescriptorVarRemapper rewriter(std::move(var_remap), - std::move(removed_allocs)); - ffi::Array rewritten_children; - rewritten_children.reserve(visited_children.size()); - for (const Stmt &stmt : visited_children) { - rewritten_children.push_back(rewriter(stmt)); - } - - Stmt body = rewritten_children.size() == 1 ? rewritten_children[0] - : SeqStmt(rewritten_children); - for (auto it = hoisted_allocs.rbegin(); it != hoisted_allocs.rend(); ++it) { - body = Allocate(it->var, it->dtype, it->extents, const_true(), - std::move(body), it->annotations); - } - return body; - } -}; - -} // namespace - -tir::transform::Pass ReuseLocalDescriptorAllocations() { - auto pass_func = [](PrimFunc func, IRModule mod, - tvm::transform::PassContext ctx) { - return ReuseLocalDescriptorAllocationsMutator::Rewrite(std::move(func)); - }; - return tir::transform::CreatePrimFuncPass( - pass_func, 0, "tl.ReuseLocalDescriptorAllocations", {}); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - refl::GlobalDef().def("tl.transform.ReuseLocalDescriptorAllocations", - ReuseLocalDescriptorAllocations); -} - -} // namespace tl -} // namespace tvm diff --git a/src/transform/storage_rewrite.cc b/src/transform/storage_rewrite.cc index ee5554d593..d494fdac32 100644 --- a/src/transform/storage_rewrite.cc +++ b/src/transform/storage_rewrite.cc @@ -285,6 +285,8 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { VisitNewScope(op); } else if (op->attr_key == tir::attr::virtual_thread) { VisitNewScope(op); + } else if (op->attr_key == tl::attr::kLexicalAllocScope) { + VisitNewScope(op); } else { StmtExprVisitor::VisitStmt_(op); } @@ -575,6 +577,7 @@ class StoragePlanRewriter : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode *op) final { if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread || + op->attr_key == tl::attr::kLexicalAllocScope || tir::attr::IsPragmaKey(op->attr_key)) { // remake all the allocation at the attach scope. if (attach_map_.count(op)) { @@ -954,6 +957,21 @@ class StoragePlanRewriter : public StmtExprMutator { } } + /*! \brief Return the effective attach scope for the given storage scope. + * + * lexical_alloc_scope is intended to bound register/local-like allocations. + * Shared/global allocations should continue to follow thread_scope_ so we do + * not accidentally re-scope shared buffers nested inside a lexical block. + */ + const Object *effective_scope(const StorageScope &storage_scope) const { + if (lexical_scope_ != nullptr && + storage_scope.rank != StorageRank::kGlobal && + storage_scope.rank != StorageRank::kShared) { + return lexical_scope_; + } + return thread_scope_; + } + // Memory plan algorithm void PlanMemory(const std::vector &seq, @@ -990,7 +1008,8 @@ class StoragePlanRewriter : public StmtExprMutator { InplaceOpVerifier visitor; StorageEntry *src_entry = alloc_map_.at(src); if (src_entry->scope == storage_scope && - src_entry->attach_scope_ == thread_scope_ && + src_entry->attach_scope_ == + effective_scope(storage_scope) && src_entry->elem_type == alloc->dtype.element_of() && visitor.Check(s.stmt, var, src)) { uint64_t const_nbits = @@ -1007,9 +1026,10 @@ class StoragePlanRewriter : public StmtExprMutator { } } if (dst_entry == nullptr) { - dst_entry = FindAlloc(alloc, thread_scope_, storage_scope, - entry.num_physical_dimensions, enable_reuse, - reuse_require_exact_matched_dtype); + dst_entry = + FindAlloc(alloc, effective_scope(storage_scope), storage_scope, + entry.num_physical_dimensions, enable_reuse, + reuse_require_exact_matched_dtype); } dst_entry->allocs.emplace_back(alloc); alloc_map_[var] = dst_entry; @@ -1022,6 +1042,33 @@ class StoragePlanRewriter : public StmtExprMutator { op->attr_key == tir::attr::virtual_thread || tir::attr::IsPragmaKey(op->attr_key)) { PlanNewScope(op); + } else if (op->attr_key == tl::attr::kLexicalAllocScope) { + if (s.scope_pair_offset > 0) { + // Entering: redirect allocation attachment to this scope. + // thread_scope_ is NOT touched so PlanNewScope keeps working. + lexical_scope_stack_.push_back(lexical_scope_); + lexical_scope_ = op; + } else { + // Exiting: clear free lists for this scope and restore. + for (auto it = const_free_map_.begin(); + it != const_free_map_.end();) { + if (it->second->attach_scope_ == op) { + it = const_free_map_.erase(it); + } else { + ++it; + } + } + for (auto it = sym_free_list_.begin(); + it != sym_free_list_.end();) { + if ((*it)->attach_scope_ == op) { + it = sym_free_list_.erase(it); + } else { + ++it; + } + } + lexical_scope_ = lexical_scope_stack_.back(); + lexical_scope_stack_.pop_back(); + } } else { ICHECK(op->attr_key == tir::attr::extern_scope); } @@ -1179,6 +1226,11 @@ class StoragePlanRewriter : public StmtExprMutator { } // thread scope. const Object *thread_scope_{nullptr}; + // Current lexical scope (set by lexical_alloc_scope, independent of + // thread_scope_ so that PlanNewScope's toggle protocol is preserved). + const Object *lexical_scope_{nullptr}; + // Stack for nested lexical scopes. + std::vector lexical_scope_stack_; // whether enable inplace detection. bool detect_inplace_{false}; // Locations of free ops. diff --git a/testing/python/transform/test_tilelang_transform_lexical_alloc_scope.py b/testing/python/transform/test_tilelang_transform_lexical_alloc_scope.py new file mode 100644 index 0000000000..946ada4453 --- /dev/null +++ b/testing/python/transform/test_tilelang_transform_lexical_alloc_scope.py @@ -0,0 +1,429 @@ +"""Tests for the lexical_alloc_scope feature. + +Verifies that: +1. LowerOpaqueBlock inserts AttrStmt("lexical_alloc_scope") for blocks + explicitly marked for lexical allocation scoping. +2. Unmarked blocks do NOT receive the marker by heuristic inference. +3. gemm_py-produced blocks are explicitly marked and survive as scopes. +4. StorageRewrite does not hoist allocations past the scope boundary. +5. CUDA codegen emits { ... } for the scoped block. +""" + +import tilelang as tl +import tilelang.language as T +from tilelang import tvm +from tilelang.engine.phase import LowerAndLegalize +from tvm.tir.stmt_functor import post_order_visit +import tilelang.testing + + +def _count_attrs(func, attr_key): + """Count occurrences of a specific AttrStmt key in the function body.""" + count = [0] + + def _visit(node): + if isinstance(node, tvm.tir.AttrStmt) and str(node.attr_key) == attr_key: + count[0] += 1 + + post_order_visit(func.body, _visit) + return count[0] + + +def _count_allocate_inside_attr(func, attr_key): + """Count Allocate nodes that are (transitively) nested inside the given AttrStmt.""" + count = [0] + inside = [False] + + def _visit(node): + if isinstance(node, tvm.tir.AttrStmt) and str(node.attr_key) == attr_key: + old = inside[0] + inside[0] = True + post_order_visit(node.body, _visit) + inside[0] = old + elif isinstance(node, tvm.tir.Allocate) and inside[0]: + count[0] += 1 + + post_order_visit(func.body, _visit) + return count[0] + + +def _apply_lower_opaque_pipeline(func, target, pass_configs=None): + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + pass_configs = pass_configs or {} + with target, tvm.transform.PassContext(config=pass_configs): + mod = LowerAndLegalize(mod, target) + mod = tl.transform.LowerSharedTmem()(mod) + mod = tl.transform.IfStmtBinding()(mod) + mod = tl.transform.PlanAndUpdateBufferAllocationLocation()(mod) + mod = tl.transform.LowerSharedBarrier()(mod) + mod = tl.transform.HoistGlobalBufferAllocations()(mod) + mod = tl.transform.LowerOpaqueBlock()(mod) + return mod + + +# --------------------------------------------------------------------------- +# Test 1: LowerOpaqueBlock inserts the lexical_alloc_scope marker for an +# explicitly annotated block. +# --------------------------------------------------------------------------- +def test_lower_opaque_block_inserts_lexical_alloc_scope_for_explicit_block(): + """An explicitly annotated block should produce a lexical_alloc_scope.""" + target = tvm.target.Target("cuda -arch=sm_80") + + @T.prim_func + def func( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + ): + T.func_attr({"global_symbol": "main", "target": target}) + T.launch_thread("blockIdx.x", 1) + tx = T.launch_thread("threadIdx.x", 128) + for _ in T.serial(4): + with T.block(): + T.block_attr({"lexical_alloc_scope": 1}) + S = T.alloc_buffer((128,), dtype=T.float32, scope="local") + S[tx] = A[tx] + B[tx] = S[tx] + + mod = tvm.IRModule.from_expr(func) + mod = tl.transform.LowerOpaqueBlock()(mod) + lowered = mod["main"] + + n = _count_attrs(lowered, "lexical_alloc_scope") + assert n >= 1, f"Expected at least 1 lexical_alloc_scope AttrStmt, got {n}" + + # The Allocate for S should be inside the scope + n_alloc = _count_allocate_inside_attr(lowered, "lexical_alloc_scope") + assert n_alloc >= 1, f"Expected Allocate inside lexical_alloc_scope, got {n_alloc}" + + +# --------------------------------------------------------------------------- +# Test 2: An unmarked block with local alloc should NOT get the marker +# --------------------------------------------------------------------------- +def test_lower_opaque_block_skips_unmarked_local_alloc(): + """An unmarked local-alloc block should not produce a lexical_alloc_scope.""" + target = tvm.target.Target("cuda -arch=sm_80") + + @T.prim_func + def func( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + ): + T.func_attr({"global_symbol": "main", "target": target}) + T.launch_thread("blockIdx.x", 1) + tx = T.launch_thread("threadIdx.x", 128) + for _ in T.serial(4): + with T.block(): + S = T.alloc_buffer((128,), dtype=T.float32, scope="local") + S[tx] = A[tx] + B[tx] = S[tx] + + mod = tvm.IRModule.from_expr(func) + mod = tl.transform.LowerOpaqueBlock()(mod) + lowered = mod["main"] + + n = _count_attrs(lowered, "lexical_alloc_scope") + assert n == 0, f"Expected 0 lexical_alloc_scope AttrStmt for unmarked local block, got {n}" + + +# --------------------------------------------------------------------------- +# Test 3: Block without alloc_buffers should NOT get the marker +# --------------------------------------------------------------------------- +def test_lower_opaque_block_skips_empty_alloc(): + """A block without alloc_buffers should not produce a lexical_alloc_scope.""" + target = tvm.target.Target("cuda -arch=sm_80") + + @T.prim_func + def func( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + ): + T.func_attr({"global_symbol": "main", "target": target}) + T.launch_thread("blockIdx.x", 1) + tx = T.launch_thread("threadIdx.x", 128) + for _ in T.serial(4): + with T.block(): + B[tx] = A[tx] + + mod = tvm.IRModule.from_expr(func) + mod = tl.transform.LowerOpaqueBlock()(mod) + lowered = mod["main"] + + n = _count_attrs(lowered, "lexical_alloc_scope") + assert n == 0, f"Expected 0 lexical_alloc_scope AttrStmt for empty block, got {n}" + + +# --------------------------------------------------------------------------- +# Test 4: GEMM descriptor allocs inside loop should get the marker +# --------------------------------------------------------------------------- +def test_lower_opaque_block_inserts_scope_for_gemm_descriptor_alloc(): + """Lowered WGMMA descriptor buffers inside a loop should trigger lexical_alloc_scope.""" + target = tvm.target.Target("cuda -arch=sm_90a") + + @T.prim_func + def func( + A: T.Tensor((64, 16), T.bfloat16), + B: T.Tensor((64, 16), T.bfloat16), + C: T.Tensor((64, 64), T.bfloat16), + ): + with T.Kernel(1, threads=128): + A_shared = T.alloc_shared((64, 16), T.bfloat16) + B_shared = T.alloc_shared((64, 16), T.bfloat16) + C_local = T.alloc_fragment((64, 64), T.float32) + T.clear(C_local) + for _ in T.serial(2): + T.copy(A[0, 0], A_shared) + T.copy(B[0, 0], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + T.copy(C_local, C[0:64, 0:64]) + + mod = _apply_lower_opaque_pipeline(func, target) + lowered = mod["main"] + + n = _count_attrs(lowered, "lexical_alloc_scope") + assert n >= 1, f"Expected lexical_alloc_scope for lowered GEMM descriptor alloc, got {n}" + + +# --------------------------------------------------------------------------- +# Test 5: local.var-only block inside loop should NOT get the marker +# --------------------------------------------------------------------------- +def test_lower_opaque_block_skips_local_var_only_alloc(): + """A block that allocates only local.var should not get lexical_alloc_scope.""" + target = tvm.target.Target("cuda -arch=sm_80") + + @T.prim_func + def func( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + ): + T.func_attr({"global_symbol": "main", "target": target}) + T.launch_thread("blockIdx.x", 1) + tx = T.launch_thread("threadIdx.x", 128) + for _ in T.serial(4): + with T.block(): + idx = T.alloc_var(T.int32) + idx = tx + B[tx] = A[idx] + + mod = tvm.IRModule.from_expr(func) + mod = tl.transform.LowerOpaqueBlock()(mod) + lowered = mod["main"] + + n = _count_attrs(lowered, "lexical_alloc_scope") + assert n == 0, f"Expected 0 lexical_alloc_scope for local.var-only block, got {n}" + + +# --------------------------------------------------------------------------- +# Test 6: top-level explicitly annotated local alloc should get the marker +# --------------------------------------------------------------------------- +def test_lower_opaque_block_marks_explicit_top_level_local_alloc(): + """A top-level explicitly annotated local alloc should get lexical_alloc_scope.""" + target = tvm.target.Target("cuda -arch=sm_80") + + @T.prim_func + def func( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + ): + T.func_attr({"global_symbol": "main", "target": target}) + T.launch_thread("blockIdx.x", 1) + tx = T.launch_thread("threadIdx.x", 128) + with T.block(): + T.block_attr({"lexical_alloc_scope": 1}) + S = T.alloc_buffer((128,), dtype=T.float32, scope="local") + S[tx] = A[tx] + B[tx] = S[tx] + + mod = tvm.IRModule.from_expr(func) + mod = tl.transform.LowerOpaqueBlock()(mod) + lowered = mod["main"] + + n = _count_attrs(lowered, "lexical_alloc_scope") + assert n >= 1, f"Expected lexical_alloc_scope for top-level local block, got {n}" + + +# --------------------------------------------------------------------------- +# Test 7: top-level fragment alloc should not force a lexical scope +# --------------------------------------------------------------------------- +def test_lower_opaque_block_skips_fragment_alloc(): + """A fragment alloc should not force lexical_alloc_scope by itself.""" + target = tvm.target.Target("cuda -arch=sm_80") + + @T.prim_func + def func( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + ): + T.func_attr({"global_symbol": "main", "target": target}) + T.launch_thread("blockIdx.x", 1) + tx = T.launch_thread("threadIdx.x", 128) + with T.block(): + S = T.alloc_buffer((128,), dtype=T.float32, scope="local.fragment") + S[tx] = A[tx] + B[tx] = S[tx] + + mod = tvm.IRModule.from_expr(func) + mod = tl.transform.LowerOpaqueBlock()(mod) + lowered = mod["main"] + + n = _count_attrs(lowered, "lexical_alloc_scope") + assert n == 0, f"Expected no lexical_alloc_scope for fragment-only block, got {n}" + + +# --------------------------------------------------------------------------- +# Test 8: disable-ws pipelined GEMM should not wrap the fragment root block +# --------------------------------------------------------------------------- +def test_lower_opaque_block_skips_fragment_root_in_disable_ws_pipeline(): + """A fragment root block should not force lexical_alloc_scope in disable-ws pipeline.""" + target = tvm.target.Target("cuda -arch=sm_90a") + pass_configs = {tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED.value: True} + + @T.prim_func + def func( + A: T.Tensor((256, 256), T.bfloat16), + B: T.Tensor((128, 256), T.bfloat16), + C: T.Tensor((256, 128), T.bfloat16), + ): + with T.Kernel(1, threads=256): + A_shared = T.alloc_shared((256, 128), T.bfloat16) + B_shared = T.alloc_shared((128, 128), T.bfloat16) + C_local = T.alloc_fragment((256, 128), T.float32) + C_shared = T.alloc_shared((256, 128), T.bfloat16) + T.clear(C_local) + for k in T.Pipelined(2, num_stages=2): + T.copy(A[0, k * 128], A_shared) + T.copy(B[0, k * 128], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + T.copy(C_local, C_shared) + T.copy(C_shared, C[0:256, 0:128]) + + mod = _apply_lower_opaque_pipeline(func, target, pass_configs=pass_configs) + lowered = mod["main"] + lowered_script = lowered.script(show_meta=False) + + assert 'T.attr(0, "lexical_alloc_scope", 1)\n C_local = T.decl_buffer' not in lowered_script, ( + "Unexpected top-level lexical_alloc_scope around fragment-backed C_local" + ) + assert lowered_script.count("lexical_alloc_scope") >= 2, "Expected inner GEMM lexical scopes to remain in the disable-ws pipeline" + + +# --------------------------------------------------------------------------- +# Test 9: StorageRewrite preserves lexical_alloc_scope +# --------------------------------------------------------------------------- +def test_storage_rewrite_preserves_scope(): + """lexical_alloc_scope should survive StorageRewrite without crashing.""" + target = tvm.target.Target("cuda -arch=sm_80") + + @T.prim_func + def func( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + ): + T.func_attr({"global_symbol": "main", "target": target}) + T.launch_thread("blockIdx.x", 1) + tx = T.launch_thread("threadIdx.x", 128) + for _ in T.serial(4): + with T.block(): + T.block_attr({"lexical_alloc_scope": 1}) + S = T.alloc_buffer((128,), dtype=T.float32, scope="local") + S[tx] = A[tx] + B[tx] = S[tx] + + mod = tvm.IRModule.from_expr(func) + mod = tl.transform.LowerOpaqueBlock()(mod) + mod = tl.transform.Simplify()(mod) + mod = tl.transform.FlattenBuffer()(mod) + mod = tl.transform.VectorizeLoop()(mod) + mod = tl.transform.StorageRewrite()(mod) + lowered = mod["main"] + + # The scope marker should still be present after StorageRewrite + n = _count_attrs(lowered, "lexical_alloc_scope") + assert n >= 1, f"Expected lexical_alloc_scope to survive StorageRewrite, got {n}" + + +# --------------------------------------------------------------------------- +# Test 10: CUDA codegen emits { } for the scope +# --------------------------------------------------------------------------- +@tilelang.testing.requires_cuda +def test_codegen_emits_braces(): + """The generated CUDA source should contain scoped { } blocks for explicitly marked allocs.""" + + @T.prim_func + def func( + A: T.Tensor((128, 4), T.float32), + B: T.Tensor((128, 4), T.float32), + ): + with T.Kernel(1, threads=128): + for k in T.serial(4): + with T.block(): + T.block_attr({"lexical_alloc_scope": 1}) + S = T.alloc_buffer((128,), dtype=T.float32, scope="local") + S[T.get_thread_binding()] = A[T.get_thread_binding(), k] + B[T.get_thread_binding(), k] = S[T.get_thread_binding()] + + kernel = tilelang.compile(func, out_idx=[1], target="cuda") + src = kernel.get_kernel_source() + print("=== lexical_alloc_scope codegen ===") + print(src) + import re + + standalone_open_braces = re.findall(r"^\s*\{\s*$", src, re.MULTILINE) + assert len(standalone_open_braces) >= 1, f"Expected at least 1 standalone '{{' for lexical scope, found {len(standalone_open_braces)}" + + +@tilelang.testing.requires_cuda +def test_codegen_skips_redundant_top_level_braces(): + """The outermost top-level lexical scope should not emit a redundant brace block.""" + + @T.prim_func + def func( + A: T.Tensor((128, 4), T.float32), + B: T.Tensor((128, 4), T.float32), + ): + with T.Kernel(1, threads=128): + C = T.alloc_fragment((128,), T.float32) + T.clear(C) + for k in T.serial(4): + with T.block(): + T.block_attr({"lexical_alloc_scope": 1}) + S = T.alloc_buffer((128,), dtype=T.float32, scope="local") + S[T.get_thread_binding()] = A[T.get_thread_binding(), k] + C[T.get_thread_binding()] = S[T.get_thread_binding()] + for k in T.serial(4): + B[T.get_thread_binding(), k] = C[T.get_thread_binding()] + + kernel = tilelang.compile(func, out_idx=[1], target="cuda") + src = kernel.get_kernel_source() + print("=== top-level lexical_alloc_scope codegen ===") + print(src) + import re + + assert re.search(r"^\s*float [A-Za-z_]\w*\[\d+\];\s*$", src, re.MULTILINE), ( + "Expected top-level fragment allocation to stay directly in function scope" + ) + kernel_match = re.search( + r'extern "C" __global__ void(?: __launch_bounds__\([^)]*\))? [A-Za-z_]\w*_kernel\(', + src, + ) + assert kernel_match, "Expected generated CUDA source to contain a kernel signature" + body_open_idx = src.find("{", kernel_match.start()) + assert body_open_idx >= 0, "Expected generated CUDA kernel body" + first_nonempty = next(line.strip() for line in src[body_open_idx + 1 :].splitlines() if line.strip()) + assert first_nonempty != "{", "Unexpected redundant top-level lexical scope brace" + standalone_open_braces = re.findall(r"^\s*\{\s*$", src, re.MULTILINE) + assert len(standalone_open_braces) >= 1, "Expected inner lexical scopes to still emit standalone braces" + + +if __name__ == "__main__": + test_lower_opaque_block_inserts_lexical_alloc_scope_for_explicit_block() + test_lower_opaque_block_skips_unmarked_local_alloc() + test_lower_opaque_block_skips_empty_alloc() + test_lower_opaque_block_inserts_scope_for_gemm_descriptor_alloc() + test_lower_opaque_block_skips_local_var_only_alloc() + test_lower_opaque_block_marks_explicit_top_level_local_alloc() + test_lower_opaque_block_skips_fragment_alloc() + test_lower_opaque_block_skips_fragment_root_in_disable_ws_pipeline() + test_storage_rewrite_preserves_scope() + test_codegen_emits_braces() + test_codegen_skips_redundant_top_level_braces() + print("All tests passed!") diff --git a/testing/python/transform/test_tilelang_transform_plan_update_buffer_allocation_location.py b/testing/python/transform/test_tilelang_transform_plan_update_buffer_allocation_location.py new file mode 100644 index 0000000000..e3d53b813f --- /dev/null +++ b/testing/python/transform/test_tilelang_transform_plan_update_buffer_allocation_location.py @@ -0,0 +1,71 @@ +import tilelang as tl +import tilelang.language as T +import tilelang.testing +from tilelang import tvm +from tilelang.engine.phase import LowerAndLegalize + + +def _apply_plan_update(func: tvm.tir.PrimFunc) -> tvm.IRModule: + target = tvm.target.Target("cuda") + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + with target: + mod = LowerAndLegalize(mod, target) + mod = tl.transform.LowerSharedTmem()(mod) + mod = tl.transform.IfStmtBinding()(mod) + mod = tl.transform.PlanAndUpdateBufferAllocationLocation()(mod) + return mod + + +def _find_block(stmt: tvm.tir.Stmt, name_hint: str) -> tvm.tir.Block: + blocks = [] + + def _visit(node): + if isinstance(node, tvm.tir.Block) and str(node.name_hint) == name_hint: + blocks.append(node) + + tvm.tir.stmt_functor.post_order_visit(stmt, _visit) + assert len(blocks) == 1, f"Expected exactly one block named {name_hint}, got {len(blocks)}" + return blocks[0] + + +def _find_first_for(stmt: tvm.tir.Stmt) -> tvm.tir.For: + loops = [] + + def _visit(node): + if isinstance(node, tvm.tir.For): + loops.append(node) + + tvm.tir.stmt_functor.post_order_visit(stmt, _visit) + assert loops, "Expected at least one loop" + return loops[0] + + +def test_plan_update_keeps_loop_header_local_var_outside_loop_body(): + @T.prim_func + def func(x: T.Tensor((256,), "int64")): + with T.Kernel(256, threads=128): + a, b = T.alloc_var(T.int), T.alloc_var(T.int) + T.fill(x[a:b], 0) + + mod = _apply_plan_update(func) + main = mod["main"] + + tilelang_root = _find_block(main.body, "tilelang_root") + root_local_vars = {buf.name for buf in tilelang_root.alloc_buffers if buf.scope() == "local.var"} + assert {"a", "b"} <= root_local_vars + + loop = _find_first_for(main.body) + loop_body_local_vars = set() + + def _visit_loop_body(node): + if isinstance(node, tvm.tir.Block): + for buf in node.alloc_buffers: + if buf.scope() == "local.var": + loop_body_local_vars.add(buf.name) + + tvm.tir.stmt_functor.post_order_visit(loop.body, _visit_loop_body) + assert "b" not in loop_body_local_vars + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/transform/test_tilelang_transform_reuse_local_descriptor_allocations.py b/testing/python/transform/test_tilelang_transform_reuse_local_descriptor_allocations.py deleted file mode 100644 index 940baa88a1..0000000000 --- a/testing/python/transform/test_tilelang_transform_reuse_local_descriptor_allocations.py +++ /dev/null @@ -1,105 +0,0 @@ -# ruff: noqa -from tilelang import tvm as tvm -import tilelang as tl -from tilelang.utils.target import determine_target -import tilelang.language as T - - -auto_target = tvm.target.Target(determine_target("auto")) - - -def _check(original, transformed): - mod = tvm.IRModule.from_expr(original.with_attr("global_symbol", "main")) - mod = tvm.tir.transform.BindTarget(auto_target)(mod) - mod = tl.transform.ReuseLocalDescriptorAllocations()(mod) - - expected = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main")) - expected = tvm.tir.transform.BindTarget(auto_target)(expected) - - tvm.ir.assert_structural_equal(mod["main"], expected["main"], True) - - -def test_reuse_local_descriptor_allocations(): - @T.prim_func - def before(): - T.func_attr({"tir.noalias": True}) - with T.attr(0, "test.region", 0): - desc_a = T.allocate([1], "uint64", "local.descriptor.wgmma") - desc_b = T.allocate([1], "uint64", "local.descriptor.wgmma") - desc_a_buf = T.Buffer((1,), "uint64", data=desc_a, scope="local.descriptor.wgmma") - desc_b_buf = T.Buffer((1,), "uint64", data=desc_b, scope="local.descriptor.wgmma") - T.initialize_wgmma_descriptor(desc_a_buf[0], T.uint64(0), 1, 1, 64) - T.initialize_wgmma_descriptor(desc_b_buf[0], T.uint64(0), 1, 1, 64) - T.evaluate(T.call_extern("handle", "use_desc_pair", desc_a, desc_b)) - with T.attr(0, "test.region", 1): - desc_a_1 = T.allocate([1], "uint64", "local.descriptor.wgmma") - desc_b_1 = T.allocate([1], "uint64", "local.descriptor.wgmma") - desc_a_buf_1 = T.Buffer((1,), "uint64", data=desc_a_1, scope="local.descriptor.wgmma") - desc_b_buf_1 = T.Buffer((1,), "uint64", data=desc_b_1, scope="local.descriptor.wgmma") - T.initialize_wgmma_descriptor(desc_a_buf_1[0], T.uint64(1), 1, 1, 64) - T.initialize_wgmma_descriptor(desc_b_buf_1[0], T.uint64(1), 1, 1, 64) - T.evaluate(T.call_extern("handle", "use_desc_pair", desc_a_1, desc_b_1)) - - @T.prim_func - def after(): - T.func_attr({"tir.noalias": True}) - desc_a = T.allocate([1], "uint64", "local.descriptor.wgmma") - desc_b = T.allocate([1], "uint64", "local.descriptor.wgmma") - with T.attr(0, "test.region", 0): - desc_a_buf = T.Buffer((1,), "uint64", data=desc_a, scope="local.descriptor.wgmma") - desc_b_buf = T.Buffer((1,), "uint64", data=desc_b, scope="local.descriptor.wgmma") - T.initialize_wgmma_descriptor(desc_a_buf[0], T.uint64(0), 1, 1, 64) - T.initialize_wgmma_descriptor(desc_b_buf[0], T.uint64(0), 1, 1, 64) - T.evaluate(T.call_extern("handle", "use_desc_pair", desc_a, desc_b)) - with T.attr(0, "test.region", 1): - desc_a_buf_1 = T.Buffer((1,), "uint64", data=desc_a, scope="local.descriptor.wgmma") - desc_b_buf_1 = T.Buffer((1,), "uint64", data=desc_b, scope="local.descriptor.wgmma") - T.initialize_wgmma_descriptor(desc_a_buf_1[0], T.uint64(1), 1, 1, 64) - T.initialize_wgmma_descriptor(desc_b_buf_1[0], T.uint64(1), 1, 1, 64) - T.evaluate(T.call_extern("handle", "use_desc_pair", desc_a, desc_b)) - - _check(before, after) - - -def test_reuse_local_descriptor_allocations_stays_inside_launch_thread(): - @T.prim_func - def before(): - T.func_attr({"tir.noalias": True}) - with T.launch_thread("blockIdx.x", 1): - with T.attr(0, "test.region", 0): - desc_a = T.allocate([1], "uint64", "local.descriptor.wgmma") - desc_b = T.allocate([1], "uint64", "local.descriptor.wgmma") - desc_a_buf = T.Buffer((1,), "uint64", data=desc_a, scope="local.descriptor.wgmma") - desc_b_buf = T.Buffer((1,), "uint64", data=desc_b, scope="local.descriptor.wgmma") - T.initialize_wgmma_descriptor(desc_a_buf[0], T.uint64(0), 1, 1, 64) - T.initialize_wgmma_descriptor(desc_b_buf[0], T.uint64(0), 1, 1, 64) - T.evaluate(T.call_extern("handle", "use_desc_pair", desc_a, desc_b)) - with T.attr(0, "test.region", 1): - desc_a_1 = T.allocate([1], "uint64", "local.descriptor.wgmma") - desc_b_1 = T.allocate([1], "uint64", "local.descriptor.wgmma") - desc_a_buf_1 = T.Buffer((1,), "uint64", data=desc_a_1, scope="local.descriptor.wgmma") - desc_b_buf_1 = T.Buffer((1,), "uint64", data=desc_b_1, scope="local.descriptor.wgmma") - T.initialize_wgmma_descriptor(desc_a_buf_1[0], T.uint64(1), 1, 1, 64) - T.initialize_wgmma_descriptor(desc_b_buf_1[0], T.uint64(1), 1, 1, 64) - T.evaluate(T.call_extern("handle", "use_desc_pair", desc_a_1, desc_b_1)) - - @T.prim_func - def after(): - T.func_attr({"tir.noalias": True}) - with T.launch_thread("blockIdx.x", 1): - desc_a = T.allocate([1], "uint64", "local.descriptor.wgmma") - desc_b = T.allocate([1], "uint64", "local.descriptor.wgmma") - with T.attr(0, "test.region", 0): - desc_a_buf = T.Buffer((1,), "uint64", data=desc_a, scope="local.descriptor.wgmma") - desc_b_buf = T.Buffer((1,), "uint64", data=desc_b, scope="local.descriptor.wgmma") - T.initialize_wgmma_descriptor(desc_a_buf[0], T.uint64(0), 1, 1, 64) - T.initialize_wgmma_descriptor(desc_b_buf[0], T.uint64(0), 1, 1, 64) - T.evaluate(T.call_extern("handle", "use_desc_pair", desc_a, desc_b)) - with T.attr(0, "test.region", 1): - desc_a_buf_1 = T.Buffer((1,), "uint64", data=desc_a, scope="local.descriptor.wgmma") - desc_b_buf_1 = T.Buffer((1,), "uint64", data=desc_b, scope="local.descriptor.wgmma") - T.initialize_wgmma_descriptor(desc_a_buf_1[0], T.uint64(1), 1, 1, 64) - T.initialize_wgmma_descriptor(desc_b_buf_1[0], T.uint64(1), 1, 1, 64) - T.evaluate(T.call_extern("handle", "use_desc_pair", desc_a, desc_b)) - - _check(before, after) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 3153db1d7c..ec47fa3075 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -238,7 +238,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.FuseMBarrierArriveExpectTx()(mod) mod = tilelang.transform.HoistGlobalBufferAllocations()(mod) mod = tilelang.transform.LowerOpaqueBlock()(mod) - mod = tilelang.transform.ReuseLocalDescriptorAllocations()(mod) if is_hopper(target): mod = tilelang.transform.RewriteWgmmaSync()(mod) mod = tilelang.transform.Simplify()(mod) diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index f82592548f..75ad91d244 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -165,17 +165,6 @@ def RewriteWgmmaSync(): return _ffi_api.RewriteWgmmaSync() # type: ignore -def ReuseLocalDescriptorAllocations(): - """Pool lexically-disjoint local descriptor allocations. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.ReuseLocalDescriptorAllocations() # type: ignore - - def ThreadSync(storage_scope: str): """Insert sync between parallel read/write of shared buffers. From 90299d6831c76c4125c96bdcd6812b48230e3b56 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Sat, 11 Apr 2026 23:11:25 +0800 Subject: [PATCH 038/156] [Bugfix] Fix incorrect sync hoist for fragment buffer conditions in ThreadSync (#2030) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The ConditionThreadPropertyChecker in ThreadSync incorrectly classified conditions derived from fragment (local-scope) buffer loads as non-block-uniform, solely based on storage scope. This caused the sync planner to hoist __syncthreads() from inside the if-body to before the if-statement, removing the write-before-read synchronization guarantee between shared memory writes and TMA store reads. Fragment buffers commonly hold block-uniform data when populated from block-uniform global addresses (e.g., T.copy(BlockMask[blockIdx.y, :], fragment)). The fix removes the scope-based heuristic and instead relies on the recursive visit of buffer load indices — if any index depends on threadIdx, VisitExpr_(VarNode*) will correctly mark the load as non-block-uniform. Before fix: __syncthreads(); // hoisted here (too early) if (a >= 0) { write_to_shared(); // all threads tma_store(); // elected thread — no sync protection! } After fix: __syncthreads(); // loop-carried sync if (a >= 0) { write_to_shared(); // all threads __syncthreads(); // correctly placed intra-iteration sync tma_store(); // elected thread } --- src/transform/thread_storage_sync.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index c06be1eb1f..628e10b887 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -665,9 +665,13 @@ class ConditionThreadPropertyChecker : public IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const BufferLoadNode *op) final { current_.depends_on_runtime = true; - if (IsThreadLocalScope(GetScope(op->buffer->data))) { - current_.is_block_uniform = false; - } + // Do not mark local-scope loads as non-block-uniform solely based on + // storage scope. Thread-local buffers (fragments) commonly hold + // block-uniform data when populated from block-uniform global addresses + // (e.g., T.copy(BlockMask[blockIdx.y, :], fragment)). If the load + // indices actually depend on threadIdx, the recursive visit of indices + // below (via IRMutatorWithAnalyzer::VisitExpr_) will correctly set + // is_block_uniform = false through VisitExpr_(VarNode*). return IRMutatorWithAnalyzer::VisitExpr_(op); } From d6191644ab39738da30213e9291977235f8a02cb Mon Sep 17 00:00:00 2001 From: Yichen Yan Date: Sat, 11 Apr 2026 23:34:00 +0800 Subject: [PATCH 039/156] add .agents/skills/build/SKILL.md for build conventions (#2019) * add .agents/skills/build/SKILL.md for build conventions * Add cmake+PYTHONPATH workflow and improve build docs in SKILL.md Document the recommended C++ development workflow (cmake + PYTHONPATH), clarify editable install guidance, and add specific test invocation examples. Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: LeiWang1999 Co-authored-by: Claude Opus 4.6 (1M context) --- .agents/skills/build/SKILL.md | 83 +++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 .agents/skills/build/SKILL.md diff --git a/.agents/skills/build/SKILL.md b/.agents/skills/build/SKILL.md new file mode 100644 index 0000000000..63dde07427 --- /dev/null +++ b/.agents/skills/build/SKILL.md @@ -0,0 +1,83 @@ +# Build & Install + +## Installing / Rebuilding tilelang + +The standard way to build and install: + +```bash +pip install . +``` + +Or with verbose output for debugging build issues: + +```bash +pip install . -v +``` + +`uv pip install .` also works if `uv` is available but is not required. + +Build dependencies are declared in `pyproject.toml` and resolved automatically during `pip install .`. + +If `ccache` is available, repeated builds only recompile changed C++ files. + +## Alternative: Development Build with `--no-build-isolation` + +If you need faster iteration (e.g. calling `cmake` directly to recompile C++ without re-running the full pip install), install build dependencies first: + +```bash +pip install -r requirements-dev.txt +pip install --no-build-isolation . +``` + +After this, you can invoke `cmake --build build` directly to recompile only changed C++ files. This is useful when iterating on C++ code. + +## Alternative: cmake + PYTHONPATH (recommended for C++ development) + +For the fastest C++ iteration, bypass pip entirely and drive cmake directly: + +```bash +# Configure (auto-detects CUDA; git submodules are initialised automatically) +cmake -S . -B build + +# Build +cmake --build build -j$(nproc) + +# Make the local tilelang package importable +export PYTHONPATH=$(pwd):$PYTHONPATH +``` + +After the initial configure, recompiling is just `cmake --build build -j$(nproc)`. The runtime automatically discovers native libraries from `build/lib/` when it detects a dev checkout (see `tilelang/env.py`). + +Useful cmake options: + +| Flag | Purpose | +|------|---------| +| `-DUSE_CUDA=ON/OFF` | Enable/disable CUDA backend (ON by default) | +| `-DUSE_ROCM=ON` | Enable ROCm/HIP backend | +| `-DUSE_METAL=ON` | Enable Metal backend (default on macOS) | +| `-DCMAKE_BUILD_TYPE=Debug` | Debug build with `TVM_LOG_DEBUG` enabled | + +## Editable Installs + +**Never use `pip install -e .`** (editable install). When running Python from the repo root, the local `./tilelang` directory is imported instead of the installed copy (because `.` is on `sys.path` by default). This makes editable installs unnecessary. Avoid `pip install -e .` as it can cause import confusion with this project's layout. + +## Running Tests + +Most tests require a GPU. + +```bash +python -m pytest testing/python/ -x +``` + +Run a specific test file or test case: + +```bash +python -m pytest testing/python/language/test_tilelang_language_copy.py -x +python -m pytest testing/python/language/test_tilelang_language_copy.py -x -k "test_name" +``` + +For Metal-specific tests (requires macOS with Apple Silicon): + +```bash +python -m pytest testing/python/metal/ -x +``` From 7a515b58e21cf90f1ba941f6608f9cf583dd7d47 Mon Sep 17 00:00:00 2001 From: Zhang Jason Date: Mon, 13 Apr 2026 01:27:09 +0800 Subject: [PATCH 040/156] [AMD][gfx950] Add gfx950 support for DeepGeem example (#2028) * add MI355 support and fix some rocm releated issues * add MI355 support and fix some rocm releated issues * update * add gfx950 support for deepgemm example * lint fix --------- Co-authored-by: LeiWang1999 --- docker/Dockerfile.rocm | 2 +- src/target/codegen_hip.cc | 100 +++++++++++++++++++- src/tl_templates/hip/common.h | 1 + tilelang/intrinsics/mfma_macro_generator.py | 12 ++- 4 files changed, 105 insertions(+), 10 deletions(-) diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index f0dd2050d0..a658a510c3 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -17,7 +17,7 @@ ENV USE_ROCM=1 ENV USE_CUDA=0 ENV ROCM_HOME=/opt/rocm ENV HIP_PLATFORM=amd -ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx1201;gfx1100" +ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx950;gfx1201;gfx1100" RUN apt-get update && apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev && \ apt-get clean autoclean && rm -rf /var/lib/apt/lists/{cache,log} /tmp/* /var/tmp/* diff --git a/src/target/codegen_hip.cc b/src/target/codegen_hip.cc index 40c4cbd3e1..1955f492a9 100644 --- a/src/target/codegen_hip.cc +++ b/src/target/codegen_hip.cc @@ -231,6 +231,14 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream &os) { // NOLINT(*) ICHECK_EQ(lanes % 2, 0) << "only support even lane for float type with lanes > 4"; os << "ulonglong" << lanes / 2; + } else if (lanes == 16) { + // float32x16: GCC vector extension type used by MFMA accumulators. + os << "float32x16"; + return; + } else if (lanes == 32) { + // float32x32: GCC vector extension type used by MFMA accumulators. + os << "float32x32"; + return; } else { fail = true; } @@ -256,6 +264,10 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream &os) { // NOLINT(*) } else if (lanes <= 8) { ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; os << "uint" << lanes / 2; + } else if (lanes == 16) { + // bfloat16x16: struct { bfloat16_t data[16]; } used by MFMA accumulators. + os << "bfloat16x16"; + return; } else { fail = true; } @@ -469,6 +481,8 @@ void CodeGenTileLangHIP::PrintVecElemLoad(const std::string &vec, DataType t, static const char access[] = {'x', 'y', 'z', 'w'}; ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 + : (t.lanes() == 16) ? 16 + : (t.lanes() == 32) ? 32 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { @@ -479,6 +493,14 @@ void CodeGenTileLangHIP::PrintVecElemLoad(const std::string &vec, DataType t, std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))"; } + } else if ((t.lanes() == 16 || t.lanes() == 32) && t.bits() == 32 && + t.is_float()) { + // float32x16/float32x32: __attribute__((__vector_size__(...))) supports + // subscript. + os << vec << "[" << i << "]"; + } else if (t.lanes() == 16 && t.is_bfloat16()) { + // bfloat16x16: struct { bfloat16_t data[16]; } + os << vec << ".data[" << i << "]"; } else if (t.is_float16()) { os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; @@ -514,7 +536,10 @@ void CodeGenTileLangHIP::PrintVecElemStore(const std::string &vec, DataType t, int i, const std::string &value) { this->PrintIndent(); static const char access[] = {'x', 'y', 'z', 'w'}; + ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 + : (t.lanes() == 16) ? 16 + : (t.lanes() == 32) ? 32 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { @@ -530,6 +555,14 @@ void CodeGenTileLangHIP::PrintVecElemStore(const std::string &vec, DataType t, } stream << "(" << value << " << " << i % 4 * 8 << ");\n"; } + } else if ((t.lanes() == 16 || t.lanes() == 32) && t.bits() == 32 && + t.is_float()) { + // float32x16/float32x32: __attribute__((__vector_size__(...))) supports + // subscript. + stream << vec << "[" << i << "] = " << value << ";\n"; + } else if (t.lanes() == 16 && t.is_bfloat16()) { + // bfloat16x16: struct { bfloat16_t data[16]; } + stream << vec << ".data[" << i << "] = " << value << ";\n"; } else if (t.is_float16()) { stream << "*((half_t*)(&(((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << "))) = " << value << ";\n"; @@ -974,7 +1007,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { {"float8_e5m2fnuzx8", "long"}, {"float8_e5m2x4", "fp8_e5_4_t"}, {"float8_e5m2x8", "long"}, - {"float32x16", "float32x16"}}; + {"float32x16", "float32x16"}, + {"float32x32", "float32x32"}}; std::string call_mfma_code = R"({ *((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}), *((({B_dtype}*){b_ref}) + {b_bias}), @@ -1204,8 +1238,8 @@ void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode *op) { ICHECK(!func_name.empty() && panel_size > 0) << "threadblock_swizzle_pattern: failed to extract func_name and " "panel_size"; - this->stream << "const dim3 blockIdx = tl::" << func_name << "(" - << panel_size << ");\n"; + this->stream << "const dim3 blockIdx = tl::" << func_name << "<" + << panel_size << ">();\n"; this->VisitStmt(op->body); return; } @@ -1304,16 +1338,48 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode *op, if (op->dtype.is_float() && op->dtype.bits() == 32 && op->dtype.lanes() == 8) { std::string v = PrintExpr(op->value); + // HIP does not allow taking the address of a temporary, so use a union + // to reinterpret float2 as unsigned long long without UB or temp-address. os << "make_ulonglong4("; for (int i = 0; i < 4; ++i) { if (i != 0) os << ", "; - os << "*(unsigned long long*)&make_float2(" << v << ", " << v << ")"; + os << "([&]{ union { float2 f; unsigned long long u; } _tmp;" + << " _tmp.f = make_float2(" << v << ", " << v + << "); return _tmp.u; }())"; } os << ')'; return; } + if (op->dtype.is_float() && op->dtype.bits() == 32 && + (op->dtype.lanes() == 16 || op->dtype.lanes() == 32)) { + // float32x16/float32x32: GCC vector extension — initialize with compound + // literal. + std::string v = PrintExpr(op->value); + os << "(float32x" << op->dtype.lanes() << "){"; + for (int i = 0; i < op->dtype.lanes(); ++i) { + if (i != 0) + os << ", "; + os << v; + } + os << "}"; + return; + } + + if (op->dtype.is_bfloat16() && op->dtype.lanes() == 16) { + // bfloat16x16: struct aggregate initializer. + std::string v = PrintExpr(op->value); + os << "bfloat16x16{"; + for (int i = 0; i < 16; ++i) { + if (i != 0) + os << ", "; + os << v; + } + os << "}"; + return; + } + if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) { bool fail = false; const int64_t *p = as_const_int(op->value); @@ -1473,7 +1539,7 @@ void CodeGenTileLangHIP::PrintVecElemLoadExpr(DataType t, int i, return; } - if (t.is_bfloat16()) { + if (t.is_bfloat16() && t.lanes() != 16) { if (i == 0) { os << "make_"; PrintType(t, os); @@ -1492,6 +1558,30 @@ void CodeGenTileLangHIP::PrintVecElemLoadExpr(DataType t, int i, return; } + if ((t.lanes() == 16 || t.lanes() == 32) && t.bits() == 32 && t.is_float()) { + // float32x16/float32x32: compound literal. + if (i == 0) + os << "(float32x" << t.lanes() << "){"; + os << value; + if (i != t.lanes() - 1) + os << ","; + else + os << "}"; + return; + } + + if (t.lanes() == 16 && t.is_bfloat16()) { + // bfloat16x16: struct aggregate initializer. + if (i == 0) + os << "bfloat16x16{"; + os << value; + if (i != t.lanes() - 1) + os << ","; + else + os << "}"; + return; + } + if (i == 0) { os << "make_"; PrintType(t, os); diff --git a/src/tl_templates/hip/common.h b/src/tl_templates/hip/common.h index 1711e2c7c0..c7041b4a18 100644 --- a/src/tl_templates/hip/common.h +++ b/src/tl_templates/hip/common.h @@ -92,6 +92,7 @@ typedef using int32x4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; using float32x4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; using float32x16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; +using float32x32 = __attribute__((__vector_size__(32 * sizeof(float)))) float; using int8x4 = __attribute__((__vector_size__(4 * sizeof(int8_t)))) int8_t; diff --git a/tilelang/intrinsics/mfma_macro_generator.py b/tilelang/intrinsics/mfma_macro_generator.py index c6069f869f..d482fe6a33 100644 --- a/tilelang/intrinsics/mfma_macro_generator.py +++ b/tilelang/intrinsics/mfma_macro_generator.py @@ -328,6 +328,8 @@ def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0) A_buf = A_region.buffer A_base0 = A_region.region[-2].min A_base1 = A_region.region[-1].min + # Leading dimensions (e.g. pipeline stage axis) – empty for 2-D buffers + A_other = [r.min for r in A_region.region[:-2]] @T.macro def _warp_ldmatrix_a( @@ -343,13 +345,13 @@ def _warp_ldmatrix_a( for local_id in T.vectorized(k_pack * local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x) - A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[tuple(A_other) + (A_base0 + l + row, A_base1 + r + col)] else: for i in T.serial(warp_rows): for local_id in T.vectorized(k_pack * local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k)) - A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[tuple(A_other) + (A_base0 + l + row, A_base1 + r + col)] return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) @@ -370,6 +372,8 @@ def ldmatrix_b(self, B_local_buf, B_shared_buf: Buffer | BufferRegion, ki, rk=0) B_buf = B_region.buffer B_base0 = B_region.region[-2].min B_base1 = B_region.region[-1].min + # Leading dimensions (e.g. pipeline stage axis) – empty for 2-D buffers + B_other = [r.min for r in B_region.region[:-2]] @T.macro def _warp_ldmatrix_b( @@ -388,7 +392,7 @@ def _warp_ldmatrix_b( warp_n * warp_col_tiles + j * micro_size_y, rk * chunk + ki * (k_pack * micro_size_k), ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col] + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[tuple(B_other) + (B_base0 + l + row, B_base1 + r + col)] else: for j in T.serial(warp_cols): @@ -398,7 +402,7 @@ def _warp_ldmatrix_b( rk * chunk + ki * (k_pack * micro_size_k), warp_n * warp_col_tiles + j * micro_size_y, ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col] + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[tuple(B_other) + (B_base0 + l + row, B_base1 + r + col)] return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) From 11dc3e66bde475e918f0049812555a92575aeb82 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Mon, 13 Apr 2026 12:25:04 +0800 Subject: [PATCH 041/156] Fix shared memory buffer reuse --- .../auto_schedule/warpgroup_partition.cc | 5 +- .../merge_shared_memory_allocations.cc | 516 +++++++++++++----- 2 files changed, 385 insertions(+), 136 deletions(-) diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index ccb6658ddf..b5b8b5f9a6 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -965,7 +965,8 @@ Stmt ApplyWarpgroupPartitionToIRStructure( } else { // Fallback for non-SequenceNode root: clone entire root per warpgroup for (size_t i = 0; i < num_wgs; ++i) { - wg_structures[i] = CloneIRStructureWithWarpgroupFilter(root, i); + wg_structures[i] = + RemoveUnusedLetDecls(CloneIRStructureWithWarpgroupFilter(root, i)); } } @@ -1104,6 +1105,8 @@ Stmt ApplyWarpgroupPartitionToIRStructure( segmented_stmts.push_back(MakeWarpgroupIf(wg_stmts)); } + segmented_stmts.push_back(AttrStmt( + Integer(0), attr::kAutoScheduleSharedMemoryBoundary, 0, Evaluate(0))); if_then_else = SeqStmt::Flatten(segmented_stmts); } else { // Fallback for non-SequenceNode root: no boundary insertion, simple diff --git a/src/transform/merge_shared_memory_allocations.cc b/src/transform/merge_shared_memory_allocations.cc index 999ae8763a..6d686a2519 100644 --- a/src/transform/merge_shared_memory_allocations.cc +++ b/src/transform/merge_shared_memory_allocations.cc @@ -23,6 +23,7 @@ * memory allocation. This pass merges multiple TIR-level dynamic or static * shared memory allocations into one allocation. */ +#include #include #include #include @@ -587,93 +588,6 @@ class SharedMemoryRewriter : public StmtExprMutator { return StmtMutator::VisitStmt_(op); } - Stmt VisitStmt_(const SeqStmtNode *op) final { - // Visit children first (strips boundaries, shared memory allocates, etc.) - Stmt visited = StmtExprMutator::VisitStmt_(op); - const auto *seq = visited.as(); - if (!seq) - return visited; - - // Helper: check if stmt is Evaluate(0) (remnant of stripped boundaries) - auto is_noop = [](const Stmt &s) -> bool { - const auto *e = s.as(); - return e && is_zero(e->value); - }; - - // Helper: peel off DeclBuffer layers, returning (buffers, inner stmt) - auto unwrap_decl_buffers = - [](Stmt s) -> std::pair, Stmt> { - std::vector bufs; - while (const auto *d = s.as()) { - bufs.push_back(d->buffer); - s = d->body; - } - return {bufs, s}; - }; - - // Filter out Evaluate(0) remnants - std::vector stmts; - for (const auto &s : seq->seq) { - if (!is_noop(s)) { - stmts.push_back(s); - } - } - - // Merge consecutive DeclBuffer*(IfThenElse(same_cond, ...)) entries - tir::ExprDeepEqual expr_equal; - std::vector merged; - size_t i = 0; - while (i < stmts.size()) { - auto [bufs_i, inner_i] = unwrap_decl_buffers(stmts[i]); - const auto *ite_i = inner_i.as(); - if (!ite_i || !ite_i->else_case.defined()) { - merged.push_back(stmts[i]); - ++i; - continue; - } - - // Start a merge group - std::vector all_bufs(bufs_i); - std::vector then_parts{ite_i->then_case}; - std::vector else_parts{ite_i->else_case.value()}; - PrimExpr cond = ite_i->condition; - - size_t j = i + 1; - while (j < stmts.size()) { - auto [bufs_j, inner_j] = unwrap_decl_buffers(stmts[j]); - const auto *ite_j = inner_j.as(); - if (!ite_j || !ite_j->else_case.defined()) - break; - if (!expr_equal(cond, ite_j->condition)) - break; - all_bufs.insert(all_bufs.end(), bufs_j.begin(), bufs_j.end()); - then_parts.push_back(ite_j->then_case); - else_parts.push_back(ite_j->else_case.value()); - ++j; - } - - if (j == i + 1) { - // No merge possible - merged.push_back(stmts[i]); - ++i; - } else { - // Build merged IfThenElse - Stmt body = IfThenElse(cond, SeqStmt::Flatten(then_parts), - SeqStmt::Flatten(else_parts)); - // Wrap with all DeclBuffers (innermost last) - for (int k = static_cast(all_bufs.size()) - 1; k >= 0; --k) { - body = DeclBuffer(all_bufs[k], body); - } - merged.push_back(body); - i = j; - } - } - - if (merged.size() == 1) - return merged[0]; - return SeqStmt(merged); - } - Stmt VisitStmt_(const AllocateNode *op) final { if (IsAppropriateSharedMemory(op->buffer_var)) { return StmtExprMutator::VisitStmt(op->body); @@ -868,66 +782,99 @@ class SharedMemoryRewriter : public StmtExprMutator { for (int i = 0, n = static_cast(blocks_.size()); i < n; ++i) { size_t aligned = AlignUpSize(blocks_[i].offset, alignment); size_t head = aligned - blocks_[i].offset; - if (head <= blocks_[i].size && (blocks_[i].size - head) >= need) { - size_t waste = blocks_[i].size - head - need; - if (waste < best_waste) { - best_waste = waste; - best = i; - } + if (head > blocks_[i].size) + continue; + size_t usable = blocks_[i].size - head; + if (usable < need) + continue; + size_t waste = blocks_[i].size - need; + if (waste < best_waste) { + best_waste = waste; + best = i; } } - if (best < 0) { + if (best < 0) return std::nullopt; - } - FreeBlock blk = blocks_[best]; - size_t aligned = AlignUpSize(blk.offset, alignment); + return CarveBlock(best, need, alignment); + } + + // Try to allocate from the free block whose end touches arena_top. + // The block may be smaller than need; the caller grows the arena to + // cover the deficit. Returns the aligned start offset on success. + std::optional AllocateFromTail(size_t need, size_t alignment, + size_t arena_top) { + if (blocks_.empty()) + return std::nullopt; + int tail_idx = static_cast(blocks_.size()) - 1; + if (blocks_[tail_idx].offset + blocks_[tail_idx].size != arena_top) + return std::nullopt; + + size_t aligned = AlignUpSize(blocks_[tail_idx].offset, alignment); + if (aligned >= arena_top) + return std::nullopt; + + FreeBlock blk = blocks_[tail_idx]; size_t head = aligned - blk.offset; - size_t tail = blk.size - head - need; - blocks_.erase(blocks_.begin() + best); + + blocks_.erase(blocks_.begin() + tail_idx); if (head) { - blocks_.push_back({blk.offset, head}); - } - if (tail) { - blocks_.push_back({aligned + need, tail}); + InsertBlock(blk.offset, head); } - Normalize(); return aligned; } void Free(size_t offset, size_t size) { if (size == 0) return; - blocks_.push_back({offset, size}); - Normalize(); + InsertBlock(offset, size); } private: - void Normalize() { - if (blocks_.empty()) - return; - std::sort(blocks_.begin(), blocks_.end(), - [](const FreeBlock &a, const FreeBlock &b) { - return a.offset < b.offset; - }); - std::vector merged; - merged.reserve(blocks_.size()); - for (const FreeBlock &blk : blocks_) { - if (merged.empty()) { - merged.push_back(blk); - continue; - } - FreeBlock &last = merged.back(); - size_t last_end = last.offset + last.size; - if (blk.offset <= last_end) { - size_t blk_end = blk.offset + blk.size; - if (blk_end > last_end) { - last.size = blk_end - last.offset; - } - } else { - merged.push_back(blk); + // Insert a block at the correct sorted position and merge with adjacent + // neighbours so the sorted-and-coalesced invariant is preserved. + void InsertBlock(size_t offset, size_t size) { + FreeBlock entry{offset, size}; + auto it = std::lower_bound( + blocks_.begin(), blocks_.end(), offset, + [](const FreeBlock &b, size_t off) { return b.offset < off; }); + it = blocks_.insert(it, entry); + + // Merge with the next neighbour. + auto next = std::next(it); + if (next != blocks_.end() && it->offset + it->size >= next->offset) { + size_t merged_end = + std::max(it->offset + it->size, next->offset + next->size); + it->size = merged_end - it->offset; + blocks_.erase(next); + } + // Merge with the previous neighbour. + if (it != blocks_.begin()) { + auto prev = std::prev(it); + if (prev->offset + prev->size >= it->offset) { + size_t merged_end = + std::max(prev->offset + prev->size, it->offset + it->size); + prev->size = merged_end - prev->offset; + blocks_.erase(it); } } - blocks_ = std::move(merged); + } + + // Remove blocks_[idx], allocate `need` bytes at the aligned offset + // within it, and return any head/tail fragments to the free list. + size_t CarveBlock(int idx, size_t need, size_t alignment) { + FreeBlock blk = blocks_[idx]; + blocks_.erase(blocks_.begin() + idx); + + size_t aligned = AlignUpSize(blk.offset, alignment); + size_t head = aligned - blk.offset; + size_t tail = blk.size - head - need; + + // Insert tail first so indices are not disturbed by head insertion. + if (tail) + InsertBlock(aligned + need, tail); + if (head) + InsertBlock(blk.offset, head); + return aligned; } std::vector blocks_; @@ -954,8 +901,6 @@ class SharedMemoryRewriter : public StmtExprMutator { if (lhs.size_bytes != rhs.size_bytes) { return lhs.size_bytes > rhs.size_bytes; } - // Use name comparison for deterministic ordering instead of - // pointer comparison return lhs.var->name_hint < rhs.var->name_hint; }); @@ -966,7 +911,6 @@ class SharedMemoryRewriter : public StmtExprMutator { size_t arena_top = 0; std::unordered_map offsets; - // Expire intervals that end before or at program counter `pc`. auto retire = [&](int pc) { while (!active.empty() && active.top().end <= pc) { const ActiveInterval top = active.top(); @@ -978,13 +922,23 @@ class SharedMemoryRewriter : public StmtExprMutator { for (const Interval &interval : intervals) { retire(interval.start); size_t offset = 0; - // Try to recycle previously freed memory first; fall back to bumping the - // arena. + // 1) Reuse a fully fitting free block (best-fit). + // 2) Extend the tail free block that touches arena_top. + // 3) Bump-allocate at arena_top (reclaim alignment gap). if (auto slot = freelist.Allocate(interval.size_bytes, interval.alignment)) { offset = slot.value(); + } else if (auto tail_slot = freelist.AllocateFromTail( + interval.size_bytes, interval.alignment, arena_top)) { + offset = tail_slot.value(); + arena_top = offset + interval.size_bytes; } else { offset = AlignUpSize(arena_top, interval.alignment); + // Reclaim the alignment gap [arena_top, offset) so future small + // allocations can reuse it. + if (offset > arena_top) { + freelist.Free(arena_top, offset - arena_top); + } arena_top = offset + interval.size_bytes; } active.push(ActiveInterval{interval.end, offset, interval.size_bytes, @@ -1562,6 +1516,295 @@ class SharedMemoryRewriter : public StmtExprMutator { std::unordered_map shmem_alignment_map_; }; +/*! + * \brief Post-pass that merges consecutive thread-partitioning IfThenElse + * nodes. Runs after SharedMemoryRewriter so that it operates on the + * already-rewritten IR without interfering with the analysis framework. + * + * For each IfThenElse whose condition partitions threadIdx.x, we extract + * cut points from the conditions (pattern-matching thread_var < N, + * thread_var >= N, etc.) and track the precise integer interval [lo, hi) + * for each leaf branch. Consecutive IfThenElse nodes whose interval sets + * are either identical or mutually non-overlapping are merged. + */ +class ThreadPartitionMerger : public StmtMutator { +public: + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iter_var = Downcast(op->node); + if (iter_var->thread_tag == "threadIdx.x") { + thread_var_ = iter_var->var; + if (const auto *imm = op->value.as()) { + thread_extent_ = imm->value; + } + } + } + return StmtMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const SeqStmtNode *op) final { + Stmt visited = StmtMutator::VisitStmt_(op); + const auto *seq = visited.as(); + if (!seq) + return visited; + + if (!thread_var_.defined() || thread_extent_ <= 1) + return visited; + + auto unwrap_decl_buffers = + [](Stmt s) -> std::pair, Stmt> { + std::vector bufs; + while (const auto *d = s.as()) { + bufs.push_back(d->buffer); + s = d->body; + } + return {bufs, s}; + }; + + std::vector stmts(seq->seq.begin(), seq->seq.end()); + + struct ThreadInterval { + int64_t lower; + int64_t upper; + bool operator==(const ThreadInterval &o) const { + return lower == o.lower && upper == o.upper; + } + }; + + struct CutInfo { + int64_t cut; + bool then_lower; + }; + + auto try_extract_cut = + [this](const PrimExpr &cond, + arith::Analyzer *analyzer) -> std::optional { + if (!thread_var_.defined()) + return std::nullopt; + const VarNode *tv = thread_var_.value().get(); + auto eval_const = + [analyzer](const PrimExpr &e) -> std::optional { + auto b = analyzer->const_int_bound(e); + return (b->min_value == b->max_value) + ? std::optional(b->min_value) + : std::nullopt; + }; + if (const auto *op = cond.as()) { + if (op->a.as() == tv) { + if (auto v = eval_const(op->b)) + return CutInfo{*v, true}; + } + if (op->b.as() == tv) { + if (auto v = eval_const(op->a)) + return CutInfo{*v + 1, false}; + } + } + if (const auto *op = cond.as()) { + if (op->a.as() == tv) { + if (auto v = eval_const(op->b)) + return CutInfo{*v + 1, true}; + } + if (op->b.as() == tv) { + if (auto v = eval_const(op->a)) + return CutInfo{*v, false}; + } + } + if (const auto *op = cond.as()) { + if (op->a.as() == tv) { + if (auto v = eval_const(op->b)) + return CutInfo{*v + 1, false}; + } + if (op->b.as() == tv) { + if (auto v = eval_const(op->a)) + return CutInfo{*v, true}; + } + } + if (const auto *op = cond.as()) { + if (op->a.as() == tv) { + if (auto v = eval_const(op->b)) + return CutInfo{*v, false}; + } + if (op->b.as() == tv) { + if (auto v = eval_const(op->a)) + return CutInfo{*v + 1, true}; + } + } + return std::nullopt; + }; + + using IntervalBranches = std::vector>; + + std::function + decompose_thread_partition; + decompose_thread_partition = + [this, &try_extract_cut, &decompose_thread_partition]( + const Stmt &stmt, arith::Analyzer *analyzer, int64_t lo, int64_t hi, + IntervalBranches &out) -> bool { + const auto *ite = stmt.as(); + if (!ite) { + out.push_back({{lo, hi}, stmt}); + return true; + } + auto cut_opt = try_extract_cut(ite->condition, analyzer); + if (!cut_opt.has_value()) { + out.push_back({{lo, hi}, stmt}); + return true; + } + int64_t cut = cut_opt->cut; + if (cut <= lo || cut >= hi) + return false; + + if (cut_opt->then_lower) { + if (!decompose_thread_partition(ite->then_case, analyzer, lo, cut, out)) + return false; + if (ite->else_case.defined()) + return decompose_thread_partition(ite->else_case.value(), analyzer, + cut, hi, out); + } else { + if (ite->else_case.defined()) { + if (!decompose_thread_partition(ite->else_case.value(), analyzer, lo, + cut, out)) + return false; + } + if (!decompose_thread_partition(ite->then_case, analyzer, cut, hi, out)) + return false; + } + return true; + }; + + auto try_decompose = + [this, &decompose_thread_partition]( + const Stmt &stmt) -> std::optional { + if (!thread_var_.defined() || thread_extent_ <= 1) + return std::nullopt; + const auto *ite = stmt.as(); + if (!ite) + return std::nullopt; + + arith::Analyzer analyzer; + IntervalBranches branches; + if (!decompose_thread_partition(stmt, &analyzer, 0, thread_extent_, + branches)) + return std::nullopt; + if (branches.empty()) + return std::nullopt; + return branches; + }; + + auto rebuild_from_intervals = + [this](const std::vector &intervals, + const std::vector &bodies) -> Stmt { + ICHECK_EQ(intervals.size(), bodies.size()); + Stmt result = Evaluate(0); + for (int i = static_cast(intervals.size()) - 1; i >= 0; --i) { + PrimExpr lower = + make_const(thread_var_.value().dtype(), intervals[i].lower); + PrimExpr upper = + make_const(thread_var_.value().dtype(), intervals[i].upper); + PrimExpr cond = + (lower <= thread_var_.value()) && (thread_var_.value() < upper); + result = IfThenElse(cond, bodies[i], result); + } + return result; + }; + + auto intervals_compatible = [](const IntervalBranches &a, + const IntervalBranches &b) -> bool { + for (const auto &[iv_a, _a] : a) { + for (const auto &[iv_b, _b] : b) { + if (iv_a == iv_b) + continue; + if (iv_a.upper <= iv_b.lower || iv_b.upper <= iv_a.lower) + continue; + return false; + } + } + return true; + }; + + auto merge_interval_branches = + [](const IntervalBranches &a, + const IntervalBranches &b) -> IntervalBranches { + IntervalBranches result = a; + for (const auto &[iv_b, body_b] : b) { + bool found = false; + for (auto &[iv_r, body_r] : result) { + if (iv_r == iv_b) { + body_r = SeqStmt::Flatten(std::vector{body_r, body_b}); + found = true; + break; + } + } + if (!found) { + result.push_back({iv_b, body_b}); + } + } + std::sort(result.begin(), result.end(), [](const auto &x, const auto &y) { + return x.first.lower < y.first.lower; + }); + return result; + }; + + // Main merge loop + std::vector merged; + size_t i = 0; + while (i < stmts.size()) { + auto [bufs_i, inner_i] = unwrap_decl_buffers(stmts[i]); + auto opt_i = try_decompose(inner_i); + if (!opt_i.has_value()) { + merged.push_back(stmts[i]); + ++i; + continue; + } + + IntervalBranches accumulated = opt_i.value(); + std::vector all_bufs = bufs_i; + + size_t j = i + 1; + while (j < stmts.size()) { + auto [bufs_j, inner_j] = unwrap_decl_buffers(stmts[j]); + auto opt_j = try_decompose(inner_j); + if (!opt_j.has_value()) + break; + if (!intervals_compatible(accumulated, opt_j.value())) + break; + accumulated = merge_interval_branches(accumulated, opt_j.value()); + all_bufs.insert(all_bufs.end(), bufs_j.begin(), bufs_j.end()); + ++j; + } + + if (j == i + 1) { + merged.push_back(stmts[i]); + ++i; + } else { + std::vector intervals; + std::vector bodies; + for (const auto &[iv, body] : accumulated) { + intervals.push_back(iv); + bodies.push_back(body); + } + + Stmt body = rebuild_from_intervals(intervals, bodies); + for (int k = static_cast(all_bufs.size()) - 1; k >= 0; --k) + body = DeclBuffer(all_bufs[k], body); + merged.push_back(body); + i = j; + } + } + + if (merged.size() == stmts.size()) + return visited; // nothing merged + if (merged.size() == 1) + return merged[0]; + return SeqStmt(merged); + } + +private: + Optional thread_var_; + int64_t thread_extent_{0}; +}; + Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem, bool enable_aggressive_merge, int align_bytes = 16, bool verbose = false) { @@ -1579,6 +1822,9 @@ Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem, rewriter.PlanReuse(stmt, false, enable_aggressive_merge); stmt = rewriter(std::move(stmt)); } + // Merge consecutive thread-partitioning IfThenElse nodes + ThreadPartitionMerger merger; + stmt = merger(std::move(stmt)); return stmt; } From 5d729eeebca3ea776373a2918e3945d667bd1c7d Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Mon, 13 Apr 2026 12:42:44 +0800 Subject: [PATCH 042/156] [Refactor] Remove GEMM v1 and promote gemm_py to be the canonical gemm op (#2033) * [Refactor] Remove GEMM v1 and merge gemm_py into the canonical gemm op GEMM v2 has been the default for a while and no tests or examples still depend on v1. This removes the legacy C++ GemmNode path entirely and promotes the former GemmPyNode (Python-delegated lowering) to be the one and only GEMM tile op. C++: - Merge src/op/gemm_py.{h,cc} into src/op/gemm.{h,cc}. GemmPyNode becomes GemmNode, GemmPy becomes Gemm, FFI type key "tl.GemmPy" becomes "tl.Gemm", and the op names "tl.tileop.{gemm_py,wgmma_gemm_py,tcgen05_gemm_py}" collapse to "tl.tileop.{gemm,wgmma_gemm,tcgen05_gemm}". The shared infrastructure (GemmWarpPolicyNode, GemmInst enum, GemmWarpPolicyNode::computeWarpPartition) is preserved in place so GemmSP/GemmSPPy keep compiling unchanged. - Delete the tl_gemm builtin, its codegen dispatch in codegen_{cuda,hip,cutedsl}.cc, and its matching branches in inject_fence_proxy.cc and merge_shared_memory_allocations.cc. tl_gemm_sp is kept as the remaining async-proxy gemm builtin. - Drop #include "gemm_py.h" and GemmPyNode/GemmPy references from inject_pipeline.cc, producer_consumer_ws.cc, and lower_blackwell_2sm.cc. Python: - tileop/gemm: rename GemmPy -> Gemm, register as "tl.Gemm", and rename the FFI callbacks tl.gemm_py.{lower,infer_layout} -> tl.gemm.{lower, infer_layout}. _ffi_api.GemmPyGemmInst becomes GemmGetGemmInst. - language/gemm_op.py: delete gemm_v1/gemm_v2 wrappers. gemm() now emits tl.tileop.gemm directly; wgmma_gemm/tcgen05_gemm drop the _py suffix. - env.py: remove TILELANG_USE_GEMM_V1 and the use_gemm_v1() helper. - ir.py: drop the empty v1-era `class Gemm(Node, Scriptable): ...` placeholder that was colliding with the real tileop.gemm.Gemm wrapper under the "tl.Gemm" FFI type key. - contrib/cutedsl: delete cutedsl/gemm_v1.py (runtime helper for the v1 cutedsl lowering path) and its re-export in cutedsl/__init__.py. Tests / maint: - testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py switches from T.gemm_v2 to T.gemm. - testing/python/transform/test_tilelang_transform_inject_fence_proxy.py:: test_lower_fence_proxy is rewritten to exercise tl.tl_gemm_sp instead of the removed tl.tl_gemm, preserving the fence-proxy assertion. - Delete maint/gemm_v2/ (v1-vs-v2 benchmarking scripts) and drop the v1-specific comment from maint/scripts/run_local_ci_test.sh. C++ headers under src/tl_templates/{cuda,hip,cpp}/gemm*.h are left in place as reference implementations; they are no longer reachable from the generated code path since the only consumer (tl_gemm builtin) is gone, but keeping them avoids a large, hard-to-reverse deletion. * [Fix] Restore cutedsl/gemm_v1.py (Gemm_SM80 / Gemm_SM90 helpers) The previous commit deleted cutedsl/gemm_v1.py on the assumption that it was the CuteDSL lowering helper for the removed T.gemm_v1 path. It is actually a runtime helper module that defines SM80/SM90 CUTE DSL primitives (make_smem_layout_AB, get_tma_atom, etc.) which are emitted into generated cutedsl kernel code via `import tilelang.contrib.cutedsl as tl; tl.Gemm_SM90.get_tma_atom(...)` (see jit/adapter/cutedsl/wrapper.py CUBIN_TMA_ATOM_INIT_TEMPLATE). The "v1" in the filename refers to the CuteDSL helper version, not to the deleted T.gemm_v1 op, so removing it broke the cutedsl code path for attention_sink, blocksparse_attention, blocksparse_gemm, and deepgemm examples with `module 'tilelang.contrib.cutedsl' has no attribute 'Gemm_SM90'`. Restore the file and its `from .gemm_v1 import *` wildcard re-export. * [Refactor] Restore tl_gemm builtin as a call_intrin path to tl:: templates The previous commit in this PR deleted the `tl_gemm` builtin entirely on the argument that no code emitted it anymore. That was too aggressive: keeping the builtin preserves a deliberate extension point for any Python-side gemm lowering backend (for example a future GemmLegacyExtern) to delegate into the existing C++ templates in src/tl_templates//gemm*.h via T.call_intrin("handle", tir.op.Op.get("tl.tl_gemm"), op_instance_str, A_ptr, B_ptr, C_ptr) without any extra C++ plumbing. It is also the only thing that keeps those tl_templates headers reachable rather than dead code. Restore: - tl_gemm builtin declaration in src/op/builtin.h (with an updated docstring describing the call_intrin entry point) and its TIR_DEFINE_TL_BUILTIN registration in src/op/builtin.cc. - Codegen dispatch branches in codegen_cuda.cc, codegen_hip.cc, and codegen_cutedsl.cc that turn a tl_gemm Call into a PrintCallExtern emission of `tl::gemm_xx<...>(A, B, C)`. - SMEM alignment and async-proxy fence handling for tl_gemm in merge_shared_memory_allocations.cc and inject_fence_proxy.cc. These passes are tied to the builtin's semantics (not to the removed v1 GemmNode) and are still correct the moment anything emits tl_gemm. - test_lower_fence_proxy in test_tilelang_transform_inject_fence_proxy reverts to exercising `tl.tl_gemm` instead of the temporary `tl.tl_gemm_sp` substitution introduced when the builtin was gone. Note that the current sole gemm lowering path (the merged GemmNode that delegates to Python via tl.gemm.lower) still emits inline mma/wgmma/ tcgen5mma via the Python macro generators, not tl_gemm. So restoring the builtin is a no-op for the default path - it only re-enables the "call into v1 templates from Python" option for anyone who wants it. * [Fix] Default HIP execution backend to tvm_ffi, not cython resolve_execution_backend() only promoted cuda and metal to tvm_ffi for the "auto" path, so hip fell into the else branch and picked cython. But allowed_backends_for_target() already lists tvm_ffi as an allowed backend for hip, and the cuda default is tvm_ffi for good reasons: - The cython SourceWrapper path in src/tilelang/jit/adapter/wrapper.py has known gaps (for example T.alloc_global is not handled in create_dispatch_func, which breaks any test that uses an internally allocated global scratch buffer). - libgen.py appends user-supplied compile_flags verbatim to the hipcc/nvcc command. nvcc accepts --use_fast_math, clang++/hipcc does not, so a CUDA-style flag leaks into HIP compilation. Neither of these leak through the tvm_ffi backend. Aligning hip's default with cuda makes the two paths symmetric and unbreaks test_alloc_global and test_all_attrs_together_lazy on ROCm without having to mark the tests skip or rewrite the cython wrapper. cython is still available on hip via execution_backend='cython' for anyone who explicitly wants it. * [Test] Gate CUDA-only language tests with requires_cuda Three language tests hardcode target='cuda' (or rely on CUDA-only codegen behavior) but lacked a CUDA availability marker, so they failed on the ROCm CI runner with "Cannot find global function target.build.tilelang_cuda" once the earlier pre-existing failures got past --maxfail=3. Add @tilelang.testing.requires_cuda to: - test_tilelang_if_range - test_annotate_min_blocks_per_sm_launch_bounds - test_tilelang_transpose and test_tilelang_transpose_square Also replace the hand-rolled __main__ block in test_tilelang_transpose with tilelang.testing.main() so the skip markers are honored when the file is run directly. * [Test] Delete tile library gemm test Remove testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py. --- docs/deeplearning_operators/matmul_sparse.md | 2 +- maint/gemm_v2/correctness_evaluation.py | 808 ------------------ maint/gemm_v2/correctness_evaluation_sm70.py | 350 -------- .../gemm_v2/correctness_evaluation_tcgen05.py | 231 ----- .../correctness_evaluation_tcgen05_2cta.py | 168 ---- maint/gemm_v2/latency.py | 98 --- maint/gemm_v2/latency_gemm.py | 98 --- maint/gemm_v2/latency_mha_fwd_bhsd.py | 228 ----- maint/scripts/run_local_ci_test.sh | 2 - src/op/builtin.h | 8 +- src/op/gemm.cc | 566 +++--------- src/op/gemm.h | 15 +- src/op/gemm_py.cc | 438 ---------- src/op/gemm_py.h | 99 --- src/op/gemm_sp_py.h | 2 +- src/transform/inject_pipeline.cc | 4 +- src/transform/lower_blackwell_2sm.cc | 16 +- src/transform/lower_opaque_block.cc | 2 +- src/transform/producer_consumer_ws.cc | 4 +- .../test_tilelang_language_if_range.py | 1 + ...est_tilelang_language_min_blocks_per_sm.py | 1 + .../test_tilelang_language_transpose.py | 5 +- .../test_tilelang_tilelibrary_gemm.py | 624 -------------- tilelang/env.py | 12 - tilelang/ir.py | 4 - tilelang/jit/execution_backend.py | 2 +- tilelang/language/__init__.py | 2 +- tilelang/language/gemm_op.py | 76 +- tilelang/tileop/__init__.py | 2 +- tilelang/tileop/gemm/__init__.py | 20 +- 30 files changed, 157 insertions(+), 3731 deletions(-) delete mode 100644 maint/gemm_v2/correctness_evaluation.py delete mode 100644 maint/gemm_v2/correctness_evaluation_sm70.py delete mode 100644 maint/gemm_v2/correctness_evaluation_tcgen05.py delete mode 100644 maint/gemm_v2/correctness_evaluation_tcgen05_2cta.py delete mode 100644 maint/gemm_v2/latency.py delete mode 100644 maint/gemm_v2/latency_gemm.py delete mode 100644 maint/gemm_v2/latency_mha_fwd_bhsd.py delete mode 100644 src/op/gemm_py.cc delete mode 100644 src/op/gemm_py.h delete mode 100644 testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py diff --git a/docs/deeplearning_operators/matmul_sparse.md b/docs/deeplearning_operators/matmul_sparse.md index 8caa6182f0..09dcc6460d 100644 --- a/docs/deeplearning_operators/matmul_sparse.md +++ b/docs/deeplearning_operators/matmul_sparse.md @@ -258,4 +258,4 @@ However, fixing a specific layout introduces several potential issues: 3. Alignment requirements: `CUTLASS` enforces strict alignment checks, and many hyperparameter configurations can lead to compilation errors. (For reference, sm8x was implemented in `CUTLASS 2`.) -`T.gemm_sp_v2` was designed to address these limitations, following the approach of `T.gemm_v2`. It lowers directly to PTX, removing the need for a fixed metadata layout. +`T.gemm_sp_v2` was designed to address these limitations, following the approach of `T.gemm`. It lowers directly to PTX, removing the need for a fixed metadata layout. diff --git a/maint/gemm_v2/correctness_evaluation.py b/maint/gemm_v2/correctness_evaluation.py deleted file mode 100644 index 000bdb9488..0000000000 --- a/maint/gemm_v2/correctness_evaluation.py +++ /dev/null @@ -1,808 +0,0 @@ -# pytest correctness_evaluation.py -n 32 -import pytest -from tilelang import tvm as tvm -import tilelang.testing -from tilelang import language as T -import torch - - -def matmul( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - num_stages, - threads, -): - A_shape = (K, M) if trans_A else (M, K) - B_shape = (N, K) if trans_B else (K, N) - A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn") - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - if trans_A: - T.copy(A[k * block_K, by * block_M], A_shared) - else: - T.copy(A[by * block_M, k * block_K], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - -def _compile_and_check( - program, - trans_A, - trans_B, - in_dtype, - out_dtype, -): - kernel = tilelang.compile( - program, - out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - # tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False, - }, - ) - - print(kernel.get_kernel_source()) - - profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) - - def ref_program(A, B): - if trans_A: - A = A.T - if trans_B: - B = B.T - if in_dtype == T.float32: - A = (A.view(torch.int32) - 0x1000).view(torch.float32) - B = (B.view(torch.int32) - 0x1000).view(torch.float32) - C = torch.matmul(A.to(torch.float), B.to(torch.float)) - C = C.to(torch.__getattribute__(out_dtype)) - return C - - profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) - print("assert_allclose") - - -def run_gemm( - M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=2, - num_threads=128, -): - if block_N >= 256 or block_M >= 256 or block_K >= 256: - num_stages = 0 - program = matmul( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - num_stages, - num_threads, - ) - - _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) - - -def matmul_rs( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - num_stages, - threads, -): - A_shape = (K, M) if trans_A else (M, K) - B_shape = (N, K) if trans_B else (K, N) - A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - A_frag_shape = A_shared_shape - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn") - A_frag = T.alloc_fragment(A_frag_shape, in_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - if trans_A: - T.copy(A[k * block_K, by * block_M], A_shared) - else: - T.copy(A[by * block_M, k * block_K], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.copy(A_shared, A_frag) - T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B) - # T.gemm(A_frag, B_shared, C_local, trans_A, trans_B) - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - -def run_gemm_rs( - M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=2, - num_threads=128, -): - if block_N >= 256 or block_M >= 256 or block_K >= 256: - num_stages = 0 - program = matmul_rs( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - num_stages, - num_threads, - ) - _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) - - -def matmul_sr( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - num_stages, - threads, -): - A_shape = (K, M) if trans_A else (M, K) - B_shape = (N, K) if trans_B else (K, N) - A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - B_frag_shape = B_shared_shape - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn") - B_frag = T.alloc_fragment(B_frag_shape, in_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - if trans_A: - T.copy(A[k * block_K, by * block_M], A_shared) - else: - T.copy(A[by * block_M, k * block_K], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.copy(B_shared, B_frag) - T.gemm_v2(A_shared, B_frag, C_local, trans_A, trans_B) - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - -def run_gemm_sr( - M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=2, - num_threads=128, -): - if block_N >= 256 or block_M >= 256 or block_K >= 256: - num_stages = 0 - program = matmul_sr( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - num_stages, - num_threads, - ) - - _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) - - -def matmul_rr( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - num_stages, - threads, -): - A_shape = (K, M) if trans_A else (M, K) - B_shape = (N, K) if trans_B else (K, N) - A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - A_frag_shape = A_shared_shape - B_frag_shape = B_shared_shape - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn") - A_frag = T.alloc_fragment(A_frag_shape, in_dtype) - B_frag = T.alloc_fragment(B_frag_shape, in_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - if trans_A: - T.copy(A[k * block_K, by * block_M], A_shared) - else: - T.copy(A[by * block_M, k * block_K], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.copy(A_shared, A_frag) - T.copy(B_shared, B_frag) - T.gemm_v2(A_frag, B_frag, C_local, trans_A, trans_B) - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - -def run_gemm_rr( - M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=2, - num_threads=128, -): - if block_N >= 256 or block_M >= 256 or block_K >= 256: - num_stages = 0 - program = matmul_rr( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - num_stages, - num_threads, - ) - - _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) - - -M_VALUES = [64, 128, 256] -N_VALUES = [16, 32, 64, 128, 256, 512] -K_VALUES = [16, 32, 64, 128] -K_VALUES_8Bit = [32, 64, 128] -NUM_THREADS_VALUES = [128, 256] - - -def _generate_dtype_cases(k_values, num_threads): - """Generate dtype test cases for given K values and num_threads.""" - return ( - [ - pytest.param( - k, - T.float16, - T.float16, - T.float16, - num_threads, - id=f"K{k}-float16-float16-float16-threads{num_threads}", - ) - for k in k_values - ] - + [ - pytest.param( - k, - T.int8, - T.int32, - T.int32, - num_threads, - id=f"K{k}-int8-int32-int32-threads{num_threads}", - ) - for k in K_VALUES_8Bit - ] - + [ - pytest.param( - k, - T.float8_e5m2, - T.float32, - T.float32, - num_threads, - id=f"K{k}-float8_e5m2-float32-float32-threads{num_threads}", - ) - for k in K_VALUES_8Bit - ] - + [ - pytest.param( - k, - T.float8_e4m3fn, - T.float32, - T.float32, - num_threads, - id=f"K{k}-float8_e4m3fn-float32-float32-threads{num_threads}", - ) - for k in K_VALUES_8Bit - ] - ) - - -# num_threads=128 can work with any N -FALSE_TRUE_CASES_128 = _generate_dtype_cases(K_VALUES, 128) -# num_threads=256 requires N >= 32 -FALSE_TRUE_CASES_256 = _generate_dtype_cases(K_VALUES, 256) -FALSE_TRUE_CASES = FALSE_TRUE_CASES_128 + FALSE_TRUE_CASES_256 - - -def _ensure_torch_dtypes(*dtype_names): - import torch - - for name in set(dtype_names): - if not hasattr(torch, name): - pytest.skip(f"Torch does not expose dtype {name}") - - -def _skip_if_threads_exceed_n(num_threads, n): - """Skip test if num_threads=256 and N < 32.""" - if num_threads == 256 and n < 32: - pytest.skip(f"num_threads=256 requires N >= 32, but N={n}") - - -def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype, num_threads=128): - run_gemm_rs(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k, num_threads=num_threads) - - -def run_gemm_rs_false_false(m, n, k, num_threads=128): - run_gemm_rs(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, num_threads=num_threads) - - -def run_gemm_rs_true_false(m, n, k, num_threads=128): - run_gemm_rs(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k, num_threads=num_threads) - - -def run_gemm_rs_true_true(m, n, k, num_threads=128): - run_gemm_rs(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k, num_threads=num_threads) - - -def run_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype, num_threads=128): - run_gemm_sr(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k, num_threads=num_threads) - - -def run_gemm_sr_false_false(m, n, k, num_threads=128): - run_gemm_sr(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, num_threads=num_threads) - - -def run_gemm_sr_true_false(m, n, k, num_threads=128): - run_gemm_sr(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k, num_threads=num_threads) - - -def run_gemm_sr_true_true(m, n, k, num_threads=128): - run_gemm_sr(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k, num_threads=num_threads) - - -def run_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype, num_threads=128): - run_gemm_rr(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k, num_threads=num_threads) - - -def run_gemm_rr_false_false(m, n, k, num_threads=128): - run_gemm_rr(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, num_threads=num_threads) - - -def run_gemm_rr_true_false(m, n, k, num_threads=128): - run_gemm_rr(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k, num_threads=num_threads) - - -def run_gemm_rr_true_true(m, n, k, num_threads=128): - run_gemm_rr(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k, num_threads=num_threads) - - -TRANS_CASES = [ - pytest.param(False, False, id="nn"), - pytest.param(False, True, id="nt"), - pytest.param(True, False, id="tn"), - pytest.param(True, True, id="tt"), -] - - -@pytest.fixture(scope="module", autouse=True) -def _setup_tilelang_environment(): - tilelang.disable_cache() - tilelang.testing.set_random_seed(42) - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype,num_threads", FALSE_TRUE_CASES) -def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype, num_threads): - import torch - - _skip_if_threads_exceed_n(num_threads, n) - - required_torch_attrs = { - in_dtype, - out_dtype, - accum_dtype, - } - for attr in required_torch_attrs: - if not hasattr(torch, attr): - pytest.skip(f"Torch does not expose dtype {attr}") - run_gemm( - m, - n, - k * 3, - False, - True, - in_dtype, - out_dtype, - accum_dtype, - m, - n, - k, - num_threads=num_threads, - ) - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") -@pytest.mark.parametrize("num_threads", NUM_THREADS_VALUES, ids=lambda v: f"threads{v}") -def test_gemm_false_false(m, n, k, num_threads): - _skip_if_threads_exceed_n(num_threads, n) - run_gemm( - m, - n, - k * 3, - False, - False, - T.float16, - T.float16, - T.float16, - m, - n, - k, - num_threads=num_threads, - ) - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") -@pytest.mark.parametrize("num_threads", NUM_THREADS_VALUES, ids=lambda v: f"threads{v}") -def test_gemm_true_false(m, n, k, num_threads): - _skip_if_threads_exceed_n(num_threads, n) - run_gemm( - m, - n, - k * 3, - True, - False, - T.float16, - T.float16, - T.float16, - m, - n, - k, - num_threads=num_threads, - ) - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") -@pytest.mark.parametrize("num_threads", NUM_THREADS_VALUES, ids=lambda v: f"threads{v}") -def test_gemm_true_true(m, n, k, num_threads): - _skip_if_threads_exceed_n(num_threads, n) - run_gemm( - m, - n, - k * 3, - True, - True, - T.float16, - T.float16, - T.float16, - m, - n, - k, - num_threads=num_threads, - ) - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype,num_threads", FALSE_TRUE_CASES) -def test_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype, num_threads): - _skip_if_threads_exceed_n(num_threads, n) - _ensure_torch_dtypes(in_dtype, out_dtype, accum_dtype) - run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype, num_threads) - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") -@pytest.mark.parametrize("num_threads", NUM_THREADS_VALUES, ids=lambda v: f"threads{v}") -def test_gemm_rs_false_false(m, n, k, num_threads): - _skip_if_threads_exceed_n(num_threads, n) - _ensure_torch_dtypes(T.float16) - run_gemm_rs_false_false(m, n, k, num_threads) - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") -@pytest.mark.parametrize("num_threads", NUM_THREADS_VALUES, ids=lambda v: f"threads{v}") -def test_gemm_rs_true_false(m, n, k, num_threads): - _skip_if_threads_exceed_n(num_threads, n) - _ensure_torch_dtypes(T.float16) - run_gemm_rs_true_false(m, n, k, num_threads) - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") -@pytest.mark.parametrize("num_threads", NUM_THREADS_VALUES, ids=lambda v: f"threads{v}") -def test_gemm_rs_true_true(m, n, k, num_threads): - _skip_if_threads_exceed_n(num_threads, n) - _ensure_torch_dtypes(T.float16) - run_gemm_rs_true_true(m, n, k, num_threads) - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype,num_threads", FALSE_TRUE_CASES) -def test_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype, num_threads): - _skip_if_threads_exceed_n(num_threads, n) - _ensure_torch_dtypes(in_dtype, out_dtype, accum_dtype) - run_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype, num_threads) - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") -@pytest.mark.parametrize("num_threads", NUM_THREADS_VALUES, ids=lambda v: f"threads{v}") -def test_gemm_sr_false_false(m, n, k, num_threads): - _skip_if_threads_exceed_n(num_threads, n) - _ensure_torch_dtypes(T.float16) - run_gemm_sr_false_false(m, n, k, num_threads) - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") -@pytest.mark.parametrize("num_threads", NUM_THREADS_VALUES, ids=lambda v: f"threads{v}") -def test_gemm_sr_true_false(m, n, k, num_threads): - _skip_if_threads_exceed_n(num_threads, n) - _ensure_torch_dtypes(T.float16) - run_gemm_sr_true_false(m, n, k, num_threads) - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") -@pytest.mark.parametrize("num_threads", NUM_THREADS_VALUES, ids=lambda v: f"threads{v}") -def test_gemm_sr_true_true(m, n, k, num_threads): - _skip_if_threads_exceed_n(num_threads, n) - _ensure_torch_dtypes(T.float16) - run_gemm_sr_true_true(m, n, k, num_threads) - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype,num_threads", FALSE_TRUE_CASES) -def test_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype, num_threads): - _skip_if_threads_exceed_n(num_threads, n) - _ensure_torch_dtypes(in_dtype, out_dtype, accum_dtype) - run_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype, num_threads) - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") -@pytest.mark.parametrize("num_threads", NUM_THREADS_VALUES, ids=lambda v: f"threads{v}") -def test_gemm_rr_false_false(m, n, k, num_threads): - _skip_if_threads_exceed_n(num_threads, n) - _ensure_torch_dtypes(T.float16) - run_gemm_rr_false_false(m, n, k, num_threads) - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") -@pytest.mark.parametrize("num_threads", NUM_THREADS_VALUES, ids=lambda v: f"threads{v}") -def test_gemm_rr_true_false(m, n, k, num_threads): - _skip_if_threads_exceed_n(num_threads, n) - _ensure_torch_dtypes(T.float16) - run_gemm_rr_true_false(m, n, k, num_threads) - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") -@pytest.mark.parametrize("num_threads", NUM_THREADS_VALUES, ids=lambda v: f"threads{v}") -def test_gemm_rr_true_true(m, n, k, num_threads): - _skip_if_threads_exceed_n(num_threads, n) - _ensure_torch_dtypes(T.float16) - run_gemm_rr_true_true(m, n, k, num_threads) - - -if __name__ == "__main__": - run_gemm( - M=64, - N=192, - K=64, - trans_A=False, - trans_B=False, - in_dtype=T.bfloat16, - out_dtype=T.bfloat16, - dtypeAccum=T.float32, - block_M=64, - block_N=192, - block_K=64, - num_stages=0, - num_threads=256, - ) - - # # Test Pass - # for m in [64, 128, 256]: - # for n in [16, 32, 64, 128]: - # for k in [16, 32, 64, 128]: - # print(f"======================= Test {m} {n} {k} False True =============================") - # run_gemm(m, n, k * 3, False, True, T.float16, T.float16, T.float16, m, n, k, 2, 128) - # print(f"Test {m} {n} {k} Pass") - - # # Test Pass - # for m in [64, 128, 256]: - # for n in [16, 32, 64, 128]: - # for k in [16, 32, 64, 128]: - # print(f"======================= Test {m} {n} {k} False False =============================") - # run_gemm(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, 2, 128) - # print(f"Test {m} {n} {k} Pass") - - # # Test Pass - # for m in [64, 128, 256]: - # for n in [16, 32, 64, 128]: - # for k in [16, 32, 64, 128]: - # print(f"======================= Test {m} {n} {k} True False =============================") - # run_gemm(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k, 2, 128) - # print(f"Test {m}, {n} {k} Pass") - # print(f"Test {n} Pass") - - # # Test Pass - # for m in [64, 128, 256]: - # for n in [16, 32, 64, 128]: - # for k in [16, 32, 64, 128]: - # print(f"======================= Test {m} {n} {k} True True =============================") - # run_gemm(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k, 2, 128) - # print(f"Test {m}, {n} {k} Pass") - # print(f"Test {n} Pass") - - # Test Pass - # for m in [64, 128, 256]: - # for n in [16, 32, 64, 128]: - # for k in [16, 32, 64, 128]: - # print(f"======================= Test {m} {n} {k} False True =============================") - # run_gemm_rs(m, n, k * 3, False, True, T.float16, T.float16, T.float16, m, n, k, 2, 128) - # print(f"Test {m} {n} {k} Pass") - - # for n in [16, 32, 64, 128]: - # for k in [16, 32, 64, 128]: - # run_gemm_rs(64, n, k, False, False, T.float16, T.float16, T.float16, 64, n, k, 0, 256) - # print(f"Test {64} {n} {k} Pass") - - # for n in [16, 32, 64, 128]: - # for k in [16, 32, 64, 128]: - # run_gemm(64, n, k, False, False, T.float16, T.float16, T.float16, 64, n, k, 0, 256) - # print(f"Test {64} {n} {k} Pass") diff --git a/maint/gemm_v2/correctness_evaluation_sm70.py b/maint/gemm_v2/correctness_evaluation_sm70.py deleted file mode 100644 index 606d102611..0000000000 --- a/maint/gemm_v2/correctness_evaluation_sm70.py +++ /dev/null @@ -1,350 +0,0 @@ -# pytest maint/gemm_v2/correctness_evaluation_sm70.py -n 32 -import pytest -from tilelang import tvm as tvm -import tilelang.testing -from tilelang import language as T - - -def matmul( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - num_stages, - threads, -): - A_shape = (K, M) if trans_A else (M, K) - B_shape = (N, K) if trans_B else (K, N) - A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn") - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - if trans_A: - T.copy(A[k * block_K, by * block_M], A_shared) - else: - T.copy(A[by * block_M, k * block_K], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - # T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) - T.gemm_v2(A_shared, B_shared, C_local, trans_A, trans_B) - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - -def _compile_and_check( - program, - trans_A, - trans_B, - in_dtype, - out_dtype, -): - kernel = tilelang.compile( - program, - out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - # tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False, - }, - ) - - print(kernel.get_kernel_source()) - - profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) - - def ref_program(A, B): - import torch - - if trans_A: - A = A.T - if trans_B: - B = B.T - if in_dtype == T.float32: - A = (A.view(torch.int32) - 0x1000).view(torch.float32) - B = (B.view(torch.int32) - 0x1000).view(torch.float32) - C = torch.matmul(A.to(torch.float), B.to(torch.float)) - C = C.to(torch.__getattribute__(out_dtype)) - return C - - profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) - print("assert_allclose") - - -def run_gemm( - M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128, -): - program = matmul( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - num_stages, - num_threads, - ) - - _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) - - -def matmul_rs( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - num_stages, - threads, -): - A_shape = (K, M) if trans_A else (M, K) - B_shape = (N, K) if trans_B else (K, N) - A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - A_frag_shape = A_shared_shape - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn") - A_frag = T.alloc_fragment(A_frag_shape, in_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - if trans_A: - T.copy(A[k * block_K, by * block_M], A_shared) - else: - T.copy(A[by * block_M, k * block_K], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.copy(A_shared, A_frag) - T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B) - # T.gemm(A_frag, B_shared, C_local, trans_A, trans_B) - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - -def run_gemm_rs( - M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128, -): - program = matmul_rs( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - num_stages, - num_threads, - ) - _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) - - -M_VALUES = [64, 128] -N_VALUES = [32, 64, 128] -K_VALUES = [16, 32, 64] -FALSE_TRUE_CASES = [ - pytest.param( - k, - T.float16, - T.float16, - T.float16, - id=f"K{k}-float16-float16-float16", - ) - for k in K_VALUES -] + [ - pytest.param( - k, - T.float16, - T.float16, - T.float32, - id=f"K{k}-float16-float16-float32", - ) - for k in K_VALUES -] - - -def _ensure_torch_dtypes(*dtype_names): - import torch - - for name in set(dtype_names): - if not hasattr(torch, name): - pytest.skip(f"Torch does not expose dtype {name}") - - -def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): - run_gemm_rs(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k, 2, 128) - - -def run_gemm_rs_false_false(m, n, k): - run_gemm_rs(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, 2, 128) - - -TRANS_CASES = [ - pytest.param(False, False, id="nn"), - pytest.param(False, True, id="nt"), - pytest.param(True, False, id="tn"), - pytest.param(True, True, id="tt"), -] - - -@pytest.fixture(scope="module", autouse=True) -def _setup_tilelang_environment(): - tilelang.disable_cache() - tilelang.testing.set_random_seed(42) - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES) -def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): - import torch - - required_torch_attrs = { - in_dtype, - out_dtype, - accum_dtype, - } - for attr in required_torch_attrs: - if not hasattr(torch, attr): - pytest.skip(f"Torch does not expose dtype {attr}") - run_gemm( - m, - n, - k * 3, - False, - True, - in_dtype, - out_dtype, - accum_dtype, - m, - n, - k, - 2, - 128, - ) - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") -def test_gemm_false_false(m, n, k): - run_gemm( - m, - n, - k * 3, - False, - False, - T.float16, - T.float16, - T.float16, - m, - n, - k, - 2, - 128, - ) - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES) -def test_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): - _ensure_torch_dtypes(in_dtype, out_dtype, accum_dtype) - run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype) - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") -def test_gemm_rs_false_false(m, n, k): - _ensure_torch_dtypes(T.float16) - run_gemm_rs_false_false(m, n, k) - - -if __name__ == "__main__": - tilelang.testing.main() - - # # Test Pass - # for m in [64, 128]: - # for n in [16, 32, 64, 128]: - # for k in [16, 32, 64]: - # print(f"======================= Test {m} {n} {k} False True =============================") - # run_gemm(m, n, k * 3, False, True, T.float16, T.float16, T.float16, m, n, k, 2, 128) - # print(f"Test {m} {n} {k} Pass") - - # # Test Pass - # for m in [64, 128]: - # for n in [16, 32, 64, 128]: - # for k in [16, 32, 64]: - # print(f"======================= Test {m} {n} {k} False False =============================") - # run_gemm(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, 2, 128) - # print(f"Test {m} {n} {k} Pass") diff --git a/maint/gemm_v2/correctness_evaluation_tcgen05.py b/maint/gemm_v2/correctness_evaluation_tcgen05.py deleted file mode 100644 index c516520fad..0000000000 --- a/maint/gemm_v2/correctness_evaluation_tcgen05.py +++ /dev/null @@ -1,231 +0,0 @@ -# pytest correctness_evaluation.py -n 32 -import pytest -from tilelang import tvm as tvm -import tilelang.testing -import tilelang.language as T - - -def matmul( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - num_stages, - threads, -): - A_shape = (K, M) if trans_A else (M, K) - B_shape = (N, K) if trans_B else (K, N) - A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) - C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) - mbar = T.alloc_barrier(1) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - C_shared = T.alloc_shared((block_M, block_N), out_dtype) - - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[bx * block_N, k * block_K], B_shared) - T.tcgen05_gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, clear_accum=k == 0) - T.mbarrier_wait_parity(mbar, k % 2) - - T.copy(C_tmem, C_local) - T.copy(C_local, C_shared) - - T.copy(C_shared, C[by * block_M, bx * block_N]) - - return main - - -def _compile_and_check( - program, - trans_A, - trans_B, - in_dtype, - out_dtype, -): - kernel = tilelang.compile( - program, - out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, - ) - - print(kernel.get_kernel_source()) - - profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) - - def ref_program(A, B): - import torch - - if trans_A: - A = A.T - if trans_B: - B = B.T - if in_dtype == T.float32: - A = (A.view(torch.int32) - 0x1000).view(torch.float32) - B = (B.view(torch.int32) - 0x1000).view(torch.float32) - C = torch.matmul(A.to(torch.float), B.to(torch.float)) - C = C.to(torch.__getattribute__(out_dtype)) - return C - - profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) - print("assert_allclose") - - -def run_gemm( - M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=2, - num_threads=128, -): - if block_N >= 256 or block_M >= 256 or block_K >= 256: - num_stages = 0 - program = matmul( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - num_stages, - num_threads, - ) - - _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) - - -M_VALUES = [32, 64, 128, 256] -N_VALUES = [64, 128, 256, 512] -K_VALUES = [16, 32, 64, 128] -K_VALUES_8Bit = [32, 64, 128] -FALSE_TRUE_CASES = ( - [ - pytest.param( - k, - T.float16, - T.float32, - T.float32, - id=f"K{k}-float16-float-float", - ) - for k in K_VALUES - ] - + [ - pytest.param( - k, - T.float8_e5m2, - T.float32, - T.float32, - id="K32-float8_e5m2-float32-float32", - ) - for k in K_VALUES_8Bit - ] - + [ - pytest.param( - k, - T.int8, - T.int32, - T.int32, - id="K32-int8-int32-int32", - ) - for k in K_VALUES_8Bit - ] -) - -TRANS_CASES = [ - pytest.param(False, True, id="nt"), -] - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES) -def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): - import torch - - required_torch_attrs = { - in_dtype, - out_dtype, - accum_dtype, - } - for attr in required_torch_attrs: - if not hasattr(torch, attr): - pytest.skip(f"Torch does not expose dtype {attr}") - run_gemm( - m, - n, - k * 3, - False, - True, - in_dtype, - out_dtype, - accum_dtype, - m, - n, - k, - ) - - -if __name__ == "__main__": - tilelang.testing.main() - - # # Test Pass - # for m in [32, 64, 128, 256]: - # for n in [16, 32, 64, 128]: - # for k in [16, 32, 64, 128]: - # if m in [32, 64] and (n not in [64, 128, 256]): - # continue - # print(f"======================= Test {m} {n} {k} False True =============================") - # run_gemm(m, n, k * 3, False, True, T.float16, T.float, T.float, m, n, k, 2, 128) - # print(f"Test {m} {n} {k} Pass") - - # # Test Pass - # for m in [32, 64, 128, 256]: - # for n in [32, 64, 128]: - # for k in [16, 32, 64, 128]: - # if m in [32, 64] and (n not in [64, 128, 256]): - # continue - # print(f"======================= Test {m} {n} {k} False True =============================") - # run_gemm(m, n, k * 3, False, True, T.float16, T.float, T.float, m, n, k, 2, 256) - # print(f"Test {m} {n} {k} Pass") - - # # Test Pass - # for m in [32, 64, 128, 256]: - # for n in [16, 32, 64, 128]: - # for k in [32, 64, 128]: - # if m in [32, 64] and (n not in [64, 128, 256]): - # continue - # print(f"======================= Test {m} {n} {k} False True =============================") - # run_gemm(m, n, k * 3, False, True, T.float8_e5m2, T.float, T.float, m, n, k, 2, 128) diff --git a/maint/gemm_v2/correctness_evaluation_tcgen05_2cta.py b/maint/gemm_v2/correctness_evaluation_tcgen05_2cta.py deleted file mode 100644 index 3d83b575ee..0000000000 --- a/maint/gemm_v2/correctness_evaluation_tcgen05_2cta.py +++ /dev/null @@ -1,168 +0,0 @@ -# pytest correctness_evaluation_tcgen05_2cta.py -n 32 -import pytest -from tilelang import tvm as tvm -import tilelang -import tilelang.testing -import tilelang.language as T - -tilelang.disable_cache() - - -def matmul_2cta( - M, - N, - K, - block_M, - block_N, - block_K, - in_dtype, - out_dtype, - accum_dtype, - num_stages, -): - @T.prim_func - def main( - A: T.Tensor((M, K), in_dtype), - B: T.Tensor((K, N), in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128, cluster_dims=2) as (bx, by): - A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) - B_shared = T.alloc_shared((num_stages, block_K, block_N // 2), in_dtype) # each CTA holds half of B - C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - C_shared = T.alloc_shared((block_M, block_N), out_dtype) - loaded = T.alloc_cluster_barrier([32 * 2] * num_stages) - consumed = T.alloc_cluster_barrier([1] * num_stages) - tmem_full = T.alloc_barrier([1]) - - tx = T.get_thread_binding() - cta_id = T.block_rank_in_cluster() - T.assume(cta_id < 2) - - T.use_swizzle(16) - - if tx < 32: # warp 0: issue TMA loads - for k in T.serial(T.ceildiv(K, block_K)): - T.mbarrier_wait_parity(consumed[k % num_stages], ((k // num_stages) & 1) ^ 1) - T.tma_copy( - A[bx * block_M : (bx + 1) * block_M, k * block_K : (k + 1) * block_K], - A_shared[k % num_stages, :, :], - barrier=loaded[k % num_stages], - ) - T.tma_copy( - B[k * block_K : (k + 1) * block_K, (by * 2 + cta_id) * (block_N // 2) : (by * 2 + cta_id + 1) * (block_N // 2)], - B_shared[k % num_stages, :, :], - barrier=loaded[k % num_stages], - ) - T.mbarrier_arrive(loaded[k % num_stages], 0) # arrive on leader CTA's barrier - elif cta_id == 0 and tx < 64: # warp 1 on leader CTA: issue tcgen5 MMA - for k in T.serial(T.ceildiv(K, block_K)): - T.mbarrier_wait_parity(loaded[k % num_stages], (k // num_stages) & 1) - T.tcgen05_gemm( - A_shared[k % num_stages, :, :], - B_shared[k % num_stages, :, :], - C_tmem, - mbar=consumed[k % num_stages], - clear_accum=k == 0, - use_2cta=True, - ) - T.tcgen05_mma_arrive(tmem_full, arrive_2cta=True) - - T.mbarrier_wait_parity(tmem_full, 0) - T.copy(C_tmem, C_local) - T.copy(C_local, C_shared) - T.copy(C_shared, C[bx * block_M, by * block_N]) - - return main - - -def _compile_and_check(program, out_dtype): - kernel = tilelang.compile(program, out_idx=[2]) - - print(kernel.get_kernel_source()) - - profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) - - def ref_program(A, B): - import torch - - C = torch.matmul(A.to(torch.float), B.to(torch.float)) - return C.to(torch.__getattribute__(out_dtype)) - - profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) - print("assert_allclose passed") - - -def run_gemm( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - block_M, - block_N, - block_K, - num_stages=4, -): - program = matmul_2cta( - M, - N, - K, - block_M, - block_N, - block_K, - in_dtype, - out_dtype, - accum_dtype, - num_stages, - ) - _compile_and_check(program, out_dtype) - - -M_VALUES = [64, 128, 256] -N_VALUES = [64, 128, 256] - -# atom_k=16 for fp16/bf16 (K%16==0), atom_k=32 for fp8/int8 (K%32==0) -K_VALUES_16 = [16, 32, 64, 128] -K_VALUES_32 = [32, 64, 128] - -# Dtype cases: (block_K, in_dtype, out_dtype, accum_dtype) -FP16_CASES = [pytest.param(k, T.float16, T.float32, T.float32, id=f"K{k}-fp16-fp32-fp32") for k in K_VALUES_16] - -FP8_E5M2_CASES = [pytest.param(k, T.float8_e5m2, T.float32, T.float32, id=f"K{k}-fp8e5m2-fp32-fp32") for k in K_VALUES_32] - -INT8_CASES = [pytest.param(k, T.int8, T.int32, T.int32, id=f"K{k}-int8-int32-int32") for k in K_VALUES_32] - -ALL_DTYPE_CASES = FP16_CASES + FP8_E5M2_CASES + INT8_CASES - - -@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") -@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") -@pytest.mark.parametrize("block_k,in_dtype,out_dtype,accum_dtype", ALL_DTYPE_CASES) -def test_gemm_2cta(m, n, block_k, in_dtype, out_dtype, accum_dtype): - import torch - - for attr in {in_dtype, out_dtype, accum_dtype}: - if not hasattr(torch, attr): - pytest.skip(f"Torch does not expose dtype {attr}") - - # M = 2 * block_M so ceildiv(M, block_M) = 2 (cluster needs >= 2 tiles in M dim) - # K = 3 * block_K to exercise multi-iteration pipelining - k = block_k * 3 - run_gemm( - m * 2, - n, - k, - in_dtype, - out_dtype, - accum_dtype, - block_M=m, - block_N=n, - block_K=block_k, - ) - - -if __name__ == "__main__": - tilelang.testing.main() diff --git a/maint/gemm_v2/latency.py b/maint/gemm_v2/latency.py deleted file mode 100644 index b7b2a2af95..0000000000 --- a/maint/gemm_v2/latency.py +++ /dev/null @@ -1,98 +0,0 @@ -import tilelang -import tilelang.language as T -import argparse - -parser = argparse.ArgumentParser() -parser.add_argument("--use_v2", action="store_true") -args = parser.parse_args() - -use_v2 = args.use_v2 - - -# @tilelang.jit(target="cuda") -# target currently can be "cuda" or "hip" or "cpu". -# if not specified, it will be inferred from the input tensors during compile time -@tilelang.jit -def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): - @T.prim_func - def matmul_relu_kernel( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), - ): - # Initialize Kernel Context - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_K, block_N), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - - # Enable rasterization for better L2 cache locality (Optional) - # T.use_swizzle(panel_size=10, enable=True) - - # Clear local accumulation - T.clear(C_local) - - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): - # Copy tile of A - # This is a sugar syntax for parallelized copy - T.copy(A[by * block_M, ko * block_K], A_shared) - - # Copy tile of B - T.copy(B[ko * block_K, bx * block_N], B_shared) - - # Perform a tile-level GEMM on the shared buffers - # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs - if use_v2: - T.gemm_v2(A_shared, B_shared, C_local) - else: - T.gemm_v1(A_shared, B_shared, C_local) - - # relu - for i, j in T.Parallel(block_M, block_N): - C_local[i, j] = T.max(C_local[i, j], 0) - - # Copy result back to global memory - T.copy(C_local, C[by * block_M, bx * block_N]) - - return matmul_relu_kernel - - -M = 16384 # M = T.dynamic("m") if you want to use dynamic shape -N = 16384 -K = 16384 -block_M = 128 -block_N = 128 -block_K = 32 - -# 1. Define the kernel (matmul) and compile/lower it into an executable module -matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) - -# 3. Test the kernel in Python with PyTorch data -import torch - -# Create random input tensors on the GPU -a = torch.randn(M, K, device="cuda", dtype=torch.float16) -b = torch.randn(K, N, device="cuda", dtype=torch.float16) -c = torch.empty(M, N, device="cuda", dtype=torch.float16) - -# Run the kernel through the Profiler -matmul_relu_kernel(a, b, c) - -print(c) -# Reference multiplication using PyTorch -ref_c = torch.relu(a @ b) - -# Validate correctness -torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) -print("Kernel output matches PyTorch reference.") - -# 4. Retrieve and inspect the generated CUDA source (optional) -# cuda_source = jit_kernel.get_kernel_source() -# print("Generated CUDA kernel:\n", cuda_source) - -# 5.Profile latency with kernel -profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) - -latency = profiler.do_bench() - -print(f"Latency: {latency} ms") diff --git a/maint/gemm_v2/latency_gemm.py b/maint/gemm_v2/latency_gemm.py deleted file mode 100644 index 5f0450e023..0000000000 --- a/maint/gemm_v2/latency_gemm.py +++ /dev/null @@ -1,98 +0,0 @@ -import tilelang -import tilelang.language as T -import argparse - -parser = argparse.ArgumentParser() -parser.add_argument("--use_v2", action="store_true") -args = parser.parse_args() - -use_v2 = args.use_v2 - - -# @tilelang.jit(target="cuda") -# target currently can be "cuda" or "hip" or "cpu". -# if not specified, it will be inferred from the input tensors during compile time -@tilelang.jit -def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): - @T.prim_func - def matmul_relu_kernel( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), - ): - # Initialize Kernel Context - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_K, block_N), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - - # Enable rasterization for better L2 cache locality (Optional) - # T.use_swizzle(panel_size=10, enable=True) - - # Clear local accumulation - T.clear(C_local) - - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): - # Copy tile of A - # This is a sugar syntax for parallelized copy - T.copy(A[by * block_M, ko * block_K], A_shared) - - # Copy tile of B - T.copy(B[ko * block_K, bx * block_N], B_shared) - - # Perform a tile-level GEMM on the shared buffers - # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs - if use_v2: - T.gemm_v2(A_shared, B_shared, C_local) - else: - T.gemm_v1(A_shared, B_shared, C_local) - - # relu - for i, j in T.Parallel(block_M, block_N): - C_local[i, j] = T.max(C_local[i, j], 0) - - # Copy result back to global memory - T.copy(C_local, C[by * block_M, bx * block_N]) - - return matmul_relu_kernel - - -M = 16384 # M = T.dynamic("m") if you want to use dynamic shape -N = 16384 -K = 16384 -block_M = 128 -block_N = 128 -block_K = 64 - -# 1. Define the kernel (matmul) and compile/lower it into an executable module -matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) - -# 3. Test the kernel in Python with PyTorch data -import torch - -# Create random input tensors on the GPU -a = torch.randn(M, K, device="cuda", dtype=torch.float16) -b = torch.randn(K, N, device="cuda", dtype=torch.float16) -c = torch.empty(M, N, device="cuda", dtype=torch.float16) - -# Run the kernel through the Profiler -matmul_relu_kernel(a, b, c) - -print(c) -# Reference multiplication using PyTorch -ref_c = torch.relu(a @ b) - -# Validate correctness -torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) -print("Kernel output matches PyTorch reference.") - -# 4. Retrieve and inspect the generated CUDA source (optional) -# cuda_source = jit_kernel.get_kernel_source() -# print("Generated CUDA kernel:\n", cuda_source) - -# 5.Profile latency with kernel -profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) - -latency = profiler.do_bench() - -print(f"Latency: {latency} ms") diff --git a/maint/gemm_v2/latency_mha_fwd_bhsd.py b/maint/gemm_v2/latency_mha_fwd_bhsd.py deleted file mode 100644 index 7a83d7cec8..0000000000 --- a/maint/gemm_v2/latency_mha_fwd_bhsd.py +++ /dev/null @@ -1,228 +0,0 @@ -import torch -import torch.nn.functional as F -import tilelang -from tilelang.autotuner import * -import tilelang.language as T -import itertools -import argparse -from functools import partial - -parser = argparse.ArgumentParser() -parser.add_argument("--batch", type=int, default=128, help="batch size") -parser.add_argument("--heads", type=int, default=16, help="heads") -parser.add_argument("--seq_q", type=int, default=1024, help="query sequence length") -parser.add_argument("--seq_kv", type=int, default=1024, help="key/value sequence length") -parser.add_argument("--dim", type=int, default=256, help="dim") -parser.add_argument("--is_causal", action="store_true", help="causal") -parser.add_argument("--tune", action="store_true", help="tune configs") -parser.add_argument("--use_v2", action="store_true") - -args = parser.parse_args() - -use_v2 = args.use_v2 - - -def get_configs(): - iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256]) - return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] - - -@autotune(configs=get_configs(), warmup=10, rep=10) -@tilelang.jit( - out_idx=[3], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, -) -def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=128): - scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) - q_shape = [batch, heads, seq_q, dim] - kv_shape = [batch, heads, seq_kv, dim] - dtype = T.float16 - accum_dtype = T.float32 - - past_len = seq_kv - seq_q - assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" - - @T.macro - def MMA0( - K: T.Tensor(kv_shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - q_idx = bx * block_M + i + past_len - k_idx = k * block_N + j - acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - if use_v2: - T.gemm_v2(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - else: - T.gemm_v1(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(kv_shape, dtype), - V_shared: T.SharedBuffer([block_N, dim], dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) - # T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - if use_v2: - T.gemm_v2(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - else: - T.gemm_v1(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - - @T.prim_func - def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - ): - with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_M, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim], dtype) - O_shared = T.alloc_shared([block_M, dim], dtype) - acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) - acc_o = T.alloc_fragment([block_M, dim], accum_dtype) - scores_max = T.alloc_fragment([block_M], accum_dtype) - scores_max_prev = T.alloc_fragment([block_M], accum_dtype) - scores_scale = T.alloc_fragment([block_M], accum_dtype) - scores_sum = T.alloc_fragment([block_M], accum_dtype) - logsum = T.alloc_fragment([block_M], accum_dtype) - - T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - - loop_range = ( - T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) - if is_causal - else T.ceildiv(seq_kv, block_N) - ) - - for k in T.Pipelined(loop_range, num_stages=num_stages): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] /= logsum[i] - T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) - - return main - - -def ref_program(Q, K, V, is_causal): - dim = Q.size(-1) - scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) - scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) - if is_causal: - seq_q = Q.size(2) - seq_kv = K.size(2) - mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) - mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float("-inf")) - attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) - return output - - -def main( - batch: int = 1, - heads: int = 1, - seq_q: int = 256, - seq_kv: int = 256, - dim: int = 64, - is_causal: bool = False, - tune: bool = False, -): - flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim - total_flops = 2 * flops_per_matmul - if is_causal: - total_flops *= 0.5 - - if not tune: - kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=128) - print(kernel.get_kernel_source()) - ref_program_processed = partial(ref_program, is_causal=is_causal) - - profiler = kernel.get_profiler() - profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) - print("All checks pass.") - latency = profiler.do_bench(ref_program_processed, warmup=500) - print(f"Ref: {latency:.2f} ms") - print(f"Ref: {total_flops / latency * 1e-9:.2f} TFlops") - latency = profiler.do_bench(warmup=500) - print(f"Tile-lang: {latency:.2f} ms") - print(f"Tile-lang: {total_flops / latency * 1e-9:.2f} TFlops") - else: - kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal) - best_latency = kernel.latency - best_config = kernel.config - ref_latency = kernel.ref_latency - print(f"Best latency: {best_latency}") - print(f"Best TFlops: {total_flops / best_latency * 1e-9}") - print(f"Best config: {best_config}") - print(f"Ref latency: {ref_latency}") - - -if __name__ == "__main__": - tilelang.disable_cache() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) diff --git a/maint/scripts/run_local_ci_test.sh b/maint/scripts/run_local_ci_test.sh index adbf9627a9..a1de0d1ec7 100755 --- a/maint/scripts/run_local_ci_test.sh +++ b/maint/scripts/run_local_ci_test.sh @@ -190,7 +190,5 @@ python -m pytest "${PYTEST_ARGS_DEVICE[@]}" . "${PYTEST_ARGS_COMMON[@]}" cd .. || exit 1 # Run pytest in parallel for all tests in the testing/python directory. -# CuTeDSL backend now defaults to GEMM v2 (WGMMA descriptor-based). -# Set TILELANG_USE_GEMM_V1=1 only if you need to debug V1-specific issues. cd testing/python || exit 1 python -m pytest "${PYTEST_ARGS_DEVICE[@]}" . "${PYTEST_ARGS_COMMON[@]}" diff --git a/src/op/builtin.h b/src/op/builtin.h index f4da8bb7f9..276800bf6c 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -696,7 +696,13 @@ TVM_DLL const Op &tvm_rdna_wmma_store(); /*! * \brief tilelang intrinsic for general matrix multiplication (GEMM). * - * This op is used to represent a generic GEMM operation in tilelang. + * This op wraps a templated `tl::gemm_*<...>` call into the generated device + * code. Python-side lowering backends that want to delegate to the C++ + * template implementations in `src/tl_templates//gemm*.h` can emit a + * call to this builtin directly via + * T.call_intrin("handle", "tl.tl_gemm", op_instance_str, A_ptr, B_ptr, + * C_ptr) where `op_instance_str` is the fully-instantiated `tl::gemm_ss` template string. */ TVM_DLL const Op &tl_gemm(); diff --git a/src/op/gemm.cc b/src/op/gemm.cc index ebe717bd44..2facb9d9f3 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -21,35 +21,21 @@ namespace tl { using namespace tir; /** - * @brief Construct a Gemm operator from serialized TL arguments and a buffer - * map. + * @brief Construct a Gemm operator from serialized TL arguments. * - * This constructor deserializes operator parameters from `args` and resolves - * buffer references via `vmap`, populating an internal GemmNode with: - * - device pointers for A, B, C and their corresponding Buffer objects, - * - transpose flags for A and B, - * - matrix dimensions M, N, K, - * - warp allocation policy and clear_accum flag, - * - strides and memory offsets for A and B, - * - optional kPack (must be 1 or 2) and optional internal wg_wait. - * - * The populated GemmNode is stored into the wrapper's internal `data_`. + * Deserializes operator parameters from `args` and resolves buffer references, + * populating an internal GemmNode with buffers, transpose flags, M/N/K, + * warp policy, clear_accum, strides, offsets, optional kPack/wg_wait, and + * optional mbarrier. * * @param args Positional serialized arguments produced by the TL frontend: * expected layout is: * [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool), * M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool), * stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int), - * (optional) kPack (Int), (optional) internal wg_wait (Int)] - * - * @note If `kPack` is provided it must be 1; otherwise the constructor - * fails with an ICHECK (runtime assertion). No other validation is - * performed here. + * (optional) kPack (Int), (optional) internal wg_wait (Int), + * (optional) mbar (BufferLoad), cCoord_y (PrimExpr), cCoord_x (PrimExpr)] */ -// NormalizeToBufferRegion moved to src/op/utils.{h,cc} - -// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc} - Gemm::Gemm(Array args, Map annotations) { ObjectPtr node = tvm::ffi::make_object(); @@ -115,14 +101,6 @@ AccessRegions GemmNode::GetAccessRegions() const { return result; } -/** - * @brief Create a copy of this GemmNode as a TileOperator. - * - * Constructs a new GemmNode by copying the current node state and returns it - * wrapped in a Gemm TileOperator. - * - * @return TileOperator A Gemm operator that owns a copy of this node. - */ TileOperator GemmNode::Clone() const { auto op = tvm::ffi::make_object(*this); return Gemm(op); @@ -133,6 +111,7 @@ bool GemmNode::allowTcgen5Mma(Target target) const { IsSharedBuffer(b_) && c_.scope() == "shared.tmem"; if (!TargetIsSm100(target) || !scope_ok) return false; + // For TS variant (A from TMEM), use B's dtype as the input dtype DataType ab_dtype = (a_.scope() == "shared.tmem") ? b_->dtype : a_->dtype; return GetTCGEN5MMAMeta(m_, n_, k_, ab_dtype, c_->dtype).first; } @@ -172,9 +151,11 @@ GemmInst GemmNode::getGemmInst(int block_size, Target target) const { } return GemmInst::kTCGEN5MMA; } - if (allowTcgen5Mma(target)) { + bool allow_tcgen5mma = allowTcgen5Mma(target); + bool allow_wgmma = allowWgmma(block_size, target); + if (allow_tcgen5mma) { return GemmInst::kTCGEN5MMA; - } else if (allowWgmma(block_size, target)) { + } else if (allow_wgmma) { return GemmInst::kWGMMA; } else if (TargetIsCDNA(target)) { return GemmInst::kMFMA; @@ -182,8 +163,10 @@ GemmInst GemmNode::getGemmInst(int block_size, Target target) const { return GemmInst::kWMMA; } else if (TargetIsCuda(target)) { return GemmInst::kMMA; + } else if (TargetIsCPU(target)) { + return GemmInst::kScalar; } else { - ICHECK(0) << "Unsupported target for gemm: " << target; + ICHECK(0) << "Unsupported target for gemm: " << target->str(); return GemmInst::kMMA; } } @@ -379,32 +362,8 @@ std::pair GemmWarpPolicyNode::computeWarpPartition( /** * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM. * - * Evaluates device-memory placement, data-type combinations, transpose flags, - * and K divisibility constraints required for the Hopper WGMMA code path. - * - * The check returns true only when: - * - B resides in shared memory ("shared" or "shared.dyn"); and - * - (C, A, B) dtypes match one of the supported combinations below and K - * satisfies the required alignment; and - * - for combinations that require specific orientations, A is not transposed - * and B is transposed. - * - * Supported combinations and constraints: - * - C=float16: - * - A=float16, B=float16: K % 16 == 0 - * - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K % - * 32 == 0 - * - C=float32: - * - A=float16, B=float16: K % 16 == 0 - * - A=bfloat16, B=bfloat16: K % 16 == 0 - * - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0 - * - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0 - * - C=int32: - * - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B) - * and K % 32 == 0 - * - * @return true if WGMMA is supported for the current buffers, dtypes, and - * transpose/shape constraints; false otherwise. + * Returns true only when B resides in shared memory and the (C, A, B) dtype + * combination plus K alignment matches one of the supported WGMMA variants. */ bool GemmNode::checkWgmma() const { if (b_.scope() != "shared.dyn" && b_.scope() != "shared") { @@ -447,436 +406,71 @@ bool GemmNode::checkWgmma() const { } } -/** - * @brief Parse and return the numeric GPU architecture from a Target's "arch" - * attribute. - * - * Examines the target's "arch" string and, if it matches the pattern - * "sm_", returns as an int. If the attribute is present but does not - * match that pattern, returns 0. - * - * Preconditions: the target must have an "arch" attribute (this is checked via - * ICHECK). - * - * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if - * the arch string does not match "sm_". - */ -static int GetArchInt(Target target) { - int arch_int = 0; - auto s = target->GetAttr("arch"); - ICHECK(s.has_value()); - std::string arch = s.value(); - if (arch.rfind("sm_", 0) == 0) { - arch_int = std::stoi(arch.substr(3)); - } else { - arch_int = 0; - } - return arch_int; -} - -/** - * @brief Lower the GEMM operator to a TL TIR call expression. - * - * Constructs a tl::gemm call string parameterized by M, N, K, warp partition, - * transpose flags, accumulation clearing, target-specific stride/offset/kPack - * and optional workgroup wait value, then returns an Evaluate(call) node - * invoking tl::tl_gemm with the composed string and the A/B/C buffer handles. - * - * @param T Contains lowering context including thread bounds and target. - * @param analyzer Optional arithmetic analyzer used by lowering (may be - * nullptr). - * @return Stmt A TIR statement representing the evaluated TL GEMM call. - */ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { - auto block_size = *as_const_int(T.thread_bounds->extent); - GemmInst gemm_inst = getGemmInst(block_size, T.target); - auto [warp_m, warp_n] = - policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst); - - // Build access pointers from regions locally - PrimExpr Aptr = - MakeAccessPtrFromRegion(aRegion_, /*r*/ 1, /*require_2d*/ true); - PrimExpr Bptr = - MakeAccessPtrFromRegion(bRegion_, /*r*/ 1, /*require_2d*/ true); - PrimExpr Cptr = - MakeAccessPtrFromRegion(cRegion_, /*rw*/ 3, /*require_2d*/ true); - - std::stringstream ss; - std::string op_name; - - if (gemm_inst == GemmInst::kTCGEN5MMA) { - DataType ab_dtype = (a_.scope() == "shared.tmem") ? b_->dtype : a_->dtype; - auto [can_use_tcgen5mma, meta] = - GetTCGEN5MMAMeta(m_, n_, k_, ab_dtype, c_->dtype); - ICHECK(can_use_tcgen5mma); - ICHECK(IsSharedBuffer(b_)); - ICHECK(c_.scope() == "shared.tmem"); - ICHECK(mbar_.defined()) << "mbar must be provided for TCGEN5MMA"; - if (a_.scope() == "shared.tmem") { - op_name = "tl::tcgen5mma_gemm_ts"; - } else if (IsSharedBuffer(a_)) { - op_name = "tl::tcgen5mma_gemm_ss"; - } else { - ICHECK(0) - << "Unsupported A scope for TCGEN5MMA: " - << a_.scope(); // If this is triggered, it means Tilelang has bugs. - } - ICHECK(wgWait_ == 0 || wgWait_ == -1) - << "TCGEN5MMA only accepts internal wg_wait values 0 or -1. " - "Public T.gemm() uses 0 and synchronization is still managed " - "manually via mbarrier."; - - std::string accum_dtype = ""; - if (c_->dtype.is_float()) { - if (c_->dtype.bits() == 32) { - accum_dtype = "float"; - } - } - ICHECK(!accum_dtype.empty()) - << "Unsupported C dtype for TCGEN5MMA: " << c_->dtype; - ss << op_name << "<" << m_ << ", " << n_ << ", " << k_ << ", "; - ss << meta.atom_m << ", " << meta.atom_n << ", " << meta.atom_k << ", "; - ss << transA_ << ", " << transB_ << ", "; - ss << accum_dtype; - ss << ">"; - - auto C_buffer = T.buffer_remap.count(c_) ? T.buffer_remap[c_] : c_; - Array new_args; - auto mbarPtr = MakeAccessPtrFromBufferLoad(mbar_, /*rw*/ 3); - new_args.push_back(StringImm(ss.str())); - new_args.push_back(Aptr); - new_args.push_back(Bptr); - new_args.push_back(BufferLoad(C_buffer, cCoords_)); - new_args.push_back(mbarPtr); - new_args.push_back(clearAccum_); - auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args); - - // Since TCGEN5MMA atoms provided by CUTLASS always have an internal - // `elect_one_sync()`, we check if we are calling it using full warps - constexpr int warp_size = 32; - ICHECK( - analyzer->CanProveEqual(FloorMod(T.thread_bounds->min, warp_size), 0) && - analyzer->CanProveEqual(FloorMod(T.thread_bounds->extent, warp_size), - 0)) - << "TCGEN5MMA requires thread bounds to be multiples of warp size (32) " - "and aligned to warps."; - Stmt tcgen5mma_call; - if (analyzer->CanProveEqual(T.thread_bounds->extent, warp_size)) { - // If the thread bounds is exactly one warp, we can use the original call - tcgen5mma_call = Evaluate(new_call); - } else { - // Add an if-else clause - tcgen5mma_call = IfThenElse(EQ(FloorDiv(T.thread_var, warp_size), - FloorDiv(T.thread_bounds->min, warp_size)), - Evaluate(new_call)); - } - if (isTcgen05_) { - return tcgen5mma_call; - } + if (const auto f = ffi::Function::GetGlobal("tl.gemm.lower")) { PrimExpr mbar_phase = T.mbar_phase_expr; if (auto explicit_phase = GetAnnotatedMbarPhaseExpr(annotations_)) { mbar_phase = explicit_phase.value(); } - Stmt wait_stmt = Evaluate( - Call(DataType::Handle(), mbarrier_wait_parity(), {mbar_, mbar_phase})); - return SeqStmt({tcgen5mma_call, wait_stmt}); - } - - if (IsFragmentBuffer(a_)) { - ICHECK(!IsFragmentBuffer(b_)); - ICHECK(!transA_) - << "gemm_rs requires the A operand to be in non-transposed layout."; - op_name = "tl::gemm_rs"; - } else if (IsFragmentBuffer(b_)) { - op_name = "tl::gemm_sr"; - } else { - op_name = "tl::gemm_ss"; - } - ICHECK(IsFragmentBuffer(c_)); - - ss << op_name << "<" << m_ << ", " << n_ << ", " << k_ << ", "; - ss << warp_m << ", " << warp_n << ", "; - ss << transA_ << ", " << transB_; - auto clear_accum_bool = clearAccum_.as(); - ICHECK(clear_accum_bool.has_value()) - << "clear_accum must be a constant Bool type, got " << clearAccum_; - ss << ", " << bool(clear_accum_bool.value()); - if (TargetIsCuda(T.target) && (GetArchInt(T.target) >= 75)) { - ss << ", " << strideA_ << ", " << strideB_; - ss << ", " << offsetA_ << ", " << offsetB_; - } - if (TargetIsCDNA(T.target)) { - // for cdna gemm, we need to specify kPack - ss << ", " << kPack_; - } else if (TargetIsHopper(T.target)) { - ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false"); - } - - // Emit internal wg_wait only for Hopper WGMMA lowering. - if (TargetIsHopper(T.target)) { - if (wgWait_ != 0) { - ss << ", " << wgWait_; + // NOTE(wt): Decide GemmInst and compute warp partition on Python side + auto prim_func = Downcast( + (*f)(tvm::ffi::GetRef(this), T.layout_map, T.target, + T.thread_bounds, T.thread_var, mbar_phase)); + ICHECK(prim_func->attrs.defined()); + auto global_symbol = + prim_func->attrs.GetAttr("global_symbol"); + ICHECK(global_symbol.has_value()); + if (prim_func->body.as()) { + BlockRealize block_realize = Downcast(prim_func->body); + auto block = block_realize->block; + { + BlockNode *n = block.CopyOnWrite(); + n->name_hint = global_symbol.value(); + n->annotations.Set(tl::attr::kLexicalAllocScope, + IntImm(DataType::Int(32), 1)); + } + return BlockRealize(block_realize->iter_values, block_realize->predicate, + block); } - } else if (TargetIsSm100(T.target)) { - // NOTE On sm100, only the leading thread issues the TCGEN5MMA instruction - // but all threads need to wait, so we emit another statement for cases - // where wg_wait == 0. - ICHECK(wgWait_ == 0 || wgWait_ == -1) - << "wg_wait must be 0 or -1 for Sm100"; + // wrap with block realize node + Map block_annotations; + block_annotations.Set(tl::attr::kLexicalAllocScope, + IntImm(DataType::Int(32), 1)); + return BlockRealize( + /*iter_values=*/Array(), + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/global_symbol.value(), prim_func->body, + /*init=*/Optional(), /*alloc_buffers=*/{}, + /*match_buffers=*/{}, /*annotations=*/block_annotations)); } else { - ICHECK(wgWait_ == 0) - << "wg_wait must be 0 for non-Hopper and non-Sm100 targets"; + LOG(FATAL) << "No lower function found for gemm"; + return Stmt(); } - ss << ">"; - - auto new_call = Call(DataType::Handle(), tl::tl_gemm(), - Array{StringImm(ss.str()), Aptr, Bptr, Cptr}); - return Evaluate(new_call); } -/** - * @brief Infer and bind target-specific memory/layout mappings for A, B, and C. - * - * Infers per-buffer layouts (fragment or shared-memory layouts) for this GEMM - * operator according to the target architecture, thread bounds, warp - * partitioning, data types, and transpose flags, then binds fragment layouts - * to the thread range when required. - * - * Preconditions: - * - C.scope() == "local.fragment" - * - * Side effects: - * - Marks layout inference as completed (sets completed_ = true). - * - May abort via ICHECK on unsupported targets, invalid buffer scopes, or - * incompatible shape constraints. - * - * @param T Input layout-inference context (provides thread bounds and target). - * @return LayoutMap mapping A, B, and C to their inferred layouts. - */ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, InferLevel level) const { if (completed_) return {}; LayoutMap results; - auto thread_range = T.thread_bounds; - auto block_size = *as_const_int(thread_range->extent); - GemmInst gemm_inst = getGemmInst(block_size, T.target); - auto [warp_m, warp_n] = - policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst); - if (TargetIsVolta(T.target)) { - ICHECK(IsFragmentBuffer(c_)) - << "Volta gemm only supports C in local.fragment scope, got " - << c_.scope(); - auto fragment = makeGemmVoltaFragmentC(m_, n_, m_ / warp_m, n_ / warp_n, - c_->dtype.bits()); - results.Set(c_, fragment->BindThreadRange(thread_range)); - if (IsSharedBuffer(a_)) { - int dim_A = a_->shape.size(); - auto layout = makeGemmVoltaABLayout(*as_const_int(a_->shape[dim_A - 2]), - *as_const_int(a_->shape[dim_A - 1]), - true, !transA_); - results.Set(a_, ExpandLayoutToMatchBuffer(layout, a_)); - } else if (IsFragmentBuffer(a_)) { - ICHECK(transA_ == false); - auto fragment = - makeGemmVoltaFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n); - results.Set(a_, fragment->BindThreadRange(thread_range)); - } else { - ICHECK(0); - } - ICHECK(IsSharedBuffer(b_)); - int dim_B = b_->shape.size(); - auto layout = makeGemmVoltaABLayout(*as_const_int(b_->shape[dim_B - 2]), - *as_const_int(b_->shape[dim_B - 1]), - false, transB_); - results.Set(b_, ExpandLayoutToMatchBuffer(layout, b_)); - } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) || - TargetIsSM120(T.target) || - (TargetIsSm100(T.target) && gemm_inst == GemmInst::kMMA)) { - ICHECK(IsFragmentBuffer(c_)) - << "MMA only supports C in local.fragment scope, got " << c_.scope(); - - auto fragment = - makeGemmFragmentC(m_, n_, m_ / warp_m, n_ / warp_n, c_->dtype.bits()); - results.Set(c_, fragment->BindThreadRange(thread_range)); - - if (IsSharedBuffer(a_)) { - int dim_A = a_->shape.size(); - const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]); - const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]); - auto layout = makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, - a_->dtype.bits(), !transA_); - results.Set(a_, ExpandLayoutToMatchBuffer(layout, a_)); - } else if (IsFragmentBuffer(a_)) { - auto fragment = makeGemmFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n, - a_->dtype.bits(), transA_); - results.Set(a_, fragment->BindThreadRange(thread_range)); - } else { - ICHECK(0); - } - if (IsSharedBuffer(b_)) { - int dim_B = b_->shape.size(); - const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]); - const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]); - auto layout = makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, - b_->dtype.bits(), transB_); - results.Set(b_, ExpandLayoutToMatchBuffer(layout, b_)); - } else if (IsFragmentBuffer(b_)) { - auto fragment = - makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_); - results.Set(b_, fragment->BindThreadRange(thread_range)); - } else { - ICHECK(0); - } - } else if (TargetIsHopper(T.target)) { - ICHECK(IsFragmentBuffer(c_)) - << (gemm_inst == GemmInst::kWGMMA ? "WGMMA " : "MMA ") - << "only supports C in local.fragment scope, got " << c_.scope(); - auto fragment = gemm_inst == GemmInst::kWGMMA - ? makeGemmFragmentCHopper(m_, n_, m_ / warp_m, - n_ / warp_n, c_->dtype.bits()) - : makeGemmFragmentC(m_, n_, m_ / warp_m, n_ / warp_n, - c_->dtype.bits()); - results.Set(c_, fragment->BindThreadRange(thread_range)); - if (IsSharedBuffer(a_)) { - int dim_A = a_->shape.size(); - const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]); - const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]); - const int64_t continuity = - transA_ ? 4 * mat_continuous / warp_m : mat_continuous; - auto ABLayout = - gemm_inst == GemmInst::kWGMMA - ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity, - a_->dtype.bits(), !transA_) - : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, - a_->dtype.bits(), !transA_); - results.Set(a_, ExpandLayoutToMatchBuffer(ABLayout, a_)); - } else { - auto fragment = makeGemmFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n, - a_->dtype.bits(), transA_); - results.Set(a_, fragment->BindThreadRange(thread_range)); - } - if (IsSharedBuffer(b_)) { - int dim_B = b_->shape.size(); - const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]); - const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]); - const int64_t continuity = - transB_ ? mat_continuous : mat_continuous / warp_n; - - auto ABLayout = - gemm_inst == GemmInst::kWGMMA - ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity, - b_->dtype.bits(), transB_) - : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, - b_->dtype.bits(), transB_); - results.Set(b_, ExpandLayoutToMatchBuffer(ABLayout, b_)); - } else { - auto fragment = - makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_); - results.Set(b_, fragment->BindThreadRange(thread_range)); - } - } else if (gemm_inst == GemmInst::kTCGEN5MMA) { - ICHECK(c_.scope() == "shared.tmem") - << "TCGEN5MMA only supports C in shared.tmem scope, got " << c_.scope(); - ICHECK(IsSharedBuffer(a_)) - << "Current TCGEN5MMA only supports A in shared.dyn scope"; - DataType ab_dtype = (a_.scope() == "shared.tmem") ? b_->dtype : a_->dtype; - auto [can_use_tcgen5mma, meta] = - GetTCGEN5MMAMeta(m_, n_, k_, ab_dtype, c_->dtype); - ICHECK(can_use_tcgen5mma); - { - int dim_A = a_->shape.size(); - const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]); - const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]); - auto layout = - makeGemmABLayoutSm100(mat_stride, mat_continuous, mat_continuous, - a_->dtype.bits(), transA_ ? 1 : 2); - results.Set(a_, ExpandLayoutToMatchBuffer(layout, a_)); - } - { - int dim_B = b_->shape.size(); - const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]); - const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]); - const int64_t continuity = mat_continuous; - auto layout = - makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity, - b_->dtype.bits(), transB_ ? 2 : 1); - results.Set(b_, ExpandLayoutToMatchBuffer(layout, b_)); - } - { - Layout res; - IterVar i = make_itervar("i", m_); - IterVar j = make_itervar("j", n_); - ICHECK(m_ % meta.atom_m == 0); - PrimExpr atom_idx = FloorDiv(i, meta.atom_m) + - FloorDiv(j, meta.atom_n) * (m_ / meta.atom_m); - PrimExpr ai = FloorMod(i, meta.atom_m); // "ai" means "atom_i" - PrimExpr aj = FloorMod(j, meta.atom_n); - if (meta.atom_m == 128) { - // Layout D - // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-d) - res = Layout(Array{i, j}, {ai, aj + atom_idx * meta.atom_n}); - } else if (meta.atom_m == 64) { - // Layout E - // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-e) - // since .ws variant is used About why we use .ws variant here, please - // refer to gemm_sm100.h - res = Layout(Array{i, j}, {FloorDiv(ai, 32) * 32 + FloorMod(ai, 32) + - FloorDiv(aj, meta.atom_n / 2) * 64, - FloorMod(aj, meta.atom_n / 2) + - atom_idx * (meta.atom_n / 2)}); - } else if (meta.atom_m == 32) { - // Layout G - // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-g) - res = Layout( - Array{i, j}, - {FloorMod(ai, 32) + FloorDiv(aj, meta.atom_n / 4) * 32, - FloorMod(aj, meta.atom_n / 4) + atom_idx * (meta.atom_n / 4)}); - } else { - ICHECK(0); + if (const auto f = ffi::Function::GetGlobal("tl.gemm.infer_layout")) { + results = Downcast( + (*f)(tvm::ffi::GetRef(this), T.target, T.thread_bounds)); + // Bind all fragment layouts with the provided thread range + for (auto kv : results) { + const Buffer &buf = kv.first; + const Layout &layout = kv.second; + if (auto frag = layout.as()) { + results.Set(buf, frag.value()->BindThreadRange(T.thread_bounds)); } - results.Set(c_, res); - } - } else if (TargetIsCDNA(T.target)) { - ICHECK(IsFragmentBuffer(c_)) - << "CDNA gemm (FMMA) only supports C in local.fragment scope, got " - << c_.scope(); - auto fragment = makeGemmFragmentCCDNA(m_, n_, m_ / warp_m, n_ / warp_n, - c_->dtype.bits()); - results.Set(c_, fragment->BindThreadRange(thread_range)); - - if (IsSharedBuffer(a_)) { - int dim_A = a_->shape.size(); - auto shared_layout = makeGemmABLayoutCDNA( - *as_const_int(a_->shape[dim_A - 2]), - *as_const_int(a_->shape[dim_A - 1]), a_->dtype.bits(), kPack_); - results.Set(a_, ExpandLayoutToMatchBuffer(shared_layout, a_)); - } else if (IsFragmentBuffer(a_)) { - auto fragment = - makeGemmFragmentACDNA(m_, n_, k_, m_ / warp_m, n_ / warp_n, - a_->dtype.bits(), kPack_, transA_); - results.Set(a_, fragment->BindThreadRange(thread_range)); - } else { - ICHECK(0); - } - if (IsSharedBuffer(b_)) { - int dim_B = b_->shape.size(); - auto shared_layout = makeGemmABLayoutCDNA( - *as_const_int(b_->shape[dim_B - 2]), - *as_const_int(b_->shape[dim_B - 1]), b_->dtype.bits(), kPack_); - - results.Set(b_, ExpandLayoutToMatchBuffer(shared_layout, b_)); - } else if (IsFragmentBuffer(b_)) { - auto fragment = - makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_); - results.Set(b_, fragment->BindThreadRange(thread_range)); - } else { - ICHECK(0); } } else { - ICHECK(0) << "Not supported " << T.target->str(); + LOG(FATAL) << "No infer layout function found for gemm"; } + completed_ = true; return results; } @@ -920,6 +514,9 @@ TVM_REGISTER_OP("tl.GemmWarpPolicy") TVM_FFI_STATIC_INIT_BLOCK() { GemmNode::RegisterReflection(); GemmWarpPolicyNode::RegisterReflection(); +} + +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.GemmWarpPolicyComputeWarpPartition", [](GemmWarpPolicy policy, int M, int N, int block_size, @@ -927,6 +524,39 @@ TVM_FFI_STATIC_INIT_BLOCK() { policy->computeWarpPartition(M, N, block_size, target, gemm_inst); }); + refl::GlobalDef().def("tl.GemmGetGemmInst", + [](Gemm gemm, int block_size, Target target) { + return gemm->getGemmInst(block_size, target); + }); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "tl.get_tcgen5_mma_meta", [](int M, int N, int K, DataType ab_dtype, + DataType c_dtype, bool disable_2cta) { + auto [success, meta] = + GetTCGEN5MMAMeta(M, N, K, ab_dtype, c_dtype, disable_2cta); + Array result; + if (success) { + result.push_back(Integer(meta.atom_m)); + result.push_back(Integer(meta.atom_n)); + result.push_back(Integer(meta.atom_k)); + result.push_back(Integer(meta.enable_ws)); + result.push_back(Integer(meta.enable_2cta)); + } + return result; + }); + refl::GlobalDef().def( + "tl.get_tcgen5_instr_desc", + [](int atom_m, int atom_n, int atom_k, DataType ab_dtype, + DataType c_dtype, bool a_is_k_major, bool b_is_k_major, int scale_in_a, + int scale_in_b) { + uint32_t desc = GetTCGEN5InstrDesc(atom_m, atom_n, atom_k, ab_dtype, + c_dtype, a_is_k_major, b_is_k_major, + scale_in_a, scale_in_b); + return Integer(static_cast(desc)); + }); } } // namespace tl diff --git a/src/op/gemm.h b/src/op/gemm.h index 523e8bafb3..2c78f54760 100644 --- a/src/op/gemm.h +++ b/src/op/gemm.h @@ -129,6 +129,8 @@ class GemmWarpPolicy : public ObjectRef { class GemmNode : public TileOperatorNode { public: bool checkWgmma() const; + bool allowTcgen5Mma(Target target) const; + bool allowWgmma(int block_size, Target target) const; tir::Buffer a_, b_, c_; // BufferRegion for A, B and C BufferRegion aRegion_, bRegion_, cRegion_; @@ -137,16 +139,17 @@ class GemmNode : public TileOperatorNode { int strideA_, strideB_; int offsetA_, offsetB_; PrimExpr clearAccum_ = const_false(); + tir::BufferLoad mbar_; // mbar is optional, only used for TCGEN5MMA + Array cCoords_; // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack // only will be enabled under cdna mfma instructions int kPack_ = 1; int wgWait_ = 0; bool isWgmma_ = false; bool isTcgen05_ = false; - tir::BufferLoad mbar_; // mbar is optional, only used for TCGEN5MMA - Array cCoords_; mutable GemmWarpPolicy policy_; Map annotations_; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Gemm", GemmNode, TileOperatorNode); static void RegisterReflection() { @@ -168,12 +171,12 @@ class GemmNode : public TileOperatorNode { .def_ro("offsetA", &GemmNode::offsetA_) .def_ro("offsetB", &GemmNode::offsetB_) .def_ro("clearAccum", &GemmNode::clearAccum_) + .def_ro("mbar", &GemmNode::mbar_) + .def_ro("cCoords", &GemmNode::cCoords_) .def_ro("kPack", &GemmNode::kPack_) .def_ro("wgWait", &GemmNode::wgWait_) .def_ro("isWgmma", &GemmNode::isWgmma_) .def_ro("isTcgen05", &GemmNode::isTcgen05_) - .def_ro("mbar", &GemmNode::mbar_) - .def_ro("cCoords", &GemmNode::cCoords_) .def_ro("policy", &GemmNode::policy_) .def_ro("annotations", &GemmNode::annotations_); } @@ -185,12 +188,10 @@ class GemmNode : public TileOperatorNode { TileOperator Clone() const; + // Target GEMM instruction GemmInst getGemmInst(int block_size, Target target) const; private: - bool allowTcgen5Mma(Target target) const; - bool allowWgmma(int block_size, Target target) const; - mutable bool completed_ = false; }; diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc deleted file mode 100644 index b5346c79b8..0000000000 --- a/src/op/gemm_py.cc +++ /dev/null @@ -1,438 +0,0 @@ -/*! - * \file tl/op/gemm_py.cc - * \brief Implementation of General Matrix Multiplication (GEMM) operators - */ - -#include "gemm_py.h" - -#include "builtin.h" -#include -#include -#include -#include - -#include "../target/utils.h" -#include "tcgen5_meta.h" -#include "utils.h" - -namespace tvm { -namespace tl { - -using namespace tir; - -// NormalizeToBufferRegion moved to src/op/utils.{h,cc} - -// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc} - -/** - * @brief Construct a Gemm operator from serialized TL arguments and a buffer - * map. - * - * This constructor deserializes operator parameters from `args` and resolves - * buffer references via `vmap`, populating an internal GemmPyNode with: - * - device pointers for A, B, C and their corresponding Buffer objects, - * - transpose flags for A and B, - * - matrix dimensions M, N, K, - * - warp allocation policy and clear_accum flag, - * - strides and memory offsets for A and B, - * - optional kPack (must be 1 or 2) and optional internal wg_wait. - * - * The populated GemmPyNode is stored into the wrapper's internal `data_`. - * - * @param args Positional serialized arguments produced by the TL frontend: - * expected layout is: - * [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool), - * M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool), - * stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int), - * (optional) kPack (Int), (optional) internal wg_wait (Int)] - * - * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor - * fails with an ICHECK (runtime assertion). No other validation is - * performed here. - */ -GemmPy::GemmPy(Array args, Map annotations) { - ObjectPtr node = tvm::ffi::make_object(); - - auto a_access = NormalizeToAccessRegion(args[0], kAccessRead); - auto b_access = NormalizeToAccessRegion(args[1], kAccessRead); - auto c_access = NormalizeToAccessRegion(args[2], kAccessReadWrite); - - node->aRegion_ = a_access.region; - node->bRegion_ = b_access.region; - node->cRegion_ = c_access.region; - node->SetAccessRegions({a_access, b_access, c_access}); - - node->a_ = node->aRegion_->buffer; - node->b_ = node->bRegion_->buffer; - node->c_ = node->cRegion_->buffer; - node->transA_ = args[3].as().value(); - node->transB_ = args[4].as().value(); - node->m_ = args[5].as().value()->value; - node->n_ = args[6].as().value()->value; - node->k_ = args[7].as().value()->value; - node->policy_ = GemmWarpPolicy(args[8].as().value()->value); - node->clearAccum_ = args[9].as().value(); - node->strideA_ = args[10].as().value()->value; - node->strideB_ = args[11].as().value()->value; - node->offsetA_ = args[12].as().value()->value; - node->offsetB_ = args[13].as().value()->value; - if (args.size() > 14) { - node->kPack_ = args[14].as().value()->value; - if (node->kPack_ != 1 && node->kPack_ != 2) { - ICHECK(false) << "kPack must be 1 or 2"; - } - } - if (args.size() > 15) { - node->wgWait_ = args[15].as().value()->value; - } - if (auto val = annotations.Get("is_wgmma")) { - const auto *int_val = val->as(); - ICHECK(int_val) << "is_wgmma annotation must be IntImmNode"; - node->isWgmma_ = int_val->value != 0; - } - if (auto val = annotations.Get("is_tcgen05")) { - const auto *int_val = val->as(); - ICHECK(int_val) << "is_tcgen05 annotation must be IntImmNode"; - node->isTcgen05_ = int_val->value != 0; - } - if (args.size() > 16 && args[16]->IsInstance()) { - node->mbar_ = Downcast(args[16]); - } - node->cCoords_ = Array( - {args[17].as().value(), args[18].as().value()}); - node->annotations_ = annotations; - data_ = std::move(node); -} - -AccessRegions GemmPyNode::GetAccessRegions() const { - AccessRegions result; - result.reads.push_back(aRegion_); - result.reads.push_back(bRegion_); - if (!is_one(clearAccum_)) { - result.reads.push_back(cRegion_); - } - result.writes.push_back(cRegion_); - return result; -} - -/** - * @brief Create a copy of this GemmPyNode as a TileOperator. - * - * Constructs a new GemmPyNode by copying the current node state and returns it - * wrapped in a Gemm TileOperator. - * - * @return TileOperator A Gemm operator that owns a copy of this node. - */ -TileOperator GemmPyNode::Clone() const { - auto op = tvm::ffi::make_object(*this); - return GemmPy(op); -} - -bool GemmPyNode::allowTcgen5Mma(Target target) const { - bool scope_ok = (IsSharedBuffer(a_) || a_.scope() == "shared.tmem") && - IsSharedBuffer(b_) && c_.scope() == "shared.tmem"; - if (!TargetIsSm100(target) || !scope_ok) - return false; - // For TS variant (A from TMEM), use B's dtype as the input dtype - DataType ab_dtype = (a_.scope() == "shared.tmem") ? b_->dtype : a_->dtype; - return GetTCGEN5MMAMeta(m_, n_, k_, ab_dtype, c_->dtype).first; -} - -bool GemmPyNode::allowWgmma(int block_size, Target target) const { - tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); - - int warp_size = TargetGetWarpSize(target); - int num_warps = block_size / warp_size; - return !ctxt->GetConfig(kDisableWGMMA, Optional()).value_or(false) && - TargetIsHopper(target) && (this->m_ >= 64) && (num_warps % 4 == 0) && - checkWgmma(); -} - -GemmInst GemmPyNode::getGemmInst(int block_size, Target target) const { - if (isWgmma_) { - if (!allowWgmma(block_size, target)) { - LOG(FATAL) << "T.wgmma_gemm() requires Hopper WGMMA lowering, but " - "constraints were not satisfied. Got target=" - << target << ", A(scope=" << a_.scope() - << ", dtype=" << a_->dtype << "), B(scope=" << b_.scope() - << ", dtype=" << b_->dtype << "), C(scope=" << c_.scope() - << ", dtype=" << c_->dtype << "), M=" << m_ << ", N=" << n_ - << ", K=" << k_ << "."; - } - return GemmInst::kWGMMA; - } - if (isTcgen05_) { - if (!allowTcgen5Mma(target)) { - LOG(FATAL) << "T.tcgen05_gemm() requires Blackwell TCGEN5MMA lowering, " - "but constraints were not satisfied. Got target=" - << target << ", A(scope=" << a_.scope() - << ", dtype=" << a_->dtype << "), B(scope=" << b_.scope() - << ", dtype=" << b_->dtype << "), C(scope=" << c_.scope() - << ", dtype=" << c_->dtype << "), M=" << m_ << ", N=" << n_ - << ", K=" << k_ << "."; - } - return GemmInst::kTCGEN5MMA; - } - bool allow_tcgen5mma = allowTcgen5Mma(target); - bool allow_wgmma = allowWgmma(block_size, target); - if (allow_tcgen5mma) { - return GemmInst::kTCGEN5MMA; - } else if (allow_wgmma) { - return GemmInst::kWGMMA; - } else if (TargetIsCDNA(target)) { - return GemmInst::kMFMA; - } else if (TargetIsRDNA(target)) { - return GemmInst::kWMMA; - } else if (TargetIsCuda(target)) { - return GemmInst::kMMA; - } else if (TargetIsCPU(target)) { - return GemmInst::kScalar; - } else { - ICHECK(0) << "Unsupported target for gemm: " << target->str(); - return GemmInst::kMMA; // This line will never be reached due to ICHECK, but - // satisfies compiler - } -} - -/** - * @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM. - * - * Evaluates device-memory placement, data-type combinations, transpose flags, - * and K divisibility constraints required for the Hopper WGMMA code path. - * - * The check returns true only when: - * - B resides in shared memory ("shared" or "shared.dyn"); and - * - (C, A, B) dtypes match one of the supported combinations below and K - * satisfies the required alignment; and - * - for combinations that require specific orientations, A is not transposed - * and B is transposed. - * - * Supported combinations and constraints: - * - C=float16: - * - A=float16, B=float16: K % 16 == 0 - * - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K % - * 32 == 0 - * - C=float32: - * - A=float16, B=float16: K % 16 == 0 - * - A=bfloat16, B=bfloat16: K % 16 == 0 - * - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0 - * - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0 - * - C=int32: - * - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B) - * and K % 32 == 0 - * - * @return true if WGMMA is supported for the current buffers, dtypes, and - * transpose/shape constraints; false otherwise. - */ -bool GemmPyNode::checkWgmma() const { - if (b_.scope() != "shared.dyn" && b_.scope() != "shared") { - return false; - } - - if (c_->dtype == DataType::Float(16)) { - if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16)) - return k_ % 16 == 0; - else if (a_->dtype.is_float8() && b_->dtype.is_float8()) - return (!transA_) && transB_ && k_ % 32 == 0; - else - return false; - } else if (c_->dtype == DataType::Float(32)) { - if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16)) - return k_ % 16 == 0; - else if (a_->dtype == DataType::BFloat(16) && - b_->dtype == DataType::BFloat(16)) - return k_ % 16 == 0; - else if (a_->dtype == DataType::Float(32) && - b_->dtype == DataType::Float(32)) - return (!transA_) && transB_ && k_ % 8 == 0; - else if (a_->dtype.is_float8() && b_->dtype.is_float8()) - return (!transA_) && transB_ && k_ % 32 == 0; - else - return false; - } else if (c_->dtype == DataType::Int(32)) { - if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::Int(8)) - return (!transA_) && transB_ && k_ % 32 == 0; - else if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::UInt(8)) - return (!transA_) && transB_ && k_ % 32 == 0; - else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::Int(8)) - return (!transA_) && transB_ && k_ % 32 == 0; - else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::UInt(8)) - return (!transA_) && transB_ && k_ % 32 == 0; - else - return false; - } else { - return false; - } -} - -/** - * @brief Parse and return the numeric GPU architecture from a Target's "arch" - * attribute. - * - * Examines the target's "arch" string and, if it matches the pattern - * "sm_", returns as an int. If the attribute is present but does not - * match that pattern, returns 0. - * - * Preconditions: the target must have an "arch" attribute (this is checked via - * ICHECK). - * - * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if - * the arch string does not match "sm_". - */ -static int GetArchInt(Target target) { - int arch_int = 0; - auto s = target->GetAttr("arch"); - ICHECK(s.has_value()); - std::string arch = s.value(); - if (arch.rfind("sm_", 0) == 0) { - arch_int = std::stoi(arch.substr(3)); - } else { - arch_int = 0; - } - return arch_int; -} - -Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { - if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) { - PrimExpr mbar_phase = T.mbar_phase_expr; - if (auto explicit_phase = GetAnnotatedMbarPhaseExpr(annotations_)) { - mbar_phase = explicit_phase.value(); - } - // NOTE(wt): Decide GemmInst and compute warp partition on Python side - auto prim_func = Downcast( - (*f)(tvm::ffi::GetRef(this), T.layout_map, T.target, - T.thread_bounds, T.thread_var, mbar_phase)); - ICHECK(prim_func->attrs.defined()); - auto global_symbol = - prim_func->attrs.GetAttr("global_symbol"); - ICHECK(global_symbol.has_value()); - if (prim_func->body.as()) { - BlockRealize block_realize = Downcast(prim_func->body); - auto block = block_realize->block; - { - BlockNode *n = block.CopyOnWrite(); - n->name_hint = global_symbol.value(); - n->annotations.Set(tl::attr::kLexicalAllocScope, - IntImm(DataType::Int(32), 1)); - } - return BlockRealize(block_realize->iter_values, block_realize->predicate, - block); - } - // warp with block realize node - Map block_annotations; - block_annotations.Set(tl::attr::kLexicalAllocScope, - IntImm(DataType::Int(32), 1)); - return BlockRealize( - /*iter_values=*/Array(), - /*predicate=*/const_true(), - /*block=*/ - Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, - /*name_hint=*/global_symbol.value(), prim_func->body, - /*init=*/Optional(), /*alloc_buffers=*/{}, - /*match_buffers=*/{}, /*annotations=*/block_annotations)); - } else { - LOG(FATAL) << "No lower function found for gemm_py"; - return Stmt(); // This line will never be reached due to LOG(FATAL), but - // satisfies compiler - } -} - -LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T, - InferLevel level) const { - if (completed_) - return {}; - LayoutMap results; - - if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.infer_layout")) { - results = Downcast( - (*f)(tvm::ffi::GetRef(this), T.target, T.thread_bounds)); - // Bind all fragment layouts with the provided thread range - for (auto kv : results) { - const Buffer &buf = kv.first; - const Layout &layout = kv.second; - if (auto frag = layout.as()) { - results.Set(buf, frag.value()->BindThreadRange(T.thread_bounds)); - } - } - } else { - LOG(FATAL) << "No infer layout function found for gemm_py"; - } - - completed_ = true; - return results; -} - -TIR_REGISTER_TL_TILE_OP(GemmPy, gemm_py) - .set_num_inputs(5) - .set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); - -TVM_REGISTER_OP("tl.tileop.wgmma_gemm_py") - .set_attr("TScriptPrinterName", "wgmma_gemm_py") - .set_attr("TLOpBuilder", - [](Array args, - Map annotations) { - Map ann = annotations; - ann.Set("is_wgmma", - IntImm(DataType::Int(32), 1)); - return GemmPy(args, ann); - }) - .set_num_inputs(5) - .set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); - -TVM_REGISTER_OP("tl.tileop.tcgen05_gemm_py") - .set_attr("TScriptPrinterName", "tcgen05_gemm_py") - .set_attr("TLOpBuilder", - [](Array args, - Map annotations) { - Map ann = annotations; - ann.Set("is_tcgen05", - IntImm(DataType::Int(32), 1)); - return GemmPy(args, ann); - }) - .set_num_inputs(5) - .set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); - -TVM_FFI_STATIC_INIT_BLOCK() { GemmPyNode::RegisterReflection(); } - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tl.GemmPyGemmInst", - [](GemmPy gemm_py, int block_size, Target target) { - return gemm_py->getGemmInst(block_size, target); - }); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "tl.get_tcgen5_mma_meta", [](int M, int N, int K, DataType ab_dtype, - DataType c_dtype, bool disable_2cta) { - auto [success, meta] = - GetTCGEN5MMAMeta(M, N, K, ab_dtype, c_dtype, disable_2cta); - Array result; - if (success) { - result.push_back(Integer(meta.atom_m)); - result.push_back(Integer(meta.atom_n)); - result.push_back(Integer(meta.atom_k)); - result.push_back(Integer(meta.enable_ws)); - result.push_back(Integer(meta.enable_2cta)); - } - return result; - }); - refl::GlobalDef().def( - "tl.get_tcgen5_instr_desc", - [](int atom_m, int atom_n, int atom_k, DataType ab_dtype, - DataType c_dtype, bool a_is_k_major, bool b_is_k_major, int scale_in_a, - int scale_in_b) { - uint32_t desc = GetTCGEN5InstrDesc(atom_m, atom_n, atom_k, ab_dtype, - c_dtype, a_is_k_major, b_is_k_major, - scale_in_a, scale_in_b); - return Integer(static_cast(desc)); - }); -} - -} // namespace tl -} // namespace tvm diff --git a/src/op/gemm_py.h b/src/op/gemm_py.h deleted file mode 100644 index d8dcfa7554..0000000000 --- a/src/op/gemm_py.h +++ /dev/null @@ -1,99 +0,0 @@ -/*! - * \file tl/op/gemm_py.h - * \brief Define gemm operator. - * - */ - -#ifndef TVM_TL_OP_GEMM_PY_H_ -#define TVM_TL_OP_GEMM_PY_H_ - -#include "gemm.h" -#include "operator.h" - -namespace tvm { - -namespace tl { - -using namespace tir; - -class GemmPyNode : public TileOperatorNode { -public: - bool checkWgmma() const; - bool allowTcgen5Mma(Target target) const; - bool allowWgmma(int block_size, Target target) const; - tir::Buffer a_, b_, c_; - // BufferRegion for A, B and C - BufferRegion aRegion_, bRegion_, cRegion_; - bool transA_, transB_; - int m_, n_, k_; - int strideA_, strideB_; - int offsetA_, offsetB_; - PrimExpr clearAccum_ = const_false(); - tir::BufferLoad mbar_; // mbar is optional, only used for TCGEN5MMA - Array cCoords_; - // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack - // only will be enabled under cdna mfma instructions - int kPack_ = 1; - int wgWait_ = 0; - bool isWgmma_ = false; - bool isTcgen05_ = false; - mutable GemmWarpPolicy policy_; - Map annotations_; - - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmPy", GemmPyNode, TileOperatorNode); - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("a", &GemmPyNode::a_) - .def_ro("b", &GemmPyNode::b_) - .def_ro("c", &GemmPyNode::c_) - .def_ro("aRegion", &GemmPyNode::aRegion_) - .def_ro("bRegion", &GemmPyNode::bRegion_) - .def_ro("cRegion", &GemmPyNode::cRegion_) - .def_ro("transA", &GemmPyNode::transA_) - .def_ro("transB", &GemmPyNode::transB_) - .def_ro("m", &GemmPyNode::m_) - .def_ro("n", &GemmPyNode::n_) - .def_ro("k", &GemmPyNode::k_) - .def_ro("strideA", &GemmPyNode::strideA_) - .def_ro("strideB", &GemmPyNode::strideB_) - .def_ro("offsetA", &GemmPyNode::offsetA_) - .def_ro("offsetB", &GemmPyNode::offsetB_) - .def_ro("clearAccum", &GemmPyNode::clearAccum_) - .def_ro("mbar", &GemmPyNode::mbar_) - .def_ro("cCoords", &GemmPyNode::cCoords_) - .def_ro("kPack", &GemmPyNode::kPack_) - .def_ro("wgWait", &GemmPyNode::wgWait_) - .def_ro("isWgmma", &GemmPyNode::isWgmma_) - .def_ro("isTcgen05", &GemmPyNode::isTcgen05_) - .def_ro("policy", &GemmPyNode::policy_) - .def_ro("annotations", &GemmPyNode::annotations_); - } - - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; - LayoutMap InferLayout(const LayoutInferArgs &T, - InferLevel level) const override; - AccessRegions GetAccessRegions() const override; - - TileOperator Clone() const; - - // Target GEMM instruction - GemmInst getGemmInst(int block_size, Target target) const; - -private: - mutable bool completed_ = false; -}; - -class GemmPy : public TileOperator { -public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmPy, TileOperator, GemmPyNode); - TVM_DLL GemmPy(Array args, - Map annotations = Map()); - static const Op &Get(); -}; - -} // namespace tl -} // namespace tvm - -#endif // TVM_TL_OP_GEMM_PY_H_ diff --git a/src/op/gemm_sp_py.h b/src/op/gemm_sp_py.h index 7ced4d5663..fecd14438a 100644 --- a/src/op/gemm_sp_py.h +++ b/src/op/gemm_sp_py.h @@ -4,7 +4,7 @@ * */ -// TODO: @botbw: remove redundant code with gemm_py.h +// TODO: @botbw: remove redundant code with gemm.h #ifndef TVM_TL_OP_GEMM_SP_PY_H_ #define TVM_TL_OP_GEMM_SP_PY_H_ diff --git a/src/transform/inject_pipeline.cc b/src/transform/inject_pipeline.cc index d2731e85e5..d0e832f94e 100644 --- a/src/transform/inject_pipeline.cc +++ b/src/transform/inject_pipeline.cc @@ -19,7 +19,6 @@ #include "../op/builtin.h" #include "../op/copy.h" #include "../op/gemm.h" -#include "../op/gemm_py.h" #include "../op/operator.h" #include "../op/region.h" #include "../op/utils.h" @@ -404,8 +403,7 @@ class TileOpMbarPhaseAnnotator : public StmtExprMutator { auto tile_op = ParseOperator(call); return tile_op.defined() && (tile_op.as() != nullptr || tile_op.as() != nullptr || - tile_op.as() != nullptr || - tile_op.as() != nullptr); + tile_op.as() != nullptr); } PrimExpr phase_expr_; diff --git a/src/transform/lower_blackwell_2sm.cc b/src/transform/lower_blackwell_2sm.cc index df99022107..21ef2277be 100644 --- a/src/transform/lower_blackwell_2sm.cc +++ b/src/transform/lower_blackwell_2sm.cc @@ -3,11 +3,8 @@ * \brief Lower 2SM TCGEN5MMA and related on Blackwell target * * This pass runs before LowerTileOp. At that point the IR still has T.gemm - * (tl_gemm / tl.tileop.gemm_py Call), not the lowered tl::tcgen5mma_gemm_ss/ts. - * We detect Gemm ops that will be lowered to TCGEN5MMA with use_2cta and set - * block attr. - * - * Tilelang gemm defaults to v2 (GemmPyNode); we only support v2, not v1 (Gemm). + * (tl.tileop.gemm Call), not the lowered tl::tcgen5mma_gemm_ss/ts. We detect + * Gemm ops that will be lowered to TCGEN5MMA with use_2cta and set block attr. */ // todo: consider mixture of 1cta/2cta tcgen5mma in the same kernel @@ -20,7 +17,7 @@ #include #include -#include "../op/gemm_py.h" +#include "../op/gemm.h" #include "../op/operator.h" #include "../target/utils.h" @@ -62,9 +59,8 @@ static bool HasValidClusterDimsFor2Cta(const Stmt &body) { /** * \brief Detect 2SM TCGEN5MMA in the kernel (before LowerTileOp). - * Looks for T.gemm (tl_gemm() Call); if it will be lowered to TCGEN5MMA with - * use_2cta, sets the flag for the mutator to add block attr. - * Only supports v2 (GemmPy); v1 (Gemm) is ignored. + * Looks for T.gemm (tl.tileop.gemm Call); if it will be lowered to TCGEN5MMA + * with use_2cta, sets the flag for the mutator to add block attr. */ class Tcgen5_2SmLower : public StmtExprMutator { public: @@ -76,7 +72,7 @@ class Tcgen5_2SmLower : public StmtExprMutator { Stmt VisitStmt_(const EvaluateNode *op) final { if (const CallNode *call = op->value.as()) { TileOperator tile_op = ParseOperator(ffi::GetRef(op)); - if (tile_op.defined() && tile_op.as()) { + if (tile_op.defined() && tile_op.as()) { // Check if the user explicitly requested 2CTA via the use_2cta // annotation on the Call node (set by T.tcgen05_gemm(use_2cta=True)). if (call->annotations.count(attr::kUse2Cta)) { diff --git a/src/transform/lower_opaque_block.cc b/src/transform/lower_opaque_block.cc index c019819bff..34097d015b 100644 --- a/src/transform/lower_opaque_block.cc +++ b/src/transform/lower_opaque_block.cc @@ -106,7 +106,7 @@ class OpaqueBlockLower : public StmtExprMutator { } // Step 5. Materialize a lexical scope boundary only for blocks that were // explicitly marked by an earlier semantic lowering pass (for example - // gemm_py/gemm_sp_py). We intentionally avoid re-inferring this from the + // gemm/gemm_sp_py). We intentionally avoid re-inferring this from the // lowered alloc_buffers here because provenance has already been blurred by // earlier allocation planning/hoisting passes. if (HasLexicalAllocScopeAnnotation(new_block->annotations)) { diff --git a/src/transform/producer_consumer_ws.cc b/src/transform/producer_consumer_ws.cc index e9080f0b0a..31c02d3ae5 100644 --- a/src/transform/producer_consumer_ws.cc +++ b/src/transform/producer_consumer_ws.cc @@ -31,7 +31,6 @@ #include "../op/copy.h" #include "../op/fill.h" #include "../op/gemm.h" -#include "../op/gemm_py.h" #include "../op/operator.h" #include "../op/region.h" #include "../op/utils.h" @@ -839,8 +838,7 @@ class TileOpMbarPhaseAnnotator : public StmtExprMutator { auto tile_op = ParseOperator(call); return tile_op.defined() && (tile_op.as() != nullptr || tile_op.as() != nullptr || - tile_op.as() != nullptr || - tile_op.as() != nullptr); + tile_op.as() != nullptr); } PrimExpr phase_expr_; diff --git a/testing/python/language/test_tilelang_language_if_range.py b/testing/python/language/test_tilelang_language_if_range.py index c81a241ba1..8da3f8b498 100644 --- a/testing/python/language/test_tilelang_language_if_range.py +++ b/testing/python/language/test_tilelang_language_if_range.py @@ -45,6 +45,7 @@ def run_tilelang_if_range(M=128, N=128, block_M=32, block_N=32, dtype=T.float16) torch.testing.assert_close(b, ref_b, rtol=1e-2, atol=1e-2) +@tilelang.testing.requires_cuda def test_tilelang_if_range(): run_tilelang_if_range(M=128, N=128, block_M=32, block_N=32) diff --git a/testing/python/language/test_tilelang_language_min_blocks_per_sm.py b/testing/python/language/test_tilelang_language_min_blocks_per_sm.py index b7a560f462..4b98a44c13 100644 --- a/testing/python/language/test_tilelang_language_min_blocks_per_sm.py +++ b/testing/python/language/test_tilelang_language_min_blocks_per_sm.py @@ -21,6 +21,7 @@ def main( return main +@tilelang.testing.requires_cuda def test_annotate_min_blocks_per_sm_launch_bounds(): """Codegen should emit the second __launch_bounds__ argument from the annotation.""" src = _kernel_min_blocks_per_sm.get_kernel_source() diff --git a/testing/python/language/test_tilelang_language_transpose.py b/testing/python/language/test_tilelang_language_transpose.py index 2bf1633525..e01f04935a 100644 --- a/testing/python/language/test_tilelang_language_transpose.py +++ b/testing/python/language/test_tilelang_language_transpose.py @@ -94,12 +94,14 @@ def run_tilelang_transpose_square(M=256, block_M=128, dtype=T.float16): print(f"PASS: square transpose M={M}, block_M={block_M}") +@tilelang.testing.requires_cuda def test_tilelang_transpose(): run_tilelang_transpose(M=128, N=128, block_M=128, block_N=128) run_tilelang_transpose(M=256, N=256, block_M=128, block_N=128) run_tilelang_transpose(M=128, N=256, block_M=128, block_N=256) +@tilelang.testing.requires_cuda def test_tilelang_transpose_square(): run_tilelang_transpose_square(M=128, block_M=128) run_tilelang_transpose_square(M=256, block_M=128) @@ -107,5 +109,4 @@ def test_tilelang_transpose_square(): if __name__ == "__main__": - test_tilelang_transpose() - test_tilelang_transpose_square() + tilelang.testing.main() diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py deleted file mode 100644 index e2d7a5ee3e..0000000000 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py +++ /dev/null @@ -1,624 +0,0 @@ -import tilelang.language as T -from tilelang import tvm as tvm -import tilelang.testing -from tilelang.utils import determine_fp8_type -import pytest - - -def matmul( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - num_stages, - threads, -): - A_shape = (K, M) if trans_A else (M, K) - B_shape = (N, K) if trans_B else (K, N) - A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - if trans_A: - T.copy(A[k * block_K, by * block_M], A_shared) - else: - T.copy(A[by * block_M, k * block_K], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm_v2(A_shared, B_shared, C_local, trans_A, trans_B) - # T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - -def run_gemm_ss( - M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128, -): - program = matmul( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - num_stages, - num_threads, - ) - - kernel = tilelang.compile( - program, - out_idx=[2], - pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, - ) - - profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) - - def ref_program(A, B): - import torch - - if trans_A: - A = A.T - if trans_B: - B = B.T - C = torch.matmul(A.to(torch.float), B.to(torch.float)) - C = C.to(torch.__getattribute__(out_dtype)) - return C - - profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) - - -@pytest.mark.skip(reason="Temporarily disabling until GEMM SS issues are resolved") -@pytest.mark.parametrize( - "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", - [ - (512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 128, 32, 2, 128), - (512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 128, 32, 2, 128), - (512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 128, 32, 2, 128), - (512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 128, 32, 2, 128), - (128, 16, 32, False, True, T.float16, T.float16, T.float32, 128, 16, 32, 0, 128), - (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), - (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), - (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), - (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), - (128, 128, 128, False, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), - (128, 128, 128, False, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), - (128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), - (128, 128, 128, True, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), - ], -) -def test_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): - run_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) - - -@pytest.mark.skip(reason="Temporarily disabling until GEMM SS issues are resolved") -@tilelang.testing.requires_cuda -@pytest.mark.parametrize( - "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", - [ - (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 32, 2, 128), - ], -) -def test_gemm_ss_fp8_cuda(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): - run_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) - - -@pytest.mark.skip(reason="Temporarily disabling until GEMM SS issues are resolved") -@tilelang.testing.requires_rocm -@pytest.mark.parametrize( - "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", - [ - (128, 128, 128, True, True, determine_fp8_type("e5m2"), determine_fp8_type("e5m2"), T.float32, 128, 128, 32, 2, 128), - (128, 128, 128, True, True, determine_fp8_type(), determine_fp8_type(), T.float32, 128, 128, 32, 2, 128), - ], -) -def test_gemm_ss_fp8_rocm(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): - run_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) - - -def matmul_rs( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - num_stages, - threads, -): - A_shape = (K, M) if trans_A else (M, K) - B_shape = (N, K) if trans_B else (K, N) - A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - A_frag_shape = A_shared_shape - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) - A_frag = T.alloc_fragment(A_frag_shape, in_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - T.clear(C_local) - T.annotate_layout( - { - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - } - ) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - if trans_A: - T.copy(A[k * block_K, by * block_M], A_shared) - else: - T.copy(A[by * block_M, k * block_K], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.copy(A_shared, A_frag) - T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B) - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - -def run_gemm_rs( - M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128, -): - program = matmul_rs( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - num_stages, - num_threads, - ) - - kernel = tilelang.compile( - program, - out_idx=[2], - pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, - ) - profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) - - def ref_program(A, B): - import torch - - if trans_A: - A = A.T - if trans_B: - B = B.T - C = torch.matmul(A.to(torch.float), B.to(torch.float)) - C = C.to(torch.__getattribute__(out_dtype)) - return C - - profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) - - -@pytest.mark.skip(reason="Temporarily disabling until GEMM RS issues are resolved") -@pytest.mark.parametrize( - "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", - [ - (512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (128, 16, 32, False, True, T.float16, T.float16, T.float32, 128, 16, 32, 0, 128), - (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), - (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), - (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), - (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), - (128, 128, 128, False, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), - (128, 128, 128, False, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), - (128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), - (128, 128, 128, True, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), - ], -) -def test_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): - run_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) - - -@pytest.mark.skip(reason="Temporarily disabling until GEMM RS issues are resolved") -@tilelang.testing.requires_cuda -@pytest.mark.parametrize( - "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", - [ - (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 32, 2, 128), - ], -) -def test_gemm_rs_fp8_cuda(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): - run_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) - - -@pytest.mark.skip(reason="Temporarily disabling until GEMM RS issues are resolved") -@tilelang.testing.requires_rocm -@pytest.mark.parametrize( - "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", - [ - (128, 128, 128, True, True, T.float8_e5m2fnuz, T.float8_e5m2fnuz, T.float32, 128, 128, 32, 2, 128), - (128, 128, 128, True, True, T.float8_e4m3fnuz, T.float8_e4m3fnuz, T.float32, 128, 128, 32, 2, 128), - ], -) -def test_gemm_rs_fp8_rocm(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): - run_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) - - -def matmul_sr( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - num_stages, - threads, -): - A_shape = (K, M) if trans_A else (M, K) - B_shape = (N, K) if trans_B else (K, N) - A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - B_frag_shape = B_shared_shape - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) - B_frag = T.alloc_fragment(B_frag_shape, in_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - T.clear(C_local) - T.annotate_layout( - { - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - } - ) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - if trans_A: - T.copy(A[k * block_K, by * block_M], A_shared) - else: - T.copy(A[by * block_M, k * block_K], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.copy(B_shared, B_frag) - T.gemm_v2(A_shared, B_frag, C_local, trans_A, trans_B) - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - -def run_gemm_sr( - M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128, -): - program = matmul_sr( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - num_stages, - num_threads, - ) - - kernel = tilelang.compile( - program, - out_idx=[2], - pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, - ) - profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) - - def ref_program(A, B): - import torch - - if trans_A: - A = A.T - if trans_B: - B = B.T - C = torch.matmul(A.to(torch.float), B.to(torch.float)) - C = C.to(torch.__getattribute__(out_dtype)) - return C - - profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) - - -@pytest.mark.parametrize( - "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", - [ - (512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (128, 16, 32, False, True, T.float16, T.float16, T.float32, 128, 16, 32, 0, 128), - (128, 128, 32, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), - (128, 128, 32, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), - (128, 128, 32, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), - (128, 128, 32, True, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), - (128, 128, 128, False, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), - (128, 128, 128, False, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), - (128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), - (128, 128, 128, True, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), - ], -) -def test_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): - run_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) - - -@tilelang.testing.requires_cuda -@pytest.mark.parametrize( - "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", - [ - (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 32, 2, 128), - ], -) -def test_gemm_sr_fp8_cuda(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): - run_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) - - -@tilelang.testing.requires_rocm -@pytest.mark.parametrize( - "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", - [ - # TODO: There is precision problem needs to repair - # (128, 128, 128, True, True, determine_fp8_type("e5m2"), determine_fp8_type("e5m2"), T.float32, 128, 128, 32, 2, 128), - (128, 128, 128, True, True, determine_fp8_type(), determine_fp8_type(), T.float32, 128, 128, 32, 2, 128), - ], -) -def test_gemm_sr_fp8_rocm(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): - run_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) - - -def matmul_rr( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - num_stages, - threads, -): - A_shape = (K, M) if trans_A else (M, K) - B_shape = (N, K) if trans_B else (K, N) - A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - A_frag_shape = A_shared_shape - B_frag_shape = B_shared_shape - - import tilelang.language as T - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) - A_frag = T.alloc_fragment(A_frag_shape, in_dtype) - B_frag = T.alloc_fragment(B_frag_shape, in_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - T.clear(C_local) - T.annotate_layout( - { - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - } - ) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - if trans_A: - T.copy(A[k * block_K, by * block_M], A_shared) - else: - T.copy(A[by * block_M, k * block_K], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.copy(A_shared, A_frag) - T.copy(B_shared, B_frag) - T.gemm_v2(A_frag, B_frag, C_local, trans_A, trans_B) - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - -def run_gemm_rr( - M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128, -): - program = matmul_rr( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - num_stages, - num_threads, - ) - - kernel = tilelang.compile( - program, - out_idx=[2], - pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, - ) - profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) - - def ref_program(A, B): - import torch - - if trans_A: - A = A.T - if trans_B: - B = B.T - C = torch.matmul(A.to(torch.float), B.to(torch.float)) - C = C.to(torch.__getattribute__(out_dtype)) - return C - - profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) - - -@pytest.mark.parametrize( - "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", - [ - (512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (512, 1024, 768, False, True, T.bfloat16, T.bfloat16, T.float, 128, 256, 32, 2, 128), - # TODO: There is precision problem when num_stages=2 on ROCm - # (128, 16, 128, False, True, T.float16, T.float16, T.float32, 128, 16, 32, 2, 128) - # (128, 16, 128, False, True, T.int8, T.int8, T.int32, 128, 16, 32, 2, 128), - (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), - (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), - (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), - (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), - (128, 128, 128, False, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), - (128, 128, 128, False, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), - (128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), - (128, 128, 128, True, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), - ], -) -def test_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): - run_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) - - -@tilelang.testing.requires_cuda -@pytest.mark.parametrize( - "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", - [ - (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 32, 2, 128), - ], -) -def test_gemm_rr_fp8_cuda(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): - run_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) - - -@tilelang.testing.requires_rocm -@pytest.mark.parametrize( - "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", - [ - # TODO: There is precision problem needs to repair - # (128, 128, 128, True, True, T.float8_e5m2fnuz, T.float8_e5m2fnuz, T.float32, 128, 128, 32, 2, 128), - (128, 128, 128, True, True, determine_fp8_type(), determine_fp8_type(), T.float32, 128, 128, 32, 2, 128), - ], -) -def test_gemm_rr_fp8_rocm(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): - run_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) - - -if __name__ == "__main__": - tilelang.testing.main() diff --git a/tilelang/env.py b/tilelang/env.py index 3d316b8810..9a9e3e5e1c 100644 --- a/tilelang/env.py +++ b/tilelang/env.py @@ -272,10 +272,6 @@ class Environment: "TILELANG_CLEANUP_TEMP_FILES", "0" ) # cleanup temporary compiler files/dirs after compilation (default: keep for debugging) - # Kernel selection options - # Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1 - TILELANG_USE_GEMM_V1 = EnvVar("TILELANG_USE_GEMM_V1", "0") - # Auto-tuning settings TILELANG_AUTO_TUNING_DISABLE_CACHE = EnvVar("TILELANG_AUTO_TUNING_DISABLE_CACHE", "0") TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES", "0.9") # percent of CPUs used @@ -327,14 +323,6 @@ def is_print_on_compilation_enabled(self) -> bool: def should_cleanup_temp_files(self) -> bool: return str(self.TILELANG_CLEANUP_TEMP_FILES).lower() in ("1", "true", "yes", "on") - def use_gemm_v1(self) -> bool: - """Return True if GEMM v1 should be used based on env. - - Controlled by `TILELANG_USE_GEMM_V1`. Truthy values are one of - {"1", "true", "yes", "on"} (case-insensitive). - """ - return str(self.TILELANG_USE_GEMM_V1).lower() in ("1", "true", "yes", "on") - def get_default_target(self) -> str: """Get default compilation target from environment.""" return self.TILELANG_DEFAULT_TARGET diff --git a/tilelang/ir.py b/tilelang/ir.py index 5afe7d04c7..b4ae5fe7e4 100644 --- a/tilelang/ir.py +++ b/tilelang/ir.py @@ -45,10 +45,6 @@ def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target return self.m_warp, self.n_warp -@tvm_ffi.register_object("tl.Gemm") -class Gemm(Node, Scriptable): ... - - @tvm_ffi.register_object("tl.GemmSP") class GemmSP(Node, Scriptable): ... diff --git a/tilelang/jit/execution_backend.py b/tilelang/jit/execution_backend.py index 9855516424..7481dde7aa 100644 --- a/tilelang/jit/execution_backend.py +++ b/tilelang/jit/execution_backend.py @@ -80,7 +80,7 @@ def resolve_execution_backend(requested: str | None, target: Target) -> str: if is_cutedsl_target(target): return "cutedsl" kind = _target_kind(target) - if kind == "cuda" or kind == "metal": + if kind == "cuda" or kind == "metal" or kind == "hip": choice = "tvm_ffi" else: choice = "cython" diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 8ce71e5f31..74ad185c41 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -56,7 +56,7 @@ from tvm.script.parser.tir import allocate as allocate # noqa: F401 from .copy_op import copy, async_copy, tma_copy, transpose, c2d_im2col # noqa: F401 from tilelang.tileop.base import GemmWarpPolicy # noqa: F401 -from .gemm_op import gemm, gemm_v1, gemm_v2, wgmma_gemm, tcgen05_gemm # noqa: F401 +from .gemm_op import gemm, wgmma_gemm, tcgen05_gemm # noqa: F401 from .experimental.gemm_sp import gemm_sp, gemm_sp_v2 # noqa: F401 from .fill_op import fill, clear # noqa: F401 from .reduce_op import ( diff --git a/tilelang/language/gemm_op.py b/tilelang/language/gemm_op.py index 843243e998..ed26d2e677 100644 --- a/tilelang/language/gemm_op.py +++ b/tilelang/language/gemm_op.py @@ -16,7 +16,6 @@ from tilelang.language.utils import ( buffer_region_to_tile_region, ) -from tilelang.env import env as _env def _gemm_impl( @@ -146,62 +145,6 @@ def legalize_arguments(arg: BufferLikeType | tir.Var) -> BufferLikeType: ) -# Public wrappers -def gemm_v1( - A: BufferLikeType, - B: BufferLikeType, - C: BufferLikeType, - transpose_A: bool = False, - transpose_B: bool = False, - policy: GemmWarpPolicy = GemmWarpPolicy.Square, - clear_accum: bool = False, - k_pack: int = 1, - mbar: BarrierType | None = None, -) -> tir.PrimExpr: - """Synchronous GEMM v1: use op tl.gemm.""" - return _gemm_impl( - "tl.tileop.gemm", - A, - B, - C, - transpose_A, - transpose_B, - policy, - clear_accum, - k_pack, - 0, - mbar, - ) - - -# experimental currently, for fast compilation -def gemm_v2( - A: BufferLikeType, - B: BufferLikeType, - C: BufferLikeType, - transpose_A: bool = False, - transpose_B: bool = False, - policy: GemmWarpPolicy = GemmWarpPolicy.Square, - clear_accum: bool = False, - k_pack: int = 1, - mbar: BarrierType | None = None, -) -> tir.PrimExpr: - """Synchronous GEMM v2: use op tl.gemm_py.""" - return _gemm_impl( - "tl.tileop.gemm_py", - A, - B, - C, - transpose_A, - transpose_B, - policy, - clear_accum, - k_pack, - 0, - mbar, - ) - - def gemm( A: BufferLikeType, B: BufferLikeType, @@ -239,8 +182,19 @@ def gemm( Returns: tir.Call: A handle to the GEMM operation. """ - impl = gemm_v1 if _env.use_gemm_v1() else gemm_v2 - return impl(A, B, C, transpose_A, transpose_B, policy, clear_accum, k_pack, mbar) + return _gemm_impl( + "tl.tileop.gemm", + A, + B, + C, + transpose_A, + transpose_B, + policy, + clear_accum, + k_pack, + 0, + mbar, + ) def wgmma_gemm( @@ -264,7 +218,7 @@ def wgmma_gemm( """ return _gemm_impl( - "tl.tileop.wgmma_gemm_py", + "tl.tileop.wgmma_gemm", A, B, C, @@ -306,7 +260,7 @@ def tcgen05_gemm( ann = {"use_2cta": int(use_2cta)} if use_2cta else None return _gemm_impl( - "tl.tileop.tcgen05_gemm_py", + "tl.tileop.tcgen05_gemm", A, B, C, diff --git a/tilelang/tileop/__init__.py b/tilelang/tileop/__init__.py index 6e7798a051..849ba541e1 100644 --- a/tilelang/tileop/__init__.py +++ b/tilelang/tileop/__init__.py @@ -1,3 +1,3 @@ from .base import GemmWarpPolicy # noqa: F401 -from .gemm import GemmPy # noqa: F401 +from .gemm import Gemm # noqa: F401 from .gemm_sp import GemmSPPy # noqa: F401 diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index 70dc84270e..a24545103c 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -17,15 +17,15 @@ from tilelang.utils.target import target_is_volta -@tvm_ffi.register_global_func("tl.gemm_py.infer_layout") -def gemm_py_infer_layout(gemm_py: GemmMMA, target: Target, thread_bounds: Range): +@tvm_ffi.register_global_func("tl.gemm.infer_layout") +def gemm_infer_layout(gemm: GemmMMA, target: Target, thread_bounds: Range): thread_nums = thread_bounds.extent - return gemm_py.infer_layout(target, thread_nums) + return gemm.infer_layout(target, thread_nums) -@tvm_ffi.register_global_func("tl.gemm_py.lower") -def gemm_py_lower( - gemm_py: GemmMMA, +@tvm_ffi.register_global_func("tl.gemm.lower") +def gemm_lower( + gemm: GemmMMA, layout_map, target: Target, thread_bounds: Range, @@ -33,12 +33,12 @@ def gemm_py_lower( mbar_phase_expr: tir.PrimExpr, ): # We pass thread_bounds rather than thread_extents because tcgen5mma need to check this - stmt = gemm_py.lower(layout_map, target, thread_bounds, thread_var, mbar_phase_expr) + stmt = gemm.lower(layout_map, target, thread_bounds, thread_var, mbar_phase_expr) return stmt -@tvm_ffi.register_object("tl.GemmPy") -class GemmPy(Node, Scriptable): +@tvm_ffi.register_object("tl.Gemm") +class Gemm(Node, Scriptable): # FFI fields (LLVM/MLIR-style lowerCamel via reflection): # a, b, c, aPtr, bPtr, cPtr, m, n, k, transA, transB, # strideA, strideB, offsetA, offsetB, clearAccum, kPack, wgWait, policy @@ -159,7 +159,7 @@ def _select_gemm_instruction(self, thread_nums: int, target: Target) -> GemmInst Returns: GemmInst: The selected GEMM instruction type """ - return GemmInst(_ffi_api.GemmPyGemmInst(self, int(thread_nums), target)) + return GemmInst(_ffi_api.GemmGetGemmInst(self, int(thread_nums), target)) def _get_implementation_class(self, gemm_inst: GemmInst, target: Target): """Get the appropriate implementation class for the given GEMM instruction. From 35d81393ad1fe17fd10c282adde5d4e4a000ba4e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 13 Apr 2026 12:43:01 +0800 Subject: [PATCH 043/156] [CI]: Bump actions/github-script from 8 to 9 (#2036) Bumps [actions/github-script](https://github.com/actions/github-script) from 8 to 9. - [Release notes](https://github.com/actions/github-script/releases) - [Commits](https://github.com/actions/github-script/compare/v8...v9) --- updated-dependencies: - dependency-name: actions/github-script dependency-version: '9' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/pr-regression-test-bot.yml | 4 ++-- .github/workflows/pr-reminder-bot.yml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pr-regression-test-bot.yml b/.github/workflows/pr-regression-test-bot.yml index 4e060c3766..4e2c89f66d 100644 --- a/.github/workflows/pr-regression-test-bot.yml +++ b/.github/workflows/pr-regression-test-bot.yml @@ -41,7 +41,7 @@ jobs: steps: - name: Get commenter permission id: perm - uses: actions/github-script@v8 + uses: actions/github-script@v9 with: script: | const username = context.payload.comment.user.login @@ -241,7 +241,7 @@ jobs: path: regression_result.png - name: Post test results as PR comment - uses: actions/github-script@v8 + uses: actions/github-script@v9 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | diff --git a/.github/workflows/pr-reminder-bot.yml b/.github/workflows/pr-reminder-bot.yml index 67e12936c6..38fa5e460c 100644 --- a/.github/workflows/pr-reminder-bot.yml +++ b/.github/workflows/pr-reminder-bot.yml @@ -11,7 +11,7 @@ jobs: if: github.repository_owner == 'tile-ai' steps: - name: Remind - uses: actions/github-script@v8 + uses: actions/github-script@v9 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | From 8243f7e08922641a07258d82bc17bf78df9b8202 Mon Sep 17 00:00:00 2001 From: haoran35-jpg Date: Mon, 13 Apr 2026 03:18:24 -0500 Subject: [PATCH 044/156] Nan propagation option for bf16 and half16 (#1958) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * register op, and add to build file * update semantics of NaN for Max and Min ops in Reduce * lower Max op to cutlass fastmax in general case, _hmax for bf16 and _hmax_nan for hf16 * update semantics of NaN for Max and Min ops in Reduce in reduce.h * develope codes transforming ReduceOpNode in TIR and handling CallNode in CUDA code generator's runtime template * test file for reduce maxmin nan * refine test case message and explanation“ ” ‘ “ * invert the senamtics of global config TL_REDUCE_MAXMIN_NAN_PROPAGATE * lint fix * refactor reduce maxmin NaN propagate into per-call annotation Replace the global tl.reduce_maxmin_nan_propagate pass-config with a per-op nan_propagate flag carried on the ReduceOp annotation map. Frontend exposes it as a kwarg on T.reduce_max/min/absmax (default False, preserving prior behavior). The CUDA codegen MinNode/MaxNode visitors are reverted to plain __hmin/__hmax; only the new tl::max_nan /tl::min_nan CallNode handler emits __hmax_nan/__hmin_nan, so non-reduce Min/Max ops are no longer silently rewritten. ReduceOp::Lower now errors early with a clear message if nan_propagate=True is requested for fp16 /bf16 on a non-CUDA target, instead of emitting an undefined symbol on HIP/CPU. Tests rewritten to parametrize on the kwarg and assert runtime NaN-propagation behavior, not just generated source substrings. Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: Haoran Co-authored-by: LeiWang1999 Co-authored-by: Claude Opus 4.6 (1M context) --- src/op/builtin.cc | 6 + src/op/builtin.h | 4 + src/op/reduce.cc | 36 +++++- src/op/reduce.h | 13 ++- src/target/codegen_cuda.cc | 23 ++++ src/tl_templates/cuda/reduce.h | 27 +++++ ...est_tilelang_language_reduce_maxmin_nan.py | 107 ++++++++++++++++++ tilelang/language/reduce_op.py | 34 ++++-- 8 files changed, 236 insertions(+), 14 deletions(-) create mode 100644 testing/python/language/test_tilelang_language_reduce_maxmin_nan.py diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 662f945878..c6cbfde130 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -88,6 +88,12 @@ TIR_DEFINE_TL_BUILTIN(__cos).set_num_inputs(1).set_attr( TIR_DEFINE_TL_BUILTIN(__sin).set_num_inputs(1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(max_nan).set_num_inputs(2).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(min_nan).set_num_inputs(2).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kPure)); + // high precision with IEEE-compliant TIR_DEFINE_TL_BUILTIN(ieee_add).set_num_inputs(3).set_attr( "TCallEffectKind", Integer(CallEffectKind::kPure)); diff --git a/src/op/builtin.h b/src/op/builtin.h index 276800bf6c..4c7d542f77 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -229,6 +229,10 @@ TVM_DLL const Op &__tan(); TVM_DLL const Op &__cos(); // __sin(x) - fast sine TVM_DLL const Op &__sin(); +// max_nan(x, y) - max with CUDA __hmax_nan semantics for fp16/bf16 +TVM_DLL const Op &max_nan(); +// min_nan(x, y) - min with CUDA __hmin_nan semantics for fp16/bf16 +TVM_DLL const Op &min_nan(); // high precision with IEEE-compliant. // ieee_add(x, y, rounding_mode) - IEEE-compliant addition diff --git a/src/op/reduce.cc b/src/op/reduce.cc index 32fc93f74e..dea9dbc822 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -15,6 +15,7 @@ #include "../op/parallel.h" #include "../target/utils.h" #include "../transform/loop_partition.h" +#include "builtin.h" #include "tir/transforms/ir_utils.h" #include "tvm/ir/expr.h" #include "tvm/tir/expr.h" @@ -44,6 +45,15 @@ ReduceOp::ReduceOp(Array args, Map annotations) { node->dim = args[3].as().value()->value; node->type = ReduceType(reduce_type); node->clear = args[4].as().value(); + // Optional annotation: "nan_propagate" — for fp16/bf16 max/min/absmax, + // when true, lower to CUDA __hmax_nan/__hmin_nan so NaNs propagate. + if (auto opt = annotations.Get("nan_propagate")) { + if (auto b = opt.value().as()) { + node->nan_propagate = b.value(); + } else if (auto i = opt.value().as()) { + node->nan_propagate = i.value()->value != 0; + } + } data_ = std::move(node); } @@ -120,15 +130,26 @@ PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &acc, if (acc->dtype != rhs->dtype) { rhs = Cast(acc->dtype, rhs); } + const bool use_nan_op = + nan_propagate && (acc.dtype().is_float16() || acc.dtype().is_bfloat16()); if (type->isSum()) { return acc + rhs; } else if (type->isAbsSum()) { return acc + Max(rhs, -rhs); } else if (type->isMax()) { + if (use_nan_op) { + return Call(acc.dtype(), tl::max_nan(), {acc, rhs}); + } return Max(acc, rhs); } else if (type->isMin()) { + if (use_nan_op) { + return Call(acc.dtype(), tl::min_nan(), {acc, rhs}); + } return Min(acc, rhs); } else if (type->isAbsMax()) { + if (use_nan_op) { + return Call(acc.dtype(), tl::max_nan(), {acc, tvm::abs(rhs)}); + } return Max(acc, tvm::abs(rhs)); } else if (type->isBitAnd()) { return acc & rhs; @@ -142,16 +163,18 @@ PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &acc, } std::string ReduceOpNode::MakeCodegenReducer() const { + const bool use_nan_op = + nan_propagate && (dst->dtype.is_float16() || dst->dtype.is_bfloat16()); if (type->isSum()) { return "tl::SumOp"; } else if (type->isAbsSum()) { return "tl::SumOp"; } else if (type->isMax()) { - return "tl::MaxOp"; + return use_nan_op ? "tl::MaxOpNan" : "tl::MaxOp"; } else if (type->isMin()) { - return "tl::MinOp"; + return use_nan_op ? "tl::MinOpNan" : "tl::MinOp"; } else if (type->isAbsMax()) { - return "tl::MaxOp"; + return use_nan_op ? "tl::MaxOpNan" : "tl::MaxOp"; } else if (type->isBitAnd()) { return "tl::BitAndOp"; } else if (type->isBitOr()) { @@ -235,6 +258,13 @@ static Fragment ComputeReducerLayout(const Fragment &src_layout, int dim) { * @return Stmt Lowered TIR statement implementing the reduction. */ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + if (nan_propagate && (dst->dtype.is_float16() || dst->dtype.is_bfloat16()) && + !TargetIsCuda(T.target)) { + LOG(FATAL) << "ReduceOp: nan_propagate=True for fp16/bf16 max/min/absmax " + "is only supported on CUDA targets (requires " + "__hmax_nan/__hmin_nan intrinsics). Target was: " + << T.target->str(); + } auto get_buffer = [&](const Buffer &buf) { if (T.buffer_remap.count(buf)) return T.buffer_remap[buf]; diff --git a/src/op/reduce.h b/src/op/reduce.h index 7c9db0c431..5ddb3c64d1 100644 --- a/src/op/reduce.h +++ b/src/op/reduce.h @@ -84,9 +84,13 @@ class ReduceOpNode : public TileOperatorNode { tir::Buffer src, dst; ///< Source and destination buffers // Optional: keep the original regions used to construct this op BufferRegion srcRegion_, dstRegion_; - int dim; ///< Dimension to reduce along - ReduceType type; ///< Type of reduction operation - bool clear; ///< Whether to clear destination before reduction + int dim; ///< Dimension to reduce along + ReduceType type; ///< Type of reduction operation + bool clear; ///< Whether to clear destination before reduction + bool nan_propagate{false}; ///< For fp16/bf16 max/min/absmax: propagate NaN + ///< (use __hmax_nan/__hmin_nan) instead of the + ///< default __hmax/__hmin which return the + ///< non-NaN operand. TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReduceOp", ReduceOpNode, TileOperatorNode); @@ -100,7 +104,8 @@ class ReduceOpNode : public TileOperatorNode { .def_ro("dstRegion", &ReduceOpNode::dstRegion_) .def_ro("dim", &ReduceOpNode::dim) .def_ro("type", &ReduceOpNode::type) - .def_ro("clear", &ReduceOpNode::clear); + .def_ro("clear", &ReduceOpNode::clear) + .def_ro("nan_propagate", &ReduceOpNode::nan_propagate); } /// Lower the operator to TIR statements diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 4a11cc1b3c..474cebe044 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1952,6 +1952,29 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { this->stream << ss.str(); this->stream << ");\n"; }; + if (op->op.same_as(tl::max_nan()) || op->op.same_as(tl::min_nan())) { + ICHECK_EQ(op->args.size(), 2); + const bool is_max = op->op.same_as(tl::max_nan()); + const DataType t = op->dtype; + const char *f16_intrin = is_max ? "__hmax_nan" : "__hmin_nan"; + const char *fallback = is_max ? "cutlass::fast_max" : "cutlass::fast_min"; + + if (t.is_bfloat16() && t.is_scalar()) { + os << "cutlass::bfloat16_t(" << f16_intrin << "(" + << "(" << PrintExpr(op->args[0]) << ").to_nv_bfloat16(), " + << "(" << PrintExpr(op->args[1]) << ").to_nv_bfloat16()))"; + return; + } + if (t.is_float16() && t.is_scalar()) { + os << "cutlass::half_t(" << f16_intrin << "(" + << "(" << PrintExpr(op->args[0]) << ").to_half(), " + << "(" << PrintExpr(op->args[1]) << ").to_half()))"; + return; + } + os << fallback << "(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]) << ")"; + return; + } if (op->op.same_as(builtin::ptx_cp_async())) { // args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes, // args[3] = predicate (optional) diff --git a/src/tl_templates/cuda/reduce.h b/src/tl_templates/cuda/reduce.h index 5c5faa397c..539c246098 100644 --- a/src/tl_templates/cuda/reduce.h +++ b/src/tl_templates/cuda/reduce.h @@ -45,6 +45,19 @@ struct MaxOp { return half_t(__hmax(x.to_half(), y.to_half())); } }; +struct MaxOpNan { + template TL_DEVICE T operator()(T const &x, T const &y) { + return cutlass::fast_max(x, y); + } + + TL_DEVICE bfloat16_t operator()(bfloat16_t const &x, bfloat16_t const &y) { + return bfloat16_t(__hmax_nan(x.to_nv_bfloat16(), y.to_nv_bfloat16())); + } + + TL_DEVICE half_t operator()(half_t const &x, half_t const &y) { + return half_t(__hmax_nan(x.to_half(), y.to_half())); + } +}; struct MinOp { template TL_DEVICE T operator()(T const &x, T const &y) { @@ -61,6 +74,20 @@ struct MinOp { } }; +struct MinOpNan { + template TL_DEVICE T operator()(T const &x, T const &y) { + return cutlass::fast_min(x, y); + } + + TL_DEVICE bfloat16_t operator()(bfloat16_t const &x, bfloat16_t const &y) { + return bfloat16_t(__hmin_nan(x.to_nv_bfloat16(), y.to_nv_bfloat16())); + } + + TL_DEVICE half_t operator()(half_t const &x, half_t const &y) { + return half_t(__hmin_nan(x.to_half(), y.to_half())); + } +}; + struct BitAndOp { template TL_DEVICE T operator()(T const &x, T const &y) { return x & y; diff --git a/testing/python/language/test_tilelang_language_reduce_maxmin_nan.py b/testing/python/language/test_tilelang_language_reduce_maxmin_nan.py new file mode 100644 index 0000000000..94cd837d49 --- /dev/null +++ b/testing/python/language/test_tilelang_language_reduce_maxmin_nan.py @@ -0,0 +1,107 @@ +"""Tests for the per-call ``nan_propagate`` kwarg on T.reduce_max / reduce_min / +reduce_absmax for float16 and bfloat16 buffers (CUDA only).""" + +import math + +import torch + +import tilelang +import tilelang.testing +import tilelang.language as T + +_DTYPES = [("float16", T.float16, torch.float16), ("bfloat16", T.bfloat16, torch.bfloat16)] + + +def _compile(prim_func): + return tilelang.compile(prim_func, out_idx=-1, target="cuda") + + +def _make_reduce_kernel(reduce_fn, length, dtype, *, nan_propagate): + + @T.prim_func + def kernel(a: T.Tensor((length,), dtype), out: T.Tensor((1,), dtype)): + with T.Kernel(1, threads=32): + frag = T.alloc_fragment((length,), dtype) + out_frag = T.alloc_fragment((1,), dtype) + T.copy(a, frag) + reduce_fn(frag, out_frag, nan_propagate=nan_propagate) + T.copy(out_frag, out) + + return kernel + + +# --------------------------------------------------------------------------- +# Source-level checks: confirm the right reducer / intrinsic is emitted. +# --------------------------------------------------------------------------- + + +@tilelang.testing.requires_cuda +def test_reduce_max_default_uses_plain_op(): + k = _compile(_make_reduce_kernel(T.reduce_max, 64, T.float16, nan_propagate=False)) + src = k.get_kernel_source() + assert "tl::MaxOp" in src and "MaxOpNan" not in src + assert "__hmax(" in src and "__hmax_nan" not in src + + +@tilelang.testing.requires_cuda +def test_reduce_max_nan_propagate_uses_nan_op(): + k = _compile(_make_reduce_kernel(T.reduce_max, 64, T.float16, nan_propagate=True)) + src = k.get_kernel_source() + assert "tl::MaxOpNan" in src + assert "__hmax_nan" in src + + +@tilelang.testing.requires_cuda +def test_reduce_min_nan_propagate_uses_nan_op(): + k = _compile(_make_reduce_kernel(T.reduce_min, 64, T.bfloat16, nan_propagate=True)) + src = k.get_kernel_source() + assert "tl::MinOpNan" in src + assert "__hmin_nan" in src + + +@tilelang.testing.requires_cuda +def test_reduce_absmax_nan_propagate_uses_nan_op(): + k = _compile(_make_reduce_kernel(T.reduce_absmax, 64, T.float16, nan_propagate=True)) + src = k.get_kernel_source() + assert "tl::MaxOpNan" in src + assert "__hmax_nan" in src + + +# --------------------------------------------------------------------------- +# Runtime behavioral checks: NaN actually propagates only when requested. +# --------------------------------------------------------------------------- + + +@tilelang.testing.requires_cuda +def test_reduce_max_runtime_nan_behavior(): + for _, tl_dtype, torch_dtype in _DTYPES: + length = 64 + a = torch.arange(length, dtype=torch.float32).to(torch_dtype).cuda() + a[7] = float("nan") + + k_default = _compile(_make_reduce_kernel(T.reduce_max, length, tl_dtype, nan_propagate=False)) + k_nan = _compile(_make_reduce_kernel(T.reduce_max, length, tl_dtype, nan_propagate=True)) + + out_default = k_default(a) + out_nan = k_nan(a) + + assert not math.isnan(out_default.float().item()), f"{tl_dtype}: default reduce_max should ignore NaN, got {out_default}" + assert math.isnan(out_nan.float().item()), f"{tl_dtype}: nan_propagate reduce_max should return NaN, got {out_nan}" + + +@tilelang.testing.requires_cuda +def test_reduce_min_runtime_nan_behavior(): + for _, tl_dtype, torch_dtype in _DTYPES: + length = 64 + a = torch.arange(length, dtype=torch.float32).to(torch_dtype).cuda() + a[13] = float("nan") + + k_default = _compile(_make_reduce_kernel(T.reduce_min, length, tl_dtype, nan_propagate=False)) + k_nan = _compile(_make_reduce_kernel(T.reduce_min, length, tl_dtype, nan_propagate=True)) + + assert not math.isnan(k_default(a).float().item()) + assert math.isnan(k_nan(a).float().item()) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/language/reduce_op.py b/tilelang/language/reduce_op.py index 503886767f..e0118ea6cd 100644 --- a/tilelang/language/reduce_op.py +++ b/tilelang/language/reduce_op.py @@ -22,7 +22,7 @@ def _legalize_dim(buffer: tir.Buffer, dim: int): # NOTE(chaofan): T.reduce is implemented as a macro, so no return -def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: ReduceKind, dim: int, clear: bool) -> None: +def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: ReduceKind, dim: int, clear: bool, nan_propagate: bool = False) -> None: """Perform a reduction operation on a buffer along a specified dimension. Args: @@ -31,6 +31,10 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: ReduceKind, dim: in reduce_type (str): Type of reduction ('max', 'min', 'sum', 'abssum') dim (int): Dimension along which to perform reduction clear (bool): Whether to initialize the output buffer before reduction + nan_propagate (bool): Only meaningful for max/min/absmax on + float16/bfloat16. When True, lower to CUDA __hmax_nan/__hmin_nan so + NaNs propagate through the reduction. When False (default), use + __hmax/__hmin which return the non-NaN operand. CUDA-only. """ # input shape: [X, d, Y], expected output shape: [X, Y] or [X, 1, Y] expected_shapes = [buffer.shape[:dim] + buffer.shape[dim + 1 :], buffer.shape[:dim] + [1] + buffer.shape[dim + 1 :]] @@ -41,6 +45,8 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: ReduceKind, dim: in f"output shape is {out.shape}, expected shapes are {expected_shapes_str}" ) + annotations = {"nan_propagate": True} if nan_propagate else None + @macro def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool) -> None: if is_shared(buffer) and is_shared(out): @@ -63,6 +69,7 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int reduce_type, dim, clear, + annotations=annotations, ) copy(red_frag_out, out) elif is_shared(buffer) and is_fragment(out): @@ -78,6 +85,7 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int reduce_type, dim, clear, + annotations=annotations, ) elif is_fragment(buffer) and is_shared(out): red_frag_out = alloc_fragment(out.shape, out.dtype) @@ -94,6 +102,7 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int reduce_type, dim, clear, + annotations=annotations, ) copy(red_frag_out, out) elif is_fragment(buffer) and is_fragment(out): @@ -105,6 +114,7 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int reduce_type, dim, clear, + annotations=annotations, ) else: raise ValueError(f"Invalid buffer scopes: {buffer.scope()} and {out.scope()}") @@ -112,7 +122,7 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int reduce_macro(buffer, out, reduce_type, dim, clear) -def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True) -> None: +def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True, nan_propagate: bool = False) -> None: """Perform reduce max on input buffer, store the result to output buffer Parameters @@ -125,15 +135,19 @@ def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = The dimension to perform reduce on clear : bool If set to True, the output buffer will first be initialized to -inf. + nan_propagate : bool + For float16/bfloat16 only. When True, NaN inputs propagate through the + reduction (CUDA __hmax_nan). When False (default), NaN inputs are + ignored in favor of the other operand (CUDA __hmax). CUDA-only. Returns ------- handle : PrimExpr """ dim = _legalize_dim(buffer, dim) - reduce(buffer, out, "max", dim, clear) + reduce(buffer, out, "max", dim, clear, nan_propagate) -def reduce_min(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True) -> None: +def reduce_min(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True, nan_propagate: bool = False) -> None: """Perform reduce min on input buffer, store the result to output buffer. Args: @@ -141,12 +155,15 @@ def reduce_min(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = out (tir.Buffer): The output buffer dim (int): The dimension to perform reduce on clear (bool, optional): If True, output buffer will be initialized to inf. Defaults to True. + nan_propagate (bool, optional): For float16/bfloat16 only. When True, + NaN inputs propagate (CUDA __hmin_nan). When False (default), NaNs + are ignored (CUDA __hmin). CUDA-only. Returns: tir.Call: Handle to the reduction operation """ dim = _legalize_dim(buffer, dim) - reduce(buffer, out, "min", dim, clear) + reduce(buffer, out, "min", dim, clear, nan_propagate) def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True) -> None: @@ -189,19 +206,22 @@ def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1) -> None: reduce(buffer, out, "abssum", dim, True) -def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True) -> None: +def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True, nan_propagate: bool = False) -> None: """Perform reduce absolute max on input buffer, store the result to output buffer. Args: buffer (tir.Buffer): The input buffer out (tir.Buffer): The output buffer dim (int): The dimension to perform reduce on + nan_propagate (bool, optional): For float16/bfloat16 only. When True, + NaN inputs propagate (CUDA __hmax_nan). When False (default), NaNs + are ignored. CUDA-only. Returns: tir.Call: Handle to the reduction operation """ dim = _legalize_dim(buffer, dim) - reduce(buffer, out, "absmax", dim, clear) + reduce(buffer, out, "absmax", dim, clear, nan_propagate) def reduce_bitand(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True) -> None: From fc5001ffb4dc3c817b7eabb23e863c7315033a66 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Mon, 13 Apr 2026 16:45:24 +0800 Subject: [PATCH 045/156] FIx naive loop var duplication bug --- .../auto_schedule/schedule_builder.cc | 66 +++++++++++++++++++ .../auto_schedule/warpgroup_partition.cc | 25 ++++--- 2 files changed, 77 insertions(+), 14 deletions(-) diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index c0aaf64a2a..9066157a15 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -792,6 +792,61 @@ void ScheduleUnitBuilder::NaiveScheduleLoop(ControlNode *ctrl) { ctrl->SetIIperIter(1); + int n = static_cast(seq_body->children.size()); + auto IsVarDecl = [](IRStructure *node) -> bool { + if (!node || !node->IsTask()) + return false; + auto task = static_cast(node); + return task->stmts.size() == 1 && + task->stmts[0].as() != nullptr; + }; + auto SolveConflictVar = [&]() -> bool { + for (int i = 0; i < n; ++i) { + if (!IsVarDecl(seq_body->children[i].get())) + continue; + for (int j = 0; j < n; ++j) { + if (i == j) + continue; + auto node_i = seq_body->children[i].get(); + auto node_j = seq_body->children[j].get(); + if (!HasDependency(node_i, node_j)) + continue; + if (stage_map[node_j] == stage_map[node_i]) + continue; + + int rem_stage_j = stage_map[node_j]; + auto node_i_task = static_cast(node_i); + auto node_i_let_stmt = node_i_task->stmts[0].as(); + + auto cloned_let_stmt = + LetStmt(node_i_let_stmt->var.copy_with_suffix(""), + node_i_let_stmt->value, Evaluate(0)); + auto cloned_task = std::make_shared(); + cloned_task->stmts.push_back(cloned_let_stmt); + stage_map[cloned_task.get()] = rem_stage_j; + + for (int k = j; k < n; ++k) { + auto node_k = seq_body->children[k].get(); + if (rem_stage_j != stage_map[node_k]) + continue; + if (HasDependency(node_i, node_k)) { + node_k->SubstituteVar(node_i_let_stmt->var, cloned_let_stmt->var); + stage_map[node_k] = rem_stage_j; + } + } + + seq_body->children.insert(seq_body->children.begin() + j, + std::move(cloned_task)); + n += 1; + return true; + } + } + return false; + }; + int conflict_count = 0; + while (SolveConflictVar() && ++conflict_count < 100) + ; + // Estimate overall latency int64_t tripcount = 100; const ForNode *for_node = ctrl->control.get(); @@ -855,6 +910,17 @@ void ScheduleUnitBuilder::NaiveScheduleRecursive( auto ctrl = static_cast(node.get()); if (ctrl->child) { if (ctrl->child->IsSequence() || ctrl->child->IsWrapper()) { + std::vector> origin_children; + if (ctrl->child->IsSequence()) { + auto seq_body = static_cast(ctrl->child.get()); + GatherTaskNodes(seq_body->children, origin_children); + } else { + auto wrapper = static_cast(ctrl->child.get()); + GatherTaskNodes({wrapper->task, wrapper->child}, origin_children); + } + for (auto &child : origin_children) { + NaiveScheduleRecursive(child); + } NaiveScheduleLoop(ctrl); } else { NaiveScheduleRecursive(ctrl->child); diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index b5b8b5f9a6..d128038887 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -176,21 +176,18 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, return nullptr; auto ctrl = static_cast(node); auto new_ctrl = std::make_shared(); - // Apply var_remap to the For statement's min/extent/step so that renamed - // LetDecl variables are correctly referenced in the loop bounds. - if (!var_remap.empty()) { - For new_for = ctrl->control; - new_for.CopyOnWrite()->min = Substitute(ctrl->control->min, var_remap); - new_for.CopyOnWrite()->extent = - Substitute(ctrl->control->extent, var_remap); - if (ctrl->control->step.has_value()) { - new_for.CopyOnWrite()->step = - Substitute(ctrl->control->step.value(), var_remap); - } - new_ctrl->control = new_for; - } else { - new_ctrl->control = ctrl->control; + For new_for = ctrl->control; + auto new_loop_var = ctrl->control->loop_var.copy_with_suffix(""); + new_for.CopyOnWrite()->loop_var = new_loop_var; + var_remap.Set(ctrl->control->loop_var, new_loop_var); + new_for.CopyOnWrite()->min = Substitute(ctrl->control->min, var_remap); + new_for.CopyOnWrite()->extent = + Substitute(ctrl->control->extent, var_remap); + if (ctrl->control->step.has_value()) { + new_for.CopyOnWrite()->step = + Substitute(ctrl->control->step.value(), var_remap); } + new_ctrl->control = new_for; // Clone the task and apply var_remap so each warpgroup gets its own copy // with correctly renamed LetDecl variables. if (ctrl->task) { From 19236b44c07104f7453013081f6bc01ad3ca5aa5 Mon Sep 17 00:00:00 2001 From: Sepcnt <30561671+sepcnt@users.noreply.github.com> Date: Mon, 13 Apr 2026 23:35:03 +0800 Subject: [PATCH 046/156] [Feature] Add TIR builtins for warp-level vote and block-level predicate sync (#1858) * [Feature] Add TIR builtins for warp-level vote and block-level predicate sync * [Refactor] Lower warp-level intrinsics through tl.* ops Convert the warp-vote, warp-shuffle, predicated block-sync, and warp-match builtins from raw `tir.call_extern` wrappers to proper TIR ops registered under `tl.*` and lowered in `codegen_cuda` / `codegen_hip`. This removes Python-side `_IS_HIP_AVAILABLE` branching and pushes the CUDA/HIP split into codegen, where it belongs. * Register new tl ops: any_sync, all_sync, ballot_sync, ballot, activemask, syncthreads_count/and/or, shfl_sync, shfl_xor_sync, shfl_down_sync, shfl_up_sync, match_any_sync, match_all_sync. * Codegen lowering on both CUDA and HIP. uint32->uint64 zero-extension for ballot/activemask now happens in codegen. HIP drops the mask argument for shfl/vote and emits LOG(FATAL) for match_*_sync. * `__match_all_sync`'s int* pred output is hidden behind an immediately-invoked lambda so the wrapper stays expression-form. * Python wrappers normalize Python int masks to uint32 TIR consts via `_as_uint32_mask`, so the emitted source prints `(uint)0xFFFFFFFF` instead of `(int64_t)4294967295`. * Unify `shfl_xor/down/up` signatures to `(mask, value, delta, width=32)`, matching CUDA convention and `shfl_sync`. No in-tree callers existed. * Add tests for match_any_sync and match_all_sync. Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: LeiWang1999 Co-authored-by: Claude Opus 4.6 (1M context) --- docs/programming_guides/instructions.md | 27 ++ src/op/builtin.cc | 68 +++ src/op/builtin.h | 126 ++++++ src/target/codegen_cuda.cc | 70 +++ src/target/codegen_hip.cc | 55 +++ .../test_tilelang_language_warp_vote.py | 410 ++++++++++++++++++ tilelang/language/__init__.py | 8 + tilelang/language/builtin.py | 201 +++++++-- 8 files changed, 928 insertions(+), 37 deletions(-) create mode 100644 testing/python/language/test_tilelang_language_warp_vote.py diff --git a/docs/programming_guides/instructions.md b/docs/programming_guides/instructions.md index a260b915ae..874b8c020d 100644 --- a/docs/programming_guides/instructions.md +++ b/docs/programming_guides/instructions.md @@ -188,9 +188,36 @@ Annotation helpers - `T.annotate_l2_hit_ratio(buf, ratio)`: Cache behavior hint. Synchronization helpers +- `T.sync_threads([barrier_id, arrive_count])`: Block-wide barrier (`__syncthreads()`). +- `T.sync_warp([mask])`: Warp-wide barrier (`__syncwarp([mask])`). +- `T.sync_grid()`: Cooperative grid barrier (requires cooperative launch). - `T.pdl_trigger()`: Signal programmatic launch completion for the current kernel. - `T.pdl_sync()`: Wait until kernel dependencies are satisfied. +Warp-vote / warp-ballot (CUDA ≥ 9 / HIP) +- `T.any_sync(mask, predicate)` → `int32`: Non-zero if ANY lane in `mask` has non-zero predicate (`__any_sync`). +- `T.all_sync(mask, predicate)` → `int32`: Non-zero if ALL lanes in `mask` have non-zero predicate (`__all_sync`). +- `T.ballot_sync(mask, predicate)` → `uint64`: Bitmask of lanes in `mask` with non-zero predicate. CUDA: `__ballot_sync` zero-extended to 64 bits; HIP: `__ballot` returns natively as `uint64`, covering all 64 wavefront lanes. +- `T.ballot(predicate)` → `uint64`: Full-warp/wavefront ballot (mask = `0xFFFFFFFF`). No truncation on HIP. +- `T.activemask()` → `uint64`: Bitmask of currently active lanes. CUDA: `__activemask` zero-extended to 64 bits; HIP: `__ballot(1)` as `uint64`. + +Block-wide predicated sync +- `T.syncthreads_count(predicate)` → `int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`). +- `T.syncthreads_and(predicate)` → `int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`). +- `T.syncthreads_or(predicate)` → `int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`). + +Warp-shuffle (intra-warp data exchange) +- `T.shfl_sync(mask, value, src_lane[, width])`: Broadcast value from `src_lane` to all lanes (`__shfl_sync`). +- `T.shfl_xor(value, offset[, mask, width])`: XOR-swap across lanes (`__shfl_xor_sync`). +- `T.shfl_down(value, offset[, mask, width])`: Shift down by `offset` lanes (`__shfl_down_sync`). +- `T.shfl_up(value, offset[, mask, width])`: Shift up by `offset` lanes (`__shfl_up_sync`). + +Warp-match (CUDA sm_70+, not supported on HIP) +- `T.match_any_sync(mask, value)` → `uint32`: Bitmask of lanes in `mask` whose `value` matches the calling lane's (`__match_any_sync`). +- `T.match_all_sync(mask, value)` → `uint32`: Returns `mask` if all lanes in `mask` agree on `value`, else 0 (`__match_all_sync`). The C-level `int*` predicate output is hidden; reconstruct it as `result != 0`. + +> **Note on HIP:** `any_sync`/`all_sync` ignore the mask and call `__any`/`__all` directly. `ballot_sync`, `ballot`, and `activemask` call `__ballot` which returns `uint64` natively on 64-thread wavefronts — no truncation occurs. Shuffle intrinsics lower to `__shfl`/`__shfl_xor`/`__shfl_down`/`__shfl_up` (mask ignored). `syncthreads_count/and/or` have identical signatures on both platforms. `match_any_sync` and `match_all_sync` have no HIP equivalent and will fail to codegen on HIP. + Atomics - `T.atomic_add(dst, value, memory_order=None, return_prev=False, use_tma=False)`. - `T.atomic_addx2(dst, value, return_prev=False)`; `T.atomic_addx4(...)`. diff --git a/src/op/builtin.cc b/src/op/builtin.cc index c6cbfde130..b95ec04360 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -368,6 +368,74 @@ TIR_DEFINE_TL_BUILTIN(pdl_trigger) TIR_DEFINE_TL_BUILTIN(pdl_sync).set_num_inputs(0).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +// Warp-vote / warp-ballot intrinsics. These synchronize the warp, so they are +// marked opaque to prevent reordering across divergent control flow. +TIR_DEFINE_TL_BUILTIN(any_sync).set_num_inputs(2).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(all_sync).set_num_inputs(2).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ballot_sync) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ballot).set_num_inputs(1).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(activemask) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +// Block-wide predicated barriers. +TIR_DEFINE_TL_BUILTIN(syncthreads_count) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(syncthreads_and) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(syncthreads_or) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +// Warp-shuffle intrinsics. All four accept (mask, value, lane_or_offset, +// width) and are opaque because they involve inter-lane communication. +TIR_DEFINE_TL_BUILTIN(shfl_sync).set_num_inputs(4).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(shfl_xor_sync) + .set_num_inputs(4) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(shfl_down_sync) + .set_num_inputs(4) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(shfl_up_sync) + .set_num_inputs(4) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +// Warp match-any/match-all intrinsics (CUDA sm_70+). HIP lowering errors. +TIR_DEFINE_TL_BUILTIN(match_any_sync) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(match_all_sync) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(loop_break) .set_num_inputs(0) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index 4c7d542f77..6e71b2446a 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -641,6 +641,132 @@ TVM_DLL const Op &pdl_trigger(); */ TVM_DLL const Op &pdl_sync(); +/*! + * \brief Warp-vote: non-zero if ANY active lane in the mask has a non-zero + * predicate. Lowers to `__any_sync(mask, predicate)` on CUDA and + * `__any(predicate)` on HIP (mask is ignored on HIP). + * + * int32 any_sync(mask, predicate) + */ +TVM_DLL const Op &any_sync(); + +/*! + * \brief Warp-vote: non-zero only if ALL active lanes in the mask have a + * non-zero predicate. Lowers to `__all_sync(mask, predicate)` on CUDA and + * `__all(predicate)` on HIP (mask is ignored on HIP). + * + * int32 all_sync(mask, predicate) + */ +TVM_DLL const Op &all_sync(); + +/*! + * \brief Warp-ballot: bitmask of lanes in the mask with non-zero predicate. + * + * CUDA: `__ballot_sync(mask, predicate)` returns `uint32`; the codegen + * zero-extends the result to `uint64`. + * HIP: `__ballot(predicate)` returns `uint64` natively, covering all 64 + * lanes of the wavefront. Mask is ignored on HIP. + * + * uint64 ballot_sync(mask, predicate) + */ +TVM_DLL const Op &ballot_sync(); + +/*! + * \brief Full-warp / full-wavefront ballot. Equivalent to + * `ballot_sync(0xFFFFFFFF, predicate)`. + * + * uint64 ballot(predicate) + */ +TVM_DLL const Op &ballot(); + +/*! + * \brief Bitmask of currently active (non-exited) lanes. Lowers to + * `__activemask()` (zero-extended to `uint64`) on CUDA and `__ballot(1)` on + * HIP. + * + * uint64 activemask() + */ +TVM_DLL const Op &activemask(); + +/*! + * \brief Block barrier that returns the number of threads whose predicate + * evaluates to non-zero. Lowers to `__syncthreads_count(predicate)` on both + * CUDA and HIP. + * + * int32 syncthreads_count(predicate) + */ +TVM_DLL const Op &syncthreads_count(); + +/*! + * \brief Block barrier that returns non-zero only if ALL threads have a + * non-zero predicate. Lowers to `__syncthreads_and(predicate)` on both + * CUDA and HIP. + * + * int32 syncthreads_and(predicate) + */ +TVM_DLL const Op &syncthreads_and(); + +/*! + * \brief Block barrier that returns non-zero if ANY thread has a non-zero + * predicate. Lowers to `__syncthreads_or(predicate)` on both CUDA and HIP. + * + * int32 syncthreads_or(predicate) + */ +TVM_DLL const Op &syncthreads_or(); + +/*! + * \brief Warp shuffle: broadcast `value` from `src_lane` within each subgroup + * of `width` lanes. Lowers to `__shfl_sync(mask, value, src_lane, width)` on + * CUDA and `__shfl(value, src_lane, width)` on HIP. The dtype of the result + * matches the dtype of `value`. + * + * T shfl_sync(mask, value, src_lane, width) + */ +TVM_DLL const Op &shfl_sync(); + +/*! + * \brief Warp shuffle (XOR-swap variant). Lowers to `__shfl_xor_sync` on CUDA + * and `__shfl_xor` on HIP. + * + * T shfl_xor_sync(mask, value, lane_mask, width) + */ +TVM_DLL const Op &shfl_xor_sync(); + +/*! + * \brief Warp shuffle (shift-down variant). Lowers to `__shfl_down_sync` on + * CUDA and `__shfl_down` on HIP. + * + * T shfl_down_sync(mask, value, delta, width) + */ +TVM_DLL const Op &shfl_down_sync(); + +/*! + * \brief Warp shuffle (shift-up variant). Lowers to `__shfl_up_sync` on CUDA + * and `__shfl_up` on HIP. + * + * T shfl_up_sync(mask, value, delta, width) + */ +TVM_DLL const Op &shfl_up_sync(); + +/*! + * \brief Warp match-any: returns a mask of lanes in `mask` whose `value` + * equals the calling lane's value. Lowers to `__match_any_sync` on CUDA + * (compute capability >= 7.0). Not supported on HIP. + * + * uint32 match_any_sync(mask, value) + */ +TVM_DLL const Op &match_any_sync(); + +/*! + * \brief Warp match-all: returns `mask` if all lanes in `mask` agree on + * `value`, else 0. Lowers to `__match_all_sync` on CUDA (compute capability + * >= 7.0, the trailing `int*` predicate output is discarded via an + * immediately-invoked lambda). Not supported on HIP. + * + * uint32 match_all_sync(mask, value) + */ +TVM_DLL const Op &match_all_sync(); + /*! * \brief tvm intrinsic for loop continue * diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 474cebe044..7ca86911f3 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -3270,6 +3270,76 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { enable_sparse_gemm_ = true; this->PrintCallExtern(GetType(tvm::ffi::GetRef(op)), op_instance->value, op->args, true, os); + } else if (op->op.same_as(tl::any_sync())) { + ICHECK_EQ(op->args.size(), 2U) << "tl.any_sync expects ."; + os << "__any_sync(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]) << ")"; + } else if (op->op.same_as(tl::all_sync())) { + ICHECK_EQ(op->args.size(), 2U) << "tl.all_sync expects ."; + os << "__all_sync(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]) << ")"; + } else if (op->op.same_as(tl::ballot_sync())) { + ICHECK_EQ(op->args.size(), 2U) + << "tl.ballot_sync expects ."; + // __ballot_sync returns unsigned int (32 bits); zero-extend to uint64. + os << "((unsigned long long)__ballot_sync(" << PrintExpr(op->args[0]) + << ", " << PrintExpr(op->args[1]) << "))"; + } else if (op->op.same_as(tl::ballot())) { + ICHECK_EQ(op->args.size(), 1U) << "tl.ballot expects ."; + os << "((unsigned long long)__ballot_sync(0xFFFFFFFFu, " + << PrintExpr(op->args[0]) << "))"; + } else if (op->op.same_as(tl::activemask())) { + ICHECK(op->args.empty()) << "tl.activemask takes no arguments."; + os << "((unsigned long long)__activemask())"; + } else if (op->op.same_as(tl::syncthreads_count())) { + ICHECK_EQ(op->args.size(), 1U) + << "tl.syncthreads_count expects ."; + os << "__syncthreads_count(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::syncthreads_and())) { + ICHECK_EQ(op->args.size(), 1U) << "tl.syncthreads_and expects ."; + os << "__syncthreads_and(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::syncthreads_or())) { + ICHECK_EQ(op->args.size(), 1U) << "tl.syncthreads_or expects ."; + os << "__syncthreads_or(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::shfl_sync())) { + ICHECK_EQ(op->args.size(), 4U) + << "tl.shfl_sync expects ."; + os << "__shfl_sync(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) << ", " + << PrintExpr(op->args[3]) << ")"; + } else if (op->op.same_as(tl::shfl_xor_sync())) { + ICHECK_EQ(op->args.size(), 4U) + << "tl.shfl_xor_sync expects ."; + os << "__shfl_xor_sync(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) << ", " + << PrintExpr(op->args[3]) << ")"; + } else if (op->op.same_as(tl::shfl_down_sync())) { + ICHECK_EQ(op->args.size(), 4U) + << "tl.shfl_down_sync expects ."; + os << "__shfl_down_sync(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) << ", " + << PrintExpr(op->args[3]) << ")"; + } else if (op->op.same_as(tl::shfl_up_sync())) { + ICHECK_EQ(op->args.size(), 4U) + << "tl.shfl_up_sync expects ."; + os << "__shfl_up_sync(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) << ", " + << PrintExpr(op->args[3]) << ")"; + } else if (op->op.same_as(tl::match_any_sync())) { + ICHECK_EQ(op->args.size(), 2U) + << "tl.match_any_sync expects ."; + os << "__match_any_sync(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]) << ")"; + } else if (op->op.same_as(tl::match_all_sync())) { + ICHECK_EQ(op->args.size(), 2U) + << "tl.match_all_sync expects ."; + // __match_all_sync writes a `pred` flag through its third argument. We + // hide the out-parameter behind an immediately-invoked lambda and + // discard pred (the returned mask already encodes whether all lanes + // agreed: a non-zero result implies pred == 1). + os << "([&]() -> unsigned { int _tl_pred = 0; return __match_all_sync(" + << PrintExpr(op->args[0]) << ", " << PrintExpr(op->args[1]) + << ", &_tl_pred); }())"; } else if (op->op.same_as(tl::get_lane_idx())) { ICHECK_LE(op->args.size(), 1) << "tl.get_lane_idx expects at most one argument ."; diff --git a/src/target/codegen_hip.cc b/src/target/codegen_hip.cc index 1955f492a9..6cc566b9a9 100644 --- a/src/target/codegen_hip.cc +++ b/src/target/codegen_hip.cc @@ -857,6 +857,61 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { } else if (op->op.same_as(tl::pack_b16())) { os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", " << this->PrintExpr(op->args[1]) << ")"; + } else if (op->op.same_as(tl::any_sync())) { + ICHECK_EQ(op->args.size(), 2U) << "tl.any_sync expects ."; + // HIP __any takes only the predicate; the mask is ignored because + // wavefront execution is always convergent across the full wave. + os << "__any(" << PrintExpr(op->args[1]) << ")"; + } else if (op->op.same_as(tl::all_sync())) { + ICHECK_EQ(op->args.size(), 2U) << "tl.all_sync expects ."; + os << "__all(" << PrintExpr(op->args[1]) << ")"; + } else if (op->op.same_as(tl::ballot_sync())) { + ICHECK_EQ(op->args.size(), 2U) + << "tl.ballot_sync expects ."; + // HIP __ballot returns uint64 natively, covering every lane of the + // wavefront; the CUDA-style mask argument is ignored. + os << "((unsigned long long)__ballot(" << PrintExpr(op->args[1]) << "))"; + } else if (op->op.same_as(tl::ballot())) { + ICHECK_EQ(op->args.size(), 1U) << "tl.ballot expects ."; + os << "((unsigned long long)__ballot(" << PrintExpr(op->args[0]) << "))"; + } else if (op->op.same_as(tl::activemask())) { + ICHECK(op->args.empty()) << "tl.activemask takes no arguments."; + os << "((unsigned long long)__ballot(1))"; + } else if (op->op.same_as(tl::syncthreads_count())) { + ICHECK_EQ(op->args.size(), 1U) + << "tl.syncthreads_count expects ."; + os << "__syncthreads_count(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::syncthreads_and())) { + ICHECK_EQ(op->args.size(), 1U) << "tl.syncthreads_and expects ."; + os << "__syncthreads_and(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::syncthreads_or())) { + ICHECK_EQ(op->args.size(), 1U) << "tl.syncthreads_or expects ."; + os << "__syncthreads_or(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::shfl_sync())) { + ICHECK_EQ(op->args.size(), 4U) + << "tl.shfl_sync expects ."; + // HIP __shfl takes only (value, src_lane, width); the mask is ignored. + os << "__shfl(" << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) + << ", " << PrintExpr(op->args[3]) << ")"; + } else if (op->op.same_as(tl::shfl_xor_sync())) { + ICHECK_EQ(op->args.size(), 4U) + << "tl.shfl_xor_sync expects ."; + os << "__shfl_xor(" << PrintExpr(op->args[1]) << ", " + << PrintExpr(op->args[2]) << ", " << PrintExpr(op->args[3]) << ")"; + } else if (op->op.same_as(tl::shfl_down_sync())) { + ICHECK_EQ(op->args.size(), 4U) + << "tl.shfl_down_sync expects ."; + os << "__shfl_down(" << PrintExpr(op->args[1]) << ", " + << PrintExpr(op->args[2]) << ", " << PrintExpr(op->args[3]) << ")"; + } else if (op->op.same_as(tl::shfl_up_sync())) { + ICHECK_EQ(op->args.size(), 4U) + << "tl.shfl_up_sync expects ."; + os << "__shfl_up(" << PrintExpr(op->args[1]) << ", " + << PrintExpr(op->args[2]) << ", " << PrintExpr(op->args[3]) << ")"; + } else if (op->op.same_as(tl::match_any_sync()) || + op->op.same_as(tl::match_all_sync())) { + LOG(FATAL) << "tl." << op->op << " is not supported on HIP: the " + << "__match_{any,all}_sync primitives have no HIP equivalent."; } else if (op->op.same_as(tl::add2()) || op->op.same_as(tl::sub2()) || op->op.same_as(tl::mul2()) || op->op.same_as(tl::fma2()) || op->op.same_as(tl::max2()) || op->op.same_as(tl::min2()) || diff --git a/testing/python/language/test_tilelang_language_warp_vote.py b/testing/python/language/test_tilelang_language_warp_vote.py new file mode 100644 index 0000000000..c7a26c4adb --- /dev/null +++ b/testing/python/language/test_tilelang_language_warp_vote.py @@ -0,0 +1,410 @@ +"""Tests for warp-vote / warp-ballot / block-sync-with-predicate intrinsics. + +Covered intrinsics +------------------ +T.any_sync – __any_sync / __any (HIP) +T.all_sync – __all_sync / __all (HIP) +T.ballot_sync – __ballot_sync→uint64 (CUDA, zero-ext) / __ballot uint64 (HIP, all lanes) +T.ballot – full-warp ballot_sync / __ballot uint64 (HIP, all lanes) +T.activemask – __activemask→uint64 (CUDA, zero-ext) / __ballot(1) uint64 (HIP, all lanes) +T.syncthreads_count – __syncthreads_count +T.syncthreads_and – __syncthreads_and +T.syncthreads_or – __syncthreads_or +""" + +import tilelang +import tilelang.language as T +import torch +import tilelang.testing + + +# --------------------------------------------------------------------------- +# any_sync +# --------------------------------------------------------------------------- + + +@tilelang.jit +def kernel_any_sync(): + """Lane 0 has a non-zero predicate; any_sync should return non-zero for all lanes.""" + + @T.prim_func + def main( + B: T.Tensor((32,), "int32"), + ): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding() + val = T.any_sync(0xFFFFFFFF, tx == 0) + B[tx] = val + + return main + + +@tilelang.testing.requires_cuda +def test_any_sync(): + b = torch.zeros((32,), device="cuda", dtype=torch.int32) + kernel = kernel_any_sync() + src = kernel.get_kernel_source() + assert "__any_sync" in src or "__any" in src, f"Expected __any_sync/__any in source:\n{src}" + kernel(b) + # any lane (lane 0) has predicate==1 → result must be non-zero for all lanes + assert torch.all(b != 0), f"Expected all non-zero, got {b}" + + +# --------------------------------------------------------------------------- +# all_sync +# --------------------------------------------------------------------------- + + +@tilelang.jit +def kernel_all_sync(): + """All lanes always pass predicate 1 → all_sync should return non-zero.""" + + @T.prim_func + def main( + B: T.Tensor((32,), "int32"), + ): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding() + val = T.all_sync(0xFFFFFFFF, 1) + B[tx] = val + + return main + + +@tilelang.testing.requires_cuda +def test_all_sync(): + b = torch.zeros((32,), device="cuda", dtype=torch.int32) + kernel = kernel_all_sync() + src = kernel.get_kernel_source() + assert "__all_sync" in src or "__all" in src, f"Expected __all_sync/__all in source:\n{src}" + kernel(b) + assert torch.all(b != 0), f"Expected all non-zero, got {b}" + + +# --------------------------------------------------------------------------- +# ballot_sync +# --------------------------------------------------------------------------- + + +@tilelang.jit +def kernel_ballot_sync(): + """Only lane 0 has a non-zero predicate → bit 0 of ballot must be set.""" + + @T.prim_func + def main( + B: T.Tensor((32,), "int64"), + ): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding() + mask = T.ballot_sync(0xFFFFFFFF, tx == 0) + B[tx] = T.cast(mask, "int64") + + return main + + +@tilelang.testing.requires_cuda +def test_ballot_sync(): + b = torch.zeros((32,), device="cuda", dtype=torch.int64) + kernel = kernel_ballot_sync() + src = kernel.get_kernel_source() + assert "__ballot_sync" in src or "__ballot" in src, f"Expected __ballot_sync/__ballot in source:\n{src}" + kernel(b) + # All lanes read the same ballot value; bit 0 must be set (lane 0 had pred=1) + assert int(b[0]) & 1, f"Expected bit 0 set in ballot result, got {int(b[0]):#018x}" + # upper 32 bits must be zero on CUDA (32-wide warp) + assert (int(b[0]) >> 32) == 0, f"Expected upper 32 bits zero on CUDA, got {int(b[0]):#018x}" + + +# --------------------------------------------------------------------------- +# ballot (full-warp convenience wrapper) +# --------------------------------------------------------------------------- + + +@tilelang.jit +def kernel_ballot(): + """All lanes pass predicate 1 → lower 32 bits of ballot must all be set.""" + + @T.prim_func + def main( + B: T.Tensor((32,), "int64"), + ): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding() + mask = T.ballot(1) + B[tx] = T.cast(mask, "int64") + + return main + + +@tilelang.testing.requires_cuda +def test_ballot(): + b = torch.zeros((32,), device="cuda", dtype=torch.int64) + kernel = kernel_ballot() + src = kernel.get_kernel_source() + assert "__ballot_sync" in src or "__ballot" in src, f"Expected __ballot_sync/__ballot in source:\n{src}" + kernel(b) + # With predicate=1 for all 32 lanes the lower 32 bits should be 0xFFFFFFFF; + # upper 32 bits are 0 on CUDA (32-wide warp). + assert (int(b[0]) & 0xFFFFFFFF) == 0xFFFFFFFF, f"Expected lower 32 bits all set, got {int(b[0]):#018x}" + assert (int(b[0]) >> 32) == 0, f"Expected upper 32 bits zero on CUDA, got {int(b[0]):#018x}" + + +# --------------------------------------------------------------------------- +# activemask +# --------------------------------------------------------------------------- + + +@tilelang.jit +def kernel_activemask(): + """All 32 threads are active → lower 32 bits of activemask must all be set.""" + + @T.prim_func + def main( + B: T.Tensor((32,), "int64"), + ): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding() + mask = T.activemask() + B[tx] = T.cast(mask, "int64") + + return main + + +@tilelang.testing.requires_cuda +def test_activemask(): + b = torch.zeros((32,), device="cuda", dtype=torch.int64) + kernel = kernel_activemask() + src = kernel.get_kernel_source() + assert "__activemask" in src or "__ballot" in src, f"Expected __activemask/__ballot in source:\n{src}" + kernel(b) + # All 32 lanes active → lower 32 bits = 0xFFFFFFFF; upper 32 bits = 0 on CUDA. + assert (int(b[0]) & 0xFFFFFFFF) == 0xFFFFFFFF, f"Expected lower 32 bits all set, got {int(b[0]):#018x}" + assert (int(b[0]) >> 32) == 0, f"Expected upper 32 bits zero on CUDA, got {int(b[0]):#018x}" + + +# --------------------------------------------------------------------------- +# syncthreads_count +# --------------------------------------------------------------------------- + + +@tilelang.jit +def kernel_syncthreads_count(): + """Exactly half the threads (lanes 0–15) pass predicate 1.""" + + @T.prim_func + def main( + B: T.Tensor((32,), "int32"), + ): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding() + cnt = T.syncthreads_count(tx < 16) + B[tx] = cnt + + return main + + +@tilelang.testing.requires_cuda +def test_syncthreads_count(): + b = torch.zeros((32,), device="cuda", dtype=torch.int32) + kernel = kernel_syncthreads_count() + src = kernel.get_kernel_source() + assert "__syncthreads_count" in src, f"Expected __syncthreads_count in source:\n{src}" + kernel(b) + assert torch.all(b == 16), f"Expected all 16, got {b}" + + +# --------------------------------------------------------------------------- +# syncthreads_and +# --------------------------------------------------------------------------- + + +@tilelang.jit +def kernel_syncthreads_and_true(): + """All threads pass predicate 1 → syncthreads_and returns non-zero.""" + + @T.prim_func + def main( + B: T.Tensor((32,), "int32"), + ): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding() + result = T.syncthreads_and(1) + B[tx] = result + + return main + + +@tilelang.jit +def kernel_syncthreads_and_false(): + """Thread 0 passes predicate 0 → syncthreads_and returns 0.""" + + @T.prim_func + def main( + B: T.Tensor((32,), "int32"), + ): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding() + result = T.syncthreads_and(tx != 0) + B[tx] = result + + return main + + +@tilelang.testing.requires_cuda +def test_syncthreads_and(): + b = torch.zeros((32,), device="cuda", dtype=torch.int32) + kernel = kernel_syncthreads_and_true() + src = kernel.get_kernel_source() + assert "__syncthreads_and" in src, f"Expected __syncthreads_and in source:\n{src}" + kernel(b) + assert torch.all(b != 0), f"Expected all non-zero, got {b}" + + b2 = torch.zeros((32,), device="cuda", dtype=torch.int32) + kernel2 = kernel_syncthreads_and_false() + kernel2(b2) + assert torch.all(b2 == 0), f"Expected all 0, got {b2}" + + +# --------------------------------------------------------------------------- +# syncthreads_or +# --------------------------------------------------------------------------- + + +@tilelang.jit +def kernel_syncthreads_or_true(): + """At least one thread (lane 0) passes predicate 1 → syncthreads_or != 0.""" + + @T.prim_func + def main( + B: T.Tensor((32,), "int32"), + ): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding() + result = T.syncthreads_or(tx == 0) + B[tx] = result + + return main + + +@tilelang.jit +def kernel_syncthreads_or_false(): + """No thread passes predicate 1 → syncthreads_or returns 0.""" + + @T.prim_func + def main( + B: T.Tensor((32,), "int32"), + ): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding() + result = T.syncthreads_or(0) + B[tx] = result + + return main + + +@tilelang.testing.requires_cuda +def test_syncthreads_or(): + b = torch.zeros((32,), device="cuda", dtype=torch.int32) + kernel = kernel_syncthreads_or_true() + src = kernel.get_kernel_source() + assert "__syncthreads_or" in src, f"Expected __syncthreads_or in source:\n{src}" + kernel(b) + assert torch.all(b != 0), f"Expected all non-zero, got {b}" + + b2 = torch.zeros((32,), device="cuda", dtype=torch.int32) + kernel2 = kernel_syncthreads_or_false() + kernel2(b2) + assert torch.all(b2 == 0), f"Expected all 0, got {b2}" + + +# --------------------------------------------------------------------------- +# match_any_sync +# --------------------------------------------------------------------------- + + +@tilelang.jit +def kernel_match_any_sync(): + """Lanes 0-15 share value 1; lanes 16-31 share value 2. match_any_sync + should return 0x0000FFFF for the first half and 0xFFFF0000 for the + second half.""" + + @T.prim_func + def main( + B: T.Tensor((32,), "int32"), + ): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding() + value = T.if_then_else(tx < 16, 1, 2) + peers = T.match_any_sync(0xFFFFFFFF, value) + B[tx] = T.cast(peers, "int32") + + return main + + +@tilelang.testing.requires_cuda +def test_match_any_sync(): + b = torch.zeros((32,), device="cuda", dtype=torch.int32) + kernel = kernel_match_any_sync() + src = kernel.get_kernel_source() + assert "__match_any_sync" in src, f"Expected __match_any_sync in source:\n{src}" + kernel(b) + # Reinterpret the int32 buffer as uint32 to compare against bitmasks + # whose high bit is set (0xFFFF0000 overflows int32). + observed_u32 = b.to(torch.int64) & 0xFFFFFFFF + expected = torch.tensor([0x0000FFFF] * 16 + [0xFFFF0000] * 16, dtype=torch.int64, device="cuda") + assert torch.equal(observed_u32, expected), f"Expected {expected}, got {observed_u32}" + + +# --------------------------------------------------------------------------- +# match_all_sync +# --------------------------------------------------------------------------- + + +@tilelang.jit +def kernel_match_all_sync_true(): + """All lanes share value 7 → match_all_sync returns the full mask.""" + + @T.prim_func + def main( + B: T.Tensor((32,), "int32"), + ): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding() + result = T.match_all_sync(0xFFFFFFFF, 7) + B[tx] = T.cast(result, "int32") + + return main + + +@tilelang.jit +def kernel_match_all_sync_false(): + """Lanes disagree → match_all_sync returns 0.""" + + @T.prim_func + def main( + B: T.Tensor((32,), "int32"), + ): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding() + result = T.match_all_sync(0xFFFFFFFF, tx) + B[tx] = T.cast(result, "int32") + + return main + + +@tilelang.testing.requires_cuda +def test_match_all_sync(): + b = torch.zeros((32,), device="cuda", dtype=torch.int32) + kernel = kernel_match_all_sync_true() + src = kernel.get_kernel_source() + assert "__match_all_sync" in src, f"Expected __match_all_sync in source:\n{src}" + kernel(b) + assert torch.all(b == -1), f"Expected all 0xFFFFFFFF (sign-extended -1), got {b}" + + b2 = torch.zeros((32,), device="cuda", dtype=torch.int32) + kernel_match_all_sync_false()(b2) + assert torch.all(b2 == 0), f"Expected all 0, got {b2}" + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 74ad185c41..8308d762c2 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -103,6 +103,14 @@ from .builtin import stg64 as stg64 # noqa: F401 from .builtin import stg128 as stg128 # noqa: F401 from .builtin import stg256 as stg256 # noqa: F401 +from .builtin import any_sync as any_sync # noqa: F401 +from .builtin import all_sync as all_sync # noqa: F401 +from .builtin import ballot_sync as ballot_sync # noqa: F401 +from .builtin import ballot as ballot # noqa: F401 +from .builtin import activemask as activemask # noqa: F401 +from .builtin import syncthreads_count as syncthreads_count # noqa: F401 +from .builtin import syncthreads_and as syncthreads_and # noqa: F401 +from .builtin import syncthreads_or as syncthreads_or # noqa: F401 from .utils import index_to_coordinates # noqa: F401 diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 7ec535c5c9..f4ad89ea37 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -866,47 +866,48 @@ def barrier_arrive(mbarrier: BarrierType): return mbarrier_arrive(mbarrier) -def shfl_xor(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call): - """Perform a shuffle operation with XOR offset. +# Full-warp mask as a proper uint32 TIR constant so the emitted C/C++ source +# prints as `0xFFFFFFFFu` instead of `(int64_t)4294967295` after TIR widening. +_FULL_WARP_MASK = tir.const(0xFFFFFFFF, "uint32") +_DEFAULT_SHFL_WIDTH = 32 - Args: - value: Optional[int, PrimExpr] - The value to shuffle - offset: Optional[int, PrimExpr] - The offset for the shuffle operation - Returns: - tir.Call: A handle to the shuffle operation - """ - if _IS_HIP_AVAILABLE: - return tir.call_extern(value.dtype, "__shfl_xor", value, offset) - else: - return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xFFFFFFFF, value, offset) +def _as_uint32_mask(mask: int | PrimExpr) -> PrimExpr: + """Normalize a warp lane mask to a uint32 TIR expression. -def shfl_down(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call): - """Perform a shuffle operation with down offset. + Python literals (e.g. ``0xFFFFFFFF``) would otherwise be widened to int64 + by TIR and printed as ``(int64_t)4294967295`` in the generated source. + """ + if isinstance(mask, int): + return tir.const(mask, "uint32") + return mask - Args: - value: Optional[int, PrimExpr] - The value to shuffle + +def shfl_xor( + mask: int | PrimExpr, value: int | PrimExpr | tir.Call, delta: int | PrimExpr | tir.Call, width: int | PrimExpr = _DEFAULT_SHFL_WIDTH +): + """XOR-swap ``value`` across lanes (``__shfl_xor_sync`` on CUDA, + ``__shfl_xor`` on HIP — mask ignored on HIP). """ - if _IS_HIP_AVAILABLE: - return tir.call_extern(value.dtype, "__shfl_down", value, offset) - else: - return tir.call_extern(value.dtype, "__shfl_down_sync", 0xFFFFFFFF, value, offset) + return tir.call_intrin(value.dtype, tir.op.Op.get("tl.shfl_xor_sync"), _as_uint32_mask(mask), value, delta, width) -def shfl_up(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call): - """Perform a shuffle operation with up offset. +def shfl_down( + mask: int | PrimExpr, value: int | PrimExpr | tir.Call, delta: int | PrimExpr | tir.Call, width: int | PrimExpr = _DEFAULT_SHFL_WIDTH +): + """Shift ``value`` down by ``delta`` lanes (``__shfl_down_sync`` on CUDA, + ``__shfl_down`` on HIP). + """ + return tir.call_intrin(value.dtype, tir.op.Op.get("tl.shfl_down_sync"), _as_uint32_mask(mask), value, delta, width) - Args: - value: Optional[int, PrimExpr] - The value to shuffle + +def shfl_up( + mask: int | PrimExpr, value: int | PrimExpr | tir.Call, delta: int | PrimExpr | tir.Call, width: int | PrimExpr = _DEFAULT_SHFL_WIDTH +): + """Shift ``value`` up by ``delta`` lanes (``__shfl_up_sync`` on CUDA, + ``__shfl_up`` on HIP). """ - if _IS_HIP_AVAILABLE: - return tir.call_extern(value.dtype, "__shfl_up", value, offset) - else: - return tir.call_extern(value.dtype, "__shfl_up_sync", 0xFFFFFFFF, value, offset) + return tir.call_intrin(value.dtype, tir.op.Op.get("tl.shfl_up_sync"), _as_uint32_mask(mask), value, delta, width) def sync_threads(barrier_id: int = None, arrive_count: int = None): @@ -926,11 +927,137 @@ def sync_warp(mask: int = None): return tir.call_intrin("void", tir.op.Op.get("tl.sync_warp")) -def shfl_sync(mask: int, value: int | PrimExpr, srcLane: int, width: int = None): - """Receives data from a thread in the same warp.""" - if width is None: - return tir.call_extern(value.dtype, "__shfl_sync", mask, value, srcLane) - return tir.call_extern(value.dtype, "__shfl_sync", mask, value, srcLane, width) +def shfl_sync(mask: int | PrimExpr, value: int | PrimExpr, srcLane: int | PrimExpr, width: int | PrimExpr = _DEFAULT_SHFL_WIDTH): + """Broadcast ``value`` from ``srcLane`` to all lanes in the subgroup of + ``width`` lanes (``__shfl_sync`` on CUDA, ``__shfl`` on HIP — mask ignored + on HIP). + """ + return tir.call_intrin(value.dtype, tir.op.Op.get("tl.shfl_sync"), _as_uint32_mask(mask), value, srcLane, width) + + +# --------------------------------------------------------------------------- +# Warp-vote / warp-ballot intrinsics +# +# The CUDA/HIP split and the uint32->uint64 zero-extension for ballot_sync and +# activemask are handled in codegen (src/target/codegen_{cuda,hip}.cc). These +# Python wrappers simply emit the backend-agnostic tl.* ops. +# --------------------------------------------------------------------------- + + +def any_sync(mask: int | PrimExpr, predicate: int | PrimExpr) -> PrimExpr: + """Non-zero if ANY active lane in ``mask`` has a non-zero ``predicate``. + + Lowers to ``__any_sync(mask, predicate)`` on CUDA and ``__any(predicate)`` + on HIP (the mask is ignored on HIP because the full wavefront is always + convergent). + + Args: + mask: Warp lane mask (e.g. ``0xFFFFFFFF`` for all 32 lanes). + predicate: Integer condition to test. + + Returns: + int32: Non-zero if any thread in the mask has a non-zero predicate. + """ + return tir.call_intrin("int32", tir.op.Op.get("tl.any_sync"), _as_uint32_mask(mask), predicate) + + +def all_sync(mask: int | PrimExpr, predicate: int | PrimExpr) -> PrimExpr: + """Non-zero only if ALL active lanes in ``mask`` have a non-zero predicate. + + Lowers to ``__all_sync(mask, predicate)`` on CUDA and ``__all(predicate)`` + on HIP. + + Args: + mask: Warp lane mask (e.g. ``0xFFFFFFFF`` for all 32 lanes). + predicate: Integer condition to test. + + Returns: + int32: Non-zero if all threads in the mask have a non-zero predicate. + """ + return tir.call_intrin("int32", tir.op.Op.get("tl.all_sync"), _as_uint32_mask(mask), predicate) + + +def ballot_sync(mask: int | PrimExpr, predicate: int | PrimExpr) -> PrimExpr: + """Return a ``uint64`` bitmask of lanes in ``mask`` whose predicate is set. + + CUDA: ``__ballot_sync(mask, predicate)`` returns ``unsigned int``; codegen + zero-extends it to ``uint64`` (upper 32 bits always zero for 32-wide warps). + HIP: ``__ballot(predicate)`` returns ``uint64`` natively, covering all + 64 wavefront lanes. The mask argument is ignored on HIP. + + Returns: + uint64: Bitmask with bit N set if lane N's predicate is non-zero. + """ + return tir.call_intrin("uint64", tir.op.Op.get("tl.ballot_sync"), _as_uint32_mask(mask), predicate) + + +def ballot(predicate: int | PrimExpr) -> PrimExpr: + """Full-warp / full-wavefront ballot. Equivalent to + ``ballot_sync(0xFFFFFFFF, predicate)``. + + Returns: + uint64: Bitmask with bit N set if lane N's predicate is non-zero. + """ + return tir.call_intrin("uint64", tir.op.Op.get("tl.ballot"), predicate) + + +def activemask() -> PrimExpr: + """Return a ``uint64`` bitmask of currently active (non-exited) lanes. + + Lowers to ``__activemask()`` (zero-extended to ``uint64``) on CUDA and + ``__ballot(1)`` on HIP. + """ + return tir.call_intrin("uint64", tir.op.Op.get("tl.activemask")) + + +# --------------------------------------------------------------------------- +# Thread-block synchronization with predicate (shared across CUDA & HIP) +# --------------------------------------------------------------------------- + + +def syncthreads_count(predicate: int | PrimExpr) -> PrimExpr: + """Block barrier that returns the number of threads whose ``predicate`` + evaluates to non-zero (``__syncthreads_count`` on CUDA and HIP). + """ + return tir.call_intrin("int32", tir.op.Op.get("tl.syncthreads_count"), predicate) + + +def syncthreads_and(predicate: int | PrimExpr) -> PrimExpr: + """Block barrier that returns non-zero only if ALL threads have a non-zero + ``predicate`` (``__syncthreads_and`` on CUDA and HIP). + """ + return tir.call_intrin("int32", tir.op.Op.get("tl.syncthreads_and"), predicate) + + +def syncthreads_or(predicate: int | PrimExpr) -> PrimExpr: + """Block barrier that returns non-zero if ANY thread has a non-zero + ``predicate`` (``__syncthreads_or`` on CUDA and HIP). + """ + return tir.call_intrin("int32", tir.op.Op.get("tl.syncthreads_or"), predicate) + + +# --------------------------------------------------------------------------- +# Warp match intrinsics (CUDA sm_70+; unsupported on HIP) +# --------------------------------------------------------------------------- + + +def match_any_sync(mask: int | PrimExpr, value: int | PrimExpr) -> PrimExpr: + """Return a ``uint32`` bitmask of lanes in ``mask`` whose ``value`` equals + the calling lane's value. Lowers to ``__match_any_sync`` on CUDA + (compute capability >= 7.0). Not supported on HIP. + """ + return tir.call_intrin("uint32", tir.op.Op.get("tl.match_any_sync"), _as_uint32_mask(mask), value) + + +def match_all_sync(mask: int | PrimExpr, value: int | PrimExpr) -> PrimExpr: + """Return ``mask`` if all lanes in ``mask`` agree on ``value``, else 0. + + Lowers to ``__match_all_sync`` on CUDA (compute capability >= 7.0); the + trailing ``int*`` predicate output is hidden in codegen and discarded. + Callers can reconstruct the predicate as ``result != 0``. Not supported + on HIP. + """ + return tir.call_intrin("uint32", tir.op.Op.get("tl.match_all_sync"), _as_uint32_mask(mask), value) def sync_global(): From b3d5981b001336150e69f5ba8bfb468b0948ea51 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Tue, 14 Apr 2026 01:05:54 +0800 Subject: [PATCH 047/156] [API] Default warp-lane mask to 0xFFFFFFFF for warp-sync builtins (#2039) Reorder the `mask` parameter to the end with a default of `_FULL_WARP_MASK` for shfl_xor/down/up/sync, any_sync, all_sync, ballot_sync, and match_any/all_sync. Full-warp is the overwhelmingly common case, so making mask an optional trailing kwarg removes boilerplate at call sites and lets old `(value, delta)`-style callers keep working. Update the test suite and docs/instructions.md to match the new signatures. --- docs/programming_guides/instructions.md | 22 ++++---- .../test_tilelang_language_warp_sync.py | 2 +- .../test_tilelang_language_warp_vote.py | 12 ++--- tilelang/language/builtin.py | 53 ++++++++++++++----- 4 files changed, 59 insertions(+), 30 deletions(-) diff --git a/docs/programming_guides/instructions.md b/docs/programming_guides/instructions.md index 874b8c020d..9fe92bb803 100644 --- a/docs/programming_guides/instructions.md +++ b/docs/programming_guides/instructions.md @@ -195,9 +195,9 @@ Synchronization helpers - `T.pdl_sync()`: Wait until kernel dependencies are satisfied. Warp-vote / warp-ballot (CUDA ≥ 9 / HIP) -- `T.any_sync(mask, predicate)` → `int32`: Non-zero if ANY lane in `mask` has non-zero predicate (`__any_sync`). -- `T.all_sync(mask, predicate)` → `int32`: Non-zero if ALL lanes in `mask` have non-zero predicate (`__all_sync`). -- `T.ballot_sync(mask, predicate)` → `uint64`: Bitmask of lanes in `mask` with non-zero predicate. CUDA: `__ballot_sync` zero-extended to 64 bits; HIP: `__ballot` returns natively as `uint64`, covering all 64 wavefront lanes. +- `T.any_sync(predicate[, mask])` → `int32`: Non-zero if ANY lane in `mask` has non-zero predicate (`__any_sync`). `mask` defaults to `0xFFFFFFFF`. +- `T.all_sync(predicate[, mask])` → `int32`: Non-zero if ALL lanes in `mask` have non-zero predicate (`__all_sync`). `mask` defaults to `0xFFFFFFFF`. +- `T.ballot_sync(predicate[, mask])` → `uint64`: Bitmask of lanes in `mask` with non-zero predicate. CUDA: `__ballot_sync` zero-extended to 64 bits; HIP: `__ballot` returns natively as `uint64`, covering all 64 wavefront lanes. `mask` defaults to `0xFFFFFFFF`. - `T.ballot(predicate)` → `uint64`: Full-warp/wavefront ballot (mask = `0xFFFFFFFF`). No truncation on HIP. - `T.activemask()` → `uint64`: Bitmask of currently active lanes. CUDA: `__activemask` zero-extended to 64 bits; HIP: `__ballot(1)` as `uint64`. @@ -206,15 +206,15 @@ Block-wide predicated sync - `T.syncthreads_and(predicate)` → `int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`). - `T.syncthreads_or(predicate)` → `int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`). -Warp-shuffle (intra-warp data exchange) -- `T.shfl_sync(mask, value, src_lane[, width])`: Broadcast value from `src_lane` to all lanes (`__shfl_sync`). -- `T.shfl_xor(value, offset[, mask, width])`: XOR-swap across lanes (`__shfl_xor_sync`). -- `T.shfl_down(value, offset[, mask, width])`: Shift down by `offset` lanes (`__shfl_down_sync`). -- `T.shfl_up(value, offset[, mask, width])`: Shift up by `offset` lanes (`__shfl_up_sync`). +Warp-shuffle (intra-warp data exchange). All accept a trailing `mask` kwarg that defaults to `0xFFFFFFFF`. +- `T.shfl_sync(value, src_lane[, width, mask])`: Broadcast value from `src_lane` to all lanes (`__shfl_sync`). +- `T.shfl_xor(value, delta[, width, mask])`: XOR-swap across lanes (`__shfl_xor_sync`). +- `T.shfl_down(value, delta[, width, mask])`: Shift down by `delta` lanes (`__shfl_down_sync`). +- `T.shfl_up(value, delta[, width, mask])`: Shift up by `delta` lanes (`__shfl_up_sync`). -Warp-match (CUDA sm_70+, not supported on HIP) -- `T.match_any_sync(mask, value)` → `uint32`: Bitmask of lanes in `mask` whose `value` matches the calling lane's (`__match_any_sync`). -- `T.match_all_sync(mask, value)` → `uint32`: Returns `mask` if all lanes in `mask` agree on `value`, else 0 (`__match_all_sync`). The C-level `int*` predicate output is hidden; reconstruct it as `result != 0`. +Warp-match (CUDA sm_70+, not supported on HIP). `mask` defaults to `0xFFFFFFFF`. +- `T.match_any_sync(value[, mask])` → `uint32`: Bitmask of lanes in `mask` whose `value` matches the calling lane's (`__match_any_sync`). +- `T.match_all_sync(value[, mask])` → `uint32`: Returns `mask` if all lanes in `mask` agree on `value`, else 0 (`__match_all_sync`). The C-level `int*` predicate output is hidden; reconstruct it as `result != 0`. > **Note on HIP:** `any_sync`/`all_sync` ignore the mask and call `__any`/`__all` directly. `ballot_sync`, `ballot`, and `activemask` call `__ballot` which returns `uint64` natively on 64-thread wavefronts — no truncation occurs. Shuffle intrinsics lower to `__shfl`/`__shfl_xor`/`__shfl_down`/`__shfl_up` (mask ignored). `syncthreads_count/and/or` have identical signatures on both platforms. `match_any_sync` and `match_all_sync` have no HIP equivalent and will fail to codegen on HIP. diff --git a/testing/python/language/test_tilelang_language_warp_sync.py b/testing/python/language/test_tilelang_language_warp_sync.py index 4c9aaff2a3..d113a43c0f 100644 --- a/testing/python/language/test_tilelang_language_warp_sync.py +++ b/testing/python/language/test_tilelang_language_warp_sync.py @@ -43,7 +43,7 @@ def main( with T.Kernel(1, threads=32): tx = T.get_thread_binding() val = tx * 10 - broadcast = T.shfl_sync(0xFFFFFFFF, val, 31) + broadcast = T.shfl_sync(val, 31) A[tx] = broadcast return main diff --git a/testing/python/language/test_tilelang_language_warp_vote.py b/testing/python/language/test_tilelang_language_warp_vote.py index c7a26c4adb..8a630ce6df 100644 --- a/testing/python/language/test_tilelang_language_warp_vote.py +++ b/testing/python/language/test_tilelang_language_warp_vote.py @@ -33,7 +33,7 @@ def main( ): with T.Kernel(1, threads=32): tx = T.get_thread_binding() - val = T.any_sync(0xFFFFFFFF, tx == 0) + val = T.any_sync(tx == 0) B[tx] = val return main @@ -65,7 +65,7 @@ def main( ): with T.Kernel(1, threads=32): tx = T.get_thread_binding() - val = T.all_sync(0xFFFFFFFF, 1) + val = T.all_sync(1) B[tx] = val return main @@ -96,7 +96,7 @@ def main( ): with T.Kernel(1, threads=32): tx = T.get_thread_binding() - mask = T.ballot_sync(0xFFFFFFFF, tx == 0) + mask = T.ballot_sync(tx == 0) B[tx] = T.cast(mask, "int64") return main @@ -335,7 +335,7 @@ def main( with T.Kernel(1, threads=32): tx = T.get_thread_binding() value = T.if_then_else(tx < 16, 1, 2) - peers = T.match_any_sync(0xFFFFFFFF, value) + peers = T.match_any_sync(value) B[tx] = T.cast(peers, "int32") return main @@ -370,7 +370,7 @@ def main( ): with T.Kernel(1, threads=32): tx = T.get_thread_binding() - result = T.match_all_sync(0xFFFFFFFF, 7) + result = T.match_all_sync(7) B[tx] = T.cast(result, "int32") return main @@ -386,7 +386,7 @@ def main( ): with T.Kernel(1, threads=32): tx = T.get_thread_binding() - result = T.match_all_sync(0xFFFFFFFF, tx) + result = T.match_all_sync(tx) B[tx] = T.cast(result, "int32") return main diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index f4ad89ea37..150d32a311 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -884,7 +884,10 @@ def _as_uint32_mask(mask: int | PrimExpr) -> PrimExpr: def shfl_xor( - mask: int | PrimExpr, value: int | PrimExpr | tir.Call, delta: int | PrimExpr | tir.Call, width: int | PrimExpr = _DEFAULT_SHFL_WIDTH + value: int | PrimExpr | tir.Call, + delta: int | PrimExpr | tir.Call, + width: int | PrimExpr = _DEFAULT_SHFL_WIDTH, + mask: int | PrimExpr = _FULL_WARP_MASK, ): """XOR-swap ``value`` across lanes (``__shfl_xor_sync`` on CUDA, ``__shfl_xor`` on HIP — mask ignored on HIP). @@ -893,7 +896,10 @@ def shfl_xor( def shfl_down( - mask: int | PrimExpr, value: int | PrimExpr | tir.Call, delta: int | PrimExpr | tir.Call, width: int | PrimExpr = _DEFAULT_SHFL_WIDTH + value: int | PrimExpr | tir.Call, + delta: int | PrimExpr | tir.Call, + width: int | PrimExpr = _DEFAULT_SHFL_WIDTH, + mask: int | PrimExpr = _FULL_WARP_MASK, ): """Shift ``value`` down by ``delta`` lanes (``__shfl_down_sync`` on CUDA, ``__shfl_down`` on HIP). @@ -902,7 +908,10 @@ def shfl_down( def shfl_up( - mask: int | PrimExpr, value: int | PrimExpr | tir.Call, delta: int | PrimExpr | tir.Call, width: int | PrimExpr = _DEFAULT_SHFL_WIDTH + value: int | PrimExpr | tir.Call, + delta: int | PrimExpr | tir.Call, + width: int | PrimExpr = _DEFAULT_SHFL_WIDTH, + mask: int | PrimExpr = _FULL_WARP_MASK, ): """Shift ``value`` up by ``delta`` lanes (``__shfl_up_sync`` on CUDA, ``__shfl_up`` on HIP). @@ -927,7 +936,12 @@ def sync_warp(mask: int = None): return tir.call_intrin("void", tir.op.Op.get("tl.sync_warp")) -def shfl_sync(mask: int | PrimExpr, value: int | PrimExpr, srcLane: int | PrimExpr, width: int | PrimExpr = _DEFAULT_SHFL_WIDTH): +def shfl_sync( + value: int | PrimExpr, + srcLane: int | PrimExpr, + width: int | PrimExpr = _DEFAULT_SHFL_WIDTH, + mask: int | PrimExpr = _FULL_WARP_MASK, +): """Broadcast ``value`` from ``srcLane`` to all lanes in the subgroup of ``width`` lanes (``__shfl_sync`` on CUDA, ``__shfl`` on HIP — mask ignored on HIP). @@ -944,7 +958,10 @@ def shfl_sync(mask: int | PrimExpr, value: int | PrimExpr, srcLane: int | PrimEx # --------------------------------------------------------------------------- -def any_sync(mask: int | PrimExpr, predicate: int | PrimExpr) -> PrimExpr: +def any_sync( + predicate: int | PrimExpr, + mask: int | PrimExpr = _FULL_WARP_MASK, +) -> PrimExpr: """Non-zero if ANY active lane in ``mask`` has a non-zero ``predicate``. Lowers to ``__any_sync(mask, predicate)`` on CUDA and ``__any(predicate)`` @@ -952,8 +969,8 @@ def any_sync(mask: int | PrimExpr, predicate: int | PrimExpr) -> PrimExpr: convergent). Args: - mask: Warp lane mask (e.g. ``0xFFFFFFFF`` for all 32 lanes). predicate: Integer condition to test. + mask: Warp lane mask (defaults to ``0xFFFFFFFF``, i.e. all 32 lanes). Returns: int32: Non-zero if any thread in the mask has a non-zero predicate. @@ -961,15 +978,18 @@ def any_sync(mask: int | PrimExpr, predicate: int | PrimExpr) -> PrimExpr: return tir.call_intrin("int32", tir.op.Op.get("tl.any_sync"), _as_uint32_mask(mask), predicate) -def all_sync(mask: int | PrimExpr, predicate: int | PrimExpr) -> PrimExpr: +def all_sync( + predicate: int | PrimExpr, + mask: int | PrimExpr = _FULL_WARP_MASK, +) -> PrimExpr: """Non-zero only if ALL active lanes in ``mask`` have a non-zero predicate. Lowers to ``__all_sync(mask, predicate)`` on CUDA and ``__all(predicate)`` on HIP. Args: - mask: Warp lane mask (e.g. ``0xFFFFFFFF`` for all 32 lanes). predicate: Integer condition to test. + mask: Warp lane mask (defaults to ``0xFFFFFFFF``, i.e. all 32 lanes). Returns: int32: Non-zero if all threads in the mask have a non-zero predicate. @@ -977,7 +997,10 @@ def all_sync(mask: int | PrimExpr, predicate: int | PrimExpr) -> PrimExpr: return tir.call_intrin("int32", tir.op.Op.get("tl.all_sync"), _as_uint32_mask(mask), predicate) -def ballot_sync(mask: int | PrimExpr, predicate: int | PrimExpr) -> PrimExpr: +def ballot_sync( + predicate: int | PrimExpr, + mask: int | PrimExpr = _FULL_WARP_MASK, +) -> PrimExpr: """Return a ``uint64`` bitmask of lanes in ``mask`` whose predicate is set. CUDA: ``__ballot_sync(mask, predicate)`` returns ``unsigned int``; codegen @@ -993,7 +1016,7 @@ def ballot_sync(mask: int | PrimExpr, predicate: int | PrimExpr) -> PrimExpr: def ballot(predicate: int | PrimExpr) -> PrimExpr: """Full-warp / full-wavefront ballot. Equivalent to - ``ballot_sync(0xFFFFFFFF, predicate)``. + ``ballot_sync(predicate)`` (i.e. with the default full warp mask). Returns: uint64: Bitmask with bit N set if lane N's predicate is non-zero. @@ -1041,7 +1064,10 @@ def syncthreads_or(predicate: int | PrimExpr) -> PrimExpr: # --------------------------------------------------------------------------- -def match_any_sync(mask: int | PrimExpr, value: int | PrimExpr) -> PrimExpr: +def match_any_sync( + value: int | PrimExpr, + mask: int | PrimExpr = _FULL_WARP_MASK, +) -> PrimExpr: """Return a ``uint32`` bitmask of lanes in ``mask`` whose ``value`` equals the calling lane's value. Lowers to ``__match_any_sync`` on CUDA (compute capability >= 7.0). Not supported on HIP. @@ -1049,7 +1075,10 @@ def match_any_sync(mask: int | PrimExpr, value: int | PrimExpr) -> PrimExpr: return tir.call_intrin("uint32", tir.op.Op.get("tl.match_any_sync"), _as_uint32_mask(mask), value) -def match_all_sync(mask: int | PrimExpr, value: int | PrimExpr) -> PrimExpr: +def match_all_sync( + value: int | PrimExpr, + mask: int | PrimExpr = _FULL_WARP_MASK, +) -> PrimExpr: """Return ``mask`` if all lanes in ``mask`` agree on ``value``, else 0. Lowers to ``__match_all_sync`` on CUDA (compute capability >= 7.0); the From a8bafa619970f819e15ede6fbc33c75bcc302a16 Mon Sep 17 00:00:00 2001 From: Kuris <227995639+kurisu6912@users.noreply.github.com> Date: Tue, 14 Apr 2026 13:35:03 +0800 Subject: [PATCH 048/156] fix: suppress false positive conflict write warning when dst index depends on thread var (#2041) fix: suppress false positive conflict write warning when dst index depends on thread var (#2035) --- src/op/copy.cc | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index c58f1296bd..2caa948b48 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -18,6 +18,7 @@ #include "utils.h" #include "builtin.h" +#include #include #include #include @@ -1099,9 +1100,24 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, if (is_cpu_target || IsLocalBuffer(src) || IsLocalBuffer(dst)) { if (IsLocalBuffer(src) && !IsLocalBuffer(dst)) { - LOG(WARNING) << "Copy from local buffer `" << src->name << "` to " - << dst.scope() << " buffer `" << dst->name - << "` may cause conflicted write."; + // A conflict write only occurs when multiple threads write to the same + // global address. If any dst_range dimension's min depends on the thread + // variable, each thread targets a distinct location and there is no + // conflict. + bool dst_depends_on_thread = false; + for (const auto &range : dst_range) { + if (tir::UsesVar(range->min, [&](const VarNode *v) { + return v == T.thread_var.get(); + })) { + dst_depends_on_thread = true; + break; + } + } + if (!dst_depends_on_thread) { + LOG(WARNING) << "Copy from local buffer `" << src->name << "` to " + << dst.scope() << " buffer `" << dst->name + << "` may cause conflicted write."; + } } vectorized_thread_loop = VectorizeLoop(fused_loop, T.layout_map); return vectorized_thread_loop; From 74fc9806fbf1675563e2838c4541da4ba5921b23 Mon Sep 17 00:00:00 2001 From: LJC00118 <77378439+LJC00118@users.noreply.github.com> Date: Tue, 14 Apr 2026 15:24:33 +0800 Subject: [PATCH 049/156] [Refactor] Refactor `DecoupleTypeCast` Pass (#2026) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * draft * fix lint * fix lint * Improve pass comments * treat loop-invariant memory access as scalar in vectorization * Add unit test * fix lint * fix b[i] = a[i] + a[i+32] case * Minor fix * Tighten DecoupleTypeCast cast detection and add RMW regression test Skip BufferLoad/BufferStore indices when searching for Cast nodes so an index-type conversion does not spuriously trigger the decoupling transformation. Clarify the load-replacement table in visit_for_ — store entries must feed into it so RMW loads map to the store-side cast buffer — and cover the a[i] = a[i] + a[i+32] case with a regression test. Co-Authored-By: Claude Opus 4.6 (1M context) * Keep loop-invariant stores in memory bucket during vectorize planning Commit 03bb0706 diverted all loop-invariant global/shared accesses into the local_fragment bucket, but that bucket's constraint is dropped by the has_global_or_shared_buffer strategy. ComputeBufferVectorSize already returns 1 for a reduction-like store such as shared[tx] += a[...+j], yet that 1 was silently lost, so vectorization proceeded with vector_size=2 and emitted two scalar writes to the same shared[tx] — the second clobbered the first and dropped a lane of the accumulation. Only loop-invariant loads (genuine broadcast reads) are safe to divert; stores must stay in the memory bucket so their vector_size=1 constraint is honored and the loop is left scalar. Co-Authored-By: Claude Opus 4.6 (1M context) * Remove issue 1106 test Each thread only touches its own shared[tx] slot, so no __syncthreads is actually required and the test asserts an over-conservative behavior. Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: LeiWang1999 Co-authored-by: Claude Opus 4.6 (1M context) --- src/transform/loop_vectorize.cc | 25 +- .../python/issue/test_tilelang_issue_1106.py | 42 -- ...t_tilelang_transform_decouple_type_cast.py | 92 +++ tilelang/transform/decouple_type_cast.py | 587 ++++++++---------- 4 files changed, 387 insertions(+), 359 deletions(-) delete mode 100644 testing/python/issue/test_tilelang_issue_1106.py diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index d612208438..c2655212b8 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -250,8 +250,29 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { local_fragment_buffers.push_back(info); } else { // global, shared, shared.dyn - memory_min = arith::ZeroAwareGCD(memory_min, info.vector_size); - has_global_or_shared_buffer = true; + // If a *load*'s indices don't depend on loop var (e.g. b[0]), treat + // as local — it will become a scalar broadcast, not a vector memory + // access, and DecoupleTypeCast won't create a cast buffer for it. + // Stores must stay in the memory bucket: a loop-invariant store is a + // reduction-like pattern where ComputeBufferVectorSize has already + // returned 1 to disable vectorization, and that constraint must not + // be dropped (memory strategy ignores local_fragment_min). + bool depends_on_loop_var = + !info.indices.empty() && inner_for_ && + std::any_of(info.indices.begin(), info.indices.end(), + [&](const PrimExpr &idx) { + return UsesVar(idx, [&](const VarNode *v) { + return v == inner_for_->loop_var.get(); + }); + }); + if (depends_on_loop_var || info.is_store) { + memory_min = arith::ZeroAwareGCD(memory_min, info.vector_size); + has_global_or_shared_buffer = true; + } else { + local_fragment_min = + arith::ZeroAwareGCD(local_fragment_min, info.vector_size); + local_fragment_buffers.push_back(info); + } } } diff --git a/testing/python/issue/test_tilelang_issue_1106.py b/testing/python/issue/test_tilelang_issue_1106.py deleted file mode 100644 index c5ae33b1aa..0000000000 --- a/testing/python/issue/test_tilelang_issue_1106.py +++ /dev/null @@ -1,42 +0,0 @@ -import tilelang -import tilelang.testing -from tilelang import language as T - - -@tilelang.jit( - pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, -) -def get_kernel(m: int): - dtype = "int32" - - @T.prim_func - def test_kernel(a: T.Tensor[(m,), dtype], b: T.Tensor[(m,), dtype]): - with T.Kernel(1, threads=64) as (bx): - shared = T.alloc_shared((64,), dtype) - tx = T.get_thread_binding(0) - tid = tx + bx * 64 - - for i in T.serial((m // 2 - tx) // 64 + 1): - for j in T.vectorized(2): - shared[tx] += a[(i * 64 + tid) * 2 + j] - - b[tid] = shared[tx] - - return test_kernel - - -def test_issue_1106(): - m = 200 - kernel = get_kernel(m) - source = kernel.get_kernel_source() - # Ensure __syncthreads is not inside the for loop - for_start = source.find("for (int i = 0;") - for_end = source.find("__syncthreads") - assert for_end > for_start, "__syncthreads should be after the for loop, not inside it" - # Check that __syncthreads appears after the closing brace of the outer for loop - assert source[for_end - 4 : for_end - 2] == "}\n", "__syncthreads should not be inside any for loop" - - -if __name__ == "__main__": - # tilelang.testing.main() - test_issue_1106() diff --git a/testing/python/transform/test_tilelang_transform_decouple_type_cast.py b/testing/python/transform/test_tilelang_transform_decouple_type_cast.py index 7a4fdaf081..0dcf9bf1de 100644 --- a/testing/python/transform/test_tilelang_transform_decouple_type_cast.py +++ b/testing/python/transform/test_tilelang_transform_decouple_type_cast.py @@ -120,6 +120,58 @@ def after(cond_buf: T.Tensor[(1,), T.int32]): _check(before, after) +def test_rmw_same_buffer_different_indices(): + """RMW with different indices into the same buffer: a[i] = a[i] + a[i+32]. + + Both loads and the store target the same buffer but at different index + expressions. Each unique (buffer, indices) pair should get its own cast + buffer, and the RMW load `a[i]` should read from the same cast buffer the + store writes to (so the read-side copy-from and the write-side copy-to + share that buffer). + """ + + @T.prim_func + def before(a: T.Tensor[(64,), T.float8_e4m3fn]): + for i in T.vectorized(32): + a[i] = T.cast( + T.cast(a[i], T.float32) + T.cast(a[i + 32], T.float32), + T.float8_e4m3fn, + ) + + mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) + mod = DecoupleTypeCast()(mod) + + # Sanity checks: pass ran, two distinct cast buffers were created, and the + # RMW load site no longer references `a` directly in the compute body. + text = mod["main"].script() + assert "a_local_cast" in text, "Expected cast buffer for store-side of a[i]" + assert "a_local_cast_1" in text, "Expected second cast buffer for a[i+32]" + + +def test_local_to_memory_with_let_stmt(): + """Test local → memory transform still triggers through LetStmt-bound loads.""" + + @T.prim_func + def before(b: T.Tensor[(16,), T.float8_e4m3fn]): + a_frag = T.alloc_local((16,), T.float32) + scale = T.alloc_local((16,), T.float32) + for i in T.vectorized(16): + factor = scale[i] + b[i] = a_frag[i] * factor + + @T.prim_func + def after(b: T.Tensor[(16,), T.float8_e4m3fn]): + a_frag = T.alloc_local((16,), T.float32) + scale = T.alloc_local((16,), T.float32) + b_local_cast = T.decl_buffer((16,), T.float8_e4m3fn, scope="local") + for i in T.vectorized(16): + b_local_cast[i] = T.cast(a_frag[i] * scale[i], T.float8_e4m3fn) + for i_copy in T.vectorized(16): + b[i_copy] = b_local_cast[i_copy] + + _check(before, after) + + # ============================================================================= # CUDA Codegen Tests # ============================================================================= @@ -375,8 +427,48 @@ def main( ) +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9) +def test_e2e_scalar_load_no_cast_buffer(): + """Test that scalar memory load (b[0]) is not decoupled into a cast buffer. + + When a vectorized loop stores to global with a scalar memory load in the + expression (e.g. c[i] = a_local[i] * b[0]), the scalar load's index does + not depend on the loop variable. It should remain in the compute loop as + a broadcast, not be extracted into a local cast buffer. + + Previously this caused float32x32 codegen errors because both + VectorizePlanner and DecoupleTypeCast treated b[0] as a vector memory + access. + """ + + @tilelang.jit + def kernel_fn(): + @T.prim_func + def main( + a: T.Tensor[(32,), T.float8_e4m3fn], + b: T.Tensor[(1,), T.float32], + c: T.Tensor[(32,), T.float8_e4m3fn], + ): + with T.Kernel(1, threads=32): + a_local = T.alloc_local((32,), T.float8_e4m3fn) + T.copy(a, a_local) + + for i in T.vectorized(32): + c[i] = a_local[i] * b[0] + + return main + + kernel = kernel_fn() + source = kernel.get_kernel_source() + + assert "c_local_cast" in source, "Expected c_local_cast for store-side decoupling" + assert "b_local_cast" not in source, "Scalar load b[0] should not get a cast buffer" + + if __name__ == "__main__": test_no_transform_if_then_else_condition() + test_e2e_scalar_load_no_cast_buffer() test_e2e_bf16_global_to_frag() test_e2e_bf16_global_shared_frag() test_e2e_fp8_global_to_frag() diff --git a/tilelang/transform/decouple_type_cast.py b/tilelang/transform/decouple_type_cast.py index f69ab87e9b..eafd8b36f5 100644 --- a/tilelang/transform/decouple_type_cast.py +++ b/tilelang/transform/decouple_type_cast.py @@ -9,6 +9,8 @@ intermediate stage, allowing optimal vectorization for both computation and memory access. +Mixed-precision is detected by the presence of Cast nodes in the loop body. + Two cases are handled: Case 1: local → memory (store to memory with mixed types) @@ -38,6 +40,7 @@ from __future__ import annotations +from tvm import ir as tvm_ir from tvm import tir from tvm.ir import Op from tvm.tir import ( @@ -46,11 +49,13 @@ BufferLoad, BufferStore, Call, + Cast, DeclBuffer, For, ForKind, IfThenElse, IntImm, + LetStmt, PrimFunc, PyStmtExprVisitor, SeqStmt, @@ -80,142 +85,148 @@ def is_global_or_shared_buffer(buffer: Buffer) -> bool: return is_global(buffer) or is_shared(buffer) -def validate_buffer_scope(buffer: Buffer) -> None: - """Validate that buffer has a known scope. - - Raises: - ValueError: If buffer scope is unknown or empty. - """ - if buffer is None: - return - if not is_local_buffer(buffer) and not is_global_or_shared_buffer(buffer): - raise ValueError( - f"Unknown buffer scope '{buffer.scope()}' for buffer '{buffer.name}'. " - f"Expected one of: local, local.fragment, local.var, global, shared, shared.dyn" - ) +# --------------------------------------------------------------------------- +# Mixed-precision detection: check for Cast nodes in the statement tree +# --------------------------------------------------------------------------- @tir.functor.visitor -class MixedTypeChecker(PyStmtExprVisitor): - """Check if expression contains BufferLoads with different dtypes, skipping indices.""" +class _CastFinder(PyStmtExprVisitor): + """Find Cast nodes in a statement, skipping BufferLoad/BufferStore indices. - def __init__(self, target_dtype: str): + A Cast that only appears inside an index expression is not a mixed-precision + compute — it's just an index-type conversion — so it should not trigger the + decoupling transformation. + """ + + def __init__(self): super().__init__() - self.target_dtype = str(target_dtype) - self.found_different = False + self.found = False - def visit_buffer_load_(self, op: BufferLoad) -> None: - if str(op.buffer.dtype) != self.target_dtype: - self.found_different = True - # Skip indices traversal + def visit_cast_(self, op: Cast) -> None: + self.found = True + self.visit_expr(op.value) + def visit_buffer_store_(self, op: BufferStore) -> None: + self.visit_expr(op.value) -def has_mixed_types(expr: tir.PrimExpr, target_dtype: str) -> bool: - """Check if expression contains BufferLoads with different dtypes than target. + def visit_buffer_load_(self, op: BufferLoad) -> None: + pass - If any BufferLoad in the expression has a different dtype than the target - (store buffer's dtype), vectorization may be constrained by GCD of all dtypes. - """ - checker = MixedTypeChecker(target_dtype) - checker.visit_expr(expr) - return checker.found_different +def _has_cast(stmt: Stmt) -> bool: + """Check if a statement tree contains any Cast node outside of indices.""" + finder = _CastFinder() + finder.visit_stmt(stmt) + return finder.found -@tir.functor.visitor -class GlobalSharedBufferLoadCollector(PyStmtExprVisitor): - """Collect BufferLoads from global/shared buffers, skipping if_then_else conditions. - The condition part of if_then_else doesn't participate in type casting, - so we skip collecting BufferLoads from there. - """ +def _contains_seq_stmt(stmt: Stmt) -> bool: + """Check if statement contains SeqStmt (multiple statements). - def __init__(self, skip_if_then_else_cond: bool = False): - super().__init__() - self.result: list[BufferLoad] = [] - self.skip_if_then_else_cond = skip_if_then_else_cond + When the For body has SeqStmt, the transformation is more complex + and we skip the optimization for now. + """ + found = False - def visit_buffer_load_(self, op: BufferLoad) -> None: - if is_global_or_shared_buffer(op.buffer): - self.result.append(op) + def visitor(node) -> None: + nonlocal found + if isinstance(node, SeqStmt): + found = True - def visit_call_(self, op: Call) -> None: - if self.skip_if_then_else_cond and op.op.same_as(_IF_THEN_ELSE_OP): - # Skip condition (args[0]), only visit true/false values (args[1], args[2]) - self.visit_expr(op.args[1]) - self.visit_expr(op.args[2]) - else: - # Visit all arguments normally - for arg in op.args: - self.visit_expr(arg) + post_order_visit(stmt, visitor) + return found -def get_global_or_shared_buffer_loads(expr: tir.PrimExpr, skip_if_then_else_cond: bool = False) -> list[BufferLoad]: - """Get BufferLoads from global/shared buffers in the expression. +def _expr_depends_on_var(expr: tir.PrimExpr, var: Var) -> bool: + """Check if an expression references the given Var.""" + found = False - Args: - expr: The expression to search. - skip_if_then_else_cond: If True, skip BufferLoads in if_then_else conditions, - since they don't participate in type casting. - """ - collector = GlobalSharedBufferLoadCollector(skip_if_then_else_cond) - collector.visit_expr(expr) - return collector.result + def visitor(node) -> None: + nonlocal found + if isinstance(node, Var) and node.same_as(var): + found = True + post_order_visit(expr, visitor) + return found -def has_global_or_shared_load_with_different_dtype(expr: tir.PrimExpr, target_dtype: str) -> bool: - """Check if expression has global/shared BufferLoad with different dtype than target. - Used to detect memory→local cases where we need to insert cast buffer. - Skips if_then_else condition since it doesn't participate in type casting. - """ - target_dtype = str(target_dtype) - return any(str(load.buffer.dtype) != target_dtype for load in get_global_or_shared_buffer_loads(expr, skip_if_then_else_cond=True)) +# --------------------------------------------------------------------------- +# Collection: gather all shared/global BufferStores and BufferLoads +# --------------------------------------------------------------------------- @tir.functor.visitor -class StoreCollector(PyStmtExprVisitor): - """Collect BufferStore nodes that need transformation, skipping indices traversal. +class MemoryAccessCollector(PyStmtExprVisitor): + """Collect shared/global BufferStore and BufferLoad nodes. - This avoids visiting BufferLoad/BufferStore nodes inside indices, which don't - participate in the type casting transformation. + Skips indices traversal so that index expressions (which may contain + BufferLoads to index buffers) do not pollute the result. + + BufferLoads in if_then_else conditions are skipped because conditions + don't participate in the type-cast compute path. + + BufferLoads whose indices do not depend on ``loop_var`` are skipped + because they are scalar accesses (e.g. ``b[0]``) that should remain + in the compute loop as broadcasts. """ - def __init__(self): + def __init__(self, loop_var: Var): super().__init__() - self.local_to_memory: list[BufferStore] = [] - self.memory_to_local: list[BufferStore] = [] + self.loop_var = loop_var + self.stores: list[BufferStore] = [] + self.loads: list[BufferLoad] = [] def visit_buffer_store_(self, op: BufferStore) -> None: - validate_buffer_scope(op.buffer) - # Case 1: store to memory with mixed types - if is_global_or_shared_buffer(op.buffer) and has_mixed_types(op.value, op.buffer.dtype): - self.local_to_memory.append(op) - # Case 2: store to local with memory load of different dtype - elif is_local_buffer(op.buffer) and has_global_or_shared_load_with_different_dtype(op.value, op.buffer.dtype): - self.memory_to_local.append(op) - # Only visit value, skip indices + if is_global_or_shared_buffer(op.buffer): + self.stores.append(op) + # Visit value but skip indices self.visit_expr(op.value) def visit_buffer_load_(self, op: BufferLoad) -> None: - # Skip indices traversal for BufferLoad as well - pass + # Skip loads whose indices do not depend on loop_var (scalar access). + # Collect ALL qualifying loads (even from the same buffer with different + # indices, e.g. a[i] and a[i+32]) so each gets its own cast buffer. + if is_global_or_shared_buffer(op.buffer) and any(_expr_depends_on_var(idx, self.loop_var) for idx in op.indices): + self.loads.append(op) + # Skip indices traversal + def visit_call_(self, op: Call) -> None: + if op.op.same_as(_IF_THEN_ELSE_OP): + # Skip condition (args[0]), only visit true/false values + self.visit_expr(op.args[1]) + self.visit_expr(op.args[2]) + else: + for arg in op.args: + self.visit_expr(arg) -def contains_seq_stmt(stmt: Stmt) -> bool: - """Check if statement contains SeqStmt (multiple statements). - When the For body has SeqStmt, the transformation is more complex - and we skip the optimization for now. - """ - found = False +# --------------------------------------------------------------------------- +# Utilities +# --------------------------------------------------------------------------- - def visitor(node) -> None: - nonlocal found - if isinstance(node, SeqStmt): - found = True - post_order_visit(stmt, visitor) - return found +def inline_let_stmts(stmt: Stmt) -> Stmt: + """Inline all LetStmt bindings in *stmt* so that downstream visitors can + see the original BufferLoad nodes that were hidden behind Var references. + + Used before collecting memory accesses so that BufferLoads inside LetStmt + values are visible to ``MemoryAccessCollector``. + """ + if isinstance(stmt, LetStmt): + body = inline_let_stmts(stmt.body) + return substitute(body, {stmt.var: stmt.value}) + elif isinstance(stmt, IfThenElse): + then_case = inline_let_stmts(stmt.then_case) + else_case = inline_let_stmts(stmt.else_case) if stmt.else_case else None + if then_case is not stmt.then_case or else_case is not stmt.else_case: + return IfThenElse(stmt.condition, then_case, else_case) + return stmt + elif isinstance(stmt, SeqStmt): + new_seq = [inline_let_stmts(s) for s in stmt.seq] + return SeqStmt(new_seq) + else: + return stmt def extract_if_condition(stmt: Stmt) -> tuple[tir.PrimExpr | None, Stmt]: @@ -229,17 +240,49 @@ def extract_if_condition(stmt: Stmt) -> tuple[tir.PrimExpr | None, Stmt]: return None, stmt -# Type alias for cast buffer mapping -# Maps original buffer -> (cast buffer, original indices) -CastBufferMap = dict[Buffer, tuple[Buffer, list[tir.PrimExpr]]] +# Cast entry: (original buffer, original indices, cast buffer) +# Each unique (buffer, indices) pair gets its own entry, so that accesses +# like a[i] and a[i+32] from the same buffer are handled correctly. +CastEntry = tuple[Buffer, list[tir.PrimExpr], Buffer] + + +def _buf_indices_match( + buf_a: Buffer, + indices_a: list[tir.PrimExpr], + buf_b: Buffer, + indices_b: list[tir.PrimExpr], +) -> bool: + """Check if two (buffer, indices) pairs refer to the same access pattern.""" + if not buf_a.same_as(buf_b): + return False + if len(indices_a) != len(indices_b): + return False + return all(tvm_ir.structural_equal(a, b) for a, b in zip(indices_a, indices_b)) + + +def _find_cast_entry( + entries: list[CastEntry], + buffer: Buffer, + indices: list[tir.PrimExpr], +) -> Buffer | None: + """Find the cast buffer for a given (buffer, indices) pair, or None.""" + for orig_buf, orig_indices, cast_buf in entries: + if _buf_indices_match(orig_buf, orig_indices, buffer, indices): + return cast_buf + return None + + +# --------------------------------------------------------------------------- +# Mutator +# --------------------------------------------------------------------------- @tir.functor.mutator class DecoupleTypeCastMutator(tir.PyStmtExprMutator): """Mutator that decouples type cast vectorization constraints. - This mutator transforms vectorized loops that store to memory buffers - (global/shared) with mixed-precision expressions by inserting local + This mutator transforms vectorized loops that have mixed-precision + operations (detected by the presence of Cast nodes) by inserting local cache buffers as intermediate stages. """ @@ -268,169 +311,117 @@ def _make_for(self, original: For, new_body: Stmt) -> For: original.step, ) + # ----- entry point for each For loop ----- + def visit_for_(self, op: For) -> Stmt: """Visit For nodes, transforming vectorized loops with mixed-type stores.""" # Recursively visit body to handle nested loops new_body = self.visit_stmt(op.body) - # Only transform vectorized loops + # Only transform vectorized loops with static extent if op.kind != ForKind.VECTORIZED: return self._make_for(op, new_body) if new_body is not op.body else op - - # Skip transformation for complex cases with multiple statements - # Currently we only handle simple single BufferStore cases - if contains_seq_stmt(new_body): + if not isinstance(op.extent, IntImm): return self._make_for(op, new_body) if new_body is not op.body else op - # Collect stores that need transformation - local_to_memory, memory_to_local = self._collect_stores_to_transform(new_body) - if local_to_memory: - return self._transform_local_to_memory(op, local_to_memory) - elif memory_to_local: - return self._transform_memory_to_local(op, memory_to_local) - else: + # Check if the body has any Cast nodes + if not _has_cast(new_body): return self._make_for(op, new_body) if new_body is not op.body else op - def _collect_stores_to_transform(self, stmt: Stmt) -> tuple[list[BufferStore], list[BufferStore]]: - """Collect BufferStore nodes that need local cast buffer insertion. - - Returns two lists: - 1. local_to_memory: stores to memory buffer with mixed-type values - (compute → cast buffer → copy to memory) - 2. memory_to_local: stores to local buffer with memory buffer loads of different dtype - (copy from memory → cast buffer → compute) - - Note: Vectorized for is always the innermost loop, so no nested For handling needed. - """ - collector = StoreCollector() - collector.visit_stmt(stmt) - return collector.local_to_memory, collector.memory_to_local - - def _transform_local_to_memory(self, op: For, stores_to_transform: list[BufferStore]) -> Stmt: - """Transform local→memory: compute to cast buffer, then copy to memory. + # Skip SeqStmt (multiple statements) — not supported yet + if _contains_seq_stmt(new_body): + return self._make_for(op, new_body) if new_body is not op.body else op - Before: - b[i] = cast(a_frag[i], fp4) + # Inline LetStmts for analysis so BufferLoads behind Vars are visible + inlined_body = inline_let_stmts(new_body) - After: - cast_buf[i] = cast(a_frag[i], fp4) # compute to cast buffer - b[i] = cast_buf[i] # copy to memory - """ - # Skip dynamic extents - if not isinstance(op.extent, IntImm): - return op + # Collect all shared/global stores and loads + collector = MemoryAccessCollector(op.loop_var) + collector.visit_stmt(inlined_body) - # Extract condition if the body is wrapped in IfThenElse - condition, _ = extract_if_condition(op.body) + if not collector.stores and not collector.loads: + # Cast exists but no memory access → nothing to decouple + return self._make_for(op, new_body) if new_body is not op.body else op - # Create cast buffers for each unique target buffer (memory buffer) - cast_buffers = self._create_cast_buffers_for_stores(stores_to_transform, op.extent.value) + extent = op.extent.value + + # Extract condition (from inlined body for correctness) + condition, _ = extract_if_condition(inlined_body) + + # Create cast entries for stores and loads + store_entries = self._create_cast_entries(collector.stores, extent) + # For loads, skip those already covered by a store entry (read-modify-write) + # by matching (buffer, indices). Loads with different indices from the same + # buffer still get their own cast buffer. + uncovered_loads = [ld for ld in collector.loads if _find_cast_entry(store_entries, ld.buffer, list(ld.indices)) is None] + load_entries = self._create_cast_entries(uncovered_loads, extent) + + # Build copy-from-memory loops (before compute) + # For read-modify-write, reuse the store-side cast buffer for copy-from. + rmw_entries = [ + entry + for entry in store_entries + if any(_buf_indices_match(entry[0], entry[1], ld.buffer, list(ld.indices)) for ld in collector.loads) + ] + copy_from_loops = self._create_copy_loops( + op, + load_entries + rmw_entries, + direction="from_memory", + condition=condition, + ) - # Build compute loop (stores to local cast buffer) - compute_body = self._replace_stores_with_cast(op.body, cast_buffers, op.loop_var) + # Build compute loop: replace stores and loads in the *inlined* body + # so that indices match what the collector saw (LetStmt vars are expanded). + # For RMW (a load whose (buffer, indices) matches a store entry), the load + # must be rewritten to the *same* cast buffer the store writes to, so we + # feed both store and load entries into the load-replacement table. + load_replacement_entries = store_entries + load_entries + compute_body = inlined_body + if store_entries or load_entries: + compute_body = self._replace_access(compute_body, store_entries, load_replacement_entries, op.loop_var) compute_loop = self._make_vectorized_loop(op, compute_body) - # Build copy loops (transfer from cast buffer to memory, with condition if present) - copy_loops = self._create_copy_loops_to_memory(op, cast_buffers, condition) + # Build copy-to-memory loops (after compute) + copy_to_loops = self._create_copy_loops( + op, + store_entries, + direction="to_memory", + condition=condition, + ) - # Combine: compute → copy - all_stmts = [compute_loop] + copy_loops + # Combine: copy-from → compute → copy-to + all_stmts = copy_from_loops + [compute_loop] + copy_to_loops result: Stmt = SeqStmt(all_stmts) if len(all_stmts) > 1 else all_stmts[0] # Wrap with buffer declarations and allocations - result = self._wrap_with_allocations(result, cast_buffers) + result = self._wrap_with_allocations(result, store_entries + load_entries) return result - def _transform_memory_to_local(self, op: For, stores_to_transform: list[BufferStore]) -> Stmt: - """Transform memory→local: copy from memory to cast buffer, then compute. + # ----- helpers ----- - Before: - a_frag[i] = cast(b[i], fp32) + def _create_cast_entries(self, accesses: list[BufferStore | BufferLoad], extent: int) -> list[CastEntry]: + """Create local cast buffers for memory accesses. - After: - cast_buf[i] = b[i] # copy from memory to cast buffer - a_frag[i] = cast(cast_buf[i], fp32) # compute from cast buffer + Each unique (buffer, indices) pair gets its own cast buffer. """ - # Skip dynamic extents - if not isinstance(op.extent, IntImm): - return op - - # Extract condition if the body is wrapped in IfThenElse - condition, _ = extract_if_condition(op.body) - - # Collect memory buffer loads that need cast buffering - memory_loads = self._collect_memory_loads_to_cast(stores_to_transform) - if not memory_loads: - return op - - # Create cast buffers for each unique source buffer (memory buffer) - cast_buffers = self._create_cast_buffers_for_loads(memory_loads, op.extent.value) - - # Build copy loops (transfer from memory to cast buffer, with condition if present) - copy_loops = self._create_copy_loops_from_memory(op, cast_buffers, condition) - - # Build compute loop (replace memory loads with cast buffer loads) - compute_body = self._replace_loads_with_cast(op.body, cast_buffers, op.loop_var) - compute_loop = self._make_vectorized_loop(op, compute_body) + entries: list[CastEntry] = [] - # Combine: copy → compute - all_stmts = copy_loops + [compute_loop] - result: Stmt = SeqStmt(all_stmts) if len(all_stmts) > 1 else all_stmts[0] - - # Wrap with buffer declarations and allocations - result = self._wrap_with_allocations(result, cast_buffers) - - return result - - def _collect_memory_loads_to_cast(self, stores: list[BufferStore]) -> list[BufferLoad]: - """Collect memory BufferLoads from store values that need cast buffering.""" - result: list[BufferLoad] = [] - seen_buffers = set() - for store in stores: - for load in get_global_or_shared_buffer_loads(store.value, skip_if_then_else_cond=True): - if load.buffer not in seen_buffers: - result.append(load) - seen_buffers.add(load.buffer) - return result - - def _create_cast_buffers_for_stores(self, stores: list[BufferStore], extent: int) -> CastBufferMap: - """Create local cast buffers for store targets (memory buffers).""" - cast_buffers: CastBufferMap = {} - - for store in stores: - if store.buffer in cast_buffers: - continue - - cache_name = self._make_unique_name(f"{store.buffer.name}_local_cast") - cast_buffer = tir.decl_buffer( - shape=(extent,), - dtype=store.buffer.dtype, - name=cache_name, - scope="local", - ) - cast_buffers[store.buffer] = (cast_buffer, list(store.indices)) - - return cast_buffers - - def _create_cast_buffers_for_loads(self, loads: list[BufferLoad], extent: int) -> CastBufferMap: - """Create local cast buffers for load sources (memory buffers).""" - cast_buffers: CastBufferMap = {} - - for load in loads: - if load.buffer in cast_buffers: + for access in accesses: + indices = list(access.indices) + if _find_cast_entry(entries, access.buffer, indices) is not None: continue - cache_name = self._make_unique_name(f"{load.buffer.name}_local_cast") + cache_name = self._make_unique_name(f"{access.buffer.name}_local_cast") cast_buffer = tir.decl_buffer( shape=(extent,), - dtype=load.buffer.dtype, + dtype=access.buffer.dtype, name=cache_name, scope="local", ) - cast_buffers[load.buffer] = (cast_buffer, list(load.indices)) + entries.append((access.buffer, indices, cast_buffer)) - return cast_buffers + return entries def _make_vectorized_loop(self, original: For, body: Stmt) -> For: """Create a vectorized For loop based on the original.""" @@ -445,62 +436,41 @@ def _make_vectorized_loop(self, original: For, body: Stmt) -> For: original.step, ) - def _create_copy_loops_to_memory(self, op: For, cast_buffers: CastBufferMap, condition: tir.PrimExpr | None = None) -> list[For]: - """Create copy loops to transfer data from cast buffers to memory buffers.""" - copy_loops: list[For] = [] + def _create_copy_loops( + self, + op: For, + entries: list[CastEntry], + direction: str, + condition: tir.PrimExpr | None = None, + ) -> list[For]: + """Create vectorized copy loops between memory and cast buffers. - for orig_buffer, (cast_buffer, orig_indices) in cast_buffers.items(): - # vectorized loop only has one iteration variable, so we use the same name for the copy variable - copy_var = Var(f"{op.loop_var.name}_copy", op.loop_var.dtype) - - # Substitute loop_var with copy_var in original indices - new_indices = [substitute(idx, {op.loop_var: copy_var}) for idx in orig_indices] - - # cast buffer → memory - copy_store: Stmt = BufferStore( - orig_buffer, - BufferLoad(cast_buffer, [copy_var]), - new_indices, - ) - - # Wrap with condition if present (substitute loop_var with copy_var) - if condition is not None: - new_condition = substitute(condition, {op.loop_var: copy_var}) - copy_store = IfThenElse(new_condition, copy_store, None) - - copy_loop = For( - copy_var, - op.min, - op.extent, - ForKind.VECTORIZED, - copy_store, - op.thread_binding, - op.annotations, - op.step, - ) - copy_loops.append(copy_loop) - - return copy_loops - - def _create_copy_loops_from_memory(self, op: For, cast_buffers: CastBufferMap, condition: tir.PrimExpr | None = None) -> list[For]: - """Create copy loops to transfer data from memory buffers to cast buffers.""" + direction: "to_memory" (cast → memory) or "from_memory" (memory → cast). + """ copy_loops: list[For] = [] - for orig_buffer, (cast_buffer, orig_indices) in cast_buffers.items(): - # vectorized loop only has one iteration variable, so we use the same name for the copy variable + for orig_buffer, orig_indices, cast_buffer in entries: + # vectorized loop only has one iteration variable, + # so we use the same name for the copy variable copy_var = Var(f"{op.loop_var.name}_copy", op.loop_var.dtype) # Substitute loop_var with copy_var in original indices new_indices = [substitute(idx, {op.loop_var: copy_var}) for idx in orig_indices] - # memory → cast buffer - copy_store: Stmt = BufferStore( - cast_buffer, - BufferLoad(orig_buffer, new_indices), - [copy_var], - ) - - # Wrap with condition if present (substitute loop_var with copy_var) + if direction == "to_memory": + copy_store: Stmt = BufferStore( + orig_buffer, + BufferLoad(cast_buffer, [copy_var]), + new_indices, + ) + else: + copy_store = BufferStore( + cast_buffer, + BufferLoad(orig_buffer, new_indices), + [copy_var], + ) + + # Wrap with condition if present if condition is not None: new_condition = substitute(condition, {op.loop_var: copy_var}) copy_store = IfThenElse(new_condition, copy_store, None) @@ -519,10 +489,10 @@ def _create_copy_loops_from_memory(self, op: For, cast_buffers: CastBufferMap, c return copy_loops - def _wrap_with_allocations(self, body: Stmt, cast_buffers: CastBufferMap) -> Stmt: + def _wrap_with_allocations(self, body: Stmt, entries: list[CastEntry]) -> Stmt: """Wrap statement with buffer declarations and allocations.""" result = body - for cast_buffer, _ in cast_buffers.values(): + for _, _, cast_buffer in entries: result = DeclBuffer(cast_buffer, result) result = Allocate( cast_buffer.data, @@ -533,51 +503,39 @@ def _wrap_with_allocations(self, body: Stmt, cast_buffers: CastBufferMap) -> Stm ) return result - def _replace_stores_with_cast(self, stmt: Stmt, cast_buffers: CastBufferMap, loop_var: Var) -> Stmt: - """Replace stores to memory buffers with stores to cast buffers.""" - store_replacer = StoreReplacer(cast_buffers, loop_var) - return store_replacer.visit_stmt(stmt) - - def _replace_loads_with_cast(self, stmt: Stmt, cast_buffers: CastBufferMap, loop_var: Var) -> Stmt: - """Replace loads from memory buffers with loads from cast buffers. - - This method recursively processes the statement tree, replacing - BufferLoad nodes from cast buffers with loads from the cast buffer. - """ - # Create an expression mutator to replace BufferLoads - load_replacer = LoadReplacer(cast_buffers, loop_var) - return load_replacer.visit_stmt(stmt) + def _replace_access(self, stmt: Stmt, store_entries: list[CastEntry], load_entries: list[CastEntry], loop_var: Var) -> Stmt: + """Replace memory accesses with cast buffer accesses.""" + replacer = AccessReplacer(store_entries, load_entries, loop_var) + return replacer.visit_stmt(stmt) @tir.functor.mutator -class StoreReplacer(tir.PyStmtExprMutator): - """Mutator to replace memory BufferStores with cast buffer BufferStores.""" +class AccessReplacer(tir.PyStmtExprMutator): + """Mutator to replace memory BufferStores/BufferLoads with cast buffer accesses. + + Matches by both buffer and indices (structural equality) so that accesses + like a[i] and a[i+32] from the same buffer map to different cast buffers. + """ - def __init__(self, cast_buffers: CastBufferMap, loop_var: Var): + def __init__(self, store_entries: list[CastEntry], load_entries: list[CastEntry], loop_var: Var): super().__init__() - self.cast_buffers = cast_buffers + self.store_entries = store_entries + self.load_entries = load_entries self.loop_var = loop_var def visit_buffer_store_(self, op: BufferStore) -> Stmt: - if op.buffer in self.cast_buffers: - cast_buffer, _ = self.cast_buffers[op.buffer] - return BufferStore(cast_buffer, op.value, [self.loop_var]) + new_value = self.visit_expr(op.value) + cast_buf = _find_cast_entry(self.store_entries, op.buffer, list(op.indices)) + if cast_buf is not None: + return BufferStore(cast_buf, new_value, [self.loop_var]) + if new_value is not op.value: + return BufferStore(op.buffer, new_value, list(op.indices)) return op - -@tir.functor.mutator -class LoadReplacer(tir.PyStmtExprMutator): - """Mutator to replace memory BufferLoads with cast buffer BufferLoads.""" - - def __init__(self, cast_buffers: CastBufferMap, loop_var: Var): - super().__init__() - self.cast_buffers = cast_buffers - self.loop_var = loop_var - def visit_buffer_load_(self, op: BufferLoad) -> tir.PrimExpr: - if op.buffer in self.cast_buffers: - cast_buffer, _ = self.cast_buffers[op.buffer] - return BufferLoad(cast_buffer, [self.loop_var]) + cast_buf = _find_cast_entry(self.load_entries, op.buffer, list(op.indices)) + if cast_buf is not None: + return BufferLoad(cast_buf, [self.loop_var]) return op @@ -585,8 +543,7 @@ def DecoupleTypeCast(): """Create a TVM pass that decouples type cast vectorization constraints. This pass inserts a local buffer as an intermediate stage for vectorized - stores to non-local buffers (global/shared) where the store value contains - expressions with different dtypes. + loops where the body contains Cast nodes (mixed-precision operations). This allows optimal vectorization for both computation and memory access. From f309d8147e15c74e643012652b518a1926d830bc Mon Sep 17 00:00:00 2001 From: Kuris <227995639+kurisu6912@users.noreply.github.com> Date: Tue, 14 Apr 2026 15:54:28 +0800 Subject: [PATCH 050/156] [Bugfix][Subtype] Fix scalar fp4 store/load codegen for non-packed buffers (#2037) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Bugfix][Subtype] Fix scalar fp4 load/store codegen and dynamic stride inference for sub-byte types The scalar fp4 buffer access path in codegen_cuda.cc used GetBufferRef which applies a /2 index division (mapping two fp4 elements to one byte), then performed a plain byte assignment — destroying the neighboring nibble. This caused incorrect results whenever fp4 stores couldn't be vectorized (e.g. with dynamic/non-constant strides in StridedTensor). Fix: use tl_fp4_packed_load/tl_fp4_packed_store with (fp4_e2_2_t*) cast for scalar fp4 accesses to non-packed buffers, which correctly handles nibble-level read-modify-write via set_x/set_y. Also add stride_scale factor to _process_dynamic_symbolic in the cython and tvm_ffi adapters, so torch strides (in storage units) are correctly converted to logical element strides for sub-byte dtypes. Co-Authored-By: Claude Opus 4.6 (1M context) * [Test][Subtype] Add correctness tests for fp4 dynamic-stride store codegen Add regression test that scatters fp4 rows into a non-contiguous StridedTensor with dynamic strides and verifies byte-for-byte match against a static-stride reference. Parametrized over block_size. Co-Authored-By: Claude Opus 4.6 (1M context) * [Enhancement][Subtype] Inject stride divisibility assume for sub-byte dtypes and add scalar fp4 store codegen test - Extend InjectAssumes pass to emit `assume(stride % pack_factor == 0)` for non-last-dimension strides of sub-byte buffers (e.g. fp4 with pack_factor=2). This provides the TIR analyzer with divisibility information before LetStmt inlining. - Add `test_subtype_fp4_scalar_store_codegen`: scatter fp4 elements via indirection to force the scalar (non-vectorized) store codegen path and verify nibble-level correctness. - Convert new tests to eager-style @tilelang.jit. Co-Authored-By: Claude Opus 4.6 (1M context) * style: apply clang-format and ruff formatting Co-Authored-By: Claude Opus 4.6 (1M context) * fix lint error * fix: unpack 4-tuple in cython_wrapper dynamic_symbolic_map lookup --------- Co-authored-by: Claude Opus 4.6 (1M context) --- src/target/codegen_cuda.cc | 38 ++++++-- src/transform/inject_assumes.cc | 61 ++++++++++++ .../test_tilelang_language_subtype.py | 95 +++++++++++++++++++ tilelang/jit/adapter/cython/adapter.py | 14 ++- .../jit/adapter/cython/cython_wrapper.pyx | 6 +- tilelang/jit/adapter/tvm_ffi.py | 20 ++-- 6 files changed, 211 insertions(+), 23 deletions(-) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 7ca86911f3..77eb42c8c4 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -3928,8 +3928,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op, int lanes = op->dtype.lanes(); // declare type. if (value_dtype.lanes() == element_dtype.lanes()) { - std::string ref = GetBufferRef(op->dtype, op->buffer.get(), index); - HandleVolatileLoads(ref, op, os); + // For scalar fp4 loads from non-packed buffers, use tl_fp4_packed_load + // to correctly extract the nibble at the given index (the /2 in + // GetBufferRef maps two consecutive fp4 elements to the same byte, but + // reading that byte only returns the low nibble — the odd-indexed element + // is lost). + if (element_dtype.is_float4() && element_dtype.lanes() == 1) { + std::string idx_str = PrintExpr(index); + std::string vid = GetVarID(buffer_var.get()); + os << "tl_fp4_packed_load((fp4_e2_2_t*)" << vid << ", " << idx_str << ")"; + } else { + std::string ref = GetBufferRef(op->dtype, op->buffer.get(), index); + HandleVolatileLoads(ref, op, os); + } } else { bool can_vector_load = false; arith::PVar base; @@ -4001,11 +4012,24 @@ void CodeGenTileLangCUDA::VisitStmt_(const BufferStoreNode *op) { } if (value_dtype.lanes() == element_dtype.lanes()) { - std::string value = this->PrintExpr(op->value); - std::string ref = - this->GetBufferRef(value_dtype, op->buffer.get(), index_expr); - this->PrintIndent(); - stream << ref << " = " << value << ";\n"; + // For scalar fp4 stores to non-packed buffers, use tl_fp4_packed_store + // to correctly handle nibble-level writes. The /2 in GetBufferRef maps two + // consecutive fp4 elements to the same byte, and a plain assignment + // overwrites the entire byte — destroying the neighboring nibble. + if (element_dtype.is_float4() && element_dtype.lanes() == 1) { + std::string idx_str = PrintExpr(index_expr); + std::string value = this->PrintExpr(op->value); + std::string vid = GetVarID(buffer_var.get()); + this->PrintIndent(); + stream << "tl_fp4_packed_store((fp4_e2_2_t*)" << vid << ", " << idx_str + << ", " << value << ");\n"; + } else { + std::string value = this->PrintExpr(op->value); + std::string ref = + this->GetBufferRef(value_dtype, op->buffer.get(), index_expr); + this->PrintIndent(); + stream << ref << " = " << value << ";\n"; + } } else { arith::PVar base; int ramp_lanes = value_dtype.lanes() / element_dtype.lanes(); diff --git a/src/transform/inject_assumes.cc b/src/transform/inject_assumes.cc index 40ad4378a5..7c6ab6d349 100644 --- a/src/transform/inject_assumes.cc +++ b/src/transform/inject_assumes.cc @@ -71,6 +71,46 @@ class AssumeInjector : public tvm::tir::StmtExprMutator { } } + // --- Stride divisibility for sub-byte dtypes --- + struct StrideDivisibilityItem { + PrimExpr stride; + int pack_factor; + std::vector buffers; + }; + std::vector stride_div_items; + std::unordered_map> stride_div_buckets; + + void addStrideExpr(PrimExpr stride, int pack_factor, Buffer buffer) { + size_t h = sh(stride); + auto &bucket = stride_div_buckets[h]; + auto it = std::find_if(bucket.begin(), bucket.end(), [&](size_t y) { + return se(stride, stride_div_items[y].stride, true); + }); + if (it == bucket.end()) { + auto index = stride_div_items.size(); + stride_div_items.push_back({stride, pack_factor, {buffer}}); + bucket.push_back(index); + } else { + auto &item = stride_div_items[*it]; + item.buffers.push_back(buffer); + // Use the largest pack_factor (strongest constraint) + item.pack_factor = std::max(item.pack_factor, pack_factor); + } + } + + void addBufferStrides(Buffer buf) { + int element_bits = buf->dtype.bits() * buf->dtype.lanes(); + if (element_bits >= 8 || buf->strides.empty()) + return; + int pack_factor = 8 / element_bits; + for (size_t k = 0; k + 1 < buf->strides.size(); ++k) { + auto stride = buf->strides[k]; + if (stride->IsInstance()) + continue; + addStrideExpr(stride, pack_factor, buf); + } + } + Stmt build(Stmt body) { auto analyzer = arith::Analyzer{}; for (const auto &e : items) { @@ -87,6 +127,23 @@ class AssumeInjector : public tvm::tir::StmtExprMutator { body = AttrStmt(simplified, tir::attr::tilelang_assume, StringImm(ss.str()), body); } + // Inject stride divisibility assumes for sub-byte dtypes. + // E.g. for fp4 (pack_factor=2), non-last-dim strides must be even. + for (const auto &e : stride_div_items) { + auto cond = + EQ(floormod(e.stride, make_const(e.stride.dtype(), e.pack_factor)), + make_zero(e.stride.dtype())); + std::stringstream ss; + ss << "Sub-byte buffer stride must be divisible by " << e.pack_factor + << ": stride `" << e.stride << "` from buffer "; + for (size_t i = 0; i < e.buffers.size(); i++) { + if (i) + ss << ", "; + ss << "`" << e.buffers[i]->name << "`"; + } + body = AttrStmt(cond, tir::attr::tilelang_assume, StringImm(ss.str()), + body); + } return body; } }; @@ -95,6 +152,7 @@ class AssumeInjector : public tvm::tir::StmtExprMutator { auto body = VisitStmt(op->body); AssumeCreator c; c.addBuffer(op->buffer); + c.addBufferStrides(op->buffer); return DeclBuffer(op->buffer, c.build(body), op->span); } @@ -153,13 +211,16 @@ class AssumeInjector : public tvm::tir::StmtExprMutator { if (IsHostMainBlock(op)) { for (auto item : f->buffer_map) { c.addBuffer(item.second); + c.addBufferStrides(item.second); } } for (auto item : op->alloc_buffers) { c.addBuffer(item); + c.addBufferStrides(item); } for (auto item : op->match_buffers) { c.addBuffer(item->buffer); + c.addBufferStrides(item->buffer); } return Block(op->iter_vars, op->reads, op->writes, op->name_hint, diff --git a/testing/python/language/test_tilelang_language_subtype.py b/testing/python/language/test_tilelang_language_subtype.py index e061e1832e..5aba3e96b3 100644 --- a/testing/python/language/test_tilelang_language_subtype.py +++ b/testing/python/language/test_tilelang_language_subtype.py @@ -301,5 +301,100 @@ def test_subtype_complex_expressions_various(m, n): complex_expr_kernel(x, y) +# --------------------------------------------------------------------------- +# Scalar fp4 store to StridedTensor with dynamic strides. +# Before the fix the codegen wrote full bytes instead of nibbles, so +# consecutive fp4 elements sharing a byte would overwrite each other. +# --------------------------------------------------------------------------- + + +@tilelang.testing.requires_cuda +@pytest.mark.parametrize("block_size", [8, 16]) +def test_subtype_fp4_dynamic_stride_store(block_size): + """fp4 store via StridedTensor: dynamic strides must match static strides.""" + num_blocks, n, padding = 10, 64, 4 + fp4_bytes = 64 # 128 fp4 elems packed into 64 bytes + jit_kw = dict(out_idx=None, target="cuda", pass_configs={"tl.disable_data_race_check": True}) + + def make_buf(): + row = fp4_bytes + padding + back = torch.zeros(num_blocks, block_size * row, dtype=torch.uint8, device="cuda") + fp4 = back[:, : block_size * fp4_bytes].view(num_blocks, block_size, fp4_bytes).view(torch.int8) + return back, fp4 + + torch.manual_seed(0) + src = torch.randint(0, 256, (n, 64), dtype=torch.uint8, device="cuda").view(torch.int8) + slots = torch.randperm(num_blocks * block_size, dtype=torch.int32, device="cuda")[:n] + + # static (reference) — strides known at compile time + back_s, fp4_s = make_buf() + s0 = fp4_s.stride(0) * 2 # byte stride → fp4-element stride + s1 = fp4_s.stride(1) * 2 + + @tilelang.jit(**jit_kw) + def static_kern(src, dst, slots): + nv = T.dynamic("n") + nb = T.dynamic("num_blocks") + src: T.Tensor[(nv, 128), T.float4_e2m1fn] + dst: T.StridedTensor[(nb, block_size, 128), (s0, s1, 1), T.float4_e2m1fn] + slots: T.Tensor[(nv,), T.int32] + with T.Kernel(nv, threads=32) as i: + for k in T.serial(128): + dst[slots[i] // block_size, slots[i] % block_size, k] = src[i, k] + + static_kern(src, fp4_s, slots) + + # dynamic — strides resolved at runtime + @tilelang.jit(**jit_kw) + def dynamic_kern(src, dst, slots): + nv = T.dynamic("n") + nb = T.dynamic("num_blocks") + ds0 = T.dynamic("ds0") + ds1 = T.dynamic("ds1") + src: T.Tensor[(nv, 128), T.float4_e2m1fn] + dst: T.StridedTensor[(nb, block_size, 128), (ds0, ds1, 1), T.float4_e2m1fn] + slots: T.Tensor[(nv,), T.int32] + with T.Kernel(nv, threads=32) as i: + for k in T.serial(128): + dst[slots[i] // block_size, slots[i] % block_size, k] = src[i, k] + + back_d, fp4_d = make_buf() + dynamic_kern(src, fp4_d, slots) + + assert torch.equal(back_s, back_d), ( + f"static vs dynamic stride mismatch: {(back_s != back_d).sum().item()}/{back_s.numel()} bytes differ" + ) + + +@tilelang.testing.requires_cuda +@pytest.mark.parametrize("n", [64, 128]) +def test_subtype_fp4_scalar_store_codegen(n): + """Scatter fp4 elements via indirection — forces scalar (non-vectorized) stores.""" + + @tilelang.jit(out_idx=None, pass_configs={"tl.disable_data_race_check": True}) + def scatter_kern(src, dst, perm): + nv = T.dynamic("n") + src: T.Tensor[(nv, 128), T.float4_e2m1fn] + dst: T.Tensor[(nv, 128), T.float4_e2m1fn] + perm: T.Tensor[(nv,), T.int32] + with T.Kernel(nv, threads=32) as i: + for k in T.serial(128): + dst[perm[i], k] = src[i, k] + + torch.manual_seed(42) + src = torch.randint(0, 256, (n, 64), dtype=torch.uint8, device="cuda").view(torch.int8) + dst = torch.zeros(n, 64, dtype=torch.uint8, device="cuda").view(torch.int8) + perm = torch.randperm(n, dtype=torch.int32, device="cuda") + + scatter_kern(src, dst, perm) + + # Invert the permutation and check src[inv[j]] == dst[j] at byte level + inv = torch.empty_like(perm) + inv[perm] = torch.arange(n, dtype=torch.int32, device="cuda") + expected = src.view(torch.uint8)[inv.long()] + actual = dst.view(torch.uint8) + assert torch.equal(expected, actual), f"scatter fp4 mismatch: {(expected != actual).sum().item()}/{actual.numel()} bytes differ" + + if __name__ == "__main__": tilelang.testing.main() diff --git a/tilelang/jit/adapter/cython/adapter.py b/tilelang/jit/adapter/cython/adapter.py index acf33e532d..912eb07b38 100644 --- a/tilelang/jit/adapter/cython/adapter.py +++ b/tilelang/jit/adapter/cython/adapter.py @@ -213,12 +213,14 @@ def from_database( adapter._post_init() return adapter - def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int, int]]: + def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int, int, int]]: """Extract information about dynamic shapes from the TIR function. - Maps symbolic variables to their corresponding (id, buffer_index, dimension) + Maps symbolic variables to their corresponding (id, buffer_index, dimension, stride_scale) for runtime shape resolution. - id represents shape or stride, 0 represents shape, 1 represents stride + id represents shape or stride, 0 represents shape, 1 represents stride. + stride_scale compensates for sub-byte dtypes (e.g. float4_e2m1fn) where torch strides + are in storage units but the kernel expects logical element strides. """ func = self.prim_func params = func.params @@ -229,13 +231,15 @@ def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int, int]]: buffer = buffer_map[param] for j, shape in enumerate(buffer.shape): if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and (shape not in params): - dynamic_symbolic_map[shape] = (0, i, j) + dynamic_symbolic_map[shape] = (0, i, j, 1) for i, param in enumerate(params): if param in buffer_map: buffer = buffer_map[param] + element_bits = buffer.dtype.bits * buffer.dtype.lanes + stride_scale = 8 // element_bits if element_bits < 8 else 1 for j, stride in enumerate(buffer.strides): if isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and (stride not in params): - dynamic_symbolic_map[stride] = (1, i, j) + dynamic_symbolic_map[stride] = (1, i, j, stride_scale) return dynamic_symbolic_map def _process_buffer_dtype(self) -> dict[tir.Var, tuple[int, torch.dtype]]: diff --git a/tilelang/jit/adapter/cython/cython_wrapper.pyx b/tilelang/jit/adapter/cython/cython_wrapper.pyx index 38c1738f72..b4d51fc916 100644 --- a/tilelang/jit/adapter/cython/cython_wrapper.pyx +++ b/tilelang/jit/adapter/cython/cython_wrapper.pyx @@ -199,7 +199,7 @@ cdef class CythonKernelWrapper: if isinstance(s, tir.Var): for key in self.dynamic_symbolic_map: if(str(s) == str(key)): - ref_id, ref_tensor_idx, ref_shape_idx = self.dynamic_symbolic_map[key] + ref_id, ref_tensor_idx, ref_shape_idx, _stride_scale = self.dynamic_symbolic_map[key] shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx]) else: # Already converted to Python int during initialization shape.append(s) @@ -266,11 +266,11 @@ cdef class CythonKernelWrapper: self._check_static_contiguous(tensor_list) # Add dynamic dimension values to kernel arguments - for _, (ref_id, buffer_idx, shape_idx) in self.dynamic_symbolic_map.items(): + for _, (ref_id, buffer_idx, shape_idx, stride_scale) in self.dynamic_symbolic_map.items(): if ref_id == 0: call_args.append(ctypes.c_int64(tensor_list[buffer_idx].shape[shape_idx])) else: - call_args.append(ctypes.c_int64(tensor_list[buffer_idx].stride(shape_idx))) + call_args.append(ctypes.c_int64(tensor_list[buffer_idx].stride(shape_idx) * stride_scale)) # Add CUDA stream to kernel arguments call_args.append(ctypes.c_void_p(stream)) diff --git a/tilelang/jit/adapter/tvm_ffi.py b/tilelang/jit/adapter/tvm_ffi.py index 6314795444..3aff6bda21 100644 --- a/tilelang/jit/adapter/tvm_ffi.py +++ b/tilelang/jit/adapter/tvm_ffi.py @@ -110,12 +110,14 @@ def __init__( self._post_init() - def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int]]: + def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int, int, int]]: """Extract information about dynamic shapes from the TIR function. - Maps symbolic variables to their corresponding (id, buffer_index, dimension) + Maps symbolic variables to their corresponding (id, buffer_index, dimension, stride_scale) for runtime shape resolution. - id represents shape or stride, 0 represents shape, 1 represents stride + id represents shape or stride, 0 represents shape, 1 represents stride, 2 represents scalar param. + stride_scale compensates for sub-byte dtypes (e.g. float4_e2m1fn) where torch strides + are in storage units but the kernel expects logical element strides. """ func = self.prim_func params = func.params @@ -123,19 +125,21 @@ def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int]]: dynamic_symbolic_map = {} for i, param in enumerate(params): if isinstance(param, tir.Var) and (param not in dynamic_symbolic_map): - dynamic_symbolic_map[param] = (2, i, -1) + dynamic_symbolic_map[param] = (2, i, -1, 1) for i, param in enumerate(params): if param in buffer_map: buffer = buffer_map[param] for j, shape in enumerate(buffer.shape): if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and (shape not in params): - dynamic_symbolic_map[shape] = (0, i, j) + dynamic_symbolic_map[shape] = (0, i, j, 1) for i, param in enumerate(params): if param in buffer_map: buffer = buffer_map[param] + element_bits = buffer.dtype.bits * buffer.dtype.lanes + stride_scale = 8 // element_bits if element_bits < 8 else 1 for j, stride in enumerate(buffer.strides): if isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and (stride not in params): - dynamic_symbolic_map[stride] = (1, i, j) + dynamic_symbolic_map[stride] = (1, i, j, stride_scale) return dynamic_symbolic_map def _convert_torch_func(self) -> Callable[..., Any]: @@ -216,13 +220,13 @@ def func(*inputs: torch.Tensor | Any): if isinstance(s, tir.Var): for key in dynamic_symbolic_map: if str(s) == str(key): - ref_id, ref_tensor_idx, ref_shape_idx = dynamic_symbolic_map[key] + ref_id, ref_tensor_idx, ref_shape_idx, stride_scale = dynamic_symbolic_map[key] if ref_id == 2: shape.append(inputs[ref_tensor_idx]) elif ref_id == 0: shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx]) elif ref_id == 1: - shape.append(tensor_list[ref_tensor_idx].stride()[ref_shape_idx]) + shape.append(tensor_list[ref_tensor_idx].stride()[ref_shape_idx] * stride_scale) else: # Already converted to Python int during initialization shape.append(s) From 380fb5e2fcf0cd60ed91df111ffab9bb8d4f59d5 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Wed, 15 Apr 2026 10:48:45 +0800 Subject: [PATCH 051/156] Support local var fragment --- src/transform/auto_schedule.cc | 9 +++++++++ src/transform/auto_schedule/ir_structure.h | 6 +++--- src/transform/auto_schedule/memory_detector.h | 2 ++ src/transform/auto_schedule/warpgroup_partition.cc | 3 +++ tilelang/transform/z3_scheduler.py | 4 +--- 5 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/transform/auto_schedule.cc b/src/transform/auto_schedule.cc index 8ba5dbf3ba..faf11b338d 100644 --- a/src/transform/auto_schedule.cc +++ b/src/transform/auto_schedule.cc @@ -228,6 +228,15 @@ class IRStructureBuilder : public StmtVisitor { root_ = std::move(task_node); } + void VisitStmt_(const BufferStoreNode *op) override { + auto task_node = std::make_shared(); + task_node->stmts.push_back(GetRef(op)); + + AnalyzeResourceUsage(GetRef(op), task_node.get()); + + root_ = std::move(task_node); + } + void VisitStmt_(const IfThenElseNode *op) override { // If statement -> treat as TaskNode for now (could be refined later) auto task_node = std::make_shared(); diff --git a/src/transform/auto_schedule/ir_structure.h b/src/transform/auto_schedule/ir_structure.h index 6b26872e7a..902a31fa54 100644 --- a/src/transform/auto_schedule/ir_structure.h +++ b/src/transform/auto_schedule/ir_structure.h @@ -461,7 +461,7 @@ class ControlNode : public IRStructure { std::shared_ptr Clone() const override; bool containWarpgroupId(int id) const override { - return child->containWarpgroupId(id); + return child && child->containWarpgroupId(id); } private: @@ -560,7 +560,7 @@ class WrapperNode : public IRStructure { std::shared_ptr Clone() const override; bool containWarpgroupId(int id) const override { - return child->containWarpgroupId(id); + return child && child->containWarpgroupId(id); } private: @@ -665,7 +665,7 @@ class ScheduleUnit : public IRStructure { std::shared_ptr Clone() const override; bool containWarpgroupId(int id) const override { - return child->containWarpgroupId(id); + return child && child->containWarpgroupId(id); } private: diff --git a/src/transform/auto_schedule/memory_detector.h b/src/transform/auto_schedule/memory_detector.h index 2f9127f82a..0d7b400681 100644 --- a/src/transform/auto_schedule/memory_detector.h +++ b/src/transform/auto_schedule/memory_detector.h @@ -391,6 +391,7 @@ class MemoryAccessDetector : public StmtExprVisitor { << "First argument of tl.tileop.region should be a BufferLoad"; } } + StmtExprVisitor::VisitExpr_(op); return; } @@ -408,6 +409,7 @@ class MemoryAccessDetector : public StmtExprVisitor { // Process second argument as write region ProcessBufferRegion(op->args[1], false); // is_read = false } + StmtExprVisitor::VisitExpr_(op); return; } diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index d128038887..c638feff69 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -540,6 +540,9 @@ Stmt ConvertIRStructureToStmt(IRStructure *structure, } } } + } else if (ctrl->child->IsTask()) { + auto task = static_cast(ctrl->child.get()); + stmts.push_back(ConvertIRStructureToStmt(task, outer_enable_epi)); } else { LOG(FATAL); } diff --git a/tilelang/transform/z3_scheduler.py b/tilelang/transform/z3_scheduler.py index fd066ae1f6..95ba51f07e 100644 --- a/tilelang/transform/z3_scheduler.py +++ b/tilelang/transform/z3_scheduler.py @@ -259,9 +259,7 @@ def z3_schedule_loop_python( # For small number of tasks, return trivial schedule if n <= 1: - if n == 1: - return [0], [0] - return [], [] + raise RuntimeError("Z3 loop scheduling failed: n too small") if verbose: print(f"[Python Z3 Loop] Starting scheduling for {n} tasks") From a16ff869a474202d84d3997b491b8293d0d8b9ab Mon Sep 17 00:00:00 2001 From: Kuris <227995639+kurisu6912@users.noreply.github.com> Date: Wed, 15 Apr 2026 13:13:30 +0800 Subject: [PATCH 052/156] [Feature] autodd: add __freeze__ annotation to protect code regions from reduction (#2045) --- tilelang/autodd.py | 261 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 254 insertions(+), 7 deletions(-) diff --git a/tilelang/autodd.py b/tilelang/autodd.py index aea682fd20..7122282beb 100644 --- a/tilelang/autodd.py +++ b/tilelang/autodd.py @@ -20,6 +20,50 @@ import traceback +class _FreezeSentinel: + """No-op context manager and identity function used to mark frozen regions for autodd. + + Usage in the target script:: + + from tilelang.autodd import __freeze__ + + # Protect a statement block: + with __freeze__: + critical_call(args) + + # Protect a single expression: + result = __freeze__(critical_expr) + """ + + def __enter__(self): + return self + + def __exit__(self, *a): + pass + + def __call__(self, x=None): + return x + + +__freeze__ = _FreezeSentinel() + + +def _is_freeze_with(node: ast.AST) -> bool: + """Detect ``with __freeze__: body`` (no ``as`` clause).""" + return ( + isinstance(node, ast.With) + and len(node.items) == 1 + and node.items[0].optional_vars is None + and isinstance(node.items[0].context_expr, ast.Name) + and node.items[0].context_expr.id == "__freeze__" + ) + + +def _is_freeze_call(node: ast.AST) -> bool: + """Detect ``__freeze__(expr)``.""" + return isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == "__freeze__" + + def ast_replace(node: ast.AST, **changes) -> ast.AST: node = copy(node) for field, value in changes.items(): @@ -374,6 +418,14 @@ def __init__(self, rewrites: list[ASTRewrite]): self.uid_counter = 0 self.rewrite_counter = 0 self.rewrite_names = Counter() + # Freeze propagation state: + # _frozen – True when we are currently inside a frozen subtree + # _stmt_stack – stack of enclosing ast.stmt nodes; used so that a + # __freeze__(expr) child can retroactively freeze its + # ancestor statement (preventing e.g. assign_rhs_1 from + # replacing the whole RHS and destroying the frozen expr). + self._frozen: bool = False + self._stmt_stack: list[ast.AST] = [] @override def visit(self, node: ast.AST, parent: "ast.AST | None", field: "str | None", inside_list: bool): @@ -381,13 +433,46 @@ def visit(self, node: ast.AST, parent: "ast.AST | None", field: "str | None", in node._dd_uid = self.uid_counter self.uid_counter += 1 node._dd_rewrites = [] - for r in self.rewrites: - if r.match(node, parent, field, inside_list): - lr = LabeledRewrite(self.rewrite_counter, r) - self.rewrite_counter += 1 - self.rewrite_names[lr.rewrite.get_name()] += 1 - node._dd_rewrites.append(lr) + + # A node is a freeze boundary if it is ``with __freeze__:`` or + # ``__freeze__(expr)``. Once we cross a boundary, every descendant + # is also frozen. + is_boundary = _is_freeze_with(node) or _is_freeze_call(node) + is_frozen = self._frozen or is_boundary + + # If this node is a freeze boundary, retroactively mark *all* + # enclosing statements as frozen. Marking only the directly + # enclosing statement is not enough: a parent ``if``/``for``/``while`` + # could be removed by stmt-remover, which would take the frozen + # subtree with it. + if is_boundary: + for stmt in self._stmt_stack: + stmt._dd_ancestor_frozen = True + + if not is_frozen: + for r in self.rewrites: + if r.match(node, parent, field, inside_list): + lr = LabeledRewrite(self.rewrite_counter, r) + self.rewrite_counter += 1 + self.rewrite_names[lr.rewrite.get_name()] += 1 + node._dd_rewrites.append(lr) + + is_stmt = isinstance(node, ast.stmt) + if is_stmt: + self._stmt_stack.append(node) + + old_frozen = self._frozen + self._frozen = is_frozen res = self.generic_visit(node) + self._frozen = old_frozen + + if is_stmt: + self._stmt_stack.pop() + # If a child __freeze__() call flagged this statement, wipe any + # rewrites that were attached before we discovered the frozen child. + if getattr(node, "_dd_ancestor_frozen", False): + node._dd_rewrites = [] + return res @@ -594,7 +679,10 @@ class LinePDD(TaskManager, PDD): def __init__(self, source: str, init_proba: float = 0.93): lines = [line for line in source.splitlines() if line.strip() != ""] self.lines = lines - all_labels = [i for i in range(len(lines))] + # Frozen lines are never candidates for removal: exclude them from + # all_labels entirely so PDD never generates tasks that delete them. + frozen = _find_frozen_line_set(source, lines) + all_labels = [i for i in range(len(lines)) if i not in frozen] super().__init__(all_labels, init_proba) @override @@ -878,6 +966,161 @@ def visit_ExceptHandler(self, node: ast.ExceptHandler) -> ast.AST: return ast.unparse(new_tree) +def _has_freeze_import(source: str) -> bool: + """Return True if *source* already contains ``from tilelang.autodd import __freeze__`` + as an actual import statement (not inside a comment or string literal). + """ + try: + tree = ast.parse(source) + except SyntaxError: + return False + for node in ast.walk(tree): + if ( + isinstance(node, ast.ImportFrom) + and node.module == "tilelang.autodd" + and any(alias.name == "__freeze__" for alias in node.names) + ): + return True + return False + + +def _preprocess_freeze_comments(source: str) -> str: + """Convert ``# autodd: freeze`` comment annotations to ``with __freeze__:`` blocks. + + Supports two forms: + + **Block form** – wrap a group of statements:: + + # autodd: freeze-start + stmt1 + stmt2 + # autodd: end-freeze + + **Single-statement form** – end-of-line comment on any non-comment line:: + + stmt # autodd: freeze + + Both forms are converted in-place to ``with __freeze__:`` blocks so that + the freeze information survives ``ast.unparse`` round-trips. + + .. note:: + The single-statement form only works for *physically single-line* statements. + Placing ``# autodd: freeze`` on the last line of a multi-line expression (e.g. + the closing ``)`` of a parenthesised call) will produce a ``SyntaxError`` + because only that one line is wrapped. Use the block form instead. + + The block form prepends exactly 4 spaces to every non-empty line inside the + annotated region. This is correct for regular statements, but it will corrupt + **multi-line string literals** whose continuation lines start at column 0: those + lines will gain unintended leading spaces (or cause a ``SyntaxError`` if the + closing ``\"\"\"`` is shifted). Avoid using the block form around triple-quoted + string literals. + + If any substitution is made and ``from tilelang.autodd import __freeze__`` is not + already present in *source*, the import is automatically prepended so that the + generated ``with __freeze__:`` blocks remain valid Python when executed. + """ + lines = source.splitlines() + result: list[str] = [] + substituted = False + i = 0 + while i < len(lines): + line = lines[i] + stripped = line.strip() + indent = line[: len(line) - len(line.lstrip())] + + # Block form: standalone comment line "# autodd: freeze-start" + if stripped == "# autodd: freeze-start": + substituted = True + i += 1 + block: list[str] = [] + found_end = False + while i < len(lines): + if lines[i].strip() == "# autodd: end-freeze": + i += 1 + found_end = True + break + block.append(lines[i]) + i += 1 + if not found_end: + print( + "autodd WARNING: '# autodd: freeze-start' has no matching " + "'# autodd: end-freeze' — all remaining source is treated as frozen." + ) + result.append(f"{indent}with __freeze__:") + for bl in block: + # Prepend 4 spaces to preserve relative indentation inside the with block. + result.append(f" {bl}" if bl.strip() else bl) + + # Single-statement form: end-of-line "# autodd: freeze" on a non-comment line. + # Extract the comment text and verify it is exactly "# autodd: freeze" so that + # "# autodd: freeze-start" used as an inline comment is not misidentified here. + elif "# autodd: freeze" in line and not stripped.startswith("#"): + marker_idx = line.index("# autodd: freeze") + comment_text = line[marker_idx:].strip() + if comment_text != "# autodd: freeze": + # e.g. "# autodd: freeze-start" or "# autodd: freeze-end" as inline comment + result.append(line) + i += 1 + else: + substituted = True + code_part = line[:marker_idx].rstrip() + result.append(f"{indent}with __freeze__:") + result.append(f"{indent} {code_part.lstrip()}") + i += 1 + + else: + result.append(line) + i += 1 + + body = "\n".join(result) + + # If we made substitutions, ensure __freeze__ is importable in the generated code. + # Users who used only comment annotations may not have the explicit import in their + # script; without it every exec() call would raise NameError. + # We use an AST-level check rather than a plain substring search so that a + # commented-out import (e.g. "# from tilelang.autodd import __freeze__") is not + # mistaken for an active one. + if substituted and not _has_freeze_import(body): + body = "from tilelang.autodd import __freeze__\n" + body + + return body + + +def _find_frozen_line_set(source: str, nonempty_lines: list[str]) -> set[int]: + """Return the set of indices into *nonempty_lines* that belong to frozen regions. + + A line is considered frozen if it falls within the source span of a + ``with __freeze__:`` block or a ``__freeze__(expr)`` call. + """ + try: + tree = ast.parse(source) + except SyntaxError: + return set() + + # Collect 1-indexed source line numbers that are inside frozen regions. + frozen_linenos: set[int] = set() + for node in ast.walk(tree): + if _is_freeze_with(node) or _is_freeze_call(node): + start = getattr(node, "lineno", None) + end = getattr(node, "end_lineno", None) + if start is not None and end is not None: + frozen_linenos.update(range(start, end + 1)) + + if not frozen_linenos: + return set() + + # Map 1-indexed source line numbers → indices in nonempty_lines. + frozen_indices: set[int] = set() + nonempty_idx = 0 + for lineno_0, line in enumerate(source.splitlines()): + if line.strip(): # non-empty → has an entry in nonempty_lines + if (lineno_0 + 1) in frozen_linenos: + frozen_indices.add(nonempty_idx) + nonempty_idx += 1 + return frozen_indices + + JobBackend = Literal["subproc", "runner"] @@ -1119,6 +1362,10 @@ async def main(args: Args): ] + fast_reducers await manager.start_workers() + # One-time preprocessing: convert # autodd: freeze comments to + # ``with __freeze__:`` blocks so that freeze annotations survive + # ast.unparse round-trips throughout the reduction loop. + manager.text = _preprocess_freeze_comments(manager.text) manager.text = manager.post_proc(manager.text) try: while True: From d2e02e10c51422dcdfd7925ad9d16e96540461e4 Mon Sep 17 00:00:00 2001 From: Kuris <227995639+kurisu6912@users.noreply.github.com> Date: Wed, 15 Apr 2026 13:34:43 +0800 Subject: [PATCH 053/156] [BugFix] Skip MMA shared buffer layout inference when layout already exists (#2008) * [BugFix] Skip MMA shared buffer layout inference when layout already exists (#1997) When a shared memory buffer is consumed by multiple gemm operations with different transpose semantics, each gemm infers a different swizzle layout, causing a layout conflict error. For MMA instructions, the swizzle layout is only a bank conflict optimization, not a correctness requirement. Skip layout inference for shared buffers that already have a layout inferred by a prior operator. WGMMA/TCGEN5MMA/MFMA retain strict layout enforcement. Co-Authored-By: Claude Opus 4.6 * Rename raw_results to inferred_layouts --------- Co-authored-by: Claude Opus 4.6 Co-authored-by: LeiWang1999 --- src/op/gemm.cc | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 2facb9d9f3..7ba7c01ce2 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -455,16 +455,27 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, if (completed_) return {}; LayoutMap results; - if (const auto f = ffi::Function::GetGlobal("tl.gemm.infer_layout")) { - results = Downcast( + auto inferred_layouts = Downcast( (*f)(tvm::ffi::GetRef(this), T.target, T.thread_bounds)); - // Bind all fragment layouts with the provided thread range - for (auto kv : results) { + // For MMA instructions, skip shared buffer layouts that are already + // inferred by a prior operator to avoid layout conflicts when the same + // shared buffer is consumed by multiple gemm ops with different transpose + // semantics. WGMMA/TCGEN5MMA have strict shared memory layout requirements + // and must always set their layouts. + auto block_size = *as_const_int(T.thread_bounds->extent); + GemmInst gemm_inst = getGemmInst(block_size, T.target); + bool is_mma = (gemm_inst == GemmInst::kMMA); + for (auto kv : inferred_layouts) { const Buffer &buf = kv.first; const Layout &layout = kv.second; + if (is_mma && IsSharedBuffer(buf) && T.layout_map.count(buf)) { + continue; + } if (auto frag = layout.as()) { results.Set(buf, frag.value()->BindThreadRange(T.thread_bounds)); + } else { + results.Set(buf, layout); } } } else { From c93778f7bd4c504c85c84a9f44e39ed774d4a56a Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Wed, 15 Apr 2026 14:29:08 +0800 Subject: [PATCH 054/156] support tcgen05_gemm --- src/transform/auto_schedule.cc | 7 +++++- src/transform/auto_schedule/barrier.h | 9 +++++++- src/transform/auto_schedule/ir_structure.cc | 16 ++++++++++++++ src/transform/auto_schedule/ir_structure.h | 24 +++++++++++++++++++++ 4 files changed, 54 insertions(+), 2 deletions(-) diff --git a/src/transform/auto_schedule.cc b/src/transform/auto_schedule.cc index faf11b338d..70be1d5eba 100644 --- a/src/transform/auto_schedule.cc +++ b/src/transform/auto_schedule.cc @@ -359,6 +359,9 @@ class IRStructureBuilder : public StmtVisitor { static const auto gemm_op = Op::Get("tl.tileop.gemm"); static const auto wgmma_gemm_py_op = Op::Get("tl.tileop.wgmma_gemm_py"); static const auto wgmma_gemm_op = Op::Get("tl.tileop.wgmma_gemm"); + static const auto tcgen05_gemm_py_op = + Op::Get("tl.tileop.tcgen05_gemm_py"); + static const auto tcgen05_gemm_op = Op::Get("tl.tileop.tcgen05_gemm"); static const auto reduce_op = Op::Get("tl.tileop.reduce"); static const auto fill_op = Op::Get("tl.tileop.fill"); static const auto region_op = Op::Get("tl.tileop.region"); @@ -401,7 +404,9 @@ class IRStructureBuilder : public StmtVisitor { } } else if (op->op.same_as(gemm_py_op) || op->op.same_as(gemm_op) || op->op.same_as(wgmma_gemm_py_op) || - op->op.same_as(wgmma_gemm_op)) { + op->op.same_as(wgmma_gemm_op) || + op->op.same_as(tcgen05_gemm_py_op) || + op->op.same_as(tcgen05_gemm_op)) { found_tensor = true; int64_t m = op->args[5].as()->value; diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index 5bf85689d0..81502880c7 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -461,6 +461,8 @@ static void RewriteGemmMbar(TaskNode *task, PrimExpr mbar_expr) { static const auto gemm_op = Op::Get("tl.tileop.gemm"); static const auto wgmma_gemm_py_op = Op::Get("tl.tileop.wgmma_gemm_py"); static const auto wgmma_gemm_op = Op::Get("tl.tileop.wgmma_gemm"); + static const auto tcgen05_gemm_py_op = Op::Get("tl.tileop.tcgen05_gemm_py"); + static const auto tcgen05_gemm_op = Op::Get("tl.tileop.tcgen05_gemm"); class GemmMbarRewriter : public StmtExprMutator { public: @@ -472,9 +474,14 @@ static void RewriteGemmMbar(TaskNode *task, PrimExpr mbar_expr) { static const auto gemm_op = Op::Get("tl.tileop.gemm"); static const auto wgmma_gemm_py_op = Op::Get("tl.tileop.wgmma_gemm_py"); static const auto wgmma_gemm_op = Op::Get("tl.tileop.wgmma_gemm"); + static const auto tcgen05_gemm_py_op = + Op::Get("tl.tileop.tcgen05_gemm_py"); + static const auto tcgen05_gemm_op = Op::Get("tl.tileop.tcgen05_gemm"); if ((op->op.same_as(gemm_py_op) || op->op.same_as(gemm_op) || - op->op.same_as(wgmma_gemm_py_op) || op->op.same_as(wgmma_gemm_op)) && + op->op.same_as(wgmma_gemm_py_op) || op->op.same_as(wgmma_gemm_op) || + op->op.same_as(tcgen05_gemm_py_op) || + op->op.same_as(tcgen05_gemm_op)) && op->args.size() > 16) { Array new_args; for (size_t i = 0; i < op->args.size(); ++i) { diff --git a/src/transform/auto_schedule/ir_structure.cc b/src/transform/auto_schedule/ir_structure.cc index 394e776307..fa0ecee278 100644 --- a/src/transform/auto_schedule/ir_structure.cc +++ b/src/transform/auto_schedule/ir_structure.cc @@ -88,6 +88,22 @@ bool SequenceNode::UsesTensorCore() const { return false; } +bool SequenceNode::HasWGMMA() const { + for (const auto &child : children) { + if (child && child->HasWGMMA()) + return true; + } + return false; +} + +bool SequenceNode::HasTCGEN05() const { + for (const auto &child : children) { + if (child && child->HasTCGEN05()) + return true; + } + return false; +} + std::vector SequenceNode::GetReadRegions() const { std::vector all_read_regions; for (const auto &child : children) { diff --git a/src/transform/auto_schedule/ir_structure.h b/src/transform/auto_schedule/ir_structure.h index 902a31fa54..b159409cf0 100644 --- a/src/transform/auto_schedule/ir_structure.h +++ b/src/transform/auto_schedule/ir_structure.h @@ -80,6 +80,9 @@ class IRStructure { virtual bool UsesTMACore() const = 0; virtual bool UsesTensorCore() const = 0; + virtual bool HasWGMMA() const = 0; + virtual bool HasTCGEN05() const = 0; + // Memory access regions (collected during analysis) virtual std::vector GetReadRegions() const = 0; virtual std::vector GetWriteRegions() const = 0; @@ -246,6 +249,9 @@ class TaskNode : public IRStructure { return has_gemm_inst_ && gemm_inst_ == GemmInst::kTCGEN5MMA; } + bool HasWGMMA() const override { return is_WGMMA(); } + bool HasTCGEN05() const override { return is_TCGEN05(); } + // Get aggregated shape information for II estimation int64_t GetTotalTensorCoreOps() const { int64_t total_ops = 0; @@ -362,6 +368,11 @@ class ControlNode : public IRStructure { return child ? child->UsesTensorCore() : false; } + bool HasWGMMA() const override { return child ? child->HasWGMMA() : false; } + bool HasTCGEN05() const override { + return child ? child->HasTCGEN05() : false; + } + // Memory access regions (aggregate from child & task) std::vector GetReadRegions() const override { std::vector regions = @@ -496,6 +507,11 @@ class WrapperNode : public IRStructure { return child ? child->UsesTensorCore() : false; } + bool HasWGMMA() const override { return child ? child->HasWGMMA() : false; } + bool HasTCGEN05() const override { + return child ? child->HasTCGEN05() : false; + } + // Memory access regions (aggregate from child) std::vector GetReadRegions() const override { return child ? child->GetReadRegions() : std::vector{}; @@ -590,6 +606,11 @@ class ScheduleUnit : public IRStructure { return child ? child->UsesTensorCore() : false; } + bool HasWGMMA() const override { return child ? child->HasWGMMA() : false; } + bool HasTCGEN05() const override { + return child ? child->HasTCGEN05() : false; + } + // Memory access regions (aggregate from child) std::vector GetReadRegions() const override { return child ? child->GetReadRegions() : std::vector{}; @@ -686,6 +707,9 @@ class SequenceNode : public IRStructure { bool UsesTMACore() const override; bool UsesTensorCore() const override; + bool HasWGMMA() const override; + bool HasTCGEN05() const override; + // Memory access regions (aggregate from all children) std::vector GetReadRegions() const override; std::vector GetWriteRegions() const override; From 789367800ae0bd71e2c985cc6c8a98a3c6276d49 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Wed, 15 Apr 2026 16:13:49 +0800 Subject: [PATCH 055/156] [Refactor] Remove obsolete RewriteWgmmaSync pass (#2046) The RewriteWgmmaSync pass is no longer needed. Remove the pass implementation, its Python wrapper, and the call site in the compilation phase. --- src/transform/wgmma_sync_rewriter.cc | 275 --------------------------- tilelang/engine/phase.py | 4 +- tilelang/transform/__init__.py | 11 -- 3 files changed, 1 insertion(+), 289 deletions(-) delete mode 100644 src/transform/wgmma_sync_rewriter.cc diff --git a/src/transform/wgmma_sync_rewriter.cc b/src/transform/wgmma_sync_rewriter.cc deleted file mode 100644 index 538b491107..0000000000 --- a/src/transform/wgmma_sync_rewriter.cc +++ /dev/null @@ -1,275 +0,0 @@ -/*! - * \file warp_specialized_pipeline.cc - * \brief Warp specialized Pipeline for cuda GPU (sm90+) - */ - -#include -#include -#include -#include -#include -#include - -#include - -#include "../op/builtin.h" - -namespace tvm { -namespace tl { - -using namespace tir; - -bool isGemm(const Stmt &stmt) { - bool is_gemm = false; - if (stmt.as()) { - auto call = Downcast(stmt)->value.as(); - if (call && call->op.same_as(Op::Get("tir.call_extern"))) { - if (call->args[0].as()) { - std::string name = Downcast(call->args[0])->value; - if (name.find("gemm") != std::string::npos) { - is_gemm = true; - } - } - } - } - return is_gemm; -} - -bool isGemmSync(const Stmt &stmt) { - bool is_gemm_sync = false; - if (stmt.as()) { - auto call = Downcast(stmt)->value.as(); - if (call && call->op.same_as(Op::Get("tir.call_extern"))) { - if (call->args[0].as()) { - std::string name = Downcast(call->args[0])->value; - if (name.find("warpgroup_wait") != std::string::npos) { - is_gemm_sync = true; - } - } - } - } - return is_gemm_sync; -} - -bool isArriveBarrier(const Stmt &stmt) { - bool is_arrive_barrier = false; - if (stmt.as()) { - auto call = Downcast(stmt)->value.as(); - if (call && call->op.same_as(Op::Get("tir.ptx_arrive_barrier"))) { - is_arrive_barrier = true; - } - } - return is_arrive_barrier; -} - -class WgmmaSyncRewriter : public StmtExprMutator { -public: - static PrimFunc Substitute(PrimFunc f) { - auto T = WgmmaSyncRewriter(); - T.buffer_lca_ = DetectBufferAccessLCA(f); - for (auto [buffer, _] : T.buffer_lca_) - T.buffer_data_to_buffer_.Set(buffer->data, buffer); - f.CopyOnWrite()->body = T(f->body); - return f; - } - -private: - void CollectWgmmaInfo(const SeqStmtNode *op) { - for (int i = 0; i < static_cast(op->seq.size()); i++) { - auto stmt = op->seq[i]; - if (isGemm(stmt)) { - gemm_stmts_.push_back(stmt); - gemm_stmt_ids_.push_back(i); - bool found_release = false; - for (int j = i + 1; j < static_cast(op->seq.size()); j++) { - auto release_stmt = op->seq[j]; - if (isArriveBarrier(release_stmt)) { - found_release = true; - gemm_release_stmts_.push_back(release_stmt); - break; - } - } - if (!found_release) { - gemm_release_stmts_.push_back(Evaluate(0)); - } - // ICHECK(op->seq.size() > i + 1); - // auto release_stmt = op->seq[i + 1]; - // auto next_call = - // Downcast(release_stmt)->value.as(); - // ICHECK(next_call); - // ICHECK(next_call->op.same_as(Op::Get("tir.ptx_arrive_barrier"))); - Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, - /*name_hint=*/"", - /*body*/ op->seq[i]); - auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); - std::set read_set, write_set; - for (auto region : access[0]) - read_set.insert(region->buffer.get()); - for (auto region : access[1]) - write_set.insert(region->buffer.get()); - gemm_read_buffers_.push_back(read_set); - gemm_write_buffers_.push_back(write_set); - } - } - } - - Stmt VisitStmt_(const ForNode *op) final { - auto order_anno = op->annotations.Get("tl_pipeline_order"); - if (!order_anno) { - return StmtExprMutator::VisitStmt_(op); - } - - CollectWgmmaInfo(op->body.as()); - auto stmt_node = (op->body).as(); - ICHECK(stmt_node); - - auto intersect_fn = [](const std::set &lhs, - const std::set &rhs) { - for (auto ptr : lhs) - if (rhs.count(ptr)) - return true; - return false; - }; - - for (int r = 0; r < static_cast(gemm_stmts_.size()); r++) { - bool found = false; - auto last_stmt = Stmt(); - for (int i = 0; i < static_cast(stmt_node->seq.size()); i++) { - if (stmt_node->seq[i].same_as(gemm_stmts_[r])) { - found = true; - last_stmt = stmt_node->seq[i]; - continue; - } - if (!found) - continue; - Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, - /*name_hint=*/"", - /*body*/ stmt_node->seq[i]); - auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); - std::set read_set, write_set; - for (auto region : access[0]) - read_set.insert(region->buffer.get()); - for (auto region : access[1]) - write_set.insert(region->buffer.get()); - if (intersect_fn(read_set, gemm_write_buffers_[r]) || - intersect_fn(write_set, gemm_read_buffers_[r]) || - intersect_fn(write_set, gemm_write_buffers_[r])) { - break; - } - last_stmt = stmt_node->seq[i]; - } - last_stmts_.push_back(last_stmt); - } - - auto new_seq = Array(); - for (int i = 0; i < static_cast(stmt_node->seq.size()); i++) { - bool remove_ = false; - for (int j = 0; j < static_cast(gemm_stmts_.size()); j++) { - if (stmt_node->seq[i].same_as(gemm_release_stmts_[j])) { - remove_ = true; - continue; - } - } - if (remove_) - continue; - auto stmt = stmt_node->seq[i]; - for (int j = 0; j < static_cast(gemm_stmts_.size()); j++) { - if (stmt_node->seq[i].same_as(gemm_stmts_[j])) { - auto call = Downcast(stmt)->value.as(); - ICHECK(call); - ICHECK(call->op.same_as(Op::Get("tir.call_extern"))); - ICHECK(call->args[0].as()); - std::string name = Downcast(call->args[0])->value; - std::string new_name = name.substr(0, name.size() - 1) + ", -1>"; - auto new_args = Array(); - new_args.push_back(StringImm(new_name)); - for (int k = 1; k < static_cast(call->args.size()); k++) { - new_args.push_back(call->args[k]); - } - stmt = Evaluate( - Call(DataType::Handle(), builtin::call_extern(), new_args)); - break; - } - } - - new_seq.push_back(stmt); - for (int j = 0; j < static_cast(gemm_stmts_.size()); j++) { - if (stmt_node->seq[i].same_as(last_stmts_[j])) { - Array new_args; - new_args.push_back(StringImm("cute::warpgroup_wait<0>")); - new_args.push_back(Integer(j)); - auto new_call = - Call(DataType::Handle(), builtin::call_extern(), new_args); - new_seq.push_back(Evaluate(new_call)); - if (std::count(gemm_release_stmts_.begin(), gemm_release_stmts_.end(), - gemm_release_stmts_[j]) == 1) { - new_seq.push_back(gemm_release_stmts_[j]); - } else { - gemm_release_stmts_[j] = Evaluate(0); - } - } - } - } - - int gemm_count = 0; - int max_sync_index = 0; - for (int i = 0; i < static_cast(new_seq.size()); i++) { - if (isGemm(new_seq[i])) { - gemm_count++; - } else if (isGemmSync(new_seq[i])) { - auto call = Downcast(new_seq[i])->value.as(); - auto sync_index = - static_cast(Downcast(call->args[1])->value); - auto wait_count = gemm_count - sync_index - 1; - if (sync_index > max_sync_index) - max_sync_index = sync_index; - if (sync_index < max_sync_index) { - // new_seq.erase(new_seq.begin() + i); - new_seq.Set(i, Evaluate(0)); - } else { - Array new_args; - std::string call_str = - "cute::warpgroup_wait<" + std::to_string(wait_count) + ">"; - new_args.push_back(StringImm(call_str)); - new_seq.Set(i, Evaluate(Call(DataType::Handle(), - builtin::call_extern(), new_args))); - } - } - } - auto new_for = - For(op->loop_var, op->min, op->extent, op->kind, - new_seq.size() == 1 ? new_seq[0] : SeqStmt(std::move(new_seq)), - op->thread_binding, op->annotations); - return new_for; - } - - WgmmaSyncRewriter() = default; - - Map> buffer_lca_; - Map buffer_data_to_buffer_; - std::vector> gemm_read_buffers_; - std::vector> gemm_write_buffers_; - std::vector gemm_stmts_; - std::vector gemm_release_stmts_; - std::vector last_stmts_; - - std::vector gemm_stmt_ids_; - friend class WgmmaReleaseCollector; -}; - -using namespace tir::transform; - -tvm::transform::Pass RewriteWgmmaSync() { - auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { - return WgmmaSyncRewriter::Substitute(std::move(f)); - }; - return CreatePrimFuncPass(pass_func, 0, "tl.RewriteWgmmaSync", {}); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tl.transform.RewriteWgmmaSync", RewriteWgmmaSync); -} - -} // namespace tl -} // namespace tvm diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index ec47fa3075..683307415b 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -3,7 +3,7 @@ from tvm.target import Target import tilelang from tilelang.transform import PassContext -from tilelang.contrib.nvcc import have_tma, is_hopper, have_pdl +from tilelang.contrib.nvcc import have_tma, have_pdl def allow_warp_specialized(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool: @@ -238,8 +238,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.FuseMBarrierArriveExpectTx()(mod) mod = tilelang.transform.HoistGlobalBufferAllocations()(mod) mod = tilelang.transform.LowerOpaqueBlock()(mod) - if is_hopper(target): - mod = tilelang.transform.RewriteWgmmaSync()(mod) mod = tilelang.transform.Simplify()(mod) mod = tir.transform.NarrowDataType(32)(mod) mod = tilelang.transform.FlattenBuffer()(mod) diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index 75ad91d244..ba851b5e03 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -154,17 +154,6 @@ def WarpSpecializedPipeline(): return _ffi_api.WarpSpecializedPipeline() # type: ignore -def RewriteWgmmaSync(): - """RewriteWgmmaSync - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.RewriteWgmmaSync() # type: ignore - - def ThreadSync(storage_scope: str): """Insert sync between parallel read/write of shared buffers. From e3d214dee6df9cbd986a65affb0cec67be590cf2 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Wed, 15 Apr 2026 16:51:33 +0800 Subject: [PATCH 056/156] Add if node --- src/transform/auto_schedule.cc | 37 +++- src/transform/auto_schedule/barrier.h | 28 ++- src/transform/auto_schedule/ir_structure.cc | 45 +++++ src/transform/auto_schedule/ir_structure.h | 190 +++++++++++++++++- src/transform/auto_schedule/memory_detector.h | 9 +- .../auto_schedule/schedule_builder.cc | 22 ++ .../auto_schedule/schedule_builder.h | 12 +- .../auto_schedule/warpgroup_partition.cc | 76 +++++++ tilelang/engine/phase.py | 19 +- tilelang/transform/z3_scheduler.py | 18 +- 10 files changed, 420 insertions(+), 36 deletions(-) diff --git a/src/transform/auto_schedule.cc b/src/transform/auto_schedule.cc index 70be1d5eba..63f6d835c2 100644 --- a/src/transform/auto_schedule.cc +++ b/src/transform/auto_schedule.cc @@ -238,20 +238,37 @@ class IRStructureBuilder : public StmtVisitor { } void VisitStmt_(const IfThenElseNode *op) override { - // If statement -> treat as TaskNode for now (could be refined later) - auto task_node = std::make_shared(); - task_node->stmts.push_back(GetRef(op)); - - AnalyzeMemoryExpr(op->condition, task_node.get()); - AnalyzeResourceUsage(Evaluate(op->condition), task_node.get(), true); + // If statement -> IfNode with independently schedulable branches + auto if_node = std::make_shared(); + if_node->condition = op->condition; + + // Create task for condition expression resource analysis + auto cond_task = std::make_shared(); + cond_task->stmts.push_back(Evaluate(op->condition)); + AnalyzeMemoryExpr(op->condition, cond_task.get()); + AnalyzeResourceUsage(Evaluate(op->condition), cond_task.get(), true); + if_node->task = std::move(cond_task); + + // Recursively build then branch + VisitStmt(op->then_case); + if (root_) { + if_node->then_child = std::move(root_); + } - // Analyze both branches for resource usage - AnalyzeResourceUsage(op->then_case, task_node.get()); + // Recursively build else branch (if present) if (op->else_case) { - AnalyzeResourceUsage(op->else_case.value(), task_node.get()); + VisitStmt(op->else_case.value()); + if (root_) { + if_node->else_child = std::move(root_); + } } - root_ = std::move(task_node); + // Latency = max of both branches + int64_t then_latency = if_node->then_child ? if_node->then_child->GetLatency() : 0; + int64_t else_latency = if_node->else_child ? if_node->else_child->GetLatency() : 0; + if_node->SetLatency(std::max(then_latency, else_latency)); + + root_ = std::move(if_node); } void VisitStmt_(const LetStmtNode *op) override { diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index 81502880c7..2251bde05f 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -449,6 +449,12 @@ static void RewriteTaskNodeBuffers( } else if (node->IsScheduleUnit()) { auto unit = static_cast(node); RewriteTaskNodeBuffers(unit->child.get(), multi_buffer, iteration); + } else if (node->IsIf()) { + auto if_node = static_cast(node); + if (if_node->then_child) + RewriteTaskNodeBuffers(if_node->then_child.get(), multi_buffer, iteration); + if (if_node->else_child) + RewriteTaskNodeBuffers(if_node->else_child.get(), multi_buffer, iteration); } } @@ -571,6 +577,20 @@ AnalyzeAndInsertBarriers(IRStructure *node, int &next_barrier_id, AnalyzeAndInsertBarriers( wrapper->child.get(), next_barrier_id, barrier_buffers, barrier_map, thread_count, loop_info, buffer_infos, neutral_sync_shared_barrier); + } else if (node->IsIf()) { + auto if_node = static_cast(node); + if (if_node->then_child) { + AnalyzeAndInsertBarriers( + if_node->then_child.get(), next_barrier_id, barrier_buffers, + barrier_map, thread_count, loop_info, buffer_infos, + neutral_sync_shared_barrier); + } + if (if_node->else_child) { + AnalyzeAndInsertBarriers( + if_node->else_child.get(), next_barrier_id, barrier_buffers, + barrier_map, thread_count, loop_info, buffer_infos, + neutral_sync_shared_barrier); + } } else if (node->IsTask()) { // For TaskNode, nothing to do at this level } else { @@ -591,8 +611,8 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, for (auto &promote_child : seq->children) { auto task = static_cast(promote_child.get()); - if (task->child->IsSequence() || task->child->IsControl()) { - // If child is SequenceNode or ControlNode, recursively analyze it + if (task->child->IsSequence() || task->child->IsControl() || task->child->IsIf()) { + // If child is SequenceNode, ControlNode, or IfNode, recursively analyze it AnalyzeAndInsertBarriers( task->child.get(), next_barrier_id, barrier_buffers, barrier_map, thread_count, loop_info, buffer_infos, neutral_sync_shared_barrier); @@ -836,8 +856,8 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, std::vector ordered_tasks; for (auto &child : seq->children) { auto task = static_cast(child.get()); - if (task->child->IsSequence() || task->child->IsControl()) { - // If child is SequenceNode or ControlNode, recursively analyze it + if (task->child->IsSequence() || task->child->IsControl() || task->child->IsIf()) { + // If child is SequenceNode, ControlNode, or IfNode, recursively analyze it AnalyzeAndInsertBarriers( task->child.get(), next_barrier_id, barrier_buffers, barrier_map, thread_count, loop_info, buffer_infos, neutral_sync_shared_barrier); diff --git a/src/transform/auto_schedule/ir_structure.cc b/src/transform/auto_schedule/ir_structure.cc index fa0ecee278..5305f4bbd8 100644 --- a/src/transform/auto_schedule/ir_structure.cc +++ b/src/transform/auto_schedule/ir_structure.cc @@ -364,6 +364,24 @@ std::shared_ptr ScheduleUnit::Clone() const { return new_unit; } +std::shared_ptr IfNode::Clone() const { + auto new_if = std::make_shared(); + new_if->condition = condition; + if (then_child) { + new_if->then_child = then_child->Clone(); + } + if (else_child) { + new_if->else_child = else_child->Clone(); + } + if (task) { + new_if->task = std::static_pointer_cast(task->Clone()); + } + new_if->SetLatency(GetLatency()); + new_if->SetII(GetII()); + new_if->SetStartTime(GetStartTime()); + return new_if; +} + void ControlNode::CollectRegions( std::vector &result, std::set>> &visited) const { @@ -398,6 +416,20 @@ void SequenceNode::CollectRegions( } } +void IfNode::CollectRegions( + std::vector &result, + std::set>> &visited) const { + if (task) { + task->CollectRegions(result, visited); + } + if (then_child) { + then_child->CollectRegions(result, visited); + } + if (else_child) { + else_child->CollectRegions(result, visited); + } +} + // Helper function to collect all TaskNodes with context information void CollectAllTaskNodesWithContext(IRStructure *node, std::vector &all_tasks, @@ -466,6 +498,19 @@ void CollectAllTaskNodesWithContext(IRStructure *node, // Promote nodes don't change control context, just recurse into child CollectAllTaskNodesWithContext(promote->child.get(), all_tasks, current_control_node); + } else if (node->IsIf()) { + auto if_node = static_cast(node); + // Recurse into both branches + if (if_node->task) { + CollectAllTaskNodesWithContext(if_node->task.get(), all_tasks, + current_control_node); + } + CollectAllTaskNodesWithContext(if_node->then_child.get(), all_tasks, + current_control_node); + if (if_node->else_child) { + CollectAllTaskNodesWithContext(if_node->else_child.get(), all_tasks, + current_control_node); + } } else { LOG(FATAL); } diff --git a/src/transform/auto_schedule/ir_structure.h b/src/transform/auto_schedule/ir_structure.h index b159409cf0..7f5fe1b984 100644 --- a/src/transform/auto_schedule/ir_structure.h +++ b/src/transform/auto_schedule/ir_structure.h @@ -29,6 +29,7 @@ class TaskNode; class ControlNode; class SequenceNode; class WrapperNode; +class IfNode; // Structure to store region access information with warpgroup id struct RegionAccessInfo { @@ -62,7 +63,7 @@ inline bool RegionsEqual(const Region &a, const Region &b) { // Base class for all IR nodes in scheduling class IRStructure { public: - enum class Kind { kTask, kControl, kSequence, kWrapper, kSchedule }; + enum class Kind { kTask, kControl, kSequence, kWrapper, kSchedule, kIf }; virtual ~IRStructure() = default; virtual Kind GetKind() const = 0; @@ -74,6 +75,7 @@ class IRStructure { bool IsSequence() const { return GetKind() == Kind::kSequence; } bool IsWrapper() const { return GetKind() == Kind::kWrapper; } bool IsScheduleUnit() const { return GetKind() == Kind::kSchedule; } + bool IsIf() const { return GetKind() == Kind::kIf; } // Resource usage flags (accessible by all IR nodes) virtual bool UsesCUDACore() const = 0; @@ -585,6 +587,160 @@ class WrapperNode : public IRStructure { int64_t ii_{0}; // Initiation interval in cycles }; +// If node: represents an IfThenElse conditional with independently schedulable +// branches. Acts as an atomic unit for outer scheduling but allows recursive +// scheduling within each branch. +class IfNode : public IRStructure { +public: + PrimExpr condition; + std::shared_ptr then_child; + std::shared_ptr else_child; // optional + std::shared_ptr task; // resource analysis for the condition expr + + Kind GetKind() const override { return Kind::kIf; } + + // Resource usage flags (aggregate from both branches) + bool UsesCUDACore() const override { + bool result = false; + if (then_child) result |= then_child->UsesCUDACore(); + if (else_child) result |= else_child->UsesCUDACore(); + return result; + } + bool UsesTMACore() const override { + bool result = false; + if (then_child) result |= then_child->UsesTMACore(); + if (else_child) result |= else_child->UsesTMACore(); + return result; + } + bool UsesTensorCore() const override { + bool result = false; + if (then_child) result |= then_child->UsesTensorCore(); + if (else_child) result |= else_child->UsesTensorCore(); + return result; + } + + bool HasWGMMA() const override { + return (then_child && then_child->HasWGMMA()) || + (else_child && else_child->HasWGMMA()); + } + bool HasTCGEN05() const override { + return (then_child && then_child->HasTCGEN05()) || + (else_child && else_child->HasTCGEN05()); + } + + // Memory access regions (union of both branches + task) + std::vector GetReadRegions() const override { + std::vector regions; + if (task) { + auto task_regions = task->GetReadRegions(); + regions.insert(regions.end(), task_regions.begin(), task_regions.end()); + } + if (then_child) { + auto r = then_child->GetReadRegions(); + regions.insert(regions.end(), r.begin(), r.end()); + } + if (else_child) { + auto r = else_child->GetReadRegions(); + regions.insert(regions.end(), r.begin(), r.end()); + } + return regions; + } + std::vector GetWriteRegions() const override { + std::vector regions; + if (task) { + auto task_regions = task->GetWriteRegions(); + regions.insert(regions.end(), task_regions.begin(), task_regions.end()); + } + if (then_child) { + auto r = then_child->GetWriteRegions(); + regions.insert(regions.end(), r.begin(), r.end()); + } + if (else_child) { + auto r = else_child->GetWriteRegions(); + regions.insert(regions.end(), r.begin(), r.end()); + } + return regions; + } + + std::vector GetReadVars() const override { + std::vector vars; + if (task) { + auto v = task->GetReadVars(); + vars.insert(vars.end(), v.begin(), v.end()); + } + if (then_child) { + auto v = then_child->GetReadVars(); + vars.insert(vars.end(), v.begin(), v.end()); + } + if (else_child) { + auto v = else_child->GetReadVars(); + vars.insert(vars.end(), v.begin(), v.end()); + } + return vars; + } + std::vector GetWriteVars() const override { + std::vector vars; + if (task) { + auto v = task->GetWriteVars(); + vars.insert(vars.end(), v.begin(), v.end()); + } + if (then_child) { + auto v = then_child->GetWriteVars(); + vars.insert(vars.end(), v.begin(), v.end()); + } + if (else_child) { + auto v = else_child->GetWriteVars(); + vars.insert(vars.end(), v.begin(), v.end()); + } + return vars; + } + + void SubstituteVar(const Var &old_var, const Var &new_var) override { + condition = Substitute(condition, {{old_var, new_var}}); + if (then_child) then_child->SubstituteVar(old_var, new_var); + if (else_child) else_child->SubstituteVar(old_var, new_var); + if (task) task->SubstituteVar(old_var, new_var); + } + + // Latency = max of both branches + int64_t GetLatency() const override { return latency_; } + int64_t GetII() const override { return ii_; } + + // Setters (delegate to both branches) + void SetUsesCUDACore(bool value) override { + if (then_child) then_child->SetUsesCUDACore(value); + if (else_child) else_child->SetUsesCUDACore(value); + } + void SetUsesTMACore(bool value) override { + if (then_child) then_child->SetUsesTMACore(value); + if (else_child) else_child->SetUsesTMACore(value); + } + void SetUsesTensorCore(bool value) override { + if (then_child) then_child->SetUsesTensorCore(value); + if (else_child) else_child->SetUsesTensorCore(value); + } + void SetReadRegions(const std::vector ®ions) override {} + void SetWriteRegions(const std::vector ®ions) override {} + void SetLatency(int64_t latency) override { latency_ = latency; } + void SetII(int64_t ii) override { ii_ = ii; } + + void CollectRegions( + std::vector &result, + std::set>> &visited) const override; + + // Clone method + std::shared_ptr Clone() const override; + + bool containWarpgroupId(int id) const override { + return (then_child && then_child->containWarpgroupId(id)) || + (else_child && else_child->containWarpgroupId(id)); + } + +private: + int64_t latency_{0}; + int64_t ii_{0}; +}; + class ScheduleUnit : public IRStructure { public: int stage; @@ -875,6 +1031,7 @@ CollectTopLevelControlNodes(IRStructure *node, auto unit = static_cast(node); CollectTopLevelControlNodes(unit->child.get(), control_nodes); } + // IfNode is atomic — don't recurse into it } // Collect all TaskNodes at the top level (not inside any ControlNode) @@ -896,6 +1053,7 @@ inline void CollectTopLevelTaskNodes(IRStructure *node, auto unit = static_cast(node); CollectTopLevelTaskNodes(unit->child.get(), task_nodes); } + // IfNode is atomic — don't recurse into it } // Collect all top-level leaf items (TaskNodes and ControlNodes) in order. @@ -905,7 +1063,7 @@ inline void CollectTopLevelItems(IRStructure *node, std::vector &items) { if (!node) return; - if (node->IsTask() || node->IsControl()) { + if (node->IsTask() || node->IsControl() || node->IsIf()) { items.push_back(node); } else if (node->IsSequence()) { auto seq = static_cast(node); @@ -1032,6 +1190,18 @@ inline void PrintAllStmts(const IRStructure *node, int indent = 0) { LOG(INFO) << indent_str << " Promote body:"; PrintAllStmts(promote->child.get(), indent + 2); } + } else if (node->IsIf()) { + const IfNode *if_node = static_cast(node); + LOG(INFO) << indent_str << "IfNode:"; + LOG(INFO) << indent_str << " Condition: " << if_node->condition; + if (if_node->then_child) { + LOG(INFO) << indent_str << " Then:"; + PrintAllStmts(if_node->then_child.get(), indent + 2); + } + if (if_node->else_child) { + LOG(INFO) << indent_str << " Else:"; + PrintAllStmts(if_node->else_child.get(), indent + 2); + } } } @@ -1127,6 +1297,22 @@ inline void PrintIRStructure(const IRStructure *node, int indent = 0) { LOG(INFO) << indent_str << " Promote body:"; PrintAllStmts(promote->child.get(), indent + 2); } + } else if (node->IsIf()) { + const IfNode *if_node = static_cast(node); + LOG(INFO) << indent_str << "IfNode:"; + LOG(INFO) << indent_str << " Condition: " << if_node->condition; + if (if_node->task) { + LOG(INFO) << indent_str << " Task:"; + PrintIRStructure(if_node->task.get(), indent + 4); + } + if (if_node->then_child) { + LOG(INFO) << indent_str << " Then:"; + PrintIRStructure(if_node->then_child.get(), indent + 2); + } + if (if_node->else_child) { + LOG(INFO) << indent_str << " Else:"; + PrintIRStructure(if_node->else_child.get(), indent + 2); + } } } diff --git a/src/transform/auto_schedule/memory_detector.h b/src/transform/auto_schedule/memory_detector.h index 0d7b400681..3e509aca0b 100644 --- a/src/transform/auto_schedule/memory_detector.h +++ b/src/transform/auto_schedule/memory_detector.h @@ -386,12 +386,18 @@ class MemoryAccessDetector : public StmtExprVisitor { if (access_type == 2 || access_type == 3) { // write or read/write Update(&write_buffers_, &write_regions_, buffer, relaxed_region); } + + for (const auto &index : buffer_load->indices) { + VisitExpr(index); + } + for (size_t i = 2; i < op->args.size(); ++i) { + VisitExpr(op->args[i]); + } } else { LOG(FATAL) << "First argument of tl.tileop.region should be a BufferLoad"; } } - StmtExprVisitor::VisitExpr_(op); return; } @@ -409,7 +415,6 @@ class MemoryAccessDetector : public StmtExprVisitor { // Process second argument as write region ProcessBufferRegion(op->args[1], false); // is_read = false } - StmtExprVisitor::VisitExpr_(op); return; } diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index 9066157a15..b52064c857 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -89,6 +89,9 @@ void GatherTaskNodes(const std::vector> &nodes, GatherTaskNodesSingle(wrapper->child, task_nodes); } else if (node->IsControl()) { task_nodes.emplace_back(node); + } else if (node->IsIf()) { + // IfNode is atomic — add as whole unit, don't decompose + task_nodes.emplace_back(node); } else { LOG(FATAL) << "Unknown node type in GatherTaskNodes"; } @@ -677,6 +680,16 @@ void ScheduleUnitBuilder::ScheduleRecursive( seq_node->children = ChildrenScheduleHelper(origin_children); node = seq_node; return; + } else if (node->IsIf()) { + // IfNode: recursively schedule both branches internally + auto if_node = static_cast(node.get()); + if (if_node->then_child) { + ScheduleRecursive(if_node->then_child, used_buffers); + } + if (if_node->else_child) { + ScheduleRecursive(if_node->else_child, used_buffers); + } + return; } LOG(FATAL) << "[ScheduleRecursive] Unknown IRStructure type" << node.get(); @@ -937,6 +950,15 @@ void ScheduleUnitBuilder::NaiveScheduleRecursive( WrapInScheduleUnits(origin_children); seq_node->children = origin_children; node = seq_node; + } else if (node->IsIf()) { + // IfNode: recursively schedule both branches internally + auto if_node = static_cast(node.get()); + if (if_node->then_child) { + NaiveScheduleRecursive(if_node->then_child); + } + if (if_node->else_child) { + NaiveScheduleRecursive(if_node->else_child); + } } else { LOG(FATAL) << "[NaiveScheduleRecursive] Unknown IRStructure type"; } diff --git a/src/transform/auto_schedule/schedule_builder.h b/src/transform/auto_schedule/schedule_builder.h index 7ca3af114e..cf7f43d702 100644 --- a/src/transform/auto_schedule/schedule_builder.h +++ b/src/transform/auto_schedule/schedule_builder.h @@ -125,11 +125,7 @@ class ScheduleUnitBuilder { size_t n = nodes.size(); if (n <= 1) { if (n == 1) { - // For TaskNode, set start time - if (nodes[0]->IsTask()) { - auto task = static_cast(nodes[0]); - task->SetStartTime(0); - } + nodes[0]->SetStartTime(0); } return nodes; } @@ -235,11 +231,7 @@ class ScheduleUnitBuilder { // Apply start times to nodes for (size_t i = 0; i < n; ++i) { - // Only TaskNode has SetStartTime method - if (nodes[i]->IsTask()) { - auto task = static_cast(nodes[i]); - task->SetStartTime(start_times[i]); - } + nodes[i]->SetStartTime(start_times[i]); } // Create sorted task list based on start_time (and original index as diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index c638feff69..615fa7630c 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -112,6 +112,12 @@ bool ContainsLetDecl(const IRStructure *node) { } else if (node->IsScheduleUnit()) { auto unit = static_cast(node); return ContainsLetDecl(unit->child.get()); + } else if (node->IsIf()) { + auto if_node = static_cast(node); + if (ContainsLetDecl(if_node->then_child.get())) + return true; + if (if_node->else_child && ContainsLetDecl(if_node->else_child.get())) + return true; } return false; } @@ -246,6 +252,34 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, } } return new_unit; + } else if (node->IsIf()) { + if (!node->containWarpgroupId(warpgroup_id) && !ContainsLetDecl(node)) + return nullptr; + auto if_node = static_cast(node); + auto new_if = std::make_shared(); + new_if->condition = var_remap.empty() + ? if_node->condition + : Substitute(if_node->condition, var_remap); + if (if_node->task) { + auto cloned_task = + std::static_pointer_cast(if_node->task->Clone()); + if (!var_remap.empty()) { + for (size_t i = 0; i < cloned_task->stmts.size(); ++i) { + cloned_task->stmts[i] = Substitute(cloned_task->stmts[i], var_remap); + } + } + new_if->task = std::move(cloned_task); + } + new_if->then_child = CloneIRStructureWithWarpgroupFilter( + if_node->then_child.get(), warpgroup_id, var_remap); + if (if_node->else_child) { + new_if->else_child = CloneIRStructureWithWarpgroupFilter( + if_node->else_child.get(), warpgroup_id, var_remap); + } + // Return nullptr if both branches are empty + if (!new_if->then_child && !new_if->else_child) + return nullptr; + return new_if; } LOG(FATAL); return nullptr; @@ -346,6 +380,18 @@ RemoveUnusedLetDecls(std::shared_ptr root) { collector(s); } referenced_vars.insert(collector.vars.begin(), collector.vars.end()); + } else if (node->IsIf()) { + auto if_node = static_cast(node); + collect(if_node->task.get()); + collect(if_node->then_child.get()); + if (if_node->else_child) { + collect(if_node->else_child.get()); + } + // Collect variable references from the condition + VarRefCollector cond_collector; + cond_collector(if_node->condition); + referenced_vars.insert(cond_collector.vars.begin(), + cond_collector.vars.end()); } }; collect(root.get()); @@ -424,6 +470,16 @@ RemoveUnusedLetDecls(std::shared_ptr root) { if (!new_unit->child) return nullptr; return new_unit; + } else if (node->IsIf()) { + auto if_node = static_cast(node.get()); + auto new_if = std::make_shared(); + new_if->condition = if_node->condition; + new_if->task = if_node->task; + new_if->then_child = filter_tree(if_node->then_child); + if (if_node->else_child) { + new_if->else_child = filter_tree(if_node->else_child); + } + return new_if; } return node; }; @@ -603,6 +659,12 @@ Stmt ConvertIRStructureToStmt(IRStructure *structure, } else if (structure->IsWrapper()) { auto wrapper = static_cast(structure); return check_contains_loop_break(wrapper->child.get()); + } else if (structure->IsIf()) { + auto if_node = static_cast(structure); + if (if_node->then_child && check_contains_loop_break(if_node->then_child.get())) + return true; + if (if_node->else_child && check_contains_loop_break(if_node->else_child.get())) + return true; } return false; }; @@ -720,6 +782,16 @@ Stmt ConvertIRStructureToStmt(IRStructure *structure, } else { LOG(FATAL); } + } else if (structure->IsIf()) { + auto if_node = static_cast(structure); + Stmt then_stmt = ConvertIRStructureToStmt(if_node->then_child.get(), + outer_enable_epi); + Optional else_stmt; + if (if_node->else_child) { + else_stmt = ConvertIRStructureToStmt(if_node->else_child.get(), + outer_enable_epi); + } + return IfThenElse(if_node->condition, then_stmt, else_stmt); } LOG(FATAL) @@ -800,6 +872,8 @@ Stmt ApplyWarpgroupPartitionToIRStructure( return nullptr; } else if (node->IsControl()) { return nullptr; + } else if (node->IsIf()) { + return nullptr; } LOG(FATAL); return nullptr; @@ -868,6 +942,8 @@ Stmt ApplyWarpgroupPartitionToIRStructure( return nullptr; } else if (node->IsControl()) { return nullptr; + } else if (node->IsIf()) { + return nullptr; } LOG(FATAL); return nullptr; diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 9ce8b28187..3fd503de66 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -39,10 +39,19 @@ def allow_vectorize(pass_ctx: PassContext | None = None) -> bool: return not disable_vectorize -def allow_autoschedule(pass_ctx: PassContext | None = None) -> bool: +def allow_autoschedule(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() enable_autoschedule = pass_ctx.config.get("tl.enable_auto_schedule", False) + if enable_autoschedule and target is not None: + # Auto-schedule only works on CUDA targets; skip on CPU + if target.kind.name != "cuda": + return False + # When TMA lowering is disabled, skip auto-schedule to avoid + # rewriting copies to tma_copy that cannot be lowered. + disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False) + if disable_tma_lower: + return False return enable_autoschedule @@ -180,10 +189,16 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.InjectAssumes()(mod) # Simplify the IR expressions mod = tilelang.transform.Simplify()(mod) - if allow_autoschedule(): + if allow_autoschedule(target=target): # Auto schedule for high-level operations mod = tilelang.transform.IfConditionExtract()(mod) mod = tilelang.transform.AutoSchedule(False)(mod) + import os + if os.environ.get("TILELANG_DUMP_AUTO_SCHEDULE"): + print("=" * 60) + print("IR after AutoSchedule:") + print("=" * 60) + print(mod) mod = tilelang.transform.Simplify()(mod) # Set layouts for reducers mod = tilelang.transform.LayoutReducer()(mod) diff --git a/tilelang/transform/z3_scheduler.py b/tilelang/transform/z3_scheduler.py index 95ba51f07e..ebbaf7f4e2 100644 --- a/tilelang/transform/z3_scheduler.py +++ b/tilelang/transform/z3_scheduler.py @@ -8,8 +8,12 @@ import os import json import time +import threading from pathlib import Path +# Global lock to serialize Z3 calls — Z3 is not thread-safe +_z3_lock = threading.Lock() + # Try to import z3, but handle missing installation gracefully try: import z3 @@ -206,8 +210,9 @@ def z3_schedule_ffi(latencies, iis, resource_flags, data_deps, resource_deps): if hasattr(resource_deps[i], "__len__") and len(resource_deps[i]) == 2: resource_deps_list.append((int(resource_deps[i][0]), int(resource_deps[i][1]))) - # Call the actual scheduler - start_times, _ = z3_schedule_python(latencies_list, iis_list, resource_flags_list, data_deps_list, resource_deps_list) + # Call the actual scheduler (Z3 is not thread-safe, serialize access) + with _z3_lock: + start_times, _ = z3_schedule_python(latencies_list, iis_list, resource_flags_list, data_deps_list, resource_deps_list) # Return only start_times, C++ side will sort by start_time return start_times @@ -664,10 +669,11 @@ def z3_schedule_loop_ffi(num_stages, latencies, iis, resource_flags, data_deps, if hasattr(resource_deps[i], "__len__") and len(resource_deps[i]) == 2: resource_deps_list.append((int(resource_deps[i][0]), int(resource_deps[i][1]))) - # Call the actual scheduler - start_times, stages, best_ii = z3_schedule_loop_python( - num_stages, latencies_list, iis_list, resource_flags_list, data_deps_list, resource_deps_list, buffer_sizes_list, memory_limit - ) + # Call the actual scheduler (Z3 is not thread-safe, serialize access) + with _z3_lock: + start_times, stages, best_ii = z3_schedule_loop_python( + num_stages, latencies_list, iis_list, resource_flags_list, data_deps_list, resource_deps_list, buffer_sizes_list, memory_limit + ) # Return start_times and promotes as separate arrays for easier FFI handling # C++ side expects a tuple of (start_times_array, promotes_array) From 3c7f4a057208b6eef0c8b5a3ce815d00338fd0c9 Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Wed, 15 Apr 2026 17:07:22 +0800 Subject: [PATCH 057/156] remove debug info --- tilelang/engine/phase.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 3fd503de66..d8b5e4a61a 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -193,12 +193,6 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: # Auto schedule for high-level operations mod = tilelang.transform.IfConditionExtract()(mod) mod = tilelang.transform.AutoSchedule(False)(mod) - import os - if os.environ.get("TILELANG_DUMP_AUTO_SCHEDULE"): - print("=" * 60) - print("IR after AutoSchedule:") - print("=" * 60) - print(mod) mod = tilelang.transform.Simplify()(mod) # Set layouts for reducers mod = tilelang.transform.LayoutReducer()(mod) From db3d8591223e102d73dba61f801efed0a1a05fd1 Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Wed, 15 Apr 2026 17:09:22 +0800 Subject: [PATCH 058/156] run format --- src/transform/auto_schedule.cc | 6 ++- src/transform/auto_schedule/ir_structure.h | 45 ++++++++++++------- .../auto_schedule/warpgroup_partition.cc | 14 +++--- tilelang/engine/phase.py | 5 +-- 4 files changed, 44 insertions(+), 26 deletions(-) diff --git a/src/transform/auto_schedule.cc b/src/transform/auto_schedule.cc index 63f6d835c2..3366af8b04 100644 --- a/src/transform/auto_schedule.cc +++ b/src/transform/auto_schedule.cc @@ -264,8 +264,10 @@ class IRStructureBuilder : public StmtVisitor { } // Latency = max of both branches - int64_t then_latency = if_node->then_child ? if_node->then_child->GetLatency() : 0; - int64_t else_latency = if_node->else_child ? if_node->else_child->GetLatency() : 0; + int64_t then_latency = + if_node->then_child ? if_node->then_child->GetLatency() : 0; + int64_t else_latency = + if_node->else_child ? if_node->else_child->GetLatency() : 0; if_node->SetLatency(std::max(then_latency, else_latency)); root_ = std::move(if_node); diff --git a/src/transform/auto_schedule/ir_structure.h b/src/transform/auto_schedule/ir_structure.h index 7f5fe1b984..560797ad3c 100644 --- a/src/transform/auto_schedule/ir_structure.h +++ b/src/transform/auto_schedule/ir_structure.h @@ -602,20 +602,26 @@ class IfNode : public IRStructure { // Resource usage flags (aggregate from both branches) bool UsesCUDACore() const override { bool result = false; - if (then_child) result |= then_child->UsesCUDACore(); - if (else_child) result |= else_child->UsesCUDACore(); + if (then_child) + result |= then_child->UsesCUDACore(); + if (else_child) + result |= else_child->UsesCUDACore(); return result; } bool UsesTMACore() const override { bool result = false; - if (then_child) result |= then_child->UsesTMACore(); - if (else_child) result |= else_child->UsesTMACore(); + if (then_child) + result |= then_child->UsesTMACore(); + if (else_child) + result |= else_child->UsesTMACore(); return result; } bool UsesTensorCore() const override { bool result = false; - if (then_child) result |= then_child->UsesTensorCore(); - if (else_child) result |= else_child->UsesTensorCore(); + if (then_child) + result |= then_child->UsesTensorCore(); + if (else_child) + result |= else_child->UsesTensorCore(); return result; } @@ -697,9 +703,12 @@ class IfNode : public IRStructure { void SubstituteVar(const Var &old_var, const Var &new_var) override { condition = Substitute(condition, {{old_var, new_var}}); - if (then_child) then_child->SubstituteVar(old_var, new_var); - if (else_child) else_child->SubstituteVar(old_var, new_var); - if (task) task->SubstituteVar(old_var, new_var); + if (then_child) + then_child->SubstituteVar(old_var, new_var); + if (else_child) + else_child->SubstituteVar(old_var, new_var); + if (task) + task->SubstituteVar(old_var, new_var); } // Latency = max of both branches @@ -708,16 +717,22 @@ class IfNode : public IRStructure { // Setters (delegate to both branches) void SetUsesCUDACore(bool value) override { - if (then_child) then_child->SetUsesCUDACore(value); - if (else_child) else_child->SetUsesCUDACore(value); + if (then_child) + then_child->SetUsesCUDACore(value); + if (else_child) + else_child->SetUsesCUDACore(value); } void SetUsesTMACore(bool value) override { - if (then_child) then_child->SetUsesTMACore(value); - if (else_child) else_child->SetUsesTMACore(value); + if (then_child) + then_child->SetUsesTMACore(value); + if (else_child) + else_child->SetUsesTMACore(value); } void SetUsesTensorCore(bool value) override { - if (then_child) then_child->SetUsesTensorCore(value); - if (else_child) else_child->SetUsesTensorCore(value); + if (then_child) + then_child->SetUsesTensorCore(value); + if (else_child) + else_child->SetUsesTensorCore(value); } void SetReadRegions(const std::vector ®ions) override {} void SetWriteRegions(const std::vector ®ions) override {} diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index 615fa7630c..3cb3e57039 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -661,9 +661,11 @@ Stmt ConvertIRStructureToStmt(IRStructure *structure, return check_contains_loop_break(wrapper->child.get()); } else if (structure->IsIf()) { auto if_node = static_cast(structure); - if (if_node->then_child && check_contains_loop_break(if_node->then_child.get())) + if (if_node->then_child && + check_contains_loop_break(if_node->then_child.get())) return true; - if (if_node->else_child && check_contains_loop_break(if_node->else_child.get())) + if (if_node->else_child && + check_contains_loop_break(if_node->else_child.get())) return true; } return false; @@ -784,12 +786,12 @@ Stmt ConvertIRStructureToStmt(IRStructure *structure, } } else if (structure->IsIf()) { auto if_node = static_cast(structure); - Stmt then_stmt = ConvertIRStructureToStmt(if_node->then_child.get(), - outer_enable_epi); + Stmt then_stmt = + ConvertIRStructureToStmt(if_node->then_child.get(), outer_enable_epi); Optional else_stmt; if (if_node->else_child) { - else_stmt = ConvertIRStructureToStmt(if_node->else_child.get(), - outer_enable_epi); + else_stmt = + ConvertIRStructureToStmt(if_node->else_child.get(), outer_enable_epi); } return IfThenElse(if_node->condition, then_stmt, else_stmt); } diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index d8b5e4a61a..db085eae0f 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -43,10 +43,9 @@ def allow_autoschedule(pass_ctx: PassContext | None = None, target: Target | Non if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() enable_autoschedule = pass_ctx.config.get("tl.enable_auto_schedule", False) - if enable_autoschedule and target is not None: + if enable_autoschedule and target is not None and target.kind.name != "cuda": # Auto-schedule only works on CUDA targets; skip on CPU - if target.kind.name != "cuda": - return False + return False # When TMA lowering is disabled, skip auto-schedule to avoid # rewriting copies to tma_copy that cannot be lowered. disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False) From f12fb47c5a64b57b529e9261ff570f67ec30b994 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Wed, 15 Apr 2026 17:41:11 +0800 Subject: [PATCH 059/156] [Refactor] Move target gating into InjectFenceProxy pass entry (#2047) Check the PrimFunc's target inside the pass and no-op when the hardware lacks the TMA / async-proxy programming model (non-CUDA or pre-sm_90). Callers in engine/phase.py no longer need to branch on allow_fence_proxy / allow_warp_specialized before invoking the pass, so the redundant if/else is collapsed and the now-unused allow_fence_proxy helper is removed. --- src/transform/inject_fence_proxy.cc | 12 ++++++++++-- tilelang/engine/phase.py | 14 +++----------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/transform/inject_fence_proxy.cc b/src/transform/inject_fence_proxy.cc index 14f130e783..5ac24e7e91 100644 --- a/src/transform/inject_fence_proxy.cc +++ b/src/transform/inject_fence_proxy.cc @@ -19,6 +19,7 @@ #include "tir/transforms/ir_utils.h" #include "../op/builtin.h" +#include "../target/utils.h" namespace tvm { namespace tl { @@ -536,8 +537,15 @@ class ProxyFenceRewriter : public StmtExprMutator { tvm::transform::Pass InjectFenceProxy() { auto pass_func = [](PrimFunc f, const IRModule &, const PassContext &) { - f = ProxyFenceRewriter::Apply(f); - return f; + // fence.proxy.async is only meaningful on CUDA targets that expose the + // TMA / async-proxy programming model (sm_90+). On anything else the + // rewriter has no work to do, so skip it to keep the pipeline target- + // agnostic at its call sites. + auto target_opt = f->GetAttr(tvm::attr::kTarget); + if (!target_opt.defined() || !TargetHasBulkCopy(target_opt.value())) { + return f; + } + return ProxyFenceRewriter::Apply(f); }; return tir::transform::CreatePrimFuncPass(pass_func, 0, "tl.InjectFenceProxy", {}); diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 683307415b..9e74801565 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -28,10 +28,6 @@ def module_has_tma(mod: IRModule) -> bool: return any(func.attrs and func.attrs.get("tl.has_tma", False) for _, func in mod.functions.items()) -def allow_fence_proxy(target: Target | None = None) -> bool: - return have_tma(target) - - def allow_vectorize(pass_ctx: PassContext | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() @@ -284,13 +280,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: # because the merged allocation site is at the beginning of each device function enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target) mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)(mod) - if allow_warp_specialized(pass_ctx=pass_ctx, target=target): - mod = tilelang.transform.InjectFenceProxy()(mod) - else: - if allow_fence_proxy(target=target): - # in hopper device, wgmma is an async proxy - # so we need to inject a fence proxy before it - mod = tilelang.transform.InjectFenceProxy()(mod) + # InjectFenceProxy is a no-op on targets that lack the TMA / async-proxy + # programming model; the pass itself checks the PrimFunc's target. + mod = tilelang.transform.InjectFenceProxy()(mod) mod = tilelang.transform.ThreadSync("shared")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod) mod = tilelang.transform.MergeIfStmt()(mod) From 7ae8d98114778387d637f05cac9e53e0c564958e Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Wed, 15 Apr 2026 18:00:18 +0800 Subject: [PATCH 060/156] fix loop break detection --- src/transform/auto_schedule.cc | 2 +- src/transform/auto_schedule/ir_structure.h | 34 +++++++++++++- .../auto_schedule/warpgroup_partition.cc | 45 +------------------ 3 files changed, 36 insertions(+), 45 deletions(-) diff --git a/src/transform/auto_schedule.cc b/src/transform/auto_schedule.cc index 3366af8b04..cd165b4991 100644 --- a/src/transform/auto_schedule.cc +++ b/src/transform/auto_schedule.cc @@ -642,7 +642,7 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { // Check if aggressive auto-schedule is enabled bool aggressive = - ctx->GetConfig(kEnableAggressiveAutoSchedule, Bool(true)).value(); + ctx->GetConfig(kEnableAggressiveAutoSchedule, Bool(false)).value(); // Build ScheduleUnits from IRStructure ScheduleUnitBuilder unit_builder; diff --git a/src/transform/auto_schedule/ir_structure.h b/src/transform/auto_schedule/ir_structure.h index 560797ad3c..71d63e040a 100644 --- a/src/transform/auto_schedule/ir_structure.h +++ b/src/transform/auto_schedule/ir_structure.h @@ -126,6 +126,9 @@ class IRStructure { virtual bool containWarpgroupId(int id) const = 0; + // Check if this node (or any descendant) contains a loop_break call + virtual bool ContainsLoopBreak() const = 0; + // Start time for scheduling void SetStartTime(int64_t start_time) { start_time_ = start_time; } int64_t GetStartTime() const { return start_time_; } @@ -313,7 +316,7 @@ class TaskNode : public IRStructure { } // Check if this task contains loop_break call - bool ContainsLoopBreak() const; + bool ContainsLoopBreak() const override; private: // Resource usage flags @@ -477,6 +480,11 @@ class ControlNode : public IRStructure { return child && child->containWarpgroupId(id); } + bool ContainsLoopBreak() const override { + return (task && task->ContainsLoopBreak()) || + (child && child->ContainsLoopBreak()); + } + private: // Latency estimation int64_t latency_{0}; // Estimated latency in cycles @@ -581,6 +589,11 @@ class WrapperNode : public IRStructure { return child && child->containWarpgroupId(id); } + bool ContainsLoopBreak() const override { + return (task && task->ContainsLoopBreak()) || + (child && child->ContainsLoopBreak()); + } + private: // Latency estimation int64_t latency_{0}; // Estimated latency in cycles @@ -751,6 +764,12 @@ class IfNode : public IRStructure { (else_child && else_child->containWarpgroupId(id)); } + bool ContainsLoopBreak() const override { + return (task && task->ContainsLoopBreak()) || + (then_child && then_child->ContainsLoopBreak()) || + (else_child && else_child->ContainsLoopBreak()); + } + private: int64_t latency_{0}; int64_t ii_{0}; @@ -860,6 +879,10 @@ class ScheduleUnit : public IRStructure { return child && child->containWarpgroupId(id); } + bool ContainsLoopBreak() const override { + return child && child->ContainsLoopBreak(); + } + private: // Latency estimation int64_t latency_{0}; // Estimated latency in cycles @@ -923,6 +946,15 @@ class SequenceNode : public IRStructure { return false; } + bool ContainsLoopBreak() const override { + for (const auto &child : children) { + if (child->ContainsLoopBreak()) { + return true; + } + } + return false; + } + private: // Latency estimation int64_t latency_{0}; // Estimated latency in cycles diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index 3cb3e57039..854ea86b1a 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -630,51 +630,10 @@ Stmt ConvertIRStructureToStmt(IRStructure *structure, } unit_stages[unit->stage - min_stages].push_back(SeqStmt::Flatten(stmts)); } - // Check if any task in this control node contains loop_break - // If any task contains loop_break, disable prologue - std::function check_contains_loop_break; - check_contains_loop_break = - [&check_contains_loop_break](IRStructure *structure) -> bool { - if (!structure) - return false; - - if (structure->IsTask()) { - auto task = static_cast(structure); - return task->ContainsLoopBreak(); - } else if (structure->IsSequence()) { - auto seq = static_cast(structure); - for (const auto &child : seq->children) { - auto unit = static_cast(child.get()); - if (check_contains_loop_break(unit->child.get())) { - return true; - } - } - return false; - } else if (structure->IsScheduleUnit()) { - auto unit = static_cast(structure); - return check_contains_loop_break(unit->child.get()); - } else if (structure->IsControl()) { - auto ctrl = static_cast(structure); - return check_contains_loop_break(ctrl->child.get()); - } else if (structure->IsWrapper()) { - auto wrapper = static_cast(structure); - return check_contains_loop_break(wrapper->child.get()); - } else if (structure->IsIf()) { - auto if_node = static_cast(structure); - if (if_node->then_child && - check_contains_loop_break(if_node->then_child.get())) - return true; - if (if_node->else_child && - check_contains_loop_break(if_node->else_child.get())) - return true; - } - return false; - }; - // Set enable_pro to true only if: - // 1. No task contains loop_break + // 1. No node contains loop_break // 2. Loop boundaries (min and extent) are constants - bool enable_pro = !check_contains_loop_break(ctrl->child.get()); + bool enable_pro = !(ctrl->child && ctrl->child->ContainsLoopBreak()); // Check if loop boundaries are constants bool loop_min_is_const = tir::is_const_int(loop_start); From dbadd7742e1be6f362843024004803cfb7283a48 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Wed, 15 Apr 2026 18:05:43 +0800 Subject: [PATCH 061/156] fix control node break & format --- src/transform/auto_schedule.cc | 3 +- src/transform/auto_schedule/barrier.h | 34 +++++++++++++--------- src/transform/auto_schedule/ir_structure.h | 3 +- 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/transform/auto_schedule.cc b/src/transform/auto_schedule.cc index cd165b4991..5f2b0e10f5 100644 --- a/src/transform/auto_schedule.cc +++ b/src/transform/auto_schedule.cc @@ -642,7 +642,8 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { // Check if aggressive auto-schedule is enabled bool aggressive = - ctx->GetConfig(kEnableAggressiveAutoSchedule, Bool(false)).value(); + ctx->GetConfig(kEnableAggressiveAutoSchedule, Bool(false)) + .value(); // Build ScheduleUnits from IRStructure ScheduleUnitBuilder unit_builder; diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index 2251bde05f..de333c04cf 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -452,9 +452,11 @@ static void RewriteTaskNodeBuffers( } else if (node->IsIf()) { auto if_node = static_cast(node); if (if_node->then_child) - RewriteTaskNodeBuffers(if_node->then_child.get(), multi_buffer, iteration); + RewriteTaskNodeBuffers(if_node->then_child.get(), multi_buffer, + iteration); if (if_node->else_child) - RewriteTaskNodeBuffers(if_node->else_child.get(), multi_buffer, iteration); + RewriteTaskNodeBuffers(if_node->else_child.get(), multi_buffer, + iteration); } } @@ -580,16 +582,16 @@ AnalyzeAndInsertBarriers(IRStructure *node, int &next_barrier_id, } else if (node->IsIf()) { auto if_node = static_cast(node); if (if_node->then_child) { - AnalyzeAndInsertBarriers( - if_node->then_child.get(), next_barrier_id, barrier_buffers, - barrier_map, thread_count, loop_info, buffer_infos, - neutral_sync_shared_barrier); + AnalyzeAndInsertBarriers(if_node->then_child.get(), next_barrier_id, + barrier_buffers, barrier_map, thread_count, + loop_info, buffer_infos, + neutral_sync_shared_barrier); } if (if_node->else_child) { - AnalyzeAndInsertBarriers( - if_node->else_child.get(), next_barrier_id, barrier_buffers, - barrier_map, thread_count, loop_info, buffer_infos, - neutral_sync_shared_barrier); + AnalyzeAndInsertBarriers(if_node->else_child.get(), next_barrier_id, + barrier_buffers, barrier_map, thread_count, + loop_info, buffer_infos, + neutral_sync_shared_barrier); } } else if (node->IsTask()) { // For TaskNode, nothing to do at this level @@ -611,8 +613,10 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, for (auto &promote_child : seq->children) { auto task = static_cast(promote_child.get()); - if (task->child->IsSequence() || task->child->IsControl() || task->child->IsIf()) { - // If child is SequenceNode, ControlNode, or IfNode, recursively analyze it + if (task->child->IsSequence() || task->child->IsControl() || + task->child->IsIf()) { + // If child is SequenceNode, ControlNode, or IfNode, recursively analyze + // it AnalyzeAndInsertBarriers( task->child.get(), next_barrier_id, barrier_buffers, barrier_map, thread_count, loop_info, buffer_infos, neutral_sync_shared_barrier); @@ -856,8 +860,10 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, std::vector ordered_tasks; for (auto &child : seq->children) { auto task = static_cast(child.get()); - if (task->child->IsSequence() || task->child->IsControl() || task->child->IsIf()) { - // If child is SequenceNode, ControlNode, or IfNode, recursively analyze it + if (task->child->IsSequence() || task->child->IsControl() || + task->child->IsIf()) { + // If child is SequenceNode, ControlNode, or IfNode, recursively analyze + // it AnalyzeAndInsertBarriers( task->child.get(), next_barrier_id, barrier_buffers, barrier_map, thread_count, loop_info, buffer_infos, neutral_sync_shared_barrier); diff --git a/src/transform/auto_schedule/ir_structure.h b/src/transform/auto_schedule/ir_structure.h index 71d63e040a..0da1da9546 100644 --- a/src/transform/auto_schedule/ir_structure.h +++ b/src/transform/auto_schedule/ir_structure.h @@ -481,8 +481,7 @@ class ControlNode : public IRStructure { } bool ContainsLoopBreak() const override { - return (task && task->ContainsLoopBreak()) || - (child && child->ContainsLoopBreak()); + return false; // Loop does not contain loop break } private: From 7ed82669ec457fd07d43fdd15df41c7b9ba853e1 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Wed, 15 Apr 2026 18:49:37 +0800 Subject: [PATCH 062/156] Fix if task collection bug --- src/transform/auto_schedule.cc | 3 +-- src/transform/auto_schedule/ir_structure.cc | 4 ---- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/transform/auto_schedule.cc b/src/transform/auto_schedule.cc index 5f2b0e10f5..3366af8b04 100644 --- a/src/transform/auto_schedule.cc +++ b/src/transform/auto_schedule.cc @@ -642,8 +642,7 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { // Check if aggressive auto-schedule is enabled bool aggressive = - ctx->GetConfig(kEnableAggressiveAutoSchedule, Bool(false)) - .value(); + ctx->GetConfig(kEnableAggressiveAutoSchedule, Bool(true)).value(); // Build ScheduleUnits from IRStructure ScheduleUnitBuilder unit_builder; diff --git a/src/transform/auto_schedule/ir_structure.cc b/src/transform/auto_schedule/ir_structure.cc index 5305f4bbd8..e87460e59d 100644 --- a/src/transform/auto_schedule/ir_structure.cc +++ b/src/transform/auto_schedule/ir_structure.cc @@ -501,10 +501,6 @@ void CollectAllTaskNodesWithContext(IRStructure *node, } else if (node->IsIf()) { auto if_node = static_cast(node); // Recurse into both branches - if (if_node->task) { - CollectAllTaskNodesWithContext(if_node->task.get(), all_tasks, - current_control_node); - } CollectAllTaskNodesWithContext(if_node->then_child.get(), all_tasks, current_control_node); if (if_node->else_child) { From 235ad7eda4d64c1b37d8299bcceb705198782360 Mon Sep 17 00:00:00 2001 From: Nguyen Huy Hoang <24520554@gm.uit.edu.vn> Date: Wed, 15 Apr 2026 23:22:14 +0700 Subject: [PATCH 063/156] Add regression test for 1D TMA load compilation and execution (#1989) * refactor: add regression test for 1d tma load compilation and execution Signed-off-by: Nguyen Huy Hoang <181364121+huyhoang171106@users.noreply.github.com> * lint fix --------- Signed-off-by: Nguyen Huy Hoang <181364121+huyhoang171106@users.noreply.github.com> Co-authored-by: LeiWang1999 --- 3rdparty/tvm | 2 +- testing/python/tilelang/test_tma_load.py | 68 ++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 testing/python/tilelang/test_tma_load.py diff --git a/3rdparty/tvm b/3rdparty/tvm index 882a774844..fab43e41c0 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 882a774844993d103ae6e317ba3c7bbb5952b662 +Subproject commit fab43e41c004e888ded30d45df25ccc8e2612617 diff --git a/testing/python/tilelang/test_tma_load.py b/testing/python/tilelang/test_tma_load.py new file mode 100644 index 0000000000..02b82ee278 --- /dev/null +++ b/testing/python/tilelang/test_tma_load.py @@ -0,0 +1,68 @@ +import pytest +import torch + +import tilelang as tl +import tilelang.language as T + + +N = 4096 +BLOCK = 256 + + +def _get_device_capability() -> tuple[int, int]: + if not torch.cuda.is_available(): + return (0, 0) + return torch.cuda.get_device_capability() + + +def _extract_source(kernel) -> str: + if hasattr(kernel, "get_source"): + source = kernel.get_source() + if isinstance(source, str) and source: + return source + + module = getattr(kernel, "module", None) + if module is not None and hasattr(module, "imported_modules"): + imported = getattr(module, "imported_modules", []) + if imported: + source = imported[0].get_source() + if isinstance(source, str) and source: + return source + + runtime_mod = getattr(kernel, "rt_mod", None) + if runtime_mod is not None and hasattr(runtime_mod, "imported_modules"): + imported = getattr(runtime_mod, "imported_modules", []) + if imported: + source = imported[0].get_source() + if isinstance(source, str) and source: + return source + + raise RuntimeError("Unable to extract generated source from compiled kernel") + + +def _build_1d_tma_copy(): + @T.prim_func + def main(A: T.Buffer((N,), "float16"), B: T.Buffer((N,), "float16")): + with T.Kernel(T.ceildiv(N, BLOCK), threads=128) as bx: + A_shared = T.alloc_shared((BLOCK,), "float16") + T.copy(A[bx * BLOCK : (bx + 1) * BLOCK], A_shared) + T.copy(A_shared, B[bx * BLOCK : (bx + 1) * BLOCK]) + + return main + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +@pytest.mark.skipif(_get_device_capability()[0] < 9, reason="Hopper (sm90+) is required for TMA") +def test_tma_load_1d_compile_and_run_regression(): + program = _build_1d_tma_copy() + kernel = tl.compile(program, out_idx=[1], target="cuda -arch=sm_90a") + + source = _extract_source(kernel) + assert "cp.async.bulk.tensor" in source + assert ".1d" in source + + a = torch.randn((N,), device="cuda", dtype=torch.float16) + b = torch.empty_like(a) + + kernel(a, b) + torch.testing.assert_close(b, a, atol=0, rtol=0) From 891109ea879839bbfe2ad937db208f2d7c6f1ce0 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Thu, 16 Apr 2026 12:53:22 +0800 Subject: [PATCH 064/156] [Transform] Add InjectTcgen05Fence pass (#2003) * [Transform] Add InjectTcgen05Fence pass for Blackwell (SM100+) On Blackwell GPUs, the tcgen05 accumulator (TMEM) resides in a separate address space that is not synchronized by regular thread barriers like __syncthreads() or mbarrier. Two PTX fence instructions are required to ensure cross-thread visibility of TMEM state: tcgen05.fence::before_thread_sync -- flush TMEM before barrier tcgen05.fence::after_thread_sync -- pull TMEM after barrier This commit introduces the `InjectTcgen05Fence` TIR pass that automatically wraps every `tvm_storage_sync("shared")` call with the fence pair when the target is SM100+ and the function uses tcgen05/TMEM operations. Changes: - Define two new TIR intrinsic Ops: `tcgen05_before_thread_sync` and `tcgen05_after_thread_sync` (builtin.h/cc) - Add codegen support to emit `tl::tcgen05_before_thread_sync()` and `tl::tcgen05_after_thread_sync()` (codegen_cuda.cc) - Implement the `InjectTcgen05Fence` pass (inject_tcgen05_fence.cc) - Register the pass in the Python transform module (__init__.py) - Insert the pass in OptimizeForTarget after ThreadSync (phase.py) * fix * Add conservative tcgen05 fence injection Extend InjectTcgen05Fence to cover shared storage sync and linear wait/use, use/arrive handoffs on SM100+ kernels, update the surrounding docs, and add transform coverage. Also lower tcgen05 ld/st copies through dedicated intrinsics instead of call_extern strings, with matching CUDA and CuTeDSL codegen support. * Use tcgen05 fence injection in SM100 examples Remove the hand-written tcgen05 before/after thread-sync fences from the ws and persistent GEMM examples now that InjectTcgen05Fence covers these linear wait/use and use/arrive handoffs. * lint fix * Drop tcgen05 call_extern fence compatibility Simplify InjectTcgen05Fence to recognize only the lowered tcgen05 intrinsics seen after LowerTileOp, and update the transform tests to construct tcgen05_ld directly instead of relying on legacy call_extern strings. --------- Co-authored-by: Freebase6912 --- examples/gemm_sm100/README.md | 10 + examples/gemm_sm100/gemm_tcgen5mma_ws.py | 4 - .../gemm_tcgen5mma_ws_persistent.py | 10 - src/op/builtin.cc | 18 +- src/op/builtin.h | 16 + src/op/copy.cc | 49 ++- src/target/codegen_cuda.cc | 27 ++ src/target/codegen_cutedsl.cc | 24 ++ src/transform/inject_tcgen05_fence.cc | 317 ++++++++++++++++++ ...tilelang_transform_inject_tcgen05_fence.py | 283 ++++++++++++++++ tilelang/engine/phase.py | 5 + tilelang/transform/__init__.py | 18 + 12 files changed, 746 insertions(+), 35 deletions(-) create mode 100644 src/transform/inject_tcgen05_fence.cc create mode 100644 testing/python/transform/test_tilelang_transform_inject_tcgen05_fence.py diff --git a/examples/gemm_sm100/README.md b/examples/gemm_sm100/README.md index e9490b8654..3ae66dde91 100644 --- a/examples/gemm_sm100/README.md +++ b/examples/gemm_sm100/README.md @@ -21,6 +21,16 @@ T.tcgen05_gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, clear_ac T.mbarrier_wait_parity(mbar, k%2) # Manual phase calculation required ``` +TileLang now has a conservative `InjectTcgen05Fence` pass on SM100+ that can +insert `tcgen05_before_thread_sync()` / `tcgen05_after_thread_sync()` around: +- `tvm_storage_sync("shared"|"shared.dyn")` +- linear `mbarrier_wait_parity(...) -> tcgen05/TMEM use` regions +- linear `tcgen05/TMEM use -> mbarrier_arrive(...)` regions + +This does **not** eliminate the need to structure the mbarrier protocol +explicitly in user code, and the examples in this directory still keep manual +fences where they make the handoff points obvious. + ## Examples ### TCGEN5MMA Example (`gemm_tcgen5mma.py`) diff --git a/examples/gemm_sm100/gemm_tcgen5mma_ws.py b/examples/gemm_sm100/gemm_tcgen5mma_ws.py index ff147a2827..b8f2adf41a 100644 --- a/examples/gemm_sm100/gemm_tcgen5mma_ws.py +++ b/examples/gemm_sm100/gemm_tcgen5mma_ws.py @@ -48,7 +48,6 @@ def gemm(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_ elif tx < 64: # warp 1: issue tcgen5 for k in T.serial(k_iters): T.mbarrier_wait_parity(loaded[k % num_stages], (k // num_stages) & 1) - T.tcgen05_after_thread_sync() T.tcgen05_gemm( A_shared[k % num_stages, :, :], B_shared[k % num_stages, :, :], @@ -60,7 +59,6 @@ def gemm(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_ # Wait for all tcgen5 to finish T.mbarrier_wait_parity(tmem_full, 0) - T.tcgen05_after_thread_sync() T.copy(C_tmem, C_local) if use_tma_store: T.copy(C_local, C_shared) @@ -115,7 +113,6 @@ def gemm_2cta(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, elif cta_id == 0 and tx < 64: # Only warp 1 on leader cta issues tcgen5 for k in T.serial(k_iters): T.mbarrier_wait_parity(loaded[k % num_stages], (k // num_stages) & 1) - T.tcgen05_after_thread_sync() T.tcgen05_gemm( A_shared[k % num_stages, :, :], B_shared[k % num_stages, :, :], @@ -128,7 +125,6 @@ def gemm_2cta(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, # Wait for all tcgen5 to finish T.mbarrier_wait_parity(tmem_full, 0) - T.tcgen05_after_thread_sync() T.copy(C_tmem, C_local) if use_tma_store: T.copy(C_local, C_shared) diff --git a/examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py b/examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py index 82a82aaa1a..5a7d820220 100644 --- a/examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py +++ b/examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py @@ -81,11 +81,9 @@ def gemm_persistent( if bx * block_M < M and by * block_N < N: T.mbarrier_wait_parity(tmem_empty[w & 1], ((w // 2) & 1) ^ 1) - T.tcgen05_after_thread_sync() for k in T.serial(k_blocks): phase = w * k_blocks + k T.mbarrier_wait_parity(loaded[phase % num_stages], (phase // num_stages) & 1) - T.tcgen05_after_thread_sync() if w & 1 == 0: T.tcgen05_gemm( A_shared[k % num_stages, :, :], @@ -116,13 +114,10 @@ def gemm_persistent( if bx * block_M < M and by * block_N < N: T.mbarrier_wait_parity(tmem_full[w & 1], (w // 2) & 1) - T.tcgen05_after_thread_sync() - T.sync_threads(1, 128) if (w & 1) == 0: T.copy(C_tmem_0, C_local) else: T.copy(C_tmem_1, C_local) - T.tcgen05_before_thread_sync() T.mbarrier_arrive(tmem_empty[w & 1]) if use_tma_store: @@ -220,11 +215,9 @@ def gemm_persistent_2cta( if bx * block_M < M and by * block_N < N: T.mbarrier_wait_parity(tmem_empty[w & 1], ((w // 2) & 1) ^ 1) - T.tcgen05_after_thread_sync() for k in T.serial(k_blocks): phase = w * k_blocks + k T.mbarrier_wait_parity(loaded[phase % num_stages], (phase // num_stages) & 1) - T.tcgen05_after_thread_sync() if w & 1 == 0: T.tcgen05_gemm( A_shared[phase % num_stages, :, :], @@ -256,13 +249,10 @@ def gemm_persistent_2cta( if bx * block_M < M and by * block_N < N: T.mbarrier_wait_parity(tmem_full[w & 1], (w // 2) & 1) - T.tcgen05_after_thread_sync() - T.sync_threads(1, 128) if (w & 1) == 0: T.copy(C_tmem_0, C_local) else: T.copy(C_tmem_1, C_local) - T.tcgen05_before_thread_sync() T.mbarrier_arrive(tmem_empty[w & 1], 0) if use_tma_store: diff --git a/src/op/builtin.cc b/src/op/builtin.cc index b95ec04360..1ccbf03f9e 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -262,6 +262,16 @@ TIR_DEFINE_TL_BUILTIN(fence_proxy_async) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(tcgen05_before_thread_sync) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(tcgen05_after_thread_sync) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(tma_store_arrive) .set_num_inputs(0) .set_attr("TCallEffectKind", @@ -552,13 +562,13 @@ TIR_DEFINE_TL_BUILTIN(tcgen05_mma_arrive) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(tcgen05_before_thread_sync) - .set_num_inputs(0) +TIR_DEFINE_TL_BUILTIN(tcgen05_ld) + .set_num_inputs(6) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(tcgen05_after_thread_sync) - .set_num_inputs(0) +TIR_DEFINE_TL_BUILTIN(tcgen05_st) + .set_num_inputs(6) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/builtin.h b/src/op/builtin.h index 6e71b2446a..93f9dad82a 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -874,6 +874,22 @@ TVM_DLL const Op &initialize_tcgen05_descriptor(); */ TVM_DLL const Op &tcgen05_mma_arrive(); +/*! + * \brief tilelang intrinsic for lowered TCGEN05 tensor-memory load. + * + * Internal lowering op used by LowerTmemCopy to represent + * `tl::tcgen05_ld_*` calls without routing through `call_extern`. + */ +TVM_DLL const Op &tcgen05_ld(); + +/*! + * \brief tilelang intrinsic for lowered TCGEN05 tensor-memory store. + * + * Internal lowering op used by LowerTmemCopy to represent + * `tl::tcgen05_st_*` calls without routing through `call_extern`. + */ +TVM_DLL const Op &tcgen05_st(); + /*! * \brief TCGEN05 fence before a thread-block-wide sync (__syncthreads / * bar.sync). Matches PTX \c tcgen05.fence::before_thread_sync (DeepGEMM / diff --git a/src/op/copy.cc b/src/op/copy.cc index 2caa948b48..2c2eb24f53 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -1486,7 +1486,6 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, // unpack::16b) so MMA TS reads correctly packed bf16 from TMEM columns. // For tcgen05_ld, pack::16b is still needed when reading unpacked data. bool use_pack_unpack_modifier = is_ld ? needs_pack_unpack : false; - const char *bool_str = use_pack_unpack_modifier ? "true" : "false"; int effective_chunks = needs_pack_unpack ? num_chunks_each_wg / 2 : num_chunks_each_wg; PrimExpr relative_wg_idx = @@ -1497,22 +1496,38 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, : relative_wg_idx * (effective_chunks * meta.width); have_succeeded = true; Array args; - args.push_back(StringImm(meta.intrinsics_name + "<" + - std::to_string(effective_chunks) + ", " + - bool_str + ">")); - args.push_back( - BufferLoad(tmem_buf, {(int)logical_row_min, - (int)logical_col_min})); // Will be translated - // later in - // lower_shared_tmem - // pass - args.push_back(col_offset); - int reg_access_mode = is_ld ? 2 : 1; - args.push_back(reg_buf.access_ptr(reg_access_mode, DataType::Handle(), 1, - 0, PrimExpr(tmem_phy_col_extent))); - - Stmt call = - Evaluate(Call(DataType::Handle(), builtin::call_extern(), args)); + Stmt call; + if (is_ld) { + args.push_back(IntImm(DataType::Int(32), meta.width * 32)); + args.push_back(IntImm(DataType::Int(32), effective_chunks)); + args.push_back(Bool(use_pack_unpack_modifier)); + args.push_back( + BufferLoad(tmem_buf, {(int)logical_row_min, + (int)logical_col_min})); // Will be translated + // later in + // lower_shared_tmem + // pass + args.push_back(col_offset); + args.push_back(reg_buf.access_ptr(/*access_mask=*/2, DataType::Handle(), + /*content_lanes=*/1, /*offset=*/0, + PrimExpr(tmem_phy_col_extent))); + call = Evaluate(Call(DataType::Handle(), tcgen05_ld(), args)); + } else { + args.push_back(IntImm(DataType::Int(32), meta.width * 32)); + args.push_back(IntImm(DataType::Int(32), effective_chunks)); + args.push_back(Bool(use_pack_unpack_modifier)); + args.push_back( + BufferLoad(tmem_buf, {(int)logical_row_min, + (int)logical_col_min})); // Will be translated + // later in + // lower_shared_tmem + // pass + args.push_back(col_offset); + int reg_access_mode = 1; + args.push_back(reg_buf.access_ptr(reg_access_mode, DataType::Handle(), + 1, 0, PrimExpr(tmem_phy_col_extent))); + call = Evaluate(Call(DataType::Handle(), tcgen05_st(), args)); + } if (num_useful_threads != num_threads) { body = IfThenElse(T.thread_var < T.thread_bounds->min + num_useful_threads, diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 77eb42c8c4..90eff80e00 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -2767,6 +2767,33 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { replacer.register_rule("(mask3)", mask3); tcgen05_call = replacer.rewrite(tcgen05_call); this->stream << tcgen05_call; + } else if (op->op.same_as(tl::tcgen05_ld())) { + ICHECK_EQ(op->args.size(), 6U) << "tcgen05_ld expects 6 arguments"; + need_tcgen05_common_h_ = true; + int inst_bits = Downcast(op->args[0])->value; + int chunks = Downcast(op->args[1])->value; + bool pack16 = Downcast(op->args[2])->value; + std::string tmem_start_col = this->PrintExpr(op->args[3]); + std::string col_offset = this->PrintExpr(op->args[4]); + std::string dst_ptr = this->PrintExpr(op->args[5]); + this->PrintIndent(); + this->stream << "tl::tcgen05_ld_32dp" << inst_bits << "bNx<" << chunks + << ", " << (pack16 ? "true" : "false") << ">(" + << tmem_start_col << ", " << col_offset << ", " << dst_ptr + << ");\n"; + } else if (op->op.same_as(tl::tcgen05_st())) { + ICHECK_EQ(op->args.size(), 6U) << "tcgen05_st expects 6 arguments"; + int inst_bits = Downcast(op->args[0])->value; + int chunks = Downcast(op->args[1])->value; + bool unpack16 = Downcast(op->args[2])->value; + std::string tmem_start_col = this->PrintExpr(op->args[3]); + std::string col_offset = this->PrintExpr(op->args[4]); + std::string src_ptr = this->PrintExpr(op->args[5]); + this->PrintIndent(); + this->stream << "tl::tcgen05_st_32dp" << inst_bits << "bNx<" << chunks + << ", " << (unpack16 ? "true" : "false") << ">(" + << tmem_start_col << ", " << col_offset << ", " << src_ptr + << ");\n"; } else if (op->op.same_as(tl::tcgen05_mma_arrive())) { ICHECK_EQ(op->args.size(), 1U) << "tcgen05_mma_arrive expects 1 argument"; need_tcgen05_common_h_ = true; diff --git a/src/target/codegen_cutedsl.cc b/src/target/codegen_cutedsl.cc index d0e44e089b..f495489bc4 100644 --- a/src/target/codegen_cutedsl.cc +++ b/src/target/codegen_cutedsl.cc @@ -899,6 +899,30 @@ void CodeGenTileLangCuTeDSL::VisitExpr_(const CallNode *op, << "[0] + " << c_offset << ", " << desc_val << ", " << scale_out << ", " << mask0 << ", " << mask1 << ", " << mask2 << ", " << mask3 << ")\n"; + } else if (op->op.same_as(tl::tcgen05_ld())) { + ICHECK_EQ(op->args.size(), 6U) << "tcgen05_ld expects 6 arguments"; + int inst_bits = Downcast(op->args[0])->value; + int chunks = Downcast(op->args[1])->value; + bool pack16 = Downcast(op->args[2])->value; + std::string tmem_start_col = PrintExpr_(op->args[3]); + std::string col_offset = PrintExpr_(op->args[4]); + std::string dst_ptr = PrintExpr_(op->args[5]); + PrintIndent(); + stream << "tl.tcgen05_ld_32dp" << inst_bits << "bNx(" << chunks << ", " + << (pack16 ? "True" : "False") << ", " << tmem_start_col << ", " + << col_offset << ", " << dst_ptr << ")\n"; + } else if (op->op.same_as(tl::tcgen05_st())) { + ICHECK_EQ(op->args.size(), 6U) << "tcgen05_st expects 6 arguments"; + int inst_bits = Downcast(op->args[0])->value; + int chunks = Downcast(op->args[1])->value; + bool unpack16 = Downcast(op->args[2])->value; + std::string tmem_start_col = PrintExpr_(op->args[3]); + std::string col_offset = PrintExpr_(op->args[4]); + std::string src_ptr = PrintExpr_(op->args[5]); + PrintIndent(); + stream << "tl.tcgen05_st_32dp" << inst_bits << "bNx(" << chunks << ", " + << (unpack16 ? "True" : "False") << ", " << tmem_start_col << ", " + << col_offset << ", " << src_ptr << ")\n"; } else if (op->op.same_as(tl::tcgen05_mma_arrive())) { ICHECK_EQ(op->args.size(), 1U) << "tcgen05_mma_arrive expects 1 argument"; PrintIndent(); diff --git a/src/transform/inject_tcgen05_fence.cc b/src/transform/inject_tcgen05_fence.cc new file mode 100644 index 0000000000..0573378a8a --- /dev/null +++ b/src/transform/inject_tcgen05_fence.cc @@ -0,0 +1,317 @@ +/*! + * \file inject_tcgen05_fence.cc + * \brief Inject tcgen05.fence::before_thread_sync / after_thread_sync at + * conservative TCGEN05/TMEM synchronization boundaries on Blackwell + * (SM100+) targets. + * + * On Blackwell, the tcgen05 accumulator (TMEM) lives in its own address + * space. Regular thread synchronization barriers (__syncthreads, mbarrier) + * do NOT automatically make TMEM writes visible across threads. Two PTX + * fence instructions bridge this gap: + * + * tcgen05.fence::before_thread_sync — flush TMEM state before barrier + * tcgen05.fence::after_thread_sync — pull TMEM state after barrier + * + * This pass currently handles three patterns when the function targets SM100+ + * and contains tcgen05/TMEM operations: + * + * 1. Wrap every tvm_storage_sync("shared") / ("shared.dyn") with + * before+after fences. + * 2. Insert after_thread_sync after mbarrier_wait_parity when a linear + * scan of following statements reaches tcgen05/TMEM use before another + * synchronization boundary. + * 3. Insert before_thread_sync before ptx_arrive_barrier / + * ptx_arrive_cluster_barrier when a linear reverse scan reaches + * tcgen05/TMEM use before another synchronization boundary. + * + * It intentionally does not add an extra before_thread_sync around + * tcgen05_mma_arrive(), because the underlying tcgen05.commit.*.mbarrier + * already carries the producer-side ordering. + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/builtin.h" +#include "../target/utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; +using tvm::transform::PassContext; + +namespace { + +/*! + * \brief Check if a call is tvm_storage_sync("shared") or + * tvm_storage_sync("shared.dyn"). + */ +bool IsSharedStorageSync(const CallNode *call) { + if (!call || !call->op.same_as(builtin::tvm_storage_sync())) { + return false; + } + if (call->args.empty()) + return false; + const auto *scope = call->args[0].as(); + if (!scope) + return false; + return scope->value == "shared" || scope->value == "shared.dyn"; +} + +bool IsMbarrierWaitParity(const CallNode *call) { + return call && call->op.same_as(mbarrier_wait_parity()); +} + +bool IsPlainBarrierArrive(const CallNode *call) { + return call && (call->op.same_as(builtin::ptx_arrive_barrier()) || + call->op.same_as(ptx_arrive_cluster_barrier())); +} + +bool IsBeforeFenceCall(const CallNode *call) { + return call && call->op.same_as(tcgen05_before_thread_sync()); +} + +bool IsAfterFenceCall(const CallNode *call) { + return call && call->op.same_as(tcgen05_after_thread_sync()); +} + +const CallNode *GetEvaluateCall(const Stmt &stmt) { + if (const auto *eval = stmt.as()) { + return eval->value.as(); + } + return nullptr; +} + +bool IsTcgen05OrTmemCall(const CallNode *call) { + if (!call || IsBeforeFenceCall(call) || IsAfterFenceCall(call)) { + return false; + } + + return call->op.same_as(ptx_tcgen05_mma_ss()) || + call->op.same_as(ptx_tcgen05_mma_ts()) || + call->op.same_as(tcgen05_ld()) || call->op.same_as(tcgen05_st()) || + call->op.same_as(tcgen05_mma_arrive()) || + call->op.same_as(ptx_init_tensor_memory()) || + call->op.same_as(ptx_deallocate_tensor_memory()); +} + +bool StmtUsesTcgen05OrTmem(const Stmt &stmt) { + bool found = false; + PostOrderVisit(stmt, [&](const ObjectRef &node) { + if (found) { + return; + } + if (const auto *call = node.as()) { + found = IsTcgen05OrTmemCall(call); + } + }); + return found; +} + +bool IsBeforeFenceStmt(const Stmt &stmt) { + return IsBeforeFenceCall(GetEvaluateCall(stmt)); +} + +bool IsAfterFenceStmt(const Stmt &stmt) { + return IsAfterFenceCall(GetEvaluateCall(stmt)); +} + +bool IsFenceSyncBoundary(const CallNode *call) { + return IsSharedStorageSync(call) || IsMbarrierWaitParity(call) || + IsPlainBarrierArrive(call) || + (call && call->op.same_as(tcgen05_mma_arrive())); +} + +bool HasUpcomingTcgen05Use(const Array &seq, int start_index) { + for (int i = start_index + 1; i < static_cast(seq.size()); ++i) { + const Stmt &stmt = seq[i]; + if (IsAfterFenceStmt(stmt)) { + return false; + } + if (StmtUsesTcgen05OrTmem(stmt)) { + return true; + } + if (IsBeforeFenceStmt(stmt) || IsFenceSyncBoundary(GetEvaluateCall(stmt))) { + return false; + } + } + return false; +} + +bool HasPriorTcgen05Use(const Array &seq, int start_index) { + for (int i = start_index - 1; i >= 0; --i) { + const Stmt &stmt = seq[i]; + if (IsBeforeFenceStmt(stmt)) { + return false; + } + if (StmtUsesTcgen05OrTmem(stmt)) { + return true; + } + if (IsAfterFenceStmt(stmt) || IsFenceSyncBoundary(GetEvaluateCall(stmt))) { + return false; + } + } + return false; +} + +/*! + * \brief Check whether the function body contains any tcgen05 / TMEM + * operations that warrant fence insertion. + */ +bool HasTcgen05Operations(const Stmt &body) { + return StmtUsesTcgen05OrTmem(body); +} + +inline Stmt MakeBeforeFenceStmt() { + return Evaluate(Call(DataType::Void(), tcgen05_before_thread_sync(), {})); +} + +inline Stmt MakeAfterFenceStmt() { + return Evaluate(Call(DataType::Void(), tcgen05_after_thread_sync(), {})); +} + +inline void AppendFlattened(Array *out, const Stmt &stmt) { + if (!stmt.defined()) { + return; + } + if (const auto *seq = stmt.as()) { + for (const Stmt &inner : seq->seq) { + out->push_back(inner); + } + return; + } + out->push_back(stmt); +} + +/*! + * \brief Rewriter for conservative TCGEN05/TMEM handoff points. + * + * Supported rewrites: + * + * tcgen05_before_thread_sync(); + * __syncthreads(); // tvm_storage_sync("shared") + * tcgen05_after_thread_sync(); + * + * mbarrier_wait_parity(...); + * tcgen05_after_thread_sync(); // when the subsequent linear region uses + * // tcgen05/TMEM before another sync point + * + * tcgen05_before_thread_sync(); // when the prior linear region used + * ptx_arrive_barrier(...); // tcgen05/TMEM after the previous sync + */ +class Tcgen05FenceRewriter : public StmtExprMutator { +public: + Stmt VisitStmt_(const SeqStmtNode *op) final { + bool saved_in_seq = in_seq_rewrite_; + in_seq_rewrite_ = true; + + Array mutated_children; + for (const Stmt &stmt : op->seq) { + mutated_children.push_back(VisitStmt(stmt)); + } + + in_seq_rewrite_ = saved_in_seq; + + Array flat_seq; + for (const Stmt &stmt : mutated_children) { + AppendFlattened(&flat_seq, stmt); + } + + Array rewritten; + for (int i = 0; i < static_cast(flat_seq.size()); ++i) { + const Stmt &stmt = flat_seq[i]; + const CallNode *call = GetEvaluateCall(stmt); + + if (IsSharedStorageSync(call)) { + if (i == 0 || !IsBeforeFenceStmt(flat_seq[i - 1])) { + rewritten.push_back(MakeBeforeFenceStmt()); + } + rewritten.push_back(stmt); + if (i + 1 >= static_cast(flat_seq.size()) || + !IsAfterFenceStmt(flat_seq[i + 1])) { + rewritten.push_back(MakeAfterFenceStmt()); + } + continue; + } + + if (IsMbarrierWaitParity(call)) { + rewritten.push_back(stmt); + bool has_manual_after = i + 1 < static_cast(flat_seq.size()) && + IsAfterFenceStmt(flat_seq[i + 1]); + if (!has_manual_after && HasUpcomingTcgen05Use(flat_seq, i)) { + rewritten.push_back(MakeAfterFenceStmt()); + } + continue; + } + + if (IsPlainBarrierArrive(call)) { + bool has_manual_before = i > 0 && IsBeforeFenceStmt(flat_seq[i - 1]); + if (!has_manual_before && HasPriorTcgen05Use(flat_seq, i)) { + rewritten.push_back(MakeBeforeFenceStmt()); + } + rewritten.push_back(stmt); + continue; + } + + rewritten.push_back(stmt); + } + + if (rewritten.size() == 1) { + return rewritten[0]; + } + return SeqStmt(std::move(rewritten)); + } + + Stmt VisitStmt_(const EvaluateNode *op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + if (in_seq_rewrite_) { + return stmt; + } + const auto *call = GetEvaluateCall(stmt); + if (IsSharedStorageSync(call)) { + return SeqStmt( + {MakeBeforeFenceStmt(), std::move(stmt), MakeAfterFenceStmt()}); + } + return stmt; + } + +private: + bool in_seq_rewrite_{false}; +}; + +} // namespace + +tvm::transform::Pass InjectTcgen05Fence() { + auto pass_func = [](PrimFunc f, const IRModule &, const PassContext &) { + // Only apply on SM100+ (Blackwell) targets. + Optional opt_target = f->GetAttr(tvm::attr::kTarget); + if (!opt_target.defined() || + !TargetHasSMVersionGE(opt_target.value(), 100)) { + return f; + } + // Only apply if the function actually uses tcgen05 / TMEM operations. + if (!HasTcgen05Operations(f->body)) { + return f; + } + Tcgen05FenceRewriter rewriter; + f.CopyOnWrite()->body = rewriter(f->body); + return f; + }; + return tir::transform::CreatePrimFuncPass(pass_func, 0, + "tl.InjectTcgen05Fence", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.InjectTcgen05Fence", InjectTcgen05Fence); +} + +} // namespace tl +} // namespace tvm diff --git a/testing/python/transform/test_tilelang_transform_inject_tcgen05_fence.py b/testing/python/transform/test_tilelang_transform_inject_tcgen05_fence.py new file mode 100644 index 0000000000..41672084e1 --- /dev/null +++ b/testing/python/transform/test_tilelang_transform_inject_tcgen05_fence.py @@ -0,0 +1,283 @@ +# ruff: noqa +from tilelang import tvm as tvm +import tilelang as tl +import tilelang.language as T +from tilelang.engine.phase import LowerAndLegalize +from tvm import tir + + +sm100_target = tvm.target.Target("cuda -arch=sm_100") +sm90_target = tvm.target.Target("cuda -arch=sm_90a") + + +def _apply(func, target=sm100_target): + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tl.transform.InjectTcgen05Fence()(mod) + mod = tir.transform.LowerOpaqueBlock()(mod) + return mod + + +def _check(original, expected, target=sm100_target): + mod = _apply(original, target) + expected_mod = tvm.IRModule.from_expr(expected.with_attr("global_symbol", "main")) + expected_mod = tvm.tir.transform.BindTarget(target)(expected_mod) + expected_mod = tir.transform.LowerOpaqueBlock()(expected_mod) + tvm.ir.assert_structural_equal(mod["main"], expected_mod["main"], True) + + +def _count_calls(stmt, op_name: str): + count = 0 + + def visitor(node): + nonlocal count + if isinstance(node, tir.Call) and hasattr(node, "op") and hasattr(node.op, "name") and node.op.name == op_name: + count += 1 + + tir.stmt_functor.post_order_visit(stmt, visitor) + return count + + +def _count_extern_calls_with_prefix(stmt, prefix: str): + count = 0 + + def visitor(node): + nonlocal count + if not isinstance(node, tir.Call): + return + op = getattr(node, "op", None) + if getattr(op, "name", None) != "tir.call_extern": + return + if not node.args: + return + name = node.args[0] + if isinstance(name, tir.StringImm) and name.value.startswith(prefix): + count += 1 + + tir.stmt_functor.post_order_visit(stmt, visitor) + return count + + +def _tcgen05_ld_call(tmem_ref, local_buf): + return T.call_intrin( + "handle", + tir.op.Op.get("tl.tcgen05_ld"), + 32, + 128, + False, + tmem_ref, + 0, + T.tvm_access_ptr(T.type_annotation(T.float32), local_buf.data, 0, 128, 2), + ) + + +def test_storage_sync_is_wrapped_with_tcgen05_fences(): + @T.prim_func + def before(): + with T.Kernel(1): + C_tmem = T.decl_buffer((1,), T.uint32, scope="shared") + C_local = T.decl_buffer((128,), T.float32, scope="local") + T.tvm_storage_sync("shared") + T.evaluate(_tcgen05_ld_call(C_tmem[0], C_local)) + + @T.prim_func + def after(): + with T.Kernel(1): + C_tmem = T.decl_buffer((1,), T.uint32, scope="shared") + C_local = T.decl_buffer((128,), T.float32, scope="local") + T.tcgen05_before_thread_sync() + T.tvm_storage_sync("shared") + T.tcgen05_after_thread_sync() + T.evaluate(_tcgen05_ld_call(C_tmem[0], C_local)) + + _check(before, after) + + +def test_lower_tmem_copy_uses_tcgen05_ld_intrin(): + @T.prim_func + def func(X: T.Tensor((256, 256), T.float16), Y: T.Tensor((256, 256), T.float16)): + with T.Kernel(1, 1, threads=128) as (bx, by): + A_shared = T.alloc_shared((128, 128), T.float16) + B_shared = T.alloc_shared((128, 128), T.float16) + C_tmem = T.alloc_tmem([128, 128], T.float32) + mbar = T.alloc_barrier(1) + C_local = T.alloc_fragment((128, 128), T.float32) + T.copy(X[0, 0], A_shared) + T.copy(X[0, 0], B_shared) + T.tcgen05_gemm( + A_shared, + B_shared, + C_tmem, + transpose_B=True, + mbar=mbar, + clear_accum=True, + ) + T.mbarrier_wait_parity(mbar, 0) + T.copy(C_tmem, C_local) + T.copy(C_local, Y[0, 0]) + + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + with sm100_target: + mod = LowerAndLegalize(mod, sm100_target) + mod = tl.transform.LowerSharedTmem()(mod) + + body = mod["main"].body + assert _count_calls(body, "tl.tcgen05_ld") == 1 + assert _count_extern_calls_with_prefix(body, "tl::tcgen05_ld_") == 0 + + +def test_lower_tmem_copy_uses_tcgen05_st_intrin(): + @T.prim_func + def func(X: T.Tensor((256, 256), T.bfloat16)): + with T.Kernel(1, 1, threads=128) as (bx, by): + A_shared = T.alloc_shared((128, 128), T.bfloat16) + B_shared = T.alloc_shared((128, 128), T.bfloat16) + S_tmem = T.alloc_tmem([128, 128], T.float32) + mbar = T.alloc_barrier(1) + S_local = T.alloc_fragment((128, 128), T.float32) + P_local = T.alloc_fragment((128, 128), T.bfloat16) + P_tmem = T.alloc_tmem([128, 128], T.bfloat16) + B2_shared = T.alloc_shared((128, 128), T.bfloat16) + D_tmem = T.alloc_tmem([128, 128], T.float32) + mbar2 = T.alloc_barrier(1) + T.copy(X[0, 0], A_shared) + T.copy(X[0, 0], B_shared) + T.tcgen05_gemm( + A_shared, + B_shared, + S_tmem, + transpose_B=True, + mbar=mbar, + clear_accum=True, + ) + T.mbarrier_wait_parity(mbar, 0) + T.copy(S_tmem, S_local) + T.copy(S_local, P_local) + T.copy(P_local, P_tmem) + T.copy(X[0, 0], B2_shared) + T.tcgen05_gemm( + P_tmem, + B2_shared, + D_tmem, + transpose_B=True, + mbar=mbar2, + clear_accum=True, + ) + + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + with sm100_target: + mod = LowerAndLegalize(mod, sm100_target) + mod = tl.transform.LowerSharedTmem()(mod) + + body = mod["main"].body + assert _count_calls(body, "tl.tcgen05_st") == 1 + assert _count_extern_calls_with_prefix(body, "tl::tcgen05_st_") == 0 + + +def test_wait_and_arrive_are_rewritten_only_at_tmem_handoffs(): + @T.prim_func + def before(): + with T.Kernel(1): + mbarrier = T.decl_buffer((1,), T.uint64, scope="shared.barrier") + C_tmem = T.decl_buffer((1,), T.uint32, scope="shared") + C_local = T.decl_buffer((128,), T.float32, scope="local") + T.mbarrier_wait_parity(mbarrier[0], 0) + T.evaluate(_tcgen05_ld_call(C_tmem[0], C_local)) + T.ptx_arrive_barrier(mbarrier[0]) + + @T.prim_func + def after(): + with T.Kernel(1): + mbarrier = T.decl_buffer((1,), T.uint64, scope="shared.barrier") + C_tmem = T.decl_buffer((1,), T.uint32, scope="shared") + C_local = T.decl_buffer((128,), T.float32, scope="local") + T.mbarrier_wait_parity(mbarrier[0], 0) + T.tcgen05_after_thread_sync() + T.evaluate(_tcgen05_ld_call(C_tmem[0], C_local)) + T.tcgen05_before_thread_sync() + T.ptx_arrive_barrier(mbarrier[0]) + + _check(before, after) + + +def test_wait_and_arrive_scan_across_neutral_statements(): + @T.prim_func + def before(): + with T.Kernel(1): + mbarrier = T.decl_buffer((1,), T.uint64, scope="shared.barrier") + C_tmem = T.decl_buffer((1,), T.uint32, scope="shared") + C_local = T.decl_buffer((128,), T.float32, scope="local") + T.mbarrier_wait_parity(mbarrier[0], 0) + T.call_extern("handle", "generic_op") + T.evaluate(_tcgen05_ld_call(C_tmem[0], C_local)) + T.call_extern("handle", "generic_op") + T.ptx_arrive_barrier(mbarrier[0]) + + @T.prim_func + def after(): + with T.Kernel(1): + mbarrier = T.decl_buffer((1,), T.uint64, scope="shared.barrier") + C_tmem = T.decl_buffer((1,), T.uint32, scope="shared") + C_local = T.decl_buffer((128,), T.float32, scope="local") + T.mbarrier_wait_parity(mbarrier[0], 0) + T.tcgen05_after_thread_sync() + T.call_extern("handle", "generic_op") + T.evaluate(_tcgen05_ld_call(C_tmem[0], C_local)) + T.call_extern("handle", "generic_op") + T.tcgen05_before_thread_sync() + T.ptx_arrive_barrier(mbarrier[0]) + + _check(before, after) + + +def test_sync_boundary_stops_wait_lookahead(): + @T.prim_func + def func(): + with T.Kernel(1): + mbarrier = T.decl_buffer((1,), T.uint64, scope="shared.barrier") + C_tmem = T.decl_buffer((1,), T.uint32, scope="shared") + C_local = T.decl_buffer((128,), T.float32, scope="local") + T.mbarrier_wait_parity(mbarrier[0], 0) + T.call_extern("handle", "generic_op") + T.ptx_arrive_barrier(mbarrier[0]) + T.evaluate(_tcgen05_ld_call(C_tmem[0], C_local)) + + mod = _apply(func) + assert _count_calls(mod["main"].body, "tl.tcgen05_after_thread_sync") == 0 + + +def test_existing_manual_fences_are_not_duplicated(): + @T.prim_func + def func(): + with T.Kernel(1): + mbarrier = T.decl_buffer((1,), T.uint64, scope="shared.barrier") + C_tmem = T.decl_buffer((1,), T.uint32, scope="shared") + C_local = T.decl_buffer((128,), T.float32, scope="local") + T.mbarrier_wait_parity(mbarrier[0], 0) + T.tcgen05_after_thread_sync() + T.evaluate(_tcgen05_ld_call(C_tmem[0], C_local)) + T.tcgen05_before_thread_sync() + T.ptx_arrive_barrier(mbarrier[0]) + + mod = _apply(func) + body = mod["main"].body + assert _count_calls(body, "tl.tcgen05_after_thread_sync") == 1 + assert _count_calls(body, "tl.tcgen05_before_thread_sync") == 1 + + +def test_non_sm100_targets_are_left_untouched(): + @T.prim_func + def func(): + with T.Kernel(1): + C_tmem = T.decl_buffer((1,), T.uint32, scope="shared") + C_local = T.decl_buffer((128,), T.float32, scope="local") + T.tvm_storage_sync("shared") + T.evaluate(_tcgen05_ld_call(C_tmem[0], C_local)) + + mod = _apply(func, sm90_target) + assert _count_calls(mod["main"].body, "tl.tcgen05_before_thread_sync") == 0 + assert _count_calls(mod["main"].body, "tl.tcgen05_after_thread_sync") == 0 + + +if __name__ == "__main__": + tl.testing.main() diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 9e74801565..028de25991 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -285,6 +285,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.InjectFenceProxy()(mod) mod = tilelang.transform.ThreadSync("shared")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod) + # Inject conservative tcgen05 fences on Blackwell (SM100+). + # Must run after ThreadSync so that tvm_storage_sync calls are present. + # The pass handles shared syncs and simple linear wait/use, use/arrive + # handoffs, and is a no-op on non-SM100 targets or functions without TMEM. + mod = tilelang.transform.InjectTcgen05Fence()(mod) mod = tilelang.transform.MergeIfStmt()(mod) # NOTE: LowerPTXAsyncCopy is applied earlier (before PipelinePlanning). if allow_warp_specialized(pass_ctx=pass_ctx, target=target): diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index ba851b5e03..a7a9414d71 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -279,6 +279,24 @@ def InjectFenceProxy(): return _ffi_api.InjectFenceProxy() # type: ignore +def InjectTcgen05Fence(): + """Inject tcgen05.fence::before_thread_sync / after_thread_sync at + conservative TCGEN05/TMEM synchronization boundaries on Blackwell + (SM100+) targets. + + The current pass wraps CTA-wide shared-memory syncs and also inserts + fences around linear mbarrier wait/use and use/arrive handoff patterns. + It is intentionally conservative and does not try to infer arbitrary + barrier protocols. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectTcgen05Fence() # type: ignore + + def LegalizeVectorizedLoop(): """LegalizeLoopVectorize From 8f67446ea5d2502a13ce642ced85f7042d620a0a Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Thu, 16 Apr 2026 12:55:42 +0800 Subject: [PATCH 065/156] fix loop break --- .../auto_schedule/schedule_builder.cc | 28 ++++++------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index b52064c857..76e41f7831 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -111,16 +111,10 @@ bool SameBuffer(const BufferRegion &a, const BufferRegion &b) { bool SameVar(const Var &a, const Var &b) { return a.same_as(b); } bool HasDependency(const IRStructure *a, const IRStructure *b) { - if (a->IsTask()) { - const TaskNode *task_a = static_cast(a); - if (task_a->ContainsLoopBreak()) - return true; - } - if (b->IsTask()) { - const TaskNode *task_b = static_cast(b); - if (task_b->ContainsLoopBreak()) - return true; - } + if (a->ContainsLoopBreak()) + return true; + if (b->ContainsLoopBreak()) + return true; for (const auto &write_region_a : a->GetWriteRegions()) { for (const auto &read_region_b : b->GetReadRegions()) { if (SameBuffer(write_region_a, read_region_b)) @@ -147,16 +141,10 @@ bool HasDependency(const IRStructure *a, const IRStructure *b) { } bool HasRegisterDependency(const IRStructure *a, const IRStructure *b) { - if (a->IsTask()) { - const TaskNode *task_a = static_cast(a); - if (task_a->ContainsLoopBreak()) - return true; - } - if (b->IsTask()) { - const TaskNode *task_b = static_cast(b); - if (task_b->ContainsLoopBreak()) - return true; - } + if (a->ContainsLoopBreak()) + return true; + if (b->ContainsLoopBreak()) + return true; for (const auto &write_region_a : a->GetWriteRegions()) { if (IsSharedBuffer(write_region_a.get()->buffer)) continue; From 844d04e82b4ed557668f2524f2481d43c7a75578 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Thu, 16 Apr 2026 14:13:37 +0800 Subject: [PATCH 066/156] fix tma load detection --- src/transform/auto_schedule.cc | 50 ++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/src/transform/auto_schedule.cc b/src/transform/auto_schedule.cc index 3366af8b04..284b1d3551 100644 --- a/src/transform/auto_schedule.cc +++ b/src/transform/auto_schedule.cc @@ -54,6 +54,7 @@ #include #include "../op/builtin.h" +#include "../op/copy.h" #include "../op/gemm_py.h" #include "../target/utils.h" #include "./common/attr.h" @@ -392,34 +393,35 @@ class IRStructureBuilder : public StmtVisitor { } // Check if this is a TMA copy operation - if (op->op.same_as(copy_op)) { - bool found_global = false, found_shared = false; - int idx_global = -1, idx_shared = -1; - for (unsigned idx = 0; idx != 2; ++idx) { - auto region = Downcast(op->args[idx]); - if (const auto *buffer_load = - region->args[0].as()) { - Buffer buffer = buffer_load->buffer; - String scope = buffer.scope(); - MemoryType mem_type = GetMemoryTypeFromScope(scope); - if (mem_type == MemoryType::kGlobal) { - found_global = true; - idx_global = idx; - } - if (mem_type == MemoryType::kShared) { - found_shared = true; - idx_shared = idx; - } - } - } - found_tma = false; - if (found_global && found_shared) { - if (idx_global == 0 && idx_shared == 1) { + static const auto tma_copy_op = Op::Get("tl.tileop.tma_copy"); + static const auto async_copy_op = Op::Get("tl.tileop.async_copy"); + + bool is_copy_like = op->op.same_as(copy_op) || + op->op.same_as(tma_copy_op) || + op->op.same_as(async_copy_op); + + if (is_copy_like) { + Copy copy_obj(op->args, op->annotations); + const CopyNode *copy = copy_obj.get(); + + if (copy->GetIsAsyncCopy()) { + // T.async_copy() — cp.async path, never TMA. + } else if (copy->GetIsTmaCopy()) { + // Explicit T.tma_copy(): only valid global->shared TMA loads + // are producers; TMA stores stay on the consumer side. + arith::Analyzer ana; + if (copy->CheckBulkLoad(target, &ana, /*check_last_dim=*/false)) { found_tma = true; found_tma_load = true; } - if (idx_global == 1 && idx_shared == 0) + } else { + // Generic T.copy(): check if TMA is possible. + arith::Analyzer ana; + if (!copy->GetDisableTMA() && + copy->CheckBulkLoad(target, &ana, /*check_last_dim=*/true)) { found_tma = true; + found_tma_load = true; + } } } else if (op->op.same_as(gemm_py_op) || op->op.same_as(gemm_op) || op->op.same_as(wgmma_gemm_py_op) || From 45f5da4203aaae5686fa1c9880ad44194a0663df Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Thu, 16 Apr 2026 14:14:13 +0800 Subject: [PATCH 067/156] fix pro/epilogue identification --- src/transform/auto_schedule/barrier.h | 41 ++++++++++++++----- src/transform/auto_schedule/ir_structure.cc | 7 +++- src/transform/auto_schedule/ir_structure.h | 40 +++++++++++++++++- .../auto_schedule/schedule_builder.cc | 8 ++++ .../auto_schedule/warpgroup_partition.cc | 28 +++++++++---- 5 files changed, 101 insertions(+), 23 deletions(-) diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index de333c04cf..0b8688cd18 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -638,7 +638,7 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, bool found_wgmma = false; for (const auto ®ion_access : task->GetReadWriteRegions()) { int wg_id = region_access.warpgroup_id; - if (wg_id == -1) + if (wg_id == -1 || region_access.schedule_phase != SchedulePhase::kBody) continue; auto ®ion = region_access.region; if (IsRegisterRegion(region)) { @@ -654,7 +654,7 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, } else { for (const auto ®ion_access : task->GetReadWriteRegions()) { int wg_id = region_access.warpgroup_id; - if (wg_id == -1) + if (wg_id == -1 || region_access.schedule_phase != SchedulePhase::kBody) continue; auto ®ion = region_access.region; if (IsRegisterRegion(region)) { @@ -683,7 +683,7 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, auto child = static_cast(task->child.get()); if (child->is_TCGEN05()) { int wg_id = child->GetWarpgroupId(); - if (wg_id != -1) { + if (!child->IsNeutralPhase()) { int barrier_id = next_barrier_id++; // Create a single barrier buffer with shape (1,) Buffer barrier_buffer = makeBarrierBuffer( @@ -710,7 +710,13 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, auto child = static_cast(task->child.get()); if (child->HasTMALoad()) { int wg_id = child->GetWarpgroupId(); - if (wg_id != -1) { + LOG(INFO) << "[DEBUG] AnalyzeSequenceNodeBarriers: TMA load found, wg_id=" + << wg_id << " has_tma_load=" << child->HasTMALoad() + << " phase=" << static_cast(child->GetSchedulePhase()); + for (auto &stmt : child->stmts) { + LOG(INFO) << "[DEBUG] stmt: " << stmt; + } + if (!child->IsNeutralPhase()) { int barrier_id = next_barrier_id++; Buffer barrier_buffer = makeBarrierBuffer( thread_count[wg_id], "tma_barrier_" + std::to_string(barrier_id), @@ -722,6 +728,8 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, Stmt arrive_stmt = makeBarrierArrive(barrier_load); InsertStatementIntoScheduleUnit(task, arrive_stmt, false, wg_id); } else { + LOG(INFO) << "[DEBUG] AnalyzeSequenceNodeBarriers: TMA load with wg_id=-1 (NEUTRAL), " + << "reusing neutral_sync_shared_barrier for ALL neutral TMA loads!"; PrimExpr barrier_load = BufferLoad(neutral_sync_shared_barrier, {0}); RewriteCopyMbar(child, barrier_load); } @@ -754,13 +762,20 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, bool is_async = is_async_task(task); for (const auto ®ion_access : task->GetReadWriteRegions()) { int wg_id = region_access.warpgroup_id; + if (region_access.schedule_phase != SchedulePhase::kBody) { + LOG(INFO) << "[DEBUG] Barrier dep analysis: SKIPPING region with phase=" + << static_cast(region_access.schedule_phase) << " wg_id=" << wg_id + << " buffer=" << region_access.region->buffer->name + << " is_write=" << region_access.is_write; + continue; + } if (wg_id == -1) continue; if (region_access.region->buffer != buffer) continue; auto insert_barrier = [&](ScheduleUnit *last_task, int last_wg_id) { - if (last_wg_id == -1) + if (last_wg_id == -1 && last_task->IsNeutralPhase()) return; bool last_async = is_async_task(last_task); if (last_wg_id == wg_id && !last_async) @@ -806,6 +821,8 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, } for (const auto ®ion_access : task->GetReadWriteRegions()) { int wg_id = region_access.warpgroup_id; + if (region_access.schedule_phase != SchedulePhase::kBody) + continue; if (wg_id == -1) continue; if (region_access.region->buffer != buffer) @@ -957,7 +974,7 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, bool found_wgmma = false; for (const auto ®ion_access : task->GetReadWriteRegions()) { int wg_id = region_access.warpgroup_id; - if (wg_id == -1) + if (wg_id == -1 || region_access.schedule_phase != SchedulePhase::kBody) continue; auto ®ion = region_access.region; if (IsRegisterRegion(region)) { @@ -973,7 +990,7 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, } else { for (const auto ®ion_access : task->GetReadWriteRegions()) { int wg_id = region_access.warpgroup_id; - if (wg_id == -1) + if (wg_id == -1 || region_access.schedule_phase != SchedulePhase::kBody) continue; auto ®ion = region_access.region; if (IsRegisterRegion(region)) { @@ -1012,7 +1029,7 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, } } int wg_id = child->GetWarpgroupId(); - ICHECK(wg_id != -1) << "TCGEN05MMA must have valid warpgroup id"; + ICHECK(!child->IsNeutralPhase()) << "TCGEN05MMA must not be in prologue/epilogue"; int barrier_id = next_barrier_id++; // Create a single barrier buffer with shape (num_versions,) @@ -1046,7 +1063,7 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, } } int wg_id = child->GetWarpgroupId(); - ICHECK(wg_id != -1) << "TMA loads must have valid warpgroup id"; + ICHECK(!child->IsNeutralPhase()) << "TMA loads in pipeline must not be in prologue/epilogue"; int barrier_id = next_barrier_id++; Buffer barrier_buffer = makeBarrierBuffer( @@ -1095,13 +1112,15 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, bool is_async = is_async_task(task); for (const auto ®ion_access : task->GetReadWriteRegions()) { int wg_id = region_access.warpgroup_id; + if (region_access.schedule_phase != SchedulePhase::kBody) + continue; if (wg_id == -1) continue; if (region_access.region->buffer != buffer) continue; auto insert_barrier = [&](ScheduleUnit *last_task, int last_wg_id) { - if (last_wg_id == -1) + if (last_wg_id == -1 && last_task->IsNeutralPhase()) return; if (last_task == task) // ??? return; @@ -1175,6 +1194,8 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, if (iter == 0) { for (const auto ®ion_access : task->GetReadWriteRegions()) { int wg_id = region_access.warpgroup_id; + if (region_access.schedule_phase != SchedulePhase::kBody) + continue; if (wg_id == -1) continue; if (region_access.region->buffer != buffer) diff --git a/src/transform/auto_schedule/ir_structure.cc b/src/transform/auto_schedule/ir_structure.cc index e87460e59d..811bab9689 100644 --- a/src/transform/auto_schedule/ir_structure.cc +++ b/src/transform/auto_schedule/ir_structure.cc @@ -260,6 +260,8 @@ std::shared_ptr TaskNode::Clone() const { new_task->SetStartTime(GetStartTime()); // Copy warpgroup id new_task->SetWarpgroupId(GetWarpgroupId()); + // Copy scheduling phase + new_task->SetSchedulePhase(GetSchedulePhase()); // Copy loop_break cache new_task->contains_loop_break_cache_ = contains_loop_break_cache_; return new_task; @@ -269,12 +271,13 @@ void TaskNode::CollectRegions( std::vector &result, std::set>> &visited) const { int wg_id = GetWarpgroupId(); + SchedulePhase phase = GetSchedulePhase(); // Collect write regions for (const auto ®ion : GetWriteRegions()) { auto key = std::make_pair(region->buffer, std::make_pair(true, wg_id)); if (visited.find(key) == visited.end()) { visited.insert(key); - result.emplace_back(region, true, wg_id); + result.emplace_back(region, true, wg_id, phase); } } // Collect read regions @@ -282,7 +285,7 @@ void TaskNode::CollectRegions( auto key = std::make_pair(region->buffer, std::make_pair(false, wg_id)); if (visited.find(key) == visited.end()) { visited.insert(key); - result.emplace_back(region, false, wg_id); + result.emplace_back(region, false, wg_id, phase); } } } diff --git a/src/transform/auto_schedule/ir_structure.h b/src/transform/auto_schedule/ir_structure.h index 0da1da9546..89957969ac 100644 --- a/src/transform/auto_schedule/ir_structure.h +++ b/src/transform/auto_schedule/ir_structure.h @@ -31,14 +31,24 @@ class SequenceNode; class WrapperNode; class IfNode; +// Scheduling phase: separates "when does this task run" from "which warpgroup" +enum class SchedulePhase : uint8_t { + kBody = 0, // Normal body task - participates in warpgroup partition + kPrologue = 1, // Runs on ALL threads BEFORE warpgroup-specific code + kEpilogue = 2, // Runs on ALL threads AFTER warpgroup-specific code +}; + // Structure to store region access information with warpgroup id struct RegionAccessInfo { BufferRegion region; bool is_write; // true for write, false for read int warpgroup_id; // warpgroup id of the innermost TaskNode + SchedulePhase schedule_phase{SchedulePhase::kBody}; // scheduling phase - RegionAccessInfo(BufferRegion region, bool is_write, int warpgroup_id) - : region(region), is_write(is_write), warpgroup_id(warpgroup_id) {} + RegionAccessInfo(BufferRegion region, bool is_write, int warpgroup_id, + SchedulePhase phase = SchedulePhase::kBody) + : region(region), is_write(is_write), warpgroup_id(warpgroup_id), + schedule_phase(phase) {} }; // Helper function to compare if two regions are equal @@ -124,6 +134,11 @@ class IRStructure { // Get warpgroup id for this node (-1 if not applicable) virtual int GetWarpgroupId() const { return -1; } + // Get scheduling phase for this node + virtual SchedulePhase GetSchedulePhase() const { return SchedulePhase::kBody; } + // Convenience: true if this node is prologue or epilogue (not body) + virtual bool IsNeutralPhase() const { return GetSchedulePhase() != SchedulePhase::kBody; } + virtual bool containWarpgroupId(int id) const = 0; // Check if this node (or any descendant) contains a loop_break call @@ -194,6 +209,13 @@ class TaskNode : public IRStructure { void SetWarpgroupId(int warpgroup_id) { warpgroup_id_ = warpgroup_id; } int GetWarpgroupId() const override { return warpgroup_id_; } + // Scheduling phase (prologue / body / epilogue) + void SetSchedulePhase(SchedulePhase phase) { schedule_phase_ = phase; } + SchedulePhase GetSchedulePhase() const override { return schedule_phase_; } + bool IsNeutralPhase() const override { + return schedule_phase_ != SchedulePhase::kBody; + } + // TMA load flag void SetHasTMALoad(bool value) { has_tma_load_ = value; } bool HasTMALoad() const { return has_tma_load_; } @@ -336,6 +358,8 @@ class TaskNode : public IRStructure { int64_t ii_{0}; // Initiation interval in cycles int warpgroup_id_{ -1}; // Warpgroup id for warpgroup specialization (-1 means unassigned) + SchedulePhase schedule_phase_{ + SchedulePhase::kBody}; // Scheduling phase (prologue/body/epilogue) // TMA information bool has_tma_load_{false}; @@ -870,6 +894,14 @@ class ScheduleUnit : public IRStructure { const TaskNode *task = static_cast(child.get()); return task->GetWarpgroupId(); } + SchedulePhase GetSchedulePhase() const override { + if (!isInnerTask()) + return SchedulePhase::kBody; + return static_cast(child.get())->GetSchedulePhase(); + } + bool IsNeutralPhase() const override { + return GetSchedulePhase() != SchedulePhase::kBody; + } // Clone method std::shared_ptr Clone() const override; @@ -1300,6 +1332,10 @@ inline void PrintIRStructure(const IRStructure *node, int indent = 0) { LOG(INFO) << indent_str << " latency: " << task->GetLatency() << " cycles"; LOG(INFO) << indent_str << " II: " << task->GetII() << " cycles"; LOG(INFO) << indent_str << " warpgroup_id: " << task->GetWarpgroupId(); + LOG(INFO) << indent_str << " schedule_phase: " << static_cast(task->GetSchedulePhase()) + << (task->GetSchedulePhase() == SchedulePhase::kPrologue ? " (prologue)" + : task->GetSchedulePhase() == SchedulePhase::kEpilogue ? " (epilogue)" + : " (body)"); } else if (node->IsControl()) { const ControlNode *control = static_cast(node); LOG(INFO) << indent_str << "ControlNode (For loop):"; diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index 76e41f7831..eaa23cdedd 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -375,9 +375,15 @@ AssignWarpgroupIdsGlobal(IRStructure *root, const WarpSpecializeConfig &config, std::unordered_set prefix_tasks; CollectPrefixTasks(root, prefix_tasks); + for (auto *task : prefix_tasks) { + task->SetSchedulePhase(SchedulePhase::kPrologue); + } std::unordered_set suffix_tasks; CollectSuffixTasks(root, all_tasks, uf, suffix_tasks); + for (auto *task : suffix_tasks) { + task->SetSchedulePhase(SchedulePhase::kEpilogue); + } std::unordered_map> components; for (int i = 0; i < n; i++) { @@ -716,6 +722,7 @@ NaiveAssignWarpgroupIds(IRStructure *root, const WarpSpecializeConfig &config, CollectPrefixTasks(root, prefix_tasks); for (auto *task : prefix_tasks) { task->SetWarpgroupId(-1); + task->SetSchedulePhase(SchedulePhase::kPrologue); } int n = all_tasks.size(); @@ -731,6 +738,7 @@ NaiveAssignWarpgroupIds(IRStructure *root, const WarpSpecializeConfig &config, CollectSuffixTasks(root, all_tasks, uf, suffix_tasks); for (auto *task : suffix_tasks) { task->SetWarpgroupId(-1); + task->SetSchedulePhase(SchedulePhase::kEpilogue); } // no double_thread in naive mode diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index 854ea86b1a..ea4dec5a94 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -643,7 +643,17 @@ Stmt ConvertIRStructureToStmt(IRStructure *structure, enable_pro = false; } - bool enable_epi = outer_enable_epi && enable_pro; + bool enable_epi = enable_pro; + + // Read num_stages from loop annotation for prologue extent + int num_stages_annotation = max_stages - min_stages; + auto num_stages_val = ctrl->control.get()->annotations.Get("num_stages"); + if (num_stages_val.has_value()) { + num_stages_annotation = num_stages_val.value().cast()->value; + } + int prologue_extent = 2 * num_stages_annotation; + int epilogue_extent = max_stages - min_stages; + std::vector steady; for (auto &child : seq->children) { @@ -706,10 +716,10 @@ Stmt ConvertIRStructureToStmt(IRStructure *structure, new_for.CopyOnWrite()->loop_var = pro; new_for.CopyOnWrite()->kind = ForKind::kUnrolled; new_for.CopyOnWrite()->extent = - min(max_stages - min_stages, for_op.get()->extent); - for_op.CopyOnWrite()->min += loop_step * (max_stages - min_stages); + min(prologue_extent, for_op.get()->extent); + for_op.CopyOnWrite()->min += loop_step * prologue_extent; for_op.CopyOnWrite()->extent = - max(0, for_op.get()->extent - (max_stages - min_stages)); + max(0, for_op.get()->extent - prologue_extent); prologue = Substitute(new_for, sub); } Stmt epilogue = Evaluate(0); @@ -722,11 +732,11 @@ Stmt ConvertIRStructureToStmt(IRStructure *structure, new_for.CopyOnWrite()->kind = ForKind::kUnrolled; new_for.CopyOnWrite()->min = for_op.get()->min + - loop_step * (for_op.get()->extent - (max_stages - min_stages)); + loop_step * (for_op.get()->extent - epilogue_extent); new_for.CopyOnWrite()->extent = - min(max_stages - min_stages, for_op.get()->extent); + min(epilogue_extent, for_op.get()->extent); for_op.CopyOnWrite()->extent = - max(0, for_op.get()->extent - (max_stages - min_stages)); + max(0, for_op.get()->extent - epilogue_extent); epilogue = Substitute(new_for, sub); } return SeqStmt({prologue, for_op, epilogue}); @@ -801,7 +811,7 @@ Stmt ApplyWarpgroupPartitionToIRStructure( if (node->IsTask()) { auto task = static_cast(node); - if (task->GetWarpgroupId() == -1) { + if (task->IsNeutralPhase()) { return task->Clone(); } else { auto new_task = std::make_shared(); @@ -934,7 +944,7 @@ Stmt ApplyWarpgroupPartitionToIRStructure( CollectAllTaskNodesWithContext(unit->child.get(), child_tasks); auto &info = child_infos[i]; for (const auto &task : child_tasks) { - if (task.task->GetWarpgroupId() >= 0) { + if (!task.task->IsNeutralPhase()) { info.all_neutral = false; for (const auto &wr : task.task->GetWriteRegions()) wg_write_buffers.insert(wr->buffer.get()); From 9970ac9b21f6a139252d3a5faed710908349b801 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Thu, 16 Apr 2026 16:09:59 +0800 Subject: [PATCH 068/156] remove debug output & format --- src/transform/auto_schedule/barrier.h | 27 ++++++++-------------- src/transform/auto_schedule/ir_structure.h | 25 ++++++++++++-------- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index 0b8688cd18..f9ea1eae65 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -638,7 +638,8 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, bool found_wgmma = false; for (const auto ®ion_access : task->GetReadWriteRegions()) { int wg_id = region_access.warpgroup_id; - if (wg_id == -1 || region_access.schedule_phase != SchedulePhase::kBody) + if (wg_id == -1 || + region_access.schedule_phase != SchedulePhase::kBody) continue; auto ®ion = region_access.region; if (IsRegisterRegion(region)) { @@ -710,12 +711,6 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, auto child = static_cast(task->child.get()); if (child->HasTMALoad()) { int wg_id = child->GetWarpgroupId(); - LOG(INFO) << "[DEBUG] AnalyzeSequenceNodeBarriers: TMA load found, wg_id=" - << wg_id << " has_tma_load=" << child->HasTMALoad() - << " phase=" << static_cast(child->GetSchedulePhase()); - for (auto &stmt : child->stmts) { - LOG(INFO) << "[DEBUG] stmt: " << stmt; - } if (!child->IsNeutralPhase()) { int barrier_id = next_barrier_id++; Buffer barrier_buffer = makeBarrierBuffer( @@ -728,8 +723,6 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, Stmt arrive_stmt = makeBarrierArrive(barrier_load); InsertStatementIntoScheduleUnit(task, arrive_stmt, false, wg_id); } else { - LOG(INFO) << "[DEBUG] AnalyzeSequenceNodeBarriers: TMA load with wg_id=-1 (NEUTRAL), " - << "reusing neutral_sync_shared_barrier for ALL neutral TMA loads!"; PrimExpr barrier_load = BufferLoad(neutral_sync_shared_barrier, {0}); RewriteCopyMbar(child, barrier_load); } @@ -763,10 +756,6 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, for (const auto ®ion_access : task->GetReadWriteRegions()) { int wg_id = region_access.warpgroup_id; if (region_access.schedule_phase != SchedulePhase::kBody) { - LOG(INFO) << "[DEBUG] Barrier dep analysis: SKIPPING region with phase=" - << static_cast(region_access.schedule_phase) << " wg_id=" << wg_id - << " buffer=" << region_access.region->buffer->name - << " is_write=" << region_access.is_write; continue; } if (wg_id == -1) @@ -974,7 +963,8 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, bool found_wgmma = false; for (const auto ®ion_access : task->GetReadWriteRegions()) { int wg_id = region_access.warpgroup_id; - if (wg_id == -1 || region_access.schedule_phase != SchedulePhase::kBody) + if (wg_id == -1 || + region_access.schedule_phase != SchedulePhase::kBody) continue; auto ®ion = region_access.region; if (IsRegisterRegion(region)) { @@ -990,7 +980,8 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, } else { for (const auto ®ion_access : task->GetReadWriteRegions()) { int wg_id = region_access.warpgroup_id; - if (wg_id == -1 || region_access.schedule_phase != SchedulePhase::kBody) + if (wg_id == -1 || + region_access.schedule_phase != SchedulePhase::kBody) continue; auto ®ion = region_access.region; if (IsRegisterRegion(region)) { @@ -1029,7 +1020,8 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, } } int wg_id = child->GetWarpgroupId(); - ICHECK(!child->IsNeutralPhase()) << "TCGEN05MMA must not be in prologue/epilogue"; + ICHECK(!child->IsNeutralPhase()) + << "TCGEN05MMA must not be in prologue/epilogue"; int barrier_id = next_barrier_id++; // Create a single barrier buffer with shape (num_versions,) @@ -1063,7 +1055,8 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, } } int wg_id = child->GetWarpgroupId(); - ICHECK(!child->IsNeutralPhase()) << "TMA loads in pipeline must not be in prologue/epilogue"; + ICHECK(!child->IsNeutralPhase()) + << "TMA loads in pipeline must not be in prologue/epilogue"; int barrier_id = next_barrier_id++; Buffer barrier_buffer = makeBarrierBuffer( diff --git a/src/transform/auto_schedule/ir_structure.h b/src/transform/auto_schedule/ir_structure.h index 89957969ac..24fc94c87f 100644 --- a/src/transform/auto_schedule/ir_structure.h +++ b/src/transform/auto_schedule/ir_structure.h @@ -33,9 +33,9 @@ class IfNode; // Scheduling phase: separates "when does this task run" from "which warpgroup" enum class SchedulePhase : uint8_t { - kBody = 0, // Normal body task - participates in warpgroup partition - kPrologue = 1, // Runs on ALL threads BEFORE warpgroup-specific code - kEpilogue = 2, // Runs on ALL threads AFTER warpgroup-specific code + kBody = 0, // Normal body task - participates in warpgroup partition + kPrologue = 1, // Runs on ALL threads BEFORE warpgroup-specific code + kEpilogue = 2, // Runs on ALL threads AFTER warpgroup-specific code }; // Structure to store region access information with warpgroup id @@ -135,9 +135,13 @@ class IRStructure { virtual int GetWarpgroupId() const { return -1; } // Get scheduling phase for this node - virtual SchedulePhase GetSchedulePhase() const { return SchedulePhase::kBody; } + virtual SchedulePhase GetSchedulePhase() const { + return SchedulePhase::kBody; + } // Convenience: true if this node is prologue or epilogue (not body) - virtual bool IsNeutralPhase() const { return GetSchedulePhase() != SchedulePhase::kBody; } + virtual bool IsNeutralPhase() const { + return GetSchedulePhase() != SchedulePhase::kBody; + } virtual bool containWarpgroupId(int id) const = 0; @@ -1332,10 +1336,13 @@ inline void PrintIRStructure(const IRStructure *node, int indent = 0) { LOG(INFO) << indent_str << " latency: " << task->GetLatency() << " cycles"; LOG(INFO) << indent_str << " II: " << task->GetII() << " cycles"; LOG(INFO) << indent_str << " warpgroup_id: " << task->GetWarpgroupId(); - LOG(INFO) << indent_str << " schedule_phase: " << static_cast(task->GetSchedulePhase()) - << (task->GetSchedulePhase() == SchedulePhase::kPrologue ? " (prologue)" - : task->GetSchedulePhase() == SchedulePhase::kEpilogue ? " (epilogue)" - : " (body)"); + LOG(INFO) << indent_str << " schedule_phase: " + << static_cast(task->GetSchedulePhase()) + << (task->GetSchedulePhase() == SchedulePhase::kPrologue + ? " (prologue)" + : task->GetSchedulePhase() == SchedulePhase::kEpilogue + ? " (epilogue)" + : " (body)"); } else if (node->IsControl()) { const ControlNode *control = static_cast(node); LOG(INFO) << indent_str << "ControlNode (For loop):"; From 1660c855ebc05a6857e3db1edd548821eb4028ec Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Thu, 16 Apr 2026 16:17:43 +0800 Subject: [PATCH 069/156] reimplement barrier logic --- src/transform/auto_schedule/barrier.h | 752 ++++++++++---------------- 1 file changed, 294 insertions(+), 458 deletions(-) diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index de333c04cf..b610951ccf 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -121,18 +121,11 @@ struct LoopNestingInfo { } PrimExpr CalculateIterationCount() const { - ICHECK(!loop_vars.empty()); PrimExpr total_iter = IntImm(DataType::Int(32), 0); PrimExpr total_multiplier = IntImm(DataType::Int(32), 1); - - // Build expression: outer_var * inner_tripcount + inner_var - // For nested loops: (((outer_var * inner_tripcount) + inner_var) * - // innermost_step) + ... - for (int i = loop_vars.size() - 1; i >= 0; i--) { - // Calculate normalized iteration: (loop_var - start) / step + for (size_t i = loop_vars.size(); i-- > 0;) { PrimExpr normalized_iter = indexdiv(loop_vars[i] - loop_starts[i], loop_steps[i]); - if (i == static_cast(loop_vars.size()) - 1) { total_iter = normalized_iter; } else { @@ -195,15 +188,10 @@ static Stmt makeBarrierArrive(PrimExpr barrier_expr, int cta_id = -1, Call(DataType::Handle(), builtin::ptx_arrive_barrier(), args)); } -static Stmt makeTcgen05MmaArrive(Buffer barrier_buffer) { - Array access_ptr_args; - access_ptr_args.push_back(tir::TypeAnnotation(DataType::UInt(64))); - access_ptr_args.push_back(barrier_buffer->data); - access_ptr_args.push_back(barrier_buffer->elem_offset); - access_ptr_args.push_back(IntImm(DataType::Int(32), 1)); - access_ptr_args.push_back(IntImm(DataType::Int(32), 3)); - auto access_ptr = - Call(DataType::Handle(), builtin::tvm_access_ptr(), access_ptr_args); +static Stmt makeTcgen05MmaArrive(Buffer barrier_buffer, + PrimExpr offset = IntImm(DataType::Int(32), + 0)) { + auto access_ptr = barrier_buffer.access_ptr(3, DataType::Handle(), 1, offset); return Evaluate(Call(DataType::Handle(), tcgen05_mma_arrive(), {access_ptr})); } @@ -600,225 +588,317 @@ AnalyzeAndInsertBarriers(IRStructure *node, int &next_barrier_id, } } -static void -AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, - std::vector &barrier_buffers, - Map &barrier_map, - const std::vector &thread_count, - LoopNestingInfo &loop_info, - std::vector &buffer_infos, - Buffer neutral_sync_shared_barrier) { - if (!seq) - return; - - for (auto &promote_child : seq->children) { - auto task = static_cast(promote_child.get()); - if (task->child->IsSequence() || task->child->IsControl() || - task->child->IsIf()) { - // If child is SequenceNode, ControlNode, or IfNode, recursively analyze - // it - AnalyzeAndInsertBarriers( - task->child.get(), next_barrier_id, barrier_buffers, barrier_map, - thread_count, loop_info, buffer_infos, neutral_sync_shared_barrier); +static auto +GetSyncInfos(const std::vector &units, int num_wgs, + const std::unordered_map &buffer_num_versions = {}, + bool is_loop = false) { + std::set buffers; + for (auto *unit : units) { + for (const auto ®ion_access : unit->GetReadWriteRegions()) { + buffers.insert(region_access.region->buffer); } } - - size_t num_wgs = thread_count.size(); - - // Insert wait_wgmma - std::vector> last_wgmma_map(num_wgs); - std::vector wait_wgmma_id(num_wgs, 0); - std::vector total_wgmma(num_wgs, 0); - - for (auto &promote_child : seq->children) { - auto task = static_cast(promote_child.get()); - if (task->isInnerTask() && task->UsesTensorCore()) { - auto child = static_cast(task->child.get()); - if (child->is_WGMMA()) { - bool found_wgmma = false; - for (const auto ®ion_access : task->GetReadWriteRegions()) { + std::map, + std::pair, int>>> + sync_infos; + for (const auto &buffer : buffers) { + int num_versions = 1; + auto it = buffer_num_versions.find(buffer); + if (it != buffer_num_versions.end()) { + num_versions = it->second; + } + std::vector last_read_unit(num_wgs, nullptr); + ScheduleUnit *last_write_unit = nullptr; + int last_write_wg_id = -1; + std::vector waited_write_wgs(num_wgs, false); + for (int iter = 0; iter < (is_loop ? 2 : 1); ++iter) { + for (ScheduleUnit *unit : units) { + for (const auto ®ion_access : unit->GetReadWriteRegions()) { int wg_id = region_access.warpgroup_id; if (wg_id == -1) continue; - auto ®ion = region_access.region; - if (IsRegisterRegion(region)) { - Buffer buffer = region->buffer; - if (!found_wgmma) { - found_wgmma = true; - ++total_wgmma[wg_id]; + if (region_access.region->buffer != buffer) + continue; + auto add_sync = [&](ScheduleUnit *wait_unit, int wait_wg_id) { + int distance = iter ? num_versions : 0; + auto &[barrier_versions, wait_map] = + sync_infos[{wait_unit, wait_wg_id}]; + barrier_versions = std::max(barrier_versions, num_versions); + auto it = wait_map.find({unit, wg_id}); + if (it == wait_map.end()) { + wait_map[{unit, wg_id}] = distance; + } else { + it->second = std::min(it->second, distance); + } + }; + if (!region_access.is_write) { + if (last_write_unit == nullptr) + continue; + if (waited_write_wgs[wg_id]) + continue; + add_sync(last_write_unit, last_write_wg_id); + } else { + for (int last_wg_id = 0; last_wg_id < num_wgs; ++last_wg_id) { + if (last_read_unit[last_wg_id] == nullptr) + continue; + add_sync(last_read_unit[last_wg_id], last_wg_id); } - last_wgmma_map[wg_id][buffer] = total_wgmma[wg_id]; } } - } - } else { - for (const auto ®ion_access : task->GetReadWriteRegions()) { - int wg_id = region_access.warpgroup_id; - if (wg_id == -1) - continue; - auto ®ion = region_access.region; - if (IsRegisterRegion(region)) { - Buffer buffer = region->buffer; - auto it = last_wgmma_map[wg_id].find(buffer); - if (it == last_wgmma_map[wg_id].end()) - continue; - if (it->second <= wait_wgmma_id[wg_id]) - continue; - wait_wgmma_id[wg_id] = it->second; - Stmt wait_stmt = Evaluate(Call(DataType::Handle(), wait_wgmma(), - {total_wgmma[wg_id] - it->second})); - InsertStatementIntoScheduleUnit(task, wait_stmt, true, wg_id); + if (iter == 0) { + for (const auto ®ion_access : unit->GetReadWriteRegions()) { + int wg_id = region_access.warpgroup_id; + if (wg_id == -1) + continue; + if (region_access.region->buffer != buffer) + continue; + if (!region_access.is_write) { + waited_write_wgs[wg_id] = true; + } else { + for (int wg_id = 0; wg_id < num_wgs; ++wg_id) { + last_read_unit[wg_id] = nullptr; + } + } + } + for (const auto ®ion_access : unit->GetReadWriteRegions()) { + int wg_id = region_access.warpgroup_id; + if (wg_id == -1) + continue; + if (region_access.region->buffer != buffer) + continue; + if (!region_access.is_write) { + last_read_unit[wg_id] = unit; + } else { + last_write_unit = unit; + last_write_wg_id = wg_id; + for (int wg_id = 0; wg_id < num_wgs; ++wg_id) { + waited_write_wgs[wg_id] = false; + } + } + } } } } } + return sync_infos; +} - // Map (ScheduleUnit, warpgroup_id) to barrier buffer - std::map, Buffer> barrier_unit_map; - - // Allocate barriers for TCGEN05MMA - for (auto &promote_child : seq->children) { - auto task = static_cast(promote_child.get()); - if (task->isInnerTask() && task->UsesTensorCore()) { - auto child = static_cast(task->child.get()); - if (child->is_TCGEN05()) { - int wg_id = child->GetWarpgroupId(); - if (wg_id != -1) { - int barrier_id = next_barrier_id++; - // Create a single barrier buffer with shape (1,) - Buffer barrier_buffer = makeBarrierBuffer( - 1, "tcgen05_barrier_" + std::to_string(barrier_id), 1, - barrier_buffers, barrier_map); - barrier_unit_map[std::make_pair(task, wg_id)] = barrier_buffer; - - // Rewrite the gemm call's mbar argument (arg[16]) to use - // BufferLoad(barrier_buffer, {0}) - PrimExpr mbar_expr = BufferLoad(barrier_buffer, {0}); - RewriteGemmMbar(child, mbar_expr); - } else { - PrimExpr mbar_expr = BufferLoad(neutral_sync_shared_barrier, {0}); - RewriteGemmMbar(child, mbar_expr); - } +static void InsertSynchronization( + const std::vector &units, + const std::map< + std::pair, + std::pair, int>>> + &sync_infos, + int &next_barrier_id, std::vector &barrier_buffers, + Map &barrier_map, + const std::vector &thread_count, LoopNestingInfo &loop_info) { + std::map unit_to_order; + for (size_t i = 0; i < units.size(); ++i) { + unit_to_order[units[i]] = i; + } + // Initiate WGMMA tracking structures + int num_wgs = thread_count.size(); + std::vector wgmma_count(num_wgs, 0); + std::vector> wgmma_id(num_wgs); + for (auto unit : units) { + for (int wg_id = 0; wg_id < num_wgs; ++wg_id) { + wgmma_id[wg_id][unit] = wgmma_count[wg_id]; + } + if (unit->HasWGMMA() && unit->isInnerTask()) { + int wg_id = static_cast(unit->child.get())->GetWarpgroupId(); + if (wg_id != -1) { + ++wgmma_count[wg_id]; + } else { + LOG(FATAL) << "WGMMA task without valid warpgroup id"; } } } - - // Allocate barriers for TMA - for (auto &promote_child : seq->children) { - auto task = static_cast(promote_child.get()); - if (task->isInnerTask() && task->UsesTMACore()) { - auto child = static_cast(task->child.get()); - if (child->HasTMALoad()) { - int wg_id = child->GetWarpgroupId(); - if (wg_id != -1) { + // Insert synchronization statements based on sync_infos + for (auto unit : units) { + for (int wg_id = 0; wg_id < num_wgs; ++wg_id) { + auto sync_it = sync_infos.find({unit, wg_id}); + if (sync_it == sync_infos.end()) + continue; + const auto &wait_map = sync_it->second.second; + bool is_async = unit->UsesTMACore() || unit->UsesTensorCore(); + // Handle WGMMA synchronization + if (unit->HasWGMMA()) { + bool different_wg_id = false; + for (const auto &[waiting_unit_info, distance] : wait_map) { + auto [waiting_unit, waiting_wg_id] = waiting_unit_info; + if (waiting_wg_id != wg_id) { + different_wg_id = true; + break; + } + } + if (!different_wg_id) { + for (const auto &[waiting_unit_info, distance] : wait_map) { + auto [waiting_unit, waiting_wg_id] = waiting_unit_info; + // Error: wrong num_mma in prologue and epilogue. + // Cannot fix now. + int real_distance = distance + unit->stage - waiting_unit->stage; + int num_mma = wgmma_id[wg_id][waiting_unit] - wgmma_id[wg_id][unit]; + num_mma += real_distance * wgmma_count[wg_id]; + if (unit->isInnerTask()) { + --num_mma; + } + // Fallback to set num_mma to 0 to avoid error. + num_mma = 0; + Stmt wait_stmt = + Evaluate(Call(DataType::Handle(), wait_wgmma(), {num_mma})); + InsertStatementIntoScheduleUnit(waiting_unit, wait_stmt, true, + wg_id); + } + } else { + Stmt wait_stmt = + Evaluate(Call(DataType::Handle(), wait_wgmma(), {0})); + InsertStatementIntoScheduleUnit(unit, wait_stmt, false, wg_id); + } + // Even if different_wg_id is false, we already inserted the necessary + // wait_wgmma statements inside the warp group. Now we can consider the + // unit as synchronized unless it uses other asynchronous operations. + is_async = unit->UsesTMACore(); + } + int barrier_versions = std::max(sync_it->second.first, 1); + Buffer barrier_buffer; + // Handle single special task, such as TCGEN05 or TMA load, that requires + // a barrier for itself. + if (unit->isInnerTask()) { + auto task = static_cast(unit->child.get()); + int task_wg_id = task->GetWarpgroupId(); + if (task->is_TCGEN05() && task_wg_id == wg_id) { int barrier_id = next_barrier_id++; - Buffer barrier_buffer = makeBarrierBuffer( + barrier_buffer = makeBarrierBuffer( + 1, "tcgen05_barrier_" + std::to_string(barrier_id), + barrier_versions, barrier_buffers, barrier_map); + PrimExpr version_index = + indexmod(loop_info.CalculateIterationCount(), barrier_versions); + PrimExpr mbar_expr = BufferLoad(barrier_buffer, {version_index}); + RewriteGemmMbar(task, mbar_expr); + } + if (task->HasTMALoad() && task_wg_id == wg_id) { + int barrier_id = next_barrier_id++; + barrier_buffer = makeBarrierBuffer( thread_count[wg_id], "tma_barrier_" + std::to_string(barrier_id), - 1, barrier_buffers, barrier_map); - barrier_unit_map[std::make_pair(task, wg_id)] = barrier_buffer; - - PrimExpr barrier_load = BufferLoad(barrier_buffer, {0}); - RewriteCopyMbar(child, barrier_load); - Stmt arrive_stmt = makeBarrierArrive(barrier_load); - InsertStatementIntoScheduleUnit(task, arrive_stmt, false, wg_id); + barrier_versions, barrier_buffers, barrier_map); + PrimExpr version_index = + indexmod(loop_info.CalculateIterationCount(), barrier_versions); + PrimExpr mbar_expr = BufferLoad(barrier_buffer, {version_index}); + RewriteCopyMbar(task, mbar_expr); + Stmt arrive_stmt = makeBarrierArrive(mbar_expr); + InsertStatementIntoScheduleUnit(unit, arrive_stmt, false, wg_id); + } + } + auto check_need_barrier = [&](ScheduleUnit *waiting_unit, + int waiting_wg_id) { + if (wg_id != waiting_wg_id) + return true; + if (!is_async) + return false; + if (unit->isInnerTask() && waiting_unit->isInnerTask() && + static_cast(unit->child.get())->is_TCGEN05() && + static_cast(waiting_unit->child.get())->is_TCGEN05()) { + return false; + } + return true; + }; + bool need_barrier = false; + for (const auto &[waiting_unit_info, distance] : wait_map) { + auto [waiting_unit, waiting_wg_id] = waiting_unit_info; + if (check_need_barrier(waiting_unit, waiting_wg_id)) { + need_barrier = true; + break; + } + } + if (!need_barrier) + continue; + if (!barrier_buffer.defined()) { + // Note: the logic here assumes that if there are TCGEN05 tasks, then + // all tasks are finished when all TCGEN05 tasks are finished. So we can + // use the TCGEN05 barrier for all tasks. If this assumption does not + // hold, we may need to implement a more complex logic to synchronize. + if (unit->HasTCGEN05()) { + int barrier_id = next_barrier_id++; + barrier_buffer = makeBarrierBuffer( + 1, "tcgen05_barrier_" + std::to_string(barrier_id), + barrier_versions, barrier_buffers, barrier_map); + PrimExpr version_index = + indexmod(loop_info.CalculateIterationCount(), barrier_versions); + Stmt arrive_stmt = + makeTcgen05MmaArrive(barrier_buffer, version_index); + InsertStatementIntoScheduleUnit(unit, arrive_stmt, false, wg_id); } else { - PrimExpr barrier_load = BufferLoad(neutral_sync_shared_barrier, {0}); - RewriteCopyMbar(child, barrier_load); + int barrier_id = next_barrier_id++; + barrier_buffer = makeBarrierBuffer( + thread_count[wg_id], "barrier_" + std::to_string(barrier_id), + barrier_versions, barrier_buffers, barrier_map); + PrimExpr version_index = + indexmod(loop_info.CalculateIterationCount(), barrier_versions); + PrimExpr mbar_expr = BufferLoad(barrier_buffer, {version_index}); + Stmt arrive_stmt = makeBarrierArrive(mbar_expr); + InsertStatementIntoScheduleUnit(unit, arrive_stmt, false, wg_id); + } + } + // Add wait statements for all waiting units. + for (const auto &[waiting_unit_info, distance] : wait_map) { + auto [waiting_unit, waiting_wg_id] = waiting_unit_info; + if (check_need_barrier(waiting_unit, waiting_wg_id)) { + PrimExpr iteration = loop_info.CalculateIterationCount() - distance; + PrimExpr version_index = indexmod(iteration, barrier_versions); + PrimExpr mbar_expr = BufferLoad(barrier_buffer, {version_index}); + PrimExpr parity_expr = + indexmod(indexdiv(iteration, barrier_versions), 2); + Stmt wait_stmt = makeBarrierWait(mbar_expr, parity_expr); + InsertStatementIntoScheduleUnit(waiting_unit, wait_stmt, true, + waiting_wg_id); } } } } +} - // Insert barriers for other dependencies - // First collect all buffers except register buffers - std::set buffers; - for (const auto ®ion_access : seq->GetReadWriteRegions()) { - auto &buffer = region_access.region->buffer; - if (!IsRegisterRegion(region_access.region)) { - buffers.emplace(buffer); +static void +AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, + std::vector &barrier_buffers, + Map &barrier_map, + const std::vector &thread_count, + LoopNestingInfo &loop_info, + std::vector &buffer_infos, + Buffer neutral_sync_shared_barrier) { + if (!seq) + return; + + // Collect all tasks from the sequence + std::vector tasks; + for (auto &child : seq->children) { + auto task = static_cast(child.get()); + if (task->child->IsSequence() || task->child->IsControl() || + task->child->IsIf()) { + // If child is SequenceNode, ControlNode, or IfNode, recursively analyze + // it + AnalyzeAndInsertBarriers( + task->child.get(), next_barrier_id, barrier_buffers, barrier_map, + thread_count, loop_info, buffer_infos, neutral_sync_shared_barrier); } + tasks.push_back(task); } - auto is_async_task = [](ScheduleUnit *task) { - return task->UsesTensorCore() || task->UsesTMACore(); - }; - for (const auto &buffer : buffers) { - std::vector last_access_task(num_wgs, nullptr); - std::vector last_access(num_wgs, false); - ScheduleUnit *last_write_task = nullptr; - uint64_t waited_write_wgs = 0; - int last_write_wg_id = -1; - bool last_write = false; - // Process tasks in sequence order - for (auto &promote_child : seq->children) { - auto task = static_cast(promote_child.get()); - bool is_async = is_async_task(task); - for (const auto ®ion_access : task->GetReadWriteRegions()) { - int wg_id = region_access.warpgroup_id; - if (wg_id == -1) - continue; - if (region_access.region->buffer != buffer) - continue; - auto insert_barrier = [&](ScheduleUnit *last_task, int last_wg_id) { - if (last_wg_id == -1) - return; - bool last_async = is_async_task(last_task); - if (last_wg_id == wg_id && !last_async) - return; - if (barrier_unit_map.find(std::make_pair(last_task, last_wg_id)) == - barrier_unit_map.end()) { - // Allocate a new barrier buffer - int barrier_id = next_barrier_id++; - Buffer barrier_buffer = - makeBarrierBuffer(thread_count[last_wg_id], - "barrier_" + std::to_string(barrier_id), 1, - barrier_buffers, barrier_map); - barrier_unit_map[std::make_pair(last_task, last_wg_id)] = - barrier_buffer; - PrimExpr barrier_load = BufferLoad(barrier_buffer, {0}); - // Insert barrier_arrive at the end of last_task's statements - Stmt arrive_stmt = makeBarrierArrive(barrier_load); - InsertStatementIntoScheduleUnit(last_task, arrive_stmt, false, - last_wg_id); - } - auto barrier_buffer = - barrier_unit_map[std::make_pair(last_task, last_wg_id)]; - PrimExpr barrier_load = BufferLoad(barrier_buffer, {0}); - Stmt wait_stmt = makeBarrierWait(barrier_load, 0); - InsertStatementIntoScheduleUnit(task, wait_stmt, true, wg_id); - }; - - if (!region_access.is_write) { - if (last_write_task == nullptr) - continue; - if (waited_write_wgs >> wg_id & 1) - continue; - insert_barrier(last_write_task, last_write_wg_id); - waited_write_wgs |= (1 << wg_id); - } else { - for (int last_wg_id = 0; last_wg_id < num_wgs; ++last_wg_id) { - if (last_access_task[last_wg_id] == nullptr) - continue; - insert_barrier(last_access_task[last_wg_id], last_wg_id); - last_access_task[last_wg_id] = nullptr; - } - } - } - for (const auto ®ion_access : task->GetReadWriteRegions()) { - int wg_id = region_access.warpgroup_id; - if (wg_id == -1) - continue; - if (region_access.region->buffer != buffer) - continue; - last_access_task[wg_id] = task; - if (region_access.is_write) { - last_write_task = task; - last_write_wg_id = wg_id; - waited_write_wgs = 0; - } + // Rewrite TMA load tasks to use tma_copy and neutral_sync_shared_barrier + for (auto task : tasks) { + if (task->isInnerTask() && task->UsesTMACore()) { + auto child = static_cast(task->child.get()); + if (child->HasTMALoad() && child->GetWarpgroupId() == -1) { + PrimExpr barrier_load = BufferLoad(neutral_sync_shared_barrier, {0}); + RewriteCopyMbar(child, barrier_load); } } } + + // Insert synchronization + auto sync_infos = GetSyncInfos(tasks, thread_count.size()); + InsertSynchronization(tasks, sync_infos, next_barrier_id, barrier_buffers, + barrier_map, thread_count, loop_info); } static void @@ -857,7 +937,7 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, auto seq = static_cast(ctrl->child.get()); // Collect all tasks from the sequence - std::vector ordered_tasks; + std::vector tasks; for (auto &child : seq->children) { auto task = static_cast(child.get()); if (task->child->IsSequence() || task->child->IsControl() || @@ -868,11 +948,12 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, task->child.get(), next_barrier_id, barrier_buffers, barrier_map, thread_count, loop_info, buffer_infos, neutral_sync_shared_barrier); } - ordered_tasks.push_back(task); + tasks.push_back(task); } // Process in order: sort by stage // This matches the software pipelining order + auto ordered_tasks = tasks; std::stable_sort( ordered_tasks.begin(), ordered_tasks.end(), [](ScheduleUnit *a, ScheduleUnit *b) { return a->stage > b->stage; }); @@ -940,256 +1021,11 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, RewriteTaskNodeBuffers(ctrl, multi_buffer, iteration); } - size_t num_wgs = thread_count.size(); - - // Insert wait_wgmma - std::vector> last_wgmma_map(num_wgs); - std::vector wait_wgmma_id(num_wgs, 0); - std::vector total_wgmma(num_wgs, 0); - for (unsigned iter = 0; iter != 2; ++iter) { - for (ScheduleUnit *task : ordered_tasks) { - if (task->isInnerTask() && task->UsesTensorCore()) { - if (iter == 1) { - continue; - } - auto child = static_cast(task->child.get()); - if (child->is_WGMMA()) { - bool found_wgmma = false; - for (const auto ®ion_access : task->GetReadWriteRegions()) { - int wg_id = region_access.warpgroup_id; - if (wg_id == -1) - continue; - auto ®ion = region_access.region; - if (IsRegisterRegion(region)) { - Buffer buffer = region->buffer; - if (!found_wgmma) { - found_wgmma = true; - ++total_wgmma[wg_id]; - } - last_wgmma_map[wg_id][buffer] = total_wgmma[wg_id]; - } - } - } - } else { - for (const auto ®ion_access : task->GetReadWriteRegions()) { - int wg_id = region_access.warpgroup_id; - if (wg_id == -1) - continue; - auto ®ion = region_access.region; - if (IsRegisterRegion(region)) { - Buffer buffer = region->buffer; - auto it = last_wgmma_map[wg_id].find(buffer); - if (it == last_wgmma_map[wg_id].end()) - continue; - if (it->second <= wait_wgmma_id[wg_id]) - continue; - wait_wgmma_id[wg_id] = it->second; - Stmt wait_stmt = - Evaluate(Call(DataType::Handle(), wait_wgmma(), - {total_wgmma[wg_id] - it->second})); - InsertStatementIntoScheduleUnit(task, wait_stmt, true, wg_id); - } - } - } - } - } - - // Map (ScheduleUnit, warpgroup_id) to (barrier buffer, num_versions) - std::map, std::pair> - barrier_unit_map; - - // Allocate barriers for TCGEN05MMA - for (ScheduleUnit *task : ordered_tasks) { - if (task->isInnerTask() && task->UsesTensorCore()) { - auto child = static_cast(task->child.get()); - if (child->is_TCGEN05()) { - int num_versions = 1; - for (const auto ®ion_access : child->GetReadWriteRegions()) { - auto &buffer = region_access.region->buffer; - auto it = buffer_num_versions.find(buffer); - if (it != buffer_num_versions.end()) { - num_versions = std::max(num_versions, it->second); - } - } - int wg_id = child->GetWarpgroupId(); - ICHECK(wg_id != -1) << "TCGEN05MMA must have valid warpgroup id"; - - int barrier_id = next_barrier_id++; - // Create a single barrier buffer with shape (num_versions,) - Buffer barrier_buffer = makeBarrierBuffer( - 1, "tcgen05_barrier_" + std::to_string(barrier_id), num_versions, - barrier_buffers, barrier_map); - barrier_unit_map[std::make_pair(task, wg_id)] = - std::make_pair(barrier_buffer, num_versions); - - // Rewrite the gemm call's mbar argument (arg[16]) to use - // BufferLoad(barrier_buffer, {version_index}) - PrimExpr version_index = - indexmod(loop_info.CalculateIterationCount(), num_versions); - PrimExpr mbar_expr = BufferLoad(barrier_buffer, {version_index}); - RewriteGemmMbar(child, mbar_expr); - } - } - } - - // Allocate barriers for TMA - for (ScheduleUnit *task : ordered_tasks) { - if (task->isInnerTask() && task->UsesTMACore()) { - auto child = static_cast(task->child.get()); - if (child->HasTMALoad()) { - int num_versions = 1; - for (const auto ®ion_access : child->GetReadWriteRegions()) { - auto &buffer = region_access.region->buffer; - auto it = buffer_num_versions.find(buffer); - if (it != buffer_num_versions.end()) { - num_versions = std::max(num_versions, it->second); - } - } - int wg_id = child->GetWarpgroupId(); - ICHECK(wg_id != -1) << "TMA loads must have valid warpgroup id"; - - int barrier_id = next_barrier_id++; - Buffer barrier_buffer = makeBarrierBuffer( - thread_count[wg_id], "tma_barrier_" + std::to_string(barrier_id), - num_versions, barrier_buffers, barrier_map); - barrier_unit_map[std::make_pair(task, wg_id)] = - std::make_pair(barrier_buffer, num_versions); - - PrimExpr version_index = - indexmod(loop_info.CalculateIterationCount(), num_versions); - PrimExpr barrier_load = BufferLoad(barrier_buffer, {version_index}); - RewriteCopyMbar(child, barrier_load); - Stmt arrive_stmt = makeBarrierArrive(barrier_load); - InsertStatementIntoScheduleUnit(task, arrive_stmt, false, wg_id); - } - } - } - - // Insert barriers for other dependencies - // First collect all buffers except register buffers - std::set, std::greater>> - buffers; - for (const auto ®ion_access : ctrl->GetReadWriteRegions()) { - auto &buffer = region_access.region->buffer; - if (!IsRegisterRegion(region_access.region)) { - auto it = buffer_num_versions.find(buffer); - int num_versions = it != buffer_num_versions.end() ? it->second : 1; - buffers.emplace(num_versions, buffer); - } - } - // Process buffers in order of decreasing number of versions to ensure - // correct barrier size - auto is_async_task = [](ScheduleUnit *task) { - return task->UsesTensorCore() || task->UsesTMACore(); - }; - for (const auto &[num_versions, buffer] : buffers) { - std::vector last_access_task(num_wgs, nullptr); - std::vector last_access(num_wgs, false); - ScheduleUnit *last_write_task = nullptr; - uint64_t waited_write_wgs = 0; - int last_write_wg_id = -1; - bool last_write = false; - // Process tasks in the specified order - for (unsigned iter = 0; iter != 2; ++iter) { - for (ScheduleUnit *task : ordered_tasks) { - bool is_async = is_async_task(task); - for (const auto ®ion_access : task->GetReadWriteRegions()) { - int wg_id = region_access.warpgroup_id; - if (wg_id == -1) - continue; - if (region_access.region->buffer != buffer) - continue; - - auto insert_barrier = [&](ScheduleUnit *last_task, int last_wg_id) { - if (last_wg_id == -1) - return; - if (last_task == task) // ??? - return; - bool last_async = is_async_task(last_task); - if (last_wg_id == wg_id && !last_async) - return; - if (barrier_unit_map.find(std::make_pair( - last_task, last_wg_id)) == barrier_unit_map.end()) { - // Allocate a new barrier buffer - int barrier_id = next_barrier_id++; - Buffer barrier_buffer = makeBarrierBuffer( - thread_count[last_wg_id], - "barrier_" + std::to_string(barrier_id), num_versions, - barrier_buffers, barrier_map); - barrier_unit_map[std::make_pair(last_task, last_wg_id)] = - std::make_pair(barrier_buffer, num_versions); - // Create BufferLoad with version-indexed offset - PrimExpr version_index = - indexmod(loop_info.CalculateIterationCount(), num_versions); - PrimExpr barrier_load = - BufferLoad(barrier_buffer, {version_index}); - // Insert barrier_arrive at the end of last_task's statements - Stmt arrive_stmt = makeBarrierArrive(barrier_load); - InsertStatementIntoScheduleUnit(last_task, arrive_stmt, false, - last_wg_id); - } - auto [barrier_buffer, barrier_versions] = - barrier_unit_map[std::make_pair(last_task, last_wg_id)]; - PrimExpr iteration = loop_info.CalculateIterationCount(); - if (iter == 1) { - // Calculate the real iteration to wait. - // "+ barrier_versions * 2" ensures positive iteration for - // division and modulo, and keeps the parity the same. - iteration += barrier_versions * 2 - num_versions; - } - PrimExpr version_index = indexmod(iteration, barrier_versions); - PrimExpr barrier_load = - BufferLoad(barrier_buffer, {version_index}); - PrimExpr parity_expr = - indexmod(indexdiv(iteration, barrier_versions), 2); - Stmt wait_stmt = makeBarrierWait(barrier_load, parity_expr); - InsertStatementIntoScheduleUnit(task, wait_stmt, true, wg_id); - }; - - if (!region_access.is_write) { - if (iter == 1) { - if (last_write) - continue; - last_write = true; - } - if (last_write_task == nullptr) - continue; - if (waited_write_wgs >> wg_id & 1) - continue; - insert_barrier(last_write_task, last_write_wg_id); - waited_write_wgs |= (1 << wg_id); - } else { - for (int last_wg_id = 0; last_wg_id < num_wgs; ++last_wg_id) { - if (iter == 1) { - if (last_access[last_wg_id]) - continue; - last_access[last_wg_id] = true; - } - if (last_access_task[last_wg_id] == nullptr) - continue; - insert_barrier(last_access_task[last_wg_id], last_wg_id); - last_access_task[last_wg_id] = nullptr; - } - } - } - if (iter == 0) { - for (const auto ®ion_access : task->GetReadWriteRegions()) { - int wg_id = region_access.warpgroup_id; - if (wg_id == -1) - continue; - if (region_access.region->buffer != buffer) - continue; - last_access_task[wg_id] = task; - if (region_access.is_write) { - last_write_task = task; - last_write_wg_id = wg_id; - waited_write_wgs = 0; - } - } - } - } - } - } + // Insert synchronization + auto sync_infos = GetSyncInfos(ordered_tasks, thread_count.size(), + buffer_num_versions, true); + InsertSynchronization(tasks, sync_infos, next_barrier_id, barrier_buffers, + barrier_map, thread_count, loop_info); } else { AnalyzeAndInsertBarriers( ctrl->child.get(), next_barrier_id, barrier_buffers, barrier_map, From 37f05eece2c67e4df691b37efe0ec50a96181fa5 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Thu, 16 Apr 2026 16:54:51 +0800 Subject: [PATCH 070/156] partly fix pro/epilogue logic for barrier --- src/transform/auto_schedule.cc | 15 +++++++++++---- src/transform/auto_schedule/barrier.h | 10 +++++----- src/transform/auto_schedule/ir_structure.h | 7 +++++-- src/transform/auto_schedule/schedule_builder.cc | 4 ++++ 4 files changed, 25 insertions(+), 11 deletions(-) diff --git a/src/transform/auto_schedule.cc b/src/transform/auto_schedule.cc index 284b1d3551..9f9bfb2614 100644 --- a/src/transform/auto_schedule.cc +++ b/src/transform/auto_schedule.cc @@ -417,12 +417,19 @@ class IRStructureBuilder : public StmtVisitor { } else { // Generic T.copy(): check if TMA is possible. arith::Analyzer ana; - if (!copy->GetDisableTMA() && - copy->CheckBulkLoad(target, &ana, /*check_last_dim=*/true)) { - found_tma = true; - found_tma_load = true; + if (!copy->GetDisableTMA()) { + if (copy->CheckBulkLoad(target, &ana, /*check_last_dim=*/true)) { + found_tma = true; + found_tma_load = true; + } + if (copy->CheckBulkStore(target, &ana, /*check_last_dim=*/true)) { + found_tma = true; + } } } + LOG(INFO) << "ResourceAnalyzer: Detected copy-like operation: " << op_name + << ", found_tma=" << found_tma + << ", found_tma_load=" << found_tma_load; } else if (op->op.same_as(gemm_py_op) || op->op.same_as(gemm_op) || op->op.same_as(wgmma_gemm_py_op) || op->op.same_as(wgmma_gemm_op) || diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index 46e7dfb96a..bd97320aed 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -616,7 +616,7 @@ GetSyncInfos(const std::vector &units, int num_wgs, for (ScheduleUnit *unit : units) { for (const auto ®ion_access : unit->GetReadWriteRegions()) { int wg_id = region_access.warpgroup_id; - if (wg_id == -1) + if (region_access.schedule_phase != SchedulePhase::kBody) continue; if (region_access.region->buffer != buffer) continue; @@ -649,7 +649,7 @@ GetSyncInfos(const std::vector &units, int num_wgs, if (iter == 0) { for (const auto ®ion_access : unit->GetReadWriteRegions()) { int wg_id = region_access.warpgroup_id; - if (wg_id == -1) + if (region_access.schedule_phase != SchedulePhase::kBody) continue; if (region_access.region->buffer != buffer) continue; @@ -663,7 +663,7 @@ GetSyncInfos(const std::vector &units, int num_wgs, } for (const auto ®ion_access : unit->GetReadWriteRegions()) { int wg_id = region_access.warpgroup_id; - if (wg_id == -1) + if (region_access.schedule_phase != SchedulePhase::kBody) continue; if (region_access.region->buffer != buffer) continue; @@ -707,7 +707,7 @@ static void InsertSynchronization( } if (unit->HasWGMMA() && unit->isInnerTask()) { int wg_id = static_cast(unit->child.get())->GetWarpgroupId(); - if (wg_id != -1) { + if (unit->GetSchedulePhase() != SchedulePhase::kBody) { ++wgmma_count[wg_id]; } else { LOG(FATAL) << "WGMMA task without valid warpgroup id"; @@ -888,7 +888,7 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, for (auto task : tasks) { if (task->isInnerTask() && task->UsesTMACore()) { auto child = static_cast(task->child.get()); - if (child->HasTMALoad() && child->GetWarpgroupId() == -1) { + if (child->HasTMALoad() && child->GetSchedulePhase() == SchedulePhase::kPrologue) { PrimExpr barrier_load = BufferLoad(neutral_sync_shared_barrier, {0}); RewriteCopyMbar(child, barrier_load); } diff --git a/src/transform/auto_schedule/ir_structure.h b/src/transform/auto_schedule/ir_structure.h index 24fc94c87f..3606f22f27 100644 --- a/src/transform/auto_schedule/ir_structure.h +++ b/src/transform/auto_schedule/ir_structure.h @@ -214,7 +214,10 @@ class TaskNode : public IRStructure { int GetWarpgroupId() const override { return warpgroup_id_; } // Scheduling phase (prologue / body / epilogue) - void SetSchedulePhase(SchedulePhase phase) { schedule_phase_ = phase; } + void SetSchedulePhase(SchedulePhase phase) { + schedule_phase_ = phase; + warpgroup_id_ = 0; + } SchedulePhase GetSchedulePhase() const override { return schedule_phase_; } bool IsNeutralPhase() const override { return schedule_phase_ != SchedulePhase::kBody; @@ -1384,7 +1387,7 @@ inline void PrintIRStructure(const IRStructure *node, int indent = 0) { } if (promote->child) { LOG(INFO) << indent_str << " Promote body:"; - PrintAllStmts(promote->child.get(), indent + 2); + PrintIRStructure(promote->child.get(), indent + 2); } } else if (node->IsIf()) { const IfNode *if_node = static_cast(node); diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index eaa23cdedd..49b11c35d7 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -268,6 +268,10 @@ void CollectSuffixTasks(IRStructure *root, std::unordered_set candidate_set; for (int i = static_cast(items.size()) - 1; i >= 0; --i) { auto *item = items[i]; + if (item->GetSchedulePhase() == SchedulePhase::kPrologue) { + rejected.push_back(item); + continue; + } if (item->IsControl()) { rejected.push_back(item); continue; From a4aa5f3f6e36fd1b924924932041aa5a78257d8d Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Thu, 16 Apr 2026 17:41:28 +0800 Subject: [PATCH 071/156] refactor ir structure clone --- src/transform/auto_schedule.cc | 3 - src/transform/auto_schedule/barrier.h | 5 +- .../auto_schedule/warpgroup_partition.cc | 155 ++---------------- 3 files changed, 20 insertions(+), 143 deletions(-) diff --git a/src/transform/auto_schedule.cc b/src/transform/auto_schedule.cc index 9f9bfb2614..fdefb6483d 100644 --- a/src/transform/auto_schedule.cc +++ b/src/transform/auto_schedule.cc @@ -427,9 +427,6 @@ class IRStructureBuilder : public StmtVisitor { } } } - LOG(INFO) << "ResourceAnalyzer: Detected copy-like operation: " << op_name - << ", found_tma=" << found_tma - << ", found_tma_load=" << found_tma_load; } else if (op->op.same_as(gemm_py_op) || op->op.same_as(gemm_op) || op->op.same_as(wgmma_gemm_py_op) || op->op.same_as(wgmma_gemm_op) || diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index bd97320aed..b9e356d38a 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -888,7 +888,8 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, for (auto task : tasks) { if (task->isInnerTask() && task->UsesTMACore()) { auto child = static_cast(task->child.get()); - if (child->HasTMALoad() && child->GetSchedulePhase() == SchedulePhase::kPrologue) { + if (child->HasTMALoad() && + child->GetSchedulePhase() == SchedulePhase::kPrologue) { PrimExpr barrier_load = BufferLoad(neutral_sync_shared_barrier, {0}); RewriteCopyMbar(child, barrier_load); } @@ -1036,4 +1037,4 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, loop_info.PopLoop(); } } // namespace tl -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index ea4dec5a94..665fb87da5 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -131,6 +131,10 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, if (node->IsTask()) { auto task = static_cast(node); + if (task->GetSchedulePhase() != SchedulePhase::kBody) { + auto new_task = std::make_shared(); + return new_task; + } // LetDecl tasks are always included in every warp group clone. // Create a fresh variable copy so the two warp groups use different names. @@ -799,57 +803,6 @@ Stmt ApplyWarpgroupPartitionToIRStructure( size_t num_wgs = thread_count.size(); - // Helper function to clone IRStructure filtering tasks with warpgroup_id == - // -1 (neutral tasks) - std::function(IRStructure *)> - clone_neutral_filter; - clone_neutral_filter = - [&clone_neutral_filter]( - IRStructure *node) -> std::shared_ptr { - if (!node) - return nullptr; - - if (node->IsTask()) { - auto task = static_cast(node); - if (task->IsNeutralPhase()) { - return task->Clone(); - } else { - auto new_task = std::make_shared(); - // Empty statements - return new_task; - } - } else if (node->IsSequence()) { - auto seq = static_cast(node); - auto new_seq = std::make_shared(); - for (const auto &child : seq->children) { - if (child) { - auto node = static_cast(child.get()); - auto new_node = clone_neutral_filter(node->child.get()); - if (new_node) { - auto new_unit = std::make_shared(); - new_unit->child = std::move(new_node); - new_seq->children.push_back(std::move(new_unit)); - } - } - } - return new_seq; - } else if (node->IsWrapper()) { - auto wrapper = static_cast(node); - auto new_wrapper = std::make_shared(); - new_wrapper->child = clone_neutral_filter(wrapper->child.get()); - if (new_wrapper->child) { - return new_wrapper; - } - return nullptr; - } else if (node->IsControl()) { - return nullptr; - } else if (node->IsIf()) { - return nullptr; - } - LOG(FATAL); - return nullptr; - }; - auto has_actual_statements = [](IRStructure *node) -> bool { if (!node) return false; @@ -863,50 +816,42 @@ Stmt ApplyWarpgroupPartitionToIRStructure( return false; }; - std::function( - IRStructure *, const std::function &, int)> - clone_neutral_filter_with_top_level; - clone_neutral_filter_with_top_level = - [&clone_neutral_filter_with_top_level, &clone_neutral_filter]( - IRStructure *node, const std::function &include_top_level, - int top_level_index) -> std::shared_ptr { + std::function(IRStructure *, SchedulePhase)> + clone_phase_filter; + clone_phase_filter = + [&clone_phase_filter]( + IRStructure *node, + SchedulePhase phase) -> std::shared_ptr { if (!node) return nullptr; if (node->IsTask()) { - if (include_top_level(top_level_index)) { - return clone_neutral_filter(node); + auto task = static_cast(node); + if (task->GetSchedulePhase() == phase) { + return task->Clone(); } else { auto new_task = std::make_shared(); - // Empty statements return new_task; } } else if (node->IsSequence()) { auto seq = static_cast(node); auto new_seq = std::make_shared(); - int child_index = 0; for (const auto &child : seq->children) { if (child) { auto schedule_unit = static_cast(child.get()); - int next_top_level_index = - top_level_index == -1 ? child_index : top_level_index; - auto new_node = clone_neutral_filter_with_top_level( - schedule_unit->child.get(), include_top_level, - next_top_level_index); + auto new_node = clone_phase_filter(schedule_unit->child.get(), phase); if (new_node) { auto new_unit = std::make_shared(); new_unit->child = std::move(new_node); new_seq->children.push_back(std::move(new_unit)); } } - child_index++; } return new_seq; } else if (node->IsWrapper()) { auto wrapper = static_cast(node); auto new_wrapper = std::make_shared(); - new_wrapper->child = clone_neutral_filter_with_top_level( - wrapper->child.get(), include_top_level, top_level_index); + new_wrapper->child = clone_phase_filter(wrapper->child.get(), phase); if (new_wrapper->child) { return new_wrapper; } @@ -920,76 +865,10 @@ Stmt ApplyWarpgroupPartitionToIRStructure( return nullptr; }; - // Determine which top-level neutral children should be epi (run after - // warpgroup-partitioned code). A neutral child is epi if it directly or - // transitively depends on warpgroup task output. - std::unordered_set wg_write_buffers; - std::unordered_set depends_on_wg_output; - // Per-child write buffers and read buffers for neutral children - struct ChildBufferInfo { - std::unordered_set read_bufs; - std::unordered_set write_bufs; - bool all_neutral = true; - }; - std::vector child_infos; - if (root->IsSequence()) { - auto seq = static_cast(root); - child_infos.resize(seq->children.size()); - for (size_t i = 0; i < seq->children.size(); ++i) { - const auto &child = seq->children[i]; - if (!child) - continue; - auto unit = static_cast(child.get()); - std::vector child_tasks; - CollectAllTaskNodesWithContext(unit->child.get(), child_tasks); - auto &info = child_infos[i]; - for (const auto &task : child_tasks) { - if (!task.task->IsNeutralPhase()) { - info.all_neutral = false; - for (const auto &wr : task.task->GetWriteRegions()) - wg_write_buffers.insert(wr->buffer.get()); - } - for (const auto &rd : task.task->GetReadRegions()) - info.read_bufs.insert(rd->buffer.get()); - for (const auto &wr : task.task->GetWriteRegions()) - info.write_bufs.insert(wr->buffer.get()); - } - } - // Transitive fixpoint: if a neutral child reads from wg_write_buffers, - // mark it as epi and add its write buffers to wg_write_buffers so that - // other neutral children that depend on it are also marked epi. - bool changed = true; - while (changed) { - changed = false; - for (size_t i = 0; i < child_infos.size(); ++i) { - if (!child_infos[i].all_neutral) - continue; - if (depends_on_wg_output.count(static_cast(i))) - continue; - for (const auto *buf : child_infos[i].read_bufs) { - if (wg_write_buffers.count(buf)) { - depends_on_wg_output.insert(static_cast(i)); - for (const auto *wb : child_infos[i].write_bufs) - wg_write_buffers.insert(wb); - changed = true; - break; - } - } - } - } - } - - auto is_epi_top_level_index = [&depends_on_wg_output](int top_level_index) { - return depends_on_wg_output.count(top_level_index) > 0; - }; - auto is_pro_top_level_index = [is_epi_top_level_index](int top_level_index) { - return !is_epi_top_level_index(top_level_index); - }; - auto wg_pro_neutral_structure = - clone_neutral_filter_with_top_level(root, is_pro_top_level_index, -1); + clone_phase_filter(root, SchedulePhase::kPrologue); auto wg_epi_neutral_structure = - clone_neutral_filter_with_top_level(root, is_epi_top_level_index, -1); + clone_phase_filter(root, SchedulePhase::kEpilogue); // wg_children[wg_id][child_index] = filtered IRStructure (nullptr if absent) std::vector>> wg_children(num_wgs); std::vector> wg_structures(num_wgs); From 33e5c366114891a1e659d22b6cb062442e8635c2 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Thu, 16 Apr 2026 18:02:23 +0800 Subject: [PATCH 072/156] fix bug --- src/transform/auto_schedule/barrier.h | 2 +- src/transform/auto_schedule/ir_structure.h | 5 +---- src/transform/auto_schedule/schedule_builder.cc | 6 ++++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index b9e356d38a..659b92b1e8 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -707,7 +707,7 @@ static void InsertSynchronization( } if (unit->HasWGMMA() && unit->isInnerTask()) { int wg_id = static_cast(unit->child.get())->GetWarpgroupId(); - if (unit->GetSchedulePhase() != SchedulePhase::kBody) { + if (unit->GetSchedulePhase() == SchedulePhase::kBody) { ++wgmma_count[wg_id]; } else { LOG(FATAL) << "WGMMA task without valid warpgroup id"; diff --git a/src/transform/auto_schedule/ir_structure.h b/src/transform/auto_schedule/ir_structure.h index 3606f22f27..317904906b 100644 --- a/src/transform/auto_schedule/ir_structure.h +++ b/src/transform/auto_schedule/ir_structure.h @@ -214,10 +214,7 @@ class TaskNode : public IRStructure { int GetWarpgroupId() const override { return warpgroup_id_; } // Scheduling phase (prologue / body / epilogue) - void SetSchedulePhase(SchedulePhase phase) { - schedule_phase_ = phase; - warpgroup_id_ = 0; - } + void SetSchedulePhase(SchedulePhase phase) { schedule_phase_ = phase; } SchedulePhase GetSchedulePhase() const override { return schedule_phase_; } bool IsNeutralPhase() const override { return schedule_phase_ != SchedulePhase::kBody; diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index 49b11c35d7..45d8cecb3a 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -381,12 +381,14 @@ AssignWarpgroupIdsGlobal(IRStructure *root, const WarpSpecializeConfig &config, CollectPrefixTasks(root, prefix_tasks); for (auto *task : prefix_tasks) { task->SetSchedulePhase(SchedulePhase::kPrologue); + task->SetWarpgroupId(0); } std::unordered_set suffix_tasks; CollectSuffixTasks(root, all_tasks, uf, suffix_tasks); for (auto *task : suffix_tasks) { task->SetSchedulePhase(SchedulePhase::kEpilogue); + task->SetWarpgroupId(0); } std::unordered_map> components; @@ -725,8 +727,8 @@ NaiveAssignWarpgroupIds(IRStructure *root, const WarpSpecializeConfig &config, std::unordered_set prefix_tasks; CollectPrefixTasks(root, prefix_tasks); for (auto *task : prefix_tasks) { - task->SetWarpgroupId(-1); task->SetSchedulePhase(SchedulePhase::kPrologue); + task->SetWarpgroupId(0); } int n = all_tasks.size(); @@ -741,8 +743,8 @@ NaiveAssignWarpgroupIds(IRStructure *root, const WarpSpecializeConfig &config, std::unordered_set suffix_tasks; CollectSuffixTasks(root, all_tasks, uf, suffix_tasks); for (auto *task : suffix_tasks) { - task->SetWarpgroupId(-1); task->SetSchedulePhase(SchedulePhase::kEpilogue); + task->SetWarpgroupId(0); } // no double_thread in naive mode From 7d1e6e3bb26cf6ebb33d6abe7a85ee601d8312f0 Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Thu, 16 Apr 2026 18:23:04 +0800 Subject: [PATCH 073/156] fix bug --- src/transform/auto_schedule/barrier.h | 6 ++-- .../auto_schedule/warpgroup_partition.cc | 30 +++---------------- 2 files changed, 6 insertions(+), 30 deletions(-) diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index 659b92b1e8..08575119f9 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -257,10 +257,6 @@ static Stmt InsertBarriersForNeutralSyncWithDependency( Buffer neutral_sync_shared_barrier = Buffer(), Var thread_var = Var(), PrimExpr tensor_core_wg_start = PrimExpr(), PrimExpr tensor_core_wg_end = PrimExpr()) { - if (IsEvaluateZero(producer_body) || IsEvaluateZero(consumer_body)) { - return SeqStmt({producer_body, consumer_body}); - } - if (!need_regular_barrier && !need_tmem_barrier) { return SeqStmt({producer_body, consumer_body}); } @@ -618,6 +614,7 @@ GetSyncInfos(const std::vector &units, int num_wgs, int wg_id = region_access.warpgroup_id; if (region_access.schedule_phase != SchedulePhase::kBody) continue; + ICHECK(0 <= wg_id && wg_id < num_wgs); if (region_access.region->buffer != buffer) continue; auto add_sync = [&](ScheduleUnit *wait_unit, int wait_wg_id) { @@ -708,6 +705,7 @@ static void InsertSynchronization( if (unit->HasWGMMA() && unit->isInnerTask()) { int wg_id = static_cast(unit->child.get())->GetWarpgroupId(); if (unit->GetSchedulePhase() == SchedulePhase::kBody) { + ICHECK(0 <= wg_id && wg_id < num_wgs); ++wgmma_count[wg_id]; } else { LOG(FATAL) << "WGMMA task without valid warpgroup id"; diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index 665fb87da5..0c6579b4d0 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -1063,32 +1063,10 @@ Stmt ApplyWarpgroupPartitionToIRStructure( thread_count.begin() + 1, thread_count.end(), thread_count[0]); Stmt pro_and_warpgroup_stmt; - if (wg_pro_neutral_has_stmts) { - if (!IsEvaluateZero(if_then_else) && !IsEvaluateZero(pro_neutral_body)) { - // Both have statements: insert barriers for neutral-to-warpgroup - // synchronization - pro_and_warpgroup_stmt = InsertBarriersForNeutralSync( - pro_neutral_body, if_then_else, barrier_buffers, barrier_map, - updated_thread_extent, neutral_sync_shared_barrier); - } else if (!IsEvaluateZero(if_then_else) || - !IsEvaluateZero(pro_neutral_body)) { - // Only one has actual statements - std::vector stmts; - if (!IsEvaluateZero(pro_neutral_body)) { - stmts.push_back(pro_neutral_body); - } - if (!IsEvaluateZero(if_then_else)) { - stmts.push_back(if_then_else); - } - if (stmts.size() == 1) { - pro_and_warpgroup_stmt = stmts[0]; - } else { - pro_and_warpgroup_stmt = SeqStmt(stmts); - } - } else { - // Both are empty - pro_and_warpgroup_stmt = Evaluate(0); - } + if (wg_pro_neutral_has_stmts && !IsEvaluateZero(pro_neutral_body)) { + pro_and_warpgroup_stmt = InsertBarriersForNeutralSync( + pro_neutral_body, if_then_else, barrier_buffers, barrier_map, + updated_thread_extent, neutral_sync_shared_barrier); } else { pro_and_warpgroup_stmt = if_then_else; } From d098f51b8fe8e7bf1258388292583cf07a4513bf Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Fri, 17 Apr 2026 02:27:00 +0800 Subject: [PATCH 074/156] [Enhancement] Use atomic directory rename for cache writes (#1982) * [Enhancement] Use atomic directory rename for cache writes Replace per-file os.replace with write-to-staging + os.rename to ensure other processes never see an incomplete cache directory. Also add cleanup of stale staging dirs on init. * [Bugfix] Fix atomic cache save: abort on critical file failure and fix staging cleanup - Let kernel lib and params save failures propagate instead of swallowing them, so the staging directory is cleaned up rather than renamed as an incomplete cache entry. - Move autotuner staging directory to TILELANG_CACHE_DIR so that _cleanup_stale_staging_dirs() can find and remove stale entries. * Fix atomic cache save recovery and error handling * Remove unnecessary blank lines in _FakeAdapter and _FakeKernel classes in autotune and cache test files. * Tighten atomic cache completeness checks * Simplify atomic cache save error handling * submodule update * remove legacy test --------- Co-authored-by: Freebase6912 --- 3rdparty/tvm | 2 +- .../test_tilelang_autotune_atomic_save.py | 150 ++++++++ .../test_tilelang_kernel_cache_atomic_save.py | 118 ++++++ testing/python/tilelang/test_tma_load.py | 68 ---- tilelang/autotuner/param.py | 339 +++++++++++------- tilelang/cache/kernel_cache.py | 120 +++++-- tilelang/jit/adapter/nvrtc/kernel_cache.py | 3 + 7 files changed, 565 insertions(+), 235 deletions(-) create mode 100644 testing/python/autotune/test_tilelang_autotune_atomic_save.py create mode 100644 testing/python/cache/test_tilelang_kernel_cache_atomic_save.py delete mode 100644 testing/python/tilelang/test_tma_load.py diff --git a/3rdparty/tvm b/3rdparty/tvm index fab43e41c0..882a774844 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit fab43e41c004e888ded30d45df25ccc8e2612617 +Subproject commit 882a774844993d103ae6e317ba3c7bbb5952b662 diff --git a/testing/python/autotune/test_tilelang_autotune_atomic_save.py b/testing/python/autotune/test_tilelang_autotune_atomic_save.py new file mode 100644 index 0000000000..1369d76440 --- /dev/null +++ b/testing/python/autotune/test_tilelang_autotune_atomic_save.py @@ -0,0 +1,150 @@ +import errno + +import pytest + +from tilelang.autotuner import param as autotune_param +from tilelang.autotuner.param import ( + AutotuneResult, + BEST_CONFIG_PATH, + FUNCTION_PATH, + LATENCY_PATH, + DEVICE_KERNEL_PATH, + HOST_KERNEL_PATH, + KERNEL_CUBIN_PATH, + KERNEL_LIB_PATH, + KERNEL_PY_PATH, + PARAMS_PATH, +) +from tilelang.env import env + + +class _FakeAdapter: + def __init__(self, libpath: str): + self.libpath = libpath + + def get_kernel_source(self): + return "// host kernel" + + def get_host_source(self): + return "// host kernel" + + +class _FakeKernel: + def __init__(self, libpath: str, execution_backend: str = "cython"): + self.execution_backend = execution_backend + self.adapter = _FakeAdapter(libpath) + self.kernel_source = "// device kernel" + self.params = ["param"] + + +def _fake_func(): + return None + + +@pytest.fixture +def cache_dirs(tmp_path, monkeypatch): + cache_dir = tmp_path / "cache" + tmp_dir = tmp_path / "tmp" + cache_dir.mkdir() + tmp_dir.mkdir() + monkeypatch.setattr(env, "TILELANG_CACHE_DIR", str(cache_dir)) + monkeypatch.setattr(env, "TILELANG_TMP_DIR", str(tmp_dir)) + return cache_dir + + +def _make_result(tmp_path, execution_backend: str = "cython"): + if execution_backend == "nvrtc": + lib_path = tmp_path / "kernel.cubin" + lib_path.write_bytes(b"fake-cubin") + lib_path.with_suffix(".py").write_text("# fake launcher") + else: + lib_path = tmp_path / "kernel_lib.so" + lib_path.write_bytes(b"fake-so") + _fake_func.attrs = None + return AutotuneResult( + latency=1.0, + config={"threads": 128}, + ref_latency=2.0, + libcode="// libcode", + func=_fake_func, + kernel=_FakeKernel(str(lib_path), execution_backend=execution_backend), + ) + + +def test_autotune_save_rewrites_incomplete_cache_dir(cache_dirs, tmp_path): + result = _make_result(tmp_path) + path = cache_dirs / "autotune-entry" + path.mkdir() + (path / "stale.txt").write_text("partial") + + result.save_to_disk(path) + + for filename in ( + BEST_CONFIG_PATH, + FUNCTION_PATH, + LATENCY_PATH, + DEVICE_KERNEL_PATH, + HOST_KERNEL_PATH, + KERNEL_LIB_PATH, + PARAMS_PATH, + ): + assert (path / filename).exists() + assert not (path / "stale.txt").exists() + + +def test_autotune_save_logs_write_oserror_instead_of_treating_it_as_race(cache_dirs, tmp_path, monkeypatch): + result = _make_result(tmp_path) + path = cache_dirs / "autotune-error" + logged = [] + + def raise_write_error(self, *args, **kwargs): + raise OSError(errno.ENOSPC, "No space left on device") + + def record_exception(message, *args, **kwargs): + logged.append(message) + + monkeypatch.setattr(AutotuneResult, "_save_kernel_to_disk", raise_write_error) + monkeypatch.setattr(autotune_param.logger, "exception", record_exception) + + result.save_to_disk(path) + + assert not path.exists() + assert "Error during atomic autotune result save" in logged + assert not any(child.name.startswith(".staging_") for child in cache_dirs.iterdir()) + + +def test_autotune_save_does_not_publish_incomplete_dir_when_device_source_is_missing(cache_dirs, tmp_path, monkeypatch): + result = _make_result(tmp_path) + result.kernel.kernel_source = None + path = cache_dirs / "autotune-missing-device-source" + logged = [] + + def record_exception(message, *args, **kwargs): + logged.append(message) + + monkeypatch.setattr(autotune_param.logger, "exception", record_exception) + + result.save_to_disk(path) + + assert not path.exists() + assert "Error during atomic autotune result save" in logged + assert not any(child.name.startswith(".staging_") for child in cache_dirs.iterdir()) + + +def test_autotune_save_rewrites_nvrtc_dir_missing_launcher(cache_dirs, tmp_path): + result = _make_result(tmp_path, execution_backend="nvrtc") + path = cache_dirs / "autotune-nvrtc-entry" + path.mkdir() + (path / BEST_CONFIG_PATH).write_text("{}") + (path / FUNCTION_PATH).write_bytes(b"old-func") + (path / LATENCY_PATH).write_text('{"latency": 1.0, "ref_latency": 2.0}') + (path / DEVICE_KERNEL_PATH).write_text("// device kernel") + (path / HOST_KERNEL_PATH).write_text("// host kernel") + (path / KERNEL_CUBIN_PATH).write_bytes(b"old-cubin") + (path / PARAMS_PATH).write_bytes(b"old-params") + (path / "legacy.txt").write_text("stale") + + result.save_to_disk(path) + + assert (path / KERNEL_PY_PATH).exists() + assert not (path / "legacy.txt").exists() diff --git a/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py b/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py new file mode 100644 index 0000000000..d287c81799 --- /dev/null +++ b/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py @@ -0,0 +1,118 @@ +import errno +import pytest + +from tilelang.cache.kernel_cache import KernelCache +from tilelang.env import env +from tilelang.jit.adapter.nvrtc.kernel_cache import NVRTCKernelCache + + +class _FakeAdapter: + def __init__(self, libpath: str): + self.libpath = libpath + + def get_kernel_source(self): + return "// host kernel" + + +class _FakeKernel: + def __init__(self, libpath: str): + self.adapter = _FakeAdapter(libpath) + self.kernel_source = "// device kernel" + self.params = ["param"] + + +@pytest.fixture +def cache_dirs(tmp_path, monkeypatch): + cache_dir = tmp_path / "cache" + tmp_dir = tmp_path / "tmp" + cache_dir.mkdir() + tmp_dir.mkdir() + monkeypatch.setattr(env, "TILELANG_CACHE_DIR", str(cache_dir)) + monkeypatch.setattr(env, "TILELANG_TMP_DIR", str(tmp_dir)) + return cache_dir + + +def _make_fake_kernel(tmp_path): + lib_path = tmp_path / "kernel_lib.so" + lib_path.write_bytes(b"fake-so") + return _FakeKernel(str(lib_path)) + + +def _make_fake_nvrtc_kernel(tmp_path): + lib_path = tmp_path / "kernel.cubin" + lib_path.write_bytes(b"fake-cubin") + lib_path.with_suffix(".py").write_text("# fake launcher") + return _FakeKernel(str(lib_path)) + + +def test_kernel_cache_rewrites_incomplete_cache_dir(cache_dirs, tmp_path): + cache = KernelCache() + key = "atomic-repair" + cache_path = cache_dirs / key + cache_path.mkdir() + (cache_path / "stale.txt").write_text("partial") + + cache._save_kernel_to_disk(key, _make_fake_kernel(tmp_path)) + + assert (cache_path / cache.device_kernel_path).exists() + assert (cache_path / cache.host_kernel_path).exists() + assert (cache_path / cache.kernel_lib_path).exists() + assert (cache_path / cache.params_path).exists() + assert not (cache_path / "stale.txt").exists() + + +def test_kernel_cache_logs_write_oserror_instead_of_treating_it_as_race(cache_dirs, tmp_path, monkeypatch): + cache = KernelCache() + key = "atomic-write-error" + logged = [] + + def raise_write_error(*args, **kwargs): + raise OSError(errno.ENOSPC, "No space left on device") + + def record_exception(message, *args, **kwargs): + logged.append(message) + + monkeypatch.setattr(cache, "_save_so_cubin_to_disk", raise_write_error) + monkeypatch.setattr(cache.logger, "exception", record_exception) + + cache._save_kernel_to_disk(key, _make_fake_kernel(tmp_path)) + + assert f"{key}" not in {path.name for path in cache_dirs.iterdir()} + assert "Error during atomic cache save" in logged + assert not any(path.name.startswith(".staging_") for path in cache_dirs.iterdir()) + + +def test_kernel_cache_does_not_publish_incomplete_dir_when_device_source_is_missing(cache_dirs, tmp_path, monkeypatch): + cache = KernelCache() + key = "atomic-missing-device-source" + kernel = _make_fake_kernel(tmp_path) + kernel.kernel_source = None + logged = [] + + def record_exception(message, *args, **kwargs): + logged.append(message) + + monkeypatch.setattr(cache.logger, "exception", record_exception) + + cache._save_kernel_to_disk(key, kernel) + + assert f"{key}" not in {path.name for path in cache_dirs.iterdir()} + assert "Error during atomic cache save" in logged + assert not any(path.name.startswith(".staging_") for path in cache_dirs.iterdir()) + + +def test_nvrtc_kernel_cache_rewrites_dir_missing_launcher(cache_dirs, tmp_path): + cache = NVRTCKernelCache() + key = "nvrtc-atomic-repair" + cache_path = cache_dirs / key + cache_path.mkdir() + (cache_path / cache.device_kernel_path).write_text("// device kernel") + (cache_path / cache.host_kernel_path).write_text("// host kernel") + (cache_path / cache.kernel_lib_path).write_bytes(b"old-cubin") + (cache_path / cache.params_path).write_bytes(b"old-params") + (cache_path / "legacy.txt").write_text("stale") + + cache._save_kernel_to_disk(key, _make_fake_nvrtc_kernel(tmp_path)) + + assert (cache_path / cache.kernel_py_path).exists() + assert not (cache_path / "legacy.txt").exists() diff --git a/testing/python/tilelang/test_tma_load.py b/testing/python/tilelang/test_tma_load.py deleted file mode 100644 index 02b82ee278..0000000000 --- a/testing/python/tilelang/test_tma_load.py +++ /dev/null @@ -1,68 +0,0 @@ -import pytest -import torch - -import tilelang as tl -import tilelang.language as T - - -N = 4096 -BLOCK = 256 - - -def _get_device_capability() -> tuple[int, int]: - if not torch.cuda.is_available(): - return (0, 0) - return torch.cuda.get_device_capability() - - -def _extract_source(kernel) -> str: - if hasattr(kernel, "get_source"): - source = kernel.get_source() - if isinstance(source, str) and source: - return source - - module = getattr(kernel, "module", None) - if module is not None and hasattr(module, "imported_modules"): - imported = getattr(module, "imported_modules", []) - if imported: - source = imported[0].get_source() - if isinstance(source, str) and source: - return source - - runtime_mod = getattr(kernel, "rt_mod", None) - if runtime_mod is not None and hasattr(runtime_mod, "imported_modules"): - imported = getattr(runtime_mod, "imported_modules", []) - if imported: - source = imported[0].get_source() - if isinstance(source, str) and source: - return source - - raise RuntimeError("Unable to extract generated source from compiled kernel") - - -def _build_1d_tma_copy(): - @T.prim_func - def main(A: T.Buffer((N,), "float16"), B: T.Buffer((N,), "float16")): - with T.Kernel(T.ceildiv(N, BLOCK), threads=128) as bx: - A_shared = T.alloc_shared((BLOCK,), "float16") - T.copy(A[bx * BLOCK : (bx + 1) * BLOCK], A_shared) - T.copy(A_shared, B[bx * BLOCK : (bx + 1) * BLOCK]) - - return main - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") -@pytest.mark.skipif(_get_device_capability()[0] < 9, reason="Hopper (sm90+) is required for TMA") -def test_tma_load_1d_compile_and_run_regression(): - program = _build_1d_tma_copy() - kernel = tl.compile(program, out_idx=[1], target="cuda -arch=sm_90a") - - source = _extract_source(kernel) - assert "cp.async.bulk.tensor" in source - assert ".1d" in source - - a = torch.randn((N,), device="cuda", dtype=torch.float16) - b = torch.empty_like(a) - - kernel(a, b) - torch.testing.assert_close(b, a, atol=0, rtol=0) diff --git a/tilelang/autotuner/param.py b/tilelang/autotuner/param.py index 6c79b78d47..5e0152e6b9 100644 --- a/tilelang/autotuner/param.py +++ b/tilelang/autotuner/param.py @@ -9,10 +9,12 @@ from typing import Callable, Literal, Any from dataclasses import dataclass from pathlib import Path +import errno from tilelang.jit import JITKernel import cloudpickle import os +import shutil from tilelang.engine.param import KernelParam from tilelang import logger import json @@ -193,109 +195,88 @@ def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: boo - kernel_lib.so: The compiled kernel library - params.pkl: The serialized kernel parameters """ - os.makedirs(cache_path, exist_ok=True) # Ensure directory exists + os.makedirs(cache_path, exist_ok=True) # Ensure directory exists. # Save device kernel source code - try: - device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH) - if verbose: - logger.debug(f"Saving kernel source code to file: {device_kernel_path}") - if kernel.kernel_source is not None: - self._safe_write_file(device_kernel_path, "w", lambda f: f.write(kernel.kernel_source)) - except Exception as e: - logger.error(f"Error saving kernel source code to disk: {e}") + device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH) + if verbose: + logger.debug(f"Saving kernel source code to file: {device_kernel_path}") + if kernel.kernel_source is not None: + self._safe_write_file(device_kernel_path, "w", lambda f: f.write(kernel.kernel_source)) # Save host kernel source code (wrapped) - try: - host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH) - if verbose: - logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}") - # Match kernel_cache behavior: use host source for tvm_ffi, otherwise wrapped kernel - if kernel.execution_backend == "tvm_ffi": - self._safe_write_file(host_kernel_path, "w", lambda f: f.write(kernel.adapter.get_host_source())) - else: - self._safe_write_file(host_kernel_path, "w", lambda f: f.write(kernel.adapter.get_kernel_source())) - except Exception as e: - logger.error(f"Error saving wrapped kernel source code to disk: {e}") + host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH) + if verbose: + logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}") + # Match kernel_cache behavior: use host source for tvm_ffi, otherwise wrapped kernel + if kernel.execution_backend == "tvm_ffi": + self._safe_write_file(host_kernel_path, "w", lambda f: f.write(kernel.adapter.get_host_source())) + else: + self._safe_write_file(host_kernel_path, "w", lambda f: f.write(kernel.adapter.get_kernel_source())) # Save kernel library (backend-specific) - try: - if kernel.execution_backend == "nvrtc": - kernel_lib_file = KERNEL_CUBIN_PATH - elif kernel.execution_backend == "tvm_ffi": - kernel_lib_file = EXECUTABLE_PATH - elif kernel.execution_backend == "cutedsl": - # cutedsl only generates a Python source file as the "library", so save that instead of a .so - kernel_lib_file = KERNEL_PY_PATH - else: - kernel_lib_file = KERNEL_LIB_PATH + kernel_lib_file = self._get_kernel_lib_file(kernel.execution_backend) - kernel_lib_path = os.path.join(cache_path, kernel_lib_file) + kernel_lib_path = os.path.join(cache_path, kernel_lib_file) - if kernel.execution_backend == "nvrtc": - # Save cubin and python helper file - src_lib_path = kernel.adapter.libpath - kernel_py_path = os.path.join(cache_path, KERNEL_PY_PATH) - py_src_path = src_lib_path.replace(".cubin", ".py") - if verbose: - logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}") - self._safe_write_file(kernel_py_path, "wb", lambda f: f.write(self._load_binary(py_src_path))) - if verbose: - logger.debug(f"Saving kernel library to file: {kernel_lib_path}") - self._safe_write_file(kernel_lib_path, "wb", lambda f: f.write(self._load_binary(src_lib_path))) - elif kernel.execution_backend == "tvm_ffi": - if hasattr(kernel.adapter, "libpath") and kernel.adapter.libpath: - src_lib_path = kernel.adapter.libpath - if verbose: - logger.debug(f"Copying kernel library to file: {kernel_lib_path}") - self._safe_write_file(kernel_lib_path, "wb", lambda f: f.write(self._load_binary(src_lib_path))) - else: - executable = kernel.adapter.executable - if verbose: - logger.debug(f"Saving kernel executable to file: {kernel_lib_path}") - self._safe_write_executable(executable, kernel_lib_path) - elif kernel.execution_backend == "cutedsl": - # Save the Python source file (CuTeDSL "library" is a .py, not a .so) + if kernel.execution_backend == "nvrtc": + # Save cubin and python helper file + src_lib_path = kernel.adapter.libpath + kernel_py_path = os.path.join(cache_path, KERNEL_PY_PATH) + py_src_path = src_lib_path.replace(".cubin", ".py") + if verbose: + logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}") + self._safe_write_file(kernel_py_path, "wb", lambda f: f.write(self._load_binary(py_src_path))) + if verbose: + logger.debug(f"Saving kernel library to file: {kernel_lib_path}") + self._safe_write_file(kernel_lib_path, "wb", lambda f: f.write(self._load_binary(src_lib_path))) + elif kernel.execution_backend == "tvm_ffi": + if hasattr(kernel.adapter, "libpath") and kernel.adapter.libpath: src_lib_path = kernel.adapter.libpath if verbose: - logger.debug(f"Saving CuTeDSL kernel Python source to file: {kernel_lib_path}") + logger.debug(f"Copying kernel library to file: {kernel_lib_path}") self._safe_write_file(kernel_lib_path, "wb", lambda f: f.write(self._load_binary(src_lib_path))) - - # Save launcher .so if present (compiled C++ launcher for TMA etc.) - lib_gen = kernel.adapter.lib_generator - launcher_src = getattr(lib_gen, "launcher_libpath", None) - if launcher_src and os.path.exists(launcher_src): - launcher_name = getattr(lib_gen, "launcher_libname", os.path.basename(launcher_src)) - dst_launcher = os.path.join(cache_path, launcher_name) - if verbose: - logger.debug(f"Saving CuTeDSL launcher library to file: {dst_launcher}") - self._safe_write_file(dst_launcher, "wb", lambda f: f.write(self._load_binary(launcher_src))) - - # Save cubin if already generated (generated during autotuning benchmark) - src_dir = os.path.dirname(src_lib_path) - src_cubin = os.path.join(src_dir, "kernel.cubin") - if os.path.exists(src_cubin): - dst_cubin = os.path.join(cache_path, KERNEL_CUBIN_PATH) - if verbose: - logger.debug(f"Saving CuTeDSL cubin to file: {dst_cubin}") - self._safe_write_file(dst_cubin, "wb", lambda f: f.write(self._load_binary(src_cubin))) else: - src_lib_path = kernel.adapter.libpath + executable = kernel.adapter.executable if verbose: - logger.debug(f"Saving kernel library to file: {kernel_lib_path}") - self._safe_write_file(kernel_lib_path, "wb", lambda f: f.write(self._load_binary(src_lib_path))) - - except Exception as e: - logger.error(f"Error saving kernel library to disk: {e}") + logger.debug(f"Saving kernel executable to file: {kernel_lib_path}") + self._safe_write_executable(executable, kernel_lib_path) + elif kernel.execution_backend == "cutedsl": + # Save the Python source file (CuTeDSL "library" is a .py, not a .so) + src_lib_path = kernel.adapter.libpath + if verbose: + logger.debug(f"Saving CuTeDSL kernel Python source to file: {kernel_lib_path}") + self._safe_write_file(kernel_lib_path, "wb", lambda f: f.write(self._load_binary(src_lib_path))) + + # Save launcher .so if present (compiled C++ launcher for TMA etc.) + lib_gen = kernel.adapter.lib_generator + launcher_src = getattr(lib_gen, "launcher_libpath", None) + if launcher_src and os.path.exists(launcher_src): + launcher_name = getattr(lib_gen, "launcher_libname", os.path.basename(launcher_src)) + dst_launcher = os.path.join(cache_path, launcher_name) + if verbose: + logger.debug(f"Saving CuTeDSL launcher library to file: {dst_launcher}") + self._safe_write_file(dst_launcher, "wb", lambda f: f.write(self._load_binary(launcher_src))) + + # Save cubin if already generated (generated during autotuning benchmark) + src_dir = os.path.dirname(src_lib_path) + src_cubin = os.path.join(src_dir, "kernel.cubin") + if os.path.exists(src_cubin): + dst_cubin = os.path.join(cache_path, KERNEL_CUBIN_PATH) + if verbose: + logger.debug(f"Saving CuTeDSL cubin to file: {dst_cubin}") + self._safe_write_file(dst_cubin, "wb", lambda f: f.write(self._load_binary(src_cubin))) + else: + src_lib_path = kernel.adapter.libpath + if verbose: + logger.debug(f"Saving kernel library to file: {kernel_lib_path}") + self._safe_write_file(kernel_lib_path, "wb", lambda f: f.write(self._load_binary(src_lib_path))) # Save kernel parameters - try: - params_path = os.path.join(cache_path, PARAMS_PATH) - if verbose: - logger.debug(f"Saving kernel parameters to disk: {params_path}") - self._safe_write_file(params_path, "wb", lambda f: cloudpickle.dump(kernel.params, f)) - except Exception as e: - logger.error(f"Error saving kernel parameters to disk: {e}") + params_path = os.path.join(cache_path, PARAMS_PATH) + if verbose: + logger.debug(f"Saving kernel parameters to disk: {params_path}") + self._safe_write_file(params_path, "wb", lambda f: cloudpickle.dump(kernel.params, f)) def _load_kernel_from_disk( self, @@ -330,21 +311,15 @@ def _load_kernel_from_disk( return None # Resolve backend to pick correct file names - if execution_backend == "nvrtc": - kernel_lib_file = KERNEL_CUBIN_PATH - elif execution_backend == "tvm_ffi": - kernel_lib_file = EXECUTABLE_PATH - elif execution_backend == "cutedsl": - kernel_lib_file = KERNEL_PY_PATH - else: - kernel_lib_file = KERNEL_LIB_PATH + kernel_lib_file = self._get_kernel_lib_file(execution_backend) device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH) host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH) kernel_lib_path = os.path.join(cache_path, kernel_lib_file) params_path = os.path.join(cache_path, PARAMS_PATH) - if not all([os.path.exists(file) for file in (kernel_lib_path, params_path)]): + required_files = [*self._get_required_kernel_files(Path(cache_path), execution_backend), Path(params_path)] + if not all(file.exists() for file in required_files): return None device_kernel_source: str | None = None @@ -396,52 +371,88 @@ def _load_kernel_from_disk( return None def save_to_disk(self, path: Path, verbose: bool = False): - if not os.path.exists(path): - os.makedirs(path) + """Persist autotune result to disk using atomic directory rename. - # save best config (atomic) - if verbose: - logger.debug(f"Saving best config to file: {path / BEST_CONFIG_PATH}") - self._safe_write_file(str(path / BEST_CONFIG_PATH), "w", lambda f: json.dump(self.config, f)) + All files are written into a temporary staging directory next to the + final *path*. Once complete, the staging directory is atomically + renamed to *path* so that concurrent readers never see a half-written + result. + """ + # Already saved (e.g. another process won the race with a complete entry). + if self._is_complete_result_dir(path, self.kernel.execution_backend): + return - # save function (atomic) - if verbose: - logger.debug(f"Saving function to file: {path / FUNCTION_PATH}") - self._safe_write_file(str(path / FUNCTION_PATH), "wb", lambda f: cloudpickle.dump(self.func, f)) + # Staging dir lives under TILELANG_CACHE_DIR (not the autotuner subdir) so that + # KernelCache._cleanup_stale_staging_dirs() can find and clean up stale entries. + staging_path = Path(env.TILELANG_CACHE_DIR) / f".staging_{Path(path).name}_{os.getpid()}_{uuid.uuid4().hex[:8]}" + os.makedirs(staging_path) + # Ensure the parent of the final path exists (e.g. ~/.tilelang/cache/autotuner/) + os.makedirs(Path(path).parent, exist_ok=True) - # save out idx (atomic) - if verbose: - logger.debug(f"Saving out idx to file: {path / OUT_IDX_PATH}") - self._safe_write_file( - str(path / OUT_IDX_PATH), - "w", - lambda f: json.dump( - { - "out_idx": list(self.func.attrs["tilelang_out_idx"]) - if (self.func.attrs and "tilelang_out_idx" in self.func.attrs) - else None - }, - f, - ), - ) + try: + # save best config + if verbose: + logger.debug(f"Saving best config to file: {staging_path / BEST_CONFIG_PATH}") + self._safe_write_file(str(staging_path / BEST_CONFIG_PATH), "w", lambda f: json.dump(self.config, f)) - # save ref latency (atomic) - if verbose: - logger.debug(f"Saving latency to file: {path / LATENCY_PATH}") - self._safe_write_file( - str(path / LATENCY_PATH), - "w", - lambda f: json.dump( - { - "latency": self.latency, - "ref_latency": self.ref_latency, - }, - f, - ), - ) + # save function + if verbose: + logger.debug(f"Saving function to file: {staging_path / FUNCTION_PATH}") + self._safe_write_file(str(staging_path / FUNCTION_PATH), "wb", lambda f: cloudpickle.dump(self.func, f)) + + # save out idx + if verbose: + logger.debug(f"Saving out idx to file: {staging_path / OUT_IDX_PATH}") + self._safe_write_file( + str(staging_path / OUT_IDX_PATH), + "w", + lambda f: json.dump( + { + "out_idx": list(self.func.attrs["tilelang_out_idx"]) + if (self.func.attrs and "tilelang_out_idx" in self.func.attrs) + else None, + }, + f, + ), + ) + + # save latency + if verbose: + logger.debug(f"Saving latency to file: {staging_path / LATENCY_PATH}") + self._safe_write_file( + str(staging_path / LATENCY_PATH), + "w", + lambda f: json.dump( + { + "latency": self.latency, + "ref_latency": self.ref_latency, + }, + f, + ), + ) - # save kernel - self._save_kernel_to_disk(path, self.kernel) + # save kernel + self._save_kernel_to_disk(staging_path, self.kernel, verbose) + + missing_files = self._get_missing_complete_result_files(staging_path, self.kernel.execution_backend) + if missing_files: + missing_names = ", ".join(path.name for path in missing_files) + raise RuntimeError(f"Incomplete autotune staging directory is missing required file(s): {missing_names}") + + # Repair stale/incomplete entries before making the new directory visible. + self._remove_incomplete_result_dir(path, self.kernel.execution_backend) + + # Atomic rename — directory becomes visible in one step. + try: + os.rename(str(staging_path), str(path)) + except OSError as exc: + if not self._is_rename_collision(exc): + raise + # Another process won the race with a complete cache entry. + shutil.rmtree(str(staging_path), ignore_errors=True) + except Exception: + shutil.rmtree(str(staging_path), ignore_errors=True) + logger.exception("Error during atomic autotune result save") @classmethod def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> AutotuneResult: @@ -511,3 +522,55 @@ def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> AutotuneResult ref_latency=ref_latency, ) return result + + @staticmethod + def _get_kernel_lib_file(execution_backend: str) -> str: + if execution_backend == "nvrtc": + return KERNEL_CUBIN_PATH + if execution_backend == "tvm_ffi": + return EXECUTABLE_PATH + if execution_backend == "cutedsl": + return KERNEL_PY_PATH + return KERNEL_LIB_PATH + + @classmethod + def _get_required_kernel_files(cls, path: Path, execution_backend: str) -> list[Path]: + files = [path / cls._get_kernel_lib_file(execution_backend)] + if execution_backend == "nvrtc": + files.append(path / KERNEL_PY_PATH) + return files + + @classmethod + def _get_complete_result_files(cls, path: Path, execution_backend: str) -> list[Path]: + return list( + dict.fromkeys( + [ + path / BEST_CONFIG_PATH, + path / FUNCTION_PATH, + path / LATENCY_PATH, + path / DEVICE_KERNEL_PATH, + path / HOST_KERNEL_PATH, + *cls._get_required_kernel_files(path, execution_backend), + path / PARAMS_PATH, + ] + ) + ) + + @classmethod + def _get_missing_complete_result_files(cls, path: Path, execution_backend: str) -> list[Path]: + return [file for file in cls._get_complete_result_files(path, execution_backend) if not file.exists()] + + @classmethod + def _is_complete_result_dir(cls, path: Path, execution_backend: str) -> bool: + return path.is_dir() and not cls._get_missing_complete_result_files(path, execution_backend) + + @classmethod + def _remove_incomplete_result_dir(cls, path: Path, execution_backend: str) -> bool: + if not path.is_dir() or cls._is_complete_result_dir(path, execution_backend): + return False + shutil.rmtree(path) + return True + + @staticmethod + def _is_rename_collision(exc: OSError) -> bool: + return exc.errno in {errno.EEXIST, errno.ENOTEMPTY} diff --git a/tilelang/cache/kernel_cache.py b/tilelang/cache/kernel_cache.py index ff4f3c1f0e..fc3dbc68f4 100644 --- a/tilelang/cache/kernel_cache.py +++ b/tilelang/cache/kernel_cache.py @@ -5,6 +5,7 @@ import functools import json import logging +import errno import os import shutil import threading @@ -131,6 +132,27 @@ def __new__(cls): def _create_dirs(): os.makedirs(env.TILELANG_CACHE_DIR, exist_ok=True) os.makedirs(env.TILELANG_TMP_DIR, exist_ok=True) + KernelCache._cleanup_stale_staging_dirs() + + @staticmethod + def _cleanup_stale_staging_dirs(max_age_seconds: int = 3600): + """Remove staging directories older than *max_age_seconds* (default 1 h). + + These are left behind when a process crashes mid-save. + """ + import time + + try: + now = time.time() + for entry in os.scandir(env.TILELANG_CACHE_DIR): + if entry.name.startswith(".staging_") and entry.is_dir(follow_symlinks=False): + try: + if now - entry.stat().st_mtime > max_age_seconds: + shutil.rmtree(entry.path, ignore_errors=True) + except OSError: + pass + except OSError: + pass def _generate_key( self, @@ -334,52 +356,67 @@ def _safe_write_executable(cls, executable: Executable, path: str): def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = None, verbose: bool = False): """ - Persists a compiled kernel to disk cache. + Persists a compiled kernel to disk cache using atomic directory rename. + + All files are first written into a temporary staging directory under + TILELANG_CACHE_DIR. Once every file is in place, the staging directory + is atomically renamed to the final cache path so that other processes + never observe an incomplete cache entry. Args: key (str): The hash key identifying the kernel. kernel (JITKernel): The compiled kernel to be saved. func (Callable, optional): The original function. verbose (bool): Enable verbose log messages. - - Note: - Saves the following files: - - kernel.cu: The compiled kernel source code - - wrapped_kernel.cu: The wrapped kernel source code - - kernel_lib.so: The compiled kernel library - - params.pkl: The serialized kernel parameters """ cache_path = self._get_cache_path(key) - os.makedirs(cache_path, exist_ok=True) # Ensure directory exists - # Save kernel source code - try: - self._save_kernel_source_code_to_disk(kernel, cache_path, verbose) - except Exception: - self.logger.exception("Error saving kernel source code to disk") + # Another process already wrote a complete entry — nothing to do. + if self._is_complete_cache_dir(cache_path): + return - # Save wrapped kernel source code - try: - self._save_wrapper_kernel_code_to_disk(kernel, cache_path, verbose) - except Exception: - self.logger.exception("Error saving host kernel source code to disk") + # Staging dir lives under CACHE_DIR (same filesystem) so os.rename works. + staging_path = os.path.join( + env.TILELANG_CACHE_DIR, + f".staging_{key}_{os.getpid()}_{uuid.uuid4().hex[:8]}", + ) + os.makedirs(staging_path) - # Save the kernel library try: - # Save CUBIN or SO file - self._save_so_cubin_to_disk(kernel, cache_path, verbose) + # Save kernel source code + self._save_kernel_source_code_to_disk(kernel, staging_path, verbose) - except Exception: - self.logger.exception("Error saving kernel library to disk") + # Save wrapped kernel source code + self._save_wrapper_kernel_code_to_disk(kernel, staging_path, verbose) - # Save kernel parameters - try: - params_path = os.path.join(cache_path, self.params_path) + # Save the kernel library + self._save_so_cubin_to_disk(kernel, staging_path, verbose) + + # Save kernel parameters + params_path = os.path.join(staging_path, self.params_path) if verbose: self.logger.debug(f"Saving kernel parameters to disk: {params_path}") KernelCache._safe_write_file(params_path, "wb", lambda file: cloudpickle.dump(kernel.params, file)) + + missing_files = self._get_missing_complete_cache_files(staging_path) + if missing_files: + missing_names = ", ".join(os.path.basename(path) for path in missing_files) + raise RuntimeError(f"Incomplete cache staging directory is missing required file(s): {missing_names}") + + # Repair stale/incomplete entries before making the new directory visible. + self._remove_incomplete_cache_dir(cache_path) + + # Atomic rename — makes the complete directory visible in one step. + try: + os.rename(staging_path, cache_path) + except OSError as exc: + if not self._is_rename_collision(exc): + raise + # Another process won the race with a complete cache entry. + shutil.rmtree(staging_path, ignore_errors=True) except Exception: - self.logger.exception("Error saving kernel parameters to disk") + shutil.rmtree(staging_path, ignore_errors=True) + self.logger.exception("Error during atomic cache save") def _load_kernel_from_disk( self, @@ -489,6 +526,33 @@ def _get_required_files(self, cache_path: str) -> list[str]: params_path = os.path.join(cache_path, self.params_path) return [kernel_lib_path, params_path] + def _get_complete_cache_files(self, cache_path: str) -> list[str]: + return list( + dict.fromkeys( + [ + os.path.join(cache_path, self.device_kernel_path), + os.path.join(cache_path, self.host_kernel_path), + *self._get_required_files(cache_path), + ] + ) + ) + + def _get_missing_complete_cache_files(self, cache_path: str) -> list[str]: + return [file for file in self._get_complete_cache_files(cache_path) if not os.path.exists(file)] + + def _is_complete_cache_dir(self, cache_path: str) -> bool: + return os.path.isdir(cache_path) and not self._get_missing_complete_cache_files(cache_path) + + def _remove_incomplete_cache_dir(self, cache_path: str) -> bool: + if not os.path.isdir(cache_path) or self._is_complete_cache_dir(cache_path): + return False + shutil.rmtree(cache_path) + return True + + @staticmethod + def _is_rename_collision(exc: OSError) -> bool: + return exc.errno in {errno.EEXIST, errno.ENOTEMPTY} + def _load_kernel_source(self, device_kernel_path: str, host_kernel_path: str, verbose: bool = False) -> tuple[str | None, str | None]: try: if verbose: diff --git a/tilelang/jit/adapter/nvrtc/kernel_cache.py b/tilelang/jit/adapter/nvrtc/kernel_cache.py index 754ab61479..fb71b76418 100644 --- a/tilelang/jit/adapter/nvrtc/kernel_cache.py +++ b/tilelang/jit/adapter/nvrtc/kernel_cache.py @@ -8,6 +8,9 @@ class NVRTCKernelCache(KernelCache): kernel_lib_path = "kernel.cubin" kernel_py_path = "kernel.py" + def _get_required_files(self, cache_path: str) -> list[str]: + return super()._get_required_files(cache_path) + [os.path.join(cache_path, self.kernel_py_path)] + def _save_so_cubin_to_disk(self, kernel: JITKernel, cache_path: str, verbose: bool = False): src_lib_path = kernel.adapter.libpath kernel_py_path = os.path.join(cache_path, self.kernel_py_path) From a70fa263ad322266a63f0c4a7f1c7fa948341496 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Fri, 17 Apr 2026 11:18:19 +0800 Subject: [PATCH 075/156] refactor shared memory buffer merge --- .../merge_shared_memory_allocations.cc | 118 +++++++++--------- 1 file changed, 60 insertions(+), 58 deletions(-) diff --git a/src/transform/merge_shared_memory_allocations.cc b/src/transform/merge_shared_memory_allocations.cc index 6d686a2519..511f0400dd 100644 --- a/src/transform/merge_shared_memory_allocations.cc +++ b/src/transform/merge_shared_memory_allocations.cc @@ -272,74 +272,63 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { linear_seq_[begin_index].scope_pair_offset = end_index - begin_index; } - // Visit kAutoScheduleSharedMemoryBoundary bounded scopes. - // - // After ReNestLetStmts, the next boundary marker may be nested inside - // LetStmt / non-boundary AttrStmt chains rather than sitting as a direct - // sibling in a SeqStmt. We therefore recursively peel through such - // wrappers to find the innermost SeqStmt and locate the boundary there. - void VisitBoundedNewScopes(const AttrStmtNode *op) { + // Open a new boundary scope for a kAutoScheduleSharedMemoryBoundary marker. + // Pushes a scope sentinel onto linear_seq_ and records the begin index. + void OpenBoundaryScope(const AttrStmtNode *op) { scope_.push_back(StmtEntry()); StmtEntry e; e.stmt = op; UpdateStmtAttr(op, scope_level_); - int64_t begin_index = static_cast(linear_seq_.size()); - // before scope. + boundary_scope_begin_index_ = static_cast(linear_seq_.size()); linear_seq_.push_back(e); - bool has_tail_stmt = false; - const AttrStmtNode *tail_stmt = nullptr; - - // Recursively visit the body, peeling LetStmt / non-boundary AttrStmt - // wrappers. When a SeqStmt is reached, scan its children for the next - // boundary marker. Everything that is not the boundary is visited - // normally so that buffer accesses are recorded in this scope. - std::function VisitBodyFindBoundary = - [&](const Stmt &body) { - if (const auto *seq = body.as()) { - for (const auto &sub_stmt : seq->seq) { - if (const auto *attr = sub_stmt.as(); - attr && - attr->attr_key == attr::kAutoScheduleSharedMemoryBoundary) { - has_tail_stmt = true; - tail_stmt = attr; - } else { - StmtExprVisitor::VisitStmt(sub_stmt); - } - } - } else if (const auto *let = body.as()) { - // Record the let-binding variable/value, then recurse into body. - StmtExprVisitor::VisitExpr(let->value); - VisitBodyFindBoundary(let->body); - } else if (const auto *attr = body.as()) { - if (attr->attr_key == attr::kAutoScheduleSharedMemoryBoundary) { - // The body itself is a boundary — treat it as the tail. - has_tail_stmt = true; - tail_stmt = attr; - } else { - // Non-boundary AttrStmt wrapper — visit value and recurse. - StmtExprVisitor::VisitExpr(attr->value); - VisitBodyFindBoundary(attr->body); - } - } else { - // Any other statement — visit normally. - StmtExprVisitor::VisitStmt(body); - } - }; - - VisitBodyFindBoundary(op->body); + in_boundary_scope_ = true; + } - // after scope. + // Close the current boundary scope. Pops the scope, writes the end + // sentinel and patches up the scope_pair_offset links. + void CloseBoundaryScope(const AttrStmtNode *op) { + ICHECK(in_boundary_scope_); + StmtEntry e; + e.stmt = op; + UpdateStmtAttr(op, scope_level_); e.touched = std::move(scope_.back().touched); scope_.pop_back(); int64_t end_index = static_cast(linear_seq_.size()); - ICHECK_GT(end_index, begin_index); - e.scope_pair_offset = begin_index - end_index; + ICHECK_GT(end_index, boundary_scope_begin_index_); + e.scope_pair_offset = boundary_scope_begin_index_ - end_index; linear_seq_.push_back(e); ICHECK_NE(end_index, 0U); - linear_seq_[begin_index].scope_pair_offset = end_index - begin_index; - // visit tail statement (the next boundary scope). - if (has_tail_stmt) { - StmtExprVisitor::VisitStmt_(tail_stmt); + linear_seq_[boundary_scope_begin_index_].scope_pair_offset = + end_index - boundary_scope_begin_index_; + in_boundary_scope_ = false; + } + + // Recursively visit the body of a boundary AttrStmt, peeling through + // LetStmt / non-boundary AttrStmt / SeqStmt wrappers. When a nested + // boundary marker is encountered it is dispatched back through + // VisitStmt_ which will close the current scope and open a new one. + void VisitBoundaryBody(const Stmt &body) { + if (const auto *seq = body.as()) { + for (const auto &sub_stmt : seq->seq) { + if (const auto *attr = sub_stmt.as(); + attr && + attr->attr_key == attr::kAutoScheduleSharedMemoryBoundary) { + this->VisitStmt_(attr); + } else { + StmtExprVisitor::VisitStmt(sub_stmt); + } + } + } else if (const auto *let = body.as()) { + StmtExprVisitor::VisitExpr(let->value); + VisitBoundaryBody(let->body); + } else if (const auto *attr = body.as()) { + if (attr->attr_key == attr::kAutoScheduleSharedMemoryBoundary) { + this->VisitStmt_(attr); + } else { + VisitBoundaryBody(attr->body); + } + } else { + StmtExprVisitor::VisitStmt(body); } } @@ -356,7 +345,16 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { } else if (op->attr_key == "kWarpSpecializationScope") { VisitWarpSpecializationBody(op->body); } else if (op->attr_key == "kAutoScheduleSharedMemoryBoundary") { - VisitBoundedNewScopes(op); + if (in_boundary_scope_) { + CloseBoundaryScope( + static_cast( + linear_seq_[boundary_scope_begin_index_].stmt)); + } + OpenBoundaryScope(op); + VisitBoundaryBody(op->body); + if (in_boundary_scope_) { + CloseBoundaryScope(op); + } } else { StmtExprVisitor::VisitStmt_(op); } @@ -437,6 +435,10 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { bool verbose_{false}; // Whether already in thread env. bool in_thread_env_{false}; + // Whether we are currently inside a boundary scope. + bool in_boundary_scope_{false}; + // The begin index in linear_seq_ of the current boundary scope. + int64_t boundary_scope_begin_index_{0}; // The scope stack. std::vector scope_; // The size of the scope. From 189c99fab76df241a44e8b101d282c8abd6b18a8 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Fri, 17 Apr 2026 12:42:02 +0800 Subject: [PATCH 076/156] upload latency & ii --- .../auto_schedule/schedule_builder.cc | 22 +++++++++++++++++-- tilelang/transform/z3_scheduler.py | 2 +- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index 45d8cecb3a..dde0c242f1 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -90,7 +90,6 @@ void GatherTaskNodes(const std::vector> &nodes, } else if (node->IsControl()) { task_nodes.emplace_back(node); } else if (node->IsIf()) { - // IfNode is atomic — add as whole unit, don't decompose task_nodes.emplace_back(node); } else { LOG(FATAL) << "Unknown node type in GatherTaskNodes"; @@ -610,6 +609,12 @@ void ScheduleUnitBuilder::ScheduleRecursive( } seq->children = ChildrenScheduleHelper(origin_children); + int64_t overall_latency = 0; + for (const auto &child : seq->children) { + overall_latency += child->GetLatency(); + } + seq->SetLatency(overall_latency); + seq->SetII(overall_latency); return; } else if (node->IsControl()) { auto ctrl = static_cast(node.get()); @@ -655,6 +660,11 @@ void ScheduleUnitBuilder::ScheduleRecursive( Z3SchedulePythonLoop(ctrl, used_buffers); } else { ScheduleRecursive(ctrl->child, used_buffers); + auto old_child = ctrl->child; + auto seq_node = std::make_shared(); + seq_node->children = {old_child}; + ctrl->child = seq_node; + Z3SchedulePythonLoop(ctrl, used_buffers); } } return; @@ -678,10 +688,15 @@ void ScheduleUnitBuilder::ScheduleRecursive( } auto seq_node = std::make_shared(); seq_node->children = ChildrenScheduleHelper(origin_children); + int64_t overall_latency = 0; + for (const auto &child : seq_node->children) { + overall_latency += child->GetLatency(); + } + seq_node->SetLatency(overall_latency); + seq_node->SetII(overall_latency); node = seq_node; return; } else if (node->IsIf()) { - // IfNode: recursively schedule both branches internally auto if_node = static_cast(node.get()); if (if_node->then_child) { ScheduleRecursive(if_node->then_child, used_buffers); @@ -689,6 +704,9 @@ void ScheduleUnitBuilder::ScheduleRecursive( if (if_node->else_child) { ScheduleRecursive(if_node->else_child, used_buffers); } + if_node->SetLatency(std::max(if_node->then_child ? if_node->then_child->GetLatency() : 0, + if_node->else_child ? if_node->else_child->GetLatency() : 0)); + if_node->SetII(if_node->GetLatency()); return; } diff --git a/tilelang/transform/z3_scheduler.py b/tilelang/transform/z3_scheduler.py index ebbaf7f4e2..d515db1602 100644 --- a/tilelang/transform/z3_scheduler.py +++ b/tilelang/transform/z3_scheduler.py @@ -323,7 +323,7 @@ def z3_schedule_loop_python( for i in range(len(buffer_sizes)): solver.add(buffer_vars[i] >= 1) solver.add(buffer_vars[i] <= num_stages) - # solver.add(z3.Sum([buffer_vars[i] * buffer_sizes[i] for i in range(len(buffer_sizes))]) <= memory_limit) + solver.add(z3.Sum([buffer_vars[i] * buffer_sizes[i] for i in range(len(buffer_sizes))]) <= memory_limit) # Add data dependency constraints with distance for u, v, distance in data_deps: From cb00a60c37c3bfa6185c419c20c99e021a991c85 Mon Sep 17 00:00:00 2001 From: LJC00118 <77378439+LJC00118@users.noreply.github.com> Date: Fri, 17 Apr 2026 13:37:05 +0800 Subject: [PATCH 077/156] Replace syntactic loop-var checks with invariance checks (#2050) * Replace syntactic loop-var checks with invariance checks * Fix lint --- src/transform/loop_vectorize.cc | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index c2655212b8..6f6dd239e1 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -257,15 +257,17 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { // reduction-like pattern where ComputeBufferVectorSize has already // returned 1 to disable vectorization, and that constraint must not // be dropped (memory strategy ignores local_fragment_min). - bool depends_on_loop_var = - !info.indices.empty() && inner_for_ && - std::any_of(info.indices.begin(), info.indices.end(), - [&](const PrimExpr &idx) { - return UsesVar(idx, [&](const VarNode *v) { - return v == inner_for_->loop_var.get(); - }); - }); - if (depends_on_loop_var || info.is_store) { + bool depends_on_loop_var = true; + if (!info.indices.empty() && inner_for_) { + Array strides = GetBufferStrides(info.buffer); + PrimExpr elem_offset = 0; + for (size_t i = 0; i < info.indices.size(); ++i) { + elem_offset += info.indices[i] * strides[i]; + } + depends_on_loop_var = !IsExprInvariantInVectorBoundary( + elem_offset, inner_for_->loop_var, vector_size_, analyzer_); + } + if (depends_on_loop_var) { memory_min = arith::ZeroAwareGCD(memory_min, info.vector_size); has_global_or_shared_buffer = true; } else { From 27f1f81c61291b0fe27d28fbe81a1fa15ae8995e Mon Sep 17 00:00:00 2001 From: Tong WU <109033598+Rachmanino@users.noreply.github.com> Date: Fri, 17 Apr 2026 15:19:04 +0800 Subject: [PATCH 078/156] [Feature][Example] Introduce CLC tile schedule and add example for sm100 GEMM (#2029) * [Feature] Add Blackwell Cluster Launch Control (CLC) primitives Expose cluster launch control query operations (try_cancel, try_cancel_multicast, is_canceled, get_first_ctaid_{x,y,z}) in both the C++ codegen path and the Python tilelang.language API, plus a new mbarrier_arrive_expect_tx builtin used with CLC multicast completion. Enables writing persistent SM100 kernels that use the hardware scheduler to fetch the next tile coordinate. Co-Authored-By: Claude Opus 4.6 (1M context) * [Example] Add SM100 persistent 2-CTA GEMM example using CLC scheduling Adds examples/gemm_sm100/gemm_tcgen5mma_ws_clc.py, a warp-specialized persistent GEMM for Blackwell (sm_100) that uses tcgen05 2-CTA MMA with cluster launch control tile scheduling. The kernel is split into producer/MMA/scheduler/epilogue warp groups, with the scheduler warp running on both CTAs and feeding next-tile coordinates via the CLC hardware multicast into schedule_arrived directly (no intermediate barrier). Storage sync is disabled and all syncs are placed manually. Co-Authored-By: Claude Opus 4.6 (1M context) * lint * typo Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * remove unnecessary test * Add CLC query intrinsics and async fence * Refactor CLC response handling by introducing CLCResponseDecode struct. Update clc_get_first_ctaid_x/y/z and clc_is_canceled functions to utilize the new decode response method for improved clarity and maintainability. * Enhance CLC scheduling in persistent GEMM by introducing shared variables for schedule validity and tile ID. Update cancellation checks to utilize the new shared variables, improving clarity and maintainability of the scheduling logic. * test fix --------- Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: Zhiwen Mo Co-authored-by: LeiWang1999 --- examples/gemm_sm100/gemm_tcgen5mma_ws_clc.py | 212 ++++++++++++++++++ src/op/builtin.cc | 30 +++ src/op/builtin.h | 49 ++++ src/target/codegen_cuda.cc | 18 ++ src/tl_templates/cuda/cluster.h | 109 +++++++++ .../test_tilelang_language_cluster.py | 23 ++ tilelang/language/__init__.py | 6 + tilelang/language/builtin.py | 8 + tilelang/language/cluster.py | 71 ++++++ 9 files changed, 526 insertions(+) create mode 100644 examples/gemm_sm100/gemm_tcgen5mma_ws_clc.py diff --git a/examples/gemm_sm100/gemm_tcgen5mma_ws_clc.py b/examples/gemm_sm100/gemm_tcgen5mma_ws_clc.py new file mode 100644 index 0000000000..10ac66937a --- /dev/null +++ b/examples/gemm_sm100/gemm_tcgen5mma_ws_clc.py @@ -0,0 +1,212 @@ +# Introduce CLC tile schedule + +import torch +import tilelang +import tilelang.language as T +from tilelang.profiler import do_bench + + +def get_swizzled_block_idx(tile_id, group_size, m_clusters, cta_id): + bx_cluster = (tile_id // group_size) % m_clusters + bx = bx_cluster * 2 + cta_id + by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size + return bx, by + + +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True}) +def gemm_clc_persistent_2cta( + A, + B, + block_M, + block_N, + store_block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + group_size=8, + use_tma_store=True, +): + M, N, K = T.const("M, N, K") + + A: T.Tensor[[M, K], in_dtype] + B: T.Tensor[[K, N], in_dtype] + C = T.empty((M, N), out_dtype) + + m_blocks = T.ceildiv(M, block_M) + m_clusters = m_blocks // 2 + n_blocks = T.ceildiv(N, block_N) + total_cluster_tiles = m_clusters * n_blocks + k_blocks = T.ceildiv(K, block_K) + assert n_blocks % (2 * group_size) == 0 + + with T.Kernel(total_cluster_tiles * 2, threads=256, cluster_dims=2) as block_id: + A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) + B_shared = T.alloc_shared((num_stages, block_K, block_N // 2), in_dtype) + C_tmem_0 = T.alloc_tmem([block_M, block_N], accum_dtype) + C_tmem_1 = T.alloc_tmem([block_M, block_N], accum_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype) + C_shared = T.alloc_shared((block_M, store_block_N), out_dtype) + loaded = T.alloc_cluster_barrier([32 * 2] * num_stages) + consumed = T.alloc_cluster_barrier([1] * num_stages) + tmem_full = T.alloc_cluster_barrier([1] * 2) + tmem_empty = T.alloc_cluster_barrier([128 * 2] * 2) + schedule_arrived = T.alloc_cluster_barrier([1]) + schedule_finished = T.alloc_cluster_barrier([7]) + clc_result = T.alloc_shared((4,), "uint32", scope="shared") + schedule_valid = T.alloc_shared((1,), "int32") + schedule_tile_id = T.alloc_shared((1,), "int32") + + tx = T.get_thread_binding() + cta_id = T.block_rank_in_cluster() + T.assume(cta_id < 2) + + if tx < 32: # Producer (TMA loads) + for work_iter in T.unroll(total_cluster_tiles): + if work_iter > 0: + T.mbarrier_wait_parity(schedule_arrived, (work_iter - 1) & 1) + if tx == 0: + T.mbarrier_arrive(schedule_finished, 0) + if schedule_valid[0] == 0: + break + + tile_id = T.if_then_else( + work_iter == 0, + block_id // 2, + schedule_tile_id[0], + ) + bx, by = get_swizzled_block_idx(tile_id, group_size, m_clusters, cta_id) + + for k in T.serial(k_blocks): + phase = work_iter * k_blocks + k + T.mbarrier_wait_parity(consumed[phase % num_stages], ((phase // num_stages) & 1) ^ 1) + T.tma_copy( + A[bx * block_M : (bx + 1) * block_M, k * block_K : (k + 1) * block_K], + A_shared[phase % num_stages, :, :], + barrier=loaded[phase % num_stages], + ) + T.tma_copy( + B[k * block_K : (k + 1) * block_K, (by * 2 + cta_id) * block_N // 2 : (by * 2 + cta_id + 1) * block_N // 2], + B_shared[phase % num_stages, :, :], + barrier=loaded[phase % num_stages], + ) + T.mbarrier_arrive(loaded[phase % num_stages], 0) + + elif cta_id == 0 and tx < 64: # MMA (cta_id 0 only) + for work_iter in T.unroll(total_cluster_tiles): + if work_iter > 0: + T.mbarrier_wait_parity(schedule_arrived, (work_iter - 1) & 1) + if tx == 32: + T.mbarrier_arrive(schedule_finished, 0) + if schedule_valid[0] == 0: + break + + T.mbarrier_wait_parity(tmem_empty[work_iter & 1], ((work_iter // 2) & 1) ^ 1) + for k in T.serial(k_blocks): + phase = work_iter * k_blocks + k + T.mbarrier_wait_parity(loaded[phase % num_stages], (phase // num_stages) & 1) + if work_iter & 1 == 0: + T.tcgen05_gemm( + A_shared[phase % num_stages, :, :], + B_shared[phase % num_stages, :, :], + C_tmem_0, + mbar=consumed[phase % num_stages], + clear_accum=k == 0, + use_2cta=True, + ) + else: + T.tcgen05_gemm( + A_shared[phase % num_stages, :, :], + B_shared[phase % num_stages, :, :], + C_tmem_1, + mbar=consumed[phase % num_stages], + clear_accum=k == 0, + use_2cta=True, + ) + T.tcgen05_mma_arrive(tmem_full[work_iter & 1], arrive_2cta=True) + + elif 64 <= tx < 96: # CLC Scheduler (both CTAs) + for work_iter in T.unroll(total_cluster_tiles): + if tx == 64: + if cta_id == 0 and work_iter > 0: + T.mbarrier_wait_parity(schedule_finished, (work_iter - 1) & 1) + T.mbarrier_arrive_expect_tx(schedule_arrived, 16) + if cta_id == 0: + T.clc_try_cancel_multicast(clc_result, schedule_arrived) + T.mbarrier_wait_parity(schedule_arrived, work_iter & 1) + schedule_valid[0] = T.clc_is_canceled(clc_result) + schedule_tile_id[0] = T.cast(T.clc_get_first_ctaid_x(clc_result), "int32") // 2 + T.mbarrier_arrive(schedule_finished, 0) + if schedule_valid[0] == 0: + break + + elif 128 <= tx < 256: # Epilogue + for work_iter in T.unroll(total_cluster_tiles): + if work_iter > 0: + T.mbarrier_wait_parity(schedule_arrived, (work_iter - 1) & 1) + if tx == 128: + T.mbarrier_arrive(schedule_finished, 0) + if schedule_valid[0] == 0: + break + + tile_id = T.if_then_else( + work_iter == 0, + block_id // 2, + schedule_tile_id[0], + ) + bx, by = get_swizzled_block_idx(tile_id, group_size, m_clusters, cta_id) + + T.mbarrier_wait_parity(tmem_full[work_iter & 1], (work_iter // 2) & 1) + T.sync_threads(1, 128) + if work_iter & 1 == 0: + T.copy(C_tmem_0, C_local) + else: + T.copy(C_tmem_1, C_local) + T.mbarrier_arrive(tmem_empty[work_iter & 1], 0) + + if use_tma_store: + for i in T.unroll(T.ceildiv(block_N, store_block_N)): + T.copy(C_local[:, i * store_block_N : (i + 1) * store_block_N], C_shared) + T.sync_threads(3, 128) + T.copy(C_shared, C[bx * block_M, by * block_N + i * store_block_N]) + T.sync_threads(3, 128) + else: + T.copy(C_local, C_local_cast) + T.copy(C_local_cast, C[bx * block_M, by * block_N]) + + return C + + +def main(): + M, N, K = 8192, 8192, 8192 + block_M, block_N, block_K = 128, 256, 64 + store_block_N = 64 + in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float + num_stages = 6 + l2_swizzle_group_size = 8 + + kernel_args = (block_M, block_N, store_block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages, l2_swizzle_group_size) + + # a = (torch.rand(M, K, device="cuda", dtype=torch.bfloat16) * 2 - 1) + # b = (torch.rand(K, N, device="cuda", dtype=torch.bfloat16) * 2 - 1) + a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) + print(gemm_clc_persistent_2cta.get_kernel_source(a, b, *kernel_args)) + c = gemm_clc_persistent_2cta(a, b, *kernel_args) + + ref_c = (a.to(torch.float) @ b.to(torch.float)).to(torch.bfloat16) + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("All checks passed. ✅") + + tl_latency = do_bench(lambda: gemm_clc_persistent_2cta(a, b, *kernel_args), backend="cupti") + torch_latency = do_bench(lambda: a @ b, backend="cupti") + print(f"Tilelang latency: {tl_latency} ms") + print(f"Flops: {2 * M * N * K / (tl_latency / 1e3) / 1e12} TFLOPS") + print(f"Torch latency: {torch_latency} ms") + print(f"Flops: {2 * M * N * K / (torch_latency / 1e3) / 1e12} TFLOPS") + + +if __name__ == "__main__": + main() diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 1ccbf03f9e..7de31d5205 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -364,6 +364,36 @@ TIR_DEFINE_TL_BUILTIN(block_rank_in_cluster) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); +TIR_DEFINE_TL_BUILTIN(clc_try_cancel) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(clc_try_cancel_multicast) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(clc_is_canceled) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(clc_get_first_ctaid_x) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(clc_get_first_ctaid_y) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(clc_get_first_ctaid_z) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + TIR_DEFINE_TL_BUILTIN(sync_grid).set_num_inputs(0).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/builtin.h b/src/op/builtin.h index 93f9dad82a..dae5503a68 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -609,6 +609,55 @@ TVM_DLL const Op &cluster_sync(); */ TVM_DLL const Op &block_rank_in_cluster(); +/*! + * \brief Issue a Blackwell cluster launch control query that writes a 16-byte + * response into shared memory and signals completion on the given mbarrier. + * + * clc_try_cancel(result_ptr, mbar_ptr) + * + */ +TVM_DLL const Op &clc_try_cancel(); + +/*! + * \brief Cluster-wide multicast variant of cluster launch control query. + * + * clc_try_cancel_multicast(result_ptr, mbar_ptr) + * + */ +TVM_DLL const Op &clc_try_cancel_multicast(); + +/*! + * \brief Return 1 when a CLC response represents a successful cancellation. + * + * int32 clc_is_canceled(result_ptr) + * + */ +TVM_DLL const Op &clc_is_canceled(); + +/*! + * \brief Return the x coordinate of the first CTA in a successful CLC response. + * + * uint32 clc_get_first_ctaid_x(result_ptr) + * + */ +TVM_DLL const Op &clc_get_first_ctaid_x(); + +/*! + * \brief Return the y coordinate of the first CTA in a successful CLC response. + * + * uint32 clc_get_first_ctaid_y(result_ptr) + * + */ +TVM_DLL const Op &clc_get_first_ctaid_y(); + +/*! + * \brief Return the z coordinate of the first CTA in a successful CLC response. + * + * uint32 clc_get_first_ctaid_z(result_ptr) + * + */ +TVM_DLL const Op &clc_get_first_ctaid_z(); + /*! * \brief Synchronize all threads in a grid * diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 90eff80e00..1ff51883bd 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -2267,6 +2267,24 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { } else if (op->op.same_as(tl::block_rank_in_cluster())) { need_cluster_h_ = true; os << "tl::block_rank_in_cluster()"; + } else if (op->op.same_as(tl::clc_try_cancel())) { + need_cluster_h_ = true; + print_extern_call_stmt("tl::clc_try_cancel"); + } else if (op->op.same_as(tl::clc_try_cancel_multicast())) { + need_cluster_h_ = true; + print_extern_call_stmt("tl::clc_try_cancel_multicast"); + } else if (op->op.same_as(tl::clc_is_canceled())) { + need_cluster_h_ = true; + os << "tl::clc_is_canceled(" << this->PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::clc_get_first_ctaid_x())) { + need_cluster_h_ = true; + os << "tl::clc_get_first_ctaid_x(" << this->PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::clc_get_first_ctaid_y())) { + need_cluster_h_ = true; + os << "tl::clc_get_first_ctaid_y(" << this->PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::clc_get_first_ctaid_z())) { + need_cluster_h_ = true; + os << "tl::clc_get_first_ctaid_z(" << this->PrintExpr(op->args[0]) << ")"; } else if (op->op.same_as(tl::loop_break())) { this->PrintIndent(); this->stream << "break;\n"; diff --git a/src/tl_templates/cuda/cluster.h b/src/tl_templates/cuda/cluster.h index 0da4dc904d..b353f58be9 100644 --- a/src/tl_templates/cuda/cluster.h +++ b/src/tl_templates/cuda/cluster.h @@ -105,6 +105,115 @@ TL_DEVICE int block_rank_in_cluster() { #endif } +/* Cluster launch control for tile schedule (Available on sm100) */ + +TL_DEVICE void clc_try_cancel(void *result_ptr, void *mbar_ptr) { +#if defined(CUTLASS_ARCH_CLC_ENABLED) + uint32_t result_addr = smem_ptr_to_uint(result_ptr); + uint32_t mbar_addr = smem_ptr_to_uint(mbar_ptr); + asm volatile("{\n\t" + "clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::" + "complete_tx::bytes.b128 [%0], [%1];\n\t" + "}\n" + : + : "r"(result_addr), "r"(mbar_addr)); +#else + TILELANG_UNREACHABLE("CUTLASS_ARCH_CLC_ENABLED is not defined"); +#endif +} + +TL_DEVICE void clc_try_cancel_multicast(void *result_ptr, void *mbar_ptr) { +#if defined(CUTLASS_ARCH_CLC_ENABLED) + uint32_t result_addr = smem_ptr_to_uint(result_ptr); + uint32_t mbar_addr = smem_ptr_to_uint(mbar_ptr); + asm volatile("{\n\t" + "clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::" + "complete_tx::bytes.multicast::cluster::all.b128 [%0], [%1];\n\t" + "}\n" + : + : "r"(result_addr), "r"(mbar_addr)); +#else + TILELANG_UNREACHABLE("CUTLASS_ARCH_CLC_ENABLED is not defined"); +#endif +} + +// CLC query responses are produced through the async shared-memory proxy and +// must be fenced before normal shared-memory loads decode the 16-byte result. +TL_DEVICE void clc_fence_proxy_async_shared_cta() { +#if defined(CUTLASS_ARCH_CLC_ENABLED) + asm volatile("fence.proxy.async.shared::cta;" : : : "memory"); +#else + TILELANG_UNREACHABLE("CUTLASS_ARCH_CLC_ENABLED is not defined"); +#endif +} + +struct CLCResponseDecode { + uint32_t x = 0; + uint32_t y = 0; + uint32_t z = 0; + uint32_t is_canceled = 0; +}; + +TL_DEVICE CLCResponseDecode clc_decode_response(void const *result_ptr) { +#if defined(CUTLASS_ARCH_CLC_ENABLED) + uint32_t result_addr = smem_ptr_to_uint(result_ptr); + CLCResponseDecode decoded; + asm volatile( + "{\n\t" + ".reg .pred p1;\n\t" + ".reg .b128 clc_result;\n\t" + "ld.shared.b128 clc_result, [%4];\n\t" + "clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p1, " + "clc_result;\n\t" + "selp.u32 %3, 1, 0, p1;\n\t" + "@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128 {%0, " + "%1, %2, _}, clc_result;\n\t" + "}\n" + : "=r"(decoded.x), "=r"(decoded.y), "=r"(decoded.z), + "=r"(decoded.is_canceled) + : "r"(result_addr) + : "memory"); + // Match CUTLASS's CLC decode path ordering: decode from shared first, then + // issue the async proxy fence before subsequent shared-memory consumers. + clc_fence_proxy_async_shared_cta(); + return decoded; +#else + TILELANG_UNREACHABLE("CUTLASS_ARCH_CLC_ENABLED is not defined"); +#endif +} + +TL_DEVICE int clc_is_canceled(void const *result_ptr) { +#if defined(CUTLASS_ARCH_CLC_ENABLED) + return static_cast(clc_decode_response(result_ptr).is_canceled); +#else + TILELANG_UNREACHABLE("CUTLASS_ARCH_CLC_ENABLED is not defined"); +#endif +} + +TL_DEVICE uint32_t clc_get_first_ctaid_x(void const *result_ptr) { +#if defined(CUTLASS_ARCH_CLC_ENABLED) + return clc_decode_response(result_ptr).x; +#else + TILELANG_UNREACHABLE("CUTLASS_ARCH_CLC_ENABLED is not defined"); +#endif +} + +TL_DEVICE uint32_t clc_get_first_ctaid_y(void const *result_ptr) { +#if defined(CUTLASS_ARCH_CLC_ENABLED) + return clc_decode_response(result_ptr).y; +#else + TILELANG_UNREACHABLE("CUTLASS_ARCH_CLC_ENABLED is not defined"); +#endif +} + +TL_DEVICE uint32_t clc_get_first_ctaid_z(void const *result_ptr) { +#if defined(CUTLASS_ARCH_CLC_ENABLED) + return clc_decode_response(result_ptr).z; +#else + TILELANG_UNREACHABLE("CUTLASS_ARCH_CLC_ENABLED is not defined"); +#endif +} + // Set the destination block-ID in cluster for a given SMEM Address TL_DEVICE uint32_t set_block_rank(uint32_t smemAddr, uint32_t rank) { #if defined(TILELANG_CLUSTER_ENABLED) diff --git a/testing/python/language/test_tilelang_language_cluster.py b/testing/python/language/test_tilelang_language_cluster.py index a45ff52d9b..0e3ee9d38c 100644 --- a/testing/python/language/test_tilelang_language_cluster.py +++ b/testing/python/language/test_tilelang_language_cluster.py @@ -59,6 +59,18 @@ def main(A: T.Tensor((128), T.int32)): return main +def _get_clc_query_codegen_source() -> str: + @T.prim_func + def main(A: T.Tensor((2,), T.int32)): + with T.Kernel(1, threads=1): + result = T.alloc_shared((4,), T.uint32) + A[0] = T.clc_is_canceled(result) + A[1] = T.Cast("int32", T.clc_get_first_ctaid_x(result)) + + artifact = tilelang.lower(main, target="cuda") + return artifact.kernel_source + + def run_cython_cluster_launch(): kernel = matmul(1024, 1024, 1024, 128, 128, 32) mod = tilelang.compile(kernel, execution_backend="cython") @@ -105,5 +117,16 @@ def test_cluster_barrier(): kernel() +@tilelang.testing.requires_cuda +def test_clc_query_codegen_includes_cluster_header(): + src = _get_clc_query_codegen_source() + print("=== clc query codegen ===") + print(src) + + assert "#include " in src + assert "tl::clc_is_canceled(" in src + assert "tl::clc_get_first_ctaid_x(" in src + + if __name__ == "__main__": tilelang.testing.main() diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 8308d762c2..323ee7e633 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -141,6 +141,12 @@ cluster_wait, # noqa: F401 cluster_sync, # noqa: F401 block_rank_in_cluster, # noqa: F401 + clc_try_cancel, # noqa: F401 + clc_try_cancel_multicast, # noqa: F401 + clc_is_canceled, # noqa: F401 + clc_get_first_ctaid_x, # noqa: F401 + clc_get_first_ctaid_y, # noqa: F401 + clc_get_first_ctaid_z, # noqa: F401 ) diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 150d32a311..525e7c4320 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -496,6 +496,14 @@ def mbarrier_expect_tx(mbarrier: BarrierType, tx: int): return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_expect_tx"), mbarrier, tx) +def mbarrier_arrive_expect_tx(mbarrier: BarrierType, tx: int): + """Arrive at a memory barrier and expect completion of async transactions.""" + from tilelang.language.tir.op import ptx_arrive_barrier_expect_tx + + mbarrier = _mbar_to_buffer_load(mbarrier) + return ptx_arrive_barrier_expect_tx(mbarrier, tx) + + def warpgroup_arrive(): """Signal warpgroup readiness for subsequent WGMMA operations. diff --git a/tilelang/language/cluster.py b/tilelang/language/cluster.py index d10a12c348..60fd36a630 100644 --- a/tilelang/language/cluster.py +++ b/tilelang/language/cluster.py @@ -1,4 +1,7 @@ from tvm import tir +from tvm.tir import BufferLoad + +from tilelang.utils.language import retrieve_ptr __all__ = [ "cluster_arrive_relaxed", @@ -6,9 +9,21 @@ "cluster_wait", "cluster_sync", "block_rank_in_cluster", + "clc_try_cancel", + "clc_try_cancel_multicast", + "clc_is_canceled", + "clc_get_first_ctaid_x", + "clc_get_first_ctaid_y", + "clc_get_first_ctaid_z", ] +def _to_ptr(value, access_type: str): + if isinstance(value, BufferLoad): + return retrieve_ptr(value, access_type=access_type) + return retrieve_ptr(value, access_type=access_type) + + def cluster_arrive_relaxed() -> tir.PrimExpr: """Issue barrier.cluster.arrive.relaxed.aligned.""" return tir.call_intrin("void", tir.op.Op.get("tl.cluster_arrive_relaxed")) @@ -32,3 +47,59 @@ def cluster_sync() -> tir.PrimExpr: def block_rank_in_cluster() -> tir.PrimExpr: """Return the 1-D rank of the calling CTA within its cluster (%%cluster_ctarank).""" return tir.call_intrin("int32", tir.op.Op.get("tl.block_rank_in_cluster")) + + +def clc_try_cancel(result, mbarrier) -> tir.PrimExpr: + """Issue a single-CTA cluster launch control query.""" + return tir.call_intrin( + "void", + tir.op.Op.get("tl.clc_try_cancel"), + _to_ptr(result, "w"), + _to_ptr(mbarrier, "rw"), + ) + + +def clc_try_cancel_multicast(result, mbarrier) -> tir.PrimExpr: + """Issue a cluster-wide multicast cluster launch control query.""" + return tir.call_intrin( + "void", + tir.op.Op.get("tl.clc_try_cancel_multicast"), + _to_ptr(result, "w"), + _to_ptr(mbarrier, "rw"), + ) + + +def clc_is_canceled(result) -> tir.PrimExpr: + """Return 1 when the CLC query successfully canceled a future launch.""" + return tir.call_intrin( + "int32", + tir.op.Op.get("tl.clc_is_canceled"), + _to_ptr(result, "r"), + ) + + +def clc_get_first_ctaid_x(result) -> tir.PrimExpr: + """Return the x coordinate of the first CTA in a successful CLC response.""" + return tir.call_intrin( + "uint32", + tir.op.Op.get("tl.clc_get_first_ctaid_x"), + _to_ptr(result, "r"), + ) + + +def clc_get_first_ctaid_y(result) -> tir.PrimExpr: + """Return the y coordinate of the first CTA in a successful CLC response.""" + return tir.call_intrin( + "uint32", + tir.op.Op.get("tl.clc_get_first_ctaid_y"), + _to_ptr(result, "r"), + ) + + +def clc_get_first_ctaid_z(result) -> tir.PrimExpr: + """Return the z coordinate of the first CTA in a successful CLC response.""" + return tir.call_intrin( + "uint32", + tir.op.Op.get("tl.clc_get_first_ctaid_z"), + _to_ptr(result, "r"), + ) From 77cbe6da87d214f27c71e456b2adc83df130858f Mon Sep 17 00:00:00 2001 From: Chaofan Lin Date: Fri, 17 Apr 2026 16:48:28 +0800 Subject: [PATCH 079/156] [Feature] Introduce T.CUDASourceCodeKernel (#1970) * [WIP][Feature] Introduce CUDARunner * refactor * refactor to T.SourceCodeKernel api * remove streamk change * remove runner * remove prim func params * enhance API * add entry_name * add IsCodeBlockKey * remove misc file * cleanup code * move parse logic to python through callback * fix * comments --- src/target/codegen_c_host.cc | 20 ++- src/target/codegen_cuda.cc | 16 ++ src/target/rt_mod_cuda.cc | 45 ++++- src/transform/common/attr.h | 20 +++ src/transform/lower_device_kernel_launch.cc | 9 +- src/transform/lower_opaque_block.cc | 3 +- src/transform/make_packed_api.cc | 8 + src/transform/split_host_device.cc | 142 ++++++++++++++-- .../test_tilelang_language_source_kernel.py | 87 ++++++++++ tilelang/engine/lower.py | 51 +++++- tilelang/language/__init__.py | 1 + tilelang/language/kernel.py | 156 +++++++++++++++--- 12 files changed, 502 insertions(+), 56 deletions(-) create mode 100644 testing/python/language/test_tilelang_language_source_kernel.py diff --git a/src/target/codegen_c_host.cc b/src/target/codegen_c_host.cc index 7e30862008..6af6bb9725 100644 --- a/src/target/codegen_c_host.cc +++ b/src/target/codegen_c_host.cc @@ -103,15 +103,17 @@ void CodeGenCHost::AddFunction(const tvm::GlobalVar &gvar, << "CodeGenCHost: The entry func must have the global_symbol " "attribute, " << "but function " << gvar << " only has attributes " << func->attrs; - function_names_.push_back(tvm::ffi::symbol::tvm_ffi_main); - stream << "// CodegenC: NOTE: Auto-generated entry function\n"; - PrintFuncPrefix(stream); - PrintType(func->ret_type, stream); - stream << " " << tvm::ffi::symbol::tvm_ffi_main - << "(void* self, void* args,int num_args, void* result) {\n"; - stream << " return " << static_cast(global_symbol.value()) - << "(self, args, num_args, result);\n"; - stream << "}\n"; + if (global_symbol.value() != tvm::ffi::symbol::tvm_ffi_main) { + function_names_.push_back(tvm::ffi::symbol::tvm_ffi_main); + stream << "// CodegenC: NOTE: Auto-generated entry function\n"; + PrintFuncPrefix(stream); + PrintType(func->ret_type, stream); + stream << " " << tvm::ffi::symbol::tvm_ffi_main + << "(void* self, void* args,int num_args, void* result) {\n"; + stream << " return " << static_cast(global_symbol.value()) + << "(self, args, num_args, result);\n"; + stream << "}\n"; + } has_main_func_ = true; } } diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 1ff51883bd..28a74a2632 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -4647,6 +4647,22 @@ void CodeGenTileLangCUDA::PrintFunctionSignature(const String &function_name, void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar, const PrimFunc &f) { + auto code_block_source = f->GetAttr(tl::attr::kCodeBlockSource); + if (code_block_source) { + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol) << "CodeGenTileLangCUDA: Expect PrimFunc to have the " + "global_symbol attribute"; + if (auto code_block_entry_name = + f->GetAttr(tl::attr::kCodeBlockEntryName)) { + ICHECK_EQ(static_cast(global_symbol.value()), + static_cast(code_block_entry_name.value())) + << "T.CUDASourceCodeKernel expects the lowered device global_symbol " + "to match entry_name"; + } + stream << static_cast(code_block_source.value()) << "\n\n"; + return; + } + // If the function has already been forward-declared, this is a // no-op. CodeGenC::DeclareFunction(gvar, f); diff --git a/src/target/rt_mod_cuda.cc b/src/target/rt_mod_cuda.cc index 2cc27eb8c8..37db80d6ac 100644 --- a/src/target/rt_mod_cuda.cc +++ b/src/target/rt_mod_cuda.cc @@ -9,6 +9,36 @@ namespace tvm { namespace codegen { +static std::string GetDeviceGlobalSymbol(const GlobalVar &gvar, + const tir::PrimFunc &f) { + if (auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol)) { + return static_cast(global_symbol.value()); + } + return gvar->name_hint; +} + +static void ValidateUniqueDeviceGlobalSymbols(const IRModule &mod) { + std::unordered_map symbol_to_gvar; + + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) + << "Can only lower IR Module with PrimFuncs"; + auto gvar = Downcast(kv.first); + auto f = Downcast(kv.second); + std::string global_symbol = GetDeviceGlobalSymbol(gvar, f); + + auto [it, inserted] = + symbol_to_gvar.emplace(global_symbol, gvar->name_hint); + ICHECK(inserted) + << "Duplicate CUDA kernel global_symbol `" << global_symbol + << "` found on PrimFuncs `" << it->second << "` and `" + << gvar->name_hint + << "`. T.CUDASourceCodeKernel emits raw CUDA source without " + "renaming, so CUDA entry names must be unique within the compiled " + "module."; + } +} + static std::unordered_map ExtractFuncInfo(const IRModule &mod) { std::unordered_map fmap; @@ -56,8 +86,7 @@ ExtractFuncInfo(const IRModule &mod) { } } } - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - fmap[static_cast(global_symbol.value())] = info; + fmap[GetDeviceGlobalSymbol(Downcast(kv.first), f)] = info; } return fmap; } @@ -67,6 +96,12 @@ ffi::Module BuildTileLangCUDA(IRModule mod, Target target) { CodeGenTileLangCUDA cg; cg.Init(output_ssa); + ValidateUniqueDeviceGlobalSymbols(mod); + if (const auto f = + ffi::Function::GetGlobal("tilelang_callback_cuda_validate")) { + (*f)(mod); + } + for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenTileLangCUDA: Can only take PrimFunc"; @@ -103,6 +138,12 @@ ffi::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) { CodeGenTileLangCUDA cg; cg.Init(output_ssa); + ValidateUniqueDeviceGlobalSymbols(mod); + if (const auto f = + ffi::Function::GetGlobal("tilelang_callback_cuda_validate")) { + (*f)(mod); + } + for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenTileLangCUDA: Can only take PrimFunc"; diff --git a/src/transform/common/attr.h b/src/transform/common/attr.h index 66b68d7e35..c70cccb02e 100644 --- a/src/transform/common/attr.h +++ b/src/transform/common/attr.h @@ -3,7 +3,11 @@ * \brief Check attributes of the IR */ +#ifndef TVM_TL_TRANSFORM_COMMON_ATTR_H_ +#define TVM_TL_TRANSFORM_COMMON_ATTR_H_ + #include "tvm/tir/stmt.h" +#include namespace tvm { namespace tl { @@ -27,7 +31,23 @@ namespace attr { // Attributes to mark CUDA sync calls constexpr const char *kHasTriggerLaunch = "has_cuda_pdl_trigger"; constexpr const char *kHasGridSync = "has_cuda_pdl_sync"; + +// Attributes to implement SourceCodeBlock +constexpr const char *kCodeBlockSource = "code_block_source"; +constexpr const char *kCodeBlockEntryName = "code_block_entry_name"; + +/*! + * \brief Check if attr_key is a code block key extension + * \param attr_key The attr key to be compared + * \return true if it is a code block key + */ +inline bool IsCodeBlockKey(const std::string &attr_key) { + return attr_key.compare(0, 11, "code_block_") == 0; +} + } // namespace attr } // namespace tl } // namespace tvm + +#endif // TVM_TL_TRANSFORM_COMMON_ATTR_H_ diff --git a/src/transform/lower_device_kernel_launch.cc b/src/transform/lower_device_kernel_launch.cc index b7db51bac3..6bf82493ba 100644 --- a/src/transform/lower_device_kernel_launch.cc +++ b/src/transform/lower_device_kernel_launch.cc @@ -278,9 +278,12 @@ class DeviceKernelMutator : public StmtExprMutator { {tvm::tir::attr::kKernelLaunchParams, info.launch_params}, {tvm::attr::kGlobalSymbol, info.global_symbol}}); } - // @lei: workaround as we may require c host codegen, so we need to set the - // global symbol for cpu backend. - func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint); + // Preserve any global_symbol chosen earlier during device splitting. + // Source kernels rely on this to launch the external CUDA entry directly. + if (!func->GetAttr(tvm::attr::kGlobalSymbol)) { + func = + WithAttr(std::move(func), tvm::attr::kGlobalSymbol, gvar->name_hint); + } const auto &info = device_info_map_.at(gvar.get()); const auto &thread_extent = info.thread_extent; diff --git a/src/transform/lower_opaque_block.cc b/src/transform/lower_opaque_block.cc index 34097d015b..a56e370e4a 100644 --- a/src/transform/lower_opaque_block.cc +++ b/src/transform/lower_opaque_block.cc @@ -30,6 +30,7 @@ #include #include "../op/builtin.h" +#include "common/attr.h" #include "tir/transforms/ir_utils.h" namespace tvm { @@ -244,7 +245,7 @@ class OpaqueBlockLower : public StmtExprMutator { pragma_attrs->clear(); for (const auto &kv : annotations) { const String &key = kv.first; - if (tir::attr::IsPragmaKey(key)) { + if (tir::attr::IsPragmaKey(key) || tl::attr::IsCodeBlockKey(key)) { pragma_attrs->emplace_back(key, ConvertAttrValue(key, kv.second)); } else if (key == tl::attr::kLocalVarInit) { if (auto local_init_map = kv.second.try_cast>()) { diff --git a/src/transform/make_packed_api.cc b/src/transform/make_packed_api.cc index 15db2d2354..be331d9736 100644 --- a/src/transform/make_packed_api.cc +++ b/src/transform/make_packed_api.cc @@ -39,6 +39,7 @@ #include "../op/builtin.h" #include "arg_binder.h" +#include "common/attr.h" #include "merge_if_stmt.h" #include "tir/transforms/ir_utils.h" @@ -204,6 +205,13 @@ Optional RequiresPackedAPI(const PrimFunc &func) { } } + // Source kernels must stay as direct GlobalVar calls until + // LowerDeviceKernelLaunch can turn them into device launches using the + // selected external CUDA entry symbol. + if (func->GetAttr(tl::attr::kCodeBlockSource)) { + return std::nullopt; + } + // Internal function calls do not need the PackedFunc API auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); if (!global_symbol) { diff --git a/src/transform/split_host_device.cc b/src/transform/split_host_device.cc index b82a3168ee..072ca0e663 100644 --- a/src/transform/split_host_device.cc +++ b/src/transform/split_host_device.cc @@ -33,8 +33,11 @@ #include #include +#include + #include "../op/builtin.h" #include "common/assume.h" +#include "common/attr.h" #include "tir/analysis/var_use_def_analysis.h" #include "tvm/node/cast.h" #include "tvm/runtime/logging.h" @@ -68,6 +71,11 @@ class HostDeviceSplitter : public tir::StmtMutator { cluster_dims_ = std::move(cluster_dims); } + void SetHostFuncSignature(const tir::PrimFunc &func) { + host_params_ = func->params; + host_buffer_map_ = func->buffer_map; + } + tir::Stmt VisitStmt_(const tir::AttrStmtNode *op) final { if (op->attr_key == tvm::attr::kTarget) { found_device_region_ = true; @@ -104,8 +112,105 @@ class HostDeviceSplitter : public tir::StmtMutator { private: bool found_device_region_{false}; + Array host_params_; + Map host_buffer_map_; Array non_restrict_params_; Optional> cluster_dims_{std::nullopt}; + Optional code_block_source_{std::nullopt}; + Optional code_block_entry_name_{std::nullopt}; + + static void SortDeviceParams(std::vector *params) { + std::sort(params->begin(), params->end(), + [](const tir::Var &a, const tir::Var &b) { + auto sort_key = [](const tir::Var &var) { + return std::tuple{ + !var->dtype.is_handle(), + var->name_hint, + }; + }; + return sort_key(a) < sort_key(b); + }); + } + + std::tuple, Array> + CollectSourceKernelSignature() const { + std::vector params; + std::unordered_set seen_vars; + + auto push = [&](const tir::Var &var) { + if (var.defined() && seen_vars.insert(var->name_hint).second) { + params.push_back(var); + } + }; + + Array buffers_to_declare; + for (const auto &kv : host_buffer_map_) { + const tir::Buffer &buf = kv.second; + push(buf->data); + buffers_to_declare.push_back(buf); + for (const PrimExpr &dim : buf->shape) { + if (const auto *var = dim.as()) { + push(GetRef(var)); + } + } + for (const PrimExpr &stride : buf->strides) { + if (const auto *var = stride.as()) { + push(GetRef(var)); + } + } + if (const auto *var = buf->elem_offset.as()) { + push(GetRef(var)); + } + } + + SortDeviceParams(¶ms); + return {Array(params.begin(), params.end()), buffers_to_declare}; + } + + class SourceKernelAttrExtractor : public tir::StmtMutator { + public: + static Stmt Extract(Stmt body, Optional *code_block_source, + Optional *code_block_entry_name) { + SourceKernelAttrExtractor extractor(code_block_source, + code_block_entry_name); + return extractor(std::move(body)); + } + + private: + explicit SourceKernelAttrExtractor(Optional *code_block_source, + Optional *code_block_entry_name) + : code_block_source_(code_block_source), + code_block_entry_name_(code_block_entry_name) {} + + Stmt VisitStmt_(const tir::AttrStmtNode *op) final { + if (op->attr_key == tl::attr::kCodeBlockSource) { + if (auto str = op->value.as()) { + *code_block_source_ = str->value; + } else { + LOG(FATAL) << "Expected `" << tl::attr::kCodeBlockSource + << "` AttrStmt to carry a StringImm value, but got " + << op->value->GetTypeKey(); + } + return VisitStmt(op->body); + } + + if (op->attr_key == tl::attr::kCodeBlockEntryName) { + if (auto str = op->value.as()) { + *code_block_entry_name_ = str->value; + } else { + LOG(FATAL) << "Expected `" << tl::attr::kCodeBlockEntryName + << "` AttrStmt to carry a StringImm value, but got " + << op->value->GetTypeKey(); + } + return VisitStmt(op->body); + } + + return tir::StmtMutator::VisitStmt_(op); + } + + Optional *code_block_source_; + Optional *code_block_entry_name_; + }; // Wrap body with assumes, substituting variables in assumes with the // corresponding variables in the device body based on name_hint matching. @@ -140,27 +245,30 @@ class HostDeviceSplitter : public tir::StmtMutator { } tir::Stmt SplitDeviceFunc(tir::Stmt body, tvm::Target device_target) { - // First, analyze undefined variables in the device body + code_block_source_ = std::nullopt; + code_block_entry_name_ = std::nullopt; + body = SourceKernelAttrExtractor::Extract( + std::move(body), &code_block_source_, &code_block_entry_name_); + + // Normal kernels infer device parameters from use-def of the device body. + // Source kernels have no meaningful DSL body, so their device signature + // must be reconstructed explicitly from the host PrimFunc signature and + // buffer metadata. auto [old_params, buffers_to_declare] = [&]() -> std::tuple, Array> { + if (code_block_source_) { + return CollectSourceKernelSignature(); + } + tir::VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/true); use_def(body); - // Sort first by variable type, then by variable name std::vector params{use_def.undefined_.begin(), use_def.undefined_.end()}; - std::sort(params.begin(), params.end(), - [](const tir::Var &a, const tir::Var &b) { - auto sort_key = [](const tir::Var &var) { - return std::tuple{ - !var->dtype.is_handle(), - var->name_hint, - }; - }; - return sort_key(a) < sort_key(b); - }); - return {params, use_def.undefined_buffers_}; + SortDeviceParams(¶ms); + return {Array(params.begin(), params.end()), + use_def.undefined_buffers_}; }(); // Create new parameter variables for the device function to avoid sharing @@ -246,9 +354,16 @@ class HostDeviceSplitter : public tir::StmtMutator { if (cluster_dims_.defined()) { device_attrs.Set("cluster_dims", cluster_dims_.value()); } + if (code_block_source_) { + device_attrs.Set(tl::attr::kCodeBlockSource, code_block_source_.value()); + } device_func = WithAttrs(std::move(device_func), device_attrs); GlobalVar kernel_symbol_global = var_supply_(); + if (code_block_entry_name_) { + kernel_symbol_global = GlobalVar(code_block_entry_name_.value()); + } + (*device_mod_)->Add(kernel_symbol_global, device_func); // Use old_params as call arguments (host-side variables) Array args = @@ -281,6 +396,7 @@ class HostDeviceSplitter : public tir::StmtMutator { tir::PrimFunc SplitHostDevice(tir::PrimFunc func, IRModule *device_mod, std::function var_supply) { HostDeviceSplitter splitter(device_mod, std::move(var_supply)); + splitter.SetHostFuncSignature(func); // Propagate non-restrict parameter list from host func to device kernels if (auto opt = func->GetAttr>(tl::attr::kNonRestrictParams)) { splitter.SetNonRestrictParams(opt.value()); diff --git a/testing/python/language/test_tilelang_language_source_kernel.py b/testing/python/language/test_tilelang_language_source_kernel.py new file mode 100644 index 0000000000..782ff5d205 --- /dev/null +++ b/testing/python/language/test_tilelang_language_source_kernel.py @@ -0,0 +1,87 @@ +import os +import re +import tempfile +from pathlib import Path + +import pytest +import tilelang +import tilelang.language as T +import tilelang.testing + +import torch + + +CUDA_SOURCE = """ +extern "C" __global__ void external_copy(float* A, float* B, int n) { + int i = (int)(blockIdx.x * blockDim.x + threadIdx.x); + if (i < n) { + B[i] = A[i]; + } +} +""" + + +def make_source_kernel(source_code_or_path: str | os.PathLike[str], entry_name: str): + N = T.dynamic("N") + + @T.prim_func + def main( + A: T.Tensor((N,), T.float32), + B: T.Tensor((N,), T.float32), + ): + T.CUDASourceCodeKernel(T.ceildiv(N, 128), threads=128, source_code_or_path=source_code_or_path, entry_name=entry_name) + + return main + + +def get_single_device_function_name(device_mod) -> str: + function_names = [g_var.name_hint for g_var in device_mod.functions] + assert len(function_names) == 1 + return function_names[0] + + +@tilelang.testing.requires_cuda +def test_source_kernel_inline_codegen(): + artifact = tilelang.lower(make_source_kernel(CUDA_SOURCE, entry_name="external_copy"), target="cuda") + function_name = get_single_device_function_name(artifact.device_mod) + + assert re.search( + rf"__global__\s+void\s+(?:__launch_bounds__\([^\)]*\)\s+)?{re.escape(function_name)}\s*\(", + artifact.kernel_source, + ) + assert "B[i] = A[i];" in artifact.kernel_source + + +@tilelang.testing.requires_cuda +def test_source_kernel_run(): + kernel = tilelang.compile(make_source_kernel(CUDA_SOURCE, entry_name="external_copy"), target="cuda") + print(kernel.get_kernel_source()) + print(kernel.get_host_source()) + a = torch.randn(128, dtype=torch.float32, device="cuda") + b = torch.empty_like(a) + kernel(a, b) + torch.testing.assert_close(b, a) + + +@tilelang.testing.requires_cuda +def test_source_kernel_loads_from_file(): + with tempfile.NamedTemporaryFile("w", suffix=".cu", delete=False, encoding="utf-8") as f: + f.write(CUDA_SOURCE) + source_path = f.name + + try: + artifact = tilelang.lower(make_source_kernel(Path(source_path), entry_name="external_copy"), target="cuda") + finally: + os.unlink(source_path) + + assert "B[i] = A[i];" in artifact.kernel_source + + +@tilelang.testing.requires_cuda +def test_source_kernel_invalid_entry_name_fails_in_lower(): + with pytest.raises(Exception, match=r"Available entries: external_copy"): + tilelang.lower(make_source_kernel(CUDA_SOURCE, entry_name="main_kernel"), target="cuda") + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index 40a757c7a2..d7b456386b 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -2,6 +2,7 @@ from __future__ import annotations +import re from typing import Callable import tilelang.transform from tilelang import tvm as tvm @@ -55,6 +56,50 @@ def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]: return lambda func: not get_device_call(is_device_c)(func) +_CUDA_GLOBAL_KERNEL_PATTERN = re.compile(r'(?:extern\s+"C"\s+)?__global__\s+void\s+(?:__launch_bounds__\([^\)]*\)\s+)?(\w+)') + + +def _collect_external_cuda_kernel_names(source: str) -> list[str]: + kernel_names: list[str] = [] + seen_names: set[str] = set() + for match in _CUDA_GLOBAL_KERNEL_PATTERN.finditer(source): + kernel_name = match.group(1) + if kernel_name not in seen_names: + kernel_names.append(kernel_name) + seen_names.add(kernel_name) + return kernel_names + + +@tvm_ffi.register_global_func("tilelang_callback_cuda_validate", override=True) +def tilelang_callback_cuda_validate(device_mod): + for _, base_func in device_mod.functions.items(): + if not isinstance(base_func, tir.PrimFunc) or not base_func.attrs: + continue + + code_block_source = base_func.attrs.get("code_block_source") + if code_block_source is None: + continue + + global_symbol = base_func.attrs.get("global_symbol") + if global_symbol is None: + raise ValueError("CodeGenTileLangCUDA expects source-kernel PrimFunc to have the global_symbol attribute") + + expected_name = str(global_symbol) + code_block_entry_name = base_func.attrs.get("code_block_entry_name") + if code_block_entry_name is not None and str(code_block_entry_name) != expected_name: + raise ValueError("T.CUDASourceCodeKernel expects the lowered device global_symbol to match entry_name") + + kernel_names = _collect_external_cuda_kernel_names(str(code_block_source)) + if not kernel_names: + raise ValueError("T.CUDASourceCodeKernel expects external CUDA source to declare at least one __global__ kernel") + if expected_name not in kernel_names: + raise ValueError( + "T.CUDASourceCodeKernel expected device global_symbol " + f"`{expected_name}` to match a __global__ kernel in the provided CUDA source. " + f"Available entries: {', '.join(kernel_names)}" + ) + + @tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True) def tilelang_callback_cuda_compile(code, target, pass_config=None): target_arch = nvcc.get_target_arch(nvcc.get_target_compute_version(target)) @@ -265,12 +310,12 @@ def lower( host_mod = tir.transform.Filter(_is_host_call)(mod) device_mod = tir.transform.Filter(_is_device_call)(mod) - codegen_mod = device_codegen(device_mod, target) if enable_device_compile else device_codegen_without_compile(device_mod, target) + kernel_source = codegen_mod.inspect_source() if enable_host_codegen: host_mod = host_codegen(host_mod, target_host, target=target) host_mod.import_module(codegen_mod) - return CompiledArtifact(host_mod, device_mod, params, codegen_mod.inspect_source(), rt_mod=host_mod) + return CompiledArtifact(host_mod, device_mod, params, kernel_source, rt_mod=host_mod) - return CompiledArtifact(host_mod, device_mod, params, codegen_mod.inspect_source()) + return CompiledArtifact(host_mod, device_mod, params, kernel_source) diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 323ee7e633..785b91c489 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -30,6 +30,7 @@ from .math_intrinsics import * # noqa: F401 from .kernel import ( Kernel, # noqa: F401 + CUDASourceCodeKernel, # noqa: F401 KernelLaunchFrame, # noqa: F401 get_thread_binding, # noqa: F401 get_thread_bindings, # noqa: F401 diff --git a/tilelang/language/kernel.py b/tilelang/language/kernel.py index 9bc49dd49e..a1780a5ae1 100644 --- a/tilelang/language/kernel.py +++ b/tilelang/language/kernel.py @@ -2,8 +2,10 @@ from __future__ import annotations from collections import deque +import os from tvm import tir from tvm.tir import Var +from tvm.script.ir_builder.tir import evaluate as T_evaluate from tvm.script.ir_builder.tir.frame import TIRFrame, BlockFrame from tvm.ffi import register_object from tilelang import _ffi_api @@ -92,6 +94,41 @@ def _normalize_bindings(bindings: list[Var]) -> Var | list[Var]: return bindings +def _normalize_threads( + threads: int | list[int] | tuple | None, + *, + is_cpu: bool, +) -> list[int] | None: + if not is_cpu and threads is None: + threads = 128 # default thread number + + if isinstance(threads, int): + return [threads, 1, 1] + if isinstance(threads, list): + return threads + [1] * (3 - len(threads)) + if isinstance(threads, tuple): + return list(threads) + [1] * (3 - len(threads)) + + assert is_cpu, "threads must be an integer or a list of integers" + return None + + +def _normalize_cluster_dims( + cluster_dims: int | tuple[int, int, int] | list[int] | None, +) -> list[int] | None: + if cluster_dims is None: + return None + + if isinstance(cluster_dims, (list, tuple)): + cluster_dims = list(cluster_dims) + [1] * (3 - len(cluster_dims)) + elif isinstance(cluster_dims, int): + cluster_dims = [cluster_dims, 1, 1] + else: + raise ValueError("cluster_dims must be a list or tuple of integers") + + return None if cluster_dims == [1, 1, 1] else cluster_dims + + @register_object("tl.KernelLaunchFrame") class KernelLaunchFrame(TIRFrame): """ @@ -248,10 +285,6 @@ def Kernel( For example, use 2 or (2, 1, 1) to create 2-CTA clusters. When specified, the kernel will be launched using cudaLaunchKernelEx with cudaLaunchAttributeClusterDimension. - is_cpu : bool - Whether the kernel is running on CPU. - Thus we will not bind threadIdx.x, threadIdx.y, threadIdx.z. - and blockIdx.x, blockIdx.y, blockIdx.z. prelude : str The import c code of the kernel, will be injected before the generated kernel code. @@ -296,18 +329,7 @@ def Kernel( raise JITNoBuilderError("T.Kernel() can only be used inside @tilelang.jit or @T.prim_func context. No Builder is available.") attrs: dict = {} - - if not is_cpu and threads is None: - threads = 128 # default thread number - - if isinstance(threads, int): - threads = [threads, 1, 1] - elif isinstance(threads, list): - threads = threads + [1] * (3 - len(threads)) - elif isinstance(threads, tuple): - threads = list(threads) + [1] * (3 - len(threads)) - else: - assert is_cpu, "threads must be an integer or a list of integers" + threads = _normalize_threads(threads, is_cpu=is_cpu) if is_cpu: attrs["tilelang.is_cpu_kernel_frame"] = True @@ -315,20 +337,104 @@ def Kernel( if prelude is not None: attrs["pragma_import_c"] = prelude + cluster_dims = _normalize_cluster_dims(cluster_dims) if cluster_dims is not None: - if isinstance(cluster_dims, (list, tuple)): - cluster_dims = list(cluster_dims) + [1] * (3 - len(cluster_dims)) - elif isinstance(cluster_dims, int): - cluster_dims = [cluster_dims, 1, 1] - else: - raise ValueError("cluster_dims must be a list or tuple of integers") - - if cluster_dims != [1, 1, 1]: - attrs["cluster_dims"] = cluster_dims + attrs["cluster_dims"] = cluster_dims return _ffi_api.KernelLaunch(blocks, threads, attrs) +# For CUDA source kernels, we need to load the source code from a file or string. + + +def _load_cuda_source(source_code_or_path: str | os.PathLike[str]) -> str: + source = os.fspath(source_code_or_path) + if not isinstance(source, str) or not source.strip(): + raise ValueError("source_code_or_path must be a non-empty source string or source path") + + expanded = os.path.expanduser(source) + if os.path.isfile(expanded): + with open(expanded, encoding="utf-8") as f: + return f.read() + + source_markers = ("\n", "__global__", 'extern "C"', "#include") + if any(marker in source for marker in source_markers): + return source + + contains_path_sep = os.path.sep in source or (os.path.altsep is not None and os.path.altsep in source) + if contains_path_sep or source.endswith((".cu", ".cuh", ".cuda", ".cpp", ".cc", ".c")): + raise FileNotFoundError(f"CUDA source file not found: {source}") + + return source + + +def CUDASourceCodeKernel( + *blocks: int | tir.PrimExpr, + threads: int | list[int] | tuple | None = None, + source_code_or_path: str | os.PathLike[str], + entry_name: str = "main_kernel", + cluster_dims: int | tuple[int, int, int] | list[int] | None = None, + prelude: str | None = None, +) -> None: + """Launch a kernel from CUDA source code or a CUDA source file. + + The code must follows the following rules: + 1. The kernel source must be a valid CUDA kernel which can be correctly compiled under TileLang's context. + 2. The kernel source must either contains only one `__global__` function as an entry, or have a `__global__` entry function named `main_kernel`. + + Parameters + ---------- + source_code_or_path : str | os.PathLike[str] + Inline CUDA source code, or a path to a CUDA source file. + If the argument resolves to an existing file, the file contents are + loaded. Otherwise it is treated as inline CUDA source code. + blocks : int + A list of extent, can be 1-3 dimension, representing gridDim.(x|y|z) + entry_name : str | None + Optional name of the `__global__` CUDA entry function inside the + provided source. When specified, TileLang launches that external CUDA + entry directly. + threads : int + A integer representing blockDim.x + Or a list of integers representing blockDim.(x|y|z) + if the value is -1, we skip the threadIdx.x binding. + cluster_dims : int | tuple[int, int, int] | list[int] | None + The cluster dimensions for SM90+ cluster launch. + For example, use 2 or (2, 1, 1) to create 2-CTA clusters. + When specified, the kernel will be launched using cudaLaunchKernelEx + with cudaLaunchAttributeClusterDimension. + prelude : str + The import c code of the kernel, + will be injected before the generated kernel code. + """ + from tilelang.language.eager.builder import Builder + + if Builder.current() is None: + raise JITNoBuilderError( + "T.CUDASourceCodeKernel() can only be used inside @tilelang.jit or @T.prim_func context. No Builder is available." + ) + + source = _load_cuda_source(source_code_or_path) + if prelude is not None: + source = prelude + "\n" + source + + attrs: dict = {"code_block_source": source} + if not isinstance(entry_name, str) or not entry_name.strip(): + raise ValueError("entry_name must be a non-empty string when provided") + attrs["code_block_entry_name"] = entry_name + + threads = _normalize_threads(threads, is_cpu=False) + + cluster_dims = _normalize_cluster_dims(cluster_dims) + if cluster_dims is not None: + attrs["cluster_dims"] = cluster_dims + + with _ffi_api.KernelLaunch(blocks, threads, attrs): + # Keep the launch frame alive until SplitHostDevice can lift the + # external CUDA source pragma onto the device PrimFunc. + T_evaluate(tir.call_extern("int32", entry_name)) + + def get_thread_binding(dim: int = 0) -> Var: """Returns the thread binding for the given dimension.""" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" From aa0700bbb90e39e09cc2e592b71b2fa7790c06c5 Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Fri, 17 Apr 2026 16:55:01 +0800 Subject: [PATCH 080/156] run format --- src/transform/merge_shared_memory_allocations.cc | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/transform/merge_shared_memory_allocations.cc b/src/transform/merge_shared_memory_allocations.cc index 511f0400dd..80b0eb3ac1 100644 --- a/src/transform/merge_shared_memory_allocations.cc +++ b/src/transform/merge_shared_memory_allocations.cc @@ -311,8 +311,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { if (const auto *seq = body.as()) { for (const auto &sub_stmt : seq->seq) { if (const auto *attr = sub_stmt.as(); - attr && - attr->attr_key == attr::kAutoScheduleSharedMemoryBoundary) { + attr && attr->attr_key == attr::kAutoScheduleSharedMemoryBoundary) { this->VisitStmt_(attr); } else { StmtExprVisitor::VisitStmt(sub_stmt); @@ -346,9 +345,8 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { VisitWarpSpecializationBody(op->body); } else if (op->attr_key == "kAutoScheduleSharedMemoryBoundary") { if (in_boundary_scope_) { - CloseBoundaryScope( - static_cast( - linear_seq_[boundary_scope_begin_index_].stmt)); + CloseBoundaryScope(static_cast( + linear_seq_[boundary_scope_begin_index_].stmt)); } OpenBoundaryScope(op); VisitBoundaryBody(op->body); From 72d7748d5ddcc3b8ce4017ef2aa23298daa2677c Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Fri, 17 Apr 2026 16:55:26 +0800 Subject: [PATCH 081/156] fix II of IfNode --- src/transform/auto_schedule/schedule_builder.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index dde0c242f1..eb559a126b 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -704,9 +704,12 @@ void ScheduleUnitBuilder::ScheduleRecursive( if (if_node->else_child) { ScheduleRecursive(if_node->else_child, used_buffers); } - if_node->SetLatency(std::max(if_node->then_child ? if_node->then_child->GetLatency() : 0, - if_node->else_child ? if_node->else_child->GetLatency() : 0)); - if_node->SetII(if_node->GetLatency()); + if_node->SetLatency( + std::max(if_node->then_child ? if_node->then_child->GetLatency() : 0, + if_node->else_child ? if_node->else_child->GetLatency() : 0)); + if_node->SetII( + std::max(if_node->then_child ? if_node->then_child->GetII() : 0, + if_node->else_child ? if_node->else_child->GetII() : 0)); return; } From 4bf8678be289821feb1e84ce638f95c7053c7472 Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Fri, 17 Apr 2026 16:56:11 +0800 Subject: [PATCH 082/156] fix barrier --- src/transform/auto_schedule/barrier.h | 269 +++++++++++++------------- 1 file changed, 133 insertions(+), 136 deletions(-) diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index 08575119f9..99f0371c94 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -584,6 +584,16 @@ AnalyzeAndInsertBarriers(IRStructure *node, int &next_barrier_id, } } +static TaskNode *GetInnerTask(ScheduleUnit *unit) { + std::vector task_contexts; + CollectAllTaskNodesWithContext(unit, task_contexts); + if (task_contexts.size() == 1 && task_contexts[0].control_node == nullptr) { + return task_contexts[0].task; + } else { + return nullptr; + } +} + static auto GetSyncInfos(const std::vector &units, int num_wgs, const std::unordered_mapstage - waiting_unit->stage; - int num_mma = wgmma_id[wg_id][waiting_unit] - wgmma_id[wg_id][unit]; - num_mma += real_distance * wgmma_count[wg_id]; - if (unit->isInnerTask()) { - --num_mma; - } + // + // int real_distance = distance + unit->stage - waiting_unit->stage; + // int num_mma = wgmma_id[wg_id][waiting_unit] - + // wgmma_id[wg_id][unit]; num_mma += real_distance * + // wgmma_count[wg_id]; if (unit->isInnerTask()) { + // --num_mma; + // } + // // Fallback to set num_mma to 0 to avoid error. - num_mma = 0; + int num_mma = 0; Stmt wait_stmt = Evaluate(Call(DataType::Handle(), wait_wgmma(), {num_mma})); InsertStatementIntoScheduleUnit(waiting_unit, wait_stmt, true, @@ -762,8 +774,7 @@ static void InsertSynchronization( Buffer barrier_buffer; // Handle single special task, such as TCGEN05 or TMA load, that requires // a barrier for itself. - if (unit->isInnerTask()) { - auto task = static_cast(unit->child.get()); + if (auto task = GetInnerTask(unit)) { int task_wg_id = task->GetWarpgroupId(); if (task->is_TCGEN05() && task_wg_id == wg_id) { int barrier_id = next_barrier_id++; @@ -774,6 +785,9 @@ static void InsertSynchronization( indexmod(loop_info.CalculateIterationCount(), barrier_versions); PrimExpr mbar_expr = BufferLoad(barrier_buffer, {version_index}); RewriteGemmMbar(task, mbar_expr); + // Stmt arrive_stmt = + // makeTcgen05MmaArrive(barrier_buffer, version_index); + // InsertStatementIntoScheduleUnit(unit, arrive_stmt, false, wg_id); } if (task->HasTMALoad() && task_wg_id == wg_id) { int barrier_id = next_barrier_id++; @@ -794,10 +808,12 @@ static void InsertSynchronization( return true; if (!is_async) return false; - if (unit->isInnerTask() && waiting_unit->isInnerTask() && - static_cast(unit->child.get())->is_TCGEN05() && - static_cast(waiting_unit->child.get())->is_TCGEN05()) { - return false; + if (auto task = GetInnerTask(unit)) { + if (auto waiting_task = GetInnerTask(waiting_unit)) { + if (task->is_TCGEN05() && waiting_task->is_TCGEN05()) { + return false; + } + } } return true; }; @@ -812,10 +828,9 @@ static void InsertSynchronization( if (!need_barrier) continue; if (!barrier_buffer.defined()) { - // Note: the logic here assumes that if there are TCGEN05 tasks, then - // all tasks are finished when all TCGEN05 tasks are finished. So we can - // use the TCGEN05 barrier for all tasks. If this assumption does not - // hold, we may need to implement a more complex logic to synchronize. + // Note: the logic here assumes that we DO NOT need to wait for TMA + // loads in this unit. If this assumption does not hold, we may need to + // implement a more complex logic to synchronize. if (unit->HasTCGEN05()) { int barrier_id = next_barrier_id++; barrier_buffer = makeBarrierBuffer( @@ -867,36 +882,40 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, if (!seq) return; - // Collect all tasks from the sequence - std::vector tasks; + // Collect all units from the sequence + std::vector units; for (auto &child : seq->children) { - auto task = static_cast(child.get()); - if (task->child->IsSequence() || task->child->IsControl() || - task->child->IsIf()) { + auto unit = static_cast(child.get()); + units.push_back(unit); + if (GetInnerTask(unit) != nullptr) { + // We will handle these units specially in InsertSynchronization, so we + // skip it here. + continue; + } + if (unit->child->IsSequence() || unit->child->IsControl() || + unit->child->IsIf()) { // If child is SequenceNode, ControlNode, or IfNode, recursively analyze // it AnalyzeAndInsertBarriers( - task->child.get(), next_barrier_id, barrier_buffers, barrier_map, + unit->child.get(), next_barrier_id, barrier_buffers, barrier_map, thread_count, loop_info, buffer_infos, neutral_sync_shared_barrier); } - tasks.push_back(task); } - // Rewrite TMA load tasks to use tma_copy and neutral_sync_shared_barrier - for (auto task : tasks) { - if (task->isInnerTask() && task->UsesTMACore()) { - auto child = static_cast(task->child.get()); - if (child->HasTMALoad() && - child->GetSchedulePhase() == SchedulePhase::kPrologue) { + // Rewrite TMA load units to use tma_copy and neutral_sync_shared_barrier + for (auto unit : units) { + if (auto task = GetInnerTask(unit)) { + if (task->HasTMALoad() && + task->GetSchedulePhase() == SchedulePhase::kPrologue) { PrimExpr barrier_load = BufferLoad(neutral_sync_shared_barrier, {0}); - RewriteCopyMbar(child, barrier_load); + RewriteCopyMbar(task, barrier_load); } } } - // Insert synchronization - auto sync_infos = GetSyncInfos(tasks, thread_count.size()); - InsertSynchronization(tasks, sync_infos, next_barrier_id, barrier_buffers, + // Analyze dependencies and insert synchronization + auto sync_infos = GetSyncInfos(units, thread_count.size()); + InsertSynchronization(units, sync_infos, next_barrier_id, barrier_buffers, barrier_map, thread_count, loop_info); } @@ -916,120 +935,98 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, if (!for_node) return; - PrimExpr loop_var = for_node->loop_var; - PrimExpr loop_start = for_node->min; - PrimExpr loop_step = for_node->step.has_value() - ? for_node->step.value() - : IntImm(DataType::Int(32), 1); - PrimExpr loop_extent = for_node->extent; - bool has_promoted_tasks = ctrl->hasPromote(); - // Add this loop to nesting info loop_info.AddLoop(for_node); - // Check if inner loops have constant extents (if any) - // This check will be done when calculating parity expression - - // If child is a SequenceNode, we need special handling for - // promote/non-promote tasks - if (ctrl->child->IsSequence()) { - auto seq = static_cast(ctrl->child.get()); - - // Collect all tasks from the sequence - std::vector tasks; - for (auto &child : seq->children) { - auto task = static_cast(child.get()); - if (task->child->IsSequence() || task->child->IsControl() || - task->child->IsIf()) { - // If child is SequenceNode, ControlNode, or IfNode, recursively analyze - // it - AnalyzeAndInsertBarriers( - task->child.get(), next_barrier_id, barrier_buffers, barrier_map, - thread_count, loop_info, buffer_infos, neutral_sync_shared_barrier); - } - tasks.push_back(task); + // Collect all units from the sequence + ICHECK(ctrl->child->IsSequence()); + auto seq = static_cast(ctrl->child.get()); + std::vector units; + for (auto &child : seq->children) { + auto unit = static_cast(child.get()); + units.push_back(unit); + if (GetInnerTask(unit) != nullptr) { + // We will handle these units specially in InsertSynchronization, so we + // skip it here. + continue; } - - // Process in order: sort by stage - // This matches the software pipelining order - auto ordered_tasks = tasks; - std::stable_sort( - ordered_tasks.begin(), ordered_tasks.end(), - [](ScheduleUnit *a, ScheduleUnit *b) { return a->stage > b->stage; }); - - // Rewrite multi-buffer - auto num_stages = 1; - auto num_stages_val = ctrl->control.get()->annotations.Get("num_stages"); - if (num_stages_val.has_value()) { - num_stages = num_stages_val.value().cast()->value; + if (unit->child->IsSequence() || unit->child->IsControl() || + unit->child->IsIf()) { + // If child is SequenceNode, ControlNode, or IfNode, recursively analyze + // it + AnalyzeAndInsertBarriers( + unit->child.get(), next_barrier_id, barrier_buffers, barrier_map, + thread_count, loop_info, buffer_infos, neutral_sync_shared_barrier); } - std::unordered_map - multi_buffer; - std::unordered_map - buffer_num_versions; - if (num_stages != 1) { - for (const auto &task : ordered_tasks) { - for (const auto ®ion_access : task->GetReadWriteRegions()) { - auto &buffer = region_access.region->buffer; - if (!ctrl->multi_buffering_buffers.count(buffer)) + } + + // Sort units by stage + // This matches the software pipelining order + auto ordered_units = units; + std::stable_sort( + ordered_units.begin(), ordered_units.end(), + [](ScheduleUnit *a, ScheduleUnit *b) { return a->stage > b->stage; }); + + // Detect multi-version buffers and create new buffers for them + std::unordered_map + multi_buffer; + std::unordered_map + buffer_num_versions; + for (const auto &unit : ordered_units) { + for (const auto ®ion_access : unit->GetReadWriteRegions()) { + auto &buffer = region_access.region->buffer; + if (!ctrl->multi_buffering_buffers.count(buffer)) + continue; + for (const auto &other_unit : ordered_units) { + if (unit == other_unit) + continue; + int distance = unit->child->GetStartTime() + unit->child->GetLatency() - + other_unit->child->GetStartTime(); + if (distance <= 0) + continue; + distance = (distance - 1) / ctrl->GetIIperIter() + 1; + for (const auto &other_region_access : + other_unit->GetReadWriteRegions()) { + auto &other_buffer = other_region_access.region->buffer; + if (!buffer.same_as(other_buffer)) continue; - for (const auto &other_task : ordered_tasks) { - if (task == other_task) - continue; - int distance = task->child->GetStartTime() + - task->child->GetLatency() - - other_task->child->GetStartTime(); - if (distance <= 0) - continue; - distance = (distance - 1) / ctrl->GetIIperIter() + 1; - for (const auto &other_region_access : - other_task->GetReadWriteRegions()) { - auto &other_buffer = other_region_access.region->buffer; - if (!buffer.same_as(other_buffer)) - continue; - if (region_access.is_write || other_region_access.is_write) { - auto &num_versions = buffer_num_versions[buffer]; - num_versions = std::max(num_versions, distance); - } - } + if (region_access.is_write || other_region_access.is_write) { + auto &num_versions = buffer_num_versions[buffer]; + num_versions = std::max(num_versions, distance); } } } - for (auto ®ion : ctrl->GetWriteRegions()) { - auto &buffer = region.get()->buffer; - if (!ctrl->multi_buffering_buffers.count(buffer)) - continue; - if (multi_buffer.find(buffer) != multi_buffer.end()) - continue; - auto it = buffer_num_versions.find(buffer); - if (it == buffer_num_versions.end()) - continue; - int num_versions = it->second; - if (num_versions == 1) - continue; - auto new_buffer = RewriteAllocBuffer(buffer, num_versions); - multi_buffer[buffer] = new_buffer; - buffer_infos.emplace_back(buffer, num_versions, new_buffer); - } + } + } + for (auto ®ion : ctrl->GetWriteRegions()) { + auto &buffer = region.get()->buffer; + if (!ctrl->multi_buffering_buffers.count(buffer)) + continue; + if (multi_buffer.find(buffer) != multi_buffer.end()) + continue; + auto it = buffer_num_versions.find(buffer); + if (it == buffer_num_versions.end()) + continue; + int num_versions = it->second; + if (num_versions == 1) + continue; + auto new_buffer = RewriteAllocBuffer(buffer, num_versions); + multi_buffer[buffer] = new_buffer; + buffer_infos.emplace_back(buffer, num_versions, new_buffer); + } - // Rewrite BufferLoad/BufferStore in TaskNode stmts for multi-version - // buffers - PrimExpr iteration = loop_info.CalculateIterationCount(); + // Rewrite BufferLoad/BufferStore in TaskNode stmts for multi-version + // buffers + PrimExpr iteration = loop_info.CalculateIterationCount(); - // Recursively rewrite all TaskNode stmts - RewriteTaskNodeBuffers(ctrl, multi_buffer, iteration); - } + // Recursively rewrite all TaskNode stmts + RewriteTaskNodeBuffers(ctrl, multi_buffer, iteration); - // Insert synchronization - auto sync_infos = GetSyncInfos(ordered_tasks, thread_count.size(), - buffer_num_versions, true); - InsertSynchronization(tasks, sync_infos, next_barrier_id, barrier_buffers, - barrier_map, thread_count, loop_info); - } else { - AnalyzeAndInsertBarriers( - ctrl->child.get(), next_barrier_id, barrier_buffers, barrier_map, - thread_count, loop_info, buffer_infos, neutral_sync_shared_barrier); - } + // Analyze dependencies and insert synchronization + auto sync_infos = GetSyncInfos(ordered_units, thread_count.size(), + buffer_num_versions, true); + InsertSynchronization(units, sync_infos, next_barrier_id, barrier_buffers, + barrier_map, thread_count, loop_info); // Remove this loop from nesting info when exiting loop_info.PopLoop(); From 9c5fe44c7ab2b556f99aae324d74048235dd17bc Mon Sep 17 00:00:00 2001 From: Tong WU <109033598+Rachmanino@users.noreply.github.com> Date: Fri, 17 Apr 2026 17:22:03 +0800 Subject: [PATCH 083/156] [BugFix] Keep shared-prelude local vars in producer-consumer WS (#2055) * [BugFix] Keep shared-prelude local vars in producer-consumer WS Preserve prelude definitions needed by shared LetStmt values during warp-specialized pipeline rewriting, and avoid cloning local.var buffers into the producer branch. This keeps grouped GEMM batch index calculations shared and fixes incorrect TMA batch selection. * add test * lint --- .../grouped_gemm/test_example_grouped_gemm.py | 60 ++++++++ src/transform/producer_consumer_ws.cc | 50 +++++- ...tilelang_transform_producer_consumer_ws.py | 145 ++++++++++++++++++ 3 files changed, 251 insertions(+), 4 deletions(-) create mode 100644 examples/grouped_gemm/test_example_grouped_gemm.py diff --git a/examples/grouped_gemm/test_example_grouped_gemm.py b/examples/grouped_gemm/test_example_grouped_gemm.py new file mode 100644 index 0000000000..0cf03462a8 --- /dev/null +++ b/examples/grouped_gemm/test_example_grouped_gemm.py @@ -0,0 +1,60 @@ +import tilelang.testing + +import example_grouped_gemm_bwd +import example_grouped_gemm_fwd +import example_grouped_gemm_fwd_ptr + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_example_grouped_gemm_fwd_small(): + example_grouped_gemm_fwd.run_tilelang_grouped_gemm( + [5, 9, 13], + K=64, + M=96, + block_M=64, + block_N=64, + block_K=32, + trans_b=False, + num_stages=2, + threads=256, + profile=False, + ) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_example_grouped_gemm_fwd_ptr_small(): + example_grouped_gemm_fwd_ptr.run_tilelang_grouped_gemm_ptr( + [5, 9, 13], + K=64, + N=96, + block_M=64, + block_N=64, + block_K=32, + num_stages=1, + threads=256, + backend="tvm_ffi", + profile=False, + ) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_example_grouped_gemm_bwd_small(): + example_grouped_gemm_bwd.run_tilelang_grouped_gemm( + [5, 9, 13], + K=64, + M=96, + block_M=64, + block_N=64, + block_K=32, + trans_b=False, + num_stages=2, + threads=256, + profile=False, + ) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/src/transform/producer_consumer_ws.cc b/src/transform/producer_consumer_ws.cc index 31c02d3ae5..cb69bfa1c7 100644 --- a/src/transform/producer_consumer_ws.cc +++ b/src/transform/producer_consumer_ws.cc @@ -527,6 +527,7 @@ enum class PreludeStmtPlacement : uint8_t { static PreludeStmtPlacement ClassifyPreludeStmt(const Stmt &stmt, const BufferDataToBufferMap &buffer_map, + const LocalLiveSet &shared_live_seed, const LocalLiveSet &producer_live_seed, const LocalLiveSet &consumer_live_seed) { LocalAccessSummary summary = LocalAccessCollector::Collect(stmt, buffer_map); @@ -534,6 +535,10 @@ ClassifyPreludeStmt(const Stmt &stmt, const BufferDataToBufferMap &buffer_map, return PreludeStmtPlacement::kKeepSharedPrelude; } + if (shared_live_seed.NeedsAnyDef(summary)) { + return PreludeStmtPlacement::kKeepSharedPrelude; + } + bool producer_needs = producer_live_seed.NeedsAnyDef(summary); bool consumer_needs = consumer_live_seed.NeedsAnyDef(summary); if (producer_needs && consumer_needs) { @@ -1687,6 +1692,7 @@ class ProducerConsumerWSRewriter : public StmtExprMutator { // Consumer: threadIdx.x stays, but extent is consumer_extent Stmt rewritten_consumer = final_consumer_loop; + shared_prelude_live_seed_ = {}; producer_prelude_live_seed_ = {}; consumer_prelude_live_seed_ = {}; producer_prelude_live_seed_.AddUses(LocalAccessCollector::Collect( @@ -1735,7 +1741,7 @@ class ProducerConsumerWSRewriter : public StmtExprMutator { } auto maybe_clone = [&](const Buffer &buffer) { if (!buffer.defined() || - !(IsFragmentBuffer(buffer) || IsLocalBuffer(buffer, true)) || + !(IsFragmentBuffer(buffer) || IsLocalBuffer(buffer)) || !block_alloc_buffers.count(buffer.get()) || producer_buffer_remap.count(buffer.get())) { return; @@ -2070,14 +2076,37 @@ class ProducerConsumerWSRewriter : public StmtExprMutator { if (loop_idx < 0) { return {stmt, false}; } + // Propagate liveness backwards through prelude statements so that + // transitive dependencies are captured. For example, if consumer + // needs `m_start` and `m_start` is defined by a prelude statement + // that reads `cur_batch_idx`, the loop defining `cur_batch_idx` + // must also be visible to the consumer. + { + LocalLiveSet producer_live = producer_prelude_live_seed_; + LocalLiveSet consumer_live = consumer_prelude_live_seed_; + for (int i = loop_idx - 1; i >= 0; --i) { + LocalAccessSummary summary = LocalAccessCollector::Collect( + seq->seq[i], buffer_data_to_buffer_); + if (!summary.HasTrackedDefs()) + continue; + if (producer_live.NeedsAnyDef(summary)) { + producer_live.AddUses(summary); + } + if (consumer_live.NeedsAnyDef(summary)) { + consumer_live.AddUses(summary); + } + } + producer_prelude_live_seed_ = producer_live; + consumer_prelude_live_seed_ = consumer_live; + } // Classify pre-loop statements using branch-private def/use sets. // Shared-prelude statements stay in place; branch-private definitions // move next to the branch that consumes them, or are duplicated when // both producer and consumer need the same definition. for (int i = 0; i < loop_idx; ++i) { - switch (ClassifyPreludeStmt(seq->seq[i], buffer_data_to_buffer_, - producer_prelude_live_seed_, - consumer_prelude_live_seed_)) { + switch (ClassifyPreludeStmt( + seq->seq[i], buffer_data_to_buffer_, shared_prelude_live_seed_, + producer_prelude_live_seed_, consumer_prelude_live_seed_)) { case PreludeStmtPlacement::kProducerOnly: extracted_producer_init_.push_back(seq->seq[i]); break; @@ -2109,6 +2138,18 @@ class ProducerConsumerWSRewriter : public StmtExprMutator { return {new_seq.size() == 1 ? new_seq[0] : SeqStmt(new_seq), true}; } if (auto *let = stmt.as()) { + // The LetStmt value is evaluated in the shared prelude (outside + // both producer and consumer branches). If it reads branch-private + // buffers or vars defined by a prelude statement, that definition + // must remain available in the shared scope. Propagate such uses + // into both live seeds before visiting the body so the upstream + // prelude-statement classifier sees them when classifying the + // surrounding SeqStmt. + { + LocalAccessSummary val_summary = LocalAccessCollector::Collect( + Evaluate(let->value), buffer_data_to_buffer_); + shared_prelude_live_seed_.AddUses(val_summary); + } ReplaceResult result = ReplacePipelineLoopInStmt( let->body, pipeline_loop, ws_body, consumer_extent); if (!result.found) { @@ -2186,6 +2227,7 @@ class ProducerConsumerWSRewriter : public StmtExprMutator { bool ws_transformed_{false}; BufferDataToBufferMap buffer_data_to_buffer_; std::unordered_map common_prelude_rewrites_; + LocalLiveSet shared_prelude_live_seed_; LocalLiveSet producer_prelude_live_seed_; LocalLiveSet consumer_prelude_live_seed_; Array extracted_producer_init_; diff --git a/testing/python/transform/test_tilelang_transform_producer_consumer_ws.py b/testing/python/transform/test_tilelang_transform_producer_consumer_ws.py index 4bbfd52e36..eee80b2c49 100644 --- a/testing/python/transform/test_tilelang_transform_producer_consumer_ws.py +++ b/testing/python/transform/test_tilelang_transform_producer_consumer_ws.py @@ -114,6 +114,119 @@ def main( return main +def grouped_gemm_padded_pipelined( + batch_sizes, + K, + N, + block_M=64, + block_N=64, + block_K=32, + num_stages=2, + threads=256, + dtype="float16", +): + """Grouped GEMM with padded M tiles to exercise WS shared-prelude local vars.""" + + batch_sizes = tuple(batch_sizes) + batch_count = len(batch_sizes) + batch_sum = sum(batch_sizes) + total_m_blocks = sum((size + block_M - 1) // block_M for size in batch_sizes) + + @T.prim_func + def main( + A: T.Buffer((batch_sum, K), dtype), + B: T.Buffer((batch_count, K, N), dtype), + C: T.Buffer((batch_sum, N), dtype), + batch_sizes_buf: T.Buffer((batch_count,), "int32"), + batch_offsets: T.Buffer((batch_count,), "int32"), + batch_padded_offsets: T.Buffer((batch_count,), "int32"), + ): + with T.Kernel(total_m_blocks, T.ceildiv(N, block_N), threads=threads) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), "float32") + cur_batch_idx = T.alloc_var("int32") + cur_batch_size = T.alloc_var("int32") + + m_start_padded = bx * block_M + for i in range(batch_count): + in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i] + cur_batch_idx = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx) + + cur_batch_size = batch_sizes_buf[cur_batch_idx] + m_start = m_start_padded - batch_padded_offsets[cur_batch_idx] + batch_offsets[cur_batch_idx] + actual_rows = T.max( + 0, + T.min(block_M, cur_batch_size + batch_padded_offsets[cur_batch_idx] - m_start_padded), + ) + + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[m_start, ko * block_K], A_shared) + T.copy(B[cur_batch_idx, ko * block_K, by * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + for i, j in T.Parallel(block_M, block_N): + if i < actual_rows: + C[m_start + i, by * block_N + j] = C_local[i, j] + + return main + + +def grouped_gemm_reference(A, B, batch_sizes): + import torch + + outputs = [] + start = 0 + for idx, size in enumerate(batch_sizes): + end = start + size + outputs.append(torch.mm(A[start:end], B[idx])) + start = end + return torch.cat(outputs, dim=0) + + +def grouped_gemm_inputs(batch_sizes, K, N, block_M, dtype="float16"): + import math + import torch + + batch_sizes = list(batch_sizes) + batch_offsets = [0] + batch_padded_offsets = [0] + for i in range(len(batch_sizes) - 1): + batch_offsets.append(batch_offsets[-1] + batch_sizes[i]) + batch_padded_offsets.append(batch_padded_offsets[-1] + math.ceil(batch_sizes[i] / block_M) * block_M) + + A = torch.randn(sum(batch_sizes), K, dtype=getattr(torch, dtype), device="cuda") + B = torch.randn(len(batch_sizes), K, N, dtype=getattr(torch, dtype), device="cuda") + batch_sizes_t = torch.tensor(batch_sizes, dtype=torch.int32, device="cuda") + batch_offsets_t = torch.tensor(batch_offsets, dtype=torch.int32, device="cuda") + batch_padded_offsets_t = torch.tensor(batch_padded_offsets, dtype=torch.int32, device="cuda") + return A, B, batch_sizes_t, batch_offsets_t, batch_padded_offsets_t + + +def _find_after(src, needle, start=0): + pos = src.find(needle, start) + assert pos >= 0, f"missing substring: {needle}" + return pos + + +def _compile_grouped_gemm_ws(batch_sizes=(63, 77), K=128, N=128, block_M=64, block_N=64, block_K=32): + pass_configs = {tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False} + func = grouped_gemm_padded_pipelined(batch_sizes, K, N, block_M, block_N, block_K) + kernel = _compile_tvm_ffi(func, pass_configs, out_idx=[2]) + return kernel, batch_sizes + + +def _run_grouped_gemm_ws(kernel, batch_sizes, K=128, N=128, block_M=64, dtype="float16"): + import torch + + A, B, batch_sizes_t, batch_offsets_t, batch_padded_offsets_t = grouped_gemm_inputs(batch_sizes, K, N, block_M, dtype) + out = kernel(A, B, batch_sizes_t, batch_offsets_t, batch_padded_offsets_t) + ref = grouped_gemm_reference(A.float(), B.float(), batch_sizes) + torch.testing.assert_close(out.float(), ref, rtol=1e-2, atol=1e-2) + return out + + @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version(9, 0) def test_tiled_ws_stage1_dynamic_loop_start(): @@ -318,6 +431,36 @@ def test_tiled_ws_sinks_preloop_tma_waits_into_consumer(): assert k_load < v_load < branch < first_wait +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_tiled_ws_keeps_shared_prelude_local_vars_for_grouped_gemm(): + """Shared-prelude grouped-gemm indices must stay outside WS branches.""" + kernel, batch_sizes = _compile_grouped_gemm_ws() + src = kernel.get_kernel_source() + + branch = _find_after(src, "if (256 <= ((int)threadIdx.x))") + cur_batch_idx_loop = _find_after(src, "for (int i = 0; i < 2; ++i)") + m_start = _find_after(src, "int m_start =") + actual_rows = _find_after(src, "int actual_rows =") + + assert cur_batch_idx_loop < m_start < actual_rows < branch + _run_grouped_gemm_ws(kernel, batch_sizes) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_tiled_ws_does_not_clone_local_var_into_producer_branch(): + """Producer branch should reuse shared local.var state instead of cloning it.""" + kernel, batch_sizes = _compile_grouped_gemm_ws() + src = kernel.get_kernel_source() + + assert "cur_batch_idx_producer_ws" not in src + assert "cur_batch_size_producer_ws" not in src + assert "tl::tma_load(B_desc" in src + assert "cur_batch_idx);" in src + _run_grouped_gemm_ws(kernel, batch_sizes) + + if __name__ == "__main__": test_tiled_ws_stage1_dynamic_loop_start() test_tiled_ws_correctness() @@ -325,3 +468,5 @@ def test_tiled_ws_sinks_preloop_tma_waits_into_consumer(): test_tiled_ws_swizzled_layout_allows_ws() test_tiled_ws_incompatible_layout_blocks_ws() test_tiled_ws_sinks_preloop_tma_waits_into_consumer() + test_tiled_ws_keeps_shared_prelude_local_vars_for_grouped_gemm() + test_tiled_ws_does_not_clone_local_var_into_producer_branch() From 04468a342549ed112c91a5265ca72517327c2f0e Mon Sep 17 00:00:00 2001 From: TerminusAkivili Date: Fri, 17 Apr 2026 17:22:54 +0800 Subject: [PATCH 084/156] [Bugfix] Fix stage-expanded annotated-layout aliases in LayoutInference (#2031) * Fix annotated layout inference for dtype-changing views * lint fix --------- Co-authored-by: LeiWang1999 --- src/transform/layout_inference.cc | 46 +++++++++++++------ .../language/test_tilelang_language_view.py | 27 +++++++++++ 2 files changed, 59 insertions(+), 14 deletions(-) diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index 7935391d35..6ee867fe8a 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -48,6 +48,30 @@ int64_t GetElementStorageBits(DataType dtype) { return static_cast(dtype.bits()) * dtype.lanes(); } +bool ShapesEqual(const Array &lhs, const Array &rhs, + arith::Analyzer *analyzer) { + if (lhs.size() != rhs.size()) { + return false; + } + for (size_t i = 0; i < lhs.size(); ++i) { + if (!analyzer->CanProveEqual(lhs[i], rhs[i])) { + return false; + } + } + return true; +} + +Optional FindLayoutAnchorBuffer(const Array &buffers, + const Layout &layout, + arith::Analyzer *analyzer) { + for (const auto &buffer : buffers) { + if (ShapesEqual(layout->InputShape(), buffer->shape, analyzer)) { + return buffer; + } + } + return Optional(); +} + } // namespace /*! @@ -738,31 +762,25 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { << "buffer " << var << " is not found in the block"; const auto &buffers = buffer_data_to_buffers_[var]; ICHECK(!buffers.empty()) << "buffer list for " << var << " is empty"; + Optional anchor_buffer = + FindLayoutAnchorBuffer(buffers, layout, &analyzer_); + int64_t anchor_bits = + anchor_buffer.defined() + ? GetElementStorageBits(anchor_buffer.value()->dtype) + : GetElementStorageBits(buffers[0]->dtype); // Apply layout to all buffers associated with this var for (const auto &buffer : buffers) { // Reshape the layout to match the buffer's shape // Check if shapes are structurally equal bool shapes_equal = - layout->InputShape().size() == buffer->shape.size(); - if (shapes_equal) { - for (size_t i = 0; i < layout->InputShape().size(); ++i) { - if (!analyzer_.CanProveEqual(layout->InputShape()[i], - buffer->shape[i])) { - shapes_equal = false; - break; - } - } - } + ShapesEqual(layout->InputShape(), buffer->shape, &analyzer_); if (shapes_equal) { annotated_layout_map_.Set(buffer, layout); } else { - // Use the first buffer sharing this var as the base for dtype ratio - int64_t base_bits = GetElementStorageBits(buffers[0]->dtype); - auto reshaped_layout = - layout->Reshape(buffer->shape, &analyzer_, Integer(base_bits), + layout->Reshape(buffer->shape, &analyzer_, Integer(anchor_bits), Integer(GetElementStorageBits(buffer->dtype))); annotated_layout_map_.Set(buffer, reshaped_layout); } diff --git a/testing/python/language/test_tilelang_language_view.py b/testing/python/language/test_tilelang_language_view.py index 356e96d7b1..8f427d8757 100644 --- a/testing/python/language/test_tilelang_language_view.py +++ b/testing/python/language/test_tilelang_language_view.py @@ -124,5 +124,32 @@ def test_view_shared_fp4_to_uint8_compile(): assert output.dtype == torch.uint8 +def annotated_layout_on_dtype_changing_view_test(): + @T.prim_func + def main( + A: T.Tensor((64, 64), T.float16), + B: T.Tensor((64, 128), T.int8), + ): + with T.Kernel(1, threads=128) as _: + A_stage = T.alloc_shared((2, 64, 64), T.float16, scope="shared.dyn") + A_i8 = T.view(A_stage, (2, 64, 128), dtype=T.int8) + T.annotate_layout({A_i8: T.Layout((2, 64, 128), lambda s, i, j: [s, i, j])}) + + for i, j in T.Parallel(64, 64): + A_stage[0, i, j] = A[i, j] + + for i, j in T.Parallel(64, 128): + B[i, j] = A_i8[0, i, j] + + return main + + +@tilelang.testing.requires_cuda +def test_annotated_layout_on_dtype_changing_view_compile(): + program = annotated_layout_on_dtype_changing_view_test() + kernel = tl.compile(program, out_idx=-1) + assert kernel.get_kernel_source() + + if __name__ == "__main__": tilelang.testing.main() From 6364f5d2d2affa8536808c8d8c8669478b022a12 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Fri, 17 Apr 2026 17:37:36 +0800 Subject: [PATCH 085/156] fix pro/epilogue let stmt copy --- .../auto_schedule/schedule_builder.cc | 11 ++-- .../auto_schedule/warpgroup_partition.cc | 54 +++++++++++++++++++ 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index eb559a126b..10c1fb60ef 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -531,6 +531,13 @@ AssignWarpgroupIdsGlobal(IRStructure *root, const WarpSpecializeConfig &config, } } +/* + Recursively schedule root node, after scheduling IRStructure satisfies the +following properties: 1) For each SequenceNode, its children are reordered by Z3 +scheduler and wrapped in ScheduleUnits. 2) For each ControlNode, its child is a +SequenceNode with Z3-scheduled children wrapped in ScheduleUnits. 3) For each +IfNode, its then_child and else_child are recursively scheduled (if exist). +*/ void ScheduleUnitBuilder::ScheduleRecursive( std::shared_ptr &node, const std::set &used_buffers) { if (!node) @@ -707,9 +714,7 @@ void ScheduleUnitBuilder::ScheduleRecursive( if_node->SetLatency( std::max(if_node->then_child ? if_node->then_child->GetLatency() : 0, if_node->else_child ? if_node->else_child->GetLatency() : 0)); - if_node->SetII( - std::max(if_node->then_child ? if_node->then_child->GetII() : 0, - if_node->else_child ? if_node->else_child->GetII() : 0)); + if_node->SetII(if_node->GetLatency()); return; } diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index 0c6579b4d0..492e08f45a 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -71,6 +71,58 @@ namespace tl { using namespace tir; using ffi::GetRef; +class LetStmtVarRenamer : public StmtExprMutator { +public: + explicit LetStmtVarRenamer(const std::string &suffix) : suffix_(suffix) {} + + Stmt Rename(Stmt stmt) { + CollectLetVars(stmt); + if (var_remap_.empty()) return stmt; + return VisitStmt(std::move(stmt)); + } + +private: + void CollectLetVars(const Stmt &stmt) { + class Collector : public StmtExprVisitor { + public: + explicit Collector(Map &remap, const std::string &suffix) + : remap_(remap), suffix_(suffix) {} + void VisitStmt_(const LetStmtNode *op) final { + remap_.Set(op->var, op->var.copy_with_suffix(suffix_)); + StmtExprVisitor::VisitStmt_(op); + } + Map &remap_; + const std::string &suffix_; + }; + Collector c(var_remap_, suffix_); + c(stmt); + } + + Stmt VisitStmt_(const LetStmtNode *op) final { + auto it = var_remap_.find(op->var); + Var new_var = it != var_remap_.end() ? (*it).second : op->var; + PrimExpr new_value = VisitExpr(op->value); + Stmt new_body = VisitStmt(op->body); + return LetStmt(new_var, new_value, new_body, op->span); + } + + PrimExpr VisitExpr_(const VarNode *op) final { + Var var = GetRef(op); + auto it = var_remap_.find(var); + if (it != var_remap_.end()) { + return (*it).second; + } + return StmtExprMutator::VisitExpr_(op); + } + + std::string suffix_; + Map var_remap_; +}; + +static Stmt RenameLetStmtVars(Stmt stmt, const std::string &suffix) { + return LetStmtVarRenamer(suffix).Rename(std::move(stmt)); +} + bool IsLetDeclTask(const TaskNode *task) { return task->stmts.size() == 1 && task->stmts[0].as() != nullptr; } @@ -725,6 +777,7 @@ Stmt ConvertIRStructureToStmt(IRStructure *structure, for_op.CopyOnWrite()->extent = max(0, for_op.get()->extent - prologue_extent); prologue = Substitute(new_for, sub); + prologue = RenameLetStmtVars(prologue, "_prologue"); } Stmt epilogue = Evaluate(0); if (enable_epi) { @@ -742,6 +795,7 @@ Stmt ConvertIRStructureToStmt(IRStructure *structure, for_op.CopyOnWrite()->extent = max(0, for_op.get()->extent - epilogue_extent); epilogue = Substitute(new_for, sub); + epilogue = RenameLetStmtVars(epilogue, "_epilogue"); } return SeqStmt({prologue, for_op, epilogue}); } else if (structure->IsWrapper()) { From 557bcc1ee8bd72fa8d38b3759fda246e7c1650d3 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Fri, 17 Apr 2026 17:37:55 +0800 Subject: [PATCH 086/156] fix z3 small n error --- tilelang/transform/z3_scheduler.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tilelang/transform/z3_scheduler.py b/tilelang/transform/z3_scheduler.py index d515db1602..3f24209d73 100644 --- a/tilelang/transform/z3_scheduler.py +++ b/tilelang/transform/z3_scheduler.py @@ -83,10 +83,8 @@ def z3_schedule_python( n = len(latencies) # For small number of tasks, return trivial schedule - if n <= 1: - if n == 1: - return [0], [0] - return [], [] + if n < 1: + raise RuntimeError("Z3 scheduling failed: n too small") if verbose: print(f"[Python Z3] Starting scheduling for {n} tasks") @@ -263,7 +261,7 @@ def z3_schedule_loop_python( n = len(latencies) # For small number of tasks, return trivial schedule - if n <= 1: + if n < 1: raise RuntimeError("Z3 loop scheduling failed: n too small") if verbose: From 3bc1c01c93903c3996077129630ed66184994c2d Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Fri, 17 Apr 2026 17:38:11 +0800 Subject: [PATCH 087/156] fix double kernel issue --- src/transform/auto_schedule.cc | 497 ++++++++++++++++++++++++--------- 1 file changed, 365 insertions(+), 132 deletions(-) diff --git a/src/transform/auto_schedule.cc b/src/transform/auto_schedule.cc index fdefb6483d..c111270432 100644 --- a/src/transform/auto_schedule.cc +++ b/src/transform/auto_schedule.cc @@ -123,6 +123,105 @@ class TilelangRootBodyExtractor : public StmtVisitor { } }; +// Detect multiple kernel launches in the PrimFunc body. +// In tilelang, when multiple T.Kernel() blocks are used, the IR structure is: +// root block body: +// AttrStmt(tl.assume, ...) +// AttrStmt(tl.assume, ...) +// SeqStmt [ +// AttrStmt(blockIdx.x, thread_extent, ..., kernel1_subtree), +// AttrStmt(blockIdx.x, thread_extent, ..., kernel2_subtree), +// ] +// Each kernel subtree contains its own launch_threads and tilelang_root block. +// This class finds that SeqStmt and returns each child as a separate kernel. +class MultiKernelDetector { +public: + static bool Detect(const Stmt &func_body, std::vector &kernel_stmts, + Stmt &prefix_wrapper) { + std::vector stmts; + const Stmt *inner = &func_body; + + // Peel through root block -> BlockRealize + if (const auto *br = inner->as()) { + inner = &br->block->body; + } + + // Peel through AttrStmt(tl.assume, ...) chains + while (const auto *attr = inner->as()) { + if (attr->attr_key != "tl.assume") + break; + inner = &attr->body; + } + + // Check if we have a SeqStmt with multiple children that each contain + // a launch_thread (thread_extent) + const auto *seq = inner->as(); + if (!seq || seq->seq.size() < 2) + return false; + + int kernel_count = 0; + for (const auto &child : seq->seq) { + if (ContainsLaunchThread(child)) { + kernel_count++; + } + } + + if (kernel_count < 2) + return false; + + for (const auto &child : seq->seq) { + kernel_stmts.push_back(child); + } + return true; + } + +private: + static bool ContainsLaunchThread(const Stmt &stmt) { + if (const auto *attr = stmt.as()) { + if (attr->attr_key == tir::attr::thread_extent) { + return true; + } + } + return false; + } +}; + +// Mutator that replaces the inner SeqStmt (inside root block -> tl.assume +// chain) with a new body. Used to reassemble multi-kernel results. +class InnerSeqStmtReplacer : public StmtMutator { +public: + explicit InnerSeqStmtReplacer(Stmt new_inner) : new_inner_(new_inner) {} + + Stmt VisitStmt_(const BlockRealizeNode *op) override { + auto new_block_body = this->VisitStmt(op->block->body); + if (new_block_body.same_as(op->block->body)) + return GetRef(op); + auto new_block = + Block(op->block->iter_vars, op->block->reads, op->block->writes, + op->block->name_hint, new_block_body, op->block->init, + op->block->alloc_buffers, op->block->match_buffers, + op->block->annotations); + return BlockRealize(op->iter_values, op->predicate, new_block); + } + + Stmt VisitStmt_(const AttrStmtNode *op) override { + if (op->attr_key == "tl.assume") { + auto new_body = this->VisitStmt(op->body); + if (new_body.same_as(op->body)) + return GetRef(op); + return AttrStmt(op->node, op->attr_key, op->value, new_body); + } + return GetRef(op); + } + + Stmt VisitStmt_(const SeqStmtNode *op) override { + return new_inner_; + } + +private: + Stmt new_inner_; +}; + // Mutator to replace the body of tilelang_root block class TilelangRootBodyReplacer : public StmtMutator { public: @@ -590,6 +689,136 @@ class IRStructureBuilder : public StmtVisitor { Stmt ReNestLetStmts(const Stmt &stmt); +// Result of scheduling a single kernel segment +struct ScheduledKernelResult { + Stmt scheduled_body; + std::vector barrier_buffers; + Map barrier_map; + std::vector buffer_infos; + PrimExpr updated_thread_extent; + bool did_warpgroup_partition{false}; +}; + +// Schedule a single kernel body (the logic previously inlined in AutoSchedule). +// This handles IRStructure building, ScheduleUnit building, barrier analysis, +// and warpgroup partition for one kernel. +static ScheduledKernelResult ScheduleSingleKernel( + const Stmt &kernel_body, IterVar thread_var, Target target, + const WarpSpecializeConfig &config, bool aggressive, bool enable_epi) { + ScheduledKernelResult result; + + // Calculate thread count for latency estimation + int64_t latency_thread_count = 1; + if (thread_var.defined() && thread_var->dom.defined()) { + PrimExpr thread_extent = thread_var->dom->extent; + if (const int64_t *extent_ptr = as_const_int(thread_extent)) { + latency_thread_count = *extent_ptr; + if (latency_thread_count < 1) + latency_thread_count = 1; + } + } + + // Build IRStructure from the body to schedule + IRStructureBuilder builder; + auto ir_structure = + builder.Build(kernel_body, latency_thread_count, target); + + // Print the built IRStructure with all statements + ICHECK(ir_structure) << "IRStructure is null (empty body?)"; + + // Build ScheduleUnits from IRStructure + ScheduleUnitBuilder unit_builder; + if (thread_var.defined()) { + unit_builder.SetThreadVar(thread_var); + } else { + LOG(FATAL) << "Could not find thread index variable, warpgroup " + "partition will use default"; + } + unit_builder.SetWarpSpecializeConfig(config); + unit_builder.SetSharedMemoryLimit(GetSharedMemoryLimit(target)); + + std::vector thread_count; + if (!aggressive) { + thread_count = unit_builder.NaiveBuild(ir_structure); + } else { + thread_count = unit_builder.Build(ir_structure); + } + + if (!config.enable_warpgroup_partition) { + result.scheduled_body = + ConvertIRStructureToStmt(ir_structure.get(), enable_epi); + result.did_warpgroup_partition = false; + return result; + } + + // Print the modified summary view + // PrintIRStructure(ir_structure.get()); + + // Analyze buffer dependencies and insert barriers before warpgroup + // partition + int next_barrier_id = 1; + LoopNestingInfo loop_info; + PrimExpr updated_thread_extent = std::accumulate( + thread_count.begin() + 1, thread_count.end(), thread_count[0]); + result.updated_thread_extent = updated_thread_extent; + Buffer neutral_sync_shared_barrier = makeBarrierBuffer( + updated_thread_extent, "neutral_sync_shared_barrier", 1, + result.barrier_buffers, result.barrier_map); + AnalyzeAndInsertBarriers(ir_structure.get(), next_barrier_id, + result.barrier_buffers, result.barrier_map, + thread_count, loop_info, result.buffer_infos, + neutral_sync_shared_barrier); + + // Print the modified summary view + // PrintIRStructure(ir_structure.get()); + + // Apply warpgroup partition to entire IRStructure + result.scheduled_body = ApplyWarpgroupPartitionToIRStructure( + ir_structure.get(), thread_var, result.barrier_buffers, + result.barrier_map, enable_epi, thread_count, config, + neutral_sync_shared_barrier); + result.did_warpgroup_partition = true; + return result; +} + + +// Helper: add barrier buffers and barrier_map to the tilelang_root block +static Stmt AddBarrierBuffersToRoot( + const Stmt &body, const std::vector &barrier_buffers, + Map &barrier_map) { + class TilelangRootAllocBufferAdder : public StmtMutator { + public: + explicit TilelangRootAllocBufferAdder( + const std::vector &buffers_to_add, + Map &barrier_map) + : buffers_to_add_(buffers_to_add), barrier_map_(barrier_map) {} + + Stmt VisitStmt_(const BlockNode *op) override { + auto block = GetRef(op); + if (op->name_hint == "tilelang_root") { + // Combine existing alloc_buffers with new buffers + Array new_alloc_buffers = op->alloc_buffers; + for (const auto &buffer : buffers_to_add_) { + new_alloc_buffers.push_back(buffer); + } + auto new_annotations = op->annotations; + new_annotations.Set("barrier_init", barrier_map_); + // Create new block with updated alloc_buffers + return Block(op->iter_vars, op->reads, op->writes, op->name_hint, + op->body, op->init, new_alloc_buffers, + op->match_buffers, new_annotations); + } + return StmtMutator::VisitStmt_(op); + } + + private: + std::vector buffers_to_add_; + Map &barrier_map_; + }; + TilelangRootAllocBufferAdder adder(barrier_buffers, barrier_map); + return adder(body); +} + // The main pass function tvm::transform::Pass AutoSchedule(const bool enable_epi) { using namespace tir::transform; @@ -604,78 +833,67 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { } auto config = GetWarpSpecializeConfig(target); - // Extract the body of tilelang_root block if it exists - TilelangRootBodyExtractor extractor; - extractor(func->body); - Stmt body_to_schedule; - bool has_tilelang_root = false; - IterVar thread_var; // Thread index variable for warpgroup partition - - if (extractor.body.defined()) { - body_to_schedule = extractor.body; - has_tilelang_root = true; - } else { - LOG(FATAL); - body_to_schedule = func->body; - } - - // Get thread index variable for warpgroup partition - // First try to get from body_to_schedule, if not found, try from the entire - // function body - thread_var = ThreadTagChecker::GetThreadVar(body_to_schedule); - if (!thread_var.defined()) { - thread_var = ThreadTagChecker::GetThreadVar(func->body); - } - - // Calculate thread count for latency estimation - int64_t latency_thread_count = 1; - if (thread_var.defined() && thread_var->dom.defined()) { - PrimExpr thread_extent = thread_var->dom->extent; - if (const int64_t *extent_ptr = as_const_int(thread_extent)) { - latency_thread_count = *extent_ptr; - if (latency_thread_count < 1) - latency_thread_count = 1; - } - } - - // Build IRStructure from the body to schedule - IRStructureBuilder builder; - auto ir_structure = - builder.Build(body_to_schedule, latency_thread_count, target); - - // Print the built IRStructure with all statements - ICHECK(ir_structure) << "IRStructure is null (empty body?)"; - // Check if aggressive auto-schedule is enabled bool aggressive = ctx->GetConfig(kEnableAggressiveAutoSchedule, Bool(true)).value(); - // Build ScheduleUnits from IRStructure - ScheduleUnitBuilder unit_builder; - if (thread_var.defined()) { - unit_builder.SetThreadVar(thread_var); - } else { - LOG(FATAL) << "Could not find thread index variable, warpgroup " - "partition will use default"; - } - unit_builder.SetWarpSpecializeConfig(config); - unit_builder.SetSharedMemoryLimit(GetSharedMemoryLimit(target)); + // Detect multiple kernel launches in the PrimFunc body. + // When multiple T.Kernel() blocks are used, the IR has a SeqStmt + // containing separate kernel subtrees, each with its own tilelang_root. + std::vector kernel_stmts; + Stmt prefix_wrapper; + bool is_multi_kernel = + MultiKernelDetector::Detect(func->body, kernel_stmts, prefix_wrapper); + + if (!is_multi_kernel) { + // --- Single-kernel path (original behavior) --- + // Extract the body of tilelang_root block if it exists + TilelangRootBodyExtractor extractor; + extractor(func->body); + Stmt body_to_schedule; + + if (extractor.body.defined()) { + body_to_schedule = extractor.body; + } else { + LOG(FATAL); + body_to_schedule = func->body; + } - std::vector thread_count; - if (!aggressive) { - thread_count = unit_builder.NaiveBuild(ir_structure); - } else { - thread_count = unit_builder.Build(ir_structure); - } + // Get thread index variable for warpgroup partition + // First try to get from body_to_schedule, if not found, try from the + // entire function body + IterVar thread_var = + ThreadTagChecker::GetThreadVar(body_to_schedule); + if (!thread_var.defined()) { + thread_var = ThreadTagChecker::GetThreadVar(func->body); + } - if (!config.enable_warpgroup_partition) { - Stmt new_body = ConvertIRStructureToStmt(ir_structure.get(), enable_epi); + auto kr = ScheduleSingleKernel(body_to_schedule, thread_var, target, + config, aggressive, enable_epi); // If we extracted from tilelang_root block, replace the body Stmt final_body; - TilelangRootBodyReplacer replacer(new_body); + TilelangRootBodyReplacer replacer(kr.scheduled_body); final_body = replacer(func->body); + if (kr.did_warpgroup_partition) { + // Apply thread extent update if warpgroup partition was applied + // (sm_90 only) + if (config.enable_thread_extend) { + ThreadExtentUpdater extent_updater(kr.updated_thread_extent); + final_body = extent_updater(final_body); + } + // Add barrier buffers to tilelang_root block's alloc_buffers + if (!kr.barrier_buffers.empty()) { + final_body = AddBarrierBuffersToRoot( + final_body, kr.barrier_buffers, kr.barrier_map); + } + // Apply multi-version alloc_buffer rewrite if needed + if (!kr.buffer_infos.empty()) { + final_body = RewriteAllocBuffers(final_body, kr.buffer_infos); + } + } + final_body = ReNestLetStmts(final_body); // Create a new PrimFunc with the updated body @@ -684,84 +902,99 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { return new_func; } - // Print the modified summary view - // PrintIRStructure(ir_structure.get()); - - // Analyze buffer dependencies and insert barriers before warpgroup - // partition - int next_barrier_id = 1; - std::vector barrier_buffers; - Map barrier_map; - LoopNestingInfo loop_info; - std::vector buffer_infos; - PrimExpr updated_thread_extent = std::accumulate( - thread_count.begin() + 1, thread_count.end(), thread_count[0]); - Buffer neutral_sync_shared_barrier = - makeBarrierBuffer(updated_thread_extent, "neutral_sync_shared_barrier", - 1, barrier_buffers, barrier_map); - AnalyzeAndInsertBarriers( - ir_structure.get(), next_barrier_id, barrier_buffers, barrier_map, - thread_count, loop_info, buffer_infos, neutral_sync_shared_barrier); - - // Print the modified summary view - // PrintIRStructure(ir_structure.get()); - - // Apply warpgroup partition to entire IRStructure - Stmt new_body = ApplyWarpgroupPartitionToIRStructure( - ir_structure.get(), thread_var, barrier_buffers, barrier_map, - enable_epi, thread_count, config, neutral_sync_shared_barrier); - - // If we extracted from tilelang_root block, replace the body - Stmt final_body; - TilelangRootBodyReplacer replacer(new_body); - final_body = replacer(func->body); - // Apply thread extent update if warpgroup partition was applied (sm_90 - // only) - if (config.enable_thread_extend) { - ThreadExtentUpdater extent_updater(updated_thread_extent); - final_body = extent_updater(final_body); - } - // Add barrier buffers to tilelang_root block's alloc_buffers - if (!barrier_buffers.empty()) { - class TilelangRootAllocBufferAdder : public StmtMutator { - public: - explicit TilelangRootAllocBufferAdder( - const std::vector &buffers_to_add, - Map &barrier_map) - : buffers_to_add_(buffers_to_add), barrier_map_(barrier_map) {} - - Stmt VisitStmt_(const BlockNode *op) override { - auto block = GetRef(op); - if (op->name_hint == "tilelang_root") { - // Combine existing alloc_buffers with new buffers - Array new_alloc_buffers = op->alloc_buffers; - for (const auto &buffer : buffers_to_add_) { - new_alloc_buffers.push_back(buffer); - } - auto new_annotations = op->annotations; - new_annotations.Set("barrier_init", barrier_map_); - // Create new block with updated alloc_buffers - return Block(op->iter_vars, op->reads, op->writes, op->name_hint, - op->body, op->init, new_alloc_buffers, - op->match_buffers, new_annotations); - } - return StmtMutator::VisitStmt_(op); + // --- Multi-kernel path --- + // Each kernel_stmts[i] is a complete kernel subtree: + // AttrStmt(blockIdx.x) -> ... -> AttrStmt(threadIdx.x) -> + // BlockRealize("tilelang_root") -> body + // Schedule each independently and reassemble with shared memory + // boundary markers between them. + Array combined_stmts; + + for (size_t i = 0; i < kernel_stmts.size(); ++i) { + Stmt kernel_subtree = kernel_stmts[i]; + + // Extract the tilelang_root body from this kernel subtree + TilelangRootBodyExtractor extractor; + extractor(kernel_subtree); + + if (!extractor.body.defined()) { + // Not a schedulable kernel (no tilelang_root), pass through + if (!combined_stmts.empty()) { + combined_stmts.push_back( + AttrStmt(Integer(0), attr::kAutoScheduleSharedMemoryBoundary, + 0, Evaluate(0))); } + combined_stmts.push_back(kernel_subtree); + continue; + } - private: - std::vector buffers_to_add_; - Map &barrier_map_; - }; + Stmt body_to_schedule = extractor.body; + + // Get thread index variable for this kernel + IterVar thread_var = + ThreadTagChecker::GetThreadVar(kernel_subtree); + if (!thread_var.defined()) { + // Fallback: pass through without scheduling + if (!combined_stmts.empty()) { + combined_stmts.push_back( + AttrStmt(Integer(0), attr::kAutoScheduleSharedMemoryBoundary, + 0, Evaluate(0))); + } + combined_stmts.push_back(kernel_subtree); + continue; + } - TilelangRootAllocBufferAdder adder(barrier_buffers, barrier_map); - final_body = adder(final_body); + // Schedule this kernel independently + auto kr = ScheduleSingleKernel(body_to_schedule, thread_var, target, + config, aggressive, enable_epi); + + // Replace the tilelang_root body in this kernel subtree + Stmt scheduled_subtree; + { + TilelangRootBodyReplacer replacer(kr.scheduled_body); + scheduled_subtree = replacer(kernel_subtree); + } + + if (kr.did_warpgroup_partition) { + // Apply thread extent update if warpgroup partition was applied + // (sm_90 only) + if (config.enable_thread_extend) { + ThreadExtentUpdater extent_updater(kr.updated_thread_extent); + scheduled_subtree = extent_updater(scheduled_subtree); + } + // Add barrier buffers to this kernel's tilelang_root block + if (!kr.barrier_buffers.empty()) { + scheduled_subtree = AddBarrierBuffersToRoot( + scheduled_subtree, kr.barrier_buffers, kr.barrier_map); + } + // Apply multi-version alloc_buffer rewrite if needed + if (!kr.buffer_infos.empty()) { + scheduled_subtree = + RewriteAllocBuffers(scheduled_subtree, kr.buffer_infos); + } + } + + // Insert shared memory boundary between kernel segments + if (!combined_stmts.empty()) { + combined_stmts.push_back( + AttrStmt(Integer(0), attr::kAutoScheduleSharedMemoryBoundary, + 0, Evaluate(0))); + } + combined_stmts.push_back(scheduled_subtree); } - // Apply multi-version alloc_buffer rewrite if needed - if (!buffer_infos.empty()) { - final_body = RewriteAllocBuffers(final_body, buffer_infos); + // Reassemble: replace the inner SeqStmt in the PrimFunc body with the + // new combined statements + Stmt new_inner; + if (combined_stmts.size() == 1) { + new_inner = combined_stmts[0]; + } else { + new_inner = SeqStmt(combined_stmts); } + InnerSeqStmtReplacer seq_replacer(new_inner); + Stmt final_body = seq_replacer(func->body); + final_body = ReNestLetStmts(final_body); // Create a new PrimFunc with the updated body From 105b73bb22d3c553a44541220997e81eaa252cbd Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Fri, 17 Apr 2026 17:42:09 +0800 Subject: [PATCH 088/156] format --- src/transform/auto_schedule.cc | 53 +++++++++---------- .../auto_schedule/warpgroup_partition.cc | 5 +- 2 files changed, 27 insertions(+), 31 deletions(-) diff --git a/src/transform/auto_schedule.cc b/src/transform/auto_schedule.cc index c111270432..eea299a4b2 100644 --- a/src/transform/auto_schedule.cc +++ b/src/transform/auto_schedule.cc @@ -214,9 +214,7 @@ class InnerSeqStmtReplacer : public StmtMutator { return GetRef(op); } - Stmt VisitStmt_(const SeqStmtNode *op) override { - return new_inner_; - } + Stmt VisitStmt_(const SeqStmtNode *op) override { return new_inner_; } private: Stmt new_inner_; @@ -702,9 +700,10 @@ struct ScheduledKernelResult { // Schedule a single kernel body (the logic previously inlined in AutoSchedule). // This handles IRStructure building, ScheduleUnit building, barrier analysis, // and warpgroup partition for one kernel. -static ScheduledKernelResult ScheduleSingleKernel( - const Stmt &kernel_body, IterVar thread_var, Target target, - const WarpSpecializeConfig &config, bool aggressive, bool enable_epi) { +static ScheduledKernelResult +ScheduleSingleKernel(const Stmt &kernel_body, IterVar thread_var, Target target, + const WarpSpecializeConfig &config, bool aggressive, + bool enable_epi) { ScheduledKernelResult result; // Calculate thread count for latency estimation @@ -720,8 +719,7 @@ static ScheduledKernelResult ScheduleSingleKernel( // Build IRStructure from the body to schedule IRStructureBuilder builder; - auto ir_structure = - builder.Build(kernel_body, latency_thread_count, target); + auto ir_structure = builder.Build(kernel_body, latency_thread_count, target); // Print the built IRStructure with all statements ICHECK(ir_structure) << "IRStructure is null (empty body?)"; @@ -761,9 +759,9 @@ static ScheduledKernelResult ScheduleSingleKernel( PrimExpr updated_thread_extent = std::accumulate( thread_count.begin() + 1, thread_count.end(), thread_count[0]); result.updated_thread_extent = updated_thread_extent; - Buffer neutral_sync_shared_barrier = makeBarrierBuffer( - updated_thread_extent, "neutral_sync_shared_barrier", 1, - result.barrier_buffers, result.barrier_map); + Buffer neutral_sync_shared_barrier = + makeBarrierBuffer(updated_thread_extent, "neutral_sync_shared_barrier", 1, + result.barrier_buffers, result.barrier_map); AnalyzeAndInsertBarriers(ir_structure.get(), next_barrier_id, result.barrier_buffers, result.barrier_map, thread_count, loop_info, result.buffer_infos, @@ -781,11 +779,10 @@ static ScheduledKernelResult ScheduleSingleKernel( return result; } - // Helper: add barrier buffers and barrier_map to the tilelang_root block -static Stmt AddBarrierBuffersToRoot( - const Stmt &body, const std::vector &barrier_buffers, - Map &barrier_map) { +static Stmt AddBarrierBuffersToRoot(const Stmt &body, + const std::vector &barrier_buffers, + Map &barrier_map) { class TilelangRootAllocBufferAdder : public StmtMutator { public: explicit TilelangRootAllocBufferAdder( @@ -805,8 +802,8 @@ static Stmt AddBarrierBuffersToRoot( new_annotations.Set("barrier_init", barrier_map_); // Create new block with updated alloc_buffers return Block(op->iter_vars, op->reads, op->writes, op->name_hint, - op->body, op->init, new_alloc_buffers, - op->match_buffers, new_annotations); + op->body, op->init, new_alloc_buffers, op->match_buffers, + new_annotations); } return StmtMutator::VisitStmt_(op); } @@ -862,8 +859,7 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { // Get thread index variable for warpgroup partition // First try to get from body_to_schedule, if not found, try from the // entire function body - IterVar thread_var = - ThreadTagChecker::GetThreadVar(body_to_schedule); + IterVar thread_var = ThreadTagChecker::GetThreadVar(body_to_schedule); if (!thread_var.defined()) { thread_var = ThreadTagChecker::GetThreadVar(func->body); } @@ -885,8 +881,8 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { } // Add barrier buffers to tilelang_root block's alloc_buffers if (!kr.barrier_buffers.empty()) { - final_body = AddBarrierBuffersToRoot( - final_body, kr.barrier_buffers, kr.barrier_map); + final_body = AddBarrierBuffersToRoot(final_body, kr.barrier_buffers, + kr.barrier_map); } // Apply multi-version alloc_buffer rewrite if needed if (!kr.buffer_infos.empty()) { @@ -921,8 +917,8 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { // Not a schedulable kernel (no tilelang_root), pass through if (!combined_stmts.empty()) { combined_stmts.push_back( - AttrStmt(Integer(0), attr::kAutoScheduleSharedMemoryBoundary, - 0, Evaluate(0))); + AttrStmt(Integer(0), attr::kAutoScheduleSharedMemoryBoundary, 0, + Evaluate(0))); } combined_stmts.push_back(kernel_subtree); continue; @@ -931,14 +927,13 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { Stmt body_to_schedule = extractor.body; // Get thread index variable for this kernel - IterVar thread_var = - ThreadTagChecker::GetThreadVar(kernel_subtree); + IterVar thread_var = ThreadTagChecker::GetThreadVar(kernel_subtree); if (!thread_var.defined()) { // Fallback: pass through without scheduling if (!combined_stmts.empty()) { combined_stmts.push_back( - AttrStmt(Integer(0), attr::kAutoScheduleSharedMemoryBoundary, - 0, Evaluate(0))); + AttrStmt(Integer(0), attr::kAutoScheduleSharedMemoryBoundary, 0, + Evaluate(0))); } combined_stmts.push_back(kernel_subtree); continue; @@ -977,8 +972,8 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { // Insert shared memory boundary between kernel segments if (!combined_stmts.empty()) { combined_stmts.push_back( - AttrStmt(Integer(0), attr::kAutoScheduleSharedMemoryBoundary, - 0, Evaluate(0))); + AttrStmt(Integer(0), attr::kAutoScheduleSharedMemoryBoundary, 0, + Evaluate(0))); } combined_stmts.push_back(scheduled_subtree); } diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index 492e08f45a..36a1a3225b 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -77,8 +77,9 @@ class LetStmtVarRenamer : public StmtExprMutator { Stmt Rename(Stmt stmt) { CollectLetVars(stmt); - if (var_remap_.empty()) return stmt; - return VisitStmt(std::move(stmt)); + if (var_remap_.empty()) + return stmt; + return VisitStmt(stmt); } private: From b18c60a246ca448cc8363b8c776ea97fc5d683e4 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Fri, 17 Apr 2026 17:48:31 +0800 Subject: [PATCH 089/156] undo failed merge --- src/transform/auto_schedule/schedule_builder.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index 10c1fb60ef..263fff3b8e 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -714,7 +714,9 @@ void ScheduleUnitBuilder::ScheduleRecursive( if_node->SetLatency( std::max(if_node->then_child ? if_node->then_child->GetLatency() : 0, if_node->else_child ? if_node->else_child->GetLatency() : 0)); - if_node->SetII(if_node->GetLatency()); + if_node->SetII( + std::max(if_node->then_child ? if_node->then_child->GetII() : 0, + if_node->else_child ? if_node->else_child->GetII() : 0)); return; } From aa877ab698ec7c8eaa11dc00207e8c8061de97af Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Fri, 17 Apr 2026 17:53:54 +0800 Subject: [PATCH 090/156] check dependency in prologue --- src/transform/auto_schedule/schedule_builder.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index 263fff3b8e..1fad07bb73 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -241,6 +241,12 @@ void CollectPrefixTasks(IRStructure *root, break; } } + for (auto *pre : prefix_tasks) { + if (HasDependency(task, pre)) { + has_dep = true; + break; + } + } if (has_dep) { rejected.push_back(task); } else { From c96dd9e1b16b2200db74a2a6d4320b71117b694e Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Fri, 17 Apr 2026 18:52:10 +0800 Subject: [PATCH 091/156] fix header missing --- src/transform/auto_schedule.cc | 14 ++++---------- src/transform/auto_schedule/barrier.h | 12 ++---------- src/transform/auto_schedule/schedule_builder.cc | 2 +- 3 files changed, 7 insertions(+), 21 deletions(-) diff --git a/src/transform/auto_schedule.cc b/src/transform/auto_schedule.cc index eea299a4b2..58e94127b9 100644 --- a/src/transform/auto_schedule.cc +++ b/src/transform/auto_schedule.cc @@ -55,7 +55,7 @@ #include "../op/builtin.h" #include "../op/copy.h" -#include "../op/gemm_py.h" +#include "../op/gemm.h" #include "../target/utils.h" #include "./common/attr.h" #include "./common/collector.h" @@ -472,12 +472,8 @@ class IRStructureBuilder : public StmtVisitor { void VisitExpr_(const CallNode *op) override { // Check for specific TileLang operations static const auto copy_op = Op::Get("tl.tileop.copy"); - static const auto gemm_py_op = Op::Get("tl.tileop.gemm_py"); static const auto gemm_op = Op::Get("tl.tileop.gemm"); - static const auto wgmma_gemm_py_op = Op::Get("tl.tileop.wgmma_gemm_py"); static const auto wgmma_gemm_op = Op::Get("tl.tileop.wgmma_gemm"); - static const auto tcgen05_gemm_py_op = - Op::Get("tl.tileop.tcgen05_gemm_py"); static const auto tcgen05_gemm_op = Op::Get("tl.tileop.tcgen05_gemm"); static const auto reduce_op = Op::Get("tl.tileop.reduce"); static const auto fill_op = Op::Get("tl.tileop.fill"); @@ -524,10 +520,8 @@ class IRStructureBuilder : public StmtVisitor { } } } - } else if (op->op.same_as(gemm_py_op) || op->op.same_as(gemm_op) || - op->op.same_as(wgmma_gemm_py_op) || + } else if (op->op.same_as(gemm_op) || op->op.same_as(wgmma_gemm_op) || - op->op.same_as(tcgen05_gemm_py_op) || op->op.same_as(tcgen05_gemm_op)) { found_tensor = true; @@ -538,9 +532,9 @@ class IRStructureBuilder : public StmtVisitor { // Determine the final GemmInst using GemmPyNode::getGemmInst if (target.defined()) { - GemmPy gemm_py(op->args); + Gemm gemm(op->args); GemmInst inst = - gemm_py->getGemmInst(static_cast(block_size), target); + gemm->getGemmInst(static_cast(block_size), target); ICHECK(!has_gemm_inst || gemm_inst == inst) << "All gemm operations in a task must use the same GemmInst, " << "but got " << GemmInstToString(gemm_inst) << " and " diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index 99f0371c94..843b17090f 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -449,11 +449,8 @@ static void RewriteTaskNodeBuffers( // This is used for TCGEN05MMA where the gemm needs to reference the correct // mbarrier for synchronization. static void RewriteGemmMbar(TaskNode *task, PrimExpr mbar_expr) { - static const auto gemm_py_op = Op::Get("tl.tileop.gemm_py"); static const auto gemm_op = Op::Get("tl.tileop.gemm"); - static const auto wgmma_gemm_py_op = Op::Get("tl.tileop.wgmma_gemm_py"); static const auto wgmma_gemm_op = Op::Get("tl.tileop.wgmma_gemm"); - static const auto tcgen05_gemm_py_op = Op::Get("tl.tileop.tcgen05_gemm_py"); static const auto tcgen05_gemm_op = Op::Get("tl.tileop.tcgen05_gemm"); class GemmMbarRewriter : public StmtExprMutator { @@ -462,17 +459,12 @@ static void RewriteGemmMbar(TaskNode *task, PrimExpr mbar_expr) { private: PrimExpr VisitExpr_(const CallNode *op) override { - static const auto gemm_py_op = Op::Get("tl.tileop.gemm_py"); static const auto gemm_op = Op::Get("tl.tileop.gemm"); - static const auto wgmma_gemm_py_op = Op::Get("tl.tileop.wgmma_gemm_py"); static const auto wgmma_gemm_op = Op::Get("tl.tileop.wgmma_gemm"); - static const auto tcgen05_gemm_py_op = - Op::Get("tl.tileop.tcgen05_gemm_py"); static const auto tcgen05_gemm_op = Op::Get("tl.tileop.tcgen05_gemm"); - if ((op->op.same_as(gemm_py_op) || op->op.same_as(gemm_op) || - op->op.same_as(wgmma_gemm_py_op) || op->op.same_as(wgmma_gemm_op) || - op->op.same_as(tcgen05_gemm_py_op) || + if ((op->op.same_as(gemm_op) || + op->op.same_as(wgmma_gemm_op) || op->op.same_as(tcgen05_gemm_op)) && op->args.size() > 16) { Array new_args; diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index 263fff3b8e..9b4f77bd4f 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -59,7 +59,7 @@ #include #include "../../op/builtin.h" -#include "../../op/gemm_py.h" +#include "../../op/gemm.h" #include "../../op/utils.h" #include "../../target/utils.h" #include "../common/attr.h" From e1d63880f6506d40938bc5adb4b38f96c48271cd Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Sat, 18 Apr 2026 20:15:36 +0800 Subject: [PATCH 092/156] [Cache] Refactor cache namespace layout (#2057) * Refactor cache namespace layout * Update atomic cache tests for namespaced paths * Ensure namespaced cache dirs exist before save * Use auto backend in grouped gemm ptr example * Keep grouped gemm ptr example off cutedsl * Always use auto backend for grouped gemm ptr * Skip grouped gemm examples in cutedsl CI * Ignore grouped gemm examples in cutedsl workflow --- .github/workflows/ci.yml | 1 + .../example_grouped_gemm_fwd_ptr.py | 8 +- .../grouped_gemm/test_example_grouped_gemm.py | 1 - .../test_tilelang_autotune_atomic_save.py | 18 ++-- .../test_tilelang_kernel_cache_atomic_save.py | 21 +++-- tilelang/autotuner/param.py | 16 ++-- tilelang/autotuner/tuner.py | 11 ++- tilelang/cache/kernel_cache.py | 87 +++++++++++++++---- 8 files changed, 117 insertions(+), 46 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f899ed473a..fba08c54a6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -310,6 +310,7 @@ jobs: pytest --verbose --color=yes --durations=0 --showlocals --cache-clear ) "${PYTEST[@]}" --maxfail=3 --numprocesses=4 \ + --ignore=../examples/grouped_gemm/test_example_grouped_gemm.py \ ../examples # NVIDIA CUDA tests diff --git a/examples/grouped_gemm/example_grouped_gemm_fwd_ptr.py b/examples/grouped_gemm/example_grouped_gemm_fwd_ptr.py index 4ce9e7320c..d57edcc6ca 100644 --- a/examples/grouped_gemm/example_grouped_gemm_fwd_ptr.py +++ b/examples/grouped_gemm/example_grouped_gemm_fwd_ptr.py @@ -125,15 +125,17 @@ def run_tilelang_grouped_gemm_ptr( block_K, num_stages=2, threads=128, - backend="tvm_ffi", profile=False, ): device = torch.device("cuda") dtype = torch.float16 program = grouped_gemm_ptr(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages, threads) + # The ptr-backed grouped GEMM example is intended to exercise the regular CUDA + # execution path; CuTeDSL does not support these handle tensors. kernel = tl.compile( program, - execution_backend=backend, + target="cuda", + execution_backend="auto", pass_configs={"tl.disable_warp_specialized": True}, ) a_list, b_list, c_list, a_ptrs, b_ptrs, c_ptrs, batch_tile_offsets = construct_inputs(batch_sizes_list, K, N, block_M, device, dtype) @@ -159,7 +161,6 @@ def test_grouped_gemm_ptr(): parser.add_argument("--batch_sizes", type=str, default="64,128,256", help="comma-separated per-group M sizes") parser.add_argument("--K", type=int, default=4096, help="reduce dim") parser.add_argument("--N", type=int, default=4096, help="output dim") - parser.add_argument("--backend", type=str, default="tvm_ffi", choices=["tvm_ffi", "cython"], help="execution backend") parser.add_argument("--profile", action="store_true", help="benchmark the kernel") args = parser.parse_args() @@ -180,7 +181,6 @@ def test_grouped_gemm_ptr(): block_K, num_stages=num_stages, threads=threads, - backend=args.backend, profile=args.profile, ) print(f"End-to-end: {time.time() - t0:.3f} s") diff --git a/examples/grouped_gemm/test_example_grouped_gemm.py b/examples/grouped_gemm/test_example_grouped_gemm.py index 0cf03462a8..dc0c945072 100644 --- a/examples/grouped_gemm/test_example_grouped_gemm.py +++ b/examples/grouped_gemm/test_example_grouped_gemm.py @@ -34,7 +34,6 @@ def test_example_grouped_gemm_fwd_ptr_small(): block_K=32, num_stages=1, threads=256, - backend="tvm_ffi", profile=False, ) diff --git a/testing/python/autotune/test_tilelang_autotune_atomic_save.py b/testing/python/autotune/test_tilelang_autotune_atomic_save.py index 1369d76440..5fa0931b47 100644 --- a/testing/python/autotune/test_tilelang_autotune_atomic_save.py +++ b/testing/python/autotune/test_tilelang_autotune_atomic_save.py @@ -73,8 +73,8 @@ def _make_result(tmp_path, execution_backend: str = "cython"): def test_autotune_save_rewrites_incomplete_cache_dir(cache_dirs, tmp_path): result = _make_result(tmp_path) - path = cache_dirs / "autotune-entry" - path.mkdir() + path = cache_dirs / "test-namespace" / "autotuner" / "autotune-entry" + path.mkdir(parents=True) (path / "stale.txt").write_text("partial") result.save_to_disk(path) @@ -94,8 +94,9 @@ def test_autotune_save_rewrites_incomplete_cache_dir(cache_dirs, tmp_path): def test_autotune_save_logs_write_oserror_instead_of_treating_it_as_race(cache_dirs, tmp_path, monkeypatch): result = _make_result(tmp_path) - path = cache_dirs / "autotune-error" + path = cache_dirs / "test-namespace" / "autotuner" / "autotune-error" logged = [] + staging_root = path.parent.parent / ".staging" def raise_write_error(self, *args, **kwargs): raise OSError(errno.ENOSPC, "No space left on device") @@ -110,14 +111,15 @@ def record_exception(message, *args, **kwargs): assert not path.exists() assert "Error during atomic autotune result save" in logged - assert not any(child.name.startswith(".staging_") for child in cache_dirs.iterdir()) + assert not staging_root.exists() or not any(staging_root.iterdir()) def test_autotune_save_does_not_publish_incomplete_dir_when_device_source_is_missing(cache_dirs, tmp_path, monkeypatch): result = _make_result(tmp_path) result.kernel.kernel_source = None - path = cache_dirs / "autotune-missing-device-source" + path = cache_dirs / "test-namespace" / "autotuner" / "autotune-missing-device-source" logged = [] + staging_root = path.parent.parent / ".staging" def record_exception(message, *args, **kwargs): logged.append(message) @@ -128,13 +130,13 @@ def record_exception(message, *args, **kwargs): assert not path.exists() assert "Error during atomic autotune result save" in logged - assert not any(child.name.startswith(".staging_") for child in cache_dirs.iterdir()) + assert not staging_root.exists() or not any(staging_root.iterdir()) def test_autotune_save_rewrites_nvrtc_dir_missing_launcher(cache_dirs, tmp_path): result = _make_result(tmp_path, execution_backend="nvrtc") - path = cache_dirs / "autotune-nvrtc-entry" - path.mkdir() + path = cache_dirs / "test-namespace" / "autotuner" / "autotune-nvrtc-entry" + path.mkdir(parents=True) (path / BEST_CONFIG_PATH).write_text("{}") (path / FUNCTION_PATH).write_bytes(b"old-func") (path / LATENCY_PATH).write_text('{"latency": 1.0, "ref_latency": 2.0}') diff --git a/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py b/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py index d287c81799..3e87ceacaa 100644 --- a/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py +++ b/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py @@ -1,4 +1,5 @@ import errno +from pathlib import Path import pytest from tilelang.cache.kernel_cache import KernelCache @@ -48,8 +49,8 @@ def _make_fake_nvrtc_kernel(tmp_path): def test_kernel_cache_rewrites_incomplete_cache_dir(cache_dirs, tmp_path): cache = KernelCache() key = "atomic-repair" - cache_path = cache_dirs / key - cache_path.mkdir() + cache_path = Path(cache._get_cache_path(key)) + cache_path.mkdir(parents=True) (cache_path / "stale.txt").write_text("partial") cache._save_kernel_to_disk(key, _make_fake_kernel(tmp_path)) @@ -65,6 +66,8 @@ def test_kernel_cache_logs_write_oserror_instead_of_treating_it_as_race(cache_di cache = KernelCache() key = "atomic-write-error" logged = [] + cache_path = Path(cache._get_cache_path(key)) + staging_root = Path(cache._get_staging_root()) def raise_write_error(*args, **kwargs): raise OSError(errno.ENOSPC, "No space left on device") @@ -77,9 +80,9 @@ def record_exception(message, *args, **kwargs): cache._save_kernel_to_disk(key, _make_fake_kernel(tmp_path)) - assert f"{key}" not in {path.name for path in cache_dirs.iterdir()} + assert not cache_path.exists() assert "Error during atomic cache save" in logged - assert not any(path.name.startswith(".staging_") for path in cache_dirs.iterdir()) + assert not staging_root.exists() or not any(staging_root.iterdir()) def test_kernel_cache_does_not_publish_incomplete_dir_when_device_source_is_missing(cache_dirs, tmp_path, monkeypatch): @@ -88,6 +91,8 @@ def test_kernel_cache_does_not_publish_incomplete_dir_when_device_source_is_miss kernel = _make_fake_kernel(tmp_path) kernel.kernel_source = None logged = [] + cache_path = Path(cache._get_cache_path(key)) + staging_root = Path(cache._get_staging_root()) def record_exception(message, *args, **kwargs): logged.append(message) @@ -96,16 +101,16 @@ def record_exception(message, *args, **kwargs): cache._save_kernel_to_disk(key, kernel) - assert f"{key}" not in {path.name for path in cache_dirs.iterdir()} + assert not cache_path.exists() assert "Error during atomic cache save" in logged - assert not any(path.name.startswith(".staging_") for path in cache_dirs.iterdir()) + assert not staging_root.exists() or not any(staging_root.iterdir()) def test_nvrtc_kernel_cache_rewrites_dir_missing_launcher(cache_dirs, tmp_path): cache = NVRTCKernelCache() key = "nvrtc-atomic-repair" - cache_path = cache_dirs / key - cache_path.mkdir() + cache_path = Path(cache._get_cache_path(key)) + cache_path.mkdir(parents=True) (cache_path / cache.device_kernel_path).write_text("// device kernel") (cache_path / cache.host_kernel_path).write_text("// host kernel") (cache_path / cache.kernel_lib_path).write_bytes(b"old-cubin") diff --git a/tilelang/autotuner/param.py b/tilelang/autotuner/param.py index 5e0152e6b9..39cef841d1 100644 --- a/tilelang/autotuner/param.py +++ b/tilelang/autotuner/param.py @@ -373,20 +373,20 @@ def _load_kernel_from_disk( def save_to_disk(self, path: Path, verbose: bool = False): """Persist autotune result to disk using atomic directory rename. - All files are written into a temporary staging directory next to the - final *path*. Once complete, the staging directory is atomically - renamed to *path* so that concurrent readers never see a half-written - result. + All files are written into a temporary staging directory under the + shared namespace staging root. Once complete, the staging directory is + atomically renamed to *path* so that concurrent readers never see a + half-written result. """ # Already saved (e.g. another process won the race with a complete entry). if self._is_complete_result_dir(path, self.kernel.execution_backend): return - # Staging dir lives under TILELANG_CACHE_DIR (not the autotuner subdir) so that - # KernelCache._cleanup_stale_staging_dirs() can find and clean up stale entries. - staging_path = Path(env.TILELANG_CACHE_DIR) / f".staging_{Path(path).name}_{os.getpid()}_{uuid.uuid4().hex[:8]}" + # Keep autotuner staging under the shared namespace staging root so stale cleanup + # never needs to scan the full cache directory. + staging_path = path.parent.parent / ".staging" / f"{Path(path).name}_{os.getpid()}_{uuid.uuid4().hex[:8]}" os.makedirs(staging_path) - # Ensure the parent of the final path exists (e.g. ~/.tilelang/cache/autotuner/) + # Ensure the parent of the final path exists (e.g. ~/.tilelang/cache//autotuner/) os.makedirs(Path(path).parent, exist_ok=True) try: diff --git a/tilelang/autotuner/tuner.py b/tilelang/autotuner/tuner.py index f35828f2c4..413a66602b 100644 --- a/tilelang/autotuner/tuner.py +++ b/tilelang/autotuner/tuner.py @@ -132,7 +132,6 @@ class AutoTuner: _function_parameters: dict[str, Any] | None = None _lock = threading.Lock() # For thread safety _memory_cache = {} # In-memory cache dictionary - cache_dir: Path = Path(env.TILELANG_CACHE_DIR) / "autotuner" def __init__(self, fn: Callable, configs): self.fn = fn @@ -142,6 +141,16 @@ def __init__(self, fn: Callable, configs): self.ref_input_tensors = None self.jit_compile = None + @classmethod + def _get_cache_dir(cls) -> Path: + from tilelang.cache.kernel_cache import KernelCache + + return Path(KernelCache._get_namespace_root()) / "autotuner" + + @property + def cache_dir(self) -> Path: + return self._get_cache_dir() + @classmethod def from_kernel(cls, kernel: Callable, configs): """Create an AutoTuner instance from a kernel function. diff --git a/tilelang/cache/kernel_cache.py b/tilelang/cache/kernel_cache.py index fc3dbc68f4..f57d14e6e2 100644 --- a/tilelang/cache/kernel_cache.py +++ b/tilelang/cache/kernel_cache.py @@ -39,11 +39,15 @@ class KernelCache: _instance = None # For implementing singleton pattern _lock = threading.Lock() # For thread safety _memory_cache = {} # In-memory cache dictionary + _staging_cleanup_lock = threading.Lock() + _last_cleaned_staging_root: str | None = None execution_backend: Literal["tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi" device_kernel_path = "device_kernel.cu" host_kernel_path = "host_kernel.cu" kernel_lib_path = "kernel_lib.so" params_path = "params.pkl" + cache_root_dir = "kernels" + staging_root_dir = ".staging" @staticmethod @functools.cache @@ -110,6 +114,29 @@ def _get_base_key() -> dict: base["torch"] = torch.__version__ return base + @staticmethod + def _sanitize_path_component(component: str) -> str: + sanitized = "".join(ch if ch.isalnum() or ch in "._-" else "_" for ch in component) + sanitized = sanitized.strip("._-") + return sanitized or "unknown" + + @staticmethod + def _format_version_namespace(version: str) -> str: + public, sep, local = version.partition("+") + public = KernelCache._sanitize_path_component(public) + if not sep: + return public + local = "".join(ch if ch.isalnum() else "_" for ch in local).strip("_") + return f"{public}_{local}" if local else public + + @staticmethod + @functools.cache + def _get_cache_namespace() -> str: + base_key = KernelCache._get_base_key() + version = KernelCache._format_version_namespace(str(base_key.get("version", "unknown"))) + platform_name = KernelCache._sanitize_path_component(str(base_key.get("platform", "unknown"))) + return f"{version}-{platform_name}" + def __new__(cls): """ Implements singleton pattern for KernelCache class. @@ -132,11 +159,31 @@ def __new__(cls): def _create_dirs(): os.makedirs(env.TILELANG_CACHE_DIR, exist_ok=True) os.makedirs(env.TILELANG_TMP_DIR, exist_ok=True) - KernelCache._cleanup_stale_staging_dirs() + os.makedirs(KernelCache._get_namespace_root(), exist_ok=True) + os.makedirs(KernelCache._get_cache_root(), exist_ok=True) + os.makedirs(KernelCache._get_staging_root(), exist_ok=True) + + staging_root = KernelCache._get_staging_root() + with KernelCache._staging_cleanup_lock: + if KernelCache._last_cleaned_staging_root != staging_root: + KernelCache._cleanup_stale_staging_dirs() + KernelCache._last_cleaned_staging_root = staging_root + + @staticmethod + def _get_namespace_root() -> str: + return os.path.join(env.TILELANG_CACHE_DIR, KernelCache._get_cache_namespace()) + + @staticmethod + def _get_cache_root() -> str: + return os.path.join(KernelCache._get_namespace_root(), KernelCache.cache_root_dir) + + @staticmethod + def _get_staging_root() -> str: + return os.path.join(KernelCache._get_namespace_root(), KernelCache.staging_root_dir) @staticmethod def _cleanup_stale_staging_dirs(max_age_seconds: int = 3600): - """Remove staging directories older than *max_age_seconds* (default 1 h). + """Remove stale entries from the dedicated staging root. These are left behind when a process crashes mid-save. """ @@ -144,8 +191,12 @@ def _cleanup_stale_staging_dirs(max_age_seconds: int = 3600): try: now = time.time() - for entry in os.scandir(env.TILELANG_CACHE_DIR): - if entry.name.startswith(".staging_") and entry.is_dir(follow_symlinks=False): + staging_root = KernelCache._get_staging_root() + if not os.path.isdir(staging_root): + return + + for entry in os.scandir(staging_root): + if entry.is_dir(follow_symlinks=False): try: if now - entry.stat().st_mtime > max_age_seconds: shutil.rmtree(entry.path, ignore_errors=True) @@ -330,7 +381,7 @@ def _get_cache_path(self, key: str) -> str: Returns: str: Absolute path to the cache directory for this kernel. """ - return os.path.join(env.TILELANG_CACHE_DIR, key) + return os.path.join(self._get_cache_root(), key) @staticmethod def _load_binary(path: str): @@ -358,10 +409,10 @@ def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = Non """ Persists a compiled kernel to disk cache using atomic directory rename. - All files are first written into a temporary staging directory under - TILELANG_CACHE_DIR. Once every file is in place, the staging directory - is atomically renamed to the final cache path so that other processes - never observe an incomplete cache entry. + All files are first written into a temporary staging directory under the + namespace staging root. Once every file is in place, the staging directory + is atomically renamed to the final cache path so that other processes never + observe an incomplete cache entry. Args: key (str): The hash key identifying the kernel. @@ -369,16 +420,21 @@ def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = Non func (Callable, optional): The original function. verbose (bool): Enable verbose log messages. """ + # Env-backed cache roots may change across tests or at runtime; recreate the + # namespace-specific directories lazily here so direct save helpers keep working + # even when the singleton instance is reused. + KernelCache._create_dirs() cache_path = self._get_cache_path(key) # Another process already wrote a complete entry — nothing to do. if self._is_complete_cache_dir(cache_path): return - # Staging dir lives under CACHE_DIR (same filesystem) so os.rename works. + # Staging dir lives under CACHE_DIR//.staging (same filesystem) so + # os.rename works without scanning the full cache root during stale cleanup. staging_path = os.path.join( - env.TILELANG_CACHE_DIR, - f".staging_{key}_{os.getpid()}_{uuid.uuid4().hex[:8]}", + self._get_staging_root(), + f"{key}_{os.getpid()}_{uuid.uuid4().hex[:8]}", ) os.makedirs(staging_path) @@ -489,14 +545,13 @@ def _clear_disk_cache(self): Removes all cached kernels from disk. Note: - This operation will delete the entire cache directory and recreate it empty. + This operation will delete the current kernel-cache namespace and recreate it empty. Use with caution as this operation cannot be undone. """ try: - # Delete the entire cache directory - shutil.rmtree(env.TILELANG_CACHE_DIR) + shutil.rmtree(self._get_cache_root(), ignore_errors=True) + shutil.rmtree(self._get_staging_root(), ignore_errors=True) - # Re-create the cache directory KernelCache._create_dirs() except Exception: self.logger.exception("Error clearing disk cache") From 0924dab6451b843d222c6ca7d41287d03ec40c42 Mon Sep 17 00:00:00 2001 From: Tao QU Date: Sat, 18 Apr 2026 22:38:11 +0800 Subject: [PATCH 093/156] =?UTF-8?q?[Bugfix]=20Use=20shared::cta=20instead?= =?UTF-8?q?=20of=20shared::cluster=20for=20non-cluster=20T=E2=80=A6=20(#20?= =?UTF-8?q?52)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [Bugfix] Use shared::cta instead of shared::cluster for non-cluster TMA loads On SM120 (RTX 5090), using `shared::cluster` in cp.async.bulk PTX instructions without an actual cluster launch causes the NVIDIA driver to allocate ~3.5 GB of extra device memory that persists for the CUDA context lifetime. Root cause: the driver provisions internal data structures for cluster coordination (DSMEM, cross-CTA barriers) when it encounters `shared::cluster` address space qualifiers, even when cluster_size=1. TileLang's tma_load always targets the local CTA's own shared memory (never cross-CTA DSMEM), so `shared::cta` is semantically correct on both SM90 (Hopper) and SM120 (Blackwell). The only function that genuinely needs `shared::cluster` is tma_load_multicast, which remains unchanged. Changes: - copy_sm90.h: replace `shared::cluster` with `shared::cta` in all 8 tma_load / tma_load_im2col overloads (raw bulk + 1D-5D tensor). tma_load_multicast retains `shared::cluster.multicast::cluster`. - ptx.cc: change PrintCpAsyncBulkAsm (used by tir.ptx_cp_async_bulk codegen path) from `shared::cluster` to `shared::cta`. Verified on RTX 5090 (SM120): - Memory overhead: 3506 MiB -> 0 MiB - Performance: unchanged (405 TFLOPS on FP8 blockwise GEMM 8192^3) - Correctness: unchanged (cosine diff ~3e-6) - SASS: eliminates __cuda_syscall_cp_async_bulk_tensor_2d_tile_unicast Minimal repro: https://gist.github.com/Harry-Chen/38c0f47ce3eff4469db4a310e763e949 Made-with: Cursor Co-authored-by: qutao --- src/target/ptx.cc | 2 +- src/tl_templates/cuda/copy_sm90.h | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/target/ptx.cc b/src/target/ptx.cc index 9bf7c5a2e7..8b806c783f 100644 --- a/src/target/ptx.cc +++ b/src/target/ptx.cc @@ -1429,7 +1429,7 @@ std::string PrintCpAsyncBulkAsm(const std::string &shared_ptr, unsigned int smem_addr_int = cast_smem_ptr_to_int({smem_addr}); unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); __asm__ __volatile__( - "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" + "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" :: "r"(smem_addr_int), "l"({global_ptr}), "r"({bytes}), "r"(barrier_addr_int) : "memory" ); diff --git a/src/tl_templates/cuda/copy_sm90.h b/src/tl_templates/cuda/copy_sm90.h index 3d5b3f4145..86c845bbf7 100644 --- a/src/tl_templates/cuda/copy_sm90.h +++ b/src/tl_templates/cuda/copy_sm90.h @@ -20,7 +20,7 @@ TL_DEVICE void tma_load(void *smem_ptr, void const *gmem_ptr, uint32_t smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); - asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::" + asm volatile("cp.async.bulk.shared::cta.global.mbarrier::complete_tx::" "bytes [%0], [%1], %2, [%3]; \n" ::"r"(smem_int_ptr), "l"((void const *)gmem_ptr), "r"(size), "r"(smem_int_mbar) :); @@ -50,7 +50,7 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); } uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); - asm volatile("cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::" + asm volatile("cp.async.bulk.tensor.1d.shared::cta.global.mbarrier::" "complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3}], [%2], %4;" : @@ -72,7 +72,7 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); } uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); - asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::" + asm volatile("cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::" "complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4}], [%2], %5;" : @@ -94,7 +94,7 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); } uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); - asm volatile("cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::" + asm volatile("cp.async.bulk.tensor.3d.shared::cta.global.mbarrier::" "complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4, %5}], [%2], %6;" : @@ -116,7 +116,7 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); } uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); - asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::" + asm volatile("cp.async.bulk.tensor.4d.shared::cta.global.mbarrier::" "complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" : @@ -139,7 +139,7 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); } uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); - asm volatile("cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::" + asm volatile("cp.async.bulk.tensor.5d.shared::cta.global.mbarrier::" "complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" : @@ -161,7 +161,7 @@ tma_load_im2col(const CUtensorMap &descriptor, BarrierType &smem_mbar, uint32_t smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); - asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier:" + asm volatile("cp.async.bulk.tensor.4d.shared::cta.global.im2col.mbarrier:" ":complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9;" : From b13cdf33b70bd01c4c819457cab4bfad62795866 Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Mon, 20 Apr 2026 11:53:44 +0800 Subject: [PATCH 094/156] change the interface to support tasks with wg_id=-1 --- src/transform/auto_schedule/barrier.h | 52 ++++++------ src/transform/auto_schedule/ir_structure.cc | 77 +++++++++-------- src/transform/auto_schedule/ir_structure.h | 82 +++++++++++-------- .../auto_schedule/schedule_builder.h | 6 +- 4 files changed, 121 insertions(+), 96 deletions(-) diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index 99f0371c94..75bf140ea4 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -601,8 +601,9 @@ GetSyncInfos(const std::vector &units, int num_wgs, bool is_loop = false) { std::set buffers; for (auto *unit : units) { - for (const auto ®ion_access : unit->GetReadWriteRegions()) { - buffers.insert(region_access.region->buffer); + for (const auto &buffer_access : + unit->GetBufferAccessInfo(num_wgs, SchedulePhase::kBody)) { + buffers.insert(buffer_access.buffer); } } std::map, @@ -620,12 +621,11 @@ GetSyncInfos(const std::vector &units, int num_wgs, std::vector waited_write_wgs(num_wgs, false); for (int iter = 0; iter < (is_loop ? 2 : 1); ++iter) { for (ScheduleUnit *unit : units) { - for (const auto ®ion_access : unit->GetReadWriteRegions()) { - int wg_id = region_access.warpgroup_id; - if (region_access.schedule_phase != SchedulePhase::kBody) - continue; + for (const auto &buffer_access : + unit->GetBufferAccessInfo(num_wgs, SchedulePhase::kBody)) { + int wg_id = buffer_access.warpgroup_id; ICHECK(0 <= wg_id && wg_id < num_wgs); - if (region_access.region->buffer != buffer) + if (buffer_access.buffer != buffer) continue; auto add_sync = [&](ScheduleUnit *wait_unit, int wait_wg_id) { int distance = iter ? num_versions : 0; @@ -639,7 +639,7 @@ GetSyncInfos(const std::vector &units, int num_wgs, it->second = std::min(it->second, distance); } }; - if (!region_access.is_write) { + if (!buffer_access.is_write) { if (last_write_unit == nullptr) continue; if (waited_write_wgs[wg_id]) @@ -654,13 +654,12 @@ GetSyncInfos(const std::vector &units, int num_wgs, } } if (iter == 0) { - for (const auto ®ion_access : unit->GetReadWriteRegions()) { - int wg_id = region_access.warpgroup_id; - if (region_access.schedule_phase != SchedulePhase::kBody) - continue; - if (region_access.region->buffer != buffer) + for (const auto &buffer_access : + unit->GetBufferAccessInfo(num_wgs, SchedulePhase::kBody)) { + int wg_id = buffer_access.warpgroup_id; + if (buffer_access.buffer != buffer) continue; - if (!region_access.is_write) { + if (!buffer_access.is_write) { waited_write_wgs[wg_id] = true; } else { for (int wg_id = 0; wg_id < num_wgs; ++wg_id) { @@ -668,13 +667,12 @@ GetSyncInfos(const std::vector &units, int num_wgs, } } } - for (const auto ®ion_access : unit->GetReadWriteRegions()) { - int wg_id = region_access.warpgroup_id; - if (region_access.schedule_phase != SchedulePhase::kBody) + for (const auto &buffer_access : + unit->GetBufferAccessInfo(num_wgs, SchedulePhase::kBody)) { + int wg_id = buffer_access.warpgroup_id; + if (buffer_access.buffer != buffer) continue; - if (region_access.region->buffer != buffer) - continue; - if (!region_access.is_write) { + if (!buffer_access.is_write) { last_read_unit[wg_id] = unit; } else { last_write_unit = unit; @@ -972,9 +970,11 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, multi_buffer; std::unordered_map buffer_num_versions; + int num_wgs = thread_count.size(); for (const auto &unit : ordered_units) { - for (const auto ®ion_access : unit->GetReadWriteRegions()) { - auto &buffer = region_access.region->buffer; + for (const auto &buffer_access : + unit->GetBufferAccessInfo(num_wgs, SchedulePhase::kBody)) { + auto &buffer = buffer_access.buffer; if (!ctrl->multi_buffering_buffers.count(buffer)) continue; for (const auto &other_unit : ordered_units) { @@ -985,12 +985,12 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, if (distance <= 0) continue; distance = (distance - 1) / ctrl->GetIIperIter() + 1; - for (const auto &other_region_access : - other_unit->GetReadWriteRegions()) { - auto &other_buffer = other_region_access.region->buffer; + for (const auto &other_buffer_access : + other_unit->GetBufferAccessInfo(num_wgs, SchedulePhase::kBody)) { + auto &other_buffer = other_buffer_access.buffer; if (!buffer.same_as(other_buffer)) continue; - if (region_access.is_write || other_region_access.is_write) { + if (buffer_access.is_write || other_buffer_access.is_write) { auto &num_versions = buffer_num_versions[buffer]; num_versions = std::max(num_versions, distance); } diff --git a/src/transform/auto_schedule/ir_structure.cc b/src/transform/auto_schedule/ir_structure.cc index 811bab9689..a3273b0e77 100644 --- a/src/transform/auto_schedule/ir_structure.cc +++ b/src/transform/auto_schedule/ir_structure.cc @@ -267,25 +267,31 @@ std::shared_ptr TaskNode::Clone() const { return new_task; } -void TaskNode::CollectRegions( - std::vector &result, - std::set>> &visited) const { +void TaskNode::CollectBufferAccessInfo( + int num_wgs, SchedulePhase phase, + std::set &result) const { int wg_id = GetWarpgroupId(); - SchedulePhase phase = GetSchedulePhase(); - // Collect write regions + if (GetSchedulePhase() != phase) { + return; + } + // Collect write buffers for (const auto ®ion : GetWriteRegions()) { - auto key = std::make_pair(region->buffer, std::make_pair(true, wg_id)); - if (visited.find(key) == visited.end()) { - visited.insert(key); - result.emplace_back(region, true, wg_id, phase); + if (wg_id != -1) { + result.emplace(region->buffer, true, wg_id, phase); + } else { + for (int i = 0; i < num_wgs; ++i) { + result.emplace(region->buffer, true, i, phase); + } } } - // Collect read regions + // Collect read buffers for (const auto ®ion : GetReadRegions()) { - auto key = std::make_pair(region->buffer, std::make_pair(false, wg_id)); - if (visited.find(key) == visited.end()) { - visited.insert(key); - result.emplace_back(region, false, wg_id, phase); + if (wg_id != -1) { + result.emplace(region->buffer, false, wg_id, phase); + } else { + for (int i = 0; i < num_wgs; ++i) { + result.emplace(region->buffer, false, i, phase); + } } } } @@ -385,51 +391,50 @@ std::shared_ptr IfNode::Clone() const { return new_if; } -void ControlNode::CollectRegions( - std::vector &result, - std::set>> &visited) const { +void ControlNode::CollectBufferAccessInfo( + int num_wgs, SchedulePhase phase, + std::set &result) const { if (child) { - child->CollectRegions(result, visited); + child->CollectBufferAccessInfo(num_wgs, phase, result); } } -void WrapperNode::CollectRegions( - std::vector &result, - std::set>> &visited) const { +void WrapperNode::CollectBufferAccessInfo( + int num_wgs, SchedulePhase phase, + std::set &result) const { if (child) { - child->CollectRegions(result, visited); + child->CollectBufferAccessInfo(num_wgs, phase, result); } } -void ScheduleUnit::CollectRegions( - std::vector &result, - std::set>> &visited) const { +void ScheduleUnit::CollectBufferAccessInfo( + int num_wgs, SchedulePhase phase, + std::set &result) const { if (child) { - child->CollectRegions(result, visited); + child->CollectBufferAccessInfo(num_wgs, phase, result); } } -void SequenceNode::CollectRegions( - std::vector &result, - std::set>> &visited) const { +void SequenceNode::CollectBufferAccessInfo( + int num_wgs, SchedulePhase phase, + std::set &result) const { for (const auto &child : children) { if (child) { - child->CollectRegions(result, visited); + child->CollectBufferAccessInfo(num_wgs, phase, result); } } } -void IfNode::CollectRegions( - std::vector &result, - std::set>> &visited) const { +void IfNode::CollectBufferAccessInfo(int num_wgs, SchedulePhase phase, + std::set &result) const { if (task) { - task->CollectRegions(result, visited); + task->CollectBufferAccessInfo(num_wgs, phase, result); } if (then_child) { - then_child->CollectRegions(result, visited); + then_child->CollectBufferAccessInfo(num_wgs, phase, result); } if (else_child) { - else_child->CollectRegions(result, visited); + else_child->CollectBufferAccessInfo(num_wgs, phase, result); } } diff --git a/src/transform/auto_schedule/ir_structure.h b/src/transform/auto_schedule/ir_structure.h index 317904906b..63b6b033db 100644 --- a/src/transform/auto_schedule/ir_structure.h +++ b/src/transform/auto_schedule/ir_structure.h @@ -38,17 +38,34 @@ enum class SchedulePhase : uint8_t { kEpilogue = 2, // Runs on ALL threads AFTER warpgroup-specific code }; -// Structure to store region access information with warpgroup id -struct RegionAccessInfo { - BufferRegion region; +// Structure to store buffer access information +struct BufferAccessInfo { + Buffer buffer; bool is_write; // true for write, false for read int warpgroup_id; // warpgroup id of the innermost TaskNode SchedulePhase schedule_phase{SchedulePhase::kBody}; // scheduling phase - RegionAccessInfo(BufferRegion region, bool is_write, int warpgroup_id, + BufferAccessInfo(Buffer buffer, bool is_write, int warpgroup_id, SchedulePhase phase = SchedulePhase::kBody) - : region(region), is_write(is_write), warpgroup_id(warpgroup_id), + : buffer(buffer), is_write(is_write), warpgroup_id(warpgroup_id), schedule_phase(phase) {} + + // Define operator< for set + bool operator<(const BufferAccessInfo &other) const { + if (buffer != other.buffer) { + return buffer.get() < other.buffer.get(); + } + if (is_write != other.is_write) { + return is_write < other.is_write; + } + if (warpgroup_id != other.warpgroup_id) { + return warpgroup_id < other.warpgroup_id; + } + if (schedule_phase != other.schedule_phase) { + return schedule_phase < other.schedule_phase; + } + return false; + } }; // Helper function to compare if two regions are equal @@ -116,16 +133,17 @@ class IRStructure { virtual void SetLatency(int64_t latency) = 0; virtual void SetII(int64_t ii) = 0; - // Recursive region collection method - virtual void CollectRegions( - std::vector &result, - std::set>> &visited) const = 0; + // Recursive buffer collection method + virtual void + CollectBufferAccessInfo(int num_wgs, SchedulePhase phase, + std::set &result) const = 0; - std::vector GetReadWriteRegions() const { - std::vector result; - std::set>> visited; - CollectRegions(result, visited); - return result; + std::vector + GetBufferAccessInfo(int num_wgs = 1, + SchedulePhase phase = SchedulePhase::kBody) const { + std::set result; + CollectBufferAccessInfo(num_wgs, phase, result); + return std::vector(result.begin(), result.end()); } // Substitute a variable throughout this IR node @@ -333,9 +351,9 @@ class TaskNode : public IRStructure { void AddReadVar(const Var &var) { read_vars_.push_back(var); } void AddWriteVar(const Var &var) { write_vars_.push_back(var); } - void CollectRegions( - std::vector &result, - std::set>> &visited) const override; + void + CollectBufferAccessInfo(int num_wgs, SchedulePhase phase, + std::set &result) const override; bool containWarpgroupId(int id) const override { return ContainsLoopBreak() || warpgroup_id_ == id; @@ -489,9 +507,9 @@ class ControlNode : public IRStructure { void SetLatency(int64_t latency) override { latency_ = latency; } void SetII(int64_t ii) override { ii_ = ii; } - void CollectRegions( - std::vector &result, - std::set>> &visited) const override; + void + CollectBufferAccessInfo(int num_wgs, SchedulePhase phase, + std::set &result) const override; bool hasPromote() const { return has_promote_; } @@ -605,9 +623,9 @@ class WrapperNode : public IRStructure { void SetLatency(int64_t latency) override { latency_ = latency; } void SetII(int64_t ii) override { ii_ = ii; } - void CollectRegions( - std::vector &result, - std::set>> &visited) const override; + void + CollectBufferAccessInfo(int num_wgs, SchedulePhase phase, + std::set &result) const override; // Clone method std::shared_ptr Clone() const override; @@ -779,9 +797,9 @@ class IfNode : public IRStructure { void SetLatency(int64_t latency) override { latency_ = latency; } void SetII(int64_t ii) override { ii_ = ii; } - void CollectRegions( - std::vector &result, - std::set>> &visited) const override; + void + CollectBufferAccessInfo(int num_wgs, SchedulePhase phase, + std::set &result) const override; // Clone method std::shared_ptr Clone() const override; @@ -887,9 +905,9 @@ class ScheduleUnit : public IRStructure { void SetLatency(int64_t latency) override { latency_ = latency; } void SetII(int64_t ii) override { ii_ = ii; } - void CollectRegions( - std::vector &result, - std::set>> &visited) const override; + void + CollectBufferAccessInfo(int num_wgs, SchedulePhase phase, + std::set &result) const override; int GetStage() const { return stage; } bool isInnerTask() const { return child->IsTask(); } @@ -965,9 +983,9 @@ class SequenceNode : public IRStructure { void SetLatency(int64_t latency) override; void SetII(int64_t ii) override; - void CollectRegions( - std::vector &result, - std::set>> &visited) const override; + void + CollectBufferAccessInfo(int num_wgs, SchedulePhase phase, + std::set &result) const override; // Clone method std::shared_ptr Clone() const override; diff --git a/src/transform/auto_schedule/schedule_builder.h b/src/transform/auto_schedule/schedule_builder.h index cf7f43d702..975dddd696 100644 --- a/src/transform/auto_schedule/schedule_builder.h +++ b/src/transform/auto_schedule/schedule_builder.h @@ -367,14 +367,16 @@ class ScheduleUnitBuilder { std::map buffer_to_num_versions; std::set multi_buffering_buffers; int64_t memory_limit = shared_memory_limit_; - for (const auto ®ion_access : ctrl->GetReadWriteRegions()) { - const auto &buffer = region_access.region->buffer; + for (const auto &buffer_access : ctrl->GetBufferAccessInfo()) { + const auto &buffer = buffer_access.buffer; if (!IsSharedBuffer(buffer)) { continue; // Only consider shared buffers for multi-buffer } if (buffer_to_num_versions.count(buffer)) { continue; } + // If the buffer is used outside the loop or is read before being written, + // we cannot multi-buffer it if (used_buffers.count(buffer) || !check_buffer_write_first(buffer)) { buffer_to_num_versions[buffer] = 1; memory_limit -= GetBufferSize(buffer); From acadde02e04559239d70591ee041069ee695895b Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Mon, 20 Apr 2026 12:58:25 +0800 Subject: [PATCH 095/156] remove unused declarations --- src/transform/auto_schedule.h | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/transform/auto_schedule.h b/src/transform/auto_schedule.h index e3beab40c6..4ec2e70b4a 100644 --- a/src/transform/auto_schedule.h +++ b/src/transform/auto_schedule.h @@ -107,14 +107,6 @@ inline int64_t GetSharedMemoryLimit(Target target) { } } -// Global warpgroup id assignment - should be called from the top level -// Tasks that use the same register region must have the same warpgroup id -// Goal: balance weighted latency between two warpgroups (0 and 1) -// Weighted latency = latency * tripcount (tripcount = 100 for non-constant loop -// extent) -bool AssignWarpgroupIdsGlobal(IRStructure *root, - bool enable_warp_partition = false); - // Function to rewrite alloc_buffers for multi-version support Stmt RewriteAllocBuffers( const Stmt &stmt, const std::vector &buffer_infos); From f8e70593dde896c26eb690b4fcea33ac3ccf01d6 Mon Sep 17 00:00:00 2001 From: Kuris <227995639+kurisu6912@users.noreply.github.com> Date: Mon, 20 Apr 2026 13:04:31 +0800 Subject: [PATCH 096/156] fix: improve warning output in eager frontend (#2064) * fix: remove stack_info from logger.warning in eager builder * fix: add filename and lineno to tilelang logger formatter * fix: improve warning messages in eager builder for clarity * fix: address review comments - stacklevel and warning text alignment --- tilelang/__init__.py | 2 +- tilelang/language/eager/builder.py | 30 ++++++++++++++---------------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/tilelang/__init__.py b/tilelang/__init__.py index 15d9ed1093..14dbe6b1bf 100644 --- a/tilelang/__init__.py +++ b/tilelang/__init__.py @@ -88,7 +88,7 @@ def emit(self, record): handler = TqdmLoggingHandler() formatter = logging.Formatter( - fmt="%(asctime)s [TileLang:%(name)s:%(levelname)s]: %(message)s", + fmt="%(asctime)s [TileLang:%(name)s:%(levelname)s] (%(filename)s:%(lineno)d): %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) handler.setFormatter(formatter) diff --git a/tilelang/language/eager/builder.py b/tilelang/language/eager/builder.py index c579fd9511..8a9ad465e9 100644 --- a/tilelang/language/eager/builder.py +++ b/tilelang/language/eager/builder.py @@ -66,8 +66,7 @@ def unwrap_cond(expr): return bool(expr) else: logger.warning( - f"Python expression `{expr}` is used as condition in TileLang, \nthis is treated as a constant expression. ", - stack_info=True, + f"Python expression `{expr}` is used as condition in TileLang, this is treated as a constant expression.", stacklevel=3, ) return bool(expr) @@ -252,7 +251,7 @@ def enter_frame(self, frame: AbstractContextManager[Any]): def check_continue_break(self): idx = self.find_frame_idx(ContinueOrBreak) if idx is not None: - logger.warning("Writing code after continue/break may cause undefined behavior in tilelang.", stack_info=True, stacklevel=3) + logger.warning("Statements after continue/break have no effect and will be ignored.", stacklevel=3) @contextmanager def with_frame(self, frame: AbstractContextManager[Any] | None): @@ -295,9 +294,8 @@ def eval(self, val: Any): elif isinstance(val, tir.frame.IRBuilderFrame): if isinstance(val, tir.frame.ForFrame): logger.warning( - "Evaluating a for frame may cause undefined behavior in tilelang.", - stack_info=True, - stacklevel=1, + "A for-loop frame is being evaluated as a standalone expression. Did you mean to use it in a `for` statement?", + stacklevel=2, ) self.enter_frame(val) elif isinstance(val, PrimExpr): @@ -311,7 +309,7 @@ def eval(self, val: Any): elif isinstance(val, (Buffer, Var)): pass else: - logger.warning(f"Unused return value: {val}({type(val)})", stack_info=True, stacklevel=2) + logger.warning(f"Return value `{val}` ({type(val)}) is unused and will be discarded.", stacklevel=2) def ctx_for(self, it): self.check_continue_break() @@ -327,7 +325,10 @@ def ctx_for(self, it): else: real_stop = tir.ceildiv(it.start - it.stop, -step_value) else: - logger.warning(f"Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang") + logger.warning( + f"Non-constant step `{it.step}` in serial range may produce unexpected results. Consider using a constant step if possible.", + stacklevel=2, + ) real_stop = tir.ceildiv(it.stop - it.start, it.step) if isinstance(it, UnrollForWithStep): real_frame = tir.unroll(real_stop, annotations=it.annotations) @@ -374,9 +375,8 @@ def ctx_while(self, cond): ) else: logger.warning( - "While loop with constant false condition detected in Tilelang, the loop body will never be executed.\n", - f"Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n", - stack_info=True, + "While loop condition is always false; the loop body will be skipped.\n" + f"Condition: {cond_v} ({type(cond_v)}) => {cond_v_unwrap} ({type(cond_v_unwrap)})\n", stacklevel=2, ) with self.with_frame(tir.While(cond_v_unwrap)): @@ -471,8 +471,7 @@ def bind(self, name, value, annot=BaseBuilder.empty): assert frame is not None, f"Variable `{name}` is not defined inside any control flow." if name in self.name_inside_frame and self.name_inside_frame[name] in self.frames: logger.warning( - f"Immutable value `{name}` is re-bound. If you want to modify its value, please use T.alloc_var to make it a variable!", - stack_info=True, + f"Immutable value `{name}` is re-bound; use T.alloc_var to create a mutable variable.", stacklevel=2, ) self.name_inside_frame[name] = self.frames[frame] @@ -527,7 +526,7 @@ def bind_immutable(self, name, value): def assign_slice(self, lval: Any, sl: slice, value: Any, annot=BaseBuilder.empty): self.check_continue_break() if annot is not self.empty: - logger.warning("Type annotation in slice assignment has no effect", stack_info=True, stacklevel=2) + logger.warning("Type annotation on slice assignment is not supported and will be ignored.", stacklevel=2) if isinstance(lval, Buffer): tir.buffer_store(lval, value, sl) else: @@ -571,8 +570,7 @@ def aug_assign(self, op, target, aug_value, name: str | None = None): assert frame is not None, f"Variable `{name}` is not defined inside any control flow." if name in self.name_inside_frame and self.name_inside_frame[name] in self.frames: logger.warning( - f"Immutable value `{name}` is re-bound. If you want to modify its value, please use T.alloc_var to make it a variable!", - stack_info=True, + f"Immutable value `{name}` is re-bound; use T.alloc_var to create a mutable variable.", stacklevel=2, ) self.name_inside_frame[name] = self.frames[frame] From f7c6f4353337ec35a069ea5a65e0f690ac42af0b Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Mon, 20 Apr 2026 15:06:42 +0800 Subject: [PATCH 097/156] fix read/write regions --- src/transform/auto_schedule/ir_structure.cc | 6 ++++ src/transform/auto_schedule/ir_structure.h | 34 +++++++++++++++++---- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/src/transform/auto_schedule/ir_structure.cc b/src/transform/auto_schedule/ir_structure.cc index a3273b0e77..b4cd281d38 100644 --- a/src/transform/auto_schedule/ir_structure.cc +++ b/src/transform/auto_schedule/ir_structure.cc @@ -394,6 +394,9 @@ std::shared_ptr IfNode::Clone() const { void ControlNode::CollectBufferAccessInfo( int num_wgs, SchedulePhase phase, std::set &result) const { + if (task) { + task->CollectBufferAccessInfo(num_wgs, phase, result); + } if (child) { child->CollectBufferAccessInfo(num_wgs, phase, result); } @@ -402,6 +405,9 @@ void ControlNode::CollectBufferAccessInfo( void WrapperNode::CollectBufferAccessInfo( int num_wgs, SchedulePhase phase, std::set &result) const { + if (task) { + task->CollectBufferAccessInfo(num_wgs, phase, result); + } if (child) { child->CollectBufferAccessInfo(num_wgs, phase, result); } diff --git a/src/transform/auto_schedule/ir_structure.h b/src/transform/auto_schedule/ir_structure.h index 63b6b033db..e9098fb9c9 100644 --- a/src/transform/auto_schedule/ir_structure.h +++ b/src/transform/auto_schedule/ir_structure.h @@ -567,20 +567,42 @@ class WrapperNode : public IRStructure { return child ? child->HasTCGEN05() : false; } - // Memory access regions (aggregate from child) + // Memory access regions (aggregate from child & task) std::vector GetReadRegions() const override { - return child ? child->GetReadRegions() : std::vector{}; + std::vector regions = + child ? child->GetReadRegions() : std::vector{}; + if (task) { + auto task_regions = task->GetReadRegions(); + regions.insert(regions.end(), task_regions.begin(), task_regions.end()); + } + return regions; } std::vector GetWriteRegions() const override { - return child ? child->GetWriteRegions() : std::vector{}; + std::vector regions = + child ? child->GetWriteRegions() : std::vector{}; + if (task) { + auto task_regions = task->GetWriteRegions(); + regions.insert(regions.end(), task_regions.begin(), task_regions.end()); + } + return regions; } - // Variable access (aggregate from child) + // Variable access (aggregate from child & task) std::vector GetReadVars() const override { - return child ? child->GetReadVars() : std::vector{}; + std::vector vars = child ? child->GetReadVars() : std::vector{}; + if (task) { + auto task_vars = task->GetReadVars(); + vars.insert(vars.end(), task_vars.begin(), task_vars.end()); + } + return vars; } std::vector GetWriteVars() const override { - return child ? child->GetWriteVars() : std::vector{}; + std::vector vars = child ? child->GetWriteVars() : std::vector{}; + if (task) { + auto task_vars = task->GetWriteVars(); + vars.insert(vars.end(), task_vars.begin(), task_vars.end()); + } + return vars; } void SubstituteVar(const Var &old_var, const Var &new_var) override { From 99ee74b3437611ce610f62804a71e0174fed0a44 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Mon, 20 Apr 2026 16:35:43 +0800 Subject: [PATCH 098/156] [CUDA] Support int4 `T.gemm` (#2063) * Support int4 T.gemm and direct packed cp.async lowering * Keep cp.async vectorization as a late fallback * Change tl.ptx_cp_async to use num_elems semantics * fix test --- .../deepseek_mla/example_mla_decode_ws.py | 24 +- .../deepseek_v32/sparse_mla_fwd_pipelined.py | 12 +- .../deepseek_v32/sparse_mla_fwd_seesaw.py | 12 +- .../gemm_int4/example_tilelang_gemm_int4.py | 60 +++ src/op/builtin.h | 4 +- src/op/copy.cc | 133 ----- src/target/codegen_cuda.cc | 73 ++- src/target/codegen_cutedsl.cc | 73 ++- src/target/codegen_hip.cc | 86 +++- src/transform/loop_vectorize.cc | 85 +++- src/transform/lower_ptx_async_copy.cc | 97 ++-- .../merge_shared_memory_allocations.cc | 5 +- src/transform/vectorize_loop.cc | 79 ++- .../kernel/test_tilelang_kernel_int4_gemm.py | 40 ++ .../test_tilelang_kernel_int4_gemm_mma.py | 16 +- ...st_tilelang_language_access_ptr_codegen.py | 38 +- ...tilelang_transform_lower_ptx_async_copy.py | 24 +- .../test_tilelang_transform_lower_tile_op.py | 11 +- ...st_tilelang_transform_pipeline_planning.py | 7 +- tilelang/engine/phase.py | 4 - tilelang/intrinsics/mma_macro_generator.py | 473 ++++++++++-------- tilelang/language/dtypes.py | 3 + tilelang/language/tir/op.py | 18 +- tilelang/tileop/gemm/gemm_mma.py | 134 ++++- tilelang/utils/tensor.py | 2 + 25 files changed, 980 insertions(+), 533 deletions(-) create mode 100644 examples/gemm_int4/example_tilelang_gemm_int4.py create mode 100644 testing/python/kernel/test_tilelang_kernel_int4_gemm.py diff --git a/examples/deepseek_mla/example_mla_decode_ws.py b/examples/deepseek_mla/example_mla_decode_ws.py index d77887c7a4..98657e381a 100644 --- a/examples/deepseek_mla/example_mla_decode_ws.py +++ b/examples/deepseek_mla/example_mla_decode_ws.py @@ -219,17 +219,17 @@ def main_split( T.ptx_cp_async( T.access_ptr(KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), T.access_ptr(KV[bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.ptx_cp_async( T.access_ptr(KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), T.access_ptr(KV[bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.ptx_cp_async( T.access_ptr(K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8], "w", 8), T.access_ptr(K_pe[bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.cp_async_barrier_noinc(bar_k_0_ready[0]) @@ -241,17 +241,17 @@ def main_split( T.ptx_cp_async( T.access_ptr(KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), T.access_ptr(KV[bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.ptx_cp_async( T.access_ptr(KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), T.access_ptr(KV[bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.ptx_cp_async( T.access_ptr(K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8], "w", 8), T.access_ptr(K_pe[bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.cp_async_barrier_noinc(bar_k_1_ready[0]) @@ -467,17 +467,17 @@ def main_no_split( T.ptx_cp_async( T.access_ptr(KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), T.access_ptr(KV[bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.ptx_cp_async( T.access_ptr(KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), T.access_ptr(KV[bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.ptx_cp_async( T.access_ptr(K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8], "w", 8), T.access_ptr(K_pe[bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.cp_async_barrier_noinc(bar_k_0_ready[0]) @@ -489,17 +489,17 @@ def main_no_split( T.ptx_cp_async( T.access_ptr(KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), T.access_ptr(KV[bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.ptx_cp_async( T.access_ptr(KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), T.access_ptr(KV[bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.ptx_cp_async( T.access_ptr(K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8], "w", 8), T.access_ptr(K_pe[bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.cp_async_barrier_noinc(bar_k_1_ready[0]) diff --git a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py index 8f31d00b76..bff9f19b98 100644 --- a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py +++ b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py @@ -271,17 +271,17 @@ def main( T.ptx_cp_async( T.access_ptr(KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), T.access_ptr(KV[b_i, indices_local, g_i, 64 * u + (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.ptx_cp_async( T.access_ptr(KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), T.access_ptr(KV[b_i, indices_local, g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.ptx_cp_async( T.access_ptr(K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8], "w", 8), T.access_ptr(KV[b_i, indices_local, g_i, D + (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.cp_async_barrier_noinc(bar_k_0_ready[0]) @@ -296,17 +296,17 @@ def main( T.ptx_cp_async( T.access_ptr(KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), T.access_ptr(KV[b_i, indices_local, g_i, 64 * u + (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.ptx_cp_async( T.access_ptr(KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), T.access_ptr(KV[b_i, indices_local, g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.ptx_cp_async( T.access_ptr(K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8], "w", 8), T.access_ptr(KV[b_i, indices_local, g_i, D + (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.cp_async_barrier_noinc(bar_k_1_ready[0]) diff --git a/examples/deepseek_v32/sparse_mla_fwd_seesaw.py b/examples/deepseek_v32/sparse_mla_fwd_seesaw.py index 08999d5532..cdffac281a 100644 --- a/examples/deepseek_v32/sparse_mla_fwd_seesaw.py +++ b/examples/deepseek_v32/sparse_mla_fwd_seesaw.py @@ -222,18 +222,18 @@ def main( T.ptx_cp_async( T.access_ptr(KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), T.access_ptr(KV[b_i, index, g_i, 64 * u + (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.ptx_cp_async( T.access_ptr(KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), T.access_ptr(KV[b_i, index, g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) # tail_dim (64) needs only one iter of 8 elems per 8 collaborating threads T.ptx_cp_async( T.access_ptr(K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8], "w", 8), T.access_ptr(KV[b_i, index, g_i, D + (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.cp_async_barrier_noinc(bar_k_0_ready[0]) @@ -253,17 +253,17 @@ def main( T.ptx_cp_async( T.access_ptr(KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), T.access_ptr(KV[b_i, index, g_i, 64 * u + (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.ptx_cp_async( T.access_ptr(KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), T.access_ptr(KV[b_i, index, g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.ptx_cp_async( T.access_ptr(K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8], "w", 8), T.access_ptr(KV[b_i, index, g_i, D + (tx - 256) % 8 * 8], "r", 8), - 16, + 8, ) T.cp_async_barrier_noinc(bar_k_1_ready[0]) diff --git a/examples/gemm_int4/example_tilelang_gemm_int4.py b/examples/gemm_int4/example_tilelang_gemm_int4.py new file mode 100644 index 0000000000..4ad0fca710 --- /dev/null +++ b/examples/gemm_int4/example_tilelang_gemm_int4.py @@ -0,0 +1,60 @@ +"""Frontend int4 GEMM example for the T.gemm int4 path. + +This file intentionally models the desired TileLang frontend API: +- A/B are declared as T.int4 tensors +- the matmul is expressed with T.gemm(...) + +The example compiles the kernel and prints the generated CUDA source. +""" + +import tilelang +import tilelang.language as T + +tilelang.disable_cache() + + +def matmul_nt_int4(M, N, K, block_M, block_N, block_K): + @T.prim_func + def main( + A: T.Tensor((M, K), T.int4), + B: T.Tensor((N, K), T.int4), + C: T.Tensor((M, N), T.int32), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), T.int4) + B_shared = T.alloc_shared((block_N, block_K), T.int4) + C_local = T.alloc_fragment((block_M, block_N), T.int32) + + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[bx * block_N, ko * block_K], B_shared) + # Frontend expectation: T.gemm should accept int4 operands directly. + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def compile_int4_gemm( + M=1024, + N=1024, + K=1024, + block_M=128, + block_N=128, + block_K=64, +): + func = matmul_nt_int4(M, N, K, block_M, block_N, block_K) + kernel = tilelang.compile(func, out_idx=-1) + print("Compilation succeeded.") + print(kernel.get_kernel_source()) + return func, kernel + + +def main(): + compile_int4_gemm() + + +if __name__ == "__main__": + main() diff --git a/src/op/builtin.h b/src/op/builtin.h index dae5503a68..43cf562963 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -443,8 +443,8 @@ TVM_DLL const Op &ptx_cp_async_barrier_noinc(); /*! * \brief TileLang intrinsic for PTX async copy from global to shared memory * - * ptx_cp_async(dst_access_ptr, src_access_ptr, bytes) - * ptx_cp_async(dst_access_ptr, src_access_ptr, bytes, predicate) + * ptx_cp_async(dst_access_ptr, src_access_ptr, num_elems) + * ptx_cp_async(dst_access_ptr, src_access_ptr, num_elems, predicate) * */ TVM_DLL const Op &ptx_cp_async(); diff --git a/src/op/copy.cc b/src/op/copy.cc index 2c2eb24f53..aeb057eb44 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -49,139 +49,6 @@ PrimExpr GetCopyMbarPhaseExpr(const Map &annotations, return phase; } -// Rewrite scalar global->shared stores into ptx_cp_async calls. -// This rewriter is applied before the global vectorize pass, so each generated -// cp.async call starts with element-wise bytes and can be widened later. -class CPAsyncStoreRewriter : public StmtMutator { -public: - Stmt Rewrite(const Stmt &stmt) { return VisitStmt(stmt); } - - bool RewriteSuccess() const { - return rewritten_any_store_ && !failed_on_shared_store_; - } - -private: - static bool IsZeroValue(const PrimExpr &e) { - if (auto *b = e.as()) { - return IsZeroValue(b->value); - } - if (auto *f = e.as()) { - return f->value == 0.0f; - } - if (auto *i = e.as()) { - return i->value == 0; - } - return false; - } - - static const BufferLoadNode * - MatchZeroFillBufferLoad(const PrimExpr &value, - Optional *predicate) { - if (const auto *load = value.as()) { - return load; - } - - const auto *call = value.as(); - if (!call || !call->op.same_as(builtin::if_then_else()) || - !IsZeroValue(call->args[2])) { - return nullptr; - } - - const BufferLoadNode *load = - MatchZeroFillBufferLoad(call->args[1], predicate); - if (load == nullptr) { - return nullptr; - } - - // Nested zero-fill guards only permit issuing cp.async when every guard - // on the path to the load is true. - *predicate = - predicate->defined() - ? Optional(And(call->args[0], predicate->value())) - : Optional(call->args[0]); - return load; - } - - Stmt VisitStmt_(const BufferStoreNode *op) final { - if (!IsSharedBuffer(op->buffer)) { - return StmtMutator::VisitStmt_(op); - } - - Optional predicate = std::nullopt; - // Accept either a direct load or a nested zero-fill guard chain: - // if_then_else(p1, if_then_else(p2, load, 0), 0). Nested predicates are - // combined so the generated cp.async is only issued when all guards hold. - const BufferLoadNode *load = MatchZeroFillBufferLoad(op->value, &predicate); - if (load == nullptr) { - failed_on_shared_store_ = true; - return StmtMutator::VisitStmt_(op); - } - - if (!IsGlobalBuffer(load->buffer)) { - failed_on_shared_store_ = true; - return StmtMutator::VisitStmt_(op); - } - int bytes = op->value.dtype().bytes(); - int vectorized_lanes = current_vectorized_lanes_; - - if (!IsValidCPAsyncTransferBytes(bytes * vectorized_lanes)) { - failed_on_shared_store_ = true; - return StmtMutator::VisitStmt_(op); - } - - // Keep pointer metadata in tl.access_ptr form for downstream analysis; - // LowerAccessPtr will translate it to tvm_access_ptr later. - PrimExpr dst_access_ptr = - Call(DataType::Handle(), tvm::tl::access_ptr(), - { - BufferLoad(op->buffer, op->indices), - IntImm(DataType::Int(32), 1), // extent - IntImm(DataType::Int(32), 2) // rw_mask: write - }); - PrimExpr src_access_ptr = - Call(DataType::Handle(), tvm::tl::access_ptr(), - { - BufferLoad(load->buffer, load->indices), - IntImm(DataType::Int(32), 1), // extent - IntImm(DataType::Int(32), 1) // rw_mask: read - }); - - Array args{dst_access_ptr, src_access_ptr, PrimExpr(bytes)}; - if (predicate.defined()) { - args.push_back(predicate.value()); - } - rewritten_any_store_ = true; - return Evaluate(Call(DataType::Handle(), builtin::ptx_cp_async(), args)); - } - - Stmt VisitStmt_(const ForNode *op) final { - int previous_vectorized_lanes = current_vectorized_lanes_; - if (op->kind == ForKind::kVectorized) { - // Assume vectorized access pattern is contiguous on the vectorized iter. - // This is guaranteed by tl.VectorizeLoop: if an access pattern is not - // vectorizable/contiguous for the chosen iter, it is scalarized instead - // of staying as ForKind::kVectorized. - const auto *extent_imm = op->extent.as(); - ICHECK(extent_imm) - << "Vectorized loops must have constant extent, but got " - << op->extent; - int lanes = static_cast(extent_imm->value); - if (lanes > 1 && current_vectorized_lanes_ <= - std::numeric_limits::max() / lanes) { - current_vectorized_lanes_ *= lanes; - } - } - - Stmt stmt = StmtMutator::VisitStmt_(op); - current_vectorized_lanes_ = previous_vectorized_lanes; - return stmt; - } - - bool rewritten_any_store_ = false; - bool failed_on_shared_store_ = false; - int current_vectorized_lanes_ = 1; -}; - } // namespace // Constructs a Copy operator node from call arguments and annotations. diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 28a74a2632..59f31e7297 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -9,6 +9,8 @@ #include #include +#include +#include #include #include #include @@ -26,6 +28,69 @@ using namespace ffi; namespace { +bool IsValidCPAsyncTransferBytes(int64_t bytes) { + return bytes == 4 || bytes == 8 || bytes == 16; +} + +std::optional GetAccessPtrElementType(const PrimExpr &expr) { + const auto *ptr_call = expr.as(); + if (ptr_call == nullptr) { + return std::nullopt; + } + if (ptr_call->op.same_as(builtin::address_of())) { + const auto *buffer_load = ptr_call->args[0].as(); + ICHECK(buffer_load) << "address_of arg must be BufferLoad"; + return buffer_load->buffer->dtype; + } + if (ptr_call->op.same_as(builtin::tvm_access_ptr())) { + ICHECK(!ptr_call->args.empty()); + return ptr_call->args[0].dtype(); + } + if (ptr_call->op.same_as(tl::access_ptr())) { + ICHECK_EQ(ptr_call->args.size(), 3U) + << "tl.access_ptr expects 3 args: (BufferLoad, extent, rw_mask)"; + const auto *buffer_load = ptr_call->args[0].as(); + ICHECK(buffer_load) << "tl.access_ptr arg0 must be BufferLoad"; + return buffer_load->buffer->dtype; + } + return std::nullopt; +} + +int GetTileLangCPAsyncTransferBytes(const CallNode *op) { + ICHECK(op->args.size() == 3 || op->args.size() == 4) + << "tl::ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " + "src_access_ptr, num_elems, [predicate])"; + const auto *num_elems_imm = op->args[2].as(); + ICHECK(num_elems_imm) << "tl::ptx_cp_async num_elems must be IntImm, but got " + << op->args[2]; + int64_t num_elems = num_elems_imm->value; + ICHECK_GT(num_elems, 0); + + auto dst_elem_type = GetAccessPtrElementType(op->args[0]); + auto src_elem_type = GetAccessPtrElementType(op->args[1]); + ICHECK(dst_elem_type.has_value() && src_elem_type.has_value()) + << "tl::ptx_cp_async expects address_of, tl.access_ptr, or " + "tvm_access_ptr operands"; + + int64_t dst_total_bits = + num_elems * dst_elem_type.value().bits() * dst_elem_type.value().lanes(); + int64_t src_total_bits = + num_elems * src_elem_type.value().bits() * src_elem_type.value().lanes(); + ICHECK_EQ(dst_total_bits, src_total_bits) + << "tl::ptx_cp_async requires src/dst transfer widths to match, but got " + << dst_total_bits << " vs " << src_total_bits << " bits"; + ICHECK_EQ(dst_total_bits % 8, 0) + << "tl::ptx_cp_async requires byte-aligned transfers, but got " + << dst_total_bits << " bits"; + + int64_t total_bytes = dst_total_bits / 8; + ICHECK(IsValidCPAsyncTransferBytes(total_bytes)) + << "tl::ptx_cp_async requires a final PTX byte width in {4, 8, 16}, but " + "got " + << total_bytes; + return static_cast(total_bytes); +} + bool CanEmitPackedX2Math(DataType t) { int lanes = t.lanes(); if (lanes < 2 || lanes % 2 != 0) { @@ -1999,14 +2064,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { } } else if (op->op.same_as(tl::ptx_cp_async())) { // TileLang version: args[0] = dst_access_ptr, args[1] = src_access_ptr, - // args[2] = bytes, args[3] = predicate (optional) - ICHECK(op->args.size() == 3 || op->args.size() == 4) - << "tl::ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " - "src_access_ptr, bytes, [predicate])"; + // args[2] = num_elems, args[3] = predicate (optional) + int total_bytes = GetTileLangCPAsyncTransferBytes(op); std::string dst = this->PrintExpr(op->args[0]); std::string src = this->PrintExpr(op->args[1]); - std::string size = this->PrintExpr(op->args[2]); + std::string size = std::to_string(total_bytes); this->PrintIndent(); if (op->args.size() == 3) { diff --git a/src/target/codegen_cutedsl.cc b/src/target/codegen_cutedsl.cc index f495489bc4..5aa72fadbd 100644 --- a/src/target/codegen_cutedsl.cc +++ b/src/target/codegen_cutedsl.cc @@ -13,6 +13,8 @@ #include #include +#include +#include #include #include #include @@ -64,6 +66,69 @@ void ReplaceAll(std::string &str, const std::string &from, } } +bool IsValidCPAsyncTransferBytes(int64_t bytes) { + return bytes == 4 || bytes == 8 || bytes == 16; +} + +std::optional GetAccessPtrElementType(const PrimExpr &expr) { + const auto *ptr_call = expr.as(); + if (ptr_call == nullptr) { + return std::nullopt; + } + if (ptr_call->op.same_as(builtin::address_of())) { + const auto *buffer_load = ptr_call->args[0].as(); + ICHECK(buffer_load) << "address_of arg must be BufferLoad"; + return buffer_load->buffer->dtype; + } + if (ptr_call->op.same_as(builtin::tvm_access_ptr())) { + ICHECK(!ptr_call->args.empty()); + return ptr_call->args[0].dtype(); + } + if (ptr_call->op.same_as(tl::access_ptr())) { + ICHECK_EQ(ptr_call->args.size(), 3U) + << "tl.access_ptr expects 3 args: (BufferLoad, extent, rw_mask)"; + const auto *buffer_load = ptr_call->args[0].as(); + ICHECK(buffer_load) << "tl.access_ptr arg0 must be BufferLoad"; + return buffer_load->buffer->dtype; + } + return std::nullopt; +} + +int GetTileLangCPAsyncTransferBytes(const CallNode *op) { + ICHECK(op->args.size() == 3 || op->args.size() == 4) + << "tl::ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " + "src_access_ptr, num_elems, [predicate])"; + const auto *num_elems_imm = op->args[2].as(); + ICHECK(num_elems_imm) << "tl::ptx_cp_async num_elems must be IntImm, but got " + << op->args[2]; + int64_t num_elems = num_elems_imm->value; + ICHECK_GT(num_elems, 0); + + auto dst_elem_type = GetAccessPtrElementType(op->args[0]); + auto src_elem_type = GetAccessPtrElementType(op->args[1]); + ICHECK(dst_elem_type.has_value() && src_elem_type.has_value()) + << "tl::ptx_cp_async expects address_of, tl.access_ptr, or " + "tvm_access_ptr operands"; + + int64_t dst_total_bits = + num_elems * dst_elem_type.value().bits() * dst_elem_type.value().lanes(); + int64_t src_total_bits = + num_elems * src_elem_type.value().bits() * src_elem_type.value().lanes(); + ICHECK_EQ(dst_total_bits, src_total_bits) + << "tl::ptx_cp_async requires src/dst transfer widths to match, but got " + << dst_total_bits << " vs " << src_total_bits << " bits"; + ICHECK_EQ(dst_total_bits % 8, 0) + << "tl::ptx_cp_async requires byte-aligned transfers, but got " + << dst_total_bits << " bits"; + + int64_t total_bytes = dst_total_bits / 8; + ICHECK(IsValidCPAsyncTransferBytes(total_bytes)) + << "tl::ptx_cp_async requires a final PTX byte width in {4, 8, 16}, but " + "got " + << total_bytes; + return static_cast(total_bytes); +} + } // namespace CodeGenTileLangCuTeDSL::CodeGenTileLangCuTeDSL() { @@ -429,14 +494,12 @@ void CodeGenTileLangCuTeDSL::VisitExpr_(const CallNode *op, } } else if (op->op.same_as(tl::ptx_cp_async())) { // TileLang version: args[0] = dst_access_ptr, args[1] = src_access_ptr, - // args[2] = bytes, args[3] = predicate (optional) - ICHECK(op->args.size() == 3 || op->args.size() == 4) - << "tl::ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " - "src_access_ptr, bytes, [predicate])"; + // args[2] = num_elems, args[3] = predicate (optional) + int total_bytes = GetTileLangCPAsyncTransferBytes(op); std::string dst = PrintExpr_(op->args[0]); std::string src = PrintExpr_(op->args[1]); - std::string size = PrintExpr_(op->args[2]); + std::string size = std::to_string(total_bytes); if (op->args.size() == 3) { this->PrintIndent(); diff --git a/src/target/codegen_hip.cc b/src/target/codegen_hip.cc index 6cc566b9a9..f6503a8348 100644 --- a/src/target/codegen_hip.cc +++ b/src/target/codegen_hip.cc @@ -9,6 +9,8 @@ #include #include +#include +#include #include #include #include @@ -19,6 +21,73 @@ namespace tvm { namespace codegen { +namespace { + +bool IsValidCPAsyncTransferBytes(int64_t bytes) { + return bytes == 4 || bytes == 8 || bytes == 16; +} + +std::optional GetAccessPtrElementType(const PrimExpr &expr) { + const auto *ptr_call = expr.as(); + if (ptr_call == nullptr) { + return std::nullopt; + } + if (ptr_call->op.same_as(builtin::address_of())) { + const auto *buffer_load = ptr_call->args[0].as(); + ICHECK(buffer_load) << "address_of arg must be BufferLoad"; + return buffer_load->buffer->dtype; + } + if (ptr_call->op.same_as(builtin::tvm_access_ptr())) { + ICHECK(!ptr_call->args.empty()); + return ptr_call->args[0].dtype(); + } + if (ptr_call->op.same_as(tl::access_ptr())) { + ICHECK_EQ(ptr_call->args.size(), 3U) + << "tl.access_ptr expects 3 args: (BufferLoad, extent, rw_mask)"; + const auto *buffer_load = ptr_call->args[0].as(); + ICHECK(buffer_load) << "tl.access_ptr arg0 must be BufferLoad"; + return buffer_load->buffer->dtype; + } + return std::nullopt; +} + +int GetTileLangCPAsyncTransferBytes(const CallNode *op) { + ICHECK(op->args.size() == 3 || op->args.size() == 4) + << "tl::ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " + "src_access_ptr, num_elems, [predicate])"; + const auto *num_elems_imm = op->args[2].as(); + ICHECK(num_elems_imm) << "tl::ptx_cp_async num_elems must be IntImm, but got " + << op->args[2]; + int64_t num_elems = num_elems_imm->value; + ICHECK_GT(num_elems, 0); + + auto dst_elem_type = GetAccessPtrElementType(op->args[0]); + auto src_elem_type = GetAccessPtrElementType(op->args[1]); + ICHECK(dst_elem_type.has_value() && src_elem_type.has_value()) + << "tl::ptx_cp_async expects address_of, tl.access_ptr, or " + "tvm_access_ptr operands"; + + int64_t dst_total_bits = + num_elems * dst_elem_type.value().bits() * dst_elem_type.value().lanes(); + int64_t src_total_bits = + num_elems * src_elem_type.value().bits() * src_elem_type.value().lanes(); + ICHECK_EQ(dst_total_bits, src_total_bits) + << "tl::ptx_cp_async requires src/dst transfer widths to match, but got " + << dst_total_bits << " vs " << src_total_bits << " bits"; + ICHECK_EQ(dst_total_bits % 8, 0) + << "tl::ptx_cp_async requires byte-aligned transfers, but got " + << dst_total_bits << " bits"; + + int64_t total_bytes = dst_total_bits / 8; + ICHECK(IsValidCPAsyncTransferBytes(total_bytes)) + << "tl::ptx_cp_async requires a final PTX byte width in {4, 8, 16}, but " + "got " + << total_bytes; + return static_cast(total_bytes); +} + +} // namespace + static std::string GetFP8Type(DataType type) { std::stringstream stream; int32_t lanes = type.lanes(); @@ -798,8 +867,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { } this->stream << ");\n"; }; - if (op->op.same_as(builtin::ptx_cp_async()) || - op->op.same_as(tl::ptx_cp_async())) { + if (op->op.same_as(builtin::ptx_cp_async())) { // args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes, // args[3] = predicate (optional) ICHECK(op->args.size() == 3 || op->args.size() == 4) @@ -819,6 +887,20 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst << ", " << src << ", " << condition << ");\n"; } + } else if (op->op.same_as(tl::ptx_cp_async())) { + int total_bytes = GetTileLangCPAsyncTransferBytes(op); + std::string dst = this->PrintExpr(op->args[0]); + std::string src = this->PrintExpr(op->args[1]); + std::string size = std::to_string(total_bytes); + this->PrintIndent(); + if (op->args.size() == 3) { + this->stream << "tl::cp_async_gs<" << size << ">(" << dst << ", " << src + << ");\n"; + } else { + std::string condition = this->PrintExpr(op->args[3]); + this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst + << ", " << src << ", " << condition << ");\n"; + } } else if (op->op.same_as(builtin::ptx_commit_group())) { print_extern_call_stmt("tl::cp_async_commit"); } else if (op->op.same_as(builtin::ptx_wait_group())) { diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index 6f6dd239e1..d1c8e2feaf 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -33,6 +33,7 @@ #include "tvm/tir/analysis.h" #include "tvm/tir/var.h" #include +#include #include #include #include @@ -410,6 +411,69 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { return arith::IRMutatorWithAnalyzer::VisitStmt_(node); } + static std::optional GetAccessPtrElementBits(const PrimExpr &expr) { + const auto *ptr_call = expr.as(); + if (ptr_call == nullptr) { + return std::nullopt; + } + if (ptr_call->op.same_as(builtin::tvm_access_ptr())) { + ICHECK(!ptr_call->args.empty()); + DataType dtype = ptr_call->args[0].dtype(); + return dtype.bits() * dtype.lanes(); + } + if (ptr_call->op.same_as(tl::access_ptr())) { + ICHECK_EQ(ptr_call->args.size(), 3U) + << "tl.access_ptr expects 3 args: (BufferLoad, extent, rw_mask)"; + const auto *buffer_load = ptr_call->args[0].as(); + ICHECK(buffer_load) << "tl.access_ptr arg0 must be BufferLoad"; + DataType dtype = buffer_load->buffer->dtype; + return dtype.bits() * dtype.lanes(); + } + return std::nullopt; + } + + static std::optional GetCPAsyncBitsPerCall(const CallNode *node) { + ICHECK_GE(node->args.size(), 3U) + << "cp.async expects at least 3 arguments, but got " << node->args; + const auto *count_imm = node->args[2].as(); + ICHECK(count_imm) << "cp.async transfer count must be IntImm, but got " + << node->args[2]; + int count = static_cast(count_imm->value); + if (count <= 0) { + return std::nullopt; + } + if (node->op.same_as(builtin::ptx_cp_async())) { + return count * 8; + } + ICHECK(node->op.same_as(tl::ptx_cp_async())); + auto dst_elem_bits = GetAccessPtrElementBits(node->args[0]); + auto src_elem_bits = GetAccessPtrElementBits(node->args[1]); + if (!dst_elem_bits.has_value() || !src_elem_bits.has_value()) { + return std::nullopt; + } + int dst_total_bits = count * dst_elem_bits.value(); + int src_total_bits = count * src_elem_bits.value(); + ICHECK_EQ(dst_total_bits, src_total_bits) + << "tl.ptx_cp_async requires src/dst transfer widths to match, but got " + << dst_total_bits << " vs " << src_total_bits << " bits"; + return dst_total_bits; + } + + static int GetMaxCPAsyncVectorizeLength(int per_call_bits) { + if (per_call_bits <= 0) { + return 1; + } + int vectorize_length = 1; + for (int target_bytes : {16, 8, 4}) { + int target_bits = target_bytes * 8; + if (target_bits % per_call_bits == 0) { + vectorize_length = + std::max(vectorize_length, target_bits / per_call_bits); + } + } + return vectorize_length; + } + PrimExpr VisitExpr_(const CallNode *node) final { if (node->op == builtin::if_then_else()) { CheckConditionVectorized(node->args[0]); @@ -457,22 +521,11 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } else if (node->op.same_as(builtin::ptx_cp_async()) || node->op.same_as(tl::ptx_cp_async())) { - // cp.async supports byte sizes 4/8/16. For element-wise calls with small - // byte width (e.g., fp16 => 2 bytes), we rely on vectorization to fold - // multiple calls into one wider cp.async call. - int vectorize_length = 1; - ICHECK_GE(node->args.size(), 3U) - << "cp.async expects at least 3 arguments, but got " << node->args; - const auto *bytes_imm = node->args[2].as(); - ICHECK(bytes_imm) << "cp.async byte count must be IntImm, but got " - << node->args[2]; - int bytes = static_cast(bytes_imm->value); - for (int lanes : {16, 8, 4, 2, 1}) { - if (IsValidCPAsyncTransferBytes(bytes * lanes)) { - vectorize_length = lanes; - break; - } - } + // builtin::ptx_cp_async stores bytes, while tl::ptx_cp_async stores + // logical element counts. In both cases we pick the largest vector width + // whose eventual PTX payload is one of {4, 8, 16} bytes. + int vectorize_length = + GetMaxCPAsyncVectorizeLength(GetCPAsyncBitsPerCall(node).value_or(0)); buffer_vector_infos_.push_back({Buffer(), vectorize_length, false, {}}); return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } else if (node->op == builtin::address_of() || diff --git a/src/transform/lower_ptx_async_copy.cc b/src/transform/lower_ptx_async_copy.cc index 1324456af3..fc90b9ad4c 100644 --- a/src/transform/lower_ptx_async_copy.cc +++ b/src/transform/lower_ptx_async_copy.cc @@ -66,14 +66,15 @@ class PTXAsyncCopyInjector : public StmtMutator { } Stmt VisitStmt_(const ForNode *op) final { - // Track nested vectorized loop extents so we can rewrite element-wise - // copies (e.g. float16 stores) into `tir.ptx_cp_async` with element bytes, - // relying on the later `tl.VectorizeLoop` pass to widen: - // for v in T.vectorized(k): ptx_cp_async(dst, src, elem_bytes) - // => ptx_cp_async(dst_base, src_base, elem_bytes * k) + // Track nested vectorized loop extents so we can decide whether an + // element-wise copy has a legal final cp.async width after later loop + // vectorization: + // for v in T.vectorized(k): tl.ptx_cp_async(dst, src, elem_count) + // => tl.ptx_cp_async(dst_base, src_base, elem_count * k) // - // This mirrors the logic in `CPAsyncStoreRewriter` used by `T.copy` - // lowering, and avoids duplicating vectorize-loop collapse here. + // TileLang records logical element counts in tl.ptx_cp_async. The final + // PTX byte width is derived later from the access_ptr dtype, so subbyte + // dtypes such as int4/fp4/int2/int1 remain representable here. int previous_vectorized_lanes = current_vectorized_lanes_; bool pushed_vectorized_loop = false; if (op->kind == ForKind::kVectorized) { @@ -103,20 +104,13 @@ class PTXAsyncCopyInjector : public StmtMutator { const PrimExpr &predicate_value = PrimExpr()) { // Pipeline: // 1) Analyze source/destination indices and transfer width eligibility. - // 2) Validate pointer type metadata for access_ptr construction. - // 3) Build cp.async with scalar/vectorized offsets if representable. + // 2) Build tl.ptx_cp_async with scalar/vectorized base offsets when the + // eventual PTX byte width is representable. std::optional index_info = PrepareCopyIndexInfo(load, store); if (!index_info.has_value()) { return Optional(); } - std::optional ptr_info = - PreparePointerTypeInfo(load, store); - if (!ptr_info.has_value()) { - // Be conservative: if pointer metadata is missing, skip injection. - return Optional(); - } - if (index_info->index_lanes == 1) { if (current_vectorized_lanes_ > 1 && !HasContiguousVectorizedOffsets(index_info->src_index, @@ -124,10 +118,11 @@ class PTXAsyncCopyInjector : public StmtMutator { return Optional(); } return MakeCPAsyncStmtFromLoads( - store, ptr_info.value(), + store, /*dst_base_load=*/BufferLoad(store->buffer, store->indices), /*src_base_load=*/BufferLoad(load->buffer, load->indices), - /*bytes=*/index_info->transfer_bytes, predicated, predicate_value); + /*num_elems=*/index_info->per_access_num_elems, predicated, + predicate_value); } Optional> src_base_indices = @@ -144,10 +139,11 @@ class PTXAsyncCopyInjector : public StmtMutator { return Optional(); } return MakeCPAsyncStmtFromLoads( - store, ptr_info.value(), + store, /*dst_base_load=*/BufferLoad(store->buffer, dst_base_indices.value()), /*src_base_load=*/BufferLoad(load->buffer, src_base_indices.value()), - /*bytes=*/index_info->transfer_bytes, predicated, predicate_value); + /*num_elems=*/index_info->per_access_num_elems, predicated, + predicate_value); } Stmt VisitStmt_(const SeqStmtNode *op) final { @@ -301,13 +297,7 @@ class PTXAsyncCopyInjector : public StmtMutator { PrimExpr src_index; PrimExpr dst_index; int index_lanes{1}; - int transfer_bytes{0}; - }; - - // Pointer element type metadata extracted from buffer handle annotations. - struct PointerTypeInfo { - DataType dst_elem_type; - DataType src_elem_type; + int per_access_num_elems{0}; }; // Synchronization state for injected cp.async runs carried across statements. @@ -409,9 +399,16 @@ class PTXAsyncCopyInjector : public StmtMutator { } const int effective_lanes = std::max(value_lanes, index_lanes); - const int elem_bytes = effective_lanes * load->dtype.bytes(); - const int total_bytes = static_cast(elem_bytes) * - static_cast(current_vectorized_lanes_); + const int per_access_bits = effective_lanes * load->dtype.bits(); + const int total_bits = static_cast(per_access_bits) * + static_cast(current_vectorized_lanes_); + // PTX cp.async is byte-granular. `tl.ptx_cp_async` stores logical element + // counts, but we still need to know that the eventual vectorized transfer + // can map to a legal byte width without over-copying packed subbyte data. + if (total_bits % 8 != 0) { + return std::nullopt; + } + const int total_bytes = total_bits / 8; if (!IsValidCPAsyncTransferBytes(total_bytes)) { return std::nullopt; } @@ -420,21 +417,10 @@ class PTXAsyncCopyInjector : public StmtMutator { info.src_index = src_index; info.dst_index = dst_index; info.index_lanes = index_lanes; - info.transfer_bytes = elem_bytes; + info.per_access_num_elems = effective_lanes; return info; } - static std::optional - PreparePointerTypeInfo(const BufferLoadNode *load, - const BufferStoreNode *store) { - auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation); - auto src_elem_type = GetPointerType(load->buffer->data->type_annotation); - if (!dst_elem_type.has_value() || !src_elem_type.has_value()) { - return std::nullopt; - } - return PointerTypeInfo{dst_elem_type.value(), src_elem_type.value()}; - } - static PrimExpr ExtractVectorBase(const PrimExpr &index) { if (index.dtype().lanes() == 1) { return index; @@ -500,30 +486,25 @@ class PTXAsyncCopyInjector : public StmtMutator { IntImm(DataType::Int(32), rw_mask)}); } - static Optional MakeCPAsyncStmtFromLoads( - const BufferStoreNode *store, const PointerTypeInfo &ptr_info, - const BufferLoad &dst_base_load, const BufferLoad &src_base_load, - int bytes, bool predicated, const PrimExpr &predicate_value) { - int dst_elem_count = bytes / ptr_info.dst_elem_type.bytes(); - int src_elem_count = bytes / ptr_info.src_elem_type.bytes(); - if (dst_elem_count <= 0 || src_elem_count <= 0) { - return Optional(); - } - + static Optional + MakeCPAsyncStmtFromLoads(const BufferStoreNode *store, + const BufferLoad &dst_base_load, + const BufferLoad &src_base_load, int num_elems, + bool predicated, const PrimExpr &predicate_value) { PrimExpr dst_access_ptr = - MakeAccessPtrFromLoad(dst_base_load, dst_elem_count, /*rw_mask=*/2); + MakeAccessPtrFromLoad(dst_base_load, num_elems, /*rw_mask=*/2); PrimExpr src_access_ptr = - MakeAccessPtrFromLoad(src_base_load, src_elem_count, /*rw_mask=*/1); + MakeAccessPtrFromLoad(src_base_load, num_elems, /*rw_mask=*/1); ffi::Array cp_async_args; if (predicated) { - cp_async_args = {dst_access_ptr, src_access_ptr, PrimExpr(bytes), + cp_async_args = {dst_access_ptr, src_access_ptr, PrimExpr(num_elems), predicate_value}; } else { - cp_async_args = {dst_access_ptr, src_access_ptr, PrimExpr(bytes)}; + cp_async_args = {dst_access_ptr, src_access_ptr, PrimExpr(num_elems)}; } - return Evaluate(Call(store->buffer->dtype, - tvm::tir::builtin::ptx_cp_async(), cp_async_args)); + return Evaluate( + Call(store->buffer->dtype, tvm::tl::ptx_cp_async(), cp_async_args)); } static Stmt MakeCommitGroupStmt() { diff --git a/src/transform/merge_shared_memory_allocations.cc b/src/transform/merge_shared_memory_allocations.cc index 5f77b121c8..9d2b46250f 100644 --- a/src/transform/merge_shared_memory_allocations.cc +++ b/src/transform/merge_shared_memory_allocations.cc @@ -589,10 +589,11 @@ class SharedMemoryRewriter : public StmtExprMutator { return Call(op->dtype, op->op, {op->args[0], merged_buf_var_, extra_offset + offset, extent, op->args[4]}); - } else if (op->op.same_as(builtin::ptx_cp_async())) { + } else if (op->op.same_as(builtin::ptx_cp_async()) || + op->op.same_as(tl::ptx_cp_async())) { ICHECK(op->args.size() == 3U || op->args.size() == 4U) << "ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " - "src_access_ptr, bytes[, predicate])"; + "src_access_ptr, count[, predicate])"; // Extract dst_access_ptr and check if it needs merging Call dst_access_ptr = Downcast(op->args[0]); diff --git a/src/transform/vectorize_loop.cc b/src/transform/vectorize_loop.cc index 2eac249def..9735a9a5b1 100644 --- a/src/transform/vectorize_loop.cc +++ b/src/transform/vectorize_loop.cc @@ -32,6 +32,7 @@ #include #include +#include #include #include #include @@ -651,10 +652,58 @@ class TLVectorizer : public StmtMutator, return Call(op->dtype, GetVectorizedAtomicOp(vector_size), {dst, src}); } - // cp.async call vectorization. - // Pattern: - // for i in vectorized(k): ptx_cp_async(dst, src, elem_bytes) - // => ptx_cp_async(dst_base, src_base, elem_bytes * k) + static std::optional GetAccessPtrElementBits(const PrimExpr &expr) { + const auto *ptr_call = expr.as(); + if (ptr_call == nullptr) { + return std::nullopt; + } + if (ptr_call->op.same_as(builtin::tvm_access_ptr())) { + ICHECK(!ptr_call->args.empty()); + DataType dtype = ptr_call->args[0].dtype(); + return dtype.bits() * dtype.lanes(); + } + if (ptr_call->op.same_as(tl::access_ptr())) { + ICHECK_GE(ptr_call->args.size(), 3U); + const auto *buffer_load = ptr_call->args[0].as(); + ICHECK(buffer_load) << "tl.access_ptr arg0 must be BufferLoad"; + DataType dtype = buffer_load->buffer->dtype; + return dtype.bits() * dtype.lanes(); + } + return std::nullopt; + } + + static std::optional GetCPAsyncBitsPerCall(const CallNode *op, + const PrimExpr &count) { + const auto *count_imm = count.as(); + if (count_imm == nullptr) { + return std::nullopt; + } + int scalar_count = static_cast(count_imm->value); + if (scalar_count <= 0) { + return std::nullopt; + } + if (op->op.same_as(builtin::ptx_cp_async())) { + return scalar_count * 8; + } + ICHECK(op->op.same_as(tl::ptx_cp_async())); + auto dst_elem_bits = GetAccessPtrElementBits(op->args[0]); + auto src_elem_bits = GetAccessPtrElementBits(op->args[1]); + if (!dst_elem_bits.has_value() || !src_elem_bits.has_value()) { + return std::nullopt; + } + int dst_total_bits = scalar_count * dst_elem_bits.value(); + int src_total_bits = scalar_count * src_elem_bits.value(); + ICHECK_EQ(dst_total_bits, src_total_bits) + << "tl.ptx_cp_async requires src/dst transfer widths to match, but got " + << dst_total_bits << " vs " << src_total_bits << " bits"; + return dst_total_bits; + } + + // Vectorized cp.async widening. + // builtin::ptx_cp_async keeps the transfer width in bytes, while + // tl::ptx_cp_async keeps it in logical element counts. The generic + // vectorization pass widens either form by the vector lane count and lets + // the final codegen validate the derived PTX byte width. PrimExpr MutatePTXCPAsyncExpr_(const CallNode *op) { ICHECK(op->op.same_as(builtin::ptx_cp_async()) || op->op.same_as(tl::ptx_cp_async())); @@ -664,7 +713,7 @@ class TLVectorizer : public StmtMutator, PrimExpr dst = VisitExpr(op->args[0]); PrimExpr src = VisitExpr(op->args[1]); - PrimExpr bytes = VisitExpr(op->args[2]); + PrimExpr count = VisitExpr(op->args[2]); Optional predicate = std::nullopt; if (op->args.size() == 4) { auto pred = VisitExpr(op->args[3]); @@ -677,7 +726,7 @@ class TLVectorizer : public StmtMutator, auto lanes_ptr = as_const_int(var_lanes_); if (!lanes_ptr || *lanes_ptr <= 1) { - Array new_args{dst, src, bytes}; + Array new_args{dst, src, count}; if (predicate.defined()) { new_args.push_back(predicate.value()); } @@ -687,23 +736,33 @@ class TLVectorizer : public StmtMutator, return Call(op->dtype, op->op, new_args); } - const auto *bytes_imm = bytes.as(); - if (bytes_imm == nullptr) { + auto bits_per_call = GetCPAsyncBitsPerCall(op, count); + if (!bits_per_call.has_value()) { need_scalarize_ = true; return tvm::ffi::GetRef(op); } int vector_size = static_cast(*lanes_ptr); - int total_bytes = static_cast(bytes_imm->value) * vector_size; + int total_bits = bits_per_call.value() * vector_size; + if (total_bits % 8 != 0) { + need_scalarize_ = true; + return tvm::ffi::GetRef(op); + } + int total_bytes = total_bits / 8; if (!IsValidCPAsyncTransferBytes(total_bytes)) { need_scalarize_ = true; return tvm::ffi::GetRef(op); } - Array new_args{dst, src, IntImm(bytes_imm->dtype, total_bytes)}; + int total_count = + static_cast(Downcast(count)->value) * vector_size; + Array new_args{dst, src, IntImm(count.dtype(), total_count)}; if (predicate.defined()) { new_args.push_back(predicate.value()); } + if (new_args.same_as(op->args)) { + return tvm::ffi::GetRef(op); + } return Call(op->dtype, op->op, new_args); } diff --git a/testing/python/kernel/test_tilelang_kernel_int4_gemm.py b/testing/python/kernel/test_tilelang_kernel_int4_gemm.py new file mode 100644 index 0000000000..3804d6303b --- /dev/null +++ b/testing/python/kernel/test_tilelang_kernel_int4_gemm.py @@ -0,0 +1,40 @@ +import tilelang +import tilelang.testing +import tilelang.language as T + + +def matmul_nt_int4(M, N, K, block_M, block_N, block_K): + @T.prim_func + def main( + A: T.Tensor((M, K), T.int4), + B: T.Tensor((N, K), T.int4), + C: T.Tensor((M, N), T.int32), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), T.int4) + B_shared = T.alloc_shared((block_N, block_K), T.int4) + C_local = T.alloc_fragment((block_M, block_N), T.int32) + + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[bx * block_N, ko * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(8, 0) +def test_compile_int4_gemm_tgemm(): + func = matmul_nt_int4(1024, 1024, 1024, 128, 128, 64) + kernel = tilelang.compile(func, out_idx=-1) + src = kernel.get_kernel_source() + assert src is not None + assert "s4.s4.s32" in src or "int4" in src + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py index f80f3f1629..5bcd955488 100644 --- a/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py +++ b/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py @@ -8,8 +8,8 @@ ) from tilelang.transform import simplify_prim_func from tilelang.intrinsics.mma_macro_generator import ( - INT4TensorCoreIntrinEmitter, - INT4TensorCoreIntrinEmitterWithLadderTransform, + TensorCoreIntrinEmitter, + TensorCoreIntrinEmitterWithLadderTransform, ) tilelang.testing.set_random_seed(42) @@ -75,9 +75,9 @@ def tl_matmul( warp_cols = warp_col_tiles // micro_size_y # MMA Wrapper to Auto Generate Code for MMA - mma_emitter = INT4TensorCoreIntrinEmitter( - a_dtype=in_dtype, - b_dtype=in_dtype, + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=T.int4, + b_dtype=T.int4, accum_dtype=accum_dtype, a_transposed=False, b_transposed=True, @@ -266,9 +266,9 @@ def tl_matmul_weight_only_transform( warp_cols = warp_col_tiles // micro_size_y # MMA Wrapper to Auto Generate Code for MMA - mma_emitter = INT4TensorCoreIntrinEmitterWithLadderTransform( - a_dtype=in_dtype, - b_dtype=in_dtype, + mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=T.int4, + b_dtype=T.int4, accum_dtype=accum_dtype, a_transposed=False, b_transposed=True, diff --git a/testing/python/language/test_tilelang_language_access_ptr_codegen.py b/testing/python/language/test_tilelang_language_access_ptr_codegen.py index 578d5255de..1d66b8b32b 100644 --- a/testing/python/language/test_tilelang_language_access_ptr_codegen.py +++ b/testing/python/language/test_tilelang_language_access_ptr_codegen.py @@ -32,8 +32,8 @@ def main( @tilelang.testing.requires_cuda -def test_vectorized_cp_async_bytes_codegen(): - """Check vectorized ptx_cp_async byte folding (elem_bytes * lanes).""" +def test_vectorized_cp_async_num_elems_codegen(): + """Check vectorized tl.ptx_cp_async widens logical element counts.""" @T.prim_func def main( @@ -46,7 +46,7 @@ def main( T.ptx_cp_async( T.access_ptr(S[i], "w", 1), T.access_ptr(A[i], "r", 1), - 2, + 1, ) T.ptx_commit_group() T.ptx_wait_group(0) @@ -56,8 +56,36 @@ def main( src = kernel.get_kernel_source() print("=== vectorized cp.async codegen ===") print(src) - assert "cp_async_gs<8>" in src, "Expected vectorized cp.async bytes to fold into cp_async_gs<8>" - assert "cp_async_gs<2>" not in src, "Did not expect scalar cp.async bytes in generated CUDA source" + assert "cp_async_gs<8>" in src, "Expected vectorized cp.async to fold 4 x fp16 elems into cp_async_gs<8>" + assert "cp_async_gs<2>" not in src, "Did not expect scalar cp.async width in generated CUDA source" + + +@tilelang.testing.requires_cuda +def test_vectorized_int4_cp_async_num_elems_codegen(): + """Check subbyte tl.ptx_cp_async derives PTX bytes from logical element counts.""" + + @T.prim_func + def main( + A: T.Tensor((128,), T.int4), + B: T.Tensor((128,), T.int4), + ): + with T.Kernel(1, threads=32): + S = T.alloc_shared((128,), T.int4) + for i in T.vectorized(32): + T.ptx_cp_async( + T.access_ptr(S[i], "w", 1), + T.access_ptr(A[i], "r", 1), + 1, + ) + T.ptx_commit_group() + T.ptx_wait_group(0) + B[0] = S[0] + + kernel = tilelang.compile(main, out_idx=[1], target="cuda") + src = kernel.get_kernel_source() + print("=== vectorized int4 cp.async codegen ===") + print(src) + assert "cp_async_gs<16>" in src, "Expected 32 x int4 elems to fold into cp_async_gs<16>" @tilelang.testing.requires_cuda diff --git a/testing/python/transform/test_tilelang_transform_lower_ptx_async_copy.py b/testing/python/transform/test_tilelang_transform_lower_ptx_async_copy.py index 71b4c1e74a..588a3170c3 100644 --- a/testing/python/transform/test_tilelang_transform_lower_ptx_async_copy.py +++ b/testing/python/transform/test_tilelang_transform_lower_ptx_async_copy.py @@ -51,7 +51,7 @@ def before( mod = tl.transform.LowerPTXAsyncCopy()(mod) calls = _count_calls(mod["main"]) - assert calls.get("tir.ptx_cp_async", 0) > 0 + assert calls.get("tl.ptx_cp_async", 0) > 0 assert calls.get("tir.ptx_commit_group", 0) > 0 assert calls.get("tir.ptx_wait_group", 0) > 0 @@ -77,7 +77,7 @@ def before( mod = tl.transform.LowerPTXAsyncCopy()(mod) calls = _count_calls(mod["main"]) - assert calls.get("tir.ptx_cp_async", 0) > 0 + assert calls.get("tl.ptx_cp_async", 0) > 0 assert calls.get("tir.ptx_commit_group", 0) == 0 assert calls.get("tir.ptx_wait_group", 0) == 0 @@ -102,7 +102,7 @@ def before( mod = tl.transform.LowerPTXAsyncCopy()(mod) calls = _count_calls(mod["main"]) - assert calls.get("tir.ptx_cp_async", 0) > 0 + assert calls.get("tl.ptx_cp_async", 0) > 0 assert calls.get("tir.ptx_commit_group", 0) > 0 assert calls.get("tir.ptx_wait_group", 0) > 0 @@ -128,7 +128,7 @@ def before( mod = tl.transform.LowerPTXAsyncCopy()(mod) calls = _count_calls(mod["main"]) - assert calls.get("tir.ptx_cp_async", 0) > 0 + assert calls.get("tl.ptx_cp_async", 0) > 0 assert calls.get("tir.ptx_commit_group", 0) > 0 assert calls.get("tir.ptx_wait_group", 0) > 0 @@ -155,7 +155,7 @@ def before( mod = tl.transform.LowerPTXAsyncCopy()(mod) calls = _count_calls(mod["main"]) - assert calls.get("tir.ptx_cp_async", 0) > 0 + assert calls.get("tl.ptx_cp_async", 0) > 0 assert calls.get("tir.ptx_commit_group", 0) > 0 assert calls.get("tir.ptx_wait_group", 0) > 0 @@ -195,7 +195,7 @@ def before( mod = tl.transform.LowerPTXAsyncCopy()(mod) calls = _count_calls(mod["main"]) - assert calls.get("tir.ptx_cp_async", 0) == 0 + assert calls.get("tl.ptx_cp_async", 0) == 0 assert calls.get("tir.ptx_commit_group", 0) == 0 assert calls.get("tir.ptx_wait_group", 0) == 0 @@ -222,7 +222,7 @@ def before( mod = tl.transform.LowerPTXAsyncCopy()(mod) calls = _count_calls(mod["main"]) - assert calls.get("tir.ptx_cp_async", 0) > 0 + assert calls.get("tl.ptx_cp_async", 0) > 0 assert calls.get("tir.ptx_commit_group", 0) == 1 assert calls.get("tir.ptx_wait_group", 0) == 1 @@ -248,7 +248,7 @@ def before( mod = tl.transform.LowerPTXAsyncCopy()(mod) calls = _count_calls(mod["main"]) - assert calls.get("tir.ptx_cp_async", 0) > 0 + assert calls.get("tl.ptx_cp_async", 0) > 0 assert calls.get("tir.ptx_commit_group", 0) == 1 assert calls.get("tir.ptx_wait_group", 0) == 1 @@ -273,7 +273,7 @@ def before( mod = tl.transform.LowerPTXAsyncCopy()(mod) calls = _count_calls(mod["main"]) - assert calls.get("tir.ptx_cp_async", 0) > 0 + assert calls.get("tl.ptx_cp_async", 0) > 0 assert calls.get("tir.ptx_commit_group", 0) > 0 assert calls.get("tir.ptx_wait_group", 0) > 0 @@ -321,7 +321,7 @@ def before( mod = tl.transform.LowerPTXAsyncCopy()(mod) calls = _count_calls(mod["main"]) - assert calls.get("tir.ptx_cp_async", 0) > 0 + assert calls.get("tl.ptx_cp_async", 0) > 0 def test_lower_ptx_async_copy_skips_vectorized_broadcast_source(): @@ -344,7 +344,7 @@ def before( mod = tl.transform.LowerPTXAsyncCopy()(mod) calls = _count_calls(mod["main"]) - assert calls.get("tir.ptx_cp_async", 0) == 0 + assert calls.get("tl.ptx_cp_async", 0) == 0 assert calls.get("tir.ptx_commit_group", 0) == 0 assert calls.get("tir.ptx_wait_group", 0) == 0 @@ -369,7 +369,7 @@ def before( print(mod) calls = _count_calls(mod["main"]) print(calls) - assert calls.get("tir.ptx_cp_async", 0) > 0 + assert calls.get("tl.ptx_cp_async", 0) > 0 if __name__ == "__main__": diff --git a/testing/python/transform/test_tilelang_transform_lower_tile_op.py b/testing/python/transform/test_tilelang_transform_lower_tile_op.py index ae170a2663..c43199151f 100644 --- a/testing/python/transform/test_tilelang_transform_lower_tile_op.py +++ b/testing/python/transform/test_tilelang_transform_lower_tile_op.py @@ -2,6 +2,7 @@ import tilelang as tl import tilelang.language as T +import tilelang.testing from tilelang import tvm from tvm.tir.stmt_functor import post_order_visit @@ -42,7 +43,7 @@ def before( mod = tl.transform.LowerTileOp()(mod) calls = _count_calls(mod["main"]) - assert calls.get("tir.ptx_cp_async", 0) > 0 + assert calls.get("tl.ptx_cp_async", 0) > 0 assert calls.get("tir.ptx_commit_group", 0) == 0 assert calls.get("tir.ptx_wait_group", 0) == 0 @@ -71,7 +72,7 @@ def before( mod = tl.transform.LowerTileOp()(mod) calls = _count_calls(mod["main"]) - assert calls.get("tir.ptx_cp_async", 0) > 0 + assert calls.get("tl.ptx_cp_async", 0) > 0 assert calls.get("tir.ptx_commit_group", 0) == 0 assert calls.get("tir.ptx_wait_group", 0) == 0 @@ -101,6 +102,10 @@ def before( mod = tl.transform.LowerTileOp()(mod) calls = _count_calls(mod["main"]) - assert calls.get("tir.ptx_cp_async", 0) > 0 + assert calls.get("tl.ptx_cp_async", 0) > 0 assert calls.get("tir.ptx_commit_group", 0) == 0 assert calls.get("tir.ptx_wait_group", 0) == 0 + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/transform/test_tilelang_transform_pipeline_planning.py b/testing/python/transform/test_tilelang_transform_pipeline_planning.py index 9ada565fba..75fc4d4899 100644 --- a/testing/python/transform/test_tilelang_transform_pipeline_planning.py +++ b/testing/python/transform/test_tilelang_transform_pipeline_planning.py @@ -461,13 +461,13 @@ def main( T.ptx_cp_async( T.access_ptr(S[0], "w", 16), T.access_ptr(A[0], "r", 16), - 16, + 8, True, ) T.ptx_cp_async( T.access_ptr(S[8], "w", 16), T.access_ptr(A[0], "r", 16), - 16, + 8, False, ) T.ptx_commit_group() @@ -487,4 +487,5 @@ def main( if __name__ == "__main__": - tilelang.testing.main() + # tilelang.testing.main() + test_pipeline_predicated_copy_preserves_shared_fill_correctness() diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 028de25991..b26ae38a6e 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -185,11 +185,7 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: # Run pipeline planning and software-pipeline rewriting before layout # inference so inferred layouts see the final pipelined structure directly. mod = tilelang.transform.PipelinePlanning()(mod) - # print("After pipeline planing") - # print(mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod) - # print("After InjectSoftwarePipeline") - # print(mod) mod = tilelang.transform.Simplify()(mod) # Infer memory layouts for fragments and shared memory mod = tilelang.transform.LayoutInference()(mod) diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index f1932245d1..3ac11a465f 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -1,4 +1,5 @@ from __future__ import annotations +from dataclasses import dataclass import tilelang.language as T from typing import Literal, Callable from tilelang.common import TransformKind @@ -31,6 +32,160 @@ lift = convert +def _resolve_subbyte_local_offset(local_size: int, numerator: int, denominator: int) -> int: + if denominator <= 0: + raise ValueError(f"denominator must be positive, but got {denominator}") + scaled = local_size * numerator + if scaled % denominator != 0: + raise ValueError(f"Invalid subbyte MMA offset {numerator}/{denominator} for local_size={local_size}") + return scaled // denominator + + +def _infer_subbyte_storage_bits(logical_bits: int) -> int: + for storage_bits in (8, 16, 32): + if storage_bits >= logical_bits and storage_bits % logical_bits == 0: + return storage_bits + raise ValueError(f"Unsupported subbyte logical bit width: {logical_bits}") + + +def _infer_subbyte_storage_dtype(logical_dtype: str, logical_bits: int) -> str: + storage_bits = _infer_subbyte_storage_bits(logical_bits) + logical_dtype = str(logical_dtype) + if logical_dtype.startswith("uint"): + return f"uint{storage_bits}" + if logical_dtype.startswith("int"): + return f"int{storage_bits}" + # For non-integer subbyte dtypes such as future fp4, use an integer carrier dtype + # inside lowering. The logical dtype still drives the MMA opcode selection. + return f"int{storage_bits}" + + +@dataclass(frozen=True) +class SubByteTensorCoreMMAOp: + a_offset_num: int = 0 + a_offset_den: int = 1 + b_offset_num: int = 0 + b_offset_den: int = 1 + c_offset_num: int = 0 + c_offset_den: int = 1 + + def resolve_offsets(self, local_size_a: int, local_size_b: int, local_size_out: int) -> tuple[int, int, int]: + return ( + _resolve_subbyte_local_offset(local_size_a, self.a_offset_num, self.a_offset_den), + _resolve_subbyte_local_offset(local_size_b, self.b_offset_num, self.b_offset_den), + _resolve_subbyte_local_offset(local_size_out, self.c_offset_num, self.c_offset_den), + ) + + +@dataclass(frozen=True) +class SubByteTensorCoreMMASpec: + logical_a_dtype: str + logical_b_dtype: str + logical_a_bits: int + logical_b_bits: int + accum_dtype: str + mma_prefix: str + mma_a_dtype_abbrv: str + mma_b_dtype_abbrv: str + mma_ops: tuple[SubByteTensorCoreMMAOp, ...] + + def __post_init__(self): + self._validate_pack_factor(self.storage_a_dtype, self.logical_a_bits, "A") + self._validate_pack_factor(self.storage_b_dtype, self.logical_b_bits, "B") + + @property + def storage_a_dtype(self) -> str: + return _infer_subbyte_storage_dtype(self.logical_a_dtype, self.logical_a_bits) + + @property + def storage_b_dtype(self) -> str: + return _infer_subbyte_storage_dtype(self.logical_b_dtype, self.logical_b_bits) + + @staticmethod + def _validate_pack_factor(storage_dtype: str, logical_bits: int, matrix: str): + storage_bits = DataType(storage_dtype).bits + if storage_bits < logical_bits or storage_bits % logical_bits != 0: + raise ValueError( + f"Subbyte MMA spec expects {matrix} storage dtype {storage_dtype} to pack logical {logical_bits}-bit elements exactly" + ) + + @property + def a_pack_factor(self) -> int: + return DataType(self.storage_a_dtype).bits // self.logical_a_bits + + @property + def b_pack_factor(self) -> int: + return DataType(self.storage_b_dtype).bits // self.logical_b_bits + + def get_pack_factor(self, matrix: Literal["A", "B"]) -> int: + if matrix == "A": + return self.a_pack_factor + if matrix == "B": + return self.b_pack_factor + raise ValueError(f"Unsupported matrix kind: {matrix}") + + def get_storage_dtype(self, matrix: Literal["A", "B"]) -> str: + if matrix == "A": + return self.storage_a_dtype + if matrix == "B": + return self.storage_b_dtype + raise ValueError(f"Unsupported matrix kind: {matrix}") + + def get_logical_dtype(self, matrix: Literal["A", "B"]) -> str: + if matrix == "A": + return self.logical_a_dtype + if matrix == "B": + return self.logical_b_dtype + raise ValueError(f"Unsupported matrix kind: {matrix}") + + def pack_extent(self, extent: int, matrix: Literal["A", "B"]) -> int: + pack_factor = self.get_pack_factor(matrix) + if extent % pack_factor != 0: + raise ValueError(f"{self.get_logical_dtype(matrix)} expects extent divisible by {pack_factor}, but got {extent}") + return extent // pack_factor + + +INT4_TENSORCORE_MMA_SPEC = SubByteTensorCoreMMASpec( + logical_a_dtype="int4", + logical_b_dtype="int4", + logical_a_bits=4, + logical_b_bits=4, + accum_dtype="int32", + mma_prefix="m16n8k32", + mma_a_dtype_abbrv="int4", + mma_b_dtype_abbrv="int4", + mma_ops=( + SubByteTensorCoreMMAOp(), + SubByteTensorCoreMMAOp(b_offset_num=1, b_offset_den=2, c_offset_num=1, c_offset_den=2), + SubByteTensorCoreMMAOp(a_offset_num=1, a_offset_den=2, b_offset_num=1, b_offset_den=4), + SubByteTensorCoreMMAOp(a_offset_num=1, a_offset_den=2, b_offset_num=3, b_offset_den=4, c_offset_num=1, c_offset_den=2), + ), +) + +_SUBBYTE_TENSORCORE_MMA_SPECS = { + "int4": INT4_TENSORCORE_MMA_SPEC, +} + + +def get_subbyte_tensorcore_mma_spec(dtype: str) -> SubByteTensorCoreMMASpec | None: + return _SUBBYTE_TENSORCORE_MMA_SPECS.get(str(dtype)) + + +def infer_subbyte_tensorcore_mma_spec(a_dtype: str, b_dtype: str) -> SubByteTensorCoreMMASpec | None: + a_spec = get_subbyte_tensorcore_mma_spec(a_dtype) + b_spec = get_subbyte_tensorcore_mma_spec(b_dtype) + + if a_spec is None and b_spec is None: + return None + if a_spec is None or b_spec is None: + raise ValueError(f"Subbyte MMA requires both operands to be subbyte dtypes, but got a_dtype={a_dtype}, b_dtype={b_dtype}") + if not (str(a_dtype) == str(a_spec.logical_a_dtype) and str(b_dtype) == str(a_spec.logical_b_dtype)): + raise ValueError(f"Unsupported subbyte MMA operand dtypes: a_dtype={a_dtype}, b_dtype={b_dtype}") + if a_spec != b_spec: + raise ValueError(f"Mismatched subbyte MMA specs for operands: a_dtype={a_dtype}, b_dtype={b_dtype}") + return a_spec + + class TensorCoreIntrinEmitter: """ To eliminate Python syntax within TIR Macro. @@ -83,22 +238,35 @@ def __init__( self.accum_dtype = accum_dtype self.a_transposed = a_transposed self.b_transposed = b_transposed + self.subbyte_mma_spec = infer_subbyte_tensorcore_mma_spec(a_dtype, b_dtype) + if self.subbyte_mma_spec is not None and str(accum_dtype) != str(self.subbyte_mma_spec.accum_dtype): + raise ValueError( + f"Subbyte MMA dtypes ({a_dtype}, {b_dtype}) expect accum dtype {self.subbyte_mma_spec.accum_dtype}, but got {accum_dtype}" + ) # Hint Information self.block_row_warps = block_row_warps self.block_col_warps = block_col_warps self.warp_row_tiles = warp_row_tiles self.warp_col_tiles = warp_col_tiles self.chunk = chunk - self._initialize_k_dim(a_dtype) + a_storage_dtype = self._get_storage_dtype("A") + b_storage_dtype = self._get_storage_dtype("B") + self._initialize_k_dim(a_storage_dtype) # For FP64, MMA shape is m8n8k4; adjust instance dims early - if DataType(a_dtype).bits == 64: + if DataType(a_storage_dtype).bits == 64: # Override default M/N dims for fp64 MMA self.M_DIM = 8 # n_dim will be set to 8 in _initialize_micro_size via k_dim==4 - self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) self._initialize_micro_size(self.M_DIM, self.k_dim) self._initialize_local_size(self.M_DIM, self.n_dim, self.k_dim, self.WARP_SIZE) - self._initialize_mma_prefix(self.k_dim) + if self.subbyte_mma_spec is None: + self._initialize_abbrev(a_storage_dtype, b_storage_dtype, accum_dtype) + self._initialize_mma_prefix(self.k_dim) + else: + self.a_dtype_abbrv = self.subbyte_mma_spec.mma_a_dtype_abbrv + self.b_dtype_abbrv = self.subbyte_mma_spec.mma_b_dtype_abbrv + self.accum_dtype_abbrv = str(self.subbyte_mma_spec.accum_dtype) + self.mma_prefix = self.subbyte_mma_spec.mma_prefix self._initialize_is_m_first(is_m_first) self.reduce_k = reduce_k @@ -111,6 +279,17 @@ def __init__( f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}" ) + def _get_storage_dtype(self, matrix: Literal["A", "B"]) -> str: + if matrix == "A": + logical_dtype = self.a_dtype + elif matrix == "B": + logical_dtype = self.b_dtype + else: + raise ValueError(f"Unsupported matrix kind: {matrix}") + if self.subbyte_mma_spec is None: + return logical_dtype + return self.subbyte_mma_spec.get_storage_dtype(matrix) + def _initialize_k_dim(self, a_dtype=T.float16): if isinstance(a_dtype, str): a_dtype = DataType(a_dtype) @@ -236,7 +415,8 @@ def extract_thread_binding(self, thread_id: PrimExpr, is_m_first: bool | None = def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0): # Fast path for fp64: no ldmatrix support, do direct per-lane loads - if DataType(self.a_dtype).bits == 64: + a_dtype = self._get_storage_dtype("A") + if DataType(a_dtype).bits == 64: warp_row_tiles = self.warp_row_tiles warp_rows = self.warp_rows chunk = self.chunk @@ -280,7 +460,6 @@ def _warp_ld_a_fp64( micro_size_x = self.micro_size_x micro_size_k = self.micro_size_k local_size_a = self.local_size_a - a_dtype = self.a_dtype a_transposed = self.a_transposed # ldmatrix cannot be used for int8 + trans case. ldmatrix_available = not (DataType(a_dtype).bits != 16 and a_transposed) @@ -352,7 +531,8 @@ def _warp_ldmatrix_a( def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0): # Fast path for fp64: no ldmatrix support, do direct per-lane loads - if DataType(self.b_dtype).bits == 64: + b_dtype = self._get_storage_dtype("B") + if DataType(b_dtype).bits == 64: warp_col_tiles = self.warp_col_tiles warp_cols = self.warp_cols chunk = self.chunk @@ -396,7 +576,6 @@ def _warp_ld_b_fp64( micro_size_y = self.micro_size_y micro_size_k = self.micro_size_k local_size_b = self.local_size_b - b_dtype = self.b_dtype b_transposed = self.b_transposed thread_binding = self.get_thread_binding() @@ -474,6 +653,19 @@ def _warp_ldmatrix_b( return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) def mma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr | None = 0): + if self.subbyte_mma_spec is not None: + return _emit_subbyte_tensorcore_mma( + self.subbyte_mma_spec, + self.warp_rows, + self.warp_cols, + self.local_size_a, + self.local_size_b, + self.local_size_out, + self.accum_dtype, + A_local_buf, + B_local_buf, + C_local_buf, + ) warp_rows = self.warp_rows warp_cols = self.warp_cols local_size_a = self.local_size_a @@ -613,7 +805,7 @@ def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A assert matrix in ["A", "B"], "matrix should be either A or B" matrix_is_a: bool = matrix == "A" matrix_is_b: bool = matrix == "B" - dtype = self.a_dtype if matrix_is_a else self.b_dtype + dtype = self._get_storage_dtype(matrix) dtype_bits = DataType(dtype).bits transposed = self.a_transposed if matrix_is_a else self.b_transposed @@ -905,7 +1097,7 @@ def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, rk=0): micro_size_x = self.micro_size_x micro_size_k = self.micro_size_k local_size_a = self.local_size_a - a_dtype = self.a_dtype + a_dtype = self._get_storage_dtype("A") a_transposed = self.a_transposed transform_kind_a = self.transform_kind_a @@ -1012,7 +1204,7 @@ def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, rk=0): micro_size_y = self.micro_size_y micro_size_k = self.micro_size_k local_size_b = self.local_size_b - b_dtype = self.b_dtype + b_dtype = self._get_storage_dtype("B") transform_kind_b = self.transform_kind_b b_transposed = self.b_transposed num_elems_per_byte = self.num_elems_per_byte @@ -1119,6 +1311,19 @@ def _warp_ldmatrix_b( return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) def mma(self, A_local_buf, B_local_buf, C_local_buf): + if self.subbyte_mma_spec is not None: + return _emit_subbyte_tensorcore_mma( + self.subbyte_mma_spec, + self.warp_rows, + self.warp_cols, + self.local_size_a, + self.local_size_b, + self.local_size_out, + self.accum_dtype, + A_local_buf, + B_local_buf, + C_local_buf, + ) warp_rows = self.warp_rows warp_cols = self.warp_cols local_size_a = self.local_size_a @@ -1170,205 +1375,53 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): return _warp_mma(A_local_buf, B_local_buf, C_local_buf) -class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter): - def mma(self, A_local_buf, B_local_buf, C_local_buf): - warp_rows = self.warp_rows - warp_cols = self.warp_cols - local_size_a = self.local_size_a - local_size_b = self.local_size_b - local_size_out = self.local_size_out - a_dtype_abbrv = "int4" - b_dtype_abbrv = "int4" - accum_dtype = self.accum_dtype - accum_dtype_abbrv = accum_dtype - mma_prefix = "m16n8k32" - - @T.macro - def _warp_mma(A_local_buf, B_local_buf, C_local_buf): - for i, j in T.grid(warp_rows, warp_cols): - """ - A[16, 32], B[16, 32], C[16, 16] - A_local_size -> 16 - B_local_size -> 16 - C_local_size -> 8 - For each m16n8k32 inst - For A: m16k32 consume 16 int4 elements -> 8 A_local_size - For A: n8k32 consume 8 int4 elements -> 4 B_local_size - For C: m16n8 consume 4 int32 elements -> 4 C_local_size - """ - - # A[0:16, 0:16] * B[0:8, 0:16] -> C[0:16, 0:8] - T.ptx_mma( - accum_dtype, - mma_prefix, - "row", - "col", - a_dtype_abbrv, - b_dtype_abbrv, - accum_dtype_abbrv, - A_local_buf.data, - i * local_size_a, - B_local_buf.data, - j * local_size_b, - C_local_buf.data, - i * warp_cols * local_size_out + j * local_size_out, - T.bool(False), - ) - - # A[0:16, 0:16] * B[8:16, 0:16] -> C[0:16, 8:16] - T.ptx_mma( - accum_dtype, - mma_prefix, - "row", - "col", - a_dtype_abbrv, - b_dtype_abbrv, - accum_dtype_abbrv, - A_local_buf.data, - i * local_size_a, - B_local_buf.data, - j * local_size_b + lift(local_size_b) // 2, - C_local_buf.data, - i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, - T.bool(False), - ) - - # A[0:16, 16:32] * B[0:8, 16:32] -> C[0:16, 0:8] - T.ptx_mma( - accum_dtype, - mma_prefix, - "row", - "col", - a_dtype_abbrv, - b_dtype_abbrv, - accum_dtype_abbrv, - A_local_buf.data, - i * local_size_a + lift(local_size_a) // 2, - B_local_buf.data, - j * local_size_b + lift(local_size_b) // 4, - C_local_buf.data, - i * warp_cols * local_size_out + j * local_size_out, - T.bool(False), - ) - - # A[0:16, 16:32] * B[8:16, 16:32] -> C[0:16, 8:16] - T.ptx_mma( - accum_dtype, - mma_prefix, - "row", - "col", - a_dtype_abbrv, - b_dtype_abbrv, - accum_dtype_abbrv, - A_local_buf.data, - i * local_size_a + lift(local_size_b) // 2, - B_local_buf.data, - j * local_size_b + lift(local_size_b) // 2 + lift(local_size_b) // 4, - C_local_buf.data, - i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, - T.bool(False), - ) - - return _warp_mma(A_local_buf, B_local_buf, C_local_buf) - - -class INT4TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitterWithLadderTransform): - def mma(self, A_local_buf, B_local_buf, C_local_buf): - warp_rows = self.warp_rows - warp_cols = self.warp_cols - local_size_a = self.local_size_a - local_size_b = self.local_size_b - local_size_out = self.local_size_out - a_dtype_abbrv = "int4" - b_dtype_abbrv = "int4" - accum_dtype = self.accum_dtype - accum_dtype_abbrv = T.int32 - mma_prefix = "m16n8k32" - - @T.macro - def _warp_mma(A_local_buf, B_local_buf, C_local_buf): - for i, j in T.grid(warp_rows, warp_cols): - """ - A[16, 32], B[16, 32], C[16, 16] - A_local_size -> 16 - B_local_size -> 16 - C_local_size -> 8 - For each m16n8k32 inst - For A: m16k32 consume 16 int4 elements -> 8 A_local_size - For A: n8k32 consume 8 int4 elements -> 4 B_local_size - For C: m16n8 consume 4 int32 elements -> 4 C_local_size - """ - - # A[0:16, 0:16] * B[0:8, 0:16] -> C[0:16, 0:8] - T.ptx_mma( - accum_dtype, - mma_prefix, - "row", - "col", - a_dtype_abbrv, - b_dtype_abbrv, - accum_dtype_abbrv, - A_local_buf.data, - i * local_size_a, - B_local_buf.data, - j * local_size_b, - C_local_buf.data, - i * warp_cols * local_size_out + j * local_size_out, - T.bool(False), - ) - - # A[0:16, 0:16] * B[8:16, 0:16] -> C[0:16, 8:16] - T.ptx_mma( - accum_dtype, - mma_prefix, - "row", - "col", - a_dtype_abbrv, - b_dtype_abbrv, - accum_dtype_abbrv, - A_local_buf.data, - i * local_size_a, - B_local_buf.data, - j * local_size_b + lift(local_size_b) // 2, - C_local_buf.data, - i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, - T.bool(False), - ) +def _emit_subbyte_tensorcore_mma( + mma_spec: SubByteTensorCoreMMASpec, + warp_rows: int, + warp_cols: int, + local_size_a: int, + local_size_b: int, + local_size_out: int, + accum_dtype: str, + A_local_buf, + B_local_buf, + C_local_buf, +): + accum_dtype_abbrv = accum_dtype + mma_prefix = mma_spec.mma_prefix + a_dtype_abbrv = mma_spec.mma_a_dtype_abbrv + b_dtype_abbrv = mma_spec.mma_b_dtype_abbrv + mma_op_offsets = tuple(mma_op.resolve_offsets(local_size_a, local_size_b, local_size_out) for mma_op in mma_spec.mma_ops) + + @T.macro + def _emit_subbyte_mma_op(A_local_buf, B_local_buf, C_local_buf, i, j, a_offset, b_offset, c_offset): + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a + a_offset, + B_local_buf.data, + j * local_size_b + b_offset, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out + c_offset, + T.bool(False), + ) - # A[0:16, 16:32] * B[0:8, 16:32] -> C[0:16, 0:8] - T.ptx_mma( - accum_dtype, - mma_prefix, - "row", - "col", - a_dtype_abbrv, - b_dtype_abbrv, - accum_dtype_abbrv, - A_local_buf.data, - i * local_size_a + lift(local_size_a) // 2, - B_local_buf.data, - j * local_size_b + lift(local_size_b) // 4, - C_local_buf.data, - i * warp_cols * local_size_out + j * local_size_out, - T.bool(False), - ) + def _emit_subbyte_mma_ops(A_local_buf, B_local_buf, C_local_buf, i, j, op_index: int = 0): + if op_index >= len(mma_op_offsets): + return + a_offset, b_offset, c_offset = mma_op_offsets[op_index] + _emit_subbyte_mma_op(A_local_buf, B_local_buf, C_local_buf, i, j, a_offset, b_offset, c_offset) + _emit_subbyte_mma_ops(A_local_buf, B_local_buf, C_local_buf, i, j, op_index + 1) - # A[0:16, 16:32] * B[8:16, 16:32] -> C[0:16, 8:16] - T.ptx_mma( - accum_dtype, - mma_prefix, - "row", - "col", - a_dtype_abbrv, - b_dtype_abbrv, - accum_dtype_abbrv, - A_local_buf.data, - i * local_size_a + lift(local_size_b) // 2, - B_local_buf.data, - j * local_size_b + lift(local_size_b) // 2 + lift(local_size_b) // 4, - C_local_buf.data, - i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, - T.bool(False), - ) + @T.macro + def _warp_mma(A_local_buf, B_local_buf, C_local_buf): + for i, j in T.grid(warp_rows, warp_cols): + _emit_subbyte_mma_ops(A_local_buf, B_local_buf, C_local_buf, i, j) - return _warp_mma(A_local_buf, B_local_buf, C_local_buf) + return _warp_mma(A_local_buf, B_local_buf, C_local_buf) diff --git a/tilelang/language/dtypes.py b/tilelang/language/dtypes.py index 74d5aab3f5..7b34e3fa8d 100644 --- a/tilelang/language/dtypes.py +++ b/tilelang/language/dtypes.py @@ -212,6 +212,9 @@ def __dtype_as_torch__(self: dtype) -> torch.dtype: elif dtype_str == "float4_e2m1fn": logger.info("torch doesn't support float4_e2m1fn, using float4_e2m1fnx2 as storage dtype.") return torch.float4_e2m1fn_x2 if hasattr(torch, "float4_e2m1fn_x2") else torch.int8 + elif dtype_str == "int4": + logger.info("torch doesn't support int4, using int8 as storage dtype.") + return torch.int8 elif dtype_str == "handle": return None elif dtype_str in _STR_TO_TORCH_DTYPE: diff --git a/tilelang/language/tir/op.py b/tilelang/language/tir/op.py index ffa00e030a..ebada05f9c 100644 --- a/tilelang/language/tir/op.py +++ b/tilelang/language/tir/op.py @@ -1350,7 +1350,7 @@ def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, sme return _tvm_op.ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset) -def ptx_cp_async(dst_access_ptr, src_access_ptr, bytes, predicate=None): +def ptx_cp_async(dst_access_ptr, src_access_ptr, num_elems, predicate=None): """TVM intrinsic for ptx async copy from global to shared memory using cp.async https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async @@ -1364,8 +1364,12 @@ def ptx_cp_async(dst_access_ptr, src_access_ptr, bytes, predicate=None): The source (global memory) access pointer created by tvm_access_ptr. Should include pointer, offset, extent, and read access flag (rw_mask=1). - bytes : int or PrimExpr - The number of bytes to copy (must be 4, 8, or 16). + num_elems : int or PrimExpr + The number of logical elements to copy. + + For TileLang's ``tl.ptx_cp_async`` frontend op, the final PTX byte width + is derived later from ``num_elems * element_bits(access_ptr)`` and must + eventually land on a legal ``cp.async`` width of 4, 8, or 16 bytes. predicate : PrimExpr, optional Optional predicate condition for conditional cp.async. When provided, the copy @@ -1379,11 +1383,11 @@ def ptx_cp_async(dst_access_ptr, src_access_ptr, bytes, predicate=None): Examples -------- - >>> # Copy 16 bytes from global to shared memory + >>> # Copy 16 uint8 elements (= 16 bytes) from global to shared memory >>> T.ptx_cp_async( ... T.tvm_access_ptr(T.type_annotation(T.uint8), A_shared.data, 0, 16, 2), # dst ... T.tvm_access_ptr(T.type_annotation(T.uint8), B_global.data, 0, 16, 1), # src - ... 16 # bytes + ... 16 # num_elems ... ) >>> >>> # Predicated cp.async (only copy if condition is true) @@ -1397,9 +1401,9 @@ def ptx_cp_async(dst_access_ptr, src_access_ptr, bytes, predicate=None): from tvm import tir if predicate is None: - return tir.call_intrin("", tir.op.Op.get("tl.ptx_cp_async"), dst_access_ptr, src_access_ptr, bytes) + return tir.call_intrin("", tir.op.Op.get("tl.ptx_cp_async"), dst_access_ptr, src_access_ptr, num_elems) else: - return tir.call_intrin("", tir.op.Op.get("tl.ptx_cp_async"), dst_access_ptr, src_access_ptr, bytes, predicate) + return tir.call_intrin("", tir.op.Op.get("tl.ptx_cp_async"), dst_access_ptr, src_access_ptr, num_elems, predicate) def ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id): diff --git a/tilelang/tileop/gemm/gemm_mma.py b/tilelang/tileop/gemm/gemm_mma.py index 1f2d1676e9..5b342d98f2 100644 --- a/tilelang/tileop/gemm/gemm_mma.py +++ b/tilelang/tileop/gemm/gemm_mma.py @@ -5,6 +5,8 @@ from tilelang.layout import make_swizzled_layout from tilelang.intrinsics.mma_macro_generator import ( TensorCoreIntrinEmitter, + SubByteTensorCoreMMASpec, + get_subbyte_tensorcore_mma_spec, ) from tilelang.utils.language import is_shared, is_fragment, is_full_region from tilelang import tvm as tvm @@ -15,12 +17,98 @@ from tilelang.transform.simplify import _Simplify +class SubByteGemmOperandAdaptor: + def __init__(self, mma_spec: SubByteTensorCoreMMASpec): + self.mma_spec = mma_spec + + def get_storage_dtype(self, matrix: str) -> str: + return self.mma_spec.get_storage_dtype(matrix) + + def get_packed_chunk(self, logical_chunk: int, matrix: str = "A") -> int: + logical_chunk = int(logical_chunk) + return self.mma_spec.pack_extent(logical_chunk, matrix) + + def make_packed_buffer(self, buf: tir.Buffer, matrix: str) -> tir.Buffer: + shape = list(buf.shape) + if len(shape) < 2: + raise ValueError(f"{self.mma_spec.get_logical_dtype(matrix)} T.gemm expects at least 2D operands, but got shape={shape}") + packed_last_dim = int(shape[-1]) + pack_factor = self.mma_spec.get_pack_factor(matrix) + if packed_last_dim % pack_factor != 0: + raise ValueError( + f"{self.mma_spec.get_logical_dtype(matrix)} T.gemm expects an innermost K extent divisible by " + f"{pack_factor}, but got {packed_last_dim}" + ) + shape[-1] = packed_last_dim // pack_factor + return T.view(buf, tuple(shape), dtype=self.get_storage_dtype(matrix)) + + def make_packed_region(self, region: tir.BufferRegion, matrix: str) -> tir.BufferRegion: + packed_buf = self.make_packed_buffer(region.buffer, matrix) + pack_factor = self.mma_spec.get_pack_factor(matrix) + packed_ranges = list(region.region) + last_range = packed_ranges[-1] + packed_ranges[-1] = Range.from_min_extent(last_range.min // pack_factor, last_range.extent // pack_factor) + return tir.BufferRegion(packed_buf, packed_ranges) + + class GemmMMA(GemmBase): - def infer_layout(self, target: Target, thread_nums: int): + def _get_subbyte_mma_spec(self) -> SubByteTensorCoreMMASpec | None: + return get_subbyte_tensorcore_mma_spec(self.in_dtype) + + def _get_subbyte_operand_adaptor(self) -> SubByteGemmOperandAdaptor | None: + mma_spec = self._get_subbyte_mma_spec() + if mma_spec is None: + return None + return SubByteGemmOperandAdaptor(mma_spec) + + def _validate_subbyte_mma_support(self, mma_spec: SubByteTensorCoreMMASpec): + chunk = int(self.chunk) + pack_factor_a = mma_spec.get_pack_factor("A") + pack_factor_b = mma_spec.get_pack_factor("B") + if not self.is_gemm_ss(): + raise ValueError(f"{self.in_dtype} T.gemm currently only supports shared/shared operands in the subbyte MMA path") + if self.trans_A or not self.trans_B: + raise ValueError( + f"{self.in_dtype} T.gemm currently only supports innermost-K packed layout (transpose_A=False, transpose_B=True)" + ) + if str(self.accum_dtype) != str(mma_spec.accum_dtype): + raise ValueError( + f"{self.in_dtype} T.gemm currently only supports {mma_spec.accum_dtype} accumulation, but got {self.accum_dtype}" + ) + if chunk % pack_factor_a != 0: + raise ValueError(f"{self.in_dtype} T.gemm expects the A K tile to be divisible by {pack_factor_a}, but got chunk={chunk}") + if chunk % pack_factor_b != 0: + raise ValueError(f"{self.in_dtype} T.gemm expects the B K tile to be divisible by {pack_factor_b}, but got chunk={chunk}") + + def _make_mma_emitter(self, target: Target, thread_nums: int, thread_var: tir.Var | None = None): m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GemmInst.MMA) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) - mma_emitter = TensorCoreIntrinEmitter( + subbyte_mma_spec = self._get_subbyte_mma_spec() + subbyte_adaptor = self._get_subbyte_operand_adaptor() + if subbyte_mma_spec is not None and subbyte_adaptor is not None: + self._validate_subbyte_mma_support(subbyte_mma_spec) + packed_chunk_a = subbyte_adaptor.get_packed_chunk(self.chunk, matrix="A") + packed_chunk_b = subbyte_adaptor.get_packed_chunk(self.chunk, matrix="B") + if packed_chunk_a != packed_chunk_b: + raise ValueError( + f"Subbyte MMA currently expects A/B to use the same packed K tile, but got A={packed_chunk_a}, B={packed_chunk_b}" + ) + emitter = TensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=packed_chunk_a, + thread_var=thread_var, + ) + return emitter, m_warp, n_warp + emitter = TensorCoreIntrinEmitter( a_dtype=self.in_dtype, b_dtype=self.in_dtype, accum_dtype=self.accum_dtype, @@ -31,7 +119,12 @@ def infer_layout(self, target: Target, thread_nums: int): warp_row_tiles=warp_row_tiles, warp_col_tiles=warp_col_tiles, chunk=self.chunk, + thread_var=thread_var, ) + return emitter, m_warp, n_warp + + def infer_layout(self, target: Target, thread_nums: int): + mma_emitter, _, _ = self._make_mma_emitter(target, thread_nums) if self.is_gemm_ss(): return { self.A: make_swizzled_layout(self.A), @@ -68,24 +161,11 @@ def lower( mbar_phase_expr: tir.PrimExpr | None = None, ): thread_nums = thread_bounds.extent - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GemmInst.MMA) - warp_row_tiles = int(self.M // m_warp) - warp_col_tiles = int(self.N // n_warp) - mma_emitter = TensorCoreIntrinEmitter( - a_dtype=self.in_dtype, - b_dtype=self.in_dtype, - accum_dtype=self.accum_dtype, - a_transposed=self.trans_A, - b_transposed=self.trans_B, - block_row_warps=m_warp, - block_col_warps=n_warp, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=self.chunk, - thread_var=thread_var, - ) + mma_emitter, _, _ = self._make_mma_emitter(target, thread_nums, thread_var=thread_var) + subbyte_adaptor = self._get_subbyte_operand_adaptor() - in_dtype = self.in_dtype + a_local_dtype = subbyte_adaptor.get_storage_dtype("A") if subbyte_adaptor is not None else self.in_dtype + b_local_dtype = subbyte_adaptor.get_storage_dtype("B") if subbyte_adaptor is not None else self.in_dtype warp_rows = mma_emitter.warp_rows warp_cols = mma_emitter.warp_cols local_size_a = mma_emitter.local_size_a @@ -97,6 +177,9 @@ def lower( A_region = self.ARegion B_region = self.BRegion C_region = self.CRegion + if subbyte_adaptor is not None: + A_region = subbyte_adaptor.make_packed_region(A_region, "A") + B_region = subbyte_adaptor.make_packed_region(B_region, "B") A_buf = A_region.buffer B_buf = B_region.buffer @@ -117,8 +200,8 @@ def _gemm_ssr() -> None: B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), a_local_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), b_local_dtype) if clear_accum: T.clear(C_buf) for ki in T.serial(0, (block_K // micro_size_k)): @@ -137,7 +220,10 @@ def _gemm_ssr() -> None: ) # Perform Matrix Multiplication - mma_emitter.mma(A_local, B_local, C_buf, ki) + if subbyte_adaptor is not None: + mma_emitter.mma(A_local, B_local, C_buf) + else: + mma_emitter.mma(A_local, B_local, C_buf, ki) # Simplify to optimize the index computing # Must inline let statements to simplify the analysis @@ -152,7 +238,7 @@ def _gemm_srr() -> None: B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), a_local_dtype) for ki in T.serial(0, (block_K // micro_size_k)): if clear_accum: @@ -182,7 +268,7 @@ def _gemm_rsr() -> None: B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), b_local_dtype) if clear_accum: T.clear(C_buf) for ki in T.serial(0, (block_K // micro_size_k)): diff --git a/tilelang/utils/tensor.py b/tilelang/utils/tensor.py index 7aa46062b5..d42254590e 100644 --- a/tilelang/utils/tensor.py +++ b/tilelang/utils/tensor.py @@ -65,6 +65,8 @@ def map_torch_type(intype) -> torch.dtype: "torch.float4_e2m1fnx2 is not supported in this version of torchPlease upgrade torch >= 2.8.0" ) return torch.float4_e2m1fnx2 + elif intype == "int4": + return torch.int8 elif "float4" in intype: # PyTorch doesn't support float4, use int8 as storage type return torch.int8 From c797e41040c25e3c71d62f9d4cc648ec90821c78 Mon Sep 17 00:00:00 2001 From: Tong WU <109033598+Rachmanino@users.noreply.github.com> Date: Tue, 21 Apr 2026 01:21:45 +0800 Subject: [PATCH 099/156] [Bugfix] Correct index calculation in Software Pipeline pass (#2070) [Bugfix] Correct index calculation in inject_pipeline.cc Adjusted the index calculation in the inject_pipeline.cc file to account for the minimum value of the pipeline loop variable. This change ensures that the floormod operation correctly reflects the intended behavior when computing new indices for buffer access. --- src/transform/inject_pipeline.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transform/inject_pipeline.cc b/src/transform/inject_pipeline.cc index d0e832f94e..facf769f87 100644 --- a/src/transform/inject_pipeline.cc +++ b/src/transform/inject_pipeline.cc @@ -526,7 +526,9 @@ class PipelineBodyRewriter : public StmtExprMutator { } PrimExpr new_index = old_index + - floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset; + floormod((pipeline_loop_->loop_var - pipeline_loop_->min), + new_buffer->shape[0]) * + offset; new_args.Set(i + 1, new_index); } } From 96c649fb68fb21ae7a50813aec37110956c664ad Mon Sep 17 00:00:00 2001 From: VitalyR Date: Tue, 21 Apr 2026 15:36:55 +0800 Subject: [PATCH 100/156] Add frontmatter for the build skill (#2068) * [Build] Add Agent Skills frontmatter to build skill The Agent Skills spec requires SKILL.md to contain YAML frontmatter followed by Markdown content, with name and description as required fields. Reference: https://agentskills.io/specification#skill-md-format * [Build] Rename repo skill to tilelang-build Rename the repository-local build skill directory from .agents/skills/build to .agents/skills/tilelang-build, and rename the skill itself to tilelang-build. This avoids collisions with the repository's build/ ignore rule and makes the skill's repository-specific scope explicit. --- .agents/skills/{build => tilelang-build}/SKILL.md | 5 +++++ 1 file changed, 5 insertions(+) rename .agents/skills/{build => tilelang-build}/SKILL.md (90%) diff --git a/.agents/skills/build/SKILL.md b/.agents/skills/tilelang-build/SKILL.md similarity index 90% rename from .agents/skills/build/SKILL.md rename to .agents/skills/tilelang-build/SKILL.md index 63dde07427..f474736fc1 100644 --- a/.agents/skills/build/SKILL.md +++ b/.agents/skills/tilelang-build/SKILL.md @@ -1,3 +1,8 @@ +--- +name: tilelang-build +description: Repository-specific build, rebuild, install, and test instructions for tilelang. Use when working in the tilelang repository and the correct commands are needed for building from source, reinstalling after changes, or running project tests. +--- + # Build & Install ## Installing / Rebuilding tilelang From b744da1ab8bb0ddcac9e19370550cb002db932e8 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Tue, 21 Apr 2026 17:00:19 +0800 Subject: [PATCH 101/156] Refactor ptx_ldmatrix to use tl.access_ptr with simplified signature (#2072) * Refactor ptx_ldmatrix to use tl.access_ptr with simplified signature Replace the legacy 8-parameter TVM-style ptx_ldmatrix API with a 4-parameter TileLang-style API that uses tl.access_ptr wrapping BufferLoad, preserving buffer metadata for analysis passes. - Change ptx_ldmatrix signature from (dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset) to (trans, num, src_access_ptr, dst_access_ptr) - Replace Buffer::access_ptr() with tl::access_ptr() in LowerLDSMCopy - Update get_ldmatrix_offset to return (row, col) coordinates instead of linear byte offsets - Update all 12 call sites across mma_macro_generator.py and mma_sp_macro_generator.py - Remove _dtype_forward wrapper from ptx_ldmatrix in tir/ir.py Co-Authored-By: Claude Opus 4.6 * Fix lint: remove unused variables and apply clang-format Co-Authored-By: Claude Opus 4.6 * Simplify src_indices assignment with ternary expressions Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Claude Opus 4.6 --- src/op/copy.cc | 13 ++- tilelang/intrinsics/mma_macro_generator.py | 109 ++++++------------ tilelang/intrinsics/mma_sp_layout.py | 8 +- tilelang/intrinsics/mma_sp_macro_generator.py | 68 +++++------ tilelang/intrinsics/utils.py | 12 +- tilelang/language/tir/ir.py | 2 +- tilelang/language/tir/op.py | 41 ++++--- 7 files changed, 109 insertions(+), 144 deletions(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index aeb057eb44..2eb8bef602 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -1162,9 +1162,10 @@ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer, shared_coords = inv->Forward({local_index, thread_index}); } shared_coords.pop_back(); // remove rep - PrimExpr shared_addr = shared_tensor.access_ptr( - is_ldmatrix ? 1 : 2, DataType::Handle(), 1, - shared_tensor.OffsetOf(shared_coords).back(), PrimExpr(2 * num)); + PrimExpr shared_addr = + Call(DataType::Handle(), tl::access_ptr(), + {BufferLoad(shared_tensor, shared_coords), PrimExpr(2 * num), + make_const(DataType::Int(32), is_ldmatrix ? 1 : 2)}); args.push_back(shared_addr); if (is_ldmatrix) { @@ -1174,8 +1175,10 @@ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer, // copy return LowerNormalCopy(T, analyzer); } - PrimExpr local_addr = local_tensor.access_ptr( - 2, DataType::Handle(), 1, local_iter * 2 * num, PrimExpr(2 * num)); + PrimExpr local_addr = + Call(DataType::Handle(), tl::access_ptr(), + {BufferLoad(local_tensor, {local_iter * 2 * num}), + PrimExpr(2 * num), make_const(DataType::Int(32), 2)}); args.push_back(local_addr); } else { for (int i = 0; i < num; i++) { diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index 3ac11a465f..ff6e427b65 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -500,24 +500,20 @@ def _warp_ldmatrix_a( trans = self.a_transposed for i in T.serial(warp_rows): - # Assign A_shared_buf_elem wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k - A_shared_buf_elem = ( - A_buf[tuple(A_other) + (A_base0 + wk, A_base1 + wi)] - if a_transposed - else A_buf[tuple(A_other) + (A_base0 + wi, A_base1 + wk)] - ) if ldmatrix_available: + row_off, col_off = get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed) + src_indices = ( + tuple(A_other) + (A_base0 + wk + row_off, A_base1 + wi + col_off) + if a_transposed + else tuple(A_other) + (A_base0 + wi + row_off, A_base1 + wk + col_off) + ) T.ptx_ldmatrix( - a_dtype, T.bool(trans), 4, - ".b16", - A_local_buf.data, - i * local_size_a, - T.access_ptr(A_shared_buf_elem, "r"), - get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), + T.access_ptr(A_buf[src_indices], "r", extent=8), + T.access_ptr(A_local_buf[i * local_size_a], "w", extent=8), ) else: for j in T.serial(local_size_a): @@ -623,21 +619,18 @@ def _warp_ldmatrix_b( ) if ldmatrix_available: - B_shared_buf_elem = ( - B_buf[tuple(B_other) + (B_base0 + wi, B_base1 + wk)] + num = 4 if replicate_b else 2 + row_off, col_off = get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed) + src_indices = ( + tuple(B_other) + (B_base0 + wi + row_off, B_base1 + wk + col_off) if b_transposed - else B_buf[tuple(B_other) + (B_base0 + wk, B_base1 + wi)] + else tuple(B_other) + (B_base0 + wk + row_off, B_base1 + wi + col_off) ) - T.ptx_ldmatrix( - b_dtype, T.bool(trans), - 4 if replicate_b else 2, - ".b16", - B_local_buf.data, - i * local_size_b, - T.access_ptr(B_shared_buf_elem, "r"), - get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed), + num, + T.access_ptr(B_buf[src_indices], "r", extent=2 * num), + T.access_ptr(B_local_buf[i * local_size_b], "w", extent=2 * num), ) else: @@ -1115,21 +1108,19 @@ def _warp_ldmatrix_a( tx, _, warp_m = self.extract_thread_binding(thread_binding) if transform_kind_a == TransformKind.NonTransform: for i in T.serial(warp_rows): + row_off, col_off = get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed) T.ptx_ldmatrix( - a_dtype, T.bool(False), 4, - ".b16", - A_local_buf.data, - i * local_size_a, T.access_ptr( A_shared_buf[ - warp_m * warp_row_tiles + i * micro_size_x, - rk * chunk + ki * micro_size_k, + warp_m * warp_row_tiles + i * micro_size_x + row_off, + rk * chunk + ki * micro_size_k + col_off, ], "r", + extent=8, ), - get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), + T.access_ptr(A_local_buf[i * local_size_a], "w", extent=8), ) elif transform_kind_a == TransformKind.InterWarpTransform: for i in T.serial(warp_rows): @@ -1144,18 +1135,12 @@ def _warp_ldmatrix_a( (ri) % micro_size_x, (rj) % micro_size_k, ) - args = (ni, nj, nii, njj) if transform_kind_a > 0 else (ri, rj) - A_shared_elem = A_shared_buf[args] - + row_off, col_off = get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed) T.ptx_ldmatrix( - a_dtype, T.bool(False), 4, - ".b16", - A_local_buf.data, - i * local_size_a, - T.access_ptr(A_shared_elem, "r"), - get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), + T.access_ptr(A_shared_buf[ni, nj, nii + row_off, njj + col_off], "r", extent=8), + T.access_ptr(A_local_buf[i * local_size_a], "w", extent=8), ) elif transform_kind_a == TransformKind.IntraWarpTransform: for i in T.serial(warp_rows): @@ -1170,17 +1155,13 @@ def _warp_ldmatrix_a( (ri) % micro_size_x, (rj) % micro_size_k, ) - A_shared_elem = A_shared_buf[ni, nj, nii, njj] - + row_off = (tx * local_size_a) // stride + col_off = (tx * local_size_a) % stride T.ptx_ldmatrix( - a_dtype, T.bool(False), 4, - ".b16", - A_local_buf.data, - i * local_size_a, - T.access_ptr(A_shared_elem, "r"), - tx * local_size_a, + T.access_ptr(A_shared_buf[ni, nj, nii + row_off, njj + col_off], "r", extent=8), + T.access_ptr(A_local_buf[i * local_size_a], "w", extent=8), ) elif transform_kind_a == TransformKind.LDMatrixTransform: for j in T.serial(warp_rows): @@ -1229,17 +1210,12 @@ def _warp_ldmatrix_b( warp_n * warp_col_tiles + j * micro_size_y, rk * chunk + ki * micro_size_k, ) - B_shared_elem = B_shared_buf[ri, rj] - + row_off, col_off = get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed) T.ptx_ldmatrix( - b_dtype, T.bool(False), 4, - ".b16", - B_local_buf.data, - j * local_size_b, - T.access_ptr(B_shared_elem, "r"), - get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed), + T.access_ptr(B_shared_buf[ri + row_off, rj + col_off], "r", extent=8), + T.access_ptr(B_local_buf[j * local_size_b], "w", extent=8), ) elif transform_kind_b == TransformKind.InterWarpTransform: for j in T.serial(warp_cols): @@ -1254,17 +1230,12 @@ def _warp_ldmatrix_b( (ri) % micro_size_y, (rj) % micro_size_k, ) - B_shared_elem = B_shared_buf[ni, nj, nii, njj] - + row_off, col_off = get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed) T.ptx_ldmatrix( - b_dtype, T.bool(False), # TODO(lei): should be optimized 4, - ".b16", - B_local_buf.data, - j * local_size_b, - T.access_ptr(B_shared_elem, "r"), - get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed), + T.access_ptr(B_shared_buf[ni, nj, nii + row_off, njj + col_off], "r", extent=8), + T.access_ptr(B_local_buf[j * local_size_b], "w", extent=8), ) elif transform_kind_b == TransformKind.IntraWarpTransform: for j in T.serial(warp_cols): @@ -1279,17 +1250,13 @@ def _warp_ldmatrix_b( (ri) % micro_size_y, (rj) % micro_size_k, ) - B_shared_elem = B_shared_buf[ni, nj, nii, njj] - + row_off = (tx * local_size_b) // stride + col_off = (tx * local_size_b) % stride T.ptx_ldmatrix( - b_dtype, T.bool(False), # TODO(lei): should be optimized 4, - ".b16", - B_local_buf.data, - j * local_size_b, - T.access_ptr(B_shared_elem, "r"), - tx * local_size_b, + T.access_ptr(B_shared_buf[ni, nj, nii + row_off, njj + col_off], "r", extent=8), + T.access_ptr(B_local_buf[j * local_size_b], "w", extent=8), ) elif transform_kind_b == TransformKind.LDMatrixTransform: local_size_dequantize = local_size_b // num_elems_per_byte diff --git a/tilelang/intrinsics/mma_sp_layout.py b/tilelang/intrinsics/mma_sp_layout.py index 58034e7fdb..73da1289ab 100644 --- a/tilelang/intrinsics/mma_sp_layout.py +++ b/tilelang/intrinsics/mma_sp_layout.py @@ -158,7 +158,7 @@ def get_ldmatrix_offset_b( if transposed: transform_func = ldmatrix_trans_32x8_to_shared_16x16_layout new_row_idx, new_col_idx = transform_func(row_idx, col_idx) - return new_row_idx * stride + new_col_idx + return new_row_idx, new_col_idx else: raise ValueError("ldmatrix only supports B transposed for 32-bit dtype") elif dtype_bits == 16: @@ -166,15 +166,15 @@ def get_ldmatrix_offset_b( transform_func_trans = ldmatrix_trans_32x16_to_shared_16x32_layout if transposed: new_row_idx, new_col_idx = transform_func_trans(row_idx, col_idx) - return new_row_idx * stride + new_col_idx + return new_row_idx, new_col_idx else: new_row_idx, new_col_idx = transform_func(row_idx, col_idx) - return new_row_idx * stride + new_col_idx + return new_row_idx, new_col_idx elif dtype_bits == 8: if transposed: transform_func = ldmatrix_trans_32x32_to_shared_shared_16x64_layout new_row_idx, new_col_idx = transform_func(row_idx, col_idx) - return new_row_idx * stride + new_col_idx + return new_row_idx, new_col_idx else: raise ValueError("ldmatrix only supports B transposed for 8-bit dtype") else: diff --git a/tilelang/intrinsics/mma_sp_macro_generator.py b/tilelang/intrinsics/mma_sp_macro_generator.py index 480a85601b..18a37b8e83 100644 --- a/tilelang/intrinsics/mma_sp_macro_generator.py +++ b/tilelang/intrinsics/mma_sp_macro_generator.py @@ -285,24 +285,20 @@ def _warp_ldmatrix_a( trans = self.a_transposed for i in T.serial(warp_rows): - # Assign A_shared_buf_elem wi, wk = warp_m * warp_row_tiles + i * micro_size_x, (rk * warp_k + ki * micro_size_k) // self.SPARSE_FACTOR - A_shared_buf_elem = ( - A_buf[tuple(A_other) + (A_base0 + wk, A_base1 + wi)] - if a_transposed - else A_buf[tuple(A_other) + (A_base0 + wi, A_base1 + wk)] - ) if ldmatrix_available: + row_off, col_off = get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed) + src_indices = ( + tuple(A_other) + (A_base0 + wk + row_off, A_base1 + wi + col_off) + if a_transposed + else tuple(A_other) + (A_base0 + wi + row_off, A_base1 + wk + col_off) + ) T.ptx_ldmatrix( - a_dtype, T.bool(trans), 4, - ".b16", - A_local_buf.data, - i * local_size_a, - T.access_ptr(A_shared_buf_elem, "r"), - get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), + T.access_ptr(A_buf[src_indices], "r", extent=8), + T.access_ptr(A_local_buf[i * local_size_a], "w", extent=8), ) else: for j in T.serial(local_size_a): @@ -443,44 +439,44 @@ def _warp_ldmatrix_b( ) if ldmatrix_available: - B_shared_buf_elem = ( - B_buf[tuple(B_other) + (B_base0 + wi, B_base1 + wk)] - if b_transposed - else B_buf[tuple(B_other) + (B_base0 + wk, B_base1 + wi)] - ) - if replicate_b: + row_off, col_off = get_ldmatrix_offset_b("B", tx, 0, stride, b_dtype, b_transposed) + src_indices = ( + tuple(B_other) + (B_base0 + wi + row_off, B_base1 + wk + col_off) + if b_transposed + else tuple(B_other) + (B_base0 + wk + row_off, B_base1 + wi + col_off) + ) T.ptx_ldmatrix( - b_dtype, T.bool(trans), 4, - ".b16", - B_local_buf.data, - i * local_size_b, - T.access_ptr(B_shared_buf_elem, "r"), - get_ldmatrix_offset_b("B", tx, 0, stride, b_dtype, b_transposed), + T.access_ptr(B_buf[src_indices], "r", extent=8), + T.access_ptr(B_local_buf[i * local_size_b], "w", extent=8), ) + row_off, col_off = get_ldmatrix_offset_b("B", tx, lift(local_size_b) // 2, stride, b_dtype, b_transposed) + src_indices = ( + tuple(B_other) + (B_base0 + wi + row_off, B_base1 + wk + col_off) + if b_transposed + else tuple(B_other) + (B_base0 + wk + row_off, B_base1 + wi + col_off) + ) T.ptx_ldmatrix( - b_dtype, T.bool(trans), 4, - ".b16", - B_local_buf.data, - i * local_size_b + lift(local_size_b) // 2, - T.access_ptr(B_shared_buf_elem, "r"), - get_ldmatrix_offset_b("B", tx, lift(local_size_b) // 2, stride, b_dtype, b_transposed), + T.access_ptr(B_buf[src_indices], "r", extent=8), + T.access_ptr(B_local_buf[i * local_size_b + lift(local_size_b) // 2], "w", extent=8), ) else: + row_off, col_off = get_ldmatrix_offset_b("B", tx, 0, stride, b_dtype, b_transposed) + src_indices = ( + tuple(B_other) + (B_base0 + wi + row_off, B_base1 + wk + col_off) + if b_transposed + else tuple(B_other) + (B_base0 + wk + row_off, B_base1 + wi + col_off) + ) T.ptx_ldmatrix( - b_dtype, T.bool(trans), 4, - ".b16", - B_local_buf.data, - i * local_size_b, - T.access_ptr(B_shared_buf_elem, "r"), - get_ldmatrix_offset_b("B", tx, 0, stride, b_dtype, b_transposed), + T.access_ptr(B_buf[src_indices], "r", extent=8), + T.access_ptr(B_local_buf[i * local_size_b], "w", extent=8), ) else: diff --git a/tilelang/intrinsics/utils.py b/tilelang/intrinsics/utils.py index fb24a4add2..128d9819e7 100644 --- a/tilelang/intrinsics/utils.py +++ b/tilelang/intrinsics/utils.py @@ -33,11 +33,11 @@ def get_ldmatrix_offset( if matrix == "B" and transposed: transform_func = ldmatrix_32x4_to_shared_16x8_layout_b new_row_idx, new_col_idx = transform_func(row_idx, col_idx) - return new_row_idx * stride + new_col_idx + return new_row_idx, new_col_idx elif matrix == "A" and not transposed: transform_func = ldmatrix_32x4_to_shared_16x8_layout_a new_row_idx, new_col_idx = transform_func(row_idx, col_idx) - return new_row_idx * stride + new_col_idx + return new_row_idx, new_col_idx else: raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8") elif dtype_bits == 16: @@ -45,19 +45,19 @@ def get_ldmatrix_offset( transform_func_trans = ldmatrix_trans_32x8_to_shared_16x16_layout if transposed: new_row_idx, new_col_idx = transform_func_trans(row_idx, col_idx) - return new_row_idx * stride + new_col_idx + return new_row_idx, new_col_idx else: new_row_idx, new_col_idx = transform_func(row_idx, col_idx) - return new_row_idx * stride + new_col_idx + return new_row_idx, new_col_idx elif dtype_bits == 8: if matrix == "B" and transposed: transform_func = ldmatrix_32x16_to_shared_16x32_layout_b new_row_idx, new_col_idx = transform_func(row_idx, col_idx) - return new_row_idx * stride + new_col_idx + return new_row_idx, new_col_idx elif matrix == "A" and not transposed: transform_func = ldmatrix_32x16_to_shared_16x32_layout_a new_row_idx, new_col_idx = transform_func(row_idx, col_idx) - return new_row_idx * stride + new_col_idx + return new_row_idx, new_col_idx else: raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8") else: diff --git a/tilelang/language/tir/ir.py b/tilelang/language/tir/ir.py index 1e35f77b07..6fab09ae51 100644 --- a/tilelang/language/tir/ir.py +++ b/tilelang/language/tir/ir.py @@ -292,7 +292,7 @@ def wrapped(*args, **kwargs): ptx_wgmma_rs = _dtype_forward(_tir_op.ptx_wgmma_rs) ptx_tcgen05_mma_ss = _dtype_forward(_tir_op.ptx_tcgen05_mma_ss) ptx_tcgen05_mma_ts = _dtype_forward(_tir_op.ptx_tcgen05_mma_ts) -ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix) +ptx_ldmatrix = _tir_op.ptx_ldmatrix ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async) ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk) mma_store = _dtype_forward(_tir_op.mma_store) diff --git a/tilelang/language/tir/op.py b/tilelang/language/tir/op.py index ebada05f9c..ca70d63673 100644 --- a/tilelang/language/tir/op.py +++ b/tilelang/language/tir/op.py @@ -1312,42 +1312,41 @@ def mma_fill(dtype, local_size, local_ptr, offset): return _tvm_op.mma_fill(dtype, local_size, local_ptr, offset) -def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset): - """TVM intrinsic for ptx load matrix from shared memory +def ptx_ldmatrix(trans, num, src_access_ptr, dst_access_ptr): + """TileLang intrinsic for ptx load matrix from shared memory + + Uses `tl.ptx_ldmatrix` which expects access pointers created via + `T.access_ptr` (i.e. `tl.access_ptr` wrapping a `BufferLoad`). + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix Parameters ---------- - dtype : str - The data type of the result. - trans : bool The matrix is loaded in column-major format. num : IntImm - The number of matrices. - - type : Literal[".b16"] - The data type of the matrices. - - local_ptr : Var - The local pointer variable. + The number of matrices (2 or 4). - local_offset : Expr - The offset of local pointer. - - smem_ptr : Var - The shared memory pointer variable. + src_access_ptr : PrimExpr + A `tl.access_ptr` pointing to the source (shared memory) buffer. - smem_offset : Expr - The offset of shared memort pointer. + dst_access_ptr : PrimExpr + A `tl.access_ptr` pointing to the destination (local/register) buffer. Returns ------- call : PrimExpr - The call expression. + The call expression (handle-typed). """ - return _tvm_op.ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset) + return tvm.tir.call_intrin( + "handle", + tvm.tir.op.Op.get("tl.ptx_ldmatrix"), + trans, + num, + src_access_ptr, + dst_access_ptr, + ) def ptx_cp_async(dst_access_ptr, src_access_ptr, num_elems, predicate=None): From 046b1bdc698623efef4d45afd9b0ce4d41ff0982 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Tue, 21 Apr 2026 18:09:24 +0800 Subject: [PATCH 102/156] [FFI] Remove upper version bound on apache-tvm-ffi (#2071) * ugrade tvm * Remove upper version bound on apache-tvm-ffi The <0.1.10 pin was introduced to avoid a derived_object regression, which has since been resolved. Removing the cap allows compatibility with newer versions of apache-tvm-ffi. * Enhance ptxas usage level handling in CUDA compilation Added type conversion for ptxas_usage_level in both tilelang_callback_cuda_compile and LibraryGenerator classes to ensure it is treated as an integer when specified. This change improves the robustness of the configuration handling for CUDA compilation. * Fix IR string representation issues from tvm-ffi upgrade The new tvm-ffi version changes str() output for IR nodes from compact script format to verbose repr format. This caused two issues: 1. T.call_packed("name") is now tir.tvm_call_packed with value="name" 2. tir.Var str() now produces "tir.Var(span=None, ...)" instead of name Co-Authored-By: Claude Opus 4.6 * Simplify TMA descriptor globalAddress handling in cutedsl wrapper Since global_address is now converted to string by pythonic_expr_func in parse_tma_descriptor_args, remove the unnecessary tir.Var type check and .name extraction. Co-Authored-By: Claude Opus 4.6 * Refactor CUDA test for vectorization and update IR string representation Updated the test for vectorization in CUDA by replacing the direct function definition with a JIT-compiled kernel. Additionally, modified the string representation of IR modules in multiple tests to use the script format instead of the default string format, ensuring consistency and clarity in assertions. --------- Co-authored-by: Claude Opus 4.6 --- 3rdparty/tvm | 2 +- pyproject.toml | 2 +- requirements-dev.txt | 2 +- requirements.txt | 2 +- .../language/test_tilelang_language_let.py | 24 ++++++------- .../test_tilelang_transform_thread_sync.py | 36 +++++++++---------- tilelang/engine/lower.py | 2 ++ tilelang/jit/adapter/cutedsl/wrapper.py | 7 ++-- tilelang/jit/adapter/libgen.py | 2 ++ tilelang/jit/adapter/utils.py | 1 + tilelang/jit/adapter/wrapper.py | 5 ++- 11 files changed, 45 insertions(+), 40 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 882a774844..0e15b274bc 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 882a774844993d103ae6e317ba3c7bbb5952b662 +Subproject commit 0e15b274bce8b46f971abf5ac390e844aa6acee5 diff --git a/pyproject.toml b/pyproject.toml index 601f2e35fa..80aba41e32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ # >=0.1.6 fixes a memory issue: tilelang#1502, but keep # requirement as wide as possible to be compatible with other libraries # pip will try to use latest version whenever possible. - "apache-tvm-ffi~=0.1.0,>=0.1.2,<0.1.10", + "apache-tvm-ffi~=0.1.0,>=0.1.2", # torch-c-dlpack-ext provides prebuilt torch extensions. # Without it, TVM FFI may require JIT compilation on first import. "torch-c-dlpack-ext; python_version < '3.14'", diff --git a/requirements-dev.txt b/requirements-dev.txt index a74959409a..f8dccdc871 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,6 @@ # Requirements to run local build with `--no-build-isolation` or other developments -apache-tvm-ffi~=0.1.0,>=0.1.2,<0.1.10 +apache-tvm-ffi~=0.1.0,>=0.1.2 build cmake>=3.26 cython>=3.1.0 diff --git a/requirements.txt b/requirements.txt index 37023a758f..2dbe070d9a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # Runtime requirements -apache-tvm-ffi~=0.1.0,>=0.1.2,<0.1.10 +apache-tvm-ffi~=0.1.0,>=0.1.2 torch-c-dlpack-ext; python_version < '3.14' cloudpickle ml-dtypes diff --git a/testing/python/language/test_tilelang_language_let.py b/testing/python/language/test_tilelang_language_let.py index e1f3f394b5..74ec0901c6 100644 --- a/testing/python/language/test_tilelang_language_let.py +++ b/testing/python/language/test_tilelang_language_let.py @@ -3,20 +3,20 @@ from tilelang import language as T -@tilelang.testing.requires_cuda -def test_let_vectorize_load(): - @T.prim_func - def main(A_ptr: T.handle): - A = T.match_buffer(A_ptr, (16, 16), dtype=T.float32, align=16) +@tilelang.jit +def test_kernel( + A: T.Tensor((16, 16), dtype=T.float32), +): + for _blockIdx in T.thread_binding(1, thread="blockIdx.x"): + for _threadIdx in T.thread_binding(128, thread="threadIdx.x"): + b = A[0, 0:4] + A[0, 4:8] = b - for _blockIdx in T.thread_binding(1, thread="blockIdx.x"): - for _threadIdx in T.thread_binding(128, thread="threadIdx.x"): - b = A[0, 0:4] - A[0, 4:8] = b - mod = tvm.IRModule({"main": main}) - mod = tvm.compile(mod, target="cuda") - assert "float4 b" in mod.mod.imports[0].inspect_source() +@tilelang.testing.requires_cuda +def test_let_vectorize_load(): + kernel_source = test_kernel.get_kernel_source() + assert "float4 b" in kernel_source if __name__ == "__main__": diff --git a/testing/python/transform/test_tilelang_transform_thread_sync.py b/testing/python/transform/test_tilelang_transform_thread_sync.py index 25f9c2aab1..08ec54e451 100644 --- a/testing/python/transform/test_tilelang_transform_thread_sync.py +++ b/testing/python/transform/test_tilelang_transform_thread_sync.py @@ -52,7 +52,7 @@ def func(): mod = tvm.IRModule({"main": func}) mod = tilelang.transform.ThreadSync("shared")(mod) - s = str(mod) + s = str(mod.script()) assert 'T.tvm_storage_sync("shared")' not in s, f"Unexpected sync inserted for atomic ops:\n{s}" @@ -90,7 +90,7 @@ def func(): mod = tvm.IRModule({"main": func}) mod = tilelang.transform.ThreadSync("shared.dyn")(mod) - s = str(mod) + s = str(mod.script()) assert 'T.tvm_storage_sync("shared.dyn")' not in s, f"Unexpected sync inserted for single atomic op:\n{s}" @@ -115,7 +115,7 @@ def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) result_local[0] = result_local[0] + temp_shared[0] mod = run_passes(func) - assert "T.tvm_storage_sync" in str(mod) + assert "T.tvm_storage_sync" in str(mod.script()) @tilelang.testing.requires_cuda @@ -137,7 +137,7 @@ def func() -> None: result_local[0] = temp_shared[threadIdx_x] mod = run_passes(func) - assert "T.tvm_storage_sync" in str(mod) + assert "T.tvm_storage_sync" in str(mod.script()) @tilelang.testing.requires_cuda @@ -162,7 +162,7 @@ def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) result_local[0] = result_local[0] + temp_shared[0] * p1[1] mod = run_passes(func) - assert "T.tvm_storage_sync" in str(mod) + assert "T.tvm_storage_sync" in str(mod.script()) @tilelang.testing.requires_cuda @@ -326,7 +326,7 @@ def func(): mod = tvm.IRModule({"main": func}) mod = tilelang.transform.ThreadSync("shared.dyn")(mod) - s = str(mod) + s = str(mod.script()) assert 'T.tvm_storage_sync("shared.dyn")' in s # Ensure the sync appears before the unrolled loop assert s.index('T.tvm_storage_sync("shared.dyn")') < s.index("for i in T.unroll(8)") @@ -359,7 +359,7 @@ def func(): mod = tvm.IRModule({"main": func}) mod = tilelang.transform.ThreadSync("shared")(mod) - s = str(mod) + s = str(mod.script()) # Should NOT have sync inside the loop since A[tx] in iteration i # does not conflict with A[tx] in iteration i+1 (they're different threads' data) # The key insight: same thread writes and reads its own location @@ -399,7 +399,7 @@ def func(): mod = tvm.IRModule({"main": func}) mod = tilelang.transform.ThreadSync("shared")(mod) - s = str(mod) + s = str(mod.script()) # Should have sync because thread tx reads from thread (tx+127)%128's location # This is a WAR hazard across threads assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync for cross-thread dependency:\n{s}" @@ -433,7 +433,7 @@ def func(): mod = tvm.IRModule({"main": func}) mod = tilelang.transform.ThreadSync("shared")(mod) - s = str(mod) + s = str(mod.script()) # Should NOT have sync inside loop due to modulo buffering analysis # Note: This test verifies the modulo analysis capability print(f"Modulo buffering result:\n{s}") @@ -467,7 +467,7 @@ def func(): mod = tvm.IRModule({"main": func}) mod = tilelang.transform.ThreadSync("shared")(mod) - s = str(mod) + s = str(mod.script()) print(f"Different indices result:\n{s}") @@ -505,7 +505,7 @@ def func(): mod = tvm.IRModule({"main": func}) mod = tilelang.transform.ThreadSync("shared")(mod) - s = str(mod) + s = str(mod.script()) # Sync should appear before the if statement assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}" # The sync should be before the if, not inside it @@ -545,7 +545,7 @@ def func(): mod = tvm.IRModule({"main": func}) mod = tilelang.transform.ThreadSync("shared")(mod) - s = str(mod) + s = str(mod.script()) # Sync should appear before the if statement assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}" # The sync should be before the if that checks token_ids @@ -581,7 +581,7 @@ def func(): mod = tvm.IRModule({"main": func}) mod = tilelang.transform.ThreadSync("shared")(mod) - s = str(mod) + s = str(mod.script()) # Should have sync (either inside or outside the if is fine for uniform condition) assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}" @@ -605,7 +605,7 @@ def func(flags: T.Buffer((4,), "int32")): mod = tvm.IRModule({"main": func}) mod = tilelang.transform.ThreadSync("shared")(mod) - s = str(mod) + s = str(mod.script()) assert s.count('T.tvm_storage_sync("shared")') == 1, f"Expected exactly one sync:\n{s}" if_pos = s.index("if flags[bx] > 0") sync_pos = s.index('T.tvm_storage_sync("shared")') @@ -636,7 +636,7 @@ def func(): mod = tvm.IRModule({"main": func}) mod = tilelang.transform.ThreadSync("shared")(mod) - s = str(mod) + s = str(mod.script()) assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}" # Sync should be before the outermost non-uniform if sync_pos = s.index('T.tvm_storage_sync("shared")') @@ -667,7 +667,7 @@ def func(): mod = tvm.IRModule({"main": func}) mod = tilelang.transform.ThreadSync("shared")(mod) - s = str(mod) + s = str(mod.script()) assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}" # Sync should be before the if inside the loop, not inside the if # This ensures all threads can reach the sync point @@ -697,7 +697,7 @@ def func(): mod = tvm.IRModule({"main": func}) mod = tilelang.transform.ThreadSync("shared")(mod) - s = str(mod) + s = str(mod.script()) # No sync needed - only local memory is accessed assert 'T.tvm_storage_sync("shared")' not in s, f"Unexpected sync:\n{s}" @@ -724,7 +724,7 @@ def func(): mod = tvm.IRModule({"main": func}) mod = tilelang.transform.ThreadSync("shared")(mod) - s = str(mod) + s = str(mod.script()) assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}" # Sync should be before the if inside the loop, not inside the if sync_pos = s.index('T.tvm_storage_sync("shared")') diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index d7b456386b..b39a58c6fb 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -112,6 +112,8 @@ def tilelang_callback_cuda_compile(code, target, pass_config=None): enable_fast_math = bool(cfg.get(PassConfigKey.TL_ENABLE_FAST_MATH, False)) ptxas_usage_level = cfg.get(PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL, None) + if ptxas_usage_level is not None: + ptxas_usage_level = int(ptxas_usage_level) verbose_ptxas_output = bool(cfg.get(PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT, False)) options = [ diff --git a/tilelang/jit/adapter/cutedsl/wrapper.py b/tilelang/jit/adapter/cutedsl/wrapper.py index 95c8cf0733..da3f2dd98d 100644 --- a/tilelang/jit/adapter/cutedsl/wrapper.py +++ b/tilelang/jit/adapter/cutedsl/wrapper.py @@ -1287,11 +1287,8 @@ def _process_tma_descriptors(self, desc_names: list[str]) -> tuple[list[str], di for desc_name in desc_names: info = self.tma_desc_info[desc_name] - # Extract the base buffer variable name (must be a Var, not arbitrary expression) - global_addr = info["globalAddress"] - if not isinstance(global_addr, tvm.tir.Var): - raise ValueError(f"TMA globalAddress must be a buffer Var, got {type(global_addr)}: {global_addr}") - tensor_name = global_addr.name + # Extract the base buffer variable name + tensor_name = info["globalAddress"] if tensor_name not in tensor_args: tensor_args.append(tensor_name) diff --git a/tilelang/jit/adapter/libgen.py b/tilelang/jit/adapter/libgen.py index 20c56bbd6b..761fc44248 100644 --- a/tilelang/jit/adapter/libgen.py +++ b/tilelang/jit/adapter/libgen.py @@ -62,6 +62,8 @@ def compile_lib(self, timeout: float = None): enable_fast_math = self.pass_configs.get(PassConfigKey.TL_ENABLE_FAST_MATH, False) ptxas_usage_level = self.pass_configs.get(PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL, None) + if ptxas_usage_level is not None: + ptxas_usage_level = int(ptxas_usage_level) verbose_ptxas_output = self.pass_configs.get(PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT, False) command = [ diff --git a/tilelang/jit/adapter/utils.py b/tilelang/jit/adapter/utils.py index d43adf840a..5c730b3790 100644 --- a/tilelang/jit/adapter/utils.py +++ b/tilelang/jit/adapter/utils.py @@ -432,6 +432,7 @@ def parse_tma_descriptor_args( if not isinstance(tensor_rank, int) or tensor_rank <= 0: raise ValueError(f"Invalid tensor_rank: {tensor_rank}. Must be a positive integer") + global_address = pythonic_expr_func(global_address) params = TMADescriptorParams(handle_name, dtype, tensor_rank, global_address, is_img2col) if not is_img2col: diff --git a/tilelang/jit/adapter/wrapper.py b/tilelang/jit/adapter/wrapper.py index 9cf5658206..3b15f2609a 100644 --- a/tilelang/jit/adapter/wrapper.py +++ b/tilelang/jit/adapter/wrapper.py @@ -513,7 +513,10 @@ def parse_source_information(self): host_code = str(func) for function_name in function_names: - index = host_code.index(f'T.call_packed("{function_name}"') + try: + index = host_code.index(f'T.call_packed("{function_name}"') + except ValueError: + index = host_code.index(f'value="{function_name}"') function_names_index[function_name] = index # sort function_names function_names = sorted(function_names, key=lambda x: function_names_index[x]) From 9c95a42dfb804faaa0f8c895c76f7467294a9f7b Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Tue, 21 Apr 2026 18:27:15 +0800 Subject: [PATCH 103/156] [Refactor] Phaseout legacy util `map_torch_type` with `T.dtype.as_torch` (#2075) * Remove redundant `map_torch_type` in favor of `T.dtype.as_torch()` `map_torch_type` was a standalone function that converted dtype strings or TVM DataType objects to `torch.dtype`. This was redundant since `T.dtype.as_torch()` (and `T.dtype(x).as_torch()`) already provides the same functionality with better handling (HIP/CUDA awareness, proper dtype lookup tables, etc.). Changes: - Remove `map_torch_type` from `tilelang/utils/tensor.py` and its re-export - Make `T.dtype()` idempotent for existing dtype objects (pass-through) - Add backward-compat for `"e4m3fnuz_float8"` TVM internal name in `as_torch()` - Replace all usages across production code, tests, and examples - Rebuild Cython extension to reflect the import change * Refactor output and accumulation data types in GEMM FP8 example Updated the `out_dtype` and `accum_dtype` variables in the `example_tilelang_gemm_fp8_intrinsic.py` file to use `T.float32` instead of string literals. This change enhances type consistency and aligns with recent updates in type handling. --- .../example_deepgemm_fp8_2xAcc.py | 7 ++-- .../example_tilelang_gemm_fp8_intrinsic.py | 9 +++-- .../example_tilelang_gemm_fp8_sm100.py | 5 ++- .../python/jit/test_tilelang_jit_cutedsl.py | 13 ++++--- .../jit/test_tilelang_jit_gemm_cython.py | 25 +++++++------ .../python/jit/test_tilelang_jit_nullptr.py | 7 ++-- testing/python/jit/test_tilelang_jit_nvrtc.py | 13 ++++--- .../python/jit/test_tilelang_jit_tvm_ffi.py | 13 ++++--- .../test_tilelang_kernel_bf16_gemm_mma.py | 7 ++-- .../kernel/test_tilelang_kernel_fp8_gemm.py | 7 ++-- .../test_tilelang_kernel_fp8_gemm_mma.py | 7 ++-- .../test_tilelang_kernel_fp8_gemv_simt.py | 7 ++-- ...test_tilelang_kernel_gemm_mma_intrinsic.py | 7 ++-- .../kernel/test_tilelang_kernel_gemv_simt.py | 7 ++-- .../language/test_tilelang_language_clear.py | 5 ++- .../language/test_tilelang_language_ptr.py | 25 +++++++------ .../test_tilelang_tilelibrary_gemm_sp.py | 10 +++--- .../test_tilelang_tilelibrary_gemm_sp_v2.py | 26 +++++++------- tilelang/jit/adapter/cython/adapter.py | 3 +- .../jit/adapter/cython/cython_wrapper.pyx | 1 - tilelang/language/dtypes.py | 4 ++- tilelang/utils/__init__.py | 2 +- tilelang/utils/tensor.py | 35 ------------------- 23 files changed, 97 insertions(+), 148 deletions(-) diff --git a/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py b/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py index 18467a8118..0d72ed3678 100644 --- a/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py +++ b/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py @@ -4,7 +4,6 @@ import tilelang.testing import tilelang import tilelang.language as T -from tilelang.utils.tensor import map_torch_type tilelang.testing.set_random_seed(42) @@ -148,9 +147,9 @@ def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtyp # src_code is the generated cuda source assert src_code is not None - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) - accum_dtype = map_torch_type(accum_dtype) + in_dtype = in_dtype.as_torch() + out_dtype = out_dtype.as_torch() + accum_dtype = accum_dtype.as_torch() A = torch.randn(M, K).to(torch.bfloat16).cuda() B = torch.randn(N, K).to(torch.bfloat16).cuda() diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py index 17a606fc9e..d9f749d9f2 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py @@ -6,7 +6,6 @@ from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics.mma_macro_generator import TensorCoreIntrinEmitter from tilelang.intrinsics.mfma_macro_generator import MatrixCoreIntrinEmitter -from tilelang.utils.tensor import map_torch_type from tilelang.utils import determine_fp8_type tilelang.testing.set_random_seed(0) @@ -195,9 +194,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): # src_code is the generated cuda source assert src_code is not None - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) - accum_dtype = map_torch_type(accum_dtype) + in_dtype = in_dtype.as_torch() + out_dtype = out_dtype.as_torch() + accum_dtype = accum_dtype.as_torch() if in_dtype in {torch.int8, torch.int32}: A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() @@ -234,7 +233,7 @@ def main(): def run_regression_perf(): M, N, K = 4096, 4096, 4096 - out_dtype, accum_dtype = "float32", "float32" + out_dtype, accum_dtype = T.float32, T.float32 in_dtype = determine_fp8_type() kernel_e4m3 = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py index cb42d921ef..72f09c2503 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py @@ -1,7 +1,6 @@ import torch import tilelang import tilelang.language as T -from tilelang.utils.tensor import map_torch_type def matmul( @@ -74,8 +73,8 @@ def calc_diff(x, y): threads = 256 for tvm_fp8_dtype in [T.float8_e4m3fn, T.float8_e5m2]: for tvm_acc_dtype in [T.float16, T.float32]: # , torch.float16]: - torch_fp8_dtype = map_torch_type(tvm_fp8_dtype) - torch_acc_dtype = map_torch_type(tvm_acc_dtype) + torch_fp8_dtype = tvm_fp8_dtype.as_torch() + torch_acc_dtype = tvm_acc_dtype.as_torch() print(f"running {tvm_fp8_dtype} -> {tvm_acc_dtype}") in_dtype, out_dtype, accum_dtype = tvm_fp8_dtype, tvm_acc_dtype, tvm_acc_dtype diff --git a/testing/python/jit/test_tilelang_jit_cutedsl.py b/testing/python/jit/test_tilelang_jit_cutedsl.py index 0cacddc58e..75c9e9012e 100644 --- a/testing/python/jit/test_tilelang_jit_cutedsl.py +++ b/testing/python/jit/test_tilelang_jit_cutedsl.py @@ -4,7 +4,6 @@ import tilelang import torch import pytest -from tilelang.utils.tensor import map_torch_type def matmul( @@ -180,8 +179,8 @@ def run_gemm_jit_kernel( matmul_kernel = tilelang.compile(program, out_idx=-1, target="cutedsl") - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) + in_dtype = T.dtype(in_dtype).as_torch() + out_dtype = T.dtype(out_dtype).as_torch() A = torch.randn(M, K, dtype=in_dtype).cuda() B = torch.randn(K, N, dtype=in_dtype).cuda() @@ -282,8 +281,8 @@ def run_cutedsl_kernel_multi_stream( ) matmul_kernel = tilelang.compile(program, target="cutedsl") - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) + in_dtype = T.dtype(in_dtype).as_torch() + out_dtype = T.dtype(out_dtype).as_torch() tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() @@ -332,8 +331,8 @@ def run_cutedsl_dynamic_shape( if isinstance(K, T.Var): K = 768 - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) + in_dtype = T.dtype(in_dtype).as_torch() + out_dtype = T.dtype(out_dtype).as_torch() tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() diff --git a/testing/python/jit/test_tilelang_jit_gemm_cython.py b/testing/python/jit/test_tilelang_jit_gemm_cython.py index 4925e66c72..eb9d9c66cb 100644 --- a/testing/python/jit/test_tilelang_jit_gemm_cython.py +++ b/testing/python/jit/test_tilelang_jit_gemm_cython.py @@ -4,7 +4,6 @@ import tilelang import torch import pytest -from tilelang.utils.tensor import map_torch_type def matmul( @@ -134,8 +133,8 @@ def run_gemm_jit_kernel( matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="cython") - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) + in_dtype = T.dtype(in_dtype).as_torch() + out_dtype = T.dtype(out_dtype).as_torch() A = torch.randn(M, K, dtype=in_dtype).cuda() B = torch.randn(K, N, dtype=in_dtype).cuda() @@ -234,8 +233,8 @@ def run_cython_kernel_multi_stream( matmul_kernel = tilelang.compile(program, execution_backend="cython") - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) + in_dtype = T.dtype(in_dtype).as_torch() + out_dtype = T.dtype(out_dtype).as_torch() tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() @@ -284,8 +283,8 @@ def run_cython_dynamic_shape( if isinstance(K, T.Var): K = 192 - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) + in_dtype = T.dtype(in_dtype).as_torch() + out_dtype = T.dtype(out_dtype).as_torch() tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() @@ -334,8 +333,8 @@ def run_cython_dynamic_shape_with_out_idx( if isinstance(K, T.Var): K = 192 - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) + in_dtype = T.dtype(in_dtype).as_torch() + out_dtype = T.dtype(out_dtype).as_torch() tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() @@ -411,8 +410,8 @@ def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B ) matmul_kernel = tilelang.compile(program, execution_backend="cython", out_idx=2) - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) + in_dtype = T.dtype(in_dtype).as_torch() + out_dtype = T.dtype(out_dtype).as_torch() tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() @@ -482,8 +481,8 @@ def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans ) matmul_kernel = tilelang.compile(program, execution_backend="cython", out_idx=2) - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) + in_dtype = T.dtype(in_dtype).as_torch() + out_dtype = T.dtype(out_dtype).as_torch() tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() diff --git a/testing/python/jit/test_tilelang_jit_nullptr.py b/testing/python/jit/test_tilelang_jit_nullptr.py index a9edb5e930..33820608da 100644 --- a/testing/python/jit/test_tilelang_jit_nullptr.py +++ b/testing/python/jit/test_tilelang_jit_nullptr.py @@ -3,7 +3,6 @@ import tilelang.testing import tilelang as tl import tilelang.language as T -from tilelang.utils import map_torch_type @tl.jit @@ -39,9 +38,9 @@ def main( def run_test(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): - a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype)) - b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype)) - c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype)) + a = torch.randn(M, K, device="cuda", dtype=dtype.as_torch()) + b = torch.randn(N, K, device="cuda", dtype=dtype.as_torch()) + c = torch.zeros(M, N, device="cuda", dtype=accum_dtype.as_torch()) kernel = tensor_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype, with_bias=False) kernel(a, b, c, None) diff --git a/testing/python/jit/test_tilelang_jit_nvrtc.py b/testing/python/jit/test_tilelang_jit_nvrtc.py index d8a6eb5e09..4f4d2bbc4a 100644 --- a/testing/python/jit/test_tilelang_jit_nvrtc.py +++ b/testing/python/jit/test_tilelang_jit_nvrtc.py @@ -4,7 +4,6 @@ import tilelang import torch import pytest -from tilelang.utils.tensor import map_torch_type def matmul( @@ -132,8 +131,8 @@ def run_gemm_jit_kernel( matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc") - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) + in_dtype = T.dtype(in_dtype).as_torch() + out_dtype = T.dtype(out_dtype).as_torch() A = torch.randn(M, K, dtype=in_dtype).cuda() B = torch.randn(K, N, dtype=in_dtype).cuda() @@ -234,8 +233,8 @@ def run_nvrtc_kernel_multi_stream( ) matmul_kernel = tilelang.compile(program, execution_backend="nvrtc") - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) + in_dtype = T.dtype(in_dtype).as_torch() + out_dtype = T.dtype(out_dtype).as_torch() tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() @@ -284,8 +283,8 @@ def run_nvrtc_dynamic_shape( if isinstance(K, T.Var): K = 768 - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) + in_dtype = T.dtype(in_dtype).as_torch() + out_dtype = T.dtype(out_dtype).as_torch() tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() diff --git a/testing/python/jit/test_tilelang_jit_tvm_ffi.py b/testing/python/jit/test_tilelang_jit_tvm_ffi.py index 38642d0911..5d9a8256d4 100644 --- a/testing/python/jit/test_tilelang_jit_tvm_ffi.py +++ b/testing/python/jit/test_tilelang_jit_tvm_ffi.py @@ -4,7 +4,6 @@ import tilelang import torch import pytest -from tilelang.utils.tensor import map_torch_type def matmul( @@ -132,8 +131,8 @@ def run_gemm_jit_kernel( matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="tvm_ffi") - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) + in_dtype = T.dtype(in_dtype).as_torch() + out_dtype = T.dtype(out_dtype).as_torch() A = torch.randn(M, K, dtype=in_dtype).cuda() B = torch.randn(K, N, dtype=in_dtype).cuda() @@ -232,8 +231,8 @@ def run_tvm_ffi_kernel_multi_stream( ) matmul_kernel = tilelang.compile(program, execution_backend="tvm_ffi") - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) + in_dtype = T.dtype(in_dtype).as_torch() + out_dtype = T.dtype(out_dtype).as_torch() tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() @@ -281,8 +280,8 @@ def run_tvm_ffi_dynamic_shape( if isinstance(K, T.Var): K = 768 - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) + in_dtype = T.dtype(in_dtype).as_torch() + out_dtype = T.dtype(out_dtype).as_torch() tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() diff --git a/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py index 97d050b730..33eef09a56 100644 --- a/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py +++ b/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py @@ -9,7 +9,6 @@ TensorCoreIntrinEmitter, ) from tilelang.transform import simplify_prim_func -from tilelang.utils.tensor import map_torch_type tilelang.testing.set_random_seed(0) @@ -190,9 +189,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): # src_code is the generated cuda source assert src_code is not None - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) - accum_dtype = map_torch_type(accum_dtype) + in_dtype = T.dtype(in_dtype).as_torch() + out_dtype = T.dtype(out_dtype).as_torch() + accum_dtype = T.dtype(accum_dtype).as_torch() if in_dtype in {torch.int8, torch.int32}: A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() diff --git a/testing/python/kernel/test_tilelang_kernel_fp8_gemm.py b/testing/python/kernel/test_tilelang_kernel_fp8_gemm.py index 276083b262..70d7a3f286 100644 --- a/testing/python/kernel/test_tilelang_kernel_fp8_gemm.py +++ b/testing/python/kernel/test_tilelang_kernel_fp8_gemm.py @@ -1,7 +1,6 @@ import torch import tilelang.testing import tilelang.language as T -from tilelang.utils.tensor import map_torch_type def calc_diff(x, y): @@ -38,12 +37,12 @@ def assert_matmul_correctness(M, N, K, block_M, block_N, block_K, in_dtype, out_ func = matmul_nt(M, N, K, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype) kernel = tilelang.compile(func, out_idx=-1) - A = torch.randn(M, K).to(map_torch_type(in_dtype)).cuda() - B = torch.randn(N, K).to(map_torch_type(in_dtype)).cuda() + A = torch.randn(M, K).to(T.dtype(in_dtype).as_torch()).cuda() + B = torch.randn(N, K).to(T.dtype(in_dtype).as_torch()).cuda() C = kernel(A, B) - ref_c = torch.matmul(A.to(map_torch_type(accum_dtype)), B.T.to(map_torch_type(accum_dtype))).to(map_torch_type(out_dtype)) + ref_c = torch.matmul(A.to(T.dtype(accum_dtype).as_torch()), B.T.to(T.dtype(accum_dtype).as_torch())).to(T.dtype(out_dtype).as_torch()) print(C) print(ref_c) diff = calc_diff(C, ref_c) diff --git a/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py index 9ba369b6b9..f8793ba2e9 100644 --- a/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py +++ b/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py @@ -9,7 +9,6 @@ TensorCoreIntrinEmitter, ) from tilelang.transform import simplify_prim_func -from tilelang.utils.tensor import map_torch_type tilelang.testing.set_random_seed(0) @@ -190,9 +189,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): # src_code is the generated cuda source assert src_code is not None - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) - accum_dtype = map_torch_type(accum_dtype) + in_dtype = T.dtype(in_dtype).as_torch() + out_dtype = T.dtype(out_dtype).as_torch() + accum_dtype = T.dtype(accum_dtype).as_torch() if in_dtype in {torch.int8, torch.int32}: A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() diff --git a/testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py b/testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py index 1a7a5e460a..1f819231f9 100644 --- a/testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py +++ b/testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py @@ -6,7 +6,6 @@ import tilelang.language as T from tilelang import JITKernel from tilelang.transform.simplify import apply_simplify -from tilelang.utils.tensor import map_torch_type from typing import Optional tilelang.testing.set_random_seed(0) @@ -128,9 +127,9 @@ def evaluate_gemv_simt( kernel = JITKernel(program, target="cuda") - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) - accum_dtype = map_torch_type(accum_dtype) + in_dtype = T.dtype(in_dtype).as_torch() + out_dtype = T.dtype(out_dtype).as_torch() + accum_dtype = T.dtype(accum_dtype).as_torch() if in_dtype in {torch.int8, torch.int32}: A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() diff --git a/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py b/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py index dd1b75ebc5..7f7f36c51d 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py @@ -9,7 +9,6 @@ TensorCoreIntrinEmitter, ) from tilelang.transform import simplify_prim_func -from tilelang.utils.tensor import map_torch_type tilelang.testing.set_random_seed(0) @@ -190,9 +189,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): # src_code is the generated cuda source assert src_code is not None - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) - accum_dtype = map_torch_type(accum_dtype) + in_dtype = T.dtype(in_dtype).as_torch() + out_dtype = T.dtype(out_dtype).as_torch() + accum_dtype = T.dtype(accum_dtype).as_torch() if in_dtype in {torch.int8, torch.int32}: A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() diff --git a/testing/python/kernel/test_tilelang_kernel_gemv_simt.py b/testing/python/kernel/test_tilelang_kernel_gemv_simt.py index d211488cdc..4669cee51f 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemv_simt.py +++ b/testing/python/kernel/test_tilelang_kernel_gemv_simt.py @@ -6,7 +6,6 @@ import tilelang.language as T from tilelang import JITKernel from tilelang.transform.simplify import apply_simplify -from tilelang.utils.tensor import map_torch_type from typing import Optional tilelang.testing.set_random_seed(0) @@ -128,9 +127,9 @@ def evaluate_gemv_simt( kernel = JITKernel(program, target="cuda") - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) - accum_dtype = map_torch_type(accum_dtype) + in_dtype = T.dtype(in_dtype).as_torch() + out_dtype = T.dtype(out_dtype).as_torch() + accum_dtype = T.dtype(accum_dtype).as_torch() if in_dtype in {torch.int8, torch.int32}: A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() diff --git a/testing/python/language/test_tilelang_language_clear.py b/testing/python/language/test_tilelang_language_clear.py index c3e9df24e3..b9e618bf30 100644 --- a/testing/python/language/test_tilelang_language_clear.py +++ b/testing/python/language/test_tilelang_language_clear.py @@ -43,10 +43,9 @@ def run_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype= program = matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) kernel = tilelang.compile(program, out_idx=[2]) import torch - from tilelang.utils import map_torch_type - a = torch.randn((M, K), dtype=map_torch_type(dtype)).cuda() - b = torch.randn((N, K), dtype=map_torch_type(dtype)).cuda() + a = torch.randn((M, K), dtype=dtype.as_torch()).cuda() + b = torch.randn((N, K), dtype=dtype.as_torch()).cuda() c = kernel(a, b) assert torch.allclose(c, torch.zeros_like(c)) diff --git a/testing/python/language/test_tilelang_language_ptr.py b/testing/python/language/test_tilelang_language_ptr.py index 41c6fa9f4d..9ddba5c21d 100644 --- a/testing/python/language/test_tilelang_language_ptr.py +++ b/testing/python/language/test_tilelang_language_ptr.py @@ -4,7 +4,6 @@ import tilelang.testing import tilelang as tl import tilelang.language as T -from tilelang.utils import map_torch_type def matmul_test(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @@ -118,10 +117,10 @@ def run_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype= def ref_program(a, b): return (a @ b.T).to(torch.float32) - a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype)) - b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype)) - ffi_c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype)) - cython_c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype)) + a = torch.randn(M, K, device="cuda", dtype=dtype.as_torch()) + b = torch.randn(N, K, device="cuda", dtype=dtype.as_torch()) + ffi_c = torch.zeros(M, N, device="cuda", dtype=accum_dtype.as_torch()) + cython_c = torch.zeros(M, N, device="cuda", dtype=accum_dtype.as_torch()) ffi_jit_kernel(a, b, ffi_c, M, N, K) cython_jit_kernel(a.data_ptr(), b.data_ptr(), cython_c.data_ptr(), M, N, K) @@ -133,10 +132,10 @@ def run_pointer_table_copy(N, dtype=T.float16): program = pointer_table_copy_test(N, dtype) cython_jit_kernel = tl.compile(program, execution_backend="cython") ffi_jit_kernel = tl.compile(program, execution_backend="tvm_ffi") - src = torch.randn(N, device="cuda", dtype=map_torch_type(dtype)) + src = torch.randn(N, device="cuda", dtype=dtype.as_torch()) src_ptrs = torch.tensor([src.data_ptr()], device="cuda", dtype=torch.int64) - ffi_out = torch.empty(N, device="cuda", dtype=map_torch_type(dtype)) - cython_out = torch.empty(N, device="cuda", dtype=map_torch_type(dtype)) + ffi_out = torch.empty(N, device="cuda", dtype=dtype.as_torch()) + cython_out = torch.empty(N, device="cuda", dtype=dtype.as_torch()) ffi_jit_kernel(src_ptrs, ffi_out) cython_jit_kernel(src_ptrs, cython_out) @@ -150,11 +149,11 @@ def run_pointer_table_multi_copy(G, N, dtype=T.float16): program = pointer_table_multi_copy_test(G, N, dtype) cython_jit_kernel = tl.compile(program, execution_backend="cython") ffi_jit_kernel = tl.compile(program, execution_backend="tvm_ffi") - srcs = [torch.randn(N, device="cuda", dtype=map_torch_type(dtype)) for _ in range(G)] + srcs = [torch.randn(N, device="cuda", dtype=dtype.as_torch()) for _ in range(G)] src_ptrs = torch.tensor([src.data_ptr() for src in srcs], device="cuda", dtype=torch.int64) ref = torch.stack(srcs, dim=0) - ffi_out = torch.empty((G, N), device="cuda", dtype=map_torch_type(dtype)) - cython_out = torch.empty((G, N), device="cuda", dtype=map_torch_type(dtype)) + ffi_out = torch.empty((G, N), device="cuda", dtype=dtype.as_torch()) + cython_out = torch.empty((G, N), device="cuda", dtype=dtype.as_torch()) ffi_jit_kernel(src_ptrs, ffi_out) cython_jit_kernel(src_ptrs, cython_out) @@ -171,8 +170,8 @@ def run_pointer_table_grouped_matmul(batch_sizes_list, N, K, block_M, block_N, b ffi_jit_kernel = tl.compile(program, execution_backend="tvm_ffi", **compile_kwargs) device = "cuda" - torch_dtype = map_torch_type(dtype) - torch_accum_dtype = map_torch_type(accum_dtype) + torch_dtype = dtype.as_torch() + torch_accum_dtype = accum_dtype.as_torch() max_M = max(batch_sizes_list) batch_tile_offsets = [0] for size in batch_sizes_list[:-1]: diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py index cfd0f75e0a..8ffffd8ce0 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py @@ -6,7 +6,7 @@ from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse from tilelang.layout import make_cutlass_metadata_layout -from tilelang.utils.tensor import torch_assert_close, map_torch_type +from tilelang.utils.tensor import torch_assert_close from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter torch.backends.cuda.matmul.allow_tf32 = False @@ -21,11 +21,11 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype): low, high = (0, 4) if is_unsigned else (-2, 2) else: low, high = (0, 128) if is_unsigned else (-64, 64) - A = randint_semi_sparse(M, K, low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda", transposed=trans_A) - B = torch.randint(size=(N, K) if trans_B else (K, N), low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda") + A = randint_semi_sparse(M, K, low=low, high=high, dtype=T.dtype(in_dtype).as_torch(), device="cuda", transposed=trans_A) + B = torch.randint(size=(N, K) if trans_B else (K, N), low=low, high=high, dtype=T.dtype(in_dtype).as_torch(), device="cuda") else: - A = randn_semi_sparse(M, K, dtype=torch.float32, device="cuda", transposed=trans_A).to(map_torch_type(in_dtype)) - B = torch.randn((N, K) if trans_B else (K, N), device="cuda", dtype=torch.float32).to(map_torch_type(in_dtype)) + A = randn_semi_sparse(M, K, dtype=torch.float32, device="cuda", transposed=trans_A).to(T.dtype(in_dtype).as_torch()) + B = torch.randn((N, K) if trans_B else (K, N), device="cuda", dtype=torch.float32).to(T.dtype(in_dtype).as_torch()) return A, B diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py index 6ec5718e8a..32742a005f 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py @@ -1,7 +1,7 @@ import pytest from tilelang import tvm as tvm from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse -from tilelang.utils.tensor import torch_assert_close, map_torch_type +from tilelang.utils.tensor import torch_assert_close from tilelang.layout import make_cutlass_metadata_layout from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter @@ -123,8 +123,8 @@ def _matmul(A, B): C = _matmul(A, B) torch_assert_close( - C_sp.to(map_torch_type(out_dtype)).to(torch.float32), - C.to(map_torch_type(out_dtype)).to(torch.float32), + C_sp.to(T.dtype(out_dtype).as_torch()).to(torch.float32), + C.to(T.dtype(out_dtype).as_torch()).to(torch.float32), rtol=1e-3, atol=1e-3, base_name="tilelang_sp", @@ -142,11 +142,11 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype): low, high = (0, 4) if is_unsigned else (-2, 2) else: low, high = (0, 128) if is_unsigned else (-64, 64) - A = randint_semi_sparse(M, K, low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda", transposed=trans_A) - B = torch.randint(size=(N, K) if trans_B else (K, N), low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda") + A = randint_semi_sparse(M, K, low=low, high=high, dtype=T.dtype(in_dtype).as_torch(), device="cuda", transposed=trans_A) + B = torch.randint(size=(N, K) if trans_B else (K, N), low=low, high=high, dtype=T.dtype(in_dtype).as_torch(), device="cuda") else: - A = randn_semi_sparse(M, K, dtype=map_torch_type(in_dtype), device="cuda", transposed=trans_A) - B = torch.randn((N, K) if trans_B else (K, N), device="cuda", dtype=torch.float32).to(map_torch_type(in_dtype)) + A = randn_semi_sparse(M, K, dtype=T.dtype(in_dtype).as_torch(), device="cuda", transposed=trans_A) + B = torch.randn((N, K) if trans_B else (K, N), device="cuda", dtype=torch.float32).to(T.dtype(in_dtype).as_torch()) return A, B @@ -288,8 +288,8 @@ def _matmul(A, B): C = _matmul(A, B) torch_assert_close( - C_sp.to(map_torch_type(out_dtype)).to(torch.float32), - C.to(map_torch_type(out_dtype)).to(torch.float32), + C_sp.to(T.dtype(out_dtype).as_torch()).to(torch.float32), + C.to(T.dtype(out_dtype).as_torch()).to(torch.float32), rtol=1e-3, atol=1e-3, base_name="tilelang_sp", @@ -436,8 +436,8 @@ def _matmul(A, B): C = _matmul(A, B) torch_assert_close( - C_sp.to(map_torch_type(out_dtype)).to(torch.float32), - C.to(map_torch_type(out_dtype)).to(torch.float32), + C_sp.to(T.dtype(out_dtype).as_torch()).to(torch.float32), + C.to(T.dtype(out_dtype).as_torch()).to(torch.float32), rtol=1e-3, atol=1e-3, base_name="tilelang_sp", @@ -588,8 +588,8 @@ def _matmul(A, B): C = _matmul(A, B) torch_assert_close( - C_sp.to(map_torch_type(out_dtype)).to(torch.float32), - C.to(map_torch_type(out_dtype)).to(torch.float32), + C_sp.to(T.dtype(out_dtype).as_torch()).to(torch.float32), + C.to(T.dtype(out_dtype).as_torch()).to(torch.float32), rtol=1e-3, atol=1e-3, base_name="tilelang_sp", diff --git a/tilelang/jit/adapter/cython/adapter.py b/tilelang/jit/adapter/cython/adapter.py index 912eb07b38..62bb0a6c33 100644 --- a/tilelang/jit/adapter/cython/adapter.py +++ b/tilelang/jit/adapter/cython/adapter.py @@ -18,7 +18,6 @@ from tilelang.jit.adapter.utils import is_cuda_target, is_hip_target, is_cpu_target, is_metal_target from tilelang.utils.target import determine_target from tilelang.utils.language import retrieve_func_from_module -from tilelang.utils.tensor import map_torch_type logger = logging.getLogger(__name__) @@ -255,7 +254,7 @@ def _process_buffer_dtype(self) -> dict[tir.Var, tuple[int, torch.dtype]]: if param in buffer_map: buffer = buffer_map[param] name, dtype = buffer.name, buffer.dtype - buffer_dtype_map[name] = (i, map_torch_type(dtype)) + buffer_dtype_map[name] = (i, dtype.as_torch()) return buffer_dtype_map def _process_ptr_map(self) -> dict[int, str]: diff --git a/tilelang/jit/adapter/cython/cython_wrapper.pyx b/tilelang/jit/adapter/cython/cython_wrapper.pyx index b4d51fc916..03238f7808 100644 --- a/tilelang/jit/adapter/cython/cython_wrapper.pyx +++ b/tilelang/jit/adapter/cython/cython_wrapper.pyx @@ -6,7 +6,6 @@ import ctypes from libc.stdint cimport int64_t, uintptr_t from libc.stdlib cimport malloc, free from tvm import tir -from tilelang.utils.tensor import map_torch_type cdef class CythonKernelWrapper: # Class attributes to store kernel configuration and library reference diff --git a/tilelang/language/dtypes.py b/tilelang/language/dtypes.py index 7b34e3fa8d..02b289ec49 100644 --- a/tilelang/language/dtypes.py +++ b/tilelang/language/dtypes.py @@ -194,7 +194,7 @@ def __dtype_as_torch__(self: dtype) -> torch.dtype: elif dtype_str == "float8_e5m2": assert hasattr(torch, "float8_e5m2"), "torch.float8_e5m2 is not supported in this version of torch. Please upgrade torch >= 2.1.0" return torch.float8_e5m2 - elif dtype_str == "float8_e4m3fnuz": + elif dtype_str in ("float8_e4m3fnuz", "e4m3fnuz_float8"): assert hasattr(torch, "float8_e4m3fnuz"), ( "torch.float8_e4m3fnuz is not supported in this version of torch. Please upgrade torch >= 2.2.0" ) @@ -227,6 +227,8 @@ def __dtype_as_torch__(self: dtype) -> torch.dtype: def __dtype_new__(cls, value: AnyDType) -> dtype: + if isinstance(value, dtype): + return value if isinstance(value, str): return __orig_dtype_new(cls, _CANONICAL_TO_DISPLAY_STR.get(value, value)) elif value in _DTYPE_TO_STR: diff --git a/tilelang/utils/__init__.py b/tilelang/utils/__init__.py index dacd29a95d..81398db262 100644 --- a/tilelang/utils/__init__.py +++ b/tilelang/utils/__init__.py @@ -5,7 +5,7 @@ determine_fp8_type, determine_torch_fp8_type, ) -from .tensor import TensorSupplyType, torch_assert_close, map_torch_type # noqa: F401 +from .tensor import TensorSupplyType, torch_assert_close # noqa: F401 from .language import ( is_global, # noqa: F401 is_shared, # noqa: F401 diff --git a/tilelang/utils/tensor.py b/tilelang/utils/tensor.py index d42254590e..27bac38e98 100644 --- a/tilelang/utils/tensor.py +++ b/tilelang/utils/tensor.py @@ -39,41 +39,6 @@ class TensorSupplyType(Enum): Auto = 7 -def map_torch_type(intype) -> torch.dtype: - # Convert to string if needed - if not isinstance(intype, str): - intype = str(intype) - - if intype == "float8_e4m3": - assert hasattr(torch, "float8_e4m3fn"), "torch.float8_e4m3fn is not supported in this version of torchPlease upgrade torch >= 2.1.0" - return torch.float8_e4m3fn - elif intype == "float8_e5m2": - assert hasattr(torch, "float8_e5m2"), "torch.float8_e5m2 is not supported in this version of torchPlease upgrade torch >= 2.1.0" - return torch.float8_e5m2 - elif intype == "e4m3fnuz_float8": - assert hasattr(torch, "float8_e4m3fnuz"), ( - "torch.float8_e4m3fnuz is not supported in this version of torchPlease upgrade torch >= 2.2.0" - ) - return torch.float8_e4m3fnuz - elif intype == "float8_e8m0fnu": - assert hasattr(torch, "float8_e8m0fnu"), ( - "torch.float8_e8m0fnu is not supported in this version of torchPlease upgrade torch >= 2.8.0" - ) - return torch.float8_e8m0fnu - elif intype == "float4_e2m1fnx2": - assert hasattr(torch, "float4_e2m1fnx2"), ( - "torch.float4_e2m1fnx2 is not supported in this version of torchPlease upgrade torch >= 2.8.0" - ) - return torch.float4_e2m1fnx2 - elif intype == "int4": - return torch.int8 - elif "float4" in intype: - # PyTorch doesn't support float4, use int8 as storage type - return torch.int8 - else: - return getattr(torch, intype) - - def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer): from tilelang.engine.param import KernelParam from .device import get_current_device From 64bd7420714f14911f095b72add5d9df16f72744 Mon Sep 17 00:00:00 2001 From: Chenhao Xu <122071158+bucket-xv@users.noreply.github.com> Date: Tue, 21 Apr 2026 19:49:06 +0800 Subject: [PATCH 104/156] [Bugfix] Fix reduce layout (#2074) * fix: fix reduce layout * test: add targeting test --- src/layout/utils.cc | 4 +-- src/op/reduce.cc | 3 ++- .../test_tilelang_language_reshape.py | 26 +++++++++++++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/layout/utils.cc b/src/layout/utils.cc index 92cda64f7b..73236d2e7c 100644 --- a/src/layout/utils.cc +++ b/src/layout/utils.cc @@ -325,8 +325,8 @@ std::pair CompressIterator(const PrimExpr &expr, collector.Collect({iter_sum}); IterMark mark; for (const IterMark &m : collector.visited_) { - ICHECK(m->source.as()) << "Not a normalized iterator: " << mark; - if (m->source.as().value().same_as(var)) { + auto v = m->source.as(); + if (v && v.value().same_as(var)) { mark = m; break; } diff --git a/src/op/reduce.cc b/src/op/reduce.cc index dea9dbc822..aff8ba1986 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -403,7 +403,8 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer); for (const auto &iter_split : iter_sum->args) { auto mark = iter_split->source->source.as(); - ICHECK(mark) << "Not a normalized iterator: " << iter_split->source; + if (!mark) + continue; if (mark.value().same_as(src_vars[this->dim]->var)) { // `scale` is the stride of participating threads in the thread index // space. When the thread-to-data mapping for the reduce dimension is diff --git a/testing/python/language/test_tilelang_language_reshape.py b/testing/python/language/test_tilelang_language_reshape.py index c7ff50c145..27388911b7 100644 --- a/testing/python/language/test_tilelang_language_reshape.py +++ b/testing/python/language/test_tilelang_language_reshape.py @@ -260,5 +260,31 @@ def test_reshape_shape_mismatch(): reshape_shape_mismatch_test(1024, 32, T.float32) +def test_reduce_absmax_after_reshape_3d(): + M, N, num_groups, num_per_channels = 2, 384, 3, 128 + threads = 128 + dtype = "int64" + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, num_groups), dtype), + ): + with T.Kernel(1, threads=threads) as _: + A_local = T.alloc_fragment((M, N), dtype) + A_reshaped = T.reshape(A_local, [M, num_groups, num_per_channels]) + B_local = T.alloc_fragment((M, num_groups), dtype) + T.copy(A, A_local) + T.reduce_absmax(A_reshaped, B_local, dim=2) + T.copy(B_local, B) + + jit_kernel = tl.compile(main, out_idx=-1) + A_torch = torch.randint(-100, 100, (M, N), dtype=getattr(torch, dtype)).cuda() + B_torch = jit_kernel(A_torch) + + ref = A_torch.abs().reshape(M, num_groups, num_per_channels).max(dim=2).values + torch.testing.assert_close(B_torch, ref) + + if __name__ == "__main__": tilelang.testing.main() From 948d38a2c4c7b1e36ab1a56c2d72c4c705bf05b1 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Tue, 21 Apr 2026 20:51:19 +0800 Subject: [PATCH 105/156] [Refactor] Disable unhelpful warning print (#2077) disable warning for swizzle mergeing --- src/transform/layout_inference.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index 6ee867fe8a..4cfdb6bf82 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -257,8 +257,8 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { const Layout &existing = layout_map[buffer]; if (!layout.as() && !existing.as()) { if (auto merged = MergeSwizzleLayouts(existing, layout, buffer)) { - LOG(WARNING) << "Swizzle layout conflict for buffer " << buffer - << ", merging to smaller granularity"; + DLOG(WARNING) << "Swizzle layout conflict for buffer " << buffer + << ", merging to smaller granularity"; layout_map.Set(buffer, merged.value()); propagate_alias(buffer, merged.value()); continue; From 15309f5cba9b975838fe8f7c64e67a6d64e765a7 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Wed, 22 Apr 2026 00:09:01 +0800 Subject: [PATCH 106/156] [CUDA] Improve int4 GEMM lowering and packed codegen support (#2073) * [Cleanup] Remove redundant int4 MMA code paths * [CI] Apply format.sh updates * lint fix * fix * Refactor FP8 type determination to utilize `T` module constants Updated the `determine_fp8_type` function to return constants from the `T` module instead of string literals for FP8 formats. This change enhances consistency and aligns with recent updates in type handling. Adjusted return statements for both HIP and CUDA scenarios to use `T.float8_*` constants. * Refactor FP8 type usage in tests to directly call `determine_fp8_type` Updated test files to replace `T.dtype(determine_fp8_type())` with direct calls to `determine_fp8_type()` for improved clarity and consistency in FP8 type handling. This change aligns with recent refactoring efforts in type determination. * Remove print statements from `assert_tl_matmul_correctness` and `test_assert_tl_matmul` functions to clean up test output. This change enhances the clarity of test results by eliminating unnecessary console output. * Enhance CI workflow for CuTeDSL examples with Python 3.12 and CUDA-12.8 - Added a step to list generated files and clean up Python cache directories. - Introduced a new job for running CuTeDSL examples on self-hosted NVIDIA runners. - Set up environment variables for CUDA and Python, including caching configurations. - Implemented error handling to clear cache on setup failure. - Updated steps for installing dependencies and running examples, ensuring compatibility with the latest Python and CUDA versions. --- .github/workflows/ci.yml | 126 +++++++- .../gemm_int4/example_tilelang_gemm_int4.py | 69 +++- src/target/codegen_cuda.cc | 172 +++++++--- src/tl_templates/cuda/common.h | 38 +++ src/tl_templates/cuda/instruction/mma.h | 6 +- .../amd/test_tilelang_gemm_mfma_intrinsic.py | 12 +- .../amd/test_tilelang_gemm_mfma_preshuffle.py | 8 +- .../python/debug/test_tilelang_debug_print.py | 4 +- tilelang/intrinsics/mma_macro_generator.py | 303 ++---------------- tilelang/intrinsics/utils.py | 10 +- tilelang/tileop/gemm/gemm_mma.py | 114 +------ tilelang/utils/target.py | 15 +- 12 files changed, 410 insertions(+), 467 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fba08c54a6..4b319a7218 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -355,10 +355,128 @@ jobs: -k metal \ ./python - # CuTeDSL backend: run examples with TILELANG_TARGET=cutedsl - # Placed after core test steps so a CuTeDSL failure doesn't skip them. - - name: Run CuTeDSL examples with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }}) - if: ${{ !cancelled() && contains(matrix.runner.toolkit, 'CUDA') }} + - name: List generated files + if: ${{ !cancelled() }} + run: | + find . -type f -name '*.py[co]' -delete + find . -depth -type d -name "__pycache__" -exec rm -r "{}" + + if git status --ignored --porcelain | grep -qvE '/$'; then + ls -alh $(git status --ignored --porcelain | grep -vE '/$' | grep -oE '\S+$') + fi + + cutedsl: + name: CuTeDSL Examples for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia) + if: | + github.repository_owner == 'tile-ai' && + (github.event_name != 'pull_request' || !github.event.pull_request.draft) + needs: [tests] + runs-on: [self-hosted, nvidia] + timeout-minutes: 120 + + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + fetch-depth: 0 + submodules: recursive + + - name: Set environment (self-hosted runners) + run: | + # Hide sensitive data in logs for self-hosted runners + if [[ -n "${{ secrets.SECRET_PATH_PREFIXES }}" ]]; then + echo "::add-mask::${{ secrets.SECRET_PATH_PREFIXES }}" + # Colon separated list of secrets to mask + for secret in $(echo "${{ secrets.SECRET_PATH_PREFIXES }}" | tr ':' '\n'); do + echo "::add-mask::${secret}" + done + fi + + # Use runner tool_cache as cache root for self-hosted runners to avoid internet connection + # issues and to share cache between jobs. + export XDG_CACHE_HOME="${{ runner.tool_cache }}/.ci-cache-${{ github.workflow }}" + echo "XDG_CACHE_HOME=${XDG_CACHE_HOME}" | tee -a "${GITHUB_ENV}" + echo "PIP_CACHE_DIR=${XDG_CACHE_HOME}/pip" | tee -a "${GITHUB_ENV}" + echo "UV_CACHE_DIR=${XDG_CACHE_HOME}/uv" | tee -a "${GITHUB_ENV}" + echo "PRE_COMMIT_HOME=${XDG_CACHE_HOME}/pip/.pre-commit" | tee -a "${GITHUB_ENV}" + + - name: Set environment (CUDA) + run: | + TOOLKIT="CUDA-12.8" + CUDA_VERSION="${TOOLKIT##*-}" + CUDA_VERSION_MAJMIN="$(echo ${CUDA_VERSION} | cut -d '.' -f-2)" + CUDA_VERSION_MAJMIN_NODOT="${CUDA_VERSION_MAJMIN//./}" + export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu${CUDA_VERSION_MAJMIN_NODOT}" + export UV_INDEX="${PIP_EXTRA_INDEX_URL}" + + echo "USE_CUDA=ON" | tee -a "${GITHUB_ENV}" + echo "CUDA_VERSION=${CUDA_VERSION}" | tee -a "${GITHUB_ENV}" + echo "CUDA_VERSION_MAJMIN=${CUDA_VERSION_MAJMIN}" | tee -a "${GITHUB_ENV}" + echo "CUDA_VERSION_MAJMIN_NODOT=${CUDA_VERSION_MAJMIN_NODOT}" | tee -a "${GITHUB_ENV}" + echo "PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}" | tee -a "${GITHUB_ENV}" + echo "UV_INDEX=${UV_INDEX}" | tee -a "${GITHUB_ENV}" + + if [[ ! -x "$(command -v nvcc)" ]]; then + export PATH="/usr/local/cuda/bin:${PATH}" + export LD_LIBRARY_PATH="/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" + echo "PATH=${PATH}" | tee -a "${GITHUB_ENV}" + echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" | tee -a "${GITHUB_ENV}" + fi + if [[ -x "$(command -v nvcc)" ]]; then + echo "\$ $(command -v nvcc) --version" && nvcc --version + else + echo "::warning::nvcc not found in PATH!" + fi + + - name: Setup Python and uv with caching + id: setup-uv + uses: astral-sh/setup-uv@v7 + with: + python-version: "3.12" + activate-environment: true + enable-cache: false + prune-cache: false + cache-local-path: ${{ env.UV_CACHE_DIR }} + ignore-nothing-to-cache: true + cache-suffix: uv-${{ runner.os }}-${{ runner.arch }}-3.12-self-hosted-nvidia-CUDA-12.8 + cache-dependency-glob: | + pyproject.toml + requirements*.txt + .pre-commit-config.yaml + + - name: Setup venv + id: setup-venv + run: | + set -o pipefail + + uv pip install --upgrade pip setuptools wheel + uv pip install -v -r requirements-test.txt + echo "import torch; print(f'torch: {torch.__version__}')" | uv run --no-project --script - + uv pip install --no-build-isolation-package=flash-attn -v -r requirements-test-cuda.txt + echo "import flash_attn; print(f'flash_attn: {flash_attn.__version__}')" | uv run --no-project --script - + echo "::group::torch.utils.collect_env" + uv run --no-project -m -- torch.utils.collect_env + echo "::endgroup::" + + - name: Clear uv cache for self-hosted runners (if setup failed) + if: >- + ${{ + failure() && + (steps.setup-uv.conclusion == 'failure' || steps.setup-venv.conclusion == 'failure') + }} + run: | + echo "Clearing uv cache at ${UV_CACHE_DIR} due to failure." + uv cache clean + + - name: Install project (wheel form) + run: | + uv pip install -v . + + - name: Clean up stale /tmp files (self-hosted runners) + run: | + rm -f /tmp/tmp*.so /tmp/tmp*.cu /tmp/tmp*.cubin /tmp/tmp*.cpp + rm -rf /tmp/tvm-debug-mode-tempdirs /tmp/tilelang_cutedsl_* + + - name: Run CuTeDSL examples with Python 3.12 (CUDA-12.8) env: TILELANG_TARGET: cutedsl run: | diff --git a/examples/gemm_int4/example_tilelang_gemm_int4.py b/examples/gemm_int4/example_tilelang_gemm_int4.py index 4ad0fca710..3db000b616 100644 --- a/examples/gemm_int4/example_tilelang_gemm_int4.py +++ b/examples/gemm_int4/example_tilelang_gemm_int4.py @@ -4,23 +4,24 @@ - A/B are declared as T.int4 tensors - the matmul is expressed with T.gemm(...) -The example compiles the kernel and prints the generated CUDA source. +The example compiles the kernel, prints the generated CUDA source, and +checks correctness against a PyTorch reference. """ +import torch + import tilelang import tilelang.language as T -tilelang.disable_cache() - -def matmul_nt_int4(M, N, K, block_M, block_N, block_K): +def matmul_nt_int4(M, N, K, block_M, block_N, block_K, threads=128): @T.prim_func def main( A: T.Tensor((M, K), T.int4), B: T.Tensor((N, K), T.int4), C: T.Tensor((M, N), T.int32), ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), T.int4) B_shared = T.alloc_shared((block_N, block_K), T.int4) C_local = T.alloc_fragment((block_M, block_N), T.int32) @@ -44,16 +45,68 @@ def compile_int4_gemm( block_M=128, block_N=128, block_K=64, + threads=128, + print_cuda_source=True, ): - func = matmul_nt_int4(M, N, K, block_M, block_N, block_K) + func = matmul_nt_int4(M, N, K, block_M, block_N, block_K, threads) kernel = tilelang.compile(func, out_idx=-1) print("Compilation succeeded.") - print(kernel.get_kernel_source()) + if print_cuda_source: + print(kernel.get_kernel_source()) return func, kernel +def pack_int4(tensor: torch.Tensor) -> torch.Tensor: + if tensor.dtype != torch.int8: + raise TypeError(f"Expected torch.int8 logical int4 tensor, but got {tensor.dtype}.") + if tensor.ndim == 0 or tensor.shape[-1] % 2 != 0: + raise ValueError("The last dimension of a logical int4 tensor must be even for int8 packing.") + + tensor_i16 = tensor.to(torch.int16) + packed = (tensor_i16[..., ::2] & 0x0F) | ((tensor_i16[..., 1::2] & 0x0F) << 4) + return packed.to(torch.int8).contiguous() + + +def check_int4_gemm_correctness( + M=1024, + N=1024, + K=1024, + block_M=128, + block_N=128, + block_K=64, + threads=128, +): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required to run the int4 GEMM example.") + + _, kernel = compile_int4_gemm( + M=M, + N=N, + K=K, + block_M=block_M, + block_N=block_N, + block_K=block_K, + threads=threads, + ) + + A_logical = torch.randint(-8, 8, (M, K), device="cuda", dtype=torch.int8) + B_logical = torch.randint(-8, 8, (N, K), device="cuda", dtype=torch.int8) + + A_packed = pack_int4(A_logical) + B_packed = pack_int4(B_logical) + C = kernel(A_packed, B_packed) + torch.cuda.synchronize() + + ref_c = torch.matmul(A_logical.cpu().to(torch.int32), B_logical.cpu().to(torch.int32).T) + torch.testing.assert_close(C.cpu(), ref_c, rtol=0, atol=0) + print("Correctness check passed.") + return C, ref_c + + def main(): - compile_int4_gemm() + # check_int4_gemm_correctness(M=16, N=16, K=32, block_M=16, block_N=16, block_K=32) + # check_int4_gemm_correctness(M=16, N=16, K=64, block_M=16, block_N=16, block_K=64) + check_int4_gemm_correctness() if __name__ == "__main__": diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 59f31e7297..317b01370e 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -813,7 +813,12 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) } case 4: { if (t.is_scalar()) { - os << "int"; + enable_int8_ = true; + if (!t.is_uint()) { + os << "signed char"; + } else { + os << "char"; + } return; } else if (t.lanes() == 4) { os << "int16_t"; @@ -1264,8 +1269,8 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t, ICHECK(i >= 0 && i < 256 / t.bits()); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (t.lanes() == 2 || t.lanes() == 3) { - stream << vec << '.' << access[i % t.lanes()] << "=" << "(" << value - << ");\n"; + stream << vec << '.' << access[i % t.lanes()] << "=" + << "(" << value << ");\n"; } else if (t.lanes() <= 16) { std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); stream << ac << "="; @@ -1724,15 +1729,16 @@ void CodeGenTileLangCUDA::VisitExpr_(const MinNode *op, std::ostream &os) { // Standard min/max functions don't support bfloat16 or float16 if (t.is_bfloat16() && t.is_scalar()) { - os << "cutlass::bfloat16_t(__hmin(" << "(" << PrintExpr(op->a) - << ").to_nv_bfloat16(), " << "(" << PrintExpr(op->b) - << ").to_nv_bfloat16()))"; + os << "cutlass::bfloat16_t(__hmin(" + << "(" << PrintExpr(op->a) << ").to_nv_bfloat16(), " + << "(" << PrintExpr(op->b) << ").to_nv_bfloat16()))"; return; } if (t.is_float16() && t.is_scalar()) { - os << "cutlass::half_t(__hmin(" << "(" << PrintExpr(op->a) - << ").to_half(), " << "(" << PrintExpr(op->b) << ").to_half()))"; + os << "cutlass::half_t(__hmin(" + << "(" << PrintExpr(op->a) << ").to_half(), " + << "(" << PrintExpr(op->b) << ").to_half()))"; return; } @@ -1754,15 +1760,16 @@ void CodeGenTileLangCUDA::VisitExpr_(const MaxNode *op, std::ostream &os) { // Standard min/max functions don't support bfloat16 or float16 if (t.is_bfloat16() && t.is_scalar()) { - os << "cutlass::bfloat16_t(__hmax(" << "(" << PrintExpr(op->a) - << ").to_nv_bfloat16(), " << "(" << PrintExpr(op->b) - << ").to_nv_bfloat16()))"; + os << "cutlass::bfloat16_t(__hmax(" + << "(" << PrintExpr(op->a) << ").to_nv_bfloat16(), " + << "(" << PrintExpr(op->b) << ").to_nv_bfloat16()))"; return; } if (t.is_float16() && t.is_scalar()) { - os << "cutlass::half_t(__hmax(" << "(" << PrintExpr(op->a) - << ").to_half(), " << "(" << PrintExpr(op->b) << ").to_half()))"; + os << "cutlass::half_t(__hmax(" + << "(" << PrintExpr(op->a) << ").to_half(), " + << "(" << PrintExpr(op->b) << ").to_half()))"; return; } @@ -1887,16 +1894,18 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, } std::string index_str = PrintExpr(index); if ((t.bits() == 4 && !t.is_float4()) || (t.bits() == 1 && t.is_int())) { - // This is a special case, because CodegenCUDA::PrintType() - // returns "int" for bool and for 4-bit integers. In most cases, - // we divide by the number of lanes to determine the index. - // However, the backing type for scalar int4 and scalar bool is - // int32. Therefore, we need to divide by the ratio of their - // sizes in that case. - int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes(); + // Scalar int4/uint4 storage is byte-packed (2 logical elements per byte). + // Vector int4 loads/stores reinterpret the underlying packed bytes as the + // requested vector type, so their index still advances by the vector lane + // count. Scalar int1 keeps the existing int32 backing. + int div_factor = t.lanes(); + if (t.lanes() == 1) { + div_factor = (t.bits() == 4) ? 2 : (32 / t.bits()); + } index_str = PrintExpr(arith::Analyzer().Simplify(truncdiv(index, div_factor))); - os << "*((" << ptr_cast(t) << vid << ")" << " + " << index_str << ")"; + os << "*((" << ptr_cast(t) << vid << ")" + << " + " << index_str << ")"; } else if (t == buffer_element_dtype) { int div_factor = 1; if (buffer_element_dtype.is_float4() && buffer_element_dtype.lanes() == 1) { @@ -2433,14 +2442,24 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { std::string B_dtype = Downcast(op->args[4])->value; std::string C_dtype = Downcast(op->args[5])->value; std::string a_ref = this->PrintExpr(op->args[6]); - std::string a_bias = this->PrintExpr(op->args[7]); std::string b_ref = this->PrintExpr(op->args[8]); - std::string b_bias = this->PrintExpr(op->args[9]); std::string c_ref = this->PrintExpr(op->args[10]); std::string c_bias = this->PrintExpr(op->args[11]); auto dtype_a_enum = tl::codegen::ptx::DTypeFromString(A_dtype); auto dtype_b_enum = tl::codegen::ptx::DTypeFromString(B_dtype); auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype); + PrimExpr a_bias_expr = op->args[7]; + PrimExpr b_bias_expr = op->args[9]; + if (dtype_a_enum == tl::codegen::ptx::DataType::kInt4 || + dtype_a_enum == tl::codegen::ptx::DataType::kUInt4) { + a_bias_expr = arith::Analyzer().Simplify(truncdiv(a_bias_expr, 2)); + } + if (dtype_b_enum == tl::codegen::ptx::DataType::kInt4 || + dtype_b_enum == tl::codegen::ptx::DataType::kUInt4) { + b_bias_expr = arith::Analyzer().Simplify(truncdiv(b_bias_expr, 2)); + } + std::string a_bias = this->PrintExpr(a_bias_expr); + std::string b_bias = this->PrintExpr(b_bias_expr); auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape); need_mma_instruction_h_ = true; @@ -2470,7 +2489,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { if (BRegType == "float") { BRegType = "uint32_t"; } - replacer.register_rule("(AType)", AType); replacer.register_rule("(BType)", BType); replacer.register_rule("(CType)", @@ -2909,8 +2927,16 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { int num = Downcast(op->args[1])->value; std::string type = Downcast(op->args[2])->value; std::string local_ptr = this->PrintExpr(op->args[3]); - std::string local_elem_offset = this->PrintExpr(op->args[4]); + bool is_packed_int4 = + op->dtype.bits() == 4 && (op->dtype.is_int() || op->dtype.is_uint()); + PrimExpr local_elem_offset_expr = op->args[4]; + if (is_packed_int4) { + local_elem_offset_expr = + arith::Analyzer().Simplify(truncdiv(local_elem_offset_expr, 2)); + } + std::string local_elem_offset = this->PrintExpr(local_elem_offset_expr); std::string smem_ptr = this->PrintExpr(op->args[5]); + if (trans && op->dtype.bits() == 8) { // Since ldmatrix assumes that a matrix element is 16 bit, it cannot // properly transpose an int8 matrix. @@ -2924,7 +2950,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { " + threadIdx.x / 4 + (i / 8) * 8];\n"; os << "}\n"; } else { - std::string smem_elem_offset = this->PrintExpr(op->args[6]); + PrimExpr smem_elem_offset_expr = op->args[6]; + if (is_packed_int4) { + smem_elem_offset_expr = + arith::Analyzer().Simplify(truncdiv(smem_elem_offset_expr, 2)); + } + std::string smem_elem_offset = this->PrintExpr(smem_elem_offset_expr); std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num); if (trans == 1) func_name += "_trans"; @@ -2993,8 +3024,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { if (op->dtype.bits() == 16) { os << "for (int local_id = 0; local_id < 8; local_id+=2) {\n"; os << "*((uint *)&" << dst << "[" + this->PrintExpr(dst_ind) + "])" - << " = " << "*((uint *)&" << src << "[" << src_offset - << " + local_id]);\n"; + << " = " + << "*((uint *)&" << src << "[" << src_offset << " + local_id]);\n"; os << "}\n"; } else { os << "for (int local_id = 0; local_id < 8; ++local_id) {\n"; @@ -3086,7 +3117,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { this->stream << "\" @!p mov.b32 %0, 0;\\n\"\n"; this->stream << "\" @p ld.global.nc.f32 %0, [%1];}\\n\"\n"; // stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ; - stream << ": \"=f\"(" << reg << "[" << local_addr << "]" << ")\n"; + stream << ": \"=f\"(" << reg << "[" << local_addr << "]" + << ")\n"; stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr << ")), \"r\"((int)" << guard << ")\n"; stream << ");\n"; @@ -3898,9 +3930,12 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { // For FP4 scalar local buffers, we use packed storage type, // so skip type declaration here (will be handled in the local scope section // below) - bool is_fp4_scalar_local = op->dtype.is_float4() && op->dtype.is_scalar() && - (scope == "local" || scope.empty()); - if (!is_fp4_scalar_local) { + bool is_fp4_scalar_local = + op->dtype.is_float4() && op->dtype.is_scalar() && scope == "local"; + bool is_int4_scalar_local = + (op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4)) && + op->dtype.is_scalar() && scope == "local"; + if (!is_fp4_scalar_local && !is_int4_scalar_local) { PrintStorageScope(scope, stream); PrintType(op->dtype, stream); } @@ -3916,10 +3951,11 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { if (scope.find("wmma.") == 0) { constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); } - if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || - op->dtype == DataType::Int(1)) && + if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4)) && scope == "shared") { - constant_size = constant_size / (32 / op->dtype.bits()); + constant_size = (constant_size + 1) / 2; + } else if (op->dtype == DataType::Int(1) && scope == "shared") { + constant_size = constant_size / 32; } if (scope == "shared") { stream << ' ' << vid << '[' << constant_size << "];\n"; @@ -3930,18 +3966,24 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { stream << "auto " << vid << " = reinterpret_cast<" << mbarrier_dtype_ << "*>(" << v_id_mem << ");\n"; } else if (scope == "local") { - // For FP4 types, use packed storage type to avoid wasting registers. - // fp4_e2_t uses int8 as storage but only needs 4 bits per element. - // By using fp4_e2_2_t (which stores 2 fp4 values in 1 byte), we halve the - // storage. - if (op->dtype.is_float4() && op->dtype.is_scalar()) { - auto vid_packed = vid + "_packed"; - stream << "fp4_e2_2_t " << vid_packed << '[' << (constant_size + 1) / 2 - << "];\n"; - // Record mapping from original buffer to packed buffer name - fp4_packed_buffers_[op->buffer_var.get()] = vid_packed; + if (op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4)) { + stream << "alignas(16) "; + PrintType(op->dtype, stream); + stream << ' ' << vid << '[' << (constant_size + 1) / 2 << "];\n"; } else { - stream << ' ' << vid << '[' << constant_size << "];\n"; + // For FP4 types, use packed storage type to avoid wasting registers. + // fp4_e2_t uses int8 as storage but only needs 4 bits per element. + // By using fp4_e2_2_t (which stores 2 fp4 values in 1 byte), we halve + // the storage. + if (op->dtype.is_float4() && op->dtype.is_scalar()) { + auto vid_packed = vid + "_packed"; + stream << "fp4_e2_2_t " << vid_packed << '[' + << (constant_size + 1) / 2 << "];\n"; + // Record mapping from original buffer to packed buffer name + fp4_packed_buffers_[op->buffer_var.get()] = vid_packed; + } else { + stream << ' ' << vid << '[' << constant_size << "];\n"; + } } } else if (scope == "local.var") { PrimExpr init = tir::make_const(op->dtype, 0); @@ -4005,8 +4047,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) { PrintType(op->dtype, os); os << "("; for (int i = 0; i < lanes; i++) { - os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) - << "*" << i << ")"; + os << "(" << PrintExpr(op->base) << ")" + << "+(" << PrintExpr(op->stride) << "*" << i << ")"; if (i != lanes - 1) os << ", "; } @@ -4025,6 +4067,21 @@ void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op, Var buffer_var = op->buffer->data; DataType element_dtype = op->buffer->dtype; + if ((element_dtype == DataType::Int(4) || + element_dtype == DataType::UInt(4)) && + element_dtype.is_scalar() && value_dtype.is_scalar()) { + std::string idx_str = PrintExpr(index); + std::string vid = GetVarID(buffer_var.get()); + if (element_dtype.is_uint()) { + os << "tl_uint4_packed_load((const unsigned char*)" << vid << ", " + << idx_str << ")"; + } else { + os << "tl_int4_packed_load((const signed char*)" << vid << ", " << idx_str + << ")"; + } + return; + } + // Check if this is a fp4 packed buffer access auto packed_it = fp4_packed_buffers_.find(buffer_var.get()); if (packed_it != fp4_packed_buffers_.end() && value_dtype.is_scalar()) { @@ -4032,7 +4089,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op, os << "tl_fp4_packed_load(" << packed_it->second << ", " << idx_str << ")"; return; } - int lanes = op->dtype.lanes(); // declare type. if (value_dtype.lanes() == element_dtype.lanes()) { @@ -4108,6 +4164,23 @@ void CodeGenTileLangCUDA::VisitStmt_(const BufferStoreNode *op) { PrimExpr index_expr = op->indices[0]; Var buffer_var = op->buffer->data; + if ((element_dtype == DataType::Int(4) || + element_dtype == DataType::UInt(4)) && + element_dtype.is_scalar() && value_dtype.is_scalar()) { + std::string idx_str = PrintExpr(index_expr); + std::string value = this->PrintExpr(op->value); + std::string vid = GetVarID(buffer_var.get()); + this->PrintIndent(); + if (element_dtype.is_uint()) { + stream << "tl_uint4_packed_store((unsigned char*)" << vid << ", " + << idx_str << ", " << value << ");\n"; + } else { + stream << "tl_int4_packed_store((signed char*)" << vid << ", " << idx_str + << ", " << value << ");\n"; + } + return; + } + // Check if this is a fp4 packed buffer access auto packed_it = fp4_packed_buffers_.find(buffer_var.get()); if (packed_it != fp4_packed_buffers_.end() && value_dtype.is_scalar()) { @@ -4118,7 +4191,6 @@ void CodeGenTileLangCUDA::VisitStmt_(const BufferStoreNode *op) { << ", " << value << ");\n"; return; } - if (value_dtype.lanes() == element_dtype.lanes()) { // For scalar fp4 stores to non-packed buffers, use tl_fp4_packed_store // to correctly handle nibble-level writes. The /2 in GetBufferRef maps two diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h index bb124b795f..57409178de 100644 --- a/src/tl_templates/cuda/common.h +++ b/src/tl_templates/cuda/common.h @@ -205,6 +205,44 @@ TL_DEVICE uint4 make_uint4(unsigned short x0, unsigned short x1, return result; } +// ============================================================================ +// Packed INT4 Buffer Access Helpers +// ============================================================================ +// TileLang lowers scalar int4/uint4 storage through byte-packed buffers, where +// each byte carries 2 logical 4-bit elements. + +TL_DEVICE int tl_int4_packed_load(const signed char *packed, int idx) { + unsigned char byte = static_cast(packed[idx >> 1]); + unsigned int shift = (idx & 1) * 4; + int value = static_cast((byte >> shift) & 0xF); + return (value << 28) >> 28; +} + +TL_DEVICE unsigned int tl_uint4_packed_load(const unsigned char *packed, + int idx) { + unsigned char byte = packed[idx >> 1]; + unsigned int shift = (idx & 1) * 4; + return (byte >> shift) & 0xF; +} + +TL_DEVICE void tl_int4_packed_store(signed char *packed, int idx, int val) { + unsigned int shift = (idx & 1) * 4; + unsigned char mask = static_cast(0xFu << shift); + unsigned char nibble = static_cast( + (static_cast(val) & 0xF) << shift); + unsigned char byte = static_cast(packed[idx >> 1]); + packed[idx >> 1] = static_cast((byte & ~mask) | nibble); +} + +TL_DEVICE void tl_uint4_packed_store(unsigned char *packed, int idx, + unsigned int val) { + unsigned int shift = (idx & 1) * 4; + unsigned char mask = static_cast(0xFu << shift); + unsigned char nibble = static_cast((val & 0xF) << shift); + packed[idx >> 1] = + static_cast((packed[idx >> 1] & ~mask) | nibble); +} + // Pack eight int values. TL_DEVICE longlong4 make_longlong4(int x0, int x1, int y0, int y1, int z0, int z1, int w0, int w1) { diff --git a/src/tl_templates/cuda/instruction/mma.h b/src/tl_templates/cuda/instruction/mma.h index 869fa777bc..c4a276f3a9 100644 --- a/src/tl_templates/cuda/instruction/mma.h +++ b/src/tl_templates/cuda/instruction/mma.h @@ -105,11 +105,15 @@ TL_DEFINE_MMA_DISPATCHER(kInt8, kInt8, kInt32, 16, 8, 32, false, true, false, TL_DEFINE_MMA_DISPATCHER(kUInt8, kUInt8, kInt32, 16, 8, 32, false, true, false, cute::SM80_16x8x32_S32U8U8S32_TN) -// INT4 inputs (k32) +// INT4 inputs (k32, k64) TL_DEFINE_MMA_DISPATCHER(kInt4, kInt4, kInt32, 16, 8, 32, false, true, false, cute::SM80_16x8x32_S32S4S4S32_TN) +TL_DEFINE_MMA_DISPATCHER(kInt4, kInt4, kInt32, 16, 8, 64, false, true, false, + cute::SM80_16x8x64_S32S4S4S32_TN) TL_DEFINE_MMA_DISPATCHER(kUInt4, kUInt4, kInt32, 16, 8, 32, false, true, false, cute::SM80_16x8x32_S32U4U4S32_TN) +TL_DEFINE_MMA_DISPATCHER(kUInt4, kUInt4, kInt32, 16, 8, 64, false, true, false, + cute::SM80_16x8x64_S32U4U4S32_TN) // FP8 inputs (k32) TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 16, 8, 32, false, diff --git a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py index 2e131f7f54..3fe33aebf0 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py @@ -164,7 +164,6 @@ def main( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype=T.float32, a_transposed=False, b_transposed=True, k_pack=1): matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack) - print(matmul) kernel = tilelang.compile(matmul) src_code = kernel.get_kernel_source() # src_code is the generated cuda source @@ -221,16 +220,15 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype=T.flo (128, 256, 256, T.int8, T.int32, T.int32, False, True, 2), (128, 256, 256, T.int8, T.int32, T.int32, False, False, 1), (128, 256, 256, T.int8, T.int32, T.int32, False, False, 2), - (128, 128, 128, T.dtype(determine_fp8_type()), T.float16, T.float32, False, True, 1), - (128, 256, 256, T.dtype(determine_fp8_type()), T.float32, T.float32, False, True, 1), - (128, 256, 256, T.dtype(determine_fp8_type()), T.float32, T.float32, False, True, 2), - (128, 256, 256, T.dtype(determine_fp8_type()), T.float32, T.float32, False, False, 1), - (128, 256, 256, T.dtype(determine_fp8_type()), T.float32, T.float32, False, False, 2), + (128, 128, 128, determine_fp8_type(), T.float16, T.float32, False, True, 1), + (128, 256, 256, determine_fp8_type(), T.float32, T.float32, False, True, 1), + (128, 256, 256, determine_fp8_type(), T.float32, T.float32, False, True, 2), + (128, 256, 256, determine_fp8_type(), T.float32, T.float32, False, False, 1), + (128, 256, 256, determine_fp8_type(), T.float32, T.float32, False, False, 2), ], ) @tilelang.testing.requires_rocm def test_assert_tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack): - print(f"in_dtype: {in_dtype}") assert_tl_matmul_correctness( M, N, diff --git a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py index 271bfdaf26..d4746c16d9 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py @@ -266,10 +266,10 @@ def assert_tl_matmul_correctness( (256, 256, 512, T.int8, T.int32, T.int32, False, False, 1, True, False), (256, 256, 512, T.int8, T.int32, T.int32, False, True, 2, True, False), (256, 256, 512, T.int8, T.int32, T.int32, False, False, 2, True, False), - (256, 256, 512, T.dtype(determine_fp8_type()), T.float32, T.float32, False, True, 1, True, False), - (256, 256, 512, T.dtype(determine_fp8_type()), T.float32, T.float32, False, False, 1, True, False), - (256, 256, 512, T.dtype(determine_fp8_type()), T.float32, T.float32, False, True, 2, True, False), - (256, 256, 512, T.dtype(determine_fp8_type()), T.float32, T.float32, False, False, 2, True, False), + (256, 256, 512, determine_fp8_type(), T.float32, T.float32, False, True, 1, True, False), + (256, 256, 512, determine_fp8_type(), T.float32, T.float32, False, False, 1, True, False), + (256, 256, 512, determine_fp8_type(), T.float32, T.float32, False, True, 2, True, False), + (256, 256, 512, determine_fp8_type(), T.float32, T.float32, False, False, 2, True, False), ], ) @tilelang.testing.requires_rocm diff --git a/testing/python/debug/test_tilelang_debug_print.py b/testing/python/debug/test_tilelang_debug_print.py index 4d59d3d17b..d7b7f7a7e4 100644 --- a/testing/python/debug/test_tilelang_debug_print.py +++ b/testing/python/debug/test_tilelang_debug_print.py @@ -33,8 +33,8 @@ def test_debug_print_buffer_cuda_fp8(): @tilelang.testing.requires_rocm def test_debug_print_buffer_rocm_fp8(): - debug_print_buffer(dtype=getattr(T, determine_fp8_type("e4m3"))) - debug_print_buffer(dtype=getattr(T, determine_fp8_type("e5m2"))) + debug_print_buffer(dtype=determine_fp8_type("e4m3")) + debug_print_buffer(dtype=determine_fp8_type("e5m2")) def debug_print_buffer_conditional(M=16, N=16): diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index ff6e427b65..bc10821734 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -1,5 +1,4 @@ from __future__ import annotations -from dataclasses import dataclass import tilelang.language as T from typing import Literal, Callable from tilelang.common import TransformKind @@ -32,160 +31,6 @@ lift = convert -def _resolve_subbyte_local_offset(local_size: int, numerator: int, denominator: int) -> int: - if denominator <= 0: - raise ValueError(f"denominator must be positive, but got {denominator}") - scaled = local_size * numerator - if scaled % denominator != 0: - raise ValueError(f"Invalid subbyte MMA offset {numerator}/{denominator} for local_size={local_size}") - return scaled // denominator - - -def _infer_subbyte_storage_bits(logical_bits: int) -> int: - for storage_bits in (8, 16, 32): - if storage_bits >= logical_bits and storage_bits % logical_bits == 0: - return storage_bits - raise ValueError(f"Unsupported subbyte logical bit width: {logical_bits}") - - -def _infer_subbyte_storage_dtype(logical_dtype: str, logical_bits: int) -> str: - storage_bits = _infer_subbyte_storage_bits(logical_bits) - logical_dtype = str(logical_dtype) - if logical_dtype.startswith("uint"): - return f"uint{storage_bits}" - if logical_dtype.startswith("int"): - return f"int{storage_bits}" - # For non-integer subbyte dtypes such as future fp4, use an integer carrier dtype - # inside lowering. The logical dtype still drives the MMA opcode selection. - return f"int{storage_bits}" - - -@dataclass(frozen=True) -class SubByteTensorCoreMMAOp: - a_offset_num: int = 0 - a_offset_den: int = 1 - b_offset_num: int = 0 - b_offset_den: int = 1 - c_offset_num: int = 0 - c_offset_den: int = 1 - - def resolve_offsets(self, local_size_a: int, local_size_b: int, local_size_out: int) -> tuple[int, int, int]: - return ( - _resolve_subbyte_local_offset(local_size_a, self.a_offset_num, self.a_offset_den), - _resolve_subbyte_local_offset(local_size_b, self.b_offset_num, self.b_offset_den), - _resolve_subbyte_local_offset(local_size_out, self.c_offset_num, self.c_offset_den), - ) - - -@dataclass(frozen=True) -class SubByteTensorCoreMMASpec: - logical_a_dtype: str - logical_b_dtype: str - logical_a_bits: int - logical_b_bits: int - accum_dtype: str - mma_prefix: str - mma_a_dtype_abbrv: str - mma_b_dtype_abbrv: str - mma_ops: tuple[SubByteTensorCoreMMAOp, ...] - - def __post_init__(self): - self._validate_pack_factor(self.storage_a_dtype, self.logical_a_bits, "A") - self._validate_pack_factor(self.storage_b_dtype, self.logical_b_bits, "B") - - @property - def storage_a_dtype(self) -> str: - return _infer_subbyte_storage_dtype(self.logical_a_dtype, self.logical_a_bits) - - @property - def storage_b_dtype(self) -> str: - return _infer_subbyte_storage_dtype(self.logical_b_dtype, self.logical_b_bits) - - @staticmethod - def _validate_pack_factor(storage_dtype: str, logical_bits: int, matrix: str): - storage_bits = DataType(storage_dtype).bits - if storage_bits < logical_bits or storage_bits % logical_bits != 0: - raise ValueError( - f"Subbyte MMA spec expects {matrix} storage dtype {storage_dtype} to pack logical {logical_bits}-bit elements exactly" - ) - - @property - def a_pack_factor(self) -> int: - return DataType(self.storage_a_dtype).bits // self.logical_a_bits - - @property - def b_pack_factor(self) -> int: - return DataType(self.storage_b_dtype).bits // self.logical_b_bits - - def get_pack_factor(self, matrix: Literal["A", "B"]) -> int: - if matrix == "A": - return self.a_pack_factor - if matrix == "B": - return self.b_pack_factor - raise ValueError(f"Unsupported matrix kind: {matrix}") - - def get_storage_dtype(self, matrix: Literal["A", "B"]) -> str: - if matrix == "A": - return self.storage_a_dtype - if matrix == "B": - return self.storage_b_dtype - raise ValueError(f"Unsupported matrix kind: {matrix}") - - def get_logical_dtype(self, matrix: Literal["A", "B"]) -> str: - if matrix == "A": - return self.logical_a_dtype - if matrix == "B": - return self.logical_b_dtype - raise ValueError(f"Unsupported matrix kind: {matrix}") - - def pack_extent(self, extent: int, matrix: Literal["A", "B"]) -> int: - pack_factor = self.get_pack_factor(matrix) - if extent % pack_factor != 0: - raise ValueError(f"{self.get_logical_dtype(matrix)} expects extent divisible by {pack_factor}, but got {extent}") - return extent // pack_factor - - -INT4_TENSORCORE_MMA_SPEC = SubByteTensorCoreMMASpec( - logical_a_dtype="int4", - logical_b_dtype="int4", - logical_a_bits=4, - logical_b_bits=4, - accum_dtype="int32", - mma_prefix="m16n8k32", - mma_a_dtype_abbrv="int4", - mma_b_dtype_abbrv="int4", - mma_ops=( - SubByteTensorCoreMMAOp(), - SubByteTensorCoreMMAOp(b_offset_num=1, b_offset_den=2, c_offset_num=1, c_offset_den=2), - SubByteTensorCoreMMAOp(a_offset_num=1, a_offset_den=2, b_offset_num=1, b_offset_den=4), - SubByteTensorCoreMMAOp(a_offset_num=1, a_offset_den=2, b_offset_num=3, b_offset_den=4, c_offset_num=1, c_offset_den=2), - ), -) - -_SUBBYTE_TENSORCORE_MMA_SPECS = { - "int4": INT4_TENSORCORE_MMA_SPEC, -} - - -def get_subbyte_tensorcore_mma_spec(dtype: str) -> SubByteTensorCoreMMASpec | None: - return _SUBBYTE_TENSORCORE_MMA_SPECS.get(str(dtype)) - - -def infer_subbyte_tensorcore_mma_spec(a_dtype: str, b_dtype: str) -> SubByteTensorCoreMMASpec | None: - a_spec = get_subbyte_tensorcore_mma_spec(a_dtype) - b_spec = get_subbyte_tensorcore_mma_spec(b_dtype) - - if a_spec is None and b_spec is None: - return None - if a_spec is None or b_spec is None: - raise ValueError(f"Subbyte MMA requires both operands to be subbyte dtypes, but got a_dtype={a_dtype}, b_dtype={b_dtype}") - if not (str(a_dtype) == str(a_spec.logical_a_dtype) and str(b_dtype) == str(a_spec.logical_b_dtype)): - raise ValueError(f"Unsupported subbyte MMA operand dtypes: a_dtype={a_dtype}, b_dtype={b_dtype}") - if a_spec != b_spec: - raise ValueError(f"Mismatched subbyte MMA specs for operands: a_dtype={a_dtype}, b_dtype={b_dtype}") - return a_spec - - class TensorCoreIntrinEmitter: """ To eliminate Python syntax within TIR Macro. @@ -201,6 +46,7 @@ class TensorCoreIntrinEmitter: "bfloat16": "bf16", "float32": "fp32", "float64": "fp64", + "int4": "int4", "int8": "int8", "uint8": "uint8", "int32": "int32", @@ -238,35 +84,22 @@ def __init__( self.accum_dtype = accum_dtype self.a_transposed = a_transposed self.b_transposed = b_transposed - self.subbyte_mma_spec = infer_subbyte_tensorcore_mma_spec(a_dtype, b_dtype) - if self.subbyte_mma_spec is not None and str(accum_dtype) != str(self.subbyte_mma_spec.accum_dtype): - raise ValueError( - f"Subbyte MMA dtypes ({a_dtype}, {b_dtype}) expect accum dtype {self.subbyte_mma_spec.accum_dtype}, but got {accum_dtype}" - ) # Hint Information self.block_row_warps = block_row_warps self.block_col_warps = block_col_warps self.warp_row_tiles = warp_row_tiles self.warp_col_tiles = warp_col_tiles self.chunk = chunk - a_storage_dtype = self._get_storage_dtype("A") - b_storage_dtype = self._get_storage_dtype("B") - self._initialize_k_dim(a_storage_dtype) + self._initialize_k_dim(self.a_dtype) # For FP64, MMA shape is m8n8k4; adjust instance dims early - if DataType(a_storage_dtype).bits == 64: + if DataType(self.a_dtype).bits == 64: # Override default M/N dims for fp64 MMA self.M_DIM = 8 # n_dim will be set to 8 in _initialize_micro_size via k_dim==4 self._initialize_micro_size(self.M_DIM, self.k_dim) self._initialize_local_size(self.M_DIM, self.n_dim, self.k_dim, self.WARP_SIZE) - if self.subbyte_mma_spec is None: - self._initialize_abbrev(a_storage_dtype, b_storage_dtype, accum_dtype) - self._initialize_mma_prefix(self.k_dim) - else: - self.a_dtype_abbrv = self.subbyte_mma_spec.mma_a_dtype_abbrv - self.b_dtype_abbrv = self.subbyte_mma_spec.mma_b_dtype_abbrv - self.accum_dtype_abbrv = str(self.subbyte_mma_spec.accum_dtype) - self.mma_prefix = self.subbyte_mma_spec.mma_prefix + self._initialize_abbrev(self.a_dtype, self.b_dtype, accum_dtype) + self._initialize_mma_prefix(self.k_dim) self._initialize_is_m_first(is_m_first) self.reduce_k = reduce_k @@ -279,21 +112,10 @@ def __init__( f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}" ) - def _get_storage_dtype(self, matrix: Literal["A", "B"]) -> str: - if matrix == "A": - logical_dtype = self.a_dtype - elif matrix == "B": - logical_dtype = self.b_dtype - else: - raise ValueError(f"Unsupported matrix kind: {matrix}") - if self.subbyte_mma_spec is None: - return logical_dtype - return self.subbyte_mma_spec.get_storage_dtype(matrix) - def _initialize_k_dim(self, a_dtype=T.float16): if isinstance(a_dtype, str): a_dtype = DataType(a_dtype) - self.k_dim = 256 // a_dtype.bits + self.k_dim = min(256 // a_dtype.bits, self.chunk) def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): self.local_size_a = (m_dim * k_dim) // warp_size @@ -306,10 +128,9 @@ def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): self.accum_dtype_abbrv = self._get_dtype_abbrv(accum_dtype) def _get_dtype_abbrv(self, dtype: str) -> str: - try: - return self.dtype_abbrv[dtype] - except KeyError as err: - raise ValueError(f"Unsupported dtype: {dtype}") from err + if dtype not in self.dtype_abbrv: + raise ValueError(f"Unsupported dtype: {dtype}") + return self.dtype_abbrv[dtype] def _initialize_mma_prefix(self, k_dim: int = 16): if k_dim == 4: @@ -323,9 +144,19 @@ def _initialize_mma_prefix(self, k_dim: int = 16): self.mma_prefix = "m16n8k16" elif k_dim == 32: # typically used for int8/fp8 + # sometimes int4/uint4 is also supported self.mma_prefix = "m16n8k32" + elif k_dim == 64: + # typically used for int4/uint4 + self.mma_prefix = "m16n8k64" + elif k_dim == 128: + # typically used for int2/uint2 + self.mma_prefix = "m16n8k128" + elif k_dim == 256: + # typically used for uint1 + self.mma_prefix = "m16n8k256" else: - raise ValueError("Unsupported k_dim") + raise ValueError(f"Unsupported k_dim {k_dim}") def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16): warp_row_tiles = self.warp_row_tiles @@ -415,7 +246,7 @@ def extract_thread_binding(self, thread_id: PrimExpr, is_m_first: bool | None = def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0): # Fast path for fp64: no ldmatrix support, do direct per-lane loads - a_dtype = self._get_storage_dtype("A") + a_dtype = self.a_dtype if DataType(a_dtype).bits == 64: warp_row_tiles = self.warp_row_tiles warp_rows = self.warp_rows @@ -527,7 +358,7 @@ def _warp_ldmatrix_a( def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0): # Fast path for fp64: no ldmatrix support, do direct per-lane loads - b_dtype = self._get_storage_dtype("B") + b_dtype = self.b_dtype if DataType(b_dtype).bits == 64: warp_col_tiles = self.warp_col_tiles warp_cols = self.warp_cols @@ -646,19 +477,6 @@ def _warp_ldmatrix_b( return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) def mma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr | None = 0): - if self.subbyte_mma_spec is not None: - return _emit_subbyte_tensorcore_mma( - self.subbyte_mma_spec, - self.warp_rows, - self.warp_cols, - self.local_size_a, - self.local_size_b, - self.local_size_out, - self.accum_dtype, - A_local_buf, - B_local_buf, - C_local_buf, - ) warp_rows = self.warp_rows warp_cols = self.warp_cols local_size_a = self.local_size_a @@ -798,7 +616,7 @@ def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A assert matrix in ["A", "B"], "matrix should be either A or B" matrix_is_a: bool = matrix == "A" matrix_is_b: bool = matrix == "B" - dtype = self._get_storage_dtype(matrix) + dtype = self.a_dtype if matrix_is_a else self.b_dtype dtype_bits = DataType(dtype).bits transposed = self.a_transposed if matrix_is_a else self.b_transposed @@ -1048,9 +866,9 @@ def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): self.local_size_out = (m_dim * n_dim) // warp_size def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): - self.a_dtype_abbrv = self.dtype_abbrv[a_dtype] - self.b_dtype_abbrv = self.dtype_abbrv[b_dtype] - self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype] + self.a_dtype_abbrv = self._get_dtype_abbrv(a_dtype) + self.b_dtype_abbrv = self._get_dtype_abbrv(b_dtype) + self.accum_dtype_abbrv = self._get_dtype_abbrv(accum_dtype) def _initialize_mma_prefix(self, k_dim=16): if k_dim == 16: @@ -1090,7 +908,7 @@ def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, rk=0): micro_size_x = self.micro_size_x micro_size_k = self.micro_size_k local_size_a = self.local_size_a - a_dtype = self._get_storage_dtype("A") + a_dtype = self.a_dtype a_transposed = self.a_transposed transform_kind_a = self.transform_kind_a @@ -1185,7 +1003,7 @@ def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, rk=0): micro_size_y = self.micro_size_y micro_size_k = self.micro_size_k local_size_b = self.local_size_b - b_dtype = self._get_storage_dtype("B") + b_dtype = self.b_dtype transform_kind_b = self.transform_kind_b b_transposed = self.b_transposed num_elems_per_byte = self.num_elems_per_byte @@ -1278,19 +1096,6 @@ def _warp_ldmatrix_b( return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) def mma(self, A_local_buf, B_local_buf, C_local_buf): - if self.subbyte_mma_spec is not None: - return _emit_subbyte_tensorcore_mma( - self.subbyte_mma_spec, - self.warp_rows, - self.warp_cols, - self.local_size_a, - self.local_size_b, - self.local_size_out, - self.accum_dtype, - A_local_buf, - B_local_buf, - C_local_buf, - ) warp_rows = self.warp_rows warp_cols = self.warp_cols local_size_a = self.local_size_a @@ -1340,55 +1145,3 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): ) return _warp_mma(A_local_buf, B_local_buf, C_local_buf) - - -def _emit_subbyte_tensorcore_mma( - mma_spec: SubByteTensorCoreMMASpec, - warp_rows: int, - warp_cols: int, - local_size_a: int, - local_size_b: int, - local_size_out: int, - accum_dtype: str, - A_local_buf, - B_local_buf, - C_local_buf, -): - accum_dtype_abbrv = accum_dtype - mma_prefix = mma_spec.mma_prefix - a_dtype_abbrv = mma_spec.mma_a_dtype_abbrv - b_dtype_abbrv = mma_spec.mma_b_dtype_abbrv - mma_op_offsets = tuple(mma_op.resolve_offsets(local_size_a, local_size_b, local_size_out) for mma_op in mma_spec.mma_ops) - - @T.macro - def _emit_subbyte_mma_op(A_local_buf, B_local_buf, C_local_buf, i, j, a_offset, b_offset, c_offset): - T.ptx_mma( - accum_dtype, - mma_prefix, - "row", - "col", - a_dtype_abbrv, - b_dtype_abbrv, - accum_dtype_abbrv, - A_local_buf.data, - i * local_size_a + a_offset, - B_local_buf.data, - j * local_size_b + b_offset, - C_local_buf.data, - i * warp_cols * local_size_out + j * local_size_out + c_offset, - T.bool(False), - ) - - def _emit_subbyte_mma_ops(A_local_buf, B_local_buf, C_local_buf, i, j, op_index: int = 0): - if op_index >= len(mma_op_offsets): - return - a_offset, b_offset, c_offset = mma_op_offsets[op_index] - _emit_subbyte_mma_op(A_local_buf, B_local_buf, C_local_buf, i, j, a_offset, b_offset, c_offset) - _emit_subbyte_mma_ops(A_local_buf, B_local_buf, C_local_buf, i, j, op_index + 1) - - @T.macro - def _warp_mma(A_local_buf, B_local_buf, C_local_buf): - for i, j in T.grid(warp_rows, warp_cols): - _emit_subbyte_mma_ops(A_local_buf, B_local_buf, C_local_buf, i, j) - - return _warp_mma(A_local_buf, B_local_buf, C_local_buf) diff --git a/tilelang/intrinsics/utils.py b/tilelang/intrinsics/utils.py index 128d9819e7..724d3f94a2 100644 --- a/tilelang/intrinsics/utils.py +++ b/tilelang/intrinsics/utils.py @@ -24,7 +24,7 @@ def get_ldmatrix_offset( row_idx, col_idx, stride, - dtype: Literal["float16", "int8"] = "float16", + dtype: Literal["float16", "int8", "int4"] = "float16", transposed: bool = False, ): assert matrix in ["A", "B"], "matrix should be either A or B" @@ -49,15 +49,17 @@ def get_ldmatrix_offset( else: new_row_idx, new_col_idx = transform_func(row_idx, col_idx) return new_row_idx, new_col_idx - elif dtype_bits == 8: + elif dtype_bits <= 8: if matrix == "B" and transposed: transform_func = ldmatrix_32x16_to_shared_16x32_layout_b new_row_idx, new_col_idx = transform_func(row_idx, col_idx) - return new_row_idx, new_col_idx + pack_factor = 8 // dtype_bits + return new_row_idx, new_col_idx * pack_factor elif matrix == "A" and not transposed: transform_func = ldmatrix_32x16_to_shared_16x32_layout_a new_row_idx, new_col_idx = transform_func(row_idx, col_idx) - return new_row_idx, new_col_idx + pack_factor = 8 // dtype_bits + return new_row_idx, new_col_idx * pack_factor else: raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8") else: diff --git a/tilelang/tileop/gemm/gemm_mma.py b/tilelang/tileop/gemm/gemm_mma.py index 5b342d98f2..ebbd3653d9 100644 --- a/tilelang/tileop/gemm/gemm_mma.py +++ b/tilelang/tileop/gemm/gemm_mma.py @@ -5,8 +5,6 @@ from tilelang.layout import make_swizzled_layout from tilelang.intrinsics.mma_macro_generator import ( TensorCoreIntrinEmitter, - SubByteTensorCoreMMASpec, - get_subbyte_tensorcore_mma_spec, ) from tilelang.utils.language import is_shared, is_fragment, is_full_region from tilelang import tvm as tvm @@ -17,97 +15,11 @@ from tilelang.transform.simplify import _Simplify -class SubByteGemmOperandAdaptor: - def __init__(self, mma_spec: SubByteTensorCoreMMASpec): - self.mma_spec = mma_spec - - def get_storage_dtype(self, matrix: str) -> str: - return self.mma_spec.get_storage_dtype(matrix) - - def get_packed_chunk(self, logical_chunk: int, matrix: str = "A") -> int: - logical_chunk = int(logical_chunk) - return self.mma_spec.pack_extent(logical_chunk, matrix) - - def make_packed_buffer(self, buf: tir.Buffer, matrix: str) -> tir.Buffer: - shape = list(buf.shape) - if len(shape) < 2: - raise ValueError(f"{self.mma_spec.get_logical_dtype(matrix)} T.gemm expects at least 2D operands, but got shape={shape}") - packed_last_dim = int(shape[-1]) - pack_factor = self.mma_spec.get_pack_factor(matrix) - if packed_last_dim % pack_factor != 0: - raise ValueError( - f"{self.mma_spec.get_logical_dtype(matrix)} T.gemm expects an innermost K extent divisible by " - f"{pack_factor}, but got {packed_last_dim}" - ) - shape[-1] = packed_last_dim // pack_factor - return T.view(buf, tuple(shape), dtype=self.get_storage_dtype(matrix)) - - def make_packed_region(self, region: tir.BufferRegion, matrix: str) -> tir.BufferRegion: - packed_buf = self.make_packed_buffer(region.buffer, matrix) - pack_factor = self.mma_spec.get_pack_factor(matrix) - packed_ranges = list(region.region) - last_range = packed_ranges[-1] - packed_ranges[-1] = Range.from_min_extent(last_range.min // pack_factor, last_range.extent // pack_factor) - return tir.BufferRegion(packed_buf, packed_ranges) - - class GemmMMA(GemmBase): - def _get_subbyte_mma_spec(self) -> SubByteTensorCoreMMASpec | None: - return get_subbyte_tensorcore_mma_spec(self.in_dtype) - - def _get_subbyte_operand_adaptor(self) -> SubByteGemmOperandAdaptor | None: - mma_spec = self._get_subbyte_mma_spec() - if mma_spec is None: - return None - return SubByteGemmOperandAdaptor(mma_spec) - - def _validate_subbyte_mma_support(self, mma_spec: SubByteTensorCoreMMASpec): - chunk = int(self.chunk) - pack_factor_a = mma_spec.get_pack_factor("A") - pack_factor_b = mma_spec.get_pack_factor("B") - if not self.is_gemm_ss(): - raise ValueError(f"{self.in_dtype} T.gemm currently only supports shared/shared operands in the subbyte MMA path") - if self.trans_A or not self.trans_B: - raise ValueError( - f"{self.in_dtype} T.gemm currently only supports innermost-K packed layout (transpose_A=False, transpose_B=True)" - ) - if str(self.accum_dtype) != str(mma_spec.accum_dtype): - raise ValueError( - f"{self.in_dtype} T.gemm currently only supports {mma_spec.accum_dtype} accumulation, but got {self.accum_dtype}" - ) - if chunk % pack_factor_a != 0: - raise ValueError(f"{self.in_dtype} T.gemm expects the A K tile to be divisible by {pack_factor_a}, but got chunk={chunk}") - if chunk % pack_factor_b != 0: - raise ValueError(f"{self.in_dtype} T.gemm expects the B K tile to be divisible by {pack_factor_b}, but got chunk={chunk}") - def _make_mma_emitter(self, target: Target, thread_nums: int, thread_var: tir.Var | None = None): m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GemmInst.MMA) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) - subbyte_mma_spec = self._get_subbyte_mma_spec() - subbyte_adaptor = self._get_subbyte_operand_adaptor() - if subbyte_mma_spec is not None and subbyte_adaptor is not None: - self._validate_subbyte_mma_support(subbyte_mma_spec) - packed_chunk_a = subbyte_adaptor.get_packed_chunk(self.chunk, matrix="A") - packed_chunk_b = subbyte_adaptor.get_packed_chunk(self.chunk, matrix="B") - if packed_chunk_a != packed_chunk_b: - raise ValueError( - f"Subbyte MMA currently expects A/B to use the same packed K tile, but got A={packed_chunk_a}, B={packed_chunk_b}" - ) - emitter = TensorCoreIntrinEmitter( - a_dtype=self.in_dtype, - b_dtype=self.in_dtype, - accum_dtype=self.accum_dtype, - a_transposed=self.trans_A, - b_transposed=self.trans_B, - block_row_warps=m_warp, - block_col_warps=n_warp, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=packed_chunk_a, - thread_var=thread_var, - ) - return emitter, m_warp, n_warp emitter = TensorCoreIntrinEmitter( a_dtype=self.in_dtype, b_dtype=self.in_dtype, @@ -121,10 +33,10 @@ def _make_mma_emitter(self, target: Target, thread_nums: int, thread_var: tir.Va chunk=self.chunk, thread_var=thread_var, ) - return emitter, m_warp, n_warp + return emitter def infer_layout(self, target: Target, thread_nums: int): - mma_emitter, _, _ = self._make_mma_emitter(target, thread_nums) + mma_emitter = self._make_mma_emitter(target, thread_nums) if self.is_gemm_ss(): return { self.A: make_swizzled_layout(self.A), @@ -161,11 +73,9 @@ def lower( mbar_phase_expr: tir.PrimExpr | None = None, ): thread_nums = thread_bounds.extent - mma_emitter, _, _ = self._make_mma_emitter(target, thread_nums, thread_var=thread_var) - subbyte_adaptor = self._get_subbyte_operand_adaptor() + mma_emitter = self._make_mma_emitter(target, thread_nums, thread_var=thread_var) - a_local_dtype = subbyte_adaptor.get_storage_dtype("A") if subbyte_adaptor is not None else self.in_dtype - b_local_dtype = subbyte_adaptor.get_storage_dtype("B") if subbyte_adaptor is not None else self.in_dtype + in_dtype = self.in_dtype warp_rows = mma_emitter.warp_rows warp_cols = mma_emitter.warp_cols local_size_a = mma_emitter.local_size_a @@ -177,9 +87,6 @@ def lower( A_region = self.ARegion B_region = self.BRegion C_region = self.CRegion - if subbyte_adaptor is not None: - A_region = subbyte_adaptor.make_packed_region(A_region, "A") - B_region = subbyte_adaptor.make_packed_region(B_region, "B") A_buf = A_region.buffer B_buf = B_region.buffer @@ -200,8 +107,8 @@ def _gemm_ssr() -> None: B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ - A_local = T.alloc_local((warp_rows * local_size_a), a_local_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), b_local_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) if clear_accum: T.clear(C_buf) for ki in T.serial(0, (block_K // micro_size_k)): @@ -220,10 +127,7 @@ def _gemm_ssr() -> None: ) # Perform Matrix Multiplication - if subbyte_adaptor is not None: - mma_emitter.mma(A_local, B_local, C_buf) - else: - mma_emitter.mma(A_local, B_local, C_buf, ki) + mma_emitter.mma(A_local, B_local, C_buf, ki) # Simplify to optimize the index computing # Must inline let statements to simplify the analysis @@ -238,7 +142,7 @@ def _gemm_srr() -> None: B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ - A_local = T.alloc_local((warp_rows * local_size_a), a_local_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) for ki in T.serial(0, (block_K // micro_size_k)): if clear_accum: @@ -268,7 +172,7 @@ def _gemm_rsr() -> None: B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ - B_local = T.alloc_local((warp_cols * local_size_b), b_local_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) if clear_accum: T.clear(C_buf) for ki in T.serial(0, (block_K // micro_size_k)): diff --git a/tilelang/utils/target.py b/tilelang/utils/target.py index b5338f51e4..94252a5f3d 100644 --- a/tilelang/utils/target.py +++ b/tilelang/utils/target.py @@ -6,6 +6,7 @@ from platform import mac_ver from typing import Literal from tilelang import tvm as tvm +from tilelang import language as T from tilelang import _ffi_api from tvm.target import Target from tvm.contrib import rocm @@ -73,18 +74,18 @@ def determine_fp8_type(fp8_format: Literal["e4m3", "e5m2"] = "e4m3") -> str: if fp8_format not in {"e4m3", "e5m2"}: raise ValueError(f"Unsupported FP8 format: {fp8_format}") if torch.version.hip is None: - return "float8_e4m3fn" if fp8_format == "e4m3" else "float8_e5m2" + return T.float8_e4m3fn if fp8_format == "e4m3" else T.float8_e5m2 if not torch.cuda.is_available(): - return "float8_e4m3fnuz" if fp8_format == "e4m3" else "float8_e5m2fnuz" + return T.float8_e4m3fnuz if fp8_format == "e4m3" else T.float8_e5m2fnuz props = torch.cuda.get_device_properties(0) gcn_arch = getattr(props, "gcnArchName", "") if fp8_format == "e4m3": if gcn_arch.startswith("gfx950"): - return "float8_e4m3fn" - return "float8_e4m3fnuz" - if gcn_arch.startswith("gfx950") and hasattr(torch, "float8_e5m2"): - return "float8_e5m2" - return "float8_e5m2fnuz" + return T.float8_e4m3fn + return T.float8_e4m3fnuz + if gcn_arch.startswith("gfx950") and hasattr(T, "float8_e5m2"): + return T.float8_e5m2 + return T.float8_e5m2fnuz def determine_torch_fp8_type(fp8_format: Literal["e4m3", "e5m2"] = "e4m3") -> torch.dtype: From 4e7d126fd308b2ef1fadcf0bc282a10066f3c567 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Wed, 22 Apr 2026 00:47:07 +0800 Subject: [PATCH 107/156] Bump pytest --numprocesses from 4 to 8 across all platforms (#2076) Increase the default pytest parallel workers from 4 to 8 for all test suites (CUDA, ROCm, Metal, and CuTeDSL examples) to better utilize available CPU resources on CI runners. Co-authored-by: Claude Opus 4.6 --- .github/workflows/ci.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4b319a7218..576eb5b6ca 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -309,7 +309,7 @@ jobs: uv run --no-project -m -- pytest --verbose --color=yes --durations=0 --showlocals --cache-clear ) - "${PYTEST[@]}" --maxfail=3 --numprocesses=4 \ + "${PYTEST[@]}" --maxfail=3 --numprocesses=8 \ --ignore=../examples/grouped_gemm/test_example_grouped_gemm.py \ ../examples @@ -323,7 +323,7 @@ jobs: uv run --no-project -m -- pytest --verbose --color=yes --durations=0 --showlocals --cache-clear ) - "${PYTEST[@]}" --maxfail=3 --numprocesses=4 \ + "${PYTEST[@]}" --maxfail=3 --numprocesses=8 \ ./python # AMD ROCm tests @@ -337,7 +337,7 @@ jobs: uv run --no-project -m -- pytest --verbose --color=yes --durations=0 --showlocals --cache-clear ) - "${PYTEST[@]}" --maxfail=3 --numprocesses=4 \ + "${PYTEST[@]}" --maxfail=3 --numprocesses=8 \ --ignore=./python/runtime --ignore=./python/transform \ ./python @@ -351,7 +351,7 @@ jobs: uv run --no-project -m -- pytest --verbose --color=yes --durations=0 --showlocals --cache-clear ) - "${PYTEST[@]}" --maxfail=3 --numprocesses=4 \ + "${PYTEST[@]}" --maxfail=3 --numprocesses=8 \ -k metal \ ./python @@ -485,7 +485,7 @@ jobs: uv run --no-project -m -- pytest --verbose --color=yes --durations=0 --showlocals --cache-clear ) - "${PYTEST[@]}" --maxfail=3 --numprocesses=4 \ + "${PYTEST[@]}" --maxfail=3 --numprocesses=8 \ ../examples - name: List generated files From 38cac969d81237edbf1c89a2b770d6aad42d42ce Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Wed, 22 Apr 2026 11:19:32 +0800 Subject: [PATCH 108/156] fix dead-lock bug --- src/transform/auto_schedule/barrier.h | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index 75bf140ea4..ee07e91b89 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -802,17 +802,13 @@ static void InsertSynchronization( } auto check_need_barrier = [&](ScheduleUnit *waiting_unit, int waiting_wg_id) { + if (unit == waiting_unit) + // Note: the logic here need some assumption. + return false; if (wg_id != waiting_wg_id) return true; if (!is_async) return false; - if (auto task = GetInnerTask(unit)) { - if (auto waiting_task = GetInnerTask(waiting_unit)) { - if (task->is_TCGEN05() && waiting_task->is_TCGEN05()) { - return false; - } - } - } return true; }; bool need_barrier = false; From 4b3127a7fa15739c2215ebd3267be7844d0489b8 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Wed, 22 Apr 2026 12:46:45 +0800 Subject: [PATCH 109/156] [Enhancement] Enhance alloc_var function to handle _ptr_sentinel dtype (#2078) Enhance alloc_var function to handle _ptr_sentinel dtype Updated the alloc_var function to check for the _ptr_sentinel dtype and default to int64 if encountered. This change improves type handling and ensures compatibility with new tensor allocation scenarios. --- tilelang/language/allocate.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index a2d6b96c1d..15a76897f7 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -27,7 +27,7 @@ from . import dtypes as _dtypes from .dtypes import dtype as tl_dtype from .eager.builder import OutTensor -from .proxy import Tensor +from .proxy import Tensor, ptr as _ptr_sentinel def alloc_shared(shape: ShapeType, dtype: DType, scope="shared.dyn") -> Buffer: @@ -132,6 +132,9 @@ def alloc_var(dtype: DType, *args, scope: str = "local.var", init: PrimExpr | in if not isinstance(parsed_scope, str): raise TypeError("Scope must be a string in alloc_var.") + if dtype is _ptr_sentinel: + dtype = _dtypes.int64 + buffer = T.alloc_buffer([1], dtype, scope=parsed_scope) if parsed_init is not None: if isinstance(parsed_init, (int, float, IntImm, FloatImm)): From 6fee85002ec6a869456202b162ade2d4010b7dd2 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Wed, 22 Apr 2026 12:48:41 +0800 Subject: [PATCH 110/156] [Release] Bump version into 0.1.9 (#2060) * bump version into 0.1.9 * remove legacy test --- VERSION | 2 +- .../kernel/test_tilelang_kernel_int4_gemm.py | 40 -- .../test_tilelang_kernel_int4_gemm_mma.py | 404 ------------------ 3 files changed, 1 insertion(+), 445 deletions(-) delete mode 100644 testing/python/kernel/test_tilelang_kernel_int4_gemm.py delete mode 100644 testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py diff --git a/VERSION b/VERSION index 699c6c6d4e..1a030947e8 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.8 +0.1.9 diff --git a/testing/python/kernel/test_tilelang_kernel_int4_gemm.py b/testing/python/kernel/test_tilelang_kernel_int4_gemm.py deleted file mode 100644 index 3804d6303b..0000000000 --- a/testing/python/kernel/test_tilelang_kernel_int4_gemm.py +++ /dev/null @@ -1,40 +0,0 @@ -import tilelang -import tilelang.testing -import tilelang.language as T - - -def matmul_nt_int4(M, N, K, block_M, block_N, block_K): - @T.prim_func - def main( - A: T.Tensor((M, K), T.int4), - B: T.Tensor((N, K), T.int4), - C: T.Tensor((M, N), T.int32), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), T.int4) - B_shared = T.alloc_shared((block_N, block_K), T.int4) - C_local = T.alloc_fragment((block_M, block_N), T.int32) - - T.clear(C_local) - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): - T.copy(A[by * block_M, ko * block_K], A_shared) - T.copy(B[bx * block_N, ko * block_K], B_shared) - T.gemm(A_shared, B_shared, C_local, transpose_B=True) - - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_eq(8, 0) -def test_compile_int4_gemm_tgemm(): - func = matmul_nt_int4(1024, 1024, 1024, 128, 128, 64) - kernel = tilelang.compile(func, out_idx=-1) - src = kernel.get_kernel_source() - assert src is not None - assert "s4.s4.s32" in src or "int4" in src - - -if __name__ == "__main__": - tilelang.testing.main() diff --git a/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py deleted file mode 100644 index 5bcd955488..0000000000 --- a/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py +++ /dev/null @@ -1,404 +0,0 @@ -import torch -import tilelang -from tilelang import tvm as tvm -import tilelang.testing -import tilelang.language as T -from tilelang.intrinsics import ( - make_mma_swizzle_layout as make_swizzle_layout, -) -from tilelang.transform import simplify_prim_func -from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter, - TensorCoreIntrinEmitterWithLadderTransform, -) - -tilelang.testing.set_random_seed(42) - - -def tl_matmul( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, -): - assert in_dtype in [ - T.float16, - T.int8, - ], "Currently only float16 and int8 are supported" - assert out_dtype in [ - T.float16, - T.float32, - T.int32, - ], "Currently only float16, float32 and int32 are supported" - - K = K // 2 - - micro_size_x = micro_size_y = micro_size_k = 16 - - if accum_dtype == T.int32: - micro_size_k = 32 - - # This is a debug config - block_row_warps = 2 - block_col_warps = 2 - warp_row_tiles = 64 - warp_col_tiles = 64 - chunk = 32 if in_dtype == T.float16 else 64 - shared_scope = "shared.dyn" - - # Pipeline Stage - stage = 2 - - block_M = block_row_warps * warp_row_tiles - block_N = block_col_warps * warp_col_tiles - block_K = chunk - - A_shape = (M, K) # int8 storage represents int4*2 - B_shape = (N, K) # int8 storage represents int4*2 - A_shared_shape = (block_M, block_K) - B_shared_shape = (block_N, block_K) - C_shared_shape = ( - block_M // micro_size_x, - block_N // micro_size_y, - micro_size_x, - micro_size_y, - ) - - warp_size = 32 - threads = warp_size * (block_row_warps * block_col_warps) - local_size_a = (micro_size_x * micro_size_k) // warp_size - local_size_b = (micro_size_y * micro_size_k) // warp_size - local_size_c = (micro_size_x * micro_size_y) // warp_size - warp_rows = warp_row_tiles // micro_size_x - warp_cols = warp_col_tiles // micro_size_y - - # MMA Wrapper to Auto Generate Code for MMA - mma_emitter = TensorCoreIntrinEmitter( - a_dtype=T.int4, - b_dtype=T.int4, - accum_dtype=accum_dtype, - a_transposed=False, - b_transposed=True, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - ) - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - - T.annotate_layout( - { - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - } - ) - - # Improve L2 Cache - T.use_swizzle(panel_size=10) - - T.clear(C_local) - - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=stage): - # Load A into shared memory - for i, k in T.Parallel(block_M, block_K): - A_shared[i, k] = A[by * block_M + i, ko * block_K + k] - - # Load B into shared memory - for j, k in T.Parallel(block_N, block_K): - B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] - - for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment - mma_emitter.ldmatrix_a( - A_local, - A_shared, - ki, - ) - - # Load B into fragment - mma_emitter.ldmatrix_b( - B_local, - B_shared, - ki, - ) - - # Perform Matrix Multiplication - mma_emitter.mma(A_local, B_local, C_local) - - # Perform STMatrix - mma_emitter.stmatrix( - C_local, - C_shared, - ) - - # Store shared into global - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] - - return main - - -def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): - matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - kernel = tilelang.compile( - matmul, - out_idx=[2], - ) - print(kernel.get_kernel_source()) - profiler = kernel.get_profiler() - - src_code = kernel.get_kernel_source() - # src_code is the generated cuda source - assert src_code is not None - - A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) - B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) - - compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) - compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) - C = kernel(compressed_A, compressed_B) - print(C) - latency = profiler.do_bench() - print(latency) - # Ensure that the latency is not None - assert latency is not None - - # Get Reference Result - ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) - - print(ref_c) - torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) - - -@tilelang.testing.requires_cuda -def test_assert_tl_matmul_correctness(): - assert_tl_matmul_correctness(128, 128, 128, T.int8, T.int32, T.int32) - assert_tl_matmul_correctness(128, 128, 64, T.int8, T.int32, T.int32) - - -@simplify_prim_func -def tl_matmul_weight_only_transform( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, -): - K = K // 2 - assert in_dtype in [ - T.float16, - T.int8, - ], "Currently only float16 and int8 are supported" - assert out_dtype in [ - T.float16, - T.float32, - T.int32, - ], "Currently only float16, float32 and int32 are supported" - - micro_size_x = micro_size_y = micro_size_k = 16 - - if out_dtype == T.int32: - micro_size_k = 32 - - transform_b = 3 - - # This is a debug config - block_row_warps = 2 - block_col_warps = 2 - warp_row_tiles = 64 - warp_col_tiles = 64 - chunk = 32 if in_dtype == T.float16 else 64 - shared_scope = "shared.dyn" - - # Pipeline Stage - stage = 2 - - block_M = block_row_warps * warp_row_tiles - block_N = block_col_warps * warp_col_tiles - block_K = chunk - - A_shape = (M, K) - B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) - A_shared_shape = ( - block_M, - block_K, - ) - B_shared_shape = ( - block_N // micro_size_y, - block_K // micro_size_k, - micro_size_y, - micro_size_k, - ) - C_shared_shape = ( - block_M // micro_size_x, - block_N // micro_size_y, - micro_size_x, - micro_size_y, - ) - warp_size = 32 - threads = warp_size * (block_row_warps * block_col_warps) - local_size_a = (micro_size_x * micro_size_k) // warp_size - local_size_b = (micro_size_y * micro_size_k) // warp_size - local_size_c = (micro_size_x * micro_size_y) // warp_size - warp_rows = warp_row_tiles // micro_size_x - warp_cols = warp_col_tiles // micro_size_y - - # MMA Wrapper to Auto Generate Code for MMA - mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( - a_dtype=T.int4, - b_dtype=T.int4, - accum_dtype=accum_dtype, - a_transposed=False, - b_transposed=True, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - transform_kind_b=transform_b, - ) - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - - T.annotate_layout( - { - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - } - ) - - # Improve L2 Cache - T.use_swizzle(panel_size=10) - - T.clear(C_local) - - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=stage): - # Load A into shared memory - for i, k in T.Parallel(block_M, block_K): - A_shared[i, k] = A[by * block_M + i, ko * block_K + k] - - # Load B into shared memory - for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, micro_size_y, micro_size_k): - B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, ko * (block_K // micro_size_k) + k, jj, kk] - - for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment - mma_emitter.ldmatrix_a( - A_local, - A_shared, - ki, - ) - - # Load B into fragment - mma_emitter.ldmatrix_b( - B_local, - B_shared, - ki, - ) - - # Perform Matrix Multiplication - mma_emitter.mma(A_local, B_local, C_local) - - # Perform STMatrix - mma_emitter.stmatrix( - C_local, - C_shared, - ) - - # Store shared into global - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] - - return main - - -def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): - import bitblas - - matmul = tl_matmul_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype) - kernel = tilelang.compile(matmul, out_idx=[2]) - profiler = kernel.get_profiler() - - src_code = kernel.get_kernel_source() - # src_code is the generated cuda source - assert src_code is not None - transform_b = 3 - - A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) - B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) - compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) - compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) - - ladder_permutate_config = bitblas.ops.LadderPermutateConfig( - M=N, - N=(K // 2), - datatype=T.int8, - storage_dtype=T.int8, - transform_kind=transform_b, - transpose_matrix=True, - ) - - ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) - LB = ladder_permutate(compressed_B.cpu()).cuda() - C = kernel(compressed_A, LB) - - latency = profiler.do_bench() - print(f"Latency: {latency}") - # Ensure that the latency is not None - assert latency is not None - - # Get Reference Result - ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) - print(C) - print(ref_c) - torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) - - -@tilelang.testing.requires_package("bitblas") -@tilelang.testing.requires_llvm -@tilelang.testing.requires_cuda -def test_assert_tl_matmul_weight_only_transform(): - assert_tl_matmul_weight_only_transform_correctness(128, 128, 128, T.int8, T.int32, T.int32) - - -if __name__ == "__main__": - tilelang.testing.main() From 1f831f32646ac7be50141c3156d450eba154418c Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Wed, 22 Apr 2026 14:30:00 +0800 Subject: [PATCH 111/156] fix register fragment reuse --- src/transform/auto_schedule.cc | 27 +- src/transform/auto_schedule/barrier.h | 3 +- src/transform/auto_schedule/ir_structure.cc | 35 ++- src/transform/auto_schedule/ir_structure.h | 22 +- .../auto_schedule/schedule_builder.cc | 251 ++++++++++++++++- .../auto_schedule/schedule_builder.h | 11 +- .../auto_schedule/warpgroup_partition.cc | 258 ++++++++++++++++-- .../auto_schedule/warpgroup_partition.h | 13 +- src/transform/if_condition_extract.cc | 10 +- 9 files changed, 574 insertions(+), 56 deletions(-) diff --git a/src/transform/auto_schedule.cc b/src/transform/auto_schedule.cc index 58e94127b9..b68b65f634 100644 --- a/src/transform/auto_schedule.cc +++ b/src/transform/auto_schedule.cc @@ -280,6 +280,7 @@ class IRStructureBuilder : public StmtVisitor { auto control_node = std::make_shared(); control_node->control = GetRef(op); control_node->task = std::make_shared(); + control_node->task->SetWarpgroupId(kWarpgroupBroadcast); control_node->task->stmts.push_back( For(op->loop_var, op->min, op->extent, op->kind, Evaluate(0), op->thread_binding, op->annotations, op->step, op->span)); @@ -376,6 +377,7 @@ class IRStructureBuilder : public StmtVisitor { auto wrapper_node = std::make_shared(); wrapper_node->wrapper = GetRef(op); auto task_node = std::make_shared(); + task_node->SetWarpgroupId(kWarpgroupBroadcast); task_node->stmts.push_back(GetLetDecl(op)); AnalyzeResourceUsage(GetLetDecl(op), task_node.get()); wrapper_node->task = std::move(task_node); @@ -394,6 +396,7 @@ class IRStructureBuilder : public StmtVisitor { auto wrapper_node = std::make_shared(); wrapper_node->wrapper = GetRef(op); auto task_node = std::make_shared(); + task_node->SetWarpgroupId(kWarpgroupBroadcast); task_node->stmts.push_back(GetAttrDecl(op)); AnalyzeResourceUsage(GetAttrDecl(op), task_node.get()); wrapper_node->task = std::move(task_node); @@ -520,8 +523,7 @@ class IRStructureBuilder : public StmtVisitor { } } } - } else if (op->op.same_as(gemm_op) || - op->op.same_as(wgmma_gemm_op) || + } else if (op->op.same_as(gemm_op) || op->op.same_as(wgmma_gemm_op) || op->op.same_as(tcgen05_gemm_op)) { found_tensor = true; @@ -687,6 +689,7 @@ struct ScheduledKernelResult { std::vector barrier_buffers; Map barrier_map; std::vector buffer_infos; + std::vector duplicated_fragment_buffers; PrimExpr updated_thread_extent; bool did_warpgroup_partition{false}; }; @@ -768,7 +771,7 @@ ScheduleSingleKernel(const Stmt &kernel_body, IterVar thread_var, Target target, result.scheduled_body = ApplyWarpgroupPartitionToIRStructure( ir_structure.get(), thread_var, result.barrier_buffers, result.barrier_map, enable_epi, thread_count, config, - neutral_sync_shared_barrier); + neutral_sync_shared_barrier, result.duplicated_fragment_buffers); result.did_warpgroup_partition = true; return result; } @@ -874,8 +877,13 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { final_body = extent_updater(final_body); } // Add barrier buffers to tilelang_root block's alloc_buffers - if (!kr.barrier_buffers.empty()) { - final_body = AddBarrierBuffersToRoot(final_body, kr.barrier_buffers, + if (!kr.barrier_buffers.empty() || + !kr.duplicated_fragment_buffers.empty()) { + std::vector all_alloc_buffers = kr.barrier_buffers; + all_alloc_buffers.insert(all_alloc_buffers.end(), + kr.duplicated_fragment_buffers.begin(), + kr.duplicated_fragment_buffers.end()); + final_body = AddBarrierBuffersToRoot(final_body, all_alloc_buffers, kr.barrier_map); } // Apply multi-version alloc_buffer rewrite if needed @@ -952,9 +960,14 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { scheduled_subtree = extent_updater(scheduled_subtree); } // Add barrier buffers to this kernel's tilelang_root block - if (!kr.barrier_buffers.empty()) { + if (!kr.barrier_buffers.empty() || + !kr.duplicated_fragment_buffers.empty()) { + std::vector all_alloc_buffers = kr.barrier_buffers; + all_alloc_buffers.insert(all_alloc_buffers.end(), + kr.duplicated_fragment_buffers.begin(), + kr.duplicated_fragment_buffers.end()); scheduled_subtree = AddBarrierBuffersToRoot( - scheduled_subtree, kr.barrier_buffers, kr.barrier_map); + scheduled_subtree, all_alloc_buffers, kr.barrier_map); } // Apply multi-version alloc_buffer rewrite if needed if (!kr.buffer_infos.empty()) { diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index c1722e6e78..326331f730 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -463,8 +463,7 @@ static void RewriteGemmMbar(TaskNode *task, PrimExpr mbar_expr) { static const auto wgmma_gemm_op = Op::Get("tl.tileop.wgmma_gemm"); static const auto tcgen05_gemm_op = Op::Get("tl.tileop.tcgen05_gemm"); - if ((op->op.same_as(gemm_op) || - op->op.same_as(wgmma_gemm_op) || + if ((op->op.same_as(gemm_op) || op->op.same_as(wgmma_gemm_op) || op->op.same_as(tcgen05_gemm_op)) && op->args.size() > 16) { Array new_args; diff --git a/src/transform/auto_schedule/ir_structure.cc b/src/transform/auto_schedule/ir_structure.cc index b4cd281d38..9b7e987128 100644 --- a/src/transform/auto_schedule/ir_structure.cc +++ b/src/transform/auto_schedule/ir_structure.cc @@ -274,25 +274,36 @@ void TaskNode::CollectBufferAccessInfo( if (GetSchedulePhase() != phase) { return; } - // Collect write buffers - for (const auto ®ion : GetWriteRegions()) { - if (wg_id != -1) { - result.emplace(region->buffer, true, wg_id, phase); + + // Helper: emit buffer access for a single region. + auto emit_access = [&](const BufferRegion ®ion, bool is_write) { + if (wg_id >= 0) { + // Normal assigned warpgroup + result.emplace(region->buffer, is_write, wg_id, phase); + } else if (IsWarpgroupBroadcast(wg_id)) { + // Broadcast: skip register memory (each wg has its own copy) + if (IsRegisterRegion(region)) { + return; + } + // Shared/global memory is shared across wgs — emit for all + for (int i = 0; i < num_wgs; ++i) { + result.emplace(region->buffer, is_write, i, phase); + } } else { + // Unassigned (kWarpgroupUnassigned): expand to all wgs (legacy behavior) for (int i = 0; i < num_wgs; ++i) { - result.emplace(region->buffer, true, i, phase); + result.emplace(region->buffer, is_write, i, phase); } } + }; + + // Collect write buffers + for (const auto ®ion : GetWriteRegions()) { + emit_access(region, true); } // Collect read buffers for (const auto ®ion : GetReadRegions()) { - if (wg_id != -1) { - result.emplace(region->buffer, false, wg_id, phase); - } else { - for (int i = 0; i < num_wgs; ++i) { - result.emplace(region->buffer, false, i, phase); - } - } + emit_access(region, false); } } diff --git a/src/transform/auto_schedule/ir_structure.h b/src/transform/auto_schedule/ir_structure.h index e9098fb9c9..a1e28318b5 100644 --- a/src/transform/auto_schedule/ir_structure.h +++ b/src/transform/auto_schedule/ir_structure.h @@ -38,6 +38,19 @@ enum class SchedulePhase : uint8_t { kEpilogue = 2, // Runs on ALL threads AFTER warpgroup-specific code }; +// Special warpgroup id constants +constexpr int kWarpgroupUnassigned = -1; // Not yet assigned (initial state) +constexpr int kWarpgroupBroadcast = -2; // Broadcast: the statement is cloned + // into every warp group; each wg + // operates on its own register copies. + // No cross-wg sync is needed for + // register (local.fragment) buffers. + +// Helper: check if a warpgroup id represents a broadcast task +inline bool IsWarpgroupBroadcast(int wg_id) { + return wg_id == kWarpgroupBroadcast; +} + // Structure to store buffer access information struct BufferAccessInfo { Buffer buffer; @@ -149,8 +162,8 @@ class IRStructure { // Substitute a variable throughout this IR node virtual void SubstituteVar(const Var &old_var, const Var &new_var) = 0; - // Get warpgroup id for this node (-1 if not applicable) - virtual int GetWarpgroupId() const { return -1; } + // Get warpgroup id for this node (kWarpgroupUnassigned if not applicable) + virtual int GetWarpgroupId() const { return kWarpgroupUnassigned; } // Get scheduling phase for this node virtual SchedulePhase GetSchedulePhase() const { @@ -356,7 +369,8 @@ class TaskNode : public IRStructure { std::set &result) const override; bool containWarpgroupId(int id) const override { - return ContainsLoopBreak() || warpgroup_id_ == id; + return ContainsLoopBreak() || IsWarpgroupBroadcast(warpgroup_id_) || + warpgroup_id_ == id; } // Check if this task contains loop_break call @@ -379,7 +393,7 @@ class TaskNode : public IRStructure { int64_t latency_{0}; // Estimated latency in cycles int64_t ii_{0}; // Initiation interval in cycles int warpgroup_id_{ - -1}; // Warpgroup id for warpgroup specialization (-1 means unassigned) + kWarpgroupUnassigned}; // Warpgroup id for warpgroup specialization SchedulePhase schedule_phase_{ SchedulePhase::kBody}; // Scheduling phase (prologue/body/epilogue) diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index a753112c85..260108c9f3 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -46,6 +46,7 @@ #include #include #include +#include #include #include #include @@ -197,6 +198,240 @@ bool HasRegisterRegion(const IRStructure *node) { return CountRegisterRegions(node) > 0; } +// Collect register buffers read by all broadcast tasks in the IR tree. +static void CollectBroadcastRegisterReads( + IRStructure *node, std::unordered_set ®_bufs) { + if (!node) + return; + auto collect_from_leaf_task = [&](const TaskNode *task) { + if (!task) + return; + int wg_id = task->GetWarpgroupId(); + if (!IsWarpgroupBroadcast(wg_id) && wg_id != kWarpgroupUnassigned) + return; + for (const auto ®ion : task->GetReadRegions()) { + if (IsRegisterRegion(region)) { + reg_bufs.insert(region->buffer.get()); + } + } + }; + auto collect_from_structural_task = [&](const TaskNode *task) { + if (!task) + return; + for (const auto ®ion : task->GetReadRegions()) { + if (IsRegisterRegion(region)) { + reg_bufs.insert(region->buffer.get()); + } + } + }; + + if (node->IsTask()) { + collect_from_leaf_task(static_cast(node)); + } else if (node->IsControl()) { + auto ctrl = static_cast(node); + if (ctrl->task) + collect_from_structural_task(ctrl->task.get()); + CollectBroadcastRegisterReads(ctrl->child.get(), reg_bufs); + } else if (node->IsWrapper()) { + auto wrapper = static_cast(node); + if (wrapper->task) + collect_from_structural_task(wrapper->task.get()); + CollectBroadcastRegisterReads(wrapper->child.get(), reg_bufs); + } else if (node->IsSequence()) { + auto seq = static_cast(node); + for (auto &child : seq->children) { + CollectBroadcastRegisterReads(child.get(), reg_bufs); + } + } else if (node->IsScheduleUnit()) { + auto unit = static_cast(node); + CollectBroadcastRegisterReads(unit->child.get(), reg_bufs); + } else if (node->IsIf()) { + auto if_node = static_cast(node); + if (if_node->task) + collect_from_structural_task(if_node->task.get()); + CollectBroadcastRegisterReads(if_node->then_child.get(), reg_bufs); + if (if_node->else_child) + CollectBroadcastRegisterReads(if_node->else_child.get(), reg_bufs); + } +} + +// Propagate broadcast: if a broadcast task reads a register buffer, +// any leaf task that writes that register buffer must also be broadcast +// (because each wg needs its own initialized copy). +void PropagateBroadcastWarpgroupId(IRStructure *root) { + std::vector all_tasks; + CollectAllTaskNodesWithContext(root, all_tasks); + + bool changed = true; + while (changed) { + changed = false; + // 1. Collect register buffers read by all current broadcast tasks + std::unordered_set broadcast_reg_reads; + CollectBroadcastRegisterReads(root, broadcast_reg_reads); + // Also collect from leaf broadcast tasks + for (auto &task_ctx : all_tasks) { + TaskNode *task = task_ctx.task; + if (!IsWarpgroupBroadcast(task->GetWarpgroupId())) + continue; + for (const auto ®ion : task->GetReadRegions()) { + if (IsRegisterRegion(region)) { + broadcast_reg_reads.insert(region->buffer.get()); + } + } + } + // 2. Mark leaf tasks that write these register buffers as broadcast + if (!broadcast_reg_reads.empty()) { + for (auto &task_ctx : all_tasks) { + TaskNode *task = task_ctx.task; + if (IsWarpgroupBroadcast(task->GetWarpgroupId())) + continue; + for (const auto ®ion : task->GetWriteRegions()) { + if (IsRegisterRegion(region) && + broadcast_reg_reads.count(region->buffer.get())) { + task->SetWarpgroupId(kWarpgroupBroadcast); + changed = true; + break; + } + } + } + } + + // 3. Detect cross-warpgroup register buffer / scalar-var accesses + constexpr int kReaderAnyWg = std::numeric_limits::min(); + std::unordered_map> + buffer_reader_wgs; + std::unordered_map> + buffer_writer_wgs; + std::unordered_map> var_reader_wgs; + std::unordered_map> var_writer_wgs; + + auto add_accesses = [&](const TaskNode *task, int reader_key, int wg_id) { + for (const auto ®ion : task->GetReadRegions()) { + if (IsRegisterRegion(region)) { + buffer_reader_wgs[region->buffer.get()].insert(reader_key); + } + } + if (wg_id >= 0) { + for (const auto ®ion : task->GetWriteRegions()) { + if (IsRegisterRegion(region)) { + buffer_writer_wgs[region->buffer.get()].insert(wg_id); + } + } + } + for (const auto &v : task->GetReadVars()) { + var_reader_wgs[v.get()].insert(reader_key); + } + if (wg_id >= 0) { + for (const auto &v : task->GetWriteVars()) { + var_writer_wgs[v.get()].insert(wg_id); + } + } + }; + + for (auto &task_ctx : all_tasks) { + TaskNode *task = task_ctx.task; + int wg_id = task->GetWarpgroupId(); + int reader_key = (wg_id >= 0) ? wg_id : kReaderAnyWg; + add_accesses(task, reader_key, wg_id); + } + // Also include structural tasks + std::function walk_structural = + [&](IRStructure *node) { + if (!node) + return; + auto add_struct = [&](const TaskNode *t) { + if (!t) + return; + int wg = t->GetWarpgroupId(); + int key = (wg >= 0) ? wg : kReaderAnyWg; + if (IsWarpgroupBroadcast(wg)) + key = kReaderAnyWg; + add_accesses(t, key, -1); + }; + if (node->IsControl()) { + auto *c = static_cast(node); + add_struct(c->task.get()); + walk_structural(c->child.get()); + } else if (node->IsWrapper()) { + auto *w = static_cast(node); + add_struct(w->task.get()); + walk_structural(w->child.get()); + } else if (node->IsIf()) { + auto *i = static_cast(node); + add_struct(i->task.get()); + walk_structural(i->then_child.get()); + walk_structural(i->else_child.get()); + } else if (node->IsSequence()) { + auto *s = static_cast(node); + for (auto &c : s->children) + walk_structural(c.get()); + } else if (node->IsScheduleUnit()) { + auto *u = static_cast(node); + walk_structural(u->child.get()); + } + // Leaf tasks already handled by the all_tasks loop. + }; + walk_structural(root); + + std::unordered_set cross_wg_buffers; + for (const auto &kv : buffer_reader_wgs) { + const auto *buf = kv.first; + const auto &reader_wgs = kv.second; + const auto &writer_wgs = buffer_writer_wgs[buf]; + for (int rwg : reader_wgs) { + if (writer_wgs.find(rwg) == writer_wgs.end()) { + cross_wg_buffers.insert(buf); + break; + } + } + } + + std::unordered_set cross_wg_vars; + for (const auto &kv : var_reader_wgs) { + const auto *v = kv.first; + const auto &reader_wgs = kv.second; + auto it_w = var_writer_wgs.find(v); + if (it_w == var_writer_wgs.end()) + continue; + const auto &writer_wgs = it_w->second; + for (int rwg : reader_wgs) { + if (writer_wgs.find(rwg) == writer_wgs.end()) { + cross_wg_vars.insert(v); + break; + } + } + } + + if (!cross_wg_buffers.empty() || !cross_wg_vars.empty()) { + for (auto &task_ctx : all_tasks) { + TaskNode *task = task_ctx.task; + if (IsWarpgroupBroadcast(task->GetWarpgroupId())) + continue; + bool should_broadcast = false; + for (const auto ®ion : task->GetWriteRegions()) { + if (IsRegisterRegion(region) && + cross_wg_buffers.count(region->buffer.get())) { + should_broadcast = true; + break; + } + } + if (!should_broadcast) { + for (const auto &v : task->GetWriteVars()) { + if (cross_wg_vars.count(v.get())) { + should_broadcast = true; + break; + } + } + } + if (should_broadcast) { + task->SetWarpgroupId(kWarpgroupBroadcast); + changed = true; + } + } + } + } +} + bool HasResourceDependency(const IRStructure *a, const IRStructure *b) { if (a->UsesTMACore() && b->UsesTMACore()) return true; @@ -370,7 +605,14 @@ AssignWarpgroupIdsGlobal(IRStructure *root, const WarpSpecializeConfig &config, int n = all_tasks.size(); for (auto &task_ctx : all_tasks) { - task_ctx.task->SetWarpgroupId(-1); + task_ctx.task->SetWarpgroupId(kWarpgroupUnassigned); + } + + // Tasks with loop_break are broadcast to all warp groups + for (auto &task_ctx : all_tasks) { + if (task_ctx.task->ContainsLoopBreak()) { + task_ctx.task->SetWarpgroupId(kWarpgroupBroadcast); + } } TaskUnionFind uf(n); @@ -747,7 +989,7 @@ NaiveAssignWarpgroupIds(IRStructure *root, const WarpSpecializeConfig &config, for (auto &task_ctx : all_tasks) { TaskNode *task = task_ctx.task; if (task->ContainsLoopBreak()) { - task->SetWarpgroupId(-1); + task->SetWarpgroupId(kWarpgroupBroadcast); continue; } if (task->UsesTMACore() && !task->UsesTensorCore()) { @@ -1003,7 +1245,10 @@ void ScheduleUnitBuilder::NaiveScheduleRecursive( std::vector ScheduleUnitBuilder::NaiveBuild(std::shared_ptr &root) { NaiveScheduleRecursive(root); - return NaiveAssignWarpgroupIds(root.get(), config_, thread_var_->dom->extent); + auto result = + NaiveAssignWarpgroupIds(root.get(), config_, thread_var_->dom->extent); + PropagateBroadcastWarpgroupId(root.get()); + return result; } } // namespace tl diff --git a/src/transform/auto_schedule/schedule_builder.h b/src/transform/auto_schedule/schedule_builder.h index 975dddd696..1aef87d785 100644 --- a/src/transform/auto_schedule/schedule_builder.h +++ b/src/transform/auto_schedule/schedule_builder.h @@ -52,11 +52,14 @@ std::vector AssignWarpgroupIdsGlobal(IRStructure *root, const WarpSpecializeConfig &config, PrimExpr thread_count); -// Naive warpgroup assignment: TMA→wg1, compute→wg0, neutral→-1 +// Naive warpgroup assignment: TMA→wg1, compute→wg0, +// broadcast→kWarpgroupBroadcast std::vector NaiveAssignWarpgroupIds(IRStructure *root, const WarpSpecializeConfig &config, PrimExpr thread_count); +void PropagateBroadcastWarpgroupId(IRStructure *root); + // Extract all sequential task nodes from the IR structure tree void GatherTaskNodes(const std::vector> &nodes, std::vector> &task_nodes); @@ -100,8 +103,10 @@ class ScheduleUnitBuilder { ScheduleRecursive(root, {}); // Global warpgroup id assignment from the top level - return AssignWarpgroupIdsGlobal(root.get(), config_, - thread_var_->dom->extent); + auto result = + AssignWarpgroupIdsGlobal(root.get(), config_, thread_var_->dom->extent); + PropagateBroadcastWarpgroupId(root.get()); + return result; } // Naive build: preserve original order, assign pipeline stages based on diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index 36a1a3225b..58ebb731b9 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -124,6 +124,153 @@ static Stmt RenameLetStmtVars(Stmt stmt, const std::string &suffix) { return LetStmtVarRenamer(suffix).Rename(std::move(stmt)); } +// Mutator that replaces references to selected Buffers with their duplicates. +class BufferRemapMutator : public StmtExprMutator { +public: + explicit BufferRemapMutator(const Map &buffer_remap) + : buffer_remap_(buffer_remap) { + for (const auto &kv : buffer_remap_) { + var_to_new_buffer_.Set(kv.first->data, kv.second); + } + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + auto it = buffer_remap_.find(load->buffer); + if (it == buffer_remap_.end()) + return std::move(load); + auto *n = load.CopyOnWrite(); + n->buffer = (*it).second; + return std::move(load); + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + auto it = buffer_remap_.find(store->buffer); + if (it == buffer_remap_.end()) + return std::move(store); + auto *n = store.CopyOnWrite(); + n->buffer = (*it).second; + return std::move(store); + } + + PrimExpr VisitExpr_(const VarNode *op) final { + Var var = GetRef(op); + auto it = var_to_new_buffer_.find(var); + if (it != var_to_new_buffer_.end()) { + return (*it).second->data; + } + return StmtExprMutator::VisitExpr_(op); + } + + BufferRegion RemapRegion(const BufferRegion ®ion) const { + auto it = buffer_remap_.find(region->buffer); + if (it == buffer_remap_.end()) + return region; + return BufferRegion((*it).second, region->region); + } + + Var RemapVar(const Var &var) const { + auto it = var_to_new_buffer_.find(var); + if (it == var_to_new_buffer_.end()) + return var; + return (*it).second->data; + } + +private: + const Map &buffer_remap_; + Map var_to_new_buffer_; +}; + +// Collect all local.fragment Buffers +static void CollectBroadcastFragmentBuffersImpl( + const IRStructure *node, std::unordered_set &seen, + std::vector &out) { + if (!node) + return; + if (node->IsTask()) { + auto task = static_cast(node); + if (IsWarpgroupBroadcast(task->GetWarpgroupId())) { + for (const auto ®ion : task->GetWriteRegions()) { + if (IsRegisterRegion(region) && + region->buffer.scope() == "local.fragment") { + if (!seen.count(region->buffer.get())) { + seen.insert(region->buffer.get()); + out.push_back(region->buffer); + } + } + } + } + } else if (node->IsSequence()) { + for (const auto &child : + static_cast(node)->children) { + CollectBroadcastFragmentBuffersImpl(child.get(), seen, out); + } + } else if (node->IsControl()) { + auto ctrl = static_cast(node); + CollectBroadcastFragmentBuffersImpl(ctrl->task.get(), seen, out); + CollectBroadcastFragmentBuffersImpl(ctrl->child.get(), seen, out); + } else if (node->IsWrapper()) { + auto wrapper = static_cast(node); + CollectBroadcastFragmentBuffersImpl(wrapper->task.get(), seen, out); + CollectBroadcastFragmentBuffersImpl(wrapper->child.get(), seen, out); + } else if (node->IsScheduleUnit()) { + CollectBroadcastFragmentBuffersImpl( + static_cast(node)->child.get(), seen, out); + } else if (node->IsIf()) { + auto if_node = static_cast(node); + CollectBroadcastFragmentBuffersImpl(if_node->task.get(), seen, out); + CollectBroadcastFragmentBuffersImpl(if_node->then_child.get(), seen, out); + CollectBroadcastFragmentBuffersImpl(if_node->else_child.get(), seen, out); + } +} + +static std::vector +CollectBroadcastFragmentBuffers(const IRStructure *root) { + std::vector result; + std::unordered_set seen; + CollectBroadcastFragmentBuffersImpl(root, seen, result); + return result; +} + +static Buffer DuplicateFragmentBuffer(const Buffer &buffer, + const std::string &suffix) { + Type new_type = buffer->data->type_annotation; + if (IsFragmentBuffer(buffer)) { + const auto *ptr_type = buffer->data->type_annotation.as(); + ICHECK(ptr_type); + new_type = PointerType(ptr_type->element_type, "local"); + } + Var new_var(buffer->data->name_hint + suffix, new_type); + return Buffer(new_var, buffer->dtype, buffer->shape, buffer->strides, + buffer->elem_offset, buffer->name + suffix, + buffer->data_alignment, buffer->offset_factor, + buffer->buffer_type); +} + +static void ApplyBufferRemapToTask(TaskNode *task, + BufferRemapMutator &mutator) { + for (size_t i = 0; i < task->stmts.size(); ++i) { + task->stmts[i] = mutator(task->stmts[i]); + } + auto read_regions = task->GetReadRegions(); + for (auto &r : read_regions) + r = mutator.RemapRegion(r); + task->SetReadRegions(read_regions); + auto write_regions = task->GetWriteRegions(); + for (auto &r : write_regions) + r = mutator.RemapRegion(r); + task->SetWriteRegions(write_regions); + auto read_vars = task->GetReadVars(); + for (auto &v : read_vars) + v = mutator.RemapVar(v); + task->SetReadVars(read_vars); + auto write_vars = task->GetWriteVars(); + for (auto &v : write_vars) + v = mutator.RemapVar(v); + task->SetWriteVars(write_vars); +} + bool IsLetDeclTask(const TaskNode *task) { return task->stmts.size() == 1 && task->stmts[0].as() != nullptr; } @@ -178,10 +325,30 @@ bool ContainsLetDecl(const IRStructure *node) { // Helper function to clone IRStructure with warpgroup filter. std::shared_ptr CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, - Map &var_remap) { + Map &var_remap, + Map &buffer_remap) { if (!node) return nullptr; + auto apply_buffer_remap_stmt = [&](Stmt s) -> Stmt { + if (buffer_remap.empty()) + return s; + BufferRemapMutator m(buffer_remap); + return m(std::move(s)); + }; + auto apply_buffer_remap_expr = [&](PrimExpr e) -> PrimExpr { + if (buffer_remap.empty()) + return e; + BufferRemapMutator m(buffer_remap); + return m(std::move(e)); + }; + auto apply_buffer_remap_task = [&](TaskNode *ct) { + if (buffer_remap.empty()) + return; + BufferRemapMutator m(buffer_remap); + ApplyBufferRemapToTask(ct, m); + }; + if (node->IsTask()) { auto task = static_cast(node); if (task->GetSchedulePhase() != SchedulePhase::kBody) { @@ -197,6 +364,7 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, // Substitute previously renamed variables in the value expression. PrimExpr new_value = var_remap.empty() ? let->value : Substitute(let->value, var_remap); + new_value = apply_buffer_remap_expr(new_value); var_remap.Set(let->var, new_var); auto new_task = std::make_shared(); new_task->stmts.push_back(LetStmt(new_var, new_value, Evaluate(0))); @@ -214,6 +382,7 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, ct->stmts[i] = Substitute(ct->stmts[i], var_remap); } } + apply_buffer_remap_task(static_cast(cloned.get())); return cloned; } else if (node->IsSequence()) { // A SequenceNode is included if it contains the target warp group @@ -224,7 +393,7 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, auto new_seq = std::make_shared(); for (const auto &child : seq->children) { auto new_child = CloneIRStructureWithWarpgroupFilter( - child.get(), warpgroup_id, var_remap); + child.get(), warpgroup_id, var_remap, buffer_remap); if (new_child) { new_seq->children.push_back(std::move(new_child)); } @@ -243,12 +412,13 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, auto new_loop_var = ctrl->control->loop_var.copy_with_suffix(""); new_for.CopyOnWrite()->loop_var = new_loop_var; var_remap.Set(ctrl->control->loop_var, new_loop_var); - new_for.CopyOnWrite()->min = Substitute(ctrl->control->min, var_remap); + new_for.CopyOnWrite()->min = + apply_buffer_remap_expr(Substitute(ctrl->control->min, var_remap)); new_for.CopyOnWrite()->extent = - Substitute(ctrl->control->extent, var_remap); + apply_buffer_remap_expr(Substitute(ctrl->control->extent, var_remap)); if (ctrl->control->step.has_value()) { - new_for.CopyOnWrite()->step = - Substitute(ctrl->control->step.value(), var_remap); + new_for.CopyOnWrite()->step = apply_buffer_remap_expr( + Substitute(ctrl->control->step.value(), var_remap)); } new_ctrl->control = new_for; // Clone the task and apply var_remap so each warpgroup gets its own copy @@ -261,11 +431,12 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, cloned_task->stmts[i] = Substitute(cloned_task->stmts[i], var_remap); } } + apply_buffer_remap_task(cloned_task.get()); new_ctrl->task = std::move(cloned_task); } new_ctrl->SetPromote(ctrl->hasPromote()); new_ctrl->child = CloneIRStructureWithWarpgroupFilter( - ctrl->child.get(), warpgroup_id, var_remap); + ctrl->child.get(), warpgroup_id, var_remap, buffer_remap); return new_ctrl; } else if (node->IsWrapper()) { if (!node->containWarpgroupId(warpgroup_id) && !ContainsLetDecl(node)) @@ -277,8 +448,9 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, new_wrapper->wrapper = var_remap.empty() ? wrapper->wrapper : Substitute(wrapper->wrapper, var_remap); + new_wrapper->wrapper = apply_buffer_remap_stmt(new_wrapper->wrapper); new_wrapper->child = CloneIRStructureWithWarpgroupFilter( - wrapper->child.get(), warpgroup_id, var_remap); + wrapper->child.get(), warpgroup_id, var_remap, buffer_remap); return new_wrapper; } else if (node->IsScheduleUnit()) { auto unit = static_cast(node); @@ -292,7 +464,7 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, auto new_unit = std::make_shared(); new_unit->stage = unit->stage; new_unit->child = CloneIRStructureWithWarpgroupFilter( - unit->child.get(), warpgroup_id, var_remap); + unit->child.get(), warpgroup_id, var_remap, buffer_remap); if (!child_is_let_decl) { // Copy before/after for the target warp group @@ -307,6 +479,12 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, s = Substitute(s, var_remap); } } + for (auto &s : new_unit->before[warpgroup_id]) { + s = apply_buffer_remap_stmt(s); + } + for (auto &s : new_unit->after[warpgroup_id]) { + s = apply_buffer_remap_stmt(s); + } } return new_unit; } else if (node->IsIf()) { @@ -317,6 +495,7 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, new_if->condition = var_remap.empty() ? if_node->condition : Substitute(if_node->condition, var_remap); + new_if->condition = apply_buffer_remap_expr(new_if->condition); if (if_node->task) { auto cloned_task = std::static_pointer_cast(if_node->task->Clone()); @@ -325,13 +504,14 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, cloned_task->stmts[i] = Substitute(cloned_task->stmts[i], var_remap); } } + apply_buffer_remap_task(cloned_task.get()); new_if->task = std::move(cloned_task); } new_if->then_child = CloneIRStructureWithWarpgroupFilter( - if_node->then_child.get(), warpgroup_id, var_remap); + if_node->then_child.get(), warpgroup_id, var_remap, buffer_remap); if (if_node->else_child) { new_if->else_child = CloneIRStructureWithWarpgroupFilter( - if_node->else_child.get(), warpgroup_id, var_remap); + if_node->else_child.get(), warpgroup_id, var_remap, buffer_remap); } // Return nullptr if both branches are empty if (!new_if->then_child && !new_if->else_child) @@ -342,11 +522,20 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, return nullptr; } -// Entry point overload — creates a fresh var_remap per call +std::shared_ptr +CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, + Map &var_remap) { + Map buffer_remap; + return CloneIRStructureWithWarpgroupFilter(node, warpgroup_id, var_remap, + buffer_remap); +} + std::shared_ptr CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id) { Map var_remap; - return CloneIRStructureWithWarpgroupFilter(node, warpgroup_id, var_remap); + Map buffer_remap; + return CloneIRStructureWithWarpgroupFilter(node, warpgroup_id, var_remap, + buffer_remap); } // For each child of a root SequenceNode, apply @@ -354,16 +543,26 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id) { std::vector> CloneIRStructureChildrenWithWarpgroupFilter(SequenceNode *root_seq, int warpgroup_id, - Map &var_remap) { + Map &var_remap, + Map &buffer_remap) { std::vector> result; result.reserve(root_seq->children.size()); for (const auto &child : root_seq->children) { result.push_back(CloneIRStructureWithWarpgroupFilter( - child.get(), warpgroup_id, var_remap)); + child.get(), warpgroup_id, var_remap, buffer_remap)); } return result; } +std::vector> +CloneIRStructureChildrenWithWarpgroupFilter(SequenceNode *root_seq, + int warpgroup_id, + Map &var_remap) { + Map buffer_remap; + return CloneIRStructureChildrenWithWarpgroupFilter(root_seq, warpgroup_id, + var_remap, buffer_remap); +} + std::shared_ptr RemoveUnusedLetDecls(std::shared_ptr root) { if (!root) @@ -834,7 +1033,8 @@ Stmt ApplyWarpgroupPartitionToIRStructure( IRStructure *root, IterVar thread_var, std::vector &barrier_buffers, Map &barrier_map, const bool outer_enable_epi, const std::vector &thread_count, - const WarpSpecializeConfig &config, Buffer neutral_sync_shared_barrier) { + const WarpSpecializeConfig &config, Buffer neutral_sync_shared_barrier, + std::vector &duplicated_fragment_buffers) { if (!root) return Evaluate(0); @@ -844,7 +1044,8 @@ Stmt ApplyWarpgroupPartitionToIRStructure( if (wrapper->child) { body = ApplyWarpgroupPartitionToIRStructure( wrapper->child.get(), thread_var, barrier_buffers, barrier_map, - outer_enable_epi, thread_count, config, neutral_sync_shared_barrier); + outer_enable_epi, thread_count, config, neutral_sync_shared_barrier, + duplicated_fragment_buffers); } if (const auto *let = wrapper->wrapper.as()) { return LetStmt(let->var, let->value, body); @@ -927,12 +1128,27 @@ Stmt ApplyWarpgroupPartitionToIRStructure( // wg_children[wg_id][child_index] = filtered IRStructure (nullptr if absent) std::vector>> wg_children(num_wgs); std::vector> wg_structures(num_wgs); + + std::vector broadcast_fragments = + CollectBroadcastFragmentBuffers(root); + std::vector> per_wg_buffer_remap(num_wgs); + if (!broadcast_fragments.empty()) { + for (size_t i = 1; i < num_wgs; ++i) { + std::string suffix = "_wg" + std::to_string(i); + for (const auto &buf : broadcast_fragments) { + Buffer new_buf = DuplicateFragmentBuffer(buf, suffix); + per_wg_buffer_remap[i].Set(buf, new_buf); + duplicated_fragment_buffers.push_back(new_buf); + } + } + } + if (root->IsSequence()) { auto root_seq = static_cast(root); for (size_t i = 0; i < num_wgs; ++i) { Map var_remap; - wg_children[i] = - CloneIRStructureChildrenWithWarpgroupFilter(root_seq, i, var_remap); + wg_children[i] = CloneIRStructureChildrenWithWarpgroupFilter( + root_seq, i, var_remap, per_wg_buffer_remap[i]); } for (size_t i = 0; i < num_wgs; ++i) { // Rebuild from wg_children: wrap non-null children into a SequenceNode @@ -946,8 +1162,10 @@ Stmt ApplyWarpgroupPartitionToIRStructure( } else { // Fallback for non-SequenceNode root: clone entire root per warpgroup for (size_t i = 0; i < num_wgs; ++i) { + Map var_remap; wg_structures[i] = - RemoveUnusedLetDecls(CloneIRStructureWithWarpgroupFilter(root, i)); + RemoveUnusedLetDecls(CloneIRStructureWithWarpgroupFilter( + root, i, var_remap, per_wg_buffer_remap[i])); } } diff --git a/src/transform/auto_schedule/warpgroup_partition.h b/src/transform/auto_schedule/warpgroup_partition.h index a8ee663057..fcf0d7e5fa 100644 --- a/src/transform/auto_schedule/warpgroup_partition.h +++ b/src/transform/auto_schedule/warpgroup_partition.h @@ -33,6 +33,10 @@ bool IsLetDeclTask(const TaskNode *task); bool IsLetDeclNode(const IRStructure *node); bool ContainsLetDecl(const IRStructure *node); +std::shared_ptr +CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, + Map &var_remap, + Map &buffer_remap); std::shared_ptr CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, Map &var_remap); @@ -42,6 +46,12 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id); std::shared_ptr RemoveUnusedLetDecls(std::shared_ptr root); +std::vector> +CloneIRStructureChildrenWithWarpgroupFilter(SequenceNode *root_seq, + int warpgroup_id, + Map &var_remap, + Map &buffer_remap); + std::vector> CloneIRStructureChildrenWithWarpgroupFilter(SequenceNode *root_seq, int warpgroup_id, @@ -55,7 +65,8 @@ Stmt ApplyWarpgroupPartitionToIRStructure( IRStructure *root, IterVar thread_var, std::vector &barrier_buffers, Map &barrier_map, const bool enable_epi, const std::vector &thread_count, - const WarpSpecializeConfig &config, Buffer neutral_sync_shared_barrier); + const WarpSpecializeConfig &config, Buffer neutral_sync_shared_barrier, + std::vector &duplicated_fragment_buffers); Stmt ReNestLetStmts(const Stmt &stmt); diff --git a/src/transform/if_condition_extract.cc b/src/transform/if_condition_extract.cc index bca573b348..de6d669995 100644 --- a/src/transform/if_condition_extract.cc +++ b/src/transform/if_condition_extract.cc @@ -56,9 +56,10 @@ class IfConditionExtractor : public StmtExprMutator { is_simple = false; } - auto bind_cond_var = [](const Stmt &sentence, const Var &cond) -> Stmt { + auto bind_cond_var = [](const Stmt &sentence, + const PrimExpr &cond) -> Stmt { if (auto if_sentence = sentence.as()) { - PrimExpr new_cond = cond & if_sentence->condition; + PrimExpr new_cond = cond && if_sentence->condition; return IfThenElse(new_cond, if_sentence->then_case, if_sentence->else_case); } else { @@ -67,7 +68,7 @@ class IfConditionExtractor : public StmtExprMutator { }; auto bind_cond_var_body = [&](const Optional &body, - const Var &cond) -> Stmt { + const PrimExpr &cond) -> Stmt { if (!body.defined()) { return Stmt(); } @@ -85,7 +86,8 @@ class IfConditionExtractor : public StmtExprMutator { Array new_seq; new_seq.insert(new_seq.end(), bind_cond_var_body(then_case, cond_var)); if (else_case.defined()) - new_seq.insert(new_seq.end(), bind_cond_var_body(else_case, cond_var)); + new_seq.insert(new_seq.end(), + bind_cond_var_body(else_case, Not(cond_var))); Stmt body = new_seq.empty() From 441c3b06acb23b09d68639532d3a21f427370ced Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Wed, 22 Apr 2026 16:22:01 +0800 Subject: [PATCH 112/156] [Refactor] Strip build machine paths from LOG messages in wheel releases (#2080) * [Refactor] Use TVM_LOG_CUSTOMIZE to strip build paths from LOG messages in release builds Enable TVM's custom logging hook (TVM_LOG_CUSTOMIZE) and provide our own LogMessageImpl/LogFatalImpl that conditionally omit source file paths. In wheel/release builds (detected via CIBUILDWHEEL or SKBUILD_STATE env vars), LOG(WARNING) and friends no longer leak CI machine paths, showing only the message. Local dev builds keep full paths for debugging. Co-Authored-By: Claude Opus 4.6 * [Refactor] Downgrade noisy LOG(WARNING) to DLOG(WARNING) These warnings about TMA/swizzle layout fallbacks and warp specialization status are development diagnostics, not actionable for end users. Use DLOG so they are compiled out in release (wheel) builds. Co-Authored-By: Claude Opus 4.6 * [Chore] Apply clang-format Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Claude Opus 4.6 --- CMakeLists.txt | 25 +++++++++++++++ src/op/atomic_add.cc | 17 ++++++----- src/runtime/logging.cc | 44 +++++++++++++++++++++++++++ src/transform/producer_consumer_ws.cc | 8 ++--- 4 files changed, 83 insertions(+), 11 deletions(-) create mode 100644 src/runtime/logging.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 65b06a7052..51fc79a079 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -426,9 +426,30 @@ if(USE_Z3 AND USE_PYPI_Z3) find_package(Z3 REQUIRED) endif() +# Enable custom logging so we control the output format (e.g. strip build paths +# from __FILE__ so wheel users don't see CI machine paths in warnings). +set(USE_CUSTOM_LOGGING ON CACHE BOOL "Use custom logging implementation" FORCE) + +# Detect release (wheel) builds: in CI (cibuildwheel) or scikit-build-core wheel builds, +# we strip source paths from LOG(WARNING) etc. for a cleaner user experience. +# Local dev builds keep full paths for debugging. +if(DEFINED ENV{CIBUILDWHEEL} OR "$ENV{SKBUILD_STATE}" STREQUAL "wheel") + set(TILELANG_RELEASE_BUILD_DEFAULT ON) +else() + set(TILELANG_RELEASE_BUILD_DEFAULT OFF) +endif() +option(TILELANG_RELEASE_BUILD "Strip source paths from log messages (for wheel releases)" ${TILELANG_RELEASE_BUILD_DEFAULT}) + # Include tvm after configs have been populated add_subdirectory(${TVM_SOURCE} tvm EXCLUDE_FROM_ALL) +# Provide the custom LogMessageImpl / LogFatalImpl implementation to TVM, +# since TVM_LOG_CUSTOMIZE=1 requires them to be supplied by the user. +target_sources(tvm_objs PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/src/runtime/logging.cc") +if(TILELANG_RELEASE_BUILD) + target_compile_definitions(tvm_objs PRIVATE TILELANG_RELEASE_BUILD=1) +endif() + # Resolve compile warnings in tvm add_compile_definitions(DMLC_USE_LOGGING_LIBRARY=) @@ -442,6 +463,10 @@ if(CMAKE_BUILD_TYPE STREQUAL "Debug") endif() target_include_directories(tilelang_objs PRIVATE ${TILE_LANG_INCLUDES}) +target_compile_definitions(tilelang_objs PRIVATE TVM_LOG_CUSTOMIZE=1) +if(TILELANG_RELEASE_BUILD) + target_compile_definitions(tilelang_objs PRIVATE TILELANG_RELEASE_BUILD=1) +endif() add_library(tilelang SHARED $) target_link_libraries(tilelang PUBLIC tvm) diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index 816ce379b4..b4fa1fb6b0 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -456,12 +456,15 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { shared_layout, makeGemmABLayoutPadded(*stride, *continuous, shared_tensor->dtype.bits()))) { - LOG(WARNING) << "AtomicAdd TMA cannot support a padded layout for src: " - << src->name << ", dst: " << dst->name; + DLOG(WARNING) + << "AtomicAdd TMA cannot support a padded layout for src: " + << src->name << ", dst: " << dst->name + << " fallback to none swizzle"; desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_NONE); } else { - LOG(WARNING) << "AtomicAdd TMA unsupported swizzle layout for src: " - << src->name << ", dst: " << dst->name; + DLOG(WARNING) << "AtomicAdd TMA unsupported swizzle layout for src: " + << src->name << ", dst: " << dst->name + << " fallback to none swizzle"; desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_NONE); } } @@ -499,9 +502,9 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { }; for (const auto &check : swizzle_checks) { if (desc.swizzle == check.swizzle && inner_box_dim_ > check.max_dim) { - LOG(WARNING) << "AtomicAdd TMA cannot support swizzled layout with " - "inner_box_dim_ > " - << check.max_dim; + DLOG(WARNING) << "AtomicAdd TMA cannot support swizzled layout with " + "inner_box_dim_ > " + << check.max_dim; } } diff --git a/src/runtime/logging.cc b/src/runtime/logging.cc new file mode 100644 index 0000000000..2088459b30 --- /dev/null +++ b/src/runtime/logging.cc @@ -0,0 +1,44 @@ +#include + +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace detail { + +namespace { +const char *level_strings[] = { + ": Debug: ", // TVM_LOG_LEVEL_DEBUG = 0 + ": ", // TVM_LOG_LEVEL_INFO = 1 + ": Warning: ", // TVM_LOG_LEVEL_WARNING = 2 + ": Error: ", // TVM_LOG_LEVEL_ERROR = 3 + ": Fatal: ", // TVM_LOG_LEVEL_FATAL = 4 +}; +} // namespace + +void LogMessageImpl(const std::string &file, int lineno, int level, + const std::string &message) { + std::time_t t = std::time(nullptr); + std::cerr << "[" << std::put_time(std::localtime(&t), "%H:%M:%S") << "] "; +#ifdef TILELANG_RELEASE_BUILD + // Release (wheel) builds: omit file path for a cleaner user experience. + std::cerr << level_strings[level] << message << std::endl; +#else + // Dev builds: include file path for debugging. + std::cerr << file << ":" << lineno << level_strings[level] << message + << std::endl; +#endif +} + +[[noreturn]] void LogFatalImpl(const std::string &file, int lineno, + const std::string &message) { + LogMessageImpl(file, lineno, TVM_LOG_LEVEL_FATAL, message); + throw InternalError(file, lineno, message); +} + +} // namespace detail +} // namespace runtime +} // namespace tvm diff --git a/src/transform/producer_consumer_ws.cc b/src/transform/producer_consumer_ws.cc index cb69bfa1c7..a09a69f1ba 100644 --- a/src/transform/producer_consumer_ws.cc +++ b/src/transform/producer_consumer_ws.cc @@ -2394,10 +2394,10 @@ tvm::transform::Pass ProducerConsumerWarpSpecialized() { } // Only apply MVB + WS if the function is a tiled WS candidate. if (!TiledWSCandidate::Check(f->body, target.value())) { - LOG(WARNING) << "[WS] skipped: no TMA copies in pipeline loop"; + DLOG(WARNING) << "[WS] skipped: no TMA copies in pipeline loop"; return f; } - LOG(WARNING) << "[WS] candidate found, applying MVB + WS"; + DLOG(WARNING) << "[WS] candidate found, applying MVB + WS"; // Expand shared buffers for pipelining before the WS split. // Keep the original so we can fall back if the WS rewriter doesn't fire // (e.g. non-tile-op consumers in the loop body). @@ -2405,7 +2405,7 @@ tvm::transform::Pass ProducerConsumerWarpSpecialized() { f = ApplyMultiVersionBufferRewriter(std::move(f)); PrimFunc result = ProducerConsumerWSRewriter::Substitute(std::move(f)); if (!result->HasNonzeroAttr(kTiledWSApplied)) { - LOG(WARNING) << "[WS] rewriter did not fire, falling back"; + DLOG(WARNING) << "[WS] rewriter did not fire, falling back"; // The TMA kernel needs warp specialization for correct pipelined // execution. Since the tiled rewriter could not apply WS (e.g. // conditional loop body), strip pipeline annotations so that @@ -2432,7 +2432,7 @@ tvm::transform::Pass ProducerConsumerWarpSpecialized() { fn->body = stripped; return original_f; } - LOG(WARNING) << "[WS] transformation applied successfully"; + DLOG(WARNING) << "[WS] transformation applied successfully"; return result; }; return CreatePrimFuncPass(pass_func, 0, "tl.ProducerConsumerWarpSpecialized", From a640a89e25de1d22cb80efc1a0dfaf38a5076d46 Mon Sep 17 00:00:00 2001 From: Jiawei Xiang Date: Wed, 22 Apr 2026 17:41:15 +0800 Subject: [PATCH 113/156] [AMD][Radeon] Add the Support of RDNA3/RDNA3.5(gfx11) WMMA (#2044) * feat: support RDNA3 wmma pattern * fix: support tuple-based threadblock swizzle parsing in HIP codegen * fix: improve some notes * fix: support leading stage dims in WMMA shared loads * fix: add transpose-aware RDNA WMMA A layout selection * fix: fix annotation errors * fix: use ffi helper for rdna generation detection * fix: fix format error * fix: fix lower call --- examples/amd/example_amd_flash_attn_bwd.py | 96 ++++-- src/target/codegen_hip.cc | 41 ++- src/target/codegen_hip.h | 4 + src/target/rt_mod_hip.cc | 2 + src/target/utils.cc | 15 + src/target/utils.h | 1 + tilelang/intrinsics/wmma_layout.py | 339 +++++++++++++++++++- tilelang/intrinsics/wmma_macro_generator.py | 170 +++++----- 8 files changed, 551 insertions(+), 117 deletions(-) diff --git a/examples/amd/example_amd_flash_attn_bwd.py b/examples/amd/example_amd_flash_attn_bwd.py index 27986ce78d..0a9b7d26cb 100644 --- a/examples/amd/example_amd_flash_attn_bwd.py +++ b/examples/amd/example_amd_flash_attn_bwd.py @@ -1,3 +1,4 @@ +import sys import torch import torch.nn.functional as F import tilelang @@ -10,6 +11,15 @@ import time +def IsRDNA(): + if torch.cuda.is_available(): + gpu_name = torch.cuda.get_device_name().strip() + return "Radeon" in gpu_name + else: + print("Error: GPU Device is not detected") + sys.exit(1) + + def ref_program(Q, K, V, is_causal, groups=1): assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" @@ -30,13 +40,24 @@ def ref_program(Q, K, V, is_causal, groups=1): def get_fwd_configs(): - block_M = [32, 64, 128, 256] - block_N = [32, 64, 128, 256] - threads = [128, 256, 512] - num_split_q = [64, 128, 256] - num_stages = [0, 1] + # Match the standalone forward example on RDNA. WMMA configs larger than + # 32x32 can trigger layout issues when bridging the softmax fragment into + # the second GEMM's A-layout. + if IsRDNA(): + block_M = [16, 32, 64] + block_N = [16, 32, 64] + threads = [32, 64] + num_split_q = [16, 32, 64] + num_stages = [0] + k_pack = [1] + else: + block_M = [32, 64, 128, 256] + block_N = [32, 64, 128, 256] + threads = [128, 256, 512] + num_split_q = [64, 128, 256] + num_stages = [0, 1] + k_pack = [2] enable_rasterization = [True] - k_pack = [2] panel_size = [7, 8, 9, 10] qk_coalesced_width = [8] v_coalesced_width = [4] @@ -46,6 +67,8 @@ def get_fwd_configs(): for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product( block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width ): + if IsRDNA() and m == 16 and n == 16 and t == 64: + continue valid_configs.append( { "block_M": m, @@ -127,6 +150,10 @@ def main( Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) + # Bridge the WMMA D-layout softmax fragment into the A-layout + # expected by GEMM 2 on RDNA GPUs. + if IsRDNA(): + P_shared = T.alloc_shared([block_M, block_N], dtype) acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) @@ -188,7 +215,12 @@ def main( for i in T.Parallel(block_M): l_i[i] += row_sum[i] - T.copy(acc_s, acc_s_cast) + if IsRDNA(): + for i, j in T.Parallel(block_M, block_N): + P_shared[i, j] = T.cast(acc_s[i, j], dtype) + T.copy(P_shared, acc_s_cast) + else: + T.copy(acc_s, acc_s_cast) T.gemm(acc_s_cast, V_shared, acc_o, policy=GemmWarpPolicy.FullRow) @@ -211,15 +243,27 @@ def main( def get_bwd_configs(): - block_M = [16, 32, 64, 128, 256] - block_N = [16, 32, 64, 128, 256] - threads = [64, 128, 256, 512, 1024] - num_stages = [0, 1, 2] + # Keep the RDNA search space aligned with the WMMA-friendly tile sizes + # verified above. Larger tiles and some warp/block combinations are either + # unsupported or known to trigger invalid lowering on RDNA. + if IsRDNA(): + block_M = [16, 32] + block_N = [16, 32] + threads = [32, 64] + num_stages = [0] + panel_size = [7, 8] + else: + block_M = [16, 32, 64, 128, 256] + block_N = [16, 32, 64, 128, 256] + threads = [64, 128, 256, 512, 1024] + num_stages = [0, 1, 2] + panel_size = [7, 8, 9, 10] enable_rasterization = [True] - panel_size = [7, 8, 9, 10] configs = [] for m, n, stages, t, r, p in itertools.product(block_M, block_N, num_stages, threads, enable_rasterization, panel_size): + if IsRDNA() and m == 16 and n == 16 and t == 64: + continue configs.append( { "block_M": m, @@ -305,6 +349,10 @@ def flash_bwd_kernel( lse_shared = T.alloc_shared([block_N], accum_dtype) delta_shared = T.alloc_shared([block_N], accum_dtype) ds_shared = T.alloc_shared([block_M, block_N], dtype) + if IsRDNA(): + # Bridge the WMMA D-layout fragment produced by GEMM/elementwise + # ops into the A-layout expected by the following GEMM. + p_shared = T.alloc_shared([block_M, block_N], dtype) p_cast = T.alloc_fragment([block_M, block_N], dtype) qkT = T.alloc_fragment([block_M, block_N], accum_dtype) @@ -343,17 +391,29 @@ def flash_bwd_kernel( T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(P_acc, p_cast) + if IsRDNA(): + for i, j in T.Parallel(block_M, block_N): + p_shared[i, j] = T.cast(P_acc[i, j], dtype) + T.copy(p_shared, p_cast) + else: + T.copy(P_acc, p_cast) T.gemm(p_cast, do_shared, dv, policy=T.GemmWarpPolicy.FullRow) T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta_shared) - for i, j in T.Parallel(block_M, block_N): - p_cast[i, j] = P_acc[i, j] * (dP[i, j] - delta_shared[j]) * sm_scale - - T.gemm(p_cast, q_shared, dk, policy=T.GemmWarpPolicy.FullRow) - + if IsRDNA(): + for i, j in T.Parallel(block_M, block_N): + dP[i, j] = P_acc[i, j] * (dP[i, j] - delta_shared[j]) * sm_scale + for i, j in T.Parallel(block_M, block_N): + p_shared[i, j] = T.cast(dP[i, j], dtype) + T.copy(p_shared, p_cast) + T.gemm(p_cast, q_shared, dk, policy=T.GemmWarpPolicy.FullRow) + else: + for i, j in T.Parallel(block_M, block_N): + p_cast[i, j] = P_acc[i, j] * (dP[i, j] - delta_shared[j]) * sm_scale + T.gemm(p_cast, q_shared, dk, policy=T.GemmWarpPolicy.FullRow) T.copy(p_cast, ds_shared) + T.clear(dq) T.gemm(ds_shared, K_shared, dq, transpose_A=True) for i, j in T.Parallel(block_N, dim): diff --git a/src/target/codegen_hip.cc b/src/target/codegen_hip.cc index f6503a8348..dd3dd0aeac 100644 --- a/src/target/codegen_hip.cc +++ b/src/target/codegen_hip.cc @@ -17,6 +17,7 @@ #include "../op/builtin.h" #include "target/source/ptx.h" +#include "utils.h" namespace tvm { namespace codegen { @@ -1192,32 +1193,54 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { std::string c_ref = this->PrintExpr(op->args[10]); std::string c_bias = this->PrintExpr(op->args[11]); + // Get RDNA Generation + ICHECK(target_.defined()) << "CodeGenTileLangHIP target is not set"; + int rdna_gen = tvm::tl::TargetGetRDNAGeneration(target_); + ICHECK(rdna_gen == 11 || rdna_gen == 12) + << "Unsupported RDNA target for WMMA: gfx" << target_->str(); + // Determine wmma builtin name from shape // shape = "f32_16x16x16_f16_w32" -> // "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12" For gfx12 targets use - // the _gfx12 suffix variant. - std::string wmma_builtin = "__builtin_amdgcn_wmma_" + shape + "_gfx12"; + // the _gfx12 suffix variant, which gfx11 targets don't have. + std::string wmma_builtin = "__builtin_amdgcn_wmma_" + shape; + + int ab_half_elems = 16; + std::string ab_vec_typedef = "tl_v16f16"; + if (rdna_gen == 12) { + wmma_builtin += "_gfx12"; + ab_half_elems = 8; + ab_vec_typedef = "tl_v8f16"; + } // Emit the WMMA call. + // For gfx12: // Signature: v8f32 = wmma_builtin(v8f16 a, v8f16 b, v8f32 c) // where v8f16 = __fp16 x 8, v8f32 = float x 8. + // For gfx11: + // Signature: v8f32 = wmma_builtin(v16f16 a, v16f16 b, v8f32 c) + // where v16f16 = __fp16 x 16, v8f32 = float x 8. + // // A/B buffers hold half_t (fp16), C/D buffers hold float. - // Each element index accesses a packed vector of 8 elements. + // Each element index accesses a packed vector of 8/16 elements. // // Using typedef'd vector types for the cast: - // typedef __attribute__((__vector_size__(8 * sizeof(__fp16)))) __fp16 - // tl_v8f16; typedef __attribute__((__vector_size__(8 * sizeof(float)))) - // float tl_v8f32; + // typedef __attribute__((__vector_size__(8/16 * sizeof(__fp16)))) __fp16 + // tl_v8/16f16; typedef __attribute__((__vector_size__(8 * + // sizeof(float)))) float tl_v8f32; std::string call_wmma_code = R"({ - typedef __attribute__((__vector_size__(8 * sizeof(__fp16)))) __fp16 tl_v8f16; + typedef __attribute__((__vector_size__({ab_half_elems} * sizeof(__fp16)))) __fp16 {ab_vec_typedef}; typedef __attribute__((__vector_size__(8 * sizeof(float)))) float tl_v8f32; *((tl_v8f32*){c_ref} + {c_bias}) = {wmma_builtin}( - *((tl_v8f16*){a_ref} + {a_bias}), - *((tl_v8f16*){b_ref} + {b_bias}), + *(({ab_vec_typedef}*){a_ref} + {a_bias}), + *(({ab_vec_typedef}*){b_ref} + {b_bias}), *((tl_v8f32*){c_ref} + {c_bias})); })"; Replacer wmma_replacer; wmma_replacer.register_rule("{wmma_builtin}", wmma_builtin); + wmma_replacer.register_rule("{ab_half_elems}", + std::to_string(ab_half_elems)); + wmma_replacer.register_rule("{ab_vec_typedef}", ab_vec_typedef); wmma_replacer.register_rule("{a_ref}", a_ref); wmma_replacer.register_rule("{a_bias}", a_bias); wmma_replacer.register_rule("{b_ref}", b_ref); diff --git a/src/target/codegen_hip.h b/src/target/codegen_hip.h index 631050feb6..0dfef6d609 100644 --- a/src/target/codegen_hip.h +++ b/src/target/codegen_hip.h @@ -6,6 +6,7 @@ #define TVM_TL_TARGET_CODEGEN_HIP_H_ #include +#include #include #include @@ -21,6 +22,7 @@ class CodeGenTileLangHIP final : public CodeGenC { public: CodeGenTileLangHIP(); std::string Finish(); + void SetTarget(Target target) { target_ = std::move(target); } // override behavior void PrintFuncPrefix(std::ostream &os) final; void PrintExtraAttrs(const PrimFunc &f, std::ostream &os) final; @@ -88,6 +90,8 @@ class CodeGenTileLangHIP final : public CodeGenC { // The alignment of the barrier array in shared memory // Set to 16 to maintain minimum alignment requirements for async bulk copy const int barrier_alignment_bytes_ = 16; + // Target + Target target_; }; } // namespace codegen diff --git a/src/target/rt_mod_hip.cc b/src/target/rt_mod_hip.cc index 1e5c689c6e..63d7cea5b1 100644 --- a/src/target/rt_mod_hip.cc +++ b/src/target/rt_mod_hip.cc @@ -57,6 +57,7 @@ ffi::Module BuildTileLangHIP(IRModule mod, Target target) { bool output_ssa = false; CodeGenTileLangHIP cg; cg.Init(output_ssa); + cg.SetTarget(target); for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) @@ -93,6 +94,7 @@ ffi::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) { bool output_ssa = false; CodeGenTileLangHIP cg; cg.Init(output_ssa); + cg.SetTarget(target); for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) diff --git a/src/target/utils.cc b/src/target/utils.cc index 5fc32aa479..36a845c477 100644 --- a/src/target/utils.cc +++ b/src/target/utils.cc @@ -288,6 +288,19 @@ bool IsCudaVectorizableCast(DataType from_ty, DataType target_ty) { return false; } +int TargetGetRDNAGeneration(Target target) { + if (!TargetIsRDNA(target)) + return 0; + if (target->attrs.count("mcpu")) { + std::string mcpu = Downcast(target->attrs.at("mcpu")); + if (mcpu.rfind("gfx11", 0) == 0) + return 11; + if (mcpu.rfind("gfx12", 0) == 0) + return 12; + } + return 0; +} + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() @@ -321,6 +334,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](Target target) { return TargetHasStmatrix(target); }) .def("tl.TargetHasBulkCopy", [](Target target) { return TargetHasBulkCopy(target); }) + .def("tl.TargetGetRDNAGeneration", + [](Target target) { return TargetGetRDNAGeneration(target); }) .def("tl.TargetGetWarpSize", [](Target target) { return TargetGetWarpSize(target); }); } diff --git a/src/target/utils.h b/src/target/utils.h index d1741df463..424a5b7463 100644 --- a/src/target/utils.h +++ b/src/target/utils.h @@ -40,6 +40,7 @@ bool TargetHasSMVersionGE(Target target, int version); bool IsCudaVectorizableFP8(DataType dtype); bool IsCudaVectorizableCast(DataType from_ty, DataType target_ty); +int TargetGetRDNAGeneration(Target target); } // namespace tl } // namespace tvm diff --git a/tilelang/intrinsics/wmma_layout.py b/tilelang/intrinsics/wmma_layout.py index fd6f7c5c33..65143332ac 100644 --- a/tilelang/intrinsics/wmma_layout.py +++ b/tilelang/intrinsics/wmma_layout.py @@ -22,95 +22,400 @@ Reverse: (thread, local) -> (M=(thread//16)*8+local, N=thread%16) Store: D[M=(t//16)*8+l][N=t%16] = d_vec[l] -NOTE: A and D have DIFFERENT layouts (A uses t%16 for M, D uses (t//16)*8+l for M). -This means they cannot be used interchangeably without a layout change. +EMPIRICALLY VERIFIED hardware layouts for wmma_f32_16x16x16_f16_w32 (gfx11): -local_size = 8 per thread + A[M=16][K=16]: + thread t, elem e -> A[M=t%16][K=e] + Forward: (M, K) -> (thread=M, local=K%16) [Mapping to tid=0~15] + Reverse: (thread, local) -> (M=thread%16, K=local) + Memory load: A[M=t%16][K=0..+15] -> CONTIGUOUS in K (vectorized) + + B[K=16][N=16] (non-transposed, K x N storage): + thread t, elem e -> B[K=e][N=t%16] + Forward: (K, N) -> (thread=N, local=K%16) [Mapping to tid=0~15] + Reverse: (thread, local) -> (K=local, N=thread%16) + + B_T[N=16][K=16] (transposed storage of B): + B_T[N=t%16][K=e] -> CONTIGUOUS in K (vectorized) + + D[M=16][N=16]: + thread t, elem l -> D[M=(t//16)+l*2][N=t%16] + Forward: (M, N) -> (thread=(M%2)*16+N, local=M//2) + Reverse: (thread, local) -> (M=(thread//16)+local*2, N=thread%16) + Store: D[M=(t//16)+l*2][N=t%16] = d_vec[l] + +NOTE: +1. A and D have DIFFERENT layouts (e.g. For gfx12, A uses t%16 for M, + D uses (t//16)*8+l for M). This means they cannot be used interchangeably + without a layout change. +2. For gfx11, lane 16~31 share the same A/B data as lane 0~15. + +local_size = 8 (gfx12) | 16 (gfx11) """ from tvm.runtime import convert # ────────────────────────────────────────────────────────────────────────────── -# A matrix: shared[M=16][K=16] +# gfx12 A matrix: shared[M=16][K=16] # A[M=t%16][K=(t//16)*8+l] -> vectorized load from row M=t%16, consecutive K # ────────────────────────────────────────────────────────────────────────────── -def shared_16x16_to_local_32x8_layout_A(i, j): +def shared_16x16_to_local_32x8_layout_A_gfx12(i, j): """Forward: A[i=M, j=K] -> (thread=(j//8)*16+i, local=j%8).""" thread_id = (j // 8) * 16 + i # (K//8)*16 + M local_id = j % 8 # K%8 return thread_id, local_id -def thread_id_shared_access_32x8_to_16x16_layout_A(thread_id, local_id): +def thread_id_shared_access_32x8_to_16x16_layout_A_gfx12(thread_id, local_id): """Reverse: (thread, local) -> (i=M=thread%16, j=K=(thread//16)*8+local).""" return thread_id % 16, (thread_id // 16) * 8 + local_id # ────────────────────────────────────────────────────────────────────────────── -# B matrix (non-transposed, K x N): shared[K=16][N=16] +# gfx12 A_T matrix (transposed storage, K x M): shared[K=16][M=16] +# A_T[K=(t//16)*8+l][M=t%16] +# ────────────────────────────────────────────────────────────────────────────── + + +def shared_16x16_to_local_32x8_layout_A_colmajor_gfx12(i, j): + """Forward: A_T[i=K, j=M] -> (thread=(i//8)*16+j, local=i%8).""" + thread_id = (i // 8) * 16 + j # (K//8)*16 + M + local_id = i % 8 # K%8 + return thread_id, local_id + + +def thread_id_shared_access_32x8_to_16x16_layout_A_colmajor_gfx12(thread_id, local_id): + """Reverse: (thread, local) -> (i=K=(thread//16)*8+local, j=M=thread%16).""" + return (thread_id // 16) * 8 + local_id, thread_id % 16 + + +# ────────────────────────────────────────────────────────────────────────────── +# gfx12 B matrix (non-transposed, K x N): shared[K=16][N=16] # B[K=(t//16)*8+l][N=t%16] # ────────────────────────────────────────────────────────────────────────────── -def shared_16x16_to_local_32x8_layout_B(i, j): +def shared_16x16_to_local_32x8_layout_B_gfx12(i, j): """Forward: B[i=K, j=N] -> (thread=(i//8)*16+j, local=i%8).""" thread_id = (i // 8) * 16 + j # (K//8)*16 + N local_id = i % 8 # K%8 return thread_id, local_id -def thread_id_shared_access_32x8_to_16x16_layout_B(thread_id, local_id): +def thread_id_shared_access_32x8_to_16x16_layout_B_gfx12(thread_id, local_id): """Reverse: (thread, local) -> (i=K=(thread//16)*8+local, j=N=thread%16).""" return (thread_id // 16) * 8 + local_id, thread_id % 16 # ────────────────────────────────────────────────────────────────────────────── -# B_T matrix (transposed storage, N x K): shared[N=16][K=16] +# gfx12 B_T matrix (transposed storage, N x K): shared[N=16][K=16] # B_T[N=t%16][K=(t//16)*8+l] -> vectorized load from row N=t%16, consecutive K # ────────────────────────────────────────────────────────────────────────────── -def shared_16x16_to_local_32x8_layout_B_colmajor(i, j): +def shared_16x16_to_local_32x8_layout_B_colmajor_gfx12(i, j): """Forward: B_T[i=N, j=K] -> (thread=(j//8)*16+i, local=j%8).""" thread_id = (j // 8) * 16 + i # (K//8)*16 + N local_id = j % 8 # K%8 return thread_id, local_id -def thread_id_shared_access_32x8_to_16x16_layout_B_colmajor(thread_id, local_id): +def thread_id_shared_access_32x8_to_16x16_layout_B_colmajor_gfx12(thread_id, local_id): """Reverse: (thread, local) -> (i=N=thread%16, j=K=(thread//16)*8+local).""" return thread_id % 16, (thread_id // 16) * 8 + local_id # ────────────────────────────────────────────────────────────────────────────── -# D/C output matrix: shared[M=16][N=16] fp32 +# gfx12 D/C output matrix: shared[M=16][N=16] fp32 # D[M=(t//16)*8+l][N=t%16] -- hardware native # ────────────────────────────────────────────────────────────────────────────── -def shared_16x16_to_local_32x8_layout_C(i, j): +def shared_16x16_to_local_32x8_layout_C_gfx12(i, j): """Forward: D[i=M, j=N] -> (thread=(i//8)*16+j, local=i%8).""" thread_id = (i // 8) * 16 + j # (M//8)*16 + N local_id = i % 8 # M%8 return thread_id, local_id -def thread_id_shared_access_32x8_to_16x16_layout_C(thread_id, local_id): +def thread_id_shared_access_32x8_to_16x16_layout_C_gfx12(thread_id, local_id): """Reverse: (thread, local) -> (i=M=(thread//16)*8+local, j=N=thread%16).""" return (thread_id // 16) * 8 + local_id, thread_id % 16 # ────────────────────────────────────────────────────────────────────────────── -# Store index map: (thread, local) -> (M, N) in D (hardware D layout) +# gfx12 store index map: (thread, local) -> (M, N) in D (hardware D layout) # D[M=(t//16)*8+local][N=t%16] -- affine, invertible # ────────────────────────────────────────────────────────────────────────────── -def wmma_store_index_map(thread_id, local_id): +def wmma_store_index_map_gfx12(thread_id, local_id): """(thread, local) -> (M, N) in D. Hardware D layout.""" i = (thread_id // 16) * 8 + local_id # M j = thread_id % 16 # N return convert([i, j]) + + +# ────────────────────────────────────────────────────────────────────────────── +# gfx11 A matrix: shared[M=16][K=16] +# A[M=t%16][K=l] -> vectorized load from row M=t%16, consecutive K +# ────────────────────────────────────────────────────────────────────────────── + + +def shared_16x16_to_local_32x16_layout_A_gfx11(i, j): + """ + Forward: A[i=M, j=K] -> (thread=i, local=j%16). + ATTN: Here we only reflect (i, j) to the lower-half-lane of threads in + a warp. + """ + thread_id = i + local_id = j % 16 + return thread_id, local_id + + +def thread_id_shared_access_32x16_to_16x16_layout_A_gfx11(thread_id, local_id): + """Reverse: (thread, local) -> (i=M=thread%16, j=K=local)""" + return thread_id % 16, local_id + + +# ────────────────────────────────────────────────────────────────────────────── +# gfx11 A_T matrix (transposed storage, K x M): shared[K=16][M=16] +# A_T[K=l][M=t%16] +# ────────────────────────────────────────────────────────────────────────────── + + +def shared_16x16_to_local_32x16_layout_A_colmajor_gfx11(i, j): + """ + Forward: A_T[i=K, j=M] -> (thread=M, local=K%16). + ATTN: Here we only reflect (i, j) to the lower-half-lane of threads in + a warp. + """ + thread_id = j + local_id = i % 16 + return thread_id, local_id + + +def thread_id_shared_access_32x16_to_16x16_layout_A_colmajor_gfx11(thread_id, local_id): + """Reverse: (thread, local) -> (i=K=local, j=M=thread%16)""" + return local_id, thread_id % 16 + + +# ────────────────────────────────────────────────────────────────────────────── +# gfx11 B matrix (non-transposed, K x N): shared[K=16][N=16] +# B[K=l][N=t%16] +# ────────────────────────────────────────────────────────────────────────────── + + +def shared_16x16_to_local_32x16_layout_B_gfx11(i, j): + """ + Forward: B[i=K, j=N] -> (thread=N, local=K%16). + ATTN: Here we only reflect (i, j) to the lower-half-lane of threads in + a warp. + """ + thread_id = j + local_id = i % 16 + return thread_id, local_id + + +def thread_id_shared_access_32x16_to_16x16_layout_B_gfx11(thread_id, local_id): + """Reverse: (thread, local) -> (i=K=local, j=N=thread%16)""" + return local_id, thread_id % 16 + + +# ────────────────────────────────────────────────────────────────────────────── +# gfx11 B_T matrix (transposed storage, N x K): shared[N=16][K=16] +# B_T[N=t%16][K=l] -> vectorized load from row N=t%16, consecutive K +# ────────────────────────────────────────────────────────────────────────────── + + +def shared_16x16_to_local_32x16_layout_B_colmajor_gfx11(i, j): + """ + Forward: B_T[i=N, j=K] -> (thread=i, local=j%16). + ATTN: Here we only reflect (i, j) to the lower-half-lane of threads in + a warp. + """ + thread_id = i + local_id = j % 16 + return thread_id, local_id + + +def thread_id_shared_access_32x16_to_16x16_layout_B_colmajor_gfx11(thread_id, local_id): + """Reverse: (thread, local) -> (j=K=local, i=N=thread%16)""" + return thread_id % 16, local_id + + +# ────────────────────────────────────────────────────────────────────────────── +# gfx11 D/C output matrix: shared[M=16][N=16] fp32 +# D[M=(t//16)+l*2][N=t%16] -- hardware native +# ────────────────────────────────────────────────────────────────────────────── + + +def shared_16x16_to_local_32x8_layout_C_gfx11(i, j): + """Forward: D[i=M, j=N] -> (thread=(i%2)*16+j, local=i//2).""" + thread_id = (i % 2) * 16 + j + local_id = i // 2 + return thread_id, local_id + + +def thread_id_shared_access_32x8_to_16x16_layout_C_gfx11(thread_id, local_id): + """Reverse: (thread, local) -> (i=M=(thread//16)+local*2, j=N=thread%16)""" + return (thread_id // 16) + local_id * 2, thread_id % 16 + + +# ────────────────────────────────────────────────────────────────────────────── +# gfx11 store index map: (thread, local) -> (M, N) in D (hardware D layout) +# D[M=(t//16)+local*2][N=t%16] -- affine, invertible +# ────────────────────────────────────────────────────────────────────────────── + + +def wmma_store_index_map_gfx11(thread_id, local_id): + """(thread, local) -> (M, N) in D. Hardware D layout.""" + i = (thread_id // 16) + local_id * 2 + j = thread_id % 16 + return convert([i, j]) + + +# ────────────────────────────────────────────────────────────────────────────── +# gfx11 fragment-forward helpers for duplicated half-wave ownership +# ────────────────────────────────────────────────────────────────────────────── + + +def fragment_forward_A_gfx11(i, j, rep): + """Replicated fragment forward map for gfx11 A. + + The canonical owner lives in the lower half-wave and `rep` selects whether + the logical element is materialized in the lower or upper half-wave copy. + """ + thread_id, local_id = shared_16x16_to_local_32x16_layout_A_gfx11(i, j) + return thread_id + 16 * rep, local_id + + +def fragment_forward_A_colmajor_gfx11(i, j, rep): + """Replicated fragment forward map for gfx11 transposed A.""" + thread_id, local_id = shared_16x16_to_local_32x16_layout_A_colmajor_gfx11(i, j) + return thread_id + 16 * rep, local_id + + +def fragment_forward_B_gfx11(i, j, rep): + """Replicated fragment forward map for gfx11 B.""" + thread_id, local_id = shared_16x16_to_local_32x16_layout_B_gfx11(i, j) + return thread_id + 16 * rep, local_id + + +def fragment_forward_B_colmajor_gfx11(i, j, rep): + """Replicated fragment forward map for gfx11 transposed B.""" + thread_id, local_id = shared_16x16_to_local_32x16_layout_B_colmajor_gfx11(i, j) + return thread_id + 16 * rep, local_id + + +# ────────────────────────────────────────────────────────────────────────────── +# Factory helpers +# ────────────────────────────────────────────────────────────────────────────── + + +def _unsupported_rdna_generation(rdna_gen: int): + raise ValueError(f"Unsupported RDNA generation for WMMA layout: {rdna_gen}") + + +def get_wmma_a_layout_funcs(rdna_gen: int, transposed: bool): + """Return (forward_map, reverse_map) for A layout.""" + if rdna_gen == 11: + if transposed: + return ( + shared_16x16_to_local_32x16_layout_A_colmajor_gfx11, + thread_id_shared_access_32x16_to_16x16_layout_A_colmajor_gfx11, + ) + return ( + shared_16x16_to_local_32x16_layout_A_gfx11, + thread_id_shared_access_32x16_to_16x16_layout_A_gfx11, + ) + if rdna_gen == 12: + if transposed: + return ( + shared_16x16_to_local_32x8_layout_A_colmajor_gfx12, + thread_id_shared_access_32x8_to_16x16_layout_A_colmajor_gfx12, + ) + return ( + shared_16x16_to_local_32x8_layout_A_gfx12, + thread_id_shared_access_32x8_to_16x16_layout_A_gfx12, + ) + _unsupported_rdna_generation(rdna_gen) + + +def get_wmma_b_layout_funcs(rdna_gen: int, transposed: bool): + """Return (forward_map, reverse_map) for B layout.""" + if rdna_gen == 11: + if transposed: + return ( + shared_16x16_to_local_32x16_layout_B_colmajor_gfx11, + thread_id_shared_access_32x16_to_16x16_layout_B_colmajor_gfx11, + ) + return ( + shared_16x16_to_local_32x16_layout_B_gfx11, + thread_id_shared_access_32x16_to_16x16_layout_B_gfx11, + ) + if rdna_gen == 12: + if transposed: + return ( + shared_16x16_to_local_32x8_layout_B_colmajor_gfx12, + thread_id_shared_access_32x8_to_16x16_layout_B_colmajor_gfx12, + ) + return ( + shared_16x16_to_local_32x8_layout_B_gfx12, + thread_id_shared_access_32x8_to_16x16_layout_B_gfx12, + ) + _unsupported_rdna_generation(rdna_gen) + + +def get_wmma_c_layout_funcs(rdna_gen: int): + """Return (forward_map, reverse_map) for C/D layout.""" + if rdna_gen == 11: + return ( + shared_16x16_to_local_32x8_layout_C_gfx11, + thread_id_shared_access_32x8_to_16x16_layout_C_gfx11, + ) + if rdna_gen == 12: + return ( + shared_16x16_to_local_32x8_layout_C_gfx12, + thread_id_shared_access_32x8_to_16x16_layout_C_gfx12, + ) + _unsupported_rdna_generation(rdna_gen) + + +def get_wmma_store_index_map_func(rdna_gen: int): + """Return the (thread_id, local_id) -> (row, col) store map.""" + if rdna_gen == 11: + return wmma_store_index_map_gfx11 + if rdna_gen == 12: + return wmma_store_index_map_gfx12 + _unsupported_rdna_generation(rdna_gen) + + +def get_wmma_a_fragment_forward_func(rdna_gen: int, transposed: bool): + """Return the fragment forward function for A layout.""" + if rdna_gen == 11: + return fragment_forward_A_colmajor_gfx11 if transposed else fragment_forward_A_gfx11 + if rdna_gen == 12: + return None + _unsupported_rdna_generation(rdna_gen) + + +def get_wmma_b_fragment_forward_func(rdna_gen: int, transposed: bool): + """Return the fragment forward function for B layout.""" + if rdna_gen == 11: + return fragment_forward_B_colmajor_gfx11 if transposed else fragment_forward_B_gfx11 + if rdna_gen == 12: + return None + _unsupported_rdna_generation(rdna_gen) + + +def get_wmma_fragment_replicate_count(rdna_gen: int): + """Return the fragment replicate count used for logical one-to-many owners.""" + if rdna_gen == 11: + return 2 + if rdna_gen == 12: + return 1 + _unsupported_rdna_generation(rdna_gen) diff --git a/tilelang/intrinsics/wmma_macro_generator.py b/tilelang/intrinsics/wmma_macro_generator.py index 840788b47d..811d3f1719 100644 --- a/tilelang/intrinsics/wmma_macro_generator.py +++ b/tilelang/intrinsics/wmma_macro_generator.py @@ -3,9 +3,13 @@ Only supports the f16->f32, 16x16x16 variant with warp-size=32. Thread-data mapping (per AMDGPU ISA): - A[16][K=16]: thread t holds A[t//2][(t%2)*8 : (t%2)*8+8] (8 fp16 = 4 f32 per thread) - B[K=16][16]: same mapping as A for the transposed dimension - C/D[16][16]: thread t holds D[t//2][(t%2)*8 : (t%2)*8+8] (8 f32 per thread) + gfx11: + - A/B: duplicated across the two half-waves, so each logical input fragment + is distributed over an effective wave size of 16 lanes. + - C/D: distributed over the full wave32 output layout. + gfx12: + - A/B: distributed over the full wave32 input layout. + - C/D: distributed over the full wave32 output layout. """ from __future__ import annotations @@ -13,6 +17,7 @@ from typing import Literal import tilelang.language as T +from tilelang import _ffi_api from tilelang import tvm as tvm from tvm import tir from tvm.ir import Range @@ -23,23 +28,23 @@ from tilelang.language.utils import get_buffer_region_from_load from tilelang.utils import is_fragment from .wmma_layout import ( - shared_16x16_to_local_32x8_layout_A, - shared_16x16_to_local_32x8_layout_B, - shared_16x16_to_local_32x8_layout_B_colmajor, - thread_id_shared_access_32x8_to_16x16_layout_A, - thread_id_shared_access_32x8_to_16x16_layout_B, - thread_id_shared_access_32x8_to_16x16_layout_B_colmajor, - wmma_store_index_map, + get_wmma_a_layout_funcs, + get_wmma_a_fragment_forward_func, + get_wmma_b_layout_funcs, + get_wmma_b_fragment_forward_func, + get_wmma_c_layout_funcs, + get_wmma_fragment_replicate_count, + get_wmma_store_index_map_func, ) lift = convert class WMMAIntrinEmitter: - """Intrinsic emitter for AMD RDNA WMMA (16×16×16, warp-size=32). + """Intrinsic emitter for AMD RDNA WMMA (16x16x16, warp-size=32). Supports: - - fp16 -> fp32 (f32_16x16x16_f16_w32 / _gfx12) + - fp16 -> fp32 (f32_16x16x16_f16_w32, with `_gfx12` codegen suffix on gfx12) """ M_DIM = 16 @@ -65,6 +70,7 @@ def __init__( ): assert a_dtype in ("float16", "bfloat16"), f"Unsupported a_dtype: {a_dtype}" assert accum_dtype == "float32", f"Unsupported accum_dtype: {accum_dtype}" + assert target is not None, "WMMAIntrinEmitter requires a HIP target to select WMMA layouts." self.a_dtype = a_dtype self.b_dtype = b_dtype @@ -79,20 +85,34 @@ def __init__( self.k_pack = k_pack self.thread_var = thread_var self.target = target + self.rdna_gen = _ffi_api.TargetGetRDNAGeneration(target) + if self.rdna_gen == 0: + raise ValueError(f"Invalid RDNA target for WMMA: {target}") self.micro_size_x = self.M_DIM self.micro_size_y = self.N_DIM self.micro_size_k = self.K_DIM - # Each thread holds 8 fp16 (A/B) or 8 fp32 (C/D) - self.local_size_a = (self.M_DIM * self.K_DIM) // self.WARP_SIZE # = 8 - self.local_size_b = (self.N_DIM * self.K_DIM) // self.WARP_SIZE # = 8 - self.local_size_out = (self.M_DIM * self.N_DIM) // self.WARP_SIZE # = 8 + # gfx11 duplicates A/B across half-waves, so the effective input fragment + # distribution uses 16 lanes instead of the full wave32 used by gfx12. + input_fragment_warp_size = (self.WARP_SIZE // 2) if self.rdna_gen == 11 else self.WARP_SIZE + self.local_size_a = (self.M_DIM * self.K_DIM) // input_fragment_warp_size + self.local_size_b = (self.N_DIM * self.K_DIM) // input_fragment_warp_size + # C/D outputs are distributed over the full wave32 layout on both gfx11 and gfx12. + self.local_size_out = (self.M_DIM * self.N_DIM) // self.WARP_SIZE self.warp_rows = warp_row_tiles // self.M_DIM self.warp_cols = warp_col_tiles // self.N_DIM self.threads = self.WARP_SIZE * block_row_warps * block_col_warps + self.a_forward_layout_fn, self.a_reverse_layout_fn = get_wmma_a_layout_funcs(self.rdna_gen, self.a_transposed) + self.a_fragment_forward_fn = get_wmma_a_fragment_forward_func(self.rdna_gen, self.a_transposed) + self.b_forward_layout_fn, self.b_reverse_layout_fn = get_wmma_b_layout_funcs(self.rdna_gen, self.b_transposed) + self.b_fragment_forward_fn = get_wmma_b_fragment_forward_func(self.rdna_gen, self.b_transposed) + self.c_forward_layout_fn, self.c_reverse_layout_fn = get_wmma_c_layout_funcs(self.rdna_gen) + self.fragment_replicate = get_wmma_fragment_replicate_count(self.rdna_gen) + self.store_index_map_fn = get_wmma_store_index_map_func(self.rdna_gen) + # Build the wmma shape string used by T.tvm_rdna_wmma # shape = "f32_16x16x16_f16_w32" (or _gfx12 suffix is handled in codegen) dtype_in_abbrv = {"float16": "f16", "bfloat16": "bf16"}[a_dtype] @@ -126,37 +146,13 @@ def extract_thread_binding(self, thread_id): def get_ldmatrix_index_map(self, is_b: bool = False): """Return (forward, reverse) index maps for shared→local loading. - For WMMA gfx12: - - A is stored row-major [M, K]. Thread t loads A[t%16][(t//16)*8+local]. - - B (non-transposed) is stored row-major [K, N]. - Thread t loads B[t%16][(t//16)*8+local] (same shape/pattern as A). - - B (transposed) is stored [N, K]. - Thread t loads B_T[t%16][(t//16)*8+local] (N-row, K-col). + The actual layout functions are chosen during __init__ based on rdna_gen: + - gfx11 uses half-wave duplicated A/B input layouts (32x16 naming). + - gfx12 uses full wave32 A/B input layouts (32x8 naming). """ - transposed = self.b_transposed if is_b else self.a_transposed if not is_b: - # A matrix [M, K] - if transposed: - # A stored as [K, M]: row=K, col=M → same mapping but rotated - # In this case row index runs over K, col over M - # thread t: K-row = t%16, M-col = (t//16)*8+local - index_map = shared_16x16_to_local_32x8_layout_A - reverse_index_map = thread_id_shared_access_32x8_to_16x16_layout_A - else: - # A stored as [M, K]: row=M, col=K - index_map = shared_16x16_to_local_32x8_layout_A - reverse_index_map = thread_id_shared_access_32x8_to_16x16_layout_A - else: - # B matrix - if transposed: - # B stored as [N, K]: thread t: N-row = t%16, K-col = (t//16)*8+local - index_map = shared_16x16_to_local_32x8_layout_B_colmajor - reverse_index_map = thread_id_shared_access_32x8_to_16x16_layout_B_colmajor - else: - # B stored as [K, N]: thread t: K-row = t%16, N-col = (t//16)*8+local - index_map = shared_16x16_to_local_32x8_layout_B - reverse_index_map = thread_id_shared_access_32x8_to_16x16_layout_B - return index_map, reverse_index_map + return self.a_forward_layout_fn, self.a_reverse_layout_fn + return self.b_forward_layout_fn, self.b_reverse_layout_fn def get_store_index_map(self, inverse: bool = False) -> IndexMap: """Return the store index map. @@ -166,7 +162,7 @@ def get_store_index_map(self, inverse: bool = False) -> IndexMap: """ warp_size, local_size_c = self.WARP_SIZE, self.local_size_out # forward: (thread_id, local_id) -> (row, col) - index_map = IndexMap.from_func(wmma_store_index_map, index_dtype=T.int32) + index_map = IndexMap.from_func(self.store_index_map_fn, index_dtype=T.int32) if not inverse: return index_map # inverse: (row, col) -> (thread_id, local_id) @@ -188,10 +184,13 @@ def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, rk=0): thread_binding = self.get_thread_binding() _, reverse_index_map = self.get_ldmatrix_index_map(is_b=False) + # legalize shared buffer to region A_region = self._legalize_to_buffer_region(A_shared_buf) A_buf = A_region.buffer A_base0 = A_region.region[-2].min A_base1 = A_region.region[-1].min + # Leading dimensions (e.g. pipeline stage axis) – empty for 2-D buffers + A_other = [r.min for r in A_region.region[:-2]] @T.macro def _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk=0): @@ -201,13 +200,13 @@ def _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk=0): for local_id in T.vectorized(k_pack * local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x) - A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[tuple(A_other) + (A_base0 + l + row, A_base1 + r + col)] else: for i in T.serial(warp_rows): for local_id in T.vectorized(k_pack * local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k)) - A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[tuple(A_other) + (A_base0 + l + row, A_base1 + r + col)] return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) @@ -227,10 +226,13 @@ def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, rk=0): thread_binding = self.get_thread_binding() _, reverse_index_map = self.get_ldmatrix_index_map(is_b=True) + # legalize shared buffer to region B_region = self._legalize_to_buffer_region(B_shared_buf) B_buf = B_region.buffer B_base0 = B_region.region[-2].min B_base1 = B_region.region[-1].min + # Leading dimensions (e.g. pipeline stage axis) – empty for 2-D buffers + B_other = [r.min for r in B_region.region[:-2]] @T.macro def _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk=0): @@ -240,13 +242,13 @@ def _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk=0): for local_id in T.vectorized(k_pack * local_size_b): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = (warp_n * warp_col_tiles + j * micro_size_y, rk * chunk + ki * (k_pack * micro_size_k)) - B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col] + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[tuple(B_other) + (B_base0 + l + row, B_base1 + r + col)] else: for j in T.serial(warp_cols): for local_id in T.vectorized(k_pack * local_size_b): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_n * warp_col_tiles + j * micro_size_y) - B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col] + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[tuple(B_other) + (B_base0 + l + row, B_base1 + r + col)] return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) @@ -317,13 +319,14 @@ def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None): M_DIM, N_DIM = self.M_DIM, self.N_DIM C_buf_dims = len(C_buf.shape) assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D" + store_index_map = self.store_index_map_fn @T.macro def _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding): tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) for i, j in T.grid(warp_rows, warp_cols): for local_id in T.vectorized(local_size_out): - row, col = T.meta_var(wmma_store_index_map(tx, local_id)) + row, col = T.meta_var(store_index_map(tx, local_id)) if C_buf_dims == 2: C_buf[ (warp_m * warp_rows + i) * M_DIM + row, @@ -339,7 +342,7 @@ def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) for i, j in T.grid(warp_rows, warp_cols): for local_id in T.vectorized(local_size_out): - row, col = T.meta_var(wmma_store_index_map(tx, local_id)) + row, col = T.meta_var(store_index_map(tx, local_id)) C_buf[ (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, (pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + col, @@ -361,17 +364,6 @@ def make_wmma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = " matrix_is_a = matrix == "A" transposed = self.a_transposed if matrix_is_a else self.b_transposed - index_map, _ = self.get_ldmatrix_index_map(is_b=not matrix_is_a) - - inverse_load_layout = IndexMap.from_func(index_map, index_dtype=T.int32) - - def forward_thread(i, j): - lane_id, _ = inverse_load_layout.map_indices([i, j]) - return lane_id - - def forward_index(i, j): - _, local_id = inverse_load_layout.map_indices([i, j]) - return local_id micro_size_k = self.micro_size_k * self.k_pack if matrix_is_a: @@ -385,29 +377,61 @@ def forward_index(i, j): else: shape_atom = [micro_size_k, self.micro_size_y] - base_fragment = T.Fragment( - shape_atom, - forward_thread_fn=forward_thread, - forward_index_fn=forward_index, - ) + """ + gfx11 and gfx12 differ in how logical A/B fragment elements map to lanes. + + gfx11 duplicates each logical A/B element across the two half-waves + (lane t and lane t + 16). A single-owner forward_thread_fn cannot + faithfully represent this one-to-many ownership, so we model it with + T.Fragment(..., forward_fn=..., replicate=2), where `rep` selects the + lower/upper half-wave copy. + + gfx12 has a single unique owner for each logical element, so the + existing forward_thread_fn/forward_index_fn form is sufficient. + """ + if self.rdna_gen == 11: + fragment_forward = self.a_fragment_forward_fn if matrix_is_a else self.b_fragment_forward_fn + assert fragment_forward is not None + base_fragment = T.Fragment( + shape_atom, + forward_fn=fragment_forward, + replicate=self.fragment_replicate, + ) + else: + index_map, _ = self.get_ldmatrix_index_map(is_b=not matrix_is_a) + inverse_load_layout = IndexMap.from_func(index_map, index_dtype=T.int32) + + def forward_thread(i, j): + lane_id, _ = inverse_load_layout.map_indices([i, j]) + return lane_id + + def forward_index(i, j): + _, local_id = inverse_load_layout.map_indices([i, j]) + return local_id + + base_fragment = T.Fragment( + shape_atom, + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) warp_s = self.warp_rows if matrix_is_a else self.warp_cols warp_r = self.chunk // micro_size_k block_s = self.block_row_warps if matrix_is_a else self.block_col_warps - replicate = self.block_col_warps if matrix_is_a else self.block_row_warps + block_replicate = self.block_col_warps if matrix_is_a else self.block_row_warps if (matrix_is_a and not transposed) or (not matrix_is_a and transposed): warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False) if matrix_is_a: - block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) + block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(block_replicate) else: - block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True) + block_fragment = warp_fragment.replicate(block_replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True) else: warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True) if matrix_is_a: - block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) + block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(block_replicate) else: - block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True) + block_fragment = warp_fragment.replicate(block_replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True) return block_fragment From 3aeb962e9ec76ff6644cb304c9f6364f7c629a95 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Wed, 22 Apr 2026 17:52:21 +0800 Subject: [PATCH 114/156] fix let stmt clone bug --- src/transform/auto_schedule/schedule_builder.cc | 17 +++++++++++++++++ src/transform/auto_schedule/schedule_builder.h | 17 +++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index 260108c9f3..e57797aef5 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -1114,6 +1114,23 @@ void ScheduleUnitBuilder::NaiveScheduleLoop(ControlNode *ctrl) { node_i_let_stmt->value, Evaluate(0)); auto cloned_task = std::make_shared(); cloned_task->stmts.push_back(cloned_let_stmt); + cloned_task->SetReadRegions(node_i_task->GetReadRegions()); + cloned_task->SetWriteRegions(node_i_task->GetWriteRegions()); + cloned_task->SetReadVars(node_i_task->GetReadVars()); + { + auto write_vars = node_i_task->GetWriteVars(); + for (auto &v : write_vars) { + if (v.same_as(node_i_let_stmt->var)) { + v = cloned_let_stmt->var; + } + } + cloned_task->SetWriteVars(write_vars); + } + cloned_task->SetLatency(node_i_task->GetLatency()); + cloned_task->SetII(node_i_task->GetII()); + cloned_task->SetUsesCUDACore(node_i_task->UsesCUDACore()); + cloned_task->SetUsesTMACore(node_i_task->UsesTMACore()); + cloned_task->SetUsesTensorCore(node_i_task->UsesTensorCore()); stage_map[cloned_task.get()] = rem_stage_j; for (int k = j; k < n; ++k) { diff --git a/src/transform/auto_schedule/schedule_builder.h b/src/transform/auto_schedule/schedule_builder.h index 1aef87d785..4d95344d82 100644 --- a/src/transform/auto_schedule/schedule_builder.h +++ b/src/transform/auto_schedule/schedule_builder.h @@ -581,6 +581,23 @@ class ScheduleUnitBuilder { Evaluate(0)); auto cloned_task = std::make_shared(); cloned_task->stmts.push_back(cloned_let_stmt); + cloned_task->SetReadRegions(node_i_task->GetReadRegions()); + cloned_task->SetWriteRegions(node_i_task->GetWriteRegions()); + cloned_task->SetReadVars(node_i_task->GetReadVars()); + { + auto write_vars = node_i_task->GetWriteVars(); + for (auto &v : write_vars) { + if (v.same_as(node_i_let_stmt->var)) { + v = cloned_let_stmt->var; + } + } + cloned_task->SetWriteVars(write_vars); + } + cloned_task->SetLatency(node_i_task->GetLatency()); + cloned_task->SetII(node_i_task->GetII()); + cloned_task->SetUsesCUDACore(node_i_task->UsesCUDACore()); + cloned_task->SetUsesTMACore(node_i_task->UsesTMACore()); + cloned_task->SetUsesTensorCore(node_i_task->UsesTensorCore()); stage_map[cloned_task.get()] = rem_stage_j; for (int k = j; k < n; ++k) { From 4eed399f0ba32f39ac0c813c7164b85ee43af3be Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Wed, 22 Apr 2026 17:56:41 +0800 Subject: [PATCH 115/156] add the innermost task to sync infos --- src/transform/auto_schedule/barrier.h | 138 ++++++++++++++------ src/transform/auto_schedule/ir_structure.cc | 6 +- src/transform/auto_schedule/ir_structure.h | 14 +- 3 files changed, 105 insertions(+), 53 deletions(-) diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index a9f7d293ee..521c6f4a47 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -585,6 +585,37 @@ static TaskNode *GetInnerTask(ScheduleUnit *unit) { } } +struct SyncInfo { + int distance; // the distance of iterations + Buffer buffer; // the buffer that requires synchronization + const TaskNode *producer; // the innermost task that needs to be waited on + const TaskNode *consumer; // the innermost task that needs to wait + int buffer_versions; // the number of versions for the buffer (for calculating + // barrier slots) + + SyncInfo(int distance, Buffer buffer, const TaskNode *producer, + const TaskNode *consumer, int buffer_versions) + : distance(distance), buffer(buffer), producer(producer), + consumer(consumer), buffer_versions(buffer_versions) {} + + // Define operator< for set + bool operator<(const SyncInfo &other) const { + if (distance != other.distance) { + return distance < other.distance; + } + if (buffer->name != other.buffer->name) { + return buffer->name < other.buffer->name; + } + if (producer != other.producer) { + return producer < other.producer; + } + if (consumer != other.consumer) { + return consumer < other.consumer; + } + return buffer_versions < other.buffer_versions; + } +}; + static auto GetSyncInfos(const std::vector &units, int num_wgs, const std::unordered_map &units, int num_wgs, } } std::map, - std::pair, int>>> + std::map, std::set>> sync_infos; for (const auto &buffer : buffers) { int num_versions = 1; @@ -607,7 +638,9 @@ GetSyncInfos(const std::vector &units, int num_wgs, num_versions = it->second; } std::vector last_read_unit(num_wgs, nullptr); + std::vector> last_read_unit_tasks(num_wgs); ScheduleUnit *last_write_unit = nullptr; + std::set last_write_unit_tasks; int last_write_wg_id = -1; std::vector waited_write_wgs(num_wgs, false); for (int iter = 0; iter < (is_loop ? 2 : 1); ++iter) { @@ -618,16 +651,14 @@ GetSyncInfos(const std::vector &units, int num_wgs, ICHECK(0 <= wg_id && wg_id < num_wgs); if (buffer_access.buffer != buffer) continue; - auto add_sync = [&](ScheduleUnit *wait_unit, int wait_wg_id) { + auto add_sync = [&](ScheduleUnit *wait_unit, int wait_wg_id, + const std::set &wait_tasks) { int distance = iter ? num_versions : 0; - auto &[barrier_versions, wait_map] = - sync_infos[{wait_unit, wait_wg_id}]; - barrier_versions = std::max(barrier_versions, num_versions); + auto &wait_map = sync_infos[{wait_unit, wait_wg_id}]; auto it = wait_map.find({unit, wg_id}); - if (it == wait_map.end()) { - wait_map[{unit, wg_id}] = distance; - } else { - it->second = std::min(it->second, distance); + for (auto wait_task : wait_tasks) { + wait_map[{unit, wg_id}].emplace(distance, buffer, wait_task, + buffer_access.task, num_versions); } }; if (!buffer_access.is_write) { @@ -635,12 +666,13 @@ GetSyncInfos(const std::vector &units, int num_wgs, continue; if (waited_write_wgs[wg_id]) continue; - add_sync(last_write_unit, last_write_wg_id); + add_sync(last_write_unit, last_write_wg_id, last_write_unit_tasks); } else { for (int last_wg_id = 0; last_wg_id < num_wgs; ++last_wg_id) { if (last_read_unit[last_wg_id] == nullptr) continue; - add_sync(last_read_unit[last_wg_id], last_wg_id); + add_sync(last_read_unit[last_wg_id], last_wg_id, + last_read_unit_tasks[last_wg_id]); } } } @@ -652,10 +684,12 @@ GetSyncInfos(const std::vector &units, int num_wgs, continue; if (!buffer_access.is_write) { waited_write_wgs[wg_id] = true; + last_read_unit_tasks[wg_id].clear(); } else { for (int wg_id = 0; wg_id < num_wgs; ++wg_id) { last_read_unit[wg_id] = nullptr; } + last_write_unit_tasks.clear(); } } for (const auto &buffer_access : @@ -665,8 +699,10 @@ GetSyncInfos(const std::vector &units, int num_wgs, continue; if (!buffer_access.is_write) { last_read_unit[wg_id] = unit; + last_read_unit_tasks[wg_id].insert(buffer_access.task); } else { last_write_unit = unit; + last_write_unit_tasks.insert(buffer_access.task); last_write_wg_id = wg_id; for (int wg_id = 0; wg_id < num_wgs; ++wg_id) { waited_write_wgs[wg_id] = false; @@ -682,19 +718,15 @@ GetSyncInfos(const std::vector &units, int num_wgs, static void InsertSynchronization( const std::vector &units, - const std::map< - std::pair, - std::pair, int>>> + const std::map, + std::map, std::set>> &sync_infos, int &next_barrier_id, std::vector &barrier_buffers, Map &barrier_map, const std::vector &thread_count, LoopNestingInfo &loop_info) { - std::map unit_to_order; - for (size_t i = 0; i < units.size(); ++i) { - unit_to_order[units[i]] = i; - } - // Initiate WGMMA tracking structures int num_wgs = thread_count.size(); + // Initiate WGMMA tracking structures + /* std::vector wgmma_count(num_wgs, 0); std::vector> wgmma_id(num_wgs); for (auto unit : units) { @@ -711,18 +743,19 @@ static void InsertSynchronization( } } } + */ // Insert synchronization statements based on sync_infos for (auto unit : units) { for (int wg_id = 0; wg_id < num_wgs; ++wg_id) { auto sync_it = sync_infos.find({unit, wg_id}); if (sync_it == sync_infos.end()) continue; - const auto &wait_map = sync_it->second.second; + const auto &wait_map = sync_it->second; bool is_async = unit->UsesTMACore() || unit->UsesTensorCore(); // Handle WGMMA synchronization if (unit->HasWGMMA()) { bool different_wg_id = false; - for (const auto &[waiting_unit_info, distance] : wait_map) { + for (const auto &[waiting_unit_info, _] : wait_map) { auto [waiting_unit, waiting_wg_id] = waiting_unit_info; if (waiting_wg_id != wg_id) { different_wg_id = true; @@ -730,19 +763,8 @@ static void InsertSynchronization( } } if (!different_wg_id) { - for (const auto &[waiting_unit_info, distance] : wait_map) { + for (const auto &[waiting_unit_info, _] : wait_map) { auto [waiting_unit, waiting_wg_id] = waiting_unit_info; - // Error: wrong num_mma in prologue and epilogue. - // Cannot fix now. - // - // int real_distance = distance + unit->stage - waiting_unit->stage; - // int num_mma = wgmma_id[wg_id][waiting_unit] - - // wgmma_id[wg_id][unit]; num_mma += real_distance * - // wgmma_count[wg_id]; if (unit->isInnerTask()) { - // --num_mma; - // } - // - // Fallback to set num_mma to 0 to avoid error. int num_mma = 0; Stmt wait_stmt = Evaluate(Call(DataType::Handle(), wait_wgmma(), {num_mma})); @@ -759,7 +781,13 @@ static void InsertSynchronization( // unit as synchronized unless it uses other asynchronous operations. is_async = unit->UsesTMACore(); } - int barrier_versions = std::max(sync_it->second.first, 1); + int barrier_versions = 1; + for (const auto &[waiting_unit_info, sync_infos] : wait_map) { + for (const auto &sync_info : sync_infos) { + barrier_versions = + std::max(barrier_versions, sync_info.buffer_versions); + } + } Buffer barrier_buffer; // Handle single special task, such as TCGEN05 or TMA load, that requires // a barrier for itself. @@ -774,9 +802,15 @@ static void InsertSynchronization( indexmod(loop_info.CalculateIterationCount(), barrier_versions); PrimExpr mbar_expr = BufferLoad(barrier_buffer, {version_index}); RewriteGemmMbar(task, mbar_expr); - // Stmt arrive_stmt = - // makeTcgen05MmaArrive(barrier_buffer, version_index); - // InsertStatementIntoScheduleUnit(unit, arrive_stmt, false, wg_id); + // TODO: need to change the lower of tcgen05_gemm to check if there is + // already a arrive statement. Then we can manually insert the arrive + // statement to deal with the case where the tcgen05_gemm is inside an + // if condition. + /* + Stmt arrive_stmt = + makeTcgen05MmaArrive(barrier_buffer, version_index); + InsertStatementIntoScheduleUnit(unit, arrive_stmt, false, wg_id); + */ } if (task->HasTMALoad() && task_wg_id == wg_id) { int barrier_id = next_barrier_id++; @@ -792,7 +826,8 @@ static void InsertSynchronization( } } auto check_need_barrier = [&](ScheduleUnit *waiting_unit, - int waiting_wg_id) { + int waiting_wg_id, + const SyncInfo &sync_info) { if (unit == waiting_unit) // Note: the logic here need some assumption. return false; @@ -800,13 +835,24 @@ static void InsertSynchronization( return true; if (!is_async) return false; + if (!sync_info.producer->UsesTMACore() && + !sync_info.producer->UsesTensorCore()) + return false; + if (sync_info.producer->UsesTensorCore() && + sync_info.consumer->UsesTensorCore()) + return false; return true; }; bool need_barrier = false; - for (const auto &[waiting_unit_info, distance] : wait_map) { + for (const auto &[waiting_unit_info, sync_infos] : wait_map) { auto [waiting_unit, waiting_wg_id] = waiting_unit_info; - if (check_need_barrier(waiting_unit, waiting_wg_id)) { - need_barrier = true; + for (const auto &sync_info : sync_infos) { + if (check_need_barrier(waiting_unit, waiting_wg_id, sync_info)) { + need_barrier = true; + break; + } + } + if (need_barrier) { break; } } @@ -839,9 +885,15 @@ static void InsertSynchronization( } } // Add wait statements for all waiting units. - for (const auto &[waiting_unit_info, distance] : wait_map) { + for (const auto &[waiting_unit_info, sync_infos] : wait_map) { auto [waiting_unit, waiting_wg_id] = waiting_unit_info; - if (check_need_barrier(waiting_unit, waiting_wg_id)) { + int distance = 100; + for (const auto &sync_info : sync_infos) { + if (check_need_barrier(waiting_unit, waiting_wg_id, sync_info)) { + distance = std::min(distance, sync_info.distance); + } + } + if (distance < 100) { PrimExpr iteration = loop_info.CalculateIterationCount() - distance; PrimExpr version_index = indexmod(iteration, barrier_versions); PrimExpr mbar_expr = BufferLoad(barrier_buffer, {version_index}); diff --git a/src/transform/auto_schedule/ir_structure.cc b/src/transform/auto_schedule/ir_structure.cc index 9b7e987128..6d035fb5f5 100644 --- a/src/transform/auto_schedule/ir_structure.cc +++ b/src/transform/auto_schedule/ir_structure.cc @@ -279,7 +279,7 @@ void TaskNode::CollectBufferAccessInfo( auto emit_access = [&](const BufferRegion ®ion, bool is_write) { if (wg_id >= 0) { // Normal assigned warpgroup - result.emplace(region->buffer, is_write, wg_id, phase); + result.emplace(region->buffer, is_write, wg_id, this); } else if (IsWarpgroupBroadcast(wg_id)) { // Broadcast: skip register memory (each wg has its own copy) if (IsRegisterRegion(region)) { @@ -287,12 +287,12 @@ void TaskNode::CollectBufferAccessInfo( } // Shared/global memory is shared across wgs — emit for all for (int i = 0; i < num_wgs; ++i) { - result.emplace(region->buffer, is_write, i, phase); + result.emplace(region->buffer, is_write, i, this); } } else { // Unassigned (kWarpgroupUnassigned): expand to all wgs (legacy behavior) for (int i = 0; i < num_wgs; ++i) { - result.emplace(region->buffer, is_write, i, phase); + result.emplace(region->buffer, is_write, i, this); } } }; diff --git a/src/transform/auto_schedule/ir_structure.h b/src/transform/auto_schedule/ir_structure.h index a1e28318b5..82c2f8093f 100644 --- a/src/transform/auto_schedule/ir_structure.h +++ b/src/transform/auto_schedule/ir_structure.h @@ -54,14 +54,14 @@ inline bool IsWarpgroupBroadcast(int wg_id) { // Structure to store buffer access information struct BufferAccessInfo { Buffer buffer; - bool is_write; // true for write, false for read - int warpgroup_id; // warpgroup id of the innermost TaskNode - SchedulePhase schedule_phase{SchedulePhase::kBody}; // scheduling phase + bool is_write; // true for write, false for read + int warpgroup_id; // warpgroup id of the access + const TaskNode *task; // the innermost TaskNode BufferAccessInfo(Buffer buffer, bool is_write, int warpgroup_id, - SchedulePhase phase = SchedulePhase::kBody) + const TaskNode *task) : buffer(buffer), is_write(is_write), warpgroup_id(warpgroup_id), - schedule_phase(phase) {} + task(task) {} // Define operator< for set bool operator<(const BufferAccessInfo &other) const { @@ -74,8 +74,8 @@ struct BufferAccessInfo { if (warpgroup_id != other.warpgroup_id) { return warpgroup_id < other.warpgroup_id; } - if (schedule_phase != other.schedule_phase) { - return schedule_phase < other.schedule_phase; + if (task != other.task) { + return task < other.task; } return false; } From 55cf9c1c9d961c1c340ff0ab78d5bc092d062e5e Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Wed, 22 Apr 2026 20:45:27 +0800 Subject: [PATCH 116/156] [codex] Remove dead transform pass leftovers (#2083) * Remove dead transform pass leftovers * Fix formatting in transform cleanup --- ...align_dynamic_shared_memory_allocations.cc | 159 ------------------ .../annotate_warp_group_reg_alloc.cc | 13 +- .../eliminate_storage_sync_for_mbarrier.cc | 125 -------------- src/transform/producer_consumer_ws.cc | 21 +-- src/transform/warp_specialized_rewriter.h | 100 ----------- tilelang/transform/__init__.py | 57 ------- 6 files changed, 22 insertions(+), 453 deletions(-) delete mode 100644 src/transform/align_dynamic_shared_memory_allocations.cc delete mode 100644 src/transform/eliminate_storage_sync_for_mbarrier.cc delete mode 100644 src/transform/warp_specialized_rewriter.h diff --git a/src/transform/align_dynamic_shared_memory_allocations.cc b/src/transform/align_dynamic_shared_memory_allocations.cc deleted file mode 100644 index 1c2519df99..0000000000 --- a/src/transform/align_dynamic_shared_memory_allocations.cc +++ /dev/null @@ -1,159 +0,0 @@ -/*! - * \file align_dynamic_shared_memory_allocations.cc - * \brief align dynamic shared memory allocations - */ - -#include -#include -#include -#include -#include -#include - -#include "../op/builtin.h" -#include "arith/ir_mutator_with_analyzer.h" -#include "runtime/thread_storage_scope.h" -#include "tir/transforms/ir_utils.h" - -namespace tvm { -namespace tl { - -using namespace tir; - -class TileLangAlignDynamicSharedMemoryAllocations : public StmtExprMutator { -public: - explicit TileLangAlignDynamicSharedMemoryAllocations(int align_bytes) - : align_bytes_(align_bytes) {} - - static Stmt Substitute(int align_bytes, const Stmt &stmt) { - TileLangAlignDynamicSharedMemoryAllocations smem_rewriter(align_bytes); - return smem_rewriter.VisitStmt(stmt); - } - - Stmt VisitStmt_(const AllocateNode *op) final { - auto storage_scope = - runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); - if (storage_scope.rank == runtime::StorageRank::kShared && - storage_scope.tag == ".dyn") { - auto new_extents = - MakeRoundRobinAlignment(op->extents, align_bytes_, op->dtype.bytes()); - if (!new_extents.same_as(op->extents)) { - auto new_allocate = Allocate(op->buffer_var, op->dtype, new_extents, - op->condition, op->body, op->annotations); - return StmtExprMutator::VisitStmt(new_allocate); - } - } - return StmtExprMutator::VisitStmt_(op); - } - - Stmt VisitStmt_(const BlockNode *op) final { - Block block = tvm::ffi::GetRef(op); - Array alloc_buffers = op->alloc_buffers; - alloc_buffers.MutateByApply([this](Buffer buf) { - auto storage_scope = - runtime::StorageScope::Create(GetPtrStorageScope(buf->data)); - if (storage_scope.rank == runtime::StorageRank::kShared && - storage_scope.tag == ".dyn") { - auto new_shape = MakeRoundRobinAlignment(buf->shape, align_bytes_, - buf->dtype.bytes()); - if (!new_shape.same_as(buf->shape)) { - ObjectPtr new_buffer = - tvm::ffi::make_object(*(buf.get())); - new_buffer->shape = std::move(new_shape); - buffer_remap_.Set(buf, Buffer(new_buffer)); - return Buffer(new_buffer); - } - } - return buf; - }); - if (!alloc_buffers.same_as(op->alloc_buffers)) { - block.CopyOnWrite()->alloc_buffers = alloc_buffers; - } - return StmtExprMutator::VisitStmt_(block.get()); - } - - Stmt VisitStmt_(const BufferStoreNode *op) final { - auto store_node = tvm::ffi::GetRef(op); - Buffer buf = op->buffer; - if (buffer_remap_.count(buf)) { - buf = buffer_remap_[buf]; - return BufferStore(buf, op->value, op->indices); - } - return StmtExprMutator::VisitStmt_(store_node.get()); - } - - PrimExpr VisitExpr_(const BufferLoadNode *op) final { - auto load_node = tvm::ffi::GetRef(op); - Buffer buf = op->buffer; - if (buffer_remap_.count(buf)) { - buf = buffer_remap_[buf]; - return BufferLoad(buf, op->indices); - } - return StmtExprMutator::VisitExpr_(load_node.get()); - } - -private: - static Array MakeRoundRobinAlignment(Array extents, - int align_bytes, - int dtype_bytes) { - if (extents.empty()) - return extents; - // Calculate total number of elements - PrimExpr total_elems = make_const(extents[0].dtype(), 1); - for (auto extent : extents) { - total_elems = total_elems * extent; - } - // Calculate total bytes - PrimExpr total_bytes = total_elems * dtype_bytes; - // Check if already aligned - PrimExpr remainder = indexmod(total_bytes, align_bytes); - if (is_zero(remainder)) { - return extents; - } - // Need to pad the last dimension - Array adjusted; - for (size_t i = 0; i < extents.size(); ++i) { - adjusted.push_back(extents[i]); - } - // Calculate padded last dimension - // pad = ceil(total_bytes / align_bytes) * align_bytes - PrimExpr last_extent = extents.back(); - PrimExpr other_elems = make_const(extents[0].dtype(), 1); - for (size_t i = 0; i < extents.size() - 1; ++i) { - other_elems = other_elems * extents[i]; - } - // new_last_extent = ceil(total_bytes / align_bytes) * align_bytes / - // (other_elems * dtype_bytes) - PrimExpr padded_total_bytes = - floordiv(total_bytes + align_bytes - 1, align_bytes) * align_bytes; - PrimExpr new_last_extent = - floordiv(padded_total_bytes, other_elems * dtype_bytes); - adjusted.Set(adjusted.size() - 1, new_last_extent); - return adjusted; - } - - int align_bytes_; - Map buffer_remap_; -}; - -tvm::transform::Pass AlignDynamicSharedMemoryAllocations(int align_bytes) { - using namespace tir::transform; - auto pass_func = [align_bytes](PrimFunc f, const IRModule &m, - const PassContext &ctx) { - auto *n = f.CopyOnWrite(); - n->body = TileLangAlignDynamicSharedMemoryAllocations::Substitute( - align_bytes, n->body); - return f; - }; - return CreatePrimFuncPass(pass_func, 0, - "tl.AlignDynamicSharedMemoryAllocations", {}); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tl.transform.AlignDynamicSharedMemoryAllocations", - AlignDynamicSharedMemoryAllocations); -} - -} // namespace tl -} // namespace tvm diff --git a/src/transform/annotate_warp_group_reg_alloc.cc b/src/transform/annotate_warp_group_reg_alloc.cc index 0e7748593e..0c3d77141c 100644 --- a/src/transform/annotate_warp_group_reg_alloc.cc +++ b/src/transform/annotate_warp_group_reg_alloc.cc @@ -3,7 +3,16 @@ * \brief Annotate warp group reg alloc for warp specialization */ -#include "warp_specialized_rewriter.h" +#include +#include +#include +#include +#include +#include + +#include "../op/builtin.h" +#include "runtime/thread_storage_scope.h" +#include "tir/transforms/ir_utils.h" #include #include #include @@ -24,7 +33,7 @@ Stmt RewriteWarpSpecializationBody(const Stmt &stmt, F &&rewrite_if, if (const auto *if_node = stmt.as()) { *rewrote = true; - return rewrite_if(GetRef(if_node)); + return rewrite_if(ffi::GetRef(if_node)); } if (const auto *seq = stmt.as()) { diff --git a/src/transform/eliminate_storage_sync_for_mbarrier.cc b/src/transform/eliminate_storage_sync_for_mbarrier.cc deleted file mode 100644 index 6face1fa45..0000000000 --- a/src/transform/eliminate_storage_sync_for_mbarrier.cc +++ /dev/null @@ -1,125 +0,0 @@ -/*! - * \file eliminate_storage_sync_for_mbarrier.cc - */ -#include "../op/builtin.h" -#include "arith/ir_mutator_with_analyzer.h" -#include "arith/ir_visitor_with_analyzer.h" -#include -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace tl { - -using namespace tir; -using arith::IRMutatorWithAnalyzer; -using arith::IRVisitorWithAnalyzer; - -class Eliminator : public IRMutatorWithAnalyzer { -public: - static Stmt Substitute(const Stmt &stmt, bool skip_thread_partition = false) { - arith::Analyzer analyzer; - Eliminator transformer(&analyzer); - return transformer.VisitStmt(stmt); - } - - Eliminator(arith::Analyzer *analyzer) : IRMutatorWithAnalyzer(analyzer) { - im_mbarrier_for_ = false; - in_mbarrier_region_ = false; - } - - Stmt VisitStmt_(const AttrStmtNode *op) final { - if (op->attr_key == "thread_extent") { - if (const auto *var = op->node.as()) { - if (var->name_hint == "threadIdx.x") { - thread_extent_ = op; - } - } - } - return IRMutatorWithAnalyzer::VisitStmt_(op); - } - - Stmt VisitStmt_(const EvaluateNode *op) final { - const CallNode *call = nullptr; - if (op->value->IsInstance()) { - call = op->value.as(); - if (call->op.same_as(builtin::tvm_storage_sync())) { - // Skip storage sync if we're in a region with mbarrier operations - // and we're not in a for loop with mbarrier operations - if (in_mbarrier_region_ || im_mbarrier_for_) { - return Stmt(); - } - } else if (call->op.same_as(builtin::ptx_arrive_barrier()) || - call->op.same_as(tl::ptx_arrive_cluster_barrier()) || - call->op.same_as(builtin::ptx_wait_barrier())) { - in_mbarrier_region_ = true; - } - } - return IRMutatorWithAnalyzer::VisitStmt_(op); - } - - Stmt VisitStmt_(const IfThenElseNode *op) final { - bool old_in_mbarrier = in_mbarrier_region_; - Stmt then_case = VisitStmt(op->then_case); - - Stmt ret; - if (op->else_case.defined()) { - in_mbarrier_region_ = old_in_mbarrier; - Stmt else_case = VisitStmt(op->else_case.value()); - in_mbarrier_region_ = old_in_mbarrier || in_mbarrier_region_; - ret = IfThenElse(VisitExpr(op->condition), then_case, else_case); - } else { - in_mbarrier_region_ = old_in_mbarrier || in_mbarrier_region_; - ret = IfThenElse(VisitExpr(op->condition), then_case, Stmt()); - } - return ret; - } - - Stmt VisitStmt_(const ForNode *op) final { - PostOrderVisit(tvm::ffi::GetRef(op), [&](const ObjectRef &node) { - if (const auto *call = node.as()) { - if (call->op.same_as(builtin::ptx_init_barrier_thread_count()) || - call->op.same_as(mbarrier_wait_parity()) || - call->op.same_as(builtin::ptx_arrive_barrier()) || - call->op.same_as(builtin::ptx_cp_async_barrier())) { - im_mbarrier_for_ = true; - } - } - }); - auto stmt = IRMutatorWithAnalyzer::VisitStmt_(op); - im_mbarrier_for_ = false; - return stmt; - } - -private: - bool im_mbarrier_for_; - bool in_mbarrier_region_; - const AttrStmtNode *thread_extent_{nullptr}; -}; -using namespace tir::transform; - -namespace transform { - -tvm::transform::Pass EliminateStorageSyncForMBarrier() { - auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) { - auto *n = f.CopyOnWrite(); - n->body = Eliminator::Substitute(n->body); - return f; - }; - return CreatePrimFuncPass(pass_func, 0, "tl.EliminateStorageSyncForMBarrier", - {}); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tl.transform.EliminateStorageSyncForMBarrier", - EliminateStorageSyncForMBarrier); -} - -} // namespace transform -} // namespace tl -} // namespace tvm diff --git a/src/transform/producer_consumer_ws.cc b/src/transform/producer_consumer_ws.cc index a09a69f1ba..d4ba4f65a2 100644 --- a/src/transform/producer_consumer_ws.cc +++ b/src/transform/producer_consumer_ws.cc @@ -21,7 +21,9 @@ */ #include +#include #include +#include #include #include #include @@ -37,7 +39,6 @@ #include "../target/utils.h" #include "common/mbarrier.h" #include "multi_version_buffer_rewriter.h" -#include "warp_specialized_rewriter.h" namespace tvm { namespace tl { @@ -362,7 +363,7 @@ class BufferDataToBufferCollector : public StmtExprVisitor { } void VisitStmt_(const BlockNode *op) final { - CollectBuffers(GetRef(op)); + CollectBuffers(ffi::GetRef(op)); StmtExprVisitor::VisitStmt_(op); } @@ -423,7 +424,7 @@ class LocalAccessCollector : public StmtExprVisitor { } void VisitExpr_(const VarNode *op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); if (bound_vars_.count(var) || buffer_data_to_buffer_.count(var)) { return; } @@ -487,7 +488,7 @@ class LocalAccessCollector : public StmtExprVisitor { ICHECK_EQ(op->args.size(), 5); const auto *var = op->args[1].as(); ICHECK(var); - auto it = buffer_data_to_buffer_.find(GetRef(var)); + auto it = buffer_data_to_buffer_.find(ffi::GetRef(var)); if (it != buffer_data_to_buffer_.end() && IsBranchPrivateBuffer(it->second)) { int rw_mask = GetConstAccessMask(op->args[4]); @@ -919,7 +920,7 @@ AnalyzeBufferDataAccess(const Stmt &stmt, const Var &buffer_data, ICHECK_EQ(op->args.size(), 5); const auto *var = op->args[1].as(); ICHECK(var); - auto it = buffer_map_.find(GetRef(var)); + auto it = buffer_map_.find(ffi::GetRef(var)); if (it != buffer_map_.end() && it->second->data.same_as(buffer_data_)) { MarkAccess(op->args[4]); } @@ -1860,7 +1861,7 @@ class ProducerConsumerWSRewriter : public StmtExprMutator { ws_transformed_ = true; // Rebuild BlockRealize. - BlockRealize new_realize = GetRef(orig_realize); + BlockRealize new_realize = ffi::GetRef(orig_realize); new_realize.CopyOnWrite()->block = new_block; return new_realize; } @@ -1934,7 +1935,7 @@ class ProducerConsumerWSRewriter : public StmtExprMutator { if (!SameExpr(ge->b, consumer_extent_)) { return false; } - *branch = GetRef(if_node); + *branch = ffi::GetRef(if_node); return true; } @@ -2165,7 +2166,7 @@ class ProducerConsumerWSRewriter : public StmtExprMutator { } Block block = realize->block; block.CopyOnWrite()->body = result.stmt; - BlockRealize new_realize = GetRef(realize); + BlockRealize new_realize = ffi::GetRef(realize); new_realize.CopyOnWrite()->block = block; return {new_realize, true}; } @@ -2175,7 +2176,7 @@ class ProducerConsumerWSRewriter : public StmtExprMutator { if (!result.found) { return {stmt, false}; } - Block new_block = GetRef(block); + Block new_block = ffi::GetRef(block); new_block.CopyOnWrite()->body = result.stmt; return {new_block, true}; } @@ -2185,7 +2186,7 @@ class ProducerConsumerWSRewriter : public StmtExprMutator { if (!result.found) { return {stmt, false}; } - AttrStmt new_attr = GetRef(attr); + AttrStmt new_attr = ffi::GetRef(attr); new_attr.CopyOnWrite()->body = result.stmt; return {new_attr, true}; } diff --git a/src/transform/warp_specialized_rewriter.h b/src/transform/warp_specialized_rewriter.h deleted file mode 100644 index e9a9bbf5c0..0000000000 --- a/src/transform/warp_specialized_rewriter.h +++ /dev/null @@ -1,100 +0,0 @@ -/*! - * \file warp_specialized_rewriter.h - * \brief tools for warp-specialized-related analysis and transformation - */ - -#pragma once - -#include "arith/ir_visitor_with_analyzer.h" -#include "tir/analysis/var_use_def_analysis.h" -#include -#include -#include -#include -#include -#include - -#include - -#include "../op/builtin.h" -#include "./common/collector.h" -#include "./common/mbarrier.h" -#include "runtime/thread_storage_scope.h" -#include "tir/transforms/ir_utils.h" - -namespace tvm { -namespace tl { - -using namespace tir; -using namespace runtime; -using arith::IRVisitorWithAnalyzer; - -class WarpSpecializedDetector : public IRVisitorWithAnalyzer { -public: - // return true means this aws will be disabled - static bool Detect(const Stmt &stmt, bool skip_thread_partition = false) { - WarpSpecializedDetector detector; - detector.VisitStmt(stmt); - if (detector.has_warp_specialization_) { - LOG(WARNING) << "Auto warp specialization will be disabled because warp " - "specialization is manually enabled"; - return true; - } - if (detector.has_tma_op_ && detector.has_mbarrier_op_) { - LOG(WARNING) << "Auto warp specialization will be disabled because TMA " - "and mbarrier are both present"; - return true; - } - return false; - } - - WarpSpecializedDetector() { - has_tma_op_ = false; - has_mbarrier_op_ = false; - has_warp_specialization_ = false; - } - -private: - void VisitStmt_(const EvaluateNode *op) final { - if (const CallNode *call = op->value.as()) { - if (call->op.same_as(mbarrier_wait_parity()) || - call->op.same_as(builtin::ptx_arrive_barrier()) || - call->op.same_as(tl::ptx_arrive_cluster_barrier()) || - call->op.same_as(builtin::ptx_cp_async_barrier())) { - has_mbarrier_op_ = true; - } - } - IRVisitorWithAnalyzer::VisitStmt_(op); - } - - void VisitExpr_(const CallNode *op) final { - if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) || - op->op.same_as(set_max_nreg())) { - has_tma_op_ = true; - } - IRVisitorWithAnalyzer::VisitExpr_(op); - } - - void VisitStmt_(const AttrStmtNode *op) final { - if (op->attr_key == "warp_specialize" && - op->value.as()->value == 1) { - has_warp_specialization_ = true; - } - if (op->attr_key == tir::attr::thread_extent) { - IterVar iv = Downcast(op->node); - if (iv->thread_tag == "threadIdx.x") { - ICHECK(iv->dom->extent.as()); - thread_var_ = iv; - } - } - IRVisitorWithAnalyzer::VisitStmt_(op); - } - - bool has_tma_op_{false}; - IterVar thread_var_; - bool has_mbarrier_op_{false}; - bool has_warp_specialization_{false}; -}; - -} // namespace tl -} // namespace tvm diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index a7a9414d71..677887bd49 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -87,17 +87,6 @@ def InjectSoftwarePipeline(): return _ffi_api.InjectSoftwarePipeline() # type: ignore -def FrontendLegalize(): - """FrontendLegalize - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.FrontendLegalize() # type: ignore - - def LegalizeNegativeIndex(): """Legalize negative indices in buffer loads. @@ -143,17 +132,6 @@ def LowerHopperIntrin(): return _ffi_api.LowerHopperIntrin() if hasattr(_ffi_api, "LowerHopperIntrin") else lambda f: f # type: ignore -def WarpSpecializedPipeline(): - """WarpSpecializedPipeline - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.WarpSpecializedPipeline() # type: ignore - - def ThreadSync(storage_scope: str): """Insert sync between parallel read/write of shared buffers. @@ -170,22 +148,6 @@ def ThreadSync(storage_scope: str): return _ffi_api.ThreadSync(storage_scope) # type: ignore -def ThreadPartialSync(storage_scope: str): - """Insert partial sync. - - Parameters - ---------- - storage_scope: str - The target storage scope. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.ThreadPartialSync(storage_scope) # type: ignore - - def IfStmtBinding(): """IfStmtBinding @@ -451,11 +413,6 @@ def FlattenBuffer(): return _ffi_api.FlattenBuffer() # type: ignore -def EliminateStorageSyncForMBarrier(): - """EliminateStorageSyncForMBarrier""" - return _ffi_api.EliminateStorageSyncForMBarrier() # type: ignore - - def MergeSharedMemoryAllocations(enable_aggressive_merge: bool = False, align_bytes: int = 16): """MergeSharedMemoryAllocations @@ -482,20 +439,6 @@ def PersistThreadblock(): return _ffi_api.PersistThreadblock() # type: ignore -def AlignDynamicSharedMemoryAllocations(align_bytes: int = 16): - """AlignDynamicSharedMemoryAllocations - - Parameters - ---------- - align_bytes: int - The alignment bytes. - - Returns - ------- - """ - return _ffi_api.AlignDynamicSharedMemoryAllocations(align_bytes) # type: ignore - - def LowerSharedBarrier(): """LowerSharedBarrier""" return _ffi_api.LowerSharedBarrier() # type: ignore From 9aba41f026ddb7c0e03903c5a5924f8af0db6ab7 Mon Sep 17 00:00:00 2001 From: Liu Yunuo Date: Thu, 23 Apr 2026 00:36:53 +0800 Subject: [PATCH 117/156] [Bugfix] Enable `.shared::cta` in TMA copy paths only on CUDA 12.8+ (#2087) fix: enable in TMA copy paths only on CUDA 12.8+ --- src/target/ptx.cc | 8 ++++ src/tl_templates/cuda/copy_sm90.h | 76 +++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/src/target/ptx.cc b/src/target/ptx.cc index 8b806c783f..789183e277 100644 --- a/src/target/ptx.cc +++ b/src/target/ptx.cc @@ -1428,11 +1428,19 @@ std::string PrintCpAsyncBulkAsm(const std::string &shared_ptr, { unsigned int smem_addr_int = cast_smem_ptr_to_int({smem_addr}); unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8) __asm__ __volatile__( "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" :: "r"(smem_addr_int), "l"({global_ptr}), "r"({bytes}), "r"(barrier_addr_int) : "memory" ); +#else + __asm__ __volatile__( + "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" + :: "r"(smem_addr_int), "l"({global_ptr}), "r"({bytes}), "r"(barrier_addr_int) + : "memory" + ); +#endif } )"; diff --git a/src/tl_templates/cuda/copy_sm90.h b/src/tl_templates/cuda/copy_sm90.h index 86c845bbf7..c8e1794485 100644 --- a/src/tl_templates/cuda/copy_sm90.h +++ b/src/tl_templates/cuda/copy_sm90.h @@ -20,10 +20,18 @@ TL_DEVICE void tma_load(void *smem_ptr, void const *gmem_ptr, uint32_t smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); +#if (__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8) asm volatile("cp.async.bulk.shared::cta.global.mbarrier::complete_tx::" "bytes [%0], [%1], %2, [%3]; \n" ::"r"(smem_int_ptr), "l"((void const *)gmem_ptr), "r"(size), "r"(smem_int_mbar) :); +#else + asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::" + "bytes [%0], [%1], %2, [%3]; \n" ::"r"(smem_int_ptr), + "l"((void const *)gmem_ptr), "r"(size), "r"(smem_int_mbar) + :); +#endif } TL_DEVICE void tma_load_multicast(void *smem_ptr, void *gmem_ptr, @@ -50,6 +58,8 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); } uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); +#if (__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8) asm volatile("cp.async.bulk.tensor.1d.shared::cta.global.mbarrier::" "complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3}], [%2], %4;" @@ -57,6 +67,15 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "l"(cache_hint) : "memory"); +#else + asm volatile("cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::" + "complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3}], [%2], %4;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "l"(cache_hint) + : "memory"); +#endif } template (&smem_mbar)); } uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); +#if (__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8) asm volatile("cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::" "complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4}], [%2], %5;" @@ -79,6 +100,15 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), "l"(cache_hint) : "memory"); +#else + asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::" + "complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4}], [%2], %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "l"(cache_hint) + : "memory"); +#endif } template (&smem_mbar)); } uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); +#if (__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8) asm volatile("cp.async.bulk.tensor.3d.shared::cta.global.mbarrier::" "complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4, %5}], [%2], %6;" @@ -101,6 +133,15 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) : "memory"); +#else + asm volatile("cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::" + "complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +#endif } template @@ -116,6 +157,8 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); } uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); +#if (__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8) asm volatile("cp.async.bulk.tensor.4d.shared::cta.global.mbarrier::" "complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" @@ -123,6 +166,15 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) : "memory"); +#else + asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::" + "complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +#endif } template (&smem_mbar)); } uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); +#if (__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8) asm volatile("cp.async.bulk.tensor.5d.shared::cta.global.mbarrier::" "complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" @@ -147,6 +201,16 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) : "memory"); +#else + asm volatile("cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::" + "complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), + "l"(cache_hint) + : "memory"); +#endif } template (&smem_mbar)); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); +#if (__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8) asm volatile("cp.async.bulk.tensor.4d.shared::cta.global.im2col.mbarrier:" ":complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9;" @@ -169,6 +235,16 @@ tma_load_im2col(const CUtensorMap &descriptor, BarrierType &smem_mbar, "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), "h"(offset_w), "h"(offset_h), "l"(cache_hint) : "memory"); +#else + asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier:" + ":complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "l"(cache_hint) + : "memory"); +#endif } template From b6e75b1ca48320ec98a9982c7672b8080c5f2456 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Thu, 23 Apr 2026 10:57:10 +0800 Subject: [PATCH 118/156] fix hopper neutral stage --- .../auto_schedule/schedule_builder.cc | 65 ++++++++++--------- 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index e57797aef5..1474c52abf 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -624,18 +624,19 @@ AssignWarpgroupIdsGlobal(IRStructure *root, const WarpSpecializeConfig &config, } } - std::unordered_set prefix_tasks; - CollectPrefixTasks(root, prefix_tasks); - for (auto *task : prefix_tasks) { - task->SetSchedulePhase(SchedulePhase::kPrologue); - task->SetWarpgroupId(0); - } + std::unordered_set prefix_tasks, suffix_tasks; + if (config.producer_thread_count == 32) { + CollectPrefixTasks(root, prefix_tasks); + for (auto *task : prefix_tasks) { + task->SetSchedulePhase(SchedulePhase::kPrologue); + task->SetWarpgroupId(0); + } - std::unordered_set suffix_tasks; - CollectSuffixTasks(root, all_tasks, uf, suffix_tasks); - for (auto *task : suffix_tasks) { - task->SetSchedulePhase(SchedulePhase::kEpilogue); - task->SetWarpgroupId(0); + CollectSuffixTasks(root, all_tasks, uf, suffix_tasks); + for (auto *task : suffix_tasks) { + task->SetSchedulePhase(SchedulePhase::kEpilogue); + task->SetWarpgroupId(0); + } } std::unordered_map> components; @@ -999,28 +1000,30 @@ NaiveAssignWarpgroupIds(IRStructure *root, const WarpSpecializeConfig &config, } } - // Collect prefix/suffix tasks and reset them to neutral - std::unordered_set prefix_tasks; - CollectPrefixTasks(root, prefix_tasks); - for (auto *task : prefix_tasks) { - task->SetSchedulePhase(SchedulePhase::kPrologue); - task->SetWarpgroupId(0); - } - - int n = all_tasks.size(); - TaskUnionFind uf(n); - for (int i = 0; i < n; i++) { - for (int j = i + 1; j < n; j++) { - if (UseSameRegisterRegion(all_tasks[i].task, all_tasks[j].task)) { - uf.unite(i, j); + if (config.producer_thread_count == 32) { + // Collect prefix/suffix tasks and reset them to neutral + std::unordered_set prefix_tasks; + CollectPrefixTasks(root, prefix_tasks); + for (auto *task : prefix_tasks) { + task->SetSchedulePhase(SchedulePhase::kPrologue); + task->SetWarpgroupId(0); + } + + int n = all_tasks.size(); + TaskUnionFind uf(n); + for (int i = 0; i < n; i++) { + for (int j = i + 1; j < n; j++) { + if (UseSameRegisterRegion(all_tasks[i].task, all_tasks[j].task)) { + uf.unite(i, j); + } } } - } - std::unordered_set suffix_tasks; - CollectSuffixTasks(root, all_tasks, uf, suffix_tasks); - for (auto *task : suffix_tasks) { - task->SetSchedulePhase(SchedulePhase::kEpilogue); - task->SetWarpgroupId(0); + std::unordered_set suffix_tasks; + CollectSuffixTasks(root, all_tasks, uf, suffix_tasks); + for (auto *task : suffix_tasks) { + task->SetSchedulePhase(SchedulePhase::kEpilogue); + task->SetWarpgroupId(0); + } } // no double_thread in naive mode From 6e2fb565b41a9adb2405ed6c6be808a7c7db17f0 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Thu, 23 Apr 2026 13:11:34 +0800 Subject: [PATCH 119/156] change layout map & remove unused letstmt --- src/transform/auto_schedule.cc | 97 +++++++ .../auto_schedule/schedule_builder.cc | 21 ++ .../auto_schedule/warpgroup_partition.cc | 265 ++++++------------ .../auto_schedule/warpgroup_partition.h | 5 +- 4 files changed, 204 insertions(+), 184 deletions(-) diff --git a/src/transform/auto_schedule.cc b/src/transform/auto_schedule.cc index b68b65f634..ee879b6c99 100644 --- a/src/transform/auto_schedule.cc +++ b/src/transform/auto_schedule.cc @@ -53,6 +53,7 @@ #include #include +#include "../layout/layout.h" #include "../op/builtin.h" #include "../op/copy.h" #include "../op/gemm.h" @@ -742,6 +743,7 @@ ScheduleSingleKernel(const Stmt &kernel_body, IterVar thread_var, Target target, if (!config.enable_warpgroup_partition) { result.scheduled_body = ConvertIRStructureToStmt(ir_structure.get(), enable_epi); + result.scheduled_body = StripUnusedLetStmts(result.scheduled_body); result.did_warpgroup_partition = false; return result; } @@ -772,6 +774,7 @@ ScheduleSingleKernel(const Stmt &kernel_body, IterVar thread_var, Target target, ir_structure.get(), thread_var, result.barrier_buffers, result.barrier_map, enable_epi, thread_count, config, neutral_sync_shared_barrier, result.duplicated_fragment_buffers); + result.scheduled_body = StripUnusedLetStmts(result.scheduled_body); result.did_warpgroup_partition = true; return result; } @@ -1152,6 +1155,96 @@ Stmt ReNestLetStmts(const Stmt &stmt) { } // StmtMutator to rewrite alloc_buffers in Block nodes +namespace { + +bool LayoutShapesEqual(const Array &lhs, const Array &rhs, + arith::Analyzer *analyzer) { + if (lhs.size() != rhs.size()) { + return false; + } + for (size_t i = 0; i < lhs.size(); ++i) { + if (!analyzer->CanProveEqual(lhs[i], rhs[i])) { + return false; + } + } + return true; +} + +// Expand an annotated Layout so its InputShape matches the multi-versioned +// buffer shape by prepending the leading "num_versions" dim(s). +Layout ExpandAnnotatedLayoutForMultiVersionedBuffer(const Layout &layout, + const Buffer &old_buffer, + const Buffer &new_buffer) { + if (!layout.defined() || + new_buffer->shape.size() <= old_buffer->shape.size()) { + return Layout(); + } + + arith::Analyzer analyzer; + if (!LayoutShapesEqual(layout->InputShape(), old_buffer->shape, &analyzer)) { + return Layout(); + } + + size_t leading_ndim = new_buffer->shape.size() - old_buffer->shape.size(); + Array trailing_shape; + Array leading_shape; + for (size_t i = 0; i < leading_ndim; ++i) { + leading_shape.push_back(new_buffer->shape[i]); + } + for (size_t i = 0; i < old_buffer->shape.size(); ++i) { + trailing_shape.push_back(new_buffer->shape[leading_ndim + i]); + } + if (!LayoutShapesEqual(trailing_shape, old_buffer->shape, &analyzer)) { + return Layout(); + } + + return layout->Expand(leading_shape); +} + +// Walk the block's layout_map annotation and expand any entries whose buffer +// has been multi-versioned so downstream LayoutInference sees a matching shape. +bool UpdateExpandedLayoutMapForRemappedAllocs( + const std::vector> &remapped_allocs, + Map *annotations) { + if (remapped_allocs.empty() || !annotations->count(attr::kLayoutMap)) { + return false; + } + + auto layout_map_ref = annotations->Get(attr::kLayoutMap); + if (!layout_map_ref.has_value()) { + return false; + } + auto layout_map = layout_map_ref.value().as>(); + if (!layout_map.has_value()) { + return false; + } + + Map updated_layout_map = layout_map.value(); + std::unordered_set visited; + bool changed = false; + for (const auto &[old_buffer, new_buffer] : remapped_allocs) { + if (!visited.insert(old_buffer->data.get()).second || + !updated_layout_map.count(old_buffer->data)) { + continue; + } + Layout layout = updated_layout_map[old_buffer->data]; + Layout expanded = ExpandAnnotatedLayoutForMultiVersionedBuffer( + layout, old_buffer, new_buffer); + if (!expanded.defined()) { + continue; + } + updated_layout_map.Set(old_buffer->data, expanded); + changed = true; + } + + if (changed) { + annotations->Set(attr::kLayoutMap, updated_layout_map); + } + return changed; +} + +} // namespace + class AllocBufferRewriter : public StmtMutator { public: AllocBufferRewriter(const std::vector &buffer_infos) @@ -1169,11 +1262,13 @@ class AllocBufferRewriter : public StmtMutator { // Check if we need to update alloc_buffers bool needs_update = false; Array new_alloc_buffers; + std::vector> remapped_allocs; for (auto buffer : op->alloc_buffers) { auto it = buffer_remap_.find(buffer); if (it != buffer_remap_.end()) { new_alloc_buffers.push_back(it->second); + remapped_allocs.emplace_back(buffer, it->second); needs_update = true; } else { new_alloc_buffers.push_back(buffer); @@ -1184,6 +1279,8 @@ class AllocBufferRewriter : public StmtMutator { new_block->body = new_body; if (needs_update) { new_block->alloc_buffers = new_alloc_buffers; + UpdateExpandedLayoutMapForRemappedAllocs(remapped_allocs, + &new_block->annotations); } return Stmt(new_block); } diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index 1474c52abf..fbaaf9574e 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -110,11 +110,26 @@ bool SameBuffer(const BufferRegion &a, const BufferRegion &b) { bool SameVar(const Var &a, const Var &b) { return a.same_as(b); } +bool IsAttrInTask(const IRStructure *a) { + if (!a->IsTask()) { + return false; + } + auto task = static_cast(a); + for (const auto &stmt : task->stmts) { + if (stmt.as()) { + return true; + } + } + return false; +} + bool HasDependency(const IRStructure *a, const IRStructure *b) { if (a->ContainsLoopBreak()) return true; if (b->ContainsLoopBreak()) return true; + if (IsAttrInTask(a) || IsAttrInTask(b)) + return true; for (const auto &write_region_a : a->GetWriteRegions()) { for (const auto &read_region_b : b->GetReadRegions()) { if (SameBuffer(write_region_a, read_region_b)) @@ -137,6 +152,12 @@ bool HasDependency(const IRStructure *a, const IRStructure *b) { return true; } } + for (const auto &read_var_a : a->GetReadVars()) { + for (const auto &write_var_b : b->GetWriteVars()) { + if (SameVar(read_var_a, write_var_b)) + return true; + } + } return false; } diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index 58ebb731b9..918425bd94 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -563,184 +563,33 @@ CloneIRStructureChildrenWithWarpgroupFilter(SequenceNode *root_seq, var_remap, buffer_remap); } -std::shared_ptr -RemoveUnusedLetDecls(std::shared_ptr root) { - if (!root) - return nullptr; +// Post-pass for the finished auto-schedule Stmt: drop any LetStmt whose bound +// variable is not referenced in its body, provided the bound value expression +// is pure (no side-effects beyond `kReadState`). +class UnusedLetStmtStripper : public StmtExprMutator { +public: + Stmt VisitStmt_(const LetStmtNode *op) final { + Stmt new_body = this->VisitStmt(op->body); + PrimExpr new_value = this->VisitExpr(op->value); - // Phase 1: Collect LetDecl definitions and variable references from - // non-LetDecl nodes (task stmts and ScheduleUnit before/after). - struct LetDeclEntry { - const VarNode *var; - PrimExpr value; - }; - std::vector let_decls; - std::unordered_set referenced_vars; - - std::function collect = - [&](const IRStructure *node) { - if (!node) - return; - if (node->IsTask()) { - auto task = static_cast(node); - if (IsLetDeclTask(task)) { - const auto *let = task->stmts[0].as(); - let_decls.push_back({let->var.get(), let->value}); - } else { - VarRefCollector collector; - for (const auto &stmt : task->stmts) { - collector(stmt); - } - referenced_vars.insert(collector.vars.begin(), - collector.vars.end()); - } - } else if (node->IsSequence()) { - for (const auto &child : - static_cast(node)->children) { - collect(child.get()); - } - } else if (node->IsControl()) { - auto ctrl = static_cast(node); - collect(ctrl->task.get()); - collect(ctrl->child.get()); - // Also collect variable references from the For loop bounds - // (min, extent, step) so their LetDecls are not removed. - VarRefCollector for_collector; - for_collector(ctrl->control->min); - for_collector(ctrl->control->extent); - if (ctrl->control->step.has_value()) { - for_collector(ctrl->control->step.value()); - } - referenced_vars.insert(for_collector.vars.begin(), - for_collector.vars.end()); - } else if (node->IsWrapper()) { - auto wrapper = static_cast(node); - collect(wrapper->task.get()); - collect(wrapper->child.get()); - // Also collect variable references from the wrapper statement - // (LetStmt value / AttrStmt value) so their LetDecls are not removed. - VarRefCollector wrapper_collector; - wrapper_collector(wrapper->wrapper); - referenced_vars.insert(wrapper_collector.vars.begin(), - wrapper_collector.vars.end()); - } else if (node->IsScheduleUnit()) { - auto unit = static_cast(node); - collect(unit->child.get()); - VarRefCollector collector; - for (const auto &[_, stmts] : unit->before) { - for (const auto &s : stmts) - collector(s); - } - for (const auto &[_, stmts] : unit->after) { - for (const auto &s : stmts) - collector(s); - } - referenced_vars.insert(collector.vars.begin(), collector.vars.end()); - } else if (node->IsIf()) { - auto if_node = static_cast(node); - collect(if_node->task.get()); - collect(if_node->then_child.get()); - if (if_node->else_child) { - collect(if_node->else_child.get()); - } - // Collect variable references from the condition - VarRefCollector cond_collector; - cond_collector(if_node->condition); - referenced_vars.insert(cond_collector.vars.begin(), - cond_collector.vars.end()); - } - }; - collect(root.get()); - - // Phase 2: Transitive closure — if a LetDecl var is referenced, - // all vars in its value expression are transitively referenced too. - bool changed = true; - while (changed) { - changed = false; - for (const auto &entry : let_decls) { - if (referenced_vars.count(entry.var)) { - VarRefCollector collector; - collector(entry.value); - for (const auto *v : collector.vars) { - if (!referenced_vars.count(v)) { - referenced_vars.insert(v); - changed = true; - } - } - } - } - } + auto body_uses_var = UsesVar(new_body, [&](const VarNode *v) { + return v == op->var.get(); + }); + bool value_is_pure = SideEffect(new_value) <= CallEffectKind::kPure; - // Phase 3: Filter the tree — remove LetDecl tasks for unused vars. - std::function( - const std::shared_ptr &)> - filter_tree = [&](const std::shared_ptr &node) - -> std::shared_ptr { - if (!node) - return nullptr; - if (node->IsTask()) { - if (IsLetDeclTask(static_cast(node.get()))) { - const auto *let = static_cast(node.get()) - ->stmts[0] - .as(); - if (!referenced_vars.count(let->var.get())) { - return nullptr; // Remove unused LetDecl - } - } - return node; - } else if (node->IsSequence()) { - auto seq = static_cast(node.get()); - auto new_seq = std::make_shared(); - for (const auto &child : seq->children) { - auto filtered = filter_tree(child); - if (filtered) - new_seq->children.push_back(std::move(filtered)); - } - if (new_seq->children.empty()) - return nullptr; - return new_seq; - } else if (node->IsControl()) { - auto ctrl = static_cast(node.get()); - auto new_ctrl = std::make_shared(); - new_ctrl->control = ctrl->control; - new_ctrl->task = ctrl->task; - new_ctrl->SetPromote(ctrl->hasPromote()); - new_ctrl->child = filter_tree(ctrl->child); - if (!new_ctrl->child) - return nullptr; - return new_ctrl; - } else if (node->IsWrapper()) { - auto wrapper = static_cast(node.get()); - auto new_wrapper = std::make_shared(); - new_wrapper->wrapper = wrapper->wrapper; - new_wrapper->task = wrapper->task; - new_wrapper->child = filter_tree(wrapper->child); - return new_wrapper; - } else if (node->IsScheduleUnit()) { - auto unit = static_cast(node.get()); - auto new_unit = std::make_shared(); - new_unit->stage = unit->stage; - new_unit->before = unit->before; - new_unit->after = unit->after; - new_unit->child = filter_tree(unit->child); - if (!new_unit->child) - return nullptr; - return new_unit; - } else if (node->IsIf()) { - auto if_node = static_cast(node.get()); - auto new_if = std::make_shared(); - new_if->condition = if_node->condition; - new_if->task = if_node->task; - new_if->then_child = filter_tree(if_node->then_child); - if (if_node->else_child) { - new_if->else_child = filter_tree(if_node->else_child); - } - return new_if; + if (!body_uses_var && value_is_pure) { + return new_body; } - return node; - }; + if (new_body.same_as(op->body) && new_value.same_as(op->value)) { + return GetRef(op); + } + return LetStmt(op->var, new_value, new_body, op->span); + } +}; - return filter_tree(root); +Stmt StripUnusedLetStmts(const Stmt &stmt) { + UnusedLetStmtStripper stripper; + return stripper(stmt); } class SimtCopyDetector : public StmtExprVisitor { @@ -764,6 +613,44 @@ class SimtCopyDetector : public StmtExprVisitor { bool has_simt_copy_{false}; }; +// Detects whether `stmt` already carries any of the signals that downstream +// AnnotateWarpGroupRegAlloc (see src/transform/annotate_warp_group_reg_alloc.cc +// SetMaxNRegCollector) treats as "caller already decided register allocation". +// If any of these signals is present, the auto-schedule warp-group partitioner +// must skip its own (240/24-style) set_max_nreg injection so the inner choice +// is preserved. The three honored signals are: +// 1. an explicit tl::set_max_nreg() call, +// 2. an explicit tl::no_set_max_nreg() opt-out sentinel, +// 3. an AttrStmt with key attr::kCustomWarpSpecialization. +class InnerNRegDecisionDetector : public StmtExprVisitor { +public: + static bool Detect(const Stmt &stmt) { + InnerNRegDecisionDetector detector; + detector.VisitStmt(stmt); + return detector.has_decision_; + } + +private: + void VisitStmt_(const EvaluateNode *op) final { + if (const CallNode *call = op->value.as()) { + if (call->op.same_as(tl::set_max_nreg()) || + call->op.same_as(tl::no_set_max_nreg())) { + has_decision_ = true; + } + } + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == attr::kCustomWarpSpecialization) { + has_decision_ = true; + } + StmtExprVisitor::VisitStmt_(op); + } + + bool has_decision_{false}; +}; + Stmt ConvertIRStructureToStmt(IRStructure *structure, const bool outer_enable_epi) { if (!structure) { @@ -1160,12 +1047,11 @@ Stmt ApplyWarpgroupPartitionToIRStructure( wg_structures[i] = rebuilt_seq->children.empty() ? nullptr : rebuilt_seq; } } else { - // Fallback for non-SequenceNode root: clone entire root per warpgroup + // Fallback for non-SequenceNode root: clone entire root per warpgroup. for (size_t i = 0; i < num_wgs; ++i) { Map var_remap; - wg_structures[i] = - RemoveUnusedLetDecls(CloneIRStructureWithWarpgroupFilter( - root, i, var_remap, per_wg_buffer_remap[i])); + wg_structures[i] = CloneIRStructureWithWarpgroupFilter( + root, i, var_remap, per_wg_buffer_remap[i]); } } @@ -1212,6 +1098,22 @@ Stmt ApplyWarpgroupPartitionToIRStructure( has_simt_copy = SimtCopyDetector::Detect(full_wg1); } + // Check whether any inner pass already decided register allocation. + bool has_inner_nreg_decision = false; + if (num_wgs == 2 && config.enable_set_max_nreg) { + for (size_t i = 0; i < num_wgs; ++i) { + if (!wg_structures[i]) { + continue; + } + Stmt full_wg = + ConvertIRStructureToStmt(wg_structures[i].get(), outer_enable_epi); + if (InnerNRegDecisionDetector::Detect(full_wg)) { + has_inner_nreg_decision = true; + break; + } + } + } + // --- Per-child construction --- // Walk root SequenceNode's children. LetDecl children accumulate bindings; // non-LetDecl children produce IfThenElse blocks wrapped with accumulated @@ -1289,8 +1191,8 @@ Stmt ApplyWarpgroupPartitionToIRStructure( Integer(0), attr::kAutoScheduleSharedMemoryBoundary, 0, Evaluate(0))); // Prepend set_max_nreg only to the first non-LetDecl child - if (first_non_let && !has_simt_copy && num_wgs == 2 && - config.enable_set_max_nreg) { + if (first_non_let && !has_simt_copy && !has_inner_nreg_decision && + num_wgs == 2 && config.enable_set_max_nreg) { for (size_t i = 0; i < num_wgs; ++i) { wg_stmts[i] = SeqStmt({Evaluate(Call(DataType::Handle(), tl::set_max_nreg(), @@ -1319,7 +1221,8 @@ Stmt ApplyWarpgroupPartitionToIRStructure( wg_stmts[i] = Evaluate(0); } } - if (!has_simt_copy && num_wgs == 2 && config.enable_set_max_nreg) { + if (!has_simt_copy && !has_inner_nreg_decision && num_wgs == 2 && + config.enable_set_max_nreg) { for (size_t i = 0; i < num_wgs; ++i) { wg_stmts[i] = SeqStmt({Evaluate(Call(DataType::Handle(), tl::set_max_nreg(), diff --git a/src/transform/auto_schedule/warpgroup_partition.h b/src/transform/auto_schedule/warpgroup_partition.h index fcf0d7e5fa..25f3bd6a9e 100644 --- a/src/transform/auto_schedule/warpgroup_partition.h +++ b/src/transform/auto_schedule/warpgroup_partition.h @@ -43,9 +43,6 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, std::shared_ptr CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id); -std::shared_ptr -RemoveUnusedLetDecls(std::shared_ptr root); - std::vector> CloneIRStructureChildrenWithWarpgroupFilter(SequenceNode *root_seq, int warpgroup_id, @@ -70,5 +67,7 @@ Stmt ApplyWarpgroupPartitionToIRStructure( Stmt ReNestLetStmts(const Stmt &stmt); +Stmt StripUnusedLetStmts(const Stmt &stmt); + } // namespace tl } // namespace tvm From 89e68123776403e4eece4df8cd3961169048b58f Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Thu, 23 Apr 2026 13:55:49 +0800 Subject: [PATCH 120/156] disable auto scheduling when using thread vars --- tilelang/engine/phase.py | 60 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 33d294ede5..20860860e1 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -28,6 +28,64 @@ def module_has_tma(mod: IRModule) -> bool: return any(func.attrs and func.attrs.get("tl.has_tma", False) for _, func in mod.functions.items()) +def module_uses_thread_var(mod: IRModule) -> bool: + """Check whether any PrimFunc in ``mod`` references thread-index variables + inside its body. + """ + from tvm.tir import stmt_functor + + for _, func in mod.functions.items(): + if not isinstance(func, tir.PrimFunc): + continue + + thread_extent_vars: set = set() + explicit_thread_binding_loop: list[bool] = [False] + + def _collect(node): + if isinstance(node, tir.AttrStmt) and node.attr_key == "thread_extent": + iter_var = node.node + if isinstance(iter_var, tir.IterVar): + tag = getattr(iter_var, "thread_tag", "") or "" + if tag.startswith("threadIdx."): + thread_extent_vars.add(iter_var.var) + elif isinstance(node, tir.For) and node.kind == tir.ForKind.THREAD_BINDING: + tb = node.thread_binding + tag = getattr(tb, "thread_tag", "") if tb is not None else "" + if isinstance(tag, str) and tag.startswith("threadIdx."): + explicit_thread_binding_loop[0] = True + + stmt_functor.post_order_visit(func.body, _collect) + + if explicit_thread_binding_loop[0]: + return True + + if not thread_extent_vars: + continue + + uses_thread_var = [False] + + def _find_use(node): + if uses_thread_var[0]: + return + if isinstance(node, tir.Var) and node in thread_extent_vars: + uses_thread_var[0] = True + + def _walk(stmt): + if uses_thread_var[0]: + return + if isinstance(stmt, tir.AttrStmt) and stmt.attr_key == "thread_extent": + _walk(stmt.body) + return + stmt_functor.post_order_visit(stmt, _find_use) + + _walk(func.body) + + if uses_thread_var[0]: + return True + + return False + + def allow_vectorize(pass_ctx: PassContext | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() @@ -184,7 +242,7 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.InjectAssumes()(mod) # Simplify the IR expressions mod = tilelang.transform.Simplify()(mod) - if allow_autoschedule(target=target): + if allow_autoschedule(target=target) and not module_uses_thread_var(mod): # Auto schedule for high-level operations mod = tilelang.transform.IfConditionExtract()(mod) mod = tilelang.transform.AutoSchedule(False)(mod) From f3f6e7436d6b3c220b398008f26c4164128c9e7f Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Thu, 23 Apr 2026 14:24:53 +0800 Subject: [PATCH 121/156] find first/last tasks of a buffer access and reduce syncs by checking the type of dependency --- src/transform/auto_schedule/barrier.h | 225 +++++++---- src/transform/auto_schedule/ir_structure.cc | 414 ++++++++++++++++++++ src/transform/auto_schedule/ir_structure.h | 86 ++++ 3 files changed, 642 insertions(+), 83 deletions(-) diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index 521c6f4a47..922bc4b10c 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -640,73 +640,95 @@ GetSyncInfos(const std::vector &units, int num_wgs, std::vector last_read_unit(num_wgs, nullptr); std::vector> last_read_unit_tasks(num_wgs); ScheduleUnit *last_write_unit = nullptr; - std::set last_write_unit_tasks; - int last_write_wg_id = -1; + std::vector>> + last_write_unit_wg_tasks; std::vector waited_write_wgs(num_wgs, false); for (int iter = 0; iter < (is_loop ? 2 : 1); ++iter) { for (ScheduleUnit *unit : units) { + // Add dependencies for RAW and WAR between this unit and the last + // writer/reader units + for (int wg_id = 0; wg_id < num_wgs; ++wg_id) { + int distance = iter ? num_versions : 0; + // RAW: unit reads buffer, wait for last writer + auto first_reads = unit->GetFirstAccessTasks( + buffer, /*is_write=*/false, wg_id, SchedulePhase::kBody); + if (!first_reads.empty() && last_write_unit != nullptr && + !waited_write_wgs[wg_id]) { + for (auto *consumer : first_reads) { + for (const auto &[last_write_wg_id, last_write_unit_tasks] : + last_write_unit_wg_tasks) { + for (auto *producer : last_write_unit_tasks) { + sync_infos[{last_write_unit, last_write_wg_id}][{unit, wg_id}] + .emplace(distance, buffer, producer, consumer, + num_versions); + } + } + } + } + // WAR: unit writes buffer, wait for all last readers + auto first_writes = unit->GetFirstAccessTasks( + buffer, /*is_write=*/true, wg_id, SchedulePhase::kBody); + if (!first_writes.empty()) { + for (int last_wg = 0; last_wg < num_wgs; ++last_wg) { + if (last_read_unit[last_wg] == nullptr) + continue; + for (auto *consumer : first_writes) { + for (auto *producer : last_read_unit_tasks[last_wg]) { + sync_infos[{last_read_unit[last_wg], last_wg}][{unit, wg_id}] + .emplace(distance, buffer, producer, consumer, + num_versions); + } + } + } + } + } + // Set status to avoid redundant dependencies for subsequent units for (const auto &buffer_access : unit->GetBufferAccessInfo(num_wgs, SchedulePhase::kBody)) { int wg_id = buffer_access.warpgroup_id; - ICHECK(0 <= wg_id && wg_id < num_wgs); if (buffer_access.buffer != buffer) continue; - auto add_sync = [&](ScheduleUnit *wait_unit, int wait_wg_id, - const std::set &wait_tasks) { - int distance = iter ? num_versions : 0; - auto &wait_map = sync_infos[{wait_unit, wait_wg_id}]; - auto it = wait_map.find({unit, wg_id}); - for (auto wait_task : wait_tasks) { - wait_map[{unit, wg_id}].emplace(distance, buffer, wait_task, - buffer_access.task, num_versions); - } - }; if (!buffer_access.is_write) { - if (last_write_unit == nullptr) - continue; - if (waited_write_wgs[wg_id]) - continue; - add_sync(last_write_unit, last_write_wg_id, last_write_unit_tasks); + waited_write_wgs[wg_id] = true; } else { - for (int last_wg_id = 0; last_wg_id < num_wgs; ++last_wg_id) { - if (last_read_unit[last_wg_id] == nullptr) - continue; - add_sync(last_read_unit[last_wg_id], last_wg_id, - last_read_unit_tasks[last_wg_id]); + for (int wg_id = 0; wg_id < num_wgs; ++wg_id) { + last_read_unit[wg_id] = nullptr; } } } if (iter == 0) { - for (const auto &buffer_access : - unit->GetBufferAccessInfo(num_wgs, SchedulePhase::kBody)) { - int wg_id = buffer_access.warpgroup_id; - if (buffer_access.buffer != buffer) - continue; - if (!buffer_access.is_write) { - waited_write_wgs[wg_id] = true; - last_read_unit_tasks[wg_id].clear(); - } else { - for (int wg_id = 0; wg_id < num_wgs; ++wg_id) { - last_read_unit[wg_id] = nullptr; - } - last_write_unit_tasks.clear(); + // Update last_read info + for (int wg = 0; wg < num_wgs; ++wg) { + auto last_reads = unit->GetLastAccessTasks( + buffer, /*is_write=*/false, wg, SchedulePhase::kBody); + if (!last_reads.empty()) { + last_read_unit[wg] = unit; + last_read_unit_tasks[wg] = std::move(last_reads); } } - for (const auto &buffer_access : - unit->GetBufferAccessInfo(num_wgs, SchedulePhase::kBody)) { - int wg_id = buffer_access.warpgroup_id; - if (buffer_access.buffer != buffer) - continue; - if (!buffer_access.is_write) { - last_read_unit[wg_id] = unit; - last_read_unit_tasks[wg_id].insert(buffer_access.task); - } else { + // Update last_write info + { + std::vector write_wg_ids; + for (const auto &ba : + unit->GetBufferAccessInfo(num_wgs, SchedulePhase::kBody)) { + if (ba.buffer == buffer && ba.is_write) { + if (std::find(write_wg_ids.begin(), write_wg_ids.end(), + ba.warpgroup_id) == write_wg_ids.end()) + write_wg_ids.push_back(ba.warpgroup_id); + } + } + if (!write_wg_ids.empty()) { last_write_unit = unit; - last_write_unit_tasks.insert(buffer_access.task); - last_write_wg_id = wg_id; - for (int wg_id = 0; wg_id < num_wgs; ++wg_id) { - waited_write_wgs[wg_id] = false; + last_write_unit_wg_tasks.clear(); + for (int wg : write_wg_ids) { + auto last_writes = unit->GetLastAccessTasks( + buffer, /*is_write=*/true, wg, SchedulePhase::kBody); + if (!last_writes.empty()) + last_write_unit_wg_tasks.emplace_back(wg, + std::move(last_writes)); } + for (int wg = 0; wg < num_wgs; ++wg) + waited_write_wgs[wg] = false; } } } @@ -751,35 +773,84 @@ static void InsertSynchronization( if (sync_it == sync_infos.end()) continue; const auto &wait_map = sync_it->second; - bool is_async = unit->UsesTMACore() || unit->UsesTensorCore(); + auto check_need_sync = [&](ScheduleUnit *waiting_unit, int waiting_wg_id, + const SyncInfo &sync_info) { + if (unit == waiting_unit) + // Note: the logic here need some assumption. + return false; + if (wg_id != waiting_wg_id) + return true; + if (!sync_info.producer->UsesTMACore() && + !sync_info.producer->UsesTensorCore()) + return false; + if (sync_info.producer->UsesTensorCore() && + sync_info.consumer->UsesTensorCore()) + return false; + return true; + }; // Handle WGMMA synchronization - if (unit->HasWGMMA()) { - bool different_wg_id = false; - for (const auto &[waiting_unit_info, _] : wait_map) { + { + auto check_need_wgmma_sync = [&](ScheduleUnit *waiting_unit, + int waiting_wg_id, + const SyncInfo &sync_info) { + return check_need_sync(waiting_unit, waiting_wg_id, sync_info) && + sync_info.producer->is_WGMMA(); + }; + bool has_wgmma_sync = false; + for (const auto &[waiting_unit_info, sync_infos] : wait_map) { auto [waiting_unit, waiting_wg_id] = waiting_unit_info; - if (waiting_wg_id != wg_id) { - different_wg_id = true; + for (const auto &sync_info : sync_infos) { + if (check_need_wgmma_sync(waiting_unit, waiting_wg_id, sync_info)) { + has_wgmma_sync = true; + break; + } + } + if (has_wgmma_sync) { break; } } - if (!different_wg_id) { - for (const auto &[waiting_unit_info, _] : wait_map) { + if (has_wgmma_sync) { + bool different_wg_id = false; + for (const auto &[waiting_unit_info, sync_infos] : wait_map) { auto [waiting_unit, waiting_wg_id] = waiting_unit_info; - int num_mma = 0; + if (wg_id == waiting_wg_id) { + continue; + } + for (const auto &sync_info : sync_infos) { + if (check_need_wgmma_sync(waiting_unit, waiting_wg_id, + sync_info)) { + different_wg_id = true; + break; + } + } + if (different_wg_id) { + break; + } + } + if (!different_wg_id) { + for (const auto &[waiting_unit_info, sync_infos] : wait_map) { + auto [waiting_unit, waiting_wg_id] = waiting_unit_info; + bool need_wait = false; + for (const auto &sync_info : sync_infos) { + if (check_need_wgmma_sync(waiting_unit, waiting_wg_id, + sync_info)) { + need_wait = true; + break; + } + } + if (need_wait) { + Stmt wait_stmt = + Evaluate(Call(DataType::Handle(), wait_wgmma(), {0})); + InsertStatementIntoScheduleUnit(waiting_unit, wait_stmt, true, + wg_id); + } + } + } else { Stmt wait_stmt = - Evaluate(Call(DataType::Handle(), wait_wgmma(), {num_mma})); - InsertStatementIntoScheduleUnit(waiting_unit, wait_stmt, true, - wg_id); + Evaluate(Call(DataType::Handle(), wait_wgmma(), {0})); + InsertStatementIntoScheduleUnit(unit, wait_stmt, false, wg_id); } - } else { - Stmt wait_stmt = - Evaluate(Call(DataType::Handle(), wait_wgmma(), {0})); - InsertStatementIntoScheduleUnit(unit, wait_stmt, false, wg_id); } - // Even if different_wg_id is false, we already inserted the necessary - // wait_wgmma statements inside the warp group. Now we can consider the - // unit as synchronized unless it uses other asynchronous operations. - is_async = unit->UsesTMACore(); } int barrier_versions = 1; for (const auto &[waiting_unit_info, sync_infos] : wait_map) { @@ -828,20 +899,8 @@ static void InsertSynchronization( auto check_need_barrier = [&](ScheduleUnit *waiting_unit, int waiting_wg_id, const SyncInfo &sync_info) { - if (unit == waiting_unit) - // Note: the logic here need some assumption. - return false; - if (wg_id != waiting_wg_id) - return true; - if (!is_async) - return false; - if (!sync_info.producer->UsesTMACore() && - !sync_info.producer->UsesTensorCore()) - return false; - if (sync_info.producer->UsesTensorCore() && - sync_info.consumer->UsesTensorCore()) - return false; - return true; + return check_need_sync(waiting_unit, waiting_wg_id, sync_info) && + (wg_id != waiting_wg_id || !sync_info.producer->is_WGMMA()); }; bool need_barrier = false; for (const auto &[waiting_unit_info, sync_infos] : wait_map) { diff --git a/src/transform/auto_schedule/ir_structure.cc b/src/transform/auto_schedule/ir_structure.cc index 6d035fb5f5..1994d681ca 100644 --- a/src/transform/auto_schedule/ir_structure.cc +++ b/src/transform/auto_schedule/ir_structure.cc @@ -537,5 +537,419 @@ void CollectAllTaskNodesWithContext(IRStructure *node, } } +// ============================================================================ +// CollectFirstAccessTasks / CollectLastAccessTasks implementations +// +// These methods return the set of TaskNode pointers that could possibly be the +// first (or last) to perform a specific buffer access (buffer, is_write, wg_id) +// within the IR subtree. The bool return value indicates whether the subtree +// is *guaranteed* to contain at least one matching access (must_have). +// ============================================================================ + +static const IfNode *TryGetIfNode(const IRStructure *node) { + if (!node) + return nullptr; + if (node->IsIf()) + return static_cast(node); + if (node->IsScheduleUnit()) { + auto *unit = static_cast(node); + if (unit->child && unit->child->IsIf()) + return static_cast(unit->child.get()); + } + return nullptr; +} + +static bool +SequenceCollectFirstAccessTasks(const std::vector &nodes, + const Buffer &buffer, bool is_write, int wg_id, + SchedulePhase phase, + std::set &result) { + int n = static_cast(nodes.size()); + + // Find the first node that contains loop_break. + int break_idx = -1; + for (int j = 0; j < n; ++j) { + if (nodes[j]->ContainsLoopBreak()) { + break_idx = j; + break; + } + } + if (break_idx >= 0) { + if (break_idx > 0) { + std::vector before(nodes.begin(), + nodes.begin() + break_idx); + if (SequenceCollectFirstAccessTasks(before, buffer, is_write, wg_id, + phase, result)) + return true; + } + if (nodes[break_idx]->CollectFirstAccessTasks(buffer, is_write, wg_id, + phase, result)) + return true; + if (break_idx + 1 < n) { + std::vector after(nodes.begin() + break_idx + 1, + nodes.end()); + SequenceCollectFirstAccessTasks(after, buffer, is_write, wg_id, phase, + result); + } + return false; + } + + // No loop_break + int i = 0; + while (i < n) { + const IfNode *if_node = TryGetIfNode(nodes[i]); + if (if_node) { + int group_start = i; + while (i + 1 < n) { + const IfNode *next_if = TryGetIfNode(nodes[i + 1]); + if (!next_if) + break; + if (!StructuralEqual()(if_node->condition, next_if->condition)) + break; + ++i; + } + int group_end = i; + ++i; + if (if_node->task) { + if (if_node->task->CollectFirstAccessTasks(buffer, is_write, wg_id, + phase, result)) + return true; + } + std::vector then_children, else_children; + for (int j = group_start; j <= group_end; ++j) { + const IfNode *cur_if = TryGetIfNode(nodes[j]); + ICHECK(cur_if); + if (cur_if->then_child) + then_children.push_back(cur_if->then_child.get()); + if (cur_if->else_child) + else_children.push_back(cur_if->else_child.get()); + } + bool then_must = false, else_must = false; + if (!then_children.empty()) + then_must = SequenceCollectFirstAccessTasks( + then_children, buffer, is_write, wg_id, phase, result); + if (!else_children.empty()) + else_must = SequenceCollectFirstAccessTasks( + else_children, buffer, is_write, wg_id, phase, result); + if (then_must && else_must) + return true; + } else { + if (nodes[i]->CollectFirstAccessTasks(buffer, is_write, wg_id, phase, + result)) + return true; + ++i; + } + } + return false; +} + +static bool +SequenceCollectLastAccessTasks(const std::vector &nodes, + const Buffer &buffer, bool is_write, int wg_id, + SchedulePhase phase, + std::set &result) { + int n = static_cast(nodes.size()); + + // Find the last node that contains loop_break. + int break_idx = -1; + for (int j = n - 1; j >= 0; --j) { + if (nodes[j]->ContainsLoopBreak()) { + break_idx = j; + break; + } + } + if (break_idx >= 0) { + if (break_idx + 1 < n) { + std::vector after(nodes.begin() + break_idx + 1, + nodes.end()); + SequenceCollectLastAccessTasks(after, buffer, is_write, wg_id, phase, + result); + } + nodes[break_idx]->CollectLastAccessTasks(buffer, is_write, wg_id, phase, + result); + if (break_idx > 0) { + std::vector before(nodes.begin(), + nodes.begin() + break_idx); + SequenceCollectLastAccessTasks(before, buffer, is_write, wg_id, phase, + result); + } + return false; + } + + // No loop_break + int i = n - 1; + while (i >= 0) { + const IfNode *if_node = TryGetIfNode(nodes[i]); + if (if_node) { + int group_end = i; + while (i - 1 >= 0) { + const IfNode *prev_if = TryGetIfNode(nodes[i - 1]); + if (!prev_if) + break; + if (!StructuralEqual()(if_node->condition, prev_if->condition)) + break; + --i; + } + int group_start = i; + --i; + std::vector then_children, else_children; + for (int j = group_start; j <= group_end; ++j) { + const IfNode *cur_if = TryGetIfNode(nodes[j]); + ICHECK(cur_if); + if (cur_if->then_child) + then_children.push_back(cur_if->then_child.get()); + if (cur_if->else_child) + else_children.push_back(cur_if->else_child.get()); + } + bool then_must = false, else_must = false; + if (!then_children.empty()) + then_must = SequenceCollectLastAccessTasks( + then_children, buffer, is_write, wg_id, phase, result); + if (!else_children.empty()) + else_must = SequenceCollectLastAccessTasks( + else_children, buffer, is_write, wg_id, phase, result); + if (then_must && else_must) + return true; + const IfNode *last_if = TryGetIfNode(nodes[group_end]); + ICHECK(last_if); + if (last_if->task) { + if (last_if->task->CollectLastAccessTasks(buffer, is_write, wg_id, + phase, result)) + return true; + } + } else { + if (nodes[i]->CollectLastAccessTasks(buffer, is_write, wg_id, phase, + result)) + return true; + --i; + } + } + return false; +} + +static bool TaskMatchesAccess(const TaskNode *task, const Buffer &buffer, + bool is_write, int wg_id, SchedulePhase phase) { + if (task->GetSchedulePhase() != phase) + return false; + int task_wg = task->GetWarpgroupId(); + const auto ®ions = + is_write ? task->GetWriteRegions() : task->GetReadRegions(); + for (const auto ®ion : regions) { + if (region->buffer == buffer && + (task_wg == wg_id || IsWarpgroupBroadcast(task_wg))) + return true; + } + return false; +} + +bool TaskNode::CollectFirstAccessTasks( + const Buffer &buffer, bool is_write, int wg_id, SchedulePhase phase, + std::set &result) const { + if (TaskMatchesAccess(this, buffer, is_write, wg_id, phase)) { + result.insert(this); + return true; + } + return false; +} + +bool TaskNode::CollectLastAccessTasks( + const Buffer &buffer, bool is_write, int wg_id, SchedulePhase phase, + std::set &result) const { + if (TaskMatchesAccess(this, buffer, is_write, wg_id, phase)) { + result.insert(this); + return true; + } + return false; +} + +bool ScheduleUnit::CollectFirstAccessTasks( + const Buffer &buffer, bool is_write, int wg_id, SchedulePhase phase, + std::set &result) const { + if (child) + return child->CollectFirstAccessTasks(buffer, is_write, wg_id, phase, + result); + return false; +} + +bool ScheduleUnit::CollectLastAccessTasks( + const Buffer &buffer, bool is_write, int wg_id, SchedulePhase phase, + std::set &result) const { + if (child) + return child->CollectLastAccessTasks(buffer, is_write, wg_id, phase, + result); + return false; +} + +bool WrapperNode::CollectFirstAccessTasks( + const Buffer &buffer, bool is_write, int wg_id, SchedulePhase phase, + std::set &result) const { + if (task) { + if (task->CollectFirstAccessTasks(buffer, is_write, wg_id, phase, result)) + return true; + } + if (child) { + if (child->CollectFirstAccessTasks(buffer, is_write, wg_id, phase, result)) + return true; + } + return false; +} + +bool WrapperNode::CollectLastAccessTasks( + const Buffer &buffer, bool is_write, int wg_id, SchedulePhase phase, + std::set &result) const { + if (child) { + if (child->CollectLastAccessTasks(buffer, is_write, wg_id, phase, result)) + return true; + } + if (task) { + if (task->CollectLastAccessTasks(buffer, is_write, wg_id, phase, result)) + return true; + } + return false; +} + +bool IfNode::CollectFirstAccessTasks(const Buffer &buffer, bool is_write, + int wg_id, SchedulePhase phase, + std::set &result) const { + if (task) { + if (task->CollectFirstAccessTasks(buffer, is_write, wg_id, phase, result)) + return true; + } + bool then_must = false, else_must = false; + if (then_child) + then_must = then_child->CollectFirstAccessTasks(buffer, is_write, wg_id, + phase, result); + if (else_child) + else_must = else_child->CollectFirstAccessTasks(buffer, is_write, wg_id, + phase, result); + return then_must && else_must; +} + +bool IfNode::CollectLastAccessTasks(const Buffer &buffer, bool is_write, + int wg_id, SchedulePhase phase, + std::set &result) const { + bool then_must = false, else_must = false; + if (then_child) + then_must = then_child->CollectLastAccessTasks(buffer, is_write, wg_id, + phase, result); + if (else_child) + else_must = else_child->CollectLastAccessTasks(buffer, is_write, wg_id, + phase, result); + if (then_must && else_must) + return true; + if (task) { + if (task->CollectLastAccessTasks(buffer, is_write, wg_id, phase, result)) + return true; + } + return false; +} + +static bool LoopMustExecute(const ControlNode *ctrl) { + if (!ctrl->control.defined()) + return false; + const ForNode *for_node = ctrl->control.get(); + const int64_t *extent_ptr = as_const_int(for_node->extent); + if (extent_ptr && *extent_ptr >= 1) { + if (for_node->step.has_value()) { + const int64_t *step_ptr = as_const_int(for_node->step.value()); + if (step_ptr && *step_ptr >= 1) + return true; + if (!step_ptr) + return false; + } else { + return true; + } + } + return false; +} + +bool ControlNode::CollectFirstAccessTasks( + const Buffer &buffer, bool is_write, int wg_id, SchedulePhase phase, + std::set &result) const { + if (task) { + if (task->CollectFirstAccessTasks(buffer, is_write, wg_id, phase, result)) + return true; + } + bool body_must = false; + if (child && child->IsSequence()) { + auto *seq = static_cast(child.get()); + std::vector ordered; + ordered.reserve(seq->children.size()); + for (const auto &c : seq->children) + ordered.push_back(c.get()); + std::stable_sort(ordered.begin(), ordered.end(), + [](const IRStructure *a, const IRStructure *b) { + auto *ua = dynamic_cast(a); + auto *ub = dynamic_cast(b); + if (ua && ub) + return ua->stage > ub->stage; + return false; + }); + body_must = SequenceCollectFirstAccessTasks(ordered, buffer, is_write, + wg_id, phase, result); + } else if (child) { + body_must = + child->CollectFirstAccessTasks(buffer, is_write, wg_id, phase, result); + } + if (body_must && LoopMustExecute(this)) + return true; + return false; +} + +bool ControlNode::CollectLastAccessTasks( + const Buffer &buffer, bool is_write, int wg_id, SchedulePhase phase, + std::set &result) const { + bool body_must = false; + if (child && child->IsSequence()) { + auto *seq = static_cast(child.get()); + std::vector ordered; + ordered.reserve(seq->children.size()); + for (const auto &c : seq->children) + ordered.push_back(c.get()); + std::stable_sort(ordered.begin(), ordered.end(), + [](const IRStructure *a, const IRStructure *b) { + auto *ua = dynamic_cast(a); + auto *ub = dynamic_cast(b); + if (ua && ub) + return ua->stage > ub->stage; + return false; + }); + body_must = SequenceCollectLastAccessTasks(ordered, buffer, is_write, wg_id, + phase, result); + } else if (child) { + body_must = + child->CollectLastAccessTasks(buffer, is_write, wg_id, phase, result); + } + if (body_must && LoopMustExecute(this)) + return true; + if (task) { + if (task->CollectLastAccessTasks(buffer, is_write, wg_id, phase, result)) + return true; + } + return false; +} + +bool SequenceNode::CollectFirstAccessTasks( + const Buffer &buffer, bool is_write, int wg_id, SchedulePhase phase, + std::set &result) const { + std::vector ordered; + ordered.reserve(children.size()); + for (const auto &c : children) + ordered.push_back(c.get()); + return SequenceCollectFirstAccessTasks(ordered, buffer, is_write, wg_id, + phase, result); +} + +bool SequenceNode::CollectLastAccessTasks( + const Buffer &buffer, bool is_write, int wg_id, SchedulePhase phase, + std::set &result) const { + std::vector ordered; + ordered.reserve(children.size()); + for (const auto &c : children) + ordered.push_back(c.get()); + return SequenceCollectLastAccessTasks(ordered, buffer, is_write, wg_id, phase, + result); +} + } // namespace tl } // namespace tvm diff --git a/src/transform/auto_schedule/ir_structure.h b/src/transform/auto_schedule/ir_structure.h index 82c2f8093f..ec6206e447 100644 --- a/src/transform/auto_schedule/ir_structure.h +++ b/src/transform/auto_schedule/ir_structure.h @@ -159,6 +159,38 @@ class IRStructure { return std::vector(result.begin(), result.end()); } + // Collect tasks that could possibly be the first/last to access a specific + // (buffer, is_write, wg_id) within this IR subtree. + // The result set is populated with candidate tasks. + // Returns true if this subtree is guaranteed to contain at least one matching + // access (i.e., the access must happen unconditionally). + virtual bool + CollectFirstAccessTasks(const Buffer &buffer, bool is_write, int wg_id, + SchedulePhase phase, + std::set &result) const = 0; + + virtual bool + CollectLastAccessTasks(const Buffer &buffer, bool is_write, int wg_id, + SchedulePhase phase, + std::set &result) const = 0; + + // Convenience wrappers that return the result set directly. + std::set + GetFirstAccessTasks(const Buffer &buffer, bool is_write, int wg_id, + SchedulePhase phase = SchedulePhase::kBody) const { + std::set result; + CollectFirstAccessTasks(buffer, is_write, wg_id, phase, result); + return result; + } + + std::set + GetLastAccessTasks(const Buffer &buffer, bool is_write, int wg_id, + SchedulePhase phase = SchedulePhase::kBody) const { + std::set result; + CollectLastAccessTasks(buffer, is_write, wg_id, phase, result); + return result; + } + // Substitute a variable throughout this IR node virtual void SubstituteVar(const Var &old_var, const Var &new_var) = 0; @@ -368,6 +400,15 @@ class TaskNode : public IRStructure { CollectBufferAccessInfo(int num_wgs, SchedulePhase phase, std::set &result) const override; + bool + CollectFirstAccessTasks(const Buffer &buffer, bool is_write, int wg_id, + SchedulePhase phase, + std::set &result) const override; + bool + CollectLastAccessTasks(const Buffer &buffer, bool is_write, int wg_id, + SchedulePhase phase, + std::set &result) const override; + bool containWarpgroupId(int id) const override { return ContainsLoopBreak() || IsWarpgroupBroadcast(warpgroup_id_) || warpgroup_id_ == id; @@ -525,6 +566,15 @@ class ControlNode : public IRStructure { CollectBufferAccessInfo(int num_wgs, SchedulePhase phase, std::set &result) const override; + bool + CollectFirstAccessTasks(const Buffer &buffer, bool is_write, int wg_id, + SchedulePhase phase, + std::set &result) const override; + bool + CollectLastAccessTasks(const Buffer &buffer, bool is_write, int wg_id, + SchedulePhase phase, + std::set &result) const override; + bool hasPromote() const { return has_promote_; } void SetPromote(bool promote) { has_promote_ = promote; } @@ -663,6 +713,15 @@ class WrapperNode : public IRStructure { CollectBufferAccessInfo(int num_wgs, SchedulePhase phase, std::set &result) const override; + bool + CollectFirstAccessTasks(const Buffer &buffer, bool is_write, int wg_id, + SchedulePhase phase, + std::set &result) const override; + bool + CollectLastAccessTasks(const Buffer &buffer, bool is_write, int wg_id, + SchedulePhase phase, + std::set &result) const override; + // Clone method std::shared_ptr Clone() const override; @@ -837,6 +896,15 @@ class IfNode : public IRStructure { CollectBufferAccessInfo(int num_wgs, SchedulePhase phase, std::set &result) const override; + bool + CollectFirstAccessTasks(const Buffer &buffer, bool is_write, int wg_id, + SchedulePhase phase, + std::set &result) const override; + bool + CollectLastAccessTasks(const Buffer &buffer, bool is_write, int wg_id, + SchedulePhase phase, + std::set &result) const override; + // Clone method std::shared_ptr Clone() const override; @@ -945,6 +1013,15 @@ class ScheduleUnit : public IRStructure { CollectBufferAccessInfo(int num_wgs, SchedulePhase phase, std::set &result) const override; + bool + CollectFirstAccessTasks(const Buffer &buffer, bool is_write, int wg_id, + SchedulePhase phase, + std::set &result) const override; + bool + CollectLastAccessTasks(const Buffer &buffer, bool is_write, int wg_id, + SchedulePhase phase, + std::set &result) const override; + int GetStage() const { return stage; } bool isInnerTask() const { return child->IsTask(); } int GetWarpgroupId() const override { @@ -1023,6 +1100,15 @@ class SequenceNode : public IRStructure { CollectBufferAccessInfo(int num_wgs, SchedulePhase phase, std::set &result) const override; + bool + CollectFirstAccessTasks(const Buffer &buffer, bool is_write, int wg_id, + SchedulePhase phase, + std::set &result) const override; + bool + CollectLastAccessTasks(const Buffer &buffer, bool is_write, int wg_id, + SchedulePhase phase, + std::set &result) const override; + // Clone method std::shared_ptr Clone() const override; From bdb20c2fd8391a3f446cf8eecad21668ff23b0de Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Thu, 23 Apr 2026 14:43:48 +0800 Subject: [PATCH 122/156] format --- .../auto_schedule/warpgroup_partition.cc | 5 ++- tilelang/engine/phase.py | 32 +++++++++++++------ 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index 918425bd94..0ce4b48ac3 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -572,9 +572,8 @@ class UnusedLetStmtStripper : public StmtExprMutator { Stmt new_body = this->VisitStmt(op->body); PrimExpr new_value = this->VisitExpr(op->value); - auto body_uses_var = UsesVar(new_body, [&](const VarNode *v) { - return v == op->var.get(); - }); + auto body_uses_var = + UsesVar(new_body, [&](const VarNode *v) { return v == op->var.get(); }); bool value_is_pure = SideEffect(new_value) <= CallEffectKind::kPure; if (!body_uses_var && value_is_pure) { diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 20860860e1..cb4e56bbe2 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -41,18 +41,22 @@ def module_uses_thread_var(mod: IRModule) -> bool: thread_extent_vars: set = set() explicit_thread_binding_loop: list[bool] = [False] - def _collect(node): + def _collect( + node, + _thread_extent_vars=thread_extent_vars, + _explicit_thread_binding_loop=explicit_thread_binding_loop, + ): if isinstance(node, tir.AttrStmt) and node.attr_key == "thread_extent": iter_var = node.node if isinstance(iter_var, tir.IterVar): tag = getattr(iter_var, "thread_tag", "") or "" if tag.startswith("threadIdx."): - thread_extent_vars.add(iter_var.var) + _thread_extent_vars.add(iter_var.var) elif isinstance(node, tir.For) and node.kind == tir.ForKind.THREAD_BINDING: tb = node.thread_binding tag = getattr(tb, "thread_tag", "") if tb is not None else "" if isinstance(tag, str) and tag.startswith("threadIdx."): - explicit_thread_binding_loop[0] = True + _explicit_thread_binding_loop[0] = True stmt_functor.post_order_visit(func.body, _collect) @@ -64,14 +68,22 @@ def _collect(node): uses_thread_var = [False] - def _find_use(node): - if uses_thread_var[0]: + def _find_use( + node, + _uses_thread_var=uses_thread_var, + _thread_extent_vars=thread_extent_vars, + ): + if _uses_thread_var[0]: return - if isinstance(node, tir.Var) and node in thread_extent_vars: - uses_thread_var[0] = True - - def _walk(stmt): - if uses_thread_var[0]: + if isinstance(node, tir.Var) and node in _thread_extent_vars: + _uses_thread_var[0] = True + + def _walk( + stmt, + _uses_thread_var=uses_thread_var, + _find_use=_find_use, + ): + if _uses_thread_var[0]: return if isinstance(stmt, tir.AttrStmt) and stmt.attr_key == "thread_extent": _walk(stmt.body) From ff067b06512496ae67f7ff3627041c767fbeddfc Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Thu, 23 Apr 2026 14:54:36 +0800 Subject: [PATCH 123/156] fix attr warp partition --- .../auto_schedule/warpgroup_partition.cc | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index 0ce4b48ac3..2cbcc36cba 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -1185,10 +1185,39 @@ Stmt ApplyWarpgroupPartitionToIRStructure( if (all_empty) continue; + auto PeelLetsToInner = [](const Stmt &s) -> Stmt { + const Stmt *cur = &s; + while (const auto *let = cur->as()) { + cur = &let->body; + } + return *cur; + }; + bool is_shared_attr_segment = true; + Stmt shared_attr_stmt; + for (size_t i = 0; i < num_wgs; ++i) { + if (IsEvaluateZero(wg_stmts[i])) { + continue; + } + Stmt inner = PeelLetsToInner(wg_stmts[i]); + const auto *attr = inner.as(); + if (!attr || !IsEvaluateZero(attr->body)) { + is_shared_attr_segment = false; + break; + } + if (!shared_attr_stmt.defined()) { + shared_attr_stmt = wg_stmts[i]; + } + } + // Insert liveness boundary before each non-empty non-LetDecl child segmented_stmts.push_back(AttrStmt( Integer(0), attr::kAutoScheduleSharedMemoryBoundary, 0, Evaluate(0))); + if (is_shared_attr_segment && shared_attr_stmt.defined()) { + segmented_stmts.push_back(shared_attr_stmt); + continue; + } + // Prepend set_max_nreg only to the first non-LetDecl child if (first_non_let && !has_simt_copy && !has_inner_nreg_decision && num_wgs == 2 && config.enable_set_max_nreg) { From b88bbcb5fa427b9964aba27d69ece29ba998713a Mon Sep 17 00:00:00 2001 From: Zhang Jason Date: Thu, 23 Apr 2026 14:56:13 +0800 Subject: [PATCH 124/156] [AMD][gfx950] Add ds_read_tr16_b64 / ds_read_tr8_b64 support for gfx950 LDS transpose reads (#2085) * add gfx950 ds_read_tr support * update with format checking --- src/op/builtin.cc | 14 ++ src/op/builtin.h | 20 +++ src/target/codegen_hip.cc | 8 ++ src/tl_templates/hip/ldsm.h | 43 ++++++ .../python/amd/test_tilelang_ds_read_tr.py | 125 ++++++++++++++++++ tilelang/language/__init__.py | 2 + tilelang/language/builtin.py | 42 ++++++ 7 files changed, 254 insertions(+) create mode 100644 testing/python/amd/test_tilelang_ds_read_tr.py diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 7de31d5205..0b1e20d4be 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -627,6 +627,20 @@ TIR_DEFINE_TL_BUILTIN(warp_reduce_bitor) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +// ds_read_tr16_b64(smem_ptr) -> uint32x2 +// gfx950 LDS transpose read: 64-bit, 16-element transpose (FP16/BF16 MFMA) +TIR_DEFINE_TL_BUILTIN(ds_read_tr16_b64) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + +// ds_read_tr8_b64(smem_ptr) -> uint32x2 +// gfx950 LDS transpose read: 64-bit, 8-element transpose (FP32 MFMA) +TIR_DEFINE_TL_BUILTIN(ds_read_tr8_b64) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + // __ldg(BufferLoad | Buffer, idx?) -> value // Treat as a pure call that returns the loaded value. TIR_DEFINE_TL_BUILTIN(__ldg).set_num_inputs(-1).set_attr( diff --git a/src/op/builtin.h b/src/op/builtin.h index 43cf562963..68c41843c3 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -824,6 +824,26 @@ TVM_DLL const Op &match_all_sync(); */ TVM_DLL const Op &loop_break(); +/*! + * \brief tilelang intrinsic for gfx950 LDS transpose read, 64-bit, 16-element. + * + * Reads 8 bytes from LDS with a 16-element transpose (FP16/BF16 MFMA B-load). + * Only available on gfx950 (MI350/MI355X). + * + * uint32x2 ds_read_tr16_b64(smem_access_ptr) + */ +TVM_DLL const Op &ds_read_tr16_b64(); + +/*! + * \brief tilelang intrinsic for gfx950 LDS transpose read, 64-bit, 8-element. + * + * Reads 8 bytes from LDS with an 8-element transpose (FP32 MFMA B-load). + * Only available on gfx950 (MI350/MI355X). + * + * uint32x2 ds_read_tr8_b64(smem_access_ptr) + */ +TVM_DLL const Op &ds_read_tr8_b64(); + /*! * \brief tvm intrinsic for amd matrix core mfma instructions. * diff --git a/src/target/codegen_hip.cc b/src/target/codegen_hip.cc index dd3dd0aeac..fafb3475ca 100644 --- a/src/target/codegen_hip.cc +++ b/src/target/codegen_hip.cc @@ -1023,6 +1023,14 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { os << ", " << PrintExpr(op->args[i]); } os << ")"; + } else if (op->op.same_as(tl::ds_read_tr16_b64())) { + ICHECK_EQ(op->args.size(), 1U) + << "tl.ds_read_tr16_b64 expects one argument (smem_access_ptr)."; + os << "tl::ds_read_tr16_b64(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::ds_read_tr8_b64())) { + ICHECK_EQ(op->args.size(), 1U) + << "tl.ds_read_tr8_b64 expects one argument (smem_access_ptr)."; + os << "tl::ds_read_tr8_b64(" << PrintExpr(op->args[0]) << ")"; } else if (op->op.same_as(tl::__ldg())) { // HIP fallback: regular load const BufferLoadNode *bl = op->args[0].as(); diff --git a/src/tl_templates/hip/ldsm.h b/src/tl_templates/hip/ldsm.h index 286b773242..1fbd797f22 100644 --- a/src/tl_templates/hip/ldsm.h +++ b/src/tl_templates/hip/ldsm.h @@ -1,3 +1,46 @@ #pragma once #include "common.h" + +namespace tl { + +#if defined(__gfx950__) + +// ds_read_tr16_b64: LDS transpose read, 64-bit, 16-element transpose. +// Reads 8 bytes from LDS with a transpose across 16 elements. +// Used for FP16/BF16 MFMA matrix B loads on gfx950 (MI350/MI355X). +// smem_ptr must point into __shared__ memory. +// +// Uses __builtin_amdgcn_ds_read_tr16_b64_v4f16 (LLVM builtin) instead of +// inline assembly because ROCm <= 7.2 assembler does not yet recognise the +// ds_read_tr16_b64 mnemonic even though the hardware supports it. +CK_TILE_DEVICE uint2 ds_read_tr16_b64(const void *smem_ptr) { + typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 fp16x4_t; + // C-style cast: void* → LDS fp16x4_t* (required by the LLVM builtin + // signature) + fp16x4_t v = __builtin_amdgcn_ds_read_tr16_b64_v4f16( + (__attribute__((address_space(3))) fp16x4_t *)(smem_ptr)); + uint2 result; + __builtin_memcpy(&result, &v, sizeof(result)); + return result; +} + +// ds_read_tr8_b64: LDS transpose read, 64-bit, 8-element transpose. +// Reads 8 bytes from LDS with a transpose across 8 elements. +// Used for FP32 MFMA matrix B loads on gfx950 (MI350/MI355X). +// smem_ptr must point into __shared__ memory. +// +// Uses __builtin_amdgcn_ds_read_tr8_b64_v2i32 (LLVM builtin) for the same +// reason as ds_read_tr16_b64 above. +CK_TILE_DEVICE uint2 ds_read_tr8_b64(const void *smem_ptr) { + typedef __attribute__((__vector_size__(2 * sizeof(int)))) int i32x2_t; + i32x2_t v = __builtin_amdgcn_ds_read_tr8_b64_v2i32( + (__attribute__((address_space(3))) i32x2_t *)(smem_ptr)); + uint2 result; + __builtin_memcpy(&result, &v, sizeof(result)); + return result; +} + +#endif // __gfx950__ + +} // namespace tl diff --git a/testing/python/amd/test_tilelang_ds_read_tr.py b/testing/python/amd/test_tilelang_ds_read_tr.py new file mode 100644 index 0000000000..65d465707b --- /dev/null +++ b/testing/python/amd/test_tilelang_ds_read_tr.py @@ -0,0 +1,125 @@ +"""Tests for ds_read_tr16_b64 and ds_read_tr8_b64 intrinsics on gfx950. + +Covers: + - Codegen: generated HIP source contains the correct tl:: call. + - Runtime: kernel compiles and executes on gfx950 without errors. + +ds_read_tr16_b64 – LDS transpose read, 64-bit, 16-element transpose. + Used for FP16/BF16 MFMA B-loads on MI350/MI355X (gfx950). +ds_read_tr8_b64 – LDS transpose read, 64-bit, 8-element transpose. + Used for FP32 MFMA B-loads on MI350/MI355X (gfx950). +""" + +import pytest +import torch +import tilelang +import tilelang.language as T +import tilelang.testing +from tilelang.utils.target import target_is_gfx950, determine_target + + +def requires_gfx950(): + """Skip the test when the current ROCm target is not gfx950.""" + target = determine_target("auto", return_object=True) + if not target_is_gfx950(target): + pytest.skip("gfx950 (MI350/MI355X) not detected") + + +# --------------------------------------------------------------------------- +# Kernels +# --------------------------------------------------------------------------- + + +# ds_read_tr16_b64: each thread reads 2 fp16 elements from LDS with a +# 16-element transpose and stores the result (as float32x2) into a staging +# shared buffer, which is then copied to global memory. +@tilelang.jit(target="hip") +def _kernel_tr16(X, Out): + NV = T.const("NV") + X: T.Tensor[[NV], T.float16] + Out: T.Tensor[[NV // 2], T.float32] + + with T.Kernel(1, threads=NV // 2) as _: + smem = T.alloc_shared([NV], T.float16) + smem2 = T.alloc_shared([NV // 2], T.float32) + T.copy(X[:NV], smem[:NV]) + T.sync_threads() + for i in T.Parallel(NV // 2): + val = T.reinterpret(T.ds_read_tr16_b64(smem[i * 2]), T.float32x2) + smem2[i * 2 : i * 2 + 2] = val + T.sync_threads() + T.copy(smem2[: NV // 2], Out[: NV // 2]) + + +# ds_read_tr8_b64: same pattern but reads float32 elements. +@tilelang.jit(target="hip") +def _kernel_tr8(X, Out): + NV = T.const("NV") + X: T.Tensor[[NV], T.float32] + Out: T.Tensor[[NV // 2], T.float32] + + with T.Kernel(1, threads=NV // 2) as _: + smem = T.alloc_shared([NV], T.float32) + smem2 = T.alloc_shared([NV // 2], T.float32) + T.copy(X[:NV], smem[:NV]) + T.sync_threads() + for i in T.Parallel(NV // 2): + val = T.reinterpret(T.ds_read_tr8_b64(smem[i * 2]), T.float32x2) + smem2[i * 2 : i * 2 + 2] = val + T.sync_threads() + T.copy(smem2[: NV // 2], Out[: NV // 2]) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +N = 128 # number of fp16 elements in shared memory + + +@tilelang.testing.requires_rocm +def test_ds_read_tr16_b64_codegen(): + """Generated HIP source must contain tl::ds_read_tr16_b64(...).""" + requires_gfx950() + + src = _kernel_tr16.get_kernel_source(NV=N) + print("=== ds_read_tr16_b64 codegen ===") + print(src) + assert "ds_read_tr16_b64" in src, "Expected tl::ds_read_tr16_b64 call in generated HIP source" + + +@tilelang.testing.requires_rocm +def test_ds_read_tr8_b64_codegen(): + """Generated HIP source must contain tl::ds_read_tr8_b64(...).""" + requires_gfx950() + + src = _kernel_tr8.get_kernel_source(NV=N) + print("=== ds_read_tr8_b64 codegen ===") + print(src) + assert "ds_read_tr8_b64" in src, "Expected tl::ds_read_tr8_b64 call in generated HIP source" + + +@tilelang.testing.requires_rocm +def test_ds_read_tr16_b64_runtime(): + """ds_read_tr16_b64 kernel must execute without error on gfx950.""" + requires_gfx950() + + X = torch.randn(N, dtype=torch.float16, device="cuda") + Out = torch.empty(N // 2, dtype=torch.float32, device="cuda") + _kernel_tr16(X, Out) + torch.cuda.synchronize() + + +@tilelang.testing.requires_rocm +def test_ds_read_tr8_b64_runtime(): + """ds_read_tr8_b64 kernel must execute without error on gfx950.""" + requires_gfx950() + + X = torch.randn(N, dtype=torch.float32, device="cuda") + Out = torch.empty(N // 2, dtype=torch.float32, device="cuda") + _kernel_tr8(X, Out) + torch.cuda.synchronize() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 785b91c489..7dddc7a748 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -96,6 +96,8 @@ from .logical import any_of, all_of # noqa: F401 from .builtin import * # noqa: F401 from .builtin import __ldg as __ldg # noqa: F401 +from .builtin import ds_read_tr16_b64 as ds_read_tr16_b64 # noqa: F401 +from .builtin import ds_read_tr8_b64 as ds_read_tr8_b64 # noqa: F401 from .builtin import ldg32 as ldg32 # noqa: F401 from .builtin import ldg64 as ldg64 # noqa: F401 from .builtin import ldg128 as ldg128 # noqa: F401 diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 525e7c4320..da6bde0b37 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -1332,6 +1332,48 @@ def ptx_mma_sm70( ) +def ds_read_tr16_b64(src: BufferLikeType) -> PrimExpr: + """LDS transpose read, 64-bit, 16-element transpose (gfx950 only). + + Reads 8 bytes from LDS (__shared__ memory) with a 16-element transpose. + Used for FP16/BF16 MFMA matrix B-loads on MI350/MI355X (gfx950). + + Args: + src: A `Buffer`, `BufferRegion`, or `BufferLoad` in shared memory. + + Returns: + PrimExpr: The loaded 64-bit value as uint32x2. + + Example: + >>> val = T.ds_read_tr16_b64(smem[i]) + """ + if not isinstance(src, BufferLikeTypeTuple): + raise TypeError(f"T.ds_read_tr16_b64 expects Buffer, BufferRegion, or BufferLoad. Got {type(src)}: {src}") + ptr = retrieve_ptr(src, access_type="r") + return tir.call_intrin("uint32x2", tir.op.Op.get("tl.ds_read_tr16_b64"), ptr) + + +def ds_read_tr8_b64(src: BufferLikeType) -> PrimExpr: + """LDS transpose read, 64-bit, 8-element transpose (gfx950 only). + + Reads 8 bytes from LDS (__shared__ memory) with an 8-element transpose. + Used for FP32 MFMA matrix B-loads on MI350/MI355X (gfx950). + + Args: + src: A `Buffer`, `BufferRegion`, or `BufferLoad` in shared memory. + + Returns: + PrimExpr: The loaded 64-bit value as uint32x2. + + Example: + >>> val = T.ds_read_tr8_b64(smem[i]) + """ + if not isinstance(src, BufferLikeTypeTuple): + raise TypeError(f"T.ds_read_tr8_b64 expects Buffer, BufferRegion, or BufferLoad. Got {type(src)}: {src}") + ptr = retrieve_ptr(src, access_type="r") + return tir.call_intrin("uint32x2", tir.op.Op.get("tl.ds_read_tr8_b64"), ptr) + + def ldg32(src: BufferLikeType, pred: PrimExpr = None) -> PrimExpr: """Load 32 bits (4 bytes) from global memory using explicit PTX instructions. From 95f1d29f6f922e97bd385f3b9e7a1fa80773d209 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Thu, 23 Apr 2026 16:09:41 +0800 Subject: [PATCH 125/156] fix let missing bug --- src/transform/auto_schedule.cc | 6 ++---- src/transform/auto_schedule/schedule_builder.cc | 14 ++++++++++++-- src/transform/auto_schedule/schedule_builder.h | 14 ++++++++++++-- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/src/transform/auto_schedule.cc b/src/transform/auto_schedule.cc index ee879b6c99..78c2b13866 100644 --- a/src/transform/auto_schedule.cc +++ b/src/transform/auto_schedule.cc @@ -766,15 +766,11 @@ ScheduleSingleKernel(const Stmt &kernel_body, IterVar thread_var, Target target, thread_count, loop_info, result.buffer_infos, neutral_sync_shared_barrier); - // Print the modified summary view - // PrintIRStructure(ir_structure.get()); - // Apply warpgroup partition to entire IRStructure result.scheduled_body = ApplyWarpgroupPartitionToIRStructure( ir_structure.get(), thread_var, result.barrier_buffers, result.barrier_map, enable_epi, thread_count, config, neutral_sync_shared_barrier, result.duplicated_fragment_buffers); - result.scheduled_body = StripUnusedLetStmts(result.scheduled_body); result.did_warpgroup_partition = true; return result; } @@ -896,6 +892,7 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { } final_body = ReNestLetStmts(final_body); + final_body = StripUnusedLetStmts(final_body); // Create a new PrimFunc with the updated body auto new_func = PrimFunc(func->params, final_body, func->ret_type, @@ -1001,6 +998,7 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { Stmt final_body = seq_replacer(func->body); final_body = ReNestLetStmts(final_body); + final_body = StripUnusedLetStmts(final_body); // Create a new PrimFunc with the updated body auto new_func = PrimFunc(func->params, final_body, func->ret_type, diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index fbaaf9574e..9f118c8e00 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -1116,6 +1116,16 @@ void ScheduleUnitBuilder::NaiveScheduleLoop(ControlNode *ctrl) { task->stmts[0].as() != nullptr; }; auto SolveConflictVar = [&]() -> bool { + auto HasVarRawDep = [](const IRStructure *producer, + const IRStructure *consumer) -> bool { + for (const auto &w : producer->GetWriteVars()) { + for (const auto &r : consumer->GetReadVars()) { + if (SameVar(w, r)) + return true; + } + } + return false; + }; for (int i = 0; i < n; ++i) { if (!IsVarDecl(seq_body->children[i].get())) continue; @@ -1124,7 +1134,7 @@ void ScheduleUnitBuilder::NaiveScheduleLoop(ControlNode *ctrl) { continue; auto node_i = seq_body->children[i].get(); auto node_j = seq_body->children[j].get(); - if (!HasDependency(node_i, node_j)) + if (!HasVarRawDep(node_i, node_j)) continue; if (stage_map[node_j] == stage_map[node_i]) continue; @@ -1161,7 +1171,7 @@ void ScheduleUnitBuilder::NaiveScheduleLoop(ControlNode *ctrl) { auto node_k = seq_body->children[k].get(); if (rem_stage_j != stage_map[node_k]) continue; - if (HasDependency(node_i, node_k)) { + if (HasVarRawDep(node_i, node_k)) { node_k->SubstituteVar(node_i_let_stmt->var, cloned_let_stmt->var); stage_map[node_k] = rem_stage_j; } diff --git a/src/transform/auto_schedule/schedule_builder.h b/src/transform/auto_schedule/schedule_builder.h index 4d95344d82..f095c2cddc 100644 --- a/src/transform/auto_schedule/schedule_builder.h +++ b/src/transform/auto_schedule/schedule_builder.h @@ -552,6 +552,16 @@ class ScheduleUnitBuilder { return false; }; auto SolveConflictVar = [&]() -> bool { + auto HasVarRawDep = [](const IRStructure *producer, + const IRStructure *consumer) -> bool { + for (const auto &w : producer->GetWriteVars()) { + for (const auto &r : consumer->GetReadVars()) { + if (SameVar(w, r)) + return true; + } + } + return false; + }; for (int i = 0; i < n; ++i) if (IsVarDecl(seq_body->children[i].get())) { for (int j = 0; j < n; ++j) { @@ -562,7 +572,7 @@ class ScheduleUnitBuilder { auto node_j = seq_body->children[j].get(); int rem_stage_j = stage_map[node_j]; - if (!HasDependency(node_i, node_j)) + if (!HasVarRawDep(node_i, node_j)) continue; if (stage_map[node_j] == stage_map[node_i]) @@ -604,7 +614,7 @@ class ScheduleUnitBuilder { auto node_k = seq_body->children[k].get(); if (rem_stage_j != stage_map[node_k]) continue; - if (HasDependency(node_i, node_k)) { + if (HasVarRawDep(node_i, node_k)) { node_k->SubstituteVar(node_i_let_stmt->var, cloned_let_stmt->var); stage_map[node_k] = rem_stage_j; From 6d0bffb9e72dde2b7ed2a79aef54b5fe09270174 Mon Sep 17 00:00:00 2001 From: Zhang Jason Date: Thu, 23 Apr 2026 17:51:54 +0800 Subject: [PATCH 126/156] [AMD][Gfx950] Add the support of 160K LDS and copy.async (#2058) * add gfx950 new feature, 160K LDS and compy async * update accoding to format check result * fix a typo error * add gfx950 copy async test codes --------- Co-authored-by: LeiWang1999 --- src/tl_templates/hip/copy.h | 38 +++ .../amd/test_tilelang_gfx950_copy_async.py | 230 ++++++++++++++++++ tilelang/carver/arch/cdna.py | 16 +- 3 files changed, 283 insertions(+), 1 deletion(-) create mode 100644 testing/python/amd/test_tilelang_gfx950_copy_async.py diff --git a/src/tl_templates/hip/copy.h b/src/tl_templates/hip/copy.h index 3f122d801f..c691017148 100644 --- a/src/tl_templates/hip/copy.h +++ b/src/tl_templates/hip/copy.h @@ -72,10 +72,35 @@ CK_TILE_DEVICE void async_buffer_load_dword_v(void *smem, int32x4_t rsrc, : "memory"); } +// gfx950 (CDNA4 / MI350): 128-bit direct-to-LDS async load. +// buffer_load_dwordx4 ... lds bypasses VGPRs entirely, giving 4x the +// bandwidth of the 32-bit path and overlapping with MFMA computation. +#if defined(__gfx950__) +CK_TILE_DEVICE void async_buffer_load_dwordx4_v(void *smem, int32x4_t rsrc, + index_t voffset) { + auto const lds_ptr_sgpr = + __builtin_amdgcn_readfirstlane((reinterpret_cast(smem))); + asm volatile( + "s_mov_b32 m0, %0; \n\t" + "buffer_load_dwordx4 %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), + "v"(voffset), "s"(rsrc) + : "memory"); +} +#endif // __gfx950__ + template TL_DEVICE void cp_async_gs(void *lds_base_ptr, void const *global_base_ptr) { if constexpr (N == 16) { +#if defined(__gfx950__) + // gfx950: use 128-bit direct-to-LDS async copy (buffer_load_dwordx4 lds) + async_buffer_load_dwordx4_v( + lds_base_ptr, + make_wave_buffer_resource(((const int32_t *)global_base_ptr) - + threadIdx.x), + threadIdx.x * N /*16 bytes*/); +#else *(uint4 *)lds_base_ptr = *(const uint4 *)global_base_ptr; +#endif } else if constexpr (N == 8) { *(uint2 *)lds_base_ptr = *(const uint2 *)global_base_ptr; } else if constexpr (N == 4) { @@ -91,8 +116,21 @@ template TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr, void const *global_base_ptr, bool cond) { if constexpr (N == 16) { +#if defined(__gfx950__) + // gfx950: use 128-bit direct-to-LDS async copy (buffer_load_dwordx4 lds) + if (cond) { + async_buffer_load_dwordx4_v( + lds_base_ptr, + make_wave_buffer_resource(((const int32_t *)global_base_ptr) - + threadIdx.x), + threadIdx.x * N /*16 bytes*/); + } else { + *(uint4 *)lds_base_ptr = make_uint4(0, 0, 0, 0); + } +#else *(uint4 *)lds_base_ptr = cond ? *(const uint4 *)global_base_ptr : make_uint4(0, 0, 0, 0); +#endif } else if constexpr (N == 8) { *(uint2 *)lds_base_ptr = cond ? *(const uint2 *)global_base_ptr : make_uint2(0, 0); diff --git a/testing/python/amd/test_tilelang_gfx950_copy_async.py b/testing/python/amd/test_tilelang_gfx950_copy_async.py new file mode 100644 index 0000000000..78712ef11c --- /dev/null +++ b/testing/python/amd/test_tilelang_gfx950_copy_async.py @@ -0,0 +1,230 @@ +"""Tests for gfx950 (MI350) copy.async feature. + +Two new behaviours introduced in commit dfa63b10: + 1. cp_async_gs<16> on gfx950 lowers to buffer_load_dwordx4 ... lds + (128-bit direct-to-LDS, bypassing VGPRs) instead of a plain uint4 + scalar store. coalesced_width=8 (8 fp16 = 16 bytes) is required to + trigger the 16-byte path. + 2. CDNA arch helper reports smem_cap = 160 KB for gfx950, even when + an older driver reports the conservative 64 KB default. +""" + +import pytest +import tilelang as tl +import tilelang.language as T +import tilelang.testing + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _is_gfx950() -> bool: + try: + from tilelang import tvm + + mcpu = str(tvm.target.Target("rocm").attrs.get("mcpu", "")) + return "gfx950" in mcpu + except Exception: + return False + + +def _matmul_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads=128, + k_pack=1, + # coalesced_width=8 → cp_async_gs<16> (16 bytes, 8×fp16) + # coalesced_width=4 → cp_async_gs<8> (8 bytes, 4×fp16) + coalesced_width=4, +): + """Return a prim_func for pipelined GEMM using T.copy (global->shared).""" + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared, coalesced_width=coalesced_width) + else: + T.copy(A[by * block_M, k * block_K], A_shared, coalesced_width=coalesced_width) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared, coalesced_width=coalesced_width) + else: + T.copy(B[k * block_K, bx * block_N], B_shared, coalesced_width=coalesced_width) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B, k_pack=k_pack) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +# --------------------------------------------------------------------------- +# Test 1: codegen — cp_async_gs<16> is present in generated HIP source +# --------------------------------------------------------------------------- + + +@tilelang.testing.requires_rocm +def test_gfx950_cp_async_gs_16_in_codegen(): + """coalesced_width=8 (16 bytes) must emit cp_async_gs<16> in generated HIP source.""" + prog = _matmul_kernel( + 256, + 256, + 256, + 128, + 128, + 32, + False, + True, + T.float16, + T.float32, + T.float32, + num_stages=2, + coalesced_width=8, # 8 fp16 = 16 bytes → cp_async_gs<16> + ) + kernel = tl.compile(prog, out_idx=[2]) + src = kernel.get_kernel_source() + assert "cp_async_gs<16>" in src, "Expected cp_async_gs<16> in generated HIP source for 128-bit async copy path" + + +# --------------------------------------------------------------------------- +# Test 2: LDS capacity reported as 160 KB on gfx950 +# --------------------------------------------------------------------------- + + +@tilelang.testing.requires_rocm +def test_gfx950_smem_cap_160kb(): + """CDNA arch helper must report 160 KB LDS for gfx950.""" + from tilelang import tvm + from tilelang.carver.arch.cdna import CDNA, _GFX950_LDS_SIZE + + target = tvm.target.Target("rocm") + arch = CDNA(target) + + if _is_gfx950(): + assert arch.smem_cap == _GFX950_LDS_SIZE, f"Expected smem_cap={_GFX950_LDS_SIZE} for gfx950, got {arch.smem_cap}" + else: + # On non-gfx950 devices the override must NOT kick in + from tilelang import tvm as _tvm + + dev = _tvm.device("rocm", 0) + assert arch.smem_cap == dev.max_shared_memory_per_block + + +# --------------------------------------------------------------------------- +# Test 3: numerical correctness — pipelined copy.async GEMM (num_stages=2) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "trans_A, trans_B, k_pack", + [ + (False, False, 1), + (False, True, 1), + (True, True, 1), + (True, False, 1), + ], +) +@tilelang.testing.requires_rocm +def test_gfx950_copy_async_gemm_pipelined(trans_A, trans_B, k_pack): + """Pipelined GEMM (num_stages=2) with gfx950 copy.async must be numerically correct.""" + prog = _matmul_kernel( + 512, + 512, + 512, + 128, + 128, + 32, + trans_A, + trans_B, + T.float16, + T.float32, + T.float32, + num_stages=2, + threads=128, + k_pack=k_pack, + coalesced_width=4 * k_pack, + ) + kernel = tl.compile(prog, out_idx=[2]) + profiler = kernel.get_profiler() + + def ref_program(A, B): + import torch + + a = A.T.float() if trans_A else A.float() + b = B.T.float() if trans_B else B.float() + return torch.matmul(a, b) + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +# --------------------------------------------------------------------------- +# Test 4: non-pipelined baseline still correct (num_stages=0) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "trans_A, trans_B", + [ + (False, False), + (False, True), + (True, True), + (True, False), + ], +) +@tilelang.testing.requires_rocm +def test_gfx950_copy_async_gemm_no_pipeline(trans_A, trans_B): + """Non-pipelined GEMM (num_stages=0) must also produce correct results.""" + prog = _matmul_kernel( + 512, + 512, + 512, + 128, + 128, + 32, + trans_A, + trans_B, + T.float16, + T.float32, + T.float32, + num_stages=0, + threads=128, + coalesced_width=4, + ) + kernel = tl.compile(prog, out_idx=[2]) + profiler = kernel.get_profiler() + + def ref_program(A, B): + import torch + + a = A.T.float() if trans_A else A.float() + b = B.T.float() if trans_B else B.float() + return torch.matmul(a, b) + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/carver/arch/cdna.py b/tilelang/carver/arch/cdna.py index 5c2d4c4ed6..f27f21ca8f 100644 --- a/tilelang/carver/arch/cdna.py +++ b/tilelang/carver/arch/cdna.py @@ -3,6 +3,10 @@ from tvm.target import Target from .arch_base import TileDevice +# LDS size per CU for specific AMD GPU architectures (in bytes). +# gfx950 (CDNA4 / MI350): 160 KB — larger than the 64 KB default for gfx942. +_GFX950_LDS_SIZE = 160 * 1024 # 163840 bytes + def is_cdna_arch(arch: TileDevice) -> bool: return isinstance(arch, CDNA) @@ -18,7 +22,17 @@ def __init__(self, target: Target | str): raise RuntimeError("Cannot find HIP device 0.") self.device: tvm.runtime.Device = device self.platform: str = "CDNA" - self.smem_cap = device.max_shared_memory_per_block + + # TVM runtime should correctly report 160 KB (163840 B) for gfx950; the + # override is kept as a safety net in case an older driver reports the + # conservative 64 KB default. + mcpu = str(target.attrs.get("mcpu", "")) + reported = device.max_shared_memory_per_block + if "gfx950" in mcpu and reported < _GFX950_LDS_SIZE: + self.smem_cap = _GFX950_LDS_SIZE + else: + self.smem_cap = reported + self.compute_max_core = device.multi_processor_count self.warp_size = device.warp_size self.compute_capability = device.compute_version.replace(".", "") From 10b7f1f461d687969887382a6b65191bf13c5dc1 Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Thu, 23 Apr 2026 18:25:12 +0800 Subject: [PATCH 127/156] add double-thread constraint Co-authored-by: Copilot --- src/transform/auto_schedule/schedule_builder.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index fbaaf9574e..d65607d971 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -748,7 +748,11 @@ AssignWarpgroupIdsGlobal(IRStructure *root, const WarpSpecializeConfig &config, int64_t max_latency = std::max(warpgroup0_latency, warpgroup1_latency); int64_t min_latency = std::min(warpgroup0_latency, warpgroup1_latency); - if ((double)max_latency < 1.1 * min_latency) { + bool double_thread = (double)max_latency < 1.1 * min_latency; + if (auto thread_count_num = as_const_int(thread_count)) { + double_thread &= *thread_count_num <= 128; + } + if (double_thread) { int64_t warpgroup0_latency = 0; int64_t warpgroup1_latency = 0; From ef6a4313e5514e4e9ee006bf79be7fd47a0b4b15 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Thu, 23 Apr 2026 21:08:46 +0800 Subject: [PATCH 128/156] fix local var copy --- src/transform/auto_schedule/ir_structure.cc | 8 ++++---- src/transform/auto_schedule/warpgroup_partition.cc | 6 ++++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/transform/auto_schedule/ir_structure.cc b/src/transform/auto_schedule/ir_structure.cc index 1994d681ca..763ac19ce4 100644 --- a/src/transform/auto_schedule/ir_structure.cc +++ b/src/transform/auto_schedule/ir_structure.cc @@ -281,10 +281,10 @@ void TaskNode::CollectBufferAccessInfo( // Normal assigned warpgroup result.emplace(region->buffer, is_write, wg_id, this); } else if (IsWarpgroupBroadcast(wg_id)) { - // Broadcast: skip register memory (each wg has its own copy) - if (IsRegisterRegion(region)) { - return; - } + // Broadcast: skip register memory (each wg has its own copy), removed for barrier analysis. + // if (IsRegisterRegion(region)) { + // return; + // } // Shared/global memory is shared across wgs — emit for all for (int i = 0; i < num_wgs; ++i) { result.emplace(region->buffer, is_write, i, this); diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index 2cbcc36cba..9a21f3f1f2 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -182,7 +182,8 @@ class BufferRemapMutator : public StmtExprMutator { Map var_to_new_buffer_; }; -// Collect all local.fragment Buffers +// Collect all local / local.var / local.fragment Buffers written by +// broadcast tasks so each warpgroup gets its own private copy. static void CollectBroadcastFragmentBuffersImpl( const IRStructure *node, std::unordered_set &seen, std::vector &out) { @@ -193,7 +194,8 @@ static void CollectBroadcastFragmentBuffersImpl( if (IsWarpgroupBroadcast(task->GetWarpgroupId())) { for (const auto ®ion : task->GetWriteRegions()) { if (IsRegisterRegion(region) && - region->buffer.scope() == "local.fragment") { + (IsFragmentBuffer(region->buffer) || + IsLocalBuffer(region->buffer, /*allow_var=*/true))) { if (!seen.count(region->buffer.get())) { seen.insert(region->buffer.get()); out.push_back(region->buffer); From 0f29f9c98672e2f615e86af482b5a9b0e34913a2 Mon Sep 17 00:00:00 2001 From: Tong WU <109033598+Rachmanino@users.noreply.github.com> Date: Fri, 24 Apr 2026 00:10:52 +0800 Subject: [PATCH 129/156] [BugFix] Relax loop wait and adjust trailing drain behavior in async pipeline tests (#2092) Updated the async pipeline logic to progressively relax loop waits and modify the trailing drain suffix. The test case was also adjusted to reflect changes in the expected behavior, ensuring that the pipeline maintains the correct number of groups in flight and descends through the drain suffix as intended. Co-authored-by: wutong.1109 --- src/transform/inject_pipeline.cc | 8 +++++++- ...st_tilelang_transform_Inject_software_pipeline.py | 12 ++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/transform/inject_pipeline.cc b/src/transform/inject_pipeline.cc index facf769f87..47a9344886 100644 --- a/src/transform/inject_pipeline.cc +++ b/src/transform/inject_pipeline.cc @@ -1792,7 +1792,13 @@ class PipelineRewriter : public StmtExprMutator { for (size_t pos = 1; pos < suffix_wait_indices.size(); ++pos) { bool changed = false; int idx = suffix_wait_indices[pos]; - seq.Set(idx, RewriteFirstStaticWaitInWrapper(seq[idx], retain, &changed)); + // Tail consumers drain the final committed groups with no new commits in + // between. Relax them progressively from the end so the suffix becomes + // ..., wait<2>, wait<1>, wait<0> instead of rewriting every drain wait to + // the same retain count. + int new_wait_n = std::min(retain, static_cast(pos)); + seq.Set(idx, + RewriteFirstStaticWaitInWrapper(seq[idx], new_wait_n, &changed)); } return seq; } diff --git a/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py b/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py index eb71215ccb..1ec14566f0 100644 --- a/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py +++ b/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py @@ -390,15 +390,15 @@ def before( assert annotated == 0 -def test_async_pipeline_relaxes_loop_wait_and_splits_trailing_drain(): +def test_async_pipeline_relaxes_loop_wait_and_descends_trailing_drain(): @T.prim_func - def before(A: T.Tensor((32,), T.uint8), B: T.Tensor((32,), T.uint8)): + def before(A: T.Tensor((40,), T.uint8), B: T.Tensor((40,), T.uint8)): S = T.alloc_buffer((4,), dtype=T.uint8, scope="shared") for i in T.serial( 0, - 4, + 5, annotations={ - "software_pipeline_stage": [0, 2], + "software_pipeline_stage": [0, 3], "software_pipeline_order": [0, 1], "software_pipeline_async_stages": [0], "software_pipeline_async_producers": [1, 0], @@ -424,8 +424,8 @@ def before(A: T.Tensor((32,), T.uint8), B: T.Tensor((32,), T.uint8)): loop_waits = _collect_wait_args(loop.body) all_waits = _collect_wait_args(func) - assert loop_waits == [2], f"Expected relaxed loop wait to keep two groups in flight, got {loop_waits}" - assert all_waits == [2, 2, 0], f"Expected trailing waits to split into retain+drain, got {all_waits}" + assert loop_waits == [3], f"Expected relaxed loop wait to keep three groups in flight, got {loop_waits}" + assert all_waits == [3, 2, 1, 0], f"Expected trailing waits to descend through the drain suffix, got {all_waits}" def test_degenerate_pipeline_with_single_stage_is_not_expanded(): From e99d35a53e16f02a10eb41d7070f895e50727c2b Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Fri, 24 Apr 2026 10:57:08 +0800 Subject: [PATCH 130/156] fix naive ir structure bug --- src/transform/auto_schedule/schedule_builder.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index 170cd5766a..403841c72f 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -1270,6 +1270,10 @@ void ScheduleUnitBuilder::NaiveScheduleRecursive( NaiveScheduleLoop(ctrl); } else { NaiveScheduleRecursive(ctrl->child); + auto seq_node = std::make_shared(); + seq_node->children = {ctrl->child}; + WrapInScheduleUnits(seq_node->children); + ctrl->child = seq_node; } } } else if (node->IsWrapper()) { From 86989d711912a885d9e32da4d2fe7c2976069712 Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Fri, 24 Apr 2026 11:48:36 +0800 Subject: [PATCH 131/156] format --- src/transform/auto_schedule/ir_structure.cc | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/transform/auto_schedule/ir_structure.cc b/src/transform/auto_schedule/ir_structure.cc index 763ac19ce4..0ef3abba3a 100644 --- a/src/transform/auto_schedule/ir_structure.cc +++ b/src/transform/auto_schedule/ir_structure.cc @@ -281,10 +281,13 @@ void TaskNode::CollectBufferAccessInfo( // Normal assigned warpgroup result.emplace(region->buffer, is_write, wg_id, this); } else if (IsWarpgroupBroadcast(wg_id)) { - // Broadcast: skip register memory (each wg has its own copy), removed for barrier analysis. - // if (IsRegisterRegion(region)) { - // return; - // } + // Broadcast: skip register memory (each wg has its own copy), removed for + // barrier analysis. + /* + if (IsRegisterRegion(region)) { + return; + } + */ // Shared/global memory is shared across wgs — emit for all for (int i = 0; i < num_wgs; ++i) { result.emplace(region->buffer, is_write, i, this); From a5529167ac4d05522237eda7bf6f1cc9d01d6cc6 Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Fri, 24 Apr 2026 12:47:30 +0800 Subject: [PATCH 132/156] fix: remove unused let --- src/transform/auto_schedule/warpgroup_partition.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index 9a21f3f1f2..a8a8243881 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -576,7 +576,7 @@ class UnusedLetStmtStripper : public StmtExprMutator { auto body_uses_var = UsesVar(new_body, [&](const VarNode *v) { return v == op->var.get(); }); - bool value_is_pure = SideEffect(new_value) <= CallEffectKind::kPure; + bool value_is_pure = SideEffect(new_value) <= CallEffectKind::kReadState; if (!body_uses_var && value_is_pure) { return new_body; From c18c6236326f4fb3eff33ab54d67a3d74fda7353 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Fri, 24 Apr 2026 12:48:44 +0800 Subject: [PATCH 133/156] remove redundant letstmts --- src/transform/auto_schedule/warpgroup_partition.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index 9a21f3f1f2..a8a8243881 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -576,7 +576,7 @@ class UnusedLetStmtStripper : public StmtExprMutator { auto body_uses_var = UsesVar(new_body, [&](const VarNode *v) { return v == op->var.get(); }); - bool value_is_pure = SideEffect(new_value) <= CallEffectKind::kPure; + bool value_is_pure = SideEffect(new_value) <= CallEffectKind::kReadState; if (!body_uses_var && value_is_pure) { return new_body; From 0edb76ca789e76672b74814b62e36f3101a962a0 Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Fri, 24 Apr 2026 12:51:52 +0800 Subject: [PATCH 134/156] move the rewrites forward Co-authored-by: Copilot --- src/transform/auto_schedule/barrier.h | 90 ++++++++++++++------------- 1 file changed, 46 insertions(+), 44 deletions(-) diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index 922bc4b10c..85367b72eb 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -770,6 +770,52 @@ static void InsertSynchronization( for (auto unit : units) { for (int wg_id = 0; wg_id < num_wgs; ++wg_id) { auto sync_it = sync_infos.find({unit, wg_id}); + int barrier_versions = 1; + if (sync_it != sync_infos.end()) { + for (const auto &[waiting_unit_info, sync_infos] : sync_it->second) { + for (const auto &sync_info : sync_infos) { + barrier_versions = + std::max(barrier_versions, sync_info.buffer_versions); + } + } + } + Buffer barrier_buffer; + // Handle single special task, such as TCGEN05 or TMA load, that requires + // a barrier for itself. + if (auto task = GetInnerTask(unit)) { + int task_wg_id = task->GetWarpgroupId(); + if (task->is_TCGEN05() && task_wg_id == wg_id) { + int barrier_id = next_barrier_id++; + barrier_buffer = makeBarrierBuffer( + 1, "tcgen05_barrier_" + std::to_string(barrier_id), + barrier_versions, barrier_buffers, barrier_map); + PrimExpr version_index = + indexmod(loop_info.CalculateIterationCount(), barrier_versions); + PrimExpr mbar_expr = BufferLoad(barrier_buffer, {version_index}); + RewriteGemmMbar(task, mbar_expr); + // TODO: need to change the lower of tcgen05_gemm to check if there is + // already a arrive statement. Then we can manually insert the arrive + // statement to deal with the case where the tcgen05_gemm is inside an + // if condition. + /* + Stmt arrive_stmt = + makeTcgen05MmaArrive(barrier_buffer, version_index); + InsertStatementIntoScheduleUnit(unit, arrive_stmt, false, wg_id); + */ + } + if (task->HasTMALoad() && task_wg_id == wg_id) { + int barrier_id = next_barrier_id++; + barrier_buffer = makeBarrierBuffer( + thread_count[wg_id], "tma_barrier_" + std::to_string(barrier_id), + barrier_versions, barrier_buffers, barrier_map); + PrimExpr version_index = + indexmod(loop_info.CalculateIterationCount(), barrier_versions); + PrimExpr mbar_expr = BufferLoad(barrier_buffer, {version_index}); + RewriteCopyMbar(task, mbar_expr); + Stmt arrive_stmt = makeBarrierArrive(mbar_expr); + InsertStatementIntoScheduleUnit(unit, arrive_stmt, false, wg_id); + } + } if (sync_it == sync_infos.end()) continue; const auto &wait_map = sync_it->second; @@ -852,50 +898,6 @@ static void InsertSynchronization( } } } - int barrier_versions = 1; - for (const auto &[waiting_unit_info, sync_infos] : wait_map) { - for (const auto &sync_info : sync_infos) { - barrier_versions = - std::max(barrier_versions, sync_info.buffer_versions); - } - } - Buffer barrier_buffer; - // Handle single special task, such as TCGEN05 or TMA load, that requires - // a barrier for itself. - if (auto task = GetInnerTask(unit)) { - int task_wg_id = task->GetWarpgroupId(); - if (task->is_TCGEN05() && task_wg_id == wg_id) { - int barrier_id = next_barrier_id++; - barrier_buffer = makeBarrierBuffer( - 1, "tcgen05_barrier_" + std::to_string(barrier_id), - barrier_versions, barrier_buffers, barrier_map); - PrimExpr version_index = - indexmod(loop_info.CalculateIterationCount(), barrier_versions); - PrimExpr mbar_expr = BufferLoad(barrier_buffer, {version_index}); - RewriteGemmMbar(task, mbar_expr); - // TODO: need to change the lower of tcgen05_gemm to check if there is - // already a arrive statement. Then we can manually insert the arrive - // statement to deal with the case where the tcgen05_gemm is inside an - // if condition. - /* - Stmt arrive_stmt = - makeTcgen05MmaArrive(barrier_buffer, version_index); - InsertStatementIntoScheduleUnit(unit, arrive_stmt, false, wg_id); - */ - } - if (task->HasTMALoad() && task_wg_id == wg_id) { - int barrier_id = next_barrier_id++; - barrier_buffer = makeBarrierBuffer( - thread_count[wg_id], "tma_barrier_" + std::to_string(barrier_id), - barrier_versions, barrier_buffers, barrier_map); - PrimExpr version_index = - indexmod(loop_info.CalculateIterationCount(), barrier_versions); - PrimExpr mbar_expr = BufferLoad(barrier_buffer, {version_index}); - RewriteCopyMbar(task, mbar_expr); - Stmt arrive_stmt = makeBarrierArrive(mbar_expr); - InsertStatementIntoScheduleUnit(unit, arrive_stmt, false, wg_id); - } - } auto check_need_barrier = [&](ScheduleUnit *waiting_unit, int waiting_wg_id, const SyncInfo &sync_info) { From 09d907193703fd4b5f604ef311547b6bd815ef7a Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Fri, 24 Apr 2026 13:18:53 +0800 Subject: [PATCH 135/156] add WAW dependence & avoid duplicated dependence when iter=1 Co-authored-by: Copilot --- src/transform/auto_schedule/barrier.h | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index 85367b72eb..6d4c07f236 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -665,13 +665,16 @@ GetSyncInfos(const std::vector &units, int num_wgs, } } } - // WAR: unit writes buffer, wait for all last readers + // WAR/WAW: unit writes buffer, wait for all last readers + // if there are no last readers, wait for last writer auto first_writes = unit->GetFirstAccessTasks( buffer, /*is_write=*/true, wg_id, SchedulePhase::kBody); if (!first_writes.empty()) { + bool has_last_read = false; for (int last_wg = 0; last_wg < num_wgs; ++last_wg) { if (last_read_unit[last_wg] == nullptr) continue; + has_last_read = true; for (auto *consumer : first_writes) { for (auto *producer : last_read_unit_tasks[last_wg]) { sync_infos[{last_read_unit[last_wg], last_wg}][{unit, wg_id}] @@ -680,6 +683,20 @@ GetSyncInfos(const std::vector &units, int num_wgs, } } } + if (!has_last_read && last_write_unit != nullptr && + !waited_write_wgs[wg_id]) { + for (auto *consumer : first_writes) { + for (const auto &[last_write_wg_id, last_write_unit_tasks] : + last_write_unit_wg_tasks) { + for (auto *producer : last_write_unit_tasks) { + sync_infos[{last_write_unit, last_write_wg_id}] + [{unit, wg_id}] + .emplace(distance, buffer, producer, consumer, + num_versions); + } + } + } + } } } // Set status to avoid redundant dependencies for subsequent units @@ -690,10 +707,12 @@ GetSyncInfos(const std::vector &units, int num_wgs, continue; if (!buffer_access.is_write) { waited_write_wgs[wg_id] = true; + last_read_unit[wg_id] = nullptr; } else { for (int wg_id = 0; wg_id < num_wgs; ++wg_id) { last_read_unit[wg_id] = nullptr; } + last_write_unit = nullptr; } } if (iter == 0) { From 9897fe5702672bc3f25f81bd5ff785b8cca88b89 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Fri, 24 Apr 2026 14:44:28 +0800 Subject: [PATCH 136/156] fix barrier around let missing --- .../auto_schedule/warpgroup_partition.cc | 111 ++++++++++++------ 1 file changed, 73 insertions(+), 38 deletions(-) diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index a8a8243881..811ec942d6 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -468,26 +468,24 @@ CloneIRStructureWithWarpgroupFilter(IRStructure *node, int warpgroup_id, new_unit->child = CloneIRStructureWithWarpgroupFilter( unit->child.get(), warpgroup_id, var_remap, buffer_remap); - if (!child_is_let_decl) { - // Copy before/after for the target warp group - new_unit->before[warpgroup_id] = unit->before[warpgroup_id]; - new_unit->after[warpgroup_id] = unit->after[warpgroup_id]; - // Substitute renamed LetDecl variables in before/after stmts - if (!var_remap.empty()) { - for (auto &s : new_unit->before[warpgroup_id]) { - s = Substitute(s, var_remap); - } - for (auto &s : new_unit->after[warpgroup_id]) { - s = Substitute(s, var_remap); - } - } + // Copy before/after for the target warp group + new_unit->before[warpgroup_id] = unit->before[warpgroup_id]; + new_unit->after[warpgroup_id] = unit->after[warpgroup_id]; + // Substitute renamed LetDecl variables in before/after stmts + if (!var_remap.empty()) { for (auto &s : new_unit->before[warpgroup_id]) { - s = apply_buffer_remap_stmt(s); + s = Substitute(s, var_remap); } for (auto &s : new_unit->after[warpgroup_id]) { - s = apply_buffer_remap_stmt(s); + s = Substitute(s, var_remap); } } + for (auto &s : new_unit->before[warpgroup_id]) { + s = apply_buffer_remap_stmt(s); + } + for (auto &s : new_unit->after[warpgroup_id]) { + s = apply_buffer_remap_stmt(s); + } return new_unit; } else if (node->IsIf()) { if (!node->containWarpgroupId(warpgroup_id) && !ContainsLetDecl(node)) @@ -1116,18 +1114,27 @@ Stmt ApplyWarpgroupPartitionToIRStructure( } // --- Per-child construction --- - // Walk root SequenceNode's children. LetDecl children accumulate bindings; - // non-LetDecl children produce IfThenElse blocks wrapped with accumulated - // LetDecl scopes per warp group. + // Walk root SequenceNode's children. LetDecl children accumulate as + // (var, value, before_stmts, after_stmts) tuples so that the cloned + // before/after barriers on their ScheduleUnit are preserved. When wrapping + // a subsequent non-LetDecl child, each accumulated tuple is re-emitted as + // ; let var = value in ( ; body) + // so the barrier pair brackets the let binding while `var` stays in scope + // for the rest of the segment. Stmt if_then_else; if (root->IsSequence()) { auto root_seq = static_cast(root); size_t num_children = root_seq->children.size(); - // per-wg accumulated LetDecl {var, value} from earlier children - std::vector>> wg_accumulated_lets( - num_wgs); + struct AccumulatedLet { + Var var; + PrimExpr value; + std::vector before; + std::vector after; + }; + // per-wg accumulated LetDecl entries from earlier children + std::vector> wg_accumulated_lets(num_wgs); std::vector segmented_stmts; bool first_non_let = true; @@ -1137,23 +1144,36 @@ Stmt ApplyWarpgroupPartitionToIRStructure( bool is_let_decl = IsLetDeclNode(unit->child.get()); if (is_let_decl) { - // Extract LetDecl {var, value} from each wg's filtered result + // Extract LetDecl {var, value, before, after} from each wg's filtered + // result. The surrounding before/after live on the cloned + // ScheduleUnit wrapping the LetDecl task. for (size_t i = 0; i < num_wgs; ++i) { - if (wg_children[i][ci]) { - // wg_children[i][ci] is a ScheduleUnit wrapping a TaskNode - IRStructure *inner = wg_children[i][ci].get(); - TaskNode *task = nullptr; - if (inner->IsScheduleUnit()) { - task = static_cast( - static_cast(inner)->child.get()); - } else if (inner->IsTask()) { - task = static_cast(inner); + if (!wg_children[i][ci]) + continue; + IRStructure *inner = wg_children[i][ci].get(); + TaskNode *task = nullptr; + std::vector before_stmts; + std::vector after_stmts; + if (inner->IsScheduleUnit()) { + auto wg_unit = static_cast(inner); + task = static_cast(wg_unit->child.get()); + auto it_before = wg_unit->before.find(static_cast(i)); + if (it_before != wg_unit->before.end()) { + before_stmts = it_before->second; + } + auto it_after = wg_unit->after.find(static_cast(i)); + if (it_after != wg_unit->after.end()) { + after_stmts = it_after->second; } - if (task && !task->stmts.empty()) { - const auto *let = task->stmts[0].as(); - if (let) { - wg_accumulated_lets[i].push_back({let->var, let->value}); - } + } else if (inner->IsTask()) { + task = static_cast(inner); + } + if (task && !task->stmts.empty()) { + const auto *let = task->stmts[0].as(); + if (let) { + wg_accumulated_lets[i].push_back({let->var, let->value, + std::move(before_stmts), + std::move(after_stmts)}); } } } @@ -1178,8 +1198,23 @@ Stmt ApplyWarpgroupPartitionToIRStructure( // Wrap with accumulated LetDecl bindings (innermost first) for (int j = static_cast(wg_accumulated_lets[i].size()) - 1; j >= 0; --j) { - wg_stmts[i] = LetStmt(wg_accumulated_lets[i][j].first, - wg_accumulated_lets[i][j].second, wg_stmts[i]); + const AccumulatedLet &acc = wg_accumulated_lets[i][j]; + Stmt body = wg_stmts[i]; + if (!acc.after.empty()) { + std::vector tmp = acc.after; + if (!IsEvaluateZero(body)) { + tmp.push_back(body); + } + body = SeqStmt::Flatten(tmp); + } + Stmt let_stmt = LetStmt(acc.var, acc.value, body); + if (!acc.before.empty()) { + std::vector tmp = acc.before; + tmp.push_back(let_stmt); + wg_stmts[i] = SeqStmt::Flatten(tmp); + } else { + wg_stmts[i] = let_stmt; + } } } From 01bf798a2e1a6c8319735ce6167ba61e9338f6c8 Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Fri, 24 Apr 2026 14:13:59 +0800 Subject: [PATCH 137/156] remove cross-warpgroup dependency for register buffers --- src/transform/auto_schedule/barrier.h | 12 ++++++++++++ src/transform/auto_schedule/ir_structure.cc | 9 +-------- src/transform/auto_schedule/ir_structure.h | 10 +++++++--- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index 6d4c07f236..e4dcb55e6f 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -658,6 +658,10 @@ GetSyncInfos(const std::vector &units, int num_wgs, for (const auto &[last_write_wg_id, last_write_unit_tasks] : last_write_unit_wg_tasks) { for (auto *producer : last_write_unit_tasks) { + if (IsRegisterBuffer(buffer) && wg_id != last_write_wg_id) { + // Skip cross-warpgroup dependency for register buffers + continue; + } sync_infos[{last_write_unit, last_write_wg_id}][{unit, wg_id}] .emplace(distance, buffer, producer, consumer, num_versions); @@ -677,6 +681,10 @@ GetSyncInfos(const std::vector &units, int num_wgs, has_last_read = true; for (auto *consumer : first_writes) { for (auto *producer : last_read_unit_tasks[last_wg]) { + if (IsRegisterBuffer(buffer) && wg_id != last_wg) { + // Skip cross-warpgroup dependency for register buffers + continue; + } sync_infos[{last_read_unit[last_wg], last_wg}][{unit, wg_id}] .emplace(distance, buffer, producer, consumer, num_versions); @@ -689,6 +697,10 @@ GetSyncInfos(const std::vector &units, int num_wgs, for (const auto &[last_write_wg_id, last_write_unit_tasks] : last_write_unit_wg_tasks) { for (auto *producer : last_write_unit_tasks) { + if (IsRegisterBuffer(buffer) && wg_id != last_write_wg_id) { + // Skip cross-warpgroup dependency for register buffers + continue; + } sync_infos[{last_write_unit, last_write_wg_id}] [{unit, wg_id}] .emplace(distance, buffer, producer, consumer, diff --git a/src/transform/auto_schedule/ir_structure.cc b/src/transform/auto_schedule/ir_structure.cc index 0ef3abba3a..a9cb429dbb 100644 --- a/src/transform/auto_schedule/ir_structure.cc +++ b/src/transform/auto_schedule/ir_structure.cc @@ -281,14 +281,7 @@ void TaskNode::CollectBufferAccessInfo( // Normal assigned warpgroup result.emplace(region->buffer, is_write, wg_id, this); } else if (IsWarpgroupBroadcast(wg_id)) { - // Broadcast: skip register memory (each wg has its own copy), removed for - // barrier analysis. - /* - if (IsRegisterRegion(region)) { - return; - } - */ - // Shared/global memory is shared across wgs — emit for all + // Broadcast: shared across wgs — emit for all for (int i = 0; i < num_wgs; ++i) { result.emplace(region->buffer, is_write, i, this); } diff --git a/src/transform/auto_schedule/ir_structure.h b/src/transform/auto_schedule/ir_structure.h index ec6206e447..f948ba1686 100644 --- a/src/transform/auto_schedule/ir_structure.h +++ b/src/transform/auto_schedule/ir_structure.h @@ -1167,14 +1167,18 @@ inline MemoryType GetMemoryTypeFromScope(const String &scope) { return MemoryType::kUnknown; } -// Helper function to check if a buffer region is in register memory -inline bool IsRegisterRegion(const BufferRegion ®ion) { - const Buffer &buffer = region->buffer; +// Helper function to check if a buffer is in register memory +inline bool IsRegisterBuffer(const Buffer &buffer) { String scope = buffer.scope(); MemoryType mem_type = GetMemoryTypeFromScope(scope); return mem_type == MemoryType::kRegister; } +// Helper function to check if a buffer region is in register memory +inline bool IsRegisterRegion(const BufferRegion ®ion) { + return IsRegisterBuffer(region->buffer); +} + // Helper function to collect all register regions from an IRStructure node inline std::vector CollectRegisterRegions(const IRStructure *node) { From d255c0af9f4df692fa48b8e31987dfee20f59ee8 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Fri, 24 Apr 2026 17:17:50 +0800 Subject: [PATCH 138/156] fix reused buffer analysis --- .../auto_schedule/warpgroup_partition.cc | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index 811ec942d6..3ba0f49eae 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -1138,6 +1138,7 @@ Stmt ApplyWarpgroupPartitionToIRStructure( std::vector segmented_stmts; bool first_non_let = true; + bool prev_was_loop = true; for (size_t ci = 0; ci < num_children; ++ci) { auto unit = static_cast(root_seq->children[ci].get()); @@ -1246,15 +1247,23 @@ Stmt ApplyWarpgroupPartitionToIRStructure( } } - // Insert liveness boundary before each non-empty non-LetDecl child - segmented_stmts.push_back(AttrStmt( - Integer(0), attr::kAutoScheduleSharedMemoryBoundary, 0, Evaluate(0))); - if (is_shared_attr_segment && shared_attr_stmt.defined()) { segmented_stmts.push_back(shared_attr_stmt); continue; } + bool is_loop = unit->child->IsControl(); + + // Insert liveness boundary only before for-loop segments and + // before non-loop segments that follow a for-loop. Consecutive + // non-loop segments share a single boundary to avoid introducing + // spurious buffer reuse hints between them. + if (prev_was_loop || is_loop) { + segmented_stmts.push_back(AttrStmt( + Integer(0), attr::kAutoScheduleSharedMemoryBoundary, 0, + Evaluate(0))); + } + // Prepend set_max_nreg only to the first non-LetDecl child if (first_non_let && !has_simt_copy && !has_inner_nreg_decision && num_wgs == 2 && config.enable_set_max_nreg) { @@ -1270,6 +1279,8 @@ Stmt ApplyWarpgroupPartitionToIRStructure( first_non_let = false; segmented_stmts.push_back(MakeWarpgroupIf(wg_stmts)); + + prev_was_loop = is_loop; } segmented_stmts.push_back(AttrStmt( Integer(0), attr::kAutoScheduleSharedMemoryBoundary, 0, Evaluate(0))); From 3df8b465db5615322311628d13be0ff4b8e6c0ed Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Fri, 24 Apr 2026 17:36:57 +0800 Subject: [PATCH 139/156] check kernel using barrier & format --- .../auto_schedule/warpgroup_partition.cc | 6 +- tilelang/engine/phase.py | 73 ++++++++++++++++++- 2 files changed, 74 insertions(+), 5 deletions(-) diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index 3ba0f49eae..9eb09b7ce4 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -1259,9 +1259,9 @@ Stmt ApplyWarpgroupPartitionToIRStructure( // non-loop segments share a single boundary to avoid introducing // spurious buffer reuse hints between them. if (prev_was_loop || is_loop) { - segmented_stmts.push_back(AttrStmt( - Integer(0), attr::kAutoScheduleSharedMemoryBoundary, 0, - Evaluate(0))); + segmented_stmts.push_back( + AttrStmt(Integer(0), attr::kAutoScheduleSharedMemoryBoundary, 0, + Evaluate(0))); } // Prepend set_max_nreg only to the first non-LetDecl child diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index cb4e56bbe2..cfea714aa3 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -28,6 +28,67 @@ def module_has_tma(mod: IRModule) -> bool: return any(func.attrs and func.attrs.get("tl.has_tma", False) for _, func in mod.functions.items()) +def module_has_barrier(mod: IRModule) -> bool: + """Check whether any PrimFunc in ``mod`` allocates / initializes an mbarrier. + """ + from tvm.tir import stmt_functor + + for _, func in mod.functions.items(): + if not isinstance(func, tir.PrimFunc): + continue + + # Explicit allocations with a barrier storage scope. + for buf in func.buffer_map.values(): + scope = buf.scope() if hasattr(buf, "scope") else "" + if isinstance(scope, str) and scope.startswith("shared.barrier"): + return True + if isinstance(scope, str) and scope.startswith("shared.cluster_barrier"): + return True + + found = [False] + + def _check(node, _found=found): + if _found[0]: + return + # Buffer / BufferRealize allocations inside the body. + buffer = None + if isinstance(node, tir.BufferRealize): + buffer = node.buffer + elif isinstance(node, tir.Allocate): + # Allocate does not carry storage scope directly; rely on the + # associated AttrStmt "storage_scope" picked up below. + buffer = None + if buffer is not None: + scope = buffer.scope() if hasattr(buffer, "scope") else "" + if isinstance(scope, str) and ( + scope.startswith("shared.barrier") + or scope.startswith("shared.cluster_barrier") + ): + _found[0] = True + return + # Block-level "barrier_init" annotation produced by alloc_barrier. + if isinstance(node, tir.Block): + annotations = getattr(node, "annotations", None) + if annotations is not None and "barrier_init" in annotations: + _found[0] = True + return + # AttrStmt-level "storage_scope" carrying a barrier scope. + if isinstance(node, tir.AttrStmt) and node.attr_key == "storage_scope": + value = node.value + scope_str = value.value if hasattr(value, "value") else str(value) + if isinstance(scope_str, str) and ( + scope_str.startswith("shared.barrier") + or scope_str.startswith("shared.cluster_barrier") + ): + _found[0] = True + + stmt_functor.post_order_visit(func.body, _check) + if found[0]: + return True + + return False + + def module_uses_thread_var(mod: IRModule) -> bool: """Check whether any PrimFunc in ``mod`` references thread-index variables inside its body. @@ -254,8 +315,16 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.InjectAssumes()(mod) # Simplify the IR expressions mod = tilelang.transform.Simplify()(mod) - if allow_autoschedule(target=target) and not module_uses_thread_var(mod): - # Auto schedule for high-level operations + if ( + allow_autoschedule(target=target) + and not module_uses_thread_var(mod) + and not module_has_barrier(mod) + ): + # Auto schedule for high-level operations. + # Skip when the kernel already manages explicit mbarriers + # (alloc_barrier / alloc_cluster_barrier), because reordering the + # rewrites breaks invariants that later barrier lowering and the + # WS / pipelined TMA copy pipeline rely on. mod = tilelang.transform.IfConditionExtract()(mod) mod = tilelang.transform.AutoSchedule(False)(mod) mod = tilelang.transform.Simplify()(mod) From e18b8e6621d53bd7311aee8586462869ecb8638f Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Fri, 24 Apr 2026 17:41:36 +0800 Subject: [PATCH 140/156] format --- tilelang/engine/phase.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index cfea714aa3..3ad6aea577 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -29,8 +29,7 @@ def module_has_tma(mod: IRModule) -> bool: def module_has_barrier(mod: IRModule) -> bool: - """Check whether any PrimFunc in ``mod`` allocates / initializes an mbarrier. - """ + """Check whether any PrimFunc in ``mod`` allocates / initializes an mbarrier.""" from tvm.tir import stmt_functor for _, func in mod.functions.items(): @@ -60,10 +59,7 @@ def _check(node, _found=found): buffer = None if buffer is not None: scope = buffer.scope() if hasattr(buffer, "scope") else "" - if isinstance(scope, str) and ( - scope.startswith("shared.barrier") - or scope.startswith("shared.cluster_barrier") - ): + if isinstance(scope, str) and (scope.startswith("shared.barrier") or scope.startswith("shared.cluster_barrier")): _found[0] = True return # Block-level "barrier_init" annotation produced by alloc_barrier. @@ -77,8 +73,7 @@ def _check(node, _found=found): value = node.value scope_str = value.value if hasattr(value, "value") else str(value) if isinstance(scope_str, str) and ( - scope_str.startswith("shared.barrier") - or scope_str.startswith("shared.cluster_barrier") + scope_str.startswith("shared.barrier") or scope_str.startswith("shared.cluster_barrier") ): _found[0] = True @@ -315,11 +310,7 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.InjectAssumes()(mod) # Simplify the IR expressions mod = tilelang.transform.Simplify()(mod) - if ( - allow_autoschedule(target=target) - and not module_uses_thread_var(mod) - and not module_has_barrier(mod) - ): + if allow_autoschedule(target=target) and not module_uses_thread_var(mod) and not module_has_barrier(mod): # Auto schedule for high-level operations. # Skip when the kernel already manages explicit mbarriers # (alloc_barrier / alloc_cluster_barrier), because reordering the From 264efe2d36802e5e84a77ff11f59e46c86cbec11 Mon Sep 17 00:00:00 2001 From: Tong WU <109033598+Rachmanino@users.noreply.github.com> Date: Fri, 24 Apr 2026 23:58:20 +0800 Subject: [PATCH 141/156] [Feature] Block-scaled GEMM support for MXFP8 on Blackwell (#1945) * mxfp8 blockscaled gemm (squashed: 7 commits from wt/mxfp8) * fix: post-rebase compat with tcgen05 refactors - Remove duplicate tcgen05_before/after_thread_sync in tcgen_05.h (now in main) - Pass disable_2cta=True in tcgen05mma_blockscaled meta lookup (block-scaled does not support .ws/.cta_group::2 in this path) * fix(mxfp8): declare SFA/SFB reads and plumb mbar_phase_expr P0-1: GemmPyNode::GetAccessRegions now pushes sfaRegion_/sfbRegion_ into result.reads when they are defined (block-scaled GEMM), so pipeline planning and layout inference correctly track these buffers. Previously SFA/SFB were invisible to those passes, which only worked by accident because the current example schedules the SF pipeline by hand. P0-3: GemmTCGEN5._lower_blockscaled now accepts mbar_phase_expr for API consistency with _gemm_ss and the rest of the GemmPyNode.Lower chain. Block-scaled follows explicit-async TCGEN5MMA semantics (user / pipeline pass manages mbarrier_wait_parity), so the parameter is intentionally unused in the current lowering; the structural alignment means a future synchronous block-scaled path can add the auto-wait without another signature change. * feat(mxfp8): rename blockscaled API to tcgen05_gemm_blockscaled Expose block-scaled MXFP8 GEMM under T.tcgen05_gemm_blockscaled to make its explicit-async TCGEN05 semantics obvious and align its naming with T.tcgen05_gemm. Add a dedicated C++ op registration (tl.tileop.tcgen05_gemm_blockscaled_py) that tags the GemmPy node with is_tcgen05=1 and is_blockscaled=1, update the Python language surface to export only the new name, and switch the SM100 MXFP8 examples to call the renamed API. * fix(mxfp8): keep 1cta blockscaled tcgen05 on 1cta path * draft(mxfp8): wire up 2cta blockscaled path * Remove tilelang.disable_cache() calls from gemm examples for cleaner execution. * lint fix * upd 3PFLOPs mxfp8 gemm * lint * cleanup * more cleanup and clarify * rename ptx_tcgen05_cp to ptx_tcgen05_cp_warpx4 for clarity * lint --------- Co-authored-by: LeiWang1999 --- .../gemm_sm100/gemm_mxfp8_blockscaled_1d1d.py | 703 ++++++++++++++++++ src/op/builtin.cc | 15 + src/op/builtin.h | 15 + src/op/copy.cc | 4 + src/op/gemm.cc | 28 + src/op/gemm.h | 8 +- src/op/tcgen5_meta.h | 60 ++ src/target/codegen_cuda.cc | 82 ++ .../cuda/instruction/tcgen05mma.h | 78 ++ src/tl_templates/cuda/tcgen_05.h | 51 ++ .../intrinsics/tcgen05_macro_generator.py | 240 ++++++ tilelang/language/__init__.py | 8 +- tilelang/language/ast/ir.py | 2 + tilelang/language/builtin.py | 79 +- tilelang/language/gemm_op.py | 210 +++++- tilelang/language/tir/ir.py | 1 + tilelang/language/tir/op.py | 72 ++ tilelang/tileop/gemm/__init__.py | 8 + tilelang/tileop/gemm/gemm_base.py | 20 + tilelang/tileop/gemm/gemm_tcgen05.py | 100 ++- tilelang/utils/language.py | 37 + 21 files changed, 1813 insertions(+), 8 deletions(-) create mode 100644 examples/gemm_sm100/gemm_mxfp8_blockscaled_1d1d.py diff --git a/examples/gemm_sm100/gemm_mxfp8_blockscaled_1d1d.py b/examples/gemm_sm100/gemm_mxfp8_blockscaled_1d1d.py new file mode 100644 index 0000000000..c451a8856d --- /dev/null +++ b/examples/gemm_sm100/gemm_mxfp8_blockscaled_1d1d.py @@ -0,0 +1,703 @@ +# MXFP8 Block-Scaled GEMM on SM100 +# Blockscale size: (M, N, K) = (1, 1, 128) + +import argparse +import torch +import tilelang +import tilelang.language as T +from tilelang.carver.arch import driver +from tilelang.profiler import do_bench + + +@tilelang.jit +def mxfp8_blockscaled_gemm( + A, + B, + SFA, + SFB, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + sf_granularity_k=128, +): + """1D-1D Block-scaled MXFP8 GEMM. + + A: [M, K] in FP8 (E4M3 or E5M2) + B: [K, N] in FP8 (E4M3 or E5M2) + SFA: [(K / sf_granularity_k) / 4) * M] in uint32 + Group-major packed E8M0 scale factors for A. + SFB: [(K / sf_granularity_k) / 4) * N] in uint32 + Group-major packed E8M0 scale factors for B. + """ + M, N, K = T.const("M, N, K") + + k_iters = T.ceildiv(K, block_K) + # Load 4 K-blocks of SF at once → load every 4 iterations + sf_load_period = sf_granularity_k * 4 // block_K + sf_k_groups = T.ceildiv(T.ceildiv(K, sf_granularity_k), 4) + + A: T.Tensor[[M, K], in_dtype] + B: T.Tensor[[K, N], in_dtype] + SFA: T.Tensor[[sf_k_groups * M], T.uint32] + SFB: T.Tensor[[sf_k_groups * N], T.uint32] + C = T.empty((M, N), out_dtype) + + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by): + # Data shared memory (pipelined) + A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) + B_shared = T.alloc_shared((num_stages, block_K, block_N), in_dtype) + + # Scale factor shared memory — one uint32 per row/column, packing 4 K-blocks. + SFA_shared = T.alloc_shared((num_stages, block_M), "uint32") + SFB_shared = T.alloc_shared((num_stages, block_N), "uint32") + + # Accumulator in tensor memory + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + + # Scale factors in tensor memory (TMEM has 128 rows / 32-bit cells) + SFA_tmem = T.alloc_tmem([block_M, block_M // 128 * 4], "uint32") + SFB_tmem = T.alloc_tmem([block_M, block_N // 128 * 4], "uint32") + + # Output buffers + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + # Barriers + loaded = T.alloc_barrier([32] * num_stages) + with_sf_full = T.alloc_barrier([32] * num_stages) + consumed = T.alloc_barrier([1] * num_stages) + tmem_full = T.alloc_barrier([1]) + + tx = T.get_thread_binding() + T.use_swizzle(8) + + if tx < 32: + # Warp 0: TMA load + for k in T.serial(k_iters): + T.mbarrier_wait_parity(consumed[k % num_stages], ((k // num_stages) & 1) ^ 1) + T.tma_copy( + A[bx * block_M : (bx + 1) * block_M, k * block_K : (k + 1) * block_K], + A_shared[k % num_stages, :, :], + barrier=loaded[k % num_stages], + ) + T.tma_copy( + B[k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], + B_shared[k % num_stages, :, :], + barrier=loaded[k % num_stages], + ) + # Load one packed uint32 SF word every sf_load_period iterations. + if k % sf_load_period == 0: + sf_group_idx = k // sf_load_period + T.tma_copy( + SFA[sf_group_idx * M + bx * block_M : sf_group_idx * M + (bx + 1) * block_M], + SFA_shared[k % num_stages, :], + barrier=loaded[k % num_stages], + ) + T.tma_copy( + SFB[sf_group_idx * N + by * block_N : sf_group_idx * N + (by + 1) * block_N], + SFB_shared[k % num_stages, :], + barrier=loaded[k % num_stages], + ) + T.mbarrier_arrive(loaded[k % num_stages]) + + elif tx < 64: + # Warp 1: MMA issue + UTCCP + for k in T.serial(k_iters): + stage = k % num_stages + phase = (k // num_stages) & 1 + T.mbarrier_wait_parity(loaded[stage], phase) + T.mbarrier_wait_parity(with_sf_full[stage], phase) + + if k % sf_load_period == 0: + T.tcgen05_cp_warpx4(SFA_shared[stage, :], SFA_tmem) + T.tcgen05_cp_warpx4(SFB_shared[stage, :], SFB_tmem) + + # sf_id selects which of the 4 packed E8M0 values to use + T.tcgen05_gemm_blockscaled( + A_shared[stage, :, :], + B_shared[stage, :, :], + C_tmem, + SFA_tmem, + SFB_tmem, + mbar=consumed[stage], + clear_accum=k == 0, + sf_a_id=k % sf_load_period, + sf_b_id=k % sf_load_period, + ) + + T.tcgen05_mma_arrive(tmem_full) + + elif tx < 96: + # Warp 2: scale-factor transpose + for k in T.serial(k_iters): + stage = k % num_stages + phase = (k // num_stages) & 1 + T.mbarrier_wait_parity(loaded[stage], phase) + + if k % sf_load_period == 0: + T.tcgen05_sf_warp_transpose(SFA_shared[stage, :]) + T.tcgen05_sf_warp_transpose(SFB_shared[stage, :]) + T.fence_proxy_async() + T.mbarrier_arrive(with_sf_full[stage]) + + # Epilogue: all warps + T.mbarrier_wait_parity(tmem_full, 0) + T.sync_threads() + + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + T.copy(C_shared, C[bx * block_M, by * block_N]) + + return C + + +@tilelang.jit +def mxfp8_blockscaled_gemm_2cta( + A, + B, + SFA, + SFB, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + sf_granularity_k=128, +): + M, N, K = T.const("M, N, K") + + assert block_M == 128 + assert block_N == 256 + assert block_K == 128 + assert sf_granularity_k == 128 + + half_N = block_N // 2 + k_iters = T.ceildiv(K, block_K) + sf_load_period = sf_granularity_k * 4 // block_K + sf_k_groups = T.ceildiv(T.ceildiv(K, sf_granularity_k), 4) + assert sf_load_period == 4 + + A: T.Tensor[[M, K], in_dtype] + B: T.Tensor[[K, N], in_dtype] + SFA: T.Tensor[[sf_k_groups * M], T.uint32] + SFB: T.Tensor[[sf_k_groups * N], T.uint32] + C = T.empty((M, N), out_dtype) + + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128, cluster_dims=2) as (bx, by): + cta_id = T.block_rank_in_cluster() + T.assume(cta_id < 2) + + A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) + B_shared = T.alloc_shared((num_stages, block_K, half_N), in_dtype) + SFA_shared = T.alloc_shared((num_stages, block_M), "uint32") + SFB_shared = T.alloc_shared((num_stages, block_N), "uint32") + + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + SFA_tmem = T.alloc_tmem([block_M, 4], "uint32") + SFB_tmem = T.alloc_tmem([block_M, 8], "uint32") + + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + loaded = T.alloc_barrier([32] * num_stages) + with_sf_full = T.alloc_cluster_barrier([32 * 2] * num_stages) + consumed = T.alloc_cluster_barrier([1] * num_stages) + tmem_full = T.alloc_barrier([1]) + + tx = T.get_thread_binding() + warp_idx = tx // 32 + T.use_swizzle(16) + + if warp_idx == 0: + for k in T.serial(k_iters): + stage = k % num_stages + phase = (k // num_stages) & 1 + T.mbarrier_wait_parity(consumed[stage], phase ^ 1) + T.tma_copy( + A[bx * block_M : (bx + 1) * block_M, k * block_K : (k + 1) * block_K], + A_shared[stage, :, :], + barrier=loaded[stage], + ) + T.tma_copy( + B[ + k * block_K : (k + 1) * block_K, + (by * block_N + cta_id * half_N) : (by * block_N + (cta_id + 1) * half_N), + ], + B_shared[stage, :, :], + barrier=loaded[stage], + ) + if k % sf_load_period == 0: + sf_group_idx = k // sf_load_period + T.tma_copy( + SFA[sf_group_idx * M + bx * block_M : sf_group_idx * M + (bx + 1) * block_M], + SFA_shared[stage, :], + barrier=loaded[stage], + ) + T.tma_copy( + SFB[sf_group_idx * N + by * block_N : sf_group_idx * N + (by + 1) * block_N], + SFB_shared[stage, :], + barrier=loaded[stage], + ) + T.mbarrier_arrive(loaded[stage]) + + elif warp_idx == 1 and cta_id == 0: + for k in T.serial(k_iters): + stage = k % num_stages + phase = (k // num_stages) & 1 + T.mbarrier_wait_parity(with_sf_full[stage], phase) + if k % sf_load_period == 0: + T.tcgen05_cp_warpx4(SFA_shared[stage, :], SFA_tmem, use_2cta=True) + T.tcgen05_cp_warpx4(SFB_shared[stage, :], SFB_tmem, use_2cta=True) + + T.tcgen05_gemm_blockscaled( + A_shared[stage, :, :], + B_shared[stage, :, :], + C_tmem, + SFA_tmem, + SFB_tmem, + mbar=consumed[stage], + clear_accum=k == 0, + sf_a_id=k % sf_load_period, + sf_b_id=k % sf_load_period, + use_2cta=True, + ) + T.tcgen05_mma_arrive(tmem_full, arrive_2cta=True) + + elif warp_idx == 2: + for k in T.serial(k_iters): + stage = k % num_stages + phase = (k // num_stages) & 1 + T.mbarrier_wait_parity(loaded[stage], phase) + if k % sf_load_period == 0: + T.tcgen05_sf_warp_transpose(SFA_shared[stage, :]) + T.tcgen05_sf_warp_transpose(SFB_shared[stage, :]) + T.fence_proxy_async() + T.mbarrier_arrive(with_sf_full[stage], 0) + + T.mbarrier_wait_parity(tmem_full, 0) + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + T.copy(C_shared, C[bx * block_M, by * block_N]) + + return C + + +@tilelang.jit +def mxfp8_blockscaled_gemm_2cta_persistent( + A, + B, + SFA, + SFB, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + sf_granularity_k=128, + use_tma_store=True, + store_block_N=64, +): + M, N, K = T.const("M, N, K") + + half_N = block_N // 2 + k_iters = T.ceildiv(K, block_K) + sf_load_period = sf_granularity_k * 4 // block_K + sf_k_groups = T.ceildiv(T.ceildiv(K, sf_granularity_k), 4) + + A: T.Tensor[[M, K], in_dtype] + B: T.Tensor[[K, N], in_dtype] + SFA: T.Tensor[[sf_k_groups * M], T.uint32] + SFB: T.Tensor[[sf_k_groups * N], T.uint32] + C = T.empty((M, N), out_dtype) + + sm_num = driver.get_num_sms() + num_clusters = sm_num // 2 + m_blocks = T.ceildiv(M, block_M) + m_clusters = m_blocks // 2 + n_blocks = T.ceildiv(N, block_N) + assert K % (2 * block_K) == 0 # for simplicity + waves = T.ceildiv(m_blocks * n_blocks, sm_num) + group_size = 16 # in cluster + assert n_blocks % (2 * group_size) == 0 # Please adjust group_size if not satisfied + + with T.Kernel(sm_num, threads=256, cluster_dims=2) as (block_id): + cta_id = T.block_rank_in_cluster() + T.assume(cta_id < 2) + + A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) + B_shared = T.alloc_shared((num_stages, block_K, half_N), in_dtype) + SFA_shared = T.alloc_shared((num_stages, block_M), "uint32") + SFB_shared = T.alloc_shared((num_stages, block_N), "uint32") + + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + SFA_tmem = T.alloc_tmem([block_M, block_M // 128 * 4], "uint32") + SFB_tmem = T.alloc_tmem([block_M, block_N // 128 * 4], "uint32") + + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype) + C_shared = T.alloc_shared((block_M, store_block_N), out_dtype) + + loaded = T.alloc_barrier([32] * num_stages) + with_sf_full = T.alloc_cluster_barrier([32 * 2] * num_stages) + consumed = T.alloc_cluster_barrier([1] * num_stages) + tmem_full = T.alloc_cluster_barrier([1]) + tmem_empty = T.alloc_cluster_barrier([128 * 2]) + + tx = T.get_thread_binding() + warp_idx = tx // 32 + + if warp_idx == 0: + for w in T.unroll(waves): + cluster_id = block_id // 2 + tile_id = num_clusters * w + cluster_id + bx_cluster = (tile_id // group_size) % m_clusters + bx = bx_cluster * 2 + cta_id + by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size + + if bx * block_M < M and by * block_N < N: + for k in T.serial(k_iters): + phase = w * k_iters + k + stage = phase % num_stages + parity = (phase // num_stages) & 1 + T.mbarrier_wait_parity(consumed[stage], parity ^ 1) + T.tma_copy( + A[bx * block_M : (bx + 1) * block_M, k * block_K : (k + 1) * block_K], + A_shared[stage, :, :], + barrier=loaded[stage], + ) + T.tma_copy( + B[ + k * block_K : (k + 1) * block_K, + by * block_N + cta_id * half_N : by * block_N + (cta_id + 1) * half_N, + ], + B_shared[stage, :, :], + barrier=loaded[stage], + ) + if k % sf_load_period == 0: + sf_group_idx = k // sf_load_period + T.tma_copy( + SFA[sf_group_idx * M + bx * block_M : sf_group_idx * M + (bx + 1) * block_M], + SFA_shared[stage, :], + barrier=loaded[stage], + ) + T.tma_copy( + SFB[sf_group_idx * N + by * block_N : sf_group_idx * N + (by + 1) * block_N], + SFB_shared[stage, :], + barrier=loaded[stage], + ) + T.mbarrier_arrive(loaded[stage]) + + elif warp_idx == 1 and cta_id == 0: + for w in T.unroll(waves): + cluster_id = block_id // 2 + tile_id = num_clusters * w + cluster_id + bx_cluster = (tile_id // group_size) % m_clusters + bx = bx_cluster * 2 + cta_id + by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size + + if bx * block_M < M and by * block_N < N: + T.mbarrier_wait_parity(tmem_empty, (w & 1) ^ 1) + for k in T.serial(k_iters): + phase = w * k_iters + k + stage = phase % num_stages + parity = (phase // num_stages) & 1 + T.mbarrier_wait_parity(with_sf_full[stage], parity) + if k % sf_load_period == 0: + T.tcgen05_cp_warpx4(SFA_shared[stage, :], SFA_tmem, use_2cta=True) + T.tcgen05_cp_warpx4(SFB_shared[stage, :], SFB_tmem, use_2cta=True) + T.tcgen05_gemm_blockscaled( + A_shared[stage, :, :], + B_shared[stage, :, :], + C_tmem, + SFA_tmem, + SFB_tmem, + mbar=consumed[stage], + clear_accum=k == 0, + sf_a_id=k % sf_load_period, + sf_b_id=k % sf_load_period, + use_2cta=True, + ) + T.tcgen05_mma_arrive(tmem_full, arrive_2cta=True) + + elif warp_idx == 2: + for w in T.unroll(waves): + cluster_id = block_id // 2 + tile_id = num_clusters * w + cluster_id + bx_cluster = (tile_id // group_size) % m_clusters + bx = bx_cluster * 2 + cta_id + by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size + + if bx * block_M < M and by * block_N < N: + for k in T.serial(k_iters): + phase = w * k_iters + k + stage = phase % num_stages + parity = (phase // num_stages) & 1 + T.mbarrier_wait_parity(loaded[stage], parity) + if k % sf_load_period == 0: + T.tcgen05_sf_warp_transpose(SFA_shared[stage, :]) + T.tcgen05_sf_warp_transpose(SFB_shared[stage, :]) + T.fence_proxy_async() + T.mbarrier_arrive(with_sf_full[stage], 0) + + elif 128 <= tx < 256: + for w in T.unroll(waves): + cluster_id = block_id // 2 + tile_id = num_clusters * w + cluster_id + bx_cluster = (tile_id // group_size) % m_clusters + bx = bx_cluster * 2 + cta_id + by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size + + if bx * block_M < M and by * block_N < N: + T.mbarrier_wait_parity(tmem_full, w & 1) + T.copy(C_tmem, C_local) + T.mbarrier_arrive(tmem_empty, 0) + + if use_tma_store: + for i in T.unroll(T.ceildiv(block_N, store_block_N)): + T.copy(C_local[:, i * store_block_N : (i + 1) * store_block_N], C_shared) + T.copy(C_shared, C[bx * block_M, by * block_N + i * store_block_N]) + else: + T.copy(C_local, C_local_cast) + T.copy(C_local_cast, C[bx * block_M, by * block_N]) + return C + + +def unpack_sf_u32_1d(packed_sf, mn, sf_k_blocks): + sf_k_groups = (sf_k_blocks + 3) // 4 + packed_2d = packed_sf.view(sf_k_groups, mn).T.contiguous().to(torch.int64) + unpacked = torch.empty((mn, sf_k_groups * 4), device=packed_sf.device, dtype=torch.uint8) + for i in range(4): + unpacked[:, i::4] = ((packed_2d >> (8 * i)) & 0xFF).to(torch.uint8) + return unpacked[:, :sf_k_blocks].contiguous() + + +def pack_sf_u8_to_u32_1d(sf_u8): + assert sf_u8.dtype == torch.uint8 + assert sf_u8.dim() == 2 + mn, sf_k_padded = sf_u8.shape + assert sf_k_padded % 4 == 0 + words = sf_u8.to(torch.int64) + packed = (words[:, 0::4] | (words[:, 1::4] << 8) | (words[:, 2::4] << 16) | (words[:, 3::4] << 24)).to(torch.uint32) + return packed.T.contiguous().reshape(-1) + + +def quantize_fp8_with_packed_ue8m0(x, gran_k=128): + """DeepGEMM-style per-token FP8 quantization with UE8M0 scale factors. + + Returns: + x_fp8: [MN, K] in float8_e4m3fn + sf_packed_u32: flattened group-major packed uint32 scale factors + sf_u8: [MN, ceil(K / gran_k)] unpacked E8M0 exponents + """ + + def ceil_div_int(x, y): + return (x + y - 1) // y + + def align_up(x, y): + return ceil_div_int(x, y) * y + + def ceil_to_ue8m0(x): + bits = x.abs().float().view(torch.int32) + exp = ((bits >> 23) & 0xFF) + (bits & 0x7FFFFF).ne(0).to(torch.int32) + return (exp.clamp(1, 254) << 23).view(torch.float32) + + assert x.dim() == 2 + mn, k = x.shape + padded_k = align_up(k, gran_k) + + x_padded = torch.zeros((mn, padded_k), device=x.device, dtype=x.dtype) + x_padded[:, :k] = x + x_view = x_padded.view(mn, padded_k // gran_k, gran_k) + + x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4) + sf = ceil_to_ue8m0(x_amax / 448.0) + + x_fp8 = (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn) + x_fp8 = x_fp8.view(mn, padded_k)[:, :k].contiguous() + + sf_u8 = (sf.contiguous().view(torch.int32) >> 23).to(torch.uint8) + sf_k_blocks = sf_u8.shape[1] + sf_k_padded = align_up(sf_k_blocks, 4) + if sf_k_padded != sf_k_blocks: + sf_u8_padded = torch.full((mn, sf_k_padded), 127, device=x.device, dtype=torch.uint8) + sf_u8_padded[:, :sf_k_blocks] = sf_u8 + else: + sf_u8_padded = sf_u8 + + sf_packed_u32 = pack_sf_u8_to_u32_1d(sf_u8_padded) + return x_fp8, sf_packed_u32, sf_u8 + + +def blockscaled_gemm_ref(a, b, sfa_packed, sfb_packed, sf_granularity_k=128): + """Torch reference for block-scaled MXFP8 GEMM. + + Args: + a: [M, K] FP8 tensor + b: [K, N] FP8 tensor + sfa_packed: [(sf_k_blocks / 4) * M] uint32 packed E8M0 scale factors for A + sfb_packed: [(sf_k_blocks / 4) * N] uint32 packed E8M0 scale factors for B + sf_granularity_k: number of K elements per scale factor block (default 128) + + Returns: + [M, N] float32 result + """ + M, K = a.shape + K2, N = b.shape + assert K == K2 + sf_k_blocks = (K + sf_granularity_k - 1) // sf_granularity_k + sfa_unpacked = unpack_sf_u32_1d(sfa_packed, M, sf_k_blocks) + sfb_unpacked = unpack_sf_u32_1d(sfb_packed, N, sf_k_blocks) + + a_f32 = a.to(torch.float32) + b_f32 = b.to(torch.float32) + + # E8M0 exponent to float scale: 2^(exp - 127) + sfa_scales = torch.pow(2.0, sfa_unpacked.to(torch.float32) - 127.0) # [M, sf_k_blocks] + sfb_scales = torch.pow(2.0, sfb_unpacked.to(torch.float32) - 127.0) # [N, sf_k_blocks] + + c = torch.zeros(M, N, device=a.device, dtype=torch.float32) + for bi in range(sf_k_blocks): + k_start = bi * sf_granularity_k + k_end = min(k_start + sf_granularity_k, K) + # Scale A block: [M, block_k] * [M, 1] + a_block = a_f32[:, k_start:k_end] * sfa_scales[:, bi : bi + 1] + # Scale B block: [block_k, N] * [1, N] (sfb is [N, blocks], transpose for broadcast) + b_block = b_f32[k_start:k_end, :] * sfb_scales[:, bi : bi + 1].T + c += a_block @ b_block + return c + + +def cosine_similarity(a, b): + a_flat = a.flatten().float() + b_flat = b.flatten().float() + return (a_flat @ b_flat) / (a_flat.norm() * b_flat.norm()) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--use-e2e-quant-path", action="store_true", default=True) + parser.add_argument("--persistent", action="store_true", default=True) + parser.add_argument("--enable-2cta", action="store_true", default=True) + return parser.parse_args() + + +def main(): + args = parse_args() + + M, N, K = 8192, 8192, 8192 + block_M, block_N, block_K = 128, 256, 128 + in_dtype, out_dtype, accum_dtype = T.float8_e4m3fn, T.bfloat16, T.float + use_e2e_quant_path = args.use_e2e_quant_path + persistent = args.persistent + enable_2cta = args.enable_2cta + num_stages = 6 if enable_2cta else 4 + if persistent: + assert enable_2cta + kernel = mxfp8_blockscaled_gemm_2cta_persistent + else: + kernel = mxfp8_blockscaled_gemm_2cta if enable_2cta else mxfp8_blockscaled_gemm + sf_granularity_k = 128 + assert sf_granularity_k == 128 + + if use_e2e_quant_path: + # End-to-end path: + # fp16/bf16 source tensors -> per-token FP8 quantization with UE8M0 SF + # -> pack 4 SF entries into one uint32 -> blockscaled GEMM + x = torch.randn(M, K, device="cuda", dtype=torch.float16) + w_nt = torch.randn(N, K, device="cuda", dtype=torch.float16) + + a, sfa, _ = quantize_fp8_with_packed_ue8m0(x, gran_k=sf_granularity_k) + b_nt, sfb, _ = quantize_fp8_with_packed_ue8m0(w_nt, gran_k=sf_granularity_k) + b = b_nt.T.contiguous() + else: + a = torch.randn(M, K, device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) + b = torch.randn(K, N, device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) + + # E8M0 scale factors: one uint32 per row per 4 K-blocks. + sf_k_blocks = (K + sf_granularity_k - 1) // sf_granularity_k + + # Pad to multiple of 4 (UTCCP loads 4 K-blocks at a time) + sf_k_padded = ((sf_k_blocks + 3) // 4) * 4 + sfa_u8 = torch.randint(127 - 5, 127 + 5, (M, sf_k_padded), device="cuda", dtype=torch.uint8) + sfb_u8 = torch.randint(127 - 5, 127 + 5, (N, sf_k_padded), device="cuda", dtype=torch.uint8) + sfa = pack_sf_u8_to_u32_1d(sfa_u8) + sfb = pack_sf_u8_to_u32_1d(sfb_u8) + + c = kernel( + a, + b, + sfa, + sfb, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + sf_granularity_k, + ) + print( + kernel.get_kernel_source( + a, + b, + sfa, + sfb, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + sf_granularity_k, + ) + ) + + if use_e2e_quant_path: + # For the end-to-end quantization path, compare against the reference with bf16 gemm + ref_c = (x.float() @ w_nt.float().T).to(torch.bfloat16) + else: + ref_c = blockscaled_gemm_ref(a, b, sfa, sfb, sf_granularity_k).to(torch.bfloat16) + sim = cosine_similarity(c, ref_c) + + print(f"Output shape: {c.shape}, dtype: {c.dtype}") + print(f"E2E quant path: {use_e2e_quant_path}") + print(f"{c=}, {ref_c=}") + # print(f"Max abs error: {(c.float() - ref_c.float()).abs().max().item():.6f}") + print(f"Cosine similarity: {sim.item():.6f}") + if use_e2e_quant_path: + assert 1 - sim < 1e-3 # err tolerance from DeepGEMM + print("e2e check passed ✅") + + tl_latency = do_bench( + lambda: kernel( + a, + b, + sfa, + sfb, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + sf_granularity_k, + ), + backend="cupti", + ) + print(f"Tilelang MXFP8 latency: {tl_latency} ms") + print(f"TFLOPs: {2 * M * N * K / (tl_latency / 1e3) / 1e12:.2f}") + + +if __name__ == "__main__": + main() diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 0b1e20d4be..1d1aa368ce 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -217,6 +217,21 @@ TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ts) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_blockscaled_ss) + .set_num_inputs(16) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_cp_warpx4) + .set_num_inputs(3) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_sf_warp_transpose) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(deallocate_tmem) .set_num_inputs(1) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index 68c41843c3..2d28eb75bb 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -379,6 +379,21 @@ TVM_DLL const Op &ptx_tcgen05_mma_ss(); */ TVM_DLL const Op &ptx_tcgen05_mma_ts(); +/*! + * \brief tvm intrinsic for tcgen05 block-scaled mma shared-shared instructions. + */ +TVM_DLL const Op &ptx_tcgen05_mma_blockscaled_ss(); + +/*! + * \brief tvm intrinsic for tcgen05 copy warpx4 (smem to tmem). + */ +TVM_DLL const Op &ptx_tcgen05_cp_warpx4(); + +/*! + * \brief tvm intrinsic for scale factor warp transpose in shared memory. + */ +TVM_DLL const Op &ptx_tcgen05_sf_warp_transpose(); + /*! * \brief Frontend TMEM deallocation marker. * diff --git a/src/op/copy.cc b/src/op/copy.cc index 2eb8bef602..4f8500684c 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -1242,6 +1242,10 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, } // Currently tcgen05.cp is not supported // TODO (mzw) Support tcgen05.cp + + // NOTE(wt): For copying scaling factor from SMEM to TMEM, + // please use `T.tcgen05_cp_warpx4` instead, + // as blockscaled GEMM on SM100 requires 4 duplicated 32-row sf. ICHECK(!is_cp) << "Copy from shared memory to tensor memory is not supported yet"; // Extract loop variables and ranges diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 7ba7c01ce2..5472b33386 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -86,6 +86,18 @@ Gemm::Gemm(Array args, Map annotations) { } node->cCoords_ = Array( {args[17].as().value(), args[18].as().value()}); + if (args.size() > 19) { + node->sfaRegion_ = NormalizeToBufferRegion(args[19]); + } + if (args.size() > 20) { + node->sfbRegion_ = NormalizeToBufferRegion(args[20]); + } + if (args.size() > 21) { + node->sfAId_ = args[21].as().value(); + } + if (args.size() > 22) { + node->sfBId_ = args[22].as().value(); + } node->annotations_ = annotations; data_ = std::move(node); } @@ -97,6 +109,12 @@ AccessRegions GemmNode::GetAccessRegions() const { if (!is_one(clearAccum_)) { result.reads.push_back(cRegion_); } + if (sfaRegion_.defined()) { + result.reads.push_back(sfaRegion_); + } + if (sfbRegion_.defined()) { + result.reads.push_back(sfbRegion_); + } result.writes.push_back(cRegion_); return result; } @@ -568,6 +586,16 @@ TVM_FFI_STATIC_INIT_BLOCK() { scale_in_a, scale_in_b); return Integer(static_cast(desc)); }); + refl::GlobalDef().def("tl.get_tcgen5_blockscaled_instr_desc", + [](int atom_m, int atom_n, DataType ab_dtype, + bool a_is_k_major, bool b_is_k_major, int scale_in_a, + int scale_in_b, int a_sf_id, int b_sf_id) { + uint32_t desc = GetTCGEN5BlockScaledInstrDesc( + atom_m, atom_n, ab_dtype, a_is_k_major, + b_is_k_major, scale_in_a, scale_in_b, a_sf_id, + b_sf_id); + return Integer(static_cast(desc)); + }); } } // namespace tl diff --git a/src/op/gemm.h b/src/op/gemm.h index 2c78f54760..26b6678402 100644 --- a/src/op/gemm.h +++ b/src/op/gemm.h @@ -149,6 +149,8 @@ class GemmNode : public TileOperatorNode { bool isTcgen05_ = false; mutable GemmWarpPolicy policy_; Map annotations_; + BufferRegion sfaRegion_, sfbRegion_; + PrimExpr sfAId_, sfBId_; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Gemm", GemmNode, TileOperatorNode); @@ -178,7 +180,11 @@ class GemmNode : public TileOperatorNode { .def_ro("isWgmma", &GemmNode::isWgmma_) .def_ro("isTcgen05", &GemmNode::isTcgen05_) .def_ro("policy", &GemmNode::policy_) - .def_ro("annotations", &GemmNode::annotations_); + .def_ro("annotations", &GemmNode::annotations_) + .def_ro("sfaRegion", &GemmNode::sfaRegion_) + .def_ro("sfbRegion", &GemmNode::sfbRegion_) + .def_ro("sfAId", &GemmNode::sfAId_) + .def_ro("sfBId", &GemmNode::sfBId_); } Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; diff --git a/src/op/tcgen5_meta.h b/src/op/tcgen5_meta.h index 2f1b5790ff..9effa12f63 100644 --- a/src/op/tcgen5_meta.h +++ b/src/op/tcgen5_meta.h @@ -239,6 +239,66 @@ inline uint32_t GetTCGEN5InstrDesc(int atom_m, int atom_n, int atom_k, return desc; } +// Build block-scaled instruction descriptor for mxf8f6f4.block_scale +// Bit layout: InstrDescriptorBlockScaled (see CUTLASS mma_sm100_desc.hpp) +inline uint32_t GetTCGEN5BlockScaledInstrDesc(int atom_m, int atom_n, + DataType ab_dtype, + bool a_is_k_major, + bool b_is_k_major, int scale_in_a, + int scale_in_b, int a_sf_id, + int b_sf_id) { + ICHECK(atom_m % 16 == 0) << "atom_m must be divisible by 16"; + ICHECK(atom_n % 8 == 0) << "atom_n must be divisible by 8"; + ICHECK(scale_in_a == 1 || scale_in_a == -1); + ICHECK(scale_in_b == 1 || scale_in_b == -1); + + // a_format / b_format for MXF8F6F4: E4M3=0, E5M2=1 + auto encode_mxfp_dtype = [&](DataType dtype) -> uint32_t { + if (dtype.is_float8_e4m3fn() || dtype.is_float8_e4m3fnuz() || + dtype.is_float8_e4m3()) { + return 0u; // E4M3 + } else if (dtype.is_float8_e5m2fnuz() || dtype.is_float8_e5m2()) { + return 1u; // E5M2 + } + LOG(FATAL) << "Unsupported dtype for block-scaled descriptor: " << dtype; + return 0u; + }; + + auto set_bits = [](uint32_t value, int start, int width) -> uint32_t { + uint32_t mask = (width == 32) ? 0xFFFFFFFFu : ((1u << width) - 1); + return (value & mask) << start; + }; + + uint32_t a_format = encode_mxfp_dtype(ab_dtype); + uint32_t b_format = a_format; + uint32_t a_neg = (scale_in_a == -1) ? 1u : 0u; + uint32_t b_neg = (scale_in_b == -1) ? 1u : 0u; + uint32_t a_major = a_is_k_major ? 0u : 1u; + uint32_t b_major = b_is_k_major ? 0u : 1u; + uint32_t n_dim = static_cast(atom_n >> 3); + uint32_t m_dim = static_cast(atom_m >> 4); + + uint32_t desc = 0; + desc |= set_bits(0, 0, 2); // sparse_id2 + desc |= set_bits(0, 2, 1); // sparse_flag + // bit 3 reserved + desc |= set_bits(static_cast(b_sf_id), 4, 2); // b_sf_id + // bit 6 reserved + desc |= set_bits(a_format, 7, 3); // a_format + desc |= set_bits(b_format, 10, 3); // b_format + desc |= set_bits(a_neg, 13, 1); // a_negate + desc |= set_bits(b_neg, 14, 1); // b_negate + desc |= set_bits(a_major, 15, 1); // a_major + desc |= set_bits(b_major, 16, 1); // b_major + desc |= set_bits(n_dim, 17, 6); // n_dim + desc |= set_bits(1, 23, 1); // scale_format = 1 (E8M0) + desc |= set_bits(m_dim, 24, 5); // m_dim + desc |= set_bits(static_cast(a_sf_id), 29, 2); // a_sf_id + desc |= set_bits(0, 31, 1); // k_size = 0 (K32) + + return desc; +} + } // namespace tl } // namespace tvm diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 317b01370e..20a5fa5515 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -2866,6 +2866,88 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { replacer.register_rule("(mask3)", mask3); tcgen05_call = replacer.rewrite(tcgen05_call); this->stream << tcgen05_call; + } else if (op->op.same_as(tl::ptx_tcgen05_mma_blockscaled_ss())) { + ICHECK_EQ(op->args.size(), 17U) + << "ptx_tcgen05_mma_blockscaled_ss expects 17 arguments"; + std::string kind_dtype = Downcast(op->args[0])->value; + std::string a_desc = this->PrintExpr(op->args[1]); + std::string A_offset = this->PrintExpr(op->args[2]); + std::string b_desc = this->PrintExpr(op->args[3]); + std::string B_offset = this->PrintExpr(op->args[4]); + std::string c_ref = this->PrintExpr(op->args[5]); + std::string c_offset = this->PrintExpr(op->args[6]); + PrimExpr desc_expr = op->args[7]; + std::string scale_out = this->PrintExpr(op->args[8]); + std::string sfa_ref = this->PrintExpr(op->args[9]); + std::string sfa_offset = this->PrintExpr(op->args[10]); + std::string sfb_ref = this->PrintExpr(op->args[11]); + std::string sfb_offset = this->PrintExpr(op->args[12]); + // args[13], [14] reserved for future mask/flags + bool enable_ws = Downcast(op->args[15])->value; + bool enable_2cta = Downcast(op->args[16])->value; + ICHECK(!(enable_ws && enable_2cta)) + << "Block-scaled TCGEN05 does not support combining .ws and 2CTA"; + + auto dtype_enum = tl::codegen::ptx::DTypeFromString(kind_dtype); + + need_tcgen05mma_instruction_h_ = true; + this->PrintIndent(); + std::string tcgen05_call = + "tl::(tcgen05_name)<(ABType), (USE_2CTA)>(uint64_t((desc_a) + " + "(A_offset)), " + "uint64_t((desc_b) + (B_offset)), (*reinterpret_cast((C))) " + "+ (C_offset), " + "(scale_out), static_cast((desc_val)), " + "(*reinterpret_cast((SFA))) + (SFA_offset), " + "(*reinterpret_cast((SFB))) + (SFB_offset));\n"; + tl::codegen::Replacer replacer; + replacer.register_rule("(ABType)", + tl::codegen::ptx::DTypeEnumToString(dtype_enum)); + replacer.register_rule("(USE_2CTA)", enable_2cta ? "true" : "false"); + replacer.register_rule("(desc_a)", a_desc); + replacer.register_rule("(A_offset)", A_offset); + replacer.register_rule("(desc_b)", b_desc); + replacer.register_rule("(B_offset)", B_offset); + replacer.register_rule("(C)", c_ref); + replacer.register_rule("(C_offset)", c_offset); + replacer.register_rule("(tcgen05_name)", + enable_ws ? "tcgen05mma_blockscaled_ws_ss" + : "tcgen05mma_blockscaled_ss"); + replacer.register_rule("(scale_out)", scale_out); + replacer.register_rule("(desc_val)", this->PrintExpr(desc_expr)); + replacer.register_rule("(SFA)", sfa_ref); + replacer.register_rule("(SFA_offset)", sfa_offset); + replacer.register_rule("(SFB)", sfb_ref); + replacer.register_rule("(SFB_offset)", sfb_offset); + tcgen05_call = replacer.rewrite(tcgen05_call); + this->stream << tcgen05_call; + } else if (op->op.same_as(tl::ptx_tcgen05_cp_warpx4())) { + ICHECK_EQ(op->args.size(), 3U) + << "ptx_tcgen05_cp_warpx4 expects 3 arguments"; + need_tcgen05_common_h_ = true; + // arg[0] = smem pointer, arg[1] = tmem data pointer, arg[2] = tmem column + // offset + std::string smem_ptr = this->PrintExpr(op->args[0]); + std::string tmem_ptr = this->PrintExpr(op->args[1]); + std::string tmem_col_offset = this->PrintExpr(op->args[2]); + bool use_2cta = false; + if (op->annotations.find("use_2cta") != op->annotations.end()) { + use_2cta = Downcast(op->annotations["use_2cta"])->value; + } + this->PrintIndent(); + this->stream << "tl::tcgen05_cp<" << (use_2cta ? "true" : "false") << ">(" + << "tl::make_sf_smem_desc(reinterpret_cast(" << smem_ptr + << ")), " + << "(*reinterpret_cast(" << tmem_ptr << ")) + " + << tmem_col_offset << ");\n"; + } else if (op->op.same_as(tl::ptx_tcgen05_sf_warp_transpose())) { + ICHECK_EQ(op->args.size(), 1U) + << "ptx_tcgen05_sf_warp_transpose expects 1 argument"; + need_tcgen05_common_h_ = true; + std::string smem_ptr = this->PrintExpr(op->args[0]); + this->PrintIndent(); + this->stream << "tl::tcgen05_sf_warp_transpose(reinterpret_cast(" + << smem_ptr << "));\n"; } else if (op->op.same_as(tl::tcgen05_ld())) { ICHECK_EQ(op->args.size(), 6U) << "tcgen05_ld expects 6 arguments"; need_tcgen05_common_h_ = true; diff --git a/src/tl_templates/cuda/instruction/tcgen05mma.h b/src/tl_templates/cuda/instruction/tcgen05mma.h index 2e1a8ec9da..7a8d831bfa 100644 --- a/src/tl_templates/cuda/instruction/tcgen05mma.h +++ b/src/tl_templates/cuda/instruction/tcgen05mma.h @@ -526,4 +526,82 @@ TL_DEVICE void tcgen05mma_ws_ss( desc_a, desc_b, tmem_c, scalec, desc_val, mask0, mask1, mask2, mask3); } +// ============================================================================ +// Block-scaled MMA variants: +// tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale Used for MXFP8 +// block-scaled GEMM with scale factors in TMEM. +// ============================================================================ + +// Generic declaration: unsupported by default +template +TL_DEVICE void tcgen05mma_blockscaled_ss(uint64_t const & /*desc_a*/, + uint64_t const & /*desc_b*/, + uint32_t const & /*tmem_c*/, + uint32_t const & /*scalec*/, + uint32_t const & /*desc_val*/, + uint32_t const & /*tmem_sfa*/, + uint32_t const & /*tmem_sfb*/) { + static_assert( + always_false_v(C_type)>>, + "tl::tcgen05mma_blockscaled_ss: unsupported accumulator type"); +} + +// FP8 E4M3 block-scaled +template <> +TL_DEVICE void tcgen05mma_blockscaled_ss( + uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, uint32_t const &tmem_sfa, + uint32_t const &tmem_sfb) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, " + "%3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +} + +template <> +TL_DEVICE void tcgen05mma_blockscaled_ss( + uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, uint32_t const &tmem_sfa, + uint32_t const &tmem_sfb) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, " + "%3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +} + +// FP8 E5M2 maps to same instruction +template <> +TL_DEVICE void tcgen05mma_blockscaled_ss( + uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, uint32_t const &tmem_sfa, + uint32_t const &tmem_sfb) { + tcgen05mma_blockscaled_ss( + desc_a, desc_b, tmem_c, scalec, desc_val, tmem_sfa, tmem_sfb); +} + +template <> +TL_DEVICE void tcgen05mma_blockscaled_ss( + uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, uint32_t const &tmem_sfa, + uint32_t const &tmem_sfb) { + tcgen05mma_blockscaled_ss( + desc_a, desc_b, tmem_c, scalec, desc_val, tmem_sfa, tmem_sfb); +} + } // namespace tl diff --git a/src/tl_templates/cuda/tcgen_05.h b/src/tl_templates/cuda/tcgen_05.h index 700aa3ba7a..85e7e491db 100644 --- a/src/tl_templates/cuda/tcgen_05.h +++ b/src/tl_templates/cuda/tcgen_05.h @@ -86,4 +86,55 @@ TL_DEVICE void tcgen05_mma_arrive(void const *smem_ptr, } } +// UTCCP: Copy scale factors from shared memory to tensor memory. +// Must be called by one warp; only one elected thread issues the instruction. +template +TL_DEVICE void tcgen05_cp(uint64_t const &smem_desc, uint32_t const &tmem_col) { + if (cute::elect_one_sync()) { + if constexpr (use_2cta) { + asm volatile("tcgen05.cp.cta_group::2.32x128b.warpx4 [%0], %1;" + : + : "r"(tmem_col), "l"(smem_desc)); + } else { + asm volatile("tcgen05.cp.cta_group::1.32x128b.warpx4 [%0], %1;" + : + : "r"(tmem_col), "l"(smem_desc)); + } + } +} + +// Warp-level transpose of 128 uint32 elements in shared memory for UTCCP. +// Each warp (32 threads) transposes a 4x32 block in-place. +// Must be called by exactly one warp. Call __syncwarp() is embedded. +TL_DEVICE void tcgen05_sf_warp_transpose(uint32_t *smem_ptr) { + const uint32_t lane = threadIdx.x % 32; + uint32_t values[4]; +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) + values[i] = smem_ptr[(i ^ (lane >> 3)) * 32 + lane]; + __syncwarp(); +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) + smem_ptr[lane * 4 + (i ^ (lane >> 3))] = values[i]; +} + +// Build a SMEM descriptor for UTCCP scale factor copy (no swizzle, K-major) +// SBO = 128 bytes (stride between atoms on MN), LBO = 0 (single K atom) +TL_DEVICE uint64_t make_sf_smem_desc(void *smem_ptr) { + uint32_t uint_ptr = smem_ptr_to_uint(smem_ptr); + // SmemDescriptor bit layout: + // [0,14): start_address >> 4 + // [16,30): leading_byte_offset >> 4 = 0 + // [32,46): stride_byte_offset >> 4 = 128/16 = 8 + // [46,48): version = 1 (SM100) + // [61,64): layout_type = 0 (SWIZZLE_NONE) + uint64_t desc = 0; + desc |= static_cast(uint_ptr >> 4) & 0x3FFFull; // start_address + // leading_byte_offset = 0 (bits [16,30)) + desc |= static_cast(8u) << 32; // stride_byte_offset >> 4 = 8 + desc |= static_cast(1u) << 46; // version = 1 + // layout_type = 0 (SWIZZLE_NONE), base_offset = 0, lbo_mode = 0 + return desc; +} + } // namespace tl diff --git a/tilelang/intrinsics/tcgen05_macro_generator.py b/tilelang/intrinsics/tcgen05_macro_generator.py index a8c4d79ce9..b275a6f76d 100644 --- a/tilelang/intrinsics/tcgen05_macro_generator.py +++ b/tilelang/intrinsics/tcgen05_macro_generator.py @@ -565,6 +565,246 @@ def _warp_mma_ts(a_data, B_buf, C_local_buf, mbar): return _warp_mma_ts(a_tmem_data, B_buf, C_local_buf, mbar) + def tcgen05mma_blockscaled( + self, + A_buf: Buffer, + B_buf: Buffer, + C_local_buf: Buffer, + SFA_tmem, + SFB_tmem, + mbar, + clear_accum: PrimExpr = False, + sf_a_id=0, + sf_b_id=0, + ): + """Emit a block-scaled TCGEN5MMA (SS variant with TMEM scale factors). + + Uses ``tcgen05.mma.cta_group::1|2.kind::mxf8f6f4.block_scale`` PTX instruction. + Scale factors must already reside in tensor memory. + """ + accum_dtype = self.accum_dtype + m_dim = self.block_row_warps * self.warp_row_tiles + micro_size_k = self.micro_size_k + k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles + scale_in_a = 1 + scale_in_b = 1 + + assert k_dim >= micro_size_k + + a_is_k_major = not self.a_transposed + b_is_k_major = self.b_transposed + a_swizzle_mode = self._determinate_swizzle_mode(A_buf, self.a_shared_layout) + b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout) + + elems_in_bits = DataType(self.a_dtype).bits + elems_in_bytes = elems_in_bits // 8 + accum_dtype_in_bits = DataType(accum_dtype).bits + + if len(self.meta) != 5: + self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim, disable_2cta=False) + if len(self.meta) != 5: + raise ValueError( + f"Unsupported TCGEN5MMA configuration for block-scaled: M={m_dim}, N={n_dim}, " + f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}" + ) + atom_m, atom_n, atom_k, _enable_ws, enable_2cta = (int(x) for x in self.meta) + enable_ws = 0 + atom_m_per_cta = atom_m // 2 if enable_2cta else atom_m + n_dim_per_cta = n_dim // 2 if enable_2cta else n_dim + + a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes + b_swizzle_atom_elems = n_dim_per_cta if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes + + a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * atom_m_per_cta * elems_in_bytes) + a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * elems_in_bytes) + if not a_swizzle_mode.is_none(): + if a_is_k_major: + a_leading_byte_offset = 16 + a_stride_byte_offset = 8 * a_swizzle_mode.swizzle_byte_size() + else: + a_m_axis_atoms = atom_m_per_cta // a_swizzle_atom_elems + a_leading_byte_offset = k_dim * a_swizzle_mode.swizzle_byte_size() if a_m_axis_atoms > 1 else 0 + a_stride_byte_offset = ( + 8 * elems_in_bytes * a_swizzle_atom_elems if a_m_axis_atoms > 1 else 8 * elems_in_bytes * atom_m_per_cta + ) + + b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim_per_cta * elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim_per_cta == 8 else (8 * 8 * elems_in_bytes)) + if not b_swizzle_mode.is_none(): + if b_is_k_major: + b_leading_byte_offset = 16 + b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() + else: + b_n_axis_atoms = n_dim_per_cta // b_swizzle_atom_elems + b_leading_byte_offset = b_swizzle_mode.swizzle_byte_size() * k_dim if b_n_axis_atoms > 1 else 0 + b_stride_byte_offset = ( + 8 * elems_in_bytes * b_swizzle_atom_elems if b_n_axis_atoms > 1 else 8 * elems_in_bytes * n_dim_per_cta + ) + + ak_atom_size = max(a_swizzle_atom_elems // micro_size_k, 1) + bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1) + + base_instr_desc = self.get_tcgen5_blockscaled_instr_desc( + atom_m, + atom_n, + a_is_k_major, + b_is_k_major, + scale_in_a, + scale_in_b, + 0, + 0, + ) + + a_dtype_abbrv = self.a_dtype_abbrv + num_inst_m = m_dim // atom_m_per_cta + num_inst_n = n_dim // atom_n + + def access_ptr_from(buffer_or_load_or_region, access_type: str = "r"): + if isinstance(buffer_or_load_or_region, Buffer): + return buffer_or_load_or_region.access_ptr(access_type) + elif isinstance(buffer_or_load_or_region, BufferLoad): + buffer_load = buffer_or_load_or_region + offset, stride = 0, 1 + buffer = buffer_load.buffer + for i, shape in enumerate(reversed(buffer.shape)): + indice = buffer_load.indices[len(buffer_load.indices) - i - 1] + if isinstance(indice, tvm.tir.Ramp): + offset += indice.base * stride + elif isinstance(indice, (tvm.tir.IntImm, tvm.tir.PrimExpr)): + offset += indice * stride + else: + raise ValueError(f"Unsupported index type: {type(indice)}") + stride *= shape + return buffer.access_ptr(access_type, offset=offset) + elif isinstance(buffer_or_load_or_region, BufferRegion): + buffer_region = buffer_or_load_or_region + buffer = buffer_region.buffer + offset, stride = 0, 1 + for i, shape in enumerate(reversed(buffer.shape)): + offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride + stride *= shape + return buffer.access_ptr(access_type, offset=offset) + else: + raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}") + + if isinstance(SFA_tmem, BufferRegion): + sfa_data = SFA_tmem.buffer.data + elif isinstance(SFA_tmem, Buffer): + sfa_data = SFA_tmem.data + else: + raise ValueError(f"Unsupported SFA_tmem type: {type(SFA_tmem)}") + + if isinstance(SFB_tmem, BufferRegion): + sfb_data = SFB_tmem.buffer.data + elif isinstance(SFB_tmem, Buffer): + sfb_data = SFB_tmem.data + else: + raise ValueError(f"Unsupported SFB_tmem type: {type(SFB_tmem)}") + + @T.macro + def _warp_mma_blockscaled(A_buf, B_buf, C_local_buf, sfa_data, sfb_data, mbar): + desc_a = T.alloc_tcgen05_smem_desc() + desc_b = T.alloc_tcgen05_smem_desc() + A_ptr = access_ptr_from(A_buf, "r") + B_ptr = access_ptr_from(B_buf, "r") + + T.initialize_tcgen05_descriptor( + desc_a, + A_ptr, + int(a_leading_byte_offset >> 4), + int(a_stride_byte_offset >> 4), + 0, + False, + int(a_swizzle_mode), + ) + T.initialize_tcgen05_descriptor( + desc_b, + B_ptr, + int(b_leading_byte_offset >> 4), + int(b_stride_byte_offset >> 4), + 0, + False, + int(b_swizzle_mode), + ) + + tmem_col_step = atom_n // (128 // atom_m_per_cta) + _sf_a = tvm.tir.const(sf_a_id, "int32") if isinstance(sf_a_id, int) else sf_a_id + _sf_b = tvm.tir.const(sf_b_id, "int32") if isinstance(sf_b_id, int) else sf_b_id + runtime_instr_desc = base_instr_desc | (_sf_a << 29) | (_sf_b << 4) + for j in T.unroll(num_inst_n): + for i in T.unroll(num_inst_m): + for ki in T.unroll(0, (k_dim // micro_size_k)): + scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1)) + A_elem_offset = ( + (ki % ak_atom_size) * micro_size_k + + i * atom_m_per_cta * a_swizzle_atom_elems + + (ki // ak_atom_size) * m_dim * a_swizzle_atom_elems + if a_is_k_major + else i * atom_m_per_cta * k_dim + ki * a_swizzle_atom_elems * micro_size_k + ) + B_elem_offset = ( + (ki // bk_atom_size) * n_dim_per_cta * b_swizzle_atom_elems + + (ki % bk_atom_size) * micro_size_k + + j * atom_n * b_swizzle_atom_elems + if b_is_k_major + else ( + ki * b_swizzle_atom_elems * micro_size_k + + j * atom_n * (k_dim if n_dim_per_cta // b_swizzle_atom_elems > 1 else 1) + ) + ) + + A_byte_offset = A_elem_offset * elems_in_bytes + B_byte_offset = B_elem_offset * elems_in_bytes + C_offset = (i * n_dim + j * tmem_col_step) * accum_dtype_in_bits // 32 + + T.ptx_tcgen05_mma_blockscaled_ss( + a_dtype_abbrv, + desc_a.data, + A_byte_offset, + desc_b.data, + B_byte_offset, + C_local_buf.data, + C_offset, + runtime_instr_desc, + scale_out, + sfa_data, + 0, + sfb_data, + 0, + 0, + 0, + enable_ws, + enable_2cta, + ) + T.tcgen05_mma_arrive(mbar, arrive_2cta=enable_2cta) + + return _warp_mma_blockscaled(A_buf, B_buf, C_local_buf, sfa_data, sfb_data, mbar) + + def get_tcgen5_blockscaled_instr_desc( + self, + atom_m: int, + atom_n: int, + a_is_k_major: bool, + b_is_k_major: bool, + scale_in_a: int, + scale_in_b: int, + a_sf_id: int, + b_sf_id: int, + ) -> PrimExpr: + """Build the block-scaled instruction descriptor via FFI.""" + desc = _ffi_api.get_tcgen5_blockscaled_instr_desc( + atom_m, + atom_n, + DataType(self.a_dtype), + a_is_k_major, + b_is_k_major, + scale_in_a, + scale_in_b, + a_sf_id, + b_sf_id, + ) + return lift(desc) + def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragment: raise NotImplementedError diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 7dddc7a748..47b4ee037b 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -57,7 +57,13 @@ from tvm.script.parser.tir import allocate as allocate # noqa: F401 from .copy_op import copy, async_copy, tma_copy, transpose, c2d_im2col # noqa: F401 from tilelang.tileop.base import GemmWarpPolicy # noqa: F401 -from .gemm_op import gemm, wgmma_gemm, tcgen05_gemm # noqa: F401 +from .gemm_op import ( # noqa: F401 + gemm, + wgmma_gemm, + tcgen05_gemm, + tcgen05_gemm_blockscaled, + make_blockscaled_gemm_layout, +) from .experimental.gemm_sp import gemm_sp, gemm_sp_v2 # noqa: F401 from .fill_op import fill, clear # noqa: F401 from .reduce_op import ( diff --git a/tilelang/language/ast/ir.py b/tilelang/language/ast/ir.py index 49db2ca641..7865b74be7 100644 --- a/tilelang/language/ast/ir.py +++ b/tilelang/language/ast/ir.py @@ -1886,6 +1886,7 @@ def wrapped(*args, **kwargs): ptx_wgmma_rs = _dtype_forward(_tir_op.ptx_wgmma_rs) ptx_tcgen05_mma_ss = _dtype_forward(_tir_op.ptx_tcgen05_mma_ss) ptx_tcgen05_mma_ts = _dtype_forward(_tir_op.ptx_tcgen05_mma_ts) +ptx_tcgen05_mma_blockscaled_ss = _dtype_forward(_tir_op.ptx_tcgen05_mma_blockscaled_ss) ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix) ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async) ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk) @@ -2138,6 +2139,7 @@ def wrapped(*args, **kwargs): "ptx_wgmma_ss", "ptx_wgmma_rs", "ptx_tcgen05_mma_ss", + "ptx_tcgen05_mma_blockscaled_ss", "ptx_ldmatrix", "ptx_cp_async", "ptx_cp_async_bulk", diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index da6bde0b37..21c8214b3a 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -2,15 +2,17 @@ from __future__ import annotations +import tvm.script.parser.tir as T from tilelang._typing import BufferLikeType, BufferLikeTypeTuple, BarrierType, DType from tilelang import tvm as tvm from tilelang.language import ptx_arrive_barrier, evaluate +from tilelang.language.eager.builder import macro from tilelang.language.kernel import get_thread_bindings, get_block_extents from tilelang.utils.target import check_hip_availability from tvm import DataType, tir from tvm.runtime import convert from tvm.tir import PrimExpr, Var, Call, BufferLoad, BufferRegion -from tilelang.utils.language import retrieve_ptr +from tilelang.utils.language import retrieve_ptr, get_buffer_region_from_load, retrieve_buffer_and_offset _IS_HIP_AVAILABLE = check_hip_availability() @@ -1233,6 +1235,81 @@ def tcgen05_after_thread_sync(): return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_after_thread_sync")) +def _tcgen05_num_smem_chunks(smem_src, chunk_elems: int): + if isinstance(smem_src, tir.Buffer): + shape = list(smem_src.shape) + elif isinstance(smem_src, tir.BufferRegion): + shape = [r.extent for r in smem_src.region] + elif isinstance(smem_src, tir.BufferLoad): + region = get_buffer_region_from_load(smem_src) + if region is None: + raise TypeError("T.tcgen05_cp_warpx4 requires Buffer/BufferRegion-like scale-factor sources.") + shape = [r.extent for r in region.region] + else: + raise TypeError(f"Unsupported scale-factor buffer type: {type(smem_src)}") + + total_elems = 1 + for extent in shape: + if not isinstance(extent, tir.IntImm): + raise ValueError("Packed scale-factor helpers require a static extent.") + total_elems *= extent.value + if total_elems % chunk_elems != 0: + raise ValueError(f"Packed scale-factor helpers require total extent to be a multiple of {chunk_elems}, got {total_elems}.") + return total_elems // chunk_elems + + +def tcgen05_cp_warpx4(smem_src, tmem_dst, tmem_col_offset=0, *, use_2cta: bool = False): + """Copy one or more packed scale-factor chunks from shared memory to tensor memory. + + The helper lowers to one or more ``tcgen05.cp.cta_group::{1,2}.32x128b.warpx4`` + instructions. For 1D packed ``uint32`` scale buffers, each 128-word chunk maps to + 4 TMEM columns and the column offset is advanced automatically. + """ + num_chunks = _tcgen05_num_smem_chunks(smem_src, 128) + if isinstance(tmem_dst, tir.Buffer): + tmem_ptr = tmem_dst.data + elif isinstance(tmem_dst, (BufferLoad, BufferRegion)): + tmem_ptr = tmem_dst.buffer.data + else: + tmem_ptr = tmem_dst + ann = {"use_2cta": 1} if use_2cta else None + buffer, base_offset = retrieve_buffer_and_offset(smem_src) + + @macro + def _tcgen05_cp_warpx4_chunked(buffer, tmem_ptr, tmem_col_offset, base_offset): + for i in T.unroll(num_chunks): + chunk_ptr = buffer.access_ptr("r", offset=base_offset + i * 128) + tir.call_intrin( + "void", + tir.op.Op.get("tl.ptx_tcgen05_cp_warpx4"), + chunk_ptr, + tmem_ptr, + tmem_col_offset + i * 4, + annotations=ann, + ) + + return _tcgen05_cp_warpx4_chunked(buffer, tmem_ptr, tmem_col_offset, base_offset) + + +def tcgen05_sf_warp_transpose(smem_src): + """Warp-level transpose for one or more packed scale-factor chunks in shared memory. + + For 1D packed ``uint32`` scale buffers, the helper automatically applies the + transpose to each 128-word chunk in order. + """ + num_chunks = _tcgen05_num_smem_chunks(smem_src, 128) + + buffer, base_offset = retrieve_buffer_and_offset(smem_src) + + @macro + def _tcgen05_sf_warp_transpose_chunked(buffer, base_offset): + for i in T.unroll(num_chunks): + chunk_ptr = buffer.access_ptr("rw", offset=base_offset + i * 128) + tir.call_intrin("void", tir.op.Op.get("tl.ptx_tcgen05_sf_warp_transpose"), chunk_ptr) + + return _tcgen05_sf_warp_transpose_chunked(buffer, base_offset) + + def ptx_mma_sm70( shape, A_layout, diff --git a/tilelang/language/gemm_op.py b/tilelang/language/gemm_op.py index ed26d2e677..d5ec03728f 100644 --- a/tilelang/language/gemm_op.py +++ b/tilelang/language/gemm_op.py @@ -5,6 +5,7 @@ from tilelang._typing import BufferLikeType, BarrierType from tilelang.tileop.base import GemmWarpPolicy import tilelang.language as T +from tilelang.layout import Layout from tvm import tir from tilelang.utils.language import ( to_buffer_region, @@ -258,7 +259,9 @@ def tcgen05_gemm( compilation fails instead of silently falling back to another GEMM path. """ - ann = {"use_2cta": int(use_2cta)} if use_2cta else None + ann = {"is_tcgen05": 1} + if use_2cta: + ann["use_2cta"] = 1 return _gemm_impl( "tl.tileop.tcgen05_gemm", A, @@ -273,3 +276,208 @@ def tcgen05_gemm( mbar, annotations=ann, ) + + +def tcgen05_gemm_blockscaled( + A: BufferLikeType, + B: BufferLikeType, + C: BufferLikeType, + SFA_tmem: BufferLikeType, + SFB_tmem: BufferLikeType, + transpose_A: bool = False, + transpose_B: bool = False, + clear_accum=False, + wg_wait: int = 0, + mbar: BarrierType | None = None, + sf_a_id: int = 0, + sf_b_id: int = 0, + *, + use_2cta: bool = False, +) -> tir.PrimExpr: + """Explicit Blackwell TCGEN05 block-scaled GEMM without an implicit wait. + + This is the explicit asynchronous Blackwell TCGEN5MMA block-scaled + counterpart to `T.tcgen05_gemm(...)`. It never auto-emits an inlined + `mbarrier_wait_parity`, and compilation fails instead of silently falling + back if the requested ISA path is unavailable. + + With ``use_2cta=True``, this lowers to the true 2CTA block-scaled TCGEN05 + path only; there is no fallback or emulation. That mode requires + ``cluster_dims`` to be ``(2,1,1)`` or ``(1,2,1)``. + + A and B are FP8 (E4M3/E5M2) in shared memory, C is the accumulator in + tensor memory, and SFA/SFB are E8M0 scale factors already resident in + tensor memory. As with `T.tcgen05_gemm(...)`, this API is explicit-async: + it issues the MMA and leaves synchronization to the user schedule. + + Args: + A: FP8 input buffer A in shared memory. + B: FP8 input buffer B in shared memory. + C: Accumulator in tensor memory. + SFA_tmem: Scale factors for A in tensor memory. + SFB_tmem: Scale factors for B in tensor memory. + transpose_A: Whether A is MN-major. Default: False (K-major). + transpose_B: Whether B is K-major. Default: False (MN-major). + clear_accum: Whether to zero the accumulator. + wg_wait: Warp group wait identifier. + mbar: Mbarrier for MMA completion signaling. + sf_a_id: Scale factor ID for A (0-3). + sf_b_id: Scale factor ID for B (0-3). + use_2cta: Whether to request true ``cta_group::2`` lowering. + """ + + ann = {"use_2cta": int(use_2cta)} if use_2cta else None + + # Re-read normalized regions below after let legalization. + + def legalize(arg): + if isinstance(arg, tir.Var) and T.has_let_value(arg): + return T.get_let_value(arg).buffer + return arg + + A = legalize(A) + B = legalize(B) + C = legalize(C) + SFA_tmem = legalize(SFA_tmem) + SFB_tmem = legalize(SFB_tmem) + mbar = legalize(mbar) if mbar is not None else None + + A_region = to_buffer_region(A) + B_region = to_buffer_region(B) + C_region = to_buffer_region(C) + SFA_region = to_buffer_region(SFA_tmem) + SFB_region = to_buffer_region(SFB_tmem) + + A_shape = retrieve_shape(A_region) + B_shape = retrieve_shape(B_region) + C_shape = retrieve_shape(C_region) + + assert len(C_shape) == 2, "current only support C as a 2D tensor" + assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor" + assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor" + + M, N = C_shape + M_A = A_shape[-1] if transpose_A else A_shape[-2] + N_B = B_shape[-2] if transpose_B else B_shape[-1] + K = A_shape[-2] if transpose_A else A_shape[-1] + K_B = B_shape[-1] if transpose_B else B_shape[-2] + assert prim_expr_equal(K, K_B), f"T.tcgen05_gemm_blockscaled K shape check failed: K_A = {K}, K_B = {K_B}" + if use_2cta: + assert prim_expr_equal(M_A, M) and prim_expr_equal(N_B * 2, N), ( + f"T.tcgen05_gemm_blockscaled 2CTA shape check failed: M_A = {M_A}, expected M_C = {M}; N_B = {N_B}, expected N_C / 2 = {N} / 2" + ) + else: + assert prim_expr_equal(N_B, N), f"T.tcgen05_gemm_blockscaled N shape check failed: N_B = {N_B}, N_C = {N}" + + A_stride = retrieve_stride(A_region) + B_stride = retrieve_stride(B_region) + stride_a = A_stride[-2] + stride_b = B_stride[-2] + + A_offset = retrieve_offset(A_region) + B_offset = retrieve_offset(B_region) + offset_a = A_offset[-1] + offset_b = B_offset[-1] + + if mbar is not None: + assert isinstance(mbar, (tir.Buffer, tir.BufferLoad)), ( + f"mbar for tcgen5mma must be a tir.Buffer or tir.BufferLoad, but got {type(mbar)}" + ) + mbar = to_buffer_region(mbar, access_type="rw") + + C_coords = [r.min for r in C_region.region] + + # Convert BufferRegion to tl.region calls for arguments + A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape]) + B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape]) + C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape]) + SFA_arg = buffer_region_to_tile_region(SFA_region, "r", list(retrieve_shape(SFA_region))) + SFB_arg = buffer_region_to_tile_region(SFB_region, "r", list(retrieve_shape(SFB_region))) + + assert mbar is not None, "mbar is required for tcgen05_gemm_blockscaled" + + # Ensure sf_a_id and sf_b_id are PrimExpr + if not isinstance(sf_a_id, tir.PrimExpr): + sf_a_id = tir.const(sf_a_id, dtype="int32") + if not isinstance(sf_b_id, tir.PrimExpr): + sf_b_id = tir.const(sf_b_id, dtype="int32") + + # Block-scaled always uses Square policy (1x1 warp partition) + policy = GemmWarpPolicy.Square + + return tir.call_intrin( + "handle", + tir.op.Op.get("tl.tileop.gemm"), + A_arg, + B_arg, + C_arg, + transpose_A, + transpose_B, + M, + N, + K, + policy, + clear_accum, + stride_a, + stride_b, + offset_a, + offset_b, + 1, # k_pack + wg_wait, + mbar, + C_coords[0], + C_coords[1], + SFA_arg, # arg 19 + SFB_arg, # arg 20 + sf_a_id, # arg 21 + sf_b_id, # arg 22 + annotations=ann, + ) + + +def make_blockscaled_gemm_layout( + C: BufferLikeType, + A: BufferLikeType, + transpose_A: bool = False, +) -> Layout: + """Build the TMEM store layout for the C accumulator of a block-scaled GEMM. + + Users must call ``T.annotate_layout({C_tmem: layout})`` with the returned layout + so that subsequent ``T.copy(C_tmem, ...)`` can be lowered correctly. + + Args: + C: The TMEM accumulator buffer (block_M, block_N). + A: The FP8 operand A buffer (used to infer K and dtype). + transpose_A: Whether A is MN-major. + + Returns: + A Layout object for C's TMEM storage. + """ + from tilelang.intrinsics.tcgen05_macro_generator import TensorCoreIntrinEmitter + + C_region = to_buffer_region(C) + A_region = to_buffer_region(A) + + C_shape = retrieve_shape(C_region) + A_shape = retrieve_shape(A_region) + + M, N = int(C_shape[0]), int(C_shape[1]) + K = int(A_shape[-2] if transpose_A else A_shape[-1]) + a_dtype = str(A_region.buffer.dtype) + accum_dtype = str(C_region.buffer.dtype) + + emitter = TensorCoreIntrinEmitter( + a_dtype=a_dtype, + b_dtype=a_dtype, + accum_dtype=accum_dtype, + a_transposed=transpose_A, + b_transposed=False, + block_row_warps=1, + block_col_warps=1, + warp_row_tiles=M, + warp_col_tiles=N, + chunk=K, + ) + + c_buf = C_region.buffer if isinstance(C_region, tir.BufferRegion) else C + return emitter.make_mma_store_layout(c_buf) diff --git a/tilelang/language/tir/ir.py b/tilelang/language/tir/ir.py index 6fab09ae51..384be21ccc 100644 --- a/tilelang/language/tir/ir.py +++ b/tilelang/language/tir/ir.py @@ -292,6 +292,7 @@ def wrapped(*args, **kwargs): ptx_wgmma_rs = _dtype_forward(_tir_op.ptx_wgmma_rs) ptx_tcgen05_mma_ss = _dtype_forward(_tir_op.ptx_tcgen05_mma_ss) ptx_tcgen05_mma_ts = _dtype_forward(_tir_op.ptx_tcgen05_mma_ts) +ptx_tcgen05_mma_blockscaled_ss = _dtype_forward(_tir_op.ptx_tcgen05_mma_blockscaled_ss) ptx_ldmatrix = _tir_op.ptx_ldmatrix ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async) ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk) diff --git a/tilelang/language/tir/op.py b/tilelang/language/tir/op.py index ca70d63673..0aee38da35 100644 --- a/tilelang/language/tir/op.py +++ b/tilelang/language/tir/op.py @@ -1253,6 +1253,78 @@ def ptx_tcgen05_mma_ts( ) +def ptx_tcgen05_mma_blockscaled_ss( + kind_dtype, + desc_a, + A_offset, + desc_b, + B_offset, + C_ptr, + C_offset, + desc_val, + scale_out, + sfa_ptr, + sfa_offset, + sfb_ptr, + sfb_offset, + reserved0=0, + reserved1=0, + variant=False, + enable_2cta=False, +): + """TVM intrinsic for tcgen05.mma block-scaled (mxf8f6f4.block_scale) instructions. + + Block-scaled TCGEN05 is explicit-async and carries an explicit ``enable_2cta`` + flag, analogous to the regular SS/TS TCGEN05 intrinsics. There is no + fallback path if 2CTA is requested. + + Positional args: + kind_dtype, desc_a, A_offset, desc_b, B_offset, C_ptr, C_offset, + desc_val, scale_out, sfa_ptr, sfa_offset, sfb_ptr, sfb_offset, + reserved0, reserved1, enable_ws, enable_2cta. + """ + + if enable_2cta and isinstance(variant, str): + v_check = variant.lower() + if v_check in ("ws", "warp_specialized", "warp-specialized"): + raise ValueError("ptx_tcgen05_mma_blockscaled_ss: .ws and 2CTA cannot be combined") + elif enable_2cta and bool(variant): + raise ValueError("ptx_tcgen05_mma_blockscaled_ss: .ws and 2CTA cannot be combined") + + if isinstance(variant, str): + v = variant.lower() + if v in ("ws", "warp_specialized", "warp-specialized"): + enable_ws = True + elif v in ("default", "std", "ss"): + enable_ws = False + else: + raise ValueError(f"ptx_tcgen05_mma_blockscaled_ss: unknown variant: {variant}") + else: + enable_ws = bool(variant) + + return call_intrin( + "handle", + _tvm_op.Op.get("tl.ptx_tcgen05_mma_blockscaled_ss"), + kind_dtype, + desc_a, + A_offset, + desc_b, + B_offset, + C_ptr, + C_offset, + desc_val, + scale_out, + sfa_ptr, + sfa_offset, + sfb_ptr, + sfb_offset, + reserved0, + reserved1, + enable_ws, + enable_2cta, + ) + + def mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride): """TVM intrinsic for storing the result of PTX MMA into a destination pointer diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index a24545103c..a2290b0191 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -122,6 +122,14 @@ def wg_wait(self): def is_tcgen05(self): return getattr(self, "isTcgen05", False) + @property + def sf_a_id(self): + return self.sfAId + + @property + def sf_b_id(self): + return self.sfBId + def infer_layout(self, target: Target, thread_nums: int): """Infer the layout for the GEMM operation based on target architecture.""" gemm_inst = self._select_gemm_instruction(thread_nums, target) diff --git a/tilelang/tileop/gemm/gemm_base.py b/tilelang/tileop/gemm/gemm_base.py index 2e6d15b59b..95f69bbfa6 100644 --- a/tilelang/tileop/gemm/gemm_base.py +++ b/tilelang/tileop/gemm/gemm_base.py @@ -172,6 +172,26 @@ def C_coords(self): return [zero, zero] return [coords[i] for i in range(len(coords))] + @property + def SFARegion(self): + return getattr(self.gemm_node, "sfaRegion", None) + + @property + def SFBRegion(self): + return getattr(self.gemm_node, "sfbRegion", None) + + @property + def sf_a_id(self) -> PrimExpr: + return getattr(self.gemm_node, "sfAId", tvm.tir.const(0, T.int32)) + + @property + def sf_b_id(self) -> PrimExpr: + return getattr(self.gemm_node, "sfBId", tvm.tir.const(0, T.int32)) + + @property + def is_blockscaled(self) -> bool: + return self.SFARegion is not None and self.SFBRegion is not None + def get_region_base_offsets(self, region): """ Get the base offset (start index) for each dimension from a BufferRegion. diff --git a/tilelang/tileop/gemm/gemm_tcgen05.py b/tilelang/tileop/gemm/gemm_tcgen05.py index 4fa5815a51..f1e373ef9b 100644 --- a/tilelang/tileop/gemm/gemm_tcgen05.py +++ b/tilelang/tileop/gemm/gemm_tcgen05.py @@ -35,7 +35,8 @@ class GemmTCGEN5(GemmBase): """GEMM operator for Blackwell (SM100) TCGEN5MMA instructions. - Supports the SS (Shared-Shared) and TS (TensorMemory-Shared) variants. + Supports the SS (Shared-Shared) and TS (TensorMemory-Shared) variants, + as well as block-scaled MXFP8 GEMM when SFA/SFB scale factors are present. Layout inference and lowering are dispatched based on the memory scopes of operands A and B. """ @@ -57,8 +58,13 @@ def infer_layout(self, target: Target, thread_nums: int): For SS: both A and B get swizzled shared-memory layouts. For TS: A and C get TMEM store layouts, B gets a swizzled shared-memory layout. + For block-scaled: same as SS (A and B get swizzle, C gets TMEM store layout). """ - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GemmInst.TCGEN5MMA) + # Block-scaled GEMM keeps a 1x1 warp partition even when using cta_group::2. + if self.is_blockscaled: + m_warp, n_warp = 1, 1 + else: + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GemmInst.TCGEN5MMA) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( @@ -80,7 +86,7 @@ def infer_layout(self, target: Target, thread_nums: int): use_2cta = bool(annotations.get("use_2cta", 0)) mma_emitter.get_tcgen5_mma_meta(self.M, self.N, self.K, disable_2cta=not use_2cta) - if self.is_gemm_ss(): + if self.is_blockscaled or self.is_gemm_ss(): a_continuity = self.K if a_is_k_major else self.M b_continuity = self.K if b_is_k_major else int(self.B.shape[-1]) # don't use N, as it may be for 2cta @@ -109,7 +115,11 @@ def lower( ): """Lower the GEMM tile-op into a TIR prim_func containing TCGEN5MMA calls.""" thread_nums = thread_bounds.extent - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GemmInst.TCGEN5MMA) + # Block-scaled GEMM keeps a 1x1 warp partition even when using cta_group::2. + if self.is_blockscaled: + m_warp, n_warp = 1, 1 + else: + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GemmInst.TCGEN5MMA) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( @@ -130,6 +140,9 @@ def lower( if self.B in layout_map: mma_emitter._assign_b_shared_layout(layout_map[self.B]) + if self.is_blockscaled: + return self._lower_blockscaled(mma_emitter, thread_bounds, thread_var, mbar_phase_expr) + if not (self.is_gemm_ss() or self.is_gemm_ts()): raise ValueError(f"TCGEN5MMA supports gemm_ss and gemm_ts, got A scope {self.A.scope()}, B scope {self.B.scope()}") @@ -196,3 +209,82 @@ def _gemm_ss() -> None: if analyzer.can_prove(thread_bounds.extent == warp_size) else _Simplify(_gemm_ss_cond, inline_let=True) ) + + def _lower_blockscaled(self, mma_emitter, thread_bounds, thread_var, mbar_phase_expr: tir.PrimExpr | None = None): + """Lower block-scaled MXFP8 GEMM to TIR. + + Block-scaled GEMM follows explicit-async TCGEN5MMA semantics: the MMA + issue posts completion to `mbar`, and the user (or pipeline pass) is + responsible for waiting on that barrier at the consumption point. We + therefore never auto-emit `mbarrier_wait_parity` here. This mirrors the + `is_tcgen05=True` branch of `_gemm_ss`. `mbar_phase_expr` is accepted + for API consistency with the rest of the `GemmPyNode.Lower` chain and + so that a future synchronous block-scaled path can use it without + needing another signature change. + """ + mbar = self.mbar + if mbar is None: + raise ValueError("Block-scaled GEMM requires a valid mbarrier") + mbarptr = retrieve_ptr(mbar, "rw") + + A_shared = self.ARegion + B_shared = self.BRegion + C_local = self.C + clear_accum = self.clear_accum + SFA_tmem = self.SFARegion.buffer + SFB_tmem = self.SFBRegion.buffer + sf_a_id = self.sf_a_id + sf_b_id = self.sf_b_id + # NOTE: mbar_phase_expr is intentionally unused in the current + # frontend, which always requests explicit-async semantics. Keep the + # parameter so the signature matches `_gemm_ss` and the call site in + # `lower()` does not need a special case. + del mbar_phase_expr + + annotations = getattr(self.gemm_node, "annotations", {}) + use_2cta = bool(annotations.get("use_2cta", 0)) + mma_emitter.get_tcgen5_mma_meta(self.M, self.N, self.K, disable_2cta=not use_2cta) + _atom_m, _atom_n, _atom_k, _enable_ws, enable_2cta = (int(x) for x in mma_emitter.meta) + + analyzer = Analyzer() + warp_size = 32 + assert analyzer.can_prove(thread_bounds.min % warp_size == 0 and thread_bounds.extent % warp_size == 0), ( + "Block-scaled GEMM requires thread bounds aligned to warps." + ) + cluster_cond = not enable_2cta or T.block_rank_in_cluster() == 0 + + @T.prim_func + def _gemm_blockscaled_cond() -> None: + if cluster_cond and thread_var // 32 == thread_bounds.min // warp_size: + mma_emitter.tcgen05mma_blockscaled( + A_shared, + B_shared, + C_local, + SFA_tmem, + SFB_tmem, + mbarptr, + clear_accum, + sf_a_id, + sf_b_id, + ) + + @T.prim_func + def _gemm_blockscaled() -> None: + if cluster_cond: + mma_emitter.tcgen05mma_blockscaled( + A_shared, + B_shared, + C_local, + SFA_tmem, + SFB_tmem, + mbarptr, + clear_accum, + sf_a_id, + sf_b_id, + ) + + return ( + _Simplify(_gemm_blockscaled, inline_let=True) + if analyzer.can_prove(thread_bounds.extent == warp_size) + else _Simplify(_gemm_blockscaled_cond, inline_let=True) + ) diff --git a/tilelang/utils/language.py b/tilelang/utils/language.py index 9a9fa217a6..53a5730ca2 100644 --- a/tilelang/utils/language.py +++ b/tilelang/utils/language.py @@ -335,6 +335,43 @@ def retrieve_ptr( raise ValueError(f"Unsupported retrieve_ptr argument type: {type(obj)} for object {obj}") +def retrieve_buffer_and_offset(obj: BufferLikeType) -> tuple[Buffer, PrimExpr | int]: + """ + Retrieve the underlying buffer together with its logical element offset. + + - Buffer -> (buffer, 0) + - BufferRegion -> (buffer, offset from region minima) + - BufferLoad -> (buffer, offset from indices or derived region minima) + + This is useful when callers need to build custom access patterns from a + common buffer base rather than materializing a full `access_ptr` directly. + """ + if isinstance(obj, tir.Buffer): + return obj, 0 + + if isinstance(obj, tir.BufferRegion): + buffer, region = obj.buffer, obj.region + strides = retrieve_stride(obj) + offset = 0 + for i, r in enumerate(region): + offset += r.min * strides[i] + return buffer, offset + + if isinstance(obj, tir.BufferLoad): + region = get_buffer_region_from_load(obj) + if region is not None: + return retrieve_buffer_and_offset(region) + + buffer = obj.buffer + strides = retrieve_stride(obj) + offset = 0 + for i, idx in enumerate(obj.indices): + offset += idx * strides[i] + return buffer, offset + + raise ValueError(f"Unsupported retrieve_buffer_and_offset argument type: {type(obj)} for object {obj}") + + def retrieve_offset(obj: BufferLikeType) -> list: """ Retrieve per-dimension minima offsets. From 057e5bad220e0e28f339c56574313eafb132344f Mon Sep 17 00:00:00 2001 From: Chaofan Lin Date: Sat, 25 Apr 2026 14:11:24 +0800 Subject: [PATCH 142/156] [Host CodeGen][Refactor] Cleanup namespace and remove useless C templates (#2091) * [Host CodeGen][Refactor] Cleanup namespace and remove useless C templates * fix * fix cmake * fix log --- CMakeLists.txt | 4 +- src/target/{codegen_cpp.cc => codegen_c.cc} | 69 ++++++++++----------- src/target/{codegen_cpp.h => codegen_c.h} | 14 ++--- src/target/codegen_c_host.cc | 8 ++- src/target/codegen_c_host.h | 8 +-- src/target/{rt_mod_cpp.cc => rt_mod_c.cc} | 10 +-- src/tl_templates/cpp/gemm.h | 3 - tilelang/engine/lower.py | 4 +- 8 files changed, 60 insertions(+), 60 deletions(-) rename src/target/{codegen_cpp.cc => codegen_c.cc} (87%) rename src/target/{codegen_cpp.h => codegen_c.h} (94%) rename src/target/{rt_mod_cpp.cc => rt_mod_c.cc} (91%) delete mode 100644 src/tl_templates/cpp/gemm.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 51fc79a079..26af62a550 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -179,8 +179,8 @@ file(GLOB TILE_LANG_SRCS src/op/*.cc src/target/utils.cc src/target/codegen_c_host.cc - src/target/codegen_cpp.cc - src/target/rt_mod_cpp.cc + src/target/codegen_c.cc + src/target/rt_mod_c.cc # intrin_rule doesn't have system dependency src/target/intrin_rule*.cc ) diff --git a/src/target/codegen_cpp.cc b/src/target/codegen_c.cc similarity index 87% rename from src/target/codegen_cpp.cc rename to src/target/codegen_c.cc index 1fd6ec3026..e290d94144 100644 --- a/src/target/codegen_cpp.cc +++ b/src/target/codegen_c.cc @@ -18,9 +18,9 @@ */ /*! - * \file codegen_c_host.cc + * \file codegen_c.cc */ -#include "codegen_cpp.h" +#include "codegen_c.h" #include #include @@ -38,32 +38,31 @@ namespace tvm { namespace codegen { -CodeGenTileLangCPP::CodeGenTileLangCPP() { +CodeGenTileLangC::CodeGenTileLangC() { module_name_ = name_supply_->FreshName("__tvm_ffi_library_ctx"); } -void CodeGenTileLangCPP::Init(bool output_ssa, bool emit_asserts, - bool emit_fwd_func_decl, std::string target_str, - const std::unordered_set &devices) { +void CodeGenTileLangC::Init(bool output_ssa, bool emit_asserts, + bool emit_fwd_func_decl, std::string target_str, + const std::unordered_set &devices) { emit_asserts_ = emit_asserts; emit_fwd_func_decl_ = emit_fwd_func_decl; declared_globals_.clear(); decl_stream << "// tilelang target: " << target_str << "\n"; decl_stream << "#include \n"; - decl_stream << "#include \n"; decl_stream << "\n"; CodeGenC::Init(output_ssa); } -void CodeGenTileLangCPP::InitGlobalContext() { +void CodeGenTileLangC::InitGlobalContext() { decl_stream << "void* " << ffi::symbol::tvm_ffi_library_ctx << " = NULL;\n"; } -void CodeGenTileLangCPP::DefineModuleName() { +void CodeGenTileLangC::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; } -void CodeGenTileLangCPP::GenerateForwardFunctionDeclarations( +void CodeGenTileLangC::GenerateForwardFunctionDeclarations( String global_symbol, const Array &arg_types, const Type &ret_type) { @@ -87,13 +86,13 @@ void CodeGenTileLangCPP::GenerateForwardFunctionDeclarations( fwd_decl_stream << ");\n"; } -void CodeGenTileLangCPP::PrintFuncPrefix(std::ostream &os) { // NOLINT(*) +void CodeGenTileLangC::PrintFuncPrefix(std::ostream &os) { // NOLINT(*) os << "#ifdef __cplusplus\n" << "extern \"C\"\n" << "#endif\n"; } -void CodeGenTileLangCPP::PrintType(DataType t, std::ostream &os) { // NOLINT(*) +void CodeGenTileLangC::PrintType(DataType t, std::ostream &os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { ICHECK_EQ(lanes, 1) << "does not support vector types"; @@ -164,8 +163,8 @@ void CodeGenTileLangCPP::PrintType(DataType t, std::ostream &os) { // NOLINT(*) LOG(FATAL) << "Cannot convert type " << t << " to C type"; } -void CodeGenTileLangCPP::VisitExpr_(const BroadcastNode *op, - std::ostream &os) { // NOLINT(*) +void CodeGenTileLangC::VisitExpr_(const BroadcastNode *op, + std::ostream &os) { // NOLINT(*) std::string v = PrintExpr(op->value); int lanes = op->dtype.lanes(); os << "(("; @@ -179,7 +178,7 @@ void CodeGenTileLangCPP::VisitExpr_(const BroadcastNode *op, os << "))"; } -void CodeGenTileLangCPP::PrintGetFuncFromBackend( +void CodeGenTileLangC::PrintGetFuncFromBackend( const std::string &func_name, const std::string &packed_func_name) { this->PrintIndent(); this->stream << "if (" << packed_func_name << " == NULL) {\n"; @@ -199,8 +198,8 @@ void CodeGenTileLangCPP::PrintGetFuncFromBackend( this->stream << "}\n"; } -void CodeGenTileLangCPP::PrintFuncCall(const std::string &packed_func_name, - int num_args) { +void CodeGenTileLangC::PrintFuncCall(const std::string &packed_func_name, + int num_args) { this->PrintIndent(); std::string ret_val = name_supply_->FreshName("ret_val"); std::string ret_type_code = name_supply_->FreshName("ret_type_code"); @@ -223,9 +222,9 @@ void CodeGenTileLangCPP::PrintFuncCall(const std::string &packed_func_name, this->stream << "}\n"; } -void CodeGenTileLangCPP::PrintFuncCallC( - const std::string &packed_func_name, int num_args, - const std::string &resource_handle_name) { +void CodeGenTileLangC::PrintFuncCallC(const std::string &packed_func_name, + int num_args, + const std::string &resource_handle_name) { this->PrintIndent(); std::string ret_val = name_supply_->FreshName("ret_val"); std::string ret_type_code = name_supply_->FreshName("ret_type_code"); @@ -251,7 +250,7 @@ void CodeGenTileLangCPP::PrintFuncCallC( this->stream << "}\n"; } -void CodeGenTileLangCPP::AddFunction(const PrimFunc &f) { +void CodeGenTileLangC::AddFunction(const PrimFunc &f) { // clear previous generated state. this->InitFuncState(f); // reserve keywords @@ -318,7 +317,7 @@ void CodeGenTileLangCPP::AddFunction(const PrimFunc &f) { this->stream << "}\n\n"; } -std::string CodeGenTileLangCPP::GetPackedName(const CallNode *op) { +std::string CodeGenTileLangC::GetPackedName(const CallNode *op) { const StringImmNode *s = op->args[0].as(); ICHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name"; @@ -336,9 +335,9 @@ std::string CodeGenTileLangCPP::GetPackedName(const CallNode *op) { return unique_name; } -CodeGenTileLangCPP::FunctionInfo -CodeGenTileLangCPP::GetFunctionInfo(const CallNode *op, - bool has_resource_handle) { +CodeGenTileLangC::FunctionInfo +CodeGenTileLangC::GetFunctionInfo(const CallNode *op, + bool has_resource_handle) { const StringImmNode *s = op->args[0].as(); ICHECK(s != nullptr) << "tvm_call_[c]packed_lowered expects first argument as function name"; @@ -379,8 +378,8 @@ CodeGenTileLangCPP::GetFunctionInfo(const CallNode *op, return {func_name, num_args, "NULL"}; } -void CodeGenTileLangCPP::VisitExpr_(const CallNode *op, - std::ostream &os) { // NOLINT(*) +void CodeGenTileLangC::VisitExpr_(const CallNode *op, + std::ostream &os) { // NOLINT(*) if (op->op.same_as(builtin::tvm_stack_alloca())) { std::string stack_name = name_supply_->FreshName("stack"); const std::string &type = op->args[0].as()->value; @@ -425,7 +424,7 @@ void CodeGenTileLangCPP::VisitExpr_(const CallNode *op, } } -void CodeGenTileLangCPP::VisitStmt_(const AssertStmtNode *op) { // NOLINT(*) +void CodeGenTileLangC::VisitStmt_(const AssertStmtNode *op) { // NOLINT(*) if (emit_asserts_) { std::string cond = PrintExpr(op->condition); PrintIndent(); @@ -443,7 +442,7 @@ void CodeGenTileLangCPP::VisitStmt_(const AssertStmtNode *op) { // NOLINT(*) this->PrintStmt(op->body); } -void CodeGenTileLangCPP::VisitStmt_(const AllocateNode *op) { +void CodeGenTileLangC::VisitStmt_(const AllocateNode *op) { ICHECK(!is_zero(op->condition)); std::string vid = AllocVarID(op->buffer_var.get()); @@ -461,20 +460,20 @@ void CodeGenTileLangCPP::VisitStmt_(const AllocateNode *op) { this->PrintStmt(op->body); } -void CodeGenTileLangCPP::VisitExpr_(const MinNode *op, - std::ostream &os) { // NOLINT(*) +void CodeGenTileLangC::VisitExpr_(const MinNode *op, + std::ostream &os) { // NOLINT(*) PrintTernaryCondExpr(op, "<", os); } -void CodeGenTileLangCPP::VisitExpr_(const MaxNode *op, - std::ostream &os) { // NOLINT(*) +void CodeGenTileLangC::VisitExpr_(const MaxNode *op, + std::ostream &os) { // NOLINT(*) PrintTernaryCondExpr(op, ">", os); } template inline void -CodeGenTileLangCPP::PrintTernaryCondExpr(const T *op, const char *compare, - std::ostream &os) { // NOLINT(*) +CodeGenTileLangC::PrintTernaryCondExpr(const T *op, const char *compare, + std::ostream &os) { // NOLINT(*) std::ostringstream temp_a; VisitExpr(op->a, temp_a); std::string a_id = SSAGetID(temp_a.str(), op->a.dtype()); diff --git a/src/target/codegen_cpp.h b/src/target/codegen_c.h similarity index 94% rename from src/target/codegen_cpp.h rename to src/target/codegen_c.h index 25bb115c82..64bff21b66 100644 --- a/src/target/codegen_cpp.h +++ b/src/target/codegen_c.h @@ -18,11 +18,11 @@ */ /*! - * \file codegen_c_host.h - * \brief Generate C host code. + * \file codegen_c.h + * \brief Generate C code when target is c (CPU). */ -#ifndef TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_ -#define TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_ +#ifndef TVM_TL_CODEGEN_C_H_ +#define TVM_TL_CODEGEN_C_H_ #include #include @@ -37,9 +37,9 @@ namespace tvm { namespace codegen { -class CodeGenTileLangCPP : public CodeGenC { +class CodeGenTileLangC : public CodeGenC { public: - CodeGenTileLangCPP(); + CodeGenTileLangC(); void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl, std::string target_str, const std::unordered_set &devices); @@ -122,4 +122,4 @@ class CodeGenTileLangCPP : public CodeGenC { } // namespace codegen } // namespace tvm -#endif // TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_ +#endif // TVM_TL_CODEGEN_C_H_ diff --git a/src/target/codegen_c_host.cc b/src/target/codegen_c_host.cc index 6af6bb9725..aaf1a452fe 100644 --- a/src/target/codegen_c_host.cc +++ b/src/target/codegen_c_host.cc @@ -503,6 +503,10 @@ using tvm::ffi::String; // Build function that mirrors TVM's host C codegen, registered under a // TileLang-specific name. +// NOTE(chaofan): Different from codegen_c / BuildTileLangC, this CodeGen class +// is only used to generate C host code for TileLang when the host codegen is +// enabled. If you use TileLang to generate CPU code (in this case, C is the +// device code) , it will be generated by BuildTileLangC. ::tvm::ffi::Module BuildTileLangCHost(::tvm::IRModule mod, ::tvm::Target target) { bool output_ssa = false; @@ -534,7 +538,7 @@ ::tvm::ffi::Module BuildTileLangCHost(::tvm::IRModule mod, std::vector> funcs; for (auto [gvar, base_func] : mod->functions) { ICHECK(base_func->IsInstance<::tvm::tir::PrimFuncNode>()) - << "CodegenCHost: Can only take PrimFunc"; + << "TileLangCodegenCHost: Can only take PrimFunc"; auto prim_func = ::tvm::Downcast<::tvm::tir::PrimFunc>(base_func); funcs.push_back({gvar, prim_func}); } @@ -565,7 +569,7 @@ ::tvm::ffi::Module BuildTileLangCHost(::tvm::IRModule mod, TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("target.build.tilelang_c", BuildTileLangCHost); + refl::GlobalDef().def("target.build.tilelang_c_host", BuildTileLangCHost); } } // namespace tl diff --git a/src/target/codegen_c_host.h b/src/target/codegen_c_host.h index 1644246d74..345caab88a 100644 --- a/src/target/codegen_c_host.h +++ b/src/target/codegen_c_host.h @@ -19,10 +19,10 @@ /*! * \file codegen_c_host.h - * \brief Generate C host code (TileLang copy). + * \brief Generate C host code with TVM FFI when Host CodeGen is enabled. */ -#ifndef TL_TARGET_SOURCE_CODEGEN_C_HOST_H_ -#define TL_TARGET_SOURCE_CODEGEN_C_HOST_H_ +#ifndef TVM_TL_CODEGEN_C_HOST_H_ +#define TVM_TL_CODEGEN_C_HOST_H_ #include #include @@ -130,4 +130,4 @@ class CodeGenCHost : public tvm::codegen::CodeGenC { } // namespace tl } // namespace tvm -#endif // TL_TARGET_SOURCE_CODEGEN_C_HOST_H_ +#endif // TVM_TL_CODEGEN_C_HOST_H_ diff --git a/src/target/rt_mod_cpp.cc b/src/target/rt_mod_c.cc similarity index 91% rename from src/target/rt_mod_cpp.cc rename to src/target/rt_mod_c.cc index 10e3d57b6a..bc652b1e2b 100644 --- a/src/target/rt_mod_cpp.cc +++ b/src/target/rt_mod_c.cc @@ -1,4 +1,4 @@ -#include "codegen_cpp.h" +#include "codegen_c.h" #include #include @@ -7,7 +7,7 @@ namespace tvm { namespace codegen { -ffi::Module BuildCPPHost(IRModule mod, Target target) { +ffi::Module BuildTileLangC(IRModule mod, Target target) { bool output_ssa = false; bool emit_asserts = false; bool emit_fwd_func_decl = true; @@ -21,7 +21,7 @@ ffi::Module BuildCPPHost(IRModule mod, Target target) { } } - CodeGenTileLangCPP cg; + CodeGenTileLangC cg; cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices); cg.SetConstantsByteAlignment( target->GetAttr("constants-byte-alignment").value_or(16)); @@ -33,7 +33,7 @@ ffi::Module BuildCPPHost(IRModule mod, Target target) { std::vector> funcs; for (auto [gvar, base_func] : mod->functions) { ICHECK(base_func->IsInstance()) - << "CodegenCHost: Can only take PrimFunc"; + << "BuildTileLangC: Can only take PrimFunc"; auto prim_func = Downcast(base_func); funcs.push_back({gvar, prim_func}); } @@ -72,7 +72,7 @@ ffi::Module BuildCPPHost(IRModule mod, Target target) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("target.build.tilelang_cpp", BuildCPPHost); + refl::GlobalDef().def("target.build.tilelang_c", BuildTileLangC); } } // namespace codegen diff --git a/src/tl_templates/cpp/gemm.h b/src/tl_templates/cpp/gemm.h deleted file mode 100644 index 1d8fbb7e2a..0000000000 --- a/src/tl_templates/cpp/gemm.h +++ /dev/null @@ -1,3 +0,0 @@ -#pragma once - -// Not Implemented diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index b39a58c6fb..e900890300 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -218,7 +218,7 @@ def host_codegen(host_mod: tvm.IRModule, target_host: Target, target: Target | N if target_host.kind.name == "llvm": host_mod = tvm.ffi.get_global_func("target.build.llvm")(host_mod, target_host) elif target_host.kind.name == "c": - host_mod = tvm.ffi.get_global_func("target.build.tilelang_c")(host_mod, target_host) + host_mod = tvm.ffi.get_global_func("target.build.tilelang_c_host")(host_mod, target_host) else: raise ValueError(f"Target host {target_host.kind.name} is not supported") return host_mod @@ -255,7 +255,7 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> elif target.kind.name == "hip": device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip_without_compile")(device_mod, target) elif target.kind.name == "c": - device_mod = tvm.ffi.get_global_func("target.build.tilelang_cpp")(device_mod, target) + device_mod = tvm.ffi.get_global_func("target.build.tilelang_c")(device_mod, target) elif target.kind.name == "llvm": device_mod = tvm.ffi.get_global_func("target.build.llvm")(device_mod, target) elif target.kind.name == "webgpu": From 3f16e5043e26953aa6f198ec942bd37a70c39017 Mon Sep 17 00:00:00 2001 From: foraxe <73625538+foraxe@users.noreply.github.com> Date: Sat, 25 Apr 2026 14:51:30 +0800 Subject: [PATCH 143/156] Add opt-out for prelower semantic checks for DeepSeek V4 Flash on ARM64 (#2094) * Add opt-out for prelower semantic checks * unify code style --------- Co-authored-by: Foraxe Co-authored-by: SiriusNEO --- tilelang/engine/phase.py | 10 ++++++++++ tilelang/transform/pass_config.py | 3 +++ 2 files changed, 13 insertions(+) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index b26ae38a6e..5563845214 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -80,6 +80,13 @@ def should_enable_race_check(pass_ctx: PassContext | None = None) -> bool: return enabled +def should_enable_prelower_semantic_check(pass_ctx: PassContext | None = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + enabled = not pass_ctx.config.get(tilelang.PassConfigKey.TL_DISABLE_PRELOWER_SEMANTIC_CHECK, False) + return enabled + + def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() @@ -122,6 +129,9 @@ def PreLowerSemanticCheck(mod: IRModule) -> None: Note: This is a validation-only pipeline of passes and does not modify or return the module. """ + if not should_enable_prelower_semantic_check(): + return + # Print AST for debugging purpose if should_enable_ast_print(): tilelang.analysis.ASTPrinter()(mod) diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py index 872d0d3484..d50f155975 100644 --- a/tilelang/transform/pass_config.py +++ b/tilelang/transform/pass_config.py @@ -52,6 +52,9 @@ class PassConfigKey(str, Enum): TL_DISABLE_DATA_RACE_CHECK = "tl.disable_data_race_check" """Disable data race check in TileLang. Default: False""" + TL_DISABLE_PRELOWER_SEMANTIC_CHECK = "tl.disable_prelower_semantic_check" + """Disable Python-side pre-lower semantic checks. Default: False""" + TL_DISABLE_WARP_SPECIALIZED = "tl.disable_warp_specialized" """Disable warp specialization optimization. Default: False""" From 0ee634559ee87b06c7acc5426cf7ba8abe2abf18 Mon Sep 17 00:00:00 2001 From: Yufei Xu Date: Sat, 25 Apr 2026 20:44:50 +0800 Subject: [PATCH 144/156] [Example] Add HISA: hierarchical sparse attention indexer (#2069) * add hisa prefill kernels and pipeline * pass lint check * refactor into eagerjit style --------- Co-authored-by: yufeixu --- examples/dsa_hisa/README.md | 200 +++++++++++ examples/dsa_hisa/block_sparse_mqa_fp8.py | 269 +++++++++++++++ .../dsa_hisa/clean_and_maintain_logits.py | 121 +++++++ examples/dsa_hisa/fp8_block_mean_pooling.py | 146 ++++++++ examples/dsa_hisa/hisa.py | 240 +++++++++++++ examples/dsa_hisa/pool_mqa_fp8.py | 257 ++++++++++++++ examples/dsa_hisa/tilelang_utils.py | 314 ++++++++++++++++++ 7 files changed, 1547 insertions(+) create mode 100644 examples/dsa_hisa/README.md create mode 100644 examples/dsa_hisa/block_sparse_mqa_fp8.py create mode 100644 examples/dsa_hisa/clean_and_maintain_logits.py create mode 100644 examples/dsa_hisa/fp8_block_mean_pooling.py create mode 100644 examples/dsa_hisa/hisa.py create mode 100644 examples/dsa_hisa/pool_mqa_fp8.py create mode 100644 examples/dsa_hisa/tilelang_utils.py diff --git a/examples/dsa_hisa/README.md b/examples/dsa_hisa/README.md new file mode 100644 index 0000000000..5eb0f72a8a --- /dev/null +++ b/examples/dsa_hisa/README.md @@ -0,0 +1,200 @@ +# tilelang_kernels — hisa prefill pipeline + +Tilelang prefill implementation of **hisa** (HIerarchical Sparse Attention). +Paper: . + +## What is HISA? + +HISA optimizes DeepSeek sparse attention by a plug-and-play replacement +for the indexer that rewrites the search path from a flat token scan into +a two-stage hierarchical procedure. + +**Stage 1 — coarse block-level selection.** Group K tokens into pool blocks +of `k_block_size` tokens, mean-pool each block, then score each query +against all pool blocks and pick the top `block_topk` blocks per query. + +**Stage 2 — fine-grained token-level scoring.** For each query, run a +full-resolution MQA over the raw tokens inside its selected blocks, then +pick the top `topk_tokens` tokens per query. + +## Files + +| file | step | role | +|---|---|---| +| `fp8_block_mean_pooling.py` | 1.1 | Mean-pool raw K into pool blocks (fp8 + per-block f32 scale) | +| `pool_mqa_fp8.py` | 1.2 | fp8×fp8 score `Q · pooled_K` → one logit per (query, pool block) | +| `clean_and_maintain_logits.py` | 1.3 | In-place mask on stage-1 logits: -inf outside per-query range, +inf at first/last valid block | +| `block_sparse_mqa_fp8.py` | 2.1 | fp8×fp8 fine-grained score over the raw tokens of the `block_topk` selected blocks | +| `hisa.py` | — | End-to-end orchestration: all four kernels + the two `torch.topk` steps + the index-translation post-processing | + +Each per-kernel file has one `test_*` entry that (a) runs the kernel + +torch ref, (b) asserts via `torch.testing.assert_close`, (c) prints the +latency of the kernel. `hisa.py` has `test_hisa` that runs the full +pipeline, checks the output-index mask invariant, and prints end-to-end +latency. + +## Per-kernel reference + +### 1.1 `fp8_block_mean_pooling.py` + +**Function**: `fp8_native_block_mean_pooling` + +**Meaning**: flat per-block mean of the chunk's K tokens, re-quantized to +fp8 with a per-block f32 scale. Groups `N` K tokens into +`ceildiv(N, k_block_size)` pool blocks. + +**Interface**: +```python +blocked_k, blocked_k_scale = fp8_native_block_mean_pooling_interface( + k, # [N, D] fp8 + k_scale, # [N] f32 — per-token scale from indexer_k_quant_and_cache + k_block_size, +) +# blocked_k: [num_blocks, D] fp8 +# blocked_k_scale: [num_blocks] f32 +``` + +**What it does**: per pool block `b` of size `kb = k_block_size`, +1. dequantize each of the `kb` tokens: `k_f[i] = k_fp8[i] * k_scale[i]` +2. average across the block in f32: `mean = sum_i k_f[i] / kb` (or the + actual valid count for the ragged tail block) +3. re-quantize the f32 mean to fp8 with a per-block scale + `block_scale = max(max_abs(mean) / 448, 1e-10)`, writing + `blocked_k[b] = fp8(mean / block_scale)` and `blocked_k_scale[b] = block_scale`. + +### 1.2 `pool_mqa_fp8.py` + +**Function**: `pool_mqa_attn_return_logits_fp8` + +**Meaning**: coarse-grained fp8 multi-query attention over the **pooled** K +(one vector per pool block). Produces one logit per (query, pool-block). + +**Interface**: +```python +block_k_score = pool_mqa_attn_return_logits_fp8_interface( + q_fp8, # [M, H, D] fp8 + blocked_kv_fp8, # [Nb, D] fp8 (from step 1.1) + blocked_kv_scale, # [Nb] f32 (from step 1.1) + weights_f32, # [M, H] f32 + cu_seqlen_blocked_ks, # [M] int32 — per-query start in pool-block coords + cu_seqlen_blocked_ke, # [M] int32 — per-query end in pool-block coords +) +# block_k_score: [M, Nb] f32 +``` + +**What it does**: for each query `m` and each pool block `n` in +`[cu_seqlen_blocked_ks[m], cu_seqlen_blocked_ke[m])`, +``` +block_k_score[m, n] = sum_h ReLU(q[m, h] · blocked_k[n]) * blocked_k_scale[n] * weights[m, h] +``` +Uses tile-level fp8×fp8→f32 Tensor Core GEMM; the per-block scale is +applied post-GEMM. The kernel processes queries in tiles of size +`block_Q × block_N` and **writes the union of the tile's queries' visible +K ranges** — entries outside an individual query's range inside that +union still carry raw dot-product values (they will be masked by +step 1.3 next). Entries outside the tile union are left at their +zero-init value. + +### 1.3 `clean_and_maintain_logits.py` + +**Function**: `clean_and_maintain_logits_` + +**Meaning**: in-place post-kernel mask on the stage-1 logits. + +**Interface**: +```python +clean_and_maintain_logits_interface( + logits, # [M, Nb] f32 — stage-1 output; modified in place + cu_seqlen_ks, # [M] int32 — per-row start (inclusive) + cu_seqlen_ke, # [M] int32 — per-row end (exclusive) +) +``` + +**What it does**: for each row `m`, +- positions outside `[cu_seqlen_ks[m], cu_seqlen_ke[m])` → set to `-inf` + (so `torch.topk` ignores them), +- positions `cu_seqlen_ks[m]` and `cu_seqlen_ke[m] - 1` → set to `+inf` + (force-maintain the boundary blocks: they are always picked by the + subsequent top-block selection — a standard hisa trick to preserve + sink and local blocks). + +### 2.1 `block_sparse_mqa_fp8.py` + +**Function**: `fp8_native_block_sparse_mqa_attn_return_logits` + +**Meaning**: fine-grained fp8 MQA over only the **raw K tokens** inside the +top-`block_topk` pool blocks selected per query. Two kernel variants are +auto-dispatched by the factory: +- general (`kv_block_size > block_N`): pipelined sub-block inner loop +- small-pooling-size (`kv_block_size == block_N`): single pass, no pipeline + +**Interface**: +```python +block_sparse_logits = fp8_native_block_sparse_mqa_attn_return_logits_interface( + q, # [M, H, D] fp8 + k, # [N, D] fp8 + k_scale, # [N] f32 + topk_block_index, # [M, block_topk] int64 — from torch.topk over stage-1 scores + kv_block_size, # == k_block_size + weights, # [M, H] f32 + cu_seqlen_ks, # [M] int32 — per-query K start (absolute, in raw tokens) + cu_seqlen_ke, # [M] int32 — per-query K end +) +# block_sparse_logits: [M, block_topk * kv_block_size] f32 +``` + +**What it does**: for each query `m`, for each selected block +`t ∈ [0, block_topk)` with `blk = topk_block_index[m, t]`, for each +in-block offset `i ∈ [0, kv_block_size)`, +``` +k_abs = blk * kv_block_size + i +if k_abs ∉ [cu_seqlen_ks[m], cu_seqlen_ke[m]) or k_abs >= N: + block_sparse_logits[m, t * kv_block_size + i] = -inf +else: + block_sparse_logits[m, t * kv_block_size + i] = + sum_h ReLU(q[m, h] · k[k_abs]) * k_scale[k_abs] * weights[m, h] +``` +The out-of-range mask is written directly by this kernel — no separate +mask pass is needed here (unlike stage 1). + +### End-to-end `hisa.py` + +**Function**: `hisa_indexer` + +**Meaning**: single entry point that runs the full pipeline below. + +**Interface**: +```python +topk_indices = hisa_indexer( + q, # [M, H, D] fp8 + k, # [N, D] fp8 + k_scale, # [N] f32 + weights, # [M, H] f32 + cu_seqlen_ks, # [M] int32 — per-query K start + cu_seqlen_ke, # [M] int32 — per-query K end + *, + k_block_size, # pool block size (=128 in DeepSeek-V3.2) + block_topk, # number of top pool blocks kept per query + topk_tokens, # final top-k size handed to the sparse attention +) +# topk_indices: [M, topk_tokens] int32 — each row is the query's top-k K +# positions expressed as offsets within its own [cu_ks, cu_ke) window. +# Out-of-range slots are -1. +``` + +**Pipeline**: + +``` +(1.1) fp8_native_block_mean_pooling K, k_scale → blocked_k, blocked_k_scale +(1.2) pool_mqa_attn_return_logits_fp8 Q × blocked_k → block_k_score[M, Nb] +(1.3) clean_and_maintain_logits in-place mask (-inf/+inf) on block_k_score +(1.4) torch.topk(block_k_score.bfloat16(), → topk_block_indices[M, block_topk] int64 + k=block_topk, sorted=False) +(2.1) fp8_native_block_sparse_mqa_… Q × K[selected] → block_sparse_logits + [M, block_topk * k_block_size] +(2.2) torch.topk(block_sparse_logits, → relevant_topk_indices[M, topk_tokens] int64 + k=topk_tokens) +(2.3) (Python) gather topk_block_indices + → absolute K positions, then subtract + arith + subtract cu_seqlen_ks + mask cu_seqlen_ks for per-query-relative offsets + → topk_indices[M, topk_tokens] int32 +``` diff --git a/examples/dsa_hisa/block_sparse_mqa_fp8.py b/examples/dsa_hisa/block_sparse_mqa_fp8.py new file mode 100644 index 0000000000..10f95bb204 --- /dev/null +++ b/examples/dsa_hisa/block_sparse_mqa_fp8.py @@ -0,0 +1,269 @@ +import tilelang +from tilelang import language as T +from tilelang.profiler import do_bench +import torch + +from tilelang_utils import prepare_ks_ke_from_cu_seqlens + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def fp8_native_block_sparse_mqa_attn_return_logits( + IndexQ, + IndexK, + IndexKScale, + TopKBlockIndex, + Weights, + CuSeqLenKS, + CuSeqLenKE, + heads: int = 64, + index_dim: int = 128, + kv_block_size: int = 128, + topk: int = 64, + block_N: int = 128, + num_stages: int = 1, + threads: int = 256, +): + fp8_dtype = T.float8_e4m3fn + accum_dtype = T.float32 + index_dtype = T.int32 + topk_index_dtype = T.int64 + + seq_len, seq_len_kv = T.const("seq_len, seq_len_kv") + + H_per_block = heads + block_N = min(block_N, kv_block_size // 2) + assert kv_block_size % block_N == 0, "block_N must divide kv_block_size" + + IndexQ: T.Tensor[[seq_len * heads, index_dim], fp8_dtype] + IndexK: T.Tensor[[seq_len_kv, index_dim], fp8_dtype] + IndexKScale: T.Tensor[[seq_len_kv], accum_dtype] + TopKBlockIndex: T.Tensor[[seq_len, topk], topk_index_dtype] + Weights: T.Tensor[[seq_len, heads], accum_dtype] + CuSeqLenKS: T.Tensor[[seq_len], index_dtype] + CuSeqLenKE: T.Tensor[[seq_len], index_dtype] + + Logits = T.empty((seq_len, topk * kv_block_size), accum_dtype) + + with T.Kernel(seq_len, threads=threads) as bx: + index_q_shared = T.alloc_shared([H_per_block, index_dim], fp8_dtype) + index_k_shared = T.alloc_shared([block_N, index_dim], fp8_dtype) + # Shared (zero-init'd) — see note in the hisa source about serial-topk + # loop making shared slightly faster than fragment here. + scale_shared = T.alloc_shared([block_N], accum_dtype) + + s = T.alloc_fragment([block_N, H_per_block], accum_dtype) + s_reshaped = T.reshape(s, (block_N, H_per_block // heads, heads)) + logits = T.alloc_fragment([block_N, H_per_block // heads], accum_dtype) + weights = T.alloc_fragment([H_per_block // heads, heads], accum_dtype) + + seq_len_i = bx + + cu_k_s_min = CuSeqLenKS[seq_len_i] + cu_k_e_max = CuSeqLenKE[seq_len_i] + + T.copy(IndexQ[seq_len_i * heads : seq_len_i * heads + H_per_block, :], index_q_shared) + T.copy(Weights[seq_len_i, :], weights) + + for n_i in T.serial(topk): + topk_block_id = T.cast(TopKBlockIndex[seq_len_i, n_i], index_dtype) + block_s = topk_block_id * kv_block_size + for b_i in T.Pipelined(kv_block_size // block_N, num_stages=num_stages): + block_s_i = block_s + b_i * block_N + + T.copy(IndexK[block_s_i : block_s_i + block_N, :], index_k_shared) + for bn_i in T.Parallel(block_N): + scale_shared[bn_i] = IndexKScale[block_s_i + bn_i] + + T.gemm( + index_k_shared, + index_q_shared, + s, + transpose_B=True, + clear_accum=True, + policy=T.GemmWarpPolicy.FullRow, + ) + + for bn_i, bq_i, h_i in T.Parallel(block_N, H_per_block // heads, heads): + s_reshaped[bn_i, bq_i, h_i] = T.max(s_reshaped[bn_i, bq_i, h_i] * scale_shared[bn_i], 0) * weights[bq_i, h_i] + + T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) + + for i_i in T.Parallel(block_N): + k_i = block_s_i + i_i + if k_i < cu_k_s_min or k_i >= cu_k_e_max: + logits[i_i, 0] = -T.infinity(accum_dtype) + + for bn_i in T.Parallel(block_N): + Logits[seq_len_i, n_i * kv_block_size + b_i * block_N + bn_i] = logits[bn_i, 0] + + return Logits + + +def fp8_native_block_sparse_mqa_attn_return_logits_interface( + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + topk_block_index: torch.Tensor, + kv_block_size: int, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +): + seq_len, heads, index_dim = q.shape + topk = topk_block_index.shape[1] + logits = fp8_native_block_sparse_mqa_attn_return_logits( + q.view(seq_len * heads, index_dim), + k, + k_scale, + topk_block_index, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + heads=heads, + index_dim=index_dim, + kv_block_size=kv_block_size, + topk=topk, + ) + return logits + + +def ref_fp8_block_sparse_mqa( + q_fp8: torch.Tensor, + k_fp8: torch.Tensor, + k_scale: torch.Tensor, + topk_block_index: torch.Tensor, + kv_block_size: int, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + M, H, D = q_fp8.shape + N = k_fp8.shape[0] + topk = topk_block_index.shape[1] + + block_starts = topk_block_index.long() * kv_block_size # [M, topk] + pos_in_block = torch.arange(kv_block_size, device=q_fp8.device) + k_abs = block_starts[..., None] + pos_in_block[None, None, :] # [M, topk, B] + k_safe = k_abs.clamp(0, N - 1) + + q_f = q_fp8.float() + k_f = k_fp8.float() * k_scale[:, None] + gathered_k = k_f[k_safe.flatten()].reshape(M, topk, kv_block_size, D) + + s = torch.einsum("mhd,mtid->mtih", q_f, gathered_k) # [M, topk, B, H] + logits = (s.clamp(min=0) * weights[:, None, None, :]).sum(dim=-1) # [M, topk, B] + + in_range = (k_abs >= cu_seqlen_ks.long()[:, None, None]) & (k_abs < cu_seqlen_ke.long()[:, None, None]) & (k_abs < N) + logits = logits.masked_fill(~in_range, float("-inf")) + return logits.reshape(M, topk * kv_block_size) + + +def test_fp8_block_sparse_mqa( + M: int = 1024, + H: int = 64, + D: int = 128, + kv_block_size: int = 128, + topk: int = 64, + num_seqs: int = 1, +): + """Correctness + speed test packing `num_seqs` equal-length causal + sequences into the [M, H, D] Q and [M, D] K tensors. Each query sees + only the prefix of its own sequence (``cu_ks = start_of_seq``, + ``cu_ke = start_of_seq + position_in_seq + 1``). + + ``topk_block_index`` is drawn at random from [0, num_k_blocks) — some + picks will point to blocks outside the query's own sequence; those + positions get -inf via the kernel's built-in mask, and the torch ref + produces the same -inf. Comparison checks both the +/-inf mask + pattern (exact) and the finite values (fp8 tolerance).""" + torch.manual_seed(0) + assert M % num_seqs == 0, f"M ({M}) must be divisible by num_seqs ({num_seqs})" + N = M # causal self-attention prefill, packed + + per_seq = M // num_seqs + cu_seqlens = torch.arange(num_seqs + 1, device="cuda", dtype=torch.long) * per_seq + ks_long, ke_long = prepare_ks_ke_from_cu_seqlens(cu_seqlens) + cu_ks = ks_long.to(torch.int32).contiguous() + cu_ke = ke_long.to(torch.int32).contiguous() + + q_bf16 = torch.randn(M, H, D, device="cuda", dtype=torch.bfloat16) + q = q_bf16.to(torch.float8_e4m3fn) + k_bf16 = torch.randn(N, D, device="cuda", dtype=torch.bfloat16) + k = k_bf16.to(torch.float8_e4m3fn) + k_scale = (0.1 + 0.01 * torch.rand(N, device="cuda", dtype=torch.float32)).contiguous() + weights = torch.randn(M, H, device="cuda", dtype=torch.float32) + + # Random per-query top-k blocks (distinct indices drawn from [0, num_blocks)). + num_k_blocks = (N + kv_block_size - 1) // kv_block_size + topk = min(topk, num_k_blocks) + g = torch.Generator(device="cuda").manual_seed(42) + topk_block_index = torch.stack([torch.randperm(num_k_blocks, generator=g, device="cuda")[:topk] for _ in range(M)]).to(torch.int64) + + # Correctness. + got = fp8_native_block_sparse_mqa_attn_return_logits_interface( + q, + k, + k_scale, + topk_block_index, + kv_block_size, + weights, + cu_ks, + cu_ke, + ) + ref = ref_fp8_block_sparse_mqa( + q, + k, + k_scale, + topk_block_index, + kv_block_size, + weights, + cu_ks, + cu_ke, + ) + # The kernel marks out-of-range as -inf. Compare finite positions only — + # the -inf mask pattern must agree exactly, so we also check that. + finite = torch.isfinite(got) & torch.isfinite(ref) + assert torch.equal(torch.isposinf(got), torch.isposinf(ref)), "pos-inf mask differs" + assert torch.equal(torch.isneginf(got), torch.isneginf(ref)), "neg-inf mask differs" + torch.testing.assert_close(got[finite], ref[finite], rtol=1e-1, atol=2e-1) + print(f" correctness: PASS (M={M}, H={H}, D={D}, kv_block_size={kv_block_size}, topk={topk}, num_seqs={num_seqs}, per_seq={per_seq})") + + # Speed. + def fn(): + return fp8_native_block_sparse_mqa_attn_return_logits_interface( + q, + k, + k_scale, + topk_block_index, + kv_block_size, + weights, + cu_ks, + cu_ke, + ) + + ms = do_bench(fn, warmup=50, rep=200) + # FLOPs: M × topk × kv_block_size × H × D (fp8×fp8) × 2 (mul+add). + total_flops = 2 * M * topk * kv_block_size * H * D + tflops = total_flops / (ms * 1e-3) / 1e12 + print(f" latency: {ms:.4f} ms ({tflops:.2f} fp8 TFLOPS)") + + +if __name__ == "__main__": + # Ref path materialises [M, topk, B, D] fp32 gathered_k which is ~M GB at + # topk=64, kv_block_size=128, D=128. Keep M modest to avoid OOM. + # (M, H, D, kv_block_size, topk, num_seqs) + for cfg in [ + (1024, 64, 128, 128, 64, 1), + (4096, 64, 128, 128, 64, 1), + (4096, 64, 128, 128, 64, 4), + (8192, 64, 128, 128, 64, 1), + (8192, 64, 128, 128, 64, 8), + (8192, 64, 128, 64, 128, 8), + (8192, 64, 128, 256, 32, 8), + ]: + test_fp8_block_sparse_mqa(*cfg) + torch.cuda.empty_cache() diff --git a/examples/dsa_hisa/clean_and_maintain_logits.py b/examples/dsa_hisa/clean_and_maintain_logits.py new file mode 100644 index 0000000000..12ff8a4c10 --- /dev/null +++ b/examples/dsa_hisa/clean_and_maintain_logits.py @@ -0,0 +1,121 @@ +import tilelang +from tilelang import language as T +from tilelang.profiler import do_bench +import torch + +from tilelang_utils import prepare_ks_ke_from_cu_seqlens + + +@tilelang.jit +def clean_and_maintain_logits_( + Logits, + CuSeqLenKS, + CuSeqLenKE, + threads: int = 512, + block_K: int = 4096, +): + seq_len, seq_len_kv = T.const("seq_len, seq_len_kv") + + dtype = T.float + indices_dtype = T.int32 + + Logits: T.Tensor[[seq_len, seq_len_kv], dtype] + CuSeqLenKS: T.Tensor[[seq_len], indices_dtype] + CuSeqLenKE: T.Tensor[[seq_len], indices_dtype] + + with T.Kernel(seq_len, threads=threads) as bx: + tx = T.thread_binding(0, threads, thread="threadIdx.x") + cu_k_s = CuSeqLenKS[bx] + cu_k_e = CuSeqLenKE[bx] + + for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)): + for k_i in T.serial(block_K // threads): + idx = n_i * block_K + k_i * threads + tx + if idx == cu_k_s or idx == cu_k_e - 1: + Logits[bx, idx] = T.infinity(dtype) + if idx < cu_k_s or idx >= cu_k_e: + Logits[bx, idx] = -T.infinity(dtype) + + +def clean_and_maintain_logits_interface( + logits: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +): + """In-place: applies +inf/-inf mask based on per-row [ks, ke).""" + clean_and_maintain_logits_(logits, cu_seqlen_ks, cu_seqlen_ke) + return logits + + +def ref_clean_and_maintain_logits( + logits: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + """Pure torch equivalent. Returns a new tensor (doesn't mutate the input).""" + M, N = logits.shape + out = logits.clone() + n = torch.arange(N, device=logits.device)[None, :] + mask_out = (n < cu_seqlen_ks.long()[:, None]) | (n >= cu_seqlen_ke.long()[:, None]) + out = out.masked_fill(mask_out, float("-inf")) + m_idx = torch.arange(M, device=logits.device) + out[m_idx, cu_seqlen_ks.long()] = float("inf") + out[m_idx, (cu_seqlen_ke - 1).clamp(min=0).long()] = float("inf") + return out + + +def test_clean_and_maintain_logits(M: int = 4096, N: int = 4096, num_seqs: int = 1): + """Correctness + speed test where `M` query rows are packed from + `num_seqs` equal-length causal sequences. Per-row ``cu_ks / cu_ke`` + is derived from ``prepare_ks_ke_from_cu_seqlens`` so each row sees + only the prefix of its own sequence (causal self-attention).""" + torch.manual_seed(0) + assert M % num_seqs == 0, f"M ({M}) must be divisible by num_seqs ({num_seqs})" + assert (M // num_seqs) <= N, "N must accommodate the longest sequence" + + per_seq = M // num_seqs + cu_seqlens = torch.arange(num_seqs + 1, device="cuda", dtype=torch.long) * per_seq + ks_long, ke_long = prepare_ks_ke_from_cu_seqlens(cu_seqlens) + cu_ks = ks_long.to(torch.int32).contiguous() + cu_ke = ke_long.to(torch.int32).clamp(max=N).contiguous() + + logits_init = torch.randn(M, N, device="cuda", dtype=torch.float32) + + # Run kernel in place on a copy. + got = logits_init.clone() + clean_and_maintain_logits_interface(got, cu_ks, cu_ke) + + # Ref. + ref = ref_clean_and_maintain_logits(logits_init, cu_ks, cu_ke) + + # Exact equality: this kernel only writes +/-inf, other positions untouched + # (ref clones the input and does the same). Compare directly. + assert torch.equal(torch.isposinf(got), torch.isposinf(ref)), "pos-inf mask differs" + assert torch.equal(torch.isneginf(got), torch.isneginf(ref)), "neg-inf mask differs" + finite = torch.isfinite(got) & torch.isfinite(ref) + torch.testing.assert_close(got[finite], ref[finite], rtol=0.0, atol=0.0) + print(f" correctness: PASS (M={M}, N={N}, num_seqs={num_seqs}, per_seq={per_seq})") + + # Speed. + def fn(): + logits = torch.randn(M, N, device="cuda", dtype=torch.float32) # fresh copy each iter + clean_and_maintain_logits_interface(logits, cu_ks, cu_ke) + return logits + + ms = do_bench(fn, warmup=50, rep=200) + # ~2 reads + 1 write of [M, N] f32, but mostly no-op except at mask boundaries. + bytes_moved = 2 * M * N * 4 + gbps = bytes_moved / (ms * 1e-3) / 1e9 + print(f" latency: {ms:.4f} ms ({gbps:.1f} GB/s)") + + +if __name__ == "__main__": + # (M, N, num_seqs) + for cfg in [ + (4096, 4096, 1), + (4096, 4096, 4), + (16384, 16384, 1), + (16384, 16384, 8), + (65536, 65536, 16), + ]: + test_clean_and_maintain_logits(*cfg) diff --git a/examples/dsa_hisa/fp8_block_mean_pooling.py b/examples/dsa_hisa/fp8_block_mean_pooling.py new file mode 100644 index 0000000000..1c9f90cc4c --- /dev/null +++ b/examples/dsa_hisa/fp8_block_mean_pooling.py @@ -0,0 +1,146 @@ +import tilelang +from tilelang import language as T +from tilelang.profiler import do_bench +import torch + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def fp8_native_block_mean_pooling( + K, + KScale, + dim: int = 128, + pooling_block_size: int = 128, + block_N: int = 64, + num_stages: int = 1, + threads: int = 256, +): + dtype = T.float8_e4m3fn + accum_dtype = T.float32 + FP8_MAX_INV = 1.0 / 448.0 + + seq_len_k = T.const("seq_len_k") + + K: T.Tensor[[seq_len_k, dim], dtype] + KScale: T.Tensor[[seq_len_k], accum_dtype] + + num_blocks = T.ceildiv(seq_len_k, pooling_block_size) + BlockedK = T.empty((num_blocks, dim), dtype) + BlockedKScale = T.empty((num_blocks,), accum_dtype) + + with T.Kernel(num_blocks, threads=threads) as bx: + index_k = T.alloc_fragment([block_N, dim], dtype) + scale = T.alloc_fragment([block_N], accum_dtype) + acc = T.alloc_fragment([dim], accum_dtype) + max_abs = T.alloc_fragment([1], accum_dtype) + T.fill(acc, 0.0) + + k_start = bx * pooling_block_size + k_end = T.min(k_start + pooling_block_size, seq_len_k) + cur_pooling_block_size = k_end - k_start + + for b_i in T.serial(T.ceildiv(cur_pooling_block_size, block_N)): + T.fill(index_k, 0.0) + + tl_block_s = k_start + b_i * block_N + tl_block_e = T.min(k_start + (b_i + 1) * block_N, k_end) + T.copy(K[tl_block_s : tl_block_s + block_N, :], index_k) + for bn_i in T.Parallel(block_N): + scale[bn_i] = KScale[tl_block_s + bn_i] + + for bn_i, d_i in T.Parallel(block_N, dim): + index_k[bn_i, d_i] = index_k[bn_i, d_i] * scale[bn_i] + + cur_tl_block_size = tl_block_e - tl_block_s + for n_i in T.parallel(block_N): + for d_i in T.parallel(dim): + if n_i >= cur_tl_block_size: + index_k[n_i, d_i] = T.cast(0, accum_dtype) + + T.reduce_sum(index_k, acc, dim=0, clear=False) + + inv_count = T.cast(1.0, accum_dtype) / T.cast(cur_pooling_block_size, accum_dtype) + for d_i in T.Parallel(dim): + acc[d_i] = acc[d_i] * inv_count + + # Re-quantize f32 mean to fp8 with a per-block scale. + T.reduce_absmax(acc, max_abs, dim=0, clear=True) + block_scale = T.max(max_abs[0] * T.cast(FP8_MAX_INV, accum_dtype), T.cast(1e-10, accum_dtype)) + inv_block_scale = T.cast(1.0, accum_dtype) / block_scale + + for d_i in T.Parallel(dim): + BlockedK[bx, d_i] = T.cast(acc[d_i] * inv_block_scale, dtype) + BlockedKScale[bx] = block_scale + + return BlockedK, BlockedKScale + + +def fp8_native_block_mean_pooling_interface(k: torch.Tensor, k_scale: torch.Tensor, k_block_size: int): + return fp8_native_block_mean_pooling(k, k_scale, dim=k.shape[1], pooling_block_size=k_block_size) + + +def ref_fp8_block_mean_pooling(k_fp8: torch.Tensor, k_scale: torch.Tensor, k_block_size: int) -> torch.Tensor: + """Spec: per-token dequant + per-block mean (dividing by actual valid count). + Returns the f32 mean (caller can compare against fp8*scale re-quant of the kernel).""" + N, D = k_fp8.shape + dequant = k_fp8.float() * k_scale[:, None] + num_blocks = (N + k_block_size - 1) // k_block_size + out = torch.empty(num_blocks, D, device=k_fp8.device, dtype=torch.float32) + for b in range(num_blocks): + s = b * k_block_size + e = min(s + k_block_size, N) + out[b] = dequant[s:e].sum(dim=0) / (e - s) + return out + + +def test_fp8_block_mean_pooling(N: int = 16384, D: int = 128, k_block_size: int = 128, num_seqs: int = 1): + """Correctness + speed test with `num_seqs` sequences of equal length + packed into the flat K buffer. + + NOTE: the flat mean-pool kernel is sequence-agnostic — it pools every + `k_block_size` consecutive tokens regardless of sequence boundaries. + `num_seqs` is accepted here for API consistency with the other kernels' + tests; it affects how `cu_seqlens` is laid out (shown for illustration) + but not the kernel's inputs / outputs. + """ + torch.manual_seed(0) + assert N % num_seqs == 0, f"N ({N}) must be divisible by num_seqs ({num_seqs})" + per_seq = N // num_seqs + + k_bf16 = torch.randn(N, D, device="cuda", dtype=torch.bfloat16) + k = k_bf16.to(torch.float8_e4m3fn) + k_scale = (0.1 + 0.01 * torch.rand(N, device="cuda", dtype=torch.float32)).contiguous() + + # Correctness. + blocked_k_fp8, blocked_k_scale = fp8_native_block_mean_pooling_interface(k, k_scale, k_block_size) + got = blocked_k_fp8.float() * blocked_k_scale[:, None] + ref = ref_fp8_block_mean_pooling(k, k_scale, k_block_size) + # fp8 re-quant: ~1/256 rel error on top of bf16-level precision. + torch.testing.assert_close(got, ref, rtol=5e-2, atol=5e-3) + print(f" correctness: PASS (N={N}, D={D}, k_block_size={k_block_size}, num_seqs={num_seqs}, per_seq={per_seq})") + + # Speed. + def fn(): + return fp8_native_block_mean_pooling_interface(k, k_scale, k_block_size) + + ms = do_bench(fn, warmup=50, rep=200) + num_blocks = (N + k_block_size - 1) // k_block_size + # Bytes moved: read N * D fp8 (K) + N * 4 f32 (scale) + write num_blocks * D fp8 + num_blocks * 4 f32. + bytes_moved = N * D + N * 4 + num_blocks * D + num_blocks * 4 + gbps = bytes_moved / (ms * 1e-3) / 1e9 + print(f" latency: {ms:.4f} ms ({gbps:.1f} GB/s)") + + +if __name__ == "__main__": + # (N, D, k_block_size, num_seqs) + for cfg in [ + (16384, 128, 128, 1), + (16384, 128, 128, 4), + (65536, 128, 128, 1), + (65536, 128, 128, 8), + (131072, 128, 128, 16), + ]: + test_fp8_block_mean_pooling(*cfg) diff --git a/examples/dsa_hisa/hisa.py b/examples/dsa_hisa/hisa.py new file mode 100644 index 0000000000..e9863874e1 --- /dev/null +++ b/examples/dsa_hisa/hisa.py @@ -0,0 +1,240 @@ +import torch +from tilelang.profiler import do_bench + +from fp8_block_mean_pooling import fp8_native_block_mean_pooling_interface +from pool_mqa_fp8 import pool_mqa_attn_return_logits_fp8_interface +from block_sparse_mqa_fp8 import fp8_native_block_sparse_mqa_attn_return_logits_interface +from clean_and_maintain_logits import clean_and_maintain_logits_interface +from tilelang_utils import prepare_ks_ke_from_cu_seqlens + + +def hisa_indexer( + q: torch.Tensor, # [M, H, D] fp8_e4m3fn + k: torch.Tensor, # [N, D] fp8_e4m3fn + k_scale: torch.Tensor, # [N] f32 + weights: torch.Tensor, # [M, H] f32 + cu_seqlen_ks: torch.Tensor, # [M] int32 — per-query K start (inclusive) + cu_seqlen_ke: torch.Tensor, # [M] int32 — per-query K end (exclusive) + *, + k_block_size: int, + block_topk: int, + topk_tokens: int, +) -> torch.Tensor: + """Run the full hisa prefill pipeline. + + Returns: ``[M, topk_tokens]`` int32 — each row is this query's top + ``topk_tokens`` K positions, expressed as offsets relative to + ``cu_seqlen_ks[m]`` (so ``0`` means the query's own K start). Slots + that fell outside ``[cu_seqlen_ks[m], cu_seqlen_ke[m])`` get ``-1``. + """ + # ------------------------------------------------------------------ + # Stage 0: fp8 mean-pool over K. Groups K into pool blocks of + # k_block_size tokens each; outputs one fp8 vector + f32 scale per + # pool block. Grid = (ceil(N/k_block_size),). + # ------------------------------------------------------------------ + blocked_k_fp8, blocked_k_scale = fp8_native_block_mean_pooling_interface( + k, + k_scale, + k_block_size, + ) # [Nb, D] fp8, [Nb] f32 + + # Translate the per-query K range from flat-token coords to + # pool-block coords (floor for start, ceil for end). + cu_seqlen_blocked_ks = cu_seqlen_ks // k_block_size + cu_seqlen_blocked_ke = (cu_seqlen_ke + k_block_size - 1) // k_block_size + + # ------------------------------------------------------------------ + # Stage 1: block-level Q·BlockedK score with ReLU + per-head weight + # reduction. Output is dense (kernel doesn't mask out-of-range). + # ------------------------------------------------------------------ + block_k_score = pool_mqa_attn_return_logits_fp8_interface( + q, + blocked_k_fp8, + blocked_k_scale, + weights, + cu_seqlen_blocked_ks, + cu_seqlen_blocked_ke, + ) # [M, Nb] f32 + + # Mask out-of-range entries to -inf and force +inf on first / last + # valid block so torch.topk picks the boundary blocks. + clean_and_maintain_logits_interface( + block_k_score, + cu_seqlen_blocked_ks, + cu_seqlen_blocked_ke, + ) + + # ------------------------------------------------------------------ + # Stage 1.5: top-block_topk selection. bfloat16 + sorted=False is + # ~40% faster than f32 and the downstream sparse_mqa doesn't rely + # on order. + # ------------------------------------------------------------------ + block_topk_eff = min(block_topk, block_k_score.shape[-1]) + topk_block_indices = torch.topk( + block_k_score.bfloat16(), + k=block_topk_eff, + dim=-1, + sorted=False, + ).indices # [M, block_topk_eff] int64 + + # ------------------------------------------------------------------ + # Stage 2: fp8 fine-grained Q·K MQA over only the selected + # blocks' raw tokens (block_topk_eff blocks × k_block_size tokens + # per query). The kernel writes -inf for positions outside + # [cu_seqlen_ks[m], cu_seqlen_ke[m]). + # ------------------------------------------------------------------ + block_sparse_logits = fp8_native_block_sparse_mqa_attn_return_logits_interface( + q, + k, + k_scale, + topk_block_indices, + k_block_size, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + ) # [M, block_topk_eff * k_block_size] f32 + + # ------------------------------------------------------------------ + # Stage 2.5: top-topk_tokens selection over the block_topk_eff + # × k_block_size candidate tokens. Gives per-query slot ids. + # ------------------------------------------------------------------ + topk_tokens_eff = min(topk_tokens, block_sparse_logits.shape[-1]) + relevant_topk_indices = torch.topk( + block_sparse_logits, + k=topk_tokens_eff, + dim=-1, + ).indices # [M, topk_tokens_eff] int64 + + # ------------------------------------------------------------------ + # Stage 3 (post, Python): translate slot ids → absolute K token + # position → per-query relative offset (matches vLLM indexer + # output buffer). Slots whose relative offset falls outside the + # query's visible range are set to -1. + # ------------------------------------------------------------------ + # slot = block_id_in_topk × k_block_size + offset_in_block + # where block_id_in_topk ∈ [0, block_topk_eff) + # absolute_k = topk_block_indices[m, block_id_in_topk] × k_block_size + offset_in_block + absolute_topk_block_indices = torch.gather( + topk_block_indices, + dim=-1, + index=(relevant_topk_indices // k_block_size), + ) + topk_indices = absolute_topk_block_indices * k_block_size + (relevant_topk_indices % k_block_size) + topk_indices = topk_indices.to(torch.int32) + + # Relative to this query's K start. + topk_indices -= cu_seqlen_ks[:, None] + mask_lo = topk_indices >= 0 + mask_hi = topk_indices - (cu_seqlen_ke - cu_seqlen_ks)[:, None] < 0 + mask = mask_lo & mask_hi + topk_indices = topk_indices.masked_fill(~mask, -1) + + return topk_indices + + +def test_hisa( + M: int = 1024, + H: int = 64, + D: int = 128, + k_block_size: int = 128, + block_topk: int = 8, + topk_tokens: int = 256, + num_seqs: int = 1, +): + """End-to-end smoke + speed test packing `num_seqs` equal-length causal + sequences into the flat [M, H, D] Q and [N=M, D] K tensors. + + Per-token ``cu_ks / cu_ke`` are produced by + ``prepare_ks_ke_from_cu_seqlens`` so each query sees only the prefix + of its own sequence. Validity checks are done per-query (so each + sequence's tail queries have fewer valid candidate slots). + """ + torch.manual_seed(0) + assert M % num_seqs == 0, f"M ({M}) must be divisible by num_seqs ({num_seqs})" + per_seq = M // num_seqs + N = M # causal self-attention, packed + + cu_seqlens = torch.arange(num_seqs + 1, device="cuda", dtype=torch.long) * per_seq + ks_long, ke_long = prepare_ks_ke_from_cu_seqlens(cu_seqlens) + cu_ks = ks_long.to(torch.int32).contiguous() + cu_ke = ke_long.to(torch.int32).contiguous() + + q_bf16 = torch.randn(M, H, D, device="cuda", dtype=torch.bfloat16) + q = q_bf16.to(torch.float8_e4m3fn) + k_bf16 = torch.randn(N, D, device="cuda", dtype=torch.bfloat16) + k = k_bf16.to(torch.float8_e4m3fn) + k_scale = (0.1 + 0.01 * torch.rand(N, device="cuda", dtype=torch.float32)).contiguous() + weights = torch.randn(M, H, device="cuda", dtype=torch.float32) + + topk_indices = hisa_indexer( + q, + k, + k_scale, + weights, + cu_ks, + cu_ke, + k_block_size=k_block_size, + block_topk=block_topk, + topk_tokens=topk_tokens, + ) + + # Sanity checks. + assert topk_indices.shape == (M, topk_tokens), f"unexpected output shape {tuple(topk_indices.shape)}" + assert topk_indices.dtype == torch.int32 + + # Every non-(-1) offset must be within [0, cu_ke[m] - cu_ks[m]). + valid = topk_indices >= 0 + spans = (cu_ke - cu_ks)[:, None].expand_as(topk_indices) + in_range = topk_indices < spans + assert (valid == (valid & in_range)).all(), "some valid offset falls outside its query's K window" + + # Per-query expected number of valid slots = min(cu_ke[m] - cu_ks[m], + # topk_tokens) (clipped by K range and by block_topk × k_block_size). + expected_valid = torch.minimum( + (cu_ke - cu_ks).clamp(min=0), + torch.tensor(min(topk_tokens, block_topk * k_block_size), device=cu_ke.device), + ) + got_valid = valid.sum(dim=-1).to(torch.int32) + frac_match = (got_valid == expected_valid).float().mean().item() + print( + f" shape: {tuple(topk_indices.shape)} " + f"valid_frac: {valid.float().mean().item():.4f} " + f"per-query valid count match: {frac_match:.4f} " + f"(num_seqs={num_seqs}, per_seq={per_seq})" + ) + + # Speed. + def fn(): + return hisa_indexer( + q, + k, + k_scale, + weights, + cu_ks, + cu_ke, + k_block_size=k_block_size, + block_topk=block_topk, + topk_tokens=topk_tokens, + ) + + ms = do_bench(fn, warmup=20, rep=50) + print( + f" latency: {ms:.3f} ms " + f"(M={M}, H={H}, D={D}, k_block_size={k_block_size}, " + f"block_topk={block_topk}, topk_tokens={topk_tokens}, num_seqs={num_seqs})" + ) + + +if __name__ == "__main__": + # Ref path in block_sparse_mqa materialises [M, topk, kvB, D] fp32 so + # stay modest on M (reuse the sparse_mqa module's sizing intuition). + for cfg in [ + dict(M=1024, H=64, D=128, k_block_size=128, block_topk=16, topk_tokens=256, num_seqs=1), + dict(M=1024, H=64, D=128, k_block_size=128, block_topk=16, topk_tokens=256, num_seqs=4), + dict(M=4096, H=64, D=128, k_block_size=128, block_topk=32, topk_tokens=1024, num_seqs=1), + dict(M=4096, H=64, D=128, k_block_size=128, block_topk=32, topk_tokens=1024, num_seqs=4), + dict(M=8192, H=64, D=128, k_block_size=128, block_topk=64, topk_tokens=2048, num_seqs=1), + dict(M=8192, H=64, D=128, k_block_size=128, block_topk=64, topk_tokens=2048, num_seqs=8), + ]: + test_hisa(**cfg) + torch.cuda.empty_cache() diff --git a/examples/dsa_hisa/pool_mqa_fp8.py b/examples/dsa_hisa/pool_mqa_fp8.py new file mode 100644 index 0000000000..515b311ac4 --- /dev/null +++ b/examples/dsa_hisa/pool_mqa_fp8.py @@ -0,0 +1,257 @@ +"""Stage-1 kernel: prefill pool-MQA over pooled (blocked) K. + +Input: fp8 Q ``[M, H, D]`` + fp8 BlockedK ``[Nb, D]`` + per-block f32 scale +``[Nb]`` + f32 Weights ``[M, H]`` + per-query ``cu_seqlen_blocked_ks/ke [M]``. + +For each query ``m`` and pool block ``n`` in ``[cu_seqlen_blocked_ks[m], +cu_seqlen_blocked_ke[m])``: + ``logits[m, n] = sum_h ReLU(Q[m, h] . BlockedK[n]) * BlockedKScale[n] * Weights[m, h]`` + +Out-of-range entries in the raw kernel output are undefined — caller should +zero-init the buffer or apply a separate mask kernel. +""" + +import tilelang +from tilelang import language as T +from tilelang.profiler import do_bench +import torch + +from tilelang_utils import prepare_ks_ke_from_cu_seqlens +from clean_and_maintain_logits import ( + clean_and_maintain_logits_interface, + ref_clean_and_maintain_logits, +) + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def pool_mqa_attn_return_logits_fp8( + IndexQ, + IndexBlockedK, + IndexBlockedKScale, + Logits, + Weights, + CuSeqLenBlockedKS, + CuSeqLenBlockedKE, + heads: int = 64, + index_dim: int = 128, + block_N: int = 256, + num_stages: int = 3, + threads: int = 512, + block_Q: int = 0, +): + # block_Q is the tile size for queries; `0` means "derive from heads". + if block_Q == 0: + block_Q = 128 // heads + fp8_dtype = T.float8_e4m3fn + accum_dtype = T.float32 + index_dtype = T.int32 + + seq_len, seq_len_blocked_kv = T.const("seq_len, seq_len_blocked_kv") + + IndexQ: T.Tensor[[seq_len * heads, index_dim], fp8_dtype] + IndexBlockedK: T.Tensor[[seq_len_blocked_kv, index_dim], fp8_dtype] + IndexBlockedKScale: T.Tensor[[seq_len_blocked_kv], accum_dtype] + Logits: T.Tensor[[seq_len, seq_len_blocked_kv], accum_dtype] + Weights: T.Tensor[[seq_len, heads], accum_dtype] + CuSeqLenBlockedKS: T.Tensor[[seq_len], index_dtype] + CuSeqLenBlockedKE: T.Tensor[[seq_len], index_dtype] + + with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx: + index_q_shared = T.alloc_shared([block_Q * heads, index_dim], fp8_dtype) + index_k_shared = T.alloc_shared([block_N, index_dim], fp8_dtype) + index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype) + s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype) + s_reshaped = T.reshape(s, (block_N, block_Q, heads)) + logits = T.alloc_fragment([block_N, block_Q], accum_dtype) + weights = T.alloc_fragment([block_Q, heads], accum_dtype) + + seq_len_i = bx * block_Q + + cu_k_s_min = T.alloc_var(index_dtype) + cu_k_e_max = T.alloc_var(index_dtype) + cu_k_s_min = 2147483647 + cu_k_e_max = -2147483648 + + for bq_i in T.serial(block_Q): + cu_k_s_min = T.min(cu_k_s_min, T.min(CuSeqLenBlockedKS[seq_len_i + bq_i], seq_len_blocked_kv)) + for bq_i in T.serial(block_Q): + cu_k_e_max = T.max(cu_k_e_max, T.min(CuSeqLenBlockedKE[seq_len_i + bq_i], seq_len_blocked_kv)) + + T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared) + T.copy(Weights[seq_len_i, 0], weights) + + for nbn_i in T.Pipelined(T.ceildiv(cu_k_e_max - cu_k_s_min, block_N), num_stages=num_stages): + T.copy(IndexBlockedK[cu_k_s_min + nbn_i * block_N, 0], index_k_shared) + T.copy(IndexBlockedKScale[cu_k_s_min + nbn_i * block_N], index_k_scale_fragment) + + T.gemm( + index_k_shared, + index_q_shared, + s, + transpose_B=True, + clear_accum=True, + policy=T.GemmWarpPolicy.FullCol, + ) + + for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): + s_reshaped[bn_i, bq_i, h_i] = T.max(s_reshaped[bn_i, bq_i, h_i] * index_k_scale_fragment[bn_i], 0) * weights[bq_i, h_i] + + T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) + + for bq_i, bn_i in T.Parallel(block_Q, block_N): + Logits[seq_len_i + bq_i, cu_k_s_min + nbn_i * block_N + bn_i] = logits[bn_i, bq_i] + + +def pool_mqa_attn_return_logits_fp8_interface( + q_fp8: torch.Tensor, + blocked_kv_fp8: torch.Tensor, + blocked_kv_scale: torch.Tensor, + weights_f32: torch.Tensor, + cu_seqlen_blocked_ks: torch.Tensor, + cu_seqlen_blocked_ke: torch.Tensor, + block_N: int = 256, +): + """Raw kernel invocation; zero-inits logits so positions the kernel + doesn't touch are 0 (matches the ref).""" + seq_len, heads, index_dim = q_fp8.shape + seq_len_blocked_kv = blocked_kv_fp8.shape[0] + + logits = torch.zeros([seq_len, seq_len_blocked_kv], device=q_fp8.device, dtype=torch.float32) + pool_mqa_attn_return_logits_fp8( + q_fp8.view(seq_len * heads, index_dim), + blocked_kv_fp8, + blocked_kv_scale, + logits, + weights_f32, + cu_seqlen_blocked_ks, + cu_seqlen_blocked_ke, + heads=heads, + index_dim=index_dim, + block_N=block_N, + ) + return logits + + +def ref_pool_mqa_fp8( + q_fp8: torch.Tensor, + blocked_kv_fp8: torch.Tensor, + blocked_kv_scale: torch.Tensor, + weights_f32: torch.Tensor, +) -> torch.Tensor: + """Spec: for each (m, n), logits[m, n] = sum_h ReLU(q[m,h] . k[n] * k_scale[n]) * w[m,h]. + Computes the full dense [M, Nb] grid — caller is responsible for any masking.""" + q_f = q_fp8.float() + k_f = blocked_kv_fp8.float() * blocked_kv_scale[:, None] + # score[m, n, h] = q[m, h] . k[n] + s = torch.einsum("mhd,nd->mnh", q_f, k_f) # [M, Nb, H] + logits = (s.clamp(min=0) * weights_f32[:, None, :]).sum(dim=-1) # [M, Nb] + return logits + + +def test_pool_mqa_fp8( + M: int = 32768, + H: int = 64, + D: int = 128, + k_block_size: int = 128, + block_N: int = 256, + num_seqs: int = 1, +): + """Correctness + speed test packing `num_seqs` equal-length causal + sequences into the [M, H, D] Q tensor. + + Per-query ``cu_seqlen_blocked_ks/ke`` is derived from the raw-token + packed ``cu_ks / cu_ke`` produced by ``prepare_ks_ke_from_cu_seqlens`` + (floor-divide / ceil-divide by ``k_block_size`` respectively). + + The kernel writes the per-tile ``[cu_k_s_min, cu_k_e_max)`` union of + visible K ranges — entries inside this union but outside an + individual query's visible range carry raw (unmasked) dot-product + values. To make correctness well-defined, we apply the + ``clean_and_maintain_logits`` mask (-inf for out-of-range, +inf for + the first/last valid block) to both the kernel output and the torch + reference before comparing — this mirrors what the hisa pipeline + does right after this kernel. + """ + torch.manual_seed(0) + assert M % num_seqs == 0, f"M ({M}) must be divisible by num_seqs ({num_seqs})" + per_seq = M // num_seqs + N_blocked = (M + k_block_size - 1) // k_block_size + assert N_blocked % block_N == 0, ( + f"N_blocked ({N_blocked}) must be a multiple of block_N ({block_N}). Pick M such that ceildiv(M, k_block_size) % block_N == 0." + ) + + # Per-token packed ks/ke (causal within each sequence), then translate + # to pool-block coords. + cu_seqlens = torch.arange(num_seqs + 1, device="cuda", dtype=torch.long) * per_seq + ks_long, ke_long = prepare_ks_ke_from_cu_seqlens(cu_seqlens) + cu_ks_token = ks_long.to(torch.int32).contiguous() + cu_ke_token = ke_long.to(torch.int32).contiguous() + cu_blocked_ks = (cu_ks_token // k_block_size).contiguous() + cu_blocked_ke = ((cu_ke_token + k_block_size - 1) // k_block_size).contiguous() + + q_bf16 = torch.randn(M, H, D, device="cuda", dtype=torch.bfloat16) + q = q_bf16.to(torch.float8_e4m3fn) + blocked_k_bf16 = torch.randn(N_blocked, D, device="cuda", dtype=torch.bfloat16) + blocked_k = blocked_k_bf16.to(torch.float8_e4m3fn) + blocked_k_scale = (0.1 + 0.01 * torch.rand(N_blocked, device="cuda", dtype=torch.float32)).contiguous() + weights = torch.randn(M, H, device="cuda", dtype=torch.float32) + + # Correctness — kernel + post-mask. + got = pool_mqa_attn_return_logits_fp8_interface( + q, + blocked_k, + blocked_k_scale, + weights, + cu_blocked_ks, + cu_blocked_ke, + block_N=block_N, + ) + clean_and_maintain_logits_interface(got, cu_blocked_ks, cu_blocked_ke) + + ref = ref_pool_mqa_fp8(q, blocked_k, blocked_k_scale, weights) + ref = ref_clean_and_maintain_logits(ref, cu_blocked_ks, cu_blocked_ke) + + # After the mask, +/-inf positions must agree exactly. Compare the + # remaining finite values under an fp8×fp8 GEMM tolerance. + assert torch.equal(torch.isposinf(got), torch.isposinf(ref)), "pos-inf mask differs" + assert torch.equal(torch.isneginf(got), torch.isneginf(ref)), "neg-inf mask differs" + finite = torch.isfinite(got) & torch.isfinite(ref) + torch.testing.assert_close(got[finite], ref[finite], rtol=5e-2, atol=5e-2) + print(f" correctness: PASS (M={M}, H={H}, D={D}, N_blocked={N_blocked}, block_N={block_N}, num_seqs={num_seqs}, per_seq={per_seq})") + + # Speed (kernel only — excludes the post mask). + def fn(): + return pool_mqa_attn_return_logits_fp8_interface( + q, + blocked_k, + blocked_k_scale, + weights, + cu_blocked_ks, + cu_blocked_ke, + block_N=block_N, + ) + + ms = do_bench(fn, warmup=50, rep=200) + # FLOPs: fp8×fp8 GEMM dominates = 2 * M * H * Nb * D (mul+add). + total_flops = 2 * M * H * N_blocked * D + tflops = total_flops / (ms * 1e-3) / 1e12 + print(f" latency: {ms:.4f} ms ({tflops:.2f} fp8 TFLOPS)") + + +if __name__ == "__main__": + # M × k_block_size^-1 must be a multiple of block_N=256. + # With k_block_size=128 → N_blocked = M/128; need N_blocked % 256 == 0 + # → M % 32768 == 0. + # (M, H, D, k_block_size, block_N, num_seqs) + for cfg in [ + (32768, 64, 128, 128, 256, 1), + (32768, 64, 128, 128, 256, 4), + (65536, 64, 128, 128, 256, 1), + (65536, 64, 128, 128, 256, 8), + (131072, 64, 128, 128, 256, 16), + ]: + test_pool_mqa_fp8(*cfg) diff --git a/examples/dsa_hisa/tilelang_utils.py b/examples/dsa_hisa/tilelang_utils.py new file mode 100644 index 0000000000..80a1441c5b --- /dev/null +++ b/examples/dsa_hisa/tilelang_utils.py @@ -0,0 +1,314 @@ +import torch +import torch.nn.functional as F +import functools +from typing import Callable, Any, Tuple + + +def tensor_cache( + fn: Callable[..., torch.Tensor], +) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent result of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + If the function is called again with the same input tensors, it will return the cached result. + + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + last_args: tuple | None = None + last_kwargs: dict | None = None + last_result: Any = None + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal last_args, last_kwargs, last_result + + if ( + (last_args is not None and last_kwargs is not None) + and (len(args) == len(last_args) and len(kwargs) == len(last_kwargs)) + and all(a is b for a, b in zip(args, last_args, strict=False)) + and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()) + ): + return last_result + + result = fn(*args, **kwargs) + last_args, last_kwargs, last_result = args, kwargs, result + return result + + return wrapper + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return torch.diff(cu_seqlens) + + +@tensor_cache +def prepare_cu_seqlens_from_lens( + lens: torch.LongTensor, + dtype: torch.dtype | None = torch.int32, +) -> torch.LongTensor: + return F.pad(lens.cumsum(dim=0, dtype=dtype), (1, 0)) + + +@tensor_cache +def prepare_lens_from_cu_seqlens( + cu_seqlens: torch.LongTensor, +) -> torch.LongTensor: + return torch.diff(cu_seqlens) + + +@tensor_cache +def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return torch.cat([torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) for n in prepare_lens(cu_seqlens).unbind()]) + + +@tensor_cache +def prepare_sequence_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return prepare_position_ids(cu_seqlens).eq(0).cumsum(0) - 1 + + +@tensor_cache +def prepare_token_indices(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + position_ids = prepare_position_ids(cu_seqlens) + return torch.stack([prepare_sequence_ids(cu_seqlens), position_ids], 1).to(cu_seqlens) + + +@tensor_cache +def prepare_cu_seqlens_from_position_ids( + position_ids: torch.LongTensor, + dtype: torch.dtype | None = torch.int32, +) -> torch.LongTensor: + starts = (position_ids == 0).nonzero(as_tuple=True)[0] + total_len = position_ids.new_tensor([position_ids.numel()]) + boundaries = torch.cat([starts, total_len]) + lens = torch.diff(boundaries) + cu_seqlens = prepare_cu_seqlens_from_lens(lens, dtype=dtype) + return cu_seqlens + + +@tensor_cache +def prepare_ks_ke_from_cu_seqlens( + cu_seqlens: torch.LongTensor, +) -> tuple[torch.LongTensor, torch.LongTensor]: + position_ids = prepare_position_ids(cu_seqlens) + sequence_ids = position_ids.eq(0).cumsum(0) - 1 + + ks = cu_seqlens[sequence_ids] + ke = ks + position_ids + 1 + + return ks, ke + + +@tensor_cache +def prepare_ks_ke_from_cu_seqlens_qk( + cu_seqlens_q: torch.LongTensor, + cu_seqlens_k: torch.LongTensor, +) -> tuple[torch.LongTensor, torch.LongTensor]: + position_ids_q = prepare_position_ids(cu_seqlens_q) + sequence_ids_q = position_ids_q.eq(0).cumsum(0) - 1 + + seqlens_q = prepare_lens(cu_seqlens_q) + seqlens_k = prepare_lens(cu_seqlens_k) + offset = seqlens_k - seqlens_q + + ks = cu_seqlens_k[sequence_ids_q] + ke = ks + position_ids_q + offset[sequence_ids_q] + 1 + + return ks, ke + + +def ceil_to_ue8m0(x: torch.Tensor): + assert x.view(-1).amax().item() > 0 + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + + +def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) + x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled, sf.squeeze() + + +def get_abs_err(y, x): + x = x.to(torch.float32) + y = y.to(torch.float32) + return (x - y).flatten().abs().max().item() + + +def get_err_ratio(y, x): + x = x.to(torch.float32) + y = y.to(torch.float32) + err = (x - y).flatten().square().mean().sqrt().item() + base = (x).flatten().square().mean().sqrt().item() + return err / base + + +def calculate_tensor_similarity(x, y, name="tensor"): + """ + Calculate similarity between two tensors using a normalized dot product metric. + + Unlike torch.testing.assert_close which uses absolute/relative tolerance based on + element-wise differences, this function computes a global similarity score: + sim = 2 * / (||x||^2 + ||y||^2) + + This metric is scale-invariant and measures the cosine-like similarity normalized + by the magnitude of both tensors. It returns 1 for identical tensors and values + closer to 0 for dissimilar ones. This is particularly useful for comparing tensors + with varying magnitudes where relative errors matter more than absolute differences. + + Args: + x: First tensor to compare + y: Second tensor to compare + name: Name of the tensor for logging purposes + + Returns: + Similarity score in range [0, 1] where 1 means identical + """ + x, y = x.data.double(), y.data.double() + denominator = (x * x + y * y).sum() + if denominator == 0: + print(f"\033[33mWARNING: {name} all zero\033[0m") + return 1 + sim = 2 * (x * y).sum() / denominator + return sim + + +def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): + """ + Assert that two tensors are similar using a global similarity metric. + + Key differences from torch.testing.assert_close: + - torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking + that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers + and requires all elements to satisfy the tolerance. + - assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the + normalized dot product. It's more robust to outliers and focuses on overall + tensor similarity rather than element-wise precision. This is better suited for + comparing large tensors where a few outlier elements shouldn't fail the test. + + Args: + x: First tensor to compare + y: Second tensor to compare + eps: Maximum allowed difference (1 - similarity), default 1e-8 + name: Name of the tensor for error messages + raise_assert: Whether to raise assertion error on failure + """ + sim = calculate_tensor_similarity(x, y, name) + diff = 1.0 - sim + if not (0 <= diff <= eps): + print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m") + if raise_assert: + assert False # noqa: B011 + + +@tensor_cache +def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, seq_len: int) -> torch.IntTensor: + seq_idx_for_q = torch.full((seq_len,), len(cu_seqlens_qs), dtype=torch.int32, device=cu_seqlens_qs.device) + for i in range(len(cu_seqlens_qs)): + seq_idx_for_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = i + return seq_idx_for_q + + +@tensor_cache +def cal_cu_seqlen_ks_for_q( + cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, cu_seqlens_ks: torch.LongTensor, seq_len: int +) -> torch.IntTensor: + cu_seqlen_ks_for_each_q = torch.gather( + input=torch.cat([cu_seqlens_ks, torch.full((1,), torch.iinfo(torch.int32).max, dtype=torch.int32, device=cu_seqlens_qs.device)]), + dim=0, + index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(), + ) + return cu_seqlen_ks_for_each_q.int() + + +@tensor_cache +def cal_cu_seqlen_ke_for_q( + cu_seqlens_qs: torch.LongTensor, + cu_seqlens_qe: torch.LongTensor, + cu_seqlens_ks: torch.LongTensor, + cu_seqlens_ke: torch.LongTensor, + q_start_idxs: torch.LongTensor, + seq_len: int, + kv_stride: int, +) -> torch.IntTensor: + cu_seqlen_ke_for_each_q = torch.gather( + input=torch.cat([cu_seqlens_ke, torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]), + dim=0, + index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(), + ) + casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,), dtype=torch.int32, device=cu_seqlens_qs.device) + for i in range(len(cu_seqlens_qs)): + casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = ( + torch.arange( + q_start_idxs[i], q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i], dtype=torch.int32, device=cu_seqlens_qs.device + ) + + 1 + ) // kv_stride + cu_seqlens_ks[i] + cu_seqlen_ke_for_each_q = torch.minimum(casual_cu_seqlen_ke_for_each_q, cu_seqlen_ke_for_each_q) + return cu_seqlen_ke_for_each_q.int() + + +def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1, average_q_len=512): + total_seqlen = per_cp_seqlen * cp_size + + cu_seqlens = torch.randint(0, average_q_len * 2, (total_seqlen // average_q_len * 2,)).cuda() + last_seq_id = torch.where(cu_seqlens.cumsum(0) >= total_seqlen)[0][0] + cu_seqlens = cu_seqlens[:last_seq_id] + + if cu_seqlens.sum() < total_seqlen: + cu_seqlens = torch.cat([cu_seqlens, torch.tensor([total_seqlen - cu_seqlens.sum()]).cuda()]) + + cu_seqlens_cumsum = torch.cumsum(cu_seqlens, dim=0) + cu_seqlens_k_cumsum = torch.cumsum(cu_seqlens // kv_stride, dim=0) + cu_seqlens_qs = torch.cat([torch.tensor([0]).cuda(), cu_seqlens_cumsum[:-1]]) + cu_seqlens_ks = torch.cat([torch.tensor([0]).cuda(), cu_seqlens_k_cumsum[:-1]]) + cu_seqlens_qe = cu_seqlens_cumsum.clone() + cu_seqlens_ke = cu_seqlens_k_cumsum.clone() + + cu_seqlens_ks_for_each_q = cal_cu_seqlen_ks_for_q( + cu_seqlens_qs=cu_seqlens_qs, + cu_seqlens_qe=cu_seqlens_qe, + cu_seqlens_ks=cu_seqlens_ks, + seq_len=total_seqlen, + ) + cu_seqlens_ke_for_each_q = cal_cu_seqlen_ke_for_q( + cu_seqlens_qs=cu_seqlens_qs, + cu_seqlens_qe=cu_seqlens_qe, + cu_seqlens_ks=cu_seqlens_ks, + cu_seqlens_ke=cu_seqlens_ke, + q_start_idxs=torch.zeros_like(cu_seqlens_qs), + seq_len=total_seqlen, + kv_stride=kv_stride, + ) + + assert per_cp_seqlen % 2 == 0 + per_chunk_seqlen = per_cp_seqlen // 2 + slice_short = slice(cp_rank * per_chunk_seqlen, (cp_rank + 1) * per_chunk_seqlen) + slice_long = slice( + total_seqlen - (cp_rank + 1) * per_chunk_seqlen, + total_seqlen - cp_rank * per_chunk_seqlen, + ) + ks = torch.cat( + [ + cu_seqlens_ks_for_each_q[slice_short], + cu_seqlens_ks_for_each_q[slice_long], + ] + ) + ke = torch.cat( + [ + cu_seqlens_ke_for_each_q[slice_short], + cu_seqlens_ke_for_each_q[slice_long], + ] + ) + assert len(ks) == len(ke) == per_cp_seqlen + return ks, ke From 8f4a08f56de7683162f5a84fdae7be3a5d98d8e2 Mon Sep 17 00:00:00 2001 From: Chaofan Lin Date: Sat, 25 Apr 2026 21:49:28 +0800 Subject: [PATCH 145/156] [Language] Small cleanup and notes for alloc global (#2100) --- tilelang/language/__init__.py | 2 +- tilelang/language/allocate.py | 45 +++++++++++++++++++---------------- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 47b4ee037b..43c70563a2 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -43,6 +43,7 @@ alloc_local, # noqa: F401 alloc_shared, # noqa: F401 alloc_fragment, # noqa: F401 + alloc_global, # noqa: F401 alloc_barrier, # noqa: F401 alloc_cluster_barrier, # noqa: F401 alloc_tmem, # noqa: F401 @@ -52,7 +53,6 @@ alloc_tcgen05_smem_desc, # noqa: F401 alloc_tcgen05_instr_desc, # noqa: F401 empty, # noqa: F401 - alloc_global, # noqa: F401 ) from tvm.script.parser.tir import allocate as allocate # noqa: F401 from .copy_op import copy, async_copy, tma_copy, transpose, c2d_im2col # noqa: F401 diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index 15a76897f7..47de8c44ce 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -9,6 +9,7 @@ - alloc_local: Allocates local memory buffers for thread-private storage - alloc_fragment: Allocates fragment memory buffers for specialized operations - alloc_var: Allocates single-element variable buffers + - alloc_global: Allocates global memory buffers as workspace Each function takes shape and dtype parameters and returns a TVM buffer object with the appropriate memory scope. @@ -144,6 +145,29 @@ def alloc_var(dtype: DType, *args, scope: str = "local.var", init: PrimExpr | in return buffer +def alloc_global(shape: ShapeType, dtype: DType, scope="global") -> Buffer: + """Allocate a global memory buffer as a global workspace. + + NOTE(chaofan): Memory allocated in this way doesn't go through torch allocator. Instead, + it's allocated directly by the corresponding backend APIs, like cudaMalloc. We + recommend allocating workspace in Torch side and pass it to the kernel via arguments, + which is managed under the hood by the framework. This API is mainly for testing + purposes and some specific purposes. + + NOTE(chaofan): This API may not be available in all backends (e.g. CuteDSL). + + Args: + shape (tuple): The shape of the buffer to allocate + dtype (str): The data type of the buffer (e.g., 'float32', 'int32') + scope (str, optional): The memory scope. Defaults to "global" + + Returns: + T.Buffer: A TVM buffer object allocated in global memory + """ + + return T.alloc_buffer(shape, dtype, scope=scope) + + def alloc_barrier(arrive_count: int | list[int]) -> Buffer: """Allocate a barrier buffer. @@ -330,24 +354,3 @@ def empty(*shape, dtype: DType = _dtypes.float32) -> Tensor: return OutTensor(shape, dtype) else: raise TypeError(f"Invalid shape {shape}") - - -def alloc_global(shape: ShapeType, dtype: DType, scope="global") -> Buffer: - """Allocate a global memory buffer as a global workspace. - - NOTE: Memory allocated in this way doesn't go through torch allocator. Instead, - it's allocated directly by the corresponding backend APIs, like cudaMalloc. We - recommend allocating workspace in Torch side and pass it to the kernel via arguments, - which is managed under the hood by the framework. This API is mainly for testing - purposes and some specific purposes. - - Args: - shape (tuple): The shape of the buffer to allocate - dtype (str): The data type of the buffer (e.g., 'float32', 'int32') - scope (str, optional): The memory scope. Defaults to "global" - - Returns: - T.Buffer: A TVM buffer object allocated in global memory - """ - - return T.alloc_buffer(shape, dtype, scope=scope) From 8e1215779aba7ffb6fed55083578ab55c1f84bd0 Mon Sep 17 00:00:00 2001 From: Tong WU <109033598+Rachmanino@users.noreply.github.com> Date: Mon, 27 Apr 2026 02:19:54 +0800 Subject: [PATCH 146/156] [Enhancement] Optimize hopper fp8 deepgemm tile size (#2103) optimize hopper fp8 deepgemm --- examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py b/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py index 0d72ed3678..22ce27de18 100644 --- a/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py +++ b/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py @@ -27,7 +27,7 @@ def tl_gemm( ], "Currently only float16 and float32 are supported" group_size = 128 - block_M = 128 + block_M = 64 block_K = 128 A_shape = (M, K) @@ -50,7 +50,6 @@ def main( A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_shared = T.alloc_shared(C_shared_shape, out_dtype) - Scale_C_shared = T.alloc_shared((block_M), T.float32) C_local = T.alloc_fragment(C_shared_shape, accum_dtype) C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype) @@ -65,15 +64,12 @@ def main( T.copy(A[by * block_M, k * block_K], A_shared) # Load B into shared memory T.copy(B[bx * block_N, k * block_K], B_shared) - # Load scale into shared memory Scale_B = scales_b[bx * block_N // group_size, k] - for i in T.Parallel(block_M): - Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B T.gemm(A_shared, B_shared, C_local, transpose_B=True) # Promote to enable 2xAcc for i, j in T.Parallel(block_M, block_N): - C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i] + C_local_accum[i, j] += C_local[i, j] * (scales_a[by * block_M + i, k] * Scale_B) T.clear(C_local) # TMA store T.copy(C_local_accum, C_shared) From ffdf5148b371baaf2ce2787e0596993e48e64737 Mon Sep 17 00:00:00 2001 From: TerminusAkivili Date: Mon, 27 Apr 2026 02:23:14 +0800 Subject: [PATCH 147/156] [CUDA][SM100] Include cuda_fp6.h when emitting FP6 types (#2102) --- src/target/codegen_cuda.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 20a5fa5515..fddc1994a0 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -592,6 +592,9 @@ std::string CodeGenTileLangCUDA::Finish() { if (enable_fp8_) { decl_stream << "#include \n"; } + if (enable_fp6_) { + decl_stream << "#include \n"; + } if (enable_fp4_) { decl_stream << "#include \n"; } From 6a29c76aee815a087c8eb22e6a44757c8cff3c92 Mon Sep 17 00:00:00 2001 From: Jiaxing Ding <61589029+Paran0idy@users.noreply.github.com> Date: Mon, 27 Apr 2026 02:23:42 +0800 Subject: [PATCH 148/156] feat: support cdna4 v_mfma_i32_16x16x64_i8 & v_mfma_i32_32x32x32_i8 (#2097) Co-authored-by: Jiaxing Ding --- src/target/codegen_hip.cc | 2 + src/tl_templates/hip/common.h | 1 + .../amd/test_tilelang_gemm_mfma_preshuffle.py | 269 +++++++++++++----- tilelang/intrinsics/mfma_layout.py | 36 +++ tilelang/intrinsics/mfma_macro_generator.py | 113 ++++++-- tilelang/intrinsics/utils.py | 6 +- 6 files changed, 335 insertions(+), 92 deletions(-) diff --git a/src/target/codegen_hip.cc b/src/target/codegen_hip.cc index fafb3475ca..8c578fa8af 100644 --- a/src/target/codegen_hip.cc +++ b/src/target/codegen_hip.cc @@ -1136,7 +1136,9 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { {"int32", "int"}, {"int8x4", "int32_t"}, {"int8x8", "int64_t"}, + {"int8x16", "int32x4"}, {"int32x4", "int32x4"}, + {"int32x16", "int32x16"}, {"float16", "half"}, {"float32", "float"}, {"float64", "double"}, diff --git a/src/tl_templates/hip/common.h b/src/tl_templates/hip/common.h index c7041b4a18..6b2da95a37 100644 --- a/src/tl_templates/hip/common.h +++ b/src/tl_templates/hip/common.h @@ -90,6 +90,7 @@ typedef __attribute__((__vector_size__(8 * sizeof(short)))) short bfloat16x8_vec; using int32x4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; +using int32x16 = __attribute__((__vector_size__(16 * sizeof(int)))) int; using float32x4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; using float32x16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; using float32x32 = __attribute__((__vector_size__(32 * sizeof(float)))) float; diff --git a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py index d4746c16d9..2df6b1962d 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py @@ -1,5 +1,6 @@ import pytest import torch +import tilelang import tilelang.testing from tilelang import tvm as tvm import tilelang.language as T @@ -22,27 +23,70 @@ def tl_matmul( a_transposed=False, b_transposed=True, k_pack=1, + a_preshuffle=False, b_preshuffle=False, b_g2l_load=False, + block_row_warps=None, + block_col_warps=None, + warp_row_tiles=None, + warp_col_tiles=None, + chunk=None, + num_stages=0, + panel_size=10, + mfma_shape=None, ): - micro_size_x = micro_size_y = micro_size_k = 16 - - if in_dtype.bits == 8: - micro_size_k = 32 - - block_row_warps = 2 - block_col_warps = 2 - warp_row_tiles = 32 - warp_col_tiles = 32 - - # for preshuffle_b, warp_layout = {1, 4} - if b_preshuffle: - block_row_warps = 1 - block_col_warps = 4 - warp_row_tiles = 64 - warp_col_tiles = 16 - - chunk = 256 * k_pack + """Build a TileLang MFMA kernel for ``A @ B^T`` (with optional preshuffle). + + The (block_*_warps, warp_*_tiles, chunk, num_stages, panel_size) parameters + expose the underlying CK-style template knobs so that an external autotuner + can pick per-shape tile / wave / pipeline configurations. + + ``mfma_shape`` selects which MFMA instruction to use, as an ``(M, N, K)`` + tuple. Supported int8 shapes on CDNA4 (gfx950): + (16, 16, 32) — default, ``v_mfma_i32_16x16x32_i8`` + (16, 16, 64) — doubled-K, ``v_mfma_i32_16x16x64_i8`` + (32, 32, 32) — doubled-MN, ``v_mfma_i32_32x32x32_i8`` + """ + if mfma_shape is not None: + micro_size_x, micro_size_y, micro_size_k = mfma_shape + else: + micro_size_x = micro_size_y = 16 + micro_size_k = 32 if in_dtype.bits == 8 else 16 + + if block_row_warps is None: + block_row_warps = 2 + if block_col_warps is None: + block_col_warps = 2 + if warp_row_tiles is None: + warp_row_tiles = max(32, micro_size_x) + if warp_col_tiles is None: + warp_col_tiles = max(32, micro_size_y) + + # Legacy heuristic: if the caller did not override any tile knob and we are + # in B-only preshuffle mode, keep the historical 1x4 warp grid. + _all_tile_defaults = ( + block_row_warps == 2 + and block_col_warps == 2 + and warp_row_tiles == max(32, micro_size_x) + and warp_col_tiles == max(32, micro_size_y) + ) + if _all_tile_defaults and b_preshuffle and not a_preshuffle: + block_row_warps, block_col_warps = 1, 4 + warp_row_tiles = max(64, micro_size_x) + warp_col_tiles = max(16, micro_size_y) + + if chunk is None: + chunk = 256 * k_pack + + # ---- structural validation (catch invalid configs early) ---- + assert warp_row_tiles % micro_size_x == 0, f"warp_row_tiles={warp_row_tiles} must be a multiple of micro_size_x={micro_size_x}" + assert warp_col_tiles % micro_size_y == 0, f"warp_col_tiles={warp_col_tiles} must be a multiple of micro_size_y={micro_size_y}" + assert chunk % (k_pack * micro_size_k) == 0, f"chunk={chunk} must be a multiple of k_pack*micro_size_k={k_pack * micro_size_k}" + block_M_check = block_row_warps * warp_row_tiles + block_N_check = block_col_warps * warp_col_tiles + assert M % block_M_check == 0, f"M={M} must be a multiple of block_M={block_M_check}" + assert N % block_N_check == 0, f"N={N} must be a multiple of block_N={block_N_check}" + assert K % chunk == 0, f"K={K} must be a multiple of chunk={chunk}" pack_size_k = micro_size_k * k_pack @@ -52,7 +96,14 @@ def tl_matmul( block_N = block_col_warps * warp_col_tiles block_K = chunk - A_shape = (K, M) if a_transposed else (M, K) + if a_preshuffle: + A_shape = ( + (K // pack_size_k, M // micro_size_x, pack_size_k, micro_size_x) + if a_transposed + else (M // micro_size_x, K // pack_size_k, micro_size_x, pack_size_k) + ) + else: + A_shape = (K, M) if a_transposed else (M, K) if b_preshuffle: B_shape = ( (N // micro_size_y, K // pack_size_k, micro_size_y, pack_size_k) @@ -62,7 +113,14 @@ def tl_matmul( else: B_shape = (N, K) if b_transposed else (K, N) - A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K) + if a_preshuffle: + A_shared_shape = ( + (block_K // pack_size_k, block_M // micro_size_x, pack_size_k, micro_size_x) + if a_transposed + else (block_M // micro_size_x, block_K // pack_size_k, micro_size_x, pack_size_k) + ) + else: + A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K) if b_preshuffle: B_shared_shape = ( (block_N // micro_size_y, block_K // pack_size_k, micro_size_y, pack_size_k) @@ -93,7 +151,9 @@ def tl_matmul( warp_col_tiles=warp_col_tiles, chunk=chunk, k_pack=k_pack, + a_preshuffle=a_preshuffle, b_preshuffle=b_preshuffle, + mfma_shape=mfma_shape, ) @T.prim_func @@ -109,35 +169,51 @@ def main( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout( - { - A_shared: make_swizzle_layout(A_shared), - } - ) + layout_map = {} + if not a_preshuffle: + layout_map[A_shared] = make_swizzle_layout(A_shared) + if not b_preshuffle: + layout_map[B_shared] = make_swizzle_layout(B_shared) + if layout_map: + T.annotate_layout(layout_map) num_ko = K // block_K num_ki = block_K // (k_pack * micro_size_k) # Improve L2 Cache - T.use_swizzle(panel_size=10) + T.use_swizzle(panel_size=panel_size) T.clear(C_local) - for ko in T.Pipelined(num_ko, num_stages=0): + for ko in T.Pipelined(num_ko, num_stages=num_stages): # Load A into shared memory - if a_transposed: - T.copy(A[ko * block_K, by * block_M], A_shared) + if a_preshuffle: + if a_transposed: + for k, i, kk, ii in T.Parallel(block_K // pack_size_k, block_M // micro_size_x, pack_size_k, micro_size_x): + A_shared[k, i, kk, ii] = A[ko * block_K // pack_size_k + k, by * block_M // micro_size_x + i, kk, ii] + else: + for i, k, ii, kk in T.Parallel(block_M // micro_size_x, block_K // pack_size_k, micro_size_x, pack_size_k): + A_shared[i, k, ii, kk] = A[by * block_M // micro_size_x + i, ko * block_K // pack_size_k + k, ii, kk] else: - T.copy(A[by * block_M, ko * block_K], A_shared) + if a_transposed: + T.copy(A[ko * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, ko * block_K], A_shared) # Load B into shared memory if b_g2l_load is False: - if b_transposed: - for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // pack_size_k, micro_size_y, pack_size_k): - B_shared[j, k, jj, kk] = B[bx * block_N // micro_size_y + j, ko * block_K // pack_size_k + k, jj, kk] + if b_preshuffle: + if b_transposed: + for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // pack_size_k, micro_size_y, pack_size_k): + B_shared[j, k, jj, kk] = B[bx * block_N // micro_size_y + j, ko * block_K // pack_size_k + k, jj, kk] + else: + for k, j, kk, jj in T.Parallel(block_K // pack_size_k, block_N // micro_size_y, pack_size_k, micro_size_y): + B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k, bx * block_N // micro_size_y + j, kk, jj] else: - for k, j, kk, jj in T.Parallel(block_K // pack_size_k, block_N // micro_size_y, pack_size_k, micro_size_y): - B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k, bx * block_N // micro_size_y + j, kk, jj] + if b_transposed: + T.copy(B[bx * block_N, ko * block_K], B_shared) + else: + T.copy(B[ko * block_K, bx * block_N], B_shared) for ki in T.serial(0, num_ki): # Load A S2L @@ -201,21 +277,35 @@ def assert_tl_matmul_correctness( a_transposed=False, b_transposed=True, k_pack=1, + a_preshuffle=False, b_preshuffle=False, b_g2l_load=False, + mfma_shape=None, ): - matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack, b_preshuffle, b_g2l_load) - print(matmul) + matmul = tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + a_transposed, + b_transposed, + k_pack, + a_preshuffle, + b_preshuffle, + b_g2l_load, + mfma_shape=mfma_shape, + ) kernel = tilelang.compile(matmul) - src_code = kernel.get_kernel_source() - # src_code is the generated cuda source - assert src_code is not None + assert kernel.get_kernel_source() is not None + A_shape = (K, M) if a_transposed else (M, K) B_shape = (N, K) if b_transposed else (K, N) if in_dtype == T.int8: A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8) B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8) - elif "float8" in str(in_dtype): # for T.float8_e4m3fnuz in gfx942 and T.float8_e4m3fn in gfx950 + elif "float8" in str(in_dtype): A = torch.rand(A_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype)) B = torch.rand(B_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype)) else: @@ -224,52 +314,46 @@ def assert_tl_matmul_correctness( C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - B_preshuffle = B - if b_preshuffle: - B_preshuffle = shuffle_weight(B_preshuffle, k_pack=k_pack, is_transpose=b_transposed) - kernel(A, B_preshuffle, C) - else: - kernel(A, B, C) - - print(kernel.get_kernel_source()) - - profiler = kernel.get_profiler() - - latency = profiler.do_bench() - - # Ensure that the latency is not None - assert latency is not None + shuf_layout = (mfma_shape[0], mfma_shape[2]) if mfma_shape else (16, 32) + A_in = shuffle_weight(A, layout=shuf_layout, k_pack=k_pack, is_transpose=not a_transposed) if a_preshuffle else A + B_in = shuffle_weight(B, layout=shuf_layout, k_pack=k_pack, is_transpose=b_transposed) if b_preshuffle else B + kernel(A_in, B_in, C) if a_transposed and b_transposed: - # Get Reference Result ref_c = torch.matmul(A.T.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) elif a_transposed and not b_transposed: - # Get Reference Result ref_c = torch.matmul(A.T.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) elif not a_transposed and b_transposed: - # Get Reference Result ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) else: - # Get Reference Result ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) - print(C) - print(ref_c) - torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) @pytest.mark.parametrize( - "M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack, b_preshuffle, b_g2l_load", + "M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack, a_preshuffle, b_preshuffle, b_g2l_load", [ - (256, 256, 512, T.int8, T.int32, T.int32, False, True, 1, True, False), - (256, 256, 512, T.int8, T.int32, T.int32, False, False, 1, True, False), - (256, 256, 512, T.int8, T.int32, T.int32, False, True, 2, True, False), - (256, 256, 512, T.int8, T.int32, T.int32, False, False, 2, True, False), - (256, 256, 512, determine_fp8_type(), T.float32, T.float32, False, True, 1, True, False), - (256, 256, 512, determine_fp8_type(), T.float32, T.float32, False, False, 1, True, False), - (256, 256, 512, determine_fp8_type(), T.float32, T.float32, False, True, 2, True, False), - (256, 256, 512, determine_fp8_type(), T.float32, T.float32, False, False, 2, True, False), + # B-only preshuffle + (256, 256, 512, T.int8, T.int32, T.int32, False, True, 1, False, True, False), + (256, 256, 512, T.int8, T.int32, T.int32, False, False, 1, False, True, False), + (256, 256, 512, T.int8, T.int32, T.int32, False, True, 2, False, True, False), + (256, 256, 512, T.int8, T.int32, T.int32, False, False, 2, False, True, False), + (256, 256, 512, determine_fp8_type(), T.float32, T.float32, False, True, 1, False, True, False), + (256, 256, 512, determine_fp8_type(), T.float32, T.float32, False, False, 1, False, True, False), + (256, 256, 512, determine_fp8_type(), T.float32, T.float32, False, True, 2, False, True, False), + (256, 256, 512, determine_fp8_type(), T.float32, T.float32, False, False, 2, False, True, False), + # No preshuffle + (256, 256, 512, T.int8, T.int32, T.int32, False, True, 1, False, False, False), + (256, 256, 512, determine_fp8_type(), T.float32, T.float32, False, True, 1, False, False, False), + # A-only preshuffle + (256, 256, 512, T.int8, T.int32, T.int32, False, True, 1, True, False, False), + (256, 256, 512, T.int8, T.int32, T.int32, True, True, 1, True, False, False), + (256, 256, 512, determine_fp8_type(), T.float32, T.float32, False, True, 1, True, False, False), + (256, 256, 512, determine_fp8_type(), T.float32, T.float32, True, True, 1, True, False, False), + # A+B preshuffle together (default 2x2 warp grid) + (256, 256, 512, T.int8, T.int32, T.int32, False, True, 1, True, True, False), + (256, 256, 512, determine_fp8_type(), T.float32, T.float32, False, True, 1, True, True, False), ], ) @tilelang.testing.requires_rocm @@ -283,6 +367,7 @@ def test_assert_tl_matmul( a_transposed, b_transposed, k_pack, + a_preshuffle, b_preshuffle, b_g2l_load, ): @@ -296,10 +381,50 @@ def test_assert_tl_matmul( a_transposed=a_transposed, b_transposed=b_transposed, k_pack=k_pack, + a_preshuffle=a_preshuffle, b_preshuffle=b_preshuffle, b_g2l_load=b_g2l_load, ) +# ---- CDNA4 extended MFMA shapes: 16x16x64 and 32x32x32 for int8 ---- +@pytest.mark.parametrize( + "M, N, K, in_dtype, out_dtype, accum_dtype, b_transposed, k_pack, b_preshuffle, mfma_shape", + [ + # v_mfma_i32_16x16x64_i8 — doubled K throughput (kp=1 only, micro_k=64) + (256, 256, 512, T.int8, T.int32, T.int32, True, 1, False, (16, 16, 64)), + (256, 256, 512, T.int8, T.int32, T.int32, True, 1, True, (16, 16, 64)), + # v_mfma_i32_32x32x32_i8 — doubled MN throughput + (256, 256, 512, T.int8, T.int32, T.int32, True, 1, False, (32, 32, 32)), + (256, 256, 512, T.int8, T.int32, T.int32, True, 1, True, (32, 32, 32)), + ], +) +@tilelang.testing.requires_rocm +def test_assert_tl_matmul_extended_mfma( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + b_transposed, + k_pack, + b_preshuffle, + mfma_shape, +): + assert_tl_matmul_correctness( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype=accum_dtype, + b_transposed=b_transposed, + k_pack=k_pack, + b_preshuffle=b_preshuffle, + mfma_shape=mfma_shape, + ) + + if __name__ == "__main__": tilelang.testing.main() diff --git a/tilelang/intrinsics/mfma_layout.py b/tilelang/intrinsics/mfma_layout.py index d8af979887..62eb0dd9e0 100644 --- a/tilelang/intrinsics/mfma_layout.py +++ b/tilelang/intrinsics/mfma_layout.py @@ -128,6 +128,42 @@ def shared_16x64_to_local_64x16_layout_B(i, j): return thread_id, local +def shared_32x32_to_local_64x16_layout_C(i, j): + thread_id = (i % 8 // 4) * 32 + j + local_id = (i // 8) * 4 + i % 4 + return thread_id, local_id + + +def thread_id_shared_access_64x16_to_32x32_layout_C_n_m(thread_id, local_id): + i = (thread_id // 32) * 4 + local_id % 4 + (local_id // 4) * 8 + j = thread_id % 32 + return i, j + + +def shared_32x32_to_local_64x16_layout_A(i, j): + thread_id = i + 32 * (j // 16) + local_id = j % 16 + return thread_id, local_id + + +def thread_id_shared_access_64x16_to_32x32_layout_A(thread_id, local_id): + i = thread_id % 32 + j = (thread_id // 32) * 16 + local_id + return i, j + + +def shared_32x32_to_local_64x16_layout_B(i, j): + thread_id = j + 32 * (i // 16) + local_id = i % 16 + return thread_id, local_id + + +def thread_id_shared_access_64x16_to_32x32_layout_B(thread_id, local_id): + i = (thread_id // 32) * 16 + local_id + j = thread_id % 32 + return i, j + + def make_mfma_swizzle_layout(shared_buf, vecSize=8): dtype = shared_buf.dtype shape = shared_buf.shape diff --git a/tilelang/intrinsics/mfma_macro_generator.py b/tilelang/intrinsics/mfma_macro_generator.py index d482fe6a33..737934d254 100644 --- a/tilelang/intrinsics/mfma_macro_generator.py +++ b/tilelang/intrinsics/mfma_macro_generator.py @@ -23,6 +23,8 @@ shared_16x32_to_local_64x8_layout_B, shared_16x64_to_local_64x16_layout_A, shared_16x64_to_local_64x16_layout_B, + shared_32x32_to_local_64x16_layout_A, + shared_32x32_to_local_64x16_layout_B, thread_id_shared_access_64x1_to_16x4_layout_A, thread_id_shared_access_64x1_to_4x16_layout_B, thread_id_shared_access_64x4_to_16x16_layout_A, @@ -31,6 +33,8 @@ thread_id_shared_access_64x8_to_16x32_layout_B, thread_id_shared_access_64x16_to_16x64_layout_A, thread_id_shared_access_64x16_to_16x64_layout_B, + thread_id_shared_access_64x16_to_32x32_layout_A, + thread_id_shared_access_64x16_to_32x32_layout_B, ) lift = convert @@ -41,8 +45,6 @@ class MatrixCoreIntrinEmitter: To eliminate Python syntax within TIR Macro. """ - M_DIM = 16 - N_DIM = 16 WARP_SIZE = 64 dtype_abbrv = { "float16": "fp16", @@ -83,6 +85,7 @@ def __init__( b_preshuffle: bool | None = False, thread_var: Var | None = None, target: Target | None = None, + mfma_shape: tuple[int, int, int] | None = None, ): self.a_dtype = a_dtype self.b_dtype = b_dtype @@ -100,7 +103,7 @@ def __init__( self.warp_col_tiles = warp_col_tiles self.chunk = chunk self._initialize_k_pack(k_pack) - self._initialize_k_dim(a_dtype) + self._initialize_mfma_shape(mfma_shape, a_dtype) self._normalize_gfx950_f16_bf16_kpack() self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE) @@ -116,7 +119,21 @@ def __init__( self.num_elems_per_byte = num_elems_per_byte self.thread_var = thread_var - def _initialize_k_dim(self, a_dtype=T.float16): + def _initialize_mfma_shape(self, mfma_shape: tuple[int, int, int] | None, a_dtype): + """Set ``(M_DIM, N_DIM, k_dim)`` from an explicit shape or auto-detect. + + Supported shapes on CDNA4 (gfx950) for int8: + (16, 16, 32) — ``v_mfma_i32_16x16x32_i8`` (default, CDNA3-compatible) + (16, 16, 64) — ``v_mfma_i32_16x16x64_i8`` (doubled K throughput) + (32, 32, 32) — ``v_mfma_i32_32x32x32_i8`` (doubled MN throughput) + """ + if mfma_shape is not None: + self.M_DIM, self.N_DIM, self.k_dim = mfma_shape + return + + # Auto-detect: same logic as the old _initialize_k_dim, defaulting to 16x16. + self.M_DIM = 16 + self.N_DIM = 16 if isinstance(a_dtype, str): if a_dtype in ["float8_e4m3fn", "float8_e4m3fnuz", "float8_e5m2", "float8_e5m2fnuz", T.int8]: self.k_dim = 32 @@ -218,6 +235,12 @@ def _initialize_b_preshuffle(self, b_preshuffle: bool | None = False): def get_ldmatrix_index_map(self, is_b=False): k_dim = self.k_dim * self.k_pack transposed = self.a_transposed if not is_b else self.b_transposed + mn_dim = self.N_DIM if is_b else self.M_DIM + + # 32x32 MFMA instructions use a different set of layout maps. + if mn_dim == 32: + return self._get_ldmatrix_index_map_32(is_b, k_dim, transposed) + if k_dim == 4: index_map = shared_4x16_to_local_64x1_layout_B if transposed else shared_16x4_to_local_64x1_layout_A reverse_index_map = ( @@ -266,9 +289,45 @@ def get_ldmatrix_index_map(self, is_b=False): return index_map, reverse_index_map + def _get_ldmatrix_index_map_32(self, is_b, k_dim, transposed): + """Index maps for 32x32xK MFMA instructions (M_DIM=N_DIM=32). + + For int8 with mfma_shape=(32,32,32): k_dim*k_pack=32, local_size=16 + so the tile is 32×32 → 64×16 thread/local layout. + """ + # For 32x32 MFMA, the A/B layouts have tile dim 32 on the MN side. + # The maps are symmetric: A[row=M, col=K] and B[row=K, col=N] use the + # same underlying 32xK↔64×local layout, just with axes swapped for transpose. + if k_dim != 32: + raise ValueError(f"32x32 MFMA with effective k_dim={k_dim} is not supported yet; only k_dim=32 (k_dim*k_pack) is implemented.") + # k_dim=32 → shared_32x32_to_local_64x16 + if not is_b: + # A: non-transposed = [M=32, K=32], transposed = [K=32, M=32] + if transposed: + index_map = shared_32x32_to_local_64x16_layout_B + reverse_index_map = thread_id_shared_access_64x16_to_32x32_layout_B + else: + index_map = shared_32x32_to_local_64x16_layout_A + reverse_index_map = thread_id_shared_access_64x16_to_32x32_layout_A + else: + # B: transposed = [N=32, K=32], non-transposed = [K=32, N=32] + if transposed: + index_map = shared_32x32_to_local_64x16_layout_A + reverse_index_map = thread_id_shared_access_64x16_to_32x32_layout_A + else: + index_map = shared_32x32_to_local_64x16_layout_B + reverse_index_map = thread_id_shared_access_64x16_to_32x32_layout_B + return index_map, reverse_index_map + def get_store_index_map(self, inverse: bool = False) -> IndexMap: warp_size, local_size_c = self.WARP_SIZE, self.local_size_out - index_map = IndexMap.from_func(mfma_store_index_map, index_dtype=T.int32) + if self.M_DIM == 32: + from .utils import mfma_store_index_map_32x32 + + map_func = mfma_store_index_map_32x32 + else: + map_func = mfma_store_index_map + index_map = IndexMap.from_func(map_func, index_dtype=T.int32) if not inverse: return index_map inverse_index_map = index_map.inverse([warp_size, local_size_c]) @@ -459,6 +518,13 @@ def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None): C_buf_dims = len(C_buf.shape) assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D" + if M_DIM == 32: + from .mfma_layout import thread_id_shared_access_64x16_to_32x32_layout_C_n_m + + _store_map = thread_id_shared_access_64x16_to_32x32_layout_C_n_m + else: + _store_map = mfma_store_index_map + # STS # MFMA Store must be in simulated instead of TVM Intrins # As TVM Intrins is like a hack that the threadIdx.x should be always @@ -468,7 +534,7 @@ def _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding): tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) for i, j in T.grid(warp_rows, warp_cols): for local_id in T.vectorized(local_size_out): - row, col = T.meta_var(mfma_store_index_map(tx, local_id)) + row, col = T.meta_var(_store_map(tx, local_id)) if C_buf_dims == 2: C_buf[(warp_m * warp_rows + i) * M_DIM + row, (warp_n * warp_cols + j) * N_DIM + col] = C_local_buf[ i * (warp_cols * local_size_out) + j * local_size_out + local_id @@ -483,7 +549,7 @@ def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) for i, j in T.grid(warp_rows, warp_cols): for local_id in T.vectorized(local_size_out): - row, col = T.meta_var(mfma_store_index_map(tx, local_id)) + row, col = T.meta_var(_store_map(tx, local_id)) C_buf[ (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, (pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + col ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id] @@ -531,8 +597,16 @@ def make_mfma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = " transform_func_sr_b: Callable = None k_dim = self.k_dim * self.k_pack + mn_dim = self.M_DIM # M_DIM == N_DIM for all supported shapes - if k_dim == 4: + if mn_dim == 32: + if k_dim != 32: + raise ValueError( + f"make_mfma_load_layout: 32x32 MFMA with effective k_dim={k_dim} is not supported; only k_dim=32 is implemented." + ) + transform_func_sr_a = shared_32x32_to_local_64x16_layout_A + transform_func_sr_b = shared_32x32_to_local_64x16_layout_A + elif k_dim == 4: transform_func_sr_a = shared_16x4_to_local_64x1_layout_A transform_func_sr_b = shared_16x4_to_local_64x1_layout_A elif k_dim == 16: @@ -739,6 +813,7 @@ def __init__( b_preshuffle: bool | None = False, thread_var: Var | None = None, target: Target | None = None, + mfma_shape: tuple[int, int, int] | None = None, ): super().__init__( a_dtype=a_dtype, @@ -756,13 +831,14 @@ def __init__( k_pack=k_pack, is_m_first=is_m_first, thread_var=thread_var, + mfma_shape=mfma_shape, target=target, ) self._initialize_preshuffle(a_preshuffle, b_preshuffle) - def _initialize_preshuffle(self, a_preshuffle: bool, b_preshuffle: bool): - if a_preshuffle is not None: - self.a_preshuffle = a_preshuffle + def _initialize_preshuffle(self, a_preshuffle: bool | None, b_preshuffle: bool | None): + # Parent does not set a_preshuffle; default False when omitted. + self.a_preshuffle = False if a_preshuffle is None else a_preshuffle if b_preshuffle is not None: self.b_preshuffle = b_preshuffle @@ -773,10 +849,10 @@ def ldmatrix_a(self, A_local_buf, A_buf, ki, rk=0, pid_m=None, pid_n=None): local_size_a = self.local_size_a k_pack = self.k_pack is_transposed = self.a_transposed - current_frame = T.KernelLaunchFrame.Current() - thread_binding = current_frame.get_thread_binding() + thread_binding = self.get_thread_binding() _, reverse_index_map = self.get_ldmatrix_index_map(is_b=False) - is_global = pid_m is not None and pid_n is not None + # A-side global load only depends on the row block id (pid_m) + is_global = pid_m is not None # no preshuffle, use the default implementation if self.a_preshuffle is False: @@ -828,7 +904,6 @@ def _warp_ldmatrix_a_shared( ) A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, col] else: - print(self.a_preshuffle) for i in T.serial(warp_rows): for local_id in T.vectorized(k_pack * local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) @@ -848,13 +923,13 @@ def ldmatrix_b(self, B_local_buf, B_buf, ki, rk=0, pid_m=None, pid_n=None): local_size_b = self.local_size_b k_pack = self.k_pack is_transposed = self.b_transposed - current_frame = T.KernelLaunchFrame.Current() - thread_binding = current_frame.get_thread_binding() + thread_binding = self.get_thread_binding() _, reverse_index_map = self.get_ldmatrix_index_map(is_b=True) - is_global = pid_m is not None and pid_n is not None + # B-side global load only depends on the column block id (pid_n) + is_global = pid_n is not None if self.b_preshuffle is False: - return super().ldmatrix_b(B_local_buf, B_buf, ki, rk, pid_m, pid_n) + return super().ldmatrix_b(B_local_buf, B_buf, ki, rk) @T.macro def _warp_ldmatrix_b_global( diff --git a/tilelang/intrinsics/utils.py b/tilelang/intrinsics/utils.py index 724d3f94a2..30556499f9 100644 --- a/tilelang/intrinsics/utils.py +++ b/tilelang/intrinsics/utils.py @@ -10,7 +10,7 @@ mma_store_32x8_to_shared_16x16_layout, mma_store_32x2_to_shared_8x8_layout_fp64, ) -from .mfma_layout import thread_id_shared_access_64x4_to_16x16_layout_C_n_m +from .mfma_layout import thread_id_shared_access_64x4_to_16x16_layout_C_n_m, thread_id_shared_access_64x16_to_32x32_layout_C_n_m from .mma_layout import get_swizzle_layout # noqa: F401 from .mma_layout import make_mma_swizzle_layout # noqa: F401 @@ -93,6 +93,10 @@ def mfma_store_index_map(thread_id, local_id): return thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id) +def mfma_store_index_map_32x32(thread_id, local_id): + return thread_id_shared_access_64x16_to_32x32_layout_C_n_m(thread_id, local_id) + + def get_mma_micro_size(dtype: Literal["float16", "int8"]): # TODO(lei): FP8 related precision support. # Basic Tensor Core Matrix Multiply operation Unit From 53a4c9866bfc44b70a687103270018489eaa7300 Mon Sep 17 00:00:00 2001 From: Zhang Jason Date: Mon, 27 Apr 2026 02:31:01 +0800 Subject: [PATCH 149/156] [AMD] [gfx950]Fix multiple HIP codegen bugs to support TileKernel (#2099) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix HIP codegen for sync_warp, sync_grid, and local.var initialisation * [AMD/HIP] Fix warp_reduce VGPR bug, ShuffleNode packing, and Pipelined LDS overflow Extends PR #2096 with three additional fixes for CDNA (MI350) targets: Fix 1 — src/tl_templates/hip/reduce.h: warp_reduce width=32 The old 6-step butterfly called __shfl_xor(value, 32) without a width argument. On CDNA (wave64) with 32 active threads, lanes 32-63 are inactive and hold uninitialised VGPRs, producing NaN in reduce_max / reduce_sum / AllReduce. Fix: remove the step-32 shuffle; pass width=32 to all remaining 5 steps so every shuffle stays within the [0,31] group. Fix 2 — src/target/codegen_hip.cc + src/tl_templates/hip/common.h: ShuffleNode bfloat16x2 / float16x2 packing CodeGenC emitted `uint1(a, b)` for bfloat16x2 construction, which is an invalid HIP constructor call. Fix: override VisitExpr_(ShuffleNode) in CodeGenTileLangHIP to emit `uint1{__pack_bfloat162(a, b)}` / `uint1{ __pack_half2(a, b)}` using aggregate initialisation. Also add five bfloat16x2 math overloads for uint1 carrier (abs2/max2/min2/add2/mul2). Fix 3 — src/transform/pipeline_planning.cc: skip T.Pipelined(num_stages>1) Double-buffering doubled LDS per loop-body buffer. On CDNA (≤128 KB LDS per workgroup), this caused hipModuleLaunchKernel EINVAL. Fix: when TargetIsRocm() && num_stages > 1, skip pipeline planning and fall back to a plain sequential loop with synchronous T.copy. Also: fix __habs(hip_bfloat16) and __habs(float16_t) in common.h to use __builtin_memcpy instead of reinterpret_cast to avoid strict-aliasing UB (as flagged by CodeRabbit on PR #2096). Tests: 19 new cases added to testing/python/amd/test_tilelang_hip_codegen.py covering all three fixes. All 42 tests pass on MI350 (gfx950). * [AMD/HIP] Merge test_tilelang_hip_bugfixes.py into test_tilelang_hip_codegen.py Consolidate all HIP regression tests into a single file. The merged file covers all six fixes with 32 tests total (previously split across two files with duplicated test cases for warp_reduce, pipelined GEMM, and ShuffleNode). Changes versus the two individual files: - Deduplicated test_warp_reduce_no_nan (identical in both files) - Deduplicated test_pipelined_no_lds_overflow / test_pipelined_shared_mem_no_launch_error - Deduplicated test_pipelined_multi_stage_fp16_gemm - Merged bfloat16 shuffle tests: source check + runtime correctness in one function - Kept PR #2096 source-inspection tests (alloc_var, sync_warp, sync_grid) - Added runtime tests from bugfixes: inf init, serial loop accumulation, float scalar readback, two-group wave64 reduce, float16 shuffle * fixup: correct LDS size comment — gfx950 has 160 KB, not 128 KB gfx942 (CDNA3 / MI300X) has 64 KB LDS per workgroup. gfx950 (CDNA4 / MI350) has 160 KB LDS per workgroup (see PR #2058). The old comment said '128 KB' which is wrong for both generations. Updated pipeline_planning.cc and the test docstrings to reflect the correct per-architecture limits. * update for format checking --- src/target/codegen_hip.cc | 65 +- src/target/codegen_hip.h | 3 + src/target/rt_mod_hip.cc | 4 + src/target/stubs/hip.cc | 12 + src/target/stubs/hip.h | 14 +- src/tl_templates/hip/common.h | 74 ++ src/tl_templates/hip/hip_fp8.h | 85 +- src/tl_templates/hip/reduce.h | 34 +- src/transform/pipeline_planning.cc | 17 + .../python/amd/test_tilelang_hip_codegen.py | 757 ++++++++++++++++++ tilelang/language/allocate.py | 12 +- 11 files changed, 1030 insertions(+), 47 deletions(-) create mode 100644 testing/python/amd/test_tilelang_hip_codegen.py diff --git a/src/target/codegen_hip.cc b/src/target/codegen_hip.cc index 8c578fa8af..fd22bab4f3 100644 --- a/src/target/codegen_hip.cc +++ b/src/target/codegen_hip.cc @@ -210,6 +210,10 @@ std::string CodeGenTileLangHIP::Finish() { decl_stream << "#include \n"; } + if (need_cooperative_groups_) { + decl_stream << "#include \n"; + } + decl_stream << "#include \n"; decl_stream << "#include \n"; decl_stream << "#include \n"; @@ -710,6 +714,32 @@ std::string CodeGenTileLangHIP::CastFromTo(std::string value, DataType from, return os.str(); } +void CodeGenTileLangHIP::VisitExpr_(const ShuffleNode *op, + std::ostream &os) { // NOLINT(*) + // For bfloat16x2 / float16x2 construction from two scalar lanes, emit the + // HIP pack intrinsic instead of the invalid `uint1(a, b)` that CodeGenC + // would generate (HIP's uint1 has no two-argument constructor). + // `uint1{value}` aggregate initialisation is valid: uint1 is defined by ROCm + // as HIP_vector_type which has a single .x member. + DataType t = op->dtype; + bool is_bf16x2 = t.is_bfloat16() && t.lanes() == 2; + bool is_fp16x2 = t.is_float16() && t.lanes() == 2; + if ((is_bf16x2 || is_fp16x2) && op->vectors.size() == 2 && + op->vectors[0].dtype().lanes() == 1 && + op->vectors[1].dtype().lanes() == 1) { + std::string e0 = PrintExpr(op->vectors[0]); + std::string e1 = PrintExpr(op->vectors[1]); + if (is_bf16x2) { + os << "uint1{__pack_bfloat162(" << e0 << ", " << e1 << ")}"; + } else { + os << "uint1{__pack_half2(" << e0 << ", " << e1 << ")}"; + } + return; + } + // Default path for all other shuffle patterns. + CodeGenC::VisitExpr_(op, os); +} + void CodeGenTileLangHIP::VisitExpr_(const CastNode *op, std::ostream &os) { DataType from_ty = op->value.dtype(); DataType target_ty = op->dtype; @@ -836,6 +866,15 @@ std::string CodeGenTileLangHIP::GetBufferRef(DataType t, buffer_str = temp.str(); } + if (scope.empty()) { + scope = GetPtrStorageScope(buffer->data); + } + // local.var is a scalar — no indexing needed. + if (scope == "local.var") { + os << vid; + return os.str(); + } + std::string index_str = PrintExpr(index); if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { // This is a special case, because CodegenCUDA::PrintType() @@ -940,6 +979,14 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { } else if (op->op.same_as(tl::pack_b16())) { os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", " << this->PrintExpr(op->args[1]) << ")"; + } else if (op->op.same_as(tl::sync_grid())) { + this->need_cooperative_groups_ = true; + this->PrintIndent(); + this->stream << "cooperative_groups::this_grid().sync();\n"; + } else if (op->op.same_as(tl::sync_warp())) { + // AMD wavefronts execute in lockstep, so intra-wavefront convergence is + // guaranteed by the hardware. __syncwarp() has no HIP equivalent and is a + // no-op here. The mask argument (if present) is intentionally ignored. } else if (op->op.same_as(tl::any_sync())) { ICHECK_EQ(op->args.size(), 2U) << "tl.any_sync expects ."; // HIP __any takes only the predicate; the mask is ignored because @@ -1437,7 +1484,23 @@ void CodeGenTileLangHIP::VisitStmt_(const AllocateNode *op) { scope == "shared") { constant_size = constant_size / (32 / op->dtype.bits()); } - stream << ' ' << vid << '[' << constant_size << "];\n"; + + if (scope == "local.var") { + // Single-element variable: emit an initializer so the value is defined. + // Default to 0; respect the user-provided tl.local_var_init annotation. + PrimExpr init = tir::make_const(op->dtype, 0); + auto init_it = op->annotations.find(tl::attr::kLocalVarInit); + if (init_it != op->annotations.end()) { + PrimExpr user_init = Downcast((*init_it).second); + if (!user_init.dtype().is_void() && user_init.dtype() != op->dtype) { + user_init = tir::Cast(op->dtype, user_init); + } + init = user_init; + } + stream << ' ' << vid << " = " << PrintExpr(init) << ";\n"; + } else { + stream << ' ' << vid << '[' << constant_size << "];\n"; + } } RegisterHandleType(op->buffer_var.get(), op->dtype); diff --git a/src/target/codegen_hip.h b/src/target/codegen_hip.h index 0dfef6d609..1030352e95 100644 --- a/src/target/codegen_hip.h +++ b/src/target/codegen_hip.h @@ -49,6 +49,7 @@ class CodeGenTileLangHIP final : public CodeGenC { void VisitExpr_(const FloatImmNode *op, std::ostream &os) final; void VisitExpr_(const CallNode *op, std::ostream &os) final; void VisitExpr_(const CastNode *op, std::ostream &os) final; + void VisitExpr_(const ShuffleNode *op, std::ostream &os) final; // NOLINT(*) void VisitStmt_(const AllocateNode *op) final; void VisitStmt_(const AttrStmtNode *op) final; @@ -73,6 +74,8 @@ class CodeGenTileLangHIP final : public CodeGenC { friend void PrintConst(const FloatImmNode *op, std::ostream &os, CodeGenTileLangHIP *p); + // whether need hip_cooperative_groups.h + bool need_cooperative_groups_{false}; // whether need math_constants.h bool need_math_constants_h_{false}; // whether need mfma.h diff --git a/src/target/rt_mod_hip.cc b/src/target/rt_mod_hip.cc index 63d7cea5b1..e4b45b5ddb 100644 --- a/src/target/rt_mod_hip.cc +++ b/src/target/rt_mod_hip.cc @@ -41,6 +41,10 @@ ExtractFuncInfo(const IRModule &mod) { dtype = DataType::Int(32); info.arg_types.push_back(dtype); } + if (f->HasNonzeroAttr("use_cooperative_groups")) { + info.launch_param_tags.push_back( + runtime::launch_param::kUseCooperativeLaunch); + } if (auto opt = f->GetAttr>( tir::attr::kKernelLaunchParams)) { for (const auto &tag : opt.value()) { diff --git a/src/target/stubs/hip.cc b/src/target/stubs/hip.cc index 4725e71c94..131b123f57 100644 --- a/src/target/stubs/hip.cc +++ b/src/target/stubs/hip.cc @@ -139,6 +139,7 @@ HIPDriverAPI CreateHIPDriverAPI() { LOOKUP(hipModuleGetFunction_, "hipModuleGetFunction") LOOKUP(hipModuleGetGlobal_, "hipModuleGetGlobal") LOOKUP(hipModuleLaunchKernel_, "hipModuleLaunchKernel") + LOOKUP(hipModuleLaunchCooperativeKernel_, "hipModuleLaunchCooperativeKernel") #undef LOOKUP return api; @@ -397,6 +398,17 @@ hipError_t hipModuleLaunchKernel(hipFunction_t f, unsigned int gridDimX, sharedMemBytes, stream, kernelParams, extra); } +hipError_t hipModuleLaunchCooperativeKernel( + hipFunction_t f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, unsigned int sharedMemBytes, hipStream_t stream, + void **kernelParams) { + // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) + return HIPDriverAPI::get()->hipModuleLaunchCooperativeKernel_( + f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, + sharedMemBytes, stream, kernelParams); +} + // --- Minimal HSA wrappers // ------------------------------------------------------- diff --git a/src/target/stubs/hip.h b/src/target/stubs/hip.h index 68208a1a8f..5030e2e7dc 100644 --- a/src/target/stubs/hip.h +++ b/src/target/stubs/hip.h @@ -80,7 +80,8 @@ _(hipModuleUnload) \ _(hipModuleGetFunction) \ _(hipModuleGetGlobal) \ - _(hipModuleLaunchKernel) + _(hipModuleLaunchKernel) \ + _(hipModuleLaunchCooperativeKernel) namespace tvm::tl::hip { @@ -132,6 +133,11 @@ struct TILELANG_HIP_STUB_API HIPDriverAPI { unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, hipStream_t, void **, void **); + hipError_t (*hipModuleLaunchCooperativeKernel_)(hipFunction_t, unsigned int, + unsigned int, unsigned int, + unsigned int, unsigned int, + unsigned int, unsigned int, + hipStream_t, void **); static HIPDriverAPI *get(); static bool is_available(); @@ -211,4 +217,10 @@ TILELANG_HIP_STUB_API hipError_t hipModuleLaunchKernel( unsigned int blockDimZ, unsigned int sharedMemBytes, hipStream_t stream, void **kernelParams, void **extra); +TILELANG_HIP_STUB_API hipError_t hipModuleLaunchCooperativeKernel( + hipFunction_t f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, unsigned int sharedMemBytes, hipStream_t stream, + void **kernelParams); + } // extern "C" diff --git a/src/tl_templates/hip/common.h b/src/tl_templates/hip/common.h index 6b2da95a37..49c5b6c1e4 100644 --- a/src/tl_templates/hip/common.h +++ b/src/tl_templates/hip/common.h @@ -111,6 +111,27 @@ TL_DEVICE unsigned __pack_bfloat162(const bfloat16_t x, const bfloat16_t y) { return (v1 << 16) | v0; } +// __habs overloads for hip_bfloat16 and float16_t to resolve ambiguity on ROCm. +// hip_bfloat16 != __hip_bfloat16, and float16_t != __half, so the standard +// __habs overloads don't match exactly, causing ambiguous overload errors. +// Use __builtin_memcpy instead of reinterpret_cast to avoid strict-aliasing UB. +__device__ __forceinline__ hip_bfloat16 __habs(hip_bfloat16 a) { + uint16_t bits; + __builtin_memcpy(&bits, &a, sizeof(bits)); + bits &= 0x7FFFu; + hip_bfloat16 result; + __builtin_memcpy(&result, &bits, sizeof(result)); + return result; +} +__device__ __forceinline__ float16_t __habs(float16_t a) { + uint16_t bits; + __builtin_memcpy(&bits, &a, sizeof(bits)); + bits &= 0x7FFFu; + float16_t result; + __builtin_memcpy(&result, &bits, sizeof(result)); + return result; +} + namespace tl { // Packed x2 element-wise math helpers (HIP scalar fallbacks) @@ -167,6 +188,59 @@ TL_DEVICE float2 abs2(float2 a) { return out; } +// Packed bfloat16x2 overloads for uint1 carrier. +// On HIP, uint1 = HIP_vector_type (32-bit), already defined +// by ROCm via amd_hip_vector_types.h — no additional typedef needed. +// A packed bfloat16x2 word layout: +// bits [15: 0] = first bfloat16 (sign at bit 15) +// bits [31:16] = second bfloat16 (sign at bit 31) +// These overloads are required by the HIP codegen's ShuffleNode packing path +// (VisitExpr_ ShuffleNode emits uint1{__pack_bfloat162(a, b)}). +TL_DEVICE uint1 abs2(uint1 val) { + // Clear both sign bits simultaneously. + return uint1{val.x & 0x7FFF7FFFu}; +} +TL_DEVICE uint1 max2(uint1 a, uint1 b) { + bfloat16_t a0, a1, b0, b1; + __builtin_memcpy(&a0, &a.x, sizeof(a0)); + __builtin_memcpy(&a1, (char *)&a.x + 2, sizeof(a1)); + __builtin_memcpy(&b0, &b.x, sizeof(b0)); + __builtin_memcpy(&b1, (char *)&b.x + 2, sizeof(b1)); + bfloat16_t r0 = (float)a0 > (float)b0 ? a0 : b0; + bfloat16_t r1 = (float)a1 > (float)b1 ? a1 : b1; + return uint1{__pack_bfloat162(r0, r1)}; +} +TL_DEVICE uint1 min2(uint1 a, uint1 b) { + bfloat16_t a0, a1, b0, b1; + __builtin_memcpy(&a0, &a.x, sizeof(a0)); + __builtin_memcpy(&a1, (char *)&a.x + 2, sizeof(a1)); + __builtin_memcpy(&b0, &b.x, sizeof(b0)); + __builtin_memcpy(&b1, (char *)&b.x + 2, sizeof(b1)); + bfloat16_t r0 = (float)a0 < (float)b0 ? a0 : b0; + bfloat16_t r1 = (float)a1 < (float)b1 ? a1 : b1; + return uint1{__pack_bfloat162(r0, r1)}; +} +TL_DEVICE uint1 add2(uint1 a, uint1 b) { + bfloat16_t a0, a1, b0, b1; + __builtin_memcpy(&a0, &a.x, sizeof(a0)); + __builtin_memcpy(&a1, (char *)&a.x + 2, sizeof(a1)); + __builtin_memcpy(&b0, &b.x, sizeof(b0)); + __builtin_memcpy(&b1, (char *)&b.x + 2, sizeof(b1)); + bfloat16_t r0 = (bfloat16_t)((float)a0 + (float)b0); + bfloat16_t r1 = (bfloat16_t)((float)a1 + (float)b1); + return uint1{__pack_bfloat162(r0, r1)}; +} +TL_DEVICE uint1 mul2(uint1 a, uint1 b) { + bfloat16_t a0, a1, b0, b1; + __builtin_memcpy(&a0, &a.x, sizeof(a0)); + __builtin_memcpy(&a1, (char *)&a.x + 2, sizeof(a1)); + __builtin_memcpy(&b0, &b.x, sizeof(b0)); + __builtin_memcpy(&b1, (char *)&b.x + 2, sizeof(b1)); + bfloat16_t r0 = (bfloat16_t)((float)a0 * (float)b0); + bfloat16_t r1 = (bfloat16_t)((float)a1 * (float)b1); + return uint1{__pack_bfloat162(r0, r1)}; +} + // Any template TL_DEVICE bool Any(T *a, int size) { for (int i = 0; i < size; i++) { diff --git a/src/tl_templates/hip/hip_fp8.h b/src/tl_templates/hip/hip_fp8.h index 224f8ff59d..ee9e3237bf 100644 --- a/src/tl_templates/hip/hip_fp8.h +++ b/src/tl_templates/hip/hip_fp8.h @@ -199,28 +199,33 @@ struct __align__(16) fp8_e8_16_t { __device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z, fp8_e4_t w) { - // reinterpret the 4 fp8_e4_t values to signed char value and shift - signed char x_char = *reinterpret_cast(&x); - signed char y_char = *reinterpret_cast(&y); - signed char z_char = *reinterpret_cast(&z); - signed char w_char = *reinterpret_cast(&w); - int res = (w_char << 24) | (z_char << 16) | (y_char << 8) | x_char; + // reinterpret the 4 fp8_e4_t values to unsigned char to avoid sign extension + // on shift + unsigned char x_char = *reinterpret_cast(&x); + unsigned char y_char = *reinterpret_cast(&y); + unsigned char z_char = *reinterpret_cast(&z); + unsigned char w_char = *reinterpret_cast(&w); + unsigned int res = ((unsigned int)w_char << 24) | + ((unsigned int)z_char << 16) | + ((unsigned int)y_char << 8) | (unsigned int)x_char; return *reinterpret_cast(&res); } __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z, fp8_e4_t w, fp8_e4_t v, fp8_e4_t u, fp8_e4_t t, fp8_e4_t s) { - signed char x_char = *reinterpret_cast(&x); - signed char y_char = *reinterpret_cast(&y); - signed char z_char = *reinterpret_cast(&z); - signed char w_char = *reinterpret_cast(&w); - signed char v_char = *reinterpret_cast(&v); - signed char u_char = *reinterpret_cast(&u); - signed char t_char = *reinterpret_cast(&t); - signed char s_char = *reinterpret_cast(&s); - int a = (w_char << 24) | (z_char << 16) | (y_char << 8) | x_char; - int b = (s_char << 24) | (t_char << 16) | (u_char << 8) | v_char; + unsigned char x_char = *reinterpret_cast(&x); + unsigned char y_char = *reinterpret_cast(&y); + unsigned char z_char = *reinterpret_cast(&z); + unsigned char w_char = *reinterpret_cast(&w); + unsigned char v_char = *reinterpret_cast(&v); + unsigned char u_char = *reinterpret_cast(&u); + unsigned char t_char = *reinterpret_cast(&t); + unsigned char s_char = *reinterpret_cast(&s); + unsigned int a = ((unsigned int)w_char << 24) | ((unsigned int)z_char << 16) | + ((unsigned int)y_char << 8) | (unsigned int)x_char; + unsigned int b = ((unsigned int)s_char << 24) | ((unsigned int)t_char << 16) | + ((unsigned int)u_char << 8) | (unsigned int)v_char; fp8_e4_8_t res; res.x = *reinterpret_cast(&a); res.y = *reinterpret_cast(&b); @@ -233,26 +238,34 @@ __device__ fp8_e4_16_t make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, fp8_e4_t y1, fp8_e4_t y2, fp8_e4_t y3, fp8_e4_t y4, fp8_e4_t y5, fp8_e4_t y6, fp8_e4_t y7) { - signed char x0_char = *reinterpret_cast(&x0); - signed char x1_char = *reinterpret_cast(&x1); - signed char x2_char = *reinterpret_cast(&x2); - signed char x3_char = *reinterpret_cast(&x3); - signed char x4_char = *reinterpret_cast(&x4); - signed char x5_char = *reinterpret_cast(&x5); - signed char x6_char = *reinterpret_cast(&x6); - signed char x7_char = *reinterpret_cast(&x7); - signed char y0_char = *reinterpret_cast(&y0); - signed char y1_char = *reinterpret_cast(&y1); - signed char y2_char = *reinterpret_cast(&y2); - signed char y3_char = *reinterpret_cast(&y3); - signed char y4_char = *reinterpret_cast(&y4); - signed char y5_char = *reinterpret_cast(&y5); - signed char y6_char = *reinterpret_cast(&y6); - signed char y7_char = *reinterpret_cast(&y7); - int a = (x3_char << 24) | (x2_char << 16) | (x1_char << 8) | x0_char; - int b = (x7_char << 24) | (x6_char << 16) | (x5_char << 8) | x4_char; - int c = (y3_char << 24) | (y2_char << 16) | (y1_char << 8) | y0_char; - int d = (y7_char << 24) | (y6_char << 16) | (y5_char << 8) | y4_char; + unsigned char x0_char = *reinterpret_cast(&x0); + unsigned char x1_char = *reinterpret_cast(&x1); + unsigned char x2_char = *reinterpret_cast(&x2); + unsigned char x3_char = *reinterpret_cast(&x3); + unsigned char x4_char = *reinterpret_cast(&x4); + unsigned char x5_char = *reinterpret_cast(&x5); + unsigned char x6_char = *reinterpret_cast(&x6); + unsigned char x7_char = *reinterpret_cast(&x7); + unsigned char y0_char = *reinterpret_cast(&y0); + unsigned char y1_char = *reinterpret_cast(&y1); + unsigned char y2_char = *reinterpret_cast(&y2); + unsigned char y3_char = *reinterpret_cast(&y3); + unsigned char y4_char = *reinterpret_cast(&y4); + unsigned char y5_char = *reinterpret_cast(&y5); + unsigned char y6_char = *reinterpret_cast(&y6); + unsigned char y7_char = *reinterpret_cast(&y7); + unsigned int a = ((unsigned int)x3_char << 24) | + ((unsigned int)x2_char << 16) | + ((unsigned int)x1_char << 8) | (unsigned int)x0_char; + unsigned int b = ((unsigned int)x7_char << 24) | + ((unsigned int)x6_char << 16) | + ((unsigned int)x5_char << 8) | (unsigned int)x4_char; + unsigned int c = ((unsigned int)y3_char << 24) | + ((unsigned int)y2_char << 16) | + ((unsigned int)y1_char << 8) | (unsigned int)y0_char; + unsigned int d = ((unsigned int)y7_char << 24) | + ((unsigned int)y6_char << 16) | + ((unsigned int)y5_char << 8) | (unsigned int)y4_char; fp8_e4_8_t res_x; res_x.x = *reinterpret_cast(&a); res_x.y = *reinterpret_cast(&b); diff --git a/src/tl_templates/hip/reduce.h b/src/tl_templates/hip/reduce.h index 7185585ee9..eaf41b6a91 100644 --- a/src/tl_templates/hip/reduce.h +++ b/src/tl_templates/hip/reduce.h @@ -261,12 +261,34 @@ template struct CumSum2D { template TL_DEVICE T warp_reduce(T value, ReduceOp op) { - value = op(value, __shfl_xor(value, 32)); - value = op(value, __shfl_xor(value, 16)); - value = op(value, __shfl_xor(value, 8)); - value = op(value, __shfl_xor(value, 4)); - value = op(value, __shfl_xor(value, 2)); - value = op(value, __shfl_xor(value, 1)); + // 5-step butterfly reduction with width=32, matching CUDA's 32-lane warp + // semantics on CDNA (wave64) and RDNA (wave32) targets. + // + // On CDNA (wave64, 64-lane wavefronts) with threads=32 per block: + // Only lanes 0-31 are active; lanes 32-63 hold uninitialised VGPRs. + // The old step `__shfl_xor(value, 32)` without a width argument read + // from those uninitialised lanes, producing NaN or garbage. + // width=32 confines every shuffle to the [0,31] group, preventing this. + // + // On CDNA with threads=64 (one full wave64): + // width=32 splits the wavefront into two independent 32-lane groups. + // Lane 0 of each group holds the group partial sum, which is exactly what + // kernels that assume logical warp_size=32 expect for their inter-warp + // shared-memory reductions. + // + // On RDNA (wave32, 32-lane wavefronts): + // width=32 equals the wavefront size, so behaviour is identical to + // omitting the width argument. + // + // Note: this intentionally preserves 32-lane logical-warp semantics for + // backward compatibility. Full wave64 utilisation (6-step, width=64) would + // require restructuring inter-warp shared-memory communication in all + // kernels and is deferred to a separate optimisation pass. + value = op(value, __shfl_xor(value, 16, 32)); + value = op(value, __shfl_xor(value, 8, 32)); + value = op(value, __shfl_xor(value, 4, 32)); + value = op(value, __shfl_xor(value, 2, 32)); + value = op(value, __shfl_xor(value, 1, 32)); return value; } diff --git a/src/transform/pipeline_planning.cc b/src/transform/pipeline_planning.cc index 191f3a93ca..6911a73ce9 100644 --- a/src/transform/pipeline_planning.cc +++ b/src/transform/pipeline_planning.cc @@ -1074,6 +1074,23 @@ class PipelinePlanner : public StmtExprMutator { if (!num_stages_anno) return StmtExprMutator::VisitStmt_(loop); int num_stages = num_stages_anno->as()->value; + // On HIP/ROCM, skip software pipelining for multi-stage loops. + // Double-buffering (num_stages > 1) doubles the LDS allocation for every + // shared buffer declared inside the loop body, which easily exhausts the + // per-workgroup LDS limit and causes hipModuleLaunchKernel to return + // HIPERRORINVALIDVALUE. LDS limits by generation: + // gfx942 (CDNA3 / MI300X): 64 KB per workgroup + // gfx950 (CDNA4 / MI350): 160 KB per workgroup (see PR #2058) + // Even with the larger gfx950 budget, double-buffering a pair of large + // shared tiles (e.g. bM×bK + bK×bN in bf16) can exceed 160 KB, and the + // HIP async-copy infrastructure used by the CUDA pipeline planner has no + // equivalent on ROCM today. Setting num_stages = 1 inside the planner + // generates incorrect copy loops, so the safest fallback is a plain + // sequential loop where T.copy executes synchronously without any + // async-copy or double-buffering infrastructure. + if (TargetIsRocm(target_) && num_stages > 1) { + return StmtExprMutator::VisitStmt_(loop); + } Stmt pipeline_body_root{nullptr}; if (const auto *realize = loop->body.as()) { const auto &block = realize->block; diff --git a/testing/python/amd/test_tilelang_hip_codegen.py b/testing/python/amd/test_tilelang_hip_codegen.py new file mode 100644 index 0000000000..203f5f1653 --- /dev/null +++ b/testing/python/amd/test_tilelang_hip_codegen.py @@ -0,0 +1,757 @@ +""" +Regression tests for HIP/AMD codegen fixes in TileLang. + +Covers six bug fixes across five source files: + + Fix 1 (reduce.h) warp_reduce 5-step butterfly with width=32 + Fix 2 (codegen_hip.cc, ShuffleNode bfloat16x2/float16x2 packing; + common.h) uint1 bf16x2 math overloads + Fix 3 (allocate.py, T.alloc_var(init=) emits a correctly + codegen_hip.cc) initialised scalar on HIP + Fix 4 (codegen_hip.cc) T.sync_warp() lowered to no-op on HIP + Fix 5 (codegen_hip.cc, T.sync_grid() lowered to cooperative groups + rt_mod_hip.cc, grid barrier; runtime launch infrastructure + stubs/) added + Fix 6 (pipeline_planning.cc) T.Pipelined(num_stages>1) falls back to a + plain sequential loop on ROCM to avoid LDS + overflow (hipModuleLaunchKernel EINVAL) +""" + +import pytest +import torch +import tilelang +import tilelang.testing +import tilelang.language as T + + +# --------------------------------------------------------------------------- +# Fix 1 — src/tl_templates/hip/reduce.h +# warp_reduce: 5-step butterfly with explicit width=32 +# +# Symptom: On CDNA (wave64) with 32 active threads the old 6-step butterfly +# called __shfl_xor(value, 32) without a width argument, reading uninitialised +# VGPRs in lanes 32-63. This produced NaN or garbage in every reduction that +# went through warp_reduce (reduce_max, reduce_sum, AllReduce). +# +# Fix: remove the step-32 shuffle; add width=32 to every remaining step. +# __shfl_xor(v, N, 32) restricts the butterfly to the lower 32-lane group, +# matching CUDA warp semantics on CDNA wave64 and RDNA wave32 alike. +# With 64 threads and width=32 the wavefront splits into two independent +# 32-lane groups — correct for kernels that assume logical warp_size=32. +# --------------------------------------------------------------------------- + + +@tilelang.testing.requires_rocm +@pytest.mark.parametrize("n_tokens,n_experts", [(64, 8), (128, 16), (512, 32)]) +def test_warp_reduce_no_nan(n_tokens, n_experts): + """ + 32-thread-per-block reduce_max / reduce_sum must not produce NaN on CDNA. + + Old: __shfl_xor(v, 32) with 32 active threads reads uninit VGPRs → NaN. + New: 5-step with width=32 stays in [0,31] group → correct, no NaN. + """ + assert n_experts <= 32 + + @tilelang.jit + def gate_reduce(n_tok: int, n_exp: int): + @T.prim_func + def kernel( + logits: T.Tensor((n_tok, n_exp), T.float32), + out_max: T.Tensor((n_tok,), T.float32), + out_sum: T.Tensor((n_tok,), T.float32), + ) -> None: + with T.Kernel(n_tok, threads=32) as pid: + lf = T.alloc_fragment(n_exp, T.float32) + T.copy(logits[pid, 0], lf) + mx = T.alloc_fragment(1, T.float32) + T.reduce_max(lf, mx, dim=0) + sm = T.alloc_fragment(1, T.float32) + T.reduce_sum(lf, sm, dim=0) + if T.get_thread_binding() == 0: + out_max[pid] = mx[0] + out_sum[pid] = sm[0] + + return kernel + + logits = torch.randn(n_tokens, n_experts, dtype=torch.float32, device="cuda") + out_max = torch.zeros(n_tokens, dtype=torch.float32, device="cuda") + out_sum = torch.zeros(n_tokens, dtype=torch.float32, device="cuda") + + gate_reduce(n_tokens, n_experts)(logits, out_max, out_sum) + torch.cuda.synchronize() + + assert not out_max.isnan().any(), "reduce_max NaN — __shfl_xor(v,32) uninit VGPR bug" + assert not out_sum.isnan().any(), "reduce_sum NaN — __shfl_xor(v,32) uninit VGPR bug" + torch.testing.assert_close(out_max, logits.max(dim=1).values, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(out_sum, logits.sum(dim=1), atol=1e-4, rtol=1e-4) + + +@tilelang.testing.requires_rocm +def test_warp_reduce_correctness_32_threads(): + """ + 32-thread reduce_sum over 32 elements must return the exact sum on CDNA. + + Exercises the warp-level shuffle path directly. With the old step-32 + shuffle, uninitialised VGPR reads on CDNA produced garbage. + """ + N = 32 + + @tilelang.jit + def reduce_kernel(): + @T.prim_func + def kernel( + x: T.Tensor((N,), T.float32), + out: T.Tensor((1,), T.float32), + ) -> None: + with T.Kernel(1, threads=N) as _: + frag = T.alloc_fragment((N,), T.float32) + T.copy(x, frag) + s = T.alloc_fragment((1,), T.float32) + T.reduce_sum(frag, s, dim=0) + if T.get_thread_binding() == 0: + out[0] = s[0] + + return kernel + + x = torch.arange(1, N + 1, dtype=torch.float32, device="cuda") + out = torch.zeros(1, dtype=torch.float32, device="cuda") + reduce_kernel()(x, out) + torch.cuda.synchronize() + + assert not out[0].isnan(), "reduce_sum NaN — warp_reduce VGPR bug on CDNA" + torch.testing.assert_close(out[0], x.sum(), atol=1e-4, rtol=1e-4) + + +@tilelang.testing.requires_rocm +def test_warp_reduce_with_64_threads_two_groups(): + """ + With 64 threads and width=32 the wavefront splits into two independent + 32-lane groups — each group's lane 0 holds its partial sum. + + Old: __shfl_xor(v, 32) without width mixed the groups → wrong result. + New: width=32 confines each shuffle to its own 32-lane group → correct. + """ + N, n_exp = 64, 4 + + @tilelang.jit + def two_warp_reduce(): + @T.prim_func + def kernel( + x: T.Tensor((N, n_exp), T.float32), + out: T.Tensor((N,), T.float32), + ) -> None: + with T.Kernel(1, threads=N) as _: + tx = T.get_thread_binding() + frag = T.alloc_fragment(n_exp, T.float32) + T.copy(x[tx, 0], frag) + s = T.alloc_fragment(1, T.float32) + T.reduce_sum(frag, s, dim=0) + out[tx] = s[0] + + return kernel + + x = torch.ones(N, n_exp, dtype=torch.float32, device="cuda") + out = torch.zeros(N, dtype=torch.float32, device="cuda") + two_warp_reduce()(x, out) + torch.cuda.synchronize() + + assert not out.isnan().any(), "NaN in two-warp reduce — width=32 fix not applied" + + +# --------------------------------------------------------------------------- +# Fix 2 — src/target/codegen_hip.cc (VisitExpr_ ShuffleNode) +# src/tl_templates/hip/common.h (uint1 bfloat16x2 math overloads) +# +# Symptom: Packing two bfloat16 scalars into a bfloat16x2 ShuffleNode caused +# CodeGenC to emit `uint1(a, b)` — invalid HIP constructor → compile error. +# +# Fix (codegen_hip.cc): Override VisitExpr_(ShuffleNode) to emit +# `uint1{__pack_bfloat162(a, b)}` / `uint1{__pack_half2(a, b)}`. +# Fix (common.h): Add abs2/max2/min2/add2/mul2 overloads for uint1 as a +# packed bfloat16x2 carrier. +# --------------------------------------------------------------------------- + + +@tilelang.testing.requires_rocm +def test_bfloat16_shuffle_codegen_and_correctness(): + """ + bfloat16 fragment warp-reduction: source must use __pack_bfloat162 + (not invalid `uint1(a,b)`) and the result must be numerically correct. + """ + n_tok, n_exp = 16, 8 + + @tilelang.jit + def bf16_reduce(n_t: int, n_e: int): + @T.prim_func + def kernel( + x: T.Tensor((n_t, n_e), T.bfloat16), + out: T.Tensor((n_t,), T.float32), + ) -> None: + with T.Kernel(n_t, threads=32) as pid: + frag = T.alloc_fragment(n_e, T.bfloat16) + T.copy(x[pid, 0], frag) + frag_f32 = T.alloc_fragment(n_e, T.float32) + for i in T.Parallel(n_e): + frag_f32[i] = T.cast(frag[i], T.float32) + s = T.alloc_fragment(1, T.float32) + T.reduce_sum(frag_f32, s, dim=0) + if T.get_thread_binding() == 0: + out[pid] = s[0] + + return kernel + + kernel = bf16_reduce(n_tok, n_exp) + + # Source check: no invalid two-argument constructor + src = kernel.get_kernel_source() + assert "uint1(a" not in src and "uint1(b" not in src, "Old `uint1(a, b)` constructor found — ShuffleNode fix not applied" + + # Runtime correctness + x = torch.randn(n_tok, n_exp, dtype=torch.bfloat16, device="cuda") + out = torch.zeros(n_tok, dtype=torch.float32, device="cuda") + kernel(x, out) + torch.cuda.synchronize() + assert not out.isnan().any(), "bf16 ShuffleNode reduction NaN" + torch.testing.assert_close(out, x.float().sum(dim=1), atol=5e-2, rtol=1e-2) + + +@tilelang.testing.requires_rocm +def test_float16_shuffle_correctness(): + """ + float16 fragment warp-reduction exercises the __pack_half2 path. + Analogous to the bfloat16 test but for float16x2 packing. + """ + n_tok, n_exp = 64, 8 + + @tilelang.jit + def f16_reduce(n_t: int, n_e: int): + @T.prim_func + def kernel( + x: T.Tensor((n_t, n_e), T.float16), + out: T.Tensor((n_t,), T.float32), + ) -> None: + with T.Kernel(n_t, threads=32) as pid: + frag = T.alloc_fragment(n_e, T.float16) + T.copy(x[pid, 0], frag) + frag_f32 = T.alloc_fragment(n_e, T.float32) + for i in T.Parallel(n_e): + frag_f32[i] = T.cast(frag[i], T.float32) + s = T.alloc_fragment(1, T.float32) + T.reduce_sum(frag_f32, s, dim=0) + if T.get_thread_binding() == 0: + out[pid] = s[0] + + return kernel + + x = torch.randn(n_tok, n_exp, dtype=torch.float16, device="cuda") + out = torch.zeros(n_tok, dtype=torch.float32, device="cuda") + f16_reduce(n_tok, n_exp)(x, out) + torch.cuda.synchronize() + assert not out.isnan().any(), "float16 ShuffleNode reduction NaN" + torch.testing.assert_close(out, x.float().sum(dim=1), atol=1e-1, rtol=1e-2) + + +# --------------------------------------------------------------------------- +# Fix 3 — tilelang/language/allocate.py + src/target/codegen_hip.cc +# T.alloc_var(init=) initialisation on HIP; +# local.var scalar declaration and GetBufferRef bare-name return +# +# Symptom (allocate.py): int/float literals used block_attr("tl.local_var_init") +# which the HIP backend silently ignored → variable uninitialised at runtime. +# Symptom (codegen_hip.cc): AllocateNode emitted `type vid[1];` for local.var; +# alloc_storage_scope_ was not updated → GetBufferRef fell through to an +# invalid pointer-cast path → compile failure. +# +# Fix (allocate.py): always route init through T.buffer_store → explicit +# BufferStore TIR node → assignment statement in every backend. +# Fix (codegen_hip.cc): emit `type vid = init;`; register alloc_storage_scope_ +# so GetBufferRef returns the bare name `vid`. +# --------------------------------------------------------------------------- + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, + } +) +def _kernel_alloc_var_init(): + """Kernel that initialises a local int32 variable to 7 and writes it out.""" + + @T.prim_func + def main(Out: T.Tensor((64,), "int32")): + with T.Kernel(1, threads=64): + tx = T.get_thread_binding() + counter = T.alloc_var(T.int32, init=7) + Out[tx] = counter + + return main + + +@tilelang.testing.requires_rocm +def test_alloc_var_init_in_hip_source(): + """Init value must appear as `= 7;` in the generated HIP source.""" + src = _kernel_alloc_var_init().get_kernel_source() + assert "= 7;" in src, ( + f"T.alloc_var(T.int32, init=7) should generate '= 7;' in HIP source, but it was not found.\nGenerated source:\n{src}" + ) + + +@tilelang.testing.requires_rocm +def test_alloc_var_init_no_array_subscript_in_hip_source(): + """local.var must be declared as a scalar (no `counter[` array syntax).""" + src = _kernel_alloc_var_init().get_kernel_source() + assert "counter[" not in src, ( + f"local.var should be emitted as a scalar (e.g. 'int counter = 7'), but array-style access was found:\n{src}" + ) + + +@tilelang.testing.requires_rocm +def test_alloc_var_init_correctness(): + """All output elements must equal 7 — the initialised value.""" + out = torch.zeros(64, dtype=torch.int32, device="cuda") + _kernel_alloc_var_init()(out) + assert torch.all(out == 7), f"Expected all 7, got: {out}" + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, + } +) +def _kernel_multi_alloc_var_init(): + """Two local variables with different init values, summed into output.""" + + @T.prim_func + def main(Out: T.Tensor((32,), "int32")): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding() + a = T.alloc_var(T.int32, init=3) + b = T.alloc_var(T.int32, init=4) + Out[tx] = a + b + + return main + + +@tilelang.testing.requires_rocm +def test_multi_alloc_var_init_in_hip_source(): + """Both init values must appear in the HIP source.""" + src = _kernel_multi_alloc_var_init().get_kernel_source() + assert src.count("= 3;") >= 1, f"Init value 3 not found in HIP source:\n{src}" + assert src.count("= 4;") >= 1, f"Init value 4 not found in HIP source:\n{src}" + + +@tilelang.testing.requires_rocm +def test_multi_alloc_var_init_correctness(): + """Sum of two initialised local variables must equal 7 (3+4).""" + out = torch.zeros(32, dtype=torch.int32, device="cuda") + _kernel_multi_alloc_var_init()(out) + assert torch.all(out == 7), f"Expected all 7 (3+4), got: {out}" + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, + } +) +def _kernel_alloc_var_count(): + """Counter initialised to 0, incremented 5 times in a loop.""" + + @T.prim_func + def main(Out: T.Tensor((32,), "int32")): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding() + count = T.alloc_var(T.int32, init=0) + for _ in T.unroll(5): + count += 1 + Out[tx] = count + + return main + + +@tilelang.testing.requires_rocm +def test_alloc_var_zero_init_correctness(): + """Variable initialised to 0 and incremented 5 times must equal 5.""" + out = torch.zeros(32, dtype=torch.int32, device="cuda") + _kernel_alloc_var_count()(out) + assert torch.all(out == 5), f"Expected all 5, got: {out}" + + +@tilelang.testing.requires_rocm +@pytest.mark.parametrize( + "init_val,dtype_str", + [ + (0, "int32"), + (7, "int32"), + (-3, "int32"), + (0.0, "float32"), + (1.0, "float32"), + (-0.5, "float32"), + ], +) +def test_alloc_var_literal_init_is_reliable(init_val, dtype_str): + """ + alloc_var with any literal init must produce that exact value on HIP. + + Old: int/float literals → block_attr (silently ignored) → uninit. + New: always T.buffer_store → `vid = init_val;` in generated HIP C. + """ + tl_dtype = T.int32 if dtype_str == "int32" else T.float32 + torch_dtype = torch.int32 if dtype_str == "int32" else torch.float32 + N = 32 + + @tilelang.jit + def var_init_kernel(iv, tld): + @T.prim_func + def kernel(out: T.Tensor((N,), tld)) -> None: + with T.Kernel(1, threads=N) as _: + v = T.alloc_var(tld, init=iv) + for i in T.Parallel(N): + out[i] = v + + return kernel + + out = torch.zeros(N, dtype=torch_dtype, device="cuda") + var_init_kernel(init_val, tl_dtype)(out) + torch.cuda.synchronize() + + expected = torch.full((N,), init_val, dtype=torch_dtype, device="cuda") + if dtype_str == "int32": + assert torch.equal(out, expected), f"alloc_var(init={init_val}) got {out[0].item()}, expected {init_val}" + else: + torch.testing.assert_close(out, expected, atol=0, rtol=0) + + +@tilelang.testing.requires_rocm +def test_alloc_var_inf_init(): + """ + alloc_var(init=-T.infinity(T.float32)) — the pattern used for top1_var / + top2_var in MoE topk gate kernels — must produce -inf on HIP. + """ + N = 32 + + @tilelang.jit + def inf_init_kernel(): + @T.prim_func + def kernel(out: T.Tensor((N,), T.float32)) -> None: + with T.Kernel(1, threads=N) as _: + v = T.alloc_var(T.float32, init=-T.infinity(T.float32)) + for i in T.Parallel(N): + out[i] = v + + return kernel + + out = torch.zeros(N, dtype=torch.float32, device="cuda") + inf_init_kernel()(out) + torch.cuda.synchronize() + assert out.isinf().all() and (out < 0).all(), f"alloc_var(init=-inf) got {out[0].item()}, expected -inf" + + +@tilelang.testing.requires_rocm +def test_alloc_var_init_zero_persists_across_serial_loop(): + """ + count_var = T.alloc_var(T.int32, init=0) must start at 0 and accumulate + correctly. This is the exact pattern used by count_var in MoE kernels. + """ + N = 8 + + @tilelang.jit + def serial_count_kernel(): + @T.prim_func + def kernel(out: T.Tensor((1,), T.int32)) -> None: + with T.Kernel(1, threads=1) as _: + count_var = T.alloc_var(T.int32, init=0) + for _ in T.serial(N): + count_var = count_var + 1 + out[0] = count_var + + return kernel + + out = torch.zeros(1, dtype=torch.int32, device="cuda") + serial_count_kernel()(out) + torch.cuda.synchronize() + assert out[0].item() == N, f"count_var: got {out[0].item()}, expected {N} — init=0 not applied (block_attr bug)" + + +@tilelang.testing.requires_rocm +def test_local_var_scalar_codegen(): + """ + local.var must be emitted and accessed as a plain scalar on HIP. + + Before: alloc_storage_scope_ not registered → GetBufferRef fell through to + an invalid pointer-cast path → compile failure. + After: `type vid = init;` emitted; GetBufferRef returns bare `vid`. + """ + N = 32 + + @tilelang.jit + def local_var_scalar(): + @T.prim_func + def kernel(out: T.Tensor((N,), T.int32)) -> None: + with T.Kernel(1, threads=N) as _: + v = T.alloc_var(T.int32, init=5) + if T.get_thread_binding() == 0: + v = v + 1 + for i in T.Parallel(N): + out[i] = v + + return kernel + + out = torch.zeros(N, dtype=torch.int32, device="cuda") + local_var_scalar()(out) + torch.cuda.synchronize() + assert out[0].item() == 6, f"local.var scalar: got {out[0].item()}, expected 6 (5+1)" + + +@tilelang.testing.requires_rocm +def test_local_var_float_init_readable(): + """ + local.var with float32 literal init must be readable on HIP. + Before the alloc_storage_scope_ fix, GetBufferRef emitted invalid code. + """ + + @tilelang.jit + def float_init_readback(): + @T.prim_func + def kernel(out: T.Tensor((1,), T.float32)) -> None: + with T.Kernel(1, threads=32) as _: + v = T.alloc_var(T.float32, init=3.14) + if T.get_thread_binding() == 0: + out[0] = v + + return kernel + + out = torch.zeros(1, dtype=torch.float32, device="cuda") + float_init_readback()(out) + torch.cuda.synchronize() + torch.testing.assert_close(out[0].item(), 3.14, atol=1e-5, rtol=0) + + +# --------------------------------------------------------------------------- +# Fix 4 — src/target/codegen_hip.cc +# T.sync_warp() → no-op on HIP +# +# Symptom: tl::sync_warp() had no handler → codegen assertion failure or +# undefined symbol at link time. +# Fix: emit an empty statement; AMD wavefronts execute in lockstep so +# intra-wavefront convergence is guaranteed by hardware. +# --------------------------------------------------------------------------- + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, + } +) +def _kernel_sync_warp_codegen(): + """Minimal kernel that exercises T.sync_warp().""" + + @T.prim_func + def main(A: T.Tensor((32,), "float32"), B: T.Tensor((32,), "float32")): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding() + A_shared = T.alloc_shared((32,), "float32") + A_shared[tx] = A[tx] + T.sync_warp() + B[tx] = A_shared[tx] * 2.0 + + return main + + +@tilelang.testing.requires_rocm +def test_sync_warp_no_syncwarp_in_hip_source(): + """__syncwarp must NOT appear in the generated HIP source.""" + src = _kernel_sync_warp_codegen().get_kernel_source() + assert "__syncwarp" not in src, f"T.sync_warp() should be a no-op on HIP, but __syncwarp was found in the generated source:\n{src}" + + +@tilelang.testing.requires_rocm +def test_sync_warp_correctness(): + """Kernel using T.sync_warp() must produce correct results on HIP.""" + A = torch.arange(32, dtype=torch.float32, device="cuda") + B = torch.zeros(32, dtype=torch.float32, device="cuda") + _kernel_sync_warp_codegen()(A, B) + torch.testing.assert_close(B, A * 2.0) + + +@tilelang.testing.requires_rocm +def test_sync_warp_inside_conditional(): + """ + T.sync_warp() inside a conditional branch (pattern from moe/common.py + get_topk_group_idx). Verifies compilation and deterministic output. + """ + N, M = 32, 8 + + @tilelang.jit + def sync_warp_cond_kernel(): + @T.prim_func + def kernel( + x: T.Tensor((N,), T.float32), + out: T.Tensor((M,), T.float32), + ) -> None: + with T.Kernel(1, threads=N) as _: + shmem = T.alloc_shared((M,), T.float32) + tx = T.get_thread_binding() + if tx < M: + shmem[tx] = x[tx] + T.sync_warp() + for i in T.Parallel(M): + out[i] = shmem[i] + + return kernel + + x = torch.randn(N, dtype=torch.float32, device="cuda") + out = torch.zeros(M, dtype=torch.float32, device="cuda") + sync_warp_cond_kernel()(x, out) + torch.cuda.synchronize() + torch.testing.assert_close(out, x[:M]) + + +# --------------------------------------------------------------------------- +# Fix 5 — src/target/codegen_hip.cc, src/target/rt_mod_hip.cc, +# src/target/stubs/hip.cc, src/target/stubs/hip.h +# T.sync_grid() → cooperative_groups::this_grid().sync() +# +# Symptom: tl::sync_grid() had no handler → same assertion / link failure. +# Fix: emit cooperative_groups call; add need_cooperative_groups_ flag to +# conditionally include hip_cooperative_groups.h; add runtime launch +# infrastructure (hipModuleLaunchCooperativeKernel stubs). +# --------------------------------------------------------------------------- + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, + } +) +def _kernel_sync_grid_codegen(): + """Kernel that calls T.sync_grid() to trigger cooperative groups codegen.""" + + @T.prim_func + def main(A: T.Tensor((32,), "float32")): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding() + T.sync_grid() + A[tx] = T.float32(tx) + + return main + + +@tilelang.testing.requires_rocm +def test_sync_grid_cooperative_groups_in_hip_source(): + """ + T.sync_grid() must emit cooperative_groups::this_grid().sync() and + include in the generated HIP source. + + Note: runtime execution requires hipModuleLaunchCooperativeKernel which + is added via the stub infrastructure; this test validates codegen only. + """ + src = _kernel_sync_grid_codegen().get_kernel_source() + assert "this_grid().sync()" in src, f"T.sync_grid() should generate 'this_grid().sync()' but not found:\n{src}" + assert "cooperative_groups" in src, f"T.sync_grid() should include cooperative_groups but not found:\n{src}" + + +# --------------------------------------------------------------------------- +# Fix 6 — src/transform/pipeline_planning.cc +# Skip T.Pipelined(num_stages>1) pipeline planning on ROCM +# +# Symptom: Double-buffering doubled the LDS allocation for every shared buffer +# inside the loop body, exhausting the per-workgroup LDS budget and causing +# hipModuleLaunchKernel to return HIPERRORINVALIDVALUE. LDS limits: +# gfx942 (CDNA3 / MI300X): 64 KB; gfx950 (CDNA4 / MI350): 160 KB (#2058) +# Even with gfx950's larger budget, double-buffering large shared tiles can +# still exceed 160 KB, and the HIP async-copy infrastructure has no ROCM +# equivalent, so the planner cannot safely pipeline on any ROCM target. +# +# Fix: when TargetIsRocm() && num_stages > 1, skip pipeline planning and fall +# back to a plain sequential loop with synchronous T.copy — always LDS-safe. +# --------------------------------------------------------------------------- + + +@tilelang.testing.requires_rocm +@pytest.mark.parametrize("num_stages", [1, 2, 3]) +def test_pipelined_no_lds_overflow(num_stages): + """ + T.Pipelined(num_stages=N) must not raise hipModuleLaunchKernel EINVAL and + must produce the correct result regardless of N. + + Old: num_stages=2 doubled LDS → EINVAL (64 KB on gfx942, 160 KB on gfx950). + New: multi-stage loops fall back to plain sequential on ROCM. + """ + M, K, blk = 32, 256, 64 + + @tilelang.jit + def pipelined_rowsum(n_stages: int): + @T.prim_func + def kernel( + x: T.Tensor((M, K), T.float32), + out: T.Tensor((M,), T.float32), + ) -> None: + with T.Kernel(M, threads=64) as pid: + acc = T.alloc_fragment((1,), T.float32) + T.clear(acc) + for k in T.Pipelined(K // blk, num_stages=n_stages): + xs = T.alloc_shared((blk,), T.float32) + xl = T.alloc_fragment((blk,), T.float32) + T.copy(x[pid, k * blk], xs, disable_tma=True) + T.copy(xs, xl, disable_tma=True) + s = T.alloc_fragment((1,), T.float32) + T.reduce_sum(xl, s, dim=0) + acc[0] = acc[0] + s[0] + out[pid] = acc[0] + + return kernel + + x = torch.ones(M, K, dtype=torch.float32, device="cuda") + out = torch.zeros(M, dtype=torch.float32, device="cuda") + pipelined_rowsum(num_stages)(x, out) + torch.cuda.synchronize() + torch.testing.assert_close(out, torch.full((M,), float(K), device="cuda"), atol=1e-4, rtol=0) + + +@tilelang.testing.requires_rocm +@pytest.mark.parametrize("num_stages", [2, 3]) +def test_pipelined_multi_stage_fp16_gemm(num_stages): + """ + FP16 GEMM with T.Pipelined(num_stages>1) must launch and produce correct + results on ROCM — the most common pattern that triggered the LDS overflow + (A_s bM×bK + B_s bK×bN doubled per pipeline stage). + """ + M, N, K = 128, 128, 128 + bM, bN, bK = 64, 64, 32 + + @tilelang.jit + def fp16_gemm(n_stages: int): + @T.prim_func + def kernel( + A: T.Tensor((M, K), T.float16), + B: T.Tensor((K, N), T.float16), + C: T.Tensor((M, N), T.float32), + ) -> None: + with T.Kernel(T.ceildiv(N, bN), T.ceildiv(M, bM), threads=128) as (bx, by): + A_s = T.alloc_shared((bM, bK), T.float16) + B_s = T.alloc_shared((bK, bN), T.float16) + C_l = T.alloc_fragment((bM, bN), T.float32) + T.clear(C_l) + for k in T.Pipelined(K // bK, num_stages=n_stages): + T.copy(A[by * bM, k * bK], A_s) + T.copy(B[k * bK, bx * bN], B_s) + T.gemm(A_s, B_s, C_l) + T.copy(C_l, C[by * bM, bx * bN]) + + return kernel + + A = torch.randn(M, K, dtype=torch.float16, device="cuda") + B = torch.randn(K, N, dtype=torch.float16, device="cuda") + C = torch.zeros(M, N, dtype=torch.float32, device="cuda") + fp16_gemm(num_stages)(A, B, C) + torch.cuda.synchronize() + torch.testing.assert_close(C, A.float() @ B.float(), atol=1.0, rtol=5e-2) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index 47de8c44ce..19d1ced762 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -138,10 +138,16 @@ def alloc_var(dtype: DType, *args, scope: str = "local.var", init: PrimExpr | in buffer = T.alloc_buffer([1], dtype, scope=parsed_scope) if parsed_init is not None: + # Always use T.buffer_store for reliable initialisation across all + # backends. The block_attr("tl.local_var_init") path feeds into the + # flatten_buffer transform which does not reliably emit initialiser + # code on some backends (e.g. HIP codegen silently drops the + # annotation for integer/float literals, leaving the scalar + # uninitialised). T.buffer_store emits an explicit BufferStore TIR + # node that every backend lowers to an assignment statement. if isinstance(parsed_init, (int, float, IntImm, FloatImm)): - block_attr({"tl.local_var_init": {buffer.data: tl_dtype(dtype)(parsed_init)}}) - else: - T.buffer_store(buffer, parsed_init, 0) + parsed_init = tl_dtype(dtype)(parsed_init) + T.buffer_store(buffer, parsed_init, 0) return buffer From a2f6a455e51f6b5dcad0d22537b7bb4326159711 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Mon, 27 Apr 2026 18:38:44 +0800 Subject: [PATCH 150/156] fix T.make_tensor buffer missing --- tilelang/engine/phase.py | 45 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index e00d16d2eb..53d95da32e 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -154,6 +154,40 @@ def _walk( return False +def module_has_runtime_pointer_tensor(mod: IRModule) -> bool: + """Detect ``T.make_tensor(, ...)`` style base addresses.""" + from tvm.ir import PointerType + from tvm.tir import stmt_functor + + for _, func in mod.functions.items(): + if not isinstance(func, tir.PrimFunc): + continue + + found = [False] + + def _check(node, _found=found): + if _found[0]: + return + if not isinstance(node, tir.LetStmt): + return + var = node.var + ann = getattr(var, "type_annotation", None) + if not isinstance(ann, PointerType): + return + scope = getattr(ann, "storage_scope", "") or "" + if scope != "global": + return + value = node.value + if isinstance(value, tir.Call) and getattr(value.op, "name", "") == "tir.reinterpret": + _found[0] = True + + stmt_functor.post_order_visit(func.body, _check) + if found[0]: + return True + + return False + + def allow_vectorize(pass_ctx: PassContext | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() @@ -320,12 +354,21 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.InjectAssumes()(mod) # Simplify the IR expressions mod = tilelang.transform.Simplify()(mod) - if allow_autoschedule(target=target) and not module_uses_thread_var(mod) and not module_has_barrier(mod): + if ( + allow_autoschedule(target=target) + and not module_uses_thread_var(mod) + and not module_has_barrier(mod) + and not module_has_runtime_pointer_tensor(mod) + ): # Auto schedule for high-level operations. # Skip when the kernel already manages explicit mbarriers # (alloc_barrier / alloc_cluster_barrier), because reordering the # rewrites breaks invariants that later barrier lowering and the # WS / pipelined TMA copy pipeline rely on. + # Also skip when the kernel uses ``T.make_tensor`` runtime-bound + # base addresses (ptr-backed grouped GEMM): promoting their copies + # to TMA would lift descriptor creation past the LetStmt that + # defines the base ``Var``, breaking ``MakePackedAPI``. mod = tilelang.transform.IfConditionExtract()(mod) mod = tilelang.transform.AutoSchedule(False)(mod) mod = tilelang.transform.Simplify()(mod) From 3490f3ac36b469fc8b24eb525d51106054249a12 Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Wed, 6 May 2026 17:03:09 +0800 Subject: [PATCH 151/156] add constraints for warpgroup partition & always analyze and insert barriers --- src/transform/auto_schedule.cc | 92 ++++++++----------- .../auto_schedule/schedule_builder.cc | 48 +++++++++- 2 files changed, 86 insertions(+), 54 deletions(-) diff --git a/src/transform/auto_schedule.cc b/src/transform/auto_schedule.cc index 78c2b13866..d547649028 100644 --- a/src/transform/auto_schedule.cc +++ b/src/transform/auto_schedule.cc @@ -692,7 +692,6 @@ struct ScheduledKernelResult { std::vector buffer_infos; std::vector duplicated_fragment_buffers; PrimExpr updated_thread_extent; - bool did_warpgroup_partition{false}; }; // Schedule a single kernel body (the logic previously inlined in AutoSchedule). @@ -740,14 +739,6 @@ ScheduleSingleKernel(const Stmt &kernel_body, IterVar thread_var, Target target, thread_count = unit_builder.Build(ir_structure); } - if (!config.enable_warpgroup_partition) { - result.scheduled_body = - ConvertIRStructureToStmt(ir_structure.get(), enable_epi); - result.scheduled_body = StripUnusedLetStmts(result.scheduled_body); - result.did_warpgroup_partition = false; - return result; - } - // Print the modified summary view // PrintIRStructure(ir_structure.get()); @@ -771,7 +762,6 @@ ScheduleSingleKernel(const Stmt &kernel_body, IterVar thread_var, Target target, ir_structure.get(), thread_var, result.barrier_buffers, result.barrier_map, enable_epi, thread_count, config, neutral_sync_shared_barrier, result.duplicated_fragment_buffers); - result.did_warpgroup_partition = true; return result; } @@ -868,27 +858,25 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { TilelangRootBodyReplacer replacer(kr.scheduled_body); final_body = replacer(func->body); - if (kr.did_warpgroup_partition) { - // Apply thread extent update if warpgroup partition was applied - // (sm_90 only) - if (config.enable_thread_extend) { - ThreadExtentUpdater extent_updater(kr.updated_thread_extent); - final_body = extent_updater(final_body); - } - // Add barrier buffers to tilelang_root block's alloc_buffers - if (!kr.barrier_buffers.empty() || - !kr.duplicated_fragment_buffers.empty()) { - std::vector all_alloc_buffers = kr.barrier_buffers; - all_alloc_buffers.insert(all_alloc_buffers.end(), - kr.duplicated_fragment_buffers.begin(), - kr.duplicated_fragment_buffers.end()); - final_body = AddBarrierBuffersToRoot(final_body, all_alloc_buffers, - kr.barrier_map); - } - // Apply multi-version alloc_buffer rewrite if needed - if (!kr.buffer_infos.empty()) { - final_body = RewriteAllocBuffers(final_body, kr.buffer_infos); - } + // Apply thread extent update if warpgroup partition was applied + // (sm_90 only) + if (config.enable_thread_extend) { + ThreadExtentUpdater extent_updater(kr.updated_thread_extent); + final_body = extent_updater(final_body); + } + // Add barrier buffers to tilelang_root block's alloc_buffers + if (!kr.barrier_buffers.empty() || + !kr.duplicated_fragment_buffers.empty()) { + std::vector all_alloc_buffers = kr.barrier_buffers; + all_alloc_buffers.insert(all_alloc_buffers.end(), + kr.duplicated_fragment_buffers.begin(), + kr.duplicated_fragment_buffers.end()); + final_body = AddBarrierBuffersToRoot(final_body, all_alloc_buffers, + kr.barrier_map); + } + // Apply multi-version alloc_buffer rewrite if needed + if (!kr.buffer_infos.empty()) { + final_body = RewriteAllocBuffers(final_body, kr.buffer_infos); } final_body = ReNestLetStmts(final_body); @@ -952,28 +940,26 @@ tvm::transform::Pass AutoSchedule(const bool enable_epi) { scheduled_subtree = replacer(kernel_subtree); } - if (kr.did_warpgroup_partition) { - // Apply thread extent update if warpgroup partition was applied - // (sm_90 only) - if (config.enable_thread_extend) { - ThreadExtentUpdater extent_updater(kr.updated_thread_extent); - scheduled_subtree = extent_updater(scheduled_subtree); - } - // Add barrier buffers to this kernel's tilelang_root block - if (!kr.barrier_buffers.empty() || - !kr.duplicated_fragment_buffers.empty()) { - std::vector all_alloc_buffers = kr.barrier_buffers; - all_alloc_buffers.insert(all_alloc_buffers.end(), - kr.duplicated_fragment_buffers.begin(), - kr.duplicated_fragment_buffers.end()); - scheduled_subtree = AddBarrierBuffersToRoot( - scheduled_subtree, all_alloc_buffers, kr.barrier_map); - } - // Apply multi-version alloc_buffer rewrite if needed - if (!kr.buffer_infos.empty()) { - scheduled_subtree = - RewriteAllocBuffers(scheduled_subtree, kr.buffer_infos); - } + // Apply thread extent update if warpgroup partition was applied + // (sm_90 only) + if (config.enable_thread_extend) { + ThreadExtentUpdater extent_updater(kr.updated_thread_extent); + scheduled_subtree = extent_updater(scheduled_subtree); + } + // Add barrier buffers to this kernel's tilelang_root block + if (!kr.barrier_buffers.empty() || + !kr.duplicated_fragment_buffers.empty()) { + std::vector all_alloc_buffers = kr.barrier_buffers; + all_alloc_buffers.insert(all_alloc_buffers.end(), + kr.duplicated_fragment_buffers.begin(), + kr.duplicated_fragment_buffers.end()); + scheduled_subtree = AddBarrierBuffersToRoot( + scheduled_subtree, all_alloc_buffers, kr.barrier_map); + } + // Apply multi-version alloc_buffer rewrite if needed + if (!kr.buffer_infos.empty()) { + scheduled_subtree = + RewriteAllocBuffers(scheduled_subtree, kr.buffer_infos); } // Insert shared memory boundary between kernel segments diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index 403841c72f..40ebb1be18 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -623,6 +623,29 @@ AssignWarpgroupIdsGlobal(IRStructure *root, const WarpSpecializeConfig &config, LOG(FATAL) << "No task"; } + bool enable_partition = config.enable_warpgroup_partition; + if (auto thread_count_num = as_const_int(thread_count)) { + if (config.enable_warp_partition) { + enable_partition &= (*thread_count_num >= 64); + } else { + enable_partition &= (*thread_count_num % 32 == 0); + } + } else { + enable_partition = false; + } + + if (!enable_partition) { + for (auto &task_ctx : all_tasks) { + TaskNode *task = task_ctx.task; + if (task->ContainsLoopBreak()) { + task->SetWarpgroupId(kWarpgroupBroadcast); + } else { + task->SetWarpgroupId(0); + } + } + return {thread_count}; + } + int n = all_tasks.size(); for (auto &task_ctx : all_tasks) { @@ -1010,6 +1033,29 @@ NaiveAssignWarpgroupIds(IRStructure *root, const WarpSpecializeConfig &config, if (all_tasks.empty()) LOG(FATAL) << "No task"; + bool enable_partition = config.enable_warpgroup_partition; + if (auto thread_count_num = as_const_int(thread_count)) { + if (config.enable_warp_partition) { + enable_partition &= (*thread_count_num >= 64); + } else { + enable_partition &= (*thread_count_num % 32 == 0); + } + } else { + enable_partition = false; + } + + if (!enable_partition) { + for (auto &task_ctx : all_tasks) { + TaskNode *task = task_ctx.task; + if (task->ContainsLoopBreak()) { + task->SetWarpgroupId(kWarpgroupBroadcast); + } else { + task->SetWarpgroupId(0); + } + } + return {thread_count}; + } + // Simple producer/consumer assignment: // TMA tasks → wg1 (producer), compute tasks → wg0 (consumer) for (auto &task_ctx : all_tasks) { @@ -1025,7 +1071,7 @@ NaiveAssignWarpgroupIds(IRStructure *root, const WarpSpecializeConfig &config, } } - if (config.producer_thread_count == 32) { + if (config.enable_warp_partition) { // Collect prefix/suffix tasks and reset them to neutral std::unordered_set prefix_tasks; CollectPrefixTasks(root, prefix_tasks); From 2903cd1f89fd7961d5af792c404667271aade382 Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Wed, 6 May 2026 17:03:57 +0800 Subject: [PATCH 152/156] assign tma store to consumer side --- src/transform/auto_schedule/schedule_builder.cc | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/transform/auto_schedule/schedule_builder.cc b/src/transform/auto_schedule/schedule_builder.cc index 40ebb1be18..fb11d97331 100644 --- a/src/transform/auto_schedule/schedule_builder.cc +++ b/src/transform/auto_schedule/schedule_builder.cc @@ -1064,7 +1064,7 @@ NaiveAssignWarpgroupIds(IRStructure *root, const WarpSpecializeConfig &config, task->SetWarpgroupId(kWarpgroupBroadcast); continue; } - if (task->UsesTMACore() && !task->UsesTensorCore()) { + if (task->HasTMALoad()) { task->SetWarpgroupId(1); // producer } else { task->SetWarpgroupId(0); // consumer @@ -1134,20 +1134,17 @@ void ScheduleUnitBuilder::NaiveScheduleLoop(ControlNode *ctrl) { // Assign pipeline stages and start times: // - TMA load → stage 0, start_time = 0 - // - Everything else → stage (num_stages - 1), start_time = num_stages + // - Everything else → stage (num_stages), start_time = num_stages // - All task latencies set to 0, IIperIter = 1 std::map stage_map; bool has_promoted = false; for (auto &child : seq_body->children) { IRStructure *node = child.get(); bool is_tma_load = - node->UsesTMACore() && !node->UsesTensorCore() && !node->UsesCUDACore(); - if (is_tma_load && node->IsTask()) { - is_tma_load = static_cast(node)->HasTMALoad(); - } + node->IsTask() && static_cast(node)->HasTMALoad(); int stage = !is_tma_load ? 0 : (num_stages); stage_map[node] = stage; - if (stage != num_stages) { + if (stage != 0) { has_promoted = true; } node->SetStartTime(is_tma_load ? 0 : num_stages); From 28aa5ad1f07da738945ad3547df146c76a872509 Mon Sep 17 00:00:00 2001 From: Denver Jin Date: Thu, 7 May 2026 17:30:43 +0800 Subject: [PATCH 153/156] Fix let & barrier bugs --- .../auto_schedule/warpgroup_partition.cc | 229 ++++++++++++++++-- 1 file changed, 204 insertions(+), 25 deletions(-) diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index 9eb09b7ce4..01f874cc5f 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -800,19 +800,6 @@ Stmt ConvertIRStructureToStmt(IRStructure *structure, for (auto &child : seq->children) { auto unit = static_cast(child.get()); - std::vector stmts; - for (const auto &[_, before] : unit->before) { - for (const auto &stmt : before) { - stmts.push_back(stmt); - } - } - stmts.push_back( - ConvertIRStructureToStmt(unit->child.get(), outer_enable_epi)); - for (const auto &[_, after] : unit->after) { - for (const auto &stmt : after) { - stmts.push_back(stmt); - } - } Map substitution, substitution_cond; substitution.Set(loop_var, loop_var - loop_step * (max_stages - unit->stage)); @@ -820,17 +807,45 @@ Stmt ConvertIRStructureToStmt(IRStructure *structure, loop_var, Max(loop_start, Min(loop_start + loop_extent - loop_step, loop_var - loop_step * (max_stages - unit->stage)))); + PrimExpr condition = + And(loop_var < loop_start + loop_extent, loop_var >= loop_start); + if (unit->stage == min_stages) { + condition = loop_var >= loop_start; + } + if (unit->stage == max_stages) { + condition = loop_var < loop_start + loop_extent; + } if (IsLetDeclNode(unit->child.get())) { - Stmt stmt = SeqStmt::Flatten(stmts); - steady.push_back(Substitute(stmt, substitution_cond)); + std::vector guarded; + for (const auto &[_, before] : unit->before) { + for (const auto &stmt : before) { + guarded.push_back( + Substitute(IfThenElse(condition, stmt), substitution)); + } + } + guarded.push_back(Substitute( + ConvertIRStructureToStmt(unit->child.get(), outer_enable_epi), + substitution_cond)); + for (const auto &[_, after] : unit->after) { + for (const auto &stmt : after) { + guarded.push_back( + Substitute(IfThenElse(condition, stmt), substitution)); + } + } + steady.push_back(SeqStmt::Flatten(guarded)); } else { - PrimExpr condition = - And(loop_var < loop_start + loop_extent, loop_var >= loop_start); - if (unit->stage == min_stages) { - condition = loop_var >= loop_start; + std::vector stmts; + for (const auto &[_, before] : unit->before) { + for (const auto &stmt : before) { + stmts.push_back(stmt); + } } - if (unit->stage == max_stages) { - condition = loop_var < loop_start + loop_extent; + stmts.push_back( + ConvertIRStructureToStmt(unit->child.get(), outer_enable_epi)); + for (const auto &[_, after] : unit->after) { + for (const auto &stmt : after) { + stmts.push_back(stmt); + } } Stmt stmt = IfThenElse(condition, SeqStmt::Flatten(stmts)); steady.push_back(Substitute(stmt, substitution)); @@ -1114,10 +1129,11 @@ Stmt ApplyWarpgroupPartitionToIRStructure( } // --- Per-child construction --- - // Walk root SequenceNode's children. LetDecl children accumulate as - // (var, value, before_stmts, after_stmts) tuples so that the cloned - // before/after barriers on their ScheduleUnit are preserved. When wrapping - // a subsequent non-LetDecl child, each accumulated tuple is re-emitted as + // Walk root SequenceNode's children. LetDecl children are normally + // accumulated as (var, value, before_stmts, after_stmts) tuples so that the + // cloned before/after barriers on their ScheduleUnit are preserved. When + // wrapping a subsequent non-LetDecl child, each accumulated tuple is + // re-emitted as // ; let var = value in ( ; body) // so the barrier pair brackets the let binding while `var` stays in scope // for the rest of the segment. @@ -1127,6 +1143,86 @@ Stmt ApplyWarpgroupPartitionToIRStructure( auto root_seq = static_cast(root); size_t num_children = root_seq->children.size(); + std::vector> child_read_bufs( + num_children); + std::vector> child_write_bufs( + num_children); + std::vector> child_read_vars( + num_children); + std::vector> child_write_vars( + num_children); + for (size_t ci = 0; ci < num_children; ++ci) { + IRStructure *c = root_seq->children[ci].get(); + if (!c) + continue; + for (const auto &r : c->GetReadRegions()) { + child_read_bufs[ci].insert(r->buffer.get()); + } + for (const auto &r : c->GetWriteRegions()) { + child_write_bufs[ci].insert(r->buffer.get()); + } + for (const auto &v : c->GetReadVars()) { + child_read_vars[ci].insert(v.get()); + } + for (const auto &v : c->GetWriteVars()) { + child_write_vars[ci].insert(v.get()); + } + } + + std::vector anchor_end(num_children, -1); + for (size_t i = 0; i < num_children; ++i) { + auto unit_i = static_cast(root_seq->children[i].get()); + if (!unit_i || !IsLetDeclNode(unit_i->child.get())) + continue; + const auto &rb = child_read_bufs[i]; + const auto &rv = child_read_vars[i]; + int j_max = -1; + for (size_t j = i + 1; j < num_children; ++j) { + bool conflict = false; + for (const auto *wb : child_write_bufs[j]) { + if (rb.count(wb)) { + conflict = true; + break; + } + } + if (!conflict) { + for (const auto *wv : child_write_vars[j]) { + if (rv.count(wv)) { + conflict = true; + break; + } + } + } + if (conflict) + j_max = static_cast(j); + } + anchor_end[i] = j_max; + } + + std::vector cluster_end(num_children, -1); + { + size_t i = 0; + while (i < num_children) { + if (anchor_end[i] < 0) { + ++i; + continue; + } + int e = anchor_end[i]; + bool extended = true; + while (extended) { + extended = false; + for (int k = static_cast(i) + 1; k <= e; ++k) { + if (anchor_end[k] > e) { + e = anchor_end[k]; + extended = true; + } + } + } + cluster_end[i] = e; + i = static_cast(e) + 1; + } + } + struct AccumulatedLet { Var var; PrimExpr value; @@ -1144,6 +1240,89 @@ Stmt ApplyWarpgroupPartitionToIRStructure( auto unit = static_cast(root_seq->children[ci].get()); bool is_let_decl = IsLetDeclNode(unit->child.get()); + if (cluster_end[ci] >= 0) { + size_t end = static_cast(cluster_end[ci]); + + bool cluster_contains_loop = false; + for (size_t cj = ci; cj <= end; ++cj) { + auto u = static_cast(root_seq->children[cj].get()); + if (u && u->child && u->child->IsControl()) { + cluster_contains_loop = true; + break; + } + } + + std::vector> wg_stmt_seq(num_wgs); + for (size_t i = 0; i < num_wgs; ++i) { + for (size_t cj = ci; cj <= end; ++cj) { + if (!wg_children[i][cj]) + continue; + auto tmp_seq = std::make_shared(); + tmp_seq->children.push_back(wg_children[i][cj]); + Stmt s = ConvertIRStructureToStmt(tmp_seq.get(), outer_enable_epi); + if (!IsEvaluateZero(s)) { + wg_stmt_seq[i].push_back(s); + } + } + } + + std::vector wg_stmts(num_wgs); + bool all_empty = true; + for (size_t i = 0; i < num_wgs; ++i) { + if (wg_stmt_seq[i].empty()) { + wg_stmts[i] = Evaluate(0); + } else { + wg_stmts[i] = SeqStmt::Flatten(wg_stmt_seq[i]); + all_empty = false; + } + for (int j = static_cast(wg_accumulated_lets[i].size()) - 1; + j >= 0; --j) { + const AccumulatedLet &acc = wg_accumulated_lets[i][j]; + Stmt body = wg_stmts[i]; + if (!acc.after.empty()) { + std::vector tmp = acc.after; + if (!IsEvaluateZero(body)) { + tmp.push_back(body); + } + body = SeqStmt::Flatten(tmp); + } + Stmt let_stmt = LetStmt(acc.var, acc.value, body); + if (!acc.before.empty()) { + std::vector tmp = acc.before; + tmp.push_back(let_stmt); + wg_stmts[i] = SeqStmt::Flatten(tmp); + } else { + wg_stmts[i] = let_stmt; + } + } + } + + if (!all_empty) { + if (prev_was_loop || cluster_contains_loop) { + segmented_stmts.push_back( + AttrStmt(Integer(0), attr::kAutoScheduleSharedMemoryBoundary, 0, + Evaluate(0))); + } + if (first_non_let && !has_simt_copy && !has_inner_nreg_decision && + num_wgs == 2 && config.enable_set_max_nreg) { + for (size_t i = 0; i < num_wgs; ++i) { + wg_stmts[i] = SeqStmt( + {Evaluate(Call(DataType::Handle(), tl::set_max_nreg(), + {i == 0 ? config.consumer_max_nreg + : config.producer_max_nreg, + static_cast(!i)})), + wg_stmts[i]}); + } + } + first_non_let = false; + segmented_stmts.push_back(MakeWarpgroupIf(wg_stmts)); + prev_was_loop = cluster_contains_loop; + } + + ci = end; + continue; + } + if (is_let_decl) { // Extract LetDecl {var, value, before, after} from each wg's filtered // result. The surrounding before/after live on the cloned From 10503bac79ff9c3e85de6fbce46c7536172bd100 Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Thu, 7 May 2026 17:50:04 +0800 Subject: [PATCH 154/156] add top-level barriers --- src/transform/auto_schedule.cc | 2 +- src/transform/auto_schedule/barrier.h | 114 +++++++++++++++++++------- 2 files changed, 86 insertions(+), 30 deletions(-) diff --git a/src/transform/auto_schedule.cc b/src/transform/auto_schedule.cc index d547649028..de7f386aee 100644 --- a/src/transform/auto_schedule.cc +++ b/src/transform/auto_schedule.cc @@ -755,7 +755,7 @@ ScheduleSingleKernel(const Stmt &kernel_body, IterVar thread_var, Target target, AnalyzeAndInsertBarriers(ir_structure.get(), next_barrier_id, result.barrier_buffers, result.barrier_map, thread_count, loop_info, result.buffer_infos, - neutral_sync_shared_barrier); + neutral_sync_shared_barrier, /*is_root=*/true); // Apply warpgroup partition to entire IRStructure result.scheduled_body = ApplyWarpgroupPartitionToIRStructure( diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index e4dcb55e6f..13b4f6ddfd 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -148,30 +148,27 @@ struct MultiVersionBufferInfo { }; // Barrier dependency analysis function declarations -static void -AnalyzeAndInsertBarriers(IRStructure *node, int &next_barrier_id, - std::vector &barrier_buffers, - Map &barrier_map, - const std::vector &thread_count, - LoopNestingInfo &loop_info, - std::vector &buffer_infos, - Buffer neutral_sync_shared_barrier); -static void -AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, - std::vector &barrier_buffers, - Map &barrier_map, - const std::vector &thread_count, - LoopNestingInfo &loop_info, - std::vector &buffer_infos, - Buffer neutral_sync_shared_barrier); -static void -AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, - std::vector &barrier_buffers, - Map &barrier_map, - const std::vector &thread_count, - LoopNestingInfo &loop_info, - std::vector &buffer_infos, - Buffer neutral_sync_shared_barrier); +static void AnalyzeAndInsertBarriers( + IRStructure *node, int &next_barrier_id, + std::vector &barrier_buffers, + Map &barrier_map, + const std::vector &thread_count, LoopNestingInfo &loop_info, + std::vector &buffer_infos, + Buffer neutral_sync_shared_barrier, bool is_root = false); +static void AnalyzeSequenceNodeBarriers( + SequenceNode *seq, int &next_barrier_id, + std::vector &barrier_buffers, + Map &barrier_map, + const std::vector &thread_count, LoopNestingInfo &loop_info, + std::vector &buffer_infos, + Buffer neutral_sync_shared_barrier, bool is_root = false); +static void AnalyzeControlNodeBarriers( + ControlNode *ctrl, int &next_barrier_id, + std::vector &barrier_buffers, + Map &barrier_map, + const std::vector &thread_count, LoopNestingInfo &loop_info, + std::vector &buffer_infos, + Buffer neutral_sync_shared_barrier, bool is_root = false); // Create a barrier_arrive statement for the given barrier expression // Equivalent to T.barrier_arrive(barrier_expr) in Python @@ -535,7 +532,7 @@ AnalyzeAndInsertBarriers(IRStructure *node, int &next_barrier_id, const std::vector &thread_count, LoopNestingInfo &loop_info, std::vector &buffer_infos, - Buffer neutral_sync_shared_barrier) { + Buffer neutral_sync_shared_barrier, bool is_root) { if (!node) return; @@ -543,12 +540,12 @@ AnalyzeAndInsertBarriers(IRStructure *node, int &next_barrier_id, AnalyzeSequenceNodeBarriers(static_cast(node), next_barrier_id, barrier_buffers, barrier_map, thread_count, loop_info, buffer_infos, - neutral_sync_shared_barrier); + neutral_sync_shared_barrier, is_root); } else if (node->IsControl()) { AnalyzeControlNodeBarriers(static_cast(node), next_barrier_id, barrier_buffers, barrier_map, thread_count, loop_info, buffer_infos, - neutral_sync_shared_barrier); + neutral_sync_shared_barrier, is_root); } else if (node->IsWrapper()) { auto wrapper = static_cast(node); AnalyzeAndInsertBarriers( @@ -1007,7 +1004,7 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, const std::vector &thread_count, LoopNestingInfo &loop_info, std::vector &buffer_infos, - Buffer neutral_sync_shared_barrier) { + Buffer neutral_sync_shared_barrier, bool is_root) { if (!seq) return; @@ -1046,6 +1043,65 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, auto sync_infos = GetSyncInfos(units, thread_count.size()); InsertSynchronization(units, sync_infos, next_barrier_id, barrier_buffers, barrier_map, thread_count, loop_info); + + // For the root, since we will insert kAutoScheduleSharedMemoryBoundary before + // and after for-loop segments, we + // naively add barriers at these positions to ensure synchronization. + if (is_root) { + int num_wgs = thread_count.size(); + for (const auto &unit : units) { + if (!unit->child->IsControl()) + continue; + { + std::vector barrier_buffer(num_wgs); + for (int wg_id = 0; wg_id < num_wgs; ++wg_id) { + int barrier_id = next_barrier_id++; + barrier_buffer[wg_id] = makeBarrierBuffer( + thread_count[wg_id], "root_barrier_" + std::to_string(barrier_id), + 1, barrier_buffers, barrier_map); + } + for (int wg_id = 0; wg_id < num_wgs; ++wg_id) { + for (int other_wg_id = 0; other_wg_id < num_wgs; ++other_wg_id) { + if (wg_id == other_wg_id) + continue; + PrimExpr mbar_expr = BufferLoad(barrier_buffer[wg_id], {0}); + PrimExpr parity_expr = IntImm(DataType::Int(32), 0); + Stmt wait_stmt = makeBarrierWait(mbar_expr, parity_expr); + InsertStatementIntoScheduleUnit(unit, wait_stmt, true, other_wg_id); + } + } + for (int wg_id = 0; wg_id < num_wgs; ++wg_id) { + PrimExpr mbar_expr = BufferLoad(barrier_buffer[wg_id], {0}); + Stmt arrive_stmt = makeBarrierArrive(mbar_expr); + InsertStatementIntoScheduleUnit(unit, arrive_stmt, true, wg_id); + } + } + { + std::vector barrier_buffer(num_wgs); + for (int wg_id = 0; wg_id < num_wgs; ++wg_id) { + int barrier_id = next_barrier_id++; + barrier_buffer[wg_id] = makeBarrierBuffer( + thread_count[wg_id], "root_barrier_" + std::to_string(barrier_id), + 1, barrier_buffers, barrier_map); + } + for (int wg_id = 0; wg_id < num_wgs; ++wg_id) { + PrimExpr mbar_expr = BufferLoad(barrier_buffer[wg_id], {0}); + Stmt arrive_stmt = makeBarrierArrive(mbar_expr); + InsertStatementIntoScheduleUnit(unit, arrive_stmt, false, wg_id); + } + for (int wg_id = 0; wg_id < num_wgs; ++wg_id) { + for (int other_wg_id = 0; other_wg_id < num_wgs; ++other_wg_id) { + if (wg_id == other_wg_id) + continue; + PrimExpr mbar_expr = BufferLoad(barrier_buffer[wg_id], {0}); + PrimExpr parity_expr = IntImm(DataType::Int(32), 0); + Stmt wait_stmt = makeBarrierWait(mbar_expr, parity_expr); + InsertStatementIntoScheduleUnit(unit, wait_stmt, false, other_wg_id); + } + } + } + } + } } static void @@ -1055,7 +1111,7 @@ AnalyzeControlNodeBarriers(ControlNode *ctrl, int &next_barrier_id, const std::vector &thread_count, LoopNestingInfo &loop_info, std::vector &buffer_infos, - Buffer neutral_sync_shared_barrier) { + Buffer neutral_sync_shared_barrier, bool is_root) { if (!ctrl || !ctrl->child) return; From f62744a950cd0c1556c925bf6c14ee24ae6dfa8b Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Thu, 7 May 2026 17:51:32 +0800 Subject: [PATCH 155/156] format --- src/transform/auto_schedule/barrier.h | 3 ++- src/transform/auto_schedule/warpgroup_partition.cc | 12 ++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/transform/auto_schedule/barrier.h b/src/transform/auto_schedule/barrier.h index 13b4f6ddfd..bcef1f25ea 100644 --- a/src/transform/auto_schedule/barrier.h +++ b/src/transform/auto_schedule/barrier.h @@ -1096,7 +1096,8 @@ AnalyzeSequenceNodeBarriers(SequenceNode *seq, int &next_barrier_id, PrimExpr mbar_expr = BufferLoad(barrier_buffer[wg_id], {0}); PrimExpr parity_expr = IntImm(DataType::Int(32), 0); Stmt wait_stmt = makeBarrierWait(mbar_expr, parity_expr); - InsertStatementIntoScheduleUnit(unit, wait_stmt, false, other_wg_id); + InsertStatementIntoScheduleUnit(unit, wait_stmt, false, + other_wg_id); } } } diff --git a/src/transform/auto_schedule/warpgroup_partition.cc b/src/transform/auto_schedule/warpgroup_partition.cc index 01f874cc5f..4d67c89d10 100644 --- a/src/transform/auto_schedule/warpgroup_partition.cc +++ b/src/transform/auto_schedule/warpgroup_partition.cc @@ -1306,12 +1306,12 @@ Stmt ApplyWarpgroupPartitionToIRStructure( if (first_non_let && !has_simt_copy && !has_inner_nreg_decision && num_wgs == 2 && config.enable_set_max_nreg) { for (size_t i = 0; i < num_wgs; ++i) { - wg_stmts[i] = SeqStmt( - {Evaluate(Call(DataType::Handle(), tl::set_max_nreg(), - {i == 0 ? config.consumer_max_nreg - : config.producer_max_nreg, - static_cast(!i)})), - wg_stmts[i]}); + wg_stmts[i] = + SeqStmt({Evaluate(Call(DataType::Handle(), tl::set_max_nreg(), + {i == 0 ? config.consumer_max_nreg + : config.producer_max_nreg, + static_cast(!i)})), + wg_stmts[i]}); } } first_non_let = false; From 4f8823bb6ec13997740cf052e0700f53d7b14527 Mon Sep 17 00:00:00 2001 From: AutumnKite Date: Fri, 8 May 2026 13:59:08 +0800 Subject: [PATCH 156/156] fix empty for bug --- src/transform/auto_schedule/ir_structure.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transform/auto_schedule/ir_structure.h b/src/transform/auto_schedule/ir_structure.h index f948ba1686..b3708a932e 100644 --- a/src/transform/auto_schedule/ir_structure.h +++ b/src/transform/auto_schedule/ir_structure.h @@ -1042,6 +1042,9 @@ class ScheduleUnit : public IRStructure { std::shared_ptr Clone() const override; bool containWarpgroupId(int id) const override { + if (before.count(id) > 0 || after.count(id) > 0) { + return true; + } return child && child->containWarpgroupId(id); }