Skip to content

Commit 64a72da

Browse files
authored
Merge pull request #47 from TuringLang/kx/improve-interface
Improve NUTS interface
2 parents a9f0055 + 21389e5 commit 64a72da

File tree

6 files changed

+64
-39
lines changed

6 files changed

+64
-39
lines changed

src/adaptation.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ function update(h::Hamiltonian, prop::AbstractProposal, dpc::Adaptation.Abstract
1919
end
2020

2121
function update(h::Hamiltonian, prop::AbstractProposal, da::NesterovDualAveraging)
22-
return h, prop(getϵ(da))
22+
return h, prop(prop.integrator(getϵ(da)))
2323
end
2424

2525
function update(h::Hamiltonian, prop::AbstractProposal, ca::Adaptation.AbstractCompositeAdaptor)
26-
return h(getM⁻¹(ca.pc)), prop(getϵ(ca.ssa))
26+
return h(getM⁻¹(ca.pc)), prop(prop.integrator(getϵ(ca.ssa)))
2727
end

src/adaptation/stepsize.jl

+1
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ function adapt_stepsize!(da::NesterovDualAveraging, α::AbstractFloat)
105105
ϵ = exp(x)
106106
DEBUG && @debug "Adapting step size..." "new ϵ = " "old ϵ = $(da.state.ϵ)"
107107

108+
# TODO: we might want to remove this when all other numerical issues are correctly handelled
108109
if isnan(ϵ) || isinf(ϵ)
109110
@warn "Incorrect ϵ = ; ϵ_previous = $(da.state.ϵ) is used instead."
110111
ϵ = da.state.ϵ

src/hamiltonian.jl

+11-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,17 @@ end
1616
∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric}, r::AbstractVector) = h.metric.M⁻¹ * r
1717

1818
function hamiltonian_energy(h::Hamiltonian, θ::AbstractVector, r::AbstractVector)
19-
return kinetic_energy(h, r, θ) + potential_energy(h, θ)
19+
K = kinetic_energy(h, r, θ)
20+
if isnan(K)
21+
K = Inf
22+
@warn "Kinetic energy is `NaN` and is set to `Inf`."
23+
end
24+
V = potential_energy(h, θ)
25+
if isnan(V)
26+
V = Inf
27+
@warn "Potential energy is `NaN` and is set to `Inf`."
28+
end
29+
return K + V
2030
end
2131

2232
potential_energy(h::Hamiltonian, θ::AbstractVector) = -h.logπ(θ)

src/proposal.jl

+20-21
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ struct TakeLastProposal{I<:AbstractIntegrator} <: StaticTrajectory{I}
77
n_steps :: Int
88
end
99

10-
# Create a `TakeLastProposal` with a new `ϵ`
11-
function (tlp::TakeLastProposal)(ϵ::AbstractFloat)
12-
return TakeLastProposal(tlp.integrator(ϵ), tlp.n_steps)
10+
# Create a `TakeLastProposal` with a new integrator
11+
function (tlp::TakeLastProposal)(integrator::AbstractIntegrator)
12+
return TakeLastProposal(integrator, tlp.n_steps)
1313
end
1414

15-
function propose(prop::TakeLastProposal, h::Hamiltonian, θ::AbstractVector{T}, r::AbstractVector{T}) where {T<:Real}
15+
function transition(prop::TakeLastProposal, h::Hamiltonian, θ::AbstractVector{T}, r::AbstractVector{T}) where {T<:Real}
1616
θ, r, _ = steps(prop.integrator, h, θ, r, prop.n_steps)
1717
return θ, -r
1818
end
@@ -21,21 +21,22 @@ abstract type DynamicTrajectory{I<:AbstractIntegrator} <: AbstractHamiltonianTra
2121
abstract type NoUTurnTrajectory{I<:AbstractIntegrator} <: DynamicTrajectory{I} end
2222
struct NUTS{I<:AbstractIntegrator} <: NoUTurnTrajectory{I}
2323
integrator :: I
24+
max_depth :: Int
25+
Δ_max :: AbstractFloat
2426
end
2527

26-
# Create a `NUTS` with a new `ϵ`
27-
function (snuts::NUTS)(ϵ::AbstractFloat)
28-
return NUTS(snuts.integrator(ϵ))
28+
# Helper function to use default values
29+
NUTS(integrator::AbstractIntegrator) = NUTS(integrator, 10, 1000.0)
30+
31+
# Create a `NUTS` with a new integrator
32+
function (snuts::NUTS)(integrator::AbstractIntegrator)
33+
return NUTS(integrator, snuts.max_depth, snuts.Δ_max)
2934
end
3035

3136
struct MultinomialNUTS{I<:AbstractIntegrator} <: NoUTurnTrajectory{I}
3237
integrator :: I
3338
end
3439

35-
function NUTS(h::Hamiltonian, θ::AbstractVector{T}) where {T<:Real}
36-
return NUTS(Leapfrog(find_good_eps(h, θ)))
37-
end
38-
3940
function find_good_eps(rng::AbstractRNG, h::Hamiltonian, θ::AbstractVector{T}; max_n_iters::Int=100) where {T<:Real}
4041
ϵ′ = ϵ = 0.1
4142
a_min, a_cross, a_max = 0.25, 0.5, 0.75 # minimal, crossing, maximal accept ratio
@@ -93,14 +94,14 @@ end
9394
find_good_eps(h::Hamiltonian, θ::AbstractVector{T}; max_n_iters::Int=100) where {T<:Real} = find_good_eps(GLOBAL_RNG, h, θ; max_n_iters=max_n_iters)
9495

9596
# TODO: implement a more efficient way to build the balance tree
96-
function build_tree(rng::AbstractRNG, nt::NoUTurnTrajectory{I}, h::Hamiltonian, θ::AbstractVector{T}, r::AbstractVector{T}, logu::AbstractFloat, v::Int, j::Int, H::AbstractFloat;
97-
Δ_max::AbstractFloat=1000.0) where {I<:AbstractIntegrator,T<:Real}
97+
function build_tree(rng::AbstractRNG, nt::NoUTurnTrajectory{I}, h::Hamiltonian, θ::AbstractVector{T}, r::AbstractVector{T},
98+
logu::AbstractFloat, v::Int, j::Int, H::AbstractFloat) where {I<:AbstractIntegrator,T<:Real}
9899
if j == 0
99100
# Base case - take one leapfrog step in the direction v.
100101
θ′, r′, _is_valid = step(nt.integrator, h, θ, r)
101102
H′ = _is_valid ? hamiltonian_energy(h, θ′, r′) : Inf
102103
n′ = (logu <= -H′) ? 1 : 0
103-
s′ = (logu < Δ_max + -H′) ? 1 : 0
104+
s′ = (logu < nt.Δ_max + -H′) ? 1 : 0
104105
α′ = exp(min(0, H - H′))
105106

