Skip to content

Commit

Permalink
Add TestOptaxLBFGS class to test optax_lbfgs function
Browse files Browse the repository at this point in the history
- Introduced a new test class  to verify the functionality and convergence of the  function.
- Implemented a test method  that checks the consistency of the history and convergence of the optimizer.
  • Loading branch information
aphc14 committed Oct 26, 2024
1 parent 17f9cb5 commit 3f3eb75
Showing 1 changed file with 57 additions and 1 deletion.
58 changes: 57 additions & 1 deletion tests/optimizers/test_optimizers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test optimizers."""

import functools

import chex
Expand All @@ -17,6 +18,7 @@
lbfgs_inverse_hessian_formula_2,
lbfgs_recover_alpha,
minimize_lbfgs,
optax_lbfgs,
)


Expand Down Expand Up @@ -154,5 +156,59 @@ def loss_fn(x):
np.testing.assert_allclose(inv_hess_1, inv_hess_2, rtol=0.01)


class TestOptaxLBFGS(chex.TestCase):
def test_optax_lbfgs(
self,
maxcor=6,
maxiter=1000,
ftol=1e-5,
gtol=1e-8,
maxls=1000,
):
"""Test the optax_lbfgs function for consistency in history and convergence."""

def example_fun(w):
return jnp.sum(100.0 * (w[1:] - w[:-1] ** 2) ** 2 + (1.0 - w[:-1]) ** 2)

x0_example = jnp.zeros((8,))

(final_params, final_state), history = optax_lbfgs(
example_fun,
x0_example,
maxcor=maxcor,
maxiter=maxiter,
ftol=ftol,
gtol=gtol,
maxls=maxls,
)

# test that the history is correct
L = history.iter.shape[0]

for l in range(1, L):
last = history.last[l]
current_s = history.s[l]
sml = jnp.delete(current_s, last, axis=0)

previous_s = history.s[l - 1]
previous_sml = jnp.delete(previous_s, last, axis=0)

np.testing.assert_allclose(
previous_sml,
sml,
err_msg=f"l = {l}, last = {last}, previous_sml = {previous_sml}, sml = {sml}",
)

# additional checks for convergence
expected_solution = jnp.ones((8,))
np.testing.assert_allclose(
final_params,
expected_solution,
rtol=1e-2,
err_msg="Final parameters did not converge to expected solution.",
)


if __name__ == "__main__":
absltest.main()
# absltest.main()
TestOptaxLBFGS().test_optax_lbfgs()

0 comments on commit 3f3eb75

Please sign in to comment.