diff --git a/pancax/__init__.py b/pancax/__init__.py index cc684ce..e950f53 100644 --- a/pancax/__init__.py +++ b/pancax/__init__.py @@ -56,6 +56,7 @@ from .networks import \ Field, \ FieldPhysicsPair, \ + InputPolyconvexNN, \ Linear, \ MLP, \ MLPBasis, \ @@ -165,6 +166,7 @@ # networks "Field", "FieldPhysicsPair", + "InputPolyconvexNN", "Linear", "MLP", "MLPBasis", diff --git a/pancax/loss_functions/data_loss_functions.py b/pancax/loss_functions/data_loss_functions.py index b6e813f..74acbfd 100644 --- a/pancax/loss_functions/data_loss_functions.py +++ b/pancax/loss_functions/data_loss_functions.py @@ -12,7 +12,6 @@ def __init__(self, weight: Optional[float] = 1.0): def __call__(self, params, problem, inputs, outputs): field_network, _, _ = params - # n_dims = inputs.shape[-1] n_dims = problem.coords.shape[1] xs = inputs[:, 0:n_dims] ts = inputs[:, n_dims] @@ -23,19 +22,3 @@ def __call__(self, params, problem, inputs, outputs): loss = jnp.square(u_pred - outputs).mean() aux = {"field_data_loss": loss} return self.weight * loss, aux - - def __call__old(self, params, domain): - field_network, _, _ = params - n_dims = domain.coords.shape[1] - xs = domain.field_data.inputs[:, 0:n_dims] - # TODO need time normalization - ts = domain.field_data.inputs[:, n_dims] - # TODO below is currenlty the odd ball for the field_value API - u_pred = vmap(domain.physics.field_values, in_axes=(None, 0, 0))( - field_network, xs, ts - ) - - # TODO add output normalization - loss = jnp.square(u_pred - domain.field_data.outputs).mean() - aux = {"field_data_loss": loss} - return self.weight * loss, aux diff --git a/pancax/networks/__init__.py b/pancax/networks/__init__.py index fafa7e7..66c699e 100644 --- a/pancax/networks/__init__.py +++ b/pancax/networks/__init__.py @@ -1,6 +1,7 @@ from .fields import Field from .field_physics_pair import FieldPhysicsPair from .initialization import trunc_init +from .input_polyconvex_nn import InputPolyconvexNN from .ml_dirichlet_field import MLDirichletField from .mlp import Linear from .mlp import MLP @@ -16,6 +17,7 @@ def Network(network_type, *args, **kwargs): __all__ = [ "Field", "FieldPhysicsPair", + "InputPolyconvexNN", "Linear", "MLDirichletField", "MLP", diff --git a/pancax/networks/input_polyconvex_nn.py b/pancax/networks/input_polyconvex_nn.py new file mode 100644 index 0000000..d08161c --- /dev/null +++ b/pancax/networks/input_polyconvex_nn.py @@ -0,0 +1,220 @@ +from typing import Callable, List, Optional +import equinox as eqx +import jax +import jax.numpy as jnp + + +# TODO still need to apply weight enforcement manually on weights +# of layers with pos in the name + + +def _icnn_init(key, shape): + in_features = shape[0] + k = 1. / in_features + return jax.random.uniform(key=key, shape=shape, minval=-k, maxval=k) + + +class InputPolyconvexNN(eqx.Module): + n_convex: int + n_inputs: int + n_layers: int + x1_xx_pos: eqx.nn.Linear + x1_xy: eqx.nn.Linear + y1: eqx.nn.Linear + x_h_plus_1_xx_pos: List[eqx.nn.Linear] + x_h_plus_1_xy: List[eqx.nn.Linear] + y_h_plus_1: List[eqx.nn.Linear] + x_n_xx_pos: eqx.nn.Linear + x_n_xy: eqx.nn.Linear + activation_x: Callable + activation_y: Callable + + def __init__( + self, + n_inputs: int, + n_outputs: int, + n_convex: int, + activation_x: Callable, + activation_y: Callable, + # key: jax.random.PRNGKey, + n_layers: Optional[int] = 3, + n_neurons_x: Optional[int] = 40, + n_neurons_y: Optional[int] = 30, + *, + key: jax.random.PRNGKey + ): + keys = jax.random.split(key, 3 + 3 * n_layers + 2) + + x1_xx_pos = eqx.nn.Linear( + n_convex, n_neurons_x, + use_bias=True, + key=keys[0] + ) + x1_xx_pos = eqx.tree_at( + lambda x: x.weight, + x1_xx_pos, + _icnn_init(key=keys[0], shape=x1_xx_pos.weight.shape) + ) + x1_xx_pos = eqx.tree_at( + lambda x: x.bias, + x1_xx_pos, + _icnn_init(key=keys[0], shape=x1_xx_pos.bias.shape) + ) + x1_xy = eqx.nn.Linear( + n_neurons_y, n_neurons_x, + use_bias=False, + key=keys[1] + ) + x1_xy = eqx.tree_at( + lambda x: x.weight, + x1_xy, + _icnn_init(key=keys[1], shape=x1_xy.weight.shape) + ) + y1 = eqx.nn.Linear( + n_inputs - n_convex, n_neurons_y, + use_bias=True, + key=keys[2] + ) + y1 = eqx.tree_at( + lambda x: x.weight, + y1, + _icnn_init(key=keys[2], shape=y1.weight.shape) + ) + y1 = eqx.tree_at( + lambda x: x.bias, + y1, + _icnn_init(key=keys[2], shape=y1.bias.shape) + ) + + x_h_plus_1_xx_pos = [] + x_h_plus_1_xy = [] + y_h_plus_1 = [] + + for n in range(n_layers): + x_h_plus_1_xx_pos.append(eqx.nn.Linear( + n_neurons_x, n_neurons_x, + use_bias=True, + key=keys[3 + 3 * n] + )) + x_h_plus_1_xx_pos[n] = eqx.tree_at( + lambda x: x.weight, + x_h_plus_1_xx_pos[n], + _icnn_init( + key=keys[3 + 3 * n], + shape=x_h_plus_1_xx_pos[n].weight.shape + ) + ) + x_h_plus_1_xx_pos[n] = eqx.tree_at( + lambda x: x.bias, + x_h_plus_1_xx_pos[n], + _icnn_init( + key=keys[3 + 3 * n], + shape=x_h_plus_1_xx_pos[n].bias.shape + ) + ) + x_h_plus_1_xy.append(eqx.nn.Linear( + n_neurons_y, n_neurons_x, + use_bias=False, + key=keys[3 + 3 * n + 1] + )) + x_h_plus_1_xy[n] = eqx.tree_at( + lambda x: x.weight, + x_h_plus_1_xy[n], + _icnn_init( + key=keys[3 + 3 * n + 1], + shape=x_h_plus_1_xy[n].weight.shape + ) + ) + y_h_plus_1.append(eqx.nn.Linear( + n_neurons_y, n_neurons_y, + use_bias=True, + key=keys[3 + 3 * n + 2] + )) + y_h_plus_1[n] = eqx.tree_at( + lambda x: x.weight, + y_h_plus_1[n], + _icnn_init( + key=keys[3 + 3 * n + 2], + shape=y_h_plus_1[n].weight.shape + ) + ) + y_h_plus_1[n] = eqx.tree_at( + lambda x: x.bias, + y_h_plus_1[n], + _icnn_init( + key=keys[3 + 3 * n + 2], + shape=y_h_plus_1[n].bias.shape + ) + ) + + x_n_xx_pos = eqx.nn.Linear( + n_neurons_x, n_outputs, + use_bias=True, + key=keys[3 * (n + 1)] + ) + x_n_xx_pos = eqx.tree_at( + lambda x: x.weight, + x_n_xx_pos, + _icnn_init(key=keys[3 * (n + 1)], shape=x_n_xx_pos.weight.shape) + ) + x_n_xx_pos = eqx.tree_at( + lambda x: x.bias, + x_n_xx_pos, + _icnn_init(key=keys[3 * (n + 1)], shape=x_n_xx_pos.bias.shape) + ) + x_n_xy = eqx.nn.Linear( + n_neurons_y, n_outputs, + use_bias=False, + key=keys[3 * (n + 1) + 1] + ) + x_n_xy = eqx.tree_at( + lambda x: x.weight, + x_n_xy, + _icnn_init(key=keys[3 * (n + 1) + 1], shape=x_n_xy.weight.shape) + ) + + # finally set fields + self.n_convex = n_convex + self.n_inputs = n_inputs + self.n_layers = n_layers + self.x1_xx_pos = x1_xx_pos + self.x1_xy = x1_xy + self.y1 = y1 + self.x_h_plus_1_xx_pos = x_h_plus_1_xx_pos + self.x_h_plus_1_xy = x_h_plus_1_xy + self.y_h_plus_1 = y_h_plus_1 + self.x_n_xx_pos = x_n_xx_pos + self.x_n_xy = x_n_xy + self.activation_x = activation_x + self.activation_y = activation_y + + def __call__(self, x_in): + n_nonconvex = self.n_inputs - self.n_convex + y0 = x_in[0:n_nonconvex] + y = self.y1(y0) + x0 = x_in[n_nonconvex:] + x = self.x1_xx_pos(x0) + self.x1_xy(y) + x = self.activation_x(x) + + for layer in range(self.n_layers): + y = self.y_h_plus_1[layer](y) + y = self.activation_y(y) + x = self.x_h_plus_1_xx_pos[layer](x) + self.x_h_plus_1_xy[layer](y) + x = self.activation_x(x) + + z = self.x_n_xx_pos(x) + self.x_n_xy(y) + return z + + def parameter_enforcement(self): + temp = jnp.clip(self.x1_xx_pos.weight, min=1e-3) + self = eqx.tree_at(lambda x: x.x1_xx_pos.weight, self, temp) + + for n in range(self.n_layers): + temp = jnp.clip(self.x_h_plus_1_xx_pos[n].weight, min=1e-3) + self = eqx.tree_at( + lambda x: x.x_h_plus_1_xx_pos[n].weight, self, temp + ) + + temp = jnp.clip(self.x_n_xx_pos.weight, min=1e-3) + self = eqx.tree_at(lambda x: x.x_n_xx_pos.weight, self, temp) + return self diff --git a/test/networks/test_field_physics_pair.py b/test/networks/test_field_physics_pair.py index 72122e2..9cd8e32 100644 --- a/test/networks/test_field_physics_pair.py +++ b/test/networks/test_field_physics_pair.py @@ -19,7 +19,7 @@ # y = network(x) # props = props() # assert y.shape == (2,) -# assert props[0] >= 1. and props[0] <= 2. +# assert props[0] >= 1. and props[0] <= 2. # assert props[1] >= 2. and props[1] <= 3. @@ -40,7 +40,7 @@ # model.serialise(os.path.join(Path(__file__).parent, 'checkpoint'), 0) # model_loaded = eqx.tree_deserialise_leaves( -# os.path.join(Path(__file__).parent, 'checkpoint_0000000.eqx'), +# os.path.join(Path(__file__).parent, 'checkpoint_0000000.eqx'), # model # ) # network, props = model_loaded diff --git a/test/networks/test_input_polyconvex_nn.py b/test/networks/test_input_polyconvex_nn.py new file mode 100644 index 0000000..7551297 --- /dev/null +++ b/test/networks/test_input_polyconvex_nn.py @@ -0,0 +1,70 @@ +def is_convex(network, x1, x2, lambda_val): + """Check if the network is convex between two points x1 and x2.""" + # Calculate the convex combination + x_combined = lambda_val * x1 + (1 - lambda_val) * x2 + + # Evaluate the network at the points + f_x1 = network(x1) + f_x2 = network(x2) + f_combined = network(x_combined) + + # Check the convexity condition + return f_combined <= lambda_val * f_x1 + (1 - lambda_val) * f_x2 + + +def is_polyconvex(network, fixed_inputs, dim, lambda_val): + import jax.random as random + """Check polyconvexity for a specific dimension.""" + # Create two input points with fixed values in other dimensions + x1 = fixed_inputs.copy() + x2 = fixed_inputs.copy() + + # Vary the specified dimension + x1 = x1.at[dim].set(random.uniform(random.PRNGKey(0), ())) + x2 = x2.at[dim].set(random.uniform(random.PRNGKey(1), ())) + + return is_convex(network, x1, x2, lambda_val) + + +def test_icnn_all_convex(): + from pancax import InputPolyconvexNN + import jax + model = InputPolyconvexNN( + 3, 1, 3, jax.nn.softplus, jax.nn.softplus, + key=jax.random.key(0) + ) + model = model.parameter_enforcement() + + x1 = jax.random.uniform(key=jax.random.key(0), shape=(3,)) + x2 = jax.random.uniform(key=jax.random.key(1), shape=(3,)) + # y = model(x) + + for dim in range(x1.shape[0]): + for lambda_val in [0.0, 0.25, 0.5, 0.75, 1.0]: + assert is_polyconvex(model, x1, dim, lambda_val), \ + f"Polyconvexity failed for dimension={dim}, \ + lambda={lambda_val}, fixed_inputs={x1}" + + for lambda_val in [0.0, 0.25, 0.5, 0.75, 1.0]: + assert is_convex(model, x1, x2, lambda_val), f"Convexity failed for \ + lambda={lambda_val}, x1={x1}, x2={x2}" + + +def test_icnn_some_convex(): + from pancax import InputPolyconvexNN + import jax + n_convex = 2 + model = InputPolyconvexNN( + 3, 1, n_convex, jax.nn.softplus, jax.nn.softplus, + key=jax.random.key(0) + ) + model = model.parameter_enforcement() + + x1 = jax.random.uniform(key=jax.random.key(0), shape=(3,)) + # y = model(x) + + for dim in range(n_convex): + for lambda_val in [0.0, 0.25, 0.5, 0.75, 1.0]: + assert is_polyconvex(model, x1, dim, lambda_val), \ + f"Polyconvexity failed for dimension={dim}, \ + lambda={lambda_val}, fixed_inputs={x1}" diff --git a/test/networks/test_network.py b/test/networks/test_network.py index 01bedbb..7511da3 100644 --- a/test/networks/test_network.py +++ b/test/networks/test_network.py @@ -1,7 +1,7 @@ def test_network(): from pancax import MLP, Network import jax - model = Network(MLP, 3, 2, 20, 3, jax.nn.tanh, key=jax.random.key(0)) + model = Network(MLP, 3, 2, 20, 3, jax.nn.tanh, key=jax.random.key(0)) x = jax.numpy.ones(3) y = model(x) assert y.shape == (2,) diff --git a/test/networks/test_parameters.py b/test/networks/test_parameters.py index f20edf0..fb21058 100644 --- a/test/networks/test_parameters.py +++ b/test/networks/test_parameters.py @@ -15,5 +15,5 @@ # key=jax.random.key(0) # ) # props = props() -# assert props[0] >= 1. and props[0] <= 2. +# assert props[0] >= 1. and props[0] <= 2. # assert props[1] >= 2. and props[1] <= 3.