1
1
using AdvancedHMC, AbstractMCMC, Random
2
2
include (" common.jl" )
3
3
4
+ get_kernel_hyperparams (spl:: HMC , state) = state. κ. τ. termination_criterion. L
5
+ get_kernel_hyperparams (spl:: HMCDA , state) = state. κ. τ. termination_criterion. λ
6
+ get_kernel_hyperparams (spl:: NUTS , state) =
7
+ state. κ. τ. termination_criterion. max_depth, state. κ. τ. termination_criterion. Δ_max
8
+
9
+ get_kernel_hyperparamsT (spl:: HMC , state) = typeof (state. κ. τ. termination_criterion. L)
10
+ get_kernel_hyperparamsT (spl:: HMCDA , state) = typeof (state. κ. τ. termination_criterion. λ)
11
+ get_kernel_hyperparamsT (spl:: NUTS , state) = typeof (state. κ. τ. termination_criterion. Δ_max)
12
+
4
13
@testset " Constructors" begin
5
14
d = 2
6
15
θ_init = randn (d)
16
+ rng = Random. default_rng ()
7
17
model = AbstractMCMC. LogDensityModel (ℓπ_gdemo)
8
18
9
19
@testset " $T " for T in [Float32, Float64]
@@ -14,6 +24,7 @@ include("common.jl")
14
24
adaptor_type = NoAdaptation,
15
25
metric_type = DiagEuclideanMetric{T},
16
26
integrator_type = Leapfrog{T},
27
+ kernel_hp = 25 ,
17
28
),
18
29
),
19
30
(
@@ -22,6 +33,7 @@ include("common.jl")
22
33
adaptor_type = NoAdaptation,
23
34
metric_type = DiagEuclideanMetric{T},
24
35
integrator_type = Leapfrog{T},
36
+ kernel_hp = 25 ,
25
37
),
26
38
),
27
39
(
@@ -30,6 +42,7 @@ include("common.jl")
30
42
adaptor_type = NoAdaptation,
31
43
metric_type = DiagEuclideanMetric{T},
32
44
integrator_type = Leapfrog{T},
45
+ kernel_hp = 25 ,
33
46
),
34
47
),
35
48
(
@@ -38,6 +51,7 @@ include("common.jl")
38
51
adaptor_type = NoAdaptation,
39
52
metric_type = UnitEuclideanMetric{T},
40
53
integrator_type = Leapfrog{T},
54
+ kernel_hp = 25 ,
41
55
),
42
56
),
43
57
(
@@ -46,6 +60,7 @@ include("common.jl")
46
60
adaptor_type = NoAdaptation,
47
61
metric_type = DenseEuclideanMetric{T},
48
62
integrator_type = Leapfrog{T},
63
+ kernel_hp = 25 ,
49
64
),
50
65
),
51
66
(
@@ -54,6 +69,7 @@ include("common.jl")
54
69
adaptor_type = NesterovDualAveraging,
55
70
metric_type = DiagEuclideanMetric{T},
56
71
integrator_type = Leapfrog{T},
72
+ kernel_hp = one (T),
57
73
),
58
74
),
59
75
# This should perform the correct promotion for the 2nd argument.
@@ -63,14 +79,16 @@ include("common.jl")
63
79
adaptor_type = NesterovDualAveraging,
64
80
metric_type = DiagEuclideanMetric{T},
65
81
integrator_type = Leapfrog{T},
82
+ kernel_hp = one (T),
66
83
),
67
84
),
68
85
(
69
- NUTS (T (0.8 )),
86
+ NUTS (T (0.8 ); max_depth = 20 , Δ_max = T ( 2000.0 ) ),
70
87
(
71
88
adaptor_type = StanHMCAdaptor,
72
89
metric_type = DiagEuclideanMetric{T},
73
90
integrator_type = Leapfrog{T},
91
+ kernel_hp = (20 , T (2000.0 )),
74
92
),
75
93
),
76
94
(
@@ -79,6 +97,7 @@ include("common.jl")
79
97
adaptor_type = StanHMCAdaptor,
80
98
metric_type = UnitEuclideanMetric{T},
81
99
integrator_type = Leapfrog{T},
100
+ kernel_hp = (10 , T (1000.0 )),
82
101
),
83
102
),
84
103
(
@@ -87,6 +106,7 @@ include("common.jl")
87
106
adaptor_type = StanHMCAdaptor,
88
107
metric_type = DenseEuclideanMetric{T},
89
108
integrator_type = Leapfrog{T},
109
+ kernel_hp = (10 , T (1000.0 )),
90
110
),
91
111
),
92
112
(
@@ -95,6 +115,7 @@ include("common.jl")
95
115
adaptor_type = StanHMCAdaptor,
96
116
metric_type = DiagEuclideanMetric{T},
97
117
integrator_type = JitteredLeapfrog{T,T},
118
+ kernel_hp = (10 , T (1000.0 )),
98
119
),
99
120
),
100
121
(
@@ -103,14 +124,14 @@ include("common.jl")
103
124
adaptor_type = StanHMCAdaptor,
104
125
metric_type = DiagEuclideanMetric{T},
105
126
integrator_type = TemperedLeapfrog{T,T},
127
+ kernel_hp = (10 , T (1000.0 )),
106
128
),
107
129
),
108
130
]
109
131
# Make sure the sampler element type is preserved.
110
132
@test AdvancedHMC. sampler_eltype (sampler) == T
111
133
112
134
# Step.
113
- rng = Random. default_rng ()
114
135
transition, state =
115
136
AbstractMCMC. step (rng, model, sampler; n_adapts = 0 , init_params = θ_init)
116
137
@@ -126,6 +147,35 @@ include("common.jl")
126
147
@test AdvancedHMC. getmetric (state) isa expected. metric_type
127
148
@test AdvancedHMC. getintegrator (state) isa expected. integrator_type
128
149
@test AdvancedHMC. getadaptor (state) isa expected. adaptor_type
150
+
151
+ # Verify that the kernel is receiving the hyperparameters
152
+ @test get_kernel_hyperparams (sampler, state) == expected. kernel_hp
153
+ if typeof (sampler) <: HMC
154
+ @test get_kernel_hyperparamsT (sampler, state) == Int64
155
+ else
156
+ @test get_kernel_hyperparamsT (sampler, state) == T
157
+ end
129
158
end
130
159
end
131
160
end
161
+
162
+ @testset " Utils" begin
163
+ @testset " init_params" begin
164
+ d = 2
165
+ θ_init = randn (d)
166
+ rng = Random. default_rng ()
167
+ model = AbstractMCMC. LogDensityModel (ℓπ_gdemo)
168
+ logdensity = model. logdensity
169
+ spl = NUTS (0.8 )
170
+ T = AdvancedHMC. sampler_eltype (spl)
171
+
172
+ metric = make_metric (spl, logdensity)
173
+ hamiltonian = Hamiltonian (metric, model)
174
+
175
+ init_params1 = make_init_params (rng, spl, logdensity, nothing )
176
+ @test typeof (init_params1) == Vector{T}
177
+ @test length (init_params1) == d
178
+ init_params2 = make_init_params (rng, spl, logdensity, θ_init)
179
+ @test init_params2 === θ_init
180
+ end
181
+ end
0 commit comments