Skip to content

[Metal] thread stage dim through T.access_ptr for T.Pipelined num_stages>1#2141

Open
apstenku123 wants to merge 11 commits intotile-ai:mainfrom
apstenku123:cppmega/metal-pipelined-3d-buffer
Open

[Metal] thread stage dim through T.access_ptr for T.Pipelined num_stages>1#2141
apstenku123 wants to merge 11 commits intotile-ai:mainfrom
apstenku123:cppmega/metal-pipelined-3d-buffer

Conversation

@apstenku123
Copy link
Copy Markdown

@apstenku123 apstenku123 commented May 4, 2026

Summary

Threads the pipeline-stage dimension through the Metal T.access_ptr rewrite so T.Pipelined(K_iters, num_stages > 1) lowers correctly on Metal. Without this, the metal_macro_generator's T.access_ptr pattern emits a 2D shape that doesn't index the per-stage shared buffers, producing either a compile error or silent overwrite at stage>0.

The fix is in tilelang/intrinsics/metal_macro_generator.py: when the buffer has 3D shape (num_stages, R, C), the access path now folds the stage axis into the offset arithmetic instead of dropping it.

Why

Software pipelining is critical for hiding memory latency in MLA-style attention on Apple GPUs. Without num_stages > 1 working, every K-iteration of the attention loop is fully serialized over shared memory. Local probes (Mamba3, sparse-MLA fwd) showed 2-3× speedup with num_stages=2 once the 3D buffer indexing was fixed.

This PR makes the existing Metal T.Pipelined path actually correct for num_stages>1 at 16×16 fragment tiles. (Larger fragment sizes — 32×32 — still need follow-up work; this PR is the foundational fix that unblocks the small-fragment shape used by Mamba3 and partial sparse-MLA.)

Stacking topology

