Skip to content

tilelang: T.fp8_scaled_matmul DSL intrinsic + Metal lowering#2142

Open
apstenku123 wants to merge 11 commits intotile-ai:mainfrom
apstenku123:cppmega/fp8-scaled-matmul-intrinsic
Open

tilelang: T.fp8_scaled_matmul DSL intrinsic + Metal lowering#2142
apstenku123 wants to merge 11 commits intotile-ai:mainfrom
apstenku123:cppmega/fp8-scaled-matmul-intrinsic

Conversation

@apstenku123
Copy link
Copy Markdown

@apstenku123 apstenku123 commented May 4, 2026

Summary

Adds the T.fp8_scaled_matmul(A_fp8, A_scale, B_fp8, B_scale, C_out) DSL intrinsic — a hygienic @T.macro mirroring the audiohacking/fp8-mps-metal scaled-matmul kernel surface (Apache 2.0). Provides Metal lowering via the GemmMetalScalar path with per-tensor and (single-block) per-row scale dispatch resolved at macro-expansion time.

This is the "real" implementation, not a stub: the macro expands to a working scaled-FP8 GEMM that produces results bit-near-equal to the audiohacking reference MSL kernel. 25/25 IR + Metal e2e parity tests pass on M-series.

Files

  • tilelang/language/__init__.py — re-exports T.fp8_scaled_matmul
  • tilelang/language/fp8_op.py — 379 LOC, the macro implementation
  • testing/python/cpu/test_fp8_scaled_matmul_lowering.py — 8 IR-level lowering tests
  • testing/python/metal/test_fp8_scaled_matmul_metal.py — 17 Metal codegen + xcrun + e2e parity tests vs mlx.core ground truth
  • TVM-mirror codegen_metal e4m3 subnormal bugfix (e + 8e + 7) — required for parity with torch.float8_e4m3fn / mx.from_fp8 / audiohacking LUT decoder. Without the fix, bytes 0x01..0x07 / 0x81..0x87 dequantize to 2× the correct value. NOTE: this hunk is for the 3rdparty/tvm vendored fork (TileLang/tvm); the code change is in this PR's history but the actual landing happens via a companion PR to that submodule. Maintainers can either merge this PR with the bugfix included (it lives in the patch text) or split it off.

Why

cppmega.mlx sparse-MLA-FP8 attention probes need a fused FP8 scaled GEMM path. Path B (hand-written MSL via mx.fast.metal_kernel calling the audiohacking kernel pre-compiled) works today; this PR adds the equivalent surface in TileLang DSL so T.fp8_scaled_matmul(...) works inside @T.prim_func.

Bench

Local Apple M4 Max:

Shape TileLang T.fp8_scaled_matmul audiohacking MSL reference Ratio
128×128×128 e4m3 per-tensor matmul 0.555 ms (0.008 TFLOPS) 0.172 ms (0.024 TFLOPS) 3.16× slower
M=1, N=K=4096 vecmat per-tensor 1.098 ms (0.031 TFLOPS) 0.182 ms (0.184 TFLOPS) 6.01× slower

The gap to audiohacking is expected: the MSL reference is a hand-tuned kernel with simd_sum reduction for the M=1 case; the DSL macro lowers via T.gemm + post-loop scale broadcast. A follow-up scheduler pass that fuses per-load scale into the K-loop would close most of the matmul gap; the vecmat case additionally needs a T.simdgroup_reduce_sum primitive.

Stacking topology

