Skip to content

Commit 871713f

Browse files
Fixing a bunch of tests
1 parent fd70221 commit 871713f

File tree

2 files changed

+122
-119
lines changed

2 files changed

+122
-119
lines changed

blackjax/mcmc/metrics.py

+50-65
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@
2828
We can also generate a relativistic dynamic :cite:p:`lu2017relativistic`.
2929
3030
"""
31-
from typing import Callable, NamedTuple, Optional, Protocol, Tuple, Union
31+
from typing import Callable, NamedTuple, Optional, Protocol, Union
3232

33+
import jax
3334
import jax.numpy as jnp
3435
import jax.scipy as jscipy
3536
from chex import Numeric
@@ -64,7 +65,7 @@ class Metric(NamedTuple):
6465
sample_momentum: Callable[[PRNGKey, ArrayLikeTree], ArrayLikeTree]
6566
kinetic_energy: KineticEnergy
6667
check_turning: CheckTurning
67-
scale: Callable[[ArrayLikeTree, Tuple[Tuple[ArrayLikeTree, bool]]], ArrayLikeTree]
68+
scale: Callable[[ArrayLikeTree, ArrayLikeTree, bool], ArrayLikeTree]
6869

6970

7071
MetricTypes = Union[Metric, Array, Callable[[ArrayLikeTree], Array]]
@@ -129,8 +130,8 @@ def gaussian_euclidean(
129130
itself given the values of the momentum along the trajectory.
130131
131132
"""
132-
inv_mass_matrix_sqrt, mass_matrix_sqrt, diag = _format_covariance(
133-
inverse_mass_matrix, get_inv=True
133+
mass_matrix_sqrt, inv_mass_matrix_sqrt, diag = _format_covariance(
134+
inverse_mass_matrix, is_inv=True
134135
)
135136

136137
def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayTree:
@@ -180,32 +181,34 @@ def is_turning(
180181
return turning_at_left | turning_at_right
181182

182183
def scale(
183-
position: ArrayLikeTree, elements: Tuple[Tuple[ArrayLikeTree, bool]]
184-
) -> Tuple[ArrayLikeTree]:
184+
position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
185+
) -> ArrayLikeTree:
185186
"""Scale elements by the mass matrix.
186187
187188
Parameters
188189
----------
189190
position
190191
The current position. Not used in this metric.
191192
elements
192-
A tuple of (element, inv) pairs to scale.
193-
If inv is True, the element is scaled by the inverse square root mass matrix, i.e., elem <- M^{-1/2} elem.
193+
Elements to scale
194+
invs
195+
Whether to scale the elements by the inverse mass matrix or the mass matrix.
196+
If True, the element is scaled by the inverse square root mass matrix, i.e., elem <- (M^{1/2})^{-1} elem.
197+
Same pytree structure as `elements`.
194198
195199
Returns
196200
-------
197201
scaled_elements
198202
The scaled elements.
199203
"""
200-
scaled_elements = []
201-
for element, inv in elements:
202-
ravelled_element, unravel_fn = ravel_pytree(element)
203-
if inv:
204-
ravelled_element = linear_map(inv_mass_matrix_sqrt, ravelled_element)
205-
else:
206-
ravelled_element = linear_map(mass_matrix_sqrt, ravelled_element)
207-
scaled_elements.append(unravel_fn(ravelled_element))
208-
return tuple(scaled_elements) # type: ignore
204+
205+
ravelled_element, unravel_fn = ravel_pytree(element)
206+
scaled = jax.lax.cond(
207+
inv,
208+
lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element),
209+
lambda: linear_map(mass_matrix_sqrt, ravelled_element),
210+
)
211+
return unravel_fn(scaled)
209212

210213
return Metric(momentum_generator, kinetic_energy, is_turning, scale)
211214

@@ -215,7 +218,7 @@ def gaussian_riemannian(
215218
) -> Metric:
216219
def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayLikeTree:
217220
mass_matrix = mass_matrix_fn(position)
218-
mass_matrix_sqrt, *_ = _format_covariance(mass_matrix, get_inv=False)
221+
mass_matrix_sqrt, *_ = _format_covariance(mass_matrix, is_inv=False)
219222

220223
return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt)
221224

