diff --git a/tests/smc/__init__.py b/tests/smc/__init__.py index 71d59e529..8f5a39d6f 100644 --- a/tests/smc/__init__.py +++ b/tests/smc/__init__.py @@ -9,13 +9,16 @@ def logdensity_by_observation(self, log_scale, coefs, preds, x): scale = jnp.exp(log_scale) y = jnp.dot(x, coefs) logpdf = stats.norm.logpdf(preds, y, scale) - return log_scale + logpdf + return logpdf def logdensity_fn(self, log_scale, coefs, preds, x): """Linear regression""" logpdf = self.logdensity_by_observation(log_scale, coefs, preds, x) return jnp.sum(logpdf) + def logprior_fn(self, log_scale, coefs): + return log_scale + stats.norm.logpdf(log_scale) + stats.norm.logpdf(coefs) + def observations(self): num_particles = 100 @@ -27,9 +30,7 @@ def observations(self): def particles_prior_loglikelihood(self): observations, num_particles = self.observations() - logprior_fn = lambda x: stats.norm.logpdf(x["log_scale"]) + stats.norm.logpdf( - x["coefs"] - ) + logprior_fn = lambda x: self.logprior_fn(**x) loglikelihood_fn = lambda x: self.logdensity_fn(**x, **observations) log_scale_init = np.random.randn(num_particles) @@ -45,9 +46,7 @@ def partial_posterior_test_case(self): y_data = 3 * x_data + np.random.normal(size=x_data.shape) observations = {"x": x_data, "preds": y_data} - logprior_fn = lambda x: stats.norm.logpdf(x["log_scale"]) + stats.norm.logpdf( - x["coefs"] - ) + logprior_fn = lambda x: self.logprior_fn(**x) log_scale_init = np.random.randn(num_particles) coeffs_init = np.random.randn(num_particles) diff --git a/tests/smc/test_inner_kernel_tuning.py b/tests/smc/test_inner_kernel_tuning.py index 080a02749..7d6190af5 100644 --- a/tests/smc/test_inner_kernel_tuning.py +++ b/tests/smc/test_inner_kernel_tuning.py @@ -69,7 +69,7 @@ def logdensity_fn(self, log_scale, coefs, preds, x): scale = jnp.exp(log_scale) y = jnp.dot(x, coefs) logpdf = stats.norm.logpdf(preds, y, scale) - return log_scale + jnp.sum(logpdf) + return jnp.sum(logpdf) def test_smc_inner_kernel_adaptive_tempered(self): self.smc_inner_kernel_tuning_test_case(