Skip to content

[Metal] FP8 vector cast lanes 2/3/4 (extends storage-only FP8)#2145

Open
apstenku123 wants to merge 12 commits intotile-ai:mainfrom
apstenku123:cppmega/metal-fp8-vector-cast
Open

[Metal] FP8 vector cast lanes 2/3/4 (extends storage-only FP8)#2145
apstenku123 wants to merge 12 commits intotile-ai:mainfrom
apstenku123:cppmega/metal-fp8-vector-cast

Conversation

@apstenku123
Copy link
Copy Markdown

@apstenku123 apstenku123 commented May 4, 2026

Summary

Extends the storage-only Metal FP8 codegen to handle vectorized casts at IR
lanes 2 / 3 / 4. Without this change, every T.Cast("float16x4", fp8_x4)
emitted by upstream TileLang DSL programs raises
LOG(FATAL): Vector FP8 casts (lanes=4) are not yet supported, forcing
callers to manually scalarise the cast and giving up the IR-level vector
type for any subsequent pass (vectorize, fragment-to-simdgroup, etc.).

This PR is the TileLang supermodule half. The companion
TileLang/tvm submodule half is filed at
https://github.com/TileLang/tvm/pulls (search cppmega/metal-fp8-vector-cast).
Both halves only share helper names; they can land independently but
should be merged in tandem so the vendored 3rdparty/tvm checkout stays in
sync with this codepath.

What this changes

Adds an enable_fp8_vector_ codegen flag and a new
PrintFP8VectorPrelude(...) that emits inline MSL helpers that wrap the
existing scalar helpers (__tvm_fp8_e4m3_to_half, etc.) per lane:

inline half4 __tvm_fp8_e4m3_to_half_v4(uchar4 x) {
  return half4(__tvm_fp8_e4m3_to_half(x.x), __tvm_fp8_e4m3_to_half(x.y),
               __tvm_fp8_e4m3_to_half(x.z), __tvm_fp8_e4m3_to_half(x.w));
}

Mirrors are emitted for _v2 / _v3, plus the reverse direction
(half -> fp8) and the e5m2 variant. The compiler is free to scalarise
back into per-lane calls; the goal here is to preserve the IR-level
vector type so subsequent passes can keep their vector loads and stores
and the downstream MSL is uchar4-typed instead of uchar arrays.

Finish() is updated to splice the vector prelude after the scalar
prelude when at least one vector FP8 cast is encountered during codegen.

Wider lanes (8 / 16) keep the existing LOG(FATAL) with a sharper
message — those widths print as uint2 / uint4 packed storage and
need an out-pointer ABI to be wired through; callers should lower them
to scalar casts upstream.

Why Apple Silicon needs software FP8 emulation

Apple Silicon (M1 through M4 Max, including the M5 NAX which is
FP16/INT8 only) has no native FP8 ALU. FP8 is realised as uchar
storage with explicit dequantize-on-load / quantize-on-store; the
encoding mirrors the OCP "OFP8 Formats for Deep Learning" v1.0 spec
(E4M3 finite-only, E5M2 IEEE-style with NaN/Inf).

The vector helpers in this PR are inline-trivial wrappers around the
scalar helpers that landed in the storage-only PR — no new conversion
math. Their value is purely codegen: the IR-level vector type is
preserved so the rest of the lowering pipeline can vectorise.

Path C consumer evidence (vector lanes matter)

The downstream cppmega.mlx project's Path C TileLang FP8 vecmat kernel
(cppmega_mlx/nn/_tilelang/fp8_vecmat_path_c.py) explicitly uses
T.alloc_local((4,), "float8_e4m3") and a T.vectorized(4) inner loop
over packed FP8 weights. Without this PR, that kernel cannot be lowered
on Metal — the FP8 cast inside the K-loop hits the lanes=4 FATAL.
With this PR, the cast lowers and the resulting MSL preserves
uchar4-typed loads through the K-loop.

Dependency

This PR stacks on two prereqs:

  1. tilelang_metal_fp8 storage-only patch (parallel
    [Metal] FP8 storage-only emulation (uchar storage + LUT decode helpers)
    PR being filed against this same repo). That patch adds
    PrintFP8Prelude, enable_fp8_, and the scalar
    __tvm_fp8_*_to_half / __tvm_half_to_fp8_* helpers that the
    vector helpers in this PR call. Reviewers will need that patch
    applied first; the branch in this PR includes it as the first commit
    [Metal] FP8 storage-only emulation ... [prereq] for self-contained
    review.
  2. PR Rebase Metal simdgroup GEMM support and runtime coverage #2130 (jorgecurious's metal-gemm-upstream-rebase) at
    HEAD 971c17b. That branch in turn stacks on PRs
    [Metal] Add Metal GEMM support with simdgroup_matrix MMA #1869 / Add Metal scalar fallback for T.gemm #2118 / [Refactor][CodeGen] Refactor CodeGen part for multi-backend decoupling #2121.

When the storage-only PR merges, the prereq commit on this branch
should be rebased away. Before that, this branch is reviewable as
2-commits stacked.

Test plan

  • git apply --check clean against
    jorgecurious/tilelang:metal-gemm-upstream-rebase @ 971c17b with the
    storage-only prereq applied first
  • git apply --reverse --check clean for both commits in sequence
    (round-trip verified)
  • xcrun --sdk macosx metal -c compile of any prim_func with vector
    FP8 cast (lanes 2/3/4) lowers to MSL using the new vector helpers,
    not scalar fallback
  • Direct probe /tmp/test_fp8_vector_cast.py: lanes=4 cast lowers
    and the resulting MSL contains __tvm_fp8_e4m3_to_half_v4 with
    uchar4 typed loads
  • testing/python/metal/test_metal_codegen_linux.py: net +1 pass
    vs storage-only-only baseline
    (test_t_gemm_metal_codegen_pipelined_float32 flips green)
  • cppmega.mlx tilelang test suite: 134 passed, 0 regressions

Summary by CodeRabbit

Release Notes

  • New Features

    • Added Metal (Apple Silicon) GPU acceleration support for GEMM operations with optimized simdgroup and local storage modes
    • Introduced quantization helpers for low-precision formats (fp8, fp4) in Metal kernels
    • Added Metal-optimized GDN and attention component macros for advanced operations
  • Tests

    • Added comprehensive Metal-specific test suites covering GEMM correctness, simdgroup operations, local variable handling, and internal scaffolding validation
    • Added device fallback tests for MPS (Metal Performance Shaders) availability
  • Chores

    • Updated dependencies to constrain Metal backend compatibility on macOS
    • Enhanced Metal build system for cross-compilation support

oraluben and others added 12 commits April 30, 2026 01:43
Add T.gemm support for Apple Metal using simdgroup_matrix 8x8 operations
(simdgroup_load/store/multiply_accumulate). Works on all Apple Silicon
(M1-M5) without requiring a TVM fork.

Key changes:
- codegen_metal.cc/h: Fork TVM Metal codegen to tilelang with
  simdgroup intrinsic emission and 128-bit vectorized copy
