Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Training-oriented kernels and schemes include:

- **[Blockwise FP8](alto/kernels/blockwise_fp8)** — linear, grouped GEMM, and FlashAttention.
- **[MXFP4](alto/kernels/fp4/mxfp4)** — linear, grouped GEMM, and FlashAttention.
- **[MXFP8](alto/kernels/mxfp8)** — linear and grouped GEMM (block-scaled E4M3, with E5M2 reserved for gradients).
- **[NVFP4](alto/kernels/fp4/nvfp4)** — linear and grouped GEMM, using an E4M3 inner-block scale with an optional two-level (tensorwise) outer scale.

Techniques used to narrow the gap versus BF16 include:
Expand Down
32 changes: 30 additions & 2 deletions alto/kernels/dispatch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from alto.kernels.fp4.nvfp4.nvfp_grouped_gemm.functional import (
_quantize_then_nvfp4_scaled_grouped_mm,)
from alto.kernels.mxfp8.mxfp8_linear import _to_mxfp8_then_scaled_mm
from alto.kernels.mxfp8.mxfp8_grouped_gemm import _quantize_then_mxfp8_scaled_grouped_mm
from .config import TrainingOpConfig

aten = torch.ops.aten
Expand Down Expand Up @@ -406,8 +407,35 @@ class MXFP8TrainingWeightWrapperTensor(TrainingWeightWrapperBaseTensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs={}):
if func.__name__ == "_grouped_mm":
Comment thread
hann-wang marked this conversation as resolved.
raise NotImplementedError("MXFP8 _grouped_mm is not supported by this dispatch path; "
"restrict MXFP8 schemes to Linear targets.")
# Routed-expert MoE path: 2d activations x 3d weights with offsets.
A, B = args[0], args[1]
bias = kwargs.get("bias", None)
offs = kwargs.get("offs", None)

assert not isinstance(A, cls), f"A should not be a {cls.__name__}"
assert isinstance(B, cls), f"B should be a {cls.__name__}"
assert A.ndim == 2 and B.ndim == 3 and offs is not None, (
"Only 2d x 3d with offsets is supported for MXFP8 grouped_mm"
)
assert bias is None, "Bias is not supported for grouped_mm"

config = B.config
assert config.precision == "mxfp8_e4m3", (
"MXFP8 grouped_mm V1 supports only mxfp8_e4m3; "
f"got {config.precision} (e5m2 grouped path is not yet validated)"
)
assert not config.use_hadamard and not config.use_dge, (
"MXFP8 grouped_mm V1 does not support Hadamard or DGE options."
)

return _quantize_then_mxfp8_scaled_grouped_mm(
A,
B,
offs=offs,
use_2dblock_x=config.use_2dblock_x,
use_2dblock_w=config.use_2dblock_w,
use_sr_grad=config.use_sr_grad,
)

if func.__name__ in gemm_ops:
trans_b = func.__name__ == "linear"
Expand Down
413 changes: 413 additions & 0 deletions alto/kernels/mxfp8/MXFP8_GROUPED_GEMM_PLAN.md

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions alto/kernels/mxfp8/mxfp8_grouped_gemm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) 2026 Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT

from alto.kernels.mxfp8.mxfp8_grouped_gemm.functional import (
mxfp8_grouped_gemm,
_quantize_then_mxfp8_scaled_grouped_mm,
)

__all__ = ["mxfp8_grouped_gemm", "_quantize_then_mxfp8_scaled_grouped_mm"]
51 changes: 51 additions & 0 deletions alto/kernels/mxfp8/mxfp8_grouped_gemm/autotune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2026 Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT
"""Autotune configs for mxfp8 grouped GEMM.

v1 keeps a single conservative config:
- BLOCK_SIZE_K == QUANT_BLOCK_SIZE (=32) so each tl.dot_scaled call covers
exactly one mx scale group; this matches the numerical contract validated
by alto/kernels/mxfp8/mxfp8_linear.py.
- BSM=BSN=128 matches mxfp4 grouped GEMM's default tile.
Wider autotune is deferred to v2.
"""

import triton

ALIGN_SIZE_M = 128 # token routing alignment; tokens routed to the same expert must form contiguous blocks of this size

STANDARD_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
},
num_stages=2,
num_warps=4,
),
]

DGRAD_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32, # dgrad reduces over N; keep one MX scale group per dot_scaled
"BLOCK_SIZE_K": 32,
},
num_stages=2,
num_warps=4,
),
]

WGRAD_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": 32, # wgrad reduces over M; keep one MX scale group per dot_scaled
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
},
num_stages=2,
num_warps=4,
),
]
Loading