@@ -25,8 +25,9 @@ class RemoveCloneOpsTransform(ExportPass):
2525 exir_ops .edge .dim_order_ops ._clone_dim_order .default ,
2626 }
2727
28- def __init__ (self ) -> None :
28+ def __init__ (self , preserve_input_output_copies : bool = True ) -> None :
2929 super ().__init__ ()
30+ self ._preserve_input_output_copies = preserve_input_output_copies
3031
3132 def _remove (self , graph_module : torch .fx .GraphModule ) -> None :
3233 dequant_nodes = []
@@ -38,6 +39,11 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None:
3839 if self ._is_non_identity_clone (n ):
3940 continue
4041
42+ # If preserve_input_output_copies is set, don't remove clones that directly
43+ # copy from input to output.
44+ if self ._is_input_output_copy (n ) and self ._preserve_input_output_copies :
45+ continue
46+
4147 to_be_removed = n
4248 for user_n in list (n .users .keys ()):
4349 user_n .replace_input_with (n , n .args [0 ])
@@ -76,3 +82,16 @@ def _is_non_identity_clone(self, node: torch.fx.Node) -> bool:
7682 )
7783
7884 return False
85+
86+ def _is_input_output_copy (self , node : torch .fx .Node ) -> bool :
87+ """Return True if the node input is a graph input and output goes into an output node."""
88+
89+ input_node = node .args [0 ]
90+ if input_node .op != "placeholder" :
91+ return False
92+
93+ for users in node .users :
94+ if users .op == "output" :
95+ return True
96+
97+ return False
0 commit comments