Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Implement save/load functions in IR and handle external data properly #1801

Merged
merged 17 commits into from
Aug 13, 2024
4 changes: 4 additions & 0 deletions onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@
# Pass infrastructure
"passes",
"traversal",
# IO
"load",
"save",
]

from onnxscript.ir import passes, serde, traversal
Expand Down Expand Up @@ -114,6 +117,7 @@
AttributeType,
DataType,
)
from onnxscript.ir._io import load, save
from onnxscript.ir._protocols import (
ArrayCompatible,
AttributeProtocol,
Expand Down
30 changes: 26 additions & 4 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@

Attributes:
path: The path to the data file. This can be a relative path or an absolute path.
base_dir: The base directory for the external data. It is used to resolve relative paths.
At serialization, only the ``path`` is serialized into the "location" field of the TensorProto.
offset: The offset in bytes from the start of the file.
length: The length of the data in bytes.
dtype: The data type of the tensor.
Expand Down Expand Up @@ -509,8 +511,15 @@
name: str,
doc_string: str | None = None,
metadata_props: dict[str, str] | None = None,
base_dir: os.PathLike | str = "",
) -> None:
self._path = path
if os.path.isabs(path):
self._base_dir = os.path.dirname(path)
self._path = os.path.basename(path)
else:
self._base_dir = base_dir
self._path = path

self._offset: int | None = offset
self._length: int | None = length
self._dtype: _enums.DataType = dtype
Expand All @@ -528,6 +537,15 @@
# Immutable
return self._path

@property
def base_dir(self) -> str | os.PathLike:
# Mutable
return self._base_dir

@base_dir.setter
def base_dir(self, value: str | os.PathLike) -> None:
self._base_dir = value

@property
def offset(self) -> int | None:
# Immutable
Expand Down Expand Up @@ -556,7 +574,8 @@
return
# Map the whole file into the memory
# TODO(justinchuby): Verify if this would exhaust the memory address space
with open(self._path, "rb") as f:
file_path = os.path.join(self._base_dir, self._path)
with open(file_path, "rb") as f:
self.raw = mmap.mmap(
f.fileno(),
0,
Expand Down Expand Up @@ -599,7 +618,10 @@
)

def __repr__(self) -> str:
return f"{self._repr_base()}(path='{self._path}', name={self.name!r}, offset={self._offset!r}), length={self._length!r})"
return (

Check warning on line 621 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L621

Added line #L621 was not covered by tests
f"{self._repr_base()}(path='{self._path}', name={self.name!r}, "
f"offset={self._offset!r}, length={self._length!r}, base_dir={self._base_dir!r})"
)

def numpy(self) -> np.ndarray:
"""Return the tensor as a numpy array.
Expand Down Expand Up @@ -2069,7 +2091,7 @@
outputs: Sequence[Value],
*,
nodes: Iterable[Node],
initializers: Sequence[_protocols.TensorProtocol] = (),
initializers: Sequence[_protocols.ValueProtocol] = (),
doc_string: str | None = None,
opset_imports: dict[str, int] | None = None,
name: str | None = None,
Expand Down
17 changes: 17 additions & 0 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,23 @@ def test_initialize(self):
# Ensure repeated reads are consistent
np.testing.assert_equal(tensor, self.data)

def test_initialize_with_relative_path(self):
external_tensor = self.model.graph.initializer[0]
external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor)
tensor = _core.ExternalTensor(
path=external_info.location,
offset=external_info.offset,
length=external_info.length,
dtype=ir.DataType.FLOAT,
name="input",
shape=_core.Shape(external_tensor.dims),
base_dir=pathlib.Path(self.base_path),
)
self.assertEqual(tensor.dtype, ir.DataType.FLOAT)
np.testing.assert_equal(tensor, self.data)
# Ensure repeated reads are consistent
np.testing.assert_equal(tensor, self.data)

def test_totypes_returns_correct_data_in(self):
external_tensor = self.model.graph.initializer[0]
external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor)
Expand Down
53 changes: 53 additions & 0 deletions onnxscript/ir/_external_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""External data related utilities."""

from __future__ import annotations

__all__ = ["set_base_dir"]

import os
from typing import Iterator

from onnxscript.ir import _core, _enums, _protocols, traversal


def _all_tensors(
graph: _core.Graph | _core.GraphView, include_attributes: bool = False
) -> Iterator[_protocols.TensorProtocol]:
"""Iterate over all tensors in the graph.

Args:
graph: The graph to traverse tensors on.
include_attributes: Whether to include tensors in attributes.

Yields:
Tensors in the graph.
"""
# Yield all tensors in initializers
for value in graph.initializers.values():
if value.const_value is not None:
yield value.const_value
if not include_attributes:
return

Check warning on line 32 in onnxscript/ir/_external_data.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_external_data.py#L32

Added line #L32 was not covered by tests
# Look at constant attributes in nodes
for node in traversal.RecursiveGraphIterator(graph):
for attr in node.attributes.values():
if isinstance(attr, _core.RefAttr):
continue

Check warning on line 37 in onnxscript/ir/_external_data.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_external_data.py#L37

Added line #L37 was not covered by tests
if attr.type == _enums.AttributeType.TENSOR and attr.value is not None:
yield attr.value
elif attr.type == _enums.AttributeType.TENSORS and attr.value is not None:
yield from attr.value

Check warning on line 41 in onnxscript/ir/_external_data.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_external_data.py#L41

Added line #L41 was not covered by tests


def set_base_dir(graph: _core.Graph | _core.GraphView, base_dir: str | os.PathLike) -> None:
"""Set the base directory for external data in a graph.

Args:
graph: The graph to traverse tensors on.
base_dir: The base directory. This is the directory where the ONNX file is.
"""
for tensor in _all_tensors(graph, include_attributes=True):
if isinstance(tensor, _core.ExternalTensor):
tensor.base_dir = base_dir
63 changes: 63 additions & 0 deletions onnxscript/ir/_external_data_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Microsoft Corporation.
Fixed Show fixed Hide fixed
# Licensed under the MIT License.
import unittest

import onnx
import onnx.external_data_helper

from onnxscript import ir
from onnxscript.ir import _external_data


class ExternalDataTest(unittest.TestCase):
def test_set_base_dir_sets_base_dir_for_all_external_tensors(self):
attr_tensor = onnx.helper.make_tensor(
name="test_constant",
data_type=onnx.TensorProto.FLOAT,
dims=[1],
vals=b"\x01\x00\x00\x00",
raw=True,
)
graph = onnx.helper.make_graph(
nodes=[
onnx.helper.make_node(
"Constant",
[],
["test"],
value=attr_tensor,
)
],
name="test",
inputs=[],
outputs=[],
initializer=[
onnx.helper.make_tensor(
name="test_tensor",
data_type=onnx.TensorProto.FLOAT,
dims=[1],
vals=b"\x01\x00\x00\x00",
raw=True,
),
],
)
model_proto = onnx.helper.make_model(graph)
onnx.external_data_helper.convert_model_to_external_data(
model_proto, location="tempdir", size_threshold=0, convert_attribute=True
)
model = ir.serde.deserialize_model(model_proto)
expected_dir = "something_else"
_external_data.set_base_dir(model.graph, expected_dir)

initializer_tensor = model.graph.initializers["test_tensor"].const_value
assert isinstance(
initializer_tensor, ir.ExternalTensor
)
self.assertEqual(
initializer_tensor.base_dir, expected_dir
)
attr_tensor = model.graph.node(0).attributes["value"].value
self.assertEqual(attr_tensor.base_dir, expected_dir)


if __name__ == "__main__":
unittest.main()

Check warning on line 63 in onnxscript/ir/_external_data_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_external_data_test.py#L63

Added line #L63 was not covered by tests
50 changes: 50 additions & 0 deletions onnxscript/ir/_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Load and save ONNX models."""

from __future__ import annotations

__all__ = ["load", "save"]

import os

import onnx

from onnxscript.ir import _core, _external_data, serde


def load(path: str | os.PathLike, format: str | None = None) -> _core.Model:
"""Load an ONNX model from a file.

Args:
path: The path to the ONNX file.
format: The format of the file (e.g. protobuf, textproto, json, etc.).
If None, the format is inferred from the file extension.

Returns:
The loaded model.
"""
# Do not use ONNX to load external data because the IR handles external data
# by doing memory mapping directly.
proto = onnx.load(path, format=format, load_external_data=False)
model = serde.deserialize_model(proto)
base_dir = os.path.dirname(path)

Check warning on line 31 in onnxscript/ir/_io.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_io.py#L29-L31

Added lines #L29 - L31 were not covered by tests
# Set the base directory for external data to the directory of the ONNX file
# so that relative paths are resolved correctly.
_external_data.set_base_dir(model.graph, base_dir)
return model

Check warning on line 35 in onnxscript/ir/_io.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_io.py#L34-L35

Added lines #L34 - L35 were not covered by tests


def save(model: _core.Model, path: str | os.PathLike, format: str | None = None) -> None:
"""Save an ONNX model to a file.

Args:
model: The model to save.
path: The path to save the model to.
format: The format of the file (e.g. protobuf, textproto, json, etc.).
If None, the format is inferred from the file extension.
"""
proto = serde.serialize_model(model)
onnx.save(proto, path, format=format)

Check warning on line 48 in onnxscript/ir/_io.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_io.py#L47-L48

Added lines #L47 - L48 were not covered by tests
# TODO(justinchuby): Handle external data when the relative path has changed
# TODO(justinchuby): Handle off loading external data to disk when saving
12 changes: 7 additions & 5 deletions onnxscript/ir/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,35 @@
"RecursiveGraphIterator",
]

from typing import Callable, Iterator, Reversible
from typing import Callable, Iterator, Reversible, Union

from typing_extensions import Self

from onnxscript.ir import _core, _enums

GraphLike = Union[_core.Graph, _core.Function, _core.GraphView]


class RecursiveGraphIterator(Iterator[_core.Node], Reversible[_core.Node]):
def __init__(
self,
graph: _core.Graph | _core.Function | _core.GraphView,
graph_like: GraphLike,
*,
recursive: Callable[[_core.Node], bool] | None = None,
reverse: bool = False,
):
"""Iterate over the nodes in the graph, recursively visiting subgraphs.

Args:
graph: The graph to traverse.
graph_like: The graph to traverse.
recursive: A callback that determines whether to recursively visit the subgraphs
contained in a node. If not provided, all nodes in subgraphs are visited.
reverse: Whether to iterate in reverse order.
"""
self._graph = graph
self._graph = graph_like
self._recursive = recursive
self._reverse = reverse
self._iterator = self._recursive_node_iter(graph)
self._iterator = self._recursive_node_iter(graph_like)

def __iter__(self) -> Self:
self._iterator = self._recursive_node_iter(self._graph)
Expand Down
Loading