Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Working torch.jit.script() and torch.compile() support #44

Merged
merged 115 commits into from
Jan 6, 2025
Merged
Changes from 1 commit
Commits
Show all changes
115 commits
Select commit Hold shift + click to select a range
ba9580a
test and quick fix for zero batch
mariogeiger Nov 20, 2024
0bfada9
trigger uniform 1d in test
mariogeiger Nov 20, 2024
fd097c6
satisfy linter
mariogeiger Nov 21, 2024
251fc4d
from typing import
mariogeiger Nov 21, 2024
3498a32
determine math_dtype earlier
mariogeiger Nov 21, 2024
7f3cf05
warning with pip commands
mariogeiger Nov 21, 2024
2624335
remove unused argument
mariogeiger Nov 21, 2024
91f7fce
changelog
mariogeiger Nov 21, 2024
4401048
list of inputs
mariogeiger Nov 21, 2024
ad2db8d
add Fixed subtite
mariogeiger Nov 21, 2024
dca96a8
Merge branch 'zero-batch' into list-inputs
mariogeiger Nov 21, 2024
889051a
changelog
mariogeiger Nov 21, 2024
c23816a
Merge branch 'main' into list-inputs
mariogeiger Nov 21, 2024
0487d77
Merge branch 'main' into list-inputs
mariogeiger Dec 3, 2024
bc6b405
add test for torch.jit.script
mariogeiger Dec 3, 2024
c8de185
fix
mariogeiger Dec 3, 2024
5e00b37
Merge branch 'list-inputs' into jit-script
mariogeiger Dec 3, 2024
16e4450
remove keyword-only and import in the forward
mariogeiger Dec 3, 2024
e979b0f
Merge branch 'main' into jit-script
mariogeiger Dec 4, 2024
b2c4fbb
low lvl script tests
mariogeiger Dec 4, 2024
4669a86
TensorProduct working with script()
borisfom Dec 4, 2024
dc9d5b0
add 4 operands tests
mariogeiger Dec 4, 2024
334b460
Unit tests run
borisfom Dec 5, 2024
79e7c5f
Restoring debug logging
borisfom Dec 5, 2024
46a0478
Merge branch 'jit-script' into jit-script
borisfom Dec 5, 2024
401fd53
Merge branch 'jit-script' of github.com:NVIDIA/cuEquivariance into ji…
borisfom Dec 5, 2024
8fce54b
Merge remote-tracking branch 'b/jit-script' into jit-script
borisfom Dec 5, 2024
6c5cdb0
Parameterized script test
borisfom Dec 5, 2024
e21c45f
Fixed transpose for script(), script_test successful
borisfom Dec 5, 2024
779dd9c
Fixed input mutation
borisfom Dec 5, 2024
c315857
Fixed tests
borisfom Dec 6, 2024
ab590c8
format with black
mariogeiger Dec 6, 2024
ec1eb27
format with black
mariogeiger Dec 6, 2024
faf235e
fix tests
mariogeiger Dec 6, 2024
c476af9
fix missing parenthesis
mariogeiger Dec 6, 2024
994b8d9
fix tests: increase torch._dynamo.config.cache_size_limit
mariogeiger Dec 6, 2024
f240eb8
fix docstring tests
mariogeiger Dec 6, 2024
fbfb9d0
replace == by is
mariogeiger Dec 6, 2024
dc20be5
clean use_fallback conditions
mariogeiger Dec 6, 2024
4b201c3
fix
mariogeiger Dec 6, 2024
b5b59b8
fix
mariogeiger Dec 6, 2024
72baf17
Export test added, scripting fallback attempt
borisfom Dec 7, 2024
5a94b09
Merge remote-tracking branch 'b/jit-script' into jit-script
borisfom Dec 7, 2024
6bdf924
Merge branch 'main' into jit-script
mariogeiger Dec 9, 2024
8d31929
enable tests on cpu
mariogeiger Dec 9, 2024
8afa056
fix tests
mariogeiger Dec 9, 2024
09bbc8d
fix ruff
mariogeiger Dec 9, 2024
9c38168
fix
mariogeiger Dec 9, 2024
de9af8f
fix docstring tests
mariogeiger Dec 9, 2024
999a31d
add -x to tests
mariogeiger Dec 9, 2024
8c435fe
Working around torch_tensorrt bugs
borisfom Dec 11, 2024
8271c06
Merge remote-tracking branch 'origin/main' into torch_trt_war
borisfom Dec 11, 2024
ae0bff2
Fixing utils.py import
borisfom Dec 11, 2024
2bad3bc
Adding utils.py
borisfom Dec 11, 2024
7c82e20
Style
borisfom Dec 11, 2024
68a84f8
import nvidia_sphinx_theme
mariogeiger Dec 11, 2024
5ca4edc
spherical harmonics module
mariogeiger Dec 11, 2024
01914dd
fix tests
mariogeiger Dec 11, 2024
50b75dc
test SymmetricContraction export
mariogeiger Dec 11, 2024
bd16dbf
Fixed symmetric_contraction test
borisfom Dec 12, 2024
245e594
add device info
mariogeiger Dec 13, 2024
6f0e1b5
fix sh
mariogeiger Dec 13, 2024
e296279
fix
mariogeiger Dec 13, 2024
dc1f394
skip
mariogeiger Dec 13, 2024
ad91f38
torch._dynamo.config.cache_size_limit = 100
mariogeiger Dec 13, 2024
9d224b0
fix test
mariogeiger Dec 13, 2024
6c03f29
Script compatibility for fallback
borisfom Dec 16, 2024
21da7b0
style
borisfom Dec 16, 2024
d8a4336
Trying to make trace() work
borisfom Dec 16, 2024
6e518f6
Restoring integer cast
borisfom Dec 16, 2024
9410be6
Skipping failing tests
borisfom Dec 16, 2024
d4a0842
disabling cast for fallback
borisfom Dec 16, 2024
c4799c4
Merge remote-tracking branch 'origin/torch_trt_war' into torch_trt_war
borisfom Dec 16, 2024
a6856db
optimize_fallback=use_fallback
mariogeiger Dec 16, 2024
b8be9a2
Fixing the reinterpret cast
borisfom Dec 16, 2024
877d4c7
Merge remote-tracking branch 'b/torch_trt_war' into torch_trt_war
borisfom Dec 16, 2024
3538231
Fixing clone()
borisfom Dec 16, 2024
0c0d7e9
delete broadcast_shapes
mariogeiger Dec 17, 2024
97ca27a
delete _reshape
mariogeiger Dec 17, 2024
64bd41f
rename
mariogeiger Dec 17, 2024
d527020
Using alternative disable type change fixture
borisfom Dec 17, 2024
ad97ae9
Merge remote-tracking branch 'b/torch_trt_war' into torch_trt_war
borisfom Dec 17, 2024
3687da3
Restored assert
borisfom Dec 17, 2024
687dd53
try fix test
mariogeiger Dec 17, 2024
0576ad0
simplify symmetric_tensor_product_test to make test run faster
mariogeiger Dec 17, 2024
66aa108
try to fix some tests
mariogeiger Dec 17, 2024
17df143
Fixing disable_type_conv
borisfom Dec 17, 2024
1c459aa
try fix
mariogeiger Dec 17, 2024
3a7f051
Merge remote-tracking branch 'b/torch_trt_war' into torch_trt_war
borisfom Dec 17, 2024
ef5cd76
Merge remote-tracking branch 'b/torch_trt_war' into torch_trt_war
borisfom Dec 17, 2024
df17042
fix strange bug
mariogeiger Dec 17, 2024
73647fe
Merge remote-tracking branch 'b/torch_trt_war' into torch_trt_war
borisfom Dec 18, 2024
a8516aa
Script fixes for uniform
borisfom Dec 18, 2024
9a256c9
add test_script_tensor_product
mariogeiger Dec 18, 2024
0b37932
Moving all export tests, disabling torch_trt for now
borisfom Dec 18, 2024
a5262bd
more strict input shapes
mariogeiger Dec 18, 2024
59dd354
add back @pytest.mark.parametrize("mode", export_modes)
mariogeiger Dec 18, 2024
8a0f109
fix
mariogeiger Dec 18, 2024
e01a7a9
Fixing noconv bug
borisfom Dec 18, 2024
3fbadb5
Really fixing noconv
borisfom Dec 19, 2024
686782c
fix linear
mariogeiger Dec 19, 2024
d8c1dee
fix rotations
mariogeiger Dec 19, 2024
c837298
fix tpfc
mariogeiger Dec 19, 2024
ac2ed25
fix tpcw
mariogeiger Dec 19, 2024
64c0121
less test
mariogeiger Dec 19, 2024
4e5ddc0
Merge branch 'main' into torch_trt_war
mariogeiger Dec 19, 2024
0080283
remove unused mode in tensor_product_test
mariogeiger Dec 19, 2024
9c07cd2
typo
mariogeiger Dec 19, 2024
c860042
disable export
mariogeiger Dec 19, 2024
df212f5
Reduced export test modes list
borisfom Dec 20, 2024
ed5609d
Merge remote-tracking branch 'b/torch_trt_war' into torch_trt_war
borisfom Dec 20, 2024
25c2eef
Added unit tests to operations and the rest of the primitives
borisfom Dec 21, 2024
de0fa42
Fixing script() for non-internal weights
borisfom Dec 23, 2024
d297764
skip GPU tests when running on CPU
mariogeiger Jan 6, 2025
4af57d8
Fix: if use_fallback is None and cuda is not available => use fallback
mariogeiger Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Restoring debug logging
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
  • Loading branch information
borisfom committed Dec 5, 2024
commit 79e7c5f750879c0b61f1d8b9e8b1ebdd15ada7ac
Original file line number Diff line number Diff line change
@@ -324,9 +324,10 @@ def forward(
i0 = i0.to(torch.int32)
x0 = x0.reshape(x0.shape[0], x0.shape[1] // self.u, self.u)
x1 = x1.reshape(x1.shape[0], x1.shape[1] // self.u, self.u)
# logger.debug(
# f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}"
# )
if not torch.jit.is_scripting():
logger.debug(
f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}"
)
out = self.f(x1, x0, i0)
out = out.reshape(out.shape[0], out.shape[1] * self.u)
return out