Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Renyi divergence #769

Open
wants to merge 40 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
5a5f8a9
trying ab_divergence
jb-regli Sep 16, 2017
3d65841
adding renyi as special case
jb-regli Sep 16, 2017
523210a
trying ab_divergence
jb-regli Sep 16, 2017
4682fa7
trying ab_divergence
jb-regli Sep 16, 2017
0e30c18
sign error ?
jb-regli Sep 16, 2017
cc63477
ignore data + add renyi divergence
jb-regli Sep 17, 2017
c79dbb2
cleaning
jb-regli Sep 17, 2017
397eb71
docstring
jb-regli Sep 17, 2017
005dd03
renyi divergence
jb-regli Sep 20, 2017
a38fa82
renyi examples in notebook
jb-regli Sep 20, 2017
102453c
renyi examples in notebook
jb-regli Sep 20, 2017
8fdbe52
branch
jb-regli Sep 21, 2017
f85d97d
hard reset
jb-regli Sep 21, 2017
b146080
hard reset
jb-regli Sep 21, 2017
d618579
renyi divergence
jb-regli Sep 27, 2017
9f9a889
renyi divergence improvement
jb-regli Sep 27, 2017
66b8e87
Renyi divergence improvement
jb-regli Sep 27, 2017
17bdb8b
Error
jb-regli Sep 27, 2017
377ff9c
Moved build_loss and gradient in the class
jb-regli Sep 27, 2017
4c67eed
Renyi exqmple + docstring
jb-regli Sep 27, 2017
5aa9a25
Merge branch 'master' of https://github.com/blei-lab/edward
jb-regli Sep 27, 2017
d4f98b0
Merge remote-tracking branch 'origin/master' into renyi_divergence
jb-regli Sep 27, 2017
c340e46
testing
jb-regli Sep 27, 2017
18fd32f
test
jb-regli Sep 27, 2017
8623ebc
Pep8 correction
jb-regli Sep 27, 2017
57c5ba0
remove irrelevant file from PR
jb-regli Sep 27, 2017
6fc9b8a
2-space indent
jb-regli Sep 27, 2017
0df0215
Markdown formated docstring
jb-regli Sep 27, 2017
671541b
Edited docstring
jb-regli Sep 27, 2017
9e0a3b7
Correct order of call
jb-regli Sep 27, 2017
821d102
Updated docstrings
jb-regli Sep 27, 2017
6de5523
Testing for Renyi VI
jb-regli Sep 28, 2017
9d58f4f
Add Renyi_div to shortcut
jb-regli Sep 29, 2017
bdcaa8f
Correct init
jb-regli Sep 29, 2017
dfad744
Allow renyidivergence
jb-regli Sep 29, 2017
e5c4867
Allow quick call
jb-regli Sep 29, 2017
89fe5cd
Debug shortcut
jb-regli Sep 29, 2017
da83c97
restore from edward
jb-regli Sep 29, 2017
a9139ed
Call shortcut for Renyi divergence
jb-regli Sep 29, 2017
717d236
Correct style
jb-regli Sep 29, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/tex/bib.bib
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,14 @@ @article{johnson2016composing
year = {2016},
}

