Skip to content

Rebase Metal simdgroup GEMM support and runtime coverage#2130

Open
jorgecurious wants to merge 10 commits into
tile-ai:mainfrom
jorgecurious:metal-gemm-upstream-rebase
Open

Rebase Metal simdgroup GEMM support and runtime coverage#2130
jorgecurious wants to merge 10 commits into
tile-ai:mainfrom
jorgecurious:metal-gemm-upstream-rebase

Conversation

@jorgecurious
Copy link
Copy Markdown

@jorgecurious jorgecurious commented Apr 30, 2026

Summary

  • Rebase the Metal simdgroup GEMM backend support from [Metal] Add Metal GEMM support with simdgroup_matrix MMA #1869 onto current tile-ai/main, including the new backend-local CMake layout.
  • Add scalar local.var lowering for Metal codegen and prefer MPS as the JIT fallback device when CUDA is unavailable.
  • Add internal-only Metal source-boundary and runtime coverage for simdgroup register tiles, packed uint8 quant probes, and GDN/attention-style tiled kernels.

Upstream Compatibility

  • Keeps the public kernel surface target-agnostic through existing T.gemm / alloc_fragment usage.
  • Keeps RegisterTile/RowVector helpers internal under tilelang.tileop; no public T.rt or T.rv aliases are added.
  • Keeps native Metal fp8/fp4 storage fail-closed and uses packed uint8 probes only.
  • Does not add MPSGraph, MPP/cooperative tensor lowering, CUDA-specific scheduling, or model/checkpoint dependencies.

Rebase Notes

  • Integrated Metal codegen registration into src/backend/metal/CMakeLists.txt so it follows the backend-local CMake structure now used by upstream.
  • Metal shader source generation remains compiled on all hosts, while Metal runtime linkage remains gated through the Metal backend settings.

Verification

  • cmake -S . -B build -DUSE_CUDA=OFF -DUSE_ROCM=OFF -DUSE_METAL=ON -DCMAKE_BUILD_TYPE=Release -DPython3_EXECUTABLE=/Applications/Xcode.app/Contents/Developer/usr/bin/python3 -DPython_EXECUTABLE=/Applications/Xcode.app/Contents/Developer/usr/bin/python3
  • cmake --build build -j8
  • python3 -m pytest testing/python/metal/test_metal_gemm_v2.py testing/python/metal/test_metal_gemm_v2_linux.py testing/python/metal/test_metal_simdgroup_store.py testing/python/metal/test_metal_local_var.py testing/python/jit/test_tilelang_jit_adapter_mps.py testing/python/metal/test_metal_internal_scaffolding.py -q
  • Result: 39 passed, 3 skipped, 42 warnings in 30.33s

Summary by CodeRabbit

  • New Features

    • Broader Metal GPU support: simdgroup/register-tile primitives, Metal-target GEMM and attention-style kernels, in-kernel low-precision decode/quant, and API to retrieve generated kernel source.
    • Added a standalone Metal GEMM benchmarking script.
  • Tests

    • Large expansion of Metal-focused tests for codegen, correctness, runtime MPS validation, and optional benchmarks.
  • Documentation

    • Added Metal internal runtime coverage notes.
  • Chores

    • Build/config tweaks including macOS-specific dependency cap and Metal codegen plumbing.

oraluben and others added 6 commits April 30, 2026 01:43
Add T.gemm support for Apple Metal using simdgroup_matrix 8x8 operations
(simdgroup_load/store/multiply_accumulate). Works on all Apple Silicon
(M1-M5) without requiring a TVM fork.

Key changes:
- codegen_metal.cc/h: Fork TVM Metal codegen to tilelang with
  simdgroup intrinsic emission and 128-bit vectorized copy
- gemm_metal.py: GemmMetal tile operator for sharedxshared GEMM
- metal_macro_generator.py: MPSIntrinEmitter for simdgroup MMA macros
- metal_fragment_to_simdgroup.py: Pass rewrites local.fragment GEMM
  accumulators to metal.simdgroup scope before layout inference
- LowerSIMDGroupCopy in copy.cc for fragment->device simdgroup_store

24 Metal tests (codegen cross-platform + correctness on device).
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 30, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a Metal/MPS backend: Metal codegen, simdgroup register-tile intrinsics/macros, Metal-specific GEMM/copy/fill lowering, a pass rewriting local.fragmentmetal.simdgroup, JIT/MPS runtime updates, extensive Metal tests, and an MPS matmul benchmark script.

Changes

