[Metal] route FP8-input T.gemm to scalar fallback#2140
[Metal] route FP8-input T.gemm to scalar fallback#2140apstenku123 wants to merge 11 commits intotile-ai:mainfrom
Conversation
Add T.gemm support for Apple Metal using simdgroup_matrix 8x8 operations (simdgroup_load/store/multiply_accumulate). Works on all Apple Silicon (M1-M5) without requiring a TVM fork. Key changes: - codegen_metal.cc/h: Fork TVM Metal codegen to tilelang with simdgroup intrinsic emission and 128-bit vectorized copy - gemm_metal.py: GemmMetal tile operator for sharedxshared GEMM - metal_macro_generator.py: MPSIntrinEmitter for simdgroup MMA macros - metal_fragment_to_simdgroup.py: Pass rewrites local.fragment GEMM accumulators to metal.simdgroup scope before layout inference - LowerSIMDGroupCopy in copy.cc for fragment->device simdgroup_store 24 Metal tests (codegen cross-platform + correctness on device).
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughThis pull request adds comprehensive Metal SIMD group support for TileLang GEMM operations, including a new Metal code generator, GEMM lowering path, copy and fill operations with SIMD group backing, high-level tile abstraction macros, intrinsic emitters, and extensive tests covering codegen and runtime correctness on Apple MPS hardware. ChangesMetal SIMD Group GEMM & Infrastructure
Metal SIMD Group Abstractions & High-Level Macros
JIT & Host Integration
Tests, Benchmarks & Documentation
Platform Dependencies
Sequence DiagramsequenceDiagram
participant User as Application
participant JIT as TileLang JIT
participant Lower as Lowering Pipeline
participant Pass as FragmentToSimdgroup Pass
participant Codegen as Metal CodeGen
participant Target as Metal Target
User->>JIT: Define `@T.prim_func` GEMM<br/>(shared A, shared B, fragment C)
JIT->>JIT: Compile with tilelang.jit()
JIT->>Lower: Lower to IR
Lower->>Pass: Apply MetalFragmentToSimdgroup<br/>Rewrite fragment→simdgroup
Pass->>Pass: Detect GEMM ops,<br/>exclude FP8 cases,<br/>remap allocations
Pass-->>Lower: Return transformed mod
Lower->>Lower: LayoutInference<br/>(no fragment layout needed on Metal)
Lower->>Codegen: Invoke CodeGenTileLangMetal
Codegen->>Codegen: Generate kernel::<br/>- Bind thread indices<br/>- Allocate 8x8 simdgroup tiles<br/>- Emit simdgroup ops<br/>- Handle local.var scalars
Codegen->>Target: Optional: Compile to metallib<br/>via tvm_callback_metal_compile
Target-->>Codegen: Return Metal binary
Codegen-->>JIT: Return MetalModule
JIT-->>User: Return callable kernel
User->>JIT: Call kernel(A_mps, B_mps, C_mps)
JIT->>Target: Launch on Apple MPS
Target->>Target: Execute simdgroup matrix ops,<br/>copy results
Target-->>User: Return results
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
✨ Finishing Touches🧪 Generate unit tests (beta)
|
Files documenting the actual PRs we just opened upstream: - PR #1: ml-explore/mlx#3476 — from_dlpack Metal-aware consumer (against main, clean) - PR #2: apache/tvm#19504 — TVM_METAL_STORAGE_MODE env opt-in (against main, clean) - PR #3: tile-ai/tilelang#2139 — mixed-dtype T.gemm via scalar fallback (stacks on PR #2130) - PR #4: tile-ai/tilelang#2140 — FP8-input T.gemm scalar fallback routing (stacks on PR #2130) - PR #5: tile-ai/tilelang#2141 — T.Pipelined num_stages>1 3D buffer fix (stacks on PR #2130) - PR #6: tile-ai/tilelang#2142 — T.fp8_scaled_matmul DSL intrinsic (stacks on PR #2130) Deferred (split into companion PRs needed): tilelang_metal_fp8 and tilelang_metal_fp8_vector each touch both tilelang supermodule and the TileLang/tvm vendored submodule. These need 2 PRs each — one to tile-ai/tilelang, one to TileLang/tvm — separate filing round. PRs #3-#6 are independent of each other; each branches directly from jorgecurious/tilelang:metal-gemm-upstream-rebase HEAD 971c17b, so they can be reviewed in any order. They DO depend on the upstream 4-PR Apple Metal landing chain (#1869, #2118, #2121, #2130) merging first; if any of those land separately, ours can be retargeted at main.
There was a problem hiding this comment.
Pull request overview
This PR improves Metal backend lowering for T.gemm, including routing FP8-input GEMMs away from simdgroup MMA (which can’t legally represent FP8 on Apple Silicon today) and adding the supporting Metal codegen/scaffolding needed for simdgroup accumulators and stores.
Changes:
- Add Metal-specific GEMM dispatch (
GemmInst.METAL_SIMDGROUP) with FP8-input detection that routes Metal FP8 GEMMs to the scalar fallback. - Introduce a Metal pass to rewrite GEMM accumulators from
local.fragmenttometal.simdgroup(and skip that rewrite for FP8-input GEMMs). - Add/extend Metal codegen, lowering, and tests to support
metal.simdgroupallocation/fill/copy and MPS-oriented runtime coverage.
Reviewed changes
Copilot reviewed 36 out of 37 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| tilelang/utils/language.py | Add is_metal_simdgroup() scope helper. |
| tilelang/transform/metal_fragment_to_simdgroup.py | New Metal pass rewriting GEMM accumulators to metal.simdgroup, skipping FP8-input GEMMs. |
| tilelang/transform/decouple_type_cast.py | Treat metal.simdgroup buffers as “local” for cast-decoupling decisions. |
| tilelang/tileop/metal_simdgroup.py | New internal simdgroup helper macros/types (RegisterTile/RowVector/etc.). |
| tilelang/tileop/metal_quant.py | Add packed-uint8 FP8/FP4/e8m0 decode helpers for Metal kernels. |
| tilelang/tileop/metal_gdn.py | Add internal GDN/attention-style simdgroup tile macros. |
| tilelang/tileop/gemm/inst.py | Extend GemmInst with METAL_SIMDGROUP. |
| tilelang/tileop/gemm/gemm_metal.py | New Metal simdgroup GEMM lowering implementation. |
| tilelang/tileop/gemm/init.py | Metal GEMM instruction selection + FP8-input routing to scalar; hook up GemmMetal. |
| tilelang/jit/adapter/torch/metal.py | Expose Metal kernel source via adapter. |
| tilelang/jit/adapter/base.py | Prefer MPS device selection when CUDA is unavailable/initialization fails. |
| tilelang/intrinsics/metal_macro_generator.py | New MPSIntrinEmitter to generate simdgroup load/store/MMA macros. |
| tilelang/engine/phase.py | Insert Metal fragment→simdgroup rewrite pass before layout inference. |
| tilelang/engine/lower.py | Switch Metal device codegen to target.build.tilelang_metal. |
| testing/python/metal/test_metal_simdgroup_store.py | Tests for simdgroup accumulator path + direct simdgroup_store to device memory. |
| testing/python/metal/test_metal_local_var.py | Tests for local.var scalar lowering on Metal. |
| testing/python/metal/test_metal_internal_scaffolding.py | Broad internal Metal scaffolding/runtime/source-boundary probes. |
| testing/python/metal/test_metal_gemm_v2_linux.py | Cross-platform (no Metal runtime) codegen-only tests for Metal GEMM v2. |
| testing/python/metal/test_metal_gemm_v2.py | Runtime correctness tests for Metal GEMM v2 on Metal hardware. |
| testing/python/metal/metal_internal_runtime_coverage.md | Document internal Metal runtime coverage expectations. |
| testing/python/jit/test_tilelang_jit_adapter_mps.py | Tests for MPS device selection behavior in the JIT adapter. |
| src/transform/lower_device_storage_access_info.cc | Allow .fragment scope tags in storage-access-info lowering. |
| src/transform/layout_inference.cc | Metal-specific relaxation for fragment layout completeness checks. |
| src/target/codegen_metal.h | Add TileLang Metal codegen declaration (CodeGenTileLangMetal). |
| src/target/codegen_metal.cc | Implement TileLang Metal codegen including metal.simdgroup + local.var support; register target.build.tilelang_metal. |
| src/op/utils.h | Add IsSIMDGroupBuffer/IsRegisterBuffer helpers. |
| src/op/parallel.cc | Guard fragment layout access when layout may be absent. |
| src/op/gemm.h | Add kMetalSimdgroup GEMM inst enum value. |
| src/op/gemm.cc | Select Metal simdgroup GEMM inst for Metal targets; adjust warp partition heuristics for Metal. |
| src/op/fill.cc | Add fill lowering for metal.simdgroup buffers via make_filled_simdgroup_matrix. |
| src/op/copy.h | Add Metal simdgroup copy inst + lowering hooks. |
| src/op/copy.cc | Implement simdgroup-store lowering from metal.simdgroup → shared/global. |
| src/backend/metal/CMakeLists.txt | Always build Metal codegen source; gate runtime pieces on Apple/USE_METAL. |
| requirements.txt | Constrain apache-tvm-ffi on Darwin. |
| requirements-dev.txt | Constrain apache-tvm-ffi on Darwin (dev). |
| pyproject.toml | Constrain apache-tvm-ffi on Darwin (wheel/install). |
| benchmark/matmul_metal/benchmark_matmul_metal.py | Add a simple Metal GEMM benchmark script. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _rewrite_scope(body, var_map): | ||
| buf_map = {} | ||
|
|
||
| def _pre_order(stmt): | ||
| if isinstance(stmt, tir.Block): | ||
| new_alloc_bufs = [] | ||
| changed = False | ||
| for buf in stmt.alloc_buffers: | ||
| new_buf = _remap_buffer(buf, var_map) | ||
| new_alloc_bufs.append(new_buf) | ||
| if not new_buf.same_as(buf): | ||
| buf_map[buf] = new_buf | ||
| changed = True | ||
| if changed: | ||
| new_body = tir.stmt_functor.substitute(stmt.body, var_map) | ||
| new_block = tir.Block( | ||
| stmt.iter_vars, | ||
| stmt.reads, | ||
| stmt.writes, | ||
| stmt.name_hint, | ||
| new_body, | ||
| stmt.init, | ||
| new_alloc_bufs, | ||
| stmt.match_buffers, | ||
| stmt.annotations, | ||
| ) | ||
| return ( | ||
| tir.BlockRealize( | ||
| stmt.iter_vars, | ||
| tir.const(True, "bool"), | ||
| new_block, | ||
| ) | ||
| if False | ||
| else new_block | ||
| ) |
| def _collect_fragment_gemm_accum_vars(body: tir.Stmt) -> set: | ||
| """Walk the body and return fragment vars safe to rewrite to simdgroup. | ||
|
|
||
| GEMM accumulators backed by ``local.fragment`` are eligible for the | ||
| rewrite to ``metal.simdgroup``, which the Metal simdgroup MMA path | ||
| needs. We exclude FP8-input GEMMs because the dispatcher routes them | ||
| to the scalar fallback (Apple has no native FP8 ALU through M5; the | ||
| per-element T.cast invokes the storage-only decode helpers from the | ||
| FP8 prelude -- see audiohacking fp8_scaled_matmul_kernel for the | ||
| analogous pattern). For those GEMMs the accumulator must stay in | ||
| ``local.fragment`` so the scalar fallback can perform its | ||
| per-element T.cast(..., accum_dtype) arithmetic without tripping the | ||
| Metal codegen's check that ``metal.simdgroup`` allocations are | ||
| scalar 8x8 blocks. | ||
| """ | ||
| accum_vars: set = set() | ||
| gemm_ops = _get_gemm_ops() | ||
|
|
||
| def _visitor(stmt): | ||
| if isinstance(stmt, tir.Evaluate) and isinstance(stmt.value, tir.Call): | ||
| call = stmt.value | ||
| if call.op in gemm_ops and len(call.args) >= 3: | ||
| # FP8 inputs (storage-only on Metal) route to the scalar | ||
| # fallback; exclude their accumulators from the simdgroup | ||
| # rewrite so the codegen does not allocate a | ||
| # ``metal.simdgroup`` buffer for them. | ||
| a_buf = _extract_buffer_from_region(call.args[0]) | ||
| b_buf = _extract_buffer_from_region(call.args[1]) | ||
| fp8_inputs = (a_buf is not None and _is_fp8_dtype(a_buf.dtype)) or \ | ||
| (b_buf is not None and _is_fp8_dtype(b_buf.dtype)) | ||
| if fp8_inputs: | ||
| return |
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 |
| # helpers. The runtime mapping of GemmInst.Scalar to | ||
| # GemmMetalScalar on Metal targets is provided by the | ||
| # tilelang_gemm_mixed_dtype companion patch (PR #2118 stack); | ||
| # without it the resulting kernel will not lower correctly, | ||
| # but the routing decision is the load-bearing change here. |
| return lambda: torch.device("cuda", current_device()) | ||
| except Exception: | ||
| return lambda: torch.device("cuda", torch.cuda.current_device()) | ||
| pass |
| // Check that all local.fragment buffers have inferred layouts. | ||
| // On Metal targets, fragment buffers used as GEMM accumulators are | ||
| // lowered to opaque simdgroup matrices, so they have no explicit | ||
| // thread-level layout and can be safely skipped. | ||
| for (const auto &[buffer, _] : use_list_) { | ||
| if (IsFragmentBuffer(buffer)) { | ||
| ICHECK_NE(layout_map.count(buffer), 0) | ||
| << "The layout for fragment " << buffer | ||
| << " can not be inferred correctly."; | ||
| if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) { | ||
| ICHECK(false) << "The layout for fragment " << buffer | ||
| << " can not be inferred correctly."; | ||
| } |
| << "Only float16, float32, and bfloat16 are supported, but got " | ||
| << op->dtype; | ||
| ICHECK(constant_size % 64 == 0) << "Only 8x8 matrix is supported, but got " | ||
| << constant_size << " bytes\n"; |
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/op/parallel.cc (1)
362-383:⚠️ Potential issue | 🟠 Major | ⚡ Quick winGuard the earlier fragment dereference too.
This branch now treats
T.layout_map[buffer]as optional, butIsBufferCompletelyReplicated(buffer, T.layout_map)at Lines 362-363 still doesas<Fragment>().value(). On the Metal path you described, that means we can still crash before reaching the newfrag.has_value()check. Make the helper returnfalsefor a missing fragment, or checkhas_value()before calling it.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/parallel.cc` around lines 362 - 383, IsBufferCompletelyReplicated(buffer, T.layout_map) may dereference a missing Fragment; update the call site or the helper to avoid crashing: either modify IsBufferCompletelyReplicated to check T.layout_map[buffer].as<Fragment>() and return false when no fragment exists, or check that T.layout_map[buffer].as<Fragment>().has_value() before calling .value() inside the helper so it never dereferences an absent Fragment. Ensure this change keeps the later logic that uses frag.has_value() and the source_buffer/read_source_buffer flow intact.
🧹 Nitpick comments (2)
tilelang/engine/phase.py (1)
200-204: ⚡ Quick winGate this rewrite on
target.kind.name == "metal".The comment says this is a Metal-only rewrite, but the pass currently runs for every target. Guarding it here keeps non-Metal pipelines unchanged and avoids depending on
MetalFragmentToSimdgroupbeing a perfect no-op everywhere else.♻️ Suggested change
- # On Metal, rewrite local.fragment GEMM accumulators to metal.simdgroup - # before layout inference (which would otherwise require a layout for them) - from tilelang.transform.metal_fragment_to_simdgroup import MetalFragmentToSimdgroup - - mod = MetalFragmentToSimdgroup(mod) + # On Metal, rewrite local.fragment GEMM accumulators to metal.simdgroup + # before layout inference (which would otherwise require a layout for them) + if target.kind.name == "metal": + from tilelang.transform.metal_fragment_to_simdgroup import MetalFragmentToSimdgroup + + mod = MetalFragmentToSimdgroup(mod)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/engine/phase.py` around lines 200 - 204, The Metal-specific rewrite MetalFragmentToSimdgroup is being applied unconditionally; gate it so it only runs when the compilation target is Metal by checking target.kind.name == "metal" before constructing/applying MetalFragmentToSimdgroup to mod. Locate the code around the phase where mod is transformed (the MetalFragmentToSimdgroup(mod) call) and wrap that construction/application in a conditional using the existing target variable (e.g., if target.kind.name == "metal": mod = MetalFragmentToSimdgroup(mod)).testing/python/metal/test_metal_gemm_v2_linux.py (1)
62-79: ⚡ Quick winAdd an FP8 regression case for the new routing behavior.
This suite only locks down the non-FP8 simdgroup path, but the PR’s actual behavior change is “FP8 inputs lower successfully without simdgroup MMA.” Please add a codegen-only FP8 case here that asserts Metal lowering succeeds and that
simdgroup_multiply_accumulateis absent, otherwise the new dispatcher/pass interaction can regress silently.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_gemm_v2_linux.py` around lines 62 - 79, Add a codegen-only FP8 regression test in this file that calls assert_metal_gemm_v2_codegen with FP8 input dtype(s) (matching your FP8 type symbol) and appropriate shapes (e.g., small 16/32 blocks) to ensure Metal lowering succeeds; then assert the generated IR/source does not contain the symbol simdgroup_multiply_accumulate (i.e., verify absence of that call). Name the new test something like test_metal_gemm_v2_fp8_no_simdgroup and place it alongside the other test_* functions so it runs in CI, using the existing assert_metal_gemm_v2_codegen helper to drive codegen-only validation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/backend/metal/CMakeLists.txt`:
- Around line 3-8: The CMakeLists currently unconditionally appends
src/target/codegen_metal.cc to TILE_LANG_SRCS causing build failures on
non-Apple/Non-Metal hosts because runtime/metal/metal_module.h is missing; fix
by moving the list(APPEND TILE_LANG_SRCS src/target/codegen_metal.cc) so it is
executed only inside the existing guards (USE_METAL and APPLE) in
CMakeLists.txt, or alternatively refactor the codegen_metal.cc file to split
Metal runtime includes from pure codegen logic (e.g., create a
codegen_metal_core.cc without runtime/metal/metal_module.h and keep
runtime-dependent bits in a separate file) and update TILE_LANG_SRCS to only add
the runtime-dependent file under the USE_METAL/APPLE guards so non-Metal builds
no longer attempt to compile runtime headers.
In `@src/op/copy.cc`:
- Around line 802-840: CopyNode::CheckSIMDGroupCopy currently accepts Metal
SIMD-group copies without proving dst_range is fully in-bounds, which allows
LowerSIMDGroupCopy to emit unpredicated simdgroup_store and OOB stores; modify
CheckSIMDGroupCopy to consult the buffer_oob/analyzer-based in-bounds check (the
same predicate used elsewhere) for dst (and if relevant src) and only return
true when the analyzer proves dst_range is fully inside the dst buffer bounds;
thread the analyzer/buffer_oob check into this path (or call the existing helper
used by other copy legality checks) so that CheckSIMDGroupCopy refuses
SIMD-group selection for boundary tiles that might be OOB.
In `@src/op/fill.cc`:
- Around line 183-202: Before emitting consecutive matrix fills, prove the
region maps to a dense whole-matrix span: compute the region's last element
offset (e.g., last_element = sum((region[i]->min + region[i]->extent - 1) *
strides[i])) and use analyzer to check both that FloorMod(element_offset,
matrix_elements) == 0 (already present) and that FloorMod(last_element + 1,
matrix_elements) == 0 and FloorDiv(last_element, matrix_elements) ==
matrix_index_base + IntImm(DataType::Int(32), num_matrices - 1); if any of these
cannot be proven, do not run the consecutive Call loop that uses
matrix_index_base/num_matrices (the stmts loop) and instead fall back to the
non-consecutive/dense-element emission path. Reference variables: region,
strides, element_offset, last_element, matrix_elements, matrix_index_base,
num_matrices, analyzer->CanProveEqual, and the stmts/Call emission block.
In `@testing/python/metal/test_metal_local_var.py`:
- Around line 34-40: The zero-init count assertion is too strict for
_make_local_var_func() because it only guarantees one default-initialized local
(y) while x is explicitly initialized to 3; update the assertion in
test_metal_local_var.py that currently uses "assert
len(re.findall(r\"\\bint\\s+\\w+\\s*=\\s*0;\", src)) >= 2" to accept a single
zero-initialized local (e.g., change >= 2 to >= 1) so the test verifies the
presence of default-initialized locals without failing on valid codegen that
doesn't produce incidental temporaries.
In `@tilelang/jit/adapter/base.py`:
- Around line 82-84: The device-selection try/except currently only wraps
_lazy_init() at construction time, letting exceptions from _cuda_getDevice()
escape when the returned lambda is later invoked; move the CUDA and MPS
device-query logic inside the returned thunk (the lambda returned by the device
functor) so the try/except covers the actual runtime call to
torch._C._cuda_getDevice() and fallback to MPS/CPU; update both the device
functor and get_current_stream_functor to defer queries into their returned
closures and keep _lazy_init() where needed. Also add a regression test that
monkeypatches torch._C._cuda_getDevice to raise when the thunk is executed and
assert the device_functor()/get_current_stream_functor() fall back to MPS or CPU
as expected.
---
Outside diff comments:
In `@src/op/parallel.cc`:
- Around line 362-383: IsBufferCompletelyReplicated(buffer, T.layout_map) may
dereference a missing Fragment; update the call site or the helper to avoid
crashing: either modify IsBufferCompletelyReplicated to check
T.layout_map[buffer].as<Fragment>() and return false when no fragment exists, or
check that T.layout_map[buffer].as<Fragment>().has_value() before calling
.value() inside the helper so it never dereferences an absent Fragment. Ensure
this change keeps the later logic that uses frag.has_value() and the
source_buffer/read_source_buffer flow intact.
---
Nitpick comments:
In `@testing/python/metal/test_metal_gemm_v2_linux.py`:
- Around line 62-79: Add a codegen-only FP8 regression test in this file that
calls assert_metal_gemm_v2_codegen with FP8 input dtype(s) (matching your FP8
type symbol) and appropriate shapes (e.g., small 16/32 blocks) to ensure Metal
lowering succeeds; then assert the generated IR/source does not contain the
symbol simdgroup_multiply_accumulate (i.e., verify absence of that call). Name
the new test something like test_metal_gemm_v2_fp8_no_simdgroup and place it
alongside the other test_* functions so it runs in CI, using the existing
assert_metal_gemm_v2_codegen helper to drive codegen-only validation.
In `@tilelang/engine/phase.py`:
- Around line 200-204: The Metal-specific rewrite MetalFragmentToSimdgroup is
being applied unconditionally; gate it so it only runs when the compilation
target is Metal by checking target.kind.name == "metal" before
constructing/applying MetalFragmentToSimdgroup to mod. Locate the code around
the phase where mod is transformed (the MetalFragmentToSimdgroup(mod) call) and
wrap that construction/application in a conditional using the existing target
variable (e.g., if target.kind.name == "metal": mod =
MetalFragmentToSimdgroup(mod)).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 81b49bed-fc4f-498a-bbce-19185f4b4703
📒 Files selected for processing (37)
benchmark/matmul_metal/benchmark_matmul_metal.pypyproject.tomlrequirements-dev.txtrequirements.txtsrc/backend/metal/CMakeLists.txtsrc/op/copy.ccsrc/op/copy.hsrc/op/fill.ccsrc/op/gemm.ccsrc/op/gemm.hsrc/op/parallel.ccsrc/op/utils.hsrc/target/codegen_metal.ccsrc/target/codegen_metal.hsrc/transform/layout_inference.ccsrc/transform/lower_device_storage_access_info.cctesting/python/jit/test_tilelang_jit_adapter_mps.pytesting/python/metal/metal_internal_runtime_coverage.mdtesting/python/metal/test_metal_gemm_v2.pytesting/python/metal/test_metal_gemm_v2_linux.pytesting/python/metal/test_metal_internal_scaffolding.pytesting/python/metal/test_metal_local_var.pytesting/python/metal/test_metal_simdgroup_store.pytilelang/engine/lower.pytilelang/engine/phase.pytilelang/intrinsics/metal_macro_generator.pytilelang/jit/adapter/base.pytilelang/jit/adapter/torch/metal.pytilelang/tileop/gemm/__init__.pytilelang/tileop/gemm/gemm_metal.pytilelang/tileop/gemm/inst.pytilelang/tileop/metal_gdn.pytilelang/tileop/metal_quant.pytilelang/tileop/metal_simdgroup.pytilelang/transform/decouple_type_cast.pytilelang/transform/metal_fragment_to_simdgroup.pytilelang/utils/language.py
| # Metal codegen is pure C++ and can generate Metal shader source on any | ||
| # platform. Always compile it so target.build.tilelang_metal is available for | ||
| # cross-compilation and source-level tests on non-Apple hosts. | ||
| list(APPEND TILE_LANG_SRCS | ||
| src/target/codegen_metal.cc | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== CMake block =="
sed -n '1,40p' src/backend/metal/CMakeLists.txt
echo
echo "== codegen_metal.cc includes =="
sed -n '1,50p' src/target/codegen_metal.cc | grep '^#include'
echo
echo "== metal_module.h presence in this checkout =="
fd -a 'metal_module\.h$' . || trueRepository: tile-ai/tilelang
Length of output: 1331
Unconditionally compiling src/target/codegen_metal.cc breaks non-Metal builds.
The CMakeLists.txt appends src/target/codegen_metal.cc to TILE_LANG_SRCS before the USE_METAL and APPLE guards take effect. However, this file includes runtime/metal/metal_module.h, which is not available on non-Metal or non-Apple platforms. This causes compilation to fail before the early return() statements can skip the Metal-specific configuration.
To fix: either move the source append behind the appropriate guards, or refactor codegen_metal.cc to isolate the Metal runtime dependencies from the codegen-only logic.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/backend/metal/CMakeLists.txt` around lines 3 - 8, The CMakeLists
currently unconditionally appends src/target/codegen_metal.cc to TILE_LANG_SRCS
causing build failures on non-Apple/Non-Metal hosts because
runtime/metal/metal_module.h is missing; fix by moving the list(APPEND
TILE_LANG_SRCS src/target/codegen_metal.cc) so it is executed only inside the
existing guards (USE_METAL and APPLE) in CMakeLists.txt, or alternatively
refactor the codegen_metal.cc file to split Metal runtime includes from pure
codegen logic (e.g., create a codegen_metal_core.cc without
runtime/metal/metal_module.h and keep runtime-dependent bits in a separate file)
and update TILE_LANG_SRCS to only add the runtime-dependent file under the
USE_METAL/APPLE guards so non-Metal builds no longer attempt to compile runtime
headers.
| bool CopyNode::CheckSIMDGroupCopy(Target target) const { | ||
| if (!TargetIsMetal(target) || !IsSIMDGroupBuffer(src)) { | ||
| return false; | ||
| } | ||
| if (!IsSharedBuffer(dst) && !IsGlobalBuffer(dst)) { | ||
| return false; | ||
| } | ||
| if (src->dtype != dst->dtype) { | ||
| return false; | ||
| } | ||
| if (src_range.size() != 2 || dst_range.size() != 2 || | ||
| dst->shape.size() != 2) { | ||
| return false; | ||
| } | ||
|
|
||
| int total_elements = 1; | ||
| for (auto extent : src->shape) { | ||
| auto imm = extent.as<IntImmNode>(); | ||
| if (!imm) { | ||
| return false; | ||
| } | ||
| total_elements *= imm->value; | ||
| } | ||
| if (total_elements % 64 != 0) { | ||
| return false; | ||
| } | ||
|
|
||
| for (int i = 0; i < 2; ++i) { | ||
| auto src_shape = src->shape[i].as<IntImmNode>(); | ||
| auto src_min = src_range[i]->min.as<IntImmNode>(); | ||
| auto src_extent = src_range[i]->extent.as<IntImmNode>(); | ||
| auto dst_extent = dst_range[i]->extent.as<IntImmNode>(); | ||
| if (!src_shape || !src_min || src_min->value != 0 || !src_extent || | ||
| !dst_extent || src_extent->value != src_shape->value || | ||
| src_extent->value != dst_extent->value || src_extent->value % 8 != 0) { | ||
| return false; | ||
| } | ||
| } | ||
| return true; |
There was a problem hiding this comment.
Guard Metal SIMD-group copy against edge-tile OOB stores.
This legality check never proves that dst_range is fully in-bounds, but LowerSIMDGroupCopy later emits unpredicated simdgroup_store calls. For boundary tiles on non-divisible shapes, this can select kMetalSIMDGroup and write past dst where LowerNormalCopy would have kept the bounds predicate. Please thread buffer_oob/analyzer-based in-bounds checks through this path before returning true.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/op/copy.cc` around lines 802 - 840, CopyNode::CheckSIMDGroupCopy
currently accepts Metal SIMD-group copies without proving dst_range is fully
in-bounds, which allows LowerSIMDGroupCopy to emit unpredicated simdgroup_store
and OOB stores; modify CheckSIMDGroupCopy to consult the
buffer_oob/analyzer-based in-bounds check (the same predicate used elsewhere)
for dst (and if relevant src) and only return true when the analyzer proves
dst_range is fully inside the dst buffer bounds; thread the analyzer/buffer_oob
check into this path (or call the existing helper used by other copy legality
checks) so that CheckSIMDGroupCopy refuses SIMD-group selection for boundary
tiles that might be OOB.
| PrimExpr element_offset = 0; | ||
| for (size_t i = 0; i < region.size(); ++i) { | ||
| element_offset += region[i]->min * strides[i]; | ||
| } | ||
| PrimExpr matrix_elements = IntImm(element_offset.dtype(), 64); | ||
| ICHECK( | ||
| analyzer->CanProveEqual(FloorMod(element_offset, matrix_elements), 0)) | ||
| << "simdgroup fill region must start on an 8x8 matrix boundary"; | ||
| PrimExpr matrix_index_base = FloorDiv(element_offset, matrix_elements); | ||
| Array<Stmt> stmts; | ||
| for (int i = 0; i < num_matrices; i++) { | ||
| stmts.push_back(Evaluate( | ||
| Call(DataType::Handle(), builtin::make_filled_simdgroup_matrix(), | ||
| {dst->data, matrix_index_base + IntImm(DataType::Int(32), i), | ||
| fill_value, IntImm(DataType::Int(32), 8), | ||
| IntImm(DataType::Int(32), 8)}))); | ||
| } | ||
| if (stmts.size() == 1) | ||
| return stmts[0]; | ||
| return SeqStmt(stmts); |
There was a problem hiding this comment.
Reject non-contiguous regions before emitting consecutive matrix fills.
This path derives one matrix_index_base from element_offset and then fills num_matrices consecutive 8x8 matrices, but it never proves that region itself is contiguous in matrix order. A strided slice can pass the current % 64 == 0 and alignment checks yet still target non-consecutive elements, so this will overwrite the wrong matrices instead of the requested region. Please add a contiguity check here and fall back when the region is not a dense whole-matrix span.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/op/fill.cc` around lines 183 - 202, Before emitting consecutive matrix
fills, prove the region maps to a dense whole-matrix span: compute the region's
last element offset (e.g., last_element = sum((region[i]->min +
region[i]->extent - 1) * strides[i])) and use analyzer to check both that
FloorMod(element_offset, matrix_elements) == 0 (already present) and that
FloorMod(last_element + 1, matrix_elements) == 0 and FloorDiv(last_element,
matrix_elements) == matrix_index_base + IntImm(DataType::Int(32), num_matrices -
1); if any of these cannot be proven, do not run the consecutive Call loop that
uses matrix_index_base/num_matrices (the stmts loop) and instead fall back to
the non-consecutive/dense-element emission path. Reference variables: region,
strides, element_offset, last_element, matrix_elements, matrix_index_base,
num_matrices, analyzer->CanProveEqual, and the stmts/Call emission block.
| # local.var should lower to scalar declarations/stores rather than arrays or | ||
| # an unsupported storage scope. | ||
| assert len(re.findall(r"\bint\s+\w+\s*=\s*0;", src)) >= 2, src | ||
| assert re.search(r"\w+\s*=\s*3;", src), src | ||
| assert re.search(r"\w+\s*=\s*\(\w+ \+ 4\);", src), src | ||
| assert "local.var" not in src | ||
| assert "thread int" not in src |
There was a problem hiding this comment.
Loosen the zero-init assertion to match the test kernel.
_make_local_var_func() only guarantees one default-initialized local (y); x is explicitly initialized to 3. Requiring >= 2 matches incidental temporaries rather than the local.var behavior you're trying to lock down, so this can fail on valid codegen changes.
Suggested adjustment
- assert len(re.findall(r"\bint\s+\w+\s*=\s*0;", src)) >= 2, src
+ assert len(re.findall(r"\bint\s+\w+\s*=\s*0;", src)) >= 1, src🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@testing/python/metal/test_metal_local_var.py` around lines 34 - 40, The
zero-init count assertion is too strict for _make_local_var_func() because it
only guarantees one default-initialized local (y) while x is explicitly
initialized to 3; update the assertion in test_metal_local_var.py that currently
uses "assert len(re.findall(r\"\\bint\\s+\\w+\\s*=\\s*0;\", src)) >= 2" to
accept a single zero-initialized local (e.g., change >= 2 to >= 1) so the test
verifies the presence of default-initialized locals without failing on valid
codegen that doesn't produce incidental temporaries.
| pass | ||
| if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available(): | ||
| return lambda: torch.device("mps") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "Inspecting deferred device lookup in BaseKernelAdapter:"
sed -n '76,86p' tilelang/jit/adapter/base.py
echo
echo "Searching for test coverage of _lazy_init vs _cuda_getDevice failure:"
rg -n -C2 '_lazy_init|_cuda_getDevice' tilelang/jit/adapter/base.py testing/python/jit/test_tilelang_jit_adapter_mps.pyRepository: tile-ai/tilelang
Length of output: 2182
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Checking all tests in test_tilelang_jit_adapter_mps.py ==="
cat testing/python/jit/test_tilelang_jit_adapter_mps.py
echo
echo "=== Checking where device_functor() is called ==="
rg -n "device_functor\(\)" tilelang/jit/adapter/
echo
echo "=== Checking similar pattern in get_current_stream_functor ==="
sed -n '51,66p' tilelang/jit/adapter/base.py
echo
echo "=== Search for _cuda_getDevice usage beyond base.py ==="
rg -n "_cuda_getDevice" --type pyRepository: tile-ai/tilelang
Length of output: 3471
Move CUDA/MPS device-query logic into the returned thunk to ensure fallback handles deferred failures.
The except handler only wraps _lazy_init() during functor construction, but _cuda_getDevice() is only called when the returned lambda executes later. If the device query fails at that time, the exception bypasses the fallback logic entirely.
Suggested fix
`@staticmethod`
def get_current_device_functor() -> Callable[[], torch.device]:
"""Return a callable that yields Torch's current device.
Similar to the stream functor, we capture a callable that, when called,
fetches the current device according to PyTorch. On CPU or when CUDA is
unavailable, returns ``torch.device('cpu')``.
"""
- if torch.cuda.is_available():
- try:
- torch.cuda._lazy_init()
- current_device = torch._C._cuda_getDevice
- return lambda: torch.device("cuda", current_device())
- except Exception:
- pass
- if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
- return lambda: torch.device("mps")
- # CPU fallback
- return lambda: torch.device("cpu")
+ def _current_device() -> torch.device:
+ if torch.cuda.is_available():
+ try:
+ torch.cuda._lazy_init()
+ return torch.device("cuda", torch._C._cuda_getDevice())
+ except Exception:
+ pass
+ if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
+ return torch.device("mps")
+ return torch.device("cpu")
+
+ return _current_deviceAdd a regression test that monkeypatches torch._C._cuda_getDevice to raise when device_functor() is invoked, verifying fallback to MPS/CPU. Note: get_current_stream_functor has an identical structural issue.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/jit/adapter/base.py` around lines 82 - 84, The device-selection
try/except currently only wraps _lazy_init() at construction time, letting
exceptions from _cuda_getDevice() escape when the returned lambda is later
invoked; move the CUDA and MPS device-query logic inside the returned thunk (the
lambda returned by the device functor) so the try/except covers the actual
runtime call to torch._C._cuda_getDevice() and fallback to MPS/CPU; update both
the device functor and get_current_stream_functor to defer queries into their
returned closures and keep _lazy_init() where needed. Also add a regression test
that monkeypatches torch._C._cuda_getDevice to raise when the thunk is executed
and assert the device_functor()/get_current_stream_functor() fall back to MPS or
CPU as expected.
Summary
Routes
T.gemmwith FP8 input operands on the Metal target to the scalar dequant-multiply-accumulate fallback. Apple Silicon (M1-M4) has no native FP8 ALU, so simdgroup_multiply_accumulate cannot lower FP8 inputs directly; the scalar path correctly emits per-element dequant via__tvm_fp8_e4m3_to_half(or__tvm_fp8_e5m2_to_half) helpers and accumulates in fp32.Two changes:
tilelang/tileop/gemm/__init__.py::Gemm._select_gemm_instruction— Metal branch detects FP8 input via new_has_fp8_input_dtypehelper and routes toGemmInst.Scalarinstead ofGemmInst.METAL_SIMDGROUP.tilelang/transform/metal_fragment_to_simdgroup.py— visitor short-circuits on FP8 inputs so the simdgroup-fragment rewrite skips FP8 GEMMs (they stay scalar).Why
cppmega.mlx sparse-MLA-FP8 attention probes need an FP8 GEMM that lowers correctly on Metal. Without this routing,
T.gemm(A_fp8, B_fp8, C_fp32)either produces invalid MSL (simdgroup path) or fails dispatch outright. The scalar fallback, paired with the storage-only FP8 emulation patches (separate PRs), gives a software FP8 GEMM that's slower than CUDA's native FP8 tensor cores but functionally correct on Apple Silicon today. M5/M6 will revisit native FP8.Stacking topology
This PR is based on
jorgecurious/tilelang:metal-gemm-upstream-rebase(PR #2130) at HEAD971c17b, which itself stacks on top of:Runtime prereq (NOT applied here): the storage-only FP8 emulation patch that introduces the
__tvm_fp8_e4m3_to_half/__tvm_fp8_e5m2_to_halfhelper functions on Metal. Without that, this dispatcher routing has nothing to dequant through. That patch lives atcppmega.mlx:docs/upstream/tilelang_metal_fp8/and will be filed once the TileLang/tvm-mirror split is settled.Test plan
Test the dispatcher path with an FP8-input
@T.prim_func:tilelang.engine.lower.lower(fp8_gemm, target="metal")should succeed (with the storage-only FP8 patch as runtime prereq).Caveats
Attribution
Co-developed with
cppmega.mlxfor Apple-Silicon Metal MLA kernel ports.Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
Tests
Documentation