Skip to content

Commit

Permalink
Example: Bayes NN with steinVI (pyro-ppl#1297)
Browse files Browse the repository at this point in the history
* added stein example

* added test case

* added stein bnn to docs.

* moveed stein_bnn to other inf algs in docs

* Added correct plating for model in `stein_bnn.py`. Works with latest pyro-ppl#833.

* Add some doctests to transforms  (pyro-ppl#1300)

* add some doctest to transforms

* make format

* Tutorial for truncated distributions (pyro-ppl#1272)

* WIP Do not merge. Tutorial for truncated distributions

* WIP: Completed a few todos and fixed a few typos

* WIP: Completed main sections. References and part 5 still pending

* Added section on built in distributions and folded distributions

* Draft ready

* Remove M1-related warning from cell output

* Truncated distributions tutorial added to index

* Wrap latex equations in double dollar sign

* Fix broken markdown equations

* Added more details on folded distribs. Re-arranged sections.

* Test: Change title level.

* Links now point to the docs instead of the source code.
Fixed some broken formatting of the titles.
Use different seeds for Prior/Inference/Prediction.
Changed models for inferring the truncation.
Fixed minor typos.

* Install numpyro and upgrade jax, jaxlib and matplotlib
Copy jax arrays before passing to matplotlib functions

* Clarified statement about the log_prob method in the
TruncatedDistribution class.

* Changed intro sentence to include folded distributions.

* Remove command for installing jax.
Use np.unique instead of jnp.unique

* Cast rate parameter to float (pyro-ppl#1301)

* Make potential_fn_gen and postprocess_fn_gen picklable (pyro-ppl#1302)

* add wrapper

* Make potential_fn_gen postprocess_fn_gen pickable

* Stein based inference (pyro-ppl#833)

* Added stein interface.

* Fixed style and removed from VI baseclass.

* Added reinit_guide.py

* Added license.

* added examples

* Added examples.

* Fixed some linting and LDA example; need to refactor wrapped_guide.

* Added param site also get rng_keys; this should be reworked!

* Removed datasets and fixed lda to running.

* Fixed dimensionality bug for simplex support.

* Added code from refactor/einstein

* Fixed notebooks; todo: comment notebook.

* Factored initialization of `kernels.RandomFeatureKernel` into `Stein.init` and updated `test_kernels.test_kernel_forward` accordingly.

* Started testing.

* Removed assert from test_init_strategy.

* Skeleton test_stein.py

* Updated `test_stein/test_init`

* Added test_params and likelihood computation to lda.

* Fix init in MixtureKernel

* Notebook fixes

* debugging log likelihood

* WIP, move benchmarks to datasets

* trace guide to compute likelihood in lda.

* Debugging LDA

* Removed test_vi.py (will use test_stein.py), added `test_stein.test_update_evaluate`

* Cleaned test covered by `test_get_params`.

* Added skeleton and finished _param_size test.

* Fix LR example

* IRIS LR

* Fix Toy examples

* Added pinfo test.

* moved stein/test_kernels.py into stein/test_stein.py; updated `test_stein.test_apply_kernel`

* Ran black and removed lambdas from KERNEL_TEST_CASE.

* Added `test_stein.test_sp_mcmc` and removed calls to jnp.random.shuffle (deprecated).

* Added skelelton test for test_score_sp_mcmc.

* Fixed overwriting kval in `test_stein.test_apply_kernel`

* Fixed lint

* Fixed lint.

* Added stein_loss test.

* Factored vi source and test_vi out of einstein.

* updated with black.

* Figured out likelihood for LDA (need to change to compute likelihood instead of ELBO)

* Added perplexity to LDA.

* Fixed log position for perplexity.

* Refactored callbacks and added `test_checkpoint`.

* Fixed imports

* Reverted LDA to working version.

* Added callback tests

* Return loss history for `stein.run`

* Added visual to LDA.

* Fixed return for `run_lda`

* Added missing topic num 20.

* Added todo

* Cleaned 1d_mm stein notebook.

* Updated 2d gaussian notebook.

* Add description to SVGD.

* Updated `RBF_kernel` to work with one particle and added kernels notebook.

* Fixed bug in bandwidth of RBF_kernel

* SVI reproducing result from SVGD paper.

* Better learning rate for SVI.

* larger network

* Updated predictive to allow for particle methods.

* Removed TODO and fixed learning rate.

* EinStein out performance SVGD

* Latest working.

* Fixed VI without progressbar.

* Fixed mini batching for VI.

* Added kernel visualization.

* Init to sample for bayesian networks.

* TODO predict shape.

* Added scaling to plate primitive.

* Fixed enumeration in Stein and added subsample_scale to funsor.plate.

* Debugging LDA

* Debugging lda

* Debugging merge.

* Updated jacobian computation in Stein.

* Fixed issue with nested parameters for stein grad.

* Fixed issue with nested parameters for stein grad.

* Added NLL to DMM and predictions.

* renaming and removing benchmark code

* Cleaning branch from benchmarking.

* Removed prediction from DMM.

* Changed to syntax from older python

* Fixed lint.

* Fixed reinit warning for `init_to_uniform`.

* updated to use black[jupyter]

* Added licenses.

* Added smoke a smoke test for SteinVI

* Factored out Stein point MCMC

* Factored out VI from EinStein.

* Updated stein_kernels.ipynb and removed debugging pred_prey.

* Removed `mixture.py` use `mixtures.py` instead.

* Fixed lint.

* Added examples to docs build.

* Fixed stale import in `hmc.py` docstring.

* Removed stein point test cases.

* Changed Predictive to only check for guided models with particles.

* Fixed lint.

* Changed `reinit_guide` to add rng_keys for reinitialization.

* Added boston pricing dataset. Commented stein bnn example. Added `stein_bnn.py` to `test_examples.py`.

* Removed empty line from `test_examples.py`.

* Fixed lint.

* Changed `stein_mixture_dmm.py` to use new signature and run method.

* Added some comments and fixed `stein_mixture_dmm.py` to use new signature.

* Fixed `event_shape` and `support` for `Sine`, `DoubleBanana`, and `Star` distributions in `stein_2d_dists.py`.

* Removed notebooks from initial PR and updated stein_2d_toy.py to new run signature.

* Parameterize `gru_dim` in `stein_mixture_dmm.py`.

* Fixed steinvi to use pyro-ppl#1263; TODO: update examples.

* Removed init_with_noise.

* removed stein_bnn and changed `examples/datasets.py` to upstream

* removed examples

* Removed stein examples from docs.

* renamed einstein.utils to einstein.util.

* updated testing

* Changed test to use auto_guide `init_loc_fn`.

* removed `numpyro/util/ravel_pytree`

* removed unused imports in numpyro/util

* Added initialization to kernels in `test_einstein_kernels.py`

* Changed kernel test to use np.arrays at global level.

* change jnp arrays to np np array in tests. reverted subsample scale.

* added docstring to `einstein/util/batch_ravel_pytree`

Co-authored-by: Ahmad Salim Al-Sibahi <[email protected]>
Co-authored-by: einsteinvi <[email protected]>

* Improve subsample warning keys (pyro-ppl#1303)

* Add ProvenanceArray to infer relational structure in a model (pyro-ppl#1248)

* Add provenance array

* Add tests for provenance

* run make format

* Workaround not be able to eval_shape a distribution

* Make license

* add a clearer guide for render a model with scan

* fix failing bugs in recent jax release

* Fix further failing tests

* Make sure to be able to render ImproperUniform and random initialized params

* port get_dependencies to numpyro

* tighten test_improper_normal bound (pyro-ppl#1307)

* Fix HMCECS multiple plates (pyro-ppl#1305)

* Add Kumaraswamy and relaxed Bernoulli distributions (pyro-ppl#1283)

* Add kumaraswamy and relaxed bernoulli distributions

* clean up the flag

* Require logits to be keyword argument

* make relaxed bernoulli have the same signature as Pyro

* fix docs build

* Fix rsample bug

* add more simple test for Kumaraswamy

* Add various KL divergences for Gamma/Beta families (pyro-ppl#1284)

* Add new distributions and kl

* Add kumaraswamy and relaxed bernoulli distributions

* clean up the flag

* Require logits to be keyword argument

* make relaxed bernoulli have the same signature as Pyro

* fix docs build

* Fix rsample bug

* move the flag to Kumaraswamy class for convenient

* Add loose strategy for missing plates in MCMC (pyro-ppl#1304)

* Add loose strategy for MCMC

* merge svi and mcmc plate warning strategies

* fix failing tests

* validate model accross ELBOs

* update vae example

* fix typos

* Fix failing tests

* skip prodlda test on CI

* Bump to 0.9.0 (pyro-ppl#1310)

* Add loose strategy for MCMC

* merge svi and mcmc plate warning strategies

* fix failing tests

* validate model accross ELBOs

* update vae example

* fix typos

* Bump to version 0.9.0

* Fix failing tests

* Fix warnings in tests/examples

* relax funsor requirement

* Move optax_to_numpyro to optim

* skip prodlda test on CI

* added dimensions to plate and sqrt precision.

* fixed/added comments in stein_bnn.py and removed lr datasets.

* added comment to stein_bnn.py

* formatted files to black==22.1.0

Co-authored-by: Wataru Hashimoto <[email protected]>
Co-authored-by: Omar Sosa Rodríguez <[email protected]>
Co-authored-by: Vedran Hadziosmanovic <[email protected]>
Co-authored-by: Du Phan <[email protected]>
Co-authored-by: Ahmad Salim Al-Sibahi <[email protected]>
Co-authored-by: einsteinvi <[email protected]>
Co-authored-by: austereantelope <[email protected]>
  • Loading branch information
8 people authored Jan 31, 2022
1 parent 7084aaa commit 2690521
Show file tree
Hide file tree
Showing 35 changed files with 304 additions and 104 deletions.
12 changes: 6 additions & 6 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@

# -- Project information -----------------------------------------------------

project = u"NumPyro"
copyright = u"2019, Uber Technologies, Inc"
author = u"Uber AI Labs"
project = "NumPyro"
copyright = "2019, Uber Technologies, Inc"
author = "Uber AI Labs"

version = ""

Expand Down Expand Up @@ -280,14 +280,14 @@
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, "NumPyro.tex", u"NumPyro Documentation", u"Uber AI Labs", "manual")
(master_doc, "NumPyro.tex", "NumPyro Documentation", "Uber AI Labs", "manual")
]

# -- Options for manual page output ------------------------------------------

# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [(master_doc, "NumPyro", u"NumPyro Documentation", [author], 1)]
man_pages = [(master_doc, "NumPyro", "NumPyro Documentation", [author], 1)]

# -- Options for Texinfo output ----------------------------------------------

Expand All @@ -298,7 +298,7 @@
(
master_doc,
"NumPyro",
u"NumPyro Documentation",
"NumPyro Documentation",
author,
"NumPyro",
"Pyro PPL on Numpy",
Expand Down
3 changes: 2 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ NumPyro documentation
:caption: Other Inference Algorithms
:name: other-inference-algorithms

examples/hmcecs
examples/covtype
examples/hmcecs
examples/stein_bnn


Indices and tables
Expand Down
4 changes: 2 additions & 2 deletions examples/hmm_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def transition_fn(carry, y):
# target hidden dimension.
def model_3(sequences, lengths, args, include_prior=True):
num_sequences, max_length, data_dim = sequences.shape
hidden_dim = int(args.hidden_dim ** 0.5) # split between w and x
hidden_dim = int(args.hidden_dim**0.5) # split between w and x
with mask(mask=include_prior):
probs_w = numpyro.sample(
"probs_w", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)
Expand Down Expand Up @@ -220,7 +220,7 @@ def transition_fn(carry, y):
# Factorial HMM, but this model has more parameters.
def model_4(sequences, lengths, args, include_prior=True):
num_sequences, max_length, data_dim = sequences.shape
hidden_dim = int(args.hidden_dim ** 0.5) # split between w and x
hidden_dim = int(args.hidden_dim**0.5) # split between w and x
with mask(mask=include_prior):
probs_w = numpyro.sample(
"probs_w", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)
Expand Down
4 changes: 2 additions & 2 deletions examples/hsgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def make_birthdays_data_dict(data):
# --- Modelling utility functions --- #
def spectral_density(w, alpha, length):
c = alpha * jnp.sqrt(2 * jnp.pi) * length
e = jnp.exp(-0.5 * (length ** 2) * (w ** 2))
e = jnp.exp(-0.5 * (length**2) * (w**2))
return c * e


Expand Down Expand Up @@ -197,7 +197,7 @@ def diag_spectral_density_periodic(alpha, length, M):
a = length ** (-2)
J = jnp.arange(0, M)
c = jnp.where(J > 0, 2, 1)
q2 = (c * alpha ** 2 / jnp.exp(a)) * modified_bessel_first_kind(J, a)
q2 = (c * alpha**2 / jnp.exp(a)) * modified_bessel_first_kind(J, a)
return q2


Expand Down
8 changes: 4 additions & 4 deletions examples/sparse_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def model(X, Y, hypers):

# compute kernel
kX = kappa * X
k = kernel(kX, kX, eta1, eta2, hypers["c"]) + sigma ** 2 * jnp.eye(N)
k = kernel(kX, kX, eta1, eta2, hypers["c"]) + sigma**2 * jnp.eye(N)
assert k.shape == (N, N)

# sample Y according to the standard gaussian process formula
Expand All @@ -99,7 +99,7 @@ def compute_singleton_mean_variance(X, Y, dimension, msq, lam, eta1, xisq, c, si
kX = kappa * X
kprobe = kappa * probe

k_xx = kernel(kX, kX, eta1, eta2, c) + sigma ** 2 * jnp.eye(N)
k_xx = kernel(kX, kX, eta1, eta2, c) + sigma**2 * jnp.eye(N)
k_xx_inv = jnp.linalg.inv(k_xx)
k_probeX = kernel(kprobe, kX, eta1, eta2, c)
k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)
Expand Down Expand Up @@ -130,7 +130,7 @@ def compute_pairwise_mean_variance(X, Y, dim1, dim2, msq, lam, eta1, xisq, c, si
kX = kappa * X
kprobe = kappa * probe

k_xx = kernel(kX, kX, eta1, eta2, c) + sigma ** 2 * jnp.eye(N)
k_xx = kernel(kX, kX, eta1, eta2, c) + sigma**2 * jnp.eye(N)
k_xx_inv = jnp.linalg.inv(k_xx)
k_probeX = kernel(kprobe, kX, eta1, eta2, c)
k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)
Expand Down Expand Up @@ -188,7 +188,7 @@ def sample_theta_space(X, Y, active_dims, msq, lam, eta1, xisq, c, sigma):
kX = kappa * X
kprobe = kappa * probe

k_xx = kernel(kX, kX, eta1, eta2, c) + sigma ** 2 * jnp.eye(N)
k_xx = kernel(kX, kX, eta1, eta2, c) + sigma**2 * jnp.eye(N)
L = cho_factor(k_xx, lower=True)[0]
k_probeX = kernel(kprobe, kX, eta1, eta2, c)
k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)
Expand Down
184 changes: 184 additions & 0 deletions examples/stein_bnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

"""
Example: Bayesian Neural Network with SteinVI
=============================================
We demonstrate how to use SteinVI to predict housing prices using a BNN for the Boston Housing prices dataset
from the UCI regression benchmarks.
"""

import argparse
from collections import namedtuple
import datetime
from functools import partial
from time import time

from sklearn.model_selection import train_test_split

from jax import random
import jax.numpy as jnp

import numpyro
from numpyro.contrib.einstein import RBFKernel, SteinVI
from numpyro.distributions import Gamma, Normal
from numpyro.examples.datasets import BOSTON_HOUSING, load_dataset
from numpyro.infer import Predictive, Trace_ELBO, init_to_uniform
from numpyro.infer.autoguide import AutoDelta
from numpyro.optim import Adagrad

DataState = namedtuple("data", ["xtr", "xte", "ytr", "yte"])


def load_data() -> DataState:
_, fetch = load_dataset(BOSTON_HOUSING, shuffle=False)
x, y = fetch()
xtr, xte, ytr, yte = train_test_split(x, y, train_size=0.90)

return DataState(*map(partial(jnp.array, dtype=float), (xtr, xte, ytr, yte)))


def normalize(val, mean=None, std=None):
"""Normalize data to zero mean, unit variance"""
if mean is None and std is None:
# Only use training data to estimate mean and std.
std = jnp.std(val, 0, keepdims=True)
std = jnp.where(std == 0, 1.0, std)
mean = jnp.mean(val, 0, keepdims=True)
return (val - mean) / std, mean, std


def model(x, y=None, hidden_dim=50, subsample_size=100):
"""BNN described in section 5 of [1].
**References:**
1. *Stein variational gradient descent: A general purpose bayesian inference algorithm*
Qiang Liu and Dilin Wang (2016).
"""

prec_nn = numpyro.sample(
"prec_nn", Gamma(1.0, 0.1)
) # hyper prior for precision of nn weights and biases

n, m = x.shape

with numpyro.plate("l1_hidden", hidden_dim, dim=-1):
# prior l1 bias term
b1 = numpyro.sample(
"nn_b1",
Normal(
0.0,
1.0 / jnp.sqrt(prec_nn),
),
)
assert b1.shape == (hidden_dim,)

with numpyro.plate("l1_feat", m, dim=-2):
w1 = numpyro.sample(
"nn_w1", Normal(0.0, 1.0 / jnp.sqrt(prec_nn))
) # prior on l1 weights
assert w1.shape == (m, hidden_dim)

with numpyro.plate("l2_hidden", hidden_dim, dim=-1):
w2 = numpyro.sample(
"nn_w2", Normal(0.0, 1.0 / jnp.sqrt(prec_nn))
) # prior on output weights

b2 = numpyro.sample(
"nn_b2", Normal(0.0, 1.0 / jnp.sqrt(prec_nn))
) # prior on output bias term

# precision prior on observations
prec_obs = numpyro.sample("prec_obs", Gamma(1.0, 0.1))
with numpyro.plate(
"data",
x.shape[0],
subsample_size=subsample_size,
dim=-1,
):
batch_x = numpyro.subsample(x, event_dim=1)
if y is not None:
batch_y = numpyro.subsample(y, event_dim=0)
else:
batch_y = y

numpyro.sample(
"y",
Normal(
jnp.maximum(batch_x @ w1 + b1, 0) @ w2 + b2, 1.0 / jnp.sqrt(prec_obs)
), # 1 hidden layer with ReLU activation
obs=batch_y,
)


def main(args):
data = load_data()

inf_key, pred_key, data_key = random.split(random.PRNGKey(args.rng_key), 3)
# normalize data and labels to zero mean unit variance!
x, xtr_mean, xtr_std = normalize(data.xtr)
y, ytr_mean, ytr_std = normalize(data.ytr)

rng_key, inf_key = random.split(inf_key)

stein = SteinVI(
model,
AutoDelta(model, init_loc_fn=partial(init_to_uniform, radius=0.1)),
Adagrad(0.05),
Trace_ELBO(20), # estimate elbo with 20 particles (not stein particles!)
RBFKernel(),
repulsion_temperature=args.repulsion,
num_particles=args.num_particles,
)
start = time()

# use keyword params for static (shape etc.)!
result = stein.run(
rng_key,
args.max_iter,
x,
y,
hidden_dim=args.hidden_dim,
subsample_size=args.subsample_size,
progress_bar=args.progress_bar,
)
time_taken = time() - start

pred = Predictive(
model,
guide=stein.guide,
params=stein.get_params(result.state),
num_samples=1,
batch_ndims=1, # stein particle dimension
)
xte, _, _ = normalize(
data.xte, xtr_mean, xtr_std
) # use train data statistics when accessing generalization
preds = pred(pred_key, xte, subsample_size=xte.shape[0])["y"].reshape(
-1, xte.shape[0]
)

y_pred = jnp.mean(preds, 0) * ytr_std + ytr_mean
rmse = jnp.sqrt(jnp.mean((y_pred - data.yte) ** 2))

print(rf"Time taken: {datetime.timedelta(seconds=int(time_taken))}")
print(rf"RMSE: {rmse:.2f}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--subsample-size", type=int, default=100)
parser.add_argument("--max-iter", type=int, default=1000)
parser.add_argument("--repulsion", type=float, default=1.0)
parser.add_argument("--verbose", type=bool, default=True)
parser.add_argument("--num-particles", type=int, default=100)
parser.add_argument("--progress-bar", type=bool, default=True)
parser.add_argument("--rng-key", type=int, default=142)
parser.add_argument("--device", default="cpu", choices=["gpu", "cpu"])
parser.add_argument("--hidden-dim", default=50, type=int)

args = parser.parse_args()

numpyro.set_platform(args.device)

main(args)
8 changes: 4 additions & 4 deletions examples/thompson_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# the function to be minimized. At y=0 to get a 1D cut at the origin
def ackley_1d(x, y=0):
out = (
-20 * jnp.exp(-0.2 * jnp.sqrt(0.5 * (x ** 2 + y ** 2)))
-20 * jnp.exp(-0.2 * jnp.sqrt(0.5 * (x**2 + y**2)))
- jnp.exp(0.5 * (jnp.cos(2 * jnp.pi * x) + jnp.cos(2 * jnp.pi * y)))
+ jnp.e
+ 20
Expand All @@ -45,7 +45,7 @@ def ackley_1d(x, y=0):
# matern kernel with nu = 5/2
def matern52_kernel(X, Z, var=1.0, length=0.5, jitter=1.0e-6):
d = jnp.sqrt(0.5) * jnp.sqrt(jnp.power((X[:, None] - Z), 2.0)) / length
k = var * (1 + d + (d ** 2) / 3) * jnp.exp(-d)
k = var * (1 + d + (d**2) / 3) * jnp.exp(-d)
if jitter:
# we are assuming a noise free process, but add a small jitter for numerical stability
k += jitter * jnp.eye(X.shape[0])
Expand Down Expand Up @@ -125,10 +125,10 @@ def predict(self, X, return_std=False):
if return_std:
return (
(mean * self.y_std) + self.y_mean,
jnp.sqrt(jnp.diag(K * self.y_std ** 2)),
jnp.sqrt(jnp.diag(K * self.y_std**2)),
)
else:
return (mean * self.y_std) + self.y_mean, K * self.y_std ** 2
return (mean * self.y_std) + self.y_mean, K * self.y_std**2

def sample_y(self, rng_key, X):
# get posterior mean and covariance
Expand Down
14 changes: 7 additions & 7 deletions notebooks/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@
master_doc = "index"

# General information about the project.
project = u"NumPyro Tutorials"
copyright = u"2019, Uber Technologies, Inc"
author = u"Uber AI Labs"
project = "NumPyro Tutorials"
copyright = "2019, Uber Technologies, Inc"
author = "Uber AI Labs"

# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
Expand Down Expand Up @@ -164,8 +164,8 @@
(
master_doc,
"NumPyroTutorials.tex",
u"Numpyro Examples and Tutorials",
u"Uber AI Labs",
"Numpyro Examples and Tutorials",
"Uber AI Labs",
"manual",
)
]
Expand All @@ -176,7 +176,7 @@
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, "NumPyroTutorials", u"Numpyro Examples and Tutorials", [author], 1)
(master_doc, "NumPyroTutorials", "Numpyro Examples and Tutorials", [author], 1)
]


Expand All @@ -189,7 +189,7 @@
(
master_doc,
"NumPyroTutorials",
u"NumPyro Examples and Tutorials",
"NumPyro Examples and Tutorials",
author,
"NumPyroTutorials",
"One line description of project.",
Expand Down
6 changes: 3 additions & 3 deletions notebooks/source/time_series_forecasting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@
"\n",
" def transition_fn(carry, t):\n",
" level, s, moving_sum = carry\n",
" season = s[0] * level ** pow_season\n",
" exp_val = level + coef_trend * level ** pow_trend + season\n",
" season = s[0] * level**pow_season\n",
" exp_val = level + coef_trend * level**pow_trend + season\n",
" exp_val = jnp.clip(exp_val, a_min=0)\n",
" # use expected vale when forecasting\n",
" y_t = jnp.where(t >= N, exp_val, y[t])\n",
Expand All @@ -222,7 +222,7 @@
" new_s = jnp.where(t >= N, s[0], new_s)\n",
" s = jnp.concatenate([s[1:], new_s[None]], axis=0)\n",
"\n",
" omega = sigma * exp_val ** powx + offset_sigma\n",
" omega = sigma * exp_val**powx + offset_sigma\n",
" y_ = numpyro.sample(\"y\", dist.StudentT(nu, exp_val, omega))\n",
"\n",
" return (level, s, moving_sum), y_\n",
Expand Down
Loading

0 comments on commit 2690521

Please sign in to comment.