From ba9605d291f474376b6e6144670ab2d53b221cd2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Aug 2024 17:36:01 +0000 Subject: [PATCH 01/17] [IR] Implement efficient save/load --- onnxscript/ir/__init__.py | 4 ++++ onnxscript/ir/_io.py | 46 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 onnxscript/ir/_io.py diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py index 80df83bbf..9fcf9d6db 100644 --- a/onnxscript/ir/__init__.py +++ b/onnxscript/ir/__init__.py @@ -71,6 +71,9 @@ # Pass infrastructure "passes", "traversal", + # IO + "load", + "save", ] from onnxscript.ir import passes, serde, traversal @@ -134,6 +137,7 @@ ValueProtocol, ) from onnxscript.ir.serde import TensorProtoTensor, from_proto, to_proto +from onnxscript.ir._io import load, save def __set_module() -> None: diff --git a/onnxscript/ir/_io.py b/onnxscript/ir/_io.py new file mode 100644 index 000000000..9200a6103 --- /dev/null +++ b/onnxscript/ir/_io.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Load and save ONNX models.""" + +from __future__ import annotations + +import os + +import onnx + +from onnxscript.ir import _core, serde + + +def load(path: str | os.PathLike, format: str | None = None) -> _core.Model: + """Load an ONNX model from a file. + + Args: + path: The path to the ONNX file. + format: The format of the file (e.g. protobuf, textproto, json, etc.). + If None, the format is inferred from the file extension. + + Returns: + The loaded model. + """ + # Do not use ONNX to load external data because the IR handles external data + # by doing memory mapping directly. + proto = onnx.load(path, format=format, load_external_data=False) + model = serde.deserialize_model(proto) + base_dir = os.path.dirname(path) + # Set the base directory for external data to the directory of the ONNX file + # so that relative paths are resolved correctly. + _external_data.set_base_dir(model, base_dir) + return model + + +def save(model: _core.Model, path: str | os.PathLike, format: str | None = None) -> None: + """Save an ONNX model to a file. + + Args: + model: The model to save. + path: The path to save the model to. + format: The format of the file (e.g. protobuf, textproto, json, etc.). + If None, the format is inferred from the file extension. + """ + onnx_model = serde.serialize_model(model) + onnx.save(onnx_model, path, format=format) From 4272a1f36e73dfa050534f4ba0e89d07bc23aa8d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Aug 2024 17:36:44 +0000 Subject: [PATCH 02/17] lint --- onnxscript/ir/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py index 9fcf9d6db..b9266ea1f 100644 --- a/onnxscript/ir/__init__.py +++ b/onnxscript/ir/__init__.py @@ -117,6 +117,7 @@ AttributeType, DataType, ) +from onnxscript.ir._io import load, save from onnxscript.ir._protocols import ( ArrayCompatible, AttributeProtocol, @@ -137,7 +138,6 @@ ValueProtocol, ) from onnxscript.ir.serde import TensorProtoTensor, from_proto, to_proto -from onnxscript.ir._io import load, save def __set_module() -> None: From 2b2325404fa6763c81b644e2d3940a049ff1bc74 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Aug 2024 18:20:04 +0000 Subject: [PATCH 03/17] Save --- onnxscript/ir/_core.py | 21 +++++++++++++++++++-- onnxscript/ir/_io.py | 30 ++++++++++++++++++++++++++++-- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index b5a29cdd4..0fdbc7037 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -475,6 +475,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable= Attributes: path: The path to the data file. This can be a relative path or an absolute path. + base_dir: The base directory for the external data. It is used to resolve relative paths. offset: The offset in bytes from the start of the file. length: The length of the data in bytes. dtype: The data type of the tensor. @@ -509,8 +510,15 @@ def __init__( name: str, doc_string: str | None = None, metadata_props: dict[str, str] | None = None, + base_dir: os.PathLike | str = "", ) -> None: - self._path = path + if os.path.isabs(path): + self._base_dir = os.path.dirname(path) + self._path = os.path.basename(path) + else: + self._base_dir = base_dir + self._path = path + self._offset: int | None = offset self._length: int | None = length self._dtype: _enums.DataType = dtype @@ -528,6 +536,15 @@ def path(self) -> str | os.PathLike: # Immutable return self._path + @property + def base_dir(self) -> str | os.PathLike: + # Mutable + return self._base_dir + + @base_dir.setter + def base_dir(self, value: str | os.PathLike) -> None: + self._base_dir = value + @property def offset(self) -> int | None: # Immutable @@ -2069,7 +2086,7 @@ def __init__( outputs: Sequence[Value], *, nodes: Iterable[Node], - initializers: Sequence[_protocols.TensorProtocol] = (), + initializers: Sequence[_protocols.ValueProtocol] = (), doc_string: str | None = None, opset_imports: dict[str, int] | None = None, name: str | None = None, diff --git a/onnxscript/ir/_io.py b/onnxscript/ir/_io.py index 9200a6103..567829690 100644 --- a/onnxscript/ir/_io.py +++ b/onnxscript/ir/_io.py @@ -5,10 +5,34 @@ from __future__ import annotations import os +from typing import Iterator import onnx -from onnxscript.ir import _core, serde +from onnxscript.ir import _core, _enums, _protocols, serde, traversal + + +def _all_tensors( + graph: _core.Graph | _core.GraphView, include_constants: bool = False +) -> Iterator[_protocols.TensorProtocol]: + """Iterate over all tensors in the graph.""" + + # Yield all tensors in initializers + for value in graph.initializers.values(): + if value.const_value is not None: + yield value.const_value + if not include_constants: + return + # Look at constant attributes in nodes + for node in traversal.RecursiveGraphIterator(graph): + for attr in node.attributes.values(): + if isinstance(attr, _core.RefAttr): + continue + if attr.type == _enums.AttributeType.TENSOR and attr.value is not None: + yield attr.value + elif attr.type == _enums.AttributeType.TENSORS and attr.value is not None: + for value in attr.value: + yield value def load(path: str | os.PathLike, format: str | None = None) -> _core.Model: @@ -29,7 +53,9 @@ def load(path: str | os.PathLike, format: str | None = None) -> _core.Model: base_dir = os.path.dirname(path) # Set the base directory for external data to the directory of the ONNX file # so that relative paths are resolved correctly. - _external_data.set_base_dir(model, base_dir) + for tensor in _all_tensors(model.graph, include_constants=True): + if isinstance(tensor, _core.ExternalTensor): + tensor.base_dir = base_dir return model From b973bef71b80877f33bc642944f8b0cd22e20f85 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Aug 2024 18:20:25 +0000 Subject: [PATCH 04/17] all --- onnxscript/ir/_io.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxscript/ir/_io.py b/onnxscript/ir/_io.py index 567829690..3e401d6a1 100644 --- a/onnxscript/ir/_io.py +++ b/onnxscript/ir/_io.py @@ -4,6 +4,8 @@ from __future__ import annotations +__all__ = ["load", "save"] + import os from typing import Iterator From c1c50f9a8e2f6e1840b8c474e6e1954b0a10ecc7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Aug 2024 18:20:54 +0000 Subject: [PATCH 05/17] proto --- onnxscript/ir/_io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/_io.py b/onnxscript/ir/_io.py index 3e401d6a1..509328957 100644 --- a/onnxscript/ir/_io.py +++ b/onnxscript/ir/_io.py @@ -70,5 +70,5 @@ def save(model: _core.Model, path: str | os.PathLike, format: str | None = None) format: The format of the file (e.g. protobuf, textproto, json, etc.). If None, the format is inferred from the file extension. """ - onnx_model = serde.serialize_model(model) - onnx.save(onnx_model, path, format=format) + proto = serde.serialize_model(model) + onnx.save(proto, path, format=format) From 10e229edeff8ec4a4e7525e12c058f2c13c85bfa Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Aug 2024 18:29:00 +0000 Subject: [PATCH 06/17] load --- onnxscript/ir/_core.py | 9 +++++++-- onnxscript/ir/_io.py | 1 + 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 0fdbc7037..f1f5c9350 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -476,6 +476,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable= Attributes: path: The path to the data file. This can be a relative path or an absolute path. base_dir: The base directory for the external data. It is used to resolve relative paths. + At serialization, only the ``path`` is serialized into the "location" field of the TensorProto. offset: The offset in bytes from the start of the file. length: The length of the data in bytes. dtype: The data type of the tensor. @@ -573,7 +574,8 @@ def _load(self): return # Map the whole file into the memory # TODO(justinchuby): Verify if this would exhaust the memory address space - with open(self._path, "rb") as f: + file_path = os.path.join(self._base_dir, self._path) + with open(file_path, "rb") as f: self.raw = mmap.mmap( f.fileno(), 0, @@ -616,7 +618,10 @@ def __dlpack_device__(self) -> tuple[int, int]: ) def __repr__(self) -> str: - return f"{self._repr_base()}(path='{self._path}', name={self.name!r}, offset={self._offset!r}), length={self._length!r})" + return ( + f"{self._repr_base()}(path='{self._path}', name={self.name!r}, " + f"offset={self._offset!r}, length={self._length!r}, base_dir={self._base_dir!r})" + ) def numpy(self) -> np.ndarray: """Return the tensor as a numpy array. diff --git a/onnxscript/ir/_io.py b/onnxscript/ir/_io.py index 509328957..0333a09ac 100644 --- a/onnxscript/ir/_io.py +++ b/onnxscript/ir/_io.py @@ -72,3 +72,4 @@ def save(model: _core.Model, path: str | os.PathLike, format: str | None = None) """ proto = serde.serialize_model(model) onnx.save(proto, path, format=format) + # TODO(justinchuby): Handle external data From 6ff6d2a3c55573ad108ce249da91e4ebcffc4977 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Aug 2024 18:41:14 +0000 Subject: [PATCH 07/17] fix --- onnxscript/ir/_io.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxscript/ir/_io.py b/onnxscript/ir/_io.py index 0333a09ac..33a4b9224 100644 --- a/onnxscript/ir/_io.py +++ b/onnxscript/ir/_io.py @@ -33,8 +33,7 @@ def _all_tensors( if attr.type == _enums.AttributeType.TENSOR and attr.value is not None: yield attr.value elif attr.type == _enums.AttributeType.TENSORS and attr.value is not None: - for value in attr.value: - yield value + yield from attr.value def load(path: str | os.PathLike, format: str | None = None) -> _core.Model: From c121639cacb1877a70fd4f16f2d8c56004be0cbf Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Aug 2024 18:45:04 +0000 Subject: [PATCH 08/17] Refactor --- onnxscript/ir/_external_data.py | 49 +++++++++++++++++++++++++++++++++ onnxscript/ir/_io.py | 29 ++----------------- 2 files changed, 51 insertions(+), 27 deletions(-) create mode 100644 onnxscript/ir/_external_data.py diff --git a/onnxscript/ir/_external_data.py b/onnxscript/ir/_external_data.py new file mode 100644 index 000000000..c75f40872 --- /dev/null +++ b/onnxscript/ir/_external_data.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""External data related utilities.""" + +from __future__ import annotations + + +import os +from typing import Iterator + + + +import onnx + +from onnxscript.ir import _core, _enums, _protocols, serde, traversal + + +def _all_tensors( + graph: _core.Graph | _core.GraphView, include_constants: bool = False +) -> Iterator[_protocols.TensorProtocol]: + """Iterate over all tensors in the graph.""" + + # Yield all tensors in initializers + for value in graph.initializers.values(): + if value.const_value is not None: + yield value.const_value + if not include_constants: + return + # Look at constant attributes in nodes + for node in traversal.RecursiveGraphIterator(graph): + for attr in node.attributes.values(): + if isinstance(attr, _core.RefAttr): + continue + if attr.type == _enums.AttributeType.TENSOR and attr.value is not None: + yield attr.value + elif attr.type == _enums.AttributeType.TENSORS and attr.value is not None: + yield from attr.value + + +def set_base_dir(graph: _core.Graph, base_dir: str | os.PathLike) -> None: + """Set the base directory for external data in a model. + + Args: + model: The model. + base_dir: The base directory. + """ + for tensor in _all_tensors(graph, include_constants=True): + if isinstance(tensor, _core.ExternalTensor): + tensor.base_dir = base_dir diff --git a/onnxscript/ir/_io.py b/onnxscript/ir/_io.py index 33a4b9224..ce196d60a 100644 --- a/onnxscript/ir/_io.py +++ b/onnxscript/ir/_io.py @@ -7,33 +7,10 @@ __all__ = ["load", "save"] import os -from typing import Iterator import onnx -from onnxscript.ir import _core, _enums, _protocols, serde, traversal - - -def _all_tensors( - graph: _core.Graph | _core.GraphView, include_constants: bool = False -) -> Iterator[_protocols.TensorProtocol]: - """Iterate over all tensors in the graph.""" - - # Yield all tensors in initializers - for value in graph.initializers.values(): - if value.const_value is not None: - yield value.const_value - if not include_constants: - return - # Look at constant attributes in nodes - for node in traversal.RecursiveGraphIterator(graph): - for attr in node.attributes.values(): - if isinstance(attr, _core.RefAttr): - continue - if attr.type == _enums.AttributeType.TENSOR and attr.value is not None: - yield attr.value - elif attr.type == _enums.AttributeType.TENSORS and attr.value is not None: - yield from attr.value +from onnxscript.ir import _core, _external_data, serde def load(path: str | os.PathLike, format: str | None = None) -> _core.Model: @@ -54,9 +31,7 @@ def load(path: str | os.PathLike, format: str | None = None) -> _core.Model: base_dir = os.path.dirname(path) # Set the base directory for external data to the directory of the ONNX file # so that relative paths are resolved correctly. - for tensor in _all_tensors(model.graph, include_constants=True): - if isinstance(tensor, _core.ExternalTensor): - tensor.base_dir = base_dir + _external_data.set_base_dir(model.graph, base_dir) return model From 8624f24c663d5164a63ffd065d1b4604de79fe4b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Aug 2024 18:47:18 +0000 Subject: [PATCH 09/17] format --- onnxscript/ir/_external_data.py | 13 ++++--------- onnxscript/ir/traversal.py | 12 +++++++----- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/onnxscript/ir/_external_data.py b/onnxscript/ir/_external_data.py index c75f40872..bee78a53d 100644 --- a/onnxscript/ir/_external_data.py +++ b/onnxscript/ir/_external_data.py @@ -4,15 +4,10 @@ from __future__ import annotations - import os from typing import Iterator - - -import onnx - -from onnxscript.ir import _core, _enums, _protocols, serde, traversal +from onnxscript.ir import _core, _enums, _protocols, traversal def _all_tensors( @@ -37,11 +32,11 @@ def _all_tensors( yield from attr.value -def set_base_dir(graph: _core.Graph, base_dir: str | os.PathLike) -> None: - """Set the base directory for external data in a model. +def set_base_dir(graph: _core.Graph | _core.GraphView, base_dir: str | os.PathLike) -> None: + """Set the base directory for external data in a graph. Args: - model: The model. + graph: The graph. base_dir: The base directory. """ for tensor in _all_tensors(graph, include_constants=True): diff --git a/onnxscript/ir/traversal.py b/onnxscript/ir/traversal.py index 5951506fe..5fa9a9acf 100644 --- a/onnxscript/ir/traversal.py +++ b/onnxscript/ir/traversal.py @@ -8,17 +8,19 @@ "RecursiveGraphIterator", ] -from typing import Callable, Iterator, Reversible +from typing import Callable, Iterator, Reversible, Union from typing_extensions import Self from onnxscript.ir import _core, _enums +GraphLike = Union[_core.Graph, _core.Function, _core.GraphView] + class RecursiveGraphIterator(Iterator[_core.Node], Reversible[_core.Node]): def __init__( self, - graph: _core.Graph | _core.Function | _core.GraphView, + graph_like: GraphLike, *, recursive: Callable[[_core.Node], bool] | None = None, reverse: bool = False, @@ -26,15 +28,15 @@ def __init__( """Iterate over the nodes in the graph, recursively visiting subgraphs. Args: - graph: The graph to traverse. + graph_like: The graph to traverse. recursive: A callback that determines whether to recursively visit the subgraphs contained in a node. If not provided, all nodes in subgraphs are visited. reverse: Whether to iterate in reverse order. """ - self._graph = graph + self._graph = graph_like self._recursive = recursive self._reverse = reverse - self._iterator = self._recursive_node_iter(graph) + self._iterator = self._recursive_node_iter(graph_like) def __iter__(self) -> Self: self._iterator = self._recursive_node_iter(self._graph) From 62a075c87c6f86f7c97af981fedd5f234a939050 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Aug 2024 18:48:21 +0000 Subject: [PATCH 10/17] docs --- onnxscript/ir/_external_data.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/_external_data.py b/onnxscript/ir/_external_data.py index bee78a53d..d5779fb6f 100644 --- a/onnxscript/ir/_external_data.py +++ b/onnxscript/ir/_external_data.py @@ -4,6 +4,8 @@ from __future__ import annotations +__all__ = ["set_base_dir"] + import os from typing import Iterator @@ -36,8 +38,8 @@ def set_base_dir(graph: _core.Graph | _core.GraphView, base_dir: str | os.PathLi """Set the base directory for external data in a graph. Args: - graph: The graph. - base_dir: The base directory. + graph: The graph to traverse tensors on. + base_dir: The base directory. This is the directory where the ONNX file is. """ for tensor in _all_tensors(graph, include_constants=True): if isinstance(tensor, _core.ExternalTensor): From 997384fa056b3df547b728b409e237ef39358f45 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Aug 2024 18:50:05 +0000 Subject: [PATCH 11/17] docs --- onnxscript/ir/_external_data.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/onnxscript/ir/_external_data.py b/onnxscript/ir/_external_data.py index d5779fb6f..3d19bae5c 100644 --- a/onnxscript/ir/_external_data.py +++ b/onnxscript/ir/_external_data.py @@ -13,15 +13,22 @@ def _all_tensors( - graph: _core.Graph | _core.GraphView, include_constants: bool = False + graph: _core.Graph | _core.GraphView, include_attributes: bool = False ) -> Iterator[_protocols.TensorProtocol]: - """Iterate over all tensors in the graph.""" + """Iterate over all tensors in the graph. + Args: + graph: The graph to traverse tensors on. + include_attributes: Whether to include tensors in attributes. + + Yields: + Tensors in the graph. + """ # Yield all tensors in initializers for value in graph.initializers.values(): if value.const_value is not None: yield value.const_value - if not include_constants: + if not include_attributes: return # Look at constant attributes in nodes for node in traversal.RecursiveGraphIterator(graph): @@ -41,6 +48,6 @@ def set_base_dir(graph: _core.Graph | _core.GraphView, base_dir: str | os.PathLi graph: The graph to traverse tensors on. base_dir: The base directory. This is the directory where the ONNX file is. """ - for tensor in _all_tensors(graph, include_constants=True): + for tensor in _all_tensors(graph, include_attributes=True): if isinstance(tensor, _core.ExternalTensor): tensor.base_dir = base_dir From b69c3147fcb5fb5242a187413df06d7e4ed2bc70 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Aug 2024 19:04:15 +0000 Subject: [PATCH 12/17] basedir --- onnxscript/ir/_core_test.py | 17 +++++++++++++++++ onnxscript/ir/_io.py | 3 ++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index 1fbbca692..c284fa365 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -243,6 +243,23 @@ def test_initialize(self): # Ensure repeated reads are consistent np.testing.assert_equal(tensor, self.data) + def test_initialize_with_relative_path(self): + external_tensor = self.model.graph.initializer[0] + external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor) + tensor = _core.ExternalTensor( + path=external_info.location, + offset=external_info.offset, + length=external_info.length, + dtype=ir.DataType.FLOAT, + name="input", + shape=_core.Shape(external_tensor.dims), + base_dir=pathlib.Path(self.base_path), + ) + self.assertEqual(tensor.dtype, ir.DataType.FLOAT) + np.testing.assert_equal(tensor, self.data) + # Ensure repeated reads are consistent + np.testing.assert_equal(tensor, self.data) + def test_totypes_returns_correct_data_in(self): external_tensor = self.model.graph.initializer[0] external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor) diff --git a/onnxscript/ir/_io.py b/onnxscript/ir/_io.py index ce196d60a..a9c867f3f 100644 --- a/onnxscript/ir/_io.py +++ b/onnxscript/ir/_io.py @@ -46,4 +46,5 @@ def save(model: _core.Model, path: str | os.PathLike, format: str | None = None) """ proto = serde.serialize_model(model) onnx.save(proto, path, format=format) - # TODO(justinchuby): Handle external data + # TODO(justinchuby): Handle external data when the relative path has changed + # TODO(justinchuby): Handle off loading external data to disk when saving From e1cbc008ffab14f0a0827d88c67a851257b7a1e7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Aug 2024 19:27:49 +0000 Subject: [PATCH 13/17] data test --- onnxscript/ir/_external_data_test.py | 60 ++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 onnxscript/ir/_external_data_test.py diff --git a/onnxscript/ir/_external_data_test.py b/onnxscript/ir/_external_data_test.py new file mode 100644 index 000000000..2f80a2d2b --- /dev/null +++ b/onnxscript/ir/_external_data_test.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import numpy as np +import onnx +import onnx.external_data_helper + +from onnxscript import ir +from onnxscript.ir import _external_data +import tempfile + + +class ExternalDataTest(unittest.TestCase): + def test_set_base_dir_sets_base_dir_for_all_external_tensors(self): + attr_tensor = onnx.helper.make_tensor( + name="test_constant", + data_type=onnx.TensorProto.FLOAT, + dims=[1], + vals=b"\x01\x00\x00\x00", + raw=True, + ) + graph = onnx.helper.make_graph( + nodes=[ + onnx.helper.make_node( + "Constant", + [], + ["test"], + value=attr_tensor, + ) + ], + name="test", + inputs=[], + outputs=[], + initializer=[ + tensor := onnx.helper.make_tensor( + name="test_tensor", + data_type=onnx.TensorProto.FLOAT, + dims=[1], + vals=b"\x01\x00\x00\x00", + raw=True, + ), + ], + ) + model_proto = onnx.helper.make_model(graph) + onnx.external_data_helper.convert_model_to_external_data(model_proto, location="tempdir", size_threshold=0, convert_attribute=True) + model = ir.serde.deserialize_model(model_proto) + expected_dir = "something_else" + _external_data.set_base_dir(model.graph, expected_dir) + assert isinstance( + model.graph.initializers["test_tensor"].const_value, ir.ExternalTensor + ) + self.assertEqual( + model.graph.initializers["test_tensor"].const_value.base_dir, expected_dir + ) + self.assertEqual(model.graph.node(0).attributes["value"].value.base_dir, expected_dir) + + +if __name__ == "__main__": + unittest.main() From d01bb257c724ccfd713c7e42f8ad14da047cfa24 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Aug 2024 19:27:55 +0000 Subject: [PATCH 14/17] lint --- onnxscript/ir/_external_data_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/ir/_external_data_test.py b/onnxscript/ir/_external_data_test.py index 2f80a2d2b..62e510407 100644 --- a/onnxscript/ir/_external_data_test.py +++ b/onnxscript/ir/_external_data_test.py @@ -2,13 +2,11 @@ # Licensed under the MIT License. import unittest -import numpy as np import onnx import onnx.external_data_helper from onnxscript import ir from onnxscript.ir import _external_data -import tempfile class ExternalDataTest(unittest.TestCase): @@ -43,7 +41,9 @@ def test_set_base_dir_sets_base_dir_for_all_external_tensors(self): ], ) model_proto = onnx.helper.make_model(graph) - onnx.external_data_helper.convert_model_to_external_data(model_proto, location="tempdir", size_threshold=0, convert_attribute=True) + onnx.external_data_helper.convert_model_to_external_data( + model_proto, location="tempdir", size_threshold=0, convert_attribute=True + ) model = ir.serde.deserialize_model(model_proto) expected_dir = "something_else" _external_data.set_base_dir(model.graph, expected_dir) From fcbec0aa50c06a3c6fa0cd2a6f8c7fc2c29e63be Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Aug 2024 19:28:16 +0000 Subject: [PATCH 15/17] lint --- onnxscript/ir/_external_data_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/_external_data_test.py b/onnxscript/ir/_external_data_test.py index 62e510407..9978a3f4b 100644 --- a/onnxscript/ir/_external_data_test.py +++ b/onnxscript/ir/_external_data_test.py @@ -31,7 +31,7 @@ def test_set_base_dir_sets_base_dir_for_all_external_tensors(self): inputs=[], outputs=[], initializer=[ - tensor := onnx.helper.make_tensor( + onnx.helper.make_tensor( name="test_tensor", data_type=onnx.TensorProto.FLOAT, dims=[1], From 7887b9eb06d5b6e041b98c4f0187bbb74d5c0be7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Aug 2024 19:32:17 +0000 Subject: [PATCH 16/17] test --- onnxscript/ir/_external_data_test.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/onnxscript/ir/_external_data_test.py b/onnxscript/ir/_external_data_test.py index 9978a3f4b..6a6500c24 100644 --- a/onnxscript/ir/_external_data_test.py +++ b/onnxscript/ir/_external_data_test.py @@ -47,13 +47,16 @@ def test_set_base_dir_sets_base_dir_for_all_external_tensors(self): model = ir.serde.deserialize_model(model_proto) expected_dir = "something_else" _external_data.set_base_dir(model.graph, expected_dir) + + initializer_tensor = model.graph.initializers["test_tensor"].const_value assert isinstance( - model.graph.initializers["test_tensor"].const_value, ir.ExternalTensor + initializer_tensor, ir.ExternalTensor ) self.assertEqual( - model.graph.initializers["test_tensor"].const_value.base_dir, expected_dir + initializer_tensor.base_dir, expected_dir ) - self.assertEqual(model.graph.node(0).attributes["value"].value.base_dir, expected_dir) + attr_tensor = model.graph.node(0).attributes["value"].value + self.assertEqual(attr_tensor.base_dir, expected_dir) if __name__ == "__main__": From fb65fa4bed07fbc071cc4d16ddd3ac00747b4b5f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Aug 2024 21:14:48 +0000 Subject: [PATCH 17/17] lint --- onnxscript/ir/_external_data_test.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/onnxscript/ir/_external_data_test.py b/onnxscript/ir/_external_data_test.py index 6a6500c24..624f7e0a5 100644 --- a/onnxscript/ir/_external_data_test.py +++ b/onnxscript/ir/_external_data_test.py @@ -49,12 +49,8 @@ def test_set_base_dir_sets_base_dir_for_all_external_tensors(self): _external_data.set_base_dir(model.graph, expected_dir) initializer_tensor = model.graph.initializers["test_tensor"].const_value - assert isinstance( - initializer_tensor, ir.ExternalTensor - ) - self.assertEqual( - initializer_tensor.base_dir, expected_dir - ) + assert isinstance(initializer_tensor, ir.ExternalTensor) + self.assertEqual(initializer_tensor.base_dir, expected_dir) attr_tensor = model.graph.node(0).attributes["value"].value self.assertEqual(attr_tensor.base_dir, expected_dir)