Skip to content

Commit 1b8ef9d

Browse files
authored
Cortex-M backend: Minimize scope of cmsis_nn dependency. (#20371)
For Ethos-U flows, you might want to run only the ReplaceQuantNodesPass pass which doesn't require the cmsis_nn dependency. Since the install is currently not trivial, we shouldn't force people to do it when not needed. Right now, all passes are imported when importing that pass, triggering importing cmsis_nn. - Only require cmsis_nn when cmsis_nn functions are used. - Do this by wrapping cmsis_nn, taking the chance to add typing. Tested by running from executorch.backends.cortex_m.passes import ReplaceQuantNodesPass Before the patch this triggers an error if cmsis_nn is not installed. After the patch, it doesn't. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani @psiddh @AdrianLundell Signed-off-by: Erik Lundell <erik.lundell@arm.com>
1 parent 66884b4 commit 1b8ef9d

10 files changed

Lines changed: 305 additions & 8 deletions

File tree

backends/cortex_m/TARGETS

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,23 @@ python_library(
2020
],
2121
)
2222

23+
python_library(
24+
name = "cmsis_nn",
25+
srcs = [
26+
"library/__init__.py",
27+
"library/cmsis_nn.py",
28+
],
29+
deps = [
30+
"fbsource//third-party/cmsis-nn:cmsis_nn_py",
31+
],
32+
)
33+
2334
python_library(
2435
name = "target_config",
2536
srcs = [
2637
"target_config.py",
2738
],
2839
deps = [
29-
"fbsource//third-party/cmsis-nn:cmsis_nn_py",
40+
":cmsis_nn",
3041
],
3142
)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from __future__ import annotations
8+
9+
from types import ModuleType
10+
from typing import Any, cast, ClassVar, Sequence, TYPE_CHECKING
11+
12+
_cmsis_nn: ModuleType | None = None
13+
_cmsis_nn_import_error: ModuleNotFoundError | None = None
14+
15+
16+
class _EnumValue:
17+
def __init__(self, enum_name: str, name: str, value: int) -> None:
18+
self._enum_name = enum_name
19+
self.name = name
20+
self.value = value
21+
22+
def __repr__(self) -> str:
23+
return f"<{self._enum_name}.{self.name}: {self.value}>"
24+
25+
def __str__(self) -> str:
26+
return f"{self._enum_name}.{self.name}"
27+
28+
29+
class Backend:
30+
MVE: ClassVar[Backend]
31+
DSP: ClassVar[Backend]
32+
SCALAR: ClassVar[Backend]
33+
34+
name: str
35+
value: int
36+
37+
38+
Backend.MVE = cast(Backend, _EnumValue("Backend", "MVE", 0))
39+
Backend.DSP = cast(Backend, _EnumValue("Backend", "DSP", 1))
40+
Backend.SCALAR = cast(Backend, _EnumValue("Backend", "SCALAR", 2))
41+
42+
43+
class CortexM:
44+
M0: ClassVar[CortexM]
45+
M0PLUS: ClassVar[CortexM]
46+
M3: ClassVar[CortexM]
47+
M4: ClassVar[CortexM]
48+
M7: ClassVar[CortexM]
49+
M23: ClassVar[CortexM]
50+
M33: ClassVar[CortexM]
51+
M35P: ClassVar[CortexM]
52+
M55: ClassVar[CortexM]
53+
M85: ClassVar[CortexM]
54+
55+
name: str
56+
value: int
57+
58+
59+
CortexM.M0 = cast(CortexM, _EnumValue("CortexM", "M0", 0))
60+
CortexM.M0PLUS = cast(CortexM, _EnumValue("CortexM", "M0PLUS", 1))
61+
CortexM.M3 = cast(CortexM, _EnumValue("CortexM", "M3", 2))
62+
CortexM.M4 = cast(CortexM, _EnumValue("CortexM", "M4", 3))
63+
CortexM.M7 = cast(CortexM, _EnumValue("CortexM", "M7", 4))
64+
CortexM.M23 = cast(CortexM, _EnumValue("CortexM", "M23", 5))
65+
CortexM.M33 = cast(CortexM, _EnumValue("CortexM", "M33", 6))
66+
CortexM.M35P = cast(CortexM, _EnumValue("CortexM", "M35P", 7))
67+
CortexM.M55 = cast(CortexM, _EnumValue("CortexM", "M55", 8))
68+
CortexM.M85 = cast(CortexM, _EnumValue("CortexM", "M85", 9))
69+
70+
71+
class DataType:
72+
A8W4: ClassVar[DataType]
73+
A8W8: ClassVar[DataType]
74+
A16W8: ClassVar[DataType]
75+
76+
name: str
77+
value: int
78+
79+
80+
DataType.A8W4 = cast(DataType, _EnumValue("DataType", "A8W4", 0))
81+
DataType.A8W8 = cast(DataType, _EnumValue("DataType", "A8W8", 1))
82+
DataType.A16W8 = cast(DataType, _EnumValue("DataType", "A16W8", 2))
83+
84+
85+
if not TYPE_CHECKING:
86+
try:
87+
import cmsis_nn as _real_cmsis_nn # type: ignore[import-not-found, import-untyped]
88+
except ModuleNotFoundError as exc:
89+
if exc.name != "cmsis_nn":
90+
raise
91+
_cmsis_nn_import_error = exc
92+
else:
93+
_cmsis_nn = _real_cmsis_nn
94+
Backend = _real_cmsis_nn.Backend
95+
CortexM = _real_cmsis_nn.CortexM
96+
DataType = _real_cmsis_nn.DataType
97+
98+
99+
def _missing_dependencies_error() -> ModuleNotFoundError:
100+
return ModuleNotFoundError(
101+
"Cortex-M backend dependencies are not installed. "
102+
"Install by running `examples/arm/setup.sh --i-agree-to-the-contained-eula`, "
103+
"or pip install from the CMSIS-NN repo."
104+
)
105+
106+
107+
def _require_cmsis_nn() -> ModuleType:
108+
if _cmsis_nn is None:
109+
raise _missing_dependencies_error() from _cmsis_nn_import_error
110+
return _cmsis_nn
111+
112+
113+
def resolve_backend(cpu: CortexM) -> Backend:
114+
return _require_cmsis_nn().resolve_backend(cpu)
115+
116+
117+
def convolve_wrapper_buffer_size(
118+
backend: Backend,
119+
data_type: DataType,
120+
*,
121+
input_nhwc: Sequence[int],
122+
filter_nhwc: Sequence[int],
123+
output_nhwc: Sequence[int],
124+
padding_hw: Sequence[int],
125+
stride_hw: Sequence[int],
126+
dilation_hw: Sequence[int],
127+
input_offset: int = 0,
128+
output_offset: int = 0,
129+
activation_min: int = -128,
130+
activation_max: int = 127,
131+
) -> int:
132+
return _require_cmsis_nn().convolve_wrapper_buffer_size(
133+
backend,
134+
data_type,
135+
input_nhwc=input_nhwc,
136+
filter_nhwc=filter_nhwc,
137+
output_nhwc=output_nhwc,
138+
padding_hw=padding_hw,
139+
stride_hw=stride_hw,
140+
dilation_hw=dilation_hw,
141+
input_offset=input_offset,
142+
output_offset=output_offset,
143+
activation_min=activation_min,
144+
activation_max=activation_max,
145+
)
146+
147+
148+
def depthwise_conv_wrapper_buffer_size(
149+
backend: Backend,
150+
data_type: DataType,
151+
*,
152+
input_nhwc: Sequence[int],
153+
filter_nhwc: Sequence[int],
154+
output_nhwc: Sequence[int],
155+
padding_hw: Sequence[int],
156+
stride_hw: Sequence[int],
157+
dilation_hw: Sequence[int],
158+
ch_mult: int,
159+
input_offset: int = 0,
160+
output_offset: int = 0,
161+
activation_min: int = -128,
162+
activation_max: int = 127,
163+
) -> int:
164+
return _require_cmsis_nn().depthwise_conv_wrapper_buffer_size(
165+
backend,
166+
data_type,
167+
input_nhwc=input_nhwc,
168+
filter_nhwc=filter_nhwc,
169+
output_nhwc=output_nhwc,
170+
padding_hw=padding_hw,
171+
stride_hw=stride_hw,
172+
dilation_hw=dilation_hw,
173+
ch_mult=ch_mult,
174+
input_offset=input_offset,
175+
output_offset=output_offset,
176+
activation_min=activation_min,
177+
activation_max=activation_max,
178+
)
179+
180+
181+
def fully_connected_buffer_size(
182+
backend: Backend,
183+
data_type: DataType,
184+
*,
185+
filter_nhwc: Sequence[int],
186+
) -> int:
187+
return _require_cmsis_nn().fully_connected_buffer_size(
188+
backend,
189+
data_type,
190+
filter_nhwc=filter_nhwc,
191+
)
192+
193+
194+
def transpose_conv_buffer_size(
195+
backend: Backend,
196+
data_type: DataType,
197+
*,
198+
input_nhwc: Sequence[int],
199+
filter_nhwc: Sequence[int],
200+
output_nhwc: Sequence[int],
201+
padding_hw: Sequence[int],
202+
stride_hw: Sequence[int],
203+
dilation_hw: Sequence[int],
204+
padding_offsets_hw: Sequence[int] = (0, 0),
205+
input_offset: int = 0,
206+
output_offset: int = 0,
207+
activation_min: int = -128,
208+
activation_max: int = 127,
209+
) -> int:
210+
return _require_cmsis_nn().transpose_conv_buffer_size(
211+
backend,
212+
data_type,
213+
input_nhwc=input_nhwc,
214+
filter_nhwc=filter_nhwc,
215+
output_nhwc=output_nhwc,
216+
padding_hw=padding_hw,
217+
stride_hw=stride_hw,
218+
dilation_hw=dilation_hw,
219+
padding_offsets_hw=padding_offsets_hw,
220+
input_offset=input_offset,
221+
output_offset=output_offset,
222+
activation_min=activation_min,
223+
activation_max=activation_max,
224+
)
225+
226+
227+
def transpose_conv_reverse_conv_buffer_size(
228+
backend: Backend,
229+
data_type: DataType,
230+
*,
231+
input_nhwc: Sequence[int],
232+
filter_nhwc: Sequence[int],
233+
padding_hw: Sequence[int],
234+
stride_hw: Sequence[int],
235+
dilation_hw: Sequence[int] = (1, 1),
236+
padding_offsets_hw: Sequence[int] = (0, 0),
237+
input_offset: int = 0,
238+
output_offset: int = 0,
239+
activation_min: int = -128,
240+
activation_max: int = 127,
241+
) -> int:
242+
return _require_cmsis_nn().transpose_conv_reverse_conv_buffer_size(
243+
backend,
244+
data_type,
245+
input_nhwc=input_nhwc,
246+
filter_nhwc=filter_nhwc,
247+
padding_hw=padding_hw,
248+
stride_hw=stride_hw,
249+
dilation_hw=dilation_hw,
250+
padding_offsets_hw=padding_offsets_hw,
251+
input_offset=input_offset,
252+
output_offset=output_offset,
253+
activation_min=activation_min,
254+
activation_max=activation_max,
255+
)
256+
257+
258+
def avgpool_buffer_size(
259+
backend: Backend,
260+
data_type: DataType,
261+
*,
262+
dim_dst_width: int,
263+
ch_src: int,
264+
) -> int:
265+
return _require_cmsis_nn().avgpool_buffer_size(
266+
backend,
267+
data_type,
268+
dim_dst_width=dim_dst_width,
269+
ch_src=ch_src,
270+
)
271+
272+
273+
def __getattr__(name: str) -> Any:
274+
return getattr(_require_cmsis_nn(), name)
275+
276+
277+
def __dir__() -> list[str]:
278+
cmsis_names = set() if _cmsis_nn is None else set(dir(_cmsis_nn))
279+
return sorted(set(globals()) | cmsis_names)

