[Metal] thread stage dim through T.access_ptr for T.Pipelined num_stages>1#2141
[Metal] thread stage dim through T.access_ptr for T.Pipelined num_stages>1#2141apstenku123 wants to merge 11 commits intotile-ai:mainfrom
Conversation
Add T.gemm support for Apple Metal using simdgroup_matrix 8x8 operations (simdgroup_load/store/multiply_accumulate). Works on all Apple Silicon (M1-M5) without requiring a TVM fork. Key changes: - codegen_metal.cc/h: Fork TVM Metal codegen to tilelang with simdgroup intrinsic emission and 128-bit vectorized copy - gemm_metal.py: GemmMetal tile operator for sharedxshared GEMM - metal_macro_generator.py: MPSIntrinEmitter for simdgroup MMA macros - metal_fragment_to_simdgroup.py: Pass rewrites local.fragment GEMM accumulators to metal.simdgroup scope before layout inference - LowerSIMDGroupCopy in copy.cc for fragment->device simdgroup_store 24 Metal tests (codegen cross-platform + correctness on device).
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughThis PR adds comprehensive Metal (Apple Silicon/MPS) target support to the TVM TileLang compiler, including a Metal code generator, GEMM operations via simdgroup matrix operations, register-tile abstractions, quantization/GDN attention macros, and extensive test coverage. It constrains ChangesMetal Infrastructure, Device Selection & Code Generation
Metal GEMM & Warp-Level Operations
Metal Memory Operations & Register Tile Abstractions
Metal Advanced Optimization Features
Tests, Benchmarks & Validation
Sequence Diagram(s)sequenceDiagram
participant User
participant LowerPhase as LowerAndLegalize<br/>(Phase)
participant MetalPass as MetalFragmentToSimdgroup<br/>(IR Pass)
participant LayoutInf as LayoutInference
participant LowerGEMM as GEMM Lowering<br/>(gemm_metal.py)
participant MPSEmitter as MPSIntrinEmitter
participant Codegen as CodeGenTileLangMetal
participant Metal as Metal Kernel<br/>(Source)
User->>LowerPhase: Lower TileLang Module
LowerPhase->>LowerPhase: Software Pipeline Injection
LowerPhase->>MetalPass: Apply MetalFragmentToSimdgroup
MetalPass->>MetalPass: Rewrite local.fragment → metal.simdgroup
LowerPhase->>LayoutInf: Run Layout Inference
LayoutInf->>LayoutInf: Skip fragment layout check for Metal
LowerPhase->>LowerGEMM: Lower GEMM Ops
LowerGEMM->>MPSEmitter: Instantiate with dtype/warp config
LowerGEMM->>MPSEmitter: ldmatrix_a, ldmatrix_b loads
LowerGEMM->>MPSEmitter: mma multiply-accumulate
LowerGEMM->>MPSEmitter: simdgroup_copy store result
LowerPhase->>Codegen: CodeGen to Metal
Codegen->>Codegen: Emit Metal kernel void
Codegen->>Codegen: Translate simdgroup_* ops to Metal
Codegen->>Codegen: Handle local.var scalars
Codegen->>Metal: Generate Metal source
Metal-->>User: Metal Kernel Ready
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
✨ Finishing Touches🧪 Generate unit tests (beta)
|
Files documenting the actual PRs we just opened upstream: - PR #1: ml-explore/mlx#3476 — from_dlpack Metal-aware consumer (against main, clean) - PR #2: apache/tvm#19504 — TVM_METAL_STORAGE_MODE env opt-in (against main, clean) - PR #3: tile-ai/tilelang#2139 — mixed-dtype T.gemm via scalar fallback (stacks on PR #2130) - PR #4: tile-ai/tilelang#2140 — FP8-input T.gemm scalar fallback routing (stacks on PR #2130) - PR #5: tile-ai/tilelang#2141 — T.Pipelined num_stages>1 3D buffer fix (stacks on PR #2130) - PR #6: tile-ai/tilelang#2142 — T.fp8_scaled_matmul DSL intrinsic (stacks on PR #2130) Deferred (split into companion PRs needed): tilelang_metal_fp8 and tilelang_metal_fp8_vector each touch both tilelang supermodule and the TileLang/tvm vendored submodule. These need 2 PRs each — one to tile-ai/tilelang, one to TileLang/tvm — separate filing round. PRs #3-#6 are independent of each other; each branches directly from jorgecurious/tilelang:metal-gemm-upstream-rebase HEAD 971c17b, so they can be reviewed in any order. They DO depend on the upstream 4-PR Apple Metal landing chain (#1869, #2118, #2121, #2130) merging first; if any of those land separately, ours can be retargeted at main.
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/jit/adapter/base.py (1)
69-75:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winUpdate the stale docstring to reflect the MPS fallback.
The docstring states the function returns
torch.device('cpu')when CUDA is unavailable, but after this change it now returnstorch.device("mps")when MPS is available. The documented fallback chain should be CUDA → MPS → CPU.📝 Proposed docstring update
- """Return a callable that yields Torch's current device. - - Similar to the stream functor, we capture a callable that, when called, - fetches the current device according to PyTorch. On CPU or when CUDA is - unavailable, returns ``torch.device('cpu')``. - """ + """Return a callable that yields Torch's current device. + + Similar to the stream functor, we capture a callable that, when called, + fetches the current device according to PyTorch. Falls back through + CUDA → MPS → CPU: returns an MPS device functor when CUDA is unavailable + or CUDA initialisation fails and ``torch.backends.mps`` is available, + otherwise returns ``torch.device('cpu')``. + """🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/jit/adapter/base.py` around lines 69 - 75, The docstring for get_current_device_functor is outdated about fallback behavior; update it to document the actual device selection order (CUDA → MPS → CPU) and mention that when CUDA is unavailable but MPS is available the callable yields torch.device("mps"), otherwise torch.device("cpu"); keep the rest of the description about returning a callable that fetches the current torch device at call time.
🧹 Nitpick comments (8)
requirements-dev.txt (1)
4-4: 💤 Low valueConsider adding an inline comment explaining the Darwin cap.
The
<0.1.8bound is non-obvious without context. A short comment (e.g.,# 0.1.8+ breaks on macOS; see tilelang#XXXX) would help future readers understand why this constraint exists, especially since it sits on a separate line from the base range.💡 Suggested change
-apache-tvm-ffi<0.1.8; platform_system == 'Darwin' +apache-tvm-ffi<0.1.8; platform_system == 'Darwin' # 0.1.8+ incompatible on macOS🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@requirements-dev.txt` at line 4, Add a short inline comment next to the apache-tvm-ffi<0.1.8; platform_system == 'Darwin' entry in requirements-dev.txt explaining the macOS upper bound (e.g., "# 0.1.8+ breaks on macOS; see issue tilelang#XXXX or link to bug/PR"), so future readers understand why this Darwin-specific constraint exists; reference the package name apache-tvm-ffi when adding the comment.pyproject.toml (1)
33-34: 💤 Low valueLGTM — Darwin-specific cap is valid and follows the established project pattern.
Each string in
[project].dependenciesrepresents a dependency and maps directly to aRequires-Distentry, so two entries forapache-tvm-ffiwith different markers are valid. The effective constraint on macOS becomes>=0.1.2, <0.1.8(the base~=0.1.0,>=0.1.2intersected with<0.1.8), which is the intended behavior. The same two-entry pattern is already used fortorchon lines 42–43.Optionally, a brief inline comment (e.g.,
# 0.1.8+ incompatible on macOS) would give the Darwin cap the same documentation context that lines 30–32 give the base constraint.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pyproject.toml` around lines 33 - 34, Add a brief inline comment explaining the Darwin cap on the second apache-tvm-ffi entry: locate the dependency string "apache-tvm-ffi<0.1.8; platform_system == 'Darwin'" and append a short comment (e.g., "# 0.1.8+ incompatible on macOS") to document why the macOS-specific upper bound exists, mirroring the explanatory style used for the base constraint and the torch entries.testing/python/jit/test_tilelang_jit_adapter_mps.py (1)
10-47: ⚡ Quick winExtract the duplicated MPS patching guard into a shared fixture.
The
getattr/monkeypatchguard fortorch.backends.mpsis copy-pasted verbatim in all three tests (Lines 13-16, 27-30, 40-43), differing only in theis_availablereturn value. A parametrised fixture keeps the pattern DRY and makes future additions trivial.♻️ Proposed refactor
+import pytest + +@pytest.fixture +def patch_mps(monkeypatch): + """Return a helper that sets torch.backends.mps.is_available() to *available*.""" + def _patch(available: bool): + if getattr(torch.backends, "mps", None) is None: + monkeypatch.setattr( + torch.backends, "mps", + SimpleNamespace(is_available=lambda: available), + raising=False, + ) + else: + monkeypatch.setattr(torch.backends.mps, "is_available", lambda: available) + return _patch + + -def test_current_device_functor_prefers_mps_when_cuda_unavailable(monkeypatch): +def test_current_device_functor_prefers_mps_when_cuda_unavailable(monkeypatch, patch_mps): monkeypatch.setattr(torch.cuda, "is_available", lambda: False) - - if getattr(torch.backends, "mps", None) is None: - monkeypatch.setattr(torch.backends, "mps", SimpleNamespace(is_available=lambda: True), raising=False) - else: - monkeypatch.setattr(torch.backends.mps, "is_available", lambda: True) + patch_mps(True) device_functor = BaseKernelAdapter.get_current_device_functor() assert device_functor() == torch.device("mps") -def test_current_device_functor_prefers_mps_when_cuda_init_fails(monkeypatch): +def test_current_device_functor_prefers_mps_when_cuda_init_fails(monkeypatch, patch_mps): monkeypatch.setattr(torch.cuda, "is_available", lambda: True) monkeypatch.setattr(torch.cuda, "_lazy_init", lambda: (_ for _ in ()).throw(RuntimeError("cuda init failed"))) - - if getattr(torch.backends, "mps", None) is None: - monkeypatch.setattr(torch.backends, "mps", SimpleNamespace(is_available=lambda: True), raising=False) - else: - monkeypatch.setattr(torch.backends.mps, "is_available", lambda: True) + patch_mps(True) device_functor = BaseKernelAdapter.get_current_device_functor() assert device_functor() == torch.device("mps") -def test_current_device_functor_falls_back_to_cpu_without_cuda_or_mps(monkeypatch): +def test_current_device_functor_falls_back_to_cpu_without_cuda_or_mps(monkeypatch, patch_mps): monkeypatch.setattr(torch.cuda, "is_available", lambda: False) - - if getattr(torch.backends, "mps", None) is None: - monkeypatch.setattr(torch.backends, "mps", SimpleNamespace(is_available=lambda: False), raising=False) - else: - monkeypatch.setattr(torch.backends.mps, "is_available", lambda: False) + patch_mps(False) device_functor = BaseKernelAdapter.get_current_device_functor() assert device_functor() == torch.device("cpu")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/jit/test_tilelang_jit_adapter_mps.py` around lines 10 - 47, These tests duplicate the same torch.backends.mps monkeypatch guard; create a parametrized pytest fixture (e.g., mps_availability) that accepts a boolean for is_available and performs the getattr/monkeypatch logic, then update test_current_device_functor_prefers_mps_when_cuda_unavailable, test_current_device_functor_prefers_mps_when_cuda_init_fails, and test_current_device_functor_falls_back_to_cpu_without_cuda_or_mps to depend on that fixture and pass True/False as needed; keep the existing monkeypatch for torch.cuda adjustments and continue calling BaseKernelAdapter.get_current_device_functor() and asserting the returned device.src/transform/layout_inference.cc (1)
436-446: ⚡ Quick winComment is misleading; the guard may suppress real layout failures for non-accumulator fragment buffers.
MetalFragmentToSimdgroupruns beforeLayoutInference, converting GEMM accumulatorlocal.fragmentbuffers tometal.simdgroup. Those buffers are therefore absent fromuse_list_by the time this loop executes (sinceaddToUseListonly trackslocal.fragmentscope). Anylocal.fragmententry that does appear inuse_list_on Metal is a non-accumulator fragment buffer (e.g., one used in aT.Parallelloop).The current guard silently skips the
ICHECKfor those non-accumulator buffers too, which could hide real layout-inference failures and make downstream codegen errors difficult to diagnose.Consider adding a diagnostic for the unexpected case:
🛡️ Proposed improvement
- // Check that all local.fragment buffers have inferred layouts. - // On Metal targets, fragment buffers used as GEMM accumulators are - // lowered to opaque simdgroup matrices, so they have no explicit - // thread-level layout and can be safely skipped. + // Check that all local.fragment buffers have inferred layouts. + // On Metal targets, MetalFragmentToSimdgroup converts GEMM accumulator + // local.fragment buffers to metal.simdgroup *before* this pass runs, so + // those buffers are already absent from use_list_. Any local.fragment + // that remains in use_list_ on Metal is a non-accumulator buffer; emit a + // warning instead of a hard ICHECK so Metal kernels that genuinely don't + // need thread-level fragment layouts can still compile. for (const auto &[buffer, _] : use_list_) { if (IsFragmentBuffer(buffer)) { - if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) { - ICHECK(false) << "The layout for fragment " << buffer - << " can not be inferred correctly."; + if (layout_map.count(buffer) == 0) { + if (TargetIsMetal(target_)) { + LOG(WARNING) << "[LayoutInference] Metal: fragment buffer " + << buffer << " has no inferred layout; " + << "verify it is handled by the Metal codegen path."; + } else { + ICHECK(false) << "The layout for fragment " << buffer + << " can not be inferred correctly."; + } } } }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/transform/layout_inference.cc` around lines 436 - 446, The loop currently skips layout presence checks for all fragment buffers on Metal, which can hide real failures; modify the loop that iterates use_list_ so that for every IsFragmentBuffer(buffer) you verify layout_map.count(buffer) != 0 and trigger an ICHECK (or a failing diagnostic) when missing; for the Metal target make the error message explicit (referencing TargetIsMetal(target_), IsFragmentBuffer, layout_map, use_list_) stating this is an unexpected non-accumulator fragment on Metal and suggesting investigation of MetalFragmentToSimdgroup/LayoutInference ordering (so replace the existing TargetIsMetal guard with a straight check and a clearer error message).tilelang/jit/adapter/torch/metal.py (1)
56-57:kernel_onlyparameter is silently ignored.The method always returns
self.kernel_global_source or ""regardless of thekernel_onlyflag. While Metal may not distinguish between kernel-only and full source (having no separate preamble), this deviates from theBaseKernelAdaptercontract wherekernel_only=Falseshould return additional preamble/includes.Document the lack of support or add conditional logic to match the interface:
💡 Proposed fix
def get_kernel_source(self, kernel_only: bool = True) -> str: + """Return the Metal kernel source. `kernel_only` is unused on Metal + (there is no separate host-side preamble).""" return self.kernel_global_source or ""🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/jit/adapter/torch/metal.py` around lines 56 - 57, The get_kernel_source method currently ignores the kernel_only parameter and always returns self.kernel_global_source or "", so update get_kernel_source in the Metal adapter (method name: get_kernel_source) to honor the BaseKernelAdapter contract: if kernel_only is True return only the kernel body (e.g. self.kernel_source or self.kernel_global_source_kernel_part) and if kernel_only is False return the full source including any preamble/includes (or, if Metal genuinely has no preamble support, explicitly document that by returning the same string but add a comment/docstring clarifying that kernel_only is not applicable); ensure you reference and use the existing attributes (self.kernel_global_source, self.kernel_source, or similar) to construct the conditional return and update the docstring to state the behavior when kernel_only is unsupported.tilelang/intrinsics/metal_macro_generator.py (1)
111-111: 💤 Low valuePrefer tuple unpacking over concatenation.
The static analysis suggests using
(*leading, row_idx, col_idx)instead ofleading + (row_idx, col_idx)for cleaner tuple construction.♻️ Proposed fix
- ptr = T.access_ptr(buffer[leading + (row_idx, col_idx)], "r") + ptr = T.access_ptr(buffer[(*leading, row_idx, col_idx)], "r")Apply similar changes at lines 145 and 203.
Also applies to: 145-145, 203-203
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/intrinsics/metal_macro_generator.py` at line 111, Replace tuple concatenation with tuple unpacking when indexing the buffer: change uses like buffer[leading + (row_idx, col_idx)] (seen when assigning ptr via T.access_ptr) to buffer[(*leading, row_idx, col_idx)] so the tuple is built cleanly; update the same pattern at the other occurrences noted (the other buffer index usages around the references at lines 145 and 203) to use (*leading, row_idx, col_idx) as well.testing/python/metal/test_metal_gemm_v2.py (1)
29-29: ⚡ Quick winAdd a
num_stages > 1variant here.Line 29 hardcodes
num_stages=0, so this test never exercises the pipelined path that the PR is fixing. Anum_stages=2/3case would cover the stage-dimension lowering regression directly.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_gemm_v2.py` at line 29, The test currently only uses a pipelined loop with num_stages=0 and therefore never exercises the pipelined lowering path; update the test in test_metal_gemm_v2 by adding a variant that iterates the same loop with T.Pipelined(..., num_stages=2) (and optionally num_stages=3) so the code path for staged pipelining is exercised. Concretely, either parametrize the test to run with num_stages in [0,2] (or [0,2,3]) or duplicate the loop that uses T.Pipelined(T.ceildiv(K, block_K), num_stages=0) and change the duplicate to num_stages=2 so the stage-dimension lowering regression is covered; keep the same surrounding setup and assertions.benchmark/matmul_metal/benchmark_matmul_metal.py (1)
34-38: ⚡ Quick winThis benchmark never measures the pipelined path.
The inner loop is fixed at
num_stages=0, so the script cannot quantify the stage-aware 3D-buffer fix or thenum_stages>1speedups called out in the PR. Please makenum_stagesconfigurable and include a staged default/sweep value so the benchmark reflects the feature being added here.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmark/matmul_metal/benchmark_matmul_metal.py` around lines 34 - 38, The pipelined inner loop currently hardcodes num_stages=0 in the T.Pipelined call (loop variable ko) so the benchmark never measures staged execution; make num_stages configurable in benchmark_matmul_metal.py (e.g., a CLI flag or config variable) and use that value in the T.Pipelined(..., num_stages=...) call, provide a sensible default that includes staged runs (e.g., 0 and >1) and add an optional sweep (iterate over multiple num_stages values) so the benchmark can run and report results for both non-pipelined and staged settings; update any driver/runner code that invokes the benchmark to accept and pass this parameter.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@testing/python/metal/test_metal_gemm_v2_linux.py`:
- Around line 28-34: The test currently only exercises the pipeline with
num_stages=0; add a new regression case that uses T.Pipelined(..., num_stages=2)
or num_stages=3 and uses stage-indexed shared-memory buffers so the 3D
shared-buffer lowering is exercised. Concretely, duplicate the existing loop
variant that calls T.Pipelined(T.ceildiv(K, block_K), num_stages=0) but set
num_stages=2 (or 3), allocate the shared buffers with an extra leading stage
dimension (i.e., stage-indexed shared buffers) and update the T.copy/T.gemm
calls to read/write via the current stage index so copies target
A_shared[stage,...] and B_shared[stage,...] and the compute reads those
stage-indexed buffers; this will force the stage>0 compile path to run and catch
regressions in the shared-buffer lowering.
In `@tilelang/engine/lower.py`:
- Around line 263-264: The device_codegen_without_compile() path is still
invoking the compiled Metal builder (target.build.tilelang_metal) which triggers
tvm_callback_metal_compile in src/target/codegen_metal.cc; change the lower.py
branch so that when enable_device_compile is False it calls a new, compile-free
entry point (e.g., target.build.tilelang_metal_without_compile) or passes an
explicit compile flag through the Metal builder API; implement the new symbol
target.build.tilelang_metal_without_compile (or accept the flag) in the Metal
codegen and ensure src/target/codegen_metal.cc avoids calling
tvm_callback_metal_compile when the compile flag is false so tilelang.lower(...,
enable_device_compile=False) never requires the Metal compiler.
In `@tilelang/transform/metal_fragment_to_simdgroup.py`:
- Around line 57-65: The _remap_buffer function rebuilds a buffer but omits
important metadata (strides, elem_offset, buffer_type, axis_separators, etc.),
which changes indexing/layout; update the tir.decl_buffer call in _remap_buffer
to pass through all original buffer fields from buf (e.g., strides, elem_offset,
buffer_type, axis_separators and any other non-default attributes) along with
shape/dtype/name/data=new_data/data_alignment/offset_factor so the new buffer
preserves the original layout and semantic information.
- Around line 82-110: The pre-order rewrite must rebuild the tir.Block
completely and not leave metadata pointing at old buffers: apply
tir.stmt_functor.substitute (using var_map) to stmt.reads, stmt.writes and
stmt.init (in addition to stmt.body) when constructing new_block/BlockRealize so
BufferRegion and init expressions reference the new variables (e.g., when
creating new_block in the branch that handles tir.Block). Also avoid
short-circuiting traversal by returning None from the pre-order callback after
recording the replacement (or otherwise ensure the transform continues into
nested binders) so nested tir.Allocate nodes under the block are visited and
their buffer_var bindings are updated; reference tir.Block, tir.BlockRealize,
tir.Allocate, var_map, and tir.stmt_functor.ir_transform in your changes.
---
Outside diff comments:
In `@tilelang/jit/adapter/base.py`:
- Around line 69-75: The docstring for get_current_device_functor is outdated
about fallback behavior; update it to document the actual device selection order
(CUDA → MPS → CPU) and mention that when CUDA is unavailable but MPS is
available the callable yields torch.device("mps"), otherwise
torch.device("cpu"); keep the rest of the description about returning a callable
that fetches the current torch device at call time.
---
Nitpick comments:
In `@benchmark/matmul_metal/benchmark_matmul_metal.py`:
- Around line 34-38: The pipelined inner loop currently hardcodes num_stages=0
in the T.Pipelined call (loop variable ko) so the benchmark never measures
staged execution; make num_stages configurable in benchmark_matmul_metal.py
(e.g., a CLI flag or config variable) and use that value in the T.Pipelined(...,
num_stages=...) call, provide a sensible default that includes staged runs
(e.g., 0 and >1) and add an optional sweep (iterate over multiple num_stages
values) so the benchmark can run and report results for both non-pipelined and
staged settings; update any driver/runner code that invokes the benchmark to
accept and pass this parameter.
In `@pyproject.toml`:
- Around line 33-34: Add a brief inline comment explaining the Darwin cap on the
second apache-tvm-ffi entry: locate the dependency string "apache-tvm-ffi<0.1.8;
platform_system == 'Darwin'" and append a short comment (e.g., "# 0.1.8+
incompatible on macOS") to document why the macOS-specific upper bound exists,
mirroring the explanatory style used for the base constraint and the torch
entries.
In `@requirements-dev.txt`:
- Line 4: Add a short inline comment next to the apache-tvm-ffi<0.1.8;
platform_system == 'Darwin' entry in requirements-dev.txt explaining the macOS
upper bound (e.g., "# 0.1.8+ breaks on macOS; see issue tilelang#XXXX or link to
bug/PR"), so future readers understand why this Darwin-specific constraint
exists; reference the package name apache-tvm-ffi when adding the comment.
In `@src/transform/layout_inference.cc`:
- Around line 436-446: The loop currently skips layout presence checks for all
fragment buffers on Metal, which can hide real failures; modify the loop that
iterates use_list_ so that for every IsFragmentBuffer(buffer) you verify
layout_map.count(buffer) != 0 and trigger an ICHECK (or a failing diagnostic)
when missing; for the Metal target make the error message explicit (referencing
TargetIsMetal(target_), IsFragmentBuffer, layout_map, use_list_) stating this is
an unexpected non-accumulator fragment on Metal and suggesting investigation of
MetalFragmentToSimdgroup/LayoutInference ordering (so replace the existing
TargetIsMetal guard with a straight check and a clearer error message).
In `@testing/python/jit/test_tilelang_jit_adapter_mps.py`:
- Around line 10-47: These tests duplicate the same torch.backends.mps
monkeypatch guard; create a parametrized pytest fixture (e.g., mps_availability)
that accepts a boolean for is_available and performs the getattr/monkeypatch
logic, then update
test_current_device_functor_prefers_mps_when_cuda_unavailable,
test_current_device_functor_prefers_mps_when_cuda_init_fails, and
test_current_device_functor_falls_back_to_cpu_without_cuda_or_mps to depend on
that fixture and pass True/False as needed; keep the existing monkeypatch for
torch.cuda adjustments and continue calling
BaseKernelAdapter.get_current_device_functor() and asserting the returned
device.
In `@testing/python/metal/test_metal_gemm_v2.py`:
- Line 29: The test currently only uses a pipelined loop with num_stages=0 and
therefore never exercises the pipelined lowering path; update the test in
test_metal_gemm_v2 by adding a variant that iterates the same loop with
T.Pipelined(..., num_stages=2) (and optionally num_stages=3) so the code path
for staged pipelining is exercised. Concretely, either parametrize the test to
run with num_stages in [0,2] (or [0,2,3]) or duplicate the loop that uses
T.Pipelined(T.ceildiv(K, block_K), num_stages=0) and change the duplicate to
num_stages=2 so the stage-dimension lowering regression is covered; keep the
same surrounding setup and assertions.
In `@tilelang/intrinsics/metal_macro_generator.py`:
- Line 111: Replace tuple concatenation with tuple unpacking when indexing the
buffer: change uses like buffer[leading + (row_idx, col_idx)] (seen when
assigning ptr via T.access_ptr) to buffer[(*leading, row_idx, col_idx)] so the
tuple is built cleanly; update the same pattern at the other occurrences noted
(the other buffer index usages around the references at lines 145 and 203) to
use (*leading, row_idx, col_idx) as well.
In `@tilelang/jit/adapter/torch/metal.py`:
- Around line 56-57: The get_kernel_source method currently ignores the
kernel_only parameter and always returns self.kernel_global_source or "", so
update get_kernel_source in the Metal adapter (method name: get_kernel_source)
to honor the BaseKernelAdapter contract: if kernel_only is True return only the
kernel body (e.g. self.kernel_source or self.kernel_global_source_kernel_part)
and if kernel_only is False return the full source including any
preamble/includes (or, if Metal genuinely has no preamble support, explicitly
document that by returning the same string but add a comment/docstring
clarifying that kernel_only is not applicable); ensure you reference and use the
existing attributes (self.kernel_global_source, self.kernel_source, or similar)
to construct the conditional return and update the docstring to state the
behavior when kernel_only is unsupported.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 3abbf2a1-a65f-4d9f-ab27-98881f9770bb
📒 Files selected for processing (37)
benchmark/matmul_metal/benchmark_matmul_metal.pypyproject.tomlrequirements-dev.txtrequirements.txtsrc/backend/metal/CMakeLists.txtsrc/op/copy.ccsrc/op/copy.hsrc/op/fill.ccsrc/op/gemm.ccsrc/op/gemm.hsrc/op/parallel.ccsrc/op/utils.hsrc/target/codegen_metal.ccsrc/target/codegen_metal.hsrc/transform/layout_inference.ccsrc/transform/lower_device_storage_access_info.cctesting/python/jit/test_tilelang_jit_adapter_mps.pytesting/python/metal/metal_internal_runtime_coverage.mdtesting/python/metal/test_metal_gemm_v2.pytesting/python/metal/test_metal_gemm_v2_linux.pytesting/python/metal/test_metal_internal_scaffolding.pytesting/python/metal/test_metal_local_var.pytesting/python/metal/test_metal_simdgroup_store.pytilelang/engine/lower.pytilelang/engine/phase.pytilelang/intrinsics/metal_macro_generator.pytilelang/jit/adapter/base.pytilelang/jit/adapter/torch/metal.pytilelang/tileop/gemm/__init__.pytilelang/tileop/gemm/gemm_metal.pytilelang/tileop/gemm/inst.pytilelang/tileop/metal_gdn.pytilelang/tileop/metal_quant.pytilelang/tileop/metal_simdgroup.pytilelang/transform/decouple_type_cast.pytilelang/transform/metal_fragment_to_simdgroup.pytilelang/utils/language.py
| for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): | ||
| T.copy(A[by * block_M, ko * block_K], A_shared, coalesced_width=2) | ||
| T.copy(B[ko * block_K, bx * block_N], B_shared, coalesced_width=2) | ||
|
|
||
| T.gemm(A_shared, B_shared, C_local) | ||
|
|
||
| T.copy(C_local, C[by * block_M, bx * block_N], coalesced_width=2) |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win
Add a staged regression case here.
This helper hardcodes num_stages=0, so none of the tests in this file exercise the 3D shared-buffer lowering that this PR is fixing. Please add at least one T.Pipelined(..., num_stages=2/3) case with stage-indexed shared memory so the old stage>0 compile/overwrite failure cannot regress silently.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@testing/python/metal/test_metal_gemm_v2_linux.py` around lines 28 - 34, The
test currently only exercises the pipeline with num_stages=0; add a new
regression case that uses T.Pipelined(..., num_stages=2) or num_stages=3 and
uses stage-indexed shared-memory buffers so the 3D shared-buffer lowering is
exercised. Concretely, duplicate the existing loop variant that calls
T.Pipelined(T.ceildiv(K, block_K), num_stages=0) but set num_stages=2 (or 3),
allocate the shared buffers with an extra leading stage dimension (i.e.,
stage-indexed shared buffers) and update the T.copy/T.gemm calls to read/write
via the current stage index so copies target A_shared[stage,...] and
B_shared[stage,...] and the compute reads those stage-indexed buffers; this will
force the stage>0 compile path to run and catch regressions in the shared-buffer
lowering.
| elif target.kind.name == "metal": | ||
| device_mod = tvm.ffi.get_global_func("target.build.metal")(device_mod, target) | ||
| device_mod = tvm.ffi.get_global_func("target.build.tilelang_metal")(device_mod, target) |
There was a problem hiding this comment.
device_codegen_without_compile() still compiles Metal kernels.
This branch now calls the same builder as the compiled path, and src/target/codegen_metal.cc:547-574 still invokes tvm_callback_metal_compile whenever that callback is registered. So tilelang.lower(..., enable_device_compile=False) can still require a working Metal compiler on Apple hosts, which breaks the no-compile contract and makes source-only tests/environment behavior host-dependent. Please add a distinct target.build.tilelang_metal_without_compile entry point or thread an explicit compile flag through the builder.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/engine/lower.py` around lines 263 - 264, The
device_codegen_without_compile() path is still invoking the compiled Metal
builder (target.build.tilelang_metal) which triggers tvm_callback_metal_compile
in src/target/codegen_metal.cc; change the lower.py branch so that when
enable_device_compile is False it calls a new, compile-free entry point (e.g.,
target.build.tilelang_metal_without_compile) or passes an explicit compile flag
through the Metal builder API; implement the new symbol
target.build.tilelang_metal_without_compile (or accept the flag) in the Metal
codegen and ensure src/target/codegen_metal.cc avoids calling
tvm_callback_metal_compile when the compile flag is false so tilelang.lower(...,
enable_device_compile=False) never requires the Metal compiler.
| return tir.decl_buffer( | ||
| buf.shape, | ||
| buf.dtype, | ||
| buf.name, | ||
| data=new_data, | ||
| scope="metal.simdgroup", | ||
| data_alignment=buf.data_alignment, | ||
| offset_factor=buf.offset_factor, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, find and examine the file
find . -name "metal_fragment_to_simdgroup.py" -type fRepository: tile-ai/tilelang
Length of output: 112
🏁 Script executed:
# Also search for the _remap_buffer function
rg "_remap_buffer" -A 20 --context 5Repository: tile-ai/tilelang
Length of output: 1831
🏁 Script executed:
# Check the tir.decl_buffer API to see what parameters it accepts
rg "def decl_buffer" -A 30Repository: tile-ai/tilelang
Length of output: 1669
🏁 Script executed:
# Read the full _remap_buffer function and surrounding context
sed -n '45,80p' tilelang/transform/metal_fragment_to_simdgroup.pyRepository: tile-ai/tilelang
Length of output: 1118
🏁 Script executed:
# Also check the Buffer class definition to see what fields it has
rg "class Buffer|def strides|def elem_offset|def buffer_type|def axis_separators" -B 2 -A 2Repository: tile-ai/tilelang
Length of output: 4441
🏁 Script executed:
# Search for Buffer class definition
rg "class Buffer" -A 30 tilelang/Repository: tile-ai/tilelang
Length of output: 3579
🏁 Script executed:
# Also check TVM's buffer definition to understand Buffer object
rg "\.strides|\.elem_offset|\.buffer_type|\.axis_separators" tilelang/transform/metal_fragment_to_simdgroup.py -B 2 -A 2Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Check how buf properties are used elsewhere in the file
rg "buf\." tilelang/transform/metal_fragment_to_simdgroup.pyRepository: tile-ai/tilelang
Length of output: 269
🏁 Script executed:
# Search for how buffer strides and elem_offset are used in the codebase
rg "\.strides\b|\.elem_offset\b|\.buffer_type\b|\.axis_separators\b" --type py -B 2 -A 2 | head -100Repository: tile-ai/tilelang
Length of output: 7063
🏁 Script executed:
# Check if there are any uses of buffer properties in layout inference or similar passes
rg "layout_inference|LayoutInference" tilelang/ --type pyRepository: tile-ai/tilelang
Length of output: 793
🏁 Script executed:
# Look at what happens when a buffer is created to see if these fields have defaults
rg "def buffer\(" tilelang/ -A 50 | head -150Repository: tile-ai/tilelang
Length of output: 5958
🏁 Script executed:
# Check if buf is a tir.Buffer object and verify it has readable strides/elem_offset/buffer_type/axis_separators properties
rg "from.*tir import|import.*tir" tilelang/transform/metal_fragment_to_simdgroup.py -B 2 -A 2Repository: tile-ai/tilelang
Length of output: 206
🏁 Script executed:
# Look for any Buffer attribute access to confirm these are real properties
cat tilelang/language/ast/ir.py | grep -A 100 "def buffer(" | head -120Repository: tile-ai/tilelang
Length of output: 2798
Preserve all buffer metadata when changing scope.
_remap_buffer() currently rebuilds the buffer with only shape/dtype/name/data_alignment/offset_factor. If the original buffer carries non-default strides, elem_offset, buffer_type, or axis_separators, this pass silently changes indexing semantics and buffer layout information. The strides field in particular is actively used downstream for element offset computation in buffer indexing operations.
Proposed fix
return tir.decl_buffer(
buf.shape,
buf.dtype,
buf.name,
data=new_data,
+ strides=buf.strides,
+ elem_offset=buf.elem_offset,
scope="metal.simdgroup",
data_alignment=buf.data_alignment,
offset_factor=buf.offset_factor,
+ buffer_type=buf.buffer_type,
+ axis_separators=buf.axis_separators,
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| return tir.decl_buffer( | |
| buf.shape, | |
| buf.dtype, | |
| buf.name, | |
| data=new_data, | |
| scope="metal.simdgroup", | |
| data_alignment=buf.data_alignment, | |
| offset_factor=buf.offset_factor, | |
| ) | |
| return tir.decl_buffer( | |
| buf.shape, | |
| buf.dtype, | |
| buf.name, | |
| data=new_data, | |
| strides=buf.strides, | |
| elem_offset=buf.elem_offset, | |
| scope="metal.simdgroup", | |
| data_alignment=buf.data_alignment, | |
| offset_factor=buf.offset_factor, | |
| buffer_type=buf.buffer_type, | |
| axis_separators=buf.axis_separators, | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/transform/metal_fragment_to_simdgroup.py` around lines 57 - 65, The
_remap_buffer function rebuilds a buffer but omits important metadata (strides,
elem_offset, buffer_type, axis_separators, etc.), which changes indexing/layout;
update the tir.decl_buffer call in _remap_buffer to pass through all original
buffer fields from buf (e.g., strides, elem_offset, buffer_type, axis_separators
and any other non-default attributes) along with
shape/dtype/name/data=new_data/data_alignment/offset_factor so the new buffer
preserves the original layout and semantic information.
| new_body = tir.stmt_functor.substitute(stmt.body, var_map) | ||
| new_block = tir.Block( | ||
| stmt.iter_vars, | ||
| stmt.reads, | ||
| stmt.writes, | ||
| stmt.name_hint, | ||
| new_body, | ||
| stmt.init, | ||
| new_alloc_bufs, | ||
| stmt.match_buffers, | ||
| stmt.annotations, | ||
| ) | ||
| return ( | ||
| tir.BlockRealize( | ||
| stmt.iter_vars, | ||
| tir.const(True, "bool"), | ||
| new_block, | ||
| ) | ||
| if False | ||
| else new_block | ||
| ) | ||
| elif isinstance(stmt, tir.Allocate): | ||
| new_var = var_map.get(stmt.buffer_var, None) | ||
| if new_var is not None: | ||
| new_body = tir.stmt_functor.substitute(stmt.body, var_map) | ||
| return tir.Allocate(new_var, stmt.dtype, stmt.extents, stmt.condition, new_body, stmt.annotations) | ||
| return None | ||
|
|
||
| return tir.stmt_functor.ir_transform(body, _pre_order, None, ["tir.Block", "tir.Allocate"]) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
find . -name "metal_fragment_to_simdgroup.py" -type fRepository: tile-ai/tilelang
Length of output: 112
🏁 Script executed:
wc -l ./tilelang/transform/metal_fragment_to_simdgroup.pyRepository: tile-ai/tilelang
Length of output: 116
🏁 Script executed:
cat -n ./tilelang/transform/metal_fragment_to_simdgroup.pyRepository: tile-ai/tilelang
Length of output: 5319
🏁 Script executed:
# Search for similar patterns of ir_transform usage in the codebase
rg "ir_transform" --type py -B 3 -A 3 | head -100Repository: tile-ai/tilelang
Length of output: 4632
🏁 Script executed:
# Search for tests related to metal_fragment_to_simdgroup
find . -name "*test*metal*" -type f 2>/dev/nullRepository: tile-ai/tilelang
Length of output: 440
🏁 Script executed:
# Search for references to metal_fragment_to_simdgroup in tests
rg "metal_fragment_to_simdgroup|MetalFragmentToSimdgroup" --type pyRepository: tile-ai/tilelang
Length of output: 549
🏁 Script executed:
# Search for how ir_transform is used in add_bufstore_wrapper to understand the pattern
cat -n ./tilelang/transform/add_bufstore_wrapper.py | head -100Repository: tile-ai/tilelang
Length of output: 4199
🏁 Script executed:
# Look for TVM documentation or comments about ir_transform behavior
rg "pre.*order|post.*order" --type py tilelang/transform/ -A 2 -B 2 | head -80Repository: tile-ai/tilelang
Length of output: 3565
🏁 Script executed:
# Check how reads/writes are handled in test_tilelang_transform_legalize_safe_memory_access.py
cat -n ./testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py | head -150Repository: tile-ai/tilelang
Length of output: 5914
🏁 Script executed:
# Search for references to reads/writes substitution in TVM-related code
rg "reads|writes|init" tilelang/transform/metal_fragment_to_simdgroup.pyRepository: tile-ai/tilelang
Length of output: 156
🏁 Script executed:
# Search for how reads/writes are used in Block creation in TVM/tilelang
rg "BufferRegion|reads.*=|writes.*=" --type py tilelang/ | head -50Repository: tile-ai/tilelang
Length of output: 6373
🏁 Script executed:
# Look for how TVM's ir_transform processes returns from pre-order callbacks
# Check if there are any comments or documentation about pre-order callback behavior
rg "ir_transform" tilelang/transform/ -B 5 -A 10 | grep -A 15 "def _pre_order\|def _pre_visit"Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Check how ir_transform handles pre-order returns
cat -n ./tilelang/transform/add_bufstore_wrapper.py | sed -n '130,180p'Repository: tile-ai/tilelang
Length of output: 1377
🏁 Script executed:
# Check if there are any test files for metal_fragment_to_simdgroup
rg "simdgroup|fragment" testing/python/metal/ --type py | head -30Repository: tile-ai/tilelang
Length of output: 4027
🏁 Script executed:
# Let's look at the test file to understand what the transformation is supposed to do
cat -n ./testing/python/metal/test_metal_simdgroup_store.py | head -80Repository: tile-ai/tilelang
Length of output: 4069
🏁 Script executed:
# Let's understand the structure better - check if reads/writes contain buffer references
rg "BufferRegion" tilelang/transform/ --type py -B 2 -A 2Repository: tile-ai/tilelang
Length of output: 42
Rewrite the whole tir.Block consistently, including reads, writes, and init fields.
The block body is substituted at line 82, but the reconstructed block keeps the old reads, writes, and init fields without applying the variable substitution. Since these fields typically contain BufferRegion objects that reference buffer variables, this creates an IR inconsistency where the block's body uses the new metal.simdgroup variables while metadata fields reference the old local.fragment variables.
Additionally, when a modified Block is returned from the pre-order callback, nested tir.Allocate binders under it may not be visited, leaving old variable bindings in place while expressions already use the new variables. This can cause subsequent passes to encounter an internally inconsistent IR state.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/transform/metal_fragment_to_simdgroup.py` around lines 82 - 110, The
pre-order rewrite must rebuild the tir.Block completely and not leave metadata
pointing at old buffers: apply tir.stmt_functor.substitute (using var_map) to
stmt.reads, stmt.writes and stmt.init (in addition to stmt.body) when
constructing new_block/BlockRealize so BufferRegion and init expressions
reference the new variables (e.g., when creating new_block in the branch that
handles tir.Block). Also avoid short-circuiting traversal by returning None from
the pre-order callback after recording the replacement (or otherwise ensure the
transform continues into nested binders) so nested tir.Allocate nodes under the
block are visited and their buffer_var bindings are updated; reference
tir.Block, tir.BlockRealize, tir.Allocate, var_map, and
tir.stmt_functor.ir_transform in your changes.
There was a problem hiding this comment.
Pull request overview
This PR extends TileLang’s Metal backend to correctly handle software-pipelined shared buffers (T.Pipelined(..., num_stages>1)) by threading the stage/version dimension through Metal T.access_ptr-based address generation, and adds broader Metal simdgroup support (GEMM lowering/codegen/copy/fill) plus tests/benchmarks around the new paths.
Changes:
- Fix Metal simdgroup load/store address calculation to include leading (e.g., pipeline stage) indices when shared buffers become 3D under software pipelining.
- Introduce/extend Metal simdgroup infrastructure: fragment→simdgroup rewrite pass, simdgroup-aware GEMM implementation, simdgroup fill/copy lowering, and Metal codegen plumbing.
- Add Metal-focused tests (codegen + hardware-gated correctness) and internal scaffolding probes; adjust Torch JIT adapter to prefer MPS when CUDA is unavailable.
Reviewed changes
Copilot reviewed 36 out of 37 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| tilelang/utils/language.py | Adds is_metal_simdgroup scope predicate. |
| tilelang/transform/metal_fragment_to_simdgroup.py | New pass to rewrite GEMM accumulators from local.fragment to metal.simdgroup on Metal. |
| tilelang/transform/decouple_type_cast.py | Treats metal.simdgroup as a local/register buffer for cast decoupling. |
| tilelang/tileop/metal_simdgroup.py | Internal simdgroup RegisterTile helpers/macros (alloc/load/store/mma, plus scalar row ops). |
| tilelang/tileop/metal_quant.py | Internal packed-uint8 quant decode helpers for Metal probes. |
| tilelang/tileop/metal_gdn.py | Internal GDN/attention-style simdgroup tile macros. |
| tilelang/tileop/gemm/inst.py | Adds METAL_SIMDGROUP GEMM instruction enum value. |
| tilelang/tileop/gemm/gemm_metal.py | Implements Metal GEMM lowering via MPSIntrinEmitter simdgroup operations. |
| tilelang/tileop/gemm/init.py | Selects Metal GEMM implementation when target is Metal. |
| tilelang/jit/adapter/torch/metal.py | Exposes kernel source for Torch Metal adapter. |
| tilelang/jit/adapter/base.py | Updates device selection to prefer MPS when CUDA is unavailable/initialization fails. |
| tilelang/intrinsics/metal_macro_generator.py | Metal macro emitter updated to preserve leading indices (pipeline stage) in T.access_ptr. |
| tilelang/engine/phase.py | Inserts MetalFragmentToSimdgroup before layout inference. |
| tilelang/engine/lower.py | Switches Metal build hook to target.build.tilelang_metal. |
| testing/python/metal/test_metal_simdgroup_store.py | Tests simdgroup-accumulation path and direct simdgroup_store emission. |
| testing/python/metal/test_metal_local_var.py | Adds Metal codegen/runtime tests for local.var scalar lowering. |
| testing/python/metal/test_metal_internal_scaffolding.py | Adds internal-only Metal probes for simdgroup, quant decode, and GDN-style kernels. |
| testing/python/metal/test_metal_gemm_v2.py | Hardware-gated correctness tests for Metal T.gemm (gemm_v2). |
| testing/python/metal/test_metal_gemm_v2_linux.py | Cross-platform Metal source-generation tests for T.gemm (gemm_v2). |
| testing/python/metal/metal_internal_runtime_coverage.md | Documents Metal internal runtime/source-boundary coverage. |
| testing/python/jit/test_tilelang_jit_adapter_mps.py | Tests new MPS-preferred device selection behavior. |
| src/transform/lower_device_storage_access_info.cc | Excludes fragment scope tag from memory-info enforcement. |
| src/transform/layout_inference.cc | Skips fragment-layout completeness check for Metal targets. |
| src/target/codegen_metal.h | Declares CodeGenTileLangMetal code generator. |
| src/target/codegen_metal.cc | Implements Metal source codegen, simdgroup matrix lowering, and registers target.build.tilelang_metal. |
| src/op/utils.h | Adds IsSIMDGroupBuffer / IsRegisterBuffer helpers. |
| src/op/parallel.cc | Makes fragment-layout lookup conditional to avoid missing-layout crashes. |
| src/op/gemm.h | Adds Metal simdgroup GEMM instruction kind to C++ enum. |
| src/op/gemm.cc | Returns Metal simdgroup instruction for Metal targets; adjusts warp policy constants for Metal. |
| src/op/fill.cc | Adds simdgroup buffer fill lowering via make_filled_simdgroup_matrix. |
| src/op/copy.h | Adds Metal simdgroup copy instruction and lowering hooks. |
| src/op/copy.cc | Adds Metal simdgroup store lowering (LowerSIMDGroupCopy) and instruction selection. |
| src/backend/metal/CMakeLists.txt | Always compiles Metal codegen source for cross-platform codegen-only mode. |
| requirements.txt | Adds Darwin-only upper bound on apache-tvm-ffi. |
| requirements-dev.txt | Adds Darwin-only upper bound on apache-tvm-ffi for dev installs. |
| pyproject.toml | Adds Darwin-only upper bound on apache-tvm-ffi for packaging. |
| benchmark/matmul_metal/benchmark_matmul_metal.py | Adds a Metal GEMM benchmark script. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| float ideal = N > 0 ? static_cast<float>(M) / N : 1.f; | ||
| float best_score = std::numeric_limits<float>::max(); | ||
| for (int m = 1; m <= std::min(num_warps, max_m); ++m) { | ||
| if (num_warps % m != 0) | ||
| continue; | ||
| int n = num_warps / m; | ||
| if (n > max_n) | ||
| continue; | ||
| if (M % (m * kMPerWarp) != 0 || N % (n * kNPerWarp) != 0) | ||
| continue; | ||
| float m_per = static_cast<float>(M) / (m * kMPerWarp); | ||
| float n_per = static_cast<float>(N) / (n * kNPerWarp); | ||
| float score = std::abs(m_per / n_per - ideal); | ||
| if (score < best_score) { |
| int kMPerWarp = 8; | ||
| int kNPerWarp = 8; | ||
| int m_warp = 1, n_warp = num_warps; | ||
| int max_m = M / kMPerWarp; | ||
| int max_n = N / kNPerWarp; | ||
| if (max_m <= 0 || max_n <= 0) { | ||
| return LowerNormalCopy(T, analyzer); | ||
| } | ||
| float ideal = N > 0 ? static_cast<float>(M) / N : 1.f; | ||
| float best_score = std::numeric_limits<float>::max(); | ||
| for (int m = 1; m <= std::min(num_warps, max_m); ++m) { | ||
| if (num_warps % m != 0) | ||
| continue; | ||
| int n = num_warps / m; | ||
| if (n > max_n) | ||
| continue; | ||
| if (M % (m * kMPerWarp) != 0 || N % (n * kNPerWarp) != 0) | ||
| continue; | ||
| float m_per = static_cast<float>(M) / (m * kMPerWarp); | ||
| float n_per = static_cast<float>(N) / (n * kNPerWarp); | ||
| float score = std::abs(m_per / n_per - ideal); | ||
| if (score < best_score) { | ||
| best_score = score; | ||
| m_warp = m; | ||
| n_warp = n; | ||
| } | ||
| } | ||
|
|
||
| if (best_score == std::numeric_limits<float>::max() || M < m_warp * 8 || | ||
| N < n_warp * 8) { | ||
| return LowerNormalCopy(T, analyzer); | ||
| } | ||
| int warp_row_tiles = M / m_warp / 8; | ||
| int warp_col_tiles = N / n_warp / 8; | ||
| if (warp_row_tiles <= 0 || warp_col_tiles <= 0 || | ||
| warp_row_tiles * warp_col_tiles * 64 > total_elements) { | ||
| return LowerNormalCopy(T, analyzer); | ||
| } | ||
|
|
||
| PrimExpr warp_m = FloorMod(warp_id, m_warp); | ||
| PrimExpr warp_n = FloorDiv(warp_id, m_warp); | ||
|
|
| def _rewrite_scope(body, var_map): | ||
| buf_map = {} | ||
|
|
||
| def _pre_order(stmt): | ||
| if isinstance(stmt, tir.Block): | ||
| new_alloc_bufs = [] | ||
| changed = False | ||
| for buf in stmt.alloc_buffers: | ||
| new_buf = _remap_buffer(buf, var_map) | ||
| new_alloc_bufs.append(new_buf) | ||
| if not new_buf.same_as(buf): | ||
| buf_map[buf] = new_buf | ||
| changed = True | ||
| if changed: | ||
| new_body = tir.stmt_functor.substitute(stmt.body, var_map) | ||
| new_block = tir.Block( | ||
| stmt.iter_vars, | ||
| stmt.reads, | ||
| stmt.writes, | ||
| stmt.name_hint, | ||
| new_body, | ||
| stmt.init, | ||
| new_alloc_bufs, | ||
| stmt.match_buffers, | ||
| stmt.annotations, | ||
| ) | ||
| return ( | ||
| tir.BlockRealize( | ||
| stmt.iter_vars, | ||
| tir.const(True, "bool"), | ||
| new_block, | ||
| ) | ||
| if False | ||
| else new_block | ||
| ) | ||
| elif isinstance(stmt, tir.Allocate): |
| // Check that all local.fragment buffers have inferred layouts. | ||
| // On Metal targets, fragment buffers used as GEMM accumulators are | ||
| // lowered to opaque simdgroup matrices, so they have no explicit | ||
| // thread-level layout and can be safely skipped. | ||
| for (const auto &[buffer, _] : use_list_) { | ||
| if (IsFragmentBuffer(buffer)) { | ||
| ICHECK_NE(layout_map.count(buffer), 0) | ||
| << "The layout for fragment " << buffer | ||
| << " can not be inferred correctly."; | ||
| if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) { | ||
| ICHECK(false) << "The layout for fragment " << buffer | ||
| << " can not be inferred correctly."; | ||
| } | ||
| } |
| if target_is_metal(target): | ||
| return GemmInst.METAL_SIMDGROUP |
| @staticmethod | ||
| def _parse_buffer_2d(buf): | ||
| """Extract (buffer, row_offset, col_offset, stride, leading_indices) from Buffer or BufferRegion. | ||
|
|
||
| ``leading_indices`` carries the .min of any region dims preceding the | ||
| last two (row, col). The InjectSoftwarePipeline pass expands shared | ||
| buffers from 2D to 3D by inserting a "version" dim at position 0 | ||
| (shape becomes ``[num_versions, M, N]`` and ``region[0]`` becomes the | ||
| per-iteration version index). Accessors must therefore include those | ||
| leading dims; otherwise the patched ``Buffer.__getitem__`` raises | ||
| ``IndexError: Buffer X is 3-dimensional ... but 2 index(es) were provided``. | ||
| See ``inject_pipeline.cc::RewritePipelineBufferRegion`` and | ||
| ``mma_macro_generator.py`` (CUDA) for the same pattern. | ||
| """ | ||
| if isinstance(buf, BufferRegion): | ||
| buffer = buf.buffer | ||
| off_row = buf.region[-2].min | ||
| off_col = buf.region[-1].min | ||
| leading = tuple(r.min for r in buf.region[:-2]) | ||
| else: | ||
| buffer = buf | ||
| off_row = 0 | ||
| off_col = 0 | ||
| leading = tuple(0 for _ in buf.shape[:-2]) | ||
| stride = buffer.strides[-2] if len(buffer.strides) == len(buffer.shape) else buffer.shape[-1] | ||
| return buffer, off_row, off_col, stride, leading |
Summary
Threads the pipeline-stage dimension through the Metal
T.access_ptrrewrite soT.Pipelined(K_iters, num_stages > 1)lowers correctly on Metal. Without this, the metal_macro_generator'sT.access_ptrpattern emits a 2D shape that doesn't index the per-stage shared buffers, producing either a compile error or silent overwrite at stage>0.The fix is in
tilelang/intrinsics/metal_macro_generator.py: when the buffer has 3D shape(num_stages, R, C), the access path now folds the stage axis into the offset arithmetic instead of dropping it.Why
Software pipelining is critical for hiding memory latency in MLA-style attention on Apple GPUs. Without
num_stages > 1working, every K-iteration of the attention loop is fully serialized over shared memory. Local probes (Mamba3, sparse-MLA fwd) showed 2-3× speedup withnum_stages=2once the 3D buffer indexing was fixed.This PR makes the existing Metal
T.Pipelinedpath actually correct fornum_stages>1at 16×16 fragment tiles. (Larger fragment sizes — 32×32 — still need follow-up work; this PR is the foundational fix that unblocks the small-fragment shape used by Mamba3 and partial sparse-MLA.)Stacking topology
This PR is based on
jorgecurious/tilelang:metal-gemm-upstream-rebase(PR #2130) at HEAD971c17b, which itself stacks on top of:metal_macro_generator.pyOnce #1869+#2118+#2121+#2130 merge into
tile-ai/tilelang:main, this PR can be retargeted to main directly.Test plan
The downstream probe that exercised this is
cppmega.mlx'stest_pipelined_probe.py:Expect
num_stages=2andnum_stages=3to both lower successfully on Metal target.Caveats
Attribution
Co-developed with
cppmega.mlxfor Apple-Silicon Metal MLA kernel ports.Summary by CodeRabbit
New Features
Documentation
Tests
Chores
apache-tvm-ffidependency constraints for macOS compatibility.