From 17f9cb5cb52f1c14a68b9137e8c5848745111d4f Mon Sep 17 00:00:00 2001 From: aphc14 <177544929+aphc14@users.noreply.github.com> Date: Sat, 26 Oct 2024 19:10:33 +1100 Subject: [PATCH] Added a function to utilise optax for lbfgs optimisation. Updated lbfgs to truncate history to reduce inverse hessian computation. --- blackjax/optimizers/lbfgs.py | 171 +++++++++++++++++++++++++++++++++-- blackjax/vi/pathfinder.py | 13 ++- 2 files changed, 171 insertions(+), 13 deletions(-) diff --git a/blackjax/optimizers/lbfgs.py b/blackjax/optimizers/lbfgs.py index 0dd59f003..fd8bd1241 100644 --- a/blackjax/optimizers/lbfgs.py +++ b/blackjax/optimizers/lbfgs.py @@ -17,6 +17,8 @@ import jax.numpy as jnp import jax.random import jaxopt +import optax +import optax.tree_utils as otu from jax import lax from jax.flatten_util import ravel_pytree from jaxopt._src.lbfgs import LbfgsState @@ -37,6 +39,22 @@ MIN_STEP_SIZE = 1e-3 +class _OptaxLBFGSHistory(NamedTuple): + x: Array + f: Array + g: Array + alpha: Array + update_mask: Array + # store intermediate values to perform checks + not_converged: Array + s: Array + z: Array + s_l: Array + z_l: Array + last: Array + iter: Array + + class LBFGSHistory(NamedTuple): """Container for the optimization path of a L-BFGS run @@ -60,6 +78,7 @@ class LBFGSHistory(NamedTuple): g: Array alpha: Array update_mask: Array + not_converged: Array # for clipping history for shorter inverse hessian calcs and bfgs sampling def minimize_lbfgs( @@ -148,6 +167,143 @@ def minimize_lbfgs( return last_step, history +def optax_lbfgs( + fun: Callable, + x0: Array, + maxiter: int, + maxcor: float, + gtol: float, + ftol: float, + maxls: int, + # **lbfgs_kwargs, # TODO: insert kwargs to optax.scale_by_zoom_linesearch and optax.value_and_grad_from_state +): + linesearch = optax.scale_by_zoom_linesearch( + max_linesearch_steps=maxls, + verbose=True, + ) + solver = optax.lbfgs( + memory_size=maxcor, + linesearch=linesearch, + ) + value_and_grad_fun = optax.value_and_grad_from_state(fun) + + def lbfgs_one_step(carry, i): + # state is a 3-dim tuple + (params, state), previous_history = carry + value, grad = value_and_grad_fun(params, state=state) + updates, next_state = solver.update( + grad, state, params, value=value, grad=grad, value_fn=fun + ) + + # ensure num_linesearch_steps is of the same type + info = next_state[2].info._replace( + num_linesearch_steps=jnp.asarray( + next_state[2].info.num_linesearch_steps, dtype=jnp.int32 + ) + ) + + next_state = (next_state[0], next_state[1], next_state[2]._replace(info=info)) + + # LBFGS use a rolling history, getting the correct index here. + iter = state[0].count + # last variable for getting the correct index where updates occur + last = jnp.max(jnp.array([iter - 1, 0], dtype=jnp.int32)) % maxcor + next_params = optax.apply_updates(params, updates) + + # Recover alpha and update mask + s_l = next_state[0].diff_params_memory[last] + z_l = next_state[0].diff_updates_memory[last] + alpha_lm1 = previous_history.alpha + alpha_l, mask_l = lbfgs_recover_alpha(alpha_lm1, s_l, z_l) + + # TODO: check correct calc for g + # g = next_state[2].grad + # g = state[2].grad + # g = grad + # g = previous_history.g + # g = previous_history.g + z_l + # g = state[2].grad + z_l + + not_converged = check_convergence(state, next_state, iter) + history = _OptaxLBFGSHistory( + x=next_params, + f=next_state[2].value, + g=next_state[2].grad, + alpha=alpha_l, + update_mask=mask_l, + not_converged=not_converged, + s=next_state[0].diff_params_memory, + z=next_state[0].diff_updates_memory, + s_l=s_l, + z_l=z_l, + last=jnp.asarray(last, dtype=jnp.int32), + iter=jnp.asarray(iter, dtype=jnp.int32), + ) + return ((next_params, next_state), history), not_converged + + def check_convergence(state, next_state, iter): + f_delta = ( + jnp.abs(state[2].value - next_state[2].value) + / jnp.asarray( + [jnp.abs(state[2].value), jnp.abs(next_state[2].value), 1.0] + ).max() + ) + next_state_grad = otu.tree_get(next_state[2], "grad") + error = otu.tree_l2_norm(next_state_grad) + return jnp.array( + (iter == 0) | (error > gtol) & (f_delta > ftol) & (iter < maxiter), + dtype=bool, + ) + + def non_op(carry, i): + (params, state), previous_history = carry + + info = state[2].info._replace( + num_linesearch_steps=jnp.asarray( + state[2].info.num_linesearch_steps, dtype=jnp.int32 + ) + ) + state = (state[0], state[1], state[2]._replace(info=info)) + + return ((params, state), previous_history), jnp.array(False, dtype=bool) + + def scan_body(tup, i): + carry, not_converged = tup + next_tup = jax.lax.cond(not_converged, lbfgs_one_step, non_op, carry, i) + return next_tup, next_tup[0][-1] + + x0, init_state = (x0, solver.init(x0)) + init_history = _OptaxLBFGSHistory( + x=init_state[0].params, + f=init_state[2].value, + g=init_state[2].grad, + alpha=jnp.ones_like(x0), + update_mask=jnp.zeros_like(x0, dtype=bool), + not_converged=jnp.array(True, dtype=bool), + s=init_state[0].diff_params_memory, + z=init_state[0].diff_updates_memory, + s_l=jnp.zeros_like(x0), + z_l=jnp.zeros_like(x0), + last=jnp.asarray(-1, dtype=jnp.int32), + iter=jnp.asarray(-1, dtype=jnp.int32), + ) + + # Use lax.scan to accumulate history + (((final_params, final_state), _), _), history = jax.lax.scan( + scan_body, + (((x0, init_state), init_history), True), + jnp.arange(maxiter), + length=maxiter, + ) + + history = jax.tree.map( + lambda x, y: jnp.concatenate([x[None, ...], y], axis=0), + init_history, + history, + ) + return (final_params, final_state), history + + def _minimize_lbfgs( fun: Callable, x0: Array, @@ -181,19 +337,21 @@ def lbfgs_one_step(carry, i): alpha_l, mask_l = lbfgs_recover_alpha(alpha_lm1, s_l, z_l) current_grad = previous_history.g + z_l + + # check convergence + f_delta = ( + jnp.abs(state.value - next_state.value) + / jnp.asarray([jnp.abs(state.value), jnp.abs(next_state.value), 1.0]).max() + ) + not_converged = (next_state.error > gtol) & (f_delta > ftol) & (i < maxiter) history = LBFGSHistory( x=next_params, f=next_state.value, g=current_grad, alpha=alpha_l, update_mask=mask_l, + not_converged=jnp.array(not_converged, dtype=bool), ) - # check convergence - f_delta = ( - jnp.abs(state.value - next_state.value) - / jnp.asarray([jnp.abs(state.value), jnp.abs(next_state.value), 1.0]).max() - ) - not_converged = (next_state.error > gtol) & (f_delta > ftol) & (i < maxiter) return (OptStep(params=next_params, state=next_state), history), not_converged def non_op(carry, it): @@ -224,6 +382,7 @@ def scan_body(tup, it): g=grad0, alpha=jnp.ones_like(x0), update_mask=jnp.zeros_like(x0, dtype=bool), + not_converged=jnp.array(True, dtype=bool), ) ((last_step, _), _), history = lax.scan( diff --git a/blackjax/vi/pathfinder.py b/blackjax/vi/pathfinder.py index c1b7dc113..62981594a 100644 --- a/blackjax/vi/pathfinder.py +++ b/blackjax/vi/pathfinder.py @@ -138,7 +138,11 @@ def approximate( **lbfgs_kwargs, ) - # Get postions and gradients of the optimization path (including the starting point). + # get the index where lbfgs converged + lbfgs_converged_idx = history.not_converged.sum() + # truncate history to the point of convergence + history = jax.tree.map(lambda x: x[:lbfgs_converged_idx], history) + position = history.x grad_position = history.g alpha = history.alpha @@ -172,7 +176,7 @@ def path_finder_body_fn(rng_key, S, Z, alpha_l, theta, theta_grad): # Index and reshape S and Z to be sliding window view shape=(maxiter, # maxcor, param_dim), so we can vmap over all the iterations. # This is in effect numpy.lib.stride_tricks.sliding_window_view - path_size = maxiter + 1 + path_size = lbfgs_converged_idx index = jnp.arange(path_size)[:, None] + jnp.arange(maxcor)[None, :] s_j = s_padded[index.reshape(path_size, maxcor)].reshape(path_size, maxcor, -1) z_j = z_padded[index.reshape(path_size, maxcor)].reshape(path_size, maxcor, -1) @@ -180,11 +184,6 @@ def path_finder_body_fn(rng_key, S, Z, alpha_l, theta, theta_grad): elbo, beta, gamma = jax.vmap(path_finder_body_fn)( rng_keys, s_j, z_j, alpha, position, grad_position ) - elbo = jnp.where( - (jnp.arange(path_size) < (status.iter_num)) & jnp.isfinite(elbo), - elbo, - -jnp.inf, - ) unravel_fn_mapped = jax.vmap(unravel_fn) pathfinder_result = PathfinderState(