Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/ci-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10"]
python-version:
- 3.11
- 3.12
- 3.13
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
Expand Down
14 changes: 4 additions & 10 deletions optimism/NewtonSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from scipy.sparse.linalg import LinearOperator, gmres

from optimism.JaxConfig import *

from optimism.ScipyInterface import make_scipy_linear_function

Settings = namedtuple('Settings', ['relative_gmres_tol',
'max_gmres_iters'])
Expand All @@ -27,13 +27,9 @@ def compute_min_p(ps, bounds):
return min(max(quadMin, bounds[0]), bounds[1])


def newton_step(residual, linear_op, x, settings=Settings(1e-2,100), precond=None):
def newton_step(residual, residual_jvp, x, settings=Settings(1e-2,100), precond=None):
sz = x.size
# The call to onp.array copies the jax array output into a plain numpy
# array. The copy is necessary for safety, since as far as scipy knows,
# it is allowed to modify the output in place.
A = LinearOperator((sz,sz),
lambda v: onp.array(linear_op(v)))
A = LinearOperator((sz, sz), make_scipy_linear_function(residual_jvp))
r = onp.array(residual(x))

numIters = 0
Expand All @@ -45,9 +41,7 @@ def callback(xk):
maxIters = settings.max_gmres_iters

if precond is not None:
# Another copy to a plain numpy array, see comment for A above.
M = LinearOperator((sz,sz),
lambda v: onp.array(precond(v)))
M = LinearOperator((sz, sz), make_scipy_linear_function(precond))
else:
M = None

Expand Down
8 changes: 6 additions & 2 deletions optimism/ScalarRootFind.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,12 @@ def loop_body(carry):

x = np.where(converged, x, np.nan)

return x, SolutionInfo(converged=converged, function_calls=functionCalls,
iterations=iters, residual_norm=np.abs(F), correction_norm=np.abs(dx))
# BT 10/14/2025 As of Jax 0.4.34, the has_aux argument of custom_root is broken
# and cannot handle non-differentiable outputs.
# See https://github.com/jax-ml/jax/issues/24295
# return x, SolutionInfo(converged=converged, function_calls=functionCalls,
# iterations=iters, residual_norm=np.abs(F), correction_norm=np.abs(dx))
return x, None


def bisection_step(x, xl, xh, df, f):
Expand Down
16 changes: 16 additions & 0 deletions optimism/ScipyInterface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import jax.numpy as np
import numpy as onp

def make_scipy_linear_function(linear_function):
"""Transform a linear function of a jax array to one that can be used with scipy.linalg.LinearOperator."""
def linear_op(v):
# The v is going into a jax function (probably a jvp).
# Sometimes scipy passes in an array of dtype int, which breaks
# jax tracing and differentiation, so explicitly set type to
# something jax can handle.
jax_v = np.array(v, dtype=np.float64)
jax_Av = linear_function(jax_v)
# The result is going back into a scipy solver, so convert back
# to a standard numpy array.
return onp.array(jax_Av)
return linear_op
5 changes: 3 additions & 2 deletions optimism/inverse/test/test_Hyperelastic_gradient_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from optimism import Objective
from optimism import Mesh
from optimism.material import Neohookean
from optimism.ScipyInterface import make_scipy_linear_function

from .FiniteDifferenceFixture import FiniteDifferenceFixture

Expand Down Expand Up @@ -166,8 +167,8 @@ def total_work_gradient_with_adjoint(self, storedState, parameters):
n = self.dofManager.get_unknown_size()
self.objective.p = p # have to update parameters to get precond to work
self.objective.update_precond(Uu) # update preconditioner for use in cg (will converge in 1 iteration as long as the preconditioner is not approximate)
dRdu = linalg.LinearOperator((n, n), lambda V: onp.asarray(self.objective.hessian_vec(Uu, V)))
dRdu_decomp = linalg.LinearOperator((n, n), lambda V: onp.asarray(self.objective.apply_precond(V)))
dRdu = linalg.LinearOperator((n, n), make_scipy_linear_function(lambda V: self.objective.hessian_vec(Uu, V)))
dRdu_decomp = linalg.LinearOperator((n, n), make_scipy_linear_function(self.objective.apply_precond))
adjointVector = linalg.cg(dRdu, onp.array(adjointLoad, copy=False), rtol=1e-10, atol=0.0, M=dRdu_decomp)[0]

