diff --git a/Project.toml b/Project.toml index bcb43388..bfc65dc8 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" +KernelSpectralDensities = "027d52a2-76e5-4228-9bfe-bc7e0f5a8348" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -24,6 +25,7 @@ Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25" FillArrays = "0.7, 0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1" IrrationalConstants = "0.1, 0.2" KernelFunctions = "0.9, 0.10" +KernelSpectralDensities = "0.2.0" LinearAlgebra = "1" PDMats = "0.11" Random = "1" diff --git a/src/AbstractGPs.jl b/src/AbstractGPs.jl index b0cd6885..c24626c9 100644 --- a/src/AbstractGPs.jl +++ b/src/AbstractGPs.jl @@ -13,6 +13,7 @@ using RecipesBase using IrrationalConstants: log2π using KernelFunctions: ColVecs, RowVecs +using KernelSpectralDensities using ChainRulesCore: ChainRulesCore @@ -33,6 +34,7 @@ export rand!, posterior, update_posterior export ColVecs, RowVecs +export GPSampler, CholeskySampling, Conditional, Independent, RFFSampling, PathwiseSampling # Various bits of utility functionality. include("util/common_covmat_ops.jl") @@ -56,6 +58,9 @@ include("sparse_approximations.jl") # LatentGP and LatentFiniteGP objects to accommodate GPs with non-Gaussian likelihoods. include("latent_gp.jl") +# Different sampling methods +include("sampling.jl") + # Plotting utilities. include("util/plotting.jl") diff --git a/src/sampling.jl b/src/sampling.jl new file mode 100644 index 00000000..09422b7f --- /dev/null +++ b/src/sampling.jl @@ -0,0 +1,264 @@ +abstract type AbstractGPSamplingMethod end + +SeedableRNG = Union{Xoshiro,MersenneTwister} + +_rand(rng, d) = Random.rand(rng, d) +function _rand(rng::AbstractRNG, ::Type{T}) where {T<:SeedableRNG} + return T(Random.rand(rng, 1:typemax(Int))) +end + +# ## Interface + +struct GPSample{F,S} + fun::F + sample::S +end + +(gs::GPSample)(x::AbstractArray) = eval_at(gs.fun, gs.sample, x) + +# This may become more challenging once we extend to multi-input GPS +(gs::GPSample)(x::Number) = only(gs([x])) + +""" + GPSampler(gp::AbstractGPs.AbstractGP, method::AbstractGPSamplingMethod) +Creates a sampler for the given `gp` using the specified `method`. + +```jldoctest +julia> f = GP(Matern32Kernel()); + +julia> gps = GPSampler(f, CholeskySampling()); + +julia> rand(gps); +``` +""" +struct GPSampler{F,S} <: Random.Sampler{GPSample} + fun::F + sampler::S + + # Specify input types here, since it is a "public" interface + function GPSampler(gp::AbstractGPs.AbstractGP, method::AbstractGPSamplingMethod) + fun, sampler = method(gp) + return new{typeof(fun),typeof(sampler)}(fun, sampler) + end +end + +# Don't love the deepcopy here +# issue is "pass by sharing" and the mutable struct in CholeskySampling +function Random.rand(rng::AbstractRNG, gs::GPSampler) + return GPSample(deepcopy(gs.fun), _rand(rng, gs.sampler)) +end + +# ## Utils + +_get_prior(gp::AbstractGPs.GP) = gp +_get_prior(pgp::AbstractGPs.PosteriorGP) = pgp.prior + +function get_obs_variance(pgp::AbstractGPs.PosteriorGP) + σk = pgp.prior.kernel(0, 0) + v = diag(pgp.data.C.L * pgp.data.C.U) .- σk + return v +end + +function get_target_prior(pgp::AbstractGPs.PosteriorGP) + m = pgp.data.δ + σ2 = get_obs_variance(pgp) + return MvNormal(m, sqrt.(σ2)) +end + +######################### +# Function Space/ Cholesky + +""" + CholeskySampling(s=Conditional, generator=Xoshiro) +Sampling by using the standard way, via Cholesky decomposition. +Arguments: +- `s`: Sampling type, either `Conditional` or `Independent`. Default is `Conditional`. +- `generator`: Random number generator to use in each sample. Default is `Xoshiro`. +""" +struct CholeskySampling{M,G} <: AbstractGPSamplingMethod + function CholeskySampling(s=Conditional, generator=Xoshiro) + return new{s,generator}() + end +end + +function (cs::CholeskySampling{M,G})(gp) where {M,G} + return M(gp), G +end + +""" + Conditional +Generates a GP sample that conditions function samples on all previous samples. +""" +mutable struct Conditional{GPT<:AbstractGPs.AbstractGP} + gp::GPT +end + +function Conditional(gp::AbstractGPs.GP) + data = ( + α=Vector{Float64}(undef, 0), + C=Cholesky(UpperTriangular(Matrix{Float64}(undef, 0, 0))), + x=Vector{Float64}(undef, 0), + δ=Vector{Float64}(undef, 0), + ) + pgp = AbstractGPs.PosteriorGP(gp, data) + return Conditional(pgp) +end + +function eval_at(s::Conditional, rng, x::AbstractArray) + if isassigned(s.gp.data.x, 1) + pgp = s.gp + else + pgp = s.gp.prior + end + fgp = pgp(x) + y = rand(rng, fgp) + s.gp = posterior(fgp, y) + return y +end + +""" + Independent +Generates a GP sample that samples function samples independent from previous samples. +""" +struct Independent{GPT<:AbstractGPs.AbstractGP} + gp::GPT + function Independent(gp) + return new{typeof(gp)}(gp) + end +end + +function eval_at(s::Independent, rng, x::AbstractArray) + gp = s.gp + fgp = gp(x) + y = rand(rng, fgp) + return y +end + +# ## WeightSpace + +# ### Utils + +get_weight_distribution(::AbstractGPs.GP, rff) = MvNormal(ones(rff.l)) + +function get_weight_distribution(pgp::AbstractGPs.PosteriorGP, rff) + d = get_target_prior(pgp) + + P = rff.(pgp.data.x) + Pt = reduce(hcat, P) + Cp = Symmetric(Pt * (d.Σ \ Pt') + I) + C = cholesky(Cp) + + μ = C \ (Pt * (d.Σ \ d.μ)) + Σ = C \ I + return MvNormal(μ, Symmetric(Σ)) +end + +# ### Main + +""" + RFFSampling(l::Int, rff_type::Type{<:KernelSpectralDensities.AbstractRFF}=DoubleRFF) +Sampling by using Random Fourier Features. +Arguments: +- `l`: Number of random Fourier features to use. +- `rff_type`: Type of random Fourier features to use. Default is `DoubleRFF`. +""" +struct RFFSampling{RFF,RNG} <: AbstractGPSamplingMethod + l::Int + rng::RNG + function RFFSampling( + rng, l; rff_type::Type{<:KernelSpectralDensities.AbstractRFF}=DoubleRFF + ) + return new{rff_type,typeof(rng)}(l, rng) + end +end + +function RFFSampling(l; rff_type::Type{<:KernelSpectralDensities.AbstractRFF}=DoubleRFF) + return RFFSampling(Random.default_rng(), l; rff_type) +end + +function (rffs::RFFSampling{RFF})(gp) where {RFF} + prior = _get_prior(gp) + # for now, hardcoding "1", later expand for multi-input + S = SpectralDensity(prior.kernel, 1) + # ToDo: add rng to RFF + rff = RFF(rffs.rng, S, rffs.l) + + ws = get_weight_distribution(gp, rff) + + return rff, ws +end + +function eval_at(rff::KernelSpectralDensities.AbstractRFF, w, x::AbstractArray) + return dot.(rff.(x), Ref(w)) +end + +# ## PathwiseSampler + +# ### utils +struct KernelBasis{K} + ker::K + x::AbstractArray +end + +(kb::KernelBasis)(x) = kb.ker.(Ref(x), kb.x) + +function update_basis(pgp, cs::CholeskySampling) + ker = pgp.prior.kernel + x = pgp.data.x + return KernelBasis(ker, x) +end + +function update_basis(pgp, rffs::RFFSampling) + rff, _ = rffs(pgp) + + return rff +end + +# ### Main + +""" + PathwiseSampling(l::Int) +Sampling by using pathwise sampling, which uses RFF sampling for the prior and an update rule +based on the kernel. Takes as an input the number of random Fourier features `l` to use. +""" +struct PathwiseSampling{P,U} <: AbstractGPSamplingMethod + prior::P + update::U +end + +function PathwiseSampling(l::Int) + return PathwiseSampling(RFFSampling(l), CholeskySampling()) +end + +struct PathwiseSampler{PS,TS,D} + prior_sampler::PS + target_sampler::TS + data::D +end + +function (ps::PathwiseSampling)(pgp::AbstractGPs.PosteriorGP) + upd_fun = update_basis(pgp, ps.update) + + prior = pgp.prior + prior_sampler = GPSampler(prior, ps.prior) + + target_dist = get_target_prior(pgp) + + data = (C=pgp.data.C, x=pgp.data.x) + return upd_fun, PathwiseSampler(prior_sampler, target_dist, data) +end + +function _rand(rng::AbstractRNG, ps::PathwiseSampler) + prior = rand(rng, ps.prior_sampler) + f = prior(ps.data.x) + + ts = rand(rng, ps.target_sampler) + + v = ps.data.C \ (ts - f) + + return (prior=prior, v=v) +end + +function eval_at(basis::KernelBasis, s, x::AbstractArray) + return s.prior(x) .+ dot.(basis.(x), Ref(s.v)) +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index d5edac8d..415399f7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -77,6 +77,10 @@ include("test_util.jl") println(" ") @info "Ran latent_gp tests" + include("sampling.jl") + println(" ") + @info "Ran sampling tests" + include("deprecations.jl") println(" ") @info "Ran deprecation tests" diff --git a/test/sampling.jl b/test/sampling.jl new file mode 100644 index 00000000..066b4c96 --- /dev/null +++ b/test/sampling.jl @@ -0,0 +1,137 @@ +@testset "Sampling" begin + rng = Xoshiro(1234) + + nx = 8 + x1 = collect(range(0, 2; length=nx)) + y1 = rand(nx) + + k = GaussianKernel() + g1 = GP(k) + g1x1 = g1(x1, 0.1) + pg1 = posterior(g1x1, y1) + + # # FunctionSpace + @testset "Constructors" begin + @testset "Cholesky" begin + @test CholeskySampling() isa CholeskySampling{Conditional,Xoshiro} + @test CholeskySampling(Conditional) isa CholeskySampling{Conditional,Xoshiro} + @test CholeskySampling(Conditional, Xoshiro) isa + CholeskySampling{Conditional,Xoshiro} + @test CholeskySampling(Conditional, MersenneTwister) isa + CholeskySampling{Conditional,MersenneTwister} + @test CholeskySampling(Independent) isa CholeskySampling{Independent,Xoshiro} + end + @testset "RFF" begin + @test RFFSampling(10) isa RFFSampling + # Testing other RFFs? Needs re-export from KernelSpectralDensities + end + @testset "Pathwise" begin + @test PathwiseSampling(RFFSampling(10), CholeskySampling()) isa + PathwiseSampling{<:RFFSampling,<:CholeskySampling} + @test PathwiseSampling(10) isa PathwiseSampling + end + end + + @testset "Basic Functional" begin + function test_basic_fun(gp, method) + gps = GPSampler(gp, method) + gps1 = rand(rng, gps) + @test gps1(0.4) isa Float64 + @test gps1([0.6, 0.7]) isa Vector{Float64} + end + @testset "Cholesky" begin + test_basic_fun(g1, CholeskySampling()) + test_basic_fun(pg1, CholeskySampling()) + end + + @testset "RFF" begin + test_basic_fun(g1, RFFSampling(20)) + test_basic_fun(pg1, RFFSampling(20)) + end + + @testset "Pathwise" begin + method = PathwiseSampling(RFFSampling(20), CholeskySampling()) + test_basic_fun(pg1, method) + end + end + + # ## Accuracy test + + # compute error between empirical and analytical version + function eval_res(x, gp, resv) + empres = cov(resv) + trueres = cov(gp, x) + return norm(empres .- trueres) + end + + # Evaluate sample all at once + function oneshot_error(x, gp, gps, n) + resv = [rand(rng, gps)(x) for _ in 1:n] + return eval_res(x, gp, resv) + end + + # Evaluate samples one by one + function onebyone(gpsampler, x) + gps = rand(rng, gpsampler) + y = [gps(xi) for xi in x] + return y + end + function onebyone_error(x, gp, gps, n) + resv = [onebyone(gps, x) for _ in 1:n] + return eval_res(x, gp, resv) + end + + function grid_test(gp, x, nv, method, evalfun) + gps = GPSampler(gp, method) + + res = [evalfun(x, gp, gps, n) for n in nv] + return all(diff(res) .< 0) + end + + x = collect(range(0, 2; length=9)) + nv = [10, 100, 1000] + + @testset "Correctness" begin + @testset "FunctionSpace" begin + @testset "Prior, FullMemory" begin + @test grid_test(g1, x, nv, CholeskySampling(Conditional), oneshot_error) + @test grid_test(g1, x, nv, CholeskySampling(Conditional), onebyone_error) + end + + @testset "Posterior, FullMemory" begin + @test grid_test(pg1, x, nv, CholeskySampling(Conditional), oneshot_error) + @test grid_test(pg1, x, nv, CholeskySampling(Conditional), onebyone_error) + end + + @testset "Prior, NoMemory" begin + @test grid_test(g1, x, nv, CholeskySampling(Independent), oneshot_error) + # @test grid_test(g1, x, nv, FunctionSpace(NoMemory), onebyone_error) + end + + @testset "Posterior, NoMemory" begin + @test grid_test(pg1, x, nv, CholeskySampling(Independent), oneshot_error) + # @test grid_test(g1, x, nv, FunctionSpace(NoMemory), onebyone_error) + end + end + + @testset "WeightSpace" begin + l = 80 + wsp = RFFSampling(rng, l) + @testset "Prior, DoubleRFF" begin + @test grid_test(g1, x, nv, wsp, oneshot_error) + @test grid_test(g1, x, nv, wsp, onebyone_error) + end + @testset "Posterior, DoubleRFF" begin + @test grid_test(pg1, x, nv, wsp, oneshot_error) + @test grid_test(pg1, x, nv, wsp, onebyone_error) + end + end + + l = 80 + @testset "Posterior, DoubleRFF" begin + method = PathwiseSampling(RFFSampling(l), CholeskySampling()) + @test grid_test(pg1, x, nv, method, oneshot_error) + @test grid_test(pg1, x, nv, method, onebyone_error) + end + end +end \ No newline at end of file