Skip to content

Commit 3a4b384

Browse files
JaimeRZPgithub-actions[bot]yebai
authoredJul 28, 2023
NUTS kernel options (#342)
* pass options * bump * Update src/abstractmcmc.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * tests for kernel hyperparameters * format * test * bring back all tests * bug * make_init_params bug * more tests+ init_params bug * more tests+ init_params bug * tests for bug * format * catch HMC case * Typofix. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Hong Ge <hg344@cam.ac.uk>
1 parent 762e55f commit 3a4b384

File tree

3 files changed

+63
-6
lines changed

3 files changed

+63
-6
lines changed
 

‎Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedHMC"
22
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
3-
version = "0.5.1"
3+
version = "0.5.2"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

‎src/abstractmcmc.jl

+10-3
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function AbstractMCMC.step(
117117

118118
# Define integration algorithm
119119
# Find good eps if not provided one
120-
init_params = make_init_params(spl, logdensity, init_params)
120+
init_params = make_init_params(rng, spl, logdensity, init_params)
121121
ϵ = make_step_size(rng, spl, hamiltonian, init_params)
122122
integrator = make_integrator(spl, ϵ)
123123

@@ -251,7 +251,12 @@ end
251251
#############
252252
### Utils ###
253253
#############
254-
function make_init_params(spl::AbstractHMCSampler, logdensity, init_params)
254+
function make_init_params(
255+
rng::AbstractRNG,
256+
spl::AbstractHMCSampler,
257+
logdensity,
258+
init_params,
259+
)
255260
T = sampler_eltype(spl)
256261
if init_params == nothing
257262
d = LogDensityProblems.dimension(logdensity)
@@ -354,7 +359,9 @@ end
354359
#########
355360

356361
function make_kernel(spl::NUTS, integrator::AbstractIntegrator)
357-
return HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
362+
return HMCKernel(
363+
Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn(spl.max_depth, spl.Δ_max)),
364+
)
358365
end
359366

360367
function make_kernel(spl::HMC, integrator::AbstractIntegrator)

‎test/constructors.jl

+52-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
using AdvancedHMC, AbstractMCMC, Random
22
include("common.jl")
33

4+
get_kernel_hyperparams(spl::HMC, state) = state.κ.τ.termination_criterion.L
5+
get_kernel_hyperparams(spl::HMCDA, state) = state.κ.τ.termination_criterion.λ
6+
get_kernel_hyperparams(spl::NUTS, state) =
7+
state.κ.τ.termination_criterion.max_depth, state.κ.τ.termination_criterion.Δ_max
8+
9+
get_kernel_hyperparamsT(spl::HMC, state) = typeof(state.κ.τ.termination_criterion.L)
10+
get_kernel_hyperparamsT(spl::HMCDA, state) = typeof(state.κ.τ.termination_criterion.λ)
11+
get_kernel_hyperparamsT(spl::NUTS, state) = typeof(state.κ.τ.termination_criterion.Δ_max)
12+
413
@testset "Constructors" begin
514
d = 2
615
θ_init = randn(d)
16+
rng = Random.default_rng()
717
model = AbstractMCMC.LogDensityModel(ℓπ_gdemo)
818

919
@testset "$T" for T in [Float32, Float64]
@@ -14,6 +24,7 @@ include("common.jl")
1424
adaptor_type = NoAdaptation,
1525
metric_type = DiagEuclideanMetric{T},
1626
integrator_type = Leapfrog{T},
27+
kernel_hp = 25,
1728
),
1829
),
1930
(
@@ -22,6 +33,7 @@ include("common.jl")
2233
adaptor_type = NoAdaptation,
2334
metric_type = DiagEuclideanMetric{T},
2435
integrator_type = Leapfrog{T},
36+
kernel_hp = 25,
2537
),
2638
),
2739
(
@@ -30,6 +42,7 @@ include("common.jl")
3042
adaptor_type = NoAdaptation,
3143
metric_type = DiagEuclideanMetric{T},
3244
integrator_type = Leapfrog{T},
45+
kernel_hp = 25,
3346
),
3447
),
3548
(
@@ -38,6 +51,7 @@ include("common.jl")
3851
adaptor_type = NoAdaptation,
3952
metric_type = UnitEuclideanMetric{T},
4053
integrator_type = Leapfrog{T},
54+
kernel_hp = 25,
4155
),
4256
),
4357
(
@@ -46,6 +60,7 @@ include("common.jl")
4660
adaptor_type = NoAdaptation,
4761
metric_type = DenseEuclideanMetric{T},
4862
integrator_type = Leapfrog{T},
63+
kernel_hp = 25,
4964
),
5065
),
5166
(
@@ -54,6 +69,7 @@ include("common.jl")
5469
adaptor_type = NesterovDualAveraging,
5570
metric_type = DiagEuclideanMetric{T},
5671
integrator_type = Leapfrog{T},
72+
kernel_hp = one(T),
5773
),
5874
),
5975
# This should perform the correct promotion for the 2nd argument.
@@ -63,14 +79,16 @@ include("common.jl")
6379
adaptor_type = NesterovDualAveraging,
6480
metric_type = DiagEuclideanMetric{T},
6581
integrator_type = Leapfrog{T},
82+
kernel_hp = one(T),
6683
),
6784
),
6885
(
69-
NUTS(T(0.8)),
86+
NUTS(T(0.8); max_depth = 20, Δ_max = T(2000.0)),
7087
(
7188
adaptor_type = StanHMCAdaptor,
7289
metric_type = DiagEuclideanMetric{T},
7390
integrator_type = Leapfrog{T},
91+
kernel_hp = (20, T(2000.0)),
7492
),
7593
),
7694
(
@@ -79,6 +97,7 @@ include("common.jl")
7997
adaptor_type = StanHMCAdaptor,
8098
metric_type = UnitEuclideanMetric{T},
8199
integrator_type = Leapfrog{T},
100+
kernel_hp = (10, T(1000.0)),
82101
),
83102
),
84103
(
@@ -87,6 +106,7 @@ include("common.jl")
87106
adaptor_type = StanHMCAdaptor,
88107
metric_type = DenseEuclideanMetric{T},
89108
integrator_type = Leapfrog{T},
109+
kernel_hp = (10, T(1000.0)),
90110
),
91111
),
92112
(
@@ -95,6 +115,7 @@ include("common.jl")
95115
adaptor_type = StanHMCAdaptor,
96116
metric_type = DiagEuclideanMetric{T},
97117
integrator_type = JitteredLeapfrog{T,T},
118+
kernel_hp = (10, T(1000.0)),
98119
),
99120
),
100121
(
@@ -103,14 +124,14 @@ include("common.jl")
103124
adaptor_type = StanHMCAdaptor,
104125
metric_type = DiagEuclideanMetric{T},
105126
integrator_type = TemperedLeapfrog{T,T},
127+
kernel_hp = (10, T(1000.0)),
106128
),
107129
),
108130
]
109131
# Make sure the sampler element type is preserved.
110132
@test AdvancedHMC.sampler_eltype(sampler) == T
111133