Cohort / File(s) Summary
Metal Codegen & Build
src/target/codegen_metal.cc, src/target/codegen_metal.h, src/backend/metal/CMakeLists.txt
Add CodeGenTileLangMetal to emit Metal kernel void sources and handle simdgroup ops/storage/sync; ensure codegen source is included in build generation and provide metallib compile path.
GEMM Integration & Insts
src/op/gemm.cc, src/op/gemm.h, tilelang/tileop/gemm/gemm_metal.py, tilelang/tileop/gemm/inst.py, tilelang/tileop/gemm/__init__.py
Introduce Metal GEMM instruction (enum) and GemmMetal lowering; adjust warp partitioning and select Metal simdgroup GEMM path.
Copy / Fill Lowering
src/op/copy.cc, src/op/copy.h, src/op/fill.cc
Add kMetalSIMDGroup copy kind with legality checks and lowering emitting 8×8 simdgroup_store tiles; add SIMD-group-aware fill lowering with 64-element chunking and alignment checks.
Simdgroup Primitives & Macros
tilelang/intrinsics/metal_macro_generator.py, tilelang/tileop/metal_simdgroup.py, tilelang/tileop/metal_gdn.py, tilelang/tileop/metal_quant.py
Add MPSIntrinEmitter, register-tile types and macros (alloc/load/store/mma), GDN/KKT/WU macros, fp8/fp4/e8m0 decode helpers, and tile-selection utilities.
Compiler Passes & Lowering Pipeline
tilelang/transform/metal_fragment_to_simdgroup.py, tilelang/transform/decouple_type_cast.py, tilelang/engine/phase.py, tilelang/engine/lower.py
New MetalFragmentToSimdgroup pass rewrites local.fragmentmetal.simdgroup; mark simdgroup as local; insert pass into LowerAndLegalize; route Metal codegen to target.build.tilelang_metal.
Layout / Storage / Utils
src/transform/layout_inference.cc, src/transform/lower_device_storage_access_info.cc, src/op/parallel.cc, src/op/utils.h, tilelang/utils/language.py
Relax fragment-layout enforcement for Metal, skip storage-info for .fragment, safer optional fragment access, and add IsSIMDGroupBuffer/IsRegisterBuffer/is_metal_simdgroup helpers.
JIT / Runtime / Adapter
tilelang/jit/adapter/base.py, tilelang/jit/adapter/torch/metal.py
Add MPS fallback device selection after CUDA init failure; expose MetalKernelAdapter.get_kernel_source.
Tests & Benchmarks
testing/python/metal/*, testing/python/jit/test_tilelang_jit_adapter_mps.py, benchmark/matmul_metal/benchmark_matmul_metal.py
Add extensive Metal-focused tests (GEMM, simdgroup store, local.var, internal scaffolding, quant/packed tests) and an MPS matmul benchmark comparing TileLang vs torch.mm with config sweep.
Dev Dependencies
pyproject.toml, requirements.txt, requirements-dev.txt
Add Darwin-specific upper bound for apache-tvm-ffi <0.1.8 while preserving baseline constraints.
Docs
testing/python/metal/metal_internal_runtime_coverage.md
Add internal runtime coverage doc enumerating validated Metal backend components and test entry points.

Sequence Diagram

sequenceDiagram
    participant User as User/Test
    participant JIT as TileLang JIT/Adapter
    participant Engine as Compilation Engine
    participant Pass as MetalFragmentToSimdgroup
    participant Lower as Layout & Lowering
    participant Codegen as CodeGenTileLangMetal
    participant MPS as MPS Runtime

    User->>JIT: submit prim_func for Metal
    JIT->>Engine: Lower & Legalize
    Engine->>Pass: apply MetalFragmentToSimdgroup
    Pass-->>Engine: prim_func with `metal.simdgroup` buffers
    Engine->>Lower: run layout inference & lowering (copy/fill/gemm)
    Lower-->>Engine: lowered device IR
    Engine->>Codegen: call `target.build.tilelang_metal`
    Codegen-->>Engine: emit Metal kernel source / module
    JIT->>MPS: run compiled kernel on MPS via adapter
    MPS-->>User: execution result
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

"I hopped through fragments, soft and bright,
packed simdgroup tiles in morning light,
kernels stitched in half‑precision cheer,
MPS hums — the metal spring is near,
rabbit dances, benchmarks appear!"

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 20.60% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Rebase Metal simdgroup GEMM support and runtime coverage' accurately and concisely summarizes the main objective of the PR, which is to rebase Metal simdgroup GEMM backend support and add runtime coverage validation.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
Review rate limit: 7/8 reviews remaining, refill in 7 minutes and 30 seconds.

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tilelang/jit/adapter/base.py (1)

76-86: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Extend the fallback into the CUDA-init failure path.

MPS is only considered when torch.cuda.is_available() is false. If _lazy_init() raises, this code still returns a CUDA device functor and skips the new MPS fallback, so broken CUDA setups can still fail instead of landing on mps or cpu.

Suggested fix
         if torch.cuda.is_available():
             try:
                 torch.cuda._lazy_init()
                 current_device = torch._C._cuda_getDevice
                 return lambda: torch.device("cuda", current_device())
             except Exception:
-                return lambda: torch.device("cuda", torch.cuda.current_device())
+                pass
         if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
             return lambda: torch.device("mps")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/jit/adapter/base.py` around lines 76 - 86, The CUDA init try/except
in the device chooser (around torch.cuda._lazy_init and the current_device
lambda) currently returns a fallback CUDA lambda on exception, skipping the MPS
and CPU fallbacks; change the except branch so that if torch.cuda._lazy_init()
fails it does not return a CUDA device lambda but instead checks for MPS
availability (torch.backends.mps.is_available()) and returns a mps lambda if
available, otherwise returns the cpu lambda—i.e., move the MPS/CPU fallback
logic into the except path (or call a shared helper) so broken CUDA setups fall
back to MPS or CPU instead of a failing CUDA lambda.
src/backend/metal/CMakeLists.txt (1)

14-24: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Non-Apple “codegen-only” mode still includes Metal runtime sources.

At Line 18 you switch USE_METAL off for non-Apple, but Lines 21-24 still append src/target/rt_mod_metal.cc. That contradicts the stated codegen-only behavior and can break portability.

Proposed fix
 if(NOT APPLE)
   # On non-Apple platforms USE_METAL=ON enables only codegen (Metal source
   # generation) without requiring the Metal/Foundation frameworks.
   message(STATUS "Metal backend on non-Apple: enabling codegen-only mode (no Metal runtime)")
-  set(USE_METAL OFF)
+  return()
 endif()

 file(GLOB TILE_LANG_METAL_SRCS
   src/target/rt_mod_metal.cc
 )
 list(APPEND TILE_LANG_SRCS ${TILE_LANG_METAL_SRCS})
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/backend/metal/CMakeLists.txt` around lines 14 - 24, The CMake script
disables USE_METAL for non-Apple platforms but still unconditionally globs and
appends the Metal runtime source rt_mod_metal.cc via
TILE_LANG_METAL_SRCS/TILE_LANG_SRCS; wrap the file(GLOB TILE_LANG_METAL_SRCS ...
) and list(APPEND TILE_LANG_SRCS ${TILE_LANG_METAL_SRCS}) in a conditional that
only runs when USE_METAL is ON (or when APPLE), or alternatively split
codegen-only sources from runtime sources and only append the runtime file
rt_mod_metal.cc when USE_METAL is true so non-Apple codegen-only builds do not
include Metal runtime sources.
🧹 Nitpick comments (1)
testing/python/metal/test_metal_gemm_v2.py (1)

23-35: ⚡ Quick win

Add a runtime case that keeps C_local as a fragment.

Using shared storage here exercises the shared-output lowering, but it skips the local.fragment -> metal.simdgroup rewrite that the public T.gemm/alloc_fragment path depends on. I'd add one hardware-backed case with T.alloc_fragment as well, so this file covers the public Metal GEMM path directly.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2.py` around lines 23 - 35, Add a
runtime case that allocates the output as a hardware-backed fragment instead of
shared memory: create a variant where C_local is created with T.alloc_fragment
(matching accum_dtype/shape) rather than T.alloc_shared, keep the same
T.clear(C_local), loop with T.gemm(A_shared, B_shared, C_local) and the final
T.copy(C_local, C[by * block_M, bx * block_N]), and ensure the test invokes both
the existing shared-storage path and this new alloc_fragment path so the
local.fragment -> metal.simdgroup rewrite is exercised; refer to symbols
A_shared, B_shared, C_local, T.alloc_shared, T.alloc_fragment, and T.gemm to
locate and implement the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmark/matmul_metal/benchmark_matmul_metal.py`:
- Around line 74-89: The script currently continues even when MPS is
unavailable; in the __main__ block (where parser/args, M/N/K are set and prints
occur) add an explicit availability check using
torch.backends.mps.is_available() and abort early (e.g., print an error and call
sys.exit(1) or raise RuntimeError) before any device allocations or use of
"mps"; ensure the check runs immediately after computing/printing MPS (or right
after args parsing) and include a clear message like "MPS backend not available,
aborting" referencing the symbols args, M, N, K and the
torch.backends.mps.is_available() call so downstream code never attempts to
allocate on "mps" when unavailable.

In `@src/op/copy.cc`:
- Line 1015: The code computes dst_stride from dst->shape[...] (dst_stride
variable) before calling simdgroup_store, which is wrong for strided
buffers/views; change those computations to read the actual destination stride
(e.g., the last entry of dst->strides or the buffer/view API that exposes
per-dimension strides) wherever dst_stride is used for simdgroup_store (the
dst_stride assignment at the dst_stride declaration and the similar occurrences
around the later simdgroup_store calls referenced in the comment, ~lines
1069-1072). Ensure you use the destination buffer/view stride accessor (falling
back only if no strides exist) so simdgroup_store writes use the correct memory
stride.
- Around line 802-807: CheckSIMDGroupCopy is too permissive: it only tests
target/scope but LowerSIMDGroupCopy assumes constant 2D extents and directly
dereferences IntImmNode and emits unconditional stores (risking OOB on edge
tiles). Tighten the legality check in CheckSIMDGroupCopy to verify the copy is a
2D constant-extent copy (both extents are IntImmNode constants and match the
expected SIMD group tile size/or divisibility) and that src/dst buffer kinds are
compatible (IsSIMDGroupBuffer with IsSharedBuffer/IsGlobalBuffer), and/or modify
LowerSIMDGroupCopy to detect non-constant or non-2D cases and fall back to the
generic copy path; additionally change LowerSIMDGroupCopy to emit
masked/conditional stores for edge tiles (or bounds checks) instead of
unconditional stores to prevent OOB writes when IntImm assumptions fail.

In `@src/op/fill.cc`:
- Around line 159-181: The simdgroup fill path (IsSIMDGroupBuffer) ignores
region offsets: the loop uses i as the matrix index for
make_filled_simdgroup_matrix without accounting for region[*].min, so fills
always start at matrix 0 and can overwrite when region is a slice; fix by
computing a start_matrix index from the region mins (e.g., sum/compute offset in
units of 64 elements from region[*].min) and use start_matrix + i when building
the Call, or if you want a minimal safe change, add an ICHECK that all
region[*].min are zero and return an error if any min != 0 to explicitly reject
non-zero mins before constructing stmts.

In `@testing/python/jit/test_tilelang_jit_adapter_mps.py`:
- Around line 10-33: Add a regression test that simulates CUDA appearing
available but failing during initialization so the except branch in
BaseKernelAdapter.get_current_device_functor() is executed: create a test (e.g.,
test_current_device_functor_handles_cuda_init_failure_prefers_mps) that
monkeypatches torch.cuda.is_available to return True, monkeypatches
torch.cuda._lazy_init (or the init function BaseKernelAdapter calls) to raise
RuntimeError, ensures torch.backends.mps.is_available returns True, then calls
BaseKernelAdapter.get_current_device_functor() and asserts the returned
device_functor() equals torch.device("mps"); use the same monkeypatch pattern as
the existing tests so cleanup is automatic.

In `@testing/python/metal/test_metal_internal_scaffolding.py`:
- Around line 488-497: The probe currently calls bad_kernel(32) which relies on
default JIT device selection and can miss Metal; instead, after obtaining the
returned PrimFunc from bad_kernel (the inner T.prim_func named main), explicitly
lower or compile that PrimFunc with target="metal" so the Metal lowering path is
exercised (e.g., invoke the lowering/compile API on the returned function with
target="metal"); update the probe where bad_kernel(32) is called (and the
similar call at 517-521) to perform this explicit lowering/compilation step.

In `@tilelang/jit/adapter/torch/metal.py`:
- Around line 56-57: The get_kernel_source method currently returns
self.kernel_global_source which may be None despite the method being annotated
to return str; update get_kernel_source to fail fast by checking if
self.kernel_global_source is None and either raise a clear ValueError (e.g.
"kernel_global_source is not initialized" including the kernel/class context) or
return an explicit empty string per your chosen contract, and ensure callers of
get_kernel_source (if any) expect the new behavior; reference get_kernel_source
and kernel_global_source when making the change.

In `@tilelang/tileop/gemm/gemm_metal.py`:
- Around line 24-27: Before emitting the simdgroup loops, validate that warp
tiling and K-micro tiling are exact multiples of 8: compute thread_nums =
thread_bounds.extent and obtain m_warp, n_warp from
self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
GemmInst.METAL_SIMDGROUP), then derive warp_row_tiles and warp_col_tiles and
check that (self.M % (m_warp*8) == 0), (self.N % (n_warp*8) == 0) and (block_K %
micro_size_k == 0 and (block_K // micro_size_k) % 8 == 0); if any check fails,
reject (raise/abort) before emitting simdgroup loops (the checks should be added
near where warp_row_tiles, warp_col_tiles, and block_K // micro_size_k are
computed to prevent silent tail dropping).

In `@tilelang/tileop/metal_simdgroup.py`:
- Around line 43-44: RegisterTile.index currently flattens tile_m/tile_n without
validation which can produce out-of-range fragment indices; add bounds checks in
index(self, tile_m: int, tile_n: int = 0) to verify 0 <= tile_m <
self.fragments_m and 0 <= tile_n < self.fragments_n and raise a clear exception
(e.g., ValueError or IndexError) if violated before computing tile_m *
self.fragments_n + tile_n so downstream simdgroup accesses cannot receive
invalid indices.
- Around line 468-470: In prefix_block_vector the store to writeback is
currently nested under "if writeback_guard" which causes writeback to be skipped
when writeback_guard is False; change the logic so that when writeback is not
None the store writeback[token, head] = value always executes (i.e., remove the
outer dependency on writeback_guard), and if masking/guarding per-lane is
required preserve writeback_guard only for computing/conditioning the stored
value rather than preventing the store entirely; update the code around
prefix_block_vector, writeback, writeback_guard and the writeback[token, head] =
value line accordingly.

---

Outside diff comments:
In `@src/backend/metal/CMakeLists.txt`:
- Around line 14-24: The CMake script disables USE_METAL for non-Apple platforms
but still unconditionally globs and appends the Metal runtime source
rt_mod_metal.cc via TILE_LANG_METAL_SRCS/TILE_LANG_SRCS; wrap the file(GLOB
TILE_LANG_METAL_SRCS ... ) and list(APPEND TILE_LANG_SRCS
${TILE_LANG_METAL_SRCS}) in a conditional that only runs when USE_METAL is ON
(or when APPLE), or alternatively split codegen-only sources from runtime
sources and only append the runtime file rt_mod_metal.cc when USE_METAL is true
so non-Apple codegen-only builds do not include Metal runtime sources.

In `@tilelang/jit/adapter/base.py`:
- Around line 76-86: The CUDA init try/except in the device chooser (around
torch.cuda._lazy_init and the current_device lambda) currently returns a
fallback CUDA lambda on exception, skipping the MPS and CPU fallbacks; change
the except branch so that if torch.cuda._lazy_init() fails it does not return a
CUDA device lambda but instead checks for MPS availability
(torch.backends.mps.is_available()) and returns a mps lambda if available,
otherwise returns the cpu lambda—i.e., move the MPS/CPU fallback logic into the
except path (or call a shared helper) so broken CUDA setups fall back to MPS or
CPU instead of a failing CUDA lambda.

---

Nitpick comments:
In `@testing/python/metal/test_metal_gemm_v2.py`:
- Around line 23-35: Add a runtime case that allocates the output as a
hardware-backed fragment instead of shared memory: create a variant where
C_local is created with T.alloc_fragment (matching accum_dtype/shape) rather
than T.alloc_shared, keep the same T.clear(C_local), loop with T.gemm(A_shared,
B_shared, C_local) and the final T.copy(C_local, C[by * block_M, bx * block_N]),
and ensure the test invokes both the existing shared-storage path and this new
alloc_fragment path so the local.fragment -> metal.simdgroup rewrite is
exercised; refer to symbols A_shared, B_shared, C_local, T.alloc_shared,
T.alloc_fragment, and T.gemm to locate and implement the change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2326baa8-17c2-4baa-bb12-933ff83d2473

📥 Commits

Reviewing files that changed from the base of the PR and between 936ae92 and 3ee8b7f.

📒 Files selected for processing (37)
  • benchmark/matmul_metal/benchmark_matmul_metal.py
  • pyproject.toml
  • requirements-dev.txt
  • requirements.txt
  • src/backend/metal/CMakeLists.txt
  • src/op/copy.cc
  • src/op/copy.h
  • src/op/fill.cc
  • src/op/gemm.cc
  • src/op/gemm.h
  • src/op/parallel.cc
  • src/op/utils.h
  • src/target/codegen_metal.cc
  • src/target/codegen_metal.h
  • src/transform/layout_inference.cc
  • src/transform/lower_device_storage_access_info.cc
  • testing/python/jit/test_tilelang_jit_adapter_mps.py
  • testing/python/metal/metal_internal_runtime_coverage.md
  • testing/python/metal/test_metal_gemm_v2.py
  • testing/python/metal/test_metal_gemm_v2_linux.py
  • testing/python/metal/test_metal_internal_scaffolding.py
  • testing/python/metal/test_metal_local_var.py
  • testing/python/metal/test_metal_simdgroup_store.py
  • tilelang/engine/lower.py
  • tilelang/engine/phase.py
  • tilelang/intrinsics/metal_macro_generator.py
  • tilelang/jit/adapter/base.py
  • tilelang/jit/adapter/torch/metal.py
  • tilelang/tileop/gemm/__init__.py
  • tilelang/tileop/gemm/gemm_metal.py
  • tilelang/tileop/gemm/inst.py
  • tilelang/tileop/metal_gdn.py
  • tilelang/tileop/metal_quant.py
  • tilelang/tileop/metal_simdgroup.py
  • tilelang/transform/decouple_type_cast.py
  • tilelang/transform/metal_fragment_to_simdgroup.py
  • tilelang/utils/language.py

Comment thread benchmark/matmul_metal/benchmark_matmul_metal.py
Comment thread src/op/copy.cc
Comment thread src/op/copy.cc Outdated
Comment thread src/op/fill.cc
Comment thread testing/python/jit/test_tilelang_jit_adapter_mps.py
Comment thread testing/python/metal/test_metal_internal_scaffolding.py Outdated
Comment thread tilelang/jit/adapter/torch/metal.py Outdated
Comment thread tilelang/tileop/gemm/gemm_metal.py
Comment thread tilelang/tileop/metal_simdgroup.py
Comment thread tilelang/tileop/metal_simdgroup.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
testing/python/metal/test_metal_internal_scaffolding.py (1)

484-493: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Compile the negative probe for Metal explicitly.

bad_kernel(32) still goes through default JIT device selection, so a host with CUDA available can miss the Metal path entirely and turn this backend-specific assertion into a flaky no-op. Return the PrimFunc directly here and lower/compile it with target="metal" inside the subprocess.

Proposed fix
-            `@tilelang.jit`(out_idx=[-1])
             def bad_kernel(M):
                 `@T.prim_func`
                 def main(A: T.Tensor((M,), T.float32), B: T.Tensor((M,), T.{dtype_name})):
                     with T.Kernel(T.ceildiv(M, 32), threads=32) as bx:
                         for i in T.Parallel(32):
                             B[bx * 32 + i] = A[bx * 32 + i]
                 return main
 
-            bad_kernel(32)
+            tilelang.lower(bad_kernel(32), target="metal")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_internal_scaffolding.py` around lines 484 -
493, The test currently calls bad_kernel(32) which triggers default JIT device
selection and can skip Metal; instead have bad_kernel return the PrimFunc (the
inner main defined under `@T.prim_func`) and in the subprocess explicitly
lower/compile that PrimFunc with target="metal" so the Metal-specific compile
path and assertion are exercised; update references to bad_kernel and the
returned PrimFunc in the subprocess to call lowering/compilation APIs with
target="metal" rather than relying on implicit JIT device selection.
tilelang/tileop/metal_simdgroup.py (2)

43-44: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate fragment coordinates before flattening them.

Every load/store/MMA helper eventually feeds this index into simdgroup operations. Without bounds checks here, an invalid (tile_m, tile_n) turns into an invalid fragment index instead of a clear exception.

Proposed fix
     def index(self, tile_m: int, tile_n: int = 0) -> int:
+        if not (0 <= tile_m < self.fragments_m and 0 <= tile_n < self.fragments_n):
+            raise IndexError(
+                f"RegisterTile index out of range: ({tile_m}, {tile_n}) "
+                f"for shape ({self.fragments_m}, {self.fragments_n})"
+            )
         return tile_m * self.fragments_n + tile_n
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/tileop/metal_simdgroup.py` around lines 43 - 44, The index(tile_m:
int, tile_n: int = 0) method must validate that tile_m and tile_n are within
bounds before computing the flat index; add checks that 0 <= tile_m <
self.fragments_m and 0 <= tile_n < self.fragments_n and raise an IndexError
(with a clear message including the invalid coordinates and allowed ranges) if
either check fails, then return tile_m * self.fragments_n + tile_n as before;
update the index method in metal_simdgroup.py accordingly.

468-469: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

writeback_guard=False should not disable writeback entirely.

The store is nested under writeback_guard, so turning the guard off suppresses every writeback[token, head] = value assignment. That makes the flag change semantics from “unguarded writeback” to “no writeback”.

Proposed fix
-            if writeback is not None and writeback_guard:
+            if writeback is not None:
                 writeback[token, head] = value
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/tileop/metal_simdgroup.py` around lines 468 - 469, The current
condition "if writeback is not None and writeback_guard:" suppresses all
writebacks when writeback_guard is False; instead, make writeback occur whenever
writeback is provided and only apply a guard when explicitly required—replace
the combined condition with a simple "if writeback is not None:" so the
assignment "writeback[token, head] = value" always happens for unguarded mode,
and if guarded behavior is needed implement a separate explicit guard check
(using writeback_guard) around the assignment rather than embedding it into the
existence check.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/target/codegen_metal.cc`:
- Around line 392-418: Both the load and store paths for "local.var" (in
CodeGenTileLangMetal::VisitExpr_(const BufferLoadNode*) and
CodeGenTileLangMetal::VisitStmt_(const BufferStoreNode*)) ignore the index value
and thus silently accept non-zero indices; add a defensive check that rejects
any non-zero index. After the existing ICHECK_EQ(op->indices.size(), 1) and
scalar checks, validate that op->indices[0] is an IntImm (or constant integer)
and its value == 0, and fail fast (ICHECK / LOG(FATAL) with a clear message
referencing local.var index) if it is not zero so we do not silently read/write
the scalar for non-zero accesses.

---

Duplicate comments:
In `@testing/python/metal/test_metal_internal_scaffolding.py`:
- Around line 484-493: The test currently calls bad_kernel(32) which triggers
default JIT device selection and can skip Metal; instead have bad_kernel return
the PrimFunc (the inner main defined under `@T.prim_func`) and in the subprocess
explicitly lower/compile that PrimFunc with target="metal" so the Metal-specific
compile path and assertion are exercised; update references to bad_kernel and
the returned PrimFunc in the subprocess to call lowering/compilation APIs with
target="metal" rather than relying on implicit JIT device selection.

In `@tilelang/tileop/metal_simdgroup.py`:
- Around line 43-44: The index(tile_m: int, tile_n: int = 0) method must
validate that tile_m and tile_n are within bounds before computing the flat
index; add checks that 0 <= tile_m < self.fragments_m and 0 <= tile_n <
self.fragments_n and raise an IndexError (with a clear message including the
invalid coordinates and allowed ranges) if either check fails, then return
tile_m * self.fragments_n + tile_n as before; update the index method in
metal_simdgroup.py accordingly.
- Around line 468-469: The current condition "if writeback is not None and
writeback_guard:" suppresses all writebacks when writeback_guard is False;
instead, make writeback occur whenever writeback is provided and only apply a
guard when explicitly required—replace the combined condition with a simple "if
writeback is not None:" so the assignment "writeback[token, head] = value"
always happens for unguarded mode, and if guarded behavior is needed implement a
separate explicit guard check (using writeback_guard) around the assignment
rather than embedding it into the existence check.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7f436436-7c81-41f0-af47-b4905cde4908

📥 Commits

Reviewing files that changed from the base of the PR and between 3ee8b7f and 79158bd.

📒 Files selected for processing (5)
  • src/target/codegen_metal.cc
  • src/target/codegen_metal.h
  • testing/python/metal/test_metal_internal_scaffolding.py
  • tilelang/tileop/metal_quant.py
  • tilelang/tileop/metal_simdgroup.py
✅ Files skipped from review due to trivial changes (1)
  • tilelang/tileop/metal_quant.py

Comment thread src/target/codegen_metal.cc
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (1)
tilelang/tileop/gemm/gemm_metal.py (1)

69-73: Consider replacing assert-based runtime guards with explicit exceptions across GEMM lowerers.

Lines 69–73 use assert for input validation. While this pattern appears in other GEMM lowerers (e.g., gemm_wmma.py), assert statements can be disabled with Python's -O optimization flag, silently removing these safeguards. For production lowering paths, explicit exceptions (e.g., ValueError) should be used to ensure validation always runs regardless of Python runtime mode. If addressed, apply consistently across all GEMM lowering modules that use this pattern.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/tileop/gemm/gemm_metal.py` around lines 69 - 73, Replace the
assert-based checks with explicit exceptions so validation always runs; e.g.,
instead of "assert block_K >= micro_size_k", raise ValueError with a clear
message referencing block_K and micro_size_k; replace "assert
is_full_region(C_region)" with a ValueError mentioning C_region and
is_full_region; and replace the c_in_simdgroup_reg / is_shared(C_buf) assert
with a ValueError that includes C_buf.scope() in the message. Apply this change
where these symbols appear (block_K, micro_size_k, is_full_region, C_region,
c_in_simdgroup_reg, is_shared, C_buf.scope()) so the Metal GEMM lowerer always
raises explicit exceptions rather than relying on asserts.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmark/matmul_metal/benchmark_matmul_metal.py`:
- Around line 82-85: After calling parser.parse_args() and assigning M,N,K from
args (the block using parser.parse_args() and M, N, K = args.m, args.n, args.k),
add an early sanity check that ensures M, N, K and args.repeats are positive
integers (>0) and that any other numeric flags (e.g., args.batch, args.threads
if present) are non-negative; if any check fails, print a clear error to stderr
or via the existing logger and exit with a non-zero code. Use the parsed args
values for validation, raise or sys.exit(1) on invalid input, and include which
argument was invalid in the message so downstream throughput math and runtime
never run with bad values.

In `@src/op/copy.cc`:
- Around line 1057-1061: The ICHECKs in src/op/copy.cc that assert
dst_strides.size() == dst->shape.size() and
analyzer->CanProveEqual(dst_strides[1], 1) should not hard-fail for
non-contiguous destination columns; instead detect the failing condition and
fall back to the safe path (call LowerNormalCopy) when dst_strides.size()
mismatches or analyzer->CanProveEqual(dst_strides[1], 1) is false. Replace the
hard ICHECK on dst_strides[1] with a conditional that routes to LowerNormalCopy
for strided destinations, keeping the original check for complete stride vector
only if it's truly required, and ensure you reference dst_strides, dst->shape,
analyzer->CanProveEqual, and LowerNormalCopy when implementing the conditional
fallback.
- Around line 802-837: CheckSIMDGroupCopy currently allows non-zero
src_range[*].min which permits sliced metal.simdgroup sources but
LowerSIMDGroupCopy emits simdgroup_store(...) assuming tile indices start at 0;
update CopyNode::CheckSIMDGroupCopy to reject any simdgroup source with non-zero
src_range min values (both dimensions) by verifying src_range[i]->min is a
constant zero (or equivalent IntImmNode with value 0) for i in {0,1}; this
prevents sliced sources from being accepted until LowerSIMDGroupCopy is changed
to apply a source tile base.
- Around line 1064-1095: The code assumes constant thread extents and perfect
warp tiling by directly casting T.thread_bounds->extent to IntImmNode and later
ICHECK-ing that M >= m_warp*8 && N >= n_warp*8; change this to detect
non-constant or invalid extents and fall back to a safe normal-copy path instead
of crashing: check TargetGetWarpSize(T.target) and whether
T.thread_bounds->extent is an IntImmNode (and that block_size % warp_size == 0)
before computing num_warps and warp_id; if any check fails (non-IntImm extent,
partial warps, or computed m_warp/n_warp that exceed max_m/max_n), avoid the
warp-mapping loop and return/use the existing normal-copy fallback, and replace
the final ICHECK(M >= m_warp*8 && N >= n_warp*8) with a conditional that selects
the fallback when the condition is false so lowering does not abort.

---

Nitpick comments:
In `@tilelang/tileop/gemm/gemm_metal.py`:
- Around line 69-73: Replace the assert-based checks with explicit exceptions so
validation always runs; e.g., instead of "assert block_K >= micro_size_k", raise
ValueError with a clear message referencing block_K and micro_size_k; replace
"assert is_full_region(C_region)" with a ValueError mentioning C_region and
is_full_region; and replace the c_in_simdgroup_reg / is_shared(C_buf) assert
with a ValueError that includes C_buf.scope() in the message. Apply this change
where these symbols appear (block_K, micro_size_k, is_full_region, C_region,
c_in_simdgroup_reg, is_shared, C_buf.scope()) so the Metal GEMM lowerer always
raises explicit exceptions rather than relying on asserts.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5da34eab-ba80-48ba-bd78-9561d8a53452

📥 Commits

Reviewing files that changed from the base of the PR and between 79158bd and d4fb922.

📒 Files selected for processing (10)
  • benchmark/matmul_metal/benchmark_matmul_metal.py
  • src/op/copy.cc
  • src/op/fill.cc
  • testing/python/jit/test_tilelang_jit_adapter_mps.py
  • testing/python/metal/test_metal_internal_scaffolding.py
  • tilelang/intrinsics/metal_macro_generator.py
  • tilelang/jit/adapter/base.py
  • tilelang/jit/adapter/torch/metal.py
  • tilelang/tileop/gemm/gemm_metal.py
  • tilelang/tileop/metal_simdgroup.py
✅ Files skipped from review due to trivial changes (1)
  • tilelang/tileop/metal_simdgroup.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • tilelang/jit/adapter/torch/metal.py
  • tilelang/jit/adapter/base.py
  • testing/python/jit/test_tilelang_jit_adapter_mps.py

Comment thread benchmark/matmul_metal/benchmark_matmul_metal.py
Comment thread src/op/copy.cc
Comment thread src/op/copy.cc Outdated
Comment thread src/op/copy.cc Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
tilelang/tileop/gemm/gemm_metal.py (1)

69-70: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Add an explicit divisibility guard for block_K and micro_size_k.

block_K // micro_size_k is floor division in both K loops. Without a % check, non-divisible configurations can silently skip tail K work.

Suggested patch
         if block_K < micro_size_k:
             raise ValueError(f"Metal GEMM requires block_K ({block_K}) to be >= micro_size_k ({micro_size_k})")
+        if block_K % micro_size_k != 0:
+            raise ValueError(
+                f"Metal GEMM requires block_K ({block_K}) to be divisible by micro_size_k ({micro_size_k})"
+            )

Also applies to: 86-87, 104-105

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/tileop/gemm/gemm_metal.py` around lines 69 - 70, Add an explicit
divisibility guard so block_K is a whole multiple of micro_size_k: in
tilelang/tileop/gemm/gemm_metal.py, update the validation around the existing
check that raises when block_K < micro_size_k (and the analogous checks at the
other occurrences) to also raise a ValueError when block_K % micro_size_k != 0;
reference the variables block_K and micro_size_k in the message so misconfigured
plans fail fast instead of silently dropping tail K work (apply the same change
to the other two validation sites mentioned).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/op/copy.cc`:
- Around line 1066-1076: The code computes warp_id from the absolute
T.thread_var which breaks when thread_bounds->min != 0; instead normalize the
thread index against T.thread_bounds->min before dividing by warp_size: compute
a local_thread = T.thread_var - T.thread_bounds->min (matching how other
specialized lowerings normalize), then set warp_id = FloorDiv(local_thread,
warp_size); keep the same guards around block_extent/warp_size and fall back to
LowerNormalCopy(T, analyzer) as before.
- Around line 1096-1120: Loop chooses m_warp/n_warp partitions but
warp_row_tiles and warp_col_tiles use truncating division and can drop partial
tiles; update the guard to require exact tiling by checking divisibility before
proceeding: ensure M % (m_warp * 8) == 0 and N % (n_warp * 8) == 0 (and keep the
existing checks for <=0 and total_elements), and only compute warp_row_tiles = M
/ (m_warp * 8) and warp_col_tiles = N / (n_warp * 8) after those divisibility
checks; if divisibility fails, return LowerNormalCopy(T, analyzer). This ensures
the warp partition exactly tiles the source and prevents silently dropping
tiles.
- Around line 829-837: The loop that validates src_range currently only rejects
non-zero mins and mismatched extents, but you must also reject partial source
tiles: inside the for-loop over i check that the source extent exactly equals
the full source buffer dimension used by LowerSIMDGroupCopy (e.g., compare
src_extent->value against the source buffer's shape/size for that dimension
instead of only comparing to dst_extent), and return false if the src tile is a
subrange (partial) so tile_idx remapping uses the full source stride; update the
condition that currently references src_range, src_extent and dst_extent to
include this equality against the source's full shape.

---

Duplicate comments:
In `@tilelang/tileop/gemm/gemm_metal.py`:
- Around line 69-70: Add an explicit divisibility guard so block_K is a whole
multiple of micro_size_k: in tilelang/tileop/gemm/gemm_metal.py, update the
validation around the existing check that raises when block_K < micro_size_k
(and the analogous checks at the other occurrences) to also raise a ValueError
when block_K % micro_size_k != 0; reference the variables block_K and
micro_size_k in the message so misconfigured plans fail fast instead of silently
dropping tail K work (apply the same change to the other two validation sites
mentioned).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5e67df19-49aa-47c3-a42a-c12fef080a02

📥 Commits

Reviewing files that changed from the base of the PR and between d4fb922 and 7f948ec.

📒 Files selected for processing (5)
  • benchmark/matmul_metal/benchmark_matmul_metal.py
  • src/backend/metal/CMakeLists.txt
  • src/op/copy.cc
  • src/target/codegen_metal.cc
  • tilelang/tileop/gemm/gemm_metal.py

Comment thread src/op/copy.cc
Comment thread src/op/copy.cc Outdated
Comment thread src/op/copy.cc
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
src/op/copy.cc (1)

812-815: 💤 Low value

Consider adding explicit check for src->shape.size() == 2.

The code validates dst->shape.size() != 2 but not src->shape.size(). The subsequent loop at lines 829-839 accesses src->shape[i] for i ∈ {0, 1}, which would be undefined behavior if src->shape has fewer than 2 dimensions.

While IsSIMDGroupBuffer(src) likely guarantees 2D shape by construction, adding an explicit check makes the invariant clear and provides defense-in-depth.

💡 Suggested defensive check
   if (src_range.size() != 2 || dst_range.size() != 2 ||
-      dst->shape.size() != 2) {
+      src->shape.size() != 2 || dst->shape.size() != 2) {
     return false;
   }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/copy.cc` around lines 812 - 815, Add an explicit guard that verifies
src->shape.size() == 2 before using src->shape[0/1]: the current block already
checks src_range.size() and dst_range.size() and dst->shape.size(), but does not
assert src->shape has two dims; update the same conditional (the one that
returns false) to also check src->shape.size() == 2 (or add a separate early
return) so that subsequent code in this function that indexes src->shape[0] and
src->shape[1] is safe even if IsSIMDGroupBuffer(src) invariants change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@src/op/copy.cc`:
- Around line 812-815: Add an explicit guard that verifies src->shape.size() ==
2 before using src->shape[0/1]: the current block already checks
src_range.size() and dst_range.size() and dst->shape.size(), but does not assert
src->shape has two dims; update the same conditional (the one that returns
false) to also check src->shape.size() == 2 (or add a separate early return) so
that subsequent code in this function that indexes src->shape[0] and
src->shape[1] is safe even if IsSIMDGroupBuffer(src) invariants change.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9afc7693-7180-495c-a751-c68f9a613d47

📥 Commits

Reviewing files that changed from the base of the PR and between 7f948ec and 971c17b.

📒 Files selected for processing (2)
  • src/op/copy.cc
  • tilelang/tileop/gemm/gemm_metal.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants