@@ -232,21 +232,29 @@ See the [docs](https://num.pyro.ai/en/latest/contrib.html#stein-variational-infe
232232
233233To install NumPyro with the latest CPU version of JAX, you can use pip:
234234
235- ```
235+ ``` bash
236236pip install numpyro
237237```
238238
239239In case of compatibility issues arise during execution of the above command, you can instead force the installation of a known
240240compatible CPU version of JAX with
241241
242- ```
242+ ``` bash
243243pip install ' numpyro[cpu]'
244244```
245245
246- To use ** NumPyro on the GPU** , you need to install CUDA first and then use the following pip command:
246+ To use ** NumPyro on the GPU** , you need to install CUDA first, and based on your CUDA version, you can use the following pip command:
247247
248+ For ** CUDA 12.x.y** :
249+
250+ ``` bash
251+ pip install ' numpyro[cuda12]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
248252```
249- pip install 'numpyro[cuda]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
253+
254+ For ** CUDA 13.x.y** :
255+
256+ ``` bash
257+ pip install ' numpyro[cuda13]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
250258```
251259
252260If you need further guidance, please have a look at the [ JAX GPU installation instructions] ( https://github.com/jax-ml/jax#pip-installation-gpu-cuda ) .
@@ -261,7 +269,7 @@ you can install NumPyro using the `pip install numpyro` command.
261269
262270You can also install NumPyro from source:
263271
264- ```
272+ ``` bash
265273git clone https://github.com/pyro-ppl/numpyro.git
266274cd numpyro
267275# install jax/jaxlib first for CUDA support
@@ -270,7 +278,7 @@ pip install -e '.[dev]' # contains additional dependencies for NumPyro developm
270278
271279You can also install NumPyro with conda:
272280
273- ```
281+ ``` bash
274282conda install -c conda-forge numpyro
275283```
276284
0 commit comments