Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@ docs/build
history*.csv
jax_profile/
pinn*.log
venv/
sierra_batch_script_*
rocm_venv/
venv/
24 changes: 10 additions & 14 deletions examples/inverse_problems/mechanics/path-dependent/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
global_data_file = find_data_file('global_data.csv')
mesh_file = find_mesh_file('mesh_quad4.g')
# logger = Logger('pinn.log', log_every=250)
history = HistoryWriter('history.csv', log_every=250, write_every=250)
history = HistoryWriter('history_pd.csv', log_every=250, write_every=250)
pp = PostProcessor(mesh_file)

##################
Expand All @@ -27,15 +27,11 @@
plotting=True,
interpolate=False
)
# print(global_data.times)
times = global_data.times
print(times)

##################
# domain setup
##################
# times_1 = jnp.linspace(0., 1., 11)
# times_2 = jnp.linspace(1., 11., 11)
# times = jnp.hstack((times_1, times_2[1:]))
domain = VariationalDomain(mesh_file, times, q_order=2)

##################
Expand Down Expand Up @@ -144,13 +140,13 @@ def loss_function(params, problem, inputs, outputs):
params, opt_st, loss = opt.step(params, problem, opt_st, inputs, outputs)
# params, opt_st, loss = opt.step(params, problem, opt_st)

# if epoch % 10 == 0:

history.write_data("epoch", epoch)
history.write_loss(loss)
history.write_data("shear_modulus", params.physics.constitutive_model.eq_model.shear_modulus)

print(epoch)
print(loss)
print(params.physics.constitutive_model.eq_model.bulk_modulus)
print(params.physics.constitutive_model.eq_model.shear_modulus)
print(params.physics.constitutive_model.prony_series.moduli)
print(params.physics.constitutive_model.prony_series.relaxation_times)
print(params.physics.constitutive_model.shift_factor_model)
# print(params.physics.constitutive_model)
# print(params.physics.constitutive_model.Jm_parameter())

if epoch % 1000 == 0:
history.to_csv()
2 changes: 0 additions & 2 deletions pancax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
Linear, \
MLP, \
MLPBasis, \
Network, \
Parameters, \
ResNet
from .optimizers import Adam, LBFGS
Expand Down Expand Up @@ -170,7 +169,6 @@
"Linear",
"MLP",
"MLPBasis",
"Network",
"Parameters",
"ResNet",
# optimizers
Expand Down
9 changes: 1 addition & 8 deletions pancax/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .fields import Field
from .field_physics_pair import FieldPhysicsPair
from .initialization import trunc_init
from .input_polyconvex_nn import InputPolyconvexNN
from .ml_dirichlet_field import MLDirichletField
from .mlp import Linear
Expand All @@ -10,10 +9,6 @@
from .resnet import ResNet


def Network(network_type, *args, **kwargs):
return network_type(*args, **kwargs)


__all__ = [
"Field",
"FieldPhysicsPair",
Expand All @@ -22,8 +17,6 @@ def Network(network_type, *args, **kwargs):
"MLDirichletField",
"MLP",
"MLPBasis",
"Network",
"Parameters",
"ResNet",
"trunc_init"
"ResNet"
]
97 changes: 96 additions & 1 deletion pancax/networks/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,108 @@
from jaxtyping import Float
from typing import Callable, Union
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as random


class BasePancaxModel(eqx.Module):
def _apply_init(
init_fn: Callable,
*args,
key: random.PRNGKey
):
return init_fn(*args, key=key)

# TODO make this conform with new interface
# def box_init(layer: eqx.nn.Linear, key: jax.random.PRNGKey):
# in_size = layer.in_features
# out_size = layer.out_features
# k1, k2 = jax.random.split(key, 2)
# p = jax.random.uniform(k1, (out_size, in_size))
# n = jax.random.normal(k2, (out_size, in_size))

# # normalize normals
# # for i in range(n.shape[0]):
# # n = n.at[i, :].set(n[i, :] / jnp.linalg.norm(n[i, :]))

# # calculate p_max
# # p_max = jnp.max(jnp.sign(n))
# # p_max = jnp.max(jnp.array([0.0, p_max]))

# # setup A and b one vector at a time
# A = jnp.zeros((out_size, in_size))
# b = jnp.zeros((out_size,))
# for i in range(n.shape[0]):
# p_temp = p[i, :]
# n_temp = n[i, :] / jnp.linalg.norm(n[i, :])
# p_max = jnp.max(jnp.array([0.0, jnp.max(jnp.sign(n_temp))]))
# k = 1.0 / jnp.sum((p_max - p_temp) * n_temp)
# A = A.at[i, :].set(k * n_temp)
# b = b.at[i].set(k * jnp.dot(p_temp, n_temp))

# # k = jnp.zeros((n.shape[0],))
# # for i in range(n.shape[0]):
# # k = k.at[i].set(1. / jnp.sum((p_max - p[i, :]) * n[i, :]))

# # A = jax.vmap(lambda k, n: k * n, in_axes=(0, 0))(k, n)
# # b = k * jax.vmap(lambda x: jnp.sum(x), in_axes=(1,))(n @ p.T)
# # print(A)
# # print(b)
# # assert False
# return A, b


def trunc_init(params: Float, key: random.PRNGKey):
stddev = jnp.sqrt(1 / params.shape[0])
return stddev * jax.random.truncated_normal(
key, shape=params.shape, lower=-2, upper=2
)


def uniform_init(params: Float, key: random.PRNGKey):
k = 1. / params.shape[0]
return jax.random.uniform(
key=key, shape=params.shape, minval=-k, maxval=k
)


def zero_init(params: Float, key: random.PRNGKey):
return jnp.zeros(params.shape)


class AbstractPancaxModel(eqx.Module):
"""
Base class for pancax model parameters.

This includes a few helper methods
"""

def deserialise(self, f_name):
self = eqx.tree_deserialise_leaves(f_name, self)
return self

def init(
self,
init_fn: Callable,
filter_func: Union[None, Callable] = None,
*,
key: random.PRNGKey
):
def get_leaves(m):
return jax.tree_util.tree_leaves(m, is_leaf=filter_func)

leaves = get_leaves(self)
keys = random.split(key, len(leaves))
new_leaves = []
for key, leave in zip(keys, leaves):
if hasattr(leave, "shape"):
# case for arrays
new_leaves.append(_apply_init(init_fn, leave, key=key))
else:
new_leaves.append(leave)

return eqx.tree_at(get_leaves, self, new_leaves)

def serialise(self, base_name, epoch):
file_name = f"{base_name}_{str(epoch).zfill(7)}.eqx"
print(f"Serialising current parameters to {file_name}")
Expand Down
4 changes: 2 additions & 2 deletions pancax/networks/field_physics_pair.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .base import BasePancaxModel
from .base import AbstractPancaxModel
import equinox as eqx
import jax.tree_util as jtu


class FieldPhysicsPair(BasePancaxModel):
class FieldPhysicsPair(AbstractPancaxModel):
"""
Data structure for storing a set of field network
parameters and a physics object
Expand Down
4 changes: 2 additions & 2 deletions pancax/networks/fields.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .base import BasePancaxModel
from .base import AbstractPancaxModel
from .mlp import MLP
from typing import Callable, List, Optional, Union
import equinox as eqx
import jax


class Field(BasePancaxModel):
class Field(AbstractPancaxModel):
networks: Union[eqx.Module, List[eqx.Module]]
seperate_networks: bool

Expand Down
144 changes: 0 additions & 144 deletions pancax/networks/initialization.py

This file was deleted.

Loading
Loading