33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
55
6+ from collections .abc import Callable
7+
68import executorch .backends .arm .tosa .dialect # noqa: F401
79from executorch .backends .arm ._passes .aten_to_tosa_activation_functions import (
810 get_activation_replacement ,
911)
10- from executorch .backends .arm ._passes .aten_to_tosa_tensor_operators import rewrite_argmax
12+ from executorch .backends .arm ._passes .aten_to_tosa_tensor_operators import (
13+ rewrite_argmax ,
14+ rewrite_binary_operator ,
15+ )
1116from executorch .backends .transforms .aten_to_dialect_pass import (
1217 AtenToDialectPass ,
1318 DialectNodeSpec ,
19+ SubstitutionFn ,
1420)
1521from executorch .exir .dialects ._ops import ops as exir_ops
1622from torch .fx import Node
23+ from torch .fx .node import Target
1724
1825
1926class ExirToTosaPass (AtenToDialectPass ):
@@ -25,17 +32,55 @@ class ExirToTosaPass(AtenToDialectPass):
2532 """
2633
2734
35+ def register_dialect_substitutions (
36+ * targets : Target ,
37+ ) -> Callable [[SubstitutionFn ], SubstitutionFn ]:
38+ def decorator (func : SubstitutionFn ) -> SubstitutionFn :
39+ for target in targets :
40+ ExirToTosaPass .register_dialect_substitution (target )(func )
41+ return func
42+
43+ return decorator
44+
45+
2846@ExirToTosaPass .register_dialect_substitution (exir_ops .edge .aten .argmax .default )
2947def _get_tensor_operators_replacement (
3048 node : Node , pass_ : AtenToDialectPass
3149) -> DialectNodeSpec :
3250 return rewrite_argmax (node , pass_ )
3351
3452
35- @ExirToTosaPass .register_dialect_substitution (exir_ops .edge .aten .clamp .default )
36- @ExirToTosaPass .register_dialect_substitution (exir_ops .edge .aten .erf .default )
37- @ExirToTosaPass .register_dialect_substitution (exir_ops .edge .aten .sigmoid .default )
38- @ExirToTosaPass .register_dialect_substitution (exir_ops .edge .aten .tanh .default )
53+ @register_dialect_substitutions (
54+ exir_ops .edge .aten .add .Tensor ,
55+ exir_ops .edge .aten .bitwise_and .Tensor ,
56+ exir_ops .edge .aten .bitwise_left_shift .Tensor ,
57+ exir_ops .edge .aten .bitwise_or .Tensor ,
58+ exir_ops .edge .aten .bitwise_right_shift .Tensor ,
59+ exir_ops .edge .aten .bitwise_xor .Tensor ,
60+ exir_ops .edge .aten .eq .Tensor ,
61+ exir_ops .edge .aten .ge .Tensor ,
62+ exir_ops .edge .aten .gt .Tensor ,
63+ exir_ops .edge .aten .logical_and .default ,
64+ exir_ops .edge .aten .logical_or .default ,
65+ exir_ops .edge .aten .logical_xor .default ,
66+ exir_ops .edge .aten .maximum .default ,
67+ exir_ops .edge .aten .minimum .default ,
68+ exir_ops .edge .aten .mul .Tensor ,
69+ exir_ops .edge .aten .pow .Tensor_Tensor ,
70+ exir_ops .edge .aten .sub .Tensor ,
71+ )
72+ def _get_binary_operator_replacement (
73+ node : Node , pass_ : AtenToDialectPass
74+ ) -> DialectNodeSpec | None :
75+ return rewrite_binary_operator (node , pass_ )
76+
77+
78+ @register_dialect_substitutions (
79+ exir_ops .edge .aten .clamp .default ,
80+ exir_ops .edge .aten .erf .default ,
81+ exir_ops .edge .aten .sigmoid .default ,
82+ exir_ops .edge .aten .tanh .default ,
83+ )
3984def _get_activation_replacement (
4085 node : Node , pass_ : AtenToDialectPass
4186) -> DialectNodeSpec | None :
0 commit comments