diff --git a/.gitignore b/.gitignore index 5e873b1f..e41b2127 100644 --- a/.gitignore +++ b/.gitignore @@ -59,6 +59,7 @@ coverage.xml .hypothesis/ .pytest_cache/ cover/ +tests/lightning_logs/ # Translations *.mo diff --git a/drevalpy/datasets/featurizer/__init__.py b/drevalpy/datasets/featurizer/__init__.py new file mode 100644 index 00000000..5ba97ed4 --- /dev/null +++ b/drevalpy/datasets/featurizer/__init__.py @@ -0,0 +1,77 @@ +"""Featurizers for converting drug and cell line data to embeddings. + +This module provides abstract base classes and concrete implementations for +featurizing drugs and cell lines for drug response prediction models. + +Drug Featurizers: + - DrugFeaturizer: Abstract base class for drug featurizers + - ChemBERTaFeaturizer: ChemBERTa transformer embeddings from SMILES + - DrugGraphFeaturizer: Molecular graph representations + - MolGNetFeaturizer: MolGNet graph neural network embeddings + +Cell Line Featurizers: + - CellLineFeaturizer: Abstract base class for cell line featurizers + - PCAFeaturizer: PCA dimensionality reduction for omics data + +Mixins for DRP Models: + - ChemBERTaMixin: Provides load_drug_features using ChemBERTa + - DrugGraphMixin: Provides load_drug_features using DrugGraphFeaturizer + - MolGNetMixin: Provides load_drug_features using MolGNet + - PCAMixin: Provides load_cell_line_features using PCA + +Example usage:: + + from drevalpy.datasets.featurizer import ChemBERTaFeaturizer, PCAFeaturizer + + # Drug features + drug_featurizer = ChemBERTaFeaturizer(device="cuda") + drug_features = drug_featurizer.load_or_generate("data", "GDSC1") + + # Cell line features + cell_featurizer = PCAFeaturizer(n_components=100) + cell_features = cell_featurizer.load_or_generate("data", "GDSC1") + +Example using mixins in a model:: + + from drevalpy.models.drp_model import DRPModel + from drevalpy.datasets.featurizer import ChemBERTaMixin, PCAMixin + + class MyModel(ChemBERTaMixin, PCAMixin, DRPModel): + # ChemBERTaMixin provides load_drug_features + # PCAMixin provides load_cell_line_features + ... +""" + +# Cell line featurizers +from .cell_line import ( + CellLineFeaturizer, + PCAFeaturizer, + PCAMixin, +) + +# Drug featurizers +from .drug import ( + ChemBERTaFeaturizer, + ChemBERTaMixin, + DrugFeaturizer, + DrugGraphFeaturizer, + DrugGraphMixin, + MolGNetFeaturizer, + MolGNetMixin, +) + +__all__ = [ + # Drug featurizers + "DrugFeaturizer", + "ChemBERTaFeaturizer", + "DrugGraphFeaturizer", + "MolGNetFeaturizer", + # Cell line featurizers + "CellLineFeaturizer", + "PCAFeaturizer", + # Mixins + "ChemBERTaMixin", + "DrugGraphMixin", + "MolGNetMixin", + "PCAMixin", +] diff --git a/drevalpy/datasets/featurizer/cell_line/__init__.py b/drevalpy/datasets/featurizer/cell_line/__init__.py new file mode 100644 index 00000000..8c11b2ba --- /dev/null +++ b/drevalpy/datasets/featurizer/cell_line/__init__.py @@ -0,0 +1,10 @@ +"""Cell line featurizers for converting omics data to embeddings.""" + +from .base import CellLineFeaturizer +from .pca import PCAFeaturizer, PCAMixin + +__all__ = [ + "CellLineFeaturizer", + "PCAFeaturizer", + "PCAMixin", +] diff --git a/drevalpy/datasets/featurizer/cell_line/base.py b/drevalpy/datasets/featurizer/cell_line/base.py new file mode 100644 index 00000000..499321f2 --- /dev/null +++ b/drevalpy/datasets/featurizer/cell_line/base.py @@ -0,0 +1,256 @@ +"""Abstract base class for cell line featurizers.""" + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + +import numpy as np +import pandas as pd + +from drevalpy.datasets.dataset import FeatureDataset +from drevalpy.datasets.utils import CELL_LINE_IDENTIFIER + + +class CellLineFeaturizer(ABC): + """Abstract base class for cell line featurizers. + + Cell line featurizers convert omics data (e.g., gene expression, methylation) + into numerical embeddings that can be used as input features for machine learning models. + + Supports both single-omics and multi-omics featurization through the `omics_types` + parameter. + + Subclasses must implement: + - featurize(): Convert omics data for a single cell line to its embedding + - get_feature_name(): Return the name of the feature view + - get_output_filename(): Return the filename for cached embeddings + + The base class provides: + - load_or_generate(): Load cached embeddings or generate and cache them + - generate_embeddings(): Generate embeddings for all cell lines in a dataset + - load_embeddings(): Load pre-generated embeddings from disk + """ + + # Supported omics types and their corresponding file names + OMICS_FILE_MAPPING = { + "gene_expression": "gene_expression.csv", + "methylation": "methylation.csv", + "mutations": "mutations.csv", + "copy_number_variation": "copy_number_variation.csv", + } + + def __init__(self, omics_types: list[str] | str = "gene_expression"): + """Initialize the featurizer. + + :param omics_types: Single omics type or list of omics types to use. + Supported types: 'gene_expression', 'methylation', + 'mutations', 'copy_number_variation' + :raises ValueError: If an unsupported omics type is provided + """ + if isinstance(omics_types, str): + omics_types = [omics_types] + + for omics_type in omics_types: + if omics_type not in self.OMICS_FILE_MAPPING: + raise ValueError( + f"Unsupported omics type: {omics_type}. " f"Supported types: {list(self.OMICS_FILE_MAPPING.keys())}" + ) + + self.omics_types = omics_types + + @abstractmethod + def featurize(self, omics_data: dict[str, np.ndarray]) -> np.ndarray | Any: + """Convert omics data to a feature representation. + + :param omics_data: Dictionary mapping omics type to data array for a single cell line + :returns: Feature representation (numpy array or other format) + """ + + @classmethod + @abstractmethod + def get_feature_name(cls) -> str: + """Return the name of the feature view. + + This name is used as the key in the FeatureDataset. + + :returns: Feature view name (e.g., 'gene_expression_pca') + """ + + @abstractmethod + def get_output_filename(self) -> str: + """Return the filename for cached embeddings. + + Note: This is an instance method (not classmethod) because the filename + may depend on featurizer parameters (e.g., n_components for PCA). + + :returns: Filename (e.g., 'cell_line_gene_expression_pca_100.csv') + """ + + def load_or_generate(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load cached embeddings or generate and cache them if not available. + + This is the main entry point for using a featurizer. It checks if + pre-generated embeddings exist and loads them, otherwise generates + new embeddings and saves them for future use. + + :param data_path: Path to the data directory (e.g., 'data/') + :param dataset_name: Name of the dataset (e.g., 'GDSC1') + :returns: FeatureDataset containing the cell line embeddings + """ + output_path = Path(data_path) / dataset_name / self.get_output_filename() + + if output_path.exists(): + return self.load_embeddings(data_path, dataset_name) + else: + print(f"Embeddings not found at {output_path}. Generating...") + return self.generate_embeddings(data_path, dataset_name) + + def _load_omics_data(self, data_path: str, dataset_name: str) -> dict[str, pd.DataFrame]: + """Load omics data files for the specified omics types. + + :param data_path: Path to the data directory + :param dataset_name: Name of the dataset + :returns: Dictionary mapping omics type to DataFrame + :raises FileNotFoundError: If any required omics file is not found + """ + data_dir = Path(data_path) / dataset_name + omics_data = {} + + for omics_type in self.omics_types: + filename = self.OMICS_FILE_MAPPING[omics_type] + filepath = data_dir / filename + + if not filepath.exists(): + raise FileNotFoundError( + f"Omics data file not found: {filepath}. " f"Please ensure the {omics_type} data is available." + ) + + df = pd.read_csv(filepath, dtype={CELL_LINE_IDENTIFIER: str}) + df = df.set_index(CELL_LINE_IDENTIFIER) + omics_data[omics_type] = df + + return omics_data + + def generate_embeddings(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Generate embeddings for all cell lines in a dataset and save to disk. + + :param data_path: Path to the data directory + :param dataset_name: Name of the dataset + :returns: FeatureDataset containing the generated embeddings + """ + data_dir = Path(data_path).resolve() + output_file = data_dir / dataset_name / self.get_output_filename() + + # Load omics data + omics_data = self._load_omics_data(data_path, dataset_name) + + # Get common cell line IDs across all omics types + cell_line_ids_set: set[str] = set() + for _omics_type, df in omics_data.items(): + cell_line_ids_set = cell_line_ids_set.intersection(set(df.index)) + + cell_line_ids_list = sorted(list(cell_line_ids_set)) + print(f"Processing {len(cell_line_ids_list)} cell lines for dataset {dataset_name}...") + + # Generate embeddings + embeddings_list = [] + valid_cell_line_ids = [] + + for cell_line_id in cell_line_ids_list: + try: + # Prepare omics data for this cell line + cell_omics = {} + for omics_type, df in omics_data.items(): + cell_omics[omics_type] = df.loc[cell_line_id].to_numpy(dtype=np.float32) + + embedding = self.featurize(cell_omics) + embeddings_list.append(embedding) + valid_cell_line_ids.append(cell_line_id) + except Exception as e: + print(f"Failed to process cell line {cell_line_id}: {e}") + continue + + # Save embeddings + self._save_embeddings(embeddings_list, valid_cell_line_ids, output_file, omics_data) + + print(f"Embeddings saved to {output_file}") + + # Return as FeatureDataset + return self._create_feature_dataset(embeddings_list, valid_cell_line_ids) + + def _save_embeddings( + self, + embeddings: list, + cell_line_ids: list[str], + output_path: Path, + omics_data: dict[str, pd.DataFrame] | None = None, + ) -> None: + """Save embeddings to disk. + + Default implementation saves as CSV. Subclasses can override for other formats. + + :param embeddings: List of embedding arrays + :param cell_line_ids: List of cell line identifiers + :param output_path: Path to save the embeddings + :param omics_data: Optional omics data (may be used by subclasses for saving fitted models) + """ + embeddings_df = pd.DataFrame(embeddings) + embeddings_df.insert(0, CELL_LINE_IDENTIFIER, cell_line_ids) + embeddings_df.to_csv(output_path, index=False) + + def _create_feature_dataset(self, embeddings: list, cell_line_ids: list[str]) -> FeatureDataset: + """Create a FeatureDataset from embeddings. + + :param embeddings: List of embedding arrays + :param cell_line_ids: List of cell line identifiers + :returns: FeatureDataset containing the embeddings + """ + feature_name = self.get_feature_name() + features = {} + for cell_line_id, embedding in zip(cell_line_ids, embeddings, strict=True): + if isinstance(embedding, np.ndarray): + features[cell_line_id] = {feature_name: embedding.astype(np.float32)} + else: + features[cell_line_id] = {feature_name: embedding} + return FeatureDataset(features) + + def load_embeddings(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load pre-generated embeddings from disk. + + :param data_path: Path to the data directory + :param dataset_name: Name of the dataset + :returns: FeatureDataset containing the embeddings + :raises FileNotFoundError: If the embeddings file is not found + """ + embeddings_file = Path(data_path) / dataset_name / self.get_output_filename() + + if not embeddings_file.exists(): + raise FileNotFoundError( + f"Embeddings file not found: {embeddings_file}. " + f"Use load_or_generate() to automatically generate embeddings." + ) + + embeddings_df = pd.read_csv(embeddings_file, dtype={CELL_LINE_IDENTIFIER: str}) + feature_name = self.get_feature_name() + features = {} + + for _, row in embeddings_df.iterrows(): + cell_line_id = row[CELL_LINE_IDENTIFIER] + embedding = row.drop(CELL_LINE_IDENTIFIER).to_numpy(dtype=np.float32) + features[cell_line_id] = {feature_name: embedding} + + return FeatureDataset(features) + + +def main(): + """Entry point for running featurizer from command line. + + This function should be overridden by subclasses that support CLI usage. + + :raises NotImplementedError: Always, as subclasses should implement their own main() + """ + raise NotImplementedError("Subclasses should implement their own main() function") + + +if __name__ == "__main__": + main() diff --git a/drevalpy/datasets/featurizer/cell_line/pca.py b/drevalpy/datasets/featurizer/cell_line/pca.py new file mode 100644 index 00000000..4fba4ae8 --- /dev/null +++ b/drevalpy/datasets/featurizer/cell_line/pca.py @@ -0,0 +1,243 @@ +"""PCA featurizer for cell line gene expression data.""" + +import argparse +import pickle # noqa: S403 +from pathlib import Path + +import numpy as np +import pandas as pd +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler + +from drevalpy.datasets.dataset import FeatureDataset +from drevalpy.datasets.utils import CELL_LINE_IDENTIFIER + +from .base import CellLineFeaturizer + + +class PCAFeaturizer(CellLineFeaturizer): + """Featurizer that applies PCA to gene expression data. + + This featurizer standardizes gene expression data and applies PCA + to reduce dimensionality. It is designed specifically for transcriptomics + (gene expression) data. + + Example usage:: + + featurizer = PCAFeaturizer(n_components=100) + features = featurizer.load_or_generate("data", "GDSC1") + """ + + def __init__(self, n_components: int = 100): + """Initialize the PCA featurizer. + + :param n_components: Number of principal components to keep + """ + super().__init__(omics_types="gene_expression") + self.n_components = n_components + self._scaler: StandardScaler | None = None + self._pca: PCA | None = None + self._fitted = False + + def featurize(self, omics_data: dict[str, np.ndarray]) -> np.ndarray: + """Apply PCA transformation to gene expression data. + + :param omics_data: Dictionary with 'gene_expression' key containing the data + :returns: PCA-transformed features + :raises RuntimeError: If the PCA model is not fitted + :raises ValueError: If gene_expression data is not provided + """ + if not self._fitted or self._scaler is None or self._pca is None: + raise RuntimeError("PCA model is not fitted. Call generate_embeddings() or fit() first.") + + if "gene_expression" not in omics_data: + raise ValueError("gene_expression data is required for PCA featurizer") + + data = omics_data["gene_expression"].reshape(1, -1) + scaled = self._scaler.transform(data) + return self._pca.transform(scaled).flatten() + + def fit(self, gene_expression_df: pd.DataFrame) -> None: + """Fit the scaler and PCA model on gene expression data. + + :param gene_expression_df: DataFrame with cell lines as rows and genes as columns + """ + data = gene_expression_df.values + + self._scaler = StandardScaler() + scaled_data = self._scaler.fit_transform(data) + + n_components = min(self.n_components, min(scaled_data.shape)) + self._pca = PCA(n_components=n_components) + self._pca.fit(scaled_data) + + self._fitted = True + + @classmethod + def get_feature_name(cls) -> str: + """Return the feature view name. + + :returns: 'gene_expression_pca' + """ + return "gene_expression_pca" + + def get_output_filename(self) -> str: + """Return the output filename for cached embeddings. + + :returns: Filename like 'cell_line_gene_expression_pca_100.csv' + """ + return f"cell_line_gene_expression_pca_{self.n_components}.csv" + + def _get_model_filename(self) -> str: + """Return the filename for the fitted model. + + :returns: Filename like 'cell_line_gene_expression_pca_100_models.pkl' + """ + return f"cell_line_gene_expression_pca_{self.n_components}_models.pkl" + + def generate_embeddings(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Generate PCA embeddings for all cell lines and save to disk. + + :param data_path: Path to the data directory + :param dataset_name: Name of the dataset + :returns: FeatureDataset containing the PCA embeddings + :raises FileNotFoundError: If the gene expression file is not found + :raises RuntimeError: If fitting fails + """ + data_dir = Path(data_path).resolve() + output_file = data_dir / dataset_name / self.get_output_filename() + model_file = data_dir / dataset_name / self._get_model_filename() + + # Load gene expression data + ge_file = data_dir / dataset_name / "gene_expression.csv" + if not ge_file.exists(): + raise FileNotFoundError(f"Gene expression file not found: {ge_file}") + + ge_df = pd.read_csv(ge_file, dtype={CELL_LINE_IDENTIFIER: str}) + ge_df = ge_df.set_index(CELL_LINE_IDENTIFIER) + + # Drop non-numeric columns (e.g., cellosaurus_id) + ge_df = ge_df.select_dtypes(include=[np.number]) + + cell_line_ids = list(ge_df.index) + print(f"Processing {len(cell_line_ids)} cell lines for dataset {dataset_name}...") + + # Fit the model + self.fit(ge_df) + + # Transform all cell lines (scaler and pca are guaranteed to be set after fit()) + if self._scaler is None or self._pca is None: + raise RuntimeError("Fitting failed: scaler or PCA model is None") + scaled_data = self._scaler.transform(ge_df.values) + embeddings = self._pca.transform(scaled_data) + + # Save embeddings + embeddings_df = pd.DataFrame(embeddings) + embeddings_df.insert(0, CELL_LINE_IDENTIFIER, cell_line_ids) + embeddings_df.to_csv(output_file, index=False) + + # Save fitted models + with open(model_file, "wb") as f: + pickle.dump({"scaler": self._scaler, "pca": self._pca}, f) + + print(f"Embeddings saved to {output_file}") + print(f"Fitted models saved to {model_file}") + + # Return as FeatureDataset + return self._create_feature_dataset(list(embeddings), cell_line_ids) + + def load_embeddings(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load pre-generated PCA embeddings from disk. + + Also loads the fitted scaler and PCA model for future transformations. + + :param data_path: Path to the data directory + :param dataset_name: Name of the dataset + :returns: FeatureDataset containing the embeddings + :raises FileNotFoundError: If the embeddings file is not found + """ + embeddings_file = Path(data_path) / dataset_name / self.get_output_filename() + model_file = Path(data_path) / dataset_name / self._get_model_filename() + + if not embeddings_file.exists(): + raise FileNotFoundError( + f"Embeddings file not found: {embeddings_file}. " + f"Use load_or_generate() to automatically generate embeddings." + ) + + # Load fitted models if available (optional - only needed for transforming new data) + if model_file.exists(): + with open(model_file, "rb") as f: + models = pickle.load(f) # noqa: S301 + self._scaler = models["scaler"] + self._pca = models["pca"] + self._fitted = True + + # Load embeddings + embeddings_df = pd.read_csv(embeddings_file, dtype={CELL_LINE_IDENTIFIER: str}) + feature_name = self.get_feature_name() + features = {} + + for _, row in embeddings_df.iterrows(): + cell_line_id = row[CELL_LINE_IDENTIFIER] + embedding = row.drop(CELL_LINE_IDENTIFIER).to_numpy(dtype=np.float32) + features[cell_line_id] = {feature_name: embedding} + + return FeatureDataset(features) + + +class PCAMixin: + """Mixin that provides PCA-transformed gene expression loading for DRP models. + + This mixin implements load_cell_line_features using the PCAFeaturizer. + It automatically generates embeddings if they don't exist. + + The number of PCA components can be configured via: + - hyperparameters['n_components'] (if the model has hyperparameters) + - pca_n_components class attribute (default: 100) + + Example usage:: + + from drevalpy.models.drp_model import DRPModel + from drevalpy.datasets.featurizer.cell_line.pca import PCAMixin + + class MyModel(PCAMixin, DRPModel): + cell_line_views = ["gene_expression_pca"] + ... + """ + + pca_n_components: int = 100 + + def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load PCA-transformed gene expression features. + + Uses the PCAFeaturizer to load pre-generated embeddings or generate + them automatically if they don't exist. + + :param data_path: Path to the data directory, e.g., 'data/' + :param dataset_name: Name of the dataset, e.g., 'GDSC1' + :returns: FeatureDataset containing the PCA-transformed gene expression + """ + # Try to get n_components from hyperparameters if available + n_components = self.pca_n_components + if hasattr(self, "hyperparameters") and self.hyperparameters is not None: + n_components = self.hyperparameters.get("n_components", n_components) + + featurizer = PCAFeaturizer(n_components=n_components) + return featurizer.load_or_generate(data_path, dataset_name) + + +def main(): + """Generate PCA embeddings for cell line gene expression from command line.""" + parser = argparse.ArgumentParser(description="Generate PCA embeddings for cell line gene expression.") + parser.add_argument("dataset_name", type=str, help="The name of the dataset to process.") + parser.add_argument("--data_path", type=str, default="data", help="Path to the data folder") + parser.add_argument("--n_components", type=int, default=100, help="Number of PCA components") + args = parser.parse_args() + + featurizer = PCAFeaturizer(n_components=args.n_components) + featurizer.generate_embeddings(args.data_path, args.dataset_name) + + +if __name__ == "__main__": + main() diff --git a/drevalpy/datasets/featurizer/create_chemberta_drug_embeddings.py b/drevalpy/datasets/featurizer/create_chemberta_drug_embeddings.py deleted file mode 100644 index 3d19ff18..00000000 --- a/drevalpy/datasets/featurizer/create_chemberta_drug_embeddings.py +++ /dev/null @@ -1,85 +0,0 @@ -"""Preprocesses drug SMILES strings into ChemBERTa embeddings.""" - -import argparse -from pathlib import Path - -import pandas as pd -import torch -from tqdm import tqdm - -try: - from transformers import AutoModel, AutoTokenizer -except ImportError: - raise ImportError( - "Please install transformers package for ChemBERTa embedding featurizer: pip install transformers" - ) -# Load ChemBERTa -tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") -model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") -model.eval() - - -def _smiles_to_chemberta(smiles: str, device="cpu"): - inputs = tokenizer(smiles, return_tensors="pt", truncation=True) - inputs = {k: v.to(device) for k, v in inputs.items()} - - with torch.no_grad(): - outputs = model(**inputs) - hidden_states = outputs.last_hidden_state - - embedding = hidden_states.mean(dim=1).squeeze(0) - return embedding.cpu().numpy() - - -def main(): - """Process drug SMILES and save ChemBERTa embeddings. - - :raises Exception: If a drug fails to process. - """ - parser = argparse.ArgumentParser(description="Preprocess drug SMILES to ChemBERTa embeddings.") - parser.add_argument("dataset_name", type=str, help="The name of the dataset to process.") - parser.add_argument("--device", type=str, default="cpu", help="Torch device (cpu or cuda)") - parser.add_argument("--data_path", type=str, default="data", help="Path to the data folder") - args = parser.parse_args() - - dataset_name = args.dataset_name - device = args.device - data_dir = Path(args.data_path).resolve() - - smiles_file = data_dir / dataset_name / "drug_smiles.csv" - output_file = data_dir / dataset_name / "drug_chemberta_embeddings.csv" - - if not smiles_file.exists(): - print(f"Error: {smiles_file} not found.") - return - - smiles_df = pd.read_csv(smiles_file, dtype={"canonical_smiles": str, "pubchem_id": str}) - embeddings_list = [] - drug_ids = [] - - print(f"Processing {len(smiles_df)} drugs for dataset {dataset_name}...") - - for row in tqdm(smiles_df.itertuples(index=False), total=len(smiles_df)): - drug_id = row.pubchem_id - smiles = row.canonical_smiles - - try: - embedding = _smiles_to_chemberta(smiles, device=device) - embeddings_list.append(embedding) - drug_ids.append(drug_id) - except Exception as e: - print() - print(smiles) - print() - print(f"Failed to process {drug_id}") - raise e - - embeddings_array = pd.DataFrame(embeddings_list) - embeddings_array.insert(0, "pubchem_id", drug_ids) - embeddings_array.to_csv(output_file, index=False) - - print(f"Finished processing. Embeddings saved to {output_file}") - - -if __name__ == "__main__": - main() diff --git a/drevalpy/datasets/featurizer/create_drug_graphs.py b/drevalpy/datasets/featurizer/create_drug_graphs.py deleted file mode 100644 index c79e09ca..00000000 --- a/drevalpy/datasets/featurizer/create_drug_graphs.py +++ /dev/null @@ -1,145 +0,0 @@ -""" -Preprocesses drug SMILES strings into graph representations. - -This script takes a dataset name as input, reads the corresponding -drug_smiles.csv file, and converts each SMILES string into a -torch_geometric.data.Data object. The resulting graph objects are saved -to {data_path}/{dataset_name}/drug_graphs/{drug_name}.pt. -""" - -import argparse -import os -from pathlib import Path - -import pandas as pd -import torch -from torch_geometric.data import Data -from tqdm import tqdm - -try: - from rdkit import Chem -except ImportError: - raise ImportError("Please install rdkit package for drug graphs featurizer: pip install rdkit") - -# Atom feature configuration -ATOM_FEATURES = { - "atomic_num": list(range(1, 119)), - "degree": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - "formal_charge": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], - "num_hs": [0, 1, 2, 3, 4, 5, 6, 7, 8], - "hybridization": [ - Chem.rdchem.HybridizationType.SP, - Chem.rdchem.HybridizationType.SP2, - Chem.rdchem.HybridizationType.SP3, - Chem.rdchem.HybridizationType.SP3D, - Chem.rdchem.HybridizationType.SP3D2, - ], -} - -# Bond feature configuration -BOND_FEATURES = { - "bond_type": [ - Chem.rdchem.BondType.SINGLE, - Chem.rdchem.BondType.DOUBLE, - Chem.rdchem.BondType.TRIPLE, - Chem.rdchem.BondType.AROMATIC, - ] -} - - -def _one_hot_encode(value, choices): - """Create a one-hot encoding for a value in a list of choices. - - :param value: The value to be one-hot encoded. - :param choices: A list of possible choices for the value. - :return: A list representing the one-hot encoding. - """ - encoding = [0] * (len(choices) + 1) - index = choices.index(value) if value in choices else -1 - encoding[index] = 1 - return encoding - - -def _smiles_to_graph(smiles: str): - """ - Converts a SMILES string to a torch_geometric.data.Data object. - - :param smiles: The SMILES string for the drug. - :return: A Data object representing the molecular graph, or None if conversion fails. - """ - mol = Chem.MolFromSmiles(smiles) - if mol is None: - return None - - # Atom features - atom_features_list = [] - for atom in mol.GetAtoms(): - features = [] - features.extend(_one_hot_encode(atom.GetAtomicNum(), ATOM_FEATURES["atomic_num"])) - features.extend(_one_hot_encode(atom.GetDegree(), ATOM_FEATURES["degree"])) - features.extend(_one_hot_encode(atom.GetFormalCharge(), ATOM_FEATURES["formal_charge"])) - features.extend(_one_hot_encode(atom.GetTotalNumHs(), ATOM_FEATURES["num_hs"])) - features.extend(_one_hot_encode(atom.GetHybridization(), ATOM_FEATURES["hybridization"])) - features.append(atom.GetIsAromatic()) - features.append(atom.IsInRing()) - atom_features_list.append(features) - x = torch.tensor(atom_features_list, dtype=torch.float) - - # Edge index and edge features - edge_indices = [] - edge_features_list = [] - for bond in mol.GetBonds(): - i = bond.GetBeginAtomIdx() - j = bond.GetEndAtomIdx() - - # Edge features - features = [] - features.extend(_one_hot_encode(bond.GetBondType(), BOND_FEATURES["bond_type"])) - features.append(bond.GetIsConjugated()) - features.append(bond.IsInRing()) - - edge_indices.extend([[i, j], [j, i]]) - edge_features_list.extend([features, features]) # Same features for both directions - - edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous() - edge_attr = torch.tensor(edge_features_list, dtype=torch.float) - - return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) - - -def main(): - """Main function to run the preprocessing.""" - parser = argparse.ArgumentParser(description="Preprocess drug SMILES to graphs.") - parser.add_argument("dataset_name", type=str, help="The name of the dataset to process.") - parser.add_argument("--data_path", type=str, default="data", help="Path to the data folder") - args = parser.parse_args() - - dataset_name = args.dataset_name - data_dir = Path(args.data_path).resolve() - smiles_file = data_dir / dataset_name / "drug_smiles.csv" - output_dir = data_dir / dataset_name / "drug_graphs" - - if not smiles_file.exists(): - print(f"Error: {smiles_file} not found.") - return - - os.makedirs(output_dir, exist_ok=True) - - smiles_df = pd.read_csv(smiles_file) - - print(f"Processing {len(smiles_df)} drugs for dataset {dataset_name}...") - - for _, row in tqdm(smiles_df.iterrows(), total=smiles_df.shape[0]): - drug_id = row["pubchem_id"] - smiles = row["canonical_smiles"] - - graph = _smiles_to_graph(smiles) - - if graph: - torch.save(graph, output_dir / f"{drug_id}.pt") - - print(f"Finished processing. Graphs saved to {output_dir}") - - -if __name__ == "__main__": - main() diff --git a/drevalpy/datasets/featurizer/create_molgnet_embeddings.py b/drevalpy/datasets/featurizer/create_molgnet_embeddings.py deleted file mode 100644 index 2d337e80..00000000 --- a/drevalpy/datasets/featurizer/create_molgnet_embeddings.py +++ /dev/null @@ -1,917 +0,0 @@ -#!/usr/bin/env python3 -"""MolGNet feature extraction utilities (needed for DIPK and adapted from the DIPK github). - -Creates MolGNet embeddings for molecules given their SMILES strings. This module needs torch_scatter. - python create_molgnet_embeddings.py dataset_name --checkpoint meta/MolGNet.pt --data_path data -""" - -import argparse -import math -import os -import pickle # noqa: S403 -from pathlib import Path -from typing import Any, Optional - -import numpy as np -import pandas as pd -import torch -import torch.nn.functional as torch_nn_f -from torch import nn -from torch.nn import Parameter -from torch_geometric.data import Data -from torch_geometric.utils import add_self_loops, softmax -from tqdm import tqdm - -try: - from rdkit import Chem - from rdkit.Chem.rdchem import Mol as RDMol -except ImportError: - raise ImportError("Please install rdkit package for MolGNet featurizer: pip install rdkit") - -# building graphs -allowable_features: dict[str, list[Any]] = { - "atomic_num": list(range(1, 122)), - "formal_charge": ["unk", -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], - "chirality": [ - "unk", - Chem.rdchem.ChiralType.CHI_UNSPECIFIED, - Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, - Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, - Chem.rdchem.ChiralType.CHI_OTHER, - ], - "hybridization": [ - "unk", - Chem.rdchem.HybridizationType.S, - Chem.rdchem.HybridizationType.SP, - Chem.rdchem.HybridizationType.SP2, - Chem.rdchem.HybridizationType.SP3, - Chem.rdchem.HybridizationType.SP3D, - Chem.rdchem.HybridizationType.SP3D2, - Chem.rdchem.HybridizationType.UNSPECIFIED, - ], - "numH": ["unk", 0, 1, 2, 3, 4, 5, 6, 7, 8], - "implicit_valence": ["unk", 0, 1, 2, 3, 4, 5, 6], - "degree": ["unk", 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - "isaromatic": [False, True], - "bond_type": [ - "unk", - Chem.rdchem.BondType.SINGLE, - Chem.rdchem.BondType.DOUBLE, - Chem.rdchem.BondType.TRIPLE, - Chem.rdchem.BondType.AROMATIC, - ], - "bond_dirs": [ - Chem.rdchem.BondDir.NONE, - Chem.rdchem.BondDir.ENDUPRIGHT, - Chem.rdchem.BondDir.ENDDOWNRIGHT, - ], - "bond_isconjugated": [False, True], - "bond_inring": [False, True], - "bond_stereo": [ - "STEREONONE", - "STEREOANY", - "STEREOZ", - "STEREOE", - "STEREOCIS", - "STEREOTRANS", - ], -} - -atom_dic = [ - len(allowable_features["atomic_num"]), - len(allowable_features["formal_charge"]), - len(allowable_features["chirality"]), - len(allowable_features["hybridization"]), - len(allowable_features["numH"]), - len(allowable_features["implicit_valence"]), - len(allowable_features["degree"]), - len(allowable_features["isaromatic"]), -] -bond_dic = [ - len(allowable_features["bond_type"]), - len(allowable_features["bond_dirs"]), - len(allowable_features["bond_isconjugated"]), - len(allowable_features["bond_inring"]), - len(allowable_features["bond_stereo"]), -] -atom_cumsum = np.cumsum(atom_dic) -bond_cumsum = np.cumsum(bond_dic) - - -def mol_to_graph_data_obj_complex(mol: RDMol) -> Data: - """Convert an RDKit Mol into a torch_geometric ``Data`` object. - - The function encodes a fixed set of atom and bond categorical - features and returns a ``Data`` instance with ``x``, ``edge_index`` - and ``edge_attr`` fields. It mirrors the feature layout expected by - the MolGNet implementation used in this repository. - - :param mol: RDKit ``Mol`` instance. Must not be ``None``. - :return: A ``torch_geometric.data.Data`` object with node and edge fields. - :raises ValueError: If ``mol`` is ``None``. - """ - if mol is None: - raise ValueError("mol must not be None") - atom_features_list: list = [] - # Shortcuts for feature lists - fc_list = allowable_features["formal_charge"] - ch_list = allowable_features["chirality"] - hyb_list = allowable_features["hybridization"] - numh_list = allowable_features["numH"] - imp_list = allowable_features["implicit_valence"] - deg_list = allowable_features["degree"] - isa_list = allowable_features["isaromatic"] - bt_list = allowable_features["bond_type"] - bd_list = allowable_features["bond_dirs"] - bic_list = allowable_features["bond_isconjugated"] - bir_list = allowable_features["bond_inring"] - bs_list = allowable_features["bond_stereo"] - for atom in mol.GetAtoms(): - a_idx = allowable_features["atomic_num"].index(atom.GetAtomicNum()) - fc_idx = fc_list.index(atom.GetFormalCharge()) + atom_cumsum[0] - ch_idx = ch_list.index(atom.GetChiralTag()) + atom_cumsum[1] - hyb_idx = hyb_list.index(atom.GetHybridization()) + atom_cumsum[2] - numh_idx = numh_list.index(atom.GetTotalNumHs()) + atom_cumsum[3] - imp_idx = imp_list.index(atom.GetImplicitValence()) + atom_cumsum[4] - deg_idx = deg_list.index(atom.GetDegree()) + atom_cumsum[5] - isa_idx = isa_list.index(atom.GetIsAromatic()) + atom_cumsum[6] - - atom_feature = [ - a_idx, - fc_idx, - ch_idx, - hyb_idx, - numh_idx, - imp_idx, - deg_idx, - isa_idx, - ] - atom_features_list.append(atom_feature) - x = torch.tensor(np.array(atom_features_list), dtype=torch.long) - - # bonds - num_bond_features = 5 - if len(mol.GetBonds()) > 0: - edges_list = [] - edge_features_list = [] - for bond in mol.GetBonds(): - i = bond.GetBeginAtomIdx() - j = bond.GetEndAtomIdx() - bt = bt_list.index(bond.GetBondType()) - bd = bd_list.index(bond.GetBondDir()) + bond_cumsum[0] - bic = bic_list.index(bond.GetIsConjugated()) + bond_cumsum[1] - bir = bir_list.index(bond.IsInRing()) + bond_cumsum[2] - bs = bs_list.index(str(bond.GetStereo())) + bond_cumsum[3] - - edge_feature = [bt, bd, bic, bir, bs] - edges_list.append((i, j)) - edge_features_list.append(edge_feature) - edges_list.append((j, i)) - edge_features_list.append(edge_feature) - edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long) - edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.long) - else: - edge_index = torch.empty((2, 0), dtype=torch.long) - edge_attr = torch.empty((0, num_bond_features), dtype=torch.long) - - data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) - return data - - -class SelfLoop: - """Callable that appends self-loops and matching edge attributes. - - This helper mutates the provided ``Data`` object by adding self-loop - entries to ``edge_index`` and a corresponding edge attribute row for - every node. - """ - - def __call__(self, data: Data) -> Data: - """Modify ``data`` in-place by adding self-loop indices and corresponding edge attributes. - - :param data: ``torch_geometric.data.Data`` to modify. - :return: The modified ``Data`` object (same instance). - """ - num_nodes = data.num_nodes - data.edge_index, _ = add_self_loops(data.edge_index, num_nodes=num_nodes) - self_loop_attr = torch.LongTensor([0, 5, 8, 10, 12]).repeat(num_nodes, 1) - data.edge_attr = torch.cat((data.edge_attr, self_loop_attr), dim=0) - return data - - -class AddSegId: - """Attach zero-valued segment id tensors to nodes and edges. - - The created ``node_seg`` and ``edge_seg`` tensors are added to the - provided ``Data`` instance and used by the MolGNet embedding layers. - """ - - def __init__(self) -> None: - """Create an AddSegId callable (no parameters).""" - pass - - def __call__(self, data: Data) -> Data: - """Attach zero-filled ``node_seg`` and ``edge_seg`` tensors to ``data``. - - :param data: ``torch_geometric.data.Data`` to modify. - :return: The modified ``Data`` object (same instance). - """ - num_nodes = data.num_nodes - num_edges = data.num_edges - node_seg = [0 for _ in range(num_nodes)] - edge_seg = [0 for _ in range(num_edges)] - data.edge_seg = torch.LongTensor(edge_seg) - data.node_seg = torch.LongTensor(node_seg) - return data - - -# MolGNet model - - -class BertLayerNorm(nn.Module): - """Layer normalization compatible with BERT-style implementations. - - :param hidden_size: Dimension of the last axis to normalize. - :param eps: Small epsilon for numerical stability. - """ - - def __init__(self, hidden_size, eps=1e-12): - """Create a BertLayerNorm module. - - :param hidden_size: Dimension of the last axis to normalize. - :param eps: Small epsilon for numerical stability. - """ - super().__init__() - self.shape = torch.Size((hidden_size,)) - self.eps = eps - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.bias = nn.Parameter(torch.zeros(hidden_size)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Apply layer normalization to the last dimension of ``x``. - - :param x: Input tensor. - :return: Normalized tensor with same shape as ``x``. - """ - u = x.mean(-1, keepdim=True) - s = (x - u).pow(2).mean(-1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.eps) - x = self.weight * x + self.bias - return x - - -def gelu(x: torch.Tensor) -> torch.Tensor: - """Gaussian Error Linear Unit activation (approximation). - - :param x: Input tensor. - :return: Activated tensor. - """ - return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2))) - - -def bias_gelu(bias: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """Apply GELU to ``bias + y``. - - :param bias: Bias tensor to add. - :param y: Linear output tensor. - :return: GELU applied to ``bias + y``. - """ - x = bias + y - return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2))) - - -class LinearActivation(nn.Module): - """Linear layer with optional bias-aware GELU activation. - - :param in_features: Input feature dimension. - :param out_features: Output feature dimension. - :param bias: Whether to use a bias parameter and the biased GELU. - """ - - def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: - """ - Create a LinearActivation module. - - :param in_features: Input feature dimension. - :param out_features: Output feature dimension. - :param bias: Whether to use a bias parameter and the biased GELU. - """ - super().__init__() - self.in_features = in_features - self.out_features = out_features - if bias: - self.biased_act_fn = bias_gelu - else: - self.act_fn = gelu - self.weight = Parameter(torch.Tensor(out_features, in_features)) - if bias: - self.bias = Parameter(torch.Tensor(out_features)) - else: - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self) -> None: - """Initialize the layer parameters.""" - nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) - bound = 1 / math.sqrt(fan_in) - nn.init.uniform_(self.bias, -bound, bound) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """Apply the linear transformation and activation. - - :param input: Input tensor of shape [N, in_features]. - :return: Transformed tensor of shape [N, out_features]. - """ - if self.bias is not None: - linear_out = torch_nn_f.linear(input, self.weight, None) - return self.biased_act_fn(self.bias, linear_out) - else: - return self.act_fn(torch_nn_f.linear(input, self.weight, self.bias)) - - -class Intermediate(nn.Module): - """Intermediate feed-forward block used inside GT layers. - - :param hidden: Hidden dimension size. - """ - - def __init__(self, hidden: int) -> None: - """Create the intermediate dense activation block. - - :param hidden: Hidden dimension size. - """ - super().__init__() - self.dense_act = LinearActivation(hidden, 4 * hidden) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Apply the dense activation to the hidden states. - - :param hidden_states: Input tensor of shape [N, hidden]. - :return: Transformed tensor of shape [N, 4*hidden]. - """ - hidden_states = self.dense_act(hidden_states) - return hidden_states - - -class AttentionOut(nn.Module): - """Post-attention output block: projection, dropout and residual norm. - - :param hidden: Hidden dimension used for the linear projection. - :param dropout: Dropout probability. - """ - - def __init__(self, hidden: int, dropout: float) -> None: - """Create an AttentionOut block. - - :param hidden: Hidden dimension used for projection. - :param dropout: Dropout probability. - """ - super().__init__() - self.dense = nn.Linear(hidden, hidden) - self.LayerNorm = BertLayerNorm(hidden, eps=1e-12) - self.dropout = nn.Dropout(dropout) - - def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: - """Project attention outputs and apply layer norm with residual. - - :param hidden_states: Attention output tensor. - :param input_tensor: Residual tensor to add before normalization. - :return: Normalized tensor with the same shape as ``input_tensor``. - """ - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class GTOut(nn.Module): - """Output projection used in GT blocks. - - :param hidden: Hidden dimension. - :param dropout: Dropout probability. - """ - - def __init__(self, hidden: int, dropout: float) -> None: - """Create a GTOut projection block. - - :param hidden: Hidden dimension. - :param dropout: Dropout probability. - """ - super().__init__() - self.dense = nn.Linear(hidden * 4, hidden) - self.LayerNorm = BertLayerNorm(hidden, eps=1e-12) - self.dropout = nn.Dropout(dropout) - - def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: - """Project intermediate states back to hidden dimension and normalize. - - :param hidden_states: Intermediate tensor of shape [N, 4*hidden]. - :param input_tensor: Residual tensor to add. - :return: Tensor of shape [N, hidden]. - """ - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class MessagePassing(nn.Module): - """Minimal MessagePassing base class used by the MolGNet layers. - - This class provides a lightweight implementation of propagate/ - message/aggregate/update used in graph convolutions. - - :param aggr: Aggregation method (e.g., 'add', 'mean'). - :param flow: Message flow direction. - :param node_dim: Node dimension index (unused in this minimal impl). - """ - - def __init__(self, aggr: str = "add", flow: str = "source_to_target", node_dim: int = 0) -> None: - """Create a MessagePassing helper. - - :param aggr: Aggregation method (e.g., 'add' or 'mean'). - :param flow: Message flow direction. - :param node_dim: Node dimension index. - """ - super().__init__() - self.aggr = aggr - self.flow = flow - self.node_dim = node_dim - - def propagate(self, edge_index: torch.Tensor, size: Optional[tuple[int, int]] = None, **kwargs) -> torch.Tensor: - """Run full message-passing: message -> aggregate -> update. - - :param edge_index: Edge indices tensor of shape [2, E]. - :param size: Optional pair describing (num_nodes_source, num_nodes_target). - :param kwargs: Additional data (e.g., node features) needed for message computation. - :raises ValueError: If required inputs (e.g., 'x') are missing or indexing fails. - :return: Updated node tensor after aggregation. - """ - i = 1 if self.flow == "source_to_target" else 0 - j = 0 if i == 1 else 1 - x = kwargs.get("x") - if x is None: - raise ValueError("propagate requires node features passed as keyword 'x'") - try: - x_i = x[edge_index[i]] - x_j = x[edge_index[j]] - except Exception as exc: # defensive - raise ValueError("failed to index node features with edge_index") from exc - msg = self.message( - edge_index_i=edge_index[i], - edge_index_j=edge_index[j], - x_i=x_i, - x_j=x_j, - **kwargs, - ) - # determine number of destination nodes for aggregation - if hasattr(x, "size"): - dim_size = x.size(0) - else: - dim_size = len(x) - out = self.aggregate(msg, index=edge_index[i], dim_size=dim_size) - out = self.update(out) - return out - - def message(self, *args: Any, **kwargs: Any) -> torch.Tensor: - """Default message function returning neighbor features. - - Subclasses may provide richer signatures; this generic form allows - subclass overrides while keeping the base class typed. - - :param args: Positional arguments forwarded by propagate. - :param kwargs: Keyword arguments forwarded by propagate. - :raises ValueError: If required node features are not present. - :return: Message tensor. - """ - x_j = kwargs.get("x_j") if "x_j" in kwargs else (args[1] if len(args) > 1 else None) - if x_j is None: - raise ValueError("message requires node features 'x_j'") - return x_j - - def aggregate(self, inputs: torch.Tensor, index: torch.Tensor, dim_size: Optional[int] = None) -> torch.Tensor: - """Aggregate messages using ``torch_scatter.scatter``. - - :param inputs: Message tensor of shape [E, hidden]. - :param index: Indices to aggregate into nodes. - :param dim_size: Optional target size for the aggregation dimension. - :return: Aggregated node tensor. - """ - from torch_scatter import scatter # local dependency - - return scatter( - inputs, - index, - dim=0, - dim_size=dim_size, - reduce=self.aggr, - ) - - def update(self, inputs: torch.Tensor) -> torch.Tensor: - """Identity update by default. - - Override to apply post-aggregation transformations. - - :param inputs: Aggregated node tensor. - :return: Updated tensor. - """ - return inputs - - -class GraphAttentionConv(MessagePassing): - """Graph attention convolution used by MolGNet. - - :param hidden: Hidden feature dimension. - :param heads: Number of attention heads. - :param dropout: Attention dropout probability. - """ - - def __init__(self, hidden: int, heads: int = 3, dropout: float = 0.0) -> None: - """Create a GraphAttentionConv. - - :param hidden: Hidden feature dimension. - :param heads: Number of attention heads. - :param dropout: Dropout probability. - :raises ValueError: If hidden is not divisible by heads. - """ - super().__init__() - self.hidden = hidden - self.heads = heads - if hidden % heads != 0: - raise ValueError("hidden must be divisible by heads") - self.query = nn.Linear(hidden, heads * int(hidden / heads)) - self.key = nn.Linear(hidden, heads * int(hidden / heads)) - self.value = nn.Linear(hidden, heads * int(hidden / heads)) - self.attn_drop = nn.Dropout(dropout) - - def forward( - self, - x: torch.Tensor, - edge_index: torch.Tensor, - edge_attr: torch.Tensor, - size: Optional[tuple[int, int]] = None, - ) -> torch.Tensor: - """Execute the graph attention conv over the provided inputs. - - :param x: Node feature tensor. - :param edge_index: Edge indices tensor. - :param edge_attr: Edge attribute tensor. - :param size: Optional size tuple. - :return: Updated node tensor after attention. - """ - pseudo = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr - return self.propagate(edge_index=edge_index, x=x, pseudo=pseudo) - - def message( - self, - edge_index_i: torch.Tensor, - x_i: torch.Tensor, - x_j: torch.Tensor, - pseudo: torch.Tensor, - size_i: Optional[int] = None, - **kwargs, - ) -> torch.Tensor: - """Compute messages using multi-head attention between nodes. - - :param edge_index_i: Source indices for edges. - :param x_i: Node features for source nodes. - :param x_j: Node features for target nodes. - :param pseudo: Edge pseudo-features (edge attributes). - :param size_i: Optional number of destination nodes. - :param kwargs: Additional keyword arguments (ignored). - :return: Message tensor shaped for aggregation. - """ - query = self.query(x_i).view( - -1, - self.heads, - int(self.hidden / self.heads), - ) - key = self.key(x_j + pseudo).view( - -1, - self.heads, - int(self.hidden / self.heads), - ) - value = self.value(x_j + pseudo).view( - -1, - self.heads, - int(self.hidden / self.heads), - ) - denom = math.sqrt(int(self.hidden / self.heads)) - alpha = (query * key).sum(dim=-1) / denom - alpha = softmax(src=alpha, index=edge_index_i, num_nodes=size_i) - alpha = self.attn_drop(alpha.view(-1, self.heads, 1)) - return alpha * value - - def update(self, aggr_out: torch.Tensor) -> torch.Tensor: - """Reshape aggregated outputs from multi-head to flat hidden dim. - - :param aggr_out: Aggregated output tensor of shape [N*heads, head_dim]. - :return: Reshaped tensor of shape [N, hidden]. - """ - aggr_out = aggr_out.view(-1, self.heads * int(self.hidden / self.heads)) - return aggr_out - - -class GTLayer(nn.Module): - """Graph Transformer layer composed from attention and feed-forward blocks. - - :param hidden: Hidden dimension size. - :param heads: Number of attention heads. - :param dropout: Dropout probability. - :param num_message_passing: Number of internal message passing steps. - """ - - def __init__(self, hidden: int, heads: int, dropout: float, num_message_passing: int) -> None: - """Create a GTLayer composed of attention and feed-forward blocks. - - :param hidden: Hidden dimension size. - :param heads: Number of attention heads. - :param dropout: Dropout probability. - :param num_message_passing: Number of internal message passing steps. - """ - super().__init__() - self.attention = GraphAttentionConv(hidden, heads, dropout) - self.att_out = AttentionOut(hidden, dropout) - self.intermediate = Intermediate(hidden) - self.output = GTOut(hidden, dropout) - self.gru = nn.GRU(hidden, hidden) - self.LayerNorm = BertLayerNorm(hidden, eps=1e-12) - self.time_step = num_message_passing - - def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor: - """Run the GT layer for the configured number of message-passing steps. - - :param x: Node feature tensor of shape [N, hidden]. - :param edge_index: Edge index tensor. - :param edge_attr: Edge attribute tensor. - :return: Updated node tensor of shape [N, hidden]. - """ - h = x.unsqueeze(0) - for _ in range(self.time_step): - attention_output = self.attention.forward(x, edge_index, edge_attr) - attention_output = self.att_out.forward(attention_output, x) - intermediate_output = self.intermediate.forward(attention_output) - m = self.output.forward(intermediate_output, attention_output) - x, h = self.gru(m.unsqueeze(0), h) - x = self.LayerNorm.forward(x.squeeze(0)) - return x - - -class MolGNet(torch.nn.Module): - """MolGNet model implementation used for node embeddings. - - This implementation is intentionally minimal and only includes the - components required to run a checkpoint and produce per-node - embeddings saved by the featurizer script. - - :param num_layer: Number of GT layers. - :param emb_dim: Embedding dimensionality per node. - :param heads: Number of attention heads. - :param num_message_passing: Message passing steps per layer. - :param drop_ratio: Dropout probability. - """ - - def __init__( - self, - num_layer: int, - emb_dim: int, - heads: int, - num_message_passing: int, - drop_ratio: float = 0, - ) -> None: - """Create a MolGNet instance. - - :param num_layer: Number of GT layers. - :param emb_dim: Embedding dimensionality per node. - :param heads: Number of attention heads. - :param num_message_passing: Message passing steps per layer. - :param drop_ratio: Dropout probability. - """ - super().__init__() - self.num_layer = num_layer - self.drop_ratio = drop_ratio - self.x_embedding = torch.nn.Embedding(178, emb_dim) - self.x_seg_embed = torch.nn.Embedding(3, emb_dim) - self.edge_embedding = torch.nn.Embedding(18, emb_dim) - self.edge_seg_embed = torch.nn.Embedding(3, emb_dim) - self.reset_parameters() - self.gnns = torch.nn.ModuleList( - [GTLayer(emb_dim, heads, drop_ratio, num_message_passing) for _ in range(num_layer)] - ) - - def reset_parameters(self) -> None: - """Re-initialize embedding parameters with Xavier uniform. - - This mirrors common initialization used for transformer-style - embeddings. - """ - torch.nn.init.xavier_uniform_(self.x_embedding.weight.data) - torch.nn.init.xavier_uniform_(self.x_seg_embed.weight.data) - torch.nn.init.xavier_uniform_(self.edge_embedding.weight.data) - torch.nn.init.xavier_uniform_(self.edge_seg_embed.weight.data) - - def forward(self, *argv: Any) -> torch.Tensor: - """Forward pass supporting two calling conventions. - - Accepts either explicit tensors (x, edge_index, edge_attr, node_seg, - edge_seg) or a single ``Data`` object containing those attributes. - - :param argv: Positional arguments as described above. - :raises ValueError: If an unsupported number of arguments is provided. - :return: Node embeddings tensor of shape [N, emb_dim]. - """ - if len(argv) == 5: - x, edge_index, edge_attr, node_seg, edge_seg = (argv[0], argv[1], argv[2], argv[3], argv[4]) - elif len(argv) == 1: - data = argv[0] - x, edge_index, edge_attr, node_seg, edge_seg = ( - data.x, - data.edge_index, - data.edge_attr, - data.node_seg, - data.edge_seg, - ) - else: - raise ValueError("unmatched number of arguments.") - x = self.x_embedding(x).sum(1) + self.x_seg_embed(node_seg) - edge_attr = self.edge_embedding(edge_attr).sum(1) - edge_attr = edge_attr + self.edge_seg_embed(edge_seg) - for gnn in self.gnns: - x = gnn(x, edge_index, edge_attr) - return x - - -def tensor_to_csv_friendly(tensor: Any) -> np.ndarray: - """Convert a tensor-like object into a NumPy array safe for CSV output. - - :param tensor: Input tensor or array-like object. - :return: NumPy array on CPU. - """ - if isinstance(tensor, torch.Tensor): - return tensor.cpu().detach().numpy() - return np.array(tensor) - - -def run(args: argparse.Namespace) -> None: - """Execute the featurization pipeline for a given dataset. - - The function builds graphs from SMILES, runs the MolGNet checkpoint - to extract node embeddings, and writes per-drug CSVs and pickles in - the dataset folder. - - :param args: Parsed CLI arguments. - :raises FileNotFoundError: If expected files or directories are missing. - :raises ValueError: If expected columns are missing in the input CSV. - :raises Exception: For various failures during graph building or inference. - """ - # Use dataset-oriented paths: {data_path}/{dataset_name}/... - # Expand user (~) and resolve to an absolute path. - data_dir = Path(args.data_path).expanduser().resolve() - dataset_dir = data_dir / args.dataset_name - if not dataset_dir.exists(): - raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}") - - out_graphs = str(dataset_dir / "GRAPH_dict.pkl") - out_molg = str(dataset_dir / "MolGNet_dict.pkl") - - # read input csv (expected at {data_path}/{dataset_name}/drug_smiles.csv) - smiles_csv = dataset_dir / "drug_smiles.csv" - if not smiles_csv.exists(): - raise FileNotFoundError(f"Expected SMILES CSV at: {smiles_csv}") - df = pd.read_csv(smiles_csv) - if args.smiles_col not in df.columns or args.id_col not in df.columns: - msg = f"Provided columns not in CSV: {args.smiles_col}, " f"{args.id_col}" - raise ValueError(msg) - df = df.dropna(subset=[args.smiles_col]) - smiles_map = dict(zip(df[args.id_col], df[args.smiles_col])) - - # Build graphs - graph_dict: dict[Any, Data] = {} - failed_conversions = [] - for idx, smi in tqdm(smiles_map.items(), desc="building graphs"): - mol = Chem.MolFromSmiles(smi) - if mol is None: - failed_conversions.append((idx, smi, "MolFromSmiles returned None")) - continue - try: - graph_dict[idx] = mol_to_graph_data_obj_complex(mol) - except Exception as e: - failed_conversions.append((idx, smi, str(e))) - if failed_conversions: - print(f"\n{len(failed_conversions)} molecules failed to convert to graphs.") - for idx, smi, err in failed_conversions: - print(f"Failed to convert {idx} (SMILES: {smi}): {err}") - else: - print("\nAll molecules converted to graphs successfully.") - # save graphs to dataset folder - with open(out_graphs, "wb") as f: - pickle.dump(graph_dict, f) - # load model - if args.device: - device = torch.device(args.device) - else: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - num_layer = 5 - emb_dim = 768 - heads = 12 - msg_pass = 3 - drop = 0.0 - model = MolGNet( - num_layer=num_layer, - emb_dim=emb_dim, - heads=heads, - num_message_passing=msg_pass, - drop_ratio=drop, - ) - # Prefer pathlib operations when working with Path objects - checkpoint_path = data_dir / args.checkpoint - ckpt = torch.load(checkpoint_path, map_location=device) # noqa S614 - try: - model.load_state_dict(ckpt) - except Exception: - if isinstance(ckpt, dict) and "state_dict" in ckpt: - model.load_state_dict(ckpt["state_dict"]) - else: - raise - model = model.to(device) - model.eval() - - self_loop = SelfLoop() - add_seg = AddSegId() - - molgnet_dict: dict[Any, torch.Tensor] = {} - with torch.no_grad(): - for idx, graph in tqdm(graph_dict.items(), desc="running model"): - try: - g = self_loop(graph) - g = add_seg(g) - g = g.to(device) - emb = model(g) - molgnet_dict[idx] = emb.cpu() - except Exception as e: - print(f"Inference failed for {idx}: {e}") - - with open(out_molg, "wb") as f: - pickle.dump(molgnet_dict, f) - - # write per-drug CSVs to {dataset_dir}/DIPK_features/Drugs - out_drugs_dir = dataset_dir / "DIPK_features/Drugs" - os.makedirs(out_drugs_dir, exist_ok=True) - for idx, emb in tqdm(molgnet_dict.items(), desc="writing csvs"): - arr = tensor_to_csv_friendly(emb) - df_emb = pd.DataFrame(arr) - out_path = out_drugs_dir / f"MolGNet_{idx}.csv" - df_emb.to_csv(out_path, sep="\t", index=False) - - print("Done.") - print("Graphs saved to:", out_graphs) - print("Node embeddings saved to:", out_molg) - print("Per-drug CSVs in:", out_drugs_dir) - - -def parse_args() -> argparse.Namespace: - """Parse command-line arguments. - - :return: Parsed arguments namespace. - """ - p = argparse.ArgumentParser(description=("Standalone MolGNet extractor " "(dataset-oriented)")) - p.add_argument( - "dataset_name", - help="Name of the dataset (folder under data_path)", - ) - p.add_argument( - "--data_path", - default="data", - help="Top-level data folder path", - ) - p.add_argument( - "--smiles-col", - dest="smiles_col", - default="canonical_smiles", - help="Column name for SMILES in input CSV", - ) - p.add_argument( - "--id-col", - dest="id_col", - default="pubchem_id", - help="Column name for unique ID in input CSV", - ) - p.add_argument( - "--checkpoint", - default="MolGNet.pt", - help="MolGNet checkpoint (state_dict), can be obtained from Zenodo: https://doi.org/10.5281/zenodo.12633909", - ) - p.add_argument( - "--device", - default=None, - help="torch device string, e.g. cpu or cuda:0", - ) - return p.parse_args() - - -if __name__ == "__main__": - args = parse_args() - run(args) diff --git a/drevalpy/datasets/featurizer/drug/__init__.py b/drevalpy/datasets/featurizer/drug/__init__.py new file mode 100644 index 00000000..c7c5fbd7 --- /dev/null +++ b/drevalpy/datasets/featurizer/drug/__init__.py @@ -0,0 +1,16 @@ +"""Drug featurizers for converting drug representations to embeddings.""" + +from .base import DrugFeaturizer +from .chemberta import ChemBERTaFeaturizer, ChemBERTaMixin +from .drug_graph import DrugGraphFeaturizer, DrugGraphMixin +from .molgnet import MolGNetFeaturizer, MolGNetMixin + +__all__ = [ + "DrugFeaturizer", + "ChemBERTaFeaturizer", + "ChemBERTaMixin", + "DrugGraphFeaturizer", + "DrugGraphMixin", + "MolGNetFeaturizer", + "MolGNetMixin", +] diff --git a/drevalpy/datasets/featurizer/drug/base.py b/drevalpy/datasets/featurizer/drug/base.py new file mode 100644 index 00000000..c03e334c --- /dev/null +++ b/drevalpy/datasets/featurizer/drug/base.py @@ -0,0 +1,193 @@ +"""Abstract base class for drug featurizers.""" + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + +import numpy as np +import pandas as pd + +from drevalpy.datasets.dataset import FeatureDataset +from drevalpy.datasets.utils import DRUG_IDENTIFIER + + +class DrugFeaturizer(ABC): + """Abstract base class for drug featurizers. + + Drug featurizers convert drug representations (e.g., SMILES strings) into + numerical embeddings that can be used as input features for machine learning models. + + Subclasses must implement: + - featurize(): Convert a single drug to its embedding + - get_feature_name(): Return the name of the feature view + - get_output_filename(): Return the filename for cached embeddings + + The base class provides: + - load_or_generate(): Load cached embeddings or generate and cache them + - generate_embeddings(): Generate embeddings for all drugs in a dataset + - load_embeddings(): Load pre-generated embeddings from disk + """ + + def __init__(self, device: str = "cpu"): + """Initialize the featurizer. + + :param device: Device to use for computation (e.g., 'cpu', 'cuda') + """ + self.device = device + + @abstractmethod + def featurize(self, smiles: str) -> np.ndarray | Any: + """Convert a SMILES string to a feature representation. + + :param smiles: SMILES string representing the drug + :returns: Feature representation (numpy array or other format like torch_geometric.Data) + """ + + @classmethod + @abstractmethod + def get_feature_name(cls) -> str: + """Return the name of the feature view. + + This name is used as the key in the FeatureDataset. + + :returns: Feature view name (e.g., 'chemberta_embeddings') + """ + + @classmethod + @abstractmethod + def get_output_filename(cls) -> str: + """Return the filename for cached embeddings. + + :returns: Filename (e.g., 'drug_chemberta_embeddings.csv') + """ + + def load_or_generate(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load cached embeddings or generate and cache them if not available. + + This is the main entry point for using a featurizer. It checks if + pre-generated embeddings exist and loads them, otherwise generates + new embeddings and saves them for future use. + + :param data_path: Path to the data directory (e.g., 'data/') + :param dataset_name: Name of the dataset (e.g., 'GDSC1') + :returns: FeatureDataset containing the drug embeddings + """ + output_path = Path(data_path) / dataset_name / self.get_output_filename() + + if output_path.exists(): + return self.load_embeddings(data_path, dataset_name) + else: + print(f"Embeddings not found at {output_path}. Generating...") + return self.generate_embeddings(data_path, dataset_name) + + def generate_embeddings(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Generate embeddings for all drugs in a dataset and save to disk. + + :param data_path: Path to the data directory + :param dataset_name: Name of the dataset + :returns: FeatureDataset containing the generated embeddings + :raises FileNotFoundError: If the drug_smiles.csv file is not found + """ + data_dir = Path(data_path).resolve() + smiles_file = data_dir / dataset_name / "drug_smiles.csv" + output_file = data_dir / dataset_name / self.get_output_filename() + + if not smiles_file.exists(): + raise FileNotFoundError(f"SMILES file not found: {smiles_file}") + + smiles_df = pd.read_csv(smiles_file, dtype={"canonical_smiles": str, DRUG_IDENTIFIER: str}) + + embeddings_list = [] + drug_ids = [] + + print(f"Processing {len(smiles_df)} drugs for dataset {dataset_name}...") + + for row in smiles_df.itertuples(index=False): + drug_id = getattr(row, DRUG_IDENTIFIER) + smiles = row.canonical_smiles + + try: + embedding = self.featurize(smiles) + embeddings_list.append(embedding) + drug_ids.append(drug_id) + except Exception as e: + print(f"Failed to process drug {drug_id} (SMILES: {smiles}): {e}") + continue + + # Save embeddings + self._save_embeddings(embeddings_list, drug_ids, output_file) + + print(f"Embeddings saved to {output_file}") + + # Return as FeatureDataset + return self._create_feature_dataset(embeddings_list, drug_ids) + + def _save_embeddings(self, embeddings: list, drug_ids: list[str], output_path: Path) -> None: + """Save embeddings to disk. + + Default implementation saves as CSV. Subclasses can override for other formats. + + :param embeddings: List of embedding arrays + :param drug_ids: List of drug identifiers + :param output_path: Path to save the embeddings + """ + embeddings_df = pd.DataFrame(embeddings) + embeddings_df.insert(0, DRUG_IDENTIFIER, drug_ids) + embeddings_df.to_csv(output_path, index=False) + + def _create_feature_dataset(self, embeddings: list, drug_ids: list[str]) -> FeatureDataset: + """Create a FeatureDataset from embeddings. + + :param embeddings: List of embedding arrays + :param drug_ids: List of drug identifiers + :returns: FeatureDataset containing the embeddings + """ + feature_name = self.get_feature_name() + features = {} + for drug_id, embedding in zip(drug_ids, embeddings, strict=True): + if isinstance(embedding, np.ndarray): + features[drug_id] = {feature_name: embedding.astype(np.float32)} + else: + features[drug_id] = {feature_name: embedding} + return FeatureDataset(features) + + def load_embeddings(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load pre-generated embeddings from disk. + + :param data_path: Path to the data directory + :param dataset_name: Name of the dataset + :returns: FeatureDataset containing the embeddings + :raises FileNotFoundError: If the embeddings file is not found + """ + embeddings_file = Path(data_path) / dataset_name / self.get_output_filename() + + if not embeddings_file.exists(): + raise FileNotFoundError( + f"Embeddings file not found: {embeddings_file}. " + f"Use load_or_generate() to automatically generate embeddings." + ) + + embeddings_df = pd.read_csv(embeddings_file, dtype={DRUG_IDENTIFIER: str}) + feature_name = self.get_feature_name() + features = {} + + for _, row in embeddings_df.iterrows(): + drug_id = row[DRUG_IDENTIFIER] + embedding = row.drop(DRUG_IDENTIFIER).to_numpy(dtype=np.float32) + features[drug_id] = {feature_name: embedding} + + return FeatureDataset(features) + + +def main(): + """Entry point for running featurizer from command line. + + This function should be overridden by subclasses that support CLI usage. + + :raises NotImplementedError: Always, as subclasses should implement their own main() + """ + raise NotImplementedError("Subclasses should implement their own main() function") + + +if __name__ == "__main__": + main() diff --git a/drevalpy/datasets/featurizer/drug/chemberta.py b/drevalpy/datasets/featurizer/drug/chemberta.py new file mode 100644 index 00000000..9b642d77 --- /dev/null +++ b/drevalpy/datasets/featurizer/drug/chemberta.py @@ -0,0 +1,144 @@ +"""ChemBERTa drug featurizer for generating embeddings from SMILES strings.""" + +import argparse + +import numpy as np +import torch + +from drevalpy.datasets.dataset import FeatureDataset + +from .base import DrugFeaturizer + + +class ChemBERTaFeaturizer(DrugFeaturizer): + """Featurizer that generates ChemBERTa embeddings from SMILES strings. + + ChemBERTa is a transformer model pre-trained on chemical SMILES strings. + This featurizer uses the model to generate fixed-size embeddings for drugs. + + Example usage:: + + featurizer = ChemBERTaFeaturizer(device="cuda") + features = featurizer.load_or_generate("data", "GDSC1") + """ + + def __init__(self, device: str = "cpu"): + """Initialize the ChemBERTa featurizer. + + :param device: Device to use for computation ('cpu' or 'cuda') + """ + super().__init__(device=device) + self._tokenizer = None + self._model = None + + def _load_model(self): + """Lazily load the ChemBERTa model and tokenizer. + + :raises ImportError: If transformers or torch packages are not installed + """ + if self._model is None: + try: + from transformers import AutoModel, AutoTokenizer + except ImportError: + raise ImportError( + "Please install transformers package for ChemBERTa featurizer: pip install transformers torch" + ) + + self._tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") + self._model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") + self._model.to(self.device) + self._model.eval() + + def featurize(self, smiles: str) -> np.ndarray: + """Convert a SMILES string to a ChemBERTa embedding. + + :param smiles: SMILES string representing the drug + :returns: ChemBERTa embedding as numpy array + :raises RuntimeError: If model is not loaded + """ + self._load_model() + + if self._tokenizer is None or self._model is None: + raise RuntimeError("Model not loaded. Call _load_model() first.") + + inputs = self._tokenizer(smiles, return_tensors="pt", truncation=True) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = self._model(**inputs) + hidden_states = outputs.last_hidden_state + + # Mean pooling over sequence length + embedding = hidden_states.mean(dim=1).squeeze(0) + return embedding.cpu().numpy() + + @classmethod + def get_feature_name(cls) -> str: + """Return the feature view name. + + :returns: 'chemberta_embeddings' + """ + return "chemberta_embeddings" + + @classmethod + def get_output_filename(cls) -> str: + """Return the output filename for cached embeddings. + + :returns: 'drug_chemberta_embeddings.csv' + """ + return "drug_chemberta_embeddings.csv" + + +class ChemBERTaMixin: + """Mixin that provides ChemBERTa drug embeddings loading for DRP models. + + This mixin implements load_drug_features using the ChemBERTaFeaturizer. + It automatically generates embeddings if they don't exist. + + Class attributes that can be overridden: + - chemberta_device: Device for ChemBERTa model ('cpu', 'cuda', or 'auto') + + Example usage:: + + from drevalpy.models.drp_model import DRPModel + from drevalpy.datasets.featurizer.drug.chemberta import ChemBERTaMixin + + class MyModel(ChemBERTaMixin, DRPModel): + drug_views = ["chemberta_embeddings"] + ... + """ + + chemberta_device: str = "auto" + + def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load ChemBERTa drug embeddings. + + Uses the ChemBERTaFeaturizer to load pre-generated embeddings or generate + them automatically if they don't exist. + + :param data_path: Path to the data directory, e.g., 'data/' + :param dataset_name: Name of the dataset, e.g., 'GDSC1' + :returns: FeatureDataset containing the ChemBERTa embeddings + """ + device = self.chemberta_device + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + + featurizer = ChemBERTaFeaturizer(device=device) + return featurizer.load_or_generate(data_path, dataset_name) + + +def main(): + """Process drug SMILES and save ChemBERTa embeddings from command line.""" + parser = argparse.ArgumentParser(description="Generate ChemBERTa embeddings for drugs.") + parser.add_argument("dataset_name", type=str, help="The name of the dataset to process.") + parser.add_argument("--device", type=str, default="cpu", help="Torch device (cpu or cuda)") + parser.add_argument("--data_path", type=str, default="data", help="Path to the data folder") + args = parser.parse_args() + + featurizer = ChemBERTaFeaturizer(device=args.device) + featurizer.generate_embeddings(args.data_path, args.dataset_name) + + +if __name__ == "__main__": + main() diff --git a/drevalpy/datasets/featurizer/drug/drug_graph.py b/drevalpy/datasets/featurizer/drug/drug_graph.py new file mode 100644 index 00000000..cdd0182b --- /dev/null +++ b/drevalpy/datasets/featurizer/drug/drug_graph.py @@ -0,0 +1,250 @@ +"""Drug graph featurizer for converting SMILES to molecular graphs.""" + +import argparse +import os +from pathlib import Path + +import torch +from torch_geometric.data import Data + +from drevalpy.datasets.dataset import FeatureDataset + +from .base import DrugFeaturizer + +try: + from rdkit import Chem +except ImportError: + Chem = None + + +# Atom feature configuration +ATOM_FEATURES = { + "atomic_num": list(range(1, 119)), + "degree": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "formal_charge": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], + "num_hs": [0, 1, 2, 3, 4, 5, 6, 7, 8], + "hybridization": [], # Will be populated after rdkit import check +} + +# Bond feature configuration +BOND_FEATURES: dict[str, list] = { + "bond_type": [], # Will be populated after rdkit import check +} + + +def _init_rdkit_features(): + """Initialize RDKit-dependent feature configurations. + + :raises ImportError: If rdkit package is not installed + """ + if Chem is None: + raise ImportError("Please install rdkit package for drug graphs featurizer: pip install rdkit") + + ATOM_FEATURES["hybridization"] = [ + Chem.rdchem.HybridizationType.SP, + Chem.rdchem.HybridizationType.SP2, + Chem.rdchem.HybridizationType.SP3, + Chem.rdchem.HybridizationType.SP3D, + Chem.rdchem.HybridizationType.SP3D2, + ] + BOND_FEATURES["bond_type"] = [ + Chem.rdchem.BondType.SINGLE, + Chem.rdchem.BondType.DOUBLE, + Chem.rdchem.BondType.TRIPLE, + Chem.rdchem.BondType.AROMATIC, + ] + + +def _one_hot_encode(value, choices): + """Create a one-hot encoding for a value in a list of choices. + + :param value: The value to be one-hot encoded. + :param choices: A list of possible choices for the value. + :return: A list representing the one-hot encoding. + """ + encoding = [0] * (len(choices) + 1) + index = choices.index(value) if value in choices else -1 + encoding[index] = 1 + return encoding + + +class DrugGraphFeaturizer(DrugFeaturizer): + """Featurizer that converts SMILES strings to molecular graphs. + + The graphs are stored as torch_geometric.data.Data objects with: + - x: Node features (atom features) + - edge_index: Edge connectivity + - edge_attr: Edge features (bond features) + + Example usage:: + + featurizer = DrugGraphFeaturizer() + features = featurizer.load_or_generate("data", "GDSC1") + """ + + def __init__(self, device: str = "cpu"): + """Initialize the drug graph featurizer. + + :param device: Device to use (not used for graph generation, but kept for API consistency) + """ + super().__init__(device=device) + _init_rdkit_features() + + def featurize(self, smiles: str) -> Data | None: + """Convert a SMILES string to a molecular graph. + + :param smiles: SMILES string representing the drug + :returns: torch_geometric.data.Data object or None if conversion fails + """ + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return None + + # Atom features + atom_features_list = [] + for atom in mol.GetAtoms(): + features = [] + features.extend(_one_hot_encode(atom.GetAtomicNum(), ATOM_FEATURES["atomic_num"])) + features.extend(_one_hot_encode(atom.GetDegree(), ATOM_FEATURES["degree"])) + features.extend(_one_hot_encode(atom.GetFormalCharge(), ATOM_FEATURES["formal_charge"])) + features.extend(_one_hot_encode(atom.GetTotalNumHs(), ATOM_FEATURES["num_hs"])) + features.extend(_one_hot_encode(atom.GetHybridization(), ATOM_FEATURES["hybridization"])) + features.append(atom.GetIsAromatic()) + features.append(atom.IsInRing()) + atom_features_list.append(features) + x = torch.tensor(atom_features_list, dtype=torch.float) + + # Edge index and edge features + edge_indices = [] + edge_features_list = [] + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + + # Edge features + features = [] + features.extend(_one_hot_encode(bond.GetBondType(), BOND_FEATURES["bond_type"])) + features.append(bond.GetIsConjugated()) + features.append(bond.IsInRing()) + + edge_indices.extend([[i, j], [j, i]]) + edge_features_list.extend([features, features]) # Same features for both directions + + edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous() + edge_attr = torch.tensor(edge_features_list, dtype=torch.float) + + return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + + @classmethod + def get_feature_name(cls) -> str: + """Return the feature view name. + + :returns: 'drug_graphs' + """ + return "drug_graphs" + + @classmethod + def get_output_filename(cls) -> str: + """Return the output directory name for cached graphs. + + :returns: 'drug_graphs' + """ + return "drug_graphs" + + def _save_embeddings(self, embeddings: list, drug_ids: list[str], output_path: Path) -> None: + """Save graph embeddings to disk as individual .pt files. + + :param embeddings: List of Data objects + :param drug_ids: List of drug identifiers + :param output_path: Directory path to save the graphs + """ + os.makedirs(output_path, exist_ok=True) + for drug_id, graph in zip(drug_ids, embeddings, strict=True): + if graph is not None: + torch.save(graph, output_path / f"{drug_id}.pt") + + def load_embeddings(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load pre-generated graph embeddings from disk. + + :param data_path: Path to the data directory + :param dataset_name: Name of the dataset + :returns: FeatureDataset containing the graph embeddings + :raises FileNotFoundError: If the graphs directory is not found + """ + graphs_dir = Path(data_path) / dataset_name / self.get_output_filename() + + if not graphs_dir.exists(): + raise FileNotFoundError( + f"Graphs directory not found: {graphs_dir}. " + f"Use load_or_generate() to automatically generate graphs." + ) + + feature_name = self.get_feature_name() + features = {} + + for graph_file in graphs_dir.glob("*.pt"): + drug_id = graph_file.stem + graph = torch.load(graph_file) # noqa: S614 + features[drug_id] = {feature_name: graph} + + return FeatureDataset(features) + + def load_or_generate(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load cached graphs or generate and cache them if not available. + + :param data_path: Path to the data directory + :param dataset_name: Name of the dataset + :returns: FeatureDataset containing the drug graphs + """ + output_path = Path(data_path) / dataset_name / self.get_output_filename() + + if output_path.exists() and any(output_path.glob("*.pt")): + return self.load_embeddings(data_path, dataset_name) + else: + print(f"Graphs not found at {output_path}. Generating...") + return self.generate_embeddings(data_path, dataset_name) + + +class DrugGraphMixin: + """Mixin that provides drug graph loading for DRP models. + + This mixin implements load_drug_features using the DrugGraphFeaturizer. + It automatically generates graphs if they don't exist. + + Example usage:: + + from drevalpy.models.drp_model import DRPModel + from drevalpy.datasets.featurizer.drug.drug_graph import DrugGraphMixin + + class MyModel(DrugGraphMixin, DRPModel): + drug_views = ["drug_graphs"] + ... + """ + + def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load drug graph features. + + Uses the DrugGraphFeaturizer to load pre-generated graphs or generate + them automatically if they don't exist. + + :param data_path: Path to the data directory, e.g., 'data/' + :param dataset_name: Name of the dataset, e.g., 'GDSC1' + :returns: FeatureDataset containing the drug graphs + """ + featurizer = DrugGraphFeaturizer() + return featurizer.load_or_generate(data_path, dataset_name) + + +def main(): + """Process drug SMILES and save molecular graphs from command line.""" + parser = argparse.ArgumentParser(description="Generate molecular graphs for drugs.") + parser.add_argument("dataset_name", type=str, help="The name of the dataset to process.") + parser.add_argument("--data_path", type=str, default="data", help="Path to the data folder") + args = parser.parse_args() + + featurizer = DrugGraphFeaturizer() + featurizer.generate_embeddings(args.data_path, args.dataset_name) + + +if __name__ == "__main__": + main() diff --git a/drevalpy/datasets/featurizer/drug/molgnet.py b/drevalpy/datasets/featurizer/drug/molgnet.py new file mode 100644 index 00000000..53e872ea --- /dev/null +++ b/drevalpy/datasets/featurizer/drug/molgnet.py @@ -0,0 +1,862 @@ +"""MolGNet drug featurizer for generating graph-based embeddings. + +This module provides a featurizer that uses the MolGNet model to generate +node embeddings for molecules. It requires a pre-trained MolGNet checkpoint. +""" + +import argparse +import math +import os +import pickle # noqa: S403 +from pathlib import Path +from typing import Any, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as torch_nn_f +from torch import nn +from torch.nn import Parameter +from torch_geometric.data import Data +from torch_geometric.utils import add_self_loops, softmax + +from drevalpy.datasets.dataset import FeatureDataset + +from .base import DrugFeaturizer + +try: + from rdkit import Chem + from rdkit.Chem.rdchem import Mol as RDMol +except ImportError: + Chem = None + RDMol = None + + +# Feature configuration for MolGNet graph building +allowable_features: dict[str, list[Any]] = { + "atomic_num": list(range(1, 122)), + "formal_charge": ["unk", -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], + "chirality": [], # Populated after rdkit import check + "hybridization": [], # Populated after rdkit import check + "numH": ["unk", 0, 1, 2, 3, 4, 5, 6, 7, 8], + "implicit_valence": ["unk", 0, 1, 2, 3, 4, 5, 6], + "degree": ["unk", 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "isaromatic": [False, True], + "bond_type": [], # Populated after rdkit import check + "bond_dirs": [], # Populated after rdkit import check + "bond_isconjugated": [False, True], + "bond_inring": [False, True], + "bond_stereo": [ + "STEREONONE", + "STEREOANY", + "STEREOZ", + "STEREOE", + "STEREOCIS", + "STEREOTRANS", + ], +} + + +def _init_rdkit_features(): + """Initialize RDKit-dependent feature configurations. + + :raises ImportError: If rdkit package is not installed + """ + if Chem is None: + raise ImportError("Please install rdkit package for MolGNet featurizer: pip install rdkit") + + allowable_features["chirality"] = [ + "unk", + Chem.rdchem.ChiralType.CHI_UNSPECIFIED, + Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, + Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, + Chem.rdchem.ChiralType.CHI_OTHER, + ] + allowable_features["hybridization"] = [ + "unk", + Chem.rdchem.HybridizationType.S, + Chem.rdchem.HybridizationType.SP, + Chem.rdchem.HybridizationType.SP2, + Chem.rdchem.HybridizationType.SP3, + Chem.rdchem.HybridizationType.SP3D, + Chem.rdchem.HybridizationType.SP3D2, + Chem.rdchem.HybridizationType.UNSPECIFIED, + ] + allowable_features["bond_type"] = [ + "unk", + Chem.rdchem.BondType.SINGLE, + Chem.rdchem.BondType.DOUBLE, + Chem.rdchem.BondType.TRIPLE, + Chem.rdchem.BondType.AROMATIC, + ] + allowable_features["bond_dirs"] = [ + Chem.rdchem.BondDir.NONE, + Chem.rdchem.BondDir.ENDUPRIGHT, + Chem.rdchem.BondDir.ENDDOWNRIGHT, + ] + + +# Compute cumulative sums for feature indexing +atom_dic = [ + len(allowable_features["atomic_num"]), + 12, # formal_charge + 5, # chirality + 8, # hybridization + 10, # numH + 7, # implicit_valence + 12, # degree + 2, # isaromatic +] +bond_dic = [ + 5, # bond_type + 3, # bond_dirs + 2, # bond_isconjugated + 2, # bond_inring + 6, # bond_stereo +] +atom_cumsum = np.cumsum(atom_dic) +bond_cumsum = np.cumsum(bond_dic) + + +def mol_to_graph_data_obj_complex(mol: "RDMol") -> Data: + """Convert an RDKit Mol into a torch_geometric Data object for MolGNet. + + :param mol: RDKit Mol instance + :returns: torch_geometric.data.Data object + :raises ValueError: If mol is None + """ + if mol is None: + raise ValueError("mol must not be None") + + _init_rdkit_features() + + atom_features_list: list = [] + fc_list = allowable_features["formal_charge"] + ch_list = allowable_features["chirality"] + hyb_list = allowable_features["hybridization"] + numh_list = allowable_features["numH"] + imp_list = allowable_features["implicit_valence"] + deg_list = allowable_features["degree"] + isa_list = allowable_features["isaromatic"] + bt_list = allowable_features["bond_type"] + bd_list = allowable_features["bond_dirs"] + bic_list = allowable_features["bond_isconjugated"] + bir_list = allowable_features["bond_inring"] + bs_list = allowable_features["bond_stereo"] + + for atom in mol.GetAtoms(): + a_idx = allowable_features["atomic_num"].index(atom.GetAtomicNum()) + fc_idx = fc_list.index(atom.GetFormalCharge()) + atom_cumsum[0] + ch_idx = ch_list.index(atom.GetChiralTag()) + atom_cumsum[1] + hyb_idx = hyb_list.index(atom.GetHybridization()) + atom_cumsum[2] + numh_idx = numh_list.index(atom.GetTotalNumHs()) + atom_cumsum[3] + imp_idx = imp_list.index(atom.GetImplicitValence()) + atom_cumsum[4] + deg_idx = deg_list.index(atom.GetDegree()) + atom_cumsum[5] + isa_idx = isa_list.index(atom.GetIsAromatic()) + atom_cumsum[6] + + atom_feature = [a_idx, fc_idx, ch_idx, hyb_idx, numh_idx, imp_idx, deg_idx, isa_idx] + atom_features_list.append(atom_feature) + x = torch.tensor(np.array(atom_features_list), dtype=torch.long) + + # bonds + num_bond_features = 5 + if len(mol.GetBonds()) > 0: + edges_list = [] + edge_features_list = [] + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + bt = bt_list.index(bond.GetBondType()) + bd = bd_list.index(bond.GetBondDir()) + bond_cumsum[0] + bic = bic_list.index(bond.GetIsConjugated()) + bond_cumsum[1] + bir = bir_list.index(bond.IsInRing()) + bond_cumsum[2] + bs = bs_list.index(str(bond.GetStereo())) + bond_cumsum[3] + + edge_feature = [bt, bd, bic, bir, bs] + edges_list.append((i, j)) + edge_features_list.append(edge_feature) + edges_list.append((j, i)) + edge_features_list.append(edge_feature) + edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long) + edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.long) + else: + edge_index = torch.empty((2, 0), dtype=torch.long) + edge_attr = torch.empty((0, num_bond_features), dtype=torch.long) + + return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + + +class SelfLoop: + """Callable that appends self-loops and matching edge attributes.""" + + def __call__(self, data: Data) -> Data: + """Add self-loop indices and corresponding edge attributes. + + :param data: torch_geometric.data.Data to modify + :returns: Modified Data object + """ + num_nodes = data.num_nodes + data.edge_index, _ = add_self_loops(data.edge_index, num_nodes=num_nodes) + self_loop_attr = torch.LongTensor([0, 5, 8, 10, 12]).repeat(num_nodes, 1) + data.edge_attr = torch.cat((data.edge_attr, self_loop_attr), dim=0) + return data + + +class AddSegId: + """Attach zero-valued segment id tensors to nodes and edges.""" + + def __call__(self, data: Data) -> Data: + """Attach zero-filled node_seg and edge_seg tensors. + + :param data: torch_geometric.data.Data to modify + :returns: Modified Data object + """ + num_nodes = data.num_nodes + num_edges = data.num_edges + data.edge_seg = torch.LongTensor([0] * num_edges) + data.node_seg = torch.LongTensor([0] * num_nodes) + return data + + +# MolGNet model components + + +class BertLayerNorm(nn.Module): + """Layer normalization compatible with BERT-style implementations.""" + + def __init__(self, hidden_size: int, eps: float = 1e-12) -> None: + """Initialize the layer normalization. + + :param hidden_size: Size of the hidden dimension + :param eps: Small constant for numerical stability + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply layer normalization. + + :param x: Input tensor + :returns: Normalized tensor + """ + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + return self.weight * x + self.bias + + +def gelu(x: torch.Tensor) -> torch.Tensor: + """Apply Gaussian Error Linear Unit activation. + + :param x: Input tensor + :returns: Activated tensor + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2))) + + +def bias_gelu(bias: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Apply GELU activation to bias + y. + + :param bias: Bias tensor + :param y: Input tensor + :returns: Activated tensor + """ + x = bias + y + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2))) + + +class LinearActivation(nn.Module): + """Linear layer with optional bias-aware GELU activation.""" + + def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: + """Initialize the linear activation layer. + + :param in_features: Number of input features + :param out_features: Number of output features + :param bias: Whether to include a bias term + """ + super().__init__() + self.in_features = in_features + self.out_features = out_features + if bias: + self.biased_act_fn = bias_gelu + else: + self.act_fn = gelu + self.weight = Parameter(torch.Tensor(out_features, in_features)) + if bias: + self.bias = Parameter(torch.Tensor(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + """Reset layer parameters using Kaiming initialization.""" + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """Apply linear transformation with GELU activation. + + :param input: Input tensor + :returns: Transformed tensor + """ + if self.bias is not None: + linear_out = torch_nn_f.linear(input, self.weight, None) + return self.biased_act_fn(self.bias, linear_out) + else: + return self.act_fn(torch_nn_f.linear(input, self.weight, self.bias)) + + +class Intermediate(nn.Module): + """Intermediate feed-forward block used inside GT layers.""" + + def __init__(self, hidden: int) -> None: + """Initialize the intermediate layer. + + :param hidden: Hidden dimension size + """ + super().__init__() + self.dense_act = LinearActivation(hidden, 4 * hidden) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Apply feed-forward transformation. + + :param hidden_states: Input tensor + :returns: Transformed tensor + """ + return self.dense_act(hidden_states) + + +class AttentionOut(nn.Module): + """Post-attention output block: projection, dropout and residual norm.""" + + def __init__(self, hidden: int, dropout: float) -> None: + """Initialize the attention output layer. + + :param hidden: Hidden dimension size + :param dropout: Dropout probability + """ + super().__init__() + self.dense = nn.Linear(hidden, hidden) + self.LayerNorm = BertLayerNorm(hidden, eps=1e-12) + self.dropout = nn.Dropout(dropout) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + """Apply output transformation with residual connection. + + :param hidden_states: Attention output tensor + :param input_tensor: Original input for residual connection + :returns: Transformed tensor + """ + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return self.LayerNorm(hidden_states + input_tensor) + + +class GTOut(nn.Module): + """Output projection used in GT blocks.""" + + def __init__(self, hidden: int, dropout: float) -> None: + """Initialize the GT output layer. + + :param hidden: Hidden dimension size + :param dropout: Dropout probability + """ + super().__init__() + self.dense = nn.Linear(hidden * 4, hidden) + self.LayerNorm = BertLayerNorm(hidden, eps=1e-12) + self.dropout = nn.Dropout(dropout) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + """Apply output transformation with residual connection. + + :param hidden_states: Intermediate output tensor + :param input_tensor: Original input for residual connection + :returns: Transformed tensor + """ + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return self.LayerNorm(hidden_states + input_tensor) + + +class MessagePassing(nn.Module): + """Minimal MessagePassing base class used by the MolGNet layers.""" + + def __init__(self, aggr: str = "add", flow: str = "source_to_target", node_dim: int = 0) -> None: + """Initialize the message passing layer. + + :param aggr: Aggregation method ('add', 'mean', 'max') + :param flow: Direction of message flow + :param node_dim: Dimension along which to aggregate + """ + super().__init__() + self.aggr = aggr + self.flow = flow + self.node_dim = node_dim + + def propagate(self, edge_index: torch.Tensor, size: Optional[tuple[int, int]] = None, **kwargs) -> torch.Tensor: + """Propagate messages along edges. + + :param edge_index: Edge connectivity tensor + :param size: Optional size tuple for bipartite graphs + :param kwargs: Additional arguments including node features 'x' + :returns: Aggregated messages + :raises ValueError: If node features 'x' are not provided + """ + i = 1 if self.flow == "source_to_target" else 0 + j = 0 if i == 1 else 1 + x = kwargs.get("x") + if x is None: + raise ValueError("propagate requires node features passed as keyword 'x'") + x_i = x[edge_index[i]] + x_j = x[edge_index[j]] + msg = self.message( + edge_index_i=edge_index[i], + edge_index_j=edge_index[j], + x_i=x_i, + x_j=x_j, + **kwargs, + ) + dim_size = x.size(0) if hasattr(x, "size") else len(x) + out = self.aggregate(msg, index=edge_index[i], dim_size=dim_size) + return self.update(out) + + def message(self, *args: Any, **kwargs: Any) -> torch.Tensor: + """Compute messages for each edge. + + :param args: Positional arguments + :param kwargs: Keyword arguments including 'x_j' for source node features + :returns: Message tensor + :raises ValueError: If 'x_j' is not provided + """ + x_j = kwargs.get("x_j") if "x_j" in kwargs else (args[1] if len(args) > 1 else None) + if x_j is None: + raise ValueError("message requires node features 'x_j'") + return x_j + + def aggregate(self, inputs: torch.Tensor, index: torch.Tensor, dim_size: Optional[int] = None) -> torch.Tensor: + """Aggregate messages at target nodes. + + :param inputs: Message tensor + :param index: Target node indices + :param dim_size: Number of target nodes + :returns: Aggregated tensor + """ + from torch_scatter import scatter + + return scatter(inputs, index, dim=0, dim_size=dim_size, reduce=self.aggr) + + def update(self, inputs: torch.Tensor) -> torch.Tensor: + """Update node representations after aggregation. + + :param inputs: Aggregated messages + :returns: Updated node representations + """ + return inputs + + +class GraphAttentionConv(MessagePassing): + """Graph attention convolution used by MolGNet.""" + + def __init__(self, hidden: int, heads: int = 3, dropout: float = 0.0) -> None: + """Initialize the graph attention convolution. + + :param hidden: Hidden dimension size + :param heads: Number of attention heads + :param dropout: Dropout probability + :raises ValueError: If hidden is not divisible by heads + """ + super().__init__() + self.hidden = hidden + self.heads = heads + if hidden % heads != 0: + raise ValueError("hidden must be divisible by heads") + self.query = nn.Linear(hidden, heads * int(hidden / heads)) + self.key = nn.Linear(hidden, heads * int(hidden / heads)) + self.value = nn.Linear(hidden, heads * int(hidden / heads)) + self.attn_drop = nn.Dropout(dropout) + + def forward( + self, + x: torch.Tensor, + edge_index: torch.Tensor, + edge_attr: torch.Tensor, + size: Optional[tuple[int, int]] = None, + ) -> torch.Tensor: + """Apply graph attention convolution. + + :param x: Node feature tensor + :param edge_index: Edge connectivity tensor + :param edge_attr: Edge attribute tensor + :param size: Optional size tuple for bipartite graphs + :returns: Updated node features + """ + pseudo = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr + return self.propagate(edge_index=edge_index, x=x, pseudo=pseudo) + + def message( + self, + edge_index_i: torch.Tensor, + x_i: torch.Tensor, + x_j: torch.Tensor, + pseudo: torch.Tensor, + size_i: Optional[int] = None, + **kwargs, + ) -> torch.Tensor: + """Compute attention-weighted messages. + + :param edge_index_i: Target node indices + :param x_i: Target node features + :param x_j: Source node features + :param pseudo: Edge features + :param size_i: Number of target nodes + :param kwargs: Additional arguments + :returns: Attention-weighted messages + """ + query = self.query(x_i).view(-1, self.heads, int(self.hidden / self.heads)) + key = self.key(x_j + pseudo).view(-1, self.heads, int(self.hidden / self.heads)) + value = self.value(x_j + pseudo).view(-1, self.heads, int(self.hidden / self.heads)) + denom = math.sqrt(int(self.hidden / self.heads)) + alpha = (query * key).sum(dim=-1) / denom + alpha = softmax(src=alpha, index=edge_index_i, num_nodes=size_i) + alpha = self.attn_drop(alpha.view(-1, self.heads, 1)) + return alpha * value + + def update(self, aggr_out: torch.Tensor) -> torch.Tensor: + """Reshape aggregated output. + + :param aggr_out: Aggregated attention output + :returns: Reshaped tensor + """ + return aggr_out.view(-1, self.heads * int(self.hidden / self.heads)) + + +class GTLayer(nn.Module): + """Graph Transformer layer composed from attention and feed-forward blocks.""" + + def __init__(self, hidden: int, heads: int, dropout: float, num_message_passing: int) -> None: + """Initialize the Graph Transformer layer. + + :param hidden: Hidden dimension size + :param heads: Number of attention heads + :param dropout: Dropout probability + :param num_message_passing: Number of message passing iterations + """ + super().__init__() + self.attention = GraphAttentionConv(hidden, heads, dropout) + self.att_out = AttentionOut(hidden, dropout) + self.intermediate = Intermediate(hidden) + self.output = GTOut(hidden, dropout) + self.gru = nn.GRU(hidden, hidden) + self.LayerNorm = BertLayerNorm(hidden, eps=1e-12) + self.time_step = num_message_passing + + def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor: + """Apply Graph Transformer layer. + + :param x: Node feature tensor + :param edge_index: Edge connectivity tensor + :param edge_attr: Edge attribute tensor + :returns: Updated node features + """ + h = x.unsqueeze(0) + for _ in range(self.time_step): + attention_output = self.attention.forward(x, edge_index, edge_attr) + attention_output = self.att_out.forward(attention_output, x) + intermediate_output = self.intermediate.forward(attention_output) + m = self.output.forward(intermediate_output, attention_output) + x, h = self.gru(m.unsqueeze(0), h) + x = self.LayerNorm.forward(x.squeeze(0)) + return x + + +class MolGNet(torch.nn.Module): + """MolGNet model implementation used for node embeddings.""" + + def __init__( + self, + num_layer: int, + emb_dim: int, + heads: int, + num_message_passing: int, + drop_ratio: float = 0, + ) -> None: + """Initialize the MolGNet model. + + :param num_layer: Number of Graph Transformer layers + :param emb_dim: Embedding dimension + :param heads: Number of attention heads + :param num_message_passing: Number of message passing iterations per layer + :param drop_ratio: Dropout ratio + """ + super().__init__() + self.num_layer = num_layer + self.drop_ratio = drop_ratio + self.x_embedding = torch.nn.Embedding(178, emb_dim) + self.x_seg_embed = torch.nn.Embedding(3, emb_dim) + self.edge_embedding = torch.nn.Embedding(18, emb_dim) + self.edge_seg_embed = torch.nn.Embedding(3, emb_dim) + self.reset_parameters() + self.gnns = torch.nn.ModuleList( + [GTLayer(emb_dim, heads, drop_ratio, num_message_passing) for _ in range(num_layer)] + ) + + def reset_parameters(self) -> None: + """Reset model parameters using Xavier initialization.""" + torch.nn.init.xavier_uniform_(self.x_embedding.weight.data) + torch.nn.init.xavier_uniform_(self.x_seg_embed.weight.data) + torch.nn.init.xavier_uniform_(self.edge_embedding.weight.data) + torch.nn.init.xavier_uniform_(self.edge_seg_embed.weight.data) + + def forward(self, *argv: Any) -> torch.Tensor: + """Forward pass through the MolGNet model. + + :param argv: Either 5 tensors (x, edge_index, edge_attr, node_seg, edge_seg) + or a single Data object + :returns: Node embeddings + :raises ValueError: If incorrect number of arguments provided + """ + if len(argv) == 5: + x, edge_index, edge_attr, node_seg, edge_seg = argv + elif len(argv) == 1: + data = argv[0] + x, edge_index, edge_attr, node_seg, edge_seg = ( + data.x, + data.edge_index, + data.edge_attr, + data.node_seg, + data.edge_seg, + ) + else: + raise ValueError("unmatched number of arguments.") + x = self.x_embedding(x).sum(1) + self.x_seg_embed(node_seg) + edge_attr = self.edge_embedding(edge_attr).sum(1) + edge_attr = edge_attr + self.edge_seg_embed(edge_seg) + for gnn in self.gnns: + x = gnn(x, edge_index, edge_attr) + return x + + +class MolGNetFeaturizer(DrugFeaturizer): + """Featurizer that generates MolGNet node embeddings from SMILES strings. + + MolGNet is a graph neural network that produces per-node embeddings for + molecules. This featurizer requires a pre-trained MolGNet checkpoint. + + Example usage:: + + featurizer = MolGNetFeaturizer(checkpoint_path="data/MolGNet.pt", device="cuda") + features = featurizer.load_or_generate("data", "GDSC1") + """ + + # Default model hyperparameters + NUM_LAYER = 5 + EMB_DIM = 768 + HEADS = 12 + MSG_PASS = 3 + DROP = 0.0 + + def __init__(self, checkpoint_path: str = "data/MolGNet.pt", device: str = "cpu"): + """Initialize the MolGNet featurizer. + + :param checkpoint_path: Path to the MolGNet checkpoint file + :param device: Device to use for computation ('cpu' or 'cuda') + """ + super().__init__(device=device) + self.checkpoint_path = checkpoint_path + self._model = None + self._self_loop = SelfLoop() + self._add_seg = AddSegId() + + def _load_model(self): + """Lazily load the MolGNet model. + + :raises Exception: If checkpoint loading fails + """ + if self._model is None: + _init_rdkit_features() + + self._model = MolGNet( + num_layer=self.NUM_LAYER, + emb_dim=self.EMB_DIM, + heads=self.HEADS, + num_message_passing=self.MSG_PASS, + drop_ratio=self.DROP, + ) + + device = torch.device(self.device) + ckpt = torch.load(self.checkpoint_path, map_location=device) # noqa: S614 + try: + self._model.load_state_dict(ckpt) + except Exception: + if isinstance(ckpt, dict) and "state_dict" in ckpt: + self._model.load_state_dict(ckpt["state_dict"]) + else: + raise + + self._model = self._model.to(device) + self._model.eval() + + def featurize(self, smiles: str) -> torch.Tensor | None: + """Convert a SMILES string to MolGNet node embeddings. + + :param smiles: SMILES string representing the drug + :returns: Node embeddings tensor or None if conversion fails + :raises RuntimeError: If model is not loaded + """ + _init_rdkit_features() + self._load_model() + + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return None + + graph = mol_to_graph_data_obj_complex(mol) + graph = self._self_loop(graph) + graph = self._add_seg(graph) + graph = graph.to(self.device) + + if self._model is None: + raise RuntimeError("Model not loaded. Call _load_model() first.") + + with torch.no_grad(): + embeddings = self._model(graph) + + return embeddings.cpu() + + @classmethod + def get_feature_name(cls) -> str: + """Return the feature view name. + + :returns: 'molgnet_embeddings' + """ + return "molgnet_embeddings" + + @classmethod + def get_output_filename(cls) -> str: + """Return the output filename for cached embeddings. + + :returns: 'MolGNet_dict.pkl' + """ + return "MolGNet_dict.pkl" + + def _save_embeddings(self, embeddings: list, drug_ids: list[str], output_path: Path) -> None: + """Save MolGNet embeddings to disk as a pickle file. + + :param embeddings: List of embedding tensors + :param drug_ids: List of drug identifiers + :param output_path: Path to save the embeddings + """ + molgnet_dict = {} + for drug_id, emb in zip(drug_ids, embeddings, strict=True): + if emb is not None: + molgnet_dict[drug_id] = emb + + with open(output_path, "wb") as f: + pickle.dump(molgnet_dict, f) + + # Also save per-drug CSVs for DIPK compatibility + dataset_dir = output_path.parent + out_drugs_dir = dataset_dir / "DIPK_features" / "Drugs" + os.makedirs(out_drugs_dir, exist_ok=True) + + for drug_id, emb in molgnet_dict.items(): + arr = emb.cpu().detach().numpy() if isinstance(emb, torch.Tensor) else np.array(emb) + df_emb = pd.DataFrame(arr) + out_csv = out_drugs_dir / f"MolGNet_{drug_id}.csv" + df_emb.to_csv(out_csv, sep="\t", index=False) + + def load_embeddings(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load pre-generated MolGNet embeddings from disk. + + :param data_path: Path to the data directory + :param dataset_name: Name of the dataset + :returns: FeatureDataset containing the embeddings + :raises FileNotFoundError: If the embeddings file is not found + """ + embeddings_file = Path(data_path) / dataset_name / self.get_output_filename() + + if not embeddings_file.exists(): + raise FileNotFoundError( + f"MolGNet embeddings file not found: {embeddings_file}. " + f"Use load_or_generate() to automatically generate embeddings." + ) + + with open(embeddings_file, "rb") as f: + molgnet_dict = pickle.load(f) # noqa: S301 + + feature_name = self.get_feature_name() + features = {} + + for drug_id, emb in molgnet_dict.items(): + features[str(drug_id)] = {feature_name: emb} + + return FeatureDataset(features) + + +class MolGNetMixin: + """Mixin that provides MolGNet drug embeddings loading for DRP models. + + This mixin implements load_drug_features using the MolGNetFeaturizer. + It automatically generates embeddings if they don't exist. + + Class attributes that can be overridden: + - molgnet_checkpoint_path: Path to MolGNet checkpoint (default: 'data/MolGNet.pt') + - molgnet_device: Device for MolGNet model ('cpu', 'cuda', or 'auto') + + Example usage:: + + from drevalpy.models.drp_model import DRPModel + from drevalpy.datasets.featurizer.drug.molgnet import MolGNetMixin + + class MyModel(MolGNetMixin, DRPModel): + drug_views = ["molgnet_embeddings"] + ... + """ + + molgnet_checkpoint_path: str = "data/MolGNet.pt" + molgnet_device: str = "auto" + + def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load MolGNet drug embeddings. + + Uses the MolGNetFeaturizer to load pre-generated embeddings or generate + them automatically if they don't exist. + + :param data_path: Path to the data directory, e.g., 'data/' + :param dataset_name: Name of the dataset, e.g., 'GDSC1' + :returns: FeatureDataset containing the MolGNet embeddings + """ + device = self.molgnet_device + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + + featurizer = MolGNetFeaturizer(checkpoint_path=self.molgnet_checkpoint_path, device=device) + return featurizer.load_or_generate(data_path, dataset_name) + + +def main(): + """Process drug SMILES and save MolGNet embeddings from command line.""" + parser = argparse.ArgumentParser(description="Generate MolGNet embeddings for drugs.") + parser.add_argument("dataset_name", type=str, help="The name of the dataset to process.") + parser.add_argument("--data_path", type=str, default="data", help="Path to the data folder") + parser.add_argument( + "--checkpoint", + type=str, + default="data/MolGNet.pt", + help="Path to MolGNet checkpoint (can be obtained from Zenodo: https://doi.org/10.5281/zenodo.12633909)", + ) + parser.add_argument("--device", type=str, default="cpu", help="Torch device (cpu or cuda)") + args = parser.parse_args() + + featurizer = MolGNetFeaturizer(checkpoint_path=args.checkpoint, device=args.device) + featurizer.generate_embeddings(args.data_path, args.dataset_name) + + +if __name__ == "__main__": + main() diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py index 635c1ea3..e2d95f2b 100644 --- a/drevalpy/experiment.py +++ b/drevalpy/experiment.py @@ -49,6 +49,7 @@ def drug_response_experiment( path_data: str = "data", model_checkpoint_dir: str = "TEMPORARY", hyperparameter_tuning=True, + n_trials: int = 20, final_model_on_full_data: bool = False, wandb_project: str | None = None, ) -> None: @@ -99,7 +100,9 @@ def drug_response_experiment( :param overwrite: whether to overwrite existing results :param path_data: path to the data directory, usually data/ :param model_checkpoint_dir: directory to save model checkpoints. If "TEMPORARY", a temporary directory is created. - :param hyperparameter_tuning: whether to run in debug mode - if False, only select first hyperparameter set + :param hyperparameter_tuning: whether to perform hyperparameter tuning. If False, uses the first hyperparameter + configuration from the search space. + :param n_trials: number of Bayesian optimization trials for hyperparameter tuning. Default is 20. :param final_model_on_full_data: if True, a final/production model is saved in the results directory. If hyperparameter_tuning is true, the final model is produced according to the hyperparameter tuning procedure which was evaluated in the nested cross validation. @@ -173,8 +176,12 @@ def drug_response_experiment( ) parent_dir = os.path.dirname(predictions_path) - model_hpam_set = model_class.get_hyperparameter_set() - if not hyperparameter_tuning: + if hyperparameter_tuning: + # Use raw search space for Bayesian optimization + model_hpam_set = model_class.get_hyperparameter_search_space() + else: + # Use expanded grid and take first (default) configuration + model_hpam_set = model_class.get_hyperparameter_set() model_hpam_set = [model_hpam_set[0]] if response_data.cv_splits is None: @@ -222,6 +229,7 @@ def drug_response_experiment( "metric": hpam_optimization_metric, "path_data": path_data, "model_checkpoint_dir": model_checkpoint_dir, + "n_trials": n_trials, } # During hyperparameter tuning, create separate wandb runs per trial if enabled @@ -383,6 +391,7 @@ def drug_response_experiment( test_mode=test_mode, val_ratio=0.1, hyperparameter_tuning=hyperparameter_tuning, + n_trials=n_trials, ) consolidate_single_drug_model_predictions( @@ -1148,53 +1157,173 @@ def train_and_evaluate( return results +def _deep_equal(a: Any, b: Any) -> bool: + """ + Compare two values for equality, handling nested structures. + + :param a: first value + :param b: second value + :returns: True if values are equal (including nested structures) + """ + if isinstance(a, list) and isinstance(b, list): + if len(a) != len(b): + return False + return all(_deep_equal(ai, bi) for ai, bi in zip(a, b, strict=True)) + elif isinstance(a, dict) and isinstance(b, dict): + if set(a.keys()) != set(b.keys()): + return False + return all(_deep_equal(a[k], b[k]) for k in a.keys()) + else: + return a == b + + +def _sample_hyperparameters_from_search_space(trial, search_space: dict[str, Any]) -> dict[str, Any]: + """ + Sample hyperparameters from a search space definition using Optuna. + + :param trial: Optuna trial object + :param search_space: dictionary mapping parameter names to their search space definitions + :returns: dictionary of sampled hyperparameters + :raises ValueError: if an unknown parameter type is encountered in the search space + """ + sampled = {} + for param_name, param_def in search_space.items(): + if isinstance(param_def, dict) and "type" in param_def: + # Structured search space definition for continuous ranges + if "default" not in param_def: + raise ValueError( + f"Hyperparameter '{param_name}' has continuous range definition " + f"but missing required 'default' field. " + f"Please add a 'default' value to use when hyperparameter_tuning=False." + ) + param_type = param_def["type"] + low = param_def["low"] + high = param_def["high"] + log_scale = param_def.get("log", False) + + if param_type == "int": + sampled[param_name] = trial.suggest_int(param_name, low, high, log=log_scale) + elif param_type == "float": + if log_scale: + sampled[param_name] = trial.suggest_float(param_name, low, high, log=True) + else: + sampled[param_name] = trial.suggest_float(param_name, low, high) + else: + raise ValueError(f"Unknown parameter type: {param_type}") + elif isinstance(param_def, list): + # Categorical choices + if len(param_def) == 1: + # Single value, no tuning needed + sampled[param_name] = param_def[0] + else: + sampled[param_name] = trial.suggest_categorical(param_name, param_def) + else: + # Single fixed value (not a list or dict) + sampled[param_name] = param_def + + return sampled + + def hpam_tune( model: DRPModel, train_dataset: DrugResponseDataset, validation_dataset: DrugResponseDataset, - hpam_set: list[dict], + hpam_set: list[dict] | dict[str, Any], early_stopping_dataset: DrugResponseDataset | None = None, response_transformation: TransformerMixin | None = None, metric: str = "RMSE", path_data: str = "data", model_checkpoint_dir: str = "TEMPORARY", + n_trials: int = 20, *, split_index: int | None = None, wandb_project: str | None = None, wandb_base_config: dict[str, Any] | None = None, ) -> dict: """ - Tune the hyperparameters for the given model in an iterative manner. + Tune hyperparameters using Bayesian optimization with Optuna. + + This function uses Optuna's TPE (Tree-structured Parzen Estimator) sampler + for efficient hyperparameter search. Trials are run sequentially. :param model: model to use :param train_dataset: training dataset :param validation_dataset: validation dataset - :param hpam_set: hyperparameters to tune + :param hpam_set: either a search space dictionary (for Bayesian optimization) or + a list of hyperparameter configurations (legacy grid search format) :param early_stopping_dataset: early stopping dataset :param response_transformation: normalizer to use for the response data :param metric: metric to evaluate which model is the best :param path_data: path to the data directory, e.g., data/ :param model_checkpoint_dir: directory to save model checkpoints + :param n_trials: number of Bayesian optimization trials to run :param split_index: optional CV split index, used for naming wandb runs :param wandb_project: optional wandb project name; if provided, enables per-trial wandb runs :param wandb_base_config: optional base config dict to include in each wandb run :returns: best hyperparameters :raises AssertionError: if hpam_set is empty """ - if len(hpam_set) == 0: - raise AssertionError("hpam_set must contain at least one hyperparameter configuration") - if len(hpam_set) == 1: - return hpam_set[0] + import optuna + from optuna.samplers import TPESampler + + # Handle legacy list format (grid search) - convert to search space + if isinstance(hpam_set, list): + if len(hpam_set) == 0: + raise AssertionError("hpam_set must contain at least one hyperparameter configuration") + if len(hpam_set) == 1: + return hpam_set[0] + + # Convert list of dicts to search space by extracting unique values per parameter + # Handle nested structures (like lists of lists) by using a list-based approach + search_space: dict[str, Any] = {} + all_keys: set[str] = set() + for config in hpam_set: + all_keys.update(config.keys()) + + for key in all_keys: + # Collect all values for this key, preserving order and handling unhashable types + values: list[Any] = [] + seen: list[Any] = [] + for config in hpam_set: + if key in config: + value = config[key] # Use direct access since we know key exists + # For unhashable types (lists, dicts), use deep comparison + if isinstance(value, (list, dict)): + # Check if we've seen an equivalent value + if not any(_deep_equal(value, v) for v in seen): + values.append(value) + seen.append(value) + else: + # For hashable types, use set for deduplication + if value not in values: + values.append(value) + if len(values) == 1: + search_space[key] = values[0] + else: + search_space[key] = values + else: + search_space = hpam_set + + # Check if there's anything to tune + tunable_params = [ + k for k, v in search_space.items() if isinstance(v, (list, dict)) and (not isinstance(v, list) or len(v) > 1) + ] + if not tunable_params: + # No tuning needed, return fixed values + return {k: (v[0] if isinstance(v, list) else v) for k, v in search_space.items()} # Mark that we're in hyperparameter tuning phase - # This prevents updating wandb.config during tuning - we'll only log final best hyperparameters model._in_hyperparameter_tuning = True - best_hyperparameters = None mode = get_mode(metric) - best_score = float("inf") if mode == "min" else float("-inf") - for trial_idx, hyperparameter in enumerate(hpam_set): - print(f"Training model with hyperparameters: {hyperparameter}") + direction = "minimize" if mode == "min" else "maximize" + + def objective(trial): + # Sample hyperparameters + hyperparameter = _sample_hyperparameters_from_search_space(trial, search_space) + trial_idx = trial.number + + print(f"Trial {trial_idx}: Training model with hyperparameters: {hyperparameter}") # Create a separate wandb run for each hyperparameter trial if enabled if wandb_project is not None: @@ -1222,40 +1351,52 @@ def hpam_tune( finish_previous=True, ) - # During hyperparameter tuning, don't update wandb config via log_hyperparameters - # Trial hyperparameters are stored in wandb.config for each run - score = train_and_evaluate( - model=model, - hpams=hyperparameter, - path_data=path_data, - train_dataset=train_dataset, - validation_dataset=validation_dataset, - early_stopping_dataset=early_stopping_dataset, - metric=metric, - response_transformation=response_transformation, - model_checkpoint_dir=model_checkpoint_dir, - )[metric] + try: + score = train_and_evaluate( + model=model, + hpams=hyperparameter, + path_data=path_data, + train_dataset=train_dataset, + validation_dataset=validation_dataset, + early_stopping_dataset=early_stopping_dataset, + metric=metric, + response_transformation=response_transformation, + model_checkpoint_dir=model_checkpoint_dir, + )[metric] + + if np.isnan(score): + # Return a bad score for NaN results + score = float("inf") if mode == "min" else float("-inf") + else: + print(f"Trial {trial_idx}: {metric} = {np.round(score, 4)}") - # Note: train_and_evaluate() already logs val_* metrics once via - # DRPModel.compute_and_log_final_metrics(..., prefix="val_"). - # Avoid logging val_{metric} again here (it would create duplicate points). - if np.isnan(score): + except Exception as e: + print(f"Trial {trial_idx} failed: {e}") + score = float("inf") if mode == "min" else float("-inf") + + finally: if model.is_wandb_enabled(): model.finish_wandb() - continue - if (mode == "min" and score < best_score) or (mode == "max" and score > best_score): - print(f"current best {metric} score: {np.round(score, 3)}") - best_score = score - best_hyperparameters = hyperparameter + return score + + # Create and run the Optuna study + study = optuna.create_study(direction=direction, sampler=TPESampler(seed=42)) + study.optimize(objective, n_trials=n_trials, show_progress_bar=True) + + # Get best hyperparameters + best_hyperparameters = study.best_params - # Close this trial's run after all logging is done - if model.is_wandb_enabled(): - model.finish_wandb() + # Fill in fixed parameters that weren't tuned + for key, value in search_space.items(): + if key not in best_hyperparameters: + if isinstance(value, list) and len(value) == 1: + best_hyperparameters[key] = value[0] + elif not isinstance(value, (list, dict)): + best_hyperparameters[key] = value - if best_hyperparameters is None: - warnings.warn("all hpams lead to NaN respone. using last hpam combination.", stacklevel=2) - best_hyperparameters = hyperparameter + print(f"\nBest trial: {study.best_trial.number}") + print(f"Best {metric}: {np.round(study.best_value, 4)}") return best_hyperparameters @@ -1265,50 +1406,137 @@ def hpam_tune_raytune( train_dataset: DrugResponseDataset, validation_dataset: DrugResponseDataset, early_stopping_dataset: DrugResponseDataset | None, - hpam_set: list[dict], + hpam_set: list[dict] | dict[str, Any], response_transformation: TransformerMixin | None = None, metric: str = "RMSE", ray_path: str = "raytune", path_data: str = "data", model_checkpoint_dir: str = "TEMPORARY", + n_trials: int = 20, ) -> dict: """ - Tune the hyperparameters for the given model using Ray Tune. Ray[tune] must be installed. + Tune hyperparameters using Bayesian optimization with Ray Tune and Optuna. + + This function uses Ray Tune with OptunaSearch for parallel Bayesian optimization. + Ray[tune] and optuna must be installed. :param model: model to use :param train_dataset: training dataset :param validation_dataset: validation dataset :param early_stopping_dataset: early stopping dataset - :param hpam_set: hyperparameters to tune + :param hpam_set: either a search space dictionary (for Bayesian optimization) or + a list of hyperparameter configurations (legacy grid search format) :param response_transformation: normalizer for response data :param metric: evaluation metric :param ray_path: path to the raytune directory :param path_data: path to data directory, e.g., data/ :param model_checkpoint_dir: directory for model checkpoints + :param n_trials: number of Bayesian optimization trials to run :returns: best hyperparameters + :raises AssertionError: if hpam_set is empty :raises ValueError: if best_result is None """ - print("Starting hyperparameter tuning with Ray Tune ...") - print(f"Hyperparameter combinations to evaluate: {len(hpam_set)}") - print() - - if len(hpam_set) == 1: - return hpam_set[0] - import ray from ray import tune + from ray.tune.search.optuna import OptunaSearch + + print("Starting hyperparameter tuning with Ray Tune (Bayesian optimization) ...") + + # Handle legacy list format (grid search) - convert to search space + if isinstance(hpam_set, list): + if len(hpam_set) == 0: + raise AssertionError("hpam_set must contain at least one hyperparameter configuration") + if len(hpam_set) == 1: + return hpam_set[0] + + # Convert list of dicts to search space + search_space: dict[str, Any] = {} + all_keys: set[str] = set() + for config in hpam_set: + all_keys.update(config.keys()) + + for key in all_keys: + # Collect all values for this key, preserving order and handling unhashable types + values: list[Any] = [] + seen: list[Any] = [] + for config in hpam_set: + if key in config: + value = config[key] # Use direct access since we know key exists + # For unhashable types (lists, dicts), use deep comparison + if isinstance(value, (list, dict)): + # Check if we've seen an equivalent value + if not any(_deep_equal(value, v) for v in seen): + values.append(value) + seen.append(value) + else: + # For hashable types, use set for deduplication + if value not in values: + values.append(value) + if len(values) == 1: + search_space[key] = values[0] + else: + search_space[key] = values + else: + search_space = hpam_set + + # Check if there's anything to tune + tunable_params = [ + k for k, v in search_space.items() if isinstance(v, (list, dict)) and (not isinstance(v, list) or len(v) > 1) + ] + if not tunable_params: + return {k: (v[0] if isinstance(v, list) else v) for k, v in search_space.items()} + + # Convert search space to Ray Tune format + ray_search_space = {} + fixed_params = {} + for param_name, param_def in search_space.items(): + if isinstance(param_def, dict) and "type" in param_def: + if "default" not in param_def: + raise ValueError( + f"Hyperparameter '{param_name}' has continuous range definition " + f"but missing required 'default' field. " + f"Please add a 'default' value to use when hyperparameter_tuning=False." + ) + param_type = param_def["type"] + low = param_def["low"] + high = param_def["high"] + log_scale = param_def.get("log", False) + + if param_type == "int": + if log_scale: + ray_search_space[param_name] = tune.lograndint(low, high) + else: + ray_search_space[param_name] = tune.randint(low, high + 1) + elif param_type == "float": + if log_scale: + ray_search_space[param_name] = tune.loguniform(low, high) + else: + ray_search_space[param_name] = tune.uniform(low, high) + elif isinstance(param_def, list): + if len(param_def) == 1: + fixed_params[param_name] = param_def[0] + else: + ray_search_space[param_name] = tune.choice(param_def) + else: + fixed_params[param_name] = param_def + + print(f"Tunable parameters: {list(ray_search_space.keys())}") + print(f"Fixed parameters: {list(fixed_params.keys())}") + print(f"Number of trials: {n_trials}") + print() path_data = os.path.abspath(path_data) if not ray.is_initialized(): ray.init(_temp_dir=os.path.join(os.path.expanduser("~"), "raytmp")) resources_per_trial = {"gpu": 1} if torch.cuda.is_available() else {"cpu": 1} - def trainable(hpams): + def trainable(config): try: - inner = hpams["hpams"] + # Merge sampled params with fixed params + hyperparameter = {**fixed_params, **config} result = train_and_evaluate( model=model, - hpams=inner, + hpams=hyperparameter, path_data=path_data, train_dataset=train_dataset, validation_dataset=validation_dataset, @@ -1317,35 +1545,50 @@ def trainable(hpams): response_transformation=response_transformation, model_checkpoint_dir=model_checkpoint_dir, ) - tune.report(metrics={metric: result[metric]}) + return {metric: result[metric]} except Exception as e: import traceback print("Trial failed:", e) traceback.print_exc() + # Return bad score on failure + mode = get_mode(metric) + return {metric: float("inf") if mode == "min" else float("-inf")} trainable = tune.with_resources(trainable, resources_per_trial) - param_space = {"hpams": tune.grid_search(hpam_set)} + + mode = get_mode(metric) + optuna_search = OptunaSearch(metric=metric, mode=mode, seed=42) tuner = tune.Tuner( trainable, - param_space=param_space, + param_space=ray_search_space, run_config=tune.RunConfig( storage_path=ray_path, name="hpam_tuning", ), tune_config=tune.TuneConfig( metric=metric, - mode=get_mode(metric), + mode=mode, + search_alg=optuna_search, + num_samples=n_trials, + max_concurrent_trials=1, # Run one at a time for Bayesian optimization ), ) results = tuner.fit() - best_result = results.get_best_result(metric=metric, mode=get_mode(metric)) + best_result = results.get_best_result(metric=metric, mode=mode) ray.shutdown() + if best_result.config is None: raise ValueError("Ray failed; no best result.") - return best_result.config["hpams"] + + # Merge best config with fixed params + best_hyperparameters = {**fixed_params, **best_result.config} + + print(f"\nBest {metric}: {np.round(best_result.metrics[metric], 4)}") + + return best_hyperparameters @pipeline_function @@ -1470,6 +1713,7 @@ def train_final_model( test_mode: str = "LCO", val_ratio: float = 0.1, hyperparameter_tuning: bool = True, + n_trials: int = 20, ) -> None: """ Final Production Model Training. @@ -1493,6 +1737,7 @@ def train_final_model( :param test_mode: split logic for validation (LCO, LDO, LTO, LPO) :param val_ratio: validation size ratio :param hyperparameter_tuning: whether to perform hyperparameter tuning + :param n_trials: number of Bayesian optimization trials for hyperparameter tuning """ print("Training final model with application-specific validation strategy ...") @@ -1512,8 +1757,9 @@ def train_final_model( else: early_stopping_dataset = None - hpam_set = model.get_hyperparameter_set() if hyperparameter_tuning: + # Use raw search space for Bayesian optimization + hpam_set = model.get_hyperparameter_search_space() best_hpams = hpam_tune( model=model, train_dataset=train_dataset, @@ -1524,8 +1770,11 @@ def train_final_model( metric=metric, path_data=path_data, model_checkpoint_dir=model_checkpoint_dir, + n_trials=n_trials, ) else: + # Use expanded grid and take first (default) configuration + hpam_set = model.get_hyperparameter_set() best_hpams = hpam_set[0] print(f"Best hyperparameters for final model: {best_hpams}") diff --git a/drevalpy/models/DrugGNN/drug_gnn.py b/drevalpy/models/DrugGNN/drug_gnn.py index 89f0fb7d..ff8f3bb3 100644 --- a/drevalpy/models/DrugGNN/drug_gnn.py +++ b/drevalpy/models/DrugGNN/drug_gnn.py @@ -234,14 +234,6 @@ def __init__(self): self.model: DrugGNNModule | None = None self.hyperparameters = {} - @classmethod - def get_model_name(cls) -> str: - """Return the name of the model. - - :return: The name of the model. - """ - return "DrugGNN" - @property def cell_line_views(self) -> list[str]: """Return the sources the model needs as input for describing the cell line. @@ -441,7 +433,7 @@ def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDatase if not graph_path.exists(): raise FileNotFoundError( f"Drug graph directory not found at {graph_path}. " - f"Please run 'create_drug_graphs.py' for the {dataset_name} dataset." + f"Please use DrugGraphFeaturizer to generate graphs for the {dataset_name} dataset." ) drug_graphs = {} diff --git a/drevalpy/models/MOLIR/molir.py b/drevalpy/models/MOLIR/molir.py index 4aaab346..4c456c59 100644 --- a/drevalpy/models/MOLIR/molir.py +++ b/drevalpy/models/MOLIR/molir.py @@ -48,15 +48,6 @@ def __init__(self) -> None: self.gene_expression_scaler = StandardScaler() self.selector: VarianceFeatureSelector | None = None - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: MOLIR - """ - return "MOLIR" - def build_model(self, hyperparameters: dict[str, Any]) -> None: """ Builds the model from hyperparameters. diff --git a/drevalpy/models/SRMF/srmf.py b/drevalpy/models/SRMF/srmf.py index 8e1aaeb0..1d1a5eda 100644 --- a/drevalpy/models/SRMF/srmf.py +++ b/drevalpy/models/SRMF/srmf.py @@ -52,15 +52,6 @@ def __init__(self) -> None: self.max_iter: int = 50 self.seed: int = 1 - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: SRMF - """ - return "SRMF" - def build_model(self, hyperparameters: dict) -> None: """ Initializes hyperparameters for SRMF model. diff --git a/drevalpy/models/SimpleNeuralNetwork/hyperparameters.yaml b/drevalpy/models/SimpleNeuralNetwork/hyperparameters.yaml index 721991f2..47440d56 100644 --- a/drevalpy/models/SimpleNeuralNetwork/hyperparameters.yaml +++ b/drevalpy/models/SimpleNeuralNetwork/hyperparameters.yaml @@ -62,3 +62,22 @@ ChemBERTaNeuralNetwork: - 16 max_epochs: - 100 + +PCANeuralNetwork: + dropout_prob: + - 0.3 + units_per_layer: + - - 32 + - 16 + - 8 + - 4 + - - 128 + - 64 + - 32 + - - 64 + - 64 + - 32 + n_components: + - 100 + max_epochs: + - 100 diff --git a/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py index 39b778a1..17e95f14 100644 --- a/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py @@ -42,15 +42,6 @@ def __init__(self): self.pca_ncomp = 100 self.gene_expression_scaler = StandardScaler() - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: MultiOmicsNeuralNetwork - """ - return "MultiOmicsNeuralNetwork" - def build_model(self, hyperparameters: dict): """ Builds the model from hyperparameters. diff --git a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py index 077e69ff..4df87970 100644 --- a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py @@ -1,4 +1,4 @@ -"""Contains the SimpleNeuralNetwork and the ChemBERTaNeuralNetwork model.""" +"""Contains the SimpleNeuralNetwork, ChemBERTaNeuralNetwork, and PCANeuralNetwork models.""" import json import os @@ -7,11 +7,11 @@ import joblib import numpy as np -import pandas as pd import torch from sklearn.preprocessing import StandardScaler from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset +from drevalpy.datasets.featurizer import ChemBERTaMixin, PCAMixin from ..drp_model import DRPModel from ..utils import load_and_select_gene_features, load_drug_fingerprint_features, scale_gene_expression @@ -35,15 +35,6 @@ def __init__(self): self.model = None self.gene_expression_scaler = StandardScaler() - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: SimpleNeuralNetwork - """ - return "SimpleNeuralNetwork" - def build_model(self, hyperparameters: dict): """ Builds the model from hyperparameters. @@ -90,7 +81,7 @@ def train( gene_expression_scaler=self.gene_expression_scaler, ) - dim_gex = next(iter(cell_line_input.features.values()))["gene_expression"].shape[0] + dim_gex = next(iter(cell_line_input.features.values()))[self.cell_line_views[0]].shape[0] dim_fingerprint = next(iter(drug_input.features.values()))[self.drug_views[0]].shape[0] self.hyperparameters["input_dim_gex"] = dim_gex self.hyperparameters["input_dim_fp"] = dim_fingerprint @@ -159,7 +150,7 @@ def predict( ) x = self.get_concatenated_features( - cell_line_view="gene_expression", + cell_line_view=self.cell_line_views[0], drug_view=self.drug_views[0], cell_line_ids_output=cell_line_ids, drug_ids_output=drug_ids, @@ -254,41 +245,13 @@ def load(cls, directory: str) -> "SimpleNeuralNetwork": return instance -class ChemBERTaNeuralNetwork(SimpleNeuralNetwork): +class ChemBERTaNeuralNetwork(ChemBERTaMixin, SimpleNeuralNetwork): """ChemBERTa Neural Network model using gene expression and ChemBERTa drug embeddings.""" drug_views = ["chemberta_embeddings"] - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: ChemBERTaNeuralNetwork - """ - return "ChemBERTaNeuralNetwork" - - def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: - """ - Loads the ChemBERTa embeddings. - - :param data_path: Path to the ChemBERTa embeddings, e.g., data/ - :param dataset_name: name of the dataset, e.g., GDSC1 - :returns: FeatureDataset containing the ChemBERTa embeddings - :raises FileNotFoundError: if the ChemBERTa embeddings file is not found - """ - chemberta_file = os.path.join(data_path, dataset_name, "drug_chemberta_embeddings.csv") - if not os.path.exists(chemberta_file): - raise FileNotFoundError( - f"ChemBERTa embeddings file not found: {chemberta_file}. " - "Please create it first with the respective drug_featurizer." - ) - chemberta_df = pd.read_csv(chemberta_file, dtype={"pubchem_id": str}) - features = {} - for _, row in chemberta_df.iterrows(): - drug_id = row["pubchem_id"] - embedding = row.drop("pubchem_id").to_numpy(dtype=np.float32) - features[drug_id] = {"chemberta_embeddings": embedding} +class PCANeuralNetwork(PCAMixin, SimpleNeuralNetwork): + """Neural Network model using PCA-transformed gene expression and fingerprints.""" - return FeatureDataset(features) + cell_line_views = ["gene_expression_pca"] diff --git a/drevalpy/models/SuperFELTR/superfeltr.py b/drevalpy/models/SuperFELTR/superfeltr.py index 6334dfc5..a4058157 100644 --- a/drevalpy/models/SuperFELTR/superfeltr.py +++ b/drevalpy/models/SuperFELTR/superfeltr.py @@ -61,15 +61,6 @@ def __init__(self) -> None: self.copy_number_variation_features = None self.selectors: dict[str, VarianceFeatureSelector] = {} - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: SuperFELTR - """ - return "SuperFELTR" - def build_model(self, hyperparameters) -> None: """ Builds the model from hyperparameters. diff --git a/drevalpy/models/__init__.py b/drevalpy/models/__init__.py index 5ecf2e4f..f8b934b8 100644 --- a/drevalpy/models/__init__.py +++ b/drevalpy/models/__init__.py @@ -30,6 +30,7 @@ "DrugGNN", "ChemBERTaNeuralNetwork", "PharmaFormerModel", + "PCANeuralNetwork", ] from .baselines.multi_omics_random_forest import MultiOmicsRandomForest @@ -57,7 +58,7 @@ from .MOLIR.molir import MOLIR from .PharmaFormer.pharmaformer import PharmaFormerModel from .SimpleNeuralNetwork.multiomics_neural_network import MultiOmicsNeuralNetwork -from .SimpleNeuralNetwork.simple_neural_network import ChemBERTaNeuralNetwork, SimpleNeuralNetwork +from .SimpleNeuralNetwork.simple_neural_network import ChemBERTaNeuralNetwork, PCANeuralNetwork, SimpleNeuralNetwork from .SRMF.srmf import SRMF from .SuperFELTR.superfeltr import SuperFELTR @@ -93,6 +94,7 @@ "DrugGNN": DrugGNN, "ChemBERTaNeuralNetwork": ChemBERTaNeuralNetwork, "PharmaFormer": PharmaFormerModel, + "PCANeuralNetwork": PCANeuralNetwork, } # MODEL_FACTORY is used in the pipeline! diff --git a/drevalpy/models/baselines/multi_omics_random_forest.py b/drevalpy/models/baselines/multi_omics_random_forest.py index 005bb14e..afe3e5c1 100644 --- a/drevalpy/models/baselines/multi_omics_random_forest.py +++ b/drevalpy/models/baselines/multi_omics_random_forest.py @@ -32,15 +32,6 @@ def __init__(self): self.pca = None self.pca_ncomp = 100 - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: MultiOmicsRandomForest - """ - return "MultiOmicsRandomForest" - def build_model(self, hyperparameters: dict): """ Builds the model from hyperparameters. diff --git a/drevalpy/models/baselines/naive_pred.py b/drevalpy/models/baselines/naive_pred.py index db1391c6..759bd499 100644 --- a/drevalpy/models/baselines/naive_pred.py +++ b/drevalpy/models/baselines/naive_pred.py @@ -109,15 +109,6 @@ def __init__(self): """ super().__init__() - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: NaivePredictor - """ - return "NaivePredictor" - def train( self, output: DrugResponseDataset, @@ -191,15 +182,6 @@ def __init__(self): super().__init__() self.drug_means = None - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: NaiveDrugMeanPredictor - """ - return "NaiveDrugMeanPredictor" - def train( self, output: DrugResponseDataset, @@ -299,15 +281,6 @@ def __init__(self): super().__init__() self.cell_line_means = None - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: NaiveCellLineMeanPredictor - """ - return "NaiveCellLineMeanPredictor" - def train( self, output: DrugResponseDataset, @@ -406,15 +379,6 @@ def __init__(self): super().__init__() self.tissue_means = None - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: NaiveTissueMeanPredictor - """ - return "NaiveTissueMeanPredictor" - def train( self, output: DrugResponseDataset, @@ -518,15 +482,6 @@ def __init__(self): self.cell_line_effects = {} self.drug_effects = {} - @classmethod - def get_model_name(cls) -> str: - """ - Returns the name of the model. - - :return: The name of the model as a string. - """ - return "NaiveMeanEffectsPredictor" - def train( self, output: DrugResponseDataset, @@ -643,15 +598,6 @@ def __init__(self): super().__init__() self.tissue_drug_means = None - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: NaiveTissueDrugMeanPredictor - """ - return "NaiveTissueDrugMeanPredictor" - def save(self, directory: str) -> None: """ Saves the model parameters to the given directory. diff --git a/drevalpy/models/baselines/singledrug_elastic_net.py b/drevalpy/models/baselines/singledrug_elastic_net.py index 13b27379..957b39fb 100644 --- a/drevalpy/models/baselines/singledrug_elastic_net.py +++ b/drevalpy/models/baselines/singledrug_elastic_net.py @@ -35,15 +35,6 @@ def build_model(self, hyperparameters): self.model = ElasticNet(**hyperparameters) self.gene_expression_scaler = StandardScaler() - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: SingleDrugElasticNet - """ - return "SingleDrugElasticNet" - def train( self, output: DrugResponseDataset, @@ -173,15 +164,6 @@ def build_model(self, hyperparameters: dict): ) super().build_model(hyperparameters) - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: SingleDrugProteomicsElasticNet - """ - return "SingleDrugProteomicsElasticNet" - def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: """ Loads the proteomics data. diff --git a/drevalpy/models/baselines/singledrug_random_forest.py b/drevalpy/models/baselines/singledrug_random_forest.py index c09567d7..6749a754 100644 --- a/drevalpy/models/baselines/singledrug_random_forest.py +++ b/drevalpy/models/baselines/singledrug_random_forest.py @@ -19,15 +19,6 @@ class SingleDrugRandomForest(RandomForest): drug_views = [] early_stopping = False - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: SingleDrugRandomForest - """ - return "SingleDrugRandomForest" - def train( self, output: DrugResponseDataset, @@ -148,15 +139,6 @@ def build_model(self, hyperparameters: dict): normalization_width=self.normalization_width, ) - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: SingleDrugProteomicsRandomForest - """ - return "SingleDrugProteomicsRandomForest" - def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: """ Loads the proteomics features. diff --git a/drevalpy/models/baselines/sklearn_models.py b/drevalpy/models/baselines/sklearn_models.py index 425727b3..b70a6b12 100644 --- a/drevalpy/models/baselines/sklearn_models.py +++ b/drevalpy/models/baselines/sklearn_models.py @@ -39,15 +39,6 @@ def __init__(self): self.gene_expression_scaler = StandardScaler() self.proteomics_transformer = None - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :raises NotImplementedError: If the method is not implemented in the child class. - """ - raise NotImplementedError("get_model_name method has to be implemented in the child class.") - def build_model(self, hyperparameters: dict): """ Builds the model from hyperparameters. @@ -251,15 +242,6 @@ def build_model(self, hyperparameters: dict): class RandomForest(SklearnModel): """RandomForest model for drug response prediction.""" - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: RandomForest - """ - return "RandomForest" - def build_model(self, hyperparameters: dict): """ Builds the model from hyperparameters. @@ -308,15 +290,6 @@ def build_model(self, hyperparameters: dict): class GradientBoosting(SklearnModel): """Gradient Boosting model for drug response prediction.""" - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: GradientBoosting - """ - return "GradientBoosting" - def build_model(self, hyperparameters: dict): """ Builds the model from hyperparameters. @@ -375,15 +348,6 @@ def build_model(self, hyperparameters: dict): normalization_width=self.normalization_width, ) - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: ProteomicsRandomForest - """ - return "ProteomicsRandomForest" - def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: """ Loads the cell line features. diff --git a/drevalpy/models/drp_model.py b/drevalpy/models/drp_model.py index 7599be2e..b3736e2b 100644 --- a/drevalpy/models/drp_model.py +++ b/drevalpy/models/drp_model.py @@ -251,14 +251,14 @@ def finish_wandb(self) -> None: self.wandb_run = None @classmethod - @abstractmethod @pipeline_function def get_model_name(cls) -> str: """ Returns the name of the model. - :return: model name + :return: model name (the class name) """ + return cls.__name__ @classmethod @pipeline_function @@ -282,13 +282,81 @@ def get_hyperparameter_set(cls) -> list[dict[str, Any]]: if hpams is None: return [{}] - # each param should be a list + # Convert continuous ranges to default values for grid expansion + # This handles the case when hyperparameter_tuning=False + processed_hpams: dict[str, Any] = {} for hp in hpams: - if not isinstance(hpams[hp], list): - hpams[hp] = [hpams[hp]] - grid = list(ParameterGrid(hpams)) + value = hpams[hp] + # If it's a continuous range definition, require and use the default value + if isinstance(value, dict) and "type" in value: + if "default" not in value: + raise ValueError( + f"Hyperparameter '{hp}' has continuous range definition but missing required 'default' field. " + f"Please add a 'default' value to use when hyperparameter_tuning=False." + ) + # Validate default is within range + low = value["low"] + high = value["high"] + default = value["default"] + param_type = value["type"] + + if param_type == "int": + if not isinstance(default, int): + raise ValueError( + f"Hyperparameter '{hp}': default must be an integer, got {type(default).__name__}" + ) + if default < low or default > high: + raise ValueError( + f"Hyperparameter '{hp}': default value {default} is outside range [{low}, {high}]" + ) + elif param_type == "float": + if not isinstance(default, (int, float)): + raise ValueError( + f"Hyperparameter '{hp}': default must be a float, got {type(default).__name__}" + ) + if default < low or default > high: + raise ValueError( + f"Hyperparameter '{hp}': default value {default} is outside range [{low}, {high}]" + ) + + processed_hpams[hp] = [default] + elif isinstance(value, list): + processed_hpams[hp] = value + else: + # Single value + processed_hpams[hp] = [value] + grid = list(ParameterGrid(processed_hpams)) return grid + @classmethod + @pipeline_function + def get_hyperparameter_search_space(cls) -> dict[str, Any]: + """ + Load the raw hyperparameter search space from a YAML file. + + This method returns the search space definition without expanding it into + all combinations. Useful for Bayesian optimization where we sample from + the space rather than enumerating all combinations. + + :returns: dictionary mapping parameter names to their search space definitions + :raises ValueError: if the hyperparameters are not in the correct format + :raises KeyError: if the model is not found in the hyperparameters file + """ + hyperparameter_file = os.path.join(os.path.dirname(inspect.getfile(cls)), "hyperparameters.yaml") + + with open(hyperparameter_file, encoding="utf-8") as f: + try: + hpams = yaml.safe_load(f)[cls.get_model_name()] + except yaml.YAMLError as exc: + raise ValueError(f"Error in hyperparameters.yaml: {exc}") from exc + except KeyError as key_exc: + raise KeyError(f"Model {cls.get_model_name()} not found in hyperparameters.yaml") from key_exc + + if hpams is None: + return {} + + return hpams + @property @abstractmethod def cell_line_views(self) -> list[str]: diff --git a/poetry.lock b/poetry.lock index 6e8e7a5e..ec9d6845 100644 --- a/poetry.lock +++ b/poetry.lock @@ -182,6 +182,26 @@ files = [ {file = "alabaster-1.0.0.tar.gz", hash = "sha256:c00dca57bca26fa62a6d7d0a9fcce65f3e026e9bfe33e9c538fd3fbb2144fd9e"}, ] +[[package]] +name = "alembic" +version = "1.18.1" +description = "A database migration tool for SQLAlchemy." +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "alembic-1.18.1-py3-none-any.whl", hash = "sha256:f1c3b0920b87134e851c25f1f7f236d8a332c34b75416802d06971df5d1b7810"}, + {file = "alembic-1.18.1.tar.gz", hash = "sha256:83ac6b81359596816fb3b893099841a0862f2117b2963258e965d70dc62fb866"}, +] + +[package.dependencies] +Mako = "*" +SQLAlchemy = ">=1.4.0" +typing-extensions = ">=4.12" + +[package.extras] +tz = ["tzdata"] + [[package]] name = "annotated-types" version = "0.7.0" @@ -710,7 +730,7 @@ version = "6.10.1" description = "Add colours to the output of Python's logging module." optional = false python-versions = ">=3.6" -groups = ["development"] +groups = ["main", "development"] files = [ {file = "colorlog-6.10.1-py3-none-any.whl", hash = "sha256:2d7e8348291948af66122cff006c9f8da6255d224e7cf8e37d8de2df3bad8c9c"}, {file = "colorlog-6.10.1.tar.gz", hash = "sha256:eb4ae5cb65fe7fec7773c2306061a8e63e02efc2c72eba9d27b0fa23c94f1321"}, @@ -1533,6 +1553,69 @@ gitdb = ">=4.0.1,<5" doc = ["sphinx (>=7.1.2,<7.2)", "sphinx-autodoc-typehints", "sphinx_rtd_theme"] test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock ; python_version < \"3.8\"", "mypy (==1.18.2) ; python_version >= \"3.9\"", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions ; python_version < \"3.11\""] +[[package]] +name = "greenlet" +version = "3.3.0" +description = "Lightweight in-process concurrent programming" +optional = false +python-versions = ">=3.10" +groups = ["main"] +markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\"" +files = [ + {file = "greenlet-3.3.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:6f8496d434d5cb2dce025773ba5597f71f5410ae499d5dd9533e0653258cdb3d"}, + {file = "greenlet-3.3.0-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b96dc7eef78fd404e022e165ec55327f935b9b52ff355b067eb4a0267fc1cffb"}, + {file = "greenlet-3.3.0-cp310-cp310-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:73631cd5cccbcfe63e3f9492aaa664d278fda0ce5c3d43aeda8e77317e38efbd"}, + {file = "greenlet-3.3.0-cp310-cp310-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b299a0cb979f5d7197442dccc3aee67fce53500cd88951b7e6c35575701c980b"}, + {file = "greenlet-3.3.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7dee147740789a4632cace364816046e43310b59ff8fb79833ab043aefa72fd5"}, + {file = "greenlet-3.3.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:39b28e339fc3c348427560494e28d8a6f3561c8d2bcf7d706e1c624ed8d822b9"}, + {file = "greenlet-3.3.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b3c374782c2935cc63b2a27ba8708471de4ad1abaa862ffdb1ef45a643ddbb7d"}, + {file = "greenlet-3.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:b49e7ed51876b459bd645d83db257f0180e345d3f768a35a85437a24d5a49082"}, + {file = "greenlet-3.3.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:e29f3018580e8412d6aaf5641bb7745d38c85228dacf51a73bd4e26ddf2a6a8e"}, + {file = "greenlet-3.3.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a687205fb22794e838f947e2194c0566d3812966b41c78709554aa883183fb62"}, + {file = "greenlet-3.3.0-cp311-cp311-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4243050a88ba61842186cb9e63c7dfa677ec146160b0efd73b855a3d9c7fcf32"}, + {file = "greenlet-3.3.0-cp311-cp311-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:670d0f94cd302d81796e37299bcd04b95d62403883b24225c6b5271466612f45"}, + {file = "greenlet-3.3.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6cb3a8ec3db4a3b0eb8a3c25436c2d49e3505821802074969db017b87bc6a948"}, + {file = "greenlet-3.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2de5a0b09eab81fc6a382791b995b1ccf2b172a9fec934747a7a23d2ff291794"}, + {file = "greenlet-3.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4449a736606bd30f27f8e1ff4678ee193bc47f6ca810d705981cfffd6ce0d8c5"}, + {file = "greenlet-3.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:7652ee180d16d447a683c04e4c5f6441bae7ba7b17ffd9f6b3aff4605e9e6f71"}, + {file = "greenlet-3.3.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:b01548f6e0b9e9784a2c99c5651e5dc89ffcbe870bc5fb2e5ef864e9cc6b5dcb"}, + {file = "greenlet-3.3.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:349345b770dc88f81506c6861d22a6ccd422207829d2c854ae2af8025af303e3"}, + {file = "greenlet-3.3.0-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e8e18ed6995e9e2c0b4ed264d2cf89260ab3ac7e13555b8032b25a74c6d18655"}, + {file = "greenlet-3.3.0-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:c024b1e5696626890038e34f76140ed1daf858e37496d33f2af57f06189e70d7"}, + {file = "greenlet-3.3.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:047ab3df20ede6a57c35c14bf5200fcf04039d50f908270d3f9a7a82064f543b"}, + {file = "greenlet-3.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2d9ad37fc657b1102ec880e637cccf20191581f75c64087a549e66c57e1ceb53"}, + {file = "greenlet-3.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83cd0e36932e0e7f36a64b732a6f60c2fc2df28c351bae79fbaf4f8092fe7614"}, + {file = "greenlet-3.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:a7a34b13d43a6b78abf828a6d0e87d3385680eaf830cd60d20d52f249faabf39"}, + {file = "greenlet-3.3.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:a1e41a81c7e2825822f4e068c48cb2196002362619e2d70b148f20a831c00739"}, + {file = "greenlet-3.3.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9f515a47d02da4d30caaa85b69474cec77b7929b2e936ff7fb853d42f4bf8808"}, + {file = "greenlet-3.3.0-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7d2d9fd66bfadf230b385fdc90426fcd6eb64db54b40c495b72ac0feb5766c54"}, + {file = "greenlet-3.3.0-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:30a6e28487a790417d036088b3bcb3f3ac7d8babaa7d0139edbaddebf3af9492"}, + {file = "greenlet-3.3.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:087ea5e004437321508a8d6f20efc4cfec5e3c30118e1417ea96ed1d93950527"}, + {file = "greenlet-3.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ab97cf74045343f6c60a39913fa59710e4bd26a536ce7ab2397adf8b27e67c39"}, + {file = "greenlet-3.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5375d2e23184629112ca1ea89a53389dddbffcf417dad40125713d88eb5f96e8"}, + {file = "greenlet-3.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:9ee1942ea19550094033c35d25d20726e4f1c40d59545815e1128ac58d416d38"}, + {file = "greenlet-3.3.0-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:60c2ef0f578afb3c8d92ea07ad327f9a062547137afe91f38408f08aacab667f"}, + {file = "greenlet-3.3.0-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0a5d554d0712ba1de0a6c94c640f7aeba3f85b3a6e1f2899c11c2c0428da9365"}, + {file = "greenlet-3.3.0-cp314-cp314-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:3a898b1e9c5f7307ebbde4102908e6cbfcb9ea16284a3abe15cab996bee8b9b3"}, + {file = "greenlet-3.3.0-cp314-cp314-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:dcd2bdbd444ff340e8d6bdf54d2f206ccddbb3ccfdcd3c25bf4afaa7b8f0cf45"}, + {file = "greenlet-3.3.0-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5773edda4dc00e173820722711d043799d3adb4f01731f40619e07ea2750b955"}, + {file = "greenlet-3.3.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:ac0549373982b36d5fd5d30beb8a7a33ee541ff98d2b502714a09f1169f31b55"}, + {file = "greenlet-3.3.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d198d2d977460358c3b3a4dc844f875d1adb33817f0613f663a656f463764ccc"}, + {file = "greenlet-3.3.0-cp314-cp314-win_amd64.whl", hash = "sha256:73f51dd0e0bdb596fb0417e475fa3c5e32d4c83638296e560086b8d7da7c4170"}, + {file = "greenlet-3.3.0-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:d6ed6f85fae6cdfdb9ce04c9bf7a08d666cfcfb914e7d006f44f840b46741931"}, + {file = "greenlet-3.3.0-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d9125050fcf24554e69c4cacb086b87b3b55dc395a8b3ebe6487b045b2614388"}, + {file = "greenlet-3.3.0-cp314-cp314t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:87e63ccfa13c0a0f6234ed0add552af24cc67dd886731f2261e46e241608bee3"}, + {file = "greenlet-3.3.0-cp314-cp314t-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2662433acbca297c9153a4023fe2161c8dcfdcc91f10433171cf7e7d94ba2221"}, + {file = "greenlet-3.3.0-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3c6e9b9c1527a78520357de498b0e709fb9e2f49c3a513afd5a249007261911b"}, + {file = "greenlet-3.3.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:286d093f95ec98fdd92fcb955003b8a3d054b4e2cab3e2707a5039e7b50520fd"}, + {file = "greenlet-3.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6c10513330af5b8ae16f023e8ddbfb486ab355d04467c4679c5cfe4659975dd9"}, + {file = "greenlet-3.3.0.tar.gz", hash = "sha256:a82bb225a4e9e4d653dd2fb7b8b2d36e4fb25bc0165422a11e48b88e9e6f78fb"}, +] + +[package.extras] +docs = ["Sphinx", "furo"] +test = ["objgraph", "psutil", "setuptools"] + [[package]] name = "h11" version = "0.16.0" @@ -2048,6 +2131,26 @@ cli = ["jsonargparse[signatures] (>=4.38.0)", "tomlkit"] docs = ["requests (>=2.0.0)"] typing = ["mypy (>=1.0.0)", "types-setuptools"] +[[package]] +name = "mako" +version = "1.3.10" +description = "A super-fast templating language that borrows the best ideas from the existing templating languages." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59"}, + {file = "mako-1.3.10.tar.gz", hash = "sha256:99579a6f39583fa7e5630a28c3c1f440e4e97a414b80372649c0ce338da2ea28"}, +] + +[package.dependencies] +MarkupSafe = ">=0.9.2" + +[package.extras] +babel = ["Babel"] +lingua = ["lingua"] +testing = ["pytest"] + [[package]] name = "markdown-it-py" version = "4.0.0" @@ -2968,6 +3071,33 @@ files = [ {file = "nvidia_nvtx_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:619c8304aedc69f02ea82dd244541a83c3d9d40993381b3b590f1adaed3db41e"}, ] +[[package]] +name = "optuna" +version = "4.7.0" +description = "A hyperparameter optimization framework" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "optuna-4.7.0-py3-none-any.whl", hash = "sha256:e41ec84018cecc10eabf28143573b1f0bde0ba56dba8151631a590ecbebc1186"}, + {file = "optuna-4.7.0.tar.gz", hash = "sha256:d91817e2079825557bd2e97de2e8c9ae260bfc99b32712502aef8a5095b2d2c0"}, +] + +[package.dependencies] +alembic = ">=1.5.0" +colorlog = "*" +numpy = "*" +packaging = ">=20.0" +PyYAML = "*" +sqlalchemy = ">=1.4.2" +tqdm = "*" + +[package.extras] +checking = ["mypy", "mypy_boto3_s3", "ruff", "scipy-stubs ; python_version >= \"3.10\"", "types-PyYAML", "types-redis", "types-setuptools", "types-tqdm", "typing_extensions (>=3.10.0.0)"] +document = ["ase", "cmaes (>=0.12.0)", "fvcore", "kaleido (<0.4)", "lightgbm", "matplotlib (!=3.6.0)", "pandas", "pillow", "plotly (>=4.9.0)", "scikit-learn", "sphinx", "sphinx-copybutton", "sphinx-gallery", "sphinx-notfound-page", "sphinx_rtd_theme (>=1.2.0)", "torch", "torchvision"] +optional = ["boto3", "cmaes (>=0.12.0)", "google-cloud-storage", "greenlet", "grpcio", "matplotlib (!=3.6.0)", "pandas", "plotly (>=4.9.0)", "protobuf (>=5.28.1)", "redis", "scikit-learn (>=0.24.2)", "scipy", "torch"] +test = ["fakeredis[lua]", "greenlet", "grpcio", "kaleido (<0.4)", "moto", "protobuf (>=5.28.1)", "pytest", "pytest-xdist", "scipy (>=1.9.2)", "torch"] + [[package]] name = "packaging" version = "26.0" @@ -5145,6 +5275,104 @@ lint = ["mypy", "ruff (==0.5.5)", "types-docutils"] standalone = ["Sphinx (>=5)"] test = ["pytest"] +[[package]] +name = "sqlalchemy" +version = "2.0.46" +description = "Database Abstraction Library" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "sqlalchemy-2.0.46-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:895296687ad06dc9b11a024cf68e8d9d3943aa0b4964278d2553b86f1b267735"}, + {file = "sqlalchemy-2.0.46-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ab65cb2885a9f80f979b85aa4e9c9165a31381ca322cbde7c638fe6eefd1ec39"}, + {file = "sqlalchemy-2.0.46-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:52fe29b3817bd191cc20bad564237c808967972c97fa683c04b28ec8979ae36f"}, + {file = "sqlalchemy-2.0.46-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:09168817d6c19954d3b7655da6ba87fcb3a62bb575fb396a81a8b6a9fadfe8b5"}, + {file = "sqlalchemy-2.0.46-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:be6c0466b4c25b44c5d82b0426b5501de3c424d7a3220e86cd32f319ba56798e"}, + {file = "sqlalchemy-2.0.46-cp310-cp310-win32.whl", hash = "sha256:1bc3f601f0a818d27bfe139f6766487d9c88502062a2cd3a7ee6c342e81d5047"}, + {file = "sqlalchemy-2.0.46-cp310-cp310-win_amd64.whl", hash = "sha256:e0c05aff5c6b1bb5fb46a87e0f9d2f733f83ef6cbbbcd5c642b6c01678268061"}, + {file = "sqlalchemy-2.0.46-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:261c4b1f101b4a411154f1da2b76497d73abbfc42740029205d4d01fa1052684"}, + {file = "sqlalchemy-2.0.46-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:181903fe8c1b9082995325f1b2e84ac078b1189e2819380c2303a5f90e114a62"}, + {file = "sqlalchemy-2.0.46-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:590be24e20e2424a4c3c1b0835e9405fa3d0af5823a1a9fc02e5dff56471515f"}, + {file = "sqlalchemy-2.0.46-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7568fe771f974abadce52669ef3a03150ff03186d8eb82613bc8adc435a03f01"}, + {file = "sqlalchemy-2.0.46-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ebf7e1e78af38047e08836d33502c7a278915698b7c2145d045f780201679999"}, + {file = "sqlalchemy-2.0.46-cp311-cp311-win32.whl", hash = "sha256:9d80ea2ac519c364a7286e8d765d6cd08648f5b21ca855a8017d9871f075542d"}, + {file = "sqlalchemy-2.0.46-cp311-cp311-win_amd64.whl", hash = "sha256:585af6afe518732d9ccd3aea33af2edaae4a7aa881af5d8f6f4fe3a368699597"}, + {file = "sqlalchemy-2.0.46-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3a9a72b0da8387f15d5810f1facca8f879de9b85af8c645138cba61ea147968c"}, + {file = "sqlalchemy-2.0.46-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2347c3f0efc4de367ba00218e0ae5c4ba2306e47216ef80d6e31761ac97cb0b9"}, + {file = "sqlalchemy-2.0.46-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9094c8b3197db12aa6f05c51c05daaad0a92b8c9af5388569847b03b1007fb1b"}, + {file = "sqlalchemy-2.0.46-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:37fee2164cf21417478b6a906adc1a91d69ae9aba8f9533e67ce882f4bb1de53"}, + {file = "sqlalchemy-2.0.46-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b1e14b2f6965a685c7128bd315e27387205429c2e339eeec55cb75ca4ab0ea2e"}, + {file = "sqlalchemy-2.0.46-cp312-cp312-win32.whl", hash = "sha256:412f26bb4ba942d52016edc8d12fb15d91d3cd46b0047ba46e424213ad407bcb"}, + {file = "sqlalchemy-2.0.46-cp312-cp312-win_amd64.whl", hash = "sha256:ea3cd46b6713a10216323cda3333514944e510aa691c945334713fca6b5279ff"}, + {file = "sqlalchemy-2.0.46-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:93a12da97cca70cea10d4b4fc602589c4511f96c1f8f6c11817620c021d21d00"}, + {file = "sqlalchemy-2.0.46-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:af865c18752d416798dae13f83f38927c52f085c52e2f32b8ab0fef46fdd02c2"}, + {file = "sqlalchemy-2.0.46-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8d679b5f318423eacb61f933a9a0f75535bfca7056daeadbf6bd5bcee6183aee"}, + {file = "sqlalchemy-2.0.46-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:64901e08c33462acc9ec3bad27fc7a5c2b6491665f2aa57564e57a4f5d7c52ad"}, + {file = "sqlalchemy-2.0.46-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e8ac45e8f4eaac0f9f8043ea0e224158855c6a4329fd4ee37c45c61e3beb518e"}, + {file = "sqlalchemy-2.0.46-cp313-cp313-win32.whl", hash = "sha256:8d3b44b3d0ab2f1319d71d9863d76eeb46766f8cf9e921ac293511804d39813f"}, + {file = "sqlalchemy-2.0.46-cp313-cp313-win_amd64.whl", hash = "sha256:77f8071d8fbcbb2dd11b7fd40dedd04e8ebe2eb80497916efedba844298065ef"}, + {file = "sqlalchemy-2.0.46-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a1e8cc6cc01da346dc92d9509a63033b9b1bda4fed7a7a7807ed385c7dccdc10"}, + {file = "sqlalchemy-2.0.46-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:96c7cca1a4babaaf3bfff3e4e606e38578856917e52f0384635a95b226c87764"}, + {file = "sqlalchemy-2.0.46-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b2a9f9aee38039cf4755891a1e50e1effcc42ea6ba053743f452c372c3152b1b"}, + {file = "sqlalchemy-2.0.46-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:db23b1bf8cfe1f7fda19018e7207b20cdb5168f83c437ff7e95d19e39289c447"}, + {file = "sqlalchemy-2.0.46-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:56bdd261bfd0895452006d5316cbf35739c53b9bb71a170a331fa0ea560b2ada"}, + {file = "sqlalchemy-2.0.46-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:33e462154edb9493f6c3ad2125931e273bbd0be8ae53f3ecd1c161ea9a1dd366"}, + {file = "sqlalchemy-2.0.46-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9bcdce05f056622a632f1d44bb47dbdb677f58cad393612280406ce37530eb6d"}, + {file = "sqlalchemy-2.0.46-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:8e84b09a9b0f19accedcbeff5c2caf36e0dd537341a33aad8d680336152dc34e"}, + {file = "sqlalchemy-2.0.46-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:4f52f7291a92381e9b4de9050b0a65ce5d6a763333406861e33906b8aa4906bf"}, + {file = "sqlalchemy-2.0.46-cp314-cp314-win32.whl", hash = "sha256:70ed2830b169a9960193f4d4322d22be5c0925357d82cbf485b3369893350908"}, + {file = "sqlalchemy-2.0.46-cp314-cp314-win_amd64.whl", hash = "sha256:3c32e993bc57be6d177f7d5d31edb93f30726d798ad86ff9066d75d9bf2e0b6b"}, + {file = "sqlalchemy-2.0.46-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4dafb537740eef640c4d6a7c254611dca2df87eaf6d14d6a5fca9d1f4c3fc0fa"}, + {file = "sqlalchemy-2.0.46-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:42a1643dc5427b69aca967dae540a90b0fbf57eaf248f13a90ea5930e0966863"}, + {file = "sqlalchemy-2.0.46-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:ff33c6e6ad006bbc0f34f5faf941cfc62c45841c64c0a058ac38c799f15b5ede"}, + {file = "sqlalchemy-2.0.46-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:82ec52100ec1e6ec671563bbd02d7c7c8d0b9e71a0723c72f22ecf52d1755330"}, + {file = "sqlalchemy-2.0.46-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6ac245604295b521de49b465bab845e3afe6916bcb2147e5929c8041b4ec0545"}, + {file = "sqlalchemy-2.0.46-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1e6199143d51e3e1168bedd98cc698397404a8f7508831b81b6a29b18b051069"}, + {file = "sqlalchemy-2.0.46-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:716be5bcabf327b6d5d265dbdc6213a01199be587224eb991ad0d37e83d728fd"}, + {file = "sqlalchemy-2.0.46-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:6f827fd687fa1ba7f51699e1132129eac8db8003695513fcf13fc587e1bd47a5"}, + {file = "sqlalchemy-2.0.46-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c805fa6e5d461329fa02f53f88c914d189ea771b6821083937e79550bf31fc19"}, + {file = "sqlalchemy-2.0.46-cp38-cp38-win32.whl", hash = "sha256:3aac08f7546179889c62b53b18ebf1148b10244b3405569c93984b0388d016a7"}, + {file = "sqlalchemy-2.0.46-cp38-cp38-win_amd64.whl", hash = "sha256:0cc3117db526cad3e61074100bd2867b533e2c7dc1569e95c14089735d6fb4fe"}, + {file = "sqlalchemy-2.0.46-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:90bde6c6b1827565a95fde597da001212ab436f1b2e0c2dcc7246e14db26e2a3"}, + {file = "sqlalchemy-2.0.46-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:94b1e5f3a5f1ff4f42d5daab047428cd45a3380e51e191360a35cef71c9a7a2a"}, + {file = "sqlalchemy-2.0.46-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:93bb0aae40b52c57fd74ef9c6933c08c040ba98daf23ad33c3f9893494b8d3ce"}, + {file = "sqlalchemy-2.0.46-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c4e2cc868b7b5208aec6c960950b7bb821f82c2fe66446c92ee0a571765e91a5"}, + {file = "sqlalchemy-2.0.46-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:965c62be8256d10c11f8907e7a8d3e18127a4c527a5919d85fa87fd9ecc2cfdc"}, + {file = "sqlalchemy-2.0.46-cp39-cp39-win32.whl", hash = "sha256:9397b381dcee8a2d6b99447ae85ea2530dcac82ca494d1db877087a13e38926d"}, + {file = "sqlalchemy-2.0.46-cp39-cp39-win_amd64.whl", hash = "sha256:4396c948d8217e83e2c202fbdcc0389cf8c93d2c1c5e60fa5c5a955eae0e64be"}, + {file = "sqlalchemy-2.0.46-py3-none-any.whl", hash = "sha256:f9c11766e7e7c0a2767dda5acb006a118640c9fc0a4104214b96269bfb78399e"}, + {file = "sqlalchemy-2.0.46.tar.gz", hash = "sha256:cf36851ee7219c170bb0793dbc3da3e80c582e04a5437bc601bfe8c85c9216d7"}, +] + +[package.dependencies] +greenlet = {version = ">=1", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""} +typing-extensions = ">=4.6.0" + +[package.extras] +aiomysql = ["aiomysql (>=0.2.0)", "greenlet (>=1)"] +aioodbc = ["aioodbc", "greenlet (>=1)"] +aiosqlite = ["aiosqlite", "greenlet (>=1)", "typing_extensions (!=3.10.0.1)"] +asyncio = ["greenlet (>=1)"] +asyncmy = ["asyncmy (>=0.2.3,!=0.2.4,!=0.2.6)", "greenlet (>=1)"] +mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2,!=1.1.5,!=1.1.10)"] +mssql = ["pyodbc"] +mssql-pymssql = ["pymssql"] +mssql-pyodbc = ["pyodbc"] +mypy = ["mypy (>=0.910)"] +mysql = ["mysqlclient (>=1.4.0)"] +mysql-connector = ["mysql-connector-python"] +oracle = ["cx_oracle (>=8)"] +oracle-oracledb = ["oracledb (>=1.0.1)"] +postgresql = ["psycopg2 (>=2.7)"] +postgresql-asyncpg = ["asyncpg", "greenlet (>=1)"] +postgresql-pg8000 = ["pg8000 (>=1.29.1)"] +postgresql-psycopg = ["psycopg (>=3.0.7)"] +postgresql-psycopg2binary = ["psycopg2-binary"] +postgresql-psycopg2cffi = ["psycopg2cffi"] +postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] +pymysql = ["pymysql"] +sqlcipher = ["sqlcipher3_binary"] + [[package]] name = "starlette" version = "0.52.1" @@ -6402,4 +6630,4 @@ multiprocessing = ["pydantic", "ray"] [metadata] lock-version = "2.1" python-versions = ">=3.11,<3.14" -content-hash = "b48d33e2c6e66c3fa8aa5b42db423e8577a44e19090b766f65051f4b9587dde4" +content-hash = "205db54ef1c6b5ee0b5d9821e6bc9f10d020e674c71fe62a03a477c226ab0b47" diff --git a/pyproject.toml b/pyproject.toml index d6056275..c8a0b56a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ poetry = "^2.0.1" starlette = ">=0.49.1" pydantic = { version = ">=2.5", optional = true } wandb = "^0.24.0" +optuna = "^4.7.0" [tool.poetry.requires-plugins] poetry-plugin-export = ">=1.8" diff --git a/tests/models/test_global_models.py b/tests/models/test_global_models.py index 69e20b2b..535b98e6 100644 --- a/tests/models/test_global_models.py +++ b/tests/models/test_global_models.py @@ -5,9 +5,13 @@ from typing import cast import numpy as np +import pandas as pd import pytest +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler from drevalpy.datasets.dataset import DrugResponseDataset +from drevalpy.datasets.utils import CELL_LINE_IDENTIFIER from drevalpy.evaluation import evaluate from drevalpy.experiment import cross_study_prediction from drevalpy.models import MODEL_FACTORY @@ -156,3 +160,142 @@ def test_global_models( split_index=0, single_drug_id=None, ) + + +def test_pca_neural_network( + sample_dataset: DrugResponseDataset, + cross_study_dataset: DrugResponseDataset, +) -> None: + """ + Test PCANeuralNetwork model. + + This test creates PCA features from gene expression data, then trains and evaluates + the PCANeuralNetwork model. + + :param sample_dataset: from conftest.py + :param cross_study_dataset: from conftest.py + :raises ValueError: if drug input is None + """ + test_mode = "LTO" + model_name = "PCANeuralNetwork" + n_components = 10 # Use small number for testing + + drug_response = sample_dataset + drug_response.split_dataset(n_cv_splits=2, mode=test_mode, validation_ratio=0.4) + assert drug_response.cv_splits is not None + split = drug_response.cv_splits[0] + train_dataset = split["train"] + val_es_dataset = split["validation_es"] + es_dataset = split["early_stopping"] + + path_data = os.path.join("..", "data") + + # Create PCA features from gene expression data + ge_file = os.path.join(path_data, "TOYv1", "gene_expression.csv") + ge_df = pd.read_csv(ge_file, index_col=CELL_LINE_IDENTIFIER) + ge_df.index = ge_df.index.astype(str) + if "cellosaurus_id" in ge_df.columns: + ge_df = ge_df.drop(columns=["cellosaurus_id"]) + + # Perform PCA + scaler = StandardScaler() + ge_scaled = scaler.fit_transform(ge_df.values) + pca = PCA(n_components=n_components) + pca_features = pca.fit_transform(ge_scaled) + + # Save PCA features + pca_df = pd.DataFrame( + pca_features, + index=ge_df.index, + columns=[f"PC{i + 1}" for i in range(n_components)], + ) + pca_df.index.name = CELL_LINE_IDENTIFIER + pca_df = pca_df.reset_index() + + pca_output_file = os.path.join(path_data, "TOYv1", f"cell_line_gene_expression_pca_{n_components}.csv") + pca_df.to_csv(pca_output_file, index=False) + + try: + # Load model and features - need to build_model first to set n_components in hyperparameters + model_class = cast(type[DRPModel], MODEL_FACTORY[model_name]) + model = model_class() + hpams = model.get_hyperparameter_set() + hpam_combi = hpams[0] + hpam_combi["units_per_layer"] = [2, 2] + hpam_combi["max_epochs"] = 1 + hpam_combi["n_components"] = n_components + model.build_model(hyperparameters=hpam_combi) + + cell_line_input = model.load_cell_line_features(data_path=path_data, dataset_name="TOYv1") + drug_input = model.load_drug_features(data_path=path_data, dataset_name="TOYv1") + if drug_input is None: + raise ValueError("Drug input is None") + + cell_lines_to_keep = cell_line_input.identifiers + drugs_to_keep = drug_input.identifiers + + train_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep) + val_es_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep) + es_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.train( + output=train_dataset, + cell_line_input=cell_line_input, + drug_input=drug_input, + output_earlystopping=es_dataset, + model_checkpoint_dir=tmpdirname, + ) + + prediction_dataset = val_es_dataset + prediction_dataset._predictions = model.predict( + drug_ids=prediction_dataset.drug_ids, + cell_line_ids=prediction_dataset.cell_line_ids, + drug_input=drug_input, + cell_line_input=cell_line_input, + ) + + # Save and load test + with tempfile.TemporaryDirectory() as model_dir: + model.save(model_dir) + loaded_model = model_class.load(model_dir) + assert isinstance(loaded_model, DRPModel) + + preds_before = model.predict( + drug_ids=prediction_dataset.drug_ids, + cell_line_ids=prediction_dataset.cell_line_ids, + drug_input=drug_input, + cell_line_input=cell_line_input, + ) + preds_after = loaded_model.predict( + drug_ids=prediction_dataset.drug_ids, + cell_line_ids=prediction_dataset.cell_line_ids, + drug_input=drug_input, + cell_line_input=cell_line_input, + ) + + assert preds_before.shape == preds_after.shape + assert isinstance(preds_after, np.ndarray) + + metrics = evaluate(prediction_dataset, metric=["Pearson"]) + print(f"Model: {model_name}, Pearson: {metrics['Pearson']}") + assert metrics["Pearson"] >= -1.0 + + with tempfile.TemporaryDirectory() as temp_dir: + print(f"Running cross-study prediction for {model_name}") + cross_study_prediction( + dataset=cross_study_dataset, + model=model, + test_mode=test_mode, + train_dataset=train_dataset, + path_data=path_data, + early_stopping_dataset=None, + response_transformation=None, + path_out=temp_dir, + split_index=0, + single_drug_id=None, + ) + finally: + # Clean up the generated PCA file + if os.path.exists(pca_output_file): + os.remove(pca_output_file) diff --git a/tests/test_featurizers.py b/tests/test_featurizers.py index a2db9b69..f27e178a 100644 --- a/tests/test_featurizers.py +++ b/tests/test_featurizers.py @@ -1,10 +1,11 @@ -"""Tests for drug featurizers.""" +"""Tests for drug and cell line featurizers.""" import sys from unittest.mock import patch +import numpy as np import pandas as pd -import torch +import pytest def test_chemberta_featurizer(tmp_path): @@ -14,10 +15,10 @@ def test_chemberta_featurizer(tmp_path): :param tmp_path: Temporary path provided by pytest. """ try: - import drevalpy.datasets.featurizer.create_chemberta_drug_embeddings as chemberta + from drevalpy.datasets.featurizer import ChemBERTaFeaturizer except ImportError: - print("transformers package not installed; skipping ChemBERTa featurizer test.") - return + pytest.skip("transformers package not installed; skipping ChemBERTa featurizer test.") + dataset = "testset" data_dir = tmp_path / dataset data_dir.mkdir(parents=True) @@ -26,20 +27,54 @@ def test_chemberta_featurizer(tmp_path): df = pd.DataFrame({"pubchem_id": ["X1"], "canonical_smiles": ["CCO"]}) (data_dir / "drug_smiles.csv").write_text(df.to_csv(index=False)) - fake_embedding = [1.0, 2.0, 3.0] + fake_embedding = np.array([1.0, 2.0, 3.0], dtype=np.float32) - with patch.object(chemberta, "_smiles_to_chemberta", return_value=fake_embedding), patch.object( - sys, "argv", ["prog", dataset, "--data_path", str(tmp_path)] - ): + featurizer = ChemBERTaFeaturizer(device="cpu") - chemberta.main() + with patch.object(featurizer, "featurize", return_value=fake_embedding): + result = featurizer.generate_embeddings(str(tmp_path), dataset) out_file = data_dir / "drug_chemberta_embeddings.csv" assert out_file.exists() df_out = pd.read_csv(out_file) assert df_out.pubchem_id.tolist() == ["X1"] - assert df_out.iloc[0, 1:].tolist() == fake_embedding + assert df_out.iloc[0, 1:].tolist() == fake_embedding.tolist() + + # Test that FeatureDataset is returned correctly + assert "X1" in result.features + assert "chemberta_embeddings" in result.features["X1"] + + +def test_chemberta_featurizer_cli(tmp_path): + """ + Test ChemBERTa featurizer CLI entry point. + + :param tmp_path: Temporary path provided by pytest. + """ + try: + from drevalpy.datasets.featurizer.drug import chemberta + except ImportError: + pytest.skip("transformers package not installed; skipping ChemBERTa featurizer CLI test.") + + dataset = "testset" + data_dir = tmp_path / dataset + data_dir.mkdir(parents=True) + + # fake input CSV + df = pd.DataFrame({"pubchem_id": ["X1"], "canonical_smiles": ["CCO"]}) + (data_dir / "drug_smiles.csv").write_text(df.to_csv(index=False)) + + fake_embedding = np.array([1.0, 2.0, 3.0], dtype=np.float32) + + with ( + patch.object(chemberta.ChemBERTaFeaturizer, "featurize", return_value=fake_embedding), + patch.object(sys, "argv", ["prog", dataset, "--data_path", str(tmp_path)]), + ): + chemberta.main() + + out_file = data_dir / "drug_chemberta_embeddings.csv" + assert out_file.exists() def test_graph_featurizer(tmp_path): @@ -49,10 +84,10 @@ def test_graph_featurizer(tmp_path): :param tmp_path: Temporary path provided by pytest. """ try: - import drevalpy.datasets.featurizer.create_drug_graphs as graphs + from drevalpy.datasets.featurizer import DrugGraphFeaturizer except ImportError: - print("rdkit package not installed; skipping graph featurizer test.") - return + pytest.skip("rdkit package not installed; skipping graph featurizer test.") + dataset = "testset" data_dir = tmp_path / dataset data_dir.mkdir(parents=True) @@ -61,9 +96,39 @@ def test_graph_featurizer(tmp_path): df = pd.DataFrame({"pubchem_id": ["D1"], "canonical_smiles": ["CCO"]}) (data_dir / "drug_smiles.csv").write_text(df.to_csv(index=False)) - # run main exactly as the script would - sys.argv = ["prog", dataset, "--data_path", str(tmp_path)] - graphs.main() + featurizer = DrugGraphFeaturizer() + result = featurizer.generate_embeddings(str(tmp_path), dataset) + + # expected output file + out_file = data_dir / "drug_graphs" / "D1.pt" + assert out_file.exists() + + # Test that FeatureDataset is returned correctly + assert "D1" in result.features + assert "drug_graphs" in result.features["D1"] + + +def test_graph_featurizer_cli(tmp_path): + """ + Test graph featurizer CLI entry point. + + :param tmp_path: Temporary path provided by pytest. + """ + try: + from drevalpy.datasets.featurizer.drug import drug_graph + except ImportError: + pytest.skip("rdkit package not installed; skipping graph featurizer CLI test.") + + dataset = "testset" + data_dir = tmp_path / dataset + data_dir.mkdir(parents=True) + + # write minimal SMILES CSV + df = pd.DataFrame({"pubchem_id": ["D1"], "canonical_smiles": ["CCO"]}) + (data_dir / "drug_smiles.csv").write_text(df.to_csv(index=False)) + + with patch.object(sys, "argv", ["prog", dataset, "--data_path", str(tmp_path)]): + drug_graph.main() # expected output file out_file = data_dir / "drug_graphs" / "D1.pt" @@ -77,10 +142,54 @@ def test_molgnet_featurizer(tmp_path): :param tmp_path: Temporary path provided by pytest. """ try: - import drevalpy.datasets.featurizer.create_molgnet_embeddings as molg + from drevalpy.datasets.featurizer import MolGNetFeaturizer + from drevalpy.datasets.featurizer.drug import molgnet except ImportError: - print("rdkit package not installed; skipping molgnet featurizer test.") - return + pytest.skip("rdkit package not installed; skipping molgnet featurizer test.") + + ds = "testset" + ds_dir = tmp_path / ds + ds_dir.mkdir(parents=True) + + # minimal SMILES CSV + df = pd.DataFrame({"pubchem_id": ["D1"], "canonical_smiles": ["CCO"]}) + (ds_dir / "drug_smiles.csv").write_text(df.to_csv(index=False)) + + # Create a fake checkpoint file + checkpoint_path = str(tmp_path / "MolGNet.pt") + + featurizer = MolGNetFeaturizer(checkpoint_path=checkpoint_path, device="cpu") + + with ( + # we dont need real model weights for this test, takes too long to load + patch("drevalpy.datasets.featurizer.drug.molgnet.torch.load", return_value={}), + # prevent load_state_dict from complaining + patch.object(molgnet.MolGNet, "load_state_dict", return_value=None), + # cheap forward pass + patch.object(molgnet.MolGNet, "forward", return_value=torch.zeros((1, 768))), + ): + result = featurizer.generate_embeddings(str(tmp_path), ds) + + # verify outputs + assert (ds_dir / "DIPK_features/Drugs" / "MolGNet_D1.csv").exists() + assert (ds_dir / "MolGNet_dict.pkl").exists() + + # Test that FeatureDataset is returned correctly + assert "D1" in result.features + assert "molgnet_embeddings" in result.features["D1"] + + +def test_molgnet_featurizer_cli(tmp_path): + """ + Test MolGNet featurizer CLI entry point. + + :param tmp_path: Temporary path provided by pytest. + """ + try: + from drevalpy.datasets.featurizer.drug import molgnet + except ImportError: + pytest.skip("rdkit package not installed; skipping molgnet featurizer CLI test.") + ds = "testset" ds_dir = tmp_path / ds ds_dir.mkdir(parents=True) @@ -91,22 +200,19 @@ def test_molgnet_featurizer(tmp_path): with ( # we dont need real model weights for this test, takes too long to load - patch("drevalpy.datasets.featurizer.create_molgnet_embeddings.torch.load", return_value={}), + patch("drevalpy.datasets.featurizer.drug.molgnet.torch.load", return_value={}), # prevent load_state_dict from complaining - patch.object(molg.MolGNet, "load_state_dict", return_value=None), + patch.object(molgnet.MolGNet, "load_state_dict", return_value=None), # cheap forward pass - patch.object(molg.MolGNet, "forward", return_value=torch.zeros((1, 768))), - # avoid writing pickles - patch.object(molg.pickle, "dump", return_value=None), + patch.object(molgnet.MolGNet, "forward", return_value=torch.zeros((1, 768))), # simulate CLI patch.object( sys, "argv", - ["prog", ds, "--data_path", str(tmp_path), "--checkpoint", "MolGNet.pt"], + ["prog", ds, "--data_path", str(tmp_path), "--checkpoint", str(tmp_path / "MolGNet.pt")], ), ): - args = molg.parse_args() - molg.run(args) + molgnet.main() # verify outputs assert (ds_dir / "DIPK_features/Drugs" / "MolGNet_D1.csv").exists() @@ -123,6 +229,7 @@ def test_bpe_smiles_featurizer(tmp_path): except ImportError: print("subword-nmt package not installed; skipping BPE SMILES featurizer test.") return + dataset = "testset" data_dir = tmp_path / dataset data_dir.mkdir(parents=True) @@ -155,3 +262,188 @@ def test_bpe_smiles_featurizer(tmp_path): assert len(feature_cols) == 128 # Values should be numeric (character ordinals, may be stored as float in CSV) assert pd.api.types.is_numeric_dtype(df_out[feature_cols[0]]) + + +def test_transcriptome_pca_featurizer(tmp_path): + """ + Test transcriptome PCA featurizer end-to-end. + + :param tmp_path: Temporary path provided by pytest. + """ + try: + from drevalpy.datasets.featurizer import PCAFeaturizer + except ImportError: + pytest.skip("sklearn package not installed; skipping transcriptome PCA featurizer test.") + + dataset = "testset" + data_dir = tmp_path / dataset + data_dir.mkdir(parents=True) + + # Create fake gene expression CSV + # Format: rows are cell lines, columns are genes + n_cell_lines = 10 + n_genes = 100 + cell_line_names = [f"CL{i}" for i in range(n_cell_lines)] + gene_names = [f"GENE{i}" for i in range(n_genes)] + + # Generate some random gene expression data + np.random.seed(42) + ge_data = np.random.randn(n_cell_lines, n_genes).astype(np.float32) + + ge_df = pd.DataFrame(ge_data, index=cell_line_names, columns=gene_names) + ge_df.index.name = "cell_line_name" + ge_df = ge_df.reset_index() + + (data_dir / "gene_expression.csv").write_text(ge_df.to_csv(index=False)) + + # Run the featurizer + featurizer = PCAFeaturizer(n_components=10) + result = featurizer.generate_embeddings(str(tmp_path), dataset) + + # Check output files + output_file = data_dir / "cell_line_gene_expression_pca_10.csv" + model_file = data_dir / "cell_line_gene_expression_pca_10_models.pkl" + + assert output_file.exists() + assert model_file.exists() + + # Verify output CSV structure + df_out = pd.read_csv(output_file) + assert "cell_line_name" in df_out.columns + assert len(df_out.columns) == 11 # cell_line_name + 10 PC columns + assert len(df_out) == n_cell_lines + + # Test that FeatureDataset is returned correctly + assert "CL0" in result.features + assert "gene_expression_pca" in result.features["CL0"] + assert result.features["CL0"]["gene_expression_pca"].shape == (10,) + + +def test_pca_featurizer_cli(tmp_path): + """ + Test transcriptome PCA featurizer CLI entry point. + + :param tmp_path: Temporary path provided by pytest. + """ + try: + from drevalpy.datasets.featurizer.cell_line import pca + except ImportError: + pytest.skip("sklearn package not installed; skipping transcriptome PCA featurizer CLI test.") + + dataset = "testset" + data_dir = tmp_path / dataset + data_dir.mkdir(parents=True) + + # Create fake gene expression CSV + n_cell_lines = 10 + n_genes = 100 + cell_line_names = [f"CL{i}" for i in range(n_cell_lines)] + gene_names = [f"GENE{i}" for i in range(n_genes)] + + np.random.seed(42) + ge_data = np.random.randn(n_cell_lines, n_genes).astype(np.float32) + + ge_df = pd.DataFrame(ge_data, index=cell_line_names, columns=gene_names) + ge_df.index.name = "cell_line_name" + ge_df = ge_df.reset_index() + + (data_dir / "gene_expression.csv").write_text(ge_df.to_csv(index=False)) + + # Run the featurizer via CLI + with patch.object( + sys, + "argv", + ["prog", dataset, "--data_path", str(tmp_path), "--n_components", "10"], + ): + pca.main() + + # Check output files + output_file = data_dir / "cell_line_gene_expression_pca_10.csv" + assert output_file.exists() + + +def test_pca_featurizer_load_or_generate(tmp_path): + """ + Test that load_or_generate loads existing embeddings or generates new ones. + + :param tmp_path: Temporary path provided by pytest. + """ + try: + from drevalpy.datasets.featurizer import PCAFeaturizer + except ImportError: + pytest.skip("sklearn package not installed; skipping PCA featurizer test.") + + dataset = "testset" + data_dir = tmp_path / dataset + data_dir.mkdir(parents=True) + + # Create fake gene expression CSV + n_cell_lines = 10 + n_genes = 100 + cell_line_names = [f"CL{i}" for i in range(n_cell_lines)] + gene_names = [f"GENE{i}" for i in range(n_genes)] + + np.random.seed(42) + ge_data = np.random.randn(n_cell_lines, n_genes).astype(np.float32) + + ge_df = pd.DataFrame(ge_data, index=cell_line_names, columns=gene_names) + ge_df.index.name = "cell_line_name" + ge_df = ge_df.reset_index() + + (data_dir / "gene_expression.csv").write_text(ge_df.to_csv(index=False)) + + # First call should generate embeddings + featurizer1 = PCAFeaturizer(n_components=10) + result1 = featurizer1.load_or_generate(str(tmp_path), dataset) + + # Second call should load existing embeddings + featurizer2 = PCAFeaturizer(n_components=10) + result2 = featurizer2.load_or_generate(str(tmp_path), dataset) + + # Results should be the same + assert set(result1.features.keys()) == set(result2.features.keys()) + for cell_line_id in result1.features: + np.testing.assert_array_almost_equal( + result1.features[cell_line_id]["gene_expression_pca"], + result2.features[cell_line_id]["gene_expression_pca"], + ) + + +def test_chemberta_featurizer_load_or_generate(tmp_path): + """ + Test that load_or_generate loads existing embeddings or generates new ones. + + :param tmp_path: Temporary path provided by pytest. + """ + try: + from drevalpy.datasets.featurizer import ChemBERTaFeaturizer + except ImportError: + pytest.skip("transformers package not installed; skipping ChemBERTa featurizer test.") + + dataset = "testset" + data_dir = tmp_path / dataset + data_dir.mkdir(parents=True) + + # fake input CSV + df = pd.DataFrame({"pubchem_id": ["X1", "X2"], "canonical_smiles": ["CCO", "CC"]}) + (data_dir / "drug_smiles.csv").write_text(df.to_csv(index=False)) + + fake_embedding = np.array([1.0, 2.0, 3.0], dtype=np.float32) + + featurizer1 = ChemBERTaFeaturizer(device="cpu") + + # First call should generate embeddings + with patch.object(featurizer1, "featurize", return_value=fake_embedding): + result1 = featurizer1.load_or_generate(str(tmp_path), dataset) + + # Second call should load existing embeddings (no mock needed) + featurizer2 = ChemBERTaFeaturizer(device="cpu") + result2 = featurizer2.load_or_generate(str(tmp_path), dataset) + + # Results should be the same + assert set(result1.features.keys()) == set(result2.features.keys()) + for drug_id in result1.features: + np.testing.assert_array_almost_equal( + result1.features[drug_id]["chemberta_embeddings"], + result2.features[drug_id]["chemberta_embeddings"], + )