From 42763d4bbb2bc508e66f244d995c51a0e370304d Mon Sep 17 00:00:00 2001 From: Du Phan Date: Tue, 22 Jun 2021 10:19:15 -0500 Subject: [PATCH] Refactor TFP wrappers to avoid new metadata logic (#1064) * refactor TFP wrappers to avoid metadata logic * make TransformedDistribution work * raise attribute error for missing attributes --- numpyro/contrib/tfp/distributions.py | 147 +++++++++++++++++---------- test/contrib/test_tfp.py | 19 ++-- 2 files changed, 104 insertions(+), 62 deletions(-) diff --git a/numpyro/contrib/tfp/distributions.py b/numpyro/contrib/tfp/distributions.py index 30b24f532..b4ff3eb08 100644 --- a/numpyro/contrib/tfp/distributions.py +++ b/numpyro/contrib/tfp/distributions.py @@ -1,6 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import inspect + import numpy as np import jax.numpy as jnp @@ -103,42 +105,78 @@ def _transform_to_bijector_constraint(constraint): return BijectorTransform(constraint.bijector) -_TFPDistributionMeta = type(tfd.Distribution) +class _TFPDistributionMeta(type(NumPyroDistribution)): + def __getitem__(cls, tfd_class): + assert issubclass(tfd_class, tfd.Distribution) + def init(self, *args, **kwargs): + self.tfp_dist = tfd_class(*args, **kwargs) -# XXX: we create this mixin class to avoid metaclass conflict between TFP and NumPyro Ditribution -class _TFPMixinMeta(_TFPDistributionMeta, type(NumPyroDistribution)): - def __init__(cls, name, bases, dct): - # XXX: _TFPDistributionMeta.__init__ registers cls as a PyTree - # for some reasons, when defining metaclass of TFPDistributionMixin to be _TFPMixinMeta, - # TFPDistributionMixin will be registered as a PyTree 2 times, which is not allowed - # in JAX, so we skip registering TFPDistributionMixin as a PyTree. - if name == "TFPDistributionMixin": - super(_TFPDistributionMeta, cls).__init__(name, bases, dct) - else: - super(_TFPMixinMeta, cls).__init__(name, bases, dct) + init.__signature__ = inspect.signature(tfd_class.__init__) + + _PyroDist = type(tfd_class.__name__, (TFPDistribution,), {}) + _PyroDist.tfd_class = tfd_class + _PyroDist.__init__ = init + return _PyroDist -class TFPDistributionMixin(NumPyroDistribution, metaclass=_TFPMixinMeta): +class TFPDistribution(NumPyroDistribution, metaclass=_TFPDistributionMeta): """ - A mixin layer to make TensorFlow Probability (TFP) distribution compatible - with NumPyro internal. + A thin wrapper for TensorFlow Probability (TFP) distributions. The constructor + has the same signature as the corresponding TFP distribution. + + This class can be used to convert a TFP distribution to a NumPyro-compatible one + as follows:: + + d = TFPDistribution[tfd.Normal](0, 1) + """ - def __init_subclass__(cls, **kwargs): - # skip register pytree because TFP distributions are already pytrees - super(object, cls).__init_subclass__(**kwargs) + tfd_class = None - def __call__(self, *args, **kwargs): - key = kwargs.pop("rng_key") - sample_intermediates = kwargs.pop("sample_intermediates", False) - if sample_intermediates: - return self.sample(*args, seed=key, **kwargs), [] - return self.sample(*args, seed=key, **kwargs) + def __getattr__(self, name): + # return parameters from the constructor + if name in self.tfp_dist.parameters: + return self.tfp_dist.parameters[name] + elif name in ["dtype", "reparameterization_type"]: + return getattr(self.tfp_dist, name) + raise AttributeError(name) + + @property + def batch_shape(self): + return self.tfp_dist.batch_shape + + @property + def event_shape(self): + return self.tfp_dist.event_shape + + @property + def has_rsample(self): + return self.tfp_dist.reparameterization_type is tfd.FULLY_REPARAMETERIZED + + def sample(self, key, sample_shape=()): + return self.tfp_dist.sample(sample_shape=sample_shape, seed=key) + + def log_prob(self, value): + return self.tfp_dist.log_prob(value) + + @property + def mean(self): + return self.tfp_dist.mean() + + @property + def variance(self): + return self.tfp_dist.variance() + + def cdf(self, value): + return self.tfp_dist.cdf(value) + + def icdf(self, q): + return self.tfp_dist.quantile(q) @property def support(self): - bijector = self._default_event_space_bijector() + bijector = self.tfp_dist._default_event_space_bijector() if bijector is not None: return BijectorConstraint(bijector) else: @@ -150,40 +188,43 @@ def is_discrete(self): return self.support is None -class InverseGamma(tfd.InverseGamma, TFPDistributionMixin): - arg_constraints = { - "concentration": constraints.positive, - "scale": constraints.positive, - } - +InverseGamma = TFPDistribution[tfd.InverseGamma] +InverseGamma.arg_constraints = { + "concentration": constraints.positive, + "scale": constraints.positive, +} -class OneHotCategorical(tfd.OneHotCategorical, TFPDistributionMixin): - arg_constraints = {"logits": constraints.real_vector} - has_enumerate_support = True - support = constraints.simplex - is_discrete = True - def enumerate_support(self, expand=True): - n = self.event_shape[-1] - values = jnp.identity(n, dtype=jnp.result_type(self.dtype)) - values = values.reshape((n,) + (1,) * len(self.batch_shape) + (n,)) - if expand: - values = jnp.broadcast_to(values, (n,) + self.batch_shape + (n,)) - return values +def _onehot_enumerate_support(self, expand=True): + n = self.event_shape[-1] + values = jnp.identity(n, dtype=jnp.result_type(self.dtype)) + values = values.reshape((n,) + (1,) * len(self.batch_shape) + (n,)) + if expand: + values = jnp.broadcast_to(values, (n,) + self.batch_shape + (n,)) + return values -class OrderedLogistic(tfd.OrderedLogistic, TFPDistributionMixin): - arg_constraints = {"cutpoints": constraints.ordered_vector, "loc": constraints.real} +OneHotCategorical = TFPDistribution[tfd.OneHotCategorical] +OneHotCategorical.arg_constraints = {"logits": constraints.real_vector} +OneHotCategorical.has_enumerate_support = True +OneHotCategorical.support = constraints.simplex +OneHotCategorical.is_discrete = True +OneHotCategorical.enumerate_support = _onehot_enumerate_support +OrderedLogistic = TFPDistribution[tfd.OrderedLogistic] +OrderedLogistic.arg_constraints = { + "cutpoints": constraints.ordered_vector, + "loc": constraints.real, +} -class Pareto(tfd.Pareto, TFPDistributionMixin): - arg_constraints = { - "concentration": constraints.positive, - "scale": constraints.positive, - } +Pareto = TFPDistribution[tfd.Pareto] +Pareto.arg_constraints = { + "concentration": constraints.positive, + "scale": constraints.positive, +} -__all__ = ["BijectorConstraint", "BijectorTransform", "TFPDistributionMixin"] +__all__ = ["BijectorConstraint", "BijectorTransform", "TFPDistribution"] _len_all = len(__all__) for _name, _Dist in tfd.__dict__.items(): if not isinstance(_Dist, type): @@ -196,7 +237,7 @@ class Pareto(tfd.Pareto, TFPDistributionMixin): try: _PyroDist = locals()[_name] except KeyError: - _PyroDist = type(_name, (_Dist, TFPDistributionMixin), {}) + _PyroDist = TFPDistribution[_Dist] _PyroDist.__module__ = __name__ if hasattr(numpyro_dist, _name): numpyro_dist_class = getattr(numpyro_dist, _name) @@ -212,7 +253,7 @@ class Pareto(tfd.Pareto, TFPDistributionMixin): _PyroDist.__doc__ = """ Wraps `{}.{} `_ - with :class:`~numpyro.contrib.tfp.distributions.TFPDistributionMixin`. + with :class:`~numpyro.contrib.tfp.distributions.TFPDistribution`. """.format( _Dist.__module__, _Dist.__name__, _Dist.__name__ ) diff --git a/test/contrib/test_tfp.py b/test/contrib/test_tfp.py index eeb3ea1e2..aa6fc7dbe 100644 --- a/test/contrib/test_tfp.py +++ b/test/contrib/test_tfp.py @@ -55,11 +55,12 @@ def test_independent(): @pytest.mark.filterwarnings("ignore:can't resolve package") def test_transformed_distributions(): from tensorflow_probability.substrates.jax import bijectors as tfb + from tensorflow_probability.substrates.jax.distributions import Normal as TFPNormal from numpyro.contrib.tfp import distributions as tfd d = dist.TransformedDistribution(dist.Normal(0, 1), dist.transforms.ExpTransform()) - d1 = tfd.TransformedDistribution(tfd.Normal(0, 1), tfb.Exp()) + d1 = tfd.TransformedDistribution(TFPNormal(0, 1), tfb.Exp()) d2 = dist.TransformedDistribution( dist.Normal(0, 1), tfd.BijectorTransform(tfb.Exp()) ) @@ -73,19 +74,19 @@ def test_transformed_distributions(): @pytest.mark.filterwarnings("ignore:can't resolve package") def test_logistic_regression(): - from numpyro.contrib.tfp import distributions as dist + from numpyro.contrib.tfp import distributions as tfd N, dim = 3000, 3 num_warmup, num_samples = (1000, 1000) data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1.0, dim + 1.0) logits = jnp.sum(true_coefs * data, axis=-1) - labels = dist.Bernoulli(logits=logits)(rng_key=random.PRNGKey(1)) + labels = tfd.Bernoulli(logits=logits)(rng_key=random.PRNGKey(1)) def model(labels): - coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) + coefs = numpyro.sample("coefs", tfd.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1)) - return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) + return numpyro.sample("obs", tfd.Bernoulli(logits=logits), obs=labels) kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) @@ -101,19 +102,19 @@ def model(labels): # TODO: remove after https://github.com/tensorflow/probability/issues/1072 is resolved @pytest.mark.filterwarnings("ignore:Explicitly requested dtype") def test_beta_bernoulli(): - from numpyro.contrib.tfp import distributions as dist + from numpyro.contrib.tfp import distributions as tfd num_warmup, num_samples = (500, 2000) def model(data): alpha = jnp.array([1.1, 1.1]) beta = jnp.array([1.1, 1.1]) - p_latent = numpyro.sample("p_latent", dist.Beta(alpha, beta)) - numpyro.sample("obs", dist.Bernoulli(p_latent), obs=data) + p_latent = numpyro.sample("p_latent", tfd.Beta(alpha, beta)) + numpyro.sample("obs", tfd.Bernoulli(p_latent), obs=data) return p_latent true_probs = jnp.array([0.9, 0.1]) - data = dist.Bernoulli(true_probs)(rng_key=random.PRNGKey(1), sample_shape=(1000, 2)) + data = tfd.Bernoulli(true_probs)(rng_key=random.PRNGKey(1), sample_shape=(1000, 2)) kernel = NUTS(model=model, trajectory_length=0.1) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.run(random.PRNGKey(2), data)