Skip to content

Commit 0a4f0a8

Browse files
authored
[OMNIML-2244] Implement the ONNX quantization exporter for INT4 (#575)
## What does this PR do? **Type of change:** New Feature **Overview:** - Created an abstract parent class for ONNXQuantExporter - Created child classes for individual precisions - Implemented the INT4QuantExporter - Removed quantize_weights_to_int4 - Added a method to quantize weights of the ONNX model to low precision ## Testing ``` python torch_quant_to_onnx.py --quantize_mode=int4_awq \ --onnx_save_path=<onnx_path> \ ``` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> --------- Signed-off-by: ajrasane <[email protected]>
1 parent 261858c commit 0a4f0a8

File tree

10 files changed

+595
-182
lines changed

10 files changed

+595
-182
lines changed

modelopt/onnx/export/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""ONNX export utilities."""
17+
18+
__all__ = [
19+
"FP8QuantExporter",
20+
"INT4QuantExporter",
21+
"INT8QuantExporter",
22+
"MXFP8QuantExporter",
23+
"NVFP4QuantExporter",
24+
"ONNXQuantExporter",
25+
]
26+
27+
from .base_exporter import ONNXQuantExporter
28+
from .fp8_exporter import FP8QuantExporter
29+
from .int4_exporter import INT4QuantExporter
30+
from .int8_exporter import INT8QuantExporter
31+
from .mxfp8_exporter import MXFP8QuantExporter
32+
from .nvfp4_exporter import NVFP4QuantExporter
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Base class for ONNX quantizer exporters."""
17+
18+
from abc import ABC, abstractmethod
19+
20+
import onnx
21+
22+
23+
class ONNXQuantExporter(ABC):
24+
"""Base class for ONNX quantizer exporters."""
25+
26+
@classmethod
27+
def process_model(cls, onnx_model: onnx.ModelProto) -> onnx.ModelProto:
28+
"""Processes the ONNX model."""
29+
onnx_model = cls.pre_process(onnx_model)
30+
onnx_model = cls.compute_scales(onnx_model)
31+
onnx_model = cls.compress_weights(onnx_model)
32+
onnx_model = cls.post_process(onnx_model)
33+
return onnx_model
34+
35+
@staticmethod
36+
@abstractmethod
37+
def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
38+
"""Pre-processes the ONNX model. Converts all DQ -> * -> op patterns to DQ -> op."""
39+
40+
@staticmethod
41+
@abstractmethod
42+
def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
43+
"""Computes the scales for the weights in the ONNX model."""
44+
45+
@staticmethod
46+
@abstractmethod
47+
def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
48+
"""Compresses the weights in the ONNX model."""
49+
50+
@staticmethod
51+
@abstractmethod
52+
def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
53+
"""Post-processes the ONNX model."""
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""FP8 quantization exporter."""
17+
18+
import onnx
19+
20+
from .base_exporter import ONNXQuantExporter
21+
22+
23+
# TODO: Implement the FP8QuantExporter
24+
class FP8QuantExporter(ONNXQuantExporter):
25+
"""Exporter for FP8 quantization."""
26+
27+
@staticmethod
28+
def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
29+
"""Pre-processes the ONNX model for FP8 quantization."""
30+
31+
@staticmethod
32+
def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
33+
"""Computes the scales for the weights in the ONNX model for FP8 quantization."""
34+
35+
@staticmethod
36+
def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
37+
"""Compresses the weights in the ONNX model for FP8 quantization."""
38+
39+
@staticmethod
40+
def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
41+
"""Post-processes the ONNX model for FP8 quantization."""

0 commit comments

Comments
 (0)