Skip to content

[Backend] Refactor gemm_sp#2048

Open
botbw wants to merge 41 commits into
tile-ai:mainfrom
botbw:refactor_gemm_sp
Open

[Backend] Refactor gemm_sp#2048
botbw wants to merge 41 commits into
tile-ai:mainfrom
botbw:refactor_gemm_sp

Conversation

@botbw
Copy link
Copy Markdown
Contributor

@botbw botbw commented Apr 16, 2026

As described in the title.

Summary by CodeRabbit

  • New Features

    • Added sparse GEMM compression with PyTorch reference implementation.
    • Enhanced sparse GEMM testing framework with flexible operand staging support.
  • Bug Fixes

    • Improved sparse tensor metadata handling and layout inference.
  • Documentation

    • Updated sparse GEMM documentation with clearer metadata format guidance and compression examples.
  • Refactor

    • Simplified sparse GEMM architecture by removing gemm_sp_v2 and consolidating to single gemm_sp interface.
    • Refactored compression utilities with new JIT-based implementation replacing CUDA extensions.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 16, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR refactors the sparse GEMM operator stack by consolidating GemmSPPy into GemmSP, replacing static CUDA templates with FFI-based lowering, adding SM80/SM89 sparse MMA instruction primitives and SM90 WGMMA sparse support, introducing WGMMA-based sparse tensor-core emitters, and refactoring the compression utility to use TileLang JIT instead of CUDA extensions.

Changes

GemmSP Operator Definition and Schema Overhaul

Layer / File(s) Summary
Updated GemmSP operator schema
src/op/gemm_sp.h, src/op/gemm_sp.cc, tilelang/tileop/gemm_sp/__init__.py
GemmSP public fields now use region-based layout (aRegion/eRegion/bRegion/cRegion), explicit transpose flags including trans_E, and PrimExpr clear_accum for accumulator control. Constructors parse enriched serialized arguments and delegate lowering/layout inference to global FFI functions tl.gemm_sp.lower and tl.gemm_sp.infer_layout instead of C++ backends.
GemmSPPy removal and FFI migration
src/op/gemm_sp_py.h, src/op/gemm_sp_py.cc, tilelang/language/__init__.py, tilelang/ir.py
Deleted the entire GemmSPPy tile-operator class and its registration. Updated module exports to reference GemmSP directly and removed the re-export of gemm_sp_v2 from language bindings.
Warp policy refactoring
src/op/gemm_sp.h, src/op/gemm_sp.cc
GemmSPWarpPolicyNode now derives from Object and stores policy type plus mutable warp-partition fields. Removed the bits parameter from computeWarpPartition and added helper predicates (isSquare, isFullRow, isFullCol, isFree).

Sparse MMA Instruction Primitives and CUDA Templates

Layer / File(s) Summary
SM80 and SM89 sparse MMA instruction definitions
src/tl_templates/cuda/instruction/cute_extension/mma_sm80_sparse.hpp, src/tl_templates/cuda/instruction/cute_extension/mma_sm89_sparse.hpp
Added new sparse PTX instruction wrappers under SM80::MMA::SPARSE and SM89::MMA::SPARSE namespaces for FP16, FP32, BF16, TF32, INT8, and FP8 (E4M3/E5M2) data types, each parameterized by SparseSel and emitting conditional inline PTX with ordered-metadata variants for newer CUDA toolchains.
Sparse MMA dispatcher and WGMMA sparse instructions
src/tl_templates/cuda/instruction/mma_sp.h, src/tl_templates/cuda/instruction/wgmma_sp.h
Added tl::mma_sp_sync<...> template dispatcher selecting SM80/SM89 implementations and tl::wgmma_sp_ss<...> / tl::wgmma_sp_rs<...> public template functions for SM90 WGMMA sparse operations, instantiated across float/int32 data combinations including FP8 variants.
Removed legacy static templates
src/tl_templates/cuda/gemm_sp.h, src/tl_templates/cuda/gemm_sp_sm80.h, src/tl_templates/cuda/gemm_sp_sm90.h
Deleted the entire dispatch and implementation headers that previously contained hardcoded CUTLASS-based and CUTE-based SM80/SM90 sparse GEMM templates, replacing them with the new unified FFI-based lowering pathway.

WGMMA Sparse Tensor-Core Emitters

Layer / File(s) Summary
GemmSPWGMMA implementation class
tilelang/tileop/gemm_sp/gemm_sp_wgmma.py
Added GemmSPWGMMA tile-operator specialization with infer_layout, lower, and shared-memory layout selection (infer_shared_layout) methods. Dispatches to a WGSparseTensorCoreIntrinEmitter for WGMMA-based lowering, supporting shared-shared (ss) and fragment-shared (rs) operand scopes.
WGMMA sparse emitter and macro generation
tilelang/cuda/intrinsics/macro/wgmma_sp_macro_generator.py
Introduced WGSparseTensorCoreIntrinEmitter class that extends the sparse MMA emitter to generate SM90 WGMMA-based sparse multiply sequences. Implements wgmma_ss (shared A/E with descriptor B) and wgmma_rs (fragment A with shared E and descriptor B) paths, including descriptor initialization, per-iteration offset computation, PTX emission, and warpgroup synchronization logic.

Sparse Compression Utility Refactoring

Layer / File(s) Summary
TileLang-based compression implementation
tilelang/utils/sparse.py, benchmark/matmul/benchmark_matmul_sp_compress.py
Replaced CUDA-extension sparse compressor with TileLang JIT kernel (_compress_fn) and added pure-PyTorch reference implementation (torch_compress). Updated public compress API to accept meta_dtype and compute block_m/block_k instead of exposing architecture-specific parameters. Refactored semi-sparse generators to derive sparsity patterns from a GROUP_CONFIG dictionary indexed by TileLang dtype. Added new compression benchmark script.

CUDA Code Generation and Instruction Lowering

Layer / File(s) Summary
Sparse PTX intrinsic definitions and registration
src/op/builtin.h, src/op/builtin.cc, tilelang/language/ast/ir.py, tilelang/language/tir/ir.py, tilelang/language/tir/op.py
Registered two new sparse WGMMA PTX builtins (ptx_wgmma_sp_ss, ptx_wgmma_sp_rs) with opaque call effects (18 and 17 inputs respectively). Created corresponding Python wrapper functions forwarding descriptor/buffer operands, sparse selectors, and metadata operands.
CUDA codegen for sparse instructions
src/backend/cuda/codegen/codegen_cuda.h, src/backend/cuda/codegen/codegen_cuda.cc
Updated CodeGenTileLangCUDA::VisitExpr_ to generate tl::mma_sp_sync<...> calls for ptx_mma_sp (deriving MetaType from metadata buffer element dtype and mapping sparse_selector to SM80::MMA::SparseSel), and added lowering for new ptx_wgmma_sp_ss / ptx_wgmma_sp_rs intrinsics. Added feature flags (need_mma_sp_instruction_h_, need_wgmma_sp_instruction_h_) to conditionally emit corresponding CUDA template headers.
CUDA instruction selection and warp partitioning
src/backend/cuda/op/gemm_sp.cc
Introduced new instruction-selection pipeline selecting among cuda.mma.sp, cuda.wgmma.sp, and cuda.tcgen05.sp instances with eligibility predicates (WGMMA capability gates for dtype/transpose/scope, TCGEN5 gates for target/scope). Added helper methods to compute default and WGMMA warp-group partitions, and updated the operator's lowering to dispatch to selected implementation classes (MMA vs WGMMA) for Hopper GPUs.

Sparse MMA Layout and Macro Updates

Layer / File(s) Summary
Sparse MMA layout mappings
tilelang/cuda/intrinsics/layout/mma_sp_layout.py, tilelang/cuda/intrinsics/macro/mma_sp_macro_generator.py
Updated metadata layout mappings for 8-bit and 16-bit loads to use explicit logical_id-based row/col computation and introduced metadata_8bit_load_32x4_to_shared_16x8_layout_8bit variant. Modified SparseTensorCoreIntrinEmitter to compute e_factor via get_e_factor(...) and e_replicate_factor via helper functions instead of hardcoded lookup tables, and added related assertions.

Test Coverage and Example Updates

Layer / File(s) Summary
Generalized sparse GEMM tests
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py, testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
Removed architecture-gated (matmul_sp_sm90/sm80) kernels and introduced four unified kernel builders (matmul, matmul_rs, matmul_sr, matmul_rr) parameterized by explicit trans_A/trans_B, metadata_dtype, and E_factor. Standardized runner implementations with deterministic seeding, compression via compress(..., meta_dtype=...), and unified output validation using torch_assert_close. Deleted the entire test_tilelang_tilelibrary_gemm_sp_v2.py file.
Sparse example and documentation updates
examples/gemm_sp/example_gemm_sp.py, examples/gemm_sp/example_custom_compress.py, examples/gemm_sp/test_example_gemm_sp.py, docs/deeplearning_operators/matmul_sparse.md
Updated example kernels to use explicit e_dtype parameter (instead of architecture selection) and call T.gemm_sp (not gemm_sp_v2) without metadata layout annotations. Updated benchmarks to expose kernel configuration flags (--block_M, --block_N, --block_K, --num_stages, --thread_num, --e_dtype). Updated documentation to emphasize that tilelang.utils.sparse.compress produces row-major metadata directly consumable by T.gemm_sp without additional layout annotation.
Compression and sparse utility tests
testing/python/utils/test_compress_utils.py, testing/python/issue/test_tilelang_issue_tma_no_ws.py, testing/python/issue/test_tilelang_issue_ws_simt_copy_full_producer_extent.py
Refactored compression test to validate both TileLang and Torch-based compressors against a shared matmul kernel, removing SM90-specific gating and make_cutlass_metadata_layout usage. Updated sparse workspace tests to use T.int8 for metadata and removed T.annotate_layout and T.disable_warp_group_reg_alloc calls.
Operator registration updates
src/op/gemm.cc, src/transform/lower_opaque_block.cc
Updated GEMM and related operator registrations to accept variable input counts (set_num_inputs(-1)). Updated an inline comment example from gemm/gemm_sp_py to gemm/gemm_sp.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Lang as tilelang.language
    participant GemmSP as GemmSP Operator
    participant Backend as CUDA Backend
    participant Codegen as CUDA Codegen
    participant Templates as CUDA Templates

    User->>Lang: Call T.gemm_sp(A, E, B, C, ..., e_dtype)
    Lang->>GemmSP: Construct GemmSP with regions & trans_E
    GemmSP->>Backend: invoke tl.gemm_sp.lower()
    Backend->>Backend: Select instruction (MMA/WGMMA/TCGEN5)
    alt WGMMA Selected
        Backend->>Backend: Instantiate GemmSPWGMMA
        Backend->>Codegen: Lower to wgmma_ss/wgmma_rs PTX calls
        Codegen->>Templates: Generate wgmma_sp_ss/wgmma_sp_rs<...>
        Templates->>Templates: Emit SM90 sparse WGMMA PTX
    else MMA Selected
        Backend->>Backend: Instantiate GemmSPMMA
        Codegen->>Templates: Generate mma_sp_sync<...>
        Templates->>Templates: Emit SM80/SM89 sparse MMA PTX
    end
    Templates-->>Codegen: Compiled CUDA kernel
    Codegen-->>Lang: TIR PrimFunc
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

