Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split the train set into incidence and censoring #84

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 59 additions & 32 deletions hazardous/_survival_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.utils.validation import check_array, check_random_state
from sklearn.model_selection import train_test_split
from sklearn.utils.validation import check_random_state
from tqdm import tqdm

from ._ipcw import AlternatingCensoringEstimator, KaplanMeierIPCW
Expand All @@ -12,7 +13,7 @@
integrated_brier_score_incidence,
integrated_brier_score_survival,
)
from .utils import check_y_survival
from .utils import check_array, check_y_survival


class WeightedMultiClassTargetSampler(IncidenceScoreComputer):
Expand Down Expand Up @@ -72,18 +73,20 @@ def __init__(
)
# Precompute the censoring probabilities at the time of the events on the
# training set:
self.ipcw_train = self.ipcw_estimator.compute_ipcw_at(self.duration_train)
self.update_ipcw(y=y_train)

def draw(self, y, X=None, ipcw_training=False):
events, durations = check_y_survival(y)
any_events = events > 0

def draw(self, ipcw_training=False, X=None):
# Sample time horizons uniformly on the observed time range:
observation_durations = self.duration_train
n_samples = observation_durations.shape[0]
n_samples = durations.shape[0]

# Sample from t_min=0 event if never observed in the training set
# because we want to make sure that the model learns to predict a 0
# incidence at t=0.
t_min = 0.0
t_max = observation_durations.max()
t_max = durations.max()
sampled_time_horizons = self.rng.uniform(t_min, t_max, n_samples)

# Add some hard zeros to make sure that the model learns to
Expand All @@ -106,17 +109,15 @@ def draw(self, ipcw_training=False, X=None):
# * 0 when an event has happened before the sampled time horizon.
# The sample weight is zero in that case.

if not hasattr(self, "inv_any_survival_train"):
self.inv_any_survival_train = self.ipcw_estimator.compute_ipcw_at(
self.duration_train, ipcw_training=True, X=X
)
if not hasattr(self, "ipiw"):
self.update_ipiw(X=X, y=y)

censored_observations = self.any_event_train == 0
censored_observations = any_events == 0
y_targets, sample_weight = self._weighted_binary_targets(
censored_observations,
observation_durations,
durations,
sampled_time_horizons,
ipcw_y_duration=self.inv_any_survival_train,
ipcw_y_duration=self.ipiw_y_duration,
ipcw_training=True,
X=X,
)
Expand All @@ -133,26 +134,25 @@ def draw(self, ipcw_training=False, X=None):
# than the sampled time horizon. The sample weight is zero in
# that case.
y_binary, sample_weight = self._weighted_binary_targets(
self.any_event_train,
observation_durations,
any_events,
durations,
sampled_time_horizons,
ipcw_y_duration=self.ipcw_train,
ipcw_y_duration=self.ipcw_y_duration,
ipcw_training=False,
X=X,
)
y_targets = y_binary * self.event_train
y_targets = y_binary * events

return sampled_time_horizons.reshape(-1, 1), y_targets, sample_weight

def fit(self, X):
self.inv_any_survival_train = self.ipcw_estimator.compute_ipcw_at(
self.duration_train, ipcw_training=True, X=X
)
def fit(self, X, y):
self.update_ipiw(X=X, y=y)

for _ in range(self.n_iter_before_feedback):
sampled_time_horizons, y_targets, sample_weight = self.draw(
ipcw_training=True,
y,
X=X,
ipcw_training=True,
)
self.ipcw_estimator.fit_censoring_estimator(
X,
Expand All @@ -161,12 +161,22 @@ def fit(self, X):
sample_weight=sample_weight,
)

