2929class ONNXQuantExporter (ABC ):
3030 """Base class for ONNX quantizer exporters."""
3131
32+ @classmethod
33+ def process_model (cls , onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
34+ """Processes the ONNX model."""
35+ onnx_model = cls .pre_process (onnx_model )
36+ onnx_model = cls .compute_scales (onnx_model )
37+ onnx_model = cls .compress_weights (onnx_model )
38+ onnx_model = cls .post_process (onnx_model )
39+ return onnx_model
40+
41+ @staticmethod
42+ @abstractmethod
43+ def pre_process (onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
44+ """Pre-processes the ONNX model. Converts all DQ -> * -> op patterns to DQ -> op."""
45+
3246 @staticmethod
3347 @abstractmethod
3448 def compute_scales (onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
@@ -49,6 +63,10 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
4963class MXFP8QuantExporter (ONNXQuantExporter ):
5064 """Exporter for MXFP8 quantization."""
5165
66+ @staticmethod
67+ def pre_process (onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
68+ """Pre-processes the ONNX model for MXFP8 quantization."""
69+
5270 @staticmethod
5371 def compute_scales (onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
5472 """Computes the scales for the weights in the ONNX model for MXFP8 quantization."""
@@ -66,6 +84,10 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
6684class FP8QuantExporter (ONNXQuantExporter ):
6785 """Exporter for FP8 quantization."""
6886
87+ @staticmethod
88+ def pre_process (onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
89+ """Pre-processes the ONNX model for FP8 quantization."""
90+
6991 @staticmethod
7092 def compute_scales (onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
7193 """Computes the scales for the weights in the ONNX model for FP8 quantization."""
@@ -83,6 +105,10 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
83105class INT8QuantExporter (ONNXQuantExporter ):
84106 """Exporter for INT8 quantization."""
85107
108+ @staticmethod
109+ def pre_process (onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
110+ """Pre-processes the ONNX model for INT8 quantization."""
111+
86112 @staticmethod
87113 def compute_scales (onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
88114 """Computes the scales for the weights in the ONNX model for INT8 quantization."""
@@ -100,31 +126,17 @@ class INT4QuantExporter(ONNXQuantExporter):
100126 """Exporter for INT4 quantization."""
101127
102128 @staticmethod
103- def compute_scales (onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
104- """Computes the scales for the weights in the ONNX model for INT4 quantization."""
129+ def pre_process (onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
130+ """Pre-processes the ONNX model for INT4 quantization."""
105131 graph = onnx_model .graph
106- initializer_map = {initializer .name : initializer for initializer in graph .initializer }
107132 value_info_map = {value_info .name : value_info for value_info in graph .value_info }
108133 weight_dq_nodes = [node for node in graph .node if node .op_type == "DequantizeLinear" ]
109134 tensor_producer_map = get_tensor_producer_nodes (graph )
110135
111136 nodes_to_remove = []
112137 for node in weight_dq_nodes :
113138 weight_name = node .input [0 ]
114- scale_name = node .input [1 ]
115- logger .debug (f"Processing INT4 conversion for weight { weight_name } " )
116- weight = numpy_helper .to_array (initializer_map [weight_name ])
117- if scale_name in initializer_map :
118- scale = numpy_helper .to_array (initializer_map [scale_name ])
119- else :
120- scale_constant_node = tensor_producer_map [scale_name ]
121- for attr in scale_constant_node .attribute :
122- if attr .name == "value" :
123- tensor = attr .t
124- scale = numpy_helper .to_array (tensor )
125-
126- weight = weight / scale
127- block_size = weight .shape [- 1 ]
139+ logger .debug (f"Restructuring graph for weight { weight_name } " )
128140
129141 ## Convert DequantizeLinear -> Reshape -> Transpose -> MatMul/Gemm to DequantizeLinear -> Matmul/Gemm
130142 dq_child_nodes = [n for n in graph .node if node .output [0 ] in n .input ]
@@ -137,7 +149,7 @@ def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
137149 shape_constant_name = next (input for input in reshape_node .input if "Constant" in input )
138150 nodes_to_remove .append (tensor_producer_map [shape_constant_name ].name )
139151
140- # Get the shape of the output of the reshape node
152+ # Get the shape of the output of the reshape node - store for compute_scales
141153 reshape_output_value_info = value_info_map .get (reshape_node_output )
142154 if reshape_output_value_info is not None :
143155 weight_shape = [
@@ -146,13 +158,11 @@ def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
146158 else :
147159 raise ValueError (f"Unable to determine shape of weight tensor { weight_name } " )
148160
149- # Reshape weights and scales
150- weight = weight .reshape (weight_shape )
151- assert weight_shape [- 1 ] % block_size == 0 , (
152- f"Block size { block_size } is not divisible by { weight_shape [- 1 ]} "
153- )
154- scale_shape = [* weight_shape [:- 1 ], weight_shape [- 1 ] // block_size ]
155- scale = scale .reshape (scale_shape )
161+ # Store target shape as attribute on DequantizeLinear node
162+ target_shape_attr = node .attribute .add ()
163+ target_shape_attr .name = "_target_shape"
164+ target_shape_attr .ints .extend (weight_shape )
165+
156166 reshape_child_nodes = [n for n in graph .node if reshape_node .output [0 ] in n .input ]
157167 assert len (reshape_child_nodes ) == 1 , f"Expected exactly one child node for { node .name } "
158168
@@ -165,7 +175,7 @@ def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
165175 cast_child_nodes = [n for n in graph .node if cast_node .output [0 ] in n .input ]
166176 next_node = cast_child_nodes [0 ]
167177
168- # Transpose weights and scales if present
178+ # Store transpose permutation if present
169179 if next_node .op_type == "Transpose" :
170180 transpose_node = next_node
171181 nodes_to_remove .append (transpose_node .name )
@@ -177,26 +187,90 @@ def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
177187 if attr .name == "perm" :
178188 perm = [x for x in attr .ints ] # noqa: C416
179189 assert perm is not None , f"Permutation not found for { node .name } "
180- weight = weight .transpose (perm )
181- scale = scale .transpose (perm )
190+
191+ # Store permutation as attribute on DequantizeLinear node
192+ perm_attr = node .attribute .add ()
193+ perm_attr .name = "_transpose_perm"
194+ perm_attr .ints .extend (perm )
195+
182196 transpose_child_nodes = [
183197 n for n in graph .node if transpose_node .output [0 ] in n .input
184198 ]
185- # transpose_node.input = []
186199 assert len (transpose_child_nodes ) == 1 , (
187200 f"Expected exactly one matmul node for { node .name } "
188201 )
189202 matmul_node = transpose_child_nodes [0 ]
190203 else :
191204 matmul_node = next_node
205+
192206 assert matmul_node .op_type in ["MatMul" , "Gemm" ], (
193207 f"Expected MatMul or Gemm node for { node .name } "
194208 )
209+ # Rewire MatMul to use DequantizeLinear output directly
195210 matmul_node .input [1 ] = node .output [0 ]
196211
212+ # Remove transpose, reshape, and constant nodes
213+ new_nodes = [node for node in graph .node if node .name not in nodes_to_remove ]
214+ del graph .node [:]
215+ graph .node .extend (new_nodes )
216+
217+ return onnx_model
218+
219+ @staticmethod
220+ def compute_scales (onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
221+ """Computes the scales for the weights in the ONNX model for INT4 quantization."""
222+ graph = onnx_model .graph
223+ initializer_map = {initializer .name : initializer for initializer in graph .initializer }
224+ weight_dq_nodes = [node for node in graph .node if node .op_type == "DequantizeLinear" ]
225+ tensor_producer_map = get_tensor_producer_nodes (graph )
226+
227+ for node in weight_dq_nodes :
228+ weight_name = node .input [0 ]
229+ scale_name = node .input [1 ]
230+ logger .debug (f"Computing scales for weight { weight_name } " )
231+
232+ # Load weight and scale tensors
233+ weight = numpy_helper .to_array (initializer_map [weight_name ])
234+ if scale_name in initializer_map :
235+ scale = numpy_helper .to_array (initializer_map [scale_name ])
236+ else :
237+ scale_constant_node = tensor_producer_map [scale_name ]
238+ for attr in scale_constant_node .attribute :
239+ if attr .name == "value" :
240+ tensor = attr .t
241+ scale = numpy_helper .to_array (tensor )
242+
243+ # Dequantize weight
244+ weight = weight / scale
245+ block_size = weight .shape [- 1 ]
246+
247+ # Get target shape from metadata stored in pre_process
248+ target_shape = None
249+ transpose_perm = None
250+ for attr in node .attribute :
251+ if attr .name == "_target_shape" :
252+ target_shape = list (attr .ints )
253+ elif attr .name == "_transpose_perm" :
254+ transpose_perm = list (attr .ints )
255+
256+ assert target_shape is not None , f"Target shape not found for { node .name } "
257+
258+ # Reshape weights and scales
259+ weight = weight .reshape (target_shape )
260+ assert target_shape [- 1 ] % block_size == 0 , (
261+ f"Block size { block_size } is not divisible by { target_shape [- 1 ]} "
262+ )
263+ scale_shape = [* target_shape [:- 1 ], target_shape [- 1 ] // block_size ]
264+ scale = scale .reshape (scale_shape )
265+
266+ # Transpose weights and scales if permutation was stored
267+ if transpose_perm is not None :
268+ weight = weight .transpose (transpose_perm )
269+ scale = scale .transpose (transpose_perm )
270+
271+ # Handle scale tensor creation/update
197272 if scale_name not in initializer_map :
198273 # Remove scale producer if it's a Constant node
199- scale_name = node .input [1 ]
200274 scale_producer = tensor_producer_map [scale_name ]
201275 if scale_producer .op_type == "Constant" :
202276 graph .node .remove (scale_producer )
@@ -210,14 +284,17 @@ def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
210284 scale_tensor = onnx .numpy_helper .from_array (scale , scale_name )
211285 initializer_map [scale_name ].CopyFrom (scale_tensor )
212286
213- weight = numpy_helper .from_array (weight , weight_name )
214- initializer_map [weight_name ].CopyFrom (weight )
287+ # Update weight tensor
288+ weight_tensor = numpy_helper .from_array (weight , weight_name )
289+ initializer_map [weight_name ].CopyFrom (weight_tensor )
290+
215291 logger .debug (f"Computed scales for weight { weight_name } for INT4 quantization" )
216292
217- # Remove transpose and reshape nodes
218- new_nodes = [node for node in graph .node if node .name not in nodes_to_remove ]
219- del graph .node [:]
220- graph .node .extend (new_nodes )
293+ # Clean up metadata attributes from DequantizeLinear nodes
294+ for node in weight_dq_nodes :
295+ attrs_to_keep = [attr for attr in node .attribute if not attr .name .startswith ("_" )]
296+ del node .attribute [:]
297+ node .attribute .extend (attrs_to_keep )
221298
222299 return onnx_model
223300
@@ -306,6 +383,10 @@ def is_fp32_cast(node: onnx.NodeProto) -> bool:
306383class NVFP4QuantExporter (ONNXQuantExporter ):
307384 """Exporter for NVFP4 quantization."""
308385
386+ @staticmethod
387+ def pre_process (onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
388+ """Pre-processes the ONNX model for NVFP4 quantization."""
389+
309390 @staticmethod
310391 def compute_scales (onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
311392 """Computes the scales for the weights in the ONNX model for NVFP4 quantization."""
0 commit comments