diff --git a/src/inference/ais.jl b/src/inference/ais.jl new file mode 100644 index 00000000..e94b7794 --- /dev/null +++ b/src/inference/ais.jl @@ -0,0 +1,110 @@ +""" + (lml_est, trace, weights) = ais( + model::GenerativeFunction, constraints::ChoiceMap, + args_seq::Vector{Tuple}, argdiffs::Tuple, + mcmc_kernel::Function) + +Run annealed importance sampling, returning the log marginal likelihood estimate (`lml_est`). + +The mcmc_kernel must satisfy detailed balance with respect to each step in the chain. +""" +function ais( + model::GenerativeFunction, constraints::ChoiceMap, + args_seq::Vector{<:Tuple}, argdiffs::Tuple, mcmc_kernel::Function) + init_trace, init_weight = generate(model, args_seq[1], constraints) + _ais(init_trace, init_weight, args_seq, argdiffs, mcmc_kernel) +end + +function ais( + trace::Trace, selection::Selection, + args_seq::Vector{<:Tuple}, argdiffs::Tuple, mcmc_kernel::Function) + init_trace, = update(init_trace, args_seq[1], argdiffs, EmptyChoiceMap()) + init_weight = project(trace, ComplementSelection(selection)) + _ais(init_trace, init_weight, args_seq, argdiffs, mcmc_kernel) +end + +function _ais( + trace::Trace, init_weight::Float64, args_seq::Vector{<:Tuple}, + argdiffs::Tuple, mcmc_kernel::Function) + @assert get_args(trace) == args_seq[1] + + # run forward AIS + weights = Float64[] + lml_est = init_weight + push!(weights, init_weight) + for intermediate_args in args_seq[2:end-1] + trace = mcmc_kernel(trace) + (trace, weight, _, discard) = update(trace, intermediate_args, argdiffs, EmptyChoiceMap()) + if !isempty(discard) + error("Change to arguments cannot cause random choices to be removed from trace") + end + lml_est += weight + push!(weights, weight) + end + trace = mcmc_kernel(trace) + (trace, weight, _, discard) = update( + trace, args_seq[end], argdiffs, EmptyChoiceMap()) + if !isempty(discard) + error("Change to arguments cannot cause random choices to be removed from trace") + end + lml_est += weight + push!(weights, weight) + + # do MCMC at the very end + trace = mcmc_kernel(trace) + + return (lml_est, trace, weights) +end + +""" + (lml_est, weights) = reverse_ais( + model::GenerativeFunction, constraints::ChoiceMap, + args_seq::Vector{Tuple}, argdiffs::Tuple, + mcmc_kernel::Function) + +Run reverse annealed importance sampling, returning the log marginal likelihood estimate (`lml_est`). + +`constraints` must be a choice map that uniquely determines a trace of the model for the final arguments in the argument sequence. +The mcmc_kernel must satisfy detailed balance with respect to each step in the chain. +""" +function reverse_ais( + model::GenerativeFunction, constraints::ChoiceMap, + args_seq::Vector, argdiffs::Tuple, + mh_rev::Function, output_addrs::Selection; safe=true) + + # construct final model trace from the inferred choices and all the fixed choices + (trace, should_be_score) = generate(model, args_seq[end], constraints) + init_score = get_score(trace) + if safe && !isapprox(should_be_score, init_score) # check it's deterministic + error("Some random choices may have been unconstrained") + end + + # do mh at the very beginning + trace = mh_rev(trace) + + # run backward AIS + lml_est = 0. + weights = Float64[] + for model_args in reverse(args_seq[1:end-1]) + (trace, weight, _, _) = update(trace, model_args, argdiffs, EmptyChoiceMap()) + safe && isnan(weight) && error("NaN weight") + lml_est -= weight + push!(weights, -weight) + trace = mh_rev(trace) + end + + # get pi_1(z_0) / q(z_0) -- the weight that would be returned by the initial 'generate' call + # select the addresses that would be constrained by the call to generate inside to AIS.simulate() + @assert get_args(trace) == args_seq[1] + #score_from_project = project(trace, ComplementSelection(output_addrs)) + score_from_project = project(trace, output_addrs) + lml_est += score_from_project + push!(weights, score_from_project) + if isnan(score_from_project) + error("NaN score_from_project") + end + + return (lml_est, reverse(weights)) +end + +export ais, reverse_ais diff --git a/src/inference/importance.jl b/src/inference/importance.jl index fac3c0e4..4e9172bd 100644 --- a/src/inference/importance.jl +++ b/src/inference/importance.jl @@ -107,4 +107,26 @@ function importance_resampling(model::GenerativeFunction{T,U}, model_args::Tuple return (model_trace::U, log_ml_estimate::Float64) end -export importance_sampling, importance_resampling +""" + log_ml_estimate = conditional_is_estimator( + trace::Trace, observed::Selection, num_samples::Int) + +Given a trace sampled from the conditional distribution given observed choices, +return an estimate of the log marginal likelihood of the observed choices that is a +stochastic upper bound on the true log marginal likelihood. +""" +function conditional_is_estimator(trace::Trace, observed::Selection, num_samples::Int) + model = get_gen_fn(trace) + model_args = get_args(trace) + observations = get_selected(get_choices(trace), observed) + log_weights = Vector{Float64}(undef, num_samples) + log_weights[1] = project(trace, observed) + for i=2:num_samples + (_, log_weights[i]) = generate(model, model_args, observations) + end + log_total_weight = logsumexp(log_weights) + log_ml_estimate = log_total_weight - log(num_samples) + return log_ml_estimate +end + +export importance_sampling, importance_resampling, conditional_is_estimator diff --git a/src/inference/inference.jl b/src/inference/inference.jl index bf4afb1b..70120ea5 100644 --- a/src/inference/inference.jl +++ b/src/inference/inference.jl @@ -25,3 +25,4 @@ include("particle_filter.jl") include("map_optimize.jl") include("train.jl") include("variational.jl") +include("ais.jl")