diff --git a/examples/forward_problems/mechanics/hyperelasticity/example_incompressible_2d.py b/examples/forward_problems/mechanics/hyperelasticity/example_incompressible_2d.py index 4b70fa0..b36f8b0 100644 --- a/examples/forward_problems/mechanics/hyperelasticity/example_incompressible_2d.py +++ b/examples/forward_problems/mechanics/hyperelasticity/example_incompressible_2d.py @@ -29,8 +29,8 @@ bulk_modulus=1000.0, shear_modulus=1., ) -# physics = SolidMechanics(model, PlaneStrain()) -physics = SolidMechanics(model, PlaneStress()) +physics = SolidMechanics(model, PlaneStrain()) +# physics = SolidMechanics(model, PlaneStress()) ics = [ ] dirichlet_bcs = [ diff --git a/examples/forward_problems/mechanics/example_hyper_visco_2d.py b/examples/forward_problems/mechanics/hyperviscoelasticity/example_hyper_visco_2d.py similarity index 92% rename from examples/forward_problems/mechanics/example_hyper_visco_2d.py rename to examples/forward_problems/mechanics/hyperviscoelasticity/example_hyper_visco_2d.py index f8f5ffb..a8a2514 100644 --- a/examples/forward_problems/mechanics/example_hyper_visco_2d.py +++ b/examples/forward_problems/mechanics/hyperviscoelasticity/example_hyper_visco_2d.py @@ -63,7 +63,7 @@ def dirichlet_bc_func(xs, t, nn): WLF(C1=17.44, C2=51.6, theta_ref=60.0), ) physics = SolidMechanics(model, PlaneStrain()) -physics = physics.update_dirichlet_bc_func(dirichlet_bc_func) +# physics = physics.update_dirichlet_bc_func(dirichlet_bc_func) ics = [] dirichlet_bcs = [ @@ -84,10 +84,14 @@ def dirichlet_bc_func(xs, t, nn): ################## # ML setup ################## -loss_function = PathDependentEnergyLoss() -# loss_function = EnergyLoss() +# loss_function = PathDependentEnergyLoss() +loss_function = EnergyLoss(is_path_dependent=True) -params = Parameters(problem, key, seperate_networks=True) +params = Parameters( + problem, key, + dirichlet_bc_func=dirichlet_bc_func, + seperate_networks=False +) print(params) ################## diff --git a/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/2holes.g b/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/2holes.g new file mode 120000 index 0000000..769d3d3 --- /dev/null +++ b/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/2holes.g @@ -0,0 +1 @@ +../../mesh/2holes.g \ No newline at end of file diff --git a/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_10x.g b/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_10x.g new file mode 120000 index 0000000..713266e --- /dev/null +++ b/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_10x.g @@ -0,0 +1 @@ +../../mesh/mesh_10x.g \ No newline at end of file diff --git a/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_1x.g b/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_1x.g new file mode 120000 index 0000000..029f646 --- /dev/null +++ b/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_1x.g @@ -0,0 +1 @@ +../../mesh/mesh_1x.g \ No newline at end of file diff --git a/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_hex8.g b/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_hex8.g new file mode 120000 index 0000000..8a48d55 --- /dev/null +++ b/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_hex8.g @@ -0,0 +1 @@ +../../mesh/mesh_hex8.g \ No newline at end of file diff --git a/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_hex8_coarse.g b/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_hex8_coarse.g new file mode 120000 index 0000000..acd50aa --- /dev/null +++ b/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_hex8_coarse.g @@ -0,0 +1 @@ +../../mesh/mesh_hex8_coarse.g \ No newline at end of file diff --git a/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_no_ssets.g b/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_no_ssets.g new file mode 120000 index 0000000..c25e791 --- /dev/null +++ b/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_no_ssets.g @@ -0,0 +1 @@ +../../mesh/mesh_no_ssets.g \ No newline at end of file diff --git a/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_no_ssets_10x.g b/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_no_ssets_10x.g new file mode 120000 index 0000000..7cad53e --- /dev/null +++ b/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_no_ssets_10x.g @@ -0,0 +1 @@ +../../mesh/mesh_no_ssets_10x.g \ No newline at end of file diff --git a/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_quad4.g b/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_quad4.g new file mode 120000 index 0000000..97a03eb --- /dev/null +++ b/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_quad4.g @@ -0,0 +1 @@ +../../mesh/mesh_quad4.g \ No newline at end of file diff --git a/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_quad4_coarse.g b/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_quad4_coarse.g new file mode 120000 index 0000000..9aafea2 --- /dev/null +++ b/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_quad4_coarse.g @@ -0,0 +1 @@ +../../mesh/mesh_quad4_coarse.g \ No newline at end of file diff --git a/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_quad9.g b/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_quad9.g new file mode 120000 index 0000000..74df382 --- /dev/null +++ b/examples/forward_problems/mechanics/hyperviscoelasticity/mesh/mesh_quad9.g @@ -0,0 +1 @@ +../../mesh/mesh_quad9.g \ No newline at end of file diff --git a/examples/inverse_problems/mechanics/model_free/script.py b/examples/inverse_problems/mechanics/model_free/script.py new file mode 100644 index 0000000..d77cc60 --- /dev/null +++ b/examples/inverse_problems/mechanics/model_free/script.py @@ -0,0 +1,128 @@ +from pancax import * + +################## +# for reproducibility +################## +key = random.PRNGKey(10) +# key = random.split(key, 8) + +################## +# file management +################## +full_field_data_file = find_data_file('full_field_data.csv') +global_data_file = find_data_file('global_data.csv') +# mesh_file = find_mesh_file('mesh.g') +# mesh_file = 'data/2holes.g' +mesh_file = os.path.join(Path(__file__).parent, "data", "2holes.g") + +history = HistoryWriter('history.csv', log_every=250, write_every=250) +pp = PostProcessor(mesh_file) + +################## +# data setup +################## +field_data = FullFieldData(full_field_data_file, ['x', 'y', 'z', 't'], ['displ_x', 'displ_y', 'displ_z']) +# the 4 below is for the nodeset id +global_data = GlobalData( + global_data_file, 'times', 'disps', 'forces', + mesh_file, 5, 'y', # these inputs specify where to measure reactions + n_time_steps=11, # the number of time steps for inverse problems is specified here + plotting=True +) + +################## +# domain setup +################## +times = jnp.linspace(0.0, 1.0, len(global_data.outputs)) +domain = VariationalDomain(mesh_file, times) + +################## +# physics setup +################## +model = NeoHookean( + bulk_modulus=10., + shear_modulus=BoundedProperty(0.01, 5., key=key) + # shear_modulus=0.855 +) +# model = InputPolyConvexPotential( +# bulk_modulus=10., key=key +# ) +physics = SolidMechanics(model, ThreeDimensional()) + +################## +# ics/bcs +################## +ics = [ +] +dirichlet_bc_func = UniaxialTensionLinearRamp( + final_displacement=jnp.max(field_data.outputs[:, 1]), + length=1.0, direction='y', n_dimensions=3 +) +dirichlet_bcs = [ + DirichletBC('nodeset_3', 0), # left edge fixed in x + DirichletBC('nodeset_3', 1), # left edge fixed in y + DirichletBC('nodeset_3', 2), + DirichletBC('nodeset_5', 0), # right edge prescribed in x + DirichletBC('nodeset_5', 1), # right edge fixed in y + DirichletBC('nodeset_5', 2) +] +neumann_bcs = [ +] + +################## +# problem setup +################## +problem = InverseProblem(domain, physics, field_data, global_data, ics, dirichlet_bcs, neumann_bcs) + +# print(problem) + +################## +# ML setup +################## +params = Parameters( + problem, key, + dirichlet_bc_func=dirichlet_bc_func + # seperate_networks=True, + # network_type=ResNet +) +print(params) +physics_and_global_loss = EnergyResidualAndReactionLoss( + residual_weight=50.e9, reaction_weight=250.e9 +) +full_field_data_loss = FullFieldDataLoss(weight=500.e9) + +def loss_function(params, problem, inputs, outputs): + loss_1, aux_1 = physics_and_global_loss(params, problem) + loss_2, aux_2 = full_field_data_loss(params, problem, inputs, outputs) + aux_1.update(aux_2) + return loss_1 + loss_2, aux_1 + +loss_function = UserDefinedLossFunction(loss_function) +# loss_function = EnergyLoss() + +opt = Adam(loss_function, learning_rate=1.0e-3, has_aux=True, transition_steps=10000) + +################## +# Training +################## +opt, opt_st = opt.init(params) + +dataloader = FullFieldDataLoader(problem.field_data) +for epoch in range(100000): + for inputs, outputs in dataloader.dataloader(512): + params, opt_st, loss = opt.step(params, opt_st, problem, inputs, outputs) + + # # params.physics.network + # new_model = params.physics.constitutive_model.parameter_enforcement() + # # physics = SolidMechanics(model, PlaneStrain()) + # new_physics = eqx.tree_at(lambda x: x.constitutive_model, params.physics, new_model) + # params = eqx.tree_at(lambda x: x.physics, params, new_physics) + + if epoch % 10 == 0: + print(epoch) + print(loss) + # print(params.physics.constitutive_model) + # print(params.physics.constitutive_model.shear_modulus.prop_min) + # print(params.physics.constitutive_model.shear_modulus.prop_max) + + print(params.physics.constitutive_model.shear_modulus()) diff --git a/examples/inverse_problems/mechanics/path-dependent/script.py b/examples/inverse_problems/mechanics/path-dependent/script.py index 0790274..6022d93 100644 --- a/examples/inverse_problems/mechanics/path-dependent/script.py +++ b/examples/inverse_problems/mechanics/path-dependent/script.py @@ -83,7 +83,7 @@ def dirichlet_bc_func(xs, t, nn): return u_out # model = NeoHookean(bulk_modulus=10., shear_modulus=0.855) -physics = physics.update_dirichlet_bc_func(dirichlet_bc_func) +# physics = physics.update_dirichlet_bc_func(dirichlet_bc_func) dirichlet_bcs = [ DirichletBC('nset_1', 0), # left edge fixed in x DirichletBC('nset_1', 1), # left edge fixed in y @@ -103,9 +103,18 @@ def dirichlet_bc_func(xs, t, nn): ################## # ML setup ################## -params = Parameters(problem, key, seperate_networks=False, network_type=MLP) -physics_and_global_loss = PathDependentEnergyResidualAndReactionLoss( - residual_weight=250.e6, reaction_weight=250.e6 +params = Parameters( + problem, key, + dirichlet_bc_func=dirichlet_bc_func, + seperate_networks=False, + network_type=MLP +) +# physics_and_global_loss = PathDependentEnergyResidualAndReactionLoss( +# residual_weight=250.e6, reaction_weight=250.e6 +# ) +physics_and_global_loss = EnergyResidualAndReactionLoss( + residual_weight=250.e6, reaction_weight=250.e6, + is_path_dependent=True ) full_field_data_loss = FullFieldDataLoss(weight=10.e6) diff --git a/pancax/__init__.py b/pancax/__init__.py index 754390b..4b7eee9 100644 --- a/pancax/__init__.py +++ b/pancax/__init__.py @@ -51,8 +51,6 @@ EnergyAndResidualLoss, \ EnergyResidualAndReactionLoss, \ ResidualMSELoss, \ - PathDependentEnergyLoss, \ - PathDependentEnergyResidualAndReactionLoss, \ UserDefinedLossFunction from .networks import \ Field, \ @@ -159,8 +157,6 @@ "EnergyAndResidualLoss", "EnergyResidualAndReactionLoss", "ResidualMSELoss", - "PathDependentEnergyLoss", - "PathDependentEnergyResidualAndReactionLoss", "UserDefinedLossFunction", # networks "Field", diff --git a/pancax/loss_functions/__init__.py b/pancax/loss_functions/__init__.py index 0462cc2..b2b3b07 100644 --- a/pancax/loss_functions/__init__.py +++ b/pancax/loss_functions/__init__.py @@ -6,9 +6,6 @@ from .utils import CombineLossFunctions, UserDefinedLossFunction from .weak_form_loss_functions import EnergyLoss, EnergyAndResidualLoss from .weak_form_loss_functions import EnergyResidualAndReactionLoss -from .weak_form_loss_functions import PathDependentEnergyLoss -from .weak_form_loss_functions import \ - PathDependentEnergyResidualAndReactionLoss from .weak_form_loss_functions import ResidualMSELoss __all__ = [ @@ -22,7 +19,5 @@ "EnergyAndResidualLoss", "EnergyResidualAndReactionLoss", "ResidualMSELoss", - "PathDependentEnergyLoss", - "PathDependentEnergyResidualAndReactionLoss", "UserDefinedLossFunction" ] diff --git a/pancax/loss_functions/base_loss_function.py b/pancax/loss_functions/base_loss_function.py index 720a7be..abd0b01 100644 --- a/pancax/loss_functions/base_loss_function.py +++ b/pancax/loss_functions/base_loss_function.py @@ -1,5 +1,9 @@ from abc import abstractmethod +from jax import vmap +from typing import Optional import equinox as eqx +import jax +import jax.numpy as jnp class BaseLossFunction(eqx.Module): @@ -35,7 +39,40 @@ class PhysicsLossFunction(BaseLossFunction): type signature ``load_step(self, params, domain, t)`` """ - @abstractmethod def load_step(self, params, domain, t): pass + + def path_dependent_loop( + self, func, start, end, *args, + use_fori_loop: Optional[bool] = False + ): + if use_fori_loop: + def fori_loop_body(n, carry): + return func(n, carry) + + return jax.lax.fori_loop( + start, end, fori_loop_body, args + ) + else: + def scan_body(carry, n): + return func(n, carry), None + + return jax.lax.scan( + scan_body, + args, + jnp.arange(start, end) + )[0] + + def state_variable_init(self, problem): + ne = problem.domain.conns.shape[0] + nq = len(problem.domain.fspace.quadrature_rule) + + def _vmap_func(n): + return problem.physics.constitutive_model.\ + initial_state() + + state_old = vmap(vmap(_vmap_func))( + jnp.zeros((ne, nq)) + ) + return state_old diff --git a/pancax/loss_functions/weak_form_loss_functions.py b/pancax/loss_functions/weak_form_loss_functions.py index a2b7403..739ed30 100644 --- a/pancax/loss_functions/weak_form_loss_functions.py +++ b/pancax/loss_functions/weak_form_loss_functions.py @@ -17,33 +17,70 @@ class EnergyLoss(PhysicsLossFunction): :param weight: weight for this loss function """ - + is_path_dependent: bool + use_fori_loop: bool weight: float - def __init__(self, weight: Optional[float] = 1.0): + def __init__( + self, + is_path_dependent: Optional[bool] = False, + use_fori_loop: Optional[bool] = False, # for testing other looping + weight: Optional[float] = 1.0 + ): + self.is_path_dependent = is_path_dependent + self.use_fori_loop = use_fori_loop self.weight = weight def __call__(self, params, problem): - dt = problem.times[1] - problem.times[0] - energies = vmap(self.load_step, in_axes=(None, None, 0, None))( - params, problem, problem.times, dt - ) - energy = jnp.sum(energies) - loss = energy - return self.weight * loss, dict(energy=energy) + if self.is_path_dependent: + return self.path_dependent_call(params, problem) + else: + return self.path_independent_call(params, problem) - def load_step(self, params, problem, t, dt): + def load_step(self, params, problem, t, dt, state_old): field, physics, state = params - # hack for now, need a zero sized state var array - state_old = jnp.zeros(( - problem.domain.conns.shape[0], - problem.domain.fspace.num_quadrature_points, 0 - )) us = physics.vmap_field_values(field, problem.coords, t) pi, state_new = physics.potential_energy( physics, problem.domain, t, us, state_old, dt ) - return pi + return pi, state_new + + def path_dependent_call(self, params, problem): + state_old = self.state_variable_init(problem) + dt = problem.times[1] - problem.times[0] + pi = 0.0 + + def body(n, carry): + pi, state_old, dt = carry + t = problem.times[n] + pi_t, state_new = self.load_step(params, problem, t, dt, state_old) + pi = pi + pi_t + state_old = state_new + dt = problem.times[n] - problem.times[n - 1] + carry = pi, state_old, dt + return carry + + pi, state_old, dt = self.path_dependent_loop( + body, 1, len(problem.times), + pi, state_old, dt, + use_fori_loop=self.use_fori_loop + ) + + loss = pi + return self.weight * loss, dict(energy=pi) + + def path_independent_call(self, params, problem): + state_old = self.state_variable_init(problem) + dt = problem.times[1] - problem.times[0] + energies, state_news = vmap( + self.load_step, + in_axes=(None, None, 0, None, None) + )( + params, problem, problem.times, dt, state_old + ) + energy = jnp.sum(energies) + loss = energy + return self.weight * loss, dict(energy=energy) class ResidualMSELoss(PhysicsLossFunction): @@ -82,183 +119,125 @@ class EnergyAndResidualLoss(PhysicsLossFunction): """ energy_weight: float + is_path_dependent: bool residual_weight: float + use_fori_loop: bool def __init__( self, energy_weight: Optional[float] = 1.0, + is_path_dependent: Optional[bool] = False, residual_weight: Optional[float] = 1.0, + use_fori_loop: Optional[bool] = False ): self.energy_weight = energy_weight + self.is_path_dependent = is_path_dependent self.residual_weight = residual_weight + self.use_fori_loop = use_fori_loop def __call__(self, params, problem): - dt = problem.times[1] - problem.times[0] - (pis, state_new), Rs = vmap( - self.load_step, - in_axes=(None, None, 0, None))( - params, problem, problem.times, dt - ) - # pi, R = jnp.sum(pis), jnp.sum(Rs) - pi, R = jnp.sum(pis), jnp.mean(Rs) - loss = self.energy_weight * pi + self.residual_weight * R - return loss, dict(energy=pi, residual=R) - - def load_step(self, params, problem, t, dt): - field, physics, state = params - # hack for now, need a zero sized state var array - state_old = jnp.zeros(( - problem.domain.conns.shape[0], - problem.domain.fspace.num_quadrature_points, 0 - )) - us = physics.vmap_field_values(field, problem.coords, t) - return physics.potential_energy_and_residual( - params, problem.domain, t, us, state_old, dt - ) - - -class EnergyResidualAndReactionLoss(PhysicsLossFunction): - energy_weight: float - residual_weight: float - reaction_weight: float - - def __init__( - self, - energy_weight: Optional[float] = 1.0, - residual_weight: Optional[float] = 1.0, - reaction_weight: Optional[float] = 1.0, - ): - self.energy_weight = energy_weight - self.residual_weight = residual_weight - self.reaction_weight = reaction_weight - - def __call__(self, params, problem): - dt = problem.times[1] - problem.times[0] - (pis, states_new), Rs, reactions = vmap( - self.load_step, - in_axes=(None, None, 0, None))( - params, problem, problem.times, dt - ) - pi, R = jnp.sum(pis), jnp.sum(Rs) / len(problem.times) - reaction_loss = \ - jnp.square(reactions - problem.global_data.outputs).mean() - loss = ( - self.energy_weight * pi - + self.residual_weight * R - + self.reaction_weight * reaction_loss - ) - return loss, dict( - energy=pi, residual=R, - global_data_loss=reaction_loss, reactions=reactions - ) - - def load_step(self, params, problem, t, dt): - # field_network, props = params - field, physics, state = params - # us = domain.field_values(field_network, t) - state_old = jnp.zeros(( - problem.domain.conns.shape[0], - problem.domain.fspace.num_quadrature_points, 0 - )) - us = physics.vmap_field_values(field, problem.coords, t) - return physics.potential_energy_residual_and_reaction_force( - params, problem.domain, t, us, state_old, dt, - problem.global_data - ) - - -class PathDependentEnergyLoss(PhysicsLossFunction): - weight: float - - def __init__(self, weight: Optional[float] = 1.0) -> None: - self.weight = weight - - def __call__(self, params, problem): - - ne = problem.domain.conns.shape[0] - nq = len(problem.domain.fspace.quadrature_rule) - - def _vmap_func(n): - return problem.physics.constitutive_model.\ - initial_state() + if self.is_path_dependent: + assert False + else: + self.path_independent_call(params, problem) + def path_dependent_call(self, params, problem): + state_old = self.state_variable_init(problem) dt = problem.times[1] - problem.times[0] pi = 0.0 - state_old = vmap(vmap(_vmap_func))( - jnp.zeros((ne, nq)) - ) + R = 0.0 - def fori_loop_body(n, carry): - pi, state_old, dt = carry + def body(n, carry): + pi, state_old, dt, R = carry t = problem.times[n] - pi_t, state_new = self.load_step(params, problem, t, dt, state_old) + (pi_t, state_new), R_t = \ + self.load_step(params, problem, t, dt, state_old) pi = pi + pi_t + R = R + R_t state_old = state_new dt = problem.times[n] - problem.times[n - 1] - carry = pi, state_old, dt + carry = pi, state_old, dt, R return carry - def scan_body(carry, n): - pi, state_old, dt = carry - t = problem.times[n] - pi_t, state_new = self.load_step(params, problem, t, dt, state_old) - pi = pi + pi_t - state_old = state_new - dt = problem.times[n] - problem.times[n - 1] - carry = pi, state_old, dt - return carry, None - - # starting at 1 assuming time step 0 is initial condition - # pi, state_old, dt = jax.lax.fori_loop( - # 1, len(problem.times), fori_loop_body, (pi, state_old, dt) - # ) - (pi, state_old, dt), _ = jax.lax.scan( - scan_body, - (pi, state_old, dt), - jnp.arange(1, len(problem.times)) + if self.use_fori_loop: + def fori_loop_body(n, carry): + return body(n, carry) + + # starting at 1 assuming time step 0 is initial condition + pi, state_old, dt, R = jax.lax.fori_loop( + 1, len(problem.times), fori_loop_body, (pi, state_old, dt, R) + ) + else: + def scan_body(carry, n): + return body(n, carry), None + + (pi, state_old, dt, R), _ = jax.lax.scan( + scan_body, + (pi, state_old, dt, R), + jnp.arange(1, len(problem.times)) + ) + + loss = self.energy_weight * pi + \ + self.residual_weight * (R / len(problem.times)) + return loss, dict( + energy=pi, + residual=R / len(problem.times) ) - loss = pi - return self.weight * loss, dict(energy=pi) + + def path_independent_call(self, params, problem): + state_old = self.state_variable_init(problem) + dt = problem.times[1] - problem.times[0] + (pis, state_new), Rs = vmap( + self.load_step, + in_axes=(None, None, 0, None, None))( + params, problem, problem.times, dt, state_old + ) + # pi, R = jnp.sum(pis), jnp.sum(Rs) + pi, R = jnp.sum(pis), jnp.mean(Rs) + loss = self.energy_weight * pi + self.residual_weight * R + return loss, dict(energy=pi, residual=R) def load_step(self, params, problem, t, dt, state_old): field, physics, state = params us = physics.vmap_field_values(field, problem.coords, t) - pi, state_new = physics.potential_energy( - physics, problem.domain, t, us, state_old, dt + return physics.potential_energy_and_residual( + params, problem.domain, t, us, state_old, dt ) - return pi, state_new -class PathDependentEnergyResidualAndReactionLoss(PhysicsLossFunction): +class EnergyResidualAndReactionLoss(PhysicsLossFunction): energy_weight: float + is_path_dependent: bool residual_weight: float reaction_weight: float + use_fori_loop: bool def __init__( self, energy_weight: Optional[float] = 1.0, + is_path_dependent: Optional[bool] = False, residual_weight: Optional[float] = 1.0, reaction_weight: Optional[float] = 1.0, + use_fori_loop: Optional[bool] = False ): self.energy_weight = energy_weight + self.is_path_dependent = is_path_dependent self.residual_weight = residual_weight self.reaction_weight = reaction_weight + self.use_fori_loop = use_fori_loop def __call__(self, params, problem): - ne = problem.domain.conns.shape[0] - nq = len(problem.domain.fspace.quadrature_rule) - - def _vmap_func(n): - return problem.physics.constitutive_model.\ - initial_state() + if self.is_path_dependent: + return self.path_dependent_call(params, problem) + else: + return self.path_independent_call(params, problem) + def path_dependent_call(self, params, problem): + state_old = self.state_variable_init(problem) dt = problem.times[1] - problem.times[0] pi = 0.0 R = 0.0 reaction = 0.0 - state_old = vmap(vmap(_vmap_func))( - jnp.zeros((ne, nq)) - ) def body(n, carry): pi, state_old, dt, R, reaction = carry @@ -274,9 +253,12 @@ def body(n, carry): carry = pi, state_old, dt, R, reaction return carry - pi, state_old, dt, R, reaction = jax.lax.fori_loop( - 1, len(problem.times), body, (pi, state_old, dt, R, reaction) + pi, state_old, dt, R, reaction = self.path_dependent_loop( + body, 1, len(problem.times), + pi, state_old, dt, R, reaction, + use_fori_loop=self.use_fori_loop ) + loss = self.energy_weight * pi + \ self.residual_weight * (R / len(problem.times)) + \ self.reaction_weight * (reaction / len(problem.times)) @@ -286,6 +268,27 @@ def body(n, carry): global_data_loss=reaction / len(problem.times) ) + def path_independent_call(self, params, problem): + state_old = self.state_variable_init(problem) + dt = problem.times[1] - problem.times[0] + (pis, states_new), Rs, reactions = vmap( + self.load_step, + in_axes=(None, None, 0, None, None))( + params, problem, problem.times, dt, state_old + ) + pi, R = jnp.sum(pis), jnp.sum(Rs) / len(problem.times) + reaction_loss = \ + jnp.square(reactions - problem.global_data.outputs).mean() + loss = ( + self.energy_weight * pi + + self.residual_weight * R + + self.reaction_weight * reaction_loss + ) + return loss, dict( + energy=pi, residual=R, + global_data_loss=reaction_loss, reactions=reactions + ) + def load_step(self, params, problem, t, dt, state_old): field, physics, state = params us = physics.vmap_field_values(field, problem.coords, t)