|
3 | 3 |
|
4 | 4 |
|
5 | 5 | from typing import Optional |
| 6 | +import warnings |
| 7 | + |
| 8 | +import numpy as np |
6 | 9 |
|
7 | 10 | import jax |
8 | 11 | from jax import lax |
|
12 | 15 | from numpyro._typing import ConstraintT, DistributionT |
13 | 16 | from numpyro.distributions import constraints |
14 | 17 | from numpyro.distributions.distribution import Distribution |
15 | | -from numpyro.distributions.util import ( |
16 | | - promote_shapes, |
17 | | - validate_sample, |
18 | | -) |
| 18 | +from numpyro.distributions.util import log1mexp, promote_shapes, validate_sample |
| 19 | +from numpyro.util import find_stack_level, not_jax_tracer |
19 | 20 |
|
20 | 21 |
|
21 | 22 | class LeftCensoredDistribution(Distribution): |
@@ -249,3 +250,246 @@ def log_survival_censored(x): |
249 | 250 | log_survival_censored(value), # censored observations: log S(t) |
250 | 251 | self.base_dist.log_prob(value), # observed values: log f(t) |
251 | 252 | ) |
| 253 | + |
| 254 | + |
| 255 | +class IntervalCensoredDistribution(Distribution): |
| 256 | + r""" |
| 257 | + Distribution wrapper for interval-censored outcomes. |
| 258 | +
|
| 259 | + This distribution augments a base distribution with interval censoring, |
| 260 | + so that the likelihood contribution depends on whether the observation is |
| 261 | + exactly observed, |
| 262 | + left-censored, right-censored, interval-censored, or doubly-censored |
| 263 | + (i.e., known to lie outside the observed interval). |
| 264 | +
|
| 265 | + :param base_dist: Parametric distribution for the *uncensored* values |
| 266 | + (e.g., Exponential, Weibull, LogNormal, Normal, etc.). |
| 267 | + This distribution must implement a ``cdf`` method. |
| 268 | + :type base_dist: numpyro.distributions.Distribution |
| 269 | + :param left_censored: Indicator per observation: |
| 270 | + 1 → observation is left-censored at the reported upper bound |
| 271 | + 0 → not left-censored |
| 272 | + :type left_censored: array-like of {0,1} |
| 273 | + :param right_censored: Indicator per observation: |
| 274 | + 1 → observation is right-censored at the reported lower bound |
| 275 | + 0 → not right-censored |
| 276 | + :type right_censored: array-like of {0,1} |
| 277 | +
|
| 278 | + .. note:: |
| 279 | + The ``log_prob(value)`` method expects ``value`` to be a two-dimensional array |
| 280 | + of shape ``(batch_size, 2)``, where each row is ``(lower, upper)``. |
| 281 | + The contribution to the log-likelihood is determined as follows: |
| 282 | +
|
| 283 | + log F(upper) if left_censored == 1 and right_censored == 0 |
| 284 | + log (1 - F(lower)) if right_censored == 1 and left_censored == 0 |
| 285 | + log (F(upper) - F(lower)) if both == 0 (interval-censored) |
| 286 | + log (1 - (F(upper) - F(lower))) if both == 1 (doubly-censored) |
| 287 | + log f(value) if lower ≈ upper (point interval) |
| 288 | +
|
| 289 | + where f is the density and F the cumulative distribution function of ``base_dist``. |
| 290 | +
|
| 291 | + This is commonly used in survival analysis, where event times are positive, |
| 292 | + but the approach is general and can be applied to any distribution |
| 293 | + with a cumulative distribution function, regardless of support. |
| 294 | +
|
| 295 | + In R's ``survival`` package notation, this corresponds to |
| 296 | + ``Surv(l, r, type = 'interval2')``. |
| 297 | +
|
| 298 | + Example: |
| 299 | +
|
| 300 | + Surv(l = c(2, 4, 6), r = c(5, Inf, 9), type = 'interval2') |
| 301 | +
|
| 302 | + means: |
| 303 | +
|
| 304 | + subject 1 had an event in (2, 5] |
| 305 | + subject 2 was right-censored at 4 |
| 306 | + subject 3 had an event in (6, 9] |
| 307 | +
|
| 308 | + **Example:** |
| 309 | +
|
| 310 | + .. doctest:: |
| 311 | +
|
| 312 | + >>> from jax import numpy as jnp |
| 313 | + >>> from numpyro import distributions as dist |
| 314 | + >>> base = dist.Weibull(concentration=2.0, scale=3.0) |
| 315 | + >>> left_censored = jnp.array([0, 0, 0]) |
| 316 | + >>> right_censored = jnp.array([0, 1, 0]) |
| 317 | + >>> surv_dist = dist.IntervalCensoredDistribution(base, left_censored, right_censored) |
| 318 | + >>> values = jnp.array([ |
| 319 | + ... [2.0, 5.0], |
| 320 | + ... [4.0, jnp.inf], |
| 321 | + ... [6.0, 9.0], |
| 322 | + ... ]) |
| 323 | + >>> loglik = surv_dist.log_prob(values) |
| 324 | + """ |
| 325 | + |
| 326 | + arg_constraints = { |
| 327 | + "left_censored": constraints.boolean, |
| 328 | + "right_censored": constraints.boolean, |
| 329 | + } |
| 330 | + pytree_data_fields = ("base_dist", "left_censored", "right_censored", "_support") |
| 331 | + |
| 332 | + def __init__( |
| 333 | + self, |
| 334 | + base_dist: DistributionT, |
| 335 | + left_censored: ArrayLike, |
| 336 | + right_censored: ArrayLike, |
| 337 | + *, |
| 338 | + validate_args: Optional[bool] = None, |
| 339 | + ): |
| 340 | + # Optionally test that cdf actually works (in validate_args mode) |
| 341 | + if validate_args: |
| 342 | + try: |
| 343 | + test_val = base_dist.support.feasible_like(jnp.array(0.0)) |
| 344 | + _ = base_dist.cdf(test_val) |
| 345 | + except (NotImplementedError, AttributeError) as e: |
| 346 | + raise TypeError( |
| 347 | + f"{type(base_dist).__name__}.cdf() is not properly implemented." |
| 348 | + ) from e |
| 349 | + batch_shape = lax.broadcast_shapes( |
| 350 | + base_dist.batch_shape, jnp.shape(left_censored), jnp.shape(right_censored) |
| 351 | + ) |
| 352 | + self.base_dist = jax.tree.map( |
| 353 | + lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist |
| 354 | + ) |
| 355 | + self.left_censored = jnp.array( |
| 356 | + promote_shapes(left_censored, shape=batch_shape)[0], dtype=jnp.bool |
| 357 | + ) |
| 358 | + self.right_censored = jnp.array( |
| 359 | + promote_shapes(right_censored, shape=batch_shape)[0], dtype=jnp.bool |
| 360 | + ) |
| 361 | + self._support = base_dist.support |
| 362 | + super().__init__(batch_shape, event_shape=(2,), validate_args=validate_args) |
| 363 | + |
| 364 | + def sample( |
| 365 | + self, key: Optional[jax.dtypes.prng_key], sample_shape: tuple[int, ...] = () |
| 366 | + ) -> ArrayLike: |
| 367 | + return self.base_dist.expand(self.batch_shape).sample(key, sample_shape) |
| 368 | + |
| 369 | + @constraints.dependent_property(is_discrete=False, event_dim=1) |
| 370 | + def support(self) -> ConstraintT: |
| 371 | + return self._support |
| 372 | + |
| 373 | + def _get_censoring_masks(self, value): |
| 374 | + """Helper to get censoring masks.""" |
| 375 | + |
| 376 | + x1 = jnp.take(value, 0, axis=-1) # left bound |
| 377 | + x2 = jnp.take(value, 1, axis=-1) # right bound |
| 378 | + |
| 379 | + m_left = self.left_censored & (~self.right_censored) # left-censored only |
| 380 | + m_right = self.right_censored & (~self.left_censored) # right-censored only |
| 381 | + m_int = (~self.left_censored) & (~self.right_censored) # interval censored |
| 382 | + m_double = self.left_censored & self.right_censored # doubly censored |
| 383 | + m_point = jnp.isclose(x1, x2) & m_int # point observation |
| 384 | + m_int = m_int & (~m_point) # update interval mask to exclude point obs |
| 385 | + return m_left, m_right, m_int, m_double, m_point |
| 386 | + |
| 387 | + @validate_sample |
| 388 | + def log_prob(self, value): |
| 389 | + dtype = jnp.result_type(value, float) |
| 390 | + minval = 100.0 * jnp.finfo(dtype).tiny # for values close to 0 |
| 391 | + eps = jnp.finfo(dtype).eps # otherwise |
| 392 | + |
| 393 | + x1 = jnp.take(value, 0, axis=-1) # left bound |
| 394 | + x2 = jnp.take(value, 1, axis=-1) # right bound |
| 395 | + |
| 396 | + # make masks based on censoring indicators |
| 397 | + m_left, m_right, m_int, m_double, m_point = self._get_censoring_masks(value) |
| 398 | + |
| 399 | + # Replace potential out-of-support values with finite placeholder BEFORE cdf |
| 400 | + # (value doesn't matter; it will be overwritten) |
| 401 | + feasible_value = self.support.feasible_like(x1) |
| 402 | + x1_finite = jnp.where(m_left, feasible_value, x1) |
| 403 | + x2_finite = jnp.where(m_right, feasible_value, x2) |
| 404 | + |
| 405 | + # Calculate CDF on safe values |
| 406 | + F1_tmp = self.base_dist.cdf(x1_finite) |
| 407 | + F2_tmp = self.base_dist.cdf(x2_finite) |
| 408 | + |
| 409 | + # Overwrite with correct limit values on censored rows |
| 410 | + # Left-censored: F1 := 0 |
| 411 | + F1 = jnp.where(m_left, 0.0, F1_tmp) |
| 412 | + # Right-censored: F2 := 1 |
| 413 | + F2 = jnp.where(m_right, 1.0, F2_tmp) |
| 414 | + |
| 415 | + # Stabilize against log(0) and tiny intervals |
| 416 | + F1 = jnp.clip(F1, minval, 1.0 - eps) |
| 417 | + F2 = jnp.clip(F2, minval, 1.0 - eps) |
| 418 | + |
| 419 | + # Use a stable log-diff for intervals (also covers left/right cases) |
| 420 | + # log(F2 - F1) = logF2 + log1p(-exp(logF1 - logF2)) |
| 421 | + logF1 = jnp.log(F1) |
| 422 | + logF2 = jnp.log(F2) |
| 423 | + |
| 424 | + lp_interval = logF2 + jnp.log1p(-jnp.exp(jnp.clip(logF1 - logF2, max=-minval))) |
| 425 | + # handle point intervals (x1 == x2) by returning log density instead of log prob |
| 426 | + lp_interval = jnp.where(m_point, self.base_dist.log_prob(x1), lp_interval) |
| 427 | + |
| 428 | + # for doubly censored data, the value is not in the interval, so computation is 1 - exp(lp_interval) |
| 429 | + lp_double = log1mexp(lp_interval) |
| 430 | + |
| 431 | + # Select the right expression per row |
| 432 | + # left: log F(x2) |
| 433 | + lp_left = logF2 |
| 434 | + # right: log (1 - F(x1)) = log1p(-F1) |
| 435 | + lp_right = jnp.log1p(-F1) |
| 436 | + |
| 437 | + logp = jnp.zeros_like(logF1) |
| 438 | + logp = jnp.where(m_left, lp_left, logp) |
| 439 | + logp = jnp.where(m_right, lp_right, logp) |
| 440 | + logp = jnp.where(m_int, lp_interval, logp) |
| 441 | + logp = jnp.where(m_double, lp_double, logp) |
| 442 | + return logp |
| 443 | + |
| 444 | + def _validate_sample(self, value: ArrayLike) -> None: |
| 445 | + if value.shape[-1] != 2: |
| 446 | + raise ValueError( |
| 447 | + f"Expected last dimension of `value` to be 2 (lower, upper), but got shape {value.shape}" |
| 448 | + ) |
| 449 | + x1 = jnp.take(value, 0, axis=-1) # left bound |
| 450 | + x2 = jnp.take(value, 1, axis=-1) # right bound |
| 451 | + m_left, m_right, m_int, m_double, m_point = self._get_censoring_masks(value) |
| 452 | + |
| 453 | + # check validity under base_dist of x1 and x2 |
| 454 | + with warnings.catch_warnings(): |
| 455 | + warnings.simplefilter("ignore") |
| 456 | + x1_mask = self.base_dist._validate_sample(x1) |
| 457 | + x2_mask = self.base_dist._validate_sample(x2) |
| 458 | + |
| 459 | + mask = jnp.ones_like(x1, dtype=jnp.bool) |
| 460 | + # for left-censored, the upper bound must be in the support of base_dist |
| 461 | + mask = jnp.where(m_left, x2_mask, mask) |
| 462 | + if not_jax_tracer(mask): |
| 463 | + if not np.all(mask): |
| 464 | + warnings.warn( |
| 465 | + "For left-censored observations, upper bound should be within the support of base_dist. ", |
| 466 | + stacklevel=find_stack_level(), |
| 467 | + ) |
| 468 | + |
| 469 | + # for right-censored, the lower bound must be in the support of base_dist |
| 470 | + mask = jnp.where(m_right, x1_mask, mask) |
| 471 | + if not_jax_tracer(mask): |
| 472 | + if not np.all(mask): |
| 473 | + warnings.warn( |
| 474 | + "For right-censored observations, lower bound should be within the support of base_dist. ", |
| 475 | + stacklevel=find_stack_level(), |
| 476 | + ) |
| 477 | + # for interval-censored, doubly censored and point, both bounds must be in the support of base_dist |
| 478 | + mask = jnp.where(m_int | m_double | m_point, x1_mask & x2_mask, mask) |
| 479 | + if not_jax_tracer(mask): |
| 480 | + if not np.all(mask): |
| 481 | + warnings.warn( |
| 482 | + "For interval-censored, doubly-censored, or exact observations," |
| 483 | + "lower bound should be within the support of base_dist. ", |
| 484 | + stacklevel=find_stack_level(), |
| 485 | + ) |
| 486 | + # for interval-censored and doubly-censored, upper bound must be > lower bound |
| 487 | + mask = jnp.where(m_int | m_double, mask & (x2 > x1), mask) |
| 488 | + if not_jax_tracer(mask): |
| 489 | + if not np.all(mask): |
| 490 | + warnings.warn( |
| 491 | + "For interval-censored and doubly-censored observations," |
| 492 | + "upper bound should greater than lower bound. ", |
| 493 | + stacklevel=find_stack_level(), |
| 494 | + ) |
| 495 | + return mask |
0 commit comments