Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement metric scaling #733

Merged
merged 12 commits into from
Sep 16, 2024
9 changes: 7 additions & 2 deletions blackjax/mcmc/ghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree

import blackjax.mcmc.hmc as hmc
import blackjax.mcmc.integrators as integrators
Expand Down Expand Up @@ -129,8 +130,8 @@ def kernel(

"""

flat_inverse_scale = jax.flatten_util.ravel_pytree(momentum_inverse_scale)[0]
momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean(
flat_inverse_scale = ravel_pytree(momentum_inverse_scale)[0]
momentum_generator, kinetic_energy_fn, *_ = metrics.gaussian_euclidean(
flat_inverse_scale**2
)

Expand Down Expand Up @@ -248,6 +249,10 @@ def as_top_level_api(
A PyTree of the same structure as the target PyTree (position) with the
values used for as a step size for each dimension of the target space in
the velocity verlet integrator.
momentum_inverse_scale
Pytree with the same structure as the targeted position variable
specifying the per dimension inverse scaling transformation applied
to the persistent momentum variable prior to the integration step.
alpha
The value defining the persistence of the momentum variable.
delta
Expand Down
186 changes: 123 additions & 63 deletions blackjax/mcmc/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,21 @@
"""
from typing import Callable, NamedTuple, Optional, Protocol, Union

import jax
import jax.numpy as jnp
import jax.scipy as jscipy
from jax.flatten_util import ravel_pytree
from jax.scipy import stats as sp_stats

from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey
from blackjax.util import generate_gaussian_noise
from blackjax.types import Array, ArrayLikeTree, ArrayTree, Numeric, PRNGKey
from blackjax.util import generate_gaussian_noise, linear_map

__all__ = ["default_metric", "gaussian_euclidean", "gaussian_riemannian"]


class KineticEnergy(Protocol):
def __call__(
self, momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
) -> float:
) -> Numeric:
...


