Skip to content

Commit babda33

Browse files
add tests
1 parent 79d3973 commit babda33

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

tests/onnx/common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,15 @@ def add_mul(self, input_a: str, input_b: str, output: Optional[str] = None) -> s
155155
)
156156
return output
157157

158+
def add_add(self, input_a: str, input_b: str, output: Optional[str] = None) -> str:
159+
i = len(self._nodes)
160+
161+
output = f"Add_{i}_output" if output is None else output
162+
self._nodes.append(
163+
onnx.helper.make_node(op_type="Add", inputs=[input_a, input_b], outputs=[output], name=f"Add_{i}")
164+
)
165+
return output
166+
158167
def add_relu(self, input: str, output: Optional[str] = None) -> str:
159168
i = len(self._nodes)
160169

tests/onnx/quantization/test_fast_bias_correction.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@
1414
import onnx
1515
import torch
1616

17+
from nncf.common.factory import ModelTransformerFactory
1718
from nncf.common.factory import NNCFGraphFactory
19+
from nncf.common.graph.transformations.layout import TransformationLayout
1820
from nncf.onnx.graph.node_utils import get_bias_value
1921
from nncf.onnx.graph.node_utils import is_node_with_bias
22+
from nncf.onnx.graph.transformations.command_creation import ONNXCommandCreator
2023
from nncf.quantization.algorithms.fast_bias_correction.onnx_backend import ONNXFastBiasCorrectionAlgoBackend
2124
from tests.cross_fw.test_templates.test_fast_bias_correction import TemplateTestFBCAlgorithm
25+
from tests.onnx.common import ModelBuilder
2226

2327

2428
def get_data_from_node(model: onnx.ModelProto, node_name: str):
@@ -71,3 +75,47 @@ def check_bias(model: onnx.ModelProto, ref_bias: list):
7175
return
7276
msg = "Not found node with bias"
7377
raise ValueError(msg)
78+
79+
80+
def _build_matmul_add_model() -> onnx.ModelProto:
81+
mb = ModelBuilder()
82+
83+
x = mb.add_input("X", (2, 3))
84+
85+
x = mb.add_matmul(x, (3, 3))
86+
x = mb.add_add(x, mb.add_initializer(np.array([1, 1, 1], dtype=np.float32)))
87+
88+
x = mb.add_matmul(x, (3, 3))
89+
x = mb.add_add(x, mb.add_initializer(np.array([2, 2, 2], dtype=np.float32)))
90+
91+
mb.add_output(x, (2, 3))
92+
93+
return mb.build(opset_version=19, ir_version=9)
94+
95+
96+
def test_update_bias_in_matmul_add():
97+
"""
98+
Tests the ability to retrieve and update the value of the bias constant in a MatMul->Add subgraph,
99+
where the second input to the Add operation is a constant.
100+
"""
101+
model = _build_matmul_add_model()
102+
graph = NNCFGraphFactory.create(model)
103+
104+
nodes = [node for node in graph.get_all_nodes() if is_node_with_bias(node)]
105+
assert [x.node_name for x in nodes] == ["MatMul_0", "MatMul_2"]
106+
107+
for matmul, data in zip(nodes, [[1, 1, 1], [2, 2, 2]]):
108+
bias = get_bias_value(matmul, model)
109+
bias_ref = np.array(data, dtype=np.float32)
110+
assert np.all(np.isclose(bias, bias_ref, atol=0.0001)), f"{bias} != {bias_ref}"
111+
112+
layout = TransformationLayout()
113+
for matmul, data in zip(nodes, [[2, 2, 2], [1, 1, 1]]):
114+
new_bias = np.array(data, dtype=np.float32)
115+
layout.register(ONNXCommandCreator.create_command_to_update_bias(matmul, new_bias, graph))
116+
model = ModelTransformerFactory.create(model).transform(layout)
117+
118+
for matmul, data in zip(nodes, [[2, 2, 2], [1, 1, 1]]):
119+
bias = get_bias_value(matmul, model)
120+
bias_ref = np.array(data, dtype=np.float32)
121+
assert np.all(np.isclose(bias, bias_ref, atol=0.0001)), f"{bias} != {bias_ref}"

0 commit comments

Comments
 (0)