|
14 | 14 | import onnx |
15 | 15 | import torch |
16 | 16 |
|
| 17 | +from nncf.common.factory import ModelTransformerFactory |
17 | 18 | from nncf.common.factory import NNCFGraphFactory |
| 19 | +from nncf.common.graph.transformations.layout import TransformationLayout |
18 | 20 | from nncf.onnx.graph.node_utils import get_bias_value |
19 | 21 | from nncf.onnx.graph.node_utils import is_node_with_bias |
| 22 | +from nncf.onnx.graph.transformations.command_creation import ONNXCommandCreator |
20 | 23 | from nncf.quantization.algorithms.fast_bias_correction.onnx_backend import ONNXFastBiasCorrectionAlgoBackend |
21 | 24 | from tests.cross_fw.test_templates.test_fast_bias_correction import TemplateTestFBCAlgorithm |
| 25 | +from tests.onnx.common import ModelBuilder |
22 | 26 |
|
23 | 27 |
|
24 | 28 | def get_data_from_node(model: onnx.ModelProto, node_name: str): |
@@ -71,3 +75,47 @@ def check_bias(model: onnx.ModelProto, ref_bias: list): |
71 | 75 | return |
72 | 76 | msg = "Not found node with bias" |
73 | 77 | raise ValueError(msg) |
| 78 | + |
| 79 | + |
| 80 | +def _build_matmul_add_model() -> onnx.ModelProto: |
| 81 | + mb = ModelBuilder() |
| 82 | + |
| 83 | + x = mb.add_input("X", (2, 3)) |
| 84 | + |
| 85 | + x = mb.add_matmul(x, (3, 3)) |
| 86 | + x = mb.add_add(x, mb.add_initializer(np.array([1, 1, 1], dtype=np.float32))) |
| 87 | + |
| 88 | + x = mb.add_matmul(x, (3, 3)) |
| 89 | + x = mb.add_add(x, mb.add_initializer(np.array([2, 2, 2], dtype=np.float32))) |
| 90 | + |
| 91 | + mb.add_output(x, (2, 3)) |
| 92 | + |
| 93 | + return mb.build(opset_version=19, ir_version=9) |
| 94 | + |
| 95 | + |
| 96 | +def test_update_bias_in_matmul_add(): |
| 97 | + """ |
| 98 | + Tests the ability to retrieve and update the value of the bias constant in a MatMul->Add subgraph, |
| 99 | + where the second input to the Add operation is a constant. |
| 100 | + """ |
| 101 | + model = _build_matmul_add_model() |
| 102 | + graph = NNCFGraphFactory.create(model) |
| 103 | + |
| 104 | + nodes = [node for node in graph.get_all_nodes() if is_node_with_bias(node)] |
| 105 | + assert [x.node_name for x in nodes] == ["MatMul_0", "MatMul_2"] |
| 106 | + |
| 107 | + for matmul, data in zip(nodes, [[1, 1, 1], [2, 2, 2]]): |
| 108 | + bias = get_bias_value(matmul, model) |
| 109 | + bias_ref = np.array(data, dtype=np.float32) |
| 110 | + assert np.all(np.isclose(bias, bias_ref, atol=0.0001)), f"{bias} != {bias_ref}" |
| 111 | + |
| 112 | + layout = TransformationLayout() |
| 113 | + for matmul, data in zip(nodes, [[2, 2, 2], [1, 1, 1]]): |
| 114 | + new_bias = np.array(data, dtype=np.float32) |
| 115 | + layout.register(ONNXCommandCreator.create_command_to_update_bias(matmul, new_bias, graph)) |
| 116 | + model = ModelTransformerFactory.create(model).transform(layout) |
| 117 | + |
| 118 | + for matmul, data in zip(nodes, [[2, 2, 2], [1, 1, 1]]): |
| 119 | + bias = get_bias_value(matmul, model) |
| 120 | + bias_ref = np.array(data, dtype=np.float32) |
| 121 | + assert np.all(np.isclose(bias, bias_ref, atol=0.0001)), f"{bias} != {bias_ref}" |
0 commit comments