Skip to content

Commit

Permalink
Force validate_args to be keyword argument (pyro-ppl#1358)
Browse files Browse the repository at this point in the history
  • Loading branch information
tcbegley authored Mar 10, 2022
1 parent 89e323e commit f333e91
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 69 deletions.
14 changes: 7 additions & 7 deletions numpyro/distributions/conjugate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class BetaBinomial(Distribution):
enumerate_support = BinomialProbs.enumerate_support

def __init__(
self, concentration1, concentration0, total_count=1, validate_args=None
self, concentration1, concentration0, total_count=1, *, validate_args=None
):
self.concentration1, self.concentration0, self.total_count = promote_shapes(
concentration1, concentration0, total_count
Expand Down Expand Up @@ -110,7 +110,7 @@ class DirichletMultinomial(Distribution):
"total_count": constraints.nonnegative_integer,
}

def __init__(self, concentration, total_count=1, validate_args=None):
def __init__(self, concentration, total_count=1, *, validate_args=None):
if jnp.ndim(concentration) < 1:
raise ValueError(
"`concentration` parameter must be at least one-dimensional."
Expand Down Expand Up @@ -184,7 +184,7 @@ class GammaPoisson(Distribution):
}
support = constraints.nonnegative_integer

def __init__(self, concentration, rate=1.0, validate_args=None):
def __init__(self, concentration, rate=1.0, *, validate_args=None):
self.concentration, self.rate = promote_shapes(concentration, rate)
self._gamma = Gamma(concentration, rate)
super(GammaPoisson, self).__init__(
Expand Down Expand Up @@ -220,7 +220,7 @@ def cdf(self, value):
return bt


def NegativeBinomial(total_count, probs=None, logits=None, validate_args=None):
def NegativeBinomial(total_count, probs=None, logits=None, *, validate_args=None):
if probs is not None:
return NegativeBinomialProbs(total_count, probs, validate_args=validate_args)
elif logits is not None:
Expand All @@ -236,7 +236,7 @@ class NegativeBinomialProbs(GammaPoisson):
}
support = constraints.nonnegative_integer

def __init__(self, total_count, probs, validate_args=None):
def __init__(self, total_count, probs, *, validate_args=None):
self.total_count, self.probs = promote_shapes(total_count, probs)
concentration = total_count
rate = 1.0 / probs - 1.0
Expand All @@ -250,7 +250,7 @@ class NegativeBinomialLogits(GammaPoisson):
}
support = constraints.nonnegative_integer

def __init__(self, total_count, logits, validate_args=None):
def __init__(self, total_count, logits, *, validate_args=None):
self.total_count, self.logits = promote_shapes(total_count, logits)
concentration = total_count
rate = jnp.exp(-logits)
Expand All @@ -276,7 +276,7 @@ class NegativeBinomial2(GammaPoisson):
}
support = constraints.nonnegative_integer

def __init__(self, mean, concentration, validate_args=None):
def __init__(self, mean, concentration, *, validate_args=None):
rate = concentration / mean
super().__init__(concentration, rate, validate_args=validate_args)

Expand Down
56 changes: 28 additions & 28 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class AsymmetricLaplace(Distribution):
reparametrized_params = ["loc", "scale", "asymmetry"]
support = constraints.real

def __init__(self, loc=0.0, scale=1.0, asymmetry=1.0, validate_args=None):
def __init__(self, loc=0.0, scale=1.0, asymmetry=1.0, *, validate_args=None):
batch_shape = lax.broadcast_shapes(
jnp.shape(loc), jnp.shape(scale), jnp.shape(asymmetry)
)
Expand Down Expand Up @@ -150,7 +150,7 @@ class Beta(Distribution):
reparametrized_params = ["concentration1", "concentration0"]
support = constraints.unit_interval

def __init__(self, concentration1, concentration0, validate_args=None):
def __init__(self, concentration1, concentration0, *, validate_args=None):
self.concentration1, self.concentration0 = promote_shapes(
concentration1, concentration0
)
Expand Down Expand Up @@ -190,7 +190,7 @@ class Cauchy(Distribution):
support = constraints.real
reparametrized_params = ["loc", "scale"]

def __init__(self, loc=0.0, scale=1.0, validate_args=None):
def __init__(self, loc=0.0, scale=1.0, *, validate_args=None):
self.loc, self.scale = promote_shapes(loc, scale)
batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
super(Cauchy, self).__init__(
Expand Down Expand Up @@ -233,7 +233,7 @@ class Dirichlet(Distribution):
reparametrized_params = ["concentration"]
support = constraints.simplex

def __init__(self, concentration, validate_args=None):
def __init__(self, concentration, *, validate_args=None):
if jnp.ndim(concentration) < 1:
raise ValueError(
"`concentration` parameter must be at least one-dimensional."
Expand Down Expand Up @@ -299,7 +299,7 @@ class Exponential(Distribution):
arg_constraints = {"rate": constraints.positive}
support = constraints.positive

def __init__(self, rate=1.0, validate_args=None):
def __init__(self, rate=1.0, *, validate_args=None):
self.rate = rate
super(Exponential, self).__init__(
batch_shape=jnp.shape(rate), validate_args=validate_args
Expand Down Expand Up @@ -338,7 +338,7 @@ class Gamma(Distribution):
support = constraints.positive
reparametrized_params = ["concentration", "rate"]

def __init__(self, concentration, rate=1.0, validate_args=None):
def __init__(self, concentration, rate=1.0, *, validate_args=None):
self.concentration, self.rate = promote_shapes(concentration, rate)
batch_shape = lax.broadcast_shapes(jnp.shape(concentration), jnp.shape(rate))
super(Gamma, self).__init__(
Expand Down Expand Up @@ -377,7 +377,7 @@ class Chi2(Gamma):
arg_constraints = {"df": constraints.positive}
reparametrized_params = ["df"]

def __init__(self, df, validate_args=None):
def __init__(self, df, *, validate_args=None):
self.df = df
super(Chi2, self).__init__(0.5 * df, 0.5, validate_args=validate_args)

Expand All @@ -387,7 +387,7 @@ class GaussianRandomWalk(Distribution):
support = constraints.real_vector
reparametrized_params = ["scale"]

def __init__(self, scale=1.0, num_steps=1, validate_args=None):
def __init__(self, scale=1.0, num_steps=1, *, validate_args=None):
assert (
isinstance(num_steps, int) and num_steps > 0
), "`num_steps` argument should be an positive integer."
Expand Down Expand Up @@ -435,7 +435,7 @@ class HalfCauchy(Distribution):
support = constraints.positive
arg_constraints = {"scale": constraints.positive}

def __init__(self, scale=1.0, validate_args=None):
def __init__(self, scale=1.0, *, validate_args=None):
self._cauchy = Cauchy(0.0, scale)
self.scale = scale
super(HalfCauchy, self).__init__(
Expand Down Expand Up @@ -470,7 +470,7 @@ class HalfNormal(Distribution):
support = constraints.positive
arg_constraints = {"scale": constraints.positive}

def __init__(self, scale=1.0, validate_args=None):
def __init__(self, scale=1.0, *, validate_args=None):
self._normal = Normal(0.0, scale)
self.scale = scale
super(HalfNormal, self).__init__(
Expand Down Expand Up @@ -514,7 +514,7 @@ class InverseGamma(TransformedDistribution):
reparametrized_params = ["concentration", "rate"]
support = constraints.positive

def __init__(self, concentration, rate=1.0, validate_args=None):
def __init__(self, concentration, rate=1.0, *, validate_args=None):
base_dist = Gamma(concentration, rate)
self.concentration = base_dist.concentration
self.rate = base_dist.rate
Expand Down Expand Up @@ -546,7 +546,7 @@ class Gumbel(Distribution):
support = constraints.real
reparametrized_params = ["loc", "scale"]

def __init__(self, loc=0.0, scale=1.0, validate_args=None):
def __init__(self, loc=0.0, scale=1.0, *, validate_args=None):
self.loc, self.scale = promote_shapes(loc, scale)
batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))

Expand Down Expand Up @@ -597,7 +597,7 @@ class Kumaraswamy(TransformedDistribution):
# we can set this flag to 1000.
KL_KUMARASWAMY_BETA_TAYLOR_ORDER = 10

def __init__(self, concentration1, concentration0, validate_args=None):
def __init__(self, concentration1, concentration0, *, validate_args=None):
self.concentration1, self.concentration0 = promote_shapes(
concentration1, concentration0
)
Expand Down Expand Up @@ -648,7 +648,7 @@ class Laplace(Distribution):
support = constraints.real
reparametrized_params = ["loc", "scale"]

def __init__(self, loc=0.0, scale=1.0, validate_args=None):
def __init__(self, loc=0.0, scale=1.0, *, validate_args=None):
self.loc, self.scale = promote_shapes(loc, scale)
batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
super(Laplace, self).__init__(
Expand Down Expand Up @@ -736,7 +736,7 @@ def model(y): # y has dimension N x d
support = constraints.corr_matrix

def __init__(
self, dimension, concentration=1.0, sample_method="onion", validate_args=None
self, dimension, concentration=1.0, sample_method="onion", *, validate_args=None
):
base_dist = LKJCholesky(dimension, concentration, sample_method)
self.dimension, self.concentration = (
Expand Down Expand Up @@ -819,7 +819,7 @@ def model(y): # y has dimension N x d
support = constraints.corr_cholesky

def __init__(
self, dimension, concentration=1.0, sample_method="onion", validate_args=None
self, dimension, concentration=1.0, sample_method="onion", *, validate_args=None
):
if dimension < 2:
raise ValueError("Dimension must be greater than or equal to 2.")
Expand Down Expand Up @@ -980,7 +980,7 @@ class LogNormal(TransformedDistribution):
support = constraints.positive
reparametrized_params = ["loc", "scale"]

def __init__(self, loc=0.0, scale=1.0, validate_args=None):
def __init__(self, loc=0.0, scale=1.0, *, validate_args=None):
base_dist = Normal(loc, scale)
self.loc, self.scale = base_dist.loc, base_dist.scale
super(LogNormal, self).__init__(
Expand All @@ -1007,7 +1007,7 @@ class Logistic(Distribution):
support = constraints.real
reparametrized_params = ["loc", "scale"]

def __init__(self, loc=0.0, scale=1.0, validate_args=None):
def __init__(self, loc=0.0, scale=1.0, *, validate_args=None):
self.loc, self.scale = promote_shapes(loc, scale)
batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
super(Logistic, self).__init__(batch_shape, validate_args=validate_args)
Expand Down Expand Up @@ -1354,7 +1354,7 @@ class LowRankMultivariateNormal(Distribution):
support = constraints.real_vector
reparametrized_params = ["loc", "cov_factor", "cov_diag"]

def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
def __init__(self, loc, cov_factor, cov_diag, *, validate_args=None):
if jnp.ndim(loc) < 1:
raise ValueError("`loc` must be at least one-dimensional.")
event_shape = jnp.shape(loc)[-1:]
Expand Down Expand Up @@ -1485,7 +1485,7 @@ class Normal(Distribution):
support = constraints.real
reparametrized_params = ["loc", "scale"]

def __init__(self, loc=0.0, scale=1.0, validate_args=None):
def __init__(self, loc=0.0, scale=1.0, *, validate_args=None):
self.loc, self.scale = promote_shapes(loc, scale)
batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
super(Normal, self).__init__(
Expand Down Expand Up @@ -1525,7 +1525,7 @@ class Pareto(TransformedDistribution):
arg_constraints = {"scale": constraints.positive, "alpha": constraints.positive}
reparametrized_params = ["scale", "alpha"]

def __init__(self, scale, alpha, validate_args=None):
def __init__(self, scale, alpha, *, validate_args=None):
self.scale, self.alpha = promote_shapes(scale, alpha)
batch_shape = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(alpha))
scale, alpha = jnp.broadcast_to(scale, batch_shape), jnp.broadcast_to(
Expand Down Expand Up @@ -1568,7 +1568,7 @@ class RelaxedBernoulliLogits(TransformedDistribution):
arg_constraints = {"temperature": constraints.positive, "logits": constraints.real}
support = constraints.unit_interval

def __init__(self, temperature, logits, validate_args=None):
def __init__(self, temperature, logits, *, validate_args=None):
self.temperature, self.logits = promote_shapes(temperature, logits)
base_dist = Logistic(logits / temperature, 1 / temperature)
transforms = [SigmoidTransform()]
Expand All @@ -1578,7 +1578,7 @@ def tree_flatten(self):
return super(TransformedDistribution, self).tree_flatten()


def RelaxedBernoulli(temperature, probs=None, logits=None, validate_args=None):
def RelaxedBernoulli(temperature, probs=None, logits=None, *, validate_args=None):
if probs is None and logits is None:
raise ValueError("One of `probs` or `logits` must be specified.")
if probs is not None:
Expand Down Expand Up @@ -1652,7 +1652,7 @@ class StudentT(Distribution):
support = constraints.real
reparametrized_params = ["df", "loc", "scale"]

def __init__(self, df, loc=0.0, scale=1.0, validate_args=None):
def __init__(self, df, loc=0.0, scale=1.0, *, validate_args=None):
batch_shape = lax.broadcast_shapes(
jnp.shape(df), jnp.shape(loc), jnp.shape(scale)
)
Expand Down Expand Up @@ -1723,7 +1723,7 @@ class Uniform(Distribution):
arg_constraints = {"low": constraints.dependent, "high": constraints.dependent}
reparametrized_params = ["low", "high"]

def __init__(self, low=0.0, high=1.0, validate_args=None):
def __init__(self, low=0.0, high=1.0, *, validate_args=None):
self.low, self.high = promote_shapes(low, high)
batch_shape = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
self._support = constraints.interval(low, high)
Expand Down Expand Up @@ -1788,7 +1788,7 @@ class Weibull(Distribution):
support = constraints.positive
reparametrized_params = ["scale", "concentration"]

def __init__(self, scale, concentration, validate_args=None):
def __init__(self, scale, concentration, *, validate_args=None):
self.concentration, self.scale = promote_shapes(concentration, scale)
batch_shape = lax.broadcast_shapes(jnp.shape(concentration), jnp.shape(scale))
super().__init__(batch_shape=batch_shape, validate_args=validate_args)
Expand Down Expand Up @@ -1843,7 +1843,7 @@ class BetaProportion(Beta):
reparametrized_params = ["mean", "concentration"]
support = constraints.unit_interval

def __init__(self, mean, concentration, validate_args=None):
def __init__(self, mean, concentration, *, validate_args=None):
self.concentration = jnp.broadcast_to(
concentration, lax.broadcast_shapes(jnp.shape(concentration))
)
Expand Down Expand Up @@ -1878,7 +1878,7 @@ class AsymmetricLaplaceQuantile(Distribution):
reparametrized_params = ["loc", "scale", "quantile"]
support = constraints.real

def __init__(self, loc=0.0, scale=1.0, quantile=0.5, validate_args=None):
def __init__(self, loc=0.0, scale=1.0, quantile=0.5, *, validate_args=None):
batch_shape = lax.broadcast_shapes(
jnp.shape(loc), jnp.shape(scale), jnp.shape(quantile)
)
Expand Down
6 changes: 3 additions & 3 deletions numpyro/distributions/directional.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def model():
reparametrized_params = ["loc"]
support = constraints.circular

def __init__(self, loc, concentration, validate_args=None):
def __init__(self, loc, concentration, *, validate_args=None):
"""von Mises distribution for sampling directions.
:param loc: center of distribution
Expand Down Expand Up @@ -217,7 +217,7 @@ def model(obs):

support = constraints.independent(constraints.circular, 1)

def __init__(self, base_dist: Distribution, skewness, validate_args=None):
def __init__(self, base_dist: Distribution, skewness, *, validate_args=None):
assert (
base_dist.event_shape == skewness.shape[-1:]
), "Sine Skewing is only valid with a skewness parameter for each dimension of `base_dist.event_shape`."
Expand Down Expand Up @@ -373,7 +373,7 @@ def __init__(
jnp.shape(psi_concentration),
jnp.shape(correlation),
)
super().__init__(batch_shape, (2,), validate_args)
super().__init__(batch_shape, (2,), validate_args=validate_args)

@lazy_property
def norm_const(self):
Expand Down
Loading

0 comments on commit f333e91

Please sign in to comment.