Skip to content

Add MXFP8 grouped GEMM (E4M3) for routed-expert MoE training#31

Merged
hann-wang merged 18 commits into
mainfrom
yue/mxfp8-grouped-gemm
Jun 26, 2026
Merged

Add MXFP8 grouped GEMM (E4M3) for routed-expert MoE training#31
hann-wang merged 18 commits into
mainfrom
yue/mxfp8-grouped-gemm

Conversation

@ysa2215

@ysa2215 ysa2215 commented Jun 18, 2026

Copy link
Copy Markdown
Collaborator

Description:

GPT-OSS-style MoE training needs a grouped GEMM to run on MXFP8 — without it the MXFP8 dispatch path errored out on routed-expert layers, restricting MXFP8 to Linear targets. This is the minimum viable kernel to get MoE training running.

V1 uses E4M3 for all operands across fwd/dgrad/wgrad. A single format keeps the kernel free of dtype dispatch and the autograd free of a separate grad_output quantization path — simplest to implement and validate.

Adds an MXFP8 contiguous grouped GEMM kernel (forward + dgrad + wgrad) and wires it into the MXFP8 dispatch path so routed-expert MoE layers can run on MXFP8.

  • New kernel package alto/kernels/mxfp8/mxfp8_grouped_gemm/
    • cg_forward.py — Triton persistent forward grouped GEMM with super-grouping schedule and contiguous index routing, adapted from the MXFP4 scaffold with all K-packing removed (MXFP8 is one element per byte).
    • cg_backward.py — full backward (dgrad + wgrad) wrapped in an MXFP8GroupedGEMM(autograd.Function).
    • functional.py — user-facing mxfp8_grouped_gemm(...) plus the dispatch-layer entry _quantize_then_mxfp8_scaled_grouped_mm(...), which mirrors the MXFP4 _quantize_then_mxfp_scaled_grouped_mm contract (1-D cumulative offs, padded activation buffers, weights kept in [E, K, N] dispatch convention to avoid a transpose copy).
    • autotune.py, __init__.py.
  • Dispatch wiring (alto/kernels/dispatch/tensor.py): the MXFP8 weight wrapper's _grouped_mm path previously raised NotImplementedError; it now routes 2d-activation × 3d-weight + offsets calls to the new kernel, with guards limiting
    V1 to mxfp8_e4m3 and rejecting Hadamard/DGE/bias.
  • Tests (tests/unittest/mxfp8/): grouped GEMM correctness with SNR gates, use_2dblock coverage, and a toy-MoE end-to-end training sanity check (+ loss-curve plotting).
  • Docs: MXFP8_GROUPED_GEMM_PLAN.md design/plan doc; m355 test results; README entry for MXFP8 (linear + grouped GEMM).

Test plan

  • tests/unittest/mxfp8/test_mxfp8_grouped_gemm.py (SNR gates pass)
  • tests/unittest/mxfp8/test_e2e_moe.py toy-MoE training sanity

@ysa2215 ysa2215 marked this pull request as ready for review June 23, 2026 04:02
Copilot AI review requested due to automatic review settings June 23, 2026 04:02

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds an MXFP8 (E4M3) contiguous grouped GEMM implementation (forward + full autograd backward) and wires it into the dispatch layer so routed-expert MoE training can use the MXFP8 path (previously limited to Linear).

Changes:

  • Introduces alto/kernels/mxfp8/mxfp8_grouped_gemm/ (Triton forward kernel, dgrad/wgrad kernels, autograd wrapper, dispatch-facing functional entry, and minimal autotune configs).
  • Connects the MXFP8 weight wrapper _grouped_mm dispatch path to the new offsets-based grouped GEMM entry (with V1 guards).
  • Adds unit tests for kernel correctness, autograd correctness, offsets/padded-buffer behavior, plus a toy MoE multi-step training sanity check and optional repro/plot scripts/docs.

Reviewed changes

Copilot reviewed 15 out of 15 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
tests/unittest/mxfp8/utils.py Re-exports shared SNR/cossim helpers and adds deterministic contiguous-routing indices helper for tests.
tests/unittest/mxfp8/test_mxfp8_grouped_gemm.py Adds forward, autograd, and offsets/padded-buffer correctness tests for MXFP8 grouped GEMM.
tests/unittest/mxfp8/test_e2e_moe.py Adds a toy MoE training loop sanity test comparing MXFP8 vs BF16 trends.
tests/unittest/mxfp8/repro_mxfp8_dot_scaled.py Standalone repro script for tl.dot_scaled multi-block accuracy behavior.
tests/unittest/mxfp8/repro_mxfp8_dot_scaled.md Documentation for running/interpreting the tl.dot_scaled repro script.
tests/unittest/mxfp8/plot_e2e_moe_curve.py Standalone script to plot toy MoE loss curves (not a pytest test).
tests/unittest/compare_grouped_gemm_toy_moe.py Standalone comparison script across grouped GEMM precisions on the same toy task.
README.md Documents MXFP8 support including grouped GEMM.
alto/kernels/mxfp8/mxfp8_grouped_gemm/functional.py Exposes mxfp8_grouped_gemm(...) and offsets-based dispatch entry _quantize_then_mxfp8_scaled_grouped_mm(...).
alto/kernels/mxfp8/mxfp8_grouped_gemm/cg_forward.py Triton forward grouped GEMM kernel + launcher wrapper (supports dot_scaled and dequant fallback).
alto/kernels/mxfp8/mxfp8_grouped_gemm/cg_backward.py Triton dgrad/wgrad kernels + wrappers + MXFP8GroupedGEMM autograd.Function wiring.
alto/kernels/mxfp8/mxfp8_grouped_gemm/autotune.py Minimal v1 autotune configs and ALIGN_SIZE_M definition.
alto/kernels/mxfp8/mxfp8_grouped_gemm/init.py Exports public grouped GEMM API and dispatch entry.
alto/kernels/mxfp8/MXFP8_GROUPED_GEMM_PLAN.md Design/plan doc and validation notes for the V1 grouped GEMM.
alto/kernels/dispatch/tensor.py Routes MXFP8 _grouped_mm dispatch to the new offsets entry with V1 feature guards.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +128 to +145
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for ki in range(k_tiles):

offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_k_scale = ki * n_rep_k + tl.arange(0, n_rep_k)
mask_k_scale = offs_k_scale < Ks

mask_m = offs_m < M_TOTAL
mask_n = offs_n < N
mask_k = offs_k < K

mask_a = mask_m[:, None] & mask_k[None, :]
mask_b = mask_k[:, None] & mask_n[None, :]

# Determine the expert group index and load expert ID
group_idx = m_start // GROUP_SIZE_M
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ysa2215, could you please look into this?

The same issue might exist in both MXFP4 and NVFP4.

@ysa2215 ysa2215 Jun 25, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About MXFP8:
Fixed. expert_idx only depends on m_start, so it's now loaded once outside the K loop instead of every K-iteration. Verified with test_mxfp8_grouped_gemm.py (52 passed).

About MXFP4/NVFP4 — I checked both:

  • MXFP4 forward had the same redundant load — fixed (same hoist as mxfp8).
  • MXFP4 backward already hoists expert_idx outside the inner N-loop, so no change needed.
  • NVFP4 doesn't hit this pattern — its grouped GEMM uses a per-expert torch GEMM loop (no indices_ptr load inside a Triton K-loop), so there's nothing to hoist.

Comment thread alto/kernels/mxfp8/mxfp8_grouped_gemm/functional.py
Comment thread tests/unittest/mxfp8/test_e2e_moe.py Outdated
Copilot AI review requested due to automatic review settings June 25, 2026 08:47

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 16 out of 16 changed files in this pull request and generated 3 comments.

Comment thread tests/unittest/mxfp8/test_mxfp8_grouped_gemm.py
Comment thread tests/unittest/mxfp8/test_e2e_moe.py
Comment thread tests/unittest/mxfp8/repro_mxfp8_dot_scaled.py
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Copilot AI review requested due to automatic review settings June 26, 2026 03:29
ysa2215 and others added 2 commits June 26, 2026 11:30
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 16 out of 16 changed files in this pull request and generated 4 comments.

Comment on lines +18 to +21
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import torch
Comment on lines +31 to +34
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import torch
Comment on lines +1 to +3
import torch
import triton
import triton.language as tl
Comment thread tests/unittest/mxfp8/test_e2e_moe.py
Copilot AI review requested due to automatic review settings June 26, 2026 03:35

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 16 out of 16 changed files in this pull request and generated 1 comment.

Comment thread alto/kernels/dispatch/tensor.py
@hann-wang hann-wang merged commit 2543fc0 into main Jun 26, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants