Skip to content

Commit f4e6f50

Browse files
committed
Create a function to pre-process the ONNX model
Signed-off-by: ajrasane <[email protected]>
1 parent a4c3e31 commit f4e6f50

File tree

3 files changed

+119
-42
lines changed

3 files changed

+119
-42
lines changed

modelopt/onnx/export/quant_exporter.py

Lines changed: 117 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,20 @@
2929
class ONNXQuantExporter(ABC):
3030
"""Base class for ONNX quantizer exporters."""
3131

32+
@classmethod
33+
def process_model(cls, onnx_model: onnx.ModelProto) -> onnx.ModelProto:
34+
"""Processes the ONNX model."""
35+
onnx_model = cls.pre_process(onnx_model)
36+
onnx_model = cls.compute_scales(onnx_model)
37+
onnx_model = cls.compress_weights(onnx_model)
38+
onnx_model = cls.post_process(onnx_model)
39+
return onnx_model
40+
41+
@staticmethod
42+
@abstractmethod
43+
def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
44+
"""Pre-processes the ONNX model. Converts all DQ -> * -> op patterns to DQ -> op."""
45+
3246
@staticmethod
3347
@abstractmethod
3448
def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
@@ -49,6 +63,10 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
4963
class MXFP8QuantExporter(ONNXQuantExporter):
5064
"""Exporter for MXFP8 quantization."""
5165

66+
@staticmethod
67+
def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
68+
"""Pre-processes the ONNX model for MXFP8 quantization."""
69+
5270
@staticmethod
5371
def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
5472
"""Computes the scales for the weights in the ONNX model for MXFP8 quantization."""
@@ -66,6 +84,10 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
6684
class FP8QuantExporter(ONNXQuantExporter):
6785
"""Exporter for FP8 quantization."""
6886

87+
@staticmethod
88+
def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
89+
"""Pre-processes the ONNX model for FP8 quantization."""
90+
6991
@staticmethod
7092
def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
7193
"""Computes the scales for the weights in the ONNX model for FP8 quantization."""
@@ -83,6 +105,10 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
83105
class INT8QuantExporter(ONNXQuantExporter):
84106
"""Exporter for INT8 quantization."""
85107

108+
@staticmethod
109+
def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
110+
"""Pre-processes the ONNX model for INT8 quantization."""
111+
86112
@staticmethod
87113
def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
88114
"""Computes the scales for the weights in the ONNX model for INT8 quantization."""
@@ -100,31 +126,17 @@ class INT4QuantExporter(ONNXQuantExporter):
100126
"""Exporter for INT4 quantization."""
101127

102128
@staticmethod
103-
def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
104-
"""Computes the scales for the weights in the ONNX model for INT4 quantization."""
129+
def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
130+
"""Pre-processes the ONNX model for INT4 quantization."""
105131
graph = onnx_model.graph
106-
initializer_map = {initializer.name: initializer for initializer in graph.initializer}
107132
value_info_map = {value_info.name: value_info for value_info in graph.value_info}
108133
weight_dq_nodes = [node for node in graph.node if node.op_type == "DequantizeLinear"]
109134
tensor_producer_map = get_tensor_producer_nodes(graph)
110135

111136
nodes_to_remove = []
112137
for node in weight_dq_nodes:
113138
weight_name = node.input[0]
114-
scale_name = node.input[1]
115-
logger.debug(f"Processing INT4 conversion for weight {weight_name}")
116-
weight = numpy_helper.to_array(initializer_map[weight_name])
117-
if scale_name in initializer_map:
118-
scale = numpy_helper.to_array(initializer_map[scale_name])
119-
else:
120-
scale_constant_node = tensor_producer_map[scale_name]
121-
for attr in scale_constant_node.attribute:
122-
if attr.name == "value":
123-
tensor = attr.t
124-
scale = numpy_helper.to_array(tensor)
125-
126-
weight = weight / scale
127-
block_size = weight.shape[-1]
139+
logger.debug(f"Restructuring graph for weight {weight_name}")
128140

129141
## Convert DequantizeLinear -> Reshape -> Transpose -> MatMul/Gemm to DequantizeLinear -> Matmul/Gemm
130142
dq_child_nodes = [n for n in graph.node if node.output[0] in n.input]
@@ -137,7 +149,7 @@ def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
137149
shape_constant_name = next(input for input in reshape_node.input if "Constant" in input)
138150
nodes_to_remove.append(tensor_producer_map[shape_constant_name].name)
139151

