diff --git a/.gitignore b/.gitignore index 30d70f4..5161e01 100644 --- a/.gitignore +++ b/.gitignore @@ -23,5 +23,6 @@ docs/build history*.csv jax_profile/ pinn*.log -venv/ sierra_batch_script_* +rocm_venv/ +venv/ diff --git a/examples/inverse_problems/mechanics/path-dependent/script.py b/examples/inverse_problems/mechanics/path-dependent/script.py index 397b48b..783b4b5 100644 --- a/examples/inverse_problems/mechanics/path-dependent/script.py +++ b/examples/inverse_problems/mechanics/path-dependent/script.py @@ -12,7 +12,7 @@ global_data_file = find_data_file('global_data.csv') mesh_file = find_mesh_file('mesh_quad4.g') # logger = Logger('pinn.log', log_every=250) -history = HistoryWriter('history.csv', log_every=250, write_every=250) +history = HistoryWriter('history_pd.csv', log_every=250, write_every=250) pp = PostProcessor(mesh_file) ################## @@ -27,15 +27,11 @@ plotting=True, interpolate=False ) -# print(global_data.times) times = global_data.times -print(times) + ################## # domain setup ################## -# times_1 = jnp.linspace(0., 1., 11) -# times_2 = jnp.linspace(1., 11., 11) -# times = jnp.hstack((times_1, times_2[1:])) domain = VariationalDomain(mesh_file, times, q_order=2) ################## @@ -144,13 +140,13 @@ def loss_function(params, problem, inputs, outputs): params, opt_st, loss = opt.step(params, problem, opt_st, inputs, outputs) # params, opt_st, loss = opt.step(params, problem, opt_st) - # if epoch % 10 == 0: + + history.write_data("epoch", epoch) + history.write_loss(loss) + history.write_data("shear_modulus", params.physics.constitutive_model.eq_model.shear_modulus) + print(epoch) print(loss) - print(params.physics.constitutive_model.eq_model.bulk_modulus) - print(params.physics.constitutive_model.eq_model.shear_modulus) - print(params.physics.constitutive_model.prony_series.moduli) - print(params.physics.constitutive_model.prony_series.relaxation_times) - print(params.physics.constitutive_model.shift_factor_model) - # print(params.physics.constitutive_model) - # print(params.physics.constitutive_model.Jm_parameter()) + + if epoch % 1000 == 0: + history.to_csv() diff --git a/pancax/__init__.py b/pancax/__init__.py index e950f53..708065a 100644 --- a/pancax/__init__.py +++ b/pancax/__init__.py @@ -60,7 +60,6 @@ Linear, \ MLP, \ MLPBasis, \ - Network, \ Parameters, \ ResNet from .optimizers import Adam, LBFGS @@ -170,7 +169,6 @@ "Linear", "MLP", "MLPBasis", - "Network", "Parameters", "ResNet", # optimizers diff --git a/pancax/networks/__init__.py b/pancax/networks/__init__.py index 66c699e..4e45d2d 100644 --- a/pancax/networks/__init__.py +++ b/pancax/networks/__init__.py @@ -1,6 +1,5 @@ from .fields import Field from .field_physics_pair import FieldPhysicsPair -from .initialization import trunc_init from .input_polyconvex_nn import InputPolyconvexNN from .ml_dirichlet_field import MLDirichletField from .mlp import Linear @@ -10,10 +9,6 @@ from .resnet import ResNet -def Network(network_type, *args, **kwargs): - return network_type(*args, **kwargs) - - __all__ = [ "Field", "FieldPhysicsPair", @@ -22,8 +17,6 @@ def Network(network_type, *args, **kwargs): "MLDirichletField", "MLP", "MLPBasis", - "Network", "Parameters", - "ResNet", - "trunc_init" + "ResNet" ] diff --git a/pancax/networks/base.py b/pancax/networks/base.py index 3dbaf38..fcc2393 100644 --- a/pancax/networks/base.py +++ b/pancax/networks/base.py @@ -1,13 +1,108 @@ +from jaxtyping import Float +from typing import Callable, Union import equinox as eqx +import jax +import jax.numpy as jnp +import jax.random as random -class BasePancaxModel(eqx.Module): +def _apply_init( + init_fn: Callable, + *args, + key: random.PRNGKey +): + return init_fn(*args, key=key) + +# TODO make this conform with new interface +# def box_init(layer: eqx.nn.Linear, key: jax.random.PRNGKey): +# in_size = layer.in_features +# out_size = layer.out_features +# k1, k2 = jax.random.split(key, 2) +# p = jax.random.uniform(k1, (out_size, in_size)) +# n = jax.random.normal(k2, (out_size, in_size)) + +# # normalize normals +# # for i in range(n.shape[0]): +# # n = n.at[i, :].set(n[i, :] / jnp.linalg.norm(n[i, :])) + +# # calculate p_max +# # p_max = jnp.max(jnp.sign(n)) +# # p_max = jnp.max(jnp.array([0.0, p_max])) + +# # setup A and b one vector at a time +# A = jnp.zeros((out_size, in_size)) +# b = jnp.zeros((out_size,)) +# for i in range(n.shape[0]): +# p_temp = p[i, :] +# n_temp = n[i, :] / jnp.linalg.norm(n[i, :]) +# p_max = jnp.max(jnp.array([0.0, jnp.max(jnp.sign(n_temp))])) +# k = 1.0 / jnp.sum((p_max - p_temp) * n_temp) +# A = A.at[i, :].set(k * n_temp) +# b = b.at[i].set(k * jnp.dot(p_temp, n_temp)) + +# # k = jnp.zeros((n.shape[0],)) +# # for i in range(n.shape[0]): +# # k = k.at[i].set(1. / jnp.sum((p_max - p[i, :]) * n[i, :])) + +# # A = jax.vmap(lambda k, n: k * n, in_axes=(0, 0))(k, n) +# # b = k * jax.vmap(lambda x: jnp.sum(x), in_axes=(1,))(n @ p.T) +# # print(A) +# # print(b) +# # assert False +# return A, b + + +def trunc_init(params: Float, key: random.PRNGKey): + stddev = jnp.sqrt(1 / params.shape[0]) + return stddev * jax.random.truncated_normal( + key, shape=params.shape, lower=-2, upper=2 + ) + + +def uniform_init(params: Float, key: random.PRNGKey): + k = 1. / params.shape[0] + return jax.random.uniform( + key=key, shape=params.shape, minval=-k, maxval=k + ) + + +def zero_init(params: Float, key: random.PRNGKey): + return jnp.zeros(params.shape) + + +class AbstractPancaxModel(eqx.Module): """ Base class for pancax model parameters. This includes a few helper methods """ + def deserialise(self, f_name): + self = eqx.tree_deserialise_leaves(f_name, self) + return self + + def init( + self, + init_fn: Callable, + filter_func: Union[None, Callable] = None, + *, + key: random.PRNGKey + ): + def get_leaves(m): + return jax.tree_util.tree_leaves(m, is_leaf=filter_func) + + leaves = get_leaves(self) + keys = random.split(key, len(leaves)) + new_leaves = [] + for key, leave in zip(keys, leaves): + if hasattr(leave, "shape"): + # case for arrays + new_leaves.append(_apply_init(init_fn, leave, key=key)) + else: + new_leaves.append(leave) + + return eqx.tree_at(get_leaves, self, new_leaves) + def serialise(self, base_name, epoch): file_name = f"{base_name}_{str(epoch).zfill(7)}.eqx" print(f"Serialising current parameters to {file_name}") diff --git a/pancax/networks/field_physics_pair.py b/pancax/networks/field_physics_pair.py index f72287b..43ad32a 100644 --- a/pancax/networks/field_physics_pair.py +++ b/pancax/networks/field_physics_pair.py @@ -1,9 +1,9 @@ -from .base import BasePancaxModel +from .base import AbstractPancaxModel import equinox as eqx import jax.tree_util as jtu -class FieldPhysicsPair(BasePancaxModel): +class FieldPhysicsPair(AbstractPancaxModel): """ Data structure for storing a set of field network parameters and a physics object diff --git a/pancax/networks/fields.py b/pancax/networks/fields.py index c90d334..1055636 100644 --- a/pancax/networks/fields.py +++ b/pancax/networks/fields.py @@ -1,11 +1,11 @@ -from .base import BasePancaxModel +from .base import AbstractPancaxModel from .mlp import MLP from typing import Callable, List, Optional, Union import equinox as eqx import jax -class Field(BasePancaxModel): +class Field(AbstractPancaxModel): networks: Union[eqx.Module, List[eqx.Module]] seperate_networks: bool diff --git a/pancax/networks/initialization.py b/pancax/networks/initialization.py deleted file mode 100644 index af5a7f5..0000000 --- a/pancax/networks/initialization.py +++ /dev/null @@ -1,144 +0,0 @@ -from jaxtyping import Array, Float -from typing import Callable -import equinox as eqx -import jax -import jax.numpy as jnp - - -def zero_init(key: jax.random.PRNGKey, shape) -> Float[Array, "no ni"]: - """ - :param weight: current weight array for sizing - :param key: rng key - :return: A new set of weights - """ - out, in_ = weight.shape - return jnp.zeros(shape, dtype=jnp.float64) - - -def trunc_init(key: jax.random.PRNGKey, shape) -> Float[Array, "no ni"]: - """ - :param weight: current weight array for sizing - :param key: rng key - :return: A new set of weights - """ - stddev = jnp.sqrt(1 / shape[0]) - return stddev * jax.random.truncated_normal( - key, shape=shape, lower=-2, upper=2 - ) - - -def init_linear_weight( - model: eqx.Module, init_fn: Callable, key: jax.random.PRNGKey -) -> eqx.Module: - """ - :param model: equinox model - :param init_fn: function to initialize weigth with - :param key: rng key - :return: a new equinox model - """ - # is_linear = lambda x: isinstance(x, eqx.nn.Linear) - def is_linear(x): - return isinstance(x, eqx.nn.Linear) - - # get_weights = lambda m: [ - # x.weight - # for x in jax.tree_util.tree_leaves(m, is_leaf=is_linear) - # if is_linear(x) - # ] - def get_weights(m): - return [ - x.weight - for x in jax.tree_util.tree_leaves(m, is_leaf, is_linear) - if is_linear(x) - ] - - weights = get_weights(model) - new_weights = [ - init_fn(subkey, weight.shape) - for subkey, weight in zip(jax.random.split(key, len(weights)), weights) - ] - new_model = eqx.tree_at(get_weights, model, new_weights) - return new_model - - -def init_linear(model: eqx.Module, init_fn: Callable, key: jax.random.PRNGKey): - """ - :param model: equinox model - :param init_fn: function to initialize weigth with - :param key: rng key - :return: a new equinox model - """ - def is_linear(x): - return isinstance(x, eqx.nn.Linear) - - # def get_biases(m): - # return [ - # x.bias - # for x in jax.tree_util.tree_leaves(m, is_leaf, is_linear) - # if is_linear(x) - # ] - - # def get_weights(m): - # return [ - # x.weight - # for x in jax.tree_util.tree_leaves(m, is_leaf, is_linear) - # if is_linear(x) - # ] - - def get_layers(m): - return [ - x - for x in jax.tree_util.tree_leaves(m, is_leaf, is_linear) - if is_linear(x) - ] - - layers = get_layers(model) - # weights = get_weights(model) - # biases = get_biases(model) - new_layers = [ - init_fn(layer, subkey) - for layer, subkey in zip(layers, jax.random.split(key, len(layers))) - ] - new_weights = [x[0] for x in new_layers] - new_biases = [x[1] for x in new_layers] - new_model = eqx.tree_at(get_weights, model, new_weights) - new_model = eqx.tree_at(get_biases, model, new_biases) - return new_model - - -def box_init(layer: eqx.nn.Linear, key: jax.random.PRNGKey): - in_size = layer.in_features - out_size = layer.out_features - k1, k2 = jax.random.split(key, 2) - p = jax.random.uniform(k1, (out_size, in_size)) - n = jax.random.normal(k2, (out_size, in_size)) - - # normalize normals - # for i in range(n.shape[0]): - # n = n.at[i, :].set(n[i, :] / jnp.linalg.norm(n[i, :])) - - # calculate p_max - # p_max = jnp.max(jnp.sign(n)) - # p_max = jnp.max(jnp.array([0.0, p_max])) - - # setup A and b one vector at a time - A = jnp.zeros((out_size, in_size)) - b = jnp.zeros((out_size,)) - for i in range(n.shape[0]): - p_temp = p[i, :] - n_temp = n[i, :] / jnp.linalg.norm(n[i, :]) - p_max = jnp.max(jnp.array([0.0, jnp.max(jnp.sign(n_temp))])) - k = 1.0 / jnp.sum((p_max - p_temp) * n_temp) - A = A.at[i, :].set(k * n_temp) - b = b.at[i].set(k * jnp.dot(p_temp, n_temp)) - - # k = jnp.zeros((n.shape[0],)) - # for i in range(n.shape[0]): - # k = k.at[i].set(1. / jnp.sum((p_max - p[i, :]) * n[i, :])) - - # A = jax.vmap(lambda k, n: k * n, in_axes=(0, 0))(k, n) - # b = k * jax.vmap(lambda x: jnp.sum(x), in_axes=(1,))(n @ p.T) - # print(A) - # print(b) - # assert False - return A, b diff --git a/pancax/networks/mlp.py b/pancax/networks/mlp.py index f816d7d..bf0a2da 100644 --- a/pancax/networks/mlp.py +++ b/pancax/networks/mlp.py @@ -1,4 +1,4 @@ -from .initialization import trunc_init +from .base import AbstractPancaxModel from typing import Callable from typing import Optional import equinox as eqx @@ -31,8 +31,8 @@ def MLPBasis( n_neurons: int, n_layers: int, activation: Callable, - key: jax.random.PRNGKey, - init_func: Optional[Callable] = trunc_init, + key: jax.random.PRNGKey + # init_func: Optional[Callable] = trunc_init, ): return MLP( n_inputs, @@ -41,12 +41,12 @@ def MLPBasis( n_layers, activation=activation, use_final_bias=True, - key=key, - init_func=init_func, + key=key + # init_func=init_func, ) -class MLP(eqx.Module): +class MLP(AbstractPancaxModel): mlp: eqx.Module def __init__( @@ -58,7 +58,7 @@ def __init__( activation: Callable, key: jax.random.PRNGKey, use_final_bias: Optional[bool] = False, - init_func: Optional[Callable] = trunc_init, + # init_func: Optional[Callable] = trunc_init, ensure_positivity: Optional[bool] = False, ): if ensure_positivity: diff --git a/pancax/networks/parameters.py b/pancax/networks/parameters.py index 02402ff..d771eec 100644 --- a/pancax/networks/parameters.py +++ b/pancax/networks/parameters.py @@ -1,4 +1,4 @@ -from .base import BasePancaxModel +from .base import AbstractPancaxModel from .fields import Field from .mlp import MLP from ..domains import VariationalDomain @@ -13,7 +13,7 @@ State = Union[Float[Array, "nt ne nq ns"], eqx.Module, None] -class Parameters(BasePancaxModel): +class Parameters(AbstractPancaxModel): """ Data structure for storing all parameters needed for a model diff --git a/scripts/rocm-docker.sh b/scripts/rocm-docker.sh index c744906..7d381ae 100755 --- a/scripts/rocm-docker.sh +++ b/scripts/rocm-docker.sh @@ -1,6 +1,15 @@ -docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G \ ---group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/home/temp_user/pancax \ ---name rocm_jax rocm/jax-community:rocm6.2.3-jax0.4.33-py3.12.6 /bin/bash +docker run -it \ + --network=host \ + --device=/dev/kfd \ + --device=/dev/dri \ + --ipc=host \ + --shm-size 64G \ + --group-add video --cap-add=SYS_PTRACE \ + --security-opt \ + seccomp=unconfined \ + -v $(pwd):/home/temp_user/pancax \ +pancax /bin/bash +# --name rocm_jax rocm/jax-community:rocm6.2.3-jax0.4.33-py3.12.6 /bin/bash -docker attach rocm_jax +#docker attach rocm_jax diff --git a/test/constitutive_models/test_base_constitutive_model.py b/test/constitutive_models/test_base_constitutive_model.py index bb1b0e1..d9a5feb 100644 --- a/test/constitutive_models/test_base_constitutive_model.py +++ b/test/constitutive_models/test_base_constitutive_model.py @@ -5,12 +5,9 @@ def test_jacobian(): from pancax import NeoHookean import jax.numpy as jnp + model = NeoHookean(bulk_modulus=K, shear_modulus=G) - F = jnp.array([ - [4., 0., 0.], - [0., 2., 0.], - [0., 0., 1.] - ]) + F = jnp.array([[4.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 1.0]]) grad_u = F - jnp.eye(3) J = model.jacobian(grad_u) assert jnp.array_equal(J, jnp.linalg.det(F)) @@ -24,12 +21,9 @@ def test_jacobian(): def test_jacobian_bad_value(): from pancax import NeoHookean import jax.numpy as jnp + model = NeoHookean(bulk_modulus=K, shear_modulus=G) - F = jnp.array([ - [4., 0., 0.], - [0., 2., 0.], - [0., 0., -1.] - ]) + F = jnp.array([[4.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, -1.0]]) grad_u = F - jnp.eye(3) J = model.jacobian(grad_u) - assert jnp.array_equal(J, 1.e3) + assert jnp.array_equal(J, 1.0e3) diff --git a/test/constitutive_models/test_gent.py b/test/constitutive_models/test_gent.py index 36c7f0e..500e16c 100644 --- a/test/constitutive_models/test_gent.py +++ b/test/constitutive_models/test_gent.py @@ -3,127 +3,136 @@ K = 0.833 G = 0.3846 -Jm = 3. +Jm = 3.0 @pytest.fixture def gent_1(): - from pancax import Gent - return Gent( - bulk_modulus=K, - shear_modulus=G, - Jm_parameter=Jm - ) + from pancax import Gent + + return Gent(bulk_modulus=K, shear_modulus=G, Jm_parameter=Jm) @pytest.fixture def gent_2(): - from pancax import BoundedProperty, Gent - import jax - key = jax.random.key(0) - return Gent( - bulk_modulus=BoundedProperty(K, K, key), - shear_modulus=BoundedProperty(G, G, key), - Jm_parameter=BoundedProperty(Jm, Jm, key) - ) + from pancax import BoundedProperty, Gent + import jax + + key = jax.random.key(0) + return Gent( + bulk_modulus=BoundedProperty(K, K, key), + shear_modulus=BoundedProperty(G, G, key), + Jm_parameter=BoundedProperty(Jm, Jm, key), + ) def simple_shear_test(model): - from .utils import simple_shear - import jax - import jax.numpy as jnp - theta = 0. - state_old = jnp.zeros((100, 0)) - dt = 1. - gammas = jnp.linspace(0.0, 1., 100) - Fs = jax.vmap(simple_shear)(gammas) - grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs) - Js = jax.vmap(model.jacobian)(grad_us) - I1_bars = jax.vmap(model.I1_bar)(grad_us) - psis, _ = jax.vmap(model.energy, in_axes=(0, None, 0, None))( - grad_us, theta, state_old, dt - ) - sigmas, _ = jax.vmap(model.cauchy_stress, in_axes=(0, None, 0, None))( - grad_us, theta, state_old, dt - ) - - def vmap_func(gamma, I1_bar, J): - psi_an = 0.5 * K * (0.5 * (J**2 - 1) - jnp.log(J)) + \ - -0.5 * G * Jm * jnp.log(1. - (I1_bar - 3.) / Jm) - sigma_11_an = 2. / 3. * G * Jm * gamma**2 / (Jm - gamma**2) - sigma_22_an = -1. / 3. * G * Jm * gamma**2 / (Jm - gamma**2) - sigma_12_an = G * Jm * gamma / (Jm - gamma**2) - return psi_an, sigma_11_an, sigma_22_an, sigma_12_an - - psi_ans, sigma_11_ans, sigma_22_ans, sigma_12_ans = jax.vmap( - vmap_func, in_axes=(0, 0, 0) - )(gammas, I1_bars, Js) - - assert jnp.allclose(psis, psi_ans) - assert jnp.allclose(sigmas[:, 0, 0], sigma_11_ans) - assert jnp.allclose(sigmas[:, 1, 1], sigma_22_ans) - assert jnp.allclose(sigmas[:, 2, 2], sigma_22_ans) - # - assert jnp.allclose(sigmas[:, 0, 1], sigma_12_ans) - assert jnp.allclose(sigmas[:, 1, 2], 0.0) - assert jnp.allclose(sigmas[:, 2, 0], 0.0) - # # - assert jnp.allclose(sigmas[:, 1, 0], sigma_12_ans) - assert jnp.allclose(sigmas[:, 2, 1], 0.0) - assert jnp.allclose(sigmas[:, 0, 2], 0.0) + from .utils import simple_shear + import jax + import jax.numpy as jnp + + theta = 0.0 + state_old = jnp.zeros((100, 0)) + dt = 1.0 + gammas = jnp.linspace(0.0, 1.0, 100) + Fs = jax.vmap(simple_shear)(gammas) + grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs) + Js = jax.vmap(model.jacobian)(grad_us) + I1_bars = jax.vmap(model.I1_bar)(grad_us) + psis, _ = jax.vmap(model.energy, in_axes=(0, None, 0, None))( + grad_us, theta, state_old, dt + ) + sigmas, _ = jax.vmap(model.cauchy_stress, in_axes=(0, None, 0, None))( + grad_us, theta, state_old, dt + ) + + def vmap_func(gamma, I1_bar, J): + psi_an = 0.5 * K * (0.5 * (J**2 - 1) - jnp.log(J)) + \ + -0.5 * G * Jm * jnp.log( + 1.0 - (I1_bar - 3.0) / Jm + ) + sigma_11_an = 2.0 / 3.0 * G * Jm * gamma**2 / (Jm - gamma**2) + sigma_22_an = -1.0 / 3.0 * G * Jm * gamma**2 / (Jm - gamma**2) + sigma_12_an = G * Jm * gamma / (Jm - gamma**2) + return psi_an, sigma_11_an, sigma_22_an, sigma_12_an + + psi_ans, sigma_11_ans, sigma_22_ans, sigma_12_ans = jax.vmap( + vmap_func, in_axes=(0, 0, 0) + )(gammas, I1_bars, Js) + + assert jnp.allclose(psis, psi_ans) + assert jnp.allclose(sigmas[:, 0, 0], sigma_11_ans) + assert jnp.allclose(sigmas[:, 1, 1], sigma_22_ans) + assert jnp.allclose(sigmas[:, 2, 2], sigma_22_ans) + # + assert jnp.allclose(sigmas[:, 0, 1], sigma_12_ans) + assert jnp.allclose(sigmas[:, 1, 2], 0.0) + assert jnp.allclose(sigmas[:, 2, 0], 0.0) + # # + assert jnp.allclose(sigmas[:, 1, 0], sigma_12_ans) + assert jnp.allclose(sigmas[:, 2, 1], 0.0) + assert jnp.allclose(sigmas[:, 0, 2], 0.0) def uniaxial_strain_test(model): - from .utils import uniaxial_strain - import jax - import jax.numpy as jnp - theta = 0. - state_old = jnp.zeros((100, 0)) - dt = 1. - lambdas = jnp.linspace(1., 2., 100) - Fs = jax.vmap(uniaxial_strain)(lambdas) - grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs) - Js = jax.vmap(model.jacobian)(grad_us) - I1_bars = jax.vmap(model.I1_bar)(grad_us) - psis, _ = jax.vmap(model.energy, in_axes=(0, None, 0, None))( - grad_us, theta, state_old, dt - ) - sigmas, _ = jax.vmap(model.cauchy_stress, in_axes=(0, None, 0, None))( - grad_us, theta, state_old, dt - ) - - def vmap_func(lambda_, I1_bar, J): - psi_an = 0.5 * K * (0.5 * (J**2 - 1) - jnp.log(J)) + \ - -0.5 * G * Jm * jnp.log(1. - (I1_bar - 3.) / Jm) - sigma_11_an = 0.5 * K * (lambda_ - 1. / lambda_) - \ - 2. / 3. * G * Jm * (lambda_**2 - 1.) / (lambda_**3 - (Jm + 3) * lambda_**(5. / 3.) + 2. * lambda_) - sigma_22_an = 0.5 * K * (lambda_ - 1. / lambda_) + \ - 1. / 3. * G * Jm * (lambda_**2 - 1.) / (lambda_**3 - (Jm + 3) * lambda_**(5. / 3.) + 2. * lambda_) - return psi_an, sigma_11_an, sigma_22_an - - psi_ans, sigma_11_ans, sigma_22_ans = jax.vmap( - vmap_func, in_axes=(0, 0, 0) - )(lambdas, I1_bars, Js) - - assert jnp.allclose(psis, psi_ans) - assert jnp.allclose(sigmas[:, 0, 0], sigma_11_ans) - assert jnp.allclose(sigmas[:, 1, 1], sigma_22_ans) - assert jnp.allclose(sigmas[:, 2, 2], sigma_22_ans) - # - assert jnp.allclose(sigmas[:, 0, 1], 0.0) - assert jnp.allclose(sigmas[:, 1, 2], 0.0) - assert jnp.allclose(sigmas[:, 2, 0], 0.0) - # - assert jnp.allclose(sigmas[:, 1, 0], 0.0) - assert jnp.allclose(sigmas[:, 2, 1], 0.0) - assert jnp.allclose(sigmas[:, 0, 2], 0.0) + from .utils import uniaxial_strain + import jax + import jax.numpy as jnp + + theta = 0.0 + state_old = jnp.zeros((100, 0)) + dt = 1.0 + lambdas = jnp.linspace(1.0, 2.0, 100) + Fs = jax.vmap(uniaxial_strain)(lambdas) + grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs) + Js = jax.vmap(model.jacobian)(grad_us) + I1_bars = jax.vmap(model.I1_bar)(grad_us) + psis, _ = jax.vmap(model.energy, in_axes=(0, None, 0, None))( + grad_us, theta, state_old, dt + ) + sigmas, _ = jax.vmap(model.cauchy_stress, in_axes=(0, None, 0, None))( + grad_us, theta, state_old, dt + ) + + def vmap_func(lambda_, I1_bar, J): + psi_an = 0.5 * K * (0.5 * (J**2 - 1) - jnp.log(J)) + \ + -0.5 * G * Jm * jnp.log( + 1.0 - (I1_bar - 3.0) / Jm + ) + sigma_11_an = 0.5 * K * (lambda_ - 1.0 / lambda_) - \ + 2.0 / 3.0 * G * Jm * ( + lambda_**2 - 1.0 + ) / (lambda_**3 - (Jm + 3) * lambda_ ** (5.0 / 3.0) + 2.0 * lambda_) + sigma_22_an = 0.5 * K * (lambda_ - 1.0 / lambda_) + \ + 1.0 / 3.0 * G * Jm * ( + lambda_**2 - 1.0 + ) / (lambda_**3 - (Jm + 3) * lambda_ ** (5.0 / 3.0) + 2.0 * lambda_) + return psi_an, sigma_11_an, sigma_22_an + + psi_ans, sigma_11_ans, sigma_22_ans = jax.vmap( + vmap_func, in_axes=(0, 0, 0))( + lambdas, I1_bars, Js + ) + + assert jnp.allclose(psis, psi_ans) + assert jnp.allclose(sigmas[:, 0, 0], sigma_11_ans) + assert jnp.allclose(sigmas[:, 1, 1], sigma_22_ans) + assert jnp.allclose(sigmas[:, 2, 2], sigma_22_ans) + # + assert jnp.allclose(sigmas[:, 0, 1], 0.0) + assert jnp.allclose(sigmas[:, 1, 2], 0.0) + assert jnp.allclose(sigmas[:, 2, 0], 0.0) + # + assert jnp.allclose(sigmas[:, 1, 0], 0.0) + assert jnp.allclose(sigmas[:, 2, 1], 0.0) + assert jnp.allclose(sigmas[:, 0, 2], 0.0) def test_simple_shear(gent_1, gent_2): - simple_shear_test(gent_1) - simple_shear_test(gent_2) + simple_shear_test(gent_1) + simple_shear_test(gent_2) def test_uniaxial_strain(gent_1, gent_2): - uniaxial_strain_test(gent_1) - uniaxial_strain_test(gent_2) + uniaxial_strain_test(gent_1) + uniaxial_strain_test(gent_2) diff --git a/test/constitutive_models/test_hencky.py b/test/constitutive_models/test_hencky.py index 27e4a6f..f5060e0 100644 --- a/test/constitutive_models/test_hencky.py +++ b/test/constitutive_models/test_hencky.py @@ -4,89 +4,93 @@ K = 0.833 G = 0.3846 + @pytest.fixture def hencky_1(): - from pancax import Hencky - return Hencky( - bulk_modulus=K, - shear_modulus=G - ) + from pancax import Hencky + + return Hencky(bulk_modulus=K, shear_modulus=G) + @pytest.fixture def hencky_2(): - from pancax import BoundedProperty, Hencky - import jax - key = jax.random.key(0) - return Hencky( - bulk_modulus=BoundedProperty(K, K, key), - shear_modulus=BoundedProperty(G, G, key) - ) + from pancax import BoundedProperty, Hencky + import jax + + key = jax.random.key(0) + return Hencky( + bulk_modulus=BoundedProperty(K, K, key), + shear_modulus=BoundedProperty(G, G, key), + ) # TODO fix this test + def simple_shear_test(model): - from pancax.math import tensor_math - from .utils import simple_shear - import jax - import jax.numpy as jnp - theta = 0. - state_old = jnp.zeros((100, 0)) - dt = 1. - gammas = jnp.linspace(0.0, 1., 100) - Fs = jax.vmap(simple_shear)(gammas) - grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs) - # Js = jax.vmap(model.jacobian)(grad_us) - # I1_bars = jax.vmap(model.I1_bar)(grad_us) - Js = jax.vmap(model.jacobian)(grad_us) - Es = jax.vmap(model.log_strain)(grad_us) - trEs = jax.vmap(jnp.trace)(Es) - - # Edevs = jax.vmap(lambda E: E - (1. / 3.) * jnp.trace(E) * jnp.eye(3))(Es) - Edevs = jax.vmap(tensor_math.dev)(Es) - psis, _ = jax.vmap(model.energy, in_axes=(0, None, 0, None))( - grad_us, theta, state_old, dt - ) - sigmas, _ = jax.vmap(model.cauchy_stress, in_axes=(0, None, 0, None))( - grad_us, theta, state_old, dt - ) - sigmas_an = jax.vmap( - lambda trE, devE, J: (K * trE * jnp.eye(3) + 2. * G * devE) / J, in_axes=(0, 0, 0) - )(trEs, Edevs, Js) - - # print(psis) - # print(Es[-1, :, :]) - # print(sigmas[-1, :, :]) - # print(sigmas_an[-1, :, :]) - # # print(sigmas - sigmas_an) - # assert jnp.allclose(sigmas, sigmas_an, atol=1e-8) - # assert False - # for (psi, sigma, gamma, I1_bar, J) in zip(psis, sigmas, gammas, I1_bars, Js): - # psi_an = 0.5 * K * (0.5 * (J**2 - 1) - jnp.log(J)) + \ - # -0.5 * G * Jm * jnp.log(1. - (I1_bar - 3.) / Jm) - # sigma_11_an = 2. / 3. * G * Jm * gamma**2 / (Jm - gamma**2) - # sigma_22_an = -1. / 3. * G * Jm * gamma**2 / (Jm - gamma**2) - # sigma_12_an = G * Jm * gamma / (Jm - gamma**2) - # assert jnp.allclose(psi, psi_an) - # assert jnp.allclose(sigma[0, 0], sigma_11_an) - # assert jnp.allclose(sigma[1, 1], sigma_22_an) - # assert jnp.allclose(sigma[2, 2], sigma_22_an) - # # # - # assert jnp.allclose(sigma[0, 1], sigma_12_an) - # assert jnp.allclose(sigma[1, 2], 0.0) - # assert jnp.allclose(sigma[2, 0], 0.0) - # # # - # assert jnp.allclose(sigma[1, 0], sigma_12_an) - # assert jnp.allclose(sigma[2, 1], 0.0) - # assert jnp.allclose(sigma[0, 2], 0.0) + from pancax.math import tensor_math + from .utils import simple_shear + import jax + import jax.numpy as jnp + + theta = 0.0 + state_old = jnp.zeros((100, 0)) + dt = 1.0 + gammas = jnp.linspace(0.0, 1.0, 100) + Fs = jax.vmap(simple_shear)(gammas) + grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs) + # Js = jax.vmap(model.jacobian)(grad_us) + # I1_bars = jax.vmap(model.I1_bar)(grad_us) + Js = jax.vmap(model.jacobian)(grad_us) + Es = jax.vmap(model.log_strain)(grad_us) + trEs = jax.vmap(jnp.trace)(Es) + + # Edevs = jax.vmap(lambda E: E - (1. / 3.) * jnp.trace(E) * jnp.eye(3))(Es) + Edevs = jax.vmap(tensor_math.dev)(Es) + psis, _ = jax.vmap(model.energy, in_axes=(0, None, 0, None))( + grad_us, theta, state_old, dt + ) + sigmas, _ = jax.vmap(model.cauchy_stress, in_axes=(0, None, 0, None))( + grad_us, theta, state_old, dt + ) + jax.vmap( + lambda trE, devE, J: (K * trE * jnp.eye(3) + 2.0 * G * devE) / J, + in_axes=(0, 0, 0), + )(trEs, Edevs, Js) + + # print(psis) + # print(Es[-1, :, :]) + # print(sigmas[-1, :, :]) + # print(sigmas_an[-1, :, :]) + # # print(sigmas - sigmas_an) + # assert jnp.allclose(sigmas, sigmas_an, atol=1e-8) + # assert False + # for (psi, sigma, gamma, I1_bar, J) in \ + # zip(psis, sigmas, gammas, I1_bars, Js): + # psi_an = 0.5 * K * (0.5 * (J**2 - 1) - jnp.log(J)) + \ + # -0.5 * G * Jm * jnp.log(1. - (I1_bar - 3.) / Jm) + # sigma_11_an = 2. / 3. * G * Jm * gamma**2 / (Jm - gamma**2) + # sigma_22_an = -1. / 3. * G * Jm * gamma**2 / (Jm - gamma**2) + # sigma_12_an = G * Jm * gamma / (Jm - gamma**2) + # assert jnp.allclose(psi, psi_an) + # assert jnp.allclose(sigma[0, 0], sigma_11_an) + # assert jnp.allclose(sigma[1, 1], sigma_22_an) + # assert jnp.allclose(sigma[2, 2], sigma_22_an) + # # # + # assert jnp.allclose(sigma[0, 1], sigma_12_an) + # assert jnp.allclose(sigma[1, 2], 0.0) + # assert jnp.allclose(sigma[2, 0], 0.0) + # # # + # assert jnp.allclose(sigma[1, 0], sigma_12_an) + # assert jnp.allclose(sigma[2, 1], 0.0) + # assert jnp.allclose(sigma[0, 2], 0.0) def test_simple_shear(hencky_1, hencky_2): - simple_shear_test(hencky_1) - simple_shear_test(hencky_2) + simple_shear_test(hencky_1) + simple_shear_test(hencky_2) # def test_uniaxial_strain(gent_1, gent_2): # uniaxial_strain_test(gent_1) # uniaxial_strain_test(gent_2) - diff --git a/test/constitutive_models/test_hyperviscoelastic.py b/test/constitutive_models/test_hyperviscoelastic.py index 3ddb638..5f71e4a 100644 --- a/test/constitutive_models/test_hyperviscoelastic.py +++ b/test/constitutive_models/test_hyperviscoelastic.py @@ -3,174 +3,212 @@ @pytest.fixture def model(): - from pancax import NeoHookean, PronySeries, SimpleFeFv, WLF - return SimpleFeFv( - NeoHookean(bulk_modulus=1000., shear_modulus=0.855), - PronySeries( - moduli=[1., 2., 3.], - relaxation_times=[1., 10., 100.] - ), - WLF(C1=17.44, C2=51.6, theta_ref=60.), - ) + from pancax import NeoHookean, PronySeries, SimpleFeFv, WLF + + return SimpleFeFv( + NeoHookean(bulk_modulus=1000.0, shear_modulus=0.855), + PronySeries( + moduli=[1.0, 2.0, 3.0], + relaxation_times=[1.0, 10.0, 100.0] + ), + WLF(C1=17.44, C2=51.6, theta_ref=60.0), + ) def test_initial_state(model): - import jax.numpy as jnp - state = model.initial_state() - print(state) - assert jnp.allclose(state, jnp.array([ - 1., 0., 0., 0., 1., 0., 0., 0., 1., - 1., 0., 0., 0., 1., 0., 0., 0., 1., - 1., 0., 0., 0., 1., 0., 0., 0., 1. - ])) + import jax.numpy as jnp + + state = model.initial_state() + print(state) + assert jnp.allclose( + state, + jnp.array( + [ + 1.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 1.0, + 1.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 1.0, + 1.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 1.0, + ] + ), + ) def test_extract_tensor(model): - import jax.numpy as jnp - state = jnp.linspace(1., 27., 27) - Fv_1 = model.extract_tensor(state, 0) - Fv_2 = model.extract_tensor(state, 9) - Fv_3 = model.extract_tensor(state, 18) - - assert jnp.allclose(Fv_1, jnp.array([ - [1., 2., 3.], - [4., 5., 6.], - [7., 8., 9.] - ])) - assert jnp.allclose(Fv_2, jnp.array([ - [10., 11., 12.], - [13., 14., 15.], - [16., 17., 18.] - ])) - assert jnp.allclose(Fv_3, jnp.array([ - [19., 20., 21.], - [22., 23., 24.], - [25., 26., 27.] - ])) - - Fvs = state.reshape((model.num_prony_terms(), 3, 3)) - - assert jnp.allclose(Fvs[0, :, :], jnp.array([ - [1., 2., 3.], - [4., 5., 6.], - [7., 8., 9.] - ])) - assert jnp.allclose(Fvs[1, :, :], jnp.array([ - [10., 11., 12.], - [13., 14., 15.], - [16., 17., 18.] - ])) - assert jnp.allclose(Fvs[2, :, :], jnp.array([ - [19., 20., 21.], - [22., 23., 24.], - [25., 26., 27.] - ])) + import jax.numpy as jnp + + state = jnp.linspace(1.0, 27.0, 27) + Fv_1 = model.extract_tensor(state, 0) + Fv_2 = model.extract_tensor(state, 9) + Fv_3 = model.extract_tensor(state, 18) + + assert jnp.allclose( + Fv_1, jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) + ) + assert jnp.allclose( + Fv_2, jnp.array([ + [10.0, 11.0, 12.0], + [13.0, 14.0, 15.0], + [16.0, 17.0, 18.0]] + ) + ) + assert jnp.allclose( + Fv_3, jnp.array([ + [19.0, 20.0, 21.0], + [22.0, 23.0, 24.0], + [25.0, 26.0, 27.0]] + ) + ) + + Fvs = state.reshape((model.num_prony_terms(), 3, 3)) + + assert jnp.allclose( + Fvs[0, :, :], jnp.array([ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0]] + ) + ) + assert jnp.allclose( + Fvs[1, :, :], + jnp.array([ + [10.0, 11.0, 12.0], + [13.0, 14.0, 15.0], + [16.0, 17.0, 18.0]] + ), + ) + assert jnp.allclose( + Fvs[2, :, :], + jnp.array([ + [19.0, 20.0, 21.0], + [22.0, 23.0, 24.0], + [25.0, 26.0, 27.0]] + ), + ) def test_model(model): - from .utils import uniaxial_strain - import jax - import jax.numpy as jnp - - strain_rate = 1.e-2 - total_time = 100.0 - n_steps = 100 - times = jnp.linspace(0., total_time, n_steps) - # F = uniaxial_strain(1.1) - Fs = jax.vmap(lambda t: uniaxial_strain(jnp.exp(strain_rate * t)))(times) - grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs) - # grad_u = F - jnp.eye(3) - theta = model.shift_factor_model.theta_ref - state_old = model.initial_state() - dt = total_time / n_steps - - # psi, state_new = model.energy(grad_u, theta, state_old, dt) - # print(psi) - # print(state_new) - # print(model) - - energies = jnp.zeros(n_steps) - states = jnp.zeros((n_steps, len(state_old))) - - energy_func = jax.jit(model.energy) - log_strain_func = jax.jit(model.log_strain) - - for n, grad_u in enumerate(grad_us): - psi, state_new = energy_func(grad_u, theta, state_old, dt) - state_old = state_new - energies = energies.at[n].set(psi) - states = states.at[n, :].set(state_new) - - print(energies) - print(states) - - # plt.figure(1) - # plt.plot(times, states[:, 0]) - # plt.savefig('state_evolution.png') - - Gs = model.prony_series.moduli - taus = model.prony_series.relaxation_times - etas = Gs * taus - - for n in range(3): - Fvs = jax.vmap(lambda Fv: Fv.at[9 * n:9 * (n + 1)].get().reshape((3, 3)))(states) - Fes = jax.vmap(lambda F, Fv: F @ jnp.linalg.inv(Fv), in_axes=(0, 0))(Fs, Fvs) - grad_uvs = jax.vmap(lambda F: F - jnp.eye(3))(Fvs) - grad_ues = jax.vmap(lambda F: F - jnp.eye(3))(Fes) - - Evs = jax.vmap(log_strain_func)(grad_uvs) - Ees = jax.vmap(log_strain_func)(grad_ues) - - # analytic solution - e_v_11 = (2. / 3.) * strain_rate * times - \ - (2. / 3.) * strain_rate * taus[n] * (1. - jnp.exp(-times / taus[n])) - - e_e_11 = strain_rate * times - e_v_11 - e_e_22 = 0.5 * e_v_11 - - # Ee_analytic = jax.vmap( - # lambda e_11, e_22: np.array( - # [[e_11, 0., 0.], - # [0., e_22, 0.], - # [0., 0., e_22]] - # ), in_axes=(0, 0) - # )(e_e_11, e_e_22) - - # Me_analytic = jax.vmap(lambda Ee: 2. * Gs[n] * deviator(Ee))(Ee_analytic) - # Dv_analytic = jax.vmap(lambda Me: 1. / (2. * self.etas[n]) * deviator(Me))(Me_analytic) - # dissipated_energies_analytic += jax.vmap(lambda Dv: dt * self.etas[n] * np.tensordot(deviator(Dv), deviator(Dv)) )(Dv_analytic) - - # test - assert jnp.isclose(Evs[:, 0, 0], e_v_11, atol=2.5e-3).all() - assert jnp.isclose(Ees[:, 0, 0], e_e_11, atol=2.5e-3).all() - assert jnp.isclose(Ees[:, 1, 1], e_e_22, atol=2.5e-3).all() + from .utils import uniaxial_strain + import jax + import jax.numpy as jnp + + strain_rate = 1.0e-2 + total_time = 100.0 + n_steps = 100 + times = jnp.linspace(0.0, total_time, n_steps) + # F = uniaxial_strain(1.1) + Fs = jax.vmap(lambda t: uniaxial_strain(jnp.exp(strain_rate * t)))(times) + grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs) + # grad_u = F - jnp.eye(3) + theta = model.shift_factor_model.theta_ref + state_old = model.initial_state() + dt = total_time / n_steps + + # psi, state_new = model.energy(grad_u, theta, state_old, dt) + # print(psi) + # print(state_new) + # print(model) + + energies = jnp.zeros(n_steps) + states = jnp.zeros((n_steps, len(state_old))) + + # energy_func = jax.jit(model.energy) + # log_strain_func = jax.jit(model.log_strain) + energy_func = model.energy + log_strain_func = model.log_strain + + for n, grad_u in enumerate(grad_us): + psi, state_new = energy_func(grad_u, theta, state_old, dt) + state_old = state_new + energies = energies.at[n].set(psi) + states = states.at[n, :].set(state_new) + + print(energies) + print(states) + + # plt.figure(1) + # plt.plot(times, states[:, 0]) + # plt.savefig('state_evolution.png') + + taus = model.prony_series.relaxation_times + + for n in range(3): + Fvs = jax.vmap( + lambda Fv: Fv.at[9 * n:9 * (n + 1)].get().reshape((3, 3)))( + states + ) + Fes = jax.vmap( + lambda F, Fv: F @ jnp.linalg.inv(Fv), in_axes=(0, 0) + )(Fs, Fvs) + grad_uvs = jax.vmap(lambda F: F - jnp.eye(3))(Fvs) + grad_ues = jax.vmap(lambda F: F - jnp.eye(3))(Fes) + + Evs = jax.vmap(log_strain_func)(grad_uvs) + Ees = jax.vmap(log_strain_func)(grad_ues) + + # analytic solution + e_v_11 = (2.0 / 3.0) * strain_rate * times - \ + (2.0 / 3.0) * strain_rate * taus[ + n + ] * (1.0 - jnp.exp(-times / taus[n])) + + e_e_11 = strain_rate * times - e_v_11 + e_e_22 = 0.5 * e_v_11 + + # test + assert jnp.isclose(Evs[:, 0, 0], e_v_11, atol=2.5e-3).all() + assert jnp.isclose(Ees[:, 0, 0], e_e_11, atol=2.5e-3).all() + assert jnp.isclose(Ees[:, 1, 1], e_e_22, atol=2.5e-3).all() # NOTE this test is dumb... it's just checking vmap capabilities def test_with_vmap(model): - from .utils import uniaxial_strain - import jax - import jax.numpy as jnp - strain_rate = 1.e-2 - total_time = 100.0 - n_steps = 100 - times = jnp.linspace(0., total_time, n_steps) - # F = uniaxial_strain(1.1) - Fs = jax.vmap(lambda t: uniaxial_strain(jnp.exp(strain_rate * t)))(times) - grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs) - # grad_u = F - jnp.eye(3) - theta = model.shift_factor_model.theta_ref - state_old = model.initial_state() - states_old = jnp.tile(state_old, (n_steps, 1)) - dt = total_time / n_steps - - energies = jnp.zeros(n_steps) - states = jnp.zeros((n_steps, len(state_old))) - - energy_func = jax.jit(jax.vmap(model.energy, in_axes=(0, None, 0, None))) - log_strain_func = jax.jit(model.log_strain) - - psis, states_new = energy_func(grad_us, theta, states_old, dt) - print(psis.shape) - print(states_new.shape) - # assert False + from .utils import uniaxial_strain + import jax + import jax.numpy as jnp + + strain_rate = 1.0e-2 + total_time = 100.0 + n_steps = 100 + times = jnp.linspace(0.0, total_time, n_steps) + # F = uniaxial_strain(1.1) + Fs = jax.vmap(lambda t: uniaxial_strain(jnp.exp(strain_rate * t)))(times) + grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs) + # grad_u = F - jnp.eye(3) + theta = model.shift_factor_model.theta_ref + state_old = model.initial_state() + states_old = jnp.tile(state_old, (n_steps, 1)) + dt = total_time / n_steps + + # energies = jnp.zeros(n_steps) + # states = jnp.zeros((n_steps, len(state_old))) + + energy_func = jax.jit(jax.vmap(model.energy, in_axes=(0, None, 0, None))) + # log_strain_func = jax.jit(model.log_strain) + + psis, states_new = energy_func(grad_us, theta, states_old, dt) + print(psis.shape) + print(states_new.shape) + # assert False diff --git a/test/constitutive_models/test_neohookean.py b/test/constitutive_models/test_neohookean.py index 3d5943b..29ad43a 100644 --- a/test/constitutive_models/test_neohookean.py +++ b/test/constitutive_models/test_neohookean.py @@ -7,120 +7,124 @@ @pytest.fixture def neohookean_1(): - from pancax import NeoHookean - return NeoHookean( - bulk_modulus=K, - shear_modulus=G - ) + from pancax import NeoHookean + + return NeoHookean(bulk_modulus=K, shear_modulus=G) @pytest.fixture def neohookean_2(): - from pancax import BoundedProperty, NeoHookean - import jax - key = jax.random.key(0) - return NeoHookean( - bulk_modulus=BoundedProperty(K, K, key), - shear_modulus=BoundedProperty(G, G, key) - ) + from pancax import BoundedProperty, NeoHookean + import jax + + key = jax.random.key(0) + return NeoHookean( + bulk_modulus=BoundedProperty(K, K, key), + shear_modulus=BoundedProperty(G, G, key), + ) def simple_shear_test(model): - from .utils import simple_shear - import jax - import jax.numpy as jnp - theta = 0. - state_old = jnp.zeros((100, 0)) - dt = 1. - gammas = jnp.linspace(0.0, 1., 100) - Fs = jax.vmap(simple_shear)(gammas) - grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs) - Js = jax.vmap(model.jacobian)(grad_us) - I1_bars = jax.vmap(model.I1_bar)(grad_us) - psis, _ = jax.vmap(model.energy, in_axes=(0, None, 0, None))( - grad_us, theta, state_old, dt - ) - sigmas, _ = jax.vmap(model.cauchy_stress, in_axes=(0, None, 0, None))( - grad_us, theta, state_old, dt - ) - - def vmap_func(gamma, I1_bar, J): - psi_an = 0.5 * K * (0.5 * (J**2 - 1) - jnp.log(J)) + \ - 0.5 * G * (I1_bar - 3.) - sigma_11_an = 2. / 3. * G * gamma**2 - sigma_22_an = -1. / 3. * G * gamma**2 - sigma_12_an = G * gamma - return psi_an, sigma_11_an, sigma_22_an, sigma_12_an - - psi_ans, sigma_11_ans, sigma_22_ans, sigma_12_ans = jax.vmap( - vmap_func, in_axes=(0, 0, 0) - )(gammas, I1_bars, Js) - - assert jnp.allclose(psis, psi_ans) - assert jnp.allclose(sigmas[:, 0, 0], sigma_11_ans) - assert jnp.allclose(sigmas[:, 1, 1], sigma_22_ans) - assert jnp.allclose(sigmas[:, 2, 2], sigma_22_ans) - # - assert jnp.allclose(sigmas[:, 0, 1], sigma_12_ans) - assert jnp.allclose(sigmas[:, 1, 2], 0.0) - assert jnp.allclose(sigmas[:, 2, 0], 0.0) - # # - assert jnp.allclose(sigmas[:, 1, 0], sigma_12_ans) - assert jnp.allclose(sigmas[:, 2, 1], 0.0) - assert jnp.allclose(sigmas[:, 0, 2], 0.0) + from .utils import simple_shear + import jax + import jax.numpy as jnp + + theta = 0.0 + state_old = jnp.zeros((100, 0)) + dt = 1.0 + gammas = jnp.linspace(0.0, 1.0, 100) + Fs = jax.vmap(simple_shear)(gammas) + grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs) + Js = jax.vmap(model.jacobian)(grad_us) + I1_bars = jax.vmap(model.I1_bar)(grad_us) + psis, _ = jax.vmap(model.energy, in_axes=(0, None, 0, None))( + grad_us, theta, state_old, dt + ) + sigmas, _ = jax.vmap(model.cauchy_stress, in_axes=(0, None, 0, None))( + grad_us, theta, state_old, dt + ) + + def vmap_func(gamma, I1_bar, J): + psi_an = 0.5 * K * (0.5 * (J**2 - 1) - jnp.log(J)) +\ + 0.5 * G * (I1_bar - 3.0) + sigma_11_an = 2.0 / 3.0 * G * gamma**2 + sigma_22_an = -1.0 / 3.0 * G * gamma**2 + sigma_12_an = G * gamma + return psi_an, sigma_11_an, sigma_22_an, sigma_12_an + + psi_ans, sigma_11_ans, sigma_22_ans, sigma_12_ans = jax.vmap( + vmap_func, in_axes=(0, 0, 0) + )(gammas, I1_bars, Js) + + assert jnp.allclose(psis, psi_ans) + assert jnp.allclose(sigmas[:, 0, 0], sigma_11_ans) + assert jnp.allclose(sigmas[:, 1, 1], sigma_22_ans) + assert jnp.allclose(sigmas[:, 2, 2], sigma_22_ans) + # + assert jnp.allclose(sigmas[:, 0, 1], sigma_12_ans) + assert jnp.allclose(sigmas[:, 1, 2], 0.0) + assert jnp.allclose(sigmas[:, 2, 0], 0.0) + # # + assert jnp.allclose(sigmas[:, 1, 0], sigma_12_ans) + assert jnp.allclose(sigmas[:, 2, 1], 0.0) + assert jnp.allclose(sigmas[:, 0, 2], 0.0) def uniaxial_strain_test(model): - from .utils import uniaxial_strain - import jax - import jax.numpy as jnp - theta = 0. - state_old = jnp.zeros((100, 0)) - dt = 1. - lambdas = jnp.linspace(1., 4., 100) - Fs = jax.vmap(uniaxial_strain)(lambdas) - grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs) - Js = jax.vmap(model.jacobian)(grad_us) - I1_bars = jax.vmap(model.I1_bar)(grad_us) - psis, _ = jax.vmap(model.energy, in_axes=(0, None, 0, None))( - grad_us, theta, state_old, dt - ) - sigmas, _ = jax.vmap(model.cauchy_stress, in_axes=(0, None, 0, None))( - grad_us, theta, state_old, dt - ) - - def vmap_func(lambda_, I1_bar, J): - psi_an = 0.5 * K * (0.5 * (J**2 - 1) - jnp.log(J)) + \ - 0.5 * G * (I1_bar - 3.) - sigma_11_an = 0.5 * K * (lambda_ - 1. / lambda_) + \ - 2. / 3. * G * (lambda_**2 - 1.) * lambda_**(-5. / 3.) - sigma_22_an = 0.5 * K * (lambda_ - 1. / lambda_) - \ - 1. / 3. * G * (lambda_**2 - 1.) * lambda_**(-5. / 3.) - return psi_an, sigma_11_an, sigma_22_an - - psi_ans, sigma_11_ans, sigma_22_ans = jax.vmap( - vmap_func, in_axes=(0, 0, 0) - )(lambdas, I1_bars, Js) - - assert jnp.allclose(psis, psi_ans) - assert jnp.allclose(sigmas[:, 0, 0], sigma_11_ans) - assert jnp.allclose(sigmas[:, 1, 1], sigma_22_ans) - assert jnp.allclose(sigmas[:, 2, 2], sigma_22_ans) - # - assert jnp.allclose(sigmas[:, 0, 1], 0.0) - assert jnp.allclose(sigmas[:, 1, 2], 0.0) - assert jnp.allclose(sigmas[:, 2, 0], 0.0) - # - assert jnp.allclose(sigmas[:, 1, 0], 0.0) - assert jnp.allclose(sigmas[:, 2, 1], 0.0) - assert jnp.allclose(sigmas[:, 0, 2], 0.0) + from .utils import uniaxial_strain + import jax + import jax.numpy as jnp + + theta = 0.0 + state_old = jnp.zeros((100, 0)) + dt = 1.0 + lambdas = jnp.linspace(1.0, 4.0, 100) + Fs = jax.vmap(uniaxial_strain)(lambdas) + grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs) + Js = jax.vmap(model.jacobian)(grad_us) + I1_bars = jax.vmap(model.I1_bar)(grad_us) + psis, _ = jax.vmap(model.energy, in_axes=(0, None, 0, None))( + grad_us, theta, state_old, dt + ) + sigmas, _ = jax.vmap(model.cauchy_stress, in_axes=(0, None, 0, None))( + grad_us, theta, state_old, dt + ) + + def vmap_func(lambda_, I1_bar, J): + psi_an = 0.5 * K * (0.5 * (J**2 - 1) - jnp.log(J)) +\ + 0.5 * G * (I1_bar - 3.0) + sigma_11_an = 0.5 * K * (lambda_ - 1.0 / lambda_) + 2.0 / 3.0 * G * ( + lambda_**2 - 1.0 + ) * lambda_ ** (-5.0 / 3.0) + sigma_22_an = 0.5 * K * (lambda_ - 1.0 / lambda_) - 1.0 / 3.0 * G * ( + lambda_**2 - 1.0 + ) * lambda_ ** (-5.0 / 3.0) + return psi_an, sigma_11_an, sigma_22_an + + psi_ans, sigma_11_ans, sigma_22_ans = jax.vmap( + vmap_func, in_axes=(0, 0, 0))( + lambdas, I1_bars, Js + ) + + assert jnp.allclose(psis, psi_ans) + assert jnp.allclose(sigmas[:, 0, 0], sigma_11_ans) + assert jnp.allclose(sigmas[:, 1, 1], sigma_22_ans) + assert jnp.allclose(sigmas[:, 2, 2], sigma_22_ans) + # + assert jnp.allclose(sigmas[:, 0, 1], 0.0) + assert jnp.allclose(sigmas[:, 1, 2], 0.0) + assert jnp.allclose(sigmas[:, 2, 0], 0.0) + # + assert jnp.allclose(sigmas[:, 1, 0], 0.0) + assert jnp.allclose(sigmas[:, 2, 1], 0.0) + assert jnp.allclose(sigmas[:, 0, 2], 0.0) def test_simple_shear(neohookean_1, neohookean_2): - simple_shear_test(neohookean_1) - simple_shear_test(neohookean_2) + simple_shear_test(neohookean_1) + simple_shear_test(neohookean_2) def test_uniaxial_strain(neohookean_1, neohookean_2): - uniaxial_strain_test(neohookean_1) - uniaxial_strain_test(neohookean_2) + uniaxial_strain_test(neohookean_1) + uniaxial_strain_test(neohookean_2) diff --git a/test/constitutive_models/test_swanson.py b/test/constitutive_models/test_swanson.py index 579bc2e..9346285 100644 --- a/test/constitutive_models/test_swanson.py +++ b/test/constitutive_models/test_swanson.py @@ -1,11 +1,11 @@ # from pancax import BoundedProperty, Swanson # from .utils import * -# import jax +# import jax # import jax.numpy as jnp import pytest -K = 10. +K = 10.0 A1 = 0.93074321 P1 = -0.07673672 B1 = 0.0 @@ -16,138 +16,150 @@ @pytest.fixture def swanson_1(): - from pancax import Swanson - return Swanson( - bulk_modulus=K, - A1=A1, - P1=P1, - B1=B1, - Q1=Q1, - C1=C1, - R1=R1, - cutoff_strain=0.01 - ) + from pancax import Swanson + + return Swanson( + bulk_modulus=K, A1=A1, P1=P1, B1=B1, Q1=Q1, C1=C1, R1=R1, + cutoff_strain=0.01 + ) @pytest.fixture def swanson_2(): - from pancax import BoundedProperty, Swanson - import jax - key = jax.random.key(0) - return Swanson( - bulk_modulus=BoundedProperty(K, K, key), - A1=BoundedProperty(A1, A1, key), - P1=BoundedProperty(P1, P1, key), - B1=BoundedProperty(B1, B1, key), - Q1=BoundedProperty(Q1, Q1, key), - C1=BoundedProperty(C1, C1, key), - R1=BoundedProperty(R1, R1, key), - cutoff_strain=0.01 - ) + from pancax import BoundedProperty, Swanson + import jax + + key = jax.random.key(0) + return Swanson( + bulk_modulus=BoundedProperty(K, K, key), + A1=BoundedProperty(A1, A1, key), + P1=BoundedProperty(P1, P1, key), + B1=BoundedProperty(B1, B1, key), + Q1=BoundedProperty(Q1, Q1, key), + C1=BoundedProperty(C1, C1, key), + R1=BoundedProperty(R1, R1, key), + cutoff_strain=0.01, + ) def simple_shear_test(model): - from .utils import simple_shear - import jax - import jax.numpy as jnp - theta = 0. - state_old = jnp.zeros((100, 0)) - dt = 1. - gammas = jnp.linspace(0.05, 1., 100) - Fs = jax.vmap(simple_shear)(gammas) - grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs) - B_bars = jax.vmap(lambda F: jnp.power(jnp.linalg.det(F), -2. / 3.) * F @ F.T)(Fs) - Js = jax.vmap(model.jacobian)(grad_us) - I1_bars = jax.vmap(model.I1_bar)(grad_us) - psis, _ = jax.vmap(model.energy, in_axes=(0, None, 0, None))( - grad_us, theta, state_old, dt - ) - sigmas, _ = jax.vmap(model.cauchy_stress, in_axes=(0, None, 0, None))( - grad_us, theta, state_old, dt - ) - - def vmap_func(B_bar, I1_bar, J): - psi_an = K * (J * jnp.log(J) - J + 1.) + \ - 1.5 * A1 / (P1 + 1.) * (I1_bar / 3. - 1.)**(P1 + 1.) + \ - 1.5 * C1 / (R1 + 1.) * (I1_bar / 3. - 1.)**(R1 + 1.) - dUdI1_bar = 0.5 * A1 * (I1_bar / 3. - 1.)**P1 + \ - 0.5 * C1 * (I1_bar / 3. - 1.)**R1 - B_bar_dev = B_bar - (1. / 3.) * jnp.trace(B_bar) * jnp.eye(3) - sigma_an = (2. / J) * dUdI1_bar * B_bar_dev + K * jnp.log(J) * jnp.eye(3) - return psi_an, sigma_an[0, 0], sigma_an[1, 1], sigma_an[0, 1] - - psi_ans, sigma_11_ans, sigma_22_ans, sigma_12_ans = jax.vmap( - vmap_func, in_axes=(0, 0, 0) - )(B_bars, I1_bars, Js) - - assert jnp.allclose(psis, psi_ans, atol=1e-3) - assert jnp.allclose(sigmas[:, 0, 0], sigma_11_ans, atol=1e-3) - assert jnp.allclose(sigmas[:, 1, 1], sigma_22_ans, atol=1e-3) - assert jnp.allclose(sigmas[:, 2, 2], sigma_22_ans, atol=1e-3) - # - assert jnp.allclose(sigmas[:, 0, 1], sigma_12_ans, atol=1e-3) - assert jnp.allclose(sigmas[:, 1, 2], 0.0) - assert jnp.allclose(sigmas[:, 2, 0], 0.0) - # # - assert jnp.allclose(sigmas[:, 1, 0], sigma_12_ans, atol=1e-3) - assert jnp.allclose(sigmas[:, 2, 1], 0.0) - assert jnp.allclose(sigmas[:, 0, 2], 0.0) + from .utils import simple_shear + import jax + import jax.numpy as jnp + + theta = 0.0 + state_old = jnp.zeros((100, 0)) + dt = 1.0 + gammas = jnp.linspace(0.05, 1.0, 100) + Fs = jax.vmap(simple_shear)(gammas) + grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs) + B_bars = jax.vmap( + lambda F: jnp.power(jnp.linalg.det(F), -2.0 / 3.0) * F @ F.T + )(Fs) + Js = jax.vmap(model.jacobian)(grad_us) + I1_bars = jax.vmap(model.I1_bar)(grad_us) + psis, _ = jax.vmap(model.energy, in_axes=(0, None, 0, None))( + grad_us, theta, state_old, dt + ) + sigmas, _ = jax.vmap(model.cauchy_stress, in_axes=(0, None, 0, None))( + grad_us, theta, state_old, dt + ) + + def vmap_func(B_bar, I1_bar, J): + psi_an = ( + K * (J * jnp.log(J) - J + 1.0) + + 1.5 * A1 / (P1 + 1.0) * (I1_bar / 3.0 - 1.0) ** (P1 + 1.0) + + 1.5 * C1 / (R1 + 1.0) * (I1_bar / 3.0 - 1.0) ** (R1 + 1.0) + ) + dUdI1_bar = ( + 0.5 * A1 * (I1_bar / 3.0 - 1.0) ** P1 + + 0.5 * C1 * (I1_bar / 3.0 - 1.0) ** R1 + ) + B_bar_dev = B_bar - (1.0 / 3.0) * jnp.trace(B_bar) * jnp.eye(3) + sigma_an = (2.0 / J) * dUdI1_bar * B_bar_dev + \ + K * jnp.log(J) * jnp.eye(3) + return psi_an, sigma_an[0, 0], sigma_an[1, 1], sigma_an[0, 1] + + psi_ans, sigma_11_ans, sigma_22_ans, sigma_12_ans = jax.vmap( + vmap_func, in_axes=(0, 0, 0) + )(B_bars, I1_bars, Js) + + assert jnp.allclose(psis, psi_ans, atol=1e-3) + assert jnp.allclose(sigmas[:, 0, 0], sigma_11_ans, atol=1e-3) + assert jnp.allclose(sigmas[:, 1, 1], sigma_22_ans, atol=1e-3) + assert jnp.allclose(sigmas[:, 2, 2], sigma_22_ans, atol=1e-3) + # + assert jnp.allclose(sigmas[:, 0, 1], sigma_12_ans, atol=1e-3) + assert jnp.allclose(sigmas[:, 1, 2], 0.0) + assert jnp.allclose(sigmas[:, 2, 0], 0.0) + # # + assert jnp.allclose(sigmas[:, 1, 0], sigma_12_ans, atol=1e-3) + assert jnp.allclose(sigmas[:, 2, 1], 0.0) + assert jnp.allclose(sigmas[:, 0, 2], 0.0) def uniaxial_strain_test(model): - from .utils import uniaxial_strain - import jax - import jax.numpy as jnp - theta = 0. - state_old = jnp.zeros((100, 0)) - dt = 1. - lambdas = jnp.linspace(1.2, 4., 100) - Fs = jax.vmap(uniaxial_strain)(lambdas) - grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs) - B_bars = jax.vmap(lambda F: jnp.power(jnp.linalg.det(F), -2. / 3.) * F @ F.T)(Fs) - Js = jax.vmap(model.jacobian)(grad_us) - I1_bars = jax.vmap(model.I1_bar)(grad_us) - psis, _ = jax.vmap(model.energy, in_axes=(0, None, 0, None))( - grad_us, theta, state_old, dt - ) - sigmas, _ = jax.vmap(model.cauchy_stress, in_axes=(0, None, 0, None))( - grad_us, theta, state_old, dt - ) - - # for (psi, sigma, I1_bar, J, B_bar) in zip(psis, sigmas, I1_bars, Js, B_bars): - def vmap_func(B_bar, I1_bar, J): - psi_an = K * (J * jnp.log(J) - J + 1.) + \ - 1.5 * A1 / (P1 + 1.) * (I1_bar / 3. - 1.)**(P1 + 1.) + \ - 1.5 * C1 / (R1 + 1.) * (I1_bar / 3. - 1.)**(R1 + 1.) - dUdI1_bar = 0.5 * A1 * (I1_bar / 3. - 1.)**P1 + \ - 0.5 * C1 * (I1_bar / 3. - 1.)**R1 - B_bar_dev = B_bar - (1. / 3.) * jnp.trace(B_bar) * jnp.eye(3) - sigma_an = (2. / J) * dUdI1_bar * B_bar_dev + K * jnp.log(J) * jnp.eye(3) - return psi_an, sigma_an[0, 0], sigma_an[1, 1] - - psi_ans, sigma_11_ans, sigma_22_ans = jax.vmap( - vmap_func, in_axes=(0, 0, 0) - )(B_bars, I1_bars, Js) - - assert jnp.allclose(psis, psi_ans, atol=1e-3) - assert jnp.allclose(sigmas[:, 0, 0], sigma_11_ans, atol=1e-3) - assert jnp.allclose(sigmas[:, 1, 1], sigma_22_ans, atol=1e-3) - assert jnp.allclose(sigmas[:, 2, 2], sigma_22_ans, atol=1e-3) - # - assert jnp.allclose(sigmas[:, 0, 1], 0.0) - assert jnp.allclose(sigmas[:, 1, 2], 0.0) - assert jnp.allclose(sigmas[:, 2, 0], 0.0) - # - assert jnp.allclose(sigmas[:, 1, 0], 0.0) - assert jnp.allclose(sigmas[:, 2, 1], 0.0) - assert jnp.allclose(sigmas[:, 0, 2], 0.0) + from .utils import uniaxial_strain + import jax + import jax.numpy as jnp + + theta = 0.0 + state_old = jnp.zeros((100, 0)) + dt = 1.0 + lambdas = jnp.linspace(1.2, 4.0, 100) + Fs = jax.vmap(uniaxial_strain)(lambdas) + grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs) + B_bars = jax.vmap( + lambda F: jnp.power(jnp.linalg.det(F), -2.0 / 3.0) * F @ F.T + )(Fs) + Js = jax.vmap(model.jacobian)(grad_us) + I1_bars = jax.vmap(model.I1_bar)(grad_us) + psis, _ = jax.vmap(model.energy, in_axes=(0, None, 0, None))( + grad_us, theta, state_old, dt + ) + sigmas, _ = jax.vmap(model.cauchy_stress, in_axes=(0, None, 0, None))( + grad_us, theta, state_old, dt + ) + + def vmap_func(B_bar, I1_bar, J): + psi_an = ( + K * (J * jnp.log(J) - J + 1.0) + + 1.5 * A1 / (P1 + 1.0) * (I1_bar / 3.0 - 1.0) ** (P1 + 1.0) + + 1.5 * C1 / (R1 + 1.0) * (I1_bar / 3.0 - 1.0) ** (R1 + 1.0) + ) + dUdI1_bar = ( + 0.5 * A1 * (I1_bar / 3.0 - 1.0) ** P1 + + 0.5 * C1 * (I1_bar / 3.0 - 1.0) ** R1 + ) + B_bar_dev = B_bar - (1.0 / 3.0) * jnp.trace(B_bar) * jnp.eye(3) + sigma_an = (2.0 / J) * dUdI1_bar * B_bar_dev +\ + K * jnp.log(J) * jnp.eye(3) + return psi_an, sigma_an[0, 0], sigma_an[1, 1] + + psi_ans, sigma_11_ans, sigma_22_ans = jax.vmap( + vmap_func, in_axes=(0, 0, 0))( + B_bars, I1_bars, Js + ) + + assert jnp.allclose(psis, psi_ans, atol=1e-3) + assert jnp.allclose(sigmas[:, 0, 0], sigma_11_ans, atol=1e-3) + assert jnp.allclose(sigmas[:, 1, 1], sigma_22_ans, atol=1e-3) + assert jnp.allclose(sigmas[:, 2, 2], sigma_22_ans, atol=1e-3) + # + assert jnp.allclose(sigmas[:, 0, 1], 0.0) + assert jnp.allclose(sigmas[:, 1, 2], 0.0) + assert jnp.allclose(sigmas[:, 2, 0], 0.0) + # + assert jnp.allclose(sigmas[:, 1, 0], 0.0) + assert jnp.allclose(sigmas[:, 2, 1], 0.0) + assert jnp.allclose(sigmas[:, 0, 2], 0.0) def test_simple_shear(swanson_1, swanson_2): - simple_shear_test(swanson_1) - simple_shear_test(swanson_2) + simple_shear_test(swanson_1) + simple_shear_test(swanson_2) def test_uniaxial_strain(swanson_1, swanson_2): - uniaxial_strain_test(swanson_1) - uniaxial_strain_test(swanson_2) + uniaxial_strain_test(swanson_1) + uniaxial_strain_test(swanson_2) diff --git a/test/constitutive_models/utils.py b/test/constitutive_models/utils.py index c690002..8369a55 100644 --- a/test/constitutive_models/utils.py +++ b/test/constitutive_models/utils.py @@ -2,16 +2,8 @@ def uniaxial_strain(lambda_: float): - return jnp.array([ - [lambda_, 0., 0.], - [0., 1., 0.], - [0., 0., 1.] - ]) + return jnp.array([[lambda_, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) def simple_shear(gamma: float): - return jnp.array([ - [1., gamma, 0.], - [0., 1., 0.], - [0., 0., 1.] - ]) + return jnp.array([[1.0, gamma, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) diff --git a/test/networks/test_base.py b/test/networks/test_base.py new file mode 100644 index 0000000..6dee3a1 --- /dev/null +++ b/test/networks/test_base.py @@ -0,0 +1,65 @@ +import pytest + + +@pytest.fixture +def mlp(): + from pancax import MLP + import jax + return MLP(2, 3, 2, 2, jax.nn.tanh, key=jax.random.PRNGKey(0)) + + +@pytest.fixture +def trunc_init(): + import pancax + return pancax.networks.base.trunc_init + + +@pytest.fixture +def uniform_init(): + import pancax + return pancax.networks.base.uniform_init + + +@pytest.fixture +def zero_init(): + import pancax + return pancax.networks.base.zero_init + + +@pytest.fixture +def models(mlp): + return [mlp] + + +@pytest.fixture +def init_funcs( + trunc_init, + uniform_init, + zero_init +): + return [ + trunc_init, + uniform_init, + zero_init + ] + + +def test_serde(models): + import os + eqx_file_base_name = "temp" + for model in models: + model.serialise(eqx_file_base_name, 1) + new_model = model.deserialise(f"{eqx_file_base_name}_0000001.eqx") + assert model == new_model + os.system(f"rm -f {eqx_file_base_name}_0000001.eqx") + + +def test_uniform_init(models, init_funcs): + import jax + import numpy as np + + key = jax.random.PRNGKey(np.random.randint(0, 1000)) + + for model in models: + for init_func in init_funcs: + new_model = model.init(init_func, key=key) diff --git a/test/networks/test_network.py b/test/networks/test_network.py deleted file mode 100644 index 7511da3..0000000 --- a/test/networks/test_network.py +++ /dev/null @@ -1,7 +0,0 @@ -def test_network(): - from pancax import MLP, Network - import jax - model = Network(MLP, 3, 2, 20, 3, jax.nn.tanh, key=jax.random.key(0)) - x = jax.numpy.ones(3) - y = model(x) - assert y.shape == (2,)