Add MXFP8 grouped GEMM (E4M3) for routed-expert MoE training#31
Conversation
… design and feature doc
# Conflicts: # README.md # alto/kernels/dispatch/tensor.py
There was a problem hiding this comment.
Pull request overview
This PR adds an MXFP8 (E4M3) contiguous grouped GEMM implementation (forward + full autograd backward) and wires it into the dispatch layer so routed-expert MoE training can use the MXFP8 path (previously limited to Linear).
Changes:
- Introduces
alto/kernels/mxfp8/mxfp8_grouped_gemm/(Triton forward kernel, dgrad/wgrad kernels, autograd wrapper, dispatch-facing functional entry, and minimal autotune configs). - Connects the MXFP8 weight wrapper
_grouped_mmdispatch path to the new offsets-based grouped GEMM entry (with V1 guards). - Adds unit tests for kernel correctness, autograd correctness, offsets/padded-buffer behavior, plus a toy MoE multi-step training sanity check and optional repro/plot scripts/docs.
Reviewed changes
Copilot reviewed 15 out of 15 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/unittest/mxfp8/utils.py | Re-exports shared SNR/cossim helpers and adds deterministic contiguous-routing indices helper for tests. |
| tests/unittest/mxfp8/test_mxfp8_grouped_gemm.py | Adds forward, autograd, and offsets/padded-buffer correctness tests for MXFP8 grouped GEMM. |
| tests/unittest/mxfp8/test_e2e_moe.py | Adds a toy MoE training loop sanity test comparing MXFP8 vs BF16 trends. |
| tests/unittest/mxfp8/repro_mxfp8_dot_scaled.py | Standalone repro script for tl.dot_scaled multi-block accuracy behavior. |
| tests/unittest/mxfp8/repro_mxfp8_dot_scaled.md | Documentation for running/interpreting the tl.dot_scaled repro script. |
| tests/unittest/mxfp8/plot_e2e_moe_curve.py | Standalone script to plot toy MoE loss curves (not a pytest test). |
| tests/unittest/compare_grouped_gemm_toy_moe.py | Standalone comparison script across grouped GEMM precisions on the same toy task. |
| README.md | Documents MXFP8 support including grouped GEMM. |
| alto/kernels/mxfp8/mxfp8_grouped_gemm/functional.py | Exposes mxfp8_grouped_gemm(...) and offsets-based dispatch entry _quantize_then_mxfp8_scaled_grouped_mm(...). |
| alto/kernels/mxfp8/mxfp8_grouped_gemm/cg_forward.py | Triton forward grouped GEMM kernel + launcher wrapper (supports dot_scaled and dequant fallback). |
| alto/kernels/mxfp8/mxfp8_grouped_gemm/cg_backward.py | Triton dgrad/wgrad kernels + wrappers + MXFP8GroupedGEMM autograd.Function wiring. |
| alto/kernels/mxfp8/mxfp8_grouped_gemm/autotune.py | Minimal v1 autotune configs and ALIGN_SIZE_M definition. |
| alto/kernels/mxfp8/mxfp8_grouped_gemm/init.py | Exports public grouped GEMM API and dispatch entry. |
| alto/kernels/mxfp8/MXFP8_GROUPED_GEMM_PLAN.md | Design/plan doc and validation notes for the V1 grouped GEMM. |
| alto/kernels/dispatch/tensor.py | Routes MXFP8 _grouped_mm dispatch to the new offsets entry with V1 feature guards. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | ||
| for ki in range(k_tiles): | ||
|
|
||
| offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) | ||
| offs_k_scale = ki * n_rep_k + tl.arange(0, n_rep_k) | ||
| mask_k_scale = offs_k_scale < Ks | ||
|
|
||
| mask_m = offs_m < M_TOTAL | ||
| mask_n = offs_n < N | ||
| mask_k = offs_k < K | ||
|
|
||
| mask_a = mask_m[:, None] & mask_k[None, :] | ||
| mask_b = mask_k[:, None] & mask_n[None, :] | ||
|
|
||
| # Determine the expert group index and load expert ID | ||
| group_idx = m_start // GROUP_SIZE_M | ||
| expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M) | ||
|
|
There was a problem hiding this comment.
Hi @ysa2215, could you please look into this?
The same issue might exist in both MXFP4 and NVFP4.
There was a problem hiding this comment.
About MXFP8:
Fixed. expert_idx only depends on m_start, so it's now loaded once outside the K loop instead of every K-iteration. Verified with test_mxfp8_grouped_gemm.py (52 passed).
About MXFP4/NVFP4 — I checked both:
- MXFP4 forward had the same redundant load — fixed (same hoist as mxfp8).
- MXFP4 backward already hoists expert_idx outside the inner N-loop, so no change needed.
- NVFP4 doesn't hit this pattern — its grouped GEMM uses a per-expert torch GEMM loop (no indices_ptr load inside a Triton K-loop), so there's nothing to hoist.
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
| import matplotlib | ||
| matplotlib.use("Agg") | ||
| import matplotlib.pyplot as plt | ||
| import torch |
| import matplotlib | ||
| matplotlib.use("Agg") | ||
| import matplotlib.pyplot as plt | ||
| import torch |
| import torch | ||
| import triton | ||
| import triton.language as tl |
Description:
GPT-OSS-style MoE training needs a grouped GEMM to run on MXFP8 — without it the MXFP8 dispatch path errored out on routed-expert layers, restricting MXFP8 to Linear targets. This is the minimum viable kernel to get MoE training running.
V1 uses E4M3 for all operands across fwd/dgrad/wgrad. A single format keeps the kernel free of dtype dispatch and the autograd free of a separate grad_output quantization path — simplest to implement and validate.
Adds an MXFP8 contiguous grouped GEMM kernel (forward + dgrad + wgrad) and wires it into the MXFP8 dispatch path so routed-expert MoE layers can run on MXFP8.
alto/kernels/mxfp8/mxfp8_grouped_gemm/cg_forward.py— Triton persistent forward grouped GEMM with super-grouping schedule and contiguous index routing, adapted from the MXFP4 scaffold with all K-packing removed (MXFP8 is one element per byte).cg_backward.py— full backward (dgrad + wgrad) wrapped in anMXFP8GroupedGEMM(autograd.Function).functional.py— user-facingmxfp8_grouped_gemm(...)plus the dispatch-layer entry_quantize_then_mxfp8_scaled_grouped_mm(...), which mirrors the MXFP4_quantize_then_mxfp_scaled_grouped_mmcontract (1-D cumulativeoffs, padded activation buffers, weights kept in[E, K, N]dispatch convention to avoid a transpose copy).autotune.py,__init__.py.alto/kernels/dispatch/tensor.py): the MXFP8 weight wrapper's_grouped_mmpath previously raisedNotImplementedError; it now routes 2d-activation × 3d-weight + offsets calls to the new kernel, with guards limitingV1 to
mxfp8_e4m3and rejecting Hadamard/DGE/bias.tests/unittest/mxfp8/): grouped GEMM correctness with SNR gates,use_2dblockcoverage, and a toy-MoE end-to-end training sanity check (+ loss-curve plotting).MXFP8_GROUPED_GEMM_PLAN.mddesign/plan doc; m355 test results; README entry for MXFP8 (linear + grouped GEMM).Test plan
tests/unittest/mxfp8/test_mxfp8_grouped_gemm.py(SNR gates pass)tests/unittest/mxfp8/test_e2e_moe.pytoy-MoE training sanity