Add TMA tile::gather4 / tile::scatter4 support#2129
Add TMA tile::gather4 / tile::scatter4 support#2129ighoshsubho wants to merge 4 commits intotile-ai:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR adds comprehensive TMA gather4/scatter4 support across the TileLang stack, including intrinsic registrations, copy operator infrastructure, CUDA codegen pathways, Python front-end API, and test coverage for SM100+ hardware. ChangesTMA Gather4/Scatter4 Feature
Sequence DiagramsequenceDiagram
participant PyAPI as Python API<br/>(tilelang.language)
participant CopyLow as Copy Operator<br/>(src/op/copy.cc)
participant Builtin as TMA Builtin<br/>(src/op/builtin.cc)
participant Codegen as CUDA Codegen<br/>(src/target/codegen_cuda.cc)
participant Templates as CUDA Templates<br/>(copy_sm100.h)
participant Hardware as SM100 Hardware
PyAPI->>PyAPI: tma_gather4(src, dst, col, rows,<br/>barrier, eviction_policy)
PyAPI->>CopyLow: Annotation: is_gather4=true,<br/>gather4_rows, gather4_col
CopyLow->>CopyLow: GetCopyInst() → kBulkLoadGather4
CopyLow->>CopyLow: LowerBulkCopyGather4() creates<br/>TMADesc, barrier, PTX call
CopyLow->>Builtin: Emit call to tl.tma_load_gather4
Builtin->>Codegen: Route to CUDA codegen
Codegen->>Codegen: Extract eviction_policy,<br/>map to CacheHintSm90
Codegen->>Templates: Generate template<br/>instantiation
Templates->>Hardware: CP.async.bulk.tensor.2d.shared::cta<br/>.global.tile::gather4 with<br/>CUtensorMap descriptor
Hardware-->>Templates: Load 4 rows into shared memory
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/tl_templates/cuda/copy_sm90.h`:
- Around line 417-439: The tma_load_gather4 function is using the wrong
namespace and lacks a CUDA 12.8+ guard: wrap the Blackwell-only gather4
implementation with the same version guard used by other TMA loads (i.e., only
compile when CUDA toolkit >= 12.8 / Blackwell) and change the inline-asm
namespace from "shared::cluster.global" to "shared::cta.global" in the asm
string; update the guard to surround the function or the asm block for
tma_load_gather4 so older toolkits skip this code and compilation errors are
prevented.
In `@testing/python/language/test_tilelang_language_tma_gather_scatter.py`:
- Around line 43-75: Gate the test to run only when CUDA is available and the
device compute capability is sm_90+ by adding a runtime skip or pytest marker:
import torch and pytest, then either decorate test_gather_scatter_basic with
`@pytest.mark.skipif`(not (torch.cuda.is_available() and
torch.cuda.get_device_properties(0).major >= 9), reason="requires CUDA sm_90+"),
or add an early check at the top of run_gather_scatter that calls pytest.skip
with the same condition; reference the test function test_gather_scatter_basic
and the helper run_gather_scatter so the compile/kernel invocation and CUDA
tensor allocations are only executed on sm_90+ devices.
In `@tilelang/language/copy_op.py`:
- Around line 297-359: The tma_gather4 path currently only checks rank and rows
but must validate dtype and layout invariants before calling
create_tma_descriptor/_tma_load_gather4: ensure src.dtype == dst.dtype, ensure
dst.shape[0] == 4 (leading dimension is 4 for the shared tile), and ensure the
innermost global stride is 1 (check the computed g_stride_rev[0] or
src.strides[1] when strides exist); if any check fails raise a ValueError with a
clear message. Add the same set of validations to the analogous gather4/scatter4
helper (the block around lines 389-449) referencing create_tma_descriptor,
_tma_load_gather4/_tma_store_scatter4 and the variables src, dst, src.strides,
dst.shape so the descriptor is only built for packed 4xK_box tiles of matching
element type and packed global layout.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: cee9ba02-9641-473e-8859-3a49f37bee6b
📒 Files selected for processing (8)
src/op/builtin.ccsrc/op/builtin.hsrc/target/codegen_cuda.ccsrc/tl_templates/cuda/copy_sm90.htesting/python/language/test_tilelang_language_tma_gather_scatter.pytilelang/language/__init__.pytilelang/language/builtin.pytilelang/language/copy_op.py
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tilelang/language/copy_op.py`:
- Around line 256-265: The _swizzle_str_to_int helper currently accepts
"32B"/"64B"/"128B" but we must reject non-none swizzle modes until shared-layout
support exists; update _swizzle_str_to_int to only accept None/"none"/0 and
return _CU_TENSOR_MAP_SWIZZLE_NONE, and change the other branches to raise
ValueError for any other input (including "32B","64B","128B",32,64,128), and
apply the same restrictive behavior to the other similar helper(s) referenced in
the file (the equivalent functions around the later blocks noted at lines
~342-356 and ~439-453) so the public API cannot encode swizzled modes
prematurely.
- Around line 297-300: The current checks in tma_gather4 (and similarly in
tma_scatter4 at the other site) only validate that src/dst are tir.Buffer
instances but do not assert their storage scopes; update the checks to also
verify the buffer storage scope strings so tma_gather4 enforces
src.storage_scope == "global" and dst.storage_scope == "shared" (and
tma_scatter4 enforces the opposite: src == "shared", dst == "global"). Locate
the checks around the tma_gather4 and tma_scatter4 helpers (the existing
isinstance checks for src and dst) and add explicit scope checks against the
buffer's storage scope property (e.g., buffer.scope or buffer.storage_scope API
used in this codebase), raising TypeError with a clear message if the scope is
wrong.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f3a11354-f33c-47fe-b638-0dab0b208773
📒 Files selected for processing (3)
src/tl_templates/cuda/copy_sm90.htesting/python/language/test_tilelang_language_tma_gather_scatter.pytilelang/language/copy_op.py
🚧 Files skipped from review as they are similar to previous changes (1)
- src/tl_templates/cuda/copy_sm90.h
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tilelang/language/copy_op.py`:
- Around line 296-297: The docstring for tma_gather4 conflicts with the runtime:
_swizzle_str_to_int no longer accepts "32B"/"64B"/"128B", so update
tma_gather4's swizzle parameter docs to reflect the actual contract (accepts
None or integer values 0, 32, 64, 128) and remove the quoted string variants;
mention that int 0/32/64/128 are the valid inputs and that other formats will be
rejected by _swizzle_str_to_int.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ab68ca95-a876-4a8b-a374-b3d7948649d9
📒 Files selected for processing (1)
tilelang/language/copy_op.py
7efc451 to
9aace3e
Compare
|
Awesome contribution! This is absolutely useful feature for kernel developmen on Blackwell. Left some minor suggestions:) |
There was a problem hiding this comment.
better to move these templates to copy_sm100.h, as they are introduced on sm100
There was a problem hiding this comment.
Good catch, yeah that cp.async.bulk.tensor.*.tile::gather4 / scatter4 was only introduced in sm100, l'll move both out of copy_sm90.h and into copy_sm100.h in the next push.
There was a problem hiding this comment.
i think there's no need to def these 2 fns in builtin.py? They are kinda duplicated with those in copy_op.py. Those in copy_op.py can be simply lowered to tir.call_intrin(xxx)
There was a problem hiding this comment.
Agreed, the tma_load_gather4 and tma_load_scatter4 are thin shims around tir.call_intrin, I'll drop them and inline the call_intrin directly in copy_op.py's tma_gather4 / tma_scatter4 so there's a single Python entry point per op.
There was a problem hiding this comment.
not sure but is there any way the compiler can figure out the cuTensorMapDataType and swizzle mode for us, just like previously did in plain TMA copy?
There was a problem hiding this comment.
Yes, that's the right direction. The reason this PR has explicit dtype/swizzle plumbing is to land a working baseline; tile::gather4 adds constraints that don't apply to plain TMA like rank 2 descriptor, leading box-dim must be 4, so I went with manual encoding while sorting that out. Modular's Mojo wrapper exposes the same swizzle_mode knob, so it's not a hardware constraint.
The right followup ig is to extend the same path the regular bulk copy already uses CopyNode::LowerBulkCopy, and infers desc.data_type via to_CUtensorMapDataType(global_tensor->dtype and picks desc.swizzle from the shared tile SwizzleMode.
Happy to do it in this PR if you'd prefer it bundled, or as a follow up, let me know. cc @Rachmanino
|
Also feel free to quote me whenever you wanna run the CI routine |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/layout/gemm_layouts.cc`:
- Around line 434-437: DetectSwizzleMode currently enforces info.stride % 8 == 0
before returning quarter/half/full swizzle modes, so newly-allowed stride==4
layouts (used by tile::gather4/scatter4) get classified as
CU_TENSOR_MAP_SWIZZLE_NONE; update DetectSwizzleMode to treat info.stride==4 as
a valid candidate for the 4-row swizzle pattern (and similarly handle any checks
that previously required stride%8==0), by adding a branch that accepts stride==4
and applies the same swizzle-detection logic used for 8-multiple strides
(returning quarter/half/full as appropriate based on the computed pattern),
ensuring references to stride/info.stride and the CU_TENSOR_MAP_SWIZZLE_NONE
return are adjusted accordingly.
In `@src/op/copy.cc`:
- Around line 1912-2012: LowerBulkCopyGather4 currently skips the generic TMA
legality checks and can build a TMADesc with misaligned box/strides; before
calling create_tma_descriptor() in LowerBulkCopyGather4, reapply the same
validations used by the regular TMA lowering: verify the shared box byte-size is
128-bit aligned (i.e., K_box * dtype.bytes() % 16 == 0), ensure each global
stride (after TMABytesFromElements conversion) is 16-byte aligned and within the
allowed range (< 2^40 bytes), and fail fast (ICHECK or return an error) if any
check fails; place these checks right after computing
desc.global_stride/desc.smem_box and before Call(DataType::Handle(),
create_tma_descriptor(), ...), referencing TMADesc, desc.global_stride,
desc.smem_box, create_tma_descriptor(), and LowerBulkCopyGather4 so the
specialized gather4/scatter4 path enforces the same constraints as the generic
TMA lowering.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 44a9d366-6e18-4d1c-b63f-465abe03b9a6
📒 Files selected for processing (6)
src/layout/gemm_layouts.ccsrc/op/copy.ccsrc/op/copy.hsrc/tl_templates/cuda/copy_sm100.htesting/python/language/test_tilelang_language_tma_gather_scatter.pytilelang/language/copy_op.py
d8992e5 to
c0ef91a
Compare
|
@Rachmanino re-pushed: descriptor build moved into LowerBulkGather4 (CUDA backend), dtype + swizzle now compiler-inferred (DetectSwizzleMode extended for stride==4, plus generic TMA legality checks per coderabbit). Templates moved to copy_sm100.h, builtin.py wrappers dropped. 4 gather4 tests + 21 gemm regressions green on B200 feel free to kick CI. |
c0ef91a to
d4842b6
Compare
|
all merge conflicts resolved |
|
some rocm tests seem to be failing, I don't think its related to this PR but @Rachmanino let me know if its something from my side, will fix it |
|
yep, these rocm failures are irrelevent. I have no further concern, cc @LeiWang1999 |
Adds first-class support for the Hopper+ TMA
tile::gather4(load) andtile::scatter4(store) PTX modes. Previously TileLang's TMA path only covered rectangular tiled and im2col descriptors; sparse-row access patterns (MoE token gather, KV-cache compaction, embedding tile lookup, sparse attention) had to fall back to per-threadcp.async. This PR exposes the gather/scatter modes end-to-end: device PTX wrappers, TIR builtins, CUDA codegen, and Python user API.Major changes
tile::gather4(load) andtile::scatter4(store) PTX modes for Hopper+ (verified on B200 / sm_100).tl.tma_load_gather4/tl.tma_store_scatter4, CUDA codegen, device-side templates incopy_sm90.h.T.tma_gather4/T.tma_scatter4(andT.tma_gather4_bytesforexpect_tx).create_tma_descriptor+LowerHopperIntrinhoisting pipeline; descriptor is a 2D tiledCUtensorMapwith box dim along the gathered axis = 1.Usage
Test plan
testing/python/language/test_tilelang_language_tma_gather_scatter.py— round-trips 4 rows of a(64, 64)fp16 tensor throughgather4→ smem →scatter4, asserts equality.tl::tma_load_gather4/tl::tma_store_scatter4.Known follow-ups (not in this PR)
T.reads/T.writesannotations because the descriptor build references rawbuffer.dataat parse time. Cleaner fix: aLowerBulkGather4lowering pass (parallel toLowerBulkCopy) that builds the descriptor at lowering time. Happy to follow up.Summary by CodeRabbit
Release Notes
New Features
Tests