diff --git a/scripts/run_linter.sh b/scripts/run_linter.sh index d270d5c..0564d1c 100755 --- a/scripts/run_linter.sh +++ b/scripts/run_linter.sh @@ -1 +1,2 @@ flake8 pancax +flake8 test diff --git a/test/fem/test_function_space.py b/test/fem/test_function_space.py index a528251..2f3b4cb 100644 --- a/test/fem/test_function_space.py +++ b/test/fem/test_function_space.py @@ -15,7 +15,9 @@ yRange = [0.0, 1.0] # targetDispGrad = jnp.array([[0.1, -0.2],[0.4, -0.1]]) -# mesh, U = create_mesh_and_disp(Nx, Ny, xRange, yRange, lambda x: jnp.dot(targetDispGrad, x)) +# mesh, U = \ +# create_mesh_and_disp(Nx, Ny, xRange, yRange, \ +# lambda x: jnp.dot(targetDispGrad, x)) @pytest.fixture @@ -67,17 +69,25 @@ def fspace_fixture_2(mesh_and_disp): return fs, fs_na, quadratureRule, state, props, dt -def test_element_volume_single_point_quadrature(fspace_fixture_1, mesh_and_disp): +def test_element_volume_single_point_quadrature( + fspace_fixture_1, + mesh_and_disp +): import jax.numpy as jnp fs, _, _, _, _, _ = fspace_fixture_1 mesh, _, _ = mesh_and_disp elementVols = jnp.sum(fs.vols, axis=1) nElements = mesh.num_elements - jnp.array_equal(elementVols, jnp.ones(nElements) * 0.5 / ((Nx - 1) * (Ny - 1))) + jnp.array_equal( + elementVols, jnp.ones(nElements) * 0.5 / ((Nx - 1) * (Ny - 1)) + ) -def test_element_volume_single_point_quadrature_na(fspace_fixture_1, mesh_and_disp): +def test_element_volume_single_point_quadrature_na( + fspace_fixture_1, + mesh_and_disp +): import jax import jax.numpy as jnp @@ -86,10 +96,15 @@ def test_element_volume_single_point_quadrature_na(fspace_fixture_1, mesh_and_di X_els = mesh.coords[mesh.conns, :] elementVols = jnp.sum(jax.vmap(fs_na.JxWs)(X_els), axis=1) nElements = mesh.num_elements - jnp.array_equal(elementVols, jnp.ones(nElements) * 0.5 / ((Nx - 1) * (Ny - 1))) + jnp.array_equal( + elementVols, jnp.ones(nElements) * 0.5 / ((Nx - 1) * (Ny - 1)) + ) -def test_linear_reproducing_single_point_quadrature(fspace_fixture_1, mesh_and_disp): +def test_linear_reproducing_single_point_quadrature( + fspace_fixture_1, + mesh_and_disp +): from pancax.fem.function_space import compute_field_gradient import jax.numpy as jnp @@ -102,7 +117,10 @@ def test_linear_reproducing_single_point_quadrature(fspace_fixture_1, mesh_and_d assert jnp.allclose(dispGrads, exact) -def test_linear_reproducing_single_point_quadrature_na(fspace_fixture_1, mesh_and_disp): +def test_linear_reproducing_single_point_quadrature_na( + fspace_fixture_1, + mesh_and_disp +): import jax import jax.numpy as jnp @@ -110,7 +128,9 @@ def test_linear_reproducing_single_point_quadrature_na(fspace_fixture_1, mesh_an mesh, U, targetDispGrad = mesh_and_disp X_els = mesh.coords[mesh.conns, :] U_els = U[mesh.conns, :] - dispGrads = jax.vmap(fs_na.compute_field_gradient, in_axes=(0, 0))(U_els, X_els) + dispGrads = jax.vmap( + fs_na.compute_field_gradient, in_axes=(0, 0) + )(U_els, X_els) nElements = mesh.num_elements npts = quadratureRule.xigauss.shape[0] exact = jnp.tile(targetDispGrad, (nElements, npts, 1, 1)) @@ -233,17 +253,24 @@ def test_integrate_linear_field_single_point_quadrature_na( jnp.isclose(Iy, expected) -def test_element_volume_multi_point_quadrature(fspace_fixture_2, mesh_and_disp): +def test_element_volume_multi_point_quadrature( + fspace_fixture_2, + mesh_and_disp +): import jax.numpy as jnp fs, _, _, _, _, _ = fspace_fixture_2 mesh, U, _ = mesh_and_disp elementVols = jnp.sum(fs.vols, axis=1) nElements = mesh.num_elements - jnp.array_equal(elementVols, jnp.ones(nElements) * 0.5 / ((Nx - 1) * (Ny - 1))) + jnp.array_equal( + elementVols, jnp.ones(nElements) * 0.5 / ((Nx - 1) * (Ny - 1)) + ) -def test_element_volume_multi_point_quadrature_na(fspace_fixture_2, mesh_and_disp): +def test_element_volume_multi_point_quadrature_na( + fspace_fixture_2, mesh_and_disp +): import jax import jax.numpy as jnp @@ -252,10 +279,15 @@ def test_element_volume_multi_point_quadrature_na(fspace_fixture_2, mesh_and_dis X_els = mesh.coords[mesh.conns, :] elementVols = jnp.sum(jax.vmap(fs_na.JxWs)(X_els), axis=1) nElements = mesh.num_elements - jnp.array_equal(elementVols, jnp.ones(nElements) * 0.5 / ((Nx - 1) * (Ny - 1))) + jnp.array_equal( + elementVols, jnp.ones(nElements) * 0.5 / ((Nx - 1) * (Ny - 1)) + ) -def test_linear_reproducing_multi_point_quadrature(fspace_fixture_2, mesh_and_disp): +def test_linear_reproducing_multi_point_quadrature( + fspace_fixture_2, + mesh_and_disp +): from pancax.fem.function_space import compute_field_gradient import jax.numpy as jnp @@ -268,7 +300,9 @@ def test_linear_reproducing_multi_point_quadrature(fspace_fixture_2, mesh_and_di assert jnp.allclose(dispGrads, exact) -def test_linear_reproducing_multi_point_quadrature_na(fspace_fixture_2, mesh_and_disp): +def test_linear_reproducing_multi_point_quadrature_na( + fspace_fixture_2, mesh_and_disp +): import jax import jax.numpy as jnp @@ -276,7 +310,9 @@ def test_linear_reproducing_multi_point_quadrature_na(fspace_fixture_2, mesh_and mesh, U, targetDispGrad = mesh_and_disp X_els = mesh.coords[mesh.conns, :] U_els = U[mesh.conns, :] - dispGrads = jax.vmap(fs_na.compute_field_gradient, in_axes=(0, 0))(U_els, X_els) + dispGrads = jax.vmap( + fs_na.compute_field_gradient, in_axes=(0, 0) + )(U_els, X_els) nElements = mesh.num_elements npts = quadratureRule.xigauss.shape[0] exact = jnp.tile(targetDispGrad, (nElements, npts, 1, 1)) @@ -327,7 +363,10 @@ def test_integrate_constant_field_multi_point_quadrature_na( # TODO add integration test/method for new na fspace -def test_integrate_linear_field_multi_point_quadrature(fspace_fixture_2, mesh_and_disp): +def test_integrate_linear_field_multi_point_quadrature( + fspace_fixture_2, + mesh_and_disp +): from pancax.fem.function_space import integrate_over_block import jax.numpy as jnp @@ -512,7 +551,7 @@ def test_jit_on_integration(fspace_fixture_2, mesh_and_disp): fs, _, _, state, props, dt = fspace_fixture_2 mesh, U, _ = mesh_and_disp integrate_jit = jax.jit(integrate_over_block, static_argnums=(6,)) - I = integrate_jit( + Ival = integrate_jit( fs, U, mesh.coords, @@ -522,7 +561,7 @@ def test_jit_on_integration(fspace_fixture_2, mesh_and_disp): lambda u, gradu, state, props, X, dt: 1.0, mesh.blocks["block"], ) - jnp.isclose(I, 1.0) + jnp.isclose(Ival, 1.0) def test_jit_on_integration_na(fspace_fixture_2, mesh_and_disp): @@ -534,10 +573,11 @@ def test_jit_on_integration_na(fspace_fixture_2, mesh_and_disp): integrate_jit = eqx.filter_jit(fs_na.integrate_on_elements) U_els = U[mesh.conns[mesh.blocks["block"]], :] X_els = U[mesh.conns[mesh.blocks["block"]], :] - I = integrate_jit( - U_els, X_els, state, props, dt, lambda u, gradu, state, props, X, dt: 1.0 + Ival = integrate_jit( + U_els, X_els, state, props, dt, + lambda u, gradu, state, props, X, dt: 1.0 ) - jnp.isclose(I, 1.0) + jnp.isclose(Ival, 1.0) def test_jit_and_jacrev_on_integration(fspace_fixture_2, mesh_and_disp): @@ -555,12 +595,17 @@ def test_jit_and_jacrev_on_integration(fspace_fixture_2, mesh_and_disp): state, props, dt, - lambda u, gradu, state, props, X, dt: 0.5 * jnp.tensordot(gradu, gradu), + lambda u, gradu, state, props, X, dt: + 0.5 * jnp.tensordot(gradu, gradu), mesh.blocks["block"], ) nNodes = mesh.coords.shape[0] - interiorNodeIds = jnp.setdiff1d(jnp.arange(nNodes), mesh.nodeSets["all_boundary"]) - jnp.array_equal(dI[interiorNodeIds, :], jnp.zeros_like(U[interiorNodeIds, :])) + interiorNodeIds = jnp.setdiff1d( + jnp.arange(nNodes), mesh.nodeSets["all_boundary"] + ) + jnp.array_equal( + dI[interiorNodeIds, :], jnp.zeros_like(U[interiorNodeIds, :]) + ) def test_jit_and_jacrev_on_integration_na(fspace_fixture_2, mesh_and_disp): @@ -578,8 +623,13 @@ def test_jit_and_jacrev_on_integration_na(fspace_fixture_2, mesh_and_disp): state, props, dt, - lambda u, gradu, state, props, X, dt: 0.5 * jnp.tensordot(gradu, gradu), + lambda u, gradu, state, props, X, dt: + 0.5 * jnp.tensordot(gradu, gradu), ) nNodes = mesh.coords.shape[0] - interiorNodeIds = jnp.setdiff1d(jnp.arange(nNodes), mesh.nodeSets["all_boundary"]) - jnp.array_equal(dI[interiorNodeIds, :], jnp.zeros_like(U[interiorNodeIds, :])) + interiorNodeIds = jnp.setdiff1d( + jnp.arange(nNodes), mesh.nodeSets["all_boundary"] + ) + jnp.array_equal( + dI[interiorNodeIds, :], jnp.zeros_like(U[interiorNodeIds, :]) + ) diff --git a/test/fem/test_mesh.py b/test/fem/test_mesh.py index ab962c0..c08f622 100644 --- a/test/fem/test_mesh.py +++ b/test/fem/test_mesh.py @@ -1,17 +1,8 @@ -from pancax import fem -from pancax import create_nodesets_from_sidesets -from .utils import create_mesh_and_disp -import jax.numpy as np -import numpy as onp - -Nx = 3 -Ny = 2 -xRange = [0.0, 1.0] -yRange = [0.0, 1.0] -targetDispGrad = np.array([[0.1, -0.2], [0.4, -0.1]]) +import pytest def triangle_inradius(tcoords): + import numpy as onp tcoords = onp.hstack((tcoords, onp.ones((tcoords.shape[0], 1)))) area = 0.5 * onp.cross(tcoords[1] - tcoords[0], tcoords[2] - tcoords[0])[2] @@ -23,26 +14,48 @@ def triangle_inradius(tcoords): return area / peri -mesh, U = create_mesh_and_disp( - Nx, Ny, xRange, yRange, lambda x: np.dot(targetDispGrad, x) -) +@pytest.fixture +def mesh_fix(): + from .utils import create_mesh_and_disp + import jax.numpy as jnp + Nx = 3 + Ny = 2 + xRange = [0.0, 1.0] + yRange = [0.0, 1.0] + targetDispGrad = jnp.array([[0.1, -0.2], [0.4, -0.1]]) + + mesh, U = create_mesh_and_disp( + Nx, Ny, xRange, yRange, lambda x: jnp.dot(targetDispGrad, x) + ) + return mesh, U -def test_create_nodesets_from_sidesets(): - # mesh, U = create_mesh_and_disp(Nx, Ny, xRange, yRange, lambda x: np.dot(targetDispGrad, x)) +def test_create_nodesets_from_sidesets(mesh_fix): + from pancax import create_nodesets_from_sidesets + import jax.numpy as jnp + mesh, _ = mesh_fix + # mesh, U = create_mesh_and_disp(Nx, Ny, xRange, yRange, + # lambda x: np.dot(targetDispGrad, x)) nodeSets = create_nodesets_from_sidesets(mesh) # this test relies on the fact that matching nodesets # and sidesets were created on the MeshFixture for key in mesh.sideSets: - assert np.array_equal(mesh.nodeSets[key], nodeSets[key]) + assert jnp.array_equal(mesh.nodeSets[key], nodeSets[key]) + +def test_edge_connectivities(mesh_fix): + from pancax import fem + import jax.numpy as jnp + import numpy as onp -def test_edge_connectivities(): + mesh, _ = mesh_fix edgeConns, _ = fem.create_edges(mesh.conns) - goldBoundaryEdgeConns = np.array([[0, 1], [1, 2], [2, 5], [5, 4], [4, 3], [3, 0]]) + goldBoundaryEdgeConns = jnp.array([ + [0, 1], [1, 2], [2, 5], [5, 4], [4, 3], [3, 0] + ]) # Check that every boundary edge has been found. # Boundary edges must appear with the same connectivity order, @@ -63,22 +76,31 @@ def test_edge_connectivities(): # sense the vertices should be ordered, so we check # for both permutations. - goldInteriorEdgeConns = np.array([[0, 4], [1, 4], [1, 5]]) + goldInteriorEdgeConns = jnp.array([[0, 4], [1, 4], [1, 5]]) nInteriorEdges = goldInteriorEdgeConns.shape[0] interiorEdgeFound = onp.full(nInteriorEdges, False) for i, e in enumerate(goldInteriorEdgeConns): foundWithSameSense = onp.any(onp.all(edgeConns == e, axis=1)) - foundWithOppositeSense = onp.any(onp.all(edgeConns == onp.flip(e), axis=1)) + foundWithOppositeSense = onp.any( + onp.all(edgeConns == onp.flip(e), axis=1) + ) interiorEdgeFound[i] = foundWithSameSense or foundWithOppositeSense assert onp.all(interiorEdgeFound) -def test_edge_to_neighbor_cells_data(): +def test_edge_to_neighbor_cells_data(mesh_fix): + from pancax import fem + import jax.numpy as jnp + import numpy as onp + + mesh, _ = mesh_fix edgeConns, edges = fem.create_edges(mesh.conns) - goldBoundaryEdgeConns = np.array([[0, 1], [1, 2], [2, 5], [5, 4], [4, 3], [3, 0]]) + goldBoundaryEdgeConns = jnp.array([ + [0, 1], [1, 2], [2, 5], [5, 4], [4, 3], [3, 0] + ]) goldBoundaryEdges = onp.array( [ @@ -92,15 +114,17 @@ def test_edge_to_neighbor_cells_data(): ) for be, bc in zip(goldBoundaryEdges, goldBoundaryEdgeConns): - i = np.where(onp.all(edgeConns == bc, axis=1)) - assert np.all(edges[i, :] == be) + i = jnp.where(onp.all(edgeConns == bc, axis=1)) + assert jnp.all(edges[i, :] == be) - goldInteriorEdgeConns = np.array([[0, 4], [1, 4], [5, 1]]) + goldInteriorEdgeConns = jnp.array([[0, 4], [1, 4], [5, 1]]) goldInteriorEdges = onp.array([[1, 0, 0, 2], [0, 1, 3, 2], [2, 2, 3, 0]]) for ie, ic in zip(goldInteriorEdges, goldInteriorEdgeConns): foundWithSameSense = onp.any(onp.all(edgeConns == ic, axis=1)) - foundWithOppositeSense = onp.any(onp.all(edgeConns == onp.flip(ic), axis=1)) + foundWithOppositeSense = onp.any( + onp.all(edgeConns == onp.flip(ic), axis=1) + ) edgeDataMatches = False if foundWithSameSense: i = onp.where(onp.all(edgeConns == ic, axis=1)) @@ -111,22 +135,30 @@ def test_edge_to_neighbor_cells_data(): else: # self.fail('edge not found with vertices ' + str(ic)) print("Need to raise an exception test here") - edgeDataMatches = np.all(edges[i, :] == edgeData) + edgeDataMatches = jnp.all(edges[i, :] == edgeData) assert edgeDataMatches -def test_conversion_to_quadratic_mesh_is_valid(): +def test_conversion_to_quadratic_mesh_is_valid(mesh_fix): + from pancax import fem + import jax.numpy as jnp + + mesh, _ = mesh_fix newMesh = fem.mesh.create_higher_order_mesh_from_simplex_mesh(mesh, 2) nNodes = newMesh.coords.shape[0] assert nNodes == 15 # make sure all of the newly created nodes got used in the connectivity - assert np.array_equal(np.unique(newMesh.conns.ravel()), np.arange(nNodes)) + assert jnp.array_equal( + jnp.unique(newMesh.conns.ravel()), jnp.arange(nNodes) + ) # check that all triangles are valid: - # compute inradius of each triangle and of the sub-triangle of the mid-edge nodes - # Both should be nonzero, and parent inradius should be 2x sub-triangle inradius + # compute inradius of each triangle + # and of the sub-triangle of the mid-edge nodes + # Both should be nonzero, and parent inradius + # should be 2x sub-triangle inradius master = newMesh.parentElement for t in newMesh.conns: elCoords = newMesh.coords[t, :] @@ -139,4 +171,4 @@ def test_conversion_to_quadratic_mesh_is_valid(): assert parentArea > 0.0 assert childArea > 0.0 - assert np.abs(parentArea - 2.0 * childArea) < 1e-10 + assert jnp.abs(parentArea - 2.0 * childArea) < 1e-10 diff --git a/test/fem/test_quadrature_rules.py b/test/fem/test_quadrature_rules.py index 0c0621f..b6e54d9 100644 --- a/test/fem/test_quadrature_rules.py +++ b/test/fem/test_quadrature_rules.py @@ -1,12 +1,9 @@ -from pancax.fem import QuadratureRule -from pancax.fem.elements import * -from scipy.special import binom -import jax.numpy as jnp import pytest # integrate x^n y^m on unit triangle def integrate_2D_monomial_on_triangle(n, m): + from scipy.special import binom p = n + m return 1.0 / ((p + 2) * (p + 1) * binom(p, n)) @@ -38,69 +35,92 @@ def is_inside_triangle(point): def is_inside_unit_interval(point): + import jax.numpy as jnp return jnp.all(point >= 0.0) and jnp.all(point <= 1.0) -elements_to_test = [] -q_degrees = [] -in_domain_methods = [] -for q in range(1, 2 + 1): - elements_to_test.append(Hex8Element()) - q_degrees.append(q) - in_domain_methods.append(is_inside_hex) - -for q in range(1, 25 + 1): - elements_to_test.append(LineElement(1)) - q_degrees.append(q) - in_domain_methods.append(is_inside_unit_interval) - -for q in range(1, 3 + 1): - elements_to_test.append(Quad4Element()) - elements_to_test.append(Quad9Element()) - q_degrees.append(q) - q_degrees.append(q) - in_domain_methods.append(is_inside_quad) - in_domain_methods.append(is_inside_quad) - -for q in range(1, 10 + 1): - elements_to_test.append(SimplexTriElement(1)) - q_degrees.append(q) - in_domain_methods.append(is_inside_unit_interval) - -for q in range(1, 2 + 1): - elements_to_test.append(Tet4Element()) - q_degrees.append(q) - in_domain_methods.append(is_inside_tet) - - -for q in range(1, 2 + 1): - elements_to_test.append(Tet10Element()) - q_degrees.append(q) - in_domain_methods.append(is_inside_tet) - - -@pytest.mark.parametrize("el, q", zip(elements_to_test, q_degrees)) -def test_are_postive_weights(el, q): - if type(el) == Tet4Element and q == 2: - pytest.skip("Not relevant for Tet4Element and q_degree = 2") - if type(el) == Tet10Element and q == 2: - pytest.skip("Not relevant for Tet10Element and q_degree = 2") - qr = QuadratureRule(el, q) - _, w = qr - assert jnp.all(w > 0.0) - - -@pytest.mark.parametrize( - "el, q, is_inside", zip(elements_to_test, q_degrees, in_domain_methods) -) -def test_are_inside_domain(el, q, is_inside): - qr = QuadratureRule(el, q) - for point in qr.xigauss: - assert is_inside(point) +@pytest.fixture +def quadratures_fix(): + from pancax.fem.elements import Hex8Element + from pancax.fem.elements import LineElement + from pancax.fem.elements import Quad4Element, Quad9Element + from pancax.fem.elements import SimplexTriElement + from pancax.fem.elements import Tet4Element, Tet10Element + + elements_to_test = [] + q_degrees = [] + in_domain_methods = [] + for q in range(1, 2 + 1): + elements_to_test.append(Hex8Element()) + q_degrees.append(q) + in_domain_methods.append(is_inside_hex) + + for q in range(1, 25 + 1): + elements_to_test.append(LineElement(1)) + q_degrees.append(q) + in_domain_methods.append(is_inside_unit_interval) + + for q in range(1, 3 + 1): + elements_to_test.append(Quad4Element()) + elements_to_test.append(Quad9Element()) + q_degrees.append(q) + q_degrees.append(q) + in_domain_methods.append(is_inside_quad) + in_domain_methods.append(is_inside_quad) + + for q in range(1, 10 + 1): + elements_to_test.append(SimplexTriElement(1)) + q_degrees.append(q) + in_domain_methods.append(is_inside_unit_interval) + + for q in range(1, 2 + 1): + elements_to_test.append(Tet4Element()) + q_degrees.append(q) + in_domain_methods.append(is_inside_tet) + + for q in range(1, 2 + 1): + elements_to_test.append(Tet10Element()) + q_degrees.append(q) + in_domain_methods.append(is_inside_tet) + + return elements_to_test, in_domain_methods, q_degrees + + +def test_are_postive_weights(quadratures_fix): + from pancax.fem import QuadratureRule + from pancax.fem.elements import Tet4Element, Tet10Element + import jax.numpy as jnp + + els, _, qrs = quadratures_fix + for el, q in zip(els, qrs): + if type(el) is Tet4Element and q == 2: + pytest.skip("Not relevant for Tet4Element and q_degree = 2") + if type(el) is Tet10Element and q == 2: + pytest.skip("Not relevant for Tet10Element and q_degree = 2") + qr = QuadratureRule(el, q) + _, w = qr + assert jnp.all(w > 0.0) + + +# @pytest.mark.parametrize( +# "el, q, is_inside", zip(elements_to_test, q_degrees, in_domain_methods) +# ) +# def test_are_inside_domain(el, q, is_inside): +def test_are_inside_domain(quadratures_fix): + from pancax.fem import QuadratureRule + els, is_insides, qrs = quadratures_fix + for el, q, is_inside in zip(els, qrs, is_insides): + qr = QuadratureRule(el, q) + for point in qr.xigauss: + assert is_inside(point) # TODO need general test method for other element formulations def test_triangle_quadrature_exactness(): + from pancax.fem import QuadratureRule + from pancax.fem.elements import SimplexTriElement + import jax.numpy as jnp + max_degree = 10 for degree in range(1, max_degree + 1): qr = QuadratureRule(SimplexTriElement(1), degree) @@ -113,6 +133,8 @@ def test_triangle_quadrature_exactness(): def test_len_method(): + from pancax.fem import QuadratureRule + from pancax.fem.elements import Hex8Element qr = QuadratureRule(Hex8Element(), 1) assert len(qr) == 1 qr = QuadratureRule(Hex8Element(), 2) @@ -121,19 +143,26 @@ def test_len_method(): # error checks def test_error_raise_on_bad_element(): + from pancax.fem import QuadratureRule with pytest.raises(TypeError): - qr = QuadratureRule(dict(), 1) + QuadratureRule(dict(), 1) def test_error_raise_on_bad_quadrature_degree(): + from pancax.fem import QuadratureRule + from pancax.fem.elements import Hex8Element + from pancax.fem.elements import Quad4Element + from pancax.fem.elements import SimplexTriElement + from pancax.fem.elements import Tet4Element + with pytest.raises(ValueError): - qr = QuadratureRule(Hex8Element(), 3) + QuadratureRule(Hex8Element(), 3) with pytest.raises(ValueError): - qr = QuadratureRule(Quad4Element(), 4) + QuadratureRule(Quad4Element(), 4) with pytest.raises(ValueError): - qr = QuadratureRule(SimplexTriElement(1), 11) + QuadratureRule(SimplexTriElement(1), 11) with pytest.raises(ValueError): - qr = QuadratureRule(Tet4Element(), 3) + QuadratureRule(Tet4Element(), 3) diff --git a/test/fem/test_read_exodus_mesh.py b/test/fem/test_read_exodus_mesh.py index 39f798a..6aef061 100644 --- a/test/fem/test_read_exodus_mesh.py +++ b/test/fem/test_read_exodus_mesh.py @@ -1,27 +1,43 @@ -from pancax.fem import read_exodus_mesh -import os - - def test_read_exodus_mesh_hex8(): - f = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mesh_hex8.g") - mesh = read_exodus_mesh(f) + from pancax.fem import read_exodus_mesh + import os + f = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "mesh_hex8.g" + ) + read_exodus_mesh(f) def test_read_exodus_mesh_quad4(): - f = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mesh_quad4.g") - mesh = read_exodus_mesh(f) + from pancax.fem import read_exodus_mesh + import os + f = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "mesh_quad4.g" + ) + read_exodus_mesh(f) def test_read_exodus_mesh_quad9(): - f = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mesh_quad9.g") - mesh = read_exodus_mesh(f) + from pancax.fem import read_exodus_mesh + import os + f = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "mesh_quad9.g" + ) + read_exodus_mesh(f) def test_read_exodus_mesh_tri(): - f = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mesh_no_ssets.g") - mesh = read_exodus_mesh(f) + from pancax.fem import read_exodus_mesh + import os + f = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "mesh_no_ssets.g" + ) + read_exodus_mesh(f) def test_read_exodus_mesh_with_ssets_tri(): - f = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mesh_1x.g") - mesh = read_exodus_mesh(f) + from pancax.fem import read_exodus_mesh + import os + f = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "mesh_1x.g" + ) + read_exodus_mesh(f) diff --git a/test/fem/test_surface.py b/test/fem/test_surface.py index d34d6ad..5c120dc 100644 --- a/test/fem/test_surface.py +++ b/test/fem/test_surface.py @@ -1,43 +1,57 @@ -from pancax.fem import LineElement, QuadratureRule -from pancax.fem.surface import integrate_function_on_surface -from .utils import create_mesh_and_disp -import jax.numpy as jnp +import pytest -Nx = 4 -Ny = 4 -L = 1.2 -W = 1.5 -xRange = [0.0, L] -yRange = [0.0, W] +@pytest.fixture +def surf_mesh_fix(): + from pancax.fem import LineElement, QuadratureRule + from .utils import create_mesh_and_disp + import jax.numpy as jnp + Nx = 4 + Ny = 4 + L = 1.2 + W = 1.5 + xRange = [0.0, L] + yRange = [0.0, W] -targetDispGrad = jnp.zeros((2, 2)) + targetDispGrad = jnp.zeros((2, 2)) -mesh, U = create_mesh_and_disp(Nx, Ny, xRange, yRange, lambda x: targetDispGrad.dot(x)) - - -quadRule = QuadratureRule(LineElement(1), 2) + mesh, U = create_mesh_and_disp( + Nx, Ny, xRange, yRange, lambda x: targetDispGrad.dot(x) + ) + quad_rule = QuadratureRule(LineElement(1), 2) + return mesh, U, L, W, quad_rule -def test_integrate_perimeter(): - print(mesh) +def test_integrate_perimeter(surf_mesh_fix): + from pancax.fem.surface import integrate_function_on_surface + import jax.numpy as jnp + mesh, _, L, W, quad_rule = surf_mesh_fix p = integrate_function_on_surface( - quadRule, mesh.sideSets["all_boundary"], mesh, lambda x, n: 1.0 + quad_rule, mesh.sideSets["all_boundary"], mesh, lambda x, n: 1.0 ) # assertNear(p, 2*(L+W), 14) assert jnp.abs(p - 2 * (L + W)) < 1.0e-14 -def test_integrate_quadratic_fn_on_surface(): - I = integrate_function_on_surface( - quadRule, mesh.sideSets["top"], mesh, lambda x, n: x[0] ** 2 +def test_integrate_quadratic_fn_on_surface(surf_mesh_fix): + from pancax.fem.surface import integrate_function_on_surface + import jax.numpy as jnp + mesh, _, L, _, quad_rule = surf_mesh_fix + Ival = integrate_function_on_surface( + quad_rule, mesh.sideSets["top"], mesh, lambda x, n: x[0] ** 2 ) # assertNear(I, L**3/3.0, 14) - assert jnp.abs(I - L ** (3 / 3.0)) < 1.0e14 - - -def test_integrate_function_on_surface_that_uses_coords_and_normal(): - I = integrate_function_on_surface( - quadRule, mesh.sideSets["all_boundary"], mesh, lambda x, n: jnp.dot(x, n) + assert jnp.abs(Ival - L ** (3 / 3.0)) < 1.0e14 + + +def test_integrate_function_on_surface_that_uses_coords_and_normal( + surf_mesh_fix +): + from pancax.fem.surface import integrate_function_on_surface + import jax.numpy as jnp + mesh, _, L, W, quad_rule = surf_mesh_fix + Ival = integrate_function_on_surface( + quad_rule, mesh.sideSets["all_boundary"], mesh, + lambda x, n: jnp.dot(x, n) ) - assert jnp.abs(I - 2 * L * W) < 1e-14 + assert jnp.abs(Ival - 2 * L * W) < 1e-14 diff --git a/test/fem/utils.py b/test/fem/utils.py index fa00e12..ad3cf52 100644 --- a/test/fem/utils.py +++ b/test/fem/utils.py @@ -1,19 +1,27 @@ from jax import vmap -from pancax.fem import construct_mesh_from_basic_data, create_structured_mesh_data +from pancax.fem import ( + construct_mesh_from_basic_data, + create_structured_mesh_data +) from pancax.fem.surface import create_edges import jax.numpy as jnp -def create_mesh_and_disp(Nx, Ny, xRange, yRange, initial_disp_func, setNamePostFix=""): +def create_mesh_and_disp( + Nx, Ny, xRange, yRange, initial_disp_func, setNamePostFix="" +): coords, conns = create_structured_mesh_data(Nx, Ny, xRange, yRange) tol = 1e-7 nodeSets = {} - nodeSets["left" + setNamePostFix] = jnp.flatnonzero(coords[:, 0] < xRange[0] + tol) + nodeSets["left" + setNamePostFix] = \ + jnp.flatnonzero(coords[:, 0] < xRange[0] + tol) nodeSets["bottom" + setNamePostFix] = jnp.flatnonzero( coords[:, 1] < yRange[0] + tol ) - nodeSets["right" + setNamePostFix] = jnp.flatnonzero(coords[:, 0] > xRange[1] - tol) - nodeSets["top" + setNamePostFix] = jnp.flatnonzero(coords[:, 1] > yRange[1] - tol) + nodeSets["right" + setNamePostFix] = \ + jnp.flatnonzero(coords[:, 0] > xRange[1] - tol) + nodeSets["top" + setNamePostFix] = \ + jnp.flatnonzero(coords[:, 1] > yRange[1] - tol) nodeSets["all_boundary" + setNamePostFix] = jnp.flatnonzero( (coords[:, 0] < xRange[0] + tol) | (coords[:, 1] < yRange[0] + tol) @@ -34,15 +42,21 @@ def is_edge_on_top(xyOnEdge): return jnp.all(xyOnEdge[:, 1] > yRange[1] - tol) sideSets = {} - sideSets["left" + setNamePostFix] = create_edges(coords, conns, is_edge_on_left) - sideSets["bottom" + setNamePostFix] = create_edges(coords, conns, is_edge_on_bottom) - sideSets["right" + setNamePostFix] = create_edges(coords, conns, is_edge_on_right) - sideSets["top" + setNamePostFix] = create_edges(coords, conns, is_edge_on_top) + sideSets["left" + setNamePostFix] = \ + create_edges(coords, conns, is_edge_on_left) + sideSets["bottom" + setNamePostFix] = \ + create_edges(coords, conns, is_edge_on_bottom) + sideSets["right" + setNamePostFix] = \ + create_edges(coords, conns, is_edge_on_right) + sideSets["top" + setNamePostFix] = \ + create_edges(coords, conns, is_edge_on_top) print(sideSets.values()) allBoundaryEdges = jnp.vstack([s for s in sideSets.values()]) sideSets["all_boundary" + setNamePostFix] = allBoundaryEdges blocks = {"block" + setNamePostFix: jnp.arange(conns.shape[0])} - mesh = construct_mesh_from_basic_data(coords, conns, blocks, nodeSets, sideSets) + mesh = construct_mesh_from_basic_data( + coords, conns, blocks, nodeSets, sideSets + ) print(mesh) return mesh, vmap(initial_disp_func)(mesh.coords) diff --git a/test/math/test_math.py b/test/math/test_math.py index cac564c..a5ae837 100644 --- a/test/math/test_math.py +++ b/test/math/test_math.py @@ -5,15 +5,15 @@ def test_safe_sqrt(): f = math.safe_sqrt(4.0) - assert jnp.allclose(f, 2.) + assert jnp.allclose(f, 2.0) - df = jax.grad(math.safe_sqrt)(4.) + df = jax.grad(math.safe_sqrt)(4.0) assert jnp.allclose(df, 0.25) f = math.safe_sqrt(-4.0) assert jnp.isnan(f) - df = jax.grad(math.safe_sqrt)(-4.) + df = jax.grad(math.safe_sqrt)(-4.0) assert jnp.allclose(df, 0.0) diff --git a/test/math/test_tensor_math.py b/test/math/test_tensor_math.py index acc1bb4..b939b37 100644 --- a/test/math/test_tensor_math.py +++ b/test/math/test_tensor_math.py @@ -1,4 +1,3 @@ - from jax import numpy as np from jax.scipy import linalg from jax.test_util import check_grads @@ -9,50 +8,49 @@ R = Rotation.random(random_state=41).as_matrix() + def numerical_grad(f): def lam(A): - df = np.zeros((3,3)) + df = np.zeros((3, 3)) eps = 1e-7 ff = f(A) for i in range(3): for j in range(3): - Ap = A.at[i,j].add(eps) + Ap = A.at[i, j].add(eps) fp = f(Ap) - fprime = (fp-ff)/eps - df = df.at[i,j].add(fprime) + fprime = (fp - ff) / eps + df = df.at[i, j].add(fprime) return df + return lam def generate_n_random_symmetric_matrices(n, minval=0.0, maxval=1.0): key = jax.random.PRNGKey(0) - As = jax.random.uniform(key, (n,3,3), minval=minval, maxval=maxval) - return jax.vmap(lambda A: np.dot(A.T,A), (0,))(As) + As = jax.random.uniform(key, (n, 3, 3), minval=minval, maxval=maxval) + return jax.vmap(lambda A: np.dot(A.T, A), (0,))(As) + + +def log_squared(A): + return np.tensordot(tensor_math.log_sqrt(A), tensor_math.log_sqrt(A)) -log_squared = lambda A: np.tensordot(tensor_math.log_sqrt(A), tensor_math.log_sqrt(A)) sqrtm_jit = jax.jit(tensor_math.sqrtm) logm_iss_jit = jax.jit(tensor_math.logm_iss) def test_log_sqrt_tensor_jvp_0(): - A = np.array([[2.0, 0.0, 0.0], - [0.0, 1.2, 0.0], - [0.0, 0.0, 2.0]]) + A = np.array([[2.0, 0.0, 0.0], [0.0, 1.2, 0.0], [0.0, 0.0, 2.0]]) check_grads(log_squared, (A,), order=1) - - + + def test_log_sqrt_tensor_jvp_1(): - A = np.array([[2.0, 0.0, 0.0], - [0.0, 1.2, 0.0], - [0.0, 0.0, 3.0]]) + A = np.array([[2.0, 0.0, 0.0], [0.0, 1.2, 0.0], [0.0, 0.0, 3.0]]) check_grads(log_squared, (A,), order=1) - + def test_log_sqrt_tensor_jvp_2(): - A = np.array([[2.0, 0.0, 0.2], - [0.0, 1.2, 0.1], - [0.2, 0.1, 3.0]]) + A = np.array([[2.0, 0.0, 0.2], [0.0, 1.2, 0.1], [0.2, 0.1, 3.0]]) check_grads(log_squared, (A,), order=1) @@ -60,47 +58,53 @@ def test_log_sqrt_tensor_jvp_2(): # def test_log_sqrt_hessian_on_double_degenerate_eigenvalues(self): # eigvals = np.array([2., 0.5, 2.]) # C = R@np.diag(eigvals)@R.T -# check_grads(jax.jacrev(TensorMath.log_sqrt), (C,), order=1, modes=['fwd'], rtol=1e-9, atol=1e-9, eps=1e-5) +# check_grads(jax.jacrev( +# TensorMath.log_sqrt), (C,), order=1, +# modes=['fwd'], rtol=1e-9, atol=1e-9, eps=1e-5) def test_eigen_sym33_non_unit(): key = jax.random.PRNGKey(0) - F = jax.random.uniform(key, (3,3), minval=1e-8, maxval=10.0) - C = F.T@F - d,vecs = tensor_math.eigen_sym33_unit(C) + F = jax.random.uniform(key, (3, 3), minval=1e-8, maxval=10.0) + C = F.T @ F + d, vecs = tensor_math.eigen_sym33_unit(C) np.array_equal(C, vecs @ np.diag(d) @ vecs.T) np.array_equal(vecs @ vecs.T, np.identity(3)) - + def test_eigen_sym33_non_unit_degenerate_case(): - C = 5.0*np.identity(3) - d,vecs = tensor_math.eigen_sym33_unit(C) + C = 5.0 * np.identity(3) + d, vecs = tensor_math.eigen_sym33_unit(C) np.array_equal(C, vecs @ np.diag(d) @ vecs.T) np.array_equal(vecs @ vecs.T, np.identity(3)) - -### mtk_log_sqrt tests ### - - + + +# mtk_log_sqrt tests # + + def test_log_sqrt_scaled_identity(): val = 1.2 C = np.diag(np.array([val, val, val])) logSqrtVal = np.log(np.sqrt(val)) - np.array_equal(tensor_math.mtk_log_sqrt(C), np.diag(np.array([logSqrtVal, logSqrtVal, logSqrtVal]))) + np.array_equal( + tensor_math.mtk_log_sqrt(C), + np.diag(np.array([logSqrtVal, logSqrtVal, logSqrtVal])), + ) def test_log_sqrt_double_eigs(): val1 = 2.0 val2 = 0.5 - C = R@np.diag(np.array([val1, val2, val1]))@R.T + C = R @ np.diag(np.array([val1, val2, val1])) @ R.T logSqrt1 = np.log(np.sqrt(val1)) logSqrt2 = np.log(np.sqrt(val2)) diagLogSqrt = np.diag(np.array([logSqrt1, logSqrt2, logSqrt1])) - logSqrtCExpected = R@diagLogSqrt@R.T + logSqrtCExpected = R @ diagLogSqrt @ R.T np.array_equal(tensor_math.mtk_log_sqrt(C), logSqrtCExpected) - + def test_log_sqrt_squared_grad_scaled_identity(): val = 1.2 C = np.diag(np.array([val, val, val])) @@ -108,32 +112,35 @@ def test_log_sqrt_squared_grad_scaled_identity(): def log_squared(A): lg = tensor_math.mtk_log_sqrt(A) return np.tensordot(lg, lg) + check_grads(log_squared, (C,), order=1) - - + + def test_log_sqrt_squared_grad_double_eigs(): val1 = 2.0 val2 = 0.5 - C = R@np.diag(np.array([val1, val2, val1]))@R.T + C = R @ np.diag(np.array([val1, val2, val1])) @ R.T def log_squared(A): lg = tensor_math.mtk_log_sqrt(A) return np.tensordot(lg, lg) + check_grads(log_squared, (C,), order=1) - + def test_log_sqrt_squared_grad_rand(): key = jax.random.PRNGKey(0) - F = jax.random.uniform(key, (3,3), minval=1e-8, maxval=10.0) - C = F.T@F + F = jax.random.uniform(key, (3, 3), minval=1e-8, maxval=10.0) + C = F.T @ F def log_squared(A): lg = tensor_math.mtk_log_sqrt(A) return np.tensordot(lg, lg) + check_grads(log_squared, (C,), order=1) - - -### mtk_pow tests ### + + +# mtk_pow tests # def test_pow_scaled_identity(): @@ -142,24 +149,26 @@ def test_pow_scaled_identity(): C = np.diag(np.array([val, val, val])) powVal = np.power(val, m) - np.array_equal(tensor_math.mtk_pow(C,m), np.diag(np.array([powVal, powVal, powVal]))) + np.array_equal( + tensor_math.mtk_pow(C, m), np.diag(np.array([powVal, powVal, powVal])) + ) def test_pow_double_eigs(): m = 0.25 val1 = 2.1 val2 = 0.6 - C = R@np.diag(np.array([val1, val2, val1]))@R.T + C = R @ np.diag(np.array([val1, val2, val1])) @ R.T powVal1 = np.power(val1, m) powVal2 = np.power(val2, m) diagLogSqrt = np.diag(np.array([powVal1, powVal2, powVal1])) - logSqrtCExpected = R@diagLogSqrt@R.T + logSqrtCExpected = R @ diagLogSqrt @ R.T + + np.array_equal(tensor_math.mtk_pow(C, m), logSqrtCExpected) - np.array_equal(tensor_math.mtk_pow(C,m), logSqrtCExpected) - def test_pow_squared_grad_scaled_identity(): val = 1.2 C = np.diag(np.array([val, val, val])) @@ -168,41 +177,44 @@ def pow_squared(A): m = 0.25 lg = tensor_math.mtk_pow(A, m) return np.tensordot(lg, lg) + check_grads(pow_squared, (C,), order=1) def test_pow_squared_grad_double_eigs(): val1 = 2.0 val2 = 0.5 - C = R@np.diag(np.array([val1, val2, val1]))@R.T + C = R @ np.diag(np.array([val1, val2, val1])) @ R.T def pow_squared(A): - m=0.25 + m = 0.25 lg = tensor_math.mtk_pow(A, m) return np.tensordot(lg, lg) + check_grads(pow_squared, (C,), order=1) def test_pow_squared_grad_rand(): key = jax.random.PRNGKey(0) - F = jax.random.uniform(key, (3,3), minval=1e-8, maxval=10.0) - C = F.T@F + F = jax.random.uniform(key, (3, 3), minval=1e-8, maxval=10.0) + C = F.T @ F def pow_squared(A): - m=0.25 + m = 0.25 lg = tensor_math.mtk_pow(A, m) return np.tensordot(lg, lg) + check_grads(pow_squared, (C,), order=1) - - -### sqrtm ### + + +# sqrtm # def test_sqrtm_jit(): C = generate_n_random_symmetric_matrices(1)[0] sqrtC = sqrtm_jit(C) assert not np.isnan(sqrtC).any() - + def test_sqrtm(): mats = generate_n_random_symmetric_matrices(100) @@ -222,7 +234,7 @@ def test_sqrtm_rev_mode_derivative(): def test_sqrtm_on_degenerate_eigenvalues(): - C = R@np.diag(np.array([2., 0.5, 2]))@R.T + C = R @ np.diag(np.array([2.0, 0.5, 2])) @ R.T sqrtC = tensor_math.sqrtm(C) shouldBeC = np.dot(sqrtC, sqrtC) np.array_equal(shouldBeC, C) @@ -231,40 +243,42 @@ def test_sqrtm_on_degenerate_eigenvalues(): def test_sqrtm_on_10x10(): key = jax.random.PRNGKey(0) - F = jax.random.uniform(key, (10,10), minval=1e-8, maxval=10.0) - C = F.T@F + F = jax.random.uniform(key, (10, 10), minval=1e-8, maxval=10.0) + C = F.T @ F sqrtC = tensor_math.sqrtm(C) - shouldBeC = np.dot(sqrtC,sqrtC) + shouldBeC = np.dot(sqrtC, sqrtC) np.array_equal(shouldBeC, C) def test_sqrtm_derivatives_on_10x10(): key = jax.random.PRNGKey(0) - F = jax.random.uniform(key, (10,10), minval=1e-8, maxval=10.0) - C = F.T@F + F = jax.random.uniform(key, (10, 10), minval=1e-8, maxval=10.0) + C = F.T @ F check_grads(tensor_math.sqrtm, (C,), order=1, modes=["fwd", "rev"]) def test_logm_iss_on_matrix_near_identity(): key = jax.random.PRNGKey(0) - id_perturbation = 1.0 + jax.random.uniform(key, (3,), minval=1e-8, maxval=0.01) + id_perturbation = 1.0 + jax.random.uniform( + key, (3,), minval=1e-8, maxval=0.01 + ) A = np.diag(id_perturbation) logA = tensor_math.logm_iss(A) np.array_equal(logA, np.diag(np.log(id_perturbation))) def test_logm_iss_on_double_degenerate_eigenvalues(): - eigvals = np.array([2., 0.5, 2.]) - C = R@np.diag(eigvals)@R.T + eigvals = np.array([2.0, 0.5, 2.0]) + C = R @ np.diag(eigvals) @ R.T logC = tensor_math.logm_iss(C) - logCSpectral = R@np.diag(np.log(eigvals))@R.T + logCSpectral = R @ np.diag(np.log(eigvals)) @ R.T np.array_equal(logC, logCSpectral) def test_logm_iss_on_triple_degenerate_eigvalues(): - A = 4.0*np.identity(3) + A = 4.0 * np.identity(3) logA = tensor_math.logm_iss(A) - np.array_equal(logA, np.log(4.0)*np.identity(3)) + np.array_equal(logA, np.log(4.0) * np.identity(3)) def test_logm_iss_jit(): @@ -277,43 +291,51 @@ def test_logm_iss_on_full_3x3s(): mats = generate_n_random_symmetric_matrices(1000) logMats = jax.vmap(logm_iss_jit, (0,))(mats) shouldBeMats = jax.vmap(lambda A: linalg.expm(A), (0,))(logMats) - # self.assertArrayNear(shouldBeMats, mats, 7) - np.array_equal(shouldBeMats, mats) + # self.assertArrayNear(shouldBeMats, mats, 7) + np.array_equal(shouldBeMats, mats) + - def test_logm_iss_fwd_mode_derivative(): C = generate_n_random_symmetric_matrices(1)[0] - check_grads(logm_iss_jit, (C,), order=1, modes=['fwd']) + check_grads(logm_iss_jit, (C,), order=1, modes=["fwd"]) def test_logm_iss_rev_mode_derivative(): C = generate_n_random_symmetric_matrices(1)[0] - check_grads(logm_iss_jit, (C,), order=1, modes=['rev']) + check_grads(logm_iss_jit, (C,), order=1, modes=["rev"]) def test_logm_iss_hessian_on_double_degenerate_eigenvalues(): - eigvals = np.array([2., 0.5, 2.]) - C = R@np.diag(eigvals)@R.T - check_grads(jax.jacrev(tensor_math.logm_iss), (C,), order=1, modes=['fwd'], rtol=1e-9, atol=1e-9, eps=1e-5) + eigvals = np.array([2.0, 0.5, 2.0]) + C = R @ np.diag(eigvals) @ R.T + check_grads( + jax.jacrev(tensor_math.logm_iss), + (C,), + order=1, + modes=["fwd"], + rtol=1e-9, + atol=1e-9, + eps=1e-5, + ) def test_logm_iss_derivatives_on_double_degenerate_eigenvalues(): - eigvals = np.array([2., 0.5, 2.]) - C = R@np.diag(eigvals)@R.T - check_grads(tensor_math.logm_iss, (C,), order=1, modes=['fwd']) - check_grads(tensor_math.logm_iss, (C,), order=1, modes=['rev']) + eigvals = np.array([2.0, 0.5, 2.0]) + C = R @ np.diag(eigvals) @ R.T + check_grads(tensor_math.logm_iss, (C,), order=1, modes=["fwd"]) + check_grads(tensor_math.logm_iss, (C,), order=1, modes=["rev"]) def test_logm_iss_derivatives_on_triple_degenerate_eigenvalues(): - A = 4.0*np.identity(3) - check_grads(tensor_math.logm_iss, (A,), order=1, modes=['fwd']) - check_grads(tensor_math.logm_iss, (A,), order=1, modes=['rev']) + A = 4.0 * np.identity(3) + check_grads(tensor_math.logm_iss, (A,), order=1, modes=["fwd"]) + check_grads(tensor_math.logm_iss, (A,), order=1, modes=["rev"]) def test_logm_iss_on_10x10(): key = jax.random.PRNGKey(0) - F = jax.random.uniform(key, (10,10), minval=1e-8, maxval=10.0) - C = F.T@F + F = jax.random.uniform(key, (10, 10), minval=1e-8, maxval=10.0) + C = F.T @ F logC = tensor_math.logm_iss(C) logCSpectral = tensor_math.logh(C) np.array_equal(logC, logCSpectral) @@ -322,13 +344,18 @@ def test_logm_iss_on_10x10(): def test_compute_deviatoric_tensor(): key = jax.random.PRNGKey(0) F = jax.random.uniform(key, (3, 3), minval=1e-8, maxval=10.0) - assert np.allclose(tensor_math.compute_deviatoric_tensor(F), F - (1. / 3.) * np.trace(F) * np.eye(3)) + assert np.allclose( + tensor_math.compute_deviatoric_tensor(F), + F - (1.0 / 3.0) * np.trace(F) * np.eye(3), + ) def test_tensor_norm(): key = jax.random.PRNGKey(0) F = jax.random.uniform(key, (3, 3), minval=1e-8, maxval=10.0) - assert np.allclose(tensor_math.tensor_norm(F), np.linalg.norm(F, ord='fro')) + assert np.allclose( + tensor_math.tensor_norm(F), np.linalg.norm(F, ord="fro") + ) def test_norm_of_deviator_squared(): @@ -356,8 +383,11 @@ def test_mises_equivalent_stress(): def test_triaxiality(): key = jax.random.PRNGKey(0) F = jax.random.uniform(key, (3, 3), minval=1e-8, maxval=10.0) - mean_normal = np.trace(F) / 3. - an = mean_normal / (tensor_math.mises_equivalent_stress(F) + np.finfo(np.dtype('float64')).eps) + mean_normal = np.trace(F) / 3.0 + an = mean_normal / ( + tensor_math.mises_equivalent_stress(F) + + np.finfo(np.dtype("float64")).eps + ) assert np.allclose(tensor_math.triaxiality(F), an) diff --git a/test/networks/test_base.py b/test/networks/test_base.py index 6dee3a1..35ec919 100644 --- a/test/networks/test_base.py +++ b/test/networks/test_base.py @@ -62,4 +62,4 @@ def test_uniform_init(models, init_funcs): for model in models: for init_func in init_funcs: - new_model = model.init(init_func, key=key) + model.init(init_func, key=key) diff --git a/test/test_post_processors.py b/test/test_post_processors.py index f6ff007..ce02273 100644 --- a/test/test_post_processors.py +++ b/test/test_post_processors.py @@ -1,5 +1,6 @@ # from jax import random -# from pancax import DirichletBC, VariationalDomain, NeoHookean, ThreeDimensional, SolidMechanics +# from pancax import DirichletBC, +# VariationalDomain, NeoHookean, ThreeDimensional, SolidMechanics # from pancax import FieldPhysicsPair, MLP # from pancax import PostProcessor, ForwardProblem # from pathlib import Path @@ -11,7 +12,10 @@ @pytest.fixture def problem(): - from pancax import DirichletBC, ForwardProblem, NeoHookean, SolidMechanics, ThreeDimensional, VariationalDomain + from pancax import ( + DirichletBC, ForwardProblem, + NeoHookean, SolidMechanics, ThreeDimensional, VariationalDomain + ) from pathlib import Path import jax.numpy as jnp import os @@ -19,10 +23,13 @@ def problem(): times = jnp.linspace(0., 1.0, 2) domain = VariationalDomain(mesh_file, times) - physics = SolidMechanics(NeoHookean(bulk_modulus=0.833, shear_modulus=0.3846), ThreeDimensional()) + physics = SolidMechanics( + NeoHookean(bulk_modulus=0.833, shear_modulus=0.3846), + ThreeDimensional() + ) ics = [ ] - essential_bc_func = lambda x, t, z: z + # essential_bc_func = lambda x, t, z: z essential_bcs = [ DirichletBC('nset_4', 0), DirichletBC('nset_4', 1), @@ -53,7 +60,9 @@ def test_post_processor(params, problem): import os mesh_file = os.path.join(Path(__file__).parent, 'mesh.g') pp = PostProcessor(mesh_file) - pp.init(problem, 'output.e', + pp.init( + problem, + 'output.e', node_variables=[ 'field_values' ], @@ -72,7 +81,9 @@ def test_post_processor_bad_var_name(problem): mesh_file = os.path.join(Path(__file__).parent, 'mesh.g') pp = PostProcessor(mesh_file) with pytest.raises(ValueError): - pp.init(problem, 'output.e', + pp.init( + problem, + 'output.e', node_variables=[ 'bad_var_name' ]