Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ NVIDIA Model Optimizer Changelog (Linux)
- Add support for PyTorch Geometric quantization.
- Add per tensor and per channel MSE calibrator support.
- Added support for PTQ/QAT checkpoint export and loading for running fakequant evaluation in vLLM. See `examples/vllm_serve/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip>`_ for more details.
- Added support for mixed precision quantization and ONNX export. See `examples/onnx_ptq/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/onnx_ptq#mixed-precision-quantization-auto-mode>`_ for more details.

**Documentation**

Expand Down
24 changes: 24 additions & 0 deletions examples/onnx_ptq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,30 @@ python evaluate.py \
--model_name=vit_base_patch16_224
```

### Mixed Precision Quantization (Auto Mode)

The `auto` mode enables mixed precision quantization by searching for the optimal quantization format per layer. This approach balances model accuracy and compression by assigning different precision formats (e.g., NVFP4, FP8) to different layers based on their sensitivity.

#### How it works

1. **Sensitivity Analysis**: Computes per-layer sensitivity scores using gradient-based analysis
2. **Format Search**: Searches across specified quantization formats for each layer
3. **Constraint Optimization**: Finds the optimal format assignment that satisfies the effective bits constraint while minimizing accuracy loss

#### Usage

```bash
python torch_quant_to_onnx.py \
--timm_model_name=vit_base_patch16_224 \
--quantize_mode=auto \
--auto_quantization_formats NVFP4_AWQ_LITE_CFG FP8_DEFAULT_CFG \
--effective_bits=4.8 \
--num_score_steps=128 \
--calibration_data_size=512 \
--evaluate \
--onnx_save_path=vit_base_patch16_224.auto_quant.onnx
```

### ONNX Export Supported LLM Models

| Model | FP16 | INT4 | FP8 | NVFP4 |
Expand Down
6 changes: 0 additions & 6 deletions examples/onnx_ptq/torch_quant_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,12 +323,6 @@ def main():
)
print(f"Quantized Model - Top-1 Accuracy: {top1:.2f}%, Top-5 Accuracy: {top5:.2f}%")

if args.quantize_mode in ["auto"]:
print(
f"The selected quantization mode {args.quantize_mode} is not supported for ONNX export yet."
)
return

# Export to ONNX
export_to_onnx(
quantized_model,
Expand Down
68 changes: 66 additions & 2 deletions modelopt/onnx/export/fp8_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,91 @@

"""FP8 quantization exporter."""

import time

import onnx
import onnx_graphsurgeon as gs
import torch
from onnx_graphsurgeon.ir.tensor import LazyValues

from .base_exporter import ONNXQuantExporter


# TODO: Implement the FP8QuantExporter
class FP8QuantExporter(ONNXQuantExporter):
"""Exporter for FP8 quantization."""

@staticmethod
def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Pre-processes the ONNX model for FP8 quantization."""
return onnx_model

@staticmethod
def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Computes the scales for the weights in the ONNX model for FP8 quantization."""
return onnx_model

@staticmethod
def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Compresses the weights in the ONNX model for FP8 quantization."""
"""Compresses FP32/FP16 weights to FP8 by folding QDQ nodes to DQ only.

Even though modelopt supports FP8 onnx export, the weights are represented in fp32 + QDQ.
The storage is therefore very bad. In this function,
Q nodes will get removed from the weights and have only DQ nodes with those converted FP8
weights in the output model.

Parameters:
onnx_model: ONNX model with FP32/FP16 weights and QDQ nodes.

Returns:
ONNX model with FP8 weights and only DQ nodes for weights (QDQ preserved for activations).
"""
start_time = time.time()
print("Replacing all (fp32 weights + fp8 QDQ) with (fp8 weights + DQ)...")

graph = gs.import_onnx(onnx_model)
# Fold constants is required since the scale is not constant yet.
graph.cleanup().toposort().fold_constants().cleanup()

for node in graph.nodes:
if node.op == "TRT_FP8QuantizeLinear":
# Should not remove input QDQ
if not isinstance(node.inputs[0], gs.Constant):
continue

weights = node.inputs[0]
scale = node.inputs[1]
torch_weights = torch.from_numpy(weights.values)
torch_scale = torch.from_numpy(scale.values)
quantizer_name = scale.name.rsplit("/", 1)[0]
dq_op = node.outputs[0].outputs[0]
assert dq_op.op == "TRT_FP8DequantizeLinear", (
f"QDQ does not occur in pairs. You reached {dq_op.op}"
)

# Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8.
numpy_weights = (
(torch_weights / torch_scale).to(torch.float8_e4m3fn).view(torch.uint8).numpy()
)
tensor = onnx.TensorProto()
tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN
tensor.dims.extend(numpy_weights.shape)
tensor.raw_data = numpy_weights.tobytes()
values = LazyValues(tensor)
onnx_weights_fp8 = gs.Constant(quantizer_name + "/fp8_weights", values)

node.outputs.clear()
# DQ Op is separated out
dq_op.inputs[0] = onnx_weights_fp8
dq_op.op = "DequantizeLinear"
dq_op.outputs[0].dtype = dq_op.inputs[1].dtype

graph.cleanup().toposort()
end_time = time.time()
print(f"fp8 qdq replaced with only dq completed in {end_time - start_time}s.")

return gs.export_onnx(graph)

@staticmethod
def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Post-processes the ONNX model for FP8 quantization."""
return onnx_model
4 changes: 4 additions & 0 deletions modelopt/onnx/export/int8_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@ class INT8QuantExporter(ONNXQuantExporter):
@staticmethod
def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Pre-processes the ONNX model for INT8 quantization."""
return onnx_model

@staticmethod
def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Computes the scales for the weights in the ONNX model for INT8 quantization."""
return onnx_model

@staticmethod
def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Compresses the weights in the ONNX model for INT8 quantization."""
return onnx_model

@staticmethod
def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Post-processes the ONNX model for INT8 quantization."""
return onnx_model
4 changes: 4 additions & 0 deletions modelopt/torch/_deploy/utils/torch_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,8 @@ def is_int8_quantized(model: nn.Module) -> bool:
if (
hasattr(module, "weight_quantizer")
and hasattr(module, "input_quantizer")
and module.weight_quantizer.is_enabled
and module.input_quantizer.is_enabled
and module.weight_quantizer._num_bits == 8
and module.input_quantizer._num_bits == 8
):
Expand All @@ -358,6 +360,8 @@ def is_fp8_quantized(model: nn.Module) -> bool:
if (
hasattr(module, "weight_quantizer")
and hasattr(module, "input_quantizer")
and module.weight_quantizer.is_enabled
and module.input_quantizer.is_enabled
and module.weight_quantizer._num_bits == (4, 3)
and module.input_quantizer._num_bits == (4, 3)
# Exclude MXFP8 which also uses (4,3) but has block_sizes with scale_bits
Expand Down
14 changes: 9 additions & 5 deletions tests/examples/onnx_ptq/test_torch_quant_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,22 @@

# TODO: Add accuracy evaluation after we upgrade TRT version to 10.12
@pytest.mark.parametrize(
("quantize_mode", "onnx_save_path", "calib_size"),
("quantize_mode", "onnx_save_path", "calib_size", "num_score_steps"),
[
("nvfp4", "vit_base_patch16_224.nvfp4.onnx", "1"),
("mxfp8", "vit_base_patch16_224.mxfp8.onnx", "1"),
("int4_awq", "vit_base_patch16_224.int4_awq.onnx", "1"),
("fp8", "vit_base_patch16_224.fp8.onnx", "1", "1"),
("int8", "vit_base_patch16_224.int8.onnx", "1", "1"),
("nvfp4", "vit_base_patch16_224.nvfp4.onnx", "1", "1"),
("mxfp8", "vit_base_patch16_224.mxfp8.onnx", "1", "1"),
("int4_awq", "vit_base_patch16_224.int4_awq.onnx", "1", "1"),
("auto", "vit_base_patch16_224.auto.onnx", "1", "1"),
],
)
def test_torch_onnx(quantize_mode, onnx_save_path, calib_size):
def test_torch_onnx(quantize_mode, onnx_save_path, calib_size, num_score_steps):
cmd_parts = extend_cmd_parts(
["python", "torch_quant_to_onnx.py"],
quantize_mode=quantize_mode,
onnx_save_path=onnx_save_path,
calibration_data_size=calib_size,
num_score_steps=num_score_steps,
)
run_example_command(cmd_parts, "onnx_ptq")