Skip to content

Commit 9d0b328

Browse files
committed
Enable MiniMax-M3 MXFP4 (AttnFP8) on top of the BF16 M3 support
The in-tree MiniMax-M3 model already covers the BF16 checkpoint. This adds the small pieces the quantized amd/MiniMax-M3-MXFP4-AttnFP8 build needs, without disturbing the BF16 path. - config.py: register the minimax_m3_vl multimodal wrapper and parse its text sub-config (which declares no model_type) with the base PretrainedConfig so every field is retained and no deepseek/MLA defaults leak in; stamp model_type=minimax_m3 from the top-level type. The quark quantization_config (already propagated from the root) and the original architectures are preserved, so loading resolves to the existing MiniMaxM3Sparse model. The BF16 checkpoint keeps its direct minimax_m3 model_type and is unaffected. - linear.py: pad the MXFP4 Linear contraction dim to 256. The a4w4 asm GEMM reads K in 256-wide tiles, so an unaligned K (M3's shared-expert down_proj at TP=8, K=384) faults on GPU. LinearBase._pad_mxfp4_input_dim() zero-pads the fp4x2 weight, its e8m0 scale, and the activation up to 256-alignment; no-op when already aligned.
1 parent fc4d766 commit 9d0b328

2 files changed

Lines changed: 61 additions & 4 deletions

File tree

atom/config.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,7 @@ def _remap_layer_name(name: str) -> list[str]:
586586
"qwen3_5": "text_config",
587587
"qwen3_5_moe": "text_config",
588588
"mistral3": "text_config",
589+
"minimax_m3_vl": "text_config",
589590
}
590591

591592
# multimodal models fully supported by plugin mode
@@ -630,10 +631,20 @@ def _get_hf_token() -> str | None:
630631
and "quantization_config" in config_dict
631632
):
632633
text_config_dict["quantization_config"] = config_dict["quantization_config"]
633-
text_model_type = text_config_dict.get("model_type", "deepseek_v3")
634-
mapped_type = _CONFIG_REGISTRY.get(text_model_type, text_model_type)
635-
config_class = AutoConfig.for_model(mapped_type)
636-
hf_config = config_class.from_dict(text_config_dict)
634+
if "model_type" not in text_config_dict:
635+
# The text sub-config declares no `model_type` of its own (e.g.
636+
# MiniMax-M3's minimax_m3_vl wrapper). Parse it with the base
637+
# config so every field is retained and no foreign (e.g. deepseek
638+
# MLA) defaults are injected; the model class reads its own fields
639+
# via getattr. Stamp the model_type from the top-level type so
640+
# downstream policy can key off it.
641+
hf_config = PretrainedConfig(**text_config_dict)
642+
hf_config.model_type = model_type.removesuffix("_vl")
643+
else:
644+
text_model_type = text_config_dict.get("model_type", "deepseek_v3")
645+
mapped_type = _CONFIG_REGISTRY.get(text_model_type, text_model_type)
646+
config_class = AutoConfig.for_model(mapped_type)
647+
hf_config = config_class.from_dict(text_config_dict)
637648
# Override architectures so that ATOM selects the correct model class
638649
# which can handle the multimodal weight prefix during loading.
639650
original_arch = config_dict.get("architectures", [])

atom/model_ops/linear.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,46 @@ def online_quantize_weight(self):
542542
"quant_dtype": str(online_quant_dtype),
543543
}
544544

545+
def _pad_mxfp4_input_dim(self):
546+
"""Zero-pad the MXFP4 (per-1x32 fp4x2) contraction dim up to 256.
547+
548+
The a4w4 asm GEMM's preshuffle + e8m0 scale layout reads the K
549+
dimension in 256-wide tiles. Per-rank shapes whose K is not a
550+
multiple of 256 (e.g. the TP=8 shared-expert down_proj with
551+
K=384) otherwise trigger an out-of-bounds GPU memory access
552+
fault. Padded weight bytes are zero so the extra K contributes
553+
nothing to the result. Mirrors FusedMoE's pad_align=256.
554+
"""
555+
self._mxfp4_in_pad = 0
556+
if not (
557+
self.quant_type == QuantType.per_1x32
558+
and self.params_dtype == dtypes.fp4x2
559+
and self.weight.dim() == 2
560+
and self.weight.data.dtype == dtypes.fp4x2
561+
):
562+
return
563+
weight_scale = getattr(self, "weight_scale", None)
564+
if weight_scale is None:
565+
return
566+
align = 256
567+
k = weight_scale.shape[-1] * MXFP4_QUANT_BLOCK_SIZE
568+
k_pad = (k + align - 1) // align * align
569+
if k_pad == k:
570+
return
571+
scale_pad = k_pad // MXFP4_QUANT_BLOCK_SIZE - weight_scale.shape[-1]
572+
# weight_scale is e8m0 (exponent-only); 0.0 is not representable, so pad
573+
# the raw bytes (0x00 -> 2^-127, harmless since padded weights are zero).
574+
scale_u8 = weight_scale.data.view(torch.uint8)
575+
self.weight_scale.data = torch.nn.functional.pad(scale_u8, (0, scale_pad)).view(
576+
weight_scale.data.dtype
577+
)
578+
weight_u8 = self.weight.data.view(torch.uint8)
579+
weight_pad = k_pad // 2 - weight_u8.shape[-1]
580+
self.weight.data = torch.nn.functional.pad(weight_u8, (0, weight_pad)).view(
581+
dtypes.fp4x2
582+
)
583+
self._mxfp4_in_pad = k_pad - k
584+
545585
def process_weights_after_loading(self):
546586
# Re-quantize before process_weights if online quantization is enabled
547587
if self.quant_config is not None and self.quant_config.online_quant:
@@ -580,6 +620,7 @@ def process_weights_after_loading(self):
580620
)
581621
self.weight.data = w_q
582622
self.weight_scale = atom_parameter(w_s)
623+
self._pad_mxfp4_input_dim()
583624
# Only quantized 2D GEMM weights use aiter's preshuffle layout.
584625
# Qwen3-Next/Qwen3.5 GDN conv1d expands its weight to 3D, so FP8/blocked
585626
# quantized models must keep that tensor unshuffled here.
@@ -591,6 +632,8 @@ def process_weights_after_loading(self):
591632
self.quant_type == QuantType.per_1x32
592633
and self.params_dtype == dtypes.fp4x2
593634
)
635+
if is_fp4_blockscale:
636+
self._pad_mxfp4_input_dim()
594637
need_shuffle = (
595638
self.quant_type == QuantType.per_Token
596639
and self.params_dtype == dtypes.fp8
@@ -688,6 +731,9 @@ def forward(
688731
if self.bias is not None:
689732
y += self.bias
690733
elif self.quant_type.value == QuantType.per_1x32.value:
734+
in_pad = getattr(self, "_mxfp4_in_pad", 0)
735+
if in_pad and x_scale is None:
736+
x = torch.nn.functional.pad(x, (0, in_pad))
691737
y = gemm_a4w4_quant(
692738
x,
693739
x_scale,

0 commit comments

Comments
 (0)