Skip to content

Commit

Permalink
AsymmetricLaplace distributions (pyro-ppl#1332)
Browse files Browse the repository at this point in the history
* AsymmetricLaplace distributions, related to pyro-ppl#1319

* Adding ALD and ALDQ distributions to be importable from numpyro.distributions

* fixes typo in __all__ for ALDQ

* Updating tests, docs, and converting ALDQ to ALD under the hood

* fixing qscale typo in reparameterized_params of ALDQ

* Rewritten cdf, icdf, and fixing batching dims, updating tests

* Reordering args, removing gamma_params from testing
  • Loading branch information
d-diaz authored Feb 24, 2022
1 parent 38b0d48 commit f9c756c
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 2 deletions.
18 changes: 17 additions & 1 deletion docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,22 @@ Unit
Continuous Distributions
------------------------

AsymmetricLaplace
^^^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.continuous.AsymmetricLaplace
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

AsymmetricLaplaceQuantile
^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.continuous.AsymmetricLaplaceQuantile
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

Beta
^^^^
.. autoclass:: numpyro.distributions.continuous.Beta
Expand Down Expand Up @@ -809,7 +825,7 @@ ExpTransform
:undoc-members:
:show-inheritance:
:member-order: bysource

IdentityTransform
^^^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.transforms.IdentityTransform
Expand Down
4 changes: 4 additions & 0 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)
from numpyro.distributions.continuous import (
LKJ,
AsymmetricLaplace,
AsymmetricLaplaceQuantile,
Beta,
BetaProportion,
Cauchy,
Expand Down Expand Up @@ -100,6 +102,8 @@
"constraints",
"kl_divergence",
"transforms",
"AsymmetricLaplace",
"AsymmetricLaplaceQuantile",
"Bernoulli",
"BernoulliLogits",
"BernoulliProbs",
Expand Down
138 changes: 137 additions & 1 deletion numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.


from jax import lax
import jax.nn as nn
import jax.numpy as jnp
Expand Down Expand Up @@ -67,6 +66,82 @@
)


class AsymmetricLaplace(Distribution):
arg_constraints = {
"loc": constraints.real,
"scale": constraints.positive,
"asymmetry": constraints.positive,
}
reparametrized_params = ["loc", "scale", "asymmetry"]
support = constraints.real

def __init__(self, loc=0.0, scale=1.0, asymmetry=1.0, validate_args=None):
batch_shape = lax.broadcast_shapes(
jnp.shape(loc), jnp.shape(scale), jnp.shape(asymmetry)
)
self.loc, self.scale, self.asymmetry = promote_shapes(
loc, scale, asymmetry, shape=batch_shape
)
super(AsymmetricLaplace, self).__init__(
batch_shape=batch_shape, validate_args=validate_args
)

@lazy_property
def left_scale(self):
return self.scale * self.asymmetry

@lazy_property
def right_scale(self):
return self.scale / self.asymmetry

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
z = value - self.loc
z = -jnp.abs(z) / jnp.where(z < 0, self.left_scale, self.right_scale)
return z - jnp.log(self.left_scale + self.right_scale)

def sample(self, key, sample_shape=()):
assert is_prng_key(key)
shape = (2,) + sample_shape + self.batch_shape + self.event_shape
u, v = random.exponential(key, shape=shape)
return self.loc - self.left_scale * u + self.right_scale * v

@property
def mean(self):
total_scale = self.left_scale + self.right_scale
mean = self.loc + (self.right_scale**2 - self.left_scale**2) / total_scale
return jnp.broadcast_to(mean, self.batch_shape)

@property
def variance(self):
left = self.left_scale
right = self.right_scale
total = left + right
p = left / total
q = right / total
variance = p * left**2 + q * right**2 + p * q * total**2
return jnp.broadcast_to(variance, self.batch_shape)

def cdf(self, value):
z = value - self.loc
k = self.asymmetry
return jnp.where(
z >= 0,
1 - (1 / (1 + k**2)) * jnp.exp(-jnp.abs(z) / self.right_scale),
k**2 / (1 + k**2) * jnp.exp(-jnp.abs(z) / self.left_scale),
)