Expand All @@ -60,10 +60,18 @@ def __call__(
...


class Scale(Protocol):
def __call__(
self, position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
) -> ArrayLikeTree:
...


class Metric(NamedTuple):
sample_momentum: Callable[[PRNGKey, ArrayLikeTree], ArrayLikeTree]
kinetic_energy: KineticEnergy
check_turning: CheckTurning
scale: Scale


MetricTypes = Union[Metric, Array, Callable[[ArrayLikeTree], Array]]
Expand Down Expand Up @@ -128,46 +136,19 @@ def gaussian_euclidean(
itself given the values of the momentum along the trajectory.

"""
ndim = jnp.ndim(inverse_mass_matrix) # type: ignore[arg-type]
shape = jnp.shape(inverse_mass_matrix)[:1] # type: ignore[arg-type]

if ndim == 1: # diagonal mass matrix
mass_matrix_sqrt = jnp.sqrt(jnp.reciprocal(inverse_mass_matrix))
matmul = jnp.multiply

elif ndim == 2:
# inverse mass matrix can be factored into L*L.T. We want the cholesky
# factor (inverse of L.T) of the mass matrix.
L = jscipy.linalg.cholesky(inverse_mass_matrix, lower=True)
identity = jnp.identity(shape[0])
mass_matrix_sqrt = jscipy.linalg.solve_triangular(
L, identity, lower=True, trans=True
)
# Note that mass_matrix_sqrt is a upper triangular matrix here, with
# jscipy.linalg.inv(mass_matrix_sqrt @ mass_matrix_sqrt.T)
# == inverse_mass_matrix
# An alternative is to compute directly the cholesky factor of the inverse mass
# matrix
# mass_matrix_sqrt = jscipy.linalg.cholesky(
# jscipy.linalg.inv(inverse_mass_matrix), lower=True)
# which the result would instead be a lower triangular matrix.
matmul = jnp.matmul

else:
raise ValueError(
"The mass matrix has the wrong number of dimensions:"
f" expected 1 or 2, got {ndim}."
)
mass_matrix_sqrt, inv_mass_matrix_sqrt, diag = _format_covariance(
inverse_mass_matrix, is_inv=True
)

def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayTree:
return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt)

def kinetic_energy(
momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
) -> float:
) -> Numeric:
del position
momentum, _ = ravel_pytree(momentum)
velocity = matmul(inverse_mass_matrix, momentum)
velocity = linear_map(inverse_mass_matrix, momentum)
kinetic_energy_val = 0.5 * jnp.dot(velocity, momentum)
return kinetic_energy_val

Expand Down Expand Up @@ -196,39 +177,60 @@ def is_turning(
m_right, _ = ravel_pytree(momentum_right)
m_sum, _ = ravel_pytree(momentum_sum)

velocity_left = matmul(inverse_mass_matrix, m_left)
velocity_right = matmul(inverse_mass_matrix, m_right)
velocity_left = linear_map(inverse_mass_matrix, m_left)
velocity_right = linear_map(inverse_mass_matrix, m_right)

# rho = m_sum
rho = m_sum - (m_right + m_left) / 2
turning_at_left = jnp.dot(velocity_left, rho) <= 0
turning_at_right = jnp.dot(velocity_right, rho) <= 0
return turning_at_left | turning_at_right

return Metric(momentum_generator, kinetic_energy, is_turning)
def scale(
position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
AdrienCorenflos marked this conversation as resolved.
Show resolved Hide resolved
) -> ArrayLikeTree:
"""Scale elements by the mass matrix.

Parameters
----------
position
The current position. Not used in this metric.
elements
Elements to scale
invs
Whether to scale the elements by the inverse mass matrix or the mass matrix.
If True, the element is scaled by the inverse square root mass matrix, i.e., elem <- (M^{1/2})^{-1} elem.
Same pytree structure as `elements`.

Returns
-------
scaled_elements
The scaled elements.
"""

ravelled_element, unravel_fn = ravel_pytree(element)
scaled = jax.lax.cond(
inv,
lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element),
lambda: linear_map(mass_matrix_sqrt, ravelled_element),
)
return unravel_fn(scaled)

return Metric(momentum_generator, kinetic_energy, is_turning, scale)


def gaussian_riemannian(
mass_matrix_fn: Callable,
) -> Metric:
def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayLikeTree:
mass_matrix = mass_matrix_fn(position)
ndim = jnp.ndim(mass_matrix)
if ndim == 1:
mass_matrix_sqrt = jnp.sqrt(mass_matrix)
elif ndim == 2:
mass_matrix_sqrt = jscipy.linalg.cholesky(mass_matrix, lower=True)
else:
raise ValueError(
"The mass matrix has the wrong number of dimensions:"
f" expected 1 or 2, got {jnp.ndim(mass_matrix)}."
)
mass_matrix_sqrt, *_ = _format_covariance(mass_matrix, is_inv=False)

return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt)

def kinetic_energy(
momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
) -> float:
) -> Numeric:
if position is None:
raise ValueError(
"A Reinmannian kinetic energy function must be called with the "
Expand All @@ -238,18 +240,11 @@ def kinetic_energy(

momentum, _ = ravel_pytree(momentum)
mass_matrix = mass_matrix_fn(position)
ndim = jnp.ndim(mass_matrix)
if ndim == 1:
return -jnp.sum(sp_stats.norm.logpdf(momentum, 0.0, jnp.sqrt(mass_matrix)))
elif ndim == 2:
return -sp_stats.multivariate_normal.logpdf(
momentum, jnp.zeros_like(momentum), mass_matrix
)
else:
raise ValueError(
"The mass matrix has the wrong number of dimensions:"
f" expected 1 or 2, got {jnp.ndim(mass_matrix)}."
)
sqrt_mass_matrix, inv_sqrt_mass_matrix, diag = _format_covariance(
mass_matrix, is_inv=False
)

return _energy(momentum, 0, sqrt_mass_matrix, inv_sqrt_mass_matrix.T, diag)

def is_turning(
momentum_left: ArrayLikeTree,
Expand Down Expand Up @@ -283,4 +278,69 @@ def is_turning(
# turning_at_right = jnp.dot(velocity_right, rho) <= 0
# return turning_at_left | turning_at_right

return Metric(momentum_generator, kinetic_energy, is_turning)
def scale(
position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
AdrienCorenflos marked this conversation as resolved.
Show resolved Hide resolved
) -> ArrayLikeTree:
"""Scale elements by the mass matrix.

Parameters
----------
position
The current position.

Returns
-------
scaled_elements
The scaled elements.
"""
mass_matrix = mass_matrix_fn(position)
mass_matrix_sqrt, inv_mass_matrix_sqrt, diag = _format_covariance(
mass_matrix, is_inv=False
)
ravelled_element, unravel_fn = ravel_pytree(element)
scaled = jax.lax.cond(
inv,
junpenglao marked this conversation as resolved.
Show resolved Hide resolved
lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element),
lambda: linear_map(mass_matrix_sqrt, ravelled_element),
)
return unravel_fn(scaled)

return Metric(momentum_generator, kinetic_energy, is_turning, scale)


def _format_covariance(cov: Array, is_inv):
ndim = jnp.ndim(cov)
if ndim == 1:
cov_sqrt = jnp.sqrt(cov)
inv_cov_sqrt = 1 / cov_sqrt
diag = lambda x: x
if is_inv:
inv_cov_sqrt, cov_sqrt = cov_sqrt, inv_cov_sqrt
elif ndim == 2:
identity = jnp.identity(cov.shape[0])
if is_inv:
inv_cov_sqrt = jscipy.linalg.cholesky(cov, lower=True)
cov_sqrt = jscipy.linalg.solve_triangular(
inv_cov_sqrt, identity, lower=True, trans=True
)
else:
cov_sqrt = jscipy.linalg.cholesky(cov, lower=False).T
inv_cov_sqrt = jscipy.linalg.solve_triangular(
cov_sqrt, identity, lower=True, trans=True
)

diag = lambda x: jnp.diag(x)

else:
raise ValueError(
"The mass matrix has the wrong number of dimensions:"
f" expected 1 or 2, got {jnp.ndim(cov)}."
)
return cov_sqrt, inv_cov_sqrt, diag


def _energy(x, mean, cov_sqrt, inv_cov_sqrt, diag):
d = x.shape[0]
z = linear_map(inv_cov_sqrt, x - mean)
const = jnp.sum(jnp.log(diag(cov_sqrt))) + d / 2 * jnp.log(2 * jnp.pi)
return 0.5 * jnp.sum(z**2) + const
2 changes: 1 addition & 1 deletion blackjax/mcmc/periodic_orbital.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def kernel(

"""

momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean(
momentum_generator, kinetic_energy_fn, *_ = metrics.gaussian_euclidean(
inverse_mass_matrix
)
bijection_fn = bijection(logdensity_fn, kinetic_energy_fn)
Expand Down
4 changes: 4 additions & 0 deletions blackjax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,7 @@ class WelfordAlgorithmState(NamedTuple):

#: JAX PRNGKey
PRNGKey = jax.Array

#: JAX Scalar types
Scalar = Union[float, int]
Numeric = Union[jax.Array, Scalar]
Loading
Loading