Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 51 additions & 17 deletions backends/nxp/runtime/NeutronBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,21 @@ void wait(uint32_t channel) {}
static NeutronConfig neutronMemCopyConfig = {copy, wait};
#endif

// Extended dim_order check that handles 3D tensors (NCW → NWC).
// The core is_channels_last_dim_order() only supports 4D/5D.
bool is_channels_last_dim_order_extended(
const uint8_t* dim_order,
size_t dims) {
if (is_channels_last_dim_order(dim_order, dims)) {
return true;
}
// 3D: dim_order [0, 2, 1] means channel dim (1) is innermost → NWC layout.
if (dims == 3) {
return dim_order[0] == 0 && dim_order[1] == 2 && dim_order[2] == 1;
}
return false;
}

// Applied on outputs.
template <typename T>
void transposeToChannelFirst(
Expand Down Expand Up @@ -149,12 +164,21 @@ void transposeInput(
if (length < 3) {
return;
}
size_t N = 1;
size_t C = sizes[length - 3];
size_t H = sizes[length - 2];
size_t W = sizes[length - 1];
for (size_t i = 0; i < length - 3; i++) {
N *= sizes[i];
size_t N, C, H, W;
if (length == 3) {
// 3D [N, C, W] → treat as [N, C, 1, W]
N = sizes[0];
C = sizes[1];
H = 1;
W = sizes[2];
} else {
N = 1;
C = sizes[length - 3];
H = sizes[length - 2];
W = sizes[length - 1];
for (size_t i = 0; i < length - 3; i++) {
N *= sizes[i];
}
}
switch (element_size) {
case 1:
Expand Down Expand Up @@ -204,12 +228,21 @@ void transposeOutput(
if (length < 3) {
return;
}
size_t N = 1;
size_t C = sizes[length - 3];
size_t H = sizes[length - 2];
size_t W = sizes[length - 1];
for (size_t i = 0; i < length - 3; i++) {
N *= sizes[i];
size_t N, C, H, W;
if (length == 3) {
// 3D [N, C, W] → treat as [N, C, 1, W]
N = sizes[0];
C = sizes[1];
H = 1;
W = sizes[2];
} else {
N = 1;
C = sizes[length - 3];
H = sizes[length - 2];
W = sizes[length - 1];
for (size_t i = 0; i < length - 3; i++) {
N *= sizes[i];
}
}
switch (element_size) {
case 1:
Expand Down Expand Up @@ -252,7 +285,8 @@ bool multipleChannelsPresent(const ArrayRef<exec_aten::SizesType>& sizes) {
if (length < 3) {
return true;
}
exec_aten::SizesType C = sizes[length - 3];
// For 3D [N, C, W], channel dim is sizes[1]. For 4D+, it's sizes[length-3].
exec_aten::SizesType C = (length == 3) ? sizes[1] : sizes[length - 3];
return C != 1;
}

Expand Down Expand Up @@ -431,7 +465,7 @@ class NeutronBackend final : public PyTorchBackendInterface {
return Error::InvalidProgram;
}

if (is_channels_last_dim_order(dim_order, arg.dim())) {
if (is_channels_last_dim_order_extended(dim_order, arg.dim())) {
// The tensor is already permuted.
ET_LOG(Debug, "Using channels last dim order for input %d.\n", i);
cfg->dcfg.inputs[i] = arg.const_data_ptr();
Expand Down Expand Up @@ -481,7 +515,7 @@ class NeutronBackend final : public PyTorchBackendInterface {
multipleChannelsPresent(arg.sizes())) {
// The output will have to be transposed.

if (is_channels_last_dim_order(dim_order, arg.dim())) {
if (is_channels_last_dim_order_extended(dim_order, arg.dim())) {
// The tensor will already be correctly permuted. No transposition
// needed.
cfg->dcfg.outputs[i] = arg.mutable_data_ptr();
Expand Down Expand Up @@ -539,7 +573,7 @@ class NeutronBackend final : public PyTorchBackendInterface {
}

auto dim_order = arg.dim_order().data();
if (is_channels_last_dim_order(dim_order, arg.dim())) {
if (is_channels_last_dim_order_extended(dim_order, arg.dim())) {
// The rest of the model expects the `channels_last` dim order, which
// the data already matches.
ET_LOG(Debug, "Using channels last dim order for output %d.\n", i);
Expand Down Expand Up @@ -587,4 +621,4 @@ static auto registered = register_backend(backend_id);
} // namespace
} // namespace neutron
} // namespace executor
} // namespace torch
} // namespace torch
Loading