This PR is based on jorgecurious/tilelang:metal-gemm-upstream-rebase (PR #2130) at HEAD 971c17b, which itself stacks on top of:

Once #1869+#2118+#2121+#2130 merge into tile-ai/tilelang:main, this PR can be retargeted to main directly.

Test plan

The downstream probe that exercised this is cppmega.mlx's test_pipelined_probe.py:

cd /path/to/tilelang
# Build TileLang from this PR's branch
# Then run the probe (lives in cppmega.mlx)
.venv/bin/python /path/to/cppmega.mlx/docs/upstream/tilelang_metal_pipelined/test_pipelined_probe.py

Expect num_stages=2 and num_stages=3 to both lower successfully on Metal target.

Caveats

  • This fix is for 16×16 fragments. 32×32 fragment-size pipelining is a follow-up that will surface when sparse-MLA Path C extends to its production tile size.
  • No new codegen emitted; this is a fix to an existing lowering path.

Attribution

Co-developed with cppmega.mlx for Apple-Silicon Metal MLA kernel ports.

Summary by CodeRabbit

  • New Features

    • Added Metal/MPS backend compilation support for TileLang kernels on Apple devices.
    • Implemented Metal SIMDGROUP GEMM operations with configurable block tiling.
    • Added Metal intrinsics for matrix operations, copy, and fill instructions.
    • Added MPS device fallback when CUDA is unavailable.
    • Added Metal quantization and GDN pipeline utilities.
  • Documentation

    • Added Metal backend runtime coverage documentation.
  • Tests

    • Added comprehensive Metal backend test suite covering GEMM, simdgroup operations, and internal scaffolding.
  • Chores

    • Updated apache-tvm-ffi dependency constraints for macOS compatibility.

oraluben and others added 11 commits April 30, 2026 01:43
Add T.gemm support for Apple Metal using simdgroup_matrix 8x8 operations
(simdgroup_load/store/multiply_accumulate). Works on all Apple Silicon
(M1-M5) without requiring a TVM fork.

Key changes:
- codegen_metal.cc/h: Fork TVM Metal codegen to tilelang with
  simdgroup intrinsic emission and 128-bit vectorized copy
- gemm_metal.py: GemmMetal tile operator for sharedxshared GEMM
- metal_macro_generator.py: MPSIntrinEmitter for simdgroup MMA macros
- metal_fragment_to_simdgroup.py: Pass rewrites local.fragment GEMM
  accumulators to metal.simdgroup scope before layout inference
- LowerSIMDGroupCopy in copy.cc for fragment->device simdgroup_store

24 Metal tests (codegen cross-platform + correctness on device).
Copilot AI review requested due to automatic review settings May 4, 2026 08:53
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 4, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 4, 2026

📝 Walkthrough

Walkthrough

This PR adds comprehensive Metal (Apple Silicon/MPS) target support to the TVM TileLang compiler, including a Metal code generator, GEMM operations via simdgroup matrix operations, register-tile abstractions, quantization/GDN attention macros, and extensive test coverage. It constrains apache-tvm-ffi on macOS and updates device selection to prefer MPS when CUDA unavailable.

Changes

Metal Infrastructure, Device Selection & Code Generation

Layer / File(s) Summary
Dependencies & Device Fallback
pyproject.toml, requirements*.txt, tilelang/jit/adapter/base.py, tilelang/jit/adapter/torch/metal.py
Constrains apache-tvm-ffi<0.1.8 on macOS; get_current_device_functor() prefers MPS when CUDA unavailable; Metal adapter exposes kernel source accessor.
Build Configuration
src/backend/metal/CMakeLists.txt
Adds src/target/codegen_metal.cc to TILE_LANG_SRCS for cross-compilation; codegen-only mode returns immediately on non-Apple hosts instead of disabling Metal.
Metal Code Generator
src/target/codegen_metal.cc, src/target/codegen_metal.h
Emits Metal device code from TVM TIR: handles thread/threadgroup binding, storage scopes (global/shared/local/metal.simdgroup), simdgroup matrix ops, float16 packing, select/broadcast/reinterpret builtins, and Metal-specific type/constant printing; registers target.build.tilelang_metal FFI builder.
Engine Integration
tilelang/engine/lower.py
Routes Metal device codegen to new target.build.tilelang_metal builder instead of generic Metal builder.

Metal GEMM & Warp-Level Operations

Layer / File(s) Summary
GEMM Instruction Selection
src/op/gemm.cc, src/op/gemm.h, tilelang/tileop/gemm/inst.py, tilelang/tileop/gemm/__init__.py
Adds GemmInst::kMetalSimdgroup enum value; getGemmInst() returns Metal variant for Metal targets; computeWarpPartition() uses 8-element warp M-tile for Metal vs 16 for others.
MPS Intrinsics Emitter
tilelang/intrinsics/metal_macro_generator.py
MPSIntrinEmitter generates TVM TIR macros for 8×8 simdgroup matrix ops: ldmatrix_a/ldmatrix_b load micro-tiles, mma performs multiply-accumulate across warp tiles, simdgroup_copy moves C between simdgroup and destination buffers.
Metal GEMM Implementation
tilelang/tileop/gemm/gemm_metal.py
GemmMetal lowers shared/shared GEMMs: validates block/warp tiling, instantiates MPSIntrinEmitter, emits pipelined K-loop with load/MMA, optionally clears/loads/stores C via simdgroup buffers.
Fragment→Simdgroup IR Pass
tilelang/engine/phase.py, tilelang/transform/metal_fragment_to_simdgroup.py
MetalFragmentToSimdgroup rewrites GEMM accumulator buffers from local.fragment to metal.simdgroup storage scope before layout inference, allowing Metal-native opaque matrix operations.

Metal Memory Operations & Register Tile Abstractions

Layer / File(s) Summary
Copy Operations
src/op/copy.cc, src/op/copy.h
Adds CopyInst::kMetalSIMDGroup for Metal simdgroup transfers; CheckSIMDGroupCopy() validates buffer scopes/dtypes and 8×8 matrix alignment; LowerSIMDGroupCopy() emits simdgroup_store calls with warp geometry validation, falling back to normal copy on constraint mismatch.
Fill Operations
src/op/fill.cc
Handles simdgroup-scoped destinations: validates region size is multiple of 64, aligned to 8×8 boundary, emits make_filled_simdgroup_matrix calls per 8×8 tile covering the region.
Layout & Storage Inference
src/transform/layout_inference.cc, src/transform/lower_device_storage_access_info.cc, src/op/parallel.cc, src/op/utils.h
Fragment buffers on Metal skip layout-inference requirement; ".fragment" allocations excluded from device-storage lowering; ParallelOpNode::InferLayout safely guards optional fragment layout access; adds IsSIMDGroupBuffer() and IsRegisterBuffer() predicates.
Register Tile Abstractions
tilelang/tileop/metal_simdgroup.py, tilelang/utils/language.py
RegisterTile dataclass wraps simdgroup fragment with layout metadata; alloc_rt, load_*, store_*, mma_* macros provide tensor-like register-tile operations; RowVector scalar-materialized rows with reduction/scaling; is_metal_simdgroup() buffer predicate.

Metal Advanced Optimization Features

Layer / File(s) Summary
Quantization Helpers
tilelang/tileop/metal_quant.py
QuantSimdgroupTile presets (small/large 8×8 register configs); tile-shape selectors based on M×N dimensions; decode routines for fp8 e4m3fn, fp4 e2m1fn, and e8m0 scales into float32 TileLang expressions.
GDN/Attention Macros
tilelang/tileop/metal_gdn.py
Implements KKT score computation, causal gating, gate decay; linear W/U element computation; strided/contiguous A/K/V tile accumulation via mma_ab/mma_abt; simdgroup-based staged pipelines for GDN KKT/WU/component scoring.
Type Cast & Local Buffer Handling
tilelang/transform/decouple_type_cast.py
Extends is_local_buffer() to recognize metal.simdgroup buffers as "local" for type-cast decoupling purposes.

Tests, Benchmarks & Validation

Layer / File(s) Summary
Benchmark
benchmark/matmul_metal/benchmark_matmul_metal.py
Python script benchmarking TileLang Metal GEMM vs PyTorch MPS torch.mm fp16; supports configurable M/N/K, warmup/repeats, block-config sweep; reports TFLOPS and best configuration.
Metal GEMM Correctness Tests
testing/python/metal/test_metal_gemm_v2.py, testing/python/metal/test_metal_gemm_v2_linux.py
JIT-compiled Metal GEMM kernels with MPS runtime validation (PyTorch reference) and Linux codegen checks (kernel source assertions for simdgroup ops).
Register Tile & Simdgroup Tests
testing/python/metal/test_metal_simdgroup_store.py, testing/python/metal/test_metal_local_var.py
Validates simdgroup accumulator store paths, avoids redundant load/store pairs, ensures local.var scalars compile to thread-local variables on Metal.
Internal Scaffolding & Runtime Coverage
testing/python/metal/test_metal_internal_scaffolding.py, testing/python/metal/metal_internal_runtime_coverage.md
Extensive internal-only probes for register-tile MMA, row-vector reductions, packed fp8/fp4 quantized matmul, GDN KKT/WU/component pipelines; codegen/boundary token assertions; MPS runtime correctness checks; optional environment-gated benchmarks; documentation of coverage scope and fail-closed guarantees.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant LowerPhase as LowerAndLegalize<br/>(Phase)
    participant MetalPass as MetalFragmentToSimdgroup<br/>(IR Pass)
    participant LayoutInf as LayoutInference
    participant LowerGEMM as GEMM Lowering<br/>(gemm_metal.py)
    participant MPSEmitter as MPSIntrinEmitter
    participant Codegen as CodeGenTileLangMetal
    participant Metal as Metal Kernel<br/>(Source)

    User->>LowerPhase: Lower TileLang Module
    LowerPhase->>LowerPhase: Software Pipeline Injection
    LowerPhase->>MetalPass: Apply MetalFragmentToSimdgroup
    MetalPass->>MetalPass: Rewrite local.fragment → metal.simdgroup
    LowerPhase->>LayoutInf: Run Layout Inference
    LayoutInf->>LayoutInf: Skip fragment layout check for Metal
    LowerPhase->>LowerGEMM: Lower GEMM Ops
    LowerGEMM->>MPSEmitter: Instantiate with dtype/warp config
    LowerGEMM->>MPSEmitter: ldmatrix_a, ldmatrix_b loads
    LowerGEMM->>MPSEmitter: mma multiply-accumulate
    LowerGEMM->>MPSEmitter: simdgroup_copy store result
    LowerPhase->>Codegen: CodeGen to Metal
    Codegen->>Codegen: Emit Metal kernel void
    Codegen->>Codegen: Translate simdgroup_* ops to Metal
    Codegen->>Codegen: Handle local.var scalars
    Codegen->>Metal: Generate Metal source
    Metal-->>User: Metal Kernel Ready
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 Silicon hops in Metal's glow,
SIMD matrices steal the show,
Fragments dance as register tiles,
GDN macros go for miles,
Apple's GPU now compiles!

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

apstenku123 added a commit to DatasunriseOU/cppmega_mlx that referenced this pull request May 4, 2026
Files documenting the actual PRs we just opened upstream:

- PR #1: ml-explore/mlx#3476 — from_dlpack Metal-aware consumer (against main, clean)
- PR #2: apache/tvm#19504 — TVM_METAL_STORAGE_MODE env opt-in (against main, clean)
- PR #3: tile-ai/tilelang#2139 — mixed-dtype T.gemm via scalar fallback (stacks on PR #2130)
- PR #4: tile-ai/tilelang#2140 — FP8-input T.gemm scalar fallback routing (stacks on PR #2130)
- PR #5: tile-ai/tilelang#2141 — T.Pipelined num_stages>1 3D buffer fix (stacks on PR #2130)
- PR #6: tile-ai/tilelang#2142 — T.fp8_scaled_matmul DSL intrinsic (stacks on PR #2130)

Deferred (split into companion PRs needed): tilelang_metal_fp8 and
tilelang_metal_fp8_vector each touch both tilelang supermodule and the
TileLang/tvm vendored submodule. These need 2 PRs each — one to
tile-ai/tilelang, one to TileLang/tvm — separate filing round.

PRs #3-#6 are independent of each other; each branches directly from
jorgecurious/tilelang:metal-gemm-upstream-rebase HEAD 971c17b, so they
can be reviewed in any order. They DO depend on the upstream 4-PR Apple
Metal landing chain (#1869, #2118, #2121, #2130) merging first; if any
of those land separately, ours can be retargeted at main.
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/jit/adapter/base.py (1)

69-75: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Update the stale docstring to reflect the MPS fallback.

The docstring states the function returns torch.device('cpu') when CUDA is unavailable, but after this change it now returns torch.device("mps") when MPS is available. The documented fallback chain should be CUDA → MPS → CPU.

📝 Proposed docstring update
-    """Return a callable that yields Torch's current device.
-
-    Similar to the stream functor, we capture a callable that, when called,
-    fetches the current device according to PyTorch. On CPU or when CUDA is
-    unavailable, returns ``torch.device('cpu')``.
-    """
+    """Return a callable that yields Torch's current device.
+
+    Similar to the stream functor, we capture a callable that, when called,
+    fetches the current device according to PyTorch. Falls back through
+    CUDA → MPS → CPU: returns an MPS device functor when CUDA is unavailable
+    or CUDA initialisation fails and ``torch.backends.mps`` is available,
+    otherwise returns ``torch.device('cpu')``.
+    """
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/jit/adapter/base.py` around lines 69 - 75, The docstring for
get_current_device_functor is outdated about fallback behavior; update it to
document the actual device selection order (CUDA → MPS → CPU) and mention that
when CUDA is unavailable but MPS is available the callable yields
torch.device("mps"), otherwise torch.device("cpu"); keep the rest of the
description about returning a callable that fetches the current torch device at
call time.
🧹 Nitpick comments (8)
requirements-dev.txt (1)

4-4: 💤 Low value

Consider adding an inline comment explaining the Darwin cap.

The <0.1.8 bound is non-obvious without context. A short comment (e.g., # 0.1.8+ breaks on macOS; see tilelang#XXXX) would help future readers understand why this constraint exists, especially since it sits on a separate line from the base range.

💡 Suggested change
-apache-tvm-ffi<0.1.8; platform_system == 'Darwin'
+apache-tvm-ffi<0.1.8; platform_system == 'Darwin'  # 0.1.8+ incompatible on macOS
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@requirements-dev.txt` at line 4, Add a short inline comment next to the
apache-tvm-ffi<0.1.8; platform_system == 'Darwin' entry in requirements-dev.txt
explaining the macOS upper bound (e.g., "# 0.1.8+ breaks on macOS; see issue
tilelang#XXXX or link to bug/PR"), so future readers understand why this
Darwin-specific constraint exists; reference the package name apache-tvm-ffi
when adding the comment.
pyproject.toml (1)

33-34: 💤 Low value

LGTM — Darwin-specific cap is valid and follows the established project pattern.

Each string in [project].dependencies represents a dependency and maps directly to a Requires-Dist entry, so two entries for apache-tvm-ffi with different markers are valid. The effective constraint on macOS becomes >=0.1.2, <0.1.8 (the base ~=0.1.0,>=0.1.2 intersected with <0.1.8), which is the intended behavior. The same two-entry pattern is already used for torch on lines 42–43.

Optionally, a brief inline comment (e.g., # 0.1.8+ incompatible on macOS) would give the Darwin cap the same documentation context that lines 30–32 give the base constraint.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pyproject.toml` around lines 33 - 34, Add a brief inline comment explaining
the Darwin cap on the second apache-tvm-ffi entry: locate the dependency string
"apache-tvm-ffi<0.1.8; platform_system == 'Darwin'" and append a short comment
(e.g., "# 0.1.8+ incompatible on macOS") to document why the macOS-specific
upper bound exists, mirroring the explanatory style used for the base constraint
and the torch entries.
testing/python/jit/test_tilelang_jit_adapter_mps.py (1)

10-47: ⚡ Quick win

Extract the duplicated MPS patching guard into a shared fixture.

The getattr/monkeypatch guard for torch.backends.mps is copy-pasted verbatim in all three tests (Lines 13-16, 27-30, 40-43), differing only in the is_available return value. A parametrised fixture keeps the pattern DRY and makes future additions trivial.

♻️ Proposed refactor
+import pytest
+
+@pytest.fixture
+def patch_mps(monkeypatch):
+    """Return a helper that sets torch.backends.mps.is_available() to *available*."""
+    def _patch(available: bool):
+        if getattr(torch.backends, "mps", None) is None:
+            monkeypatch.setattr(
+                torch.backends, "mps",
+                SimpleNamespace(is_available=lambda: available),
+                raising=False,
+            )
+        else:
+            monkeypatch.setattr(torch.backends.mps, "is_available", lambda: available)
+    return _patch
+
+
-def test_current_device_functor_prefers_mps_when_cuda_unavailable(monkeypatch):
+def test_current_device_functor_prefers_mps_when_cuda_unavailable(monkeypatch, patch_mps):
     monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
-
-    if getattr(torch.backends, "mps", None) is None:
-        monkeypatch.setattr(torch.backends, "mps", SimpleNamespace(is_available=lambda: True), raising=False)
-    else:
-        monkeypatch.setattr(torch.backends.mps, "is_available", lambda: True)
+    patch_mps(True)
 
     device_functor = BaseKernelAdapter.get_current_device_functor()
     assert device_functor() == torch.device("mps")
 
 
-def test_current_device_functor_prefers_mps_when_cuda_init_fails(monkeypatch):
+def test_current_device_functor_prefers_mps_when_cuda_init_fails(monkeypatch, patch_mps):
     monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
     monkeypatch.setattr(torch.cuda, "_lazy_init", lambda: (_ for _ in ()).throw(RuntimeError("cuda init failed")))
-
-    if getattr(torch.backends, "mps", None) is None:
-        monkeypatch.setattr(torch.backends, "mps", SimpleNamespace(is_available=lambda: True), raising=False)
-    else:
-        monkeypatch.setattr(torch.backends.mps, "is_available", lambda: True)
+    patch_mps(True)
 
     device_functor = BaseKernelAdapter.get_current_device_functor()
     assert device_functor() == torch.device("mps")
 
 
-def test_current_device_functor_falls_back_to_cpu_without_cuda_or_mps(monkeypatch):
+def test_current_device_functor_falls_back_to_cpu_without_cuda_or_mps(monkeypatch, patch_mps):
     monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
-
-    if getattr(torch.backends, "mps", None) is None:
-        monkeypatch.setattr(torch.backends, "mps", SimpleNamespace(is_available=lambda: False), raising=False)
-    else:
-        monkeypatch.setattr(torch.backends.mps, "is_available", lambda: False)
+    patch_mps(False)
 
     device_functor = BaseKernelAdapter.get_current_device_functor()
     assert device_functor() == torch.device("cpu")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/jit/test_tilelang_jit_adapter_mps.py` around lines 10 - 47,
These tests duplicate the same torch.backends.mps monkeypatch guard; create a
parametrized pytest fixture (e.g., mps_availability) that accepts a boolean for
is_available and performs the getattr/monkeypatch logic, then update
test_current_device_functor_prefers_mps_when_cuda_unavailable,
test_current_device_functor_prefers_mps_when_cuda_init_fails, and
test_current_device_functor_falls_back_to_cpu_without_cuda_or_mps to depend on
that fixture and pass True/False as needed; keep the existing monkeypatch for
torch.cuda adjustments and continue calling
BaseKernelAdapter.get_current_device_functor() and asserting the returned
device.
src/transform/layout_inference.cc (1)

436-446: ⚡ Quick win

Comment is misleading; the guard may suppress real layout failures for non-accumulator fragment buffers.

MetalFragmentToSimdgroup runs before LayoutInference, converting GEMM accumulator local.fragment buffers to metal.simdgroup. Those buffers are therefore absent from use_list_ by the time this loop executes (since addToUseList only tracks local.fragment scope). Any local.fragment entry that does appear in use_list_ on Metal is a non-accumulator fragment buffer (e.g., one used in a T.Parallel loop).

The current guard silently skips the ICHECK for those non-accumulator buffers too, which could hide real layout-inference failures and make downstream codegen errors difficult to diagnose.

Consider adding a diagnostic for the unexpected case:

🛡️ Proposed improvement
-    // Check that all local.fragment buffers have inferred layouts.
-    // On Metal targets, fragment buffers used as GEMM accumulators are
-    // lowered to opaque simdgroup matrices, so they have no explicit
-    // thread-level layout and can be safely skipped.
+    // Check that all local.fragment buffers have inferred layouts.
+    // On Metal targets, MetalFragmentToSimdgroup converts GEMM accumulator
+    // local.fragment buffers to metal.simdgroup *before* this pass runs, so
+    // those buffers are already absent from use_list_.  Any local.fragment
+    // that remains in use_list_ on Metal is a non-accumulator buffer; emit a
+    // warning instead of a hard ICHECK so Metal kernels that genuinely don't
+    // need thread-level fragment layouts can still compile.
     for (const auto &[buffer, _] : use_list_) {
       if (IsFragmentBuffer(buffer)) {
-        if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) {
-          ICHECK(false) << "The layout for fragment " << buffer
-                        << " can not be inferred correctly.";
+        if (layout_map.count(buffer) == 0) {
+          if (TargetIsMetal(target_)) {
+            LOG(WARNING) << "[LayoutInference] Metal: fragment buffer "
+                         << buffer << " has no inferred layout; "
+                         << "verify it is handled by the Metal codegen path.";
+          } else {
+            ICHECK(false) << "The layout for fragment " << buffer
+                          << " can not be inferred correctly.";
+          }
         }
       }
     }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/transform/layout_inference.cc` around lines 436 - 446, The loop currently
skips layout presence checks for all fragment buffers on Metal, which can hide
real failures; modify the loop that iterates use_list_ so that for every
IsFragmentBuffer(buffer) you verify layout_map.count(buffer) != 0 and trigger an
ICHECK (or a failing diagnostic) when missing; for the Metal target make the
error message explicit (referencing TargetIsMetal(target_), IsFragmentBuffer,
layout_map, use_list_) stating this is an unexpected non-accumulator fragment on
Metal and suggesting investigation of MetalFragmentToSimdgroup/LayoutInference
ordering (so replace the existing TargetIsMetal guard with a straight check and
a clearer error message).
tilelang/jit/adapter/torch/metal.py (1)

56-57: kernel_only parameter is silently ignored.

The method always returns self.kernel_global_source or "" regardless of the kernel_only flag. While Metal may not distinguish between kernel-only and full source (having no separate preamble), this deviates from the BaseKernelAdapter contract where kernel_only=False should return additional preamble/includes.

Document the lack of support or add conditional logic to match the interface:

💡 Proposed fix
    def get_kernel_source(self, kernel_only: bool = True) -> str:
+        """Return the Metal kernel source. `kernel_only` is unused on Metal
+        (there is no separate host-side preamble)."""
         return self.kernel_global_source or ""
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/jit/adapter/torch/metal.py` around lines 56 - 57, The
get_kernel_source method currently ignores the kernel_only parameter and always
returns self.kernel_global_source or "", so update get_kernel_source in the
Metal adapter (method name: get_kernel_source) to honor the BaseKernelAdapter
contract: if kernel_only is True return only the kernel body (e.g.
self.kernel_source or self.kernel_global_source_kernel_part) and if kernel_only
is False return the full source including any preamble/includes (or, if Metal
genuinely has no preamble support, explicitly document that by returning the
same string but add a comment/docstring clarifying that kernel_only is not
applicable); ensure you reference and use the existing attributes
(self.kernel_global_source, self.kernel_source, or similar) to construct the
conditional return and update the docstring to state the behavior when
kernel_only is unsupported.
tilelang/intrinsics/metal_macro_generator.py (1)

111-111: 💤 Low value

Prefer tuple unpacking over concatenation.

The static analysis suggests using (*leading, row_idx, col_idx) instead of leading + (row_idx, col_idx) for cleaner tuple construction.

♻️ Proposed fix
-                ptr = T.access_ptr(buffer[leading + (row_idx, col_idx)], "r")
+                ptr = T.access_ptr(buffer[(*leading, row_idx, col_idx)], "r")

Apply similar changes at lines 145 and 203.

Also applies to: 145-145, 203-203

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/intrinsics/metal_macro_generator.py` at line 111, Replace tuple
concatenation with tuple unpacking when indexing the buffer: change uses like
buffer[leading + (row_idx, col_idx)] (seen when assigning ptr via T.access_ptr)
to buffer[(*leading, row_idx, col_idx)] so the tuple is built cleanly; update
the same pattern at the other occurrences noted (the other buffer index usages
around the references at lines 145 and 203) to use (*leading, row_idx, col_idx)
as well.
testing/python/metal/test_metal_gemm_v2.py (1)

29-29: ⚡ Quick win

Add a num_stages > 1 variant here.

Line 29 hardcodes num_stages=0, so this test never exercises the pipelined path that the PR is fixing. A num_stages=2/3 case would cover the stage-dimension lowering regression directly.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2.py` at line 29, The test currently
only uses a pipelined loop with num_stages=0 and therefore never exercises the
pipelined lowering path; update the test in test_metal_gemm_v2 by adding a
variant that iterates the same loop with T.Pipelined(..., num_stages=2) (and
optionally num_stages=3) so the code path for staged pipelining is exercised.
Concretely, either parametrize the test to run with num_stages in [0,2] (or
[0,2,3]) or duplicate the loop that uses T.Pipelined(T.ceildiv(K, block_K),
num_stages=0) and change the duplicate to num_stages=2 so the stage-dimension
lowering regression is covered; keep the same surrounding setup and assertions.
benchmark/matmul_metal/benchmark_matmul_metal.py (1)

34-38: ⚡ Quick win

This benchmark never measures the pipelined path.

The inner loop is fixed at num_stages=0, so the script cannot quantify the stage-aware 3D-buffer fix or the num_stages>1 speedups called out in the PR. Please make num_stages configurable and include a staged default/sweep value so the benchmark reflects the feature being added here.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmark/matmul_metal/benchmark_matmul_metal.py` around lines 34 - 38, The
pipelined inner loop currently hardcodes num_stages=0 in the T.Pipelined call
(loop variable ko) so the benchmark never measures staged execution; make
num_stages configurable in benchmark_matmul_metal.py (e.g., a CLI flag or config
variable) and use that value in the T.Pipelined(..., num_stages=...) call,
provide a sensible default that includes staged runs (e.g., 0 and >1) and add an
optional sweep (iterate over multiple num_stages values) so the benchmark can
run and report results for both non-pipelined and staged settings; update any
driver/runner code that invokes the benchmark to accept and pass this parameter.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@testing/python/metal/test_metal_gemm_v2_linux.py`:
- Around line 28-34: The test currently only exercises the pipeline with
num_stages=0; add a new regression case that uses T.Pipelined(..., num_stages=2)
or num_stages=3 and uses stage-indexed shared-memory buffers so the 3D
shared-buffer lowering is exercised. Concretely, duplicate the existing loop
variant that calls T.Pipelined(T.ceildiv(K, block_K), num_stages=0) but set
num_stages=2 (or 3), allocate the shared buffers with an extra leading stage
dimension (i.e., stage-indexed shared buffers) and update the T.copy/T.gemm
calls to read/write via the current stage index so copies target
A_shared[stage,...] and B_shared[stage,...] and the compute reads those
stage-indexed buffers; this will force the stage>0 compile path to run and catch
regressions in the shared-buffer lowering.

In `@tilelang/engine/lower.py`:
- Around line 263-264: The device_codegen_without_compile() path is still
invoking the compiled Metal builder (target.build.tilelang_metal) which triggers
tvm_callback_metal_compile in src/target/codegen_metal.cc; change the lower.py
branch so that when enable_device_compile is False it calls a new, compile-free
entry point (e.g., target.build.tilelang_metal_without_compile) or passes an
explicit compile flag through the Metal builder API; implement the new symbol
target.build.tilelang_metal_without_compile (or accept the flag) in the Metal
codegen and ensure src/target/codegen_metal.cc avoids calling
tvm_callback_metal_compile when the compile flag is false so tilelang.lower(...,
enable_device_compile=False) never requires the Metal compiler.

In `@tilelang/transform/metal_fragment_to_simdgroup.py`:
- Around line 57-65: The _remap_buffer function rebuilds a buffer but omits
important metadata (strides, elem_offset, buffer_type, axis_separators, etc.),
which changes indexing/layout; update the tir.decl_buffer call in _remap_buffer
to pass through all original buffer fields from buf (e.g., strides, elem_offset,
buffer_type, axis_separators and any other non-default attributes) along with
shape/dtype/name/data=new_data/data_alignment/offset_factor so the new buffer
preserves the original layout and semantic information.
- Around line 82-110: The pre-order rewrite must rebuild the tir.Block
completely and not leave metadata pointing at old buffers: apply
tir.stmt_functor.substitute (using var_map) to stmt.reads, stmt.writes and
stmt.init (in addition to stmt.body) when constructing new_block/BlockRealize so
BufferRegion and init expressions reference the new variables (e.g., when
creating new_block in the branch that handles tir.Block). Also avoid
short-circuiting traversal by returning None from the pre-order callback after
recording the replacement (or otherwise ensure the transform continues into
nested binders) so nested tir.Allocate nodes under the block are visited and
their buffer_var bindings are updated; reference tir.Block, tir.BlockRealize,
tir.Allocate, var_map, and tir.stmt_functor.ir_transform in your changes.

---

Outside diff comments:
In `@tilelang/jit/adapter/base.py`:
- Around line 69-75: The docstring for get_current_device_functor is outdated
about fallback behavior; update it to document the actual device selection order
(CUDA → MPS → CPU) and mention that when CUDA is unavailable but MPS is
available the callable yields torch.device("mps"), otherwise
torch.device("cpu"); keep the rest of the description about returning a callable
that fetches the current torch device at call time.

---

Nitpick comments:
In `@benchmark/matmul_metal/benchmark_matmul_metal.py`:
- Around line 34-38: The pipelined inner loop currently hardcodes num_stages=0
in the T.Pipelined call (loop variable ko) so the benchmark never measures
staged execution; make num_stages configurable in benchmark_matmul_metal.py
(e.g., a CLI flag or config variable) and use that value in the T.Pipelined(...,
num_stages=...) call, provide a sensible default that includes staged runs
(e.g., 0 and >1) and add an optional sweep (iterate over multiple num_stages
values) so the benchmark can run and report results for both non-pipelined and
staged settings; update any driver/runner code that invokes the benchmark to
accept and pass this parameter.

In `@pyproject.toml`:
- Around line 33-34: Add a brief inline comment explaining the Darwin cap on the
second apache-tvm-ffi entry: locate the dependency string "apache-tvm-ffi<0.1.8;
platform_system == 'Darwin'" and append a short comment (e.g., "# 0.1.8+
incompatible on macOS") to document why the macOS-specific upper bound exists,
mirroring the explanatory style used for the base constraint and the torch
entries.

In `@requirements-dev.txt`:
- Line 4: Add a short inline comment next to the apache-tvm-ffi<0.1.8;
platform_system == 'Darwin' entry in requirements-dev.txt explaining the macOS
upper bound (e.g., "# 0.1.8+ breaks on macOS; see issue tilelang#XXXX or link to
bug/PR"), so future readers understand why this Darwin-specific constraint
exists; reference the package name apache-tvm-ffi when adding the comment.

In `@src/transform/layout_inference.cc`:
- Around line 436-446: The loop currently skips layout presence checks for all
fragment buffers on Metal, which can hide real failures; modify the loop that
iterates use_list_ so that for every IsFragmentBuffer(buffer) you verify
layout_map.count(buffer) != 0 and trigger an ICHECK (or a failing diagnostic)
when missing; for the Metal target make the error message explicit (referencing
TargetIsMetal(target_), IsFragmentBuffer, layout_map, use_list_) stating this is
an unexpected non-accumulator fragment on Metal and suggesting investigation of
MetalFragmentToSimdgroup/LayoutInference ordering (so replace the existing
TargetIsMetal guard with a straight check and a clearer error message).

In `@testing/python/jit/test_tilelang_jit_adapter_mps.py`:
- Around line 10-47: These tests duplicate the same torch.backends.mps
monkeypatch guard; create a parametrized pytest fixture (e.g., mps_availability)
that accepts a boolean for is_available and performs the getattr/monkeypatch
logic, then update
test_current_device_functor_prefers_mps_when_cuda_unavailable,
test_current_device_functor_prefers_mps_when_cuda_init_fails, and
test_current_device_functor_falls_back_to_cpu_without_cuda_or_mps to depend on
that fixture and pass True/False as needed; keep the existing monkeypatch for
torch.cuda adjustments and continue calling
BaseKernelAdapter.get_current_device_functor() and asserting the returned
device.

In `@testing/python/metal/test_metal_gemm_v2.py`:
- Line 29: The test currently only uses a pipelined loop with num_stages=0 and
therefore never exercises the pipelined lowering path; update the test in
test_metal_gemm_v2 by adding a variant that iterates the same loop with
T.Pipelined(..., num_stages=2) (and optionally num_stages=3) so the code path
for staged pipelining is exercised. Concretely, either parametrize the test to
run with num_stages in [0,2] (or [0,2,3]) or duplicate the loop that uses
T.Pipelined(T.ceildiv(K, block_K), num_stages=0) and change the duplicate to
num_stages=2 so the stage-dimension lowering regression is covered; keep the
same surrounding setup and assertions.

In `@tilelang/intrinsics/metal_macro_generator.py`:
- Line 111: Replace tuple concatenation with tuple unpacking when indexing the
buffer: change uses like buffer[leading + (row_idx, col_idx)] (seen when
assigning ptr via T.access_ptr) to buffer[(*leading, row_idx, col_idx)] so the
tuple is built cleanly; update the same pattern at the other occurrences noted
(the other buffer index usages around the references at lines 145 and 203) to
use (*leading, row_idx, col_idx) as well.

In `@tilelang/jit/adapter/torch/metal.py`:
- Around line 56-57: The get_kernel_source method currently ignores the
kernel_only parameter and always returns self.kernel_global_source or "", so
update get_kernel_source in the Metal adapter (method name: get_kernel_source)
to honor the BaseKernelAdapter contract: if kernel_only is True return only the
kernel body (e.g. self.kernel_source or self.kernel_global_source_kernel_part)
and if kernel_only is False return the full source including any
preamble/includes (or, if Metal genuinely has no preamble support, explicitly
document that by returning the same string but add a comment/docstring
clarifying that kernel_only is not applicable); ensure you reference and use the
existing attributes (self.kernel_global_source, self.kernel_source, or similar)
to construct the conditional return and update the docstring to state the
behavior when kernel_only is unsupported.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 3abbf2a1-a65f-4d9f-ab27-98881f9770bb

📥 Commits

Reviewing files that changed from the base of the PR and between d135bd1 and e99df5b.

📒 Files selected for processing (37)
  • benchmark/matmul_metal/benchmark_matmul_metal.py
  • pyproject.toml
  • requirements-dev.txt
  • requirements.txt
  • src/backend/metal/CMakeLists.txt
  • src/op/copy.cc
  • src/op/copy.h
  • src/op/fill.cc
  • src/op/gemm.cc
  • src/op/gemm.h
  • src/op/parallel.cc
  • src/op/utils.h
  • src/target/codegen_metal.cc
  • src/target/codegen_metal.h
  • src/transform/layout_inference.cc
  • src/transform/lower_device_storage_access_info.cc
  • testing/python/jit/test_tilelang_jit_adapter_mps.py
  • testing/python/metal/metal_internal_runtime_coverage.md
  • testing/python/metal/test_metal_gemm_v2.py
  • testing/python/metal/test_metal_gemm_v2_linux.py
  • testing/python/metal/test_metal_internal_scaffolding.py
  • testing/python/metal/test_metal_local_var.py
  • testing/python/metal/test_metal_simdgroup_store.py
  • tilelang/engine/lower.py
  • tilelang/engine/phase.py
  • tilelang/intrinsics/metal_macro_generator.py
  • tilelang/jit/adapter/base.py
  • tilelang/jit/adapter/torch/metal.py
  • tilelang/tileop/gemm/__init__.py
  • tilelang/tileop/gemm/gemm_metal.py
  • tilelang/tileop/gemm/inst.py
  • tilelang/tileop/metal_gdn.py
  • tilelang/tileop/metal_quant.py
  • tilelang/tileop/metal_simdgroup.py
  • tilelang/transform/decouple_type_cast.py
  • tilelang/transform/metal_fragment_to_simdgroup.py
  • tilelang/utils/language.py

Comment on lines +28 to +34
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
T.copy(A[by * block_M, ko * block_K], A_shared, coalesced_width=2)
T.copy(B[ko * block_K, bx * block_N], B_shared, coalesced_width=2)

T.gemm(A_shared, B_shared, C_local)

T.copy(C_local, C[by * block_M, bx * block_N], coalesced_width=2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win

Add a staged regression case here.

This helper hardcodes num_stages=0, so none of the tests in this file exercise the 3D shared-buffer lowering that this PR is fixing. Please add at least one T.Pipelined(..., num_stages=2/3) case with stage-indexed shared memory so the old stage>0 compile/overwrite failure cannot regress silently.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2_linux.py` around lines 28 - 34, The
test currently only exercises the pipeline with num_stages=0; add a new
regression case that uses T.Pipelined(..., num_stages=2) or num_stages=3 and
uses stage-indexed shared-memory buffers so the 3D shared-buffer lowering is
exercised. Concretely, duplicate the existing loop variant that calls
T.Pipelined(T.ceildiv(K, block_K), num_stages=0) but set num_stages=2 (or 3),
allocate the shared buffers with an extra leading stage dimension (i.e.,
stage-indexed shared buffers) and update the T.copy/T.gemm calls to read/write
via the current stage index so copies target A_shared[stage,...] and
B_shared[stage,...] and the compute reads those stage-indexed buffers; this will
force the stage>0 compile path to run and catch regressions in the shared-buffer
lowering.

Comment thread tilelang/engine/lower.py
Comment on lines 263 to +264
elif target.kind.name == "metal":
device_mod = tvm.ffi.get_global_func("target.build.metal")(device_mod, target)
device_mod = tvm.ffi.get_global_func("target.build.tilelang_metal")(device_mod, target)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

device_codegen_without_compile() still compiles Metal kernels.

This branch now calls the same builder as the compiled path, and src/target/codegen_metal.cc:547-574 still invokes tvm_callback_metal_compile whenever that callback is registered. So tilelang.lower(..., enable_device_compile=False) can still require a working Metal compiler on Apple hosts, which breaks the no-compile contract and makes source-only tests/environment behavior host-dependent. Please add a distinct target.build.tilelang_metal_without_compile entry point or thread an explicit compile flag through the builder.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/engine/lower.py` around lines 263 - 264, The
device_codegen_without_compile() path is still invoking the compiled Metal
builder (target.build.tilelang_metal) which triggers tvm_callback_metal_compile
in src/target/codegen_metal.cc; change the lower.py branch so that when
enable_device_compile is False it calls a new, compile-free entry point (e.g.,
target.build.tilelang_metal_without_compile) or passes an explicit compile flag
through the Metal builder API; implement the new symbol
target.build.tilelang_metal_without_compile (or accept the flag) in the Metal
codegen and ensure src/target/codegen_metal.cc avoids calling
tvm_callback_metal_compile when the compile flag is false so tilelang.lower(...,
enable_device_compile=False) never requires the Metal compiler.

Comment on lines +57 to +65
return tir.decl_buffer(
buf.shape,
buf.dtype,
buf.name,
data=new_data,
scope="metal.simdgroup",
data_alignment=buf.data_alignment,
offset_factor=buf.offset_factor,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, find and examine the file
find . -name "metal_fragment_to_simdgroup.py" -type f

Repository: tile-ai/tilelang

Length of output: 112


🏁 Script executed:

# Also search for the _remap_buffer function
rg "_remap_buffer" -A 20 --context 5

Repository: tile-ai/tilelang

Length of output: 1831


🏁 Script executed:

# Check the tir.decl_buffer API to see what parameters it accepts
rg "def decl_buffer" -A 30

Repository: tile-ai/tilelang

Length of output: 1669


🏁 Script executed:

# Read the full _remap_buffer function and surrounding context
sed -n '45,80p' tilelang/transform/metal_fragment_to_simdgroup.py

Repository: tile-ai/tilelang

Length of output: 1118


🏁 Script executed:

# Also check the Buffer class definition to see what fields it has
rg "class Buffer|def strides|def elem_offset|def buffer_type|def axis_separators" -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 4441


🏁 Script executed:

# Search for Buffer class definition
rg "class Buffer" -A 30 tilelang/

Repository: tile-ai/tilelang

Length of output: 3579


🏁 Script executed:

# Also check TVM's buffer definition to understand Buffer object
rg "\.strides|\.elem_offset|\.buffer_type|\.axis_separators" tilelang/transform/metal_fragment_to_simdgroup.py -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check how buf properties are used elsewhere in the file
rg "buf\." tilelang/transform/metal_fragment_to_simdgroup.py

Repository: tile-ai/tilelang

Length of output: 269


🏁 Script executed:

# Search for how buffer strides and elem_offset are used in the codebase
rg "\.strides\b|\.elem_offset\b|\.buffer_type\b|\.axis_separators\b" --type py -B 2 -A 2 | head -100

Repository: tile-ai/tilelang

Length of output: 7063


🏁 Script executed:

# Check if there are any uses of buffer properties in layout inference or similar passes
rg "layout_inference|LayoutInference" tilelang/ --type py

Repository: tile-ai/tilelang

Length of output: 793


🏁 Script executed:

# Look at what happens when a buffer is created to see if these fields have defaults
rg "def buffer\(" tilelang/ -A 50 | head -150

Repository: tile-ai/tilelang

Length of output: 5958


🏁 Script executed:

# Check if buf is a tir.Buffer object and verify it has readable strides/elem_offset/buffer_type/axis_separators properties
rg "from.*tir import|import.*tir" tilelang/transform/metal_fragment_to_simdgroup.py -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 206


🏁 Script executed:

# Look for any Buffer attribute access to confirm these are real properties
cat tilelang/language/ast/ir.py | grep -A 100 "def buffer(" | head -120

Repository: tile-ai/tilelang

Length of output: 2798


Preserve all buffer metadata when changing scope.

_remap_buffer() currently rebuilds the buffer with only shape/dtype/name/data_alignment/offset_factor. If the original buffer carries non-default strides, elem_offset, buffer_type, or axis_separators, this pass silently changes indexing semantics and buffer layout information. The strides field in particular is actively used downstream for element offset computation in buffer indexing operations.

Proposed fix
     return tir.decl_buffer(
         buf.shape,
         buf.dtype,
         buf.name,
         data=new_data,
+        strides=buf.strides,
+        elem_offset=buf.elem_offset,
         scope="metal.simdgroup",
         data_alignment=buf.data_alignment,
         offset_factor=buf.offset_factor,
+        buffer_type=buf.buffer_type,
+        axis_separators=buf.axis_separators,
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
return tir.decl_buffer(
buf.shape,
buf.dtype,
buf.name,
data=new_data,
scope="metal.simdgroup",
data_alignment=buf.data_alignment,
offset_factor=buf.offset_factor,
)
return tir.decl_buffer(
buf.shape,
buf.dtype,
buf.name,
data=new_data,
strides=buf.strides,
elem_offset=buf.elem_offset,
scope="metal.simdgroup",
data_alignment=buf.data_alignment,
offset_factor=buf.offset_factor,
buffer_type=buf.buffer_type,
axis_separators=buf.axis_separators,
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/transform/metal_fragment_to_simdgroup.py` around lines 57 - 65, The
_remap_buffer function rebuilds a buffer but omits important metadata (strides,
elem_offset, buffer_type, axis_separators, etc.), which changes indexing/layout;
update the tir.decl_buffer call in _remap_buffer to pass through all original
buffer fields from buf (e.g., strides, elem_offset, buffer_type, axis_separators
and any other non-default attributes) along with
shape/dtype/name/data=new_data/data_alignment/offset_factor so the new buffer
preserves the original layout and semantic information.

Comment on lines +82 to +110
new_body = tir.stmt_functor.substitute(stmt.body, var_map)
new_block = tir.Block(
stmt.iter_vars,
stmt.reads,
stmt.writes,
stmt.name_hint,
new_body,
stmt.init,
new_alloc_bufs,
stmt.match_buffers,
stmt.annotations,
)
return (
tir.BlockRealize(
stmt.iter_vars,
tir.const(True, "bool"),
new_block,
)
if False
else new_block
)
elif isinstance(stmt, tir.Allocate):
new_var = var_map.get(stmt.buffer_var, None)
if new_var is not None:
new_body = tir.stmt_functor.substitute(stmt.body, var_map)
return tir.Allocate(new_var, stmt.dtype, stmt.extents, stmt.condition, new_body, stmt.annotations)
return None

return tir.stmt_functor.ir_transform(body, _pre_order, None, ["tir.Block", "tir.Allocate"])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find . -name "metal_fragment_to_simdgroup.py" -type f

Repository: tile-ai/tilelang

Length of output: 112


🏁 Script executed:

wc -l ./tilelang/transform/metal_fragment_to_simdgroup.py

Repository: tile-ai/tilelang

Length of output: 116


🏁 Script executed:

cat -n ./tilelang/transform/metal_fragment_to_simdgroup.py

Repository: tile-ai/tilelang

Length of output: 5319


🏁 Script executed:

# Search for similar patterns of ir_transform usage in the codebase
rg "ir_transform" --type py -B 3 -A 3 | head -100

Repository: tile-ai/tilelang

Length of output: 4632


🏁 Script executed:

# Search for tests related to metal_fragment_to_simdgroup
find . -name "*test*metal*" -type f 2>/dev/null

Repository: tile-ai/tilelang

Length of output: 440


🏁 Script executed:

# Search for references to metal_fragment_to_simdgroup in tests
rg "metal_fragment_to_simdgroup|MetalFragmentToSimdgroup" --type py

Repository: tile-ai/tilelang

Length of output: 549


🏁 Script executed:

# Search for how ir_transform is used in add_bufstore_wrapper to understand the pattern
cat -n ./tilelang/transform/add_bufstore_wrapper.py | head -100

Repository: tile-ai/tilelang

Length of output: 4199


🏁 Script executed:

# Look for TVM documentation or comments about ir_transform behavior
rg "pre.*order|post.*order" --type py tilelang/transform/ -A 2 -B 2 | head -80

Repository: tile-ai/tilelang

Length of output: 3565


🏁 Script executed:

# Check how reads/writes are handled in test_tilelang_transform_legalize_safe_memory_access.py
cat -n ./testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py | head -150

Repository: tile-ai/tilelang

Length of output: 5914


🏁 Script executed:

# Search for references to reads/writes substitution in TVM-related code
rg "reads|writes|init" tilelang/transform/metal_fragment_to_simdgroup.py

Repository: tile-ai/tilelang

Length of output: 156


🏁 Script executed:

# Search for how reads/writes are used in Block creation in TVM/tilelang
rg "BufferRegion|reads.*=|writes.*=" --type py tilelang/ | head -50

Repository: tile-ai/tilelang

Length of output: 6373


🏁 Script executed:

# Look for how TVM's ir_transform processes returns from pre-order callbacks
# Check if there are any comments or documentation about pre-order callback behavior
rg "ir_transform" tilelang/transform/ -B 5 -A 10 | grep -A 15 "def _pre_order\|def _pre_visit"

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check how ir_transform handles pre-order returns
cat -n ./tilelang/transform/add_bufstore_wrapper.py | sed -n '130,180p'

Repository: tile-ai/tilelang

Length of output: 1377


🏁 Script executed:

# Check if there are any test files for metal_fragment_to_simdgroup
rg "simdgroup|fragment" testing/python/metal/ --type py | head -30

Repository: tile-ai/tilelang

Length of output: 4027


🏁 Script executed:

# Let's look at the test file to understand what the transformation is supposed to do
cat -n ./testing/python/metal/test_metal_simdgroup_store.py | head -80

Repository: tile-ai/tilelang

Length of output: 4069


🏁 Script executed:

# Let's understand the structure better - check if reads/writes contain buffer references
rg "BufferRegion" tilelang/transform/ --type py -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 42


Rewrite the whole tir.Block consistently, including reads, writes, and init fields.

The block body is substituted at line 82, but the reconstructed block keeps the old reads, writes, and init fields without applying the variable substitution. Since these fields typically contain BufferRegion objects that reference buffer variables, this creates an IR inconsistency where the block's body uses the new metal.simdgroup variables while metadata fields reference the old local.fragment variables.

Additionally, when a modified Block is returned from the pre-order callback, nested tir.Allocate binders under it may not be visited, leaving old variable bindings in place while expressions already use the new variables. This can cause subsequent passes to encounter an internally inconsistent IR state.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/transform/metal_fragment_to_simdgroup.py` around lines 82 - 110, The
pre-order rewrite must rebuild the tir.Block completely and not leave metadata
pointing at old buffers: apply tir.stmt_functor.substitute (using var_map) to
stmt.reads, stmt.writes and stmt.init (in addition to stmt.body) when
constructing new_block/BlockRealize so BufferRegion and init expressions
reference the new variables (e.g., when creating new_block in the branch that
handles tir.Block). Also avoid short-circuiting traversal by returning None from
the pre-order callback after recording the replacement (or otherwise ensure the
transform continues into nested binders) so nested tir.Allocate nodes under the
block are visited and their buffer_var bindings are updated; reference
tir.Block, tir.BlockRealize, tir.Allocate, var_map, and
tir.stmt_functor.ir_transform in your changes.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends TileLang’s Metal backend to correctly handle software-pipelined shared buffers (T.Pipelined(..., num_stages>1)) by threading the stage/version dimension through Metal T.access_ptr-based address generation, and adds broader Metal simdgroup support (GEMM lowering/codegen/copy/fill) plus tests/benchmarks around the new paths.

Changes:

  • Fix Metal simdgroup load/store address calculation to include leading (e.g., pipeline stage) indices when shared buffers become 3D under software pipelining.
  • Introduce/extend Metal simdgroup infrastructure: fragment→simdgroup rewrite pass, simdgroup-aware GEMM implementation, simdgroup fill/copy lowering, and Metal codegen plumbing.
  • Add Metal-focused tests (codegen + hardware-gated correctness) and internal scaffolding probes; adjust Torch JIT adapter to prefer MPS when CUDA is unavailable.

Reviewed changes

Copilot reviewed 36 out of 37 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
tilelang/utils/language.py Adds is_metal_simdgroup scope predicate.
tilelang/transform/metal_fragment_to_simdgroup.py New pass to rewrite GEMM accumulators from local.fragment to metal.simdgroup on Metal.
tilelang/transform/decouple_type_cast.py Treats metal.simdgroup as a local/register buffer for cast decoupling.
tilelang/tileop/metal_simdgroup.py Internal simdgroup RegisterTile helpers/macros (alloc/load/store/mma, plus scalar row ops).
tilelang/tileop/metal_quant.py Internal packed-uint8 quant decode helpers for Metal probes.
tilelang/tileop/metal_gdn.py Internal GDN/attention-style simdgroup tile macros.
tilelang/tileop/gemm/inst.py Adds METAL_SIMDGROUP GEMM instruction enum value.
tilelang/tileop/gemm/gemm_metal.py Implements Metal GEMM lowering via MPSIntrinEmitter simdgroup operations.
tilelang/tileop/gemm/init.py Selects Metal GEMM implementation when target is Metal.
tilelang/jit/adapter/torch/metal.py Exposes kernel source for Torch Metal adapter.
tilelang/jit/adapter/base.py Updates device selection to prefer MPS when CUDA is unavailable/initialization fails.
tilelang/intrinsics/metal_macro_generator.py Metal macro emitter updated to preserve leading indices (pipeline stage) in T.access_ptr.
tilelang/engine/phase.py Inserts MetalFragmentToSimdgroup before layout inference.
tilelang/engine/lower.py Switches Metal build hook to target.build.tilelang_metal.
testing/python/metal/test_metal_simdgroup_store.py Tests simdgroup-accumulation path and direct simdgroup_store emission.
testing/python/metal/test_metal_local_var.py Adds Metal codegen/runtime tests for local.var scalar lowering.
testing/python/metal/test_metal_internal_scaffolding.py Adds internal-only Metal probes for simdgroup, quant decode, and GDN-style kernels.
testing/python/metal/test_metal_gemm_v2.py Hardware-gated correctness tests for Metal T.gemm (gemm_v2).
testing/python/metal/test_metal_gemm_v2_linux.py Cross-platform Metal source-generation tests for T.gemm (gemm_v2).
testing/python/metal/metal_internal_runtime_coverage.md Documents Metal internal runtime/source-boundary coverage.
testing/python/jit/test_tilelang_jit_adapter_mps.py Tests new MPS-preferred device selection behavior.
src/transform/lower_device_storage_access_info.cc Excludes fragment scope tag from memory-info enforcement.
src/transform/layout_inference.cc Skips fragment-layout completeness check for Metal targets.
src/target/codegen_metal.h Declares CodeGenTileLangMetal code generator.
src/target/codegen_metal.cc Implements Metal source codegen, simdgroup matrix lowering, and registers target.build.tilelang_metal.
src/op/utils.h Adds IsSIMDGroupBuffer / IsRegisterBuffer helpers.
src/op/parallel.cc Makes fragment-layout lookup conditional to avoid missing-layout crashes.
src/op/gemm.h Adds Metal simdgroup GEMM instruction kind to C++ enum.
src/op/gemm.cc Returns Metal simdgroup instruction for Metal targets; adjusts warp policy constants for Metal.
src/op/fill.cc Adds simdgroup buffer fill lowering via make_filled_simdgroup_matrix.
src/op/copy.h Adds Metal simdgroup copy instruction and lowering hooks.
src/op/copy.cc Adds Metal simdgroup store lowering (LowerSIMDGroupCopy) and instruction selection.
src/backend/metal/CMakeLists.txt Always compiles Metal codegen source for cross-platform codegen-only mode.
requirements.txt Adds Darwin-only upper bound on apache-tvm-ffi.
requirements-dev.txt Adds Darwin-only upper bound on apache-tvm-ffi for dev installs.
pyproject.toml Adds Darwin-only upper bound on apache-tvm-ffi for packaging.
benchmark/matmul_metal/benchmark_matmul_metal.py Adds a Metal GEMM benchmark script.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/op/copy.cc
Comment on lines +1097 to +1110
float ideal = N > 0 ? static_cast<float>(M) / N : 1.f;
float best_score = std::numeric_limits<float>::max();
for (int m = 1; m <= std::min(num_warps, max_m); ++m) {
if (num_warps % m != 0)
continue;
int n = num_warps / m;
if (n > max_n)
continue;
if (M % (m * kMPerWarp) != 0 || N % (n * kNPerWarp) != 0)
continue;
float m_per = static_cast<float>(M) / (m * kMPerWarp);
float n_per = static_cast<float>(N) / (n * kNPerWarp);
float score = std::abs(m_per / n_per - ideal);
if (score < best_score) {
Comment thread src/op/copy.cc
Comment on lines +1089 to +1130
int kMPerWarp = 8;
int kNPerWarp = 8;
int m_warp = 1, n_warp = num_warps;
int max_m = M / kMPerWarp;
int max_n = N / kNPerWarp;
if (max_m <= 0 || max_n <= 0) {
return LowerNormalCopy(T, analyzer);
}
float ideal = N > 0 ? static_cast<float>(M) / N : 1.f;
float best_score = std::numeric_limits<float>::max();
for (int m = 1; m <= std::min(num_warps, max_m); ++m) {
if (num_warps % m != 0)
continue;
int n = num_warps / m;
if (n > max_n)
continue;
if (M % (m * kMPerWarp) != 0 || N % (n * kNPerWarp) != 0)
continue;
float m_per = static_cast<float>(M) / (m * kMPerWarp);
float n_per = static_cast<float>(N) / (n * kNPerWarp);
float score = std::abs(m_per / n_per - ideal);
if (score < best_score) {
best_score = score;
m_warp = m;
n_warp = n;
}
}

if (best_score == std::numeric_limits<float>::max() || M < m_warp * 8 ||
N < n_warp * 8) {
return LowerNormalCopy(T, analyzer);
}
int warp_row_tiles = M / m_warp / 8;
int warp_col_tiles = N / n_warp / 8;
if (warp_row_tiles <= 0 || warp_col_tiles <= 0 ||
warp_row_tiles * warp_col_tiles * 64 > total_elements) {
return LowerNormalCopy(T, analyzer);
}

PrimExpr warp_m = FloorMod(warp_id, m_warp);
PrimExpr warp_n = FloorDiv(warp_id, m_warp);

Comment on lines +68 to +103
def _rewrite_scope(body, var_map):
buf_map = {}

def _pre_order(stmt):
if isinstance(stmt, tir.Block):
new_alloc_bufs = []
changed = False
for buf in stmt.alloc_buffers:
new_buf = _remap_buffer(buf, var_map)
new_alloc_bufs.append(new_buf)
if not new_buf.same_as(buf):
buf_map[buf] = new_buf
changed = True
if changed:
new_body = tir.stmt_functor.substitute(stmt.body, var_map)
new_block = tir.Block(
stmt.iter_vars,
stmt.reads,
stmt.writes,
stmt.name_hint,
new_body,
stmt.init,
new_alloc_bufs,
stmt.match_buffers,
stmt.annotations,
)
return (
tir.BlockRealize(
stmt.iter_vars,
tir.const(True, "bool"),
new_block,
)
if False
else new_block
)
elif isinstance(stmt, tir.Allocate):
Comment on lines +436 to 446
// Check that all local.fragment buffers have inferred layouts.
// On Metal targets, fragment buffers used as GEMM accumulators are
// lowered to opaque simdgroup matrices, so they have no explicit
// thread-level layout and can be safely skipped.
for (const auto &[buffer, _] : use_list_) {
if (IsFragmentBuffer(buffer)) {
ICHECK_NE(layout_map.count(buffer), 0)
<< "The layout for fragment " << buffer
<< " can not be inferred correctly.";
if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) {
ICHECK(false) << "The layout for fragment " << buffer
<< " can not be inferred correctly.";
}
}
Comment on lines +173 to +174
if target_is_metal(target):
return GemmInst.METAL_SIMDGROUP
Comment on lines +64 to +89
@staticmethod
def _parse_buffer_2d(buf):
"""Extract (buffer, row_offset, col_offset, stride, leading_indices) from Buffer or BufferRegion.

``leading_indices`` carries the .min of any region dims preceding the
last two (row, col). The InjectSoftwarePipeline pass expands shared
buffers from 2D to 3D by inserting a "version" dim at position 0
(shape becomes ``[num_versions, M, N]`` and ``region[0]`` becomes the
per-iteration version index). Accessors must therefore include those
leading dims; otherwise the patched ``Buffer.__getitem__`` raises
``IndexError: Buffer X is 3-dimensional ... but 2 index(es) were provided``.
See ``inject_pipeline.cc::RewritePipelineBufferRegion`` and
``mma_macro_generator.py`` (CUDA) for the same pattern.
"""
if isinstance(buf, BufferRegion):
buffer = buf.buffer
off_row = buf.region[-2].min
off_col = buf.region[-1].min
leading = tuple(r.min for r in buf.region[:-2])
else:
buffer = buf
off_row = 0
off_col = 0
leading = tuple(0 for _ in buf.shape[:-2])
stride = buffer.strides[-2] if len(buffer.strides) == len(buffer.shape) else buffer.shape[-1]
return buffer, off_row, off_col, stride, leading
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants