Skip to content

Commit 1aaa77d

Browse files
authored
[OMNIML-2244] Add support for auto quantizing a model (#571)
## What does this PR do? **Type of change:** Example update **Overview:** - Added option to quantize a model with `mtq.auto_quantize()` ## Usage ```python python torch_quant_to_onnx.py \ --timm_model_name vit_small_patch16_224 \ --quantize_mode auto \ --onnx_save_path models/vit_auto_quant.onnx \ --calibration_data_size 512 \ --batch_size 8 \ --auto_quantization_formats NVFP4_AWQ_LITE_CFG FP8_DEFAULT_CFG INT8_DEFAULT_CFG \ --effective_bits 4.8 \ --num_score_steps 128 ``` ## Testing Able to auto quantize ViT model ``` AutoQuantize best recipe for patch_embed.proj: NONE(effective-bits: 16.0) AutoQuantize best recipe for blocks.0.attn.qkv: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.0.attn.proj: FP8_DEFAULT_CFG(effective-bits: 8.0) AutoQuantize best recipe for blocks.0.mlp.fc1: FP8_DEFAULT_CFG(effective-bits: 8.0) AutoQuantize best recipe for blocks.0.mlp.fc2: FP8_DEFAULT_CFG(effective-bits: 8.0) AutoQuantize best recipe for blocks.1.attn.qkv: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.1.attn.proj: FP8_DEFAULT_CFG(effective-bits: 8.0) AutoQuantize best recipe for blocks.1.mlp.fc1: FP8_DEFAULT_CFG(effective-bits: 8.0) AutoQuantize best recipe for blocks.1.mlp.fc2: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.2.attn.qkv: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.2.attn.proj: FP8_DEFAULT_CFG(effective-bits: 8.0) AutoQuantize best recipe for blocks.2.mlp.fc1: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.2.mlp.fc2: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.3.attn.qkv: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.3.attn.proj: FP8_DEFAULT_CFG(effective-bits: 8.0) AutoQuantize best recipe for blocks.3.mlp.fc1: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.3.mlp.fc2: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.4.attn.qkv: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.4.attn.proj: FP8_DEFAULT_CFG(effective-bits: 8.0) AutoQuantize best recipe for blocks.4.mlp.fc1: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.4.mlp.fc2: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.5.attn.qkv: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.5.attn.proj: FP8_DEFAULT_CFG(effective-bits: 8.0) AutoQuantize best recipe for blocks.5.mlp.fc1: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.5.mlp.fc2: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.6.attn.qkv: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.6.attn.proj: FP8_DEFAULT_CFG(effective-bits: 8.0) AutoQuantize best recipe for blocks.6.mlp.fc1: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.6.mlp.fc2: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.7.attn.qkv: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.7.attn.proj: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.7.mlp.fc1: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.7.mlp.fc2: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.8.attn.qkv: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.8.attn.proj: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.8.mlp.fc1: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.8.mlp.fc2: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.9.attn.qkv: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.9.attn.proj: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.9.mlp.fc1: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.9.mlp.fc2: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.10.attn.qkv: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.10.attn.proj: FP8_DEFAULT_CFG(effective-bits: 8.0) AutoQuantize best recipe for blocks.10.mlp.fc1: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.10.mlp.fc2: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.11.attn.qkv: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.11.attn.proj: FP8_DEFAULT_CFG(effective-bits: 8.0) AutoQuantize best recipe for blocks.11.mlp.fc1: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for blocks.11.mlp.fc2: NVFP4_AWQ_LITE_CFG(effective-bits: 4.0) AutoQuantize best recipe for head: FP8_DEFAULT_CFG(effective-bits: 8.0) AutoQuantize effective bits from search: 4.80 ``` Accuracy comparison for the ViT model | | Top-1 accuracy | Top-5 accuracy | |------------------------------------------------------|----------------|----------------| | Original model (FP32) | 85.102% | 97.526% | | Auto Quantized (FP8 + NVFP4, 4.78 effective bits) | 84.726% | 97.434% | | MXFP8 Quantized | 85.02% | 97.53% | | NVFP4 Quantized | 84.558% | 97.36% | | INT4 Quantized | 84.23% | 97.22% | ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> --------- Signed-off-by: ajrasane <[email protected]>
1 parent a703e22 commit 1aaa77d

File tree

1 file changed

+166
-29
lines changed

1 file changed

+166
-29
lines changed

examples/onnx_ptq/torch_quant_to_onnx.py

Lines changed: 166 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,20 @@
1919
import timm
2020
import torch
2121
import torch.multiprocessing as mp
22+
import torch.nn.functional as F
2223
from datasets import load_dataset
2324
from download_example_onnx import export_to_onnx
2425
from evaluation import evaluate
2526

2627
import modelopt.torch.quantization as mtq
2728

2829
"""
29-
This script is used to quantize a timm model using dynamic quantization like MXFP8 or NVFP4.
30+
This script is used to quantize a timm model using dynamic quantization like MXFP8 or NVFP4,
31+
or using auto quantization for optimal per-layer quantization.
3032
3133
The script will:
3234
1. Given the model name, create a timm torch model.
33-
2. Quantize the torch model in MXFP8 or NVFP4 mode.
35+
2. Quantize the torch model in MXFP8, NVFP4, INT4_AWQ, or AUTO mode.
3436
3. Export the quantized torch model to ONNX format.
3537
"""
3638

@@ -55,8 +57,17 @@ def filter_func(name):
5557
return pattern.match(name) is not None
5658

5759

58-
def load_calibration_data(model_name, data_size, batch_size, device):
59-
"""Load and prepare calibration data."""
60+
def load_calibration_data(model_name, data_size, batch_size, device, with_labels=False):
61+
"""Load and prepare calibration data.
62+
63+
Args:
64+
model_name: Name of the timm model
65+
data_size: Number of samples to load
66+
batch_size: Batch size for data loader
67+
device: Device to load data to
68+
with_labels: If True, return dict with 'image' and 'label' keys (for auto_quantize)
69+
If False, return just the images (for standard quantize)
70+
"""
6071
dataset = load_dataset("zh-plus/tiny-imagenet")
6172
model = timm.create_model(model_name, pretrained=True, num_classes=1000)
6273
data_config = timm.data.resolve_model_data_config(model)
@@ -65,9 +76,18 @@ def load_calibration_data(model_name, data_size, batch_size, device):
6576
images = dataset["train"][:data_size]["image"]
6677
calib_tensor = [transforms(img) for img in images]
6778
calib_tensor = [t.to(device) for t in calib_tensor]
68-
return torch.utils.data.DataLoader(
69-
calib_tensor, batch_size=batch_size, shuffle=True, num_workers=4
70-
)
79+
80+
if with_labels:
81+
labels = dataset["train"][:data_size]["label"]
82+
labels = torch.tensor(labels, device=device)
83+
calib_dataset = [{"image": img, "label": lbl} for img, lbl in zip(calib_tensor, labels)]
84+
return torch.utils.data.DataLoader(
85+
calib_dataset, batch_size=batch_size, shuffle=True, num_workers=4
86+
)
87+
else:
88+
return torch.utils.data.DataLoader(
89+
calib_tensor, batch_size=batch_size, shuffle=True, num_workers=4
90+
)
7191

7292

7393
def quantize_model(model, config, data_loader=None):
@@ -86,16 +106,80 @@ def forward_loop(model):
86106
return quantized_model
87107

88108

89-
def get_model_input_shape(model_name, batch_size):
109+
def forward_step(model, batch):
110+
"""Forward step function for auto_quantize scoring."""
111+
return model(batch["image"])
112+
113+
114+
def loss_func(output, batch):
115+
"""Loss function for auto_quantize gradient computation."""
116+
return F.cross_entropy(output, batch["label"])
117+
118+
119+
def auto_quantize_model(
120+
model,
121+
data_loader,
122+
quantization_formats,
123+
effective_bits=4.8,
124+
num_calib_steps=512,
125+
num_score_steps=128,
126+
):
127+
"""Auto-quantize the model using optimal per-layer quantization search.
128+
129+
Args:
130+
model: PyTorch model to quantize
131+
data_loader: DataLoader with image-label dict batches
132+
quantization_formats: List of quantization format config names or dicts
133+
effective_bits: Target effective bits constraint
134+
num_calib_steps: Number of calibration steps
135+
num_score_steps: Number of scoring steps for sensitivity analysis
136+
137+
Returns:
138+
Tuple of (quantized_model, search_state_dict)
139+
"""
140+
constraints = {"effective_bits": effective_bits}
141+
142+
# Convert string format names to actual config objects
143+
format_configs = []
144+
for fmt in quantization_formats:
145+
if isinstance(fmt, str):
146+
format_configs.append(getattr(mtq, fmt))
147+
else:
148+
format_configs.append(fmt)
149+
150+
print(f"Starting auto-quantization search with {len(format_configs)} formats...")
151+
print(f"Effective bits constraint: {effective_bits}")
152+
print(f"Calibration steps: {num_calib_steps}, Scoring steps: {num_score_steps}")
153+
154+
quantized_model, search_state = mtq.auto_quantize(
155+
model,
156+
constraints=constraints,
157+
quantization_formats=format_configs,
158+
data_loader=data_loader,
159+
forward_step=forward_step,
160+
loss_func=loss_func,
161+
num_calib_steps=num_calib_steps,
162+
num_score_steps=num_score_steps,
163+
verbose=True,
164+
)
165+
166+
# Disable quantization for specified layers
167+
mtq.disable_quantizer(quantized_model, filter_func)
168+
169+
return quantized_model, search_state
170+
171+
172+
def get_model_input_shape(model):
90173
"""Get the input shape from timm model configuration."""
91-
model = timm.create_model(model_name, pretrained=True, num_classes=1000)
92174
data_config = timm.data.resolve_model_data_config(model)
93175
input_size = data_config["input_size"]
94-
return (batch_size, *tuple(input_size)) # Add batch dimension
176+
return tuple(input_size)
95177

96178

97179
def main():
98-
parser = argparse.ArgumentParser(description="Quantize timm models to MXFP8 or NVFP4")
180+
parser = argparse.ArgumentParser(
181+
description="Quantize timm models to FP8, MXFP8, INT8, NVFP4, INT4_AWQ, or use AUTO quantization"
182+
)
99183

100184
# Model hyperparameters
101185
parser.add_argument(
@@ -106,14 +190,14 @@ def main():
106190
)
107191
parser.add_argument(
108192
"--quantize_mode",
109-
choices=["fp8", "mxfp8", "int8", "nvfp4", "int4_awq"],
193+
choices=["fp8", "mxfp8", "int8", "nvfp4", "int4_awq", "auto"],
110194
default="mxfp8",
111-
help="Type of quantization to apply (mxfp8, nvfp4, int4_awq)",
195+
help="Type of quantization to apply. Default is MXFP8.",
112196
)
113197
parser.add_argument(
114198
"--onnx_save_path",
115199
required=True,
116-
help="The path to save the ONNX model.",
200+
help="The save path to save the ONNX model.",
117201
type=str,
118202
)
119203
parser.add_argument(
@@ -140,15 +224,43 @@ def main():
140224
help="Number of samples to use for evaluation. If None, use entire validation set.",
141225
)
142226

143-
args = parser.parse_args()
227+
# Auto quantization specific arguments
228+
parser.add_argument(
229+
"--auto_quantization_formats",
230+
nargs="+",
231+
choices=[
232+
"NVFP4_AWQ_LITE_CFG",
233+
"FP8_DEFAULT_CFG",
234+
"MXFP8_DEFAULT_CFG",
235+
"INT8_DEFAULT_CFG",
236+
"INT4_AWQ_CFG",
237+
],
238+
default=["NVFP4_AWQ_LITE_CFG", "FP8_DEFAULT_CFG"],
239+
help="Quantization formats to search from for auto mode (e.g., NVFP4_AWQ_LITE_CFG FP8_DEFAULT_CFG)",
240+
)
241+
parser.add_argument(
242+
"--effective_bits",
243+
type=float,
244+
default=4.8,
245+
help="Target effective bits for auto quantization constraint. Default is 4.8.",
246+
)
247+
parser.add_argument(
248+
"--num_score_steps",
249+
type=int,
250+
default=128,
251+
help="Number of scoring steps for auto quantization. Default is 128.",
252+
)
144253

145-
# Get input shape from model config
146-
input_shape = get_model_input_shape(args.timm_model_name, args.batch_size)
254+
args = parser.parse_args()
147255

148256
# Create model and move to appropriate device
149257
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
150258
model = timm.create_model(args.timm_model_name, pretrained=True, num_classes=1000).to(device)
151259

260+
# Get input shape from model config
261+
input_size = get_model_input_shape(model)
262+
input_shape = (args.batch_size, *input_size)
263+
152264
# Evaluate base model if requested
153265
if args.evaluate:
154266
print("\n=== Evaluating Base Model ===")
@@ -159,21 +271,44 @@ def main():
159271
)
160272
print(f"Base Model - Top-1 Accuracy: {top1:.2f}%, Top-5 Accuracy: {top5:.2f}%")
161273

