Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jaxtyping a class with mutable shapes #303

Open
SConsul opened this issue Feb 28, 2025 · 1 comment
Open

Jaxtyping a class with mutable shapes #303

SConsul opened this issue Feb 28, 2025 · 1 comment
Labels
question User queries

Comments

@SConsul
Copy link

SConsul commented Feb 28, 2025

I want to typecheck a class (typedict in the below example) with the contract that at any instant, all shape variables (eg. B, T, H, W) are the same for all tensors in the class but transforms such as crop_sample can modify the class. Is there a way to rebind H, W after such shape altering operations?

from typing import TypedDict

import torch
from jaxtyping import Float32, jaxtyped
from typeguard import typechecked



@jaxtyped(typechecker=typechecked)
class MyDict (TypedDict, total=False):
    foo1: Float32[torch.Tensor, "B T 3 H W"]
    foo2: Float32[torch.Tensor, "B T 3 H W"]
    baz: Float32[torch.Tensor, "B 1 4 4"]


@jaxtyped(typechecker=typechecked)
def crop_sample(dict: MyDict) -> MyDict:
    # Ensure we modify all tensors with the same crop
    h_start, w_start = 50, 50
    foo1 = dict["foo1"][:, :, :, h_start:, w_start:]
    foo2 = dict["foo2"][:, :, :, h_start:, w_start:]
    dict["foo1"] = foo1
    dict["foo2"] = foo2
    return dict


if __name__ == "__main__":
    my_dict = MyDict(foo1=torch.randn(1, 1, 3, 100, 100), foo2=torch.randn(1, 1, 3, 100, 100))
    print(my_dict["foo1"].shape)
    my_dict = crop_sample(my_dict)
    print(my_dict["foo1"].shape)

@patrick-kidger
Copy link
Owner

I'm afraid there isn't. The intended interpretation of this is to check that the shapes are the same.

@patrick-kidger patrick-kidger added the question User queries label Feb 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants