Skip to content

Commit 77dffdd

Browse files
committed
Fix depthwise check
1 parent 6c4866c commit 77dffdd

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

Deeploy/CommonExtensions/OptimizationPasses/TopologyOptimizationPasses/LoweringOptimizationPasses.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@ def _singleNodePattern(op: str) -> gs.Graph:
2222
return graph
2323

2424

25+
def _isDepthwise(node: gs.Node) -> bool:
26+
if node.op not in ["Conv", "RequantizedConv"]:
27+
return False
28+
29+
channels_first = node.attrs.get("channels_first", True)
30+
spatialDims = len(node.inputs[1].shape) - 2
31+
shapeIn = node.inputs[0].shape
32+
chIn = shapeIn[-spatialDims - 1] if channels_first else shapeIn[-1]
33+
return node.attrs.get("group", 1) == chIn
34+
35+
2536
def _createReshape(tensorIn: gs.Tensor,
2637
name: str,
2738
newShape: Sequence[Union[int, str]],
@@ -271,10 +282,8 @@ def __init__(self, default_channels_first: bool = True):
271282

272283
def _NCWHtoNHWC_dw_fun(graph: gs.Graph, match: Match, name: str, default_channels_first: bool) -> gs.Graph:
273284
node = next(iter((match.nodes_map.values())))
274-
assert node.op in ["RequantizedConv", "Conv"]
275285

276-
# Skip non-dw nodes
277-
if node.attrs.get("group", 1) == 1:
286+
if not _isDepthwise(node):
278287
return graph
279288

280289
channels_first = node.attrs.get("channels_first", True)
@@ -315,8 +324,7 @@ def __init__(self, default_channels_first: bool = True):
315324
def _PULP_NCHWtoNHWC_dw_fun(graph: gs.Graph, match: Match, name: str, default_channels_first: bool = True):
316325
node = next(iter((match.nodes_map.values())))
317326

318-
# Skip non-dw nodes
319-
if node.attrs.get("group", 1) == 1:
327+
if not _isDepthwise(node):
320328
return graph
321329

322330
channels_first = node.attrs.get("channels_first", True)

0 commit comments

Comments
 (0)