162-
# Select quantization config
163-
config = QUANT_CONFIG_DICT[args.quantize_mode]
164-
data_loader = (
165-
None
166-
if args.quantize_mode == "mxfp8"
167-
else load_calibration_data(
274+
# Quantize model based on mode
275+
if args.quantize_mode == "auto":
276+
# Auto quantization requires labels for loss computation
277+
data_loader = load_calibration_data(
168278
args.timm_model_name,
169279
args.calibration_data_size,
170-
input_shape[0], # batch size
280+
args.batch_size,
171281
device,
282+
with_labels=True,
172283
)
173-
)
174284

175-
# Quantize model
176-
quantized_model = quantize_model(model, config, data_loader)
285+
quantized_model, _ = auto_quantize_model(
286+
model,
287+
data_loader,
288+
args.auto_quantization_formats,
289+
args.effective_bits,
290+
args.calibration_data_size,
291+
args.num_score_steps,
292+
)
293+
else:
294+
# Standard quantization - only load calibration data if needed
295+
config = QUANT_CONFIG_DICT[args.quantize_mode]
296+
if args.quantize_mode == "mxfp8":
297+
data_loader = None
298+
else:
299+
data_loader = load_calibration_data(
300+
args.timm_model_name,
301+
args.calibration_data_size,
302+
args.batch_size,
303+
device,
304+
with_labels=False,
305+
)
306+
307+
quantized_model = quantize_model(model, config, data_loader)
308+
309+
# Print quantization summary
310+
print("\nQuantization Summary:")
311+
mtq.print_quant_summary(quantized_model)
177312

178313
# Evaluate quantized model if requested
179314
if args.evaluate:
@@ -188,8 +323,10 @@ def main():
188323
)
189324
print(f"Quantized Model - Top-1 Accuracy: {top1:.2f}%, Top-5 Accuracy: {top5:.2f}%")
190325

191-
if args.quantize_mode in ["fp8", "int8"]:
192-
print(f"Exporting to {args.quantize_mode} ONNX model is not supported yet.")
326+
if args.quantize_mode in ["fp8", "int8", "auto"]:
327+
print(
328+
f"The selected quantization mode {args.quantize_mode} is not supported for ONNX export yet."
329+
)
193330
return
194331

195332
# Export to ONNX

0 commit comments

Comments
 (0)