Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

from . import constraints

_VALIDATION_ENABLED = False
_VALIDATION_ENABLED = True


def enable_validation(is_validate: bool = True) -> None:
Expand Down
103 changes: 103 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4486,3 +4486,106 @@ def test_interval_censored_validate_sample(
censored_dist.log_prob(value)
else:
censored_dist.log_prob(value) # Should not raise


def test_uniform_log_prob_outside_support():
d = dist.Uniform(0, 1)
assert_allclose(d.log_prob(-0.5), -jnp.inf)
assert_allclose(d.log_prob(1.5), -jnp.inf)


@pytest.mark.parametrize(
"low, high", [(0.0, 1.0), (-2.0, 3.0), (1.0, 5.0), (-5.0, -1.0)]
)
def test_uniform_log_prob_boundaries(low, high):
"""Test that boundary values are handled correctly."""
d = dist.Uniform(low, high)
expected_log_prob = -jnp.log(high - low)

# Value at lower bound (included): should have finite log prob
assert_allclose(d.log_prob(low), expected_log_prob)

# Value just above lower bound: should have finite log prob
assert_allclose(d.log_prob(low + 1e-10), expected_log_prob)

# Value at upper bound (excluded): should be -inf
assert_allclose(d.log_prob(high), -jnp.inf)

# Value just below upper bound: should have finite log prob
assert_allclose(d.log_prob(high - 1e-10), expected_log_prob)

# Value inside support: should have finite log prob
mid = (low + high) / 2.0
assert_allclose(d.log_prob(mid), expected_log_prob)

# Value below lower bound: should be -inf
assert_allclose(d.log_prob(low - 1.0), -jnp.inf)

# Value above upper bound: should be -inf
assert_allclose(d.log_prob(high + 1.0), -jnp.inf)


@pytest.mark.parametrize("batch_shape", [(), (3,), (2, 3), (4, 2, 3)])
def test_uniform_log_prob_broadcasting(batch_shape):
"""Test broadcasting with different batch shapes."""
if batch_shape == ():
low = 0.0
high = 1.0
else:
low = jnp.linspace(0.0, 1.0, np.prod(batch_shape)).reshape(batch_shape)
high = jnp.linspace(1.0, 2.0, np.prod(batch_shape)).reshape(batch_shape)

d = dist.Uniform(low, high)

# Test with scalar value
value = 0.5
log_probs = d.log_prob(value)
assert log_probs.shape == batch_shape

# Test with batched value
if batch_shape:
value_batched = jnp.linspace(-0.5, 1.5, np.prod(batch_shape)).reshape(
batch_shape
)
log_probs_batched = d.log_prob(value_batched)
assert log_probs_batched.shape == batch_shape

# Check that values outside support return -inf
# Values < low should be -inf
below_low = low - 0.1
assert_allclose(d.log_prob(below_low), -jnp.inf)

# Values >= high should be -inf
at_high = high
assert_allclose(d.log_prob(at_high), -jnp.inf)


@pytest.mark.parametrize("value_shape", [(), (5,), (3, 4), (2, 3, 4)])
def test_uniform_log_prob_value_broadcasting(value_shape):
"""Test broadcasting when value has different shapes."""
d = dist.Uniform(0.0, 1.0)

if value_shape == ():
values = 0.5
else:
values = jnp.linspace(-0.5, 1.5, np.prod(value_shape)).reshape(value_shape)

log_probs = d.log_prob(values)
assert log_probs.shape == value_shape

# Check that values inside support have finite log prob
inside_values = jnp.linspace(0.1, 0.9, np.prod(value_shape) if value_shape else 1)
if value_shape:
inside_values = inside_values.reshape(value_shape)
log_probs_inside = d.log_prob(inside_values)
assert jnp.all(jnp.isfinite(log_probs_inside))

# Check that values outside support have -inf
outside_values = jnp.linspace(-1.0, 2.0, np.prod(value_shape) if value_shape else 1)
if value_shape:
outside_values = outside_values.reshape(value_shape)
log_probs_outside = d.log_prob(outside_values)
# Values in [0, 1) should be finite, others should be -inf
mask_inside = (outside_values >= 0.0) & (outside_values < 1.0)
assert jnp.all(jnp.where(mask_inside, jnp.isfinite(log_probs_outside), True))
assert jnp.all(jnp.where(~mask_inside, log_probs_outside == -jnp.inf, True))
Loading