@@ -7,12 +7,12 @@ struct TakeLastProposal{I<:AbstractIntegrator} <: StaticTrajectory{I}
7
7
n_steps :: Int
8
8
end
9
9
10
- # Create a `TakeLastProposal` with a new `ϵ`
11
- function (tlp:: TakeLastProposal )(ϵ :: AbstractFloat )
12
- return TakeLastProposal (tlp . integrator (ϵ) , tlp. n_steps)
10
+ # Create a `TakeLastProposal` with a new integrator
11
+ function (tlp:: TakeLastProposal )(integrator :: AbstractIntegrator )
12
+ return TakeLastProposal (integrator, tlp. n_steps)
13
13
end
14
14
15
- function propose (prop:: TakeLastProposal , h:: Hamiltonian , θ:: AbstractVector{T} , r:: AbstractVector{T} ) where {T<: Real }
15
+ function transition (prop:: TakeLastProposal , h:: Hamiltonian , θ:: AbstractVector{T} , r:: AbstractVector{T} ) where {T<: Real }
16
16
θ, r, _ = steps (prop. integrator, h, θ, r, prop. n_steps)
17
17
return θ, - r
18
18
end
@@ -21,21 +21,22 @@ abstract type DynamicTrajectory{I<:AbstractIntegrator} <: AbstractHamiltonianTra
21
21
abstract type NoUTurnTrajectory{I<: AbstractIntegrator } <: DynamicTrajectory{I} end
22
22
struct NUTS{I<: AbstractIntegrator } <: NoUTurnTrajectory{I}
23
23
integrator :: I
24
+ max_depth :: Int
25
+ Δ_max :: AbstractFloat
24
26
end
25
27
26
- # Create a `NUTS` with a new `ϵ`
27
- function (snuts:: NUTS )(ϵ:: AbstractFloat )
28
- return NUTS (snuts. integrator (ϵ))
28
+ # Helper function to use default values
29
+ NUTS (integrator:: AbstractIntegrator ) = NUTS (integrator, 10 , 1000.0 )
30
+
31
+ # Create a `NUTS` with a new integrator
32
+ function (snuts:: NUTS )(integrator:: AbstractIntegrator )
33
+ return NUTS (integrator, snuts. max_depth, snuts. Δ_max)
29
34
end
30
35
31
36
struct MultinomialNUTS{I<: AbstractIntegrator } <: NoUTurnTrajectory{I}
32
37
integrator :: I
33
38
end
34
39
35
- function NUTS (h:: Hamiltonian , θ:: AbstractVector{T} ) where {T<: Real }
36
- return NUTS (Leapfrog (find_good_eps (h, θ)))
37
- end
38
-
39
40
function find_good_eps (rng:: AbstractRNG , h:: Hamiltonian , θ:: AbstractVector{T} ; max_n_iters:: Int = 100 ) where {T<: Real }
40
41
ϵ′ = ϵ = 0.1
41
42
a_min, a_cross, a_max = 0.25 , 0.5 , 0.75 # minimal, crossing, maximal accept ratio
93
94
find_good_eps (h:: Hamiltonian , θ:: AbstractVector{T} ; max_n_iters:: Int = 100 ) where {T<: Real } = find_good_eps (GLOBAL_RNG, h, θ; max_n_iters= max_n_iters)
94
95
95
96
# TODO : implement a more efficient way to build the balance tree
96
- function build_tree (rng:: AbstractRNG , nt:: NoUTurnTrajectory{I} , h:: Hamiltonian , θ:: AbstractVector{T} , r:: AbstractVector{T} , logu :: AbstractFloat , v :: Int , j :: Int , H :: AbstractFloat ;
97
- Δ_max :: AbstractFloat = 1000.0 ) where {I<: AbstractIntegrator ,T<: Real }
97
+ function build_tree (rng:: AbstractRNG , nt:: NoUTurnTrajectory{I} , h:: Hamiltonian , θ:: AbstractVector{T} , r:: AbstractVector{T} ,
98
+ logu :: AbstractFloat , v :: Int , j :: Int , H :: AbstractFloat ) where {I<: AbstractIntegrator ,T<: Real }
98
99
if j == 0
99
100
# Base case - take one leapfrog step in the direction v.
100
101
θ′, r′, _is_valid = step (nt. integrator, h, θ, r)
101
102
H′ = _is_valid ? hamiltonian_energy (h, θ′, r′) : Inf
102
103
n′ = (logu <= - H′) ? 1 : 0
103
- s′ = (logu < Δ_max + - H′) ? 1 : 0
104
+ s′ = (logu < nt . Δ_max + - H′) ? 1 : 0
104
105
α′ = exp (min (0 , H - H′))
105
106
106
107
return θ′, r′, θ′, r′, θ′, r′, n′, s′, α′, 1
@@ -128,18 +129,17 @@ function build_tree(rng::AbstractRNG, nt::NoUTurnTrajectory{I}, h::Hamiltonian,
128
129
end
129
130
end
130
131
131
- build_tree (nt:: NoUTurnTrajectory{I} , h:: Hamiltonian , θ:: AbstractVector{T} , r:: AbstractVector{T} , logu :: AbstractFloat , v :: Int , j :: Int , H :: AbstractFloat ;
132
- Δ_max :: AbstractFloat = 1000.0 ) where {I<: AbstractIntegrator ,T<: Real } = build_tree (GLOBAL_RNG, nt, h, θ, r, logu, v, j, H; Δ_max = Δ_max )
132
+ build_tree (nt:: NoUTurnTrajectory{I} , h:: Hamiltonian , θ:: AbstractVector{T} , r:: AbstractVector{T} ,
133
+ logu :: AbstractFloat , v :: Int , j :: Int , H :: AbstractFloat ) where {I<: AbstractIntegrator ,T<: Real } = build_tree (GLOBAL_RNG, nt, h, θ, r, logu, v, j, H)
133
134
134
- function propose (rng:: AbstractRNG , nt:: NoUTurnTrajectory{I} , h:: Hamiltonian , θ:: AbstractVector{T} , r:: AbstractVector{T} ;
135
- j_max:: Int = 10 ) where {I<: AbstractIntegrator ,T<: Real }
135
+ function transition (rng:: AbstractRNG , nt:: NoUTurnTrajectory{I} , h:: Hamiltonian , θ:: AbstractVector{T} , r:: AbstractVector{T} ) where {I<: AbstractIntegrator ,T<: Real }
136
136
H = hamiltonian_energy (h, θ, r)
137
137
logu = log (rand (rng)) - H
138
138
139
139
θm = θ; θp = θ; rm = r; rp = r; j = 0 ; θ_new = θ; r_new = r; n = 1 ; s = 1
140
140
141
141
local α, nα
142
- while s == 1 && j <= j_max
142
+ while s == 1 && j <= nt . max_depth
143
143
v = rand (rng, [- 1 , 1 ])
144
144
if v == - 1
145
145
θm, rm, _, _, θ′, r′,n′, s′, α, nα = build_tree (rng, nt, h, θm, rm, logu, v, j, H)
@@ -162,8 +162,7 @@ function propose(rng::AbstractRNG, nt::NoUTurnTrajectory{I}, h::Hamiltonian, θ:
162
162
return θ_new, r_new, α / nα
163
163
end
164
164
165
- propose (nt:: NoUTurnTrajectory{I} , h:: Hamiltonian , θ:: AbstractVector{T} , r:: AbstractVector{T} ;
166
- j_max:: Int = 10 ) where {I<: AbstractIntegrator ,T<: Real } = propose (GLOBAL_RNG, nt, h, θ, r; j_max= j_max)
165
+ transition (nt:: NoUTurnTrajectory{I} , h:: Hamiltonian , θ:: AbstractVector{T} , r:: AbstractVector{T} ) where {I<: AbstractIntegrator ,T<: Real } = transition (GLOBAL_RNG, nt, h, θ, r)
167
166
168
167
function MultinomialNUTS (h:: Hamiltonian , θ:: AbstractVector{T} ) where {T<: Real }
169
168
return MultinomialNUTS (Leapfrog (find_good_eps (h, θ)))
0 commit comments