diff --git a/viabel/_utils.py b/viabel/_utils.py index 75b18422..e02bc403 100644 --- a/viabel/_utils.py +++ b/viabel/_utils.py @@ -5,7 +5,7 @@ from hashlib import md5 import autograd.numpy as np -import pystan +import stan def vectorize_if_needed(f, a, axis=-1): @@ -75,7 +75,7 @@ def StanModel_cache(model_code=None, model_name=None, **kwargs): with open(cache_file, 'rb') as f: sm = pickle.load(f) else: - sm = pystan.StanModel(model_code=model_code, model_name=model_name) + sm = stan.build(program_code=model_code) with open(cache_file, 'wb') as f: pickle.dump(sm, f) diff --git a/viabel/models.py b/viabel/models.py index cf8e6347..561defb2 100644 --- a/viabel/models.py +++ b/viabel/models.py @@ -77,13 +77,14 @@ def set_inverse_temperature(self, inverse_temp): raise NotImplementedError() -def _make_stan_log_density(fitobj): +def _make_stan_log_density(model): @primitive def log_density(x): - return vectorize_if_needed(fitobj.log_prob, x) + return vectorize_if_needed(lambda t: model.log_prob(t.tolist()), x) def log_density_vjp(ans, x): - return lambda g: ensure_2d(g) * vectorize_if_needed(fitobj.grad_log_prob, x) + return lambda g: ensure_2d( + g) * vectorize_if_needed(lambda t: model.grad_log_prob(t.tolist()), x) defvjp(log_density, log_density_vjp) return log_density @@ -91,14 +92,14 @@ def log_density_vjp(ans, x): class StanModel(Model): """Class that encapsulates a PyStan model.""" - def __init__(self, fit): + def __init__(self, model): """ Parameters ---------- - fit : `StanFit4model` object + model : `stan.Model` object """ - self._fit = fit - super().__init__(_make_stan_log_density(fit)) + self._model = model + super().__init__(_make_stan_log_density(model)) def constrain(self, model_param): - return self._fit.constrain_pars(model_param) + return self._model.constrain_pars(model_param.tolist()) diff --git a/viabel/objectives.py b/viabel/objectives.py index 35ac5cd3..53e7124f 100644 --- a/viabel/objectives.py +++ b/viabel/objectives.py @@ -138,8 +138,9 @@ class DISInclusiveKL(StochasticVariationalObjective): """Inclusive Kullback-Leibler divergence using Distilled Importance Sampling.""" def __init__(self, approx, model, num_mc_samples, ess_target, - temper_prior, temper_prior_params, use_resampling=True, - num_resampling_batches=1, w_clip_threshold=10): + temper_model, temper_model_sampler, temper_fn, temper_eps_init, + use_resampling=True, num_resampling_batches=1, w_clip_threshold=10, + pretrain_batch_size=100): """ Parameters ---------- @@ -151,35 +152,54 @@ def __init__(self, approx, model, num_mc_samples, ess_target, ess_target: `int` The ess target to adjust epsilon (M in the paper). It is also the number of samples in resampling. - temper_prior: `Model` object - A prior distribution to temper the model. Typically multivariate normal. - temper_prior_params: `numpy.ndarray` object - Parameters for the temper prior. Typically mean 0 and variance 1. + temper_model: `Model` object + A distribution to temper the model. Typically multivariate normal. + temper_model_sampler: `function` or `None` + returns n samples from temper model. Need for pretrain. + temper_fn: `function` + a function maps (log_model_density, log_temper_model_density, eps) to a tempered density + temper_eps_init: `float` + The initial value of tempering parameter eps. use_resampling: `bool` Whether to use resampling. num_resampling_batches: `int` Number of resampling batches. The resampling batch is `max(1, ess_target / num_resampling_batches)`. w_clip_threshold: `float` The maximum weight. + pretrain_batch_size: `int` + The batch size for pretraining. """ self._ess_target = ess_target self._w_clip_threshold = w_clip_threshold self._max_bisection_its = 50 - self._max_eps = self._eps = 1 + self._max_eps = self._eps = temper_eps_init self._use_resampling = use_resampling self._num_resampling_batches = num_resampling_batches self._resampling_batch_size = max(1, self._ess_target // num_resampling_batches) self._objective_step = 0 - self._tempered_model_log_pdf = lambda eps, samples, log_p_unnormalized: ( - eps * temper_prior.log_density(temper_prior_params, samples) - + (1 - eps) * log_p_unnormalized) + self._model = model + self._temper_model = temper_model + self._temper_model_sampler = temper_model_sampler + self._temper_fn = temper_fn + + self._pretraining = False + self._pretrain_batch_size = pretrain_batch_size super().__init__(approx, model, num_mc_samples) + def pretrain(self): + """Run pretraining stage.""" + self._pretraining = True + + def regular_train(self): + """Run training stage.""" + self._pretraining = False + def _get_weights(self, eps, samples, log_p_unnormalized, log_q): """Calculates normalised importance sampling weights""" - logw = self._tempered_model_log_pdf(eps, samples, log_p_unnormalized) - log_q + logw = self._temper_fn(log_p_unnormalized, self._temper_model(samples), eps) - log_q max_logw = np.max(logw) + logw -= max_logw if max_logw == -np.inf: raise ValueError('All weights zero! ' + 'Suggests overflow in importance density.') @@ -201,27 +221,27 @@ def _get_eps_and_weights(self, eps_guess, samples, log_p_unnormalized, log_q): Returns new epsilon value and corresponding ESS and normalised importance sampling weights. """ - lower = 0. - upper = eps_guess - eps_guess = (lower + upper) / 2. - for i in range(self._max_bisection_its): - w = self._get_weights(eps_guess, samples, log_p_unnormalized, log_q) - ess = self._get_ess(w) - if ess > self._ess_target: - upper = eps_guess - else: - lower = eps_guess + if not self._pretraining: + lower = 0. + upper = eps_guess eps_guess = (lower + upper) / 2. + for i in range(self._max_bisection_its): + w = self._get_weights(eps_guess, samples, log_p_unnormalized, log_q) + ess = self._get_ess(w) + if ess > self._ess_target: + upper = eps_guess + else: + lower = eps_guess + eps_guess = (lower + upper) / 2. + # Consider returning extreme epsilon values if they are still endpoints + if lower == 0.: + eps_guess = 0. + if upper == self._max_eps: + eps_guess = self._max_eps w = self._get_weights(eps_guess, samples, log_p_unnormalized, log_q) ess = self._get_ess(w) - # Consider returning extreme epsilon values if they are still endpoints - if lower == 0.: - eps_guess = 0. - if upper == self._max_eps: - eps_guess = self._max_eps - return eps_guess, ess, w def _clip_weights(self, w): @@ -247,12 +267,21 @@ def _update_objective_and_grad(self): approx = self.approx def variational_objective(var_param): + if self._pretraining: + samples = self._temper_model_sampler(self._pretrain_batch_size) + log_q = approx.log_density(var_param, getval(samples)) + _, self._ess, _ = self._get_eps_and_weights(self._max_eps, samples, 0, log_q) + return np.mean(-log_q) + if not self._use_resampling or self._objective_step % self._num_resampling_batches == 0: self._state_samples = getval(approx.sample(var_param, self.num_mc_samples)) self._state_log_q = approx.log_density(var_param, self._state_samples) self._state_log_p_unnormalized = self.model(self._state_samples) - self._eps, ess, w = self._get_eps_and_weights( + # TODO double check + self._state_log_p_unnormalized -= np.max(self._state_log_p_unnormalized) + + self._eps, self._ess, w = self._get_eps_and_weights( self._eps, self._state_samples, self._state_log_p_unnormalized, self._state_log_q) self._state_w_clipped = self._clip_weights(w) self._state_w_sum = np.sum(self._state_w_clipped) @@ -261,7 +290,8 @@ def variational_objective(var_param): self._objective_step += 1 if not self._use_resampling: - return -np.inner(getval(self._state_w_clipped), self._state_log_q) / self.num_mc_samples + return -np.inner(getval(self._state_w_clipped), + self._state_log_q) / self.num_mc_samples else: indices = np.random.choice(self.num_mc_samples, size=self._resampling_batch_size, p=getval(self._state_w_normalized)) diff --git a/viabel/tests/test_models.py b/viabel/tests/test_models.py index b65d7d75..b44b763c 100644 --- a/viabel/tests/test_models.py +++ b/viabel/tests/test_models.py @@ -2,8 +2,8 @@ import autograd.numpy as anp import numpy as np -import pystan import pytest +import stan from autograd.scipy.stats import norm from autograd.test_util import check_vjp @@ -55,25 +55,24 @@ def log_p(x): def test_StanModel(): + np.random.seed(5039) + beta_gen = np.array([-2, 1]) + N = 25 + x = np.random.randn(N, 2).dot(np.array([[1, .75], [.75, 1]])) + y_raw = x.dot(beta_gen) + np.random.standard_t(40, N) + y = y_raw - np.mean(y_raw) + data = dict(N=N, x=x, y=y, df=40) + compiled_model_file = 'robust_reg_model.pkl' try: with open(compiled_model_file, 'rb') as f: regression_model = pickle.load(f) except BaseException: # pragma: no cover - regression_model = pystan.StanModel(model_code=test_model, - model_name='regression_model') + regression_model = stan.build(program_code=test_model, data=data) with open('robust_reg_model.pkl', 'wb') as f: pickle.dump(regression_model, f) - np.random.seed(5039) - beta_gen = np.array([-2, 1]) - N = 25 - x = np.random.randn(N, 2).dot(np.array([[1, .75], [.75, 1]])) - y_raw = x.dot(beta_gen) + np.random.standard_t(40, N) - y = y_raw - np.mean(y_raw) - data = dict(N=N, x=x, y=y, df=40) - fit = regression_model.sampling(data=data, iter=10, thin=1, chains=1) - model = models.StanModel(fit) + model = models.StanModel(regression_model) x = 4 * np.random.randn(10, 2) _test_model(model, x, False, dict(beta=x[0])) diff --git a/viabel/tests/test_objectives.py b/viabel/tests/test_objectives.py index ac035938..9e4ba613 100644 --- a/viabel/tests/test_objectives.py +++ b/viabel/tests/test_objectives.py @@ -1,5 +1,6 @@ import autograd.numpy as anp import numpy as np +import scipy.stats from autograd.scipy.stats import norm from viabel.approximations import MFGaussian, MFStudentT @@ -37,12 +38,19 @@ def test_ExclusiveKL(): def test_ExclusiveKL_path_deriv(): _test_objective(ExclusiveKL, 100, use_path_deriv=True) + # def __init__(self, approx, model, num_mc_samples, ess_target, + # temper_model, temper_model_sampler, temper_fn, temper_eps_init, + # use_resampling=True, num_resampling_batches=1, w_clip_threshold=10, + # pretrain_batch_size=100): def test_DISInclusiveKL(): dim = 2 + _test_objective(DISInclusiveKL, 100, - temper_prior=MFGaussian(dim), - temper_prior_params=np.concatenate([[0] * dim, [1] * dim]), + temper_model=scipy.stats.multivariate_normal(mean=[0]*dim, cov=np.diag([1]*dim)).logpdf, + temper_model_sampler=lambda n: np.random.multivariate_normal([[0] * dim, [1] * dim], size=n), + temper_fn=lambda model_logpdf, temper_logpdf, eps: temper_logpdf * eps + model_logpdf * (1 - eps), + temper_eps_init=1, ess_target=50)