[Metal] emit Metal builtins directly instead of CUDA-style threadIdx/blockIdx aliases#2143
[Metal] emit Metal builtins directly instead of CUDA-style threadIdx/blockIdx aliases#2143apstenku123 wants to merge 11 commits intotile-ai:mainfrom
Conversation
Add T.gemm support for Apple Metal using simdgroup_matrix 8x8 operations (simdgroup_load/store/multiply_accumulate). Works on all Apple Silicon (M1-M5) without requiring a TVM fork. Key changes: - codegen_metal.cc/h: Fork TVM Metal codegen to tilelang with simdgroup intrinsic emission and 128-bit vectorized copy - gemm_metal.py: GemmMetal tile operator for sharedxshared GEMM - metal_macro_generator.py: MPSIntrinEmitter for simdgroup MMA macros - metal_fragment_to_simdgroup.py: Pass rewrites local.fragment GEMM accumulators to metal.simdgroup scope before layout inference - LowerSIMDGroupCopy in copy.cc for fragment->device simdgroup_store 24 Metal tests (codegen cross-platform + correctness on device).
…blockIdx aliases
The Metal codegen previously named the kernel-launch parameters using
CUDA-style identifiers (`blockIdx`, `threadIdx`):
kernel void smoke_kernel(
device const half4* A [[ buffer(0) ]],
device half4* C [[ buffer(1) ]],
uint3 blockIdx [[threadgroup_position_in_grid]],
uint3 threadIdx [[thread_position_in_threadgroup]]
) {
C[((((int)threadIdx.x) * 4) / 4)] = A[((((int)threadIdx.x) * 4) / 4)];
}
These names mirror CUDA but aren't required on Metal, and force every
downstream MSL pass that needs to inline the kernel body into another
kernel (e.g. cppmega.mlx's Path C ports that splice the body into
`mx.fast.metal_kernel` `source=` strings) to first inject
`uint3 blockIdx = threadgroup_position_in_grid;` shims, then regex back to
the Metal builtin, then drop the now-dead alias decl. The whole chain is
pure overhead; see cppmega.mlx
`cppmega_mlx/nn/_tilelang/_msl_transform.py::
_canonicalize_tilelang_builtin_aliases` and the four supporting helpers.
Emit the Metal builtin names directly:
kernel void smoke_kernel(
device const half4* A [[ buffer(0) ]],
device half4* C [[ buffer(1) ]],
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]]
) {
C[((((int)thread_position_in_threadgroup.x) * 4) / 4)] = A[((((int)thread_position_in_threadgroup.x) * 4) / 4)];
}
`BindThreadIndex` translates the CUDA-style `IterVar::thread_tag`
(`"blockIdx.x"`, `"threadIdx.y"`, ...) to the matching Metal builtin
reference before recording it in `var_idmap_`. The legacy
`blockIdx`/`threadIdx` names remain reserved by `name_supply_` so the
rest of the kernel cannot collide with them, preserving the existing
contract on the reservation order.
This change is limited to the `tilelang_metal` codegen path. Touches the
parallel `target.build.metal` codegen in the vendored TVM submodule are
filed separately to TileLang/tvm.
Stacks on tile-ai#2130 (jorgecurious metal-gemm-upstream-rebase).
📝 WalkthroughWalkthroughThis PR introduces a comprehensive Metal backend for TileLang, adding a Metal code generator, SIMD-group-based GEMM lowering, specialized tiling/computation primitives for quantization and GDN operations, macOS dependency constraints, and extensive test coverage including benchmarking, correctness validation, and internal probes. ChangesMetal Backend & Core Infrastructure
GEMM Metal Lowering
Metal Intrinsics & High-Level Tiling
Device Selection & Adapter Updates
Dependency & Test Infrastructure
Sequence Diagram(s)sequenceDiagram
participant User
participant JIT as TileLang JIT
participant DeviceDetect as Device Detection
participant MetalCodegen as Metal Codegen
participant Runtime as Metal Runtime
participant MPS as PyTorch MPS
User->>JIT: Define GEMM kernel (T.prim_func)
JIT->>DeviceDetect: Detect device (CUDA→MPS→CPU)
DeviceDetect-->>JIT: device="mps"
JIT->>MetalCodegen: Lower IR to Metal
MetalCodegen->>MetalCodegen: Emit kernel signature<br/>(storage scope, thread mapping)
MetalCodegen->>MetalCodegen: Lower GEMM ops to<br/>simdgroup intrinsics
MetalCodegen->>MetalCodegen: Generate Metal source code
MetalCodegen-->>JIT: artifact (kernel_source)
JIT->>Runtime: Compile Metal source<br/>(optional FFI callback)
Runtime->>MPS: Create Metal library/kernel
Runtime-->>JIT: compiled kernel
User->>JIT: Launch kernel with tensors
JIT->>MPS: Execute on GPU
MPS-->>JIT: Results
JIT-->>User: Output tensor
sequenceDiagram
participant Lowering as Lowering Pipeline
participant FragmentPass as MetalFragmentToSimdgroup
participant LayoutInfer as Layout Inference
participant CopyOp as Copy Operation
participant GemmOp as GEMM Operation
participant MetalCodegen as Metal Codegen
Lowering->>Lowering: Apply software pipeline
Lowering->>FragmentPass: Rewrite local.fragment<br/>→ metal.simdgroup
FragmentPass->>FragmentPass: Collect accumulator vars
FragmentPass->>FragmentPass: Create metal.simdgroup<br/>pointer vars
FragmentPass->>FragmentPass: Rewrite function body
FragmentPass-->>Lowering: Updated IR
Lowering->>LayoutInfer: Infer buffer layouts
LayoutInfer->>LayoutInfer: Skip fragment layout<br/>validation for Metal
LayoutInfer-->>Lowering: Layout map
Lowering->>CopyOp: Lower copy ops
CopyOp->>CopyOp: Check SIMD-group<br/>eligibility
CopyOp-->>Lowering: simdgroup_store calls
Lowering->>GemmOp: Lower GEMM ops
GemmOp->>GemmOp: Select Metal Simdgroup<br/>instruction
GemmOp->>GemmOp: Instantiate<br/>MPSIntrinEmitter
GemmOp-->>Lowering: ldmatrix/mma/store<br/>macros
Lowering->>MetalCodegen: Codegen to Metal
MetalCodegen-->>Lowering: Metal shader source
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes This PR introduces a complete, intricate Metal backend spanning architecture from low-level codegen (615 lines of code-generation logic), multiple IR lowering paths (copy/fill/gemm with Metal-specific validation and code emission), a complex set of tiling abstractions and intrinsics (530+ lines for register tiles and GDN helpers), quantization support, device selection fallback logic, and extensive test coverage with probes, benchmarks, and runtime validation. The heterogeneity of concerns (code generation, IR transformation, abstract machine models, device detection, multiple compute patterns) and the density of special-case handling (dtype conversion, storage scope mapping, thread index binding, simdgroup matrix operations, packed quantization decoding) demand sustained attention to correctness and architectural coherence across many interdependent layers. Possibly related PRs
Suggested reviewers
Poem
✨ Finishing Touches🧪 Generate unit tests (beta)
|
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
#1 as false alarm Two parallel agent investigations of follow-up Path C upstream candidates: CANDIDATE #1 — closed as false alarm docs/upstream/tilelang_metal_inline_kernel_body/ The 'Apple MSL forbids threadgroup allocs in non-kernel functions' bug DOES NOT exist in jorgecurious/tilelang:metal-gemm-upstream-rebase. Metal codegen already emits prim_func bodies directly inside `kernel void` without any inline-void wrapper. xcrun --sdk macosx metal -c compiles T.alloc_shared and T.alloc_fragment+T.gemm prim_funcs cleanly without cppmega post-processing. Verified .air artifacts checked in. The cppmega `_inline_tilelang_kernel_body` workaround is real but solves a different problem: it adapts TileLang's complete `kernel void` MSL for MLX's mx.fast.metal_kernel API (MLX generates its own kernel signature and doesn't accept a pre-baked one). This is an MLX/TileLang integration concern, not a TileLang codegen bug. No upstream PR needed. CANDIDATE #2 — PR FILED at tile-ai/tilelang#2143 docs/upstream/tilelang_metal_emit_metal_builtins/ The 'CUDA-style threadIdx/blockIdx aliases' bug IS real. TileLang's metal codegen emitted `uint3 blockIdx [[threadgroup_position_in_grid]],` and `((int)threadIdx.x)` references in body. Fix: emit the Metal builtin names directly as kernel-launch parameters and body references. Patches: - 0001-metal-emit-builtins-directly.patch (75 lines, +36/-7 in src/target/codegen_metal.cc) - 0002-tvm-metal-emit-builtins-directly.patch (69 lines, +30/-7 in 3rdparty/tvm/src/target/source/codegen_metal.cc — for TileLang/tvm submodule, separate companion PR not yet filed) Verified: TileLang ninja -j8 builds clean. Smoke test on T.copy prim_func shows alias gone, Metal builtin used directly. cppmega Path C tests (test_tilelang_mamba3_path_c.py 11/11 pass) — regex helpers in _msl_transform.py become no-ops with this fix, output identical. PR: tile-ai/tilelang#2143
There was a problem hiding this comment.
Pull request overview
This PR updates TileLang’s Metal backend to emit Metal builtin thread/block identifiers directly (instead of CUDA-style blockIdx/threadIdx aliases) and expands the Metal stack around simdgroup GEMM/register tiles, including new lowering/codegen paths and test coverage.
Changes:
- Switch Metal device codegen to a TileLang-specific Metal builder (
target.build.tilelang_metal) and emit Metal builtin parameter names directly in generated MSL. - Introduce
metal.simdgroupscope support (rewrite pass, copy/fill lowering, Metal GEMM implementation, and Metal macro emitter). - Add extensive Metal-focused tests and internal scaffolding/bench utilities (codegen-only and MPS runtime validation).
Reviewed changes
Copilot reviewed 36 out of 37 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| tilelang/utils/language.py | Add is_metal_simdgroup() scope predicate for analysis/transforms. |
| tilelang/transform/metal_fragment_to_simdgroup.py | New PrimFunc pass to rewrite GEMM accumulator fragments to metal.simdgroup. |
| tilelang/transform/decouple_type_cast.py | Treat metal.simdgroup buffers as “local/register-level” for cast decoupling. |
| tilelang/tileop/metal_simdgroup.py | New internal Metal simdgroup RegisterTile/RowVector helpers and macros. |
| tilelang/tileop/metal_quant.py | New packed-uint8 fp8/fp4/e8m0 decode helpers for Metal kernels. |
| tilelang/tileop/metal_gdn.py | New internal GDN/attention-style tile macros built on simdgroup helpers. |
| tilelang/tileop/gemm/inst.py | Add METAL_SIMDGROUP GEMM instruction kind. |
| tilelang/tileop/gemm/gemm_metal.py | New Metal GEMM lowering using simdgroup matrix intrinsics. |
| tilelang/tileop/gemm/init.py | Select Metal GEMM implementation when targeting Metal. |
| tilelang/jit/adapter/torch/metal.py | Expose kernel source getter for Torch Metal adapter. |
| tilelang/jit/adapter/base.py | Prefer MPS device when CUDA is unavailable/initialization fails. |
| tilelang/intrinsics/metal_macro_generator.py | New MPSIntrinEmitter for simdgroup load/mma/store generation. |
| tilelang/engine/phase.py | Run Metal fragment→simdgroup rewrite before layout inference. |
| tilelang/engine/lower.py | Route Metal codegen through target.build.tilelang_metal. |
| testing/python/metal/test_metal_simdgroup_store.py | New tests for direct simdgroup-store path (codegen + MPS runtime). |
| testing/python/metal/test_metal_local_var.py | New tests for local.var scalar lowering on Metal. |
| testing/python/metal/test_metal_internal_scaffolding.py | New internal-only Metal probes (source-boundary + runtime). |
| testing/python/metal/test_metal_gemm_v2.py | New MPS runtime correctness tests for T.gemm v2. |
| testing/python/metal/test_metal_gemm_v2_linux.py | New cross-platform Metal source generation tests for T.gemm v2. |
| testing/python/metal/metal_internal_runtime_coverage.md | New doc summarizing internal Metal runtime coverage. |
| testing/python/jit/test_tilelang_jit_adapter_mps.py | New tests for MPS preference in device selection. |
| src/transform/lower_device_storage_access_info.cc | Skip storage-access info lowering for .fragment scope tags. |
| src/transform/layout_inference.cc | Relax fragment-layout assertion behavior on Metal targets. |
| src/target/codegen_metal.h | New TileLang-specific Metal codegen header. |
| src/target/codegen_metal.cc | New TileLang-specific Metal codegen impl (builtins, simdgroup, local.var). |
| src/op/utils.h | Add helpers for detecting metal.simdgroup and “register” buffers. |
| src/op/parallel.cc | Make fragment layout inference more robust when layout entry is absent. |
| src/op/gemm.h | Add Metal simdgroup GEMM instruction enum value. |
| src/op/gemm.cc | Select Metal GEMM inst for Metal targets; adjust warp partition policy for Metal. |
| src/op/fill.cc | Add metal.simdgroup fill lowering via make_filled_simdgroup_matrix. |
| src/op/copy.h | Add Metal simdgroup copy instruction kind and lowering hooks. |
| src/op/copy.cc | Add metal.simdgroup store lowering via simdgroup_store. |
| src/backend/metal/CMakeLists.txt | Always compile Metal codegen sources for cross-compilation; gate runtime on Apple. |
| requirements.txt | Add a Darwin-only upper bound for apache-tvm-ffi. |
| requirements-dev.txt | Add a Darwin-only upper bound for apache-tvm-ffi. |
| pyproject.toml | Add a Darwin-only upper bound for apache-tvm-ffi. |
| benchmark/matmul_metal/benchmark_matmul_metal.py | New simdgroup GEMM benchmark script for MPS. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if torch.cuda.is_available(): | ||
| try: | ||
| torch.cuda._lazy_init() | ||
| current_device = torch._C._cuda_getDevice | ||
| return lambda: torch.device("cuda", current_device()) | ||
| except Exception: | ||
| return lambda: torch.device("cuda", torch.cuda.current_device()) | ||
| pass | ||
| if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available(): | ||
| return lambda: torch.device("mps") | ||
| # CPU fallback | ||
| return lambda: torch.device("cpu") |
| for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) { | ||
| Var v = func->params[i]; | ||
| if (!v.dtype().is_handle()) | ||
| break; | ||
| this->stream << " "; | ||
| std::string vid = AllocVarID(v.get()); | ||
| auto it = alloc_storage_scope_.find(v.get()); | ||
| if (it != alloc_storage_scope_.end()) { | ||
| PrintStorageScope(it->second, this->stream); | ||
| } | ||
| PrintType(GetType(v), this->stream); | ||
| // Register handle data type | ||
| // TODO(tvm-team): consider simply keep type info in the | ||
| // type annotation(via a normalizing rewriting). | ||
| if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) { | ||
| if (auto *prim = ptr->element_type.as<PrimTypeNode>()) { | ||
| RegisterHandleType(v.get(), prim->dtype); | ||
| } | ||
| } | ||
| this->stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n"; | ||
| } | ||
| // Setup normal arguments. | ||
| size_t nargs = func->params.size() - num_buffer; | ||
| std::string varg = name_supply_->FreshName("arg"); | ||
| if (nargs != 0) { | ||
| std::string arg_buf_type = | ||
| static_cast<std::string>(global_symbol.value()) + "_args_t"; | ||
| this->stream << " constant " << arg_buf_type << "& " << varg | ||
| << " [[ buffer(" << num_buffer << ") ]],\n"; | ||
| // declare the struct |
| def _rewrite_scope(body, var_map): | ||
| buf_map = {} | ||
|
|
||
| def _pre_order(stmt): | ||
| if isinstance(stmt, tir.Block): | ||
| new_alloc_bufs = [] | ||
| changed = False | ||
| for buf in stmt.alloc_buffers: | ||
| new_buf = _remap_buffer(buf, var_map) | ||
| new_alloc_bufs.append(new_buf) | ||
| if not new_buf.same_as(buf): | ||
| buf_map[buf] = new_buf | ||
| changed = True | ||
| if changed: | ||
| new_body = tir.stmt_functor.substitute(stmt.body, var_map) | ||
| new_block = tir.Block( | ||
| stmt.iter_vars, | ||
| stmt.reads, | ||
| stmt.writes, | ||
| stmt.name_hint, | ||
| new_body, | ||
| stmt.init, | ||
| new_alloc_bufs, | ||
| stmt.match_buffers, | ||
| stmt.annotations, | ||
| ) | ||
| return ( | ||
| tir.BlockRealize( | ||
| stmt.iter_vars, | ||
| tir.const(True, "bool"), | ||
| new_block, | ||
| ) | ||
| if False | ||
| else new_block | ||
| ) | ||
| elif isinstance(stmt, tir.Allocate): |
| // Check that all local.fragment buffers have inferred layouts. | ||
| // On Metal targets, fragment buffers used as GEMM accumulators are | ||
| // lowered to opaque simdgroup matrices, so they have no explicit | ||
| // thread-level layout and can be safely skipped. | ||
| for (const auto &[buffer, _] : use_list_) { | ||
| if (IsFragmentBuffer(buffer)) { | ||
| ICHECK_NE(layout_map.count(buffer), 0) | ||
| << "The layout for fragment " << buffer | ||
| << " can not be inferred correctly."; | ||
| if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) { | ||
| ICHECK(false) << "The layout for fragment " << buffer | ||
| << " can not be inferred correctly."; | ||
| } |
There was a problem hiding this comment.
Actionable comments posted: 12
🧹 Nitpick comments (1)
testing/python/metal/test_metal_internal_scaffolding.py (1)
43-66: ⚡ Quick winInclude the legacy CUDA aliases in the shared forbidden-token check.
These source-boundary probes are a good place to catch the exact regression this PR is fixing. Right now
_assert_clean_metal_source()would still accept generated Metal that reintroducesblockIdx/threadIdx.Suggested update
_FORBIDDEN_EXTERNAL_TOKENS = ( + "blockidx", + "threadidx", "cooperative", "mpp", "mpsgraph", "warpgroup", "cp.async",🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_internal_scaffolding.py` around lines 43 - 66, The forbidden-token check in _assert_clean_metal_source doesn't include legacy CUDA alias names (e.g., blockIdx, threadIdx and their common variants), so generated Metal source could still contain CUDA-style identifiers; update the _FORBIDDEN_EXTERNAL_TOKENS tuple to include the legacy CUDA aliases (at minimum "blockidx", "threadidx", plus any common variants like "blockidx.x", "blockidx.y", "threadidx.x", etc., normalized to lowercase) and ensure _assert_clean_metal_source continues to lowercase the source before checking so these aliases are detected; reference the symbols _FORBIDDEN_EXTERNAL_TOKENS and _assert_clean_metal_source when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/op/copy.cc`:
- Around line 910-911: The SIMD-group copy path can be chosen even when
buffer_oob is true, causing LowerSIMDGroupCopy (which emits unconditional
simdgroup_store) to write out-of-bounds for tail tiles; modify the selection
logic in src/op/copy.cc where CheckSIMDGroupCopy(target) is checked so that it
also requires !buffer_oob before returning CopyInst::kMetalSIMDGroup (mirror the
same !buffer_oob guard used by the other specialized copy paths), ensuring
LowerSIMDGroupCopy is only used when the destination is guaranteed in-bounds.
In `@src/op/fill.cc`:
- Around line 167-198: The code currently computes simdgroup matrix slots from a
flattened element_offset (matrix_index_base = FloorDiv(element_offset, 64))
which allows misaligned regions and rejects tile-aligned ones; change the logic
to validate and compute matrix slots from per-dimension tile coordinates: for
each dimension i derive tile_size_i = 8 for the two fastest-varying dims (or the
dims that form an 8x8 tile), check
analyzer->CanProveEqual(FloorMod(region[i]->min, tile_size_i), 0) (or equivalent
per-dim checks) to ensure the region starts on an 8-element boundary in each
tile dimension, compute tile_index_i = FloorDiv(region[i]->min, tile_size_i),
compute matrix_index_base as the dot product of tile_index_i with tile_strides
(where tile_strides are derived from dst->shape/strides but in units of tiles,
not elements), and pass matrix_index_base into the existing call to
builtin::make_filled_simdgroup_matrix instead of FloorDiv(element_offset, 64);
update any related ICHECK/CanProveEqual assertions and remove the flat
64-element boundary checks (matrix_elements) to rely on the per-dimension tile
checks.
In `@src/target/codegen_metal.cc`:
- Around line 156-164: Move the Metal builtin name reservations so they run at
the start of AddFunction before any user-visible ids are allocated;
specifically, call name_supply_->FreshName("threadIdx"), ("blockIdx"),
("threadgroup_position_in_grid"), and ("thread_position_in_threadgroup") and
assert their expected values at the top of AddFunction (i.e., before any
AllocVarID() calls for buffer/POD params) to ensure collisions are avoided with
user kernel symbols.
In `@src/transform/layout_inference.cc`:
- Around line 436-445: The current check skips all missing fragment-buffer
layouts on Metal by gating on TargetIsMetal(target_), which hides real inference
failures; instead, only skip a fragment buffer if it was actually lowered to a
simdgroup matrix by MetalFragmentToSimdgroup. Replace the TargetIsMetal(target_)
gate in the loop over use_list_ (symbols: use_list_, IsFragmentBuffer,
layout_map, TargetIsMetal, MetalFragmentToSimdgroup) with a check that if
layout_map.count(buffer) == 0 then (a) if IsFragmentBuffer(buffer) and the
buffer has metadata/flag indicating it was lowered by MetalFragmentToSimdgroup
(e.g., WasLoweredToSimdgroup(buffer) or consult the metadata Map set by
MetalFragmentToSimdgroup) then skip, otherwise ICHECK(false) << "The layout for
fragment " << buffer << " can not be inferred correctly."; ensure you add or
read the existing metadata key emitted by MetalFragmentToSimdgroup rather than
relying on TargetIsMetal.
In `@testing/python/metal/test_metal_gemm_v2_linux.py`:
- Around line 74-78: The test test_metal_gemm_v2_small_blocks currently calls
assert_metal_gemm_v2_codegen(16, 16, 16, 16, 16, 16, dtype=T.float16) which uses
a 16x16 block and therefore never exercises the multi-tile-per-warp path; change
the block size arguments to at least 32x32 (e.g.,
assert_metal_gemm_v2_codegen(32, 32, 16, 16, 16, 16, dtype=T.float16) or
similar) so warp_rows > 1 and warp_cols > 1 can occur, and update the test
docstring to reflect the new block size and intent to reproduce the
stride/swizzle regression.
In `@testing/python/metal/test_metal_local_var.py`:
- Around line 34-40: The test's zero-init assertion is too strict — the kernel
only default-initializes y while x is emitted as "= 3", so change the assertion
in test_metal_local_var.py that currently expects
len(re.findall(r"\bint\s+\w+\s*=\s*0;", src)) >= 2 to require at least one
zero-initialized local (e.g., >= 1) or assert specifically for "int y = 0;" so
the test checks the actual lowering of local.var and does not depend on
unrelated temporaries (also keep the existing checks for x, the "(w + 4)"
pattern, and absence of "local.var" and "thread int").
In `@testing/python/metal/test_metal_simdgroup_store.py`:
- Around line 59-76: The test helper assert_simdgroup_store_codegen currently
verifies simdgroup helpers but doesn’t assert Metal builtin usage; update this
function (and _make_simdgroup_gemm_func usage) to also assert that the lowered
kernel_source (src) contains Metal builtin identifiers (e.g.
"thread_position_in_threadgroup" and "threadgroup_position_in_grid" or other
target-specific Metal builtin names used in your codegen) and that it does NOT
contain CUDA-style identifiers like "blockIdx" or "threadIdx" so a regression
back to CUDA naming would fail the test.
In `@tilelang/tileop/gemm/gemm_metal.py`:
- Around line 38-50: The MPSIntrinEmitter is being constructed with the raw
thread_var which leaves warp-index math wrong when thread_bounds.min is
non-zero; normalize the injected thread binding before passing it into
MPSIntrinEmitter by subtracting the thread bounds minimum (i.e., use a
normalized_thread_var = thread_var - thread_bounds.min or equivalent) so
warp_id/wrap-index calculations inside MPSIntrinEmitter compute from a
zero-based thread index; update the call that creates MPSIntrinEmitter (the
constructor invocation using MPSIntrinEmitter(...) with parameter thread_var) to
pass the normalized thread variable instead.
In `@tilelang/tileop/metal_simdgroup.py`:
- Around line 385-396: The mma_tile macro only uses k=0 now; fix by iterating
over the K fragments and calling mma for each k to accumulate into acc.fragment.
Inside mma_tile (function symbol: mma_tile), add an inner loop like "for tile_k
in T.unroll(<K-fragments>, explicit=True)" using the MMATile K-dimension (e.g.,
acc.fragments_k or a.fragments_k / b.fragments_k) and replace the current
a.index and b.index calls with a.index(tile_m, tile_k) and b.index(tile_k,
tile_n), leaving acc.index(tile_m, tile_n) as the destination so each mma call
accumulates the contributions across K fragments.
In `@tilelang/transform/metal_fragment_to_simdgroup.py`:
- Around line 57-65: The cloned Buffer created in the return
tir.decl_buffer(...) call loses important metadata (strides, elem_offset,
buffer_type, axis_separators, span, etc.) from the original buf; update the call
in metal_fragment_to_simdgroup.py to pass through all of buf's metadata (e.g.,
strides=buf.strides, elem_offset=buf.elem_offset, buffer_type=buf.buffer_type,
axis_separators=buf.axis_separators, span=buf.span and any other non-default
fields) so the new buffer preserves original semantics while changing only data
and scope.
- Around line 71-93: In _pre_order when rebuilding a tir.Block after remapping
alloc_buffers, ensure you also substitute and rebuild stmt.init, stmt.reads,
stmt.writes (and any match_buffers/regions that may reference the old buffer)
using var_map so the new_block is consistent; specifically, after detecting
changed, call tir.stmt_functor.substitute on stmt.body, stmt.init, and on the
buffer region lists (reads/writes/match_buffers) as needed, then construct
new_block with the substituted body, init, reads, writes and new_alloc_bufs and
update buf_map as already done.
---
Nitpick comments:
In `@testing/python/metal/test_metal_internal_scaffolding.py`:
- Around line 43-66: The forbidden-token check in _assert_clean_metal_source
doesn't include legacy CUDA alias names (e.g., blockIdx, threadIdx and their
common variants), so generated Metal source could still contain CUDA-style
identifiers; update the _FORBIDDEN_EXTERNAL_TOKENS tuple to include the legacy
CUDA aliases (at minimum "blockidx", "threadidx", plus any common variants like
"blockidx.x", "blockidx.y", "threadidx.x", etc., normalized to lowercase) and
ensure _assert_clean_metal_source continues to lowercase the source before
checking so these aliases are detected; reference the symbols
_FORBIDDEN_EXTERNAL_TOKENS and _assert_clean_metal_source when making the
change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: a25b2591-2a2c-4d54-9d12-c1abbb153d63
📒 Files selected for processing (37)
benchmark/matmul_metal/benchmark_matmul_metal.pypyproject.tomlrequirements-dev.txtrequirements.txtsrc/backend/metal/CMakeLists.txtsrc/op/copy.ccsrc/op/copy.hsrc/op/fill.ccsrc/op/gemm.ccsrc/op/gemm.hsrc/op/parallel.ccsrc/op/utils.hsrc/target/codegen_metal.ccsrc/target/codegen_metal.hsrc/transform/layout_inference.ccsrc/transform/lower_device_storage_access_info.cctesting/python/jit/test_tilelang_jit_adapter_mps.pytesting/python/metal/metal_internal_runtime_coverage.mdtesting/python/metal/test_metal_gemm_v2.pytesting/python/metal/test_metal_gemm_v2_linux.pytesting/python/metal/test_metal_internal_scaffolding.pytesting/python/metal/test_metal_local_var.pytesting/python/metal/test_metal_simdgroup_store.pytilelang/engine/lower.pytilelang/engine/phase.pytilelang/intrinsics/metal_macro_generator.pytilelang/jit/adapter/base.pytilelang/jit/adapter/torch/metal.pytilelang/tileop/gemm/__init__.pytilelang/tileop/gemm/gemm_metal.pytilelang/tileop/gemm/inst.pytilelang/tileop/metal_gdn.pytilelang/tileop/metal_quant.pytilelang/tileop/metal_simdgroup.pytilelang/transform/decouple_type_cast.pytilelang/transform/metal_fragment_to_simdgroup.pytilelang/utils/language.py
| } else if (CheckSIMDGroupCopy(target)) { | ||
| return CopyInst::kMetalSIMDGroup; |
There was a problem hiding this comment.
Gate Metal SIMD-group copy on buffer_oob.
This path can still be selected for edge tiles, but LowerSIMDGroupCopy() emits unconditional simdgroup_store calls and never predicates the destination. On non-divisible shapes that turns the tail tile into an out-of-bounds write. Please mirror the !buffer_oob guard used by the other specialized copy paths before returning kMetalSIMDGroup.
Suggested fix
- } else if (CheckSIMDGroupCopy(target)) {
+ } else if (!buffer_oob && CheckSIMDGroupCopy(target)) {
return CopyInst::kMetalSIMDGroup;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| } else if (CheckSIMDGroupCopy(target)) { | |
| return CopyInst::kMetalSIMDGroup; | |
| } else if (!buffer_oob && CheckSIMDGroupCopy(target)) { | |
| return CopyInst::kMetalSIMDGroup; |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/op/copy.cc` around lines 910 - 911, The SIMD-group copy path can be
chosen even when buffer_oob is true, causing LowerSIMDGroupCopy (which emits
unconditional simdgroup_store) to write out-of-bounds for tail tiles; modify the
selection logic in src/op/copy.cc where CheckSIMDGroupCopy(target) is checked so
that it also requires !buffer_oob before returning CopyInst::kMetalSIMDGroup
(mirror the same !buffer_oob guard used by the other specialized copy paths),
ensuring LowerSIMDGroupCopy is only used when the destination is guaranteed
in-bounds.
| ICHECK(total_elements % 64 == 0) | ||
| << "simdgroup buffer size must be multiple of 64 (8x8), got " | ||
| << total_elements; | ||
| int num_matrices = total_elements / 64; | ||
| PrimExpr fill_value = Cast(dst->dtype, value); | ||
| Array<PrimExpr> strides = dst->strides; | ||
| if (strides.empty()) { | ||
| PrimExpr stride = 1; | ||
| strides.resize(dst->shape.size()); | ||
| for (int i = static_cast<int>(dst->shape.size()) - 1; i >= 0; --i) { | ||
| strides.Set(i, stride); | ||
| stride *= dst->shape[i]; | ||
| } | ||
| } | ||
| ICHECK_EQ(strides.size(), dst->shape.size()) | ||
| << "simdgroup fill requires complete destination strides"; | ||
| PrimExpr element_offset = 0; | ||
| for (size_t i = 0; i < region.size(); ++i) { | ||
| element_offset += region[i]->min * strides[i]; | ||
| } | ||
| PrimExpr matrix_elements = IntImm(element_offset.dtype(), 64); | ||
| ICHECK( | ||
| analyzer->CanProveEqual(FloorMod(element_offset, matrix_elements), 0)) | ||
| << "simdgroup fill region must start on an 8x8 matrix boundary"; | ||
| PrimExpr matrix_index_base = FloorDiv(element_offset, matrix_elements); | ||
| Array<Stmt> stmts; | ||
| for (int i = 0; i < num_matrices; i++) { | ||
| stmts.push_back(Evaluate( | ||
| Call(DataType::Handle(), builtin::make_filled_simdgroup_matrix(), | ||
| {dst->data, matrix_index_base + IntImm(DataType::Int(32), i), | ||
| fill_value, IntImm(DataType::Int(32), 8), | ||
| IntImm(DataType::Int(32), 8)}))); |
There was a problem hiding this comment.
Compute simdgroup matrix slots from tile coordinates, not element_offset / 64.
This path currently treats any 64 contiguous flattened elements as one 8x8 simdgroup matrix. That rejects valid tile-aligned subregions like a 16x16 buffer slice starting at (0, 8) because Line 189 sees element_offset == 8, and it would also accept non-tile-shaped regions like (64, 1) because only the flattened element count is checked. The lowering needs per-dimension 8x8 tile validation and matrix-slot derivation from tile coordinates, not a flat element offset.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/op/fill.cc` around lines 167 - 198, The code currently computes simdgroup
matrix slots from a flattened element_offset (matrix_index_base =
FloorDiv(element_offset, 64)) which allows misaligned regions and rejects
tile-aligned ones; change the logic to validate and compute matrix slots from
per-dimension tile coordinates: for each dimension i derive tile_size_i = 8 for
the two fastest-varying dims (or the dims that form an 8x8 tile), check
analyzer->CanProveEqual(FloorMod(region[i]->min, tile_size_i), 0) (or equivalent
per-dim checks) to ensure the region starts on an 8-element boundary in each
tile dimension, compute tile_index_i = FloorDiv(region[i]->min, tile_size_i),
compute matrix_index_base as the dot product of tile_index_i with tile_strides
(where tile_strides are derived from dst->shape/strides but in units of tiles,
not elements), and pass matrix_index_base into the existing call to
builtin::make_filled_simdgroup_matrix instead of FloorDiv(element_offset, 64);
update any related ICHECK/CanProveEqual assertions and remove the flat
64-element boundary checks (matrix_elements) to rely on the per-dimension tile
checks.
| // Reserve the CUDA-style alias names so user code or downstream passes | ||
| // cannot accidentally collide with them, even though the kernel itself | ||
| // emits Metal builtin names directly (no `blockIdx`/`threadIdx` aliases). | ||
| ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); | ||
| ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); | ||
| ICHECK_EQ(name_supply_->FreshName("threadgroup_position_in_grid"), | ||
| "threadgroup_position_in_grid"); | ||
| ICHECK_EQ(name_supply_->FreshName("thread_position_in_threadgroup"), | ||
| "thread_position_in_threadgroup"); |
There was a problem hiding this comment.
Reserve these builtin names before allocating any user-visible ids.
Right now the FreshName(...) checks run after buffer/POD params have already called AllocVarID(). If a legal kernel symbol is named threadIdx, blockIdx, or one of the Metal builtin identifiers, this now hard-fails instead of avoiding the collision. The reservation needs to happen at the top of AddFunction, before any ids are assigned.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/target/codegen_metal.cc` around lines 156 - 164, Move the Metal builtin
name reservations so they run at the start of AddFunction before any
user-visible ids are allocated; specifically, call
name_supply_->FreshName("threadIdx"), ("blockIdx"),
("threadgroup_position_in_grid"), and ("thread_position_in_threadgroup") and
assert their expected values at the top of AddFunction (i.e., before any
AllocVarID() calls for buffer/POD params) to ensure collisions are avoided with
user kernel symbols.
| // Check that all local.fragment buffers have inferred layouts. | ||
| // On Metal targets, fragment buffers used as GEMM accumulators are | ||
| // lowered to opaque simdgroup matrices, so they have no explicit | ||
| // thread-level layout and can be safely skipped. | ||
| for (const auto &[buffer, _] : use_list_) { | ||
| if (IsFragmentBuffer(buffer)) { | ||
| ICHECK_NE(layout_map.count(buffer), 0) | ||
| << "The layout for fragment " << buffer | ||
| << " can not be inferred correctly."; | ||
| if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) { | ||
| ICHECK(false) << "The layout for fragment " << buffer | ||
| << " can not be inferred correctly."; | ||
| } |
There was a problem hiding this comment.
Don’t suppress all missing fragment layouts on Metal.
On the TileLang Metal path, MetalFragmentToSimdgroup already runs before LayoutInference, so eligible GEMM accumulators should no longer be local.fragment at this point. Gating this entire check on TargetIsMetal(target_) will silently accept unrelated layout-inference failures for other fragment buffers.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/transform/layout_inference.cc` around lines 436 - 445, The current check
skips all missing fragment-buffer layouts on Metal by gating on
TargetIsMetal(target_), which hides real inference failures; instead, only skip
a fragment buffer if it was actually lowered to a simdgroup matrix by
MetalFragmentToSimdgroup. Replace the TargetIsMetal(target_) gate in the loop
over use_list_ (symbols: use_list_, IsFragmentBuffer, layout_map, TargetIsMetal,
MetalFragmentToSimdgroup) with a check that if layout_map.count(buffer) == 0
then (a) if IsFragmentBuffer(buffer) and the buffer has metadata/flag indicating
it was lowered by MetalFragmentToSimdgroup (e.g., WasLoweredToSimdgroup(buffer)
or consult the metadata Map set by MetalFragmentToSimdgroup) then skip,
otherwise ICHECK(false) << "The layout for fragment " << buffer << " can not be
inferred correctly."; ensure you add or read the existing metadata key emitted
by MetalFragmentToSimdgroup rather than relying on TargetIsMetal.
| def test_metal_gemm_v2_small_blocks(): | ||
| """Test with blocks where warp_rows > 1 and warp_cols > 1, which previously | ||
| produced incorrect results due to swizzle padding changing the stride. | ||
| """ | ||
| assert_metal_gemm_v2_codegen(16, 16, 16, 16, 16, 16, dtype=T.float16) |
There was a problem hiding this comment.
This case doesn't hit the multi-tile-per-warp path it describes.
With a 16x16 block and 128 threads, each warp only covers one 8x8 tile, so warp_rows > 1 and warp_cols > 1 never both happen here. If the goal is to lock in the stride/swizzle regression, this should use at least a 32x32 block.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@testing/python/metal/test_metal_gemm_v2_linux.py` around lines 74 - 78, The
test test_metal_gemm_v2_small_blocks currently calls
assert_metal_gemm_v2_codegen(16, 16, 16, 16, 16, 16, dtype=T.float16) which uses
a 16x16 block and therefore never exercises the multi-tile-per-warp path; change
the block size arguments to at least 32x32 (e.g.,
assert_metal_gemm_v2_codegen(32, 32, 16, 16, 16, 16, dtype=T.float16) or
similar) so warp_rows > 1 and warp_cols > 1 can occur, and update the test
docstring to reflect the new block size and intent to reproduce the
stride/swizzle regression.
| elif target.kind.name == "metal": | ||
| device_mod = tvm.ffi.get_global_func("target.build.metal")(device_mod, target) | ||
| device_mod = tvm.ffi.get_global_func("target.build.tilelang_metal")(device_mod, target) |
There was a problem hiding this comment.
device_codegen_without_compile() can still compile Metal kernels.
target.build.tilelang_metal switches to metallib whenever tvm_callback_metal_compile is registered, so this path no longer guarantees source-only output. That can break inspect-only flows or CI environments that intentionally skip the Metal toolchain.
| mps_emitter = MPSIntrinEmitter( | ||
| a_dtype=self.in_dtype, | ||
| b_dtype=self.in_dtype, | ||
| accum_dtype=self.accum_dtype, | ||
| a_transposed=self.trans_A, | ||
| b_transposed=self.trans_B, | ||
| block_row_warps=m_warp, | ||
| block_col_warps=n_warp, | ||
| warp_row_tiles=warp_row_tiles, | ||
| warp_col_tiles=warp_col_tiles, | ||
| chunk=self.chunk, | ||
| thread_var=thread_var, | ||
| ) |
There was a problem hiding this comment.
Normalize the injected thread binding before computing warp indices.
MPSIntrinEmitter gets the raw thread_var here, but its warp-id math is zero-based. If this GEMM is lowered with a non-zero thread_bounds.min, the warp mapping shifts and the per-warp tile coordinates become wrong. The SIMD-group copy path already subtracts the range min for the same reason.
Suggested fix
mps_emitter = MPSIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
- thread_var=thread_var,
+ thread_var=thread_var - thread_bounds.min,
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| mps_emitter = MPSIntrinEmitter( | |
| a_dtype=self.in_dtype, | |
| b_dtype=self.in_dtype, | |
| accum_dtype=self.accum_dtype, | |
| a_transposed=self.trans_A, | |
| b_transposed=self.trans_B, | |
| block_row_warps=m_warp, | |
| block_col_warps=n_warp, | |
| warp_row_tiles=warp_row_tiles, | |
| warp_col_tiles=warp_col_tiles, | |
| chunk=self.chunk, | |
| thread_var=thread_var, | |
| ) | |
| mps_emitter = MPSIntrinEmitter( | |
| a_dtype=self.in_dtype, | |
| b_dtype=self.in_dtype, | |
| accum_dtype=self.accum_dtype, | |
| a_transposed=self.trans_A, | |
| b_transposed=self.trans_B, | |
| block_row_warps=m_warp, | |
| block_col_warps=n_warp, | |
| warp_row_tiles=warp_row_tiles, | |
| warp_col_tiles=warp_col_tiles, | |
| chunk=self.chunk, | |
| thread_var=thread_var - thread_bounds.min, | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/tileop/gemm/gemm_metal.py` around lines 38 - 50, The
MPSIntrinEmitter is being constructed with the raw thread_var which leaves
warp-index math wrong when thread_bounds.min is non-zero; normalize the injected
thread binding before passing it into MPSIntrinEmitter by subtracting the thread
bounds minimum (i.e., use a normalized_thread_var = thread_var -
thread_bounds.min or equivalent) so warp_id/wrap-index calculations inside
MPSIntrinEmitter compute from a zero-based thread index; update the call that
creates MPSIntrinEmitter (the constructor invocation using MPSIntrinEmitter(...)
with parameter thread_var) to pass the normalized thread variable instead.
| @T.macro | ||
| def mma_tile(acc: MMATile, a: MMATile, b: MMATile) -> None: | ||
| for tile_m in T.unroll(acc.fragments_m, explicit=True): | ||
| for tile_n in T.unroll(acc.fragments_n, explicit=True): | ||
| mma( | ||
| acc.fragment, | ||
| a.fragment, | ||
| b.fragment, | ||
| acc.index(tile_m, tile_n), | ||
| a.index(tile_m, 0), | ||
| b.index(0, tile_n), | ||
| ) |
There was a problem hiding this comment.
Accumulate across the K fragments in mma_tile.
Right now this always multiplies a.index(tile_m, 0) with b.index(0, tile_n), so any tile with a.fragments_n > 1 or b.fragments_m > 1 drops every contribution except k=0.
Proposed fix
`@T.macro`
def mma_tile(acc: MMATile, a: MMATile, b: MMATile) -> None:
+ if a.fragments_n != b.fragments_m:
+ raise ValueError(
+ f"mma_tile requires matching K fragments, got {a.fragments_n} and {b.fragments_m}"
+ )
for tile_m in T.unroll(acc.fragments_m, explicit=True):
for tile_n in T.unroll(acc.fragments_n, explicit=True):
- mma(
- acc.fragment,
- a.fragment,
- b.fragment,
- acc.index(tile_m, tile_n),
- a.index(tile_m, 0),
- b.index(0, tile_n),
- )
+ for tile_k in T.unroll(a.fragments_n, explicit=True):
+ mma(
+ acc.fragment,
+ a.fragment,
+ b.fragment,
+ acc.index(tile_m, tile_n),
+ a.index(tile_m, tile_k),
+ b.index(tile_k, tile_n),
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/tileop/metal_simdgroup.py` around lines 385 - 396, The mma_tile
macro only uses k=0 now; fix by iterating over the K fragments and calling mma
for each k to accumulate into acc.fragment. Inside mma_tile (function symbol:
mma_tile), add an inner loop like "for tile_k in T.unroll(<K-fragments>,
explicit=True)" using the MMATile K-dimension (e.g., acc.fragments_k or
a.fragments_k / b.fragments_k) and replace the current a.index and b.index calls
with a.index(tile_m, tile_k) and b.index(tile_k, tile_n), leaving
acc.index(tile_m, tile_n) as the destination so each mma call accumulates the
contributions across K fragments.
| return tir.decl_buffer( | ||
| buf.shape, | ||
| buf.dtype, | ||
| buf.name, | ||
| data=new_data, | ||
| scope="metal.simdgroup", | ||
| data_alignment=buf.data_alignment, | ||
| offset_factor=buf.offset_factor, | ||
| ) |
There was a problem hiding this comment.
Preserve the original Buffer metadata when cloning it.
tir.decl_buffer(...) here keeps only shape/dtype/name/alignment/offset_factor and drops other fields like strides, elem_offset, buffer_type, axis_separators, and span. Any fragment buffer that relies on non-default metadata changes semantics after this rewrite.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/transform/metal_fragment_to_simdgroup.py` around lines 57 - 65, The
cloned Buffer created in the return tir.decl_buffer(...) call loses important
metadata (strides, elem_offset, buffer_type, axis_separators, span, etc.) from
the original buf; update the call in metal_fragment_to_simdgroup.py to pass
through all of buf's metadata (e.g., strides=buf.strides,
elem_offset=buf.elem_offset, buffer_type=buf.buffer_type,
axis_separators=buf.axis_separators, span=buf.span and any other non-default
fields) so the new buffer preserves original semantics while changing only data
and scope.
| def _pre_order(stmt): | ||
| if isinstance(stmt, tir.Block): | ||
| new_alloc_bufs = [] | ||
| changed = False | ||
| for buf in stmt.alloc_buffers: | ||
| new_buf = _remap_buffer(buf, var_map) | ||
| new_alloc_bufs.append(new_buf) | ||
| if not new_buf.same_as(buf): | ||
| buf_map[buf] = new_buf | ||
| changed = True | ||
| if changed: | ||
| new_body = tir.stmt_functor.substitute(stmt.body, var_map) | ||
| new_block = tir.Block( | ||
| stmt.iter_vars, | ||
| stmt.reads, | ||
| stmt.writes, | ||
| stmt.name_hint, | ||
| new_body, | ||
| stmt.init, | ||
| new_alloc_bufs, | ||
| stmt.match_buffers, | ||
| stmt.annotations, | ||
| ) |
There was a problem hiding this comment.
Rewrite the block init and buffer regions together with alloc_buffers.
This branch only substitutes stmt.body. If the accumulator also appears in stmt.init, reads, or writes, the rebuilt block still points at the old local.fragment buffer even though alloc_buffers now holds the new metal.simdgroup one.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/transform/metal_fragment_to_simdgroup.py` around lines 71 - 93, In
_pre_order when rebuilding a tir.Block after remapping alloc_buffers, ensure you
also substitute and rebuild stmt.init, stmt.reads, stmt.writes (and any
match_buffers/regions that may reference the old buffer) using var_map so the
new_block is consistent; specifically, after detecting changed, call
tir.stmt_functor.substitute on stmt.body, stmt.init, and on the buffer region
lists (reads/writes/match_buffers) as needed, then construct new_block with the
substituted body, init, reads, writes and new_alloc_bufs and update buf_map as
already done.
Summary
TileLang's Metal codegen names the thread/block kernel-launch parameters
using CUDA-style identifiers (`blockIdx`, `threadIdx`):
```cpp
kernel void smoke_kernel(
device const half4* A [[ buffer(0) ]],
device half4* C [[ buffer(1) ]],
uint3 blockIdx [[threadgroup_position_in_grid]],
uint3 threadIdx [[thread_position_in_threadgroup]]
) {
C[((((int)threadIdx.x) * 4) / 4)] = A[((((int)threadIdx.x) * 4) / 4)];
}
```
These names mirror CUDA's `blockIdx.x`/`threadIdx.x` but aren't required
on Metal. Downstream MSL passes that inline the kernel body into another
`kernel void` (e.g. `mx.fast.metal_kernel` consumers) end up having to
canonicalize the alias back to the Metal builtin, resulting in pure
overhead.
This PR makes the Metal codegen emit Metal builtin names directly:
```cpp
kernel void smoke_kernel(
device const half4* A [[ buffer(0) ]],
device half4* C [[ buffer(1) ]],
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]]
) {
C[((((int)thread_position_in_threadgroup.x) * 4) / 4)] = A[((((int)thread_position_in_threadgroup.x) * 4) / 4)];
}
```
`BindThreadIndex` translates the CUDA-style `IterVar::thread_tag`
(`"blockIdx.x"`, `"threadIdx.y"`, ...) to the matching Metal builtin
reference (`threadgroup_position_in_grid.x`, ...) before recording it in
`var_idmap_`. The legacy `blockIdx`/`threadIdx` names remain reserved
by `name_supply_` so the rest of the kernel cannot collide with them,
preserving the existing assertion-based contract on the reservation
order.
Evidence: cppmega.mlx workaround
`cppmega.mlx` is a downstream consumer that splices TileLang-emitted
`kernel void` bodies into MLX's `mx.fast.metal_kernel` `source=` strings.
Today it has to canonicalize the CUDA-style alias names back to the
Metal builtins. See
`cppmega_mlx/nn/_tilelang/_msl_transform.py`:
decls after rewriting)
These helpers are used by every Path C port (sparse_mla_path_c.py,
sparse_mla_blockscaled_path_c.py, sparse_mla_fp8_path_c.py,
fp8_vecmat_path_c.py, mamba3_path_c.py). With this PR they become no-ops
because the emitted MSL already uses the Metal builtin names. The
helpers are kept (idempotent) so the fallback works against unpatched
TileLang releases as well.
Stacking
This PR stacks on PR #2130
(`jorgecurious/tilelang:metal-gemm-upstream-rebase`) for the
simdgroup-store hardening that the same Path C ports rely on. The diff
applies cleanly on top of that branch.
TVM-submodule half
The vendored `3rdparty/tvm/src/target/source/codegen_metal.cc`
(`target.build.metal`) carries the same alias problem and gets the
parallel fix; the TVM half is filed separately to `TileLang/tvm` and is
NOT part of this PR. `target.build.tilelang_metal` (this PR) is the
codepath used by `lower(prim_func, target='metal')` so this PR alone is
sufficient to remove the alias from TileLang user code.
Test plan
(`ninja -j8` succeeds, no warnings on the touched file).
`blockIdx.x`/`threadIdx.x` references, no `int blockIdx_x = ...;`
alias, `threadgroup_position_in_grid` and
`thread_position_in_threadgroup` appear directly in body and as
parameter names.
still pass against patched TileLang (11/11) and against unpatched
TileLang (11/11), confirming the regex helpers degrade to no-ops without
regression.
keep passing; none assert on the alias names.
Summary by CodeRabbit
Release Notes
New Features
Tests