diff --git a/Project.toml b/Project.toml index 8c2a99f7..352e6b41 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,9 @@ version = "0.4.0" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +AutoPreallocation = "e7028de2-df94-4053-9fdc-99272086b8d4" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Cassette = "7057c7e9-c182-5462-911a-8362d720325c" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" diff --git a/src/CMBLensing.jl b/src/CMBLensing.jl index 392ab8aa..1aac47de 100644 --- a/src/CMBLensing.jl +++ b/src/CMBLensing.jl @@ -1,12 +1,14 @@ module CMBLensing using Adapt +using AutoPreallocation using Base.Broadcast: AbstractArrayStyle, ArrayStyle, Broadcasted, broadcasted, DefaultArrayStyle, preprocess_args, Style, result_style using Base.Iterators: flatten, product, repeated, cycle, countfrom using Base.Threads using Base: @kwdef, @propagate_inbounds, Bottom, OneTo, showarg, show_datatype, show_default, show_vector, typed_vcat +using Cassette using Combinatorics using DataStructures using DelimitedFiles diff --git a/src/flat_generic.jl b/src/flat_generic.jl index e94f8a28..0544d3ab 100644 --- a/src/flat_generic.jl +++ b/src/flat_generic.jl @@ -134,3 +134,28 @@ are not statistically the same. """ fixed_white_noise(rng, F::Type{<:FlatFieldFourier}) = exp.(im .* angle.(basis(F)(white_noise(rng,F)))) .* fieldinfo(F).Nside + + + +# optimization needed for AutoPreallocation, which otherwise really +# barfs trying to go through these `similar` calls down to the +# underlying `Array` or `CuArray` call +@inline function Cassette.overdub( + ctx :: AutoPreallocation.RecordingCtx, + :: typeof(similar), + bc :: Broadcasted{<:Union{FlatS0Style,FieldTupleStyle}}, + args... +) + ret = similar(bc, args...) + AutoPreallocation.record_alloc!(ctx, ret) + return ret +end +@inline function Cassette.overdub( + ctx :: AutoPreallocation.ReplayCtx, + :: typeof(similar), + bc :: Broadcasted{<:Union{FlatS0Style,FieldTupleStyle}}, + args... +) + scheduled = AutoPreallocation.next_scheduled_alloc!(ctx) + return scheduled +end diff --git a/src/numerical_algorithms.jl b/src/numerical_algorithms.jl index ef279b7c..57f868a1 100644 --- a/src/numerical_algorithms.jl +++ b/src/numerical_algorithms.jl @@ -19,7 +19,7 @@ function RK4Solver(F!::Function, y₀, t₀, t₁, nsteps) h, h½, h⅙ = (t₁-t₀)/nsteps ./ (1,2,6) y = copy(y₀) k₁, k₂, k₃, k₄, y′ = @repeated(similar(y₀),5) - for t in range(t₀,t₁,length=nsteps+1)[1:end-1] + @no_prealloc for t in range(t₀,t₁,length=nsteps+1)[1:end-1] @! k₁ = F(t, y) @! k₂ = F(t + h½, (@. y′ = y + h½*k₁)) @! k₃ = F(t + h½, (@. y′ = y + h½*k₂)) @@ -75,21 +75,31 @@ Info from the iterations of the solver can be returned if `hist` is specified. `histmod` can be used to include every N-th iteration only in `hist`. """ -function conjugate_gradient(M, A, b, x=0*b; nsteps=length(b), tol=sqrt(eps()), progress=false, callback=nothing, hist=nothing, histmod=1) +function conjugate_gradient( + M, A, b, x=zero(b); + nsteps = length(b), + tol = sqrt(eps(real(eltype(b)))), + progress = false, + callback = nothing, + hist = nothing, + histmod = 1, + prealloc = false +) + gethist() = hist == nothing ? nothing : NamedTuple{hist}(getindex.(Ref(@dict(i,x,p,r,res,t)),hist)) t₀ = time() i = 1 r = b - A*x z = M \ r p = z - bestres = res = res₀ = dot(r,z) + res = res₀ = dot(r,z) @assert !isnan(res) - bestx = x t = time() - t₀ _hist = [gethist()] prog = Progress(100, (progress!=false ? progress : Inf), "Conjugate Gradient: ") - for outer i = 2:nsteps + + function cg_iteration() Ap = A * p α = res / dot(p,Ap) x = x + α * p @@ -99,20 +109,16 @@ function conjugate_gradient(M, A, b, x=0*b; nsteps=length(b), tol=sqrt(eps()), p p = z + (res′ / res) * p res = res′ t = time() - t₀ - - if all(res