-
Notifications
You must be signed in to change notification settings - Fork 553
[Feature] Add T.copy_cluster to support TMA multicast and SM-to-SM cluster copy #1908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
He-Jingkai
wants to merge
82
commits into
tile-ai:main
Choose a base branch
from
He-Jingkai:t_copy_extend
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 21 commits
Commits
Show all changes
82 commits
Select commit
Hold shift + click to select a range
80e409b
[Feature] Extend t.copy to support TMA multicast and SM-to-SM cluster…
He-Jingkai 01c73b3
[Docs] Add programming guide for Cluster TMA features
He-Jingkai 9ab0119
Merge remote-tracking branch 'upstream/main' into cluster-launch
He-Jingkai e8b3bd0
change cluster_size -> cluster_dims
He-Jingkai deab402
fix merge conflict
He-Jingkai 2ef8eba
fix TILELANG CHECK bug
He-Jingkai 29a7971
unify cluster function
He-Jingkai 5d57c11
fix pre-commit errors
He-Jingkai c2efb88
docs(cluster_tma): clarify multicast non-issuer behavior in lowering …
sigmoidsee 258c948
fix(copy): gate dst_block cluster lowering on target capability
sigmoidsee 5a61902
Merge branch 'tile-ai:main' into t_copy_extend
He-Jingkai d2d7f1e
fix(cuda/cluster): hard-trap when cluster intrinsics are unavailable …
sigmoidsee 2c67281
fix inject_tma & lower_hopper & lower_tile bugs
sigmoidsee cbc15d7
Merge remote-tracking branch 'pr/t_copy_extend' into t_copy_extend
sigmoidsee db56e92
[warp specialize]Fix mbarrier retarget for tma_load_multicast in warp…
sigmoidsee 8af6bd0
fix(tma-barrier): correct arrive thread count under equality-guarded …
sigmoidsee f2cb665
fix: overlapped TMA barrier IDs
He-Jingkai febfd78
fix: multicast checks
He-Jingkai bf38cfa
minor fix
He-Jingkai 4008456
minor fix
He-Jingkai cb5305f
fix: tma_load_multicast() in the stage-local mbarrier rewrite
He-Jingkai 310ccbb
minor fix
He-Jingkai c50f426
format
He-Jingkai add3de9
minor fix
He-Jingkai 531c51f
fix: Track allocation from every hoisted barrier init.
He-Jingkai 1086f24
fix: Scope the 128-byte alignment to kernels that actually use TMA.
He-Jingkai 1996788
fix: Assert the tl::tma_store_cluster lowering, not just the output.
He-Jingkai 57fd797
format fix
He-Jingkai a39604d
fix: Don't assume both if arms transfer the same number of bytes.
He-Jingkai d1c7bb8
fix: Don't wrap the kept pipelined loop in a SeqStmt before the For p…
He-Jingkai 4244e07
rm mbarrier_init and use alloc_cluster_barrier
He-Jingkai e5b00c5
fix: Don't make cluster copy a global exemption for unrelated TMA loads.
He-Jingkai 25f0aa1
fix: Don't drop dependency modeling for handle-based mbarrier_wait_pa…
He-Jingkai f0d4169
fix: Don't hoist extracted barrier-init if statements to the block root.
He-Jingkai ccd9c1f
Merge branch 't_copy_extend' of github.com:He-Jingkai/tilelang into t…
sigmoidsee 8a947bb
fix(copy): honor remapped dst layout in cluster-copy slow path
sigmoidsee 6ca3e73
fix(copy): gate cluster TMA fast path on provable contiguity
sigmoidsee 989044e
fix(copy): add barrier completion for SIMT cluster-copy fallback
sigmoidsee 2a26453
fix: remove dup cluster_sync
He-Jingkai 4a982ab
Merge branch 't_copy_extend' of github.com:He-Jingkai/tilelang into t…
He-Jingkai c07a79c
fix: remove not used _get_mbarrier
He-Jingkai a2a1fc6
fix: remove mbarrier related code in codegen_cuda.cc and barrier.h
He-Jingkai 85cebfb
fix: codegen_cuda.cc: reuse functions in cluster.h and remove depende…
He-Jingkai a5da2a3
fix testting
He-Jingkai 99c800c
fix(transform): prevent ragged_prefix free var from breaking MakePack…
sigmoidsee 831beb4
minor fix
He-Jingkai b30e926
format
He-Jingkai a8753c8
fix inject tma barrier: fix crash and correct mbarrier thread-count i…
sigmoidsee 8a3f99d
Merge branch 't_copy_extend' of github.com:He-Jingkai/tilelang into t…
sigmoidsee a312dd1
Merge remote-tracking branch 'upstream/main' into t_copy_extend
He-Jingkai 9abd5ec
revert inject tma
sigmoidsee b9ab3f3
Merge branch 't_copy_extend' of github.com:He-Jingkai/tilelang into t…
sigmoidsee 0ec4374
fix(inject_tma_barrier): avoid SIGSEGV and correctly infer barrier ar…
sigmoidsee 79cfee7
fix: add tma_load_multicast same as tma_load
sigmoidsee 00ea64b
T.copy_cluster
He-Jingkai c54bcec
fix: renaming cluster mbarrier_arrive to ptx_arrive_cluster_barerier
sigmoidsee c6f1f6d
fix: remove mbarrier_arrive in builtin.cc
sigmoidsee b49462f
test_tma_dsmem: fallback test
He-Jingkai aeffe21
Merge branch 't_copy_extend' of github.com:He-Jingkai/tilelang into t…
He-Jingkai 0828868
Merge branch 'main' of github.com:tile-ai/tilelang into t_copy_extend
He-Jingkai 5d045aa
[Feature] Multi-TMA fallback for non-contiguous T.copy_cluster regions
He-Jingkai 317e139
Merge remote-tracking branch 'upstream/main' into t_copy_extend
sigmoidsee 3feb182
Merge branch 'main' of github.com:tile-ai/tilelang into t_copy_extend
sigmoidsee f760b53
Merge branch 'main' of github.com:tile-ai/tilelang into t_copy_extend
sigmoidsee c1cfa9a
Merge remote-tracking branch 'upstream/main' into t_copy_extend
He-Jingkai 6f7344c
Merge branch 'main' of github.com:tile-ai/tilelang into t_copy_extend
He-Jingkai 1a747cf
Merge branch 'main' of github.com:tile-ai/tilelang into t_copy_extend
He-Jingkai 5ec44e6
bug fix
He-Jingkai 537e330
fix: unified continuous check
He-Jingkai 94a4ad0
fix doc
He-Jingkai 7679d51
fix: remote redundant annotation converting
He-Jingkai 3263423
fix: remove unused code
He-Jingkai 813d830
fix: remove unused code
He-Jingkai 91b318d
rm unused comments
He-Jingkai 7288889
minor fix
He-Jingkai 07130ea
fix: rm unused helper func
He-Jingkai 0a0e436
minor
He-Jingkai 465d96b
minor
He-Jingkai 53299af
Merge branch 'main' of github.com:tile-ai/tilelang into t_copy_extend
He-Jingkai e1f6ec1
Merge remote-tracking branch 'upstream/main' into t_copy_extend
He-Jingkai b176d92
Clean up copy cluster merge formatting
He-Jingkai a6959e3
Reduce copy cluster PR churn
He-Jingkai File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,297 @@ | ||
| # Cluster TMA: Multicast and SM-to-SM Copy | ||
|
|
||
| This page describes two advanced data-movement features that are available on | ||
| NVIDIA Hopper (SM90) and later: **TMA multicast** and **SM-to-SM cluster | ||
| copy**. Both features are exposed through extensions to the existing `T.copy` | ||
| operator and require a kernel launched with thread block cluster, i.e., with `cluster_dims != (1, 1, 1)`. | ||
|
|
||
| Requirements: | ||
| - CUDA Compute Capability ≥ 9.0 (Hopper / Blackwell / RTX 5090) | ||
|
|
||
| --- | ||
|
|
||
| ## Background: Thread Block Clusters | ||
|
|
||
| A *thread block cluster* is a group of CTAs that share a common virtual address | ||
| space for their shared-memory regions and can communicate without going through | ||
| global memory. Within a cluster, each CTA has a *block rank* (0-indexed | ||
| position inside the cluster), and all CTAs can observe each other's shared | ||
| memory via the `shared::cluster` address space. | ||
|
|
||
| ```python | ||
| with T.Kernel(grid_x, grid_y, threads=128, cluster_dims=(4, 1, 1)) as (bx, by): | ||
| rank = T.block_rank_in_cluster() # 0..3 within this cluster | ||
| cid = T.get_cluster_id() # which cluster am I in | ||
| nctas = T.get_cluster_block_nums() | ||
| T.cluster_sync() # barrier across all CTAs in cluster | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| ## Feature 1 — TMA Multicast (`cluster_mask`) | ||
|
|
||
| ### What it does | ||
|
|
||
| Normally each CTA issues its own TMA load, fetching a tile from global memory | ||
| into its private shared memory. With multicast, **a single TMA transaction | ||
| broadcasts one global tile to every participating CTA simultaneously**, saving | ||
| repeated DRAM traffic when multiple CTAs in a cluster need the same data (e.g., | ||
| the same K-panel in a split-K GEMM). | ||
|
|
||
| ```text | ||
| Global memory ──TMA multicast──▶ shared memory (rank 0) | ||
| └─▶ shared memory (rank 1) (same tile, no extra DRAM read) | ||
| TMA load ──▶ shared memory (rank 2) (independent tile) | ||
| TMA load ──▶ shared memory (rank 3) (independent tile) | ||
| ``` | ||
|
|
||
| ### API | ||
|
|
||
| ```python | ||
| T.copy(src_global, dst_shared, cluster_mask=<int>) | ||
| ``` | ||
|
|
||
| `cluster_mask` is a bitmask where each set bit identifies a CTA rank that | ||
| participates in the multicast. The CTA whose rank equals the lowest set bit | ||
| in the mask issues `cp.async.bulk.tensor … multicast::cluster`; every other | ||
| CTA in the mask receives the data passively (no instruction issued). CTAs | ||
| outside the mask perform a regular TMA load for their own tile. | ||
|
|
||
| ### Example | ||
|
|
||
| ```python | ||
| import tilelang | ||
| import tilelang.language as T | ||
|
|
||
| def make_tma_multicast_kernel(M, N, block_M, block_N, cluster_mask): | ||
| @T.prim_func | ||
| def kernel( | ||
| A: T.Tensor((M, N), "float16"), | ||
| B: T.Tensor((M, N), "float16"), | ||
| ): | ||
| # 4 CTAs per cluster; ranks 0 and 1 share the same tile via multicast. | ||
| with T.Kernel( | ||
| T.ceildiv(N, block_N), | ||
| T.ceildiv(M, block_M), | ||
| threads=128, | ||
| cluster_dims=(4, 1, 1) | ||
| ) as (bx, by): | ||
| A_shared = T.alloc_shared((block_M, block_N), "float16") | ||
|
|
||
| # cluster_mask=0b0011: ranks 0 and 1 participate. | ||
| # Rank 0 issues tma_load_multicast; rank 1 receives passively. | ||
| # Ranks 2 and 3 each issue a regular tma_load. | ||
| T.copy(A[by * block_M, bx * block_N], A_shared, | ||
| cluster_mask=cluster_mask) | ||
|
|
||
| T.copy(A_shared, B[by * block_M, bx * block_N]) | ||
|
|
||
| return kernel | ||
| ``` | ||
|
|
||
| Running the kernel above with `cluster_mask = 0b0011`: | ||
|
|
||
| | Rank | Action | `B` slice receives | | ||
| |------|--------|--------------------| | ||
| | 0 | issues multicast load | A tile at rank-0 address | | ||
| | 1 | passively receives | **same** A tile as rank 0 | | ||
| | 2 | regular TMA load | A tile at rank-2 address | | ||
| | 3 | regular TMA load | A tile at rank-3 address | | ||
|
|
||
| ### Notes | ||
|
|
||
| - The compiler lowers `cluster_mask != 0` to | ||
| `cp.async.bulk.tensor.Nd.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster` | ||
| for the issuing CTA; CTAs in the mask but not elected as issuer receive | ||
| passively, and only CTAs outside the mask issue a standard | ||
| `cp.async.bulk.tensor`. | ||
| - Software-pipelining (`T.Pipelined`) is fully supported; the warp-specialized | ||
| rewriter recognises `tma_load_multicast` as a producer operation. | ||
| - `cluster_mask` is a compile-time constant; dynamic masks are not supported. | ||
|
|
||
| --- | ||
|
|
||
| ## Feature 2 — SM-to-SM Cluster Copy (`dst_block`) | ||
|
|
||
| ### What it does | ||
|
|
||
| SM-to-SM copy lets one CTA **push data directly from its own shared memory | ||
| into another CTA's shared memory** within the same cluster, without a round | ||
| trip through global memory. This is useful for patterns such as: | ||
|
|
||
| - Partial result exchange (e.g., split-K partial sums across SM boundaries) | ||
| - Producer–consumer pipelines where the producer fills a neighbor's buffer | ||
| - All-to-all collective communication within a cluster | ||
|
|
||
| Two sub-variants are provided depending on whether an mbarrier is supplied: | ||
|
|
||
| | Variant | Parameter | Hardware instruction | Threads used | | ||
| |---------|-----------|---------------------|--------------| | ||
| | **Fast path** | `dst_block` + `remote_barrier` | `cp.async.bulk.shared::cluster` | 1 (async DMA) | | ||
| | **Slow path** | `dst_block` only | `map_shared_rank` + scalar stores | all (SIMT loop) | | ||
|
|
||
| ### Fast path — bulk async copy with mbarrier | ||
|
|
||
| ```python | ||
| T.copy(src_shared, dst_shared, dst_block=<rank>, remote_barrier=<mbarrier>) | ||
| ``` | ||
|
|
||
| A single elected thread issues one `cp.async.bulk.shared::cluster` instruction. | ||
| The hardware DMA engine transfers the entire tile asynchronously and signals | ||
| the destination CTA's mbarrier on completion. The destination CTA waits with | ||
| `T.mbarrier_wait_parity`. | ||
|
|
||
| Steps: | ||
| 1. Both CTAs allocate the **same** shared memory layout so their mbarriers live | ||
| at the same offset. | ||
| 2. Every CTA initialises its own barrier for 1 arrival. | ||
| 3. The source CTA (`pid == 0` below) calls `T.copy(... dst_block=1, remote_barrier=...)`. | ||
| 4. The destination CTA (`pid == 1`) waits on its local barrier copy. | ||
|
|
||
| ```python | ||
| import tilelang | ||
| import tilelang.language as T | ||
|
|
||
| @tilelang.jit(verbose=True, execution_backend="cython") | ||
| def make_cluster_copy_kernel(N: int): | ||
| @T.prim_func | ||
| def kernel( | ||
| A: T.Tensor((N,), "float32"), | ||
| B: T.Tensor((N,), "float32"), | ||
| ): | ||
| with T.Kernel(2, threads=128, cluster_dims=(2, 1, 1)) as pid: | ||
| s_src = T.alloc_shared((N,), "float32") | ||
| s_dst = T.alloc_shared((N,), "float32") | ||
| s_barrier = T.alloc_shared((1,), "uint64") | ||
|
|
||
| T.fill(s_src, 0.0) | ||
| T.fill(s_dst, 0.0) | ||
|
|
||
| # Each CTA initialises its own barrier: 1 expected arrival. | ||
| if T.get_thread_binding() == 0: | ||
| T.mbarrier_init(s_barrier[0], 1) | ||
|
|
||
| T.cluster_sync() | ||
|
|
||
| if pid == 0: | ||
| # Load A into local shared memory. | ||
| for i in T.Parallel(N): | ||
| s_src[i] = A[i] | ||
|
|
||
| # Async-push s_src → s_dst in CTA 1, signal CTA 1's barrier. | ||
| T.copy(s_src, s_dst, dst_block=1, | ||
| remote_barrier=s_barrier[0]) | ||
|
|
||
| if pid == 1: | ||
| # Wait until CTA 0 finishes writing. | ||
| T.mbarrier_wait_parity(s_barrier[0], 0) | ||
|
|
||
| for i in T.Parallel(N): | ||
| B[i] = s_dst[i] | ||
|
|
||
| return kernel | ||
| ``` | ||
|
|
||
| Generated producer code (single-thread guard, one PTX instruction): | ||
|
|
||
| ```cuda | ||
| if (((int)threadIdx.x) == 0) { | ||
| tl::tma_store_cluster(&s_dst[0], &s_src[0], 1, | ||
| (uint32_t)(N * 4), s_barrier[0]); | ||
| } | ||
| ``` | ||
|
|
||
| ### Slow path — element-by-element SIMT fallback | ||
|
|
||
| Omit `remote_barrier` to use the slow path: | ||
|
|
||
| ```python | ||
| T.copy(s_src, s_dst, dst_block=1) | ||
| ``` | ||
|
|
||
| This lowers to a SIMT parallel loop where every thread writes one (or a few) | ||
| elements into the remote CTA's shared memory via | ||
| `cooperative_groups::map_shared_rank`. Because `map_shared_rank` returns a | ||
| scalar pointer, vectorised writes are not possible. Use this path only when an | ||
| mbarrier is unavailable or when the tile is too small to justify barrier | ||
| overhead. | ||
|
|
||
| ### Synchronisation contract | ||
|
|
||
| | | Fast path | Slow path | | ||
| |-|-----------|-----------| | ||
| | Source CTA | no wait needed; copy is async | effectively sync after the loop | | ||
| | Destination CTA | `T.mbarrier_wait_parity(barrier, parity)` | external `T.cluster_sync()` or equivalent | | ||
|
|
||
| ### Notes | ||
|
|
||
| - Both paths require `src` and `dst` to be in `shared` or `shared.dyn` scope. | ||
| - The mbarrier must be allocated with `T.alloc_shared((count,), "uint64")` and | ||
| initialised with `T.mbarrier_init` before use. | ||
| - `T.cluster_sync()` after allocation but before the copy is required to ensure | ||
| all CTAs have reached the barrier-init barrier before any data is pushed. | ||
| - `dst_block` may be a compile-time integer or a runtime `tir.PrimExpr`. | ||
|
|
||
| --- | ||
|
|
||
| ## Cluster Helper Builtins | ||
|
|
||
| | Builtin | Return | Description | | ||
| |---------|--------|-------------| | ||
| | `T.get_cluster_id()` | `int32` | Index of this cluster in the grid | | ||
| | `T.block_rank_in_cluster()` | `int32` | Block rank (0-indexed) within the cluster | | ||
| | `T.get_cluster_block_nums()` | `int32` | Total number of CTAs in the cluster | | ||
| | `T.cluster_sync()` | — | Barrier synchronisation across all cluster CTAs | | ||
| | `T.mbarrier_init(bar, count)` | — | Initialise an mbarrier for `count` arrivals | | ||
| | `T.mbarrier_arrive(bar)` | — | Signal one arrival on an mbarrier | | ||
| | `T.mbarrier_wait_parity(bar, parity)` | — | Wait until `bar` flips to `parity` | | ||
|
|
||
| --- | ||
|
|
||
| ## Putting It Together: Split-K Sketch | ||
|
|
||
| A common pattern combining both features: multicast the shared K-panel to | ||
| all cluster CTAs (saving DRAM bandwidth), then reduce partial sums with | ||
| SM-to-SM copy (saving global-memory round trips). | ||
|
|
||
| ```python | ||
| @T.prim_func | ||
| def split_k_gemm(A, B, C): | ||
| with T.Kernel(grid_x, grid_y, threads=256, cluster_dims=(4, 1, 1)) as (bx, by): | ||
| rank = T.block_rank_in_cluster() | ||
| A_s = T.alloc_shared((BM, BK), "float16") | ||
| B_s = T.alloc_shared((BK, BN), "float16") | ||
| C_f = T.alloc_fragment((BM, BN), "float32") | ||
| C_s = T.alloc_shared((BM, BN), "float32") | ||
| barrier = T.alloc_shared((1,), "uint64") | ||
| T.clear(C_f) | ||
|
|
||
| # Phase 1: each CTA loads its K-slice; A is multicast to rank 0 and 1. | ||
| for ko in T.Pipelined(T.ceildiv(K, BK * 4), num_stages=3): | ||
| k_off = (rank + ko * 4) * BK | ||
| T.copy(A[by * BM, k_off], A_s, cluster_mask=0b0011) | ||
| T.copy(B[k_off, bx * BN], B_s) | ||
| T.gemm(A_s, B_s, C_f) | ||
|
|
||
| # Phase 2: push partial sums to rank 0 via SM-to-SM copy. | ||
| T.copy(C_f, C_s) | ||
| if T.get_thread_binding() == 0: | ||
| T.mbarrier_init(barrier[0], 1) | ||
| T.cluster_sync() | ||
|
|
||
| if rank != 0: | ||
| T.copy(C_s, C_s, dst_block=0, remote_barrier=barrier[0]) | ||
| if rank == 0: | ||
| T.mbarrier_wait_parity(barrier[0], 0) | ||
| # accumulate and store ... | ||
| T.copy(C_s, C[by * BM, bx * BN]) | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| ## See Also | ||
|
|
||
| - `testing/python/cuda/test_tma_multicast_demo.py` — multicast validation | ||
| - `testing/python/cuda/test_tma_dsmem.py` — SM-to-SM copy validation | ||
| - Programming Guides → Instructions — complete `T.copy` parameter reference | ||
| - Programming Guides → Control Flow — `T.Pipelined` and warp-specialized pipelines | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.