Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ prediction. This is why we also offer the possibility to compare your model to a
the mean IC50 of all drugs in the training set. We also offer two more advanced naive predictors:
**NaiveCellLineMeanPredictor** and **NaiveDrugMeanPredictor**. The former predicts the mean IC50 of a cell line in
the training set and the latter predicts the mean IC50 of a drug in the training set.
Finally, as the strongest naive baseline we offer the **NaiveMeanEffectPredictor**
which combines the effects of cell lines and drugs.
It is equivalent to the **NaiveCellLineMeanPredictor** and **NaiveDrugMeanPredictor** for the LDO and LPO settings, respectively.

Available Models
------------------
Expand All @@ -119,6 +122,8 @@ For ``--models``, you can also perform randomization and robustness tests. The `
+----------------------------+----------------------------+--------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| NaiveDrugMeanPredictor | Baseline Method | Multi-Drug Model | Predicts the mean response of a drug in the training set. |
+----------------------------+----------------------------+--------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| NaiveMeanEffectPredictor | Baseline Method | Multi-Drug Model | Predicts using ANOVA-like mean effect model of cell lines and drugs |
+----------------------------+----------------------------+--------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| ElasticNet | Baseline Method | Multi-Drug Model | Fits an `Sklearn Elastic Net <https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html>`_, `Lasso <https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html>`_, or `Ridge <https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Ridge.html>`_ model on gene expression data and drug fingerprints (concatenated input matrix). |
+----------------------------+----------------------------+--------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| GradientBoosting | Baseline Method | Multi-Drug Model | Fits an `Sklearn Gradient Boosting Regressor <https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingRegressor.html>`_ gene expression data and drug fingerprints. |
Expand Down
9 changes: 8 additions & 1 deletion drevalpy/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"NaivePredictor",
"NaiveDrugMeanPredictor",
"NaiveCellLineMeanPredictor",
"NaiveMeanEffectsPredictor",
"ElasticNetModel",
"RandomForest",
"SVMRegressor",
Expand All @@ -22,7 +23,12 @@
]

from .baselines.multi_omics_random_forest import MultiOmicsRandomForest
from .baselines.naive_pred import NaiveCellLineMeanPredictor, NaiveDrugMeanPredictor, NaivePredictor
from .baselines.naive_pred import (
NaiveCellLineMeanPredictor,
NaiveDrugMeanPredictor,
NaiveMeanEffectsPredictor,
NaivePredictor,
)
from .baselines.singledrug_random_forest import SingleDrugRandomForest
from .baselines.sklearn_models import ElasticNetModel, GradientBoosting, RandomForest, SVMRegressor
from .DIPK.dipk import DIPKModel
Expand All @@ -45,6 +51,7 @@
"NaivePredictor": NaivePredictor,
"NaiveDrugMeanPredictor": NaiveDrugMeanPredictor,
"NaiveCellLineMeanPredictor": NaiveCellLineMeanPredictor,
"NaiveMeanEffectsPredictor": NaiveMeanEffectsPredictor,
"ElasticNet": ElasticNetModel,
"RandomForest": RandomForest,
"SVR": SVMRegressor,
Expand Down
1 change: 1 addition & 0 deletions drevalpy/models/baselines/hyperparameters.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
NaivePredictor:
NaiveDrugMeanPredictor:
NaiveCellLineMeanPredictor:
NaiveANOVAPredictor:
ElasticNet:
l1_ratio:
- 0
Expand Down
145 changes: 145 additions & 0 deletions drevalpy/models/baselines/naive_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
The naive predictor models are simple models that predict the mean of the response values. The NaivePredictor
predicts the overall mean of the response, the NaiveCellLineMeanPredictor predicts the mean of the response per cell
line, and the NaiveDrugMeanPredictor predicts the mean of the response per drug.
The NaiveMeanEffectsPredictor predicts the response as the overall mean plus the cell line effect
plus the drug effect and should be the strongest naive baseline.

"""

import numpy as np
Expand Down Expand Up @@ -334,3 +337,145 @@ def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDatase
:returns: FeatureDataset containing the drug ids
"""
return load_drug_ids_from_csv(data_path, dataset_name)


class NaiveMeanEffectsPredictor(DRPModel):
"""
ANOVA-like predictor model.

Predicts the response as:
response = overall_mean + cell_line_effect + drug_effect.

Here:
- cell_line_effect = (cell line mean - overall_mean)
- drug_effect = (drug mean - overall_mean)

This formulation ensures that the overall mean is not counted twice.
"""

cell_line_views = ["cell_line_id"]
drug_views = ["drug_id"]

def __init__(self):
"""
Initializes the NaiveMeanEffectsPredictor model.

The overall dataset mean, cell line effects, and drug effects are initialized to None
and empty dictionaries, respectively.
"""
super().__init__()
self.dataset_mean = None
self.cell_line_effects = {}
self.drug_effects = {}

@classmethod
def get_model_name(cls) -> str:
"""
Returns the name of the model.

:return: The name of the model as a string.
"""
return "NaiveMeanEffectsPredictor"

def build_model(self, hyperparameters: dict):
"""
Builds the model.

This model does not require any hyperparameter tuning.

:param hyperparameters: Dictionary of hyperparameters (not used).
"""
pass

def train(
self,
output: DrugResponseDataset,
cell_line_input: FeatureDataset,
drug_input: FeatureDataset | None = None,
output_earlystopping: DrugResponseDataset | None = None,
model_checkpoint_dir: str = "checkpoints",
) -> None:
"""
Trains with overall mean, cell line effects, and drug effects.

:param output: Training dataset containing the response output.
:param cell_line_input: Feature dataset containing cell line IDs.
:param drug_input: Feature dataset containing drug IDs. Must not be None.
:param output_earlystopping: Not used.
:param model_checkpoint_dir: Not used.
:raises ValueError: If drug_input is None.
"""
if drug_input is None:
raise ValueError("drug_input (drug_id) is required for ANOVAPredictor.")

# Compute the overall mean response.
self.dataset_mean = np.mean(output.response)

# Obtain cell line features.
cell_line_ids = cell_line_input.get_feature_matrix(view="cell_line_id", identifiers=output.cell_line_ids)
cell_line_means = {}
for cl_output, cl_feature in zip(unique(output.cell_line_ids), unique(cell_line_ids), strict=True):
responses_cl = output.response[cl_feature == output.cell_line_ids]
if len(responses_cl) > 0:
cell_line_means[cl_output] = np.mean(responses_cl)

# Obtain drug features.
drug_ids = drug_input.get_feature_matrix(view="drug_id", identifiers=output.drug_ids)
drug_means = {}
for drug_output, drug_feature in zip(unique(output.drug_ids), unique(drug_ids), strict=True):
responses_drug = output.response[drug_feature == output.drug_ids]
if len(responses_drug) > 0:
drug_means[drug_output] = np.mean(responses_drug)

# Compute the effects as deviations from the overall mean.
self.cell_line_effects = {cl: (mean - self.dataset_mean) for cl, mean in cell_line_means.items()}
self.drug_effects = {drug: (mean - self.dataset_mean) for drug, mean in drug_means.items()}

def predict(
self,
cell_line_ids: np.ndarray,
drug_ids: np.ndarray,
cell_line_input: FeatureDataset,
drug_input: FeatureDataset | None = None,
) -> np.ndarray:
"""
Predicts responses for given cell line and drug pairs.

The prediction is computed as:
prediction = overall_mean + cell_line_effect + drug_effect

If a cell line or drug has not been seen during training, their effect is set to zero.

:param cell_line_ids: Array of cell line IDs.
:param drug_ids: Array of drug IDs.
:param cell_line_input: Not used.
:param drug_input: Not used.
:return: NumPy array of predicted responses.
"""
predictions = []
for cl, drug in zip(cell_line_ids, drug_ids):
effect_cl = self.cell_line_effects.get(cl, 0)
effect_drug = self.drug_effects.get(drug, 0)
# ANOVA-based prediction: overall mean + cell line effect + drug effect.
predictions.append(self.dataset_mean + effect_cl + effect_drug)
return np.array(predictions)

def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset:
"""
Loads the cell line features.

:param data_path: Path to the data.
:param dataset_name: Name of the dataset.
:return: FeatureDataset containing the cell line IDs.
"""
return load_cl_ids_from_csv(data_path, dataset_name)

def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset:
"""
Loads the drug features.

:param data_path: Path to the data.
:param dataset_name: Name of the dataset.
:return: FeatureDataset containing the drug IDs.
"""
return load_drug_ids_from_csv(data_path, dataset_name)
51 changes: 51 additions & 0 deletions tests/individual_models/test_baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
MODEL_FACTORY,
NaiveCellLineMeanPredictor,
NaiveDrugMeanPredictor,
NaiveMeanEffectsPredictor,
NaivePredictor,
SingleDrugRandomForest,
)
Expand Down Expand Up @@ -299,3 +300,53 @@ def _call_other_baselines(
assert val_dataset.predictions is not None
metrics = evaluate(val_dataset, metric=["Pearson"])
assert metrics["Pearson"] >= -1


@pytest.mark.parametrize("test_mode", ["LPO", "LCO", "LDO"])
def test_naive_anova_predictor(
sample_dataset: tuple[DrugResponseDataset, FeatureDataset, FeatureDataset], test_mode: str
) -> None:
"""
Test the NaiveMeanEffectsPredictor model.

:param sample_dataset: from conftest.py
:param test_mode: either LPO, LCO, or LDO
"""
drug_response, cell_line_input, drug_input = sample_dataset
drug_response.split_dataset(n_cv_splits=5, mode=test_mode)

assert drug_response.cv_splits is not None
split = drug_response.cv_splits[0]
train_dataset = split["train"]
val_dataset = split["validation"]

cell_lines_to_keep = cell_line_input.identifiers
drugs_to_keep = drug_input.identifiers

len_train_before = len(train_dataset)
len_pred_before = len(val_dataset)
train_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep)
val_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep)
print(f"Reduced training dataset from {len_train_before} to {len(train_dataset)}")
print(f"Reduced val dataset from {len_pred_before} to {len(val_dataset)}")

naive = NaiveMeanEffectsPredictor()
naive.train(output=train_dataset, cell_line_input=cell_line_input, drug_input=drug_input)
val_dataset._predictions = naive.predict(
cell_line_ids=val_dataset.cell_line_ids,
drug_ids=val_dataset.drug_ids,
cell_line_input=cell_line_input,
)

assert val_dataset.predictions is not None
train_mean = train_dataset.response.mean()
assert train_mean == naive.dataset_mean

# Check that predictions are within a reasonable range
assert np.all(np.isfinite(val_dataset.predictions))
assert np.all(val_dataset.predictions >= np.min(train_dataset.response))
assert np.all(val_dataset.predictions <= np.max(train_dataset.response))

metrics = evaluate(val_dataset, metric=["Pearson"])
print(f"{test_mode}: Performance of NaiveMeanEffectsPredictor: PCC = {metrics['Pearson']}")
assert metrics["Pearson"] >= -1 # Should be within valid Pearson range
2 changes: 1 addition & 1 deletion tests/test_drp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def test_factory() -> None:
assert "NaivePredictor" in MODEL_FACTORY
assert "NaiveDrugMeanPredictor" in MODEL_FACTORY
assert "NaiveCellLineMeanPredictor" in MODEL_FACTORY
assert "NaiveMeanEffectsPredictor" in MODEL_FACTORY
assert "ElasticNet" in MODEL_FACTORY
assert "RandomForest" in MODEL_FACTORY
assert "SVR" in MODEL_FACTORY
Expand All @@ -37,7 +38,6 @@ def test_factory() -> None:
assert "MOLIR" in MODEL_FACTORY
assert "SuperFELTR" in MODEL_FACTORY
assert "DIPK" in MODEL_FACTORY
assert len(MODEL_FACTORY) == 15


def test_load_cl_ids_from_csv() -> None:
Expand Down
Loading