diff --git a/examples/hole_array/hole_array.py b/examples/hole_array/hole_array.py index 8172568b..8339f96c 100644 --- a/examples/hole_array/hole_array.py +++ b/examples/hole_array/hole_array.py @@ -30,7 +30,7 @@ ] dofManager = DofManager(func_space, 2, ebcs) - + print(dofManager) props = {'elastic modulus': 3. * 10.0 * (1. - 2. * 0.3), 'poisson ratio': 0.3, 'version': 'coupled'} @@ -53,11 +53,11 @@ def get_ubcs(p): V = np.zeros(mesh.coords.shape) index = (mesh.nodeSets['yplus_nodeset'], 1) V = V.at[index].set(yLoc) - return dofManager.get_bc_values(V) + return p.dof_manager.get_bc_values(V) def create_field(Uu, p): - return dofManager.create_field(Uu, get_ubcs(p)) + return p.dof_manager.create_field(Uu, get_ubcs(p)) def energy_function(Uu, p): @@ -120,7 +120,7 @@ def run(): Uu = dofManager.get_unknown_values(np.zeros(mesh.coords.shape)) disp = 0.0 ivs = mech_funcs.compute_initial_state() - p = Objective.Params(disp, ivs) + p = Objective.Params(disp, ivs, dof_manager=dofManager) precond_strategy = Objective.PrecondStrategy(assemble_sparse) objective = Objective.Objective(energy_function, Uu, p, precond_strategy) diff --git a/optimism/Objective.py b/optimism/Objective.py index 8f3d3f8b..082690e6 100644 --- a/optimism/Objective.py +++ b/optimism/Objective.py @@ -1,35 +1,89 @@ +from optimism.FunctionSpace import DofManager from optimism.JaxConfig import * from optimism.SparseCholesky import SparseCholesky -import numpy as onp from scipy.sparse import diags as sparse_diags from scipy.sparse import csc_matrix +from typing import Optional +import equinox as eqx +import numpy as onp + # static vs dynamics # differentiable vs undifferentiable -Params = namedtuple('Params', - ['bc_data', - 'state_data', - 'design_data', - 'app_data', - 'time', - 'dynamic_data'], - defaults=(None,None,None,None,None,None)) +# TODO fix some of these type hints for better clarity. +# maybe this will help formalize what's what when +class Params(eqx.Module): + bc_data: any + state_data: any + design_data: any + app_data: any + time: any + dynamic_data: any + # Need the eqx.field(static=True) since DofManager + # is composed of mainly og numpy arrays which leads + # to the error + #jax.errors.NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[x,x]) + dof_manager: DofManager = eqx.field(static=True) + + def __init__( + self, + bc_data = None, + state_data = None, + design_data = None, + app_data = None, + time = None, + dynamic_data = None, + dof_manager: Optional[DofManager] = None + ): + self.bc_data = bc_data + self.state_data = state_data + self.design_data = design_data + self.app_data = app_data + self.time = time + self.dynamic_data = dynamic_data + self.dof_manager = dof_manager + + def __getitem__(self, index): + if index == 0: + return self.bc_data + elif index == 1: + return self.state_data + elif index == 2: + return self.design_data + elif index == 3: + return self.app_data + elif index == 4: + return self.time + elif index == 5: + return self.dynamic_data + elif index == 6: + return self.dof_manager + else: + raise ValueError(f'Bad index value {index}') +# written for backwards compatability +# we can just use the eqx.tree_at syntax in simulations +# or we could write a single method bound to Params for this... def param_index_update(p, index, newParam): - if index==0: - return Params(newParam, p[1], p[2], p[3], p[4], p[5]) - if index==1: - return Params(p[0], newParam, p[2], p[3], p[4], p[5]) - if index==2: - return Params(p[0], p[1], newParam, p[3], p[4], p[5]) - if index==3: - return Params(p[0], p[1], p[2], newParam, p[4], p[5]) - if index==4: - return Params(p[0], p[1], p[2], p[3], newParam, p[5]) - if index==5: - return Params(p[0], p[1], p[2], p[3], p[4], newParam) - print('invalid index passed to param_index_update = ', index) + if index == 0: + p = eqx.tree_at(lambda x: x.bc_data, p, newParam) + elif index == 1: + p = eqx.tree_at(lambda x: x.state_data, p, newParam) + elif index == 2: + p = eqx.tree_at(lambda x: x.design_data, p, newParam) + elif index == 3: + p = eqx.tree_at(lambda x: x.app_data, p, newParam) + elif index == 4: + p = eqx.tree_at(lambda x: x.time, p, newParam) + elif index == 5: + p = eqx.tree_at(lambda x: x.dynamic_data, p, newParam) + elif index == 6: + p = eqx.tree_at(lambda x: x.dof_manager, p, newParam) + else: + raise ValueError(f'Bad index value {index}') + + return p class PrecondStrategy: