Skip to content

Commit de6bb76

Browse files
committed
Merge branch 'a100' into lm_workload
2 parents 42d91ac + 4e564d5 commit de6bb76

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,15 @@ jax_cpu = [
108108
jax_gpu = [
109109
"jax[cuda12]==0.7.0",
110110
"algoperf[jax_core_deps]",
111-
"nvidia-cudnn-cu12==9.10.2.21", # temporary workaround for https://github.com/jax-ml/jax/issues/30663
112111
]
113112

114113
pytorch_cpu = [
115114
"torch==2.5.1",
116115
"torchvision==0.20.1"
117116
]
118117
pytorch_gpu = [
119-
"torch==2.5.1",
120-
"torchvision==0.20.1",
118+
"torch==2.9.0",
119+
"torchvision==0.24.0",
121120
] # Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA.
122121

123122
###############################################################################

0 commit comments

Comments
 (0)