-
Notifications
You must be signed in to change notification settings - Fork 757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rejection sampling variational inference #819
base: master
Are you sure you want to change the base?
Changes from 26 commits
4efb780
7e43d1b
d673763
7a5f90e
94a1bc3
a4c87cc
163414c
f162135
2f96076
2c1162b
ad25f6d
7e4a9ce
8dc4f4f
0aae8ed
70172fb
929e25c
95d9774
c212858
81637fb
7aec66c
dda7f26
2a4ccc8
8f69548
26f8ed8
c7f3ea1
435ec01
45b17b8
ed6e266
80cee16
ef45bc3
b94ef73
a136f9d
680894b
47ba81c
26f0c32
7b997e1
6108125
77e9a6c
3846fa6
23c33af
4c481a0
00c9325
40d3808
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -123,7 +123,6 @@ def run(self, variables=None, use_coordinator=True, *args, **kwargs): | |
Passed into `initialize`. | ||
""" | ||
self.initialize(*args, **kwargs) | ||
|
||
if variables is None: | ||
init = tf.global_variables_initializer() | ||
else: | ||
|
@@ -144,6 +143,7 @@ def run(self, variables=None, use_coordinator=True, *args, **kwargs): | |
|
||
for _ in range(self.n_iter): | ||
info_dict = self.update() | ||
print(info_dict) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rm? |
||
self.print_progress(info_dict) | ||
|
||
self.finalize() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,7 +32,7 @@ class KLpq(VariationalInference): | |
|
||
with respect to $\\theta$. | ||
|
||
In conditional inference, we infer $z` in $p(z, \\beta | ||
In conditional inference, we infer $z$ in $p(z, \\beta | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is unrelated to this PR. Can you make a new PR to fix this? |
||
\mid x)$ while fixing inference over $\\beta$ using another | ||
distribution $q(\\beta)$. During gradient calculation, instead | ||
of using the model's density | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
""" | ||
""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from edward.optimizers.sgd import * | ||
|
||
from tensorflow.python.util.all_util import remove_undocumented | ||
|
||
_allowed_symbols = [ | ||
'KucukelbirOptimizer', | ||
] | ||
|
||
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import tensorflow as tf | ||
|
||
|
||
class KucukelbirOptimizer: | ||
|
||
""" | ||
Used for RSVI (Rejection-Sampling Variational Inference). | ||
|
||
# TODO: add me | ||
""" | ||
|
||
def __init__(self, t, delta, eta, s_n, n): | ||
self.t = t | ||
self.delta = delta | ||
self.eta = eta | ||
self.s_n = s_n | ||
self.n = n | ||
|
||
def apply_gradients(self, grads_and_vars, global_step=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd quite appreciate if you could glance at this method as well, as my integration test passes on some days and fails on others — with 0 changes to my code. Promise 🤞. |
||
self.n = tf.assign_add(self.n, 1.) | ||
ops = [] | ||
for i, (grad, var) in enumerate(grads_and_vars): | ||
updated_s_n = self.s_n[i].assign( (self.t * grad**2) + (1 - self.t) * self.s_n[i] ) | ||
|
||
p_n_first = self.eta * self.n**(-.5 + self.delta) | ||
p_n_second = (1 + tf.sqrt(updated_s_n[i]))**(-1) | ||
p_n = p_n_first * p_n_second | ||
|
||
updated_var = var.assign_add(-p_n * grad) | ||
ops.append(updated_var) | ||
return ops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add back newline? unrelated to PR