Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions viabel/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from hashlib import md5

import autograd.numpy as np
import pystan
import stan


def vectorize_if_needed(f, a, axis=-1):
Expand Down Expand Up @@ -75,7 +75,7 @@ def StanModel_cache(model_code=None, model_name=None, **kwargs):
with open(cache_file, 'rb') as f:
sm = pickle.load(f)
else:
sm = pystan.StanModel(model_code=model_code, model_name=model_name)
sm = stan.build(program_code=model_code)
with open(cache_file, 'wb') as f:
pickle.dump(sm, f)

Expand Down
17 changes: 9 additions & 8 deletions viabel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,28 +77,29 @@ def set_inverse_temperature(self, inverse_temp):
raise NotImplementedError()


def _make_stan_log_density(fitobj):
def _make_stan_log_density(model):
@primitive
def log_density(x):
return vectorize_if_needed(fitobj.log_prob, x)
return vectorize_if_needed(lambda t: model.log_prob(t.tolist()), x)

def log_density_vjp(ans, x):
return lambda g: ensure_2d(g) * vectorize_if_needed(fitobj.grad_log_prob, x)
return lambda g: ensure_2d(
g) * vectorize_if_needed(lambda t: model.grad_log_prob(t.tolist()), x)
defvjp(log_density, log_density_vjp)
return log_density


class StanModel(Model):
"""Class that encapsulates a PyStan model."""

def __init__(self, fit):
def __init__(self, model):
"""
Parameters
----------
fit : `StanFit4model` object
model : `stan.Model` object
"""
self._fit = fit
super().__init__(_make_stan_log_density(fit))
self._model = model
super().__init__(_make_stan_log_density(model))

def constrain(self, model_param):
return self._fit.constrain_pars(model_param)
return self._model.constrain_pars(model_param.tolist())
88 changes: 59 additions & 29 deletions viabel/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,9 @@ class DISInclusiveKL(StochasticVariationalObjective):
"""Inclusive Kullback-Leibler divergence using Distilled Importance Sampling."""

