Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use IR in onnxscript #1409

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions onnxscript/ir/convenience.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,47 @@

from typing import Any, Mapping, Sequence

from onnxscript.ir import _core
from onnxscript.ir import _core, _protocols, _enums


def convert_attribute(
name: str,
attr: str
| int
| float
| Sequence[int]
| Sequence[float]
| Sequence[str]
| _protocols.TensorProtocol
| _core.Attr
| None,
attr_type: _enums.AttributeType | None= None,
) -> _core.Attr:
if attr is None:
if attr_type is None:
raise ValueError("attr_type must be provided when attr is None")

Check warning on line 29 in onnxscript/ir/convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/convenience.py#L29

Added line #L29 was not covered by tests
return _core.Attr(name, attr_type, None)
if isinstance(attr, int):
return _core.AttrInt64(name, attr)

Check warning on line 32 in onnxscript/ir/convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/convenience.py#L32

Added line #L32 was not covered by tests
if isinstance(attr, float):
return _core.AttrFloat32(name, attr)
if isinstance(attr, str):
return _core.AttrString(name, attr)
if isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr):
return _core.AttrInt64s(name, attr) # type: ignore

Check warning on line 38 in onnxscript/ir/convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/convenience.py#L38

Added line #L38 was not covered by tests
if isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr):
return _core.AttrFloat32s(name, attr) # type: ignore

Check warning on line 40 in onnxscript/ir/convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/convenience.py#L40

Added line #L40 was not covered by tests
if isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr):
return _core.AttrStrings(name, attr) # type: ignore

Check warning on line 42 in onnxscript/ir/convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/convenience.py#L42

Added line #L42 was not covered by tests
if isinstance(attr, (_core.Tensor, _protocols.TensorProtocol)):
return _core.AttrTensor(name, attr)

Check warning on line 44 in onnxscript/ir/convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/convenience.py#L44

Added line #L44 was not covered by tests
if isinstance(attr, _core.Attr):
return attr
raise TypeError(f"Unsupported attribute type: '{type(attr)}'")

Check warning on line 47 in onnxscript/ir/convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/convenience.py#L46-L47

Added lines #L46 - L47 were not covered by tests


def convert_attributes(attrs: Mapping[str, Any]) -> list[_core.Attr]:
attributes: list[_core.Attr] = []
for name, attr in attrs.items():
if isinstance(attr, int):
attributes.append(_core.AttrInt64(name, attr))
elif isinstance(attr, float):
attributes.append(_core.AttrFloat32(name, attr))
elif isinstance(attr, str):
attributes.append(_core.AttrString(name, attr))
elif isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr):
attributes.append(_core.AttrInt64s(name, attr))
elif isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr):
attributes.append(_core.AttrFloat32s(name, attr))
elif isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr):
attributes.append(_core.AttrStrings(name, attr))
elif isinstance(attr, _core.Attr):
attributes.append(attr)
else:
raise TypeError(f"Unsupported attribute type: '{type(attr)}'")
attributes.append(convert_attribute(name, attr))

Check warning on line 53 in onnxscript/ir/convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/convenience.py#L53

Added line #L53 was not covered by tests
return attributes
193 changes: 39 additions & 154 deletions onnxscript/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@
import io
import logging
import warnings
from typing import Any, Optional, Protocol, Sequence, Union
from typing import Any, Mapping, Optional, Protocol, Sequence, Union

Check warning

Code scanning / lintrunner

PYLINT/W0611 Warning

Unused Protocol imported from typing (unused-import)
See unused-import. To disable, use # pylint: disable=unused-import

Check warning

Code scanning / lintrunner

RUFF/F401 Warning

typing.Protocol imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import

import onnx
from onnx import ValueInfoProto, helper
from onnx.defs import onnx_opset_version

import onnxscript
from onnxscript import type_annotation as ta
from onnxscript import values
from onnxscript import values, ir, sourceinfo
from onnxscript._internal import version_utils
from onnxscript.onnx_types import ONNXType
from onnxscript.sourceinfo import SourceInfo
from onnxscript.ir import convenience as ir_convenience

# A simple IR (Function, Stmt, Attr, Var):

Expand All @@ -40,39 +40,14 @@
return helper.OP_SET_ID_VERSION_MAP[domain, version]


class IRType:
def __init__(self):
self.onnx_type = onnx.TypeProto()

def to_type_proto(self):
return self.onnx_type

def __repr__(self) -> str:
return "IRType()"


class IRTensorType(IRType):
def __init__(self, elem_type: onnx.TensorProto.DataType) -> None:
super().__init__()
self.onnx_type.tensor_type.elem_type = elem_type

def __repr__(self) -> str:
return f"IRTensorType({self.onnx_type.tensor_type.elem_type})"


class IRTypeLike(Protocol):
def to_type_proto(self) -> onnx.TypeProto:
"""Converts IR type representation to onnx.TypeProto"""


class IRVar:
"""A variable (representing a formal parameter)."""

def __init__(self, varname: str, typeinfo: IRTypeLike, sourceinfo: SourceInfo) -> None:
def __init__(self, varname: str, typeinfo: ir.TypeProtocol, source_info: sourceinfo.SourceInfo) -> None:
if not isinstance(varname, str):
raise TypeError(f"varname must be a string not {type(varname)!r}.")
self.name = varname
self.info = sourceinfo
self.info = source_info
self.typeinfo = typeinfo

def __str__(self):
Expand All @@ -81,128 +56,31 @@
def __repr__(self):
return f"{self.__class__.__name__}({self.name!r}, {self.typeinfo!r})"

def typed_str(self):
return f"{self.name} : {self.typeinfo}"

def to_value_info(self, use_default_type: bool = True):
"""Converts the content of this class into :class:`onnx.ValueInfoProto`.

Args:
use_default_type: if True, use a default type if an explicit type
is not known. Otherwise, returns a ValueInfoProto without type.

Returns:
an instance of :class:`onnx.ValueInfoProto`
"""
if self.name is None:
raise ValueError(self.info.msg("name cannot be None."))
value_info_proto = ValueInfoProto()
value_info_proto.name = self.name
if self.typeinfo is not None:
value_info_proto.type.CopyFrom(self.typeinfo.to_type_proto())
elif use_default_type:
value_info_proto.type.CopyFrom(IRType().to_type_proto())
return value_info_proto


def _opt_var_to_str(x):
return "" if x is None else str(x)


class IRAttributeValue:
"""An attribute value (representing an actual parameter).

Attributes:
name: The name of the attribute.
type: The type of the attribute.
attr_proto: The attribute proto.
"""

def __init__(self, attrproto: onnx.AttributeProto) -> None:
self.attr_proto = attrproto

def __str__(self):
if self.attr_proto.HasField("ref_attr_name"):
return f"{self.attr_proto.name} = @{self.attr_proto.ref_attr_name}"
# self.name + " = " + self.value
return helper.printable_attribute(self.attr_proto)

@property
def name(self) -> str:
return self.attr_proto.name

@property
def type(self) -> onnx.AttributeProto.AttributeType:
return self.attr_proto.type


@dataclasses.dataclass(frozen=True)
class IRAttributeParameter:
"""An attribute parameter (representing a formal parameter).

It may or may not carry a default value.

Attributes:
name: The name of the attribute.
type: The type of the attribute.
default_value: The default value of the attribute.
has_default: Whether the attribute has a default value.
attr_proto: The attribute proto.
"""

name: str
type: onnx.AttributeProto.AttributeType
default_value: str | int | float | None = None

# TODO(justinchuby): Validate the default_value is the same type as specified in AttributeType.

def __str__(self):
if self.has_default:
return helper.printable_attribute(self.attr_proto)
# TODO(justinchuby): Include a readable type name.
return self.name

@property
def has_default(self):
return self.default_value is not None

@property
def attr_proto(self) -> onnx.AttributeProto:
if not self.has_default:
raise ValueError(
"Attribute has no default value. Only attributes with default "
"values can be converted to AttributeProto."
)
if version_utils.onnx_older_than("1.15"):
# TODO(after 1.14 is deprecated): Remove this branch.
# Argument 'attr_type' was added after version 1.14.
return helper.make_attribute(self.name, self.default_value)
# pylint: disable=unexpected-keyword-arg
return helper.make_attribute(self.name, self.default_value, attr_type=self.type) # type: ignore[call-arg]
# pylint: enable=unexpected-keyword-arg


class IRStmt:
def __init__(
self,
result: Sequence[str],
callee: values.Op,
args: Sequence[Optional[str]],
attrs: Sequence[IRAttributeValue],
args: Sequence[str],
attrs: Sequence[ir.Attr | ir.RefAttr],
sub_functions=None,
) -> None:
if not isinstance(callee, values.Op):
raise TypeError(f"Unexpected type {type(callee)} for callee.")
self.result = result
self._output_names = result
self.callee = callee
self.args = args
self.attrs = attrs
self.functions = sub_functions or {}

def __str__(self):
if isinstance(self.result, str):
logger.debug("unexpected str type for self.result where type(self)=%r", type(self))
lhs = ", ".join(self.result)
lhs = ", ".join(self._output_names)

Check warning on line 83 in onnxscript/irbuilder.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/irbuilder.py#L83

Added line #L83 was not covered by tests
attrs = ""
if self.attrs:
attrs = _format(self.attrs, "<", ", ", ">")
Expand All @@ -217,22 +95,29 @@
if logger.isEnabledFor(logging.DEBUG):
logger.debug("%s: %s", type(self), str(self))

