[Metal] FP8 vector cast lanes 2/3/4 (extends storage-only FP8)#2145
[Metal] FP8 vector cast lanes 2/3/4 (extends storage-only FP8)#2145apstenku123 wants to merge 12 commits intotile-ai:mainfrom
Conversation
Add T.gemm support for Apple Metal using simdgroup_matrix 8x8 operations (simdgroup_load/store/multiply_accumulate). Works on all Apple Silicon (M1-M5) without requiring a TVM fork. Key changes: - codegen_metal.cc/h: Fork TVM Metal codegen to tilelang with simdgroup intrinsic emission and 128-bit vectorized copy - gemm_metal.py: GemmMetal tile operator for sharedxshared GEMM - metal_macro_generator.py: MPSIntrinEmitter for simdgroup MMA macros - metal_fragment_to_simdgroup.py: Pass rewrites local.fragment GEMM accumulators to metal.simdgroup scope before layout inference - LowerSIMDGroupCopy in copy.cc for fragment->device simdgroup_store 24 Metal tests (codegen cross-platform + correctness on device).
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
…, #37/#38/#39) Three parallel agents completed the supermodule/submodule split filing: 1. tilelang_metal_fp8 (storage-only FP8 emulation) split: - 0001-tilelang-metal-fp8-storage-only.patch — supermodule half (235 lines) - 0002-tvm-metal-fp8-storage-only.patch — TVM-mirror half (260 lines, prefix stripped) - PR tile-ai/tilelang#2144 (supermodule, stacks on PR #2130) - PR tile-ai/tvm#38 (TVM mirror, base tilelang_main @ 0e15b274) 2. tilelang_metal_fp8_vector (vector cast lanes 2/3/4) split: - 0001-tilelang-metal-fp8-vector-cast.patch — supermodule half (148 lines) - 0002-tvm-metal-fp8-vector-cast.patch — TVM-mirror half (151 lines) - PR tile-ai/tilelang#2145 (supermodule, depends on #2144) - PR tile-ai/tvm#39 (TVM mirror, depends on #38) 3. PR #2143 TVM-mirror companion: - PR tile-ai/tvm#37 — already filed, README updated to link both halves Total filed today: 11 PRs across 3 repos - 1 ml-explore/mlx (#3476) - 1 apache/tvm (#19504) - 6 tile-ai/tilelang (#2139, #2140, #2141, #2142, #2143 super, #2144 super, #2145 super) - 3 tile-ai/tvm (#37, #38, #39 — TVM-mirror companions) PR #2142 (T.fp8_scaled_matmul) has no TVM-mirror companion needed — verified the patch only touches supermodule files. All splits round-trip clean (apply forward + reverse) on their respective bases. README files in each docs/upstream/<dir>/ updated with PR URLs and dependency-chain diagrams. Note: TileLang/tvm redirects to tile-ai/tvm server-side (canonical org slug). All TVM-mirror PRs land at tile-ai/tvm/pull/N URLs.
📝 WalkthroughWalkthroughIntroduces comprehensive Metal (MPS) backend support for TileLang including native code generation, simdgroup-scoped GEMM and tensor operations, register-tile abstractions with multiply-accumulate macros, quantization and attention-pattern helpers, and extensive hardware/codegen testing. Updates build system, dependency constraints, device selection logic, and compiler phase ordering. ChangesMetal Backend Implementation
Sequence Diagram(s)sequenceDiagram
participant Host as TVM Host (Python)
participant LowerPipeline as Lower Pipeline
participant MetalPass as MetalFragmentToSimdgroup Pass
participant MetalGEMM as GemmMetal Lowering
participant MetalCodegen as Metal Code Generator
participant MPS as MPS Device
Host->>LowerPipeline: Lower TileLang kernel
LowerPipeline->>MetalPass: Apply MetalFragmentToSimdgroup
MetalPass->>MetalPass: Rewrite local.fragment→metal.simdgroup
MetalPass-->>LowerPipeline: Updated PrimFunc
LowerPipeline->>MetalGEMM: Lower GEMM ops
MetalGEMM->>MetalGEMM: Compute warp partition (8 M-width)
MetalGEMM->>MetalGEMM: Generate K-loop with ldmatrix/mma
MetalGEMM-->>LowerPipeline: Kernel PrimFunc
LowerPipeline->>MetalCodegen: target.build.tilelang_metal
MetalCodegen->>MetalCodegen: Print Metal types/scopes
MetalCodegen->>MetalCodegen: Emit FP8 prelude (if needed)
MetalCodegen->>MetalCodegen: Emit kernel signature & body
MetalCodegen-->>Host: Metal Shading Language source
Host->>MPS: Compile & launch kernel
MPS-->>Host: Results
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Reasoning: This PR introduces a substantial new backend (Metal/MPS) spanning multiple interconnected subsystems: a full code generator with FP8 MSL prelude logic, new GEMM and copy lowering paths with complex constraint checking, register-tile abstractions with ~500 lines of macro definitions, quantization/attention-pattern helpers, a compiler pass for scope rewriting, build system changes, and 300+ lines of diverse testing. The changes are heterogeneous in nature (codegen, GEMM scheduling, IR transforms, testing infrastructure), require understanding of TVM/Metal/MPS semantics and TileLang's tiling abstractions, and involve intricate constraint logic (e.g., simdgroup store warp tiling search, FP8 vector lane handling). While there are repetitive patterns in test modules and helper generation, the density of novel logic and cross-cutting dependencies between layers demand careful scrutiny across multiple specialized domains. Possibly related PRs
Suggested reviewers
Poem
✨ Finishing Touches🧪 Generate unit tests (beta)
|
There was a problem hiding this comment.
Pull request overview
Extends TileLang’s Metal backend to better support low-precision and simdgroup-based lowering, including storage-only FP8 emulation with vectorized (lanes 2/3/4) cast helpers, and broader Metal GEMM/simdgroup infrastructure for codegen, lowering, and tests.
Changes:
- Add Metal FP8 storage-only emulation support for vector casts (lanes 2/3/4) via emitted inline MSL helpers.
- Introduce Metal simdgroup GEMM plumbing (IR transforms, intrinsics/macro emitter, copy/fill support, and new test coverage).
- Adjust JIT adapter device selection to prefer MPS when CUDA is unavailable / initialization fails, and switch Metal codegen entrypoint to
target.build.tilelang_metal.
Reviewed changes
Copilot reviewed 36 out of 37 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| tilelang/utils/language.py | Add scope predicate for metal.simdgroup. |
| tilelang/transform/metal_fragment_to_simdgroup.py | New Metal-only PrimFunc pass to rewrite GEMM accumulators from local.fragment to metal.simdgroup. |
| tilelang/transform/decouple_type_cast.py | Treat metal.simdgroup buffers as “local/register-level” for cast decoupling. |
| tilelang/tileop/metal_simdgroup.py | Add internal simdgroup tile helpers/macros (RegisterTile/RowVector, load/store/mma helpers). |
| tilelang/tileop/metal_quant.py | Add packed-uint8 FP8/FP4/e8m0 decode helpers for Metal probes. |
| tilelang/tileop/metal_gdn.py | Add internal GDN/attention-style tile macros built on simdgroup helpers. |
| tilelang/tileop/gemm/inst.py | Add METAL_SIMDGROUP GemmInst selector. |
| tilelang/tileop/gemm/gemm_metal.py | New Metal GEMM lowering using simdgroup_matrix intrinsics. |
| tilelang/tileop/gemm/init.py | Wire Metal instruction selection and implementation class mapping. |
| tilelang/jit/adapter/torch/metal.py | Expose kernel source getter for Metal torch adapter. |
| tilelang/jit/adapter/base.py | Prefer MPS device when CUDA is unavailable/failed init. |
| tilelang/intrinsics/metal_macro_generator.py | New MPS/simdgroup intrinsic emitter used by Metal GEMM lowering. |
| tilelang/engine/phase.py | Insert Metal fragment→simdgroup rewrite before layout inference. |
| tilelang/engine/lower.py | Switch Metal build entrypoint to target.build.tilelang_metal. |
| testing/python/metal/test_metal_simdgroup_store.py | New tests for simdgroup-register accumulation and direct simdgroup_store to device memory. |
| testing/python/metal/test_metal_local_var.py | New focused tests for local.var scalar lowering on Metal. |
| testing/python/metal/test_metal_internal_scaffolding.py | Large internal scaffolding tests for simdgroup helpers + packed quant + GDN-style probes. |
| testing/python/metal/test_metal_gemm_v2.py | Runtime GEMM correctness tests on Metal hardware. |
| testing/python/metal/test_metal_gemm_v2_linux.py | Cross-platform (codegen-only) Metal GEMM source tests. |
| testing/python/metal/metal_internal_runtime_coverage.md | Document internal Metal runtime/source-boundary coverage and opt-in benchmarks. |
| testing/python/jit/test_tilelang_jit_adapter_mps.py | New tests validating device selection prefers MPS when CUDA is unavailable/fails. |
| src/transform/lower_device_storage_access_info.cc | Treat fragment scope as special-case for memory info lowering. |
| src/transform/layout_inference.cc | Skip fragment-layout completeness check on Metal targets. |
| src/target/codegen_metal.h | Add TileLang Metal codegen class API, including FP8 prelude hooks. |
| src/target/codegen_metal.cc | Implement TileLang Metal codegen, including FP8 scalar+vector preludes and simdgroup intrinsics emission. |
| src/op/utils.h | Add metal.simdgroup buffer scope helpers. |
| src/op/parallel.cc | Make fragment-layout use optional to avoid hard .value() assumptions. |
| src/op/gemm.h | Add Metal simdgroup GEMM instruction enum value. |
| src/op/gemm.cc | Select Metal GEMM inst on Metal targets; adjust warp partition heuristics. |
| src/op/fill.cc | Add simdgroup-matrix-aware fill lowering via make_filled_simdgroup_matrix. |
| src/op/copy.h | Add Metal simdgroup copy instruction and lowering hook. |
| src/op/copy.cc | Implement simdgroup store lowering for Metal and bypass layout inference for that path. |
| src/backend/metal/CMakeLists.txt | Always build Metal codegen source for cross-platform codegen-only mode. |
| requirements.txt | Add Darwin-only cap for apache-tvm-ffi. |
| requirements-dev.txt | Add Darwin-only cap for apache-tvm-ffi in dev requirements. |
| pyproject.toml | Add Darwin-only cap for apache-tvm-ffi in project dependencies. |
| benchmark/matmul_metal/benchmark_matmul_metal.py | Add Metal GEMM benchmark script using simdgroup GEMM lowering. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _rewrite_scope(body, var_map): | ||
| buf_map = {} | ||
|
|
||
| def _pre_order(stmt): | ||
| if isinstance(stmt, tir.Block): | ||
| new_alloc_bufs = [] | ||
| changed = False | ||
| for buf in stmt.alloc_buffers: | ||
| new_buf = _remap_buffer(buf, var_map) | ||
| new_alloc_bufs.append(new_buf) | ||
| if not new_buf.same_as(buf): | ||
| buf_map[buf] = new_buf | ||
| changed = True | ||
| if changed: | ||
| new_body = tir.stmt_functor.substitute(stmt.body, var_map) | ||
| new_block = tir.Block( | ||
| stmt.iter_vars, | ||
| stmt.reads, | ||
| stmt.writes, | ||
| stmt.name_hint, | ||
| new_body, | ||
| stmt.init, | ||
| new_alloc_bufs, | ||
| stmt.match_buffers, | ||
| stmt.annotations, | ||
| ) | ||
| return ( | ||
| tir.BlockRealize( | ||
| stmt.iter_vars, | ||
| tir.const(True, "bool"), | ||
| new_block, | ||
| ) | ||
| if False | ||
| else new_block | ||
| ) | ||
| elif isinstance(stmt, tir.Allocate): |
| float ideal = N > 0 ? static_cast<float>(M) / N : 1.f; | ||
| float best_score = std::numeric_limits<float>::max(); | ||
| for (int m = 1; m <= std::min(num_warps, max_m); ++m) { | ||
| if (num_warps % m != 0) |
| #include <tvm/ffi/reflection/registry.h> | ||
| #include <tvm/tir/transform.h> | ||
|
|
||
| #include <algorithm> | ||
| #include <sstream> | ||
| #include <string> | ||
| #include <unordered_map> | ||
| #include <utility> | ||
|
|
| if torch.cuda.is_available(): | ||
| try: | ||
| torch.cuda._lazy_init() | ||
| current_device = torch._C._cuda_getDevice | ||
| return lambda: torch.device("cuda", current_device()) | ||
| except Exception: | ||
| return lambda: torch.device("cuda", torch.cuda.current_device()) | ||
| pass | ||
| if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available(): | ||
| return lambda: torch.device("mps") | ||
| # CPU fallback | ||
| return lambda: torch.device("cpu") |
| // Check that all local.fragment buffers have inferred layouts. | ||
| // On Metal targets, fragment buffers used as GEMM accumulators are | ||
| // lowered to opaque simdgroup matrices, so they have no explicit | ||
| // thread-level layout and can be safely skipped. | ||
| for (const auto &[buffer, _] : use_list_) { | ||
| if (IsFragmentBuffer(buffer)) { | ||
| ICHECK_NE(layout_map.count(buffer), 0) | ||
| << "The layout for fragment " << buffer | ||
| << " can not be inferred correctly."; | ||
| if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) { | ||
| ICHECK(false) << "The layout for fragment " << buffer | ||
| << " can not be inferred correctly."; | ||
| } |
| "apache-tvm-ffi~=0.1.0,>=0.1.2", | ||
| "apache-tvm-ffi<0.1.8; platform_system == 'Darwin'", |
| void CodeGenTileLangMetal::InitFuncState(const PrimFunc &f) { | ||
| CodeGenC::InitFuncState(f); | ||
| // analyze the data; | ||
| for (Var arg : f->params) { | ||
| if (arg.dtype().is_handle()) { | ||
| alloc_storage_scope_[arg.get()] = "global"; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| CodeGenTileLangMetal::CodeGenTileLangMetal(Target target) : target_(target) { | ||
| decl_stream << "#include <metal_stdlib>\n"; | ||
| decl_stream << "using namespace metal;\n\n"; | ||
| decl_stream << "union __TVMArgUnion {\n" | ||
| << " int v_int[2];\n" | ||
| << "};\n\n"; | ||
| } |
There was a problem hiding this comment.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/jit/adapter/base.py (1)
69-86:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winTwo issues with the CUDA exception path and stale docstring.
1. Silent CUDA→CPU fallback on exception (line 82):
Whentorch.cuda.is_available()isTruebut_lazy_init()raises (broken CUDA driver, etc.), thepasscauses the adapter to returntorch.device("cpu")on non-Mac hosts. The prior fallback lambdatorch.cuda.current_device()would have raised an explicit CUDA error at call time; the new path silently dispatches kernels to CPU, which can produce wrong results without a clear diagnostic.Consider restoring a CUDA-specific fallback in the
exceptblock so CUDA systems stay on CUDA even if the fast internal path is unavailable:🛠️ Proposed fix
except Exception: - pass + return lambda: torch.device("cuda", torch.cuda.current_device()) if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available(): return lambda: torch.device("mps")2. Stale docstring (line 71–74):
The docstring says "On CPU or when CUDA is unavailable, returnstorch.device('cpu')" — this is now incorrect; MPS is returned when available.📝 Proposed docstring update
- Similar to the stream functor, we capture a callable that, when called, - fetches the current device according to PyTorch. On CPU or when CUDA is - unavailable, returns ``torch.device('cpu')``. + Similar to the stream functor, we capture a callable that, when called, + fetches the current device according to PyTorch. Falls back to MPS when + CUDA is unavailable and MPS is available, otherwise returns + ``torch.device('cpu')``.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/jit/adapter/base.py` around lines 69 - 86, The get_current_device_functor function silently falls back to CPU if torch.cuda._lazy_init() or the fast cuda path fails; change the except block to return a CUDA-producing callable that calls torch.cuda.current_device() (or otherwise raises the proper CUDA error at call time) so CUDA hosts do not silently dispatch to CPU (refer to get_current_device_functor, torch.cuda._lazy_init, torch._C._cuda_getDevice and torch.cuda.current_device); also update the function docstring to reflect that MPS may be returned when available (mention torch.backends.mps) instead of claiming CUDA/unavailable CPU only.
🧹 Nitpick comments (2)
src/op/copy.cc (1)
1089-1126: 💤 Low valueWarp partitioning score computation uses different ratio than GEMM.
The ideal ratio calculation
ideal = N > 0 ? M / N : 1.fdiffers fromGemmWarpPolicyNode::computeWarpPartitionwhich computes score asabs(m_per_warp / n_per_warp - ideal). Here the score isabs(m_per / n_per - ideal)wherem_per = M / (m * kMPerWarp)andn_per = N / (n * kNPerWarp).Both approaches aim for balanced workloads but use slightly different metrics. This is acceptable since the copy operation may have different optimal tiling than GEMM, but consider documenting this difference or unifying the approach if consistency is desired.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/copy.cc` around lines 1089 - 1126, The warp-partition scoring in this copy lowering loop uses ideal = N>0 ? static_cast<float>(M)/N : 1.f and computes score from m_per = M/(m*kMPerWarp) and n_per = N/(n*kNPerWarp), which differs from GemmWarpPolicyNode::computeWarpPartition's m_per_warp/n_per_warp ratio; either make the metric consistent by changing the score computation to mirror GemmWarpPolicyNode::computeWarpPartition (use the same per-warp definitions and ideal) or add a concise comment above this block (referencing ideal, m_per, n_per and GemmWarpPolicyNode::computeWarpPartition) explaining why the copy op uses a different ratio so future maintainers understand the intentional divergence.testing/python/metal/test_metal_internal_scaffolding.py (1)
425-453: ⚡ Quick winRelax the exact float-literal source assertions.
These checks are currently tied to the printer's exact spelling of zero/one literals, so a harmless formatting change will break the tests without changing the generated behavior. Assert on the declaration/assignment pattern instead of the full literal text.
Suggested refactor
- assert "float kkt_bias = 0.000000e+00f;" in src + assert "float kkt_bias" in src + assert "kkt_bias =" in src ... - assert "gate_state = 1.000000e+00f;" in gdn_src + assert "gate_state =" in gdn_src ... - assert "gate_state = 1.000000e+00f;" in gdn_src + assert "gate_state =" in gdn_srcAlso applies to: 465-473
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_internal_scaffolding.py` around lines 425 - 453, The tests test_flashqla_gdn_kkt_probe_combines_local_var_state_and_simdgroup_boundary and test_scaled_packed_quant_and_gdn_probes_source_boundary_tokens assert exact float literal text (e.g. "float kkt_bias = 0.000000e+00f;" and "gate_state = 1.000000e+00f;") which is brittle; change those assertions to check the declaration/assignment pattern instead (e.g. assert "float kkt_bias" in src and assert re.search(r"\bkkt_bias\s*=\s*[-+]?\d*\.?\d+(e[-+]?\d+)?f?\b", src) or similarly for "gate_state" so the test verifies presence of the variable and an assignment to a numeric literal rather than an exact formatted literal; update the other occurrence around lines 465-473 the same way.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmark/matmul_metal/benchmark_matmul_metal.py`:
- Around line 108-125: The loop currently leaves best_config as configs[0] even
if every bench_tilelang call fails; change best_config to None (or similar
sentinel) at initialization and only assign it inside the try block when a run
succeeds (e.g., when tl > best_tflops or when first success), and when
args.sweep is true print the summary only if best_config is not None (otherwise
skip or print "no successful configs"); update references to best_tflops,
best_config, configs and bench_tilelang accordingly.
In `@src/target/codegen_metal.cc`:
- Around line 501-515: The ICHECK message for the constant_size validation in
the simdgroup scope is misleading: constant_size is an element count, not bytes.
Update the error text used with the check in the block that references
constant_size, op->dtype, simdgroup_dtype_, PrintType and vid to say "elements"
(e.g., "Only 8x8 matrix is supported, but got <n> elements") so the message
accurately reflects the validated quantity.
In `@testing/python/metal/test_metal_gemm_v2.py`:
- Around line 59-62: The TileLang kernel launch via jit_kernel(a, b, c) is
asynchronous on MPS; insert a call to torch.mps.synchronize() immediately after
jit_kernel(a, b, c) and before computing ref and asserting, so the Metal kernel
finishes before reading c; update the test in test_metal_gemm_v2.py to call
torch.mps.synchronize() right after the jit_kernel(...) invocation.
In `@testing/python/metal/test_metal_simdgroup_store.py`:
- Around line 49-52: The test launches the Metal kernel via kernel(a, b, c) then
immediately reads tensor c; insert an explicit device synchronization call
(torch.mps.synchronize()) after kernel(a, b, c) and before computing ref =
a.to(torch_accum_dtype) @ b.to(torch_accum_dtype) and the assert so the Metal
command queue finishes and the comparison against c is deterministic; update the
test near kernel(a, b, c) to call torch.mps.synchronize() before using c in the
reference computation and allclose check.
In `@tilelang/tileop/gemm/__init__.py`:
- Around line 173-174: The unconditional selection of GemmInst.METAL_SIMDGROUP
in the target_is_metal branch forces simdgroup lowering for all Metal GEMMs and
bypasses GemmMetal.lower()'s alignment and partition predicates; change the
logic in the target_is_metal branch to perform the same checks used by
GemmMetal.lower() (e.g., 8-alignment of shapes and valid per-warp partition
predicates implemented in tilelang/tileop/gemm/gemm_metal.py) and only return
GemmInst.METAL_SIMDGROUP when those predicates pass, otherwise fall back to the
scalar-safe path so invalid cases are not forced into METAL_SIMDGROUP.
In `@tilelang/tileop/metal_simdgroup.py`:
- Around line 385-396: The mma_tile macro currently ignores K fragments beyond 0
which drops contributions; update mma_tile in tilelang/tileop/metal_simdgroup.py
to either (A) reduce across K fragments by adding an inner loop over the
K-fragment dimension and calling mma for each k-fragment using a.index(tile_m,
k) and b.index(k, tile_n) (accumulating into acc.fragment at acc.index(tile_m,
tile_n)), or (B) fail closed by asserting the input MMATile(s) have fragments_k
== 1 (raise/assert when a.fragments_k or b.fragments_k > 1) so incorrect cases
are rejected. Use the identifiers mma_tile, MMATile, mma, acc.index, a.index and
b.index to locate and implement the fix.
In `@tilelang/transform/metal_fragment_to_simdgroup.py`:
- Around line 94-102: The code contains an unreachable conditional that always
returns new_block because of "if False", making the tir.BlockRealize
construction dead; remove the conditional and return new_block directly (or, if
BlockRealize is desired, replace the conditional with the proper condition),
updating the return in the function that builds the block (reference
tir.BlockRealize, new_block, and stmt.iter_vars) so only the intended branch is
returned and no unreachable code remains.
---
Outside diff comments:
In `@tilelang/jit/adapter/base.py`:
- Around line 69-86: The get_current_device_functor function silently falls back
to CPU if torch.cuda._lazy_init() or the fast cuda path fails; change the except
block to return a CUDA-producing callable that calls torch.cuda.current_device()
(or otherwise raises the proper CUDA error at call time) so CUDA hosts do not
silently dispatch to CPU (refer to get_current_device_functor,
torch.cuda._lazy_init, torch._C._cuda_getDevice and torch.cuda.current_device);
also update the function docstring to reflect that MPS may be returned when
available (mention torch.backends.mps) instead of claiming CUDA/unavailable CPU
only.
---
Nitpick comments:
In `@src/op/copy.cc`:
- Around line 1089-1126: The warp-partition scoring in this copy lowering loop
uses ideal = N>0 ? static_cast<float>(M)/N : 1.f and computes score from m_per =
M/(m*kMPerWarp) and n_per = N/(n*kNPerWarp), which differs from
GemmWarpPolicyNode::computeWarpPartition's m_per_warp/n_per_warp ratio; either
make the metric consistent by changing the score computation to mirror
GemmWarpPolicyNode::computeWarpPartition (use the same per-warp definitions and
ideal) or add a concise comment above this block (referencing ideal, m_per,
n_per and GemmWarpPolicyNode::computeWarpPartition) explaining why the copy op
uses a different ratio so future maintainers understand the intentional
divergence.
In `@testing/python/metal/test_metal_internal_scaffolding.py`:
- Around line 425-453: The tests
test_flashqla_gdn_kkt_probe_combines_local_var_state_and_simdgroup_boundary and
test_scaled_packed_quant_and_gdn_probes_source_boundary_tokens assert exact
float literal text (e.g. "float kkt_bias = 0.000000e+00f;" and "gate_state =
1.000000e+00f;") which is brittle; change those assertions to check the
declaration/assignment pattern instead (e.g. assert "float kkt_bias" in src and
assert re.search(r"\bkkt_bias\s*=\s*[-+]?\d*\.?\d+(e[-+]?\d+)?f?\b", src) or
similarly for "gate_state" so the test verifies presence of the variable and an
assignment to a numeric literal rather than an exact formatted literal; update
the other occurrence around lines 465-473 the same way.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 01662649-98c3-4e80-8215-d6ab6b031580
📒 Files selected for processing (37)
benchmark/matmul_metal/benchmark_matmul_metal.pypyproject.tomlrequirements-dev.txtrequirements.txtsrc/backend/metal/CMakeLists.txtsrc/op/copy.ccsrc/op/copy.hsrc/op/fill.ccsrc/op/gemm.ccsrc/op/gemm.hsrc/op/parallel.ccsrc/op/utils.hsrc/target/codegen_metal.ccsrc/target/codegen_metal.hsrc/transform/layout_inference.ccsrc/transform/lower_device_storage_access_info.cctesting/python/jit/test_tilelang_jit_adapter_mps.pytesting/python/metal/metal_internal_runtime_coverage.mdtesting/python/metal/test_metal_gemm_v2.pytesting/python/metal/test_metal_gemm_v2_linux.pytesting/python/metal/test_metal_internal_scaffolding.pytesting/python/metal/test_metal_local_var.pytesting/python/metal/test_metal_simdgroup_store.pytilelang/engine/lower.pytilelang/engine/phase.pytilelang/intrinsics/metal_macro_generator.pytilelang/jit/adapter/base.pytilelang/jit/adapter/torch/metal.pytilelang/tileop/gemm/__init__.pytilelang/tileop/gemm/gemm_metal.pytilelang/tileop/gemm/inst.pytilelang/tileop/metal_gdn.pytilelang/tileop/metal_quant.pytilelang/tileop/metal_simdgroup.pytilelang/transform/decouple_type_cast.pytilelang/transform/metal_fragment_to_simdgroup.pytilelang/utils/language.py
| best_tflops = 0.0 | ||
| best_config = configs[0] | ||
| for bM, bN, bK in configs: | ||
| try: | ||
| tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats) | ||
| ratio = tl / ref_tflops * 100 | ||
| tag = "" | ||
| if tl > best_tflops: | ||
| best_tflops = tl | ||
| best_config = (bM, bN, bK) | ||
| print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%") | ||
| except Exception as e: | ||
| print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}") | ||
|
|
||
| if args.sweep: | ||
| print() | ||
| print(f"Best config: {best_config}") | ||
| print(f"Best TFlops: {best_tflops:.1f}") |
There was a problem hiding this comment.
Don't print a winner when every sweep config failed.
Lines 108-125 still report configs[0] as the best config if every benchmark attempt throws. That makes the summary misleading in exactly the case where the per-config error handling is supposed to help.
Suggested fix
- best_tflops = 0.0
- best_config = configs[0]
+ best_tflops = 0.0
+ best_config = None
for bM, bN, bK in configs:
try:
tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats)
ratio = tl / ref_tflops * 100
- tag = ""
if tl > best_tflops:
best_tflops = tl
best_config = (bM, bN, bK)
print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%")
except Exception as e:
print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}")
- if args.sweep:
+ if args.sweep and best_config is not None:
print()
print(f"Best config: {best_config}")
print(f"Best TFlops: {best_tflops:.1f}")
print(f"Reference TFlops (PyTorch MPS): {ref_tflops:.1f}")
+ elif args.sweep:
+ print()
+ print("No TileLang configuration completed successfully.")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| best_tflops = 0.0 | |
| best_config = configs[0] | |
| for bM, bN, bK in configs: | |
| try: | |
| tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats) | |
| ratio = tl / ref_tflops * 100 | |
| tag = "" | |
| if tl > best_tflops: | |
| best_tflops = tl | |
| best_config = (bM, bN, bK) | |
| print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%") | |
| except Exception as e: | |
| print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}") | |
| if args.sweep: | |
| print() | |
| print(f"Best config: {best_config}") | |
| print(f"Best TFlops: {best_tflops:.1f}") | |
| best_tflops = 0.0 | |
| best_config = None | |
| for bM, bN, bK in configs: | |
| try: | |
| tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats) | |
| ratio = tl / ref_tflops * 100 | |
| if tl > best_tflops: | |
| best_tflops = tl | |
| best_config = (bM, bN, bK) | |
| print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%") | |
| except Exception as e: | |
| print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}") | |
| if args.sweep and best_config is not None: | |
| print() | |
| print(f"Best config: {best_config}") | |
| print(f"Best TFlops: {best_tflops:.1f}") | |
| print(f"Reference TFlops (PyTorch MPS): {ref_tflops:.1f}") | |
| elif args.sweep: | |
| print() | |
| print("No TileLang configuration completed successfully.") |
🧰 Tools
🪛 Ruff (0.15.12)
[warning] 119-119: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@benchmark/matmul_metal/benchmark_matmul_metal.py` around lines 108 - 125, The
loop currently leaves best_config as configs[0] even if every bench_tilelang
call fails; change best_config to None (or similar sentinel) at initialization
and only assign it inside the try block when a run succeeds (e.g., when tl >
best_tflops or when first success), and when args.sweep is true print the
summary only if best_config is not None (otherwise skip or print "no successful
configs"); update references to best_tflops, best_config, configs and
bench_tilelang accordingly.
| if (scope == "metal.simdgroup") { | ||
| ICHECK(op->dtype == DataType::Float(16) || | ||
| op->dtype == DataType::Float(32) || | ||
| op->dtype == DataType::BFloat(16)) | ||
| << "Only float16, float32, and bfloat16 are supported, but got " | ||
| << op->dtype; | ||
| ICHECK(constant_size % 64 == 0) << "Only 8x8 matrix is supported, but got " | ||
| << constant_size << " bytes\n"; | ||
|
|
||
| std::ostringstream dtype_os; | ||
| PrintType(op->dtype, dtype_os); | ||
| std::string dtype_str = dtype_os.str(); | ||
| simdgroup_dtype_[op->buffer_var.get()] = dtype_str; | ||
| stream << "simdgroup_" << dtype_str << "8x8 " << vid << '[' | ||
| << constant_size / 64 << "];\n"; |
There was a problem hiding this comment.
Error message says "bytes" but validation checks element count.
The ICHECK message at line 508 mentions "bytes" but constant_size represents element count, not bytes. For an 8x8 matrix with 64 elements, this check is correct but the message is misleading.
📝 Proposed fix for error message
- ICHECK(constant_size % 64 == 0) << "Only 8x8 matrix is supported, but got "
- << constant_size << " bytes\n";
+ ICHECK(constant_size % 64 == 0) << "Only 8x8 matrix is supported, but got "
+ << constant_size << " elements\n";📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if (scope == "metal.simdgroup") { | |
| ICHECK(op->dtype == DataType::Float(16) || | |
| op->dtype == DataType::Float(32) || | |
| op->dtype == DataType::BFloat(16)) | |
| << "Only float16, float32, and bfloat16 are supported, but got " | |
| << op->dtype; | |
| ICHECK(constant_size % 64 == 0) << "Only 8x8 matrix is supported, but got " | |
| << constant_size << " bytes\n"; | |
| std::ostringstream dtype_os; | |
| PrintType(op->dtype, dtype_os); | |
| std::string dtype_str = dtype_os.str(); | |
| simdgroup_dtype_[op->buffer_var.get()] = dtype_str; | |
| stream << "simdgroup_" << dtype_str << "8x8 " << vid << '[' | |
| << constant_size / 64 << "];\n"; | |
| if (scope == "metal.simdgroup") { | |
| ICHECK(op->dtype == DataType::Float(16) || | |
| op->dtype == DataType::Float(32) || | |
| op->dtype == DataType::BFloat(16)) | |
| << "Only float16, float32, and bfloat16 are supported, but got " | |
| << op->dtype; | |
| ICHECK(constant_size % 64 == 0) << "Only 8x8 matrix is supported, but got " | |
| << constant_size << " elements\n"; | |
| std::ostringstream dtype_os; | |
| PrintType(op->dtype, dtype_os); | |
| std::string dtype_str = dtype_os.str(); | |
| simdgroup_dtype_[op->buffer_var.get()] = dtype_str; | |
| stream << "simdgroup_" << dtype_str << "8x8 " << vid << '[' | |
| << constant_size / 64 << "];\n"; |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/target/codegen_metal.cc` around lines 501 - 515, The ICHECK message for
the constant_size validation in the simdgroup scope is misleading: constant_size
is an element count, not bytes. Update the error text used with the check in the
block that references constant_size, op->dtype, simdgroup_dtype_, PrintType and
vid to say "elements" (e.g., "Only 8x8 matrix is supported, but got <n>
elements") so the message accurately reflects the validated quantity.
| jit_kernel(a, b, c) | ||
|
|
||
| ref = a.to(torch_accum_dtype) @ b.to(torch_accum_dtype) | ||
| assert torch.allclose(ref, c, atol=atol), ( |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
find testing -name "test_metal_gemm_v2.py" -type fRepository: tile-ai/tilelang
Length of output: 103
🏁 Script executed:
cat testing/python/metal/test_metal_gemm_v2.pyRepository: tile-ai/tilelang
Length of output: 2722
🏁 Script executed:
grep -r "torch.mps.synchronize" testing/ --include="*.py"Repository: tile-ai/tilelang
Length of output: 1525
🏁 Script executed:
grep -A 3 "torch.mps.synchronize()" testing/python/metal/test_metal_internal_scaffolding.py | head -20Repository: tile-ai/tilelang
Length of output: 684
Add torch.mps.synchronize() after the TileLang kernel launch.
Line 59 enqueues the Metal kernel asynchronously, and lines 61-62 immediately read and compare the output c. Metal operations on PyTorch MPS are asynchronous and require explicit synchronization before consuming results. All other Metal tests in this PR follow this pattern—synchronize after kernel launch, then compare. Without it, the assertion can race the async kernel and fail intermittently.
Suggested fix
jit_kernel(a, b, c)
+ torch.mps.synchronize()
ref = a.to(torch_accum_dtype) @ b.to(torch_accum_dtype)
assert torch.allclose(ref, c, atol=atol), (📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| jit_kernel(a, b, c) | |
| ref = a.to(torch_accum_dtype) @ b.to(torch_accum_dtype) | |
| assert torch.allclose(ref, c, atol=atol), ( | |
| jit_kernel(a, b, c) | |
| torch.mps.synchronize() | |
| ref = a.to(torch_accum_dtype) @ b.to(torch_accum_dtype) | |
| assert torch.allclose(ref, c, atol=atol), ( |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@testing/python/metal/test_metal_gemm_v2.py` around lines 59 - 62, The
TileLang kernel launch via jit_kernel(a, b, c) is asynchronous on MPS; insert a
call to torch.mps.synchronize() immediately after jit_kernel(a, b, c) and before
computing ref and asserting, so the Metal kernel finishes before reading c;
update the test in test_metal_gemm_v2.py to call torch.mps.synchronize() right
after the jit_kernel(...) invocation.
| kernel(a, b, c) | ||
|
|
||
| ref = a.to(torch_accum_dtype) @ b.to(torch_accum_dtype) | ||
| assert torch.allclose(ref, c, atol=atol), ( |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n testing/python/metal/test_metal_simdgroup_store.py | head -70Repository: tile-ai/tilelang
Length of output: 3572
🏁 Script executed:
find . -name "test_metal_gemm_v2.py" -type f | head -5Repository: tile-ai/tilelang
Length of output: 105
🏁 Script executed:
cat -n ./testing/python/metal/test_metal_gemm_v2.py | head -100Repository: tile-ai/tilelang
Length of output: 3359
🏁 Script executed:
grep -n "torch.mps.synchronize" ./testing/python/metal/test_metal_gemm_v2.py ./testing/python/metal/test_metal_simdgroup_store.pyRepository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
grep -r "torch.mps.synchronize" ./testing/python/metal/ --include="*.py"Repository: tile-ai/tilelang
Length of output: 1559
🏁 Script executed:
cat -n ./testing/python/metal/test_metal_local_var.py | grep -A5 -B5 "torch.mps.synchronize"Repository: tile-ai/tilelang
Length of output: 384
🏁 Script executed:
cat -n ./testing/python/metal/test_metal_internal_scaffolding.py | head -100Repository: tile-ai/tilelang
Length of output: 4262
🏁 Script executed:
grep -B10 -A2 "torch.mps.synchronize()" ./testing/python/metal/test_metal_internal_scaffolding.py | head -50Repository: tile-ai/tilelang
Length of output: 2290
🏁 Script executed:
cat -n ./testing/python/metal/test_metal_gemm_v2.py | sed -n '50,70p'Repository: tile-ai/tilelang
Length of output: 957
Synchronize before comparing the MPS output tensor.
Line 49 launches the external kernel, and lines 51-52 immediately read c in PyTorch. Without an explicit torch.mps.synchronize(), this helper can race the Metal command queue and produce flaky correctness failures.
Suggested fix
kernel(a, b, c)
+ torch.mps.synchronize()
ref = a.to(torch_accum_dtype) @ b.to(torch_accum_dtype)
assert torch.allclose(ref, c, atol=atol), (📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| kernel(a, b, c) | |
| ref = a.to(torch_accum_dtype) @ b.to(torch_accum_dtype) | |
| assert torch.allclose(ref, c, atol=atol), ( | |
| kernel(a, b, c) | |
| torch.mps.synchronize() | |
| ref = a.to(torch_accum_dtype) @ b.to(torch_accum_dtype) | |
| assert torch.allclose(ref, c, atol=atol), ( |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@testing/python/metal/test_metal_simdgroup_store.py` around lines 49 - 52, The
test launches the Metal kernel via kernel(a, b, c) then immediately reads tensor
c; insert an explicit device synchronization call (torch.mps.synchronize())
after kernel(a, b, c) and before computing ref = a.to(torch_accum_dtype) @
b.to(torch_accum_dtype) and the assert so the Metal command queue finishes and
the comparison against c is deterministic; update the test near kernel(a, b, c)
to call torch.mps.synchronize() before using c in the reference computation and
allclose check.
| if target_is_metal(target): | ||
| return GemmInst.METAL_SIMDGROUP |
There was a problem hiding this comment.
Don’t force METAL_SIMDGROUP for every Metal GEMM.
GemmMetal.lower() still rejects non-8-aligned shapes and invalid per-warp partitions (tilelang/tileop/gemm/gemm_metal.py:21-35). This unconditional return bypasses fallback selection and turns those cases into hard ValueErrors instead of using a scalar-safe path.
Suggested direction
- if target_is_metal(target):
- return GemmInst.METAL_SIMDGROUP
+ if target_is_metal(target) and GemmMetal(self).can_lower(target, thread_nums):
+ return GemmInst.METAL_SIMDGROUP
return GemmInst(_ffi_api.GemmGetGemmInst(self, int(thread_nums), target))Use the same predicates that GemmMetal.lower() enforces, and fall back when simdgroup lowering is not applicable.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/tileop/gemm/__init__.py` around lines 173 - 174, The unconditional
selection of GemmInst.METAL_SIMDGROUP in the target_is_metal branch forces
simdgroup lowering for all Metal GEMMs and bypasses GemmMetal.lower()'s
alignment and partition predicates; change the logic in the target_is_metal
branch to perform the same checks used by GemmMetal.lower() (e.g., 8-alignment
of shapes and valid per-warp partition predicates implemented in
tilelang/tileop/gemm/gemm_metal.py) and only return GemmInst.METAL_SIMDGROUP
when those predicates pass, otherwise fall back to the scalar-safe path so
invalid cases are not forced into METAL_SIMDGROUP.
| @T.macro | ||
| def mma_tile(acc: MMATile, a: MMATile, b: MMATile) -> None: | ||
| for tile_m in T.unroll(acc.fragments_m, explicit=True): | ||
| for tile_n in T.unroll(acc.fragments_n, explicit=True): | ||
| mma( | ||
| acc.fragment, | ||
| a.fragment, | ||
| b.fragment, | ||
| acc.index(tile_m, tile_n), | ||
| a.index(tile_m, 0), | ||
| b.index(0, tile_n), | ||
| ) |
There was a problem hiding this comment.
Reduce across K fragments in mma_tile, or fail closed.
This only multiplies a.index(tile_m, 0) with b.index(0, tile_n). For multi-fragment K tiles, every slice after 0 is silently dropped, so the accumulator is wrong.
Suggested fix
`@T.macro`
def mma_tile(acc: MMATile, a: MMATile, b: MMATile) -> None:
+ if a.fragments_n != b.fragments_m:
+ raise ValueError(
+ f"mma_tile requires matching K fragments, got {a.fragments_n} and {b.fragments_m}"
+ )
for tile_m in T.unroll(acc.fragments_m, explicit=True):
for tile_n in T.unroll(acc.fragments_n, explicit=True):
- mma(
- acc.fragment,
- a.fragment,
- b.fragment,
- acc.index(tile_m, tile_n),
- a.index(tile_m, 0),
- b.index(0, tile_n),
- )
+ for tile_k in T.unroll(a.fragments_n, explicit=True):
+ mma(
+ acc.fragment,
+ a.fragment,
+ b.fragment,
+ acc.index(tile_m, tile_n),
+ a.index(tile_m, tile_k),
+ b.index(tile_k, tile_n),
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @T.macro | |
| def mma_tile(acc: MMATile, a: MMATile, b: MMATile) -> None: | |
| for tile_m in T.unroll(acc.fragments_m, explicit=True): | |
| for tile_n in T.unroll(acc.fragments_n, explicit=True): | |
| mma( | |
| acc.fragment, | |
| a.fragment, | |
| b.fragment, | |
| acc.index(tile_m, tile_n), | |
| a.index(tile_m, 0), | |
| b.index(0, tile_n), | |
| ) | |
| `@T.macro` | |
| def mma_tile(acc: MMATile, a: MMATile, b: MMATile) -> None: | |
| if a.fragments_n != b.fragments_m: | |
| raise ValueError( | |
| f"mma_tile requires matching K fragments, got {a.fragments_n} and {b.fragments_m}" | |
| ) | |
| for tile_m in T.unroll(acc.fragments_m, explicit=True): | |
| for tile_n in T.unroll(acc.fragments_n, explicit=True): | |
| for tile_k in T.unroll(a.fragments_n, explicit=True): | |
| mma( | |
| acc.fragment, | |
| a.fragment, | |
| b.fragment, | |
| acc.index(tile_m, tile_n), | |
| a.index(tile_m, tile_k), | |
| b.index(tile_k, tile_n), | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/tileop/metal_simdgroup.py` around lines 385 - 396, The mma_tile
macro currently ignores K fragments beyond 0 which drops contributions; update
mma_tile in tilelang/tileop/metal_simdgroup.py to either (A) reduce across K
fragments by adding an inner loop over the K-fragment dimension and calling mma
for each k-fragment using a.index(tile_m, k) and b.index(k, tile_n)
(accumulating into acc.fragment at acc.index(tile_m, tile_n)), or (B) fail
closed by asserting the input MMATile(s) have fragments_k == 1 (raise/assert
when a.fragments_k or b.fragments_k > 1) so incorrect cases are rejected. Use
the identifiers mma_tile, MMATile, mma, acc.index, a.index and b.index to locate
and implement the fix.
| return ( | ||
| tir.BlockRealize( | ||
| stmt.iter_vars, | ||
| tir.const(True, "bool"), | ||
| new_block, | ||
| ) | ||
| if False | ||
| else new_block | ||
| ) |
There was a problem hiding this comment.
Dead code: if False branch is unreachable.
The conditional if False else new_block always evaluates to new_block, making the tir.BlockRealize construction unreachable. This appears to be leftover debugging or incomplete code.
🧹 Proposed fix to remove dead code
- return (
- tir.BlockRealize(
- stmt.iter_vars,
- tir.const(True, "bool"),
- new_block,
- )
- if False
- else new_block
- )
+ return new_block📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| return ( | |
| tir.BlockRealize( | |
| stmt.iter_vars, | |
| tir.const(True, "bool"), | |
| new_block, | |
| ) | |
| if False | |
| else new_block | |
| ) | |
| return new_block |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/transform/metal_fragment_to_simdgroup.py` around lines 94 - 102, The
code contains an unreachable conditional that always returns new_block because
of "if False", making the tir.BlockRealize construction dead; remove the
conditional and return new_block directly (or, if BlockRealize is desired,
replace the conditional with the proper condition), updating the return in the
function that builds the block (reference tir.BlockRealize, new_block, and
stmt.iter_vars) so only the intended branch is returned and no unreachable code
remains.
Summary
Extends the storage-only Metal FP8 codegen to handle vectorized casts at IR
lanes 2 / 3 / 4. Without this change, every
T.Cast("float16x4", fp8_x4)emitted by upstream TileLang DSL programs raises
LOG(FATAL): Vector FP8 casts (lanes=4) are not yet supported, forcingcallers to manually scalarise the cast and giving up the IR-level vector
type for any subsequent pass (vectorize, fragment-to-simdgroup, etc.).
This PR is the TileLang supermodule half. The companion
TileLang/tvm submodule half is filed at
https://github.com/TileLang/tvm/pulls (search
cppmega/metal-fp8-vector-cast).Both halves only share helper names; they can land independently but
should be merged in tandem so the vendored 3rdparty/tvm checkout stays in
sync with this codepath.
What this changes
Adds an
enable_fp8_vector_codegen flag and a newPrintFP8VectorPrelude(...)that emits inline MSL helpers that wrap theexisting scalar helpers (
__tvm_fp8_e4m3_to_half, etc.) per lane:Mirrors are emitted for
_v2/_v3, plus the reverse direction(
half -> fp8) and the e5m2 variant. The compiler is free to scalariseback into per-lane calls; the goal here is to preserve the IR-level
vector type so subsequent passes can keep their vector loads and stores
and the downstream MSL is
uchar4-typed instead ofuchararrays.Finish()is updated to splice the vector prelude after the scalarprelude when at least one vector FP8 cast is encountered during codegen.
Wider lanes (8 / 16) keep the existing
LOG(FATAL)with a sharpermessage — those widths print as
uint2/uint4packed storage andneed an out-pointer ABI to be wired through; callers should lower them
to scalar casts upstream.
Why Apple Silicon needs software FP8 emulation
Apple Silicon (M1 through M4 Max, including the M5 NAX which is
FP16/INT8 only) has no native FP8 ALU. FP8 is realised as
ucharstorage with explicit dequantize-on-load / quantize-on-store; the
encoding mirrors the OCP "OFP8 Formats for Deep Learning" v1.0 spec
(E4M3 finite-only, E5M2 IEEE-style with NaN/Inf).
The vector helpers in this PR are inline-trivial wrappers around the
scalar helpers that landed in the storage-only PR — no new conversion
math. Their value is purely codegen: the IR-level vector type is
preserved so the rest of the lowering pipeline can vectorise.
Path C consumer evidence (vector lanes matter)
The downstream cppmega.mlx project's Path C TileLang FP8 vecmat kernel
(
cppmega_mlx/nn/_tilelang/fp8_vecmat_path_c.py) explicitly usesT.alloc_local((4,), "float8_e4m3")and aT.vectorized(4)inner loopover packed FP8 weights. Without this PR, that kernel cannot be lowered
on Metal — the FP8 cast inside the K-loop hits the
lanes=4FATAL.With this PR, the cast lowers and the resulting MSL preserves
uchar4-typed loads through the K-loop.Dependency
This PR stacks on two prereqs:
tilelang_metal_fp8storage-only patch (parallel[Metal] FP8 storage-only emulation (uchar storage + LUT decode helpers)PR being filed against this same repo). That patch adds
PrintFP8Prelude,enable_fp8_, and the scalar__tvm_fp8_*_to_half/__tvm_half_to_fp8_*helpers that thevector helpers in this PR call. Reviewers will need that patch
applied first; the branch in this PR includes it as the first commit
[Metal] FP8 storage-only emulation ... [prereq]for self-containedreview.
metal-gemm-upstream-rebase) atHEAD
971c17b. That branch in turn stacks on PRs[Metal] Add Metal GEMM support with simdgroup_matrix MMA #1869 / Add Metal scalar fallback for T.gemm #2118 / [Refactor][CodeGen] Refactor CodeGen part for multi-backend decoupling #2121.
When the storage-only PR merges, the prereq commit on this branch
should be rebased away. Before that, this branch is reviewable as
2-commits stacked.
Test plan
git apply --checkclean againstjorgecurious/tilelang:metal-gemm-upstream-rebase @ 971c17bwith thestorage-only prereq applied first
git apply --reverse --checkclean for both commits in sequence(round-trip verified)
xcrun --sdk macosx metal -ccompile of any prim_func with vectorFP8 cast (lanes 2/3/4) lowers to MSL using the new vector helpers,
not scalar fallback
/tmp/test_fp8_vector_cast.py:lanes=4cast lowersand the resulting MSL contains
__tvm_fp8_e4m3_to_half_v4withuchar4typed loadstesting/python/metal/test_metal_codegen_linux.py: net +1 passvs storage-only-only baseline
(
test_t_gemm_metal_codegen_pipelined_float32flips green)Summary by CodeRabbit
Release Notes
New Features
Tests
Chores