diff --git a/lib/OrdinaryDiffEqExplicitRK/src/explicit_rk_interp.jl b/lib/OrdinaryDiffEqExplicitRK/src/explicit_rk_interp.jl new file mode 100644 index 0000000000..3824b31a91 --- /dev/null +++ b/lib/OrdinaryDiffEqExplicitRK/src/explicit_rk_interp.jl @@ -0,0 +1,45 @@ +""" +Generic interpolation for Runge-Kutta methods. +Arguments: +- Θ: interpolation parameter (0 ≤ Θ ≤ 1) +- dt: time step +- y₀: initial value +- k: stage derivatives (vector of vectors, one per component) +- tableau: coefficient matrix where each row contains polynomial coefficients for a stage + Each row i contains [a₀, a₁, a₂, ...] for polynomial aᵢ₀ + aᵢ₁*Θ + aᵢ₂*Θ² + ... +- idxs: indices (optional, for partial interpolation) +- order: 0 for value, 1 for derivative +""" +function generic_interpolant(Θ, dt, y₀, k, tableau; idxs=nothing, order=0) + # Determine the number of stages based on the tableau size + num_stages = size(tableau, 1) + num_coeffs = size(tableau, 2) + + # For each stage, evaluate the polynomial or its derivative + b = if order == 0 + # Use builtin evalpoly for polynomial evaluation: a₀ + a₁*Θ + a₂*Θ² + ... + [@evalpoly(Θ, tableau[i,:]...) for i in 1:num_stages] + else + # For derivative: d/dΘ [a₀ + a₁*Θ + a₂*Θ² + ...] = a₁ + 2*a₂*Θ + 3*a₃*Θ² + ... + [@evalpoly(Θ, [j * tableau[i, j+1] for j in 1:(num_coeffs-1)]...) for i in 1:num_stages] + end + + # Compute the interpolation sum + if isnothing(idxs) + # Full vector + interp_sum = sum(k[i] * b[i] for i in 1:num_stages) + if order == 0 + return y₀ + dt * interp_sum + else + return interp_sum + end + else + # Indexed + interp_sum = sum(k[i][idxs] * b[i] for i in 1:num_stages) + if order == 0 + return y₀[idxs] + dt * interp_sum + else + return interp_sum + end + end +end \ No newline at end of file diff --git a/lib/OrdinaryDiffEqExplicitRK/src/tsit5_matrix.jl b/lib/OrdinaryDiffEqExplicitRK/src/tsit5_matrix.jl new file mode 100644 index 0000000000..9fc00448f8 --- /dev/null +++ b/lib/OrdinaryDiffEqExplicitRK/src/tsit5_matrix.jl @@ -0,0 +1,293 @@ +# ============================================================================ +# Tsit5 Interpolation Coefficients in Matrix Form +# ============================================================================ + +""" + construct_tsit5_interp_matrix(T::Type = Float64) + +Constructs the interpolation coefficient matrix for Tsit5 method. +This converts the polynomial coefficients from the original Tsit5 implementation +into a matrix format for generic interpolation. + +The matrix B_interp has dimensions (7, 5) where: +- Row i contains coefficients for stage i's interpolation polynomial +- Column j contains coefficients for Θ^(j-1) term + +Each polynomial bᵢ(Θ) is defined as: + bᵢ(Θ) = bᵢ₀ + bᵢ₁*Θ + bᵢ₂*Θ² + bᵢ₃*Θ³ + bᵢ₄*Θ⁴ + +For Tsit5, the original formulation was: + b₁(Θ) = Θ * (r11 + r12*Θ + r13*Θ² + r14*Θ³) + = 0 + r11*Θ + r12*Θ² + r13*Θ³ + r14*Θ⁴ + + b₂(Θ) = Θ² * (r22 + r23*Θ + r24*Θ²) + = 0 + 0*Θ + r22*Θ² + r23*Θ³ + r24*Θ⁴ + + ... and so on for all 7 stages += +""" +function construct_tsit5_interp_matrix(T::Type = Float64) + # Original Tsit5 interpolation coefficients + # From OrdinaryDiffEqTsit5/src/tsit_tableaus.jl + + # Stage 1: b₁(Θ) = Θ * (r11 + r12*Θ + r13*Θ² + r14*Θ³) + r11 = convert(T, 1.0) + r12 = convert(T, -2.763706197274826) + r13 = convert(T, 2.9132554618219126) + r14 = convert(T, -1.0530884977290216) + + # Stage 2: b₂(Θ) = Θ² * (r22 + r23*Θ + r24*Θ²) + r22 = convert(T, 0.13169999999999998) + r23 = convert(T, -0.2234) + r24 = convert(T, 0.1017) + + # Stage 3: b₃(Θ) = Θ² * (r32 + r33*Θ + r34*Θ²) + r32 = convert(T, 3.9302962368947516) + r33 = convert(T, -5.941033872131505) + r34 = convert(T, 2.490627285651253) + + # Stage 4: b₄(Θ) = Θ² * (r42 + r43*Θ + r44*Θ²) + r42 = convert(T, -12.411077166933676) + r43 = convert(T, 30.33818863028232) + r44 = convert(T, -16.548102889244902) + + # Stage 5: b₅(Θ) = Θ² * (r52 + r53*Θ + r54*Θ²) + r52 = convert(T, 37.50931341651104) + r53 = convert(T, -88.1789048947664) + r54 = convert(T, 47.37952196281928) + + # Stage 6: b₆(Θ) = Θ² * (r62 + r63*Θ + r64*Θ²) + r62 = convert(T, -27.896526289197286) + r63 = convert(T, 65.09189467479366) + r64 = convert(T, -34.87065786149661) + + # Stage 7: b₇(Θ) = Θ² * (r72 + r73*Θ + r74*Θ²) + r72 = convert(T, 1.5) + r73 = convert(T, -4.0) + r74 = convert(T, 2.5) + + # Construct the interpolation matrix + # B_interp[i, j] = coefficient of Θ^(j-1) in bᵢ(Θ) + B_interp = zeros(T, 7, 5) + + # Stage 1: bᵢ(Θ) = 0 + r11*Θ + r12*Θ² + r13*Θ³ + r14*Θ⁴ + B_interp[1, :] = [0, r11, r12, r13, r14] + + # Stages 2-7: bᵢ(Θ) = 0 + 0*Θ + ri2*Θ² + ri3*Θ³ + ri4*Θ⁴ + B_interp[2, :] = [0, 0, r22, r23, r24] + B_interp[3, :] = [0, 0, r32, r33, r34] + B_interp[4, :] = [0, 0, r42, r43, r44] + B_interp[5, :] = [0, 0, r52, r53, r54] + B_interp[6, :] = [0, 0, r62, r63, r64] + B_interp[7, :] = [0, 0, r72, r73, r74] + + return B_interp +end + +""" + construct_tsit5_interp_matrix_highprecision(T::Type) + +High-precision version for BigFloat and other arbitrary-precision types. +We have not tested this +""" +function construct_tsit5_interp_matrix_highprecision(T::Type) + # Stage 1 + r11 = convert(T, big"0.999999999999999974283372471559910888475488471328") + r12 = convert(T, big"-2.763706197274825911336735930481400260916070804192") + r13 = convert(T, big"2.91325546182191274375068099306808") + r14 = convert(T, -1.0530884977290216) + + # Stage 2 + r22 = convert(T, big"0.13169999999999999727") + r23 = convert(T, big"-0.22339999999999999818") + r24 = convert(T, 0.1017) + + # Stage 3 + r32 = convert(T, big"3.93029623689475152850687446709813398") + r33 = convert(T, big"-5.94103387213150473470249202589458001") + r34 = convert(T, big"2.490627285651252793") + + # Stage 4 + r42 = convert(T, big"-12.411077166933676983734381540685453484102414134010752") + r43 = convert(T, big"30.3381886302823215981729903691836576") + r44 = convert(T, big"-16.54810288924490272") + + # Stage 5 + r52 = convert(T, big"37.50931341651103919496903965334519631242339792120440212") + r53 = convert(T, big"-88.1789048947664011014276693541209817") + r54 = convert(T, big"47.37952196281928122") + + # Stage 6 + r62 = convert(T, big"-27.896526289197287805948263144598643896") + r63 = convert(T, big"65.09189467479367152629021928716553658") + r64 = convert(T, big"-34.87065786149660974") + + # Stage 7 + r72 = convert(T, 1.5) + r73 = convert(T, -4.0) + r74 = convert(T, 2.5) + + # Construct matrix + B_interp = zeros(T, 7, 5) + B_interp[1, :] = [0, r11, r12, r13, r14] + B_interp[2, :] = [0, 0, r22, r23, r24] + B_interp[3, :] = [0, 0, r32, r33, r34] + B_interp[4, :] = [0, 0, r42, r43, r44] + B_interp[5, :] = [0, 0, r52, r53, r54] + B_interp[6, :] = [0, 0, r62, r63, r64] + B_interp[7, :] = [0, 0, r72, r73, r74] + + return B_interp +end + +""" + construct_tsit5_interp_matrix_auto(T::Type) + +Automatically selects appropriate precision based on type. +""" +function construct_tsit5_interp_matrix_auto(T::Type) + if T <: Union{Float32, Float64} + return construct_tsit5_interp_matrix(T) + else + return construct_tsit5_interp_matrix_highprecision(T) + end +end + +# Convert Tsit5 tableau to ExplicitRK format + +""" + constructTsit5ExplicitRK(T::Type = Float64) + +Constructs the Tsitouras 5/4 method in ExplicitRK tableau format. +This allows using Tsit5 with the generic ExplicitRK solver. + +Tsit5 is a 7-stage, 5th-order method with 4th-order embedded error estimate. +""" +function constructTsit5ExplicitRK(T::Type = Float64) + # Build the A matrix (Butcher tableau coefficients) + # 7 stages, lower triangular (explicit method) + A=[0 0 0 0 0 0 0 + 14//87 0 0 0 0 0 0 + -1//117 50//149 0 0 0 0 0 + 310//107 -407//64 301//69 0 0 0 0 + 474//89 -2479//211 817//109 -5//54 0 0 0 + 381//65 -491//38 563//69 -19//265 -3//106 0 0 + 8//83 1//100 107//223 131//95 -329//100 179//77 0] + # A = Float8.(A) + + # Time nodes (c vector) + c = [0; 161//1000; 327//1000; 9//10; + big".9800255409045096857298102862870245954942137979563024768854764293221195950761080302604"; + 1; 1] + + + # Solution weights (b vector) - 5th order + α = [ + big".9468075576583945807478876255758922856117527357724631226139574065785592789071067303271e-1", + big".9183565540343253096776363936645313759813746240984095238905939532922955247253608687270e-2", + big".4877705284247615707855642599631228241516691959761363774365216240304071651579571959813", + big"1.234297566930478985655109673884237654035539930748192848315425833500484878378061439761", + big"-2.707712349983525454881109975059321670689605166938197378763992255714444407154902012702", + big"1.866628418170587035753719399566211498666255505244122593996591602841258328965767580089", + 1//66 # = 0.015151515151515152 + ] + # Error estimate weights (b̂ vector) - 4th order + # Note: In Tsit5, btilde = b - b̂, so b̂ = b - btilde + btilde = [ + big"-1.780011052225771443378550607539534775944678804333659557637450799792588061629796e-03", + big"-8.164344596567469032236360633546862401862537590159047610940604670770447527463931e-04", + big"7.880878010261996010314727672526304238628733777103128603258129604952959142646516e-03", + big"-1.44711007173262907537165147972635116720922712343167677619514233896760819649515e-01", + big"5.823571654525552250199376106520421794260781239567387797673045438803694038950012e-01", + big"-4.580821059291869466616365188325542974428047279788398179474684434732070620889539e-01", + 1//66 + ] + + # Calculate b̂ = b - btilde for the embedded 4th-order method + αEEst = α .- btilde + + # Convert to requested type + A = map(T, A) + α = map(T, α) + αEEst = map(T, αEEst) + c = map(T, c) + + return DiffEqBase.ExplicitRKTableau(A, c, α, 5, + αEEst = αEEst, + adaptiveorder = 4, + fsal = true, + stability_size = 2.9) # Approximate stability region size +end + +""" + constructTsit5ExplicitRKSimple(T::Type = Float64) + +Simplified version using rational and decimal approximations. +Faster to construct but slightly less accurate than the full precision version. +""" +function constructTsit5ExplicitRKSimple(T::Type = Float64) + #Tested a few more variants and leaving them commented out here for future reference + # Build the A matrix with simpler rationals/decimals + # A = [0 0 0 0 0 0 0 + # 0.161 0 0 0 0 0 0 + # -0.00848 0.3355 0 0 0 0 0 + # 2.8972 -6.3594 4.3623 0 0 0 0 + # 5.3259 -11.7489 7.4955 -0.0925 0 0 0 + # 5.8615 -12.9210 8.1594 -0.0716 -0.0283 0 0 + # 0.09646 0.01 0.4799 1.3790 -3.2901 2.3247 0] +# A = [0 0 0 0 0 0 0 +# 161//1000 0 0 0 0 0 0 +# -8480655492356989//1000000000000000000 335480655492357//1000000000000000 0 0 0 0 0 +# 2897153057105493//1000000000000000 -6359448489975075//1000000000000000 4362295432869582//1000000000000000 0 0 0 0 +# 5325864828439257//1000000000000000 -11748883564062828//10000000000000000 7495539342889836//1000000000000000 -92495066361755//1000000000000000 0 0 0 +# 5861455442946420//1000000000000000 -12920969317847109//1000000000000000 8159367898576159//1000000000000000 -71584973281401//1000000000000000 -28269050394068//1000000000000000 0 0 +# 96460766818065//1000000000000000 1//100 479889650414500//1000000000000000 1379008574103742//1000000000000000 -3290069515436081//1000000000000000 2324710524099774//1000000000000000 0] + +# A = Float64.(A) +A=[0 0 0 0 0 0 0 + 14//87 0 0 0 0 0 0 + -1//117 50//149 0 0 0 0 0 + 310//107 -407//64 301//69 0 0 0 0 + 474//89 -2479//211 817//109 -5//54 0 0 0 + 381//65 -491//38 563//69 -19//265 -3//106 0 0 + 8//83 1//100 107//223 131//95 -329//100 179//77 0] +# A=[0.0 0.0 0.0 0.0 0.0 0.0 0.0 +# 0.161 0.0 0.0 0.0 0.0 0.0 0.0 +# -0.008484 0.3354 0.0 0.0 0.0 0.0 0.0 +# 2.896 -6.36 4.363 0.0 0.0 0.0 0.0 +# 5.324 -1.175 7.496 -0.09247 0.0 0.0 0.0 +# 5.863 -12.92 8.16 -0.0716 -0.02827 0.0 0.0 +# 0.09644 0.01 0.48 1.379 -3.291 2.324 0.0] + + # Time nodes + c = [0, 0.161, 0.327, 0.9, 0.9800255409045097, 1.0, 1.0] + + + # Solution weights (5th order) + α = [0.09468075576583945, 0.009183565540343254, 0.4877705284247616, + 1.234297566930479, -2.7077123499835256, 1.866628418170587, + 0.015151515151515152] + + # Error estimate - computed from btilde + btilde = [-0.00178001105222577714, -0.0008164344596567469, 0.007880878010261995, + -0.1447110071732629, 0.5823571654525552, -0.45808210592918697, + 0.015151515151515152] + + αEEst = α .- btilde + + # Convert to requested type + A = map(T, A) + α = map(T, α) + αEEst = map(T, αEEst) + c = map(T, c) + + return DiffEqBase.ExplicitRKTableau(A, c, α, 5, + αEEst = αEEst, + adaptiveorder = 4, + fsal = true, + stability_size = 2.9) +end + +# Example usage: +# tableau = constructTsit5ExplicitRK() +# solve(prob, ExplicitRK(tableau = tableau)) \ No newline at end of file diff --git a/lib/OrdinaryDiffEqExplicitRK/test/benchmark_tests.jl b/lib/OrdinaryDiffEqExplicitRK/test/benchmark_tests.jl new file mode 100644 index 0000000000..159688a5f1 --- /dev/null +++ b/lib/OrdinaryDiffEqExplicitRK/test/benchmark_tests.jl @@ -0,0 +1,356 @@ +# Large System Benchmark: ExplicitRK vs Tsit5 +# Comparing performance on systems with 10,000+ ODEs + +using OrdinaryDiffEq +using BenchmarkTools +using Printf +using OrdinaryDiffEqExplicitRK +using DiffEqBase +using LinearAlgebra +using Plots + + +include("../src/tsit5_matrix.jl") + +# Import tableau construction functions +using OrdinaryDiffEqExplicitRK: constructDormandPrince +using OrdinaryDiffEqExplicitRK: constructTsit5ExplicitRKSimple + +using OrdinaryDiffEqExplicitRK: constructTsit5ExplicitRK +# ============================================================================ +# Problem 1: Lorenz-96 System (Atmospheric Model) +# ============================================================================ +""" + lorenz96!(du, u, p, t) + +The Lorenz-96 model - a chaotic dynamical system used in atmospheric science. +Scalable to any dimension N. + +du[i]/dt = (u[i+1] - u[i-2]) * u[i-1] - u[i] + F + +where F is a forcing parameter (typically F = 8) +""" +function lorenz96!(du, u, p, t) + F = p + N = length(u) + + # Periodic boundary conditions + @inbounds for i in 1:N + i_m2 = mod1(i - 2, N) + i_m1 = mod1(i - 1, N) + i_p1 = mod1(i + 1, N) + + du[i] = (u[i_p1] - u[i_m2]) * u[i_m1] - u[i] + F + end + return nothing +end + +function create_lorenz96_problem(N=10000; F=8.0, tspan=(0.0, 10.0)) + # Initial conditions: small random perturbations around F + u0 = F .+ 0.01 .* randn(N) + return ODEProblem(lorenz96!, u0, tspan, F) +end + +# ============================================================================ +# Problem 2: 1D Reaction-Diffusion System (PDE Discretization) +# ============================================================================ +""" + reaction_diffusion!(du, u, p, t) + +1D reaction-diffusion PDE discretized with finite differences: +∂u/∂t = D * ∂²u/∂x² + R(u) + +where R(u) = α*u*(1-u) is a reaction term (Fisher-KPP equation) +""" +function reaction_diffusion!(du, u, p, t) + D, α, dx = p + N = length(u) + inv_dx2 = 1.0 / (dx * dx) + + # Neumann boundary conditions (zero flux) + @inbounds du[1] = D * (u[2] - 2*u[1] + u[1]) * inv_dx2 + α * u[1] * (1 - u[1]) + + @inbounds for i in 2:N-1 + du[i] = D * (u[i+1] - 2*u[i] + u[i-1]) * inv_dx2 + α * u[i] * (1 - u[i]) + end + + @inbounds du[N] = D * (u[N] - 2*u[N] + u[N-1]) * inv_dx2 + α * u[N] * (1 - u[N]) + + return nothing +end + +function create_reaction_diffusion_problem(N=10000; D=0.01, α=1.0, tspan=(0.0, 5.0)) + # Spatial discretization + L = 100.0 # Domain length + dx = L / (N - 1) + + # Initial condition: step function + u0 = zeros(N) + u0[N÷4:3*N÷4] .= 1.0 + + return ODEProblem(reaction_diffusion!, u0, tspan, (D, α, dx)) +end + +# ============================================================================ +# Problem 3: Coupled Oscillators (Network Dynamics) +# ============================================================================ +""" + coupled_oscillators!(du, u, p, t) + +Network of coupled harmonic oscillators: +d²x[i]/dt² = -ω²*x[i] + K*Σ(x[j] - x[i]) for connected j + +Converted to first-order system: [x₁, v₁, x₂, v₂, ..., xₙ, vₙ] +""" +function coupled_oscillators!(du, u, p, t) + ω, K = p + N = length(u) ÷ 2 + + @inbounds for i in 1:N + x_idx = 2*i - 1 + v_idx = 2*i + + # Position derivative = velocity + du[x_idx] = u[v_idx] + + # Velocity derivative = acceleration + # Couple to neighbors (ring topology) + i_prev = mod1(i - 1, N) + i_next = mod1(i + 1, N) + + x_prev_idx = 2*i_prev - 1 + x_next_idx = 2*i_next - 1 + + coupling = K * ((u[x_prev_idx] - u[x_idx]) + (u[x_next_idx] - u[x_idx])) + du[v_idx] = -ω*ω*u[x_idx] + coupling + end + + return nothing +end + +function create_coupled_oscillators_problem(N=5000; ω=1.0, K=0.1, tspan=(0.0, 50.0)) + # N oscillators → 2N equations (position + velocity) + u0 = zeros(2*N) + # Random initial positions and velocities + u0[1:2:end] = randn(N) # positions + u0[2:2:end] = randn(N) # velocities + + return ODEProblem(coupled_oscillators!, u0, tspan, (ω, K)) +end + +# ============================================================================ +# Problem 4: Lotka Volterra System (prey-predator Model) +# ============================================================================ +""" + lotka_volterra_blockdiag!(du, u, p, t) + +Block-diagonal system of independent Lotka-Volterra pairs. +Each pair: du[2i-1] = α*u[2i-1] - γ*u[2i-1]*u[2i] + du[2i] = -β*u[2i] + δ*u[2i-1]*u[2i] +""" +function lotka_volterra_blockdiag!(du, u, p, t) + α, β, γ, δ, npairs = p + @inbounds for i in 1:npairs + xidx = 2i - 1 + yidx = 2i + x, y = u[xidx], u[yidx] + du[xidx] = α * x - γ * x * y + du[yidx] = -β * y + δ * x * y + end + return nothing +end + +function create_lotka_volterra_problem(N=10000; α=1.5, β=3.0, γ=1.0, δ=1.0, tspan=(0.0, 10.0)) + @assert N % 2 == 0 "N must be even (pairs of equations)" + npairs = N ÷ 2 + u0 = repeat([1.0, 1.0], npairs) + p = (α, β, γ, δ, npairs) + return ODEProblem(lotka_volterra_blockdiag!, u0, tspan, p) +end + +# ============================================================================ +# Benchmark Utilities +# ============================================================================ + +function benchmark_solver(prob, alg; name="Solver") + println("\n" * "="^70) + println("Benchmarking: $name") + println("="^70) + + # Warmup + sol = solve(prob, alg, saveat=1.0, abstol=1e-6, reltol=1e-3) + + # Benchmark + println("\nRunning benchmark (5 samples, 1 evaluation each)...") + b = @benchmark solve($prob, $alg, saveat=1.0, abstol=1e-6, reltol=1e-3) samples=5 evals=1 + + println("\nResults:") + println(" Minimum time: $(BenchmarkTools.prettytime(minimum(b.times)))") + println(" Median time: $(BenchmarkTools.prettytime(median(b.times)))") + println(" Mean time: $(BenchmarkTools.prettytime(mean(b.times)))") + println(" Memory: $(BenchmarkTools.prettymemory(b.memory))") + println(" Allocations: $(b.allocs)") + println("\nSolution stats:") + println(" Steps: $(sol.stats.naccept)") + println(" Function evals: $(sol.stats.nf)") + println(" Final time: $(sol.t[end])") + + return (benchmark=b, solution=sol) +end + +function compare_solvers(prob_name, prob; use_tsit5_tableau=false) + println("\n" * "#"^70) + println("# Problem: $prob_name") + println("# System size: $(length(prob.u0)) equations") + println("#"^70) + + # Benchmark ExplicitRK with chosen tableau + if use_tsit5_tableau + tableau = constructTsit5ExplicitRKSimple() + tableau_name = "Tsit5 Tableau" + else + tableau = constructDormandPrince() + tableau_name = "Dormand-Prince" + end + + result_explicit = benchmark_solver(prob, ExplicitRK(tableau=tableau), name="ExplicitRK ($tableau_name)") + result_tsit5 = benchmark_solver(prob, Tsit5(), name="Tsit5 (Specialized)") + b_explicit, sol_explicit = result_explicit.benchmark, result_explicit.solution + b_tsit5, sol_tsit5 = result_tsit5.benchmark, result_tsit5.solution + + # Compare + println("\n" * "="^70) + println("COMPARISON") + println("="^70) + + time_ratio = median(b_explicit.times) / median(b_tsit5.times) + mem_ratio = b_explicit.memory / max(b_tsit5.memory, 1) # Avoid div by 0 + alloc_ratio = b_explicit.allocs / max(b_tsit5.allocs, 1) + + @printf(" Speedup (Tsit5 vs ExplicitRK): %.2fx\n", time_ratio) + @printf(" Memory reduction: %.2fx\n", mem_ratio) + @printf(" Allocation reduction: %.2fx\n", alloc_ratio) + + return (explicit=(benchmark=b_explicit, solution=sol_explicit), tsit5=(benchmark=b_tsit5, solution=sol_tsit5)) +end + +# ============================================================================ +# Main Benchmark Suite +# ============================================================================ + +function run_large_system_benchmarks(; use_tsit5_tableau=false) + println("="^70) + println("LARGE ODE SYSTEM BENCHMARKS") + println("Comparing ExplicitRK vs Tsit5 on 10,000+ equation systems") + if use_tsit5_tableau + println("Using Tsit5 tableau for ExplicitRK") + else + println("Using Dormand-Prince tableau for ExplicitRK") + end + println("="^70) + + + # Test 1: Lorenz-96 (10,000 equations) + println("\n📊 Test 1: Lorenz-96 Atmospheric Model") + prob1 = create_lorenz96_problem(10000, tspan=(0.0, 5.0)) + results1 = compare_solvers("Lorenz-96 (N=10,000)", prob1, use_tsit5_tableau=use_tsit5_tableau) + + # Commented out these tests based on feedback + + # Test 2: Reaction-Diffusion (10,000 equations) + # println("\n📊 Test 2: Reaction-Diffusion PDE") + # prob2 = create_reaction_diffusion_problem(10000, tspan=(0.0, 2.0)) + # results2 = compare_solvers("Reaction-Diffusion 1D (N=10,000)", prob2, use_tsit5_tableau=use_tsit5_tableau) + + # # Test 3: Coupled Oscillators (10,000 equations = 5,000 oscillators) + # println("\n📊 Test 3: Coupled Oscillators Network") + # prob3 = create_coupled_oscillators_problem(5000, tspan=(0.0, 10.0)) + # results3 = compare_solvers("Coupled Oscillators (5,000 × 2 = 10,000 eqs)", prob3, use_tsit5_tableau=use_tsit5_tableau) + + # Test 4: Lotka-Volterra (10,000 equations = 5,000 pairs) + println("\n📊 Test 4: Lotka-Volterra Block-Diagonal System") + prob4 = create_lotka_volterra_problem(10000, tspan=(0.0, 10.0)) + results4 = compare_solvers("Lotka-Volterra (5,000 × 2 = 10,000 eqs)", prob4, use_tsit5_tableau=use_tsit5_tableau) + + println("\n" * "="^70) + println("BENCHMARK SUITE COMPLETE") + println("="^70) + + # return (lorenz96=results1, reaction_diffusion=results2, oscillators=results3, lotka_volterra=results4) + return (lorenz96=results1, lotka_volterra=results4) +end + +# ============================================================================ +# Quick Test (Smaller Systems for Development) +# ============================================================================ + +function quick_test() + println("\n🔬 Quick Test (smaller systems for development)") + println("="^70) + + # Small Lorenz-96 + prob = create_lorenz96_problem(100, tspan=(0.0, 1.0)) + + println("\nTesting ExplicitRKTsit5Tableau...") + # dp_tableau = constructDormandPrince() + dp_tableau = constructTsit5ExplicitRK() + @time sol1 = solve(prob, ExplicitRK(tableau=dp_tableau), abstol=1e-6, reltol=1e-5) + println(" Steps: $(sol1.stats.naccept), Function evals: $(sol1.stats.nf)") + + println("\nTesting Tsit5...") + @time sol2 = solve(prob, Tsit5(), abstol=1e-6, reltol=1e-5) + println(" Steps: $(sol2.stats.naccept), Function evals: $(sol2.stats.nf)") + + println("\n✅ Quick test passed!") +end + +function run_small_system_benchmarks(; use_tsit5_tableau=true) + println("\n" * "="^70) + println("SMALL ODE SYSTEM BENCHMARKS") + println("Comparing ExplicitRK vs Tsit5 on small systems") + println("="^70) + + # Test 1: Small Lorenz-96 (100 equations) + println("\n📊 Test 1: Lorenz-96 Atmospheric Model (N=100)") + prob1 = create_lorenz96_problem(100, tspan=(0.0, 5.0)) + results1 = compare_solvers("Lorenz-96 (N=100)", prob1, use_tsit5_tableau=use_tsit5_tableau) + + # Test 2: Small Reaction-Diffusion (100 equations) + println("\n📊 Test 2: Reaction-Diffusion PDE (N=100)") + prob2 = create_reaction_diffusion_problem(100, tspan=(0.0, 2.0)) + results2 = compare_solvers("Reaction-Diffusion 1D (N=100)", prob2, use_tsit5_tableau=use_tsit5_tableau) + + # Test 3: Small Coupled Oscillators (100 oscillators × 2 = 200 equations) + println("\n📊 Test 3: Coupled Oscillators Network (N=100 oscillators)") + prob3 = create_coupled_oscillators_problem(100, tspan=(0.0, 10.0)) + results3 = compare_solvers("Coupled Oscillators (100 × 2 = 200 eqs)", prob3, use_tsit5_tableau=use_tsit5_tableau) + + # Test 4: Small Lotka-Volterra (100 pairs × 2 = 200 equations) + println("\n📊 Test 4: Lotka-Volterra Block-Diagonal System (N=200)") + prob4 = create_lotka_volterra_problem(200, tspan=(0.0, 10.0)) + results4 = compare_solvers("Lotka-Volterra (100 × 2 = 200 eqs)", prob4, use_tsit5_tableau=use_tsit5_tableau) + + println("\n" * "="^70) + println("BENCHMARK SUITE COMPLETE") + println("="^70) + + return (lorenz96=results1, reaction_diffusion=results2, oscillators=results3, lotka_volterra=results4) +end + +# After running benchmarks +results = run_large_system_benchmarks(use_tsit5_tableau=true) +# results = run_small_system_benchmarks() +sol_explicit = results.lotka_volterra.explicit.solution +sol_tsit5 = results.lotka_volterra.tsit5.solution +# Compare final state +final_explicit = sol_explicit.u[end] +final_tsit5 = sol_tsit5.u[end] +diff_norm = norm(final_explicit .- final_tsit5) +println("Norm of difference between ExplicitRK and Tsit5 final values: $diff_norm") +# Compare all time points +all_diffs = [norm(sol_explicit.u[i] .- sol_tsit5.u[i]) for i in 1:length(sol_explicit.u)] +println("Max difference across all time points: $(maximum(all_diffs))") + +# Quick verification +plot(sol_explicit.t, sol_explicit.u, label="ExplicitRK") +plot!(sol_tsit5.t, sol_tsit5.u, label="Tsit5")