self.ipcw_train = self.ipcw_estimator.compute_ipcw_at(
self.duration_train,
def update_ipcw(self, y, X=None):
_, durations = check_y_survival(y)
self.ipcw_y_duration = self.ipcw_estimator.compute_ipcw_at(
durations,
ipcw_training=False,
X=X,
)

def update_ipiw(self, y, X=None):
_, durations = check_y_survival(y)
self.ipiw_y_duration = self.ipcw_estimator.compute_ipcw_at(
durations,
ipcw_training=True,
X=X,
)


class SurvivalBoost(BaseEstimator, ClassifierMixin):
r"""Cause-specific Cumulative Incidence Function (CIF) with GBDT [1]_.
Expand Down Expand Up @@ -333,6 +343,7 @@ def __init__(
n_iter_before_feedback=20,
random_state=None,
n_horizons_per_observation=3,
split_censor_incidence=False,
):
self.hard_zero_fraction = hard_zero_fraction
self.n_iter = n_iter
Expand All @@ -347,6 +358,7 @@ def __init__(
self.ipcw_strategy = ipcw_strategy
self.random_state = random_state
self.n_horizons_per_observation = n_horizons_per_observation
self.split_censor_incidence = split_censor_incidence

def fit(self, X, y, times=None):
"""Fit the model.
Expand Down Expand Up @@ -383,8 +395,6 @@ def fit(self, X, y, times=None):
# Add 0 as a special event id for the survival function.
self.event_ids_ = np.array(sorted(list(set([0]) | set(event))))

self.estimator_ = self._build_base_estimator()

# Compute the default time grid used at prediction time.
any_event_mask = event > 0
observed_times = duration[any_event_mask]
Expand All @@ -401,6 +411,8 @@ def fit(self, X, y, times=None):
self.time_grid_ = times.copy()
self.time_grid_.sort()

self.estimator_ = self._build_base_estimator()

if self.ipcw_strategy == "alternating":
ipcw_estimator = AlternatingCensoringEstimator(
incidence_estimator=self.estimator_
Expand All @@ -413,8 +425,20 @@ def fit(self, X, y, times=None):
"Valid values are 'alternating' and 'kaplan-meier'."
)

if self.split_censor_incidence:
X_incidence, X_censor, y_incidence, y_censor = train_test_split(
X,
y,
stratify=event,
test_size=0.1,
random_state=self.random_state,
)
else:
X_incidence = X_censor = X
y_incidence = y_censor = y

self.weighted_targets_ = WeightedMultiClassTargetSampler(
y,
y_incidence,
hard_zero_fraction=self.hard_zero_fraction,
random_state=self.random_state,
ipcw_estimator=ipcw_estimator,
Expand All @@ -426,17 +450,19 @@ def fit(self, X, y, times=None):
iterator = tqdm(iterator)

for idx_iter in iterator:
X_with_time = np.empty((0, X.shape[1] + 1))
X_with_time = np.empty((0, X_incidence.shape[1] + 1))
y_targets = np.empty((0,))
sample_weight = np.empty((0,))
for _ in range(self.n_horizons_per_observation):
(
sampled_times_,
y_targets_,
sample_weight_,
) = self.weighted_targets_.draw(X=X, ipcw_training=False)
) = self.weighted_targets_.draw(
y_incidence, X=X_incidence, ipcw_training=False
)

X_with_time_ = np.hstack([sampled_times_, X])
X_with_time_ = np.hstack([sampled_times_, X_incidence])
X_with_time = np.vstack([X_with_time, X_with_time_])
y_targets = np.hstack([y_targets, y_targets_])
sample_weight = np.hstack([sample_weight, sample_weight_])
Expand All @@ -455,7 +481,8 @@ def fit(self, X, y, times=None):
if (idx_iter % self.n_iter_before_feedback == 0) and isinstance(
ipcw_estimator, AlternatingCensoringEstimator
):
self.weighted_targets_.fit(X)
self.weighted_targets_.fit(X=X_censor, y=y_censor)
self.weighted_targets_.update_ipcw(X=X_incidence, y=y_incidence)

# XXX: implement verbose logging with a version of IBS that
# can handle competing risks.
Expand Down
15 changes: 6 additions & 9 deletions hazardous/metrics/_brier_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ def __init__(
ipcw_estimator=None,
):
self.y_train = y_train
self.event_train, self.duration_train = check_y_survival(y_train)
self.event_ids_ = np.unique(self.event_train)
self.any_event_train = self.event_train > 0
event_train, duration_train = check_y_survival(y_train)
self.event_ids_ = np.unique(event_train)
any_event_train = event_train > 0
self.event_of_interest = event_of_interest

y = dict(
event=self.any_event_train,
duration=self.duration_train,
event=any_event_train,
duration=duration_train,
)
# Estimate the censoring distribution from the training set.
if ipcw_estimator is None:
Expand Down Expand Up @@ -130,10 +130,7 @@ def brier_score_incidence(self, y_true, y_pred, times):
check_event_of_interest(self.event_of_interest)

if self.event_of_interest == "any":
if y_true is self.y_train:
event_true = self.any_event_train
else:
event_true = event_true > 0
event_true = event_true > 0

if y_pred.ndim != 2:
raise ValueError(
Expand Down
19 changes: 19 additions & 0 deletions hazardous/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import numpy as np
import pandas as pd
import sklearn
from sklearn.utils.fixes import parse_version
from sklearn.utils.validation import check_array as check_array_sk
from sklearn.utils.validation import check_scalar


Expand Down Expand Up @@ -53,3 +56,19 @@ def check_event_of_interest(k):
f"got: event_of_interest={k}"
)
return


def check_array(X, **params):
# Fix check_array() force_all_finite deprecation warning
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)

x_all_finite = True # default value
for kwarg in ["force_all_finite", "ensure_all_finite"]:
if params.get(kwarg, False):
x_all_finite = params.pop(kwarg)
break

if sklearn_version < parse_version("1.6.0"):
return check_array_sk(X, force_all_finite=x_all_finite, **params)
else:
return check_array_sk(X, ensure_all_finite=x_all_finite, **params)
Loading