Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions example/intelli_enabled.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Intellisense Enabled\n",
"\n",
"``` python\n",
"class TensorType(torch.Tensor, Generic[Unpack[Ts]], metaclass=_TensorTypeMeta):\n",
"```\n",
"\n",
"`Generic[Unpack[Ts]]` is added as a dummy.\\\n",
"It doesn't have any runtime functionality, but TensorType can provide more informative hover due to the dummy.\n",
"\n",
"See the example code, and the images (from VSCode IDE)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example Code"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torchtyping import TensorType\n",
"\n",
"batch = int\n",
"channel = int\n",
"width = int\n",
"height = int\n",
"\n",
"tensor: TensorType[batch, channel, width, height]\n",
"\n",
"\n",
"def forward(input: TensorType[batch, channel, width, height]):\n",
" pass\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Hover Images\n",
"\n",
"**Fig. Hover1**\n",
"\n",
"![hover1](./static/hover1.png)\n",
"\n",
"**Fig. Hover2**\n",
"\n",
"![hover1](./static/hover2.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Advanced Usage\n",
"\n",
"This application doesn't require any additional implementation.\n",
"\n",
"Define a new class as a custom type with docstring. Optionally add `Generic[T]` for the same purpose with `Generic[Unpack[Ts]]`.\n",
"\n",
"Check the example code, and hover images.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example Code"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import Generic, TypeVar\n",
"T = TypeVar('T')\n",
"\n",
"class Batch_(int, Generic[T]):\n",
" \"\"\"\n",
" Docstring dedicated to a parameter.<br>\n",
" It is extremely useful when some parameters are reused everywhere.\n",
" \"\"\"\n",
"\n",
"\n",
"\n",
"\n",
"def forward(input: TensorType[Batch_[int], channel, width, height]):\n",
" pass\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Hover Images\n",
"\n",
"**Fig. Advanced1**\n",
"\n",
"![advanced1](./static/advanced1.png)\n",
"\n",
"\n",
"**Fig. Advanced2**\n",
"\n",
"![advanced2](./static/advanced2.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Reference\n",
"\n",
"Check more examples\n",
"\n",
"- [IntelliType Readme](https://github.com/crimson206/intelli-type/tree/main?tab=readme-ov-file#intellitype).\n",
"- [DeepLearning Example](https://github.com/crimson206/intelli-type/blob/main/example/fusion_block_edit.py)."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Binary file added example/static/advanced1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added example/static/advanced2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added example/static/hover1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added example/static/hover2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 4 additions & 2 deletions torchtyping/tensor_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
)
from .utils import frozendict

from typing import Any, NoReturn
from typing import Any, NoReturn, Generic, TypeVarTuple, Unpack

Ts = TypeVarTuple("Ts")

# Annotated is available in python version 3.9 (PEP 593)
if sys.version_info >= (3, 9):
Expand All @@ -36,7 +38,7 @@ def __instancecheck__(cls, obj: Any) -> bool:

# Inherit from torch.Tensor so that IDEs are happy to find methods on functions
# annotated as TensorTypes.
class TensorType(torch.Tensor, metaclass=_TensorTypeMeta):
class TensorType(torch.Tensor, Generic[Unpack[Ts]], metaclass=_TensorTypeMeta):
base_cls = torch.Tensor

def __new__(cls, *args, **kwargs) -> NoReturn:
Expand Down