Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tool to modify torchlib overload names via libcst #920

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions onnxscript/function_libs/tools/torch_lib/modify_overload_names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from __future__ import annotations

import enum
import os
import pathlib
from typing import Dict, List, Set, Tuple

import libcst as cst
from libcst import matchers
from libcst._nodes.statement import FunctionDef

from onnxscript.function_libs.torch_lib import registration


class _StatusEnum(enum.Enum):
SUCCESS = enum.auto()
"""Success."""
FAILURE_OVERLOAD_EXIST = enum.auto()
"""Failure: overload name already exists."""
FAILURE_OVERLOAD_INVALID = enum.auto()
"""Failure: overload name is invalid."""
FAILURE_OP_NOT_FOUND = enum.auto()
"""Failure: op not found."""
FAILURE_OP_MULTIPLE_IMPL = enum.auto()
"""Failure: op has multiple implementations. Cannot decide which to add new overload name to."""


def _cst_arg_to_overload_names(arg: cst.Arg) -> Tuple[str, ...]:
if matchers.matches(arg, matchers.Arg(value=matchers.SimpleString())):
overload_names = (cst.ensure_type(arg.value, cst.SimpleString).value,)
else:
overload_names = tuple(
cst.ensure_type(element.value, cst.SimpleString).value
for element in cst.ensure_type(arg.value, cst.Tuple).elements
)
overload_names = tuple(name.replace('"', "") for name in overload_names)
return overload_names


def _overload_names_to_namespace_op(overload_names: Tuple[str, ...]) -> str:
match = registration._QUALIFIED_OPERATOR_NAME_REGEX.fullmatch(overload_names[0])
assert match is not None
namespace = match.group("namespace")
name = match.group("name")
return f"{namespace}::{name}"


class _TorchlibOpOverloadCollector(cst.CSTVisitor):
def __init__(self):
self._op_overloads: Dict[str, List[Tuple[str, List[str]]]] = {}
self._stack: List[str] = []

def visit_FunctionDef(self, node: FunctionDef) -> bool | None:
self._stack.append(node.name.value)

def leave_FunctionDef(self, node: FunctionDef) -> None:
self._stack.pop()

def visit_Call(self, node: cst.Call) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the function def is visited before the decorator call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, from cst tree level, decorators are a part of the function.

if not matchers.matches(node.func, matchers.Name("torch_op")):
return

function_name = self._stack[-1]
overload_names = _cst_arg_to_overload_names(node.args[0])
namespace_op_name = _overload_names_to_namespace_op(overload_names)

self._op_overloads.setdefault(namespace_op_name, [])
self._op_overloads[namespace_op_name].append((function_name, list(overload_names)))


class _TorchlibOpOverloadAdder(cst.CSTTransformer):
def __init__(
self,
overload_names: Dict[str, List[Tuple[str, List[str]]]],
new_overload_names: Set[str],
):
self._overload_names = overload_names
self._results: Dict[str, _StatusEnum] = {}

for new_overload_name in new_overload_names:
match = registration._QUALIFIED_OPERATOR_NAME_REGEX.fullmatch(new_overload_name)
if not match:
self._results[new_overload_name] = _StatusEnum.FAILURE_OVERLOAD_INVALID
continue
overload = match.group("overload") or ""
if overload == "default":
overload = ""
dot_overload = f".{overload}" if overload else ""
op_name = match.group("name")
namespace = match.group("namespace")
namespace_op_name = f"{namespace}::{op_name}"
qualified_name = f"{namespace_op_name}{dot_overload}"

if namespace_op_name not in self._overload_names:
self._results[new_overload_name] = _StatusEnum.FAILURE_OP_NOT_FOUND
continue

if len(self._overload_names[namespace_op_name]) > 1:
self._results[new_overload_name] = _StatusEnum.FAILURE_OP_MULTIPLE_IMPL
continue

if qualified_name in self._overload_names[namespace_op_name][0][1]:
self._results[new_overload_name] = _StatusEnum.FAILURE_OVERLOAD_EXIST
continue

self._overload_names[namespace_op_name][0][1].append(qualified_name)
self._results[new_overload_name] = _StatusEnum.SUCCESS

def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
if not matchers.matches(original_node.func, matchers.Name("torch_op")):
return original_node

