@@ -776,36 +776,23 @@ def to_quantized_weight(
776776
777777 if quantization == QUANTIZATION_FP8_PC_PT :
778778 if weight .dim () == 3 :
779- # for MOE stacked weights
780- # For standard MoE: weight (num_experts, output_dim, input_dim)
781- # scale (num_experts, output_dim)
782- # For BMM-style transposed experts: weight (num_experts, output_dim, input_dim)
783- # scale (num_experts, input_dim)
784-
785779 # Handle different scale tensor shapes
786780 if weights_scaling_factor .dim () == 1 :
787781 # Per-expert scaling only: (num_experts,) -> (num_experts, 1, 1)
788782 return (weight / weights_scaling_factor [:, None , None ]).to (torch .float8_e4m3fn )
789783 elif weights_scaling_factor .dim () == 2 :
790784 # Per-channel scaling: check which dimension matches
791785 if weights_scaling_factor .shape [- 1 ] == weight .shape [- 1 ]:
792- # Scale matches last dim (input_dim) - BMM-style transposed case
793- # (num_experts, input_dim) -> (num_experts, 1, input_dim)
786+ # (num_experts, input_dim) -> (num_experts, 1, input_dim), BMM-style
794787 return (weight / weights_scaling_factor .unsqueeze (- 2 )).to (torch .float8_e4m3fn )
795788 elif weights_scaling_factor .shape [- 1 ] == weight .shape [- 2 ]:
796- # Scale matches second-to-last dim (output_dim) - standard MoE case
797- # (num_experts, output_dim) -> (num_experts, output_dim, 1)
789+ # (num_experts, output_dim) -> (num_experts, output_dim, 1), Standard MoE case
798790 return (weight / weights_scaling_factor .unsqueeze (- 1 )).to (torch .float8_e4m3fn )
799791 else :
800- # Shape mismatch - try to infer correct broadcasting
801792 raise ValueError (
802793 f"Cannot determine correct unsqueeze dimension for FP8_PC_PT quantization. "
803794 f"weight shape: { weight .shape } , scale shape: { weights_scaling_factor .shape } "
804795 )
805- else :
806- raise ValueError (
807- f"Unexpected scaling factor dimension for 3D weight: { weights_scaling_factor .dim ()} "
808- )
809796 return (weight / weights_scaling_factor [:, None ]).to (torch .float8_e4m3fn )
810797
811798 if quantization in [QUANTIZATION_INT4_AWQ , QUANTIZATION_W4A8_AWQ ]:
0 commit comments