diff --git a/test/data/test_full_field_data.py b/test/data/test_full_field_data.py index 67dca23..4a294f2 100644 --- a/test/data/test_full_field_data.py +++ b/test/data/test_full_field_data.py @@ -3,8 +3,8 @@ def test_full_field_data(): from pathlib import Path import os data_file = os.path.join(Path(__file__).parent, 'data_full_field.csv') - data = FullFieldData( - data_file, + FullFieldData( + data_file, input_keys=['x', 'y', 'z', 't'], output_keys=['u_x', 'u_y', 'u_z'] ) @@ -13,8 +13,8 @@ def test_full_field_data(): # def test_full_field_data_plot_registration(): # data_file = os.path.join(Path(__file__).parent, 'data_full_field.csv') # data = FullFieldData( -# data_file, +# data_file, # input_keys=['x', 'y', 'z', 't'], # output_keys=['u_x', 'u_y', 'u_z'] # ) -# data.plot_e \ No newline at end of file +# data.plot_e diff --git a/test/data/test_global_data.py b/test/data/test_global_data.py index a883749..40c6a80 100644 --- a/test/data/test_global_data.py +++ b/test/data/test_global_data.py @@ -10,8 +10,8 @@ def test_global_data(): import os data_file = os.path.join(Path(__file__).parent, 'data_global.csv') mesh_file = os.path.join(Path(__file__).parent, 'mesh.g') - data = GlobalData( - data_file, + GlobalData( + data_file, times_key='t', disp_key='u_x', force_key='f_x', @@ -20,8 +20,8 @@ def test_global_data(): reaction_dof='x', n_time_steps=11 ) - data = GlobalData( - data_file, + GlobalData( + data_file, times_key='t', disp_key='u_x', force_key='f_x', @@ -30,8 +30,8 @@ def test_global_data(): reaction_dof='y', n_time_steps=11 ) - data = GlobalData( - data_file, + GlobalData( + data_file, times_key='t', disp_key='u_x', force_key='f_x', @@ -49,8 +49,8 @@ def test_global_data_with_plotting(): import os data_file = os.path.join(Path(__file__).parent, 'data_global.csv') mesh_file = os.path.join(Path(__file__).parent, 'mesh.g') - data = GlobalData( - data_file, + GlobalData( + data_file, times_key='t', disp_key='u_x', force_key='f_x', @@ -69,8 +69,8 @@ def test_global_data_bad_reaction_dof(): data_file = os.path.join(Path(__file__).parent, 'data_global.csv') mesh_file = os.path.join(Path(__file__).parent, 'mesh.g') with pytest.raises(ValueError): - data = GlobalData( - data_file, + GlobalData( + data_file, times_key='t', disp_key='u_x', force_key='f_x', @@ -87,11 +87,13 @@ def test_global_data_times_not_unique_exception(): GlobalDataTimesNotUniqueException from pathlib import Path import os - data_file = os.path.join(Path(__file__).parent, 'data_global_not_unique.csv') + data_file = os.path.join( + Path(__file__).parent, 'data_global_not_unique.csv' + ) mesh_file = os.path.join(Path(__file__).parent, 'mesh.g') with pytest.raises(GlobalDataTimesNotUniqueException): - data = GlobalData( - data_file, + GlobalData( + data_file, times_key='t', disp_key='u_x', force_key='f_x', @@ -108,11 +110,13 @@ def test_global_data_times_not_strictly_increasing(): GlobalDataTimesNotStrictlyIncreasingException from pathlib import Path import os - data_file = os.path.join(Path(__file__).parent, 'data_global_not_strictly_increasing.csv') + data_file = os.path.join( + Path(__file__).parent, 'data_global_not_strictly_increasing.csv' + ) mesh_file = os.path.join(Path(__file__).parent, 'mesh.g') with pytest.raises(GlobalDataTimesNotStrictlyIncreasingException): - data = GlobalData( - data_file, + GlobalData( + data_file, times_key='t', disp_key='u_x', force_key='f_x', diff --git a/test/domains/test_base.py b/test/domains/test_base.py index b7bc70f..335fdc8 100644 --- a/test/domains/test_base.py +++ b/test/domains/test_base.py @@ -8,7 +8,7 @@ def test_simulation_times_unique_exception(): import pytest mesh_file = os.path.join(Path(__file__).parent, 'mesh.g') times = jnp.array([0., 0.]) - + with pytest.raises(SimulationTimesNotUniqueException): BaseDomain(mesh_file, times) diff --git a/test/domains/test_collocation_domain.py b/test/domains/test_collocation_domain.py index bc50f6e..bd6822e 100644 --- a/test/domains/test_collocation_domain.py +++ b/test/domains/test_collocation_domain.py @@ -5,7 +5,7 @@ def test_collocation_domain_3D(): import os mesh_file = os.path.join(Path(__file__).parent, 'mesh.g') times = jnp.linspace(0., 1.0, 2) - domain = CollocationDomain(mesh_file, times) + CollocationDomain(mesh_file, times) def test_collocation_domain_tri3_p_order(): @@ -15,4 +15,4 @@ def test_collocation_domain_tri3_p_order(): import os mesh_file = os.path.join(Path(__file__).parent, 'mesh_10x.g') times = jnp.linspace(0., 1.0, 2) - domain = CollocationDomain(mesh_file, times, p_order=2) + CollocationDomain(mesh_file, times, p_order=2) diff --git a/test/domains/test_delta_pinn_domain.py b/test/domains/test_delta_pinn_domain.py index f0204ce..6d9552a 100644 --- a/test/domains/test_delta_pinn_domain.py +++ b/test/domains/test_delta_pinn_domain.py @@ -16,10 +16,11 @@ def test_forward_domain(): # natural_bcs = [ # ] # physics = SolidMechanics( - # mesh_file, essential_bc_func, + # mesh_file, essential_bc_func, # NeoHookean(), ThreeDimensional(), # use_delta_pinn=True # ) times = jnp.linspace(0., 1.0, 2) - # domain = DeltaPINNDomain(physics, essential_bcs, natural_bcs, mesh_file, times, 20) - domain = DeltaPINNDomain(mesh_file, times, n_eigen_values=100) \ No newline at end of file + # domain = DeltaPINNDomain( + # physics, essential_bcs, natural_bcs, mesh_file, times, 20) + DeltaPINNDomain(mesh_file, times, n_eigen_values=100) diff --git a/test/domains/test_inverse_domain.py b/test/domains/test_inverse_domain.py index 7b4933d..95cb356 100644 --- a/test/domains/test_inverse_domain.py +++ b/test/domains/test_inverse_domain.py @@ -7,7 +7,8 @@ # def test_inverse_domain(): -# field_data_file = os.path.join(Path(__file__).parent, 'data_full_field.csv') +# field_data_file = os.path.join( +# Path(__file__).parent, 'data_full_field.csv') # global_data_file = os.path.join(Path(__file__).parent, 'data_global.csv') # mesh_file = os.path.join(Path(__file__).parent, 'mesh.g') # essential_bc_func = lambda x, t, z: z @@ -22,16 +23,16 @@ # natural_bcs = [ # ] # physics = SolidMechanics( -# mesh_file, essential_bc_func, +# mesh_file, essential_bc_func, # NeoHookean(), ThreeDimensional() # ) # full_field_data = FullFieldData( -# field_data_file, +# field_data_file, # input_keys=['x', 'y', 'z', 't'], # output_keys=['u_x', 'u_y', 'u_z'] # ) # global_data = GlobalData( -# global_data_file, +# global_data_file, # times_key='t', # disp_key='u_x', # force_key='f_x', diff --git a/test/domains/test_variational_domain.py b/test/domains/test_variational_domain.py index 635e6ed..3527e72 100644 --- a/test/domains/test_variational_domain.py +++ b/test/domains/test_variational_domain.py @@ -5,7 +5,7 @@ def test_forward_domain(): import os mesh_file = os.path.join(Path(__file__).parent, 'mesh.g') times = jnp.linspace(0., 1.0, 2) - domain = VariationalDomain(mesh_file, times) + VariationalDomain(mesh_file, times) def test_forward_domain_tri3_p_order(): @@ -15,4 +15,4 @@ def test_forward_domain_tri3_p_order(): import os mesh_file = os.path.join(Path(__file__).parent, 'mesh_10x.g') times = jnp.linspace(0., 1.0, 2) - domain = VariationalDomain(mesh_file, times, p_order=2) + VariationalDomain(mesh_file, times, p_order=2) diff --git a/test/fem/elements/test_elements.py b/test/fem/elements/test_elements.py index 98a7ca0..bb6cba6 100644 --- a/test/fem/elements/test_elements.py +++ b/test/fem/elements/test_elements.py @@ -1,280 +1,385 @@ -from pancax.fem import QuadratureRule -from pancax.fem.elements import * -import jax -import jax.numpy as jnp -import numpy as onp import pytest tol = 1e-14 def check_1D_interpolant_in_element(element): + import jax.numpy as jnp + import numpy as onp p = element.coordinates return jnp.all(p >= 0.0) and onp.all(p <= 1.0) def check_hex_interpolant_in_element(element): + import jax.numpy as jnp p = element.coordinates - return jnp.all(p[:, 0] >= -1.0) and \ - jnp.all(p[:, 0] <= 1.0) and \ - jnp.all(p[:, 1] >= -1.0) and \ - jnp.all(p[:, 1] <= 1.0) and \ - jnp.all(p[:, 2] >= -1.0) and \ - jnp.all(p[:, 2] <= 1.0) + return ( + jnp.all(p[:, 0] >= -1.0) + and jnp.all(p[:, 0] <= 1.0) + and jnp.all(p[:, 1] >= -1.0) + and jnp.all(p[:, 1] <= 1.0) + and jnp.all(p[:, 2] >= -1.0) + and jnp.all(p[:, 2] <= 1.0) + ) def check_quad_interpolant_in_element(element): + import jax.numpy as jnp p = element.coordinates - return jnp.all(p[:, 0] >= -1.0) and \ - jnp.all(p[:, 0] <= 1.0) and \ - jnp.all(p[:, 1] >= -1.0) and \ - jnp.all(p[:, 1] <= 1.0) + return ( + jnp.all(p[:, 0] >= -1.0) + and jnp.all(p[:, 0] <= 1.0) + and jnp.all(p[:, 1] >= -1.0) + and jnp.all(p[:, 1] <= 1.0) + ) def check_tet_interpolant_in_element(element): + import jax.numpy as jnp p = element.coordinates # x conditions - return jnp.all(p[:, 0] >= -tol) and \ - jnp.all(p[:, 0] <= 1.0 + tol) and \ - jnp.all(p[:, 1] >= -tol) and \ - jnp.all(p[:, 1] <= 1. - p[:, 0] + tol) and \ - jnp.all(p[:, 2] >= -tol) and \ - jnp.all(p[:, 2] <= 1. - p[:, 0] - p[:, 1] + tol) + return ( + jnp.all(p[:, 0] >= -tol) + and jnp.all(p[:, 0] <= 1.0 + tol) + and jnp.all(p[:, 1] >= -tol) + and jnp.all(p[:, 1] <= 1.0 - p[:, 0] + tol) + and jnp.all(p[:, 2] >= -tol) + and jnp.all(p[:, 2] <= 1.0 - p[:, 0] - p[:, 1] + tol) + ) def check_tri_interpolant_in_element(element): + import jax.numpy as jnp p = element.coordinates # x conditions - return jnp.all(p[:, 0] >= -tol) and \ - jnp.all(p[:, 0] <= 1.0 + tol) and \ - jnp.all(p[:, 1] >= -tol) and \ - jnp.all(p[:, 1] <= 1. - p[:,0] + tol) + return ( + jnp.all(p[:, 0] >= -tol) + and jnp.all(p[:, 0] <= 1.0 + tol) + and jnp.all(p[:, 1] >= -tol) + and jnp.all(p[:, 1] <= 1.0 - p[:, 0] + tol) + ) def generate_random_points_in_line(npts): + import jax + import numpy as onp key = jax.random.PRNGKey(2) x = jax.random.uniform(key, (npts,)) return onp.asarray(x) def generate_random_points_in_hex(npts): + import jax + import numpy as onp key = jax.random.PRNGKey(2) x = jax.random.uniform(key, (npts,)) y = jax.random.uniform(key, (npts,)) z = jax.random.uniform(key, (npts,)) points = jax.numpy.column_stack((x, y, z)) - return onp.asarray(points) + return onp.asarray(points) def generate_random_points_in_quad(npts): + import jax + import numpy as onp key = jax.random.PRNGKey(2) x = jax.random.uniform(key, (npts,)) y = jax.random.uniform(key, (npts,)) points = jax.numpy.column_stack((x, y)) - return onp.asarray(points) + return onp.asarray(points) def generate_random_points_in_tet(npts): + import jax + import numpy as onp key = jax.random.PRNGKey(2) x = jax.random.uniform(key, (npts,)) y = jax.numpy.zeros(npts) z = jax.numpy.zeros(npts) for i in range(npts): - key,subkey = jax.random.split(key) - y = y.at[i].set(jax.random.uniform(subkey, minval=0.0, maxval=1.0 - x[i])) - z = z.at[i].set(jax.random.uniform(subkey, minval=0.0, maxval=1.0 - x[i] - y[i])) + key, subkey = jax.random.split(key) + y = y.at[i].set(jax.random.uniform( + subkey, minval=0.0, maxval=1.0 - x[i]) + ) + z = z.at[i].set( + jax.random.uniform(subkey, minval=0.0, maxval=1.0 - x[i] - y[i]) + ) points = jax.numpy.column_stack((x, y)) - return onp.asarray(points) + return onp.asarray(points) def generate_random_points_in_triangle(npts): + import jax + import numpy as onp key = jax.random.PRNGKey(2) x = jax.random.uniform(key, (npts,)) y = jax.numpy.zeros(npts) for i in range(npts): - key,subkey = jax.random.split(key) - y = y.at[i].set(jax.random.uniform(subkey, minval=0.0, maxval=1.0-x[i])) + key, subkey = jax.random.split(key) + y = y.at[i].set( + jax.random.uniform(subkey, minval=0.0, maxval=1.0 - x[i]) + ) points = jax.numpy.column_stack((x, y)) - return onp.asarray(points) - - -q_rules = [] -elements = [] -check_interpolant_in_element_methods = [] -generate_random_points_method = [] -# Hex elements -for q in range(1, 2 + 1): - el = Hex8Element() - q_rules.append(QuadratureRule(el, q)) - elements.append(el) - check_interpolant_in_element_methods.append(check_hex_interpolant_in_element) - generate_random_points_method.append(generate_random_points_in_hex) - -# Line elements -for p in range(1, 25 + 1): - el = LineElement(p) - q_rules.append(QuadratureRule(el, p)) - elements.append(el) - check_interpolant_in_element_methods.append(check_1D_interpolant_in_element) - generate_random_points_method.append(generate_random_points_in_line) - -# Quad elements -for q in range(1, 2 + 1): - el = Quad4Element() - q_rules.append(QuadratureRule(el, q)) - elements.append(el) - check_interpolant_in_element_methods.append(check_quad_interpolant_in_element) - generate_random_points_method.append(generate_random_points_in_quad) - el = Quad9Element() - q_rules.append(QuadratureRule(el, q)) - elements.append(el) - check_interpolant_in_element_methods.append(check_quad_interpolant_in_element) - generate_random_points_method.append(generate_random_points_in_quad) - -for q in range(1, 2 + 1): - el = Tet4Element() - q_rules.append(QuadratureRule(el, q)) - elements.append(el) - check_interpolant_in_element_methods.append(check_tet_interpolant_in_element) - generate_random_points_method.append(generate_random_points_in_tet) - - -for q in range(1, 2 + 1): - el = Tet10Element() - q_rules.append(QuadratureRule(el, q)) - elements.append(el) - check_interpolant_in_element_methods.append(check_tet_interpolant_in_element) - generate_random_points_method.append(generate_random_points_in_tet) - - -# Tri elements -for p in range(1, 6 + 1): - el = SimplexTriElement(p) - q_rules.append(QuadratureRule(el, p)) - elements.append(el) - check_interpolant_in_element_methods.append(check_tri_interpolant_in_element) - generate_random_points_method.append(generate_random_points_in_triangle) - - -@pytest.mark.parametrize('el, check', zip(elements, check_interpolant_in_element_methods)) -def test_interpolant_points_in_element(el, check): - assert check(el) + return onp.asarray(points) + + +@pytest.fixture +def elements_fix(): + from pancax.fem import QuadratureRule + from pancax.fem.elements import Hex8Element + from pancax.fem.elements import LineElement + from pancax.fem.elements import Quad4Element + from pancax.fem.elements import Quad9Element + from pancax.fem.elements import SimplexTriElement + from pancax.fem.elements import Tet4Element + from pancax.fem.elements import Tet10Element + + q_rules = [] + elements = [] + check_interpolant_in_element_methods = [] + generate_random_points_method = [] + + for q in range(1, 2 + 1): + el = Hex8Element() + q_rules.append(QuadratureRule(el, q)) + elements.append(el) + check_interpolant_in_element_methods.append( + check_hex_interpolant_in_element + ) + generate_random_points_method.append(generate_random_points_in_hex) + + # Line elements + for p in range(1, 25 + 1): + el = LineElement(p) + q_rules.append(QuadratureRule(el, p)) + elements.append(el) + check_interpolant_in_element_methods.append( + check_1D_interpolant_in_element + ) + generate_random_points_method.append(generate_random_points_in_line) + + # Quad elements + for q in range(1, 2 + 1): + el = Quad4Element() + q_rules.append(QuadratureRule(el, q)) + elements.append(el) + check_interpolant_in_element_methods.append( + check_quad_interpolant_in_element + ) + generate_random_points_method.append(generate_random_points_in_quad) + el = Quad9Element() + q_rules.append(QuadratureRule(el, q)) + elements.append(el) + check_interpolant_in_element_methods.append( + check_quad_interpolant_in_element + ) + generate_random_points_method.append(generate_random_points_in_quad) + + # Tet elements + for q in range(1, 2 + 1): + el = Tet4Element() + q_rules.append(QuadratureRule(el, q)) + elements.append(el) + check_interpolant_in_element_methods.append( + check_tet_interpolant_in_element + ) + generate_random_points_method.append(generate_random_points_in_tet) + + for q in range(1, 2 + 1): + el = Tet10Element() + q_rules.append(QuadratureRule(el, q)) + elements.append(el) + check_interpolant_in_element_methods.append( + check_tet_interpolant_in_element + ) + generate_random_points_method.append(generate_random_points_in_tet) + + # Tri elements + for p in range(1, 6 + 1): + el = SimplexTriElement(p) + q_rules.append(QuadratureRule(el, p)) + elements.append(el) + check_interpolant_in_element_methods.append( + check_tri_interpolant_in_element + ) + generate_random_points_method.append( + generate_random_points_in_triangle + ) + + return \ + check_interpolant_in_element_methods, \ + elements, \ + generate_random_points_method, \ + q_rules + + +def test_interpolant_points_in_element(elements_fix): + checks, els, _, _ = elements_fix + for check, el in zip(checks, els): + check(el) # topology tests # TODO generalize these methods below -def test_1D_element_element_topological_nodesets(): - for element in elements: - if type(element) == LineElement: +def test_1D_element_element_topological_nodesets(elements_fix): + from pancax.fem.elements import LineElement + import jax.numpy as jnp + for element in elements_fix: + if type(element) is LineElement: p = element.coordinates jnp.isclose(p[element.vertexNodes[0]], 0.0) jnp.isclose(p[element.vertexNodes[1]], 1.0) - + if element.interiorNodes is not None: assert jnp.all(p[element.interiorNodes] > 0.0) assert jnp.all(p[element.interiorNodes] < 1.0) -def test_tri_element_element_topological_nodesets(): - for element in elements: - if type(element) == SimplexTriElement: +def test_tri_element_element_topological_nodesets(elements_fix): + from pancax.fem.elements import SimplexTriElement + import jax.numpy as jnp + import numpy as onp + for element in elements_fix: + if type(element) is SimplexTriElement: p = element.coordinates - jnp.array_equal(p[element.vertexNodes[0], :], onp.array([1.0, 0.0])) - jnp.array_equal(p[element.vertexNodes[1], :], onp.array([0.0, 1.0])) - jnp.array_equal(p[element.vertexNodes[2], :], onp.array([0.0, 0.0])) - + jnp.array_equal( + p[element.vertexNodes[0], :], onp.array([1.0, 0.0]) + ) + jnp.array_equal( + p[element.vertexNodes[1], :], onp.array([0.0, 1.0]) + ) + jnp.array_equal( + p[element.vertexNodes[2], :], onp.array([0.0, 0.0]) + ) + if element.interiorNodes.size > 0: k = element.interiorNodes - assert jnp.all(p[k,0] > -tol) - assert jnp.all(p[k,1] + p[k,0] - 1. < tol) + assert jnp.all(p[k, 0] > -tol) + assert jnp.all(p[k, 1] + p[k, 0] - 1.0 < tol) + + # TODO generalize these methods above # TODO generalize these methods below -def test_tri_face_nodes_match_1D_lobatto_nodes(): +def test_tri_face_nodes_match_1D_lobatto_nodes(elements_fix): + from pancax.fem.elements import LineElement, SimplexTriElement + import jax.numpy as jnp elements1d = [] elements2d = [] - for element in elements: - if type(element) == LineElement: + for element in elements_fix: + if type(element) is LineElement: elements1d.append(element) - if type(element) == SimplexTriElement: + if type(element) is SimplexTriElement: elements2d.append(element) - + for element1d, elementTri in zip(elements1d, elements2d): for faceNodeIds in elementTri.faceNodes: # get the triangle face node points directly - xf = elementTri.coordinates[faceNodeIds,:] + xf = elementTri.coordinates[faceNodeIds, :] # affine transformation of 1D node points to triangle face p = element1d.coordinates - x1d = jnp.outer(1.0 - p, xf[0,:]) + jnp.outer(p, xf[-1,:]) + x1d = jnp.outer(1.0 - p, xf[0, :]) + jnp.outer(p, xf[-1, :]) # make sure they are the same jnp.isclose(xf, x1d) -# TODO generalize these methods above -@pytest.mark.parametrize('el, qr', zip(elements, q_rules)) -def test_partition_of_unity(el, qr): - if type(el) == LineElement: - pytest.skip('LineElement failing for now') - shapes, _ = el.compute_shapes(el.coordinates, qr.xigauss) - assert jnp.allclose(jnp.sum(shapes, axis=1), jnp.ones(len(qr))) - - -@pytest.mark.parametrize('el, qr', zip(elements, q_rules)) -def test_gradient_partition_of_unity(el, qr): - if type(el) == LineElement: - pytest.skip('LineElement failing for now') - _, shapeGradients = el.compute_shapes(el.coordinates, qr.xigauss) - num_dim = qr.xigauss.shape[1] - assert jnp.allclose(jnp.sum(shapeGradients, axis=1), jnp.zeros((len(qr), num_dim))) - -@pytest.mark.parametrize('el', elements) -def test_kronecker_delta_property(el): - if type(el) == LineElement: - pytest.skip('LineElement failing for now') - shapeAtNodes, _ = el.compute_shapes(el.coordinates, el.coordinates) - nNodes = el.coordinates.shape[0] - assert jnp.allclose(shapeAtNodes, jnp.identity(nNodes)) +# TODO generalize these methods above +def test_partition_of_unity(elements_fix): + from pancax.fem.elements import LineElement + import jax.numpy as jnp + _, els, _, qrs = elements_fix + for el, qr in zip(els, qrs): + if type(el) is LineElement: + continue + shapes, _ = el.compute_shapes(el.coordinates, qr.xigauss) + assert jnp.allclose(jnp.sum(shapes, axis=1), jnp.ones(len(qr))) + + +def test_gradient_partition_of_unity(elements_fix): + from pancax.fem.elements import LineElement + import jax.numpy as jnp + _, els, _, qrs = elements_fix + for el, qr in zip(els, qrs): + if type(el) is LineElement: + continue + _, shapeGradients = el.compute_shapes(el.coordinates, qr.xigauss) + num_dim = qr.xigauss.shape[1] + assert jnp.allclose( + jnp.sum(shapeGradients, axis=1), jnp.zeros((len(qr), num_dim)) + ) + + +def test_kronecker_delta_property(elements_fix): + from pancax.fem.elements import LineElement + import jax.numpy as jnp + for _, el, _, _ in zip(*elements_fix): + if type(el) is LineElement: + continue + shapeAtNodes, _ = el.compute_shapes(el.coordinates, el.coordinates) + nNodes = el.coordinates.shape[0] + assert jnp.allclose(shapeAtNodes, jnp.identity(nNodes)) # TODO generalize these methods below -def test_interpolation(): +def test_interpolation(elements_fix): + from pancax.fem.elements import SimplexTriElement + import jax.numpy as jnp + import numpy as onp x = generate_random_points_in_triangle(1) - for element in elements: - if type(element) == SimplexTriElement: + for element in elements_fix: + if type(element) is SimplexTriElement: degree = element.degree - polyCoeffs = onp.fliplr(onp.triu(onp.ones((degree+1,degree+1)))) - expected = onp.polynomial.polynomial.polyval2d(x[:,0], x[:,1], polyCoeffs) + polyCoeffs = onp.fliplr( + onp.triu(onp.ones((degree + 1, degree + 1))) + ) + expected = onp.polynomial.polynomial.polyval2d( + x[:, 0], x[:, 1], polyCoeffs + ) shape, _ = element.compute_shapes(element.coordinates, x) - fn = onp.polynomial.polynomial.polyval2d(element.coordinates[:,0], - element.coordinates[:,1], - polyCoeffs) + fn = onp.polynomial.polynomial.polyval2d( + element.coordinates[:, 0], element.coordinates[:, 1], + polyCoeffs + ) fInterpolated = onp.dot(shape, fn) jnp.array_equal(expected, fInterpolated) -def test_grad_interpolation(): +def test_grad_interpolation(elements_fix): + from pancax.fem.elements import SimplexTriElement + import jax.numpy as jnp + import numpy as onp x = generate_random_points_in_triangle(1) - for element in elements: - if type(element) == SimplexTriElement: + for element in elements_fix: + if type(element) is SimplexTriElement: degree = element.degree - poly = onp.fliplr(onp.triu(onp.ones((degree+1,degree+1)))) + poly = onp.fliplr(onp.triu(onp.ones((degree + 1, degree + 1)))) _, dShape = element.compute_shapes(element.coordinates, x) - fn = onp.polynomial.polynomial.polyval2d(element.coordinates[:,0], - element.coordinates[:,1], - poly) - dfInterpolated = onp.einsum('qai,a->qi',dShape, fn) + fn = onp.polynomial.polynomial.polyval2d( + element.coordinates[:, 0], element.coordinates[:, 1], poly + ) + dfInterpolated = onp.einsum("qai,a->qi", dShape, fn) # exact x derivative direction = 0 - DPoly = onp.polynomial.polynomial.polyder(poly, 1, scl=1, axis=direction) - expected0 = onp.polynomial.polynomial.polyval2d(x[:,0], x[:,1], DPoly) + DPoly = onp.polynomial.polynomial.polyder( + poly, 1, scl=1, axis=direction + ) + expected0 = onp.polynomial.polynomial.polyval2d( + x[:, 0], x[:, 1], DPoly + ) - jnp.array_equal(expected0, dfInterpolated[:,0]) + jnp.array_equal(expected0, dfInterpolated[:, 0]) direction = 1 - DPoly = onp.polynomial.polynomial.polyder(poly, 1, scl=1, axis=direction) - expected1 = onp.polynomial.polynomial.polyval2d(x[:,0], x[:,1], DPoly) + DPoly = onp.polynomial.polynomial.polyder( + poly, 1, scl=1, axis=direction + ) + expected1 = onp.polynomial.polynomial.polyval2d( + x[:, 0], x[:, 1], DPoly + ) + + jnp.array_equal(expected1, dfInterpolated[:, 1]) + - jnp.array_equal(expected1, dfInterpolated[:,1]) # TODO generalize these methods above diff --git a/test/fem/test_dof_manager.py b/test/fem/test_dof_manager.py index 67a17a9..faf68a8 100644 --- a/test/fem/test_dof_manager.py +++ b/test/fem/test_dof_manager.py @@ -14,20 +14,22 @@ def dof_manager_test_fixture(): from .utils import create_mesh_and_disp import jax.numpy as jnp - xRange = [0., 1.] - yRange = [0., 1.] + xRange = [0.0, 1.0] + yRange = [0.0, 1.0] - mesh, _ = create_mesh_and_disp(Nx, Ny, xRange, yRange, lambda x : 0*x) + mesh, _ = create_mesh_and_disp(Nx, Ny, xRange, yRange, lambda x: 0 * x) - ebcs = [DirichletBC(nodeSet='top', component=0), - DirichletBC(nodeSet='right', component=1)] + ebcs = [ + DirichletBC(nodeSet="top", component=0), + DirichletBC(nodeSet="right", component=1), + ] dofManager = DofManager(mesh, nFields, ebcs) U = jnp.zeros((nNodes, nFields)) - U = U.at[:,1].set(1.0) - U = U.at[mesh.nodeSets['top'],0].set(2.0) - U = U.at[mesh.nodeSets['right'],1].set(3.0) + U = U.at[:, 1].set(1.0) + U = U.at[mesh.nodeSets["top"], 0].set(2.0) + U = U.at[mesh.nodeSets["right"], 1].set(3.0) return dofManager, U @@ -52,17 +54,19 @@ def test_get_unknown_size(dof_manager_test_fixture): def test_slice_unknowns_with_dof_indices(dof_manager_test_fixture): import jax.numpy as jnp + dofManager, U = dof_manager_test_fixture Uu = dofManager.get_unknown_values(U) - Uu_x = dofManager.slice_unknowns_with_dof_indices(Uu, (slice(None),0) ) - jnp.array_equal(Uu_x, jnp.zeros(Nx*(Ny-1))) - Uu_y = dofManager.slice_unknowns_with_dof_indices(Uu, (slice(None),1) ) - jnp.array_equal(Uu_y, jnp.ones(Ny*(Nx-1))) + Uu_x = dofManager.slice_unknowns_with_dof_indices(Uu, (slice(None), 0)) + jnp.array_equal(Uu_x, jnp.zeros(Nx * (Ny - 1))) + Uu_y = dofManager.slice_unknowns_with_dof_indices(Uu, (slice(None), 1)) + jnp.array_equal(Uu_y, jnp.ones(Ny * (Nx - 1))) def test_create_field_and_get_bc_values(dof_manager_test_fixture): import jax.numpy as jnp + dofManager, U = dof_manager_test_fixture Uu = jnp.zeros(dofManager.get_unknown_size()) diff --git a/test/fem/test_function_space.py b/test/fem/test_function_space.py index eef246f..a528251 100644 --- a/test/fem/test_function_space.py +++ b/test/fem/test_function_space.py @@ -11,26 +11,32 @@ # mesh Nx = 7 Ny = 7 -xRange = [0.,1.] -yRange = [0.,1.] -# targetDispGrad = jnp.array([[0.1, -0.2],[0.4, -0.1]]) +xRange = [0.0, 1.0] +yRange = [0.0, 1.0] +# targetDispGrad = jnp.array([[0.1, -0.2],[0.4, -0.1]]) # mesh, U = create_mesh_and_disp(Nx, Ny, xRange, yRange, lambda x: jnp.dot(targetDispGrad, x)) + @pytest.fixture def mesh_and_disp(): from .utils import create_mesh_and_disp import jax.numpy as jnp - targetDispGrad = jnp.array([[0.1, -0.2],[0.4, -0.1]]) - mesh, U = create_mesh_and_disp(Nx, Ny, xRange, yRange, lambda x: jnp.dot(targetDispGrad, x)) + + targetDispGrad = jnp.array([[0.1, -0.2], [0.4, -0.1]]) + mesh, U = create_mesh_and_disp( + Nx, Ny, xRange, yRange, lambda x: jnp.dot(targetDispGrad, x) + ) return mesh, U, targetDispGrad + # function space @pytest.fixture def fspace_fixture_1(mesh_and_disp): from pancax.fem import NonAllocatedFunctionSpace, QuadratureRule from pancax.fem import construct_function_space import jax.numpy as jnp + mesh, _, _ = mesh_and_disp quadratureRule = QuadratureRule(mesh.parentElement, 1) fs = construct_function_space(mesh, quadratureRule) @@ -48,6 +54,7 @@ def fspace_fixture_2(mesh_and_disp): from pancax.fem import NonAllocatedFunctionSpace, QuadratureRule from pancax.fem import construct_function_space import jax.numpy as jnp + mesh, _, _ = mesh_and_disp quadratureRule = QuadratureRule(mesh.parentElement, 1) fs = construct_function_space(mesh, quadratureRule) @@ -62,27 +69,30 @@ def fspace_fixture_2(mesh_and_disp): def test_element_volume_single_point_quadrature(fspace_fixture_1, mesh_and_disp): import jax.numpy as jnp + fs, _, _, _, _, _ = fspace_fixture_1 mesh, _, _ = mesh_and_disp elementVols = jnp.sum(fs.vols, axis=1) nElements = mesh.num_elements - jnp.array_equal(elementVols, jnp.ones(nElements)*0.5/((Nx-1)*(Ny-1))) + jnp.array_equal(elementVols, jnp.ones(nElements) * 0.5 / ((Nx - 1) * (Ny - 1))) def test_element_volume_single_point_quadrature_na(fspace_fixture_1, mesh_and_disp): import jax import jax.numpy as jnp + _, fs_na, _, _, _, _ = fspace_fixture_1 mesh, _, _ = mesh_and_disp X_els = mesh.coords[mesh.conns, :] elementVols = jnp.sum(jax.vmap(fs_na.JxWs)(X_els), axis=1) nElements = mesh.num_elements - jnp.array_equal(elementVols, jnp.ones(nElements)*0.5/((Nx-1)*(Ny-1))) + jnp.array_equal(elementVols, jnp.ones(nElements) * 0.5 / ((Nx - 1) * (Ny - 1))) def test_linear_reproducing_single_point_quadrature(fspace_fixture_1, mesh_and_disp): from pancax.fem.function_space import compute_field_gradient import jax.numpy as jnp + fs, _, quadratureRule, _, _, _ = fspace_fixture_1 mesh, U, targetDispGrad = mesh_and_disp dispGrads = compute_field_gradient(fs, U, mesh.coords) @@ -95,6 +105,7 @@ def test_linear_reproducing_single_point_quadrature(fspace_fixture_1, mesh_and_d def test_linear_reproducing_single_point_quadrature_na(fspace_fixture_1, mesh_and_disp): import jax import jax.numpy as jnp + _, fs_na, quadratureRule, _, _, _ = fspace_fixture_1 mesh, U, targetDispGrad = mesh_and_disp X_els = mesh.coords[mesh.conns, :] @@ -106,29 +117,37 @@ def test_linear_reproducing_single_point_quadrature_na(fspace_fixture_1, mesh_an assert jnp.allclose(dispGrads, exact) -def test_integrate_constant_field_single_point_quadrature(fspace_fixture_1, mesh_and_disp): +def test_integrate_constant_field_single_point_quadrature( + fspace_fixture_1, mesh_and_disp +): from pancax.fem.function_space import integrate_over_block import jax.numpy as jnp + fs, _, _, state, props, dt = fspace_fixture_1 mesh, U, _ = mesh_and_disp - integralOfOne = integrate_over_block(fs, - U, - mesh.coords, - state, - props, - dt, - lambda u, gradu, state, props, X, dt: 1.0, - mesh.blocks['block']) + integralOfOne = integrate_over_block( + fs, + U, + mesh.coords, + state, + props, + dt, + lambda u, gradu, state, props, X, dt: 1.0, + mesh.blocks["block"], + ) jnp.isclose(integralOfOne, 1.0) -def test_integrate_constant_field_single_point_quadrature_na(fspace_fixture_1, mesh_and_disp): +def test_integrate_constant_field_single_point_quadrature_na( + fspace_fixture_1, mesh_and_disp +): import jax.numpy as jnp + _, fs_na, _, state, props, dt = fspace_fixture_1 mesh, U, _ = mesh_and_disp - U_els = U[mesh.conns[mesh.blocks['block']], :] - X_els = U[mesh.conns[mesh.blocks['block']], :] + U_els = U[mesh.conns[mesh.blocks["block"]], :] + X_els = U[mesh.conns[mesh.blocks["block"]], :] integralOfOne = fs_na.integrate_on_elements( U_els, X_els, @@ -140,43 +159,53 @@ def test_integrate_constant_field_single_point_quadrature_na(fspace_fixture_1, m jnp.isclose(integralOfOne, 1.0) -def test_integrate_linear_field_single_point_quadrature(fspace_fixture_1, mesh_and_disp): +def test_integrate_linear_field_single_point_quadrature( + fspace_fixture_1, mesh_and_disp +): from pancax.fem.function_space import integrate_over_block import jax.numpy as jnp + fs, _, _, state, props, dt = fspace_fixture_1 mesh, U, _ = mesh_and_disp - Ix = integrate_over_block(fs, - U, - mesh.coords, - state, - props, - dt, - lambda u, gradu, state, props, X, dt: gradu[0,0], - mesh.blocks['block']) + Ix = integrate_over_block( + fs, + U, + mesh.coords, + state, + props, + dt, + lambda u, gradu, state, props, X, dt: gradu[0, 0], + mesh.blocks["block"], + ) # displacement at x=1 should match integral - idx = jnp.argmax(mesh.coords[:,0]) - expected = U[idx,0]*(yRange[1] - yRange[0]) + idx = jnp.argmax(mesh.coords[:, 0]) + expected = U[idx, 0] * (yRange[1] - yRange[0]) jnp.isclose(Ix, expected) - - Iy = integrate_over_block(fs, - U, - mesh.coords, - state, - props, - dt, - lambda u, gradu, state, props, X, dt: gradu[1, 1], - mesh.blocks['block']) - idx = jnp.argmax(mesh.coords[:,1]) - expected = U[idx,1]*(xRange[1] - xRange[0]) + + Iy = integrate_over_block( + fs, + U, + mesh.coords, + state, + props, + dt, + lambda u, gradu, state, props, X, dt: gradu[1, 1], + mesh.blocks["block"], + ) + idx = jnp.argmax(mesh.coords[:, 1]) + expected = U[idx, 1] * (xRange[1] - xRange[0]) jnp.isclose(Iy, expected) -def test_integrate_linear_field_single_point_quadrature_na(fspace_fixture_1, mesh_and_disp): +def test_integrate_linear_field_single_point_quadrature_na( + fspace_fixture_1, mesh_and_disp +): import jax.numpy as jnp + _, fs_na, _, state, props, dt = fspace_fixture_1 mesh, U, _ = mesh_and_disp - U_els = U[mesh.conns[mesh.blocks['block']], :] - X_els = U[mesh.conns[mesh.blocks['block']], :] + U_els = U[mesh.conns[mesh.blocks["block"]], :] + X_els = U[mesh.conns[mesh.blocks["block"]], :] Ix = fs_na.integrate_on_elements( U_els, @@ -186,8 +215,8 @@ def test_integrate_linear_field_single_point_quadrature_na(fspace_fixture_1, mes dt, lambda u, gradu, state, props, X, dt: gradu[0, 0], ) - idx = jnp.argmax(mesh.coords[:,0]) - expected = U[idx,0]*(yRange[1] - yRange[0]) + idx = jnp.argmax(mesh.coords[:, 0]) + expected = U[idx, 0] * (yRange[1] - yRange[0]) jnp.isclose(Ix, expected) Iy = fs_na.integrate_on_elements( @@ -199,34 +228,37 @@ def test_integrate_linear_field_single_point_quadrature_na(fspace_fixture_1, mes lambda u, gradu, state, props, X, dt: gradu[1, 1], ) - idx = jnp.argmax(mesh.coords[:,1]) - expected = U[idx,1]*(xRange[1] - xRange[0]) + idx = jnp.argmax(mesh.coords[:, 1]) + expected = U[idx, 1] * (xRange[1] - xRange[0]) jnp.isclose(Iy, expected) def test_element_volume_multi_point_quadrature(fspace_fixture_2, mesh_and_disp): import jax.numpy as jnp + fs, _, _, _, _, _ = fspace_fixture_2 mesh, U, _ = mesh_and_disp elementVols = jnp.sum(fs.vols, axis=1) nElements = mesh.num_elements - jnp.array_equal(elementVols, jnp.ones(nElements)*0.5/((Nx-1)*(Ny-1))) + jnp.array_equal(elementVols, jnp.ones(nElements) * 0.5 / ((Nx - 1) * (Ny - 1))) def test_element_volume_multi_point_quadrature_na(fspace_fixture_2, mesh_and_disp): import jax import jax.numpy as jnp + _, fs_na, _, _, _, _ = fspace_fixture_2 mesh, U, _ = mesh_and_disp X_els = mesh.coords[mesh.conns, :] elementVols = jnp.sum(jax.vmap(fs_na.JxWs)(X_els), axis=1) nElements = mesh.num_elements - jnp.array_equal(elementVols, jnp.ones(nElements)*0.5/((Nx-1)*(Ny-1))) + jnp.array_equal(elementVols, jnp.ones(nElements) * 0.5 / ((Nx - 1) * (Ny - 1))) def test_linear_reproducing_multi_point_quadrature(fspace_fixture_2, mesh_and_disp): from pancax.fem.function_space import compute_field_gradient import jax.numpy as jnp + fs, _, quadratureRule, _, _, _ = fspace_fixture_2 mesh, U, targetDispGrad = mesh_and_disp dispGrads = compute_field_gradient(fs, U, mesh.coords) @@ -239,6 +271,7 @@ def test_linear_reproducing_multi_point_quadrature(fspace_fixture_2, mesh_and_di def test_linear_reproducing_multi_point_quadrature_na(fspace_fixture_2, mesh_and_disp): import jax import jax.numpy as jnp + _, fs_na, quadratureRule, _, _, _ = fspace_fixture_2 mesh, U, targetDispGrad = mesh_and_disp X_els = mesh.coords[mesh.conns, :] @@ -250,28 +283,36 @@ def test_linear_reproducing_multi_point_quadrature_na(fspace_fixture_2, mesh_and assert jnp.allclose(dispGrads, exact) -def test_integrate_constant_field_multi_point_point_quadrature(fspace_fixture_2, mesh_and_disp): +def test_integrate_constant_field_multi_point_point_quadrature( + fspace_fixture_2, mesh_and_disp +): from pancax.fem.function_space import integrate_over_block import jax.numpy as jnp + fs, _, _, state, props, dt = fspace_fixture_2 mesh, U, _ = mesh_and_disp - integralOfOne = integrate_over_block(fs, - U, - mesh.coords, - state, - props, - dt, - lambda u, gradu, state, props, X, dt: 1.0, - mesh.blocks['block']) + integralOfOne = integrate_over_block( + fs, + U, + mesh.coords, + state, + props, + dt, + lambda u, gradu, state, props, X, dt: 1.0, + mesh.blocks["block"], + ) jnp.isclose(integralOfOne, 1.0) -def test_integrate_constant_field_multi_point_quadrature_na(fspace_fixture_2, mesh_and_disp): +def test_integrate_constant_field_multi_point_quadrature_na( + fspace_fixture_2, mesh_and_disp +): import jax.numpy as jnp + _, fs_na, _, state, props, dt = fspace_fixture_2 mesh, U, _ = mesh_and_disp - U_els = U[mesh.conns[mesh.blocks['block']], :] - X_els = U[mesh.conns[mesh.blocks['block']], :] + U_els = U[mesh.conns[mesh.blocks["block"]], :] + X_els = U[mesh.conns[mesh.blocks["block"]], :] integralOfOne = fs_na.integrate_on_elements( U_els, X_els, @@ -283,46 +324,53 @@ def test_integrate_constant_field_multi_point_quadrature_na(fspace_fixture_2, me jnp.isclose(integralOfOne, 1.0) - # TODO add integration test/method for new na fspace def test_integrate_linear_field_multi_point_quadrature(fspace_fixture_2, mesh_and_disp): from pancax.fem.function_space import integrate_over_block import jax.numpy as jnp + fs, _, _, state, props, dt = fspace_fixture_2 mesh, U, _ = mesh_and_disp - Ix = integrate_over_block(fs, - U, - mesh.coords, - state, - props, - dt, - lambda u, gradu, state, props, X, dt: gradu[0,0], - mesh.blocks['block']) - idx = jnp.argmax(mesh.coords[:,0]) - expected = U[idx,0]*(yRange[1] - yRange[0]) + Ix = integrate_over_block( + fs, + U, + mesh.coords, + state, + props, + dt, + lambda u, gradu, state, props, X, dt: gradu[0, 0], + mesh.blocks["block"], + ) + idx = jnp.argmax(mesh.coords[:, 0]) + expected = U[idx, 0] * (yRange[1] - yRange[0]) jnp.isclose(Ix, expected) - - Iy = integrate_over_block(fs, - U, - mesh.coords, - state, - props, - dt, - lambda u, gradu, state, props, X, dt: gradu[1,1], - mesh.blocks['block']) - idx = jnp.argmax(mesh.coords[:,1]) - expected = U[idx,1]*(xRange[1] - xRange[0]) + + Iy = integrate_over_block( + fs, + U, + mesh.coords, + state, + props, + dt, + lambda u, gradu, state, props, X, dt: gradu[1, 1], + mesh.blocks["block"], + ) + idx = jnp.argmax(mesh.coords[:, 1]) + expected = U[idx, 1] * (xRange[1] - xRange[0]) jnp.isclose(Iy, expected) -def test_integrate_linear_field_multi_point_quadrature_na(fspace_fixture_2, mesh_and_disp): +def test_integrate_linear_field_multi_point_quadrature_na( + fspace_fixture_2, mesh_and_disp +): import jax.numpy as jnp + _, fs_na, _, state, props, dt = fspace_fixture_2 mesh, U, _ = mesh_and_disp - U_els = U[mesh.conns[mesh.blocks['block']], :] - X_els = U[mesh.conns[mesh.blocks['block']], :] + U_els = U[mesh.conns[mesh.blocks["block"]], :] + X_els = U[mesh.conns[mesh.blocks["block"]], :] Ix = fs_na.integrate_on_elements( U_els, @@ -332,8 +380,8 @@ def test_integrate_linear_field_multi_point_quadrature_na(fspace_fixture_2, mesh dt, lambda u, gradu, state, props, X, dt: gradu[0, 0], ) - idx = jnp.argmax(mesh.coords[:,0]) - expected = U[idx,0]*(yRange[1] - yRange[0]) + idx = jnp.argmax(mesh.coords[:, 0]) + expected = U[idx, 0] * (yRange[1] - yRange[0]) jnp.isclose(Ix, expected) Iy = fs_na.integrate_on_elements( @@ -345,14 +393,15 @@ def test_integrate_linear_field_multi_point_quadrature_na(fspace_fixture_2, mesh lambda u, gradu, state, props, X, dt: gradu[1, 1], ) - idx = jnp.argmax(mesh.coords[:,1]) - expected = U[idx,1]*(xRange[1] - xRange[0]) + idx = jnp.argmax(mesh.coords[:, 1]) + expected = U[idx, 1] * (xRange[1] - xRange[0]) jnp.isclose(Iy, expected) def test_integrate_over_half_block(fspace_fixture_2, mesh_and_disp): from pancax.fem.function_space import integrate_over_block import jax.numpy as jnp + mesh, U, _ = mesh_and_disp fs, _, _, state, props, dt = fspace_fixture_2 nElements = mesh.num_elements @@ -360,21 +409,24 @@ def test_integrate_over_half_block(fspace_fixture_2, mesh_and_disp): # put this in so that if test is modified to odd number, # we understand why it fails assert nElements % 2 == 0 - - blockWithHalfTheVolume = slice(0,nElements//2) - integral = integrate_over_block(fs, - U, - mesh.coords, - state, - props, - dt, - lambda u, gradu, state, props, X, dt: 1.0, - blockWithHalfTheVolume) - jnp.isclose(integral, 1.0/2.0) + + blockWithHalfTheVolume = slice(0, nElements // 2) + integral = integrate_over_block( + fs, + U, + mesh.coords, + state, + props, + dt, + lambda u, gradu, state, props, X, dt: 1.0, + blockWithHalfTheVolume, + ) + jnp.isclose(integral, 1.0 / 2.0) def test_integrate_over_half_block_na(fspace_fixture_2, mesh_and_disp): import jax.numpy as jnp + _, fs_na, _, state, props, dt = fspace_fixture_2 mesh, U, _ = mesh_and_disp nElements = mesh.num_elements @@ -382,7 +434,7 @@ def test_integrate_over_half_block_na(fspace_fixture_2, mesh_and_disp): # put this in so that if test is modified to odd number, # we understand why it fails assert nElements % 2 == 0 - blockWithHalfTheVolume = slice(0,nElements//2) + blockWithHalfTheVolume = slice(0, nElements // 2) U_els = U[mesh.conns[blockWithHalfTheVolume], :] X_els = U[mesh.conns[blockWithHalfTheVolume], :] @@ -395,12 +447,13 @@ def test_integrate_over_half_block_na(fspace_fixture_2, mesh_and_disp): dt, lambda u, gradu, state, props, X, dt: 1.0, ) - jnp.isclose(integral, 1.0/2.0) + jnp.isclose(integral, 1.0 / 2.0) def test_integrate_over_half_block_indices(fspace_fixture_2, mesh_and_disp): from pancax.fem.function_space import integrate_over_block import jax.numpy as jnp + mesh, U, _ = mesh_and_disp fs, _, _, state, props, dt = fspace_fixture_2 nElements = mesh.num_elements @@ -408,22 +461,25 @@ def test_integrate_over_half_block_indices(fspace_fixture_2, mesh_and_disp): # put this in so that if test is modified to odd number, # we understand why it fails assert nElements % 2 == 0 - - blockWithHalfTheVolume = jnp.arange(nElements//2) - - integral = integrate_over_block(fs, - U, - mesh.coords, - state, - props, - dt, - lambda u, gradu, state, props, X, dt: 1.0, - blockWithHalfTheVolume) - jnp.isclose(integral, 1.0/2.0) - + + blockWithHalfTheVolume = jnp.arange(nElements // 2) + + integral = integrate_over_block( + fs, + U, + mesh.coords, + state, + props, + dt, + lambda u, gradu, state, props, X, dt: 1.0, + blockWithHalfTheVolume, + ) + jnp.isclose(integral, 1.0 / 2.0) + def test_integrate_over_half_block_indices_na(fspace_fixture_2, mesh_and_disp): import jax.numpy as jnp + _, fs_na, _, state, props, dt = fspace_fixture_2 mesh, U, _ = mesh_and_disp nElements = mesh.num_elements @@ -431,8 +487,8 @@ def test_integrate_over_half_block_indices_na(fspace_fixture_2, mesh_and_disp): # put this in so that if test is modified to odd number, # we understand why it fails assert nElements % 2 == 0 - - blockWithHalfTheVolume = jnp.arange(nElements//2) + + blockWithHalfTheVolume = jnp.arange(nElements // 2) U_els = U[mesh.conns[blockWithHalfTheVolume], :] X_els = U[mesh.conns[blockWithHalfTheVolume], :] @@ -445,29 +501,42 @@ def test_integrate_over_half_block_indices_na(fspace_fixture_2, mesh_and_disp): dt, lambda u, gradu, state, props, X, dt: 1.0, ) - jnp.isclose(integral, 1.0/2.0) + jnp.isclose(integral, 1.0 / 2.0) def test_jit_on_integration(fspace_fixture_2, mesh_and_disp): from pancax.fem.function_space import integrate_over_block import jax import jax.numpy as jnp + fs, _, _, state, props, dt = fspace_fixture_2 mesh, U, _ = mesh_and_disp integrate_jit = jax.jit(integrate_over_block, static_argnums=(6,)) - I = integrate_jit(fs, U, mesh.coords, state, props, dt, lambda u, gradu, state, props, X, dt: 1.0, mesh.blocks['block']) + I = integrate_jit( + fs, + U, + mesh.coords, + state, + props, + dt, + lambda u, gradu, state, props, X, dt: 1.0, + mesh.blocks["block"], + ) jnp.isclose(I, 1.0) def test_jit_on_integration_na(fspace_fixture_2, mesh_and_disp): import equinox as eqx import jax.numpy as jnp + _, fs_na, _, state, props, dt = fspace_fixture_2 mesh, U, _ = mesh_and_disp integrate_jit = eqx.filter_jit(fs_na.integrate_on_elements) - U_els = U[mesh.conns[mesh.blocks['block']], :] - X_els = U[mesh.conns[mesh.blocks['block']], :] - I = integrate_jit(U_els, X_els, state, props, dt, lambda u, gradu, state, props, X, dt: 1.0) + U_els = U[mesh.conns[mesh.blocks["block"]], :] + X_els = U[mesh.conns[mesh.blocks["block"]], :] + I = integrate_jit( + U_els, X_els, state, props, dt, lambda u, gradu, state, props, X, dt: 1.0 + ) jnp.isclose(I, 1.0) @@ -475,24 +544,42 @@ def test_jit_and_jacrev_on_integration(fspace_fixture_2, mesh_and_disp): from pancax.fem.function_space import integrate_over_block import jax import jax.numpy as jnp + fs, _, _, state, props, dt = fspace_fixture_2 mesh, U, _ = mesh_and_disp F = jax.jit(jax.jacrev(integrate_over_block, 1), static_argnums=(6,)) - dI = F(fs, U, mesh.coords, state, props, dt, lambda u, gradu, state, props, X, dt: 0.5*jnp.tensordot(gradu, gradu), mesh.blocks['block']) + dI = F( + fs, + U, + mesh.coords, + state, + props, + dt, + lambda u, gradu, state, props, X, dt: 0.5 * jnp.tensordot(gradu, gradu), + mesh.blocks["block"], + ) nNodes = mesh.coords.shape[0] - interiorNodeIds = jnp.setdiff1d(jnp.arange(nNodes), mesh.nodeSets['all_boundary']) - jnp.array_equal(dI[interiorNodeIds,:], jnp.zeros_like(U[interiorNodeIds,:])) + interiorNodeIds = jnp.setdiff1d(jnp.arange(nNodes), mesh.nodeSets["all_boundary"]) + jnp.array_equal(dI[interiorNodeIds, :], jnp.zeros_like(U[interiorNodeIds, :])) def test_jit_and_jacrev_on_integration_na(fspace_fixture_2, mesh_and_disp): import equinox as eqx import jax.numpy as jnp + _, fs_na, _, state, props, dt = fspace_fixture_2 mesh, U, _ = mesh_and_disp F = eqx.filter_jit(eqx.filter_jacrev(fs_na.integrate_on_elements)) - U_els = U[mesh.conns[mesh.blocks['block']], :] - X_els = U[mesh.conns[mesh.blocks['block']], :] - dI = F(U_els, X_els, state, props, dt, lambda u, gradu, state, props, X, dt: 0.5*jnp.tensordot(gradu, gradu)) + U_els = U[mesh.conns[mesh.blocks["block"]], :] + X_els = U[mesh.conns[mesh.blocks["block"]], :] + dI = F( + U_els, + X_els, + state, + props, + dt, + lambda u, gradu, state, props, X, dt: 0.5 * jnp.tensordot(gradu, gradu), + ) nNodes = mesh.coords.shape[0] - interiorNodeIds = jnp.setdiff1d(jnp.arange(nNodes), mesh.nodeSets['all_boundary']) - jnp.array_equal(dI[interiorNodeIds,:], jnp.zeros_like(U[interiorNodeIds,:])) + interiorNodeIds = jnp.setdiff1d(jnp.arange(nNodes), mesh.nodeSets["all_boundary"]) + jnp.array_equal(dI[interiorNodeIds, :], jnp.zeros_like(U[interiorNodeIds, :])) diff --git a/test/fem/test_mesh.py b/test/fem/test_mesh.py index cfd4f0c..ab962c0 100644 --- a/test/fem/test_mesh.py +++ b/test/fem/test_mesh.py @@ -6,21 +6,26 @@ Nx = 3 Ny = 2 -xRange = [0.,1.] -yRange = [0.,1.] -targetDispGrad = np.array([[0.1, -0.2],[0.4, -0.1]]) +xRange = [0.0, 1.0] +yRange = [0.0, 1.0] +targetDispGrad = np.array([[0.1, -0.2], [0.4, -0.1]]) def triangle_inradius(tcoords): tcoords = onp.hstack((tcoords, onp.ones((tcoords.shape[0], 1)))) - area = 0.5*onp.cross(tcoords[1]-tcoords[0], tcoords[2]-tcoords[0])[2] - peri = (onp.linalg.norm(tcoords[1]-tcoords[0]) - + onp.linalg.norm(tcoords[2]-tcoords[1]) - + onp.linalg.norm(tcoords[0]-tcoords[2])) - return area/peri + area = 0.5 * onp.cross(tcoords[1] - tcoords[0], tcoords[2] - tcoords[0])[2] + peri = ( + onp.linalg.norm(tcoords[1] - tcoords[0]) + + onp.linalg.norm(tcoords[2] - tcoords[1]) + + onp.linalg.norm(tcoords[0] - tcoords[2]) + ) + return area / peri -mesh, U = create_mesh_and_disp(Nx, Ny, xRange, yRange, lambda x: np.dot(targetDispGrad, x)) + +mesh, U = create_mesh_and_disp( + Nx, Ny, xRange, yRange, lambda x: np.dot(targetDispGrad, x) +) def test_create_nodesets_from_sidesets(): @@ -29,7 +34,7 @@ def test_create_nodesets_from_sidesets(): # this test relies on the fact that matching nodesets # and sidesets were created on the MeshFixture - + for key in mesh.sideSets: assert np.array_equal(mesh.nodeSets[key], nodeSets[key]) @@ -37,12 +42,7 @@ def test_create_nodesets_from_sidesets(): def test_edge_connectivities(): edgeConns, _ = fem.create_edges(mesh.conns) - goldBoundaryEdgeConns = np.array([[0, 1], - [1, 2], - [2, 5], - [5, 4], - [4, 3], - [3, 0]]) + goldBoundaryEdgeConns = np.array([[0, 1], [1, 2], [2, 5], [5, 4], [4, 3], [3, 0]]) # Check that every boundary edge has been found. # Boundary edges must appear with the same connectivity order, @@ -63,9 +63,7 @@ def test_edge_connectivities(): # sense the vertices should be ordered, so we check # for both permutations. - goldInteriorEdgeConns = np.array([[0, 4], - [1, 4], - [1, 5]]) + goldInteriorEdgeConns = np.array([[0, 4], [1, 4], [1, 5]]) nInteriorEdges = goldInteriorEdgeConns.shape[0] interiorEdgeFound = onp.full(nInteriorEdges, False) @@ -80,30 +78,25 @@ def test_edge_connectivities(): def test_edge_to_neighbor_cells_data(): edgeConns, edges = fem.create_edges(mesh.conns) - goldBoundaryEdgeConns = np.array([[0, 1], - [1, 2], - [2, 5], - [5, 4], - [4, 3], - [3, 0]]) + goldBoundaryEdgeConns = np.array([[0, 1], [1, 2], [2, 5], [5, 4], [4, 3], [3, 0]]) - goldBoundaryEdges = onp.array([[0, 0, -1, -1], - [2, 0, -1, -1], - [2, 1, -1, -1], - [3, 1, -1, -1], - [1, 1, -1, -1], - [1, 2, -1, -1]]) + goldBoundaryEdges = onp.array( + [ + [0, 0, -1, -1], + [2, 0, -1, -1], + [2, 1, -1, -1], + [3, 1, -1, -1], + [1, 1, -1, -1], + [1, 2, -1, -1], + ] + ) for be, bc in zip(goldBoundaryEdges, goldBoundaryEdgeConns): i = np.where(onp.all(edgeConns == bc, axis=1)) assert np.all(edges[i, :] == be) - goldInteriorEdgeConns = np.array([[0, 4], - [1, 4], - [5, 1]]) - goldInteriorEdges = onp.array([[1, 0, 0, 2], - [0, 1, 3, 2], - [2, 2, 3, 0]]) + goldInteriorEdgeConns = np.array([[0, 4], [1, 4], [5, 1]]) + goldInteriorEdges = onp.array([[1, 0, 0, 2], [0, 1, 3, 2], [2, 2, 3, 0]]) for ie, ic in zip(goldInteriorEdges, goldInteriorEdgeConns): foundWithSameSense = onp.any(onp.all(edgeConns == ic, axis=1)) @@ -117,8 +110,8 @@ def test_edge_to_neighbor_cells_data(): edgeData = ie[[2, 3, 0, 1]] else: # self.fail('edge not found with vertices ' + str(ic)) - print('Need to raise an exception test here') - edgeDataMatches = np.all(edges[i,:] == edgeData) + print("Need to raise an exception test here") + edgeDataMatches = np.all(edges[i, :] == edgeData) assert edgeDataMatches @@ -143,7 +136,7 @@ def test_conversion_to_quadratic_mesh_is_valid(): parentArea = triangle_inradius(parentCoords) childArea = triangle_inradius(childCoords) - + assert parentArea > 0.0 assert childArea > 0.0 assert np.abs(parentArea - 2.0 * childArea) < 1e-10 diff --git a/test/fem/test_quadrature_rules.py b/test/fem/test_quadrature_rules.py index ad23130..0c0621f 100644 --- a/test/fem/test_quadrature_rules.py +++ b/test/fem/test_quadrature_rules.py @@ -8,32 +8,32 @@ # integrate x^n y^m on unit triangle def integrate_2D_monomial_on_triangle(n, m): p = n + m - return 1.0/((p + 2)*(p + 1)*binom(p, n)) + return 1.0 / ((p + 2) * (p + 1) * binom(p, n)) def is_inside_hex(point): - x_condition = (point[0] >= -1.) and (point[0] <= 1.) - y_condition = (point[1] >= -1.) and (point[1] <= 1.) - z_condition = (point[2] >= -1.) and (point[2] <= 1.) + x_condition = (point[0] >= -1.0) and (point[0] <= 1.0) + y_condition = (point[1] >= -1.0) and (point[1] <= 1.0) + z_condition = (point[2] >= -1.0) and (point[2] <= 1.0) return x_condition and y_condition and z_condition def is_inside_quad(point): - x_condition = (point[0] >= -1.) and (point[0] <= 1.) - y_condition = (point[1] >= -1.) and (point[1] <= 1.) + x_condition = (point[0] >= -1.0) and (point[0] <= 1.0) + y_condition = (point[1] >= -1.0) and (point[1] <= 1.0) return x_condition and y_condition def is_inside_tet(point): - x_condition = (point[0] >= 0.) and (point[0] <= 1.) - y_condition = (point[1] >= 0.) and (point[1] <= 1. - point[0]) - z_condition = (point[2] >= 0.) and (point[2] <= 1. - point[0] - point[1]) + x_condition = (point[0] >= 0.0) and (point[0] <= 1.0) + y_condition = (point[1] >= 0.0) and (point[1] <= 1.0 - point[0]) + z_condition = (point[2] >= 0.0) and (point[2] <= 1.0 - point[0] - point[1]) return x_condition and y_condition and z_condition def is_inside_triangle(point): - x_condition = (point[0] >= 0.) and (point[0] <= 1.) - y_condition = (point[1] >= 0.) and (point[1] <= 1. - point[0]) + x_condition = (point[0] >= 0.0) and (point[0] <= 1.0) + y_condition = (point[1] >= 0.0) and (point[1] <= 1.0 - point[0]) return x_condition and y_condition @@ -79,18 +79,20 @@ def is_inside_unit_interval(point): in_domain_methods.append(is_inside_tet) -@pytest.mark.parametrize('el, q', zip(elements_to_test, q_degrees)) +@pytest.mark.parametrize("el, q", zip(elements_to_test, q_degrees)) def test_are_postive_weights(el, q): if type(el) == Tet4Element and q == 2: - pytest.skip('Not relevant for Tet4Element and q_degree = 2') + pytest.skip("Not relevant for Tet4Element and q_degree = 2") if type(el) == Tet10Element and q == 2: - pytest.skip('Not relevant for Tet10Element and q_degree = 2') + pytest.skip("Not relevant for Tet10Element and q_degree = 2") qr = QuadratureRule(el, q) _, w = qr - assert jnp.all(w > 0.) + assert jnp.all(w > 0.0) -@pytest.mark.parametrize('el, q, is_inside', zip(elements_to_test, q_degrees, in_domain_methods)) +@pytest.mark.parametrize( + "el, q, is_inside", zip(elements_to_test, q_degrees, in_domain_methods) +) def test_are_inside_domain(el, q, is_inside): qr = QuadratureRule(el, q) for point in qr.xigauss: @@ -104,13 +106,12 @@ def test_triangle_quadrature_exactness(): qr = QuadratureRule(SimplexTriElement(1), degree) for i in range(degree + 1): for j in range(degree + 1 - i): - monomial = qr.xigauss[:,0]**i * qr.xigauss[:,1]**j + monomial = qr.xigauss[:, 0] ** i * qr.xigauss[:, 1] ** j quadratureAnswer = jnp.sum(monomial * qr.wgauss) exactAnswer = integrate_2D_monomial_on_triangle(i, j) assert jnp.abs(quadratureAnswer - exactAnswer) < 1e-14 - def test_len_method(): qr = QuadratureRule(Hex8Element(), 1) assert len(qr) == 1 diff --git a/test/fem/test_read_exodus_mesh.py b/test/fem/test_read_exodus_mesh.py index 1f1a5bd..39f798a 100644 --- a/test/fem/test_read_exodus_mesh.py +++ b/test/fem/test_read_exodus_mesh.py @@ -3,25 +3,25 @@ def test_read_exodus_mesh_hex8(): - f = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'mesh_hex8.g') + f = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mesh_hex8.g") mesh = read_exodus_mesh(f) def test_read_exodus_mesh_quad4(): - f = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'mesh_quad4.g') + f = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mesh_quad4.g") mesh = read_exodus_mesh(f) def test_read_exodus_mesh_quad9(): - f = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'mesh_quad9.g') + f = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mesh_quad9.g") mesh = read_exodus_mesh(f) def test_read_exodus_mesh_tri(): - f = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'mesh_no_ssets.g') + f = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mesh_no_ssets.g") mesh = read_exodus_mesh(f) def test_read_exodus_mesh_with_ssets_tri(): - f = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'mesh_1x.g') + f = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mesh_1x.g") mesh = read_exodus_mesh(f) diff --git a/test/fem/test_surface.py b/test/fem/test_surface.py index dfb9ef4..d34d6ad 100644 --- a/test/fem/test_surface.py +++ b/test/fem/test_surface.py @@ -8,40 +8,36 @@ Ny = 4 L = 1.2 W = 1.5 -xRange = [0., L] -yRange = [0., W] +xRange = [0.0, L] +yRange = [0.0, W] -targetDispGrad = jnp.zeros((2,2)) +targetDispGrad = jnp.zeros((2, 2)) -mesh, U = create_mesh_and_disp(Nx, Ny, xRange, yRange, - lambda x : targetDispGrad.dot(x)) +mesh, U = create_mesh_and_disp(Nx, Ny, xRange, yRange, lambda x: targetDispGrad.dot(x)) quadRule = QuadratureRule(LineElement(1), 2) - + def test_integrate_perimeter(): print(mesh) - p = integrate_function_on_surface(quadRule, - mesh.sideSets['all_boundary'], - mesh, - lambda x, n: 1.0) + p = integrate_function_on_surface( + quadRule, mesh.sideSets["all_boundary"], mesh, lambda x, n: 1.0 + ) # assertNear(p, 2*(L+W), 14) - assert jnp.abs(p - 2 * (L + W)) < 1.e-14 + assert jnp.abs(p - 2 * (L + W)) < 1.0e-14 + - def test_integrate_quadratic_fn_on_surface(): - I = integrate_function_on_surface(quadRule, - mesh.sideSets['top'], - mesh, - lambda x, n: x[0]**2) + I = integrate_function_on_surface( + quadRule, mesh.sideSets["top"], mesh, lambda x, n: x[0] ** 2 + ) # assertNear(I, L**3/3.0, 14) - assert jnp.abs(I - L**(3/3.0)) < 1.e14 + assert jnp.abs(I - L ** (3 / 3.0)) < 1.0e14 + - def test_integrate_function_on_surface_that_uses_coords_and_normal(): - I = integrate_function_on_surface(quadRule, - mesh.sideSets['all_boundary'], - mesh, - lambda x, n: jnp.dot(x,n)) - assert jnp.abs(I - 2*L*W) < 1e-14 + I = integrate_function_on_surface( + quadRule, mesh.sideSets["all_boundary"], mesh, lambda x, n: jnp.dot(x, n) + ) + assert jnp.abs(I - 2 * L * W) < 1e-14 diff --git a/test/fem/utils.py b/test/fem/utils.py index 4a0a034..fa00e12 100644 --- a/test/fem/utils.py +++ b/test/fem/utils.py @@ -4,43 +4,45 @@ import jax.numpy as jnp -def create_mesh_and_disp(Nx, Ny, xRange, yRange, initial_disp_func, setNamePostFix=''): +def create_mesh_and_disp(Nx, Ny, xRange, yRange, initial_disp_func, setNamePostFix=""): coords, conns = create_structured_mesh_data(Nx, Ny, xRange, yRange) tol = 1e-7 nodeSets = {} - nodeSets['left'+setNamePostFix] = jnp.flatnonzero(coords[:,0] < xRange[0] + tol) - nodeSets['bottom'+setNamePostFix] = jnp.flatnonzero(coords[:,1] < yRange[0] + tol) - nodeSets['right'+setNamePostFix] = jnp.flatnonzero(coords[:,0] > xRange[1] - tol) - nodeSets['top'+setNamePostFix] = jnp.flatnonzero(coords[:,1] > yRange[1] - tol) - nodeSets['all_boundary'+setNamePostFix] = jnp.flatnonzero( - (coords[:,0] < xRange[0] + tol) | - (coords[:,1] < yRange[0] + tol) | - (coords[:,0] > xRange[1] - tol) | - (coords[:,1] > yRange[1] - tol) + nodeSets["left" + setNamePostFix] = jnp.flatnonzero(coords[:, 0] < xRange[0] + tol) + nodeSets["bottom" + setNamePostFix] = jnp.flatnonzero( + coords[:, 1] < yRange[0] + tol ) - + nodeSets["right" + setNamePostFix] = jnp.flatnonzero(coords[:, 0] > xRange[1] - tol) + nodeSets["top" + setNamePostFix] = jnp.flatnonzero(coords[:, 1] > yRange[1] - tol) + nodeSets["all_boundary" + setNamePostFix] = jnp.flatnonzero( + (coords[:, 0] < xRange[0] + tol) + | (coords[:, 1] < yRange[0] + tol) + | (coords[:, 0] > xRange[1] - tol) + | (coords[:, 1] > yRange[1] - tol) + ) + def is_edge_on_left(xyOnEdge): - return jnp.all( xyOnEdge[:,0] < xRange[0] + tol ) + return jnp.all(xyOnEdge[:, 0] < xRange[0] + tol) def is_edge_on_bottom(xyOnEdge): - return jnp.all( xyOnEdge[:,1] < yRange[0] + tol ) + return jnp.all(xyOnEdge[:, 1] < yRange[0] + tol) def is_edge_on_right(xyOnEdge): - return jnp.all( xyOnEdge[:,0] > xRange[1] - tol ) - + return jnp.all(xyOnEdge[:, 0] > xRange[1] - tol) + def is_edge_on_top(xyOnEdge): - return jnp.all( xyOnEdge[:,1] > yRange[1] - tol ) + return jnp.all(xyOnEdge[:, 1] > yRange[1] - tol) sideSets = {} - sideSets['left'+setNamePostFix] = create_edges(coords, conns, is_edge_on_left) - sideSets['bottom'+setNamePostFix] = create_edges(coords, conns, is_edge_on_bottom) - sideSets['right'+setNamePostFix] = create_edges(coords, conns, is_edge_on_right) - sideSets['top'+setNamePostFix] = create_edges(coords, conns, is_edge_on_top) + sideSets["left" + setNamePostFix] = create_edges(coords, conns, is_edge_on_left) + sideSets["bottom" + setNamePostFix] = create_edges(coords, conns, is_edge_on_bottom) + sideSets["right" + setNamePostFix] = create_edges(coords, conns, is_edge_on_right) + sideSets["top" + setNamePostFix] = create_edges(coords, conns, is_edge_on_top) print(sideSets.values()) allBoundaryEdges = jnp.vstack([s for s in sideSets.values()]) - sideSets['all_boundary'+setNamePostFix] = allBoundaryEdges + sideSets["all_boundary" + setNamePostFix] = allBoundaryEdges - blocks = {'block'+setNamePostFix: jnp.arange(conns.shape[0])} + blocks = {"block" + setNamePostFix: jnp.arange(conns.shape[0])} mesh = construct_mesh_from_basic_data(coords, conns, blocks, nodeSets, sideSets) print(mesh) return mesh, vmap(initial_disp_func)(mesh.coords)