2
2
import jax .numpy as jnp
3
3
from jax .flatten_util import ravel_pytree
4
4
5
- import optax
6
- import blackjax
7
- from blackjax .adaptation .mclmc_adaptation import MCLMCAdaptationState , handle_nans
5
+ from blackjax .adaptation .mclmc_adaptation import MCLMCAdaptationState
8
6
from blackjax .adaptation .step_size import (
9
7
DualAveragingAdaptationState ,
10
8
dual_averaging_adaptation ,
11
9
)
12
10
from blackjax .diagnostics import effective_sample_size
13
- from blackjax .mcmc .adjusted_mclmc import rescale
14
- from blackjax .util import pytree_size , incremental_value_update , run_inference_algorithm
15
-
16
- from blackjax .mcmc .integrators import (
17
- generate_euclidean_integrator ,
18
- generate_isokinetic_integrator ,
19
- mclachlan ,
20
- yoshida ,
21
- velocity_verlet ,
22
- omelyan ,
23
- isokinetic_mclachlan ,
24
- isokinetic_velocity_verlet ,
25
- isokinetic_yoshida ,
26
- isokinetic_omelyan ,
27
- )
11
+ from blackjax .util import incremental_value_update , pytree_size
28
12
29
13
Lratio_lowerbound = 0.0
30
14
Lratio_upperbound = 2.0
@@ -92,12 +76,7 @@ def adjusted_mclmc_find_L_and_step_size(
92
76
93
77
for i in range (num_windows ):
94
78
window_key = jax .random .fold_in (part1_key , i )
95
- (
96
- state ,
97
- params ,
98
- eigenvector
99
-
100
- ) = adjusted_mclmc_make_L_step_size_adaptation (
79
+ (state , params , eigenvector ) = adjusted_mclmc_make_L_step_size_adaptation (
101
80
kernel = mclmc_kernel ,
102
81
dim = dim ,
103
82
frac_tune1 = frac_tune1 ,
@@ -106,25 +85,22 @@ def adjusted_mclmc_find_L_and_step_size(
106
85
diagonal_preconditioning = diagonal_preconditioning ,
107
86
max = max ,
108
87
tuning_factor = tuning_factor ,
109
- )(
110
- state , params , num_steps , window_key
111
- )
88
+ )(state , params , num_steps , window_key )
112
89
113
90
if frac_tune3 != 0 :
114
91
for i in range (2 ):
115
92
part2_key = jax .random .fold_in (part2_key , i )
116
93
part2_key1 , part2_key2 = jax .random .split (part2_key , 2 )
117
94
118
95
state , params = adjusted_mclmc_make_adaptation_L (
119
- mclmc_kernel , frac = frac_tune3 , Lfactor = 0.5 , max = max , eigenvector = eigenvector ,
96
+ mclmc_kernel ,
97
+ frac = frac_tune3 ,
98
+ Lfactor = 0.5 ,
99
+ max = max ,
100
+ eigenvector = eigenvector ,
120
101
)(state , params , num_steps , part2_key1 )
121
102
122
- (
123
- state ,
124
- params ,
125
- _
126
-
127
- ) = adjusted_mclmc_make_L_step_size_adaptation (
103
+ (state , params , _ ) = adjusted_mclmc_make_L_step_size_adaptation (
128
104
kernel = mclmc_kernel ,
129
105
dim = dim ,
130
106
frac_tune1 = frac_tune1 ,
@@ -134,12 +110,7 @@ def adjusted_mclmc_find_L_and_step_size(
134
110
diagonal_preconditioning = diagonal_preconditioning ,
135
111
max = max ,
136
112
tuning_factor = tuning_factor ,
137
- )(
138
- state , params , num_steps , part2_key2
139
- )
140
-
141
-
142
-
113
+ )(state , params , num_steps , part2_key2 )
143
114
144
115
return state , params
145
116
@@ -152,7 +123,7 @@ def adjusted_mclmc_make_L_step_size_adaptation(
152
123
target ,
153
124
diagonal_preconditioning ,
154
125
fix_L_first_da = False ,
155
- max = ' avg' ,
126
+ max = " avg" ,
156
127
tuning_factor = 1.0 ,
157
128
):
158
129
"""Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC"""
@@ -161,8 +132,6 @@ def dual_avg_step(fix_L, update_da):
161
132
"""does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize"""
162
133
163
134
def step (iteration_state , weight_and_key ):
164
-
165
-
166
135
mask , rng_key = weight_and_key
167
136
(
168
137
previous_state ,
@@ -180,7 +149,6 @@ def step(iteration_state, weight_and_key):
180
149
step_size = params .step_size ,
181
150
sqrt_diag_cov = params .sqrt_diag_cov ,
182
151
)
183
-
184
152
185
153
# step updating
186
154
success , state , step_size_max , energy_change = handle_nans (
@@ -231,7 +199,7 @@ def step(iteration_state, weight_and_key):
231
199
+ (1 - mask ) * params .L ,
232
200
)
233
201
234
- if max != ' max_svd' :
202
+ if max != " max_svd" :
235
203
state_position = None
236
204
else :
237
205
state_position = state .position
@@ -259,9 +227,6 @@ def step_size_adaptation(mask, state, params, keys, fix_L, initial_da, update_da
259
227
),
260
228
xs = (mask , keys ),
261
229
)
262
-
263
-
264
-
265
230
266
231
def L_step_size_adaptation (state , params , num_steps , rng_key ):
267
232
num_steps1 , num_steps2 = int (num_steps * frac_tune1 ), int (
@@ -294,26 +259,21 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
294
259
update_da = update_da ,
295
260
)
296
261
297
-
298
262
final_stepsize = final_da (dual_avg_state )
299
263
params = params ._replace (step_size = final_stepsize )
300
264
301
-
302
-
303
-
304
265
# determine L
305
- eigenvector = None
266
+ eigenvector = None
306
267
if num_steps2 != 0.0 :
307
268
x_average , x_squared_average = average [0 ], average [1 ]
308
269
variances = x_squared_average - jnp .square (x_average )
309
270
310
- if max == ' max' :
311
- contract = lambda x : jnp .sqrt (jnp .max (x )* dim )* tuning_factor
271
+ if max == " max" :
272
+ contract = lambda x : jnp .sqrt (jnp .max (x ) * dim ) * tuning_factor
312
273
274
+ elif max == "avg" :
275
+ contract = lambda x : jnp .sqrt (jnp .sum (x )) * tuning_factor
313
276
314
- elif max == 'avg' :
315
- contract = lambda x : jnp .sqrt (jnp .sum (x ))* tuning_factor
316
-
317
277
else :
318
278
raise ValueError ("max should be either 'max' or 'avg'" )
319
279
@@ -346,13 +306,14 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
346
306
347
307
params = params ._replace (step_size = final_da (dual_avg_state ))
348
308
349
-
350
309
return state , params , eigenvector
351
310
352
311
return L_step_size_adaptation
353
312
354
313
355
- def adjusted_mclmc_make_adaptation_L (kernel , frac , Lfactor , max = 'avg' , eigenvector = None ):
314
+ def adjusted_mclmc_make_adaptation_L (
315
+ kernel , frac , Lfactor , max = "avg" , eigenvector = None
316
+ ):
356
317
"""determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)"""
357
318
358
319
def adaptation_L (state , params , num_steps , key ):
@@ -375,29 +336,29 @@ def step(state, key):
375
336
xs = adaptation_L_keys ,
376
337
)
377
338
378
- if max == ' max' :
339
+ if max == " max" :
379
340
contract = jnp .min
380
341
else :
381
342
contract = jnp .mean
382
343
383
344
flat_samples = jax .vmap (lambda x : ravel_pytree (x )[0 ])(samples )
384
345
385
-
386
346
if eigenvector is not None :
387
-
388
- flat_samples = jnp .expand_dims (jnp .einsum ('ij,j' , flat_samples , eigenvector ),1 )
347
+ flat_samples = jnp .expand_dims (
348
+ jnp .einsum ("ij,j" , flat_samples , eigenvector ), 1
349
+ )
389
350
390
351
# number of effective samples per 1 actual sample
391
- ess = contract (effective_sample_size (flat_samples [None , ...]))/ num_steps
352
+ ess = contract (effective_sample_size (flat_samples [None , ...])) / num_steps
392
353
393
- return state , params ._replace (L = jnp .clip (Lfactor * params .L / jnp .mean (ess ), max = params .L * 2 ))
354
+ return state , params ._replace (
355
+ L = jnp .clip (Lfactor * params .L / jnp .mean (ess ), max = params .L * 2 )
356
+ )
394
357
395
358
return adaptation_L
396
359
397
360
398
- def handle_nans (
399
- previous_state , next_state , step_size , step_size_max , kinetic_change
400
- ):
361
+ def handle_nans (previous_state , next_state , step_size , step_size_max , kinetic_change ):
401
362
"""if there are nans, let's reduce the stepsize, and not update the state. The
402
363
function returns the old state in this case."""
403
364
0 commit comments