Skip to content

Commit

Permalink
Add a note to annotation examples on how to use the mcmc class with a…
Browse files Browse the repository at this point in the history
…rviz (pyro-ppl#1273)

* add annotation to readme

* add a note to annotation example on how to merge discrete samples into mcmc class
  • Loading branch information
fehiepsi authored Jan 3, 2022
1 parent 28fca48 commit f757183
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,14 @@ For some more examples on specifying models and doing inference in NumPyro:

- [Bayesian Regression in NumPyro](https://nbviewer.jupyter.org/github/pyro-ppl/numpyro/blob/master/notebooks/source/bayesian_regression.ipynb) - Start here to get acquainted with writing a simple model in NumPyro, MCMC inference API, effect handlers and writing custom inference utilities.
- [Time Series Forecasting](https://nbviewer.jupyter.org/github/pyro-ppl/numpyro/blob/master/notebooks/source/time_series_forecasting.ipynb) - Illustrates how to convert for loops in the model to JAX's `lax.scan` primitive for fast inference.
- [Annotation examples](https://num.pyro.ai/en/stable/examples/annotation.html) - Illustrates how to utilize the enumeration mechanism to perform inference for models with discrete latent variables.
- [Baseball example](https://github.com/pyro-ppl/numpyro/blob/master/examples/baseball.py) - Using NUTS for a simple hierarchical model. Compare this with the baseball example in [Pyro](https://github.com/pyro-ppl/pyro/blob/dev/examples/baseball.py).
- [Hidden Markov Model](https://github.com/pyro-ppl/numpyro/blob/master/examples/hmm.py) in NumPyro as compared to [Stan](https://mc-stan.org/docs/2_19/stan-users-guide/hmms-section.html).
- [Variational Autoencoder](https://github.com/pyro-ppl/numpyro/blob/master/examples/vae.py) - As a simple example that uses Variational Inference with neural networks. [Pyro implementation](https://github.com/pyro-ppl/pyro/blob/dev/examples/vae/vae.py) for comparison.
- [Gaussian Process](https://github.com/pyro-ppl/numpyro/blob/master/examples/gp.py) - Provides a simple example to use NUTS to sample from the posterior over the hyper-parameters of a Gaussian Process.
- [Horseshoe Regression](https://github.com/pyro-ppl/numpyro/blob/master/examples/horseshoe_regression.py) - Shows how to implemement generalized linear models equipped with a Horseshoe prior for both binary-valued and real-valued outputs.
- [Statistical Rethinking with NumPyro](https://github.com/fehiepsi/rethinking-numpyro) - [Notebooks](https://nbviewer.jupyter.org/github/fehiepsi/rethinking-numpyro/tree/master/notebooks/) containing translation of the code in Richard McElreath's [Statistical Rethinking](https://xcelab.net/rm/statistical-rethinking/) book second version, to NumPyro.
- Other model examples can be found in the [examples](https://github.com/pyro-ppl/numpyro/tree/master/examples) folder.
- Other model examples can be found in the [examples](https://num.pyro.ai/en/stable/) site.

Pyro users will note that the API for model specification and inference is largely the same as Pyro, including the distributions API, by design. However, there are some important core differences (reflected in the internals) that users should be aware of. e.g. in NumPyro, there is no global parameter store or random state, to make it possible for us to leverage JAX's JIT compilation. Also, users may need to write their models in a more *functional* style that works better with JAX. Refer to [FAQs](#frequently-asked-questions) for a list of differences.

Expand Down
18 changes: 18 additions & 0 deletions examples/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,24 @@ def main(args):
print(row_format.format(f"item[{i}]", *row))


# %%
# .. note::
# In the above inference code, we marginalized the discrete latent variables `c`
# hence `mcmc.get_samples(...)` does not include samples of `c`. We then utilize
# `Predictive(..., infer_discrete=True)` to get posterior samples for `c`, which
# is stored in `discrete_samples`. To merge those discrete samples into the `mcmc`
# instance, we can use the following pattern::
#
# chain_discrete_samples = jax.tree_util.tree_map(
# lambda x: x.reshape((args.num_chains, args.num_samples) + x.shape[1:]),
# discrete_samples)
# mcmc.get_samples().update(discrete_samples)
# mcmc.get_samples(group_by_chain=True).update(chain_discrete_samples)
#
# This is useful when we want to pass the `mcmc` instance to `arviz` through
# `arviz.from_numpyro(mcmc)`.


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.8.0")
parser = argparse.ArgumentParser(description="Bayesian Models of Annotation")
Expand Down

0 comments on commit f757183

Please sign in to comment.