diff --git a/docs/programming_guides/cluster_tma.md b/docs/programming_guides/cluster_tma.md new file mode 100644 index 0000000000..2b7bf87f7e --- /dev/null +++ b/docs/programming_guides/cluster_tma.md @@ -0,0 +1,353 @@ +# Cluster TMA: Multicast and SM-to-SM Copy + +This page describes two advanced data-movement features that are available on +NVIDIA Hopper (SM90) and later: **TMA multicast** and **SM-to-SM cluster +copy**. Both features are exposed through extensions to the existing `T.copy` +operator and require a kernel launched with thread block cluster, i.e., with `cluster_dims != (1, 1, 1)`. + +Requirements: +- CUDA Compute Capability ≥ 9.0 (Hopper / Blackwell / RTX 5090) + +--- + +## Background: Thread Block Clusters + +A *thread block cluster* is a group of CTAs that share a common virtual address +space for their shared-memory regions and can communicate without going through +global memory. Within a cluster, each CTA has a *block rank* (0-indexed +position inside the cluster), and all CTAs can observe each other's shared +memory via the `shared::cluster` address space. + +```python +with T.Kernel(grid_x, grid_y, threads=128, cluster_dims=(4, 1, 1)) as (bx, by): + rank = T.block_rank_in_cluster() # 0..3 within this cluster + T.cluster_sync() # barrier across all CTAs in cluster +``` + +--- + +## Feature 1 — TMA Multicast (`cluster_mask`) + +### What it does + +Normally each CTA issues its own TMA load, fetching a tile from global memory +into its private shared memory. With multicast, **a single TMA transaction +broadcasts one global tile to every participating CTA simultaneously**, saving +repeated DRAM traffic when multiple CTAs in a cluster need the same data (e.g., +the same K-panel in a split-K GEMM). + +```text +Global memory ──TMA multicast──▶ shared memory (rank 0) + └─▶ shared memory (rank 1) (same tile, no extra DRAM read) + TMA load ──▶ shared memory (rank 2) (independent tile) + TMA load ──▶ shared memory (rank 3) (independent tile) +``` + +### API + +```python +T.copy_cluster(src_global, dst_shared, cluster_mask=) +``` + +`cluster_mask` is a bitmask where each set bit identifies a CTA rank that +participates in the multicast. The CTA whose rank equals the lowest set bit +in the mask issues `cp.async.bulk.tensor … multicast::cluster`; every other +CTA in the mask receives the data passively (no instruction issued). CTAs +outside the mask perform a regular TMA load for their own tile. + +### Example + +```python +import tilelang +import tilelang.language as T + +def make_tma_multicast_kernel(M, N, block_M, block_N, cluster_mask): + @T.prim_func + def kernel( + A: T.Tensor((M, N), "float16"), + B: T.Tensor((M, N), "float16"), + ): + # 4 CTAs per cluster; ranks 0 and 1 share the same tile via multicast. + with T.Kernel( + T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + threads=128, + cluster_dims=(4, 1, 1) + ) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), "float16") + + # cluster_mask=0b0011: ranks 0 and 1 participate. + # Rank 0 issues tma_load_multicast; rank 1 receives passively. + # Ranks 2 and 3 each issue a regular tma_load. + T.copy_cluster(A[by * block_M, bx * block_N], A_shared, + cluster_mask=cluster_mask) + + T.copy(A_shared, B[by * block_M, bx * block_N]) + + return kernel +``` + +Running the kernel above with `cluster_mask = 0b0011`: + +| Rank | Action | `B` slice receives | +|------|--------|--------------------| +| 0 | issues multicast load | A tile at rank-0 address | +| 1 | passively receives | **same** A tile as rank 0 | +| 2 | regular TMA load | A tile at rank-2 address | +| 3 | regular TMA load | A tile at rank-3 address | + +### Notes + +- The compiler lowers `cluster_mask != 0` to + `cp.async.bulk.tensor.Nd.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster` + for the issuing CTA; CTAs in the mask but not elected as issuer receive + passively, and only CTAs outside the mask issue a standard + `cp.async.bulk.tensor`. +- Software-pipelining (`T.Pipelined`) is fully supported; the warp-specialized + rewriter recognises `tma_load_multicast` as a producer operation. +- `cluster_mask` is a compile-time constant; dynamic masks are not supported. + +--- + +## Feature 2 — SM-to-SM Cluster Copy (`dst_block`) + +### What it does + +SM-to-SM copy lets one CTA **push data directly from its own shared memory +into another CTA's shared memory** within the same cluster, without a round +trip through global memory. This is useful for patterns such as: + +- Partial result exchange (e.g., split-K partial sums across SM boundaries) +- Producer–consumer pipelines where the producer fills a neighbor's buffer +- All-to-all collective communication within a cluster + +### Lowering paths + +The compiler selects one of three paths depending on whether `remote_barrier` +is provided and whether the copy region is contiguous: + +| Path | Condition | Hardware instruction | Arrive count | +|------|-----------|---------------------|--------------| +| **TMA fast path** | `remote_barrier` set + region is contiguous | one `tl::tma_store_cluster` | 1 | +| **Multi-TMA path** | `remote_barrier` set + ND region is non-contiguous | one `tl::tma_store_cluster` per contiguous row | number of rows | +| **SIMT fallback** | no `remote_barrier`, or non-decomposable region | `map_shared_rank` scalar stores by all threads | auto-injected arrive if `remote_barrier` is set | + +A copy region is *contiguous* when its innermost dimension spans the full +buffer width (i.e. the copy region `[..., 0:N_tile]` satisfies +`N_tile == buffer_shape[-1]`). If the innermost extent is shorter, the region +is non-contiguous and the TMA fast path is unavailable. + +### TMA fast path — bulk async copy with mbarrier + +```python +T.copy_cluster(src_shared, dst_shared, dst_block=, remote_barrier=) +``` + +A single elected thread issues one `cp.async.bulk.shared::cluster` instruction. +The hardware DMA engine transfers the entire tile asynchronously and signals +the destination CTA's mbarrier on completion. The destination CTA waits with +`T.mbarrier_wait_parity`. + +Steps: +1. Both CTAs allocate the **same** shared memory layout so their mbarriers live + at the same offset. +2. Every CTA initialises its own barrier for 1 arrival via `T.alloc_cluster_barrier([1])`. +3. The source CTA (`pid == 0` below) calls `T.copy_cluster(... dst_block=1, remote_barrier=...)`. +4. The destination CTA (`pid == 1`) waits on its local barrier copy. + +```python +import tilelang +import tilelang.language as T + +@tilelang.jit(execution_backend="cython") +def make_cluster_copy_kernel(N: int): + @T.prim_func + def kernel( + A: T.Tensor((N,), "float32"), + B: T.Tensor((N,), "float32"), + ): + with T.Kernel(2, threads=128, cluster_dims=(2, 1, 1)) as pid: + s_src = T.alloc_shared((N,), "float32") + s_dst = T.alloc_shared((N,), "float32") + s_barrier = T.alloc_cluster_barrier([1]) + + T.fill(s_src, 0.0) + T.fill(s_dst, 0.0) + + T.cluster_sync() + + if pid == 0: + # Load A into local shared memory. + for i in T.Parallel(N): + s_src[i] = A[i] + + # Async-push s_src → s_dst in CTA 1, signal CTA 1's barrier. + T.copy_cluster(s_src, s_dst, dst_block=1, + remote_barrier=s_barrier[0]) + + if pid == 1: + # Wait until CTA 0 finishes writing. + T.mbarrier_wait_parity(s_barrier[0], 0) + + for i in T.Parallel(N): + B[i] = s_dst[i] + + return kernel +``` + +Generated producer code (single-thread guard, one PTX instruction): + +```cuda +if (((int)threadIdx.x) == 0) { + tl::tma_store_cluster(&s_dst[0], &s_src[0], 1, + (uint32_t)(N * 4), s_barrier[0]); +} +``` + +### Multi-TMA path — non-contiguous ND regions + +When `remote_barrier` is provided but the copy region is not fully contiguous +(e.g. copying a 2-D slice `[0:M, 0:N_tile]` from a buffer of shape +`[M, N_full]` where `N_tile < N_full`), the compiler automatically +**decomposes the ND region into individual contiguous rows**, emitting one +`tl::tma_store_cluster` call per row. The mbarrier `arrive_count` is updated +to the total number of rows so the destination CTA's `mbarrier_wait_parity` +completes only after all rows are transferred. + +```python +# 2-D non-contiguous copy: N_tile < N_full → compiler emits M TMA calls +s_src = T.alloc_shared((M, N_full), "float32") +s_dst = T.alloc_shared((M, N_full), "float32") +s_barrier = T.alloc_cluster_barrier([1]) # arrive_count updated to M at compile time + +T.copy_cluster( + s_src[0:M, 0:N_tile], + s_dst[0:M, 0:N_tile], + dst_block=1, + remote_barrier=s_barrier[0], +) +``` + +The decomposition is recursive: a 3-D region `[0:D, 0:M, 0:N_tile]` (with +`N_tile < N_full`) produces `D × M` TMA calls and sets `arrive_count = D * M`. +Static extents are unrolled at compile time; symbolic extents emit TIR `For` +loops. + +The API is identical to the fast path — no change is required in user code. + +### SIMT fallback — element-by-element stores + +Omit `remote_barrier` to always use the SIMT fallback: + +```python +T.copy_cluster(s_src, s_dst, dst_block=1) +``` + +This lowers to a SIMT parallel loop where every thread writes one (or a few) +elements into the remote CTA's shared memory via +`cooperative_groups::map_shared_rank`. Because `map_shared_rank` returns a +scalar pointer, vectorised writes are not possible. Use this path only when an +mbarrier is unavailable or when the tile is too small to justify barrier +overhead. + +When `remote_barrier` is provided but the region is neither contiguous nor +decomposable into TMA rows, the compiler falls back to SIMT stores and +**auto-injects a barrier arrive** (`__syncthreads()` + single-thread +`s_barrier.arrive(1u)`) so the destination CTA can still wait on the same +mbarrier without any API change. + +### Synchronisation contract + +| | TMA fast path | Multi-TMA path | SIMT fallback | +|-|---------------|----------------|---------------| +| Source CTA | no wait needed; copy is async | no wait needed | effectively sync after the loop | +| Destination CTA | `T.mbarrier_wait_parity(barrier, parity)` | `T.mbarrier_wait_parity(barrier, parity)` | `T.cluster_sync()` (no barrier), or `T.mbarrier_wait_parity` if auto-arrived | + +### Notes + +- All paths require `src` and `dst` to be in `shared` or `shared.dyn` scope. +- The mbarrier must be allocated with `T.alloc_cluster_barrier([arrive_count])`. + The compiler updates `arrive_count` automatically for the multi-TMA path. +- `T.cluster_sync()` after allocation but before the copy is required to ensure + all CTAs have reached the barrier-init point before any data is pushed. +- `dst_block` may be a compile-time integer or a runtime `tir.PrimExpr`. +- `cluster_mask` and `dst_block` are mutually exclusive in a single + `T.copy_cluster` call. + +--- + +## Cluster Helper Builtins + +| Builtin | Return | Description | +|---------|--------|-------------| +| `T.block_rank_in_cluster()` | `int32` | Block rank (0-indexed) within the cluster | +| `T.cluster_sync()` | — | Barrier synchronisation across all cluster CTAs (arrive + wait) | +| `T.cluster_arrive()` | — | Signal cluster barrier arrival (aligned) | +| `T.cluster_arrive_relaxed()` | — | Signal cluster barrier arrival (relaxed) | +| `T.cluster_wait()` | — | Wait for all cluster CTAs to arrive | +| `T.alloc_cluster_barrier([count])` | `Buffer` | Allocate and initialise an mbarrier for `count` arrivals | +| `T.mbarrier_arrive(bar)` | — | Signal one arrival on an mbarrier | +| `T.mbarrier_wait_parity(bar, parity)` | — | Wait until `bar` flips to `parity` | + +--- + +## Putting It Together: Split-K Sketch + +A common pattern combining both features: multicast the shared K-panel to +all cluster CTAs (saving DRAM bandwidth), then reduce partial sums with +SM-to-SM copy (saving global-memory round trips). + +```python +@T.prim_func +def split_k_gemm(A, B, C): + with T.Kernel(grid_x, grid_y, threads=256, cluster_dims=(4, 1, 1)) as (bx, by): + rank = T.block_rank_in_cluster() + A_s = T.alloc_shared((BM, BK), "float16") + B_s = T.alloc_shared((BK, BN), "float16") + C_f = T.alloc_fragment((BM, BN), "float32") + C_s = T.alloc_shared((BM, BN), "float32") + barrier = T.alloc_cluster_barrier([3]) + T.clear(C_f) + + # Phase 1: each CTA loads its K-slice; A is multicast to rank 0 and 1. + for ko in T.Pipelined(T.ceildiv(K, BK * 4), num_stages=3): + k_off = (rank + ko * 4) * BK + T.copy_cluster(A[by * BM, k_off], A_s, cluster_mask=0b0011) + T.copy(B[k_off, bx * BN], B_s) + T.gemm(A_s, B_s, C_f) + + # Phase 2: push each rank's partial sums to rank 0 for accumulation. + # + # Use a per-rank staging slot so every non-zero rank writes to a + # distinct destination region — avoiding both a destination race and + # an arrival-count mismatch. Each CTA stores its own partial into + # C_parts[rank]; non-zero ranks then push that slot to the matching + # slot in rank 0's shared memory. + # + # Arrival count must equal the number of producers: cluster_size - 1. + C_parts = T.alloc_shared((4, BM, BN), "float32") # one slot per rank + T.copy(C_f, C_parts[rank]) + + T.cluster_sync() + + if rank != 0: + # Push this rank's slot to the *same* slot index in rank 0's + # C_parts — different offsets, so no destination race. + T.copy_cluster(C_parts[rank], C_parts[rank], + dst_block=0, remote_barrier=barrier[0]) + + if rank == 0: + T.mbarrier_wait_parity(barrier[0], 0) # wakes after all 3 arrivals + # C_parts[0..3] in rank 0's smem now hold all four partial sums. + # accumulate and store ... + T.copy(C_parts[0], C[by * BM, bx * BN]) +``` + +--- + +## See Also + +- `testing/python/cuda/test_tma_multicast_demo.py` — multicast validation +- `testing/python/cuda/test_tma_dsmem.py` — SM-to-SM copy validation (fast path, multi-TMA, and SIMT fallback) +- Programming Guides → Instructions — complete `T.copy` parameter reference +- Programming Guides → Control Flow — `T.Pipelined` and warp-specialized pipelines diff --git a/src/backend/cuda/codegen/codegen_cuda.cc b/src/backend/cuda/codegen/codegen_cuda.cc index 832a8fba90..a98ebdd195 100644 --- a/src/backend/cuda/codegen/codegen_cuda.cc +++ b/src/backend/cuda/codegen/codegen_cuda.cc @@ -2240,6 +2240,29 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ss << ");\n"; this->PrintIndent(); this->stream << ss.str(); + } else if (op->op.same_as(tl::tma_load_multicast())) { + std::ostringstream ss; + ICHECK_GE(op->args.size(), 5) + << "tma_load_multicast requires at least 5 args"; + auto eviction_policy = + this->eviction_policy_names_ + [op->args[op->args.size() - 1].as()->value]; + if (eviction_policy != "EVICT_NORMAL") { + ss << "tl::tma_load_multicast("; + } else { + ss << "tl::tma_load_multicast("; + } + ss << this->PrintExpr(op->args[0]) << ", "; + ss << this->PrintExpr(op->args[1]) << ", "; + ss << this->PrintExpr(op->args[2]) << ", "; + ss << "(uint16_t)(" << this->PrintExpr(op->args[3]) << ")"; + for (size_t i = 4; i < op->args.size() - 1; i++) { + ss << ", " << this->PrintExpr(op->args[i]); + } + ss << ");\n"; + this->PrintIndent(); + this->stream << ss.str(); } else if (op->op.same_as(tl::tma_load_im2col())) { std::stringstream ss; auto eviction_policy = @@ -2537,6 +2560,40 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { replacer.register_rule("(C_ptr)", c_ref); replacer.register_rule("(C_offset)", c_bias); this->stream << replacer.rewrite(mma_call); + } else if (op->op.same_as(tl::tma_store_cluster())) { + ICHECK_EQ(op->args.size(), 5U) << "tma_store_cluster requires 5 args"; + this->PrintIndent(); + this->stream << "tl::tma_store_cluster("; + this->stream << this->PrintExpr(op->args[0]) << ", "; + this->stream << this->PrintExpr(op->args[1]) << ", "; + this->stream << "(int)(" << this->PrintExpr(op->args[2]) << "), "; + this->stream << "(uint32_t)(" << this->PrintExpr(op->args[3]) << "), "; + this->stream << this->PrintExpr(op->args[4]) << ");\n"; + + } else if (op->op.same_as(tl::ptx_cluster_store())) { + ICHECK_EQ(op->args.size(), 4U); + std::string buffer_var = this->PrintExpr(op->args[0]); + std::string value = this->PrintExpr(op->args[1]); + std::string dst_block = this->PrintExpr(op->args[2]); + std::string index = this->PrintExpr(op->args[3]); + + this->need_cooperative_groups_ = true; + this->PrintIndent(); + this->stream << "{\n"; + int cluster_scope = this->BeginScope(); + this->PrintIndent(); + this->stream << "namespace cg = cooperative_groups;\n"; + this->PrintIndent(); + this->stream << "cg::cluster_group cluster = cg::this_cluster();\n"; + this->PrintIndent(); + this->stream << "auto* dst_ptr = cluster.map_shared_rank(&" << buffer_var + << "[" << index << "], " << dst_block << ");\n"; + this->PrintIndent(); + this->stream << "*dst_ptr = " << value << ";\n"; + this->EndScope(cluster_scope); + this->PrintIndent(); + this->stream << "}\n"; + } else if (op->op.same_as(tl::ptx_mma_sm70())) { // arg 0: shape: mXnXkX // arg 1: A layout: row/col diff --git a/src/backend/cuda/op/copy.cc b/src/backend/cuda/op/copy.cc index 14cbd6348e..9ee98737bc 100644 --- a/src/backend/cuda/op/copy.cc +++ b/src/backend/cuda/op/copy.cc @@ -23,6 +23,7 @@ #include #include +#include #include namespace tvm { @@ -92,6 +93,35 @@ int GetEvictionPolicy(const CopyNode &op) { return 0; // default: evict_normal } +int64_t GetClusterMask(const CopyNode &op) { + if (auto val = op.annotations.Get("cluster_mask")) { + if (auto int_val = val->as()) { + return int_val->value; + } + } + return 0; +} + +int MinRankInClusterMask(int64_t cluster_mask) { + ICHECK_GT(cluster_mask, 0); + uint64_t mask = static_cast(cluster_mask); + int rank = 0; + while ((mask & 1U) == 0U) { + mask >>= 1; + ++rank; + } + return rank; +} + +Optional GetBarrier(const CopyNode &op) { + if (auto val = op.annotations.Get("barrier")) { + if (val->as()) { + return Downcast(val.value()); + } + } + return Optional(); +} + bool GetIsAsyncCopy(const CopyNode &op) { if (GetBoolAnnotation(op, "is_async_copy")) { return true; @@ -104,6 +134,128 @@ bool GetNoImplicitAsyncCommitWait(const CopyNode &op) { return GetBoolAnnotation(op, attr::kAsyncCopyNoImplicitCommitWait); } +bool IsContiguousRegion(const Buffer &buf, const Array &ranges, + arith::Analyzer *analyzer) { + ICHECK_EQ(buf->shape.size(), ranges.size()) + << "IsContiguousRegion: buffer/range rank mismatch for " << buf->name; + + int n = static_cast(ranges.size()); + int pivot = -1; + for (int i = 0; i < n; ++i) { + if (!analyzer->CanProveEqual(ranges[i]->extent, 1)) { + pivot = i; + break; + } + } + if (pivot == -1) { + return true; + } + + for (int i = 0; i < pivot; ++i) { + ICHECK(analyzer->CanProveEqual(ranges[i]->extent, 1)) + << "IsContiguousRegion: dim " << i << " precedes pivot " << pivot + << " but has non-unit extent " << ranges[i]->extent << " for buffer " + << buf->name; + } + + for (int i = pivot + 1; i < n; ++i) { + if (!analyzer->CanProveEqual(ranges[i]->min, 0) || + !analyzer->CanProveEqual(ranges[i]->extent, buf->shape[i])) { + return false; + } + } + return true; +} + +std::pair, PrimExpr> +MakeTMARows(const Buffer &src, const Array &src_ranges, + const Buffer &dst, const Array &dst_ranges, + PrimExpr dst_block, PrimExpr barrier_load, + arith::Analyzer *analyzer) { + int n = static_cast(src_ranges.size()); + + auto linear_off = [](const Buffer &buf, + const Array &ranges) -> PrimExpr { + int r = static_cast(ranges.size()); + PrimExpr off = 0, stride = 1; + for (int i = r - 1; i >= 0; --i) { + off = off + ranges[i]->min * stride; + if (i > 0) { + stride = stride * buf->shape[i]; + } + } + return off; + }; + + if (IsContiguousRegion(src, src_ranges, analyzer) && + IsContiguousRegion(dst, dst_ranges, analyzer)) { + PrimExpr total_elems = 1; + for (const auto &r : src_ranges) { + total_elems = total_elems * r->extent; + } + PrimExpr size_bytes = + cast(DataType::UInt(32), TMABytesFromElements(total_elems, src->dtype)); + PrimExpr src_ptr = src.access_ptr(1, DataType::Handle(), 1, + linear_off(src, src_ranges), total_elems); + PrimExpr dst_ptr = dst.access_ptr(2, DataType::Handle(), 1, + linear_off(dst, dst_ranges), total_elems); + Stmt call = + Evaluate(Call(DataType::Handle(), tma_store_cluster(), + {dst_ptr, src_ptr, dst_block, size_bytes, barrier_load})); + return {{call}, IntImm(DataType::Int(32), 1)}; + } + + int split_dim = -1; + for (int d = 0; d < n; ++d) { + if (!analyzer->CanProveEqual(src_ranges[d]->extent, 1)) { + split_dim = d; + break; + } + } + ICHECK(split_dim >= 0) + << "MakeTMARows: all dimensions are trivial yet region is not " + "contiguous"; + + PrimExpr extent = src_ranges[split_dim]->extent; + const auto *ext_imm = extent.as(); + + if (ext_imm) { + Array all_stmts; + PrimExpr total = IntImm(DataType::Int(32), 0); + for (int64_t k = 0; k < ext_imm->value; ++k) { + Array new_src = src_ranges; + Array new_dst = dst_ranges; + PrimExpr kexpr = IntImm(DataType::Int(32), k); + new_src.Set(split_dim, + Range::FromMinExtent(src_ranges[split_dim]->min + kexpr, 1)); + new_dst.Set(split_dim, + Range::FromMinExtent(dst_ranges[split_dim]->min + kexpr, 1)); + auto [stmts, cnt] = MakeTMARows(src, new_src, dst, new_dst, dst_block, + barrier_load, analyzer); + for (const auto &s : stmts) { + all_stmts.push_back(s); + } + total = total + cnt; + } + return {all_stmts, total}; + } + + Var k("k_tma_row", DataType::Int(32)); + Array body_src = src_ranges; + Array body_dst = dst_ranges; + body_src.Set(split_dim, + Range::FromMinExtent(src_ranges[split_dim]->min + k, 1)); + body_dst.Set(split_dim, + Range::FromMinExtent(dst_ranges[split_dim]->min + k, 1)); + auto [body_stmts, body_cnt] = MakeTMARows(src, body_src, dst, body_dst, + dst_block, barrier_load, analyzer); + Stmt body = body_stmts.size() == 1 ? body_stmts[0] + : static_cast(SeqStmt(body_stmts)); + Stmt for_loop = + For(k, IntImm(DataType::Int(32), 0), extent, ForKind::kSerial, body); + return {{for_loop}, extent * body_cnt}; +} + } // namespace namespace cuda { @@ -184,6 +336,9 @@ struct Copy { static Stmt LowerNormal(const CopyNode &op, const LowerArgs &T, arith::Analyzer *analyzer); + static Stmt LowerCluster(const CopyNode &op, const LowerArgs &T, + arith::Analyzer *analyzer); + static Stmt LowerCPAsync(const CopyNode &op, const LowerArgs &T, arith::Analyzer *analyzer); @@ -431,6 +586,12 @@ Stmt Copy::Lower(const CopyNode &op, const LowerArgs &T, arith::Analyzer *analyzer) { auto copy_inst = SelectInst(op, T.target, T.layout_map, analyzer, /*buffer_oob=*/false); + if (op.dst_block.defined()) { + ICHECK(TargetHasBulkCopy(T.target)) + << "T.copy with dst_block requires cluster-copy support (CUDA SM90+). " + << "Got target=" << T.target; + return LowerCluster(op, T, analyzer); + } if (copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) { auto tmem_copy = LowerTmem(op, T, analyzer); ICHECK(tmem_copy.defined()) << "Failed to lower tensor memory copy"; @@ -537,6 +698,206 @@ Stmt Copy::LowerNormal(const CopyNode &op, const LowerArgs &T, return tl::LowerNormalCopy(op, T, analyzer); } +Stmt Copy::LowerCluster(const CopyNode &op, const LowerArgs &T, + arith::Analyzer *analyzer) { + const Buffer &src = op.src; + const Buffer &dst = op.dst; + const Array &src_range = op.src_range; + const Array &dst_range = op.dst_range; + ICHECK(op.dst_block.defined()); + ICHECK(src.scope() == "shared" || src.scope() == "shared.dyn"); + ICHECK(dst.scope() == "shared" || dst.scope() == "shared.dyn"); + + if (auto barrier_opt = GetBarrier(op)) { + bool src_contiguous = IsContiguousRegion(src, src_range, analyzer); + bool dst_contiguous = IsContiguousRegion(dst, dst_range, analyzer); + + PrimExpr src_elements = 1; + for (auto r : src_range) { + src_elements = src_elements * r->extent; + } + PrimExpr dst_elements = 1; + for (auto r : dst_range) { + dst_elements = dst_elements * r->extent; + } + bool element_match = analyzer->CanProveEqual(src_elements, dst_elements); + + if (src_contiguous && dst_contiguous && element_match) { + PrimExpr barrier_load = barrier_opt.value(); + + auto compute_linear_offset = [](const Buffer &buf, + const Array &ranges) -> PrimExpr { + PrimExpr offset = 0; + PrimExpr stride = 1; + for (int i = static_cast(ranges.size()) - 1; i >= 0; --i) { + offset = offset + ranges[i]->min * stride; + if (i > 0) { + stride = stride * buf->shape[i]; + } + } + return offset; + }; + + PrimExpr dst_offset = compute_linear_offset(dst, dst_range); + PrimExpr src_offset = compute_linear_offset(src, src_range); + PrimExpr total_elements = 1; + for (auto r : src_range) { + total_elements = total_elements * r->extent; + } + PrimExpr size_bytes = cast( + DataType::UInt(32), TMABytesFromElements(total_elements, src->dtype)); + + PrimExpr dst_ptr = + dst.access_ptr(2, DataType::Handle(), 1, dst_offset, total_elements); + PrimExpr src_ptr = + src.access_ptr(1, DataType::Handle(), 1, src_offset, total_elements); + + Stmt bulk_copy = Evaluate(Call( + DataType::Handle(), tma_store_cluster(), + {dst_ptr, src_ptr, op.dst_block.value(), size_bytes, barrier_load})); + + return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), bulk_copy); + } + + bool same_shape = (src_range.size() == dst_range.size()); + for (size_t d = 0; d < src_range.size() && same_shape; ++d) { + if (!analyzer->CanProveEqual(src_range[d]->extent, + dst_range[d]->extent)) { + same_shape = false; + } + } + + if (element_match && same_shape) { + PrimExpr barrier_load = barrier_opt.value(); + const auto *barrier_buf_load = barrier_load.as(); + ICHECK(barrier_buf_load) + << "LowerCluster: expected BufferLoad for barrier annotation"; + Var barrier_data_var = barrier_buf_load->buffer->data; + + auto [tma_stmts, n_rows] = + MakeTMARows(src, src_range, dst, dst_range, op.dst_block.value(), + barrier_load, analyzer); + + if (T.UpdateBarrierArrive) { + T.UpdateBarrierArrive(barrier_data_var, n_rows); + } + + Stmt seq = (tma_stmts.size() == 1) + ? tma_stmts[0] + : static_cast(SeqStmt(tma_stmts)); + return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), seq); + } + + LOG(WARNING) + << "Falling back to element-wise cluster copy: bulk cluster paths " + "require matching element counts and same per-dim extents between " + "src and dst. src=" + << src->name << ", dst=" << dst->name; + } + + auto simt_loop = op.MakeSIMTLoop(analyzer); + auto fused_loop = Downcast(ParallelLoopFuser::Fuse(simt_loop)); + + std::vector levels = {InferLevel::kCommon, InferLevel::kStrict, + InferLevel::kFree}; + auto par_op = ParallelOp(fused_loop); + for (auto level : levels) { + par_op->InferLayout({T.target, + T.thread_bounds, + T.layout_map, + analyzer, + false, + T.buffer_remap, + {}}, + level); + } + auto loop_layout = par_op->GetLoopLayout(); + auto thread_loop = + PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout); + auto vectorized_thread_loop = + VectorizeLoop(thread_loop, T.layout_map, /*vectorize_hint=*/1); + + class ClusterCopyReplacer : public StmtExprMutator { + public: + ClusterCopyReplacer(const Buffer &dst, PrimExpr dst_block, + const Buffer &target_dst, Optional dst_layout) + : dst_(dst), dst_block_(dst_block), target_dst_(target_dst), + dst_layout_(dst_layout) {} + + Stmt VisitStmt_(const BufferStoreNode *op) final { + if (op->buffer.same_as(dst_)) { + Array physical_indices = op->indices; + if (!target_dst_.same_as(dst_) && dst_layout_.defined()) { + physical_indices = dst_layout_.value()->Forward(op->indices); + } + + PrimExpr linearized_index = physical_indices[0]; + if (physical_indices.size() > 1) { + PrimExpr multiplier = 1; + linearized_index = 0; + for (int i = static_cast(physical_indices.size()) - 1; i >= 0; + --i) { + linearized_index = + linearized_index + physical_indices[i] * multiplier; + if (i > 0) { + multiplier = multiplier * target_dst_->shape[i]; + } + } + } + + Buffer target_buffer = target_dst_; + if (target_dst_.same_as(dst_)) { + target_buffer = op->buffer; + } + + PrimExpr total_elems = 1; + for (const PrimExpr &s : target_buffer->shape) { + total_elems = total_elems * s; + } + + Stmt remote_store = + Evaluate(Call(DataType::Handle(), ptx_cluster_store(), + {target_buffer.access_ptr(2), op->value, dst_block_, + linearized_index})); + + return IfThenElse(linearized_index < total_elems, remote_store, Stmt()); + } + return StmtExprMutator::VisitStmt_(op); + } + + private: + const Buffer &dst_; + PrimExpr dst_block_; + const Buffer &target_dst_; + Optional dst_layout_; + }; + + Buffer target_dst = dst; + if (T.buffer_remap.count(dst)) { + target_dst = T.buffer_remap[dst]; + } + + Optional dst_layout = std::nullopt; + if (T.layout_map.count(dst)) { + dst_layout = T.layout_map[dst]; + } + + Stmt simt_copy = ClusterCopyReplacer(dst, op.dst_block.value(), target_dst, + dst_layout)(vectorized_thread_loop); + + if (auto barrier_opt = GetBarrier(op)) { + Stmt sync = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), + {StringImm("shared")})); + Stmt arrive = + Evaluate(Call(DataType::Handle(), ptx_arrive_cluster_barrier(), + {barrier_opt.value(), op.dst_block.value()})); + Stmt guarded_arrive = + IfThenElse(EQ(T.thread_var, T.thread_bounds->min), arrive); + return SeqStmt({simt_copy, sync, guarded_arrive}); + } + return simt_copy; +} + Stmt Copy::LowerLDSM(const CopyNode &op, const LowerArgs &T, arith::Analyzer *analyzer, CopyInst copy_inst) { const Buffer &src = op.src; @@ -1116,6 +1477,9 @@ Stmt Copy::LowerBulk(const CopyNode &op, const LowerArgs &T, Call create_descriptor = Call(DataType::Handle(), create_tma_descriptor(), desc.EncodeCallArgs()); + int64_t cluster_mask = GetClusterMask(op); + bool use_multicast = is_load && (cluster_mask > 0); + int barrier_base_id = -1; PrimExpr mbar_handle; bool is_cluster_barrier = false; @@ -1148,6 +1512,19 @@ Stmt Copy::LowerBulk(const CopyNode &op, const LowerArgs &T, for (auto e : desc.smem_box) total_elements *= e; + auto build_multicast_args = [&](const Array ®ular_args) { + Array mc_args; + mc_args.reserve(regular_args.size() + 1); + mc_args.push_back(regular_args[0]); // descriptor + mc_args.push_back(regular_args[1]); // mbarrier + mc_args.push_back(regular_args[2]); // shared memory pointer + mc_args.push_back(IntImm(DataType::Int(32), cluster_mask)); + for (size_t i = 3; i < regular_args.size(); ++i) { + mc_args.push_back(regular_args[i]); + } + return mc_args; + }; + if ((*inner_box_dim) != instruction_dim) { Var loop_var("i"); int loop_extent = (*inner_box_dim) / instruction_dim; @@ -1169,6 +1546,25 @@ Stmt Copy::LowerBulk(const CopyNode &op, const LowerArgs &T, } tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled, Evaluate(Call(DataType::Handle(), tma_op, args, ann_loop))); + + if (use_multicast) { + Array mc_args = build_multicast_args(args); + Stmt multicast_copy = For( + loop_var, 0, loop_extent, ForKind::kUnrolled, + Evaluate(Call(DataType::Handle(), tma_load_multicast(), mc_args))); + + int min_cta_rank = MinRankInClusterMask(cluster_mask); + PrimExpr block_rank = + Call(DataType::Int(32), block_rank_in_cluster(), {}); + PrimExpr mask_imm = IntImm(DataType::Int(32), cluster_mask); + PrimExpr not_in_mask = EQ(bitwise_and(right_shift(mask_imm, block_rank), + IntImm(DataType::Int(32), 1)), + IntImm(DataType::Int(32), 0)); + Stmt regular_or_noop = IfThenElse(not_in_mask, tma_copy, std::nullopt); + tma_copy = + IfThenElse(EQ(block_rank, IntImm(DataType::Int(32), min_cta_rank)), + multicast_copy, regular_or_noop); + } } else { PrimExpr shared_addr = shared_tensor.access_ptr( is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, total_elements); @@ -1186,6 +1582,24 @@ Stmt Copy::LowerBulk(const CopyNode &op, const LowerArgs &T, ann.Set("use_2cta", IntImm(DataType::Int(32), 1)); } tma_copy = Evaluate(Call(DataType::Handle(), tma_op, args, ann)); + + if (use_multicast) { + Array mc_args = build_multicast_args(args); + Stmt multicast_copy = + Evaluate(Call(DataType::Handle(), tma_load_multicast(), mc_args)); + + int min_cta_rank = MinRankInClusterMask(cluster_mask); + PrimExpr block_rank = + Call(DataType::Int(32), block_rank_in_cluster(), {}); + PrimExpr mask_imm = IntImm(DataType::Int(32), cluster_mask); + PrimExpr not_in_mask = EQ(bitwise_and(right_shift(mask_imm, block_rank), + IntImm(DataType::Int(32), 1)), + IntImm(DataType::Int(32), 0)); + Stmt regular_or_noop = IfThenElse(not_in_mask, tma_copy, std::nullopt); + tma_copy = + IfThenElse(EQ(block_rank, IntImm(DataType::Int(32), min_cta_rank)), + multicast_copy, regular_or_noop); + } } if (!is_load) { @@ -1278,6 +1692,14 @@ Stmt Copy::LowerBulk1D(const CopyNode &op, const LowerArgs &T, ICHECK(copy_inst == CopyInst::kBulkLoad1D || copy_inst == CopyInst::kBulkStore1D); + int64_t cluster_mask = GetClusterMask(op); + ICHECK(cluster_mask == 0) + << "cluster_mask=0x" << std::hex << cluster_mask + << " requires descriptor-based TMA (kBulkLoad); the 1D bulk-copy path " + "does not support multicast. src=" + << src->name << " (scope=" << src.scope() << "), dst=" << dst->name + << " (scope=" << dst.scope() << ")."; + bool is_load = copy_inst == CopyInst::kBulkLoad1D; auto shared_range = is_load ? dst_range : src_range; auto global_range = is_load ? src_range : dst_range; diff --git a/src/backend/cuda/op/copy_analysis.cc b/src/backend/cuda/op/copy_analysis.cc index d43aa5539c..1c9b044977 100644 --- a/src/backend/cuda/op/copy_analysis.cc +++ b/src/backend/cuda/op/copy_analysis.cc @@ -55,6 +55,15 @@ bool GetIsTmaCopy(const CopyNode &op) { return GetBoolAnnotation(op, "is_tma_copy"); } +int64_t GetClusterMask(const CopyNode &op) { + if (auto val = op.annotations.Get("cluster_mask")) { + if (auto int_val = val->as()) { + return int_val->value; + } + } + return 0; +} + bool GetIsAsyncCopy(const CopyNode &op) { if (GetBoolAnnotation(op, "is_async_copy")) { return true; @@ -344,6 +353,7 @@ struct CopyFacts { bool explicit_cp_async = false; bool no_implicit_async_commit_wait = false; bool disable_tma = false; + int64_t cluster_mask = 0; bool can_bulk_load_1d = false; bool can_bulk_store_1d = false; bool can_bulk_load = false; @@ -456,6 +466,7 @@ CopyFacts AnalyzeCopyFacts(const CopyNode &op, const CopyAnalysisContext &ctx) { facts.explicit_cp_async = GetIsAsyncCopy(op); facts.no_implicit_async_commit_wait = GetNoImplicitAsyncCommitWait(op); facts.disable_tma = GetDisableTMA(op); + facts.cluster_mask = GetClusterMask(op); facts.tma_unavailable_reason = MakeTmaUnavailableReason(op); facts.async_unavailable_reason = MakeAsyncUnavailableReason(op, ctx.target); facts.pass_context_disables_tma = @@ -488,6 +499,9 @@ CopyFacts AnalyzeCopyFacts(const CopyNode &op, const CopyAnalysisContext &ctx) { if (facts.can_bulk_load_1d) { facts.can_bulk_load_ignore_last_dim = true; + facts.can_bulk_load = + CheckBulkLoad(op, ctx.target, analyzer, /*check_last_dim=*/true, + ctx.emit_diagnostics); } else { facts.can_bulk_load_ignore_last_dim = CheckBulkLoad(op, ctx.target, analyzer, /*check_last_dim=*/false, @@ -499,6 +513,9 @@ CopyFacts AnalyzeCopyFacts(const CopyNode &op, const CopyAnalysisContext &ctx) { if (facts.can_bulk_store_1d) { facts.can_bulk_store_ignore_last_dim = true; + facts.can_bulk_store = + CheckBulkStore(op, ctx.target, analyzer, /*check_last_dim=*/true, + ctx.emit_diagnostics); } else { facts.can_bulk_store_ignore_last_dim = CheckBulkStore(op, ctx.target, analyzer, /*check_last_dim=*/false, @@ -521,6 +538,19 @@ CopyFacts AnalyzeCopyFacts(const CopyNode &op, const CopyAnalysisContext &ctx) { CopyInstSelection SelectCopyInstForLowering(const CopyNode &op, const CopyAnalysisContext &ctx) { CopyFacts facts = AnalyzeCopyFacts(op, ctx); + if (facts.cluster_mask != 0) { + if (facts.can_bulk_load) { + return Supported(CopyInst::kBulkLoad); + } + std::ostringstream oss; + oss << "cluster_mask=0x" << std::hex << facts.cluster_mask + << " requires descriptor-based TMA (kBulkLoad), but the copy does not " + "meet TMA bulk-load constraints. src=" + << op.src->name << " (scope=" << op.src.scope() + << "), dst=" << op.dst->name << " (scope=" << op.dst.scope() << ")."; + return Unsupported(oss.str()); + } + if (facts.explicit_tma) { CopyInst inst = SelectTmaInst(facts, /*allow_load=*/true, /*allow_store=*/true, @@ -557,6 +587,10 @@ std::string ClassifyCopyForInstructionAnnotation(const CopyNode &op, return "sync"; } + if (facts.cluster_mask != 0) { + return facts.can_bulk_load ? "tma" : "sync"; + } + if (facts.explicit_tma) { CopyInst inst = SelectTmaInst(facts, /*allow_load=*/true, /*allow_store=*/true, @@ -585,6 +619,11 @@ CopyInstSelection ClassifyWarpSpecializedProducerCopy(const CopyNode &op, return Supported(CopyInst::kNormal); } + if (facts.cluster_mask != 0) { + return facts.can_bulk_load ? Supported(CopyInst::kBulkLoad) + : Unsupported(facts.tma_unavailable_reason); + } + if (facts.explicit_tma) { CopyInst inst = SelectTmaInst(facts, /*allow_load=*/true, /*allow_store=*/false, diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 508f90c654..e388a878bd 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -174,6 +174,11 @@ TIR_DEFINE_TL_BUILTIN(tma_load_im2col) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(tma_load_multicast) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(tma_store).set_num_inputs(-1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); @@ -711,5 +716,15 @@ TIR_DEFINE_TL_BUILTIN(stg128).set_num_inputs(-1).set_attr( TIR_DEFINE_TL_BUILTIN(stg256).set_num_inputs(-1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(ptx_cluster_store) + .set_num_inputs(4) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(tma_store_cluster) + .set_num_inputs(5) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + } // namespace tl } // namespace tvm diff --git a/src/op/builtin.h b/src/op/builtin.h index 69a65ab10a..829662d764 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -306,6 +306,14 @@ TVM_DLL const Op &tma_load(); */ TVM_DLL const Op &tma_load_im2col(); +/*! + * \brief TMA multicast load from a tensor descriptor to cluster shared memory. + * + * tma_load_multicast(descriptor, mbarrier, smem_data, multicast_mask, + * coord_0, coord_1, ..., eviction_policy) + */ +TVM_DLL const Op &tma_load_multicast(); + /*! * \brief tvm intrinsics for storing data from shared memory to global tensor * descriptor @@ -1240,6 +1248,18 @@ TVM_DLL const Op &stg128(); */ TVM_DLL const Op &stg256(); +/*! + * \brief Elementwise shared::cluster store via cooperative groups. + */ +TVM_DLL const Op &ptx_cluster_store(); + +/*! + * \brief Bulk async shared::cluster store to another CTA. + * + * tma_store_cluster(dst_ptr, src_ptr, dst_cta, size_bytes, bar_ref) + */ +TVM_DLL const Op &tma_store_cluster(); + } // namespace tl } // namespace tvm diff --git a/src/op/copy.cc b/src/op/copy.cc index b4e160515d..fb050d9e5e 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -168,6 +168,15 @@ Copy::Copy(Array args, Map annotations) { node->SetAccessRegions({src_access, dst_access}); // Copy annotations from the Call node node->annotations = annotations; + if (auto dst_block = node->annotations.Get("dst_block")) { + if (auto int_imm = dst_block->as()) { + if (int_imm->value != -1) { + node->dst_block = Integer(int_imm->value); + } + } else { + node->dst_block = Downcast(dst_block.value()); + } + } data_ = std::move(node); } diff --git a/src/op/copy.h b/src/op/copy.h index 03569b690e..b4bdc8235f 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -29,9 +29,11 @@ class CopyNode : public TileOperatorNode { public: Buffer src, dst; // Source and destination buffers Array src_range, dst_range; // Ranges for each dimension in src and dst + Optional dst_block; // Destination block index for cluster copy Map annotations; // Backend/pass-specific annotations. // Common SIMT annotation keys: // - "coalesced_width": IntImm, width for coalesced memory access. + // - "dst_block": PrimExpr, destination CTA rank for cluster copy. // - attr::kParallelLoopLayout ("parallel_loop_layout"): Fragment, loop // layout hint applied to the outermost generated parallel loop of this // copy's SIMT loop nest. @@ -47,6 +49,7 @@ class CopyNode : public TileOperatorNode { .def_ro("dst", &CopyNode::dst) .def_ro("src_range", &CopyNode::src_range) .def_ro("dst_range", &CopyNode::dst_range) + .def_ro("dst_block", &CopyNode::dst_block) .def_ro("annotations", &CopyNode::annotations); } diff --git a/src/op/operator.h b/src/op/operator.h index bbb6fdbf90..41b7d1a9ab 100644 --- a/src/op/operator.h +++ b/src/op/operator.h @@ -28,6 +28,7 @@ using namespace tir; using AddWorkspaceCallback = std::function; using AllocMBarrierCallback = std::function; +using UpdateBarrierArriveCallback = std::function; using LayoutMap = Map; using BufferMap = Map; @@ -87,6 +88,7 @@ struct LowerArgs { Var thread_var; AddWorkspaceCallback AddWorkspace; AllocMBarrierCallback AllocMBarrier; + UpdateBarrierArriveCallback UpdateBarrierArrive; LayoutMap layout_map; Map buffer_remap; // Map from LetStmt variable to its bound expression, for resolving diff --git a/src/tl_templates/cuda/copy_sm90.h b/src/tl_templates/cuda/copy_sm90.h index c8e1794485..beb1499077 100644 --- a/src/tl_templates/cuda/copy_sm90.h +++ b/src/tl_templates/cuda/copy_sm90.h @@ -46,6 +46,167 @@ TL_DEVICE void tma_load_multicast(void *smem_ptr, void *gmem_ptr, :); } +// Generic SM-to-SM async bulk copy via cp.async.bulk.shared::cluster +template +TL_DEVICE void tma_store_cluster(void *dst, void *src, int dst_cta, + uint32_t size_bytes, BarrierType &bar) { + uint32_t mbarrier_ptr = static_cast( + __cvta_generic_to_shared(reinterpret_cast(&bar))); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(src)); + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(dst)); + + uint32_t neighbor_addr_dst; + asm volatile("mapa.shared::cluster.u32 %0, %1, %2;\n" + : "=r"(neighbor_addr_dst) + : "r"(dst_ptr), "r"(dst_cta)); + + uint32_t neighbor_addr_mbarrier; + asm volatile("mapa.shared::cluster.u32 %0, %1, %2;\n" + : "=r"(neighbor_addr_mbarrier) + : "r"(mbarrier_ptr), "r"(dst_cta)); + + // Arrive at the remote barrier and announce the expected TX byte count. + // This satisfies one arrival (matching the mbarrier_init count) and tells + // the barrier how many bytes the subsequent cp.async.bulk will transfer. + asm volatile("mbarrier.arrive.expect_tx.shared::cluster.b64 _, [%0], %1;\n" + : + : "r"(neighbor_addr_mbarrier), "r"(size_bytes) + : "memory"); + + asm volatile("fence.proxy.async.shared::cta;\n" ::: "memory"); + asm volatile("cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_" + "tx::bytes [%0], [%1], %2, [%3];\n" + : + : "r"(neighbor_addr_dst), "r"(src_ptr), "r"(size_bytes), + "r"(neighbor_addr_mbarrier) + : "memory"); +} + +template +TL_DEVICE void +tma_load_multicast(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, uint16_t multicast_mask, + int32_t const &crd0) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::" + "bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3}], [%2], %4, %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), + "h"(multicast_mask), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void +tma_load_multicast(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, uint16_t multicast_mask, + int32_t const &crd0, int32_t const &crd1) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::" + "bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4}], [%2], %5, %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), + "r"(crd1), "h"(multicast_mask), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_load_multicast(const CUtensorMap &descriptor, + BarrierType &smem_mbar, + void const *const smem_ptr, + uint16_t multicast_mask, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::" + "bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], %6, %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), + "r"(crd1), "r"(crd2), "h"(multicast_mask), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void +tma_load_multicast(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, uint16_t multicast_mask, + int32_t const &crd0, int32_t const &crd1, + int32_t const &crd2, int32_t const &crd3) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::" + "bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7, %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), + "r"(crd1), "r"(crd2), "r"(crd3), "h"(multicast_mask), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_load_multicast(const CUtensorMap &descriptor, + BarrierType &smem_mbar, + void const *const smem_ptr, + uint16_t multicast_mask, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2, + int32_t const &crd3, int32_t const &crd4) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::" + "bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), + "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "h"(multicast_mask), + "l"(cache_hint) + : "memory"); +} + template TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, diff --git a/src/transform/inject_fence_proxy.cc b/src/transform/inject_fence_proxy.cc index 5ac24e7e91..3c6745af56 100644 --- a/src/transform/inject_fence_proxy.cc +++ b/src/transform/inject_fence_proxy.cc @@ -95,8 +95,8 @@ bool IsAsyncIntrinsic(const CallNode *call) { // TileLang async intrinsics if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col()) || - call->op.same_as(tma_store()) || call->op.same_as(ptx_wgmma_ss()) || - call->op.same_as(ptx_wgmma_rs()) || + call->op.same_as(tma_load_multicast()) || call->op.same_as(tma_store()) || + call->op.same_as(ptx_wgmma_ss()) || call->op.same_as(ptx_wgmma_rs()) || call->op.same_as(ptx_tcgen05_mma_ss()) || call->op.same_as(ptx_tcgen05_mma_ts())) { return true; diff --git a/src/transform/lower_hopper_intrin.cc b/src/transform/lower_hopper_intrin.cc index e9ea2cdbc4..400d6cd765 100644 --- a/src/transform/lower_hopper_intrin.cc +++ b/src/transform/lower_hopper_intrin.cc @@ -193,6 +193,30 @@ class LowerHopperIntrin : public StmtExprMutator { } } + Stmt VisitStmt_(const IfThenElseNode *op) final { + Stmt new_stmt = StmtExprMutator::VisitStmt_(op); + if (const auto *if_node = new_stmt.as()) { + if (IsNoOp(if_node->then_case) && (!if_node->else_case.defined() || + IsNoOp(if_node->else_case.value()))) { + return Evaluate(0); + } + } + return new_stmt; + } + + bool IsNoOp(const Stmt &stmt) { + if (const auto *eval = stmt.as()) { + return is_const_int(eval->value, 0); + } else if (const auto *seq = stmt.as()) { + for (const auto &s : seq->seq) { + if (!IsNoOp(s)) + return false; + } + return true; + } + return false; + } + private: struct DescInit { const VarNode *base_var; diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 7267f17ae7..8555e8d1e7 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -346,6 +346,33 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { } workspace_stack_.pop_back(); } + + // Apply arrive-count overrides before LowerSharedBarrier consumes them. + if (!barrier_arrive_updates_.empty() && + block->annotations.count("barrier_init")) { + auto barrier_init_map = Downcast>>( + block->annotations.Get("barrier_init").value()); + bool updated = false; + for (auto it = barrier_arrive_updates_.begin(); + it != barrier_arrive_updates_.end();) { + if (barrier_init_map.count(it->first)) { + auto old_counts = barrier_init_map.at(it->first); + Array new_counts; + for (size_t i = 0; i < old_counts.size(); i++) { + new_counts.push_back(it->second); + } + barrier_init_map.Set(it->first, new_counts); + updated = true; + it = barrier_arrive_updates_.erase(it); + } else { + ++it; + } + } + if (updated) { + block_ptr->annotations.Set("barrier_init", barrier_init_map); + } + } + return block; } @@ -720,6 +747,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const tir::CallNode *op) final { if (op->op.same_as(tl::tma_load()) || op->op.same_as(tl::tma_load_im2col()) || + op->op.same_as(tl::tma_load_multicast()) || op->op.same_as(tl::tma_store())) { // skip tma related calls, as they were transformed implicitly. has_tma_ = true; @@ -728,6 +756,12 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { in_tma_context_ = false; return call; } + if (op->op.same_as(tl::tma_store_cluster())) { + // SM-to-SM bulk async copy does not use a tensor-map descriptor, so + // shared-memory swizzle must still be reflected in pointer/index + // remapping. + return Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + } if (is_ptx_) { return Downcast(op); @@ -1068,15 +1102,20 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { return id; }; - auto lowered = - tile_op->Lower(LowerArgs{target_, thread_bounds, thread_var_->var, - callback, mbarrier_callback, layout_map_, - buffer_remap_, let_var_to_expr, - loop_mbar_phase_stack_.empty() - ? PrimExpr(IntImm(DataType::Int(32), 0)) - : loop_mbar_phase_stack_.back(), - &mbarrier_buffer_, cluster_size_}, - analyzer_); + UpdateBarrierArriveCallback barrier_arrive_callback = [this](Var data_var, + PrimExpr n) { + barrier_arrive_updates_[data_var] = n; + }; + + auto lowered = tile_op->Lower( + LowerArgs{target_, thread_bounds, thread_var_->var, callback, + mbarrier_callback, barrier_arrive_callback, layout_map_, + buffer_remap_, let_var_to_expr, + loop_mbar_phase_stack_.empty() + ? PrimExpr(IntImm(DataType::Int(32), 0)) + : loop_mbar_phase_stack_.back(), + &mbarrier_buffer_, cluster_size_}, + analyzer_); return IRMutatorWithAnalyzer::VisitStmt(lowered); } @@ -1380,6 +1419,9 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { // without recomputing indices, since swizzle is encoded in TMA descriptor // parameters rather than in memory indices. bool in_tma_context_{false}; + // Pending barrier arrive-count overrides from multi-TMA cluster copies. + std::unordered_map + barrier_arrive_updates_; }; namespace transform { diff --git a/src/transform/multi_version_buffer_rewriter.cc b/src/transform/multi_version_buffer_rewriter.cc index e73e2e1b93..b5dbfde86b 100644 --- a/src/transform/multi_version_buffer_rewriter.cc +++ b/src/transform/multi_version_buffer_rewriter.cc @@ -132,7 +132,8 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor { void VisitStmt_(const EvaluateNode *op) final { Role role = Role::kConsumer; if (auto call = op->value.as()) { - if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col()) || + call->op.same_as(tma_load_multicast())) { role = Role::kProducer; has_bulk_copy_ = true; } diff --git a/src/transform/pipeline_planning.cc b/src/transform/pipeline_planning.cc index 3bb38f043d..43d3fd9541 100644 --- a/src/transform/pipeline_planning.cc +++ b/src/transform/pipeline_planning.cc @@ -452,10 +452,18 @@ class BufferRegionCollector : public StmtExprVisitor { buffer_reads->second.end()); } if (buffer_writes != chain_builder_.mbar_to_buffer_writes_.end()) { - writes_.insert( - writes_.end(), - chain_builder_.mbar_to_buffer_writes_.at(mbar_buf.get()).begin(), - chain_builder_.mbar_to_buffer_writes_.at(mbar_buf.get()).end()); + writes_.insert(writes_.end(), buffer_writes->second.begin(), + buffer_writes->second.end()); + } + } else { + // Handle-based mbarrier (e.g. get_mbarrier(id)): cannot resolve to a + // concrete Buffer. Conservatively attach all known async buffer + // dependencies so the wait is not treated as dependency-free. + for (const auto &[_, regions] : chain_builder_.mbar_to_buffer_reads_) { + reads_.insert(reads_.end(), regions.begin(), regions.end()); + } + for (const auto &[_, regions] : chain_builder_.mbar_to_buffer_writes_) { + writes_.insert(writes_.end(), regions.begin(), regions.end()); } } } else { diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index a9320b1502..80700ddb87 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -1088,7 +1088,8 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { if (auto opt = op->op.as()) { const Op &call_op = opt.value(); return call_op.same_as(tl::tma_load()) || - call_op.same_as(tl::tma_load_im2col()); + call_op.same_as(tl::tma_load_im2col()) || + call_op.same_as(tl::tma_load_multicast()); } return false; }(); diff --git a/testing/python/cuda/test_tma_dsmem.py b/testing/python/cuda/test_tma_dsmem.py new file mode 100644 index 0000000000..e9e5bd447a --- /dev/null +++ b/testing/python/cuda/test_tma_dsmem.py @@ -0,0 +1,293 @@ +"""Regression tests for SM-to-SM cluster copy.""" + +import torch +import tilelang +import tilelang.language as T +import tilelang.testing +import numpy as np + + +def make_store_cluster_kernel(N: int): + @T.prim_func + def kernel( + A: T.Tensor((N,), "float32"), + B: T.Tensor((N,), "float32"), + ): + with T.Kernel(2, threads=128, cluster_dims=(2, 1, 1)) as pid: + s_src = T.alloc_shared((N,), "float32") + s_dst = T.alloc_shared((N,), "float32") + s_barrier = T.alloc_cluster_barrier([1]) + + T.fill(s_src, 0.0) + T.fill(s_dst, 0.0) + T.cluster_sync() + + if pid == 0: + for i in T.Parallel(N): + s_src[i] = A[i] + T.copy_cluster(s_src, s_dst, dst_block=1, remote_barrier=s_barrier[0]) + + if pid == 1: + T.mbarrier_wait_parity(s_barrier[0], 0) + for i in T.Parallel(N): + B[i] = s_dst[i] + + return kernel + + +def make_store_cluster_simt_no_barrier_kernel(N: int): + """No remote_barrier -> SIMT fallback always taken; cluster_sync() orders stores.""" + + @T.prim_func + def kernel( + A: T.Tensor((N,), "float32"), + B: T.Tensor((N,), "float32"), + ): + with T.Kernel(2, threads=128, cluster_dims=(2, 1, 1)) as pid: + s_src = T.alloc_shared((N,), "float32") + s_dst = T.alloc_shared((N,), "float32") + + T.fill(s_src, 0.0) + T.fill(s_dst, 0.0) + T.cluster_sync() + + if pid == 0: + for i in T.Parallel(N): + s_src[i] = A[i] + # No remote_barrier: cluster copy lowering takes the SIMT path. + # All threads write into block 1's s_dst via map_shared_rank. + T.copy_cluster(s_src, s_dst, dst_block=1) + + # Full cluster barrier: ensures all map_shared_rank stores from + # block 0 are visible in block 1's address space before block 1 + # reads s_dst. + T.cluster_sync() + + if pid == 1: + for i in T.Parallel(N): + B[i] = s_dst[i] + + return kernel + + +def make_store_cluster_simt_barrier_kernel(M: int, N_full: int, N_tile: int): + """2-D slice copy that forces the SIMT fallback even though remote_barrier is set. + + s_src / s_dst are allocated with inner dimension N_full, but only the + first N_tile columns are copied. Because N_tile < N_full the + is_contiguous_region() check fails: the inner-dim extent of the copy + region (N_tile) does not equal the buffer shape (N_full). + + Cluster copy lowering falls back to map_shared_rank stores and, because + remote_barrier was supplied, automatically appends: + __syncthreads(); + if (threadIdx.x == 0) s_barrier[0].arrive(1u); + Block 1 therefore waits on the same mbarrier as in the fast-path API, + verifying that ptx_arrive_cluster_barrier is injected and functional. + """ + + @T.prim_func + def kernel( + A: T.Tensor((M, N_tile), "float32"), + B: T.Tensor((M, N_tile), "float32"), + ): + with T.Kernel(2, threads=128, cluster_dims=(2, 1, 1)) as pid: + # Deliberately wider buffer: N_full > N_tile so the slice + # [0:M, 0:N_tile] is non-contiguous in row-major storage. + s_src = T.alloc_shared((M, N_full), "float32") + s_dst = T.alloc_shared((M, N_full), "float32") + s_barrier = T.alloc_cluster_barrier([1]) + + T.fill(s_src, 0.0) + T.fill(s_dst, 0.0) + T.cluster_sync() + + if pid == 0: + for i, j in T.Parallel(M, N_tile): + s_src[i, j] = A[i, j] + + # [0:M, 0:N_tile] inner-dim extent N_tile != N_full + # contiguity check fails, so this uses the SIMT fallback. + # Compiler auto-injects: __syncthreads() + + # if (t == 0) s_barrier[0].arrive(1u); + T.copy_cluster( + s_src[0:M, 0:N_tile], + s_dst[0:M, 0:N_tile], + dst_block=1, + remote_barrier=s_barrier[0], + ) + + if pid == 1: + # Block 1 waits on the auto-injected ptx_arrive_cluster_barrier. + T.mbarrier_wait_parity(s_barrier[0], 0) + for i, j in T.Parallel(M, N_tile): + B[i, j] = s_dst[i, j] + + return kernel + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_tma_store_cluster(): + """Fast path: T.copy_cluster emits tl::tma_store_cluster.""" + N = 128 + prim_func = make_store_cluster_kernel(N) + mod = tilelang.compile(prim_func, out_idx=[1], execution_backend="cython") + + src = mod.get_kernel_source() + assert "tl::tma_store_cluster" in src, ( + "Expected tl::tma_store_cluster in generated kernel source; " + "T.copy_cluster(dst_block=..., remote_barrier=...) may have regressed " + f"to the SIMT fallback.\nKernel source:\n{src}" + ) + + A = torch.arange(N, dtype=torch.float32, device="cuda") + B = mod(A) + np.testing.assert_allclose( + B.cpu().numpy(), + A.cpu().numpy(), + rtol=0, + atol=0, + err_msg="tma_store_cluster copy produced wrong result", + ) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_store_cluster_simt_no_barrier(): + """SIMT fallback (no remote_barrier): map_shared_rank + cluster_sync ordering.""" + N = 128 + prim_func = make_store_cluster_simt_no_barrier_kernel(N) + mod = tilelang.compile(prim_func, out_idx=[1], execution_backend="cython") + + src = mod.get_kernel_source() + assert "map_shared_rank" in src, f"Expected map_shared_rank in generated source for no-barrier SIMT fallback.\nKernel source:\n{src}" + assert "tl::tma_store_cluster" not in src, f"No-barrier path must NOT emit tl::tma_store_cluster.\nKernel source:\n{src}" + + A = torch.arange(N, dtype=torch.float32, device="cuda") + B = mod(A) + np.testing.assert_allclose( + B.cpu().numpy(), + A.cpu().numpy(), + rtol=0, + atol=0, + err_msg="SIMT no-barrier cluster copy produced wrong result", + ) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_store_cluster_multi_tma_barrier(): + """Multi-TMA path: non-contiguous 2-D slice decomposed into N row TMA calls. + + The compiler decomposes a non-full-span 2-D slice (shape [M, N_tile] inside + a buffer of shape [M, N_full]) into M individual tma_store_cluster calls - + one per contiguous row. The mbarrier arrive_count is updated to M so the + destination CTA's wait(0) completes only after all rows are transferred. + """ + M, N_full, N_tile = 4, 64, 32 # M rows, each N_tile elements + + prim_func = make_store_cluster_simt_barrier_kernel(M, N_full, N_tile) + mod = tilelang.compile(prim_func, out_idx=[1], execution_backend="cython") + + src = mod.get_kernel_source() + # Multi-TMA path must emit M separate tma_store_cluster calls, not SIMT stores. + assert "tl::tma_store_cluster" in src, f"Expected tl::tma_store_cluster for multi-TMA row decomposition.\nKernel source:\n{src}" + assert "map_shared_rank" not in src, f"Multi-TMA path must NOT fall back to map_shared_rank.\nKernel source:\n{src}" + # The barrier must be initialised with arrive_count == M (one per TMA call). + assert f"s_barrier[0].init({M})" in src, f"Expected barrier arrive_count={M} for {M}-row decomposition.\nKernel source:\n{src}" + # Exactly M tma_store_cluster calls should appear in the source. + assert src.count("tl::tma_store_cluster") == M, ( + f"Expected exactly {M} tma_store_cluster calls, got {src.count('tl::tma_store_cluster')}.\nKernel source:\n{src}" + ) + + A = torch.arange(M * N_tile, dtype=torch.float32, device="cuda").reshape(M, N_tile) + B = mod(A) + np.testing.assert_allclose( + B.cpu().numpy(), + A.cpu().numpy(), + rtol=0, + atol=0, + err_msg="Multi-TMA row-decomposed cluster copy produced wrong result", + ) + + +def make_store_cluster_3d_multi_tma_kernel(D: int, M: int, N_full: int, N_tile: int): + """3-D slice copy decomposed into D*M tma_store_cluster calls. + + Buffer shape is [D, M, N_full]; the copy region is [0:D, 0:M, 0:N_tile]. + Because N_tile < N_full the innermost dim is not full-span. + MakeTMARows recurses twice (once on dim 0, once on dim 1) producing D*M + contiguous-row TMA calls and sets barrier arrive_count = D*M. + """ + + @T.prim_func + def kernel( + A: T.Tensor((D, M, N_tile), "float32"), + B: T.Tensor((D, M, N_tile), "float32"), + ): + with T.Kernel(2, threads=D * M * N_tile, cluster_dims=(2, 1, 1)) as pid: + s_src = T.alloc_shared((D, M, N_full), "float32") + s_dst = T.alloc_shared((D, M, N_full), "float32") + s_barrier = T.alloc_cluster_barrier([1]) + + T.fill(s_src, 0.0) + T.fill(s_dst, 0.0) + T.cluster_sync() + + if pid == 0: + for d, i, j in T.Parallel(D, M, N_tile): + s_src[d, i, j] = A[d, i, j] + + T.copy_cluster( + s_src[0:D, 0:M, 0:N_tile], + s_dst[0:D, 0:M, 0:N_tile], + dst_block=1, + remote_barrier=s_barrier[0], + ) + + if pid == 1: + T.mbarrier_wait_parity(s_barrier[0], 0) + for d, i, j in T.Parallel(D, M, N_tile): + B[d, i, j] = s_dst[d, i, j] + + return kernel + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_store_cluster_3d_multi_tma(): + """3-D multi-TMA: recursive decomposition produces D*M separate TMA calls. + + With D=2 and M=4, the two-level recursion over dims 0 and 1 yields 8 + contiguous-row tma_store_cluster calls and initialises the barrier with + arrive_count=8. + """ + D, M, N_full, N_tile = 2, 4, 32, 16 # D*M*N_tile == 128 == thread count + + prim_func = make_store_cluster_3d_multi_tma_kernel(D, M, N_full, N_tile) + mod = tilelang.compile(prim_func, out_idx=[1], execution_backend="cython") + + src = mod.get_kernel_source() + n_expected = D * M + assert "tl::tma_store_cluster" in src, f"Expected tl::tma_store_cluster for 3-D multi-TMA.\nKernel source:\n{src}" + assert f"s_barrier[0].init({n_expected})" in src, ( + f"Expected barrier arrive_count={n_expected} for {D}x{M} decomposition.\nKernel source:\n{src}" + ) + assert src.count("tl::tma_store_cluster") == n_expected, ( + f"Expected exactly {n_expected} tma_store_cluster calls, got {src.count('tl::tma_store_cluster')}.\nKernel source:\n{src}" + ) + + A = torch.arange(D * M * N_tile, dtype=torch.float32, device="cuda").reshape(D, M, N_tile) + B = mod(A) + np.testing.assert_allclose( + B.cpu().numpy(), + A.cpu().numpy(), + rtol=0, + atol=0, + err_msg="3-D multi-TMA cluster copy produced wrong result", + ) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/cuda/test_tma_multicast_demo.py b/testing/python/cuda/test_tma_multicast_demo.py new file mode 100644 index 0000000000..395622297f --- /dev/null +++ b/testing/python/cuda/test_tma_multicast_demo.py @@ -0,0 +1,76 @@ +"""TMA multicast validation.""" + +import torch +import tilelang +import tilelang.language as T +import tilelang.testing + + +def make_tma_multicast_demo_kernel(M, N, block_M, block_N, cluster_mask): + @T.prim_func + def kernel( + A: T.Tensor((M, N), "float16"), + B: T.Tensor((M, N), "float16"), + ): + with T.Kernel( + T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + threads=128, + cluster_dims=(4, 1, 1), + ) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), "float16") + T.copy_cluster(A[by * block_M, bx * block_N], A_shared, cluster_mask=cluster_mask) + T.copy(A_shared, B[by * block_M, bx * block_N]) + + return kernel + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_tma_multicast_demo(): + """Verify TMA multicast: rank 1's B region should equal rank 0's A region within the same cluster.""" + M, N = 1024, 1024 + block_M, block_N = 128, 64 + # mask=0b0011: rank 0 multicasts, rank 1 receives, ranks 2/3 each do regular tma_load + cluster_mask = 0b0011 + + kernel = make_tma_multicast_demo_kernel(M, N, block_M, block_N, cluster_mask) + mod = tilelang.compile( + kernel, + out_idx=[1], + execution_backend="cython", + ) + + A = torch.randn(M, N, device="cuda", dtype=torch.float16) + B = mod(A) + + # Within a cluster: the first 4 blocks in the grid are (0,0),(1,0),(2,0),(3,0) -> by=0, bx=0,1,2,3 + # rank 0 -> bx=0: A[0:block_M, 0:block_N] -> B[0:block_M, 0:block_N] + # rank 1 -> bx=1: multicast receives A[0:block_M, 0:block_N] -> B[0:block_M, block_N:2*block_N] + # rank 2 -> bx=2: A[0:block_M, 2*block_N:3*block_N] -> B[0:block_M, 2*block_N:3*block_N] + # rank 3 -> bx=3: A[0:block_M, 3*block_N:4*block_N] -> B[0:block_M, 3*block_N:4*block_N] + + # Multicast check: rank 1's B region should equal rank 0's A region + B_rank1 = B[0:block_M, block_N : 2 * block_N] + A_rank0 = A[0:block_M, 0:block_N] + torch.testing.assert_close(B_rank1, A_rank0, rtol=1e-2, atol=1e-2) + + # rank 0 itself: B should equal A + torch.testing.assert_close(B[0:block_M, 0:block_N], A[0:block_M, 0:block_N], rtol=1e-2, atol=1e-2) + # ranks 2, 3: each B region equals its own A region + torch.testing.assert_close( + B[0:block_M, 2 * block_N : 3 * block_N], + A[0:block_M, 2 * block_N : 3 * block_N], + rtol=1e-2, + atol=1e-2, + ) + torch.testing.assert_close( + B[0:block_M, 3 * block_N : 4 * block_N], + A[0:block_M, 3 * block_N : 4 * block_N], + rtol=1e-2, + atol=1e-2, + ) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 43c70563a2..aeb63eef49 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -55,7 +55,7 @@ empty, # noqa: F401 ) from tvm.script.parser.tir import allocate as allocate # noqa: F401 -from .copy_op import copy, async_copy, tma_copy, transpose, c2d_im2col # noqa: F401 +from .copy_op import copy, async_copy, tma_copy, transpose, c2d_im2col, copy_cluster # noqa: F401 from tilelang.tileop.base import GemmWarpPolicy # noqa: F401 from .gemm_op import ( # noqa: F401 gemm, diff --git a/tilelang/language/copy_op.py b/tilelang/language/copy_op.py index 7e9cdddb08..0a4d7a2b0d 100644 --- a/tilelang/language/copy_op.py +++ b/tilelang/language/copy_op.py @@ -120,6 +120,59 @@ def copy( return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, annotations=ann if ann else None) +def copy_cluster( + src: BufferLikeType, + dst: BufferLikeType, + *, + dst_block: int | tir.PrimExpr | None = None, + cluster_mask: int | None = None, + remote_barrier: tir.BufferLoad | None = None, + eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, + coalesced_width: int | None = None, + loop_layout: Any | None = None, +) -> tir.PrimExpr | tir.Stmt: + """Cluster-aware copy for TMA multicast or SM-to-SM shared-memory copy. + + Args: + src: Source memory region. + dst: Destination memory region. + dst_block: Destination CTA rank in the cluster for SM-to-SM copy. + cluster_mask: Bitmask of CTAs that participate in TMA multicast. + remote_barrier: Shared-memory mbarrier for asynchronous SM-to-SM copy + completion signalling. The destination CTA should wait on its + local copy of this barrier. + eviction_policy: Cache eviction hint passed to the TMA instruction. + Only relevant for the TMA multicast path (``cluster_mask`` set). + coalesced_width: Vectorization width (in elements) for the SIMT loop + used on the SM-to-SM fallback path (``dst_block`` set, no fast + bulk-async route available). + loop_layout: Parallel loop layout hint (Fragment) for the SIMT loop on + the SM-to-SM fallback path. Incompatible with the TMA multicast + path (``cluster_mask`` set). + + Returns: + tir.Call: A handle to the copy operation. + """ + src, dst = _normalize_copy_regions(src, dst) + + ann: dict = {} + if dst_block is not None: + ann["dst_block"] = dst_block + if cluster_mask is not None: + ann["cluster_mask"] = cluster_mask + if remote_barrier is not None: + ann["barrier"] = remote_barrier + if eviction_policy is not None: + eviction_policy_map = {"evict_normal": 0, "evict_first": 1, "evict_last": 2} + ann["eviction_policy"] = eviction_policy_map[eviction_policy] + if coalesced_width is not None: + ann["coalesced_width"] = coalesced_width + if loop_layout is not None: + ann["parallel_loop_layout"] = loop_layout + + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, annotations=ann if ann else None) + + def async_copy( src: BufferLikeType, dst: BufferLikeType,