diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py index 80df83bbf..b9266ea1f 100644 --- a/onnxscript/ir/__init__.py +++ b/onnxscript/ir/__init__.py @@ -71,6 +71,9 @@ # Pass infrastructure "passes", "traversal", + # IO + "load", + "save", ] from onnxscript.ir import passes, serde, traversal @@ -114,6 +117,7 @@ AttributeType, DataType, ) +from onnxscript.ir._io import load, save from onnxscript.ir._protocols import ( ArrayCompatible, AttributeProtocol, diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index b5a29cdd4..f1f5c9350 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -475,6 +475,8 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable= Attributes: path: The path to the data file. This can be a relative path or an absolute path. + base_dir: The base directory for the external data. It is used to resolve relative paths. + At serialization, only the ``path`` is serialized into the "location" field of the TensorProto. offset: The offset in bytes from the start of the file. length: The length of the data in bytes. dtype: The data type of the tensor. @@ -509,8 +511,15 @@ def __init__( name: str, doc_string: str | None = None, metadata_props: dict[str, str] | None = None, + base_dir: os.PathLike | str = "", ) -> None: - self._path = path + if os.path.isabs(path): + self._base_dir = os.path.dirname(path) + self._path = os.path.basename(path) + else: + self._base_dir = base_dir + self._path = path + self._offset: int | None = offset self._length: int | None = length self._dtype: _enums.DataType = dtype @@ -528,6 +537,15 @@ def path(self) -> str | os.PathLike: # Immutable return self._path + @property + def base_dir(self) -> str | os.PathLike: + # Mutable + return self._base_dir + + @base_dir.setter + def base_dir(self, value: str | os.PathLike) -> None: + self._base_dir = value + @property def offset(self) -> int | None: # Immutable @@ -556,7 +574,8 @@ def _load(self): return # Map the whole file into the memory # TODO(justinchuby): Verify if this would exhaust the memory address space - with open(self._path, "rb") as f: + file_path = os.path.join(self._base_dir, self._path) + with open(file_path, "rb") as f: self.raw = mmap.mmap( f.fileno(), 0, @@ -599,7 +618,10 @@ def __dlpack_device__(self) -> tuple[int, int]: ) def __repr__(self) -> str: - return f"{self._repr_base()}(path='{self._path}', name={self.name!r}, offset={self._offset!r}), length={self._length!r})" + return ( + f"{self._repr_base()}(path='{self._path}', name={self.name!r}, " + f"offset={self._offset!r}, length={self._length!r}, base_dir={self._base_dir!r})" + ) def numpy(self) -> np.ndarray: """Return the tensor as a numpy array. @@ -2069,7 +2091,7 @@ def __init__( outputs: Sequence[Value], *, nodes: Iterable[Node], - initializers: Sequence[_protocols.TensorProtocol] = (), + initializers: Sequence[_protocols.ValueProtocol] = (), doc_string: str | None = None, opset_imports: dict[str, int] | None = None, name: str | None = None, diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index 1fbbca692..c284fa365 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -243,6 +243,23 @@ def test_initialize(self): # Ensure repeated reads are consistent np.testing.assert_equal(tensor, self.data) + def test_initialize_with_relative_path(self): + external_tensor = self.model.graph.initializer[0] + external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor) + tensor = _core.ExternalTensor( + path=external_info.location, + offset=external_info.offset, + length=external_info.length, + dtype=ir.DataType.FLOAT, + name="input", + shape=_core.Shape(external_tensor.dims), + base_dir=pathlib.Path(self.base_path), + ) + self.assertEqual(tensor.dtype, ir.DataType.FLOAT) + np.testing.assert_equal(tensor, self.data) + # Ensure repeated reads are consistent + np.testing.assert_equal(tensor, self.data) + def test_totypes_returns_correct_data_in(self): external_tensor = self.model.graph.initializer[0] external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor) diff --git a/onnxscript/ir/_external_data.py b/onnxscript/ir/_external_data.py new file mode 100644 index 000000000..3d19bae5c --- /dev/null +++ b/onnxscript/ir/_external_data.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""External data related utilities.""" + +from __future__ import annotations + +__all__ = ["set_base_dir"] + +import os +from typing import Iterator + +from onnxscript.ir import _core, _enums, _protocols, traversal + + +def _all_tensors( + graph: _core.Graph | _core.GraphView, include_attributes: bool = False +) -> Iterator[_protocols.TensorProtocol]: + """Iterate over all tensors in the graph. + + Args: + graph: The graph to traverse tensors on. + include_attributes: Whether to include tensors in attributes. + + Yields: + Tensors in the graph. + """ + # Yield all tensors in initializers + for value in graph.initializers.values(): + if value.const_value is not None: + yield value.const_value + if not include_attributes: + return + # Look at constant attributes in nodes + for node in traversal.RecursiveGraphIterator(graph): + for attr in node.attributes.values(): + if isinstance(attr, _core.RefAttr): + continue + if attr.type == _enums.AttributeType.TENSOR and attr.value is not None: + yield attr.value + elif attr.type == _enums.AttributeType.TENSORS and attr.value is not None: + yield from attr.value + + +def set_base_dir(graph: _core.Graph | _core.GraphView, base_dir: str | os.PathLike) -> None: + """Set the base directory for external data in a graph. + + Args: + graph: The graph to traverse tensors on. + base_dir: The base directory. This is the directory where the ONNX file is. + """ + for tensor in _all_tensors(graph, include_attributes=True): + if isinstance(tensor, _core.ExternalTensor): + tensor.base_dir = base_dir diff --git a/onnxscript/ir/_external_data_test.py b/onnxscript/ir/_external_data_test.py new file mode 100644 index 000000000..624f7e0a5 --- /dev/null +++ b/onnxscript/ir/_external_data_test.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import onnx +import onnx.external_data_helper + +from onnxscript import ir +from onnxscript.ir import _external_data + + +class ExternalDataTest(unittest.TestCase): + def test_set_base_dir_sets_base_dir_for_all_external_tensors(self): + attr_tensor = onnx.helper.make_tensor( + name="test_constant", + data_type=onnx.TensorProto.FLOAT, + dims=[1], + vals=b"\x01\x00\x00\x00", + raw=True, + ) + graph = onnx.helper.make_graph( + nodes=[ + onnx.helper.make_node( + "Constant", + [], + ["test"], + value=attr_tensor, + ) + ], + name="test", + inputs=[], + outputs=[], + initializer=[ + onnx.helper.make_tensor( + name="test_tensor", + data_type=onnx.TensorProto.FLOAT, + dims=[1], + vals=b"\x01\x00\x00\x00", + raw=True, + ), + ], + ) + model_proto = onnx.helper.make_model(graph) + onnx.external_data_helper.convert_model_to_external_data( + model_proto, location="tempdir", size_threshold=0, convert_attribute=True + ) + model = ir.serde.deserialize_model(model_proto) + expected_dir = "something_else" + _external_data.set_base_dir(model.graph, expected_dir) + + initializer_tensor = model.graph.initializers["test_tensor"].const_value + assert isinstance(initializer_tensor, ir.ExternalTensor) + self.assertEqual(initializer_tensor.base_dir, expected_dir) + attr_tensor = model.graph.node(0).attributes["value"].value + self.assertEqual(attr_tensor.base_dir, expected_dir) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/ir/_io.py b/onnxscript/ir/_io.py new file mode 100644 index 000000000..a9c867f3f --- /dev/null +++ b/onnxscript/ir/_io.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Load and save ONNX models.""" + +from __future__ import annotations + +__all__ = ["load", "save"] + +import os + +import onnx + +from onnxscript.ir import _core, _external_data, serde + + +def load(path: str | os.PathLike, format: str | None = None) -> _core.Model: + """Load an ONNX model from a file. + + Args: + path: The path to the ONNX file. + format: The format of the file (e.g. protobuf, textproto, json, etc.). + If None, the format is inferred from the file extension. + + Returns: + The loaded model. + """ + # Do not use ONNX to load external data because the IR handles external data + # by doing memory mapping directly. + proto = onnx.load(path, format=format, load_external_data=False) + model = serde.deserialize_model(proto) + base_dir = os.path.dirname(path) + # Set the base directory for external data to the directory of the ONNX file + # so that relative paths are resolved correctly. + _external_data.set_base_dir(model.graph, base_dir) + return model + + +def save(model: _core.Model, path: str | os.PathLike, format: str | None = None) -> None: + """Save an ONNX model to a file. + + Args: + model: The model to save. + path: The path to save the model to. + format: The format of the file (e.g. protobuf, textproto, json, etc.). + If None, the format is inferred from the file extension. + """ + proto = serde.serialize_model(model) + onnx.save(proto, path, format=format) + # TODO(justinchuby): Handle external data when the relative path has changed + # TODO(justinchuby): Handle off loading external data to disk when saving diff --git a/onnxscript/ir/traversal.py b/onnxscript/ir/traversal.py index 5951506fe..5fa9a9acf 100644 --- a/onnxscript/ir/traversal.py +++ b/onnxscript/ir/traversal.py @@ -8,17 +8,19 @@ "RecursiveGraphIterator", ] -from typing import Callable, Iterator, Reversible +from typing import Callable, Iterator, Reversible, Union from typing_extensions import Self from onnxscript.ir import _core, _enums +GraphLike = Union[_core.Graph, _core.Function, _core.GraphView] + class RecursiveGraphIterator(Iterator[_core.Node], Reversible[_core.Node]): def __init__( self, - graph: _core.Graph | _core.Function | _core.GraphView, + graph_like: GraphLike, *, recursive: Callable[[_core.Node], bool] | None = None, reverse: bool = False, @@ -26,15 +28,15 @@ def __init__( """Iterate over the nodes in the graph, recursively visiting subgraphs. Args: - graph: The graph to traverse. + graph_like: The graph to traverse. recursive: A callback that determines whether to recursively visit the subgraphs contained in a node. If not provided, all nodes in subgraphs are visited. reverse: Whether to iterate in reverse order. """ - self._graph = graph + self._graph = graph_like self._recursive = recursive self._reverse = reverse - self._iterator = self._recursive_node_iter(graph) + self._iterator = self._recursive_node_iter(graph_like) def __iter__(self) -> Self: self._iterator = self._recursive_node_iter(self._graph)