Skip to content

[Metal] add T.BlockScaledLayout.e8m0_k32 + T.e8m0_to_float (blockscaled FP8)#2147

Closed
apstenku123 wants to merge 12 commits into
tile-ai:mainfrom
apstenku123:cppmega/metal-add-e8m0-blockscaled-layout
Closed

[Metal] add T.BlockScaledLayout.e8m0_k32 + T.e8m0_to_float (blockscaled FP8)#2147
apstenku123 wants to merge 12 commits into
tile-ai:mainfrom
apstenku123:cppmega/metal-add-e8m0-blockscaled-layout

Conversation

@apstenku123
Copy link
Copy Markdown

@apstenku123 apstenku123 commented May 4, 2026

Summary

Adds T.BlockScaledLayout.e8m0_k32() and T.e8m0_to_float() so T.fp8_scaled_matmul(...) accepts the e8m0 block-scale layout used by Sparse-MLA blockscaled QK reducers. Also extends T.fp8_scaled_matmul to dispatch on layout type (per-tensor / per-row / e8m0 block).

The contract

E8M0 (8-bit exponent-only) block-scale layout matches mxfp8 conventions:

  • A scale shape: (K / 32,)
  • B scale shape: (N, K / 32), with broadcast (K / 32,) accepted for local probes
  • Scale index: kb = k // 32 (one scale per 32 contracted-K values)
  • Decode: 0 and 0xFF map to zero; normal bytes decode as pow(2, byte - 127)

Why

Path C tracker entry C in cppmega.mlx documents this gap. The Sparse-MLA blockscaled (mxfp8) Path C QK reducer in cppmega_mlx/nn/_tilelang/sparse_mla_blockscaled_path_c.py currently has a partial e8m0 receipt at C/B 0.4364 vs Path B's hand-tuned MSL — but it can't be a full Path C dispatch because the DSL has no first-class block-scale layout primitive. With this PR:

@T.prim_func
def blockscaled_attn(...):
    layout = T.BlockScaledLayout.e8m0_k32()
    T.fp8_scaled_matmul(A_fp8, A_scale_e8m0, B_fp8, B_scale_e8m0, C, layout=layout)

works directly, and the lowered MSL emits the pow(2, scale - 127) decode + zero-byte handling per-K-block.

Files

  • tilelang/language/__init__.py — re-exports BlockScaledLayout and e8m0_to_float
  • tilelang/language/blockscaled_layout.py — new file, BlockScaledLayout.e8m0_k32() + e8m0_to_float helper
  • tilelang/language/fp8_op.py — extends T.fp8_scaled_matmul to accept layout objects
  • tilelang/tileop/metal_quant.py — Metal lowering for the e8m0 decode
  • testing/python/cpu/test_blockscaled_e8m0_layout.py — IR-level test coverage

Stacking

Stacks on PR #2142 (T.fp8_scaled_matmul DSL intrinsic + Metal lowering) which is itself stacked on PR #2130 (jorgecurious metal-gemm-upstream-rebase). The branch contains 2 commits:

  1. [prereq] tilelang: T.fp8_scaled_matmul DSL intrinsic + Metal lowering — exact PR tilelang: T.fp8_scaled_matmul DSL intrinsic + Metal lowering #2142 content
  2. [Metal] add T.BlockScaledLayout.e8m0_k32 + T.e8m0_to_float (blockscaled FP8) — this PR's contribution

Independent of PR #2146 (FP8 scaled-matmul fused scheduler) — the two address different gaps.

Test plan

cd tilelang
mkdir build && cd build
cmake .. -DTL_LLVM_VERSION=21
ninja -j8 tvm_runtime
cd ..
pytest testing/python/cpu/test_blockscaled_e8m0_layout.py testing/python/cpu/test_fp8_scaled_matmul_lowering.py -v

Local probe at cppmega.mlx/docs/upstream/tilelang_metal_blockscaled_e8m0/test_blockscaled_e8m0_probe.py validates 9/9 source-level invariants (DSL surface, sentinel decode, K/32 indexing, README contract).

Caveats

  • This PR adds the layout + decode primitive. The full Sparse-MLA blockscaled Path C dispatch in cppmega.mlx will mirror its BF16 partial-backward contract once this lands; that follow-up is downstream consumer work.
  • block_size=32 is the only e8m0 block size shipped in this PR; other block sizes can extend the same BlockScaledLayout factory pattern.

Attribution

Co-developed with cppmega.mlx for Apple-Silicon Metal MLA kernel ports.

Summary by CodeRabbit

  • New Features

    • Metal (MPS) GEMM kernel support with float16 inputs and float32 accumulation.
    • FP8 scaled matrix multiplication with flexible scale distribution options.
    • Block-scaled layout support for quantized operations.
    • Metal code generation backend for custom kernel compilation.
    • Metal SIMD-group intrinsics for tile-based operations.
  • Tests

    • Added Metal GEMM, FP8 quantization, and block-scaled layout test suites.
    • Metal code generation and runtime validation tests.

oraluben and others added 12 commits April 30, 2026 01:43
Add T.gemm support for Apple Metal using simdgroup_matrix 8x8 operations
(simdgroup_load/store/multiply_accumulate). Works on all Apple Silicon
(M1-M5) without requiring a TVM fork.

Key changes:
- codegen_metal.cc/h: Fork TVM Metal codegen to tilelang with
  simdgroup intrinsic emission and 128-bit vectorized copy
- gemm_metal.py: GemmMetal tile operator for sharedxshared GEMM
- metal_macro_generator.py: MPSIntrinEmitter for simdgroup MMA macros
- metal_fragment_to_simdgroup.py: Pass rewrites local.fragment GEMM
  accumulators to metal.simdgroup scope before layout inference
- LowerSIMDGroupCopy in copy.cc for fragment->device simdgroup_store

24 Metal tests (codegen cross-platform + correctness on device).
Copilot AI review requested due to automatic review settings May 4, 2026 15:16
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 4, 2026

Caution

Review failed

Pull request was closed or merged during review

📝 Walkthrough

Walkthrough

This PR introduces comprehensive Metal (MPS) backend support to TileLang, enabling GPU-accelerated GEMM, quantized matmul, and GDN operations on Apple Silicon. It adds Metal code generation, SIMD-group matrix abstractions, FP8 scaled quantization support, JIT compilation, and extensive testing infrastructure.

Changes

Metal Backend Infrastructure & GEMM Support

