[Metal] add T.BlockScaledLayout.e8m0_k32 + T.e8m0_to_float (blockscaled FP8)#2147
[Metal] add T.BlockScaledLayout.e8m0_k32 + T.e8m0_to_float (blockscaled FP8)#2147apstenku123 wants to merge 12 commits into
Conversation
Add T.gemm support for Apple Metal using simdgroup_matrix 8x8 operations (simdgroup_load/store/multiply_accumulate). Works on all Apple Silicon (M1-M5) without requiring a TVM fork. Key changes: - codegen_metal.cc/h: Fork TVM Metal codegen to tilelang with simdgroup intrinsic emission and 128-bit vectorized copy - gemm_metal.py: GemmMetal tile operator for sharedxshared GEMM - metal_macro_generator.py: MPSIntrinEmitter for simdgroup MMA macros - metal_fragment_to_simdgroup.py: Pass rewrites local.fragment GEMM accumulators to metal.simdgroup scope before layout inference - LowerSIMDGroupCopy in copy.cc for fragment->device simdgroup_store 24 Metal tests (codegen cross-platform + correctness on device).
|
Caution Review failedPull request was closed or merged during review 📝 WalkthroughWalkthroughThis PR introduces comprehensive Metal (MPS) backend support to TileLang, enabling GPU-accelerated GEMM, quantized matmul, and GDN operations on Apple Silicon. It adds Metal code generation, SIMD-group matrix abstractions, FP8 scaled quantization support, JIT compilation, and extensive testing infrastructure. ChangesMetal Backend Infrastructure & GEMM Support
Metal SIMD-group Abstractions & Macros
FP8 Quantization Support
Testing & Benchmarking
Dependency & Configuration Updates
Sequence DiagramssequenceDiagram
actor User as Application
participant Compiler as TileLang Compiler
participant MetalGen as Metal CodeGen
participant JIT as JIT Adapter
participant MPS as MPS Runtime
User->>Compiler: Define kernel (T.prim_func + GEMM/FP8Matmul)
activate Compiler
Compiler->>Compiler: Lower phase<br/>(fragment→simdgroup transform)
Compiler->>Compiler: Layout inference<br/>(Metal special-case)
Compiler->>Compiler: Select Metal GEMM<br/>(GemmInst::kMetalSimdgroup)
Compiler->>Compiler: Lower GEMM<br/>(GemmMetal→simdgroup ops)
deactivate Compiler
Compiler->>MetalGen: Invoke target.build.tilelang_metal<br/>(CodeGenTileLangMetal)
activate MetalGen
MetalGen->>MetalGen: Emit kernel void + args struct
MetalGen->>MetalGen: Emit threadgroup dimensions
MetalGen->>MetalGen: Generate simdgroup ops<br/>(load/store/mma)
deactivate MetalGen
MetalGen-->>Compiler: Metal Shading Language source
Compiler-->>JIT: Lowered module + MSL source
JIT->>JIT: Compile MSL via Xcode toolchain<br/>(xcrun metal -c)
JIT->>JIT: Create MPS command buffer<br/>+ setup kernel arguments
JIT-->>User: Return compiled JIT kernel
User->>User: Prepare PyTorch input tensors<br/>(on device="mps")
User->>MPS: Execute kernel(A, B, C)<br/>(via JIT adapter)
activate MPS
MPS->>MPS: SIMD-group threads load A/B<br/>(simdgroup_load)
MPS->>MPS: Warp-level matmul<br/>(simdgroup_multiply_accumulate)
MPS->>MPS: SIMD-group store C<br/>(simdgroup_store)
MPS->>MPS: Synchronize
deactivate MPS
MPS-->>User: Output tensor C on MPS device
sequenceDiagram
actor User as Application
participant Compiler as TileLang Compiler
participant FP8Macro as FP8 Scaled Matmul Macro
participant Lowering as Metal Lowering
participant MPS as MPS Runtime
User->>Compiler: Define kernel with<br/>T.fp8_scaled_matmul(A_fp8, B_fp8, scales, C_out)
activate Compiler
Compiler->>FP8Macro: Validate FP8 dtypes + scale shapes
activate FP8Macro
FP8Macro->>FP8Macro: Determine scale indexing pattern<br/>(per-tensor, per-row, per-col,<br/>or block-scale)
FP8Macro->>FP8Macro: Emit parallel (i,j) loop<br/>+ serial K loop
FP8Macro->>FP8Macro: Emit per-element:<br/>dequant(A_fp8[i,k]),<br/>dequant(B_fp8[k,j]),<br/>fused_scale multiplication,<br/>C_out[i,j] accumulation
deactivate FP8Macro
FP8Macro-->>Compiler: Expanded TIR with float32 ops
Compiler->>Lowering: Lower TIR to Metal backend<br/>(via GEMM path if applicable)
activate Lowering
Lowering->>Lowering: Emit simdgroup-compatible<br/>matmul kernels
Lowering->>Lowering: Avoid simdgroup_multiply_accumulate<br/>for FP8 (use scalar accum instead)
deactivate Lowering
Lowering-->>Compiler: Metal source with<br/>FP8 dequant + matmul
Compiler-->>User: Return compiled kernel
User->>User: Prepare FP8 + scale tensors on MPS
User->>MPS: Execute FP8 scaled matmul kernel
activate MPS
MPS->>MPS: Load FP8 elements
MPS->>MPS: Dequantize to float32<br/>per-element
MPS->>MPS: Apply broadcast scales
MPS->>MPS: Accumulate into output
deactivate MPS
MPS-->>User: Return float32 output on MPS
Estimated Code Review Effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly Related PRs
Suggested Reviewers
Poem
✨ Finishing Touches🧪 Generate unit tests (beta)
|
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
After re-author against current PR #2142 macro shape (the previous probe- failed drafts targeted a non-existent tileop scheduler hierarchy), both patches now apply cleanly on jorgecurious metal-gemm-upstream-rebase + #2142 prereq stack. Filed: - PR tile-ai/tilelang#2146 (Path C tracker B): fused FP8 scale broadcast into T.fp8_scaled_matmul K-loop. 16/8 LOC delta in tilelang/language/ fp8_op.py. Closes the 3-6× audiohacking perf gap on FP8 scaled matmul per the cppmega.mlx Path C consumer at fp8_vecmat_path_c.py. - PR tile-ai/tilelang#2147 (Path C tracker C): T.BlockScaledLayout.e8m0_k32 + T.e8m0_to_float DSL primitive. 5 files touched (tilelang/language/ blockscaled_layout.py new, fp8_op.py extended, __init__.py re-export, metal_quant.py Metal lowering, e8m0 layout test). Unblocks Sparse-MLA blockscaled Path C QK reducer. Both stack on PR #2142 (T.fp8_scaled_matmul intrinsic) which stacks on PR #2130 (jorgecurious base). Independent of each other — different gaps, different files (B touches the macro body, C adds the layout primitive). Receipt _filed_prs_2026_05_04.md updated with rows 13-14. Total filed PRs: 14 (across ml-explore/mlx, apache/tvm, tile-ai/tilelang, tile-ai/tvm). All OPEN. Path C tracker A (pipelined_32x32) shipped in commit 3cb6457 + 6746ff9. Path C tracker B (#2146) and C (#2147) now filed upstream. All three Path C follow-up entries from docs/upstream/_path_c_blockers_tracker.md have landing receipts.
|
Withdrawn by submitter: filed without local build/test verification of the lowered MSL output. Will re-verify end-to-end (TileLang ninja build + IR tests + probe parity) before re-filing. |
There was a problem hiding this comment.
Pull request overview
This stacked PR adds the Metal-side scaffolding needed for simdgroup GEMM / FP8 lowering and then extends the TileLang FP8 DSL with a first-class e8m0 K/32 block-scale layout for blockscaled FP8 matmul.
Changes:
- Adds Metal backend/codegen, simdgroup helpers, lowering passes, and JIT/runtime plumbing.
- Adds
T.fp8_scaled_matmul,T.BlockScaledLayout.e8m0_k32(), andT.e8m0_to_float(). - Adds broad Metal/CPU test coverage, docs, and a Metal GEMM benchmark script.
Reviewed changes
Copilot reviewed 42 out of 43 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| tilelang/utils/language.py | Adds Metal simdgroup scope helper. |
| tilelang/transform/metal_fragment_to_simdgroup.py | New pass rewriting fragment accumulators to Metal simdgroup storage. |
| tilelang/transform/decouple_type_cast.py | Treats metal.simdgroup as register-local for cast splitting. |
| tilelang/tileop/metal_simdgroup.py | Adds internal Metal register-tile/simdgroup helper macros. |
| tilelang/tileop/metal_quant.py | Adds packed quant decode helpers, including e8m0_to_float. |
| tilelang/tileop/metal_gdn.py | Adds internal Metal GDN/attention helper macros. |
| tilelang/tileop/gemm/inst.py | Adds Metal simdgroup GEMM instruction enum. |
| tilelang/tileop/gemm/gemm_metal.py | New Metal GEMM lowering implementation. |
| tilelang/tileop/gemm/init.py | Wires Metal GEMM selection into Python dispatch. |
| tilelang/language/fp8_op.py | Adds FP8 scaled matmul macro and block-scale dispatch. |
| tilelang/language/blockscaled_layout.py | New public block-scale layout metadata object. |
| tilelang/language/init.py | Re-exports new FP8/blockscaled language APIs. |
| tilelang/jit/adapter/torch/metal.py | Exposes Metal kernel source through the Torch adapter. |
| tilelang/jit/adapter/base.py | Prefers MPS device selection when CUDA is unavailable/fails. |
| tilelang/intrinsics/metal_macro_generator.py | Adds Metal simdgroup intrinsic emitter. |
| tilelang/engine/phase.py | Inserts Metal fragment-to-simdgroup rewrite into lowering. |
| tilelang/engine/lower.py | Switches Metal codegen entrypoint to TileLang-specific builder. |
| testing/python/metal/test_metal_simdgroup_store.py | Adds simdgroup store codegen/runtime tests. |
| testing/python/metal/test_metal_local_var.py | Adds focused Metal local.var tests. |
| testing/python/metal/test_metal_internal_scaffolding.py | Adds broad internal Metal source/runtime scaffold tests. |
| testing/python/metal/test_metal_gemm_v2.py | Adds Metal GEMM runtime correctness tests. |
| testing/python/metal/test_metal_gemm_v2_linux.py | Adds cross-platform Metal GEMM codegen tests. |
| testing/python/metal/test_fp8_scaled_matmul_metal.py | Adds Metal FP8 scaled matmul lowering/runtime/bench tests. |
| testing/python/metal/metal_internal_runtime_coverage.md | Documents internal Metal runtime coverage. |
| testing/python/jit/test_tilelang_jit_adapter_mps.py | Adds MPS device-selection tests for JIT adapter. |
| testing/python/cpu/test_fp8_scaled_matmul_lowering.py | Adds CPU/IR tests for FP8 scaled matmul lowering. |
| testing/python/cpu/test_blockscaled_e8m0_layout.py | Adds CPU contract tests for e8m0 block-scale layout. |
| src/transform/lower_device_storage_access_info.cc | Exempts fragment scope from memory-info lookup. |
| src/transform/layout_inference.cc | Relaxes fragment layout validation on Metal. |
| src/target/codegen_metal.h | Declares TileLang Metal codegen class. |
| src/target/codegen_metal.cc | Adds TileLang Metal source/codegen implementation. |
| src/op/utils.h | Adds Metal simdgroup/register buffer helpers. |
| src/op/parallel.cc | Guards fragment-layout lookup in parallel op inference. |
| src/op/gemm.h | Adds Metal GEMM enum value/stringification. |
| src/op/gemm.cc | Adds Metal GEMM instruction selection/partition tweaks. |
| src/op/fill.cc | Adds simdgroup fill lowering. |
| src/op/copy.h | Adds Metal simdgroup copy instruction/lowering hooks. |
| src/op/copy.cc | Adds Metal simdgroup copy detection and lowering. |
| src/backend/metal/CMakeLists.txt | Builds Metal codegen and adjusts non-Apple behavior. |
| requirements.txt | Caps apache-tvm-ffi on Darwin. |
| requirements-dev.txt | Caps apache-tvm-ffi on Darwin for dev installs. |
| pyproject.toml | Adds Darwin-specific apache-tvm-ffi cap to package metadata. |
| benchmark/matmul_metal/benchmark_matmul_metal.py | Adds Metal GEMM benchmark script. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _block_scale_value(scale, *, axis: str, col, k): | ||
| # Path C E8M0 is explicitly contracted-K-block indexed: kb = k // 32. | ||
| kb = k // 32 | ||
| if axis == "B" and len(getattr(scale, "shape", ())) == 2: | ||
| return e8m0_to_float(scale[col, kb]) | ||
| return e8m0_to_float(scale[kb]) |
| if target_is_metal(target): | ||
| return GemmInst.METAL_SIMDGROUP |
| } else if (TargetIsCPU(target)) { | ||
| return GemmInst::kScalar; | ||
| } else if (TargetIsMetal(target)) { | ||
| return GemmInst::kMetalSimdgroup; |
| new_body = tir.stmt_functor.substitute(stmt.body, var_map) | ||
| new_block = tir.Block( | ||
| stmt.iter_vars, | ||
| stmt.reads, | ||
| stmt.writes, | ||
| stmt.name_hint, | ||
| new_body, | ||
| stmt.init, |
| for (const auto &[buffer, _] : use_list_) { | ||
| if (IsFragmentBuffer(buffer)) { | ||
| ICHECK_NE(layout_map.count(buffer), 0) | ||
| << "The layout for fragment " << buffer | ||
| << " can not be inferred correctly."; | ||
| if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) { | ||
| ICHECK(false) << "The layout for fragment " << buffer | ||
| << " can not be inferred correctly."; | ||
| } |
| # Metal codegen is pure C++ and can generate Metal shader source on any | ||
| # platform. Always compile it so target.build.tilelang_metal is available for | ||
| # cross-compilation and source-level tests on non-Apple hosts. | ||
| list(APPEND TILE_LANG_SRCS | ||
| src/target/codegen_metal.cc | ||
| ) |
|
|
||
|
|
||
| def _e8m0_to_float_cpu(bits: int) -> float: | ||
| return 0.0 if bits == 255 else 2.0 ** (bits - 127) |
| assert "byte == 0" in text | ||
| assert "byte == 0xFF" in text |
| if block_scale_layout is not None: | ||
| sa = _block_scale_value(A_scale, axis="A", col=j, k=k) | ||
| sb = _block_scale_value(B_scale, axis="B", col=j, k=k) | ||
| else: | ||
| sa = A_scale[0] if sa_size == 1 else A_scale[i] | ||
| sb = B_scale[0] if sb_size == 1 else B_scale[j] | ||
| C_local[i, j] = C_local[i, j] + a_val * b_val * sa * sb |
Summary
Adds
T.BlockScaledLayout.e8m0_k32()andT.e8m0_to_float()soT.fp8_scaled_matmul(...)accepts the e8m0 block-scale layout used by Sparse-MLA blockscaled QK reducers. Also extendsT.fp8_scaled_matmulto dispatch on layout type (per-tensor / per-row / e8m0 block).The contract
E8M0 (8-bit exponent-only) block-scale layout matches mxfp8 conventions:
(K / 32,)(N, K / 32), with broadcast(K / 32,)accepted for local probeskb = k // 32(one scale per 32 contracted-K values)0and0xFFmap to zero; normal bytes decode aspow(2, byte - 127)Why
Path C tracker entry C in cppmega.mlx documents this gap. The Sparse-MLA blockscaled (mxfp8) Path C QK reducer in
cppmega_mlx/nn/_tilelang/sparse_mla_blockscaled_path_c.pycurrently has a partial e8m0 receipt at C/B 0.4364 vs Path B's hand-tuned MSL — but it can't be a full Path C dispatch because the DSL has no first-class block-scale layout primitive. With this PR:works directly, and the lowered MSL emits the
pow(2, scale - 127)decode + zero-byte handling per-K-block.Files
tilelang/language/__init__.py— re-exportsBlockScaledLayoutande8m0_to_floattilelang/language/blockscaled_layout.py— new file,BlockScaledLayout.e8m0_k32()+e8m0_to_floathelpertilelang/language/fp8_op.py— extendsT.fp8_scaled_matmulto accept layout objectstilelang/tileop/metal_quant.py— Metal lowering for the e8m0 decodetesting/python/cpu/test_blockscaled_e8m0_layout.py— IR-level test coverageStacking
Stacks on PR #2142 (T.fp8_scaled_matmul DSL intrinsic + Metal lowering) which is itself stacked on PR #2130 (jorgecurious metal-gemm-upstream-rebase). The branch contains 2 commits:
[prereq] tilelang: T.fp8_scaled_matmul DSL intrinsic + Metal lowering— exact PR tilelang: T.fp8_scaled_matmul DSL intrinsic + Metal lowering #2142 content[Metal] add T.BlockScaledLayout.e8m0_k32 + T.e8m0_to_float (blockscaled FP8)— this PR's contributionIndependent of PR #2146 (FP8 scaled-matmul fused scheduler) — the two address different gaps.
Test plan
Local probe at
cppmega.mlx/docs/upstream/tilelang_metal_blockscaled_e8m0/test_blockscaled_e8m0_probe.pyvalidates 9/9 source-level invariants (DSL surface, sentinel decode, K/32 indexing, README contract).Caveats
block_size=32is the only e8m0 block size shipped in this PR; other block sizes can extend the sameBlockScaledLayoutfactory pattern.Attribution
Co-developed with cppmega.mlx for Apple-Silicon Metal MLA kernel ports.
Summary by CodeRabbit
New Features
Tests