Skip to content

Commit

Permalink
Stein VI docs (pyro-ppl#1314)
Browse files Browse the repository at this point in the history
* Added Stein VI to contrib.rst

* added text to steinVI docs

* Added SteinVI to docs.

* Fixed formatting for einstein

* formatted steinvi

* updated example url in stein docs (contrib.rst)

* added version constraint on jax
  • Loading branch information
OlaRonning authored Feb 2, 2022
1 parent 2690521 commit a69344f
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 11 deletions.
63 changes: 63 additions & 0 deletions docs/source/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,66 @@ Nested Sampling
:undoc-members:
:show-inheritance:
:member-order: bysource


Stein Variational Inference
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Stein Variational Inference (SteinVI) is a family of VI techniques for approximate Bayesian inference based on
Stein’s method (see [1] for an overview). It is gaining popularity as it combines
the scalability of traditional VI with the flexibility of non-parametric particle-based methods.

Stein variational gradient descent (SVGD) [2] is a recent SteinVI technique which uses iteratively moves a set of
particles :math:`\{z_i\}_{i=1}^N` to approximate a distribution p(z).
SVGD is well suited for capturing correlations between latent variables as a particle-based method.
The technique preserves the scalability of traditional VI approaches while offering the flexibility and modeling scope
of methods such as Markov chain Monte Carlo (MCMC). SVGD is good at capturing multi-modality [3][4].

``numpyro.contrib.einstein`` is a framework for particle-based inference using the ELBO-within-Stein algorithm.
The framework works on Stein mixtures, a restricted mixture of guide programs parameterized by Stein particles.
Similarly to how SVGD works, Stein mixtures can approximate model posteriors by moving the Stein particles according
to the Stein forces. Because the Stein particles parameterize a guide, they capture a neighborhood rather than a
single point. This property means Stein mixtures significantly reduce the number of particles needed to represent
high dimensional models.

``numpyro.contrib.einstein`` mimics the interface from ``numpyro.infer.svi``, so trying SteinVI requires minimal
change to the code for existing models inferred with SVI. For primary usage, see the
`Bayesian neural network example <https://num.pyro.ai/en/latest/examples/stein_bnn.html>`_.

The framework currently supports several kernels, including:

- `RBFKernel`
- `LinearKernel`
- `RandomFeatureKernel`
- `MixtureKernel`
- `PrecondMatrixKernel`
- `HessianPrecondMatrix`
- `GraphicalKernel`

For example, usage see:

- The `Bayesian neural network example <https://num.pyro.ai/en/latest/examples/stein_bnn.html>`_

SteinVI Interface
-----------------
.. autoclass:: numpyro.contrib.einstein.steinvi.SteinVI

SteinVI Kernels
---------------
.. autoclass:: numpyro.contrib.einstein.kernels.RBFKernel
.. autoclass:: numpyro.contrib.einstein.kernels.LinearKernel
.. autoclass:: numpyro.contrib.einstein.kernels.RandomFeatureKernel
.. autoclass:: numpyro.contrib.einstein.kernels.MixtureKernel
.. autoclass:: numpyro.contrib.einstein.kernels.PrecondMatrixKernel
.. autoclass:: numpyro.contrib.einstein.kernels.GraphicalKernel

References
----------
1. *Stein's Method Meets Statistics: A Review of Some Recent Developments* (2021)
Andreas Anastasiou, Alessandro Barp, François-Xavier Briol, Bruno Ebner,
Robert E. Gaunt, Fatemeh Ghaderinezhad, Jackson Gorham, Arthur Gretton,
Christophe Ley, Qiang Liu, Lester Mackey, Chris. J. Oates, Gesine Reinert,
Yvik Swan. https://arxiv.org/abs/2105.03481
2. *Stein variational gradient descent: A general-purpose Bayesian inference algorithm* (2016)
Qiang Liu, Dilin Wang. NeurIPS
3. *Nonlinear Stein Variational Gradient Descent for Learning Diversified Mixture Models* (2019)
Dilin Wang, Qiang Liu. PMLR
10 changes: 2 additions & 8 deletions numpyro/contrib/einstein/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,7 @@ def kernel(x, y):

class RandomFeatureKernel(SteinKernel):
"""
Calculates the random kernel
:math:`k(x,y)= 1/m\\sum_{l=1}^{m}\\phi(x,w_l)\\phi(y,w_l),
from [1].
Calculates the random kernel :math:`k(x,y)= 1/m\\sum_{l=1}^{m}\\phi(x,w_l)\\phi(y,w_l)` from [1].
** References: **
1. *Stein Variational Gradient Descent as Moment Matching* by Liu and Wang
Expand All @@ -209,7 +206,6 @@ class RandomFeatureKernel(SteinKernel):
:param random_indices: The set of indices which to do random feature expansion on.
(default None, meaning all indices)
:param bandwidth_factor: A multiplier to the bandwidth based on data size n (default 1/log(n))
"""

def __init__(
Expand Down Expand Up @@ -416,9 +412,7 @@ def kernel(x, y):

class GraphicalKernel(SteinKernel):
"""
Calculates graphical kernel
:math: `k(x,y) = diag({K^(l)(x,y)}_l)
from [1].
Calculates graphical kernel :math: `k(x,y) = diag({K^(l)(x,y)}_l)` from [1].
** References: **
1. *Stein Variational Message Passing for Continuous Graphical Models* by Wang, Zheng and Liu
Expand Down
3 changes: 1 addition & 2 deletions numpyro/contrib/einstein/steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ def _numel(shape):


class SteinVI:
"""
Stein Variational Gradient Descent for Non-parametric Inference.
"""Stein variational inference for stein mixtures.
:param model: Python callable with Pyro primitives for the model.
:param guide: Python callable with Pyro primitives for the guide
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from setuptools import find_packages, setup

PROJECT_PATH = os.path.dirname(os.path.abspath(__file__))
_jax_version_constraints = ">=0.2.13"
_jax_version_constraints = ">=0.2.13,<0.2.28"
_jaxlib_version_constraints = ">=0.1.65"

# Find version
Expand Down

0 comments on commit a69344f

Please sign in to comment.