-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Use IR in onnxscript #1409
base: main
Are you sure you want to change the base?
Changes from 4 commits
73b375b
5714445
54ce3b4
5d753ed
060009e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,18 +8,18 @@ | |
import io | ||
import logging | ||
import warnings | ||
from typing import Any, Optional, Protocol, Sequence, Union | ||
from typing import Any, Mapping, Optional, Protocol, Sequence, Union | ||
Check warning Code scanning / lintrunner RUFF/F401 Warning
typing.Protocol imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import |
||
|
||
import onnx | ||
from onnx import ValueInfoProto, helper | ||
from onnx.defs import onnx_opset_version | ||
|
||
import onnxscript | ||
from onnxscript import type_annotation as ta | ||
from onnxscript import values | ||
from onnxscript import values, ir, sourceinfo | ||
from onnxscript._internal import version_utils | ||
from onnxscript.onnx_types import ONNXType | ||
from onnxscript.sourceinfo import SourceInfo | ||
from onnxscript.ir import _convenience as ir_convenience | ||
|
||
# A simple IR (Function, Stmt, Attr, Var): | ||
|
||
|
@@ -40,39 +40,14 @@ | |
return helper.OP_SET_ID_VERSION_MAP[domain, version] | ||
|
||
|
||
class IRType: | ||
def __init__(self): | ||
self.onnx_type = onnx.TypeProto() | ||
|
||
def to_type_proto(self): | ||
return self.onnx_type | ||
|
||
def __repr__(self) -> str: | ||
return "IRType()" | ||
|
||
|
||
class IRTensorType(IRType): | ||
def __init__(self, elem_type: onnx.TensorProto.DataType) -> None: | ||
super().__init__() | ||
self.onnx_type.tensor_type.elem_type = elem_type | ||
|
||
def __repr__(self) -> str: | ||
return f"IRTensorType({self.onnx_type.tensor_type.elem_type})" | ||
|
||
|
||
class IRTypeLike(Protocol): | ||
def to_type_proto(self) -> onnx.TypeProto: | ||
"""Converts IR type representation to onnx.TypeProto""" | ||
|
||
|
||
class IRVar: | ||
"""A variable (representing a formal parameter).""" | ||
|
||
def __init__(self, varname: str, typeinfo: IRTypeLike, sourceinfo: SourceInfo) -> None: | ||
def __init__(self, varname: str, typeinfo: ir.TypeProtocol, source_info: sourceinfo.SourceInfo) -> None: | ||
if not isinstance(varname, str): | ||
raise TypeError(f"varname must be a string not {type(varname)!r}.") | ||
self.name = varname | ||
self.info = sourceinfo | ||
self.info = source_info | ||
self.typeinfo = typeinfo | ||
|
||
def __str__(self): | ||
|
@@ -81,128 +56,34 @@ | |
def __repr__(self): | ||
return f"{self.__class__.__name__}({self.name!r}, {self.typeinfo!r})" | ||
|
||
def typed_str(self): | ||
return f"{self.name} : {self.typeinfo}" | ||
|
||
def to_value_info(self, use_default_type: bool = True): | ||
"""Converts the content of this class into :class:`onnx.ValueInfoProto`. | ||
def typed_str(self) -> str: | ||
return f"{self.name}: {self.typeinfo}" | ||
|
||
Args: | ||
use_default_type: if True, use a default type if an explicit type | ||
is not known. Otherwise, returns a ValueInfoProto without type. | ||
|
||
Returns: | ||
an instance of :class:`onnx.ValueInfoProto` | ||
""" | ||
if self.name is None: | ||
raise ValueError(self.info.msg("name cannot be None.")) | ||
value_info_proto = ValueInfoProto() | ||
value_info_proto.name = self.name | ||
if self.typeinfo is not None: | ||
value_info_proto.type.CopyFrom(self.typeinfo.to_type_proto()) | ||
elif use_default_type: | ||
value_info_proto.type.CopyFrom(IRType().to_type_proto()) | ||
return value_info_proto | ||
|
||
|
||
def _opt_var_to_str(x): | ||
return "" if x is None else str(x) | ||
|
||
|
||
class IRAttributeValue: | ||
"""An attribute value (representing an actual parameter). | ||
|
||
Attributes: | ||
name: The name of the attribute. | ||
type: The type of the attribute. | ||
attr_proto: The attribute proto. | ||
""" | ||
|
||
def __init__(self, attrproto: onnx.AttributeProto) -> None: | ||
self.attr_proto = attrproto | ||
|
||
def __str__(self): | ||
if self.attr_proto.HasField("ref_attr_name"): | ||
return f"{self.attr_proto.name} = @{self.attr_proto.ref_attr_name}" | ||
# self.name + " = " + self.value | ||
return helper.printable_attribute(self.attr_proto) | ||
|
||
@property | ||
def name(self) -> str: | ||
return self.attr_proto.name | ||
|
||
@property | ||
def type(self) -> onnx.AttributeProto.AttributeType: | ||
return self.attr_proto.type | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class IRAttributeParameter: | ||
"""An attribute parameter (representing a formal parameter). | ||
|
||
It may or may not carry a default value. | ||
|
||
Attributes: | ||
name: The name of the attribute. | ||
type: The type of the attribute. | ||
default_value: The default value of the attribute. | ||
has_default: Whether the attribute has a default value. | ||
attr_proto: The attribute proto. | ||
""" | ||
|
||
name: str | ||
type: onnx.AttributeProto.AttributeType | ||
default_value: str | int | float | None = None | ||
|
||
# TODO(justinchuby): Validate the default_value is the same type as specified in AttributeType. | ||
|
||
def __str__(self): | ||
if self.has_default: | ||
return helper.printable_attribute(self.attr_proto) | ||
# TODO(justinchuby): Include a readable type name. | ||
return self.name | ||
|
||
@property | ||
def has_default(self): | ||
return self.default_value is not None | ||
|
||
@property | ||
def attr_proto(self) -> onnx.AttributeProto: | ||
if not self.has_default: | ||
raise ValueError( | ||
"Attribute has no default value. Only attributes with default " | ||
"values can be converted to AttributeProto." | ||
) | ||
if version_utils.onnx_older_than("1.15"): | ||
# TODO(after 1.14 is deprecated): Remove this branch. | ||
# Argument 'attr_type' was added after version 1.14. | ||
return helper.make_attribute(self.name, self.default_value) | ||
# pylint: disable=unexpected-keyword-arg | ||
return helper.make_attribute(self.name, self.default_value, attr_type=self.type) # type: ignore[call-arg] | ||
# pylint: enable=unexpected-keyword-arg | ||
|
||
|
||
class IRStmt: | ||
def __init__( | ||
self, | ||
result: Sequence[str], | ||
callee: values.Op, | ||
args: Sequence[Optional[str]], | ||
attrs: Sequence[IRAttributeValue], | ||
args: Sequence[str], | ||
attrs: Sequence[ir.Attr | ir.RefAttr], | ||
sub_functions=None, | ||
) -> None: | ||
if not isinstance(callee, values.Op): | ||
raise TypeError(f"Unexpected type {type(callee)} for callee.") | ||
self.result = result | ||
self._output_names = result | ||
self.callee = callee | ||
self.args = args | ||
self.attrs = attrs | ||
self.functions = sub_functions or {} | ||
|
||
def __str__(self): | ||
if isinstance(self.result, str): | ||
logger.debug("unexpected str type for self.result where type(self)=%r", type(self)) | ||
lhs = ", ".join(self.result) | ||
lhs = ", ".join(self._output_names) | ||
attrs = "" | ||
if self.attrs: | ||
attrs = _format(self.attrs, "<", ", ", ">") | ||
|
@@ -217,22 +98,29 @@ | |
if logger.isEnabledFor(logging.DEBUG): | ||
logger.debug("%s: %s", type(self), str(self)) | ||
|
||
def to_node_proto(self, node_name: str) -> onnx.NodeProto: | ||
n = helper.make_node( | ||
self.callee.name, | ||
[_opt_var_to_str(x) for x in self.args], | ||
[str(x) for x in self.result], | ||
def to_node(self, node_name: str, values: Mapping[str, ir.Value]) -> ir.Node: | ||
Check warning Code scanning / lintrunner PYLINT/W0621 Warning
Redefining name 'values' from outer scope (line 19) (redefined-outer-name)
See redefined-outer-name. To disable, use # pylint: disable=redefined-outer-name |
||
""" | ||
Converts this statement into a node in the IR. | ||
|
||
Args: | ||
node_name: The name of the node. | ||
values: A dictionary mapping value names to values. | ||
""" | ||
node = ir.Node( | ||
domain=self.callee.opset.domain, | ||
op_type=self.callee.name, | ||
inputs=[values[x] if x != "" else None for x in self.args], | ||
name=node_name, | ||
attributes=self.attrs, | ||
) | ||
for a in self.attrs: | ||
n.attribute.append(a.attr_proto) | ||
return n | ||
for name, output in zip(self._output_names, node.outputs): | ||
output.name = name | ||
return node | ||
|
||
@property | ||
def output_names(self) -> Sequence[str]: | ||
"""Returns the list of variables assigned to by this statement.""" | ||
return [str(x) for x in self.result] | ||
return self._output_names | ||
|
||
|
||
class IRFunction: | ||
|
@@ -248,7 +136,7 @@ | |
# a dictionary of nested function-definitions | ||
self.nested_functions: dict[str, IRFunction] = {} | ||
self.outer_scope_variables: dict[Any, Any] = {} | ||
self.ordered_inputs_and_attrs: list[Union[IRVar, IRAttributeParameter]] = [] | ||
self.ordered_inputs_and_attrs: list[Union[IRVar, ir.Attr]] = [] | ||
|
||
@property | ||
def assigned_names(self) -> Sequence[str]: | ||
|
@@ -260,11 +148,11 @@ | |
return [var for var in self.ordered_inputs_and_attrs if isinstance(var, IRVar)] | ||
|
||
@property | ||
def attrs(self) -> Sequence[IRAttributeParameter]: | ||
def attrs(self) -> Sequence[ir.Attr]: | ||
return [ | ||
attr | ||
for attr in self.ordered_inputs_and_attrs | ||
if isinstance(attr, IRAttributeParameter) | ||
if isinstance(attr, ir.Attr) | ||
] | ||
|
||
def __str__(self): | ||
|
@@ -286,18 +174,9 @@ | |
def append_output(self, name: IRVar) -> None: | ||
self.outputs.append(name) | ||
|
||
def add_attr_parameter(self, attr: IRAttributeParameter) -> None: | ||
def add_attr_parameter(self, attr: ir.Attr) -> None: | ||
self.ordered_inputs_and_attrs.append(attr) | ||
|
||
def debug_print(self): | ||
if logger.isEnabledFor(logging.DEBUG): | ||
st = io.StringIO() | ||
for s in self.stmts: | ||
for attr in s.attrs: | ||
if attr.attr_proto.HasField("g"): | ||
st.write(helper.printable_graph(attr.attr_proto.g)) | ||
st.write("\n") | ||
|
||
def add_called_function(self, fun: values.OnnxFunction) -> None: | ||
for name, fct in fun.function_ir.called_functions.items(): | ||
if name in self.called_functions: | ||
|
@@ -440,40 +319,33 @@ | |
) | ||
return func_opset_imports | ||
|
||
def to_function_proto(self) -> onnx.FunctionProto: | ||
"""Converts this instance into a `onnx.FunctionProto`. | ||
|
||
Note: Default values for attributes are an experimental feature in ONNX. | ||
Conversion ignores default values for attributes if the ONNX version installed | ||
doesn't support it. | ||
""" | ||
def to_ir_function(self) -> ir.Function: | ||
"""Converts this instance into a `ir.Function`.""" | ||
opsets = self.get_opset_import() | ||
nodes = [s.to_node_proto(f"n{i}") for i, s in enumerate(self.stmts)] | ||
for n in nodes: | ||
if n.domain not in opsets: | ||
opsets[n.domain] = 1 # TODO: how to get n.version? | ||
opset_imports = [ | ||
onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items() | ||
] | ||
|
||
attribute_names = [attr.name for attr in self.attrs if not attr.has_default] | ||
|
||
f = helper.make_function( | ||
self.domain, | ||
self.name, | ||
inputs=[x.name for x in self.inputs], | ||
outputs=[y.name for y in self.outputs], | ||
values = {} | ||
Check warning Code scanning / lintrunner PYLINT/W0621 Warning
Redefining name 'values' from outer scope (line 19) (redefined-outer-name)
See redefined-outer-name. To disable, use # pylint: disable=redefined-outer-name Check failure Code scanning / lintrunner MYPY/var-annotated Error
Need type annotation for "values" (hint: "values: Dict[, ] = ...")
To disable, use # type: ignore[var-annotated]
|
||
nodes = [] | ||
function_outputs: dict[str, ir.Value | None] = {x.name: None for x in self.outputs} | ||
for i, s in enumerate(self.stmts): | ||
node = s.to_node(f"n{i}", values) | ||
nodes.append(node) | ||
if node.domain not in opsets: | ||
# FIXME(justinchuby): Node version | ||
assert s.version is not None | ||
Check failure Code scanning / lintrunner MYPY/attr-defined Error
"IRStmt" has no attribute "version"
To disable, use # type: ignore[attr-defined]
|
||
opsets[node.domain] = s.version | ||
Check failure Code scanning / lintrunner MYPY/attr-defined Error
"IRStmt" has no attribute "version"
To disable, use # type: ignore[attr-defined]
|
||
for output in node.outputs: | ||
values[output.name] = output | ||
if output.name in function_outputs: | ||
function_outputs[output.name] = output | ||
inputs = [ir.Input(input.name) for input in self.inputs] | ||
for name, output in function_outputs.items(): | ||
Check failure Code scanning / lintrunner MYPY/assignment Error
Incompatible types in assignment (expression has type "Value | None", variable has type "Value")
To disable, use # type: ignore[assignment]
|
||
assert output is not None, f"Output {name!r} is an output of any node is the function." | ||
graph = ir.Graph( | ||
inputs=inputs, | ||
outputs=function_outputs.values(), # type: ignore | ||
nodes=nodes, | ||
opset_imports=opset_imports, # TODO | ||
attributes=attribute_names, | ||
doc_string=self.docstring, | ||
opset_imports=opsets, | ||
) | ||
# In protobuf 4.x fields aren't defined as class attribute so it should check instance attribute instead | ||
if hasattr(f, "attribute_proto"): | ||
f.attribute_proto.extend( | ||
[attr.attr_proto for attr in self.attrs if attr.has_default] | ||
) | ||
return f | ||
return ir.Function(domain=self.domain, name=self.name, graph=graph, attributes=self.attrs) | ||
|
||
|
||
# IRBuilder: abstracts out details of the IR in the python-to-IR converter | ||
|
@@ -500,14 +372,15 @@ | |
results: Sequence[str], | ||
callee: values.Op, | ||
args: Sequence[Optional[str]], | ||
attrs: Sequence[IRAttributeValue], | ||
attrs: Sequence[ir.Attr | ir.RefAttr], | ||
sub_functions=None, | ||
) -> None: | ||
# TODO(justinchuby): Capture opset version here | ||
stmt = IRStmt(results, callee, args, attrs, sub_functions=sub_functions) | ||
fn.append_stmt(stmt) | ||
|
||
def add_input( | ||
self, fn: IRFunction, varname: str, type: IRTypeLike, info: SourceInfo | ||
self, fn: IRFunction, varname: str, type: ir.TypeProtocol, info: sourceinfo.SourceInfo | ||
) -> None: | ||
var = IRVar(varname, type, info) | ||
fn.append_input(var) | ||
|
@@ -516,23 +389,16 @@ | |
self, | ||
fn: IRFunction, | ||
varname: str, | ||
attribute_type: onnx.AttributeProto.AttributeType, | ||
attribute_type: ir.AttributeType, | ||
default_value: int | float | str | None, | ||
) -> None: | ||
fn.add_attr_parameter(IRAttributeParameter(varname, attribute_type, default_value)) | ||
fn.add_attr_parameter(ir_convenience.convert_attribute(varname, default_value, attribute_type)) | ||
Check failure Code scanning / lintrunner MYPY/arg-type Error
Argument 1 to "add_attr_parameter" of "IRFunction" has incompatible type "Attr | RefAttr"; expected "Attr"
To disable, use # type: ignore[arg-type]
|
||
|
||
def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None: | ||
var = IRVar(varname, typeinfo, sourceinfo) | ||
def add_output(self, fn: IRFunction, varname: str, typeinfo, source_info) -> None: | ||
var = IRVar(varname, typeinfo, source_info) | ||
fn.append_output(var) | ||
|
||
def make_attr(self, attrproto: onnx.AttributeProto) -> IRAttributeValue: | ||
return IRAttributeValue(attrproto) | ||
|
||
def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttributeValue: | ||
proto = onnx.AttributeProto() | ||
proto.name = attrname | ||
proto.ref_attr_name = refname | ||
attr_type = ta.pytype_to_attrtype(pytype) | ||
assert attr_type is not None | ||
proto.type = attr_type | ||
return IRAttributeValue(proto) | ||
def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> ir.RefAttr: | ||
return ir.RefAttr( | ||
attrname, refname, ta.pytype_to_attrtype(pytype) | ||
Check failure Code scanning / lintrunner MYPY/arg-type Error
Argument 3 to "RefAttr" has incompatible type "onnx.onnx_ml_pb2.AttributeProto.AttributeType | None"; expected "onnxscript.ir._enums.AttributeType"
To disable, use # type: ignore[arg-type]
|
||
) |
Check warning
Code scanning / lintrunner
PYLINT/W0611 Warning