This PR involves substantial, multi-layered refactoring across operator schemas, CUDA instruction templates, codegen emission, sparse utilities, and test infrastructure. It introduces new sparse MMA instruction primitives for multiple GPU architectures, transitions from static CUDA templates to FFI-based lowering, overhauls the sparse compression implementation, and updates numerous examples and tests. The changes are complex in logic, span many affected files with heterogeneous patterns (operator definition, template headers, Python emitters, tests), and require careful review of instruction semantics, FFI wiring, and backward compatibility.

Possibly related PRs

  • tile-ai/tilelang#2033: Performs the same operator FFI refactor pattern (removing legacy "*_py" variants and promoting canonical FFI object names), directly related to GemmSP/GemmSPPy consolidation.
  • tile-ai/tilelang#2161: Introduces WGMMATensorCoreIntrinEmitter and moves to atom-level MMA/descriptor APIs for WGMMA, shares overlapping sparse emitter infrastructure.
  • tile-ai/tilelang#1949: Makes closely related changes to GEMM operator plumbing by introducing explicit WGMMA/TCGEN05 variants and lowering pathways.

Suggested reviewers

  • LeiWang1999

Poem

🐰 From template halls to FFI dreams,
sparse GEMM flows through cleaner seams.
Warpgroups dance with SM90's might,
compression blooms—TileLang bright!
Ops refactored, schema sleek,
a rabbit's work, complete and chic.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (3)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (3)

197-197: Redundant import inside function.

tilelang.language as T is already imported at module level (line 10). This redundant import also appears in matmul_sr (line 344) and matmul_rr (line 493).

♻️ Proposed fix

Remove the redundant imports on lines 197, 344, and 493:

-    import tilelang.language as T
-
     `@T.prim_func`
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py` at line 197,
Remove the redundant local imports of tilelang.language as T inside the
file-level functions; since tilelang.language is already imported at module
scope, delete the in-function import statements found inside the functions
matmul_sr and matmul_rr (and the extra import at the earlier location near the
test setup) so those functions use the module-level T symbol instead.

153-171: Minor inconsistency: T.float vs T.float32 for accumulator dtype.

In test_gemm_ss parameters, the first four cases use T.float while other tests (test_gemm_rs, test_gemm_sr, test_gemm_rr) consistently use T.float32. While functionally equivalent, using T.float32 consistently would improve readability.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py` around lines
153 - 171, Replace the inconsistent accumulator dtype T.float with T.float32 in
the test_gemm_ss parameterization: update the first four tuples where dtypeAccum
is T.float to use T.float32 so they match the other tests (refer to the test
function name test_gemm_ss and the runner call run_gemm_ss to locate the
parameter list).

114-132: Consider extracting the _matmul helper to module level.

The _matmul helper function is defined identically inside run_gemm_ss, run_gemm_rs, run_gemm_sr, and run_gemm_rr. Extracting it to module level would reduce duplication.

♻️ Proposed refactor

Add at module level (e.g., after generate_dense_input):

def _reference_matmul(A, B, trans_A, trans_B):
    if trans_A:
        A = A.T
    if trans_B:
        B = B.T
    A = A.to(torch.float32)
    B = B.to(torch.float32)
    return torch.matmul(A, B)

Then replace inline definitions with calls to _reference_matmul(A, B, trans_A, trans_B).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py` around lines
114 - 132, The helper _matmul is duplicated inside
run_gemm_ss/run_gemm_rs/run_gemm_sr/run_gemm_rr; extract it to module level as a
single function (e.g., _reference_matmul(A, B, trans_A, trans_B)) placed near
generate_dense_input, implement the same behavior (apply transposes based on
trans_A/trans_B, cast to torch.float32, then torch.matmul) and update each
run_gemm_* to call _reference_matmul(A, B, trans_A, trans_B) instead of defining
its own _matmul to remove duplication and keep behavior identical.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@docs/deeplearning_operators/matmul_sparse.md`:
- Line 166: Replace the non-descriptive "here" link texts with explicit labels:
change the first link label to something like "PyTorch Ampere compressor
implementation" (pointing to the same URL) and the second link label to
"metadata permutation routine" (pointing to the permutation URL); update the
sentence so it reads e.g. "PyTorch provides an Ampere compressor (PyTorch Ampere
compressor implementation). Note that in this implementation, a permutation
(metadata permutation routine) is applied..." so the links are descriptive and
MD059-compliant while preserving the existing URLs and the instruction about
matching CUTLASS metadata layout and not using
_calculate_meta_reordering_scatter_offsets.

In `@tilelang/language/__init__.py`:
- Line 60: The public symbol gemm_sp_v2 was removed when replacing the import
with gemm_sp; restore backward compatibility by re-exporting a compatibility
alias named gemm_sp_v2 that points to the new gemm_sp implementation (i.e.,
import gemm_sp from .experimental.gemm_sp and assign gemm_sp_v2 = gemm_sp so
existing callers of T.gemm_sp_v2(...) continue to work).

In `@tilelang/tileop/__init__.py`:
- Line 3: The public export GemmSPPy was removed causing breakage for imports;
restore a deprecated alias by reintroducing GemmSPPy as an alias to GemmSP
(e.g., GemmSPPy = GemmSP) in tilelang.tileop.__init__ and emit a
DeprecationWarning via the warnings module when the alias is used or on import
to advise users to switch to GemmSP; also ensure GemmSPPy is present in the
module's public exports (e.g., __all__) so downstream code importing GemmSPPy
continues to work.

In `@tilelang/tileop/gemm_sp/__init__.py`:
- Around line 15-24: The FFI callbacks gemm_sp_infer_layout and gemm_sp_lower
are incorrectly annotated with GemmSPMMA; update both function signatures to use
GemmSP instead so the Python FFI type matches the C++ side (which uses
GetRef<GemmSP>); ensure the parameters and uses inside those functions (calls to
gemm_sp.infer_layout(...) and gemm_sp.lower(...)) remain unchanged so
infer_layout and lower are invoked on a GemmSP instance.

---

Nitpick comments:
In `@testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py`:
- Line 197: Remove the redundant local imports of tilelang.language as T inside
the file-level functions; since tilelang.language is already imported at module
scope, delete the in-function import statements found inside the functions
matmul_sr and matmul_rr (and the extra import at the earlier location near the
test setup) so those functions use the module-level T symbol instead.
- Around line 153-171: Replace the inconsistent accumulator dtype T.float with
T.float32 in the test_gemm_ss parameterization: update the first four tuples
where dtypeAccum is T.float to use T.float32 so they match the other tests
(refer to the test function name test_gemm_ss and the runner call run_gemm_ss to
locate the parameter list).
- Around line 114-132: The helper _matmul is duplicated inside
run_gemm_ss/run_gemm_rs/run_gemm_sr/run_gemm_rr; extract it to module level as a
single function (e.g., _reference_matmul(A, B, trans_A, trans_B)) placed near
generate_dense_input, implement the same behavior (apply transposes based on
trans_A/trans_B, cast to torch.float32, then torch.matmul) and update each
run_gemm_* to call _reference_matmul(A, B, trans_A, trans_B) instead of defining
its own _matmul to remove duplication and keep behavior identical.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: cf24b8d5-9f02-4a04-96fb-48da817b4d31

📥 Commits

Reviewing files that changed from the base of the PR and between 235ad7e and 0bce4d0.

📒 Files selected for processing (15)
  • benchmark/matmul/benchmark_matmul_sp.py
  • docs/deeplearning_operators/matmul_sparse.md
  • examples/gemm_sp/example_custom_compress.py
  • src/op/gemm_sp.cc
  • src/op/gemm_sp.h
  • src/op/gemm_sp_py.cc
  • src/op/gemm_sp_py.h
  • src/transform/lower_opaque_block.cc
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
  • tilelang/ir.py
  • tilelang/language/__init__.py
  • tilelang/language/experimental/gemm_sp.py
  • tilelang/tileop/__init__.py
  • tilelang/tileop/gemm_sp/__init__.py
💤 Files with no reviewable changes (4)
  • tilelang/ir.py
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
  • src/op/gemm_sp_py.h
  • src/op/gemm_sp_py.cc

Comment thread docs/deeplearning_operators/matmul_sparse.md Outdated
Comment thread tilelang/language/__init__.py Outdated
from tilelang.tileop.base import GemmWarpPolicy # noqa: F401
from .gemm_op import gemm, wgmma_gemm, tcgen05_gemm # noqa: F401
from .experimental.gemm_sp import gemm_sp, gemm_sp_v2 # noqa: F401
from .experimental.gemm_sp import gemm_sp # noqa: F401
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Preserve gemm_sp_v2 during the rename.

Line 60 removes a public frontend symbol. Existing kernels that import or call T.gemm_sp_v2(...) will break even though this PR is unifying the implementation rather than removing the feature.

Proposed compatibility shim
 from .experimental.gemm_sp import gemm_sp  # noqa: F401
