Skip to content

Commit 8cc5dcf

Browse files
rascaniclaude
andcommitted
Cortex-M backend: address review feedback on quantized_activation
Adrian's three review comments on #19792, plus SIMD acceleration of the LUT lookup (his comment asked for vector intrinsics and loop unrolling): * Drop the target -> string indirection in the activation lowering. `passes_utils._ACTIVATION_FNS` now keys directly on the edge op target (`exir_ops.edge.aten.{sigmoid,tanh,silu}.default`), and `ConvertToCortexMPass._get_activation_replacement` passes `node.target` straight into `build_activation_lut` -- no `_ACTIVATION_KINDS` dict and no string round-trip. * Replace the scalar LUT-lookup loop with three compile-gated paths: - M55/M85 (MVE): 16 lanes per iteration -- `vldrbq_u8` load, `vaddq_n_u8` to bias by 128, `vldrbq_gather_offset_s8` to gather the LUT result, `vstrbq_s8` to store. - M4/M7 (DSP, no MVE): 4 bytes per iteration -- fold four byte-loads into one word-load, batch the +128 bias with `__uadd8`, four LUT lookups (no M-class gather instruction exists), fold four byte-stores into one word-store. Uses `<arm_acle.h>` and local memcpy helpers rather than pulling in the heavyweight `arm_nnsupportfunctions.h`. - All other cores (M0+/M3): a 4x-unrolled scalar tail, which also handles the sub-vector remainder of the two SIMD paths. * Switch the source header to Meta's standard copyright block to match the other cortex_m op files. The three paths were cross-compiled for cortex-m0plus / m4 / m7 / m55; the M4 build emits `uadd8` and the M55 build emits the MVE gather. Runtime correctness on M4/M7 hardware/FVP is not yet exercised by CI -- the host unit tests cover the scalar path only. Co-authored-by: Claude <noreply@anthropic.com>
1 parent 82d9c15 commit 8cc5dcf

3 files changed

Lines changed: 95 additions & 21 deletions

File tree

backends/cortex_m/ops/op_quantized_activation.cpp

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,46 @@
11
/*
2-
* Copyright 2026 Arm Limited and/or its affiliates.
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
34
*
45
* This source code is licensed under the BSD-style license found in the
56
* LICENSE file in the root directory of this source tree.
67
*/
78

89
#include "cortex_m_ops_common.h"
910