def __init__(self, approx, model, num_mc_samples, ess_target,
temper_prior, temper_prior_params, use_resampling=True,
num_resampling_batches=1, w_clip_threshold=10):
temper_model, temper_model_sampler, temper_fn, temper_eps_init,
use_resampling=True, num_resampling_batches=1, w_clip_threshold=10,
pretrain_batch_size=100):
"""
Parameters
----------
Expand All @@ -151,35 +152,54 @@ def __init__(self, approx, model, num_mc_samples, ess_target,
ess_target: `int`
The ess target to adjust epsilon (M in the paper). It is also the number of
samples in resampling.
temper_prior: `Model` object
A prior distribution to temper the model. Typically multivariate normal.
temper_prior_params: `numpy.ndarray` object
Parameters for the temper prior. Typically mean 0 and variance 1.
temper_model: `Model` object
A distribution to temper the model. Typically multivariate normal.
temper_model_sampler: `function` or `None`
returns n samples from temper model. Need for pretrain.
temper_fn: `function`
a function maps (log_model_density, log_temper_model_density, eps) to a tempered density
temper_eps_init: `float`
The initial value of tempering parameter eps.
use_resampling: `bool`
Whether to use resampling.
num_resampling_batches: `int`
Number of resampling batches. The resampling batch is `max(1, ess_target / num_resampling_batches)`.
w_clip_threshold: `float`
The maximum weight.
pretrain_batch_size: `int`
The batch size for pretraining.
"""
self._ess_target = ess_target
self._w_clip_threshold = w_clip_threshold
self._max_bisection_its = 50
self._max_eps = self._eps = 1
self._max_eps = self._eps = temper_eps_init
self._use_resampling = use_resampling
self._num_resampling_batches = num_resampling_batches
self._resampling_batch_size = max(1, self._ess_target // num_resampling_batches)
self._objective_step = 0

self._tempered_model_log_pdf = lambda eps, samples, log_p_unnormalized: (
eps * temper_prior.log_density(temper_prior_params, samples)
+ (1 - eps) * log_p_unnormalized)
self._model = model
self._temper_model = temper_model
self._temper_model_sampler = temper_model_sampler
self._temper_fn = temper_fn

self._pretraining = False
self._pretrain_batch_size = pretrain_batch_size
super().__init__(approx, model, num_mc_samples)

def pretrain(self):
"""Run pretraining stage."""
self._pretraining = True

def regular_train(self):
"""Run training stage."""
self._pretraining = False

def _get_weights(self, eps, samples, log_p_unnormalized, log_q):
"""Calculates normalised importance sampling weights"""
logw = self._tempered_model_log_pdf(eps, samples, log_p_unnormalized) - log_q
logw = self._temper_fn(log_p_unnormalized, self._temper_model(samples), eps) - log_q
max_logw = np.max(logw)
logw -= max_logw
if max_logw == -np.inf:
raise ValueError('All weights zero! '
+ 'Suggests overflow in importance density.')
Expand All @@ -201,27 +221,27 @@ def _get_eps_and_weights(self, eps_guess, samples, log_p_unnormalized, log_q):
Returns new epsilon value and corresponding ESS and normalised importance sampling weights.
"""

lower = 0.
upper = eps_guess
eps_guess = (lower + upper) / 2.
for i in range(self._max_bisection_its):
w = self._get_weights(eps_guess, samples, log_p_unnormalized, log_q)
ess = self._get_ess(w)
if ess > self._ess_target:
upper = eps_guess
else:
lower = eps_guess
if not self._pretraining:
lower = 0.
upper = eps_guess
eps_guess = (lower + upper) / 2.
for i in range(self._max_bisection_its):
w = self._get_weights(eps_guess, samples, log_p_unnormalized, log_q)
ess = self._get_ess(w)
if ess > self._ess_target:
upper = eps_guess
else:
lower = eps_guess
eps_guess = (lower + upper) / 2.
# Consider returning extreme epsilon values if they are still endpoints
if lower == 0.:
eps_guess = 0.
if upper == self._max_eps:
eps_guess = self._max_eps

w = self._get_weights(eps_guess, samples, log_p_unnormalized, log_q)
ess = self._get_ess(w)

# Consider returning extreme epsilon values if they are still endpoints
if lower == 0.:
eps_guess = 0.
if upper == self._max_eps:
eps_guess = self._max_eps

return eps_guess, ess, w

def _clip_weights(self, w):
Expand All @@ -247,12 +267,21 @@ def _update_objective_and_grad(self):
approx = self.approx

def variational_objective(var_param):
if self._pretraining:
samples = self._temper_model_sampler(self._pretrain_batch_size)
log_q = approx.log_density(var_param, getval(samples))
_, self._ess, _ = self._get_eps_and_weights(self._max_eps, samples, 0, log_q)
return np.mean(-log_q)

if not self._use_resampling or self._objective_step % self._num_resampling_batches == 0:
self._state_samples = getval(approx.sample(var_param, self.num_mc_samples))
self._state_log_q = approx.log_density(var_param, self._state_samples)
self._state_log_p_unnormalized = self.model(self._state_samples)

self._eps, ess, w = self._get_eps_and_weights(
# TODO double check
self._state_log_p_unnormalized -= np.max(self._state_log_p_unnormalized)

self._eps, self._ess, w = self._get_eps_and_weights(
self._eps, self._state_samples, self._state_log_p_unnormalized, self._state_log_q)
self._state_w_clipped = self._clip_weights(w)
self._state_w_sum = np.sum(self._state_w_clipped)
Expand All @@ -261,7 +290,8 @@ def variational_objective(var_param):
self._objective_step += 1

if not self._use_resampling:
return -np.inner(getval(self._state_w_clipped), self._state_log_q) / self.num_mc_samples
return -np.inner(getval(self._state_w_clipped),
self._state_log_q) / self.num_mc_samples
else:
indices = np.random.choice(self.num_mc_samples,
size=self._resampling_batch_size, p=getval(self._state_w_normalized))
Expand Down
23 changes: 11 additions & 12 deletions viabel/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import autograd.numpy as anp
import numpy as np
import pystan
import pytest
import stan
from autograd.scipy.stats import norm
from autograd.test_util import check_vjp

Expand Down Expand Up @@ -55,25 +55,24 @@ def log_p(x):


def test_StanModel():
np.random.seed(5039)
beta_gen = np.array([-2, 1])
N = 25
x = np.random.randn(N, 2).dot(np.array([[1, .75], [.75, 1]]))
y_raw = x.dot(beta_gen) + np.random.standard_t(40, N)
y = y_raw - np.mean(y_raw)
data = dict(N=N, x=x, y=y, df=40)

compiled_model_file = 'robust_reg_model.pkl'
try:
with open(compiled_model_file, 'rb') as f:
regression_model = pickle.load(f)
except BaseException: # pragma: no cover
regression_model = pystan.StanModel(model_code=test_model,
model_name='regression_model')
regression_model = stan.build(program_code=test_model, data=data)
with open('robust_reg_model.pkl', 'wb') as f:
pickle.dump(regression_model, f)
np.random.seed(5039)
beta_gen = np.array([-2, 1])
N = 25
x = np.random.randn(N, 2).dot(np.array([[1, .75], [.75, 1]]))
y_raw = x.dot(beta_gen) + np.random.standard_t(40, N)
y = y_raw - np.mean(y_raw)

data = dict(N=N, x=x, y=y, df=40)
fit = regression_model.sampling(data=data, iter=10, thin=1, chains=1)
model = models.StanModel(fit)
model = models.StanModel(regression_model)

x = 4 * np.random.randn(10, 2)
_test_model(model, x, False, dict(beta=x[0]))
12 changes: 10 additions & 2 deletions viabel/tests/test_objectives.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import autograd.numpy as anp
import numpy as np
import scipy.stats
from autograd.scipy.stats import norm

from viabel.approximations import MFGaussian, MFStudentT
Expand Down Expand Up @@ -37,12 +38,19 @@ def test_ExclusiveKL():
def test_ExclusiveKL_path_deriv():
_test_objective(ExclusiveKL, 100, use_path_deriv=True)

# def __init__(self, approx, model, num_mc_samples, ess_target,
# temper_model, temper_model_sampler, temper_fn, temper_eps_init,
# use_resampling=True, num_resampling_batches=1, w_clip_threshold=10,
# pretrain_batch_size=100):

def test_DISInclusiveKL():
dim = 2

_test_objective(DISInclusiveKL, 100,
temper_prior=MFGaussian(dim),
temper_prior_params=np.concatenate([[0] * dim, [1] * dim]),
temper_model=scipy.stats.multivariate_normal(mean=[0]*dim, cov=np.diag([1]*dim)).logpdf,
temper_model_sampler=lambda n: np.random.multivariate_normal([[0] * dim, [1] * dim], size=n),
temper_fn=lambda model_logpdf, temper_logpdf, eps: temper_logpdf * eps + model_logpdf * (1 - eps),
temper_eps_init=1,
ess_target=50)


Expand Down