Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scripts/run_linter.sh
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
flake8 pancax
flake8 test
104 changes: 77 additions & 27 deletions test/fem/test_function_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
yRange = [0.0, 1.0]
# targetDispGrad = jnp.array([[0.1, -0.2],[0.4, -0.1]])

# mesh, U = create_mesh_and_disp(Nx, Ny, xRange, yRange, lambda x: jnp.dot(targetDispGrad, x))
# mesh, U = \
# create_mesh_and_disp(Nx, Ny, xRange, yRange, \
# lambda x: jnp.dot(targetDispGrad, x))


@pytest.fixture
Expand Down Expand Up @@ -67,17 +69,25 @@ def fspace_fixture_2(mesh_and_disp):
return fs, fs_na, quadratureRule, state, props, dt


def test_element_volume_single_point_quadrature(fspace_fixture_1, mesh_and_disp):
def test_element_volume_single_point_quadrature(
fspace_fixture_1,
mesh_and_disp
):
import jax.numpy as jnp

fs, _, _, _, _, _ = fspace_fixture_1
mesh, _, _ = mesh_and_disp
elementVols = jnp.sum(fs.vols, axis=1)
nElements = mesh.num_elements
jnp.array_equal(elementVols, jnp.ones(nElements) * 0.5 / ((Nx - 1) * (Ny - 1)))
jnp.array_equal(
elementVols, jnp.ones(nElements) * 0.5 / ((Nx - 1) * (Ny - 1))
)


def test_element_volume_single_point_quadrature_na(fspace_fixture_1, mesh_and_disp):
def test_element_volume_single_point_quadrature_na(
fspace_fixture_1,
mesh_and_disp
):
import jax
import jax.numpy as jnp

Expand All @@ -86,10 +96,15 @@ def test_element_volume_single_point_quadrature_na(fspace_fixture_1, mesh_and_di
X_els = mesh.coords[mesh.conns, :]
elementVols = jnp.sum(jax.vmap(fs_na.JxWs)(X_els), axis=1)
nElements = mesh.num_elements
jnp.array_equal(elementVols, jnp.ones(nElements) * 0.5 / ((Nx - 1) * (Ny - 1)))
jnp.array_equal(
elementVols, jnp.ones(nElements) * 0.5 / ((Nx - 1) * (Ny - 1))
)


def test_linear_reproducing_single_point_quadrature(fspace_fixture_1, mesh_and_disp):
def test_linear_reproducing_single_point_quadrature(
fspace_fixture_1,
mesh_and_disp
):
from pancax.fem.function_space import compute_field_gradient
import jax.numpy as jnp

Expand All @@ -102,15 +117,20 @@ def test_linear_reproducing_single_point_quadrature(fspace_fixture_1, mesh_and_d
assert jnp.allclose(dispGrads, exact)


def test_linear_reproducing_single_point_quadrature_na(fspace_fixture_1, mesh_and_disp):
def test_linear_reproducing_single_point_quadrature_na(
fspace_fixture_1,
mesh_and_disp
):
import jax
import jax.numpy as jnp

_, fs_na, quadratureRule, _, _, _ = fspace_fixture_1
mesh, U, targetDispGrad = mesh_and_disp
X_els = mesh.coords[mesh.conns, :]
U_els = U[mesh.conns, :]
dispGrads = jax.vmap(fs_na.compute_field_gradient, in_axes=(0, 0))(U_els, X_els)
dispGrads = jax.vmap(
fs_na.compute_field_gradient, in_axes=(0, 0)
)(U_els, X_els)
nElements = mesh.num_elements
npts = quadratureRule.xigauss.shape[0]
exact = jnp.tile(targetDispGrad, (nElements, npts, 1, 1))
Expand Down Expand Up @@ -233,17 +253,24 @@ def test_integrate_linear_field_single_point_quadrature_na(
jnp.isclose(Iy, expected)


def test_element_volume_multi_point_quadrature(fspace_fixture_2, mesh_and_disp):
def test_element_volume_multi_point_quadrature(
fspace_fixture_2,
mesh_and_disp
):
import jax.numpy as jnp

fs, _, _, _, _, _ = fspace_fixture_2
mesh, U, _ = mesh_and_disp
elementVols = jnp.sum(fs.vols, axis=1)
nElements = mesh.num_elements
jnp.array_equal(elementVols, jnp.ones(nElements) * 0.5 / ((Nx - 1) * (Ny - 1)))
jnp.array_equal(
elementVols, jnp.ones(nElements) * 0.5 / ((Nx - 1) * (Ny - 1))
)


def test_element_volume_multi_point_quadrature_na(fspace_fixture_2, mesh_and_disp):
def test_element_volume_multi_point_quadrature_na(
fspace_fixture_2, mesh_and_disp
):
import jax
import jax.numpy as jnp

Expand All @@ -252,10 +279,15 @@ def test_element_volume_multi_point_quadrature_na(fspace_fixture_2, mesh_and_dis
X_els = mesh.coords[mesh.conns, :]
elementVols = jnp.sum(jax.vmap(fs_na.JxWs)(X_els), axis=1)
nElements = mesh.num_elements
jnp.array_equal(elementVols, jnp.ones(nElements) * 0.5 / ((Nx - 1) * (Ny - 1)))
jnp.array_equal(
elementVols, jnp.ones(nElements) * 0.5 / ((Nx - 1) * (Ny - 1))
)


def test_linear_reproducing_multi_point_quadrature(fspace_fixture_2, mesh_and_disp):
def test_linear_reproducing_multi_point_quadrature(
fspace_fixture_2,
mesh_and_disp
):
from pancax.fem.function_space import compute_field_gradient
import jax.numpy as jnp

Expand All @@ -268,15 +300,19 @@ def test_linear_reproducing_multi_point_quadrature(fspace_fixture_2, mesh_and_di
assert jnp.allclose(dispGrads, exact)


def test_linear_reproducing_multi_point_quadrature_na(fspace_fixture_2, mesh_and_disp):
def test_linear_reproducing_multi_point_quadrature_na(
fspace_fixture_2, mesh_and_disp
):
import jax
import jax.numpy as jnp

_, fs_na, quadratureRule, _, _, _ = fspace_fixture_2
mesh, U, targetDispGrad = mesh_and_disp
X_els = mesh.coords[mesh.conns, :]
U_els = U[mesh.conns, :]
dispGrads = jax.vmap(fs_na.compute_field_gradient, in_axes=(0, 0))(U_els, X_els)
dispGrads = jax.vmap(
fs_na.compute_field_gradient, in_axes=(0, 0)
)(U_els, X_els)
nElements = mesh.num_elements
npts = quadratureRule.xigauss.shape[0]
exact = jnp.tile(targetDispGrad, (nElements, npts, 1, 1))
Expand Down Expand Up @@ -327,7 +363,10 @@ def test_integrate_constant_field_multi_point_quadrature_na(
# TODO add integration test/method for new na fspace


def test_integrate_linear_field_multi_point_quadrature(fspace_fixture_2, mesh_and_disp):
def test_integrate_linear_field_multi_point_quadrature(
fspace_fixture_2,
mesh_and_disp
):
from pancax.fem.function_space import integrate_over_block
import jax.numpy as jnp

Expand Down Expand Up @@ -512,7 +551,7 @@ def test_jit_on_integration(fspace_fixture_2, mesh_and_disp):
fs, _, _, state, props, dt = fspace_fixture_2
mesh, U, _ = mesh_and_disp
integrate_jit = jax.jit(integrate_over_block, static_argnums=(6,))
I = integrate_jit(
Ival = integrate_jit(
fs,
U,
mesh.coords,
Expand All @@ -522,7 +561,7 @@ def test_jit_on_integration(fspace_fixture_2, mesh_and_disp):
lambda u, gradu, state, props, X, dt: 1.0,
mesh.blocks["block"],
)
jnp.isclose(I, 1.0)
jnp.isclose(Ival, 1.0)


def test_jit_on_integration_na(fspace_fixture_2, mesh_and_disp):
Expand All @@ -534,10 +573,11 @@ def test_jit_on_integration_na(fspace_fixture_2, mesh_and_disp):
integrate_jit = eqx.filter_jit(fs_na.integrate_on_elements)
U_els = U[mesh.conns[mesh.blocks["block"]], :]
X_els = U[mesh.conns[mesh.blocks["block"]], :]
I = integrate_jit(
U_els, X_els, state, props, dt, lambda u, gradu, state, props, X, dt: 1.0
Ival = integrate_jit(
U_els, X_els, state, props, dt,
lambda u, gradu, state, props, X, dt: 1.0
)
jnp.isclose(I, 1.0)
jnp.isclose(Ival, 1.0)


def test_jit_and_jacrev_on_integration(fspace_fixture_2, mesh_and_disp):
Expand All @@ -555,12 +595,17 @@ def test_jit_and_jacrev_on_integration(fspace_fixture_2, mesh_and_disp):
state,
props,
dt,
lambda u, gradu, state, props, X, dt: 0.5 * jnp.tensordot(gradu, gradu),
lambda u, gradu, state, props, X, dt:
0.5 * jnp.tensordot(gradu, gradu),
mesh.blocks["block"],
)
nNodes = mesh.coords.shape[0]
interiorNodeIds = jnp.setdiff1d(jnp.arange(nNodes), mesh.nodeSets["all_boundary"])
jnp.array_equal(dI[interiorNodeIds, :], jnp.zeros_like(U[interiorNodeIds, :]))
interiorNodeIds = jnp.setdiff1d(
jnp.arange(nNodes), mesh.nodeSets["all_boundary"]
)
jnp.array_equal(
dI[interiorNodeIds, :], jnp.zeros_like(U[interiorNodeIds, :])
)


def test_jit_and_jacrev_on_integration_na(fspace_fixture_2, mesh_and_disp):
Expand All @@ -578,8 +623,13 @@ def test_jit_and_jacrev_on_integration_na(fspace_fixture_2, mesh_and_disp):
state,
props,
dt,
lambda u, gradu, state, props, X, dt: 0.5 * jnp.tensordot(gradu, gradu),
lambda u, gradu, state, props, X, dt:
0.5 * jnp.tensordot(gradu, gradu),
)
nNodes = mesh.coords.shape[0]
interiorNodeIds = jnp.setdiff1d(jnp.arange(nNodes), mesh.nodeSets["all_boundary"])
jnp.array_equal(dI[interiorNodeIds, :], jnp.zeros_like(U[interiorNodeIds, :]))
interiorNodeIds = jnp.setdiff1d(
jnp.arange(nNodes), mesh.nodeSets["all_boundary"]
)
jnp.array_equal(
dI[interiorNodeIds, :], jnp.zeros_like(U[interiorNodeIds, :])
)
Loading
Loading