Skip to content

Commit aa47742

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Remove no-op clones in xnnpack (pytorch#15884)
Summary: Pull Request resolved: pytorch#15884 Differential Revision: D87405074
1 parent 131d1f4 commit aa47742

File tree

13 files changed

+264
-3
lines changed

13 files changed

+264
-3
lines changed

backends/transforms/remove_clone_ops.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ class RemoveCloneOpsTransform(ExportPass):
2525
exir_ops.edge.dim_order_ops._clone_dim_order.default,
2626
}
2727

28-
def __init__(self) -> None:
28+
def __init__(self, preserve_input_output_copies: bool = True) -> None:
2929
super().__init__()
30+
self._preserve_input_output_copies = preserve_input_output_copies
3031

3132
def _remove(self, graph_module: torch.fx.GraphModule) -> None:
3233
dequant_nodes = []
@@ -38,6 +39,11 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None:
3839
if self._is_non_identity_clone(n):
3940
continue
4041

42+
# If preserve_input_output_copies is set, don't remove clones that directly
43+
# copy from input to output.
44+
if self._is_input_output_copy(n) and self._preserve_input_output_copies:
45+
continue
46+
4147
to_be_removed = n
4248
for user_n in list(n.users.keys()):
4349
user_n.replace_input_with(n, n.args[0])
@@ -76,3 +82,16 @@ def _is_non_identity_clone(self, node: torch.fx.Node) -> bool:
7682
)
7783

7884
return False
85+
86+
def _is_input_output_copy(self, node: torch.fx.Node) -> bool:
87+
"""Return True if the node input is a graph input and output goes into an output node."""
88+
89+
input_node = node.args[0]
90+
if input_node.op != "placeholder":
91+
return False
92+
93+
for users in node.users:
94+
if users.op == "output":
95+
return True
96+
97+
return False

backends/xnnpack/_passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ runtime.python_library(
88
deps = [
99
"//caffe2:torch",
1010
"//executorch/backends/transforms:addmm_mm_to_linear",
11+
"//executorch/backends/transforms:remove_clone_ops",
1112
"//executorch/backends/transforms:lib",
1213
"//executorch/backends/xnnpack/partition:partitioner_graphs",
1314
"//executorch/backends/xnnpack/serialization:xnnpack_schema",

backends/xnnpack/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from typing import List, Optional, Type
88

9+
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
10+
911
from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass
1012

1113
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
@@ -58,6 +60,7 @@ def __init__(
5860
if not passes:
5961
# All the XNNPACK passes
6062
self.passes = [
63+
RemoveCloneOpsTransform,
6164
# TODO - remove this pass once we have a better support for dim_order ops lowering
6265
DimOrderOpsRevertPass,
6366
ConvertToUpsampleBilinear2d,

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
op_cat,
1515
op_ceiling,
1616
op_clamp,
17+
op_clone,
1718
op_conv2d,
1819
op_div,
1920
op_dynamic_dequantize_ops,
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.backends.xnnpack.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
15+
XNNCopy,
16+
XNNGraph,
17+
XNode,
18+
)
19+
from executorch.backends.xnnpack.utils.utils import get_input_node
20+
21+
22+
@register_node_visitor
23+
class CloneVisitor(NodeVisitor):
24+
target = "aten.clone.default"
25+
26+
def __init__(self, *args) -> None:
27+
super().__init__(*args)
28+
29+
def define_node(
30+
self,
31+
node: torch.fx.Node,
32+
xnn_graph: XNNGraph,
33+
vals_to_ids: Dict[torch.fx.Node, int],
34+
debug_handle: int,
35+
) -> None:
36+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
37+
38+
# Sanity check that the input and output dim order are the same. We don't
39+
# handle dim order conversions yet.
40+
dim_order = node.kwargs.get("dim_order", None)
41+
input_meta = node.args[0].meta["val"]
42+
assert(dim_order is None or list(input_meta.dim_order() == dim_order))
43+
44+
# input
45+
input_id = vals_to_ids[get_input_node(node, 0)]
46+
47+
# output
48+
output_id = vals_to_ids[node]
49+
50+
ser_node = XNode(
51+
xnode_union=XNNCopy(
52+
input_id=input_id,
53+
output_id=output_id,
54+
flags=0,
55+
),
56+
debug_handle=debug_handle,
57+
)
58+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
CatConfig,
2323
CeilConfig,
2424
ClampConfig,
25+
CloneDimOrderConfig,
2526
ConstantPadConfig,
2627
DeQuantizedPerTensorConfig,
2728
DivConfig,
@@ -77,6 +78,7 @@
7778
BMMConfig,
7879
CatConfig,
7980
CeilConfig,
81+
CloneDimOrderConfig,
8082
ConstantPadConfig,
8183
ConvolutionConfig,
8284
ClampConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,3 +643,24 @@ class SinConfig(GenericNodePartitionerConfig):
643643

644644
def supported_precision_types(self) -> List[ConfigPrecisionType]:
645645
return [ConfigPrecisionType.FP32]
646+
647+
class CloneDimOrderConfig(GenericNodePartitionerConfig):
648+
target_name = "_clone_dim_order.default"
649+
650+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
651+
return [ConfigPrecisionType.FP32]
652+
653+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
654+
if not self.check_common_constraints(node, ep):
655+
return False
656+
657+
# Only partition no-op _clone_dim_order nodes (output dim order = input).
658+
# We can relax this in the future.
659+
# This is also a conservative check and doesn't consider ambiguity.
660+
dim_order = node.kwargs.get("dim_order", None)
661+
input_meta = node.args[0].meta["val"]
662+
if dim_order is not None and list(input_meta.dim_order()) != dim_order:
663+
why(node, reason="Only dim-order preserving clones are supported.")
664+
return False
665+
666+
return True

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,6 +1459,34 @@ Error defineBatchMatrixMultiplyNode(
14591459
return Error::Ok;
14601460
}
14611461

1462+
/*
1463+
* Defines a code node in the XNN subgraph.
1464+
*/
1465+
Error defineCopyNode(
1466+
xnn_subgraph_t subgraph_ptr,
1467+
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1468+
const NodePtr node,
1469+
const fb_xnnpack::XNNGraph* graph) noexcept {
1470+
MAYBE_UNUSED(graph);
1471+
1472+
auto graph_node = node->xnode_union_as_XNNCopy();
1473+
1474+
xnn_status status = xnn_define_copy(
1475+
subgraph_ptr,
1476+
remapped_ids.at(graph_node->input_id()),
1477+
remapped_ids.at(graph_node->output_id()),
1478+
graph_node->flags());
1479+
1480+
ET_CHECK_OR_RETURN_ERROR(
1481+
status == xnn_status_success,
1482+
Internal,
1483+
"Failed to create copy node %i with code: %s",
1484+
node->debug_handle(),
1485+
xnn_status_to_string(status));
1486+
1487+
return Error::Ok;
1488+
}
1489+
14621490
/*
14631491
Returns not Implemented Error code. This function is meant to be
14641492
called when the compiler encountes a XNodeType from the flatbuffer
@@ -1763,6 +1791,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
17631791
_DEFINE(Concatenate5)
17641792
_DEFINE(StaticSlice)
17651793
_DEFINE(BatchMatrixMultiply)
1794+
_DEFINE(Copy)
17661795
case fb_xnnpack::XNodeUnion::NONE:
17671796
default: // Adding here as a catch all, just in case
17681797
return &defineNotImplementedNode;

backends/xnnpack/serialization/runtime_schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ union XNodeUnion {
157157
XNNTanh: _XNNNode1x1,
158158
XNNExp: _XNNNode1x1,
159159
XNNSin: _XNNNode1x1,
160+
XNNCopy: _XNNNode1x1,
160161
}
161162

162163
union XValueUnion {

backends/xnnpack/serialization/schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ union XNodeUnion {
153153
XNNTanh: _XNNNode1x1,
154154
XNNExp: _XNNNode1x1,
155155
XNNSin: _XNNNode1x1,
156+
XNNCopy: _XNNNode1x1,
156157
}
157158

158159
union XValueUnion {

0 commit comments

Comments
 (0)