We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent fa946d8 commit 4e564d5Copy full SHA for 4e564d5
pyproject.toml
@@ -105,16 +105,15 @@ jax_cpu = [
105
jax_gpu = [
106
"jax[cuda12]==0.7.0",
107
"algoperf[jax_core_deps]",
108
- "nvidia-cudnn-cu12==9.10.2.21", # temporary workaround for https://github.com/jax-ml/jax/issues/30663
109
]
110
111
pytorch_cpu = [
112
"torch==2.5.1",
113
"torchvision==0.20.1"
114
115
pytorch_gpu = [
116
- "torch==2.5.1",
117
- "torchvision==0.20.1",
+ "torch==2.9.0",
+ "torchvision==0.24.0",
118
] # Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA.
119
120
###############################################################################
0 commit comments