-
Notifications
You must be signed in to change notification settings - Fork 527
/
Copy path_ops.py
166 lines (138 loc) · 6.24 KB
/
_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import types
from typing import Callable
import torch
from executorch.exir.dialects.backend._ops import (
_BACKEND_OP_LIB,
BackendOpOverloadPacket,
)
from executorch.exir.dialects.edge._ops import EdgeOpOverloadPacket
from torch._C import DispatchKey # @manual
from torch.library import Library
from torchgen.model import FunctionSchema
_OPOVERLOAD_PACKET_CLS_MAPPING = {
"edge": EdgeOpOverloadPacket,
"backend": BackendOpOverloadPacket,
}
def bind_pattern_to_op(library: Library, schema_or_name: str):
"""Bind a pattern of ops to a backend op. A backend op should only appear when a user wants to replace a pattern of nodes to a custom op.
On this front, the kernel being registered to it determines the decomposing behavior.
* If the backend op is registered with an CompositeExplicitAutograd (or Meta) kernel, once the graph is lowered (meaning the pass
of replacing a pattern to an op is executed) it will stick in the graph and we won't get the original graph even retrace.
* Otherwise, the backend op should be able to support retracing and be able to "promote" back to the original graph through retracing.
This macro is aiming to handle this complexity for users and they just need to use this macro on the pattern and we can make a decision for them.
Args:
library (Library): torch library
schema_or_name (str): schema string, e.g., "add.int(SymInt a, SymInt b) -> SymInt", or a qualified op name
"""
def wrapper(f: Callable):
if library.ns not in _BACKEND_OP_LIB:
_BACKEND_OP_LIB.append(library.ns)
no_namespace = schema_or_name.split("::")[-1]
try:
# can parse it into a FunctionSchema
func = FunctionSchema.parse(no_namespace)
name, overload_name = func.name.name.base, func.name.overload_name
library.define(no_namespace)
except AssertionError:
if "." in no_namespace:
name, overload_name = no_namespace.split(".")
else:
name, overload_name = no_namespace, None
opname = name + ("." + overload_name if overload_name else "")
overload_name = overload_name if overload_name else "default"
torch_op = getattr(getattr(getattr(torch.ops, library.ns), name), overload_name)
# we can't have both CompositeExplicitAutograd and CompositeImplicitAutograd kernel,
# we can't have both Meta and CompositeImplicitAutograd kernel either.
keys = [
DispatchKey.CompositeExplicitAutograd,
DispatchKey.CompositeImplicitAutograd,
DispatchKey.Meta,
]
if not any(torch_op.has_kernel_for_dispatch_key(k) for k in keys):
library.impl(opname, f, "CompositeImplicitAutograd")
op = getattr(getattr(getattr(ops.backend, library.ns), name), overload_name)
op._equivalent_callable = f
return f
return wrapper
class _OpNamespace(types.ModuleType):
"""
EXIR Dialect op namespace object. Contains ops and overloads registered into PyTorch dispatcher.
"""
def __init__(self, dialect, name):
super().__init__(f"exir.ops.{dialect}.{name}")
self._dialect = dialect
if dialect == "backend" and name not in _BACKEND_OP_LIB:
raise RuntimeError(f"{name} op library does not belong to backend ops.")
self._name = name
self._dir = []
self._op_namespace = getattr(torch.ops, name)
def __iter__(self):
return iter(self._dir)
def __getattr__(self, op_name):
# It is not a valid op_name when __file__ is passed in
if op_name == "__file__":
return "exir.ops"
if op_name in self.__dict__:
return getattr(self, op_name)
try:
parent_packet = getattr(self._op_namespace, op_name)
except AttributeError as e:
# Turn this into AttributeError so getattr(obj, key, default)
# works (this is called by TorchScript with __origin__)
raise AttributeError(
f"'_OpNamespace' '{self._dialect}.{self._name}' object has no attribute '{op_name}'"
) from e
qualified_op_name = f"{self._name}::{op_name}"
opoverload_packet_cls = _OPOVERLOAD_PACKET_CLS_MAPPING[self._dialect]
opoverloadpacket = opoverload_packet_cls(
qualified_op_name,
op_name,
parent_overload_packet=parent_packet,
)
opoverloadpacket.__module__ = self.__module__ + "." + self._name
# cache the opoverloadpacket to ensure that each op corresponds to
# a unique OpOverloadPacket object
setattr(self, op_name, opoverloadpacket)
self._dir.append(op_name)
return opoverloadpacket
class _DialectNamespace(types.ModuleType):
"""
Dialect namespace. Currently the dialects are:
- ATen Dialect: core ATen ops and overloads, see torch._ops._OpNamespace
- Edge Dialect: ATen ops with explicit Tensor dtype
- Backend Dialect: backend ops only meaningful to the backend we are lowering into
- Execution Dialect: memory planning ready, all out-variants
"""
def __init__(self, dialect_name):
super().__init__("exir.ops" + "." + dialect_name)
self._dialect_name = dialect_name
self._dir = []
def __getattr__(self, name):
if name in self.__dict__:
return getattr(self, name)
# Here we are creating `exir.ops.<dialect_ns>.<my_namespace>`
namespace = _OpNamespace(self._dialect_name, name)
setattr(self, name, namespace)
self._dir.append(name)
return namespace
class _Ops(types.ModuleType):
__file__ = "_ops.py"
def __init__(self):
super().__init__("exir.ops")
self._dir = []
def __getattr__(self, name):
if name in self.__dict__:
return getattr(self, name)
dialect = _DialectNamespace(name)
setattr(self, name, dialect)
self._dir.append(name)
return dialect
def __iter__(self):
return iter(self._dir)
# The ops "namespace"
ops = _Ops()