diff --git a/examples/forward_problems/mechanics/hyperelasticity/example_incompressible_2d.py b/examples/forward_problems/mechanics/hyperelasticity/example_incompressible_2d.py index 131d289..539072d 100644 --- a/examples/forward_problems/mechanics/hyperelasticity/example_incompressible_2d.py +++ b/examples/forward_problems/mechanics/hyperelasticity/example_incompressible_2d.py @@ -64,7 +64,7 @@ opt = Adam(loss_function, learning_rate=1.0e-3, has_aux=True, clip_gradients=False) opt, opt_st = opt.init(params) -for epoch in range(25000): +for epoch in range(2500): params, opt_st, loss = opt.step(params, opt_st, problem) if epoch % 100 == 0: print(epoch, flush=True) @@ -79,8 +79,9 @@ 'internal_force' ], element_variables=[ - # 'deformation_gradient', - # 'I1_bar' + 'deformation_gradient', + 'I1_bar', + 'pk1_stress' ] ) pp.write_outputs(params, problem) diff --git a/pancax/physics_kernels/base.py b/pancax/physics_kernels/base.py index 8aa03e3..ed03970 100644 --- a/pancax/physics_kernels/base.py +++ b/pancax/physics_kernels/base.py @@ -29,7 +29,9 @@ def element_pp( is_state_method=False, jit=True ): - def constitutive_method(func, params, domain, t, us, state_old, dt, *args): + def constitutive_method( + func, params, domain, t, us, theta, state_old, dt, *args + ): coords, conns, fspace = domain.coords, domain.conns, domain.fspace us = us[conns, :] xs = coords[conns, :] @@ -51,8 +53,17 @@ def _vmap_func(x, u): # vmap(func, in_axes=in_axes_2), in_axes=in_axes_1)( # params, xs, t, us, grad_us, state_old, dt # ) - grad_us = vmap(vmap(physics.formulation.modify_field_gradient))( - grad_us + if hasattr(physics, "constitutive_model"): + constitutive_model = physics.constitutive_model + else: + constitutive_model = None + + grad_us = vmap(vmap( + physics.formulation.modify_field_gradient, + in_axes=(None, 0, None, 0, None)), + in_axes=(None, 0, None, 0, None) + )( + constitutive_model, grad_us, theta, state_old, dt ) theta = 60. vals, _ = vmap( @@ -60,7 +71,9 @@ def _vmap_func(x, u): )(grad_us, theta, state_old, dt) return vals - def kinematic_method(func, params, domain, t, us, state_old, dt, *args): + def kinematic_method( + func, params, domain, t, us, theta, state_old, dt, *args + ): coords, conns, fspace = domain.coords, domain.conns, domain.fspace us = us[conns, :] xs = coords[conns, :] @@ -75,14 +88,24 @@ def _vmap_func(x, u): return xs, us, grad_us, JxWs xs, us, grad_us, JxWs = vmap(_vmap_func, in_axes=(0, 0))(xs, us) - grad_us = vmap(vmap(physics.formulation.modify_field_gradient))( - grad_us + + if hasattr(physics, "constitutive_model"): + constitutive_model = physics.constitutive_model + else: + constitutive_model = None + + grad_us = vmap(vmap( + physics.formulation.modify_field_gradient, + in_axes=(None, 0, None, 0, None)), + in_axes=(None, 0, None, 0, None) + )( + constitutive_model, grad_us, theta, state_old, dt ) vals = vmap(vmap(func))(grad_us) return vals - def state_method(func, params, domain, t, us, state_old, dt, *args): + def state_method(func, params, domain, t, us, theta, state_old, dt, *args): coords, conns, fspace = domain.coords, domain.conns, domain.fspace us = us[conns, :] xs = coords[conns, :] @@ -107,24 +130,24 @@ def _vmap_func(x, u): return state_news if is_constitutive_method: - def new_func(p, d, t, u, s, dt, *args): + def new_func(p, d, t, u, theta, s, dt, *args): return constitutive_method( # physics.constitutive_model.deformation_gradient, func, - p, d, t, u, s, dt, *args + p, d, t, u, theta, s, dt, *args ) elif is_kinematic_method: - def new_func(p, d, t, u, s, dt, *args): + def new_func(p, d, t, u, theta, s, dt, *args): return kinematic_method( # physics.constitutive_model.deformation_gradient, func, - p, d, t, u, s, dt, *args + p, d, t, u, theta, s, dt, *args ) elif is_state_method: - def new_func(p, d, t, u, s, dt, *args): + def new_func(p, d, t, u, theta, s, dt, *args): return state_method( physics.energy, - p, d, t, u, s, dt, *args + p, d, t, u, theta, s, dt, *args ) else: assert False, 'Only kinematic methods are currently supported' diff --git a/pancax/physics_kernels/solid_mechanics.py b/pancax/physics_kernels/solid_mechanics.py index bc1a8e4..e863baa 100644 --- a/pancax/physics_kernels/solid_mechanics.py +++ b/pancax/physics_kernels/solid_mechanics.py @@ -10,7 +10,9 @@ class BaseMechanicsFormulation(eqx.Module): n_dimensions: int = eqx.field(static=True) # does this need to be static? @abstractmethod - def modify_field_gradient(self, grad_u): + def modify_field_gradient( + self, constitutive_model, grad_u, theta, state_old, dt + ): pass @@ -31,7 +33,9 @@ def deformation_gradient(self, grad_u): F = F.at[2, 2].set(1.0 / jnp.linalg.det(grad_u + jnp.eye(2))) return F - def modify_field_gradient(self, grad_u): + def modify_field_gradient( + self, constitutive_model, grad_u, theta, state_old, dt + ): F = self.deformation_gradient(grad_u) return F - jnp.eye(3) @@ -42,14 +46,18 @@ class PlaneStrain(BaseMechanicsFormulation): def extract_stress(self, P): return P[0:2, 0:2] - def modify_field_gradient(self, grad_u): + def modify_field_gradient( + self, constitutive_model, grad_u, theta, state_old, dt + ): return tensor_2D_to_3D(grad_u) class ThreeDimensional(BaseMechanicsFormulation): n_dimensions: int = 3 - def modify_field_gradient(self, grad_u): + def modify_field_gradient( + self, constitutive_model, grad_u, theta, state_old, dt + ): return grad_u @@ -71,7 +79,9 @@ def __init__(self, constitutive_model, formulation) -> None: def energy(self, params, x, t, u, grad_u, state_old, dt, *args): theta = 60. - grad_u = self.formulation.modify_field_gradient(grad_u) + grad_u = self.formulation.modify_field_gradient( + self.constitutive_model, grad_u, theta, state_old, dt + ) return self.constitutive_model.energy(grad_u, theta, state_old, dt) @property diff --git a/pancax/post_processor.py b/pancax/post_processor.py index 4a69f34..916bed3 100644 --- a/pancax/post_processor.py +++ b/pancax/post_processor.py @@ -302,9 +302,12 @@ def _write_step_outputs( elem_var_num = 0 for var in self.element_variables: output = physics.var_name_to_method[var] + # TODO total hack for now. Need to bind temperature\ + # somewhere somehow + theta = 60. pred = onp.array( output["method"]( - params, problem, time, us, state_old, dt + params, problem, time, us, theta, state_old, dt ) )