@@ -232,10 +235,10 @@ def kinetic_energy(
232235
momentum, _ = ravel_pytree(momentum)
233236
mass_matrix = mass_matrix_fn(position)
234237
sqrt_mass_matrix, inv_sqrt_mass_matrix, diag = _format_covariance(
235-
mass_matrix, get_inv=True
238+
mass_matrix, is_inv=False
236239
)
237240

238-
return _energy(momentum, 0, sqrt_mass_matrix, inv_sqrt_mass_matrix, diag)
241+
return _energy(momentum, 0, sqrt_mass_matrix, inv_sqrt_mass_matrix.T, diag)
239242

240243
def is_turning(
241244
momentum_left: ArrayLikeTree,
@@ -270,76 +273,58 @@ def is_turning(
270273
# return turning_at_left | turning_at_right
271274

272275
def scale(
273-
position: ArrayLikeTree, elements: Tuple[Tuple[ArrayLikeTree, bool]]
274-
) -> Tuple[ArrayLikeTree]:
276+
position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
277+
) -> ArrayLikeTree:
275278
"""Scale elements by the mass matrix.
276279
277280
Parameters
278281
----------
279282
position
280283
The current position.
281-
elements
282-
A tuple of (element, inv) pairs to scale.
283-
If inv is True, the element is scaled by the inverse square root mass matrix, i.e., elem <- M^{-1/2} elem.
284284
285285
Returns
286286
-------
287287
scaled_elements
288288
The scaled elements.
289289
"""
290-
scaled_elements = []
291290
mass_matrix = mass_matrix_fn(position)
292-
# some small performance improvement: group by inv and only compute the inverse Cholesky if needed
293-
294-
inv_elements = [
295-
(k, element) for k, (element, inv) in enumerate(elements) if inv
296-
]
297-
non_inv_elements = [
298-
(k, element) for k, (element, inv) in enumerate(elements) if not inv
299-
]
300-
argsort = [k for k, _ in non_inv_elements] + [k for k, _ in inv_elements]
301-
302-
mass_matrix_sqrt, inv_sqrt_mass_matrix, diag = _format_covariance(
303-
mass_matrix, get_inv=bool(inv_elements)
291+
mass_matrix_sqrt, inv_mass_matrix_sqrt, diag = _format_covariance(
292+
mass_matrix, is_inv=False
304293
)
305-
306-
for _, element in non_inv_elements:
307-
rav_element, unravel_fn = ravel_pytree(element)
308-
rav_element = linear_map(mass_matrix_sqrt, rav_element)
309-
scaled_elements.append(unravel_fn(rav_element))
310-
311-
if inv_elements:
312-
for _, element in inv_elements:
313-
rav_element, unravel_fn = ravel_pytree(element)
314-
rav_element = linear_map(inv_sqrt_mass_matrix, rav_element)
315-
scaled_elements.append(unravel_fn(rav_element))
316-
317-
scaled_elements = [scaled_elements[k] for k in argsort]
318-
319-
return tuple(scaled_elements) # type: ignore
294+
ravelled_element, unravel_fn = ravel_pytree(element)
295+
scaled = jax.lax.cond(
296+
inv,
297+
lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element),
298+
lambda: linear_map(mass_matrix_sqrt, ravelled_element),
299+
)
300+
return unravel_fn(scaled)
320301

321302
return Metric(momentum_generator, kinetic_energy, is_turning, scale)
322303

323304

324-
def _format_covariance(cov: Array, get_inv):
305+
def _format_covariance(cov: Array, is_inv):
325306
ndim = jnp.ndim(cov)
326307
if ndim == 1:
327308
cov_sqrt = jnp.sqrt(cov)
309+
inv_cov_sqrt = 1 / cov_sqrt
328310
diag = lambda x: x
329-
if get_inv:
330-
inv_cov_sqrt = jnp.reciprocal(cov_sqrt)
331-
else:
332-
inv_cov_sqrt = None
311+
if is_inv:
312+
inv_cov_sqrt, cov_sqrt = cov_sqrt, inv_cov_sqrt
333313
elif ndim == 2:
334-
cov_sqrt = jscipy.linalg.cholesky(cov, lower=False)
335-
diag = lambda x: jnp.diag(x)
336-
if get_inv:
337-
identity = jnp.identity(cov.shape[0])
338-
inv_cov_sqrt = jscipy.linalg.solve_triangular(
339-
cov_sqrt, identity, lower=False
314+
identity = jnp.identity(cov.shape[0])
315+
if is_inv:
316+
inv_cov_sqrt = jscipy.linalg.cholesky(cov, lower=True)
317+
cov_sqrt = jscipy.linalg.solve_triangular(
318+
inv_cov_sqrt, identity, lower=True, trans=True
340319
)
341320
else:
342-
inv_cov_sqrt = None
321+
cov_sqrt = jscipy.linalg.cholesky(cov, lower=False).T
322+
inv_cov_sqrt = jscipy.linalg.solve_triangular(
323+
cov_sqrt, identity, lower=True, trans=True
324+
)
325+
326+
diag = lambda x: jnp.diag(x)
327+
343328
else:
344329
raise ValueError(
345330
"The mass matrix has the wrong number of dimensions:"

0 commit comments

Comments
 (0)