[Metal] FP8 storage-only emulation (uchar storage + LUT decode helpers)#38
Open
apstenku123 wants to merge 1 commit intotile-ai:tilelang_mainfrom
Open
[Metal] FP8 storage-only emulation (uchar storage + LUT decode helpers)#38apstenku123 wants to merge 1 commit intotile-ai:tilelang_mainfrom
apstenku123 wants to merge 1 commit intotile-ai:tilelang_mainfrom
Conversation
6 tasks
apstenku123
added a commit
to DatasunriseOU/cppmega_mlx
that referenced
this pull request
May 4, 2026
…, #37/#38/#39) Three parallel agents completed the supermodule/submodule split filing: 1. tilelang_metal_fp8 (storage-only FP8 emulation) split: - 0001-tilelang-metal-fp8-storage-only.patch — supermodule half (235 lines) - 0002-tvm-metal-fp8-storage-only.patch — TVM-mirror half (260 lines, prefix stripped) - PR tile-ai/tilelang#2144 (supermodule, stacks on PR #2130) - PR tile-ai/tvm#38 (TVM mirror, base tilelang_main @ 0e15b274) 2. tilelang_metal_fp8_vector (vector cast lanes 2/3/4) split: - 0001-tilelang-metal-fp8-vector-cast.patch — supermodule half (148 lines) - 0002-tvm-metal-fp8-vector-cast.patch — TVM-mirror half (151 lines) - PR tile-ai/tilelang#2145 (supermodule, depends on #2144) - PR tile-ai/tvm#39 (TVM mirror, depends on #38) 3. PR #2143 TVM-mirror companion: - PR tile-ai/tvm#37 — already filed, README updated to link both halves Total filed today: 11 PRs across 3 repos - 1 ml-explore/mlx (#3476) - 1 apache/tvm (#19504) - 6 tile-ai/tilelang (#2139, #2140, #2141, #2142, #2143 super, #2144 super, #2145 super) - 3 tile-ai/tvm (#37, #38, #39 — TVM-mirror companions) PR #2142 (T.fp8_scaled_matmul) has no TVM-mirror companion needed — verified the patch only touches supermodule files. All splits round-trip clean (apply forward + reverse) on their respective bases. README files in each docs/upstream/<dir>/ updated with PR URLs and dependency-chain diagrams. Note: TileLang/tvm redirects to tile-ai/tvm server-side (canonical org slug). All TVM-mirror PRs land at tile-ai/tvm/pull/N URLs.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This is the TVM-mirror half of a 2-PR pair adding storage-only FP8 emulation to the Metal codegen. The companion TileLang PR lives in
tile-ai/tilelang(link below) and mirrors the same change to TileLang'sCodeGenTileLangMetalspecialisation atsrc/target/codegen_metal.{cc,h}.Apple Silicon (M1-M4 / M5 NAX) has no native FP8 ALU support, MSL has no
float8scalar type, andsimdgroup_matrix<uchar, ...>fails the Metal stdlib element-type assertion. The only viable representation is "storage-only FP8": pack 8-bit values inuchar/ucharNbuffers, dequantize on load tohalf, do the math in half (or float accumulator), and quantize back on store. This mirrors how mxfp8/nvfp4 are realised in MLX core, and how stock TVM's CUDA codegen handles pre-sm89 hardware (src/target/source/codegen_cuda.cc:410-417).Without this patch, any
float8_e4m3/float8_e5m2/float8_e8m0fnudtype reachingCodeGenMetal::PrintTypetriggersLOG(FATAL) << "Cannot convert type " << t << " to Metal type", which means sparse-MLA FP8, blockscaled, and mxfp8 lowering on the Metal target are unreachable. CUDA codegen has had this path for years; Metal had no equivalent.What this PR adds
In
src/target/source/codegen_metal.{cc,h}:PrintTypeFP8 case — whent.is_float8(), emitucharforlanes==1,ucharNforlanes∈[2,4],uint2forlanes==8,uint4forlanes==16. Setsenable_fp8_=trueso the prelude is emitted. Mirrors the CUDA codegen's behaviour where FP8 vectors >4 are packed into wider integer storage.PrintFP8Prelude— inline MSL helpers__tvm_fp8_e4m3_to_half,__tvm_fp8_e5m2_to_half,__tvm_half_to_fp8_e4m3,__tvm_half_to_fp8_e5m2. Encodings follow the OCP "OFP8 Formats for Deep Learning" v1.0 spec. E4M3 uses the finite-only encoding (S.1111.111is NaN, no Inf); E5M2 uses IEEE-style with NaN/Inf. Both directions implement round-to-nearest-even on discarded mantissa bits.VisitExpr_(CastNode)override — when either side is FP8, scalar casts route through the helpers. Vector casts (lanes>1) raise a clearLOG(FATAL)directing the caller to scalarise — the TVMtir.transform.legalize_fp8pass already scalarises most user FP8 casts, so this branch is rarely hit.Finish()override — ifenable_fp8_was set, splice the prelude right afterusing namespace metal;so the helpers see the MSL namespace.The stock TVM Metal codegen path (
target.build.metal) goes throughlegalize_fp8first, which expands FP8 ops into bit-shuffle code inline. With this patch itsPrintTypeno longer faults, so the legalised output also compiles withxcrun metal -c.Motivation: cppmega.mlx workaround
In the cppmega.mlx port (
cppmega_mlx/nn/_tilelang/fp8_msl_kernels.py) we currently ship audiohacking-style FP8 MSL kernels as rawmx.fast.metal_kernel(source=...)strings to bypass this codegen FATAL. With this PR landed in the vendored TVM mirror and bumped into TileLang, those workaround kernels can be replaced by ordinaryT.cast(half, fp8_load)chains lowered throughtarget.build.metal.Companion PR
tile-ai/tilelangPR — same change to TileLang'sCodeGenTileLangMetalspecialisation that duplicates thePrintTypelogic. See: [Metal] FP8 storage-only emulation (uchar storage + LUT decode helpers) tilelang#2144Stack
This PR targets
tile-ai/tvm:tilelang_mainand is rebased on HEAD0e15b274(the SHA TileLang's3rdparty/tvmsubmodule pins).Diff stat
Test plan
git apply --checkclean againstTileLang/tvm@0e15b274lower(prim_func, target=tvm.target.Target("metal")):fp8_e4m3 -> halfhalf -> fp8_e4m3fp8_e5m2 -> halfhalf -> fp8_e5m2xcrun --sdk macosx metal -cagainst any prim_func with FP8 dtype lowered to MSL with the inline helperstorch.float8_e4m3fnandmlx.from_fp8reference (full 256-byte e4m3 finite range)float8_e8m0fnuscale storage) lowers correctly: it's adevice uchar*buffer with no helper calls (just pass-through)cmake -DUSE_METAL=ON && ninjaagainst the patchedsrc/target/source/codegen_metal.ccRisk
lanes>1) raise a clear FATAL with guidance; callers route through the existinglegalize_fp8scalarise pass.