From e417f1497bfea5fe7ba138af7bcaeda242a22ae4 Mon Sep 17 00:00:00 2001 From: Aymen Kallala Date: Fri, 16 Aug 2024 05:03:46 -0700 Subject: [PATCH 01/13] [Finetuning] Add scripts for finetuning on the GLRB --- finetuning_glrb/README.md | 16 + finetuning_glrb/__init__.py | 0 finetuning_glrb/finetune.sh | 35 ++ finetuning_glrb/finetune_bulk_rna.py | 387 ++++++++++++ finetuning_glrb/finetune_chromatin.py | 553 +++++++++++++++++ .../finetune_regulatory_elements.py | 554 ++++++++++++++++++ .../finetune_variant_effect_OMIM.py | 430 ++++++++++++++ .../finetune_variant_effect_causal_eqtl.py | 444 ++++++++++++++ ...etune_variant_effect_pathogenic_clinvar.py | 430 ++++++++++++++ finetuning_glrb/main.py | 99 ++++ finetuning_glrb/utils.py | 70 +++ 11 files changed, 3018 insertions(+) create mode 100644 finetuning_glrb/README.md create mode 100644 finetuning_glrb/__init__.py create mode 100644 finetuning_glrb/finetune.sh create mode 100644 finetuning_glrb/finetune_bulk_rna.py create mode 100644 finetuning_glrb/finetune_chromatin.py create mode 100644 finetuning_glrb/finetune_regulatory_elements.py create mode 100644 finetuning_glrb/finetune_variant_effect_OMIM.py create mode 100644 finetuning_glrb/finetune_variant_effect_causal_eqtl.py create mode 100644 finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py create mode 100644 finetuning_glrb/main.py create mode 100644 finetuning_glrb/utils.py diff --git a/finetuning_glrb/README.md b/finetuning_glrb/README.md new file mode 100644 index 0000000..86e85ad --- /dev/null +++ b/finetuning_glrb/README.md @@ -0,0 +1,16 @@ +# Fine-Tuning DNA Models on the Genomics Long Range Benchmark 🧬 + +This folder contains the necessary scripts and configurations to fine-tune DNA models on the [Genomics Long Range Benchmark](https://huggingface.co/datasets/InstaDeepAI/genomics-long-range-benchmark). + +DNA Models are loaded from the Hugging-Face Hub 🤗. + +## Getting Started + +To fine-tune a model, execute the `finetune.sh` script. The script runs the `main.py` script with various command-line arguments that configure the fine-tuning process. Below is a description of each argument used. + +### Running the Script + +To start finetuning, first make sure that you have modified the `finetune.sh` script with the correct parameters for your task. Then, simply run: + +```bash +bash finetune.sh diff --git a/finetuning_glrb/__init__.py b/finetuning_glrb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/finetuning_glrb/finetune.sh b/finetuning_glrb/finetune.sh new file mode 100644 index 0000000..d83fd7c --- /dev/null +++ b/finetuning_glrb/finetune.sh @@ -0,0 +1,35 @@ +python finetuning_glrb/main.py \ + --task "task_name" \ + --seq_len 1000 \ + --model_name "model_name_on_the_huggingface_hub" \ + --bp_per_token 1 \ + --save_dir "output/" \ + --wandb_api_key "your_wandb_api_key" \ + --name_wb "name_for_your_wandb_run" \ + --rcps true \ + --train_batch_size 16 \ + --test_batch_size 16 \ + --num_workers 6 \ + --num_epochs 10 \ + --learning_rate "3e-5" \ + --patience 3 \ + --log_interval 280 \ + --accumulate_grad_batches 4 \ + --train_ratio 1.0 \ + --eval_ratio 1.0 + +##Examples + +## Caduceus-PS +#task=bulk_rna_expression +#seq_len=131000 +#bp_per_token=1 +#model_name="kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16" +#rcps=true + +## NTv2 +#task=regulatory_element_promoter +#seq_len=12288 # 2048 (seq len) * 6 (kmers) +#bp_per_token=6 +#model_name="InstaDeepAI/nucleotide-transformer-v2-50m-multi-species" +#delete the rcps flag (it is not a RC-equivarient model) diff --git a/finetuning_glrb/finetune_bulk_rna.py b/finetuning_glrb/finetune_bulk_rna.py new file mode 100644 index 0000000..3343786 --- /dev/null +++ b/finetuning_glrb/finetune_bulk_rna.py @@ -0,0 +1,387 @@ +import os +from functools import partial +from os import path as osp +import re +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForMaskedLM, + AutoModel, + AutoTokenizer, + AutoConfig, + DefaultDataCollator, +) +import wandb +import lightning.pytorch as pl +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from datasets import load_dataset, load_from_disk +from sklearn.metrics import r2_score +from src.utils.train import get_logger +from caduceus.tokenization_caduceus import CaduceusTokenizer +from finetuning_glrb.utils import fsspec_exists, get_last_embedding_dimension + +# Logger setup +log = get_logger(__name__) + +# Constants for the upstream and downstream window sizes +WINDOW_SIZE_BP_UPSTREAM = 384 +WINDOW_SIZE_BP_DOWNSTREAM = 256 + +def tokenize_variants(examples, tokenizer, max_length: int): + """ + Tokenize sequence. + + Args: + examples: A batch of items from the dataset. + tokenizer: AutoTokenizer instance. + max_length: Maximum length for tokenization. + + Returns: + dict with tokenized input IDs. + """ + ref_tokenized = tokenizer.batch_encode_plus( + examples["sequence"], + add_special_tokens=False, + return_attention_mask=False, + max_length=max_length, + truncation=True, + ) + return { + "ref_input_ids": ref_tokenized["input_ids"], + } + +def recast_chromosome(examples): + """ + Recast chromosome to integer format. + + Returns: + dict with chromosome recast as integers. + """ + return { + "chromosome": -1 if examples["chromosome"] == "X" else -2 if examples["chromosome"] == "Y" else int(examples["chromosome"]) + } + +class MLP_BulkRNA(nn.Module): + """ + Regression head for Bulk RNA prediction task. + + Args: + input_size: Input size for the linear layer. + hidden_size: Hidden layer size. + output_size: Output size for the linear layer. + """ + def __init__(self, input_size, hidden_size, output_size): + super(MLP_BulkRNA, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.sp1 = nn.Softplus() + self.fc2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + return self.fc2(self.sp1(self.fc1(x))) + +class DNAModelForBulkRNA(nn.Module): + """ + DNA Model for Bulk RNA prediction. + + Args: + args: Arguments containing model configurations. + """ + def __init__(self, args): + super().__init__() + self.rcps = args.rcps + self.bp_per_token = args.bp_per_token + self.config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True) + + # Load the appropriate backbone model based on the model name + if "nucleotide-transformer" in args.model_name.lower(): + self.backbone = AutoModelForMaskedLM.from_pretrained(args.model_name, trust_remote_code=True).esm + else: + self.backbone = AutoModel.from_pretrained(args.model_name, trust_remote_code=True) + + print(f"MODEL LOADED: {self.backbone}") + self.inner_dim = get_last_embedding_dimension(self.backbone,self.rcps) + print(f"Inner dim founded for the Foundation Model: {self.inner_dim}") + self.head = MLP_BulkRNA(input_size=self.inner_dim, hidden_size=2*self.inner_dim, output_size=218) + + def forward(self, input_ids): + # Get embeddings from the backbone + embeds_out = self.backbone(input_ids)[0] + num_channels = embeds_out.size(-1) + + # Calculate the window size and indexes + window_size_upstream = WINDOW_SIZE_BP_UPSTREAM // self.bp_per_token + window_size_downstream = WINDOW_SIZE_BP_DOWNSTREAM // self.bp_per_token + start, end = -window_size_downstream, window_size_upstream + + if self.rcps: + # If the model is RC-equivariant + embeds = embeds_out[..., :num_channels // 2] + expanded_indices = torch.arange(start, end, device=embeds.device).unsqueeze(0).expand(embeds.size(0), -1) + embeds.size(1) // 2 + expanded_indices = torch.clamp(expanded_indices, 0, embeds.size(1) - 1) + + # Extract the relevant window from the embeddings and average it through the sequence length dimension + tokens_window_ref = torch.gather( + embeds, 1, + expanded_indices.unsqueeze(-1).expand(-1, -1, embeds.size(2)) + ) + tokens_window_ref = tokens_window_ref.mean(dim=1) + + #Same for the RC-equivalent + rc_embeds = embeds_out[..., num_channels // 2:].contiguous().flip(dims=[1, 2]) + expanded_indices = torch.arange(start, end, device=rc_embeds.device).unsqueeze(0).expand(rc_embeds.size(0), -1) + rc_embeds.size(1) // 2 + expanded_indices = torch.clamp(expanded_indices, 0, rc_embeds.size(1) - 1) + + tokens_window_rc = torch.gather( + rc_embeds, 1, + expanded_indices.unsqueeze(-1).expand(-1, -1, rc_embeds.size(2)) + ) + tokens_window_rc = tokens_window_rc.mean(dim=1) + + #Combine the Reference and RC-Equivariant resulting embeddings + aggregated_embeds = tokens_window_rc + tokens_window_ref + return self.head(aggregated_embeds) + + else: + # No reverse complement processing + expanded_indices = torch.arange(start, end, device=embeds_out.device).unsqueeze(0).expand(embeds_out.size(0), -1) + embeds_out.size(1) // 2 + expanded_indices = torch.clamp(expanded_indices, 0, embeds_out.size(1) - 1) + + # Extract the relevant window + tokens_window_ref = torch.gather( + embeds_out, 1, + expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_out.size(2)) + ) + tokens_window_ref = tokens_window_ref.mean(dim=1) + return self.head(tokens_window_ref) + +class Lit_BulkRNAFinetuning(pl.LightningModule): + """ + PyTorch Lightning model for fine-tuning on Bulk RNA prediction. + + Args: + args: Arguments containing model and training configurations. + """ + def __init__(self, args): + super().__init__() + self.save_hyperparameters(args) + self.rcps = args.rcps + self.model = DNAModelForBulkRNA(args) + self.criterion = nn.MSELoss() + self.validation_step_preds = [] + self.validation_step_labels = [] + self.training_step_preds = [] + self.training_step_labels = [] + + def forward(self, ref_input_ids): + return self.model(ref_input_ids) + + def training_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"] + + logits = self.model(ref_input_ids) + loss = self.criterion(logits, labels) + self.log('train_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + # Track predictions and labels for R² score + self.training_step_preds.extend(logits.detach().cpu().numpy()) + self.training_step_labels.extend(labels.detach().cpu().numpy()) + + return loss + + def validation_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"] + + logits = self.model(ref_input_ids) + loss = self.criterion(logits, labels) + self.log('val_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + # Track predictions and labels for R² score + self.validation_step_preds.extend(logits.detach().cpu().numpy()) + self.validation_step_labels.extend(labels.detach().cpu().numpy()) + + return loss + + def on_validation_epoch_end(self): + # Calculate R² score for validation + val_r2 = r2_score(self.validation_step_labels, self.validation_step_preds) + self.log("validation/R2", val_r2, on_epoch=True, prog_bar=True, logger=True) + self.validation_step_labels.clear() + self.validation_step_preds.clear() + + def on_train_epoch_end(self): + # Calculate R² score for training + train_r2 = r2_score(self.training_step_labels, self.training_step_preds) + self.log("train/R2", train_r2, on_epoch=True, prog_bar=True, logger=True) + self.training_step_labels.clear() + self.training_step_preds.clear() + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate) + +class BulkRNADataModule(pl.LightningDataModule): + """ + Data module for Bulk RNA finetuning with PyTorch Lightning. + + Args: + config: Configuration dictionary with data-related parameters. + """ + def __init__(self, config): + super().__init__() + self.seq_len = config.seq_len + self.model_name = config.model_name + self.train_batch_size = config.train_batch_size + self.test_batch_size = config.test_batch_size + self.num_workers = config.num_workers + self.train_ratio = config.train_ratio + self.eval_ratio = config.eval_ratio + self.cache_dir = "./" + self.dataset = None + + # Initialize the tokenizer + if "caduceus" in self.model_name: + self.tokenizer = CaduceusTokenizer( + model_max_length=self.seq_len, + add_special_tokens=False + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + + def prepare_data(self): + # Download and preprocess data if not already done + if not fsspec_exists(self._get_preprocessed_cache_file()): + self._download_and_preprocess_data() + + def setup(self, stage=None): + # Load the preprocessed dataset + self.prepare_data() + self.dataset = load_from_disk(self._get_preprocessed_cache_file()) + + # Split the dataset into train and validation sets + self.train_dataset = self.dataset["train"] + self.val_dataset = self.dataset["test"] + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _get_preprocessed_cache_file(self): + self.cache_dir = osp.join( + "./", "data", "InstaDeepAI___genomics-long-range-benchmark", + "variant_effect_pathogenic_clinvar", f"seqlen{self.seq_len}" + ) + cache_file = os.path.join(self.cache_dir, "caduceus_char_token_preprocessed") + return re.sub(r"=", "_", cache_file) + + def _download_and_preprocess_data(self): + log.warning("Downloading and preprocessing data...") + dataset = load_dataset( + "InstaDeepAI/genomics-long-range-benchmark", + task_name="bulk_rna_expression", + sequence_length=self.seq_len, + load_from_cache=False, + trust_remote_code=True + ) + try: + del dataset["validation"] # Remove empty validation split if it exists + except KeyError: + pass + + # Process data: filter sequences with too many 'N's, recast chromosomes, and tokenize + dataset = dataset.filter( + lambda example: example["sequence"].count('N') < 0.005 * self.seq_len, + desc="Filter N's" + ) + dataset = dataset.map( + recast_chromosome, + remove_columns=["chromosome"], + desc="Recast chromosome" + ) + dataset = dataset.map( + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len), + batch_size=1000, + batched=True, + remove_columns=["sequence"], + desc="Tokenize", + num_proc=self.num_workers + ) + + # Save processed dataset to disk + dataset.save_to_disk(self._get_preprocessed_cache_file()) + log.warning("Data downloaded and preprocessed successfully.") + +def main_lit(args): + """ + Main function to start the training process using PyTorch Lightning. + + Args: + args: Command line arguments or configuration dictionary. + """ + wandb.login(key=args.wandb_api_key) + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}", + project="Bulk RNA Expression", + log_model=True, + save_dir=args.save_dir + ) + data_module = BulkRNADataModule(args) + data_module.setup() + + model = Lit_BulkRNAFinetuning(args) + + # Callbacks for early stopping and model checkpointing + early_stopping_callback = EarlyStopping( + monitor="val_loss", + patience=args.patience, + verbose=True, + mode="min" + ) + + checkpoint_callback = ModelCheckpoint( + dirpath=f"{args.save_dir}/checkpoints", + filename="best-checkpoint", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[early_stopping_callback, checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision="16-mixed", + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + # Start the training process + trainer.fit(model, data_module) diff --git a/finetuning_glrb/finetune_chromatin.py b/finetuning_glrb/finetune_chromatin.py new file mode 100644 index 0000000..e814dec --- /dev/null +++ b/finetuning_glrb/finetune_chromatin.py @@ -0,0 +1,553 @@ +import os +from functools import partial +from os import path as osp +import re +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import wandb +import lightning.pytorch as pl +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer, AutoConfig +from datasets import load_dataset, load_from_disk +from sklearn.metrics import accuracy_score, precision_recall_curve, auc, roc_auc_score +from transformers import DefaultDataCollator +from src.utils.train import get_logger +from caduceus.tokenization_caduceus import CaduceusTokenizer +from finetuning_glrb.utils import ( + fsspec_exists, + get_last_embedding_dimension +) + +# Logger setup +log = get_logger(__name__) + +# Constants for the window size in base pairs +WINDOW_SIZE_BP = 200 + +def tokenize_variants(examples, tokenizer, max_length: int): + """ + Tokenize sequence. + + Args: + examples: A batch of items from the dataset. + tokenizer: AutoTokenizer instance. + max_length: Maximum length for tokenization. + + Returns: + dict with tokenized input IDs. + """ + seq_tokenized = tokenizer.batch_encode_plus( + examples["sequence"], + add_special_tokens=False, + return_attention_mask=False, + max_length=max_length, + truncation=True, + ) + return { + "ref_input_ids": seq_tokenized["input_ids"] + } + +def recast_chromosome(examples): + """ + Recast chromosome to integer format. + + Returns: + dict with chromosome recast as integers. + """ + return { + "chromosome": -1 if examples["chromosome"] == "X" else -2 if examples["chromosome"] == "Y" else int(examples["chromosome"]) + } + +class MLP_ChromatineFeatures(nn.Module): + """ + MLP model for predicting chromatin features. + + Args: + input_size: Input size for the linear layer. + hidden_size: Hidden layer size. + output_size: Output size for the linear layer. + """ + def __init__(self, input_size, hidden_size, output_size): + super(MLP_ChromatineFeatures, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.sp1 = nn.Softplus() + self.fc2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + return self.fc2(self.sp1(self.fc1(x))) + +class DNAModelForChromatineFeatures(nn.Module): + """ + DNA Model for Chromatin Features prediction. + + Args: + args: Arguments containing model configurations. + """ + def __init__(self, args): + super().__init__() + self.rcps = args.rcps + self.bp_per_token = args.bp_per_token + self.config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True) + + # Load the appropriate backbone model based on the model name + if "nucleotide-transformer" in args.model_name.lower(): + self.backbone = AutoModelForMaskedLM.from_pretrained(args.model_name, trust_remote_code=True).esm + else: + self.backbone = AutoModel.from_pretrained(args.model_name, trust_remote_code=True) + + print(f"MODEL LOADED: {self.backbone}") + self.inner_dim = get_last_embedding_dimension(self.backbone,self.rcps) + print(f"Inner dim found for the Foundation Model: {self.inner_dim}") + self.head = MLP_ChromatineFeatures(input_size=self.inner_dim, hidden_size=2 * self.inner_dim, output_size=20) + + def forward(self, input_ids): + embeds_out = self.backbone(input_ids)[0] + num_channels = embeds_out.size(-1) + window_size = WINDOW_SIZE_BP // self.bp_per_token // 2 + batch_size, seq_len, embedding_dim = embeds_out.shape + + if self.rcps: + embeds = embeds_out[..., :num_channels // 2] + expanded_indices = ( + torch.arange(-window_size, window_size + 1, device=embeds.device).unsqueeze(0).expand(batch_size, -1) + seq_len // 2 + ) + expanded_indices = torch.clamp(expanded_indices, 0, embeds.size(1) - 1) + + # Extract windowed embeddings for reference sequence + tokens_window_ref = torch.gather(embeds, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds.size(2))) + + # Extract windowed embeddings for reverse complement sequence + rc_embeds = embeds_out[..., num_channels // 2:].contiguous().flip(dims=[1, 2]) + expanded_indices = torch.arange(-window_size, window_size + 1, device=rc_embeds.device).unsqueeze(0).expand(batch_size, -1) + seq_len // 2 + expanded_indices = torch.clamp(expanded_indices, 0, rc_embeds.size(1) - 1) + + tokens_window_rc = torch.gather(rc_embeds, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, rc_embeds.size(2))) + + # Average the embeddings over the sequence length dimension + tokens_window_rc = tokens_window_rc.mean(dim=1) + tokens_window_ref = tokens_window_ref.mean(dim=1) + + # Combine the Reference and RC-equivalent embeddings + aggregated_embeds = tokens_window_rc + tokens_window_ref + return self.head(aggregated_embeds) + + else: + expanded_indices = torch.arange(-window_size, window_size + 1, device=embeds_out.device).unsqueeze(0).expand(batch_size, -1) + seq_len // 2 + expanded_indices = torch.clamp(expanded_indices, 0, embeds_out.size(1) - 1) + + # Extract windowed embeddings for reference sequence + tokens_window_ref = torch.gather(embeds_out, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_out.size(2))).mean(dim=1) + return self.head(tokens_window_ref) + +class Lit_ChromatinFeatures(pl.LightningModule): + """ + PyTorch Lightning model for predicting chromatin features. + + Args: + args: Arguments containing model and training configurations. + """ + def __init__(self, args): + super().__init__() + self.save_hyperparameters(args) + self.model = DNAModelForChromatineFeatures(args) + self.criterion = nn.BCEWithLogitsLoss() + self.validation_step_preds = [] + self.validation_step_labels = [] + self.training_step_preds = [] + self.training_step_labels = [] + + def forward(self, ref_input_ids): + return self.model(ref_input_ids) + + def training_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"].float() + + logits = self(ref_input_ids) + loss = self.criterion(logits, labels) + self.log('train_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + # Track predictions and labels for accuracy and F1 score + preds = (torch.sigmoid(logits) > 0.5).float() # Get predicted class labels + self.training_step_preds.extend(preds.detach().flatten().cpu().numpy()) + self.training_step_labels.extend(labels.detach().flatten().cpu().numpy()) + + return loss + + def validation_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"].float() + + logits = self(ref_input_ids) + loss = self.criterion(logits, labels) + self.log('val_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + # Track predictions and labels for accuracy and F1 score + preds = (torch.sigmoid(logits) > 0.5).float() # Get predicted class labels + self.validation_step_preds.extend(preds.detach().flatten().cpu().numpy()) + self.validation_step_labels.extend(labels.detach().flatten().cpu().numpy()) + + return loss + + def on_validation_epoch_end(self): + # Calculate accuracy, AUPRC, and AUROC for validation + val_accuracy = accuracy_score(self.validation_step_labels, self.validation_step_preds) + precision, recall, _ = precision_recall_curve(self.validation_step_labels, self.validation_step_preds) + val_auprc = auc(recall, precision) + val_auroc = roc_auc_score(self.validation_step_labels, self.validation_step_preds) + + # Log validation metrics + self.log("validation/accuracy", val_accuracy, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + self.log("validation/AUPRC", val_auprc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + self.log("validation/AUROC", val_auroc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + + self.validation_step_labels.clear() + self.validation_step_preds.clear() + + def on_train_epoch_end(self): + # Calculate accuracy, AUPRC, and AUROC for training + train_accuracy = accuracy_score(self.training_step_labels, self.training_step_preds) + precision, recall, _ = precision_recall_curve(self.training_step_labels, self.training_step_preds) + train_auprc = auc(recall, precision) + train_auroc = roc_auc_score(self.training_step_labels, self.training_step_preds) + + # Log training metrics + self.log("train/accuracy", train_accuracy, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + self.log("train/AUPRC", train_auprc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + self.log("train/AUROC", train_auroc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + + self.training_step_labels.clear() + self.training_step_preds.clear() + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate) + +class HistoneMarksDataModule(pl.LightningDataModule): + """ + Data module for Histone Marks finetuning with PyTorch Lightning. + + Args: + config: Configuration dictionary with data-related parameters. + """ + def __init__(self, config): + super().__init__() + self.seq_len = config.seq_len + self.model_name = config.model_name + self.train_batch_size = config.train_batch_size + self.test_batch_size = config.test_batch_size + self.num_workers = config.num_workers + self.train_ratio = config.train_ratio + self.eval_ratio = config.eval_ratio + self.cache_dir = "./" + self.dataset = None + + # Initialize the tokenizer + if "caduceus" in self.model_name: + self.tokenizer = CaduceusTokenizer( + model_max_length=self.seq_len, + add_special_tokens=False + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + + def prepare_data(self): + # Download and preprocess data if not already done + if not fsspec_exists(self._get_preprocessed_cache_file()): + self._download_and_preprocess_data() + + def setup(self, stage=None): + # Load the preprocessed dataset + self.prepare_data() + self.dataset = load_from_disk(self._get_preprocessed_cache_file()) + + # Split the dataset into train and validation sets + self.train_dataset = self.dataset["train"] + self.val_dataset = self.dataset["test"] + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _get_preprocessed_cache_file(self): + self.cache_dir = osp.join( + "./", "data", "InstaDeepAI___genomics-long-range-benchmark", + "chromatin_features_histone_marks", f"seqlen{self.seq_len}" + ) + cache_file = os.path.join(self.cache_dir, "caduceus_char_token_preprocessed") + return re.sub(r"=", "_", cache_file) + + def _download_and_preprocess_data(self): + log.warning("Downloading and preprocessing data...") + dataset = load_dataset( + "InstaDeepAI/genomics-long-range-benchmark", + task_name="chromatin_features_histone_marks", + sequence_length=self.seq_len, + subset=True, + load_from_cache=False, + trust_remote_code=True + ) + try: + del dataset["validation"] # Remove empty validation split if it exists + except KeyError: + pass + + # Process data: filter sequences with too many 'N's, recast chromosomes, and tokenize + dataset = dataset.filter( + lambda example: example["sequence"].count('N') < 0.005 * self.seq_len, + desc="Filter N's" + ) + dataset = dataset.map( + recast_chromosome, + remove_columns=["chromosome"], + desc="Recast chromosome" + ) + dataset = dataset.map( + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len), + batch_size=1000, + batched=True, + remove_columns=["sequence"], + desc="Tokenize", + num_proc=self.num_workers + ) + + # Save processed dataset to disk + dataset.save_to_disk(self._get_preprocessed_cache_file()) + log.warning("Data downloaded and preprocessed successfully.") + +class DNAAccessibilityDataModule(pl.LightningDataModule): + """ + Data module for DNA Accessibility finetuning with PyTorch Lightning. + + Args: + config: Configuration dictionary with data-related parameters. + """ + def __init__(self, config): + super().__init__() + self.seq_len = config.seq_len + self.model_name = config.model_name + self.train_batch_size = config.train_batch_size + self.test_batch_size = config.test_batch_size + self.num_workers = config.num_workers + self.train_ratio = config.train_ratio + self.eval_ratio = config.eval_ratio + self.cache_dir = "./" + self.dataset = None + + # Initialize the tokenizer + if "caduceus" in self.model_name: + self.tokenizer = CaduceusTokenizer( + model_max_length=self.seq_len, + add_special_tokens=False + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + + def prepare_data(self): + # Download and preprocess data if not already done + if not fsspec_exists(self._get_preprocessed_cache_file()): + self._download_and_preprocess_data() + + def setup(self, stage=None): + # Load the preprocessed dataset + self.prepare_data() + self.dataset = load_from_disk(self._get_preprocessed_cache_file()) + + # Split the dataset into train and validation sets + self.train_dataset = self.dataset["train"] + self.val_dataset = self.dataset["test"] + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _get_preprocessed_cache_file(self): + self.cache_dir = osp.join( + "./", "data", "InstaDeepAI___genomics-long-range-benchmark", + "chromatin_features_dna_accessibility", f"seqlen{self.seq_len}" + ) + cache_file = os.path.join(self.cache_dir, "caduceus_char_token_preprocessed") + return re.sub(r"=", "_", cache_file) + + def _download_and_preprocess_data(self): + log.warning("Downloading and preprocessing data...") + dataset = load_dataset( + "InstaDeepAI/genomics-long-range-benchmark", + task_name="chromatin_features_dna_accessibility", + sequence_length=self.seq_len, + subset=True, + load_from_cache=False, + trust_remote_code=True + ) + try: + del dataset["validation"] # Remove empty validation split if it exists + except KeyError: + pass + + # Process data: filter sequences with too many 'N's, recast chromosomes, and tokenize + dataset = dataset.filter( + lambda example: example["sequence"].count('N') < 0.005 * self.seq_len, + desc="Filter N's" + ) + dataset = dataset.map( + recast_chromosome, + remove_columns=["chromosome"], + desc="Recast chromosome" + ) + dataset = dataset.map( + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len), + batch_size=1000, + batched=True, + remove_columns=["sequence"], + desc="Tokenize", + num_proc=self.num_workers + ) + + # Save processed dataset to disk + dataset.save_to_disk(self._get_preprocessed_cache_file()) + log.warning("Data downloaded and preprocessed successfully.") + +def main_histone_marks(args): + """ + Main function to start finetuning on Histone Marks with PyTorch Lightning. + + Args: + args: Command line arguments or configuration dictionary. + """ + wandb.login(key=args.wandb_api_key) + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}", + project="Histone Marks", + log_model=True # Automatically log model checkpoints + ) + data_module = HistoneMarksDataModule(args) + data_module.setup() + + model = Lit_ChromatinFeatures(args) + + # Callbacks for early stopping and model checkpointing + early_stopping_callback = EarlyStopping( + monitor="val_loss", + patience=args.patience, + verbose=True, + mode="min" + ) + + checkpoint_callback = ModelCheckpoint( + dirpath="./checkpoints", + filename="best-checkpoint", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[early_stopping_callback, checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision="16-mixed", + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + # Start the training process + trainer.fit(model, data_module) + +def main_dna_accessibility(args): + """ + Main function to start finetunning on DNA Accessibility with PyTorch Lightning. + + Args: + args: Command line arguments or configuration dictionary. + """ + wandb.login(key=args.wandb_api_key) + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}", + project="DNA Accessibility", + log_model=True # Automatically log model checkpoints + ) + data_module = DNAAccessibilityDataModule(args) + data_module.setup() + + model = Lit_ChromatinFeatures(args) + + # Callbacks for early stopping and model checkpointing + early_stopping_callback = EarlyStopping( + monitor="val_loss", + patience=args.patience, + verbose=True, + mode="min" + ) + + checkpoint_callback = ModelCheckpoint( + dirpath="./checkpoints", + filename="best-checkpoint", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[early_stopping_callback, checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision="16-mixed", + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + # Start the training process + trainer.fit(model, data_module) diff --git a/finetuning_glrb/finetune_regulatory_elements.py b/finetuning_glrb/finetune_regulatory_elements.py new file mode 100644 index 0000000..0062117 --- /dev/null +++ b/finetuning_glrb/finetune_regulatory_elements.py @@ -0,0 +1,554 @@ +import os +from functools import partial +from os import path as osp +import re +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer, AutoConfig +import wandb +import lightning.pytorch as pl +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from datasets import load_dataset, load_from_disk +from sklearn.metrics import accuracy_score, precision_recall_curve, auc, roc_auc_score +from transformers import DefaultDataCollator +from src.utils.train import get_logger +from caduceus.tokenization_caduceus import CaduceusTokenizer +from finetuning_glrb.utils import fsspec_exists, get_last_embedding_dimension + +# Logger setup +log = get_logger(__name__) + +# Constants for the window size in base pairs +WINDOW_SIZE_BP = 200 + +def tokenize_variants(examples, tokenizer, max_length: int): + """ + Tokenize sequence. + + Args: + examples: A batch of items from the dataset. + tokenizer: AutoTokenizer instance. + max_length: Maximum length for tokenization. + + Returns: + dict with tokenized input IDs. + """ + seq_tokenized = tokenizer.batch_encode_plus( + examples["sequence"], + add_special_tokens=False, + return_attention_mask=False, + max_length=max_length, + truncation=True, + ) + return { + "ref_input_ids": seq_tokenized["input_ids"] + } + +def recast_chromosome(examples): + """ + Recast chromosome to integer format. + + Returns: + dict with chromosome recast as integers. + """ + return { + "chromosome": -1 if examples["chromosome"] == "X" else -2 if examples["chromosome"] == "Y" else int(examples["chromosome"]) + } + +class MLP_RegulatoryElements(nn.Module): + """ + MLP model for predicting regulatory elements. + + Args: + input_size: Input size for the linear layer. + hidden_size: Hidden layer size. + output_size: Output size for the linear layer. + """ + def __init__(self, input_size, hidden_size, output_size): + super(MLP_RegulatoryElements, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.sp1 = nn.Softplus() + self.fc2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + return self.fc2(self.sp1(self.fc1(x))) + +class DNAModelForRegulatoryElements(nn.Module): + """ + DNA Model for Regulatory Elements prediction. + + Args: + args: Arguments containing model configurations. + """ + def __init__(self, args): + super().__init__() + self.rcps = args.rcps + self.bp_per_token = args.bp_per_token + self.config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True) + + # Load the appropriate backbone model based on the model name + if "nucleotide-transformer" in args.model_name.lower(): + self.backbone = AutoModelForMaskedLM.from_pretrained(args.model_name, trust_remote_code=True).esm + else: + self.backbone = AutoModel.from_pretrained(args.model_name, trust_remote_code=True) + + print(f"MODEL LOADED: {self.backbone}") + self.inner_dim = get_last_embedding_dimension(self.backbone,self.rcps) + print(f"Inner dim found for the Foundation Model: {self.inner_dim}") + self.head = MLP_RegulatoryElements(input_size=self.inner_dim, hidden_size=2 * self.inner_dim, output_size=1) + + def forward(self, input_ids): + # Get embeddings for the alternate and reference sequences + embeds_out = self.backbone(input_ids)[0] + num_channels = embeds_out.size(-1) + window_size = WINDOW_SIZE_BP // self.bp_per_token // 2 + batch_size, seq_len, embedding_dim = embeds_out.shape + + if self.rcps: + # Handle reverse complement processing + ref_embeds = embeds_out[..., :num_channels // 2] + rc_embeds = embeds_out[..., num_channels // 2:].contiguous().flip(dims=[1, 2]) + + expanded_indices = torch.arange(-window_size, window_size + 1, device=ref_embeds.device).unsqueeze(0).expand(batch_size, -1) + seq_len // 2 + expanded_indices = torch.clamp(expanded_indices, 0, ref_embeds.size(1) - 1) + + # Extract windowed embeddings for the reference sequence + tokens_window_ref = torch.gather(ref_embeds, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, ref_embeds.size(2))).mean(dim=1) + + expanded_indices = torch.arange(-window_size, window_size + 1, device=rc_embeds.device).unsqueeze(0).expand(batch_size, -1) + seq_len // 2 + expanded_indices = torch.clamp(expanded_indices, 0, rc_embeds.size(1) - 1) + + # Extract windowed embeddings for the reverse complement sequence + tokens_window_rc = torch.gather(rc_embeds, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, rc_embeds.size(2))).mean(dim=1) + + # Combine the reference and reverse complement embeddings + aggregated_embeds = tokens_window_rc + tokens_window_ref + return self.head(aggregated_embeds) + + else: + # Handle non-reverse complement processing + expanded_indices = torch.arange(-window_size, window_size + 1, device=embeds_out.device).unsqueeze(0).expand(batch_size, -1) + seq_len // 2 + expanded_indices = torch.clamp(expanded_indices, 0, embeds_out.size(1) - 1) + + tokens_window_ref = torch.gather(embeds_out, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_out.size(2))).mean(dim=1) + return self.head(tokens_window_ref) + +class Lit_RegulatoryElements(pl.LightningModule): + """ + PyTorch Lightning model for predicting regulatory elements. + + Args: + args: Arguments containing model and training configurations. + """ + def __init__(self, args): + super().__init__() + self.save_hyperparameters(args) + self.model = DNAModelForRegulatoryElements(args) + self.task = args.task + self.criterion = nn.BCEWithLogitsLoss() + self.validation_step_preds = [] + self.validation_step_labels = [] + self.training_step_preds = [] + self.training_step_labels = [] + + def forward(self, ref_input_ids): + return self.model(ref_input_ids) + + def training_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"].float() + + logits = self(ref_input_ids).squeeze(-1) + loss = self.criterion(logits, labels) + self.log('train_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + # Track predictions and labels for accuracy and F1 score + preds = (torch.sigmoid(logits) > 0.5).float() # Get predicted class labels + self.training_step_preds.extend(preds.detach().flatten().cpu().numpy()) + self.training_step_labels.extend(labels.detach().flatten().cpu().numpy()) + + return loss + + def validation_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"].float() + + logits = self(ref_input_ids).squeeze(-1) + loss = self.criterion(logits, labels) + self.log('val_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + # Track predictions and labels for accuracy and F1 score + preds = (torch.sigmoid(logits) > 0.5).float() # Get predicted class labels + self.validation_step_preds.extend(preds.detach().flatten().cpu().numpy()) + self.validation_step_labels.extend(labels.detach().flatten().cpu().numpy()) + + return loss + + def on_validation_epoch_end(self): + # Calculate accuracy, AUPRC, and AUROC for validation + val_accuracy = accuracy_score(self.validation_step_labels, self.validation_step_preds) + precision, recall, _ = precision_recall_curve(self.validation_step_labels, self.validation_step_preds) + val_auprc = auc(recall, precision) + + if self.task == "regulatory_element_enhancer": + val_auroc = roc_auc_score(self.validation_step_labels, self.validation_step_preds) + + # Log validation metrics + self.log("validation/accuracy", val_accuracy, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + self.log("validation/AUPRC", val_auprc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + if self.task == "regulatory_element_enhancer": + self.log("validation/AUROC", val_auroc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + + self.validation_step_labels.clear() + self.validation_step_preds.clear() + + def on_train_epoch_end(self): + # Calculate accuracy, AUPRC, and AUROC for training + train_accuracy = accuracy_score(self.training_step_labels, self.training_step_preds) + precision, recall, _ = precision_recall_curve(self.training_step_labels, self.training_step_preds) + train_auprc = auc(recall, precision) + + if self.task == "regulatory_element_enhancer": + train_auroc = roc_auc_score(self.training_step_labels, self.training_step_preds) + + # Log training metrics + self.log("train/accuracy", train_accuracy, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + self.log("train/AUPRC", train_auprc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + if self.task == "regulatory_element_enhancer": + self.log("train/AUROC", train_auroc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + + self.training_step_labels.clear() + self.training_step_preds.clear() + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate) + +class PromoterDataModule(pl.LightningDataModule): + """ + Data module for Promoter regulatory element finetuning with PyTorch Lightning. + + Args: + config: Configuration dictionary with data-related parameters. + """ + def __init__(self, config): + super().__init__() + self.seq_len = config.seq_len + self.model_name = config.model_name + self.train_batch_size = config.train_batch_size + self.test_batch_size = config.test_batch_size + self.num_workers = config.num_workers + self.train_ratio = config.train_ratio + self.eval_ratio = config.eval_ratio + self.cache_dir = "./" + self.dataset = None + + # Initialize the tokenizer + if "caduceus" in self.model_name: + self.tokenizer = CaduceusTokenizer( + model_max_length=self.seq_len, + add_special_tokens=False + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + + def prepare_data(self): + # Download and preprocess data if not already done + if not fsspec_exists(self._get_preprocessed_cache_file()): + self._download_and_preprocess_data() + + def setup(self, stage=None): + # Load the preprocessed dataset + self.prepare_data() + self.dataset = load_from_disk(self._get_preprocessed_cache_file()) + + # Split the dataset into train and validation sets + self.train_dataset = self.dataset["train"] + self.val_dataset = self.dataset["test"] + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _get_preprocessed_cache_file(self): + self.cache_dir = osp.join( + "./", "data", "InstaDeepAI___genomics-long-range-benchmark", + "regulatory_element_promoter", f"seqlen{self.seq_len}" + ) + cache_file = os.path.join(self.cache_dir, "caduceus_char_token_preprocessed") + return re.sub(r"=", "_", cache_file) + + def _download_and_preprocess_data(self): + log.warning("Downloading and preprocessing data...") + dataset = load_dataset( + "InstaDeepAI/genomics-long-range-benchmark", + task_name="regulatory_element_promoter", + sequence_length=self.seq_len, + subset=True, + load_from_cache=False, + trust_remote_code=True + ) + try: + del dataset["validation"] # Remove empty validation split if it exists + except KeyError: + pass + + # Process data: filter sequences with too many 'N's, recast chromosomes, and tokenize + dataset = dataset.filter( + lambda example: example["sequence"].count('N') < 0.005 * self.seq_len, + desc="Filter N's" + ) + dataset = dataset.map( + recast_chromosome, + remove_columns=["chromosome"], + desc="Recast chromosome" + ) + dataset = dataset.map( + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len), + batch_size=1000, + batched=True, + remove_columns=["sequence"], + desc="Tokenize", + num_proc=self.num_workers + ) + + # Save processed dataset to disk + dataset.save_to_disk(self._get_preprocessed_cache_file()) + log.warning("Data downloaded and preprocessed successfully.") + +class EnhancerDataModule(pl.LightningDataModule): + """ + Data module for Enhancer regulatory element finetuning with PyTorch Lightning. + + Args: + config: Configuration dictionary with data-related parameters. + """ + def __init__(self, config): + super().__init__() + self.seq_len = config.seq_len + self.model_name = config.model_name + self.train_batch_size = config.train_batch_size + self.test_batch_size = config.test_batch_size + self.num_workers = config.num_workers + self.train_ratio = config.train_ratio + self.eval_ratio = config.eval_ratio + self.cache_dir = "./" + self.dataset = None + + # Initialize the tokenizer + if "caduceus" in self.model_name: + self.tokenizer = CaduceusTokenizer( + model_max_length=self.seq_len, + add_special_tokens=False + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + + def prepare_data(self): + # Download and preprocess data if not already done + if not fsspec_exists(self._get_preprocessed_cache_file()): + self._download_and_preprocess_data() + + def setup(self, stage=None): + # Load the preprocessed dataset + self.prepare_data() + self.dataset = load_from_disk(self._get_preprocessed_cache_file()) + + # Split the dataset into train and validation sets + self.train_dataset = self.dataset["train"] + self.val_dataset = self.dataset["test"] + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _get_preprocessed_cache_file(self): + self.cache_dir = osp.join( + "./", "data", "InstaDeepAI___genomics-long-range-benchmark", + "regulatory_element_enhancer", f"seqlen{self.seq_len}" + ) + cache_file = os.path.join(self.cache_dir, "caduceus_char_token_preprocessed") + return re.sub(r"=", "_", cache_file) + + def _download_and_preprocess_data(self): + log.warning("Downloading and preprocessing data...") + dataset = load_dataset( + "InstaDeepAI/genomics-long-range-benchmark", + task_name="regulatory_element_enhancer", + sequence_length=self.seq_len, + subset=True, + load_from_cache=False, + trust_remote_code=True + ) + try: + del dataset["validation"] # Remove empty validation split if it exists + except KeyError: + pass + + # Process data: filter sequences with too many 'N's, recast chromosomes, and tokenize + dataset = dataset.filter( + lambda example: example["sequence"].count('N') < 0.005 * self.seq_len, + desc="Filter N's" + ) + dataset = dataset.map( + recast_chromosome, + remove_columns=["chromosome"], + desc="Recast chromosome" + ) + dataset = dataset.map( + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len), + batch_size=1000, + batched=True, + remove_columns=["sequence"], + desc="Tokenize", + num_proc=self.num_workers + ) + + # Save processed dataset to disk + dataset.save_to_disk(self._get_preprocessed_cache_file()) + log.warning("Data downloaded and preprocessed successfully.") + +def main_promoter(args): + """ + Main function to start training on Promoter regulatory elements with PyTorch Lightning. + + Args: + args: Command line arguments or configuration dictionary. + """ + wandb.login(key=args.wandb_api_key) + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}", + project="Regulatory Element Promoter", + log_model=True # Automatically log model checkpoints + ) + data_module = PromoterDataModule(args) + data_module.setup() + + model = Lit_RegulatoryElements(args) + + # Callbacks for early stopping and model checkpointing + early_stopping_callback = EarlyStopping( + monitor="val_loss", + patience=args.patience, + verbose=True, + mode="min" + ) + + checkpoint_callback = ModelCheckpoint( + dirpath="./checkpoints", + filename="best-checkpoint", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[early_stopping_callback, checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision="16-mixed", + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + # Start the training process + trainer.fit(model, data_module) + +def main_enhancer(args): + """ + Main function to start training on Enhancer regulatory elements with PyTorch Lightning. + + Args: + args: Command line arguments or configuration dictionary. + """ + wandb.login(key=args.wandb_api_key) + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}", + project="Regulatory Elements Enhancer", + log_model=True # Automatically log model checkpoints + ) + data_module = EnhancerDataModule(args) + data_module.setup() + + model = Lit_RegulatoryElements(args) + + # Callbacks for early stopping and model checkpointing + early_stopping_callback = EarlyStopping( + monitor="val_loss", + patience=args.patience, + verbose=True, + mode="min" + ) + + checkpoint_callback = ModelCheckpoint( + dirpath="./checkpoints", + filename="best-checkpoint", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[early_stopping_callback, checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision="16-mixed", + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + # Start the training process + trainer.fit(model, data_module) diff --git a/finetuning_glrb/finetune_variant_effect_OMIM.py b/finetuning_glrb/finetune_variant_effect_OMIM.py new file mode 100644 index 0000000..18e611f --- /dev/null +++ b/finetuning_glrb/finetune_variant_effect_OMIM.py @@ -0,0 +1,430 @@ +import os +from functools import partial +from os import path as osp +import re +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer, AutoConfig +import wandb +import lightning.pytorch as pl +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from datasets import load_dataset, load_from_disk +from sklearn.metrics import roc_auc_score, precision_recall_curve, auc +from transformers import DefaultDataCollator +from src.utils.train import get_logger +from caduceus.tokenization_caduceus import CaduceusTokenizer +from finetuning_glrb.utils import fsspec_exists, get_last_embedding_dimension + +# Logger setup +log = get_logger(__name__) + +# Constants +WINDOW_SIZE_BP = 1536 + +def tokenize_variants(examples, tokenizer, max_length: int): + """ + Tokenize reference and alternate sequences. + + Args: + examples: A batch of items from the dataset. + tokenizer: AutoTokenizer instance. + max_length: Maximum length for tokenization. + + Returns: + dict with tokenized input IDs for reference and alternate sequences. + """ + ref_tokenized = tokenizer.batch_encode_plus( + examples["ref_forward_sequence"], + add_special_tokens=False, + return_attention_mask=False, + max_length=max_length, + truncation=True, + ) + alt_tokenized = tokenizer.batch_encode_plus( + examples["alt_forward_sequence"], + add_special_tokens=False, + return_attention_mask=False, + max_length=max_length, + truncation=True, + ) + return { + "ref_input_ids": ref_tokenized["input_ids"], + "alt_input_ids": alt_tokenized["input_ids"] + } + +def recast_chromosome(examples): + """ + Recast chromosome to integer format. + + Returns: + dict with recast chromosome as an integer. + """ + return { + "chromosome": -1 if examples["chromosome"] == "X" else -2 if examples["chromosome"] == "Y" else int(examples["chromosome"]) + } + +def find_variant_idx(examples): + """ + Find the index of the variant in the sequence. + + Args: + examples: Items from the dataset (not batched). + + Returns: + dict with the index of the variant. + """ + idx = len(examples["ref_input_ids"]) // 2 # Assume variant is at the midpoint + if examples["ref_input_ids"][idx] == examples["alt_input_ids"][idx]: + idx = -1 + for i, (ref, alt) in enumerate(zip(examples["ref_input_ids"], examples["alt_input_ids"])): + if ref != alt: + idx = i + return {"variant_idx": idx} + +class MLP_VEP_OMIM(nn.Module): + """ + MLP head for Variant Effect Prediction (OMIM). + + Args: + input_size: Input size for the linear layer. + hidden_size: Hidden layer size. + output_size: Output size for the linear layer. + """ + def __init__(self, input_size, hidden_size, output_size): + super(MLP_VEP_OMIM, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.sp1 = nn.Softplus() + self.fc2 = nn.Linear(hidden_size, hidden_size) + self.sp2 = nn.Softplus() + self.fc3 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + return self.fc3(self.sp2(self.fc2(self.sp1(self.fc1(x))))) + +class DNAModelForOMIMFinetuning(nn.Module): + """ + DNA Model for OMIM Variant Effect Prediction fine-tuning. + + Args: + args: Arguments containing model configurations. + """ + def __init__(self, args): + super().__init__() + self.rcps = args.rcps + self.bp_per_token = args.bp_per_token + self.config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True) + + if "nucleotide-transformer" in args.model_name.lower(): + self.backbone = AutoModelForMaskedLM.from_pretrained(args.model_name, trust_remote_code=True).esm + else: + self.backbone = AutoModel.from_pretrained(args.model_name, trust_remote_code=True) + + print(f"MODEL LOADED: {self.backbone}") + self.inner_dim = get_last_embedding_dimension(self.backbone,self.rcps) + print(f"Inner dim found for the Foundation Model: {self.inner_dim}") + self.head = MLP_VEP_OMIM(input_size=2*self.inner_dim, hidden_size=2*self.inner_dim, output_size=2) + + def forward(self, alt_input_ids, ref_input_ids, variant_idx): + embeds_alternate = self.backbone(alt_input_ids)[0] + embeds_reference = self.backbone(ref_input_ids)[0] + window_size = WINDOW_SIZE_BP // self.bp_per_token // 2 + num_channels = embeds_alternate.size(-1) + + if self.rcps: + embeds_alt = embeds_alternate[..., :num_channels // 2] + embeds_ref = embeds_reference[..., :num_channels // 2] + expanded_indices = torch.arange(-window_size, window_size + 1, device=variant_idx.device).unsqueeze(0) + variant_idx.unsqueeze(1) + expanded_indices = torch.clamp(expanded_indices, 0, embeds_alt.size(1) - 1) + + windowed_embeds_alt = torch.gather(embeds_alt, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_alt.size(2))) + windowed_embeds_ref = torch.gather(embeds_ref, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_ref.size(2))) + + mean_embeds_alt = windowed_embeds_alt.mean(dim=1) + mean_embeds_ref = windowed_embeds_ref.mean(dim=1) + + concat_embeds = torch.cat([mean_embeds_alt, mean_embeds_ref], dim=-1) + + rc_embeds_alt = embeds_alternate[..., num_channels // 2:].contiguous().flip(dims=[1, 2]) + rc_embeds_ref = embeds_reference[..., num_channels // 2:].contiguous().flip(dims=[1, 2]) + + rc_windowed_embeds_alt = torch.gather(rc_embeds_alt, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, rc_embeds_alt.size(2))) + rc_windowed_embeds_ref = torch.gather(rc_embeds_ref, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, rc_embeds_ref.size(2))) + + rc_mean_embeds_alt = rc_windowed_embeds_alt.mean(dim=1) + rc_mean_embeds_ref = rc_windowed_embeds_ref.mean(dim=1) + + rc_concat_embeds = torch.cat([rc_mean_embeds_alt, rc_mean_embeds_ref], dim=-1) + + final_window = concat_embeds + rc_concat_embeds + return self.head(final_window) + + else: + expanded_indices = torch.arange(-window_size, window_size + 1, device=variant_idx.device).unsqueeze(0) + variant_idx.unsqueeze(1) + expanded_indices = torch.clamp(expanded_indices, 0, embeds_alternate.size(1) - 1) + + windowed_embeds_alt = torch.gather(embeds_alternate, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_alternate.size(2))).mean(dim=1) + windowed_embeds_ref = torch.gather(embeds_reference, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_reference.size(2))).mean(dim=1) + + concat_embeds = torch.cat([windowed_embeds_alt, windowed_embeds_ref], dim=-1) + return self.head(concat_embeds) + +class Lit_OMIMFinetuning(pl.LightningModule): + """ + PyTorch Lightning model for fine-tuning on OMIM Variant Effect Prediction. + + Args: + args: Arguments containing model and training configurations. + """ + def __init__(self, args): + super().__init__() + self.save_hyperparameters(args) + self.model = DNAModelForOMIMFinetuning(args) + self.criterion = nn.CrossEntropyLoss() + self.validation_step_preds = [] + self.validation_step_labels = [] + self.training_step_preds = [] + self.training_step_labels = [] + self.training_step_correct = 0 + self.training_step_total = 0 + self.validation_step_correct = 0 + self.validation_step_total = 0 + + def forward(self, alt_input_ids, ref_input_ids, variant_idx): + return self.model(alt_input_ids, ref_input_ids, variant_idx) + + def training_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + alt_input_ids = batch["alt_input_ids"] + variant_index = batch["variant_idx"] + labels = batch["labels"] + + logits = self(alt_input_ids, ref_input_ids, variant_index) + loss = self.criterion(logits, labels) + self.log('train_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + preds = torch.argmax(logits, dim=1) + correct = (preds == labels).sum().item() + self.training_step_correct += correct + self.training_step_total += len(labels) + + all_labels = labels.cpu().numpy() + all_predictions = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy() + self.training_step_preds.extend(all_predictions) + self.training_step_labels.extend(all_labels) + + return loss + + def validation_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + alt_input_ids = batch["alt_input_ids"] + variant_index = batch["variant_idx"] + labels = batch["labels"] + + logits = self(alt_input_ids, ref_input_ids, variant_index) + loss = self.criterion(logits, labels) + self.log('val_loss', loss, on_epoch=True, on_step=False, sync_dist=True) + + preds = torch.argmax(logits, dim=1) + correct = (preds == labels).sum().item() + self.validation_step_correct += correct + self.validation_step_total += len(labels) + + all_labels = labels.cpu().numpy() + all_predictions = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy() + self.validation_step_preds.extend(all_predictions) + self.validation_step_labels.extend(all_labels) + + def on_validation_epoch_end(self): + val_auroc = roc_auc_score(self.validation_step_labels, self.validation_step_preds) + precision, recall, _ = precision_recall_curve(self.validation_step_labels, self.validation_step_preds) + val_auprc = auc(recall, precision) + val_accuracy = self.validation_step_correct / self.validation_step_total + + self.logger.experiment.log({ + "validation/AUROC": val_auroc, + "validation/Accuracy": val_accuracy, + "validation/AUPRC": val_auprc, + }) + + self.validation_step_labels.clear() + self.validation_step_preds.clear() + self.validation_step_correct = 0 + self.validation_step_total = 0 + + def on_train_epoch_end(self): + train_auroc = roc_auc_score(self.training_step_labels, self.training_step_preds) + precision, recall, _ = precision_recall_curve(self.training_step_labels, self.training_step_preds) + train_auprc = auc(recall, precision) + train_accuracy = self.training_step_correct / self.training_step_total + + self.logger.experiment.log({ + "train/AUROC": train_auroc, + "train/Accuracy": train_accuracy, + "train/AUPRC": train_auprc + }) + + self.training_step_labels.clear() + self.training_step_preds.clear() + self.training_step_correct = 0 + self.training_step_total = 0 + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate) + +class VariantEffectPredictionDataModule(pl.LightningDataModule): + """ + Data module for OMIM Variant Effect Prediction fine-tuning with PyTorch Lightning. + + Args: + config: Configuration dictionary with data-related parameters. + """ + def __init__(self, config): + super().__init__() + self.seq_len = config.seq_len + self.model_name = config.model_name + self.train_batch_size = config.train_batch_size + self.test_batch_size = config.test_batch_size + self.num_workers = config.num_workers + self.train_ratio = config.train_ratio + self.eval_ratio = config.eval_ratio + self.cache_dir = "./" + self.dataset = None + + if "caduceus" in self.model_name: + self.tokenizer = CaduceusTokenizer( + model_max_length=self.seq_len, + add_special_tokens=False + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + + def prepare_data(self): + if not fsspec_exists(self._get_preprocessed_cache_file()): + self._download_and_preprocess_data() + + def setup(self, stage=None): + self.prepare_data() + self.dataset = load_from_disk(self._get_preprocessed_cache_file()) + + self.train_dataset = self.dataset["train"] + self.val_dataset = self.dataset["test"] + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def _get_preprocessed_cache_file(self): + self.cache_dir = osp.join( + "./", "data", "InstaDeepAI___genomics-long-range-benchmark", + "variant_effect_pathogenic_omim", f"seqlen{self.seq_len}" + ) + cache_file = os.path.join(self.cache_dir, "caduceus_char_token_preprocessed") + return re.sub(r"=", "_", cache_file) + + def _download_and_preprocess_data(self): + log.warning(f"Downloading and preprocessing data...") + dataset = load_dataset( + "InstaDeepAI/genomics-long-range-benchmark", + task_name="variant_effect_pathogenic_omim", + sequence_length=self.seq_len, + load_from_cache=False, + trust_remote_code=True + ) + try: + del dataset["validation"] + except KeyError: + pass + + dataset = dataset.filter( + lambda example: example["ref_forward_sequence"].count('N') < 0.005 * self.seq_len, + desc="Filter N's" + ) + dataset = dataset.map( + recast_chromosome, + remove_columns=["chromosome"], + desc="Recast chromosome" + ) + dataset = dataset.map( + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len), + batch_size=1000, + batched=True, + remove_columns=["ref_forward_sequence", "alt_forward_sequence"], + desc="Tokenize", + num_proc=self.num_workers + ) + dataset = dataset.map(find_variant_idx, desc="Find variant idx") + + dataset.save_to_disk(self._get_preprocessed_cache_file()) + log.warning(f"Data downloaded and preprocessed successfully.") + +def main_lit(args): + """ + Main function to start the training process for OMIM Variant Effect Prediction using PyTorch Lightning. + + Args: + args: Command line arguments or configuration dictionary. + """ + wandb.login(key=args.wandb_api_key) + wandb_logger = WandbLogger( + name=f"{args.name}-{args.seq_len}", + project="Variant Effect Prediction OMIM", + log_model=True + ) + data_module = VariantEffectPredictionDataModule(args) + data_module.setup() + + model = Lit_OMIMFinetuning(args) + + early_stopping_callback = EarlyStopping( + monitor="val_loss", + patience=args.patience, + verbose=True, + mode="min" + ) + + checkpoint_callback = ModelCheckpoint( + dirpath="./checkpoints", + filename="best-checkpoint", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[early_stopping_callback, checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision="16-mixed", + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + trainer.fit(model, data_module) diff --git a/finetuning_glrb/finetune_variant_effect_causal_eqtl.py b/finetuning_glrb/finetune_variant_effect_causal_eqtl.py new file mode 100644 index 0000000..ff55590 --- /dev/null +++ b/finetuning_glrb/finetune_variant_effect_causal_eqtl.py @@ -0,0 +1,444 @@ +import os +from functools import partial +from os import path as osp +import re +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, Dataset +from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer, AutoConfig +import wandb +import lightning.pytorch as pl +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +import numpy as np +from datasets import load_dataset, load_from_disk +from sklearn import preprocessing +from sklearn.metrics import accuracy_score, precision_recall_curve, auc, roc_auc_score +from transformers import DefaultDataCollator +from src.utils.train import get_logger +from caduceus.tokenization_caduceus import CaduceusTokenizer +from finetuning_glrb.utils import fsspec_exists, get_last_embedding_dimension + +# Constants +WINDOW_SIZE_BP = 1536 +DIST_TO_TSS = [[0, 30_000], [30_000, 100_000], [100_000, np.inf]] + +# Logger setup +log = get_logger(__name__) + +def recast_chromosome_tissue_dist2TSS(examples): + """ + Recast chromosome to integer and retain tissue and distance to nearest TSS. + + Returns: + dict with recast chromosome, tissue, and distance to TSS. + """ + return { + "chromosome": -1 if examples["chromosome"] == "X" else int(examples["chromosome"]), + "tissue": examples["tissue"], + "distance_to_nearest_tss": examples["distance_to_nearest_tss"] + } + +def tokenize_variants(examples, tokenizer, max_length: int): + """ + Tokenize sequences. + + Args: + examples: A batch of items from the dataset. + tokenizer: AutoTokenizer instance. + max_length: Maximum length for tokenization. + + Returns: + dict with tokenized input IDs for reference and alternate sequences. + """ + ref_tokenized = tokenizer.batch_encode_plus( + examples["ref_forward_sequence"], + add_special_tokens=False, + return_attention_mask=False, + max_length=max_length, + truncation=True, + ) + alt_tokenized = tokenizer.batch_encode_plus( + examples["alt_forward_sequence"], + add_special_tokens=False, + return_attention_mask=False, + max_length=max_length, + truncation=True, + ) + return { + "ref_input_ids": ref_tokenized["input_ids"], + "alt_input_ids": alt_tokenized["input_ids"] + } + +def find_variant_idx(examples): + """ + Find the index of the variant in the sequence. + + Args: + examples: Items from the dataset (not batched). + + Returns: + dict with the index of the variant. + """ + idx = len(examples["ref_input_ids"]) // 2 # Assume variant is at the midpoint + if examples["ref_input_ids"][idx] == examples["alt_input_ids"][idx]: + idx = -1 + for i, (ref, alt) in enumerate(zip(examples["ref_input_ids"], examples["alt_input_ids"])): + if ref != alt: + idx = i + return {"variant_idx": idx} + +def dataset_tss_filter(data: Dataset, min_distance: int, max_distance: int): + """ + Filter the data based on the distance to the nearest TSS. + + Args: + data: Dataset to be filtered. + min_distance: Minimum distance to the TSS. + max_distance: Maximum distance to the TSS. + + Returns: + Filtered dataset. + """ + distance_mask = (data["distance_to_nearest_tss"] >= min_distance) & (data["distance_to_nearest_tss"] <= max_distance) + filtered_data = {key: value[distance_mask] for key, value in data.items()} + return filtered_data + +class MLP_VEP(nn.Module): + """ + MLP head for Variant Effect Prediction (VEP). + + Args: + input_size: Input size for the linear layer. + hidden_size: Hidden layer size. + output_size: Output size for the linear layer. + """ + def __init__(self, input_size, hidden_size, output_size): + super(MLP_VEP, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.sp1 = nn.Softplus() + self.fc2 = nn.Linear(hidden_size, hidden_size) + self.sp2 = nn.Softplus() + self.fc3 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + return self.fc3(self.sp2(self.fc2(self.sp1(self.fc1(x))))) + +class DNAModelForVEPFinetuning(nn.Module): + """ + DNA Model for Variant Effect Prediction fine-tuning. + + Args: + args: Arguments containing model configurations. + """ + def __init__(self, args): + super().__init__() + self.rcps = args.rcps + self.bp_per_token = args.bp_per_token + self.config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True) + + if "nucleotide-transformer" in args.model_name.lower(): + self.backbone = AutoModelForMaskedLM.from_pretrained(args.model_name, trust_remote_code=True).esm + else: + self.backbone = AutoModel.from_pretrained(args.model_name, trust_remote_code=True) + + print(f"MODEL LOADED: {self.backbone}") + self.inner_dim = get_last_embedding_dimension(self.backbone,self.rcps) + print(f"Inner dim found for the Foundation Model: {self.inner_dim}") + self.head = MLP_VEP(input_size=2 * self.inner_dim + 1, hidden_size=2 * self.inner_dim, output_size=2) + + def forward(self, alt_input_ids, ref_input_ids, variant_idx, tissue_embed): + # Get embeddings for the alternate and reference sequences + embeds_alternate = self.backbone(alt_input_ids)[0] + embeds_reference = self.backbone(ref_input_ids)[0] + window_size = WINDOW_SIZE_BP // self.bp_per_token // 2 + num_channels = embeds_alternate.size(-1) + + if self.rcps: + # Reverse complement processing + embeds_alt = embeds_alternate[..., :num_channels // 2] + embeds_ref = embeds_reference[..., :num_channels // 2] + + expanded_indices = torch.arange(-window_size, window_size + 1, device=variant_idx.device).unsqueeze(0) + variant_idx.unsqueeze(1) + expanded_indices = torch.clamp(expanded_indices, 0, embeds_alt.size(1) - 1) + + # Extract windowed embeddings + windowed_embeds_alt = torch.gather(embeds_alt, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_alt.size(2))) + windowed_embeds_ref = torch.gather(embeds_ref, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_ref.size(2))) + + mean_embeds_alt = windowed_embeds_alt.mean(dim=1) + mean_embeds_ref = windowed_embeds_ref.mean(dim=1) + + # Concatenate the embeddings + concat_embeds = torch.cat([mean_embeds_alt, mean_embeds_ref, tissue_embed[..., None]], dim=-1) + + #Same for the RC-Equivalent part + rc_embeds_alt = embeds_alternate[..., num_channels // 2:].contiguous().flip(dims=[1, 2]) + rc_embeds_ref = embeds_reference[..., num_channels // 2:].contiguous().flip(dims=[1, 2]) + + rc_windowed_embeds_alt = torch.gather(rc_embeds_alt, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, rc_embeds_alt.size(2))) + rc_windowed_embeds_ref = torch.gather(rc_embeds_ref, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, rc_embeds_ref.size(2))) + + rc_mean_embeds_alt = rc_windowed_embeds_alt.mean(dim=1) + rc_mean_embeds_ref = rc_windowed_embeds_ref.mean(dim=1) + + rc_concat_embeds = torch.cat([rc_mean_embeds_alt, rc_mean_embeds_ref, tissue_embed[..., None]], dim=-1) + + final_window = concat_embeds + rc_concat_embeds + return self.head(final_window) + + else: + expanded_indices = torch.arange(-window_size, window_size + 1, device=variant_idx.device).unsqueeze(0) + variant_idx.unsqueeze(1) + expanded_indices = torch.clamp(expanded_indices, 0, embeds_alternate.size(1) - 1) + + windowed_embeds_alt = torch.gather(embeds_alternate, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_alternate.size(2))).mean(dim=1) + windowed_embeds_ref = torch.gather(embeds_reference, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_reference.size(2))).mean(dim=1) + + concat_embeds = torch.cat([windowed_embeds_alt, windowed_embeds_ref, tissue_embed[..., None]], dim=-1) + return self.head(concat_embeds) + +class LitVEPFinetuning(pl.LightningModule): + """ + PyTorch Lightning model for fine-tuning on Variant Effect Prediction. + + Args: + args: Arguments containing model and training configurations. + """ + def __init__(self, args): + super().__init__() + self.save_hyperparameters(args) + self.model = DNAModelForVEPFinetuning(args) + self.criterion = nn.CrossEntropyLoss() + self.validation_step_preds = {i: [] for i in range(len(DIST_TO_TSS))} + self.validation_step_labels = {i: [] for i in range(len(DIST_TO_TSS))} + + def forward(self, alt_input_ids, ref_input_ids, variant_idx, tissue_embed): + return self.model(alt_input_ids, ref_input_ids, variant_idx, tissue_embed) + + def training_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + alt_input_ids = batch["alt_input_ids"] + variant_index = batch["variant_idx"] + tissue_embed = batch["tissue_embed"] + labels = batch["labels"] + + logits = self(alt_input_ids, ref_input_ids, variant_index, tissue_embed) + loss = self.criterion(logits, labels) + self.log('train_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + return loss + + def validation_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + alt_input_ids = batch["alt_input_ids"] + variant_index = batch["variant_idx"] + tissue_embed = batch["tissue_embed"] + labels = batch["labels"] + distance_to_nearest_tss = batch["distance_to_nearest_tss"] + + logits = self(alt_input_ids, ref_input_ids, variant_index, tissue_embed) + loss = self.criterion(logits, labels) + self.log('val_loss', loss, on_epoch=True, on_step=False, sync_dist=True) + + # Predictions for AUROC + preds = torch.argmax(logits, dim=1).detach().cpu().numpy() + labels_np = labels.cpu().numpy() + + for i, (min_dist, max_dist) in enumerate(DIST_TO_TSS): + mask = ((distance_to_nearest_tss >= min_dist) & (distance_to_nearest_tss < max_dist)).cpu().numpy() + filtered_preds = preds[mask] + filtered_labels = labels_np[mask] + + if len(filtered_labels) > 0: + self.validation_step_labels[i].extend(filtered_labels) + self.validation_step_preds[i].extend(filtered_preds) + + def on_validation_epoch_end(self): + for i, (min_dist, max_dist) in enumerate(DIST_TO_TSS): + if len(self.validation_step_labels[i]) > 0: + val_auroc = roc_auc_score(self.validation_step_labels[i], self.validation_step_preds[i]) + precision, recall, _ = precision_recall_curve(self.validation_step_labels[i], self.validation_step_preds[i]) + val_auprc = auc(recall, precision) + val_accuracy = accuracy_score(self.validation_step_labels[i], self.validation_step_preds[i]) + + # Log metrics for each TSS distance bucket + self.log(f'validation/TSS({min_dist}-{max_dist})/AUROC', val_auroc, on_epoch=True, sync_dist=True) + self.log(f'validation/TSS({min_dist}-{max_dist})/Accuracy', val_accuracy, on_epoch=True, sync_dist=True) + self.log(f'validation/TSS({min_dist}-{max_dist})/AUPRC', val_auprc, on_epoch=True, sync_dist=True) + print(f'Bucket {i} [{min_dist}-{max_dist}] - AUROC: {val_auroc:.4f}') + + self.validation_step_labels[i].clear() + self.validation_step_preds[i].clear() + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate) + +class VariantEffectPredictionDataModule(pl.LightningDataModule): + """ + Data module for Variant Effect Prediction finetuning with PyTorch Lightning. + + Args: + config: Configuration dictionary with data-related parameters. + """ + def __init__(self, config): + super().__init__() + self.seq_len = config.seq_len + self.model_name = config.model_name + self.train_batch_size = config.train_batch_size + self.test_batch_size = config.test_batch_size + self.num_workers = config.num_workers + self.train_ratio = config.train_ratio + self.eval_ratio = config.eval_ratio + self.cache_dir = "./" + self.dataset = None + + # Initialize the tokenizer + if "caduceus" in self.model_name: + self.tokenizer = CaduceusTokenizer( + model_max_length=self.seq_len, + add_special_tokens=False + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + + def prepare_data(self): + # Download and preprocess data if not already done + if not fsspec_exists(self._get_preprocessed_cache_file()): + self._download_and_preprocess_data() + + def setup(self, stage=None): + # Load the preprocessed dataset + self.prepare_data() + self.dataset = load_from_disk(self._get_preprocessed_cache_file()) + + # Split the dataset into train and validation sets + self.train_dataset = self.dataset["train"] + self.val_dataset = self.dataset["test"] + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _get_preprocessed_cache_file(self): + cache_dir = osp.join( + "./", "data", "InstaDeepAI___genomics-long-range-benchmark", + "variant_effect_causal_eqtl", f"seqlen{self.seq_len}" + ) + cache_file = os.path.join(cache_dir, "caduceus_char_token_preprocessed") + return re.sub(r"=", "_", cache_file) + + def _download_and_preprocess_data(self): + log.warning(f"Downloading and preprocessing data...") + dataset = load_dataset( + "InstaDeepAI/genomics-long-range-benchmark", + task_name="variant_effect_causal_eqtl", + sequence_length=self.seq_len, + load_from_cache=False, + trust_remote_code=True + ) + try: + del dataset["validation"] # Remove empty validation split if it exists + except KeyError: + pass + + dataset = dataset.filter( + lambda example: example["ref_forward_sequence"].count('N') < 0.005 * self.seq_len, + desc="Filter N's" + ) + dataset = dataset.map( + recast_chromosome_tissue_dist2TSS, + remove_columns=["chromosome", "tissue", "distance_to_nearest_tss"], + desc="Recast chromosome" + ) + dataset = dataset.map( + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len), + batch_size=1000, + batched=True, + remove_columns=["ref_forward_sequence", "alt_forward_sequence"], + desc="Tokenize", + num_proc=self.num_workers + ) + dataset = dataset.map(find_variant_idx, desc="Find variant idx") + + label_encoder = preprocessing.LabelEncoder() + label_encoder.fit(dataset["test"]["tissue"]) + train_tissue_embed = label_encoder.transform(dataset["train"]["tissue"]) + dataset["train"] = dataset["train"].add_column("tissue_embed", train_tissue_embed) + test_tissue_embed = label_encoder.transform(dataset["test"]["tissue"]) + dataset["test"] = dataset["test"].add_column("tissue_embed", test_tissue_embed) + + # Save to disk if running locally + dataset.save_to_disk(self._get_preprocessed_cache_file()) + + log.warning(f"Data downloaded and preprocessed successfully.") + +def main_lit(args): + """ + Main function to start process for Variant Effect Prediction Finetuning eQTL using PyTorch Lightning. + + Args: + args: Command line arguments or configuration dictionary. + """ + wandb.login(key=args.wandb_api_key) + wandb_logger = WandbLogger( + name=f"{args.name}-{args.seq_len}", + project="Variant Effect Prediction Causal eQTL", + log_model=True # Automatically log model checkpoints + ) + data_module = VariantEffectPredictionDataModule(args) + data_module.setup() + + model = LitVEPFinetuning(args) + + early_stopping_callback = EarlyStopping( + monitor="val_loss", + patience=args.patience, + verbose=True, + mode="min" + ) + + checkpoint_callback = ModelCheckpoint( + dirpath="./checkpoints", + filename="best-checkpoint", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[early_stopping_callback, checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision="16-mixed", + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + # Start the training process + trainer.fit(model, data_module) diff --git a/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py b/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py new file mode 100644 index 0000000..c5b1c72 --- /dev/null +++ b/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py @@ -0,0 +1,430 @@ +import os +from functools import partial +from os import path as osp +import re +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer, AutoConfig +import wandb +import lightning.pytorch as pl +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +import numpy as np +from datasets import load_dataset, load_from_disk +from sklearn import preprocessing +from sklearn.metrics import precision_recall_curve, auc, roc_auc_score +from transformers import DefaultDataCollator +from src.utils.train import get_logger +from caduceus.tokenization_caduceus import CaduceusTokenizer +from finetuning_glrb.utils import fsspec_exists, get_last_embedding_dimension + +# Logger setup +log = get_logger(__name__) + +# Constants +WINDOW_SIZE_BP = 1536 + +def tokenize_variants(examples, tokenizer, max_length: int): + """ + Tokenize reference and alternate sequences. + + Args: + examples: A batch of items from the dataset. + tokenizer: AutoTokenizer instance. + max_length: Maximum length for tokenization. + + Returns: + dict with tokenized input IDs for reference and alternate sequences. + """ + ref_tokenized = tokenizer.batch_encode_plus( + examples["ref_forward_sequence"], + add_special_tokens=False, + return_attention_mask=False, + max_length=max_length, + truncation=True, + ) + alt_tokenized = tokenizer.batch_encode_plus( + examples["alt_forward_sequence"], + add_special_tokens=False, + return_attention_mask=False, + max_length=max_length, + truncation=True, + ) + return { + "ref_input_ids": ref_tokenized["input_ids"], + "alt_input_ids": alt_tokenized["input_ids"] + } + +def recast_chromosome(examples): + """ + Recast chromosome to integer format. + + Returns: + dict with recast chromosome as an integer. + """ + return { + "chromosome": -1 if examples["chromosome"] == "X" else -2 if examples["chromosome"] == "Y" else int(examples["chromosome"]) + } + +def find_variant_idx(examples): + """ + Find the index of the variant in the sequence. + + Args: + examples: Items from the dataset (not batched). + + Returns: + dict with the index of the variant. + """ + idx = len(examples["ref_input_ids"]) // 2 # Assume variant is at the midpoint + if examples["ref_input_ids"][idx] == examples["alt_input_ids"][idx]: + idx = -1 + for i, (ref, alt) in enumerate(zip(examples["ref_input_ids"], examples["alt_input_ids"])): + if ref != alt: + idx = i + return {"variant_idx": idx} + +class MLP_VEP_ClinVar(nn.Module): + """ + MLP head for Variant Effect Prediction (ClinVar). + + Args: + input_size: Input size for the linear layer. + hidden_size: Hidden layer size. + output_size: Output size for the linear layer. + """ + def __init__(self, input_size, hidden_size, output_size): + super(MLP_VEP_ClinVar, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.sp1 = nn.Softplus() + self.fc2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + return self.fc2(self.sp1(self.fc1(x))) + +class DNAModelForVEPFinetuning(nn.Module): + """ + DNA Model for Variant Effect Prediction (ClinVar) fine-tuning. + + Args: + args: Arguments containing model configurations. + """ + def __init__(self, args): + super().__init__() + self.rcps = args.rcps + self.bp_per_token = args.bp_per_token + self.config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True) + + if "nucleotide-transformer" in args.model_name.lower(): + self.backbone = AutoModelForMaskedLM.from_pretrained(args.model_name, trust_remote_code=True).esm + else: + self.backbone = AutoModel.from_pretrained(args.model_name, trust_remote_code=True) + + print(f"MODEL LOADED: {self.backbone}") + self.inner_dim = get_last_embedding_dimension(self.backbone,self.rcps) + print(f"Inner dim found for the Foundation Model: {self.inner_dim}") + self.head = MLP_VEP_ClinVar(input_size=2*self.inner_dim, hidden_size=2*self.inner_dim, output_size=2) + + def forward(self, alt_input_ids, ref_input_ids, variant_idx): + embeds_alternate = self.backbone(alt_input_ids)[0] + embeds_reference = self.backbone(ref_input_ids)[0] + window_size = WINDOW_SIZE_BP // self.bp_per_token // 2 + num_channels = embeds_alternate.size(-1) + + if self.rcps: + embeds_alt = embeds_alternate[..., :num_channels // 2] + embeds_ref = embeds_reference[..., :num_channels // 2] + expanded_indices = torch.arange(-window_size, window_size + 1, device=variant_idx.device).unsqueeze(0) + variant_idx.unsqueeze(1) + expanded_indices = torch.clamp(expanded_indices, 0, embeds_alt.size(1) - 1) + + windowed_embeds_alt = torch.gather(embeds_alt, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_alt.size(2))) + windowed_embeds_ref = torch.gather(embeds_ref, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_ref.size(2))) + + mean_embeds_alt = windowed_embeds_alt.mean(dim=1) + mean_embeds_ref = windowed_embeds_ref.mean(dim=1) + + concat_embeds = torch.cat([mean_embeds_alt, mean_embeds_ref], dim=-1) + + rc_embeds_alt = embeds_alternate[..., num_channels // 2:].contiguous().flip(dims=[1, 2]) + rc_embeds_ref = embeds_reference[..., num_channels // 2:].contiguous().flip(dims=[1, 2]) + + rc_windowed_embeds_alt = torch.gather(rc_embeds_alt, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, rc_embeds_alt.size(2))) + rc_windowed_embeds_ref = torch.gather(rc_embeds_ref, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, rc_embeds_ref.size(2))) + + rc_mean_embeds_alt = rc_windowed_embeds_alt.mean(dim=1) + rc_mean_embeds_ref = rc_windowed_embeds_ref.mean(dim=1) + + rc_concat_embeds = torch.cat([rc_mean_embeds_alt, rc_mean_embeds_ref], dim=-1) + + final_window = concat_embeds + rc_concat_embeds + return self.head(final_window) + + else: + expanded_indices = torch.arange(-window_size, window_size + 1, device=variant_idx.device).unsqueeze(0) + variant_idx.unsqueeze(1) + expanded_indices = torch.clamp(expanded_indices, 0, embeds_alternate.size(1) - 1) + + windowed_embeds_alt = torch.gather(embeds_alternate, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_alternate.size(2))).mean(dim=1) + windowed_embeds_ref = torch.gather(embeds_reference, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_reference.size(2))).mean(dim=1) + + concat_embeds = torch.cat([windowed_embeds_alt, windowed_embeds_ref], dim=-1) + return self.head(concat_embeds) + +class Lit_ClinVarFinetuning(pl.LightningModule): + """ + PyTorch Lightning model for fine-tuning on ClinVar Variant Effect Prediction. + + Args: + args: Arguments containing model and training configurations. + """ + def __init__(self, args): + super().__init__() + self.save_hyperparameters(args) + self.model = DNAModelForVEPFinetuning(args) + self.criterion = nn.CrossEntropyLoss() + self.validation_step_preds = [] + self.validation_step_labels = [] + self.training_step_preds = [] + self.training_step_labels = [] + self.training_step_correct = 0 + self.training_step_total = 0 + self.validation_step_correct = 0 + self.validation_step_total = 0 + + def forward(self, alt_input_ids, ref_input_ids, variant_idx): + return self.model(alt_input_ids, ref_input_ids, variant_idx) + + def training_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + alt_input_ids = batch["alt_input_ids"] + variant_index = batch["variant_idx"] + labels = batch["labels"] + + logits = self(alt_input_ids, ref_input_ids, variant_index) + loss = self.criterion(logits, labels) + self.log('train_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + preds = torch.argmax(logits, dim=1) + correct = (preds == labels).sum().item() + self.training_step_correct += correct + self.training_step_total += len(labels) + + all_labels = labels.cpu().numpy() + all_predictions = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy() + self.training_step_preds.extend(all_predictions) + self.training_step_labels.extend(all_labels) + + return loss + + def validation_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + alt_input_ids = batch["alt_input_ids"] + variant_index = batch["variant_idx"] + labels = batch["labels"] + + logits = self(alt_input_ids, ref_input_ids, variant_index) + loss = self.criterion(logits, labels) + self.log('val_loss', loss, on_epoch=True, on_step=False, sync_dist=True) + + preds = torch.argmax(logits, dim=1) + correct = (preds == labels).sum().item() + self.validation_step_correct += correct + self.validation_step_total += len(labels) + + all_labels = labels.cpu().numpy() + all_predictions = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy() + self.validation_step_preds.extend(all_predictions) + self.validation_step_labels.extend(all_labels) + + def on_validation_epoch_end(self): + val_auroc = roc_auc_score(self.validation_step_labels, self.validation_step_preds) + precision, recall, _ = precision_recall_curve(self.validation_step_labels, self.validation_step_preds) + val_auprc = auc(recall, precision) + val_accuracy = self.validation_step_correct / self.validation_step_total + + self.logger.experiment.log({ + "validation/AUROC": val_auroc, + "validation/Accuracy": val_accuracy, + "validation/AUPRC": val_auprc, + }) + + self.validation_step_labels.clear() + self.validation_step_preds.clear() + self.validation_step_correct = 0 + self.validation_step_total = 0 + + def on_train_epoch_end(self): + train_auroc = roc_auc_score(self.training_step_labels, self.training_step_preds) + precision, recall, _ = precision_recall_curve(self.training_step_labels, self.training_step_preds) + train_auprc = auc(recall, precision) + train_accuracy = self.training_step_correct / self.training_step_total + + self.logger.experiment.log({ + "train/AUROC": train_auroc, + "train/Accuracy": train_accuracy, + "train/AUPRC": train_auprc + }) + + self.training_step_labels.clear() + self.training_step_preds.clear() + self.training_step_correct = 0 + self.training_step_total = 0 + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate) + +class VariantEffectPredictionDataModule(pl.LightningDataModule): + """ + Data module for ClinVar Variant Effect Prediction fine-tuning with PyTorch Lightning. + + Args: + config: Configuration dictionary with data-related parameters. + """ + def __init__(self, config): + super().__init__() + self.seq_len = config.seq_len + self.model_name = config.model_name + self.train_batch_size = config.train_batch_size + self.test_batch_size = config.test_batch_size + self.num_workers = config.num_workers + self.train_ratio = config.train_ratio + self.eval_ratio = config.eval_ratio + self.cache_dir = "./" + self.dataset = None + + if "caduceus" in self.model_name: + self.tokenizer = CaduceusTokenizer( + model_max_length=self.seq_len, + add_special_tokens=False + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + + def prepare_data(self): + if not fsspec_exists(self._get_preprocessed_cache_file()): + self._download_and_preprocess_data() + + def setup(self, stage=None): + self.prepare_data() + self.dataset = load_from_disk(self._get_preprocessed_cache_file()) + + self.train_dataset = self.dataset["train"] + self.val_dataset = self.dataset["test"] + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def _get_preprocessed_cache_file(self): + self.cache_dir = osp.join( + "./", "data", "InstaDeepAI___genomics-long-range-benchmark", + "variant_effect_pathogenic_clinvar", f"seqlen{self.seq_len}" + ) + cache_file = os.path.join(self.cache_dir, "caduceus_char_token_preprocessed") + return re.sub(r"=", "_", cache_file) + + def _download_and_preprocess_data(self): + log.warning(f"Downloading and preprocessing data...") + dataset = load_dataset( + "InstaDeepAI/genomics-long-range-benchmark", + task_name="variant_effect_pathogenic_clinvar", + sequence_length=self.seq_len, + load_from_cache=False, + trust_remote_code=True + ) + try: + del dataset["validation"] + except KeyError: + pass + + dataset = dataset.filter( + lambda example: example["ref_forward_sequence"].count('N') < 0.005 * self.seq_len, + desc="Filter N's" + ) + dataset = dataset.map( + recast_chromosome, + remove_columns=["chromosome"], + desc="Recast chromosome" + ) + dataset = dataset.map( + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len), + batch_size=1000, + batched=True, + remove_columns=["ref_forward_sequence", "alt_forward_sequence"], + desc="Tokenize", + num_proc=self.num_workers + ) + dataset = dataset.map(find_variant_idx, desc="Find variant idx") + + dataset.save_to_disk(self._get_preprocessed_cache_file()) + log.warning(f"Data downloaded and preprocessed successfully.") + +def main_lit(args): + """ + Main function to start the training process for ClinVar Variant Effect Prediction using PyTorch Lightning. + + Args: + args: Command line arguments or configuration dictionary. + """ + wandb.login(key=args.wandb_api_key) + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}", + project="Variant Effect Prediction ClinVar", + log_model=True + ) + data_module = VariantEffectPredictionDataModule(args) + data_module.setup() + + model = Lit_ClinVarFinetuning(args) + + early_stopping_callback = EarlyStopping( + monitor="val_loss", + patience=args.patience, + verbose=True, + mode="min" + ) + + checkpoint_callback = ModelCheckpoint( + dirpath="./checkpoints", + filename="best-checkpoint", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[early_stopping_callback, checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision="16-mixed", + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + trainer.fit(model, data_module) diff --git a/finetuning_glrb/main.py b/finetuning_glrb/main.py new file mode 100644 index 0000000..60ec486 --- /dev/null +++ b/finetuning_glrb/main.py @@ -0,0 +1,99 @@ + +import os +import sys +# Add the parent directory to sys.path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from os import path as osp +import argparse +import torch + + +from src.utils.train import get_logger +from finetuning_glrb.finetune_variant_effect_pathogenic_clinvar import main_lit as finetune_vep_clinvar +#from finetuning.finetune_caduceus_vep import main as finetune_vep_eqtl +from finetuning_glrb.finetune_variant_effect_OMIM import main_lit as main_omim +from finetuning_glrb.finetune_variant_effect_causal_eqtl import main_lit as finetune_vep_eqtl +from finetuning_glrb.finetune_bulk_rna import main_lit as finetune_bulk_rna_expression +from finetuning_glrb.finetune_chromatin import main_histone_marks,main_dna_accessibility +from finetuning_glrb.finetune_regulatory_elements import main_enhancer, main_promoter + +log = get_logger(__name__) + +def main(opts): + # Check if the value of args.task matches one of the predefined options + if opts.task == "variant_effect_causal_eqtl": + finetune_vep_eqtl(opts) + # Perform operations specific to this task + elif opts.task == "variant_effect_pathogenic_clinvar": + finetune_vep_clinvar(opts) + elif opts.task == "variant_effect_pathogenic_omim": + main_omim(opts) + elif opts.task == "bulk_rna_expression": + finetune_bulk_rna_expression(opts) + elif opts.task == "cage_prediction": + raise ValueError("Invalid task selected or not implemented yet.") + elif opts.task == "chromatin_features_histone_marks": + main_histone_marks(opts) + elif opts.task == "chromatin_features_dna_accessibility": + main_dna_accessibility(opts) + elif opts.task == "regulatory_element_promoter": + main_promoter(opts) + elif opts.task == "regulatory_element_enhancer": + main_enhancer(opts) + + + + +if __name__ == "__main__": + torch.multiprocessing.set_sharing_strategy('file_system') + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--task", + type=str, + choices=[ + "variant_effect_causal_eqtl", + "variant_effect_pathogenic_clinvar", + "variant_effect_pathogenic_omim", + "cage_prediction", + "bulk_rna_expression", + "chromatin_features_histone_marks", + "chromatin_features_dna_accessibility", + "regulatory_element_promoter", + "regulatory_element_enhancer" + ], + required=True, + help="Choose one of the predefined variant effects." + ) + parser.add_argument("--seq_len", type=int, default=131072, + help="Sequence length (in bp)..") + parser.add_argument("--model_name", type=str, required=True, help="Name of the pre-trained model to fine-tune") + parser.add_argument("--bp_per_token", type = int, default = 1, help = "Number of baise pairs per token.") + parser.add_argument("--save_dir", type=str, default="./outputs/downstream/vep_embeddings", + help="Directory to save downstream task.") + parser.add_argument("--wandb_api_key",type=str,default=None,help="Weights & Biases API key for logging.") + parser.add_argument("--name_wb", type=str, default=None, help="Embeddings model name.") + parser.add_argument("--train_batch_size", type=int, default=16, help="Batch size for training.") + parser.add_argument("--test_batch_size", type=int, default=16, help="Batch size for testing/validation.") + parser.add_argument("--num_workers", type=int, default=4, help="Number of workers.") + parser.add_argument("--preprocessed_dataset_path", type=str, default=None, help="Path to preprocessed dataset.") + parser.add_argument("--rcps", type=bool, default=False, help="Using rcps when extracting embeddings or not.") + parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs to train.") + parser.add_argument("--accumulate_grad_batches", type=int, default=1, help="Accumulate gradients") + parser.add_argument("--learning_rate", type=float, default=1e-4, + help="Learning rate for optimizer") + parser.add_argument("--patience", type=int, default=10, + help="Number of epochs with no improvement after which training will be stopped") + parser.add_argument("--log_interval", type=int, default=5, + help="Log interval") + parser.add_argument("--train_ratio", type=float, default=1.0, + help="Evaluation data ratio") + parser.add_argument("--eval_ratio", type=float, default=1.0, + help="Evaluation data ratio") + opts = parser.parse_args() + log.warning("*** Args ************************") + for k, v in vars(opts).items(): + log.warning(f" - {k}: {v}") + log.warning("******************************\n") + + main(opts) \ No newline at end of file diff --git a/finetuning_glrb/utils.py b/finetuning_glrb/utils.py new file mode 100644 index 0000000..4399b94 --- /dev/null +++ b/finetuning_glrb/utils.py @@ -0,0 +1,70 @@ +import os +import sys +# Add the parent directory to sys.path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +import fsspec +import torch +import torch.nn as nn +import torch.utils + +# Check if file exists using fsspec +def fsspec_exists(filename): + """Check if file exists in manner compatible with fsspec.""" + fs, _ = fsspec.core.url_to_fs(filename) + return fs.exists(filename) + +# List directory contents using fsspec +def fsspec_listdir(dirname): + """Listdir in manner compatible with fsspec.""" + fs, _ = fsspec.core.url_to_fs(dirname) + return fs.ls(dirname) + + +def get_last_embedding_dimension(model: nn.Module, rcps=False) -> int: + """ + Function to get the last embedding dimension of a PyTorch model by passing + a random tensor through the model and inspecting the output shape. + This is done with gradients disabled and always on GPU. + + Args: + model (nn.Module): The PyTorch model instance. + + Returns: + int: The last embedding dimension (i.e., the last dimension of the output tensor). + """ + # Move the model to GPU + model = model.to('cuda') + + # Try to determine the input shape based on the first layer of the model + for module in model.modules(): + if isinstance(module, nn.Conv2d): + # Assume a common image input size if it's a Conv2d layer + input_shape = (3, 224, 224) # RGB image of size 224x224 + break + elif isinstance(module, nn.Linear): + # Assume a 1D input size for a fully connected layer + input_shape = (module.in_features,) + break + elif isinstance(module, nn.Embedding): + # Assume a single index for an Embedding layer + input_shape = (1,) + break + else: + raise ValueError("Unable to determine the input shape automatically.") + + # Generate a random input tensor and move it to GPU + random_input = torch.randint(low=0, high=16, size=(1, *input_shape)).to('cuda') # Add batch size of 1 + + # Pass the tensor through the model with no gradients + with torch.no_grad(): + output = model(random_input)[0] + + # Get the shape of the output tensor + last_embedding_dimension = output.shape[-1] + + if rcps: + last_embedding_dimension //= 2 + + # Return the last dimension of the output tensor + return last_embedding_dimension From 758081981766b9685c6e730db9afcedda8b24cb1 Mon Sep 17 00:00:00 2001 From: Aymen Kallala Date: Fri, 16 Aug 2024 05:11:36 -0700 Subject: [PATCH 02/13] clean imports --- finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py | 2 -- finetuning_glrb/main.py | 1 - 2 files changed, 3 deletions(-) diff --git a/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py b/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py index c5b1c72..d60a9a1 100644 --- a/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py +++ b/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py @@ -10,9 +10,7 @@ import lightning.pytorch as pl from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint -import numpy as np from datasets import load_dataset, load_from_disk -from sklearn import preprocessing from sklearn.metrics import precision_recall_curve, auc, roc_auc_score from transformers import DefaultDataCollator from src.utils.train import get_logger diff --git a/finetuning_glrb/main.py b/finetuning_glrb/main.py index 60ec486..5b65dbc 100644 --- a/finetuning_glrb/main.py +++ b/finetuning_glrb/main.py @@ -11,7 +11,6 @@ from src.utils.train import get_logger from finetuning_glrb.finetune_variant_effect_pathogenic_clinvar import main_lit as finetune_vep_clinvar -#from finetuning.finetune_caduceus_vep import main as finetune_vep_eqtl from finetuning_glrb.finetune_variant_effect_OMIM import main_lit as main_omim from finetuning_glrb.finetune_variant_effect_causal_eqtl import main_lit as finetune_vep_eqtl from finetuning_glrb.finetune_bulk_rna import main_lit as finetune_bulk_rna_expression From d43a6cb125625ee2bcabfed885d240994a1e2790 Mon Sep 17 00:00:00 2001 From: Aymen Kallala Date: Fri, 23 Aug 2024 11:56:58 -0700 Subject: [PATCH 03/13] [Finetuning] CAGE Expression task --- finetuning_glrb/finetune_cage.py | 381 +++++++++++++++++++++++++++++++ finetuning_glrb/main.py | 3 +- 2 files changed, 383 insertions(+), 1 deletion(-) create mode 100644 finetuning_glrb/finetune_cage.py diff --git a/finetuning_glrb/finetune_cage.py b/finetuning_glrb/finetune_cage.py new file mode 100644 index 0000000..b90dfd0 --- /dev/null +++ b/finetuning_glrb/finetune_cage.py @@ -0,0 +1,381 @@ +import os +from functools import partial +from os import path as osp +import re +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForMaskedLM, + AutoModel, + AutoTokenizer, + AutoConfig, + DefaultDataCollator, +) +import numpy as np +import wandb +import lightning.pytorch as pl +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from datasets import load_dataset, load_from_disk +from sklearn.metrics import r2_score +from src.utils.train import get_logger +from caduceus.tokenization_caduceus import CaduceusTokenizer +from finetuning_glrb.utils import fsspec_exists, get_last_embedding_dimension + +# Logger setup +log = get_logger(__name__) + +# Constants for the upstream and downstream window sizes +BIN_SIZE = 128 + +def _closest_multiple_after(after, multiple_of): + """ + Computes the closest multiple of a certain int after a certain int. E.g the closest multiple of 2 after 9 is 10. + + Args: + after: int. + multiple_of: int + + Returns: + the closest multiple found. + """ + remainder = after % multiple_of + if remainder == 0: + return after + else: + return after + (multiple_of - remainder) + +def tokenize_variants(examples, tokenizer, max_length: int): + """ + Tokenize sequence. + + Args: + examples: A batch of items from the dataset. + tokenizer: AutoTokenizer instance. + max_length: Maximum length for tokenization. + + Returns: + dict with tokenized input IDs. + """ + ref_tokenized = tokenizer.batch_encode_plus( + examples["sequence"], + add_special_tokens=False, + return_attention_mask=False, + max_length=max_length, + truncation=True, + padding = "max_length" + ) + + return { + "ref_input_ids": ref_tokenized["input_ids"], + } + +class MLP_CAGE(nn.Module): + """ + Regression head for Bulk RNA prediction task. + + Args: + input_size: Input size for the linear layer. + hidden_size: Hidden layer size. + output_size: Output size for the linear layer. + """ + def __init__(self, input_size, hidden_size, output_size): + super(MLP_CAGE, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.sp1 = nn.Softplus() + self.fc2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + return self.fc2(self.sp1(self.fc1(x))) + +class DNAModelForCAGE(nn.Module): + """ + DNA Model for CAGE prediction. + + Args: + args: Arguments containing model configurations. + """ + def __init__(self, args): + super().__init__() + self.rcps = args.rcps + self.initial_seq_len = args.seq_len + self.bp_per_token = args.bp_per_token + self.config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True) + self.num_bins= self.initial_seq_len // BIN_SIZE + + # Load the appropriate backbone model based on the model name + if "nucleotide-transformer" in args.model_name.lower(): + self.backbone = AutoModelForMaskedLM.from_pretrained(args.model_name, trust_remote_code=True).esm + else: + self.backbone = AutoModel.from_pretrained(args.model_name, trust_remote_code=True) + + print(f"MODEL LOADED: {self.backbone}") + self.inner_dim = get_last_embedding_dimension(self.backbone,self.rcps) + print(f"Inner dim founded for the Foundation Model: {self.inner_dim}") + self.head = MLP_CAGE(input_size=self.inner_dim, hidden_size=2*self.inner_dim, output_size=50) + #self.heads = nn.ModuleList([ + #MLP_CAGE(input_size=self.inner_dim, hidden_size=2*self.inner_dim, output_size=50) + #for _ in range(self.num_bins) + #]) + + def forward(self, input_ids): + # Get embeddings from the backbone + embeds_out = self.backbone(input_ids)[0] + batch_size,seq_len,num_channels =embeds_out.size() + bin_size = seq_len//self.num_bins + # Reshape embeddings into bins + bins = embeds_out.view(batch_size, self.num_bins, bin_size, num_channels) + + if self.rcps: + # If the model is RC-equivariant + embeds = bins[..., :num_channels // 2] + #Same for the RC-equivalent + rc_embeds = bins[..., num_channels // 2:].contiguous().flip(dims=[1,2,3]) + + #Combine the Reference and RC-Equivariant resulting embeddings + bins = embeds + rc_embeds + + outputs = torch.zeros(batch_size,self.num_bins,50).to(bins.device) + for i in range(self.num_bins): + bin_embedding = bins[:, i, :, :].mean(dim=1) + out = self.head(bin_embedding) + outputs[:,i,:] = out + return outputs + +class Lit_DNAModelForCAGE(pl.LightningModule): + """ + PyTorch Lightning model for fine-tuning on Bulk RNA prediction. + + Args: + args: Arguments containing model and training configurations. + """ + def __init__(self, args): + super().__init__() + self.save_hyperparameters(args) + self.rcps = args.rcps + self.model = DNAModelForCAGE(args) + self.criterion = nn.MSELoss() + self.validation_step_preds = [] + self.validation_step_labels = [] + self.training_step_preds = [] + self.training_step_labels = [] + + def forward(self, ref_input_ids): + return self.model(ref_input_ids) + + def training_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"] + + logits = self.model(ref_input_ids) + loss = self.criterion(logits, labels) + self.log('train_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + # Track predictions and labels for R² score + self.training_step_preds.extend(logits.view(-1, logits.size(-1)).detach().cpu().numpy()) + self.training_step_labels.extend(labels.view(-1, labels.size(-1)).detach().cpu().numpy()) + + return loss + + def validation_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"] + + logits = self.model(ref_input_ids) + loss = self.criterion(logits, labels) + self.log('val_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + # Track predictions and labels for R² score + self.validation_step_preds.extend(logits.view(-1, logits.size(-1)).detach().cpu().numpy()) + self.validation_step_labels.extend(labels.view(-1, labels.size(-1)).detach().cpu().numpy()) + + return loss + + def on_validation_epoch_end(self): + #Log(1+x) normalize the preds and labels + self.validation_step_labels = list(np.log1p(self.validation_step_labels)) + self.validation_step_preds = list(np.log1p(self.validation_step_preds)) + + # Calculate R² score for validation + val_r2 = r2_score(self.validation_step_labels, self.validation_step_preds) + self.log("validation/R2", val_r2, on_epoch=True, prog_bar=True, logger=True) + self.validation_step_labels.clear() + self.validation_step_preds.clear() + + def on_train_epoch_end(self): + #Log(1+x) normalize the preds and labels + self.training_step_labels = list(np.log1p(self.training_step_labels)) + self.training_step_preds = list(np.log1p(self.training_step_preds)) + + # Calculate R² score for training + train_r2 = r2_score(self.training_step_labels, self.training_step_preds) + self.log("train/R2", train_r2, on_epoch=True, prog_bar=True, logger=True) + self.training_step_labels.clear() + self.training_step_preds.clear() + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate) + +class CAGEDataModule(pl.LightningDataModule): + """ + Data module for Bulk RNA finetuning with PyTorch Lightning. + + Args: + config: Configuration dictionary with data-related parameters. + """ + def __init__(self, config): + super().__init__() + self.seq_len = config.seq_len + self.bp_per_token = config.bp_per_token + self.tokens_per_seq = self.seq_len//self.bp_per_token + self.bins_per_seq = self.seq_len // 128 + self.model_name = config.model_name + self.train_batch_size = config.train_batch_size + self.test_batch_size = config.test_batch_size + self.num_workers = config.num_workers + self.train_ratio = config.train_ratio + self.eval_ratio = config.eval_ratio + self.cache_dir = "./" + self.dataset = None + + # Initialize the tokenizer + if "caduceus" in self.model_name: + self.tokenizer = CaduceusTokenizer( + model_max_length=self.seq_len, + add_special_tokens=False + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + + def prepare_data(self): + # Download and preprocess data if not already done + if not fsspec_exists(self._get_preprocessed_cache_file()): + self._download_and_preprocess_data() + + def setup(self, stage=None): + # Load the preprocessed dataset + self.prepare_data() + self.dataset = load_from_disk(self._get_preprocessed_cache_file()) + + # Split the dataset into train and validation sets + self.train_dataset = self.dataset["train"] + self.val_dataset = self.dataset["test"] + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _get_preprocessed_cache_file(self): + self.cache_dir = osp.join( + "./", "data", "InstaDeepAI___genomics-long-range-benchmark", + "cage_prediction", f"seqlen{self.seq_len}" + ) + cache_file = os.path.join(self.cache_dir, "caduceus_char_token_preprocessed") + return re.sub(r"=", "_", cache_file) + + def _download_and_preprocess_data(self): + log.warning("Downloading and preprocessing data...") + dataset = load_dataset( + "InstaDeepAI/genomics-long-range-benchmark", + task_name="cage_prediction", + sequence_length=self.seq_len, + load_from_cache=False, + trust_remote_code=True + ) + try: + del dataset["validation"] # Remove empty validation split if it exists + except KeyError: + pass + + # Process data: filter sequences with too many 'N's, recast chromosomes, and tokenize + dataset = dataset.filter( + lambda example: example["sequence"].count('N') < 0.005 * self.seq_len, + desc="Filter N's" + ) + dataset = dataset.map( + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=_closest_multiple_after(self.tokens_per_seq,self.bins_per_seq)), + batch_size=1000, + batched=True, + remove_columns=["sequence"], + desc="Tokenize", + num_proc=self.num_workers + ) + + # Save processed dataset to disk + dataset.save_to_disk(self._get_preprocessed_cache_file()) + log.warning("Data downloaded and preprocessed successfully.") + +def main_lit(args): + """ + Main function to start the training process using PyTorch Lightning. + + Args: + args: Command line arguments or configuration dictionary. + """ + + wandb.login(key=args.wandb_api_key) + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}", + project="CAGE Predictions", + log_model=True, + save_dir=args.save_dir + ) + data_module = CAGEDataModule(args) + data_module.setup() + + model = Lit_DNAModelForCAGE(args) + + # Callbacks for early stopping and model checkpointing + early_stopping_callback = EarlyStopping( + monitor="val_loss", + patience=args.patience, + verbose=True, + mode="min" + ) + + checkpoint_callback = ModelCheckpoint( + dirpath=f"{args.save_dir}/checkpoints", + filename="best-checkpoint", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[early_stopping_callback, checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision="16-mixed", + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + # Start the training process + trainer.fit(model, data_module) diff --git a/finetuning_glrb/main.py b/finetuning_glrb/main.py index 5b65dbc..5da6b0a 100644 --- a/finetuning_glrb/main.py +++ b/finetuning_glrb/main.py @@ -16,6 +16,7 @@ from finetuning_glrb.finetune_bulk_rna import main_lit as finetune_bulk_rna_expression from finetuning_glrb.finetune_chromatin import main_histone_marks,main_dna_accessibility from finetuning_glrb.finetune_regulatory_elements import main_enhancer, main_promoter +from finetuning_glrb.finetune_cage import main_lit as main_cage log = get_logger(__name__) @@ -31,7 +32,7 @@ def main(opts): elif opts.task == "bulk_rna_expression": finetune_bulk_rna_expression(opts) elif opts.task == "cage_prediction": - raise ValueError("Invalid task selected or not implemented yet.") + main_cage(opts) elif opts.task == "chromatin_features_histone_marks": main_histone_marks(opts) elif opts.task == "chromatin_features_dna_accessibility": From be079bcef1fdf4df2b82271dead0509151d16010 Mon Sep 17 00:00:00 2001 From: Aymen Kallala Date: Sat, 24 Aug 2024 13:14:13 -0700 Subject: [PATCH 04/13] [Refacto] - suitable for checkpoint reloading --- finetuning_glrb/finetune.sh | 25 +++++++++---------- finetuning_glrb/finetune_bulk_rna.py | 9 ++++--- finetuning_glrb/finetune_cage.py | 8 ++++-- finetuning_glrb/finetune_chromatin.py | 5 +++- .../finetune_regulatory_elements.py | 7 ++++-- .../finetune_variant_effect_OMIM.py | 5 +++- .../finetune_variant_effect_causal_eqtl.py | 5 +++- ...etune_variant_effect_pathogenic_clinvar.py | 5 +++- 8 files changed, 45 insertions(+), 24 deletions(-) diff --git a/finetuning_glrb/finetune.sh b/finetuning_glrb/finetune.sh index d83fd7c..40a3c58 100644 --- a/finetuning_glrb/finetune.sh +++ b/finetuning_glrb/finetune.sh @@ -1,20 +1,19 @@ python finetuning_glrb/main.py \ - --task "task_name" \ - --seq_len 1000 \ - --model_name "model_name_on_the_huggingface_hub" \ - --bp_per_token 1 \ + --task "cage_prediction" \ + --seq_len 12032 \ + --model_name "InstaDeepAI/nucleotide-transformer-v2-50m-multi-species" \ + --bp_per_token 6 \ --save_dir "output/" \ - --wandb_api_key "your_wandb_api_key" \ - --name_wb "name_for_your_wandb_run" \ - --rcps true \ - --train_batch_size 16 \ - --test_batch_size 16 \ + --wandb_api_key "765bad652bcb6ce569641fc334bcf0f0eea5b1fb" \ + --name_wb "cage--ntv2-12k" \ + --train_batch_size 1 \ + --test_batch_size 2 \ --num_workers 6 \ - --num_epochs 10 \ + --num_epochs 100 \ --learning_rate "3e-5" \ - --patience 3 \ - --log_interval 280 \ - --accumulate_grad_batches 4 \ + --patience 30 \ + --log_interval 512 \ + --accumulate_grad_batches 128 \ --train_ratio 1.0 \ --eval_ratio 1.0 diff --git a/finetuning_glrb/finetune_bulk_rna.py b/finetuning_glrb/finetune_bulk_rna.py index 3343786..4a636a3 100644 --- a/finetuning_glrb/finetune_bulk_rna.py +++ b/finetuning_glrb/finetune_bulk_rna.py @@ -166,14 +166,17 @@ class Lit_BulkRNAFinetuning(pl.LightningModule): def __init__(self, args): super().__init__() self.save_hyperparameters(args) - self.rcps = args.rcps - self.model = DNAModelForBulkRNA(args) + self.setup() + + def setup(self,stage=None): + self.rcps = self.hparams.rcps + self.model = DNAModelForBulkRNA(self.hparams) self.criterion = nn.MSELoss() self.validation_step_preds = [] self.validation_step_labels = [] self.training_step_preds = [] self.training_step_labels = [] - + def forward(self, ref_input_ids): return self.model(ref_input_ids) diff --git a/finetuning_glrb/finetune_cage.py b/finetuning_glrb/finetune_cage.py index b90dfd0..b911c04 100644 --- a/finetuning_glrb/finetune_cage.py +++ b/finetuning_glrb/finetune_cage.py @@ -153,8 +153,12 @@ class Lit_DNAModelForCAGE(pl.LightningModule): def __init__(self, args): super().__init__() self.save_hyperparameters(args) - self.rcps = args.rcps - self.model = DNAModelForCAGE(args) + self.setup() + + def setup(self,stage=None): + + self.rcps = self.hparams.rcps + self.model = DNAModelForCAGE(self.hparams) self.criterion = nn.MSELoss() self.validation_step_preds = [] self.validation_step_labels = [] diff --git a/finetuning_glrb/finetune_chromatin.py b/finetuning_glrb/finetune_chromatin.py index e814dec..f8a6a0a 100644 --- a/finetuning_glrb/finetune_chromatin.py +++ b/finetuning_glrb/finetune_chromatin.py @@ -151,7 +151,10 @@ class Lit_ChromatinFeatures(pl.LightningModule): def __init__(self, args): super().__init__() self.save_hyperparameters(args) - self.model = DNAModelForChromatineFeatures(args) + self.setup() + + def setup(self,stage=None): + self.model = DNAModelForChromatineFeatures(self.hparams) self.criterion = nn.BCEWithLogitsLoss() self.validation_step_preds = [] self.validation_step_labels = [] diff --git a/finetuning_glrb/finetune_regulatory_elements.py b/finetuning_glrb/finetune_regulatory_elements.py index 0062117..bb390c3 100644 --- a/finetuning_glrb/finetune_regulatory_elements.py +++ b/finetuning_glrb/finetune_regulatory_elements.py @@ -145,8 +145,11 @@ class Lit_RegulatoryElements(pl.LightningModule): def __init__(self, args): super().__init__() self.save_hyperparameters(args) - self.model = DNAModelForRegulatoryElements(args) - self.task = args.task + self.setup() + + def setup(self,stage=None): + self.model = DNAModelForRegulatoryElements(self.hparams) + self.task = self.hparams.task self.criterion = nn.BCEWithLogitsLoss() self.validation_step_preds = [] self.validation_step_labels = [] diff --git a/finetuning_glrb/finetune_variant_effect_OMIM.py b/finetuning_glrb/finetune_variant_effect_OMIM.py index 18e611f..6baf96d 100644 --- a/finetuning_glrb/finetune_variant_effect_OMIM.py +++ b/finetuning_glrb/finetune_variant_effect_OMIM.py @@ -180,7 +180,10 @@ class Lit_OMIMFinetuning(pl.LightningModule): def __init__(self, args): super().__init__() self.save_hyperparameters(args) - self.model = DNAModelForOMIMFinetuning(args) + self.setup() + + def setup(self,stage=None): + self.model = DNAModelForOMIMFinetuning(self.hparams) self.criterion = nn.CrossEntropyLoss() self.validation_step_preds = [] self.validation_step_labels = [] diff --git a/finetuning_glrb/finetune_variant_effect_causal_eqtl.py b/finetuning_glrb/finetune_variant_effect_causal_eqtl.py index ff55590..dad97b9 100644 --- a/finetuning_glrb/finetune_variant_effect_causal_eqtl.py +++ b/finetuning_glrb/finetune_variant_effect_causal_eqtl.py @@ -207,7 +207,10 @@ class LitVEPFinetuning(pl.LightningModule): def __init__(self, args): super().__init__() self.save_hyperparameters(args) - self.model = DNAModelForVEPFinetuning(args) + self.setup() + + def setup(self,stage=None): + self.model = DNAModelForVEPFinetuning(self.hparams) self.criterion = nn.CrossEntropyLoss() self.validation_step_preds = {i: [] for i in range(len(DIST_TO_TSS))} self.validation_step_labels = {i: [] for i in range(len(DIST_TO_TSS))} diff --git a/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py b/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py index d60a9a1..b297ac4 100644 --- a/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py +++ b/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py @@ -178,7 +178,10 @@ class Lit_ClinVarFinetuning(pl.LightningModule): def __init__(self, args): super().__init__() self.save_hyperparameters(args) - self.model = DNAModelForVEPFinetuning(args) + self.setup() + + def setup(self,stage=None): + self.model = DNAModelForVEPFinetuning(self.hparams) self.criterion = nn.CrossEntropyLoss() self.validation_step_preds = [] self.validation_step_labels = [] From 71965f4a8667e6e1bbe1eb570f2bf75d6751b594 Mon Sep 17 00:00:00 2001 From: Aymen Kallala Date: Sat, 24 Aug 2024 20:53:35 -0700 Subject: [PATCH 05/13] Renaming of functions for clarity --- finetuning_glrb/finetune_bulk_rna.py | 2 +- finetuning_glrb/finetune_cage.py | 2 +- finetuning_glrb/finetune_chromatin.py | 4 +-- .../finetune_regulatory_elements.py | 4 +-- .../finetune_variant_effect_OMIM.py | 2 +- .../finetune_variant_effect_causal_eqtl.py | 2 +- ...etune_variant_effect_pathogenic_clinvar.py | 2 +- finetuning_glrb/main.py | 25 +++++++++---------- 8 files changed, 21 insertions(+), 22 deletions(-) diff --git a/finetuning_glrb/finetune_bulk_rna.py b/finetuning_glrb/finetune_bulk_rna.py index 4a636a3..0ebe94d 100644 --- a/finetuning_glrb/finetune_bulk_rna.py +++ b/finetuning_glrb/finetune_bulk_rna.py @@ -332,7 +332,7 @@ def _download_and_preprocess_data(self): dataset.save_to_disk(self._get_preprocessed_cache_file()) log.warning("Data downloaded and preprocessed successfully.") -def main_lit(args): +def finetune(args): """ Main function to start the training process using PyTorch Lightning. diff --git a/finetuning_glrb/finetune_cage.py b/finetuning_glrb/finetune_cage.py index b911c04..4f83844 100644 --- a/finetuning_glrb/finetune_cage.py +++ b/finetuning_glrb/finetune_cage.py @@ -326,7 +326,7 @@ def _download_and_preprocess_data(self): dataset.save_to_disk(self._get_preprocessed_cache_file()) log.warning("Data downloaded and preprocessed successfully.") -def main_lit(args): +def finetune(args): """ Main function to start the training process using PyTorch Lightning. diff --git a/finetuning_glrb/finetune_chromatin.py b/finetuning_glrb/finetune_chromatin.py index f8a6a0a..b433187 100644 --- a/finetuning_glrb/finetune_chromatin.py +++ b/finetuning_glrb/finetune_chromatin.py @@ -443,7 +443,7 @@ def _download_and_preprocess_data(self): dataset.save_to_disk(self._get_preprocessed_cache_file()) log.warning("Data downloaded and preprocessed successfully.") -def main_histone_marks(args): +def finetune_histone_marks(args): """ Main function to start finetuning on Histone Marks with PyTorch Lightning. @@ -499,7 +499,7 @@ def main_histone_marks(args): # Start the training process trainer.fit(model, data_module) -def main_dna_accessibility(args): +def finetune_dna_accessibility(args): """ Main function to start finetunning on DNA Accessibility with PyTorch Lightning. diff --git a/finetuning_glrb/finetune_regulatory_elements.py b/finetuning_glrb/finetune_regulatory_elements.py index bb390c3..6e7d303 100644 --- a/finetuning_glrb/finetune_regulatory_elements.py +++ b/finetuning_glrb/finetune_regulatory_elements.py @@ -444,7 +444,7 @@ def _download_and_preprocess_data(self): dataset.save_to_disk(self._get_preprocessed_cache_file()) log.warning("Data downloaded and preprocessed successfully.") -def main_promoter(args): +def finetune_promoters(args): """ Main function to start training on Promoter regulatory elements with PyTorch Lightning. @@ -500,7 +500,7 @@ def main_promoter(args): # Start the training process trainer.fit(model, data_module) -def main_enhancer(args): +def finetune_enhancers(args): """ Main function to start training on Enhancer regulatory elements with PyTorch Lightning. diff --git a/finetuning_glrb/finetune_variant_effect_OMIM.py b/finetuning_glrb/finetune_variant_effect_OMIM.py index 6baf96d..09c0f08 100644 --- a/finetuning_glrb/finetune_variant_effect_OMIM.py +++ b/finetuning_glrb/finetune_variant_effect_OMIM.py @@ -378,7 +378,7 @@ def _download_and_preprocess_data(self): dataset.save_to_disk(self._get_preprocessed_cache_file()) log.warning(f"Data downloaded and preprocessed successfully.") -def main_lit(args): +def finetune(args): """ Main function to start the training process for OMIM Variant Effect Prediction using PyTorch Lightning. diff --git a/finetuning_glrb/finetune_variant_effect_causal_eqtl.py b/finetuning_glrb/finetune_variant_effect_causal_eqtl.py index dad97b9..8aa2e27 100644 --- a/finetuning_glrb/finetune_variant_effect_causal_eqtl.py +++ b/finetuning_glrb/finetune_variant_effect_causal_eqtl.py @@ -391,7 +391,7 @@ def _download_and_preprocess_data(self): log.warning(f"Data downloaded and preprocessed successfully.") -def main_lit(args): +def finetune(args): """ Main function to start process for Variant Effect Prediction Finetuning eQTL using PyTorch Lightning. diff --git a/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py b/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py index b297ac4..ec28348 100644 --- a/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py +++ b/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py @@ -376,7 +376,7 @@ def _download_and_preprocess_data(self): dataset.save_to_disk(self._get_preprocessed_cache_file()) log.warning(f"Data downloaded and preprocessed successfully.") -def main_lit(args): +def finetune(args): """ Main function to start the training process for ClinVar Variant Effect Prediction using PyTorch Lightning. diff --git a/finetuning_glrb/main.py b/finetuning_glrb/main.py index 5da6b0a..3ede9db 100644 --- a/finetuning_glrb/main.py +++ b/finetuning_glrb/main.py @@ -10,13 +10,13 @@ from src.utils.train import get_logger -from finetuning_glrb.finetune_variant_effect_pathogenic_clinvar import main_lit as finetune_vep_clinvar -from finetuning_glrb.finetune_variant_effect_OMIM import main_lit as main_omim -from finetuning_glrb.finetune_variant_effect_causal_eqtl import main_lit as finetune_vep_eqtl -from finetuning_glrb.finetune_bulk_rna import main_lit as finetune_bulk_rna_expression -from finetuning_glrb.finetune_chromatin import main_histone_marks,main_dna_accessibility -from finetuning_glrb.finetune_regulatory_elements import main_enhancer, main_promoter -from finetuning_glrb.finetune_cage import main_lit as main_cage +from finetuning_glrb.finetune_variant_effect_pathogenic_clinvar import finetune as finetune_vep_clinvar +from finetuning_glrb.finetune_variant_effect_OMIM import finetune as main_omim +from finetuning_glrb.finetune_variant_effect_causal_eqtl import finetune as finetune_vep_eqtl +from finetuning_glrb.finetune_bulk_rna import finetune as finetune_bulk_rna_expression +from finetuning_glrb.finetune_chromatin import finetune_histone_marks,finetune_dna_accessibility +from finetuning_glrb.finetune_regulatory_elements import finetune_enhancers, finetune_promoters +from finetuning_glrb.finetune_cage import finetune as finetune_cage log = get_logger(__name__) @@ -24,7 +24,6 @@ def main(opts): # Check if the value of args.task matches one of the predefined options if opts.task == "variant_effect_causal_eqtl": finetune_vep_eqtl(opts) - # Perform operations specific to this task elif opts.task == "variant_effect_pathogenic_clinvar": finetune_vep_clinvar(opts) elif opts.task == "variant_effect_pathogenic_omim": @@ -32,15 +31,15 @@ def main(opts): elif opts.task == "bulk_rna_expression": finetune_bulk_rna_expression(opts) elif opts.task == "cage_prediction": - main_cage(opts) + finetune_cage(opts) elif opts.task == "chromatin_features_histone_marks": - main_histone_marks(opts) + finetune_histone_marks(opts) elif opts.task == "chromatin_features_dna_accessibility": - main_dna_accessibility(opts) + finetune_dna_accessibility(opts) elif opts.task == "regulatory_element_promoter": - main_promoter(opts) + finetune_promoters(opts) elif opts.task == "regulatory_element_enhancer": - main_enhancer(opts) + finetune_enhancers(opts) From dbcd778cebcc91b5b2d5c7eb597074a0dcd7dd2d Mon Sep 17 00:00:00 2001 From: Aymen Kallala Date: Sat, 24 Aug 2024 20:58:07 -0700 Subject: [PATCH 06/13] [Feature] - Can now choose the finetuning precision --- finetuning_glrb/finetune.sh | 1 + finetuning_glrb/finetune_bulk_rna.py | 2 +- finetuning_glrb/finetune_cage.py | 2 +- finetuning_glrb/finetune_chromatin.py | 4 ++-- finetuning_glrb/finetune_regulatory_elements.py | 4 ++-- finetuning_glrb/finetune_variant_effect_OMIM.py | 2 +- finetuning_glrb/finetune_variant_effect_causal_eqtl.py | 2 +- .../finetune_variant_effect_pathogenic_clinvar.py | 2 +- finetuning_glrb/main.py | 7 +++++++ 9 files changed, 17 insertions(+), 9 deletions(-) diff --git a/finetuning_glrb/finetune.sh b/finetuning_glrb/finetune.sh index 40a3c58..14ff135 100644 --- a/finetuning_glrb/finetune.sh +++ b/finetuning_glrb/finetune.sh @@ -10,6 +10,7 @@ python finetuning_glrb/main.py \ --test_batch_size 2 \ --num_workers 6 \ --num_epochs 100 \ + --precision "16-mixed" \ --learning_rate "3e-5" \ --patience 30 \ --log_interval 512 \ diff --git a/finetuning_glrb/finetune_bulk_rna.py b/finetuning_glrb/finetune_bulk_rna.py index 0ebe94d..19e5359 100644 --- a/finetuning_glrb/finetune_bulk_rna.py +++ b/finetuning_glrb/finetune_bulk_rna.py @@ -381,7 +381,7 @@ def finetune(args): limit_val_batches=args.eval_ratio, val_check_interval=args.log_interval, gradient_clip_val=1.0, - precision="16-mixed", + precision=args.precision, accumulate_grad_batches=args.accumulate_grad_batches, num_sanity_val_steps=0 ) diff --git a/finetuning_glrb/finetune_cage.py b/finetuning_glrb/finetune_cage.py index 4f83844..2121b34 100644 --- a/finetuning_glrb/finetune_cage.py +++ b/finetuning_glrb/finetune_cage.py @@ -376,7 +376,7 @@ def finetune(args): limit_val_batches=args.eval_ratio, val_check_interval=args.log_interval, gradient_clip_val=1.0, - precision="16-mixed", + precision=args.precision, accumulate_grad_batches=args.accumulate_grad_batches, num_sanity_val_steps=0 ) diff --git a/finetuning_glrb/finetune_chromatin.py b/finetuning_glrb/finetune_chromatin.py index b433187..66a366b 100644 --- a/finetuning_glrb/finetune_chromatin.py +++ b/finetuning_glrb/finetune_chromatin.py @@ -491,7 +491,7 @@ def finetune_histone_marks(args): limit_val_batches=args.eval_ratio, val_check_interval=args.log_interval, gradient_clip_val=1.0, - precision="16-mixed", + precision=args.precision, accumulate_grad_batches=args.accumulate_grad_batches, num_sanity_val_steps=0 ) @@ -547,7 +547,7 @@ def finetune_dna_accessibility(args): limit_val_batches=args.eval_ratio, val_check_interval=args.log_interval, gradient_clip_val=1.0, - precision="16-mixed", + precision=args.precision, accumulate_grad_batches=args.accumulate_grad_batches, num_sanity_val_steps=0 ) diff --git a/finetuning_glrb/finetune_regulatory_elements.py b/finetuning_glrb/finetune_regulatory_elements.py index 6e7d303..7b21e54 100644 --- a/finetuning_glrb/finetune_regulatory_elements.py +++ b/finetuning_glrb/finetune_regulatory_elements.py @@ -492,7 +492,7 @@ def finetune_promoters(args): limit_val_batches=args.eval_ratio, val_check_interval=args.log_interval, gradient_clip_val=1.0, - precision="16-mixed", + precision=args.precision, accumulate_grad_batches=args.accumulate_grad_batches, num_sanity_val_steps=0 ) @@ -548,7 +548,7 @@ def finetune_enhancers(args): limit_val_batches=args.eval_ratio, val_check_interval=args.log_interval, gradient_clip_val=1.0, - precision="16-mixed", + precision=args.precision, accumulate_grad_batches=args.accumulate_grad_batches, num_sanity_val_steps=0 ) diff --git a/finetuning_glrb/finetune_variant_effect_OMIM.py b/finetuning_glrb/finetune_variant_effect_OMIM.py index 09c0f08..f352ec3 100644 --- a/finetuning_glrb/finetune_variant_effect_OMIM.py +++ b/finetuning_glrb/finetune_variant_effect_OMIM.py @@ -425,7 +425,7 @@ def finetune(args): limit_val_batches=args.eval_ratio, val_check_interval=args.log_interval, gradient_clip_val=1.0, - precision="16-mixed", + precision=args.precision, accumulate_grad_batches=args.accumulate_grad_batches, num_sanity_val_steps=0 ) diff --git a/finetuning_glrb/finetune_variant_effect_causal_eqtl.py b/finetuning_glrb/finetune_variant_effect_causal_eqtl.py index 8aa2e27..a5bacf8 100644 --- a/finetuning_glrb/finetune_variant_effect_causal_eqtl.py +++ b/finetuning_glrb/finetune_variant_effect_causal_eqtl.py @@ -438,7 +438,7 @@ def finetune(args): limit_val_batches=args.eval_ratio, val_check_interval=args.log_interval, gradient_clip_val=1.0, - precision="16-mixed", + precision=args.precision, accumulate_grad_batches=args.accumulate_grad_batches, num_sanity_val_steps=0 ) diff --git a/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py b/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py index ec28348..b79afd3 100644 --- a/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py +++ b/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py @@ -423,7 +423,7 @@ def finetune(args): limit_val_batches=args.eval_ratio, val_check_interval=args.log_interval, gradient_clip_val=1.0, - precision="16-mixed", + precision=args.precision, accumulate_grad_batches=args.accumulate_grad_batches, num_sanity_val_steps=0 ) diff --git a/finetuning_glrb/main.py b/finetuning_glrb/main.py index 3ede9db..4066714 100644 --- a/finetuning_glrb/main.py +++ b/finetuning_glrb/main.py @@ -78,6 +78,13 @@ def main(opts): parser.add_argument("--preprocessed_dataset_path", type=str, default=None, help="Path to preprocessed dataset.") parser.add_argument("--rcps", type=bool, default=False, help="Using rcps when extracting embeddings or not.") parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs to train.") + parser.add_argument( + "--precision", + type=str, + choices=['transformer-engine', 'transformer-engine-float16', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', '32-true', '64-true'], + required=True, + help="Choose one mode." + ) parser.add_argument("--accumulate_grad_batches", type=int, default=1, help="Accumulate gradients") parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for optimizer") From c52688f7bc09940d66d55229bbb4c9262ca1e70e Mon Sep 17 00:00:00 2001 From: Aymen Kallala Date: Sat, 24 Aug 2024 21:01:23 -0700 Subject: [PATCH 07/13] readme --- finetuning_glrb/README.md | 78 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/finetuning_glrb/README.md b/finetuning_glrb/README.md index 86e85ad..7b4f677 100644 --- a/finetuning_glrb/README.md +++ b/finetuning_glrb/README.md @@ -8,6 +8,84 @@ DNA Models are loaded from the Hugging-Face Hub 🤗. To fine-tune a model, execute the `finetune.sh` script. The script runs the `main.py` script with various command-line arguments that configure the fine-tuning process. Below is a description of each argument used. +#### `--task` +Choose one of the predefined variant effects. Options include: +- `"variant_effect_causal_eqtl"` +- `"variant_effect_pathogenic_clinvar"` +- `"variant_effect_pathogenic_omim"` +- `"cage_prediction"` +- `"bulk_rna_expression"` +- `"chromatin_features_histone_marks"` +- `"chromatin_features_dna_accessibility"` +- `"regulatory_element_promoter"` +- `"regulatory_element_enhancer"` + +#### `--seq_len` +Specifies the sequence length in base pairs (bp). + +#### `--model_name` +Name of the pre-trained model to fine-tune. + +#### `--bp_per_token` +Defines the number of base pairs per token. + +#### `--save_dir` +Directory where the outputs of the downstream task will be saved. + +#### `--wandb_api_key` +API key for Weights & Biases logging. + +#### `--name_wb` +Name for the Weights & Biases embeddings model. + +#### `--train_batch_size` +Defines the batch size for training. + +#### `--test_batch_size` +Defines the batch size for testing/validation. + +#### `--num_workers` +Number of workers to use for data loading. + +#### `--preprocessed_dataset_path` +Path to the preprocessed dataset. + +#### `--rcps` +Indicates whether to use RCPS when extracting embeddings. + +#### `--num_epochs` +Specifies the number of epochs to train. + +#### `--precision` +Choose the precision mode. Options include: +- `"transformer-engine"` +- `"transformer-engine-float16"` +- `"16-true"` +- `"16-mixed"` +- `"bf16-true"` +- `"bf16-mixed"` +- `"32-true"` +- `"64-true"` + +#### `--accumulate_grad_batches` +Number of batches for which to accumulate gradients. + +#### `--learning_rate` +Specifies the learning rate for the optimizer. + +#### `--patience` +Determines the number of epochs with no improvement after which training will be stopped. + +#### `--log_interval` +Interval (in steps) at which to log training metrics. + +#### `--train_ratio` +Specifies the ratio of the dataset to use for training. + +#### `--eval_ratio` +Specifies the ratio of the dataset to use for evaluation. + + ### Running the Script To start finetuning, first make sure that you have modified the `finetune.sh` script with the correct parameters for your task. Then, simply run: From a25cad2bb9dfc2ece97309793a7c9865a75ec708 Mon Sep 17 00:00:00 2001 From: Aymen Kallala Date: Sat, 24 Aug 2024 21:03:36 -0700 Subject: [PATCH 08/13] readme --- finetuning_glrb/README.md | 78 +++++++++++---------------------------- 1 file changed, 21 insertions(+), 57 deletions(-) diff --git a/finetuning_glrb/README.md b/finetuning_glrb/README.md index 7b4f677..f73f95d 100644 --- a/finetuning_glrb/README.md +++ b/finetuning_glrb/README.md @@ -8,82 +8,46 @@ DNA Models are loaded from the Hugging-Face Hub 🤗. To fine-tune a model, execute the `finetune.sh` script. The script runs the `main.py` script with various command-line arguments that configure the fine-tuning process. Below is a description of each argument used. -#### `--task` -Choose one of the predefined variant effects. Options include: -- `"variant_effect_causal_eqtl"` -- `"variant_effect_pathogenic_clinvar"` -- `"variant_effect_pathogenic_omim"` -- `"cage_prediction"` -- `"bulk_rna_expression"` -- `"chromatin_features_histone_marks"` -- `"chromatin_features_dna_accessibility"` -- `"regulatory_element_promoter"` -- `"regulatory_element_enhancer"` +**`--task`**: Choose one of the predefined variant effects. Options include: `"variant_effect_causal_eqtl"`, `"variant_effect_pathogenic_clinvar"`, `"variant_effect_pathogenic_omim"`, `"cage_prediction"`, `"bulk_rna_expression"`, `"chromatin_features_histone_marks"`, `"chromatin_features_dna_accessibility"`, `"regulatory_element_promoter"`, `"regulatory_element_enhancer"`. -#### `--seq_len` -Specifies the sequence length in base pairs (bp). +**`--seq_len`**: Specifies the sequence length in base pairs (bp). -#### `--model_name` -Name of the pre-trained model to fine-tune. +**`--model_name`**: Name of the pre-trained model to fine-tune. -#### `--bp_per_token` -Defines the number of base pairs per token. +**`--bp_per_token`**: Defines the number of base pairs per token. -#### `--save_dir` -Directory where the outputs of the downstream task will be saved. +**`--save_dir`**: Directory where the outputs of the downstream task will be saved. -#### `--wandb_api_key` -API key for Weights & Biases logging. +**`--wandb_api_key`**: API key for Weights & Biases logging. -#### `--name_wb` -Name for the Weights & Biases embeddings model. +**`--name_wb`**: Name for the Weights & Biases embeddings model. -#### `--train_batch_size` -Defines the batch size for training. +**`--train_batch_size`**: Defines the batch size for training. -#### `--test_batch_size` -Defines the batch size for testing/validation. +**`--test_batch_size`**: Defines the batch size for testing/validation. -#### `--num_workers` -Number of workers to use for data loading. +**`--num_workers`**: Number of workers to use for data loading. -#### `--preprocessed_dataset_path` -Path to the preprocessed dataset. +**`--preprocessed_dataset_path`**: Path to the preprocessed dataset. -#### `--rcps` -Indicates whether to use RCPS when extracting embeddings. +**`--rcps`**: Indicates whether to use RCPS when extracting embeddings. -#### `--num_epochs` -Specifies the number of epochs to train. +**`--num_epochs`**: Specifies the number of epochs to train. -#### `--precision` -Choose the precision mode. Options include: -- `"transformer-engine"` -- `"transformer-engine-float16"` -- `"16-true"` -- `"16-mixed"` -- `"bf16-true"` -- `"bf16-mixed"` -- `"32-true"` -- `"64-true"` +**`--precision`**: Choose the precision mode. Options include: `"transformer-engine"`, `"transformer-engine-float16"`, `"16-true"`, `"16-mixed"`, `"bf16-true"`, `"bf16-mixed"`, `"32-true"`, `"64-true"`. -#### `--accumulate_grad_batches` -Number of batches for which to accumulate gradients. +**`--accumulate_grad_batches`**: Number of batches for which to accumulate gradients. -#### `--learning_rate` -Specifies the learning rate for the optimizer. +**`--learning_rate`**: Specifies the learning rate for the optimizer. -#### `--patience` -Determines the number of epochs with no improvement after which training will be stopped. +**`--patience`**: Determines the number of epochs with no improvement after which training will be stopped. -#### `--log_interval` -Interval (in steps) at which to log training metrics. +**`--log_interval`**: Interval (in steps) at which to log training metrics. -#### `--train_ratio` -Specifies the ratio of the dataset to use for training. +**`--train_ratio`**: Specifies the ratio of the dataset to use for training. + +**`--eval_ratio`**: Specifies the ratio of the dataset to use for evaluation. -#### `--eval_ratio` -Specifies the ratio of the dataset to use for evaluation. ### Running the Script From e27fc9a87b50a1e351a4b7a065eefa1f4bd21d12 Mon Sep 17 00:00:00 2001 From: Aymen Kallala Date: Sat, 24 Aug 2024 21:07:06 -0700 Subject: [PATCH 09/13] readme + remove useless line --- finetuning_glrb/README.md | 16 +++++++--------- finetuning_glrb/main.py | 1 - 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/finetuning_glrb/README.md b/finetuning_glrb/README.md index f73f95d..8408034 100644 --- a/finetuning_glrb/README.md +++ b/finetuning_glrb/README.md @@ -12,15 +12,15 @@ To fine-tune a model, execute the `finetune.sh` script. The script runs the `mai **`--seq_len`**: Specifies the sequence length in base pairs (bp). -**`--model_name`**: Name of the pre-trained model to fine-tune. +**`--model_name`**: Name of the pre-trained model to fine-tune (on the HF hub). -**`--bp_per_token`**: Defines the number of base pairs per token. +**`--bp_per_token`**: Defines the number of base pairs per token used in the tokenization process of the model. -**`--save_dir`**: Directory where the outputs of the downstream task will be saved. +**`--save_dir`**: Directory where the checkpoints and logs will be saved. **`--wandb_api_key`**: API key for Weights & Biases logging. -**`--name_wb`**: Name for the Weights & Biases embeddings model. +**`--name_wb`**: Name for the Weights & Biases run. **`--train_batch_size`**: Defines the batch size for training. @@ -28,21 +28,19 @@ To fine-tune a model, execute the `finetune.sh` script. The script runs the `mai **`--num_workers`**: Number of workers to use for data loading. -**`--preprocessed_dataset_path`**: Path to the preprocessed dataset. - **`--rcps`**: Indicates whether to use RCPS when extracting embeddings. **`--num_epochs`**: Specifies the number of epochs to train. -**`--precision`**: Choose the precision mode. Options include: `"transformer-engine"`, `"transformer-engine-float16"`, `"16-true"`, `"16-mixed"`, `"bf16-true"`, `"bf16-mixed"`, `"32-true"`, `"64-true"`. +**`--precision`**: Choose the precision. Options include: `"transformer-engine"`, `"transformer-engine-float16"`, `"16-true"`, `"16-mixed"`, `"bf16-true"`, `"bf16-mixed"`, `"32-true"`, `"64-true"`. -**`--accumulate_grad_batches`**: Number of batches for which to accumulate gradients. +**`--accumulate_grad_batches`**: Number of batches for which to accumulate gradients accross devices. **`--learning_rate`**: Specifies the learning rate for the optimizer. **`--patience`**: Determines the number of epochs with no improvement after which training will be stopped. -**`--log_interval`**: Interval (in steps) at which to log training metrics. +**`--log_interval`**: Interval (in steps) at which to log training metrics and run a validation step. **`--train_ratio`**: Specifies the ratio of the dataset to use for training. diff --git a/finetuning_glrb/main.py b/finetuning_glrb/main.py index 4066714..8c1c8de 100644 --- a/finetuning_glrb/main.py +++ b/finetuning_glrb/main.py @@ -75,7 +75,6 @@ def main(opts): parser.add_argument("--train_batch_size", type=int, default=16, help="Batch size for training.") parser.add_argument("--test_batch_size", type=int, default=16, help="Batch size for testing/validation.") parser.add_argument("--num_workers", type=int, default=4, help="Number of workers.") - parser.add_argument("--preprocessed_dataset_path", type=str, default=None, help="Path to preprocessed dataset.") parser.add_argument("--rcps", type=bool, default=False, help="Using rcps when extracting embeddings or not.") parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs to train.") parser.add_argument( From a728f9087f7f0754c4dac915848ef2d270495438 Mon Sep 17 00:00:00 2001 From: Aymen Kallala Date: Tue, 27 Aug 2024 15:06:27 -0700 Subject: [PATCH 10/13] exp: fix bug in tokenization process for models using k-mer tokenization --- finetuning_glrb/finetune_bulk_rna.py | 3 ++- finetuning_glrb/finetune_chromatin.py | 6 ++++-- finetuning_glrb/finetune_regulatory_elements.py | 6 ++++-- finetuning_glrb/finetune_variant_effect_OMIM.py | 3 ++- finetuning_glrb/finetune_variant_effect_causal_eqtl.py | 3 ++- .../finetune_variant_effect_pathogenic_clinvar.py | 3 ++- 6 files changed, 16 insertions(+), 8 deletions(-) diff --git a/finetuning_glrb/finetune_bulk_rna.py b/finetuning_glrb/finetune_bulk_rna.py index 19e5359..be4d685 100644 --- a/finetuning_glrb/finetune_bulk_rna.py +++ b/finetuning_glrb/finetune_bulk_rna.py @@ -235,6 +235,7 @@ class BulkRNADataModule(pl.LightningDataModule): def __init__(self, config): super().__init__() self.seq_len = config.seq_len + self.bp_per_token = config.bp_per_token self.model_name = config.model_name self.train_batch_size = config.train_batch_size self.test_batch_size = config.test_batch_size @@ -320,7 +321,7 @@ def _download_and_preprocess_data(self): desc="Recast chromosome" ) dataset = dataset.map( - partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len), + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len//self.bp_per_token), batch_size=1000, batched=True, remove_columns=["sequence"], diff --git a/finetuning_glrb/finetune_chromatin.py b/finetuning_glrb/finetune_chromatin.py index 66a366b..82b2b06 100644 --- a/finetuning_glrb/finetune_chromatin.py +++ b/finetuning_glrb/finetune_chromatin.py @@ -237,6 +237,7 @@ class HistoneMarksDataModule(pl.LightningDataModule): def __init__(self, config): super().__init__() self.seq_len = config.seq_len + self.bp_per_token = config.bp_per_token self.model_name = config.model_name self.train_batch_size = config.train_batch_size self.test_batch_size = config.test_batch_size @@ -323,7 +324,7 @@ def _download_and_preprocess_data(self): desc="Recast chromosome" ) dataset = dataset.map( - partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len), + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len//self.bp_per_token), batch_size=1000, batched=True, remove_columns=["sequence"], @@ -345,6 +346,7 @@ class DNAAccessibilityDataModule(pl.LightningDataModule): def __init__(self, config): super().__init__() self.seq_len = config.seq_len + self.bp_per_token = config.bp_per_token self.model_name = config.model_name self.train_batch_size = config.train_batch_size self.test_batch_size = config.test_batch_size @@ -431,7 +433,7 @@ def _download_and_preprocess_data(self): desc="Recast chromosome" ) dataset = dataset.map( - partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len), + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len//self.bp_per_token), batch_size=1000, batched=True, remove_columns=["sequence"], diff --git a/finetuning_glrb/finetune_regulatory_elements.py b/finetuning_glrb/finetune_regulatory_elements.py index 7b21e54..73b6fa9 100644 --- a/finetuning_glrb/finetune_regulatory_elements.py +++ b/finetuning_glrb/finetune_regulatory_elements.py @@ -238,6 +238,7 @@ class PromoterDataModule(pl.LightningDataModule): def __init__(self, config): super().__init__() self.seq_len = config.seq_len + self.bp_per_token = config.bp_per_token self.model_name = config.model_name self.train_batch_size = config.train_batch_size self.test_batch_size = config.test_batch_size @@ -324,7 +325,7 @@ def _download_and_preprocess_data(self): desc="Recast chromosome" ) dataset = dataset.map( - partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len), + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len//self.bp_per_token), batch_size=1000, batched=True, remove_columns=["sequence"], @@ -346,6 +347,7 @@ class EnhancerDataModule(pl.LightningDataModule): def __init__(self, config): super().__init__() self.seq_len = config.seq_len + self.bp_per_token = config.bp_per_token self.model_name = config.model_name self.train_batch_size = config.train_batch_size self.test_batch_size = config.test_batch_size @@ -432,7 +434,7 @@ def _download_and_preprocess_data(self): desc="Recast chromosome" ) dataset = dataset.map( - partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len), + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len//self.bp_per_token), batch_size=1000, batched=True, remove_columns=["sequence"], diff --git a/finetuning_glrb/finetune_variant_effect_OMIM.py b/finetuning_glrb/finetune_variant_effect_OMIM.py index f352ec3..68d820e 100644 --- a/finetuning_glrb/finetune_variant_effect_OMIM.py +++ b/finetuning_glrb/finetune_variant_effect_OMIM.py @@ -286,6 +286,7 @@ class VariantEffectPredictionDataModule(pl.LightningDataModule): def __init__(self, config): super().__init__() self.seq_len = config.seq_len + self.bp_per_token = config.bp_per_token self.model_name = config.model_name self.train_batch_size = config.train_batch_size self.test_batch_size = config.test_batch_size @@ -366,7 +367,7 @@ def _download_and_preprocess_data(self): desc="Recast chromosome" ) dataset = dataset.map( - partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len), + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len//self.bp_per_token), batch_size=1000, batched=True, remove_columns=["ref_forward_sequence", "alt_forward_sequence"], diff --git a/finetuning_glrb/finetune_variant_effect_causal_eqtl.py b/finetuning_glrb/finetune_variant_effect_causal_eqtl.py index a5bacf8..29ebd4c 100644 --- a/finetuning_glrb/finetune_variant_effect_causal_eqtl.py +++ b/finetuning_glrb/finetune_variant_effect_causal_eqtl.py @@ -286,6 +286,7 @@ class VariantEffectPredictionDataModule(pl.LightningDataModule): def __init__(self, config): super().__init__() self.seq_len = config.seq_len + self.bp_per_token = config.bp_per_token self.model_name = config.model_name self.train_batch_size = config.train_batch_size self.test_batch_size = config.test_batch_size @@ -370,7 +371,7 @@ def _download_and_preprocess_data(self): desc="Recast chromosome" ) dataset = dataset.map( - partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len), + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len//self.bp_per_token), batch_size=1000, batched=True, remove_columns=["ref_forward_sequence", "alt_forward_sequence"], diff --git a/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py b/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py index b79afd3..542e6a0 100644 --- a/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py +++ b/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py @@ -284,6 +284,7 @@ class VariantEffectPredictionDataModule(pl.LightningDataModule): def __init__(self, config): super().__init__() self.seq_len = config.seq_len + self.bp_per_tokem = config.bp_per_token self.model_name = config.model_name self.train_batch_size = config.train_batch_size self.test_batch_size = config.test_batch_size @@ -364,7 +365,7 @@ def _download_and_preprocess_data(self): desc="Recast chromosome" ) dataset = dataset.map( - partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len), + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len//self.bp_per_token), batch_size=1000, batched=True, remove_columns=["ref_forward_sequence", "alt_forward_sequence"], From c387303a683662ee020a8212a3ef762b6f39f31e Mon Sep 17 00:00:00 2001 From: Aymen Kallala Date: Wed, 28 Aug 2024 13:47:00 -0700 Subject: [PATCH 11/13] anonymyze scripts --- finetuning_glrb/finetune.sh | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/finetuning_glrb/finetune.sh b/finetuning_glrb/finetune.sh index 14ff135..b105d6c 100644 --- a/finetuning_glrb/finetune.sh +++ b/finetuning_glrb/finetune.sh @@ -1,18 +1,18 @@ python finetuning_glrb/main.py \ - --task "cage_prediction" \ - --seq_len 12032 \ - --model_name "InstaDeepAI/nucleotide-transformer-v2-50m-multi-species" \ - --bp_per_token 6 \ + --task "bulk_rna_expression" \ + --seq_len 12000 \ + --model_name "your_model_name_from_the_hub" \ + --bp_per_token TBD \ --save_dir "output/" \ - --wandb_api_key "765bad652bcb6ce569641fc334bcf0f0eea5b1fb" \ - --name_wb "cage--ntv2-12k" \ - --train_batch_size 1 \ - --test_batch_size 2 \ + --wandb_api_key "your_wandb_api_key" \ + --name_wb "your_wandb_run_name" \ + --train_batch_size 4 \ + --test_batch_size 4 \ --num_workers 6 \ - --num_epochs 100 \ + --num_epochs 1 \ --precision "16-mixed" \ --learning_rate "3e-5" \ - --patience 30 \ + --patience 3 \ --log_interval 512 \ --accumulate_grad_batches 128 \ --train_ratio 1.0 \ From 2d4c56d646952d40cca2a73d54db2e214f9d62e9 Mon Sep 17 00:00:00 2001 From: Aymen Kallala Date: Wed, 28 Aug 2024 18:51:02 -0700 Subject: [PATCH 12/13] [Methodology] - 5-fold cross validation on randomly sampled chromosomes --- finetuning_glrb/finetune.sh | 2 +- finetuning_glrb/finetune_bulk_rna.py | 133 +++++---- finetuning_glrb/finetune_cage.py | 151 ++++++++--- finetuning_glrb/finetune_chromatin.py | 252 ++++++++++++------ .../finetune_regulatory_elements.py | 252 ++++++++++++------ .../finetune_variant_effect_OMIM.py | 147 +++++++--- .../finetune_variant_effect_causal_eqtl.py | 196 +++++++++++--- ...etune_variant_effect_pathogenic_clinvar.py | 131 ++++++--- finetuning_glrb/main.py | 2 - 9 files changed, 882 insertions(+), 384 deletions(-) diff --git a/finetuning_glrb/finetune.sh b/finetuning_glrb/finetune.sh index b105d6c..71a37c5 100644 --- a/finetuning_glrb/finetune.sh +++ b/finetuning_glrb/finetune.sh @@ -8,11 +8,11 @@ python finetuning_glrb/main.py \ --name_wb "your_wandb_run_name" \ --train_batch_size 4 \ --test_batch_size 4 \ + --rcps true \ --num_workers 6 \ --num_epochs 1 \ --precision "16-mixed" \ --learning_rate "3e-5" \ - --patience 3 \ --log_interval 512 \ --accumulate_grad_batches 128 \ --train_ratio 1.0 \ diff --git a/finetuning_glrb/finetune_bulk_rna.py b/finetuning_glrb/finetune_bulk_rna.py index be4d685..76345c2 100644 --- a/finetuning_glrb/finetune_bulk_rna.py +++ b/finetuning_glrb/finetune_bulk_rna.py @@ -4,6 +4,7 @@ import re import torch import torch.nn as nn +import numpy as np from torch.utils.data import DataLoader from transformers import ( AutoModelForMaskedLM, @@ -60,7 +61,7 @@ def recast_chromosome(examples): dict with chromosome recast as integers. """ return { - "chromosome": -1 if examples["chromosome"] == "X" else -2 if examples["chromosome"] == "Y" else int(examples["chromosome"]) + "chromosome": -1 if examples["chromosome"] in ["X","Y"] else int(examples["chromosome"]) } class MLP_BulkRNA(nn.Module): @@ -188,7 +189,7 @@ def training_step(self, batch, batch_idx): loss = self.criterion(logits, labels) self.log('train_loss', loss, on_epoch=True, on_step=True, sync_dist=True) - # Track predictions and labels for R² score + # Track predictions and labels for R² score"chromosome": -1 if examples["chromosome"] == "X" or "Y" else int(examples["chromosome"]) self.training_step_preds.extend(logits.detach().cpu().numpy()) self.training_step_labels.extend(labels.detach().cpu().numpy()) @@ -208,6 +209,16 @@ def validation_step(self, batch, batch_idx): return loss + def test_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"] + + logits = self.model(ref_input_ids) + + # Track predictions and labels for R² score + self.validation_step_preds.extend(logits.detach().cpu().numpy()) + self.validation_step_labels.extend(labels.detach().cpu().numpy()) + def on_validation_epoch_end(self): # Calculate R² score for validation val_r2 = r2_score(self.validation_step_labels, self.validation_step_preds) @@ -215,6 +226,13 @@ def on_validation_epoch_end(self): self.validation_step_labels.clear() self.validation_step_preds.clear() + def on_test_epoch_end(self): + # Calculate R² score for validation + val_r2 = r2_score(self.validation_step_labels, self.validation_step_preds) + self.log("test/R2", val_r2, on_epoch=True, prog_bar=True, logger=True) + self.validation_step_labels.clear() + self.validation_step_preds.clear() + def on_train_epoch_end(self): # Calculate R² score for training train_r2 = r2_score(self.training_step_labels, self.training_step_preds) @@ -265,8 +283,7 @@ def setup(self, stage=None): self.dataset = load_from_disk(self._get_preprocessed_cache_file()) # Split the dataset into train and validation sets - self.train_dataset = self.dataset["train"] - self.val_dataset = self.dataset["test"] + self.test_dataset = self.dataset["test"] def train_dataloader(self): return DataLoader( @@ -287,6 +304,29 @@ def val_dataloader(self): pin_memory=True, shuffle=False ) + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _split_dataset(self,selected_validation_chromosome): + log.warning(f"SPLITTING THE DATASET INTO TRAIN AND VAL SET, VAL SET BEING CHROMOSOME {selected_validation_chromosome}") + self.train_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + != selected_validation_chromosome, + keep_in_memory=True, + ) + self.val_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + == selected_validation_chromosome, + keep_in_memory=True, + ) + self.validation_chromosome = selected_validation_chromosome def _get_preprocessed_cache_file(self): self.cache_dir = osp.join( @@ -341,51 +381,54 @@ def finetune(args): args: Command line arguments or configuration dictionary. """ wandb.login(key=args.wandb_api_key) - wandb_logger = WandbLogger( - name=f"{args.name_wb}-{args.seq_len}", - project="Bulk RNA Expression", - log_model=True, - save_dir=args.save_dir - ) data_module = BulkRNADataModule(args) data_module.setup() - model = Lit_BulkRNAFinetuning(args) + np.random.seed(0) + candidates = np.unique(data_module.dataset["train"]["chromosome"]) + held_chromosomes = np.random.choice(candidates,5,replace = False) - # Callbacks for early stopping and model checkpointing - early_stopping_callback = EarlyStopping( - monitor="val_loss", - patience=args.patience, - verbose=True, - mode="min" - ) + for idx,val_chromosome in enumerate(held_chromosomes): - checkpoint_callback = ModelCheckpoint( - dirpath=f"{args.save_dir}/checkpoints", - filename="best-checkpoint", - save_top_k=1, - verbose=True, - monitor="val_loss", - mode="min" - ) + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}-fold-{idx+1}", + project="Bulk RNA Expression", + log_model=True, + save_dir="./wandb" + ) + data_module._split_dataset(val_chromosome) + model = Lit_BulkRNAFinetuning(args) + + checkpoint_callback = ModelCheckpoint( + dirpath=f"{args.save_dir}/checkpoints", + filename=f"best-checkpoint-on-chromosome{val_chromosome}", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) - nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" - - # Set up the PyTorch Lightning Trainer - trainer = pl.Trainer( - max_epochs=args.num_epochs, - devices=nb_device, - logger=wandb_logger, - callbacks=[early_stopping_callback, checkpoint_callback], - log_every_n_steps=1, - limit_train_batches=args.train_ratio, - limit_val_batches=args.eval_ratio, - val_check_interval=args.log_interval, - gradient_clip_val=1.0, - precision=args.precision, - accumulate_grad_batches=args.accumulate_grad_batches, - num_sanity_val_steps=0 - ) + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision=args.precision, + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + # Start the training process + trainer.fit(model, data_module) + trainer.test(model,data_module,ckpt_path=f"./{args.save_dir}/checkpoints/best-checkpoint-on-chromosome{val_chromosome}.ckpt") - # Start the training process - trainer.fit(model, data_module) + # Finish the current WandB run + wandb.finish() diff --git a/finetuning_glrb/finetune_cage.py b/finetuning_glrb/finetune_cage.py index 2121b34..a825352 100644 --- a/finetuning_glrb/finetune_cage.py +++ b/finetuning_glrb/finetune_cage.py @@ -4,6 +4,7 @@ import re import torch import torch.nn as nn +import numpy from torch.utils.data import DataLoader from transformers import ( AutoModelForMaskedLM, @@ -71,6 +72,18 @@ def tokenize_variants(examples, tokenizer, max_length: int): "ref_input_ids": ref_tokenized["input_ids"], } +def recast_chromosome(examples): + """ + Recast chromosome to integer format. + + Returns: + dict with chromosome recast as integers. + """ + return { + #"chromosome": -1 if examples["chromosome"] == "X" else -2 if examples["chromosome"] == "Y" else int(examples["chromosome"]) + "chromosome": -1 if examples["chromosome"] in ["X","Y"] else int(examples["chromosome"]) + } + class MLP_CAGE(nn.Module): """ Regression head for Bulk RNA prediction task. @@ -196,6 +209,16 @@ def validation_step(self, batch, batch_idx): return loss + def test_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"] + + logits = self.model(ref_input_ids) + + # Track predictions and labels for R² score + self.validation_step_preds.extend(logits.view(-1, logits.size(-1)).detach().cpu().numpy()) + self.validation_step_labels.extend(labels.view(-1, labels.size(-1)).detach().cpu().numpy()) + def on_validation_epoch_end(self): #Log(1+x) normalize the preds and labels self.validation_step_labels = list(np.log1p(self.validation_step_labels)) @@ -207,6 +230,17 @@ def on_validation_epoch_end(self): self.validation_step_labels.clear() self.validation_step_preds.clear() + def on_test_epoch_end(self): + #Log(1+x) normalize the preds and labels + self.validation_step_labels = list(np.log1p(self.validation_step_labels)) + self.validation_step_preds = list(np.log1p(self.validation_step_preds)) + + # Calculate R² score for validation + val_r2 = r2_score(self.validation_step_labels, self.validation_step_preds) + self.log("test/R2", val_r2, on_epoch=True, prog_bar=True, logger=True) + self.validation_step_labels.clear() + self.validation_step_preds.clear() + def on_train_epoch_end(self): #Log(1+x) normalize the preds and labels self.training_step_labels = list(np.log1p(self.training_step_labels)) @@ -263,8 +297,7 @@ def setup(self, stage=None): self.dataset = load_from_disk(self._get_preprocessed_cache_file()) # Split the dataset into train and validation sets - self.train_dataset = self.dataset["train"] - self.val_dataset = self.dataset["test"] + self.test_dataset = self.dataset["test"] def train_dataloader(self): return DataLoader( @@ -285,6 +318,29 @@ def val_dataloader(self): pin_memory=True, shuffle=False ) + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _split_dataset(self,selected_validation_chromosome): + log.warning(f"SPLITTING THE DATASET INTO TRAIN AND VAL SET, VAL SET BEING CHROMOSOME {selected_validation_chromosome}") + self.train_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + != selected_validation_chromosome, + keep_in_memory=True, + ) + self.val_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + == selected_validation_chromosome, + keep_in_memory=True, + ) + self.validation_chromosome = selected_validation_chromosome def _get_preprocessed_cache_file(self): self.cache_dir = osp.join( @@ -313,6 +369,13 @@ def _download_and_preprocess_data(self): lambda example: example["sequence"].count('N') < 0.005 * self.seq_len, desc="Filter N's" ) + + dataset = dataset.map( + recast_chromosome, + remove_columns=["chromosome"], + desc="Recast chromosome" + ) + dataset = dataset.map( partial(tokenize_variants, tokenizer=self.tokenizer, max_length=_closest_multiple_after(self.tokens_per_seq,self.bins_per_seq)), batch_size=1000, @@ -335,51 +398,53 @@ def finetune(args): """ wandb.login(key=args.wandb_api_key) - wandb_logger = WandbLogger( - name=f"{args.name_wb}-{args.seq_len}", - project="CAGE Predictions", - log_model=True, - save_dir=args.save_dir - ) data_module = CAGEDataModule(args) data_module.setup() - model = Lit_DNAModelForCAGE(args) + np.random.seed(0) + candidates = np.unique(data_module.dataset["train"]["chromosome"]) + held_chromosomes = np.random.choice(candidates,5,replace = False) - # Callbacks for early stopping and model checkpointing - early_stopping_callback = EarlyStopping( - monitor="val_loss", - patience=args.patience, - verbose=True, - mode="min" - ) + for idx,val_chromosome in enumerate(held_chromosomes): - checkpoint_callback = ModelCheckpoint( - dirpath=f"{args.save_dir}/checkpoints", - filename="best-checkpoint", - save_top_k=1, - verbose=True, - monitor="val_loss", - mode="min" - ) + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}-fold-{idx+1}", + project="CAGE Predictions", + log_model=True, + save_dir="./wandb" + ) - nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" - - # Set up the PyTorch Lightning Trainer - trainer = pl.Trainer( - max_epochs=args.num_epochs, - devices=nb_device, - logger=wandb_logger, - callbacks=[early_stopping_callback, checkpoint_callback], - log_every_n_steps=1, - limit_train_batches=args.train_ratio, - limit_val_batches=args.eval_ratio, - val_check_interval=args.log_interval, - gradient_clip_val=1.0, - precision=args.precision, - accumulate_grad_batches=args.accumulate_grad_batches, - num_sanity_val_steps=0 - ) + data_module._split_dataset(val_chromosome) + model = Lit_DNAModelForCAGE(args) + + checkpoint_callback = ModelCheckpoint( + dirpath=f"{args.save_dir}/checkpoints", + filename=f"best-checkpoint-on-chromosome{val_chromosome}", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision=args.precision, + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) - # Start the training process - trainer.fit(model, data_module) + # Start the training process + trainer.fit(model, data_module) + trainer.test(model,data_module,ckpt_path=f"./{args.save_dir}/checkpoints/best-checkpoint-on-chromosome{val_chromosome}.ckpt") + wandb.finish() diff --git a/finetuning_glrb/finetune_chromatin.py b/finetuning_glrb/finetune_chromatin.py index 82b2b06..c803262 100644 --- a/finetuning_glrb/finetune_chromatin.py +++ b/finetuning_glrb/finetune_chromatin.py @@ -4,6 +4,7 @@ import re import torch import torch.nn as nn +import numpy as np from torch.utils.data import DataLoader import wandb import lightning.pytorch as pl @@ -57,7 +58,7 @@ def recast_chromosome(examples): dict with chromosome recast as integers. """ return { - "chromosome": -1 if examples["chromosome"] == "X" else -2 if examples["chromosome"] == "Y" else int(examples["chromosome"]) + "chromosome": -1 if examples["chromosome"] in ["X","Y"] else int(examples["chromosome"]) } class MLP_ChromatineFeatures(nn.Module): @@ -193,6 +194,17 @@ def validation_step(self, batch, batch_idx): self.validation_step_labels.extend(labels.detach().flatten().cpu().numpy()) return loss + + def test_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"].float() + + logits = self(ref_input_ids) + + # Track predictions and labels for accuracy and F1 score + preds = (torch.sigmoid(logits) > 0.5).float() # Get predicted class labels + self.validation_step_preds.extend(preds.detach().flatten().cpu().numpy()) + self.validation_step_labels.extend(labels.detach().flatten().cpu().numpy()) def on_validation_epoch_end(self): # Calculate accuracy, AUPRC, and AUROC for validation @@ -208,6 +220,21 @@ def on_validation_epoch_end(self): self.validation_step_labels.clear() self.validation_step_preds.clear() + + def on_test_epoch_end(self): + # Calculate accuracy, AUPRC, and AUROC for validation + val_accuracy = accuracy_score(self.validation_step_labels, self.validation_step_preds) + precision, recall, _ = precision_recall_curve(self.validation_step_labels, self.validation_step_preds) + val_auprc = auc(recall, precision) + val_auroc = roc_auc_score(self.validation_step_labels, self.validation_step_preds) + + # Log validation metrics + self.log("test/accuracy", val_accuracy, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + self.log("test/AUPRC", val_auprc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + self.log("test/AUROC", val_auroc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + + self.validation_step_labels.clear() + self.validation_step_preds.clear() def on_train_epoch_end(self): # Calculate accuracy, AUPRC, and AUROC for training @@ -267,8 +294,7 @@ def setup(self, stage=None): self.dataset = load_from_disk(self._get_preprocessed_cache_file()) # Split the dataset into train and validation sets - self.train_dataset = self.dataset["train"] - self.val_dataset = self.dataset["test"] + self.test_dataset = self.dataset["test"] def train_dataloader(self): return DataLoader( @@ -289,6 +315,30 @@ def val_dataloader(self): pin_memory=True, shuffle=False ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _split_dataset(self,selected_validation_chromosome): + log.warning(f"SPLITTING THE DATASET INTO TRAIN AND VAL SET, VAL SET BEING CHROMOSOME {selected_validation_chromosome}") + self.train_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + != selected_validation_chromosome, + keep_in_memory=True, + ) + self.val_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + == selected_validation_chromosome, + keep_in_memory=True, + ) + self.validation_chromosome = selected_validation_chromosome def _get_preprocessed_cache_file(self): self.cache_dir = osp.join( @@ -376,8 +426,8 @@ def setup(self, stage=None): self.dataset = load_from_disk(self._get_preprocessed_cache_file()) # Split the dataset into train and validation sets - self.train_dataset = self.dataset["train"] - self.val_dataset = self.dataset["test"] + self.test_dataset = self.dataset["test"] + def train_dataloader(self): return DataLoader( @@ -398,6 +448,30 @@ def val_dataloader(self): pin_memory=True, shuffle=False ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _split_dataset(self,selected_validation_chromosome): + log.warning(f"SPLITTING THE DATASET INTO TRAIN AND VAL SET, VAL SET BEING CHROMOSOME {selected_validation_chromosome}") + self.train_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + != selected_validation_chromosome, + keep_in_memory=True, + ) + self.val_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + == selected_validation_chromosome, + keep_in_memory=True, + ) + self.validation_chromosome = selected_validation_chromosome def _get_preprocessed_cache_file(self): self.cache_dir = osp.join( @@ -445,6 +519,7 @@ def _download_and_preprocess_data(self): dataset.save_to_disk(self._get_preprocessed_cache_file()) log.warning("Data downloaded and preprocessed successfully.") + def finetune_histone_marks(args): """ Main function to start finetuning on Histone Marks with PyTorch Lightning. @@ -453,53 +528,57 @@ def finetune_histone_marks(args): args: Command line arguments or configuration dictionary. """ wandb.login(key=args.wandb_api_key) - wandb_logger = WandbLogger( - name=f"{args.name_wb}-{args.seq_len}", - project="Histone Marks", - log_model=True # Automatically log model checkpoints - ) data_module = HistoneMarksDataModule(args) data_module.setup() - model = Lit_ChromatinFeatures(args) + np.random.seed(0) + candidates = np.unique(data_module.dataset["train"]["chromosome"]) + held_chromosomes = np.random.choice(candidates,5,replace = False) - # Callbacks for early stopping and model checkpointing - early_stopping_callback = EarlyStopping( - monitor="val_loss", - patience=args.patience, - verbose=True, - mode="min" + for idx,val_chromosome in enumerate(held_chromosomes): + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}-fold-{idx+1}", + project="Histone Marks", + log_model=True # Automatically log model checkpoints ) + data_module._split_dataset(val_chromosome) - checkpoint_callback = ModelCheckpoint( - dirpath="./checkpoints", - filename="best-checkpoint", - save_top_k=1, - verbose=True, - monitor="val_loss", - mode="min" - ) + model = Lit_ChromatinFeatures(args) - nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" - - # Set up the PyTorch Lightning Trainer - trainer = pl.Trainer( - max_epochs=args.num_epochs, - devices=nb_device, - logger=wandb_logger, - callbacks=[early_stopping_callback, checkpoint_callback], - log_every_n_steps=1, - limit_train_batches=args.train_ratio, - limit_val_batches=args.eval_ratio, - val_check_interval=args.log_interval, - gradient_clip_val=1.0, - precision=args.precision, - accumulate_grad_batches=args.accumulate_grad_batches, - num_sanity_val_steps=0 - ) - # Start the training process - trainer.fit(model, data_module) + checkpoint_callback = ModelCheckpoint( + dirpath=f"{args.save_dir}/checkpoints", + filename=f"best-checkpoint-on-chromosome{val_chromosome}", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision=args.precision, + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + # Start the training process + trainer.fit(model, data_module) + trainer.test(model,data_module,ckpt_path=f"./{args.save_dir}/checkpoints/best-checkpoint-on-chromosome{val_chromosome}.ckpt") + + # Finish the current WandB run + wandb.finish() def finetune_dna_accessibility(args): """ @@ -509,50 +588,55 @@ def finetune_dna_accessibility(args): args: Command line arguments or configuration dictionary. """ wandb.login(key=args.wandb_api_key) - wandb_logger = WandbLogger( - name=f"{args.name_wb}-{args.seq_len}", - project="DNA Accessibility", - log_model=True # Automatically log model checkpoints - ) data_module = DNAAccessibilityDataModule(args) data_module.setup() + np.random.seed(0) + candidates = np.unique(data_module.dataset["train"]["chromosome"]) + held_chromosomes = np.random.choice(candidates,5,replace = False) - model = Lit_ChromatinFeatures(args) + for idx,val_chromosome in enumerate(held_chromosomes): - # Callbacks for early stopping and model checkpointing - early_stopping_callback = EarlyStopping( - monitor="val_loss", - patience=args.patience, - verbose=True, - mode="min" - ) + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}-fold-{idx+1}", + project="DNA Accessibility", + log_model=True # Automatically log model checkpoints + ) - checkpoint_callback = ModelCheckpoint( - dirpath="./checkpoints", - filename="best-checkpoint", - save_top_k=1, - verbose=True, - monitor="val_loss", - mode="min" - ) + data_module._split_dataset(val_chromosome) - nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" - - # Set up the PyTorch Lightning Trainer - trainer = pl.Trainer( - max_epochs=args.num_epochs, - devices=nb_device, - logger=wandb_logger, - callbacks=[early_stopping_callback, checkpoint_callback], - log_every_n_steps=1, - limit_train_batches=args.train_ratio, - limit_val_batches=args.eval_ratio, - val_check_interval=args.log_interval, - gradient_clip_val=1.0, - precision=args.precision, - accumulate_grad_batches=args.accumulate_grad_batches, - num_sanity_val_steps=0 - ) - # Start the training process - trainer.fit(model, data_module) + model = Lit_ChromatinFeatures(args) + + checkpoint_callback = ModelCheckpoint( + dirpath=f"{args.save_dir}/checkpoints", + filename=f"best-checkpoint-on-chromosome{val_chromosome}", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision=args.precision, + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + # Start the training process + trainer.fit(model, data_module) + trainer.test(model,data_module,ckpt_path=f"./{args.save_dir}/checkpoints/best-checkpoint-on-chromosome{val_chromosome}.ckpt") + + # Finish the current WandB run + wandb.finish() diff --git a/finetuning_glrb/finetune_regulatory_elements.py b/finetuning_glrb/finetune_regulatory_elements.py index 73b6fa9..6bc493d 100644 --- a/finetuning_glrb/finetune_regulatory_elements.py +++ b/finetuning_glrb/finetune_regulatory_elements.py @@ -4,6 +4,7 @@ import re import torch import torch.nn as nn +import numpy as np from torch.utils.data import DataLoader from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer, AutoConfig import wandb @@ -54,7 +55,7 @@ def recast_chromosome(examples): dict with chromosome recast as integers. """ return { - "chromosome": -1 if examples["chromosome"] == "X" else -2 if examples["chromosome"] == "Y" else int(examples["chromosome"]) + "chromosome": -1 if examples["chromosome"] in ["X","Y"] else int(examples["chromosome"]) } class MLP_RegulatoryElements(nn.Module): @@ -188,6 +189,16 @@ def validation_step(self, batch, batch_idx): self.validation_step_labels.extend(labels.detach().flatten().cpu().numpy()) return loss + + def test_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"].float() + + logits = self(ref_input_ids).squeeze(-1) + #Track predictions and labels for accuracy and F1 score + preds = (torch.sigmoid(logits) > 0.5).float() # Get predicted class labels + self.validation_step_preds.extend(preds.detach().flatten().cpu().numpy()) + self.validation_step_labels.extend(labels.detach().flatten().cpu().numpy()) def on_validation_epoch_end(self): # Calculate accuracy, AUPRC, and AUROC for validation @@ -206,6 +217,24 @@ def on_validation_epoch_end(self): self.validation_step_labels.clear() self.validation_step_preds.clear() + + def on_test_epoch_end(self): + # Calculate accuracy, AUPRC, and AUROC for validation + val_accuracy = accuracy_score(self.validation_step_labels, self.validation_step_preds) + precision, recall, _ = precision_recall_curve(self.validation_step_labels, self.validation_step_preds) + val_auprc = auc(recall, precision) + + if self.task == "regulatory_element_enhancer": + val_auroc = roc_auc_score(self.validation_step_labels, self.validation_step_preds) + + # Log validation metrics + self.log("test/accuracy", val_accuracy, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + self.log("test/AUPRC", val_auprc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + if self.task == "regulatory_element_enhancer": + self.log("test/AUROC", val_auroc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + + self.validation_step_labels.clear() + self.validation_step_preds.clear() def on_train_epoch_end(self): # Calculate accuracy, AUPRC, and AUROC for training @@ -268,8 +297,7 @@ def setup(self, stage=None): self.dataset = load_from_disk(self._get_preprocessed_cache_file()) # Split the dataset into train and validation sets - self.train_dataset = self.dataset["train"] - self.val_dataset = self.dataset["test"] + self.test_dataset = self.dataset["test"] def train_dataloader(self): return DataLoader( @@ -290,6 +318,30 @@ def val_dataloader(self): pin_memory=True, shuffle=False ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _split_dataset(self,selected_validation_chromosome): + log.warning(f"SPLITTING THE DATASET INTO TRAIN AND VAL SET, VAL SET BEING CHROMOSOME {selected_validation_chromosome}") + self.train_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + != selected_validation_chromosome, + keep_in_memory=True, + ) + self.val_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + == selected_validation_chromosome, + keep_in_memory=True, + ) + self.validation_chromosome = selected_validation_chromosome def _get_preprocessed_cache_file(self): self.cache_dir = osp.join( @@ -377,8 +429,7 @@ def setup(self, stage=None): self.dataset = load_from_disk(self._get_preprocessed_cache_file()) # Split the dataset into train and validation sets - self.train_dataset = self.dataset["train"] - self.val_dataset = self.dataset["test"] + self.test_dataset = self.dataset["test"] def train_dataloader(self): return DataLoader( @@ -400,6 +451,30 @@ def val_dataloader(self): shuffle=False ) + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _split_dataset(self,selected_validation_chromosome): + log.warning(f"SPLITTING THE DATASET INTO TRAIN AND VAL SET, VAL SET BEING CHROMOSOME {selected_validation_chromosome}") + self.train_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + != selected_validation_chromosome, + keep_in_memory=True, + ) + self.val_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + == selected_validation_chromosome, + keep_in_memory=True, + ) + self.validation_chromosome = selected_validation_chromosome + def _get_preprocessed_cache_file(self): self.cache_dir = osp.join( "./", "data", "InstaDeepAI___genomics-long-range-benchmark", @@ -454,53 +529,56 @@ def finetune_promoters(args): args: Command line arguments or configuration dictionary. """ wandb.login(key=args.wandb_api_key) - wandb_logger = WandbLogger( - name=f"{args.name_wb}-{args.seq_len}", - project="Regulatory Element Promoter", - log_model=True # Automatically log model checkpoints - ) data_module = PromoterDataModule(args) data_module.setup() - model = Lit_RegulatoryElements(args) + np.random.seed(0) + candidates = np.unique(data_module.dataset["train"]["chromosome"]) + held_chromosomes = np.random.choice(candidates,5,replace = False) - # Callbacks for early stopping and model checkpointing - early_stopping_callback = EarlyStopping( - monitor="val_loss", - patience=args.patience, - verbose=True, - mode="min" - ) + for idx,val_chromosome in enumerate(held_chromosomes): - checkpoint_callback = ModelCheckpoint( - dirpath="./checkpoints", - filename="best-checkpoint", - save_top_k=1, - verbose=True, - monitor="val_loss", - mode="min" - ) + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}-fold-{idx+1}", + project="Regulatory Element Promoter", + log_model=True # Automatically log model checkpoints + ) + data_module._split_dataset(val_chromosome) + model = Lit_RegulatoryElements(args) + + checkpoint_callback = ModelCheckpoint( + dirpath=f"{args.save_dir}/checkpoints", + filename=f"best-checkpoint-on-chromosome{val_chromosome}", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) - nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" - - # Set up the PyTorch Lightning Trainer - trainer = pl.Trainer( - max_epochs=args.num_epochs, - devices=nb_device, - logger=wandb_logger, - callbacks=[early_stopping_callback, checkpoint_callback], - log_every_n_steps=1, - limit_train_batches=args.train_ratio, - limit_val_batches=args.eval_ratio, - val_check_interval=args.log_interval, - gradient_clip_val=1.0, - precision=args.precision, - accumulate_grad_batches=args.accumulate_grad_batches, - num_sanity_val_steps=0 - ) + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision=args.precision, + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + # Start the training process + trainer.fit(model, data_module) + trainer.test(model,data_module,ckpt_path=f"./{args.save_dir}/checkpoints/best-checkpoint-on-chromosome{val_chromosome}.ckpt") - # Start the training process - trainer.fit(model, data_module) + # Finish the current WandB run + wandb.finish() def finetune_enhancers(args): """ @@ -510,50 +588,54 @@ def finetune_enhancers(args): args: Command line arguments or configuration dictionary. """ wandb.login(key=args.wandb_api_key) - wandb_logger = WandbLogger( - name=f"{args.name_wb}-{args.seq_len}", - project="Regulatory Elements Enhancer", - log_model=True # Automatically log model checkpoints - ) data_module = EnhancerDataModule(args) data_module.setup() - model = Lit_RegulatoryElements(args) + np.random.seed(0) + candidates = np.unique(data_module.dataset["train"]["chromosome"]) + held_chromosomes = np.random.choice(candidates,5,replace = False) - # Callbacks for early stopping and model checkpointing - early_stopping_callback = EarlyStopping( - monitor="val_loss", - patience=args.patience, - verbose=True, - mode="min" - ) + for idx,val_chromosome in enumerate(held_chromosomes): - checkpoint_callback = ModelCheckpoint( - dirpath="./checkpoints", - filename="best-checkpoint", - save_top_k=1, - verbose=True, - monitor="val_loss", - mode="min" - ) + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}-fold-{idx+1}", + project="Regulatory Elements Enhancer", + log_model=True # Automatically log model checkpoints + ) + data_module._split_dataset(val_chromosome) + model = Lit_RegulatoryElements(args) - nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" - - # Set up the PyTorch Lightning Trainer - trainer = pl.Trainer( - max_epochs=args.num_epochs, - devices=nb_device, - logger=wandb_logger, - callbacks=[early_stopping_callback, checkpoint_callback], - log_every_n_steps=1, - limit_train_batches=args.train_ratio, - limit_val_batches=args.eval_ratio, - val_check_interval=args.log_interval, - gradient_clip_val=1.0, - precision=args.precision, - accumulate_grad_batches=args.accumulate_grad_batches, - num_sanity_val_steps=0 - ) - # Start the training process - trainer.fit(model, data_module) + checkpoint_callback = ModelCheckpoint( + dirpath=f"{args.save_dir}/checkpoints", + filename=f"best-checkpoint-on-chromosome{val_chromosome}", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision=args.precision, + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + # Start the training process + trainer.fit(model, data_module) + trainer.test(model,data_module,ckpt_path=f"./{args.save_dir}/checkpoints/best-checkpoint-on-chromosome{val_chromosome}.ckpt") + + # Finish the current WandB run + wandb.finish() diff --git a/finetuning_glrb/finetune_variant_effect_OMIM.py b/finetuning_glrb/finetune_variant_effect_OMIM.py index 68d820e..c145021 100644 --- a/finetuning_glrb/finetune_variant_effect_OMIM.py +++ b/finetuning_glrb/finetune_variant_effect_OMIM.py @@ -4,6 +4,7 @@ import re import torch import torch.nn as nn +import numpy as np from torch.utils.data import DataLoader from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer, AutoConfig import wandb @@ -62,7 +63,7 @@ def recast_chromosome(examples): dict with recast chromosome as an integer. """ return { - "chromosome": -1 if examples["chromosome"] == "X" else -2 if examples["chromosome"] == "Y" else int(examples["chromosome"]) + "chromosome": -1 if examples["chromosome"] in ["X","Y"] else int(examples["chromosome"]) } def find_variant_idx(examples): @@ -238,6 +239,24 @@ def validation_step(self, batch, batch_idx): all_predictions = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy() self.validation_step_preds.extend(all_predictions) self.validation_step_labels.extend(all_labels) + + def test_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + alt_input_ids = batch["alt_input_ids"] + variant_index = batch["variant_idx"] + labels = batch["labels"] + + logits = self(alt_input_ids, ref_input_ids, variant_index) + + preds = torch.argmax(logits, dim=1) + correct = (preds == labels).sum().item() + self.validation_step_correct += correct + self.validation_step_total += len(labels) + + all_labels = labels.cpu().numpy() + all_predictions = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy() + self.validation_step_preds.extend(all_predictions) + self.validation_step_labels.extend(all_labels) def on_validation_epoch_end(self): val_auroc = roc_auc_score(self.validation_step_labels, self.validation_step_preds) @@ -255,6 +274,23 @@ def on_validation_epoch_end(self): self.validation_step_preds.clear() self.validation_step_correct = 0 self.validation_step_total = 0 + + def on_test_epoch_end(self): + val_auroc = roc_auc_score(self.validation_step_labels, self.validation_step_preds) + precision, recall, _ = precision_recall_curve(self.validation_step_labels, self.validation_step_preds) + val_auprc = auc(recall, precision) + val_accuracy = self.validation_step_correct / self.validation_step_total + + self.logger.experiment.log({ + "test/AUROC": val_auroc, + "test/Accuracy": val_accuracy, + "test/AUPRC": val_auprc, + }) + + self.validation_step_labels.clear() + self.validation_step_preds.clear() + self.validation_step_correct = 0 + self.validation_step_total = 0 def on_train_epoch_end(self): train_auroc = roc_auc_score(self.training_step_labels, self.training_step_preds) @@ -312,8 +348,8 @@ def setup(self, stage=None): self.prepare_data() self.dataset = load_from_disk(self._get_preprocessed_cache_file()) - self.train_dataset = self.dataset["train"] - self.val_dataset = self.dataset["test"] + # Split the dataset into train and validation sets + self.test_dataset = self.dataset["test"] def train_dataloader(self): return DataLoader( @@ -334,6 +370,30 @@ def val_dataloader(self): pin_memory=True, shuffle=True ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _split_dataset(self,selected_validation_chromosome): + log.warning(f"SPLITTING THE DATASET INTO TRAIN AND VAL SET, VAL SET BEING CHROMOSOME {selected_validation_chromosome}") + self.train_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + != selected_validation_chromosome, + keep_in_memory=True, + ) + self.val_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + == selected_validation_chromosome, + keep_in_memory=True, + ) + self.validation_chromosome = selected_validation_chromosome def _get_preprocessed_cache_file(self): self.cache_dir = osp.join( @@ -387,48 +447,53 @@ def finetune(args): args: Command line arguments or configuration dictionary. """ wandb.login(key=args.wandb_api_key) - wandb_logger = WandbLogger( - name=f"{args.name}-{args.seq_len}", - project="Variant Effect Prediction OMIM", - log_model=True - ) data_module = VariantEffectPredictionDataModule(args) data_module.setup() - model = Lit_OMIMFinetuning(args) + np.random.seed(0) + candidates = np.unique(data_module.dataset["train"]["chromosome"]) + held_chromosomes = np.random.choice(candidates,5,replace = False) - early_stopping_callback = EarlyStopping( - monitor="val_loss", - patience=args.patience, - verbose=True, - mode="min" - ) + for idx,val_chromosome in enumerate(held_chromosomes): - checkpoint_callback = ModelCheckpoint( - dirpath="./checkpoints", - filename="best-checkpoint", - save_top_k=1, - verbose=True, - monitor="val_loss", - mode="min" - ) + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}-fold-{idx+1}", + project="Variant Effect Prediction OMIM", + log_model=True + ) - nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" - - # Set up the PyTorch Lightning Trainer - trainer = pl.Trainer( - max_epochs=args.num_epochs, - devices=nb_device, - logger=wandb_logger, - callbacks=[early_stopping_callback, checkpoint_callback], - log_every_n_steps=1, - limit_train_batches=args.train_ratio, - limit_val_batches=args.eval_ratio, - val_check_interval=args.log_interval, - gradient_clip_val=1.0, - precision=args.precision, - accumulate_grad_batches=args.accumulate_grad_batches, - num_sanity_val_steps=0 - ) + data_module._split_dataset(val_chromosome) + model = Lit_OMIMFinetuning(args) + + checkpoint_callback = ModelCheckpoint( + dirpath=f"{args.save_dir}/checkpoints", + filename=f"best-checkpoint-on-chromosome{val_chromosome}", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision=args.precision, + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + trainer.fit(model, data_module) + trainer.test(model,data_module,ckpt_path=f"./{args.save_dir}/checkpoints/best-checkpoint-on-chromosome{val_chromosome}.ckpt") - trainer.fit(model, data_module) + # Finish the current WandB run + wandb.finish() diff --git a/finetuning_glrb/finetune_variant_effect_causal_eqtl.py b/finetuning_glrb/finetune_variant_effect_causal_eqtl.py index 29ebd4c..bf325d5 100644 --- a/finetuning_glrb/finetune_variant_effect_causal_eqtl.py +++ b/finetuning_glrb/finetune_variant_effect_causal_eqtl.py @@ -4,6 +4,7 @@ import re import torch import torch.nn as nn +import numpy as np from torch.utils.data import DataLoader, Dataset from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer, AutoConfig import wandb @@ -34,7 +35,7 @@ def recast_chromosome_tissue_dist2TSS(examples): dict with recast chromosome, tissue, and distance to TSS. """ return { - "chromosome": -1 if examples["chromosome"] == "X" else int(examples["chromosome"]), + "chromosome": -1 if examples["chromosome"] in ["X","Y"] else int(examples["chromosome"]), "tissue": examples["tissue"], "distance_to_nearest_tss": examples["distance_to_nearest_tss"] } @@ -256,7 +257,34 @@ def validation_step(self, batch, batch_idx): self.validation_step_labels[i].extend(filtered_labels) self.validation_step_preds[i].extend(filtered_preds) + def test_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + alt_input_ids = batch["alt_input_ids"] + variant_index = batch["variant_idx"] + tissue_embed = batch["tissue_embed"] + labels = batch["labels"] + distance_to_nearest_tss = batch["distance_to_nearest_tss"] + + logits = self(alt_input_ids, ref_input_ids, variant_index, tissue_embed) + + # Predictions for AUROC + preds = torch.argmax(logits, dim=1).detach().cpu().numpy() + labels_np = labels.cpu().numpy() + + for i, (min_dist, max_dist) in enumerate(DIST_TO_TSS): + mask = ((distance_to_nearest_tss >= min_dist) & (distance_to_nearest_tss < max_dist)).cpu().numpy() + filtered_preds = preds[mask] + filtered_labels = labels_np[mask] + + if len(filtered_labels) > 0: + self.validation_step_labels[i].extend(filtered_labels) + self.validation_step_preds[i].extend(filtered_preds) + def on_validation_epoch_end(self): + # Initialize lists to store all labels and predictions across all TSS distance buckets + all_labels = [] + all_preds = [] + for i, (min_dist, max_dist) in enumerate(DIST_TO_TSS): if len(self.validation_step_labels[i]) > 0: val_auroc = roc_auc_score(self.validation_step_labels[i], self.validation_step_preds[i]) @@ -270,8 +298,64 @@ def on_validation_epoch_end(self): self.log(f'validation/TSS({min_dist}-{max_dist})/AUPRC', val_auprc, on_epoch=True, sync_dist=True) print(f'Bucket {i} [{min_dist}-{max_dist}] - AUROC: {val_auroc:.4f}') + # Aggregate the labels and predictions + all_labels.extend(self.validation_step_labels[i]) + all_preds.extend(self.validation_step_preds[i]) + + self.validation_step_labels[i].clear() + self.validation_step_preds[i].clear() + + # Compute overall metrics if there are any labels + if all_labels: + overall_auroc = roc_auc_score(all_labels, all_preds) + precision, recall, _ = precision_recall_curve(all_labels, all_preds) + overall_auprc = auc(recall, precision) + overall_accuracy = accuracy_score(all_labels, all_preds) + + # Log overall metrics + self.log('validation/overall/AUROC', overall_auroc, on_epoch=True, sync_dist=True) + self.log('validation/overall/Accuracy', overall_accuracy, on_epoch=True, sync_dist=True) + self.log('validation/overall/AUPRC', overall_auprc, on_epoch=True, sync_dist=True) + print(f'Overall - AUROC: {overall_auroc:.4f}') + + + def on_test_epoch_end(self): + # Initialize lists to store all labels and predictions across all TSS distance buckets + all_labels = [] + all_preds = [] + + for i, (min_dist, max_dist) in enumerate(DIST_TO_TSS): + if len(self.validation_step_labels[i]) > 0: + val_auroc = roc_auc_score(self.validation_step_labels[i], self.validation_step_preds[i]) + precision, recall, _ = precision_recall_curve(self.validation_step_labels[i], self.validation_step_preds[i]) + val_auprc = auc(recall, precision) + val_accuracy = accuracy_score(self.validation_step_labels[i], self.validation_step_preds[i]) + + # Log metrics for each TSS distance bucket + self.log(f'test/TSS({min_dist}-{max_dist})/AUROC', val_auroc, on_epoch=True, sync_dist=True) + self.log(f'test/TSS({min_dist}-{max_dist})/Accuracy', val_accuracy, on_epoch=True, sync_dist=True) + self.log(f'test/TSS({min_dist}-{max_dist})/AUPRC', val_auprc, on_epoch=True, sync_dist=True) + print(f'Bucket {i} [{min_dist}-{max_dist}] - AUROC: {val_auroc:.4f}') + + # Aggregate the labels and predictions + all_labels.extend(self.validation_step_labels[i]) + all_preds.extend(self.validation_step_preds[i]) + self.validation_step_labels[i].clear() self.validation_step_preds[i].clear() + + # Compute overall metrics if there are any labels + if all_labels: + overall_auroc = roc_auc_score(all_labels, all_preds) + precision, recall, _ = precision_recall_curve(all_labels, all_preds) + overall_auprc = auc(recall, precision) + overall_accuracy = accuracy_score(all_labels, all_preds) + + # Log overall metrics + self.log('test/overall/AUROC', overall_auroc, on_epoch=True, sync_dist=True) + self.log('test/overall/Accuracy', overall_accuracy, on_epoch=True, sync_dist=True) + self.log('test/overall/AUPRC', overall_auprc, on_epoch=True, sync_dist=True) + print(f'Overall - AUROC: {overall_auroc:.4f}') def configure_optimizers(self): return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate) @@ -316,8 +400,7 @@ def setup(self, stage=None): self.dataset = load_from_disk(self._get_preprocessed_cache_file()) # Split the dataset into train and validation sets - self.train_dataset = self.dataset["train"] - self.val_dataset = self.dataset["test"] + self.test_dataset = self.dataset["test"] def train_dataloader(self): return DataLoader( @@ -338,6 +421,30 @@ def val_dataloader(self): pin_memory=True, shuffle=False ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _split_dataset(self,selected_validation_chromosome): + log.warning(f"SPLITTING THE DATASET INTO TRAIN AND VAL SET, VAL SET BEING CHROMOSOME {selected_validation_chromosome}") + self.train_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + != selected_validation_chromosome, + keep_in_memory=True, + ) + self.val_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + == selected_validation_chromosome, + keep_in_memory=True, + ) + self.validation_chromosome = selected_validation_chromosome def _get_preprocessed_cache_file(self): cache_dir = osp.join( @@ -400,49 +507,54 @@ def finetune(args): args: Command line arguments or configuration dictionary. """ wandb.login(key=args.wandb_api_key) - wandb_logger = WandbLogger( - name=f"{args.name}-{args.seq_len}", - project="Variant Effect Prediction Causal eQTL", - log_model=True # Automatically log model checkpoints - ) data_module = VariantEffectPredictionDataModule(args) data_module.setup() - model = LitVEPFinetuning(args) + np.random.seed(0) + candidates = np.unique(data_module.dataset["train"]["chromosome"]) + held_chromosomes = np.random.choice(candidates,5,replace = False) - early_stopping_callback = EarlyStopping( - monitor="val_loss", - patience=args.patience, - verbose=True, - mode="min" - ) + for idx,val_chromosome in enumerate(held_chromosomes): - checkpoint_callback = ModelCheckpoint( - dirpath="./checkpoints", - filename="best-checkpoint", - save_top_k=1, - verbose=True, - monitor="val_loss", - mode="min" - ) + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}-fold-{idx+1}", + project="Variant Effect Prediction Causal eQTL", + log_model=True # Automatically log model checkpoints + ) - nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" - - # Set up the PyTorch Lightning Trainer - trainer = pl.Trainer( - max_epochs=args.num_epochs, - devices=nb_device, - logger=wandb_logger, - callbacks=[early_stopping_callback, checkpoint_callback], - log_every_n_steps=1, - limit_train_batches=args.train_ratio, - limit_val_batches=args.eval_ratio, - val_check_interval=args.log_interval, - gradient_clip_val=1.0, - precision=args.precision, - accumulate_grad_batches=args.accumulate_grad_batches, - num_sanity_val_steps=0 - ) + data_module._split_dataset(val_chromosome) + model = LitVEPFinetuning(args) + + checkpoint_callback = ModelCheckpoint( + dirpath=f"{args.save_dir}/checkpoints", + filename=f"best-checkpoint-on-chromosome{val_chromosome}", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision=args.precision, + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + # Start the training process + trainer.fit(model, data_module) + trainer.test(model,data_module,ckpt_path=f"./{args.save_dir}/checkpoints/best-checkpoint-on-chromosome{val_chromosome}.ckpt") - # Start the training process - trainer.fit(model, data_module) + # Finish the current WandB run + wandb.finish() diff --git a/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py b/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py index 542e6a0..da83a2b 100644 --- a/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py +++ b/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py @@ -4,6 +4,7 @@ import re import torch import torch.nn as nn +import numpy as np from torch.utils.data import DataLoader from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer, AutoConfig import wandb @@ -62,7 +63,7 @@ def recast_chromosome(examples): dict with recast chromosome as an integer. """ return { - "chromosome": -1 if examples["chromosome"] == "X" else -2 if examples["chromosome"] == "Y" else int(examples["chromosome"]) + "chromosome": -1 if examples["chromosome"] in ["X","Y"] else int(examples["chromosome"]) } def find_variant_idx(examples): @@ -236,6 +237,23 @@ def validation_step(self, batch, batch_idx): all_predictions = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy() self.validation_step_preds.extend(all_predictions) self.validation_step_labels.extend(all_labels) + + def test_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + alt_input_ids = batch["alt_input_ids"] + variant_index = batch["variant_idx"] + labels = batch["labels"] + + logits = self(alt_input_ids, ref_input_ids, variant_index) + preds = torch.argmax(logits, dim=1) + correct = (preds == labels).sum().item() + self.validation_step_correct += correct + self.validation_step_total += len(labels) + + all_labels = labels.cpu().numpy() + all_predictions = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy() + self.validation_step_preds.extend(all_predictions) + self.validation_step_labels.extend(all_labels) def on_validation_epoch_end(self): val_auroc = roc_auc_score(self.validation_step_labels, self.validation_step_preds) @@ -253,6 +271,23 @@ def on_validation_epoch_end(self): self.validation_step_preds.clear() self.validation_step_correct = 0 self.validation_step_total = 0 + + def on_test_epoch_end(self): + val_auroc = roc_auc_score(self.validation_step_labels, self.validation_step_preds) + precision, recall, _ = precision_recall_curve(self.validation_step_labels, self.validation_step_preds) + val_auprc = auc(recall, precision) + val_accuracy = self.validation_step_correct / self.validation_step_total + + self.logger.experiment.log({ + "test/AUROC": val_auroc, + "test/Accuracy": val_accuracy, + "test/AUPRC": val_auprc, + }) + + self.validation_step_labels.clear() + self.validation_step_preds.clear() + self.validation_step_correct = 0 + self.validation_step_total = 0 def on_train_epoch_end(self): train_auroc = roc_auc_score(self.training_step_labels, self.training_step_preds) @@ -310,8 +345,8 @@ def setup(self, stage=None): self.prepare_data() self.dataset = load_from_disk(self._get_preprocessed_cache_file()) - self.train_dataset = self.dataset["train"] - self.val_dataset = self.dataset["test"] + # Split the dataset into train and validation sets + self.test_dataset = self.dataset["test"] def train_dataloader(self): return DataLoader( @@ -333,6 +368,16 @@ def val_dataloader(self): shuffle=True ) + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + def _get_preprocessed_cache_file(self): self.cache_dir = osp.join( "./", "data", "InstaDeepAI___genomics-long-range-benchmark", @@ -385,48 +430,52 @@ def finetune(args): args: Command line arguments or configuration dictionary. """ wandb.login(key=args.wandb_api_key) - wandb_logger = WandbLogger( - name=f"{args.name_wb}-{args.seq_len}", - project="Variant Effect Prediction ClinVar", - log_model=True - ) data_module = VariantEffectPredictionDataModule(args) data_module.setup() - model = Lit_ClinVarFinetuning(args) + np.random.seed(0) + candidates = np.unique(data_module.dataset["train"]["chromosome"]) + held_chromosomes = np.random.choice(candidates,5,replace = False) - early_stopping_callback = EarlyStopping( - monitor="val_loss", - patience=args.patience, - verbose=True, - mode="min" - ) + for idx,val_chromosome in enumerate(held_chromosomes): - checkpoint_callback = ModelCheckpoint( - dirpath="./checkpoints", - filename="best-checkpoint", - save_top_k=1, - verbose=True, - monitor="val_loss", - mode="min" - ) + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}-fold-{idx+1}", + project="Variant Effect Prediction ClinVar", + log_model=True + ) + data_module._split_dataset(val_chromosome) + model = Lit_ClinVarFinetuning(args) + + checkpoint_callback = ModelCheckpoint( + dirpath=f"{args.save_dir}/checkpoints", + filename=f"best-checkpoint-on-chromosome{val_chromosome}", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) - nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" - - # Set up the PyTorch Lightning Trainer - trainer = pl.Trainer( - max_epochs=args.num_epochs, - devices=nb_device, - logger=wandb_logger, - callbacks=[early_stopping_callback, checkpoint_callback], - log_every_n_steps=1, - limit_train_batches=args.train_ratio, - limit_val_batches=args.eval_ratio, - val_check_interval=args.log_interval, - gradient_clip_val=1.0, - precision=args.precision, - accumulate_grad_batches=args.accumulate_grad_batches, - num_sanity_val_steps=0 - ) + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision=args.precision, + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + trainer.fit(model, data_module) + trainer.test(model,data_module,ckpt_path=f"./{args.save_dir}/checkpoints/best-checkpoint-on-chromosome{val_chromosome}.ckpt") - trainer.fit(model, data_module) + # Finish the current WandB run + wandb.finish() diff --git a/finetuning_glrb/main.py b/finetuning_glrb/main.py index 8c1c8de..2b61e46 100644 --- a/finetuning_glrb/main.py +++ b/finetuning_glrb/main.py @@ -87,8 +87,6 @@ def main(opts): parser.add_argument("--accumulate_grad_batches", type=int, default=1, help="Accumulate gradients") parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for optimizer") - parser.add_argument("--patience", type=int, default=10, - help="Number of epochs with no improvement after which training will be stopped") parser.add_argument("--log_interval", type=int, default=5, help="Log interval") parser.add_argument("--train_ratio", type=float, default=1.0, From 2953fe8bdda79c18ecddb564b9afa062fb65d775 Mon Sep 17 00:00:00 2001 From: Aymen Kallala Date: Wed, 28 Aug 2024 18:51:56 -0700 Subject: [PATCH 13/13] readme --- finetuning_glrb/README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/finetuning_glrb/README.md b/finetuning_glrb/README.md index 8408034..30915b1 100644 --- a/finetuning_glrb/README.md +++ b/finetuning_glrb/README.md @@ -38,8 +38,6 @@ To fine-tune a model, execute the `finetune.sh` script. The script runs the `mai **`--learning_rate`**: Specifies the learning rate for the optimizer. -**`--patience`**: Determines the number of epochs with no improvement after which training will be stopped. - **`--log_interval`**: Interval (in steps) at which to log training metrics and run a validation step. **`--train_ratio`**: Specifies the ratio of the dataset to use for training.