@@ -22,6 +22,17 @@ def _singleNodePattern(op: str) -> gs.Graph:
2222 return graph
2323
2424
25+ def _isDepthwise (node : gs .Node ) -> bool :
26+ if node .op not in ["Conv" , "RequantizedConv" ]:
27+ return False
28+
29+ channels_first = node .attrs .get ("channels_first" , True )
30+ spatialDims = len (node .inputs [1 ].shape ) - 2
31+ shapeIn = node .inputs [0 ].shape
32+ chIn = shapeIn [- spatialDims - 1 ] if channels_first else shapeIn [- 1 ]
33+ return node .attrs .get ("group" , 1 ) == chIn
34+
35+
2536def _createReshape (tensorIn : gs .Tensor ,
2637 name : str ,
2738 newShape : Sequence [Union [int , str ]],
@@ -271,10 +282,8 @@ def __init__(self, default_channels_first: bool = True):
271282
272283def _NCWHtoNHWC_dw_fun (graph : gs .Graph , match : Match , name : str , default_channels_first : bool ) -> gs .Graph :
273284 node = next (iter ((match .nodes_map .values ())))
274- assert node .op in ["RequantizedConv" , "Conv" ]
275285
276- # Skip non-dw nodes
277- if node .attrs .get ("group" , 1 ) == 1 :
286+ if not _isDepthwise (node ):
278287 return graph
279288
280289 channels_first = node .attrs .get ("channels_first" , True )
@@ -315,8 +324,7 @@ def __init__(self, default_channels_first: bool = True):
315324def _PULP_NCHWtoNHWC_dw_fun (graph : gs .Graph , match : Match , name : str , default_channels_first : bool = True ):
316325 node = next (iter ((match .nodes_map .values ())))
317326
318- # Skip non-dw nodes
319- if node .attrs .get ("group" , 1 ) == 1 :
327+ if not _isDepthwise (node ):
320328 return graph
321329
322330 channels_first = node .attrs .get ("channels_first" , True )
0 commit comments