Skip to content

Commit 4e564d5

Browse files
committed
update pytorch
1 parent fa946d8 commit 4e564d5

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,16 +105,15 @@ jax_cpu = [
105105
jax_gpu = [
106106
"jax[cuda12]==0.7.0",
107107
"algoperf[jax_core_deps]",
108-
"nvidia-cudnn-cu12==9.10.2.21", # temporary workaround for https://github.com/jax-ml/jax/issues/30663
109108
]
110109

111110
pytorch_cpu = [
112111
"torch==2.5.1",
113112
"torchvision==0.20.1"
114113
]
115114
pytorch_gpu = [
116-
"torch==2.5.1",
117-
"torchvision==0.20.1",
115+
"torch==2.9.0",
116+
"torchvision==0.24.0",
118117
] # Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA.
119118

120119
###############################################################################

0 commit comments

Comments
 (0)