def icdf(self, value):
k = self.asymmetry
temp = k**2 / (1 + k**2)
return jnp.where(
value <= temp,
self.loc + self.left_scale * jnp.log(value / temp),
self.loc - self.right_scale * jnp.log((1 + k**2) * (1 - value)),
)


class Beta(Distribution):
arg_constraints = {
"concentration1": constraints.positive,
Expand Down Expand Up @@ -1777,3 +1852,64 @@ def __init__(self, mean, concentration, validate_args=None):
(1.0 - mean) * concentration,
validate_args=validate_args,
)


class AsymmetricLaplaceQuantile(Distribution):
"""An alternative parameterization of AsymmetricLaplace commonly applied in
Bayesian quantile regression.
Instead of the `asymmetry` parameter employed by AsymmetricLaplace, to
define the balance between left- versus right-hand sides of the
distribution, this class utilizes a `quantile` parameter, which describes
the proportion of probability density that falls to the left-hand side of
the distribution.
The `scale` parameter is also interpreted slightly differently than in
AsymmetricLaplce. When `loc=0` and `scale=1`, AsymmetricLaplace(0,1,1)
is equivalent to Laplace(0,1), while AsymmetricLaplaceQuantile(0,1,0.5) is
equivalent to Laplace(0,2).
"""

arg_constraints = {
"loc": constraints.real,
"scale": constraints.positive,
"quantile": constraints.open_interval(0.0, 1.0),
}
reparametrized_params = ["loc", "scale", "quantile"]
support = constraints.real

def __init__(self, loc=0.0, scale=1.0, quantile=0.5, validate_args=None):
batch_shape = lax.broadcast_shapes(
jnp.shape(loc), jnp.shape(scale), jnp.shape(quantile)
)
self.loc, self.scale, self.quantile = promote_shapes(
loc, scale, quantile, shape=batch_shape
)
super(AsymmetricLaplaceQuantile, self).__init__(
batch_shape=batch_shape, validate_args=validate_args
)
asymmetry = (1 / ((1 / quantile) - 1)) ** 0.5
scale_classic = scale * asymmetry / quantile
self._ald = AsymmetricLaplace(loc=loc, scale=scale_classic, asymmetry=asymmetry)

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
return self._ald.log_prob(value)

def sample(self, key, sample_shape=()):
return self._ald.sample(key, sample_shape=sample_shape)

@property
def mean(self):
return self._ald.mean

@property
def variance(self):
return self._ald.variance

def cdf(self, value):
return self._ald.cdf(value)

def icdf(self, value):
return self._ald.icdf(value)
14 changes: 14 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def tree_unflatten(cls, aux_data, params):


_DIST_MAP = {
dist.AsymmetricLaplace: lambda loc, scale, asymmetry: osp.laplace_asymmetric(
asymmetry, loc=loc, scale=scale
),
dist.BernoulliProbs: lambda probs: osp.bernoulli(p=probs),
dist.BernoulliLogits: lambda logits: osp.bernoulli(p=_to_probs_bernoulli(logits)),
dist.Beta: lambda con1, con0: osp.beta(con1, con0),
Expand Down Expand Up @@ -253,6 +256,17 @@ def get_sp_dist(jax_dist):


CONTINUOUS = [
T(dist.AsymmetricLaplace, 1.0, 0.5, 1.0),
T(dist.AsymmetricLaplace, np.array([1.0, 2.0]), 2.0, 2.0),
T(dist.AsymmetricLaplace, np.array([[1.0], [2.0]]), 2.0, np.array([3.0, 5.0])),
T(dist.AsymmetricLaplaceQuantile, 0.0, 1.0, 0.5),
T(dist.AsymmetricLaplaceQuantile, np.array([1.0, 2.0]), 2.0, 0.7),
T(
dist.AsymmetricLaplaceQuantile,
np.array([[1.0], [2.0]]),
2.0,
np.array([0.2, 0.8]),
),
T(dist.Beta, 0.2, 1.1),
T(dist.Beta, 1.0, np.array([2.0, 2.0])),
T(dist.Beta, 1.0, np.array([[1.0, 1.0], [2.0, 2.0]])),
Expand Down

0 comments on commit f9c756c

Please sign in to comment.