Skip to content

Commit 75fb7ec

Browse files
devmotiongithub-actions[bot]penelopeysm
authored
Fix fields with abstract types (#399)
* Fix fields with abstract types * Fix format Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix tests * Fix typo * Update version number --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Penelope Yong <[email protected]>
1 parent 35e6a01 commit 75fb7ec

File tree

5 files changed

+27
-25
lines changed

5 files changed

+27
-25
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedHMC"
22
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
3-
version = "0.6.4"
3+
version = "0.7.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/adaptation/stepsize.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ References
7171
Hoffman, M. D., & Gelman, A. (2014). The No-U-Turn Sampler: adaptively setting path lengths in Hamiltonian Monte Carlo. Journal of Machine Learning Research, 15(1), 1593-1623.
7272
Nesterov, Y. (2009). Primal-dual subgradient methods for convex problems. Mathematical programming, 120(1), 221-259.
7373
"""
74-
struct NesterovDualAveraging{T<:AbstractFloat} <: StepSizeAdaptor
74+
struct NesterovDualAveraging{T<:AbstractFloat,S<:AbstractScalarOrVec{T}} <: StepSizeAdaptor
7575
γ::T
7676
t_0::T
7777
κ::T
7878
δ::T
79-
state::DAState{<:AbstractScalarOrVec{T}}
79+
state::DAState{S}
8080
end
8181
Base.show(io::IO, a::NesterovDualAveraging) = print(
8282
io,

src/constructors.jl

+4-3
Original file line numberDiff line numberDiff line change
@@ -144,15 +144,16 @@ For more information, please view the following paper ([arXiv link](https://arxi
144144
setting path lengths in Hamiltonian Monte Carlo." Journal of Machine Learning
145145
Research 15, no. 1 (2014): 1593-1623.
146146
"""
147-
struct HMCDA{T<:Real} <: AbstractHMCSampler
147+
struct HMCDA{T<:Real,I<:Union{Symbol,AbstractIntegrator},M<:Union{Symbol,AbstractMetric}} <:
148+
AbstractHMCSampler
148149
"Target acceptance rate for dual averaging."
149150
δ::T
150151
"Target leapfrog length."
151152
λ::T
152153
"Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)"
153-
integrator::Union{Symbol,AbstractIntegrator}
154+
integrator::I
154155
"Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption."
155-
metric::Union{Symbol,AbstractMetric}
156+
metric::M
156157
end
157158

158159
function HMCDA(δ, λ; integrator = :leapfrog, metric = :diagonal)

src/trajectory.jl

+11-11
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ It contains the slice variable and the number of acceptable condidates in the tr
9999
100100
$(TYPEDFIELDS)
101101
"""
102-
struct SliceTS{F<:AbstractFloat} <: AbstractTrajectorySampler
102+
struct SliceTS{F<:AbstractFloat,P<:PhasePoint} <: AbstractTrajectorySampler
103103
"Sampled candidate `PhasePoint`."
104-
zcand::PhasePoint
104+
zcand::P
105105
"Slice variable in log-space."
106106
ℓu::F
107107
"Number of acceptable candidates, i.e. those with probability larger than slice variable `u`."
@@ -120,9 +120,9 @@ It contains the weight of the tree, defined as the total probabilities of the le
120120
121121
$(TYPEDFIELDS)
122122
"""
123-
struct MultinomialTS{F<:AbstractFloat} <: AbstractTrajectorySampler
123+
struct MultinomialTS{F<:AbstractFloat,P<:PhasePoint} <: AbstractTrajectorySampler
124124
"Sampled candidate `PhasePoint`."
125-
zcand::PhasePoint
125+
zcand::P
126126
"Total energy for the given tree, i.e. the sum of energies of all leaves."
127127
ℓw::F
128128
end
@@ -499,13 +499,13 @@ end
499499
"""
500500
A full binary tree trajectory with only necessary leaves and information stored.
501501
"""
502-
struct BinaryTree
503-
zleft::Any # left most leaf node
504-
zright::Any # right most leaf node
505-
ts::Any # turn statistics
506-
sum_α::Any # MH stats, i.e. sum of MH accept prob for all leapfrog steps
507-
::Any # total # of leap frog steps, i.e. phase points in a trajectory
508-
ΔH_max::Any # energy in tree with largest absolute different from initial energy
502+
struct BinaryTree{T<:Real,P<:PhasePoint,TS<:TurnStatistic}
503+
zleft::P # left most leaf node
504+
zright::P # right most leaf node
505+
ts::TS # turn statistics
506+
sum_α::T # MH stats, i.e. sum of MH accept prob for all leapfrog steps
507+
::Int # total # of leap frog steps, i.e. phase points in a trajectory
508+
ΔH_max::T # energy in tree with largest absolute different from initial energy
509509
end
510510

511511
"""

test/trajectory.jl

+9-8
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ ahmc_isturn(h, z0, z1, rho, v = 1) =
7272
AdvancedHMC.isterminated(
7373
ClassicNoUTurn(),
7474
h,
75-
AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(), 0, 0, 0.0),
75+
AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(), 0.0, 0, 0.0),
7676
).dynamic
7777

7878
function hand_isturn_generalised(z0, z1, rho, v = 1)
@@ -84,16 +84,16 @@ ahmc_isturn_generalised(h, z0, z1, rho, v = 1) =
8484
AdvancedHMC.isterminated(
8585
GeneralisedNoUTurn(),
8686
h,
87-
AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0, 0, 0.0),
87+
AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0.0, 0, 0.0),
8888
).dynamic
8989

9090
function ahmc_isturn_strictgeneralised(h, z0, z1, rho, v = 1)
9191
t = AdvancedHMC.isterminated(
9292
StrictGeneralisedNoUTurn(),
9393
h,
94-
AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0, 0, 0.0),
95-
AdvancedHMC.BinaryTree(z0, z0, AdvancedHMC.TurnStatistic(rho - z1.r), 0, 0, 0.0),
96-
AdvancedHMC.BinaryTree(z1, z1, AdvancedHMC.TurnStatistic(rho - z0.r), 0, 0, 0.0),
94+
AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0.0, 0, 0.0),
95+
AdvancedHMC.BinaryTree(z0, z0, AdvancedHMC.TurnStatistic(rho - z1.r), 0.0, 0, 0.0),
96+
AdvancedHMC.BinaryTree(z1, z1, AdvancedHMC.TurnStatistic(rho - z0.r), 0.0, 0, 0.0),
9797
)
9898
return t.dynamic
9999
end
@@ -102,13 +102,14 @@ end
102102
Check whether the subtree checks adequately detect U-turns.
103103
"""
104104
function check_subtree_u_turns(h, z0, z1, rho)
105-
t = AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0, 0, 0.0)
105+
t = AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0.0, 0, 0.0)
106106
# The left and right subtree are created in such a way that the
107107
# check_left_subtree and check_right_subtree checks should be equivalent
108108
# to the general no U-turn check.
109-
tleft = AdvancedHMC.BinaryTree(z0, z0, AdvancedHMC.TurnStatistic(rho - z1.r), 0, 0, 0.0)
109+
tleft =
110+
AdvancedHMC.BinaryTree(z0, z0, AdvancedHMC.TurnStatistic(rho - z1.r), 0.0, 0, 0.0)
110111
tright =
111-
AdvancedHMC.BinaryTree(z1, z1, AdvancedHMC.TurnStatistic(rho - z0.r), 0, 0, 0.0)
112+
AdvancedHMC.BinaryTree(z1, z1, AdvancedHMC.TurnStatistic(rho - z0.r), 0.0, 0, 0.0)
112113

113114
s1 = AdvancedHMC.isterminated(GeneralisedNoUTurn(), h, t)
114115
s2 = AdvancedHMC.check_left_subtree(h, t, tleft, tright)

0 commit comments

Comments
 (0)