You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I want to typecheck a class (typedict in the below example) with the contract that at any instant, all shape variables (eg. B, T, H, W) are the same for all tensors in the class but transforms such as crop_sample can modify the class. Is there a way to rebind H, W after such shape altering operations?
from typing import TypedDict
import torch
from jaxtyping import Float32, jaxtyped
from typeguard import typechecked
@jaxtyped(typechecker=typechecked)
class MyDict (TypedDict, total=False):
foo1: Float32[torch.Tensor, "B T 3 H W"]
foo2: Float32[torch.Tensor, "B T 3 H W"]
baz: Float32[torch.Tensor, "B 1 4 4"]
@jaxtyped(typechecker=typechecked)
def crop_sample(dict: MyDict) -> MyDict:
# Ensure we modify all tensors with the same crop
h_start, w_start = 50, 50
foo1 = dict["foo1"][:, :, :, h_start:, w_start:]
foo2 = dict["foo2"][:, :, :, h_start:, w_start:]
dict["foo1"] = foo1
dict["foo2"] = foo2
return dict
if __name__ == "__main__":
my_dict = MyDict(foo1=torch.randn(1, 1, 3, 100, 100), foo2=torch.randn(1, 1, 3, 100, 100))
print(my_dict["foo1"].shape)
my_dict = crop_sample(my_dict)
print(my_dict["foo1"].shape)
The text was updated successfully, but these errors were encountered:
I want to typecheck a class (typedict in the below example) with the contract that at any instant, all shape variables (eg. B, T, H, W) are the same for all tensors in the class but transforms such as
crop_sample
can modify the class. Is there a way to rebind H, W after such shape altering operations?The text was updated successfully, but these errors were encountered: