[Metal] allow mixed-dtype T.gemm via scalar fallback#2139
[Metal] allow mixed-dtype T.gemm via scalar fallback#2139apstenku123 wants to merge 11 commits intotile-ai:mainfrom
Conversation
Add T.gemm support for Apple Metal using simdgroup_matrix 8x8 operations (simdgroup_load/store/multiply_accumulate). Works on all Apple Silicon (M1-M5) without requiring a TVM fork. Key changes: - codegen_metal.cc/h: Fork TVM Metal codegen to tilelang with simdgroup intrinsic emission and 128-bit vectorized copy - gemm_metal.py: GemmMetal tile operator for sharedxshared GEMM - metal_macro_generator.py: MPSIntrinEmitter for simdgroup MMA macros - metal_fragment_to_simdgroup.py: Pass rewrites local.fragment GEMM accumulators to metal.simdgroup scope before layout inference - LowerSIMDGroupCopy in copy.cc for fragment->device simdgroup_store 24 Metal tests (codegen cross-platform + correctness on device).
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughThis pull request adds comprehensive Metal (MPS) backend support to TileLang, including a Metal code generator, simdgroup-based GEMM lowering, fragment-to-simdgroup IR transformation, register tile abstractions, quantization and GDN helpers, Metal-specific copy/fill operations, and extensive testing for codegen and runtime correctness. It also updates dependency constraints for macOS and integrates Metal device selection into the JIT adapter fallback. ChangesMetal Backend Implementation
Metal-Specific GEMM & Tiling Infrastructure
Specialized Kernels & Helpers
Testing & Documentation
Dependency Updates
Sequence Diagram(s)The PR introduces sufficient novelty and multi-component interaction (JIT compiler → IR transforms → code generator → Metal runtime) to benefit from a high-level sequence diagram: sequenceDiagram
participant User as User/JIT
participant Lower as Lowering Pipeline
participant Fragment as MetalFragmentToSimdgroup
participant Codegen as CodeGenTileLangMetal
participant Metal as Metal Runtime
User->>Lower: tilelang.jit(kernel)
activate Lower
Lower->>Fragment: IR module with<br/>local.fragment C buffers
activate Fragment
Fragment->>Fragment: Detect same-dtype GEMM,<br/>rewrite to metal.simdgroup
Fragment-->>Lower: Updated IR module
deactivate Fragment
Lower->>Codegen: PrimFunc with<br/>simdgroup allocations
activate Codegen
Codegen->>Codegen: Emit Metal kernel signature<br/>Type mapping<br/>simdgroup 8x8 matrices
Codegen->>Codegen: Emit builtin calls:<br/>simdgroup_load/store/mma
Codegen-->>Lower: Metal shader source
deactivate Codegen
Lower->>Lower: Register FFI entry<br/>target.build.tilelang_metal
deactivate Lower
User->>Metal: Execute compiled kernel<br/>on MPS device
activate Metal
Metal->>Metal: threadgroup barrier sync<br/>8x8 matrix operations
Metal-->>User: Output tensor
deactivate Metal
Estimated code review effort🎯 4 (Complex) | ⏱️ ~65 minutes Possibly related issues
Possibly related PRs
Suggested reviewers
Poem
✨ Finishing Touches🧪 Generate unit tests (beta)
|
Files documenting the actual PRs we just opened upstream: - PR #1: ml-explore/mlx#3476 — from_dlpack Metal-aware consumer (against main, clean) - PR #2: apache/tvm#19504 — TVM_METAL_STORAGE_MODE env opt-in (against main, clean) - PR #3: tile-ai/tilelang#2139 — mixed-dtype T.gemm via scalar fallback (stacks on PR #2130) - PR #4: tile-ai/tilelang#2140 — FP8-input T.gemm scalar fallback routing (stacks on PR #2130) - PR #5: tile-ai/tilelang#2141 — T.Pipelined num_stages>1 3D buffer fix (stacks on PR #2130) - PR #6: tile-ai/tilelang#2142 — T.fp8_scaled_matmul DSL intrinsic (stacks on PR #2130) Deferred (split into companion PRs needed): tilelang_metal_fp8 and tilelang_metal_fp8_vector each touch both tilelang supermodule and the TileLang/tvm vendored submodule. These need 2 PRs each — one to tile-ai/tilelang, one to TileLang/tvm — separate filing round. PRs #3-#6 are independent of each other; each branches directly from jorgecurious/tilelang:metal-gemm-upstream-rebase HEAD 971c17b, so they can be reviewed in any order. They DO depend on the upstream 4-PR Apple Metal landing chain (#1869, #2118, #2121, #2130) merging first; if any of those land separately, ours can be retargeted at main.
There was a problem hiding this comment.
Pull request overview
Enables mixed-dtype T.gemm lowering on the Metal backend by routing mixed-input GEMMs to the scalar fallback while keeping same-dtype GEMMs on the simdgroup MMA path, so chained attention-style GEMMs (e.g., fp16×fp16→fp32 followed by fp32×fp16→fp32) can compile to MSL.
Changes:
- Add Metal simdgroup GEMM instruction selection and a Metal GEMM lowering path, while forcing mixed-input-dtype GEMMs onto the scalar fallback on Metal.
- Introduce
metal.simdgroupinfrastructure across Python + C++ (scope detection, fragment→simdgroup rewrite, simdgroup fill/copy lowering, and Metal codegen plumbing). - Add/expand Metal codegen + runtime tests and internal scaffolding probes; add a Metal matmul benchmark script.
Reviewed changes
Copilot reviewed 38 out of 39 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| tilelang/utils/language.py | Adds is_metal_simdgroup scope predicate. |
| tilelang/transform/metal_fragment_to_simdgroup.py | New pass to rewrite select GEMM accumulators from local.fragment to metal.simdgroup on Metal. |
| tilelang/transform/decouple_type_cast.py | Treats metal.simdgroup buffers as “local” for cast-decoupling logic. |
| tilelang/tileop/metal_simdgroup.py | Adds internal RegisterTile/RowVector helpers and simdgroup macros for Metal kernels. |
| tilelang/tileop/metal_quant.py | Adds internal fp8/fp4/e8m0 decode helpers and tile selectors for quant probes. |
| tilelang/tileop/metal_gdn.py | Adds internal GDN/attention-style macros built on simdgroup helpers. |
| tilelang/tileop/gemm/inst.py | Adds METAL_SIMDGROUP GEMM instruction kind. |
| tilelang/tileop/gemm/gemm_metal.py | New Metal simdgroup GEMM lowering implementation. |
| tilelang/tileop/gemm/gemm_base.py | Allows mixed A/B dtypes and adds a mixed-input-dtype predicate. |
| tilelang/tileop/gemm/init.py | Metal-specific GEMM dispatch: mixed dtypes → scalar fallback; otherwise simdgroup. |
| tilelang/jit/adapter/torch/metal.py | Exposes Metal kernel source and compiles via torch.mps.compile_shader. |
| tilelang/jit/adapter/base.py | Updates device selection to prefer MPS when CUDA is unavailable/failed. |
| tilelang/intrinsics/metal_macro_generator.py | Adds MPSIntrinEmitter for simdgroup load/store/MMA macro emission. |
| tilelang/engine/phase.py | Inserts Metal fragment→simdgroup rewrite into the lowering pipeline. |
| tilelang/engine/lower.py | Switches Metal build hook to target.build.tilelang_metal. |
| testing/python/metal/test_metal_simdgroup_store.py | Adds tests for direct simdgroup-store-to-device GEMM path. |
| testing/python/metal/test_metal_local_var.py | Adds focused tests for local.var scalar codegen/runtime on Metal. |
| testing/python/metal/test_metal_internal_scaffolding.py | Adds internal-only source-boundary + runtime probes for Metal helpers/quant/GDN. |
| testing/python/metal/test_metal_gemm_v2.py | Adds Metal GEMM v2 runtime correctness tests. |
| testing/python/metal/test_metal_gemm_v2_linux.py | Adds cross-platform Metal GEMM v2 codegen tests. |
| testing/python/metal/test_metal_codegen_linux.py | Adds chained attention mixed-dtype Metal codegen test coverage. |
| testing/python/metal/metal_internal_runtime_coverage.md | Documents internal Metal runtime/source-boundary coverage. |
| testing/python/jit/test_tilelang_jit_adapter_mps.py | Adds tests for MPS preference in JIT adapter device selection. |
| src/transform/lower_device_storage_access_info.cc | Treats .fragment scope tag as a special case in storage access lowering. |
| src/transform/layout_inference.cc | Skips fragment-layout assertion on Metal targets. |
| src/target/codegen_metal.h | Adds TileLang Metal codegen class declaration. |
| src/target/codegen_metal.cc | Adds TileLang Metal codegen implementation and registers build hook. |
| src/op/utils.h | Adds helpers for detecting simdgroup/register buffers. |
| src/op/parallel.cc | Makes fragment layout usage more defensive when layout info is absent. |
| src/op/gemm.h | Adds Metal simdgroup GEMM enum value and string conversion. |
| src/op/gemm.cc | Selects Metal simdgroup GEMM inst; adjusts warp partitioning for Metal. |
| src/op/fill.cc | Adds metal.simdgroup fill lowering via make_filled_simdgroup_matrix. |
| src/op/copy.h | Adds Metal simdgroup copy instruction and lowering declarations. |
| src/op/copy.cc | Adds simdgroup store lowering for T.copy from simdgroup to shared/global. |
| src/backend/metal/CMakeLists.txt | Always builds Metal codegen; enables codegen-only mode on non-Apple hosts. |
| requirements.txt | Adds Darwin-specific apache-tvm-ffi upper bound. |
| requirements-dev.txt | Adds Darwin-specific apache-tvm-ffi upper bound for dev installs. |
| pyproject.toml | Adds Darwin-specific apache-tvm-ffi upper bound for packaging. |
| benchmark/matmul_metal/benchmark_matmul_metal.py | Adds a Metal GEMM benchmark script using the simdgroup path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| class CodeGenTileLangMetal final : public CodeGenC { | ||
| public: | ||
| explicit CodeGenTileLangMetal(Target target); | ||
| // override print thread tag. | ||
| void PrintArgUnionDecl(); | ||
| void AddFunction(const GlobalVar &gvar, const PrimFunc &func) final; | ||
| void InitFuncState(const PrimFunc &f) final; |
| } | ||
| float ideal = N > 0 ? static_cast<float>(M) / N : 1.f; | ||
| float best_score = std::numeric_limits<float>::max(); | ||
| for (int m = 1; m <= std::min(num_warps, max_m); ++m) { | ||
| if (num_warps % m != 0) |
| if torch.cuda.is_available(): | ||
| try: | ||
| torch.cuda._lazy_init() | ||
| current_device = torch._C._cuda_getDevice | ||
| return lambda: torch.device("cuda", current_device()) | ||
| except Exception: | ||
| return lambda: torch.device("cuda", torch.cuda.current_device()) | ||
| pass | ||
| if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available(): | ||
| return lambda: torch.device("mps") | ||
| # CPU fallback |
| def _rewrite_scope(body, var_map): | ||
| buf_map = {} | ||
|
|
||
| def _pre_order(stmt): | ||
| if isinstance(stmt, tir.Block): | ||
| new_alloc_bufs = [] | ||
| changed = False | ||
| for buf in stmt.alloc_buffers: | ||
| new_buf = _remap_buffer(buf, var_map) | ||
| new_alloc_bufs.append(new_buf) | ||
| if not new_buf.same_as(buf): | ||
| buf_map[buf] = new_buf | ||
| changed = True | ||
| if changed: | ||
| new_body = tir.stmt_functor.substitute(stmt.body, var_map) | ||
| new_block = tir.Block( | ||
| stmt.iter_vars, | ||
| stmt.reads, | ||
| stmt.writes, | ||
| stmt.name_hint, | ||
| new_body, | ||
| stmt.init, | ||
| new_alloc_bufs, | ||
| stmt.match_buffers, | ||
| stmt.annotations, | ||
| ) | ||
| return ( | ||
| tir.BlockRealize( | ||
| stmt.iter_vars, | ||
| tir.const(True, "bool"), | ||
| new_block, | ||
| ) | ||
| if False | ||
| else new_block | ||
| ) |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tilelang/tileop/gemm/gemm_base.py (1)
80-110:⚠️ Potential issue | 🟠 Major | 🏗️ Heavy liftDon't silently reinterpret mixed inputs as
A.dtype.
in_dtypeno longer rejectsA.dtype != B.dtype, but the typed backends still wireself.in_dtypeinto botha_dtypeandb_dtype. That means a mixed GEMM outside the Metal scalar fallback can now be emitted as if both operands hadA.dtype, which is a miscompile where the old code failed fast. I'd keepin_dtypestrict and expose a separate mixed-dtype path/property for the scalar fallback instead.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/tileop/gemm/gemm_base.py` around lines 80 - 110, The in_dtype property must remain strict: restore behavior so it does not silently pretend both operands share A.dtype when A.dtype != B.dtype — if is_tensor_memory(self.A) return self.B.dtype as before, otherwise ensure A.dtype == B.dtype (raise or assert/emit an explicit error) instead of returning A.dtype unconditionally; add a separate property (e.g., mixed_input_dtype or scalar_fallback_in_dtype) that returns the appropriate dtype for scalar fallbacks and keep has_mixed_input_dtype (or update it) to signal mixed-precision so dispatchers (Gemm._select_gemm_instruction and backends that currently read self.in_dtype for both a_dtype and b_dtype) can route to the scalar fallback instead of emitting a single-typed MMA intrinsic.src/backend/metal/CMakeLists.txt (1)
6-18:⚠️ Potential issue | 🔴 CriticalRemove the non-existent
runtime/metal/metal_module.hinclude or fix the reference—this will break compilation on all platforms.The unconditional append of
src/target/codegen_metal.cctoTILE_LANG_SRCSat lines 6–8 will immediately fail on every platform because line 35 ofcodegen_metal.ccincludes"runtime/metal/metal_module.h", which does not exist anywhere in the repository. The early return for non-Apple platforms at line 18 cannot prevent this compilation error.Either remove the include statement or ensure the header file is properly available before this change can be merged.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/backend/metal/CMakeLists.txt` around lines 6 - 18, The build unconditionally appends src/target/codegen_metal.cc to TILE_LANG_SRCS but that source includes "runtime/metal/metal_module.h" which doesn't exist, causing compile failures; fix by making the append conditional (e.g., move or wrap the list(APPEND TILE_LANG_SRCS src/target/codegen_metal.cc) behind the APPLE check or add an if(APPLE) ... endif() around it) or alternatively remove/fix the include inside src/target/codegen_metal.cc to point to an existing header or add the missing runtime/metal/metal_module.h; reference TILE_LANG_SRCS, src/target/codegen_metal.cc, and the include "runtime/metal/metal_module.h" when making the change.
🧹 Nitpick comments (8)
tilelang/jit/adapter/base.py (1)
69-86: 💤 Low valueDocstring is outdated after MPS fallback addition.
The docstring at line 74 states the method "returns
torch.device('cpu')" when CUDA is unavailable, but the implementation now returns MPS when available. Consider updating the docstring to reflect the new fallback chain: CUDA → MPS → CPU.📝 Suggested docstring update
`@staticmethod` def get_current_device_functor() -> Callable[[], torch.device]: """Return a callable that yields Torch's current device. Similar to the stream functor, we capture a callable that, when called, - fetches the current device according to PyTorch. On CPU or when CUDA is - unavailable, returns ``torch.device('cpu')``. + fetches the current device according to PyTorch. Falls back through + CUDA → MPS → CPU based on availability. """🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/jit/adapter/base.py` around lines 69 - 86, Update the get_current_device_functor docstring to reflect the actual fallback chain used by the implementation: first try CUDA, then MPS, then CPU; replace the outdated sentence that says "returns torch.device('cpu')" when CUDA is unavailable with a concise description like "returns CUDA if available, otherwise MPS if available, otherwise CPU" and keep the rest of the docstring behavior notes (callable that yields Torch's current device).requirements.txt (1)
4-4: 💤 Low valueDocument the rationale for
apache-tvm-ffi<0.1.8on Darwin.The constraint is applied consistently across
requirements.txt,requirements-dev.txt, andpyproject.toml, but none of them explain why>=0.1.8is incompatible on macOS. Thepyproject.tomlalready has a precedent for this pattern (it documents why>=0.1.6is needed viatilelang#1502). Without a similar comment here, it's unclear when this upper bound can safely be relaxed, and macOS users are silently prevented from picking up any bug/security fixes shipped in0.1.8+.💬 Suggested comment
-apache-tvm-ffi<0.1.8; platform_system == 'Darwin' +# <0.1.8 on macOS: 0.1.8 introduces an RFC Error-ABI change that breaks +# TileLang's Metal FFI entrypoints (see <link-to-issue>). +apache-tvm-ffi<0.1.8; platform_system == 'Darwin'(Mirror the comment in
pyproject.tomlandrequirements-dev.txt.)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@requirements.txt` at line 4, The package constraint "apache-tvm-ffi<0.1.8; platform_system == 'Darwin'" lacks an explanatory comment; add a brief rationale next to this requirement (and mirror it in requirements-dev and pyproject entries) that states why versions >=0.1.8 are incompatible on macOS, references the upstream issue or PR (e.g., tilelang#1502-style link or ticket number), and describes the condition under which the upper bound can be relaxed (what fix/version to wait for) so macOS users know when it's safe to upgrade; ensure the comment appears immediately adjacent to the "apache-tvm-ffi<0.1.8" spec so it's clear which constraint it documents.src/transform/lower_device_storage_access_info.cc (1)
47-49: Add a brief comment explaining why.fragmentscope is excluded.The exclusion of
scope.tag != ".fragment"is correct—fragment buffers must survive this pass on all targets, not just Metal, before being transformed by target-specific lowering (e.g.,metal_fragment_to_simdgroupfor Metal, or consumed directly for CUDA). However, CUDA kernels do usealloc_fragment()(confirmed in tests liketest_tilelang_language_wgmma_gemm.py), so the condition converts a previous loud ICHECK failure into a silent skip. A comment documenting this intent preserves clarity for future maintainers:Suggested inline comment
scope.tag != ".barrier" && scope.tag != ".cluster_barrier" && + // Fragment buffers must survive this pass for target-specific lowering phases. scope.tag != ".fragment" && scope.tag.find(".descriptor") != 0) {🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/transform/lower_device_storage_access_info.cc` around lines 47 - 49, Add a brief inline comment next to the condition that checks scope.tag (the if that includes scope.tag != ".fragment") explaining that ".fragment" buffers are intentionally excluded because fragment storage must survive this lowering pass for all backends — Metal performs target-specific lowering (e.g., metal_fragment_to_simdgroup) later while CUDA may use alloc_fragment() and consume fragments directly — so we skip lowering here rather than asserting/failing; reference the scope.tag check and the ".fragment" exclusion in the comment for future maintainers.testing/python/metal/test_metal_codegen_linux.py (1)
112-119: ⚡ Quick winMake this assert the mixed-dtype fallback path, not just successful codegen.
Right now the test only proves that some Metal kernel was emitted. A regression that routes
T.gemm(S_local, V_shared, O_local)through the wrong lowering path would still pass. Please add one assertion that is specific to the scalar mixed-dtype fallback/cast behavior so this actually guards the dispatcher change.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_codegen_linux.py` around lines 112 - 119, The test currently only checks that a Metal kernel was emitted; instead assert the mixed-dtype scalar fallback path by (1) ensuring the generated kernel source does NOT contain the high-level optimized GEMM lowering (e.g., the string "T.gemm(S_local, V_shared, O_local)") and (2) asserting it contains the scalar-cast/fallback marker used by the mixed-dtype path (e.g., a convert/cast token such as "convert(" or "as_type(") so the test specifically verifies the elementwise scalar cast/fallback lowering for attention_chain_mixed_dtype().testing/python/metal/test_metal_internal_scaffolding.py (2)
425-453: ⚡ Quick winThese source assertions are overfit to the current pretty-printer.
Checks like
float kkt_bias = 0.000000e+00f;andgate_state = 1.000000e+00f;depend on exact symbol names and float formatting. A benign SSA rename or constant-printing tweak will fail the tests without changing the lowering semantics. Regexing for the local-var initialization pattern would keep the intent while making the tests much less brittle.Also applies to: 456-474
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_internal_scaffolding.py` around lines 425 - 453, The tests test_flashqla_gdn_kkt_probe_combines_local_var_state_and_simdgroup_boundary and test_scaled_packed_quant_and_gdn_probes_source_boundary_tokens are asserting exact symbol names and float formatting (e.g., "float kkt_bias = 0.000000e+00f;" and "gate_state = 1.000000e+00f;"), which is brittle; change those assertions to use regex/token-pattern checks that verify a local float variable is declared/initialized (e.g., match a "float <ident> = <float literal>;" pattern) and that an assignment of a float literal to a gate/state variable occurs, and similarly relax any exact vector type/format checks to look for type families (e.g., no "float8"/"float4" present) or counts rather than exact formatting so the tests validate semantics without depending on pretty-printer output.
413-423: ⚡ Quick winNarrow the native-fp8/fp4 negative checks.
The
assert "float8" not in .../assert "float4" not in ...checks are broader than the behavior you're trying to lock down.float4is a valid Metal vector type, andsimdgroup_float8x8also contains thefloat8substring, so unrelated codegen changes can trip these tests. Match the unsupported dtype spellings (float8_e4m3fn,float4_e2m1fn, etc.) or the tensor boundary types instead.Also applies to: 436-463
testing/python/metal/test_metal_gemm_v2_linux.py (1)
14-19: ⚡ Quick winAdd a mixed-input-dtype coverage case for this PR path.
AandBare always instantiated with the samedtypehere, so this suite never exercises the dispatcher relaxation for mixed-input GEMMs. A regression in the fp16×fp16→fp32 / fp32×fp16→fp32 scalar-fallback path would still pass all of these tests.Also applies to: 39-78
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_gemm_v2_linux.py` around lines 14 - 19, The current matmul_gemm_v2 test only instantiates A and B with the same dtype so it never exercises mixed-input-dtype dispatcher relaxation; update the test generation to include at least one mixed-dtype case (e.g., A float32 & B float16 and/or A float16 & B float32) while keeping accum_dtype=float32, by modifying matmul_gemm_v2 (and its usages) to accept/parametrize separate input dtypes for A and B (referencing the matmul_gemm_v2 function and its inner prim_func main and tensor params A, B, C) and add a new test variant that constructs A and B with differing dtypes to exercise the fp16×fp16→fp32 vs fp32×fp16→fp32 dispatcher paths.testing/python/metal/test_metal_simdgroup_store.py (1)
72-76: ⚡ Quick winAvoid pinning the codegen test to the
C_localsymbol name.
C_localis an emitter detail, not a stable API. A harmless rename in simplify/SSA will break this test even if the directsimdgroup_storelowering is still correct. Prefer asserting on structural patterns (simdgroup_store(present, no extra C round-trip) instead of the temporary's exact name.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_simdgroup_store.py` around lines 72 - 76, The test is brittle because it looks for the temporary name `C_local`; instead change the assertions to check for the lowering pattern itself: assert that the source contains at least one "simdgroup_store(" (to ensure the lowering emitted the store) and assert there are no "simdgroup_load(" occurrences that indicate an unwanted extra round-trip through shared memory; remove any checks that specifically reference the `C_local` symbol and use the presence/absence of "simdgroup_store(" and "simdgroup_load(" patterns to validate behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/transform/layout_inference.cc`:
- Around line 436-445: The current Metal-target bypass skips validating all
fragment buffers when TargetIsMetal(target_) is true, which also skips fragments
that still require inferred layouts (e.g., mixed-dtype/scalar-fallback GEMM
fragments); change the logic in the loop over use_list_ (symbols: use_list_,
IsFragmentBuffer, layout_map) to only skip validation for buffers that were
actually rewritten to simdgroup (detectable by their storage scope or a
metal.simdgroup marker, e.g., check buffer->storage_scope == "metal.simdgroup"
or use a new IsSimdgroupBuffer helper) instead of using TargetIsMetal(target_),
and ensure any remaining local.fragment buffers still fail the ICHECK if
layout_map.count(buffer) == 0 so their layouts are inferred/validated (this
preserves MetalFragmentToSimdgroup behavior while keeping validation for
fragments that remain).
In `@tilelang/tileop/gemm/__init__.py`:
- Around line 180-189: The current Metal branch unconditionally returns
GemmInst.METAL_SIMDGROUP for same-dtype Metal GEMMs, bypassing the normal
selector and legality checks; instead call the existing selector
GemmInst(_ffi_api.GemmGetGemmInst(self, int(thread_nums), target)) to get the
default instance, then only override to GemmInst.Scalar when
target_is_metal(target) and self._has_mixed_input_dtype() is true; otherwise
return the selector's result so non-simdgroup-legal cases can fall back normally
(refer to target_is_metal, _has_mixed_input_dtype, GemmInst.METAL_SIMDGROUP,
GemmInst.Scalar, GemmInst(_ffi_api.GemmGetGemmInst) and GemmMetal.lower).
In `@tilelang/transform/metal_fragment_to_simdgroup.py`:
- Around line 126-161: The current rewrite only substitutes stmt.body when
buffers are remapped, leaving stmt.init, stmt.reads, stmt.writes and
stmt.match_buffers still pointing to old buffer objects (buf_map is built but
unused). Update the tir.Block handling (the loop that builds new_alloc_bufs
using _remap_buffer, var_map and buf_map) to also remap the block signature:
apply substitutions or rebuild stmt.init, stmt.reads, stmt.writes and
stmt.match_buffers to reference new_bufs (using buf_map and
tir.stmt_functor.substitute where appropriate) and use those remapped values
when constructing new_block (and the BlockRealize branch) so the block metadata
is consistent with alloc_buffers.
---
Outside diff comments:
In `@src/backend/metal/CMakeLists.txt`:
- Around line 6-18: The build unconditionally appends
src/target/codegen_metal.cc to TILE_LANG_SRCS but that source includes
"runtime/metal/metal_module.h" which doesn't exist, causing compile failures;
fix by making the append conditional (e.g., move or wrap the list(APPEND
TILE_LANG_SRCS src/target/codegen_metal.cc) behind the APPLE check or add an
if(APPLE) ... endif() around it) or alternatively remove/fix the include inside
src/target/codegen_metal.cc to point to an existing header or add the missing
runtime/metal/metal_module.h; reference TILE_LANG_SRCS,
src/target/codegen_metal.cc, and the include "runtime/metal/metal_module.h" when
making the change.
In `@tilelang/tileop/gemm/gemm_base.py`:
- Around line 80-110: The in_dtype property must remain strict: restore behavior
so it does not silently pretend both operands share A.dtype when A.dtype !=
B.dtype — if is_tensor_memory(self.A) return self.B.dtype as before, otherwise
ensure A.dtype == B.dtype (raise or assert/emit an explicit error) instead of
returning A.dtype unconditionally; add a separate property (e.g.,
mixed_input_dtype or scalar_fallback_in_dtype) that returns the appropriate
dtype for scalar fallbacks and keep has_mixed_input_dtype (or update it) to
signal mixed-precision so dispatchers (Gemm._select_gemm_instruction and
backends that currently read self.in_dtype for both a_dtype and b_dtype) can
route to the scalar fallback instead of emitting a single-typed MMA intrinsic.
---
Nitpick comments:
In `@requirements.txt`:
- Line 4: The package constraint "apache-tvm-ffi<0.1.8; platform_system ==
'Darwin'" lacks an explanatory comment; add a brief rationale next to this
requirement (and mirror it in requirements-dev and pyproject entries) that
states why versions >=0.1.8 are incompatible on macOS, references the upstream
issue or PR (e.g., tilelang#1502-style link or ticket number), and describes the
condition under which the upper bound can be relaxed (what fix/version to wait
for) so macOS users know when it's safe to upgrade; ensure the comment appears
immediately adjacent to the "apache-tvm-ffi<0.1.8" spec so it's clear which
constraint it documents.
In `@src/transform/lower_device_storage_access_info.cc`:
- Around line 47-49: Add a brief inline comment next to the condition that
checks scope.tag (the if that includes scope.tag != ".fragment") explaining that
".fragment" buffers are intentionally excluded because fragment storage must
survive this lowering pass for all backends — Metal performs target-specific
lowering (e.g., metal_fragment_to_simdgroup) later while CUDA may use
alloc_fragment() and consume fragments directly — so we skip lowering here
rather than asserting/failing; reference the scope.tag check and the ".fragment"
exclusion in the comment for future maintainers.
In `@testing/python/metal/test_metal_codegen_linux.py`:
- Around line 112-119: The test currently only checks that a Metal kernel was
emitted; instead assert the mixed-dtype scalar fallback path by (1) ensuring the
generated kernel source does NOT contain the high-level optimized GEMM lowering
(e.g., the string "T.gemm(S_local, V_shared, O_local)") and (2) asserting it
contains the scalar-cast/fallback marker used by the mixed-dtype path (e.g., a
convert/cast token such as "convert(" or "as_type(") so the test specifically
verifies the elementwise scalar cast/fallback lowering for
attention_chain_mixed_dtype().
In `@testing/python/metal/test_metal_gemm_v2_linux.py`:
- Around line 14-19: The current matmul_gemm_v2 test only instantiates A and B
with the same dtype so it never exercises mixed-input-dtype dispatcher
relaxation; update the test generation to include at least one mixed-dtype case
(e.g., A float32 & B float16 and/or A float16 & B float32) while keeping
accum_dtype=float32, by modifying matmul_gemm_v2 (and its usages) to
accept/parametrize separate input dtypes for A and B (referencing the
matmul_gemm_v2 function and its inner prim_func main and tensor params A, B, C)
and add a new test variant that constructs A and B with differing dtypes to
exercise the fp16×fp16→fp32 vs fp32×fp16→fp32 dispatcher paths.
In `@testing/python/metal/test_metal_internal_scaffolding.py`:
- Around line 425-453: The tests
test_flashqla_gdn_kkt_probe_combines_local_var_state_and_simdgroup_boundary and
test_scaled_packed_quant_and_gdn_probes_source_boundary_tokens are asserting
exact symbol names and float formatting (e.g., "float kkt_bias = 0.000000e+00f;"
and "gate_state = 1.000000e+00f;"), which is brittle; change those assertions to
use regex/token-pattern checks that verify a local float variable is
declared/initialized (e.g., match a "float <ident> = <float literal>;" pattern)
and that an assignment of a float literal to a gate/state variable occurs, and
similarly relax any exact vector type/format checks to look for type families
(e.g., no "float8"/"float4" present) or counts rather than exact formatting so
the tests validate semantics without depending on pretty-printer output.
In `@testing/python/metal/test_metal_simdgroup_store.py`:
- Around line 72-76: The test is brittle because it looks for the temporary name
`C_local`; instead change the assertions to check for the lowering pattern
itself: assert that the source contains at least one "simdgroup_store(" (to
ensure the lowering emitted the store) and assert there are no "simdgroup_load("
occurrences that indicate an unwanted extra round-trip through shared memory;
remove any checks that specifically reference the `C_local` symbol and use the
presence/absence of "simdgroup_store(" and "simdgroup_load(" patterns to
validate behavior.
In `@tilelang/jit/adapter/base.py`:
- Around line 69-86: Update the get_current_device_functor docstring to reflect
the actual fallback chain used by the implementation: first try CUDA, then MPS,
then CPU; replace the outdated sentence that says "returns torch.device('cpu')"
when CUDA is unavailable with a concise description like "returns CUDA if
available, otherwise MPS if available, otherwise CPU" and keep the rest of the
docstring behavior notes (callable that yields Torch's current device).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 0612ea9b-616a-434a-b3f2-4ade74952e79
📒 Files selected for processing (39)
benchmark/matmul_metal/benchmark_matmul_metal.pypyproject.tomlrequirements-dev.txtrequirements.txtsrc/backend/metal/CMakeLists.txtsrc/op/copy.ccsrc/op/copy.hsrc/op/fill.ccsrc/op/gemm.ccsrc/op/gemm.hsrc/op/parallel.ccsrc/op/utils.hsrc/target/codegen_metal.ccsrc/target/codegen_metal.hsrc/transform/layout_inference.ccsrc/transform/lower_device_storage_access_info.cctesting/python/jit/test_tilelang_jit_adapter_mps.pytesting/python/metal/metal_internal_runtime_coverage.mdtesting/python/metal/test_metal_codegen_linux.pytesting/python/metal/test_metal_gemm_v2.pytesting/python/metal/test_metal_gemm_v2_linux.pytesting/python/metal/test_metal_internal_scaffolding.pytesting/python/metal/test_metal_local_var.pytesting/python/metal/test_metal_simdgroup_store.pytilelang/engine/lower.pytilelang/engine/phase.pytilelang/intrinsics/metal_macro_generator.pytilelang/jit/adapter/base.pytilelang/jit/adapter/torch/metal.pytilelang/tileop/gemm/__init__.pytilelang/tileop/gemm/gemm_base.pytilelang/tileop/gemm/gemm_metal.pytilelang/tileop/gemm/inst.pytilelang/tileop/metal_gdn.pytilelang/tileop/metal_quant.pytilelang/tileop/metal_simdgroup.pytilelang/transform/decouple_type_cast.pytilelang/transform/metal_fragment_to_simdgroup.pytilelang/utils/language.py
| // Check that all local.fragment buffers have inferred layouts. | ||
| // On Metal targets, fragment buffers used as GEMM accumulators are | ||
| // lowered to opaque simdgroup matrices, so they have no explicit | ||
| // thread-level layout and can be safely skipped. | ||
| for (const auto &[buffer, _] : use_list_) { | ||
| if (IsFragmentBuffer(buffer)) { | ||
| ICHECK_NE(layout_map.count(buffer), 0) | ||
| << "The layout for fragment " << buffer | ||
| << " can not be inferred correctly."; | ||
| if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) { | ||
| ICHECK(false) << "The layout for fragment " << buffer | ||
| << " can not be inferred correctly."; | ||
| } |
There was a problem hiding this comment.
Keep validating the fragment buffers that remain on Metal.
use_list_ only contains local.fragment buffers. After MetalFragmentToSimdgroup, the buffers that were rewritten to metal.simdgroup are already gone from this list, so this target-wide bypass only suppresses validation for the fragment buffers that weren't rewritten — including the mixed-dtype/scalar-fallback GEMM fragments that still need inferred layouts.
Suggested fix
- if (IsFragmentBuffer(buffer)) {
- if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) {
- ICHECK(false) << "The layout for fragment " << buffer
- << " can not be inferred correctly.";
- }
- }
+ if (IsFragmentBuffer(buffer) && layout_map.count(buffer) == 0) {
+ ICHECK(false) << "The layout for fragment " << buffer
+ << " can not be inferred correctly.";
+ }📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| // Check that all local.fragment buffers have inferred layouts. | |
| // On Metal targets, fragment buffers used as GEMM accumulators are | |
| // lowered to opaque simdgroup matrices, so they have no explicit | |
| // thread-level layout and can be safely skipped. | |
| for (const auto &[buffer, _] : use_list_) { | |
| if (IsFragmentBuffer(buffer)) { | |
| ICHECK_NE(layout_map.count(buffer), 0) | |
| << "The layout for fragment " << buffer | |
| << " can not be inferred correctly."; | |
| if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) { | |
| ICHECK(false) << "The layout for fragment " << buffer | |
| << " can not be inferred correctly."; | |
| } | |
| // Check that all local.fragment buffers have inferred layouts. | |
| // On Metal targets, fragment buffers used as GEMM accumulators are | |
| // lowered to opaque simdgroup matrices, so they have no explicit | |
| // thread-level layout and can be safely skipped. | |
| for (const auto &[buffer, _] : use_list_) { | |
| if (IsFragmentBuffer(buffer) && layout_map.count(buffer) == 0) { | |
| ICHECK(false) << "The layout for fragment " << buffer | |
| << " can not be inferred correctly."; | |
| } |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/transform/layout_inference.cc` around lines 436 - 445, The current
Metal-target bypass skips validating all fragment buffers when
TargetIsMetal(target_) is true, which also skips fragments that still require
inferred layouts (e.g., mixed-dtype/scalar-fallback GEMM fragments); change the
logic in the loop over use_list_ (symbols: use_list_, IsFragmentBuffer,
layout_map) to only skip validation for buffers that were actually rewritten to
simdgroup (detectable by their storage scope or a metal.simdgroup marker, e.g.,
check buffer->storage_scope == "metal.simdgroup" or use a new IsSimdgroupBuffer
helper) instead of using TargetIsMetal(target_), and ensure any remaining
local.fragment buffers still fail the ICHECK if layout_map.count(buffer) == 0 so
their layouts are inferred/validated (this preserves MetalFragmentToSimdgroup
behavior while keeping validation for fragments that remain).
| if target_is_metal(target): | ||
| # GemmMetal (simdgroup) requires A.dtype == B.dtype because it | ||
| # lowers to simdgroup_matrix_multiply. For chained | ||
| # mixed-precision patterns we route to the scalar fallback | ||
| # (GemmMetalScalar) which handles dtype mismatch via per-load | ||
| # T.cast(..., accum_dtype). | ||
| if self._has_mixed_input_dtype(): | ||
| return GemmInst.Scalar | ||
| return GemmInst.METAL_SIMDGROUP | ||
| return GemmInst(_ffi_api.GemmGetGemmInst(self, int(thread_nums), target)) |
There was a problem hiding this comment.
Don't bypass fallback selection for all same-dtype Metal GEMMs.
This now returns GemmInst.METAL_SIMDGROUP for every same-dtype Metal GEMM. That skips the normal legality checks, so cases that are not simdgroup-valid anymore fail later in GemmMetal.lower() instead of falling back cleanly. The mixed-dtype override should be layered on top of the existing selector, not replace it.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/tileop/gemm/__init__.py` around lines 180 - 189, The current Metal
branch unconditionally returns GemmInst.METAL_SIMDGROUP for same-dtype Metal
GEMMs, bypassing the normal selector and legality checks; instead call the
existing selector GemmInst(_ffi_api.GemmGetGemmInst(self, int(thread_nums),
target)) to get the default instance, then only override to GemmInst.Scalar when
target_is_metal(target) and self._has_mixed_input_dtype() is true; otherwise
return the selector's result so non-simdgroup-legal cases can fall back normally
(refer to target_is_metal, _has_mixed_input_dtype, GemmInst.METAL_SIMDGROUP,
GemmInst.Scalar, GemmInst(_ffi_api.GemmGetGemmInst) and GemmMetal.lower).
| if isinstance(stmt, tir.Block): | ||
| new_alloc_bufs = [] | ||
| changed = False | ||
| for buf in stmt.alloc_buffers: | ||
| new_buf = _remap_buffer(buf, var_map) | ||
| new_alloc_bufs.append(new_buf) | ||
| if not new_buf.same_as(buf): | ||
| buf_map[buf] = new_buf | ||
| changed = True | ||
| if changed: | ||
| new_body = tir.stmt_functor.substitute(stmt.body, var_map) | ||
| new_block = tir.Block( | ||
| stmt.iter_vars, | ||
| stmt.reads, | ||
| stmt.writes, | ||
| stmt.name_hint, | ||
| new_body, | ||
| stmt.init, | ||
| new_alloc_bufs, | ||
| stmt.match_buffers, | ||
| stmt.annotations, | ||
| ) | ||
| return ( | ||
| tir.BlockRealize( | ||
| stmt.iter_vars, | ||
| tir.const(True, "bool"), | ||
| new_block, | ||
| ) | ||
| if False | ||
| else new_block | ||
| ) | ||
| elif isinstance(stmt, tir.Allocate): | ||
| new_var = var_map.get(stmt.buffer_var, None) | ||
| if new_var is not None: | ||
| new_body = tir.stmt_functor.substitute(stmt.body, var_map) | ||
| return tir.Allocate(new_var, stmt.dtype, stmt.extents, stmt.condition, new_body, stmt.annotations) |
There was a problem hiding this comment.
Remap the whole block signature, not just the body.
When a buffer is promoted, only stmt.body is substituted. stmt.init, stmt.reads, stmt.writes, and stmt.match_buffers still point at the old buffer objects, so the block can advertise metal.simdgroup in alloc_buffers while its metadata still references local.fragment. buf_map being unused here is a good signal that the rewrite is incomplete.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/transform/metal_fragment_to_simdgroup.py` around lines 126 - 161,
The current rewrite only substitutes stmt.body when buffers are remapped,
leaving stmt.init, stmt.reads, stmt.writes and stmt.match_buffers still pointing
to old buffer objects (buf_map is built but unused). Update the tir.Block
handling (the loop that builds new_alloc_bufs using _remap_buffer, var_map and
buf_map) to also remap the block signature: apply substitutions or rebuild
stmt.init, stmt.reads, stmt.writes and stmt.match_buffers to reference new_bufs
(using buf_map and tir.stmt_functor.substitute where appropriate) and use those
remapped values when constructing new_block (and the BlockRealize branch) so the
block metadata is consistent with alloc_buffers.
Summary
Allows mixed-dtype
T.gemmpatterns through the Metal scalar fallback path so chained MLA-style attention GEMMs lower correctly. Specifically:fp16 × fp16 → fp32(Q · K accumulation)fp32 × fp16 → fp32(P · V accumulation)Without this,
gemm_op._select_gemm_instructionrejected mixed input dtypes on Metal even though the scalar dequant-multiply-accumulate path handles them naturally.Why
Sparse-MLA / multi-latent-attention probes need a chained attention pattern that runs two GEMMs at different input dtypes:
The first GEMM produces an
fp32accumulator that becomes the input of the second GEMM. The Metal scalar fallback codegen handles per-element promotion/demotion correctly; the dispatcher just needs to allow the mismatched input dtypes through.This patch is a foundational piece of the cppmega.mlx Apple Silicon kernel-port effort. The downstream consumer is local sparse-MLA forward kernels that need the chained attention shape to compile to MSL.
Stacking topology
This PR is based on
jorgecurious/tilelang:metal-gemm-upstream-rebase(PR #2130) at HEAD971c17b, which itself stacks on top of:Once #1869+#2118+#2121+#2130 merge into
tile-ai/tilelang:main, this PR can be retargeted to main directly. Until then, please review against #2130's branch as the base.Test plan
The
assert_attention_like_metal_codegenandtest_mixed_shared_merge_metal_codegentest cases exercise the chained-attention dtype shape and verify the scalar-fallback MSL output. Run with:cd /path/to/tilelang pytest testing/python/metal/test_metal_codegen_linux.py -vExpect both cases to pass on Apple Silicon (M-series) and on Linux when Metal is built (the test name reflects historic Linux-CI categorization, not platform restriction).
Caveats
Attribution
Co-developed with
cppmega.mlxfor Apple-Silicon Metal MLA kernel ports.Summary by CodeRabbit
Release Notes
New Features
Dependencies
Tests