Layer / File(s) Summary
Core Metal Codegen
src/target/codegen_metal.{cc,h}
New CodeGenTileLangMetal class generates Metal Shading Language from TileLang kernels, handling kernel function signatures, threadgroup dimensions, address spaces, simdgroup matrix operations, and buffer management. Registers "target.build.tilelang_metal" FFI entry point.
Metal Build & Target Config
src/backend/metal/CMakeLists.txt
Early-return on non-Apple hosts when USE_METAL is enabled, preventing runtime Metal source discovery on unsupported platforms.
Compiler Pass Integration
tilelang/engine/phase.py, tilelang/engine/lower.py
Inserts MetalFragmentToSimdgroup pass before layout inference to rewrite fragment-scoped GEMM accumulators to metal.simdgroup. Routes Metal device codegen through new target.build.tilelang_metal FFI entry.
Fragment-to-Simdgroup Transform
tilelang/transform/metal_fragment_to_simdgroup.py
New prim-function pass that collects GEMM accumulator buffers in local.fragment scope and remaps them to metal.simdgroup scope for Metal targets.
Layout Inference Relaxation
src/transform/layout_inference.cc
Permits missing fragment buffer layouts on Metal targets (since fragments lower to opaque simdgroup matrices).
GEMM Instruction Selection
src/op/gemm.{cc,h}, tilelang/tileop/gemm/inst.py
New GemmInst::kMetalSimdgroup enum variant for Metal. Selects Metal GEMM path when TargetIsMetal(target) is true. Makes per-warp row tile size (kMPerWarp) target-dependent: 8 for Metal, 16 for others.
Metal GEMM Implementation
tilelang/tileop/gemm/gemm_metal.py, tilelang/tileop/gemm/__init__.py
New GemmMetal class for Metal GEMM lowering using SIMD-group intrinsics. Computes warp partitions, derives simdgroup tiling, and emits load/mma/store operations. Handles both simdgroup and shared-memory output paths.
Metal Memory Operations
src/op/copy.{cc,h}, src/op/fill.cc
Adds CopyInst::kMetalSIMDGroup path for simdgroup-scoped copies with warp-tile-based lowering. Fill operations now detect and generate make_filled_simdgroup_matrix calls for simdgroup buffers.
Utility Helpers
src/op/utils.h, src/op/parallel.cc, tilelang/utils/language.py, tilelang/transform/decouple_type_cast.py
Adds IsSIMDGroupBuffer, IsRegisterBuffer predicates. Updates fragment-cast logic to safely handle optional fragment extraction. Treats simdgroup buffers as local/register-level.
JIT & Runtime Integration
tilelang/jit/adapter/base.py, tilelang/jit/adapter/torch/metal.py
Updates device-selection fallback to prefer MPS when CUDA unavailable. Adds get_kernel_source() to expose generated Metal kernel source.

Metal SIMD-group Abstractions & Macros

Layer / File(s) Summary
Register Tile & Vector Abstractions
tilelang/tileop/metal_simdgroup.py
Defines RegisterTile (8×8 SIMD-group fragment-backed tile) and RowVector (materialized scalar row vector) dataclasses with layout metadata. Provides core macros for allocation, memory movement (load/store), MMA operations, and row-wise reductions.
Intrinsic Emitter
tilelang/intrinsics/metal_macro_generator.py
MPSIntrinEmitter class emits warp-scoped TileLang macros for SIMD-group operations: ldmatrix_a/b (with transpose), mma (multiply-accumulate), and simdgroup_copy (load/store).
Quantization Helpers
tilelang/tileop/metal_quant.py
Tile shape selection and FP8/FP4/E8M0 decoding primitives (fp8_e4m3fn_to_float, fp4_e2m1fn_to_float, e8m0_to_float) using TileLang bit-casting.
GDN/Attention Macros
tilelang/tileop/metal_gdn.py
Higher-level GDN KKT score tiling, gating, and W/U staged computation macros (kkt_score_tile, apply_kkt_gate_triangular_tile, wu_score_tiles).

FP8 Quantization Support

Layer / File(s) Summary
Block-Scale Layout Metadata
tilelang/language/blockscaled_layout.py
BlockScaledLayout frozen dataclass encapsulates E8M0 block-scale format (K/32 block size), validates scale tensor shapes, computes block indices, and provides byte decoding.
FP8 Scaled Matmul Macro
tilelang/language/fp8_op.py
Hygienic TileLang macro T.fp8_scaled_matmul performs per-element FP8→float32 dequant and fused-scale matmul with per-tensor/per-row/per-col scaling or optional block-scale layout. Supports transposed-B variant.
Language Exports
tilelang/language/__init__.py
Re-exports BlockScaledLayout, e8m0_to_float, and fp8_scaled_matmul for user-facing API.

Testing & Benchmarking

