Skip to content

Commit 47335b8

Browse files
authored
feat: add support for CUDA 12 and CUDA 13 (#2094)
1 parent 05d19d1 commit 47335b8

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

README.md

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,21 +232,29 @@ See the [docs](https://num.pyro.ai/en/latest/contrib.html#stein-variational-infe
232232
233233
To install NumPyro with the latest CPU version of JAX, you can use pip:
234234

235-
```
235+
```bash
236236
pip install numpyro
237237
```
238238

239239
In case of compatibility issues arise during execution of the above command, you can instead force the installation of a known
240240
compatible CPU version of JAX with
241241

242-
```
242+
```bash
243243
pip install 'numpyro[cpu]'
244244
```
245245

246-
To use **NumPyro on the GPU**, you need to install CUDA first and then use the following pip command:
246+
To use **NumPyro on the GPU**, you need to install CUDA first, and based on your CUDA version, you can use the following pip command:
247247

248+
For **CUDA 12.x.y**:
249+
250+
```bash
251+
pip install 'numpyro[cuda12]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
248252
```
249-
pip install 'numpyro[cuda]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
253+
254+
For **CUDA 13.x.y**:
255+
256+
```bash
257+
pip install 'numpyro[cuda13]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
250258
```
251259

252260
If you need further guidance, please have a look at the [JAX GPU installation instructions](https://github.com/jax-ml/jax#pip-installation-gpu-cuda).
@@ -261,7 +269,7 @@ you can install NumPyro using the `pip install numpyro` command.
261269
262270
You can also install NumPyro from source:
263271

264-
```
272+
```bash
265273
git clone https://github.com/pyro-ppl/numpyro.git
266274
cd numpyro
267275
# install jax/jaxlib first for CUDA support
@@ -270,7 +278,7 @@ pip install -e '.[dev]' # contains additional dependencies for NumPyro developm
270278

271279
You can also install NumPyro with conda:
272280

273-
```
281+
```bash
274282
conda install -c conda-forge numpyro
275283
```
276284

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@
8787
# TPU and CUDA installations, currently require to add package repository URL, i.e.,
8888
# pip install 'numpyro[cuda]' -f https://storage.googleapis.com/jax-releases/jax_releases.html
8989
"tpu": f"jax[tpu]{_jax_version_constraints}",
90-
"cuda": f"jax[cuda]{_jax_version_constraints}",
90+
"cuda12": f"jax[cuda12]{_jax_version_constraints}",
91+
"cuda13": f"jax[cuda13]{_jax_version_constraints}",
9192
},
9293
python_requires=">=3.9",
9394
long_description=long_description,

0 commit comments

Comments
 (0)