|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | +"""Stable APIs for PyTorch 2.5.""" |
| 4 | + |
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +__all__ = [ |
| 8 | + "check_model", |
| 9 | + "convert_version", |
| 10 | + "get_torchlib_ops", |
| 11 | + "optimize", |
| 12 | + "save_model_with_external_data", |
| 13 | +] |
| 14 | + |
| 15 | +import dataclasses |
| 16 | +import os |
| 17 | +import pathlib |
| 18 | +from typing import Callable |
| 19 | + |
| 20 | +import onnx |
| 21 | + |
| 22 | +from onnxscript import ir |
| 23 | +from onnxscript.function_libs.torch_lib import registration |
| 24 | +from onnxscript.ir import _external_data |
| 25 | + |
| 26 | +# Internal flag. Will go away. |
| 27 | +_TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR = ( |
| 28 | + os.getenv("TORCH_ONNX_OFFLOAD_EXTERNAL_DATA_WITH_IR") == "1" |
| 29 | +) |
| 30 | + |
| 31 | + |
| 32 | +@dataclasses.dataclass(frozen=True) |
| 33 | +class _OnnxFunctionMeta: |
| 34 | + """A wrapper of onnx-script function with additional metadata. |
| 35 | +
|
| 36 | + qualified_name: The qualified name of the aten operator. |
| 37 | + function: The onnx-script function. |
| 38 | + domain: The domain of the function. |
| 39 | + name: The name of the function. |
| 40 | + is_complex: Whether the function is a complex function. |
| 41 | + """ |
| 42 | + |
| 43 | + qualified_name: str |
| 44 | + function: Callable |
| 45 | + domain: str |
| 46 | + name: str |
| 47 | + is_complex: bool = False |
| 48 | + |
| 49 | + |
| 50 | +def optimize(model: ir.Model) -> ir.Model: |
| 51 | + """Optimize the model.""" |
| 52 | + |
| 53 | + # TODO(justinchuby): Use the optimizer |
| 54 | + return model |
| 55 | + |
| 56 | + |
| 57 | +def convert_version(model: ir.Model, target_version: int) -> ir.Model: |
| 58 | + """Convert the model to the specified ONNX opset version.""" |
| 59 | + # model_version = model.opset_import.get("") |
| 60 | + # if model_version == target_version: |
| 61 | + # # No conversion needed |
| 62 | + # return model |
| 63 | + |
| 64 | + # # FIXME(justinchuby): version_converter does not support functions |
| 65 | + # proto = ir.serde.serialize_model(model) |
| 66 | + # proto = onnx.version_converter.convert_version(proto, target_version) |
| 67 | + # return ir.serde.deserialize_model(proto) |
| 68 | + # TODO(justinchuby): This function needs to be carefully implemented |
| 69 | + # to handle large models. For now, we just return the model. |
| 70 | + del target_version # Unused |
| 71 | + return model |
| 72 | + |
| 73 | + |
| 74 | +def check_model(model: ir.Model) -> None: |
| 75 | + """Check the model.""" |
| 76 | + |
| 77 | + del model # Unused yet |
| 78 | + |
| 79 | + |
| 80 | +def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike) -> None: |
| 81 | + """Save the model with external data. The model is unchanged after saving.""" |
| 82 | + |
| 83 | + # TODO(#1835): Decide if we want to externalize large attributes as well |
| 84 | + if _TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR: |
| 85 | + initializer_values = tuple(model.graph.initializers.values()) |
| 86 | + tensors = [v.const_value for v in initializer_values] |
| 87 | + for tensor in tensors: |
| 88 | + if tensor is None: |
| 89 | + raise ValueError( |
| 90 | + "The model contains uninitialized initializer values. " |
| 91 | + "Please make sure all initializer values are initialized." |
| 92 | + ) |
| 93 | + destination_path = pathlib.Path(model_path) |
| 94 | + base_dir = destination_path.parent |
| 95 | + data_path = f"{destination_path.name}.data" |
| 96 | + |
| 97 | + external_tensors = _external_data.convert_tensors_to_external( |
| 98 | + tensors, # type: ignore[arg-type] |
| 99 | + base_dir, |
| 100 | + data_path, |
| 101 | + ) |
| 102 | + |
| 103 | + # Replace the initializer values with external tensors and save the model |
| 104 | + for initializer, external_tensor in zip(initializer_values, external_tensors): |
| 105 | + initializer.const_value = external_tensor |
| 106 | + ir.save(model, model_path) |
| 107 | + |
| 108 | + # Restore the original initializer values so the model is unchanged |
| 109 | + for initializer, tensor in zip(initializer_values, tensors): |
| 110 | + initializer.const_value = tensor |
| 111 | + |
| 112 | + else: |
| 113 | + destination_path = pathlib.Path(model_path) |
| 114 | + # Create the directory if it does not exist |
| 115 | + data_path = f"{destination_path.name}.data" |
| 116 | + proto = ir.serde.serialize_model(model) |
| 117 | + onnx.save_model( |
| 118 | + proto, |
| 119 | + model_path, |
| 120 | + save_as_external_data=True, |
| 121 | + location=data_path, |
| 122 | + ) |
| 123 | + |
| 124 | + |
| 125 | +def get_torchlib_ops() -> list[_OnnxFunctionMeta]: |
| 126 | + # Trigger op registration |
| 127 | + from onnxscript.function_libs.torch_lib import ( # pylint: disable=import-outside-toplevel |
| 128 | + ops, |
| 129 | + ) |
| 130 | + |
| 131 | + del ops # Unused |
| 132 | + |
| 133 | + torchlib_registry = registration.default_registry |
| 134 | + function_metas = [] |
| 135 | + |
| 136 | + for qualified_name, aten_overloads_func in torchlib_registry.items(): |
| 137 | + if qualified_name.startswith("internal::"): |
| 138 | + # Skip the custom defined internal functions |
| 139 | + continue |
| 140 | + |
| 141 | + for overload_func in aten_overloads_func.overloads: |
| 142 | + function_meta = _OnnxFunctionMeta( |
| 143 | + qualified_name=qualified_name, |
| 144 | + function=overload_func, |
| 145 | + domain=overload_func.function_ir.domain, |
| 146 | + name=overload_func.name, |
| 147 | + is_complex=False, |
| 148 | + ) |
| 149 | + function_metas.append(function_meta) |
| 150 | + for complex_func in aten_overloads_func.complex: |
| 151 | + function_meta = _OnnxFunctionMeta( |
| 152 | + qualified_name=qualified_name, |
| 153 | + function=complex_func, |
| 154 | + domain=complex_func.function_ir.domain, |
| 155 | + name=complex_func.name, |
| 156 | + is_complex=True, |
| 157 | + ) |
| 158 | + function_metas.append(function_meta) |
| 159 | + |
| 160 | + return function_metas |
0 commit comments