Layer / File(s) Summary
CPU Unit Tests
testing/python/cpu/test_blockscaled_e8m0_layout.py, testing/python/cpu/test_fp8_scaled_matmul_lowering.py
Validate block-scale layout metadata, FP8 dequant lowering, scale shape validation, and error handling without GPU requirements.
Metal Integration Tests
testing/python/metal/test_metal_gemm_v2.py, testing/python/metal/test_metal_gemm_v2_linux.py, testing/python/metal/test_metal_simdgroup_store.py, testing/python/metal/test_metal_local_var.py
Cross-platform (Linux codegen, macOS runtime) GEMM correctness, simdgroup store codegen validation, and local-variable scalar code generation.
FP8 Scaled Matmul on Metal
testing/python/metal/test_fp8_scaled_matmul_metal.py
End-to-end FP8 scaled matmul with offline MSL compilation, Torch parity vs synthesized FP32 references, optional audiohacking comparison, and benchmarks.
Internal Scaffolding & Coverage
testing/python/metal/test_metal_internal_scaffolding.py, testing/python/metal/metal_internal_runtime_coverage.md
Internal-only probe kernels for register tiles, row vectors, packed quantization, GDN operations; source-boundary validation; runtime parity vs CPU/Torch references; fail-closed native FP8/FP4 storage checks; opt-in micro-benchmarks.
JIT Adapter Tests
testing/python/jit/test_tilelang_jit_adapter_mps.py
Validate MPS device selection when CUDA unavailable.
Standalone Benchmark
benchmark/matmul_metal/benchmark_matmul_metal.py
Standalone Metal GEMM benchmark script with block-config sweep capability and TFLOPS reporting.

Dependency & Configuration Updates

Layer / File(s) Summary
Platform-Specific Constraints
pyproject.toml, requirements.txt, requirements-dev.txt
Added Darwin/macOS-specific upper bound for apache-tvm-ffi<0.1.8 to address platform compatibility.

Sequence Diagrams

sequenceDiagram
    actor User as Application
    participant Compiler as TileLang Compiler
    participant MetalGen as Metal CodeGen
    participant JIT as JIT Adapter
    participant MPS as MPS Runtime

    User->>Compiler: Define kernel (T.prim_func + GEMM/FP8Matmul)
    activate Compiler
    
    Compiler->>Compiler: Lower phase<br/>(fragment→simdgroup transform)
    Compiler->>Compiler: Layout inference<br/>(Metal special-case)
    Compiler->>Compiler: Select Metal GEMM<br/>(GemmInst::kMetalSimdgroup)
    Compiler->>Compiler: Lower GEMM<br/>(GemmMetal→simdgroup ops)
    
    deactivate Compiler
    
    Compiler->>MetalGen: Invoke target.build.tilelang_metal<br/>(CodeGenTileLangMetal)
    activate MetalGen
    
    MetalGen->>MetalGen: Emit kernel void + args struct
    MetalGen->>MetalGen: Emit threadgroup dimensions
    MetalGen->>MetalGen: Generate simdgroup ops<br/>(load/store/mma)
    
    deactivate MetalGen
    MetalGen-->>Compiler: Metal Shading Language source
    
    Compiler-->>JIT: Lowered module + MSL source
    
    JIT->>JIT: Compile MSL via Xcode toolchain<br/>(xcrun metal -c)
    JIT->>JIT: Create MPS command buffer<br/>+ setup kernel arguments
    
    JIT-->>User: Return compiled JIT kernel
    
    User->>User: Prepare PyTorch input tensors<br/>(on device="mps")
    User->>MPS: Execute kernel(A, B, C)<br/>(via JIT adapter)
    
    activate MPS
    MPS->>MPS: SIMD-group threads load A/B<br/>(simdgroup_load)
    MPS->>MPS: Warp-level matmul<br/>(simdgroup_multiply_accumulate)
    MPS->>MPS: SIMD-group store C<br/>(simdgroup_store)
    MPS->>MPS: Synchronize
    deactivate MPS
    
    MPS-->>User: Output tensor C on MPS device
