Skip to content

Commit

Permalink
Added a function to utilise optax for lbfgs optimisation.
Browse files Browse the repository at this point in the history
Updated lbfgs to truncate history to reduce inverse hessian computation.
  • Loading branch information
aphc14 committed Oct 26, 2024
1 parent b107f9f commit 17f9cb5
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 13 deletions.
171 changes: 165 additions & 6 deletions blackjax/optimizers/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import jax.numpy as jnp
import jax.random
import jaxopt
import optax
import optax.tree_utils as otu
from jax import lax
from jax.flatten_util import ravel_pytree
from jaxopt._src.lbfgs import LbfgsState
Expand All @@ -37,6 +39,22 @@
MIN_STEP_SIZE = 1e-3


class _OptaxLBFGSHistory(NamedTuple):
x: Array
f: Array
g: Array
alpha: Array
update_mask: Array
# store intermediate values to perform checks
not_converged: Array
s: Array
z: Array
s_l: Array
z_l: Array
last: Array
iter: Array


class LBFGSHistory(NamedTuple):
"""Container for the optimization path of a L-BFGS run
Expand All @@ -60,6 +78,7 @@ class LBFGSHistory(NamedTuple):
g: Array
alpha: Array
update_mask: Array
not_converged: Array # for clipping history for shorter inverse hessian calcs and bfgs sampling


def minimize_lbfgs(
Expand Down Expand Up @@ -148,6 +167,143 @@ def minimize_lbfgs(
return last_step, history


def optax_lbfgs(
fun: Callable,
x0: Array,
maxiter: int,
maxcor: float,
gtol: float,
ftol: float,
maxls: int,
# **lbfgs_kwargs, # TODO: insert kwargs to optax.scale_by_zoom_linesearch and optax.value_and_grad_from_state
):
linesearch = optax.scale_by_zoom_linesearch(
max_linesearch_steps=maxls,
verbose=True,
)
solver = optax.lbfgs(
memory_size=maxcor,
linesearch=linesearch,
)
value_and_grad_fun = optax.value_and_grad_from_state(fun)

def lbfgs_one_step(carry, i):
# state is a 3-dim tuple
(params, state), previous_history = carry
value, grad = value_and_grad_fun(params, state=state)
updates, next_state = solver.update(
grad, state, params, value=value, grad=grad, value_fn=fun
)

# ensure num_linesearch_steps is of the same type
info = next_state[2].info._replace(
num_linesearch_steps=jnp.asarray(
next_state[2].info.num_linesearch_steps, dtype=jnp.int32
)
)

next_state = (next_state[0], next_state[1], next_state[2]._replace(info=info))

# LBFGS use a rolling history, getting the correct index here.
iter = state[0].count
# last variable for getting the correct index where updates occur
last = jnp.max(jnp.array([iter - 1, 0], dtype=jnp.int32)) % maxcor
next_params = optax.apply_updates(params, updates)

# Recover alpha and update mask
s_l = next_state[0].diff_params_memory[last]
z_l = next_state[0].diff_updates_memory[last]
alpha_lm1 = previous_history.alpha
alpha_l, mask_l = lbfgs_recover_alpha(alpha_lm1, s_l, z_l)

# TODO: check correct calc for g
# g = next_state[2].grad
# g = state[2].grad
# g = grad
# g = previous_history.g
# g = previous_history.g + z_l
# g = state[2].grad + z_l

not_converged = check_convergence(state, next_state, iter)
history = _OptaxLBFGSHistory(
x=next_params,
f=next_state[2].value,
g=next_state[2].grad,
alpha=alpha_l,
update_mask=mask_l,
not_converged=not_converged,
s=next_state[0].diff_params_memory,
z=next_state[0].diff_updates_memory,
s_l=s_l,
z_l=z_l,
last=jnp.asarray(last, dtype=jnp.int32),
iter=jnp.asarray(iter, dtype=jnp.int32),
)
return ((next_params, next_state), history), not_converged

def check_convergence(state, next_state, iter):
f_delta = (
jnp.abs(state[2].value - next_state[2].value)
/ jnp.asarray(
[jnp.abs(state[2].value), jnp.abs(next_state[2].value), 1.0]
).max()
)
next_state_grad = otu.tree_get(next_state[2], "grad")
error = otu.tree_l2_norm(next_state_grad)
return jnp.array(
(iter == 0) | (error > gtol) & (f_delta > ftol) & (iter < maxiter),
dtype=bool,
)

def non_op(carry, i):
(params, state), previous_history = carry

info = state[2].info._replace(
num_linesearch_steps=jnp.asarray(
state[2].info.num_linesearch_steps, dtype=jnp.int32
)
)
state = (state[0], state[1], state[2]._replace(info=info))

return ((params, state), previous_history), jnp.array(False, dtype=bool)

def scan_body(tup, i):
carry, not_converged = tup
next_tup = jax.lax.cond(not_converged, lbfgs_one_step, non_op, carry, i)
return next_tup, next_tup[0][-1]

x0, init_state = (x0, solver.init(x0))
init_history = _OptaxLBFGSHistory(
x=init_state[0].params,
f=init_state[2].value,
g=init_state[2].grad,
alpha=jnp.ones_like(x0),
update_mask=jnp.zeros_like(x0, dtype=bool),
not_converged=jnp.array(True, dtype=bool),
s=init_state[0].diff_params_memory,
z=init_state[0].diff_updates_memory,
s_l=jnp.zeros_like(x0),
z_l=jnp.zeros_like(x0),
last=jnp.asarray(-1, dtype=jnp.int32),
iter=jnp.asarray(-1, dtype=jnp.int32),
)

# Use lax.scan to accumulate history
(((final_params, final_state), _), _), history = jax.lax.scan(
scan_body,
(((x0, init_state), init_history), True),
jnp.arange(maxiter),
length=maxiter,
)

history = jax.tree.map(
lambda x, y: jnp.concatenate([x[None, ...], y], axis=0),
init_history,
history,
)
return (final_params, final_state), history


def _minimize_lbfgs(
fun: Callable,
x0: Array,
Expand Down Expand Up @@ -181,19 +337,21 @@ def lbfgs_one_step(carry, i):
alpha_l, mask_l = lbfgs_recover_alpha(alpha_lm1, s_l, z_l)

current_grad = previous_history.g + z_l

# check convergence
f_delta = (
jnp.abs(state.value - next_state.value)
/ jnp.asarray([jnp.abs(state.value), jnp.abs(next_state.value), 1.0]).max()
)
not_converged = (next_state.error > gtol) & (f_delta > ftol) & (i < maxiter)
history = LBFGSHistory(
x=next_params,
f=next_state.value,
g=current_grad,
alpha=alpha_l,
update_mask=mask_l,
not_converged=jnp.array(not_converged, dtype=bool),
)
# check convergence
f_delta = (
jnp.abs(state.value - next_state.value)
/ jnp.asarray([jnp.abs(state.value), jnp.abs(next_state.value), 1.0]).max()
)
not_converged = (next_state.error > gtol) & (f_delta > ftol) & (i < maxiter)
return (OptStep(params=next_params, state=next_state), history), not_converged

def non_op(carry, it):
Expand Down Expand Up @@ -224,6 +382,7 @@ def scan_body(tup, it):
g=grad0,
alpha=jnp.ones_like(x0),
update_mask=jnp.zeros_like(x0, dtype=bool),
not_converged=jnp.array(True, dtype=bool),
)

((last_step, _), _), history = lax.scan(
Expand Down
13 changes: 6 additions & 7 deletions blackjax/vi/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ def approximate(
**lbfgs_kwargs,
)

# Get postions and gradients of the optimization path (including the starting point).
# get the index where lbfgs converged
lbfgs_converged_idx = history.not_converged.sum()
# truncate history to the point of convergence
history = jax.tree.map(lambda x: x[:lbfgs_converged_idx], history)

position = history.x
grad_position = history.g
alpha = history.alpha
Expand Down Expand Up @@ -172,19 +176,14 @@ def path_finder_body_fn(rng_key, S, Z, alpha_l, theta, theta_grad):
# Index and reshape S and Z to be sliding window view shape=(maxiter,
# maxcor, param_dim), so we can vmap over all the iterations.
# This is in effect numpy.lib.stride_tricks.sliding_window_view
path_size = maxiter + 1
path_size = lbfgs_converged_idx
index = jnp.arange(path_size)[:, None] + jnp.arange(maxcor)[None, :]
s_j = s_padded[index.reshape(path_size, maxcor)].reshape(path_size, maxcor, -1)
z_j = z_padded[index.reshape(path_size, maxcor)].reshape(path_size, maxcor, -1)
rng_keys = jax.random.split(rng_key, path_size)
elbo, beta, gamma = jax.vmap(path_finder_body_fn)(
rng_keys, s_j, z_j, alpha, position, grad_position
)
elbo = jnp.where(
(jnp.arange(path_size) < (status.iter_num)) & jnp.isfinite(elbo),
elbo,
-jnp.inf,
)

unravel_fn_mapped = jax.vmap(unravel_fn)
pathfinder_result = PathfinderState(
Expand Down

0 comments on commit 17f9cb5

Please sign in to comment.