diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 51c6d83502..898855edef 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -12,7 +12,6 @@ from __future__ import annotations -import abc import contextlib import dataclasses import heapq @@ -40,7 +39,7 @@ import ml_dtypes import numpy as np -from typing_extensions import TypeIs +from typing_extensions import Buffer, TypeIs import onnxscript from onnxscript.ir import ( @@ -95,7 +94,7 @@ def _compatible_with_dlpack(obj: Any) -> TypeGuard[_protocols.DLPackCompatible]: return hasattr(obj, "__dlpack__") -class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable): +class TensorBase(Buffer, _protocols.TensorProtocol, _display.PrettyPrintable): """Convenience Shared methods for classes implementing TensorProtocol.""" __slots__ = () @@ -111,6 +110,13 @@ def _repr_base(self) -> str: """ return f"{self.__class__.__name__}<{self._printable_type_shape()}>" + def __buffer__(self, flags: int, /) -> memoryview: + """Return a memoryview of the tensor. + + This is used to support the buffer protocol. + """ + return self.tobytes().__buffer__(flags) + @property def size(self) -> int: """The number of elements in the tensor.""" @@ -408,6 +414,29 @@ def __dlpack_device__(self) -> tuple[int, int]: def __repr__(self) -> str: return f"{self._repr_base()}({self._raw!r}, name={self.name!r})" + def __buffer__(self, flags: int, /) -> memoryview: + """Return a memoryview of the tensor. + + This is used to support the buffer protocol. + """ + if self.dtype in { + _enums.DataType.INT4, + _enums.DataType.UINT4, + _enums.DataType.FLOAT4E2M1, + }: + # Packing is required. So we call tobytes() directly + return self.tobytes().__buffer__(flags) + + # Otherwise get the memoryview from the numpy array + array = self.numpy() + if not array.data.c_contiguous: + array = np.ascontiguousarray(array) + assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match" + if not _IS_LITTLE_ENDIAN: + # Need to copy because we are returning the underlying data directly + array = array.view(array.dtype.newbyteorder("<")).copy() + return array.__buffer__(flags) + @property def dtype(self) -> _enums.DataType: """The data type of the tensor. Immutable.""" @@ -657,6 +686,19 @@ def __array__(self, dtype: Any = None) -> np.ndarray: assert self._array is not None return self._array.__array__(dtype) + def __buffer__(self, flags: int, /) -> memoryview: + """Return a memoryview of the tensor. + + This is used to support the buffer protocol. + """ + self._check_validity() + if self.raw is None: + self._load() + assert self.raw is not None + offset = self._offset or 0 + length = self._length or self.nbytes + return memoryview(self.raw)[offset : offset + length] + def __dlpack__(self, *, stream: Any = None) -> Any: raise NotImplementedError( "ExternalTensor does not support DLPack because it uses memory mapping. " @@ -953,6 +995,13 @@ def __dlpack_device__(self) -> tuple[int, int]: def __repr__(self) -> str: return f"{self._repr_base()}(func={self._func!r}, name={self.name!r})" + def __buffer__(self, flags: int, /) -> memoryview: + """Return a memoryview of the tensor. + + This is used to support the buffer protocol. + """ + return self._evaluate().__buffer__(flags) + @property def raw(self) -> Callable[[], _protocols.TensorProtocol]: return self._func diff --git a/onnxscript/ir/_protocols.py b/onnxscript/ir/_protocols.py index fbc2c7c054..eaf7037635 100644 --- a/onnxscript/ir/_protocols.py +++ b/onnxscript/ir/_protocols.py @@ -133,6 +133,10 @@ def __array__(self, dtype: Any = None) -> np.ndarray: """Return the tensor as a numpy array, compatible with np.array.""" ... + def __buffer__(self, flags: int, /) -> memoryview: + """Return a view of the tensor data.""" + ... + def __dlpack__(self, *, stream: Any = ...) -> Any: """Return PyCapsule.""" ... diff --git a/onnxscript/ir/external_data.py b/onnxscript/ir/external_data.py index 4ca9ca5036..3adb580fc4 100644 --- a/onnxscript/ir/external_data.py +++ b/onnxscript/ir/external_data.py @@ -173,14 +173,14 @@ def _write_external_data( for tensor, tensor_info in zip(tensors, external_data_infos, strict=True): current_offset = tensor_info.offset assert tensor is not None - raw_data = tensor.tobytes() - if isinstance(tensor, _core.ExternalTensor): - tensor.release() # Pad file to required offset if needed file_size = data_file.tell() if current_offset > file_size: data_file.write(b"\0" * (current_offset - file_size)) - data_file.write(raw_data) + with memoryview(tensor) as view: + data_file.write(view) + if isinstance(tensor, _core.ExternalTensor): + tensor.release() def _create_external_tensor( diff --git a/onnxscript/ir/tensor_adapters.py b/onnxscript/ir/tensor_adapters.py index 0a74e0a74c..eba04e7e96 100644 --- a/onnxscript/ir/tensor_adapters.py +++ b/onnxscript/ir/tensor_adapters.py @@ -79,18 +79,22 @@ def __init__( def numpy(self) -> npt.NDArray: import torch - self.raw: torch.Tensor + # Calling .contiguous() is usually less costly than calling it on numpy arrays + # so we do it first for users assuming a contiguous array is needed for most usages + torch_tensor: torch.Tensor = self.raw + if not torch_tensor.is_contiguous(): + torch_tensor = torch_tensor.contiguous() if self.dtype == ir.DataType.BFLOAT16: - return self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy()) + return torch_tensor.view(torch.uint16).numpy(force=True).view(self.dtype.numpy()) if self.dtype in { ir.DataType.FLOAT8E4M3FN, ir.DataType.FLOAT8E4M3FNUZ, ir.DataType.FLOAT8E5M2, ir.DataType.FLOAT8E5M2FNUZ, }: - return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy()) + return torch_tensor.view(torch.uint8).numpy(force=True).view(self.dtype.numpy()) - return self.raw.numpy(force=True) + return torch_tensor.numpy(force=True) def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray: del copy # Unused, but needed for the signature @@ -98,6 +102,25 @@ def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray: return self.numpy() return self.numpy().__array__(dtype) + def __buffer__(self, flags: int, /) -> memoryview: + """Return a memoryview of the tensor. + + This is used to support the buffer protocol. + """ + if self.dtype in { + ir.DataType.INT4, + ir.DataType.UINT4, + ir.DataType.FLOAT4E2M1, + }: + # Packing is required. So we call tobytes() directly + return self.tobytes().__buffer__(flags) + + # Otherwise get the memoryview from the numpy array + array = self.numpy() + assert array.data.c_contiguous, "Bug: The array should be contiguous" + assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match" + return array.__buffer__(flags) + def tobytes(self) -> bytes: # Implement tobytes to support native PyTorch types so we can use types like bloat16 # Reading from memory directly is also more efficient because