112134
# Step.
113-
rng = Random.default_rng()
114135
transition, state =
115136
AbstractMCMC.step(rng, model, sampler; n_adapts = 0, init_params = θ_init)
116137

@@ -126,6 +147,35 @@ include("common.jl")
126147
@test AdvancedHMC.getmetric(state) isa expected.metric_type
127148
@test AdvancedHMC.getintegrator(state) isa expected.integrator_type
128149
@test AdvancedHMC.getadaptor(state) isa expected.adaptor_type
150+
151+
# Verify that the kernel is receiving the hyperparameters
152+
@test get_kernel_hyperparams(sampler, state) == expected.kernel_hp
153+
if typeof(sampler) <: HMC
154+
@test get_kernel_hyperparamsT(sampler, state) == Int64
155+
else
156+
@test get_kernel_hyperparamsT(sampler, state) == T
157+
end
129158
end
130159
end
131160
end
161+
162+
@testset "Utils" begin
163+
@testset "init_params" begin
164+
d = 2
165+
θ_init = randn(d)
166+
rng = Random.default_rng()
167+
model = AbstractMCMC.LogDensityModel(ℓπ_gdemo)
168+
logdensity = model.logdensity
169+
spl = NUTS(0.8)
170+
T = AdvancedHMC.sampler_eltype(spl)
171+
172+
metric = make_metric(spl, logdensity)
173+
hamiltonian = Hamiltonian(metric, model)
174+
175+
init_params1 = make_init_params(rng, spl, logdensity, nothing)
176+
@test typeof(init_params1) == Vector{T}
177+
@test length(init_params1) == d
178+
init_params2 = make_init_params(rng, spl, logdensity, θ_init)
179+
@test init_params2 === θ_init
180+
end
181+
end

0 commit comments

Comments
 (0)