Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ rocm_venv/
venv/
dd_solver.dat
krylov_solver.dat
.vscode/
59 changes: 32 additions & 27 deletions examples/forward_problems/poisson/collocation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
##################
# for reproducibility
##################
key = random.key(10)
key = random.PRNGKey(10)

##################
# file management
##################
mesh_file = find_mesh_file('mesh_quad4.g')
logger = Logger('pinn.log', log_every=250)
pp = PostProcessor(mesh_file, 'exodus')

##################
Expand All @@ -30,15 +29,13 @@ def bc_func(x, t, z):
x, y = x[0], x[1]
return x * (1. - x) * y * (1. - y) * z

physics = physics.update_dirichlet_bc_func(bc_func)

ics = [
]
essential_bcs = [
EssentialBC('nset_1', 0),
EssentialBC('nset_2', 0),
EssentialBC('nset_3', 0),
EssentialBC('nset_4', 0),
DirichletBC('nset_1', 0),
DirichletBC('nset_2', 0),
DirichletBC('nset_3', 0),
DirichletBC('nset_4', 0),
]
natural_bcs = [
]
Expand All @@ -51,31 +48,39 @@ def bc_func(x, t, z):
##################
# ML setup
##################
n_dims = domain.coords.shape[1]
field = MLP(n_dims + 1, physics.n_dofs, 50, 3, jax.nn.tanh, key)
params = FieldPhysicsPair(field, problem.physics)
params = Parameters(problem, key, dirichlet_bc_func=bc_func, network_type=ResNet)

def loss_function(params, problem, inputs, outputs):
field, physics, state = params
residuals = jax.vmap(physics.strong_form_residual, in_axes=(None, 0, 0))(
field, inputs[:, 0:2], inputs[:, 2:3]
)
return jnp.square(residuals - outputs).mean(), dict(nothing=0.0)

loss_function = UserDefinedLossFunction(loss_function)

loss_function = StrongFormResidualLoss()
opt = Adam(loss_function, learning_rate=1e-3, has_aux=True)
opt_st = opt.init(params)
opt, opt_st = opt.init(params)

for epoch in range(5000):
params, opt_st, loss = opt.step(params, problem, opt_st)
dataloader = CollocationDataLoader(problem.domain, num_fields=1)
for epoch in range(50000):
for inputs, outputs in dataloader.dataloader(512):
params, opt_st, loss = opt.step(params, opt_st, problem, inputs, outputs)

if epoch % 100 == 0:
print(epoch)
print(loss)

##################
# post-processing
##################
pp.init(problem, 'output.e',
node_variables=['field_values']
)
pp.write_outputs(params, problem)
pp.close()
# ##################
# # post-processing
# ##################
# pp.init(params, problem, 'output.e',
# node_variables=['field_values']
# )
# pp.write_outputs(params, problem)
# pp.close()

import pyvista as pv
exo = pv.read('output.e')[0][0]
exo.set_active_scalars('u')
exo.plot(show_axes=False, cpos='xy', show_edges=True)
# # import pyvista as pv
# # exo = pv.read('output.e')[0][0]
# # exo.set_active_scalars('u')
# # exo.plot(show_axes=False, cpos='xy', show_edges=True)
63 changes: 32 additions & 31 deletions examples/forward_problems/poisson/variational_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@
##################
# for reproducibility
##################
key = random.key(100)
# key = random.key(100)
key = random.PRNGKey(100)

##################
# file management
##################
mesh_file = find_mesh_file('mesh_quad4.g')
logger = Logger('pinn.log', log_every=250)
# logger = Logger('pinn.log', log_every=250)
pp = PostProcessor(mesh_file, 'exodus')

##################
# domain setup
##################
times = jnp.linspace(0.0, 0.0, 1)
times = jnp.linspace(0.0, 1.0, 2)
domain = VariationalDomain(mesh_file, times)

##################
Expand All @@ -30,7 +31,7 @@ def bc_func(x, t, z):
x, y = x[0], x[1]
return x * (1. - x) * y * (1. - y) * z

physics = physics.update_dirichlet_bc_func(bc_func)
# physics = physics.update_dirichlet_bc_func(bc_func)

ics = [
]
Expand All @@ -52,44 +53,44 @@ def bc_func(x, t, z):
# ML setup
##################
loss_function = EnergyLoss()
params = Parameters(problem, key)
params = Parameters(problem, key, dirichlet_bc_func=bc_func)

# # pre-train with Adam
# opt = Adam(loss_function, learning_rate=1e-3, has_aux=True)
# opt_st = opt.init(params)
# for epoch in range(2500):
# params, opt_st, loss = opt.step(params, problem, opt_st)
# pre-train with Adam
opt = Adam(loss_function, learning_rate=1e-3, has_aux=True)
opt, opt_st = opt.init(params)
for epoch in range(2500):
params, opt_st, loss = opt.step(params, opt_st, problem)

# if epoch % 100 == 0:
# print(epoch)
# print(loss)
if epoch % 100 == 0:
print(epoch)
print(loss)



# switch to LBFGS
params, static = eqx.partition(params, eqx.is_inexact_array)
# # switch to LBFGS
# params, static = eqx.partition(params, eqx.is_inexact_array)

def loss_func(params):
params = eqx.combine(params, static)
loss, aux = loss_function(params, problem)
return loss
# def loss_func(params):
# params = eqx.combine(params, static)
# loss, aux = loss_function(params, problem)
# return loss

opt = optax.lbfgs(memory_size=1)
opt_st = opt.init(params)
value_and_grad = jax.jit(optax.value_and_grad_from_state(loss_func))
for _ in range(200):
value, grad = value_and_grad(params, state=opt_st)
updates, opt_st = opt.update(
grad, opt_st, params, value=value, grad=grad, value_fn=loss_func
)
params = optax.apply_updates(params, updates)
print('Objective function: {:.2E}'.format(value))
# opt = optax.lbfgs(memory_size=1)
# opt_st = opt.init(params)
# value_and_grad = jax.jit(optax.value_and_grad_from_state(loss_func))
# for _ in range(200):
# value, grad = value_and_grad(params, state=opt_st)
# updates, opt_st = opt.update(
# grad, opt_st, params, value=value, grad=grad, value_fn=loss_func
# )
# params = optax.apply_updates(params, updates)
# print('Objective function: {:.2E}'.format(value))

##################
# post-processing
##################
params = eqx.combine(params, static)
pp.init(problem, 'output.e',
# params = eqx.combine(params, static)
pp.init(params, problem, 'output.e',
node_variables=['field_values']
)
pp.write_outputs(params, problem)
Expand Down
7 changes: 6 additions & 1 deletion pancax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
SimpleFeFv, \
WLF
from .data import FullFieldData, FullFieldDataLoader, GlobalData
from .domains import CollocationDomain, DeltaPINNDomain, VariationalDomain
from .domains import \
CollocationDataLoader, \
CollocationDomain, \
DeltaPINNDomain, \
VariationalDomain
from .fem import \
DofManager, \
FunctionSpace, \
Expand Down Expand Up @@ -118,6 +122,7 @@
"FullFieldDataLoader",
"GlobalData",
# domains
"CollocationDataLoader",
"CollocationDomain",
"DeltaPINNDomain",
"VariationalDomain",
Expand Down
4 changes: 2 additions & 2 deletions pancax/data/full_field_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from jaxtyping import Array, Float
from jaxtyping import Array
from typing import List
import equinox as eqx
import jax.numpy as jnp
Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(self, data: FullFieldData) -> None:
def __len__(self):
return len(self.data)

def dataloader(self, batch_size: int) -> Float[Array, "bs d"]:
def dataloader(self, batch_size: int):
perm = np.random.permutation(self.indices)
start = 0
end = batch_size
Expand Down
3 changes: 2 additions & 1 deletion pancax/domains/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from .base import BaseDomain
from .collocation_domain import CollocationDomain
from .collocation_domain import CollocationDataLoader, CollocationDomain
from .delta_pinn_domain import DeltaPINNDomain
from .variational_domain import VariationalDomain

__all__ = [
"BaseDomain",
"CollocationDataLoader",
"CollocationDomain",
"DeltaPINNDomain",
"VariationalDomain"
Expand Down
47 changes: 47 additions & 0 deletions pancax/domains/collocation_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from jaxtyping import Array, Float
from pancax.fem import Mesh
from typing import Optional, Union
import equinox as eqx
import jax.numpy as jnp
import numpy as np


class CollocationDomain(BaseDomain):
Expand All @@ -15,3 +18,47 @@ def __init__(
p_order: Optional[int] = 1
) -> None:
super().__init__(mesh_file, times, p_order=p_order)


class CollocationDataLoader(eqx.Module):
indices: np.ndarray
inputs: Float[Array, "bs ni"]
outputs: Float[Array, "bs no"]

def __init__(
self,
domain: CollocationDomain,
num_fields: int
) -> None:
inputs = []

# For now, just a simple collection of mesh coordinates
# TODO add sampling strategies
coords = domain.coords
ones = jnp.ones((coords.shape[0], 1))
for time in domain.times:
times = time * ones
temp = jnp.hstack((coords, times))
inputs.append(temp)

inputs = jnp.vstack(inputs)
outputs = jnp.zeros((inputs.shape[0], num_fields))

indices = np.arange(len(inputs))

self.indices = indices
self.inputs = inputs
self.outputs = outputs

def __len__(self):
return len(self.inputs)

def dataloader(self, batch_size: int):
perm = np.random.permutation(self.indices)
start = 0
end = batch_size
while end <= len(self):
batch_perm = perm[start:end]
yield self.inputs[batch_perm], self.outputs[batch_perm]
start = end
end = start + batch_size
9 changes: 6 additions & 3 deletions pancax/loss_functions/base_loss_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def _vmap_func(n):
return problem.physics.constitutive_model.\
initial_state()

state_old = vmap(vmap(_vmap_func))(
jnp.zeros((ne, nq))
)
if hasattr(problem.physics, "constitutive_model"):
state_old = vmap(vmap(_vmap_func))(
jnp.zeros((ne, nq))
)
else:
state_old = jnp.zeros((ne, nq, 0))
return state_old
21 changes: 18 additions & 3 deletions pancax/networks/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class Field(AbstractPancaxModel):
dirichlet_bc_func: Callable
networks: Union[eqx.Module, List[eqx.Module]]
normalize_time: bool # to help with static problems
seperate_networks: bool
t_min: Float[Array, "1"]
t_max: Float[Array, "1"]
Expand Down Expand Up @@ -69,10 +70,26 @@ def init(k):
self.x_mins = jnp.min(problem.coords, axis=0)
self.x_maxs = jnp.max(problem.coords, axis=0)

if jnp.allclose(self.t_min, self.t_max):
self.normalize_time = False
else:
self.normalize_time = True

# def __call__(self, x):
def __call__(self, x, t):
x_norm = (x - self.x_mins) / (self.x_maxs - self.x_mins)
t_norm = (t - self.t_min) / (self.t_max - self.t_min)

# if self.normalize_time:
# t_norm = (t - self.t_min) / (self.t_max - self.t_min)
# else:
# t_norm = t
t_norm = jax.lax.cond(
self.normalize_time,
lambda z: (z - self.t_min) / (self.t_max - self.t_min),
lambda z: z,
t
)

inputs = jnp.hstack((x_norm, t_norm))

if self.seperate_networks:
Expand All @@ -86,6 +103,4 @@ def func(params, x):
# z = self.networks(x)
z = self.networks(inputs)

# TODO call dirichlet bc func

return self.dirichlet_bc_func(x, t, z)
Loading
Loading