@@ -95,6 +95,21 @@ void wait(uint32_t channel) {}
9595static NeutronConfig neutronMemCopyConfig = {copy, wait};
9696#endif
9797
98+ // Extended dim_order check that handles 3D tensors (NCW → NWC).
99+ // The core is_channels_last_dim_order() only supports 4D/5D.
100+ bool is_channels_last_dim_order_extended (
101+ const uint8_t * dim_order,
102+ size_t dims) {
103+ if (is_channels_last_dim_order (dim_order, dims)) {
104+ return true ;
105+ }
106+ // 3D: dim_order [0, 2, 1] means channel dim (1) is innermost → NWC layout.
107+ if (dims == 3 ) {
108+ return dim_order[0 ] == 0 && dim_order[1 ] == 2 && dim_order[2 ] == 1 ;
109+ }
110+ return false ;
111+ }
112+
98113// Applied on outputs.
99114template <typename T>
100115void transposeToChannelFirst (
@@ -149,12 +164,21 @@ void transposeInput(
149164 if (length < 3 ) {
150165 return ;
151166 }
152- size_t N = 1 ;
153- size_t C = sizes[length - 3 ];
154- size_t H = sizes[length - 2 ];
155- size_t W = sizes[length - 1 ];
156- for (size_t i = 0 ; i < length - 3 ; i++) {
157- N *= sizes[i];
167+ size_t N, C, H, W;
168+ if (length == 3 ) {
169+ // 3D [N, C, W] → treat as [N, C, 1, W]
170+ N = sizes[0 ];
171+ C = sizes[1 ];
172+ H = 1 ;
173+ W = sizes[2 ];
174+ } else {
175+ N = 1 ;
176+ C = sizes[length - 3 ];
177+ H = sizes[length - 2 ];
178+ W = sizes[length - 1 ];
179+ for (size_t i = 0 ; i < length - 3 ; i++) {
180+ N *= sizes[i];
181+ }
158182 }
159183 switch (element_size) {
160184 case 1 :
@@ -204,12 +228,21 @@ void transposeOutput(
204228 if (length < 3 ) {
205229 return ;
206230 }
207- size_t N = 1 ;
208- size_t C = sizes[length - 3 ];
209- size_t H = sizes[length - 2 ];
210- size_t W = sizes[length - 1 ];
211- for (size_t i = 0 ; i < length - 3 ; i++) {
212- N *= sizes[i];
231+ size_t N, C, H, W;
232+ if (length == 3 ) {
233+ // 3D [N, C, W] → treat as [N, C, 1, W]
234+ N = sizes[0 ];
235+ C = sizes[1 ];
236+ H = 1 ;
237+ W = sizes[2 ];
238+ } else {
239+ N = 1 ;
240+ C = sizes[length - 3 ];
241+ H = sizes[length - 2 ];
242+ W = sizes[length - 1 ];
243+ for (size_t i = 0 ; i < length - 3 ; i++) {
244+ N *= sizes[i];
245+ }
213246 }
214247 switch (element_size) {
215248 case 1 :
@@ -252,7 +285,8 @@ bool multipleChannelsPresent(const ArrayRef<exec_aten::SizesType>& sizes) {
252285 if (length < 3 ) {
253286 return true ;
254287 }
255- exec_aten::SizesType C = sizes[length - 3 ];
288+ // For 3D [N, C, W], channel dim is sizes[1]. For 4D+, it's sizes[length-3].
289+ exec_aten::SizesType C = (length == 3 ) ? sizes[1 ] : sizes[length - 3 ];
256290 return C != 1 ;
257291}
258292
@@ -431,7 +465,7 @@ class NeutronBackend final : public PyTorchBackendInterface {
431465 return Error::InvalidProgram;
432466 }
433467
434- if (is_channels_last_dim_order (dim_order, arg.dim ())) {
468+ if (is_channels_last_dim_order_extended (dim_order, arg.dim ())) {
435469 // The tensor is already permuted.
436470 ET_LOG (Debug, " Using channels last dim order for input %d.\n " , i);
437471 cfg->dcfg .inputs [i] = arg.const_data_ptr ();
@@ -481,7 +515,7 @@ class NeutronBackend final : public PyTorchBackendInterface {
481515 multipleChannelsPresent (arg.sizes ())) {
482516 // The output will have to be transposed.
483517
484- if (is_channels_last_dim_order (dim_order, arg.dim ())) {
518+ if (is_channels_last_dim_order_extended (dim_order, arg.dim ())) {
485519 // The tensor will already be correctly permuted. No transposition
486520 // needed.
487521 cfg->dcfg .outputs [i] = arg.mutable_data_ptr ();
@@ -539,7 +573,7 @@ class NeutronBackend final : public PyTorchBackendInterface {
539573 }
540574
541575 auto dim_order = arg.dim_order ().data ();
542- if (is_channels_last_dim_order (dim_order, arg.dim ())) {
576+ if (is_channels_last_dim_order_extended (dim_order, arg.dim ())) {
543577 // The rest of the model expects the `channels_last` dim order, which
544578 // the data already matches.
545579 ET_LOG (Debug, " Using channels last dim order for output %d.\n " , i);
@@ -587,4 +621,4 @@ static auto registered = register_backend(backend_id);
587621} // namespace
588622} // namespace neutron
589623} // namespace executor
590- } // namespace torch
624+ } // namespace torch
0 commit comments