Skip to content

Commit b10f3f9

Browse files
committed
Fixes based on review comments.
1 parent cd4b36c commit b10f3f9

5 files changed

Lines changed: 47 additions & 2 deletions

File tree

backends/nxp/backend/ir/converter/node_converters/ops_converters/addmm_converter.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,15 @@ def _is_supported_in_IR(
5151
# As these cases seem rare, conversion is not implemented for the time being.
5252
return False
5353

54+
# The `aten.addmm` operator allows any bias shape, as long as it is broadcastable with the result of the matrix
55+
# multiplication. That means it supports 4 different shapes: [N, P], [1, P], [P], [1] (provided the MM result
56+
# has shape [N, P]). Out of these 4, Neutron IR allows only [1, P] and [P], both of which are supported on
57+
# Neutron.
58+
bias_shape = list(node.args[BIAS_IDX].meta["val"].shape)
59+
_, p = node.meta["val"].shape
60+
if bias_shape not in [[1, p], [p]]:
61+
return False
62+
5463
return True
5564

5665
@staticmethod

backends/nxp/tests/generic_tests/test_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024-2025 NXP
1+
# Copyright 2024-2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.

backends/nxp/tests/ir/converter/node_converter/test_addmm_converter.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,34 @@ def test__from_addmm(self, mocker, request, use_qat, input_shape: tuple[int, ...
8686
model = AddmmModule(input_shape[-1])
8787
self.assert_delegated(model, input_shape, mocker, request, use_qat=use_qat)
8888

89+
@pytest.mark.parametrize(
90+
"bias_shape",
91+
[
92+
(1, 7),
93+
(7,),
94+
],
95+
ids=lambda s: f"bias_shape = {s}",
96+
)
97+
def test__from_addmm__bias_shapes__supported(
98+
self, mocker, request, use_qat, bias_shape: tuple[int, ...]
99+
):
100+
input_shape = (3, 11)
101+
model = AddmmModule(input_shape[-1], bias_shape=bias_shape)
102+
self.assert_delegated(model, input_shape, mocker, request, use_qat=use_qat)
103+
104+
@pytest.mark.parametrize(
105+
"bias_shape",
106+
[
107+
(3, 7),
108+
(1,),
109+
],
110+
ids=lambda s: f"bias_shape = {s}",
111+
)
112+
def test__from_addmm__bias_shapes__unsupported(self, bias_shape: tuple[int, ...]):
113+
input_shape = (3, 11)
114+
model = AddmmModule(input_shape[-1], bias_shape=bias_shape)
115+
self.assert_not_delegated(model, input_shape)
116+
89117
def test__from_addmm__unsupported_alpha(self):
90118
input_shape = (1, 8)
91119
model = AddmmModule(input_shape[-1], alpha=0.42)

backends/nxp/tests/models.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,13 @@ def __init__(
258258
out_channels: int = 7,
259259
alpha: float | None = None,
260260
beta: float | None = None,
261+
bias_shape=None,
261262
):
263+
if bias_shape is None:
264+
bias_shape = (out_channels,)
262265
super().__init__()
263266
self.weight = torch.nn.Parameter(torch.empty(in_channels, out_channels))
264-
self.bias = torch.nn.Parameter(torch.empty(out_channels))
267+
self.bias = torch.nn.Parameter(torch.empty(bias_shape))
265268
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
266269
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
267270
bound = 1 / math.sqrt(fan_in)

backends/nxp/tests/use_qat.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# Copyright 2026 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
16
import pytest
27

38

0 commit comments

Comments
 (0)