Skip to content

Commit

Permalink
Refactor logprior_fn in SMCLinearRegressionTestCase
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao committed Oct 30, 2024
1 parent 8611f8f commit e63b057
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
13 changes: 6 additions & 7 deletions tests/smc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ def logdensity_by_observation(self, log_scale, coefs, preds, x):
scale = jnp.exp(log_scale)
y = jnp.dot(x, coefs)
logpdf = stats.norm.logpdf(preds, y, scale)
return log_scale + logpdf
return logpdf

def logdensity_fn(self, log_scale, coefs, preds, x):
"""Linear regression"""
logpdf = self.logdensity_by_observation(log_scale, coefs, preds, x)
return jnp.sum(logpdf)

def logprior_fn(self, log_scale, coefs):
return log_scale + stats.norm.logpdf(log_scale) + stats.norm.logpdf(coefs)

def observations(self):
num_particles = 100

Expand All @@ -27,9 +30,7 @@ def observations(self):
def particles_prior_loglikelihood(self):
observations, num_particles = self.observations()

logprior_fn = lambda x: stats.norm.logpdf(x["log_scale"]) + stats.norm.logpdf(
x["coefs"]
)
logprior_fn = lambda x: self.logprior_fn(**x)
loglikelihood_fn = lambda x: self.logdensity_fn(**x, **observations)

log_scale_init = np.random.randn(num_particles)
Expand All @@ -45,9 +46,7 @@ def partial_posterior_test_case(self):
y_data = 3 * x_data + np.random.normal(size=x_data.shape)
observations = {"x": x_data, "preds": y_data}

logprior_fn = lambda x: stats.norm.logpdf(x["log_scale"]) + stats.norm.logpdf(
x["coefs"]
)
logprior_fn = lambda x: self.logprior_fn(**x)

log_scale_init = np.random.randn(num_particles)
coeffs_init = np.random.randn(num_particles)
Expand Down
2 changes: 1 addition & 1 deletion tests/smc/test_inner_kernel_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def logdensity_fn(self, log_scale, coefs, preds, x):
scale = jnp.exp(log_scale)
y = jnp.dot(x, coefs)
logpdf = stats.norm.logpdf(preds, y, scale)
return log_scale + jnp.sum(logpdf)
return jnp.sum(logpdf)

def test_smc_inner_kernel_adaptive_tempered(self):
self.smc_inner_kernel_tuning_test_case(
Expand Down

0 comments on commit e63b057

Please sign in to comment.