original_overload_names = _cst_arg_to_overload_names(original_node.args[0])
namespace_op_name = _overload_names_to_namespace_op(original_overload_names)
overload_names = self._overload_names[namespace_op_name][0][1]
if len(overload_names) == 1:
return original_node
return updated_node.with_changes(
args=[
cst.Arg(
value=cst.Tuple(
elements=[
cst.Element(cst.SimpleString(value=f'"{name}"'))
for name in overload_names
]
)
),
*original_node.args[1:],
],
)


def add_overload_names(
module_path: pathlib.Path, overload_names: Set[str]
) -> Dict[str, _StatusEnum]:
"""NOTE: This function assumes"""
source_tree = cst.parse_module(module_path.read_text())
op_overload_collector = _TorchlibOpOverloadCollector()
source_tree.visit(op_overload_collector)
transformer = _TorchlibOpOverloadAdder(op_overload_collector._op_overloads, overload_names)
modified_tree = source_tree.visit(transformer)
module_path.write_text(modified_tree.code)
return transformer._results


def main():
new_overload_names = {
"aten::add.Tensor",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible to add all overloads for these ops? I imagine x for x in torch.ops.aten.add.overloads() if "out" not in x etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the tool allows such usage.

Frankly this PR is libcst exercise lol. I think the only useful scenario for this tool is when we have a large batch of missing overloads that we know would work, typically from model benchmark.

"aten::clamp.Tensor",
"aten::div.Tensor",
"aten::eq.Scalar",
"aten::eq.Tensor",
"aten::fill.Tensor",
"aten::ge.Scalar",
"aten::ge.Tensor",
"aten::gt.Scalar",
"aten::le.Tensor",
"aten::lt.Scalar",
"aten::mul.Tensor",
"aten::ne.Scalar",
"aten::roll.default",
"aten::rsub.Scalar",
"aten::select.int",
"aten::slice.Tensor",
"aten::split.Tensor",
"aten::sub.Tensor",
"aten::transpose.int",
"aten::unbind.int",
"aten::where.self",
}
file_paths = [
pathlib.Path(os.path.join(root, file))
for root, dirs, files in os.walk("onnxscript/function_libs/torch_lib/ops")
for file in files
]
for file_path in file_paths:
print(add_overload_names(file_path, new_overload_names))


if __name__ == "__main__":
main()
40 changes: 20 additions & 20 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def aten_acosh(self: TFloat) -> TFloat:
return op.Acosh(self)


@torch_op("aten::add")
@torch_op(("aten::add", "aten::add.Tensor"))
def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
# TODO(microsoft/onnxruntime#15977): Improve fp16 precision
Expand Down Expand Up @@ -1235,7 +1235,7 @@ def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
return op.SplitToSequence(self, list_split, axis=dim)


@torch_op("aten::clamp", trace_only=True)
@torch_op(("aten::clamp", "aten::clamp.Tensor"), trace_only=True)
def aten_clamp(self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = None) -> TReal:
"""clamp(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor"""
clamped = self
Expand Down Expand Up @@ -2184,7 +2184,7 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType
raise NotImplementedError()


@torch_op("aten::div")
@torch_op(("aten::div", "aten::div.Tensor"))
def aten_div(self: TFloat, other: TFloat) -> TFloat:
"""div.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -2353,7 +2353,7 @@ def aten_empty_strided(
return op.Expand(zero, size)


@torch_op("aten::eq")
@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar"))
def aten_eq(self: TTensor, other: TTensor) -> BOOL:
"""eq.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -2563,7 +2563,7 @@ def aten_feature_dropout(input: TensorType, p: float, train: bool) -> TensorType
raise NotImplementedError()


@torch_op("aten::fill")
@torch_op(("aten::fill", "aten::fill.Tensor"))
def aten_fill(self: TTensor, value: TTensor) -> TTensor:
"""fill.Tensor(Tensor self, Tensor value) -> Tensor"""