@article{li2016renyi,
title={R{\'e}nyi divergence variational inference},
author={Li, Yingzhen and Turner, Richard E},
booktitle={Advances in Neural Information Processing Systems},
pages={1073--1081},
year={2016}
}

@article{mohamed2016learning,
author = {Mohamed, Shakir and Lakshminarayanan, Balaji},
title = {{Learning in Implicit Generative Models}},
Expand Down Expand Up @@ -801,4 +809,3 @@ @inproceedings{tran2017deep
booktitle = {International Conference on Learning Representations},
year = {2017}
}

3 changes: 2 additions & 1 deletion edward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
KLpq, KLqp, ReparameterizationKLqp, ReparameterizationKLKLqp, \
ReparameterizationEntropyKLqp, ScoreKLqp, ScoreKLKLqp, ScoreEntropyKLqp, \
ScoreRBKLqp, WakeSleep, GANInference, BiGANInference, WGANInference, \
ImplicitKLqp, MAP, Laplace, complete_conditional, Gibbs
ImplicitKLqp, MAP, Laplace, complete_conditional, Gibbs, RenyiDivergence
from edward.models import RandomVariable
from edward.util import check_data, check_latent_vars, copy, dot, \
get_ancestors, get_blanket, get_children, get_control_variate_coef, \
Expand Down Expand Up @@ -56,6 +56,7 @@
'BiGANInference',
'WGANInference',
'ImplicitKLqp',
'RenyiDivergence',
'MAP',
'Laplace',
'complete_conditional',
Expand Down
2 changes: 2 additions & 0 deletions edward/inferences/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from edward.inferences.variational_inference import *
from edward.inferences.wake_sleep import *
from edward.inferences.wgan_inference import *
from edward.inferences.renyi_divergence import *

from tensorflow.python.util.all_util import remove_undocumented

Expand Down Expand Up @@ -51,6 +52,7 @@
'VariationalInference',
'WakeSleep',
'WGANInference',
'RenyiDivergence',
]

remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
174 changes: 174 additions & 0 deletions edward/inferences/renyi_divergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import six
import tensorflow as tf

from edward.inferences.variational_inference import VariationalInference
from edward.models import RandomVariable
from edward.util import copy

try:
from edward.models import Normal
from tensorflow.contrib.distributions import kl_divergence
except Exception as e:
raise ImportError("{0}. Your TensorFlow version is not supported.".format(e))


class RenyiDivergence(VariationalInference):
"""Variational inference with the Renyi divergence [@li2016renyi].

It minimizes the Renyi divergence

$ \text{D}_{R}^{(\alpha)}(q(z)||p(z \mid x))
= \frac{1}{\alpha-1} \log \int q(z)^{\alpha} p(z \mid x)^{1-\alpha} dz.$

The optimization is performed using the gradient estimator as defined in
@li2016renyi.

#### Notes
+ The gradient estimator used here does not have any analytic version.
+ The gradient estimator used here does not have any version for non
reparametrizable models.
+ backward_pass = 'max': (extreme case $\alpha \rightarrow -\infty$)
the algorithm chooses the sample that has the maximum unnormalised
importance weight. This does not minimize the Renyi divergence
anymore.
+ backward_pass = 'min': (extreme case $\alpha \rightarrow +\infty$)
the algorithm chooses the sample that has the minimum unnormalised
importance weight. This does not minimize the Renyi divergence
anymore. This mode is not describe in the paper but implemented
in the publicly available implementation of the paper's experiments.
"""

def __init__(self, *args, **kwargs):

super(RenyiDivergence, self).__init__(*args, **kwargs)

is_reparameterizable = all([
rv.reparameterization_type ==
tf.contrib.distributions.FULLY_REPARAMETERIZED
for rv in six.itervalues(self.latent_vars)])

if not is_reparameterizable:
raise NotImplementedError(
"Variational Renyi inference only works with reparameterizable"
" models.")

def initialize(self,
n_samples=32,
alpha=1.0,
backward_pass='full',
*args, **kwargs):
"""Initialize inference algorithm. It initializes hyperparameters
and builds ops for the algorithm's computation graph.

Args:
n_samples: int, optional.
Number of samples from variational model for calculating
stochastic gradients.
alpha: float, optional.
Renyi divergence coefficient. $\alpha \in \mathbb{R}$.
When $\alpha < 0$, the algorithm still does something sensible but
does not minimize the Renyi divergence anymore.
(see [@li2016renyi] - section 4.2)
backward_pass: str, optional.
Backward pass mode to be used.
Options: 'min', 'max', 'full'
(see [@li2016renyi] - section 4.2)
"""
self.n_samples = n_samples
self.alpha = alpha
self.backward_pass = backward_pass

return super(RenyiDivergence, self).initialize(*args, **kwargs)

def build_loss_and_gradients(self, var_list):
"""Build the Renyi ELBO function.

Its automatic differentiation is a stochastic gradient of

$ \mcalL_{R}^{\alpha}(q; x) =
\frac{1}{1-\alpha} \log \dsE_{q} \left[
\left( \frac{p(x, z)}{q(z)}\right)^{1-\alpha} \right].$

It uses:

+ Monte Carlo approximation of the ELBO [@li2016renyi].
+ Reparameterization gradients [@kingma2014auto].
+ Stochastic approximation of the joint distribution [@li2016renyi].

#### Notes

+ If the model is not reparameterizable, it returns a
NotImplementedError.
+ See Renyi Divergence Variational Inference [@li2016renyi] for
more details.
"""
p_log_prob = [0.0] * self.n_samples
q_log_prob = [0.0] * self.n_samples
base_scope = tf.get_default_graph().unique_name("inference") + '/'
for s in range(self.n_samples):
# Form dictionary in order to replace conditioning on prior or
# observed variable with conditioning on a specific value.
scope = base_scope \
+ tf.get_default_graph().unique_name("sample")
dict_swap = {}
for x, qx in six.iteritems(self.data):
if isinstance(x, RandomVariable):
if isinstance(qx, RandomVariable):
qx_copy = copy(qx, scope=scope)
dict_swap[x] = qx_copy.value()
else:
dict_swap[x] = qx

for z, qz in six.iteritems(self.latent_vars):
# Copy q(z) to obtain new set of posterior samples.
qz_copy = copy(qz, scope=scope)
dict_swap[z] = qz_copy.value()
q_log_prob[s] += tf.reduce_sum(
self.scale.get(z, 1.0) * qz_copy.log_prob(dict_swap[z]))

for z in six.iterkeys(self.latent_vars):
z_copy = copy(z, dict_swap, scope=scope)
p_log_prob[s] += tf.reduce_sum(
self.scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z]))

