Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
bulk_modulus=1000.0,
shear_modulus=1.,
)
# physics = SolidMechanics(model, PlaneStrain())
physics = SolidMechanics(model, PlaneStress())
physics = SolidMechanics(model, PlaneStrain())
# physics = SolidMechanics(model, PlaneStress())
ics = [
]
dirichlet_bcs = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def dirichlet_bc_func(xs, t, nn):
WLF(C1=17.44, C2=51.6, theta_ref=60.0),
)
physics = SolidMechanics(model, PlaneStrain())
physics = physics.update_dirichlet_bc_func(dirichlet_bc_func)
# physics = physics.update_dirichlet_bc_func(dirichlet_bc_func)

ics = []
dirichlet_bcs = [
Expand All @@ -84,10 +84,14 @@ def dirichlet_bc_func(xs, t, nn):
##################
# ML setup
##################
loss_function = PathDependentEnergyLoss()
# loss_function = EnergyLoss()
# loss_function = PathDependentEnergyLoss()
loss_function = EnergyLoss(is_path_dependent=True)

params = Parameters(problem, key, seperate_networks=True)
params = Parameters(
problem, key,
dirichlet_bc_func=dirichlet_bc_func,
seperate_networks=False
)
print(params)

##################
Expand Down
128 changes: 128 additions & 0 deletions examples/inverse_problems/mechanics/model_free/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from pancax import *

##################
# for reproducibility
##################
key = random.PRNGKey(10)
# key = random.split(key, 8)

##################
# file management
##################
full_field_data_file = find_data_file('full_field_data.csv')
global_data_file = find_data_file('global_data.csv')
# mesh_file = find_mesh_file('mesh.g')
# mesh_file = 'data/2holes.g'
mesh_file = os.path.join(Path(__file__).parent, "data", "2holes.g")

history = HistoryWriter('history.csv', log_every=250, write_every=250)
pp = PostProcessor(mesh_file)

##################
# data setup
##################
field_data = FullFieldData(full_field_data_file, ['x', 'y', 'z', 't'], ['displ_x', 'displ_y', 'displ_z'])
# the 4 below is for the nodeset id
global_data = GlobalData(
global_data_file, 'times', 'disps', 'forces',
mesh_file, 5, 'y', # these inputs specify where to measure reactions
n_time_steps=11, # the number of time steps for inverse problems is specified here
plotting=True
)

##################
# domain setup
##################
times = jnp.linspace(0.0, 1.0, len(global_data.outputs))
domain = VariationalDomain(mesh_file, times)

##################
# physics setup
##################
model = NeoHookean(
bulk_modulus=10.,
shear_modulus=BoundedProperty(0.01, 5., key=key)
# shear_modulus=0.855
)
# model = InputPolyConvexPotential(
# bulk_modulus=10., key=key
# )
physics = SolidMechanics(model, ThreeDimensional())

##################
# ics/bcs
##################
ics = [
]
dirichlet_bc_func = UniaxialTensionLinearRamp(
final_displacement=jnp.max(field_data.outputs[:, 1]),
length=1.0, direction='y', n_dimensions=3
)
dirichlet_bcs = [
DirichletBC('nodeset_3', 0), # left edge fixed in x
DirichletBC('nodeset_3', 1), # left edge fixed in y
DirichletBC('nodeset_3', 2),
DirichletBC('nodeset_5', 0), # right edge prescribed in x
DirichletBC('nodeset_5', 1), # right edge fixed in y
DirichletBC('nodeset_5', 2)
]
neumann_bcs = [
]

##################
# problem setup
##################
problem = InverseProblem(domain, physics, field_data, global_data, ics, dirichlet_bcs, neumann_bcs)

# print(problem)

##################
# ML setup
##################
params = Parameters(
problem, key,
dirichlet_bc_func=dirichlet_bc_func
# seperate_networks=True,
# network_type=ResNet
)
print(params)
physics_and_global_loss = EnergyResidualAndReactionLoss(
residual_weight=50.e9, reaction_weight=250.e9
)
full_field_data_loss = FullFieldDataLoss(weight=500.e9)

def loss_function(params, problem, inputs, outputs):
loss_1, aux_1 = physics_and_global_loss(params, problem)
loss_2, aux_2 = full_field_data_loss(params, problem, inputs, outputs)
aux_1.update(aux_2)
return loss_1 + loss_2, aux_1

loss_function = UserDefinedLossFunction(loss_function)
# loss_function = EnergyLoss()

opt = Adam(loss_function, learning_rate=1.0e-3, has_aux=True, transition_steps=10000)

##################
# Training
##################
opt, opt_st = opt.init(params)

dataloader = FullFieldDataLoader(problem.field_data)
for epoch in range(100000):
for inputs, outputs in dataloader.dataloader(512):
params, opt_st, loss = opt.step(params, opt_st, problem, inputs, outputs)

# # params.physics.network
# new_model = params.physics.constitutive_model.parameter_enforcement()
# # physics = SolidMechanics(model, PlaneStrain())
# new_physics = eqx.tree_at(lambda x: x.constitutive_model, params.physics, new_model)
# params = eqx.tree_at(lambda x: x.physics, params, new_physics)

if epoch % 10 == 0:
print(epoch)
print(loss)
# print(params.physics.constitutive_model)
# print(params.physics.constitutive_model.shear_modulus.prop_min)
# print(params.physics.constitutive_model.shear_modulus.prop_max)

print(params.physics.constitutive_model.shear_modulus())
17 changes: 13 additions & 4 deletions examples/inverse_problems/mechanics/path-dependent/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def dirichlet_bc_func(xs, t, nn):
return u_out

# model = NeoHookean(bulk_modulus=10., shear_modulus=0.855)
physics = physics.update_dirichlet_bc_func(dirichlet_bc_func)
# physics = physics.update_dirichlet_bc_func(dirichlet_bc_func)
dirichlet_bcs = [
DirichletBC('nset_1', 0), # left edge fixed in x
DirichletBC('nset_1', 1), # left edge fixed in y
Expand All @@ -103,9 +103,18 @@ def dirichlet_bc_func(xs, t, nn):
##################
# ML setup
##################
params = Parameters(problem, key, seperate_networks=False, network_type=MLP)
physics_and_global_loss = PathDependentEnergyResidualAndReactionLoss(
residual_weight=250.e6, reaction_weight=250.e6
params = Parameters(
problem, key,
dirichlet_bc_func=dirichlet_bc_func,
seperate_networks=False,
network_type=MLP
)
# physics_and_global_loss = PathDependentEnergyResidualAndReactionLoss(
# residual_weight=250.e6, reaction_weight=250.e6
# )
physics_and_global_loss = EnergyResidualAndReactionLoss(
residual_weight=250.e6, reaction_weight=250.e6,
is_path_dependent=True
)
full_field_data_loss = FullFieldDataLoss(weight=10.e6)

Expand Down
4 changes: 0 additions & 4 deletions pancax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@
EnergyAndResidualLoss, \
EnergyResidualAndReactionLoss, \
ResidualMSELoss, \
PathDependentEnergyLoss, \
PathDependentEnergyResidualAndReactionLoss, \
UserDefinedLossFunction
from .networks import \
Field, \
Expand Down Expand Up @@ -159,8 +157,6 @@
"EnergyAndResidualLoss",
"EnergyResidualAndReactionLoss",
"ResidualMSELoss",
"PathDependentEnergyLoss",
"PathDependentEnergyResidualAndReactionLoss",
"UserDefinedLossFunction",
# networks
"Field",
Expand Down
5 changes: 0 additions & 5 deletions pancax/loss_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from .utils import CombineLossFunctions, UserDefinedLossFunction
from .weak_form_loss_functions import EnergyLoss, EnergyAndResidualLoss
from .weak_form_loss_functions import EnergyResidualAndReactionLoss
from .weak_form_loss_functions import PathDependentEnergyLoss
from .weak_form_loss_functions import \
PathDependentEnergyResidualAndReactionLoss
from .weak_form_loss_functions import ResidualMSELoss

__all__ = [
Expand All @@ -22,7 +19,5 @@
"EnergyAndResidualLoss",
"EnergyResidualAndReactionLoss",
"ResidualMSELoss",
"PathDependentEnergyLoss",
"PathDependentEnergyResidualAndReactionLoss",
"UserDefinedLossFunction"
]
39 changes: 38 additions & 1 deletion pancax/loss_functions/base_loss_function.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from abc import abstractmethod
from jax import vmap
from typing import Optional
import equinox as eqx
import jax
import jax.numpy as jnp


class BaseLossFunction(eqx.Module):
Expand Down Expand Up @@ -35,7 +39,40 @@ class PhysicsLossFunction(BaseLossFunction):
type signature
``load_step(self, params, domain, t)``
"""

@abstractmethod
def load_step(self, params, domain, t):
pass

def path_dependent_loop(
self, func, start, end, *args,
use_fori_loop: Optional[bool] = False
):
if use_fori_loop:
def fori_loop_body(n, carry):
return func(n, carry)

return jax.lax.fori_loop(
start, end, fori_loop_body, args
)
else:
def scan_body(carry, n):
return func(n, carry), None

return jax.lax.scan(
scan_body,
args,
jnp.arange(start, end)
)[0]

def state_variable_init(self, problem):
ne = problem.domain.conns.shape[0]
nq = len(problem.domain.fspace.quadrature_rule)

def _vmap_func(n):
return problem.physics.constitutive_model.\
initial_state()

state_old = vmap(vmap(_vmap_func))(
jnp.zeros((ne, nq))
)
return state_old
Loading
Loading