diff --git a/examples/forward_problems/mechanics/example_hyper_visco_2d.py b/examples/forward_problems/mechanics/example_hyper_visco_2d.py index f0b2602..cc17b01 100644 --- a/examples/forward_problems/mechanics/example_hyper_visco_2d.py +++ b/examples/forward_problems/mechanics/example_hyper_visco_2d.py @@ -3,7 +3,7 @@ ################## # for reproducibility ################## -key = random.key(10) +key = random.PRNGKey(10) ################## # file management @@ -93,9 +93,9 @@ def dirichlet_bc_func(xs, t, nn): # train network ################## opt = Adam(loss_function, learning_rate=1.0e-3, has_aux=True, clip_gradients=False) -opt_st = opt.init(params) +opt, opt_st = opt.init(params) for epoch in range(100000): - params, opt_st, loss = opt.step(params, problem, opt_st) + params, opt_st, loss = opt.step(params, opt_st, problem) # logger.log_loss(loss, epoch) if epoch % 100 == 0: print(epoch, flush=True) @@ -105,6 +105,7 @@ def dirichlet_bc_func(xs, t, nn): if epoch % 10000 == 0: pp.init( + params, problem, f"output_{str(epoch).zfill(6)}.e", node_variables=[ diff --git a/examples/forward_problems/mechanics/example_incompressible_2d.py b/examples/forward_problems/mechanics/example_incompressible_2d.py index 3627d6d..4382695 100644 --- a/examples/forward_problems/mechanics/example_incompressible_2d.py +++ b/examples/forward_problems/mechanics/example_incompressible_2d.py @@ -3,7 +3,8 @@ ################## # for reproducibility ################## -key = random.key(10) +key = random.PRNGKey(10) +key = random.split(key, 8) # comment this to not use an ensemble ################## # file management @@ -16,6 +17,7 @@ ################## times = jnp.linspace(0.0, 1.0, 11) domain = VariationalDomain(mesh_file, times, q_order=2) +# domain = DeltaPINNDomain(mesh_file, times, n_eigen_values=20, q_order=2) ################## # physics setup @@ -49,18 +51,17 @@ # ML setup ################## loss_function = EnergyLoss() -# loss_function = EnergyAndResidualLoss(residual_weight=1.0e9) -# params = Parameters(problem, key, seperate_networks=True, network_type=ResNet) -params = Parameters(problem, key, seperate_networks=True) +params = Parameters(problem, key, seperate_networks=False) print(params) ################## # train network ################## opt = Adam(loss_function, learning_rate=1.0e-3, has_aux=True, clip_gradients=False) -opt_st = opt.init(params) +opt, opt_st = opt.init(params) + for epoch in range(25000): - params, opt_st, loss = opt.step(params, problem, opt_st) + params, opt_st, loss = opt.step(params, opt_st, problem) if epoch % 100 == 0: print(epoch, flush=True) print(loss, flush=True) @@ -68,7 +69,7 @@ ################## # post-processing ################## -pp.init(problem, 'output.e', +pp.init(params, problem, 'output.e', node_variables=[ 'field_values', # 'internal_force' diff --git a/examples/forward_problems/mechanics/example_incompressible_3d.py b/examples/forward_problems/mechanics/example_incompressible_3d.py index 7e7b0fc..3508e0b 100644 --- a/examples/forward_problems/mechanics/example_incompressible_3d.py +++ b/examples/forward_problems/mechanics/example_incompressible_3d.py @@ -3,7 +3,7 @@ ################## # for reproducibility ################## -key = random.key(10) +key = random.PRNGKey(10) ################## # file management @@ -58,9 +58,9 @@ # train network ################## opt = Adam(loss_function, learning_rate=1.0e-3, has_aux=True, clip_gradients=False) -opt_st = opt.init(params) +opt, opt_st = opt.init(params) for epoch in range(25000): - params, opt_st, loss = opt.step(params, problem, opt_st) + params, opt_st, loss = opt.step(params, opt_st, problem) if epoch % 100 == 0: print(epoch, flush=True) print(loss, flush=True) @@ -68,7 +68,7 @@ ################## # post-processing ################## -pp.init(problem, 'output.e', +pp.init(params, problem, 'output.e', node_variables=[ 'field_values' # 'displacement', diff --git a/examples/inverse_problems/mechanics/vanilla/example_v2.py b/examples/inverse_problems/mechanics/vanilla/example_v2.py index 4b06c2f..24b7ffb 100644 --- a/examples/inverse_problems/mechanics/vanilla/example_v2.py +++ b/examples/inverse_problems/mechanics/vanilla/example_v2.py @@ -3,7 +3,9 @@ ################## # for reproducibility ################## -key = random.key(10) +# key = random.key(10) +key = random.PRNGKey(10) +key = random.split(key, 8) ################## # file management @@ -35,11 +37,15 @@ ################## # physics setup ################## +# @eqx.filter_vmap +# def models(key): model = NeoHookean( bulk_modulus=0.833, # bulk_modulus=BoundedProperty(0.01, 5., key=key), shear_modulus=BoundedProperty(0.01, 5., key=key) ) + # return model + physics = SolidMechanics(model, PlaneStrain()) ################## @@ -71,7 +77,9 @@ ################## # ML setup ################## -params = Parameters(problem, key, seperate_networks=True, network_type=ResNet) +params = Parameters(problem, key)#, seperate_networks=True, network_type=ResNet) +print(params) +# assert False physics_and_global_loss = EnergyResidualAndReactionLoss( residual_weight=250.e9, reaction_weight=250.e9 ) @@ -90,19 +98,15 @@ def loss_function(params, problem, inputs, outputs): ################## # Training ################## -opt_st = opt.init(params) - +opt, opt_st = opt.init(params) -# testing stuff dataloader = FullFieldDataLoader(problem.field_data) - for epoch in range(10000): for inputs, outputs in dataloader.dataloader(1024): - params, opt_st, loss = opt.step(params, problem, opt_st, inputs, outputs) + params, opt_st, loss = opt.step(params, opt_st, problem, inputs, outputs) if epoch % 10 == 0: print(epoch) print(loss) - print(params.physics.constitutive_model.bulk_modulus) - print(params.physics.constitutive_model.shear_modulus) - # print(params.physics.constitutive_model.Jm_parameter()) + print(params.physics.constitutive_model) + # print(params.physics.constitutive_model.shear_modulus) diff --git a/pancax/__init__.py b/pancax/__init__.py index dd32f2b..1c21b75 100644 --- a/pancax/__init__.py +++ b/pancax/__init__.py @@ -62,7 +62,7 @@ MLPBasis, \ Parameters, \ ResNet -from .optimizers import Adam, LBFGS +from .optimizers import Adam from .physics_kernels import \ BasePhysics, \ BaseEnergyFormPhysics, \ @@ -174,7 +174,7 @@ "ResNet", # optimizers "Adam", - "LBFGS", + # "LBFGS", # physics "BasePhysics", "BaseEnergyFormPhysics", diff --git a/pancax/constitutive_models/base.py b/pancax/constitutive_models/base.py index de4c756..0ba68d4 100644 --- a/pancax/constitutive_models/base.py +++ b/pancax/constitutive_models/base.py @@ -11,34 +11,21 @@ class ConstitutiveModel(eqx.Module): def __repr__(self): prop_names = self.properties() - # props = asdict(self) - # print(props) string = f"{type(self)}:\n" - - # for prop_name, prop in zip(prop_names, props): - # string = string + f" {prop_name} = {prop}\n" - # for k, v in props.items(): max_str_length = max(map(len, prop_names)) - print(max_str_length) for prop_name in prop_names: v = getattr(self, prop_name) string = string + f" {prop_name}" - # string = string.rjust(max_str_length) - # string = string + " = " string = string.ljust(max_str_length) string = string + " = " if type(v) is float: string = string + f"{v}\n" - # elif type(v) is ConstitutiveModel: elif issubclass(ConstitutiveModel, type(v)): string = string + f"{v.__repr__()}\n" else: - string = string + f"{v()}\n" - # if type(v) is float: - # string = string + f" {prop_name} = {v}\n" - # else: - # string = string + f" {prop_name} = {v()}\n" + string = string + f"{v.__repr__()}\n" + return string def properties(self): diff --git a/pancax/constitutive_models/properties.py b/pancax/constitutive_models/properties.py index 6adf1e4..6c996f4 100644 --- a/pancax/constitutive_models/properties.py +++ b/pancax/constitutive_models/properties.py @@ -1,7 +1,8 @@ -from jaxtyping import Array +from jaxtyping import Array, Float from typing import Union import equinox as eqx import jax +import jax.numpy as jnp # TODO patch up error check in a good way @@ -9,7 +10,7 @@ class BoundedProperty(eqx.Module): prop_min: float = eqx.field(static=True) prop_max: float = eqx.field(static=True) - prop_val: float + prop_val: Float[Array, "n"] # TODO # activation: Callable @@ -18,7 +19,16 @@ def __init__( ) -> None: self.prop_min = prop_min self.prop_max = prop_max - self.prop_val = jax.random.uniform(key, 1) + + if len(key.shape) == 1: + self.prop_val = self._sample(key, prop_min, prop_max) + elif len(key.shape) == 2: + @eqx.filter_vmap + def vmap_func(key): + return self._sample(key, prop_min, prop_max) + self.prop_val = vmap_func(key) + else: + assert False, f"This shouldn't happen key = {key}" def __call__(self): return ( @@ -28,7 +38,12 @@ def __call__(self): ) def __repr__(self): - return str(self.__call__()) + # return str(self.__call__()) + return str(( + self.prop_min + + (self.prop_max - self.prop_min) * + jax.nn.sigmoid(self.prop_val) + )) def __add__(self, other): self._check_other_type(other, "+") @@ -79,6 +94,14 @@ def _check_other_type(self, other, op_str): {op_str} with BoundingProperty" ) + def _sample(self, key, lb, ub): + if lb == ub: + return jnp.zeros(1) + + p_actual = jax.random.uniform(key, 1, minval=lb, maxval=ub) + y = (p_actual - lb) / (ub - lb) + return jnp.log(y / (1. - y)) + FixedProperty = float Property = Union[BoundedProperty, FixedProperty] diff --git a/pancax/domains/delta_pinn_domain.py b/pancax/domains/delta_pinn_domain.py index fec89aa..8ab8886 100644 --- a/pancax/domains/delta_pinn_domain.py +++ b/pancax/domains/delta_pinn_domain.py @@ -31,8 +31,6 @@ def __init__( self.physics = physics.update_var_name_to_method() self.eigen_modes = self.solve_eigen_problem() - # def __pos - def solve_eigen_problem(self): # physics = LaplaceBeltrami() dof_manager = DofManager(self.mesh, 1, []) @@ -58,9 +56,18 @@ def solve_eigen_problem(self): for n in range(len(lambdas)): print(f" Eigen mode {n} = {1. / lambdas[n]}") + # dummy params + class DummyParams: + is_ensemble = False + n_ensemble = 1 + + params = DummyParams() + with Timer("post-processing"): pp = PostProcessor(self.mesh_file) - pp.init(self, "output-eigen.e", node_variables=["field_values"]) + pp.init( + params, self, "output-eigen.e", node_variables=["field_values"] + ) with nc.Dataset(pp.pp.output_file, "a") as out: for n in range(len(lambdas)): diff --git a/pancax/networks/parameters.py b/pancax/networks/parameters.py index 1204f12..b6d17b0 100644 --- a/pancax/networks/parameters.py +++ b/pancax/networks/parameters.py @@ -10,7 +10,7 @@ State = Union[Float[Array, "nt ne nq ns"], eqx.Module, None] -class Parameters(AbstractPancaxModel): +class _Parameters(AbstractPancaxModel): """ Data structure for storing all parameters needed for a model @@ -20,13 +20,14 @@ class Parameters(AbstractPancaxModel): :param state: state object (can be parameters or jax array) """ - fields: eqx.Module + fields: Field physics: eqx.Module state: State def __init__( self, problem, + constitutive_model, key, network_type: Optional[eqx.Module] = MLP, seperate_networks: Optional[bool] = False @@ -36,62 +37,124 @@ def __init__( network_type=network_type, seperate_networks=seperate_networks ) - # state = self._create_state_array(problem) + # TODO what do we do with this guy? state = None + physics = eqx.tree_at( + lambda x: x.constitutive_model, problem.physics, constitutive_model + ) self.fields = fields - self.physics = problem.physics + self.physics = physics self.state = state def __iter__(self): return iter((self.fields, self.physics, self.state)) - # TODO - # make helper filter methods so there's - # less code duplication + +class Parameters(AbstractPancaxModel): + is_ensemble: bool = eqx.field(static=True) + n_ensemble: int = eqx.field(static=True) + parameters: _Parameters + + def __init__( + self, + problem, + key, + network_type: Optional[eqx.Module] = MLP, + seperate_networks: Optional[bool] = False + ) -> None: + if len(key.shape) == 1: + is_ensemble = False + n_ensemble = 1 + parameters = _Parameters( + problem, problem.physics.constitutive_model, key, + network_type=network_type, + seperate_networks=seperate_networks + ) + elif len(key.shape) == 2: + is_ensemble = True + n_ensemble = key.shape[0] + + @eqx.filter_vmap + def vmap_func(key, constitutive_model): + return _Parameters( + problem, constitutive_model, key, + network_type=network_type, + seperate_networks=seperate_networks + ) + + parameters = vmap_func(key, problem.physics.constitutive_model) + else: + raise ValueError( + f"Invalid shape for key {key} with shape {key.shape}" + ) + + self.is_ensemble = is_ensemble + self.n_ensemble = n_ensemble + self.parameters = parameters + + def __iter__(self): + return iter((self.fields, self.physics, self.state)) + + @property + def fields(self): + return self.parameters.fields + + @property + def physics(self): + return self.parameters.physics + + @property + def state(self): + return self.parameters.state + def freeze_fields_filter(self): filter_spec = jtu.tree_map(lambda _: True, self) - fields_filter = jtu.tree_map(lambda _: False, self.fields) + fields_filter = jtu.tree_map(lambda _: False, self.parameters.fields) filter_spec = eqx.tree_at( - lambda x: x.fields, filter_spec, fields_filter + lambda x: x.parameters.fields, filter_spec, fields_filter ) # freeze normalization filter_spec = eqx.tree_at( - lambda x: x.physics.x_mins, filter_spec, replace=False + lambda x: x.parameters.physics.x_mins, filter_spec, replace=False ) filter_spec = eqx.tree_at( - lambda x: x.physics.x_maxs, filter_spec, replace=False + lambda x: x.parameters.physics.x_maxs, filter_spec, replace=False ) filter_spec = eqx.tree_at( - lambda x: x.physics.t_min, filter_spec, replace=False + lambda x: x.parameters.physics.t_min, filter_spec, replace=False ) filter_spec = eqx.tree_at( - lambda x: x.physics.t_max, filter_spec, replace=False + lambda x: x.parameters.physics.t_max, filter_spec, replace=False ) return filter_spec # Move some of below to actually network implementation def freeze_physics_filter(self): filter_spec = jtu.tree_map(lambda _: True, self) - physics_filter = jtu.tree_map(lambda _: False, self.physics) + physics_filter = jtu.tree_map(lambda _: False, self.parameters.physics) filter_spec = eqx.tree_at( - lambda x: x.physics, filter_spec, physics_filter + lambda x: x.parameters.physics, filter_spec, physics_filter ) return filter_spec def freeze_physics_normalization_filter(self): filter_spec = jtu.tree_map(lambda _: True, self) filter_spec = eqx.tree_at( - lambda tree: tree.physics.x_mins, filter_spec, replace=False + lambda tree: tree.parameters.physics.x_mins, filter_spec, + replace=False ) filter_spec = eqx.tree_at( - lambda tree: tree.physics.x_maxs, filter_spec, replace=False + lambda tree: tree.parameters.physics.x_maxs, filter_spec, + replace=False ) filter_spec = eqx.tree_at( - lambda tree: tree.physics.t_min, filter_spec, replace=False + lambda tree: tree.parameters.physics.t_min, filter_spec, + replace=False ) filter_spec = eqx.tree_at( - lambda tree: tree.physics.t_max, filter_spec, replace=False + lambda tree: tree.parameters.physics.t_max, filter_spec, + replace=False ) return filter_spec diff --git a/pancax/optimizers/__init__.py b/pancax/optimizers/__init__.py index 34af5e1..79bba73 100644 --- a/pancax/optimizers/__init__.py +++ b/pancax/optimizers/__init__.py @@ -1,7 +1,9 @@ from .adam import Adam -from .lbfgs import LBFGS -from .base import Optimizer +# from .lbfgs import LBFGS # from .utils import * -__all__ = ["Adam", "LBFGS", "Optimizer"] +# __all__ = ["Adam", "LBFGS", "Optimizer"] +__all__ = [ + "Adam" +] diff --git a/pancax/optimizers/adam.py b/pancax/optimizers/adam.py index da88ff4..a90a929 100644 --- a/pancax/optimizers/adam.py +++ b/pancax/optimizers/adam.py @@ -1,28 +1,43 @@ -from .base import Optimizer -from typing import Callable -from typing import Optional +from .base import AbstractOptimizer +from typing import Callable, Optional, Union import equinox as eqx import optax -class Adam(Optimizer): +class Adam(AbstractOptimizer): + epoch: int + has_aux: bool + jit: bool + loss_function: Callable # TODO further type me + loss_and_grads: Callable + step: Union[None, Callable] + # + opt: any + def __init__( self, loss_function: Callable, + learning_rate: float, + # TODO this option should probably be on the base class + clip_gradients: Optional[bool] = False, + decay_rate: Optional[float] = 0.99, has_aux: Optional[bool] = False, jit: Optional[bool] = True, - learning_rate: Optional[float] = 1.0e-3, - transition_steps: Optional[int] = 500, - decay_rate: Optional[float] = 0.99, - clip_gradients: Optional[bool] = False, - filter_spec: Optional[Callable] = None, + transition_steps: Optional[int] = 500 ) -> None: - super().__init__(loss_function, has_aux, jit) + super().__init__( + loss_function, + has_aux=has_aux, + jit=jit + ) + + # TODO figure out how best to handle scheduler scheduler = optax.exponential_decay( init_value=learning_rate, transition_steps=transition_steps, decay_rate=decay_rate, ) + if clip_gradients: self.opt = optax.chain( optax.clip_by_global_norm(1.0), @@ -37,36 +52,13 @@ def __init__( optax.scale(-1.0), ) - self.loss_and_grads = eqx.filter_value_and_grad( - self.loss_function.filtered_loss, has_aux=self.has_aux - ) - self.filter_spec = filter_spec - - def make_step_method(self): - # if self.filter_spec is None: - # def step(params, domain, opt_st): - # loss, grads = self.loss_and_grads(params, domain) - # updates, opt_st = self.opt.update(grads, opt_st) - # params = eqx.apply_updates(params, updates) - # # add grad props to output - # # TODO what to do about below? - # # loss[1].update({'dprops': grads.properties.prop_params}) - # return params, opt_st, loss - # else: - def step(params, domain, opt_st, *args): - if self.filter_spec is None: - filter_spec = params.freeze_physics_normalization_filter() - else: - filter_spec = self.filter_spec - - diff_params, static_params = eqx.partition(params, filter_spec) + def make_step_method(self, filter_spec: Callable) -> Callable: + def step(params, opt_st, *args): + diff_params, static_params = \ + eqx.partition(params, filter_spec) loss, grads = \ - self.loss_and_grads(diff_params, static_params, domain, *args) + self.loss_and_grads(diff_params, static_params, *args) updates, opt_st = self.opt.update(grads, opt_st) params = eqx.apply_updates(params, updates) - - # add grad props to output - # loss[1].update({'dprops': grads.properties()}) return params, opt_st, loss - return step diff --git a/pancax/optimizers/base.py b/pancax/optimizers/base.py index d4c5d45..b060a50 100644 --- a/pancax/optimizers/base.py +++ b/pancax/optimizers/base.py @@ -1,101 +1,87 @@ -from abc import ABC from abc import abstractmethod -from typing import Callable -from typing import Optional +from typing import Callable, Optional, Union import equinox as eqx -class Optimizer(ABC): +class AbstractOptimizer(eqx.Module): + epoch: int + has_aux: bool + jit: bool + loss_function: Callable # TODO further type me + loss_and_grads: Callable # TODO not sure if this type is right? + step: Union[None, Callable] + def __init__( self, loss_function: Callable, + *, has_aux: Optional[bool] = False, - jit: Optional[bool] = True, + jit: Optional[bool] = True ) -> None: - self.loss_function = loss_function + self.epoch = 0 self.has_aux = has_aux self.jit = jit + self.loss_function = loss_function + self.loss_and_grads = eqx.filter_value_and_grad( + self.loss_function.filtered_loss, has_aux=has_aux + ) self.step = None - self.epoch = 0 @abstractmethod - def make_step_method(self, params): + def make_step_method(self, filter_spec: Callable): pass - def ensemble_init(self, params): - self.step = self.make_step_method() - # if self.jit: - # self.step = eqx.filter_jit(self.step) + def _ensemble_init(self, params, filter_spec): + filter_spec = self._init_filter(params, filter_spec=filter_spec) + step = self.make_step_method(filter_spec) + + def ensemble_step(params, opt_st, *args): + in_axes = (eqx.if_array(0), eqx.if_array(0)) + in_axes = in_axes + len(args) * (None,) - # need to now make an ensemble wrapper our self.step - # but make sure not to jit it until after the vmap - def ensemble_step(params, domain, opt_st): - params, opt_st, loss = eqx.filter_vmap( - self.step, in_axes=(eqx.if_array(0), None, eqx.if_array(0)) - )(params, domain, opt_st) - return params, opt_st, loss + @eqx.filter_vmap(in_axes=in_axes) + def vmap_func(params, opt_st, *args): + return step(params, opt_st, *args) + + return vmap_func(params, opt_st, *args) if self.jit: - self.ensemble_step = eqx.filter_jit(ensemble_step) + ensemble_step = eqx.filter_jit(ensemble_step) + + self = eqx.tree_at( + lambda x: x.step, self, ensemble_step, + is_leaf=lambda x: x is None + ) + @eqx.filter_vmap(in_axes=(eqx.if_array(0),)) def vmap_func(p): return self.opt.init(eqx.filter(p, eqx.is_array)) - opt_st = eqx.filter_vmap(vmap_func, in_axes=(eqx.if_array(0),))(params) - return opt_st + opt_st = vmap_func(params) - def ensemble_step_old(self, params, domain, opt_st): - params, opt_st, loss = eqx.filter_vmap( - self.step, in_axes=(eqx.if_array(0), None, eqx.if_array(0)) - )(params, domain, opt_st) - return params, opt_st, loss + return self, opt_st - def init(self, params): - self.step = self.make_step_method() + def _init(self, params, filter_spec): + filter_spec = self._init_filter(params, filter_spec=filter_spec) + step = self.make_step_method(filter_spec) if self.jit: - self.step = eqx.filter_jit(self.step) + step = eqx.filter_jit(step) + self = eqx.tree_at( + lambda x: x.step, self, step, + is_leaf=lambda x: x is None + ) opt_st = self.opt.init(eqx.filter(params, eqx.is_array)) - # opt_st = self.opt.init(eqx.filter(params, filter_spec)) - return opt_st - - # def train(self, params, opt_st, domain, n_epochs, log_every): - # for n in range(n_epochs): - # params, opt_st, loss = self.step(params, domain, opt_st) - # log_loss(loss, self.epoch, log_every) - # self.epoch = self.epoch + 1 - # return params, opt_st - def train( - self, - params, - domain, - times, - opt, - logger, - history, - pp, - n_epochs, - log_every: Optional[int] = 100, - serialise_every: Optional[int] = 10000, - postprocess_every: Optional[int] = 10000, - ): - opt_st = opt.init(params) - for epoch in range(int(n_epochs)): - params, opt_st, loss = opt.step(params, domain, opt_st) - logger.log_loss(loss, epoch, log_every) - history.write_loss(loss, epoch) - - if epoch % serialise_every == 0: - params.serialise("checkpoint", epoch) - - if epoch % postprocess_every == 0: - pp.init(params, domain, f"output_{str(epoch).zfill(6)}.e") - pp.write_outputs( - params, - domain, - times, - [ - "displacement", - ], - ) - pp.close() + return self, opt_st + + def init(self, params, filter_spec: Optional[Callable] = None): + if params.is_ensemble: + return self._ensemble_init(params, filter_spec) + else: + return self._init(params, filter_spec) + + def _init_filter(self, params, filter_spec): + if filter_spec is None: + filter_spec = params.freeze_physics_normalization_filter() + + return filter_spec diff --git a/pancax/optimizers/utils.py b/pancax/optimizers/utils.py deleted file mode 100644 index 8caae40..0000000 --- a/pancax/optimizers/utils.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Optional -import equinox as eqx -import jax - - -def trainable_filter( - params: any, - # freeze_fields: Optional[bool] = False, - freeze_properties: Optional[bool] = True, - freeze_basis: Optional[bool] = False, - # freeze_linear_layer: Optional[bool] = False -): - # TODO there's some logic to work out here - filter_spec = jax.tree_util.tree_map(lambda _: True, params) - - # filter_spec = eqx.tree_at( - # lambda tree: tree.fields, - # filter_spec, - # replace=not freeze_fields - # ) - filter_spec = eqx.tree_at( - lambda tree: tree.properties.prop_params, - filter_spec, - replace=not freeze_properties, - ) - - if freeze_basis: - try: - filter_spec = eqx.tree_at( - lambda tree: ( - tree.fields.basis.weight, - tree.fields.basis.bias - ), - filter_spec, - replace=(not freeze_basis, not freeze_basis), - ) - filter_spec = eqx.tree_at( - lambda tree: ( - tree.fields.linear.weight, - tree.fields.linear.bias - ), - filter_spec, - replace=(True, True), - ) - except AttributeError: - raise AttributeError("This network does not have a basis!") - - return filter_spec diff --git a/pancax/physics_kernels/base.py b/pancax/physics_kernels/base.py index 2412a1b..1baf7b7 100644 --- a/pancax/physics_kernels/base.py +++ b/pancax/physics_kernels/base.py @@ -138,8 +138,13 @@ class BasePhysics(eqx.Module): x_maxs: Float[Array, "nd"] # = jnp.zeros(3) t_min: Float[Array, "1"] t_max: Float[Array, "1"] + # + is_delta_pinn: bool - def __init__(self, field_value_names: tuple[str, ...]) -> None: + def __init__( + self, + field_value_names: tuple[str, ...] + ) -> None: self.field_value_names = field_value_names self.var_name_to_method = {} self.dirichlet_bc_func = lambda x, t, z: z @@ -149,6 +154,8 @@ def __init__(self, field_value_names: tuple[str, ...]) -> None: self.t_min = jnp.zeros(1) self.t_max = jnp.ones(1) + self.is_delta_pinn = False + # TODO need to modify for delta pinn def field_values(self, field, x, t, *args): # x = (x - stop_gradient(self.x_mins)) / diff --git a/pancax/post_processor.py b/pancax/post_processor.py index be63c8d..36a1442 100644 --- a/pancax/post_processor.py +++ b/pancax/post_processor.py @@ -1,7 +1,10 @@ from abc import abstractmethod - -# from pancax.physics import incompressible_internal_force, internal_force from typing import List +import equinox as eqx +import jax +import jax.numpy as jnp +import netCDF4 as nc +import numpy as onp try: import vtk @@ -11,60 +14,43 @@ "You'll need to use another form of output" ) -import jax -import jax.numpy as jnp -import os -import netCDF4 as nc -import numpy as onp - -# import vtk - class BasePostProcessor: mesh_file: str = None node_variables: List[str] = None element_variables: List[str] = None + output_file: str = None - # def __init__( - # self, - # mesh_file: str, - # node_variables: List[str], - # element_variables: List[str] - # ) -> None: - # self.mesh_file = mesh_file - # self.node_variables = node_variables - # self.element_variables = element_variables - - def check_variable_names(self, domain, variables) -> None: + def check_variable_names(self, problem, variables) -> None: for var in variables: if var == "internal_force" or \ var == "incompressible_internal_force": continue - if var not in domain.physics.var_name_to_method.keys(): + if var not in problem.physics.var_name_to_method.keys(): str = f"Unsupported variable requested for output {var}.\n" str += "Supported variables include:\n" - for v in domain.physics.var_name_to_method.keys(): + for v in problem.physics.var_name_to_method.keys(): str += f" {v}\n" raise ValueError(str) - def get_node_variable_number(self, domain, variables) -> int: + def get_node_variable_number(self, problem, variables) -> int: n = 0 for var in variables: if var == "internal_force" or \ var == "incompressible_internal_force": - n = n + len(domain.physics.field_value_names) + n = n + len(problem.physics.field_value_names) else: - n = n + len(domain.physics.var_name_to_method[var]["names"]) + n = n + len(problem.physics.var_name_to_method[var]["names"]) return n - def get_element_variable_number(self, domain, variables) -> int: + def get_element_variable_number(self, problem, variables) -> int: n = 0 for var in variables: - n = n + len(domain.physics.var_name_to_method[var]["names"]) + n = n + len(problem.physics.var_name_to_method[var]["names"]) if n > 0: - q_points = len(domain.fspace.quadrature_rule) + q_points = len(problem.fspace.quadrature_rule) return n * q_points else: return 0 @@ -74,9 +60,10 @@ def close(self) -> None: pass @abstractmethod - def init( + def _init( self, - domain, + params, + problem, output_file: str, node_variables: List[str], element_variables: List[str], @@ -84,9 +71,64 @@ def init( pass @abstractmethod - def write_outputs(self, params, domain): + def _write_outputs(self, params, problem): pass + def init( + self, + params, + problem, + output_file: str, + node_variables: List[str], + element_variables: List[str], + ) -> None: + self.output_file = output_file + + if params.is_ensemble: + # parts = self.output_file.split('.') + # base_name = parts[0] + # ext = parts[1] + # # n_ensemble = jnp.arange(params.n_ensemble) + # fnames = [] + # for n in range(params.n_ensemble): + # fnames.append(f"{base_name}_{n}.{ext}") + # fnames = jnp.array(fnames) + + # @eqx.filter_vmap + # def vmap_func(params, fname): + # # output_file = f"{base_name}_{n}.{ext}" + # self._init( + # params, problem, + # fname, node_variables, element_variables + # ) + # # print(n_ensemble) + # vmap_func(params, fnames) + print( + "WARNING: post-processing is " + "currently unsupported for ensembles" + ) + else: + self._init( + params, problem, + output_file, node_variables, element_variables + ) + + def write_outputs(self, params, problem): + if params.is_ensemble: + parts = self.output_file.split('.') + base_name = parts[0] + ext = parts[1] + n_ensemble = jnp.arange(params.n_ensemble) + + @eqx.filter_vmap + def vmap_func(params, n): + output_file = f"{base_name}_{n}.{ext}" + self._write_outputs(params, problem, output_file) + + vmap_func(params, n_ensemble) + else: + self._write_outputs(params, problem, self.output_file) + class ExodusPostProcessor(BasePostProcessor): def __init__(self, mesh_file: str) -> None: @@ -96,16 +138,18 @@ def __init__(self, mesh_file: str) -> None: def close(self) -> None: pass - def init( + def _init( self, - domain, + # domain, + params, + problem, output_file: str, node_variables: List[str], element_variables: List[str], ) -> None: self.output_file = output_file - self.check_variable_names(domain, node_variables) - self.check_variable_names(domain, element_variables) + self.check_variable_names(problem, node_variables) + self.check_variable_names(problem, element_variables) self.node_variables = node_variables self.element_variables = element_variables @@ -145,7 +189,7 @@ def init( # get total number of node variables num_node_vars = 0 for var in node_variables: - for v in domain.physics.\ + for v in problem.physics.\ var_name_to_method[var]["names"]: num_node_vars = num_node_vars + 1 @@ -156,7 +200,7 @@ def init( n = 0 for var in node_variables: - for v in domain.\ + for v in problem.\ physics.var_name_to_method[var]["names"]: name = v.ljust(max_str_len)[:max_str_len] # print(name) @@ -168,11 +212,11 @@ def init( n = n + 1 if len(element_variables) > 0: - q_points = len(domain.fspace.quadrature_rule) + q_points = len(problem.fspace.quadrature_rule) # get total number of node variables num_elem_vars = 0 for var in element_variables: - for v in domain.\ + for v in problem.\ physics.var_name_to_method[var]["names"]: for _ in range(q_points): num_elem_vars = num_elem_vars + 1 @@ -184,7 +228,7 @@ def init( n = 0 for var in element_variables: - for v in domain.\ + for v in problem.\ physics.var_name_to_method[var]["names"]: for q in range(q_points): name = f"{v}_{q + 1}" @@ -198,11 +242,25 @@ def init( ) n = n + 1 - def write_outputs(self, params, problem): + # def write_outputs(self, params, problem): + # if params.is_ensemble: + # parts = self.output_file.split('.') + # base_name = parts[0] + # ext = parts[1] + # n_emsemble = jnp.arange(params.n_ensemble) + + # @eqx.filter_vmap() + # def vmap_func(params, n): + # output_file = f"{base_name}_{n}.e" + # assert False, "Need to implement for ensemble" + # else: + # self._write_outputs(params, problem, self.output_file) + + def _write_outputs(self, params, problem, output_file): physics = problem.physics times = problem.times - with nc.Dataset(self.output_file, "a") as dataset: + with nc.Dataset(output_file, "a") as dataset: ne = problem.domain.conns.shape[0] nq = len(problem.domain.fspace.quadrature_rule) @@ -316,161 +374,7 @@ def _vmap_func(n): state_old = state_new -class ExodusPostProcessor_old: - def __init__(self, mesh_file: str) -> None: - self.mesh_file = mesh_file - self.exo = None - self.node_variables = None - self.element_variables = None - - def check_variable_names(self, domain, variables) -> None: - for var in variables: - if var == "internal_force" or \ - var == "incompressible_internal_force": - continue - - if var not in domain.physics.var_name_to_method.keys(): - str = f"Unsupported variable requested for output {var}.\n" - str += "Supported variables include:\n" - for v in domain.physics.var_name_to_method.keys(): - str += f" {v}\n" - raise ValueError(str) - - def close(self) -> None: - self.exo.close() - - def copy_mesh(self, output_file: str) -> None: - if os.path.isfile(output_file): - os.remove(output_file) - - exo_temp = exodus.copy_mesh(self.mesh_file, output_file) - exo_temp.close() - self.exo = exodus.exodus(output_file, mode="a", array_type="numpy") - - def get_node_variable_number(self, domain, variables) -> int: - n = 0 - for var in variables: - if var == "internal_force" or \ - var == "incompressible_internal_force": - n = n + len(domain.physics.field_value_names) - else: - n = n + len(domain.physics.var_name_to_method[var]["names"]) - - return n - - def get_element_variable_number(self, domain, variables) -> int: - n = 0 - for var in variables: - n = n + len(domain.physics.var_name_to_method[var]["names"]) - if n > 0: - q_points = len(domain.fspace.quadrature_rule) - return n * q_points - else: - return 0 - - def init( - self, - domain, - output_file: str, - node_variables: List[str], - element_variables: List[str], - ) -> None: - self.copy_mesh(output_file) - self.check_variable_names(domain, node_variables) - self.check_variable_names(domain, element_variables) - self.node_variables = node_variables - self.element_variables = element_variables - self.exo.set_node_variable_number( - self.get_node_variable_number(domain, node_variables) - ) - self.exo.set_element_variable_number( - self.get_element_variable_number(domain, element_variables) - ) - n = 1 - for var in self.node_variables: - if var == "internal_force" or \ - var == "incompressible_internal_force": - self.exo.put_node_variable_name("internal_force_x", n) - self.exo.put_node_variable_name("internal_force_y", n + 1) - self.exo.put_node_variable_name("internal_force_z", n + 2) - n = n + 3 - else: - for v in domain.physics.var_name_to_method[var]["names"]: - self.exo.put_node_variable_name(v, n) - n = n + 1 - - if len(element_variables) > 0: - q_points = len(domain.fspace.quadrature_rule) - n = 1 - for var in self.element_variables: - for v in domain.physics.var_name_to_method[var]["names"]: - for q in range(q_points): - name = f"{v}_{q + 1}" - self.exo.put_element_variable_name(name, n) - n = n + 1 - - def index_to_component(self, index): - if index == 0: - string = "x" - elif index == 1: - string = "y" - elif index == 2: - string = "z" - else: - raise ValueError("Should be 0, 1, or 2") - return string - - def write_outputs(self, params, domain) -> None: - physics = domain.physics - times = domain.times - for n, time in enumerate(times): - self.exo.put_time(n + 1, time) - - for var in self.node_variables: - if var == "internal_force" or \ - var == "incompressible_internal_force": - us = jax.vmap( - physics.field_values, in_axes=(None, 0, None) - )(params.fields, domain.coords, time) - fs = onp.array( - internal_force(domain, us, params.properties()) - ) - for i in range(fs.shape[1]): - self.exo.put_node_variable_values( - f"internal_force_{self.index_to_component(i)}", - n + 1, - fs[:, i], - ) - else: - output = physics.var_name_to_method[var] - pred = onp.array(output["method"](params, domain, time)) - if len(pred.shape) > 2: - for i in range(pred.shape[1]): - for j in range(pred.shape[2]): - k = pred.shape[1] * i + j - self.exo.put_node_variable_values( - output["names"][k], n + 1, pred[:, i, j] - ) - else: - for i in range(pred.shape[1]): - self.exo.put_node_variable_values( - output["names"][i], n + 1, pred[:, i] - ) - - if len(self.element_variables) > 0: - n_q_points = len(domain.fspace.quadrature_rule) - for var in self.element_variables: - output = physics.var_name_to_method[var] - pred = onp.array(output["method"](params, domain, time)) - for q in range(n_q_points): - for i in range(pred.shape[2]): - name = f'{output["names"][i]}_{q + 1}' - # NOTE this will only work on a single block - self.exo.put_element_variable_values( - 1, name, n + 1, pred[:, q, i] - ) - - +# TODO fix this class VtkPostProcessor: def __init__(self, mesh_file: str) -> None: self.mesh_file = mesh_file @@ -520,7 +424,7 @@ def init( # writer.SetInputData(poly_data) # writer.Write() - def write_outputs(self, params, domain) -> None: + def _write_outputs(self, params, domain, output_file) -> None: physics = domain.physics times = domain.times @@ -689,8 +593,14 @@ def __init__(self, mesh_file: str, mesh_type="exodus") -> None: def close(self): self.pp.close() - def init(self, domain, output_file, node_variables, element_variables=[]): - self.pp.init(domain, output_file, node_variables, element_variables) + def init( + self, params, problem, output_file, node_variables, + element_variables=[] + ): + self.pp.init( + params, problem, output_file, + node_variables, element_variables + ) - def write_outputs(self, params, domain): - self.pp.write_outputs(params, domain) + def write_outputs(self, params, problem): + self.pp.write_outputs(params, problem) diff --git a/pancax/problems/forward_problem.py b/pancax/problems/forward_problem.py index cb6ee85..67f6539 100644 --- a/pancax/problems/forward_problem.py +++ b/pancax/problems/forward_problem.py @@ -1,5 +1,10 @@ from ..bcs import DirichletBC, NeumannBC -from ..domains import BaseDomain, CollocationDomain, VariationalDomain +from ..domains import ( + BaseDomain, + CollocationDomain, + DeltaPINNDomain, + VariationalDomain +) from ..physics_kernels import ( BasePhysics, BaseStrongFormPhysics, @@ -8,6 +13,7 @@ ) from typing import Callable, List, Optional import equinox as eqx +import jax.numpy as jnp class DomainPhysicsCompatabilityError(Exception): @@ -22,6 +28,7 @@ class ForwardProblem(eqx.Module): ics: List[Callable] dirichlet_bcs: List[DirichletBC] neumann_bcs: List[NeumannBC] + is_delta_pinn: bool def __init__( self, @@ -40,7 +47,8 @@ def __init__( f"Got domain of type = {type(domain)}\n" f"Got physics of type = {type(physics)}" ) - elif type(domain) is VariationalDomain: + elif type(domain) is VariationalDomain \ + or type(domain) is DeltaPINNDomain: # TODO also need a weak form catch here # TODO or just maybe make a base variational physics class if not issubclass( @@ -56,15 +64,35 @@ def __init__( else: assert False, "wtf is this domain" - if type(domain) is VariationalDomain: + if type(domain) is VariationalDomain or \ + type(domain) is DeltaPINNDomain: domain = domain.update_dof_manager(dirichlet_bcs, physics.n_dofs) - self.domain = domain + # setup physics physics = physics.update_normalization(domain) - self.physics = physics.update_var_name_to_method() + physics = physics.update_var_name_to_method() + + if type(domain) is DeltaPINNDomain: + is_delta_pinn = True + physics = eqx.tree_at( + lambda x: x.x_mins, physics, + jnp.min(domain.eigen_modes, axis=0) + ) + physics = eqx.tree_at( + lambda x: x.x_maxs, physics, + jnp.max(domain.eigen_modes, axis=0) + ) + else: + is_delta_pinn = False + + self.domain = domain + # physics = physics.update_normalization(domain) + # self.physics = physics.update_var_name_to_method() + self.physics = physics self.ics = ics self.dirichlet_bcs = dirichlet_bcs self.neumann_bcs = neumann_bcs + self.is_delta_pinn = is_delta_pinn # TODO a lot of these below are for some backwards # compatability during a transition period to diff --git a/test/constitutive_models/test_gent.py b/test/constitutive_models/test_gent.py index 500e16c..ee6c2df 100644 --- a/test/constitutive_models/test_gent.py +++ b/test/constitutive_models/test_gent.py @@ -18,7 +18,7 @@ def gent_2(): from pancax import BoundedProperty, Gent import jax - key = jax.random.key(0) + key = jax.random.PRNGKey(0) return Gent( bulk_modulus=BoundedProperty(K, K, key), shear_modulus=BoundedProperty(G, G, key), diff --git a/test/constitutive_models/test_hencky.py b/test/constitutive_models/test_hencky.py index f5060e0..2395c89 100644 --- a/test/constitutive_models/test_hencky.py +++ b/test/constitutive_models/test_hencky.py @@ -17,7 +17,7 @@ def hencky_2(): from pancax import BoundedProperty, Hencky import jax - key = jax.random.key(0) + key = jax.random.PRNGKey(0) return Hencky( bulk_modulus=BoundedProperty(K, K, key), shear_modulus=BoundedProperty(G, G, key), diff --git a/test/constitutive_models/test_neohookean.py b/test/constitutive_models/test_neohookean.py index 29ad43a..60e0d5d 100644 --- a/test/constitutive_models/test_neohookean.py +++ b/test/constitutive_models/test_neohookean.py @@ -17,7 +17,7 @@ def neohookean_2(): from pancax import BoundedProperty, NeoHookean import jax - key = jax.random.key(0) + key = jax.random.PRNGKey(0) return NeoHookean( bulk_modulus=BoundedProperty(K, K, key), shear_modulus=BoundedProperty(G, G, key), diff --git a/test/constitutive_models/test_swanson.py b/test/constitutive_models/test_swanson.py index 9346285..73204b5 100644 --- a/test/constitutive_models/test_swanson.py +++ b/test/constitutive_models/test_swanson.py @@ -29,7 +29,7 @@ def swanson_2(): from pancax import BoundedProperty, Swanson import jax - key = jax.random.key(0) + key = jax.random.PRNGKey(0) return Swanson( bulk_modulus=BoundedProperty(K, K, key), A1=BoundedProperty(A1, A1, key), diff --git a/test/test_post_processors.py b/test/test_post_processors.py index 3ba1c0a..b6135b9 100644 --- a/test/test_post_processors.py +++ b/test/test_post_processors.py @@ -1,6 +1,11 @@ import pytest +class DummyParams: + is_ensemble = False + n_ensemble = 1 + + @pytest.fixture def problem(): from pancax import ( @@ -39,7 +44,7 @@ def problem(): def params(problem): from jax import random from pancax import Parameters - key = random.key(10) + key = random.PRNGKey(10) return Parameters(problem, key) @@ -49,7 +54,9 @@ def test_post_processor(params, problem): import os mesh_file = os.path.join(Path(__file__).parent, 'mesh.g') pp = PostProcessor(mesh_file) + pp.init( + DummyParams(), problem, 'output.e', node_variables=[ @@ -71,6 +78,7 @@ def test_post_processor_bad_var_name(problem): pp = PostProcessor(mesh_file) with pytest.raises(ValueError): pp.init( + DummyParams(), problem, 'output.e', node_variables=[