From aa8ac0f906b06f7576de8208571bc71623116947 Mon Sep 17 00:00:00 2001 From: "Craig M. Hamel" Date: Wed, 19 Nov 2025 13:10:46 -0700 Subject: [PATCH] adding a dataloader for residual collocation points and patching some poisson examples. --- .gitignore | 1 + .../poisson/collocation_example.py | 59 +++++++++-------- .../poisson/variational_example.py | 63 ++++++++++--------- pancax/__init__.py | 7 ++- pancax/data/full_field_data.py | 4 +- pancax/domains/__init__.py | 3 +- pancax/domains/collocation_domain.py | 47 ++++++++++++++ pancax/loss_functions/base_loss_function.py | 9 ++- pancax/networks/fields.py | 21 ++++++- pancax/networks/parameters.py | 20 ++++-- pancax/post_processor.py | 37 +++++++---- 11 files changed, 185 insertions(+), 86 deletions(-) diff --git a/.gitignore b/.gitignore index 3902c04..4a9c160 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ rocm_venv/ venv/ dd_solver.dat krylov_solver.dat +.vscode/ diff --git a/examples/forward_problems/poisson/collocation_example.py b/examples/forward_problems/poisson/collocation_example.py index c75e493..fd46581 100644 --- a/examples/forward_problems/poisson/collocation_example.py +++ b/examples/forward_problems/poisson/collocation_example.py @@ -3,13 +3,12 @@ ################## # for reproducibility ################## -key = random.key(10) +key = random.PRNGKey(10) ################## # file management ################## mesh_file = find_mesh_file('mesh_quad4.g') -logger = Logger('pinn.log', log_every=250) pp = PostProcessor(mesh_file, 'exodus') ################## @@ -30,15 +29,13 @@ def bc_func(x, t, z): x, y = x[0], x[1] return x * (1. - x) * y * (1. - y) * z -physics = physics.update_dirichlet_bc_func(bc_func) - ics = [ ] essential_bcs = [ - EssentialBC('nset_1', 0), - EssentialBC('nset_2', 0), - EssentialBC('nset_3', 0), - EssentialBC('nset_4', 0), + DirichletBC('nset_1', 0), + DirichletBC('nset_2', 0), + DirichletBC('nset_3', 0), + DirichletBC('nset_4', 0), ] natural_bcs = [ ] @@ -51,31 +48,39 @@ def bc_func(x, t, z): ################## # ML setup ################## -n_dims = domain.coords.shape[1] -field = MLP(n_dims + 1, physics.n_dofs, 50, 3, jax.nn.tanh, key) -params = FieldPhysicsPair(field, problem.physics) +params = Parameters(problem, key, dirichlet_bc_func=bc_func, network_type=ResNet) + +def loss_function(params, problem, inputs, outputs): + field, physics, state = params + residuals = jax.vmap(physics.strong_form_residual, in_axes=(None, 0, 0))( + field, inputs[:, 0:2], inputs[:, 2:3] + ) + return jnp.square(residuals - outputs).mean(), dict(nothing=0.0) + +loss_function = UserDefinedLossFunction(loss_function) -loss_function = StrongFormResidualLoss() opt = Adam(loss_function, learning_rate=1e-3, has_aux=True) -opt_st = opt.init(params) +opt, opt_st = opt.init(params) -for epoch in range(5000): - params, opt_st, loss = opt.step(params, problem, opt_st) +dataloader = CollocationDataLoader(problem.domain, num_fields=1) +for epoch in range(50000): + for inputs, outputs in dataloader.dataloader(512): + params, opt_st, loss = opt.step(params, opt_st, problem, inputs, outputs) if epoch % 100 == 0: print(epoch) print(loss) -################## -# post-processing -################## -pp.init(problem, 'output.e', - node_variables=['field_values'] -) -pp.write_outputs(params, problem) -pp.close() +# ################## +# # post-processing +# ################## +# pp.init(params, problem, 'output.e', +# node_variables=['field_values'] +# ) +# pp.write_outputs(params, problem) +# pp.close() -import pyvista as pv -exo = pv.read('output.e')[0][0] -exo.set_active_scalars('u') -exo.plot(show_axes=False, cpos='xy', show_edges=True) \ No newline at end of file +# # import pyvista as pv +# # exo = pv.read('output.e')[0][0] +# # exo.set_active_scalars('u') +# # exo.plot(show_axes=False, cpos='xy', show_edges=True) \ No newline at end of file diff --git a/examples/forward_problems/poisson/variational_example.py b/examples/forward_problems/poisson/variational_example.py index 99e619a..d583a24 100644 --- a/examples/forward_problems/poisson/variational_example.py +++ b/examples/forward_problems/poisson/variational_example.py @@ -3,19 +3,20 @@ ################## # for reproducibility ################## -key = random.key(100) +# key = random.key(100) +key = random.PRNGKey(100) ################## # file management ################## mesh_file = find_mesh_file('mesh_quad4.g') -logger = Logger('pinn.log', log_every=250) +# logger = Logger('pinn.log', log_every=250) pp = PostProcessor(mesh_file, 'exodus') ################## # domain setup ################## -times = jnp.linspace(0.0, 0.0, 1) +times = jnp.linspace(0.0, 1.0, 2) domain = VariationalDomain(mesh_file, times) ################## @@ -30,7 +31,7 @@ def bc_func(x, t, z): x, y = x[0], x[1] return x * (1. - x) * y * (1. - y) * z -physics = physics.update_dirichlet_bc_func(bc_func) +# physics = physics.update_dirichlet_bc_func(bc_func) ics = [ ] @@ -52,44 +53,44 @@ def bc_func(x, t, z): # ML setup ################## loss_function = EnergyLoss() -params = Parameters(problem, key) +params = Parameters(problem, key, dirichlet_bc_func=bc_func) -# # pre-train with Adam -# opt = Adam(loss_function, learning_rate=1e-3, has_aux=True) -# opt_st = opt.init(params) -# for epoch in range(2500): -# params, opt_st, loss = opt.step(params, problem, opt_st) +# pre-train with Adam +opt = Adam(loss_function, learning_rate=1e-3, has_aux=True) +opt, opt_st = opt.init(params) +for epoch in range(2500): + params, opt_st, loss = opt.step(params, opt_st, problem) -# if epoch % 100 == 0: -# print(epoch) -# print(loss) + if epoch % 100 == 0: + print(epoch) + print(loss) -# switch to LBFGS -params, static = eqx.partition(params, eqx.is_inexact_array) +# # switch to LBFGS +# params, static = eqx.partition(params, eqx.is_inexact_array) -def loss_func(params): - params = eqx.combine(params, static) - loss, aux = loss_function(params, problem) - return loss +# def loss_func(params): +# params = eqx.combine(params, static) +# loss, aux = loss_function(params, problem) +# return loss -opt = optax.lbfgs(memory_size=1) -opt_st = opt.init(params) -value_and_grad = jax.jit(optax.value_and_grad_from_state(loss_func)) -for _ in range(200): - value, grad = value_and_grad(params, state=opt_st) - updates, opt_st = opt.update( - grad, opt_st, params, value=value, grad=grad, value_fn=loss_func - ) - params = optax.apply_updates(params, updates) - print('Objective function: {:.2E}'.format(value)) +# opt = optax.lbfgs(memory_size=1) +# opt_st = opt.init(params) +# value_and_grad = jax.jit(optax.value_and_grad_from_state(loss_func)) +# for _ in range(200): +# value, grad = value_and_grad(params, state=opt_st) +# updates, opt_st = opt.update( +# grad, opt_st, params, value=value, grad=grad, value_fn=loss_func +# ) +# params = optax.apply_updates(params, updates) +# print('Objective function: {:.2E}'.format(value)) ################## # post-processing ################## -params = eqx.combine(params, static) -pp.init(problem, 'output.e', +# params = eqx.combine(params, static) +pp.init(params, problem, 'output.e', node_variables=['field_values'] ) pp.write_outputs(params, problem) diff --git a/pancax/__init__.py b/pancax/__init__.py index 4b7eee9..8137ec9 100644 --- a/pancax/__init__.py +++ b/pancax/__init__.py @@ -17,7 +17,11 @@ SimpleFeFv, \ WLF from .data import FullFieldData, FullFieldDataLoader, GlobalData -from .domains import CollocationDomain, DeltaPINNDomain, VariationalDomain +from .domains import \ + CollocationDataLoader, \ + CollocationDomain, \ + DeltaPINNDomain, \ + VariationalDomain from .fem import \ DofManager, \ FunctionSpace, \ @@ -118,6 +122,7 @@ "FullFieldDataLoader", "GlobalData", # domains + "CollocationDataLoader", "CollocationDomain", "DeltaPINNDomain", "VariationalDomain", diff --git a/pancax/data/full_field_data.py b/pancax/data/full_field_data.py index 6bb17b4..78a816e 100644 --- a/pancax/data/full_field_data.py +++ b/pancax/data/full_field_data.py @@ -1,4 +1,4 @@ -from jaxtyping import Array, Float +from jaxtyping import Array from typing import List import equinox as eqx import jax.numpy as jnp @@ -93,7 +93,7 @@ def __init__(self, data: FullFieldData) -> None: def __len__(self): return len(self.data) - def dataloader(self, batch_size: int) -> Float[Array, "bs d"]: + def dataloader(self, batch_size: int): perm = np.random.permutation(self.indices) start = 0 end = batch_size diff --git a/pancax/domains/__init__.py b/pancax/domains/__init__.py index 10ace7a..0c4232c 100644 --- a/pancax/domains/__init__.py +++ b/pancax/domains/__init__.py @@ -1,10 +1,11 @@ from .base import BaseDomain -from .collocation_domain import CollocationDomain +from .collocation_domain import CollocationDataLoader, CollocationDomain from .delta_pinn_domain import DeltaPINNDomain from .variational_domain import VariationalDomain __all__ = [ "BaseDomain", + "CollocationDataLoader", "CollocationDomain", "DeltaPINNDomain", "VariationalDomain" diff --git a/pancax/domains/collocation_domain.py b/pancax/domains/collocation_domain.py index 9d744e1..871e736 100644 --- a/pancax/domains/collocation_domain.py +++ b/pancax/domains/collocation_domain.py @@ -2,6 +2,9 @@ from jaxtyping import Array, Float from pancax.fem import Mesh from typing import Optional, Union +import equinox as eqx +import jax.numpy as jnp +import numpy as np class CollocationDomain(BaseDomain): @@ -15,3 +18,47 @@ def __init__( p_order: Optional[int] = 1 ) -> None: super().__init__(mesh_file, times, p_order=p_order) + + +class CollocationDataLoader(eqx.Module): + indices: np.ndarray + inputs: Float[Array, "bs ni"] + outputs: Float[Array, "bs no"] + + def __init__( + self, + domain: CollocationDomain, + num_fields: int + ) -> None: + inputs = [] + + # For now, just a simple collection of mesh coordinates + # TODO add sampling strategies + coords = domain.coords + ones = jnp.ones((coords.shape[0], 1)) + for time in domain.times: + times = time * ones + temp = jnp.hstack((coords, times)) + inputs.append(temp) + + inputs = jnp.vstack(inputs) + outputs = jnp.zeros((inputs.shape[0], num_fields)) + + indices = np.arange(len(inputs)) + + self.indices = indices + self.inputs = inputs + self.outputs = outputs + + def __len__(self): + return len(self.inputs) + + def dataloader(self, batch_size: int): + perm = np.random.permutation(self.indices) + start = 0 + end = batch_size + while end <= len(self): + batch_perm = perm[start:end] + yield self.inputs[batch_perm], self.outputs[batch_perm] + start = end + end = start + batch_size diff --git a/pancax/loss_functions/base_loss_function.py b/pancax/loss_functions/base_loss_function.py index abd0b01..522ae14 100644 --- a/pancax/loss_functions/base_loss_function.py +++ b/pancax/loss_functions/base_loss_function.py @@ -72,7 +72,10 @@ def _vmap_func(n): return problem.physics.constitutive_model.\ initial_state() - state_old = vmap(vmap(_vmap_func))( - jnp.zeros((ne, nq)) - ) + if hasattr(problem.physics, "constitutive_model"): + state_old = vmap(vmap(_vmap_func))( + jnp.zeros((ne, nq)) + ) + else: + state_old = jnp.zeros((ne, nq, 0)) return state_old diff --git a/pancax/networks/fields.py b/pancax/networks/fields.py index c82c3be..43efebe 100644 --- a/pancax/networks/fields.py +++ b/pancax/networks/fields.py @@ -10,6 +10,7 @@ class Field(AbstractPancaxModel): dirichlet_bc_func: Callable networks: Union[eqx.Module, List[eqx.Module]] + normalize_time: bool # to help with static problems seperate_networks: bool t_min: Float[Array, "1"] t_max: Float[Array, "1"] @@ -69,10 +70,26 @@ def init(k): self.x_mins = jnp.min(problem.coords, axis=0) self.x_maxs = jnp.max(problem.coords, axis=0) + if jnp.allclose(self.t_min, self.t_max): + self.normalize_time = False + else: + self.normalize_time = True + # def __call__(self, x): def __call__(self, x, t): x_norm = (x - self.x_mins) / (self.x_maxs - self.x_mins) - t_norm = (t - self.t_min) / (self.t_max - self.t_min) + + # if self.normalize_time: + # t_norm = (t - self.t_min) / (self.t_max - self.t_min) + # else: + # t_norm = t + t_norm = jax.lax.cond( + self.normalize_time, + lambda z: (z - self.t_min) / (self.t_max - self.t_min), + lambda z: z, + t + ) + inputs = jnp.hstack((x_norm, t_norm)) if self.seperate_networks: @@ -86,6 +103,4 @@ def func(params, x): # z = self.networks(x) z = self.networks(inputs) - # TODO call dirichlet bc func - return self.dirichlet_bc_func(x, t, z) diff --git a/pancax/networks/parameters.py b/pancax/networks/parameters.py index 5170183..d541f42 100644 --- a/pancax/networks/parameters.py +++ b/pancax/networks/parameters.py @@ -42,9 +42,14 @@ def __init__( # TODO what do we do with this guy? state = None - physics = eqx.tree_at( - lambda x: x.constitutive_model, problem.physics, constitutive_model - ) + if hasattr(problem.physics, "constitutive_model"): + physics = eqx.tree_at( + lambda x: x.constitutive_model, problem.physics, + constitutive_model + ) + else: + physics = problem.physics + self.fields = fields self.physics = physics self.state = state @@ -66,11 +71,16 @@ def __init__( network_type: Optional[eqx.Module] = MLP, seperate_networks: Optional[bool] = False ) -> None: + if hasattr(problem.physics, "constitutive_model"): + constitutive_model = problem.physics.constitutive_model + else: + constitutive_model = None + if len(key.shape) == 1: is_ensemble = False n_ensemble = 1 parameters = _Parameters( - problem, problem.physics.constitutive_model, key, + problem, constitutive_model, key, dirichlet_bc_func=dirichlet_bc_func, network_type=network_type, seperate_networks=seperate_networks @@ -88,7 +98,7 @@ def vmap_func(key, constitutive_model): seperate_networks=seperate_networks ) - parameters = vmap_func(key, problem.physics.constitutive_model) + parameters = vmap_func(key, constitutive_model) else: raise ValueError( f"Invalid shape for key {key} with shape {key.shape}" diff --git a/pancax/post_processor.py b/pancax/post_processor.py index 916bed3..0ae40ad 100644 --- a/pancax/post_processor.py +++ b/pancax/post_processor.py @@ -1,4 +1,5 @@ from abc import abstractmethod +from pancax.domains.variational_domain import VariationalDomain from typing import List import equinox as eqx import jax @@ -259,9 +260,12 @@ def _write_step_outputs( ) # calculate something with state update at least once to update # state later - _, state_new = physics.potential_energy( - params, problem.domain, time, us, state_old, dt - ) + if type(problem.domain) is VariationalDomain: + _, state_new = physics.potential_energy( + params, problem.domain, time, us, state_old, dt + ) + else: + state_new = None node_var_num = 0 for var in self.node_variables: @@ -354,18 +358,25 @@ def _write_outputs(self, params, problem, output_file): times = problem.times with nc.Dataset(output_file, "a") as dataset: - ne = problem.domain.conns.shape[0] - nq = len(problem.domain.fspace.quadrature_rule) + try: + ne = problem.domain.conns.shape[0] + nq = len(problem.domain.fspace.quadrature_rule) - def _vmap_func(n): - return problem.physics.constitutive_model.\ - initial_state() + def _vmap_func(n): + return problem.physics.constitutive_model.\ + initial_state() - # TODO assumes constantly spaced timesteps - dt = problem.times[1] - problem.times[0] - state_old = jax.vmap(jax.vmap(_vmap_func))( - jnp.zeros((ne, nq)) - ) + # TODO assumes constantly spaced timesteps + dt = problem.times[1] - problem.times[0] + + if hasattr(problem.physics, "constitutive_model"): + state_old = jax.vmap(jax.vmap(_vmap_func))( + jnp.zeros((ne, nq)) + ) + else: + state_old = jnp.zeros((ne, nq, 0)) + except AttributeError: + state_old = None for n, time in enumerate(times): if n == 0: