Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tripy] Add __len__ for tp.Shape and infer length statically when possible #92

Merged
merged 6 commits into from
Aug 20, 2024

Conversation

slyubomirsky
Copy link
Collaborator

For some cases, it is useful to know the length of a tp.Shape without executing the model. This PR adds a method infer_len that allows operators to specify how to statically infer the length of Shape outputs when possible (it is always optional). Test cases are added.

@pranavm-nvidia pranavm-nvidia added the tripy Pull request for the tripy project label Aug 13, 2024
@slyubomirsky
Copy link
Collaborator Author

I can confirm that the CI passes.

@slyubomirsky slyubomirsky force-pushed the tripy-infer-shape-len branch 2 times, most recently from 60d624a to bb41ae4 Compare August 14, 2024 04:40
@parthchadha
Copy link
Collaborator

@slyubomirsky can you rebase your branch? This will help run CI on github.

@slyubomirsky slyubomirsky force-pushed the tripy-infer-shape-len branch from 325ff59 to f65a6ea Compare August 19, 2024 19:51
@slyubomirsky slyubomirsky force-pushed the tripy-infer-shape-len branch from f65a6ea to 3683872 Compare August 20, 2024 03:44
@slyubomirsky slyubomirsky merged commit 9456a8b into NVIDIA:main Aug 20, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tripy Pull request for the tripy project
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants