Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,29 @@ def sample(

@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
return self._dirichlet.log_prob(jnp.stack([value, 1.0 - value], -1))
# Use double-where trick to avoid NaN gradients at boundary conditions
# Reference: https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
is_boundary = (value == 0.0) | (value == 1.0)

# Mask boundary values (0 or 1) to safe value (0.5) for gradient computation
safe_value = jnp.where(is_boundary, 0.5, value)
safe_complement = jnp.where(is_boundary, 0.5, 1.0 - value)

# Compute log_prob with safe values (gradients flow through this path)
safe_dirichlet_value = jnp.stack([safe_value, safe_complement], axis=-1)
safe_log_prob = self._dirichlet.log_prob(safe_dirichlet_value)

# At boundaries, compute correct forward value using xlogy (handles 0*log(0)=0)
# Use stop_gradient so gradients come only from safe_log_prob
correct_value = (
xlogy(self.concentration1 - 1.0, value)
+ xlogy(self.concentration0 - 1.0, 1.0 - value)
- betaln(self.concentration1, self.concentration0)
)
correction = jax.lax.stop_gradient(correct_value - safe_log_prob)

# Apply correction at boundaries, return safe value elsewhere
return safe_log_prob + jnp.where(is_boundary, correction, 0.0)

@property
def mean(self) -> ArrayLike:
Expand Down
128 changes: 128 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4486,3 +4486,131 @@ def test_interval_censored_validate_sample(
censored_dist.log_prob(value)
else:
censored_dist.log_prob(value) # Should not raise


@pytest.mark.parametrize(
argnames="concentration1,concentration0,value",
argvalues=[
(1.0, 8.0, 0.0),
(8.0, 1.0, 1.0),
],
ids=["Beta(1,8) at x=0", "Beta(8,1) at x=1"],
)
def test_beta_logprob_edge_cases(concentration1, concentration0, value):
"""Test Beta distribution with concentration=1 gives finite log probability at boundary."""
beta_dist = dist.Beta(concentration1, concentration0)
log_prob = beta_dist.log_prob(value)

assert not jnp.isnan(log_prob), (
f"Beta({concentration1},{concentration0}).log_prob({value}) should not be NaN"
)
assert jnp.isfinite(log_prob), (
f"Beta({concentration1},{concentration0}).log_prob({value}) should be finite"
)


def test_beta_logprob_edge_case_consistency_small_values():
"""Test that edge case values are consistent with small deviation values."""
beta_dist = dist.Beta(1.0, 8.0)
beta_dist2 = dist.Beta(8.0, 1.0)

# At boundary
log_prob_at_zero = beta_dist.log_prob(0.0)
log_prob_at_one = beta_dist2.log_prob(1.0)

# Very close to boundary
small_value = 1e-10
log_prob_small = beta_dist.log_prob(small_value)
log_prob_close_to_one = beta_dist2.log_prob(1.0 - small_value)

# Edge case values should be close to small deviation values
assert jnp.abs(log_prob_at_zero - log_prob_small) < 1e-5
assert jnp.abs(log_prob_at_one - log_prob_close_to_one) < 1e-5


def test_beta_logprob_edge_case_non_boundary_values():
"""Test that Beta with concentration=1 still works for non-boundary values."""
beta_dist = dist.Beta(1.0, 8.0)
beta_dist2 = dist.Beta(8.0, 1.0)

assert jnp.isfinite(beta_dist.log_prob(0.5))
assert jnp.isfinite(beta_dist2.log_prob(0.5))


def test_beta_logprob_boundary_non_edge_cases():
"""Test that non-edge cases (concentration > 1) still give -inf at boundaries."""
beta_dist3 = dist.Beta(2.0, 8.0)
beta_dist4 = dist.Beta(8.0, 2.0)

assert jnp.isneginf(beta_dist3.log_prob(0.0))
assert jnp.isneginf(beta_dist4.log_prob(1.0))


@pytest.mark.parametrize(
argnames="concentration1,concentration0,value,grad_param,grad_value",
argvalues=[
(1.0, 8.0, 0.0, "value", 0.0),
(8.0, 1.0, 1.0, "value", 1.0),
(1.0, 8.0, 0.0, "concentration1", 1.0),
(1.0, 8.0, 0.0, "concentration0", 8.0),
(8.0, 1.0, 1.0, "concentration1", 8.0),
(8.0, 1.0, 1.0, "concentration0", 1.0),
],
ids=[
"Beta(1,8) at x=0",
"Beta(8,1) at x=1",
"Beta(1,8) at concentration1=1",
"Beta(1,8) at concentration0=8",
"Beta(8,1) at concentration1=8",
"Beta(8,1) at concentration0=1",
],
)
def test_beta_gradient_edge_cases_single_param(
concentration1, concentration0, value, grad_param, grad_value
):
"""Test that gradients w.r.t. individual parameters are finite at edge cases."""
if grad_param == "value":

def log_prob_fn(x):
return dist.Beta(concentration1, concentration0).log_prob(x)

grad = jax.grad(log_prob_fn)(value)
elif grad_param == "concentration1":

def log_prob_fn(c1):
return dist.Beta(c1, concentration0).log_prob(value)

grad = jax.grad(log_prob_fn)(grad_value)
else: # concentration0

def log_prob_fn(c0):
return dist.Beta(concentration1, c0).log_prob(value)

grad = jax.grad(log_prob_fn)(grad_value)

assert jnp.isfinite(grad), (
f"Gradient w.r.t. {grad_param} for Beta({concentration1},{concentration0}) "
f"at x={value} should be finite"
)


@pytest.mark.parametrize(
argnames="concentration1,concentration0,value",
argvalues=[
(1.0, 8.0, 0.0),
(8.0, 1.0, 1.0),
],
ids=["Beta(1,8) at x=0", "Beta(8,1) at x=1"],
)
def test_beta_gradient_edge_cases_all_params(concentration1, concentration0, value):
"""Test that all gradients are finite when computed simultaneously at edge cases."""

def log_prob_fn(params):
c1, c0, v = params
return dist.Beta(c1, c0).log_prob(v)

grads = jax.grad(log_prob_fn)(jnp.array([concentration1, concentration0, value]))
assert jnp.all(jnp.isfinite(grads)), (
f"All gradients for Beta({concentration1},{concentration0}) at x={value} "
f"should be finite"
)