Skip to content

[Refactor] Refactor multiple TensorCoreIntrinEmitter to provide atom-level mma control interface#2161

Merged
LeiWang1999 merged 6 commits intotile-ai:mainfrom
Rachmanino:mma-atom
May 9, 2026
Merged

[Refactor] Refactor multiple TensorCoreIntrinEmitter to provide atom-level mma control interface#2161
LeiWang1999 merged 6 commits intotile-ai:mainfrom
Rachmanino:mma-atom

Conversation

@Rachmanino
Copy link
Copy Markdown
Collaborator

@Rachmanino Rachmanino commented May 7, 2026

  • Enhance intrinsic modules by adding WGMMATensorCoreIntrinEmitter and TCGEN05TensorCoreIntrinEmitter.
  • Introduce descriptor parameter classes for WGMMA and TCGEN05, improving the initialization and atom offset computation processes.
  • Update the TensorCoreIntrinEmitter to support atom-level interfaces and macro definitions for efficient matrix multiplication operations.
  • Provide examples to use mma atoms to write GEMM kernel
  • Pass CI and regression test

Summary by CodeRabbit

  • New Features

    • Exposed new tensor core emitter variants and descriptor parameters for SM90 and SM100 GPU architectures.
    • Added support for Hopper (SM90) and Blackwell (SM100) tensor core instructions alongside existing Ampere support.
  • Tests

    • Added comprehensive test coverage for tensor core GEMM kernels across SM80+, SM90, and SM100 architectures, validating kernel correctness against reference implementations.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 7, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 7, 2026

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR refactors tensor-core macro generators (MMA, TCGEN05, WGMMA) to separate descriptor parameter computation and initialization from atom-level instruction emission. It introduces frozen dataclasses for descriptor parameters, pure-Python helpers for offset/swizzle calculation, and reusable atom-level emitters; top-level macros now orchestrate descriptor setup once and delegate per-atom PTX emission. A comprehensive test module validates three GPU GEMM codegen paths (SM80+ MMA, SM90 WGMMA, SM100 TCGEN05).

Changes

Tensor Core Macro Refactoring

Layer / File(s) Summary
Descriptor Dataclasses
tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py, tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py
Introduce TCGEN05DescriptorParams and WGMMADescriptorParams frozen dataclasses capturing precomputed descriptor settings (swizzle mode, byte offsets, K-major flag, atom/element sizing).
Atom Geometry Properties
tilelang/cuda/intrinsics/macro/mma_macro_generator.py, tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py, tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py
Expose atom-grid properties: mma_num_inst_m/n, tcgen05_meta_unpacked, tcgen05_num_inst_m/n/k_atoms, wgmma_num_inst_m/n/k_atoms/a_regs/accum_regs. Add _access_ptr_from static resolver for Buffer/BufferLoad/BufferRegion inputs.
TCGEN05 Descriptor Compute & Init
tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py
Pure-Python helpers compute_tcgen05_a/b_desc_params calculate descriptor parameters from buffers/meta; TIR-emitting init_tcgen05_a/b_desc initialize pre-allocated descriptors; compute_tcgen05_instr_desc derives 64-bit instruction descriptor; tcgen05_atom_arrive wraps arrive sequencing.
TCGEN05 Atom Emitters
tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py
Introduce tcgen05_ss_atom, tcgen05_ts_atom, and tcgen05_blockscaled_atom that compute per-atom byte offsets from precomputed descriptor params and issue T.ptx_tcgen05_mma_* instructions with scale-out masking.
TCGEN05 Macro Refactoring
tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py
Refactor tcgen05mma_ss, tcgen05mma_ts, and tcgen05mma_blockscaled to precompute descriptor params and instruction descriptors once, initialize descriptors via helpers, then emit via atom-level emitters in nested unroll loops instead of inlining offset/descriptor math.
WGMMA Dataclass Import
tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py
Add from dataclasses import dataclass to support WGMMADescriptorParams definition.
WGMMA Descriptor Compute & Init
tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py
Pure-Python helpers compute_wgmma_a/b_desc_params calculate descriptor parameters with K-major/non-K-major branching; TIR-emitting init_wgmma_a/b_desc initialize pre-allocated descriptors; warpgroup sync wrappers (wgmma_fence_a/c, wgmma_arrive, wgmma_commit, wgmma_wait) emit TIR calls parameterized by register counts.
WGMMA Atom Emitters
tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py
Introduce wgmma_ss_atom and wgmma_rs_atom that compute per-atom A/B/C offsets (including swizzle-dependent branching), derive scale-out behavior, and emit T.ptx_wgmma_ss or T.ptx_wgmma_rs instructions.
WGMMA Macro Refactoring
tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py
Refactor wgmma (SS) and wgmma_rs (RS) macros to precompute descriptor params once, initialize descriptors, perform warpgroup fence/arrive/commit/wait sequencing, and delegate per-atom emission to atom emitters instead of inlining offset/descriptor setup.
MMA Atomization
tilelang/cuda/intrinsics/macro/mma_macro_generator.py
Refactor mma() from direct PTX emission into a warp-level macro looping over (warp_rows, warp_cols) and calling mma_atom for each (i, j) pair; move replicate_b logic inside mma_atom where second T.ptx_mma offsets are adjusted.
Package-Level Exports
tilelang/intrinsics/__init__.py
Re-export emitter aliases (WGMMATensorCoreIntrinEmitter, TCGEN05TensorCoreIntrinEmitter) and descriptor types (WGMMADescriptorParams, TCGEN05DescriptorParams) from macro generators for public API surface.
Tests: Atom-level GEMM Kernels
testing/python/language/test_tilelang_language_atom_mma.py
Add comprehensive test module with shared-layout helpers (make_swizzle_layout, infer_wgmma_shared_layout), kernel builders (make_mma_atom_kernel, make_wgmma_atom_kernel, make_tcgen05_atom_kernel), and test functions (test_mma_atom_gemm, test_wgmma_atom_gemm, test_tcgen05_atom_gemm) that compile kernels, assert generated-source instruction strings, and validate numerical correctness vs float GEMM references.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • tile-ai/tilelang#1945: Both PRs refactor tensor-core macro emitters (TensorCoreIntrinEmitter) and TCGEN05/WGMMA macro generators, introducing descriptor params and atom-level emission helpers.
  • tile-ai/tilelang#1323: Both PRs touch TCGEN05/WGMMA tensor-core codepaths and descriptor/metadata handling in macro generators.
  • tile-ai/tilelang#2165: Both PRs introduce WGMMA/TCGEN05 emitter re-exports and new atom-level emitter APIs to the public surface.

Suggested reviewers

  • LeiWang1999

Poem

🐰 Atoms dance in ordered rows,
Descriptors precomputed, parameters flow,
Where WGMMA meets TCGEN and MMA's song,
Each macro delegates, atoms march along!
Tests validate three paths on GPU shores,

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 54.07% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Title check ✅ Passed The title accurately describes the main objective: refactoring TensorCoreIntrinEmitter classes to provide atom-level MMA control interfaces, which aligns with all file changes.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/intrinsics/tcgen05_macro_generator.py (1)

332-357: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Non-swizzled A descriptors collapse to zero-width atoms.

These paths compute swizzle_byte_size() // elems_in_bytes for A even when SwizzleMode.NONE. For fp16/bf16 that becomes 1 // 2 == 0, which then propagates into swizzle_atom_elems, k_atom_size, and the A offset formulas used by the atom emitters. The default linear-layout A path will therefore generate bogus descriptor offsets. Please special-case the NONE path to a non-zero linear-layout atom size before constructing TCGEN05DescriptorParams.

Also applies to: 654-700

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/intrinsics/tcgen05_macro_generator.py` around lines 332 - 357, The A
descriptor path can produce zero-width atoms when a_swizzle_mode.is_none()
because swizzle_byte_size() // elems_in_bytes can be 0 for types like fp16/bf16;
change the computation of a_swizzle_atom_elems to ensure a non-zero
linear-layout atom size (e.g. set a_swizzle_atom_elems = 1 or max(1,
a_swizzle_mode.swizzle_byte_size() // elems_in_bytes) when
a_swizzle_mode.is_none()) before using it to compute a_leading_byte_offset,
a_stride_byte_offset, k_atom_size and constructing TCGEN05DescriptorParams (and
do the equivalent fix for the B-path where b_swizzle_atom_elems is computed);
apply the same guard in the other occurrence referenced (lines ~654-700) so
non-swizzled descriptors never collapse to zero-width atoms.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tilelang/intrinsics/tcgen05_macro_generator.py`:
- Around line 217-222: The code reads tcgen05_num_inst_m / tcgen05_num_inst_n /
compute_tcgen05_instr_desc() before ensuring TCGEN05 meta is initialized, which
can access tcgen05_meta_unpacked before self.meta exists; wrap the early uses
with the same lazy initialization used by tcgen05mma_blockscaled() by calling
get_tcgen5_mma_meta(...) (or the equivalent lazy guard) before any access to
tcgen05_num_inst_* or compute_tcgen05_instr_desc(), and apply the same guard to
the second occurrence around the block that currently calls those helpers at the
later location (the block around compute_tcgen05_instr_desc() at the other
site).
- Around line 600-652: In compute_tcgen05_b_desc_params replace the derivation
of elems_in_bytes (currently using DataType(self.a_dtype)) to use B's data type
(DataType(self.b_dtype)) so all B-related byte offsets and the computed
k_atom_size (which depends on swizzle_atom_elems and micro_size_k) are correct
when A and B dtypes differ; update the variable computation and any dependent
uses (elems_in_bytes, b_leading_byte_offset, b_stride_byte_offset, and
k_atom_size) accordingly in the compute_tcgen05_b_desc_params method.

In `@tilelang/intrinsics/wgmma_macro_generator.py`:
- Around line 362-405: compute_wgmma_a_desc_params currently computes
a_swizzle_atom_elems as a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
which yields 0 for SwizzleMode.NONE (fp16/bf16) and breaks downstream math;
change the logic in compute_wgmma_a_desc_params so that when
a_swizzle_mode.is_none() (or when swizzle_byte_size() // elems_in_bytes == 0)
you force a_swizzle_atom_elems to a non-zero value (e.g., 1) before it's used to
compute a_m_axis_atoms, k_atom_size, and the returned WGMMADescriptorParams;
ensure the final returned swizzle_atom_elems and k_atom_size use that clamped
value so the non-swizzled SS path addresses shared memory correctly.
- Around line 315-360: compute_wgmma_b_desc_params is calculating B swizzle atom
sizes and byte offsets using A's element width (elems_in_bytes derived from
self.a_dtype); change it to use B's data type (DataType(self.b_dtype).bits // 8)
so all B-specific calculations (b_swizzle_atom_elems, b_leading_byte_offset,
b_stride_byte_offset, and any downstream shifts) are computed from B’s element
size; update the variable name or comment to reflect it is B's elems_in_bytes
and ensure the return fields (swizzle_atom_elems, elems_in_bytes) reflect the B
dtype.

---

Outside diff comments:
In `@tilelang/intrinsics/tcgen05_macro_generator.py`:
- Around line 332-357: The A descriptor path can produce zero-width atoms when
a_swizzle_mode.is_none() because swizzle_byte_size() // elems_in_bytes can be 0
for types like fp16/bf16; change the computation of a_swizzle_atom_elems to
ensure a non-zero linear-layout atom size (e.g. set a_swizzle_atom_elems = 1 or
max(1, a_swizzle_mode.swizzle_byte_size() // elems_in_bytes) when
a_swizzle_mode.is_none()) before using it to compute a_leading_byte_offset,
a_stride_byte_offset, k_atom_size and constructing TCGEN05DescriptorParams (and
do the equivalent fix for the B-path where b_swizzle_atom_elems is computed);
apply the same guard in the other occurrence referenced (lines ~654-700) so
non-swizzled descriptors never collapse to zero-width atoms.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 6425585c-b2a5-442e-8c19-31977d79ac0e

📥 Commits

Reviewing files that changed from the base of the PR and between 40acae4 and f602c71.

📒 Files selected for processing (4)
  • tilelang/intrinsics/__init__.py
  • tilelang/intrinsics/mma_macro_generator.py
  • tilelang/intrinsics/tcgen05_macro_generator.py
  • tilelang/intrinsics/wgmma_macro_generator.py

Comment on lines +217 to +222
num_inst_m = self.tcgen05_num_inst_m
num_inst_n = self.tcgen05_num_inst_n
num_k_atoms = self.tcgen05_num_k_atoms
a_params = self.compute_tcgen05_a_desc_params(A_buf)
b_params = self.compute_tcgen05_b_desc_params(B_buf)
instr_desc = self.compute_tcgen05_instr_desc()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Initialize TCGEN05 meta before using the new atom-count helpers.

Lines 217-222 and 264-268 now read tcgen05_num_inst_* / compute_tcgen05_instr_desc() before any local meta bootstrap. On a fresh emitter, that reaches tcgen05_meta_unpacked before self.meta exists, so tcgen05mma_ss() and tcgen05mma_ts() can fail before emitting any TIR. Please add the same lazy get_tcgen5_mma_meta(...) guard that tcgen05mma_blockscaled() already has.

Also applies to: 264-268

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/intrinsics/tcgen05_macro_generator.py` around lines 217 - 222, The
code reads tcgen05_num_inst_m / tcgen05_num_inst_n /
compute_tcgen05_instr_desc() before ensuring TCGEN05 meta is initialized, which
can access tcgen05_meta_unpacked before self.meta exists; wrap the early uses
with the same lazy initialization used by tcgen05mma_blockscaled() by calling
get_tcgen5_mma_meta(...) (or the equivalent lazy guard) before any access to
tcgen05_num_inst_* or compute_tcgen05_instr_desc(), and apply the same guard to
the second occurrence around the block that currently calls those helpers at the
later location (the block around compute_tcgen05_instr_desc() at the other
site).

Comment on lines +600 to +652
def compute_tcgen05_b_desc_params(self, B_buf) -> TCGEN05DescriptorParams:
"""Compute B descriptor parameters from the B shared buffer.

This is a pure-Python helper -- no TIR code is emitted.
The returned ``TCGEN05DescriptorParams`` is passed to
``init_tcgen05_b_desc()`` and ``tcgen05_*_atom()`` methods.

Parameters
----------
B_buf : Buffer or BufferRegion
The B operand in shared memory.
"""
atom_m, atom_n, _, _, enable_2cta = self.tcgen05_meta_unpacked
n_dim = self.block_col_warps * self.warp_col_tiles
n_dim_per_cta = n_dim // 2 if enable_2cta else n_dim
k_dim = self.chunk
micro_size_k = self.micro_size_k
elems_in_bytes = (DataType(self.a_dtype).bits + 7) // 8
b_is_k_major = self.b_transposed

b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout)
b_swizzle_atom_elems = (
n_dim_per_cta if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
)

b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim_per_cta * elems_in_bytes)
b_stride_byte_offset = (
(8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim_per_cta == 8 else (8 * 8 * elems_in_bytes))
)
if not b_swizzle_mode.is_none():
if b_is_k_major:
b_leading_byte_offset = 16
b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size()
else:
b_n_axis_atoms = n_dim_per_cta // b_swizzle_atom_elems
if b_n_axis_atoms <= 1:
b_leading_byte_offset = 0
else:
b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
if b_n_axis_atoms <= 1:
b_stride_byte_offset = 8 * elems_in_bytes * n_dim_per_cta
else:
b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems

return TCGEN05DescriptorParams(
swizzle_mode=int(b_swizzle_mode),
leading_byte_offset=int(b_leading_byte_offset >> 4),
stride_byte_offset=int(b_stride_byte_offset >> 4),
swizzle_atom_elems=b_swizzle_atom_elems,
k_atom_size=max(b_swizzle_atom_elems // micro_size_k, 1),
elems_in_bytes=elems_in_bytes,
is_k_major=b_is_k_major,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Use self.b_dtype when computing B descriptor byte widths.

Line 617 derives elems_in_bytes from self.a_dtype, but every offset in this helper is for operand B. If A and B dtypes diverge, the B descriptor offsets and k_atom_size are wrong.

Suggested fix
-        elems_in_bytes = (DataType(self.a_dtype).bits + 7) // 8
+        elems_in_bytes = (DataType(self.b_dtype).bits + 7) // 8
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/intrinsics/tcgen05_macro_generator.py` around lines 600 - 652, In
compute_tcgen05_b_desc_params replace the derivation of elems_in_bytes
(currently using DataType(self.a_dtype)) to use B's data type
(DataType(self.b_dtype)) so all B-related byte offsets and the computed
k_atom_size (which depends on swizzle_atom_elems and micro_size_k) are correct
when A and B dtypes differ; update the variable computation and any dependent
uses (elems_in_bytes, b_leading_byte_offset, b_stride_byte_offset, and
k_atom_size) accordingly in the compute_tcgen05_b_desc_params method.

Comment on lines +315 to +360
def compute_wgmma_b_desc_params(self, B_region: BufferRegion) -> WGMMADescriptorParams:
"""Compute B descriptor parameters from the B shared buffer region.

This is a pure-Python helper -- no TIR code is emitted.
The returned ``WGMMADescriptorParams`` is passed to
``init_wgmma_b_desc()`` and ``wgmma_*_atom()`` methods.
"""
n_dim = self.block_col_warps * self.warp_col_tiles
k_dim = self.chunk
micro_size_k = self.micro_size_k
elems_in_bytes = DataType(self.a_dtype).bits // 8
b_is_k_major = self.b_transposed

a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout)
b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
b_swizzle_atom_elems = (
n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
)

elems_in_bits = DataType(self.a_dtype).bits
elems_in_bytes = elems_in_bits // 8
b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes)
b_stride_byte_offset = (
(8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes))
)
if not b_swizzle_mode.is_none():
if b_is_k_major:
b_leading_byte_offset = 16
b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size()
else:
b_n_axis_atoms = n_dim // b_swizzle_atom_elems
if b_n_axis_atoms <= 1:
b_leading_byte_offset = 0
else:
b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
if b_n_axis_atoms <= 1:
b_stride_byte_offset = 8 * elems_in_bytes * n_dim
else:
b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems

return WGMMADescriptorParams(
swizzle_mode=int(b_swizzle_mode),
leading_byte_offset=int(b_leading_byte_offset >> 4),
stride_byte_offset=int(b_stride_byte_offset >> 4),
swizzle_atom_elems=b_swizzle_atom_elems,
k_atom_size=max(b_swizzle_atom_elems // micro_size_k, 1),
elems_in_bytes=elems_in_bytes,
is_k_major=b_is_k_major,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

B descriptor math is using A’s element width.

Line 325 sizes B swizzle atoms and byte offsets from self.a_dtype. Since this emitter accepts distinct a_dtype / b_dtype, mixed-input configurations will miscompute the B descriptor.

Suggested fix
-        elems_in_bytes = DataType(self.a_dtype).bits // 8
+        elems_in_bytes = DataType(self.b_dtype).bits // 8
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def compute_wgmma_b_desc_params(self, B_region: BufferRegion) -> WGMMADescriptorParams:
"""Compute B descriptor parameters from the B shared buffer region.
This is a pure-Python helper -- no TIR code is emitted.
The returned ``WGMMADescriptorParams`` is passed to
``init_wgmma_b_desc()`` and ``wgmma_*_atom()`` methods.
"""
n_dim = self.block_col_warps * self.warp_col_tiles
k_dim = self.chunk
micro_size_k = self.micro_size_k
elems_in_bytes = DataType(self.a_dtype).bits // 8
b_is_k_major = self.b_transposed
a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout)
b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
b_swizzle_atom_elems = (
n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
)
elems_in_bits = DataType(self.a_dtype).bits
elems_in_bytes = elems_in_bits // 8
b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes)
b_stride_byte_offset = (
(8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes))
)
if not b_swizzle_mode.is_none():
if b_is_k_major:
b_leading_byte_offset = 16
b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size()
else:
b_n_axis_atoms = n_dim // b_swizzle_atom_elems
if b_n_axis_atoms <= 1:
b_leading_byte_offset = 0
else:
b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
if b_n_axis_atoms <= 1:
b_stride_byte_offset = 8 * elems_in_bytes * n_dim
else:
b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems
return WGMMADescriptorParams(
swizzle_mode=int(b_swizzle_mode),
leading_byte_offset=int(b_leading_byte_offset >> 4),
stride_byte_offset=int(b_stride_byte_offset >> 4),
swizzle_atom_elems=b_swizzle_atom_elems,
k_atom_size=max(b_swizzle_atom_elems // micro_size_k, 1),
elems_in_bytes=elems_in_bytes,
is_k_major=b_is_k_major,
)
def compute_wgmma_b_desc_params(self, B_region: BufferRegion) -> WGMMADescriptorParams:
"""Compute B descriptor parameters from the B shared buffer region.
This is a pure-Python helper -- no TIR code is emitted.
The returned ``WGMMADescriptorParams`` is passed to
``init_wgmma_b_desc()`` and ``wgmma_*_atom()`` methods.
"""
n_dim = self.block_col_warps * self.warp_col_tiles
k_dim = self.chunk
micro_size_k = self.micro_size_k
elems_in_bytes = DataType(self.b_dtype).bits // 8
b_is_k_major = self.b_transposed
b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
b_swizzle_atom_elems = (
n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
)
b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes)
b_stride_byte_offset = (
(8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes))
)
if not b_swizzle_mode.is_none():
if b_is_k_major:
b_leading_byte_offset = 16
b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size()
else:
b_n_axis_atoms = n_dim // b_swizzle_atom_elems
if b_n_axis_atoms <= 1:
b_leading_byte_offset = 0
else:
b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
if b_n_axis_atoms <= 1:
b_stride_byte_offset = 8 * elems_in_bytes * n_dim
else:
b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems
return WGMMADescriptorParams(
swizzle_mode=int(b_swizzle_mode),
leading_byte_offset=int(b_leading_byte_offset >> 4),
stride_byte_offset=int(b_stride_byte_offset >> 4),
swizzle_atom_elems=b_swizzle_atom_elems,
k_atom_size=max(b_swizzle_atom_elems // micro_size_k, 1),
elems_in_bytes=elems_in_bytes,
is_k_major=b_is_k_major,
)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/intrinsics/wgmma_macro_generator.py` around lines 315 - 360,
compute_wgmma_b_desc_params is calculating B swizzle atom sizes and byte offsets
using A's element width (elems_in_bytes derived from self.a_dtype); change it to
use B's data type (DataType(self.b_dtype).bits // 8) so all B-specific
calculations (b_swizzle_atom_elems, b_leading_byte_offset, b_stride_byte_offset,
and any downstream shifts) are computed from B’s element size; update the
variable name or comment to reflect it is B's elems_in_bytes and ensure the
return fields (swizzle_atom_elems, elems_in_bytes) reflect the B dtype.

Comment on lines +362 to +405
def compute_wgmma_a_desc_params(self, A_region: BufferRegion) -> WGMMADescriptorParams:
"""Compute A descriptor parameters from the A shared buffer region (SS variant).

This is a pure-Python helper -- no TIR code is emitted.
The returned ``WGMMADescriptorParams`` is passed to
``init_wgmma_a_desc()`` and ``wgmma_ss_atom()`` methods.
"""
m_dim = self.block_row_warps * self.warp_row_tiles
k_dim = self.chunk
micro_size_k = self.micro_size_k
elems_in_bytes = DataType(self.a_dtype).bits // 8
a_is_k_major = not self.a_transposed

a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout)
a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
accum_bits = DataType(accum_dtype).bits
accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32

# by default, we utilize non-swizzle layout offset
a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * elems_in_bytes)
a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * elems_in_bytes)

if not a_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
if a_is_k_major:
a_leading_byte_offset = 16
a_stride_byte_offset = 8 * a_swizzle_mode.swizzle_byte_size()
else:
# MN Major
# LBO represents the distance between two atoms along the M dimension
# SBO represents the distance between two atoms along the K dimension
a_m_axis_atoms = m_dim // a_swizzle_atom_elems
if a_m_axis_atoms <= 1:
a_leading_byte_offset = 0
else:
a_leading_byte_offset = 8 * a_swizzle_mode.swizzle_atom_size() * (a_swizzle_mode.swizzle_byte_size() // elems_in_bytes)

a_leading_byte_offset = (
8 * a_swizzle_mode.swizzle_atom_size() * (a_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
)
if a_m_axis_atoms <= 1:
a_stride_byte_offset = 8 * elems_in_bytes * m_dim
else:
a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems

b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes)
b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes))
if not b_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
if b_is_k_major:
b_leading_byte_offset = 16
b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size()
else:
# MN Major, K * N
# LBO represents the distance between two atoms along the N dimension
# SBO represents the distance between two atoms along the K dimension
b_n_axis_atoms = n_dim // b_swizzle_atom_elems
if b_n_axis_atoms <= 1:
b_leading_byte_offset = 0
else:
b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
if b_n_axis_atoms <= 1:
b_stride_byte_offset = 8 * elems_in_bytes * n_dim
else:
b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems
return WGMMADescriptorParams(
swizzle_mode=int(a_swizzle_mode),
leading_byte_offset=int(a_leading_byte_offset >> 4),
stride_byte_offset=int(a_stride_byte_offset >> 4),
swizzle_atom_elems=a_swizzle_atom_elems,
k_atom_size=max(a_swizzle_atom_elems // micro_size_k, 1),
elems_in_bytes=elems_in_bytes,
is_k_major=a_is_k_major,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Linear A layouts currently produce swizzle_atom_elems == 0.

Line 376 uses SwizzleMode.NONE.swizzle_byte_size() // elems_in_bytes, which is 0 for fp16/bf16. That zero is then fed into k_atom_size and the A offset math consumed by wgmma_ss_atom(), so the non-swizzled SS path misaddresses shared memory. Please give the NONE case a non-zero linear-layout atom size.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/intrinsics/wgmma_macro_generator.py` around lines 362 - 405,
compute_wgmma_a_desc_params currently computes a_swizzle_atom_elems as
a_swizzle_mode.swizzle_byte_size() // elems_in_bytes which yields 0 for
SwizzleMode.NONE (fp16/bf16) and breaks downstream math; change the logic in
compute_wgmma_a_desc_params so that when a_swizzle_mode.is_none() (or when
swizzle_byte_size() // elems_in_bytes == 0) you force a_swizzle_atom_elems to a
non-zero value (e.g., 1) before it's used to compute a_m_axis_atoms,
k_atom_size, and the returned WGMMADescriptorParams; ensure the final returned
swizzle_atom_elems and k_atom_size use that clamped value so the non-swizzled SS
path addresses shared memory correctly.

@Rachmanino
Copy link
Copy Markdown
Collaborator Author

@regression-perf

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_atom_mma.py (1)

309-327: ⚡ Quick win

test_tcgen05_atom_gemm inlines its runner and has weaker source assertions than the WGMMA test

Two consistency gaps vs. the WGMMA path:

  1. No _run_tcgen05_atom helper: MMA and WGMMA both separate kernel construction (make_*) from execution (_run_*) and then call _run_* from the test_* function. TCGEN05 conflates all three in one function, making it harder to add future parametrized calls (e.g., different M/N/K) without duplication.

  2. Single source assertion ("tcgen05.mma" in src): The WGMMA test asserts four distinct codegen artifacts (wgmma_ss, initialize_wgmma_descriptor, warpgroup_arrive, warpgroup_commit_batch). TCGEN05 would benefit from similar assertions (e.g., presence of descriptor-init and barrier-arrive helpers).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@testing/python/language/test_tilelang_language_atom_mma.py` around lines 309
- 327, test_tcgen05_atom_gemm currently builds, runs, and asserts source inline;
extract the execution part into a new helper _run_tcgen05_atom that takes
(M,N,K,in_dtype,out_dtype,compute_dtype) and calls make_tcgen05_atom_kernel then
runs the kernel (mirroring the MMA/WGMMA pattern), and update
test_tcgen05_atom_gemm to call that helper; additionally strengthen the source
assertions by checking for multiple TCGEN05 codegen artifacts (e.g., include
checks for "tcgen05.mma", the descriptor initialization helper name, and
warpgroup barrier helpers such as "warpgroup_arrive" and
"warpgroup_commit_batch") so the test validates descriptor-init and
barrier-arrive/commit codegen similarly to the WGMMA test.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@testing/python/language/test_tilelang_language_atom_mma.py`:
- Around line 141-142: The computation of warp_row_tiles and warp_col_tiles uses
integer division that silently truncates remainders; add assertions before those
assignments to ensure M % block_row_warps == 0 and N % block_col_warps == 0
(with clear messages referencing M, block_row_warps and N, block_col_warps) so
misconfigured callers fail loud and the derived warp_row_tiles and
warp_col_tiles values are guaranteed correct.
- Around line 40-42: The section header comment above the SM80+ tests
incorrectly says "Hopper" — update that comment to say "Ampere" so it matches
the decorator requires_cuda_compute_version_ge(8, 0) and the SM80 architecture;
locate the SM80+ MMA (atom-level) header comment and replace "Hopper" with
"Ampere" to keep the comments consistent with the
requires_cuda_compute_version_ge(8, 0) decorator.
- Around line 195-198: C_local is allocated via T.alloc_fragment but never
initialized before calls to emi.wgmma_ss_atom with T.bool(True), which if
representing scale_d=1 will read uninitialized accumulators; fix by either
inserting T.clear(C_local) before the nested loops that call emi.wgmma_ss_atom
(or immediately after allocation) so accumulation starts from zero, or if
T.bool(True) is instead intended as an "is-first-k" reset flag make the flag
conditional (e.g., pass True only when ki==0 and m==0 and n==0) and leave
C_local unmodified otherwise; locate uses of C_local, emi.wgmma_ss_atom, and any
T.bool(True) in the loop to implement the appropriate change.

---

Nitpick comments:
In `@testing/python/language/test_tilelang_language_atom_mma.py`:
- Around line 309-327: test_tcgen05_atom_gemm currently builds, runs, and
asserts source inline; extract the execution part into a new helper
_run_tcgen05_atom that takes (M,N,K,in_dtype,out_dtype,compute_dtype) and calls
make_tcgen05_atom_kernel then runs the kernel (mirroring the MMA/WGMMA pattern),
and update test_tcgen05_atom_gemm to call that helper; additionally strengthen
the source assertions by checking for multiple TCGEN05 codegen artifacts (e.g.,
include checks for "tcgen05.mma", the descriptor initialization helper name, and
warpgroup barrier helpers such as "warpgroup_arrive" and
"warpgroup_commit_batch") so the test validates descriptor-init and
barrier-arrive/commit codegen similarly to the WGMMA test.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ad873b00-e695-4146-b782-40d46fd86569

📥 Commits

Reviewing files that changed from the base of the PR and between e80bfd7 and 2b882fa.

📒 Files selected for processing (1)
  • testing/python/language/test_tilelang_language_atom_mma.py

Comment thread testing/python/language/test_tilelang_language_atom_mma.py
Comment on lines +141 to +142
warp_row_tiles = M // block_row_warps
warp_col_tiles = N // block_col_warps
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Silent truncation if M or N are not divisible by block_row_warps/block_col_warps

warp_row_tiles = M // block_row_warps and warp_col_tiles = N // block_col_warps silently drop any remainder. A misconfigured caller would get wrong tile counts and incorrect results without any diagnostic. An assertion at this point keeps the failure loud:

✏️ Proposed fix
+    assert M % block_row_warps == 0, f"M={M} not divisible by block_row_warps={block_row_warps}"
+    assert N % block_col_warps == 0, f"N={N} not divisible by block_col_warps={block_col_warps}"
     warp_row_tiles = M // block_row_warps
     warp_col_tiles = N // block_col_warps
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
warp_row_tiles = M // block_row_warps
warp_col_tiles = N // block_col_warps
assert M % block_row_warps == 0, f"M={M} not divisible by block_row_warps={block_row_warps}"
assert N % block_col_warps == 0, f"N={N} not divisible by block_col_warps={block_col_warps}"
warp_row_tiles = M // block_row_warps
warp_col_tiles = N // block_col_warps
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@testing/python/language/test_tilelang_language_atom_mma.py` around lines 141
- 142, The computation of warp_row_tiles and warp_col_tiles uses integer
division that silently truncates remainders; add assertions before those
assignments to ensure M % block_row_warps == 0 and N % block_col_warps == 0
(with clear messages referencing M, block_row_warps and N, block_col_warps) so
misconfigured callers fail loud and the derived warp_row_tiles and
warp_col_tiles values are guaranteed correct.

Comment on lines +195 to +198
for n in T.unroll(num_inst_n):
for m in T.unroll(num_inst_m):
for ki in T.unroll(num_k_atoms):
emi.wgmma_ss_atom(desc_a, desc_b, C_local, m, n, ki, a_params, b_params, T.bool(True))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Accumulator never initialized before wgmma_ss_atom — potential correctness bug

C_local is allocated with T.alloc_fragment (line 172) but never cleared. wgmma_fence_c is a fence/barrier primitive, not a zero-initializer. If T.bool(True) maps to WGMMA PTX scale_d=1 (accumulate), the very first call reads uninitialized memory and adds A*B to it — producing garbage.

In contrast, the MMA path correctly calls T.clear(C_local) (line 96) before any atom.

Confirm whether T.bool(True) here means:

  • scale_d=1 (accumulate-into-D) → add T.clear(C_local) before the loop, or
  • "is-first-k" reset flag (True → init D to zero) → the flag should be conditional on the first iteration (e.g., ki == 0 and m == 0 and n == 0), not always-True.
🛡️ Fix if `T.bool(True)` means `scale_d=1`
+            T.clear(C_local)
+
             emi.wgmma_fence_c(C_local)
             emi.wgmma_arrive()
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@testing/python/language/test_tilelang_language_atom_mma.py` around lines 195
- 198, C_local is allocated via T.alloc_fragment but never initialized before
calls to emi.wgmma_ss_atom with T.bool(True), which if representing scale_d=1
will read uninitialized accumulators; fix by either inserting T.clear(C_local)
before the nested loops that call emi.wgmma_ss_atom (or immediately after
allocation) so accumulation starts from zero, or if T.bool(True) is instead
intended as an "is-first-k" reset flag make the flag conditional (e.g., pass
True only when ki==0 and m==0 and n==0) and leave C_local unmodified otherwise;
locate uses of C_local, emi.wgmma_ss_atom, and any T.bool(True) in the loop to
implement the appropriate change.

Rachmanino and others added 5 commits May 8, 2026 20:48
…TCGEN05TensorCoreIntrinEmitter. Introduce descriptor parameter classes for WGMMA and TCGEN05, improving the initialization and atom offset computation processes. Update the TensorCoreIntrinEmitter to support atom-level interfaces and macro definitions for efficient matrix multiplication operations.
@Rachmanino Rachmanino changed the title [WIP][Refactor] Refactor multiple TensorCoreIntrinEmitter to provide atom-level mma control interface [Refactor] Refactor multiple TensorCoreIntrinEmitter to provide atom-level mma control interface May 8, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py (1)

318-354: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Clamp non-swizzled A atom size before using it in offset math.

Both SS paths can produce a_swizzle_atom_elems == 0 for linear fp16/bf16 layouts. That value then feeds directly into the A-offset formulas, collapsing distinct atoms onto the same shared-memory region.

Suggested fix
-        a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
+        a_swizzle_atom_elems = max(a_swizzle_mode.swizzle_byte_size() // elems_in_bytes, 1)

Also applies to: 680-708

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py` around lines 318 -
354, The bug is that a_swizzle_atom_elems can be 0 for non-swizzled linear
fp16/bf16 layouts and is used in offset math; clamp it to at least 1 before any
use. Update the block computing a_swizzle_atom_elems in
tcgen05_macro_generator.py to set a_swizzle_atom_elems =
max(a_swizzle_atom_elems, 1) (or equivalent) before computing a_m_axis_atoms,
a_leading_byte_offset/a_stride_byte_offset, and before passing
swizzle_atom_elems/k_atom_size into TCGEN05DescriptorParams (also apply the same
clamp in the other occurrence around lines 680-708 that constructs A params).
Ensure k_atom_size still uses the clamped value when doing
max(a_swizzle_atom_elems // micro_size_k, 1).
♻️ Duplicate comments (4)
tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py (2)

217-222: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Initialize TCGEN05 meta before using the atom-count helpers.

Both entrypoints now read tcgen05_num_inst_* / compute_tcgen05_instr_desc() before any lazy meta bootstrap. On a fresh emitter that reaches tcgen05_meta_unpacked before self.meta exists, so tcgen05mma_ss() and tcgen05mma_ts() can fail before emitting TIR.

Suggested fix
+        m_dim = self.block_row_warps * self.warp_row_tiles
+        n_dim = self.block_col_warps * self.warp_col_tiles
+        if not hasattr(self, "meta") or len(self.meta) != 5:
+            self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim, disable_2cta=False)
+
         num_inst_m = self.tcgen05_num_inst_m
         num_inst_n = self.tcgen05_num_inst_n
         num_k_atoms = self.tcgen05_num_k_atoms

Also applies to: 264-268

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py` around lines 217 -
222, The code reads tcgen05_num_inst_m, tcgen05_num_inst_n and calls
compute_tcgen05_instr_desc() before ensuring the TCGEN05 meta is unpacked; call
tcgen05_meta_unpacked() (or otherwise initialize/unpack self.meta) at the start
of the routine before accessing tcgen05_* fields or calling
compute_tcgen05_instr_desc(), so that helpers like tcgen05mma_ss() and
tcgen05mma_ts() won't run against an uninitialized self.meta.

626-662: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Use self.b_dtype when computing B descriptor byte widths.

This helper still derives B’s element width from self.a_dtype, so mixed-input kernels will generate the wrong B descriptor offsets.

Suggested fix
-        elems_in_bytes = (DataType(self.a_dtype).bits + 7) // 8
+        elems_in_bytes = (DataType(self.b_dtype).bits + 7) // 8
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py` around lines 626 -
662, The B-descriptor computation incorrectly uses self.a_dtype to compute
elems_in_bytes, causing wrong byte offsets for mixed-input kernels; update the
computation to use self.b_dtype (i.e., replace DataType(self.a_dtype) with
DataType(self.b_dtype)) so b_leading_byte_offset, b_stride_byte_offset,
b_swizzle_atom_elems, and any derived values (returned in
TCGEN05DescriptorParams) are based on B's element size.
tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py (2)

371-399: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Clamp linear A swizzle_atom_elems to a non-zero value.

For SwizzleMode.NONE, swizzle_byte_size() // elems_in_bytes becomes 0 for fp16/bf16, so the A-offset math in wgmma_ss_atom() collapses and multiple atoms read the same shared-memory slice.

Suggested fix
-        a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
+        a_swizzle_atom_elems = max(a_swizzle_mode.swizzle_byte_size() // elems_in_bytes, 1)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py` around lines 371 -
399, The a_swizzle_atom_elems calculation can be zero for SwizzleMode.NONE
(a_swizzle_mode.swizzle_byte_size() // elems_in_bytes) which breaks downstream
A-offset math (e.g. wgmma_ss_atom); clamp a_swizzle_atom_elems to a minimum of 1
after it is computed. Locate the variables a_swizzle_mode and
a_swizzle_atom_elems in the function that returns WGMMADescriptorParams and
replace the raw division result with a max(..., 1) or otherwise set
a_swizzle_atom_elems = max(a_swizzle_atom_elems, 1) before using it in the
subsequent offset logic and in the returned WGMMADescriptorParams.

325-355: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Use self.b_dtype for B descriptor byte math.

compute_wgmma_b_desc_params() is still sizing B’s swizzle atoms and byte offsets from self.a_dtype. Mixed-input kernels will build the B descriptor with the wrong element width and misaddress shared memory.

Suggested fix
-        elems_in_bytes = DataType(self.a_dtype).bits // 8
+        elems_in_bytes = DataType(self.b_dtype).bits // 8
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py` around lines 325 -
355, compute_wgmma_b_desc_params() currently computes B's element byte size and
all byte-offset math using DataType(self.a_dtype); change it to use
DataType(self.b_dtype) so elems_in_bytes and any calculations derived from it
(b_swizzle_atom_elems, b_leading_byte_offset, b_stride_byte_offset, and
k_atom_size) are based on B's actual dtype. Update the assignment of
elems_in_bytes and any expressions that multiply/divide by elems_in_bytes to
reference the new value so the returned WGMMADescriptorParams fields
(swizzle_atom_elems, leading_byte_offset, stride_byte_offset, k_atom_size,
elems_in_bytes) are correct for mixed-input kernels.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py`:
- Around line 318-354: The bug is that a_swizzle_atom_elems can be 0 for
non-swizzled linear fp16/bf16 layouts and is used in offset math; clamp it to at
least 1 before any use. Update the block computing a_swizzle_atom_elems in
tcgen05_macro_generator.py to set a_swizzle_atom_elems =
max(a_swizzle_atom_elems, 1) (or equivalent) before computing a_m_axis_atoms,
a_leading_byte_offset/a_stride_byte_offset, and before passing
swizzle_atom_elems/k_atom_size into TCGEN05DescriptorParams (also apply the same
clamp in the other occurrence around lines 680-708 that constructs A params).
Ensure k_atom_size still uses the clamped value when doing
max(a_swizzle_atom_elems // micro_size_k, 1).

---

Duplicate comments:
In `@tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py`:
- Around line 217-222: The code reads tcgen05_num_inst_m, tcgen05_num_inst_n and
calls compute_tcgen05_instr_desc() before ensuring the TCGEN05 meta is unpacked;
call tcgen05_meta_unpacked() (or otherwise initialize/unpack self.meta) at the
start of the routine before accessing tcgen05_* fields or calling
compute_tcgen05_instr_desc(), so that helpers like tcgen05mma_ss() and
tcgen05mma_ts() won't run against an uninitialized self.meta.
- Around line 626-662: The B-descriptor computation incorrectly uses
self.a_dtype to compute elems_in_bytes, causing wrong byte offsets for
mixed-input kernels; update the computation to use self.b_dtype (i.e., replace
DataType(self.a_dtype) with DataType(self.b_dtype)) so b_leading_byte_offset,
b_stride_byte_offset, b_swizzle_atom_elems, and any derived values (returned in
TCGEN05DescriptorParams) are based on B's element size.

In `@tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py`:
- Around line 371-399: The a_swizzle_atom_elems calculation can be zero for
SwizzleMode.NONE (a_swizzle_mode.swizzle_byte_size() // elems_in_bytes) which
breaks downstream A-offset math (e.g. wgmma_ss_atom); clamp a_swizzle_atom_elems
to a minimum of 1 after it is computed. Locate the variables a_swizzle_mode and
a_swizzle_atom_elems in the function that returns WGMMADescriptorParams and
replace the raw division result with a max(..., 1) or otherwise set
a_swizzle_atom_elems = max(a_swizzle_atom_elems, 1) before using it in the
subsequent offset logic and in the returned WGMMADescriptorParams.
- Around line 325-355: compute_wgmma_b_desc_params() currently computes B's
element byte size and all byte-offset math using DataType(self.a_dtype); change
it to use DataType(self.b_dtype) so elems_in_bytes and any calculations derived
from it (b_swizzle_atom_elems, b_leading_byte_offset, b_stride_byte_offset, and
k_atom_size) are based on B's actual dtype. Update the assignment of
elems_in_bytes and any expressions that multiply/divide by elems_in_bytes to
reference the new value so the returned WGMMADescriptorParams fields
(swizzle_atom_elems, leading_byte_offset, stride_byte_offset, k_atom_size,
elems_in_bytes) are correct for mixed-input kernels.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e6569594-ad75-493c-825a-0c72a25c7d1f

📥 Commits

Reviewing files that changed from the base of the PR and between fac9b5c and c1d00a4.

📒 Files selected for processing (5)
  • testing/python/language/test_tilelang_language_atom_mma.py
  • tilelang/cuda/intrinsics/macro/mma_macro_generator.py
  • tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py
  • tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py
  • tilelang/intrinsics/__init__.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tilelang/intrinsics/init.py
  • testing/python/language/test_tilelang_language_atom_mma.py

@Rachmanino Rachmanino requested a review from LeiWang1999 May 9, 2026 04:26
@LeiWang1999 LeiWang1999 merged commit 1fbd994 into tile-ai:main May 9, 2026
6 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants