Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KM time sampler #81

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,5 +164,9 @@ cython_debug/
doc/generated/


# Remove files auto-generated from mac
*.DS_Store

# This dataset should not be redistributed, because users have to sign an agreement.
hazardous/data/seer_cancer_cardio_raw_data.txt
hazardous/data/*.txt
5 changes: 3 additions & 2 deletions examples/plot_01_survival_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@
# "duration". This allows SurvivalBoost to estimate the survival function :math:`S`.
from hazardous import SurvivalBoost

survival_boost = SurvivalBoost(show_progressbar=False).fit(X_train, y_train)

survival_boost = SurvivalBoost(show_progressbar=False, time_sampler="uniform").fit(
X_train, y_train
)
survival_boost

# %%
Expand Down
29 changes: 6 additions & 23 deletions hazardous/_ipcw.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import numpy as np
from lifelines import KaplanMeierFitter
from scipy.interpolate import interp1d
from sklearn.base import clone
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.utils.validation import check_is_fitted

from ._km_sampler import _KaplanMeierSampler
from .utils import check_y_survival


Expand Down Expand Up @@ -83,31 +82,15 @@ def fit(self, y, X=None):
event, duration = check_y_survival(y)
censoring = event == 0

km = KaplanMeierFitter()
km.fit(
durations=duration,
event_observed=censoring,
self.kaplan_meier_sampler_ = _KaplanMeierSampler().fit(
dict(event=censoring, duration=duration)
)

df = km.survival_function_
self.unique_times_ = df.index
self.censoring_survival_probs_ = df.values[:, 0]

min_censoring_prob = self.censoring_survival_probs_[
self.censoring_survival_probs_ > 0
].min()

self.min_censoring_prob_ = max(
min_censoring_prob,
self.kaplan_meier_sampler_.min_positive_survival_prob_,
self.epsilon_censoring_prob,
)
self.censoring_survival_func_ = interp1d(
self.unique_times_,
self.censoring_survival_probs_,
kind="previous",
bounds_error=False,
fill_value="extrapolate",
)

return self

def compute_ipcw_at(self, times, X=None, ipcw_training=False):
Expand Down Expand Up @@ -160,7 +143,7 @@ def compute_censoring_survival_proba(self, times, X=None, ipcw_training=False):
ipcw : np.ndarray of shape (n_times,)
The IPCW for times
"""
return self.censoring_survival_func_(times)
return self.kaplan_meier_sampler_.survival_func_(times)


class AlternatingCensoringEstimator(KaplanMeierIPCW):
Expand Down
107 changes: 107 additions & 0 deletions hazardous/_km_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from lifelines import KaplanMeierFitter
from scipy.interpolate import interp1d

from .utils import check_y_survival


class _KaplanMeierSampler:
# TODO docstring
"""Estimate the Inverse Probability of Censoring Weight (IPCW).

This class estimates the inverse probability of 'survival' to censoring using the
Kaplan-Meier estimator applied to a binary indicator for censoring, defined as the
negation of the binary indicator for any event occurrence. This estimator assumes
that the censoring distribution is independent of the covariates X. If this
assumption is violated, the estimator may be biased, and a conditional estimator
might be more appropriate.

This approach is useful for correcting the bias introduced by right censoring in
survival analysis, particularly when computing model evaluation metrics such as
the Brier score or the concordance index.

Note that the term 'IPCW' can be somewhat misleading: IPCW values represent the
inverse of the probability of remaining censor-free (or uncensored) at a given time.
For instance, at t=0, the probability of being censored is 0, so the probability of
being uncensored is 1.0, and its inverse is also 1.0.

By construction, IPCW values are always greater than or equal to 1.0 and can only
increase over time. If no observations are censored, the IPCW values remain
uniformly at 1.0.

Note: This estimator extrapolates by maintaining a constant value equal to the last
observed IPCW value beyond the last recorded time point.

Parameters
----------
epsilon_censoring_prob : float, default=0.05
Lower limit of the predicted censoring probabilities. It helps avoiding
instabilities during the division to obtain IPCW.

Attributes
----------
min_censoring_prob_ : float
The effective minimal probability used, defined as the max between
min_censoring_prob and the minimum predicted probability.

unique_times_ : ndarray of shape (n_unique_times,)
The observed censoring durations from the training target.

censoring_survival_probs_ : ndarray of shape (n_unique_times,)
The estimated censoring survival probabilities.

censoring_survival_func_ : callable
The linear interpolation function defined with unique_times_ (x) and
censoring_survival_probs_ (y).
"""

def fit(self, y):
"""Marginal estimation of the censoring survival function

In addition to running the Kaplan-Meier estimator on the negated event
labels (1 for censoring, 0 for any event), this methods also fits
interpolation function to be able to make prediction at any time.

Parameters
----------
y : array-like of shape (n_samples, 2)
The target data.

Returns
-------
self : object
Fitted estimator.
"""
event, duration = check_y_survival(y)

km = KaplanMeierFitter()
km.fit(
durations=duration,
event_observed=event,
)

df = km.survival_function_
self.unique_times_ = df.index
self.survival_probs_ = df.values[:, 0]

self.survival_func_ = interp1d(
x=self.unique_times_,
y=self.survival_probs_,
kind="previous",
bounds_error=False,
fill_value="extrapolate",
)

self.inverse_surv_func_ = interp1d(
x=self.survival_probs_,
y=self.unique_times_,
kind="previous",
bounds_error=False,
fill_value="extrapolate",
)

self.min_survival_prob_ = self.survival_probs_.min()
self.min_positive_survival_prob_ = self.survival_probs_[
self.survival_probs_ > 0
].min()

return self
69 changes: 60 additions & 9 deletions hazardous/_survival_boost.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from numbers import Real

import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.base import BaseEstimator, ClassifierMixin, check_is_fitted
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.utils.validation import check_array, check_random_state
from tqdm import tqdm

from ._ipcw import AlternatingCensoringEstimator, KaplanMeierIPCW
from ._km_sampler import _KaplanMeierSampler
from .metrics._brier_score import (
IncidenceScoreComputer,
integrated_brier_score_incidence,
Expand All @@ -15,6 +16,43 @@
from .utils import check_y_survival


class _TimeSampler:
def __init__(self, rng):
self.rng = rng


class _TimeSamplerUniform(_TimeSampler):
# Sample from t_min=0 event if never observed in the training set
# because we want to make sure that the model learns to predict a 0
# incidence at t=0.
t_min = 0.0

def fit(self, y):
_, duration = check_y_survival(y)
self.t_max_ = duration.max()
return self

def sample(self, size):
check_is_fitted(self, "t_max_")
return self.rng.uniform(self.t_min, self.t_max_, size)


class _TimeSamplerKM(_TimeSampler):
q_max = 1.0

def fit(self, y):
self.km_sampler_ = _KaplanMeierSampler().fit(y)
# When there are residuals in the estimated survival probabilities,
# we set the minimum quantile to sample as the minimum estimated probability.
self.q_min_ = self.km_sampler_.min_survival_prob_
return self

def sample(self, size):
check_is_fitted(self, ["q_min_", "km_sampler_"])
quantiles = self.rng.uniform(self.q_min_, self.q_max, size)
return self.km_sampler_.inverse_surv_func_(quantiles)


class WeightedMultiClassTargetSampler(IncidenceScoreComputer):
"""Weighted targets for censoring-adjusted incidence estimation.

Expand Down Expand Up @@ -57,8 +95,9 @@ class WeightedMultiClassTargetSampler(IncidenceScoreComputer):
def __init__(
self,
y_train,
hard_zero_fraction=0.01,
time_sampler="kaplan-meier",
ipcw_estimator=None,
hard_zero_fraction=0.1,
n_iter_before_feedback=20,
random_state=None,
):
Expand All @@ -73,18 +112,14 @@ def __init__(
# Precompute the censoring probabilities at the time of the events on the
# training set:
self.ipcw_train = self.ipcw_estimator.compute_ipcw_at(self.duration_train)
self._init_time_sampler(y_train, time_sampler)

def draw(self, ipcw_training=False, X=None):
# Sample time horizons uniformly on the observed time range:
observation_durations = self.duration_train
n_samples = observation_durations.shape[0]

# Sample from t_min=0 event if never observed in the training set
# because we want to make sure that the model learns to predict a 0
# incidence at t=0.
t_min = 0.0
t_max = observation_durations.max()
sampled_time_horizons = self.rng.uniform(t_min, t_max, n_samples)
sampled_time_horizons = self.time_sampler.sample(n_samples)

# Add some hard zeros to make sure that the model learns to
# predict 0 incidence at t=0.
Expand Down Expand Up @@ -145,6 +180,7 @@ def draw(self, ipcw_training=False, X=None):
return sampled_time_horizons.reshape(-1, 1), y_targets, sample_weight

def fit(self, X):
"""Fit the IPCW estimator."""
self.inv_any_survival_train = self.ipcw_estimator.compute_ipcw_at(
self.duration_train, ipcw_training=True, X=X
)
Expand All @@ -160,13 +196,24 @@ def fit(self, X):
times=sampled_time_horizons,
sample_weight=sample_weight,
)

self.ipcw_train = self.ipcw_estimator.compute_ipcw_at(
self.duration_train,
ipcw_training=False,
X=X,
)

def _init_time_sampler(self, y, time_sampler):
if time_sampler == "uniform":
self.time_sampler = _TimeSamplerUniform(self.rng)
elif time_sampler == "kaplan-meier":
self.time_sampler = _TimeSamplerKM(self.rng)
else:
raise ValueError(
"time_sampler options are 'uniform' and 'kaplan-meier', "
f"but got {time_sampler}."
)
self.time_sampler.fit(y)


class SurvivalBoost(BaseEstimator, ClassifierMixin):
r"""Cause-specific Cumulative Incidence Function (CIF) with GBDT [1]_.
Expand Down Expand Up @@ -333,6 +380,7 @@ def __init__(
n_iter_before_feedback=20,
random_state=None,
n_horizons_per_observation=3,
time_sampler="kaplan-meier",
):
self.hard_zero_fraction = hard_zero_fraction
self.n_iter = n_iter
Expand All @@ -347,6 +395,7 @@ def __init__(
self.ipcw_strategy = ipcw_strategy
self.random_state = random_state
self.n_horizons_per_observation = n_horizons_per_observation
self.time_sampler = time_sampler

def fit(self, X, y, times=None):
"""Fit the model.
Expand Down Expand Up @@ -419,6 +468,7 @@ def fit(self, X, y, times=None):
random_state=self.random_state,
ipcw_estimator=ipcw_estimator,
n_iter_before_feedback=self.n_iter_before_feedback,
time_sampler=self.time_sampler,
)

iterator = range(self.n_iter)
Expand Down Expand Up @@ -582,6 +632,7 @@ def _build_base_estimator(self):
max_leaf_nodes=self.max_leaf_nodes,
max_depth=self.max_depth,
min_samples_leaf=self.min_samples_leaf,
random_state=self.random_state,
)

def score(self, X, y):
Expand Down
14 changes: 14 additions & 0 deletions hazardous/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,17 @@ def check_event_of_interest(k):
f"got: event_of_interest={k}"
)
return


def make_time_grid(duration, n_steps=20):
t_min, t_max = duration.min(), duration.max()
return np.linspace(t_min, t_max, n_steps)


def make_recarray(y):
event = y["event"].values
duration = y["duration"].values
return np.array(
[(event[i], duration[i]) for i in range(y.shape[0])],
dtype=[("e", bool), ("t", float)],
)
Loading