diff --git a/docs/source/contrib.rst b/docs/source/contrib.rst index 741417163..b2dca2ba1 100644 --- a/docs/source/contrib.rst +++ b/docs/source/contrib.rst @@ -9,3 +9,66 @@ Nested Sampling :undoc-members: :show-inheritance: :member-order: bysource + + +Stein Variational Inference +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Stein Variational Inference (SteinVI) is a family of VI techniques for approximate Bayesian inference based on +Stein’s method (see [1] for an overview). It is gaining popularity as it combines +the scalability of traditional VI with the flexibility of non-parametric particle-based methods. + +Stein variational gradient descent (SVGD) [2] is a recent SteinVI technique which uses iteratively moves a set of +particles :math:`\{z_i\}_{i=1}^N` to approximate a distribution p(z). +SVGD is well suited for capturing correlations between latent variables as a particle-based method. +The technique preserves the scalability of traditional VI approaches while offering the flexibility and modeling scope +of methods such as Markov chain Monte Carlo (MCMC). SVGD is good at capturing multi-modality [3][4]. + +``numpyro.contrib.einstein`` is a framework for particle-based inference using the ELBO-within-Stein algorithm. +The framework works on Stein mixtures, a restricted mixture of guide programs parameterized by Stein particles. +Similarly to how SVGD works, Stein mixtures can approximate model posteriors by moving the Stein particles according +to the Stein forces. Because the Stein particles parameterize a guide, they capture a neighborhood rather than a +single point. This property means Stein mixtures significantly reduce the number of particles needed to represent +high dimensional models. + +``numpyro.contrib.einstein`` mimics the interface from ``numpyro.infer.svi``, so trying SteinVI requires minimal +change to the code for existing models inferred with SVI. For primary usage, see the +`Bayesian neural network example `_. + +The framework currently supports several kernels, including: + +- `RBFKernel` +- `LinearKernel` +- `RandomFeatureKernel` +- `MixtureKernel` +- `PrecondMatrixKernel` +- `HessianPrecondMatrix` +- `GraphicalKernel` + +For example, usage see: + +- The `Bayesian neural network example `_ + +SteinVI Interface +----------------- +.. autoclass:: numpyro.contrib.einstein.steinvi.SteinVI + +SteinVI Kernels +--------------- +.. autoclass:: numpyro.contrib.einstein.kernels.RBFKernel +.. autoclass:: numpyro.contrib.einstein.kernels.LinearKernel +.. autoclass:: numpyro.contrib.einstein.kernels.RandomFeatureKernel +.. autoclass:: numpyro.contrib.einstein.kernels.MixtureKernel +.. autoclass:: numpyro.contrib.einstein.kernels.PrecondMatrixKernel +.. autoclass:: numpyro.contrib.einstein.kernels.GraphicalKernel + +References +---------- +1. *Stein's Method Meets Statistics: A Review of Some Recent Developments* (2021) +Andreas Anastasiou, Alessandro Barp, François-Xavier Briol, Bruno Ebner, +Robert E. Gaunt, Fatemeh Ghaderinezhad, Jackson Gorham, Arthur Gretton, +Christophe Ley, Qiang Liu, Lester Mackey, Chris. J. Oates, Gesine Reinert, +Yvik Swan. https://arxiv.org/abs/2105.03481 +2. *Stein variational gradient descent: A general-purpose Bayesian inference algorithm* (2016) +Qiang Liu, Dilin Wang. NeurIPS +3. *Nonlinear Stein Variational Gradient Descent for Learning Diversified Mixture Models* (2019) +Dilin Wang, Qiang Liu. PMLR diff --git a/numpyro/contrib/einstein/kernels.py b/numpyro/contrib/einstein/kernels.py index 900867ae9..76fc88c25 100644 --- a/numpyro/contrib/einstein/kernels.py +++ b/numpyro/contrib/einstein/kernels.py @@ -196,10 +196,7 @@ def kernel(x, y): class RandomFeatureKernel(SteinKernel): """ - Calculates the random kernel - :math:`k(x,y)= 1/m\\sum_{l=1}^{m}\\phi(x,w_l)\\phi(y,w_l), - from [1]. - + Calculates the random kernel :math:`k(x,y)= 1/m\\sum_{l=1}^{m}\\phi(x,w_l)\\phi(y,w_l)` from [1]. ** References: ** 1. *Stein Variational Gradient Descent as Moment Matching* by Liu and Wang @@ -209,7 +206,6 @@ class RandomFeatureKernel(SteinKernel): :param random_indices: The set of indices which to do random feature expansion on. (default None, meaning all indices) :param bandwidth_factor: A multiplier to the bandwidth based on data size n (default 1/log(n)) - """ def __init__( @@ -416,9 +412,7 @@ def kernel(x, y): class GraphicalKernel(SteinKernel): """ - Calculates graphical kernel - :math: `k(x,y) = diag({K^(l)(x,y)}_l) - from [1]. + Calculates graphical kernel :math: `k(x,y) = diag({K^(l)(x,y)}_l)` from [1]. ** References: ** 1. *Stein Variational Message Passing for Continuous Graphical Models* by Wang, Zheng and Liu diff --git a/numpyro/contrib/einstein/steinvi.py b/numpyro/contrib/einstein/steinvi.py index 78f6fc7c5..cb5be0e2a 100644 --- a/numpyro/contrib/einstein/steinvi.py +++ b/numpyro/contrib/einstein/steinvi.py @@ -33,8 +33,7 @@ def _numel(shape): class SteinVI: - """ - Stein Variational Gradient Descent for Non-parametric Inference. + """Stein variational inference for stein mixtures. :param model: Python callable with Pyro primitives for the model. :param guide: Python callable with Pyro primitives for the guide diff --git a/setup.py b/setup.py index 03f32abe6..9babf2d3c 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ from setuptools import find_packages, setup PROJECT_PATH = os.path.dirname(os.path.abspath(__file__)) -_jax_version_constraints = ">=0.2.13" +_jax_version_constraints = ">=0.2.13,<0.2.28" _jaxlib_version_constraints = ">=0.1.65" # Find version