Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Support safetensors in tensor adapters #1933

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 65 additions & 3 deletions onnxscript/ir/tensor_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
# pylint: disable=import-outside-toplevel

# NOTE: DO NOT import any framework-specific modules here in the global namespace.
# NOTE: We use ir.DataType instead of _enums.DataType to show users how they
# should create custom tensor adapters. This is fine and will not create
# circular imports because the ir.DataType's are not used in the global namespace.

from __future__ import annotations

Expand All @@ -38,13 +41,20 @@
import numpy.typing as npt

from onnxscript import ir
from onnxscript.ir import _core, _enums

if TYPE_CHECKING:
import torch


class TorchTensor(ir.Tensor):
def __init__(self, tensor: torch.Tensor, name: str | None = None):
class TorchTensor(_core.Tensor):
def __init__(
self,
tensor: torch.Tensor,
name: str | None = None,
doc_string: str | None = None,
metadata_props: dict[str, str] | None = None,
) -> None:
# Pass the tensor as the raw data to ir.Tensor's constructor
import torch

Expand All @@ -69,7 +79,13 @@
torch.uint32: ir.DataType.UINT32,
torch.uint64: ir.DataType.UINT64,
}
super().__init__(tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name)
super().__init__(
tensor,
dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype],
name=name,
doc_string=doc_string,
metadata_props=metadata_props,
)

def numpy(self) -> npt.NDArray:
import torch
Expand Down Expand Up @@ -115,3 +131,49 @@
tensor.data_ptr()
)
)


class SafeTensorsTensor(_core.Tensor):
def __init__(
self,
path: str,
tensor_name: str,
/,
dtype: _enums.DataType | None = None,
*,
shape: _core.Shape | None = None,
name: str | None = None,
doc_string: str | None = None,
metadata_props: dict[str, str] | None = None,
) -> None:
"""Create a tensor from a tensor stored in a SafeTensors file.

Args:
path: The path to the SafeTensors file.
tensor_name: The name of the tensor in the SafeTensors file.
dtype: The data type of the tensor. It can be specified if the value
is not of a standard NumPy dtype.
shape: The shape of the tensor. It can be specified if the value
is not of a standard NumPy dtype.
name: The name of the ONNX tensor.
doc_string: The documentation string for the tensor.
metadata_props: The metadata properties for the tensor.
"""
import safetensors

self._path = path
self._tensor_name = tensor_name

with safetensors.safe_open(path, framework="numpy") as f:

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

"safe_open" has no attribute "__enter__" To disable, use # type: ignore[attr-defined]

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

"safe_open" has no attribute "__exit__" To disable, use # type: ignore[attr-defined]
# The tensor is mmap'ed in memory so we might as well load it
# at initialization time since it does not take up any extra memory
array = f.get_tensor(tensor_name)

super().__init__(
array,
dtype=dtype,
shape=shape,
name=name,
doc_string=doc_string,
metadata_props=metadata_props,
)
Loading