Skip to content

Commit

Permalink
Fix AR2 modeling (pyro-ppl#1350)
Browse files Browse the repository at this point in the history
* fix ar2 modeling.

* update after feedback

* fix half normal
  • Loading branch information
hesenp authored Mar 10, 2022
1 parent 3de2797 commit 89e323e
Showing 1 changed file with 42 additions and 44 deletions.
86 changes: 42 additions & 44 deletions examples/ar2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
In this example we show how to use ``jax.lax.scan``
to avoid writing a (slow) Python for-loop. In this toy
example, with ``--num-data=1000``, the improvement is
of almost 10x.
of almost almost 3x.
To demonstrate, we will be implementing an AR2 process.
The idea is that we have some times series
Expand All @@ -34,41 +34,49 @@
import os
import time

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import jax
from jax import random
import jax.numpy as jnp

import numpyro
from numpyro.contrib.control_flow import scan
import numpyro.distributions as dist

matplotlib.use("Agg")

def ar2_scan(y):
alpha_1 = numpyro.sample("alpha_1", dist.Normal(0, 1))
alpha_2 = numpyro.sample("alpha_2", dist.Normal(0, 1))
const = numpyro.sample("const", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.HalfNormal(1))

def transition(carry, _):
y_prev, y_prev_prev = carry
m_t = const + alpha_1 * y_prev + alpha_2 * y_prev_prev
y_t = numpyro.sample("y", dist.Normal(m_t, sigma))
carry = (y_t, y_prev)
return carry, None

timesteps = jnp.arange(y.shape[0] - 2)
init = (y[1], y[0])

with numpyro.handlers.condition(data={"y": y[2:]}):
scan(transition, init, timesteps)

def ar2(y, unroll_loop=False):

def ar2_for_loop(y):
alpha_1 = numpyro.sample("alpha_1", dist.Normal(0, 1))
alpha_2 = numpyro.sample("alpha_2", dist.Normal(0, 1))
const = numpyro.sample("const", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Normal(0, 1))

def transition_fn(carry, y):
y_1, y_2 = carry
pred = const + alpha_1 * y_1 + alpha_2 * y_2
return (y, y_1), pred

if unroll_loop:
preds = []
for i in range(2, len(y)):
preds.append(const + alpha_1 * y[i - 1] + alpha_2 * y[i - 2])
preds = jnp.asarray(preds)
else:
_, preds = jax.lax.scan(transition_fn, (y[1], y[0]), y[2:])
sigma = numpyro.sample("sigma", dist.HalfNormal(1))

mu = numpyro.deterministic("mu", preds)
numpyro.sample("obs", dist.Normal(mu, sigma), obs=y[2:])
y_prev = y[1]
y_prev_prev = y[0]

for i in range(2, len(y)):
m_t = const + alpha_1 * y_prev + alpha_2 * y_prev_prev
y_t = numpyro.sample("y_{}".format(i), dist.Normal(m_t, sigma), obs=y[i])
y_prev_prev = y_prev
y_prev = y_t


def run_inference(model, args, rng_key, y):
Expand All @@ -81,7 +89,7 @@ def run_inference(model, args, rng_key, y):
num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(rng_key, y=y, unroll_loop=args.unroll_loop)
mcmc.run(rng_key, y=y)
mcmc.print_summary()
print("\nMCMC elapsed time:", time.time() - start)
return mcmc.get_samples()
Expand All @@ -90,29 +98,19 @@ def run_inference(model, args, rng_key, y):
def main(args):
# generate artifical dataset
num_data = args.num_data
t = np.arange(0, num_data)
y = np.sin(t) + np.random.randn(num_data) * 0.1
rng_key = jax.random.PRNGKey(0)
t = jnp.arange(0, num_data)
y = jnp.sin(t) + random.normal(rng_key, (num_data,)) * 0.1

# do inference
rng_key, _ = random.split(random.PRNGKey(0))
samples = run_inference(ar2, args, rng_key, y)

# do prediction
mean_prediction = samples["mu"].mean(axis=0)

# make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)

# plot training data
ax.plot(t, y, color="blue", label="True values")
# plot mean prediction
# note that we can't make predictions for the first two points,
# because they don't have lagged values to use for prediction.
ax.plot(t[2:], mean_prediction, color="orange", label="Mean predictions")
ax.set(xlabel="time", ylabel="y", title="AR2 process")
ax.legend()
if args.unroll_loop:
# slower
model = ar2_for_loop
else:
# faster
model = ar2_scan

plt.savefig("ar2_plot.pdf")
run_inference(model, args, rng_key, y)


if __name__ == "__main__":
Expand Down

0 comments on commit 89e323e

Please sign in to comment.