Skip to content

Commit 540696c

Browse files
authored
[API] Create stable APIs for PyTorch 2.5 (#1832)
Create stable APIs for PyTorch 2.5 so that it does not need to use any internal ONNX Script APIs. Created APIs are ``` "check_model", "convert_version", "get_torchlib_ops", "optimize", "save_model_with_external_data", ``` In pytorch, it is expected to write: ```python import onnxscript._framework_apis.torch_2_5 ``` Fixes #1827
1 parent e037aa0 commit 540696c

3 files changed

Lines changed: 164 additions & 1 deletion

File tree

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Semi-private stable APIs for framework-specific usage only."""
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Stable APIs for PyTorch 2.5."""
4+
5+
from __future__ import annotations
6+
7+
__all__ = [
8+
"check_model",
9+
"convert_version",
10+
"get_torchlib_ops",
11+
"optimize",
12+
"save_model_with_external_data",
13+
]
14+
15+
import dataclasses
16+
import os
17+
import pathlib
18+
from typing import Callable
19+
20+
import onnx
21+
22+
from onnxscript import ir
23+
from onnxscript.function_libs.torch_lib import registration
24+
from onnxscript.ir import _external_data
25+
26+
# Internal flag. Will go away.
27+
_TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR = (
28+
os.getenv("TORCH_ONNX_OFFLOAD_EXTERNAL_DATA_WITH_IR") == "1"
29+
)
30+
31+
32+
@dataclasses.dataclass(frozen=True)
33+
class _OnnxFunctionMeta:
34+
"""A wrapper of onnx-script function with additional metadata.
35+
36+
qualified_name: The qualified name of the aten operator.
37+
function: The onnx-script function.
38+
domain: The domain of the function.
39+
name: The name of the function.
40+
is_complex: Whether the function is a complex function.
41+
"""
42+
43+
qualified_name: str
44+
function: Callable
45+
domain: str
46+
name: str
47+
is_complex: bool = False
48+
49+
50+
def optimize(model: ir.Model) -> ir.Model:
51+
"""Optimize the model."""
52+
53+
# TODO(justinchuby): Use the optimizer
54+
return model
55+
56+
57+
def convert_version(model: ir.Model, target_version: int) -> ir.Model:
58+
"""Convert the model to the specified ONNX opset version."""
59+
# model_version = model.opset_import.get("")
60+
# if model_version == target_version:
61+
# # No conversion needed
62+
# return model
63+
64+
# # FIXME(justinchuby): version_converter does not support functions
65+
# proto = ir.serde.serialize_model(model)
66+
# proto = onnx.version_converter.convert_version(proto, target_version)
67+
# return ir.serde.deserialize_model(proto)
68+
# TODO(justinchuby): This function needs to be carefully implemented
69+
# to handle large models. For now, we just return the model.
70+
del target_version # Unused
71+
return model
72+
73+
74+
def check_model(model: ir.Model) -> None:
75+
"""Check the model."""
76+
77+
del model # Unused yet
78+
79+
80+
def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike) -> None:
81+
"""Save the model with external data. The model is unchanged after saving."""
82+
83+
# TODO(#1835): Decide if we want to externalize large attributes as well
84+
if _TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR:
85+
initializer_values = tuple(model.graph.initializers.values())
86+
tensors = [v.const_value for v in initializer_values]
87+
for tensor in tensors:
88+
if tensor is None:
89+
raise ValueError(
90+
"The model contains uninitialized initializer values. "
91+
"Please make sure all initializer values are initialized."
92+
)
93+
destination_path = pathlib.Path(model_path)
94+
base_dir = destination_path.parent
95+
data_path = f"{destination_path.name}.data"
96+
97+
external_tensors = _external_data.convert_tensors_to_external(
98+
tensors, # type: ignore[arg-type]
99+
base_dir,
100+
data_path,
101+
)
102+
103+
# Replace the initializer values with external tensors and save the model
104+
for initializer, external_tensor in zip(initializer_values, external_tensors):
105+
initializer.const_value = external_tensor
106+
ir.save(model, model_path)
107+
108+
# Restore the original initializer values so the model is unchanged
109+
for initializer, tensor in zip(initializer_values, tensors):
110+
initializer.const_value = tensor
111+
112+
else:
113+
destination_path = pathlib.Path(model_path)
114+
# Create the directory if it does not exist
115+
data_path = f"{destination_path.name}.data"
116+
proto = ir.serde.serialize_model(model)
117+
onnx.save_model(
118+
proto,
119+
model_path,
120+
save_as_external_data=True,
121+
location=data_path,
122+
)
123+
124+
125+
def get_torchlib_ops() -> list[_OnnxFunctionMeta]:
126+
# Trigger op registration
127+
from onnxscript.function_libs.torch_lib import ( # pylint: disable=import-outside-toplevel
128+
ops,
129+
)
130+
131+
del ops # Unused
132+
133+
torchlib_registry = registration.default_registry
134+
function_metas = []
135+
136+
for qualified_name, aten_overloads_func in torchlib_registry.items():
137+
if qualified_name.startswith("internal::"):
138+
# Skip the custom defined internal functions
139+
continue
140+
141+
for overload_func in aten_overloads_func.overloads:
142+
function_meta = _OnnxFunctionMeta(
143+
qualified_name=qualified_name,
144+
function=overload_func,
145+
domain=overload_func.function_ir.domain,
146+
name=overload_func.name,
147+
is_complex=False,
148+
)
149+
function_metas.append(function_meta)
150+
for complex_func in aten_overloads_func.complex:
151+
function_meta = _OnnxFunctionMeta(
152+
qualified_name=qualified_name,
153+
function=complex_func,
154+
domain=complex_func.function_ir.domain,
155+
name=complex_func.name,
156+
is_complex=True,
157+
)
158+
function_metas.append(function_meta)
159+
160+
return function_metas

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ ignore-init-module-imports = true
210210
[tool.ruff.lint.per-file-ignores]
211211
"__init__.py" = ["TID252"] # Allow relative imports in init files
212212
"setup.py" = ["TID251"] # pathlib is allowed in supporting code
213-
"**/{examples,tests,docs,tools,utils,opgen}/*" = ["TID251"] # pathlib is allowed in supporting code
213+
"**/{examples,tests,docs,tools,utils,opgen,_framework_apis}/*" = ["TID251"] # pathlib is allowed in supporting code
214214
"**/*_test.py" = ["TID251"] # pathlib is allowed in tests
215215

216216
[tool.ruff.lint.flake8-tidy-imports]

0 commit comments

Comments
 (0)