Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,41 @@ def sample(

@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
return self._dirichlet.log_prob(jnp.stack([value, 1.0 - value], -1))
# Use double-where trick to avoid NaN gradients at boundary conditions
# when concentration parameters equal 1 (following TF Probability approach).
# Reference: https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
#
# The key insight is to mask extreme values BEFORE computation, so gradients
# flow through the safe path. The forward pass automatically gets the right
# answer because xlogy(0, 0) = 0.

# Step 1: Identify boundary values (0 or 1)
is_boundary = (value == 0.0) | (value == 1.0)

# Step 2: Inner where - mask boundary values to safe canonical value (0.5)
# This ensures log(0) never appears in the gradient computation path
safe_value = jnp.where(is_boundary, 0.5, value)

# Step 3: Compute log_prob with safe values (gradients flow through here)
safe_log_prob = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we use self.dirichlet.log_prob here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tried it in c6f113b

xlogy(self.concentration1 - 1.0, safe_value)
+ xlogy(self.concentration0 - 1.0, 1.0 - safe_value)
- betaln(self.concentration1, self.concentration0)
)

# Step 4: Compute correct forward-pass value at boundaries
# Use stop_gradient to prevent gradients from flowing through this branch
# xlogy(0, 0) = 0 gives the correct value when concentration=1 at boundaries
boundary_log_prob = jax.lax.stop_gradient(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about setting this to 0.0 instead of using stop_gradient?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not make it work without stop_gradient :(

xlogy(self.concentration1 - 1.0, value)
+ xlogy(self.concentration0 - 1.0, 1.0 - value)
- betaln(self.concentration1, self.concentration0)
)

# Step 5: Outer where - select boundary value at boundaries, safe value elsewhere
# Forward pass: uses boundary_log_prob at boundaries (correct value)
# Gradients: come from safe_log_prob (finite, since safe_value avoids log(0))
return jnp.where(is_boundary, boundary_log_prob, safe_log_prob)

@property
def mean(self) -> ArrayLike:
Expand Down
128 changes: 128 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4486,3 +4486,131 @@ def test_interval_censored_validate_sample(
censored_dist.log_prob(value)
else:
censored_dist.log_prob(value) # Should not raise


@pytest.mark.parametrize(
argnames="concentration1,concentration0,value",
argvalues=[
(1.0, 8.0, 0.0),
(8.0, 1.0, 1.0),
],
ids=["Beta(1,8) at x=0", "Beta(8,1) at x=1"],
)
def test_beta_logprob_edge_cases(concentration1, concentration0, value):
"""Test Beta distribution with concentration=1 gives finite log probability at boundary."""
beta_dist = dist.Beta(concentration1, concentration0)
log_prob = beta_dist.log_prob(value)

assert not jnp.isnan(log_prob), (
f"Beta({concentration1},{concentration0}).log_prob({value}) should not be NaN"
)
assert jnp.isfinite(log_prob), (
f"Beta({concentration1},{concentration0}).log_prob({value}) should be finite"
)


def test_beta_logprob_edge_case_consistency_small_values():
"""Test that edge case values are consistent with small deviation values."""
beta_dist = dist.Beta(1.0, 8.0)
beta_dist2 = dist.Beta(8.0, 1.0)

# At boundary
log_prob_at_zero = beta_dist.log_prob(0.0)
log_prob_at_one = beta_dist2.log_prob(1.0)

# Very close to boundary
small_value = 1e-10
log_prob_small = beta_dist.log_prob(small_value)
log_prob_close_to_one = beta_dist2.log_prob(1.0 - small_value)

# Edge case values should be close to small deviation values
assert jnp.abs(log_prob_at_zero - log_prob_small) < 1e-5
assert jnp.abs(log_prob_at_one - log_prob_close_to_one) < 1e-5


def test_beta_logprob_edge_case_non_boundary_values():
"""Test that Beta with concentration=1 still works for non-boundary values."""
beta_dist = dist.Beta(1.0, 8.0)
beta_dist2 = dist.Beta(8.0, 1.0)

assert jnp.isfinite(beta_dist.log_prob(0.5))
assert jnp.isfinite(beta_dist2.log_prob(0.5))


def test_beta_logprob_boundary_non_edge_cases():
"""Test that non-edge cases (concentration > 1) still give -inf at boundaries."""
beta_dist3 = dist.Beta(2.0, 8.0)
beta_dist4 = dist.Beta(8.0, 2.0)

assert jnp.isneginf(beta_dist3.log_prob(0.0))
assert jnp.isneginf(beta_dist4.log_prob(1.0))


@pytest.mark.parametrize(
argnames="concentration1,concentration0,value,grad_param,grad_value",
argvalues=[
(1.0, 8.0, 0.0, "value", 0.0),
(8.0, 1.0, 1.0, "value", 1.0),
(1.0, 8.0, 0.0, "concentration1", 1.0),
(1.0, 8.0, 0.0, "concentration0", 8.0),
(8.0, 1.0, 1.0, "concentration1", 8.0),
(8.0, 1.0, 1.0, "concentration0", 1.0),
],
ids=[
"Beta(1,8) at x=0",
"Beta(8,1) at x=1",
"Beta(1,8) at concentration1=1",
"Beta(1,8) at concentration0=8",
"Beta(8,1) at concentration1=8",
"Beta(8,1) at concentration0=1",
],
)
def test_beta_gradient_edge_cases_single_param(
concentration1, concentration0, value, grad_param, grad_value
):
"""Test that gradients w.r.t. individual parameters are finite at edge cases."""
if grad_param == "value":

def log_prob_fn(x):
return dist.Beta(concentration1, concentration0).log_prob(x)

grad = jax.grad(log_prob_fn)(value)
elif grad_param == "concentration1":

def log_prob_fn(c1):
return dist.Beta(c1, concentration0).log_prob(value)

grad = jax.grad(log_prob_fn)(grad_value)
else: # concentration0

def log_prob_fn(c0):
return dist.Beta(concentration1, c0).log_prob(value)

grad = jax.grad(log_prob_fn)(grad_value)

assert jnp.isfinite(grad), (
f"Gradient w.r.t. {grad_param} for Beta({concentration1},{concentration0}) "
f"at x={value} should be finite"
)


@pytest.mark.parametrize(
argnames="concentration1,concentration0,value",
argvalues=[
(1.0, 8.0, 0.0),
(8.0, 1.0, 1.0),
],
ids=["Beta(1,8) at x=0", "Beta(8,1) at x=1"],
)
def test_beta_gradient_edge_cases_all_params(concentration1, concentration0, value):
"""Test that all gradients are finite when computed simultaneously at edge cases."""

def log_prob_fn(params):
c1, c0, v = params
return dist.Beta(c1, c0).log_prob(v)

grads = jax.grad(log_prob_fn)(jnp.array([concentration1, concentration0, value]))
assert jnp.all(jnp.isfinite(grads)), (
f"All gradients for Beta({concentration1},{concentration0}) at x={value} "
f"should be finite"
)