gradient += df_dx
Expand Down
9 changes: 5 additions & 4 deletions optimism/inverse/test/test_J2Plastic_gradient_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from optimism import Objective
from optimism import Mesh
from optimism.material import J2Plastic as J2
from optimism.ScipyInterface import make_scipy_linear_function

from .FiniteDifferenceFixture import FiniteDifferenceFixture

Expand Down Expand Up @@ -174,8 +175,8 @@ def total_work_gradient(self, storedState, parameters):

self.objective.p = p_objective
self.objective.update_precond(Uu) # update preconditioner for use in cg (will converge in 1 iteration as long as the preconditioner is not approximate)
dRdu = linalg.LinearOperator((n, n), lambda V: onp.asarray(self.objective.hessian_vec(Uu, V)))
dRdu_decomp = linalg.LinearOperator((n, n), lambda V: onp.asarray(self.objective.apply_precond(V)))
dRdu = linalg.LinearOperator((n, n), make_scipy_linear_function(lambda V: self.objective.hessian_vec(Uu, V)))
dRdu_decomp = linalg.LinearOperator((n, n), make_scipy_linear_function(self.objective.apply_precond))
adjointVector = linalg.cg(dRdu, onp.array(adjointLoad, copy=False), rtol=1e-10, atol=0.0, M=dRdu_decomp)[0]

gradient += residualInverseFuncs.residual_jac_coords_vjp(Uu, p, ivs_prev, parameters, adjointVector)
Expand Down Expand Up @@ -255,8 +256,8 @@ def target_curve_gradient(self, storedState, parameters):
p_objective = Objective.Params(bc_data=p.bc_data, state_data=p_prev.state_data, prop_data=self.props) # remember R is a function of ivs_prev
self.objective.p = p_objective
self.objective.update_precond(Uu) # update preconditioner for use in cg (will converge in 1 iteration as long as the preconditioner is not approximate)
dRdu = linalg.LinearOperator((n, n), lambda V: onp.asarray(self.objective.hessian_vec(Uu, V)))
dRdu_decomp = linalg.LinearOperator((n, n), lambda V: onp.asarray(self.objective.apply_precond(V)))
dRdu = linalg.LinearOperator((n, n), make_scipy_linear_function(lambda V: self.objective.hessian_vec(Uu, V)))
dRdu_decomp = linalg.LinearOperator((n, n), make_scipy_linear_function(self.objective.apply_precond))
adjointVector = linalg.cg(dRdu, onp.array(adjointLoad, copy=False), rtol=1e-10, atol=0.0, M=dRdu_decomp)[0]

gradient += residualInverseFuncs.residual_jac_coords_vjp(Uu, p, ivs_prev, parameters, adjointVector)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from optimism import Objective
from optimism import Mesh
from optimism.material import J2Plastic as J2
from optimism.ScipyInterface import make_scipy_linear_function

from .FiniteDifferenceFixture import FiniteDifferenceFixture

Expand All @@ -24,6 +25,7 @@
['energy_function_coords',
'compute_dissipation'])


