diff --git a/pancax/domains/base.py b/pancax/domains/base.py index fc10077..1c248a8 100644 --- a/pancax/domains/base.py +++ b/pancax/domains/base.py @@ -21,7 +21,6 @@ class BaseDomain(eqx.Module): mesh: Mesh coords: Float[Array, "nn nd"] times: Union[Float[Array, "nt"], Float[Array, "nn 1"]] - dt: float # TODO assumes times were generated with linspace def __init__( self, mesh_file: str, times: Float[Array, "nt"], @@ -49,17 +48,7 @@ def __init__( if times[i] >= times[i + 1]: raise SimulationTimesNotStrictlyIncreasingException() - if len(times) > 1: - dt = times[1] - times[0] - else: - print( - "WARNING: setting dt = 0. since this is a \ - static problem. Is this what you want?" - ) - dt = 0.0 - self.mesh_file = mesh_file self.mesh = mesh self.coords = jnp.array(mesh.coords) self.times = times - self.dt = dt diff --git a/pancax/loss_functions/weak_form_loss_functions.py b/pancax/loss_functions/weak_form_loss_functions.py index 3e24954..6445254 100644 --- a/pancax/loss_functions/weak_form_loss_functions.py +++ b/pancax/loss_functions/weak_form_loss_functions.py @@ -173,34 +173,6 @@ class PathDependentEnergyLoss(PhysicsLossFunction): def __init__(self, weight: Optional[float] = 1.0) -> None: self.weight = weight - def __call__old(self, params, problem): - field, physics, state = params - - ne = problem.domain.conns.shape[0] - nq = len(problem.domain.fspace.quadrature_rule) - - def _vmap_func(n): - return problem.physics.constitutive_model.\ - initial_state() - - state_old = vmap(vmap(_vmap_func))( - jnp.zeros((ne, nq)) - ) - - # TODO not adaptive - # dumb implementation below - dt = problem.times[1] - problem.times[0] - pi = 0.0 - for n in range(problem.times.shape[0]): - t = problem.times[n] - pi_t, state_new = self.load_step(params, problem, t, dt, state_old) - - state_old = state_new - pi = pi + pi_t - - loss = pi - return self.weight * loss, dict(energy=pi) - def __call__(self, params, problem): ne = problem.domain.conns.shape[0]