[Backend] Refactor gemm_sp#2048
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR refactors the sparse GEMM operator stack by consolidating ChangesGemmSP Operator Definition and Schema Overhaul
Sparse MMA Instruction Primitives and CUDA Templates
WGMMA Sparse Tensor-Core Emitters
Sparse Compression Utility Refactoring
CUDA Code Generation and Instruction Lowering
Sparse MMA Layout and Macro Updates
Test Coverage and Example Updates
Sequence DiagramsequenceDiagram
participant User as User Code
participant Lang as tilelang.language
participant GemmSP as GemmSP Operator
participant Backend as CUDA Backend
participant Codegen as CUDA Codegen
participant Templates as CUDA Templates
User->>Lang: Call T.gemm_sp(A, E, B, C, ..., e_dtype)
Lang->>GemmSP: Construct GemmSP with regions & trans_E
GemmSP->>Backend: invoke tl.gemm_sp.lower()
Backend->>Backend: Select instruction (MMA/WGMMA/TCGEN5)
alt WGMMA Selected
Backend->>Backend: Instantiate GemmSPWGMMA
Backend->>Codegen: Lower to wgmma_ss/wgmma_rs PTX calls
Codegen->>Templates: Generate wgmma_sp_ss/wgmma_sp_rs<...>
Templates->>Templates: Emit SM90 sparse WGMMA PTX
else MMA Selected
Backend->>Backend: Instantiate GemmSPMMA
Codegen->>Templates: Generate mma_sp_sync<...>
Templates->>Templates: Emit SM80/SM89 sparse MMA PTX
end
Templates-->>Codegen: Compiled CUDA kernel
Codegen-->>Lang: TIR PrimFunc
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes This PR involves substantial, multi-layered refactoring across operator schemas, CUDA instruction templates, codegen emission, sparse utilities, and test infrastructure. It introduces new sparse MMA instruction primitives for multiple GPU architectures, transitions from static CUDA templates to FFI-based lowering, overhauls the sparse compression implementation, and updates numerous examples and tests. The changes are complex in logic, span many affected files with heterogeneous patterns (operator definition, template headers, Python emitters, tests), and require careful review of instruction semantics, FFI wiring, and backward compatibility. Possibly related PRs
Suggested reviewers
Poem
✨ Finishing Touches🧪 Generate unit tests (beta)
|
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (3)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (3)
197-197: Redundant import inside function.
tilelang.language as Tis already imported at module level (line 10). This redundant import also appears inmatmul_sr(line 344) andmatmul_rr(line 493).♻️ Proposed fix
Remove the redundant imports on lines 197, 344, and 493:
- import tilelang.language as T - `@T.prim_func`🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py` at line 197, Remove the redundant local imports of tilelang.language as T inside the file-level functions; since tilelang.language is already imported at module scope, delete the in-function import statements found inside the functions matmul_sr and matmul_rr (and the extra import at the earlier location near the test setup) so those functions use the module-level T symbol instead.
153-171: Minor inconsistency:T.floatvsT.float32for accumulator dtype.In
test_gemm_ssparameters, the first four cases useT.floatwhile other tests (test_gemm_rs,test_gemm_sr,test_gemm_rr) consistently useT.float32. While functionally equivalent, usingT.float32consistently would improve readability.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py` around lines 153 - 171, Replace the inconsistent accumulator dtype T.float with T.float32 in the test_gemm_ss parameterization: update the first four tuples where dtypeAccum is T.float to use T.float32 so they match the other tests (refer to the test function name test_gemm_ss and the runner call run_gemm_ss to locate the parameter list).
114-132: Consider extracting the_matmulhelper to module level.The
_matmulhelper function is defined identically insiderun_gemm_ss,run_gemm_rs,run_gemm_sr, andrun_gemm_rr. Extracting it to module level would reduce duplication.♻️ Proposed refactor
Add at module level (e.g., after
generate_dense_input):def _reference_matmul(A, B, trans_A, trans_B): if trans_A: A = A.T if trans_B: B = B.T A = A.to(torch.float32) B = B.to(torch.float32) return torch.matmul(A, B)Then replace inline definitions with calls to
_reference_matmul(A, B, trans_A, trans_B).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py` around lines 114 - 132, The helper _matmul is duplicated inside run_gemm_ss/run_gemm_rs/run_gemm_sr/run_gemm_rr; extract it to module level as a single function (e.g., _reference_matmul(A, B, trans_A, trans_B)) placed near generate_dense_input, implement the same behavior (apply transposes based on trans_A/trans_B, cast to torch.float32, then torch.matmul) and update each run_gemm_* to call _reference_matmul(A, B, trans_A, trans_B) instead of defining its own _matmul to remove duplication and keep behavior identical.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/deeplearning_operators/matmul_sparse.md`:
- Line 166: Replace the non-descriptive "here" link texts with explicit labels:
change the first link label to something like "PyTorch Ampere compressor
implementation" (pointing to the same URL) and the second link label to
"metadata permutation routine" (pointing to the permutation URL); update the
sentence so it reads e.g. "PyTorch provides an Ampere compressor (PyTorch Ampere
compressor implementation). Note that in this implementation, a permutation
(metadata permutation routine) is applied..." so the links are descriptive and
MD059-compliant while preserving the existing URLs and the instruction about
matching CUTLASS metadata layout and not using
_calculate_meta_reordering_scatter_offsets.
In `@tilelang/language/__init__.py`:
- Line 60: The public symbol gemm_sp_v2 was removed when replacing the import
with gemm_sp; restore backward compatibility by re-exporting a compatibility
alias named gemm_sp_v2 that points to the new gemm_sp implementation (i.e.,
import gemm_sp from .experimental.gemm_sp and assign gemm_sp_v2 = gemm_sp so
existing callers of T.gemm_sp_v2(...) continue to work).
In `@tilelang/tileop/__init__.py`:
- Line 3: The public export GemmSPPy was removed causing breakage for imports;
restore a deprecated alias by reintroducing GemmSPPy as an alias to GemmSP
(e.g., GemmSPPy = GemmSP) in tilelang.tileop.__init__ and emit a
DeprecationWarning via the warnings module when the alias is used or on import
to advise users to switch to GemmSP; also ensure GemmSPPy is present in the
module's public exports (e.g., __all__) so downstream code importing GemmSPPy
continues to work.
In `@tilelang/tileop/gemm_sp/__init__.py`:
- Around line 15-24: The FFI callbacks gemm_sp_infer_layout and gemm_sp_lower
are incorrectly annotated with GemmSPMMA; update both function signatures to use
GemmSP instead so the Python FFI type matches the C++ side (which uses
GetRef<GemmSP>); ensure the parameters and uses inside those functions (calls to
gemm_sp.infer_layout(...) and gemm_sp.lower(...)) remain unchanged so
infer_layout and lower are invoked on a GemmSP instance.
---
Nitpick comments:
In `@testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py`:
- Line 197: Remove the redundant local imports of tilelang.language as T inside
the file-level functions; since tilelang.language is already imported at module
scope, delete the in-function import statements found inside the functions
matmul_sr and matmul_rr (and the extra import at the earlier location near the
test setup) so those functions use the module-level T symbol instead.
- Around line 153-171: Replace the inconsistent accumulator dtype T.float with
T.float32 in the test_gemm_ss parameterization: update the first four tuples
where dtypeAccum is T.float to use T.float32 so they match the other tests
(refer to the test function name test_gemm_ss and the runner call run_gemm_ss to
locate the parameter list).
- Around line 114-132: The helper _matmul is duplicated inside
run_gemm_ss/run_gemm_rs/run_gemm_sr/run_gemm_rr; extract it to module level as a
single function (e.g., _reference_matmul(A, B, trans_A, trans_B)) placed near
generate_dense_input, implement the same behavior (apply transposes based on
trans_A/trans_B, cast to torch.float32, then torch.matmul) and update each
run_gemm_* to call _reference_matmul(A, B, trans_A, trans_B) instead of defining
its own _matmul to remove duplication and keep behavior identical.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: cf24b8d5-9f02-4a04-96fb-48da817b4d31
📒 Files selected for processing (15)
benchmark/matmul/benchmark_matmul_sp.pydocs/deeplearning_operators/matmul_sparse.mdexamples/gemm_sp/example_custom_compress.pysrc/op/gemm_sp.ccsrc/op/gemm_sp.hsrc/op/gemm_sp_py.ccsrc/op/gemm_sp_py.hsrc/transform/lower_opaque_block.cctesting/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.pytesting/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.pytilelang/ir.pytilelang/language/__init__.pytilelang/language/experimental/gemm_sp.pytilelang/tileop/__init__.pytilelang/tileop/gemm_sp/__init__.py
💤 Files with no reviewable changes (4)
- tilelang/ir.py
- testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
- src/op/gemm_sp_py.h
- src/op/gemm_sp_py.cc
| from tilelang.tileop.base import GemmWarpPolicy # noqa: F401 | ||
| from .gemm_op import gemm, wgmma_gemm, tcgen05_gemm # noqa: F401 | ||
| from .experimental.gemm_sp import gemm_sp, gemm_sp_v2 # noqa: F401 | ||
| from .experimental.gemm_sp import gemm_sp # noqa: F401 |
There was a problem hiding this comment.
Preserve gemm_sp_v2 during the rename.
Line 60 removes a public frontend symbol. Existing kernels that import or call T.gemm_sp_v2(...) will break even though this PR is unifying the implementation rather than removing the feature.
Proposed compatibility shim
from .experimental.gemm_sp import gemm_sp # noqa: F401
+gemm_sp_v2 = gemm_sp # backward-compatible alias📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from .experimental.gemm_sp import gemm_sp # noqa: F401 | |
| from .experimental.gemm_sp import gemm_sp # noqa: F401 | |
| gemm_sp_v2 = gemm_sp # backward-compatible alias |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/language/__init__.py` at line 60, The public symbol gemm_sp_v2 was
removed when replacing the import with gemm_sp; restore backward compatibility
by re-exporting a compatibility alias named gemm_sp_v2 that points to the new
gemm_sp implementation (i.e., import gemm_sp from .experimental.gemm_sp and
assign gemm_sp_v2 = gemm_sp so existing callers of T.gemm_sp_v2(...) continue to
work).
| from .base import GemmWarpPolicy # noqa: F401 | ||
| from .gemm import Gemm # noqa: F401 | ||
| from .gemm_sp import GemmSPPy # noqa: F401 | ||
| from .gemm_sp import GemmSP # noqa: F401 |
There was a problem hiding this comment.
Keep GemmSPPy as a deprecated alias.
Line 3 removes a public export outright. Any downstream from tilelang.tileop import GemmSPPy now fails immediately, even though this change is a rename rather than a semantic removal.
Proposed compatibility shim
from .gemm_sp import GemmSP # noqa: F401
+GemmSPPy = GemmSP # backward-compatible alias📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from .gemm_sp import GemmSP # noqa: F401 | |
| from .gemm_sp import GemmSP # noqa: F401 | |
| GemmSPPy = GemmSP # backward-compatible alias |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/tileop/__init__.py` at line 3, The public export GemmSPPy was
removed causing breakage for imports; restore a deprecated alias by
reintroducing GemmSPPy as an alias to GemmSP (e.g., GemmSPPy = GemmSP) in
tilelang.tileop.__init__ and emit a DeprecationWarning via the warnings module
when the alias is used or on import to advise users to switch to GemmSP; also
ensure GemmSPPy is present in the module's public exports (e.g., __all__) so
downstream code importing GemmSPPy continues to work.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Remove unused imports (tvm, _TORCH_DTYPE_TO_STR)
- Remove print("pass") calls
- Normalize all parametrize tuples to compact aligned format
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Replace removed compress_sm90 with compress() in sparse_tensorcore example and test_compress_utils - Fix example_gemm_sp.py: remove ARCH_INFO and make_cutlass_metadata_layout; compress() now produces natural-layout int16 metadata on all architectures - Update matmul_sparse.md to reflect the simplified compress() API and remove the CUTLASS layout annotation requirement Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Add tilelang/utils/sparse_config.py as single source of truth for SPARSE_PARAMS, E_FACTOR_MAP, E_REPLICATE_FACTOR, and get_e_factor() - Remove inline E_FACTOR_MAP/E_REPLICATE_FACTOR from SparseTensorCoreIntrinEmitter; import from sparse_config.py with backward-compatible class aliases - Update compress() API: remove transposed/block_k/arch params, use T.int16 metadata - Delete example_custom_compress.py; move torch_compress reference to testing/python/utils/test_compress_utils.py with float8 .view() support - Update examples (example_gemm_sp, tilelang_example_sparse_tensorcore) and test files to use new compress() API and get_e_factor() Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Remove E_FACTOR_MAP and E_REPLICATE_FACTOR dicts from sparse_config.py - get_e_factor: derived from SPARSE_PARAMS group size and meta_dtype.bits (hardware always packs 4 bits per sparsity group) - get_e_replicate_factor: derived from a_dtype.bits (<=8 bit → 1, else → 2) - SparseTensorCoreIntrinEmitter uses the two functions directly instead of class-level dict aliases Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 8
🧹 Nitpick comments (16)
src/op/builtin.h (1)
372-381: 💤 Low valueAdd full signature comments for the new sparse WGMMA intrinsics.
ptx_wgmma_ss/ptx_wgmma_rsabove (Lines 350-370) document the full argument list. The newptx_wgmma_sp_ss(18 inputs inbuiltin.cc) andptx_wgmma_sp_rs(17 inputs) only get a one-liner, so callers can’t tell from this header which args were added (E descriptor, E offset, sparse selector, etc.) or why the sp_ss/sp_rs counts diverge by one. Mirroring the existing doc style would prevent guesswork at codegen sites.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/op/builtin.h` around lines 372 - 381, Add full signature-style comments for the two new intrinsics ptx_wgmma_sp_ss() and ptx_wgmma_sp_rs() mirroring the existing ptx_wgmma_ss/ptx_wgmma_rs documentation: enumerate each argument in order (types and brief meaning), explicitly document the extra sparse-related parameters (E descriptor, E offset, sparse selector) and state the input count (ptx_wgmma_sp_ss: 18 inputs; ptx_wgmma_sp_rs: 17 inputs) and why sp_ss has one more argument than sp_rs; place these comments directly above the TVM_DLL const Op &ptx_wgmma_sp_ss() and TVM_DLL const Op &ptx_wgmma_sp_rs() declarations so callers can see the full signature without inspecting builtin.cc.src/tl_templates/cuda/instruction/cute_extension/mma_sm89_sparse.hpp (1)
25-201: ⚖️ Poor tradeoffConsolidate the four SM89 sparse fp8 MMA structs using a macro or template.
The four structs (
SM89_16x8x64_F32E4M3E4M3F32_TN,SM89_16x8x64_F32E4M3E5M2F32_TN,SM89_16x8x64_F32E5M2E4M3F32_TN,SM89_16x8x64_F32E5M2E5M2F32_TN) are identical except for their fp8 operand type suffix (e4m3.e4m3,e4m3.e5m2,e5m2.e4m3,e5m2.e5m2) in the PTX instruction strings and struct name in the diagnostic message. Folding them with a small X-macro or template dispatcher would reduce duplication by ~75%, eliminate accidental drift between variants, and make adding new fp8 combinations (e.g., future scaled fp8) a single-line change.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/tl_templates/cuda/instruction/cute_extension/mma_sm89_sparse.hpp` around lines 25 - 201, The four near-identical structs SM89_16x8x64_F32E4M3E4M3F32_TN, SM89_16x8x64_F32E4M3E5M2F32_TN, SM89_16x8x64_F32E5M2E4M3F32_TN, and SM89_16x8x64_F32E5M2E5M2F32_TN differ only by the fp8 operand suffix in the inline-asm and the diagnostic string; replace them with a single parametric generator (either an X-macro or a template wrapper) that takes the PTX suffix string (e.g., "e4m3.e4m3", "e4m3.e5m2", etc.) and a short name fragment and emits the struct and its CUTE_INVALID_CONTROL_PATH message, then update/replace the duplicated fma implementations to use that generator so the asm literal and the invalid-path message are produced from the single parameter. Ensure the generator exposes the same struct type names (or typedefs) used elsewhere (or provide forwarding aliases) and preserves the fma signature and static_assert(spsel == SparseSel::Zero).src/tl_templates/cuda/instruction/mma_sp.h (1)
93-101: 💤 Low valueConsider asserting that
kDRegs == kCRegsin addition to type matching.The static assertion confirms
DRegandCRegtypes match, but the dispatcher then independently expandsstd::make_index_sequence<kDRegs>and<kCRegs>(lines 56–60 incall_fma_sp_impl). For accumulator-style MMA/FMA,kDRegsshould equalkCRegs. A divergence would silently mis-callImpl::fma. A trivial extrastatic_assert(Traits::kDRegs == Traits::kCRegs, ...)would harden against future CUTE template additions.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/tl_templates/cuda/instruction/mma_sp.h` around lines 93 - 101, Add a second compile-time check to ensure the register counts match: in the same scope where the existing type equality static_assert is located (the static_assert comparing typename Traits::DReg and typename Traits::CReg inside exec), add a static_assert(Traits::kDRegs == Traits::kCRegs, "tl::mma_sp_sync requires matching accumulator/output register counts"); this will prevent mismatched std::make_index_sequence expansions in call_fma_sp_impl and ensure Traits::kDRegs and Traits::kCRegs are equal before calling call_fma_sp/Impl::fma.tilelang/tileop/gemm_sp/__init__.py (2)
56-65: 💤 Low value
infer_layoutandlowercreate a fresh implementation instance per call — minor inefficiency.
impl_class(self)is constructed twice (once for layout inference, once for lowering), and each call additionally re-runs_select_gemm_instruction(which crosses the FFI to the C++GemmSPGetGemmSPInst). For a hot compilation path this is fine, but consider memoizinggemm_instonselfif compilation latency becomes a concern.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/tileop/gemm_sp/__init__.py` around lines 56 - 65, Both infer_layout and lower call _select_gemm_instruction and instantiate impl_class(self) separately, causing duplicate FFI calls and extra object creation; cache the selected gemm_inst and/or the instantiated implementation on self so both methods reuse the same value/instance: call self._select_gemm_instruction(thread_nums, target) once (e.g. store as self._cached_gemm_inst keyed by thread_nums/target or invalidate when inputs change), use self._get_implementation_class(cached_inst, target) once and keep the instantiated impl (instead of impl_class(self) twice) so infer_layout and lower call the same impl object's infer_layout and lower methods.
44-54: ⚡ Quick winRefactor to use free function wrappers for consistency with the
gemmmodule pattern.
gemm_sp_infer_layoutandgemm_sp_lowerare instance methods registered via@tvm_ffi.register_global_func, whereas the siblinggemmmodule uses explicit free function wrappers (e.g.,def gemm_infer_layout(gemm: GemmMMA, target, ...)). The C++ side calls these identically via(*f)(GetRef<GemmSP>(this), ...), so both patterns work, but instance method registration deviates from the established codebase convention. Align with thegemmpattern by extracting wrapper functions at module level.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/tileop/gemm_sp/__init__.py` around lines 44 - 54, Replace the instance-method registrations with module-level free-function wrappers consistent with the gemm pattern: add two top-level functions (e.g., gemm_sp_infer_layout(gemm_sp: GemmSP, target: Target, thread_bounds: Range) and gemm_sp_lower(gemm_sp: GemmSP, target: Target, layout_map: dict, thread_bounds: Range, thread_var: tir.Var)) that call the corresponding instance methods (gemm_sp.infer_layout(...) and gemm_sp.lower(...)) and register those free functions with tvm_ffi.register_global_func instead of the current instance-bound functions; keep the same argument order and behavior (extract thread_nums = thread_bounds.extent and forward to infer_layout/lower) so the C++ callers continue to work unchanged.src/op/gemm_sp.cc (2)
99-115: 💤 Low valueMissing fallthrough return after
ICHECK(0)may trigger compiler warnings.
GetGemmSPInstis non-void but has no return after theICHECK(0)in the unsupported-target branch. TVM'sICHECKtypically expands to aLOG(FATAL)which is[[noreturn]], so this is likely safe at runtime, but some compilers still warn-Wreturn-typehere. Consider adding an explicitreturn GemmInst::kMMA;orLOG(FATAL)after the chain to be defensive and silence warnings, mirroring the same issue inLower(line 178) andInferLayout(line 192).🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/op/gemm_sp.cc` around lines 99 - 115, GetGemmSPInst lacks an explicit return after the ICHECK(0) failure path which can trigger -Wreturn-type warnings; update GemmSPNode::GetGemmSPInst to add a defensive return after ICHECK(0) (e.g., return a sensible default like GemmInst::kMMA or mirror other functions’ behavior) so the function always returns a GemmInst even if the ICHECK is treated as non‑noreturn; modify the end of GetGemmSPInst to include this explicit return to silence compiler warnings.
145-180: 💤 Low valueLower's wrapping path may lose
iter_values/predicatesemantics.When the FFI
prim_func->bodyis aBlockRealize, the existing iter_values and predicate are preserved (lines 162–163) — good. However, when it's not aBlockRealize, you wrap the body in a syntheticBlockwith emptyiter_vars/reads/writesand constant‑true predicate (lines 169–176). If the FFI lowering ever returns a body that needs nontrivial reads/writes for downstream passes (e.g., access region inference, buffer compaction), the emptyreads/writesarrays could regress those analyses. Consider asserting the FFI always returns aBlockRealize, or populating reads/writes fromT.layout_map/access regions.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/op/gemm_sp.cc` around lines 145 - 180, The wrapper path in GemmSPNode::Lower can lose important iter_values/predicate and reads/writes info when prim_func->body is not a BlockRealize; either require/assert that the FFI always returns a BlockRealize or populate the synthetic Block's iter_vars/reads/writes/predicate from the original lowering context. Update GemmSPNode::Lower to (a) check prim_func->body and if not a BlockRealize, derive and set appropriate iter_values and predicate and fill reads/writes (using T.layout_map, T.thread_bounds/thread_var or any access-region information available from prim_func or T) instead of using empty arrays/const_true, or add a clear ICHECK/assert that the FFI must return a BlockRealize with the correct metadata so downstream passes (access region inference, buffer compaction) are not broken.src/tl_templates/cuda/instruction/wgmma_sp.h (1)
44-48: 💤 Low valueMinor: redundant condition in static_assert for
CRegsize.
sizeof(uint32_t) == sizeof(float) == 4on every supported CUDA target, so the||is tautological. Either simplify tosizeof(CReg) == sizeof(uint32_t)(matching the SS variant on line 20–21) orsizeof(CReg) == 4.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/tl_templates/cuda/instruction/wgmma_sp.h` around lines 44 - 48, The static_assert for CReg in tl::wgmma_sp_rs is redundant because sizeof(uint32_t) == sizeof(float) on supported CUDA targets; update the check to a single clear condition (e.g., require sizeof(CReg) == sizeof(uint32_t) to match the SS variant) by replacing the current "sizeof(CReg) == sizeof(uint32_t) || sizeof(CReg) == sizeof(float)" assertion with a single-size comparison referencing CReg and the tl::wgmma_sp_rs context so the intent is unambiguous.tilelang/intrinsics/wgmma_sp_macro_generator.py (3)
504-516: 💤 Low valueUse iterable unpacking for cleaner indexing (Ruff RUF005).
tuple(E_other) + (E_base0 + ..., E_base1 + ...)allocates an intermediate tuple twice; iterable unpacking is more idiomatic and slightly more efficient.Proposed fix
E_local_buf[e_local_base + inst_i * local_size_e + j] = ( - E_shared_buf[tuple(E_other) + (E_base0 + wk + mk, E_base1 + wi + mi)] + E_shared_buf[(*E_other, E_base0 + wk + mk, E_base1 + wi + mi)] if trans - else E_shared_buf[tuple(E_other) + (E_base0 + wi + mi, E_base1 + wk + mk)] + else E_shared_buf[(*E_other, E_base0 + wi + mi, E_base1 + wk + mk)] )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/intrinsics/wgmma_sp_macro_generator.py` around lines 504 - 516, The code creates temporary tuples by doing tuple(E_other) + (E_base0 + wk + mk, E_base1 + wi + mi) when indexing E_shared_buf; change both branches of the conditional in the assignment to E_local_buf[...] to use iterable unpacking instead (e.g., E_shared_buf[*E_other, E_base0 + wk + mk, E_base1 + wi + mi] and E_shared_buf[*E_other, E_base0 + wi + mi, E_base1 + wk + mk]) so you avoid allocating intermediate tuples; update the two places inside the loop that reference E_other/E_shared_buf (the assignment to E_local_buf in the for j loop) and keep the surrounding logic and return of _warp_ldmatrix_e(E_local_buf, E_buf, inst_i, ki, thread_binding) intact.
237-243: 💤 Low valueUnused variable
txflagged by Ruff (RUF059).
txis unpacked but never referenced in the_warp_mmamacro forwgmma_ss. Rename to_txto silence the lint and document intent.Proposed fix
- tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + _tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/intrinsics/wgmma_sp_macro_generator.py` around lines 237 - 243, In the _warp_mma macro, the local unpacking uses tx but that variable is unused; change the unpacked name from tx to _tx to satisfy the linter and indicate intentional unused value (e.g., replace "tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)" with "_tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)" inside the _warp_mma definition), leaving all other logic and variables (k_blocks, e_stage_elems, E_local, etc.) unchanged.
446-460: 💤 Low value
ldmatrix_available = Falselooks like an unfinished TODO worth tracking.The hardcoded
ldmatrix_available = Falseplus the comment "TODO: use ldmatrix when possible" means every metadata load currently goes through the slower per-element path even when ldmatrix would be valid. This is a perf opportunity, not a correctness bug — consider filing a follow-up issue or# FIXME(name)so it isn't lost.Want me to open a follow-up issue tracking this performance TODO?
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/intrinsics/wgmma_sp_macro_generator.py` around lines 446 - 460, The ldmatrix_available flag in ldmatrix_e is hardcoded False which leaves a TODO untracked; update ldmatrix_e to either compute ldmatrix_available from the current conditions (e.g., based on a_dtype/e_dtype and transposed state) or, if you cannot implement the fast path now, replace the hardcoded assignment with a clear tracked marker (e.g., set ldmatrix_available = False and add a FIXME(author_name) comment) and create a follow-up issue referencing ldmatrix_e and ldmatrix_available so the optimization isn't lost; ensure the change mentions the constraint (int8 + transposed case) and the location (ldmatrix_e in wgmma_sp_macro_generator.py) so future work can implement the fast path.tilelang/tileop/gemm_sp/gemm_sp_wgmma.py (2)
138-146: 💤 Low valueBoth branches return; trailing
raiseis unreachable forss/rsand only triggered forsr/rr— consider an explicit early check.
is_gemm_ss()andis_gemm_rs()already cover the supported configurations and both branchesreturn. The trailingraise ValueError(...)only fires forsr/rrcases, which is fine, but mirroring the structure ofinfer_layout(which uses an explicitelse: raise) would make the control flow uniform and easier to follow.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/tileop/gemm_sp/gemm_sp_wgmma.py` around lines 138 - 146, The trailing raise ValueError is only reachable for unsupported gemm types (sr/rr) but the current structure returns inside the is_gemm_ss() and is_gemm_rs() branches, making control flow less explicit; update the logic in the function that defines _gemm_ssr/_gemm_rsr to first check supported cases (e.g., if not (self.is_gemm_ss() or self.is_gemm_rs()): raise ValueError(...)) and then use explicit if/elif for self.is_gemm_ss() and self.is_gemm_rs() to return _Simplify(_gemm_ssr, inline_let=True) or _Simplify(_gemm_rsr, inline_let=True) respectively, mirroring infer_layout’s else: raise pattern and keeping symbols _gemm_ssr, _gemm_rsr, _Simplify, is_gemm_ss, is_gemm_rs unchanged.
73-80: 💤 Low valueType hint and unused parameter on
lower(...).
thread_nums: Rangeis misleading: the caller intilelang/tileop/gemm_sp/__init__.pypassesthread_bounds.extent(anint/PrimExpr), not aRange. Considerthread_nums: int | tir.PrimExpr.mbar_phase_expr: tir.PrimExpr | None = Noneis declared but never referenced in the body. Either drop it or thread it through.Proposed fix
- def lower( - self, - layout_map: dict, - target: Target, - thread_nums: Range, - thread_var: tir.Var, - mbar_phase_expr: tir.PrimExpr | None = None, - ): + def lower( + self, + layout_map: dict, + target: Target, + thread_nums: int | tir.PrimExpr, + thread_var: tir.Var, + ):🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/tileop/gemm_sp/gemm_sp_wgmma.py` around lines 73 - 80, The lower method's signature is incorrect: change the type hint for thread_nums in gemm_sp_wgmma.py lower(self, ...) from Range to int | tir.PrimExpr (or tir.PrimExpr) to match callers that pass thread_bounds.extent, and either remove the unused parameter mbar_phase_expr or propagate it through the function (use it where barriers/phases are computed) so it is referenced; update the signature and all internal references to use the new thread_nums type and delete mbar_phase_expr if you choose to drop it, ensuring callers (e.g., in tilelang/tileop/gemm_sp/__init__.py) remain compatible.testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (2)
958-973: 💤 Low valueInconsistent signature:
run_gemm_rrintroduces defaults that the sibling runners deliberately require.
run_gemm_ss(line 66),run_gemm_rs(line 375), andrun_gemm_sr(line 664) all takenum_stages,num_threads, andmeta_dtypeas required positional arguments, whilerun_gemm_rrdefaults them to3,128, andT.int16. All current callers pass these explicitly viapytest.mark.parametrize, so behavior is fine today, but the asymmetry is a maintenance hazard: a future caller that forgets to passmeta_dtypewill silently fall through toT.int16instead of failing fast.♻️ Align signature with the other runners
-def run_gemm_rr( - M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128, - meta_dtype=T.int16, -): +def run_gemm_rr( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages, + num_threads, + meta_dtype, +):🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py` around lines 958 - 973, The function run_gemm_rr currently defines defaults for num_stages, num_threads, and meta_dtype which is inconsistent with sibling runners (run_gemm_ss, run_gemm_rs, run_gemm_sr) and can mask caller mistakes; update run_gemm_rr's signature to require num_stages, num_threads, and meta_dtype as positional parameters (remove the defaults =3, =128, =T.int16) so callers must pass them explicitly, mirroring the other runner functions and preventing silent fallback.
311-336: 💤 Low valueDrop the redundant nested
import tilelang.language as T.Module-level line 9 already exposes
T; the per-kernel re-imports insidematmul_rs(line 336),matmul_sr(line 625), andmatmul_rr(line 916) shadow it with the same binding and are dead noise.♻️ Remove the duplicate import (one example, apply to all three)
E_shared_shape = (block_M, block_K // E_factor) if not trans_A else (block_K // E_factor, block_M) - - import tilelang.language as T `@T.prim_func`Also applies to: 600-625, 890-916
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py` around lines 311 - 336, Remove the redundant nested import statement "import tilelang.language as T" that appears inside the matmul_rs function (and similarly inside matmul_sr and matmul_rr); since T is already imported at module level, delete the per-function re-imports so the functions use the module-level T binding and avoid shadowing/redundant imports.src/tl_templates/cuda/instruction/cute_extension/mma_sm80_sparse.hpp (1)
345-380: ⚡ Quick winCollapse the
SparseSel::Oneinteger specializations viastatic_assertto remove duplication.For the four integer structs (
SM80_16x8x32_S32S8S8S32_TN,SM80_16x8x64_S32S8S8S32_TN,SM80_16x8x32_S32U8U8S32_TN,SM80_16x8x64_S32U8U8S32_TN), theSparseSel::Onespecialization duplicates ~15 register aliases plus a 13-parameterfmasignature solely to callCUTE_INVALID_CONTROL_PATH. Since there are no actual instantiations of these integer types withSparseSel::Onein the codebase, astatic_assertin the primary template provides the same safety with better compile-time diagnostics and eliminates the four specialization blocks entirely.Additionally, the named parameters in those specializations are unused and generate
-Wunused-parameterwarnings.♻️ Option A — replace specialization with static_assert in primary template (preferred)
template <SparseSel spsel = SparseSel::Zero> struct SM80_16x8x32_S32S8S8S32_TN { using DRegisters = uint32_t[4]; using ARegisters = uint32_t[2]; using BRegisters = uint32_t[2]; using CRegisters = uint32_t[4]; using ERegisters = uint32_t[1]; CUTE_HOST_DEVICE static void fma(uint32_t &d0, uint32_t &d1, uint32_t &d2, uint32_t &d3, uint32_t const &a0, uint32_t const &a1, uint32_t const &b0, uint32_t const &b1, uint32_t const &c0, uint32_t const &c1, uint32_t const &c2, uint32_t const &c3, uint32_t const &e) { + static_assert(spsel == SparseSel::Zero, + "Integer sparse MMA only supports SparseSel::Zero"); `#if` defined(CUTE_ARCH_SPARSE_MMA_SM80_ENABLED) ...-template <> struct SM80_16x8x32_S32S8S8S32_TN<SparseSel::One> { - using DRegisters = uint32_t[4]; - ... - CUTE_HOST_DEVICE static void fma(uint32_t &d0, uint32_t &d1, uint32_t &d2, - ... - uint32_t const &c3, uint32_t const &e) { - CUTE_INVALID_CONTROL_PATH( - "SM80_16x8x32_S32S8S8S32_TN with SparseSel::One is invalid"); - } -};Apply analogously to the other three integer structs (lines 442, 500, 561).
♻️ Option B — keep the specializations but silence unused-parameter warnings
- CUTE_HOST_DEVICE static void fma(uint32_t &d0, uint32_t &d1, uint32_t &d2, - uint32_t &d3, uint32_t const &a0, - uint32_t const &a1, uint32_t const &b0, - uint32_t const &b1, uint32_t const &c0, - uint32_t const &c1, uint32_t const &c2, - uint32_t const &c3, uint32_t const &e) { + CUTE_HOST_DEVICE static void fma(uint32_t &, uint32_t &, uint32_t &, + uint32_t &, uint32_t const &, + uint32_t const &, uint32_t const &, + uint32_t const &, uint32_t const &, + uint32_t const &, uint32_t const &, + uint32_t const &, uint32_t const &) { CUTE_INVALID_CONTROL_PATH( "SM80_16x8x32_S32S8S8S32_TN with SparseSel::One is invalid"); }🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/tl_templates/cuda/instruction/cute_extension/mma_sm80_sparse.hpp` around lines 345 - 380, The four integer MMA templates duplicate a SparseSel::One specialization that only calls CUTE_INVALID_CONTROL_PATH and causes unused-parameter warnings; replace those specializations by adding a static_assert in each primary template (e.g., inside SM80_16x8x32_S32S8S8S32_TN, SM80_16x8x64_S32S8S8S32_TN, SM80_16x8x32_S32U8U8S32_TN, SM80_16x8x64_S32U8U8S32_TN) that forbids SparseSel::One (e.g., static_assert(spsel != SparseSel::One, "…")) to provide the compile-time diagnostic and then remove the corresponding SparseSel::One specialization blocks (which duplicate many register typedefs and the long fma(...) signature); if you prefer to keep the specializations instead, silence -Wunused-parameter for the fma parameters in those specialization definitions.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/target/codegen_cuda.cc`:
- Around line 2769-2771: In the sparse WGMMA code paths (ptx_wgmma_ss /
ptx_wgmma_rs) the accumulator offset is being added after casting to uint32_t*,
which treats C_offset as 32-bit words instead of accumulator elements; change
the string construction so C_offset is added to C_data before the
reinterpret_cast (i.e., apply (C_data) + (C_offset) inside the cast operand) for
the occurrences matching the snippet "reinterpret_cast<uint32_t*>((C_data)) +
(C_offset), (scale_out), *reinterpret_cast<uint32_t*>((e_data) + (E_offset))"
(also update the other occurrence around lines ~2838-2839) so the offset uses
accumulator element units prior to the uint32_t* conversion.
- Around line 2523-2534: The current code reads meta_tvm_dtype from
op->args[12]->dtype (the metadata pointer) which yields the handle/pointer
dtype; instead derive the element dtype of the metadata buffer (the element type
E of the metadata array) and base MetaType on that element dtype, then map
DataType::UInt(8/16/32/64) to "uint8_t"/"uint16_t"/"uint32_t"/"uint64_t" into
the MetaType string; update the logic that sets meta_tvm_dtype (used when
computing MetaType and currently named meta_tvm_dtype / MetaType) to inspect the
buffer/element dtype of op->args[12] rather than the pointer/handle dtype so
tl::mma_sp_sync is instantiated with the correct metadata type.
In `@testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py`:
- Line 60: The T.gemm_sp calls incorrectly pass trans_A as the 7th positional
argument (which binds to the transpose_E parameter); update each invocation of
T.gemm_sp (the calls in this file including the one at the shown location) to
supply the correct transpose_E value explicitly—either by replacing the 7th
positional argument with the intended boolean (e.g., False) or by using a named
argument transpose_E=False—to ensure the transpose_E parameter is set correctly
instead of reusing trans_A.
In `@tilelang/intrinsics/wgmma_sp_macro_generator.py`:
- Around line 220-235: Replace the debug print(...) calls in
wgmma_sp_macro_generator.py (the prints that emit elems_in_bytes,
self.block_row_warps, self.block_col_warps, warp_row_tiles, warp_col_tiles,
warp_k, warp_rows, warp_cols, M_DIM, n_dim, k_dim, wgmma_prefix, micro_size_x,
micro_size_y, micro_size_k, a_swizzle_mode, b_swizzle_mode,
a_leading_byte_offset, a_stride_byte_offset, b_leading_byte_offset,
b_stride_byte_offset, ak_atom_size, bk_atom_size, num_inst_m, num_inst_n,
accum_regs, self.e_dtype, self.SPARSE_FACTOR, self.SPARSE_SELECTOR,
a_swizzle_atom_elems, b_swizzle_atom_elems) with proper logging calls (e.g.,
logger.debug(...)) or remove them; ensure you reference the module/class logger
(create one via logging.getLogger(__name__) if none exists) so these messages
respect log levels and do not pollute stdout during normal compilation.
In `@tilelang/language/ast/ir.py`:
- Around line 1887-1888: The current references to ptx_wgmma_sp_ss and
ptx_wgmma_sp_rs in tilelang/language/ast/ir.py use _tir_op from upstream tvm.tir
which doesn't expose those symbols; change the import so _tir_op is the local
module that defines them (mirror the pattern in tilelang/language/tir/ir.py by
importing tilelang.language.tir.op as _tir_op) so ptx_wgmma_sp_ss and
ptx_wgmma_sp_rs resolve to the local implementations.
In `@tilelang/tileop/gemm_sp/__init__.py`:
- Around line 44-65: Remove the stray debug print statements that currently log
inside gemm_sp_infer_layout (the print showing type(self.A)) and inside lower
(the print of gemm_inst, impl_class, target, layout_map, thread_nums,
thread_var); either delete those print(...) calls or replace them with a proper
debug-level logger call (e.g., use an existing logging facility or tvm's
logging) so they don't print to stdout in production, ensuring the rest of the
methods (gemm_sp_infer_layout, lower, infer_layout, and the calls to
impl_class(...).infer_layout/lower) remain unchanged.
In `@tilelang/utils/sparse.py`:
- Around line 39-40: Remove the two leftover debug print statements inside the
_compress_fn function in sparse.py so they don't run on every `@tilelang.jit`
invocation; locate the prints that output f"{D=} {elem=} ..." and f"{[S, D *
elem // group]=}..." and delete them (or replace with a debug-level logger call
if telemetry is required) to avoid polluting stdout in user code.
- Around line 86-103: Callsites still use removed helpers and deprecated
parameters (compress_sm90, and compress(..., transposed=, arch=, block_k=));
update every invocation to call the new compress(A, meta_dtype=None) signature:
replace compress_sm90(...) with compress(...) and remove deprecated keyword args
(transposed, arch, block_k), pass only the tensor A and optionally meta_dtype
(converted via dtype string helpers if needed), and update imports to import
compress from tilelang.utils.sparse instead of compress_sm90; verify tests and
example calls (those that previously passed meta/arch/block info) supply
meta_dtype when required.
---
Nitpick comments:
In `@src/op/builtin.h`:
- Around line 372-381: Add full signature-style comments for the two new
intrinsics ptx_wgmma_sp_ss() and ptx_wgmma_sp_rs() mirroring the existing
ptx_wgmma_ss/ptx_wgmma_rs documentation: enumerate each argument in order (types
and brief meaning), explicitly document the extra sparse-related parameters (E
descriptor, E offset, sparse selector) and state the input count
(ptx_wgmma_sp_ss: 18 inputs; ptx_wgmma_sp_rs: 17 inputs) and why sp_ss has one
more argument than sp_rs; place these comments directly above the TVM_DLL const
Op &ptx_wgmma_sp_ss() and TVM_DLL const Op &ptx_wgmma_sp_rs() declarations so
callers can see the full signature without inspecting builtin.cc.
In `@src/op/gemm_sp.cc`:
- Around line 99-115: GetGemmSPInst lacks an explicit return after the ICHECK(0)
failure path which can trigger -Wreturn-type warnings; update
GemmSPNode::GetGemmSPInst to add a defensive return after ICHECK(0) (e.g.,
return a sensible default like GemmInst::kMMA or mirror other functions’
behavior) so the function always returns a GemmInst even if the ICHECK is
treated as non‑noreturn; modify the end of GetGemmSPInst to include this
explicit return to silence compiler warnings.
- Around line 145-180: The wrapper path in GemmSPNode::Lower can lose important
iter_values/predicate and reads/writes info when prim_func->body is not a
BlockRealize; either require/assert that the FFI always returns a BlockRealize
or populate the synthetic Block's iter_vars/reads/writes/predicate from the
original lowering context. Update GemmSPNode::Lower to (a) check prim_func->body
and if not a BlockRealize, derive and set appropriate iter_values and predicate
and fill reads/writes (using T.layout_map, T.thread_bounds/thread_var or any
access-region information available from prim_func or T) instead of using empty
arrays/const_true, or add a clear ICHECK/assert that the FFI must return a
BlockRealize with the correct metadata so downstream passes (access region
inference, buffer compaction) are not broken.
In `@src/tl_templates/cuda/instruction/cute_extension/mma_sm80_sparse.hpp`:
- Around line 345-380: The four integer MMA templates duplicate a SparseSel::One
specialization that only calls CUTE_INVALID_CONTROL_PATH and causes
unused-parameter warnings; replace those specializations by adding a
static_assert in each primary template (e.g., inside SM80_16x8x32_S32S8S8S32_TN,
SM80_16x8x64_S32S8S8S32_TN, SM80_16x8x32_S32U8U8S32_TN,
SM80_16x8x64_S32U8U8S32_TN) that forbids SparseSel::One (e.g.,
static_assert(spsel != SparseSel::One, "…")) to provide the compile-time
diagnostic and then remove the corresponding SparseSel::One specialization
blocks (which duplicate many register typedefs and the long fma(...) signature);
if you prefer to keep the specializations instead, silence -Wunused-parameter
for the fma parameters in those specialization definitions.
In `@src/tl_templates/cuda/instruction/cute_extension/mma_sm89_sparse.hpp`:
- Around line 25-201: The four near-identical structs
SM89_16x8x64_F32E4M3E4M3F32_TN, SM89_16x8x64_F32E4M3E5M2F32_TN,
SM89_16x8x64_F32E5M2E4M3F32_TN, and SM89_16x8x64_F32E5M2E5M2F32_TN differ only
by the fp8 operand suffix in the inline-asm and the diagnostic string; replace
them with a single parametric generator (either an X-macro or a template
wrapper) that takes the PTX suffix string (e.g., "e4m3.e4m3", "e4m3.e5m2", etc.)
and a short name fragment and emits the struct and its CUTE_INVALID_CONTROL_PATH
message, then update/replace the duplicated fma implementations to use that
generator so the asm literal and the invalid-path message are produced from the
single parameter. Ensure the generator exposes the same struct type names (or
typedefs) used elsewhere (or provide forwarding aliases) and preserves the fma
signature and static_assert(spsel == SparseSel::Zero).
In `@src/tl_templates/cuda/instruction/mma_sp.h`:
- Around line 93-101: Add a second compile-time check to ensure the register
counts match: in the same scope where the existing type equality static_assert
is located (the static_assert comparing typename Traits::DReg and typename
Traits::CReg inside exec), add a static_assert(Traits::kDRegs == Traits::kCRegs,
"tl::mma_sp_sync requires matching accumulator/output register counts"); this
will prevent mismatched std::make_index_sequence expansions in call_fma_sp_impl
and ensure Traits::kDRegs and Traits::kCRegs are equal before calling
call_fma_sp/Impl::fma.
In `@src/tl_templates/cuda/instruction/wgmma_sp.h`:
- Around line 44-48: The static_assert for CReg in tl::wgmma_sp_rs is redundant
because sizeof(uint32_t) == sizeof(float) on supported CUDA targets; update the
check to a single clear condition (e.g., require sizeof(CReg) ==
sizeof(uint32_t) to match the SS variant) by replacing the current "sizeof(CReg)
== sizeof(uint32_t) || sizeof(CReg) == sizeof(float)" assertion with a
single-size comparison referencing CReg and the tl::wgmma_sp_rs context so the
intent is unambiguous.
In `@testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py`:
- Around line 958-973: The function run_gemm_rr currently defines defaults for
num_stages, num_threads, and meta_dtype which is inconsistent with sibling
runners (run_gemm_ss, run_gemm_rs, run_gemm_sr) and can mask caller mistakes;
update run_gemm_rr's signature to require num_stages, num_threads, and
meta_dtype as positional parameters (remove the defaults =3, =128, =T.int16) so
callers must pass them explicitly, mirroring the other runner functions and
preventing silent fallback.
- Around line 311-336: Remove the redundant nested import statement "import
tilelang.language as T" that appears inside the matmul_rs function (and
similarly inside matmul_sr and matmul_rr); since T is already imported at module
level, delete the per-function re-imports so the functions use the module-level
T binding and avoid shadowing/redundant imports.
In `@tilelang/intrinsics/wgmma_sp_macro_generator.py`:
- Around line 504-516: The code creates temporary tuples by doing tuple(E_other)
+ (E_base0 + wk + mk, E_base1 + wi + mi) when indexing E_shared_buf; change both
branches of the conditional in the assignment to E_local_buf[...] to use
iterable unpacking instead (e.g., E_shared_buf[*E_other, E_base0 + wk + mk,
E_base1 + wi + mi] and E_shared_buf[*E_other, E_base0 + wi + mi, E_base1 + wk +
mk]) so you avoid allocating intermediate tuples; update the two places inside
the loop that reference E_other/E_shared_buf (the assignment to E_local_buf in
the for j loop) and keep the surrounding logic and return of
_warp_ldmatrix_e(E_local_buf, E_buf, inst_i, ki, thread_binding) intact.
- Around line 237-243: In the _warp_mma macro, the local unpacking uses tx but
that variable is unused; change the unpacked name from tx to _tx to satisfy the
linter and indicate intentional unused value (e.g., replace "tx, warp_n, warp_m
= self.extract_thread_binding(thread_binding)" with "_tx, warp_n, warp_m =
self.extract_thread_binding(thread_binding)" inside the _warp_mma definition),
leaving all other logic and variables (k_blocks, e_stage_elems, E_local, etc.)
unchanged.
- Around line 446-460: The ldmatrix_available flag in ldmatrix_e is hardcoded
False which leaves a TODO untracked; update ldmatrix_e to either compute
ldmatrix_available from the current conditions (e.g., based on a_dtype/e_dtype
and transposed state) or, if you cannot implement the fast path now, replace the
hardcoded assignment with a clear tracked marker (e.g., set ldmatrix_available =
False and add a FIXME(author_name) comment) and create a follow-up issue
referencing ldmatrix_e and ldmatrix_available so the optimization isn't lost;
ensure the change mentions the constraint (int8 + transposed case) and the
location (ldmatrix_e in wgmma_sp_macro_generator.py) so future work can
implement the fast path.
In `@tilelang/tileop/gemm_sp/__init__.py`:
- Around line 56-65: Both infer_layout and lower call _select_gemm_instruction
and instantiate impl_class(self) separately, causing duplicate FFI calls and
extra object creation; cache the selected gemm_inst and/or the instantiated
implementation on self so both methods reuse the same value/instance: call
self._select_gemm_instruction(thread_nums, target) once (e.g. store as
self._cached_gemm_inst keyed by thread_nums/target or invalidate when inputs
change), use self._get_implementation_class(cached_inst, target) once and keep
the instantiated impl (instead of impl_class(self) twice) so infer_layout and
lower call the same impl object's infer_layout and lower methods.
- Around line 44-54: Replace the instance-method registrations with module-level
free-function wrappers consistent with the gemm pattern: add two top-level
functions (e.g., gemm_sp_infer_layout(gemm_sp: GemmSP, target: Target,
thread_bounds: Range) and gemm_sp_lower(gemm_sp: GemmSP, target: Target,
layout_map: dict, thread_bounds: Range, thread_var: tir.Var)) that call the
corresponding instance methods (gemm_sp.infer_layout(...) and
gemm_sp.lower(...)) and register those free functions with
tvm_ffi.register_global_func instead of the current instance-bound functions;
keep the same argument order and behavior (extract thread_nums =
thread_bounds.extent and forward to infer_layout/lower) so the C++ callers
continue to work unchanged.
In `@tilelang/tileop/gemm_sp/gemm_sp_wgmma.py`:
- Around line 138-146: The trailing raise ValueError is only reachable for
unsupported gemm types (sr/rr) but the current structure returns inside the
is_gemm_ss() and is_gemm_rs() branches, making control flow less explicit;
update the logic in the function that defines _gemm_ssr/_gemm_rsr to first check
supported cases (e.g., if not (self.is_gemm_ss() or self.is_gemm_rs()): raise
ValueError(...)) and then use explicit if/elif for self.is_gemm_ss() and
self.is_gemm_rs() to return _Simplify(_gemm_ssr, inline_let=True) or
_Simplify(_gemm_rsr, inline_let=True) respectively, mirroring infer_layout’s
else: raise pattern and keeping symbols _gemm_ssr, _gemm_rsr, _Simplify,
is_gemm_ss, is_gemm_rs unchanged.
- Around line 73-80: The lower method's signature is incorrect: change the type
hint for thread_nums in gemm_sp_wgmma.py lower(self, ...) from Range to int |
tir.PrimExpr (or tir.PrimExpr) to match callers that pass thread_bounds.extent,
and either remove the unused parameter mbar_phase_expr or propagate it through
the function (use it where barriers/phases are computed) so it is referenced;
update the signature and all internal references to use the new thread_nums type
and delete mbar_phase_expr if you choose to drop it, ensuring callers (e.g., in
tilelang/tileop/gemm_sp/__init__.py) remain compatible.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b1de2664-45b9-438b-9341-46deb38c38c9
📒 Files selected for processing (21)
src/op/builtin.ccsrc/op/builtin.hsrc/op/gemm_sp.ccsrc/op/gemm_sp.hsrc/target/codegen_cuda.ccsrc/target/codegen_cuda.hsrc/tl_templates/cuda/instruction/cute_extension/mma_sm80_sparse.hppsrc/tl_templates/cuda/instruction/cute_extension/mma_sm89_sparse.hppsrc/tl_templates/cuda/instruction/mma_sp.hsrc/tl_templates/cuda/instruction/wgmma_sp.htesting/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.pytilelang/intrinsics/mma_sp_layout.pytilelang/intrinsics/mma_sp_macro_generator.pytilelang/intrinsics/wgmma_sp_macro_generator.pytilelang/language/ast/ir.pytilelang/language/tir/ir.pytilelang/language/tir/op.pytilelang/tileop/gemm_sp/__init__.pytilelang/tileop/gemm_sp/gemm_sp_mma.pytilelang/tileop/gemm_sp/gemm_sp_wgmma.pytilelang/utils/sparse.py
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The constants (SPARSE_PARAMS, get_e_factor, get_e_replicate_factor) are hardware-level parameters primarily consumed by mma_sp_macro_generator; intrinsics/ is the natural home. utils/sparse.py re-exports them for backward compatibility. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
"params" better reflects the content: hardware sparsity format parameters (SPARSE_PARAMS, get_e_factor, get_e_replicate_factor), not runtime config. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
from __future__ import annotations makes all annotations lazy strings, which breaks TIR prim_func's get_type_hints() on the inner kernel. Use Optional[torch.dtype] instead of the union syntax to avoid needing it. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (3)
tilelang/intrinsics/wgmma_sp_macro_generator.py (3)
89-102: 💤 Low valueUnused
n_dimparameter in_initialize_wgmma_prefix.The parameter
n_dimis accepted (and passed from__init__asself.n_dimat Line 79) but never read —inst_nis computed fromself.warp_col_tiles. Either remove the parameter or actually driveinst_nfrom it; the current shape is misleading and risks future callers assuming the argument has effect.Proposed change
- def _initialize_wgmma_prefix(self, n_dim: int = 16): + def _initialize_wgmma_prefix(self): inst_m, inst_n = 64, gcd(self.warp_col_tiles, 256)…and update the call site at Line 79 accordingly.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/intrinsics/wgmma_sp_macro_generator.py` around lines 89 - 102, The method _initialize_wgmma_prefix currently ignores its n_dim parameter (inst_n is computed from self.warp_col_tiles), so either remove the unused parameter and stop passing self.n_dim at the call site in __init__, or make n_dim actually drive inst_n (e.g., replace inst_n = gcd(self.warp_col_tiles, 256) with an expression using n_dim such as inst_n = gcd(n_dim, 256) or inst_n = n_dim followed by the same validation/asserts); update the __init__ call accordingly to match the chosen approach and keep the existing validations for inst_n in _initialize_wgmma_prefix.
38-39: 💤 Low valuePrefer
Optional[Layout]and instance-level initialization.
a_shared_layout: Layout = None/b_shared_layout: Layout = Noneare declared at class scope with non-Optionaltype annotations bound toNone. Type checkers will flag the mismatch, and the attributes effectively become shared class state until an instance assigns them. Initializing them in__init__(or annotating asOptional[Layout] = Noneand assigning in__init__) avoids both pitfalls.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/intrinsics/wgmma_sp_macro_generator.py` around lines 38 - 39, The attributes a_shared_layout and b_shared_layout are declared at class scope with non-Optional type annotations but set to None, creating a type mismatch and shared class state; update the class to either annotate them as Optional[Layout] = None or (preferably) remove the class-level assignments and initialize self.a_shared_layout and self.b_shared_layout inside the class __init__ method (use the Layout type for annotations on instance attributes) so each instance gets its own Layout fields and the type checker is satisfied.
437-469: 🏗️ Heavy liftTrack the
ldmatrixTODO.The
ldmatrix_available = False # TODO: use ldmatrix when possibleshortcut forces the slow per-thread elementwise metadata load path for every sparse WGMMA call. Worth filing a follow-up so this doesn't get lost — sparse WGMMA E-loads are on the hot path for every K-block.Want me to open a tracking issue with a short summary of the supported
(e_dtype, a_dtype)permutations and the CUTLASS reference (include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h) for whoever picks this up?🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/intrinsics/wgmma_sp_macro_generator.py` around lines 437 - 469, Open a tracking issue for the TODO that hardcodes ldmatrix_available = False in wgmma_sp_macro_generator (variable name ldmatrix_available) so we don't lose this perf regresssion; in the issue include a short summary of the supported (e_dtype, a_dtype) permutations (matching the metadata_* helper layouts used here: metadata_8bit_load_32x4_to_shared_16x8_layout_8bit, metadata_8bit_load_32x4_to_shared_16x4_layout_16bit, metadata_8bit_load_32x4_to_shared_16x4_layout_32bit, metadata_16bit_load_32x2_to_shared_16x4_layout_8bit, metadata_16bit_load_32x2_to_shared_16x2_layout_16bit, metadata_16bit_load_32x2_to_shared_16x2_layout_32bit, metadata_32bit_load_32x1_to_shared_16x2_layout_8bit) and reference the CUTLASS header include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h; add the issue number back into the source comment next to the TODO so future authors can track progress.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@docs/deeplearning_operators/matmul_sparse.md`:
- Around line 67-69: The code uses the non-existent
SparseTensorCoreIntrinEmitter.E_FACTOR_MAP to obtain E_factor; replace that with
a call to get_e_factor(in_dtype, metadata_dtype) and add an import for
get_e_factor from tilelang.intrinsics.sparse_params at the top of the example;
specifically change the assignment where E_factor is computed (currently
referencing E_FACTOR_MAP) to call get_e_factor with the local variables in_dtype
and metadata_dtype so the example uses the correct API.
In `@testing/python/utils/test_compress_utils.py`:
- Around line 135-136: The ad-hoc main entry calls the pytest-parametrized test
function test_compress() without providing required parameters (dtype,
meta_dtype), causing a TypeError when run directly; remove the "__main__" block
or replace it with a pytest invocation (e.g., call pytest.main([...])) so the
file is executed by pytest rather than directly calling test_compress, and
ensure test_compress is only run through pytest's runner; reference the test
function name test_compress to locate the faulty block.
- Around line 49-60: The call to T.gemm_sp currently passes transpose flags
positionally as (trans_A, trans_B, trans_A), which is error-prone; update the
call to use named parameters so the mapping is explicit (e.g.,
transpose_A=trans_A, transpose_B=trans_B, transpose_E=trans_A) in the T.gemm_sp
invocation to match the copy logic for A and E and improve readability and
maintainability.
In `@tilelang/utils/sparse.py`:
- Around line 63-68: The code writes nz_idx[nz_count] before checking the
device_assert, which can cause out-of-bounds writes; change the loop around
nz_count/nz_idx (the T.if_then_else stores that reference, the
dense_local[local_idx + i] check, and the T.device_assert) so the store is
guarded: only write nz_idx when nz_count < elem (e.g., use a conditional that
writes to nz_idx[min(nz_count, elem-1)] or wrap the store in an if/conditional
that checks nz_count < elem), and separately increment an overflow counter if
you still want to assert on excess nonzeros; apply the same fix to the second
occurrence referenced around lines 135-140 so both loops avoid indexing past
nz_idx.
- Around line 19-20: The default meta dtype is set to T.int16 which diverges
from torch_compress's behavior (torch treats int8 metadata as int32); change the
module-wide constant _DEFAULT_META_DTYPE from T.int16 to T.int32 and update any
fallback logic in the compress-related code paths (see references to
_DEFAULT_META_DTYPE and the compress function blocks around lines ~189-191 and
~233-243) so that omissions of meta_dtype use T.int32 by default, only falling
back to narrower types when explicitly required.
---
Nitpick comments:
In `@tilelang/intrinsics/wgmma_sp_macro_generator.py`:
- Around line 89-102: The method _initialize_wgmma_prefix currently ignores its
n_dim parameter (inst_n is computed from self.warp_col_tiles), so either remove
the unused parameter and stop passing self.n_dim at the call site in __init__,
or make n_dim actually drive inst_n (e.g., replace inst_n =
gcd(self.warp_col_tiles, 256) with an expression using n_dim such as inst_n =
gcd(n_dim, 256) or inst_n = n_dim followed by the same validation/asserts);
update the __init__ call accordingly to match the chosen approach and keep the
existing validations for inst_n in _initialize_wgmma_prefix.
- Around line 38-39: The attributes a_shared_layout and b_shared_layout are
declared at class scope with non-Optional type annotations but set to None,
creating a type mismatch and shared class state; update the class to either
annotate them as Optional[Layout] = None or (preferably) remove the class-level
assignments and initialize self.a_shared_layout and self.b_shared_layout inside
the class __init__ method (use the Layout type for annotations on instance
attributes) so each instance gets its own Layout fields and the type checker is
satisfied.
- Around line 437-469: Open a tracking issue for the TODO that hardcodes
ldmatrix_available = False in wgmma_sp_macro_generator (variable name
ldmatrix_available) so we don't lose this perf regresssion; in the issue include
a short summary of the supported (e_dtype, a_dtype) permutations (matching the
metadata_* helper layouts used here:
metadata_8bit_load_32x4_to_shared_16x8_layout_8bit,
metadata_8bit_load_32x4_to_shared_16x4_layout_16bit,
metadata_8bit_load_32x4_to_shared_16x4_layout_32bit,
metadata_16bit_load_32x2_to_shared_16x4_layout_8bit,
metadata_16bit_load_32x2_to_shared_16x2_layout_16bit,
metadata_16bit_load_32x2_to_shared_16x2_layout_32bit,
metadata_32bit_load_32x1_to_shared_16x2_layout_8bit) and reference the CUTLASS
header include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h; add the
issue number back into the source comment next to the TODO so future authors can
track progress.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ebf491c4-e1dc-42b8-8d46-bd736f0c693e
📒 Files selected for processing (15)
benchmark/matmul/benchmark_matmul_sp.pybenchmark/matmul/benchmark_matmul_sp_compress.pydocs/deeplearning_operators/matmul_sparse.mdexamples/gemm_sp/example_custom_compress.pyexamples/gemm_sp/example_gemm_sp.pyexamples/gemm_sp/test_example_gemm_sp.pyexamples/sparse_tensorcore/tilelang_example_sparse_tensorcore.pytesting/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.pytesting/python/utils/test_compress_utils.pytilelang/intrinsics/mma_sp_macro_generator.pytilelang/intrinsics/sparse_params.pytilelang/intrinsics/wgmma_sp_macro_generator.pytilelang/tileop/gemm_sp/__init__.pytilelang/tileop/gemm_sp/gemm_sp_wgmma.pytilelang/utils/sparse.py
💤 Files with no reviewable changes (2)
- examples/gemm_sp/test_example_gemm_sp.py
- examples/gemm_sp/example_custom_compress.py
🚧 Files skipped from review as they are similar to previous changes (3)
- tilelang/tileop/gemm_sp/gemm_sp_wgmma.py
- tilelang/tileop/gemm_sp/init.py
- testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/backend/cuda/op/gemm_sp.cc`:
- Around line 119-138: When falling back in
policy.isFullRow()/policy.isFullCol(), the code must pick warp factors that are
divisors of num_warps rather than using truncated integer division; change the
logic that computes max_m_warps/max_n_warps so it searches for the largest
divisor d of num_warps satisfying d * kMPerWarp <= M (for m_warp) or d *
k_n_per_warp <= N (for n_warp), then set the other axis to num_warps / d (and
clamp to 1 if needed). Apply the same divisor-based selection in the symmetric
fallback block referenced around lines 196-213 as well; use the variables
m_warp, n_warp, num_warps, kMPerWarp and k_n_per_warp and the
policy.isFullRow()/policy.isFullCol() branches to locate the spots to change.
In `@src/op/gemm_sp.cc`:
- Around line 256-267: The op is being marked as tcgen05 by setting
ann.Set("is_tcgen05", ...) in the TLOpBuilder for
TVM_REGISTER_OP("tl.tileop.tcgen05_gemm_sp"), but the CUDA selector still routes
to FatalTcgen5Unavailable; change the builder so it only sets the is_tcgen05
annotation (or registers the tcgen05 variant) when the backend actually supports
tcgen05 lowering: e.g., gate the ann.Set("is_tcgen05", ...) behind a
runtime/compile-time capability check (a new isTcgen05Supported() or existing
backend query) or avoid registering the tcgen05 name variant until support is
added; keep the GemmSP(...) call unchanged except for omitting the is_tcgen05
flag when unsupported to prevent hard failures during lowering.
In `@testing/python/utils/test_compress_utils.py`:
- Line 77: The ternary size expressions are redundant because both branches are
(N, N) (K == N), so simplify the torch.randint calls by removing the conditional
and using size=(N, N) directly; update occurrences that use the pattern (e.g.,
the B construction line referencing trans_B and the similar line at 80) to call
torch.randint(size=(N, N), low=low, high=high, dtype=in_dtype, device="cuda") to
eliminate the useless ternary.
In `@tilelang/cuda/intrinsics/macro/wgmma_sp_macro_generator.py`:
- Around line 148-149: The swizzle detection is being called with BufferRegion
objects (A_region, B_region) but _determinate_swizzle_mode builds comparison
layouts from a tir.Buffer, causing incorrect NONE fallbacks for swizzled shared
layouts; update the calls that compute a_swizzle_mode and b_swizzle_mode to pass
the underlying tir.Buffer (e.g., A_region.buffer or B_region.buffer or the
attribute that holds the Buffer) and the appropriate region/view information
into _determinate_swizzle_mode (or add an overload) so the helper receives the
actual tir.Buffer used to construct comparison layouts; apply the same change to
the other occurrence that invokes _determinate_swizzle_mode for B (the similar
call around the later occurrence).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b325a80d-9693-4056-9d4a-97f07b1a3fdb
📒 Files selected for processing (16)
src/backend/cuda/codegen/codegen_cuda.ccsrc/backend/cuda/codegen/codegen_cuda.hsrc/backend/cuda/op/gemm_sp.ccsrc/op/builtin.ccsrc/op/builtin.hsrc/op/gemm.ccsrc/op/gemm_sp.ccsrc/op/gemm_sp.hsrc/transform/lower_opaque_block.cctesting/python/issue/test_tilelang_issue_tma_no_ws.pytesting/python/issue/test_tilelang_issue_ws_simt_copy_full_producer_extent.pytesting/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.pytesting/python/utils/test_compress_utils.pytilelang/cuda/intrinsics/layout/mma_sp_layout.pytilelang/cuda/intrinsics/macro/mma_sp_macro_generator.pytilelang/cuda/intrinsics/macro/wgmma_sp_macro_generator.py
💤 Files with no reviewable changes (1)
- testing/python/issue/test_tilelang_issue_tma_no_ws.py
✅ Files skipped from review due to trivial changes (2)
- src/backend/cuda/codegen/codegen_cuda.h
- src/transform/lower_opaque_block.cc
🚧 Files skipped from review as they are similar to previous changes (3)
- src/op/builtin.h
- src/op/builtin.cc
- testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
| a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout) | ||
| b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) |
There was a problem hiding this comment.
Pass the underlying buffer into swizzle detection.
These calls hand _determinate_swizzle_mode() a BufferRegion, but the helper builds comparison layouts from a tir.Buffer. On swizzled shared layouts that can silently fall back to SwizzleMode.NONE and produce the wrong WGMMA descriptor.
Suggested fix
- a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout)
- b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
+ a_swizzle_mode = self._determinate_swizzle_mode(A_region.buffer, self.a_shared_layout)
+ b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)
...
- b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
+ b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)Also applies to: 330-330
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/cuda/intrinsics/macro/wgmma_sp_macro_generator.py` around lines 148
- 149, The swizzle detection is being called with BufferRegion objects
(A_region, B_region) but _determinate_swizzle_mode builds comparison layouts
from a tir.Buffer, causing incorrect NONE fallbacks for swizzled shared layouts;
update the calls that compute a_swizzle_mode and b_swizzle_mode to pass the
underlying tir.Buffer (e.g., A_region.buffer or B_region.buffer or the attribute
that holds the Buffer) and the appropriate region/view information into
_determinate_swizzle_mode (or add an overload) so the helper receives the actual
tir.Buffer used to construct comparison layouts; apply the same change to the
other occurrence that invokes _determinate_swizzle_mode for B (the similar call
around the later occurrence).
|
@LeiWang1999 This refactoring is ready to be reviewed, thanks :p |
As described in the title.
Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Refactor
gemm_sp_v2and consolidating to singlegemm_spinterface.