Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement metric scaling #733

Merged
merged 12 commits into from
Sep 16, 2024
175 changes: 123 additions & 52 deletions blackjax/mcmc/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
We can also generate a relativistic dynamic :cite:p:`lu2017relativistic`.

"""
from typing import Callable, NamedTuple, Optional, Protocol, Union
from typing import Callable, NamedTuple, Optional, Protocol, Union, Tuple, List

import jax.numpy as jnp
import jax.scipy as jscipy
Expand All @@ -43,19 +43,19 @@

class KineticEnergy(Protocol):
def __call__(
self, momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
self, momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
) -> float:
...


class CheckTurning(Protocol):
def __call__(
self,
momentum_left: ArrayLikeTree,
momentum_right: ArrayLikeTree,
momentum_sum: ArrayLikeTree,
position_left: Optional[ArrayLikeTree] = None,
position_right: Optional[ArrayLikeTree] = None,
self,
momentum_left: ArrayLikeTree,
momentum_right: ArrayLikeTree,
momentum_sum: ArrayLikeTree,
position_left: Optional[ArrayLikeTree] = None,
position_right: Optional[ArrayLikeTree] = None,
) -> bool:
...

Expand All @@ -64,6 +64,7 @@ class Metric(NamedTuple):
sample_momentum: Callable[[PRNGKey, ArrayLikeTree], ArrayLikeTree]
kinetic_energy: KineticEnergy
check_turning: CheckTurning
scale: Callable[[ArrayLikeTree, Tuple[Tuple[ArrayLikeTree, bool]]], ArrayLikeTree]


MetricTypes = Union[Metric, Array, Callable[[ArrayLikeTree], Array]]
Expand Down Expand Up @@ -94,7 +95,7 @@ def default_metric(metric: MetricTypes) -> Metric:


def gaussian_euclidean(
inverse_mass_matrix: Array,
inverse_mass_matrix: Array,
) -> Metric:
r"""Hamiltonian dynamic on euclidean manifold with normally-distributed momentum
:cite:p:`betancourt2013general`.
Expand Down Expand Up @@ -128,42 +129,14 @@ def gaussian_euclidean(
itself given the values of the momentum along the trajectory.

"""
ndim = jnp.ndim(inverse_mass_matrix) # type: ignore[arg-type]
shape = jnp.shape(inverse_mass_matrix)[:1] # type: ignore[arg-type]
inv_mass_matrix_sqrt, mass_matrix_sqrt, matmul = _format_covariance(inverse_mass_matrix, get_inv=True)

if ndim == 1: # diagonal mass matrix
mass_matrix_sqrt = jnp.sqrt(jnp.reciprocal(inverse_mass_matrix))
matmul = jnp.multiply

elif ndim == 2:
# inverse mass matrix can be factored into L*L.T. We want the cholesky
# factor (inverse of L.T) of the mass matrix.
L = jscipy.linalg.cholesky(inverse_mass_matrix, lower=True)
identity = jnp.identity(shape[0])
mass_matrix_sqrt = jscipy.linalg.solve_triangular(
L, identity, lower=True, trans=True
)
# Note that mass_matrix_sqrt is a upper triangular matrix here, with
# jscipy.linalg.inv(mass_matrix_sqrt @ mass_matrix_sqrt.T)
# == inverse_mass_matrix
# An alternative is to compute directly the cholesky factor of the inverse mass
# matrix
# mass_matrix_sqrt = jscipy.linalg.cholesky(
# jscipy.linalg.inv(inverse_mass_matrix), lower=True)
# which the result would instead be a lower triangular matrix.
matmul = jnp.matmul

else:
raise ValueError(
"The mass matrix has the wrong number of dimensions:"
f" expected 1 or 2, got {ndim}."
)

def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayTree:
return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt)

def kinetic_energy(
momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
) -> float:
del position
momentum, _ = ravel_pytree(momentum)
Expand All @@ -172,11 +145,11 @@ def kinetic_energy(
return kinetic_energy_val

def is_turning(
momentum_left: ArrayLikeTree,
momentum_right: ArrayLikeTree,
momentum_sum: ArrayLikeTree,
position_left: Optional[ArrayLikeTree] = None,
position_right: Optional[ArrayLikeTree] = None,
momentum_left: ArrayLikeTree,
momentum_right: ArrayLikeTree,
momentum_sum: ArrayLikeTree,
position_left: Optional[ArrayLikeTree] = None,
position_right: Optional[ArrayLikeTree] = None,
) -> bool:
"""Generalized U-turn criterion :cite:p:`betancourt2013generalizing,nuts_uturn`.

Expand Down Expand Up @@ -205,12 +178,43 @@ def is_turning(
turning_at_right = jnp.dot(velocity_right, rho) <= 0
return turning_at_left | turning_at_right

return Metric(momentum_generator, kinetic_energy, is_turning)
def scale(position: ArrayLikeTree, elements: Tuple[Tuple[ArrayLikeTree, bool]]) -> Tuple[ArrayLikeTree]:
"""Scale elements by the mass matrix.

Parameters
----------
position
The current position. Not used in this metric.
elements
A tuple of (element, inv) pairs to scale.
If inv is True, the element is scaled by the inverse square root mass matrix, i.e., elem <- M^{-1/2} elem.

Returns
-------
scaled_elements
The scaled elements.
"""
scaled_elements = []
for element, inv in elements:
ravelled_element, unravel_fn = ravel_pytree(element)
if inv:
ravelled_element = matmul(inv_mass_matrix_sqrt, ravelled_element)
else:
ravelled_element = matmul(mass_matrix_sqrt, ravelled_element)
scaled_elements.append(unravel_fn(ravelled_element))
return tuple(scaled_elements)

return Metric(momentum_generator, kinetic_energy, is_turning, scale)


def gaussian_riemannian(
mass_matrix_fn: Callable,
mass_matrix_fn: Callable,
) -> Metric:





def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayLikeTree:
mass_matrix = mass_matrix_fn(position)
ndim = jnp.ndim(mass_matrix)
Expand All @@ -227,7 +231,7 @@ def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayLikeTr
return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt)

def kinetic_energy(
momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
) -> float:
if position is None:
raise ValueError(
Expand All @@ -252,11 +256,11 @@ def kinetic_energy(
)

def is_turning(
momentum_left: ArrayLikeTree,
momentum_right: ArrayLikeTree,
momentum_sum: ArrayLikeTree,
position_left: Optional[ArrayLikeTree] = None,
position_right: Optional[ArrayLikeTree] = None,
momentum_left: ArrayLikeTree,
momentum_right: ArrayLikeTree,
momentum_sum: ArrayLikeTree,
position_left: Optional[ArrayLikeTree] = None,
position_right: Optional[ArrayLikeTree] = None,
) -> bool:
del momentum_left, momentum_right, momentum_sum, position_left, position_right
raise NotImplementedError(
Expand All @@ -283,4 +287,71 @@ def is_turning(
# turning_at_right = jnp.dot(velocity_right, rho) <= 0
# return turning_at_left | turning_at_right

def scale(position: ArrayLikeTree, elements: Tuple[Tuple[ArrayLikeTree, bool]]) -> Tuple[ArrayLikeTree]:
"""Scale elements by the mass matrix.

Parameters
----------
position
The current position.
elements
A tuple of (element, inv) pairs to scale.
If inv is True, the element is scaled by the inverse square root mass matrix, i.e., elem <- M^{-1/2} elem.

Returns
-------
scaled_elements
The scaled elements.
"""
scaled_elements = []
mass_matrix = mass_matrix_fn(position)
# some small performance improvement: group by inv and only compute the inverse Cholesky if needed

inv_elements = [(k, element) for k, (element, inv) in enumerate(elements) if inv]
non_inv_elements = [(k, element) for k, (element, inv) in enumerate(elements) if not inv]
argsort = [k for k, _ in non_inv_elements] + [k for k, _ in inv_elements]

mass_matrix_sqrt, inv_sqrt_mass_matrix, matmul = _format_covariance(mass_matrix, get_inv=bool(inv_elements))

for _, element in non_inv_elements:
rav_element, unravel_fn = ravel_pytree(element)
rav_element = matmul(mass_matrix_sqrt, rav_element)
scaled_elements.append(unravel_fn(rav_element))

if inv_elements:
for _, element in inv_elements:
rav_element, unravel_fn = ravel_pytree(element)
rav_element = matmul(inv_sqrt_mass_matrix, rav_element)
scaled_elements.append(unravel_fn(rav_element))

scaled_elements = [scaled_elements[k] for k in argsort]

return tuple(scaled_elements)

return Metric(momentum_generator, kinetic_energy, is_turning)
AdrienCorenflos marked this conversation as resolved.
Show resolved Hide resolved

def _format_covariance(mass_matrix: Array, get_inv):
ndim = jnp.ndim(mass_matrix)
if ndim == 1:
mass_matrix_sqrt = jnp.sqrt(mass_matrix)
matmul = jnp.multiply
if get_inv:
inv_mass_matrix_sqrt = jnp.reciprocal(mass_matrix_sqrt)
else:
inv_mass_matrix_sqrt = None
elif ndim == 2:
mass_matrix_sqrt = jscipy.linalg.cholesky(mass_matrix, lower=True)
matmul = jnp.matmul
if get_inv:
identity = jnp.identity(mass_matrix.shape[0])
inv_mass_matrix_sqrt = jscipy.linalg.solve_triangular(
mass_matrix_sqrt, identity, lower=True
)
else:
inv_mass_matrix_sqrt = None
else:
raise ValueError(
"The mass matrix has the wrong number of dimensions:"
f" expected 1 or 2, got {jnp.ndim(mass_matrix)}."
)
return mass_matrix_sqrt, inv_mass_matrix_sqrt, matmul
Loading