- gemm_metal.py: GemmMetal tile operator for sharedxshared GEMM
- metal_macro_generator.py: MPSIntrinEmitter for simdgroup MMA macros
- metal_fragment_to_simdgroup.py: Pass rewrites local.fragment GEMM
  accumulators to metal.simdgroup scope before layout inference
- LowerSIMDGroupCopy in copy.cc for fragment->device simdgroup_store

24 Metal tests (codegen cross-platform + correctness on device).
Copilot AI review requested due to automatic review settings May 4, 2026 10:09
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 4, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

apstenku123 added a commit to DatasunriseOU/cppmega_mlx that referenced this pull request May 4, 2026
…, #37/#38/#39)

Three parallel agents completed the supermodule/submodule split filing:

1. tilelang_metal_fp8 (storage-only FP8 emulation) split:
   - 0001-tilelang-metal-fp8-storage-only.patch — supermodule half (235 lines)
   - 0002-tvm-metal-fp8-storage-only.patch — TVM-mirror half (260 lines, prefix stripped)
   - PR tile-ai/tilelang#2144 (supermodule, stacks on PR #2130)
   - PR tile-ai/tvm#38 (TVM mirror, base tilelang_main @ 0e15b274)

2. tilelang_metal_fp8_vector (vector cast lanes 2/3/4) split:
   - 0001-tilelang-metal-fp8-vector-cast.patch — supermodule half (148 lines)
   - 0002-tvm-metal-fp8-vector-cast.patch — TVM-mirror half (151 lines)
   - PR tile-ai/tilelang#2145 (supermodule, depends on #2144)
   - PR tile-ai/tvm#39 (TVM mirror, depends on #38)

3. PR #2143 TVM-mirror companion:
   - PR tile-ai/tvm#37 — already filed, README updated to link both halves

