[Refactor] Move backend-specific GEMM implementations and transforms into backend directories#2165
Conversation
…into backend directories Restructure the codebase so that each backend (cpu, cuda, rocm) owns its GEMM implementations, sparse GEMM implementations, and transform passes under a consistent op/ subdirectory layout. The shared GEMM registry and base classes remain in tileop/ as the platform-agnostic dispatch layer. - Move gemm/gemm_sp registries from backend/ into tileop/ as registry.py - Move CUDA GEMM impls (mma, mma_sm70, wgmma, tcgen05) into backend/cuda/op/gemm/ - Move CUDA sparse GEMM impl into backend/cuda/op/gemm_sp/ - Move CPU GEMM impl (scalar) into backend/cpu/op/gemm/ - Move ROCm GEMM impls (mfma, wmma) into backend/rocm/op/gemm/ - Move CUDA-specific transform passes from src/transform/ into src/backend/cuda/transform/ - Move CUDA runtime sources from src/runtime/ into src/backend/cuda/ - Remove dead backend-importing wrappers from transform/__init__.py - Update phase.py and tests to import CUDA transforms from their canonical location Each backend now has a symmetric op/ directory structure. Adding a new backend no longer requires modifying shared transform or tileop modules.
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR consolidates CUDA and ROCm intrinsics into backend-specific namespaces, adds CUDA TMA descriptor and L2 cache FFI functions, migrates transform passes to the CUDA backend package, reorganizes GEMM registry from backend to tileop, and updates 60+ files to route imports through backend-specific modules. ChangesCUDA TMA and L2 Cache FFI Functions
CUDA Transform Passes Backend Migration
Backend Package Hierarchy and GEMM Registry Reorganization
Intrinsics and Layout Backend Consolidation
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
|
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (5)
tilelang/backend/rocm/op/gemm/gemm_mfma.py (1)
220-233:⚠️ Potential issue | 🟡 Minor | ⚡ Quick win
self.k_packused inconsistently in theis_gemm_rrbranch — should use the emitter-derived localk_pack.All other branches (SS, SR, RS at lines 136, 172, 201) use
k_pack(the local captured frommfma_emitter.k_pack). Theis_gemm_rrbranch at line 227 usesself.k_packinstead. IfMatrixCoreIntrinEmitternormalises or clamps the k_pack value, the assertion on line 116 will pass using the emitter's value while the loop bound will use the rawself.k_pack, producing a wrong iteration count.🐛 Proposed fix
- for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))): + for ki in T.serial(0, (block_K // (micro_size_k * k_pack))):🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/backend/rocm/op/gemm/gemm_mfma.py` around lines 220 - 233, The loop in _gemm_rsr incorrectly uses self.k_pack instead of the emitter-normalized local k_pack; change the code to capture mfma_emitter.k_pack (same local name used in other branches) and use that local k_pack in the loop bound inside _gemm_rsr (replace self.k_pack with k_pack) so iteration count matches the emitter's normalized value.tilelang/backend/rocm/op/gemm/gemm_wmma.py (1)
143-152:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winSame
self.k_pack/k_packinconsistency as ingemm_mfma.py.All preceding branches use the emitter-derived local
k_pack; only theis_gemm_rrbranch reaches forself.k_pack. IfWMMAIntrinEmitteradjusts the value, the loop count can diverge from the assertion on line 97.🐛 Proposed fix
- for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))): + for ki in T.serial(0, (block_K // (micro_size_k * k_pack))):🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/backend/rocm/op/gemm/gemm_wmma.py` around lines 143 - 152, The is_gemm_rr branch incorrectly uses self.k_pack instead of the emitter-derived local k_pack, causing a potential mismatch with wmma_emitter's configuration; update the loop bound in the _gemm_rrr prim_func to use the same local k_pack variable used elsewhere (the one produced/used by WMMAIntrinEmitter) so the iteration count (for ki over block_K // (micro_size_k * k_pack)) matches the assertion and wmma_emitter.wmma usage; locate this in the is_gemm_rr branch around the _gemm_rrr definition and replace self.k_pack with the local k_pack identifier.src/backend/cuda/runtime.cc (3)
875-876:⚠️ Potential issue | 🟠 Major | ⚡ Quick winAssign the requested window size, not the L2 limit value, to
accessPolicyWindow.num_bytes.Line 875 incorrectly assigns
l2_limit_bytestonum_bytes. According to the CUDA Driver API,accessPolicyWindow.num_bytesspecifies the policy window extent, which should come from thenum_bytesparameter (args[1], line ~817), while L2 cache limits are a separate concern. Usingl2_limit_byteshere changes the window semantics and can unintentionally extend or constrain persistence coverage outside the requested range.Fix: Change line 875 to assign
num_bytesinstead ofl2_limit_bytes.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/backend/cuda/runtime.cc` around lines 875 - 876, Change the assignment so the accessPolicyWindow uses the requested window size rather than the L2 limit: replace the value assigned to stream_attribute.accessPolicyWindow.num_bytes (currently using l2_limit_bytes) with the num_bytes parameter used to specify the policy window extent; leave stream_attribute.accessPolicyWindow.hitRatio and any L2-related handling using l2_limit_bytes unchanged.
819-823:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winValidate
hit_ratiorange before callingcuStreamSetAttribute.Line 819–823 accepts arbitrary floating values; CUDA expects a valid hit ratio range. Add a bound check (typically
[0.0, 1.0]) to fail fast with a clear message instead of deferring to driver error paths.What valid value range does CUDA Driver API require for CUaccessPolicyWindow.hitRatio?🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/backend/cuda/runtime.cc` around lines 819 - 823, The code currently assigns args[2] into hit_ratio without validation; validate that hit_ratio (the variable used to populate CUaccessPolicyWindow.hitRatio before calling cuStreamSetAttribute) is within [0.0f, 1.0f] and fail fast with a clear error (throw or return error/log) if out of range; specifically check the parsed value from args[2].cast<double>() (or args when size>=3) and either clamp to [0.0f,1.0f] or prefer returning an error with a message like "hit_ratio must be in [0.0,1.0]" before proceeding to cuStreamSetAttribute. Ensure the validation happens right after setting hit_ratio and references the hit_ratio variable and CUaccessPolicyWindow.hitRatio usage so the caller sees a clear failure instead of a driver error.
337-340:⚠️ Potential issue | 🟠 Major | ⚡ Quick winValidate bounds before narrowing 64-bit arguments to 32-bit descriptor fields.
At lines 337, 340, 592, and 600–601, values are cast from 64-bit and directly assigned to
cuuint32_tfields without pre-checks. Oversized inputs silently wrap before validation, producing incorrect descriptors. For example, a 64-bit value of 2³² + 256 wraps to 256 after assignment, then passes the validation check for<= 256.Validate bounds using the 64-bit temporary before narrowing:
- Line 337:
boxDim[i]- Line 340:
elementStrides[i]- Line 592:
elementStrides[i]- Lines 600–601:
smem_box_pixelandsmem_box_channelExample fix
- T.boxDim[i] = args[idx++].cast<cuuint64_t>(); + uint64_t box_dim = args[idx++].cast<cuuint64_t>(); + ICHECK_LE(box_dim, std::numeric_limits<cuuint32_t>::max()) + << "boxDim[" << i << "] overflows cuuint32_t"; + T.boxDim[i] = static_cast<cuuint32_t>(box_dim);🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/backend/cuda/runtime.cc` around lines 337 - 340, The code narrows 64-bit values to cuuint32_t fields without checking for overflow, causing wrapped values to pass later validation; fix by reading each 64-bit argument into a 64-bit temporary (e.g., cuuint64_t tmp = args[idx++].cast<cuuint64_t>()), validate tmp against the appropriate upper bound (at minimum numeric_limits<cuuint32_t>::max() and any domain-specific limits already used later), and only then assign static_cast<cuuint32_t>(tmp) to T.boxDim[i], T.elementStrides[i], and the smem fields (smem_box_pixel, smem_box_channel); apply the same pattern to the other occurrences noted (the two elementStrides loops and the smem assignments) so no narrowing happens before bounds validation.
🧹 Nitpick comments (1)
tilelang/engine/phase.py (1)
5-5: ⚖️ Poor tradeoffTop-level CUDA import in the backend-agnostic engine couples all targets to CUDA at import time.
phase.pyis the shared lowering pipeline for all backends (CPU, ROCm, CUDA). Importingcuda_transformunconditionally at module load means every target triggers the CUDA backend initialisation chain (backend/cuda/__init__.py→op→ GEMM registrations). While functionally tolerable today, it prevents clean CPU/ROCm-only builds and makes the dependency graph harder to reason about.Consider a lazy/deferred import:
♻️ Suggested refactor
-from tilelang.backend.cuda import transform as cuda_transform + +def _cuda_transform(): + from tilelang.backend.cuda import transform # lazy import + return transformThen use
_cuda_transform().LowerL2Persistent()(mod)etc., so the import only fires when a CUDA lowering step is actually needed.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/engine/phase.py` at line 5, phase.py currently imports cuda_transform at module import time, coupling all backends to CUDA; change this to a lazy import by adding a helper function (e.g., _cuda_transform) that imports tilelang.backend.cuda.transform inside the function and returns it, then update all call sites in this module that use cuda_transform (for example calls to LowerL2Persistent(), LowerXXX() etc.) to call _cuda_transform().LowerL2Persistent()(...) (or the corresponding method invocation) so the CUDA package is only imported when a CUDA-specific lowering is actually executed.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tilelang/backend/cuda/transform/__init__.py`:
- Around line 13-20: LowerL2Persistent and PersistThreadblock call into _ffi_api
unguarded and will raise AttributeError when the CUDA FFI symbols are missing;
mirror the pattern used by LowerHopperIntrin: check hasattr(_ffi_api,
"LowerL2Persistent") and hasattr(_ffi_api, "PersistThreadblock") and return the
FFI call only if present, otherwise return None (or an appropriate no-op
fallback) so callers in engine/phase.py don't crash when USE_CUDA/FFI symbols
are absent.
---
Outside diff comments:
In `@src/backend/cuda/runtime.cc`:
- Around line 875-876: Change the assignment so the accessPolicyWindow uses the
requested window size rather than the L2 limit: replace the value assigned to
stream_attribute.accessPolicyWindow.num_bytes (currently using l2_limit_bytes)
with the num_bytes parameter used to specify the policy window extent; leave
stream_attribute.accessPolicyWindow.hitRatio and any L2-related handling using
l2_limit_bytes unchanged.
- Around line 819-823: The code currently assigns args[2] into hit_ratio without
validation; validate that hit_ratio (the variable used to populate
CUaccessPolicyWindow.hitRatio before calling cuStreamSetAttribute) is within
[0.0f, 1.0f] and fail fast with a clear error (throw or return error/log) if out
of range; specifically check the parsed value from args[2].cast<double>() (or
args when size>=3) and either clamp to [0.0f,1.0f] or prefer returning an error
with a message like "hit_ratio must be in [0.0,1.0]" before proceeding to
cuStreamSetAttribute. Ensure the validation happens right after setting
hit_ratio and references the hit_ratio variable and
CUaccessPolicyWindow.hitRatio usage so the caller sees a clear failure instead
of a driver error.
- Around line 337-340: The code narrows 64-bit values to cuuint32_t fields
without checking for overflow, causing wrapped values to pass later validation;
fix by reading each 64-bit argument into a 64-bit temporary (e.g., cuuint64_t
tmp = args[idx++].cast<cuuint64_t>()), validate tmp against the appropriate
upper bound (at minimum numeric_limits<cuuint32_t>::max() and any
domain-specific limits already used later), and only then assign
static_cast<cuuint32_t>(tmp) to T.boxDim[i], T.elementStrides[i], and the smem
fields (smem_box_pixel, smem_box_channel); apply the same pattern to the other
occurrences noted (the two elementStrides loops and the smem assignments) so no
narrowing happens before bounds validation.
In `@tilelang/backend/rocm/op/gemm/gemm_mfma.py`:
- Around line 220-233: The loop in _gemm_rsr incorrectly uses self.k_pack
instead of the emitter-normalized local k_pack; change the code to capture
mfma_emitter.k_pack (same local name used in other branches) and use that local
k_pack in the loop bound inside _gemm_rsr (replace self.k_pack with k_pack) so
iteration count matches the emitter's normalized value.
In `@tilelang/backend/rocm/op/gemm/gemm_wmma.py`:
- Around line 143-152: The is_gemm_rr branch incorrectly uses self.k_pack
instead of the emitter-derived local k_pack, causing a potential mismatch with
wmma_emitter's configuration; update the loop bound in the _gemm_rrr prim_func
to use the same local k_pack variable used elsewhere (the one produced/used by
WMMAIntrinEmitter) so the iteration count (for ki over block_K // (micro_size_k
* k_pack)) matches the assertion and wmma_emitter.wmma usage; locate this in the
is_gemm_rr branch around the _gemm_rrr definition and replace self.k_pack with
the local k_pack identifier.
---
Nitpick comments:
In `@tilelang/engine/phase.py`:
- Line 5: phase.py currently imports cuda_transform at module import time,
coupling all backends to CUDA; change this to a lazy import by adding a helper
function (e.g., _cuda_transform) that imports tilelang.backend.cuda.transform
inside the function and returns it, then update all call sites in this module
that use cuda_transform (for example calls to LowerL2Persistent(), LowerXXX()
etc.) to call _cuda_transform().LowerL2Persistent()(...) (or the corresponding
method invocation) so the CUDA package is only imported when a CUDA-specific
lowering is actually executed.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c0ae3bfc-278a-4519-91db-312c0558cfa1
📒 Files selected for processing (33)
src/backend/cuda/CMakeLists.txtsrc/backend/cuda/runtime.ccsrc/backend/cuda/runtime.hsrc/backend/cuda/transform/lower_hopper_intrin.ccsrc/backend/cuda/transform/lower_l2_persistent_annotation.ccsrc/backend/cuda/transform/persist_threadblock.cctesting/python/transform/test_tilelang_transform_lower_hopper_intrin.pytilelang/backend/__init__.pytilelang/backend/cpu/__init__.pytilelang/backend/cpu/op/__init__.pytilelang/backend/cpu/op/gemm/__init__.pytilelang/backend/cpu/op/gemm/gemm_scalar.pytilelang/backend/cuda/__init__.pytilelang/backend/cuda/op/__init__.pytilelang/backend/cuda/op/gemm/__init__.pytilelang/backend/cuda/op/gemm/gemm_mma.pytilelang/backend/cuda/op/gemm/gemm_mma_sm70.pytilelang/backend/cuda/op/gemm/gemm_tcgen05.pytilelang/backend/cuda/op/gemm/gemm_wgmma.pytilelang/backend/cuda/op/gemm_sp/__init__.pytilelang/backend/cuda/op/gemm_sp/gemm_sp_mma.pytilelang/backend/cuda/transform/__init__.pytilelang/backend/rocm/__init__.pytilelang/backend/rocm/op/__init__.pytilelang/backend/rocm/op/gemm/__init__.pytilelang/backend/rocm/op/gemm/gemm_mfma.pytilelang/backend/rocm/op/gemm/gemm_wmma.pytilelang/engine/phase.pytilelang/tileop/gemm/__init__.pytilelang/tileop/gemm/registry.pytilelang/tileop/gemm_sp/__init__.pytilelang/tileop/gemm_sp/registry.pytilelang/transform/__init__.py
💤 Files with no reviewable changes (1)
- tilelang/transform/init.py
| def LowerL2Persistent(): | ||
| """LowerL2Persistent""" | ||
| return _ffi_api.LowerL2Persistent() # type: ignore | ||
|
|
||
|
|
||
| def PersistThreadblock(): | ||
| """PersistThreadblock""" | ||
| return _ffi_api.PersistThreadblock() # type: ignore |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check for conditional compilation guards in the L2-persistent and PersistThreadblock C++ files
fd -e cc "lower_l2_persistent" src/
fd -e cc "persist_threadblock" src/
# Search for `#if` guards around TVM_FFI_STATIC_INIT_BLOCK in those files
rg -n "#if|CUDA_MAJOR|TVM_FFI_STATIC_INIT_BLOCK" src/backend/cuda/transform/Repository: tile-ai/tilelang
Length of output: 608
🏁 Script executed:
#!/bin/bash
# Check the calls in engine/phase.py at lines 208 and 309
echo "=== Checking engine/phase.py calls ==="
fd -e py "phase.py" src/ | head -5Repository: tile-ai/tilelang
Length of output: 99
🏁 Script executed:
#!/bin/bash
# Search for LowerHopperIntrin in the Python wrapper to confirm hasattr guards
echo "=== Checking LowerHopperIntrin in Python ==="
rg -n "LowerHopperIntrin" --type py src/Repository: tile-ai/tilelang
Length of output: 105
🏁 Script executed:
#!/bin/bash
# Check the full content of lower_hopper_intrin.py to see if it has hasattr
fd -e py "*hopper*" src/backend/cuda/Repository: tile-ai/tilelang
Length of output: 348
🏁 Script executed:
#!/bin/bash
# Find phase.py using git ls-files
echo "=== Looking for phase.py ==="
git ls-files | grep phase
# Find the transform __init__.py file
echo "=== Looking for transform __init__.py ==="
git ls-files | grep "backend/cuda/transform/__init__"Repository: tile-ai/tilelang
Length of output: 200
🏁 Script executed:
#!/bin/bash
# List all Python files in backend/cuda/transform/
echo "=== Python files in backend/cuda/transform/ ==="
git ls-files src/backend/cuda/transform/ | grep ".py$"Repository: tile-ai/tilelang
Length of output: 108
🏁 Script executed:
#!/bin/bash
# Check the actual __init__.py file
echo "=== Content of transform/__init__.py ==="
cat -n src/tilelang/backend/cuda/transform/__init__.pyRepository: tile-ai/tilelang
Length of output: 199
🏁 Script executed:
#!/bin/bash
# Read the transform __init__.py file with correct path
cat -n tilelang/backend/cuda/transform/__init__.pyRepository: tile-ai/tilelang
Length of output: 838
🏁 Script executed:
#!/bin/bash
# Check engine/phase.py lines around 208 and 309
echo "=== Lines 200-215 of engine/phase.py ==="
sed -n '200,215p' tilelang/engine/phase.py | cat -n
echo ""
echo "=== Lines 300-315 of engine/phase.py ==="
sed -n '300,315p' tilelang/engine/phase.py | cat -nRepository: tile-ai/tilelang
Length of output: 1681
🏁 Script executed:
#!/bin/bash
# Check the C++ files for full content around the FFI registration
echo "=== lower_l2_persistent_annotation.cc FFI block ==="
sed -n '95,110p' tilelang/backend/cuda/transform/lower_l2_persistent_annotation.cc | cat -n
echo ""
echo "=== persist_threadblock.cc FFI block ==="
sed -n '55,70p' tilelang/backend/cuda/transform/persist_threadblock.cc | cat -nRepository: tile-ai/tilelang
Length of output: 379
🏁 Script executed:
#!/bin/bash
# Check the C++ files with src/ prefix
echo "=== lower_l2_persistent_annotation.cc FFI block ==="
sed -n '95,110p' src/backend/cuda/transform/lower_l2_persistent_annotation.cc | cat -n
echo ""
echo "=== persist_threadblock.cc FFI block ==="
sed -n '55,70p' src/backend/cuda/transform/persist_threadblock.cc | cat -n
echo ""
echo "=== lower_hopper_intrin.cc around FFI (for comparison) ==="
sed -n '245,253p' src/backend/cuda/transform/lower_hopper_intrin.cc | cat -nRepository: tile-ai/tilelang
Length of output: 1563
🏁 Script executed:
#!/bin/bash
# Check if lower_l2_persistent_annotation.cc has any `#if` guards
echo "=== Checking lower_l2_persistent_annotation.cc for `#if` guards ==="
rg -n "^#if|CUDA_MAJOR|ROCM|CPU" src/backend/cuda/transform/lower_l2_persistent_annotation.ccRepository: tile-ai/tilelang
Length of output: 126
🏁 Script executed:
#!/bin/bash
# Check if persist_threadblock.cc has any `#if` guards
echo "=== Checking persist_threadblock.cc for `#if` guards ==="
rg -n "^#if|CUDA_MAJOR|ROCM|CPU" src/backend/cuda/transform/persist_threadblock.ccRepository: tile-ai/tilelang
Length of output: 115
🏁 Script executed:
#!/bin/bash
# Check CMakeLists.txt for how these files are included
echo "=== Looking for CMakeLists.txt files ==="
find . -name CMakeLists.txt | grep -E "cuda|transform"Repository: tile-ai/tilelang
Length of output: 135
🏁 Script executed:
#!/bin/bash
# Check if there are any build guards around these C++ files
echo "=== Check full build directory structure ==="
git ls-files | grep -E "CMakeLists|cmake" | head -20Repository: tile-ai/tilelang
Length of output: 385
🏁 Script executed:
#!/bin/bash
# Search for how persist_threadblock and lower_l2_persistent are included in build
echo "=== Searching CMakeLists for these files ==="
rg -l "persist_threadblock|lower_l2_persistent" --type cmakeRepository: tile-ai/tilelang
Length of output: 105
🏁 Script executed:
#!/bin/bash
# Read the CUDA CMakeLists.txt to see how files are included
cat src/backend/cuda/CMakeLists.txtRepository: tile-ai/tilelang
Length of output: 8123
🏁 Script executed:
#!/bin/bash
# Search for these filenames in any CMake files
echo "=== Searching all CMake files ==="
rg "persist_threadblock|lower_l2_persistent" . --type cmakeRepository: tile-ai/tilelang
Length of output: 94
Add hasattr guards to LowerL2Persistent and PersistThreadblock — they lack fallback protection when FFI symbols are absent.
LowerHopperIntrin has a hasattr guard (line 8-10) because its C++ implementation is gated by #if (CUDA_MAJOR_VERSION >= 12). While LowerL2Persistent and PersistThreadblock lack similar C++ guards and are unconditionally compiled into the CUDA backend, they will only be registered if the CUDA backend is built (controlled by USE_CUDA). When CUDA is disabled or unavailable, these symbols won't exist, causing AttributeError at lines 209 and 310 in engine/phase.py.
Add the same fallback pattern as LowerHopperIntrin:
Proposed fix
def LowerL2Persistent():
"""LowerL2Persistent"""
- return _ffi_api.LowerL2Persistent() # type: ignore
+ if hasattr(_ffi_api, "LowerL2Persistent"):
+ return _ffi_api.LowerL2Persistent() # type: ignore
+ return lambda f: f
def PersistThreadblock():
"""PersistThreadblock"""
- return _ffi_api.PersistThreadblock() # type: ignore
+ if hasattr(_ffi_api, "PersistThreadblock"):
+ return _ffi_api.PersistThreadblock() # type: ignore
+ return lambda f: f🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/backend/cuda/transform/__init__.py` around lines 13 - 20,
LowerL2Persistent and PersistThreadblock call into _ffi_api unguarded and will
raise AttributeError when the CUDA FFI symbols are missing; mirror the pattern
used by LowerHopperIntrin: check hasattr(_ffi_api, "LowerL2Persistent") and
hasattr(_ffi_api, "PersistThreadblock") and return the FFI call only if present,
otherwise return None (or an appropriate no-op fallback) so callers in
engine/phase.py don't crash when USE_CUDA/FFI symbols are absent.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
benchmark/matmul/benchmark_matmul_intrinsic.py (1)
285-286:⚠️ Potential issue | 🟡 Minor | ⚡ Quick win
with_roller = Trueon Line 286 unconditionally overrides the CLI argument — dead user-facing flag.Line 285 correctly assigns from
args.with_roller, but line 286 immediately clobbers it with a hardcodedTrue. The--with_rollerargument (line 277) is therefore inoperable and the non-roller config path is unreachable from__main__. This looks like a leftover debug override.🐛 Proposed fix
with_roller = args.with_roller - with_roller = True🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@benchmark/matmul/benchmark_matmul_intrinsic.py` around lines 285 - 286, The CLI flag is being clobbered: the variable with_roller is correctly set from args.with_roller but immediately overwritten by the hardcoded statement with_roller = True; remove that hardcoded assignment (or change it to only set a default when args.with_roller is None) so that the value from args.with_roller is respected in the __main__ flow and the non-roller code path can be reached.examples/plot_layout/README.md (1)
39-44:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winREADME example imports non-existent function names from
mma_layout.The imports reference
shared_16x16_to_mma_32x8_layout_sr,shared_16x16_to_mma_32x8_layout_rs,shared_16x32_to_mma_32x16_layout, andshared_32x16_to_mma_32x16_layout, but none of these functions exist in the module. The current API uses an operand-based naming scheme with_aand_bsuffixes (e.g.,shared_16x16_to_mma_a_32x8_layout,shared_16x16_to_mma_b_32x8_layout). Copying this README snippet will result in anImportError.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/plot_layout/README.md` around lines 39 - 44, The README imports refer to non-existent names in tilelang.backend.cuda.intrinsics.mma_layout; update the import list to use the operand-suffixed API (replace shared_16x16_to_mma_32x8_layout_sr/shared_16x16_to_mma_32x8_layout_rs with shared_16x16_to_mma_a_32x8_layout and shared_16x16_to_mma_b_32x8_layout, and replace shared_16x32_to_mma_32x16_layout/shared_32x16_to_mma_32x16_layout with the operand-specific shared_16x32_to_mma_a_32x16_layout and shared_32x16_to_mma_b_32x16_layout) so the README imports match the module's actual function names.
🧹 Nitpick comments (1)
tilelang/intrinsics/__init__.py (1)
14-23: Consider lazy-loading backend-specific intrinsics instead of eager imports at package initialization.
tilelang.intrinsicseagerly imports bothtilelang.backend.cuda.intrinsics.mma_macro_generatorandtilelang.backend.rocm.intrinsics.mfma_layoutat package load time. Any consumer ofimport tilelang.intrinsicswill load both backend subpackages unconditionally.While HIP/CUDA stub libraries mitigate runtime failures on CPU-only systems by lazy-loading hardware libraries at the C++ level, eager Python imports remain architecturally suboptimal. Consider using
TYPE_CHECKINGor module-level__getattr__to defer backend imports until they're actually used.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/intrinsics/__init__.py` around lines 14 - 23, The package is eagerly importing backend-specific modules; change __init__ to lazy-load those symbols instead: remove direct imports of TensorCoreIntrinEmitter, TensorCoreIntrinEmitterWithLadderTransform, make_mma_swizzle_layout and make_mfma_swizzle_layout and implement deferred resolution (either via typing.TYPE_CHECKING guards to keep imports only for type checking, and/or a module-level __getattr__ that imports and returns these names on first access); ensure __all__ exports these symbol names so consumers still get them via from tilelang.intrinsics import <name> while avoiding importing tilelang.backend.cuda.* and tilelang.backend.rocm.* at package import time.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py`:
- Around line 5-7: The current import aliases make_swizzle_layout to the
CUDA-specific make_mma_swizzle_layout and it's applied unconditionally to
A_shared/B_shared (see use around A_shared/B_shared, lines ~115–120) which
breaks ROCm; change the import/selection to use the ROCm-specific swizzle when
is_hip is true: import or reference both make_mma_swizzle_layout and
make_mfma_swizzle_layout (or use tilelang.intrinsics' exported
make_mfma_swizzle_layout) and branch on is_hip when choosing the swizzle
function (e.g., swizzle_fn = make_mfma_swizzle_layout if is_hip else
make_mma_swizzle_layout) before applying it to A_shared and B_shared so ROCm
uses the correct MFMA swizzle layout.
---
Outside diff comments:
In `@benchmark/matmul/benchmark_matmul_intrinsic.py`:
- Around line 285-286: The CLI flag is being clobbered: the variable with_roller
is correctly set from args.with_roller but immediately overwritten by the
hardcoded statement with_roller = True; remove that hardcoded assignment (or
change it to only set a default when args.with_roller is None) so that the value
from args.with_roller is respected in the __main__ flow and the non-roller code
path can be reached.
In `@examples/plot_layout/README.md`:
- Around line 39-44: The README imports refer to non-existent names in
tilelang.backend.cuda.intrinsics.mma_layout; update the import list to use the
operand-suffixed API (replace
shared_16x16_to_mma_32x8_layout_sr/shared_16x16_to_mma_32x8_layout_rs with
shared_16x16_to_mma_a_32x8_layout and shared_16x16_to_mma_b_32x8_layout, and
replace shared_16x32_to_mma_32x16_layout/shared_32x16_to_mma_32x16_layout with
the operand-specific shared_16x32_to_mma_a_32x16_layout and
shared_32x16_to_mma_b_32x16_layout) so the README imports match the module's
actual function names.
---
Nitpick comments:
In `@tilelang/intrinsics/__init__.py`:
- Around line 14-23: The package is eagerly importing backend-specific modules;
change __init__ to lazy-load those symbols instead: remove direct imports of
TensorCoreIntrinEmitter, TensorCoreIntrinEmitterWithLadderTransform,
make_mma_swizzle_layout and make_mfma_swizzle_layout and implement deferred
resolution (either via typing.TYPE_CHECKING guards to keep imports only for type
checking, and/or a module-level __getattr__ that imports and returns these names
on first access); ensure __all__ exports these symbol names so consumers still
get them via from tilelang.intrinsics import <name> while avoiding importing
tilelang.backend.cuda.* and tilelang.backend.rocm.* at package import time.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f736efea-89d6-46df-b36c-d545a6233b26
📒 Files selected for processing (43)
benchmark/matmul/benchmark_matmul_intrinsic.pyexamples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.pyexamples/dequantize_gemm/example_dequant_gemm_fine_grained.pyexamples/gemm/example_gemm_intrinsics.pyexamples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.pyexamples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.pyexamples/plot_layout/README.mdexamples/plot_layout/fragment_mfma_load_a.pyexamples/plot_layout/fragment_mma_load_a.pysrc/backend/cuda/CMakeLists.txtsrc/backend/cuda/runtime.cctesting/python/amd/test_tilelang_gemm_mfma_intrinsic.pytesting/python/amd/test_tilelang_gemm_mfma_preshuffle.pytesting/python/kernel/test_tilelang_kernel_bf16_gemm_mma.pytesting/python/kernel/test_tilelang_kernel_fp8_gemm_mma.pytesting/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.pytesting/python/kernel/test_tilelang_kernel_gemm_simt.pytesting/python/language/test_tilelang_language_reshape.pytesting/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.pytesting/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.pytilelang/backend/cuda/intrinsics/__init__.pytilelang/backend/cuda/intrinsics/mma_layout.pytilelang/backend/cuda/intrinsics/mma_macro_generator.pytilelang/backend/cuda/intrinsics/mma_sm70_layout.pytilelang/backend/cuda/intrinsics/mma_sm70_macro_generator.pytilelang/backend/cuda/intrinsics/mma_sp_layout.pytilelang/backend/cuda/intrinsics/mma_sp_macro_generator.pytilelang/backend/cuda/intrinsics/tcgen05_macro_generator.pytilelang/backend/cuda/intrinsics/wgmma_macro_generator.pytilelang/backend/cuda/op/gemm/gemm_mma.pytilelang/backend/cuda/op/gemm/gemm_mma_sm70.pytilelang/backend/cuda/op/gemm/gemm_tcgen05.pytilelang/backend/cuda/op/gemm/gemm_wgmma.pytilelang/backend/cuda/op/gemm_sp/gemm_sp_mma.pytilelang/backend/rocm/intrinsics/__init__.pytilelang/backend/rocm/intrinsics/mfma_layout.pytilelang/backend/rocm/intrinsics/mfma_macro_generator.pytilelang/backend/rocm/intrinsics/wmma_layout.pytilelang/backend/rocm/intrinsics/wmma_macro_generator.pytilelang/backend/rocm/op/gemm/gemm_mfma.pytilelang/backend/rocm/op/gemm/gemm_wmma.pytilelang/intrinsics/__init__.pytilelang/language/gemm_op.py
💤 Files with no reviewable changes (1)
- testing/python/kernel/test_tilelang_kernel_gemm_simt.py
✅ Files skipped from review due to trivial changes (25)
- tilelang/backend/rocm/intrinsics/init.py
- tilelang/backend/cuda/intrinsics/init.py
- testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
- tilelang/backend/cuda/op/gemm/gemm_wgmma.py
- tilelang/backend/cuda/op/gemm/gemm_mma.py
- testing/python/language/test_tilelang_language_reshape.py
- examples/gemm/example_gemm_intrinsics.py
- tilelang/backend/rocm/op/gemm/gemm_mfma.py
- tilelang/backend/rocm/op/gemm/gemm_wmma.py
- examples/plot_layout/fragment_mfma_load_a.py
- testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
- tilelang/backend/cuda/op/gemm/gemm_mma_sm70.py
- tilelang/backend/cuda/op/gemm_sp/gemm_sp_mma.py
- tilelang/backend/cuda/intrinsics/wgmma_macro_generator.py
- examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
- testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
- tilelang/backend/cuda/intrinsics/mma_sp_macro_generator.py
- testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
- examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py
- tilelang/backend/cuda/intrinsics/mma_macro_generator.py
- testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py
- testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py
- tilelang/backend/rocm/intrinsics/mfma_macro_generator.py
- testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py
- tilelang/backend/cuda/intrinsics/mma_sm70_macro_generator.py
🚧 Files skipped from review as they are similar to previous changes (3)
- src/backend/cuda/CMakeLists.txt
- tilelang/backend/cuda/op/gemm/gemm_tcgen05.py
- src/backend/cuda/runtime.cc
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tilelang/intrinsics/__init__.py (1)
16-33: ⚡ Quick winConsider declaring
__all__for API clarity and IDE discoverability.While no wildcard imports currently exist in the codebase, adding
__all__is a Python best practice that improves IDE autocomplete,help(tilelang.intrinsics)output, and makes the public API contract explicit. The lazy__getattr__facade already works correctly for explicit imports, but__all__provides complementary documentation value.♻️ Proposed patch
from .utils import ( mma_store_index_map, # noqa: F401 get_ldmatrix_offset, # noqa: F401 get_mma_micro_size, # noqa: F401 ) +__all__ = [ + "mma_store_index_map", + "get_ldmatrix_offset", + "get_mma_micro_size", + "TensorCoreIntrinEmitter", + "TensorCoreIntrinEmitterWithLadderTransform", + "make_mma_swizzle_layout", + "make_mfma_swizzle_layout", +] + def __getattr__(name):🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/intrinsics/__init__.py` around lines 16 - 33, Add an explicit public API list __all__ = ["TensorCoreIntrinEmitter", "TensorCoreIntrinEmitterWithLadderTransform", "make_mma_swizzle_layout", "make_mfma_swizzle_layout"] to this module (e.g., near the top of tilelang.intrinsics.__init__) so IDEs and help() can discover these lazily-exported symbols; keep the existing __getattr__ implementation intact so the lazy imports (TensorCoreIntrinEmitter, TensorCoreIntrinEmitterWithLadderTransform, make_mma_swizzle_layout, make_mfma_swizzle_layout) continue to be resolved on access.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tilelang/intrinsics/__init__.py`:
- Around line 16-33: Add an explicit public API list __all__ =
["TensorCoreIntrinEmitter", "TensorCoreIntrinEmitterWithLadderTransform",
"make_mma_swizzle_layout", "make_mfma_swizzle_layout"] to this module (e.g.,
near the top of tilelang.intrinsics.__init__) so IDEs and help() can discover
these lazily-exported symbols; keep the existing __getattr__ implementation
intact so the lazy imports (TensorCoreIntrinEmitter,
TensorCoreIntrinEmitterWithLadderTransform, make_mma_swizzle_layout,
make_mfma_swizzle_layout) continue to be resolved on access.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5c3903ad-a515-4168-b33a-2321a631606f
📒 Files selected for processing (3)
examples/gemm/example_gemm_intrinsics.pytilelang/intrinsics/__init__.pytilelang/intrinsics/utils.py
✅ Files skipped from review due to trivial changes (1)
- examples/gemm/example_gemm_intrinsics.py
6c1b178 to
595e7e2
Compare
…tor/move-backend-impls-to-op
…rix multiplication and element-wise addition. This includes the deletion of files for `benchmark_matmul_intrinsic.py`, `example_tilelang_gemm_fp8_intrinsic.py`, and associated test files, streamlining the codebase by eliminating unused components.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/cuda/op/gemm/gemm_wgmma.py (1)
152-163:⚠️ Potential issue | 🟡 Minor | ⚡ Quick win
_gemm_rsrdocstring incorrectly describes the RS operand layout.The docstring says "loads data from shared buffers A_shared and B_shared" but in the RS case (
is_gemm_rs()),Ais a fragment (register), not shared memory. This is a copy-paste from_gemm_ssr's docstring.📝 Proposed fix
`@T.prim_func` def _gemm_rsr() -> None: """ - The inner macro that loads data from shared buffers A_shared and - B_shared into local fragments, then issues Tensor Core mma ops, - accumulating into C_local. + The inner macro that uses register fragment A and shared buffer + B_shared, then issues Tensor Core wgmma ops, accumulating into + C_local. """ mma_emitter.wgmma(A_region, B_region, C_region, clear_accum, wg_wait)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/cuda/op/gemm/gemm_wgmma.py` around lines 152 - 163, The docstring for function _gemm_rsr incorrectly states that both operands are loaded from shared buffers; update the _gemm_rsr docstring to accurately describe the RS layout (A is a fragment/register, B is loaded from shared memory B_shared) instead of saying "A_shared and B_shared", mirroring the correct wording pattern used in _gemm_ssr; locate the docstring inside the _gemm_rsr prim_func and change the sentence about loading data to clearly state that A is a fragment in registers while B is taken from shared memory.
♻️ Duplicate comments (1)
tilelang/cuda/transform/__init__.py (1)
13-20:⚠️ Potential issue | 🟠 Major | ⚡ Quick win
LowerL2PersistentandPersistThreadblockare missinghasattrguards —AttributeErrorwhen CUDA is not built.
LowerHopperIntrin(lines 8–10) is correctly guarded withhasattr(_ffi_api, "LowerHopperIntrin"), but the same protection is absent for the two functions below. WhenUSE_CUDA=OFFthe FFI symbols won't be registered, causing anAttributeErrorat the call sites inengine/phase.py.Proposed fix
def LowerL2Persistent(): """LowerL2Persistent""" - return _ffi_api.LowerL2Persistent() # type: ignore + if hasattr(_ffi_api, "LowerL2Persistent"): + return _ffi_api.LowerL2Persistent() # type: ignore + return lambda f: f def PersistThreadblock(): """PersistThreadblock""" - return _ffi_api.PersistThreadblock() # type: ignore + if hasattr(_ffi_api, "PersistThreadblock"): + return _ffi_api.PersistThreadblock() # type: ignore + return lambda f: f🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/cuda/transform/__init__.py` around lines 13 - 20, LowerL2Persistent and PersistThreadblock call _ffi_api unguarded which raises AttributeError when CUDA is not built; wrap each export with a hasattr(_ffi_api, "<symbol>") guard like LowerHopperIntrin does. Specifically, add checks for hasattr(_ffi_api, "LowerL2Persistent") and hasattr(_ffi_api, "PersistThreadblock") and only return the ffi_api call if present; otherwise provide a safe fallback (e.g., no-op or raising a clearer error) so callers in engine/phase.py do not hit AttributeError at import time.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tilelang/rocm/intrinsics/utils.py`:
- Around line 18-23: The function get_mma_micro_size has a too-narrow type hint
(Literal["float16", "int8"]) but the implementation checks for "float8_e4m3" and
"float8_e5m2"; update the dtype annotation to include "float8_e4m3" and
"float8_e5m2" (e.g., Literal["float16","int8","float8_e4m3","float8_e5m2"]) so
the signature matches the logic. Make the same change for the corresponding
get_mma_micro_size definition in the CUDA layout utils module so both
implementations and their type hints are consistent. Ensure imports for Literal
remain valid.
---
Outside diff comments:
In `@tilelang/cuda/op/gemm/gemm_wgmma.py`:
- Around line 152-163: The docstring for function _gemm_rsr incorrectly states
that both operands are loaded from shared buffers; update the _gemm_rsr
docstring to accurately describe the RS layout (A is a fragment/register, B is
loaded from shared memory B_shared) instead of saying "A_shared and B_shared",
mirroring the correct wording pattern used in _gemm_ssr; locate the docstring
inside the _gemm_rsr prim_func and change the sentence about loading data to
clearly state that A is a fragment in registers while B is taken from shared
memory.
---
Duplicate comments:
In `@tilelang/cuda/transform/__init__.py`:
- Around line 13-20: LowerL2Persistent and PersistThreadblock call _ffi_api
unguarded which raises AttributeError when CUDA is not built; wrap each export
with a hasattr(_ffi_api, "<symbol>") guard like LowerHopperIntrin does.
Specifically, add checks for hasattr(_ffi_api, "LowerL2Persistent") and
hasattr(_ffi_api, "PersistThreadblock") and only return the ffi_api call if
present; otherwise provide a safe fallback (e.g., no-op or raising a clearer
error) so callers in engine/phase.py do not hit AttributeError at import time.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 56840068-6f96-45e4-ba75-ee493184c567
📒 Files selected for processing (79)
benchmark/matmul/benchmark_matmul_intrinsic.pydocs/deeplearning_operators/matmul.mdexamples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.pyexamples/dequantize_gemm/example_dequant_gemm_fine_grained.pyexamples/gemm/README.mdexamples/gemm/example_gemm_intrinsics.pyexamples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.pyexamples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.pyexamples/gemm_fp8/regression_example_gemm_fp8.pyexamples/gemm_fp8/test_example_gemm_fp8.pyexamples/hadamard_transform/example_hadamard.pyexamples/plot_layout/README.mdexamples/plot_layout/fragment_mfma_load_a.pyexamples/plot_layout/fragment_mma_load_a.pysrc/backend/cuda/CMakeLists.txtsrc/backend/cuda/runtime.ccsrc/backend/cuda/runtime.hsrc/backend/cuda/transform/lower_hopper_intrin.ccsrc/backend/cuda/transform/lower_l2_persistent_annotation.ccsrc/backend/cuda/transform/persist_threadblock.cctesting/python/amd/test_tilelang_gemm_mfma_intrinsic.pytesting/python/amd/test_tilelang_gemm_mfma_preshuffle.pytesting/python/kernel/test_tilelang_kernel_bf16_gemm_mma.pytesting/python/kernel/test_tilelang_kernel_element_wise_add.pytesting/python/kernel/test_tilelang_kernel_fp8_gemm_mma.pytesting/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.pytesting/python/kernel/test_tilelang_kernel_gemm_simt.pytesting/python/language/test_tilelang_language_reshape.pytesting/python/language/test_tilelang_language_vectorize.pytesting/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.pytesting/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.pytesting/python/transform/test_tilelang_transform_lower_hopper_intrin.pytilelang/__init__.pytilelang/backend/__init__.pytilelang/cpu/__init__.pytilelang/cpu/op/__init__.pytilelang/cpu/op/gemm/__init__.pytilelang/cpu/op/gemm/gemm_scalar.pytilelang/cuda/__init__.pytilelang/cuda/intrinsics/__init__.pytilelang/cuda/intrinsics/layout/__init__.pytilelang/cuda/intrinsics/layout/mma_layout.pytilelang/cuda/intrinsics/layout/mma_sm70_layout.pytilelang/cuda/intrinsics/layout/mma_sp_layout.pytilelang/cuda/intrinsics/layout/utils.pytilelang/cuda/intrinsics/macro/__init__.pytilelang/cuda/intrinsics/macro/mma_macro_generator.pytilelang/cuda/intrinsics/macro/mma_sm70_macro_generator.pytilelang/cuda/intrinsics/macro/mma_sp_macro_generator.pytilelang/cuda/intrinsics/macro/tcgen05_macro_generator.pytilelang/cuda/intrinsics/macro/wgmma_macro_generator.pytilelang/cuda/op/__init__.pytilelang/cuda/op/gemm/__init__.pytilelang/cuda/op/gemm/gemm_mma.pytilelang/cuda/op/gemm/gemm_mma_sm70.pytilelang/cuda/op/gemm/gemm_tcgen05.pytilelang/cuda/op/gemm/gemm_wgmma.pytilelang/cuda/op/gemm_sp/__init__.pytilelang/cuda/op/gemm_sp/gemm_sp_mma.pytilelang/cuda/transform/__init__.pytilelang/engine/phase.pytilelang/intrinsics/__init__.pytilelang/language/gemm_op.pytilelang/rocm/__init__.pytilelang/rocm/intrinsics/__init__.pytilelang/rocm/intrinsics/mfma_layout.pytilelang/rocm/intrinsics/mfma_macro_generator.pytilelang/rocm/intrinsics/utils.pytilelang/rocm/intrinsics/wmma_layout.pytilelang/rocm/intrinsics/wmma_macro_generator.pytilelang/rocm/op/__init__.pytilelang/rocm/op/gemm/__init__.pytilelang/rocm/op/gemm/gemm_mfma.pytilelang/rocm/op/gemm/gemm_wmma.pytilelang/tileop/gemm/__init__.pytilelang/tileop/gemm/registry.pytilelang/tileop/gemm_sp/__init__.pytilelang/tileop/gemm_sp/registry.pytilelang/transform/__init__.py
💤 Files with no reviewable changes (10)
- examples/gemm_fp8/regression_example_gemm_fp8.py
- tilelang/transform/init.py
- testing/python/kernel/test_tilelang_kernel_element_wise_add.py
- tilelang/backend/init.py
- tilelang/cuda/intrinsics/layout/utils.py
- benchmark/matmul/benchmark_matmul_intrinsic.py
- testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py
- examples/gemm_fp8/test_example_gemm_fp8.py
- examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
- testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py
✅ Files skipped from review due to trivial changes (37)
- tilelang/cpu/init.py
- testing/python/language/test_tilelang_language_vectorize.py
- tilelang/rocm/init.py
- tilelang/cuda/init.py
- examples/hadamard_transform/example_hadamard.py
- tilelang/cuda/intrinsics/macro/mma_sm70_macro_generator.py
- tilelang/cuda/intrinsics/macro/mma_sp_macro_generator.py
- tilelang/cuda/op/gemm/gemm_mma_sm70.py
- testing/python/kernel/test_tilelang_kernel_gemm_simt.py
- tilelang/cuda/intrinsics/macro/init.py
- tilelang/cuda/op/gemm_sp/init.py
- testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
- tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py
- docs/deeplearning_operators/matmul.md
- tilelang/cuda/op/gemm/init.py
- tilelang/cuda/intrinsics/layout/mma_sp_layout.py
- tilelang/rocm/op/gemm/init.py
- tilelang/cuda/op/gemm_sp/gemm_sp_mma.py
- tilelang/tileop/gemm_sp/init.py
- examples/gemm/example_gemm_intrinsics.py
- examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
- tilelang/tileop/gemm/init.py
- tilelang/cuda/op/init.py
- tilelang/rocm/op/gemm/gemm_wmma.py
- examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py
- tilelang/cuda/op/gemm/gemm_mma.py
- examples/plot_layout/fragment_mfma_load_a.py
- src/backend/cuda/transform/lower_l2_persistent_annotation.cc
- tilelang/rocm/op/gemm/gemm_mfma.py
- examples/dequantize_gemm/example_dequant_gemm_fine_grained.py
- tilelang/rocm/intrinsics/init.py
- tilelang/cuda/intrinsics/macro/mma_macro_generator.py
- src/backend/cuda/transform/lower_hopper_intrin.cc
- tilelang/cuda/intrinsics/layout/init.py
- tilelang/cuda/op/gemm/gemm_tcgen05.py
- src/backend/cuda/runtime.h
- examples/gemm/README.md
🚧 Files skipped from review as they are similar to previous changes (10)
- testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
- testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
- tilelang/language/gemm_op.py
- testing/python/language/test_tilelang_language_reshape.py
- testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
- examples/plot_layout/README.md
- examples/plot_layout/fragment_mma_load_a.py
- tilelang/engine/phase.py
- src/backend/cuda/runtime.cc
- src/backend/cuda/transform/persist_threadblock.cc
| def get_mma_micro_size(dtype: Literal["float16", "int8"]): | ||
| micro_size_x = micro_size_y = 16 | ||
| micro_size_k = 16 | ||
| if dtype in {"float8_e4m3", "float8_e5m2", "int8"}: | ||
| micro_size_k = 32 | ||
| return micro_size_x, micro_size_y, micro_size_k |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
rg -nP '\bget_mma_micro_size\s*\(\s*["'\''](float8_e4m3|float8_e5m2|float16|int8)["'\'']\s*\)'Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
cat -n tilelang/rocm/intrinsics/utils.py | sed -n '15,30p'Repository: tile-ai/tilelang
Length of output: 461
🏁 Script executed:
rg 'get_mma_micro_size' -nRepository: tile-ai/tilelang
Length of output: 1119
🏁 Script executed:
cat -n examples/plot_layout/fragment_mma_load_a.py | sed -n '60,70p'Repository: tile-ai/tilelang
Length of output: 644
🏁 Script executed:
cat -n examples/plot_layout/fragment_mfma_load_a.py | sed -n '65,75p'Repository: tile-ai/tilelang
Length of output: 644
🏁 Script executed:
cat -n examples/plot_layout/README.md | sed -n '65,75p'Repository: tile-ai/tilelang
Length of output: 521
🏁 Script executed:
cat -n examples/plot_layout/fragment_mma_load_a.py | sed -n '1,70p' | grep -A 20 "^.*def "Repository: tile-ai/tilelang
Length of output: 905
🏁 Script executed:
cat -n examples/plot_layout/fragment_mfma_load_a.py | sed -n '1,75p' | grep -A 20 "^.*def "Repository: tile-ai/tilelang
Length of output: 893
🏁 Script executed:
rg -i 'float8_e4m3|float8_e5m2' -nRepository: tile-ai/tilelang
Length of output: 50372
🏁 Script executed:
rg 'make_mma_load_base_layout\s*\(' -A 3 -B 1Repository: tile-ai/tilelang
Length of output: 1700
🏁 Script executed:
cat -n tilelang/cuda/intrinsics/layout/utils.py | sed -n '88,115p'Repository: tile-ai/tilelang
Length of output: 1194
Update type hint to include float8 variants that are explicitly handled in the implementation.
Lines 18-23 (ROCm) and the corresponding CUDA function at tilelang/cuda/intrinsics/layout/utils.py:90 declare dtype: Literal["float16", "int8"], but both implementations check if dtype in {"float8_e4m3", "float8_e5m2", "int8"}:. The CUDA version even documents these float8 types in its docstring (line 101), confirming they are intentionally supported. The type annotation should be corrected to match:
Suggested fix
-def get_mma_micro_size(dtype: Literal["float16", "int8"]):
+def get_mma_micro_size(
+ dtype: Literal["float16", "int8", "float8_e4m3", "float8_e5m2"]
+):Apply the same fix to tilelang/cuda/intrinsics/layout/utils.py:90.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/rocm/intrinsics/utils.py` around lines 18 - 23, The function
get_mma_micro_size has a too-narrow type hint (Literal["float16", "int8"]) but
the implementation checks for "float8_e4m3" and "float8_e5m2"; update the dtype
annotation to include "float8_e4m3" and "float8_e5m2" (e.g.,
Literal["float16","int8","float8_e4m3","float8_e5m2"]) so the signature matches
the logic. Make the same change for the corresponding get_mma_micro_size
definition in the CUDA layout utils module so both implementations and their
type hints are consistent. Ensure imports for Literal remain valid.
…into backend directories (tile-ai#2165) * [Refactor] Move backend-specific GEMM implementations and transforms into backend directories Restructure the codebase so that each backend (cpu, cuda, rocm) owns its GEMM implementations, sparse GEMM implementations, and transform passes under a consistent op/ subdirectory layout. The shared GEMM registry and base classes remain in tileop/ as the platform-agnostic dispatch layer. - Move gemm/gemm_sp registries from backend/ into tileop/ as registry.py - Move CUDA GEMM impls (mma, mma_sm70, wgmma, tcgen05) into backend/cuda/op/gemm/ - Move CUDA sparse GEMM impl into backend/cuda/op/gemm_sp/ - Move CPU GEMM impl (scalar) into backend/cpu/op/gemm/ - Move ROCm GEMM impls (mfma, wmma) into backend/rocm/op/gemm/ - Move CUDA-specific transform passes from src/transform/ into src/backend/cuda/transform/ - Move CUDA runtime sources from src/runtime/ into src/backend/cuda/ - Remove dead backend-importing wrappers from transform/__init__.py - Update phase.py and tests to import CUDA transforms from their canonical location Each backend now has a symmetric op/ directory structure. Adding a new backend no longer requires modifying shared transform or tileop modules. * Refactor Python backend package layout * Remove deprecated intrinsic implementations and related tests for matrix multiplication and element-wise addition. This includes the deletion of files for `benchmark_matmul_intrinsic.py`, `example_tilelang_gemm_fp8_intrinsic.py`, and associated test files, streamlining the codebase by eliminating unused components. * Move CUDA transform passes back to common transform
Summary
Restructure the codebase so that each backend (cpu, cuda, rocm) owns its GEMM implementations, sparse GEMM implementations, and transform passes under a consistent
op/subdirectory layout. The shared GEMM registry and base classes remain intileop/as the platform-agnostic dispatch layer.This makes backend ownership explicit: adding a new backend no longer requires modifying shared
transformortileopmodules, and all backend-specific code follows the same symmetric directory structure.Changes
Registry extraction
backend/gemm.py→tileop/gemm/registry.py(GEMM register/resolve)backend/gemm_sp.py→tileop/gemm_sp/registry.py(sparse GEMM register/resolve)CUDA backend
tileop/gemm/gemm_mma.py→backend/cuda/op/gemm/gemm_mma.pytileop/gemm/gemm_mma_sm70.py→backend/cuda/op/gemm/gemm_mma_sm70.pytileop/gemm/gemm_wgmma.py→backend/cuda/op/gemm/gemm_wgmma.pytileop/gemm/gemm_tcgen05.py→backend/cuda/op/gemm/gemm_tcgen05.pytileop/gemm_sp/gemm_sp_mma.py→backend/cuda/op/gemm_sp/gemm_sp_mma.py__init__.pyCPU backend
tileop/gemm/gemm_scalar.py→backend/cpu/op/gemm/gemm_scalar.pyROCm backend
tileop/gemm/gemm_mfma.py→backend/rocm/op/gemm/gemm_mfma.pytileop/gemm/gemm_wmma.py→backend/rocm/op/gemm/gemm_wmma.pyTransform passes
src/transform/intosrc/backend/cuda/transform/src/runtime/intosrc/backend/cuda/backend/cuda/transform/__init__.pyto wrap FFI callstransform/__init__.py(LowerHopperIntrin,LowerL2Persistent,PersistThreadblock)Callers
engine/phase.pyand the Hopper intrinsics test to import CUDA transforms frombackend.cuda.transformdirectlyFinal structure
Test plan
pytest testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py)bash format.sh)Summary by CodeRabbit
Release Notes
New Features
Refactor
Chores