diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 8cd5fc759..6a0de3809 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -37,6 +37,7 @@ from .sgmcmc import sgnht as _sgnht from .smc import adaptive_tempered from .smc import inner_kernel_tuning as _inner_kernel_tuning +from .smc import partial_posteriors_path as _partial_posteriors_smc from .smc import tempered from .vi import meanfield_vi as _meanfield_vi from .vi import pathfinder as _pathfinder @@ -122,8 +123,9 @@ def generate_top_level_api_from(module): adaptive_tempered_smc = generate_top_level_api_from(adaptive_tempered) tempered_smc = generate_top_level_api_from(tempered) inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning) +partial_posteriors_smc = generate_top_level_api_from(_partial_posteriors_smc) -smc_family = [tempered_smc, adaptive_tempered_smc] +smc_family = [tempered_smc, adaptive_tempered_smc, partial_posteriors_smc] "Step_fn returning state has a .particles attribute" # stochastic gradient mcmc diff --git a/blackjax/mcmc/barker.py b/blackjax/mcmc/barker.py index 9923bd5f3..7ae7d2463 100644 --- a/blackjax/mcmc/barker.py +++ b/blackjax/mcmc/barker.py @@ -18,11 +18,13 @@ import jax.numpy as jnp from jax.flatten_util import ravel_pytree from jax.scipy import stats -from jax.tree_util import tree_leaves, tree_map +import blackjax.mcmc.metrics as metrics from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.metrics import Metric from blackjax.mcmc.proposal import static_binomial_sampling -from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.types import ArrayLikeTree, ArrayTree, Numeric, PRNGKey +from blackjax.util import generate_gaussian_noise __all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "as_top_level_api"] @@ -81,44 +83,70 @@ def build_kernel(): """ def _compute_acceptance_probability( - state: BarkerState, - proposal: BarkerState, - ) -> float: + state: BarkerState, proposal: BarkerState, metric: Metric + ) -> Numeric: """Compute the acceptance probability of the Barker's proposal kernel.""" - def ratio_proposal_nd(y, x, log_y, log_x): - num = -_log1pexp(-log_y * (x - y)) - den = -_log1pexp(-log_x * (y - x)) + x = state.position + y = proposal.position + log_x = state.logdensity_grad + log_y = proposal.logdensity_grad - return jnp.sum(num - den) + y_minus_x = jax.tree_util.tree_map(lambda a, b: a - b, y, x) + x_minus_y = jax.tree_util.tree_map(lambda a: -a, y_minus_x) + z_tilde_x_to_y = metric.scale(x, y_minus_x, inv=True, trans=True) + z_tilde_y_to_x = metric.scale(y, x_minus_y, inv=True, trans=True) - ratios_proposals = tree_map( - ratio_proposal_nd, - proposal.position, - state.position, - proposal.logdensity_grad, - state.logdensity_grad, + c_x_to_y = metric.scale(x, log_x, inv=False, trans=True) + c_y_to_x = metric.scale(y, log_y, inv=False, trans=True) + + z_tilde_x_to_y_flat, _ = ravel_pytree(z_tilde_x_to_y) + z_tilde_y_to_x_flat, _ = ravel_pytree(z_tilde_y_to_x) + + c_x_to_y_flat, _ = ravel_pytree(c_x_to_y) + c_y_to_x_flat, _ = ravel_pytree(c_y_to_x) + + num = metric.kinetic_energy(x_minus_y, y) - _log1pexp( + -z_tilde_y_to_x_flat * c_y_to_x_flat ) - ratio_proposal = sum(tree_leaves(ratios_proposals)) + denom = metric.kinetic_energy(y_minus_x, x) - _log1pexp( + -z_tilde_x_to_y_flat * c_x_to_y_flat + ) + + ratio_proposal = jnp.sum(num - denom) + return proposal.logdensity - state.logdensity + ratio_proposal def kernel( - rng_key: PRNGKey, state: BarkerState, logdensity_fn: Callable, step_size: float + rng_key: PRNGKey, + state: BarkerState, + logdensity_fn: Callable, + step_size: float, + inverse_mass_matrix: metrics.MetricTypes | None = None, ) -> tuple[BarkerState, BarkerInfo]: - """Generate a new sample with the MALA kernel.""" + """Generate a new sample with the Barker kernel.""" + if inverse_mass_matrix is None: + p, _ = ravel_pytree(state.position) + (m,) = p.shape + inverse_mass_matrix = jnp.ones((m,)) + metric = metrics.default_metric(inverse_mass_matrix) grad_fn = jax.value_and_grad(logdensity_fn) - key_sample, key_rmh = jax.random.split(rng_key) proposed_pos = _barker_sample( - key_sample, state.position, state.logdensity_grad, step_size + key_sample, + state.position, + state.logdensity_grad, + step_size, + metric, ) + proposed_logdensity, proposed_logdensity_grad = grad_fn(proposed_pos) proposed_state = BarkerState( proposed_pos, proposed_logdensity, proposed_logdensity_grad ) - log_p_accept = _compute_acceptance_probability(state, proposed_state) + log_p_accept = _compute_acceptance_probability(state, proposed_state, metric) accepted_state, info = static_binomial_sampling( key_rmh, log_p_accept, state, proposed_state ) @@ -131,6 +159,7 @@ def kernel( def as_top_level_api( logdensity_fn: Callable, step_size: float, + inverse_mass_matrix: metrics.MetricTypes | None = None, ) -> SamplingAlgorithm: """Implements the (basic) user interface for the Barker's proposal :cite:p:`Livingstone2022Barker` kernel with a Gaussian base kernel. @@ -174,7 +203,9 @@ def as_top_level_api( logdensity_fn The log-density function we wish to draw samples from. step_size - The value to use for the step size in the symplectic integrator. + The value of the step_size correspnoding to the global scale of the proposal distribution. + inverse_mass_matrix + The inverse mass matrix to use for pre-conditioning (see Appendix G of :cite:p:`Livingstone2022Barker`). Returns ------- @@ -189,74 +220,55 @@ def init_fn(position: ArrayLikeTree, rng_key=None): return init(position, logdensity_fn) def step_fn(rng_key: PRNGKey, state): - return kernel(rng_key, state, logdensity_fn, step_size) + return kernel(rng_key, state, logdensity_fn, step_size, inverse_mass_matrix) return SamplingAlgorithm(init_fn, step_fn) -def _barker_sample_nd(key, mean, a, scale): - """ - Sample from a multivariate Barker's proposal distribution. In 1D, this has the following probability density function: - - .. math:: - p(x; \\mu, a, \\sigma) = 2 \frac{N(x; \\mu, \\sigma^2)}{1 + \\exp(-a (x - \\mu)} +def _generate_bernoulli( + rng_key: PRNGKey, position: ArrayLikeTree, p: ArrayLikeTree +) -> ArrayTree: + pos, unravel_fn = ravel_pytree(position) + p_flat, _ = ravel_pytree(p) + sample = jax.random.bernoulli(rng_key, p=p_flat, shape=pos.shape) + return unravel_fn(sample) - where :math:`N(x; \\mu, \\sigma^2)` is the normal distribution with mean :math:`\\mu` and standard deviation :math:`\\sigma`. - The multivariate Barker's proposal distribution is the product of one-dimensional Barker's proposal distributions. +def _barker_sample(key, mean, a, scale, metric): + r""" + Sample from a multivariate Barker's proposal distribution for PyTrees. Parameters ---------- key A PRNG key. mean - The mean of the normal distribution, an Array. This corresponds to :math:`\\mu` in the equation above. + The mean of the normal distribution, a PyTree. This corresponds to :math:`\mu` in the equation above. a - The parameter :math:`a` in the equation above, an Array. This is a skewness parameter. + The parameter :math:`a` in the equation above, the same PyTree as `mean`. This is a skewness parameter. scale - The standard deviation of the normal distribution, a scalar. This corresponds to :math:`\\sigma` in the equation above. + The global scale, a scalar. This corresponds to :math:`\\sigma` in the equation above. It encodes the step size of the proposal. - - Returns - ------- - A sample from the Barker's multidimensional proposal distribution. - + metric + A `metrics.MetricTypes` object encoding the mass matrix information. """ key1, key2 = jax.random.split(key) - z = scale * jax.random.normal(key1, shape=mean.shape) + + z = generate_gaussian_noise(key1, mean, sigma=scale) + c = metric.scale(mean, a, inv=False, trans=True) # Sample b=1 with probability p and 0 with probability 1 - p where # p = 1 / (1 + exp(-a * (z - mean))) - log_p = -_log1pexp(-a * z) - b = jax.random.bernoulli(key2, p=jnp.exp(log_p), shape=mean.shape) - - # return mean + z if b == 1 else mean - z - return mean + b * z - (1 - b) * z - + log_p = jax.tree_util.tree_map(lambda x, y: -_log1pexp(-x * y), c, z) + p = jax.tree_util.tree_map(lambda x: jnp.exp(x), log_p) + b = _generate_bernoulli(key2, mean, p=p) -def _barker_sample(key, mean, a, scale): - r""" - Sample from a multivariate Barker's proposal distribution for PyTrees. - - Parameters - ---------- - key - A PRNG key. - mean - The mean of the normal distribution, a PyTree. This corresponds to :math:`\mu` in the equation above. - a - The parameter :math:`a` in the equation above, the same PyTree as `mean`. This is a skewness parameter. - scale - The standard deviation of the normal distribution, a scalar. This corresponds to :math:`\sigma` in the equation above. - It encodes the step size of the proposal. - - """ + bz = jax.tree_util.tree_map(lambda x, y: x * y - (1 - x) * y, b, z) - flat_mean, unravel_fn = ravel_pytree(mean) - flat_a, _ = ravel_pytree(a) - flat_sample = _barker_sample_nd(key, flat_mean, flat_a, scale) - return unravel_fn(flat_sample) + return jax.tree_util.tree_map( + lambda a, b: a + b, mean, metric.scale(mean, bz, inv=False, trans=False) + ) def _log1pexp(a): diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index 4e079714b..f0720acf4 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -30,7 +30,6 @@ """ from typing import Callable, NamedTuple, Optional, Protocol, Union -import jax import jax.numpy as jnp import jax.scipy as jscipy from jax.flatten_util import ravel_pytree @@ -62,7 +61,12 @@ def __call__( class Scale(Protocol): def __call__( - self, position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree + self, + position: ArrayLikeTree, + element: ArrayLikeTree, + *, + inv: bool, + trans: bool, ) -> ArrayLikeTree: ... @@ -187,7 +191,11 @@ def is_turning( return turning_at_left | turning_at_right def scale( - position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree + position: ArrayLikeTree, + element: ArrayLikeTree, + *, + inv: bool, + trans: bool, ) -> ArrayLikeTree: """Scale elements by the mass matrix. @@ -197,10 +205,11 @@ def scale( The current position. Not used in this metric. elements Elements to scale - invs + inv Whether to scale the elements by the inverse mass matrix or the mass matrix. If True, the element is scaled by the inverse square root mass matrix, i.e., elem <- (M^{1/2})^{-1} elem. - Same pytree structure as `elements`. + trans + whether to transpose mass matrix when scaling Returns ------- @@ -209,11 +218,16 @@ def scale( """ ravelled_element, unravel_fn = ravel_pytree(element) - scaled = jax.lax.cond( - inv, - lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element), - lambda: linear_map(mass_matrix_sqrt, ravelled_element), - ) + + if inv: + left_hand_side_matrix = inv_mass_matrix_sqrt + else: + left_hand_side_matrix = mass_matrix_sqrt + if trans: + left_hand_side_matrix = left_hand_side_matrix.T + + scaled = linear_map(left_hand_side_matrix, ravelled_element) + return unravel_fn(scaled) return Metric(momentum_generator, kinetic_energy, is_turning, scale) @@ -279,7 +293,11 @@ def is_turning( # return turning_at_left | turning_at_right def scale( - position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree + position: ArrayLikeTree, + element: ArrayLikeTree, + *, + inv: bool, + trans: bool, ) -> ArrayLikeTree: """Scale elements by the mass matrix. @@ -298,11 +316,16 @@ def scale( mass_matrix, is_inv=False ) ravelled_element, unravel_fn = ravel_pytree(element) - scaled = jax.lax.cond( - inv, - lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element), - lambda: linear_map(mass_matrix_sqrt, ravelled_element), - ) + + if inv: + left_hand_side_matrix = inv_mass_matrix_sqrt + else: + left_hand_side_matrix = mass_matrix_sqrt + if trans: + left_hand_side_matrix = left_hand_side_matrix.T + + scaled = linear_map(left_hand_side_matrix, ravelled_element) + return unravel_fn(scaled) return Metric(momentum_generator, kinetic_energy, is_turning, scale) diff --git a/blackjax/smc/__init__.py b/blackjax/smc/__init__.py index ef10b10e6..9670fcb6e 100644 --- a/blackjax/smc/__init__.py +++ b/blackjax/smc/__init__.py @@ -6,4 +6,5 @@ "tempered", "inner_kernel_tuning", "extend_params", + "partial_posteriors_path", ] diff --git a/blackjax/smc/base.py b/blackjax/smc/base.py index 5093cf06b..56df7f010 100644 --- a/blackjax/smc/base.py +++ b/blackjax/smc/base.py @@ -156,3 +156,31 @@ def extend_params(params): """ return jax.tree.map(lambda x: jnp.asarray(x)[None, ...], params) + + +def update_and_take_last( + mcmc_init_fn, + tempered_logposterior_fn, + shared_mcmc_step_fn, + num_mcmc_steps, + n_particles, +): + """Given N particles, runs num_mcmc_steps of a kernel starting at each particle, and + returns the last values, waisting the previous num_mcmc_steps-1 + samples per chain. + """ + + def mcmc_kernel(rng_key, position, step_parameters): + state = mcmc_init_fn(position, tempered_logposterior_fn) + + def body_fn(state, rng_key): + new_state, info = shared_mcmc_step_fn( + rng_key, state, tempered_logposterior_fn, **step_parameters + ) + return new_state, info + + keys = jax.random.split(rng_key, num_mcmc_steps) + last_state, info = jax.lax.scan(body_fn, state, keys) + return last_state.position, info + + return jax.vmap(mcmc_kernel), n_particles diff --git a/blackjax/smc/from_mcmc.py b/blackjax/smc/from_mcmc.py new file mode 100644 index 000000000..0e60b5968 --- /dev/null +++ b/blackjax/smc/from_mcmc.py @@ -0,0 +1,64 @@ +from functools import partial +from typing import Callable + +import jax + +from blackjax import smc +from blackjax.smc.base import SMCState, update_and_take_last +from blackjax.types import PRNGKey + + +def build_kernel( + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, + resampling_fn: Callable, + update_strategy: Callable = update_and_take_last, +): + """SMC step from MCMC kernels. + Builds MCMC kernels from the input parameters, which may change across iterations. + Moreover, it defines the way such kernels are used to update the particles. This layer + adapts an API defined in terms of kernels (mcmc_step_fn and mcmc_init_fn) into an API + that depends on an update function over the set of particles. + Returns + ------- + A callable that takes a rng_key and a state with .particles and .weights and returns a base.SMCState + and base.SMCInfo pair. + + """ + + def step( + rng_key: PRNGKey, + state, + num_mcmc_steps: int, + mcmc_parameters: dict, + logposterior_fn: Callable, + log_weights_fn: Callable, + ) -> tuple[smc.base.SMCState, smc.base.SMCInfo]: + shared_mcmc_parameters = {} + unshared_mcmc_parameters = {} + for k, v in mcmc_parameters.items(): + if v.shape[0] == 1: + shared_mcmc_parameters[k] = v[0, ...] + else: + unshared_mcmc_parameters[k] = v + + shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters) + + update_fn, num_resampled = update_strategy( + mcmc_init_fn, + logposterior_fn, + shared_mcmc_step_fn, + n_particles=state.weights.shape[0], + num_mcmc_steps=num_mcmc_steps, + ) + + return smc.base.step( + rng_key, + SMCState(state.particles, state.weights, unshared_mcmc_parameters), + update_fn, + jax.vmap(log_weights_fn), + resampling_fn, + num_resampled, + ) + + return step diff --git a/blackjax/smc/partial_posteriors_path.py b/blackjax/smc/partial_posteriors_path.py new file mode 100644 index 000000000..81f19716d --- /dev/null +++ b/blackjax/smc/partial_posteriors_path.py @@ -0,0 +1,127 @@ +from typing import Callable, NamedTuple, Optional, Tuple + +import jax +import jax.numpy as jnp + +from blackjax import SamplingAlgorithm, smc +from blackjax.smc.base import update_and_take_last +from blackjax.smc.from_mcmc import build_kernel as smc_from_mcmc +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey + + +class PartialPosteriorsSMCState(NamedTuple): + """Current state for the tempered SMC algorithm. + + particles: PyTree + The particles' positions. + weights: + Weights of the particles, so that they represent a probability distribution + data_mask: + A 1D boolean array to indicate which datapoints to include + in the computation of the observed likelihood. + """ + + particles: ArrayTree + weights: Array + data_mask: Array + + +def init(particles: ArrayLikeTree, num_datapoints: int) -> PartialPosteriorsSMCState: + """num_datapoints are the number of observations that could potentially be + used in a partial posterior. Since the initial data_mask is all 0s, it + means that no likelihood term will be added (only prior). + """ + num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] + weights = jnp.ones(num_particles) / num_particles + return PartialPosteriorsSMCState(particles, weights, jnp.zeros(num_datapoints)) + + +def build_kernel( + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, + resampling_fn: Callable, + num_mcmc_steps: Optional[int], + mcmc_parameters: ArrayTree, + partial_logposterior_factory: Callable[[Array], Callable], + update_strategy=update_and_take_last, +) -> Callable: + """Build the Partial Posteriors (data tempering) SMC kernel. + The distribution's trajectory includes increasingly adding more + datapoints to the likelihood. See Section 2.2 of https://arxiv.org/pdf/2007.11936 + Parameters + ---------- + mcmc_step_fn + A function that computes the log density of the prior distribution + mcmc_init_fn + A function that returns the probability at a given position. + resampling_fn + A random function that resamples generated particles based of weights + num_mcmc_steps + Number of iterations in the MCMC chain. + mcmc_parameters + A dictionary of parameters to be used by the inner MCMC kernels + partial_logposterior_factory: + A callable that given an array of 0 and 1, returns a function logposterior(x). + The array represents which values to include in the logposterior calculation. The logposterior + must be jax compilable. + + Returns + ------- + A callable that takes a rng_key and PartialPosteriorsSMCState and selectors for + the current and previous posteriors, and takes a data-tempered SMC state. + """ + delegate = smc_from_mcmc(mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy) + + def step( + key, state: PartialPosteriorsSMCState, data_mask: Array + ) -> Tuple[PartialPosteriorsSMCState, smc.base.SMCInfo]: + logposterior_fn = partial_logposterior_factory(data_mask) + + previous_logposterior_fn = partial_logposterior_factory(state.data_mask) + + def log_weights_fn(x): + return logposterior_fn(x) - previous_logposterior_fn(x) + + state, info = delegate( + key, state, num_mcmc_steps, mcmc_parameters, logposterior_fn, log_weights_fn + ) + + return ( + PartialPosteriorsSMCState(state.particles, state.weights, data_mask), + info, + ) + + return step + + +def as_top_level_api( + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, + mcmc_parameters: dict, + resampling_fn: Callable, + num_mcmc_steps, + partial_logposterior_factory: Callable, + update_strategy=update_and_take_last, +) -> SamplingAlgorithm: + """A factory that wraps the kernel into a SamplingAlgorithm object. + See build_kernel for full documentation on the parameters. + """ + + kernel = build_kernel( + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + num_mcmc_steps, + mcmc_parameters, + partial_logposterior_factory, + update_strategy, + ) + + def init_fn(position: ArrayLikeTree, num_observations, rng_key=None): + del rng_key + return init(position, num_observations) + + def step(key: PRNGKey, state: PartialPosteriorsSMCState, data_mask: Array): + return kernel(key, state, data_mask) + + return SamplingAlgorithm(init_fn, step) # type: ignore[arg-type] diff --git a/blackjax/smc/tempered.py b/blackjax/smc/tempered.py index 19de8afb7..88539deaa 100644 --- a/blackjax/smc/tempered.py +++ b/blackjax/smc/tempered.py @@ -11,15 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, NamedTuple, Optional import jax import jax.numpy as jnp import blackjax.smc as smc +import blackjax.smc.from_mcmc as smc_from_mcmc from blackjax.base import SamplingAlgorithm -from blackjax.smc.base import SMCState +from blackjax.smc.base import update_and_take_last from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey __all__ = ["TemperedSMCState", "init", "build_kernel", "as_top_level_api"] @@ -48,35 +48,6 @@ def init(particles: ArrayLikeTree): return TemperedSMCState(particles, weights, 0.0) -def update_and_take_last( - mcmc_init_fn, - tempered_logposterior_fn, - shared_mcmc_step_fn, - num_mcmc_steps, - n_particles, -): - """ - Given N particles, runs num_mcmc_steps of a kernel starting at each particle, and - returns the last values, waisting the previous num_mcmc_steps-1 - samples per chain. - """ - - def mcmc_kernel(rng_key, position, step_parameters): - state = mcmc_init_fn(position, tempered_logposterior_fn) - - def body_fn(state, rng_key): - new_state, info = shared_mcmc_step_fn( - rng_key, state, tempered_logposterior_fn, **step_parameters - ) - return new_state, info - - keys = jax.random.split(rng_key, num_mcmc_steps) - last_state, info = jax.lax.scan(body_fn, state, keys) - return last_state.position, info - - return jax.vmap(mcmc_kernel), n_particles - - def build_kernel( logprior_fn: Callable, loglikelihood_fn: Callable, @@ -121,6 +92,9 @@ def build_kernel( information about the transition. """ + delegate = smc_from_mcmc.build_kernel( + mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy + ) def kernel( rng_key: PRNGKey, @@ -153,14 +127,6 @@ def kernel( """ delta = lmbda - state.lmbda - shared_mcmc_parameters = {} - unshared_mcmc_parameters = {} - for k, v in mcmc_parameters.items(): - if v.shape[0] == 1: - shared_mcmc_parameters[k] = v[0, ...] - else: - unshared_mcmc_parameters[k] = v - def log_weights_fn(position: ArrayLikeTree) -> float: return delta * loglikelihood_fn(position) @@ -169,23 +135,13 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float: tempered_loglikelihood = state.lmbda * loglikelihood_fn(position) return logprior + tempered_loglikelihood - shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters) - - update_fn, num_resampled = update_strategy( - mcmc_init_fn, - tempered_logposterior_fn, - shared_mcmc_step_fn, - n_particles=state.weights.shape[0], - num_mcmc_steps=num_mcmc_steps, - ) - - smc_state, info = smc.base.step( + smc_state, info = delegate( rng_key, - SMCState(state.particles, state.weights, unshared_mcmc_parameters), - update_fn, - jax.vmap(log_weights_fn), - resampling_fn, - num_resampled, + state, + num_mcmc_steps, + mcmc_parameters, + tempered_logposterior_fn, + log_weights_fn, ) tempered_state = TemperedSMCState( diff --git a/docs/examples/howto_sample_multiple_chains.md b/docs/examples/howto_sample_multiple_chains.md index a5b6566f8..c2947e29f 100644 --- a/docs/examples/howto_sample_multiple_chains.md +++ b/docs/examples/howto_sample_multiple_chains.md @@ -57,8 +57,9 @@ observed = np.random.normal(loc, scale, size=1_000) def logdensity_fn(loc, log_scale, observed=observed): """Univariate Normal""" scale = jnp.exp(log_scale) + logjac = log_scale logpdf = stats.norm.logpdf(observed, loc, scale) - return jnp.sum(logpdf) + return logjac + jnp.sum(logpdf) def logdensity(x): diff --git a/docs/examples/quickstart.md b/docs/examples/quickstart.md index 870e5df9a..a290bfdad 100644 --- a/docs/examples/quickstart.md +++ b/docs/examples/quickstart.md @@ -48,8 +48,9 @@ observed = np.random.normal(loc, scale, size=1_000) def logdensity_fn(loc, log_scale, observed=observed): """Univariate Normal""" scale = jnp.exp(log_scale) + logjac = log_scale logpdf = stats.norm.logpdf(observed, loc, scale) - return jnp.sum(logpdf) + return logjac + jnp.sum(logpdf) logdensity = lambda x: logdensity_fn(**x) diff --git a/tests/mcmc/test_barker.py b/tests/mcmc/test_barker.py index 5c227c4cb..04a86d1d4 100644 --- a/tests/mcmc/test_barker.py +++ b/tests/mcmc/test_barker.py @@ -1,9 +1,16 @@ +import functools +import itertools + import chex import jax import jax.numpy as jnp +import jax.scipy.stats as stats from absl.testing import absltest, parameterized -from blackjax.mcmc.barker import _barker_pdf, _barker_sample_nd +import blackjax +from blackjax.mcmc import metrics +from blackjax.mcmc.barker import _barker_pdf, _barker_sample +from blackjax.util import run_inference_algorithm class BarkerSamplingTest(chex.TestCase): @@ -18,8 +25,9 @@ def test_nd(self, seed): 0.5, ) + metric = metrics.default_metric(jnp.eye(4)) keys = jax.random.split(key, n_samples) - samples = jax.vmap(lambda k: _barker_sample_nd(k, m, a, scale))(keys) + samples = jax.vmap(lambda k: _barker_sample(k, m, a, scale, metric))(keys) # Check that the emprical mean and the mean computed as sum(x * p(x) dx) are close _test_samples_vs_pdf(samples, lambda x: _barker_pdf(x, m, a, scale)) @@ -51,5 +59,121 @@ def _test_samples_vs_pdf(samples, pdf): ) +class BarkerPreconditioiningTest(chex.TestCase): + @parameterized.parameters([1234, 5678]) + def test_preconditioning_matrix(self, seed): + """Test two different ways of using pre-conditioning matrix has exactly same effect. + + We follow the discussion in Appendix G of the Barker 2020 paper. + """ + + key = jax.random.key(seed) + init_key, inference_key = jax.random.split(key, 2) + + # setup some 2D multivariate normal model + # setup sampling mean and cov + true_x = jnp.array([0.0, 1.0]) + data = jax.random.normal(init_key, shape=(1000,)) * true_x[1] + true_x[0] + assert data.shape == (1000,) + + # some non-diagonal positive-defininte matrix for pre-conditioning + inv_mass_matrix = jnp.array([[1, 0.1], [0.1, 1]]) + metric = metrics.default_metric(inv_mass_matrix) + + # define barker kernel two ways + # non-scaled, use pre-conditioning + def logdensity(x, data): + mu_prior = stats.norm.logpdf(x[0], loc=0, scale=1) + sigma_prior = stats.uniform.logpdf(x[1], 0.0, 3.0) + return mu_prior + sigma_prior + jnp.sum(stats.norm.logcdf(data, x[0], x[1])) + + logposterior_fn1 = functools.partial(logdensity, data=data) + barker1 = blackjax.barker_proposal(logposterior_fn1, 1e-1, inv_mass_matrix) + state1 = barker1.init(true_x) + + # scaled, trivial pre-conditioning + def scaled_logdensity(x_scaled, data, metric): + x = metric.scale(x_scaled, x_scaled, inv=False, trans=False) + return logdensity(x, data) + + logposterior_fn2 = functools.partial( + scaled_logdensity, data=data, metric=metric + ) + barker2 = blackjax.barker_proposal(logposterior_fn2, 1e-1, jnp.eye(2)) + + true_x_trans = metric.scale(true_x, true_x, inv=True, trans=True) + state2 = barker2.init(true_x_trans) + + n_steps = 10 + _, states1 = run_inference_algorithm( + rng_key=inference_key, + initial_state=state1, + inference_algorithm=barker1, + transform=lambda state, info: state.position, + num_steps=n_steps, + ) + + _, states2 = run_inference_algorithm( + rng_key=inference_key, + initial_state=state2, + inference_algorithm=barker2, + transform=lambda state, info: state.position, + num_steps=n_steps, + ) + + # states should be the exact same with same random key after transforming + states2_trans = [] + for ii in range(n_steps): + s = states2[ii] + states2_trans.append(metric.scale(s, s, inv=False, trans=False)) + states2_trans = jnp.array(states2_trans) + assert jnp.allclose(states1, states2_trans) + + @parameterized.parameters( + itertools.product([1234, 5678], ["gaussian", "riemannian"]) + ) + def test_invariance(self, seed, metric): + logpdf = lambda x: -0.5 * jnp.sum(x**2) + + n_samples, m_steps = 10_000, 50 + + key = jax.random.key(seed) + init_key, inference_key = jax.random.split(key, 2) + inference_keys = jax.random.split(inference_key, n_samples) + if metric == "gaussian": + inv_mass_matrix = jnp.ones((2,)) + metric = metrics.default_metric(inv_mass_matrix) + else: + # bit of a random metric but we are testing invariance, not efficiency + metric = metrics.gaussian_riemannian( + lambda x: 1 / jnp.sum(1 + jnp.sum(x**2)) * jnp.eye(2) + ) + + barker = blackjax.barker_proposal(logpdf, 0.5, metric) + init_samples = jax.random.normal(init_key, shape=(n_samples, 2)) + + def loop(carry, key_): + state, accepted = carry + state, info = barker.step(key_, state) + accepted += info.is_accepted + return (state, accepted), None + + def get_samples(init_sample, key_): + init = (barker.init(init_sample), 0) + (out, n_accepted), _ = jax.lax.scan( + loop, init, jax.random.split(key_, m_steps) + ) + return out.position, n_accepted / m_steps + + samples, total_accepted = jax.vmap(get_samples)(init_samples, inference_keys) + # now we test the distance versus a Gaussian + chex.assert_trees_all_close( + jnp.mean(samples, 0), jnp.zeros((2,)), atol=1e-1, rtol=1e-1 + ) + chex.assert_trees_all_close( + jnp.cov(samples.T), jnp.eye(2), atol=1e-1, rtol=1e-1 + ) + + if __name__ == "__main__": absltest.main() diff --git a/tests/mcmc/test_metrics.py b/tests/mcmc/test_metrics.py index 0791f3cb1..e6aa5879f 100644 --- a/tests/mcmc/test_metrics.py +++ b/tests/mcmc/test_metrics.py @@ -131,8 +131,12 @@ def test_gaussian_euclidean_dim_1(self): assert momentum_val == expected_momentum_val assert kinetic_energy_val == expected_kinetic_energy_val - inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) - scaled_momentum = scale(arbitrary_position, momentum_val, False) + inv_scaled_momentum = scale( + arbitrary_position, momentum_val, inv=True, trans=False + ) + scaled_momentum = scale( + arbitrary_position, momentum_val, inv=False, trans=False + ) expected_scaled_momentum = momentum_val / jnp.sqrt(inverse_mass_matrix) expected_inv_scaled_momentum = momentum_val * jnp.sqrt(inverse_mass_matrix) @@ -164,8 +168,12 @@ def test_gaussian_euclidean_dim_2(self): np.testing.assert_allclose(expected_momentum_val, momentum_val) np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) - inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) - scaled_momentum = scale(arbitrary_position, momentum_val, False) + inv_scaled_momentum = scale( + arbitrary_position, momentum_val, inv=True, trans=False + ) + scaled_momentum = scale( + arbitrary_position, momentum_val, inv=False, trans=False + ) expected_inv_scaled_momentum = jnp.linalg.inv(L_inv).T @ momentum_val expected_scaled_momentum = L_inv @ momentum_val @@ -226,8 +234,12 @@ def test_gaussian_riemannian_dim_1(self): np.testing.assert_allclose(expected_momentum_val, momentum_val) np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) - inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) - scaled_momentum = scale(arbitrary_position, momentum_val, False) + inv_scaled_momentum = scale( + arbitrary_position, momentum_val, inv=True, trans=False + ) + scaled_momentum = scale( + arbitrary_position, momentum_val, inv=False, trans=False + ) expected_scaled_momentum = momentum_val / jnp.sqrt(inverse_mass_matrix) expected_inv_scaled_momentum = momentum_val * jnp.sqrt(inverse_mass_matrix) @@ -265,8 +277,12 @@ def test_gaussian_riemannian_dim_2(self): np.testing.assert_allclose(expected_momentum_val, momentum_val) np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) - inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) - scaled_momentum = scale(arbitrary_position, momentum_val, False) + inv_scaled_momentum = scale( + arbitrary_position, momentum_val, inv=True, trans=False + ) + scaled_momentum = scale( + arbitrary_position, momentum_val, inv=False, trans=False + ) expected_inv_scaled_momentum = jnp.linalg.inv(L_inv).T @ momentum_val expected_scaled_momentum = L_inv @ momentum_val diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index bb83d87ff..74fdfd6fb 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -1,4 +1,5 @@ """Test the accuracy of the MCMC kernels.""" + import functools import itertools diff --git a/tests/smc/__init__.py b/tests/smc/__init__.py index 7a4e5c029..8f5a39d6f 100644 --- a/tests/smc/__init__.py +++ b/tests/smc/__init__.py @@ -5,23 +5,32 @@ class SMCLinearRegressionTestCase(chex.TestCase): - def logdensity_fn(self, log_scale, coefs, preds, x): - """Linear regression""" + def logdensity_by_observation(self, log_scale, coefs, preds, x): scale = jnp.exp(log_scale) y = jnp.dot(x, coefs) logpdf = stats.norm.logpdf(preds, y, scale) + return logpdf + + def logdensity_fn(self, log_scale, coefs, preds, x): + """Linear regression""" + logpdf = self.logdensity_by_observation(log_scale, coefs, preds, x) return jnp.sum(logpdf) - def particles_prior_loglikelihood(self): + def logprior_fn(self, log_scale, coefs): + return log_scale + stats.norm.logpdf(log_scale) + stats.norm.logpdf(coefs) + + def observations(self): num_particles = 100 x_data = np.random.normal(0, 1, size=(1000, 1)) y_data = 3 * x_data + np.random.normal(size=x_data.shape) observations = {"x": x_data, "preds": y_data} + return observations, num_particles - logprior_fn = lambda x: stats.norm.logpdf(x["log_scale"]) + stats.norm.logpdf( - x["coefs"] - ) + def particles_prior_loglikelihood(self): + observations, num_particles = self.observations() + + logprior_fn = lambda x: self.logprior_fn(**x) loglikelihood_fn = lambda x: self.logdensity_fn(**x, **observations) log_scale_init = np.random.randn(num_particles) @@ -30,6 +39,21 @@ def particles_prior_loglikelihood(self): return init_particles, logprior_fn, loglikelihood_fn + def partial_posterior_test_case(self): + num_particles = 100 + + x_data = np.random.normal(0, 1, size=(1000, 1)) + y_data = 3 * x_data + np.random.normal(size=x_data.shape) + observations = {"x": x_data, "preds": y_data} + + logprior_fn = lambda x: self.logprior_fn(**x) + + log_scale_init = np.random.randn(num_particles) + coeffs_init = np.random.randn(num_particles) + init_particles = {"log_scale": log_scale_init, "coefs": coeffs_init} + + return init_particles, logprior_fn, observations + def assert_linear_regression_test_case(self, result): np.testing.assert_allclose( np.mean(np.exp(result.particles["log_scale"])), 1.0, rtol=1e-1 diff --git a/tests/smc/test_partial_posteriors_smc.py b/tests/smc/test_partial_posteriors_smc.py new file mode 100644 index 000000000..78d57a934 --- /dev/null +++ b/tests/smc/test_partial_posteriors_smc.py @@ -0,0 +1,88 @@ +import chex +import jax +import jax.numpy as jnp +import numpy as np +from absl.testing import absltest + +import blackjax +import blackjax.smc.resampling as resampling +from blackjax.smc import extend_params +from tests.smc import SMCLinearRegressionTestCase + + +class PartialPosteriorsSMCTest(SMCLinearRegressionTestCase): + """Test posterior mean estimate.""" + + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + + @chex.variants(with_jit=True) + def test_partial_posteriors(self): + ( + init_particles, + logprior_fn, + observations, + ) = self.partial_posterior_test_case() + + hmc_init = blackjax.hmc.init + hmc_kernel = blackjax.hmc.build_kernel() + + hmc_parameters = extend_params( + { + "step_size": 10e-3, + "inverse_mass_matrix": jnp.eye(2), + "num_integration_steps": 50, + }, + ) + + dataset_size = 1000 + + def partial_logposterior_factory(data_mask): + def partial_logposterior(x): + lp = logprior_fn(x) + return lp + jnp.sum( + self.logdensity_by_observation(**x, **observations) + * data_mask.reshape(-1, 1) + ) + + return jax.jit(partial_logposterior) + + init, kernel = blackjax.partial_posteriors_smc( + hmc_kernel, + hmc_init, + hmc_parameters, + resampling.systematic, + 50, + partial_logposterior_factory=partial_logposterior_factory, + ) + + init_state = init(init_particles, 1000) + smc_kernel = self.variant(kernel) + + data_masks = jnp.array( + [ + jnp.concat( + [ + jnp.ones(datapoints_chosen), + jnp.zeros(dataset_size - datapoints_chosen), + ] + ) + for datapoints_chosen in np.arange(100, 1001, 50) + ] + ) + + def body_fn(carry, data_mask): + i, state = carry + subkey = jax.random.fold_in(self.key, i) + new_state, info = smc_kernel(subkey, state, data_mask) + return (i + 1, new_state), (new_state, info) + + (steps, result), it = jax.lax.scan(body_fn, (0, init_state), data_masks) + assert steps == 19 + + self.assert_linear_regression_test_case(result) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/vi/test_schrodinger_follmer.py b/tests/vi/test_schrodinger_follmer.py index fd58fed0a..79e7afedd 100644 --- a/tests/vi/test_schrodinger_follmer.py +++ b/tests/vi/test_schrodinger_follmer.py @@ -51,7 +51,7 @@ def logp_unnormalized_posterior(x, observed, prior_mu, prior_prec, true_cov): # Simulate the data observed = jax.random.multivariate_normal( - rng_key_observed, true_mu, true_cov, shape=(10_000,) + rng_key_observed, true_mu, true_cov, shape=(25,) ) logp_model = functools.partial(