Skip to content

Commit 9224ef0

Browse files
committed
Move number of samples argument(s) into algorithms
1 parent 7077399 commit 9224ef0

38 files changed

+173
-273
lines changed

docs/src/internal_api.md

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Order = [:macro, :function]
3333

3434
```@docs
3535
BAT.AbstractProposalDist
36+
BAT.AbstractSampleGenerator
3637
BAT.BasicMvStatistics
3738
BAT.DataSet
3839
BAT.HMIData

docs/src/stable_api.md

-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ AbstractDensity
5353
AbstractDensityTransformTarget
5454
AbstractMCMCWeightingScheme
5555
AbstractPosteriorDensity
56-
AbstractSampleGenerator
5756
AbstractTransformedDensity
5857
AbstractTransformToInfinite
5958
AbstractTransformToUnitspace

docs/src/tutorial_lit.jl

+5-18
Original file line numberDiff line numberDiff line change
@@ -244,16 +244,10 @@ posterior = PosteriorDensity(likelihood, prior)
244244
#nb ENV["JULIA_DEBUG"] = "BAT"
245245
#jl ENV["JULIA_DEBUG"] = "BAT"
246246

247-
# Let's use 4 MCMC chains and require 10^5 unique samples from each chain
248-
# (after tuning/burn-in):
247+
# Now we can generate a set of MCMC samples via [`bat_sample`](@ref). We'll
248+
# use 4 MCMC chains with 10^5 MC steps in each chain (after tuning/burn-in):
249249

250-
nsamples = 10^4
251-
#md nothing # hide
252-
253-
254-
# Now we can generate a set of MCMC samples via [`bat_sample`](@ref):
255-
256-
samples = bat_sample(posterior, nsamples, MCMCSampling(sampler = MetropolisHastings(), nchains = 4)).result
250+
samples = bat_sample(posterior, MCMCSampling(sampler = MetropolisHastings(), nsteps = 10^4, nchains = 4)).result
257251
#md nothing # hide
258252
#nb nothing # hide
259253

@@ -447,12 +441,13 @@ convergence = BrooksGelmanConvergence()
447441

448442
samples = bat_sample(
449443
rng, posterior,
450-
nsamples,
451444
MCMCSampling(
452445
sampler = MetropolisHastings(
453446
weighting = RepetitionWeighting(),
454447
tuning = tuning
455448
),
449+
nchains = 4,
450+
nsteps = 10^5,
456451
init = init,
457452
burnin = burnin,
458453
convergence = convergence,
@@ -464,11 +459,3 @@ samples = bat_sample(
464459
).result
465460
#md nothing # hide
466461
#nb nothing # hide
467-
468-
# However, in many use cases, simply using the default options via
469-
#
470-
# ```julia
471-
# samples = bat_sample(posterior, nsamples).result
472-
# ```
473-
#
474-
# will often be sufficient.

examples/benchmarks/benchmarks.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ function setup_benchmark()
1919
include("run_benchmark_ND.jl")
2020
end
2121

22-
function do_benchmarks(;algorithm=MetropolisHastings(), n_samples=10^5, n_chains=8)
23-
#run_1D_benchmark(algorithm=algorithm, n_samples=n_samples, n_chains=n_chains)
24-
run_2D_benchmark(algorithm=algorithm, n_samples=n_samples, n_chains=n_chains)
25-
run_ND_benchmark(n_dim=2:2:20,algorithm=MetropolisHastings(), n_samples=2*10^5, n_chains=4)
22+
function do_benchmarks(;algorithm=MetropolisHastings(), n_steps=10^5, n_chains=8)
23+
#run_1D_benchmark(algorithm=algorithm, n_steps=n_steps, n_chains=n_chains)
24+
run_2D_benchmark(algorithm=algorithm, n_steps=n_steps, n_chains=n_chains)
25+
run_ND_benchmark(n_dim=2:2:20,algorithm=MetropolisHastings(), n_steps=2*10^5, n_chains=4)
2626
run_ks_ahmc_vs_mh(n_dim=20:5:35)
2727
end
2828

examples/benchmarks/functions_1D.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@ testfunctions_1D = Dict(
1919
analytical_stats_name=["mode","mean","var"] #could be taken into NamedTuple for easier addtions but would be needed to implmented into calcs anyway
2020
sample_stats=[Vector{Float64}(undef,length(analytical_stats_name)) for i in 1:length(testfunctions_1D)]
2121

22-
run_stats_names = ["nsamples","nchains","Times"]
22+
run_stats_names = ["nsteps","nchains","Times"]
2323
run_stats=[Vector{Float64}(undef,length(run_stats_names)) for i in 1:length(testfunctions_1D)]

examples/benchmarks/functions_2D.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
stats_names2D = ["mode","mean","var"]
2-
run_stats_names2D = ["nsamples","nchains","Times"]
2+
run_stats_names2D = ["nsteps","nchains","Times"]
33
##########################################multi normal###########################################
44
sig = Matrix{Float64}([1.5^2 1.5*2.5*0.4 ; 1.5*2.5*0.4 2.5^2])
55
analytical_stats_gauss2D = Vector{Any}(undef,length(stats_names2D))

examples/benchmarks/run_benchmark_1D.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
function run_1D_benchmark(;algorithm=MetropolisHastings(), n_samples=10^5, n_chains=8)
1+
function run_1D_benchmark(;algorithm=MetropolisHastings(), n_steps=10^5, n_chains=8)
22
for i in 1:length(testfunctions_1D)
33
sample_stats_all = run1D(
44
collect(keys(testfunctions_1D))[i], #There might be a nicer way but I need the name to save the plots
55
testfunctions_1D,
66
sample_stats[i],
77
run_stats[i],
88
algorithm,
9-
n_samples,
9+
n_steps,
1010
n_chains
1111
)
1212
end

examples/benchmarks/run_benchmark_2D.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
function run_2D_benchmark(;algorithm = MetropolisHastings(),n_chains = 8,n_samples = 10^5)
1+
function run_2D_benchmark(;algorithm = MetropolisHastings(),n_chains = 8,n_steps = 10^5)
22
for i in 1:length(testfunctions_2D)
33
sample_stats_all = run2D(
44
collect(keys(testfunctions_2D))[i], #There might be a nicer way but I need the name to save the plots
55
testfunctions_2D,
66
sample_stats2D[i],
77
run_stats2D[i],
88
algorithm,
9-
n_samples,
9+
n_steps,
1010
n_chains
1111
)
1212
end

examples/benchmarks/run_benchmark_ND.jl

+13-11
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ function run_ND_benchmark(;
205205
n_dim = 2:2:20,
206206
algorithm = MetropolisHastings(),
207207
n_chains = 4,
208-
n_samples = 4*10^5,
208+
n_steps = 4*10^5,
209209
time_benchmark = true,
210210
ks_test_benchmark = true,
211211
ahmi_benchmark = true,
@@ -244,16 +244,17 @@ function run_ND_benchmark(;
244244
for j in 1:length(testfunctions)
245245

246246
dis = testfunctions[collect(keys(testfunctions))[j]]
247-
iid_sample = bat_sample(dis, n_samples*n_chains).result
247+
iid_sample = bat_sample(dis, MCMCSampling(sampler = MetropolisHastings(), nsteps = nsteps, nchains = nchains)).result
248248

249249
mcmc_sample = nothing
250250
tbf = time()
251251
if isa(algorithm,BAT.MetropolisHastings)
252252
mcmc_sample = bat_sample(
253-
dis, n_samples * n_chains,
253+
dis,
254254
MCMCSampling(
255255
sampler = algorithm,
256256
nchains = n_chains,
257+
nsteps = n_steps,
257258
init = init,
258259
burnin = burnin,
259260
convergence = convergence,
@@ -263,8 +264,8 @@ function run_ND_benchmark(;
263264
).result
264265
elseif isa(algorithm,BAT.HamiltonianMC)
265266
mcmc_sample = bat_sample(
266-
dis, n_samples*n_chains,
267-
MCMCSampling(sampler = algorithm)
267+
dis,
268+
MCMCSampling(sampler = algorithm, nchains = n_chains, nsteps = n_steps)
268269
).result
269270
end
270271
taf = time()
@@ -275,10 +276,11 @@ function run_ND_benchmark(;
275276
tbf = time()
276277
if isa(algorithm,BAT.MetropolisHastings)
277278
bat_sample(
278-
dis, n_samples * n_chains,
279+
dis,
279280
MCMCSampling(
280281
sampler = algorithm,
281282
nchains = n_chains,
283+
nsteps = n_steps,
282284
init = init,
283285
burnin = burnin,
284286
convergence = convergence,
@@ -288,8 +290,8 @@ function run_ND_benchmark(;
288290
).result
289291
elseif isa(algorithm,BAT.HamiltonianMC)
290292
bat_sample(
291-
dis, n_samples*n_chains,
292-
MCMCSampling(sampler = algorithm)
293+
dis,
294+
MCMCSampling(sampler = algorithm, nchains = n_chains, nsteps = n_steps)
293295
).result
294296
end
295297
taf = time()
@@ -357,8 +359,8 @@ function run_ND_benchmark(;
357359
return [ks_test,ahmi,times]
358360
end
359361

360-
function run_ks_ahmc_vs_mh(;n_dim=20:5:35,n_samples=2*10^5, n_chains=4)
361-
ks_res_ahmc = run_ND_benchmark(n_dim=n_dim,algorithm=HamiltonianMC(), n_samples=n_samples, n_chains=n_chains, time_benchmark=false,ahmi_benchmark=false,hmc_benchmark=true)[1]
362-
ks_res_mh = run_ND_benchmark(n_dim=n_dim,algorithm=MetropolisHastings(), n_samples=n_samples, n_chains=n_chains, time_benchmark=false,ahmi_benchmark=false,hmc_benchmark=true)[1]
362+
function run_ks_ahmc_vs_mh(;n_dim=20:5:35,n_steps=2*10^5, n_chains=4)
363+
ks_res_ahmc = run_ND_benchmark(n_dim=n_dim,algorithm=HamiltonianMC(), n_steps=n_steps, n_chains=n_chains, time_benchmark=false,ahmi_benchmark=false,hmc_benchmark=true)[1]
364+
ks_res_mh = run_ND_benchmark(n_dim=n_dim,algorithm=MetropolisHastings(), n_steps=n_steps, n_chains=n_chains, time_benchmark=false,ahmi_benchmark=false,hmc_benchmark=true)[1]
363365
plot_ks_values_ahmc_vs_mh(ks_res_ahmc,ks_res_mh,n_dim)
364366
end

examples/benchmarks/utils.jl

+11-11
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ function plot1D(
130130
end
131131

132132
nun = convert(Int64,floor(sum(hunnorm.weights)/10))
133-
unweighted_samples = bat_sample(samples, nun).result
133+
unweighted_samples = bat_sample(samples, OrderedResampling(nsamples = nun)).result
134134
hunnorm = fit(Histogram, [BAT.flatview(unweighted_samples.v)...],binning)
135135

136136
edges = hunnorm.edges[1]
@@ -183,7 +183,7 @@ function plot1D(
183183
end
184184
length(sample_stats) != 4 ? push!(sample_stats,chi2) : sample_stats[4] = chi2
185185

186-
iid_sample = bat_sample(testfunctions[name].posterior,length([BAT.flatview(samples.v)...])).result
186+
iid_sample = bat_sample(testfunctions[name].posterior, IIDSampling(nsamples = length([BAT.flatview(samples.v)...]))).result
187187
if(testfunctions[name].ks[1] > 999)
188188
testfunctions[name].ks[1]=bat_compare(samples,iid_sample).result.ks_p_values[1]
189189
end
@@ -201,24 +201,24 @@ function run1D(
201201
sample_stats::Vector{Float64},
202202
run_stats::Vector{Float64},
203203
algorithm::BAT.AbstractSamplingAlgorithm,
204-
n_samples::Integer,
204+
n_steps::Integer,
205205
n_chains::Integer,
206206
n_runs=1
207207
)
208208

209209
sample_stats_all = []
210-
samples, chains = bat_sample(testfunctions[key].posterior, n_samples * n_chains, MCMCSampling(sampler = algorithm, nchains = n_chains))
210+
samples, chains = bat_sample(testfunctions[key].posterior, MCMCSampling(sampler = algorithm, nchains = n_chains, nsteps = n_steps))
211211
for i in 1:n_runs
212212
time_before = time()
213-
samples, chains = bat_sample(testfunctions[key].posterior, n_samples * n_chains, MCMCSampling(sampler = algorithm, nchains = n_chains))
213+
samples, chains = bat_sample(testfunctions[key].posterior, MCMCSampling(sampler = algorithm, nchains = n_chains, nsteps = n_steps))
214214
time_after = time()
215215

216216
h = plot1D(samples,testfunctions,key,sample_stats)# posterior, key, analytical_stats,sample_stats)
217217

218218
sample_stats[1] = mode(samples)[1]
219219
sample_stats[2] = mean(samples)[1]
220220
sample_stats[3] = var(samples)[1]
221-
run_stats[1] = n_samples
221+
run_stats[1] = n_steps
222222
run_stats[2] = n_chains
223223
run_stats[3] = time_after-time_before
224224
push!(sample_stats_all,sample_stats)
@@ -432,16 +432,16 @@ function run2D(
432432
sample_stats::Vector{Any},
433433
run_stats::Vector{Any},
434434
algorithm::MCMCAlgorithm,
435-
n_samples::Integer,
435+
n_steps::Integer,
436436
n_chains::Integer,
437437
n_runs=1)
438438

439439
sample_stats_all = []
440440

441-
samples, stats = bat_sample(testfunctions[key].posterior, n_samples * n_chains, MCMCSampling(sampler = algorithm, nchains = n_chains))
441+
samples, stats = bat_sample(testfunctions[key].posterior, MCMCSampling(sampler = algorithm, nchains = n_chains, nsteps = n_steps))
442442
for i in 1:n_runs
443443
time_before = time()
444-
samples, stats = bat_sample(testfunctions[key].posterior, n_samples * n_chains, MCMCSampling(sampler = algorithm, nchains = n_chains))
444+
samples, stats = bat_sample(testfunctions[key].posterior, MCMCSampling(sampler = algorithm, nchains = n_chains, nsteps = n_steps))
445445
time_after = time()
446446

447447
h = plot2D(samples, testfunctions, key, sample_stats)
@@ -450,7 +450,7 @@ function run2D(
450450
sample_stats[2] = mean(samples).data
451451
sample_stats[3] = var(samples).data
452452

453-
run_stats[1] = n_samples
453+
run_stats[1] = n_steps
454454
run_stats[2] = n_chains
455455
run_stats[3] = time_after-time_before
456456
push!(sample_stats_all,sample_stats)
@@ -473,7 +473,7 @@ function make_2D_results(testfunctions::Dict,sample_stats2D::Vector{Vector{Any}}
473473
push!(ahmi_val,round.(v.ahmi,digits=3))
474474
end
475475

476-
run_stats_names2D = ["nsamples","nchains","Times"]
476+
run_stats_names2D = ["nsteps","nchains","Times"]
477477
stats_names2D = ["mode","mean","var"]
478478
comparison = ["target","test","diff (abs)","diff (rel)"]
479479
header = Vector{Any}(undef,length(stats_names2D)*length(comparison)+3)

examples/dev-internal/ahmi_example.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ bounds = BAT.HyperRectBounds(lo_bounds, hi_bounds, BAT.reflective_bounds)
4040

4141

4242
#BAT.jl samples
43-
bat_samples = bat_sample(PosteriorDensity(model, bounds), (10^5, 8), algorithm).result
43+
bat_samples = bat_sample(PosteriorDensity(model, bounds), algorithm).result
4444
data = BAT.HMIData(bat_samples)
4545
BAT.hm_integrate!(data)
4646

examples/dev-internal/hmc_with_trafo.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
include(joinpath(dirname(dirname(@__DIR__)), "docs", "src", "tutorial_lit.jl"))
22

3-
samples_mh = bat_sample(posterior, 10^5, MCMCSampling(sampler = MetropolisHastings())).result
3+
samples_mh = bat_sample(posterior, MCMCSampling(sampler = MetropolisHastings(), nsteps = 10^5)).result
44

55
posterior_is, trafo_is = bat_transform(PriorToGaussian(), posterior, PriorSubstitution())
66
posterior_is2, trafo_is2 = bat_transform(PriorToGaussian(), posterior, FullDensityTransform())
77

8-
samples_is = bat_sample(posterior_is, 10^5, MCMCSampling(sampler = HamiltonianMC())).result
9-
samples_is2 = bat_sample(posterior_is2, 10^5, MCMCSampling(sampler = HamiltonianMC())).result
8+
samples_is = bat_sample(posterior_is, MCMCSampling(sampler = HamiltonianMC(), nsteps = 10^5)).result
9+
samples_is2 = bat_sample(posterior_is2, MCMCSampling(sampler = HamiltonianMC(), nsteps = 10^5)).result
1010

1111
samples = inv(trafo_is).(samples_is)
1212
samples2 = inv(trafo_is2).(samples_is2)

examples/dev-internal/output_examples.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ prior = BAT.NamedTupleDist(
2323

2424
posterior = PosteriorDensity(likelihood, prior);
2525

26-
samples, chains = bat_sample(posterior, 10^5, MCMCSampling(sampler = MetropolisHastings()));
27-
#samples = bat_sample(posterior, 10^5, SobolSampler()).result;
26+
samples, chains = bat_sample(posterior, MCMCSampling(sampler = MetropolisHastings(), nsteps = 10^5));
27+
#samples = bat_sample(posterior, SobolSampler(nsamples = 10^5)).result;
2828

2929
sd = SampledDensity(posterior, samples, generator=BAT.MCMCSampleGenerator(chains))
3030
display(sd)

examples/dev-internal/plotting_examples.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ prior = BAT.NamedTupleDist(
3030

3131
posterior = PosteriorDensity(likelihood, prior);
3232

33-
samples, chains = bat_sample(posterior, 10^5, MCMCSampling(sampler = MetropolisHastings()));
33+
samples, chains = bat_sample(posterior, MCMCSampling(sampler = MetropolisHastings(), nsteps = 10^5));
3434

3535
# ## Set up plotting
3636
# Set up plotting using the [Plots.jl](https://github.com/JuliaPlots/Plots.jl) package:

src/algodefaults/default_sampling_algorithm.jl

-18
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,3 @@ bat_default(::typeof(bat_sample), ::Val{:algorithm}, ::AnyIIDSampleable) = IIDSa
99
bat_default(::typeof(bat_sample), ::Val{:algorithm}, ::DensitySampleVector) = OrderedResampling()
1010

1111
bat_default(::typeof(bat_sample), ::Val{:algorithm}, ::AbstractDensity) = MCMCSampling()
12-
13-
14-
#=
15-
For HamiltonianMC
16-
17-
#!!!!!!!!!!!!!!!! N samples steps evals
18-
19-
# MCMCBurninStrategy for HamiltonianMC
20-
function MCMCBurninStrategy(algorithm::HamiltonianMC, nsamples::Integer, max_nsteps::Integer, tuner_config::MCMCTuningAlgorithm)
21-
max_nsamples_per_cycle = nsamples
22-
max_nsteps_per_cycle = max_nsteps
23-
MCMCBurninStrategy(
24-
max_nsamples_per_cycle = max_nsamples_per_cycle,
25-
max_nsteps_per_cycle = max_nsteps_per_cycle,
26-
max_ncycles = 1
27-
)
28-
end
29-
=#

0 commit comments

Comments
 (0)