Loading
sequenceDiagram
    actor User as Application
    participant Compiler as TileLang Compiler
    participant FP8Macro as FP8 Scaled Matmul Macro
    participant Lowering as Metal Lowering
    participant MPS as MPS Runtime

    User->>Compiler: Define kernel with<br/>T.fp8_scaled_matmul(A_fp8, B_fp8, scales, C_out)
    activate Compiler
    
    Compiler->>FP8Macro: Validate FP8 dtypes + scale shapes
    activate FP8Macro
    
    FP8Macro->>FP8Macro: Determine scale indexing pattern<br/>(per-tensor, per-row, per-col,<br/>or block-scale)
    FP8Macro->>FP8Macro: Emit parallel (i,j) loop<br/>+ serial K loop
    FP8Macro->>FP8Macro: Emit per-element:<br/>dequant(A_fp8[i,k]),<br/>dequant(B_fp8[k,j]),<br/>fused_scale multiplication,<br/>C_out[i,j] accumulation
    
    deactivate FP8Macro
    FP8Macro-->>Compiler: Expanded TIR with float32 ops
    
    Compiler->>Lowering: Lower TIR to Metal backend<br/>(via GEMM path if applicable)
    activate Lowering
    
    Lowering->>Lowering: Emit simdgroup-compatible<br/>matmul kernels
    Lowering->>Lowering: Avoid simdgroup_multiply_accumulate<br/>for FP8 (use scalar accum instead)
    
    deactivate Lowering
    Lowering-->>Compiler: Metal source with<br/>FP8 dequant + matmul
    
    Compiler-->>User: Return compiled kernel
    
    User->>User: Prepare FP8 + scale tensors on MPS
    User->>MPS: Execute FP8 scaled matmul kernel
    
    activate MPS
    MPS->>MPS: Load FP8 elements
    MPS->>MPS: Dequantize to float32<br/>per-element
    MPS->>MPS: Apply broadcast scales
    MPS->>MPS: Accumulate into output
    deactivate MPS
    
    MPS-->>User: Return float32 output on MPS
Loading

Estimated Code Review Effort

🎯 4 (Complex) | ⏱️ ~60 minutes


Possibly Related PRs


Suggested Reviewers

  • SiriusNEO
  • LeiWang1999

Poem

🐰 Hops excitedly, twitching whiskers

Shiny Metal for the Apple fold,
SIMD-group tiles of precious gold,
FP8 scales dance in quantized grace,
Fuzzy rabbits now have a speedy place!

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 4, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

apstenku123 added a commit to DatasunriseOU/cppmega_mlx that referenced this pull request May 4, 2026
After re-author against current PR #2142 macro shape (the previous probe-
failed drafts targeted a non-existent tileop scheduler hierarchy), both
patches now apply cleanly on jorgecurious metal-gemm-upstream-rebase + #2142
prereq stack.

Filed:
- PR tile-ai/tilelang#2146 (Path C tracker B): fused FP8 scale broadcast
  into T.fp8_scaled_matmul K-loop. 16/8 LOC delta in tilelang/language/
  fp8_op.py. Closes the 3-6× audiohacking perf gap on FP8 scaled matmul
  per the cppmega.mlx Path C consumer at fp8_vecmat_path_c.py.
- PR tile-ai/tilelang#2147 (Path C tracker C): T.BlockScaledLayout.e8m0_k32
  + T.e8m0_to_float DSL primitive. 5 files touched (tilelang/language/
  blockscaled_layout.py new, fp8_op.py extended, __init__.py re-export,
  metal_quant.py Metal lowering, e8m0 layout test). Unblocks Sparse-MLA
  blockscaled Path C QK reducer.

Both stack on PR #2142 (T.fp8_scaled_matmul intrinsic) which stacks on
PR #2130 (jorgecurious base). Independent of each other — different gaps,
different files (B touches the macro body, C adds the layout primitive).

Receipt _filed_prs_2026_05_04.md updated with rows 13-14.

Total filed PRs: 14 (across ml-explore/mlx, apache/tvm, tile-ai/tilelang,
tile-ai/tvm). All OPEN.