backends/cortex_m/passes/BUCK

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target")
22
# Copyright (c) Meta Platforms, Inc. and affiliates.
33
# All rights reserved.
4+
# Copyright 2026 Arm Limited and/or its affiliates.
45
#
56
# This source code is licensed under the BSD-style license found in the
67
# LICENSE file in the root directory of this source tree.
@@ -40,6 +41,7 @@ fbcode_target(_kind = runtime.python_library,
4041
deps=[
4142
"//caffe2:torch",
4243
"//executorch/backends/arm/_passes:passes",
44+
"//executorch/backends/cortex_m:cmsis_nn",
4345
"//executorch/backends/cortex_m:target_config",
4446
"//executorch/backends/cortex_m/ops:ops",
4547
"//executorch/backends/cortex_m/passes:passes_utils",

backends/cortex_m/passes/aten_to_cortex_m_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
import math
99
from typing import cast, Optional
1010

11-
import cmsis_nn # type: ignore[import-not-found, import-untyped]
1211
import executorch.backends.cortex_m.ops.operators # noqa
1312
import executorch.exir as exir
1413
import torch
1514
import torch.fx
1615
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
16+
from executorch.backends.cortex_m.library import cmsis_nn
1717

1818
from executorch.backends.cortex_m.passes.passes_utils import (
1919
build_activation_lut,

backends/cortex_m/passes/scratch_buffer_sizes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from collections.abc import Callable
77
from typing import Any, cast
88

9-
import cmsis_nn # type: ignore[import-not-found, import-untyped]
109
import executorch.backends.cortex_m.ops.operators # noqa
1110

1211
import torch
1312
import torch.fx
13+
from executorch.backends.cortex_m.library import cmsis_nn
1414

1515
from executorch.exir.dialects._ops import ops as exir_ops
1616

backends/cortex_m/target_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -10,7 +11,7 @@
1011
from enum import auto, Enum
1112
from typing import Optional
1213

13-
import cmsis_nn # type: ignore[import-not-found, import-untyped]
14+
from executorch.backends.cortex_m.library import cmsis_nn
1415

1516

1617
class CortexM(Enum):

backends/cortex_m/test/misc/test_cmsis_pybind.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2026 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -11,7 +10,7 @@
1110

1211
def _import_cmsis_nn():
1312
try:
14-
return importlib.import_module("cmsis_nn")
13+
return importlib.import_module("executorch.backends.cortex_m.library.cmsis_nn")
1514
except Exception as exc:
1615
pytest.fail(f"Failed to resolve cmsis_nn: {exc}")
1716

backends/cortex_m/test/misc/test_target_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
67

7-
import cmsis_nn # type: ignore[import-not-found, import-untyped]
88
import pytest
99

10+
from executorch.backends.cortex_m.library import cmsis_nn
1011
from executorch.backends.cortex_m.target_config import CortexM, CortexMTargetConfig
1112

1213

0 commit comments

Comments
 (0)