Skip to content

Commit

Permalink
Support kl divergence for TFP distributions (pyro-ppl#1270)
Browse files Browse the repository at this point in the history
* Support kl divergence for TFP distributions

* make lint pass
  • Loading branch information
fehiepsi authored Jan 2, 2022
1 parent fd99ec5 commit 28fca48
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
12 changes: 11 additions & 1 deletion numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
from functools import lru_cache
import warnings

from multipledispatch import dispatch
import numpy as np

import jax
import jax.numpy as jnp
from tensorflow_probability.substrates.jax import bijectors as tfb, distributions as tfd

import numpyro.distributions as numpyro_dist
from numpyro.distributions import Distribution as NumPyroDistribution, constraints
from numpyro.distributions import (
Distribution as NumPyroDistribution,
constraints,
kl_divergence,
)
from numpyro.distributions.transforms import Transform, biject_to
from numpyro.util import find_stack_level, not_jax_tracer

Expand Down Expand Up @@ -270,6 +275,11 @@ def tree_unflatten(cls, aux_data, params):
return TFPDistribution[fn.__class__](**fn.parameters)


@dispatch(TFPDistribution, TFPDistribution)
def kl_divergence(p, q): # noqa: F811
return tfd.kl_divergence(p.tfp_dist, q.tfp_dist)


__all__ = ["BijectorConstraint", "BijectorTransform", "TFPDistribution"]
_len_all = len(__all__)
for _name, _Dist in tfd.__dict__.items():
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ max-line-length = 120
exclude = docs/src, build, dist, .ipynb_checkpoints
ignore = W503,E203
per-file-ignores =
numpyro/contrib/tfp/distributions.py:F811
numpyro/distributions/kl.py:F811

[isort]
Expand Down
21 changes: 21 additions & 0 deletions test/contrib/test_tfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import os

import numpy as np
from numpy.testing import assert_allclose
import pytest

Expand Down Expand Up @@ -300,3 +301,23 @@ def model(y):
samples = mcmc.get_samples()

assert_allclose(jnp.mean(samples["p"]), 4 / 7, atol=0.05)


@pytest.mark.parametrize("shape", [(), (4,), (2, 3)], ids=str)
@pytest.mark.filterwarnings("ignore:Importing distributions from numpyro.contrib")
@pytest.mark.filterwarnings("ignore:Explicitly requested dtype")
def test_kl_normal_normal(shape):
from tensorflow_probability.substrates.jax import distributions as tfd

from numpyro.contrib.tfp.distributions import TFPDistribution

p = TFPDistribution[tfd.Normal](
np.random.normal(size=shape), np.exp(np.random.normal(size=shape))
)
q = TFPDistribution[tfd.Normal](
np.random.normal(size=shape), np.exp(np.random.normal(size=shape))
)
actual = dist.kl_divergence(p, q)
x = p.sample(random.PRNGKey(0), (10000,)).copy()
expected = jnp.mean((p.log_prob(x) - q.log_prob(x)), 0)
assert_allclose(actual, expected, rtol=0.05)

0 comments on commit 28fca48

Please sign in to comment.