From 87d7c4fbd59dc2398a232d959aa5f2e9df3707a2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Aug 2024 14:34:44 -0700 Subject: [PATCH] [IR] Implement save/load functions in IR and handle external data properly (#1801) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement efficient save/load and handle loading external data properly in the IR. Before this change, when a ModelProto containing external data is converted to IR, the external tensor objects will load the data from a path relative to the working directory, not the ONNX file. This is because we do not store the onnx file path and thus have no way to look for the external data file. With the change, a `base_dir` property is added to ExternalTensor that we can set, in a separate pass when the directory is available, so the object has full information to find the data file on disk. The base_dir is not serialized to the proto to maintain a relative path in the "location" field in TensorProto. https://github.com/microsoft/onnxscript/issues/1701, https://github.com/microsoft/onnxscript/issues/1792 Example: ``` >>> m.graph.initializers["model.model.decoder.layers.2.encoder_attn.v_proj.weight"].const_value.display() ExternalTensor(path='model.onnx.data', name='model.model.decoder.layers.2.encoder_attn.v_proj.weight', offset=245864448, length=1048576, base_dir='/home/justinchu/dev/ONNXConverter/docker/dump_bash_bench/BlenderbotSmallForConditionalGeneration-torch -onnx-detailed-cpu-') Min: -0.08586505800485611, Max: 0.09103105217218399, NaN count: 0, Inf count: 0 Sparsity (abs<1e-06): 0.00 Histogram: 11504 ┼ 10226 ┤ ╭───────╮ 8948 ┤ ╭─╯ ╰─╮ 7670 ┤ ╭─╯ ╰─╮ 6392 ┤ ╭─╯ ╰─╮ 5113 ┤ ╭─╯ ╰─╮ 3835 ┤ ╭─╯ ╰─╮ 2557 ┤ ╭──╯ ╰─╮ 1279 ┤ ╭────╯ ╰────╮ 1 ┼────────────────╯ ╰─────────────────── -0.0859 -0.0682 -0.0505 -0.0306 -0.0129 0.0070 0.0225 0.0402 0.0557 0.0733 0.0910 ``` --- onnxscript/ir/__init__.py | 4 ++ onnxscript/ir/_core.py | 30 ++++++++++++-- onnxscript/ir/_core_test.py | 17 ++++++++ onnxscript/ir/_external_data.py | 53 +++++++++++++++++++++++++ onnxscript/ir/_external_data_test.py | 59 ++++++++++++++++++++++++++++ onnxscript/ir/_io.py | 50 +++++++++++++++++++++++ onnxscript/ir/traversal.py | 12 +++--- 7 files changed, 216 insertions(+), 9 deletions(-) create mode 100644 onnxscript/ir/_external_data.py create mode 100644 onnxscript/ir/_external_data_test.py create mode 100644 onnxscript/ir/_io.py diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py index 80df83bbf..b9266ea1f 100644 --- a/onnxscript/ir/__init__.py +++ b/onnxscript/ir/__init__.py @@ -71,6 +71,9 @@ # Pass infrastructure "passes", "traversal", + # IO + "load", + "save", ] from onnxscript.ir import passes, serde, traversal @@ -114,6 +117,7 @@ AttributeType, DataType, ) +from onnxscript.ir._io import load, save from onnxscript.ir._protocols import ( ArrayCompatible, AttributeProtocol, diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index b5a29cdd4..f1f5c9350 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -475,6 +475,8 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable= Attributes: path: The path to the data file. This can be a relative path or an absolute path. + base_dir: The base directory for the external data. It is used to resolve relative paths. + At serialization, only the ``path`` is serialized into the "location" field of the TensorProto. offset: The offset in bytes from the start of the file. length: The length of the data in bytes. dtype: The data type of the tensor. @@ -509,8 +511,15 @@ def __init__( name: str, doc_string: str | None = None, metadata_props: dict[str, str] | None = None, + base_dir: os.PathLike | str = "", ) -> None: - self._path = path + if os.path.isabs(path): + self._base_dir = os.path.dirname(path) + self._path = os.path.basename(path) + else: + self._base_dir = base_dir + self._path = path + self._offset: int | None = offset self._length: int | None = length self._dtype: _enums.DataType = dtype @@ -528,6 +537,15 @@ def path(self) -> str | os.PathLike: # Immutable return self._path + @property + def base_dir(self) -> str | os.PathLike: + # Mutable + return self._base_dir + + @base_dir.setter + def base_dir(self, value: str | os.PathLike) -> None: + self._base_dir = value + @property def offset(self) -> int | None: # Immutable @@ -556,7 +574,8 @@ def _load(self): return # Map the whole file into the memory # TODO(justinchuby): Verify if this would exhaust the memory address space - with open(self._path, "rb") as f: + file_path = os.path.join(self._base_dir, self._path) + with open(file_path, "rb") as f: self.raw = mmap.mmap( f.fileno(), 0, @@ -599,7 +618,10 @@ def __dlpack_device__(self) -> tuple[int, int]: ) def __repr__(self) -> str: - return f"{self._repr_base()}(path='{self._path}', name={self.name!r}, offset={self._offset!r}), length={self._length!r})" + return ( + f"{self._repr_base()}(path='{self._path}', name={self.name!r}, " + f"offset={self._offset!r}, length={self._length!r}, base_dir={self._base_dir!r})" + ) def numpy(self) -> np.ndarray: """Return the tensor as a numpy array. @@ -2069,7 +2091,7 @@ def __init__( outputs: Sequence[Value], *, nodes: Iterable[Node], - initializers: Sequence[_protocols.TensorProtocol] = (), + initializers: Sequence[_protocols.ValueProtocol] = (), doc_string: str | None = None, opset_imports: dict[str, int] | None = None, name: str | None = None, diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index 1fbbca692..c284fa365 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -243,6 +243,23 @@ def test_initialize(self): # Ensure repeated reads are consistent np.testing.assert_equal(tensor, self.data) + def test_initialize_with_relative_path(self): + external_tensor = self.model.graph.initializer[0] + external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor) + tensor = _core.ExternalTensor( + path=external_info.location, + offset=external_info.offset, + length=external_info.length, + dtype=ir.DataType.FLOAT, + name="input", + shape=_core.Shape(external_tensor.dims), + base_dir=pathlib.Path(self.base_path), + ) + self.assertEqual(tensor.dtype, ir.DataType.FLOAT) + np.testing.assert_equal(tensor, self.data) + # Ensure repeated reads are consistent + np.testing.assert_equal(tensor, self.data) + def test_totypes_returns_correct_data_in(self): external_tensor = self.model.graph.initializer[0] external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor) diff --git a/onnxscript/ir/_external_data.py b/onnxscript/ir/_external_data.py new file mode 100644 index 000000000..3d19bae5c --- /dev/null +++ b/onnxscript/ir/_external_data.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""External data related utilities.""" + +from __future__ import annotations + +__all__ = ["set_base_dir"] + +import os +from typing import Iterator + +from onnxscript.ir import _core, _enums, _protocols, traversal + + +def _all_tensors( + graph: _core.Graph | _core.GraphView, include_attributes: bool = False +) -> Iterator[_protocols.TensorProtocol]: + """Iterate over all tensors in the graph. + + Args: + graph: The graph to traverse tensors on. + include_attributes: Whether to include tensors in attributes. + + Yields: + Tensors in the graph. + """ + # Yield all tensors in initializers + for value in graph.initializers.values(): + if value.const_value is not None: + yield value.const_value + if not include_attributes: + return + # Look at constant attributes in nodes + for node in traversal.RecursiveGraphIterator(graph): + for attr in node.attributes.values(): + if isinstance(attr, _core.RefAttr): + continue + if attr.type == _enums.AttributeType.TENSOR and attr.value is not None: + yield attr.value + elif attr.type == _enums.AttributeType.TENSORS and attr.value is not None: + yield from attr.value + + +def set_base_dir(graph: _core.Graph | _core.GraphView, base_dir: str | os.PathLike) -> None: + """Set the base directory for external data in a graph. + + Args: + graph: The graph to traverse tensors on. + base_dir: The base directory. This is the directory where the ONNX file is. + """ + for tensor in _all_tensors(graph, include_attributes=True): + if isinstance(tensor, _core.ExternalTensor): + tensor.base_dir = base_dir diff --git a/onnxscript/ir/_external_data_test.py b/onnxscript/ir/_external_data_test.py new file mode 100644 index 000000000..624f7e0a5 --- /dev/null +++ b/onnxscript/ir/_external_data_test.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import onnx +import onnx.external_data_helper + +from onnxscript import ir +from onnxscript.ir import _external_data + + +class ExternalDataTest(unittest.TestCase): + def test_set_base_dir_sets_base_dir_for_all_external_tensors(self): + attr_tensor = onnx.helper.make_tensor( + name="test_constant", + data_type=onnx.TensorProto.FLOAT, + dims=[1], + vals=b"\x01\x00\x00\x00", + raw=True, + ) + graph = onnx.helper.make_graph( + nodes=[ + onnx.helper.make_node( + "Constant", + [], + ["test"], + value=attr_tensor, + ) + ], + name="test", + inputs=[], + outputs=[], + initializer=[ + onnx.helper.make_tensor( + name="test_tensor", + data_type=onnx.TensorProto.FLOAT, + dims=[1], + vals=b"\x01\x00\x00\x00", + raw=True, + ), + ], + ) + model_proto = onnx.helper.make_model(graph) + onnx.external_data_helper.convert_model_to_external_data( + model_proto, location="tempdir", size_threshold=0, convert_attribute=True + ) + model = ir.serde.deserialize_model(model_proto) + expected_dir = "something_else" + _external_data.set_base_dir(model.graph, expected_dir) + + initializer_tensor = model.graph.initializers["test_tensor"].const_value + assert isinstance(initializer_tensor, ir.ExternalTensor) + self.assertEqual(initializer_tensor.base_dir, expected_dir) + attr_tensor = model.graph.node(0).attributes["value"].value + self.assertEqual(attr_tensor.base_dir, expected_dir) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/ir/_io.py b/onnxscript/ir/_io.py new file mode 100644 index 000000000..a9c867f3f --- /dev/null +++ b/onnxscript/ir/_io.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Load and save ONNX models.""" + +from __future__ import annotations + +__all__ = ["load", "save"] + +import os + +import onnx + +from onnxscript.ir import _core, _external_data, serde + + +def load(path: str | os.PathLike, format: str | None = None) -> _core.Model: + """Load an ONNX model from a file. + + Args: + path: The path to the ONNX file. + format: The format of the file (e.g. protobuf, textproto, json, etc.). + If None, the format is inferred from the file extension. + + Returns: + The loaded model. + """ + # Do not use ONNX to load external data because the IR handles external data + # by doing memory mapping directly. + proto = onnx.load(path, format=format, load_external_data=False) + model = serde.deserialize_model(proto) + base_dir = os.path.dirname(path) + # Set the base directory for external data to the directory of the ONNX file + # so that relative paths are resolved correctly. + _external_data.set_base_dir(model.graph, base_dir) + return model + + +def save(model: _core.Model, path: str | os.PathLike, format: str | None = None) -> None: + """Save an ONNX model to a file. + + Args: + model: The model to save. + path: The path to save the model to. + format: The format of the file (e.g. protobuf, textproto, json, etc.). + If None, the format is inferred from the file extension. + """ + proto = serde.serialize_model(model) + onnx.save(proto, path, format=format) + # TODO(justinchuby): Handle external data when the relative path has changed + # TODO(justinchuby): Handle off loading external data to disk when saving diff --git a/onnxscript/ir/traversal.py b/onnxscript/ir/traversal.py index 5951506fe..5fa9a9acf 100644 --- a/onnxscript/ir/traversal.py +++ b/onnxscript/ir/traversal.py @@ -8,17 +8,19 @@ "RecursiveGraphIterator", ] -from typing import Callable, Iterator, Reversible +from typing import Callable, Iterator, Reversible, Union from typing_extensions import Self from onnxscript.ir import _core, _enums +GraphLike = Union[_core.Graph, _core.Function, _core.GraphView] + class RecursiveGraphIterator(Iterator[_core.Node], Reversible[_core.Node]): def __init__( self, - graph: _core.Graph | _core.Function | _core.GraphView, + graph_like: GraphLike, *, recursive: Callable[[_core.Node], bool] | None = None, reverse: bool = False, @@ -26,15 +28,15 @@ def __init__( """Iterate over the nodes in the graph, recursively visiting subgraphs. Args: - graph: The graph to traverse. + graph_like: The graph to traverse. recursive: A callback that determines whether to recursively visit the subgraphs contained in a node. If not provided, all nodes in subgraphs are visited. reverse: Whether to iterate in reverse order. """ - self._graph = graph + self._graph = graph_like self._recursive = recursive self._reverse = reverse - self._iterator = self._recursive_node_iter(graph) + self._iterator = self._recursive_node_iter(graph_like) def __iter__(self) -> Self: self._iterator = self._recursive_node_iter(self._graph)