Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scripts/run_unit_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
HIP_VISIBLE_DEVICES=0 julia +1.12 --project=@. -e 'using Pkg; Pkg.test()'
6 changes: 3 additions & 3 deletions src/Parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ function Parameters(
@assert isa(properties, NamedTuple)
end

# state_old = Array{Float64, 3}[]
state_old = Array{Float64, 3}[]
# properties = []
state_old = L2QuadratureField[]
# state_old = L2QuadratureField[]
for (key, val) in pairs(physics)
# create properties for this block physics
# TODO specialize to allow for element level properties
Expand All @@ -99,7 +99,7 @@ function Parameters(
state_old_temp[:, q, e] = create_initial_state(val)
end
end
state_old_temp = L2QuadratureField(state_old_temp)
# state_old_temp = L2QuadratureField(state_old_temp)

push!(state_old, state_old_temp)
end
Expand Down
18 changes: 10 additions & 8 deletions src/assemblers/Assemblers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ function _assemble_element!(global_val::H1Field, local_val, conn)
n_dofs = size(global_val, 1)
for i in axes(conn, 1)
for d in 1:n_dofs
# n = 2 * i + d
global_id = n_dofs * (conn[i] - 1) + d
local_id = n_dofs * (i - 1) + d
global_val[global_id] += local_val[local_id]
Expand Down Expand Up @@ -167,13 +166,16 @@ end
"""
$(TYPEDSIGNATURES)
"""
function _quadrature_level_state(state::L2QuadratureField, q::Int, e::Int)
NS = size(state, 1)
if NS > 0
state_q = @views SVector{size(state, 1), eltype(state)}(state[:, q, e])
else
state_q = SVector{0, eltype(state)}()
end
function _quadrature_level_state(state::AbstractArray{<:Number, 3}, q::Int, e::Int)
# NS = size(state, 1)
# if NS > 0
# state_q = @views SVector{size(state, 1), eltype(state)}(state[:, q, e])
# else
# state_q = SVector{0, eltype(state)}()
# end
# return state_q
# NS = size(state, 1)
state_q = view(state, :, q, e)
return state_q
end

Expand Down
18 changes: 7 additions & 11 deletions src/assemblers/Matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ function _assemble_block_matrix!(
) where {
T <: Number,
Solution <: AbstractField,
S <: L2QuadratureField
S #<: L2QuadratureField
}

for e in axes(conns, 2)
Expand All @@ -90,11 +90,9 @@ function _assemble_block_matrix!(
for q in 1:num_quadrature_points(ref_fe)
interps = _cell_interpolants(ref_fe, q)
state_old_q = _quadrature_level_state(state_old, q, e)
K_q, state_new_q = func(physics, interps, x_el, t, dt, u_el, u_el_old, state_old_q, props_el)
state_new_q = _quadrature_level_state(state_new, q, e)
K_q = func(physics, interps, x_el, t, dt, u_el, u_el_old, state_old_q, state_new_q, props_el)
K_el = K_el + K_q
for s in 1:length(state_old)
state_new[s, q, e] = state_new_q[s]
end
end
_assemble_element!(field, K_el, e, block_start_index, block_el_level_size)
end
Expand All @@ -115,7 +113,7 @@ KA.@kernel function _assemble_block_matrix_kernel!(
) where {
T <: Number,
Solution <: AbstractField,
S <: L2QuadratureField
S #<: L2QuadratureField
}
E = KA.@index(Global)

Expand All @@ -127,11 +125,9 @@ KA.@kernel function _assemble_block_matrix_kernel!(
for q in 1:num_quadrature_points(ref_fe)
interps = _cell_interpolants(ref_fe, q)
state_old_q = _quadrature_level_state(state_old, q, E)
K_q, state_new_q = func(physics, interps, x_el, t, dt, u_el, u_el_old, state_old_q, props_el)
state_new_q = _quadrature_level_state(state_new, q, E)
K_q = func(physics, interps, x_el, t, dt, u_el, u_el_old, state_old_q, state_new_q, props_el)
K_el = K_el + K_q
for s in 1:length(state_old)
state_new[s, q, E] = state_new_q[s]
end
end

# leaving here just in case
Expand Down Expand Up @@ -167,7 +163,7 @@ function _assemble_block_matrix!(
) where {
T <: Number,
Solution <: AbstractField,
S <: L2QuadratureField
S #<: L2QuadratureField
}
kernel! = _assemble_block_matrix_kernel!(backend)
kernel!(
Expand Down
20 changes: 7 additions & 13 deletions src/assemblers/MatrixAction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ function _assemble_block_matrix_action!(
) where {
T <: Number,
Solution <: AbstractField,
S <: L2QuadratureField
S #<: L2QuadratureField
}
for e in axes(conns, 2)
x_el = _element_level_fields_flat(X, ref_fe, conns, e)
Expand All @@ -81,12 +81,9 @@ function _assemble_block_matrix_action!(
for q in 1:num_quadrature_points(ref_fe)
interps = _cell_interpolants(ref_fe, q)
state_old_q = _quadrature_level_state(state_old, q, e)
K_q, state_new_q = func(physics, interps, x_el, t, Δt, u_el, u_el_old, state_old_q, props_el)
state_new_q = _quadrature_level_state(state_new, q, e)
K_q = func(physics, interps, x_el, t, Δt, u_el, u_el_old, state_old_q, state_new_q, props_el)
K_el = K_el + K_q
# update state here
for s in 1:length(state_old)
state_new[s, q, e] = state_new_q[s]
end
end
Kv_el = K_el * v_el
@views _assemble_element!(field, Kv_el, conns[:, e])
Expand All @@ -113,7 +110,7 @@ KA.@kernel function _assemble_block_matrix_action_kernel!(
) where {
T <: Number,
Solution <: AbstractField,
S <: L2QuadratureField
S #<: L2QuadratureField
}
E = KA.@index(Global)

Expand All @@ -126,12 +123,9 @@ KA.@kernel function _assemble_block_matrix_action_kernel!(
for q in 1:num_quadrature_points(ref_fe)
interps = _cell_interpolants(ref_fe, q)
state_old_q = _quadrature_level_state(state_old, q, E)
K_q, state_new_q = func(physics, interps, x_el, t, Δt, u_el, u_el_old, state_old_q, props_el)
state_new_q = _quadrature_level_state(state_new, q, E)
K_q = func(physics, interps, x_el, t, Δt, u_el, u_el_old, state_old_q, state_new_q, props_el)
K_el = K_el + K_q
# update state here
for s in 1:length(state_old)
state_new[s, q, E] = state_new_q[s]
end
end
Kv_el = K_el * v_el

Expand Down Expand Up @@ -166,7 +160,7 @@ function _assemble_block_matrix_action!(
) where {
T <: Number,
Solution <: AbstractField,
S <: L2QuadratureField
S #<: L2QuadratureField
}
kernel! = _assemble_block_matrix_action_kernel!(backend)
kernel!(
Expand Down
19 changes: 7 additions & 12 deletions src/assemblers/QuadratureQuantity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ function _assemble_block_quadrature_quantity!(
) where {
T <: Number,
Solution <: AbstractField,
S <: L2QuadratureField
S #<: L2QuadratureField
}
for e in axes(conns, 2)
x_el = _element_level_fields_flat(X, ref_fe, conns, e)
Expand All @@ -80,12 +80,9 @@ function _assemble_block_quadrature_quantity!(
for q in 1:num_quadrature_points(ref_fe)
interps = _cell_interpolants(ref_fe, q)
state_old_q = _quadrature_level_state(state_old, q, e)
e_q, state_new_q = func(physics, interps, x_el, t, Δt, u_el, u_el_old, state_old_q, props_el)
state_new_q = _quadrature_level_state(state_new, q, e)
e_q = func(physics, interps, x_el, t, Δt, u_el, u_el_old, state_old_q, state_new_q, props_el)
field[q, e] = e_q
# update state here
for s in 1:length(state_old)
state_new[s, q, e] = state_new_q[s]
end
end
end
end
Expand All @@ -108,7 +105,7 @@ KA.@kernel function _assemble_block_quadrature_quantity_kernel!(
) where {
T <: Number,
Solution <: AbstractField,
S <: L2QuadratureField
S #<: L2QuadratureField
}
# Q, E = KA.@index(Global, NTuple)
E = KA.@index(Global)
Expand All @@ -120,11 +117,9 @@ KA.@kernel function _assemble_block_quadrature_quantity_kernel!(
KA.Extras.@unroll for q in 1:num_quadrature_points(ref_fe)
interps = _cell_interpolants(ref_fe, q)
state_old_q = _quadrature_level_state(state_old, q, E)
e_q, state_new_q = func(physics, interps, x_el, t, Δt, u_el, u_el_old, state_old_q, props_el)
state_new_q = _quadrature_level_state(state_new, q, E)
e_q = func(physics, interps, x_el, t, Δt, u_el, u_el_old, state_old_q, state_new_q, props_el)
@inbounds field[q, E] = e_q
for s in 1:length(state_old)
@inbounds state_new[s, q, E] = state_new_q[s]
end
end
end
# COV_EXCL_STOP
Expand All @@ -148,7 +143,7 @@ function _assemble_block_quadrature_quantity!(
) where {
T <: Number,
Solution <: AbstractField,
S <: L2QuadratureField
S #<: L2QuadratureField
}
kernel! = _assemble_block_quadrature_quantity_kernel!(backend)
kernel!(
Expand Down
20 changes: 7 additions & 13 deletions src/assemblers/Vector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ function _assemble_block_vector!(
) where {
T <: Number,
Solution <: AbstractField,
S <: L2QuadratureField
S #<: L2QuadratureField
}
for e in axes(conns, 2)
x_el = _element_level_fields_flat(X, ref_fe, conns, e)
Expand All @@ -82,12 +82,9 @@ function _assemble_block_vector!(
for q in 1:num_quadrature_points(ref_fe)
interps = _cell_interpolants(ref_fe, q)
state_old_q = _quadrature_level_state(state_old, q, e)
R_q, state_new_q = func(physics, interps, x_el, t, Δt, u_el, u_el_old, state_old_q, props_el)
state_new_q = _quadrature_level_state(state_new, q, e)
R_q = func(physics, interps, x_el, t, Δt, u_el, u_el_old, state_old_q, state_new_q, props_el)
R_el = R_el + R_q
# update state here
for s in 1:length(state_old)
state_new[s, q, e] = state_new_q[s]
end
end

@views _assemble_element!(field, R_el, conns[:, e])
Expand All @@ -114,7 +111,7 @@ KA.@kernel function _assemble_block_vector_kernel!(
) where {
T <: Number,
Solution <: AbstractField,
S <: L2QuadratureField
S #<: L2QuadratureField
}
E = KA.@index(Global)

Expand All @@ -127,12 +124,9 @@ KA.@kernel function _assemble_block_vector_kernel!(
for q in 1:num_quadrature_points(ref_fe)
interps = _cell_interpolants(ref_fe, q)
state_old_q = _quadrature_level_state(state_old, q, E)
R_q, state_new_q = func(physics, interps, x_el, t, Δt, u_el, u_el_old, state_old_q, props_el)
state_new_q = _quadrature_level_state(state_new, q, E)
R_q = func(physics, interps, x_el, t, Δt, u_el, u_el_old, state_old_q, state_new_q, props_el)
R_el = R_el + R_q
# update state here
for s in 1:length(state_old)
state_new[s, q, E] = state_new_q[s]
end
end

# now assemble atomically
Expand Down Expand Up @@ -166,7 +160,7 @@ function _assemble_block_vector!(
) where {
T <: Number,
Solution <: AbstractField,
S <: L2QuadratureField
S #<: L2QuadratureField
}
kernel! = _assemble_block_vector_kernel!(backend)
kernel!(
Expand Down
18 changes: 9 additions & 9 deletions src/integrals/Integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ function ScalarIntegral(asm, integrand)
field = L2ElementField(zeros(Float64, NQ, NE))
push!(cache, field)
end
cache = NamedTuple{keys(fspace.ref_fes)}(tuple(scalar_quadarature_storage...))
cache = NamedTuple{keys(fspace.ref_fes)}(tuple(cache...))
return ScalarIntegral(asm, cache, integrand)
end

# function gradient(integral::ScalarIntegral)
# # func(physics, interps, x, t, dt, u, u_n, state_old, props) = ForwardDiff.gradient(
# # z -> integral.integrand(physics, interps, x, t, dt, z, u_n, state_old, props)[1],
# # )
# function integrand_grad(physics, interps, x, t, dt, u, u_n, state_old, props)
# return ForwardDiff.gradient()
# # return VectorIntegral
# end
function gradient(integral::ScalarIntegral)
func(physics, interps, x, t, dt, u, u_n, state_old, state_new, props) = ForwardDiff.gradient(
z -> integral.integrand(physics, interps, x, t, dt, z, u_n, state_old, state_new, props),
u
)
cache = create_field(integral.assembler)
return VectorIntegral(integral.assembler, cache, func)
end

function integrate(integral::ScalarIntegral, U, p)
cache, dof = integral.cache, integral.assembler.dof
Expand Down
15 changes: 15 additions & 0 deletions test/TestIntegrals.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
function test_scalar_integral()
mesh_file = "./poisson/poisson.g"
mesh = UnstructuredMesh(mesh_file)
V = FunctionSpace(mesh, H1Field, Lagrange)
u = ScalarFunction(V, :u)
asm = SparseMatrixAssembler(u)
f(X, _) = 2. * π^2 * sin(π * X[1]) * sin(π * X[2])
physics = Poisson(f)
integral = FiniteElementContainers.ScalarIntegral(asm, energy)
grad_integral = FiniteElementContainers.gradient(integral)
end

@testset "Integrals" begin
test_scalar_integral()
end
12 changes: 6 additions & 6 deletions test/mechanics/TestMechanicsCommon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ end
end

@inline function FiniteElementContainers.energy(
physics::Mechanics, interps, x_el, t, dt, u_el, u_el_old, state_old_q, props_el
physics::Mechanics, interps, x_el, t, dt, u_el, u_el_old, state_old_q, state_new_q, props_el
)
interps = map_interpolants(interps, x_el)
(; X_q, N, ∇N_X, JxW) = interps
Expand All @@ -45,12 +45,12 @@ end
∇u_q = modify_field_gradients(physics.formulation, ∇u_q)
# constitutive
ψ_q = strain_energy(∇u_q, state_old_q, props_el, dt)
return JxW * ψ_q, state_old_q
return JxW * ψ_q
end

# note for CUDA things crash without inline
@inline function FiniteElementContainers.residual(
physics::Mechanics, interps, x_el, t, dt, u_el, u_el_old, state_old_q, props_el
physics::Mechanics, interps, x_el, t, dt, u_el, u_el_old, state_old_q, state_new_q, props_el
)
interps = map_interpolants(interps, x_el)
(; X_q, N, ∇N_X, JxW) = interps
Expand All @@ -64,11 +64,11 @@ end
P_q = extract_stress(physics.formulation, P_q)
G_q = discrete_gradient(physics.formulation, ∇N_X)
f_q = G_q * P_q
return JxW * f_q[:], state_old_q
return JxW * f_q[:]
end

@inline function FiniteElementContainers.stiffness(
physics::Mechanics, interps, x_el, t, dt, u_el, u_el_old, state_old_q, props_el
physics::Mechanics, interps, x_el, t, dt, u_el, u_el_old, state_old_q, state_new_q, props_el
)
interps = map_interpolants(interps, x_el)
(; X_q, N, ∇N_X, JxW) = interps
Expand All @@ -83,7 +83,7 @@ end
# turn into voigt notation
K_q = extract_stiffness(physics.formulation, K_q)
G_q = discrete_gradient(physics.formulation, ∇N_X)
return JxW * G_q * K_q * G_q', state_old_q
return JxW * G_q * K_q * G_q'
# K_q = ForwardDiff.hessian(z -> energy(physics, interps, z, x_el, state_old_q, props_el, t, dt)[1], u_el)
# @show K_q
# return K_q, state_old_q
Expand Down
Loading
Loading