This PR is based on jorgecurious/tilelang:metal-gemm-upstream-rebase (PR #2130) at HEAD 971c17b, which itself stacks on top of:

Once #1869+#2118+#2121+#2130 merge into tile-ai/tilelang:main, this PR can be retargeted to main directly.

Test plan

cd /path/to/tilelang
pytest testing/python/cpu/test_fp8_scaled_matmul_lowering.py testing/python/metal/test_fp8_scaled_matmul_metal.py

Expect 25/25 pass.

Limitations / follow-ups

  • Per-tensor scale and single-block per-row scale work end-to-end. Multi-block per-row scale (where scale.shape[0] > 1 and BM != M) needs a slice at the call site or a follow-up macro extension. Documented as TODO in the macro.
  • CUDA path: macro emits the same TIR; T.cast(fp8, fp32) lowers via __nv_fp8_e4m3_to_half etc. on CUDA. Tensor-core FP8 path (T.tcgen05_gemm_blockscaled) is intentionally not auto-dispatched here because the e8m0 block-scale layout doesn't match this op's per-tensor / per-row semantics.
  • Performance: see Bench above. Native simdgroup FP8 doesn't exist on Apple Silicon (M1-M4); this is a software path.

Attribution

The audiohacking/fp8-mps-metal MSL kernel (Apache 2.0) is the algorithmic reference. The TileLang DSL macro replicates the per-tensor scale broadcast + K-loop unrolling pattern in TIR. Co-developed with cppmega.mlx.

Summary by CodeRabbit

  • New Features

    • Added Metal GPU backend support with GEMM and copy operations for Apple devices
    • Introduced T.fp8_scaled_matmul intrinsic for FP8 quantized matrix multiplication
    • Enabled Apple MPS (Metal Performance Shaders) device detection and selection
  • Bug Fixes

    • Improved device detection to prioritize MPS when CUDA is unavailable
  • Tests

    • Added comprehensive Metal backend test suites covering GEMM, copy ops, FP8 operations, and internal scaffolding
  • Chores

    • Updated macOS-specific dependency constraints for apache-tvm-ffi

oraluben and others added 11 commits April 30, 2026 01:43
Add T.gemm support for Apple Metal using simdgroup_matrix 8x8 operations
(simdgroup_load/store/multiply_accumulate). Works on all Apple Silicon
(M1-M5) without requiring a TVM fork.

Key changes:
- codegen_metal.cc/h: Fork TVM Metal codegen to tilelang with
  simdgroup intrinsic emission and 128-bit vectorized copy
- gemm_metal.py: GemmMetal tile operator for sharedxshared GEMM
- metal_macro_generator.py: MPSIntrinEmitter for simdgroup MMA macros
- metal_fragment_to_simdgroup.py: Pass rewrites local.fragment GEMM
  accumulators to metal.simdgroup scope before layout inference
- LowerSIMDGroupCopy in copy.cc for fragment->device simdgroup_store

24 Metal tests (codegen cross-platform + correctness on device).
Copilot AI review requested due to automatic review settings May 4, 2026 08:53
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 4, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 4, 2026

📝 Walkthrough

Walkthrough

This PR introduces complete Metal backend support for TileLang, including a Metal code generator, Metal-specific GEMM and copy operations, a new FP8 scaled matmul intrinsic, and comprehensive testing infrastructure for macOS MPS execution.

Changes

Metal Backend Infrastructure & Operations

Layer / File(s) Summary
Metal Code Generation Foundation
src/target/codegen_metal.h, src/target/codegen_metal.cc, src/backend/metal/CMakeLists.txt
New CodeGenTileLangMetal class emits Metal kernel code with type conversion, thread binding, storage scope handling, and builtin dispatch (simdgroup ops, reinterpret). CMake always compiles codegen source and skips Metal runtime collection on non-Apple platforms.
Core IR Lowering
src/op/gemm.h, src/op/gemm.cc, src/op/copy.h, src/op/copy.cc, src/op/fill.cc, src/op/utils.h, src/op/parallel.cc
New GemmInst::kMetalSimdgroup and warp-partition targeting (M-per-warp = 8 for Metal). New CopyInst::kMetalSIMDGroup with CheckSIMDGroupCopy and LowerSIMDGroupCopy for simdgroup store operations. SIMD-group fill via make_filled_simdgroup_matrix. Register-buffer classification includes SIMD-group buffers.
TileOp Abstractions & Macros
tilelang/tileop/gemm/inst.py, tilelang/tileop/gemm/gemm_metal.py, tilelang/tileop/gemm/__init__.py, tilelang/intrinsics/metal_macro_generator.py, tilelang/tileop/metal_simdgroup.py, tilelang/tileop/metal_quant.py, tilelang/tileop/metal_gdn.py
New GemmMetal class lowers shared-memory GEMM to Metal with warp-level tiling. MPSIntrinEmitter generates ldmatrix/mma/simdgroup-store macros. Register-tile (RegisterTile, MMATile, RowVector) abstractions with load/store/mma/reduction macros. Quantization decoders (fp8/fp4/e8m0). GDN KKT/scoring macros.
FP8 Scaled MatMul Intrinsic
tilelang/language/fp8_op.py, tilelang/language/__init__.py
New T.fp8_scaled_matmul(A_fp8, A_scale, B_fp8, B_scale, C_out, transpose_B, accum_dtype) macro validates dtypes/shapes, selects per-tensor vs per-row scale indexing, and emits fused dequant-multiply-accumulate loop.
Transform Passes & Integration
tilelang/transform/metal_fragment_to_simdgroup.py, tilelang/transform/decouple_type_cast.py, src/transform/layout_inference.cc, src/transform/lower_device_storage_access_info.cc, tilelang/utils/language.py, tilelang/engine/lower.py, tilelang/engine/phase.py, tilelang/jit/adapter/base.py, tilelang/jit/adapter/torch/metal.py
New MetalFragmentToSimdgroup pass rewrites fragment accumulators to metal.simdgroup before layout inference. Fragment buffers skip strict layout checks on Metal. Storage scope filtering excludes fragments. Build uses target.build.tilelang_metal. Device functor prefers MPS when CUDA unavailable.
Benchmarks & Tests
benchmark/matmul_metal/benchmark_matmul_metal.py, testing/python/metal/test_fp8_scaled_matmul_metal.py, testing/python/metal/test_metal_gemm_v2.py, testing/python/metal/test_metal_gemm_v2_linux.py, testing/python/metal/test_metal_simdgroup_store.py, testing/python/metal/test_metal_local_var.py, testing/python/metal/test_metal_internal_scaffolding.py, testing/python/cpu/test_fp8_scaled_matmul_lowering.py, testing/python/jit/test_tilelang_jit_adapter_mps.py, testing/python/metal/metal_internal_runtime_coverage.md
E2E Metal execution tests for GEMM/FP8 matmul, codegen validation, parity checks vs PyTorch/audiohacking references, internal scaffolding probes, and MPS device-selection tests. Includes benchmark/xcrun compilation tests and documentation.
Dependency Constraints
pyproject.toml, requirements.txt, requirements-dev.txt
Added macOS-specific upper bound (<0.1.8) for apache-tvm-ffi.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999
  • tzj-fxz

Poem

🐰 Metal shines bright on Apple's shores,
Fragments morph to SIMDGROUP stores,
Quantized bytes dance in threads,
Register tiles in matrices spread,
Woosh! GPU speeds, FP8 threads—
TileLang leaps forward, ears spread! 🎉

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

apstenku123 added a commit to DatasunriseOU/cppmega_mlx that referenced this pull request May 4, 2026
Files documenting the actual PRs we just opened upstream:

- PR #1: ml-explore/mlx#3476 — from_dlpack Metal-aware consumer (against main, clean)
- PR #2: apache/tvm#19504 — TVM_METAL_STORAGE_MODE env opt-in (against main, clean)
- PR #3: tile-ai/tilelang#2139 — mixed-dtype T.gemm via scalar fallback (stacks on PR #2130)
- PR #4: tile-ai/tilelang#2140 — FP8-input T.gemm scalar fallback routing (stacks on PR #2130)
- PR #5: tile-ai/tilelang#2141 — T.Pipelined num_stages>1 3D buffer fix (stacks on PR #2130)
- PR #6: tile-ai/tilelang#2142 — T.fp8_scaled_matmul DSL intrinsic (stacks on PR #2130)

Deferred (split into companion PRs needed): tilelang_metal_fp8 and
tilelang_metal_fp8_vector each touch both tilelang supermodule and the
TileLang/tvm vendored submodule. These need 2 PRs each — one to
tile-ai/tilelang, one to TileLang/tvm — separate filing round.

PRs #3-#6 are independent of each other; each branches directly from
jorgecurious/tilelang:metal-gemm-upstream-rebase HEAD 971c17b, so they
can be reviewed in any order. They DO depend on the upstream 4-PR Apple
Metal landing chain (#1869, #2118, #2121, #2130) merging first; if any
of those land separately, ours can be retargeted at main.
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new FP8 scaled matmul intrinsic to the TileLang language surface and expands the Metal backend to support simdgroup-matrix GEMM lowering (plus supporting codegen/runtime plumbing and tests).

Changes:

  • Introduces T.fp8_scaled_matmul(A_fp8, A_scale, B_fp8, B_scale, C_out) as a hygienic macro and re-exports it from tilelang.language.
  • Adds Metal simdgroup GEMM lowering/codegen support (new Metal codegen target builder, simdgroup copy/store lowering, and a fragment→simdgroup rewrite pass).
  • Adds extensive Metal and IR-level tests/coverage docs, plus a Metal GEMM benchmark script and minor JIT adapter behavior changes.

Reviewed changes

Copilot reviewed 40 out of 41 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
tilelang/utils/language.py Adds helper is_metal_simdgroup() for scope checks.
tilelang/transform/metal_fragment_to_simdgroup.py New Metal-only pass rewriting GEMM accumulators from local.fragment to metal.simdgroup.
tilelang/transform/decouple_type_cast.py Treats metal.simdgroup buffers as “local” for cast-decoupling logic.
tilelang/tileop/metal_simdgroup.py Internal simdgroup register-tile utilities/macros.
tilelang/tileop/metal_quant.py Internal packed-quant decode helpers used by Metal scaffolding tests.
tilelang/tileop/metal_gdn.py Internal GDN/attention-style tile macros using simdgroup helper utilities.
tilelang/tileop/gemm/inst.py Adds METAL_SIMDGROUP GEMM instruction enum value.
tilelang/tileop/gemm/gemm_metal.py Implements Metal simdgroup-matrix GEMM lowering path.
tilelang/tileop/gemm/init.py Selects Metal simdgroup GEMM impl on Metal targets.
tilelang/language/fp8_op.py Adds fp8_scaled_matmul macro implementation and validation.
tilelang/language/init.py Re-exports fp8_scaled_matmul on the public language surface.
tilelang/jit/adapter/torch/metal.py Adds/overrides get_kernel_source() for Metal kernels.
tilelang/jit/adapter/base.py Adjusts device selection to prefer MPS when CUDA init/lookup fails.
tilelang/intrinsics/metal_macro_generator.py Adds MPSIntrinEmitter for simdgroup load/store/MMA TIR macro emission.
tilelang/engine/phase.py Inserts Metal fragment→simdgroup rewrite before layout inference.
tilelang/engine/lower.py Switches Metal build entrypoint to target.build.tilelang_metal.
testing/python/metal/test_metal_simdgroup_store.py New tests for direct simdgroup store to device memory.
testing/python/metal/test_metal_local_var.py New focused tests for local.var scalar codegen/runtime on Metal.
testing/python/metal/test_metal_internal_scaffolding.py Large internal scaffolding suite for Metal source-boundary + runtime probes.
testing/python/metal/test_metal_gemm_v2_linux.py Cross-platform (no-Metal-runtime) Metal GEMM source-generation checks.
testing/python/metal/test_metal_gemm_v2.py On-device Metal GEMM correctness checks.
testing/python/metal/test_fp8_scaled_matmul_metal.py Metal lowering + offline compile + runtime parity tests for fp8_scaled_matmul.
testing/python/metal/metal_internal_runtime_coverage.md Documents internal Metal runtime coverage and constraints.
testing/python/jit/test_tilelang_jit_adapter_mps.py Adds tests covering MPS device preference in the JIT adapter.
testing/python/cpu/test_fp8_scaled_matmul_lowering.py IR/source-level lowering tests for fp8_scaled_matmul (no GPU required).
src/transform/lower_device_storage_access_info.cc Treats .fragment scope as exempt from storage access info lowering.
src/transform/layout_inference.cc Skips fragment-layout requirement on Metal targets for GEMM accumulators.
src/target/codegen_metal.h Adds TileLang-specific Metal codegen class declaration.
src/target/codegen_metal.cc Implements TileLang Metal codegen and registers target.build.tilelang_metal.
src/op/utils.h Adds helpers for metal.simdgroup buffers and a combined “register buffer” predicate.
src/op/parallel.cc Makes fragment layout inference more defensive when layout info is absent.
src/op/gemm.h Adds Metal simdgroup GEMM instruction enum value in C++ core.
src/op/gemm.cc Selects Metal GEMM instruction on Metal targets; tweaks warp policy for Metal.
src/op/fill.cc Adds Fill lowering for metal.simdgroup buffers via make_filled_simdgroup_matrix.
src/op/copy.h Adds Metal simdgroup copy instruction kind and lowering hooks.
src/op/copy.cc Implements simdgroup store lowering from metal.simdgroup to shared/global.
src/backend/metal/CMakeLists.txt Ensures Metal codegen builds in “codegen-only” mode cross-platform.
requirements.txt Adds Darwin-only upper bound for apache-tvm-ffi.
requirements-dev.txt Adds Darwin-only upper bound for apache-tvm-ffi (dev).
pyproject.toml Adds Darwin-only upper bound for apache-tvm-ffi (packaging).
benchmark/matmul_metal/benchmark_matmul_metal.py Adds a Metal GEMM benchmark script for simdgroup GEMM.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +113 to +124
# Storage-level FP8 dtype tags accepted by this intrinsic. Any other dtype
# in the A / B operands raises a TypeError at parse time. ``float8_e8m0fnu``
# is the block-scale-factor format and is intentionally excluded — it is
# carried by the sf_a / sf_b operands of the block-scaled GEMM, not by A / B.
FP8_DTYPES: tuple[str, ...] = ("float8_e4m3", "float8_e5m2", "float8_e4m3fn", "float8_e4m3fnuz", "float8_e5m2fnuz")


def _is_fp8_dtype(dt) -> bool:
"""Return True if a dtype string / object names an FP8 storage variant."""
s = str(dt or "")
return any(s.startswith(t) for t in ("float8", "fp8"))

Comment on lines +261 to +266
M_dim, K_dim = A_fp8.shape
K_dim_b, N_dim = B_fp8.shape
sa_size = A_scale.shape[0]
sb_size = B_scale.shape[0]

# The accumulation matches the audiohacking ``fp8_scaled_matmul_kernel``
Comment on lines +273 to +274
a_val = T.cast(A_fp8[i, k], "float32")
b_val = T.cast(B_fp8[k, j], "float32")
return lambda: torch.device("cuda", current_device())
except Exception:
return lambda: torch.device("cuda", torch.cuda.current_device())
pass
Comment on lines +56 to +58
def get_kernel_source(self, kernel_only: bool = True) -> str:
return self.kernel_global_source or ""

Comment on lines +69 to +102
buf_map = {}

def _pre_order(stmt):
if isinstance(stmt, tir.Block):
new_alloc_bufs = []
changed = False
for buf in stmt.alloc_buffers:
new_buf = _remap_buffer(buf, var_map)
new_alloc_bufs.append(new_buf)
if not new_buf.same_as(buf):
buf_map[buf] = new_buf
changed = True
if changed:
new_body = tir.stmt_functor.substitute(stmt.body, var_map)
new_block = tir.Block(
stmt.iter_vars,
stmt.reads,
stmt.writes,
stmt.name_hint,
new_body,
stmt.init,
new_alloc_bufs,
stmt.match_buffers,
stmt.annotations,
)
return (
tir.BlockRealize(
stmt.iter_vars,
tir.const(True, "bool"),
new_block,
)
if False
else new_block
)
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
tilelang/jit/adapter/base.py (1)

77-86: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

pass in the except block silently drops the CUDA fallback for partial-init failures

The single try block catches two distinct failure modes:

  1. torch.cuda._lazy_init() raises → CUDA is genuinely unusable → falling through to MPS/CPU is correct.
  2. torch._C._cuda_getDevice attribute access fails → only the fast path is unavailable; CUDA is still usable → the old code returned lambda: torch.device("cuda", torch.cuda.current_device()) here, which was correct.

With pass, Scenario 2 silently routes a fully functional CUDA machine to MPS (if accidentally present) or CPU. That would silently allocate output tensors on the wrong device (see tvm_ffi.py line 242), likely causing a device mismatch at kernel launch time rather than an obvious error.

Splitting the try preserves the intended fall-through for Scenario 1 while restoring the safe CUDA fallback for Scenario 2:

🛡️ Proposed fix
         if torch.cuda.is_available():
-            try:
-                torch.cuda._lazy_init()
-                current_device = torch._C._cuda_getDevice
-                return lambda: torch.device("cuda", current_device())
-            except Exception:
-                pass
+            try:
+                torch.cuda._lazy_init()
+            except Exception:
+                pass  # CUDA init failed entirely; fall through to MPS/CPU
+            else:
+                try:
+                    current_device = torch._C._cuda_getDevice
+                    return lambda: torch.device("cuda", current_device())
+                except Exception:
+                    # Fast-path C handle unavailable; fall back to public API
+                    return lambda: torch.device("cuda", torch.cuda.current_device())

Also, the docstring on line 74 ("On CPU or when CUDA is unavailable, returns torch.device('cpu')") no longer reflects reality — MPS is now also a possible return value.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/jit/adapter/base.py` around lines 77 - 86, The try/except currently
swallows attribute-access failures and incorrectly falls back away from CUDA;
change the logic so torch.cuda._lazy_init() is attempted in its own try and only
if it raises should we fall through to MPS/CPU, while failures when accessing
torch._C._cuda_getDevice (or AttributeError) should use the safe fallback used
previously (i.e., return a lambda that calls torch.device("cuda",
torch.cuda.current_device()) or the retrieved current_device()), and keep other
exceptions bubbling or handled appropriately; also update the function docstring
(the string around line 74) to reflect that the function may return MPS or CPU
when CUDA is unavailable.
src/op/parallel.cc (1)

569-575: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

ValidateCandidateAgainstFragments has the same unsafe .as<Fragment>().value() pattern that was just fixed at lines 381-383.

ValidateCandidateAgainstFragments iterates indice_map_ with only a T.layout_map.count(buffer) pre-check (line 567) — the same setup as the unfixed code that motivated the change at 381-383. If a non-Fragment layout appears for a buffer in indice_map_, line 572 aborts unconditionally. This function is called from at least two sites: line 475 and ChooseBestCandidate (lines 779-780). The same issue exists at line 726 inside BuildReplicationGuardsIfNeeded.

🛡️ Proposed defensive fix for `ValidateCandidateAgainstFragments` (line 572) and `BuildReplicationGuardsIfNeeded` (line 726)
// ValidateCandidateAgainstFragments, line 572
-    auto fragment = T.layout_map[buffer].as<Fragment>().value();
+    auto fragment_opt = T.layout_map[buffer].as<Fragment>();
+    if (!fragment_opt.has_value())
+      continue;
+    auto fragment = fragment_opt.value();
// BuildReplicationGuardsIfNeeded, line 726
-      auto fragment_layout = T.layout_map[fragment].as<Fragment>().value();
+      auto fragment_layout_opt = T.layout_map[fragment].as<Fragment>();
+      if (!fragment_layout_opt.has_value())
+        continue;
+      auto fragment_layout = fragment_layout_opt.value();
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/parallel.cc` around lines 569 - 575, ValidateCandidateAgainstFragments
(and BuildReplicationGuardsIfNeeded) use the unsafe pattern
T.layout_map[buffer].as<Fragment>().value() which will abort if the layout isn't
a Fragment; replace this with a defensive check that the layout exists and is a
Fragment before accessing it. For example, retrieve the optional layout via auto
layout_opt = T.layout_map[buffer]; if (!layout_opt.has_value() ||
!layout_opt->is<Fragment>()) continue (or otherwise skip/handle non-Fragment
cases), then safely extract the Fragment via layout_opt->as<Fragment>().value();
apply the same guard in BuildReplicationGuardsIfNeeded and any other call sites
(e.g., ChooseBestCandidate) that assume Fragment layouts.
tilelang/tileop/gemm/__init__.py (1)

154-175: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Docstring priority ordering is misleading — Metal is checked first, not at position 6.

The docstring lists METAL_SIMDGROUP as priority 6 (after MFMA, WMMA, MMA), but the implementation returns early for Metal before any C++ FFI call. A reader following the docstring's priority list would incorrectly assume Metal goes through the same selection chain as CUDA/AMD targets.

📝 Proposed fix: reorder the docstring to match implementation
-        The selection logic follows this priority:
-        1. TCGEN5MMA for Blackwell architecture
-        2. WGMMA for Hopper architecture with sufficient matrix size and warp count
-        3. MFMA for CDNA (AMD) architecture
-        4. WMMA for RDNA (AMD) architecture
-        5. MMA for CUDA architecture
-        6. METAL_SIMDGROUP for Metal target (simdgroup_matrix)
-        7. Scalar for CPU target (scalar fallback)
+        The selection logic follows this priority:
+        1. METAL_SIMDGROUP for Metal target (short-circuit before C++ FFI dispatch)
+        2. TCGEN5MMA for Blackwell architecture
+        3. WGMMA for Hopper architecture with sufficient matrix size and warp count
+        4. MFMA for CDNA (AMD) architecture
+        5. WMMA for RDNA (AMD) architecture
+        6. MMA for CUDA architecture
+        7. Scalar for CPU target (scalar fallback)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/tileop/gemm/__init__.py` around lines 154 - 175, Update the
_select_gemm_instruction docstring to reflect the actual implementation order:
move METAL_SIMDGROUP (target_is_metal check returning GemmInst.METAL_SIMDGROUP)
to be evaluated before the FFI selection, and clarify that all other targets are
resolved via _ffi_api.GemmGetGemmInst(self, int(thread_nums), target); reference
the function name _select_gemm_instruction, the target_is_metal check,
GemmInst.METAL_SIMDGROUP, and _ffi_api.GemmGetGemmInst so readers see the
docstring matches the code path.
🧹 Nitpick comments (8)
testing/python/jit/test_tilelang_jit_adapter_mps.py (1)

10-47: ⚡ Quick win

Tests look correct — LGTM for the three covered scenarios

All three test functions correctly isolate the CUDA/MPS availability flags and assert the expected device. The SimpleNamespace approach for platforms where torch.backends.mps may not exist is a clean workaround.

One gap worth adding (relates directly to the regression flagged in base.py): there is currently no test for the case where torch.cuda.is_available() is True, _lazy_init() succeeds, but torch._C._cuda_getDevice fails. With the proposed split-try fix in base.py, that path should return a CUDA device (not MPS/CPU), and a test like the following would pin that behavior:

def test_current_device_functor_falls_back_to_cuda_when_c_handle_fails(monkeypatch):
    monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
    monkeypatch.setattr(torch.cuda, "_lazy_init", lambda: None)  # succeeds
    monkeypatch.setattr(torch._C, "_cuda_getDevice", None, raising=False)  # remove handle

    device_functor = BaseKernelAdapter.get_current_device_functor()

    assert device_functor().type == "cuda"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/jit/test_tilelang_jit_adapter_mps.py` around lines 10 - 47,
Add a test that covers the regression where torch.cuda.is_available() is True
and torch.cuda._lazy_init() succeeds but the CUDA C handle lookup fails; create
a new test (e.g.,
test_current_device_functor_falls_back_to_cuda_when_c_handle_fails) that
monkeypatches torch.cuda.is_available to return True, monkeypatches
torch.cuda._lazy_init to a no-op (succeeds), and removes or sets
torch._C._cuda_getDevice to None/raise (using monkeypatch.setattr with
raising=False) then calls BaseKernelAdapter.get_current_device_functor() and
asserts the returned device has type "cuda" to ensure the split-try behavior in
BaseKernelAdapter.get_current_device_functor is preserved.
tilelang/intrinsics/metal_macro_generator.py (1)

46-52: 💤 Low value

Consider using a more robust assertion for production code.

Line 49 uses a bare assert which can be stripped when Python runs with -O. For kernel generation code, consider raising a ValueError or using a proper check that survives optimization.

♻️ Proposed fix
     def get_thread_binding(self):
         if self.thread_var is None:
             current_frame = T.KernelLaunchFrame.Current()
-            assert current_frame is not None, "Must be called in a T.Kernel Frame"
+            if current_frame is None:
+                raise RuntimeError("MPSIntrinEmitter: Must be called within a T.Kernel Frame")
             return current_frame.get_thread_binding()
         else:
             return self.thread_var
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/intrinsics/metal_macro_generator.py` around lines 46 - 52, In
get_thread_binding, replace the bare assert on current_frame (from
T.KernelLaunchFrame.Current()) with an explicit runtime check that raises a
suitable exception (e.g., ValueError or RuntimeError) so the failure isn't
stripped with -O; locate the method get_thread_binding and the variables
thread_var and current_frame and change the control flow to raise the new
exception with a clear message instead of using assert before returning
current_frame.get_thread_binding().
tilelang/language/fp8_op.py (1)

261-264: ⚡ Quick win

Prefix unused unpacked variable with underscore.

K_dim_b is extracted but never used. Python convention and linters (RUF059) expect unused variables to be prefixed with _ to signal intentional discard.

♻️ Proposed fix
-    M_dim, K_dim = A_fp8.shape
-    K_dim_b, N_dim = B_fp8.shape
+    M_dim, K_dim = A_fp8.shape
+    _K_dim_b, N_dim = B_fp8.shape

Apply the same fix to line 284 in _fp8_scaled_matmul_macro_trans_b:

-    N_dim, K_dim_b = B_fp8.shape
+    N_dim, _K_dim_b = B_fp8.shape
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/language/fp8_op.py` around lines 261 - 264, The unpacked variable
K_dim_b in the function _fp8_scaled_matmul_macro_trans_a is unused; rename it to
_K_dim_b (prefix with underscore) to satisfy linting (RUF059) and indicate
intentional discard, and make the equivalent change in
_fp8_scaled_matmul_macro_trans_b (the similar unpack on line ~284) so both
places use _K_dim_b instead of K_dim_b.
pyproject.toml (1)

33-34: 💤 Low value

Add a comment explaining the Darwin-specific version constraint.

The <0.1.8 upper bound on Darwin lacks context. Future maintainers won't know why this limit exists or when it's safe to remove. Consider adding a brief inline comment documenting the incompatibility (e.g., a specific bug or breaking change in 0.1.8+).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pyproject.toml` around lines 33 - 34, Add an inline comment next to the
Darwin-specific constraint for "apache-tvm-ffi<0.1.8; platform_system ==
'Darwin'" in pyproject.toml that briefly explains why versions >=0.1.8 are
excluded (referencing the specific bug/PR/commit or the observed breaking
behavior), include a link or identifier for the upstream issue if available, and
note the condition under which the upper bound can be removed (e.g., fixed
upstream version or date).
tilelang/language/__init__.py (1)

117-117: 💤 Low value

Consider grouping with other operation imports.

The fp8_scaled_matmul import is placed in the middle of builtin re-exports (between ldg256 and ballot_sync), breaking the logical grouping. Consider moving it near other operation imports (e.g., after gemm_op imports at line 66 or after fill_op at line 68) for better organization.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/language/__init__.py` at line 117, Move the fp8_scaled_matmul
re-export so it stays with operation imports: locate the current import "from
.fp8_op import fp8_scaled_matmul" (currently between ldg256 and ballot_sync
among builtin re-exports) and cut/paste it into the block of other op re-exports
(e.g., immediately after the gemm_op imports or after fill_op) so all operation
functions (including fp8_scaled_matmul) are grouped together for clearer
organization.
tilelang/engine/phase.py (1)

200-204: The MetalFragmentToSimdgroup pass IS internally guarded — it checks the bound target attribute and returns early for non-Metal targets, so there's no risk of corrupting CUDA/HIP local.fragment buffers.

However, the unconditional invocation at the call site is still worth clarifying. Consider either:

  • Adding an explicit if target.kind.name == "metal" guard at the call site, or
  • Updating the comment to note that the pass has an internal Metal-only guard

This improves readability since readers shouldn't have to inspect the pass implementation to understand that it only affects Metal targets.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/engine/phase.py` around lines 200 - 204, Call site currently
unconditionally constructs and applies MetalFragmentToSimdgroup(mod); either add
an explicit target-kind guard around that call (e.g., check target.kind.name ==
"metal" before importing/instantiating MetalFragmentToSimdgroup and assigning
mod = MetalFragmentToSimdgroup(mod)) so the intent is obvious at the call site,
or update the existing comment above the import to state that
MetalFragmentToSimdgroup already checks the bound target and returns early for
non-Metal targets; reference the MetalFragmentToSimdgroup class and the mod
variable in your change so readers can quickly find the pass and its use.
tilelang/transform/metal_fragment_to_simdgroup.py (1)

94-102: 💤 Low value

Dead code: the if False branch is never taken.

The ternary expression tir.BlockRealize(...) if False else new_block always returns new_block. This looks like a debugging remnant that should be cleaned up.

Proposed cleanup
-                return (
-                    tir.BlockRealize(
-                        stmt.iter_vars,
-                        tir.const(True, "bool"),
-                        new_block,
-                    )
-                    if False
-                    else new_block
-                )
+                return new_block
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/transform/metal_fragment_to_simdgroup.py` around lines 94 - 102,
Remove the dead ternary that always selects the else branch; replace the
expression "tir.BlockRealize(... ) if False else new_block" with just
"new_block" so the function returns new_block directly (remove the unused
tir.BlockRealize/ tir.const(stmt.iter_vars, ...) debug remnant surrounding
new_block).
testing/python/metal/test_metal_gemm_v2.py (1)

84-86: 💤 Low value

Consider documenting or tightening the loose tolerance for 1024×1024×1024.

atol=1.0 is a very loose absolute tolerance. For fp16 inputs accumulated in fp32, typical differences should be much smaller. This may mask real numerical issues. Consider:

  1. Adding a comment explaining why this tolerance is needed
  2. Using relative tolerance (rtol) in addition to atol
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2.py` around lines 84 - 86, The test
test_gemm_v2_1024 uses a very loose absolute tolerance (atol=1.0); update the
test to (1) add a brief inline comment in test_gemm_v2_1024 explaining why a
relaxed tolerance is required for 1024×1024×1024 fp16 inputs (e.g., fp16 inputs
with fp32 accumulation and known nondeterminism/quantization artifacts), and (2)
tighten the check by replacing the single atol with a combination of a smaller
atol and an rtol (e.g., rtol=1e-2 and atol=1e-2) when calling assert_gemm_v2 so
the test still allows small fp16 rounding differences but will catch larger
numerical regressions; refer to the test function name test_gemm_v2_1024 and the
assertion helper assert_gemm_v2 when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmark/matmul_metal/benchmark_matmul_metal.py`:
- Around line 58-71: The benchmark is inconsistent: bench_torch_mps allocates
the output implicitly each iteration while bench_tilelang reuses c, biasing the
ratio; change bench_torch_mps to preallocate a torch.zeros output (matching
dtype/device/shape of c) and use torch.mm(a, b, out=preallocated_c) or otherwise
call the kernel with that preallocated tensor so both paths reuse the same
output allocation (refer to bench_torch_mps, bench_tilelang, variables a/b/c and
function matmul_simdgroup/_bench).

In `@src/op/gemm.cc`:
- Around line 204-207: Comments referencing a hardcoded "16" are now stale after
introducing kMPerWarp and the TargetIsMetal path; update the two comments in the
non-WGMMA warp-partition logic to refer to kMPerWarp (or to a neutral phrasing
like "m_warp * kMPerWarp" / "kMPerWarp elements") instead of the literal "16".
Locate the block that defines int kMPerWarp = 16; if (TargetIsMetal(target)) {
kMPerWarp = 8; } and replace the comment lines that read "// If M cannot be
evenly divided by m_warp*16" and "// Each warp needs at least 16 elements in M"
with wording that uses kMPerWarp (or "kMPerWarp" spelled out) or a generic
description ("m_warp * kMPerWarp" and "kMPerWarp elements per warp") so the
comments stay correct for both Metal and non-Metal targets.

In `@src/target/codegen_metal.cc`:
- Around line 55-57: The generated union __TVMArgUnion only declares v_int[2]
but kernel arg accessors emit fields like v_half, v_bool, v_char, etc.; update
the union definition (symbol: __TVMArgUnion) to include matching members for all
sub-32/64-bit POD types used by the accessor code (e.g., v_half, v_bool, v_char,
v_int8, v_int16) or alter the accessor emission logic so it does not emit
v_<type> field accesses for these types; ensure you make the same change in both
places where the union is emitted and where accessors are generated (the two
corresponding emit sites that produce the union and arg.<name>.v_* accesses).

In `@testing/python/metal/test_metal_local_var.py`:
- Line 36: The test's assertion uses the regex "\bint\s+\w+\s*=\s*0;" and
expects at least 2 matches, but only one variable (`y`) is initialized to 0;
update the assertion in testing/python/metal/test_metal_local_var.py to reflect
the actual output by either changing the numeric expectation from ">= 2" to ">=
1" or broaden the regex to "\bint\s+\w+\s*=" so it matches any int
initialization; adjust the assertion line containing the regex to use the chosen
fix.

In `@tilelang/tileop/metal_simdgroup.py`:
- Around line 208-231: The helpers load_tile and store_tile currently only
operate on tile.index(0, 0), silently dropping other fragments for MMATile
instances with fragments_m/fragments_n > 1; update both functions (load_tile and
the similar store_tile block) to either iterate over all fragment coordinates
(loop m in range(tile.fragments_m) and n in range(tile.fragments_n) and call
load/store with tile.index(m, n)) or, if you prefer a fail-fast behavior, assert
or raise an exception when tile.fragments_m != 1 or tile.fragments_n != 1 so
partial materialization cannot occur; locate references to tile.index(0, 0) in
load_tile and the corresponding store_tile implementation and apply the chosen
fix consistently.
- Around line 385-396: mma_tile currently only uses a.index(tile_m, 0) and
b.index(0, tile_n), dropping additional K fragments; add a K-fragment loop (or
reject non-unit K tiling). Concretely, inside mma_tile (the T.macro using
MMATile and mma) iterate tile_k with T.unroll over the tile/K fragment count
(e.g., acc.fragments_k or the appropriate MMATile.fragments_k property), call
mma(acc.fragment, a.fragment, b.fragment, acc.index(tile_m,tile_n),
a.index(tile_m,tile_k), b.index(tile_k,tile_n)) for each tile_k to accumulate
partial products, and add shape checks/assertions that a.fragments_n ==
acc.fragments_k == b.fragments_m (or throw if fragments_k == 1 is required) so
non-unit K tiling is handled safely.

---

Outside diff comments:
In `@src/op/parallel.cc`:
- Around line 569-575: ValidateCandidateAgainstFragments (and
BuildReplicationGuardsIfNeeded) use the unsafe pattern
T.layout_map[buffer].as<Fragment>().value() which will abort if the layout isn't
a Fragment; replace this with a defensive check that the layout exists and is a
Fragment before accessing it. For example, retrieve the optional layout via auto
layout_opt = T.layout_map[buffer]; if (!layout_opt.has_value() ||
!layout_opt->is<Fragment>()) continue (or otherwise skip/handle non-Fragment
cases), then safely extract the Fragment via layout_opt->as<Fragment>().value();
apply the same guard in BuildReplicationGuardsIfNeeded and any other call sites
(e.g., ChooseBestCandidate) that assume Fragment layouts.

In `@tilelang/jit/adapter/base.py`:
- Around line 77-86: The try/except currently swallows attribute-access failures
and incorrectly falls back away from CUDA; change the logic so
torch.cuda._lazy_init() is attempted in its own try and only if it raises should
we fall through to MPS/CPU, while failures when accessing
torch._C._cuda_getDevice (or AttributeError) should use the safe fallback used
previously (i.e., return a lambda that calls torch.device("cuda",
torch.cuda.current_device()) or the retrieved current_device()), and keep other
exceptions bubbling or handled appropriately; also update the function docstring
(the string around line 74) to reflect that the function may return MPS or CPU
when CUDA is unavailable.

In `@tilelang/tileop/gemm/__init__.py`:
- Around line 154-175: Update the _select_gemm_instruction docstring to reflect
the actual implementation order: move METAL_SIMDGROUP (target_is_metal check
returning GemmInst.METAL_SIMDGROUP) to be evaluated before the FFI selection,
and clarify that all other targets are resolved via
_ffi_api.GemmGetGemmInst(self, int(thread_nums), target); reference the function
name _select_gemm_instruction, the target_is_metal check,
GemmInst.METAL_SIMDGROUP, and _ffi_api.GemmGetGemmInst so readers see the
docstring matches the code path.

---

Nitpick comments:
In `@pyproject.toml`:
- Around line 33-34: Add an inline comment next to the Darwin-specific
constraint for "apache-tvm-ffi<0.1.8; platform_system == 'Darwin'" in
pyproject.toml that briefly explains why versions >=0.1.8 are excluded
(referencing the specific bug/PR/commit or the observed breaking behavior),
include a link or identifier for the upstream issue if available, and note the
condition under which the upper bound can be removed (e.g., fixed upstream
version or date).

In `@testing/python/jit/test_tilelang_jit_adapter_mps.py`:
- Around line 10-47: Add a test that covers the regression where
torch.cuda.is_available() is True and torch.cuda._lazy_init() succeeds but the
CUDA C handle lookup fails; create a new test (e.g.,
test_current_device_functor_falls_back_to_cuda_when_c_handle_fails) that
monkeypatches torch.cuda.is_available to return True, monkeypatches
torch.cuda._lazy_init to a no-op (succeeds), and removes or sets
torch._C._cuda_getDevice to None/raise (using monkeypatch.setattr with
raising=False) then calls BaseKernelAdapter.get_current_device_functor() and
asserts the returned device has type "cuda" to ensure the split-try behavior in
BaseKernelAdapter.get_current_device_functor is preserved.

In `@testing/python/metal/test_metal_gemm_v2.py`:
- Around line 84-86: The test test_gemm_v2_1024 uses a very loose absolute
tolerance (atol=1.0); update the test to (1) add a brief inline comment in
test_gemm_v2_1024 explaining why a relaxed tolerance is required for
1024×1024×1024 fp16 inputs (e.g., fp16 inputs with fp32 accumulation and known
nondeterminism/quantization artifacts), and (2) tighten the check by replacing
the single atol with a combination of a smaller atol and an rtol (e.g.,
rtol=1e-2 and atol=1e-2) when calling assert_gemm_v2 so the test still allows
small fp16 rounding differences but will catch larger numerical regressions;
refer to the test function name test_gemm_v2_1024 and the assertion helper
assert_gemm_v2 when making the change.

In `@tilelang/engine/phase.py`:
- Around line 200-204: Call site currently unconditionally constructs and
applies MetalFragmentToSimdgroup(mod); either add an explicit target-kind guard
around that call (e.g., check target.kind.name == "metal" before
importing/instantiating MetalFragmentToSimdgroup and assigning mod =
MetalFragmentToSimdgroup(mod)) so the intent is obvious at the call site, or
update the existing comment above the import to state that
MetalFragmentToSimdgroup already checks the bound target and returns early for
non-Metal targets; reference the MetalFragmentToSimdgroup class and the mod
variable in your change so readers can quickly find the pass and its use.

In `@tilelang/intrinsics/metal_macro_generator.py`:
- Around line 46-52: In get_thread_binding, replace the bare assert on
current_frame (from T.KernelLaunchFrame.Current()) with an explicit runtime
check that raises a suitable exception (e.g., ValueError or RuntimeError) so the
failure isn't stripped with -O; locate the method get_thread_binding and the
variables thread_var and current_frame and change the control flow to raise the
new exception with a clear message instead of using assert before returning
current_frame.get_thread_binding().

In `@tilelang/language/__init__.py`:
- Line 117: Move the fp8_scaled_matmul re-export so it stays with operation
imports: locate the current import "from .fp8_op import fp8_scaled_matmul"
(currently between ldg256 and ballot_sync among builtin re-exports) and
cut/paste it into the block of other op re-exports (e.g., immediately after the
gemm_op imports or after fill_op) so all operation functions (including
fp8_scaled_matmul) are grouped together for clearer organization.

In `@tilelang/language/fp8_op.py`:
- Around line 261-264: The unpacked variable K_dim_b in the function
_fp8_scaled_matmul_macro_trans_a is unused; rename it to _K_dim_b (prefix with
underscore) to satisfy linting (RUF059) and indicate intentional discard, and
make the equivalent change in _fp8_scaled_matmul_macro_trans_b (the similar
unpack on line ~284) so both places use _K_dim_b instead of K_dim_b.

In `@tilelang/transform/metal_fragment_to_simdgroup.py`:
- Around line 94-102: Remove the dead ternary that always selects the else
branch; replace the expression "tir.BlockRealize(... ) if False else new_block"
with just "new_block" so the function returns new_block directly (remove the
unused tir.BlockRealize/ tir.const(stmt.iter_vars, ...) debug remnant
surrounding new_block).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 6c526c4b-0de1-40b4-9566-88d13550d371

📥 Commits

Reviewing files that changed from the base of the PR and between d135bd1 and 1984db9.

📒 Files selected for processing (41)
  • benchmark/matmul_metal/benchmark_matmul_metal.py
  • pyproject.toml
  • requirements-dev.txt
  • requirements.txt
  • src/backend/metal/CMakeLists.txt
  • src/op/copy.cc
  • src/op/copy.h
  • src/op/fill.cc
  • src/op/gemm.cc
  • src/op/gemm.h
  • src/op/parallel.cc
  • src/op/utils.h
  • src/target/codegen_metal.cc
  • src/target/codegen_metal.h
  • src/transform/layout_inference.cc
  • src/transform/lower_device_storage_access_info.cc
  • testing/python/cpu/test_fp8_scaled_matmul_lowering.py
  • testing/python/jit/test_tilelang_jit_adapter_mps.py
  • testing/python/metal/metal_internal_runtime_coverage.md
  • testing/python/metal/test_fp8_scaled_matmul_metal.py
  • testing/python/metal/test_metal_gemm_v2.py
  • testing/python/metal/test_metal_gemm_v2_linux.py
  • testing/python/metal/test_metal_internal_scaffolding.py
  • testing/python/metal/test_metal_local_var.py
  • testing/python/metal/test_metal_simdgroup_store.py
  • tilelang/engine/lower.py
  • tilelang/engine/phase.py
  • tilelang/intrinsics/metal_macro_generator.py
  • tilelang/jit/adapter/base.py
  • tilelang/jit/adapter/torch/metal.py
  • tilelang/language/__init__.py
  • tilelang/language/fp8_op.py
  • tilelang/tileop/gemm/__init__.py
  • tilelang/tileop/gemm/gemm_metal.py
  • tilelang/tileop/gemm/inst.py
  • tilelang/tileop/metal_gdn.py
  • tilelang/tileop/metal_quant.py
  • tilelang/tileop/metal_simdgroup.py
  • tilelang/transform/decouple_type_cast.py
  • tilelang/transform/metal_fragment_to_simdgroup.py
  • tilelang/utils/language.py

Comment on lines +58 to +71
def bench_torch_mps(M, N, K, warmup, repeats):
a = torch.randn(M, K, dtype=torch.float16, device="mps")
b = torch.randn(K, N, dtype=torch.float16, device="mps")
avg_s = _bench(lambda: torch.mm(a, b), warmup, repeats)
return _tflops(M, N, K, avg_s)


def bench_tilelang(M, N, K, block_M, block_N, block_K, warmup, repeats):
kernel = matmul_simdgroup(M, N, K, block_M, block_N, block_K)
a = torch.randn(M, K, dtype=torch.float16, device="mps")
b = torch.randn(K, N, dtype=torch.float16, device="mps")
c = torch.zeros(M, N, dtype=torch.float32, device="mps")
avg_s = _bench(lambda: kernel(a, b, c), warmup, repeats)
return _tflops(M, N, K, avg_s)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Benchmark both paths with the same output-allocation policy.

bench_torch_mps measures torch.mm(a, b) with a fresh output tensor every iteration, but bench_tilelang reuses c. That bakes allocator cost into the reference only, so the printed TileLang/PyTorch ratio is artificially optimistic. Preallocate the reference output too, or include allocation cost on both sides.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmark/matmul_metal/benchmark_matmul_metal.py` around lines 58 - 71, The
benchmark is inconsistent: bench_torch_mps allocates the output implicitly each
iteration while bench_tilelang reuses c, biasing the ratio; change
bench_torch_mps to preallocate a torch.zeros output (matching dtype/device/shape
of c) and use torch.mm(a, b, out=preallocated_c) or otherwise call the kernel
with that preallocated tensor so both paths reuse the same output allocation
(refer to bench_torch_mps, bench_tilelang, variables a/b/c and function
matmul_simdgroup/_bench).

Comment thread src/op/gemm.cc
Comment on lines +204 to +207
int kMPerWarp = 16; // Rows processed by a single warp
if (TargetIsMetal(target)) {
kMPerWarp = 8;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Stale "16" in comments at lines 304 and 334 after introducing kMPerWarp = 8 for Metal.

Setting kMPerWarp = 8 for Metal is correct (simdgroup_matrix uses 8×8 tiles), but the existing comments in the non-WGMMA warp-partition logic still hardcode "16", which now misleads Metal readers:

  • Line 304: // If M cannot be evenly divided by m_warp*16
  • Line 334: // Each warp needs at least 16 elements in M

The code uses the kMPerWarp variable (correct), but the comments are now stale.

📝 Proposed comment fixes
-    // If M cannot be evenly divided by m_warp*16, try to split remaining warps
+    // If M cannot be evenly divided by m_warp*kMPerWarp, try to split remaining
     // to N
-    int max_m_warps =
-        M / kMPerWarp; // Each warp needs at least 16 elements in M
+    int max_m_warps =
+        M / kMPerWarp; // Each warp needs at least kMPerWarp elements in M (16 for CUDA, 8 for Metal)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/gemm.cc` around lines 204 - 207, Comments referencing a hardcoded "16"
are now stale after introducing kMPerWarp and the TargetIsMetal path; update the
two comments in the non-WGMMA warp-partition logic to refer to kMPerWarp (or to
a neutral phrasing like "m_warp * kMPerWarp" / "kMPerWarp elements") instead of
the literal "16". Locate the block that defines int kMPerWarp = 16; if
(TargetIsMetal(target)) { kMPerWarp = 8; } and replace the comment lines that
read "// If M cannot be evenly divided by m_warp*16" and "// Each warp needs at
least 16 elements in M" with wording that uses kMPerWarp (or "kMPerWarp" spelled
out) or a generic description ("m_warp * kMPerWarp" and "kMPerWarp elements per
warp") so the comments stay correct for both Metal and non-Metal targets.

Comment on lines +55 to +57
decl_stream << "union __TVMArgUnion {\n"
<< " int v_int[2];\n"
<< "};\n\n";
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Handle sub-32-bit scalar kernel args without bogus union fields.

For non-32/64-bit POD params, the emitted accessor becomes arg.<name>.v_half / v_bool / v_char / etc., but __TVMArgUnion only declares v_int[2]. Any kernel with a half/bool/int8-style launch arg will therefore generate invalid MSL and fail compilation. Either add matching union members or stop emitting v_<type> field accesses here.

Proposed fix
 union __TVMArgUnion {
   int v_int[2];
+  bool v_bool;
+  char v_char;
+  uchar v_uchar;
+  short v_short;
+  ushort v_ushort;
+  half v_half;
+  bfloat v_bfloat;
 };

Also applies to: 145-149

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/target/codegen_metal.cc` around lines 55 - 57, The generated union
__TVMArgUnion only declares v_int[2] but kernel arg accessors emit fields like
v_half, v_bool, v_char, etc.; update the union definition (symbol:
__TVMArgUnion) to include matching members for all sub-32/64-bit POD types used
by the accessor code (e.g., v_half, v_bool, v_char, v_int8, v_int16) or alter
the accessor emission logic so it does not emit v_<type> field accesses for
these types; ensure you make the same change in both places where the union is
emitted and where accessors are generated (the two corresponding emit sites that
produce the union and arg.<name>.v_* accesses).


# local.var should lower to scalar declarations/stores rather than arrays or
# an unsupported storage scope.
assert len(re.findall(r"\bint\s+\w+\s*=\s*0;", src)) >= 2, src
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Assertion may be incorrect: only one variable initializes to 0.

The regex \bint\s+\w+\s*=\s*0; expects >= 2 matches, but based on the kernel:

  • x = T.alloc_var(T.int32, init=3) → emits int x = 3;
  • y = T.alloc_var(T.int32) → emits int y = 0; (default)

Only y would match the pattern. Consider changing to >= 1 or updating the regex to match any initialization (e.g., \bint\s+\w+\s*=).

Proposed fix
-    assert len(re.findall(r"\bint\s+\w+\s*=\s*0;", src)) >= 2, src
+    assert len(re.findall(r"\bint\s+\w+\s*=\s*0;", src)) >= 1, src
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
assert len(re.findall(r"\bint\s+\w+\s*=\s*0;", src)) >= 2, src
assert len(re.findall(r"\bint\s+\w+\s*=\s*0;", src)) >= 1, src
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_local_var.py` at line 36, The test's
assertion uses the regex "\bint\s+\w+\s*=\s*0;" and expects at least 2 matches,
but only one variable (`y`) is initialized to 0; update the assertion in
testing/python/metal/test_metal_local_var.py to reflect the actual output by
either changing the numeric expectation from ">= 2" to ">= 1" or broaden the
regex to "\bint\s+\w+\s*=" so it matches any int initialization; adjust the
assertion line containing the regex to use the chosen fix.

Comment on lines +208 to +231
def load_tile(
tile: MMATile,
dtype,
data,
offset,
extent,
stride,
*,
rows: int = 8,
cols: int = 8,
transpose: bool = False,
) -> None:
load(
tile.fragment,
tile.index(0, 0),
dtype,
data,
offset,
extent,
stride,
rows,
cols,
transpose,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

load_tile / store_tile silently drop every fragment except (0, 0).

Both helpers accept an MMATile with fragments_m/fragments_n, but they hardcode tile.index(0, 0). A multi-fragment tile will therefore be only partially loaded/materialized, which is silent data corruption. Either iterate all fragments, or fail fast unless the tile is 1x1.

Proposed fix
 `@T.macro`
 def load_tile(
     tile: MMATile,
@@
 ) -> None:
+    if tile.fragments_m != 1 or tile.fragments_n != 1:
+        raise ValueError("load_tile currently supports only 1x1 MMATile")
     load(
         tile.fragment,
         tile.index(0, 0),
@@
 `@T.macro`
 def store_tile(
     tile: MMATile,
@@
 ) -> None:
+    if tile.fragments_m != 1 or tile.fragments_n != 1:
+        raise ValueError("store_tile currently supports only 1x1 MMATile")
     store(
         tile.fragment,
         tile.index(0, 0),

Also applies to: 296-320

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/tileop/metal_simdgroup.py` around lines 208 - 231, The helpers
load_tile and store_tile currently only operate on tile.index(0, 0), silently
dropping other fragments for MMATile instances with fragments_m/fragments_n > 1;
update both functions (load_tile and the similar store_tile block) to either
iterate over all fragment coordinates (loop m in range(tile.fragments_m) and n
in range(tile.fragments_n) and call load/store with tile.index(m, n)) or, if you
prefer a fail-fast behavior, assert or raise an exception when tile.fragments_m
!= 1 or tile.fragments_n != 1 so partial materialization cannot occur; locate
references to tile.index(0, 0) in load_tile and the corresponding store_tile
implementation and apply the chosen fix consistently.

Comment on lines +385 to +396
@T.macro
def mma_tile(acc: MMATile, a: MMATile, b: MMATile) -> None:
for tile_m in T.unroll(acc.fragments_m, explicit=True):
for tile_n in T.unroll(acc.fragments_n, explicit=True):
mma(
acc.fragment,
a.fragment,
b.fragment,
acc.index(tile_m, tile_n),
a.index(tile_m, 0),
b.index(0, tile_n),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

mma_tile misses the K-fragment reduction.

The implementation only multiplies a.index(tile_m, 0) by b.index(0, tile_n). If a.fragments_n / b.fragments_m is greater than 1, every partial product after k=0 is dropped and the result is wrong. Add a tile_k loop (plus shape checks), or reject non-unit K tiling here.

Proposed fix
 `@T.macro`
 def mma_tile(acc: MMATile, a: MMATile, b: MMATile) -> None:
+    if a.fragments_n != b.fragments_m:
+        raise ValueError(
+            f"incompatible tile K dimensions: {a.fragments_n} vs {b.fragments_m}"
+        )
     for tile_m in T.unroll(acc.fragments_m, explicit=True):
         for tile_n in T.unroll(acc.fragments_n, explicit=True):
-            mma(
-                acc.fragment,
-                a.fragment,
-                b.fragment,
-                acc.index(tile_m, tile_n),
-                a.index(tile_m, 0),
-                b.index(0, tile_n),
-            )
+            for tile_k in T.unroll(a.fragments_n, explicit=True):
+                mma(
+                    acc.fragment,
+                    a.fragment,
+                    b.fragment,
+                    acc.index(tile_m, tile_n),
+                    a.index(tile_m, tile_k),
+                    b.index(tile_k, tile_n),
+                )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@T.macro
def mma_tile(acc: MMATile, a: MMATile, b: MMATile) -> None:
for tile_m in T.unroll(acc.fragments_m, explicit=True):
for tile_n in T.unroll(acc.fragments_n, explicit=True):
mma(
acc.fragment,
a.fragment,
b.fragment,
acc.index(tile_m, tile_n),
a.index(tile_m, 0),
b.index(0, tile_n),
)
`@T.macro`
def mma_tile(acc: MMATile, a: MMATile, b: MMATile) -> None:
if a.fragments_n != b.fragments_m:
raise ValueError(
f"incompatible tile K dimensions: {a.fragments_n} vs {b.fragments_m}"
)
for tile_m in T.unroll(acc.fragments_m, explicit=True):
for tile_n in T.unroll(acc.fragments_n, explicit=True):
for tile_k in T.unroll(a.fragments_n, explicit=True):
mma(
acc.fragment,
a.fragment,
b.fragment,
acc.index(tile_m, tile_n),
a.index(tile_m, tile_k),
b.index(tile_k, tile_n),
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/tileop/metal_simdgroup.py` around lines 385 - 396, mma_tile
currently only uses a.index(tile_m, 0) and b.index(0, tile_n), dropping
additional K fragments; add a K-fragment loop (or reject non-unit K tiling).
Concretely, inside mma_tile (the T.macro using MMATile and mma) iterate tile_k
with T.unroll over the tile/K fragment count (e.g., acc.fragments_k or the
appropriate MMATile.fragments_k property), call mma(acc.fragment, a.fragment,
b.fragment, acc.index(tile_m,tile_n), a.index(tile_m,tile_k),
b.index(tile_k,tile_n)) for each tile_k to accumulate partial products, and add
shape checks/assertions that a.fragments_n == acc.fragments_k == b.fragments_m
(or throw if fragments_k == 1 is required) so non-unit K tiling is handled
safely.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants