Skip to content

Commit 996258c

Browse files
committed
ready for test
1 parent 37b5f57 commit 996258c

File tree

1 file changed

+30
-69
lines changed

1 file changed

+30
-69
lines changed

blackjax/adaptation/adjusted_mclmc_adaptation.py

+30-69
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,13 @@
22
import jax.numpy as jnp
33
from jax.flatten_util import ravel_pytree
44

5-
import optax
6-
import blackjax
7-
from blackjax.adaptation.mclmc_adaptation import MCLMCAdaptationState, handle_nans
5+
from blackjax.adaptation.mclmc_adaptation import MCLMCAdaptationState
86
from blackjax.adaptation.step_size import (
97
DualAveragingAdaptationState,
108
dual_averaging_adaptation,
119
)
1210
from blackjax.diagnostics import effective_sample_size
13-
from blackjax.mcmc.adjusted_mclmc import rescale
14-
from blackjax.util import pytree_size, incremental_value_update, run_inference_algorithm
15-
16-
from blackjax.mcmc.integrators import (
17-
generate_euclidean_integrator,
18-
generate_isokinetic_integrator,
19-
mclachlan,
20-
yoshida,
21-
velocity_verlet,
22-
omelyan,
23-
isokinetic_mclachlan,
24-
isokinetic_velocity_verlet,
25-
isokinetic_yoshida,
26-
isokinetic_omelyan,
27-
)
11+
from blackjax.util import incremental_value_update, pytree_size
2812

2913
Lratio_lowerbound = 0.0
3014
Lratio_upperbound = 2.0
@@ -92,12 +76,7 @@ def adjusted_mclmc_find_L_and_step_size(
9276

9377
for i in range(num_windows):
9478
window_key = jax.random.fold_in(part1_key, i)
95-
(
96-
state,
97-
params,
98-
eigenvector
99-
100-
) = adjusted_mclmc_make_L_step_size_adaptation(
79+
(state, params, eigenvector) = adjusted_mclmc_make_L_step_size_adaptation(
10180
kernel=mclmc_kernel,
10281
dim=dim,
10382
frac_tune1=frac_tune1,
@@ -106,25 +85,22 @@ def adjusted_mclmc_find_L_and_step_size(
10685
diagonal_preconditioning=diagonal_preconditioning,
10786
max=max,
10887
tuning_factor=tuning_factor,
109-
)(
110-
state, params, num_steps, window_key
111-
)
88+
)(state, params, num_steps, window_key)
11289

11390
if frac_tune3 != 0:
11491
for i in range(2):
11592
part2_key = jax.random.fold_in(part2_key, i)
11693
part2_key1, part2_key2 = jax.random.split(part2_key, 2)
11794

11895
state, params = adjusted_mclmc_make_adaptation_L(
119-
mclmc_kernel, frac=frac_tune3, Lfactor=0.5, max=max, eigenvector=eigenvector,
96+
mclmc_kernel,
97+
frac=frac_tune3,
98+
Lfactor=0.5,
99+
max=max,
100+
eigenvector=eigenvector,
120101
)(state, params, num_steps, part2_key1)
121102

122-
(
123-
state,
124-
params,
125-
_
126-
127-
) = adjusted_mclmc_make_L_step_size_adaptation(
103+
(state, params, _) = adjusted_mclmc_make_L_step_size_adaptation(
128104
kernel=mclmc_kernel,
129105
dim=dim,
130106
frac_tune1=frac_tune1,
@@ -134,12 +110,7 @@ def adjusted_mclmc_find_L_and_step_size(
134110
diagonal_preconditioning=diagonal_preconditioning,
135111
max=max,
136112
tuning_factor=tuning_factor,
137-
)(
138-
state, params, num_steps, part2_key2
139-
)
140-
141-
142-
113+
)(state, params, num_steps, part2_key2)
143114

144115
return state, params
145116

@@ -152,7 +123,7 @@ def adjusted_mclmc_make_L_step_size_adaptation(
152123
target,
153124
diagonal_preconditioning,
154125
fix_L_first_da=False,
155-
max='avg',
126+
max="avg",
156127
tuning_factor=1.0,
157128
):
158129
"""Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC"""
@@ -161,8 +132,6 @@ def dual_avg_step(fix_L, update_da):
161132
"""does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize"""
162133

163134
def step(iteration_state, weight_and_key):
164-
165-
166135
mask, rng_key = weight_and_key
167136
(
168137
previous_state,
@@ -180,7 +149,6 @@ def step(iteration_state, weight_and_key):
180149
step_size=params.step_size,
181150
sqrt_diag_cov=params.sqrt_diag_cov,
182151
)
183-
184152

185153
# step updating
186154
success, state, step_size_max, energy_change = handle_nans(
@@ -231,7 +199,7 @@ def step(iteration_state, weight_and_key):
231199
+ (1 - mask) * params.L,
232200
)
233201

234-
if max!='max_svd':
202+
if max != "max_svd":
235203
state_position = None
236204
else:
237205
state_position = state.position
@@ -259,9 +227,6 @@ def step_size_adaptation(mask, state, params, keys, fix_L, initial_da, update_da
259227
),
260228
xs=(mask, keys),
261229
)
262-
263-
264-
265230

266231
def L_step_size_adaptation(state, params, num_steps, rng_key):
267232
num_steps1, num_steps2 = int(num_steps * frac_tune1), int(
@@ -294,26 +259,21 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
294259
update_da=update_da,
295260
)
296261

297-
298262
final_stepsize = final_da(dual_avg_state)
299263
params = params._replace(step_size=final_stepsize)
300264

301-
302-
303-
304265
# determine L
305-
eigenvector = None
266+
eigenvector = None
306267
if num_steps2 != 0.0:
307268
x_average, x_squared_average = average[0], average[1]
308269
variances = x_squared_average - jnp.square(x_average)
309270

310-
if max=='max':
311-
contract = lambda x: jnp.sqrt(jnp.max(x)*dim)*tuning_factor
271+
if max == "max":
272+
contract = lambda x: jnp.sqrt(jnp.max(x) * dim) * tuning_factor
312273

274+
elif max == "avg":
275+
contract = lambda x: jnp.sqrt(jnp.sum(x)) * tuning_factor
313276

314-
elif max=='avg':
315-
contract = lambda x: jnp.sqrt(jnp.sum(x))*tuning_factor
316-
317277
else:
318278
raise ValueError("max should be either 'max' or 'avg'")
319279

@@ -346,13 +306,14 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
346306

347307
params = params._replace(step_size=final_da(dual_avg_state))
348308

349-
350309
return state, params, eigenvector
351310

352311
return L_step_size_adaptation
353312

354313

355-
def adjusted_mclmc_make_adaptation_L(kernel, frac, Lfactor, max='avg', eigenvector=None):
314+
def adjusted_mclmc_make_adaptation_L(
315+
kernel, frac, Lfactor, max="avg", eigenvector=None
316+
):
356317
"""determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)"""
357318

358319
def adaptation_L(state, params, num_steps, key):
@@ -375,29 +336,29 @@ def step(state, key):
375336
xs=adaptation_L_keys,
376337
)
377338

378-
if max=='max':
339+
if max == "max":
379340
contract = jnp.min
380341
else:
381342
contract = jnp.mean
382343

383344
flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples)
384345

385-
386346
if eigenvector is not None:
387-
388-
flat_samples = jnp.expand_dims(jnp.einsum('ij,j', flat_samples, eigenvector),1)
347+
flat_samples = jnp.expand_dims(
348+
jnp.einsum("ij,j", flat_samples, eigenvector), 1
349+
)
389350

390351
# number of effective samples per 1 actual sample
391-
ess = contract(effective_sample_size(flat_samples[None, ...]))/num_steps
352+
ess = contract(effective_sample_size(flat_samples[None, ...])) / num_steps
392353

393-
return state, params._replace(L=jnp.clip(Lfactor * params.L / jnp.mean(ess), max=params.L*2))
354+
return state, params._replace(
355+
L=jnp.clip(Lfactor * params.L / jnp.mean(ess), max=params.L * 2)
356+
)
394357

395358
return adaptation_L
396359

397360

398-
def handle_nans(
399-
previous_state, next_state, step_size, step_size_max, kinetic_change
400-
):
361+
def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change):
401362
"""if there are nans, let's reduce the stepsize, and not update the state. The
402363
function returns the old state in this case."""
403364

0 commit comments

Comments
 (0)