Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pancax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from .networks import \
Field, \
FieldPhysicsPair, \
InputPolyconvexNN, \
Linear, \
MLP, \
MLPBasis, \
Expand Down Expand Up @@ -165,6 +166,7 @@
# networks
"Field",
"FieldPhysicsPair",
"InputPolyconvexNN",
"Linear",
"MLP",
"MLPBasis",
Expand Down
17 changes: 0 additions & 17 deletions pancax/loss_functions/data_loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def __init__(self, weight: Optional[float] = 1.0):

def __call__(self, params, problem, inputs, outputs):
field_network, _, _ = params
# n_dims = inputs.shape[-1]
n_dims = problem.coords.shape[1]
xs = inputs[:, 0:n_dims]
ts = inputs[:, n_dims]
Expand All @@ -23,19 +22,3 @@ def __call__(self, params, problem, inputs, outputs):
loss = jnp.square(u_pred - outputs).mean()
aux = {"field_data_loss": loss}
return self.weight * loss, aux

def __call__old(self, params, domain):
field_network, _, _ = params
n_dims = domain.coords.shape[1]
xs = domain.field_data.inputs[:, 0:n_dims]
# TODO need time normalization
ts = domain.field_data.inputs[:, n_dims]
# TODO below is currenlty the odd ball for the field_value API
u_pred = vmap(domain.physics.field_values, in_axes=(None, 0, 0))(
field_network, xs, ts
)

# TODO add output normalization
loss = jnp.square(u_pred - domain.field_data.outputs).mean()
aux = {"field_data_loss": loss}
return self.weight * loss, aux
2 changes: 2 additions & 0 deletions pancax/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .fields import Field
from .field_physics_pair import FieldPhysicsPair
from .initialization import trunc_init
from .input_polyconvex_nn import InputPolyconvexNN
from .ml_dirichlet_field import MLDirichletField
from .mlp import Linear
from .mlp import MLP
Expand All @@ -16,6 +17,7 @@ def Network(network_type, *args, **kwargs):
__all__ = [
"Field",
"FieldPhysicsPair",
"InputPolyconvexNN",
"Linear",
"MLDirichletField",
"MLP",
Expand Down
220 changes: 220 additions & 0 deletions pancax/networks/input_polyconvex_nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
from typing import Callable, List, Optional
import equinox as eqx
import jax
import jax.numpy as jnp


# TODO still need to apply weight enforcement manually on weights
# of layers with pos in the name


def _icnn_init(key, shape):
in_features = shape[0]
k = 1. / in_features
return jax.random.uniform(key=key, shape=shape, minval=-k, maxval=k)


class InputPolyconvexNN(eqx.Module):
n_convex: int
n_inputs: int
n_layers: int
x1_xx_pos: eqx.nn.Linear
x1_xy: eqx.nn.Linear
y1: eqx.nn.Linear
x_h_plus_1_xx_pos: List[eqx.nn.Linear]
x_h_plus_1_xy: List[eqx.nn.Linear]
y_h_plus_1: List[eqx.nn.Linear]
x_n_xx_pos: eqx.nn.Linear
x_n_xy: eqx.nn.Linear
activation_x: Callable
activation_y: Callable

def __init__(
self,
n_inputs: int,
n_outputs: int,
n_convex: int,
activation_x: Callable,
activation_y: Callable,
# key: jax.random.PRNGKey,
n_layers: Optional[int] = 3,
n_neurons_x: Optional[int] = 40,
n_neurons_y: Optional[int] = 30,
*,
key: jax.random.PRNGKey
):
keys = jax.random.split(key, 3 + 3 * n_layers + 2)

x1_xx_pos = eqx.nn.Linear(
n_convex, n_neurons_x,
use_bias=True,
key=keys[0]
)
x1_xx_pos = eqx.tree_at(
lambda x: x.weight,
x1_xx_pos,
_icnn_init(key=keys[0], shape=x1_xx_pos.weight.shape)
)
x1_xx_pos = eqx.tree_at(
lambda x: x.bias,
x1_xx_pos,
_icnn_init(key=keys[0], shape=x1_xx_pos.bias.shape)
)
x1_xy = eqx.nn.Linear(
n_neurons_y, n_neurons_x,
use_bias=False,
key=keys[1]
)
x1_xy = eqx.tree_at(
lambda x: x.weight,
x1_xy,
_icnn_init(key=keys[1], shape=x1_xy.weight.shape)
)
y1 = eqx.nn.Linear(
n_inputs - n_convex, n_neurons_y,
use_bias=True,
key=keys[2]
)
y1 = eqx.tree_at(
lambda x: x.weight,
y1,
_icnn_init(key=keys[2], shape=y1.weight.shape)
)
y1 = eqx.tree_at(
lambda x: x.bias,
y1,
_icnn_init(key=keys[2], shape=y1.bias.shape)
)

x_h_plus_1_xx_pos = []
x_h_plus_1_xy = []
y_h_plus_1 = []

for n in range(n_layers):
x_h_plus_1_xx_pos.append(eqx.nn.Linear(
n_neurons_x, n_neurons_x,
use_bias=True,
key=keys[3 + 3 * n]
))
x_h_plus_1_xx_pos[n] = eqx.tree_at(
lambda x: x.weight,
x_h_plus_1_xx_pos[n],
_icnn_init(
key=keys[3 + 3 * n],
shape=x_h_plus_1_xx_pos[n].weight.shape
)
)
x_h_plus_1_xx_pos[n] = eqx.tree_at(
lambda x: x.bias,
x_h_plus_1_xx_pos[n],
_icnn_init(
key=keys[3 + 3 * n],
shape=x_h_plus_1_xx_pos[n].bias.shape
)
)
x_h_plus_1_xy.append(eqx.nn.Linear(
n_neurons_y, n_neurons_x,
use_bias=False,
key=keys[3 + 3 * n + 1]
))
x_h_plus_1_xy[n] = eqx.tree_at(
lambda x: x.weight,
x_h_plus_1_xy[n],
_icnn_init(
key=keys[3 + 3 * n + 1],
shape=x_h_plus_1_xy[n].weight.shape
)
)
y_h_plus_1.append(eqx.nn.Linear(
n_neurons_y, n_neurons_y,
use_bias=True,
key=keys[3 + 3 * n + 2]
))
y_h_plus_1[n] = eqx.tree_at(
lambda x: x.weight,
y_h_plus_1[n],
_icnn_init(
key=keys[3 + 3 * n + 2],
shape=y_h_plus_1[n].weight.shape
)
)
y_h_plus_1[n] = eqx.tree_at(
lambda x: x.bias,
y_h_plus_1[n],
_icnn_init(
key=keys[3 + 3 * n + 2],
shape=y_h_plus_1[n].bias.shape
)
)

x_n_xx_pos = eqx.nn.Linear(
n_neurons_x, n_outputs,
use_bias=True,
key=keys[3 * (n + 1)]
)
x_n_xx_pos = eqx.tree_at(
lambda x: x.weight,
x_n_xx_pos,
_icnn_init(key=keys[3 * (n + 1)], shape=x_n_xx_pos.weight.shape)
)
x_n_xx_pos = eqx.tree_at(
lambda x: x.bias,
x_n_xx_pos,
_icnn_init(key=keys[3 * (n + 1)], shape=x_n_xx_pos.bias.shape)
)
x_n_xy = eqx.nn.Linear(
n_neurons_y, n_outputs,
use_bias=False,
key=keys[3 * (n + 1) + 1]
)
x_n_xy = eqx.tree_at(
lambda x: x.weight,
x_n_xy,
_icnn_init(key=keys[3 * (n + 1) + 1], shape=x_n_xy.weight.shape)
)

# finally set fields
self.n_convex = n_convex
self.n_inputs = n_inputs
self.n_layers = n_layers
self.x1_xx_pos = x1_xx_pos
self.x1_xy = x1_xy
self.y1 = y1
self.x_h_plus_1_xx_pos = x_h_plus_1_xx_pos
self.x_h_plus_1_xy = x_h_plus_1_xy
self.y_h_plus_1 = y_h_plus_1
self.x_n_xx_pos = x_n_xx_pos
self.x_n_xy = x_n_xy
self.activation_x = activation_x
self.activation_y = activation_y

def __call__(self, x_in):
n_nonconvex = self.n_inputs - self.n_convex
y0 = x_in[0:n_nonconvex]
y = self.y1(y0)
x0 = x_in[n_nonconvex:]
x = self.x1_xx_pos(x0) + self.x1_xy(y)
x = self.activation_x(x)

for layer in range(self.n_layers):
y = self.y_h_plus_1[layer](y)
y = self.activation_y(y)
x = self.x_h_plus_1_xx_pos[layer](x) + self.x_h_plus_1_xy[layer](y)
x = self.activation_x(x)

z = self.x_n_xx_pos(x) + self.x_n_xy(y)
return z

def parameter_enforcement(self):
temp = jnp.clip(self.x1_xx_pos.weight, min=1e-3)
self = eqx.tree_at(lambda x: x.x1_xx_pos.weight, self, temp)

for n in range(self.n_layers):
temp = jnp.clip(self.x_h_plus_1_xx_pos[n].weight, min=1e-3)
self = eqx.tree_at(
lambda x: x.x_h_plus_1_xx_pos[n].weight, self, temp
)

temp = jnp.clip(self.x_n_xx_pos.weight, min=1e-3)
self = eqx.tree_at(lambda x: x.x_n_xx_pos.weight, self, temp)
return self
4 changes: 2 additions & 2 deletions test/networks/test_field_physics_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# y = network(x)
# props = props()
# assert y.shape == (2,)
# assert props[0] >= 1. and props[0] <= 2.
# assert props[0] >= 1. and props[0] <= 2.
# assert props[1] >= 2. and props[1] <= 3.


Expand All @@ -40,7 +40,7 @@
# model.serialise(os.path.join(Path(__file__).parent, 'checkpoint'), 0)

# model_loaded = eqx.tree_deserialise_leaves(
# os.path.join(Path(__file__).parent, 'checkpoint_0000000.eqx'),
# os.path.join(Path(__file__).parent, 'checkpoint_0000000.eqx'),
# model
# )
# network, props = model_loaded
Expand Down
70 changes: 70 additions & 0 deletions test/networks/test_input_polyconvex_nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
def is_convex(network, x1, x2, lambda_val):
"""Check if the network is convex between two points x1 and x2."""
# Calculate the convex combination
x_combined = lambda_val * x1 + (1 - lambda_val) * x2

# Evaluate the network at the points
f_x1 = network(x1)
f_x2 = network(x2)
f_combined = network(x_combined)

# Check the convexity condition
return f_combined <= lambda_val * f_x1 + (1 - lambda_val) * f_x2


def is_polyconvex(network, fixed_inputs, dim, lambda_val):
import jax.random as random
"""Check polyconvexity for a specific dimension."""
# Create two input points with fixed values in other dimensions
x1 = fixed_inputs.copy()
x2 = fixed_inputs.copy()

# Vary the specified dimension
x1 = x1.at[dim].set(random.uniform(random.PRNGKey(0), ()))
x2 = x2.at[dim].set(random.uniform(random.PRNGKey(1), ()))

return is_convex(network, x1, x2, lambda_val)


def test_icnn_all_convex():
from pancax import InputPolyconvexNN
import jax
model = InputPolyconvexNN(
3, 1, 3, jax.nn.softplus, jax.nn.softplus,
key=jax.random.key(0)
)
model = model.parameter_enforcement()

x1 = jax.random.uniform(key=jax.random.key(0), shape=(3,))
x2 = jax.random.uniform(key=jax.random.key(1), shape=(3,))
# y = model(x)

for dim in range(x1.shape[0]):
for lambda_val in [0.0, 0.25, 0.5, 0.75, 1.0]:
assert is_polyconvex(model, x1, dim, lambda_val), \
f"Polyconvexity failed for dimension={dim}, \
lambda={lambda_val}, fixed_inputs={x1}"

for lambda_val in [0.0, 0.25, 0.5, 0.75, 1.0]:
assert is_convex(model, x1, x2, lambda_val), f"Convexity failed for \
lambda={lambda_val}, x1={x1}, x2={x2}"


def test_icnn_some_convex():
from pancax import InputPolyconvexNN
import jax
n_convex = 2
model = InputPolyconvexNN(
3, 1, n_convex, jax.nn.softplus, jax.nn.softplus,
key=jax.random.key(0)
)
model = model.parameter_enforcement()

x1 = jax.random.uniform(key=jax.random.key(0), shape=(3,))
# y = model(x)

for dim in range(n_convex):
for lambda_val in [0.0, 0.25, 0.5, 0.75, 1.0]:
assert is_polyconvex(model, x1, dim, lambda_val), \
f"Polyconvexity failed for dimension={dim}, \
lambda={lambda_val}, fixed_inputs={x1}"
2 changes: 1 addition & 1 deletion test/networks/test_network.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def test_network():
from pancax import MLP, Network
import jax
model = Network(MLP, 3, 2, 20, 3, jax.nn.tanh, key=jax.random.key(0))
model = Network(MLP, 3, 2, 20, 3, jax.nn.tanh, key=jax.random.key(0))
x = jax.numpy.ones(3)
y = model(x)
assert y.shape == (2,)
2 changes: 1 addition & 1 deletion test/networks/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
# key=jax.random.key(0)
# )
# props = props()
# assert props[0] >= 1. and props[0] <= 2.
# assert props[0] >= 1. and props[0] <= 2.
# assert props[1] >= 2. and props[1] <= 3.
Loading