28
28
We can also generate a relativistic dynamic :cite:p:`lu2017relativistic`.
29
29
30
30
"""
31
- from typing import Callable , NamedTuple , Optional , Protocol , Tuple , Union
31
+ from typing import Callable , NamedTuple , Optional , Protocol , Union
32
32
33
+ import jax
33
34
import jax .numpy as jnp
34
35
import jax .scipy as jscipy
35
36
from chex import Numeric
@@ -64,7 +65,7 @@ class Metric(NamedTuple):
64
65
sample_momentum : Callable [[PRNGKey , ArrayLikeTree ], ArrayLikeTree ]
65
66
kinetic_energy : KineticEnergy
66
67
check_turning : CheckTurning
67
- scale : Callable [[ArrayLikeTree , Tuple [ Tuple [ ArrayLikeTree , bool ]] ], ArrayLikeTree ]
68
+ scale : Callable [[ArrayLikeTree , ArrayLikeTree , bool ], ArrayLikeTree ]
68
69
69
70
70
71
MetricTypes = Union [Metric , Array , Callable [[ArrayLikeTree ], Array ]]
@@ -129,8 +130,8 @@ def gaussian_euclidean(
129
130
itself given the values of the momentum along the trajectory.
130
131
131
132
"""
132
- inv_mass_matrix_sqrt , mass_matrix_sqrt , diag = _format_covariance (
133
- inverse_mass_matrix , get_inv = True
133
+ mass_matrix_sqrt , inv_mass_matrix_sqrt , diag = _format_covariance (
134
+ inverse_mass_matrix , is_inv = True
134
135
)
135
136
136
137
def momentum_generator (rng_key : PRNGKey , position : ArrayLikeTree ) -> ArrayTree :
@@ -180,32 +181,34 @@ def is_turning(
180
181
return turning_at_left | turning_at_right
181
182
182
183
def scale (
183
- position : ArrayLikeTree , elements : Tuple [ Tuple [ ArrayLikeTree , bool ]]
184
- ) -> Tuple [ ArrayLikeTree ] :
184
+ position : ArrayLikeTree , element : ArrayLikeTree , inv : ArrayLikeTree
185
+ ) -> ArrayLikeTree :
185
186
"""Scale elements by the mass matrix.
186
187
187
188
Parameters
188
189
----------
189
190
position
190
191
The current position. Not used in this metric.
191
192
elements
192
- A tuple of (element, inv) pairs to scale.
193
- If inv is True, the element is scaled by the inverse square root mass matrix, i.e., elem <- M^{-1/2} elem.
193
+ Elements to scale
194
+ invs
195
+ Whether to scale the elements by the inverse mass matrix or the mass matrix.
196
+ If True, the element is scaled by the inverse square root mass matrix, i.e., elem <- (M^{1/2})^{-1} elem.
197
+ Same pytree structure as `elements`.
194
198
195
199
Returns
196
200
-------
197
201
scaled_elements
198
202
The scaled elements.
199
203
"""
200
- scaled_elements = []
201
- for element , inv in elements :
202
- ravelled_element , unravel_fn = ravel_pytree (element )
203
- if inv :
204
- ravelled_element = linear_map (inv_mass_matrix_sqrt , ravelled_element )
205
- else :
206
- ravelled_element = linear_map (mass_matrix_sqrt , ravelled_element )
207
- scaled_elements .append (unravel_fn (ravelled_element ))
208
- return tuple (scaled_elements ) # type: ignore
204
+
205
+ ravelled_element , unravel_fn = ravel_pytree (element )
206
+ scaled = jax .lax .cond (
207
+ inv ,
208
+ lambda : linear_map (inv_mass_matrix_sqrt , ravelled_element ),
209
+ lambda : linear_map (mass_matrix_sqrt , ravelled_element ),
210
+ )
211
+ return unravel_fn (scaled )
209
212
210
213
return Metric (momentum_generator , kinetic_energy , is_turning , scale )
211
214
@@ -215,7 +218,7 @@ def gaussian_riemannian(
215
218
) -> Metric :
216
219
def momentum_generator (rng_key : PRNGKey , position : ArrayLikeTree ) -> ArrayLikeTree :
217
220
mass_matrix = mass_matrix_fn (position )
218
- mass_matrix_sqrt , * _ = _format_covariance (mass_matrix , get_inv = False )
221
+ mass_matrix_sqrt , * _ = _format_covariance (mass_matrix , is_inv = False )
219
222
220
223
return generate_gaussian_noise (rng_key , position , sigma = mass_matrix_sqrt )
221
224
@@ -232,10 +235,10 @@ def kinetic_energy(
232
235
momentum , _ = ravel_pytree (momentum )
233
236
mass_matrix = mass_matrix_fn (position )
234
237
sqrt_mass_matrix , inv_sqrt_mass_matrix , diag = _format_covariance (
235
- mass_matrix , get_inv = True
238
+ mass_matrix , is_inv = False
236
239
)
237
240
238
- return _energy (momentum , 0 , sqrt_mass_matrix , inv_sqrt_mass_matrix , diag )
241
+ return _energy (momentum , 0 , sqrt_mass_matrix , inv_sqrt_mass_matrix . T , diag )
239
242
240
243
def is_turning (
241
244
momentum_left : ArrayLikeTree ,
@@ -270,76 +273,58 @@ def is_turning(
270
273
# return turning_at_left | turning_at_right
271
274
272
275
def scale (
273
- position : ArrayLikeTree , elements : Tuple [ Tuple [ ArrayLikeTree , bool ]]
274
- ) -> Tuple [ ArrayLikeTree ] :
276
+ position : ArrayLikeTree , element : ArrayLikeTree , inv : ArrayLikeTree
277
+ ) -> ArrayLikeTree :
275
278
"""Scale elements by the mass matrix.
276
279
277
280
Parameters
278
281
----------
279
282
position
280
283
The current position.
281
- elements
282
- A tuple of (element, inv) pairs to scale.
283
- If inv is True, the element is scaled by the inverse square root mass matrix, i.e., elem <- M^{-1/2} elem.
284
284
285
285
Returns
286
286
-------
287
287
scaled_elements
288
288
The scaled elements.
289
289
"""
290
- scaled_elements = []
291
290
mass_matrix = mass_matrix_fn (position )
292
- # some small performance improvement: group by inv and only compute the inverse Cholesky if needed
293
-
294
- inv_elements = [
295
- (k , element ) for k , (element , inv ) in enumerate (elements ) if inv
296
- ]
297
- non_inv_elements = [
298
- (k , element ) for k , (element , inv ) in enumerate (elements ) if not inv
299
- ]
300
- argsort = [k for k , _ in non_inv_elements ] + [k for k , _ in inv_elements ]
301
-
302
- mass_matrix_sqrt , inv_sqrt_mass_matrix , diag = _format_covariance (
303
- mass_matrix , get_inv = bool (inv_elements )
291
+ mass_matrix_sqrt , inv_mass_matrix_sqrt , diag = _format_covariance (
292
+ mass_matrix , is_inv = False
304
293
)
305
-
306
- for _ , element in non_inv_elements :
307
- rav_element , unravel_fn = ravel_pytree (element )
308
- rav_element = linear_map (mass_matrix_sqrt , rav_element )
309
- scaled_elements .append (unravel_fn (rav_element ))
310
-
311
- if inv_elements :
312
- for _ , element in inv_elements :
313
- rav_element , unravel_fn = ravel_pytree (element )
314
- rav_element = linear_map (inv_sqrt_mass_matrix , rav_element )
315
- scaled_elements .append (unravel_fn (rav_element ))
316
-
317
- scaled_elements = [scaled_elements [k ] for k in argsort ]
318
-
319
- return tuple (scaled_elements ) # type: ignore
294
+ ravelled_element , unravel_fn = ravel_pytree (element )
295
+ scaled = jax .lax .cond (
296
+ inv ,
297
+ lambda : linear_map (inv_mass_matrix_sqrt , ravelled_element ),
298
+ lambda : linear_map (mass_matrix_sqrt , ravelled_element ),
299
+ )
300
+ return unravel_fn (scaled )
320
301
321
302
return Metric (momentum_generator , kinetic_energy , is_turning , scale )
322
303
323
304
324
- def _format_covariance (cov : Array , get_inv ):
305
+ def _format_covariance (cov : Array , is_inv ):
325
306
ndim = jnp .ndim (cov )
326
307
if ndim == 1 :
327
308
cov_sqrt = jnp .sqrt (cov )
309
+ inv_cov_sqrt = 1 / cov_sqrt
328
310
diag = lambda x : x
329
- if get_inv :
330
- inv_cov_sqrt = jnp .reciprocal (cov_sqrt )
331
- else :
332
- inv_cov_sqrt = None
311
+ if is_inv :
312
+ inv_cov_sqrt , cov_sqrt = cov_sqrt , inv_cov_sqrt
333
313
elif ndim == 2 :
334
- cov_sqrt = jscipy .linalg .cholesky (cov , lower = False )
335
- diag = lambda x : jnp .diag (x )
336
- if get_inv :
337
- identity = jnp .identity (cov .shape [0 ])
338
- inv_cov_sqrt = jscipy .linalg .solve_triangular (
339
- cov_sqrt , identity , lower = False
314
+ identity = jnp .identity (cov .shape [0 ])
315
+ if is_inv :
316
+ inv_cov_sqrt = jscipy .linalg .cholesky (cov , lower = True )
317
+ cov_sqrt = jscipy .linalg .solve_triangular (
318
+ inv_cov_sqrt , identity , lower = True , trans = True
340
319
)
341
320
else :
342
- inv_cov_sqrt = None
321
+ cov_sqrt = jscipy .linalg .cholesky (cov , lower = False ).T
322
+ inv_cov_sqrt = jscipy .linalg .solve_triangular (
323
+ cov_sqrt , identity , lower = True , trans = True
324
+ )
325
+
326
+ diag = lambda x : jnp .diag (x )
327
+
343
328
else :
344
329
raise ValueError (
345
330
"The mass matrix has the wrong number of dimensions:"
0 commit comments