Skip to content

Add support for user-supplied RNG state in all interfaces #520

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: modular-rng
Choose a base branch
from
2 changes: 2 additions & 0 deletions src/Gen.jl
Original file line number Diff line number Diff line change
@@ -2,6 +2,8 @@

module Gen

using Random: AbstractRNG, default_rng

"""
load_generated_functions(__module__=Main)

15 changes: 9 additions & 6 deletions src/dynamic/dynamic.jl
Original file line number Diff line number Diff line change
@@ -45,12 +45,15 @@ end

accepts_output_grad(gen_fn::DynamicDSLFunction) = gen_fn.accepts_output_grad

mutable struct GFUntracedState
mutable struct GFUntracedState{R<:AbstractRNG}
params::Dict{Symbol,Any}
rng::R
end

function (gen_fn::DynamicDSLFunction)(args...)
state = GFUntracedState(gen_fn.params)
(gen_fn::DynamicDSLFunction)(args...) = gen_fn(default_rng(), args...)

function (gen_fn::DynamicDSLFunction)(rng::AbstractRNG, args...)
state = GFUntracedState(gen_fn.params, rng)
gen_fn.julia_function(state, args...)
end

@@ -82,13 +85,13 @@ end

# Defaults for untraced execution
@inline traceat(state::GFUntracedState, gen_fn::GenerativeFunction, args, key) =
gen_fn(args...)
gen_fn(state.rng, args...)

@inline traceat(state::GFUntracedState, dist::Distribution, args, key) =
random(dist, args...)
random(state.rng, dist, args...)

@inline splice(state::GFUntracedState, gen_fn::DynamicDSLFunction, args::Tuple) =
gen_fn(args...)
gen_fn(state.rng, args...)

########################
# trainable parameters #
20 changes: 12 additions & 8 deletions src/dynamic/generate.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
mutable struct GFGenerateState
mutable struct GFGenerateState{R<:AbstractRNG}
trace::DynamicDSLTrace
constraints::ChoiceMap
weight::Float64
visitor::AddressVisitor
params::Dict{Symbol,Any}
rng::R
end

function GFGenerateState(gen_fn, args, constraints, params)
function GFGenerateState(gen_fn, args, constraints, params, rng::AbstractRNG)
trace = DynamicDSLTrace(gen_fn, args)
GFGenerateState(trace, constraints, 0., AddressVisitor(), params)
GFGenerateState(trace, constraints, 0., AddressVisitor(), params, rng)
end

function traceat(state::GFGenerateState, dist::Distribution{T},
@@ -26,7 +27,7 @@ function traceat(state::GFGenerateState, dist::Distribution{T},
if constrained
retval = get_value(state.constraints, key)
else
retval = random(dist, args...)
retval = random(state.rng, dist, args...)
end

# compute logpdf
@@ -55,7 +56,7 @@ function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U},
constraints = get_submap(state.constraints, key)

# get subtrace
(subtrace, weight) = generate(gen_fn, args, constraints)
(subtrace, weight) = generate(state.rng, gen_fn, args, constraints)

# add to the trace
add_call!(state.trace, key, subtrace)
@@ -78,9 +79,12 @@ function splice(state::GFGenerateState, gen_fn::DynamicDSLFunction,
retval
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On line 59, the recursive call to generate needs to pass state.rng to the callee function.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


function generate(gen_fn::DynamicDSLFunction, args::Tuple,
constraints::ChoiceMap)
state = GFGenerateState(gen_fn, args, constraints, gen_fn.params)
generate(gen_fn::DynamicDSLFunction, args::Tuple, constraints::ChoiceMap) =
generate(default_rng(), gen_fn, args, constraints)

function generate(rng::AbstractRNG, gen_fn::DynamicDSLFunction, args::Tuple,
constraints::ChoiceMap)
state = GFGenerateState(gen_fn, args, constraints, gen_fn.params, rng)
retval = exec(gen_fn, state, args)
set_retval!(state.trace, retval)
(state.trace, state.weight)
15 changes: 8 additions & 7 deletions src/dynamic/propose.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
mutable struct GFProposeState
mutable struct GFProposeState{R<:AbstractRNG}
choices::DynamicChoiceMap
weight::Float64
visitor::AddressVisitor
params::Dict{Symbol,Any}
rng::R
end

function GFProposeState(params::Dict{Symbol,Any})
GFProposeState(choicemap(), 0., AddressVisitor(), params)
function GFProposeState(params::Dict{Symbol,Any}, rng::AbstractRNG)
GFProposeState(choicemap(), 0., AddressVisitor(), params, rng)
end

function traceat(state::GFProposeState, dist::Distribution{T},
@@ -17,7 +18,7 @@ function traceat(state::GFProposeState, dist::Distribution{T},
visit!(state.visitor, key)

# sample return value
retval = random(dist, args...)
retval = random(state.rng, dist, args...)

# update assignment
set_value!(state.choices, key, retval)
@@ -36,7 +37,7 @@ function traceat(state::GFProposeState, gen_fn::GenerativeFunction{T,U},
visit!(state.visitor, key)

# get subtrace
(submap, weight, retval) = propose(gen_fn, args)
(submap, weight, retval) = propose(state.rng, gen_fn, args)

# update assignment
set_submap!(state.choices, key, submap)
@@ -55,8 +56,8 @@ function splice(state::GFProposeState, gen_fn::DynamicDSLFunction, args::Tuple)
retval
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On line 40, state.rng needs to be passed to the recursive call to propose.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


function propose(gen_fn::DynamicDSLFunction, args::Tuple)
state = GFProposeState(gen_fn.params)
function propose(rng::AbstractRNG, gen_fn::DynamicDSLFunction, args::Tuple)
state = GFProposeState(gen_fn.params, rng)
retval = exec(gen_fn, state, args)
(state.choices, state.weight, retval)
end
21 changes: 11 additions & 10 deletions src/dynamic/regenerate.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
mutable struct GFRegenerateState
mutable struct GFRegenerateState{R<:AbstractRNG}
prev_trace::DynamicDSLTrace
trace::DynamicDSLTrace
selection::Selection
weight::Float64
visitor::AddressVisitor
params::Dict{Symbol,Any}
rng::R
end

function GFRegenerateState(gen_fn, args, prev_trace,
selection, params)
selection, params, rng::AbstractRNG)
visitor = AddressVisitor()
GFRegenerateState(prev_trace, DynamicDSLTrace(gen_fn, args), selection,
0., visitor, params)
0., visitor, params, rng)
end

function traceat(state::GFRegenerateState, dist::Distribution{T},
@@ -35,11 +36,11 @@ function traceat(state::GFRegenerateState, dist::Distribution{T},

# get return value
if has_previous && in_selection
retval = random(dist, args...)
retval = random(state.rng, dist, args...)
elseif has_previous
retval = prev_retval
else
retval = random(dist, args...)
retval = random(state.rng, dist, args...)
end

# compute logpdf
@@ -75,9 +76,9 @@ function traceat(state::GFRegenerateState, gen_fn::GenerativeFunction{T,U},
prev_subtrace = prev_call.subtrace
get_gen_fn(prev_subtrace) === gen_fn || gen_fn_changed_error(key)
(subtrace, weight, _) = regenerate(
prev_subtrace, args, map((_) -> UnknownChange(), args), subselection)
state.rng, prev_subtrace, args, map((_) -> UnknownChange(), args), subselection)
else
(subtrace, weight) = generate(gen_fn, args, EmptyChoiceMap())
(subtrace, weight) = generate(state.rng, gen_fn, args, EmptyChoiceMap())
end

# update weight
@@ -130,10 +131,10 @@ function regenerate_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord},
noise
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On lines 78 and 81, state.rng needs to be passed to the calls to regenerate and generate respectively.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


function regenerate(trace::DynamicDSLTrace, args::Tuple, argdiffs::Tuple,
selection::Selection)
function regenerate(rng::AbstractRNG, trace::DynamicDSLTrace, args::Tuple,
argdiffs::Tuple, selection::Selection)
gen_fn = trace.gen_fn
state = GFRegenerateState(gen_fn, args, trace, selection, gen_fn.params)
state = GFRegenerateState(gen_fn, args, trace, selection, gen_fn.params, rng)
retval = exec(gen_fn, state, args)
set_retval!(state.trace, retval)
visited = state.visitor.visited
15 changes: 8 additions & 7 deletions src/dynamic/simulate.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
mutable struct GFSimulateState
mutable struct GFSimulateState{R<:AbstractRNG}
trace::DynamicDSLTrace
visitor::AddressVisitor
params::Dict{Symbol,Any}
rng::R
end

function GFSimulateState(gen_fn::GenerativeFunction, args::Tuple, params)
function GFSimulateState(gen_fn::GenerativeFunction, args::Tuple, params, rng::AbstractRNG)
trace = DynamicDSLTrace(gen_fn, args)
GFSimulateState(trace, AddressVisitor(), params)
GFSimulateState(trace, AddressVisitor(), params, rng)
end

function traceat(state::GFSimulateState, dist::Distribution{T},
@@ -16,7 +17,7 @@ function traceat(state::GFSimulateState, dist::Distribution{T},
# check that key was not already visited, and mark it as visited
visit!(state.visitor, key)

retval = random(dist, args...)
retval = random(state.rng, dist, args...)

# compute logpdf
score = logpdf(dist, retval, args...)
@@ -36,7 +37,7 @@ function traceat(state::GFSimulateState, gen_fn::GenerativeFunction{T,U},
visit!(state.visitor, key)

# get subtrace
subtrace = simulate(gen_fn, args)
subtrace = simulate(state.rng, gen_fn, args)

# add to the trace
add_call!(state.trace, key, subtrace)
@@ -56,8 +57,8 @@ function splice(state::GFSimulateState, gen_fn::DynamicDSLFunction,
retval
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On line 40, state.rng needs to be passed to the call to simulate.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


function simulate(gen_fn::DynamicDSLFunction, args::Tuple)
state = GFSimulateState(gen_fn, args, gen_fn.params)
function simulate(rng::AbstractRNG, gen_fn::DynamicDSLFunction, args::Tuple)
state = GFSimulateState(gen_fn, args, gen_fn.params, rng)
retval = exec(gen_fn, state, args)
set_retval!(state.trace, retval)
state.trace
17 changes: 9 additions & 8 deletions src/dynamic/update.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
mutable struct GFUpdateState
mutable struct GFUpdateState{R<:AbstractRNG}
prev_trace::DynamicDSLTrace
trace::DynamicDSLTrace
constraints::Any
weight::Float64
visitor::AddressVisitor
params::Dict{Symbol,Any}
discard::DynamicChoiceMap
rng::R
end

function GFUpdateState(gen_fn, args, prev_trace, constraints, params)
function GFUpdateState(gen_fn, args, prev_trace, constraints, params, rng::AbstractRNG)
visitor = AddressVisitor()
discard = choicemap()
trace = DynamicDSLTrace(gen_fn, args)
GFUpdateState(prev_trace, trace, constraints,
0., visitor, params, discard)
0., visitor, params, discard, rng)
end

function traceat(state::GFUpdateState, dist::Distribution{T},
@@ -48,7 +49,7 @@ function traceat(state::GFUpdateState, dist::Distribution{T},
elseif has_previous
retval = prev_retval
else
retval = random(dist, args...)
retval = random(state.rng, dist, args...)
end

# compute logpdf
@@ -87,10 +88,10 @@ function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U},
prev_call = get_call(state.prev_trace, key)
prev_subtrace = prev_call.subtrace
get_gen_fn(prev_subtrace) == gen_fn || gen_fn_changed_error(key)
(subtrace, weight, _, discard) = update(prev_subtrace,
(subtrace, weight, _, discard) = update(state.rng, prev_subtrace,
args, map((_) -> UnknownChange(), args), constraints)
else
(subtrace, weight) = generate(gen_fn, args, constraints)
(subtrace, weight) = generate(state.rng, gen_fn, args, constraints)
end

# update the weight
@@ -184,10 +185,10 @@ function add_unvisited_to_discard!(discard::DynamicChoiceMap,
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On lines 91 and 94, state.rng needs to be passed to the recursive calls to update and generate respectively.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


function update(trace::DynamicDSLTrace, arg_values::Tuple, arg_diffs::Tuple,
function update(rng::AbstractRNG, trace::DynamicDSLTrace, arg_values::Tuple, arg_diffs::Tuple,
constraints::ChoiceMap)
gen_fn = trace.gen_fn
state = GFUpdateState(gen_fn, arg_values, trace, constraints, gen_fn.params)
state = GFUpdateState(gen_fn, arg_values, trace, constraints, gen_fn.params, rng)
retval = exec(gen_fn, state, arg_values)
set_retval!(state.trace, retval)
visited = get_visited(state.visitor)
Loading