+gemm_sp_v2 = gemm_sp  # backward-compatible alias
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from .experimental.gemm_sp import gemm_sp # noqa: F401
from .experimental.gemm_sp import gemm_sp # noqa: F401
gemm_sp_v2 = gemm_sp # backward-compatible alias
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/language/__init__.py` at line 60, The public symbol gemm_sp_v2 was
removed when replacing the import with gemm_sp; restore backward compatibility
by re-exporting a compatibility alias named gemm_sp_v2 that points to the new
gemm_sp implementation (i.e., import gemm_sp from .experimental.gemm_sp and
assign gemm_sp_v2 = gemm_sp so existing callers of T.gemm_sp_v2(...) continue to
work).

from .base import GemmWarpPolicy # noqa: F401
from .gemm import Gemm # noqa: F401
from .gemm_sp import GemmSPPy # noqa: F401
from .gemm_sp import GemmSP # noqa: F401
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Keep GemmSPPy as a deprecated alias.

Line 3 removes a public export outright. Any downstream from tilelang.tileop import GemmSPPy now fails immediately, even though this change is a rename rather than a semantic removal.

Proposed compatibility shim
 from .gemm_sp import GemmSP  # noqa: F401
+GemmSPPy = GemmSP  # backward-compatible alias
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from .gemm_sp import GemmSP # noqa: F401
from .gemm_sp import GemmSP # noqa: F401
GemmSPPy = GemmSP # backward-compatible alias
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/tileop/__init__.py` at line 3, The public export GemmSPPy was
removed causing breakage for imports; restore a deprecated alias by
reintroducing GemmSPPy as an alias to GemmSP (e.g., GemmSPPy = GemmSP) in
tilelang.tileop.__init__ and emit a DeprecationWarning via the warnings module
when the alias is used or on import to advise users to switch to GemmSP; also
ensure GemmSPPy is present in the module's public exports (e.g., __all__) so
downstream code importing GemmSPPy continues to work.

Comment thread tilelang/tileop/gemm_sp/__init__.py Outdated
@botbw botbw marked this pull request as draft April 17, 2026 06:54
@botbw botbw force-pushed the refactor_gemm_sp branch from e51ff0a to de32b3b Compare April 19, 2026 06:18
botbw and others added 14 commits May 8, 2026 16:23
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Remove unused imports (tvm, _TORCH_DTYPE_TO_STR)
- Remove print("pass") calls
- Normalize all parametrize tuples to compact aligned format

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Replace removed compress_sm90 with compress() in sparse_tensorcore
  example and test_compress_utils
- Fix example_gemm_sp.py: remove ARCH_INFO and make_cutlass_metadata_layout;
  compress() now produces natural-layout int16 metadata on all architectures
- Update matmul_sparse.md to reflect the simplified compress() API and
  remove the CUTLASS layout annotation requirement

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Add tilelang/utils/sparse_config.py as single source of truth for
  SPARSE_PARAMS, E_FACTOR_MAP, E_REPLICATE_FACTOR, and get_e_factor()
- Remove inline E_FACTOR_MAP/E_REPLICATE_FACTOR from SparseTensorCoreIntrinEmitter;
  import from sparse_config.py with backward-compatible class aliases
- Update compress() API: remove transposed/block_k/arch params, use T.int16 metadata
- Delete example_custom_compress.py; move torch_compress reference to
  testing/python/utils/test_compress_utils.py with float8 .view() support
- Update examples (example_gemm_sp, tilelang_example_sparse_tensorcore) and
  test files to use new compress() API and get_e_factor()

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@botbw botbw marked this pull request as ready for review May 10, 2026 08:26
- Remove E_FACTOR_MAP and E_REPLICATE_FACTOR dicts from sparse_config.py
- get_e_factor: derived from SPARSE_PARAMS group size and meta_dtype.bits
  (hardware always packs 4 bits per sparsity group)
- get_e_replicate_factor: derived from a_dtype.bits (<=8 bit → 1, else → 2)
- SparseTensorCoreIntrinEmitter uses the two functions directly instead of
  class-level dict aliases

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🧹 Nitpick comments (16)
src/op/builtin.h (1)

372-381: 💤 Low value

Add full signature comments for the new sparse WGMMA intrinsics.

ptx_wgmma_ss / ptx_wgmma_rs above (Lines 350-370) document the full argument list. The new ptx_wgmma_sp_ss (18 inputs in builtin.cc) and ptx_wgmma_sp_rs (17 inputs) only get a one-liner, so callers can’t tell from this header which args were added (E descriptor, E offset, sparse selector, etc.) or why the sp_ss/sp_rs counts diverge by one. Mirroring the existing doc style would prevent guesswork at codegen sites.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/op/builtin.h` around lines 372 - 381, Add full signature-style comments
for the two new intrinsics ptx_wgmma_sp_ss() and ptx_wgmma_sp_rs() mirroring the
existing ptx_wgmma_ss/ptx_wgmma_rs documentation: enumerate each argument in
order (types and brief meaning), explicitly document the extra sparse-related
parameters (E descriptor, E offset, sparse selector) and state the input count
(ptx_wgmma_sp_ss: 18 inputs; ptx_wgmma_sp_rs: 17 inputs) and why sp_ss has one
more argument than sp_rs; place these comments directly above the TVM_DLL const
Op &ptx_wgmma_sp_ss() and TVM_DLL const Op &ptx_wgmma_sp_rs() declarations so
callers can see the full signature without inspecting builtin.cc.
src/tl_templates/cuda/instruction/cute_extension/mma_sm89_sparse.hpp (1)

25-201: ⚖️ Poor tradeoff

Consolidate the four SM89 sparse fp8 MMA structs using a macro or template.

The four structs (SM89_16x8x64_F32E4M3E4M3F32_TN, SM89_16x8x64_F32E4M3E5M2F32_TN, SM89_16x8x64_F32E5M2E4M3F32_TN, SM89_16x8x64_F32E5M2E5M2F32_TN) are identical except for their fp8 operand type suffix (e4m3.e4m3, e4m3.e5m2, e5m2.e4m3, e5m2.e5m2) in the PTX instruction strings and struct name in the diagnostic message. Folding them with a small X-macro or template dispatcher would reduce duplication by ~75%, eliminate accidental drift between variants, and make adding new fp8 combinations (e.g., future scaled fp8) a single-line change.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/tl_templates/cuda/instruction/cute_extension/mma_sm89_sparse.hpp` around
lines 25 - 201, The four near-identical structs SM89_16x8x64_F32E4M3E4M3F32_TN,
SM89_16x8x64_F32E4M3E5M2F32_TN, SM89_16x8x64_F32E5M2E4M3F32_TN, and
SM89_16x8x64_F32E5M2E5M2F32_TN differ only by the fp8 operand suffix in the
inline-asm and the diagnostic string; replace them with a single parametric
generator (either an X-macro or a template wrapper) that takes the PTX suffix
string (e.g., "e4m3.e4m3", "e4m3.e5m2", etc.) and a short name fragment and
emits the struct and its CUTE_INVALID_CONTROL_PATH message, then update/replace
the duplicated fma implementations to use that generator so the asm literal and
the invalid-path message are produced from the single parameter. Ensure the
generator exposes the same struct type names (or typedefs) used elsewhere (or
provide forwarding aliases) and preserves the fma signature and
static_assert(spsel == SparseSel::Zero).
src/tl_templates/cuda/instruction/mma_sp.h (1)

93-101: 💤 Low value

Consider asserting that kDRegs == kCRegs in addition to type matching.

The static assertion confirms DReg and CReg types match, but the dispatcher then independently expands std::make_index_sequence<kDRegs> and <kCRegs> (lines 56–60 in call_fma_sp_impl). For accumulator-style MMA/FMA, kDRegs should equal kCRegs. A divergence would silently mis-call Impl::fma. A trivial extra static_assert(Traits::kDRegs == Traits::kCRegs, ...) would harden against future CUTE template additions.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/tl_templates/cuda/instruction/mma_sp.h` around lines 93 - 101, Add a
second compile-time check to ensure the register counts match: in the same scope
where the existing type equality static_assert is located (the static_assert
comparing typename Traits::DReg and typename Traits::CReg inside exec), add a
static_assert(Traits::kDRegs == Traits::kCRegs, "tl::mma_sp_sync requires
matching accumulator/output register counts"); this will prevent mismatched
std::make_index_sequence expansions in call_fma_sp_impl and ensure
Traits::kDRegs and Traits::kCRegs are equal before calling
call_fma_sp/Impl::fma.
tilelang/tileop/gemm_sp/__init__.py (2)

56-65: 💤 Low value

infer_layout and lower create a fresh implementation instance per call — minor inefficiency.

impl_class(self) is constructed twice (once for layout inference, once for lowering), and each call additionally re-runs _select_gemm_instruction (which crosses the FFI to the C++ GemmSPGetGemmSPInst). For a hot compilation path this is fine, but consider memoizing gemm_inst on self if compilation latency becomes a concern.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/tileop/gemm_sp/__init__.py` around lines 56 - 65, Both infer_layout
and lower call _select_gemm_instruction and instantiate impl_class(self)
separately, causing duplicate FFI calls and extra object creation; cache the
selected gemm_inst and/or the instantiated implementation on self so both
methods reuse the same value/instance: call
self._select_gemm_instruction(thread_nums, target) once (e.g. store as
self._cached_gemm_inst keyed by thread_nums/target or invalidate when inputs
change), use self._get_implementation_class(cached_inst, target) once and keep
the instantiated impl (instead of impl_class(self) twice) so infer_layout and
lower call the same impl object's infer_layout and lower methods.

44-54: ⚡ Quick win

Refactor to use free function wrappers for consistency with the gemm module pattern.