for x in six.iterkeys(self.data):
if isinstance(x, RandomVariable):
x_copy = copy(x, dict_swap, scope=scope)
p_log_prob[s] += tf.reduce_sum(
self.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x]))

log_ratios = [p - q for p, q in zip(p_log_prob, q_log_prob)]

if self.backward_pass == 'max':
loss = tf.reduce_max(log_ratios, 0)
elif self.backward_pass == 'min':
loss = tf.reduce_min(log_ratios, 0)
elif np.abs(self.alpha - 1.0) < 10e-3:
loss = tf.reduce_mean(log_ratios)
else:
log_ratios = tf.stack(log_ratios)
log_ratios = log_ratios * (1 - self.alpha)
log_ratios_max = tf.reduce_max(log_ratios, 0)
log_ratios = tf.log(
tf.maximum(1e-9,
tf.reduce_mean(tf.exp(log_ratios - log_ratios_max), 0)))
log_ratios = (log_ratios + log_ratios_max) / (1 - self.alpha)
loss = tf.reduce_mean(log_ratios)
loss = -loss

if self.logging:
p_log_prob = tf.reduce_mean(p_log_prob)
q_log_prob = tf.reduce_mean(q_log_prob)
tf.summary.scalar("loss/p_log_prob", p_log_prob,
collections=[self._summary_key])
tf.summary.scalar("loss/q_log_prob", q_log_prob,
collections=[self._summary_key])

grads = tf.gradients(loss, var_list)
grads_and_vars = list(zip(grads, var_list))
return loss, grads_and_vars
76 changes: 39 additions & 37 deletions examples/vae.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
#!/usr/bin/env python
"""Variational auto-encoder for MNIST data.
"""Renyi Variational auto-encoder for MNIST data.

We here use the Renyi variational objective [@li2016renyi].
This objective allows to vary the divergence measured used for optimization, by
tuning the meta-parameter $\alpha$.

#### Notes
The Renyi variational objective reduces down to the classic kl_divergence for
$\alpha=1$ and the "standard" VAE is obtained.

