Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/nncf/onnx/graph/metatypes/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@
OPERATIONS_WITH_BIAS_REDUCED = [
onnx_metatypes.ONNXConvolutionMetatype,
onnx_metatypes.ONNXGemmMetatype,
# TODO: Need to add MatMul with the separate bias support (CVS-135433)
onnx_metatypes.ONNXMatMulMetatype,
]

OPERATIONS_WITH_BIAS = [
Expand Down
1 change: 0 additions & 1 deletion src/nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ class ONNXMatMulMetatype(ONNXOpWithWeightsMetatype):
op_names = ["MatMul"]
hw_config_names = [HWConfigOpName.MATMUL]
weight_channel_axis = -1 # For port_id=1
bias_port_id = 2
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not apply to the MatMul ONNX operation, as it does not accept a bias input. Reference: https://onnx.ai/onnx/operators/onnx__MatMul.html

possible_weight_ports = [0, 1]
output_channel_axis = -1

Expand Down
44 changes: 42 additions & 2 deletions src/nncf/onnx/graph/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from nncf.onnx.graph.metatypes.groups import OPERATIONS_WITH_BIAS
from nncf.onnx.graph.metatypes.groups import POSSIBLE_WEIGHT_LAYER_METATYPES
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXGemmMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXMatMulMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpWithWeightsMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import get_metatype
Expand Down Expand Up @@ -186,17 +187,54 @@ def _get_bias_attr(
node: onnx.NodeProto,
model: onnx.ModelProto,
parents_node_mapping: dict[str, onnx.NodeProto],
children_node_mapping: dict[str, onnx.NodeProto],
) -> dict[str, str]:
"""
Returns bias tensor attributes.

:param node: ONNX node.
:param model: ONNX model.
:param parents_node_mapping: Mapping from edge name to node which outputs this edge.
:param children_node_mapping: mapping from edge name to nodes which consume this edge as an input.
:return: Bias tensor attributes.
"""
bias_attrs = {}
metatype = get_metatype(model, node)

if metatype == ONNXMatMulMetatype:
weight_port_ids = _get_weight_port_ids(node, model, parents_node_mapping)

if not weight_port_ids:
# `node` is a MatMul without weights, so return empty attributes
return {}

# Retrieve all nodes that consume the output of the MatMul operation.
# The MatMul operation has only one output.
y = node.output[0]
consumers = children_node_mapping[y]

if len(consumers) != 1 or consumers[0].op_type != "Add":
return {}

# Here, we are certain that after a `MatMul` operation, there is only
# the `Add` operation.
add_node = consumers[0]

# Find the input of `add_node` that is not equal to `y`.
tensor_name = None
port_id = None
for i, name in enumerate(add_node.input):
if name != y:
tensor_name = name
port_id = i
break

# Ensure that `tensor_name` is the output of a `Constant` node or an initializer.
initializer = {x.name: x for x in model.graph.initializer}
if tensor_name in initializer or parents_node_mapping[tensor_name].op_type == "Constant":
return {"node": add_node.name, "name": tensor_name, "port_id": port_id}
return {}

bias_attrs = {}
if _is_node_with_bias(node, model):
bias_tensor_port_id = get_bias_tensor_port_id(metatype)
bias_edge_name = get_tensor_edge_name(model, node, bias_tensor_port_id, parents_node_mapping)
Expand Down Expand Up @@ -348,6 +386,7 @@ def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph:
"""
onnx_model = GraphConverter._replace_empty_node_name(onnx_model)
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)

edge_info_mapping = get_edge_info_mapping(onnx_model)
children_node_mapping = get_children_node_mapping(onnx_model)
parents_node_mapping = get_parents_node_mapping(onnx_model)
Expand All @@ -358,7 +397,8 @@ def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph:
is_shared = None
weight_attrs = {}
node_attrs = _get_node_attrs(node, onnx_model)
bias_attrs = _get_bias_attr(node, onnx_model, parents_node_mapping)
bias_attrs = _get_bias_attr(node, onnx_model, parents_node_mapping, children_node_mapping)

if weight_port_ids: # If node has weight
weight_edge_names = []
for weight_port_id in weight_port_ids:
Expand Down
1 change: 1 addition & 0 deletions src/nncf/onnx/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def get_bias_value(node_with_bias: NNCFNode, model: onnx.ModelProto) -> np.ndarr
:return: The bias value that is applied to the output tensor of the node's operation.
"""
assert node_with_bias.layer_attributes.has_bias()
# TODO(andrey-churkin): Support Add + Constant case
bias_name = node_with_bias.layer_attributes.bias_attrs["name"]
return get_tensor_value(model, bias_name)

Expand Down
9 changes: 7 additions & 2 deletions src/nncf/onnx/graph/transformations/command_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@ def create_bias_correction_command(node: NNCFNode, bias_value: np.ndarray) -> ON
:param bias_value: The new bias value that will be set.
:return: The `ONNXInitializerUpdateCommand` command to update bias.
"""
bias_port_id = node.metatype.bias_port_id
target_point = ONNXTargetPoint(TargetType.LAYER, node.node_name, bias_port_id)
node_name = node.layer_attributes.bias_attrs.get("node")
if node_name:
port_id = node.layer_attributes.bias_attrs["port_id"]
target_point = ONNXTargetPoint(TargetType.LAYER, node_name, port_id)
else:
bias_port_id = node.metatype.bias_port_id
target_point = ONNXTargetPoint(TargetType.LAYER, node.node_name, bias_port_id)
return ONNXInitializerUpdateCommand(target_point, bias_value)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,11 @@ def apply(

output_channel_axis = node.metatype.output_channel_axis
input_channel_axis = self._backend_entity.get_activation_channel_axis(node, input_port_id, input_shape)
if bias_value.ndim > 1:
# Make index positive
output_channel_axis = range(bias_value.ndim)[output_channel_axis]
input_channel_axis = range(bias_value.ndim)[input_channel_axis]
Comment on lines -178 to -181
Copy link
Contributor Author

@andrey-churkin andrey-churkin Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was moved into create_input_data() because this method does not work properly with a negative channel_axis. The output_channel_axis is converted to positive inside mean_per_channel() method.


input_blob = self._backend_entity.create_input_data(
input_shape, input_fp, sub_input_name, input_channel_axis
)

bias_shift = self._get_bias_shift(
model=extracted_model,
input_blob=input_blob,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def get_sub_input_output_names(subgraph: onnx.ModelProto) -> tuple[str, str]:
def create_input_data(
shape: tuple[int], data: list[Tensor], input_name: str, channel_axis: int
) -> dict[str, np.array]:
channel_axis = range(len(shape))[channel_axis]
blob = np.zeros(shape, dtype=data[0].data.dtype)
for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])):
index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def get_sub_input_output_names(subgraph: ov.Model) -> tuple[str, str]:
def create_input_data(
shape: tuple[int], data: list[Tensor], input_name: str, channel_axis: int
) -> dict[str, np.ndarray]:
channel_axis = range(len(shape))[channel_axis]
blob = np.zeros(shape, dtype=data[0].data.dtype)
for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])):
index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def get_sub_input_output_names(subgraph: NNCFNetwork) -> tuple[Optional[str], Op

@staticmethod
def create_input_data(shape: tuple[int], data: list[Tensor], input_name: str, channel_axis: int) -> torch.Tensor:
channel_axis = range(len(shape))[channel_axis]
blob = torch.zeros(shape, dtype=data[0].data.dtype, device=data[0].data.device)
for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])):
index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def get_sub_input_output_names(subgraph: torch.fx.GraphModule) -> tuple[Optional

@staticmethod
def create_input_data(shape: tuple[int], data: list[Tensor], input_name: str, channel_axis: int) -> torch.Tensor:
channel_axis = range(len(shape))[channel_axis]
blob = torch.zeros(shape, dtype=data[0].data.dtype, device=data[0].data.device)
for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])):
index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim))
Expand Down
9 changes: 9 additions & 0 deletions tests/onnx/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,15 @@ def add_mul(self, input_a: str, input_b: str, output: Optional[str] = None) -> s
)
return output

def add_add(self, input_a: str, input_b: str, output: Optional[str] = None) -> str:
i = len(self._nodes)

output = f"Add_{i}_output" if output is None else output
self._nodes.append(
onnx.helper.make_node(op_type="Add", inputs=[input_a, input_b], outputs=[output], name=f"Add_{i}")
)
return output

def add_relu(self, input: str, output: Optional[str] = None) -> str:
i = len(self._nodes)

Expand Down
48 changes: 48 additions & 0 deletions tests/onnx/quantization/test_fast_bias_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@
import onnx
import torch

from nncf.common.factory import ModelTransformerFactory
from nncf.common.factory import NNCFGraphFactory
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.onnx.graph.node_utils import get_bias_value
from nncf.onnx.graph.node_utils import is_node_with_bias
from nncf.onnx.graph.transformations.command_creation import ONNXCommandCreator
from nncf.quantization.algorithms.fast_bias_correction.onnx_backend import ONNXFastBiasCorrectionAlgoBackend
from tests.cross_fw.test_templates.test_fast_bias_correction import TemplateTestFBCAlgorithm
from tests.onnx.common import ModelBuilder


def get_data_from_node(model: onnx.ModelProto, node_name: str):
Expand Down Expand Up @@ -71,3 +75,47 @@ def check_bias(model: onnx.ModelProto, ref_bias: list):
return
msg = "Not found node with bias"
raise ValueError(msg)


def _build_matmul_add_model() -> onnx.ModelProto:
mb = ModelBuilder()

x = mb.add_input("X", (2, 3))

x = mb.add_matmul(x, (3, 3))
x = mb.add_add(x, mb.add_initializer(np.array([1, 1, 1], dtype=np.float32)))

x = mb.add_matmul(x, (3, 3))
x = mb.add_add(x, mb.add_initializer(np.array([2, 2, 2], dtype=np.float32)))

mb.add_output(x, (2, 3))

return mb.build(opset_version=19, ir_version=9)


def test_update_bias_in_matmul_add():
"""
Tests the ability to retrieve and update the value of the bias constant in a MatMul->Add subgraph,
where the second input to the Add operation is a constant.
"""
model = _build_matmul_add_model()
graph = NNCFGraphFactory.create(model)

nodes = [node for node in graph.get_all_nodes() if is_node_with_bias(node)]
assert [x.node_name for x in nodes] == ["MatMul_0", "MatMul_2"]

for matmul, data in zip(nodes, [[1, 1, 1], [2, 2, 2]]):
bias = get_bias_value(matmul, model)
bias_ref = np.array(data, dtype=np.float32)
assert np.all(np.isclose(bias, bias_ref, atol=0.0001)), f"{bias} != {bias_ref}"

layout = TransformationLayout()
for matmul, data in zip(nodes, [[2, 2, 2], [1, 1, 1]]):
new_bias = np.array(data, dtype=np.float32)
layout.register(ONNXCommandCreator.create_command_to_update_bias(matmul, new_bias, graph))
model = ModelTransformerFactory.create(model).transform(layout)

for matmul, data in zip(nodes, [[2, 2, 2], [1, 1, 1]]):
bias = get_bias_value(matmul, model)
bias_ref = np.array(data, dtype=np.float32)
assert np.all(np.isclose(bias, bias_ref, atol=0.0001)), f"{bias} != {bias_ref}"
2 changes: 1 addition & 1 deletion tests/post_training/data/ptq_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ timm/deit3_small_patch16_224_backend_CUDA_TORCH:
timm/deit3_small_patch16_224_backend_FP32:
metric_value: 0.81358
timm/deit3_small_patch16_224_backend_ONNX:
metric_value: 0.81116
metric_value: 0.81156
timm/deit3_small_patch16_224_backend_OV:
metric_value: 0.81276
timm/deit3_small_patch16_224_backend_TORCH:
Expand Down