From 95f6e2f15bb468aa4c772701c598c1b93e70e665 Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Fri, 4 Nov 2022 15:17:39 -0700 Subject: [PATCH 1/3] Create a protocol class for tensor-like objects. Update the code to use the protocol instead of torch.Tensor directly. --- test/test_torchlike.py | 61 +++++++++++++++++++++++++++++++ torchtyping/__init__.py | 2 +- torchtyping/tensor_details.py | 68 ++++++++++++++++++++++++++++------- torchtyping/tensor_type.py | 3 +- torchtyping/typechecker.py | 12 +++---- 5 files changed, 125 insertions(+), 21 deletions(-) create mode 100644 test/test_torchlike.py diff --git a/test/test_torchlike.py b/test/test_torchlike.py new file mode 100644 index 0000000..bcf85e1 --- /dev/null +++ b/test/test_torchlike.py @@ -0,0 +1,61 @@ +# Test ability to type-check user defined classes which have a "torch-like" interface +# The required interface is defined as the protocol TensorLike in tensor_details.py + +from __future__ import annotations +import pytest +import torch +from torch import rand + +from torchtyping import TensorType, TensorTypeMixin +from typeguard import typechecked + + +# New class that supports the tensor-like interface +class MyTensor: + def __init__(self, tensor: torch.Tensor = torch.zeros(2, 3)): + self.tensor = tensor + self.dtype = self.tensor.dtype + self.layout = "something special" + self.names = self.tensor.names + self.shape = self.tensor.shape + + def is_floating_point(self) -> bool: + return self.dtype == torch.float32 + + # Add tensors and take the mean over the last dimension + # Output drops the last dimension + def __add__(self, o: torch.Tensor) -> MyTensor: + res = self.tensor + o + res_reduced = torch.mean(res, -1) + res_myt = MyTensor(res_reduced) + return res_myt + + +# Create a type corresponding to the new class +class MyTensorType(MyTensor, TensorTypeMixin): + base_cls = MyTensor + + +def test_my_tensor1(): + @typechecked + def func(x: MyTensorType["x", "y"], y: TensorType["x", "y"]) -> MyTensorType["x"]: + return x + y + + @typechecked + def bad_func_spec(x: MyTensorType["x", "y"], y: TensorType["x", "y"]) -> MyTensorType["x", "y"]: + return x + y + + my_t: MyTensor = MyTensor() + func(my_t, rand((2, 3))) + + # Incorrect input dimensions for x + with pytest.raises(TypeError): + func(MyTensor(rand(1)), rand((2, 3))) + + # Incorrect input dimensions for y + with pytest.raises(TypeError): + func(my_t, rand(1)) + + # Incorrect spec for return dimensions + with pytest.raises(TypeError): + bad_func_spec(my_t, rand((2, 3))) diff --git a/torchtyping/__init__.py b/torchtyping/__init__.py index 635ccfa..f17d06f 100644 --- a/torchtyping/__init__.py +++ b/torchtyping/__init__.py @@ -7,7 +7,7 @@ TensorDetail, ) -from .tensor_type import TensorType +from .tensor_type import TensorType, TensorTypeMixin from .typechecker import patch_typeguard __version__ = "0.1.4" diff --git a/torchtyping/tensor_details.py b/torchtyping/tensor_details.py index 56040a2..5a9a9ce 100644 --- a/torchtyping/tensor_details.py +++ b/torchtyping/tensor_details.py @@ -4,24 +4,66 @@ import collections import torch -from typing import Optional, Union +from typing import Optional, Union, runtime_checkable, Protocol, Tuple, Any ellipsis = type(...) +# Define a Protocol (PEP 544) class to represent "tensor-like" objects +# These are objects which support the interface given below +@runtime_checkable +class TensorLike(Protocol): + # We assume the class has a default constructor + def __init__(self): + pass + @property + def dtype(self) -> torch.dtype: + pass + + # leave the layout definition open because tensor-like classes are likely + # to extend it with new storage types + @property + def layout(self) -> Any: + pass + + @property + def names(self) -> Tuple[str, ...]: + pass + + @property + def shape(self) -> Tuple[int, ...]: + pass + + def is_floating_point(self) -> bool: + pass + + +class MyTensor: + def __init__(self): + self.dtype = torch.float32 + self.layout = "very special" + self.names = (None, None) + self.shape = (1, 1) + + def is_floating_point(self): + return self.dtype == torch.float32 + + + + class TensorDetail(metaclass=abc.ABCMeta): @abc.abstractmethod def __repr__(self) -> str: raise NotImplementedError @abc.abstractmethod - def check(self, tensor: torch.Tensor) -> bool: + def check(self, tensor: TensorLike) -> bool: raise NotImplementedError @classmethod @abc.abstractmethod - def tensor_repr(cls, tensor: torch.Tensor) -> str: + def tensor_repr(cls, tensor: TensorLike) -> str: raise NotImplementedError @@ -69,7 +111,7 @@ def __repr__(self) -> str: out += ", is_named" return out - def check(self, tensor: torch.Tensor) -> bool: + def check(self, tensor: TensorLike) -> bool: self_names = [self_dim.name for self_dim in self.dims] self_shape = [self_dim.size for self_dim in self.dims] @@ -103,7 +145,7 @@ def check(self, tensor: torch.Tensor) -> bool: return True @classmethod - def tensor_repr(cls, tensor: torch.Tensor) -> str: + def tensor_repr(cls, tensor: TensorLike) -> str: dims = [] check_names = any(name is not None for name in tensor.names) for name, size in zip(tensor.names, tensor.shape): @@ -133,11 +175,11 @@ def __init__(self, *, dtype, **kwargs) -> None: def __repr__(self) -> str: return repr(self.dtype) - def check(self, tensor: torch.Tensor) -> bool: + def check(self, tensor: TensorLike) -> bool: return self.dtype == tensor.dtype @classmethod - def tensor_repr(cls, tensor: torch.Tensor) -> str: + def tensor_repr(cls, tensor: TensorLike) -> str: return repr(cls(dtype=tensor.dtype)) @@ -149,11 +191,11 @@ def __init__(self, *, layout, **kwargs) -> None: def __repr__(self) -> str: return repr(self.layout) - def check(self, tensor: torch.Tensor) -> bool: + def check(self, tensor: TensorLike) -> bool: return self.layout == tensor.layout @classmethod - def tensor_repr(cls, tensor: torch.Tensor) -> str: + def tensor_repr(cls, tensor: TensorLike) -> str: return repr(cls(layout=tensor.layout)) @@ -161,11 +203,11 @@ class _FloatDetail(TensorDetail): def __repr__(self) -> str: return "is_float" - def check(self, tensor: torch.Tensor) -> bool: + def check(self, tensor: TensorLike) -> bool: return tensor.is_floating_point() @classmethod - def tensor_repr(cls, tensor: torch.Tensor) -> str: + def tensor_repr(cls, tensor: TensorLike) -> str: return "is_float" if tensor.is_floating_point() else "" @@ -177,11 +219,11 @@ class _NamedTensorDetail(TensorDetail): def __repr__(self) -> str: raise RuntimeError - def check(self, tensor: torch.Tensor) -> bool: + def check(self, tensor: TensorLike) -> bool: raise RuntimeError @classmethod - def tensor_repr(cls, tensor: torch.Tensor) -> str: + def tensor_repr(cls, tensor: TensorLike) -> str: raise RuntimeError diff --git a/torchtyping/tensor_type.py b/torchtyping/tensor_type.py index e5e9289..064e463 100644 --- a/torchtyping/tensor_type.py +++ b/torchtyping/tensor_type.py @@ -11,6 +11,7 @@ LayoutDetail, ShapeDetail, TensorDetail, + TensorLike, ) from .utils import frozendict @@ -25,7 +26,7 @@ from typing_extensions import Annotated # Not Type[Annotated...] as we want to use this in instance checks. -_AnnotatedType = type(Annotated[torch.Tensor, ...]) +_AnnotatedType = type(Annotated[TensorLike, ...]) # For use when we have a plain TensorType, without any []. diff --git a/torchtyping/typechecker.py b/torchtyping/typechecker.py index 84d6705..79ffeac 100644 --- a/torchtyping/typechecker.py +++ b/torchtyping/typechecker.py @@ -3,7 +3,7 @@ import torch import typeguard -from .tensor_details import _Dim, _no_name, ShapeDetail +from .tensor_details import _Dim, _no_name, ShapeDetail, TensorLike from .tensor_type import _AnnotatedType from typing import Any, Dict, List, Tuple @@ -60,7 +60,7 @@ def _to_string(name, detail_reprs: List[str]) -> str: def _check_tensor( - argname: str, value: Any, origin: Type[torch.Tensor], metadata: Dict[str, Any] + argname: str, value: Any, origin: TensorLike, metadata: Dict[str, Any] ): details = metadata["details"] if not isinstance(value, origin) or any( @@ -69,7 +69,7 @@ def _check_tensor( expected_string = _to_string( metadata["cls_name"], [repr(detail) for detail in details] ) - if isinstance(value, torch.Tensor): + if isinstance(value, TensorLike): given_string = _to_string( metadata["cls_name"], [detail.tensor_repr(value) for detail in details] ) @@ -253,7 +253,7 @@ def _check_memo(memo): dims.append(_Dim(name=dim.name, size=size)) detail = detail.update(dims=tuple(dims)) _check_tensor( - argname, value, torch.Tensor, {"cls_name": cls_name, "details": [detail]} + argname, value, TensorLike, {"cls_name": cls_name, "details": [detail]} ) @@ -274,7 +274,7 @@ class _CallMemo(typeguard._CallMemo): "name_to_size", "name_to_shape", ) - value_info: List[Tuple[str, torch.Tensor, str, Dict[str, Any]]] + value_info: List[Tuple[str, TensorLike, str, Dict[str, Any]]] name_to_size: Dict[str, int] name_to_shape: Dict[str, Tuple[int]] @@ -301,7 +301,7 @@ def check_type(*args, **kwargs): # Now check if it's annotating a tensor if is_torchtyping_annotation: base_cls, *all_metadata = get_args(expected_type) - if not issubclass(base_cls, torch.Tensor): + if not isinstance(base_cls(), TensorLike): is_torchtyping_annotation = False # Now check if the annotation's metadata is our metadata if is_torchtyping_annotation: From 7cca5574da7b39f43c14ff8c804400ccee7535e8 Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Fri, 4 Nov 2022 15:29:05 -0700 Subject: [PATCH 2/3] black reformatting --- test/test_torchlike.py | 8 +++++++- torchtyping/tensor_details.py | 3 +-- torchtyping/typechecker.py | 5 ++--- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/test/test_torchlike.py b/test/test_torchlike.py index bcf85e1..bf5d3e2 100644 --- a/test/test_torchlike.py +++ b/test/test_torchlike.py @@ -36,13 +36,19 @@ class MyTensorType(MyTensor, TensorTypeMixin): base_cls = MyTensor +# make flake8 happy +x = y = None + + def test_my_tensor1(): @typechecked def func(x: MyTensorType["x", "y"], y: TensorType["x", "y"]) -> MyTensorType["x"]: return x + y @typechecked - def bad_func_spec(x: MyTensorType["x", "y"], y: TensorType["x", "y"]) -> MyTensorType["x", "y"]: + def bad_func_spec( + x: MyTensorType["x", "y"], y: TensorType["x", "y"] + ) -> MyTensorType["x", "y"]: return x + y my_t: MyTensor = MyTensor() diff --git a/torchtyping/tensor_details.py b/torchtyping/tensor_details.py index 5a9a9ce..932509e 100644 --- a/torchtyping/tensor_details.py +++ b/torchtyping/tensor_details.py @@ -17,6 +17,7 @@ class TensorLike(Protocol): # We assume the class has a default constructor def __init__(self): pass + @property def dtype(self) -> torch.dtype: pass @@ -50,8 +51,6 @@ def is_floating_point(self): return self.dtype == torch.float32 - - class TensorDetail(metaclass=abc.ABCMeta): @abc.abstractmethod def __repr__(self) -> str: diff --git a/torchtyping/typechecker.py b/torchtyping/typechecker.py index 79ffeac..2afa1aa 100644 --- a/torchtyping/typechecker.py +++ b/torchtyping/typechecker.py @@ -1,6 +1,5 @@ import inspect import sys -import torch import typeguard from .tensor_details import _Dim, _no_name, ShapeDetail, TensorLike @@ -11,9 +10,9 @@ # get_args is available in python version 3.8 # get_type_hints with include_extras parameter is available in 3.9 PEP 593. if sys.version_info >= (3, 9): - from typing import get_type_hints, get_args, Type + from typing import get_type_hints, get_args else: - from typing_extensions import get_type_hints, get_args, Type + from typing_extensions import get_type_hints, get_args # TYPEGUARD PATCHER From ed773ba889dfefd066d219d0f6f6ee45cc641122 Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Fri, 4 Nov 2022 16:39:08 -0700 Subject: [PATCH 3/3] Remove default construction requirement. --- torchtyping/tensor_details.py | 4 ---- torchtyping/typechecker.py | 4 +--- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/torchtyping/tensor_details.py b/torchtyping/tensor_details.py index 932509e..e1153d3 100644 --- a/torchtyping/tensor_details.py +++ b/torchtyping/tensor_details.py @@ -14,10 +14,6 @@ # These are objects which support the interface given below @runtime_checkable class TensorLike(Protocol): - # We assume the class has a default constructor - def __init__(self): - pass - @property def dtype(self) -> torch.dtype: pass diff --git a/torchtyping/typechecker.py b/torchtyping/typechecker.py index 2afa1aa..7aa08c1 100644 --- a/torchtyping/typechecker.py +++ b/torchtyping/typechecker.py @@ -297,11 +297,9 @@ def check_type(*args, **kwargs): and hasattr(memo, "value_info") and isinstance(expected_type, _AnnotatedType) ) - # Now check if it's annotating a tensor + # Grab the base class if is_torchtyping_annotation: base_cls, *all_metadata = get_args(expected_type) - if not isinstance(base_cls(), TensorLike): - is_torchtyping_annotation = False # Now check if the annotation's metadata is our metadata if is_torchtyping_annotation: for metadata in all_metadata: