Skip to content

[Metal] allow mixed-dtype T.gemm via scalar fallback#2139

Open
apstenku123 wants to merge 11 commits intotile-ai:mainfrom
apstenku123:cppmega/metal-gemm-mixed-dtype
Open

[Metal] allow mixed-dtype T.gemm via scalar fallback#2139
apstenku123 wants to merge 11 commits intotile-ai:mainfrom
apstenku123:cppmega/metal-gemm-mixed-dtype

Conversation

@apstenku123
Copy link
Copy Markdown

@apstenku123 apstenku123 commented May 4, 2026

Summary

Allows mixed-dtype T.gemm patterns through the Metal scalar fallback path so chained MLA-style attention GEMMs lower correctly. Specifically:

  • fp16 × fp16 → fp32 (Q · K accumulation)
  • fp32 × fp16 → fp32 (P · V accumulation)

Without this, gemm_op._select_gemm_instruction rejected mixed input dtypes on Metal even though the scalar dequant-multiply-accumulate path handles them naturally.

Why

Sparse-MLA / multi-latent-attention probes need a chained attention pattern that runs two GEMMs at different input dtypes:

S_local = T.gemm(Q_shared, K_shared, S_local, transpose_B=True)        # fp16 × fp16 → fp32
O_local = T.gemm(S_local, V_shared, O_local)                            # fp32 × fp16 → fp32

The first GEMM produces an fp32 accumulator that becomes the input of the second GEMM. The Metal scalar fallback codegen handles per-element promotion/demotion correctly; the dispatcher just needs to allow the mismatched input dtypes through.

This patch is a foundational piece of the cppmega.mlx Apple Silicon kernel-port effort. The downstream consumer is local sparse-MLA forward kernels that need the chained attention shape to compile to MSL.

Stacking topology

