Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Implement save/load functions in IR and handle external data properly #1801

Merged
merged 17 commits into from
Aug 13, 2024
4 changes: 4 additions & 0 deletions onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@
# Pass infrastructure
"passes",
"traversal",
# IO
"load",
"save",
]

from onnxscript.ir import passes, serde, traversal
Expand Down Expand Up @@ -114,6 +117,7 @@
AttributeType,
DataType,
)
from onnxscript.ir._io import load, save
from onnxscript.ir._protocols import (
ArrayCompatible,
AttributeProtocol,
Expand Down
30 changes: 26 additions & 4 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=

Attributes:
path: The path to the data file. This can be a relative path or an absolute path.
base_dir: The base directory for the external data. It is used to resolve relative paths.
At serialization, only the ``path`` is serialized into the "location" field of the TensorProto.
offset: The offset in bytes from the start of the file.
length: The length of the data in bytes.
dtype: The data type of the tensor.
Expand Down Expand Up @@ -509,8 +511,15 @@ def __init__(
name: str,
doc_string: str | None = None,
metadata_props: dict[str, str] | None = None,
base_dir: os.PathLike | str = "",
) -> None:
self._path = path
if os.path.isabs(path):
self._base_dir = os.path.dirname(path)
self._path = os.path.basename(path)
else:
self._base_dir = base_dir
self._path = path

self._offset: int | None = offset
self._length: int | None = length
self._dtype: _enums.DataType = dtype
Expand All @@ -528,6 +537,15 @@ def path(self) -> str | os.PathLike:
# Immutable
return self._path

@property
def base_dir(self) -> str | os.PathLike:
# Mutable
return self._base_dir

@base_dir.setter
def base_dir(self, value: str | os.PathLike) -> None:
self._base_dir = value

@property
def offset(self) -> int | None:
# Immutable
Expand Down Expand Up @@ -556,7 +574,8 @@ def _load(self):
return
# Map the whole file into the memory
# TODO(justinchuby): Verify if this would exhaust the memory address space
with open(self._path, "rb") as f:
file_path = os.path.join(self._base_dir, self._path)
with open(file_path, "rb") as f:
self.raw = mmap.mmap(
f.fileno(),
0,
Expand Down Expand Up @@ -599,7 +618,10 @@ def __dlpack_device__(self) -> tuple[int, int]:
)

def __repr__(self) -> str:
return f"{self._repr_base()}(path='{self._path}', name={self.name!r}, offset={self._offset!r}), length={self._length!r})"
return (
f"{self._repr_base()}(path='{self._path}', name={self.name!r}, "
f"offset={self._offset!r}, length={self._length!r}, base_dir={self._base_dir!r})"
)

def numpy(self) -> np.ndarray:
"""Return the tensor as a numpy array.
Expand Down Expand Up @@ -2069,7 +2091,7 @@ def __init__(
outputs: Sequence[Value],
*,
nodes: Iterable[Node],
initializers: Sequence[_protocols.TensorProtocol] = (),
initializers: Sequence[_protocols.ValueProtocol] = (),
doc_string: str | None = None,
opset_imports: dict[str, int] | None = None,
name: str | None = None,
Expand Down
75 changes: 75 additions & 0 deletions onnxscript/ir/_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Load and save ONNX models."""

from __future__ import annotations

__all__ = ["load", "save"]

import os
from typing import Iterator

import onnx

from onnxscript.ir import _core, _enums, _protocols, serde, traversal


def _all_tensors(
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
graph: _core.Graph | _core.GraphView, include_constants: bool = False
) -> Iterator[_protocols.TensorProtocol]:
"""Iterate over all tensors in the graph."""

# Yield all tensors in initializers
for value in graph.initializers.values():
if value.const_value is not None:
yield value.const_value
if not include_constants:
return
# Look at constant attributes in nodes
for node in traversal.RecursiveGraphIterator(graph):
for attr in node.attributes.values():
if isinstance(attr, _core.RefAttr):
continue
if attr.type == _enums.AttributeType.TENSOR and attr.value is not None:
yield attr.value
elif attr.type == _enums.AttributeType.TENSORS and attr.value is not None:
for value in attr.value:
yield value
Fixed Show fixed Hide fixed


def load(path: str | os.PathLike, format: str | None = None) -> _core.Model:
"""Load an ONNX model from a file.

Args:
path: The path to the ONNX file.
format: The format of the file (e.g. protobuf, textproto, json, etc.).
If None, the format is inferred from the file extension.

Returns:
The loaded model.
"""
# Do not use ONNX to load external data because the IR handles external data
# by doing memory mapping directly.
proto = onnx.load(path, format=format, load_external_data=False)
model = serde.deserialize_model(proto)
base_dir = os.path.dirname(path)
# Set the base directory for external data to the directory of the ONNX file
# so that relative paths are resolved correctly.
for tensor in _all_tensors(model.graph, include_constants=True):
if isinstance(tensor, _core.ExternalTensor):
tensor.base_dir = base_dir
return model


def save(model: _core.Model, path: str | os.PathLike, format: str | None = None) -> None:
"""Save an ONNX model to a file.

Args:
model: The model to save.
path: The path to save the model to.
format: The format of the file (e.g. protobuf, textproto, json, etc.).
If None, the format is inferred from the file extension.
"""
proto = serde.serialize_model(model)
onnx.save(proto, path, format=format)
# TODO(justinchuby): Handle external data
Loading