class J2GlobalMeshAdjointSolveFixture(FiniteDifferenceFixture):
def setUp(self):
dispGrad0 = np.array([[0.4, -0.2],
Expand Down Expand Up @@ -172,8 +174,8 @@ def dissipated_energy_gradient(self, storedState, parameters):
p_objective = Objective.Params(bc_data=p.bc_data, state_data=ivs_prev, prop_data=self.props) # remember R is a function of ivs_prev
self.objective.p = p_objective
self.objective.update_precond(Uu) # update preconditioner for use in cg (will converge in 1 iteration as long as the preconditioner is not approximate)
dRdu = linalg.LinearOperator((n, n), lambda V: onp.asarray(self.objective.hessian_vec(Uu, V)))
dRdu_decomp = linalg.LinearOperator((n, n), lambda V: onp.asarray(self.objective.apply_precond(V)))
dRdu = linalg.LinearOperator((n, n), make_scipy_linear_function(lambda V: self.objective.hessian_vec(Uu, V)))
dRdu_decomp = linalg.LinearOperator((n, n), make_scipy_linear_function(self.objective.apply_precond))
adjointVector = linalg.cg(dRdu, onp.array(adjointLoad, copy=False), rtol=1e-10, atol=0.0, M=dRdu_decomp)[0]

gradient += residualInverseFuncs.residual_jac_coords_vjp(Uu, p, ivs_prev, parameters, adjointVector)
Expand Down
2 changes: 1 addition & 1 deletion optimism/test/test_LinAlg.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_sqrtm_on_10x10(self):
C = F.T@F
sqrtC = LinAlg.sqrtm(C)
shouldBeC = np.dot(sqrtC,sqrtC)
self.assertArrayNear(shouldBeC, C, 11)
self.assertArrayNear(shouldBeC, C, 10)


def test_sqrtm_derivatives_on_10x10(self):
Expand Down
6 changes: 5 additions & 1 deletion optimism/test/test_Mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,12 @@ def test_conversion_to_quadratic_mesh_is_valid(self):
# plt.show()


def cross_prod_2d(v, w):
return v[0]*w[1] - w[0]*v[1]


def triangle_inradius(tcoords):
area = 0.5*onp.cross(tcoords[1]-tcoords[0], tcoords[2]-tcoords[0])
area = 0.5*cross_prod_2d(tcoords[1]-tcoords[0], tcoords[2]-tcoords[0])
peri = (onp.linalg.norm(tcoords[1]-tcoords[0])
+ onp.linalg.norm(tcoords[2]-tcoords[1])
+ onp.linalg.norm(tcoords[0]-tcoords[2]))
Expand Down
30 changes: 19 additions & 11 deletions optimism/test/test_ScalarRootFinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def f(x): return x**3 - 4.0
class ScalarRootFindTestFixture(TestFixture.TestFixture):

def setUp(self):
self.settings = ScalarRootFind.get_settings()
self.settings = ScalarRootFind.get_settings(r_tol=1e-12, x_tol=0)

self.rootGuess = 1e-5
self.rootExpected = np.cbrt(4.0)
Expand All @@ -33,16 +33,19 @@ def setUp(self):

def test_find_root(self):
rootBracket = np.array([float_info.epsilon, 100.0])
root, status = ScalarRootFind.find_root(f, self.rootGuess, rootBracket, self.settings)
self.assertTrue(status.converged)
root, _ = ScalarRootFind.find_root(f, self.rootGuess, rootBracket, self.settings)
#self.assertTrue(status.converged)
converged = np.abs(f(root)) <= self.settings.r_tol
self.assertTrue(converged)
self.assertNear(root, self.rootExpected, 13)


def test_find_root_with_jit(self):
rtsafe_jit = jax.jit(ScalarRootFind.find_root, static_argnums=(0,3))
rootBracket = np.array([float_info.epsilon, 100.0])
root, status = rtsafe_jit(f, self.rootGuess, rootBracket, self.settings)
self.assertTrue(status.converged)
root, _ = rtsafe_jit(f, self.rootGuess, rootBracket, self.settings)
converged = np.abs(f(root)) <= self.settings.r_tol
self.assertTrue(converged)
self.assertNear(root, self.rootExpected, 13)


Expand All @@ -56,8 +59,9 @@ def test_find_root_converges_on_hard_function(self):
g = lambda x: np.sin(x) + x
rootBracket = np.array([-3.0, 20.0])
x0 = 19.0
root, status = ScalarRootFind.find_root(f, x0, rootBracket, self.settings)
self.assertTrue(status.converged)
root, _ = ScalarRootFind.find_root(f, x0, rootBracket, self.settings)
converged = np.abs(f(root)) <= self.settings.r_tol
self.assertTrue(converged)
self.assertNear(root, self.rootExpected, 13)


Expand Down Expand Up @@ -106,16 +110,20 @@ def my_sqrt(a):
def test_solves_when_left_bracket_is_solution(self):
rootBracket = np.array([0.0, 1.0])
guess = 3.0
root, status = ScalarRootFind.find_root(lambda x: x*(x**2 - 10.0), guess, rootBracket, self.settings)
self.assertTrue(status.converged)
f = lambda x: x*(x**2 - 10.0)
root, _ = ScalarRootFind.find_root(f, guess, rootBracket, self.settings)
converged = np.abs(f(root)) <= self.settings.r_tol
self.assertTrue(converged)
self.assertNear(root, 0.0, 12)


def test_solves_when_right_bracket_is_solution(self):
rootBracket = np.array([-1.0, 0.0])
guess = 3.0
root, status = ScalarRootFind.find_root(lambda x: x*(x**2 - 10.0), guess, rootBracket, self.settings)
self.assertTrue(status.converged)
f = lambda x: x*(x**2 - 10.0)
root, _ = ScalarRootFind.find_root(f, guess, rootBracket, self.settings)
converged = np.abs(f(root)) <= self.settings.r_tol
self.assertTrue(converged)
self.assertNear(root, 0.0, 12)

if __name__ == '__main__':
Expand Down
7 changes: 3 additions & 4 deletions optimism/test/test_TensorMath.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def lam(A):
class TensorMathFixture(TestFixture):

def setUp(self):
key = jax.random.PRNGKey(1)
key = jax.random.PRNGKey(0)
self.R = jax.random.orthogonal(key, 3)
self.assertGreater(np.linalg.det(self.R), 0) # make sure this is a rotation and not a reflection
self.log_squared = lambda A: np.tensordot(TensorMath.log_sqrt(A), TensorMath.log_sqrt(A))
Expand Down Expand Up @@ -177,7 +177,7 @@ def test_exp_symm_gradient_distinct_eigenvalues(self):

def test_sqrt_symm_gradient_almost_double_degenerate(self):
C = [email protected](np.array([2.1, 2.1 + 1e-8, 3.0]))@self.R.T
check_grads(TensorMath.exp_symm, (C,), order=1, eps=1e-10)
check_grads(TensorMath.exp_symm, (C,), order=1, eps=1e-8, rtol=5e-5)

# pow_symm tests

Expand Down Expand Up @@ -219,11 +219,10 @@ def test_pow_symm_gradient_distinct_eigenvalues(self):
m = 0.25
check_grads(lambda A: TensorMath.pow_symm(C, m), (C,), order=1)

@unittest.expectedFailure
def test_pow_symm_gradient_almost_double_degenerate(self):
C = [email protected](np.array([2.1, 2.1 + 1e-8, 3.0]))@self.R.T
m = 0.25
check_grads(lambda A: TensorMath.pow_symm(A, 0.25), (C,), order=1, atol=1e-16, eps=1e-10)
check_grads(lambda A: TensorMath.pow_symm(A, 0.25), (C,), order=1, atol=1e-16, eps=1e-6)


def test_determinant(self):
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
author="Michael Tupek and Brandon Talamini",
author_email='[email protected]', # todo: make an email list
install_requires=['equinox',
'jax[cpu]==0.4.28',
'jax[cpu]',
'jaxtyping',
'matplotlib', # this is not strictly necessary
'netcdf4',
'scipy<1.15.0'],
'scipy'],
#tests_require=[], # could put chex and pytest here
extras_require={'sparse': ['scikit-sparse'],
'test': ['pytest', 'pytest-cov', 'pytest-xdist'],
'docs': ['sphinx', 'sphinx-copybutton', 'sphinx-rtd-theme', 'sphinxcontrib-bibtex', 'sphinxcontrib-napoleon']},
python_requires='>=3.7',
version='0.0.1',
python_requires='>=3.11',
version='0.0.2',
license='MIT',
url='https://github.com/sandialabs/optimism'
)