106107
return θ′, r′, θ′, r′, θ′, r′, n′, s′, α′, 1
@@ -128,18 +129,17 @@ function build_tree(rng::AbstractRNG, nt::NoUTurnTrajectory{I}, h::Hamiltonian,
128129
end
129130
end
130131

131-
build_tree(nt::NoUTurnTrajectory{I}, h::Hamiltonian, θ::AbstractVector{T}, r::AbstractVector{T}, logu::AbstractFloat, v::Int, j::Int, H::AbstractFloat;
132-
Δ_max::AbstractFloat=1000.0) where {I<:AbstractIntegrator,T<:Real} = build_tree(GLOBAL_RNG, nt, h, θ, r, logu, v, j, H; Δ_max=Δ_max)
132+
build_tree(nt::NoUTurnTrajectory{I}, h::Hamiltonian, θ::AbstractVector{T}, r::AbstractVector{T},
133+
logu::AbstractFloat, v::Int, j::Int, H::AbstractFloat) where {I<:AbstractIntegrator,T<:Real} = build_tree(GLOBAL_RNG, nt, h, θ, r, logu, v, j, H)
133134

134-
function propose(rng::AbstractRNG, nt::NoUTurnTrajectory{I}, h::Hamiltonian, θ::AbstractVector{T}, r::AbstractVector{T};
135-
j_max::Int=10) where {I<:AbstractIntegrator,T<:Real}
135+
function transition(rng::AbstractRNG, nt::NoUTurnTrajectory{I}, h::Hamiltonian, θ::AbstractVector{T}, r::AbstractVector{T}) where {I<:AbstractIntegrator,T<:Real}
136136
H = hamiltonian_energy(h, θ, r)
137137
logu = log(rand(rng)) - H
138138

139139
θm = θ; θp = θ; rm = r; rp = r; j = 0; θ_new = θ; r_new = r; n = 1; s = 1
140140

141141
local α, nα
142-
while s == 1 && j <= j_max
142+
while s == 1 && j <= nt.max_depth
143143
v = rand(rng, [-1, 1])
144144
if v == -1
145145
θm, rm, _, _, θ′, r′,n′, s′, α, nα = build_tree(rng, nt, h, θm, rm, logu, v, j, H)
@@ -162,8 +162,7 @@ function propose(rng::AbstractRNG, nt::NoUTurnTrajectory{I}, h::Hamiltonian, θ:
162162
return θ_new, r_new, α /
163163
end
164164

165-
propose(nt::NoUTurnTrajectory{I}, h::Hamiltonian, θ::AbstractVector{T}, r::AbstractVector{T};
166-
j_max::Int=10) where {I<:AbstractIntegrator,T<:Real} = propose(GLOBAL_RNG, nt, h, θ, r; j_max=j_max)
165+
transition(nt::NoUTurnTrajectory{I}, h::Hamiltonian, θ::AbstractVector{T}, r::AbstractVector{T}) where {I<:AbstractIntegrator,T<:Real} = transition(GLOBAL_RNG, nt, h, θ, r)
167166

168167
function MultinomialNUTS(h::Hamiltonian, θ::AbstractVector{T}) where {T<:Real}
169168
return MultinomialNUTS(Leapfrog(find_good_eps(h, θ)))

src/sampler.jl

+22-9
Original file line numberDiff line numberDiff line change
@@ -4,38 +4,51 @@ function mh_accept(rng::AbstractRNG, H::AbstractFloat, H_new::AbstractFloat)
44
end
55
mh_accept(H::AbstractFloat, H_new::AbstractFloat) = mh_accept(GLOBAL_RNG, logα)
66

7-
function sample(h::Hamiltonian, prop::AbstractProposal, θ::AbstractVector{T}, n_samples::Int; verbose::Bool=true) where {T<:Real}
7+
sample(h::Hamiltonian, prop::AbstractProposal, θ::AbstractVector{T}, n_samples::Int; verbose::Bool=true) where {T<:Real} =
8+
sample(GLOBAL_RNG, h, prop, θ, n_samples; verbose=verbose)
9+
10+
function sample(rng::AbstractRNG, h::Hamiltonian, prop::AbstractProposal, θ::AbstractVector{T}, n_samples::Int; verbose::Bool=true) where {T<:Real}
811
θs = Vector{Vector{T}}(undef, n_samples)
912
Hs = Vector{T}(undef, n_samples)
1013
αs = Vector{T}(undef, n_samples)
1114
time = @elapsed for i = 1:n_samples
12-
θs[i], Hs[i], αs[i] = step(h, prop, i == 1 ? θ : θs[i-1])
15+
θs[i], Hs[i], αs[i] = step(rng, h, prop, i == 1 ? θ : θs[i-1])
1316
end
14-
verbose && @info "Finished sampling with $time (s)" typeof(h) typeof(prop) EBFMI(Hs) mean(αs)
17+
verbose && @info "Finished sampling with $time (s)" typeof(h.metric) typeof(prop) EBFMI(Hs) mean(αs)
1518
return θs
1619
end
1720

18-
function sample(h::Hamiltonian, prop::AbstractProposal, θ::AbstractVector{T}, n_samples::Int, adaptor::Adaptation.AbstractAdaptor,
21+
sample(h::Hamiltonian, prop::AbstractProposal, θ::AbstractVector{T}, n_samples::Int, adaptor::Adaptation.AbstractAdaptor,
22+
n_adapts::Int=min(div(n_samples, 10), 1_000); verbose::Bool=true) where {T<:Real} =
23+
sample(GLOBAL_RNG, h, prop, θ, n_samples, adaptor, n_adapts; verbose=verbose)
24+
25+
function sample(rng::AbstractRNG, h::Hamiltonian, prop::AbstractProposal, θ::AbstractVector{T}, n_samples::Int, adaptor::Adaptation.AbstractAdaptor,
1926
n_adapts::Int=min(div(n_samples, 10), 1_000); verbose::Bool=true) where {T<:Real}
2027
θs = Vector{Vector{T}}(undef, n_samples)
2128
Hs = Vector{T}(undef, n_samples)
2229
αs = Vector{T}(undef, n_samples)
2330
time = @elapsed for i = 1:n_samples
24-
θs[i], Hs[i], αs[i] = step(h, prop, i == 1 ? θ : θs[i-1])
31+
θs[i], Hs[i], αs[i] = step(rng, h, prop, i == 1 ? θ : θs[i-1])
2532
if i <= n_adapts
2633
adapt!(adaptor, θs[i], αs[i])
2734
h, prop = update(h, prop, adaptor)
28-
verbose && i == n_adapts && @info "Finished $n_adapts adapation steps" typeof(adaptor) prop.integrator.ϵ h.metric
35+
if verbose
36+
if i == n_adapts
37+
@info "Finished $n_adapts adapation steps" typeof(adaptor) prop.integrator.ϵ h.metric
38+
elseif i % Int(n_adapts / 10) == 0
39+
@info "Adapting $i of $n_adapts steps" typeof(adaptor) prop.integrator.ϵ h.metric
40+
end
41+
end
2942
end
3043
end
31-
verbose && @info "Finished $n_samples sampling steps in $time (s)" typeof(h) typeof(prop) EBFMI(Hs) mean(αs)
44+
verbose && @info "Finished $n_samples sampling steps in $time (s)" typeof(h.metric) typeof(prop) EBFMI(Hs) mean(αs)
3245
return θs
3346
end
3447

3548
function step(rng::AbstractRNG, h::Hamiltonian, prop::TakeLastProposal{I}, θ::AbstractVector{T}) where {T<:Real,I<:AbstractIntegrator}
3649
r = rand_momentum(rng, h)
3750
H = hamiltonian_energy(h, θ, r)
38-
θ_new, r_new = propose(prop, h, θ, r)
51+
θ_new, r_new = transition(prop, h, θ, r)
3952
H_new = hamiltonian_energy(h, θ_new, r_new)
4053
# Accept via MH criteria
4154
is_accept, α = mh_accept(rng, H, H_new)
@@ -47,7 +60,7 @@ end
4760

4861
function step(rng::AbstractRNG, h::Hamiltonian, prop::NUTS{I}, θ::AbstractVector{T}) where {T<:Real,I<:AbstractIntegrator}
4962
r = rand_momentum(rng, h)
50-
θ_new, r_new, α = propose(rng, prop, h, θ, r)
63+
θ_new, r_new, α = transition(rng, prop, h, θ, r)
5164
H_new = hamiltonian_energy(h, θ_new, r_new)
5265
# We always accept in NUTS
5366
return θ_new, H_new, α

test/proposal.jl

+8-6
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@ prop = NUTS(Leapfrog(find_good_eps(h, θ_init)))
1111
r_init = AdvancedHMC.rand_momentum(h)
1212

1313
@testset "Passing random number generator" begin
14-
rng = MersenneTwister(1234)
15-
θ1, r1 = AdvancedHMC.propose(rng, prop, h, θ_init, r_init)
14+
for seed in [1234, 5678, 90]
15+
rng = MersenneTwister(seed)
16+
θ1, r1 = AdvancedHMC.transition(rng, prop, h, θ_init, r_init)
1617

17-
rng = MersenneTwister(1234)
18-
θ2, r2 = AdvancedHMC.propose(rng, prop, h, θ_init, r_init)
18+
rng = MersenneTwister(seed)
19+
θ2, r2 = AdvancedHMC.transition(rng, prop, h, θ_init, r_init)
1920

20-
@test θ1 == θ2
21-
@test r1 == r2
21+
@test θ1 == θ2
22+
@test r1 == r2
23+
end
2224
end

0 commit comments

Comments
 (0)