gemm_sp_infer_layout and gemm_sp_lower are instance methods registered via @tvm_ffi.register_global_func, whereas the sibling gemm module uses explicit free function wrappers (e.g., def gemm_infer_layout(gemm: GemmMMA, target, ...)). The C++ side calls these identically via (*f)(GetRef<GemmSP>(this), ...), so both patterns work, but instance method registration deviates from the established codebase convention. Align with the gemm pattern by extracting wrapper functions at module level.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/tileop/gemm_sp/__init__.py` around lines 44 - 54, Replace the
instance-method registrations with module-level free-function wrappers
consistent with the gemm pattern: add two top-level functions (e.g.,
gemm_sp_infer_layout(gemm_sp: GemmSP, target: Target, thread_bounds: Range) and
gemm_sp_lower(gemm_sp: GemmSP, target: Target, layout_map: dict, thread_bounds:
Range, thread_var: tir.Var)) that call the corresponding instance methods
(gemm_sp.infer_layout(...) and gemm_sp.lower(...)) and register those free
functions with tvm_ffi.register_global_func instead of the current
instance-bound functions; keep the same argument order and behavior (extract
thread_nums = thread_bounds.extent and forward to infer_layout/lower) so the C++
callers continue to work unchanged.
src/op/gemm_sp.cc (2)

99-115: 💤 Low value

Missing fallthrough return after ICHECK(0) may trigger compiler warnings.

GetGemmSPInst is non-void but has no return after the ICHECK(0) in the unsupported-target branch. TVM's ICHECK typically expands to a LOG(FATAL) which is [[noreturn]], so this is likely safe at runtime, but some compilers still warn -Wreturn-type here. Consider adding an explicit return GemmInst::kMMA; or LOG(FATAL) after the chain to be defensive and silence warnings, mirroring the same issue in Lower (line 178) and InferLayout (line 192).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/op/gemm_sp.cc` around lines 99 - 115, GetGemmSPInst lacks an explicit
return after the ICHECK(0) failure path which can trigger -Wreturn-type
warnings; update GemmSPNode::GetGemmSPInst to add a defensive return after
ICHECK(0) (e.g., return a sensible default like GemmInst::kMMA or mirror other
functions’ behavior) so the function always returns a GemmInst even if the
ICHECK is treated as non‑noreturn; modify the end of GetGemmSPInst to include
this explicit return to silence compiler warnings.

145-180: 💤 Low value

Lower's wrapping path may lose iter_values/predicate semantics.

When the FFI prim_func->body is a BlockRealize, the existing iter_values and predicate are preserved (lines 162–163) — good. However, when it's not a BlockRealize, you wrap the body in a synthetic Block with empty iter_vars/reads/writes and constant‑true predicate (lines 169–176). If the FFI lowering ever returns a body that needs nontrivial reads/writes for downstream passes (e.g., access region inference, buffer compaction), the empty reads/writes arrays could regress those analyses. Consider asserting the FFI always returns a BlockRealize, or populating reads/writes from T.layout_map/access regions.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/op/gemm_sp.cc` around lines 145 - 180, The wrapper path in
GemmSPNode::Lower can lose important iter_values/predicate and reads/writes info
when prim_func->body is not a BlockRealize; either require/assert that the FFI
always returns a BlockRealize or populate the synthetic Block's
iter_vars/reads/writes/predicate from the original lowering context. Update
GemmSPNode::Lower to (a) check prim_func->body and if not a BlockRealize, derive
and set appropriate iter_values and predicate and fill reads/writes (using
T.layout_map, T.thread_bounds/thread_var or any access-region information
available from prim_func or T) instead of using empty arrays/const_true, or add
a clear ICHECK/assert that the FFI must return a BlockRealize with the correct
metadata so downstream passes (access region inference, buffer compaction) are
not broken.
src/tl_templates/cuda/instruction/wgmma_sp.h (1)

44-48: 💤 Low value

Minor: redundant condition in static_assert for CReg size.

sizeof(uint32_t) == sizeof(float) == 4 on every supported CUDA target, so the || is tautological. Either simplify to sizeof(CReg) == sizeof(uint32_t) (matching the SS variant on line 20–21) or sizeof(CReg) == 4.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/tl_templates/cuda/instruction/wgmma_sp.h` around lines 44 - 48, The
static_assert for CReg in tl::wgmma_sp_rs is redundant because sizeof(uint32_t)
== sizeof(float) on supported CUDA targets; update the check to a single clear
condition (e.g., require sizeof(CReg) == sizeof(uint32_t) to match the SS
variant) by replacing the current "sizeof(CReg) == sizeof(uint32_t) ||
sizeof(CReg) == sizeof(float)" assertion with a single-size comparison
referencing CReg and the tl::wgmma_sp_rs context so the intent is unambiguous.
tilelang/intrinsics/wgmma_sp_macro_generator.py (3)

504-516: 💤 Low value

Use iterable unpacking for cleaner indexing (Ruff RUF005).

tuple(E_other) + (E_base0 + ..., E_base1 + ...) allocates an intermediate tuple twice; iterable unpacking is more idiomatic and slightly more efficient.

Proposed fix
                 E_local_buf[e_local_base + inst_i * local_size_e + j] = (
-                    E_shared_buf[tuple(E_other) + (E_base0 + wk + mk, E_base1 + wi + mi)]
+                    E_shared_buf[(*E_other, E_base0 + wk + mk, E_base1 + wi + mi)]
                     if trans
-                    else E_shared_buf[tuple(E_other) + (E_base0 + wi + mi, E_base1 + wk + mk)]
+                    else E_shared_buf[(*E_other, E_base0 + wi + mi, E_base1 + wk + mk)]
                 )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/intrinsics/wgmma_sp_macro_generator.py` around lines 504 - 516, The
code creates temporary tuples by doing tuple(E_other) + (E_base0 + wk + mk,
E_base1 + wi + mi) when indexing E_shared_buf; change both branches of the
conditional in the assignment to E_local_buf[...] to use iterable unpacking
instead (e.g., E_shared_buf[*E_other, E_base0 + wk + mk, E_base1 + wi + mi] and
E_shared_buf[*E_other, E_base0 + wi + mi, E_base1 + wk + mk]) so you avoid
allocating intermediate tuples; update the two places inside the loop that
reference E_other/E_shared_buf (the assignment to E_local_buf in the for j loop)
and keep the surrounding logic and return of _warp_ldmatrix_e(E_local_buf,
E_buf, inst_i, ki, thread_binding) intact.

237-243: 💤 Low value

Unused variable tx flagged by Ruff (RUF059).

tx is unpacked but never referenced in the _warp_mma macro for wgmma_ss. Rename to _tx to silence the lint and document intent.

Proposed fix
-            tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
+            _tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/intrinsics/wgmma_sp_macro_generator.py` around lines 237 - 243, In
the _warp_mma macro, the local unpacking uses tx but that variable is unused;
change the unpacked name from tx to _tx to satisfy the linter and indicate
intentional unused value (e.g., replace "tx, warp_n, warp_m =
self.extract_thread_binding(thread_binding)" with "_tx, warp_n, warp_m =
self.extract_thread_binding(thread_binding)" inside the _warp_mma definition),
leaving all other logic and variables (k_blocks, e_stage_elems, E_local, etc.)
unchanged.

446-460: 💤 Low value

ldmatrix_available = False looks like an unfinished TODO worth tracking.

The hardcoded ldmatrix_available = False plus the comment "TODO: use ldmatrix when possible" means every metadata load currently goes through the slower per-element path even when ldmatrix would be valid. This is a perf opportunity, not a correctness bug — consider filing a follow-up issue or # FIXME(name) so it isn't lost.

Want me to open a follow-up issue tracking this performance TODO?

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/intrinsics/wgmma_sp_macro_generator.py` around lines 446 - 460, The
ldmatrix_available flag in ldmatrix_e is hardcoded False which leaves a TODO
untracked; update ldmatrix_e to either compute ldmatrix_available from the
current conditions (e.g., based on a_dtype/e_dtype and transposed state) or, if
you cannot implement the fast path now, replace the hardcoded assignment with a
clear tracked marker (e.g., set ldmatrix_available = False and add a
FIXME(author_name) comment) and create a follow-up issue referencing ldmatrix_e
and ldmatrix_available so the optimization isn't lost; ensure the change
mentions the constraint (int8 + transposed case) and the location (ldmatrix_e in
wgmma_sp_macro_generator.py) so future work can implement the fast path.
tilelang/tileop/gemm_sp/gemm_sp_wgmma.py (2)

138-146: 💤 Low value

Both branches return; trailing raise is unreachable for ss/rs and only triggered for sr/rr — consider an explicit early check.

is_gemm_ss() and is_gemm_rs() already cover the supported configurations and both branches return. The trailing raise ValueError(...) only fires for sr/rr cases, which is fine, but mirroring the structure of infer_layout (which uses an explicit else: raise) would make the control flow uniform and easier to follow.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/tileop/gemm_sp/gemm_sp_wgmma.py` around lines 138 - 146, The
trailing raise ValueError is only reachable for unsupported gemm types (sr/rr)
but the current structure returns inside the is_gemm_ss() and is_gemm_rs()
branches, making control flow less explicit; update the logic in the function
that defines _gemm_ssr/_gemm_rsr to first check supported cases (e.g., if not
(self.is_gemm_ss() or self.is_gemm_rs()): raise ValueError(...)) and then use
explicit if/elif for self.is_gemm_ss() and self.is_gemm_rs() to return
_Simplify(_gemm_ssr, inline_let=True) or _Simplify(_gemm_rsr, inline_let=True)
respectively, mirroring infer_layout’s else: raise pattern and keeping symbols
_gemm_ssr, _gemm_rsr, _Simplify, is_gemm_ss, is_gemm_rs unchanged.

73-80: 💤 Low value

Type hint and unused parameter on lower(...).

  • thread_nums: Range is misleading: the caller in tilelang/tileop/gemm_sp/__init__.py passes thread_bounds.extent (an int/PrimExpr), not a Range. Consider thread_nums: int | tir.PrimExpr.
  • mbar_phase_expr: tir.PrimExpr | None = None is declared but never referenced in the body. Either drop it or thread it through.
Proposed fix
-    def lower(
-        self,
-        layout_map: dict,
-        target: Target,
-        thread_nums: Range,
-        thread_var: tir.Var,
-        mbar_phase_expr: tir.PrimExpr | None = None,
-    ):
+    def lower(
+        self,
+        layout_map: dict,
+        target: Target,
+        thread_nums: int | tir.PrimExpr,
+        thread_var: tir.Var,
+    ):
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/tileop/gemm_sp/gemm_sp_wgmma.py` around lines 73 - 80, The lower
method's signature is incorrect: change the type hint for thread_nums in
gemm_sp_wgmma.py lower(self, ...) from Range to int | tir.PrimExpr (or
tir.PrimExpr) to match callers that pass thread_bounds.extent, and either remove
the unused parameter mbar_phase_expr or propagate it through the function (use
it where barriers/phases are computed) so it is referenced; update the signature
and all internal references to use the new thread_nums type and delete
mbar_phase_expr if you choose to drop it, ensuring callers (e.g., in
tilelang/tileop/gemm_sp/__init__.py) remain compatible.
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (2)

958-973: 💤 Low value

Inconsistent signature: run_gemm_rr introduces defaults that the sibling runners deliberately require.

