Skip to content

Commit a4c3e31

Browse files
committed
Update tests
Signed-off-by: ajrasane <[email protected]>
1 parent 1fc6e83 commit a4c3e31

File tree

3 files changed

+19
-13
lines changed

3 files changed

+19
-13
lines changed

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1051,7 +1051,7 @@ def replace_zero_scale_with_smallest_nonzero(onnx_model: onnx.ModelProto) -> onn
10511051
return onnx_model
10521052

10531053

1054-
def _cast_initializer_to_dtype(
1054+
def cast_initializer_to_dtype(
10551055
node: onnx.NodeProto, dtype: str, initializer_map: dict[str, onnx.TensorProto]
10561056
):
10571057
"""Casts the initializer to the given dtype."""

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,18 @@
3232
from torch.nn.parallel import DataParallel, DistributedDataParallel
3333

3434
from modelopt.onnx.autocast.convert import convert_to_f16
35-
from modelopt.onnx.quantization.qdq_utils import (
36-
fp4qdq_to_2dq,
37-
qdq_to_dq,
38-
quantize_weights_to_int4,
39-
quantize_weights_to_mxfp8,
40-
replace_zero_scale_with_smallest_nonzero,
41-
)
4235
from modelopt.onnx.export.quant_exporter import (
4336
INT4QuantExporter,
4437
MXFP8QuantExporter,
4538
NVFP4QuantExporter,
4639
ONNXQuantExporter,
4740
)
48-
from modelopt.onnx.quantization.qdq_utils import fp4qdq_to_2dq, qdq_to_dq, quantize_weights_to_mxfp8
41+
from modelopt.onnx.quantization.qdq_utils import (
42+
fp4qdq_to_2dq,
43+
qdq_to_dq,
44+
quantize_weights_to_mxfp8,
45+
replace_zero_scale_with_smallest_nonzero,
46+
)
4947
from modelopt.onnx.utils import (
5048
get_input_names,
5149
get_input_shapes,
@@ -368,6 +366,8 @@ def is_fp8_quantized(model: nn.Module) -> bool:
368366
):
369367
return True
370368
return False
369+
370+
371371
def quantize_weights(model: nn.Module, onnx_model: onnx.ModelProto) -> onnx.ModelProto:
372372
"""Real quantizes the weights in the onnx model.
373373

tests/unit/onnx/test_qdq_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
import pytest
1818
from onnx import TensorProto, helper, numpy_helper
1919

20+
from modelopt.onnx.export.quant_exporter import INT4QuantExporter
2021
from modelopt.onnx.quantization.qdq_utils import (
2122
_cast_fp4,
2223
_cast_fp8,
2324
fp4qdq_to_2dq,
24-
quantize_weights_to_int4,
2525
quantize_weights_to_mxfp8,
2626
)
2727

@@ -337,7 +337,9 @@ def test_basic_quantization_with_reshape_transpose(self):
337337
model = create_test_model_with_int4_dq_reshape_transpose_matmul()
338338

339339
# Run quantization
340-
quantized_model = quantize_weights_to_int4(model)
340+
quantized_model = INT4QuantExporter.compute_scales(model)
341+
quantized_model = INT4QuantExporter.compress_weights(quantized_model)
342+
quantized_model = INT4QuantExporter.post_process(quantized_model)
341343

342344
# Verify weight is converted to INT4
343345
weight_tensor = next(
@@ -362,7 +364,9 @@ def test_quantization_with_constant_scale(self):
362364
model = create_test_model_with_int4_dq_reshape_transpose_matmul(constant_scale=True)
363365

364366
# Run quantization
365-
quantized_model = quantize_weights_to_int4(model)
367+
quantized_model = INT4QuantExporter.compute_scales(model)
368+
quantized_model = INT4QuantExporter.compress_weights(quantized_model)
369+
quantized_model = INT4QuantExporter.post_process(quantized_model)
366370

367371
# Verify Constant node is removed
368372
constant_nodes = [node for node in quantized_model.graph.node if node.op_type == "Constant"]
@@ -385,7 +389,9 @@ def test_projection_bias_and_scale_casting(self):
385389
model = create_test_model_with_proj_nodes()
386390

387391
# Run quantization
388-
quantized_model = quantize_weights_to_int4(model)
392+
quantized_model = INT4QuantExporter.compute_scales(model)
393+
quantized_model = INT4QuantExporter.compress_weights(quantized_model)
394+
quantized_model = INT4QuantExporter.post_process(quantized_model)
389395

390396
# Verify bias tensor is cast to float16
391397
bias_tensor = next(

0 commit comments

Comments
 (0)