Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use IR in onnxscript #1409

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions onnxscript/ir/convenience.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,47 @@

from typing import Any, Mapping, Sequence

from onnxscript.ir import _core
from onnxscript.ir import _core, _protocols, _enums


def convert_attribute(
name: str,
attr: str
| int
| float
| Sequence[int]
| Sequence[float]
| Sequence[str]
| _protocols.TensorProtocol
| _core.Attr
| None,
attr_type: _enums.AttributeType | None= None,
) -> _core.Attr:
if attr is None:
if attr_type is None:
raise ValueError("attr_type must be provided when attr is None")

Check warning on line 29 in onnxscript/ir/convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/convenience.py#L29

Added line #L29 was not covered by tests
return _core.Attr(name, attr_type, None)
if isinstance(attr, int):
return _core.AttrInt64(name, attr)

Check warning on line 32 in onnxscript/ir/convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/convenience.py#L32

Added line #L32 was not covered by tests
if isinstance(attr, float):
return _core.AttrFloat32(name, attr)
if isinstance(attr, str):
return _core.AttrString(name, attr)
if isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr):
return _core.AttrInt64s(name, attr) # type: ignore

Check warning on line 38 in onnxscript/ir/convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/convenience.py#L38

Added line #L38 was not covered by tests
if isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr):
return _core.AttrFloat32s(name, attr) # type: ignore

Check warning on line 40 in onnxscript/ir/convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/convenience.py#L40

Added line #L40 was not covered by tests
if isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr):
return _core.AttrStrings(name, attr) # type: ignore

Check warning on line 42 in onnxscript/ir/convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/convenience.py#L42

Added line #L42 was not covered by tests
if isinstance(attr, (_core.Tensor, _protocols.TensorProtocol)):
return _core.AttrTensor(name, attr)

Check warning on line 44 in onnxscript/ir/convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/convenience.py#L44

Added line #L44 was not covered by tests
if isinstance(attr, _core.Attr):
return attr
raise TypeError(f"Unsupported attribute type: '{type(attr)}'")

Check warning on line 47 in onnxscript/ir/convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/convenience.py#L46-L47

Added lines #L46 - L47 were not covered by tests


def convert_attributes(attrs: Mapping[str, Any]) -> list[_core.Attr]:
attributes: list[_core.Attr] = []
for name, attr in attrs.items():
if isinstance(attr, int):
attributes.append(_core.AttrInt64(name, attr))
elif isinstance(attr, float):
attributes.append(_core.AttrFloat32(name, attr))
elif isinstance(attr, str):
attributes.append(_core.AttrString(name, attr))
elif isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr):
attributes.append(_core.AttrInt64s(name, attr))
elif isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr):
attributes.append(_core.AttrFloat32s(name, attr))
elif isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr):
attributes.append(_core.AttrStrings(name, attr))
elif isinstance(attr, _core.Attr):
attributes.append(attr)
else:
raise TypeError(f"Unsupported attribute type: '{type(attr)}'")
attributes.append(convert_attribute(name, attr))

Check warning on line 53 in onnxscript/ir/convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/convenience.py#L53

Added line #L53 was not covered by tests
return attributes
Loading
Loading