Total filed today: 11 PRs across 3 repos
- 1 ml-explore/mlx (#3476)
- 1 apache/tvm (#19504)
- 6 tile-ai/tilelang (#2139, #2140, #2141, #2142, #2143 super, #2144 super, #2145 super)
- 3 tile-ai/tvm (#37, #38, #39 — TVM-mirror companions)

PR #2142 (T.fp8_scaled_matmul) has no TVM-mirror companion needed —
verified the patch only touches supermodule files.

All splits round-trip clean (apply forward + reverse) on their respective
bases. README files in each docs/upstream/<dir>/ updated with PR URLs and
dependency-chain diagrams.

Note: TileLang/tvm redirects to tile-ai/tvm server-side (canonical org
slug). All TVM-mirror PRs land at tile-ai/tvm/pull/N URLs.
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 4, 2026

📝 Walkthrough

Walkthrough

Introduces comprehensive Metal (MPS) backend support for TileLang including native code generation, simdgroup-scoped GEMM and tensor operations, register-tile abstractions with multiply-accumulate macros, quantization and attention-pattern helpers, and extensive hardware/codegen testing. Updates build system, dependency constraints, device selection logic, and compiler phase ordering.

Changes

Metal Backend Implementation

Layer / File(s) Summary
Data Shape & Type Extensions
src/op/copy.h, src/op/gemm.h, tilelang/tileop/gemm/inst.py, src/op/utils.h, tilelang/utils/language.py
New enum members CopyInst::kMetalSIMDGroup and GemmInst::METAL_SIMDGROUP represent Metal-specific copy/GEMM operations. New buffer scope predicates IsSIMDGroupBuffer and is_metal_simdgroup identify "metal.simdgroup" buffers. IsRegisterBuffer expanded to include simdgroup buffers.
Metal Code Generation
src/target/codegen_metal.cc, src/target/codegen_metal.h
New CodeGenTileLangMetal class derives from CodeGenC and emits Metal Shading Language. Includes FP8 storage-only conversion helpers for scalar and vectorized casts, metal-specific type/scope/thread-binding printing, overrides for allocate/buffer-load/store to handle metal.simdgroup and local.var buffers, and Finish() conditional injection of MSL prelude code.
Core GEMM & Copy Operations
src/op/gemm.cc, src/op/gemm.h, src/op/copy.cc, src/op/copy.h, src/op/fill.cc, src/op/parallel.cc
Metal-specific GEMM inst selection and warp-partition computation with kMPerWarp=8 for Metal. New CopyNode::CheckSIMDGroupCopy validation and LowerSIMDGroupCopy lowering using builtin::simdgroup_store. New FillNode lowering path for simdgroup-scoped buffers with 8×8 matrix boundary alignment checks. ParallelOpNode::InferLayout optional-Fragment safety guard.
Metal Simdgroup Register Tiles
tilelang/tileop/metal_simdgroup.py, tilelang/intrinsics/metal_macro_generator.py
Register-tile abstractions (RegisterTile, MMATile, RowVector) with layout metadata and bounds-checked fragment indexing. Macros for allocate/fill/load/store/mma operations on 8×8 fragments. MPSIntrinEmitter for warp-partitioned ldmatrix/mma/store generation with optional transposition. Scalar-tile reduction and row-wise operations (row_max, row_sum, mul_row, div_row).
Metal Quantization & Attention Patterns
tilelang/tileop/metal_quant.py, tilelang/tileop/metal_gdn.py
QuantSimdgroupTile sizing and tile-selection helpers for mixed-precision kernels. FP8 e4m3fn, FP4 e2m1fn, and e8m0 decoding functions for packed-format conversion. GDN/KKT tile building blocks (kkt_score_tile, apply_kkt_gate_triangular_tile) and W/U projection accumulators (wu_score_tiles, wu_linear_element).
Transform & Compilation Pipeline
tilelang/transform/metal_fragment_to_simdgroup.py, tilelang/transform/decouple_type_cast.py, tilelang/transform/layout_inference.cc, tilelang/transform/lower_device_storage_access_info.cc, tilelang/engine/phase.py, tilelang/engine/lower.py
New MetalFragmentToSimdgroup pass rewrites local.fragment accumulators to metal.simdgroup scope for Metal GEMM. is_local_buffer predicate extended to recognize simdgroup buffers. Layout-inference Metal target check skips fragment-completeness requirement. Storage-access lowering excludes ".fragment" buffers. Phase ordering inserts MetalFragmentToSimdgroup after pipeline injection and before layout inference. Engine switches Metal device codegen to target.build.tilelang_metal.
Metal GEMM Implementation & Wiring
tilelang/tileop/gemm/gemm_metal.py, tilelang/tileop/gemm/__init__.py
GemmMetal lowering validates M/N multiples of 8, partitions into simdgroup-resident accumulator or shared-intermediate paths, uses MPSIntrinEmitter for K-loop ldmatrix/mma/store sequences, and enforces block_K divisibility constraints. Instruction selection and implementation class mapping wired in __init__.py.
CMake & Build Configuration
src/backend/metal/CMakeLists.txt
Metal codegen source (src/target/codegen_metal.cc) always appended to TILE_LANG_SRCS for cross-platform codegen availability. Non-Apple path replaced set(USE_METAL OFF) with early return() for codegen-only mode.
Device Runtime & Adapter
tilelang/jit/adapter/base.py, tilelang/jit/adapter/torch/metal.py
BaseKernelAdapter.get_current_device_functor() adds explicit MPS fallback when CUDA unavailable. MetalKernelAdapter.get_kernel_source() method returns stored kernel MSL source.
Dependencies
pyproject.toml, requirements.txt, requirements-dev.txt
Added platform-specific upper bound apache-tvm-ffi<0.1.8; platform_system == 'Darwin' alongside existing apache-tvm-ffi~=0.1.0,>=0.1.2 base requirement.
Testing: Hardware & Codegen
testing/python/metal/test_metal_gemm_v2.py, testing/python/metal/test_metal_gemm_v2_linux.py, testing/python/metal/test_metal_simdgroup_store.py, testing/python/metal/test_metal_local_var.py
GEMM v2 correctness on MPS and cross-platform codegen verification (simdgroup operations present). Simdgroup store direct-accumulator-to-memory path testing. Local scalar variable code generation and runtime validation. Coverage documentation for internal runtime probes.
Testing: Internal Scaffolding & Benchmarks
testing/python/metal/test_metal_internal_scaffolding.py, testing/python/jit/test_tilelang_jit_adapter_mps.py, benchmark/matmul_metal/benchmark_matmul_metal.py
Internal probe kernels for DeepSeek packed fp8/fp4 decode/matmul and FlashQla GDN KKT/WU/component patterns with MPS runtime correctness vs CPU reference. Fail-closed subprocess tests for native fp8/fp4 storage rejection. MPS device-selection fallback tests. Standalone MPS matmul benchmarking script with configurable block-size sweep.

Sequence Diagram(s)

sequenceDiagram
    participant Host as TVM Host (Python)
    participant LowerPipeline as Lower Pipeline
    participant MetalPass as MetalFragmentToSimdgroup Pass
    participant MetalGEMM as GemmMetal Lowering
    participant MetalCodegen as Metal Code Generator
    participant MPS as MPS Device

    Host->>LowerPipeline: Lower TileLang kernel
    LowerPipeline->>MetalPass: Apply MetalFragmentToSimdgroup
    MetalPass->>MetalPass: Rewrite local.fragment→metal.simdgroup
    MetalPass-->>LowerPipeline: Updated PrimFunc
    LowerPipeline->>MetalGEMM: Lower GEMM ops
    MetalGEMM->>MetalGEMM: Compute warp partition (8 M-width)
    MetalGEMM->>MetalGEMM: Generate K-loop with ldmatrix/mma
    MetalGEMM-->>LowerPipeline: Kernel PrimFunc
    LowerPipeline->>MetalCodegen: target.build.tilelang_metal
    MetalCodegen->>MetalCodegen: Print Metal types/scopes
    MetalCodegen->>MetalCodegen: Emit FP8 prelude (if needed)
    MetalCodegen->>MetalCodegen: Emit kernel signature & body
    MetalCodegen-->>Host: Metal Shading Language source
    Host->>MPS: Compile & launch kernel
    MPS-->>Host: Results
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Reasoning: This PR introduces a substantial new backend (Metal/MPS) spanning multiple interconnected subsystems: a full code generator with FP8 MSL prelude logic, new GEMM and copy lowering paths with complex constraint checking, register-tile abstractions with ~500 lines of macro definitions, quantization/attention-pattern helpers, a compiler pass for scope rewriting, build system changes, and 300+ lines of diverse testing. The changes are heterogeneous in nature (codegen, GEMM scheduling, IR transforms, testing infrastructure), require understanding of TVM/Metal/MPS semantics and TileLang's tiling abstractions, and involve intricate constraint logic (e.g., simdgroup store warp tiling search, FP8 vector lane handling). While there are repetitive patterns in test modules and helper generation, the density of novel logic and cross-cutting dependencies between layers demand careful scrutiny across multiple specialized domains.

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 A simdgroup emerges from the Metal ether,
Eight-by-eight fragments gather in simdgroup tether,
FP8 whispers encoded in storage so tight,
GEMM warp partition dances through compile-time's night,
From Darwin to Linux, codegen takes flight! 🍎✨

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Extends TileLang’s Metal backend to better support low-precision and simdgroup-based lowering, including storage-only FP8 emulation with vectorized (lanes 2/3/4) cast helpers, and broader Metal GEMM/simdgroup infrastructure for codegen, lowering, and tests.

Changes:

  • Add Metal FP8 storage-only emulation support for vector casts (lanes 2/3/4) via emitted inline MSL helpers.
  • Introduce Metal simdgroup GEMM plumbing (IR transforms, intrinsics/macro emitter, copy/fill support, and new test coverage).
  • Adjust JIT adapter device selection to prefer MPS when CUDA is unavailable / initialization fails, and switch Metal codegen entrypoint to target.build.tilelang_metal.

Reviewed changes

Copilot reviewed 36 out of 37 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
tilelang/utils/language.py Add scope predicate for metal.simdgroup.
tilelang/transform/metal_fragment_to_simdgroup.py New Metal-only PrimFunc pass to rewrite GEMM accumulators from local.fragment to metal.simdgroup.
tilelang/transform/decouple_type_cast.py Treat metal.simdgroup buffers as “local/register-level” for cast decoupling.
tilelang/tileop/metal_simdgroup.py Add internal simdgroup tile helpers/macros (RegisterTile/RowVector, load/store/mma helpers).
tilelang/tileop/metal_quant.py Add packed-uint8 FP8/FP4/e8m0 decode helpers for Metal probes.
tilelang/tileop/metal_gdn.py Add internal GDN/attention-style tile macros built on simdgroup helpers.
tilelang/tileop/gemm/inst.py Add METAL_SIMDGROUP GemmInst selector.
tilelang/tileop/gemm/gemm_metal.py New Metal GEMM lowering using simdgroup_matrix intrinsics.
tilelang/tileop/gemm/init.py Wire Metal instruction selection and implementation class mapping.
tilelang/jit/adapter/torch/metal.py Expose kernel source getter for Metal torch adapter.
tilelang/jit/adapter/base.py Prefer MPS device when CUDA is unavailable/failed init.
tilelang/intrinsics/metal_macro_generator.py New MPS/simdgroup intrinsic emitter used by Metal GEMM lowering.
tilelang/engine/phase.py Insert Metal fragment→simdgroup rewrite before layout inference.
tilelang/engine/lower.py Switch Metal build entrypoint to target.build.tilelang_metal.
testing/python/metal/test_metal_simdgroup_store.py New tests for simdgroup-register accumulation and direct simdgroup_store to device memory.
testing/python/metal/test_metal_local_var.py New focused tests for local.var scalar lowering on Metal.
testing/python/metal/test_metal_internal_scaffolding.py Large internal scaffolding tests for simdgroup helpers + packed quant + GDN-style probes.
testing/python/metal/test_metal_gemm_v2.py Runtime GEMM correctness tests on Metal hardware.
testing/python/metal/test_metal_gemm_v2_linux.py Cross-platform (codegen-only) Metal GEMM source tests.
testing/python/metal/metal_internal_runtime_coverage.md Document internal Metal runtime/source-boundary coverage and opt-in benchmarks.
testing/python/jit/test_tilelang_jit_adapter_mps.py New tests validating device selection prefers MPS when CUDA is unavailable/fails.
src/transform/lower_device_storage_access_info.cc Treat fragment scope as special-case for memory info lowering.
src/transform/layout_inference.cc Skip fragment-layout completeness check on Metal targets.
src/target/codegen_metal.h Add TileLang Metal codegen class API, including FP8 prelude hooks.
src/target/codegen_metal.cc Implement TileLang Metal codegen, including FP8 scalar+vector preludes and simdgroup intrinsics emission.
src/op/utils.h Add metal.simdgroup buffer scope helpers.
src/op/parallel.cc Make fragment-layout use optional to avoid hard .value() assumptions.
src/op/gemm.h Add Metal simdgroup GEMM instruction enum value.
src/op/gemm.cc Select Metal GEMM inst on Metal targets; adjust warp partition heuristics.
src/op/fill.cc Add simdgroup-matrix-aware fill lowering via make_filled_simdgroup_matrix.
src/op/copy.h Add Metal simdgroup copy instruction and lowering hook.
src/op/copy.cc Implement simdgroup store lowering for Metal and bypass layout inference for that path.
src/backend/metal/CMakeLists.txt Always build Metal codegen source for cross-platform codegen-only mode.
requirements.txt Add Darwin-only cap for apache-tvm-ffi.
requirements-dev.txt Add Darwin-only cap for apache-tvm-ffi in dev requirements.
pyproject.toml Add Darwin-only cap for apache-tvm-ffi in project dependencies.
benchmark/matmul_metal/benchmark_matmul_metal.py Add Metal GEMM benchmark script using simdgroup GEMM lowering.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +68 to +103
def _rewrite_scope(body, var_map):
buf_map = {}

def _pre_order(stmt):
if isinstance(stmt, tir.Block):
new_alloc_bufs = []
changed = False
for buf in stmt.alloc_buffers:
new_buf = _remap_buffer(buf, var_map)
new_alloc_bufs.append(new_buf)
if not new_buf.same_as(buf):
buf_map[buf] = new_buf
changed = True
if changed:
new_body = tir.stmt_functor.substitute(stmt.body, var_map)
new_block = tir.Block(
stmt.iter_vars,
stmt.reads,
stmt.writes,
stmt.name_hint,
new_body,
stmt.init,
new_alloc_bufs,
stmt.match_buffers,
stmt.annotations,
)
return (
tir.BlockRealize(
stmt.iter_vars,
tir.const(True, "bool"),
new_block,
)
if False
else new_block
)
elif isinstance(stmt, tir.Allocate):
Comment thread src/op/copy.cc
Comment on lines +1097 to +1100
float ideal = N > 0 ? static_cast<float>(M) / N : 1.f;
float best_score = std::numeric_limits<float>::max();
for (int m = 1; m <= std::min(num_warps, max_m); ++m) {
if (num_warps % m != 0)
Comment on lines +25 to +33
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/transform.h>

#include <algorithm>
#include <sstream>
#include <string>
#include <unordered_map>
#include <utility>

Comment on lines 76 to 86
if torch.cuda.is_available():
try:
torch.cuda._lazy_init()
current_device = torch._C._cuda_getDevice
return lambda: torch.device("cuda", current_device())
except Exception:
return lambda: torch.device("cuda", torch.cuda.current_device())
pass
if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
return lambda: torch.device("mps")
# CPU fallback
return lambda: torch.device("cpu")
Comment on lines +436 to +445
// Check that all local.fragment buffers have inferred layouts.
// On Metal targets, fragment buffers used as GEMM accumulators are
// lowered to opaque simdgroup matrices, so they have no explicit
// thread-level layout and can be safely skipped.
for (const auto &[buffer, _] : use_list_) {
if (IsFragmentBuffer(buffer)) {
ICHECK_NE(layout_map.count(buffer), 0)
<< "The layout for fragment " << buffer
<< " can not be inferred correctly.";
if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) {
ICHECK(false) << "The layout for fragment " << buffer
<< " can not be inferred correctly.";
}
Comment thread pyproject.toml
Comment on lines 33 to +34
"apache-tvm-ffi~=0.1.0,>=0.1.2",
"apache-tvm-ffi<0.1.8; platform_system == 'Darwin'",
Comment on lines +42 to +58
void CodeGenTileLangMetal::InitFuncState(const PrimFunc &f) {
CodeGenC::InitFuncState(f);
// analyze the data;
for (Var arg : f->params) {
if (arg.dtype().is_handle()) {
alloc_storage_scope_[arg.get()] = "global";
}
}
}

CodeGenTileLangMetal::CodeGenTileLangMetal(Target target) : target_(target) {
decl_stream << "#include <metal_stdlib>\n";
decl_stream << "using namespace metal;\n\n";
decl_stream << "union __TVMArgUnion {\n"
<< " int v_int[2];\n"
<< "};\n\n";
}
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/jit/adapter/base.py (1)

69-86: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Two issues with the CUDA exception path and stale docstring.

1. Silent CUDA→CPU fallback on exception (line 82):
When torch.cuda.is_available() is True but _lazy_init() raises (broken CUDA driver, etc.), the pass causes the adapter to return torch.device("cpu") on non-Mac hosts. The prior fallback lambda torch.cuda.current_device() would have raised an explicit CUDA error at call time; the new path silently dispatches kernels to CPU, which can produce wrong results without a clear diagnostic.

Consider restoring a CUDA-specific fallback in the except block so CUDA systems stay on CUDA even if the fast internal path is unavailable:

🛠️ Proposed fix
         except Exception:
-            pass
+            return lambda: torch.device("cuda", torch.cuda.current_device())
     if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
         return lambda: torch.device("mps")

2. Stale docstring (line 71–74):
The docstring says "On CPU or when CUDA is unavailable, returns torch.device('cpu')" — this is now incorrect; MPS is returned when available.

📝 Proposed docstring update
-        Similar to the stream functor, we capture a callable that, when called,
-        fetches the current device according to PyTorch. On CPU or when CUDA is
-        unavailable, returns ``torch.device('cpu')``.
+        Similar to the stream functor, we capture a callable that, when called,
+        fetches the current device according to PyTorch. Falls back to MPS when
+        CUDA is unavailable and MPS is available, otherwise returns
+        ``torch.device('cpu')``.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/jit/adapter/base.py` around lines 69 - 86, The
get_current_device_functor function silently falls back to CPU if
torch.cuda._lazy_init() or the fast cuda path fails; change the except block to
return a CUDA-producing callable that calls torch.cuda.current_device() (or
otherwise raises the proper CUDA error at call time) so CUDA hosts do not
silently dispatch to CPU (refer to get_current_device_functor,
torch.cuda._lazy_init, torch._C._cuda_getDevice and torch.cuda.current_device);
also update the function docstring to reflect that MPS may be returned when
available (mention torch.backends.mps) instead of claiming CUDA/unavailable CPU
only.
🧹 Nitpick comments (2)
src/op/copy.cc (1)

1089-1126: 💤 Low value

Warp partitioning score computation uses different ratio than GEMM.

The ideal ratio calculation ideal = N > 0 ? M / N : 1.f differs from GemmWarpPolicyNode::computeWarpPartition which computes score as abs(m_per_warp / n_per_warp - ideal). Here the score is abs(m_per / n_per - ideal) where m_per = M / (m * kMPerWarp) and n_per = N / (n * kNPerWarp).

Both approaches aim for balanced workloads but use slightly different metrics. This is acceptable since the copy operation may have different optimal tiling than GEMM, but consider documenting this difference or unifying the approach if consistency is desired.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/copy.cc` around lines 1089 - 1126, The warp-partition scoring in this
copy lowering loop uses ideal = N>0 ? static_cast<float>(M)/N : 1.f and computes
score from m_per = M/(m*kMPerWarp) and n_per = N/(n*kNPerWarp), which differs
from GemmWarpPolicyNode::computeWarpPartition's m_per_warp/n_per_warp ratio;
either make the metric consistent by changing the score computation to mirror
GemmWarpPolicyNode::computeWarpPartition (use the same per-warp definitions and
ideal) or add a concise comment above this block (referencing ideal, m_per,
n_per and GemmWarpPolicyNode::computeWarpPartition) explaining why the copy op
uses a different ratio so future maintainers understand the intentional
divergence.
testing/python/metal/test_metal_internal_scaffolding.py (1)

425-453: ⚡ Quick win

Relax the exact float-literal source assertions.

These checks are currently tied to the printer's exact spelling of zero/one literals, so a harmless formatting change will break the tests without changing the generated behavior. Assert on the declaration/assignment pattern instead of the full literal text.

Suggested refactor
-    assert "float kkt_bias = 0.000000e+00f;" in src
+    assert "float kkt_bias" in src
+    assert "kkt_bias =" in src
...
-    assert "gate_state = 1.000000e+00f;" in gdn_src
+    assert "gate_state =" in gdn_src
...
-    assert "gate_state = 1.000000e+00f;" in gdn_src
+    assert "gate_state =" in gdn_src

Also applies to: 465-473

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_internal_scaffolding.py` around lines 425 -
453, The tests
test_flashqla_gdn_kkt_probe_combines_local_var_state_and_simdgroup_boundary and
test_scaled_packed_quant_and_gdn_probes_source_boundary_tokens assert exact
float literal text (e.g. "float kkt_bias = 0.000000e+00f;" and "gate_state =
1.000000e+00f;") which is brittle; change those assertions to check the
declaration/assignment pattern instead (e.g. assert "float kkt_bias" in src and
assert re.search(r"\bkkt_bias\s*=\s*[-+]?\d*\.?\d+(e[-+]?\d+)?f?\b", src) or
similarly for "gate_state" so the test verifies presence of the variable and an
assignment to a numeric literal rather than an exact formatted literal; update
the other occurrence around lines 465-473 the same way.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmark/matmul_metal/benchmark_matmul_metal.py`:
- Around line 108-125: The loop currently leaves best_config as configs[0] even
if every bench_tilelang call fails; change best_config to None (or similar
sentinel) at initialization and only assign it inside the try block when a run
succeeds (e.g., when tl > best_tflops or when first success), and when
args.sweep is true print the summary only if best_config is not None (otherwise
skip or print "no successful configs"); update references to best_tflops,
best_config, configs and bench_tilelang accordingly.

In `@src/target/codegen_metal.cc`:
- Around line 501-515: The ICHECK message for the constant_size validation in
the simdgroup scope is misleading: constant_size is an element count, not bytes.
Update the error text used with the check in the block that references
constant_size, op->dtype, simdgroup_dtype_, PrintType and vid to say "elements"
(e.g., "Only 8x8 matrix is supported, but got <n> elements") so the message
accurately reflects the validated quantity.

In `@testing/python/metal/test_metal_gemm_v2.py`:
- Around line 59-62: The TileLang kernel launch via jit_kernel(a, b, c) is
asynchronous on MPS; insert a call to torch.mps.synchronize() immediately after
jit_kernel(a, b, c) and before computing ref and asserting, so the Metal kernel
finishes before reading c; update the test in test_metal_gemm_v2.py to call
torch.mps.synchronize() right after the jit_kernel(...) invocation.

In `@testing/python/metal/test_metal_simdgroup_store.py`:
- Around line 49-52: The test launches the Metal kernel via kernel(a, b, c) then
immediately reads tensor c; insert an explicit device synchronization call
(torch.mps.synchronize()) after kernel(a, b, c) and before computing ref =
a.to(torch_accum_dtype) @ b.to(torch_accum_dtype) and the assert so the Metal
command queue finishes and the comparison against c is deterministic; update the
test near kernel(a, b, c) to call torch.mps.synchronize() before using c in the
reference computation and allclose check.

In `@tilelang/tileop/gemm/__init__.py`:
- Around line 173-174: The unconditional selection of GemmInst.METAL_SIMDGROUP
in the target_is_metal branch forces simdgroup lowering for all Metal GEMMs and
bypasses GemmMetal.lower()'s alignment and partition predicates; change the
logic in the target_is_metal branch to perform the same checks used by
GemmMetal.lower() (e.g., 8-alignment of shapes and valid per-warp partition
predicates implemented in tilelang/tileop/gemm/gemm_metal.py) and only return
GemmInst.METAL_SIMDGROUP when those predicates pass, otherwise fall back to the
scalar-safe path so invalid cases are not forced into METAL_SIMDGROUP.

In `@tilelang/tileop/metal_simdgroup.py`:
- Around line 385-396: The mma_tile macro currently ignores K fragments beyond 0
which drops contributions; update mma_tile in tilelang/tileop/metal_simdgroup.py
to either (A) reduce across K fragments by adding an inner loop over the
K-fragment dimension and calling mma for each k-fragment using a.index(tile_m,
k) and b.index(k, tile_n) (accumulating into acc.fragment at acc.index(tile_m,
tile_n)), or (B) fail closed by asserting the input MMATile(s) have fragments_k
== 1 (raise/assert when a.fragments_k or b.fragments_k > 1) so incorrect cases
are rejected. Use the identifiers mma_tile, MMATile, mma, acc.index, a.index and
b.index to locate and implement the fix.

In `@tilelang/transform/metal_fragment_to_simdgroup.py`:
- Around line 94-102: The code contains an unreachable conditional that always
returns new_block because of "if False", making the tir.BlockRealize
construction dead; remove the conditional and return new_block directly (or, if
BlockRealize is desired, replace the conditional with the proper condition),
updating the return in the function that builds the block (reference
tir.BlockRealize, new_block, and stmt.iter_vars) so only the intended branch is
returned and no unreachable code remains.

---

Outside diff comments:
In `@tilelang/jit/adapter/base.py`:
- Around line 69-86: The get_current_device_functor function silently falls back
to CPU if torch.cuda._lazy_init() or the fast cuda path fails; change the except
block to return a CUDA-producing callable that calls torch.cuda.current_device()
(or otherwise raises the proper CUDA error at call time) so CUDA hosts do not
silently dispatch to CPU (refer to get_current_device_functor,
torch.cuda._lazy_init, torch._C._cuda_getDevice and torch.cuda.current_device);
also update the function docstring to reflect that MPS may be returned when
available (mention torch.backends.mps) instead of claiming CUDA/unavailable CPU
only.

---

Nitpick comments:
In `@src/op/copy.cc`:
- Around line 1089-1126: The warp-partition scoring in this copy lowering loop
uses ideal = N>0 ? static_cast<float>(M)/N : 1.f and computes score from m_per =
M/(m*kMPerWarp) and n_per = N/(n*kNPerWarp), which differs from
GemmWarpPolicyNode::computeWarpPartition's m_per_warp/n_per_warp ratio; either
make the metric consistent by changing the score computation to mirror
GemmWarpPolicyNode::computeWarpPartition (use the same per-warp definitions and
ideal) or add a concise comment above this block (referencing ideal, m_per,
n_per and GemmWarpPolicyNode::computeWarpPartition) explaining why the copy op
uses a different ratio so future maintainers understand the intentional
divergence.

In `@testing/python/metal/test_metal_internal_scaffolding.py`:
- Around line 425-453: The tests
test_flashqla_gdn_kkt_probe_combines_local_var_state_and_simdgroup_boundary and
test_scaled_packed_quant_and_gdn_probes_source_boundary_tokens assert exact
float literal text (e.g. "float kkt_bias = 0.000000e+00f;" and "gate_state =
1.000000e+00f;") which is brittle; change those assertions to check the
declaration/assignment pattern instead (e.g. assert "float kkt_bias" in src and
assert re.search(r"\bkkt_bias\s*=\s*[-+]?\d*\.?\d+(e[-+]?\d+)?f?\b", src) or
similarly for "gate_state" so the test verifies presence of the variable and an
assignment to a numeric literal rather than an exact formatted literal; update
the other occurrence around lines 465-473 the same way.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 01662649-98c3-4e80-8215-d6ab6b031580

📥 Commits

Reviewing files that changed from the base of the PR and between d135bd1 and 81d8d96.

📒 Files selected for processing (37)
  • benchmark/matmul_metal/benchmark_matmul_metal.py
  • pyproject.toml
  • requirements-dev.txt
  • requirements.txt
  • src/backend/metal/CMakeLists.txt
  • src/op/copy.cc
  • src/op/copy.h
  • src/op/fill.cc
  • src/op/gemm.cc
  • src/op/gemm.h
  • src/op/parallel.cc
  • src/op/utils.h
  • src/target/codegen_metal.cc
  • src/target/codegen_metal.h
  • src/transform/layout_inference.cc
  • src/transform/lower_device_storage_access_info.cc
  • testing/python/jit/test_tilelang_jit_adapter_mps.py
  • testing/python/metal/metal_internal_runtime_coverage.md
  • testing/python/metal/test_metal_gemm_v2.py
  • testing/python/metal/test_metal_gemm_v2_linux.py
  • testing/python/metal/test_metal_internal_scaffolding.py
  • testing/python/metal/test_metal_local_var.py
  • testing/python/metal/test_metal_simdgroup_store.py
  • tilelang/engine/lower.py
  • tilelang/engine/phase.py
  • tilelang/intrinsics/metal_macro_generator.py
  • tilelang/jit/adapter/base.py
  • tilelang/jit/adapter/torch/metal.py
  • tilelang/tileop/gemm/__init__.py
  • tilelang/tileop/gemm/gemm_metal.py
  • tilelang/tileop/gemm/inst.py
  • tilelang/tileop/metal_gdn.py
  • tilelang/tileop/metal_quant.py
  • tilelang/tileop/metal_simdgroup.py
  • tilelang/transform/decouple_type_cast.py
  • tilelang/transform/metal_fragment_to_simdgroup.py
  • tilelang/utils/language.py

Comment on lines +108 to +125
best_tflops = 0.0
best_config = configs[0]
for bM, bN, bK in configs:
try:
tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats)
ratio = tl / ref_tflops * 100
tag = ""
if tl > best_tflops:
best_tflops = tl
best_config = (bM, bN, bK)
print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%")
except Exception as e:
print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}")

if args.sweep:
print()
print(f"Best config: {best_config}")
print(f"Best TFlops: {best_tflops:.1f}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Don't print a winner when every sweep config failed.

Lines 108-125 still report configs[0] as the best config if every benchmark attempt throws. That makes the summary misleading in exactly the case where the per-config error handling is supposed to help.

Suggested fix
-    best_tflops = 0.0
-    best_config = configs[0]
+    best_tflops = 0.0
+    best_config = None
     for bM, bN, bK in configs:
         try:
             tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats)
             ratio = tl / ref_tflops * 100
-            tag = ""
             if tl > best_tflops:
                 best_tflops = tl
                 best_config = (bM, bN, bK)
             print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%")
         except Exception as e:
             print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}")

-    if args.sweep:
+    if args.sweep and best_config is not None:
         print()
         print(f"Best config: {best_config}")
         print(f"Best TFlops: {best_tflops:.1f}")
         print(f"Reference TFlops (PyTorch MPS): {ref_tflops:.1f}")
+    elif args.sweep:
+        print()
+        print("No TileLang configuration completed successfully.")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
best_tflops = 0.0
best_config = configs[0]
for bM, bN, bK in configs:
try:
tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats)
ratio = tl / ref_tflops * 100
tag = ""
if tl > best_tflops:
best_tflops = tl
best_config = (bM, bN, bK)
print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%")
except Exception as e:
print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}")
if args.sweep:
print()
print(f"Best config: {best_config}")
print(f"Best TFlops: {best_tflops:.1f}")
best_tflops = 0.0
best_config = None
for bM, bN, bK in configs:
try:
tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats)
ratio = tl / ref_tflops * 100
if tl > best_tflops:
best_tflops = tl
best_config = (bM, bN, bK)
print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%")
except Exception as e:
print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}")
if args.sweep and best_config is not None:
print()
print(f"Best config: {best_config}")
print(f"Best TFlops: {best_tflops:.1f}")
print(f"Reference TFlops (PyTorch MPS): {ref_tflops:.1f}")
elif args.sweep:
print()
print("No TileLang configuration completed successfully.")
🧰 Tools
🪛 Ruff (0.15.12)

[warning] 119-119: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmark/matmul_metal/benchmark_matmul_metal.py` around lines 108 - 125, The
loop currently leaves best_config as configs[0] even if every bench_tilelang
call fails; change best_config to None (or similar sentinel) at initialization
and only assign it inside the try block when a run succeeds (e.g., when tl >
best_tflops or when first success), and when args.sweep is true print the
summary only if best_config is not None (otherwise skip or print "no successful
configs"); update references to best_tflops, best_config, configs and
bench_tilelang accordingly.

Comment on lines +501 to +515
if (scope == "metal.simdgroup") {
ICHECK(op->dtype == DataType::Float(16) ||
op->dtype == DataType::Float(32) ||
op->dtype == DataType::BFloat(16))
<< "Only float16, float32, and bfloat16 are supported, but got "
<< op->dtype;
ICHECK(constant_size % 64 == 0) << "Only 8x8 matrix is supported, but got "
<< constant_size << " bytes\n";

std::ostringstream dtype_os;
PrintType(op->dtype, dtype_os);
std::string dtype_str = dtype_os.str();
simdgroup_dtype_[op->buffer_var.get()] = dtype_str;
stream << "simdgroup_" << dtype_str << "8x8 " << vid << '['
<< constant_size / 64 << "];\n";
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Error message says "bytes" but validation checks element count.

The ICHECK message at line 508 mentions "bytes" but constant_size represents element count, not bytes. For an 8x8 matrix with 64 elements, this check is correct but the message is misleading.

📝 Proposed fix for error message
-    ICHECK(constant_size % 64 == 0) << "Only 8x8 matrix is supported, but got "
-                                    << constant_size << " bytes\n";
+    ICHECK(constant_size % 64 == 0) << "Only 8x8 matrix is supported, but got "
+                                    << constant_size << " elements\n";
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (scope == "metal.simdgroup") {
ICHECK(op->dtype == DataType::Float(16) ||
op->dtype == DataType::Float(32) ||
op->dtype == DataType::BFloat(16))
<< "Only float16, float32, and bfloat16 are supported, but got "
<< op->dtype;
ICHECK(constant_size % 64 == 0) << "Only 8x8 matrix is supported, but got "
<< constant_size << " bytes\n";
std::ostringstream dtype_os;
PrintType(op->dtype, dtype_os);
std::string dtype_str = dtype_os.str();
simdgroup_dtype_[op->buffer_var.get()] = dtype_str;
stream << "simdgroup_" << dtype_str << "8x8 " << vid << '['
<< constant_size / 64 << "];\n";
if (scope == "metal.simdgroup") {
ICHECK(op->dtype == DataType::Float(16) ||
op->dtype == DataType::Float(32) ||
op->dtype == DataType::BFloat(16))
<< "Only float16, float32, and bfloat16 are supported, but got "
<< op->dtype;
ICHECK(constant_size % 64 == 0) << "Only 8x8 matrix is supported, but got "
<< constant_size << " elements\n";
std::ostringstream dtype_os;
PrintType(op->dtype, dtype_os);
std::string dtype_str = dtype_os.str();
simdgroup_dtype_[op->buffer_var.get()] = dtype_str;
stream << "simdgroup_" << dtype_str << "8x8 " << vid << '['
<< constant_size / 64 << "];\n";
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/target/codegen_metal.cc` around lines 501 - 515, The ICHECK message for
the constant_size validation in the simdgroup scope is misleading: constant_size
is an element count, not bytes. Update the error text used with the check in the
block that references constant_size, op->dtype, simdgroup_dtype_, PrintType and
vid to say "elements" (e.g., "Only 8x8 matrix is supported, but got <n>
elements") so the message accurately reflects the validated quantity.

Comment on lines +59 to +62
jit_kernel(a, b, c)

ref = a.to(torch_accum_dtype) @ b.to(torch_accum_dtype)
assert torch.allclose(ref, c, atol=atol), (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find testing -name "test_metal_gemm_v2.py" -type f

Repository: tile-ai/tilelang

Length of output: 103


🏁 Script executed:

cat testing/python/metal/test_metal_gemm_v2.py

Repository: tile-ai/tilelang

Length of output: 2722


🏁 Script executed:

grep -r "torch.mps.synchronize" testing/ --include="*.py"

Repository: tile-ai/tilelang

Length of output: 1525


🏁 Script executed:

grep -A 3 "torch.mps.synchronize()" testing/python/metal/test_metal_internal_scaffolding.py | head -20

Repository: tile-ai/tilelang

Length of output: 684


Add torch.mps.synchronize() after the TileLang kernel launch.

Line 59 enqueues the Metal kernel asynchronously, and lines 61-62 immediately read and compare the output c. Metal operations on PyTorch MPS are asynchronous and require explicit synchronization before consuming results. All other Metal tests in this PR follow this pattern—synchronize after kernel launch, then compare. Without it, the assertion can race the async kernel and fail intermittently.

Suggested fix
     jit_kernel(a, b, c)
+    torch.mps.synchronize()
 
     ref = a.to(torch_accum_dtype) @ b.to(torch_accum_dtype)
     assert torch.allclose(ref, c, atol=atol), (
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
jit_kernel(a, b, c)
ref = a.to(torch_accum_dtype) @ b.to(torch_accum_dtype)
assert torch.allclose(ref, c, atol=atol), (
jit_kernel(a, b, c)
torch.mps.synchronize()
ref = a.to(torch_accum_dtype) @ b.to(torch_accum_dtype)
assert torch.allclose(ref, c, atol=atol), (
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2.py` around lines 59 - 62, The
TileLang kernel launch via jit_kernel(a, b, c) is asynchronous on MPS; insert a
call to torch.mps.synchronize() immediately after jit_kernel(a, b, c) and before
computing ref and asserting, so the Metal kernel finishes before reading c;
update the test in test_metal_gemm_v2.py to call torch.mps.synchronize() right
after the jit_kernel(...) invocation.

Comment on lines +49 to +52
kernel(a, b, c)

ref = a.to(torch_accum_dtype) @ b.to(torch_accum_dtype)
assert torch.allclose(ref, c, atol=atol), (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n testing/python/metal/test_metal_simdgroup_store.py | head -70

Repository: tile-ai/tilelang

Length of output: 3572


🏁 Script executed:

find . -name "test_metal_gemm_v2.py" -type f | head -5

Repository: tile-ai/tilelang

Length of output: 105


🏁 Script executed:

cat -n ./testing/python/metal/test_metal_gemm_v2.py | head -100

Repository: tile-ai/tilelang

Length of output: 3359


🏁 Script executed:

grep -n "torch.mps.synchronize" ./testing/python/metal/test_metal_gemm_v2.py ./testing/python/metal/test_metal_simdgroup_store.py

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

grep -r "torch.mps.synchronize" ./testing/python/metal/ --include="*.py"

Repository: tile-ai/tilelang

Length of output: 1559


🏁 Script executed:

cat -n ./testing/python/metal/test_metal_local_var.py | grep -A5 -B5 "torch.mps.synchronize"

Repository: tile-ai/tilelang

Length of output: 384


🏁 Script executed:

cat -n ./testing/python/metal/test_metal_internal_scaffolding.py | head -100

Repository: tile-ai/tilelang

Length of output: 4262


🏁 Script executed:

grep -B10 -A2 "torch.mps.synchronize()" ./testing/python/metal/test_metal_internal_scaffolding.py | head -50

Repository: tile-ai/tilelang

Length of output: 2290


🏁 Script executed:

cat -n ./testing/python/metal/test_metal_gemm_v2.py | sed -n '50,70p'

Repository: tile-ai/tilelang

Length of output: 957


Synchronize before comparing the MPS output tensor.

Line 49 launches the external kernel, and lines 51-52 immediately read c in PyTorch. Without an explicit torch.mps.synchronize(), this helper can race the Metal command queue and produce flaky correctness failures.

Suggested fix
    kernel(a, b, c)
+   torch.mps.synchronize()

    ref = a.to(torch_accum_dtype) @ b.to(torch_accum_dtype)
    assert torch.allclose(ref, c, atol=atol), (
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
kernel(a, b, c)
ref = a.to(torch_accum_dtype) @ b.to(torch_accum_dtype)
assert torch.allclose(ref, c, atol=atol), (
kernel(a, b, c)
torch.mps.synchronize()
ref = a.to(torch_accum_dtype) @ b.to(torch_accum_dtype)
assert torch.allclose(ref, c, atol=atol), (
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_simdgroup_store.py` around lines 49 - 52, The
test launches the Metal kernel via kernel(a, b, c) then immediately reads tensor
c; insert an explicit device synchronization call (torch.mps.synchronize())
after kernel(a, b, c) and before computing ref = a.to(torch_accum_dtype) @
b.to(torch_accum_dtype) and the assert so the Metal command queue finishes and
the comparison against c is deterministic; update the test near kernel(a, b, c)
to call torch.mps.synchronize() before using c in the reference computation and
allclose check.

Comment on lines +173 to +174
if target_is_metal(target):
return GemmInst.METAL_SIMDGROUP
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Don’t force METAL_SIMDGROUP for every Metal GEMM.

GemmMetal.lower() still rejects non-8-aligned shapes and invalid per-warp partitions (tilelang/tileop/gemm/gemm_metal.py:21-35). This unconditional return bypasses fallback selection and turns those cases into hard ValueErrors instead of using a scalar-safe path.

Suggested direction
-        if target_is_metal(target):
-            return GemmInst.METAL_SIMDGROUP
+        if target_is_metal(target) and GemmMetal(self).can_lower(target, thread_nums):
+            return GemmInst.METAL_SIMDGROUP
         return GemmInst(_ffi_api.GemmGetGemmInst(self, int(thread_nums), target))

Use the same predicates that GemmMetal.lower() enforces, and fall back when simdgroup lowering is not applicable.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/tileop/gemm/__init__.py` around lines 173 - 174, The unconditional
selection of GemmInst.METAL_SIMDGROUP in the target_is_metal branch forces
simdgroup lowering for all Metal GEMMs and bypasses GemmMetal.lower()'s
alignment and partition predicates; change the logic in the target_is_metal
branch to perform the same checks used by GemmMetal.lower() (e.g., 8-alignment
of shapes and valid per-warp partition predicates implemented in
tilelang/tileop/gemm/gemm_metal.py) and only return GemmInst.METAL_SIMDGROUP
when those predicates pass, otherwise fall back to the scalar-safe path so
invalid cases are not forced into METAL_SIMDGROUP.

Comment on lines +385 to +396
@T.macro
def mma_tile(acc: MMATile, a: MMATile, b: MMATile) -> None:
for tile_m in T.unroll(acc.fragments_m, explicit=True):
for tile_n in T.unroll(acc.fragments_n, explicit=True):
mma(
acc.fragment,
a.fragment,
b.fragment,
acc.index(tile_m, tile_n),
a.index(tile_m, 0),
b.index(0, tile_n),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reduce across K fragments in mma_tile, or fail closed.

This only multiplies a.index(tile_m, 0) with b.index(0, tile_n). For multi-fragment K tiles, every slice after 0 is silently dropped, so the accumulator is wrong.

Suggested fix
 `@T.macro`
 def mma_tile(acc: MMATile, a: MMATile, b: MMATile) -> None:
+    if a.fragments_n != b.fragments_m:
+        raise ValueError(
+            f"mma_tile requires matching K fragments, got {a.fragments_n} and {b.fragments_m}"
+        )
     for tile_m in T.unroll(acc.fragments_m, explicit=True):
         for tile_n in T.unroll(acc.fragments_n, explicit=True):
-            mma(
-                acc.fragment,
-                a.fragment,
-                b.fragment,
-                acc.index(tile_m, tile_n),
-                a.index(tile_m, 0),
-                b.index(0, tile_n),
-            )
+            for tile_k in T.unroll(a.fragments_n, explicit=True):
+                mma(
+                    acc.fragment,
+                    a.fragment,
+                    b.fragment,
+                    acc.index(tile_m, tile_n),
+                    a.index(tile_m, tile_k),
+                    b.index(tile_k, tile_n),
+                )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@T.macro
def mma_tile(acc: MMATile, a: MMATile, b: MMATile) -> None:
for tile_m in T.unroll(acc.fragments_m, explicit=True):
for tile_n in T.unroll(acc.fragments_n, explicit=True):
mma(
acc.fragment,
a.fragment,
b.fragment,
acc.index(tile_m, tile_n),
a.index(tile_m, 0),
b.index(0, tile_n),
)
`@T.macro`
def mma_tile(acc: MMATile, a: MMATile, b: MMATile) -> None:
if a.fragments_n != b.fragments_m:
raise ValueError(
f"mma_tile requires matching K fragments, got {a.fragments_n} and {b.fragments_m}"
)
for tile_m in T.unroll(acc.fragments_m, explicit=True):
for tile_n in T.unroll(acc.fragments_n, explicit=True):
for tile_k in T.unroll(a.fragments_n, explicit=True):
mma(
acc.fragment,
a.fragment,
b.fragment,
acc.index(tile_m, tile_n),
a.index(tile_m, tile_k),
b.index(tile_k, tile_n),
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/tileop/metal_simdgroup.py` around lines 385 - 396, The mma_tile
macro currently ignores K fragments beyond 0 which drops contributions; update
mma_tile in tilelang/tileop/metal_simdgroup.py to either (A) reduce across K
fragments by adding an inner loop over the K-fragment dimension and calling mma
for each k-fragment using a.index(tile_m, k) and b.index(k, tile_n)
(accumulating into acc.fragment at acc.index(tile_m, tile_n)), or (B) fail
closed by asserting the input MMATile(s) have fragments_k == 1 (raise/assert
when a.fragments_k or b.fragments_k > 1) so incorrect cases are rejected. Use
the identifiers mma_tile, MMATile, mma, acc.index, a.index and b.index to locate
and implement the fix.

Comment on lines +94 to +102
return (
tir.BlockRealize(
stmt.iter_vars,
tir.const(True, "bool"),
new_block,
)
if False
else new_block
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Dead code: if False branch is unreachable.

The conditional if False else new_block always evaluates to new_block, making the tir.BlockRealize construction unreachable. This appears to be leftover debugging or incomplete code.

🧹 Proposed fix to remove dead code
-                return (
-                    tir.BlockRealize(
-                        stmt.iter_vars,
-                        tir.const(True, "bool"),
-                        new_block,
-                    )
-                    if False
-                    else new_block
-                )
+                return new_block
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
return (
tir.BlockRealize(
stmt.iter_vars,
tir.const(True, "bool"),
new_block,
)
if False
else new_block
)
return new_block
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/transform/metal_fragment_to_simdgroup.py` around lines 94 - 102, The
code contains an unreachable conditional that always returns new_block because
of "if False", making the tir.BlockRealize construction dead; remove the
conditional and return new_block directly (or, if BlockRealize is desired,
replace the conditional with the proper condition), updating the return in the
function that builds the block (reference tir.BlockRealize, new_block, and
stmt.iter_vars) so only the intended branch is returned and no unreachable code
remains.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants