-
Notifications
You must be signed in to change notification settings - Fork 0
Add MXFP8 grouped GEMM (E4M3) for routed-expert MoE training #31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 13 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
815ce06
add dot_scaled repro script and md file
ysa2215 3f4d5df
add minimal MXFP8 dot_scaled multi-block repro
ysa2215 11ae961
fix MXFP8 dot_scaled multi-block repro signal
ysa2215 3e758a3
mxfp8: scaffold grouped GEMM package (Step 1)
ysa2215 fa05b99
mxfp8: implement forward grouped GEMM kernel (Step 2)
ysa2215 128b526
Add MXFP8 grouped GEMM forward guards and coverage
ysa2215 d077a80
mxfp8: implement grouped GEMM backward path
ysa2215 caccfc1
mxfp8: add SNR test gates and use_dot_scaled coverage for grouped GEMM
ysa2215 4643649
docs: add m355 MXFP8 grouped GEMM test result
ysa2215 52ffa43
mxfp8: add toy-MoE training sanity and cross-format comparison
ysa2215 b7fa4e1
mxfp8: add offsets entry + padded buffer for grouped GEMM, wire dispatch
ysa2215 54687cb
refactor: merge and simplify mxfp8 grouped gemm related tests, update…
ysa2215 c8becb3
Merge remote-tracking branch 'origin/main' into yue/mxfp8-grouped-gemm
ysa2215 d903f2d
fix: address mxfp8 grouped GEMM PR review comments
ysa2215 183e554
fix: hoist expert_idx load out of K loop in mxfp4 grouped GEMM forward
ysa2215 ad5c3e1
Potential fix for pull request finding
ysa2215 82dac0b
Potential fix for pull request finding
ysa2215 39aa9ee
Potential fix for pull request finding
ysa2215 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| # Copyright (c) 2026 Advanced Micro Devices, Inc. | ||
| # SPDX-License-Identifier: MIT | ||
|
|
||
| from alto.kernels.mxfp8.mxfp8_grouped_gemm.functional import ( | ||
| mxfp8_grouped_gemm, | ||
| _quantize_then_mxfp8_scaled_grouped_mm, | ||
| ) | ||
|
|
||
| __all__ = ["mxfp8_grouped_gemm", "_quantize_then_mxfp8_scaled_grouped_mm"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| # Copyright (c) 2026 Advanced Micro Devices, Inc. | ||
| # SPDX-License-Identifier: MIT | ||
| """Autotune configs for mxfp8 grouped GEMM. | ||
|
|
||
| v1 keeps a single conservative config: | ||
| - BLOCK_SIZE_K == QUANT_BLOCK_SIZE (=32) so each tl.dot_scaled call covers | ||
| exactly one mx scale group; this matches the numerical contract validated | ||
| by alto/kernels/mxfp8/mxfp8_linear.py. | ||
| - BSM=BSN=128 matches mxfp4 grouped GEMM's default tile. | ||
| Wider autotune is deferred to v2. | ||
| """ | ||
|
|
||
| import triton | ||
|
|
||
| ALIGN_SIZE_M = 128 # token routing alignment; tokens routed to the same expert must form contiguous blocks of this size | ||
|
|
||
| STANDARD_CONFIGS = [ | ||
| triton.Config( | ||
| { | ||
| "BLOCK_SIZE_M": 128, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 32, | ||
| }, | ||
| num_stages=2, | ||
| num_warps=4, | ||
| ), | ||
| ] | ||
|
|
||
| DGRAD_CONFIGS = [ | ||
| triton.Config( | ||
| { | ||
| "BLOCK_SIZE_M": 128, | ||
| "BLOCK_SIZE_N": 32, # dgrad reduces over N; keep one MX scale group per dot_scaled | ||
| "BLOCK_SIZE_K": 32, | ||
| }, | ||
| num_stages=2, | ||
| num_warps=4, | ||
| ), | ||
| ] | ||
|
|
||
| WGRAD_CONFIGS = [ | ||
| triton.Config( | ||
| { | ||
| "BLOCK_SIZE_M": 32, # wgrad reduces over M; keep one MX scale group per dot_scaled | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 32, | ||
| }, | ||
| num_stages=2, | ||
| num_warps=4, | ||
| ), | ||
| ] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.