Skip to content

Commit 93f5bbf

Browse files
authored
[OMNIML-3015]Add per tensor/per channel MSE calibrator (#540)
## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> new feature **Overview:** ? Add per tensor/per channel MSE calibrator. ## Usage Can be enabled with "algorithm" field in quantization configs. ``` "algorithm": {"method": "mse", "num_steps": 20, "stop_multiplier": 8.0}, ``` ## Testing <!-- Mention how have you tested your change if applicable. --> Unit test for the MseCalibrator, E2E test with NVFP4 and INT8, **results: ** start_multiplier=0.25 stop_multiplier=4.0 num_steps=20 **Qwen3-8B MMLU:** **BF16 baseline: 72.94** | Calib Algo | NVFP4 | FP8 | INT8 | | ------ | ------ | ------ | ------ | | MSE | 70.88 | 72.65 | 55.46 | | MAX | 70.83 | 72.7 | 24.52 | ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information TODO: for the follow up PR: - [ ] TP sync for HF models - [ ] Calculate weight quantizer only once <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added MSE-based quantization calibration supporting per-tensor and per-channel optimization with configurable parameters (step count, multiplier ranges). * **Tests** * Added comprehensive test coverage for MSE calibration functionality. * **Documentation** * Updated changelog to reflect MSE calibrator support. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Fridah-nv <[email protected]> Signed-off-by: Frida Hou <[email protected]>
1 parent 01e24fd commit 93f5bbf

File tree

10 files changed

+950
-25
lines changed

10 files changed

+950
-25
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Model Optimizer Changelog (Linux)
1616
- Add FP8/NVFP4 KV cache quantization support for Megatron Core models.
1717
- Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow.
1818
- Add support for PyTorch Geometric quantization.
19+
- Add per tensor and per channel MSE calibrator support.
1920

2021
**Documentation**
2122

modelopt/torch/quantization/calib/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@
2323
from .calibrator import *
2424
from .histogram import *
2525
from .max import *
26+
from .mse import *
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Calibrator that returns the MSE amax of all collected tensors."""
17+
18+
from collections.abc import Callable
19+
20+
import torch
21+
import torch.nn.functional as F
22+
23+
from .. import utils as quant_utils
24+
from .calibrator import _Calibrator
25+
26+
__all__ = ["MseCalibrator"]
27+
28+
29+
class MseCalibrator(_Calibrator):
30+
"""Per-tensor and per-channel MSE amax search that minimizes error between x and quantized x."""
31+
32+
def __init__(
33+
self,
34+
amax: torch.Tensor,
35+
axis: int | tuple | list | None = None,
36+
num_steps: int = 10,
37+
start_multiplier: float = 0.25,
38+
stop_multiplier: float = 4.0,
39+
quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
40+
error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
41+
):
42+
"""Initialize MSE calibrator.
43+
44+
Args:
45+
amax: Initial amax value (required).
46+
axis: Quantization axis. None means per-tensor quantization.
47+
num_steps: Number of amax candidates to try.
48+
start_multiplier: Starting multiplier for amax search.
49+
stop_multiplier: Ending multiplier for amax search.
50+
quant_func: Function that quantizes input tensor given an amax value.
51+
Should have signature: quant_func(x, amax) -> quantized_x.
52+
error_func: Function to compute error between x and xq.
53+
Default is F.mse_loss(x, xq, reduction='none').
54+
"""
55+
super().__init__(num_bits=None, axis=axis, unsigned=None)
56+
self._initial_amax = amax
57+
self._num_steps = num_steps
58+
self._start_multiplier = start_multiplier
59+
self._stop_multiplier = stop_multiplier
60+
self._quant_func = quant_func
61+
self._error_func = error_func
62+
self._losses_sum = [None] * num_steps
63+
self._candidate_amaxs = [None] * num_steps
64+
65+
self._amax = None
66+
67+
@torch.no_grad()
68+
def collect(self, x: torch.Tensor):
69+
"""Collect input tensor statistics and compute losses for MSE calibration.
70+
71+
Args:
72+
x: Input tensor.
73+
"""
74+
if self._quant_func is None:
75+
raise RuntimeError(
76+
"Quantization function not set. Msecalibrator requires a quant_func to be provided."
77+
)
78+
79+
x = x.detach().to(dtype=torch.float32)
80+
81+
device = x.device
82+
multipliers = torch.linspace(
83+
self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device
84+
)
85+
86+
# Get reduce axis for per-channel quantization
87+
reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis)
88+
89+
for step, multiplier in enumerate(multipliers):
90+
candidate_amax = self._initial_amax * multiplier
91+
xq = self._quant_func(x, candidate_amax)
92+
93+
if self._error_func is not None:
94+
error = self._error_func(x, xq)
95+
else:
96+
error = F.mse_loss(x, xq, reduction="none")
97+
98+
loss = quant_utils.reduce_sum(error, axis=reduce_axis, keepdims=False)
99+
100+
if self._candidate_amaxs[step] is None:
101+
self._candidate_amaxs[step] = candidate_amax
102+
103+
if self._losses_sum[step] is None:
104+
self._losses_sum[step] = loss.clone()
105+
else:
106+
self._losses_sum[step] += loss
107+
108+
def reset(self):
109+
"""Reset the stored losses and amax value."""
110+
self._losses_sum = [None] * self._num_steps
111+
self._candidate_amaxs = [None] * self._num_steps
112+
self._amax = None
113+
114+
@torch.no_grad()
115+
def compute_amax(self, verbose: bool = False):
116+
"""Return the amax value that minimizes quantization error.
117+
118+
Args:
119+
verbose: If True, print the ratio of best_amax to initial_amax.
120+
"""
121+
if not any(loss_sum is not None for loss_sum in self._losses_sum):
122+
return None
123+
124+
# Check if this is per-tensor or per-channel based on the first loss
125+
first_loss_sum = None
126+
for loss_sum in self._losses_sum:
127+
if loss_sum is not None:
128+
first_loss_sum = loss_sum
129+
break
130+
131+
if first_loss_sum is None:
132+
return None
133+
134+
# Collect losses for all steps
135+
losses_per_step = []
136+
for step in range(self._num_steps):
137+
if self._losses_sum[step] is not None:
138+
losses_per_step.append(self._losses_sum[step])
139+
# No data for this step, use inf
140+
elif first_loss_sum.ndim == 0:
141+
losses_per_step.append(torch.tensor(float("inf"), device=first_loss_sum.device))
142+
else:
143+
losses_per_step.append(torch.full_like(first_loss_sum, float("inf")))
144+
145+
# Stack to get [num_steps] for per-tensor or [num_steps, num_channels] for per-channel
146+
losses_per_step = torch.stack(losses_per_step)
147+
148+
# Find best step(s): scalar for per-tensor, [num_channels] for per-channel
149+
best_steps = torch.argmin(losses_per_step, dim=0)
150+
151+
# Stack candidate amaxs and select based on best_steps
152+
candidate_amaxs = torch.stack(self._candidate_amaxs)
153+
154+
if first_loss_sum.ndim == 0:
155+
# Per-tensor case: best_steps is a scalar
156+
self._amax = self._candidate_amaxs[best_steps.item()]
157+
else:
158+
# Per-channel case: best_steps is a tensor
159+
num_channels = best_steps.shape[0]
160+
self._amax = candidate_amaxs[
161+
best_steps, torch.arange(num_channels, device=best_steps.device)
162+
]
163+
self._amax = self._amax.reshape(self._initial_amax.shape)
164+
165+
if verbose:
166+
ratio = self._amax / self._initial_amax
167+
if ratio.ndim == 0:
168+
print(f"MSE Calibrator: best_amax/initial_amax ratio = {ratio.item():.4f}")
169+
else:
170+
print(
171+
f"MSE Calibrator: best_amax/initial_amax ratio - "
172+
f"mean: {ratio.mean().item():.4f}, "
173+
f"min: {ratio.min().item():.4f}, "
174+
f"max: {ratio.max().item():.4f}"
175+
)
176+
177+
return self._amax

modelopt/torch/quantization/config.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,45 @@ class MaxCalibConfig(QuantizeAlgorithmConfig):
981981
)
982982

983983

984+
class MseCalibConfig(QuantizeAlgorithmConfig):
985+
"""Configuration for per-tensor MSE calibration.
986+
987+
Finds a scale s (via amax a, with s = a / q_max) that minimizes the
988+
reconstruction error of a tensor after uniform Q→DQ:
989+
990+
s* = argmin_s E[(X - DQ(Q(X; s)))^2], X ∈ {weights | activations}
991+
"""
992+
993+
method: Literal["mse"] = ModeloptField("mse")
994+
995+
num_steps: int | None = ModeloptField(
996+
default=10,
997+
ge=1,
998+
title="Number of amax candidates to try.",
999+
description="Number of amax candidates to search over for MSE minimization.",
1000+
)
1001+
1002+
start_multiplier: float | None = ModeloptField(
1003+
default=0.25,
1004+
gt=0.0,
1005+
title="Starting multiplier for amax search.",
1006+
description="Starting multiplier for amax search range (multiplies initial amax).",
1007+
)
1008+
1009+
stop_multiplier: float | None = ModeloptField(
1010+
default=4.0,
1011+
gt=0.0,
1012+
title="Ending multiplier for amax search.",
1013+
description="Ending multiplier for amax search range (multiplies initial amax).",
1014+
)
1015+
1016+
distributed_sync: bool | None = ModeloptField(
1017+
default=True,
1018+
title="Whether to sync the amax across the distributed processes.",
1019+
description="If True, the amax will be synced across the distributed processes.",
1020+
)
1021+
1022+
9841023
class SmoothQuantCalibConfig(QuantizeAlgorithmConfig):
9851024
"""The config for ``smoothquant`` algorithm (SmoothQuant).
9861025

modelopt/torch/quantization/mode.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
AWQLiteCalibConfig,
3939
CompressConfig,
4040
MaxCalibConfig,
41+
MseCalibConfig,
4142
QuantizeAlgoCfgType,
4243
QuantizeAlgorithmConfig,
4344
QuantizeConfig,
@@ -54,7 +55,7 @@
5455
restore_svdquant_model,
5556
update_quantize_metadata,
5657
)
57-
from .model_calib import awq, max_calibrate, smoothquant, svdquant
58+
from .model_calib import awq, max_calibrate, mse_calibrate, smoothquant, svdquant
5859

5960
__all__ = ["BaseCalibrateModeDescriptor"]
6061

@@ -363,6 +364,18 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]:
363364
_calib_func = max_calibrate
364365

365366

367+
@CalibrateModeRegistry.register_mode
368+
class MseCalibrateModeDescriptor(BaseCalibrateModeDescriptor):
369+
"""Mode for mse calibration algorithm."""
370+
371+
@property
372+
def config_class(self) -> type[QuantizeAlgorithmConfig]:
373+
"""Specifies the config class for the mode."""
374+
return MseCalibConfig
375+
376+
_calib_func = mse_calibrate
377+
378+
366379
@CalibrateModeRegistry.register_mode
367380
class SmoothQuantModeDescriptor(BaseCalibrateModeDescriptor):
368381
"""Mode for smoothquant calibration algorithm."""

0 commit comments

Comments
 (0)