Expand Down Expand Up @@ -2748,7 +2748,7 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::ge")
@torch_op(("aten::ge", "aten::ge.Tensor", "aten::ge.Scalar"))
def aten_ge(self: TReal, other: TReal) -> BOOL:
"""ge.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -2905,7 +2905,7 @@ def aten_gru_cell(
raise NotImplementedError()


@torch_op("aten::gt")
@torch_op(("aten::gt", "aten::gt.Scalar"))
def aten_gt(self: TReal, other: TReal) -> BOOL:
"""gt.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -3595,7 +3595,7 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::le")
@torch_op(("aten::le", "aten::le.Tensor"))
def aten_le(self: TReal, other: TReal) -> BOOL:
"""le.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -3884,7 +3884,7 @@ def aten_lstm_mps_backward(
raise NotImplementedError()


@torch_op("aten::lt")
@torch_op(("aten::lt", "aten::lt.Scalar"))
def aten_lt(self: TReal, other: TReal) -> BOOL:
"""lt.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -4462,15 +4462,15 @@ def aten_msort(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::mul")
@torch_op(("aten::mul", "aten::mul.Tensor"))
def aten_mul(self: TReal, other: TReal) -> TReal:
"""mul.Tensor(Tensor self, Tensor other) -> Tensor"""
# FIXME(titaiwang): get rid of this when we have type_promotion
other = op.CastLike(other, self)
return op.Mul(self, other)


@torch_op("aten::mul")
@torch_op(("aten::mul", "aten::mul.Tensor"))
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
"""ONNX Mul doesn't support Boolean, so use And as an equivalent operator."""

Expand Down Expand Up @@ -4883,7 +4883,7 @@ def aten_native_norm(self: TensorType, p: float = 2.0) -> TensorType:
raise NotImplementedError()


@torch_op("aten::ne")
@torch_op(("aten::ne", "aten::ne.Scalar"))
def aten_ne(self: TReal, other: TReal) -> BOOL:
"""ne.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -5756,7 +5756,7 @@ def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
return op.Reciprocal(op.Sqrt(self))


@torch_op("aten::rsub")
@torch_op(("aten::rsub", "aten::rsub.Scalar"))
def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
# FIXME(titaiwang): get rid of this when we have type_promotion
Expand Down Expand Up @@ -5855,7 +5855,7 @@ def aten_segment_reduce(
raise NotImplementedError()


@torch_op("aten::select")
@torch_op(("aten::select", "aten::select.int"))
def aten_select(self: TTensor, dim: int, index: int) -> TTensor:
"""select(Tensor self, int dim, int index) -> Tensor"""

Expand Down Expand Up @@ -5935,7 +5935,7 @@ def aten_sinh(self: TFloat) -> TFloat:
return op.Sinh(self)


@torch_op("aten::slice", trace_only=True)
@torch_op(("aten::slice", "aten::slice.Tensor"), trace_only=True)
def aten_slice(
self: TTensor,
dim: int = 0,
Expand Down Expand Up @@ -6081,7 +6081,7 @@ def aten_sparse_mask(self: TensorType, mask: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::split")
@torch_op(("aten::split", "aten::split.Tensor"))
def aten_split(self: TTensor, split_size: INT64, dim: int = 0) -> TTensor:
"""split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]"""

Expand Down Expand Up @@ -6309,7 +6309,7 @@ def aten_stft(
return result


@torch_op("aten::sub")
@torch_op(("aten::sub", "aten::sub.Tensor"))
def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
alpha = op.CastLike(alpha, other)
Expand Down Expand Up @@ -6634,7 +6634,7 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType:
raise NotImplementedError()


@torch_op("aten::transpose", trace_only=True)
@torch_op(("aten::transpose", "aten::transpose.int"), trace_only=True)
def aten_transpose(self, dim0: int, dim1: int):
"""transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)"""

Expand Down Expand Up @@ -6729,7 +6729,7 @@ def aten_type_as(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::unbind")
@torch_op(("aten::unbind", "aten::unbind.int"))
def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]:
"""unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]"""

Expand Down Expand Up @@ -7082,7 +7082,7 @@ def aten_vstack(tensors: Sequence[TTensor]) -> TTensor:
return op.ConcatFromSequence(tensors, axis=0)


@torch_op("aten::where")
@torch_op(("aten::where", "aten::where.self"))
def aten_where(condition: BOOL, self: TTensor, other: TTensor) -> TTensor:
"""where.self(Tensor condition, Tensor self, Tensor other) -> Tensor"""

Expand Down
2 changes: 1 addition & 1 deletion onnxscript/function_libs/torch_lib/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# Regex that will match "<namespace>::<op_name>[.<overload>]"
_QUALIFIED_OPERATOR_NAME_REGEX = re.compile(
r"^(?P<namespace>[a-zA-Z0-9_]+)::(?P<name>[a-zA-Z0-9_]+)(?P<overload>\.[a-zA-Z0-9._]+)?$"
r"^(?P<namespace>\w+)::(?P<name>\w+)(?:\.(?P<overload>\w+))?$"
)


Expand Down