run_gemm_ss (line 66), run_gemm_rs (line 375), and run_gemm_sr (line 664) all take num_stages, num_threads, and meta_dtype as required positional arguments, while run_gemm_rr defaults them to 3, 128, and T.int16. All current callers pass these explicitly via pytest.mark.parametrize, so behavior is fine today, but the asymmetry is a maintenance hazard: a future caller that forgets to pass meta_dtype will silently fall through to T.int16 instead of failing fast.

♻️ Align signature with the other runners
-def run_gemm_rr(
-    M,
-    N,
-    K,
-    trans_A,
-    trans_B,
-    in_dtype,
-    out_dtype,
-    dtypeAccum,
-    block_M,
-    block_N,
-    block_K,
-    num_stages=3,
-    num_threads=128,
-    meta_dtype=T.int16,
-):
+def run_gemm_rr(
+    M,
+    N,
+    K,
+    trans_A,
+    trans_B,
+    in_dtype,
+    out_dtype,
+    dtypeAccum,
+    block_M,
+    block_N,
+    block_K,
+    num_stages,
+    num_threads,
+    meta_dtype,
+):
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py` around lines
958 - 973, The function run_gemm_rr currently defines defaults for num_stages,
num_threads, and meta_dtype which is inconsistent with sibling runners
(run_gemm_ss, run_gemm_rs, run_gemm_sr) and can mask caller mistakes; update
run_gemm_rr's signature to require num_stages, num_threads, and meta_dtype as
positional parameters (remove the defaults =3, =128, =T.int16) so callers must
pass them explicitly, mirroring the other runner functions and preventing silent
fallback.

311-336: 💤 Low value

Drop the redundant nested import tilelang.language as T.

Module-level line 9 already exposes T; the per-kernel re-imports inside matmul_rs (line 336), matmul_sr (line 625), and matmul_rr (line 916) shadow it with the same binding and are dead noise.

♻️ Remove the duplicate import (one example, apply to all three)
     E_shared_shape = (block_M, block_K // E_factor) if not trans_A else (block_K // E_factor, block_M)
-
-    import tilelang.language as T

     `@T.prim_func`

Also applies to: 600-625, 890-916

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py` around lines
311 - 336, Remove the redundant nested import statement "import
tilelang.language as T" that appears inside the matmul_rs function (and
similarly inside matmul_sr and matmul_rr); since T is already imported at module
level, delete the per-function re-imports so the functions use the module-level
T binding and avoid shadowing/redundant imports.
src/tl_templates/cuda/instruction/cute_extension/mma_sm80_sparse.hpp (1)

345-380: ⚡ Quick win

Collapse the SparseSel::One integer specializations via static_assert to remove duplication.

For the four integer structs (SM80_16x8x32_S32S8S8S32_TN, SM80_16x8x64_S32S8S8S32_TN, SM80_16x8x32_S32U8U8S32_TN, SM80_16x8x64_S32U8U8S32_TN), the SparseSel::One specialization duplicates ~15 register aliases plus a 13-parameter fma signature solely to call CUTE_INVALID_CONTROL_PATH. Since there are no actual instantiations of these integer types with SparseSel::One in the codebase, a static_assert in the primary template provides the same safety with better compile-time diagnostics and eliminates the four specialization blocks entirely.

Additionally, the named parameters in those specializations are unused and generate -Wunused-parameter warnings.