Path C tracker A (pipelined_32x32) shipped in commit 3cb6457 + 6746ff9.
Path C tracker B (#2146) and C (#2147) now filed upstream. All three
Path C follow-up entries from docs/upstream/_path_c_blockers_tracker.md
have landing receipts.
@apstenku123
Copy link
Copy Markdown
Author

Withdrawn by submitter: filed without local build/test verification of the lowered MSL output. Will re-verify end-to-end (TileLang ninja build + IR tests + probe parity) before re-filing.

@apstenku123 apstenku123 closed this May 4, 2026
@apstenku123 apstenku123 deleted the cppmega/metal-add-e8m0-blockscaled-layout branch May 4, 2026 15:19
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This stacked PR adds the Metal-side scaffolding needed for simdgroup GEMM / FP8 lowering and then extends the TileLang FP8 DSL with a first-class e8m0 K/32 block-scale layout for blockscaled FP8 matmul.

Changes:

  • Adds Metal backend/codegen, simdgroup helpers, lowering passes, and JIT/runtime plumbing.
  • Adds T.fp8_scaled_matmul, T.BlockScaledLayout.e8m0_k32(), and T.e8m0_to_float().
  • Adds broad Metal/CPU test coverage, docs, and a Metal GEMM benchmark script.

Reviewed changes

Copilot reviewed 42 out of 43 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
tilelang/utils/language.py Adds Metal simdgroup scope helper.
tilelang/transform/metal_fragment_to_simdgroup.py New pass rewriting fragment accumulators to Metal simdgroup storage.
tilelang/transform/decouple_type_cast.py Treats metal.simdgroup as register-local for cast splitting.
tilelang/tileop/metal_simdgroup.py Adds internal Metal register-tile/simdgroup helper macros.
tilelang/tileop/metal_quant.py Adds packed quant decode helpers, including e8m0_to_float.
tilelang/tileop/metal_gdn.py Adds internal Metal GDN/attention helper macros.
tilelang/tileop/gemm/inst.py Adds Metal simdgroup GEMM instruction enum.
tilelang/tileop/gemm/gemm_metal.py New Metal GEMM lowering implementation.
tilelang/tileop/gemm/init.py Wires Metal GEMM selection into Python dispatch.
tilelang/language/fp8_op.py Adds FP8 scaled matmul macro and block-scale dispatch.
tilelang/language/blockscaled_layout.py New public block-scale layout metadata object.
tilelang/language/init.py Re-exports new FP8/blockscaled language APIs.
tilelang/jit/adapter/torch/metal.py Exposes Metal kernel source through the Torch adapter.
tilelang/jit/adapter/base.py Prefers MPS device selection when CUDA is unavailable/fails.
tilelang/intrinsics/metal_macro_generator.py Adds Metal simdgroup intrinsic emitter.
tilelang/engine/phase.py Inserts Metal fragment-to-simdgroup rewrite into lowering.
tilelang/engine/lower.py Switches Metal codegen entrypoint to TileLang-specific builder.
testing/python/metal/test_metal_simdgroup_store.py Adds simdgroup store codegen/runtime tests.
testing/python/metal/test_metal_local_var.py Adds focused Metal local.var tests.
testing/python/metal/test_metal_internal_scaffolding.py Adds broad internal Metal source/runtime scaffold tests.
testing/python/metal/test_metal_gemm_v2.py Adds Metal GEMM runtime correctness tests.
testing/python/metal/test_metal_gemm_v2_linux.py Adds cross-platform Metal GEMM codegen tests.
testing/python/metal/test_fp8_scaled_matmul_metal.py Adds Metal FP8 scaled matmul lowering/runtime/bench tests.
testing/python/metal/metal_internal_runtime_coverage.md Documents internal Metal runtime coverage.
testing/python/jit/test_tilelang_jit_adapter_mps.py Adds MPS device-selection tests for JIT adapter.
testing/python/cpu/test_fp8_scaled_matmul_lowering.py Adds CPU/IR tests for FP8 scaled matmul lowering.
testing/python/cpu/test_blockscaled_e8m0_layout.py Adds CPU contract tests for e8m0 block-scale layout.
src/transform/lower_device_storage_access_info.cc Exempts fragment scope from memory-info lookup.
src/transform/layout_inference.cc Relaxes fragment layout validation on Metal.
src/target/codegen_metal.h Declares TileLang Metal codegen class.
src/target/codegen_metal.cc Adds TileLang Metal source/codegen implementation.
src/op/utils.h Adds Metal simdgroup/register buffer helpers.
src/op/parallel.cc Guards fragment-layout lookup in parallel op inference.
src/op/gemm.h Adds Metal GEMM enum value/stringification.
src/op/gemm.cc Adds Metal GEMM instruction selection/partition tweaks.
src/op/fill.cc Adds simdgroup fill lowering.
src/op/copy.h Adds Metal simdgroup copy instruction/lowering hooks.
src/op/copy.cc Adds Metal simdgroup copy detection and lowering.
src/backend/metal/CMakeLists.txt Builds Metal codegen and adjusts non-Apple behavior.
requirements.txt Caps apache-tvm-ffi on Darwin.
requirements-dev.txt Caps apache-tvm-ffi on Darwin for dev installs.
pyproject.toml Adds Darwin-specific apache-tvm-ffi cap to package metadata.
benchmark/matmul_metal/benchmark_matmul_metal.py Adds Metal GEMM benchmark script.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +191 to +196
def _block_scale_value(scale, *, axis: str, col, k):
# Path C E8M0 is explicitly contracted-K-block indexed: kb = k // 32.
kb = k // 32
if axis == "B" and len(getattr(scale, "shape", ())) == 2:
return e8m0_to_float(scale[col, kb])
return e8m0_to_float(scale[kb])
Comment on lines +173 to +174
if target_is_metal(target):
return GemmInst.METAL_SIMDGROUP
Comment thread src/op/gemm.cc
} else if (TargetIsCPU(target)) {
return GemmInst::kScalar;
} else if (TargetIsMetal(target)) {
return GemmInst::kMetalSimdgroup;
Comment on lines +82 to +89
new_body = tir.stmt_functor.substitute(stmt.body, var_map)
new_block = tir.Block(
stmt.iter_vars,
stmt.reads,
stmt.writes,
stmt.name_hint,
new_body,
stmt.init,
Comment on lines 440 to +445
for (const auto &[buffer, _] : use_list_) {
if (IsFragmentBuffer(buffer)) {
ICHECK_NE(layout_map.count(buffer), 0)
<< "The layout for fragment " << buffer
<< " can not be inferred correctly.";
if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) {
ICHECK(false) << "The layout for fragment " << buffer
<< " can not be inferred correctly.";
}
Comment on lines +3 to +8
# Metal codegen is pure C++ and can generate Metal shader source on any
# platform. Always compile it so target.build.tilelang_metal is available for
# cross-compilation and source-level tests on non-Apple hosts.
list(APPEND TILE_LANG_SRCS
src/target/codegen_metal.cc
)


def _e8m0_to_float_cpu(bits: int) -> float:
return 0.0 if bits == 255 else 2.0 ** (bits - 127)
Comment on lines +152 to +153
assert "byte == 0" in text
assert "byte == 0xFF" in text
Comment on lines +346 to +352
if block_scale_layout is not None:
sa = _block_scale_value(A_scale, axis="A", col=j, k=k)
sb = _block_scale_value(B_scale, axis="B", col=j, k=k)
else:
sa = A_scale[0] if sa_size == 1 else A_scale[i]
sb = B_scale[0] if sb_size == 1 else B_scale[j]
C_local[i, j] = C_local[i, j] + a_val * b_val * sa * sb
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants