From f9c756cca232a5e858f9187e3b5f403646ffe05c Mon Sep 17 00:00:00 2001 From: David Diaz Date: Wed, 23 Feb 2022 18:46:19 -0800 Subject: [PATCH] AsymmetricLaplace distributions (#1332) * AsymmetricLaplace distributions, related to #1319 * Adding ALD and ALDQ distributions to be importable from numpyro.distributions * fixes typo in __all__ for ALDQ * Updating tests, docs, and converting ALDQ to ALD under the hood * fixing qscale typo in reparameterized_params of ALDQ * Rewritten cdf, icdf, and fixing batching dims, updating tests * Reordering args, removing gamma_params from testing --- docs/source/distributions.rst | 18 +++- numpyro/distributions/__init__.py | 4 + numpyro/distributions/continuous.py | 138 +++++++++++++++++++++++++++- test/test_distributions.py | 14 +++ 4 files changed, 172 insertions(+), 2 deletions(-) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 6ac587672..5350311e1 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -80,6 +80,22 @@ Unit Continuous Distributions ------------------------ +AsymmetricLaplace +^^^^^^^^^^^^^^^^^ +.. autoclass:: numpyro.distributions.continuous.AsymmetricLaplace + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + +AsymmetricLaplaceQuantile +^^^^^^^^^^^^^^^^^^^^^^^^^ +.. autoclass:: numpyro.distributions.continuous.AsymmetricLaplaceQuantile + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + Beta ^^^^ .. autoclass:: numpyro.distributions.continuous.Beta @@ -809,7 +825,7 @@ ExpTransform :undoc-members: :show-inheritance: :member-order: bysource - + IdentityTransform ^^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.transforms.IdentityTransform diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index f5154b86e..df33f6454 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -12,6 +12,8 @@ ) from numpyro.distributions.continuous import ( LKJ, + AsymmetricLaplace, + AsymmetricLaplaceQuantile, Beta, BetaProportion, Cauchy, @@ -100,6 +102,8 @@ "constraints", "kl_divergence", "transforms", + "AsymmetricLaplace", + "AsymmetricLaplaceQuantile", "Bernoulli", "BernoulliLogits", "BernoulliProbs", diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index abe0c9363..cd38cf0ba 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -25,7 +25,6 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. - from jax import lax import jax.nn as nn import jax.numpy as jnp @@ -67,6 +66,82 @@ ) +class AsymmetricLaplace(Distribution): + arg_constraints = { + "loc": constraints.real, + "scale": constraints.positive, + "asymmetry": constraints.positive, + } + reparametrized_params = ["loc", "scale", "asymmetry"] + support = constraints.real + + def __init__(self, loc=0.0, scale=1.0, asymmetry=1.0, validate_args=None): + batch_shape = lax.broadcast_shapes( + jnp.shape(loc), jnp.shape(scale), jnp.shape(asymmetry) + ) + self.loc, self.scale, self.asymmetry = promote_shapes( + loc, scale, asymmetry, shape=batch_shape + ) + super(AsymmetricLaplace, self).__init__( + batch_shape=batch_shape, validate_args=validate_args + ) + + @lazy_property + def left_scale(self): + return self.scale * self.asymmetry + + @lazy_property + def right_scale(self): + return self.scale / self.asymmetry + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + z = value - self.loc + z = -jnp.abs(z) / jnp.where(z < 0, self.left_scale, self.right_scale) + return z - jnp.log(self.left_scale + self.right_scale) + + def sample(self, key, sample_shape=()): + assert is_prng_key(key) + shape = (2,) + sample_shape + self.batch_shape + self.event_shape + u, v = random.exponential(key, shape=shape) + return self.loc - self.left_scale * u + self.right_scale * v + + @property + def mean(self): + total_scale = self.left_scale + self.right_scale + mean = self.loc + (self.right_scale**2 - self.left_scale**2) / total_scale + return jnp.broadcast_to(mean, self.batch_shape) + + @property + def variance(self): + left = self.left_scale + right = self.right_scale + total = left + right + p = left / total + q = right / total + variance = p * left**2 + q * right**2 + p * q * total**2 + return jnp.broadcast_to(variance, self.batch_shape) + + def cdf(self, value): + z = value - self.loc + k = self.asymmetry + return jnp.where( + z >= 0, + 1 - (1 / (1 + k**2)) * jnp.exp(-jnp.abs(z) / self.right_scale), + k**2 / (1 + k**2) * jnp.exp(-jnp.abs(z) / self.left_scale), + ) + + def icdf(self, value): + k = self.asymmetry + temp = k**2 / (1 + k**2) + return jnp.where( + value <= temp, + self.loc + self.left_scale * jnp.log(value / temp), + self.loc - self.right_scale * jnp.log((1 + k**2) * (1 - value)), + ) + + class Beta(Distribution): arg_constraints = { "concentration1": constraints.positive, @@ -1777,3 +1852,64 @@ def __init__(self, mean, concentration, validate_args=None): (1.0 - mean) * concentration, validate_args=validate_args, ) + + +class AsymmetricLaplaceQuantile(Distribution): + """An alternative parameterization of AsymmetricLaplace commonly applied in + Bayesian quantile regression. + + Instead of the `asymmetry` parameter employed by AsymmetricLaplace, to + define the balance between left- versus right-hand sides of the + distribution, this class utilizes a `quantile` parameter, which describes + the proportion of probability density that falls to the left-hand side of + the distribution. + + The `scale` parameter is also interpreted slightly differently than in + AsymmetricLaplce. When `loc=0` and `scale=1`, AsymmetricLaplace(0,1,1) + is equivalent to Laplace(0,1), while AsymmetricLaplaceQuantile(0,1,0.5) is + equivalent to Laplace(0,2). + """ + + arg_constraints = { + "loc": constraints.real, + "scale": constraints.positive, + "quantile": constraints.open_interval(0.0, 1.0), + } + reparametrized_params = ["loc", "scale", "quantile"] + support = constraints.real + + def __init__(self, loc=0.0, scale=1.0, quantile=0.5, validate_args=None): + batch_shape = lax.broadcast_shapes( + jnp.shape(loc), jnp.shape(scale), jnp.shape(quantile) + ) + self.loc, self.scale, self.quantile = promote_shapes( + loc, scale, quantile, shape=batch_shape + ) + super(AsymmetricLaplaceQuantile, self).__init__( + batch_shape=batch_shape, validate_args=validate_args + ) + asymmetry = (1 / ((1 / quantile) - 1)) ** 0.5 + scale_classic = scale * asymmetry / quantile + self._ald = AsymmetricLaplace(loc=loc, scale=scale_classic, asymmetry=asymmetry) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + return self._ald.log_prob(value) + + def sample(self, key, sample_shape=()): + return self._ald.sample(key, sample_shape=sample_shape) + + @property + def mean(self): + return self._ald.mean + + @property + def variance(self): + return self._ald.variance + + def cdf(self, value): + return self._ald.cdf(value) + + def icdf(self, value): + return self._ald.icdf(value) diff --git a/test/test_distributions.py b/test/test_distributions.py index b46c539c1..d3bf13878 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -196,6 +196,9 @@ def tree_unflatten(cls, aux_data, params): _DIST_MAP = { + dist.AsymmetricLaplace: lambda loc, scale, asymmetry: osp.laplace_asymmetric( + asymmetry, loc=loc, scale=scale + ), dist.BernoulliProbs: lambda probs: osp.bernoulli(p=probs), dist.BernoulliLogits: lambda logits: osp.bernoulli(p=_to_probs_bernoulli(logits)), dist.Beta: lambda con1, con0: osp.beta(con1, con0), @@ -253,6 +256,17 @@ def get_sp_dist(jax_dist): CONTINUOUS = [ + T(dist.AsymmetricLaplace, 1.0, 0.5, 1.0), + T(dist.AsymmetricLaplace, np.array([1.0, 2.0]), 2.0, 2.0), + T(dist.AsymmetricLaplace, np.array([[1.0], [2.0]]), 2.0, np.array([3.0, 5.0])), + T(dist.AsymmetricLaplaceQuantile, 0.0, 1.0, 0.5), + T(dist.AsymmetricLaplaceQuantile, np.array([1.0, 2.0]), 2.0, 0.7), + T( + dist.AsymmetricLaplaceQuantile, + np.array([[1.0], [2.0]]), + 2.0, + np.array([0.2, 0.8]), + ), T(dist.Beta, 0.2, 1.1), T(dist.Beta, 1.0, np.array([2.0, 2.0])), T(dist.Beta, 1.0, np.array([[1.0, 1.0], [2.0, 2.0]])),