diff --git a/tests/optimizers/test_optimizers.py b/tests/optimizers/test_optimizers.py index a7549842f..8c8f27b03 100644 --- a/tests/optimizers/test_optimizers.py +++ b/tests/optimizers/test_optimizers.py @@ -1,4 +1,5 @@ """Test optimizers.""" + import functools import chex @@ -17,6 +18,7 @@ lbfgs_inverse_hessian_formula_2, lbfgs_recover_alpha, minimize_lbfgs, + optax_lbfgs, ) @@ -154,5 +156,59 @@ def loss_fn(x): np.testing.assert_allclose(inv_hess_1, inv_hess_2, rtol=0.01) +class TestOptaxLBFGS(chex.TestCase): + def test_optax_lbfgs( + self, + maxcor=6, + maxiter=1000, + ftol=1e-5, + gtol=1e-8, + maxls=1000, + ): + """Test the optax_lbfgs function for consistency in history and convergence.""" + + def example_fun(w): + return jnp.sum(100.0 * (w[1:] - w[:-1] ** 2) ** 2 + (1.0 - w[:-1]) ** 2) + + x0_example = jnp.zeros((8,)) + + (final_params, final_state), history = optax_lbfgs( + example_fun, + x0_example, + maxcor=maxcor, + maxiter=maxiter, + ftol=ftol, + gtol=gtol, + maxls=maxls, + ) + + # test that the history is correct + L = history.iter.shape[0] + + for l in range(1, L): + last = history.last[l] + current_s = history.s[l] + sml = jnp.delete(current_s, last, axis=0) + + previous_s = history.s[l - 1] + previous_sml = jnp.delete(previous_s, last, axis=0) + + np.testing.assert_allclose( + previous_sml, + sml, + err_msg=f"l = {l}, last = {last}, previous_sml = {previous_sml}, sml = {sml}", + ) + + # additional checks for convergence + expected_solution = jnp.ones((8,)) + np.testing.assert_allclose( + final_params, + expected_solution, + rtol=1e-2, + err_msg="Final parameters did not converge to expected solution.", + ) + + if __name__ == "__main__": - absltest.main() + # absltest.main() + TestOptaxLBFGS().test_optax_lbfgs()