140-
# Get the shape of the output of the reshape node
152+
# Get the shape of the output of the reshape node - store for compute_scales
141153
reshape_output_value_info = value_info_map.get(reshape_node_output)
142154
if reshape_output_value_info is not None:
143155
weight_shape = [
@@ -146,13 +158,11 @@ def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
146158
else:
147159
raise ValueError(f"Unable to determine shape of weight tensor {weight_name}")
148160

149-
# Reshape weights and scales
150-
weight = weight.reshape(weight_shape)
151-
assert weight_shape[-1] % block_size == 0, (
152-
f"Block size {block_size} is not divisible by {weight_shape[-1]}"
153-
)
154-
scale_shape = [*weight_shape[:-1], weight_shape[-1] // block_size]
155-
scale = scale.reshape(scale_shape)
161+
# Store target shape as attribute on DequantizeLinear node
162+
target_shape_attr = node.attribute.add()
163+
target_shape_attr.name = "_target_shape"
164+
target_shape_attr.ints.extend(weight_shape)
165+
156166
reshape_child_nodes = [n for n in graph.node if reshape_node.output[0] in n.input]
157167
assert len(reshape_child_nodes) == 1, f"Expected exactly one child node for {node.name}"
158168

@@ -165,7 +175,7 @@ def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
165175
cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input]
166176
next_node = cast_child_nodes[0]
167177

168-
# Transpose weights and scales if present
178+
# Store transpose permutation if present
169179
if next_node.op_type == "Transpose":
170180
transpose_node = next_node
171181
nodes_to_remove.append(transpose_node.name)
@@ -177,26 +187,90 @@ def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
177187
if attr.name == "perm":
178188
perm = [x for x in attr.ints] # noqa: C416
179189
assert perm is not None, f"Permutation not found for {node.name}"
180-
weight = weight.transpose(perm)
181-
scale = scale.transpose(perm)
190+
191+
# Store permutation as attribute on DequantizeLinear node
192+
perm_attr = node.attribute.add()
193+
perm_attr.name = "_transpose_perm"
194+
perm_attr.ints.extend(perm)
195+
182196
transpose_child_nodes = [
183197
n for n in graph.node if transpose_node.output[0] in n.input
184198
]
185-
# transpose_node.input = []
186199
assert len(transpose_child_nodes) == 1, (
187200
f"Expected exactly one matmul node for {node.name}"
188201
)
189202
matmul_node = transpose_child_nodes[0]
190203
else:
191204
matmul_node = next_node
205+
192206
assert matmul_node.op_type in ["MatMul", "Gemm"], (
193207
f"Expected MatMul or Gemm node for {node.name}"
194208
)
209+
# Rewire MatMul to use DequantizeLinear output directly
195210
matmul_node.input[1] = node.output[0]
196211

212+
# Remove transpose, reshape, and constant nodes
213+
new_nodes = [node for node in graph.node if node.name not in nodes_to_remove]
214+
del graph.node[:]
215+
graph.node.extend(new_nodes)
216+
217+
return onnx_model
218+
219+
@staticmethod
220+
def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
221+
"""Computes the scales for the weights in the ONNX model for INT4 quantization."""
222+
graph = onnx_model.graph
223+
initializer_map = {initializer.name: initializer for initializer in graph.initializer}
224+
weight_dq_nodes = [node for node in graph.node if node.op_type == "DequantizeLinear"]
225+
tensor_producer_map = get_tensor_producer_nodes(graph)
226+
227+
for node in weight_dq_nodes:
228+
weight_name = node.input[0]
229+
scale_name = node.input[1]
230+
logger.debug(f"Computing scales for weight {weight_name}")
231+
232+
# Load weight and scale tensors
233+
weight = numpy_helper.to_array(initializer_map[weight_name])
234+
if scale_name in initializer_map:
235+
scale = numpy_helper.to_array(initializer_map[scale_name])
236+
else:
237+
scale_constant_node = tensor_producer_map[scale_name]
238+
for attr in scale_constant_node.attribute:
239+
if attr.name == "value":
240+
tensor = attr.t
241+
scale = numpy_helper.to_array(tensor)
242+
243+
# Dequantize weight
244+
weight = weight / scale
245+
block_size = weight.shape[-1]
246+
247+
# Get target shape from metadata stored in pre_process
248+
target_shape = None
249+
transpose_perm = None
250+
for attr in node.attribute:
251+
if attr.name == "_target_shape":
252+
target_shape = list(attr.ints)
253+
elif attr.name == "_transpose_perm":
254+
transpose_perm = list(attr.ints)
255+
256+
assert target_shape is not None, f"Target shape not found for {node.name}"
257+
258+
# Reshape weights and scales
259+
weight = weight.reshape(target_shape)
260+
assert target_shape[-1] % block_size == 0, (
261+
f"Block size {block_size} is not divisible by {target_shape[-1]}"
262+
)
263+
scale_shape = [*target_shape[:-1], target_shape[-1] // block_size]
264+
scale = scale.reshape(scale_shape)
265+
266+
# Transpose weights and scales if permutation was stored
267+
if transpose_perm is not None:
268+
weight = weight.transpose(transpose_perm)
269+
scale = scale.transpose(transpose_perm)
270+
271+
# Handle scale tensor creation/update
197272
if scale_name not in initializer_map:
198273
# Remove scale producer if it's a Constant node
199-
scale_name = node.input[1]
200274
scale_producer = tensor_producer_map[scale_name]
201275
if scale_producer.op_type == "Constant":
202276
graph.node.remove(scale_producer)
@@ -210,14 +284,17 @@ def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
210284
scale_tensor = onnx.numpy_helper.from_array(scale, scale_name)
211285
initializer_map[scale_name].CopyFrom(scale_tensor)
212286

213-
weight = numpy_helper.from_array(weight, weight_name)
214-
initializer_map[weight_name].CopyFrom(weight)
287+
# Update weight tensor
288+
weight_tensor = numpy_helper.from_array(weight, weight_name)
289+
initializer_map[weight_name].CopyFrom(weight_tensor)
290+
215291
logger.debug(f"Computed scales for weight {weight_name} for INT4 quantization")
216292

217-
# Remove transpose and reshape nodes
218-
new_nodes = [node for node in graph.node if node.name not in nodes_to_remove]
219-
del graph.node[:]
220-
graph.node.extend(new_nodes)
293+
# Clean up metadata attributes from DequantizeLinear nodes
294+
for node in weight_dq_nodes:
295+
attrs_to_keep = [attr for attr in node.attribute if not attr.name.startswith("_")]
296+
del node.attribute[:]
297+
node.attribute.extend(attrs_to_keep)
221298

222299
return onnx_model
223300

@@ -306,6 +383,10 @@ def is_fp32_cast(node: onnx.NodeProto) -> bool:
306383
class NVFP4QuantExporter(ONNXQuantExporter):
307384
"""Exporter for NVFP4 quantization."""
308385

386+
@staticmethod
387+
def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
388+
"""Pre-processes the ONNX model for NVFP4 quantization."""
389+
309390
@staticmethod
310391
def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
311392
"""Computes the scales for the weights in the ONNX model for NVFP4 quantization."""

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -404,9 +404,7 @@ def quantize_weights(model: nn.Module, onnx_model: onnx.ModelProto) -> onnx.Mode
404404
onnx_exporters.append(MXFP8QuantExporter)
405405

406406
for onnx_exporter in onnx_exporters:
407-
onnx_model = onnx_exporter.compute_scales(onnx_model)
408-
onnx_model = onnx_exporter.compress_weights(onnx_model)
409-
onnx_model = onnx_exporter.post_process(onnx_model)
407+
onnx_model = onnx_exporter.process_model(onnx_model)
410408

411409
return onnx_model
412410

tests/unit/onnx/test_qdq_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -389,9 +389,7 @@ def test_projection_bias_and_scale_casting(self):
389389
model = create_test_model_with_proj_nodes()
390390

391391
# Run quantization
392-
quantized_model = INT4QuantExporter.compute_scales(model)
393-
quantized_model = INT4QuantExporter.compress_weights(quantized_model)
394-
quantized_model = INT4QuantExporter.post_process(quantized_model)
392+
quantized_model = INT4QuantExporter.process_model(model)
395393

396394
# Verify bias tensor is cast to float16
397395
bias_tensor = next(

0 commit comments

Comments
 (0)