We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 42d91ac + 4e564d5 commit de6bb76Copy full SHA for de6bb76
pyproject.toml
@@ -108,16 +108,15 @@ jax_cpu = [
108
jax_gpu = [
109
"jax[cuda12]==0.7.0",
110
"algoperf[jax_core_deps]",
111
- "nvidia-cudnn-cu12==9.10.2.21", # temporary workaround for https://github.com/jax-ml/jax/issues/30663
112
]
113
114
pytorch_cpu = [
115
"torch==2.5.1",
116
"torchvision==0.20.1"
117
118
pytorch_gpu = [
119
- "torch==2.5.1",
120
- "torchvision==0.20.1",
+ "torch==2.9.0",
+ "torchvision==0.24.0",
121
] # Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA.
122
123
###############################################################################
0 commit comments