Skip to content

Commit aa1f45d

Browse files
committed
Tool to modify torchlib overload names via libcst
ghstack-source-id: 420de64 Pull Request resolved: #920
1 parent 2d88d80 commit aa1f45d

File tree

2 files changed

+180
-1
lines changed

2 files changed

+180
-1
lines changed
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
from __future__ import annotations
2+
3+
import enum
4+
import os
5+
import pathlib
6+
from typing import Dict, List, Set, Tuple
7+
8+
import libcst as cst
9+
from libcst import matchers
10+
from libcst._nodes.statement import FunctionDef
11+
12+
from onnxscript.function_libs.torch_lib import registration
13+
14+
15+
class _StatusEnum(enum.Enum):
16+
SUCCESS = enum.auto()
17+
"""Success."""
18+
FAILURE_OVERLOAD_EXIST = enum.auto()
19+
"""Failure: overload name already exists."""
20+
FAILURE_OVERLOAD_INVALID = enum.auto()
21+
"""Failure: overload name is invalid."""
22+
FAILURE_OP_NOT_FOUND = enum.auto()
23+
"""Failure: op not found."""
24+
FAILURE_OP_MULTIPLE_IMPL = enum.auto()
25+
"""Failure: op has multiple implementations. Cannot decide which to add new overload name to."""
26+
27+
28+
def _cst_arg_to_overload_names(arg: cst.Arg) -> Tuple[str, ...]:
29+
if matchers.matches(arg, matchers.Arg(value=matchers.SimpleString())):
30+
overload_names = (cst.ensure_type(arg.value, cst.SimpleString).value,)
31+
else:
32+
overload_names = tuple(
33+
cst.ensure_type(element.value, cst.SimpleString).value
34+
for element in cst.ensure_type(arg.value, cst.Tuple).elements
35+
)
36+
overload_names = tuple(name.replace('"', "") for name in overload_names)
37+
return overload_names
38+
39+
40+
def _overload_names_to_namespace_op(overload_names: Tuple[str, ...]) -> str:
41+
match = registration._QUALIFIED_OPERATOR_NAME_REGEX.fullmatch(overload_names[0])
42+
assert match is not None
43+
namespace = match.group("namespace")
44+
name = match.group("name")
45+
return f"{namespace}::{name}"
46+
47+
48+
class _TorchlibOpOverloadCollector(cst.CSTVisitor):
49+
def __init__(self):
50+
self._op_overloads: Dict[str, List[Tuple[str, List[str]]]] = {}
51+
self._stack: List[str] = []
52+
53+
def visit_FunctionDef(self, node: FunctionDef) -> bool | None:
54+
self._stack.append(node.name.value)
55+
56+
def leave_FunctionDef(self, node: FunctionDef) -> None:
57+
self._stack.pop()
58+
59+
def visit_Call(self, node: cst.Call) -> None:
60+
if not matchers.matches(node.func, matchers.Name("torch_op")):
61+
return
62+
63+
function_name = self._stack[-1]
64+
overload_names = _cst_arg_to_overload_names(node.args[0])
65+
namespace_op_name = _overload_names_to_namespace_op(overload_names)
66+
67+
self._op_overloads.setdefault(namespace_op_name, [])
68+
self._op_overloads[namespace_op_name].append((function_name, list(overload_names)))
69+
70+
71+
class _TorchlibOpOverloadAdder(cst.CSTTransformer):
72+
def __init__(
73+
self,
74+
overload_names: Dict[str, List[Tuple[str, List[str]]]],
75+
new_overload_names: Set[str],
76+
):
77+
self._overload_names = overload_names
78+
self._results: Dict[str, _StatusEnum] = {}
79+
80+
for new_overload_name in new_overload_names:
81+
match = registration._QUALIFIED_OPERATOR_NAME_REGEX.fullmatch(new_overload_name)
82+
if not match:
83+
self._results[new_overload_name] = _StatusEnum.FAILURE_OVERLOAD_INVALID
84+
continue
85+
overload = match.group("overload") or ""
86+
if overload == "default":
87+
overload = ""
88+
dot_overload = f".{overload}" if overload else ""
89+
op_name = match.group("name")
90+
namespace = match.group("namespace")
91+
namespace_op_name = f"{namespace}::{op_name}"
92+
qualified_name = f"{namespace_op_name}{dot_overload}"
93+
94+
if namespace_op_name not in self._overload_names:
95+
self._results[new_overload_name] = _StatusEnum.FAILURE_OP_NOT_FOUND
96+
continue
97+
98+
if len(self._overload_names[namespace_op_name]) > 1:
99+
self._results[new_overload_name] = _StatusEnum.FAILURE_OP_MULTIPLE_IMPL
100+
continue
101+
102+
if qualified_name in self._overload_names[namespace_op_name][0][1]:
103+
self._results[new_overload_name] = _StatusEnum.FAILURE_OVERLOAD_EXIST
104+
continue
105+
106+
self._overload_names[namespace_op_name][0][1].append(qualified_name)
107+
self._results[new_overload_name] = _StatusEnum.SUCCESS
108+
109+
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
110+
if not matchers.matches(original_node.func, matchers.Name("torch_op")):
111+
return original_node
112+
113+
original_overload_names = _cst_arg_to_overload_names(original_node.args[0])
114+
namespace_op_name = _overload_names_to_namespace_op(original_overload_names)
115+
overload_names = self._overload_names[namespace_op_name][0][1]
116+
if len(overload_names) == 1:
117+
return original_node
118+
return updated_node.with_changes(
119+
args=[
120+
cst.Arg(
121+
value=cst.Tuple(
122+
elements=[
123+
cst.Element(cst.SimpleString(value=f'"{name}"'))
124+
for name in overload_names
125+
]
126+
)
127+
),
128+
*original_node.args[1:],
129+
],
130+
)
131+
132+
133+
def add_overload_names(
134+
module_path: pathlib.Path, overload_names: Set[str]
135+
) -> Dict[str, _StatusEnum]:
136+
"""NOTE: This function assumes"""
137+
source_tree = cst.parse_module(module_path.read_text())
138+
op_overload_collector = _TorchlibOpOverloadCollector()
139+
source_tree.visit(op_overload_collector)
140+
transformer = _TorchlibOpOverloadAdder(op_overload_collector._op_overloads, overload_names)
141+
modified_tree = source_tree.visit(transformer)
142+
module_path.write_text(modified_tree.code)
143+
return transformer._results
144+
145+
146+
def main():
147+
new_overload_names = {
148+
"aten::add.Tensor",
149+
"aten::clamp.Tensor",
150+
"aten::div.Tensor",
151+
"aten::eq.Scalar",
152+
"aten::eq.Tensor",
153+
"aten::fill.Tensor",
154+
"aten::ge.Scalaraten::ge.Tensoraten::gt.Scalar",
155+
"aten::le.Tensor",
156+
"aten::lt.Scalar",
157+
"aten::mul.Tensor",
158+
"aten::ne.Scalar",
159+
"aten::roll.default",
160+
"aten::rsub.Scalar",
161+
"aten::select.int",
162+
"aten::slice.Tensor",
163+
"aten::split.Tensor",
164+
"aten::sub.Tensor",
165+
"aten::transpose.int",
166+
"aten::unbind.int",
167+
"aten::where.self",
168+
}
169+
file_paths = [
170+
pathlib.Path(os.path.join(root, file))
171+
for root, dirs, files in os.walk("onnxscript/function_libs/torch_lib/ops")
172+
for file in files
173+
]
174+
for file_path in file_paths:
175+
print(add_overload_names(file_path, new_overload_names))
176+
177+
178+
if __name__ == "__main__":
179+
main()

onnxscript/function_libs/torch_lib/registration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
# Regex that will match "<namespace>::<op_name>[.<overload>]"
1212
_QUALIFIED_OPERATOR_NAME_REGEX = re.compile(
13-
r"^(?P<namespace>[a-zA-Z0-9_]+)::(?P<name>[a-zA-Z0-9_]+)(?P<overload>\.[a-zA-Z0-9._]+)?$"
13+
r"^(?P<namespace>\w+)::(?P<name>\w+)(?:\.(?P<overload>\w+))?$"
1414
)
1515

1616

0 commit comments

Comments
 (0)