diff --git a/docs/source/conf.py b/docs/source/conf.py index f9cbcc18e..216658ff8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -37,9 +37,9 @@ # -- Project information ----------------------------------------------------- -project = u"NumPyro" -copyright = u"2019, Uber Technologies, Inc" -author = u"Uber AI Labs" +project = "NumPyro" +copyright = "2019, Uber Technologies, Inc" +author = "Uber AI Labs" version = "" @@ -280,14 +280,14 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, "NumPyro.tex", u"NumPyro Documentation", u"Uber AI Labs", "manual") + (master_doc, "NumPyro.tex", "NumPyro Documentation", "Uber AI Labs", "manual") ] # -- Options for manual page output ------------------------------------------ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [(master_doc, "NumPyro", u"NumPyro Documentation", [author], 1)] +man_pages = [(master_doc, "NumPyro", "NumPyro Documentation", [author], 1)] # -- Options for Texinfo output ---------------------------------------------- @@ -298,7 +298,7 @@ ( master_doc, "NumPyro", - u"NumPyro Documentation", + "NumPyro Documentation", author, "NumPyro", "Pyro PPL on Numpy", diff --git a/docs/source/index.rst b/docs/source/index.rst index a047a0f15..4dd395d4c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -77,8 +77,9 @@ NumPyro documentation :caption: Other Inference Algorithms :name: other-inference-algorithms - examples/hmcecs examples/covtype + examples/hmcecs + examples/stein_bnn Indices and tables diff --git a/examples/hmm_enum.py b/examples/hmm_enum.py index 6ec0e3f69..cc10e78a6 100644 --- a/examples/hmm_enum.py +++ b/examples/hmm_enum.py @@ -166,7 +166,7 @@ def transition_fn(carry, y): # target hidden dimension. def model_3(sequences, lengths, args, include_prior=True): num_sequences, max_length, data_dim = sequences.shape - hidden_dim = int(args.hidden_dim ** 0.5) # split between w and x + hidden_dim = int(args.hidden_dim**0.5) # split between w and x with mask(mask=include_prior): probs_w = numpyro.sample( "probs_w", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1) @@ -220,7 +220,7 @@ def transition_fn(carry, y): # Factorial HMM, but this model has more parameters. def model_4(sequences, lengths, args, include_prior=True): num_sequences, max_length, data_dim = sequences.shape - hidden_dim = int(args.hidden_dim ** 0.5) # split between w and x + hidden_dim = int(args.hidden_dim**0.5) # split between w and x with mask(mask=include_prior): probs_w = numpyro.sample( "probs_w", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1) diff --git a/examples/hsgp.py b/examples/hsgp.py index bda4447fe..1cff5f69e 100644 --- a/examples/hsgp.py +++ b/examples/hsgp.py @@ -161,7 +161,7 @@ def make_birthdays_data_dict(data): # --- Modelling utility functions --- # def spectral_density(w, alpha, length): c = alpha * jnp.sqrt(2 * jnp.pi) * length - e = jnp.exp(-0.5 * (length ** 2) * (w ** 2)) + e = jnp.exp(-0.5 * (length**2) * (w**2)) return c * e @@ -197,7 +197,7 @@ def diag_spectral_density_periodic(alpha, length, M): a = length ** (-2) J = jnp.arange(0, M) c = jnp.where(J > 0, 2, 1) - q2 = (c * alpha ** 2 / jnp.exp(a)) * modified_bessel_first_kind(J, a) + q2 = (c * alpha**2 / jnp.exp(a)) * modified_bessel_first_kind(J, a) return q2 diff --git a/examples/sparse_regression.py b/examples/sparse_regression.py index ab735e6a1..31ae532f6 100644 --- a/examples/sparse_regression.py +++ b/examples/sparse_regression.py @@ -73,7 +73,7 @@ def model(X, Y, hypers): # compute kernel kX = kappa * X - k = kernel(kX, kX, eta1, eta2, hypers["c"]) + sigma ** 2 * jnp.eye(N) + k = kernel(kX, kX, eta1, eta2, hypers["c"]) + sigma**2 * jnp.eye(N) assert k.shape == (N, N) # sample Y according to the standard gaussian process formula @@ -99,7 +99,7 @@ def compute_singleton_mean_variance(X, Y, dimension, msq, lam, eta1, xisq, c, si kX = kappa * X kprobe = kappa * probe - k_xx = kernel(kX, kX, eta1, eta2, c) + sigma ** 2 * jnp.eye(N) + k_xx = kernel(kX, kX, eta1, eta2, c) + sigma**2 * jnp.eye(N) k_xx_inv = jnp.linalg.inv(k_xx) k_probeX = kernel(kprobe, kX, eta1, eta2, c) k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c) @@ -130,7 +130,7 @@ def compute_pairwise_mean_variance(X, Y, dim1, dim2, msq, lam, eta1, xisq, c, si kX = kappa * X kprobe = kappa * probe - k_xx = kernel(kX, kX, eta1, eta2, c) + sigma ** 2 * jnp.eye(N) + k_xx = kernel(kX, kX, eta1, eta2, c) + sigma**2 * jnp.eye(N) k_xx_inv = jnp.linalg.inv(k_xx) k_probeX = kernel(kprobe, kX, eta1, eta2, c) k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c) @@ -188,7 +188,7 @@ def sample_theta_space(X, Y, active_dims, msq, lam, eta1, xisq, c, sigma): kX = kappa * X kprobe = kappa * probe - k_xx = kernel(kX, kX, eta1, eta2, c) + sigma ** 2 * jnp.eye(N) + k_xx = kernel(kX, kX, eta1, eta2, c) + sigma**2 * jnp.eye(N) L = cho_factor(k_xx, lower=True)[0] k_probeX = kernel(kprobe, kX, eta1, eta2, c) k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c) diff --git a/examples/stein_bnn.py b/examples/stein_bnn.py new file mode 100644 index 000000000..9282e1304 --- /dev/null +++ b/examples/stein_bnn.py @@ -0,0 +1,184 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +""" +Example: Bayesian Neural Network with SteinVI +============================================= +We demonstrate how to use SteinVI to predict housing prices using a BNN for the Boston Housing prices dataset +from the UCI regression benchmarks. +""" + +import argparse +from collections import namedtuple +import datetime +from functools import partial +from time import time + +from sklearn.model_selection import train_test_split + +from jax import random +import jax.numpy as jnp + +import numpyro +from numpyro.contrib.einstein import RBFKernel, SteinVI +from numpyro.distributions import Gamma, Normal +from numpyro.examples.datasets import BOSTON_HOUSING, load_dataset +from numpyro.infer import Predictive, Trace_ELBO, init_to_uniform +from numpyro.infer.autoguide import AutoDelta +from numpyro.optim import Adagrad + +DataState = namedtuple("data", ["xtr", "xte", "ytr", "yte"]) + + +def load_data() -> DataState: + _, fetch = load_dataset(BOSTON_HOUSING, shuffle=False) + x, y = fetch() + xtr, xte, ytr, yte = train_test_split(x, y, train_size=0.90) + + return DataState(*map(partial(jnp.array, dtype=float), (xtr, xte, ytr, yte))) + + +def normalize(val, mean=None, std=None): + """Normalize data to zero mean, unit variance""" + if mean is None and std is None: + # Only use training data to estimate mean and std. + std = jnp.std(val, 0, keepdims=True) + std = jnp.where(std == 0, 1.0, std) + mean = jnp.mean(val, 0, keepdims=True) + return (val - mean) / std, mean, std + + +def model(x, y=None, hidden_dim=50, subsample_size=100): + """BNN described in section 5 of [1]. + + **References:** + 1. *Stein variational gradient descent: A general purpose bayesian inference algorithm* + Qiang Liu and Dilin Wang (2016). + """ + + prec_nn = numpyro.sample( + "prec_nn", Gamma(1.0, 0.1) + ) # hyper prior for precision of nn weights and biases + + n, m = x.shape + + with numpyro.plate("l1_hidden", hidden_dim, dim=-1): + # prior l1 bias term + b1 = numpyro.sample( + "nn_b1", + Normal( + 0.0, + 1.0 / jnp.sqrt(prec_nn), + ), + ) + assert b1.shape == (hidden_dim,) + + with numpyro.plate("l1_feat", m, dim=-2): + w1 = numpyro.sample( + "nn_w1", Normal(0.0, 1.0 / jnp.sqrt(prec_nn)) + ) # prior on l1 weights + assert w1.shape == (m, hidden_dim) + + with numpyro.plate("l2_hidden", hidden_dim, dim=-1): + w2 = numpyro.sample( + "nn_w2", Normal(0.0, 1.0 / jnp.sqrt(prec_nn)) + ) # prior on output weights + + b2 = numpyro.sample( + "nn_b2", Normal(0.0, 1.0 / jnp.sqrt(prec_nn)) + ) # prior on output bias term + + # precision prior on observations + prec_obs = numpyro.sample("prec_obs", Gamma(1.0, 0.1)) + with numpyro.plate( + "data", + x.shape[0], + subsample_size=subsample_size, + dim=-1, + ): + batch_x = numpyro.subsample(x, event_dim=1) + if y is not None: + batch_y = numpyro.subsample(y, event_dim=0) + else: + batch_y = y + + numpyro.sample( + "y", + Normal( + jnp.maximum(batch_x @ w1 + b1, 0) @ w2 + b2, 1.0 / jnp.sqrt(prec_obs) + ), # 1 hidden layer with ReLU activation + obs=batch_y, + ) + + +def main(args): + data = load_data() + + inf_key, pred_key, data_key = random.split(random.PRNGKey(args.rng_key), 3) + # normalize data and labels to zero mean unit variance! + x, xtr_mean, xtr_std = normalize(data.xtr) + y, ytr_mean, ytr_std = normalize(data.ytr) + + rng_key, inf_key = random.split(inf_key) + + stein = SteinVI( + model, + AutoDelta(model, init_loc_fn=partial(init_to_uniform, radius=0.1)), + Adagrad(0.05), + Trace_ELBO(20), # estimate elbo with 20 particles (not stein particles!) + RBFKernel(), + repulsion_temperature=args.repulsion, + num_particles=args.num_particles, + ) + start = time() + + # use keyword params for static (shape etc.)! + result = stein.run( + rng_key, + args.max_iter, + x, + y, + hidden_dim=args.hidden_dim, + subsample_size=args.subsample_size, + progress_bar=args.progress_bar, + ) + time_taken = time() - start + + pred = Predictive( + model, + guide=stein.guide, + params=stein.get_params(result.state), + num_samples=1, + batch_ndims=1, # stein particle dimension + ) + xte, _, _ = normalize( + data.xte, xtr_mean, xtr_std + ) # use train data statistics when accessing generalization + preds = pred(pred_key, xte, subsample_size=xte.shape[0])["y"].reshape( + -1, xte.shape[0] + ) + + y_pred = jnp.mean(preds, 0) * ytr_std + ytr_mean + rmse = jnp.sqrt(jnp.mean((y_pred - data.yte) ** 2)) + + print(rf"Time taken: {datetime.timedelta(seconds=int(time_taken))}") + print(rf"RMSE: {rmse:.2f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--subsample-size", type=int, default=100) + parser.add_argument("--max-iter", type=int, default=1000) + parser.add_argument("--repulsion", type=float, default=1.0) + parser.add_argument("--verbose", type=bool, default=True) + parser.add_argument("--num-particles", type=int, default=100) + parser.add_argument("--progress-bar", type=bool, default=True) + parser.add_argument("--rng-key", type=int, default=142) + parser.add_argument("--device", default="cpu", choices=["gpu", "cpu"]) + parser.add_argument("--hidden-dim", default=50, type=int) + + args = parser.parse_args() + + numpyro.set_platform(args.device) + + main(args) diff --git a/examples/thompson_sampling.py b/examples/thompson_sampling.py index 16c972a12..41d109e3f 100644 --- a/examples/thompson_sampling.py +++ b/examples/thompson_sampling.py @@ -34,7 +34,7 @@ # the function to be minimized. At y=0 to get a 1D cut at the origin def ackley_1d(x, y=0): out = ( - -20 * jnp.exp(-0.2 * jnp.sqrt(0.5 * (x ** 2 + y ** 2))) + -20 * jnp.exp(-0.2 * jnp.sqrt(0.5 * (x**2 + y**2))) - jnp.exp(0.5 * (jnp.cos(2 * jnp.pi * x) + jnp.cos(2 * jnp.pi * y))) + jnp.e + 20 @@ -45,7 +45,7 @@ def ackley_1d(x, y=0): # matern kernel with nu = 5/2 def matern52_kernel(X, Z, var=1.0, length=0.5, jitter=1.0e-6): d = jnp.sqrt(0.5) * jnp.sqrt(jnp.power((X[:, None] - Z), 2.0)) / length - k = var * (1 + d + (d ** 2) / 3) * jnp.exp(-d) + k = var * (1 + d + (d**2) / 3) * jnp.exp(-d) if jitter: # we are assuming a noise free process, but add a small jitter for numerical stability k += jitter * jnp.eye(X.shape[0]) @@ -125,10 +125,10 @@ def predict(self, X, return_std=False): if return_std: return ( (mean * self.y_std) + self.y_mean, - jnp.sqrt(jnp.diag(K * self.y_std ** 2)), + jnp.sqrt(jnp.diag(K * self.y_std**2)), ) else: - return (mean * self.y_std) + self.y_mean, K * self.y_std ** 2 + return (mean * self.y_std) + self.y_mean, K * self.y_std**2 def sample_y(self, rng_key, X): # get posterior mean and covariance diff --git a/notebooks/source/conf.py b/notebooks/source/conf.py index 9d221f450..39567245b 100644 --- a/notebooks/source/conf.py +++ b/notebooks/source/conf.py @@ -66,9 +66,9 @@ master_doc = "index" # General information about the project. -project = u"NumPyro Tutorials" -copyright = u"2019, Uber Technologies, Inc" -author = u"Uber AI Labs" +project = "NumPyro Tutorials" +copyright = "2019, Uber Technologies, Inc" +author = "Uber AI Labs" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -164,8 +164,8 @@ ( master_doc, "NumPyroTutorials.tex", - u"Numpyro Examples and Tutorials", - u"Uber AI Labs", + "Numpyro Examples and Tutorials", + "Uber AI Labs", "manual", ) ] @@ -176,7 +176,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, "NumPyroTutorials", u"Numpyro Examples and Tutorials", [author], 1) + (master_doc, "NumPyroTutorials", "Numpyro Examples and Tutorials", [author], 1) ] @@ -189,7 +189,7 @@ ( master_doc, "NumPyroTutorials", - u"NumPyro Examples and Tutorials", + "NumPyro Examples and Tutorials", author, "NumPyroTutorials", "One line description of project.", diff --git a/notebooks/source/time_series_forecasting.ipynb b/notebooks/source/time_series_forecasting.ipynb index 3e67a19cc..b03a525be 100644 --- a/notebooks/source/time_series_forecasting.ipynb +++ b/notebooks/source/time_series_forecasting.ipynb @@ -204,8 +204,8 @@ "\n", " def transition_fn(carry, t):\n", " level, s, moving_sum = carry\n", - " season = s[0] * level ** pow_season\n", - " exp_val = level + coef_trend * level ** pow_trend + season\n", + " season = s[0] * level**pow_season\n", + " exp_val = level + coef_trend * level**pow_trend + season\n", " exp_val = jnp.clip(exp_val, a_min=0)\n", " # use expected vale when forecasting\n", " y_t = jnp.where(t >= N, exp_val, y[t])\n", @@ -222,7 +222,7 @@ " new_s = jnp.where(t >= N, s[0], new_s)\n", " s = jnp.concatenate([s[1:], new_s[None]], axis=0)\n", "\n", - " omega = sigma * exp_val ** powx + offset_sigma\n", + " omega = sigma * exp_val**powx + offset_sigma\n", " y_ = numpyro.sample(\"y\", dist.StudentT(nu, exp_val, omega))\n", "\n", " return (level, s, moving_sum), y_\n", diff --git a/numpyro/compat/optim.py b/numpyro/compat/optim.py index db0889f7e..f57972ac4 100644 --- a/numpyro/compat/optim.py +++ b/numpyro/compat/optim.py @@ -21,7 +21,7 @@ def ClippedAdam(kwargs): if lrd is not None: def step_size(i): - return init_lr * lrd ** i + return init_lr * lrd**i return optim.ClippedAdam( step_size=step_size, b1=b1, b2=b2, eps=eps, clip_norm=clip_norm diff --git a/numpyro/contrib/einstein/kernels.py b/numpyro/contrib/einstein/kernels.py index 73b865e3b..900867ae9 100644 --- a/numpyro/contrib/einstein/kernels.py +++ b/numpyro/contrib/einstein/kernels.py @@ -115,7 +115,7 @@ def compute(self, particles, particle_info, loss_fn): def kernel(x, y): diff = safe_norm(x - y, ord=2) if self._normed() and x.ndim >= 1 else x - y - kernel_res = jnp.exp(-(diff ** 2) / bandwidth) + kernel_res = jnp.exp(-(diff**2) / bandwidth) if self._mode == "matrix": if self.matrix_mode == "norm_diag": return kernel_res * jnp.identity(x.shape[0]) @@ -161,7 +161,7 @@ def mode(self): def compute(self, particles, particle_info, loss_fn): def kernel(x, y): diff = safe_norm(x - y, ord=2, axis=-1) if self._mode == "norm" else x - y - return (self.const ** 2 + diff ** 2) ** self.expon + return (self.const**2 + diff**2) ** self.expon return kernel diff --git a/numpyro/contrib/einstein/util.py b/numpyro/contrib/einstein/util.py index f9eba8b62..97ec6e093 100644 --- a/numpyro/contrib/einstein/util.py +++ b/numpyro/contrib/einstein/util.py @@ -34,7 +34,7 @@ def sqrth_and_inv_sqrth(m): mvec_t = jnp.swapaxes(mvec, -2, -1) mlambdasqrt = jnp.maximum(mlambda, 1e-5) ** 0.5 msqrt = (mvec * jnp.expand_dims(mlambdasqrt, -2)) @ mvec_t - mlambdasqrt_inv = jnp.maximum(1 / mlambdasqrt, 1e-5 ** 0.5) + mlambdasqrt_inv = jnp.maximum(1 / mlambdasqrt, 1e-5**0.5) minv_sqrt = (mvec * jnp.expand_dims(mlambdasqrt_inv, -2)) @ mvec_t minv = minv_sqrt @ jnp.swapaxes(minv_sqrt, -2, -1) return msqrt, minv, minv_sqrt @@ -47,7 +47,7 @@ def safe_norm(a, ord=2, axis=None): else: is_zero = jnp.ones_like(a, dtype="bool") norm = jnp.linalg.norm( - a + jnp.where(is_zero, jnp.ones_like(a) * 1e-5 ** norm_corr, jnp.zeros_like(a)), + a + jnp.where(is_zero, jnp.ones_like(a) * 1e-5**norm_corr, jnp.zeros_like(a)), ord=ord, axis=axis, ) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 917f71daf..abe0c9363 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -104,7 +104,7 @@ def mean(self): @property def variance(self): total = self.concentration1 + self.concentration0 - return self.concentration1 * self.concentration0 / (total ** 2 * (total + 1)) + return self.concentration1 * self.concentration0 / (total**2 * (total + 1)) def cdf(self, value): return betainc(self.concentration1, self.concentration0, value) @@ -209,7 +209,7 @@ def mean(self): def variance(self): con0 = jnp.sum(self.concentration, axis=-1, keepdims=True) return ( - self.concentration * (con0 - self.concentration) / (con0 ** 2 * (con0 + 1)) + self.concentration * (con0 - self.concentration) / (con0**2 * (con0 + 1)) ) @staticmethod @@ -246,7 +246,7 @@ def mean(self): @property def variance(self): - return jnp.reciprocal(self.rate ** 2) + return jnp.reciprocal(self.rate**2) def cdf(self, value): return -jnp.expm1(-self.rate * value) @@ -422,7 +422,7 @@ def mean(self): @property def variance(self): - return (1 - 2 / jnp.pi) * self.scale ** 2 + return (1 - 2 / jnp.pi) * self.scale**2 class InverseGamma(TransformedDistribution): @@ -499,7 +499,7 @@ def mean(self): @property def variance(self): - return jnp.broadcast_to(jnp.pi ** 2 / 6.0 * self.scale ** 2, self.batch_shape) + return jnp.broadcast_to(jnp.pi**2 / 6.0 * self.scale**2, self.batch_shape) def cdf(self, value): return jnp.exp(-jnp.exp((self.loc - value) / self.scale)) @@ -550,7 +550,7 @@ def log_prob(self, value): normalize_term = jnp.log(self.concentration0) + jnp.log(self.concentration1) return ( xlogy(self.concentration1 - 1, value) - + xlog1py(self.concentration0 - 1, -(value ** self.concentration1)) + + xlog1py(self.concentration0 - 1, -(value**self.concentration1)) + normalize_term ) @@ -599,7 +599,7 @@ def mean(self): @property def variance(self): - return jnp.broadcast_to(2 * self.scale ** 2, self.batch_shape) + return jnp.broadcast_to(2 * self.scale**2, self.batch_shape) def cdf(self, value): scaled = (value - self.loc) / self.scale @@ -914,11 +914,11 @@ def __init__(self, loc=0.0, scale=1.0, validate_args=None): @property def mean(self): - return jnp.exp(self.loc + self.scale ** 2 / 2) + return jnp.exp(self.loc + self.scale**2 / 2) @property def variance(self): - return (jnp.exp(self.scale ** 2) - 1) * jnp.exp(2 * self.loc + self.scale ** 2) + return (jnp.exp(self.scale**2) - 1) * jnp.exp(2 * self.loc + self.scale**2) def tree_flatten(self): return super(TransformedDistribution, self).tree_flatten() @@ -956,7 +956,7 @@ def mean(self): @property def variance(self): - var = (self.scale ** 2) * (jnp.pi ** 2) / 3 + var = (self.scale**2) * (jnp.pi**2) / 3 return jnp.broadcast_to(var, self.batch_shape) def cdf(self, value): @@ -1000,7 +1000,7 @@ def _batch_mahalanobis(bL, bx): # permute to (i, 1, n, -1) xt = jnp.moveaxis(xt, 0, -1) solve_bL_bx = solve_triangular(bL, xt, lower=True) # shape: (i, 1, n, -1) - M = jnp.sum(solve_bL_bx ** 2, axis=-2) # shape: (i, 1, -1) + M = jnp.sum(solve_bL_bx**2, axis=-2) # shape: (i, 1, -1) # permute back to (-1, i, 1) M = jnp.moveaxis(M, -1, 0) # reshape back to (..., 1, j, i, 1) @@ -1102,7 +1102,7 @@ def mean(self): @property def variance(self): return jnp.broadcast_to( - jnp.sum(self.scale_tril ** 2, axis=-1), self.batch_shape + self.event_shape + jnp.sum(self.scale_tril**2, axis=-1), self.batch_shape + self.event_shape ) def tree_flatten(self): @@ -1428,7 +1428,7 @@ def sample(self, key, sample_shape=()): def log_prob(self, value): normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * self.scale) value_scaled = (value - self.loc) / self.scale - return -0.5 * value_scaled ** 2 - normalize_term + return -0.5 * value_scaled**2 - normalize_term def cdf(self, value): scaled = (value - self.loc) / self.scale @@ -1443,7 +1443,7 @@ def mean(self): @property def variance(self): - return jnp.broadcast_to(self.scale ** 2, self.batch_shape) + return jnp.broadcast_to(self.scale**2, self.batch_shape) class Pareto(TransformedDistribution): @@ -1470,7 +1470,7 @@ def mean(self): def variance(self): # var is inf for alpha <= 2 a = jnp.divide( - (self.scale ** 2) * self.alpha, (self.alpha - 1) ** 2 * (self.alpha - 2) + (self.scale**2) * self.alpha, (self.alpha - 1) ** 2 * (self.alpha - 2) ) return jnp.where(self.alpha <= 2, jnp.inf, a) @@ -1606,7 +1606,7 @@ def log_prob(self, value): + gammaln(0.5 * self.df) - gammaln(0.5 * (self.df + 1.0)) ) - return -0.5 * (self.df + 1.0) * jnp.log1p(y ** 2.0 / self.df) - z + return -0.5 * (self.df + 1.0) * jnp.log1p(y**2.0 / self.df) - z @property def mean(self): @@ -1618,7 +1618,7 @@ def mean(self): @property def variance(self): var = jnp.where( - self.df > 2, jnp.divide(self.scale ** 2 * self.df, self.df - 2.0), jnp.inf + self.df > 2, jnp.divide(self.scale**2 * self.df, self.df - 2.0), jnp.inf ) var = jnp.where(self.df <= 1, jnp.nan, var) return jnp.broadcast_to(var, self.batch_shape) @@ -1744,7 +1744,7 @@ def mean(self): @property def variance(self): - return self.scale ** 2 * ( + return self.scale**2 * ( jnp.exp(gammaln(1.0 + 2.0 / self.concentration)) - jnp.exp(gammaln(1.0 + 1.0 / self.concentration)) ** 2 ) diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index e966c27a1..2d068b85b 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -420,7 +420,7 @@ def sample(self, key, sample_shape=()): corr = self.correlation conc = jnp.stack((self.phi_concentration, self.psi_concentration)) - eig = 0.5 * (conc[0] - corr ** 2 / conc[1]) + eig = 0.5 * (conc[0] - corr**2 / conc[1]) eig = jnp.stack((jnp.zeros_like(eig), eig)) eigmin = jnp.where(eig[1] < 0, eig[1], jnp.zeros_like(eig[1], dtype=eig.dtype)) eig = eig - eigmin @@ -487,7 +487,7 @@ def update_fn(curr): assert lf.shape == shape lg_inv = ( - 1.0 - b0 / 2 + jnp.log(b0 / 2 + (eig * x ** 2).sum(1, keepdims=True)) + 1.0 - b0 / 2 + jnp.log(b0 / 2 + (eig * x**2).sum(1, keepdims=True)) ) assert lg_inv.shape == lf.shape @@ -627,7 +627,7 @@ def _dot(x, y): # Integrate[x/(E^((x-t)^2/2) Sqrt[2 Pi]), {x, 0, Infinity}] # = (t + Sqrt[2/Pi]/E^(t^2/2) + t Erf[t/Sqrt[2]])/2 para_part = jnp.log( - (jnp.exp((-0.5) * t2) * ((2 / math.pi) ** 0.5) + t * (1 + erf(t * 0.5 ** 0.5))) + (jnp.exp((-0.5) * t2) * ((2 / math.pi) ** 0.5) + t * (1 + erf(t * 0.5**0.5))) / 2 ) @@ -651,7 +651,7 @@ def _dot(x, y): # = t/(E^(t^2/2) Sqrt[2 Pi]) + ((1 + t^2) (1 + Erf[t/Sqrt[2]]))/2 para_part = jnp.log( t * jnp.exp((-0.5) * t2) / (2 * math.pi) ** 0.5 - + (1 + t2) * (1 + erf(t * 0.5 ** 0.5)) / 2 + + (1 + t2) * (1 + erf(t * 0.5**0.5)) / 2 ) return para_part + perp_part diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index ac3fa635e..1997b8654 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -758,8 +758,8 @@ def mean(self): @lazy_property def variance(self): return (1 - self.gate) * ( - self.base_dist.mean ** 2 + self.base_dist.variance - ) - self.mean ** 2 + self.base_dist.mean**2 + self.base_dist.variance + ) - self.mean**2 @property def has_enumerate_support(self): diff --git a/numpyro/distributions/gof.py b/numpyro/distributions/gof.py index fd186c690..7ef80b987 100644 --- a/numpyro/distributions/gof.py +++ b/numpyro/distributions/gof.py @@ -205,7 +205,7 @@ def density_goodness_of_fit(samples, probs, plot=False): def volume_of_sphere(dim, radius): - return radius ** dim * math.pi ** (0.5 * dim) / math.gamma(0.5 * dim + 1) + return radius**dim * math.pi ** (0.5 * dim) / math.gamma(0.5 * dim + 1) def get_nearest_neighbor_distances(samples): diff --git a/numpyro/distributions/truncated.py b/numpyro/distributions/truncated.py index e195f39de..7f87e0ad7 100644 --- a/numpyro/distributions/truncated.py +++ b/numpyro/distributions/truncated.py @@ -104,7 +104,7 @@ def tree_unflatten(cls, aux_data, params): def mean(self): if isinstance(self.base_dist, Normal): low_prob = jnp.exp(self.log_prob(self.low)) - return self.base_dist.loc + low_prob * self.base_dist.scale ** 2 + return self.base_dist.loc + low_prob * self.base_dist.scale**2 elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: @@ -114,7 +114,7 @@ def mean(self): def var(self): if isinstance(self.base_dist, Normal): low_prob = jnp.exp(self.log_prob(self.low)) - return (self.base_dist.scale ** 2) * ( + return (self.base_dist.scale**2) * ( 1 + (self.low - self.base_dist.loc) * low_prob - (low_prob * self.base_dist.scale) ** 2 @@ -189,7 +189,7 @@ def tree_unflatten(cls, aux_data, params): def mean(self): if isinstance(self.base_dist, Normal): high_prob = jnp.exp(self.log_prob(self.high)) - return self.base_dist.loc - high_prob * self.base_dist.scale ** 2 + return self.base_dist.loc - high_prob * self.base_dist.scale**2 elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: @@ -199,7 +199,7 @@ def mean(self): def var(self): if isinstance(self.base_dist, Normal): high_prob = jnp.exp(self.log_prob(self.high)) - return (self.base_dist.scale ** 2) * ( + return (self.base_dist.scale**2) * ( 1 - (self.high - self.base_dist.loc) * high_prob - (high_prob * self.base_dist.scale) ** 2 @@ -312,7 +312,7 @@ def mean(self): low_prob = jnp.exp(self.log_prob(self.low)) high_prob = jnp.exp(self.log_prob(self.high)) return ( - self.base_dist.loc + (low_prob - high_prob) * self.base_dist.scale ** 2 + self.base_dist.loc + (low_prob - high_prob) * self.base_dist.scale**2 ) elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) @@ -324,7 +324,7 @@ def var(self): if isinstance(self.base_dist, Normal): low_prob = jnp.exp(self.log_prob(self.low)) high_prob = jnp.exp(self.log_prob(self.high)) - return (self.base_dist.scale ** 2) * ( + return (self.base_dist.scale**2) * ( 1 + (self.low - self.base_dist.loc) * low_prob - (self.high - self.base_dist.loc) * high_prob @@ -398,7 +398,7 @@ def sample(self, key, sample_shape=()): key, jnp.ones(self.batch_shape + sample_shape + (self.num_gamma_variates,)) ) x = jnp.sum(x / denom, axis=-1) - return jnp.clip(x * (0.5 / jnp.pi ** 2), a_max=self.truncation_point) + return jnp.clip(x * (0.5 / jnp.pi**2), a_max=self.truncation_point) @validate_sample def log_prob(self, value): diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index c24c6bd96..e418e9ac6 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -372,7 +372,7 @@ def signed_stick_breaking_tril(t): # apply stick-breaking on the squared values; # we omit the step of computing s = z * z_cumprod by using the fact: # y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod) - z = r ** 2 + z = r**2 z1m_cumprod_sqrt = jnp.cumprod(jnp.sqrt(1 - z), axis=-1) pad_width = [(0, 0)] * z.ndim @@ -453,9 +453,9 @@ def _von_mises_centered(key, concentration, shape, dtype): } s_cutoff = s_cutoff_map.get(dtype) - r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2) + r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration**2) rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration) - s_exact = (1.0 + rho ** 2) / (2.0 * rho) + s_exact = (1.0 + rho**2) / (2.0 * rho) s_approximate = 1.0 / concentration diff --git a/numpyro/examples/datasets.py b/numpyro/examples/datasets.py index 126f37c38..7567bac94 100644 --- a/numpyro/examples/datasets.py +++ b/numpyro/examples/datasets.py @@ -31,6 +31,11 @@ "baseball", ["https://d2hg8soec8ck9v.cloudfront.net/datasets/EfronMorrisBB.txt"] ) +BOSTON_HOUSING = dset( + "boston_housing", + ["https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data"], +) + COVTYPE = dset( "covtype", ["https://d2hg8soec8ck9v.cloudfront.net/datasets/covtype.zip"] ) @@ -107,6 +112,13 @@ def train_test_split(file): return {"train": (train, player_names), "test": (test, player_names)} +def _load_boston_housing(): + _download(BOSTON_HOUSING) + file_path = os.path.join(DATA_DIR, "housing.data") + data = np.loadtxt(file_path) + return {"train": (data[:, :-1], data[:, -1])} + + def _load_covtype(): _download(COVTYPE) @@ -295,6 +307,8 @@ def _load_9mers(): def _load(dset, num_datapoints=-1): if dset == BASEBALL: return _load_baseball() + elif dset == BOSTON_HOUSING: + return _load_boston_housing() elif dset == COVTYPE: return _load_covtype() elif dset == DIPPER_VOLE: diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 7932424c7..0b2c414bc 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -834,7 +834,7 @@ def scan_body(carry, eps_beta): p_grad = beta * grad(log_density)(z_half) v_hat = v_prev + eta * (q_grad + p_grad) z = z_half + v_hat * eta * inv_mass_matrix - v = gamma * v_hat + jnp.sqrt(1 - gamma ** 2) * eps + v = gamma * v_hat + jnp.sqrt(1 - gamma**2) * eps delta_ke = momentum_dist.log_prob(v_prev) - momentum_dist.log_prob(v_hat) log_factor = log_factor + delta_ke return (z, v, log_factor), None @@ -1048,7 +1048,7 @@ def __init__( ) def _get_posterior(self, *args, **kwargs): - rank = int(round(self.latent_dim ** 0.5)) if self.rank is None else self.rank + rank = int(round(self.latent_dim**0.5)) if self.rank is None else self.rank loc = numpyro.param("{}_loc".format(self.prefix), self._init_latent) cov_factor = numpyro.param( "{}_cov_factor".format(self.prefix), jnp.zeros((self.latent_dim, rank)) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 8a9009e13..f14ae4919 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -730,7 +730,7 @@ def estimator(likelihoods, params, gibbs_state): diff = subsample_log_lik - proxy_value_subsample[name] unbiased_log_lik = proxy_value_all[name] + n * jnp.mean(diff) - variance = n ** 2 / m * jnp.var(diff) + variance = n**2 / m * jnp.var(diff) log_lik_sum += unbiased_log_lik - 0.5 * variance return log_lik_sum diff --git a/numpyro/infer/hmc_util.py b/numpyro/infer/hmc_util.py index 6e5836e03..c9e7d7624 100644 --- a/numpyro/infer/hmc_util.py +++ b/numpyro/infer/hmc_util.py @@ -121,7 +121,7 @@ def update_fn(g, state): # x_t = argmin{ g_avg . x + loc_t . |x - x0|^2 }, # hence x_t = x0 - g_avg / (2 * loc_t), # where loc_t := beta_t / t, beta_t := (gamma/2) * sqrt(t). - x_t = prox_center - (t ** 0.5) / gamma * g_avg + x_t = prox_center - (t**0.5) / gamma * g_avg # weight for the new x_t weight_t = t ** (-kappa) x_avg = (1 - weight_t) * x_avg + weight_t * x_t @@ -352,7 +352,7 @@ def _body_fn(state): # Note that the direction is -1 if delta_energy is `NaN`, which may be the # case for a diverging trajectory (e.g. in the case of evaluating log prob # of a value simulated using a large step size for a constrained sample site). - step_size = (2.0 ** direction) * step_size + step_size = (2.0**direction) * step_size r = momentum_generator(z, inverse_mass_matrix, rng_key_momentum) _, r_new, potential_energy_new, _ = vv_update( step_size, inverse_mass_matrix, (z, r, potential_energy, z_grad) @@ -992,7 +992,7 @@ def _iterative_build_subtree( r_ckpts, r_sum_ckpts, ): - max_num_proposals = 2 ** prototype_tree.depth + max_num_proposals = 2**prototype_tree.depth def _cond_fn(state): tree, turning, _, _, _ = state diff --git a/numpyro/infer/reparam.py b/numpyro/infer/reparam.py index ab4dfec29..7cc713404 100644 --- a/numpyro/infer/reparam.py +++ b/numpyro/infer/reparam.py @@ -103,7 +103,7 @@ def __call__(self, name, fn, obs): constraint=constraints.unit_interval, ) params["loc"] = fn.loc * centered - params["scale"] = fn.scale ** centered + params["scale"] = fn.scale**centered decentered_fn = self._wrap(type(fn)(**params), expand_shape, event_dim) # Draw decentered noise. diff --git a/numpyro/infer/sa.py b/numpyro/infer/sa.py index b63898932..9f0816f41 100644 --- a/numpyro/infer/sa.py +++ b/numpyro/infer/sa.py @@ -25,12 +25,12 @@ def _get_proposal_loc_and_scale(samples, loc, scale, new_sample): new_scale = cholesky_update(scale, new_sample - loc, weight) proposal_scale = cholesky_update(new_scale, samples - loc, -weight) proposal_scale = cholesky_update( - proposal_scale, new_sample - samples, -(weight ** 2) + proposal_scale, new_sample - samples, -(weight**2) ) else: var = jnp.square(scale) + weight * jnp.square(new_sample - loc) proposal_var = var - weight * jnp.square(samples - loc) - proposal_var = proposal_var - weight ** 2 * jnp.square(new_sample - samples) + proposal_var = proposal_var - weight**2 * jnp.square(new_sample - samples) proposal_scale = jnp.sqrt(proposal_var) proposal_loc = loc + weight * (new_sample - samples) diff --git a/test/contrib/einstein/test_einstein_util.py b/test/contrib/einstein/test_einstein_util.py index 6ed3e51aa..41cb2eff4 100644 --- a/test/contrib/einstein/test_einstein_util.py +++ b/test/contrib/einstein/test_einstein_util.py @@ -67,7 +67,7 @@ def test_safe_norm(axis, ord): assert_allclose( safe_norm(m, axis=axis), jnp.linalg.norm( - m + (1e-5 ** ord if axis is None and ord is not None else 0.0), + m + (1e-5**ord if axis is None and ord is not None else 0.0), ord=ord, axis=axis, ), diff --git a/test/contrib/test_nested_sampling.py b/test/contrib/test_nested_sampling.py index a45f2c067..2426d63e9 100644 --- a/test/contrib/test_nested_sampling.py +++ b/test/contrib/test_nested_sampling.py @@ -22,8 +22,8 @@ def get_moments(x): xxx = x * xx xxxx = xx * xx m2 = jnp.mean(xx, axis=0) - m3 = jnp.mean(xxx, axis=0) / m2 ** 1.5 - m4 = jnp.mean(xxxx, axis=0) / m2 ** 2 + m3 = jnp.mean(xxx, axis=0) / m2**1.5 + m4 = jnp.mean(xxxx, axis=0) / m2**2 return jnp.stack([m1, m2, m3, m4]) diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index ab2dcdebe..64641199d 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -362,12 +362,12 @@ def test_laplace_approximation_warning(): def model(x, y): a = numpyro.sample("a", dist.Normal(0, 10)) b = numpyro.sample("b", dist.Normal(0, 10).expand([3]).to_event()) - mu = a + b[0] * x + b[1] * x ** 2 + b[2] * x ** 3 + mu = a + b[0] * x + b[1] * x**2 + b[2] * x**3 with numpyro.plate("N", len(x)): numpyro.sample("y", dist.Normal(mu, 0.001), obs=y) x = random.normal(random.PRNGKey(0), (3,)) - y = 1 + 2 * x + 3 * x ** 2 + 4 * x ** 3 + y = 1 + 2 * x + 3 * x**2 + 4 * x**3 guide = AutoLaplaceApproximation(model) svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), x=x, y=y) init_state = svi.init(random.PRNGKey(0)) diff --git a/test/infer/test_hmc_gibbs.py b/test/infer/test_hmc_gibbs.py index 94752d389..e7b67f524 100644 --- a/test/infer/test_hmc_gibbs.py +++ b/test/infer/test_hmc_gibbs.py @@ -206,7 +206,7 @@ def test_discrete_gibbs_enum(kernel, inner_kernel, kwargs): def model(): numpyro.sample("x", dist.Bernoulli(0.7), infer={"enumerate": "parallel"}) y = numpyro.sample("y", dist.Binomial(10, 0.3)) - numpyro.deterministic("y2", y ** 2) + numpyro.deterministic("y2", y**2) sampler = kernel(inner_kernel(model), **kwargs) mcmc = MCMC(sampler, num_warmup=1000, num_samples=10000, progress_bar=False) diff --git a/test/infer/test_hmc_util.py b/test/infer/test_hmc_util.py index a507287a1..762e31b34 100644 --- a/test/infer/test_hmc_util.py +++ b/test/infer/test_hmc_util.py @@ -158,7 +158,7 @@ class CircularPlanetaryMotion(object): @staticmethod def kinetic_fn(m_inv, p): z = jnp.stack([p["x"], p["y"]], axis=-1) - return 0.5 * jnp.dot(m_inv, z ** 2) + return 0.5 * jnp.dot(m_inv, z**2) @staticmethod def potential_fn(q): @@ -231,10 +231,10 @@ def get_final_state(model, step_size, num_steps, q_i, p_i): @pytest.mark.parametrize("init_step_size", [0.1, 10.0]) def test_find_reasonable_step_size(jitted, init_step_size): def kinetic_fn(m_inv, p): - return 0.5 * jnp.sum(m_inv * p ** 2) + return 0.5 * jnp.sum(m_inv * p**2) def potential_fn(q): - return 0.5 * q ** 2 + return 0.5 * q**2 p_generator = lambda prototype, m_inv, rng_key: 1.0 # noqa: E731 q = 0.0 @@ -404,10 +404,10 @@ def test_is_iterative_turning(ckpt_idxs, expected_turning): @pytest.mark.parametrize("step_size", [0.01, 1.0, 100.0]) def test_build_tree(step_size): def kinetic_fn(m_inv, p): - return 0.5 * jnp.sum(m_inv * p ** 2) + return 0.5 * jnp.sum(m_inv * p**2) def potential_fn(q): - return 0.5 * q ** 2 + return 0.5 * q**2 vv_init, vv_update = velocity_verlet(potential_fn, kinetic_fn) vv_state = vv_init(0.0, 1.0) diff --git a/test/infer/test_infer_util.py b/test/infer/test_infer_util.py index 2fb1fc50a..0bfefa12c 100644 --- a/test/infer/test_infer_util.py +++ b/test/infer/test_infer_util.py @@ -44,7 +44,7 @@ def model(data=None): with numpyro.plate("dim", 2): beta = numpyro.sample("beta", dist.Beta(1.0, 1.0)) with numpyro.plate("plate", N, dim=-2): - numpyro.deterministic("beta_sq", beta ** 2) + numpyro.deterministic("beta_sq", beta**2) with numpyro.plate("dim", 2): numpyro.sample("obs", dist.Bernoulli(beta), obs=data) @@ -78,7 +78,7 @@ def test_predictive_with_guide(): def model(data): f = numpyro.sample("beta", dist.Beta(1.0, 1.0)) with numpyro.plate("plate", 10): - numpyro.deterministic("beta_sq", f ** 2) + numpyro.deterministic("beta_sq", f**2) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index 39c1ac3ff..5cc78d21c 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -81,7 +81,7 @@ def potential_fn(z): mcmc.run(random.PRNGKey(0), init_params=init_params) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples), true_mean, atol=0.02) - assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D ** 2 < 0.02 + assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02 @pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH]) diff --git a/test/infer/test_reparam.py b/test/infer/test_reparam.py index 2a061b2c4..f6e7135e2 100644 --- a/test/infer/test_reparam.py +++ b/test/infer/test_reparam.py @@ -33,8 +33,8 @@ def get_moments(x): xxx = x * xx xxxx = xx * xx m2 = jnp.mean(xx, axis=0) - m3 = jnp.mean(xxx, axis=0) / m2 ** 1.5 - m4 = jnp.mean(xxxx, axis=0) / m2 ** 2 + m3 = jnp.mean(xxx, axis=0) / m2**1.5 + m4 = jnp.mean(xxxx, axis=0) / m2**2 return jnp.stack([m1, m2, m3, m4]) diff --git a/test/test_distributions.py b/test/test_distributions.py index 05b289f7a..b46c539c1 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -937,11 +937,11 @@ def test_pathwise_gradient(jax_dist, params): def f(params): z = jax_dist(*params).sample(key=rng_key, sample_shape=(N,)) - return (z + z ** 2).mean(0) + return (z + z**2).mean(0) def g(params): d = jax_dist(*params) - return d.mean + d.variance + d.mean ** 2 + return d.mean + d.variance + d.mean**2 actual_grad = grad(f)(params) expected_grad = grad(g)(params) @@ -1131,7 +1131,7 @@ def test_independent_shape(jax_dist, sp_dist, params): def _tril_cholesky_to_tril_corr(x): w = vec_to_tril_matrix(x, diagonal=-1) - diag = jnp.sqrt(1 - jnp.sum(w ** 2, axis=-1)) + diag = jnp.sqrt(1 - jnp.sum(w**2, axis=-1)) cholesky = w + jnp.expand_dims(diag, axis=-1) * jnp.identity(w.shape[-1]) corr = jnp.matmul(cholesky, cholesky.T) return matrix_to_tril_vec(corr, diagonal=-1) @@ -1429,7 +1429,7 @@ def test_mean_var(jax_dist, sp_dist, params): # circular variance x, y = jnp.mean(jnp.cos(samples), 0), jnp.mean(jnp.sin(samples), 0) - expected_variance = 1 - jnp.sqrt(x ** 2 + y ** 2) + expected_variance = 1 - jnp.sqrt(x**2 + y**2) assert_allclose(d_jax.variance, expected_variance, rtol=0.05, atol=1e-2) elif jax_dist in [dist.SineBivariateVonMises]: phi_loc = _circ_mean(samples[..., 0]) diff --git a/test/test_examples.py b/test/test_examples.py index 64bd94606..59c300fd5 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -54,6 +54,7 @@ ), "sparse_regression.py --num-samples 10 --num-warmup 10 --num-data 10 --num-dimensions 10", "ssbvm_mixture.py --num-samples 10 --num-warmup 10", + "stein_bnn.py --max-iter 10 --subsample-size 10 --num-particles 5", "stochastic_volatility.py --num-samples 100 --num-warmup 100", "ucbadmit.py --num-chains 2", "vae.py -n 1", diff --git a/test/test_handlers.py b/test/test_handlers.py index 6d7e73a60..bf0c98378 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -224,7 +224,7 @@ def model_nested_plates_0(): with numpyro.plate("inner", 5): y = numpyro.sample("x", dist.Normal(0.0, 1.0)) assert y.shape == (5, 10) - z = numpyro.deterministic("z", x ** 2) + z = numpyro.deterministic("z", x**2) assert z.shape == (10,) @@ -235,7 +235,7 @@ def model_nested_plates_1(): with numpyro.plate("inner", 5): y = numpyro.sample("x", dist.Normal(0.0, 1.0)) assert y.shape == (10, 5) - z = numpyro.deterministic("z", x ** 2) + z = numpyro.deterministic("z", x**2) assert z.shape == (10, 1) @@ -248,7 +248,7 @@ def model_nested_plates_2(): with inner: y = numpyro.sample("y", dist.Normal(0.0, 1.0)) assert y.shape == (5, 1, 1) - z = numpyro.deterministic("z", x ** 2) + z = numpyro.deterministic("z", x**2) assert z.shape == (10,) with outer, inner: @@ -275,7 +275,7 @@ def model_dist_batch_shape(): with inner: y = numpyro.sample("y", dist.Normal(0.0, jnp.ones(10))) assert y.shape == (5, 1, 10) - z = numpyro.deterministic("z", x ** 2) + z = numpyro.deterministic("z", x**2) assert z.shape == (10,) with outer, inner: @@ -292,7 +292,7 @@ def model_subsample_1(): with inner: y = numpyro.sample("y", dist.Normal(0.0, 1.0)) assert y.shape == (5, 1, 1) - z = numpyro.deterministic("z", x ** 2) + z = numpyro.deterministic("z", x**2) assert z.shape == (10,) with outer, inner: @@ -310,7 +310,7 @@ def model_subsample_2(): with inner: y = numpyro.sample("y", dist.Normal(0.0, 1.0)) assert y.shape == (5, 1, 1) - z = numpyro.deterministic("z", x ** 2) + z = numpyro.deterministic("z", x**2) assert z.shape == (10,) with outer, inner: