Rebase Metal simdgroup GEMM support and runtime coverage#2130
Rebase Metal simdgroup GEMM support and runtime coverage#2130jorgecurious wants to merge 10 commits into
Conversation
Add T.gemm support for Apple Metal using simdgroup_matrix 8x8 operations (simdgroup_load/store/multiply_accumulate). Works on all Apple Silicon (M1-M5) without requiring a TVM fork. Key changes: - codegen_metal.cc/h: Fork TVM Metal codegen to tilelang with simdgroup intrinsic emission and 128-bit vectorized copy - gemm_metal.py: GemmMetal tile operator for sharedxshared GEMM - metal_macro_generator.py: MPSIntrinEmitter for simdgroup MMA macros - metal_fragment_to_simdgroup.py: Pass rewrites local.fragment GEMM accumulators to metal.simdgroup scope before layout inference - LowerSIMDGroupCopy in copy.cc for fragment->device simdgroup_store 24 Metal tests (codegen cross-platform + correctness on device).
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a Metal/MPS backend: Metal codegen, simdgroup register-tile intrinsics/macros, Metal-specific GEMM/copy/fill lowering, a pass rewriting Changes
Sequence DiagramsequenceDiagram
participant User as User/Test
participant JIT as TileLang JIT/Adapter
participant Engine as Compilation Engine
participant Pass as MetalFragmentToSimdgroup
participant Lower as Layout & Lowering
participant Codegen as CodeGenTileLangMetal
participant MPS as MPS Runtime
User->>JIT: submit prim_func for Metal
JIT->>Engine: Lower & Legalize
Engine->>Pass: apply MetalFragmentToSimdgroup
Pass-->>Engine: prim_func with `metal.simdgroup` buffers
Engine->>Lower: run layout inference & lowering (copy/fill/gemm)
Lower-->>Engine: lowered device IR
Engine->>Codegen: call `target.build.tilelang_metal`
Codegen-->>Engine: emit Metal kernel source / module
JIT->>MPS: run compiled kernel on MPS via adapter
MPS-->>User: execution result
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Review rate limit: 7/8 reviews remaining, refill in 7 minutes and 30 seconds.Comment |
There was a problem hiding this comment.
Actionable comments posted: 10
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tilelang/jit/adapter/base.py (1)
76-86:⚠️ Potential issue | 🟠 Major | ⚡ Quick winExtend the fallback into the CUDA-init failure path.
MPSis only considered whentorch.cuda.is_available()is false. If_lazy_init()raises, this code still returns a CUDA device functor and skips the new MPS fallback, so broken CUDA setups can still fail instead of landing onmpsorcpu.Suggested fix
if torch.cuda.is_available(): try: torch.cuda._lazy_init() current_device = torch._C._cuda_getDevice return lambda: torch.device("cuda", current_device()) except Exception: - return lambda: torch.device("cuda", torch.cuda.current_device()) + pass if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available(): return lambda: torch.device("mps")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/jit/adapter/base.py` around lines 76 - 86, The CUDA init try/except in the device chooser (around torch.cuda._lazy_init and the current_device lambda) currently returns a fallback CUDA lambda on exception, skipping the MPS and CPU fallbacks; change the except branch so that if torch.cuda._lazy_init() fails it does not return a CUDA device lambda but instead checks for MPS availability (torch.backends.mps.is_available()) and returns a mps lambda if available, otherwise returns the cpu lambda—i.e., move the MPS/CPU fallback logic into the except path (or call a shared helper) so broken CUDA setups fall back to MPS or CPU instead of a failing CUDA lambda.src/backend/metal/CMakeLists.txt (1)
14-24:⚠️ Potential issue | 🟠 Major | ⚡ Quick winNon-Apple “codegen-only” mode still includes Metal runtime sources.
At Line 18 you switch
USE_METALoff for non-Apple, but Lines 21-24 still appendsrc/target/rt_mod_metal.cc. That contradicts the stated codegen-only behavior and can break portability.Proposed fix
if(NOT APPLE) # On non-Apple platforms USE_METAL=ON enables only codegen (Metal source # generation) without requiring the Metal/Foundation frameworks. message(STATUS "Metal backend on non-Apple: enabling codegen-only mode (no Metal runtime)") - set(USE_METAL OFF) + return() endif() file(GLOB TILE_LANG_METAL_SRCS src/target/rt_mod_metal.cc ) list(APPEND TILE_LANG_SRCS ${TILE_LANG_METAL_SRCS})🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/backend/metal/CMakeLists.txt` around lines 14 - 24, The CMake script disables USE_METAL for non-Apple platforms but still unconditionally globs and appends the Metal runtime source rt_mod_metal.cc via TILE_LANG_METAL_SRCS/TILE_LANG_SRCS; wrap the file(GLOB TILE_LANG_METAL_SRCS ... ) and list(APPEND TILE_LANG_SRCS ${TILE_LANG_METAL_SRCS}) in a conditional that only runs when USE_METAL is ON (or when APPLE), or alternatively split codegen-only sources from runtime sources and only append the runtime file rt_mod_metal.cc when USE_METAL is true so non-Apple codegen-only builds do not include Metal runtime sources.
🧹 Nitpick comments (1)
testing/python/metal/test_metal_gemm_v2.py (1)
23-35: ⚡ Quick winAdd a runtime case that keeps
C_localas a fragment.Using shared storage here exercises the shared-output lowering, but it skips the
local.fragment -> metal.simdgrouprewrite that the publicT.gemm/alloc_fragmentpath depends on. I'd add one hardware-backed case withT.alloc_fragmentas well, so this file covers the public Metal GEMM path directly.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_gemm_v2.py` around lines 23 - 35, Add a runtime case that allocates the output as a hardware-backed fragment instead of shared memory: create a variant where C_local is created with T.alloc_fragment (matching accum_dtype/shape) rather than T.alloc_shared, keep the same T.clear(C_local), loop with T.gemm(A_shared, B_shared, C_local) and the final T.copy(C_local, C[by * block_M, bx * block_N]), and ensure the test invokes both the existing shared-storage path and this new alloc_fragment path so the local.fragment -> metal.simdgroup rewrite is exercised; refer to symbols A_shared, B_shared, C_local, T.alloc_shared, T.alloc_fragment, and T.gemm to locate and implement the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmark/matmul_metal/benchmark_matmul_metal.py`:
- Around line 74-89: The script currently continues even when MPS is
unavailable; in the __main__ block (where parser/args, M/N/K are set and prints
occur) add an explicit availability check using
torch.backends.mps.is_available() and abort early (e.g., print an error and call
sys.exit(1) or raise RuntimeError) before any device allocations or use of
"mps"; ensure the check runs immediately after computing/printing MPS (or right
after args parsing) and include a clear message like "MPS backend not available,
aborting" referencing the symbols args, M, N, K and the
torch.backends.mps.is_available() call so downstream code never attempts to
allocate on "mps" when unavailable.
In `@src/op/copy.cc`:
- Line 1015: The code computes dst_stride from dst->shape[...] (dst_stride
variable) before calling simdgroup_store, which is wrong for strided
buffers/views; change those computations to read the actual destination stride
(e.g., the last entry of dst->strides or the buffer/view API that exposes
per-dimension strides) wherever dst_stride is used for simdgroup_store (the
dst_stride assignment at the dst_stride declaration and the similar occurrences
around the later simdgroup_store calls referenced in the comment, ~lines
1069-1072). Ensure you use the destination buffer/view stride accessor (falling
back only if no strides exist) so simdgroup_store writes use the correct memory
stride.
- Around line 802-807: CheckSIMDGroupCopy is too permissive: it only tests
target/scope but LowerSIMDGroupCopy assumes constant 2D extents and directly
dereferences IntImmNode and emits unconditional stores (risking OOB on edge
tiles). Tighten the legality check in CheckSIMDGroupCopy to verify the copy is a
2D constant-extent copy (both extents are IntImmNode constants and match the
expected SIMD group tile size/or divisibility) and that src/dst buffer kinds are
compatible (IsSIMDGroupBuffer with IsSharedBuffer/IsGlobalBuffer), and/or modify
LowerSIMDGroupCopy to detect non-constant or non-2D cases and fall back to the
generic copy path; additionally change LowerSIMDGroupCopy to emit
masked/conditional stores for edge tiles (or bounds checks) instead of
unconditional stores to prevent OOB writes when IntImm assumptions fail.
In `@src/op/fill.cc`:
- Around line 159-181: The simdgroup fill path (IsSIMDGroupBuffer) ignores
region offsets: the loop uses i as the matrix index for
make_filled_simdgroup_matrix without accounting for region[*].min, so fills
always start at matrix 0 and can overwrite when region is a slice; fix by
computing a start_matrix index from the region mins (e.g., sum/compute offset in
units of 64 elements from region[*].min) and use start_matrix + i when building
the Call, or if you want a minimal safe change, add an ICHECK that all
region[*].min are zero and return an error if any min != 0 to explicitly reject
non-zero mins before constructing stmts.
In `@testing/python/jit/test_tilelang_jit_adapter_mps.py`:
- Around line 10-33: Add a regression test that simulates CUDA appearing
available but failing during initialization so the except branch in
BaseKernelAdapter.get_current_device_functor() is executed: create a test (e.g.,
test_current_device_functor_handles_cuda_init_failure_prefers_mps) that
monkeypatches torch.cuda.is_available to return True, monkeypatches
torch.cuda._lazy_init (or the init function BaseKernelAdapter calls) to raise
RuntimeError, ensures torch.backends.mps.is_available returns True, then calls
BaseKernelAdapter.get_current_device_functor() and asserts the returned
device_functor() equals torch.device("mps"); use the same monkeypatch pattern as
the existing tests so cleanup is automatic.
In `@testing/python/metal/test_metal_internal_scaffolding.py`:
- Around line 488-497: The probe currently calls bad_kernel(32) which relies on
default JIT device selection and can miss Metal; instead, after obtaining the
returned PrimFunc from bad_kernel (the inner T.prim_func named main), explicitly
lower or compile that PrimFunc with target="metal" so the Metal lowering path is
exercised (e.g., invoke the lowering/compile API on the returned function with
target="metal"); update the probe where bad_kernel(32) is called (and the
similar call at 517-521) to perform this explicit lowering/compilation step.
In `@tilelang/jit/adapter/torch/metal.py`:
- Around line 56-57: The get_kernel_source method currently returns
self.kernel_global_source which may be None despite the method being annotated
to return str; update get_kernel_source to fail fast by checking if
self.kernel_global_source is None and either raise a clear ValueError (e.g.
"kernel_global_source is not initialized" including the kernel/class context) or
return an explicit empty string per your chosen contract, and ensure callers of
get_kernel_source (if any) expect the new behavior; reference get_kernel_source
and kernel_global_source when making the change.
In `@tilelang/tileop/gemm/gemm_metal.py`:
- Around line 24-27: Before emitting the simdgroup loops, validate that warp
tiling and K-micro tiling are exact multiples of 8: compute thread_nums =
thread_bounds.extent and obtain m_warp, n_warp from
self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
GemmInst.METAL_SIMDGROUP), then derive warp_row_tiles and warp_col_tiles and
check that (self.M % (m_warp*8) == 0), (self.N % (n_warp*8) == 0) and (block_K %
micro_size_k == 0 and (block_K // micro_size_k) % 8 == 0); if any check fails,
reject (raise/abort) before emitting simdgroup loops (the checks should be added
near where warp_row_tiles, warp_col_tiles, and block_K // micro_size_k are
computed to prevent silent tail dropping).
In `@tilelang/tileop/metal_simdgroup.py`:
- Around line 43-44: RegisterTile.index currently flattens tile_m/tile_n without
validation which can produce out-of-range fragment indices; add bounds checks in
index(self, tile_m: int, tile_n: int = 0) to verify 0 <= tile_m <
self.fragments_m and 0 <= tile_n < self.fragments_n and raise a clear exception
(e.g., ValueError or IndexError) if violated before computing tile_m *
self.fragments_n + tile_n so downstream simdgroup accesses cannot receive
invalid indices.
- Around line 468-470: In prefix_block_vector the store to writeback is
currently nested under "if writeback_guard" which causes writeback to be skipped
when writeback_guard is False; change the logic so that when writeback is not
None the store writeback[token, head] = value always executes (i.e., remove the
outer dependency on writeback_guard), and if masking/guarding per-lane is
required preserve writeback_guard only for computing/conditioning the stored
value rather than preventing the store entirely; update the code around
prefix_block_vector, writeback, writeback_guard and the writeback[token, head] =
value line accordingly.
---
Outside diff comments:
In `@src/backend/metal/CMakeLists.txt`:
- Around line 14-24: The CMake script disables USE_METAL for non-Apple platforms
but still unconditionally globs and appends the Metal runtime source
rt_mod_metal.cc via TILE_LANG_METAL_SRCS/TILE_LANG_SRCS; wrap the file(GLOB
TILE_LANG_METAL_SRCS ... ) and list(APPEND TILE_LANG_SRCS
${TILE_LANG_METAL_SRCS}) in a conditional that only runs when USE_METAL is ON
(or when APPLE), or alternatively split codegen-only sources from runtime
sources and only append the runtime file rt_mod_metal.cc when USE_METAL is true
so non-Apple codegen-only builds do not include Metal runtime sources.
In `@tilelang/jit/adapter/base.py`:
- Around line 76-86: The CUDA init try/except in the device chooser (around
torch.cuda._lazy_init and the current_device lambda) currently returns a
fallback CUDA lambda on exception, skipping the MPS and CPU fallbacks; change
the except branch so that if torch.cuda._lazy_init() fails it does not return a
CUDA device lambda but instead checks for MPS availability
(torch.backends.mps.is_available()) and returns a mps lambda if available,
otherwise returns the cpu lambda—i.e., move the MPS/CPU fallback logic into the
except path (or call a shared helper) so broken CUDA setups fall back to MPS or
CPU instead of a failing CUDA lambda.
---
Nitpick comments:
In `@testing/python/metal/test_metal_gemm_v2.py`:
- Around line 23-35: Add a runtime case that allocates the output as a
hardware-backed fragment instead of shared memory: create a variant where
C_local is created with T.alloc_fragment (matching accum_dtype/shape) rather
than T.alloc_shared, keep the same T.clear(C_local), loop with T.gemm(A_shared,
B_shared, C_local) and the final T.copy(C_local, C[by * block_M, bx * block_N]),
and ensure the test invokes both the existing shared-storage path and this new
alloc_fragment path so the local.fragment -> metal.simdgroup rewrite is
exercised; refer to symbols A_shared, B_shared, C_local, T.alloc_shared,
T.alloc_fragment, and T.gemm to locate and implement the change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2326baa8-17c2-4baa-bb12-933ff83d2473
📒 Files selected for processing (37)
benchmark/matmul_metal/benchmark_matmul_metal.pypyproject.tomlrequirements-dev.txtrequirements.txtsrc/backend/metal/CMakeLists.txtsrc/op/copy.ccsrc/op/copy.hsrc/op/fill.ccsrc/op/gemm.ccsrc/op/gemm.hsrc/op/parallel.ccsrc/op/utils.hsrc/target/codegen_metal.ccsrc/target/codegen_metal.hsrc/transform/layout_inference.ccsrc/transform/lower_device_storage_access_info.cctesting/python/jit/test_tilelang_jit_adapter_mps.pytesting/python/metal/metal_internal_runtime_coverage.mdtesting/python/metal/test_metal_gemm_v2.pytesting/python/metal/test_metal_gemm_v2_linux.pytesting/python/metal/test_metal_internal_scaffolding.pytesting/python/metal/test_metal_local_var.pytesting/python/metal/test_metal_simdgroup_store.pytilelang/engine/lower.pytilelang/engine/phase.pytilelang/intrinsics/metal_macro_generator.pytilelang/jit/adapter/base.pytilelang/jit/adapter/torch/metal.pytilelang/tileop/gemm/__init__.pytilelang/tileop/gemm/gemm_metal.pytilelang/tileop/gemm/inst.pytilelang/tileop/metal_gdn.pytilelang/tileop/metal_quant.pytilelang/tileop/metal_simdgroup.pytilelang/transform/decouple_type_cast.pytilelang/transform/metal_fragment_to_simdgroup.pytilelang/utils/language.py
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (3)
testing/python/metal/test_metal_internal_scaffolding.py (1)
484-493:⚠️ Potential issue | 🟠 Major | ⚡ Quick winCompile the negative probe for Metal explicitly.
bad_kernel(32)still goes through default JIT device selection, so a host with CUDA available can miss the Metal path entirely and turn this backend-specific assertion into a flaky no-op. Return thePrimFuncdirectly here and lower/compile it withtarget="metal"inside the subprocess.Proposed fix
- `@tilelang.jit`(out_idx=[-1]) def bad_kernel(M): `@T.prim_func` def main(A: T.Tensor((M,), T.float32), B: T.Tensor((M,), T.{dtype_name})): with T.Kernel(T.ceildiv(M, 32), threads=32) as bx: for i in T.Parallel(32): B[bx * 32 + i] = A[bx * 32 + i] return main - bad_kernel(32) + tilelang.lower(bad_kernel(32), target="metal")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_internal_scaffolding.py` around lines 484 - 493, The test currently calls bad_kernel(32) which triggers default JIT device selection and can skip Metal; instead have bad_kernel return the PrimFunc (the inner main defined under `@T.prim_func`) and in the subprocess explicitly lower/compile that PrimFunc with target="metal" so the Metal-specific compile path and assertion are exercised; update references to bad_kernel and the returned PrimFunc in the subprocess to call lowering/compilation APIs with target="metal" rather than relying on implicit JIT device selection.tilelang/tileop/metal_simdgroup.py (2)
43-44:⚠️ Potential issue | 🟠 Major | ⚡ Quick winValidate fragment coordinates before flattening them.
Every load/store/MMA helper eventually feeds this index into simdgroup operations. Without bounds checks here, an invalid
(tile_m, tile_n)turns into an invalid fragment index instead of a clear exception.Proposed fix
def index(self, tile_m: int, tile_n: int = 0) -> int: + if not (0 <= tile_m < self.fragments_m and 0 <= tile_n < self.fragments_n): + raise IndexError( + f"RegisterTile index out of range: ({tile_m}, {tile_n}) " + f"for shape ({self.fragments_m}, {self.fragments_n})" + ) return tile_m * self.fragments_n + tile_n🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/tileop/metal_simdgroup.py` around lines 43 - 44, The index(tile_m: int, tile_n: int = 0) method must validate that tile_m and tile_n are within bounds before computing the flat index; add checks that 0 <= tile_m < self.fragments_m and 0 <= tile_n < self.fragments_n and raise an IndexError (with a clear message including the invalid coordinates and allowed ranges) if either check fails, then return tile_m * self.fragments_n + tile_n as before; update the index method in metal_simdgroup.py accordingly.
468-469:⚠️ Potential issue | 🟠 Major | ⚡ Quick win
writeback_guard=Falseshould not disable writeback entirely.The store is nested under
writeback_guard, so turning the guard off suppresses everywriteback[token, head] = valueassignment. That makes the flag change semantics from “unguarded writeback” to “no writeback”.Proposed fix
- if writeback is not None and writeback_guard: + if writeback is not None: writeback[token, head] = value🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/tileop/metal_simdgroup.py` around lines 468 - 469, The current condition "if writeback is not None and writeback_guard:" suppresses all writebacks when writeback_guard is False; instead, make writeback occur whenever writeback is provided and only apply a guard when explicitly required—replace the combined condition with a simple "if writeback is not None:" so the assignment "writeback[token, head] = value" always happens for unguarded mode, and if guarded behavior is needed implement a separate explicit guard check (using writeback_guard) around the assignment rather than embedding it into the existence check.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/target/codegen_metal.cc`:
- Around line 392-418: Both the load and store paths for "local.var" (in
CodeGenTileLangMetal::VisitExpr_(const BufferLoadNode*) and
CodeGenTileLangMetal::VisitStmt_(const BufferStoreNode*)) ignore the index value
and thus silently accept non-zero indices; add a defensive check that rejects
any non-zero index. After the existing ICHECK_EQ(op->indices.size(), 1) and
scalar checks, validate that op->indices[0] is an IntImm (or constant integer)
and its value == 0, and fail fast (ICHECK / LOG(FATAL) with a clear message
referencing local.var index) if it is not zero so we do not silently read/write
the scalar for non-zero accesses.
---
Duplicate comments:
In `@testing/python/metal/test_metal_internal_scaffolding.py`:
- Around line 484-493: The test currently calls bad_kernel(32) which triggers
default JIT device selection and can skip Metal; instead have bad_kernel return
the PrimFunc (the inner main defined under `@T.prim_func`) and in the subprocess
explicitly lower/compile that PrimFunc with target="metal" so the Metal-specific
compile path and assertion are exercised; update references to bad_kernel and
the returned PrimFunc in the subprocess to call lowering/compilation APIs with
target="metal" rather than relying on implicit JIT device selection.
In `@tilelang/tileop/metal_simdgroup.py`:
- Around line 43-44: The index(tile_m: int, tile_n: int = 0) method must
validate that tile_m and tile_n are within bounds before computing the flat
index; add checks that 0 <= tile_m < self.fragments_m and 0 <= tile_n <
self.fragments_n and raise an IndexError (with a clear message including the
invalid coordinates and allowed ranges) if either check fails, then return
tile_m * self.fragments_n + tile_n as before; update the index method in
metal_simdgroup.py accordingly.
- Around line 468-469: The current condition "if writeback is not None and
writeback_guard:" suppresses all writebacks when writeback_guard is False;
instead, make writeback occur whenever writeback is provided and only apply a
guard when explicitly required—replace the combined condition with a simple "if
writeback is not None:" so the assignment "writeback[token, head] = value"
always happens for unguarded mode, and if guarded behavior is needed implement a
separate explicit guard check (using writeback_guard) around the assignment
rather than embedding it into the existence check.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 7f436436-7c81-41f0-af47-b4905cde4908
📒 Files selected for processing (5)
src/target/codegen_metal.ccsrc/target/codegen_metal.htesting/python/metal/test_metal_internal_scaffolding.pytilelang/tileop/metal_quant.pytilelang/tileop/metal_simdgroup.py
✅ Files skipped from review due to trivial changes (1)
- tilelang/tileop/metal_quant.py
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (1)
tilelang/tileop/gemm/gemm_metal.py (1)
69-73: Consider replacingassert-based runtime guards with explicit exceptions across GEMM lowerers.Lines 69–73 use
assertfor input validation. While this pattern appears in other GEMM lowerers (e.g., gemm_wmma.py),assertstatements can be disabled with Python's-Ooptimization flag, silently removing these safeguards. For production lowering paths, explicit exceptions (e.g.,ValueError) should be used to ensure validation always runs regardless of Python runtime mode. If addressed, apply consistently across all GEMM lowering modules that use this pattern.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/tileop/gemm/gemm_metal.py` around lines 69 - 73, Replace the assert-based checks with explicit exceptions so validation always runs; e.g., instead of "assert block_K >= micro_size_k", raise ValueError with a clear message referencing block_K and micro_size_k; replace "assert is_full_region(C_region)" with a ValueError mentioning C_region and is_full_region; and replace the c_in_simdgroup_reg / is_shared(C_buf) assert with a ValueError that includes C_buf.scope() in the message. Apply this change where these symbols appear (block_K, micro_size_k, is_full_region, C_region, c_in_simdgroup_reg, is_shared, C_buf.scope()) so the Metal GEMM lowerer always raises explicit exceptions rather than relying on asserts.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmark/matmul_metal/benchmark_matmul_metal.py`:
- Around line 82-85: After calling parser.parse_args() and assigning M,N,K from
args (the block using parser.parse_args() and M, N, K = args.m, args.n, args.k),
add an early sanity check that ensures M, N, K and args.repeats are positive
integers (>0) and that any other numeric flags (e.g., args.batch, args.threads
if present) are non-negative; if any check fails, print a clear error to stderr
or via the existing logger and exit with a non-zero code. Use the parsed args
values for validation, raise or sys.exit(1) on invalid input, and include which
argument was invalid in the message so downstream throughput math and runtime
never run with bad values.
In `@src/op/copy.cc`:
- Around line 1057-1061: The ICHECKs in src/op/copy.cc that assert
dst_strides.size() == dst->shape.size() and
analyzer->CanProveEqual(dst_strides[1], 1) should not hard-fail for
non-contiguous destination columns; instead detect the failing condition and
fall back to the safe path (call LowerNormalCopy) when dst_strides.size()
mismatches or analyzer->CanProveEqual(dst_strides[1], 1) is false. Replace the
hard ICHECK on dst_strides[1] with a conditional that routes to LowerNormalCopy
for strided destinations, keeping the original check for complete stride vector
only if it's truly required, and ensure you reference dst_strides, dst->shape,
analyzer->CanProveEqual, and LowerNormalCopy when implementing the conditional
fallback.
- Around line 802-837: CheckSIMDGroupCopy currently allows non-zero
src_range[*].min which permits sliced metal.simdgroup sources but
LowerSIMDGroupCopy emits simdgroup_store(...) assuming tile indices start at 0;
update CopyNode::CheckSIMDGroupCopy to reject any simdgroup source with non-zero
src_range min values (both dimensions) by verifying src_range[i]->min is a
constant zero (or equivalent IntImmNode with value 0) for i in {0,1}; this
prevents sliced sources from being accepted until LowerSIMDGroupCopy is changed
to apply a source tile base.
- Around line 1064-1095: The code assumes constant thread extents and perfect
warp tiling by directly casting T.thread_bounds->extent to IntImmNode and later
ICHECK-ing that M >= m_warp*8 && N >= n_warp*8; change this to detect
non-constant or invalid extents and fall back to a safe normal-copy path instead
of crashing: check TargetGetWarpSize(T.target) and whether
T.thread_bounds->extent is an IntImmNode (and that block_size % warp_size == 0)
before computing num_warps and warp_id; if any check fails (non-IntImm extent,
partial warps, or computed m_warp/n_warp that exceed max_m/max_n), avoid the
warp-mapping loop and return/use the existing normal-copy fallback, and replace
the final ICHECK(M >= m_warp*8 && N >= n_warp*8) with a conditional that selects
the fallback when the condition is false so lowering does not abort.
---
Nitpick comments:
In `@tilelang/tileop/gemm/gemm_metal.py`:
- Around line 69-73: Replace the assert-based checks with explicit exceptions so
validation always runs; e.g., instead of "assert block_K >= micro_size_k", raise
ValueError with a clear message referencing block_K and micro_size_k; replace
"assert is_full_region(C_region)" with a ValueError mentioning C_region and
is_full_region; and replace the c_in_simdgroup_reg / is_shared(C_buf) assert
with a ValueError that includes C_buf.scope() in the message. Apply this change
where these symbols appear (block_K, micro_size_k, is_full_region, C_region,
c_in_simdgroup_reg, is_shared, C_buf.scope()) so the Metal GEMM lowerer always
raises explicit exceptions rather than relying on asserts.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5da34eab-ba80-48ba-bd78-9561d8a53452
📒 Files selected for processing (10)
benchmark/matmul_metal/benchmark_matmul_metal.pysrc/op/copy.ccsrc/op/fill.cctesting/python/jit/test_tilelang_jit_adapter_mps.pytesting/python/metal/test_metal_internal_scaffolding.pytilelang/intrinsics/metal_macro_generator.pytilelang/jit/adapter/base.pytilelang/jit/adapter/torch/metal.pytilelang/tileop/gemm/gemm_metal.pytilelang/tileop/metal_simdgroup.py
✅ Files skipped from review due to trivial changes (1)
- tilelang/tileop/metal_simdgroup.py
🚧 Files skipped from review as they are similar to previous changes (3)
- tilelang/jit/adapter/torch/metal.py
- tilelang/jit/adapter/base.py
- testing/python/jit/test_tilelang_jit_adapter_mps.py
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (1)
tilelang/tileop/gemm/gemm_metal.py (1)
69-70:⚠️ Potential issue | 🟠 Major | ⚡ Quick winAdd an explicit divisibility guard for
block_Kandmicro_size_k.
block_K // micro_size_kis floor division in both K loops. Without a%check, non-divisible configurations can silently skip tail K work.Suggested patch
if block_K < micro_size_k: raise ValueError(f"Metal GEMM requires block_K ({block_K}) to be >= micro_size_k ({micro_size_k})") + if block_K % micro_size_k != 0: + raise ValueError( + f"Metal GEMM requires block_K ({block_K}) to be divisible by micro_size_k ({micro_size_k})" + )Also applies to: 86-87, 104-105
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/tileop/gemm/gemm_metal.py` around lines 69 - 70, Add an explicit divisibility guard so block_K is a whole multiple of micro_size_k: in tilelang/tileop/gemm/gemm_metal.py, update the validation around the existing check that raises when block_K < micro_size_k (and the analogous checks at the other occurrences) to also raise a ValueError when block_K % micro_size_k != 0; reference the variables block_K and micro_size_k in the message so misconfigured plans fail fast instead of silently dropping tail K work (apply the same change to the other two validation sites mentioned).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/op/copy.cc`:
- Around line 1066-1076: The code computes warp_id from the absolute
T.thread_var which breaks when thread_bounds->min != 0; instead normalize the
thread index against T.thread_bounds->min before dividing by warp_size: compute
a local_thread = T.thread_var - T.thread_bounds->min (matching how other
specialized lowerings normalize), then set warp_id = FloorDiv(local_thread,
warp_size); keep the same guards around block_extent/warp_size and fall back to
LowerNormalCopy(T, analyzer) as before.
- Around line 1096-1120: Loop chooses m_warp/n_warp partitions but
warp_row_tiles and warp_col_tiles use truncating division and can drop partial
tiles; update the guard to require exact tiling by checking divisibility before
proceeding: ensure M % (m_warp * 8) == 0 and N % (n_warp * 8) == 0 (and keep the
existing checks for <=0 and total_elements), and only compute warp_row_tiles = M
/ (m_warp * 8) and warp_col_tiles = N / (n_warp * 8) after those divisibility
checks; if divisibility fails, return LowerNormalCopy(T, analyzer). This ensures
the warp partition exactly tiles the source and prevents silently dropping
tiles.
- Around line 829-837: The loop that validates src_range currently only rejects
non-zero mins and mismatched extents, but you must also reject partial source
tiles: inside the for-loop over i check that the source extent exactly equals
the full source buffer dimension used by LowerSIMDGroupCopy (e.g., compare
src_extent->value against the source buffer's shape/size for that dimension
instead of only comparing to dst_extent), and return false if the src tile is a
subrange (partial) so tile_idx remapping uses the full source stride; update the
condition that currently references src_range, src_extent and dst_extent to
include this equality against the source's full shape.
---
Duplicate comments:
In `@tilelang/tileop/gemm/gemm_metal.py`:
- Around line 69-70: Add an explicit divisibility guard so block_K is a whole
multiple of micro_size_k: in tilelang/tileop/gemm/gemm_metal.py, update the
validation around the existing check that raises when block_K < micro_size_k
(and the analogous checks at the other occurrences) to also raise a ValueError
when block_K % micro_size_k != 0; reference the variables block_K and
micro_size_k in the message so misconfigured plans fail fast instead of silently
dropping tail K work (apply the same change to the other two validation sites
mentioned).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5e67df19-49aa-47c3-a42a-c12fef080a02
📒 Files selected for processing (5)
benchmark/matmul_metal/benchmark_matmul_metal.pysrc/backend/metal/CMakeLists.txtsrc/op/copy.ccsrc/target/codegen_metal.cctilelang/tileop/gemm/gemm_metal.py
There was a problem hiding this comment.
🧹 Nitpick comments (1)
src/op/copy.cc (1)
812-815: 💤 Low valueConsider adding explicit check for
src->shape.size() == 2.The code validates
dst->shape.size() != 2but notsrc->shape.size(). The subsequent loop at lines 829-839 accessessrc->shape[i]fori ∈ {0, 1}, which would be undefined behavior ifsrc->shapehas fewer than 2 dimensions.While
IsSIMDGroupBuffer(src)likely guarantees 2D shape by construction, adding an explicit check makes the invariant clear and provides defense-in-depth.💡 Suggested defensive check
if (src_range.size() != 2 || dst_range.size() != 2 || - dst->shape.size() != 2) { + src->shape.size() != 2 || dst->shape.size() != 2) { return false; }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/copy.cc` around lines 812 - 815, Add an explicit guard that verifies src->shape.size() == 2 before using src->shape[0/1]: the current block already checks src_range.size() and dst_range.size() and dst->shape.size(), but does not assert src->shape has two dims; update the same conditional (the one that returns false) to also check src->shape.size() == 2 (or add a separate early return) so that subsequent code in this function that indexes src->shape[0] and src->shape[1] is safe even if IsSIMDGroupBuffer(src) invariants change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@src/op/copy.cc`:
- Around line 812-815: Add an explicit guard that verifies src->shape.size() ==
2 before using src->shape[0/1]: the current block already checks
src_range.size() and dst_range.size() and dst->shape.size(), but does not assert
src->shape has two dims; update the same conditional (the one that returns
false) to also check src->shape.size() == 2 (or add a separate early return) so
that subsequent code in this function that indexes src->shape[0] and
src->shape[1] is safe even if IsSIMDGroupBuffer(src) invariants change.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9afc7693-7180-495c-a751-c68f9a613d47
📒 Files selected for processing (2)
src/op/copy.cctilelang/tileop/gemm/gemm_metal.py
Summary
tile-ai/main, including the new backend-local CMake layout.local.varlowering for Metal codegen and prefer MPS as the JIT fallback device when CUDA is unavailable.Upstream Compatibility
T.gemm/alloc_fragmentusage.tilelang.tileop; no publicT.rtorT.rvaliases are added.uint8probes only.Rebase Notes
src/backend/metal/CMakeLists.txtso it follows the backend-local CMake structure now used by upstream.Verification
cmake -S . -B build -DUSE_CUDA=OFF -DUSE_ROCM=OFF -DUSE_METAL=ON -DCMAKE_BUILD_TYPE=Release -DPython3_EXECUTABLE=/Applications/Xcode.app/Contents/Developer/usr/bin/python3 -DPython_EXECUTABLE=/Applications/Xcode.app/Contents/Developer/usr/bin/python3cmake --build build -j8python3 -m pytest testing/python/metal/test_metal_gemm_v2.py testing/python/metal/test_metal_gemm_v2_linux.py testing/python/metal/test_metal_simdgroup_store.py testing/python/metal/test_metal_local_var.py testing/python/jit/test_tilelang_jit_adapter_mps.py testing/python/metal/test_metal_internal_scaffolding.py -qSummary by CodeRabbit
New Features
Tests
Documentation
Chores