[Refactor] Refactor multiple TensorCoreIntrinEmitter to provide atom-level mma control interface#2161
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR refactors tensor-core macro generators (MMA, TCGEN05, WGMMA) to separate descriptor parameter computation and initialization from atom-level instruction emission. It introduces frozen dataclasses for descriptor parameters, pure-Python helpers for offset/swizzle calculation, and reusable atom-level emitters; top-level macros now orchestrate descriptor setup once and delegate per-atom PTX emission. A comprehensive test module validates three GPU GEMM codegen paths (SM80+ MMA, SM90 WGMMA, SM100 TCGEN05). ChangesTensor Core Macro Refactoring
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/intrinsics/tcgen05_macro_generator.py (1)
332-357:⚠️ Potential issue | 🔴 Critical | ⚡ Quick winNon-swizzled A descriptors collapse to zero-width atoms.
These paths compute
swizzle_byte_size() // elems_in_bytesfor A even whenSwizzleMode.NONE. For fp16/bf16 that becomes1 // 2 == 0, which then propagates intoswizzle_atom_elems,k_atom_size, and the A offset formulas used by the atom emitters. The default linear-layout A path will therefore generate bogus descriptor offsets. Please special-case theNONEpath to a non-zero linear-layout atom size before constructingTCGEN05DescriptorParams.Also applies to: 654-700
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/intrinsics/tcgen05_macro_generator.py` around lines 332 - 357, The A descriptor path can produce zero-width atoms when a_swizzle_mode.is_none() because swizzle_byte_size() // elems_in_bytes can be 0 for types like fp16/bf16; change the computation of a_swizzle_atom_elems to ensure a non-zero linear-layout atom size (e.g. set a_swizzle_atom_elems = 1 or max(1, a_swizzle_mode.swizzle_byte_size() // elems_in_bytes) when a_swizzle_mode.is_none()) before using it to compute a_leading_byte_offset, a_stride_byte_offset, k_atom_size and constructing TCGEN05DescriptorParams (and do the equivalent fix for the B-path where b_swizzle_atom_elems is computed); apply the same guard in the other occurrence referenced (lines ~654-700) so non-swizzled descriptors never collapse to zero-width atoms.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tilelang/intrinsics/tcgen05_macro_generator.py`:
- Around line 217-222: The code reads tcgen05_num_inst_m / tcgen05_num_inst_n /
compute_tcgen05_instr_desc() before ensuring TCGEN05 meta is initialized, which
can access tcgen05_meta_unpacked before self.meta exists; wrap the early uses
with the same lazy initialization used by tcgen05mma_blockscaled() by calling
get_tcgen5_mma_meta(...) (or the equivalent lazy guard) before any access to
tcgen05_num_inst_* or compute_tcgen05_instr_desc(), and apply the same guard to
the second occurrence around the block that currently calls those helpers at the
later location (the block around compute_tcgen05_instr_desc() at the other
site).
- Around line 600-652: In compute_tcgen05_b_desc_params replace the derivation
of elems_in_bytes (currently using DataType(self.a_dtype)) to use B's data type
(DataType(self.b_dtype)) so all B-related byte offsets and the computed
k_atom_size (which depends on swizzle_atom_elems and micro_size_k) are correct
when A and B dtypes differ; update the variable computation and any dependent
uses (elems_in_bytes, b_leading_byte_offset, b_stride_byte_offset, and
k_atom_size) accordingly in the compute_tcgen05_b_desc_params method.
In `@tilelang/intrinsics/wgmma_macro_generator.py`:
- Around line 362-405: compute_wgmma_a_desc_params currently computes
a_swizzle_atom_elems as a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
which yields 0 for SwizzleMode.NONE (fp16/bf16) and breaks downstream math;
change the logic in compute_wgmma_a_desc_params so that when
a_swizzle_mode.is_none() (or when swizzle_byte_size() // elems_in_bytes == 0)
you force a_swizzle_atom_elems to a non-zero value (e.g., 1) before it's used to
compute a_m_axis_atoms, k_atom_size, and the returned WGMMADescriptorParams;
ensure the final returned swizzle_atom_elems and k_atom_size use that clamped
value so the non-swizzled SS path addresses shared memory correctly.
- Around line 315-360: compute_wgmma_b_desc_params is calculating B swizzle atom
sizes and byte offsets using A's element width (elems_in_bytes derived from
self.a_dtype); change it to use B's data type (DataType(self.b_dtype).bits // 8)
so all B-specific calculations (b_swizzle_atom_elems, b_leading_byte_offset,
b_stride_byte_offset, and any downstream shifts) are computed from B’s element
size; update the variable name or comment to reflect it is B's elems_in_bytes
and ensure the return fields (swizzle_atom_elems, elems_in_bytes) reflect the B
dtype.
---
Outside diff comments:
In `@tilelang/intrinsics/tcgen05_macro_generator.py`:
- Around line 332-357: The A descriptor path can produce zero-width atoms when
a_swizzle_mode.is_none() because swizzle_byte_size() // elems_in_bytes can be 0
for types like fp16/bf16; change the computation of a_swizzle_atom_elems to
ensure a non-zero linear-layout atom size (e.g. set a_swizzle_atom_elems = 1 or
max(1, a_swizzle_mode.swizzle_byte_size() // elems_in_bytes) when
a_swizzle_mode.is_none()) before using it to compute a_leading_byte_offset,
a_stride_byte_offset, k_atom_size and constructing TCGEN05DescriptorParams (and
do the equivalent fix for the B-path where b_swizzle_atom_elems is computed);
apply the same guard in the other occurrence referenced (lines ~654-700) so
non-swizzled descriptors never collapse to zero-width atoms.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 6425585c-b2a5-442e-8c19-31977d79ac0e
📒 Files selected for processing (4)
tilelang/intrinsics/__init__.pytilelang/intrinsics/mma_macro_generator.pytilelang/intrinsics/tcgen05_macro_generator.pytilelang/intrinsics/wgmma_macro_generator.py
| num_inst_m = self.tcgen05_num_inst_m | ||
| num_inst_n = self.tcgen05_num_inst_n | ||
| num_k_atoms = self.tcgen05_num_k_atoms | ||
| a_params = self.compute_tcgen05_a_desc_params(A_buf) | ||
| b_params = self.compute_tcgen05_b_desc_params(B_buf) | ||
| instr_desc = self.compute_tcgen05_instr_desc() |
There was a problem hiding this comment.
Initialize TCGEN05 meta before using the new atom-count helpers.
Lines 217-222 and 264-268 now read tcgen05_num_inst_* / compute_tcgen05_instr_desc() before any local meta bootstrap. On a fresh emitter, that reaches tcgen05_meta_unpacked before self.meta exists, so tcgen05mma_ss() and tcgen05mma_ts() can fail before emitting any TIR. Please add the same lazy get_tcgen5_mma_meta(...) guard that tcgen05mma_blockscaled() already has.
Also applies to: 264-268
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/intrinsics/tcgen05_macro_generator.py` around lines 217 - 222, The
code reads tcgen05_num_inst_m / tcgen05_num_inst_n /
compute_tcgen05_instr_desc() before ensuring TCGEN05 meta is initialized, which
can access tcgen05_meta_unpacked before self.meta exists; wrap the early uses
with the same lazy initialization used by tcgen05mma_blockscaled() by calling
get_tcgen5_mma_meta(...) (or the equivalent lazy guard) before any access to
tcgen05_num_inst_* or compute_tcgen05_instr_desc(), and apply the same guard to
the second occurrence around the block that currently calls those helpers at the
later location (the block around compute_tcgen05_instr_desc() at the other
site).
| def compute_tcgen05_b_desc_params(self, B_buf) -> TCGEN05DescriptorParams: | ||
| """Compute B descriptor parameters from the B shared buffer. | ||
|
|
||
| This is a pure-Python helper -- no TIR code is emitted. | ||
| The returned ``TCGEN05DescriptorParams`` is passed to | ||
| ``init_tcgen05_b_desc()`` and ``tcgen05_*_atom()`` methods. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| B_buf : Buffer or BufferRegion | ||
| The B operand in shared memory. | ||
| """ | ||
| atom_m, atom_n, _, _, enable_2cta = self.tcgen05_meta_unpacked | ||
| n_dim = self.block_col_warps * self.warp_col_tiles | ||
| n_dim_per_cta = n_dim // 2 if enable_2cta else n_dim | ||
| k_dim = self.chunk | ||
| micro_size_k = self.micro_size_k | ||
| elems_in_bytes = (DataType(self.a_dtype).bits + 7) // 8 | ||
| b_is_k_major = self.b_transposed | ||
|
|
||
| b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout) | ||
| b_swizzle_atom_elems = ( | ||
| n_dim_per_cta if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes | ||
| ) | ||
|
|
||
| b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim_per_cta * elems_in_bytes) | ||
| b_stride_byte_offset = ( | ||
| (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim_per_cta == 8 else (8 * 8 * elems_in_bytes)) | ||
| ) | ||
| if not b_swizzle_mode.is_none(): | ||
| if b_is_k_major: | ||
| b_leading_byte_offset = 16 | ||
| b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() | ||
| else: | ||
| b_n_axis_atoms = n_dim_per_cta // b_swizzle_atom_elems | ||
| if b_n_axis_atoms <= 1: | ||
| b_leading_byte_offset = 0 | ||
| else: | ||
| b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim | ||
| if b_n_axis_atoms <= 1: | ||
| b_stride_byte_offset = 8 * elems_in_bytes * n_dim_per_cta | ||
| else: | ||
| b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems | ||
|
|
||
| return TCGEN05DescriptorParams( | ||
| swizzle_mode=int(b_swizzle_mode), | ||
| leading_byte_offset=int(b_leading_byte_offset >> 4), | ||
| stride_byte_offset=int(b_stride_byte_offset >> 4), | ||
| swizzle_atom_elems=b_swizzle_atom_elems, | ||
| k_atom_size=max(b_swizzle_atom_elems // micro_size_k, 1), | ||
| elems_in_bytes=elems_in_bytes, | ||
| is_k_major=b_is_k_major, | ||
| ) |
There was a problem hiding this comment.
Use self.b_dtype when computing B descriptor byte widths.
Line 617 derives elems_in_bytes from self.a_dtype, but every offset in this helper is for operand B. If A and B dtypes diverge, the B descriptor offsets and k_atom_size are wrong.
Suggested fix
- elems_in_bytes = (DataType(self.a_dtype).bits + 7) // 8
+ elems_in_bytes = (DataType(self.b_dtype).bits + 7) // 8🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/intrinsics/tcgen05_macro_generator.py` around lines 600 - 652, In
compute_tcgen05_b_desc_params replace the derivation of elems_in_bytes
(currently using DataType(self.a_dtype)) to use B's data type
(DataType(self.b_dtype)) so all B-related byte offsets and the computed
k_atom_size (which depends on swizzle_atom_elems and micro_size_k) are correct
when A and B dtypes differ; update the variable computation and any dependent
uses (elems_in_bytes, b_leading_byte_offset, b_stride_byte_offset, and
k_atom_size) accordingly in the compute_tcgen05_b_desc_params method.
| def compute_wgmma_b_desc_params(self, B_region: BufferRegion) -> WGMMADescriptorParams: | ||
| """Compute B descriptor parameters from the B shared buffer region. | ||
|
|
||
| This is a pure-Python helper -- no TIR code is emitted. | ||
| The returned ``WGMMADescriptorParams`` is passed to | ||
| ``init_wgmma_b_desc()`` and ``wgmma_*_atom()`` methods. | ||
| """ | ||
| n_dim = self.block_col_warps * self.warp_col_tiles | ||
| k_dim = self.chunk | ||
| micro_size_k = self.micro_size_k | ||
| elems_in_bytes = DataType(self.a_dtype).bits // 8 | ||
| b_is_k_major = self.b_transposed | ||
|
|
||
| a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout) | ||
| b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) | ||
| b_swizzle_atom_elems = ( | ||
| n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes | ||
| ) | ||
|
|
||
| elems_in_bits = DataType(self.a_dtype).bits | ||
| elems_in_bytes = elems_in_bits // 8 | ||
| b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes) | ||
| b_stride_byte_offset = ( | ||
| (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes)) | ||
| ) | ||
| if not b_swizzle_mode.is_none(): | ||
| if b_is_k_major: | ||
| b_leading_byte_offset = 16 | ||
| b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() | ||
| else: | ||
| b_n_axis_atoms = n_dim // b_swizzle_atom_elems | ||
| if b_n_axis_atoms <= 1: | ||
| b_leading_byte_offset = 0 | ||
| else: | ||
| b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim | ||
| if b_n_axis_atoms <= 1: | ||
| b_stride_byte_offset = 8 * elems_in_bytes * n_dim | ||
| else: | ||
| b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems | ||
|
|
||
| return WGMMADescriptorParams( | ||
| swizzle_mode=int(b_swizzle_mode), | ||
| leading_byte_offset=int(b_leading_byte_offset >> 4), | ||
| stride_byte_offset=int(b_stride_byte_offset >> 4), | ||
| swizzle_atom_elems=b_swizzle_atom_elems, | ||
| k_atom_size=max(b_swizzle_atom_elems // micro_size_k, 1), | ||
| elems_in_bytes=elems_in_bytes, | ||
| is_k_major=b_is_k_major, | ||
| ) |
There was a problem hiding this comment.
B descriptor math is using A’s element width.
Line 325 sizes B swizzle atoms and byte offsets from self.a_dtype. Since this emitter accepts distinct a_dtype / b_dtype, mixed-input configurations will miscompute the B descriptor.
Suggested fix
- elems_in_bytes = DataType(self.a_dtype).bits // 8
+ elems_in_bytes = DataType(self.b_dtype).bits // 8📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def compute_wgmma_b_desc_params(self, B_region: BufferRegion) -> WGMMADescriptorParams: | |
| """Compute B descriptor parameters from the B shared buffer region. | |
| This is a pure-Python helper -- no TIR code is emitted. | |
| The returned ``WGMMADescriptorParams`` is passed to | |
| ``init_wgmma_b_desc()`` and ``wgmma_*_atom()`` methods. | |
| """ | |
| n_dim = self.block_col_warps * self.warp_col_tiles | |
| k_dim = self.chunk | |
| micro_size_k = self.micro_size_k | |
| elems_in_bytes = DataType(self.a_dtype).bits // 8 | |
| b_is_k_major = self.b_transposed | |
| a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout) | |
| b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) | |
| b_swizzle_atom_elems = ( | |
| n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes | |
| ) | |
| elems_in_bits = DataType(self.a_dtype).bits | |
| elems_in_bytes = elems_in_bits // 8 | |
| b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes) | |
| b_stride_byte_offset = ( | |
| (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes)) | |
| ) | |
| if not b_swizzle_mode.is_none(): | |
| if b_is_k_major: | |
| b_leading_byte_offset = 16 | |
| b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() | |
| else: | |
| b_n_axis_atoms = n_dim // b_swizzle_atom_elems | |
| if b_n_axis_atoms <= 1: | |
| b_leading_byte_offset = 0 | |
| else: | |
| b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim | |
| if b_n_axis_atoms <= 1: | |
| b_stride_byte_offset = 8 * elems_in_bytes * n_dim | |
| else: | |
| b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems | |
| return WGMMADescriptorParams( | |
| swizzle_mode=int(b_swizzle_mode), | |
| leading_byte_offset=int(b_leading_byte_offset >> 4), | |
| stride_byte_offset=int(b_stride_byte_offset >> 4), | |
| swizzle_atom_elems=b_swizzle_atom_elems, | |
| k_atom_size=max(b_swizzle_atom_elems // micro_size_k, 1), | |
| elems_in_bytes=elems_in_bytes, | |
| is_k_major=b_is_k_major, | |
| ) | |
| def compute_wgmma_b_desc_params(self, B_region: BufferRegion) -> WGMMADescriptorParams: | |
| """Compute B descriptor parameters from the B shared buffer region. | |
| This is a pure-Python helper -- no TIR code is emitted. | |
| The returned ``WGMMADescriptorParams`` is passed to | |
| ``init_wgmma_b_desc()`` and ``wgmma_*_atom()`` methods. | |
| """ | |
| n_dim = self.block_col_warps * self.warp_col_tiles | |
| k_dim = self.chunk | |
| micro_size_k = self.micro_size_k | |
| elems_in_bytes = DataType(self.b_dtype).bits // 8 | |
| b_is_k_major = self.b_transposed | |
| b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) | |
| b_swizzle_atom_elems = ( | |
| n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes | |
| ) | |
| b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes) | |
| b_stride_byte_offset = ( | |
| (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes)) | |
| ) | |
| if not b_swizzle_mode.is_none(): | |
| if b_is_k_major: | |
| b_leading_byte_offset = 16 | |
| b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() | |
| else: | |
| b_n_axis_atoms = n_dim // b_swizzle_atom_elems | |
| if b_n_axis_atoms <= 1: | |
| b_leading_byte_offset = 0 | |
| else: | |
| b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim | |
| if b_n_axis_atoms <= 1: | |
| b_stride_byte_offset = 8 * elems_in_bytes * n_dim | |
| else: | |
| b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems | |
| return WGMMADescriptorParams( | |
| swizzle_mode=int(b_swizzle_mode), | |
| leading_byte_offset=int(b_leading_byte_offset >> 4), | |
| stride_byte_offset=int(b_stride_byte_offset >> 4), | |
| swizzle_atom_elems=b_swizzle_atom_elems, | |
| k_atom_size=max(b_swizzle_atom_elems // micro_size_k, 1), | |
| elems_in_bytes=elems_in_bytes, | |
| is_k_major=b_is_k_major, | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/intrinsics/wgmma_macro_generator.py` around lines 315 - 360,
compute_wgmma_b_desc_params is calculating B swizzle atom sizes and byte offsets
using A's element width (elems_in_bytes derived from self.a_dtype); change it to
use B's data type (DataType(self.b_dtype).bits // 8) so all B-specific
calculations (b_swizzle_atom_elems, b_leading_byte_offset, b_stride_byte_offset,
and any downstream shifts) are computed from B’s element size; update the
variable name or comment to reflect it is B's elems_in_bytes and ensure the
return fields (swizzle_atom_elems, elems_in_bytes) reflect the B dtype.
| def compute_wgmma_a_desc_params(self, A_region: BufferRegion) -> WGMMADescriptorParams: | ||
| """Compute A descriptor parameters from the A shared buffer region (SS variant). | ||
|
|
||
| This is a pure-Python helper -- no TIR code is emitted. | ||
| The returned ``WGMMADescriptorParams`` is passed to | ||
| ``init_wgmma_a_desc()`` and ``wgmma_ss_atom()`` methods. | ||
| """ | ||
| m_dim = self.block_row_warps * self.warp_row_tiles | ||
| k_dim = self.chunk | ||
| micro_size_k = self.micro_size_k | ||
| elems_in_bytes = DataType(self.a_dtype).bits // 8 | ||
| a_is_k_major = not self.a_transposed | ||
|
|
||
| a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout) | ||
| a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes | ||
| b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes | ||
| accum_bits = DataType(accum_dtype).bits | ||
| accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32 | ||
|
|
||
| # by default, we utilize non-swizzle layout offset | ||
| a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * elems_in_bytes) | ||
| a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * elems_in_bytes) | ||
|
|
||
| if not a_swizzle_mode.is_none(): | ||
| # swizzle mode doesn't require LBO/SBO to be 1 | ||
| # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset | ||
| if a_is_k_major: | ||
| a_leading_byte_offset = 16 | ||
| a_stride_byte_offset = 8 * a_swizzle_mode.swizzle_byte_size() | ||
| else: | ||
| # MN Major | ||
| # LBO represents the distance between two atoms along the M dimension | ||
| # SBO represents the distance between two atoms along the K dimension | ||
| a_m_axis_atoms = m_dim // a_swizzle_atom_elems | ||
| if a_m_axis_atoms <= 1: | ||
| a_leading_byte_offset = 0 | ||
| else: | ||
| a_leading_byte_offset = 8 * a_swizzle_mode.swizzle_atom_size() * (a_swizzle_mode.swizzle_byte_size() // elems_in_bytes) | ||
|
|
||
| a_leading_byte_offset = ( | ||
| 8 * a_swizzle_mode.swizzle_atom_size() * (a_swizzle_mode.swizzle_byte_size() // elems_in_bytes) | ||
| ) | ||
| if a_m_axis_atoms <= 1: | ||
| a_stride_byte_offset = 8 * elems_in_bytes * m_dim | ||
| else: | ||
| a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems | ||
|
|
||
| b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes) | ||
| b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes)) | ||
| if not b_swizzle_mode.is_none(): | ||
| # swizzle mode doesn't require LBO/SBO to be 1 | ||
| # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset | ||
| if b_is_k_major: | ||
| b_leading_byte_offset = 16 | ||
| b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() | ||
| else: | ||
| # MN Major, K * N | ||
| # LBO represents the distance between two atoms along the N dimension | ||
| # SBO represents the distance between two atoms along the K dimension | ||
| b_n_axis_atoms = n_dim // b_swizzle_atom_elems | ||
| if b_n_axis_atoms <= 1: | ||
| b_leading_byte_offset = 0 | ||
| else: | ||
| b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim | ||
| if b_n_axis_atoms <= 1: | ||
| b_stride_byte_offset = 8 * elems_in_bytes * n_dim | ||
| else: | ||
| b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems | ||
| return WGMMADescriptorParams( | ||
| swizzle_mode=int(a_swizzle_mode), | ||
| leading_byte_offset=int(a_leading_byte_offset >> 4), | ||
| stride_byte_offset=int(a_stride_byte_offset >> 4), | ||
| swizzle_atom_elems=a_swizzle_atom_elems, | ||
| k_atom_size=max(a_swizzle_atom_elems // micro_size_k, 1), | ||
| elems_in_bytes=elems_in_bytes, | ||
| is_k_major=a_is_k_major, | ||
| ) |
There was a problem hiding this comment.
Linear A layouts currently produce swizzle_atom_elems == 0.
Line 376 uses SwizzleMode.NONE.swizzle_byte_size() // elems_in_bytes, which is 0 for fp16/bf16. That zero is then fed into k_atom_size and the A offset math consumed by wgmma_ss_atom(), so the non-swizzled SS path misaddresses shared memory. Please give the NONE case a non-zero linear-layout atom size.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/intrinsics/wgmma_macro_generator.py` around lines 362 - 405,
compute_wgmma_a_desc_params currently computes a_swizzle_atom_elems as
a_swizzle_mode.swizzle_byte_size() // elems_in_bytes which yields 0 for
SwizzleMode.NONE (fp16/bf16) and breaks downstream math; change the logic in
compute_wgmma_a_desc_params so that when a_swizzle_mode.is_none() (or when
swizzle_byte_size() // elems_in_bytes == 0) you force a_swizzle_atom_elems to a
non-zero value (e.g., 1) before it's used to compute a_m_axis_atoms,
k_atom_size, and the returned WGMMADescriptorParams; ensure the final returned
swizzle_atom_elems and k_atom_size use that clamped value so the non-swizzled SS
path addresses shared memory correctly.
|
@regression-perf |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_atom_mma.py (1)
309-327: ⚡ Quick win
test_tcgen05_atom_gemminlines its runner and has weaker source assertions than the WGMMA testTwo consistency gaps vs. the WGMMA path:
No
_run_tcgen05_atomhelper: MMA and WGMMA both separate kernel construction (make_*) from execution (_run_*) and then call_run_*from thetest_*function. TCGEN05 conflates all three in one function, making it harder to add future parametrized calls (e.g., different M/N/K) without duplication.Single source assertion (
"tcgen05.mma" in src): The WGMMA test asserts four distinct codegen artifacts (wgmma_ss,initialize_wgmma_descriptor,warpgroup_arrive,warpgroup_commit_batch). TCGEN05 would benefit from similar assertions (e.g., presence of descriptor-init and barrier-arrive helpers).🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@testing/python/language/test_tilelang_language_atom_mma.py` around lines 309 - 327, test_tcgen05_atom_gemm currently builds, runs, and asserts source inline; extract the execution part into a new helper _run_tcgen05_atom that takes (M,N,K,in_dtype,out_dtype,compute_dtype) and calls make_tcgen05_atom_kernel then runs the kernel (mirroring the MMA/WGMMA pattern), and update test_tcgen05_atom_gemm to call that helper; additionally strengthen the source assertions by checking for multiple TCGEN05 codegen artifacts (e.g., include checks for "tcgen05.mma", the descriptor initialization helper name, and warpgroup barrier helpers such as "warpgroup_arrive" and "warpgroup_commit_batch") so the test validates descriptor-init and barrier-arrive/commit codegen similarly to the WGMMA test.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@testing/python/language/test_tilelang_language_atom_mma.py`:
- Around line 141-142: The computation of warp_row_tiles and warp_col_tiles uses
integer division that silently truncates remainders; add assertions before those
assignments to ensure M % block_row_warps == 0 and N % block_col_warps == 0
(with clear messages referencing M, block_row_warps and N, block_col_warps) so
misconfigured callers fail loud and the derived warp_row_tiles and
warp_col_tiles values are guaranteed correct.
- Around line 40-42: The section header comment above the SM80+ tests
incorrectly says "Hopper" — update that comment to say "Ampere" so it matches
the decorator requires_cuda_compute_version_ge(8, 0) and the SM80 architecture;
locate the SM80+ MMA (atom-level) header comment and replace "Hopper" with
"Ampere" to keep the comments consistent with the
requires_cuda_compute_version_ge(8, 0) decorator.
- Around line 195-198: C_local is allocated via T.alloc_fragment but never
initialized before calls to emi.wgmma_ss_atom with T.bool(True), which if
representing scale_d=1 will read uninitialized accumulators; fix by either
inserting T.clear(C_local) before the nested loops that call emi.wgmma_ss_atom
(or immediately after allocation) so accumulation starts from zero, or if
T.bool(True) is instead intended as an "is-first-k" reset flag make the flag
conditional (e.g., pass True only when ki==0 and m==0 and n==0) and leave
C_local unmodified otherwise; locate uses of C_local, emi.wgmma_ss_atom, and any
T.bool(True) in the loop to implement the appropriate change.
---
Nitpick comments:
In `@testing/python/language/test_tilelang_language_atom_mma.py`:
- Around line 309-327: test_tcgen05_atom_gemm currently builds, runs, and
asserts source inline; extract the execution part into a new helper
_run_tcgen05_atom that takes (M,N,K,in_dtype,out_dtype,compute_dtype) and calls
make_tcgen05_atom_kernel then runs the kernel (mirroring the MMA/WGMMA pattern),
and update test_tcgen05_atom_gemm to call that helper; additionally strengthen
the source assertions by checking for multiple TCGEN05 codegen artifacts (e.g.,
include checks for "tcgen05.mma", the descriptor initialization helper name, and
warpgroup barrier helpers such as "warpgroup_arrive" and
"warpgroup_commit_batch") so the test validates descriptor-init and
barrier-arrive/commit codegen similarly to the WGMMA test.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ad873b00-e695-4146-b782-40d46fd86569
📒 Files selected for processing (1)
testing/python/language/test_tilelang_language_atom_mma.py
| warp_row_tiles = M // block_row_warps | ||
| warp_col_tiles = N // block_col_warps |
There was a problem hiding this comment.
Silent truncation if M or N are not divisible by block_row_warps/block_col_warps
warp_row_tiles = M // block_row_warps and warp_col_tiles = N // block_col_warps silently drop any remainder. A misconfigured caller would get wrong tile counts and incorrect results without any diagnostic. An assertion at this point keeps the failure loud:
✏️ Proposed fix
+ assert M % block_row_warps == 0, f"M={M} not divisible by block_row_warps={block_row_warps}"
+ assert N % block_col_warps == 0, f"N={N} not divisible by block_col_warps={block_col_warps}"
warp_row_tiles = M // block_row_warps
warp_col_tiles = N // block_col_warps📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| warp_row_tiles = M // block_row_warps | |
| warp_col_tiles = N // block_col_warps | |
| assert M % block_row_warps == 0, f"M={M} not divisible by block_row_warps={block_row_warps}" | |
| assert N % block_col_warps == 0, f"N={N} not divisible by block_col_warps={block_col_warps}" | |
| warp_row_tiles = M // block_row_warps | |
| warp_col_tiles = N // block_col_warps |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@testing/python/language/test_tilelang_language_atom_mma.py` around lines 141
- 142, The computation of warp_row_tiles and warp_col_tiles uses integer
division that silently truncates remainders; add assertions before those
assignments to ensure M % block_row_warps == 0 and N % block_col_warps == 0
(with clear messages referencing M, block_row_warps and N, block_col_warps) so
misconfigured callers fail loud and the derived warp_row_tiles and
warp_col_tiles values are guaranteed correct.
| for n in T.unroll(num_inst_n): | ||
| for m in T.unroll(num_inst_m): | ||
| for ki in T.unroll(num_k_atoms): | ||
| emi.wgmma_ss_atom(desc_a, desc_b, C_local, m, n, ki, a_params, b_params, T.bool(True)) |
There was a problem hiding this comment.
Accumulator never initialized before wgmma_ss_atom — potential correctness bug
C_local is allocated with T.alloc_fragment (line 172) but never cleared. wgmma_fence_c is a fence/barrier primitive, not a zero-initializer. If T.bool(True) maps to WGMMA PTX scale_d=1 (accumulate), the very first call reads uninitialized memory and adds A*B to it — producing garbage.
In contrast, the MMA path correctly calls T.clear(C_local) (line 96) before any atom.
Confirm whether T.bool(True) here means:
scale_d=1(accumulate-into-D) → addT.clear(C_local)before the loop, or- "is-first-k" reset flag (True → init D to zero) → the flag should be conditional on the first iteration (e.g.,
ki == 0 and m == 0 and n == 0), not always-True.
🛡️ Fix if `T.bool(True)` means `scale_d=1`
+ T.clear(C_local)
+
emi.wgmma_fence_c(C_local)
emi.wgmma_arrive()🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@testing/python/language/test_tilelang_language_atom_mma.py` around lines 195
- 198, C_local is allocated via T.alloc_fragment but never initialized before
calls to emi.wgmma_ss_atom with T.bool(True), which if representing scale_d=1
will read uninitialized accumulators; fix by either inserting T.clear(C_local)
before the nested loops that call emi.wgmma_ss_atom (or immediately after
allocation) so accumulation starts from zero, or if T.bool(True) is instead
intended as an "is-first-k" reset flag make the flag conditional (e.g., pass
True only when ki==0 and m==0 and n==0) and leave C_local unmodified otherwise;
locate uses of C_local, emi.wgmma_ss_atom, and any T.bool(True) in the loop to
implement the appropriate change.
…TCGEN05TensorCoreIntrinEmitter. Introduce descriptor parameter classes for WGMMA and TCGEN05, improving the initialization and atom offset computation processes. Update the TensorCoreIntrinEmitter to support atom-level interfaces and macro definitions for efficient matrix multiplication operations.
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py (1)
318-354:⚠️ Potential issue | 🔴 Critical | ⚡ Quick winClamp non-swizzled A atom size before using it in offset math.
Both SS paths can produce
a_swizzle_atom_elems == 0for linear fp16/bf16 layouts. That value then feeds directly into the A-offset formulas, collapsing distinct atoms onto the same shared-memory region.Suggested fix
- a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes + a_swizzle_atom_elems = max(a_swizzle_mode.swizzle_byte_size() // elems_in_bytes, 1)Also applies to: 680-708
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py` around lines 318 - 354, The bug is that a_swizzle_atom_elems can be 0 for non-swizzled linear fp16/bf16 layouts and is used in offset math; clamp it to at least 1 before any use. Update the block computing a_swizzle_atom_elems in tcgen05_macro_generator.py to set a_swizzle_atom_elems = max(a_swizzle_atom_elems, 1) (or equivalent) before computing a_m_axis_atoms, a_leading_byte_offset/a_stride_byte_offset, and before passing swizzle_atom_elems/k_atom_size into TCGEN05DescriptorParams (also apply the same clamp in the other occurrence around lines 680-708 that constructs A params). Ensure k_atom_size still uses the clamped value when doing max(a_swizzle_atom_elems // micro_size_k, 1).
♻️ Duplicate comments (4)
tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py (2)
217-222:⚠️ Potential issue | 🟠 Major | ⚡ Quick winInitialize TCGEN05 meta before using the atom-count helpers.
Both entrypoints now read
tcgen05_num_inst_*/compute_tcgen05_instr_desc()before any lazy meta bootstrap. On a fresh emitter that reachestcgen05_meta_unpackedbeforeself.metaexists, sotcgen05mma_ss()andtcgen05mma_ts()can fail before emitting TIR.Suggested fix
+ m_dim = self.block_row_warps * self.warp_row_tiles + n_dim = self.block_col_warps * self.warp_col_tiles + if not hasattr(self, "meta") or len(self.meta) != 5: + self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim, disable_2cta=False) + num_inst_m = self.tcgen05_num_inst_m num_inst_n = self.tcgen05_num_inst_n num_k_atoms = self.tcgen05_num_k_atomsAlso applies to: 264-268
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py` around lines 217 - 222, The code reads tcgen05_num_inst_m, tcgen05_num_inst_n and calls compute_tcgen05_instr_desc() before ensuring the TCGEN05 meta is unpacked; call tcgen05_meta_unpacked() (or otherwise initialize/unpack self.meta) at the start of the routine before accessing tcgen05_* fields or calling compute_tcgen05_instr_desc(), so that helpers like tcgen05mma_ss() and tcgen05mma_ts() won't run against an uninitialized self.meta.
626-662:⚠️ Potential issue | 🟠 Major | ⚡ Quick winUse
self.b_dtypewhen computing B descriptor byte widths.This helper still derives B’s element width from
self.a_dtype, so mixed-input kernels will generate the wrong B descriptor offsets.Suggested fix
- elems_in_bytes = (DataType(self.a_dtype).bits + 7) // 8 + elems_in_bytes = (DataType(self.b_dtype).bits + 7) // 8🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py` around lines 626 - 662, The B-descriptor computation incorrectly uses self.a_dtype to compute elems_in_bytes, causing wrong byte offsets for mixed-input kernels; update the computation to use self.b_dtype (i.e., replace DataType(self.a_dtype) with DataType(self.b_dtype)) so b_leading_byte_offset, b_stride_byte_offset, b_swizzle_atom_elems, and any derived values (returned in TCGEN05DescriptorParams) are based on B's element size.tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py (2)
371-399:⚠️ Potential issue | 🔴 Critical | ⚡ Quick winClamp linear A
swizzle_atom_elemsto a non-zero value.For
SwizzleMode.NONE,swizzle_byte_size() // elems_in_bytesbecomes0for fp16/bf16, so the A-offset math inwgmma_ss_atom()collapses and multiple atoms read the same shared-memory slice.Suggested fix
- a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes + a_swizzle_atom_elems = max(a_swizzle_mode.swizzle_byte_size() // elems_in_bytes, 1)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py` around lines 371 - 399, The a_swizzle_atom_elems calculation can be zero for SwizzleMode.NONE (a_swizzle_mode.swizzle_byte_size() // elems_in_bytes) which breaks downstream A-offset math (e.g. wgmma_ss_atom); clamp a_swizzle_atom_elems to a minimum of 1 after it is computed. Locate the variables a_swizzle_mode and a_swizzle_atom_elems in the function that returns WGMMADescriptorParams and replace the raw division result with a max(..., 1) or otherwise set a_swizzle_atom_elems = max(a_swizzle_atom_elems, 1) before using it in the subsequent offset logic and in the returned WGMMADescriptorParams.
325-355:⚠️ Potential issue | 🟠 Major | ⚡ Quick winUse
self.b_dtypefor B descriptor byte math.
compute_wgmma_b_desc_params()is still sizing B’s swizzle atoms and byte offsets fromself.a_dtype. Mixed-input kernels will build the B descriptor with the wrong element width and misaddress shared memory.Suggested fix
- elems_in_bytes = DataType(self.a_dtype).bits // 8 + elems_in_bytes = DataType(self.b_dtype).bits // 8🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py` around lines 325 - 355, compute_wgmma_b_desc_params() currently computes B's element byte size and all byte-offset math using DataType(self.a_dtype); change it to use DataType(self.b_dtype) so elems_in_bytes and any calculations derived from it (b_swizzle_atom_elems, b_leading_byte_offset, b_stride_byte_offset, and k_atom_size) are based on B's actual dtype. Update the assignment of elems_in_bytes and any expressions that multiply/divide by elems_in_bytes to reference the new value so the returned WGMMADescriptorParams fields (swizzle_atom_elems, leading_byte_offset, stride_byte_offset, k_atom_size, elems_in_bytes) are correct for mixed-input kernels.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py`:
- Around line 318-354: The bug is that a_swizzle_atom_elems can be 0 for
non-swizzled linear fp16/bf16 layouts and is used in offset math; clamp it to at
least 1 before any use. Update the block computing a_swizzle_atom_elems in
tcgen05_macro_generator.py to set a_swizzle_atom_elems =
max(a_swizzle_atom_elems, 1) (or equivalent) before computing a_m_axis_atoms,
a_leading_byte_offset/a_stride_byte_offset, and before passing
swizzle_atom_elems/k_atom_size into TCGEN05DescriptorParams (also apply the same
clamp in the other occurrence around lines 680-708 that constructs A params).
Ensure k_atom_size still uses the clamped value when doing
max(a_swizzle_atom_elems // micro_size_k, 1).
---
Duplicate comments:
In `@tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py`:
- Around line 217-222: The code reads tcgen05_num_inst_m, tcgen05_num_inst_n and
calls compute_tcgen05_instr_desc() before ensuring the TCGEN05 meta is unpacked;
call tcgen05_meta_unpacked() (or otherwise initialize/unpack self.meta) at the
start of the routine before accessing tcgen05_* fields or calling
compute_tcgen05_instr_desc(), so that helpers like tcgen05mma_ss() and
tcgen05mma_ts() won't run against an uninitialized self.meta.
- Around line 626-662: The B-descriptor computation incorrectly uses
self.a_dtype to compute elems_in_bytes, causing wrong byte offsets for
mixed-input kernels; update the computation to use self.b_dtype (i.e., replace
DataType(self.a_dtype) with DataType(self.b_dtype)) so b_leading_byte_offset,
b_stride_byte_offset, b_swizzle_atom_elems, and any derived values (returned in
TCGEN05DescriptorParams) are based on B's element size.
In `@tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py`:
- Around line 371-399: The a_swizzle_atom_elems calculation can be zero for
SwizzleMode.NONE (a_swizzle_mode.swizzle_byte_size() // elems_in_bytes) which
breaks downstream A-offset math (e.g. wgmma_ss_atom); clamp a_swizzle_atom_elems
to a minimum of 1 after it is computed. Locate the variables a_swizzle_mode and
a_swizzle_atom_elems in the function that returns WGMMADescriptorParams and
replace the raw division result with a max(..., 1) or otherwise set
a_swizzle_atom_elems = max(a_swizzle_atom_elems, 1) before using it in the
subsequent offset logic and in the returned WGMMADescriptorParams.
- Around line 325-355: compute_wgmma_b_desc_params() currently computes B's
element byte size and all byte-offset math using DataType(self.a_dtype); change
it to use DataType(self.b_dtype) so elems_in_bytes and any calculations derived
from it (b_swizzle_atom_elems, b_leading_byte_offset, b_stride_byte_offset, and
k_atom_size) are based on B's actual dtype. Update the assignment of
elems_in_bytes and any expressions that multiply/divide by elems_in_bytes to
reference the new value so the returned WGMMADescriptorParams fields
(swizzle_atom_elems, leading_byte_offset, stride_byte_offset, k_atom_size,
elems_in_bytes) are correct for mixed-input kernels.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: e6569594-ad75-493c-825a-0c72a25c7d1f
📒 Files selected for processing (5)
testing/python/language/test_tilelang_language_atom_mma.pytilelang/cuda/intrinsics/macro/mma_macro_generator.pytilelang/cuda/intrinsics/macro/tcgen05_macro_generator.pytilelang/cuda/intrinsics/macro/wgmma_macro_generator.pytilelang/intrinsics/__init__.py
🚧 Files skipped from review as they are similar to previous changes (2)
- tilelang/intrinsics/init.py
- testing/python/language/test_tilelang_language_atom_mma.py
Summary by CodeRabbit
New Features
Tests