This PR is based on jorgecurious/tilelang:metal-gemm-upstream-rebase (PR #2130) at HEAD 971c17b, which itself stacks on top of:

Once #1869+#2118+#2121+#2130 merge into tile-ai/tilelang:main, this PR can be retargeted to main directly. Until then, please review against #2130's branch as the base.

Test plan

The assert_attention_like_metal_codegen and test_mixed_shared_merge_metal_codegen test cases exercise the chained-attention dtype shape and verify the scalar-fallback MSL output. Run with:

cd /path/to/tilelang
pytest testing/python/metal/test_metal_codegen_linux.py -v

Expect both cases to pass on Apple Silicon (M-series) and on Linux when Metal is built (the test name reflects historic Linux-CI categorization, not platform restriction).

Caveats

  • This is the dispatcher relax + downstream codegen plumbing; no new MSL emission code is added (the scalar fallback already supported the dtype mix internally).
  • Performance: scalar fallback path, not simdgroup MMA. Use this for correctness/portability; PR [Metal] Add Metal GEMM support with simdgroup_matrix MMA #1869's simdgroup path remains the perf-critical route for matched-dtype GEMMs.

Attribution

Co-developed with cppmega.mlx for Apple-Silicon Metal MLA kernel ports.

Summary by CodeRabbit

Release Notes

  • New Features

    • Added comprehensive Metal/MPS backend support for kernel compilation and execution on macOS devices.
    • Implemented Metal-accelerated GEMM operations using simdgroup optimization for matrix multiplication.
    • Added Metal simdgroup-based copy and fill operations for improved data movement efficiency.
    • Introduced automatic MPS device fallback in JIT compilation when CUDA is unavailable.
    • Added support for quantized operations and GDN/attention computations on Metal backend.
  • Dependencies

    • Updated platform-specific constraints for Apache TVM FFI on macOS.
  • Tests

    • Added comprehensive Metal backend test coverage including GEMM, copy, fill, and quantization operations.

oraluben and others added 11 commits April 30, 2026 01:43
Add T.gemm support for Apple Metal using simdgroup_matrix 8x8 operations
(simdgroup_load/store/multiply_accumulate). Works on all Apple Silicon
(M1-M5) without requiring a TVM fork.

Key changes:
- codegen_metal.cc/h: Fork TVM Metal codegen to tilelang with
  simdgroup intrinsic emission and 128-bit vectorized copy
- gemm_metal.py: GemmMetal tile operator for sharedxshared GEMM
- metal_macro_generator.py: MPSIntrinEmitter for simdgroup MMA macros
- metal_fragment_to_simdgroup.py: Pass rewrites local.fragment GEMM
  accumulators to metal.simdgroup scope before layout inference
- LowerSIMDGroupCopy in copy.cc for fragment->device simdgroup_store

24 Metal tests (codegen cross-platform + correctness on device).
Copilot AI review requested due to automatic review settings May 4, 2026 08:52
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 4, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 4, 2026

📝 Walkthrough

Walkthrough

This pull request adds comprehensive Metal (MPS) backend support to TileLang, including a Metal code generator, simdgroup-based GEMM lowering, fragment-to-simdgroup IR transformation, register tile abstractions, quantization and GDN helpers, Metal-specific copy/fill operations, and extensive testing for codegen and runtime correctness. It also updates dependency constraints for macOS and integrates Metal device selection into the JIT adapter fallback.

Changes

Metal Backend Implementation

Layer / File(s) Summary
Data Structures & Enums
src/op/copy.h, src/op/gemm.h, tilelang/tileop/gemm/inst.py
New enum values added: CopyInst::kMetalSIMDGroup, GemmInst::kMetalSimdgroup (= 6), and GemmInst.is_metal_simdgroup() predicate for instruction dispatch.
Core Metal Code Generator
src/target/codegen_metal.h, src/target/codegen_metal.cc
Pure C++ Metal shader generator with type mapping, storage scope handling, simdgroup 8×8 matrix allocation, local.var scalar support, float constant emission with subnormal handling, and SIMD-group builtin call emission. Registers FFI entry point target.build.tilelang_metal.
Low-Level Operations
src/op/copy.cc, src/op/copy.h, src/op/fill.cc
Metal simdgroup-specific paths for copy and fill: CheckSIMDGroupCopy() validates shape/scope constraints; LowerSIMDGroupCopy() emits simdgroup store tiles; FillNode::Lower detects simdgroup scope and emits make_filled_simdgroup_matrix for 8×8 aligned regions.
GEMM Instruction Selection & Warp Policy
src/op/gemm.cc, src/op/gemm.h
GemmNode::getGemmInst() now returns kMetalSimdgroup for Metal targets; GemmWarpPolicyNode::computeWarpPartition() uses kMPerWarp = 8 on Metal (vs. 16 elsewhere).
IR Transformation Pass
tilelang/transform/metal_fragment_to_simdgroup.py, tilelang/transform/layout_inference.cc
New MetalFragmentToSimdgroup prim-func pass rewrites local.fragment accumulators to metal.simdgroup for same-dtype GEMMs, excluding mixed-precision cases. Layout inference now skips fragment-layout validation on Metal (opaque simdgroup matrices).
Parallel Op Layout Handling
src/op/parallel.cc, src/op/utils.h
Fragment layout retrieval guards with frag.has_value() check. New buffer-scope helpers: IsSIMDGroupBuffer() and IsRegisterBuffer() (fragment or simdgroup).
Build Configuration
src/backend/metal/CMakeLists.txt
Metal codegen source (src/target/codegen_metal.cc) now always compiled for cross-compilation support; non-Apple hosts early-exit after codegen-only setup.
Build Integration
tilelang/engine/lower.py, tilelang/engine/phase.py
Device codegen now calls target.build.tilelang_metal instead of generic Metal builder. Lowering pipeline inserts MetalFragmentToSimdgroup after software-pipeline injection.
Storage Access Lowering
src/transform/lower_device_storage_access_info.cc
Fragment-tagged allocations now excluded from device-storage lowering.

Metal-Specific GEMM & Tiling Infrastructure

Layer / File(s) Summary
GEMM Base Classes
tilelang/tileop/gemm/gemm_base.py, tilelang/tileop/gemm/__init__.py
in_dtype no longer asserts A/B dtype equality; new has_mixed_input_dtype property detects mixed precision. Instruction selection dispatches to GemmMetal on Metal targets unless mixed-precision (fallback to scalar).
Metal GEMM Lowering
tilelang/tileop/gemm/gemm_metal.py
GemmMetal class validates M/N/chunk multiples of 8, computes warp partitions via Metal policy, enforces C scope compatibility, and emits simplified @T.prim_func kernels: SIMDGROUP-output mode (direct accumulation) or shared-output mode (intermediate simdgroup accumulator with final store).
Metal Intrinsics Emitter
tilelang/intrinsics/metal_macro_generator.py
MPSIntrinEmitter generates macro-based ldmatrix, mma, and simdgroup load/store sequences parameterized by operand dtypes, transposition, and warp/tile decomposition.
Register Tile Abstractions
tilelang/tileop/metal_simdgroup.py
Immutable RegisterTile and MMATile dataclasses with fragment backing, layout metadata, and RowVector for materialized scalars. Macros for fragment allocation, fill, load/store (with transpose), MMA accumulation, and row-wise tensor ops (max, sum, multiply, divide).

Specialized Kernels & Helpers

Layer / File(s) Summary
Quantization Helpers
tilelang/tileop/metal_quant.py
Shape selectors for simdgroup tiling (SMALL_TILE, LARGE_TILE). Decode functions for packed fp8/fp4/e8m0 formats into float32 with special-case handling for zero and subnormals.
GDN/Attention Kernels
tilelang/tileop/metal_gdn.py
Macros for KKT score computation (tile loads, MMA, gating with causal mask), W/U element/tile accumulation (strided and non-strided), and staged score materialization with per-block gate decay.
Type Cast Decoupling
tilelang/transform/decouple_type_cast.py, tilelang/utils/language.py
Updated is_local_buffer to treat Metal simdgroup buffers as local/register-level. Added is_metal_simdgroup() helper in language utilities.
JIT Runtime Integration
tilelang/jit/adapter/base.py, tilelang/jit/adapter/torch/metal.py
BaseKernelAdapter.get_current_device_functor() now falls back to MPS when CUDA unavailable. MetalKernelAdapter.get_kernel_source() exposes generated kernel text.

Testing & Documentation

Layer / File(s) Summary
Device Selection Tests
testing/python/jit/test_tilelang_jit_adapter_mps.py
Three tests validate MPS device selection when CUDA unavailable or fails, with CPU fallback.
Codegen Validation
testing/python/metal/test_metal_codegen_linux.py, testing/python/metal/test_metal_gemm_v2_linux.py, testing/python/metal/test_metal_local_var.py
Cross-platform tests verify Metal shader generation: mixed-dtype GEMM chains, GEMM v2 with multiple tile/dtype configs, and local.var scalar codegen.
Runtime Correctness (MPS-Gated)
testing/python/metal/test_metal_gemm_v2.py, testing/python/metal/test_metal_simdgroup_store.py, testing/python/metal/test_metal_internal_scaffolding.py
Hardware tests run compiled kernels on MPS, compare against PyTorch/CPU reference implementations, and validate numerical correctness with configurable tolerances.
Internal Scaffolding
testing/python/metal/test_metal_internal_scaffolding.py
Comprehensive internal-only probes for register tiles, quantized matmul, GDN/KKT/W/U scoring, and mixed fp8/fp4 decoding. Tests source-boundary constraints, fail-closed native-dtype handling, and opt-in benchmarking.
Benchmark Script
benchmark/matmul_metal/benchmark_matmul_metal.py
Standalone Metal GEMM benchmarking script comparing TileLang simdgroup kernels against torch.mm reference with TFLOPS reporting and block-config sweep.
Coverage Documentation
testing/python/metal/metal_internal_runtime_coverage.md
Internal-only coverage specification for packed quant and GDN probes, fail-closed boundaries, and verification/benchmark commands.

Dependency Updates

Layer / File(s) Summary
macOS-Specific Constraints
pyproject.toml, requirements.txt, requirements-dev.txt
Darwin-specific upper bound apache-tvm-ffi<0.1.8 added via environment markers alongside existing base constraints.

Sequence Diagram(s)

The PR introduces sufficient novelty and multi-component interaction (JIT compiler → IR transforms → code generator → Metal runtime) to benefit from a high-level sequence diagram:

sequenceDiagram
    participant User as User/JIT
    participant Lower as Lowering Pipeline
    participant Fragment as MetalFragmentToSimdgroup
    participant Codegen as CodeGenTileLangMetal
    participant Metal as Metal Runtime

    User->>Lower: tilelang.jit(kernel)
    activate Lower
    Lower->>Fragment: IR module with<br/>local.fragment C buffers
    activate Fragment
    Fragment->>Fragment: Detect same-dtype GEMM,<br/>rewrite to metal.simdgroup
    Fragment-->>Lower: Updated IR module
    deactivate Fragment
    
    Lower->>Codegen: PrimFunc with<br/>simdgroup allocations
    activate Codegen
    Codegen->>Codegen: Emit Metal kernel signature<br/>Type mapping<br/>simdgroup 8x8 matrices
    Codegen->>Codegen: Emit builtin calls:<br/>simdgroup_load/store/mma
    Codegen-->>Lower: Metal shader source
    deactivate Codegen
    
    Lower->>Lower: Register FFI entry<br/>target.build.tilelang_metal
    deactivate Lower
    
    User->>Metal: Execute compiled kernel<br/>on MPS device
    activate Metal
    Metal->>Metal: threadgroup barrier sync<br/>8x8 matrix operations
    Metal-->>User: Output tensor
    deactivate Metal
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~65 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 Hops excitedly
Metal simdgroups bloom where fragments once did sleep,
Eight-by-eight matrices in register heaps,
From Darwin's shores to cross-compiled seas,
GEMM and quantization dance with MPS ease! ✨

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

apstenku123 added a commit to DatasunriseOU/cppmega_mlx that referenced this pull request May 4, 2026
Files documenting the actual PRs we just opened upstream:

- PR #1: ml-explore/mlx#3476 — from_dlpack Metal-aware consumer (against main, clean)
- PR #2: apache/tvm#19504 — TVM_METAL_STORAGE_MODE env opt-in (against main, clean)
- PR #3: tile-ai/tilelang#2139 — mixed-dtype T.gemm via scalar fallback (stacks on PR #2130)
- PR #4: tile-ai/tilelang#2140 — FP8-input T.gemm scalar fallback routing (stacks on PR #2130)
- PR #5: tile-ai/tilelang#2141 — T.Pipelined num_stages>1 3D buffer fix (stacks on PR #2130)
- PR #6: tile-ai/tilelang#2142 — T.fp8_scaled_matmul DSL intrinsic (stacks on PR #2130)

Deferred (split into companion PRs needed): tilelang_metal_fp8 and
tilelang_metal_fp8_vector each touch both tilelang supermodule and the
TileLang/tvm vendored submodule. These need 2 PRs each — one to
tile-ai/tilelang, one to TileLang/tvm — separate filing round.

PRs #3-#6 are independent of each other; each branches directly from
jorgecurious/tilelang:metal-gemm-upstream-rebase HEAD 971c17b, so they
can be reviewed in any order. They DO depend on the upstream 4-PR Apple
Metal landing chain (#1869, #2118, #2121, #2130) merging first; if any
of those land separately, ours can be retargeted at main.
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Enables mixed-dtype T.gemm lowering on the Metal backend by routing mixed-input GEMMs to the scalar fallback while keeping same-dtype GEMMs on the simdgroup MMA path, so chained attention-style GEMMs (e.g., fp16×fp16→fp32 followed by fp32×fp16→fp32) can compile to MSL.

Changes:

  • Add Metal simdgroup GEMM instruction selection and a Metal GEMM lowering path, while forcing mixed-input-dtype GEMMs onto the scalar fallback on Metal.
  • Introduce metal.simdgroup infrastructure across Python + C++ (scope detection, fragment→simdgroup rewrite, simdgroup fill/copy lowering, and Metal codegen plumbing).
  • Add/expand Metal codegen + runtime tests and internal scaffolding probes; add a Metal matmul benchmark script.

Reviewed changes

Copilot reviewed 38 out of 39 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
tilelang/utils/language.py Adds is_metal_simdgroup scope predicate.
tilelang/transform/metal_fragment_to_simdgroup.py New pass to rewrite select GEMM accumulators from local.fragment to metal.simdgroup on Metal.
tilelang/transform/decouple_type_cast.py Treats metal.simdgroup buffers as “local” for cast-decoupling logic.
tilelang/tileop/metal_simdgroup.py Adds internal RegisterTile/RowVector helpers and simdgroup macros for Metal kernels.
tilelang/tileop/metal_quant.py Adds internal fp8/fp4/e8m0 decode helpers and tile selectors for quant probes.
tilelang/tileop/metal_gdn.py Adds internal GDN/attention-style macros built on simdgroup helpers.
tilelang/tileop/gemm/inst.py Adds METAL_SIMDGROUP GEMM instruction kind.
tilelang/tileop/gemm/gemm_metal.py New Metal simdgroup GEMM lowering implementation.
tilelang/tileop/gemm/gemm_base.py Allows mixed A/B dtypes and adds a mixed-input-dtype predicate.
tilelang/tileop/gemm/init.py Metal-specific GEMM dispatch: mixed dtypes → scalar fallback; otherwise simdgroup.
tilelang/jit/adapter/torch/metal.py Exposes Metal kernel source and compiles via torch.mps.compile_shader.
tilelang/jit/adapter/base.py Updates device selection to prefer MPS when CUDA is unavailable/failed.
tilelang/intrinsics/metal_macro_generator.py Adds MPSIntrinEmitter for simdgroup load/store/MMA macro emission.
tilelang/engine/phase.py Inserts Metal fragment→simdgroup rewrite into the lowering pipeline.
tilelang/engine/lower.py Switches Metal build hook to target.build.tilelang_metal.
testing/python/metal/test_metal_simdgroup_store.py Adds tests for direct simdgroup-store-to-device GEMM path.
testing/python/metal/test_metal_local_var.py Adds focused tests for local.var scalar codegen/runtime on Metal.
testing/python/metal/test_metal_internal_scaffolding.py Adds internal-only source-boundary + runtime probes for Metal helpers/quant/GDN.
testing/python/metal/test_metal_gemm_v2.py Adds Metal GEMM v2 runtime correctness tests.
testing/python/metal/test_metal_gemm_v2_linux.py Adds cross-platform Metal GEMM v2 codegen tests.
testing/python/metal/test_metal_codegen_linux.py Adds chained attention mixed-dtype Metal codegen test coverage.
testing/python/metal/metal_internal_runtime_coverage.md Documents internal Metal runtime/source-boundary coverage.
testing/python/jit/test_tilelang_jit_adapter_mps.py Adds tests for MPS preference in JIT adapter device selection.
src/transform/lower_device_storage_access_info.cc Treats .fragment scope tag as a special case in storage access lowering.
src/transform/layout_inference.cc Skips fragment-layout assertion on Metal targets.
src/target/codegen_metal.h Adds TileLang Metal codegen class declaration.
src/target/codegen_metal.cc Adds TileLang Metal codegen implementation and registers build hook.
src/op/utils.h Adds helpers for detecting simdgroup/register buffers.
src/op/parallel.cc Makes fragment layout usage more defensive when layout info is absent.
src/op/gemm.h Adds Metal simdgroup GEMM enum value and string conversion.
src/op/gemm.cc Selects Metal simdgroup GEMM inst; adjusts warp partitioning for Metal.
src/op/fill.cc Adds metal.simdgroup fill lowering via make_filled_simdgroup_matrix.
src/op/copy.h Adds Metal simdgroup copy instruction and lowering declarations.
src/op/copy.cc Adds simdgroup store lowering for T.copy from simdgroup to shared/global.
src/backend/metal/CMakeLists.txt Always builds Metal codegen; enables codegen-only mode on non-Apple hosts.
requirements.txt Adds Darwin-specific apache-tvm-ffi upper bound.
requirements-dev.txt Adds Darwin-specific apache-tvm-ffi upper bound for dev installs.
pyproject.toml Adds Darwin-specific apache-tvm-ffi upper bound for packaging.
benchmark/matmul_metal/benchmark_matmul_metal.py Adds a Metal GEMM benchmark script using the simdgroup path.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +37 to +43
class CodeGenTileLangMetal final : public CodeGenC {
public:
explicit CodeGenTileLangMetal(Target target);
// override print thread tag.
void PrintArgUnionDecl();
void AddFunction(const GlobalVar &gvar, const PrimFunc &func) final;
void InitFuncState(const PrimFunc &f) final;
Comment thread src/op/copy.cc
Comment on lines +1096 to +1100
}
float ideal = N > 0 ? static_cast<float>(M) / N : 1.f;
float best_score = std::numeric_limits<float>::max();
for (int m = 1; m <= std::min(num_warps, max_m); ++m) {
if (num_warps % m != 0)
Comment on lines 76 to 85
if torch.cuda.is_available():
try:
torch.cuda._lazy_init()
current_device = torch._C._cuda_getDevice
return lambda: torch.device("cuda", current_device())
except Exception:
return lambda: torch.device("cuda", torch.cuda.current_device())
pass
if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
return lambda: torch.device("mps")
# CPU fallback
Comment on lines +122 to +156
def _rewrite_scope(body, var_map):
buf_map = {}

def _pre_order(stmt):
if isinstance(stmt, tir.Block):
new_alloc_bufs = []
changed = False
for buf in stmt.alloc_buffers:
new_buf = _remap_buffer(buf, var_map)
new_alloc_bufs.append(new_buf)
if not new_buf.same_as(buf):
buf_map[buf] = new_buf
changed = True
if changed:
new_body = tir.stmt_functor.substitute(stmt.body, var_map)
new_block = tir.Block(
stmt.iter_vars,
stmt.reads,
stmt.writes,
stmt.name_hint,
new_body,
stmt.init,
new_alloc_bufs,
stmt.match_buffers,
stmt.annotations,
)
return (
tir.BlockRealize(
stmt.iter_vars,
tir.const(True, "bool"),
new_block,
)
if False
else new_block
)
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tilelang/tileop/gemm/gemm_base.py (1)

80-110: ⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Don't silently reinterpret mixed inputs as A.dtype.

in_dtype no longer rejects A.dtype != B.dtype, but the typed backends still wire self.in_dtype into both a_dtype and b_dtype. That means a mixed GEMM outside the Metal scalar fallback can now be emitted as if both operands had A.dtype, which is a miscompile where the old code failed fast. I'd keep in_dtype strict and expose a separate mixed-dtype path/property for the scalar fallback instead.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/tileop/gemm/gemm_base.py` around lines 80 - 110, The in_dtype
property must remain strict: restore behavior so it does not silently pretend
both operands share A.dtype when A.dtype != B.dtype — if
is_tensor_memory(self.A) return self.B.dtype as before, otherwise ensure A.dtype
== B.dtype (raise or assert/emit an explicit error) instead of returning A.dtype
unconditionally; add a separate property (e.g., mixed_input_dtype or
scalar_fallback_in_dtype) that returns the appropriate dtype for scalar
fallbacks and keep has_mixed_input_dtype (or update it) to signal
mixed-precision so dispatchers (Gemm._select_gemm_instruction and backends that
currently read self.in_dtype for both a_dtype and b_dtype) can route to the
scalar fallback instead of emitting a single-typed MMA intrinsic.
src/backend/metal/CMakeLists.txt (1)

6-18: ⚠️ Potential issue | 🔴 Critical

Remove the non-existent runtime/metal/metal_module.h include or fix the reference—this will break compilation on all platforms.

The unconditional append of src/target/codegen_metal.cc to TILE_LANG_SRCS at lines 6–8 will immediately fail on every platform because line 35 of codegen_metal.cc includes "runtime/metal/metal_module.h", which does not exist anywhere in the repository. The early return for non-Apple platforms at line 18 cannot prevent this compilation error.

Either remove the include statement or ensure the header file is properly available before this change can be merged.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/backend/metal/CMakeLists.txt` around lines 6 - 18, The build
unconditionally appends src/target/codegen_metal.cc to TILE_LANG_SRCS but that
source includes "runtime/metal/metal_module.h" which doesn't exist, causing
compile failures; fix by making the append conditional (e.g., move or wrap the
list(APPEND TILE_LANG_SRCS src/target/codegen_metal.cc) behind the APPLE check
or add an if(APPLE) ... endif() around it) or alternatively remove/fix the
include inside src/target/codegen_metal.cc to point to an existing header or add
the missing runtime/metal/metal_module.h; reference TILE_LANG_SRCS,
src/target/codegen_metal.cc, and the include "runtime/metal/metal_module.h" when
making the change.
🧹 Nitpick comments (8)
tilelang/jit/adapter/base.py (1)

69-86: 💤 Low value

Docstring is outdated after MPS fallback addition.

The docstring at line 74 states the method "returns torch.device('cpu')" when CUDA is unavailable, but the implementation now returns MPS when available. Consider updating the docstring to reflect the new fallback chain: CUDA → MPS → CPU.

📝 Suggested docstring update
     `@staticmethod`
     def get_current_device_functor() -> Callable[[], torch.device]:
         """Return a callable that yields Torch's current device.

         Similar to the stream functor, we capture a callable that, when called,
-        fetches the current device according to PyTorch. On CPU or when CUDA is
-        unavailable, returns ``torch.device('cpu')``.
+        fetches the current device according to PyTorch. Falls back through
+        CUDA → MPS → CPU based on availability.
         """
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/jit/adapter/base.py` around lines 69 - 86, Update the
get_current_device_functor docstring to reflect the actual fallback chain used
by the implementation: first try CUDA, then MPS, then CPU; replace the outdated
sentence that says "returns torch.device('cpu')" when CUDA is unavailable with a
concise description like "returns CUDA if available, otherwise MPS if available,
otherwise CPU" and keep the rest of the docstring behavior notes (callable that
yields Torch's current device).
requirements.txt (1)

4-4: 💤 Low value

Document the rationale for apache-tvm-ffi<0.1.8 on Darwin.

The constraint is applied consistently across requirements.txt, requirements-dev.txt, and pyproject.toml, but none of them explain why >=0.1.8 is incompatible on macOS. The pyproject.toml already has a precedent for this pattern (it documents why >=0.1.6 is needed via tilelang#1502). Without a similar comment here, it's unclear when this upper bound can safely be relaxed, and macOS users are silently prevented from picking up any bug/security fixes shipped in 0.1.8+.

💬 Suggested comment
-apache-tvm-ffi<0.1.8; platform_system == 'Darwin'
+# <0.1.8 on macOS: 0.1.8 introduces an RFC Error-ABI change that breaks
+# TileLang's Metal FFI entrypoints (see <link-to-issue>).
+apache-tvm-ffi<0.1.8; platform_system == 'Darwin'

(Mirror the comment in pyproject.toml and requirements-dev.txt.)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@requirements.txt` at line 4, The package constraint "apache-tvm-ffi<0.1.8;
platform_system == 'Darwin'" lacks an explanatory comment; add a brief rationale
next to this requirement (and mirror it in requirements-dev and pyproject
entries) that states why versions >=0.1.8 are incompatible on macOS, references
the upstream issue or PR (e.g., tilelang#1502-style link or ticket number), and
describes the condition under which the upper bound can be relaxed (what
fix/version to wait for) so macOS users know when it's safe to upgrade; ensure
the comment appears immediately adjacent to the "apache-tvm-ffi<0.1.8" spec so
it's clear which constraint it documents.
src/transform/lower_device_storage_access_info.cc (1)

47-49: Add a brief comment explaining why .fragment scope is excluded.

The exclusion of scope.tag != ".fragment" is correct—fragment buffers must survive this pass on all targets, not just Metal, before being transformed by target-specific lowering (e.g., metal_fragment_to_simdgroup for Metal, or consumed directly for CUDA). However, CUDA kernels do use alloc_fragment() (confirmed in tests like test_tilelang_language_wgmma_gemm.py), so the condition converts a previous loud ICHECK failure into a silent skip. A comment documenting this intent preserves clarity for future maintainers:

Suggested inline comment
     scope.tag != ".barrier" && scope.tag != ".cluster_barrier" &&
+    // Fragment buffers must survive this pass for target-specific lowering phases.
     scope.tag != ".fragment" && scope.tag.find(".descriptor") != 0) {
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/transform/lower_device_storage_access_info.cc` around lines 47 - 49, Add
a brief inline comment next to the condition that checks scope.tag (the if that
includes scope.tag != ".fragment") explaining that ".fragment" buffers are
intentionally excluded because fragment storage must survive this lowering pass
for all backends — Metal performs target-specific lowering (e.g.,
metal_fragment_to_simdgroup) later while CUDA may use alloc_fragment() and
consume fragments directly — so we skip lowering here rather than
asserting/failing; reference the scope.tag check and the ".fragment" exclusion
in the comment for future maintainers.
testing/python/metal/test_metal_codegen_linux.py (1)

112-119: ⚡ Quick win

Make this assert the mixed-dtype fallback path, not just successful codegen.

Right now the test only proves that some Metal kernel was emitted. A regression that routes T.gemm(S_local, V_shared, O_local) through the wrong lowering path would still pass. Please add one assertion that is specific to the scalar mixed-dtype fallback/cast behavior so this actually guards the dispatcher change.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_codegen_linux.py` around lines 112 - 119, The
test currently only checks that a Metal kernel was emitted; instead assert the
mixed-dtype scalar fallback path by (1) ensuring the generated kernel source
does NOT contain the high-level optimized GEMM lowering (e.g., the string
"T.gemm(S_local, V_shared, O_local)") and (2) asserting it contains the
scalar-cast/fallback marker used by the mixed-dtype path (e.g., a convert/cast
token such as "convert(" or "as_type(") so the test specifically verifies the
elementwise scalar cast/fallback lowering for attention_chain_mixed_dtype().
testing/python/metal/test_metal_internal_scaffolding.py (2)

425-453: ⚡ Quick win

These source assertions are overfit to the current pretty-printer.

Checks like float kkt_bias = 0.000000e+00f; and gate_state = 1.000000e+00f; depend on exact symbol names and float formatting. A benign SSA rename or constant-printing tweak will fail the tests without changing the lowering semantics. Regexing for the local-var initialization pattern would keep the intent while making the tests much less brittle.

Also applies to: 456-474

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_internal_scaffolding.py` around lines 425 -
453, The tests
test_flashqla_gdn_kkt_probe_combines_local_var_state_and_simdgroup_boundary and
test_scaled_packed_quant_and_gdn_probes_source_boundary_tokens are asserting
exact symbol names and float formatting (e.g., "float kkt_bias = 0.000000e+00f;"
and "gate_state = 1.000000e+00f;"), which is brittle; change those assertions to
use regex/token-pattern checks that verify a local float variable is
declared/initialized (e.g., match a "float <ident> = <float literal>;" pattern)
and that an assignment of a float literal to a gate/state variable occurs, and
similarly relax any exact vector type/format checks to look for type families
(e.g., no "float8"/"float4" present) or counts rather than exact formatting so
the tests validate semantics without depending on pretty-printer output.

413-423: ⚡ Quick win

Narrow the native-fp8/fp4 negative checks.

The assert "float8" not in ... / assert "float4" not in ... checks are broader than the behavior you're trying to lock down. float4 is a valid Metal vector type, and simdgroup_float8x8 also contains the float8 substring, so unrelated codegen changes can trip these tests. Match the unsupported dtype spellings (float8_e4m3fn, float4_e2m1fn, etc.) or the tensor boundary types instead.

Also applies to: 436-463

testing/python/metal/test_metal_gemm_v2_linux.py (1)

14-19: ⚡ Quick win

Add a mixed-input-dtype coverage case for this PR path.

A and B are always instantiated with the same dtype here, so this suite never exercises the dispatcher relaxation for mixed-input GEMMs. A regression in the fp16×fp16→fp32 / fp32×fp16→fp32 scalar-fallback path would still pass all of these tests.

Also applies to: 39-78

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2_linux.py` around lines 14 - 19, The
current matmul_gemm_v2 test only instantiates A and B with the same dtype so it
never exercises mixed-input-dtype dispatcher relaxation; update the test
generation to include at least one mixed-dtype case (e.g., A float32 & B float16
and/or A float16 & B float32) while keeping accum_dtype=float32, by modifying
matmul_gemm_v2 (and its usages) to accept/parametrize separate input dtypes for
A and B (referencing the matmul_gemm_v2 function and its inner prim_func main
and tensor params A, B, C) and add a new test variant that constructs A and B
with differing dtypes to exercise the fp16×fp16→fp32 vs fp32×fp16→fp32
dispatcher paths.
testing/python/metal/test_metal_simdgroup_store.py (1)

72-76: ⚡ Quick win

Avoid pinning the codegen test to the C_local symbol name.

C_local is an emitter detail, not a stable API. A harmless rename in simplify/SSA will break this test even if the direct simdgroup_store lowering is still correct. Prefer asserting on structural patterns (simdgroup_store( present, no extra C round-trip) instead of the temporary's exact name.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_simdgroup_store.py` around lines 72 - 76, The
test is brittle because it looks for the temporary name `C_local`; instead
change the assertions to check for the lowering pattern itself: assert that the
source contains at least one "simdgroup_store(" (to ensure the lowering emitted
the store) and assert there are no "simdgroup_load(" occurrences that indicate
an unwanted extra round-trip through shared memory; remove any checks that
specifically reference the `C_local` symbol and use the presence/absence of
"simdgroup_store(" and "simdgroup_load(" patterns to validate behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/transform/layout_inference.cc`:
- Around line 436-445: The current Metal-target bypass skips validating all
fragment buffers when TargetIsMetal(target_) is true, which also skips fragments
that still require inferred layouts (e.g., mixed-dtype/scalar-fallback GEMM
fragments); change the logic in the loop over use_list_ (symbols: use_list_,
IsFragmentBuffer, layout_map) to only skip validation for buffers that were
actually rewritten to simdgroup (detectable by their storage scope or a
metal.simdgroup marker, e.g., check buffer->storage_scope == "metal.simdgroup"
or use a new IsSimdgroupBuffer helper) instead of using TargetIsMetal(target_),
and ensure any remaining local.fragment buffers still fail the ICHECK if
layout_map.count(buffer) == 0 so their layouts are inferred/validated (this
preserves MetalFragmentToSimdgroup behavior while keeping validation for
fragments that remain).

In `@tilelang/tileop/gemm/__init__.py`:
- Around line 180-189: The current Metal branch unconditionally returns
GemmInst.METAL_SIMDGROUP for same-dtype Metal GEMMs, bypassing the normal
selector and legality checks; instead call the existing selector
GemmInst(_ffi_api.GemmGetGemmInst(self, int(thread_nums), target)) to get the
default instance, then only override to GemmInst.Scalar when
target_is_metal(target) and self._has_mixed_input_dtype() is true; otherwise
return the selector's result so non-simdgroup-legal cases can fall back normally
(refer to target_is_metal, _has_mixed_input_dtype, GemmInst.METAL_SIMDGROUP,
GemmInst.Scalar, GemmInst(_ffi_api.GemmGetGemmInst) and GemmMetal.lower).

In `@tilelang/transform/metal_fragment_to_simdgroup.py`:
- Around line 126-161: The current rewrite only substitutes stmt.body when
buffers are remapped, leaving stmt.init, stmt.reads, stmt.writes and
stmt.match_buffers still pointing to old buffer objects (buf_map is built but
unused). Update the tir.Block handling (the loop that builds new_alloc_bufs
using _remap_buffer, var_map and buf_map) to also remap the block signature:
apply substitutions or rebuild stmt.init, stmt.reads, stmt.writes and
stmt.match_buffers to reference new_bufs (using buf_map and
tir.stmt_functor.substitute where appropriate) and use those remapped values
when constructing new_block (and the BlockRealize branch) so the block metadata
is consistent with alloc_buffers.

---

Outside diff comments:
In `@src/backend/metal/CMakeLists.txt`:
- Around line 6-18: The build unconditionally appends
src/target/codegen_metal.cc to TILE_LANG_SRCS but that source includes
"runtime/metal/metal_module.h" which doesn't exist, causing compile failures;
fix by making the append conditional (e.g., move or wrap the list(APPEND
TILE_LANG_SRCS src/target/codegen_metal.cc) behind the APPLE check or add an
if(APPLE) ... endif() around it) or alternatively remove/fix the include inside
src/target/codegen_metal.cc to point to an existing header or add the missing
runtime/metal/metal_module.h; reference TILE_LANG_SRCS,
src/target/codegen_metal.cc, and the include "runtime/metal/metal_module.h" when
making the change.

In `@tilelang/tileop/gemm/gemm_base.py`:
- Around line 80-110: The in_dtype property must remain strict: restore behavior
so it does not silently pretend both operands share A.dtype when A.dtype !=
B.dtype — if is_tensor_memory(self.A) return self.B.dtype as before, otherwise
ensure A.dtype == B.dtype (raise or assert/emit an explicit error) instead of
returning A.dtype unconditionally; add a separate property (e.g.,
mixed_input_dtype or scalar_fallback_in_dtype) that returns the appropriate
dtype for scalar fallbacks and keep has_mixed_input_dtype (or update it) to
signal mixed-precision so dispatchers (Gemm._select_gemm_instruction and
backends that currently read self.in_dtype for both a_dtype and b_dtype) can
route to the scalar fallback instead of emitting a single-typed MMA intrinsic.

---

Nitpick comments:
In `@requirements.txt`:
- Line 4: The package constraint "apache-tvm-ffi<0.1.8; platform_system ==
'Darwin'" lacks an explanatory comment; add a brief rationale next to this
requirement (and mirror it in requirements-dev and pyproject entries) that
states why versions >=0.1.8 are incompatible on macOS, references the upstream
issue or PR (e.g., tilelang#1502-style link or ticket number), and describes the
condition under which the upper bound can be relaxed (what fix/version to wait
for) so macOS users know when it's safe to upgrade; ensure the comment appears
immediately adjacent to the "apache-tvm-ffi<0.1.8" spec so it's clear which
constraint it documents.

In `@src/transform/lower_device_storage_access_info.cc`:
- Around line 47-49: Add a brief inline comment next to the condition that
checks scope.tag (the if that includes scope.tag != ".fragment") explaining that
".fragment" buffers are intentionally excluded because fragment storage must
survive this lowering pass for all backends — Metal performs target-specific
lowering (e.g., metal_fragment_to_simdgroup) later while CUDA may use
alloc_fragment() and consume fragments directly — so we skip lowering here
rather than asserting/failing; reference the scope.tag check and the ".fragment"
exclusion in the comment for future maintainers.

In `@testing/python/metal/test_metal_codegen_linux.py`:
- Around line 112-119: The test currently only checks that a Metal kernel was
emitted; instead assert the mixed-dtype scalar fallback path by (1) ensuring the
generated kernel source does NOT contain the high-level optimized GEMM lowering
(e.g., the string "T.gemm(S_local, V_shared, O_local)") and (2) asserting it
contains the scalar-cast/fallback marker used by the mixed-dtype path (e.g., a
convert/cast token such as "convert(" or "as_type(") so the test specifically
verifies the elementwise scalar cast/fallback lowering for
attention_chain_mixed_dtype().

In `@testing/python/metal/test_metal_gemm_v2_linux.py`:
- Around line 14-19: The current matmul_gemm_v2 test only instantiates A and B
with the same dtype so it never exercises mixed-input-dtype dispatcher
relaxation; update the test generation to include at least one mixed-dtype case
(e.g., A float32 & B float16 and/or A float16 & B float32) while keeping
accum_dtype=float32, by modifying matmul_gemm_v2 (and its usages) to
accept/parametrize separate input dtypes for A and B (referencing the
matmul_gemm_v2 function and its inner prim_func main and tensor params A, B, C)
and add a new test variant that constructs A and B with differing dtypes to
exercise the fp16×fp16→fp32 vs fp32×fp16→fp32 dispatcher paths.

In `@testing/python/metal/test_metal_internal_scaffolding.py`:
- Around line 425-453: The tests
test_flashqla_gdn_kkt_probe_combines_local_var_state_and_simdgroup_boundary and
test_scaled_packed_quant_and_gdn_probes_source_boundary_tokens are asserting
exact symbol names and float formatting (e.g., "float kkt_bias = 0.000000e+00f;"
and "gate_state = 1.000000e+00f;"), which is brittle; change those assertions to
use regex/token-pattern checks that verify a local float variable is
declared/initialized (e.g., match a "float <ident> = <float literal>;" pattern)
and that an assignment of a float literal to a gate/state variable occurs, and
similarly relax any exact vector type/format checks to look for type families
(e.g., no "float8"/"float4" present) or counts rather than exact formatting so
the tests validate semantics without depending on pretty-printer output.

In `@testing/python/metal/test_metal_simdgroup_store.py`:
- Around line 72-76: The test is brittle because it looks for the temporary name
`C_local`; instead change the assertions to check for the lowering pattern
itself: assert that the source contains at least one "simdgroup_store(" (to
ensure the lowering emitted the store) and assert there are no "simdgroup_load("
occurrences that indicate an unwanted extra round-trip through shared memory;
remove any checks that specifically reference the `C_local` symbol and use the
presence/absence of "simdgroup_store(" and "simdgroup_load(" patterns to
validate behavior.

In `@tilelang/jit/adapter/base.py`:
- Around line 69-86: Update the get_current_device_functor docstring to reflect
the actual fallback chain used by the implementation: first try CUDA, then MPS,
then CPU; replace the outdated sentence that says "returns torch.device('cpu')"
when CUDA is unavailable with a concise description like "returns CUDA if
available, otherwise MPS if available, otherwise CPU" and keep the rest of the
docstring behavior notes (callable that yields Torch's current device).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 0612ea9b-616a-434a-b3f2-4ade74952e79

📥 Commits

Reviewing files that changed from the base of the PR and between d135bd1 and 71fa920.

📒 Files selected for processing (39)
  • benchmark/matmul_metal/benchmark_matmul_metal.py
  • pyproject.toml
  • requirements-dev.txt
  • requirements.txt
  • src/backend/metal/CMakeLists.txt
  • src/op/copy.cc
  • src/op/copy.h
  • src/op/fill.cc
  • src/op/gemm.cc
  • src/op/gemm.h
  • src/op/parallel.cc
  • src/op/utils.h
  • src/target/codegen_metal.cc
  • src/target/codegen_metal.h
  • src/transform/layout_inference.cc
  • src/transform/lower_device_storage_access_info.cc
  • testing/python/jit/test_tilelang_jit_adapter_mps.py
  • testing/python/metal/metal_internal_runtime_coverage.md
  • testing/python/metal/test_metal_codegen_linux.py
  • testing/python/metal/test_metal_gemm_v2.py
  • testing/python/metal/test_metal_gemm_v2_linux.py
  • testing/python/metal/test_metal_internal_scaffolding.py
  • testing/python/metal/test_metal_local_var.py
  • testing/python/metal/test_metal_simdgroup_store.py
  • tilelang/engine/lower.py
  • tilelang/engine/phase.py
  • tilelang/intrinsics/metal_macro_generator.py
  • tilelang/jit/adapter/base.py
  • tilelang/jit/adapter/torch/metal.py
  • tilelang/tileop/gemm/__init__.py
  • tilelang/tileop/gemm/gemm_base.py
  • tilelang/tileop/gemm/gemm_metal.py
  • tilelang/tileop/gemm/inst.py
  • tilelang/tileop/metal_gdn.py
  • tilelang/tileop/metal_quant.py
  • tilelang/tileop/metal_simdgroup.py
  • tilelang/transform/decouple_type_cast.py
  • tilelang/transform/metal_fragment_to_simdgroup.py
  • tilelang/utils/language.py

Comment on lines +436 to +445
// Check that all local.fragment buffers have inferred layouts.
// On Metal targets, fragment buffers used as GEMM accumulators are
// lowered to opaque simdgroup matrices, so they have no explicit
// thread-level layout and can be safely skipped.
for (const auto &[buffer, _] : use_list_) {
if (IsFragmentBuffer(buffer)) {
ICHECK_NE(layout_map.count(buffer), 0)
<< "The layout for fragment " << buffer
<< " can not be inferred correctly.";
if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) {
ICHECK(false) << "The layout for fragment " << buffer
<< " can not be inferred correctly.";
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Keep validating the fragment buffers that remain on Metal.

use_list_ only contains local.fragment buffers. After MetalFragmentToSimdgroup, the buffers that were rewritten to metal.simdgroup are already gone from this list, so this target-wide bypass only suppresses validation for the fragment buffers that weren't rewritten — including the mixed-dtype/scalar-fallback GEMM fragments that still need inferred layouts.

Suggested fix
-      if (IsFragmentBuffer(buffer)) {
-        if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) {
-          ICHECK(false) << "The layout for fragment " << buffer
-                        << " can not be inferred correctly.";
-        }
-      }
+      if (IsFragmentBuffer(buffer) && layout_map.count(buffer) == 0) {
+        ICHECK(false) << "The layout for fragment " << buffer
+                      << " can not be inferred correctly.";
+      }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Check that all local.fragment buffers have inferred layouts.
// On Metal targets, fragment buffers used as GEMM accumulators are
// lowered to opaque simdgroup matrices, so they have no explicit
// thread-level layout and can be safely skipped.
for (const auto &[buffer, _] : use_list_) {
if (IsFragmentBuffer(buffer)) {
ICHECK_NE(layout_map.count(buffer), 0)
<< "The layout for fragment " << buffer
<< " can not be inferred correctly.";
if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) {
ICHECK(false) << "The layout for fragment " << buffer
<< " can not be inferred correctly.";
}
// Check that all local.fragment buffers have inferred layouts.
// On Metal targets, fragment buffers used as GEMM accumulators are
// lowered to opaque simdgroup matrices, so they have no explicit
// thread-level layout and can be safely skipped.
for (const auto &[buffer, _] : use_list_) {
if (IsFragmentBuffer(buffer) && layout_map.count(buffer) == 0) {
ICHECK(false) << "The layout for fragment " << buffer
<< " can not be inferred correctly.";
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/transform/layout_inference.cc` around lines 436 - 445, The current
Metal-target bypass skips validating all fragment buffers when
TargetIsMetal(target_) is true, which also skips fragments that still require
inferred layouts (e.g., mixed-dtype/scalar-fallback GEMM fragments); change the
logic in the loop over use_list_ (symbols: use_list_, IsFragmentBuffer,
layout_map) to only skip validation for buffers that were actually rewritten to
simdgroup (detectable by their storage scope or a metal.simdgroup marker, e.g.,
check buffer->storage_scope == "metal.simdgroup" or use a new IsSimdgroupBuffer
helper) instead of using TargetIsMetal(target_), and ensure any remaining
local.fragment buffers still fail the ICHECK if layout_map.count(buffer) == 0 so
their layouts are inferred/validated (this preserves MetalFragmentToSimdgroup
behavior while keeping validation for fragments that remain).

Comment on lines +180 to 189
if target_is_metal(target):
# GemmMetal (simdgroup) requires A.dtype == B.dtype because it
# lowers to simdgroup_matrix_multiply. For chained
# mixed-precision patterns we route to the scalar fallback
# (GemmMetalScalar) which handles dtype mismatch via per-load
# T.cast(..., accum_dtype).
if self._has_mixed_input_dtype():
return GemmInst.Scalar
return GemmInst.METAL_SIMDGROUP
return GemmInst(_ffi_api.GemmGetGemmInst(self, int(thread_nums), target))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Don't bypass fallback selection for all same-dtype Metal GEMMs.

This now returns GemmInst.METAL_SIMDGROUP for every same-dtype Metal GEMM. That skips the normal legality checks, so cases that are not simdgroup-valid anymore fail later in GemmMetal.lower() instead of falling back cleanly. The mixed-dtype override should be layered on top of the existing selector, not replace it.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/tileop/gemm/__init__.py` around lines 180 - 189, The current Metal
branch unconditionally returns GemmInst.METAL_SIMDGROUP for same-dtype Metal
GEMMs, bypassing the normal selector and legality checks; instead call the
existing selector GemmInst(_ffi_api.GemmGetGemmInst(self, int(thread_nums),
target)) to get the default instance, then only override to GemmInst.Scalar when
target_is_metal(target) and self._has_mixed_input_dtype() is true; otherwise
return the selector's result so non-simdgroup-legal cases can fall back normally
(refer to target_is_metal, _has_mixed_input_dtype, GemmInst.METAL_SIMDGROUP,
GemmInst.Scalar, GemmInst(_ffi_api.GemmGetGemmInst) and GemmMetal.lower).

Comment on lines +126 to +161
if isinstance(stmt, tir.Block):
new_alloc_bufs = []
changed = False
for buf in stmt.alloc_buffers:
new_buf = _remap_buffer(buf, var_map)
new_alloc_bufs.append(new_buf)
if not new_buf.same_as(buf):
buf_map[buf] = new_buf
changed = True
if changed:
new_body = tir.stmt_functor.substitute(stmt.body, var_map)
new_block = tir.Block(
stmt.iter_vars,
stmt.reads,
stmt.writes,
stmt.name_hint,
new_body,
stmt.init,
new_alloc_bufs,
stmt.match_buffers,
stmt.annotations,
)
return (
tir.BlockRealize(
stmt.iter_vars,
tir.const(True, "bool"),
new_block,
)
if False
else new_block
)
elif isinstance(stmt, tir.Allocate):
new_var = var_map.get(stmt.buffer_var, None)
if new_var is not None:
new_body = tir.stmt_functor.substitute(stmt.body, var_map)
return tir.Allocate(new_var, stmt.dtype, stmt.extents, stmt.condition, new_body, stmt.annotations)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Remap the whole block signature, not just the body.

When a buffer is promoted, only stmt.body is substituted. stmt.init, stmt.reads, stmt.writes, and stmt.match_buffers still point at the old buffer objects, so the block can advertise metal.simdgroup in alloc_buffers while its metadata still references local.fragment. buf_map being unused here is a good signal that the rewrite is incomplete.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/transform/metal_fragment_to_simdgroup.py` around lines 126 - 161,
The current rewrite only substitutes stmt.body when buffers are remapped,
leaving stmt.init, stmt.reads, stmt.writes and stmt.match_buffers still pointing
to old buffer objects (buf_map is built but unused). Update the tir.Block
handling (the loop that builds new_alloc_bufs using _remap_buffer, var_map and
buf_map) to also remap the block signature: apply substitutions or rebuild
stmt.init, stmt.reads, stmt.writes and stmt.match_buffers to reference new_bufs
(using buf_map and tir.stmt_functor.substitute where appropriate) and use those
remapped values when constructing new_block (and the BlockRealize branch) so the
block metadata is consistent with alloc_buffers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants