Skip to content

Commit e829998

Browse files
committed
Update tests
Signed-off-by: ajrasane <[email protected]>
1 parent 4dde5eb commit e829998

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

tests/unit/onnx/test_qdq_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
import pytest
1818
from onnx import TensorProto, helper, numpy_helper
1919

20+
from modelopt.onnx.export.quant_exporter import INT4QuantExporter
2021
from modelopt.onnx.quantization.qdq_utils import (
2122
_cast_fp4,
2223
_cast_fp8,
2324
fp4qdq_to_2dq,
24-
quantize_weights_to_int4,
2525
quantize_weights_to_mxfp8,
2626
)
2727

@@ -337,7 +337,9 @@ def test_basic_quantization_with_reshape_transpose(self):
337337
model = create_test_model_with_int4_dq_reshape_transpose_matmul()
338338

339339
# Run quantization
340-
quantized_model = quantize_weights_to_int4(model)
340+
quantized_model = INT4QuantExporter.compute_scales(model)
341+
quantized_model = INT4QuantExporter.compress_weights(quantized_model)
342+
quantized_model = INT4QuantExporter.post_process(quantized_model)
341343

342344
# Verify weight is converted to INT4
343345
weight_tensor = next(
@@ -362,7 +364,9 @@ def test_quantization_with_constant_scale(self):
362364
model = create_test_model_with_int4_dq_reshape_transpose_matmul(constant_scale=True)
363365

364366
# Run quantization
365-
quantized_model = quantize_weights_to_int4(model)
367+
quantized_model = INT4QuantExporter.compute_scales(model)
368+
quantized_model = INT4QuantExporter.compress_weights(quantized_model)
369+
quantized_model = INT4QuantExporter.post_process(quantized_model)
366370

367371
# Verify Constant node is removed
368372
constant_nodes = [node for node in quantized_model.graph.node if node.op_type == "Constant"]
@@ -385,7 +389,9 @@ def test_projection_bias_and_scale_casting(self):
385389
model = create_test_model_with_proj_nodes()
386390

387391
# Run quantization
388-
quantized_model = quantize_weights_to_int4(model)
392+
quantized_model = INT4QuantExporter.compute_scales(model)
393+
quantized_model = INT4QuantExporter.compress_weights(quantized_model)
394+
quantized_model = INT4QuantExporter.post_process(quantized_model)
389395

390396
# Verify bias tensor is cast to float16
391397
bias_tensor = next(

0 commit comments

Comments
 (0)