diff --git a/backends/nxp/runtime/NeutronBackend.cpp b/backends/nxp/runtime/NeutronBackend.cpp index 3ea973b7c5b..02c9e0003be 100644 --- a/backends/nxp/runtime/NeutronBackend.cpp +++ b/backends/nxp/runtime/NeutronBackend.cpp @@ -95,6 +95,21 @@ void wait(uint32_t channel) {} static NeutronConfig neutronMemCopyConfig = {copy, wait}; #endif +// Extended dim_order check that handles 3D tensors (NCW → NWC). +// The core is_channels_last_dim_order() only supports 4D/5D. +bool is_channels_last_dim_order_extended( + const uint8_t* dim_order, + size_t dims) { + if (is_channels_last_dim_order(dim_order, dims)) { + return true; + } + // 3D: dim_order [0, 2, 1] means channel dim (1) is innermost → NWC layout. + if (dims == 3) { + return dim_order[0] == 0 && dim_order[1] == 2 && dim_order[2] == 1; + } + return false; +} + // Applied on outputs. template void transposeToChannelFirst( @@ -149,12 +164,21 @@ void transposeInput( if (length < 3) { return; } - size_t N = 1; - size_t C = sizes[length - 3]; - size_t H = sizes[length - 2]; - size_t W = sizes[length - 1]; - for (size_t i = 0; i < length - 3; i++) { - N *= sizes[i]; + size_t N, C, H, W; + if (length == 3) { + // 3D [N, C, W] → treat as [N, C, 1, W] + N = sizes[0]; + C = sizes[1]; + H = 1; + W = sizes[2]; + } else { + N = 1; + C = sizes[length - 3]; + H = sizes[length - 2]; + W = sizes[length - 1]; + for (size_t i = 0; i < length - 3; i++) { + N *= sizes[i]; + } } switch (element_size) { case 1: @@ -204,12 +228,21 @@ void transposeOutput( if (length < 3) { return; } - size_t N = 1; - size_t C = sizes[length - 3]; - size_t H = sizes[length - 2]; - size_t W = sizes[length - 1]; - for (size_t i = 0; i < length - 3; i++) { - N *= sizes[i]; + size_t N, C, H, W; + if (length == 3) { + // 3D [N, C, W] → treat as [N, C, 1, W] + N = sizes[0]; + C = sizes[1]; + H = 1; + W = sizes[2]; + } else { + N = 1; + C = sizes[length - 3]; + H = sizes[length - 2]; + W = sizes[length - 1]; + for (size_t i = 0; i < length - 3; i++) { + N *= sizes[i]; + } } switch (element_size) { case 1: @@ -252,7 +285,8 @@ bool multipleChannelsPresent(const ArrayRef& sizes) { if (length < 3) { return true; } - exec_aten::SizesType C = sizes[length - 3]; + // For 3D [N, C, W], channel dim is sizes[1]. For 4D+, it's sizes[length-3]. + exec_aten::SizesType C = (length == 3) ? sizes[1] : sizes[length - 3]; return C != 1; } @@ -431,7 +465,7 @@ class NeutronBackend final : public PyTorchBackendInterface { return Error::InvalidProgram; } - if (is_channels_last_dim_order(dim_order, arg.dim())) { + if (is_channels_last_dim_order_extended(dim_order, arg.dim())) { // The tensor is already permuted. ET_LOG(Debug, "Using channels last dim order for input %d.\n", i); cfg->dcfg.inputs[i] = arg.const_data_ptr(); @@ -481,7 +515,7 @@ class NeutronBackend final : public PyTorchBackendInterface { multipleChannelsPresent(arg.sizes())) { // The output will have to be transposed. - if (is_channels_last_dim_order(dim_order, arg.dim())) { + if (is_channels_last_dim_order_extended(dim_order, arg.dim())) { // The tensor will already be correctly permuted. No transposition // needed. cfg->dcfg.outputs[i] = arg.mutable_data_ptr(); @@ -539,7 +573,7 @@ class NeutronBackend final : public PyTorchBackendInterface { } auto dim_order = arg.dim_order().data(); - if (is_channels_last_dim_order(dim_order, arg.dim())) { + if (is_channels_last_dim_order_extended(dim_order, arg.dim())) { // The rest of the model expects the `channels_last` dim order, which // the data already matches. ET_LOG(Debug, "Using channels last dim order for output %d.\n", i); @@ -587,4 +621,4 @@ static auto registered = register_backend(backend_id); } // namespace } // namespace neutron } // namespace executor -} // namespace torch \ No newline at end of file +} // namespace torch