♻️ Option A — replace specialization with static_assert in primary template (preferred)
 template <SparseSel spsel = SparseSel::Zero> struct SM80_16x8x32_S32S8S8S32_TN {
   using DRegisters = uint32_t[4];
   using ARegisters = uint32_t[2];
   using BRegisters = uint32_t[2];
   using CRegisters = uint32_t[4];
   using ERegisters = uint32_t[1];

   CUTE_HOST_DEVICE static void fma(uint32_t &d0, uint32_t &d1, uint32_t &d2,
                                    uint32_t &d3, uint32_t const &a0,
                                    uint32_t const &a1, uint32_t const &b0,
                                    uint32_t const &b1, uint32_t const &c0,
                                    uint32_t const &c1, uint32_t const &c2,
                                    uint32_t const &c3, uint32_t const &e) {
+    static_assert(spsel == SparseSel::Zero,
+                  "Integer sparse MMA only supports SparseSel::Zero");
 `#if` defined(CUTE_ARCH_SPARSE_MMA_SM80_ENABLED)
     ...
-template <> struct SM80_16x8x32_S32S8S8S32_TN<SparseSel::One> {
-  using DRegisters = uint32_t[4];
-  ...
-  CUTE_HOST_DEVICE static void fma(uint32_t &d0, uint32_t &d1, uint32_t &d2,
-                                   ...
-                                   uint32_t const &c3, uint32_t const &e) {
-    CUTE_INVALID_CONTROL_PATH(
-        "SM80_16x8x32_S32S8S8S32_TN with SparseSel::One is invalid");
-  }
-};

Apply analogously to the other three integer structs (lines 442, 500, 561).

♻️ Option B — keep the specializations but silence unused-parameter warnings
-  CUTE_HOST_DEVICE static void fma(uint32_t &d0, uint32_t &d1, uint32_t &d2,
-                                   uint32_t &d3, uint32_t const &a0,
-                                   uint32_t const &a1, uint32_t const &b0,
-                                   uint32_t const &b1, uint32_t const &c0,
-                                   uint32_t const &c1, uint32_t const &c2,
-                                   uint32_t const &c3, uint32_t const &e) {
+  CUTE_HOST_DEVICE static void fma(uint32_t &, uint32_t &, uint32_t &,
+                                   uint32_t &, uint32_t const &,
+                                   uint32_t const &, uint32_t const &,
+                                   uint32_t const &, uint32_t const &,
+                                   uint32_t const &, uint32_t const &,
+                                   uint32_t const &, uint32_t const &) {
     CUTE_INVALID_CONTROL_PATH(
         "SM80_16x8x32_S32S8S8S32_TN with SparseSel::One is invalid");
   }
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/tl_templates/cuda/instruction/cute_extension/mma_sm80_sparse.hpp` around
lines 345 - 380, The four integer MMA templates duplicate a SparseSel::One
specialization that only calls CUTE_INVALID_CONTROL_PATH and causes
unused-parameter warnings; replace those specializations by adding a
static_assert in each primary template (e.g., inside SM80_16x8x32_S32S8S8S32_TN,
SM80_16x8x64_S32S8S8S32_TN, SM80_16x8x32_S32U8U8S32_TN,
SM80_16x8x64_S32U8U8S32_TN) that forbids SparseSel::One (e.g.,
static_assert(spsel != SparseSel::One, "…")) to provide the compile-time
diagnostic and then remove the corresponding SparseSel::One specialization
blocks (which duplicate many register typedefs and the long fma(...) signature);
if you prefer to keep the specializations instead, silence -Wunused-parameter
for the fma parameters in those specialization definitions.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/target/codegen_cuda.cc`:
- Around line 2769-2771: In the sparse WGMMA code paths (ptx_wgmma_ss /
ptx_wgmma_rs) the accumulator offset is being added after casting to uint32_t*,
which treats C_offset as 32-bit words instead of accumulator elements; change
the string construction so C_offset is added to C_data before the
reinterpret_cast (i.e., apply (C_data) + (C_offset) inside the cast operand) for
the occurrences matching the snippet "reinterpret_cast<uint32_t*>((C_data)) +
(C_offset), (scale_out), *reinterpret_cast<uint32_t*>((e_data) + (E_offset))"
(also update the other occurrence around lines ~2838-2839) so the offset uses
accumulator element units prior to the uint32_t* conversion.
- Around line 2523-2534: The current code reads meta_tvm_dtype from
op->args[12]->dtype (the metadata pointer) which yields the handle/pointer
dtype; instead derive the element dtype of the metadata buffer (the element type
E of the metadata array) and base MetaType on that element dtype, then map
DataType::UInt(8/16/32/64) to "uint8_t"/"uint16_t"/"uint32_t"/"uint64_t" into
the MetaType string; update the logic that sets meta_tvm_dtype (used when
computing MetaType and currently named meta_tvm_dtype / MetaType) to inspect the
buffer/element dtype of op->args[12] rather than the pointer/handle dtype so
tl::mma_sp_sync is instantiated with the correct metadata type.

In `@testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py`:
- Line 60: The T.gemm_sp calls incorrectly pass trans_A as the 7th positional
argument (which binds to the transpose_E parameter); update each invocation of
T.gemm_sp (the calls in this file including the one at the shown location) to
supply the correct transpose_E value explicitly—either by replacing the 7th
positional argument with the intended boolean (e.g., False) or by using a named
argument transpose_E=False—to ensure the transpose_E parameter is set correctly
instead of reusing trans_A.

In `@tilelang/intrinsics/wgmma_sp_macro_generator.py`:
- Around line 220-235: Replace the debug print(...) calls in
wgmma_sp_macro_generator.py (the prints that emit elems_in_bytes,
self.block_row_warps, self.block_col_warps, warp_row_tiles, warp_col_tiles,
warp_k, warp_rows, warp_cols, M_DIM, n_dim, k_dim, wgmma_prefix, micro_size_x,
micro_size_y, micro_size_k, a_swizzle_mode, b_swizzle_mode,
a_leading_byte_offset, a_stride_byte_offset, b_leading_byte_offset,
b_stride_byte_offset, ak_atom_size, bk_atom_size, num_inst_m, num_inst_n,
accum_regs, self.e_dtype, self.SPARSE_FACTOR, self.SPARSE_SELECTOR,
a_swizzle_atom_elems, b_swizzle_atom_elems) with proper logging calls (e.g.,
logger.debug(...)) or remove them; ensure you reference the module/class logger
(create one via logging.getLogger(__name__) if none exists) so these messages
respect log levels and do not pollute stdout during normal compilation.

In `@tilelang/language/ast/ir.py`:
- Around line 1887-1888: The current references to ptx_wgmma_sp_ss and
ptx_wgmma_sp_rs in tilelang/language/ast/ir.py use _tir_op from upstream tvm.tir
which doesn't expose those symbols; change the import so _tir_op is the local
module that defines them (mirror the pattern in tilelang/language/tir/ir.py by
importing tilelang.language.tir.op as _tir_op) so ptx_wgmma_sp_ss and
ptx_wgmma_sp_rs resolve to the local implementations.

In `@tilelang/tileop/gemm_sp/__init__.py`:
- Around line 44-65: Remove the stray debug print statements that currently log
inside gemm_sp_infer_layout (the print showing type(self.A)) and inside lower
(the print of gemm_inst, impl_class, target, layout_map, thread_nums,
thread_var); either delete those print(...) calls or replace them with a proper
debug-level logger call (e.g., use an existing logging facility or tvm's
logging) so they don't print to stdout in production, ensuring the rest of the
methods (gemm_sp_infer_layout, lower, infer_layout, and the calls to
impl_class(...).infer_layout/lower) remain unchanged.

In `@tilelang/utils/sparse.py`:
- Around line 39-40: Remove the two leftover debug print statements inside the
_compress_fn function in sparse.py so they don't run on every `@tilelang.jit`
invocation; locate the prints that output f"{D=} {elem=} ..." and f"{[S, D *
elem // group]=}..." and delete them (or replace with a debug-level logger call
if telemetry is required) to avoid polluting stdout in user code.
- Around line 86-103: Callsites still use removed helpers and deprecated
parameters (compress_sm90, and compress(..., transposed=, arch=, block_k=));
update every invocation to call the new compress(A, meta_dtype=None) signature:
replace compress_sm90(...) with compress(...) and remove deprecated keyword args
(transposed, arch, block_k), pass only the tensor A and optionally meta_dtype
(converted via dtype string helpers if needed), and update imports to import
compress from tilelang.utils.sparse instead of compress_sm90; verify tests and
example calls (those that previously passed meta/arch/block info) supply
meta_dtype when required.

---

Nitpick comments:
In `@src/op/builtin.h`:
- Around line 372-381: Add full signature-style comments for the two new
intrinsics ptx_wgmma_sp_ss() and ptx_wgmma_sp_rs() mirroring the existing
ptx_wgmma_ss/ptx_wgmma_rs documentation: enumerate each argument in order (types
and brief meaning), explicitly document the extra sparse-related parameters (E
descriptor, E offset, sparse selector) and state the input count
(ptx_wgmma_sp_ss: 18 inputs; ptx_wgmma_sp_rs: 17 inputs) and why sp_ss has one
more argument than sp_rs; place these comments directly above the TVM_DLL const
Op &ptx_wgmma_sp_ss() and TVM_DLL const Op &ptx_wgmma_sp_rs() declarations so
callers can see the full signature without inspecting builtin.cc.

In `@src/op/gemm_sp.cc`:
- Around line 99-115: GetGemmSPInst lacks an explicit return after the ICHECK(0)
failure path which can trigger -Wreturn-type warnings; update
GemmSPNode::GetGemmSPInst to add a defensive return after ICHECK(0) (e.g.,
return a sensible default like GemmInst::kMMA or mirror other functions’
behavior) so the function always returns a GemmInst even if the ICHECK is
treated as non‑noreturn; modify the end of GetGemmSPInst to include this
explicit return to silence compiler warnings.
- Around line 145-180: The wrapper path in GemmSPNode::Lower can lose important
iter_values/predicate and reads/writes info when prim_func->body is not a
BlockRealize; either require/assert that the FFI always returns a BlockRealize
or populate the synthetic Block's iter_vars/reads/writes/predicate from the
original lowering context. Update GemmSPNode::Lower to (a) check prim_func->body
and if not a BlockRealize, derive and set appropriate iter_values and predicate
and fill reads/writes (using T.layout_map, T.thread_bounds/thread_var or any
access-region information available from prim_func or T) instead of using empty
arrays/const_true, or add a clear ICHECK/assert that the FFI must return a
BlockRealize with the correct metadata so downstream passes (access region
inference, buffer compaction) are not broken.

In `@src/tl_templates/cuda/instruction/cute_extension/mma_sm80_sparse.hpp`:
- Around line 345-380: The four integer MMA templates duplicate a SparseSel::One
specialization that only calls CUTE_INVALID_CONTROL_PATH and causes
unused-parameter warnings; replace those specializations by adding a
static_assert in each primary template (e.g., inside SM80_16x8x32_S32S8S8S32_TN,
SM80_16x8x64_S32S8S8S32_TN, SM80_16x8x32_S32U8U8S32_TN,
SM80_16x8x64_S32U8U8S32_TN) that forbids SparseSel::One (e.g.,
static_assert(spsel != SparseSel::One, "…")) to provide the compile-time
diagnostic and then remove the corresponding SparseSel::One specialization
blocks (which duplicate many register typedefs and the long fma(...) signature);
if you prefer to keep the specializations instead, silence -Wunused-parameter
for the fma parameters in those specialization definitions.

In `@src/tl_templates/cuda/instruction/cute_extension/mma_sm89_sparse.hpp`:
- Around line 25-201: The four near-identical structs
SM89_16x8x64_F32E4M3E4M3F32_TN, SM89_16x8x64_F32E4M3E5M2F32_TN,
SM89_16x8x64_F32E5M2E4M3F32_TN, and SM89_16x8x64_F32E5M2E5M2F32_TN differ only
by the fp8 operand suffix in the inline-asm and the diagnostic string; replace
them with a single parametric generator (either an X-macro or a template
wrapper) that takes the PTX suffix string (e.g., "e4m3.e4m3", "e4m3.e5m2", etc.)
and a short name fragment and emits the struct and its CUTE_INVALID_CONTROL_PATH
message, then update/replace the duplicated fma implementations to use that
generator so the asm literal and the invalid-path message are produced from the
single parameter. Ensure the generator exposes the same struct type names (or
typedefs) used elsewhere (or provide forwarding aliases) and preserves the fma
signature and static_assert(spsel == SparseSel::Zero).

In `@src/tl_templates/cuda/instruction/mma_sp.h`:
- Around line 93-101: Add a second compile-time check to ensure the register
counts match: in the same scope where the existing type equality static_assert
is located (the static_assert comparing typename Traits::DReg and typename
Traits::CReg inside exec), add a static_assert(Traits::kDRegs == Traits::kCRegs,
"tl::mma_sp_sync requires matching accumulator/output register counts"); this
will prevent mismatched std::make_index_sequence expansions in call_fma_sp_impl
and ensure Traits::kDRegs and Traits::kCRegs are equal before calling
call_fma_sp/Impl::fma.

In `@src/tl_templates/cuda/instruction/wgmma_sp.h`:
- Around line 44-48: The static_assert for CReg in tl::wgmma_sp_rs is redundant
because sizeof(uint32_t) == sizeof(float) on supported CUDA targets; update the
check to a single clear condition (e.g., require sizeof(CReg) ==
sizeof(uint32_t) to match the SS variant) by replacing the current "sizeof(CReg)
== sizeof(uint32_t) || sizeof(CReg) == sizeof(float)" assertion with a
single-size comparison referencing CReg and the tl::wgmma_sp_rs context so the
intent is unambiguous.

In `@testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py`:
- Around line 958-973: The function run_gemm_rr currently defines defaults for
num_stages, num_threads, and meta_dtype which is inconsistent with sibling
runners (run_gemm_ss, run_gemm_rs, run_gemm_sr) and can mask caller mistakes;
update run_gemm_rr's signature to require num_stages, num_threads, and
meta_dtype as positional parameters (remove the defaults =3, =128, =T.int16) so
callers must pass them explicitly, mirroring the other runner functions and
preventing silent fallback.
- Around line 311-336: Remove the redundant nested import statement "import
tilelang.language as T" that appears inside the matmul_rs function (and
similarly inside matmul_sr and matmul_rr); since T is already imported at module
level, delete the per-function re-imports so the functions use the module-level
T binding and avoid shadowing/redundant imports.

In `@tilelang/intrinsics/wgmma_sp_macro_generator.py`:
- Around line 504-516: The code creates temporary tuples by doing tuple(E_other)
+ (E_base0 + wk + mk, E_base1 + wi + mi) when indexing E_shared_buf; change both
branches of the conditional in the assignment to E_local_buf[...] to use
iterable unpacking instead (e.g., E_shared_buf[*E_other, E_base0 + wk + mk,
E_base1 + wi + mi] and E_shared_buf[*E_other, E_base0 + wi + mi, E_base1 + wk +
mk]) so you avoid allocating intermediate tuples; update the two places inside
the loop that reference E_other/E_shared_buf (the assignment to E_local_buf in
the for j loop) and keep the surrounding logic and return of
_warp_ldmatrix_e(E_local_buf, E_buf, inst_i, ki, thread_binding) intact.
- Around line 237-243: In the _warp_mma macro, the local unpacking uses tx but
that variable is unused; change the unpacked name from tx to _tx to satisfy the
linter and indicate intentional unused value (e.g., replace "tx, warp_n, warp_m
= self.extract_thread_binding(thread_binding)" with "_tx, warp_n, warp_m =
self.extract_thread_binding(thread_binding)" inside the _warp_mma definition),
leaving all other logic and variables (k_blocks, e_stage_elems, E_local, etc.)
unchanged.
- Around line 446-460: The ldmatrix_available flag in ldmatrix_e is hardcoded
False which leaves a TODO untracked; update ldmatrix_e to either compute
ldmatrix_available from the current conditions (e.g., based on a_dtype/e_dtype
and transposed state) or, if you cannot implement the fast path now, replace the
hardcoded assignment with a clear tracked marker (e.g., set ldmatrix_available =
False and add a FIXME(author_name) comment) and create a follow-up issue
referencing ldmatrix_e and ldmatrix_available so the optimization isn't lost;
ensure the change mentions the constraint (int8 + transposed case) and the
location (ldmatrix_e in wgmma_sp_macro_generator.py) so future work can
implement the fast path.

In `@tilelang/tileop/gemm_sp/__init__.py`:
- Around line 56-65: Both infer_layout and lower call _select_gemm_instruction
and instantiate impl_class(self) separately, causing duplicate FFI calls and
extra object creation; cache the selected gemm_inst and/or the instantiated
implementation on self so both methods reuse the same value/instance: call
self._select_gemm_instruction(thread_nums, target) once (e.g. store as
self._cached_gemm_inst keyed by thread_nums/target or invalidate when inputs
change), use self._get_implementation_class(cached_inst, target) once and keep
the instantiated impl (instead of impl_class(self) twice) so infer_layout and
lower call the same impl object's infer_layout and lower methods.
- Around line 44-54: Replace the instance-method registrations with module-level
free-function wrappers consistent with the gemm pattern: add two top-level
functions (e.g., gemm_sp_infer_layout(gemm_sp: GemmSP, target: Target,
thread_bounds: Range) and gemm_sp_lower(gemm_sp: GemmSP, target: Target,
layout_map: dict, thread_bounds: Range, thread_var: tir.Var)) that call the
corresponding instance methods (gemm_sp.infer_layout(...) and
gemm_sp.lower(...)) and register those free functions with
tvm_ffi.register_global_func instead of the current instance-bound functions;
keep the same argument order and behavior (extract thread_nums =
thread_bounds.extent and forward to infer_layout/lower) so the C++ callers
continue to work unchanged.

In `@tilelang/tileop/gemm_sp/gemm_sp_wgmma.py`:
- Around line 138-146: The trailing raise ValueError is only reachable for
unsupported gemm types (sr/rr) but the current structure returns inside the
is_gemm_ss() and is_gemm_rs() branches, making control flow less explicit;
update the logic in the function that defines _gemm_ssr/_gemm_rsr to first check
supported cases (e.g., if not (self.is_gemm_ss() or self.is_gemm_rs()): raise
ValueError(...)) and then use explicit if/elif for self.is_gemm_ss() and
self.is_gemm_rs() to return _Simplify(_gemm_ssr, inline_let=True) or
_Simplify(_gemm_rsr, inline_let=True) respectively, mirroring infer_layout’s
else: raise pattern and keeping symbols _gemm_ssr, _gemm_rsr, _Simplify,
is_gemm_ss, is_gemm_rs unchanged.
- Around line 73-80: The lower method's signature is incorrect: change the type
hint for thread_nums in gemm_sp_wgmma.py lower(self, ...) from Range to int |
tir.PrimExpr (or tir.PrimExpr) to match callers that pass thread_bounds.extent,
and either remove the unused parameter mbar_phase_expr or propagate it through
the function (use it where barriers/phases are computed) so it is referenced;
update the signature and all internal references to use the new thread_nums type
and delete mbar_phase_expr if you choose to drop it, ensuring callers (e.g., in
tilelang/tileop/gemm_sp/__init__.py) remain compatible.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b1de2664-45b9-438b-9341-46deb38c38c9

📥 Commits

Reviewing files that changed from the base of the PR and between eb55efe and edcfc0f.

📒 Files selected for processing (21)
  • src/op/builtin.cc
  • src/op/builtin.h
  • src/op/gemm_sp.cc
  • src/op/gemm_sp.h
  • src/target/codegen_cuda.cc
  • src/target/codegen_cuda.h
  • src/tl_templates/cuda/instruction/cute_extension/mma_sm80_sparse.hpp
  • src/tl_templates/cuda/instruction/cute_extension/mma_sm89_sparse.hpp
  • src/tl_templates/cuda/instruction/mma_sp.h
  • src/tl_templates/cuda/instruction/wgmma_sp.h
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
  • tilelang/intrinsics/mma_sp_layout.py
  • tilelang/intrinsics/mma_sp_macro_generator.py
  • tilelang/intrinsics/wgmma_sp_macro_generator.py
  • tilelang/language/ast/ir.py
  • tilelang/language/tir/ir.py
  • tilelang/language/tir/op.py
  • tilelang/tileop/gemm_sp/__init__.py
  • tilelang/tileop/gemm_sp/gemm_sp_mma.py
  • tilelang/tileop/gemm_sp/gemm_sp_wgmma.py
  • tilelang/utils/sparse.py

Comment thread src/backend/cuda/codegen/codegen_cuda.cc Outdated
Comment thread src/backend/cuda/codegen/codegen_cuda.cc
Comment thread testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py Outdated
Comment thread tilelang/intrinsics/wgmma_sp_macro_generator.py Outdated
Comment thread tilelang/language/ast/ir.py
Comment thread tilelang/tileop/gemm_sp/__init__.py
Comment thread tilelang/utils/sparse.py Outdated
Comment thread tilelang/utils/sparse.py
botbw and others added 8 commits May 11, 2026 04:50
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The constants (SPARSE_PARAMS, get_e_factor, get_e_replicate_factor) are
hardware-level parameters primarily consumed by mma_sp_macro_generator;
intrinsics/ is the natural home. utils/sparse.py re-exports them for
backward compatibility.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
"params" better reflects the content: hardware sparsity format parameters
(SPARSE_PARAMS, get_e_factor, get_e_replicate_factor), not runtime config.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
from __future__ import annotations makes all annotations lazy strings,
which breaks TIR prim_func's get_type_hints() on the inner kernel.
Use Optional[torch.dtype] instead of the union syntax to avoid needing it.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (3)
tilelang/intrinsics/wgmma_sp_macro_generator.py (3)

89-102: 💤 Low value

Unused n_dim parameter in _initialize_wgmma_prefix.

The parameter n_dim is accepted (and passed from __init__ as self.n_dim at Line 79) but never read — inst_n is computed from self.warp_col_tiles. Either remove the parameter or actually drive inst_n from it; the current shape is misleading and risks future callers assuming the argument has effect.

Proposed change
-    def _initialize_wgmma_prefix(self, n_dim: int = 16):
+    def _initialize_wgmma_prefix(self):
         inst_m, inst_n = 64, gcd(self.warp_col_tiles, 256)

…and update the call site at Line 79 accordingly.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/intrinsics/wgmma_sp_macro_generator.py` around lines 89 - 102, The
method _initialize_wgmma_prefix currently ignores its n_dim parameter (inst_n is
computed from self.warp_col_tiles), so either remove the unused parameter and
stop passing self.n_dim at the call site in __init__, or make n_dim actually
drive inst_n (e.g., replace inst_n = gcd(self.warp_col_tiles, 256) with an
expression using n_dim such as inst_n = gcd(n_dim, 256) or inst_n = n_dim
followed by the same validation/asserts); update the __init__ call accordingly
to match the chosen approach and keep the existing validations for inst_n in
_initialize_wgmma_prefix.

38-39: 💤 Low value

Prefer Optional[Layout] and instance-level initialization.

a_shared_layout: Layout = None / b_shared_layout: Layout = None are declared at class scope with non-Optional type annotations bound to None. Type checkers will flag the mismatch, and the attributes effectively become shared class state until an instance assigns them. Initializing them in __init__ (or annotating as Optional[Layout] = None and assigning in __init__) avoids both pitfalls.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/intrinsics/wgmma_sp_macro_generator.py` around lines 38 - 39, The
attributes a_shared_layout and b_shared_layout are declared at class scope with
non-Optional type annotations but set to None, creating a type mismatch and
shared class state; update the class to either annotate them as Optional[Layout]
= None or (preferably) remove the class-level assignments and initialize
self.a_shared_layout and self.b_shared_layout inside the class __init__ method
(use the Layout type for annotations on instance attributes) so each instance
gets its own Layout fields and the type checker is satisfied.

437-469: 🏗️ Heavy lift

Track the ldmatrix TODO.

The ldmatrix_available = False # TODO: use ldmatrix when possible shortcut forces the slow per-thread elementwise metadata load path for every sparse WGMMA call. Worth filing a follow-up so this doesn't get lost — sparse WGMMA E-loads are on the hot path for every K-block.

Want me to open a tracking issue with a short summary of the supported (e_dtype, a_dtype) permutations and the CUTLASS reference (include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h) for whoever picks this up?

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/intrinsics/wgmma_sp_macro_generator.py` around lines 437 - 469, Open
a tracking issue for the TODO that hardcodes ldmatrix_available = False in
wgmma_sp_macro_generator (variable name ldmatrix_available) so we don't lose
this perf regresssion; in the issue include a short summary of the supported
(e_dtype, a_dtype) permutations (matching the metadata_* helper layouts used
here: metadata_8bit_load_32x4_to_shared_16x8_layout_8bit,
metadata_8bit_load_32x4_to_shared_16x4_layout_16bit,
metadata_8bit_load_32x4_to_shared_16x4_layout_32bit,
metadata_16bit_load_32x2_to_shared_16x4_layout_8bit,
metadata_16bit_load_32x2_to_shared_16x2_layout_16bit,
metadata_16bit_load_32x2_to_shared_16x2_layout_32bit,
metadata_32bit_load_32x1_to_shared_16x2_layout_8bit) and reference the CUTLASS
header include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h; add the
issue number back into the source comment next to the TODO so future authors can
track progress.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@docs/deeplearning_operators/matmul_sparse.md`:
- Around line 67-69: The code uses the non-existent
SparseTensorCoreIntrinEmitter.E_FACTOR_MAP to obtain E_factor; replace that with
a call to get_e_factor(in_dtype, metadata_dtype) and add an import for
get_e_factor from tilelang.intrinsics.sparse_params at the top of the example;
specifically change the assignment where E_factor is computed (currently
referencing E_FACTOR_MAP) to call get_e_factor with the local variables in_dtype
and metadata_dtype so the example uses the correct API.

In `@testing/python/utils/test_compress_utils.py`:
- Around line 135-136: The ad-hoc main entry calls the pytest-parametrized test
function test_compress() without providing required parameters (dtype,
meta_dtype), causing a TypeError when run directly; remove the "__main__" block
or replace it with a pytest invocation (e.g., call pytest.main([...])) so the
file is executed by pytest rather than directly calling test_compress, and
ensure test_compress is only run through pytest's runner; reference the test
function name test_compress to locate the faulty block.
- Around line 49-60: The call to T.gemm_sp currently passes transpose flags
positionally as (trans_A, trans_B, trans_A), which is error-prone; update the
call to use named parameters so the mapping is explicit (e.g.,
transpose_A=trans_A, transpose_B=trans_B, transpose_E=trans_A) in the T.gemm_sp
invocation to match the copy logic for A and E and improve readability and
maintainability.

In `@tilelang/utils/sparse.py`:
- Around line 63-68: The code writes nz_idx[nz_count] before checking the
device_assert, which can cause out-of-bounds writes; change the loop around
nz_count/nz_idx (the T.if_then_else stores that reference, the
dense_local[local_idx + i] check, and the T.device_assert) so the store is
guarded: only write nz_idx when nz_count < elem (e.g., use a conditional that
writes to nz_idx[min(nz_count, elem-1)] or wrap the store in an if/conditional
that checks nz_count < elem), and separately increment an overflow counter if
you still want to assert on excess nonzeros; apply the same fix to the second
occurrence referenced around lines 135-140 so both loops avoid indexing past
nz_idx.
- Around line 19-20: The default meta dtype is set to T.int16 which diverges
from torch_compress's behavior (torch treats int8 metadata as int32); change the
module-wide constant _DEFAULT_META_DTYPE from T.int16 to T.int32 and update any
fallback logic in the compress-related code paths (see references to
_DEFAULT_META_DTYPE and the compress function blocks around lines ~189-191 and
~233-243) so that omissions of meta_dtype use T.int32 by default, only falling
back to narrower types when explicitly required.

---

Nitpick comments:
In `@tilelang/intrinsics/wgmma_sp_macro_generator.py`:
- Around line 89-102: The method _initialize_wgmma_prefix currently ignores its
n_dim parameter (inst_n is computed from self.warp_col_tiles), so either remove
the unused parameter and stop passing self.n_dim at the call site in __init__,
or make n_dim actually drive inst_n (e.g., replace inst_n =
gcd(self.warp_col_tiles, 256) with an expression using n_dim such as inst_n =
gcd(n_dim, 256) or inst_n = n_dim followed by the same validation/asserts);
update the __init__ call accordingly to match the chosen approach and keep the
existing validations for inst_n in _initialize_wgmma_prefix.
- Around line 38-39: The attributes a_shared_layout and b_shared_layout are
declared at class scope with non-Optional type annotations but set to None,
creating a type mismatch and shared class state; update the class to either
annotate them as Optional[Layout] = None or (preferably) remove the class-level
assignments and initialize self.a_shared_layout and self.b_shared_layout inside
the class __init__ method (use the Layout type for annotations on instance
attributes) so each instance gets its own Layout fields and the type checker is
satisfied.
- Around line 437-469: Open a tracking issue for the TODO that hardcodes
ldmatrix_available = False in wgmma_sp_macro_generator (variable name
ldmatrix_available) so we don't lose this perf regresssion; in the issue include
a short summary of the supported (e_dtype, a_dtype) permutations (matching the
metadata_* helper layouts used here:
metadata_8bit_load_32x4_to_shared_16x8_layout_8bit,
metadata_8bit_load_32x4_to_shared_16x4_layout_16bit,
metadata_8bit_load_32x4_to_shared_16x4_layout_32bit,
metadata_16bit_load_32x2_to_shared_16x4_layout_8bit,
metadata_16bit_load_32x2_to_shared_16x2_layout_16bit,
metadata_16bit_load_32x2_to_shared_16x2_layout_32bit,
metadata_32bit_load_32x1_to_shared_16x2_layout_8bit) and reference the CUTLASS
header include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h; add the
issue number back into the source comment next to the TODO so future authors can
track progress.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ebf491c4-e1dc-42b8-8d46-bd736f0c693e

📥 Commits

Reviewing files that changed from the base of the PR and between edcfc0f and cd52cc5.

📒 Files selected for processing (15)
  • benchmark/matmul/benchmark_matmul_sp.py
  • benchmark/matmul/benchmark_matmul_sp_compress.py
  • docs/deeplearning_operators/matmul_sparse.md
  • examples/gemm_sp/example_custom_compress.py
  • examples/gemm_sp/example_gemm_sp.py
  • examples/gemm_sp/test_example_gemm_sp.py
  • examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
  • testing/python/utils/test_compress_utils.py
  • tilelang/intrinsics/mma_sp_macro_generator.py
  • tilelang/intrinsics/sparse_params.py
  • tilelang/intrinsics/wgmma_sp_macro_generator.py
  • tilelang/tileop/gemm_sp/__init__.py
  • tilelang/tileop/gemm_sp/gemm_sp_wgmma.py
  • tilelang/utils/sparse.py
💤 Files with no reviewable changes (2)
  • examples/gemm_sp/test_example_gemm_sp.py
  • examples/gemm_sp/example_custom_compress.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • tilelang/tileop/gemm_sp/gemm_sp_wgmma.py
  • tilelang/tileop/gemm_sp/init.py
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py

Comment thread docs/deeplearning_operators/matmul_sparse.md Outdated
Comment thread testing/python/utils/test_compress_utils.py Outdated
Comment thread testing/python/utils/test_compress_utils.py Outdated
Comment thread tilelang/utils/sparse.py
Comment thread tilelang/utils/sparse.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/backend/cuda/op/gemm_sp.cc`:
- Around line 119-138: When falling back in
policy.isFullRow()/policy.isFullCol(), the code must pick warp factors that are
divisors of num_warps rather than using truncated integer division; change the
logic that computes max_m_warps/max_n_warps so it searches for the largest
divisor d of num_warps satisfying d * kMPerWarp <= M (for m_warp) or d *
k_n_per_warp <= N (for n_warp), then set the other axis to num_warps / d (and
clamp to 1 if needed). Apply the same divisor-based selection in the symmetric
fallback block referenced around lines 196-213 as well; use the variables
m_warp, n_warp, num_warps, kMPerWarp and k_n_per_warp and the
policy.isFullRow()/policy.isFullCol() branches to locate the spots to change.

In `@src/op/gemm_sp.cc`:
- Around line 256-267: The op is being marked as tcgen05 by setting
ann.Set("is_tcgen05", ...) in the TLOpBuilder for
TVM_REGISTER_OP("tl.tileop.tcgen05_gemm_sp"), but the CUDA selector still routes
to FatalTcgen5Unavailable; change the builder so it only sets the is_tcgen05
annotation (or registers the tcgen05 variant) when the backend actually supports
tcgen05 lowering: e.g., gate the ann.Set("is_tcgen05", ...) behind a
runtime/compile-time capability check (a new isTcgen05Supported() or existing
backend query) or avoid registering the tcgen05 name variant until support is
added; keep the GemmSP(...) call unchanged except for omitting the is_tcgen05
flag when unsupported to prevent hard failures during lowering.

In `@testing/python/utils/test_compress_utils.py`:
- Line 77: The ternary size expressions are redundant because both branches are
(N, N) (K == N), so simplify the torch.randint calls by removing the conditional
and using size=(N, N) directly; update occurrences that use the pattern (e.g.,
the B construction line referencing trans_B and the similar line at 80) to call
torch.randint(size=(N, N), low=low, high=high, dtype=in_dtype, device="cuda") to
eliminate the useless ternary.

In `@tilelang/cuda/intrinsics/macro/wgmma_sp_macro_generator.py`:
- Around line 148-149: The swizzle detection is being called with BufferRegion
objects (A_region, B_region) but _determinate_swizzle_mode builds comparison
layouts from a tir.Buffer, causing incorrect NONE fallbacks for swizzled shared
layouts; update the calls that compute a_swizzle_mode and b_swizzle_mode to pass
the underlying tir.Buffer (e.g., A_region.buffer or B_region.buffer or the
attribute that holds the Buffer) and the appropriate region/view information
into _determinate_swizzle_mode (or add an overload) so the helper receives the
actual tir.Buffer used to construct comparison layouts; apply the same change to
the other occurrence that invokes _determinate_swizzle_mode for B (the similar
call around the later occurrence).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b325a80d-9693-4056-9d4a-97f07b1a3fdb

📥 Commits

Reviewing files that changed from the base of the PR and between cd52cc5 and bd11d73.

📒 Files selected for processing (16)
  • src/backend/cuda/codegen/codegen_cuda.cc
  • src/backend/cuda/codegen/codegen_cuda.h
  • src/backend/cuda/op/gemm_sp.cc
  • src/op/builtin.cc
  • src/op/builtin.h
  • src/op/gemm.cc
  • src/op/gemm_sp.cc
  • src/op/gemm_sp.h
  • src/transform/lower_opaque_block.cc
  • testing/python/issue/test_tilelang_issue_tma_no_ws.py
  • testing/python/issue/test_tilelang_issue_ws_simt_copy_full_producer_extent.py
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
  • testing/python/utils/test_compress_utils.py
  • tilelang/cuda/intrinsics/layout/mma_sp_layout.py
  • tilelang/cuda/intrinsics/macro/mma_sp_macro_generator.py
  • tilelang/cuda/intrinsics/macro/wgmma_sp_macro_generator.py
💤 Files with no reviewable changes (1)
  • testing/python/issue/test_tilelang_issue_tma_no_ws.py
✅ Files skipped from review due to trivial changes (2)
  • src/backend/cuda/codegen/codegen_cuda.h
  • src/transform/lower_opaque_block.cc
🚧 Files skipped from review as they are similar to previous changes (3)
  • src/op/builtin.h
  • src/op/builtin.cc
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py

Comment thread src/backend/cuda/op/gemm_sp.cc
Comment thread src/op/gemm_sp.cc
Comment thread testing/python/utils/test_compress_utils.py
Comment on lines +148 to +149
a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout)
b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Pass the underlying buffer into swizzle detection.

These calls hand _determinate_swizzle_mode() a BufferRegion, but the helper builds comparison layouts from a tir.Buffer. On swizzled shared layouts that can silently fall back to SwizzleMode.NONE and produce the wrong WGMMA descriptor.

Suggested fix
-        a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout)
-        b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
+        a_swizzle_mode = self._determinate_swizzle_mode(A_region.buffer, self.a_shared_layout)
+        b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)
...
-        b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
+        b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)

Also applies to: 330-330

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/cuda/intrinsics/macro/wgmma_sp_macro_generator.py` around lines 148
- 149, The swizzle detection is being called with BufferRegion objects
(A_region, B_region) but _determinate_swizzle_mode builds comparison layouts
from a tir.Buffer, causing incorrect NONE fallbacks for swizzled shared layouts;
update the calls that compute a_swizzle_mode and b_swizzle_mode to pass the
underlying tir.Buffer (e.g., A_region.buffer or B_region.buffer or the attribute
that holds the Buffer) and the appropriate region/view information into
_determinate_swizzle_mode (or add an overload) so the helper receives the actual
tir.Buffer used to construct comparison layouts; apply the same change to the
other occurrence that invokes _determinate_swizzle_mode for B (the similar call
around the later occurrence).

@botbw botbw changed the title [Refactor] Refactor gemm_sp and update documentation [Backend] Refactor gemm_sp May 12, 2026
@botbw
Copy link
Copy Markdown
Contributor Author

botbw commented May 12, 2026

@LeiWang1999 This refactoring is ready to be reviewed, thanks :p

@LeiWang1999 LeiWang1999 self-requested a review May 12, 2026 15:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants