Skip to content

Add MiniMax-M3 (MXFP4/AttnFP8) model support#1317

Open
thpereir wants to merge 1 commit into
ROCm:mainfrom
thpereir:thpereir/m3
Open

Add MiniMax-M3 (MXFP4/AttnFP8) model support#1317
thpereir wants to merge 1 commit into
ROCm:mainfrom
thpereir:thpereir/m3

Conversation

@thpereir

@thpereir thpereir commented Jun 22, 2026

Copy link
Copy Markdown
Contributor

The in-tree MiniMax-M3 model already covers the BF16 checkpoint. This
adds the small pieces the quantized amd/MiniMax-M3-MXFP4-AttnFP8 build
needs, without disturbing the BF16 path.

  • config.py: register the minimax_m3_vl multimodal wrapper and parse its
    text sub-config (which declares no model_type) with the base
    PretrainedConfig so every field is retained and no deepseek/MLA
    defaults leak in; stamp model_type=minimax_m3 from the top-level type.
    The quark quantization_config (already propagated from the root) and
    the original architectures are preserved, so loading resolves to the
    existing MiniMaxM3Sparse model. The BF16 checkpoint keeps its direct
    minimax_m3 model_type and is unaffected.
  • linear.py: pad the MXFP4 Linear contraction dim to 256. The a4w4 asm
    GEMM reads K in 256-wide tiles, so an unaligned K (M3's shared-expert
    down_proj at TP=8, K=384) faults on GPU. LinearBase._pad_mxfp4_input_dim()
    zero-pads the fp4x2 weight, its e8m0 scale, and the activation up to
    256-alignment; no-op when already aligned.

Validated: GSM8K 93.9% at TP=8 (full 1319, 5-shot), matching the TP=1
baseline; lossless vs the aligned baseline (cosine 0.9934).

Motivation

Properly run MiniMax M3 MXFP4 on ATOM with TP=8

Test Result

Tested with ATOM with TP=8:

python -m atom.entrypoints.openai_server \
  --model amd/MiniMax-M3-MXFP4-AttnFP8 \
  --trust-remote-code \
  --tensor-parallel-size 8 \
  --block-size 128 \
  --server-port 8000

lm-eval

lm_eval \
  --model local-chat-completions \
  --model_args "model=amd/MiniMax-M3-MXFP4-AttnFP8,base_url=http://127.0.0.1:8000/v1/chat/completions,num_concurrent=32,max_gen_toks=16384" \
  --tasks gsm8k \
  --num_fewshot 5 \
  --batch_size 1 \
  --apply_chat_template \
  --fewshot_as_multiturn

Results, for reference with TP=1 gsm8k gives ~0.9424:

Build flexible-extract strict-match
Baseline — origin/main (no fix) 0.7665 ± 0.0117 0.7657 ± 0.0117
Fixed — this branch 0.9378 ± 0.0065 0.9386 ± 0.0065

Submission Checklist

@thpereir thpereir force-pushed the thpereir/m3 branch 3 times, most recently from b93487e to cb66f66 Compare June 23, 2026 18:47
@thpereir thpereir marked this pull request as ready for review June 23, 2026 18:51
The in-tree MiniMax-M3 model already covers the BF16 checkpoint. This
adds the small pieces the quantized amd/MiniMax-M3-MXFP4-AttnFP8 build
needs, without disturbing the BF16 path.

- config.py: register the minimax_m3_vl multimodal wrapper and parse its
  text sub-config (which declares no model_type) with the base
  PretrainedConfig so every field is retained and no deepseek/MLA
  defaults leak in; stamp model_type=minimax_m3 from the top-level type.
  The quark quantization_config (already propagated from the root) and
  the original architectures are preserved, so loading resolves to the
  existing MiniMaxM3Sparse model. The BF16 checkpoint keeps its direct
  minimax_m3 model_type and is unaffected.
- linear.py: pad the MXFP4 Linear contraction dim to 256. The a4w4 asm
  GEMM reads K in 256-wide tiles, so an unaligned K (M3's shared-expert
  down_proj at TP=8, K=384) faults on GPU. LinearBase._pad_mxfp4_input_dim()
  zero-pads the fp4x2 weight, its e8m0 scale, and the activation up to
  256-alignment; no-op when already aligned.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant