Skip to content

Commit d49f718

Browse files
vanAmsterdamvanamsterdam
andauthored
Implement IntervalCensoredDistribution (#2090)
* start implementing intervalcensored * minor updates, passing non-interval specific tests * add logp test for intervalcensored * allow exact observations in intervalcensored; update tests for intervalcensored * update docs * implement interval censoring validate_sample * use log1mexp for numerical stability --------- Co-authored-by: vanamsterdam <[email protected]>
1 parent b5b68f9 commit d49f718

File tree

4 files changed

+625
-38
lines changed

4 files changed

+625
-38
lines changed

docs/source/distributions.rst

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -766,21 +766,28 @@ Censored Distributions
766766
-----------------------
767767

768768
LeftCensoredDistribution
769-
^^^^^^^^^^^^^^^^^^^^^^^^^
769+
^^^^^^^^^^^^^^^^^^^^^^^^
770770
.. autoclass:: numpyro.distributions.censored.LeftCensoredDistribution
771771
:members:
772772
:undoc-members:
773773
:show-inheritance:
774774
:member-order: bysource
775775

776776
RightCensoredDistribution
777-
^^^^^^^^^^^^^^^^^^^^^^^^^^
777+
^^^^^^^^^^^^^^^^^^^^^^^^^
778778
.. autoclass:: numpyro.distributions.censored.RightCensoredDistribution
779779
:members:
780780
:undoc-members:
781781
:show-inheritance:
782782
:member-order: bysource
783783

784+
IntervalCensoredDistribution
785+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
786+
.. autoclass:: numpyro.distributions.censored.IntervalCensoredDistribution
787+
:members:
788+
:undoc-members:
789+
:show-inheritance:
790+
:member-order: bysource
784791

785792
TensorFlow Distributions
786793
------------------------

numpyro/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from numpyro.distributions.censored import (
5+
IntervalCensoredDistribution,
56
LeftCensoredDistribution,
67
RightCensoredDistribution,
78
)
@@ -200,6 +201,7 @@
200201
"RightTruncatedDistribution",
201202
"LeftCensoredDistribution",
202203
"RightCensoredDistribution",
204+
"IntervalCensoredDistribution",
203205
"SineBivariateVonMises",
204206
"SineSkewed",
205207
"SoftLaplace",

numpyro/distributions/censored.py

Lines changed: 248 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44

55
from typing import Optional
6+
import warnings
7+
8+
import numpy as np
69

710
import jax
811
from jax import lax
@@ -12,10 +15,8 @@
1215
from numpyro._typing import ConstraintT, DistributionT
1316
from numpyro.distributions import constraints
1417
from numpyro.distributions.distribution import Distribution
15-
from numpyro.distributions.util import (
16-
promote_shapes,
17-
validate_sample,
18-
)
18+
from numpyro.distributions.util import log1mexp, promote_shapes, validate_sample
19+
from numpyro.util import find_stack_level, not_jax_tracer
1920

2021

2122
class LeftCensoredDistribution(Distribution):
@@ -249,3 +250,246 @@ def log_survival_censored(x):
249250
log_survival_censored(value), # censored observations: log S(t)
250251
self.base_dist.log_prob(value), # observed values: log f(t)
251252
)
253+
254+
255+
class IntervalCensoredDistribution(Distribution):
256+
r"""
257+
Distribution wrapper for interval-censored outcomes.
258+
259+
This distribution augments a base distribution with interval censoring,
260+
so that the likelihood contribution depends on whether the observation is
261+
exactly observed,
262+
left-censored, right-censored, interval-censored, or doubly-censored
263+
(i.e., known to lie outside the observed interval).
264+
265+
:param base_dist: Parametric distribution for the *uncensored* values
266+
(e.g., Exponential, Weibull, LogNormal, Normal, etc.).
267+
This distribution must implement a ``cdf`` method.
268+
:type base_dist: numpyro.distributions.Distribution
269+
:param left_censored: Indicator per observation:
270+
1 → observation is left-censored at the reported upper bound
271+
0 → not left-censored
272+
:type left_censored: array-like of {0,1}
273+
:param right_censored: Indicator per observation:
274+
1 → observation is right-censored at the reported lower bound
275+
0 → not right-censored
276+
:type right_censored: array-like of {0,1}
277+
278+
.. note::
279+
The ``log_prob(value)`` method expects ``value`` to be a two-dimensional array
280+
of shape ``(batch_size, 2)``, where each row is ``(lower, upper)``.
281+
The contribution to the log-likelihood is determined as follows:
282+
283+
log F(upper) if left_censored == 1 and right_censored == 0
284+
log (1 - F(lower)) if right_censored == 1 and left_censored == 0
285+
log (F(upper) - F(lower)) if both == 0 (interval-censored)
286+
log (1 - (F(upper) - F(lower))) if both == 1 (doubly-censored)
287+
log f(value) if lower ≈ upper (point interval)
288+
289+
where f is the density and F the cumulative distribution function of ``base_dist``.
290+
291+
This is commonly used in survival analysis, where event times are positive,
292+
but the approach is general and can be applied to any distribution
293+
with a cumulative distribution function, regardless of support.
294+
295+
In R's ``survival`` package notation, this corresponds to
296+
``Surv(l, r, type = 'interval2')``.
297+
298+
Example:
299+
300+
Surv(l = c(2, 4, 6), r = c(5, Inf, 9), type = 'interval2')
301+
302+
means:
303+
304+
subject 1 had an event in (2, 5]
305+
subject 2 was right-censored at 4
306+
subject 3 had an event in (6, 9]
307+
308+
**Example:**
309+
310+
.. doctest::
311+
312+
>>> from jax import numpy as jnp
313+
>>> from numpyro import distributions as dist
314+
>>> base = dist.Weibull(concentration=2.0, scale=3.0)
315+
>>> left_censored = jnp.array([0, 0, 0])
316+
>>> right_censored = jnp.array([0, 1, 0])
317+
>>> surv_dist = dist.IntervalCensoredDistribution(base, left_censored, right_censored)
318+
>>> values = jnp.array([
319+
... [2.0, 5.0],
320+
... [4.0, jnp.inf],
321+
... [6.0, 9.0],
322+
... ])
323+
>>> loglik = surv_dist.log_prob(values)
324+
"""
325+
326+
arg_constraints = {
327+
"left_censored": constraints.boolean,
328+
"right_censored": constraints.boolean,
329+
}
330+
pytree_data_fields = ("base_dist", "left_censored", "right_censored", "_support")
331+
332+
def __init__(
333+
self,
334+
base_dist: DistributionT,
335+
left_censored: ArrayLike,
336+
right_censored: ArrayLike,
337+
*,
338+
validate_args: Optional[bool] = None,
339+
):
340+
# Optionally test that cdf actually works (in validate_args mode)
341+
if validate_args:
342+
try:
343+
test_val = base_dist.support.feasible_like(jnp.array(0.0))
344+
_ = base_dist.cdf(test_val)
345+
except (NotImplementedError, AttributeError) as e:
346+
raise TypeError(
347+
f"{type(base_dist).__name__}.cdf() is not properly implemented."
348+
) from e
349+
batch_shape = lax.broadcast_shapes(
350+
base_dist.batch_shape, jnp.shape(left_censored), jnp.shape(right_censored)
351+
)
352+
self.base_dist = jax.tree.map(
353+
lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist
354+
)
355+
self.left_censored = jnp.array(
356+
promote_shapes(left_censored, shape=batch_shape)[0], dtype=jnp.bool
357+
)
358+
self.right_censored = jnp.array(
359+
promote_shapes(right_censored, shape=batch_shape)[0], dtype=jnp.bool
360+
)
361+
self._support = base_dist.support
362+
super().__init__(batch_shape, event_shape=(2,), validate_args=validate_args)
363+
364+
def sample(
365+
self, key: Optional[jax.dtypes.prng_key], sample_shape: tuple[int, ...] = ()
366+
) -> ArrayLike:
367+
return self.base_dist.expand(self.batch_shape).sample(key, sample_shape)
368+
369+
@constraints.dependent_property(is_discrete=False, event_dim=1)
370+
def support(self) -> ConstraintT:
371+
return self._support
372+
373+
def _get_censoring_masks(self, value):
374+
"""Helper to get censoring masks."""
375+
376+
x1 = jnp.take(value, 0, axis=-1) # left bound
377+
x2 = jnp.take(value, 1, axis=-1) # right bound
378+
379+
m_left = self.left_censored & (~self.right_censored) # left-censored only
380+
m_right = self.right_censored & (~self.left_censored) # right-censored only
381+
m_int = (~self.left_censored) & (~self.right_censored) # interval censored
382+
m_double = self.left_censored & self.right_censored # doubly censored
383+
m_point = jnp.isclose(x1, x2) & m_int # point observation
384+
m_int = m_int & (~m_point) # update interval mask to exclude point obs
385+
return m_left, m_right, m_int, m_double, m_point
386+
387+
@validate_sample
388+
def log_prob(self, value):
389+
dtype = jnp.result_type(value, float)
390+
minval = 100.0 * jnp.finfo(dtype).tiny # for values close to 0
391+
eps = jnp.finfo(dtype).eps # otherwise
392+
393+
x1 = jnp.take(value, 0, axis=-1) # left bound
394+
x2 = jnp.take(value, 1, axis=-1) # right bound
395+
396+
# make masks based on censoring indicators
397+
m_left, m_right, m_int, m_double, m_point = self._get_censoring_masks(value)
398+
399+
# Replace potential out-of-support values with finite placeholder BEFORE cdf
400+
# (value doesn't matter; it will be overwritten)
401+
feasible_value = self.support.feasible_like(x1)
402+
x1_finite = jnp.where(m_left, feasible_value, x1)
403+
x2_finite = jnp.where(m_right, feasible_value, x2)
404+
405+
# Calculate CDF on safe values
406+
F1_tmp = self.base_dist.cdf(x1_finite)
407+
F2_tmp = self.base_dist.cdf(x2_finite)
408+
409+
# Overwrite with correct limit values on censored rows
410+
# Left-censored: F1 := 0
411+
F1 = jnp.where(m_left, 0.0, F1_tmp)
412+
# Right-censored: F2 := 1
413+
F2 = jnp.where(m_right, 1.0, F2_tmp)
414+
415+
# Stabilize against log(0) and tiny intervals
416+
F1 = jnp.clip(F1, minval, 1.0 - eps)
417+
F2 = jnp.clip(F2, minval, 1.0 - eps)
418+
419+
# Use a stable log-diff for intervals (also covers left/right cases)
420+
# log(F2 - F1) = logF2 + log1p(-exp(logF1 - logF2))
421+
logF1 = jnp.log(F1)
422+
logF2 = jnp.log(F2)
423+
424+
lp_interval = logF2 + jnp.log1p(-jnp.exp(jnp.clip(logF1 - logF2, max=-minval)))
425+
# handle point intervals (x1 == x2) by returning log density instead of log prob
426+
lp_interval = jnp.where(m_point, self.base_dist.log_prob(x1), lp_interval)
427+
428+
# for doubly censored data, the value is not in the interval, so computation is 1 - exp(lp_interval)
429+
lp_double = log1mexp(lp_interval)
430+
431+
# Select the right expression per row
432+
# left: log F(x2)
433+
lp_left = logF2
434+
# right: log (1 - F(x1)) = log1p(-F1)
435+
lp_right = jnp.log1p(-F1)
436+
437+
logp = jnp.zeros_like(logF1)
438+
logp = jnp.where(m_left, lp_left, logp)
439+
logp = jnp.where(m_right, lp_right, logp)
440+
logp = jnp.where(m_int, lp_interval, logp)
441+
logp = jnp.where(m_double, lp_double, logp)
442+
return logp
443+
444+
def _validate_sample(self, value: ArrayLike) -> None:
445+
if value.shape[-1] != 2:
446+
raise ValueError(
447+
f"Expected last dimension of `value` to be 2 (lower, upper), but got shape {value.shape}"
448+
)
449+
x1 = jnp.take(value, 0, axis=-1) # left bound
450+
x2 = jnp.take(value, 1, axis=-1) # right bound
451+
m_left, m_right, m_int, m_double, m_point = self._get_censoring_masks(value)
452+
453+
# check validity under base_dist of x1 and x2
454+
with warnings.catch_warnings():
455+
warnings.simplefilter("ignore")
456+
x1_mask = self.base_dist._validate_sample(x1)
457+
x2_mask = self.base_dist._validate_sample(x2)
458+
459+
mask = jnp.ones_like(x1, dtype=jnp.bool)
460+
# for left-censored, the upper bound must be in the support of base_dist
461+
mask = jnp.where(m_left, x2_mask, mask)
462+
if not_jax_tracer(mask):
463+
if not np.all(mask):
464+
warnings.warn(
465+
"For left-censored observations, upper bound should be within the support of base_dist. ",
466+
stacklevel=find_stack_level(),
467+
)
468+
469+
# for right-censored, the lower bound must be in the support of base_dist
470+
mask = jnp.where(m_right, x1_mask, mask)
471+
if not_jax_tracer(mask):
472+
if not np.all(mask):
473+
warnings.warn(
474+
"For right-censored observations, lower bound should be within the support of base_dist. ",
475+
stacklevel=find_stack_level(),
476+
)
477+
# for interval-censored, doubly censored and point, both bounds must be in the support of base_dist
478+
mask = jnp.where(m_int | m_double | m_point, x1_mask & x2_mask, mask)
479+
if not_jax_tracer(mask):
480+
if not np.all(mask):
481+
warnings.warn(
482+
"For interval-censored, doubly-censored, or exact observations,"
483+
"lower bound should be within the support of base_dist. ",
484+
stacklevel=find_stack_level(),
485+
)
486+
# for interval-censored and doubly-censored, upper bound must be > lower bound
487+
mask = jnp.where(m_int | m_double, mask & (x2 > x1), mask)
488+
if not_jax_tracer(mask):
489+
if not np.all(mask):
490+
warnings.warn(
491+
"For interval-censored and doubly-censored observations,"
492+
"upper bound should greater than lower bound. ",
493+
stacklevel=find_stack_level(),
494+
)
495+
return mask

0 commit comments

Comments
 (0)