Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/forward_problems/mechanics/example_hyper_visco_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
##################
# for reproducibility
##################
key = random.key(10)
key = random.PRNGKey(10)

##################
# file management
Expand Down Expand Up @@ -93,9 +93,9 @@ def dirichlet_bc_func(xs, t, nn):
# train network
##################
opt = Adam(loss_function, learning_rate=1.0e-3, has_aux=True, clip_gradients=False)
opt_st = opt.init(params)
opt, opt_st = opt.init(params)
for epoch in range(100000):
params, opt_st, loss = opt.step(params, problem, opt_st)
params, opt_st, loss = opt.step(params, opt_st, problem)
# logger.log_loss(loss, epoch)
if epoch % 100 == 0:
print(epoch, flush=True)
Expand All @@ -105,6 +105,7 @@ def dirichlet_bc_func(xs, t, nn):

if epoch % 10000 == 0:
pp.init(
params,
problem,
f"output_{str(epoch).zfill(6)}.e",
node_variables=[
Expand Down
15 changes: 8 additions & 7 deletions examples/forward_problems/mechanics/example_incompressible_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
##################
# for reproducibility
##################
key = random.key(10)
key = random.PRNGKey(10)
key = random.split(key, 8) # comment this to not use an ensemble

##################
# file management
Expand All @@ -16,6 +17,7 @@
##################
times = jnp.linspace(0.0, 1.0, 11)
domain = VariationalDomain(mesh_file, times, q_order=2)
# domain = DeltaPINNDomain(mesh_file, times, n_eigen_values=20, q_order=2)

##################
# physics setup
Expand Down Expand Up @@ -49,26 +51,25 @@
# ML setup
##################
loss_function = EnergyLoss()
# loss_function = EnergyAndResidualLoss(residual_weight=1.0e9)
# params = Parameters(problem, key, seperate_networks=True, network_type=ResNet)
params = Parameters(problem, key, seperate_networks=True)
params = Parameters(problem, key, seperate_networks=False)
print(params)

##################
# train network
##################
opt = Adam(loss_function, learning_rate=1.0e-3, has_aux=True, clip_gradients=False)
opt_st = opt.init(params)
opt, opt_st = opt.init(params)

for epoch in range(25000):
params, opt_st, loss = opt.step(params, problem, opt_st)
params, opt_st, loss = opt.step(params, opt_st, problem)
if epoch % 100 == 0:
print(epoch, flush=True)
print(loss, flush=True)

##################
# post-processing
##################
pp.init(problem, 'output.e',
pp.init(params, problem, 'output.e',
node_variables=[
'field_values',
# 'internal_force'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
##################
# for reproducibility
##################
key = random.key(10)
key = random.PRNGKey(10)

##################
# file management
Expand Down Expand Up @@ -58,17 +58,17 @@
# train network
##################
opt = Adam(loss_function, learning_rate=1.0e-3, has_aux=True, clip_gradients=False)
opt_st = opt.init(params)
opt, opt_st = opt.init(params)
for epoch in range(25000):
params, opt_st, loss = opt.step(params, problem, opt_st)
params, opt_st, loss = opt.step(params, opt_st, problem)
if epoch % 100 == 0:
print(epoch, flush=True)
print(loss, flush=True)

##################
# post-processing
##################
pp.init(problem, 'output.e',
pp.init(params, problem, 'output.e',
node_variables=[
'field_values'
# 'displacement',
Expand Down
24 changes: 14 additions & 10 deletions examples/inverse_problems/mechanics/vanilla/example_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
##################
# for reproducibility
##################
key = random.key(10)
# key = random.key(10)
key = random.PRNGKey(10)
key = random.split(key, 8)

##################
# file management
Expand Down Expand Up @@ -35,11 +37,15 @@
##################
# physics setup
##################
# @eqx.filter_vmap
# def models(key):
model = NeoHookean(
bulk_modulus=0.833,
# bulk_modulus=BoundedProperty(0.01, 5., key=key),
shear_modulus=BoundedProperty(0.01, 5., key=key)
)
# return model

physics = SolidMechanics(model, PlaneStrain())

##################
Expand Down Expand Up @@ -71,7 +77,9 @@
##################
# ML setup
##################
params = Parameters(problem, key, seperate_networks=True, network_type=ResNet)
params = Parameters(problem, key)#, seperate_networks=True, network_type=ResNet)
print(params)
# assert False
physics_and_global_loss = EnergyResidualAndReactionLoss(
residual_weight=250.e9, reaction_weight=250.e9
)
Expand All @@ -90,19 +98,15 @@ def loss_function(params, problem, inputs, outputs):
##################
# Training
##################
opt_st = opt.init(params)

opt, opt_st = opt.init(params)

# testing stuff
dataloader = FullFieldDataLoader(problem.field_data)

for epoch in range(10000):
for inputs, outputs in dataloader.dataloader(1024):
params, opt_st, loss = opt.step(params, problem, opt_st, inputs, outputs)
params, opt_st, loss = opt.step(params, opt_st, problem, inputs, outputs)

if epoch % 10 == 0:
print(epoch)
print(loss)
print(params.physics.constitutive_model.bulk_modulus)
print(params.physics.constitutive_model.shear_modulus)
# print(params.physics.constitutive_model.Jm_parameter())
print(params.physics.constitutive_model)
# print(params.physics.constitutive_model.shear_modulus)
4 changes: 2 additions & 2 deletions pancax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
MLPBasis, \
Parameters, \
ResNet
from .optimizers import Adam, LBFGS
from .optimizers import Adam
from .physics_kernels import \
BasePhysics, \
BaseEnergyFormPhysics, \
Expand Down Expand Up @@ -174,7 +174,7 @@
"ResNet",
# optimizers
"Adam",
"LBFGS",
# "LBFGS",
# physics
"BasePhysics",
"BaseEnergyFormPhysics",
Expand Down
17 changes: 2 additions & 15 deletions pancax/constitutive_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,21 @@
class ConstitutiveModel(eqx.Module):
def __repr__(self):
prop_names = self.properties()
# props = asdict(self)
# print(props)
string = f"{type(self)}:\n"

# for prop_name, prop in zip(prop_names, props):
# string = string + f" {prop_name} = {prop}\n"
# for k, v in props.items():
max_str_length = max(map(len, prop_names))
print(max_str_length)

for prop_name in prop_names:
v = getattr(self, prop_name)
string = string + f" {prop_name}"
# string = string.rjust(max_str_length)
# string = string + " = "
string = string.ljust(max_str_length)
string = string + " = "
if type(v) is float:
string = string + f"{v}\n"
# elif type(v) is ConstitutiveModel:
elif issubclass(ConstitutiveModel, type(v)):
string = string + f"{v.__repr__()}\n"
else:
string = string + f"{v()}\n"
# if type(v) is float:
# string = string + f" {prop_name} = {v}\n"
# else:
# string = string + f" {prop_name} = {v()}\n"
string = string + f"{v.__repr__()}\n"

return string

def properties(self):
Expand Down
31 changes: 27 additions & 4 deletions pancax/constitutive_models/properties.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from jaxtyping import Array
from jaxtyping import Array, Float
from typing import Union
import equinox as eqx
import jax
import jax.numpy as jnp


# TODO patch up error check in a good way
# probably just make a method to check type on other
class BoundedProperty(eqx.Module):
prop_min: float = eqx.field(static=True)
prop_max: float = eqx.field(static=True)
prop_val: float
prop_val: Float[Array, "n"]
# TODO
# activation: Callable

Expand All @@ -18,7 +19,16 @@ def __init__(
) -> None:
self.prop_min = prop_min
self.prop_max = prop_max
self.prop_val = jax.random.uniform(key, 1)

if len(key.shape) == 1:
self.prop_val = self._sample(key, prop_min, prop_max)
elif len(key.shape) == 2:
@eqx.filter_vmap
def vmap_func(key):
return self._sample(key, prop_min, prop_max)
self.prop_val = vmap_func(key)
else:
assert False, f"This shouldn't happen key = {key}"

def __call__(self):
return (
Expand All @@ -28,7 +38,12 @@ def __call__(self):
)

def __repr__(self):
return str(self.__call__())
# return str(self.__call__())
return str((
self.prop_min +
(self.prop_max - self.prop_min) *
jax.nn.sigmoid(self.prop_val)
))

def __add__(self, other):
self._check_other_type(other, "+")
Expand Down Expand Up @@ -79,6 +94,14 @@ def _check_other_type(self, other, op_str):
{op_str} with BoundingProperty"
)

def _sample(self, key, lb, ub):
if lb == ub:
return jnp.zeros(1)

p_actual = jax.random.uniform(key, 1, minval=lb, maxval=ub)
y = (p_actual - lb) / (ub - lb)
return jnp.log(y / (1. - y))


FixedProperty = float
Property = Union[BoundedProperty, FixedProperty]
13 changes: 10 additions & 3 deletions pancax/domains/delta_pinn_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ def __init__(
self.physics = physics.update_var_name_to_method()
self.eigen_modes = self.solve_eigen_problem()

# def __pos

def solve_eigen_problem(self):
# physics = LaplaceBeltrami()
dof_manager = DofManager(self.mesh, 1, [])
Expand All @@ -58,9 +56,18 @@ def solve_eigen_problem(self):
for n in range(len(lambdas)):
print(f" Eigen mode {n} = {1. / lambdas[n]}")

# dummy params
class DummyParams:
is_ensemble = False
n_ensemble = 1

params = DummyParams()

with Timer("post-processing"):
pp = PostProcessor(self.mesh_file)
pp.init(self, "output-eigen.e", node_variables=["field_values"])
pp.init(
params, self, "output-eigen.e", node_variables=["field_values"]
)

with nc.Dataset(pp.pp.output_file, "a") as out:
for n in range(len(lambdas)):
Expand Down
Loading
Loading