References
----------
Expand All @@ -18,39 +26,26 @@
from edward.models import Bernoulli, Normal
from edward.util import Progbar
from keras.layers import Dense
from observations import mnist
from tensorflow.examples.tutorials.mnist import input_data
from scipy.misc import imsave


def generator(array, batch_size):
"""Generate batch with respect to array's first axis."""
start = 0 # pointer to where we are in iteration
while True:
stop = start + batch_size
diff = stop - array.shape[0]
if diff <= 0:
batch = array[start:stop]
start += batch_size
else:
batch = np.concatenate((array[start:], array[:diff]))
start = diff
batch = batch.astype(np.float32) / 255.0 # normalize pixel intensities
batch = np.random.binomial(1, batch) # binarize images
yield batch


ed.set_seed(42)

data_dir = "/tmp/data"
out_dir = "/tmp/out"
if not os.path.exists(out_dir):
os.makedirs(out_dir)
DATA_DIR = "data/mnist"
IMG_DIR = "img"
if not os.path.exists(DATA_DIR):
os.makedirs(DATA_DIR)
if not os.path.exists(IMG_DIR):
os.makedirs(IMG_DIR)

M = 100 # batch size during training
d = 2 # latent dimension
alpha = 0.5 # alpha values for renyi divergence
n_samples = 5 # number of samples used to estimate the Renyi ELBO
backward_pass = 'full' # Back propagation style ('min', 'max' or 'full')

# DATA. MNIST batches are fed at training time.
(x_train, _), (x_test, _) = mnist(data_dir)
x_train_generator = generator(x_train, M)
mnist = input_data.read_data_sets(DATA_DIR)

# MODEL
# Define a subgraph of the full model, corresponding to a minibatch of
Expand All @@ -68,32 +63,39 @@ def generator(array, batch_size):
scale=Dense(d, activation='softplus')(hidden))

# Bind p(x, z) and q(z | x) to the same TensorFlow placeholder for x.
inference = ed.KLqp({z: qz}, data={x: x_ph})
inference = ed.RenyiDivergence({z: qz}, data={x: x_ph})
optimizer = tf.train.RMSPropOptimizer(0.01, epsilon=1.0)
inference.initialize(optimizer=optimizer)

inference.initialize(optimizer=optimizer,
n_samples=n_samples,
alpha=alpha,
backward_pass=backward_pass)
# inference = ed.KLqp({z: qz}, data={x: x_ph})
# optimizer = tf.train.RMSPropOptimizer(0.01, epsilon=1.0)
# inference.initialize(optimizer=optimizer)

sess = ed.get_session()
tf.global_variables_initializer().run()

n_epoch = 100
n_iter_per_epoch = x_train.shape[0] // M
for epoch in range(1, n_epoch + 1):
print("Epoch: {0}".format(epoch))
n_iter_per_epoch = 1000
for epoch in range(n_epoch):
avg_loss = 0.0

pbar = Progbar(n_iter_per_epoch)
for t in range(1, n_iter_per_epoch + 1):
pbar.update(t)
x_batch = next(x_train_generator)
info_dict = inference.update(feed_dict={x_ph: x_batch})
x_train, _ = mnist.train.next_batch(M)
x_train = np.random.binomial(1, x_train)
info_dict = inference.update(feed_dict={x_ph: x_train})
avg_loss += info_dict['loss']

# Print a lower bound to the average marginal likelihood for an
# image.
avg_loss = avg_loss / n_iter_per_epoch
avg_loss = avg_loss / M
print("-log p(x) <= {:0.3f}".format(avg_loss))
print("log p(x) >= {:0.3f}".format(avg_loss))

# Prior predictive check.
images = x.eval()
imgs = sess.run(x)
for m in range(M):
imsave(os.path.join(out_dir, '%d.png') % m, images[m].reshape(28, 28))
imsave(os.path.join(IMG_DIR, '%d.png') % m, imgs[m].reshape(28, 28))
Loading