Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions about TorchScript support #30

Open
YutackPark opened this issue Nov 24, 2024 · 5 comments
Open

Questions about TorchScript support #30

YutackPark opened this issue Nov 24, 2024 · 5 comments
Assignees

Comments

@YutackPark
Copy link

Hi, I'm exploring the possibility of using cuEquivariance-Torch in a C++ environment, similar to how e3nn models can be exported via TorchScript. I have a few questions:

  1. Can cuEquivariance modules be exported using TorchScript?
  2. If it is technically feasible but not yet officially supported, are there any plans about it?

I attempted to use both torch.jit.script and torch.jit.trace. While the former raises errors, the latter produces warnings. If it is unexpected, I'll attach a minimal code to reproduce it. Before diving deeper into debugging, I wanted to confirm if there are any related development plans or known limitations.

I'm aware that TorchScript may be deprecated in the future. However, its replacement, torch.export (https://pytorch.org/docs/stable/export.html), is still marked as unstable, and I have no option.

Lastly, I've noticed that using torch.compile with static tensor shapes can nearly double the performance of e3nn. Could cuEquivariance achieve similar speedups with torch.compile, or is this approach less relevant given its use of optimized custom kernels?

Thanks in advance for your guidance and support!

@mitkotak
Copy link
Contributor

mitkotak commented Nov 24, 2024

Had a similar question regarding torch.compile, does it make sense to put a torch.compile decorator over the FX fallback ? Happy to put in a PR

@borisfom
Copy link
Collaborator

borisfom commented Dec 2, 2024

@YutackPark : cuEquivariance should be JIT-exportable. We do have jit.trace() unit test in place for underlying cuequivariance_ops_torch, not for this repo yet.
torch.jit.script will fail if you use it on cuEquivariance modules directly - unfortunately, script() restrictions on arguments of exported methods go against API usability and we consider use case of cuEquivariance module being top-level exported module not very common.
If you do wish to script() cuEquivariance module and not the enclosing model, please write a simple wrapper that would not use variable number of arguments or keyword-only arguments with defaults, and script() it. If that fails - please do let us know and submit a repro case.

@borisfom
Copy link
Collaborator

@YutackPark : actually, we were able to fix script() compatibility (with changes to the API).
Please check it out, it's merged already: #40

@Yangxinsix
Copy link

Hi Boris,

I tried the newest version that fixed torch.compile() and script() compatibility. Now torch.compile works fine but script is still not working. The problem seems come from some API issues in cuequivariance_ops_torch, which is not included in this repo.

Is there any feasible way to circumvent using variable number of arguments or keyword-only arguments in cuequivariance_ops_torch or if there is a newer version?

Thank you for your support!

@borisfom
Copy link
Collaborator

@Yangxinsix : thanks for the input! Please watch out for cuequivariance_ops_torch updated package, should come out soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants