Skip to content

Commit 23cfcf5

Browse files
committedJan 15, 2025·
added: separate function for each RungeKutta implementation
1 parent 4849eb2 commit 23cfcf5

File tree

1 file changed

+55
-39
lines changed

1 file changed

+55
-39
lines changed
 

‎src/model/solver.jl

+55-39
Original file line numberDiff line numberDiff line change
@@ -39,54 +39,70 @@ RungeKutta(order::Int=4; supersample::Int=1) = RungeKutta(order, supersample)
3939

4040
"Get the `f!` and `h!` functions for the explicit Runge-Kutta solvers."
4141
function get_solver_functions(NT::DataType, solver::RungeKutta, fc!, hc!, Ts, _ , nx, _ , _ )
42-
order = solver.order
43-
Ts_inner = Ts/solver.supersample
4442
Nc = nx + 2
43+
f! = if solver.order==4
44+
get_rk4_function(NT, solver, fc!, Ts, nx, Nc)
45+
elseif solver.order==1
46+
get_euler_function(NT, solver, fc!, Ts, nx, Nc)
47+
else
48+
throw(ArgumentError("only 1st and 4th order Runge-Kutta is supported."))
49+
end
50+
h! = hc!
51+
return f!, h!
52+
end
53+
54+
"Get the f! function for the 4th order explicit Runge-Kutta solver."
55+
function get_rk4_function(NT, solver, fc!, Ts, nx, Nc)
56+
Ts_inner = Ts/solver.supersample
4557
xcur_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx), Nc)
4658
k1_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx), Nc)
4759
k2_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx), Nc)
4860
k3_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx), Nc)
4961
k4_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx), Nc)
50-
if order==1
51-
f! = function euler_solver!(xnext, x, u, d, p)
52-
CT = promote_type(eltype(x), eltype(u), eltype(d))
53-
xcur = get_tmp(xcur_cache, CT)
54-
k1 = get_tmp(k1_cache, CT)
55-
xterm = xnext
56-
@. xcur = x
57-
for i=1:solver.supersample
58-
fc!(k1, xcur, u, d, p)
59-
@. xcur = xcur + k1 * Ts_inner
60-
end
61-
@. xnext = xcur
62-
return nothing
62+
f! = function rk4_solver!(xnext, x, u, d, p)
63+
CT = promote_type(eltype(x), eltype(u), eltype(d))
64+
xcur = get_tmp(xcur_cache, CT)
65+
k1 = get_tmp(k1_cache, CT)
66+
k2 = get_tmp(k2_cache, CT)
67+
k3 = get_tmp(k3_cache, CT)
68+
k4 = get_tmp(k4_cache, CT)
69+
xterm = xnext
70+
@. xcur = x
71+
for i=1:solver.supersample
72+
fc!(k1, xcur, u, d, p)
73+
@. xterm = xcur + k1 * Ts_inner/2
74+
fc!(k2, xterm, u, d, p)
75+
@. xterm = xcur + k2 * Ts_inner/2
76+
fc!(k3, xterm, u, d, p)
77+
@. xterm = xcur + k3 * Ts_inner
78+
fc!(k4, xterm, u, d, p)
79+
@. xcur = xcur + (k1 + 2k2 + 2k3 + k4)*Ts_inner/6
6380
end
64-
elseif order==4
65-
f! = function rk4_solver!(xnext, x, u, d, p)
66-
CT = promote_type(eltype(x), eltype(u), eltype(d))
67-
xcur = get_tmp(xcur_cache, CT)
68-
k1 = get_tmp(k1_cache, CT)
69-
k2 = get_tmp(k2_cache, CT)
70-
k3 = get_tmp(k3_cache, CT)
71-
k4 = get_tmp(k4_cache, CT)
72-
xterm = xnext
73-
@. xcur = x
74-
for i=1:solver.supersample
75-
fc!(k1, xcur, u, d, p)
76-
@. xterm = xcur + k1 * Ts_inner/2
77-
fc!(k2, xterm, u, d, p)
78-
@. xterm = xcur + k2 * Ts_inner/2
79-
fc!(k3, xterm, u, d, p)
80-
@. xterm = xcur + k3 * Ts_inner
81-
fc!(k4, xterm, u, d, p)
82-
@. xcur = xcur + (k1 + 2k2 + 2k3 + k4)*Ts_inner/6
83-
end
84-
@. xnext = xcur
85-
return nothing
81+
@. xnext = xcur
82+
return nothing
83+
end
84+
return f!
85+
end
86+
87+
"Get the f! function for the explicit Euler solver."
88+
function get_euler_function(NT, solver, fc!, Ts, nx, Nc)
89+
Ts_inner = Ts/solver.supersample
90+
xcur_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx), Nc)
91+
k_cache::DiffCache{Vector{NT}, Vector{NT}} = DiffCache(zeros(NT, nx), Nc)
92+
f! = function euler_solver!(xnext, x, u, d, p)
93+
CT = promote_type(eltype(x), eltype(u), eltype(d))
94+
xcur = get_tmp(xcur_cache, CT)
95+
k = get_tmp(k_cache, CT)
96+
xterm = xnext
97+
@. xcur = x
98+
for i=1:solver.supersample
99+
fc!(k, xcur, u, d, p)
100+
@. xcur = xcur + k * Ts_inner
86101
end
102+
@. xnext = xcur
103+
return nothing
87104
end
88-
h! = hc!
89-
return f!, h!
105+
return f!
90106
end
91107

92108
"""

0 commit comments

Comments
 (0)