Skip to content

Commit 8d7fe0b

Browse files
committed
minor update
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent fbd5417 commit 8d7fe0b

File tree

2 files changed

+2
-17
lines changed

2 files changed

+2
-17
lines changed

modelopt/torch/export/quant_utils.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -776,36 +776,23 @@ def to_quantized_weight(
776776

777777
if quantization == QUANTIZATION_FP8_PC_PT:
778778
if weight.dim() == 3:
779-
# for MOE stacked weights
780-
# For standard MoE: weight (num_experts, output_dim, input_dim)
781-
# scale (num_experts, output_dim)
782-
# For BMM-style transposed experts: weight (num_experts, output_dim, input_dim)
783-
# scale (num_experts, input_dim)
784-
785779
# Handle different scale tensor shapes
786780
if weights_scaling_factor.dim() == 1:
787781
# Per-expert scaling only: (num_experts,) -> (num_experts, 1, 1)
788782
return (weight / weights_scaling_factor[:, None, None]).to(torch.float8_e4m3fn)
789783
elif weights_scaling_factor.dim() == 2:
790784
# Per-channel scaling: check which dimension matches
791785
if weights_scaling_factor.shape[-1] == weight.shape[-1]:
792-
# Scale matches last dim (input_dim) - BMM-style transposed case
793-
# (num_experts, input_dim) -> (num_experts, 1, input_dim)
786+
# (num_experts, input_dim) -> (num_experts, 1, input_dim), BMM-style
794787
return (weight / weights_scaling_factor.unsqueeze(-2)).to(torch.float8_e4m3fn)
795788
elif weights_scaling_factor.shape[-1] == weight.shape[-2]:
796-
# Scale matches second-to-last dim (output_dim) - standard MoE case
797-
# (num_experts, output_dim) -> (num_experts, output_dim, 1)
789+
# (num_experts, output_dim) -> (num_experts, output_dim, 1), Standard MoE case
798790
return (weight / weights_scaling_factor.unsqueeze(-1)).to(torch.float8_e4m3fn)
799791
else:
800-
# Shape mismatch - try to infer correct broadcasting
801792
raise ValueError(
802793
f"Cannot determine correct unsqueeze dimension for FP8_PC_PT quantization. "
803794
f"weight shape: {weight.shape}, scale shape: {weights_scaling_factor.shape}"
804795
)
805-
else:
806-
raise ValueError(
807-
f"Unexpected scaling factor dimension for 3D weight: {weights_scaling_factor.dim()}"
808-
)
809796
return (weight / weights_scaling_factor[:, None]).to(torch.float8_e4m3fn)
810797

811798
if quantization in [QUANTIZATION_INT4_AWQ, QUANTIZATION_W4A8_AWQ]:

modelopt/torch/export/unified_export_hf.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,6 @@ def _export_quantized_weight(
355355
)
356356
elif quantization_format == QUANTIZATION_FP8_PC_PT and is_bmm_expert_weight:
357357
# For FP8_PC_PT with BMM-style experts, transpose only the weight (not weight_scale)
358-
# Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim)
359-
# weight_scale remains (num_experts, output_dim) for per-channel quantization
360358
weight, _ = maybe_transpose_expert_weight_dimensions(
361359
weight, is_bmm_expert_weight=is_bmm_expert_weight
362360
)

0 commit comments

Comments
 (0)