Skip to content

Commit

Permalink
Refactor TFP wrappers to avoid new metadata logic (pyro-ppl#1064)
Browse files Browse the repository at this point in the history
* refactor TFP wrappers to avoid metadata logic

* make TransformedDistribution work

* raise attribute error for missing attributes
  • Loading branch information
fehiepsi authored Jun 22, 2021
1 parent ca6811b commit 42763d4
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 62 deletions.
147 changes: 94 additions & 53 deletions numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import inspect

import numpy as np

import jax.numpy as jnp
Expand Down Expand Up @@ -103,42 +105,78 @@ def _transform_to_bijector_constraint(constraint):
return BijectorTransform(constraint.bijector)


_TFPDistributionMeta = type(tfd.Distribution)
class _TFPDistributionMeta(type(NumPyroDistribution)):
def __getitem__(cls, tfd_class):
assert issubclass(tfd_class, tfd.Distribution)

def init(self, *args, **kwargs):
self.tfp_dist = tfd_class(*args, **kwargs)

# XXX: we create this mixin class to avoid metaclass conflict between TFP and NumPyro Ditribution
class _TFPMixinMeta(_TFPDistributionMeta, type(NumPyroDistribution)):
def __init__(cls, name, bases, dct):
# XXX: _TFPDistributionMeta.__init__ registers cls as a PyTree
# for some reasons, when defining metaclass of TFPDistributionMixin to be _TFPMixinMeta,
# TFPDistributionMixin will be registered as a PyTree 2 times, which is not allowed
# in JAX, so we skip registering TFPDistributionMixin as a PyTree.
if name == "TFPDistributionMixin":
super(_TFPDistributionMeta, cls).__init__(name, bases, dct)
else:
super(_TFPMixinMeta, cls).__init__(name, bases, dct)
init.__signature__ = inspect.signature(tfd_class.__init__)

_PyroDist = type(tfd_class.__name__, (TFPDistribution,), {})
_PyroDist.tfd_class = tfd_class
_PyroDist.__init__ = init
return _PyroDist


class TFPDistributionMixin(NumPyroDistribution, metaclass=_TFPMixinMeta):
class TFPDistribution(NumPyroDistribution, metaclass=_TFPDistributionMeta):
"""
A mixin layer to make TensorFlow Probability (TFP) distribution compatible
with NumPyro internal.
A thin wrapper for TensorFlow Probability (TFP) distributions. The constructor
has the same signature as the corresponding TFP distribution.
This class can be used to convert a TFP distribution to a NumPyro-compatible one
as follows::
d = TFPDistribution[tfd.Normal](0, 1)
"""

def __init_subclass__(cls, **kwargs):
# skip register pytree because TFP distributions are already pytrees
super(object, cls).__init_subclass__(**kwargs)
tfd_class = None

def __call__(self, *args, **kwargs):
key = kwargs.pop("rng_key")
sample_intermediates = kwargs.pop("sample_intermediates", False)
if sample_intermediates:
return self.sample(*args, seed=key, **kwargs), []
return self.sample(*args, seed=key, **kwargs)
def __getattr__(self, name):
# return parameters from the constructor
if name in self.tfp_dist.parameters:
return self.tfp_dist.parameters[name]
elif name in ["dtype", "reparameterization_type"]:
return getattr(self.tfp_dist, name)
raise AttributeError(name)

@property
def batch_shape(self):
return self.tfp_dist.batch_shape

@property
def event_shape(self):
return self.tfp_dist.event_shape

@property
def has_rsample(self):
return self.tfp_dist.reparameterization_type is tfd.FULLY_REPARAMETERIZED

def sample(self, key, sample_shape=()):
return self.tfp_dist.sample(sample_shape=sample_shape, seed=key)

def log_prob(self, value):
return self.tfp_dist.log_prob(value)

@property
def mean(self):
return self.tfp_dist.mean()

@property
def variance(self):
return self.tfp_dist.variance()

def cdf(self, value):
return self.tfp_dist.cdf(value)

def icdf(self, q):
return self.tfp_dist.quantile(q)

@property
def support(self):
bijector = self._default_event_space_bijector()
bijector = self.tfp_dist._default_event_space_bijector()
if bijector is not None:
return BijectorConstraint(bijector)
else:
Expand All @@ -150,40 +188,43 @@ def is_discrete(self):
return self.support is None


class InverseGamma(tfd.InverseGamma, TFPDistributionMixin):
arg_constraints = {
"concentration": constraints.positive,
"scale": constraints.positive,
}

InverseGamma = TFPDistribution[tfd.InverseGamma]
InverseGamma.arg_constraints = {
"concentration": constraints.positive,
"scale": constraints.positive,
}

class OneHotCategorical(tfd.OneHotCategorical, TFPDistributionMixin):
arg_constraints = {"logits": constraints.real_vector}
has_enumerate_support = True
support = constraints.simplex
is_discrete = True

def enumerate_support(self, expand=True):
n = self.event_shape[-1]
values = jnp.identity(n, dtype=jnp.result_type(self.dtype))
values = values.reshape((n,) + (1,) * len(self.batch_shape) + (n,))
if expand:
values = jnp.broadcast_to(values, (n,) + self.batch_shape + (n,))
return values
def _onehot_enumerate_support(self, expand=True):
n = self.event_shape[-1]
values = jnp.identity(n, dtype=jnp.result_type(self.dtype))
values = values.reshape((n,) + (1,) * len(self.batch_shape) + (n,))
if expand:
values = jnp.broadcast_to(values, (n,) + self.batch_shape + (n,))
return values


class OrderedLogistic(tfd.OrderedLogistic, TFPDistributionMixin):
arg_constraints = {"cutpoints": constraints.ordered_vector, "loc": constraints.real}
OneHotCategorical = TFPDistribution[tfd.OneHotCategorical]
OneHotCategorical.arg_constraints = {"logits": constraints.real_vector}
OneHotCategorical.has_enumerate_support = True
OneHotCategorical.support = constraints.simplex
OneHotCategorical.is_discrete = True
OneHotCategorical.enumerate_support = _onehot_enumerate_support

OrderedLogistic = TFPDistribution[tfd.OrderedLogistic]
OrderedLogistic.arg_constraints = {
"cutpoints": constraints.ordered_vector,
"loc": constraints.real,
}

class Pareto(tfd.Pareto, TFPDistributionMixin):
arg_constraints = {
"concentration": constraints.positive,
"scale": constraints.positive,
}
Pareto = TFPDistribution[tfd.Pareto]
Pareto.arg_constraints = {
"concentration": constraints.positive,
"scale": constraints.positive,
}


__all__ = ["BijectorConstraint", "BijectorTransform", "TFPDistributionMixin"]
__all__ = ["BijectorConstraint", "BijectorTransform", "TFPDistribution"]
_len_all = len(__all__)
for _name, _Dist in tfd.__dict__.items():
if not isinstance(_Dist, type):
Expand All @@ -196,7 +237,7 @@ class Pareto(tfd.Pareto, TFPDistributionMixin):
try:
_PyroDist = locals()[_name]
except KeyError:
_PyroDist = type(_name, (_Dist, TFPDistributionMixin), {})
_PyroDist = TFPDistribution[_Dist]
_PyroDist.__module__ = __name__
if hasattr(numpyro_dist, _name):
numpyro_dist_class = getattr(numpyro_dist, _name)
Expand All @@ -212,7 +253,7 @@ class Pareto(tfd.Pareto, TFPDistributionMixin):

_PyroDist.__doc__ = """
Wraps `{}.{} <https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/distributions/{}>`_
with :class:`~numpyro.contrib.tfp.distributions.TFPDistributionMixin`.
with :class:`~numpyro.contrib.tfp.distributions.TFPDistribution`.
""".format(
_Dist.__module__, _Dist.__name__, _Dist.__name__
)
Expand Down
19 changes: 10 additions & 9 deletions test/contrib/test_tfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,12 @@ def test_independent():
@pytest.mark.filterwarnings("ignore:can't resolve package")
def test_transformed_distributions():
from tensorflow_probability.substrates.jax import bijectors as tfb
from tensorflow_probability.substrates.jax.distributions import Normal as TFPNormal

from numpyro.contrib.tfp import distributions as tfd

d = dist.TransformedDistribution(dist.Normal(0, 1), dist.transforms.ExpTransform())
d1 = tfd.TransformedDistribution(tfd.Normal(0, 1), tfb.Exp())
d1 = tfd.TransformedDistribution(TFPNormal(0, 1), tfb.Exp())
d2 = dist.TransformedDistribution(
dist.Normal(0, 1), tfd.BijectorTransform(tfb.Exp())
)
Expand All @@ -73,19 +74,19 @@ def test_transformed_distributions():

@pytest.mark.filterwarnings("ignore:can't resolve package")
def test_logistic_regression():
from numpyro.contrib.tfp import distributions as dist
from numpyro.contrib.tfp import distributions as tfd

N, dim = 3000, 3
num_warmup, num_samples = (1000, 1000)
data = random.normal(random.PRNGKey(0), (N, dim))
true_coefs = jnp.arange(1.0, dim + 1.0)
logits = jnp.sum(true_coefs * data, axis=-1)
labels = dist.Bernoulli(logits=logits)(rng_key=random.PRNGKey(1))
labels = tfd.Bernoulli(logits=logits)(rng_key=random.PRNGKey(1))

def model(labels):
coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
coefs = numpyro.sample("coefs", tfd.Normal(jnp.zeros(dim), jnp.ones(dim)))
logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1))
return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)
return numpyro.sample("obs", tfd.Bernoulli(logits=logits), obs=labels)

kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
Expand All @@ -101,19 +102,19 @@ def model(labels):
# TODO: remove after https://github.com/tensorflow/probability/issues/1072 is resolved
@pytest.mark.filterwarnings("ignore:Explicitly requested dtype")
def test_beta_bernoulli():
from numpyro.contrib.tfp import distributions as dist
from numpyro.contrib.tfp import distributions as tfd

num_warmup, num_samples = (500, 2000)

def model(data):
alpha = jnp.array([1.1, 1.1])
beta = jnp.array([1.1, 1.1])
p_latent = numpyro.sample("p_latent", dist.Beta(alpha, beta))
numpyro.sample("obs", dist.Bernoulli(p_latent), obs=data)
p_latent = numpyro.sample("p_latent", tfd.Beta(alpha, beta))
numpyro.sample("obs", tfd.Bernoulli(p_latent), obs=data)
return p_latent

true_probs = jnp.array([0.9, 0.1])
data = dist.Bernoulli(true_probs)(rng_key=random.PRNGKey(1), sample_shape=(1000, 2))
data = tfd.Bernoulli(true_probs)(rng_key=random.PRNGKey(1), sample_shape=(1000, 2))
kernel = NUTS(model=model, trajectory_length=0.1)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.run(random.PRNGKey(2), data)
Expand Down

0 comments on commit 42763d4

Please sign in to comment.