def to_node_proto(self, node_name: str) -> onnx.NodeProto:
n = helper.make_node(
self.callee.name,
[_opt_var_to_str(x) for x in self.args],
[str(x) for x in self.result],
def to_node(self, node_name: str, values: Mapping[str, ir.Value]) -> ir.Node:

Check warning

Code scanning / lintrunner

PYLINT/W0621 Warning

Redefining name 'values' from outer scope (line 19) (redefined-outer-name)
See redefined-outer-name. To disable, use # pylint: disable=redefined-outer-name
"""
Converts this statement into a node in the IR.

Args:
node_name: The name of the node.
values: A dictionary mapping value names to values.
"""
node = ir.Node(
domain=self.callee.opset.domain,
op_type=self.callee.name,
inputs=[values[x] if x != "" else None for x in self.args],
name=node_name,
attributes=self.attrs,
)
for a in self.attrs:
n.attribute.append(a.attr_proto)
return n
for name, output in zip(self._output_names, node.outputs):
output.name = name
return node

Check warning on line 115 in onnxscript/irbuilder.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/irbuilder.py#L114-L115

Added lines #L114 - L115 were not covered by tests

@property
def output_names(self) -> Sequence[str]:
"""Returns the list of variables assigned to by this statement."""
return [str(x) for x in self.result]
return self._output_names


class IRFunction:
Expand All @@ -248,7 +133,7 @@
# a dictionary of nested function-definitions
self.nested_functions: dict[str, IRFunction] = {}
self.outer_scope_variables: dict[Any, Any] = {}
self.ordered_inputs_and_attrs: list[Union[IRVar, IRAttributeParameter]] = []
self.ordered_inputs_and_attrs: list[Union[IRVar, ir.Attr]] = []

@property
def assigned_names(self) -> Sequence[str]:
Expand All @@ -260,11 +145,11 @@
return [var for var in self.ordered_inputs_and_attrs if isinstance(var, IRVar)]

@property
def attrs(self) -> Sequence[IRAttributeParameter]:
def attrs(self) -> Sequence[ir.Attr]:
return [
attr
for attr in self.ordered_inputs_and_attrs
if isinstance(attr, IRAttributeParameter)
if isinstance(attr, ir.Attr)
]

def __str__(self):
Expand All @@ -286,7 +171,7 @@
def append_output(self, name: IRVar) -> None:
self.outputs.append(name)

def add_attr_parameter(self, attr: IRAttributeParameter) -> None:
def add_attr_parameter(self, attr: ir.Attr) -> None:
self.ordered_inputs_and_attrs.append(attr)

def debug_print(self):
Expand Down Expand Up @@ -500,14 +385,14 @@
results: Sequence[str],
callee: values.Op,
args: Sequence[Optional[str]],
attrs: Sequence[IRAttributeValue],
attrs: Sequence[ir.Attr | ir.RefAttr],
sub_functions=None,
) -> None:
stmt = IRStmt(results, callee, args, attrs, sub_functions=sub_functions)
fn.append_stmt(stmt)

def add_input(
self, fn: IRFunction, varname: str, type: IRTypeLike, info: SourceInfo
self, fn: IRFunction, varname: str, type: ir.TypeProtocol, info: sourceinfo.SourceInfo
) -> None:
var = IRVar(varname, type, info)
fn.append_input(var)
Expand All @@ -516,23 +401,23 @@
self,
fn: IRFunction,
varname: str,
attribute_type: onnx.AttributeProto.AttributeType,
attribute_type: ir.AttributeType,
default_value: int | float | str | None,
) -> None:
fn.add_attr_parameter(IRAttributeParameter(varname, attribute_type, default_value))
fn.add_attr_parameter(ir_convenience.convert_attribute(varname, default_value, attribute_type))

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "add_attr_parameter" of "IRFunction" has incompatible type "Attr | RefAttr"; expected "Attr" To disable, use # type: ignore[arg-type]

def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None:
var = IRVar(varname, typeinfo, sourceinfo)
def add_output(self, fn: IRFunction, varname: str, typeinfo, source_info) -> None:
var = IRVar(varname, typeinfo, source_info)
fn.append_output(var)

def make_attr(self, attrproto: onnx.AttributeProto) -> IRAttributeValue:
return IRAttributeValue(attrproto)
def make_attr(self, attrproto: onnx.AttributeProto) -> ir.Attr | ir.RefAttr:
return ir.Attr | ir.RefAttr(attrproto)
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttributeValue:
def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> ir.Attr | ir.RefAttr:
proto = onnx.AttributeProto()
proto.name = attrname
proto.ref_attr_name = refname
attr_type = ta.pytype_to_attrtype(pytype)
assert attr_type is not None
proto.type = attr_type
return IRAttributeValue(proto)
return ir.Attr | ir.RefAttr(proto)
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Loading
Loading