Skip to content

Commit 1a79957

Browse files
JakeStevensfacebook-github-bot
authored andcommitted
Add 3D (NCW) channels-last dim_order support to Neutron runtime
Summary: Although 3dim dim order is not supported pytorch, it actually *is* properly supported within the NXP backend's implementation. The only issue is within the runtime, which forces the check to four dimensions. This PR relaxes and supports 3D. Differential Revision: D102862305
1 parent cdcc915 commit 1a79957

1 file changed

Lines changed: 51 additions & 17 deletions

File tree

backends/nxp/runtime/NeutronBackend.cpp

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,21 @@ void wait(uint32_t channel) {}
9595
static NeutronConfig neutronMemCopyConfig = {copy, wait};
9696
#endif
9797

98+
// Extended dim_order check that handles 3D tensors (NCW → NWC).
99+
// The core is_channels_last_dim_order() only supports 4D/5D.
100+
bool is_channels_last_dim_order_extended(
101+
const uint8_t* dim_order,
102+
size_t dims) {
103+
if (is_channels_last_dim_order(dim_order, dims)) {
104+
return true;
105+
}
106+
// 3D: dim_order [0, 2, 1] means channel dim (1) is innermost → NWC layout.
107+
if (dims == 3) {
108+
return dim_order[0] == 0 && dim_order[1] == 2 && dim_order[2] == 1;
109+
}
110+
return false;
111+
}
112+
98113
// Applied on outputs.
99114
template <typename T>
100115
void transposeToChannelFirst(
@@ -149,12 +164,21 @@ void transposeInput(
149164
if (length < 3) {
150165
return;
151166
}
152-
size_t N = 1;
153-
size_t C = sizes[length - 3];
154-
size_t H = sizes[length - 2];
155-
size_t W = sizes[length - 1];
156-
for (size_t i = 0; i < length - 3; i++) {
157-
N *= sizes[i];
167+
size_t N, C, H, W;
168+
if (length == 3) {
169+
// 3D [N, C, W] → treat as [N, C, 1, W]
170+
N = sizes[0];
171+
C = sizes[1];
172+
H = 1;
173+
W = sizes[2];
174+
} else {
175+
N = 1;
176+
C = sizes[length - 3];
177+
H = sizes[length - 2];
178+
W = sizes[length - 1];
179+
for (size_t i = 0; i < length - 3; i++) {
180+
N *= sizes[i];
181+
}
158182
}
159183
switch (element_size) {
160184
case 1:
@@ -204,12 +228,21 @@ void transposeOutput(
204228
if (length < 3) {
205229
return;
206230
}
207-
size_t N = 1;
208-
size_t C = sizes[length - 3];
209-
size_t H = sizes[length - 2];
210-
size_t W = sizes[length - 1];
211-
for (size_t i = 0; i < length - 3; i++) {
212-
N *= sizes[i];
231+
size_t N, C, H, W;
232+
if (length == 3) {
233+
// 3D [N, C, W] → treat as [N, C, 1, W]
234+
N = sizes[0];
235+
C = sizes[1];
236+
H = 1;
237+
W = sizes[2];
238+
} else {
239+
N = 1;
240+
C = sizes[length - 3];
241+
H = sizes[length - 2];
242+
W = sizes[length - 1];
243+
for (size_t i = 0; i < length - 3; i++) {
244+
N *= sizes[i];
245+
}
213246
}
214247
switch (element_size) {
215248
case 1:
@@ -252,7 +285,8 @@ bool multipleChannelsPresent(const ArrayRef<exec_aten::SizesType>& sizes) {
252285
if (length < 3) {
253286
return true;
254287
}
255-
exec_aten::SizesType C = sizes[length - 3];
288+
// For 3D [N, C, W], channel dim is sizes[1]. For 4D+, it's sizes[length-3].
289+
exec_aten::SizesType C = (length == 3) ? sizes[1] : sizes[length - 3];
256290
return C != 1;
257291
}
258292

@@ -431,7 +465,7 @@ class NeutronBackend final : public PyTorchBackendInterface {
431465
return Error::InvalidProgram;
432466
}
433467

434-
if (is_channels_last_dim_order(dim_order, arg.dim())) {
468+
if (is_channels_last_dim_order_extended(dim_order, arg.dim())) {
435469
// The tensor is already permuted.
436470
ET_LOG(Debug, "Using channels last dim order for input %d.\n", i);
437471
cfg->dcfg.inputs[i] = arg.const_data_ptr();
@@ -481,7 +515,7 @@ class NeutronBackend final : public PyTorchBackendInterface {
481515
multipleChannelsPresent(arg.sizes())) {
482516
// The output will have to be transposed.
483517

484-
if (is_channels_last_dim_order(dim_order, arg.dim())) {
518+
if (is_channels_last_dim_order_extended(dim_order, arg.dim())) {
485519
// The tensor will already be correctly permuted. No transposition
486520
// needed.
487521
cfg->dcfg.outputs[i] = arg.mutable_data_ptr();
@@ -539,7 +573,7 @@ class NeutronBackend final : public PyTorchBackendInterface {
539573
}
540574

541575
auto dim_order = arg.dim_order().data();
542-
if (is_channels_last_dim_order(dim_order, arg.dim())) {
576+
if (is_channels_last_dim_order_extended(dim_order, arg.dim())) {
543577
// The rest of the model expects the `channels_last` dim order, which
544578
// the data already matches.
545579
ET_LOG(Debug, "Using channels last dim order for output %d.\n", i);
@@ -587,4 +621,4 @@ static auto registered = register_backend(backend_id);
587621
} // namespace
588622
} // namespace neutron
589623
} // namespace executor
590-
} // namespace torch
624+
} // namespace torch

0 commit comments

Comments
 (0)