-
Notifications
You must be signed in to change notification settings - Fork 528
/
Copy pathupgrade.py
212 lines (183 loc) · 8.86 KB
/
upgrade.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import re
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
import torch
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
from torch._export.pass_infra.node_metadata import NodeMetadata
from torch._export.pass_infra.proxy_value import ProxyValue
from torch.fx.node import Argument, Target
from torch.library import Library
lib = Library("aten", "FRAGMENT")
impl_lib = Library("aten", "IMPL")
log = logging.getLogger(__name__)
def get_target_version(versioned_upgrader_name: str) -> int:
"""div_Scalar_0_3 is the name of the upgrader, meaning it applies to div.Scalar of version 0 to 3 and is
upgrading to version 4."""
if not re.match("^.*_[0-9]+_[0-9]+$", versioned_upgrader_name):
raise RuntimeError(f"Upgrader name {versioned_upgrader_name} is invalid")
return int(versioned_upgrader_name.split("_")[-1]) + 1
def get_upgraders() -> Dict[str, Tuple[str, str]]:
"""Getting upgraders entry map and operator version map and merge them into one dict."""
upgraders = torch._C._get_upgraders_entry_map()
op_version_map = torch._C._get_operator_version_map()
output: Dict[str, Tuple[str, str]] = defaultdict(tuple) # type: ignore[arg-type]
for opname, entry_list in op_version_map.items():
if not entry_list:
raise RuntimeError(f"Op version map has an empty entry for opname {opname}")
entry = entry_list[0]
old_schema = entry.old_schema
upgrader_name = entry.upgrader_name
upgrader_str = upgraders.get(upgrader_name, None)
if not upgrader_str:
raise RuntimeError(
f"Can't find upgrader for op {opname} and upgrader name {upgrader_name}"
)
output[upgrader_name] = (old_schema, upgrader_str)
return output
class GraphModuleOpUpgrader:
"""This upgrader is able to upgrade the old version of ops in a given GraphModule, if all upgraders are available.
To use it, retrieve upgraders from somewhere (TorchScript API or new API) and pass it into this upgrader. In
__init__() it does the following:
1. parse the upgrader list and reorder for upgrading purpose.
2. register old versions of operators as custom ops.
3. prepare upgrader passes.
In `upgrade()` API run these upgrader passes.
An example of op_upgraders input:
{
"aten::div__Scalar_0_3": ( # versioned op name
"div._Scalar(self: Tensor, other: Scalar)", # old schema
'''
def div__Scalar_0_3(self: torch.Tensor, other) -> torch.Tensor: # upgrader in literal string
if (self.is_floating_point() or isinstance(other, float)):
return self.true_divide_(other)
return self.divide_(other, rounding_mode='trunc')
''',
),
},
Note that we require the upgrader function to be runnable in Python (which is a stricter requirement than the
original TorchScript upgrader).
"""
class UpgraderPass(_ExportPassBaseDeprecatedDoNotUse):
def __init__(self, old_target: Target, new_target: Target):
super().__init__()
self.old_target = old_target
self.new_target = new_target
def call_operator(
self,
op,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op == self.old_target:
return super().call_operator(self.new_target, args, kwargs, meta)
return super().call_operator(op, args, kwargs, meta)
def __init__(
self,
compiler_opset_version: Optional[Dict[str, int]] = None,
model_opset_version: Optional[Dict[str, int]] = None,
op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None,
):
self.op_upgraders: Dict[str, Tuple[str, str]] = (
get_upgraders() if not op_upgraders else op_upgraders
)
self.compiler_opset_version = (
compiler_opset_version if compiler_opset_version else {}
)
self.model_opset_version = model_opset_version if model_opset_version else {}
self.upgrader_passes: List[GraphModuleOpUpgrader.UpgraderPass] = (
GraphModuleOpUpgrader._populate_passes(
self._parse_upgraders(self.op_upgraders)
)
)
def _parse_upgraders(
self, op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None
) -> List[Tuple[str, str]]:
"""Reorder op_upgraders by version number, return an ordered list of tuples, containing old op schema as well
as the upgrader function string literal."""
# TODO(larryliu0820): Add support for custom ops
op_namespace = "aten"
if (
not op_upgraders
or op_namespace not in self.model_opset_version
or op_namespace not in self.compiler_opset_version
):
return []
model_ver = self.model_opset_version[op_namespace]
curr_ver = self.compiler_opset_version[op_namespace]
# key is the target version. div__Scalar_0_3 should have a key of 4.
versioned_upgraders: Dict[int, Tuple[str, str]] = {
get_target_version(name): v for name, v in op_upgraders.items()
}
target_upgraders: List[Tuple[str, str]] = []
# we need all upgraders from model_ver + 1 to curr_ver, inclusively
for ver in range(model_ver + 1, curr_ver + 1):
if ver in versioned_upgraders:
target_upgraders.append(versioned_upgraders[ver])
else:
# we may be able to get away with missing upgraders, if that operator is missing from given graph
# module.
log.warning(
"Missing an upgrader to upgrade to version {ver}.",
extra={"ver": ver},
)
return target_upgraders
@staticmethod
def _populate_passes(upgraders: List[Tuple[str, str]]) -> List[UpgraderPass]:
"""Given a list of upgraders, loop through it from lower version to higher version and create passes for all
upgraders. se torch.Library API to register old ops. Op name will be
<name>_<valid_from_ver>_<valid_till_ver>. Register upgraders as CompositeImplicitAutograd kernels. For example:
lib = Library("aten", "FRAGMENT")
lib.define(old_schema)
impl_lib = Library("aten", "IMPL")
impl_lib.impl("div__Scalar_0_3", div__Scalar_0_3, "CompositeImplicitAutograd")
@:var upgraders: a list of tuples. The first element of the tuple is the old schema and the second is the
upgrader function literal text.
@:return upgrader passes, order matters
"""
upgrader_passes = []
def register_old_op(name: str, schema: str, impl_str: str):
"""Registers an old version operator using impl_name as old op name."""
lib.define(schema)
try:
exec(impl_str)
except Exception as e:
raise RuntimeError(f"Invalid upgrader string: {impl_str}") from e
impl_lib.impl(name, locals()[name], "CompositeImplicitAutograd")
for schema, upgrader_str in upgraders:
upgrader_name = upgrader_str.split("(")[0].split(" ")[-1]
op_name = schema.split("(")[0].split("::")[-1]
schema = schema.replace(op_name, upgrader_name)
try:
register_old_op(
name=upgrader_name, schema=schema, impl_str=upgrader_str
)
except RuntimeError as e:
if "with the same name and overload name multiple times" in str(e):
print(f"Registering {upgrader_name} multiple times")
else:
raise RuntimeError from e
old_op_target = getattr(torch.ops.aten, upgrader_name).default
# for example, the operator instance of "aten::div" is torch.op.aten.div.default. We need to append the
# "default" at the end.
op_name, overload_name = (
(op_name, "default")
if "." not in op_name
else tuple(op_name.split(".")[:2])
)
new_op_target = getattr(getattr(torch.ops.aten, op_name), overload_name)
# Note that the graph will have op names in the graph, but actually they are of old versions.
upgrader_passes.append(
GraphModuleOpUpgrader.UpgraderPass(
old_target=new_op_target, new_target=old_op_target
)
)
return upgrader_passes
def upgrade(self, exported_program):
return exported_program