11+
#include <cstring>
12+
13+
#if defined(__ARM_FEATURE_MVE) && (__ARM_FEATURE_MVE & 1)
14+
#include <arm_mve.h>
15+
#define HAS_HELIUM_SIMD 1
16+
#endif
17+
18+
#if defined(ARM_MATH_DSP) && !defined(HAS_HELIUM_SIMD)
19+
#include <arm_acle.h>
20+
#define HAS_DSP_PACKED_LUT 1
21+
#endif
22+
1023
namespace cortex_m {
1124
namespace native {
1225

26+
#if defined(HAS_DSP_PACKED_LUT)
27+
// Local 4-byte read/write helpers. We deliberately don't include
28+
// `arm_nnsupportfunctions.h` for the equivalent CMSIS-NN `arm_nn_read_s8x4_ia`
29+
// / `arm_nn_write_s8x4_ia` -- the header is public but pulls in the entire
30+
// CMSIS-NN support surface (~1500 lines) just for two memcpy wrappers.
31+
static inline uint32_t read_u8x4_ia(const int8_t** in) {
32+
uint32_t val;
33+
std::memcpy(&val, *in, 4);
34+
*in += 4;
35+
return val;
36+
}
37+
38+
static inline void write_u8x4_ia(int8_t** out, uint32_t val) {
39+
std::memcpy(*out, &val, 4);
40+
*out += 4;
41+
}
42+
#endif
43+
1344
// cppcheck-suppress unusedFunction
1445
Tensor& quantized_activation_out(
1546
KernelRuntimeContext& context,
@@ -37,12 +68,59 @@ Tensor& quantized_activation_out(
3768
const int8_t* lut_data = lut.const_data_ptr<int8_t>();
3869
int8_t* out_data = out.mutable_data_ptr<int8_t>();
3970

40-
// Bias the signed int8 input by 128 to use it as an unsigned table index;
41-
// the LUT entries are precomputed AoT from the input/output qparams and the
71+
// The LUT is precomputed AoT from the input/output qparams and the
4272
// activation function (sigmoid / tanh / silu / ...), so the kernel does not
43-
// need to know which activation it is implementing.
73+
// need to know which activation it is implementing. The signed int8 input
74+
// is biased by 128 to use it as an unsigned [0, 255] table index.
4475
const int64_t n = input.numel();
45-
for (int64_t i = 0; i < n; ++i) {
76+
int64_t i = 0;
77+
78+
#if defined(HAS_HELIUM_SIMD)
79+
// M55/M85: 16 lanes per iteration. Reinterpret the int8 input as uint8
80+
// (bit-identical load), add 128 mod 256 to produce a uint8 LUT index, then
81+
// gather-load the int8 result from the LUT.
82+
for (; i + 15 < n; i += 16) {
83+
uint8x16_t in_u8 =
84+
vldrbq_u8(reinterpret_cast<const uint8_t*>(in_data + i));
85+
uint8x16_t idx = vaddq_n_u8(in_u8, 128);
86+
int8x16_t result = vldrbq_gather_offset_s8(lut_data, idx);
87+
vstrbq_s8(out_data + i, result);
88+
}
89+
#elif defined(HAS_DSP_PACKED_LUT)
90+
// M4/M7 (DSP, no MVE): process 4 bytes per iteration. The DSP win comes from
91+
// (a) folding 4 byte-loads into one word-load, (b) batching the +128 bias
92+
// with `__uadd8`, and (c) folding 4 byte-stores into one word-store. The
93+
// LUT lookups themselves still hit memory four times per word -- no DSP
94+
// gather instruction exists on M-class.
95+
const int8_t* in_ptr = in_data;
96+
int8_t* out_ptr = out_data;
97+
const int64_t word_iters = n >> 2;
98+
for (int64_t w = 0; w < word_iters; ++w) {
99+
const uint32_t in_word = read_u8x4_ia(&in_ptr);
100+
const uint32_t idx_word = __uadd8(in_word, 0x80808080u);
101+
const uint32_t out_word =
102+
static_cast<uint32_t>(static_cast<uint8_t>(lut_data[idx_word & 0xFFu])) |
103+
(static_cast<uint32_t>(static_cast<uint8_t>(lut_data[(idx_word >> 8) & 0xFFu]))
104+
<< 8) |
105+
(static_cast<uint32_t>(static_cast<uint8_t>(lut_data[(idx_word >> 16) & 0xFFu]))
106+
<< 16) |
107+
(static_cast<uint32_t>(static_cast<uint8_t>(lut_data[(idx_word >> 24) & 0xFFu]))
108+
<< 24);
109+
write_u8x4_ia(&out_ptr, out_word);
110+
}
111+
i = word_iters << 2;
112+
#endif
113+
114+
// 4x-unrolled scalar tail. On M-class cores without MVE or DSP the unroll
115+
// lets the compiler issue independent LUT loads; on the MVE / DSP paths
116+
// above this only runs for the < 16- (or < 4-) element remainder.
117+
for (; i + 3 < n; i += 4) {
118+
out_data[i + 0] = lut_data[static_cast<uint8_t>(in_data[i + 0] + 128)];
119+
out_data[i + 1] = lut_data[static_cast<uint8_t>(in_data[i + 1] + 128)];
120+
out_data[i + 2] = lut_data[static_cast<uint8_t>(in_data[i + 2] + 128)];
121+
out_data[i + 3] = lut_data[static_cast<uint8_t>(in_data[i + 3] + 128)];
122+
}
123+
for (; i < n; ++i) {
46124
out_data[i] = lut_data[static_cast<uint8_t>(in_data[i] + 128)];
47125
}
48126

backends/cortex_m/passes/convert_to_cortex_m_pass.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -486,12 +486,6 @@ def _get_bmm_replacement(self, node):
486486
)
487487
return exir_ops.edge.cortex_m.quantized_batch_matmul.default, args
488488

489-
_ACTIVATION_KINDS = {
490-
exir_ops.edge.aten.sigmoid.default: "sigmoid",
491-
exir_ops.edge.aten.tanh.default: "tanh",
492-
exir_ops.edge.aten.silu.default: "silu",
493-
}
494-
495489
def _get_activation_replacement(self, node):
496490
"""Lower a standalone quantized sigmoid / tanh / silu to a single
497491
cortex_m.quantized_activation call backed by an AoT-built 256-entry
@@ -500,9 +494,8 @@ def _get_activation_replacement(self, node):
500494
"""
501495
input_qparams = node.meta["input_qparams"][0]
502496
output_qparams = node.meta["output_qparams"][0]
503-
kind = self._ACTIVATION_KINDS[node.target]
504497
lut_tensor = build_activation_lut(
505-
kind,
498+
node.target,
506499
float(input_qparams.scale),
507500
int(input_qparams.zp),
508501
float(output_qparams.scale),

backends/cortex_m/passes/passes_utils.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,9 @@ def _stable_silu(x: float) -> float:
205205

206206

207207
_ACTIVATION_FNS = {
208-
"sigmoid": _stable_sigmoid,
209-
"tanh": math.tanh,
210-
"silu": _stable_silu,
208+
exir_ops.edge.aten.sigmoid.default: _stable_sigmoid,
209+
exir_ops.edge.aten.tanh.default: math.tanh,
210+
exir_ops.edge.aten.silu.default: _stable_silu,
211211
}
212212

213213

@@ -220,25 +220,28 @@ def _round_half_away_from_zero(x: float) -> int:
220220

221221

222222
def build_activation_lut(
223-
kind: str,
223+
target,
224224
input_scale: float,
225225
input_zp: int,
226226
output_scale: float,
227227
output_zp: int,
228228
) -> torch.Tensor:
229229
"""AoT-compute a 256-entry int8 lookup table for a quantized activation.
230230
231+
`target` is the edge-dialect op being lowered (e.g.
232+
`exir_ops.edge.aten.sigmoid.default`).
233+
231234
The LUT is indexed by the input byte value biased by 128: for any int8
232235
input `q_in`, the kernel reads `lut[q_in + 128]` to get the int8 output.
233236
Because the LUT is computed in float and quantized once per entry, the
234237
runtime kernel is a single memory-lookup with no requantization math.
235238
"""
236-
if kind not in _ACTIVATION_FNS:
239+
if target not in _ACTIVATION_FNS:
237240
raise ValueError(
238-
f"build_activation_lut: unknown activation '{kind}' "
239-
f"(supported: {sorted(_ACTIVATION_FNS)})"
241+
f"build_activation_lut: unsupported activation target {target!r} "
242+
f"(supported: {sorted(t.__name__ for t in _ACTIVATION_FNS)})"
240243
)
241-
f = _ACTIVATION_FNS[kind]
244+
f = _ACTIVATION_FNS[target]
242245
lut = torch.empty(256, dtype=torch.int8)
243246
for q in range(-128, 128):
244247
x = (q - input_zp) * input_scale

0 commit comments

Comments
 (0)