diff --git a/finetuning_glrb/README.md b/finetuning_glrb/README.md new file mode 100644 index 0000000..30915b1 --- /dev/null +++ b/finetuning_glrb/README.md @@ -0,0 +1,54 @@ +# Fine-Tuning DNA Models on the Genomics Long Range Benchmark 🧬 + +This folder contains the necessary scripts and configurations to fine-tune DNA models on the [Genomics Long Range Benchmark](https://huggingface.co/datasets/InstaDeepAI/genomics-long-range-benchmark). + +DNA Models are loaded from the Hugging-Face Hub 🤗. + +## Getting Started + +To fine-tune a model, execute the `finetune.sh` script. The script runs the `main.py` script with various command-line arguments that configure the fine-tuning process. Below is a description of each argument used. + +**`--task`**: Choose one of the predefined variant effects. Options include: `"variant_effect_causal_eqtl"`, `"variant_effect_pathogenic_clinvar"`, `"variant_effect_pathogenic_omim"`, `"cage_prediction"`, `"bulk_rna_expression"`, `"chromatin_features_histone_marks"`, `"chromatin_features_dna_accessibility"`, `"regulatory_element_promoter"`, `"regulatory_element_enhancer"`. + +**`--seq_len`**: Specifies the sequence length in base pairs (bp). + +**`--model_name`**: Name of the pre-trained model to fine-tune (on the HF hub). + +**`--bp_per_token`**: Defines the number of base pairs per token used in the tokenization process of the model. + +**`--save_dir`**: Directory where the checkpoints and logs will be saved. + +**`--wandb_api_key`**: API key for Weights & Biases logging. + +**`--name_wb`**: Name for the Weights & Biases run. + +**`--train_batch_size`**: Defines the batch size for training. + +**`--test_batch_size`**: Defines the batch size for testing/validation. + +**`--num_workers`**: Number of workers to use for data loading. + +**`--rcps`**: Indicates whether to use RCPS when extracting embeddings. + +**`--num_epochs`**: Specifies the number of epochs to train. + +**`--precision`**: Choose the precision. Options include: `"transformer-engine"`, `"transformer-engine-float16"`, `"16-true"`, `"16-mixed"`, `"bf16-true"`, `"bf16-mixed"`, `"32-true"`, `"64-true"`. + +**`--accumulate_grad_batches`**: Number of batches for which to accumulate gradients accross devices. + +**`--learning_rate`**: Specifies the learning rate for the optimizer. + +**`--log_interval`**: Interval (in steps) at which to log training metrics and run a validation step. + +**`--train_ratio`**: Specifies the ratio of the dataset to use for training. + +**`--eval_ratio`**: Specifies the ratio of the dataset to use for evaluation. + + + +### Running the Script + +To start finetuning, first make sure that you have modified the `finetune.sh` script with the correct parameters for your task. Then, simply run: + +```bash +bash finetune.sh diff --git a/finetuning_glrb/__init__.py b/finetuning_glrb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/finetuning_glrb/finetune.sh b/finetuning_glrb/finetune.sh new file mode 100644 index 0000000..71a37c5 --- /dev/null +++ b/finetuning_glrb/finetune.sh @@ -0,0 +1,35 @@ +python finetuning_glrb/main.py \ + --task "bulk_rna_expression" \ + --seq_len 12000 \ + --model_name "your_model_name_from_the_hub" \ + --bp_per_token TBD \ + --save_dir "output/" \ + --wandb_api_key "your_wandb_api_key" \ + --name_wb "your_wandb_run_name" \ + --train_batch_size 4 \ + --test_batch_size 4 \ + --rcps true \ + --num_workers 6 \ + --num_epochs 1 \ + --precision "16-mixed" \ + --learning_rate "3e-5" \ + --log_interval 512 \ + --accumulate_grad_batches 128 \ + --train_ratio 1.0 \ + --eval_ratio 1.0 + +##Examples + +## Caduceus-PS +#task=bulk_rna_expression +#seq_len=131000 +#bp_per_token=1 +#model_name="kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16" +#rcps=true + +## NTv2 +#task=regulatory_element_promoter +#seq_len=12288 # 2048 (seq len) * 6 (kmers) +#bp_per_token=6 +#model_name="InstaDeepAI/nucleotide-transformer-v2-50m-multi-species" +#delete the rcps flag (it is not a RC-equivarient model) diff --git a/finetuning_glrb/finetune_bulk_rna.py b/finetuning_glrb/finetune_bulk_rna.py new file mode 100644 index 0000000..76345c2 --- /dev/null +++ b/finetuning_glrb/finetune_bulk_rna.py @@ -0,0 +1,434 @@ +import os +from functools import partial +from os import path as osp +import re +import torch +import torch.nn as nn +import numpy as np +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForMaskedLM, + AutoModel, + AutoTokenizer, + AutoConfig, + DefaultDataCollator, +) +import wandb +import lightning.pytorch as pl +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from datasets import load_dataset, load_from_disk +from sklearn.metrics import r2_score +from src.utils.train import get_logger +from caduceus.tokenization_caduceus import CaduceusTokenizer +from finetuning_glrb.utils import fsspec_exists, get_last_embedding_dimension + +# Logger setup +log = get_logger(__name__) + +# Constants for the upstream and downstream window sizes +WINDOW_SIZE_BP_UPSTREAM = 384 +WINDOW_SIZE_BP_DOWNSTREAM = 256 + +def tokenize_variants(examples, tokenizer, max_length: int): + """ + Tokenize sequence. + + Args: + examples: A batch of items from the dataset. + tokenizer: AutoTokenizer instance. + max_length: Maximum length for tokenization. + + Returns: + dict with tokenized input IDs. + """ + ref_tokenized = tokenizer.batch_encode_plus( + examples["sequence"], + add_special_tokens=False, + return_attention_mask=False, + max_length=max_length, + truncation=True, + ) + return { + "ref_input_ids": ref_tokenized["input_ids"], + } + +def recast_chromosome(examples): + """ + Recast chromosome to integer format. + + Returns: + dict with chromosome recast as integers. + """ + return { + "chromosome": -1 if examples["chromosome"] in ["X","Y"] else int(examples["chromosome"]) + } + +class MLP_BulkRNA(nn.Module): + """ + Regression head for Bulk RNA prediction task. + + Args: + input_size: Input size for the linear layer. + hidden_size: Hidden layer size. + output_size: Output size for the linear layer. + """ + def __init__(self, input_size, hidden_size, output_size): + super(MLP_BulkRNA, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.sp1 = nn.Softplus() + self.fc2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + return self.fc2(self.sp1(self.fc1(x))) + +class DNAModelForBulkRNA(nn.Module): + """ + DNA Model for Bulk RNA prediction. + + Args: + args: Arguments containing model configurations. + """ + def __init__(self, args): + super().__init__() + self.rcps = args.rcps + self.bp_per_token = args.bp_per_token + self.config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True) + + # Load the appropriate backbone model based on the model name + if "nucleotide-transformer" in args.model_name.lower(): + self.backbone = AutoModelForMaskedLM.from_pretrained(args.model_name, trust_remote_code=True).esm + else: + self.backbone = AutoModel.from_pretrained(args.model_name, trust_remote_code=True) + + print(f"MODEL LOADED: {self.backbone}") + self.inner_dim = get_last_embedding_dimension(self.backbone,self.rcps) + print(f"Inner dim founded for the Foundation Model: {self.inner_dim}") + self.head = MLP_BulkRNA(input_size=self.inner_dim, hidden_size=2*self.inner_dim, output_size=218) + + def forward(self, input_ids): + # Get embeddings from the backbone + embeds_out = self.backbone(input_ids)[0] + num_channels = embeds_out.size(-1) + + # Calculate the window size and indexes + window_size_upstream = WINDOW_SIZE_BP_UPSTREAM // self.bp_per_token + window_size_downstream = WINDOW_SIZE_BP_DOWNSTREAM // self.bp_per_token + start, end = -window_size_downstream, window_size_upstream + + if self.rcps: + # If the model is RC-equivariant + embeds = embeds_out[..., :num_channels // 2] + expanded_indices = torch.arange(start, end, device=embeds.device).unsqueeze(0).expand(embeds.size(0), -1) + embeds.size(1) // 2 + expanded_indices = torch.clamp(expanded_indices, 0, embeds.size(1) - 1) + + # Extract the relevant window from the embeddings and average it through the sequence length dimension + tokens_window_ref = torch.gather( + embeds, 1, + expanded_indices.unsqueeze(-1).expand(-1, -1, embeds.size(2)) + ) + tokens_window_ref = tokens_window_ref.mean(dim=1) + + #Same for the RC-equivalent + rc_embeds = embeds_out[..., num_channels // 2:].contiguous().flip(dims=[1, 2]) + expanded_indices = torch.arange(start, end, device=rc_embeds.device).unsqueeze(0).expand(rc_embeds.size(0), -1) + rc_embeds.size(1) // 2 + expanded_indices = torch.clamp(expanded_indices, 0, rc_embeds.size(1) - 1) + + tokens_window_rc = torch.gather( + rc_embeds, 1, + expanded_indices.unsqueeze(-1).expand(-1, -1, rc_embeds.size(2)) + ) + tokens_window_rc = tokens_window_rc.mean(dim=1) + + #Combine the Reference and RC-Equivariant resulting embeddings + aggregated_embeds = tokens_window_rc + tokens_window_ref + return self.head(aggregated_embeds) + + else: + # No reverse complement processing + expanded_indices = torch.arange(start, end, device=embeds_out.device).unsqueeze(0).expand(embeds_out.size(0), -1) + embeds_out.size(1) // 2 + expanded_indices = torch.clamp(expanded_indices, 0, embeds_out.size(1) - 1) + + # Extract the relevant window + tokens_window_ref = torch.gather( + embeds_out, 1, + expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_out.size(2)) + ) + tokens_window_ref = tokens_window_ref.mean(dim=1) + return self.head(tokens_window_ref) + +class Lit_BulkRNAFinetuning(pl.LightningModule): + """ + PyTorch Lightning model for fine-tuning on Bulk RNA prediction. + + Args: + args: Arguments containing model and training configurations. + """ + def __init__(self, args): + super().__init__() + self.save_hyperparameters(args) + self.setup() + + def setup(self,stage=None): + self.rcps = self.hparams.rcps + self.model = DNAModelForBulkRNA(self.hparams) + self.criterion = nn.MSELoss() + self.validation_step_preds = [] + self.validation_step_labels = [] + self.training_step_preds = [] + self.training_step_labels = [] + + def forward(self, ref_input_ids): + return self.model(ref_input_ids) + + def training_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"] + + logits = self.model(ref_input_ids) + loss = self.criterion(logits, labels) + self.log('train_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + # Track predictions and labels for R² score"chromosome": -1 if examples["chromosome"] == "X" or "Y" else int(examples["chromosome"]) + self.training_step_preds.extend(logits.detach().cpu().numpy()) + self.training_step_labels.extend(labels.detach().cpu().numpy()) + + return loss + + def validation_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"] + + logits = self.model(ref_input_ids) + loss = self.criterion(logits, labels) + self.log('val_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + # Track predictions and labels for R² score + self.validation_step_preds.extend(logits.detach().cpu().numpy()) + self.validation_step_labels.extend(labels.detach().cpu().numpy()) + + return loss + + def test_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"] + + logits = self.model(ref_input_ids) + + # Track predictions and labels for R² score + self.validation_step_preds.extend(logits.detach().cpu().numpy()) + self.validation_step_labels.extend(labels.detach().cpu().numpy()) + + def on_validation_epoch_end(self): + # Calculate R² score for validation + val_r2 = r2_score(self.validation_step_labels, self.validation_step_preds) + self.log("validation/R2", val_r2, on_epoch=True, prog_bar=True, logger=True) + self.validation_step_labels.clear() + self.validation_step_preds.clear() + + def on_test_epoch_end(self): + # Calculate R² score for validation + val_r2 = r2_score(self.validation_step_labels, self.validation_step_preds) + self.log("test/R2", val_r2, on_epoch=True, prog_bar=True, logger=True) + self.validation_step_labels.clear() + self.validation_step_preds.clear() + + def on_train_epoch_end(self): + # Calculate R² score for training + train_r2 = r2_score(self.training_step_labels, self.training_step_preds) + self.log("train/R2", train_r2, on_epoch=True, prog_bar=True, logger=True) + self.training_step_labels.clear() + self.training_step_preds.clear() + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate) + +class BulkRNADataModule(pl.LightningDataModule): + """ + Data module for Bulk RNA finetuning with PyTorch Lightning. + + Args: + config: Configuration dictionary with data-related parameters. + """ + def __init__(self, config): + super().__init__() + self.seq_len = config.seq_len + self.bp_per_token = config.bp_per_token + self.model_name = config.model_name + self.train_batch_size = config.train_batch_size + self.test_batch_size = config.test_batch_size + self.num_workers = config.num_workers + self.train_ratio = config.train_ratio + self.eval_ratio = config.eval_ratio + self.cache_dir = "./" + self.dataset = None + + # Initialize the tokenizer + if "caduceus" in self.model_name: + self.tokenizer = CaduceusTokenizer( + model_max_length=self.seq_len, + add_special_tokens=False + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + + def prepare_data(self): + # Download and preprocess data if not already done + if not fsspec_exists(self._get_preprocessed_cache_file()): + self._download_and_preprocess_data() + + def setup(self, stage=None): + # Load the preprocessed dataset + self.prepare_data() + self.dataset = load_from_disk(self._get_preprocessed_cache_file()) + + # Split the dataset into train and validation sets + self.test_dataset = self.dataset["test"] + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _split_dataset(self,selected_validation_chromosome): + log.warning(f"SPLITTING THE DATASET INTO TRAIN AND VAL SET, VAL SET BEING CHROMOSOME {selected_validation_chromosome}") + self.train_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + != selected_validation_chromosome, + keep_in_memory=True, + ) + self.val_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + == selected_validation_chromosome, + keep_in_memory=True, + ) + self.validation_chromosome = selected_validation_chromosome + + def _get_preprocessed_cache_file(self): + self.cache_dir = osp.join( + "./", "data", "InstaDeepAI___genomics-long-range-benchmark", + "variant_effect_pathogenic_clinvar", f"seqlen{self.seq_len}" + ) + cache_file = os.path.join(self.cache_dir, "caduceus_char_token_preprocessed") + return re.sub(r"=", "_", cache_file) + + def _download_and_preprocess_data(self): + log.warning("Downloading and preprocessing data...") + dataset = load_dataset( + "InstaDeepAI/genomics-long-range-benchmark", + task_name="bulk_rna_expression", + sequence_length=self.seq_len, + load_from_cache=False, + trust_remote_code=True + ) + try: + del dataset["validation"] # Remove empty validation split if it exists + except KeyError: + pass + + # Process data: filter sequences with too many 'N's, recast chromosomes, and tokenize + dataset = dataset.filter( + lambda example: example["sequence"].count('N') < 0.005 * self.seq_len, + desc="Filter N's" + ) + dataset = dataset.map( + recast_chromosome, + remove_columns=["chromosome"], + desc="Recast chromosome" + ) + dataset = dataset.map( + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len//self.bp_per_token), + batch_size=1000, + batched=True, + remove_columns=["sequence"], + desc="Tokenize", + num_proc=self.num_workers + ) + + # Save processed dataset to disk + dataset.save_to_disk(self._get_preprocessed_cache_file()) + log.warning("Data downloaded and preprocessed successfully.") + +def finetune(args): + """ + Main function to start the training process using PyTorch Lightning. + + Args: + args: Command line arguments or configuration dictionary. + """ + wandb.login(key=args.wandb_api_key) + data_module = BulkRNADataModule(args) + data_module.setup() + + np.random.seed(0) + candidates = np.unique(data_module.dataset["train"]["chromosome"]) + held_chromosomes = np.random.choice(candidates,5,replace = False) + + for idx,val_chromosome in enumerate(held_chromosomes): + + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}-fold-{idx+1}", + project="Bulk RNA Expression", + log_model=True, + save_dir="./wandb" + ) + data_module._split_dataset(val_chromosome) + model = Lit_BulkRNAFinetuning(args) + + checkpoint_callback = ModelCheckpoint( + dirpath=f"{args.save_dir}/checkpoints", + filename=f"best-checkpoint-on-chromosome{val_chromosome}", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision=args.precision, + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + # Start the training process + trainer.fit(model, data_module) + trainer.test(model,data_module,ckpt_path=f"./{args.save_dir}/checkpoints/best-checkpoint-on-chromosome{val_chromosome}.ckpt") + + # Finish the current WandB run + wandb.finish() diff --git a/finetuning_glrb/finetune_cage.py b/finetuning_glrb/finetune_cage.py new file mode 100644 index 0000000..a825352 --- /dev/null +++ b/finetuning_glrb/finetune_cage.py @@ -0,0 +1,450 @@ +import os +from functools import partial +from os import path as osp +import re +import torch +import torch.nn as nn +import numpy +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForMaskedLM, + AutoModel, + AutoTokenizer, + AutoConfig, + DefaultDataCollator, +) +import numpy as np +import wandb +import lightning.pytorch as pl +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from datasets import load_dataset, load_from_disk +from sklearn.metrics import r2_score +from src.utils.train import get_logger +from caduceus.tokenization_caduceus import CaduceusTokenizer +from finetuning_glrb.utils import fsspec_exists, get_last_embedding_dimension + +# Logger setup +log = get_logger(__name__) + +# Constants for the upstream and downstream window sizes +BIN_SIZE = 128 + +def _closest_multiple_after(after, multiple_of): + """ + Computes the closest multiple of a certain int after a certain int. E.g the closest multiple of 2 after 9 is 10. + + Args: + after: int. + multiple_of: int + + Returns: + the closest multiple found. + """ + remainder = after % multiple_of + if remainder == 0: + return after + else: + return after + (multiple_of - remainder) + +def tokenize_variants(examples, tokenizer, max_length: int): + """ + Tokenize sequence. + + Args: + examples: A batch of items from the dataset. + tokenizer: AutoTokenizer instance. + max_length: Maximum length for tokenization. + + Returns: + dict with tokenized input IDs. + """ + ref_tokenized = tokenizer.batch_encode_plus( + examples["sequence"], + add_special_tokens=False, + return_attention_mask=False, + max_length=max_length, + truncation=True, + padding = "max_length" + ) + + return { + "ref_input_ids": ref_tokenized["input_ids"], + } + +def recast_chromosome(examples): + """ + Recast chromosome to integer format. + + Returns: + dict with chromosome recast as integers. + """ + return { + #"chromosome": -1 if examples["chromosome"] == "X" else -2 if examples["chromosome"] == "Y" else int(examples["chromosome"]) + "chromosome": -1 if examples["chromosome"] in ["X","Y"] else int(examples["chromosome"]) + } + +class MLP_CAGE(nn.Module): + """ + Regression head for Bulk RNA prediction task. + + Args: + input_size: Input size for the linear layer. + hidden_size: Hidden layer size. + output_size: Output size for the linear layer. + """ + def __init__(self, input_size, hidden_size, output_size): + super(MLP_CAGE, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.sp1 = nn.Softplus() + self.fc2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + return self.fc2(self.sp1(self.fc1(x))) + +class DNAModelForCAGE(nn.Module): + """ + DNA Model for CAGE prediction. + + Args: + args: Arguments containing model configurations. + """ + def __init__(self, args): + super().__init__() + self.rcps = args.rcps + self.initial_seq_len = args.seq_len + self.bp_per_token = args.bp_per_token + self.config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True) + self.num_bins= self.initial_seq_len // BIN_SIZE + + # Load the appropriate backbone model based on the model name + if "nucleotide-transformer" in args.model_name.lower(): + self.backbone = AutoModelForMaskedLM.from_pretrained(args.model_name, trust_remote_code=True).esm + else: + self.backbone = AutoModel.from_pretrained(args.model_name, trust_remote_code=True) + + print(f"MODEL LOADED: {self.backbone}") + self.inner_dim = get_last_embedding_dimension(self.backbone,self.rcps) + print(f"Inner dim founded for the Foundation Model: {self.inner_dim}") + self.head = MLP_CAGE(input_size=self.inner_dim, hidden_size=2*self.inner_dim, output_size=50) + #self.heads = nn.ModuleList([ + #MLP_CAGE(input_size=self.inner_dim, hidden_size=2*self.inner_dim, output_size=50) + #for _ in range(self.num_bins) + #]) + + def forward(self, input_ids): + # Get embeddings from the backbone + embeds_out = self.backbone(input_ids)[0] + batch_size,seq_len,num_channels =embeds_out.size() + bin_size = seq_len//self.num_bins + # Reshape embeddings into bins + bins = embeds_out.view(batch_size, self.num_bins, bin_size, num_channels) + + if self.rcps: + # If the model is RC-equivariant + embeds = bins[..., :num_channels // 2] + #Same for the RC-equivalent + rc_embeds = bins[..., num_channels // 2:].contiguous().flip(dims=[1,2,3]) + + #Combine the Reference and RC-Equivariant resulting embeddings + bins = embeds + rc_embeds + + outputs = torch.zeros(batch_size,self.num_bins,50).to(bins.device) + for i in range(self.num_bins): + bin_embedding = bins[:, i, :, :].mean(dim=1) + out = self.head(bin_embedding) + outputs[:,i,:] = out + return outputs + +class Lit_DNAModelForCAGE(pl.LightningModule): + """ + PyTorch Lightning model for fine-tuning on Bulk RNA prediction. + + Args: + args: Arguments containing model and training configurations. + """ + def __init__(self, args): + super().__init__() + self.save_hyperparameters(args) + self.setup() + + def setup(self,stage=None): + + self.rcps = self.hparams.rcps + self.model = DNAModelForCAGE(self.hparams) + self.criterion = nn.MSELoss() + self.validation_step_preds = [] + self.validation_step_labels = [] + self.training_step_preds = [] + self.training_step_labels = [] + + def forward(self, ref_input_ids): + return self.model(ref_input_ids) + + def training_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"] + + logits = self.model(ref_input_ids) + loss = self.criterion(logits, labels) + self.log('train_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + # Track predictions and labels for R² score + self.training_step_preds.extend(logits.view(-1, logits.size(-1)).detach().cpu().numpy()) + self.training_step_labels.extend(labels.view(-1, labels.size(-1)).detach().cpu().numpy()) + + return loss + + def validation_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"] + + logits = self.model(ref_input_ids) + loss = self.criterion(logits, labels) + self.log('val_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + # Track predictions and labels for R² score + self.validation_step_preds.extend(logits.view(-1, logits.size(-1)).detach().cpu().numpy()) + self.validation_step_labels.extend(labels.view(-1, labels.size(-1)).detach().cpu().numpy()) + + return loss + + def test_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"] + + logits = self.model(ref_input_ids) + + # Track predictions and labels for R² score + self.validation_step_preds.extend(logits.view(-1, logits.size(-1)).detach().cpu().numpy()) + self.validation_step_labels.extend(labels.view(-1, labels.size(-1)).detach().cpu().numpy()) + + def on_validation_epoch_end(self): + #Log(1+x) normalize the preds and labels + self.validation_step_labels = list(np.log1p(self.validation_step_labels)) + self.validation_step_preds = list(np.log1p(self.validation_step_preds)) + + # Calculate R² score for validation + val_r2 = r2_score(self.validation_step_labels, self.validation_step_preds) + self.log("validation/R2", val_r2, on_epoch=True, prog_bar=True, logger=True) + self.validation_step_labels.clear() + self.validation_step_preds.clear() + + def on_test_epoch_end(self): + #Log(1+x) normalize the preds and labels + self.validation_step_labels = list(np.log1p(self.validation_step_labels)) + self.validation_step_preds = list(np.log1p(self.validation_step_preds)) + + # Calculate R² score for validation + val_r2 = r2_score(self.validation_step_labels, self.validation_step_preds) + self.log("test/R2", val_r2, on_epoch=True, prog_bar=True, logger=True) + self.validation_step_labels.clear() + self.validation_step_preds.clear() + + def on_train_epoch_end(self): + #Log(1+x) normalize the preds and labels + self.training_step_labels = list(np.log1p(self.training_step_labels)) + self.training_step_preds = list(np.log1p(self.training_step_preds)) + + # Calculate R² score for training + train_r2 = r2_score(self.training_step_labels, self.training_step_preds) + self.log("train/R2", train_r2, on_epoch=True, prog_bar=True, logger=True) + self.training_step_labels.clear() + self.training_step_preds.clear() + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate) + +class CAGEDataModule(pl.LightningDataModule): + """ + Data module for Bulk RNA finetuning with PyTorch Lightning. + + Args: + config: Configuration dictionary with data-related parameters. + """ + def __init__(self, config): + super().__init__() + self.seq_len = config.seq_len + self.bp_per_token = config.bp_per_token + self.tokens_per_seq = self.seq_len//self.bp_per_token + self.bins_per_seq = self.seq_len // 128 + self.model_name = config.model_name + self.train_batch_size = config.train_batch_size + self.test_batch_size = config.test_batch_size + self.num_workers = config.num_workers + self.train_ratio = config.train_ratio + self.eval_ratio = config.eval_ratio + self.cache_dir = "./" + self.dataset = None + + # Initialize the tokenizer + if "caduceus" in self.model_name: + self.tokenizer = CaduceusTokenizer( + model_max_length=self.seq_len, + add_special_tokens=False + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + + def prepare_data(self): + # Download and preprocess data if not already done + if not fsspec_exists(self._get_preprocessed_cache_file()): + self._download_and_preprocess_data() + + def setup(self, stage=None): + # Load the preprocessed dataset + self.prepare_data() + self.dataset = load_from_disk(self._get_preprocessed_cache_file()) + + # Split the dataset into train and validation sets + self.test_dataset = self.dataset["test"] + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _split_dataset(self,selected_validation_chromosome): + log.warning(f"SPLITTING THE DATASET INTO TRAIN AND VAL SET, VAL SET BEING CHROMOSOME {selected_validation_chromosome}") + self.train_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + != selected_validation_chromosome, + keep_in_memory=True, + ) + self.val_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + == selected_validation_chromosome, + keep_in_memory=True, + ) + self.validation_chromosome = selected_validation_chromosome + + def _get_preprocessed_cache_file(self): + self.cache_dir = osp.join( + "./", "data", "InstaDeepAI___genomics-long-range-benchmark", + "cage_prediction", f"seqlen{self.seq_len}" + ) + cache_file = os.path.join(self.cache_dir, "caduceus_char_token_preprocessed") + return re.sub(r"=", "_", cache_file) + + def _download_and_preprocess_data(self): + log.warning("Downloading and preprocessing data...") + dataset = load_dataset( + "InstaDeepAI/genomics-long-range-benchmark", + task_name="cage_prediction", + sequence_length=self.seq_len, + load_from_cache=False, + trust_remote_code=True + ) + try: + del dataset["validation"] # Remove empty validation split if it exists + except KeyError: + pass + + # Process data: filter sequences with too many 'N's, recast chromosomes, and tokenize + dataset = dataset.filter( + lambda example: example["sequence"].count('N') < 0.005 * self.seq_len, + desc="Filter N's" + ) + + dataset = dataset.map( + recast_chromosome, + remove_columns=["chromosome"], + desc="Recast chromosome" + ) + + dataset = dataset.map( + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=_closest_multiple_after(self.tokens_per_seq,self.bins_per_seq)), + batch_size=1000, + batched=True, + remove_columns=["sequence"], + desc="Tokenize", + num_proc=self.num_workers + ) + + # Save processed dataset to disk + dataset.save_to_disk(self._get_preprocessed_cache_file()) + log.warning("Data downloaded and preprocessed successfully.") + +def finetune(args): + """ + Main function to start the training process using PyTorch Lightning. + + Args: + args: Command line arguments or configuration dictionary. + """ + + wandb.login(key=args.wandb_api_key) + data_module = CAGEDataModule(args) + data_module.setup() + + np.random.seed(0) + candidates = np.unique(data_module.dataset["train"]["chromosome"]) + held_chromosomes = np.random.choice(candidates,5,replace = False) + + for idx,val_chromosome in enumerate(held_chromosomes): + + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}-fold-{idx+1}", + project="CAGE Predictions", + log_model=True, + save_dir="./wandb" + ) + + data_module._split_dataset(val_chromosome) + model = Lit_DNAModelForCAGE(args) + + checkpoint_callback = ModelCheckpoint( + dirpath=f"{args.save_dir}/checkpoints", + filename=f"best-checkpoint-on-chromosome{val_chromosome}", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision=args.precision, + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + # Start the training process + trainer.fit(model, data_module) + trainer.test(model,data_module,ckpt_path=f"./{args.save_dir}/checkpoints/best-checkpoint-on-chromosome{val_chromosome}.ckpt") + wandb.finish() diff --git a/finetuning_glrb/finetune_chromatin.py b/finetuning_glrb/finetune_chromatin.py new file mode 100644 index 0000000..c803262 --- /dev/null +++ b/finetuning_glrb/finetune_chromatin.py @@ -0,0 +1,642 @@ +import os +from functools import partial +from os import path as osp +import re +import torch +import torch.nn as nn +import numpy as np +from torch.utils.data import DataLoader +import wandb +import lightning.pytorch as pl +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer, AutoConfig +from datasets import load_dataset, load_from_disk +from sklearn.metrics import accuracy_score, precision_recall_curve, auc, roc_auc_score +from transformers import DefaultDataCollator +from src.utils.train import get_logger +from caduceus.tokenization_caduceus import CaduceusTokenizer +from finetuning_glrb.utils import ( + fsspec_exists, + get_last_embedding_dimension +) + +# Logger setup +log = get_logger(__name__) + +# Constants for the window size in base pairs +WINDOW_SIZE_BP = 200 + +def tokenize_variants(examples, tokenizer, max_length: int): + """ + Tokenize sequence. + + Args: + examples: A batch of items from the dataset. + tokenizer: AutoTokenizer instance. + max_length: Maximum length for tokenization. + + Returns: + dict with tokenized input IDs. + """ + seq_tokenized = tokenizer.batch_encode_plus( + examples["sequence"], + add_special_tokens=False, + return_attention_mask=False, + max_length=max_length, + truncation=True, + ) + return { + "ref_input_ids": seq_tokenized["input_ids"] + } + +def recast_chromosome(examples): + """ + Recast chromosome to integer format. + + Returns: + dict with chromosome recast as integers. + """ + return { + "chromosome": -1 if examples["chromosome"] in ["X","Y"] else int(examples["chromosome"]) + } + +class MLP_ChromatineFeatures(nn.Module): + """ + MLP model for predicting chromatin features. + + Args: + input_size: Input size for the linear layer. + hidden_size: Hidden layer size. + output_size: Output size for the linear layer. + """ + def __init__(self, input_size, hidden_size, output_size): + super(MLP_ChromatineFeatures, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.sp1 = nn.Softplus() + self.fc2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + return self.fc2(self.sp1(self.fc1(x))) + +class DNAModelForChromatineFeatures(nn.Module): + """ + DNA Model for Chromatin Features prediction. + + Args: + args: Arguments containing model configurations. + """ + def __init__(self, args): + super().__init__() + self.rcps = args.rcps + self.bp_per_token = args.bp_per_token + self.config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True) + + # Load the appropriate backbone model based on the model name + if "nucleotide-transformer" in args.model_name.lower(): + self.backbone = AutoModelForMaskedLM.from_pretrained(args.model_name, trust_remote_code=True).esm + else: + self.backbone = AutoModel.from_pretrained(args.model_name, trust_remote_code=True) + + print(f"MODEL LOADED: {self.backbone}") + self.inner_dim = get_last_embedding_dimension(self.backbone,self.rcps) + print(f"Inner dim found for the Foundation Model: {self.inner_dim}") + self.head = MLP_ChromatineFeatures(input_size=self.inner_dim, hidden_size=2 * self.inner_dim, output_size=20) + + def forward(self, input_ids): + embeds_out = self.backbone(input_ids)[0] + num_channels = embeds_out.size(-1) + window_size = WINDOW_SIZE_BP // self.bp_per_token // 2 + batch_size, seq_len, embedding_dim = embeds_out.shape + + if self.rcps: + embeds = embeds_out[..., :num_channels // 2] + expanded_indices = ( + torch.arange(-window_size, window_size + 1, device=embeds.device).unsqueeze(0).expand(batch_size, -1) + seq_len // 2 + ) + expanded_indices = torch.clamp(expanded_indices, 0, embeds.size(1) - 1) + + # Extract windowed embeddings for reference sequence + tokens_window_ref = torch.gather(embeds, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds.size(2))) + + # Extract windowed embeddings for reverse complement sequence + rc_embeds = embeds_out[..., num_channels // 2:].contiguous().flip(dims=[1, 2]) + expanded_indices = torch.arange(-window_size, window_size + 1, device=rc_embeds.device).unsqueeze(0).expand(batch_size, -1) + seq_len // 2 + expanded_indices = torch.clamp(expanded_indices, 0, rc_embeds.size(1) - 1) + + tokens_window_rc = torch.gather(rc_embeds, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, rc_embeds.size(2))) + + # Average the embeddings over the sequence length dimension + tokens_window_rc = tokens_window_rc.mean(dim=1) + tokens_window_ref = tokens_window_ref.mean(dim=1) + + # Combine the Reference and RC-equivalent embeddings + aggregated_embeds = tokens_window_rc + tokens_window_ref + return self.head(aggregated_embeds) + + else: + expanded_indices = torch.arange(-window_size, window_size + 1, device=embeds_out.device).unsqueeze(0).expand(batch_size, -1) + seq_len // 2 + expanded_indices = torch.clamp(expanded_indices, 0, embeds_out.size(1) - 1) + + # Extract windowed embeddings for reference sequence + tokens_window_ref = torch.gather(embeds_out, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_out.size(2))).mean(dim=1) + return self.head(tokens_window_ref) + +class Lit_ChromatinFeatures(pl.LightningModule): + """ + PyTorch Lightning model for predicting chromatin features. + + Args: + args: Arguments containing model and training configurations. + """ + def __init__(self, args): + super().__init__() + self.save_hyperparameters(args) + self.setup() + + def setup(self,stage=None): + self.model = DNAModelForChromatineFeatures(self.hparams) + self.criterion = nn.BCEWithLogitsLoss() + self.validation_step_preds = [] + self.validation_step_labels = [] + self.training_step_preds = [] + self.training_step_labels = [] + + def forward(self, ref_input_ids): + return self.model(ref_input_ids) + + def training_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"].float() + + logits = self(ref_input_ids) + loss = self.criterion(logits, labels) + self.log('train_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + # Track predictions and labels for accuracy and F1 score + preds = (torch.sigmoid(logits) > 0.5).float() # Get predicted class labels + self.training_step_preds.extend(preds.detach().flatten().cpu().numpy()) + self.training_step_labels.extend(labels.detach().flatten().cpu().numpy()) + + return loss + + def validation_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"].float() + + logits = self(ref_input_ids) + loss = self.criterion(logits, labels) + self.log('val_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + # Track predictions and labels for accuracy and F1 score + preds = (torch.sigmoid(logits) > 0.5).float() # Get predicted class labels + self.validation_step_preds.extend(preds.detach().flatten().cpu().numpy()) + self.validation_step_labels.extend(labels.detach().flatten().cpu().numpy()) + + return loss + + def test_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"].float() + + logits = self(ref_input_ids) + + # Track predictions and labels for accuracy and F1 score + preds = (torch.sigmoid(logits) > 0.5).float() # Get predicted class labels + self.validation_step_preds.extend(preds.detach().flatten().cpu().numpy()) + self.validation_step_labels.extend(labels.detach().flatten().cpu().numpy()) + + def on_validation_epoch_end(self): + # Calculate accuracy, AUPRC, and AUROC for validation + val_accuracy = accuracy_score(self.validation_step_labels, self.validation_step_preds) + precision, recall, _ = precision_recall_curve(self.validation_step_labels, self.validation_step_preds) + val_auprc = auc(recall, precision) + val_auroc = roc_auc_score(self.validation_step_labels, self.validation_step_preds) + + # Log validation metrics + self.log("validation/accuracy", val_accuracy, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + self.log("validation/AUPRC", val_auprc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + self.log("validation/AUROC", val_auroc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + + self.validation_step_labels.clear() + self.validation_step_preds.clear() + + def on_test_epoch_end(self): + # Calculate accuracy, AUPRC, and AUROC for validation + val_accuracy = accuracy_score(self.validation_step_labels, self.validation_step_preds) + precision, recall, _ = precision_recall_curve(self.validation_step_labels, self.validation_step_preds) + val_auprc = auc(recall, precision) + val_auroc = roc_auc_score(self.validation_step_labels, self.validation_step_preds) + + # Log validation metrics + self.log("test/accuracy", val_accuracy, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + self.log("test/AUPRC", val_auprc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + self.log("test/AUROC", val_auroc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + + self.validation_step_labels.clear() + self.validation_step_preds.clear() + + def on_train_epoch_end(self): + # Calculate accuracy, AUPRC, and AUROC for training + train_accuracy = accuracy_score(self.training_step_labels, self.training_step_preds) + precision, recall, _ = precision_recall_curve(self.training_step_labels, self.training_step_preds) + train_auprc = auc(recall, precision) + train_auroc = roc_auc_score(self.training_step_labels, self.training_step_preds) + + # Log training metrics + self.log("train/accuracy", train_accuracy, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + self.log("train/AUPRC", train_auprc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + self.log("train/AUROC", train_auroc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + + self.training_step_labels.clear() + self.training_step_preds.clear() + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate) + +class HistoneMarksDataModule(pl.LightningDataModule): + """ + Data module for Histone Marks finetuning with PyTorch Lightning. + + Args: + config: Configuration dictionary with data-related parameters. + """ + def __init__(self, config): + super().__init__() + self.seq_len = config.seq_len + self.bp_per_token = config.bp_per_token + self.model_name = config.model_name + self.train_batch_size = config.train_batch_size + self.test_batch_size = config.test_batch_size + self.num_workers = config.num_workers + self.train_ratio = config.train_ratio + self.eval_ratio = config.eval_ratio + self.cache_dir = "./" + self.dataset = None + + # Initialize the tokenizer + if "caduceus" in self.model_name: + self.tokenizer = CaduceusTokenizer( + model_max_length=self.seq_len, + add_special_tokens=False + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + + def prepare_data(self): + # Download and preprocess data if not already done + if not fsspec_exists(self._get_preprocessed_cache_file()): + self._download_and_preprocess_data() + + def setup(self, stage=None): + # Load the preprocessed dataset + self.prepare_data() + self.dataset = load_from_disk(self._get_preprocessed_cache_file()) + + # Split the dataset into train and validation sets + self.test_dataset = self.dataset["test"] + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _split_dataset(self,selected_validation_chromosome): + log.warning(f"SPLITTING THE DATASET INTO TRAIN AND VAL SET, VAL SET BEING CHROMOSOME {selected_validation_chromosome}") + self.train_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + != selected_validation_chromosome, + keep_in_memory=True, + ) + self.val_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + == selected_validation_chromosome, + keep_in_memory=True, + ) + self.validation_chromosome = selected_validation_chromosome + + def _get_preprocessed_cache_file(self): + self.cache_dir = osp.join( + "./", "data", "InstaDeepAI___genomics-long-range-benchmark", + "chromatin_features_histone_marks", f"seqlen{self.seq_len}" + ) + cache_file = os.path.join(self.cache_dir, "caduceus_char_token_preprocessed") + return re.sub(r"=", "_", cache_file) + + def _download_and_preprocess_data(self): + log.warning("Downloading and preprocessing data...") + dataset = load_dataset( + "InstaDeepAI/genomics-long-range-benchmark", + task_name="chromatin_features_histone_marks", + sequence_length=self.seq_len, + subset=True, + load_from_cache=False, + trust_remote_code=True + ) + try: + del dataset["validation"] # Remove empty validation split if it exists + except KeyError: + pass + + # Process data: filter sequences with too many 'N's, recast chromosomes, and tokenize + dataset = dataset.filter( + lambda example: example["sequence"].count('N') < 0.005 * self.seq_len, + desc="Filter N's" + ) + dataset = dataset.map( + recast_chromosome, + remove_columns=["chromosome"], + desc="Recast chromosome" + ) + dataset = dataset.map( + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len//self.bp_per_token), + batch_size=1000, + batched=True, + remove_columns=["sequence"], + desc="Tokenize", + num_proc=self.num_workers + ) + + # Save processed dataset to disk + dataset.save_to_disk(self._get_preprocessed_cache_file()) + log.warning("Data downloaded and preprocessed successfully.") + +class DNAAccessibilityDataModule(pl.LightningDataModule): + """ + Data module for DNA Accessibility finetuning with PyTorch Lightning. + + Args: + config: Configuration dictionary with data-related parameters. + """ + def __init__(self, config): + super().__init__() + self.seq_len = config.seq_len + self.bp_per_token = config.bp_per_token + self.model_name = config.model_name + self.train_batch_size = config.train_batch_size + self.test_batch_size = config.test_batch_size + self.num_workers = config.num_workers + self.train_ratio = config.train_ratio + self.eval_ratio = config.eval_ratio + self.cache_dir = "./" + self.dataset = None + + # Initialize the tokenizer + if "caduceus" in self.model_name: + self.tokenizer = CaduceusTokenizer( + model_max_length=self.seq_len, + add_special_tokens=False + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + + def prepare_data(self): + # Download and preprocess data if not already done + if not fsspec_exists(self._get_preprocessed_cache_file()): + self._download_and_preprocess_data() + + def setup(self, stage=None): + # Load the preprocessed dataset + self.prepare_data() + self.dataset = load_from_disk(self._get_preprocessed_cache_file()) + + # Split the dataset into train and validation sets + self.test_dataset = self.dataset["test"] + + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _split_dataset(self,selected_validation_chromosome): + log.warning(f"SPLITTING THE DATASET INTO TRAIN AND VAL SET, VAL SET BEING CHROMOSOME {selected_validation_chromosome}") + self.train_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + != selected_validation_chromosome, + keep_in_memory=True, + ) + self.val_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + == selected_validation_chromosome, + keep_in_memory=True, + ) + self.validation_chromosome = selected_validation_chromosome + + def _get_preprocessed_cache_file(self): + self.cache_dir = osp.join( + "./", "data", "InstaDeepAI___genomics-long-range-benchmark", + "chromatin_features_dna_accessibility", f"seqlen{self.seq_len}" + ) + cache_file = os.path.join(self.cache_dir, "caduceus_char_token_preprocessed") + return re.sub(r"=", "_", cache_file) + + def _download_and_preprocess_data(self): + log.warning("Downloading and preprocessing data...") + dataset = load_dataset( + "InstaDeepAI/genomics-long-range-benchmark", + task_name="chromatin_features_dna_accessibility", + sequence_length=self.seq_len, + subset=True, + load_from_cache=False, + trust_remote_code=True + ) + try: + del dataset["validation"] # Remove empty validation split if it exists + except KeyError: + pass + + # Process data: filter sequences with too many 'N's, recast chromosomes, and tokenize + dataset = dataset.filter( + lambda example: example["sequence"].count('N') < 0.005 * self.seq_len, + desc="Filter N's" + ) + dataset = dataset.map( + recast_chromosome, + remove_columns=["chromosome"], + desc="Recast chromosome" + ) + dataset = dataset.map( + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len//self.bp_per_token), + batch_size=1000, + batched=True, + remove_columns=["sequence"], + desc="Tokenize", + num_proc=self.num_workers + ) + + # Save processed dataset to disk + dataset.save_to_disk(self._get_preprocessed_cache_file()) + log.warning("Data downloaded and preprocessed successfully.") + + +def finetune_histone_marks(args): + """ + Main function to start finetuning on Histone Marks with PyTorch Lightning. + + Args: + args: Command line arguments or configuration dictionary. + """ + wandb.login(key=args.wandb_api_key) + data_module = HistoneMarksDataModule(args) + data_module.setup() + + np.random.seed(0) + candidates = np.unique(data_module.dataset["train"]["chromosome"]) + held_chromosomes = np.random.choice(candidates,5,replace = False) + + for idx,val_chromosome in enumerate(held_chromosomes): + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}-fold-{idx+1}", + project="Histone Marks", + log_model=True # Automatically log model checkpoints + ) + data_module._split_dataset(val_chromosome) + + model = Lit_ChromatinFeatures(args) + + + checkpoint_callback = ModelCheckpoint( + dirpath=f"{args.save_dir}/checkpoints", + filename=f"best-checkpoint-on-chromosome{val_chromosome}", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision=args.precision, + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + # Start the training process + trainer.fit(model, data_module) + trainer.test(model,data_module,ckpt_path=f"./{args.save_dir}/checkpoints/best-checkpoint-on-chromosome{val_chromosome}.ckpt") + + # Finish the current WandB run + wandb.finish() + +def finetune_dna_accessibility(args): + """ + Main function to start finetunning on DNA Accessibility with PyTorch Lightning. + + Args: + args: Command line arguments or configuration dictionary. + """ + wandb.login(key=args.wandb_api_key) + data_module = DNAAccessibilityDataModule(args) + data_module.setup() + np.random.seed(0) + candidates = np.unique(data_module.dataset["train"]["chromosome"]) + held_chromosomes = np.random.choice(candidates,5,replace = False) + + for idx,val_chromosome in enumerate(held_chromosomes): + + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}-fold-{idx+1}", + project="DNA Accessibility", + log_model=True # Automatically log model checkpoints + ) + + data_module._split_dataset(val_chromosome) + + + model = Lit_ChromatinFeatures(args) + + checkpoint_callback = ModelCheckpoint( + dirpath=f"{args.save_dir}/checkpoints", + filename=f"best-checkpoint-on-chromosome{val_chromosome}", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision=args.precision, + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + # Start the training process + trainer.fit(model, data_module) + trainer.test(model,data_module,ckpt_path=f"./{args.save_dir}/checkpoints/best-checkpoint-on-chromosome{val_chromosome}.ckpt") + + # Finish the current WandB run + wandb.finish() diff --git a/finetuning_glrb/finetune_regulatory_elements.py b/finetuning_glrb/finetune_regulatory_elements.py new file mode 100644 index 0000000..6bc493d --- /dev/null +++ b/finetuning_glrb/finetune_regulatory_elements.py @@ -0,0 +1,641 @@ +import os +from functools import partial +from os import path as osp +import re +import torch +import torch.nn as nn +import numpy as np +from torch.utils.data import DataLoader +from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer, AutoConfig +import wandb +import lightning.pytorch as pl +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from datasets import load_dataset, load_from_disk +from sklearn.metrics import accuracy_score, precision_recall_curve, auc, roc_auc_score +from transformers import DefaultDataCollator +from src.utils.train import get_logger +from caduceus.tokenization_caduceus import CaduceusTokenizer +from finetuning_glrb.utils import fsspec_exists, get_last_embedding_dimension + +# Logger setup +log = get_logger(__name__) + +# Constants for the window size in base pairs +WINDOW_SIZE_BP = 200 + +def tokenize_variants(examples, tokenizer, max_length: int): + """ + Tokenize sequence. + + Args: + examples: A batch of items from the dataset. + tokenizer: AutoTokenizer instance. + max_length: Maximum length for tokenization. + + Returns: + dict with tokenized input IDs. + """ + seq_tokenized = tokenizer.batch_encode_plus( + examples["sequence"], + add_special_tokens=False, + return_attention_mask=False, + max_length=max_length, + truncation=True, + ) + return { + "ref_input_ids": seq_tokenized["input_ids"] + } + +def recast_chromosome(examples): + """ + Recast chromosome to integer format. + + Returns: + dict with chromosome recast as integers. + """ + return { + "chromosome": -1 if examples["chromosome"] in ["X","Y"] else int(examples["chromosome"]) + } + +class MLP_RegulatoryElements(nn.Module): + """ + MLP model for predicting regulatory elements. + + Args: + input_size: Input size for the linear layer. + hidden_size: Hidden layer size. + output_size: Output size for the linear layer. + """ + def __init__(self, input_size, hidden_size, output_size): + super(MLP_RegulatoryElements, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.sp1 = nn.Softplus() + self.fc2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + return self.fc2(self.sp1(self.fc1(x))) + +class DNAModelForRegulatoryElements(nn.Module): + """ + DNA Model for Regulatory Elements prediction. + + Args: + args: Arguments containing model configurations. + """ + def __init__(self, args): + super().__init__() + self.rcps = args.rcps + self.bp_per_token = args.bp_per_token + self.config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True) + + # Load the appropriate backbone model based on the model name + if "nucleotide-transformer" in args.model_name.lower(): + self.backbone = AutoModelForMaskedLM.from_pretrained(args.model_name, trust_remote_code=True).esm + else: + self.backbone = AutoModel.from_pretrained(args.model_name, trust_remote_code=True) + + print(f"MODEL LOADED: {self.backbone}") + self.inner_dim = get_last_embedding_dimension(self.backbone,self.rcps) + print(f"Inner dim found for the Foundation Model: {self.inner_dim}") + self.head = MLP_RegulatoryElements(input_size=self.inner_dim, hidden_size=2 * self.inner_dim, output_size=1) + + def forward(self, input_ids): + # Get embeddings for the alternate and reference sequences + embeds_out = self.backbone(input_ids)[0] + num_channels = embeds_out.size(-1) + window_size = WINDOW_SIZE_BP // self.bp_per_token // 2 + batch_size, seq_len, embedding_dim = embeds_out.shape + + if self.rcps: + # Handle reverse complement processing + ref_embeds = embeds_out[..., :num_channels // 2] + rc_embeds = embeds_out[..., num_channels // 2:].contiguous().flip(dims=[1, 2]) + + expanded_indices = torch.arange(-window_size, window_size + 1, device=ref_embeds.device).unsqueeze(0).expand(batch_size, -1) + seq_len // 2 + expanded_indices = torch.clamp(expanded_indices, 0, ref_embeds.size(1) - 1) + + # Extract windowed embeddings for the reference sequence + tokens_window_ref = torch.gather(ref_embeds, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, ref_embeds.size(2))).mean(dim=1) + + expanded_indices = torch.arange(-window_size, window_size + 1, device=rc_embeds.device).unsqueeze(0).expand(batch_size, -1) + seq_len // 2 + expanded_indices = torch.clamp(expanded_indices, 0, rc_embeds.size(1) - 1) + + # Extract windowed embeddings for the reverse complement sequence + tokens_window_rc = torch.gather(rc_embeds, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, rc_embeds.size(2))).mean(dim=1) + + # Combine the reference and reverse complement embeddings + aggregated_embeds = tokens_window_rc + tokens_window_ref + return self.head(aggregated_embeds) + + else: + # Handle non-reverse complement processing + expanded_indices = torch.arange(-window_size, window_size + 1, device=embeds_out.device).unsqueeze(0).expand(batch_size, -1) + seq_len // 2 + expanded_indices = torch.clamp(expanded_indices, 0, embeds_out.size(1) - 1) + + tokens_window_ref = torch.gather(embeds_out, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_out.size(2))).mean(dim=1) + return self.head(tokens_window_ref) + +class Lit_RegulatoryElements(pl.LightningModule): + """ + PyTorch Lightning model for predicting regulatory elements. + + Args: + args: Arguments containing model and training configurations. + """ + def __init__(self, args): + super().__init__() + self.save_hyperparameters(args) + self.setup() + + def setup(self,stage=None): + self.model = DNAModelForRegulatoryElements(self.hparams) + self.task = self.hparams.task + self.criterion = nn.BCEWithLogitsLoss() + self.validation_step_preds = [] + self.validation_step_labels = [] + self.training_step_preds = [] + self.training_step_labels = [] + + def forward(self, ref_input_ids): + return self.model(ref_input_ids) + + def training_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"].float() + + logits = self(ref_input_ids).squeeze(-1) + loss = self.criterion(logits, labels) + self.log('train_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + # Track predictions and labels for accuracy and F1 score + preds = (torch.sigmoid(logits) > 0.5).float() # Get predicted class labels + self.training_step_preds.extend(preds.detach().flatten().cpu().numpy()) + self.training_step_labels.extend(labels.detach().flatten().cpu().numpy()) + + return loss + + def validation_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"].float() + + logits = self(ref_input_ids).squeeze(-1) + loss = self.criterion(logits, labels) + self.log('val_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + # Track predictions and labels for accuracy and F1 score + preds = (torch.sigmoid(logits) > 0.5).float() # Get predicted class labels + self.validation_step_preds.extend(preds.detach().flatten().cpu().numpy()) + self.validation_step_labels.extend(labels.detach().flatten().cpu().numpy()) + + return loss + + def test_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + labels = batch["labels"].float() + + logits = self(ref_input_ids).squeeze(-1) + #Track predictions and labels for accuracy and F1 score + preds = (torch.sigmoid(logits) > 0.5).float() # Get predicted class labels + self.validation_step_preds.extend(preds.detach().flatten().cpu().numpy()) + self.validation_step_labels.extend(labels.detach().flatten().cpu().numpy()) + + def on_validation_epoch_end(self): + # Calculate accuracy, AUPRC, and AUROC for validation + val_accuracy = accuracy_score(self.validation_step_labels, self.validation_step_preds) + precision, recall, _ = precision_recall_curve(self.validation_step_labels, self.validation_step_preds) + val_auprc = auc(recall, precision) + + if self.task == "regulatory_element_enhancer": + val_auroc = roc_auc_score(self.validation_step_labels, self.validation_step_preds) + + # Log validation metrics + self.log("validation/accuracy", val_accuracy, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + self.log("validation/AUPRC", val_auprc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + if self.task == "regulatory_element_enhancer": + self.log("validation/AUROC", val_auroc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + + self.validation_step_labels.clear() + self.validation_step_preds.clear() + + def on_test_epoch_end(self): + # Calculate accuracy, AUPRC, and AUROC for validation + val_accuracy = accuracy_score(self.validation_step_labels, self.validation_step_preds) + precision, recall, _ = precision_recall_curve(self.validation_step_labels, self.validation_step_preds) + val_auprc = auc(recall, precision) + + if self.task == "regulatory_element_enhancer": + val_auroc = roc_auc_score(self.validation_step_labels, self.validation_step_preds) + + # Log validation metrics + self.log("test/accuracy", val_accuracy, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + self.log("test/AUPRC", val_auprc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + if self.task == "regulatory_element_enhancer": + self.log("test/AUROC", val_auroc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + + self.validation_step_labels.clear() + self.validation_step_preds.clear() + + def on_train_epoch_end(self): + # Calculate accuracy, AUPRC, and AUROC for training + train_accuracy = accuracy_score(self.training_step_labels, self.training_step_preds) + precision, recall, _ = precision_recall_curve(self.training_step_labels, self.training_step_preds) + train_auprc = auc(recall, precision) + + if self.task == "regulatory_element_enhancer": + train_auroc = roc_auc_score(self.training_step_labels, self.training_step_preds) + + # Log training metrics + self.log("train/accuracy", train_accuracy, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + self.log("train/AUPRC", train_auprc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + if self.task == "regulatory_element_enhancer": + self.log("train/AUROC", train_auroc, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + + self.training_step_labels.clear() + self.training_step_preds.clear() + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate) + +class PromoterDataModule(pl.LightningDataModule): + """ + Data module for Promoter regulatory element finetuning with PyTorch Lightning. + + Args: + config: Configuration dictionary with data-related parameters. + """ + def __init__(self, config): + super().__init__() + self.seq_len = config.seq_len + self.bp_per_token = config.bp_per_token + self.model_name = config.model_name + self.train_batch_size = config.train_batch_size + self.test_batch_size = config.test_batch_size + self.num_workers = config.num_workers + self.train_ratio = config.train_ratio + self.eval_ratio = config.eval_ratio + self.cache_dir = "./" + self.dataset = None + + # Initialize the tokenizer + if "caduceus" in self.model_name: + self.tokenizer = CaduceusTokenizer( + model_max_length=self.seq_len, + add_special_tokens=False + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + + def prepare_data(self): + # Download and preprocess data if not already done + if not fsspec_exists(self._get_preprocessed_cache_file()): + self._download_and_preprocess_data() + + def setup(self, stage=None): + # Load the preprocessed dataset + self.prepare_data() + self.dataset = load_from_disk(self._get_preprocessed_cache_file()) + + # Split the dataset into train and validation sets + self.test_dataset = self.dataset["test"] + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _split_dataset(self,selected_validation_chromosome): + log.warning(f"SPLITTING THE DATASET INTO TRAIN AND VAL SET, VAL SET BEING CHROMOSOME {selected_validation_chromosome}") + self.train_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + != selected_validation_chromosome, + keep_in_memory=True, + ) + self.val_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + == selected_validation_chromosome, + keep_in_memory=True, + ) + self.validation_chromosome = selected_validation_chromosome + + def _get_preprocessed_cache_file(self): + self.cache_dir = osp.join( + "./", "data", "InstaDeepAI___genomics-long-range-benchmark", + "regulatory_element_promoter", f"seqlen{self.seq_len}" + ) + cache_file = os.path.join(self.cache_dir, "caduceus_char_token_preprocessed") + return re.sub(r"=", "_", cache_file) + + def _download_and_preprocess_data(self): + log.warning("Downloading and preprocessing data...") + dataset = load_dataset( + "InstaDeepAI/genomics-long-range-benchmark", + task_name="regulatory_element_promoter", + sequence_length=self.seq_len, + subset=True, + load_from_cache=False, + trust_remote_code=True + ) + try: + del dataset["validation"] # Remove empty validation split if it exists + except KeyError: + pass + + # Process data: filter sequences with too many 'N's, recast chromosomes, and tokenize + dataset = dataset.filter( + lambda example: example["sequence"].count('N') < 0.005 * self.seq_len, + desc="Filter N's" + ) + dataset = dataset.map( + recast_chromosome, + remove_columns=["chromosome"], + desc="Recast chromosome" + ) + dataset = dataset.map( + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len//self.bp_per_token), + batch_size=1000, + batched=True, + remove_columns=["sequence"], + desc="Tokenize", + num_proc=self.num_workers + ) + + # Save processed dataset to disk + dataset.save_to_disk(self._get_preprocessed_cache_file()) + log.warning("Data downloaded and preprocessed successfully.") + +class EnhancerDataModule(pl.LightningDataModule): + """ + Data module for Enhancer regulatory element finetuning with PyTorch Lightning. + + Args: + config: Configuration dictionary with data-related parameters. + """ + def __init__(self, config): + super().__init__() + self.seq_len = config.seq_len + self.bp_per_token = config.bp_per_token + self.model_name = config.model_name + self.train_batch_size = config.train_batch_size + self.test_batch_size = config.test_batch_size + self.num_workers = config.num_workers + self.train_ratio = config.train_ratio + self.eval_ratio = config.eval_ratio + self.cache_dir = "./" + self.dataset = None + + # Initialize the tokenizer + if "caduceus" in self.model_name: + self.tokenizer = CaduceusTokenizer( + model_max_length=self.seq_len, + add_special_tokens=False + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + + def prepare_data(self): + # Download and preprocess data if not already done + if not fsspec_exists(self._get_preprocessed_cache_file()): + self._download_and_preprocess_data() + + def setup(self, stage=None): + # Load the preprocessed dataset + self.prepare_data() + self.dataset = load_from_disk(self._get_preprocessed_cache_file()) + + # Split the dataset into train and validation sets + self.test_dataset = self.dataset["test"] + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _split_dataset(self,selected_validation_chromosome): + log.warning(f"SPLITTING THE DATASET INTO TRAIN AND VAL SET, VAL SET BEING CHROMOSOME {selected_validation_chromosome}") + self.train_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + != selected_validation_chromosome, + keep_in_memory=True, + ) + self.val_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + == selected_validation_chromosome, + keep_in_memory=True, + ) + self.validation_chromosome = selected_validation_chromosome + + def _get_preprocessed_cache_file(self): + self.cache_dir = osp.join( + "./", "data", "InstaDeepAI___genomics-long-range-benchmark", + "regulatory_element_enhancer", f"seqlen{self.seq_len}" + ) + cache_file = os.path.join(self.cache_dir, "caduceus_char_token_preprocessed") + return re.sub(r"=", "_", cache_file) + + def _download_and_preprocess_data(self): + log.warning("Downloading and preprocessing data...") + dataset = load_dataset( + "InstaDeepAI/genomics-long-range-benchmark", + task_name="regulatory_element_enhancer", + sequence_length=self.seq_len, + subset=True, + load_from_cache=False, + trust_remote_code=True + ) + try: + del dataset["validation"] # Remove empty validation split if it exists + except KeyError: + pass + + # Process data: filter sequences with too many 'N's, recast chromosomes, and tokenize + dataset = dataset.filter( + lambda example: example["sequence"].count('N') < 0.005 * self.seq_len, + desc="Filter N's" + ) + dataset = dataset.map( + recast_chromosome, + remove_columns=["chromosome"], + desc="Recast chromosome" + ) + dataset = dataset.map( + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len//self.bp_per_token), + batch_size=1000, + batched=True, + remove_columns=["sequence"], + desc="Tokenize", + num_proc=self.num_workers + ) + + # Save processed dataset to disk + dataset.save_to_disk(self._get_preprocessed_cache_file()) + log.warning("Data downloaded and preprocessed successfully.") + +def finetune_promoters(args): + """ + Main function to start training on Promoter regulatory elements with PyTorch Lightning. + + Args: + args: Command line arguments or configuration dictionary. + """ + wandb.login(key=args.wandb_api_key) + data_module = PromoterDataModule(args) + data_module.setup() + + np.random.seed(0) + candidates = np.unique(data_module.dataset["train"]["chromosome"]) + held_chromosomes = np.random.choice(candidates,5,replace = False) + + for idx,val_chromosome in enumerate(held_chromosomes): + + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}-fold-{idx+1}", + project="Regulatory Element Promoter", + log_model=True # Automatically log model checkpoints + ) + data_module._split_dataset(val_chromosome) + model = Lit_RegulatoryElements(args) + + checkpoint_callback = ModelCheckpoint( + dirpath=f"{args.save_dir}/checkpoints", + filename=f"best-checkpoint-on-chromosome{val_chromosome}", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision=args.precision, + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + # Start the training process + trainer.fit(model, data_module) + trainer.test(model,data_module,ckpt_path=f"./{args.save_dir}/checkpoints/best-checkpoint-on-chromosome{val_chromosome}.ckpt") + + # Finish the current WandB run + wandb.finish() + +def finetune_enhancers(args): + """ + Main function to start training on Enhancer regulatory elements with PyTorch Lightning. + + Args: + args: Command line arguments or configuration dictionary. + """ + wandb.login(key=args.wandb_api_key) + data_module = EnhancerDataModule(args) + data_module.setup() + + np.random.seed(0) + candidates = np.unique(data_module.dataset["train"]["chromosome"]) + held_chromosomes = np.random.choice(candidates,5,replace = False) + + for idx,val_chromosome in enumerate(held_chromosomes): + + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}-fold-{idx+1}", + project="Regulatory Elements Enhancer", + log_model=True # Automatically log model checkpoints + ) + data_module._split_dataset(val_chromosome) + model = Lit_RegulatoryElements(args) + + + checkpoint_callback = ModelCheckpoint( + dirpath=f"{args.save_dir}/checkpoints", + filename=f"best-checkpoint-on-chromosome{val_chromosome}", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision=args.precision, + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + # Start the training process + trainer.fit(model, data_module) + trainer.test(model,data_module,ckpt_path=f"./{args.save_dir}/checkpoints/best-checkpoint-on-chromosome{val_chromosome}.ckpt") + + # Finish the current WandB run + wandb.finish() diff --git a/finetuning_glrb/finetune_variant_effect_OMIM.py b/finetuning_glrb/finetune_variant_effect_OMIM.py new file mode 100644 index 0000000..c145021 --- /dev/null +++ b/finetuning_glrb/finetune_variant_effect_OMIM.py @@ -0,0 +1,499 @@ +import os +from functools import partial +from os import path as osp +import re +import torch +import torch.nn as nn +import numpy as np +from torch.utils.data import DataLoader +from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer, AutoConfig +import wandb +import lightning.pytorch as pl +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from datasets import load_dataset, load_from_disk +from sklearn.metrics import roc_auc_score, precision_recall_curve, auc +from transformers import DefaultDataCollator +from src.utils.train import get_logger +from caduceus.tokenization_caduceus import CaduceusTokenizer +from finetuning_glrb.utils import fsspec_exists, get_last_embedding_dimension + +# Logger setup +log = get_logger(__name__) + +# Constants +WINDOW_SIZE_BP = 1536 + +def tokenize_variants(examples, tokenizer, max_length: int): + """ + Tokenize reference and alternate sequences. + + Args: + examples: A batch of items from the dataset. + tokenizer: AutoTokenizer instance. + max_length: Maximum length for tokenization. + + Returns: + dict with tokenized input IDs for reference and alternate sequences. + """ + ref_tokenized = tokenizer.batch_encode_plus( + examples["ref_forward_sequence"], + add_special_tokens=False, + return_attention_mask=False, + max_length=max_length, + truncation=True, + ) + alt_tokenized = tokenizer.batch_encode_plus( + examples["alt_forward_sequence"], + add_special_tokens=False, + return_attention_mask=False, + max_length=max_length, + truncation=True, + ) + return { + "ref_input_ids": ref_tokenized["input_ids"], + "alt_input_ids": alt_tokenized["input_ids"] + } + +def recast_chromosome(examples): + """ + Recast chromosome to integer format. + + Returns: + dict with recast chromosome as an integer. + """ + return { + "chromosome": -1 if examples["chromosome"] in ["X","Y"] else int(examples["chromosome"]) + } + +def find_variant_idx(examples): + """ + Find the index of the variant in the sequence. + + Args: + examples: Items from the dataset (not batched). + + Returns: + dict with the index of the variant. + """ + idx = len(examples["ref_input_ids"]) // 2 # Assume variant is at the midpoint + if examples["ref_input_ids"][idx] == examples["alt_input_ids"][idx]: + idx = -1 + for i, (ref, alt) in enumerate(zip(examples["ref_input_ids"], examples["alt_input_ids"])): + if ref != alt: + idx = i + return {"variant_idx": idx} + +class MLP_VEP_OMIM(nn.Module): + """ + MLP head for Variant Effect Prediction (OMIM). + + Args: + input_size: Input size for the linear layer. + hidden_size: Hidden layer size. + output_size: Output size for the linear layer. + """ + def __init__(self, input_size, hidden_size, output_size): + super(MLP_VEP_OMIM, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.sp1 = nn.Softplus() + self.fc2 = nn.Linear(hidden_size, hidden_size) + self.sp2 = nn.Softplus() + self.fc3 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + return self.fc3(self.sp2(self.fc2(self.sp1(self.fc1(x))))) + +class DNAModelForOMIMFinetuning(nn.Module): + """ + DNA Model for OMIM Variant Effect Prediction fine-tuning. + + Args: + args: Arguments containing model configurations. + """ + def __init__(self, args): + super().__init__() + self.rcps = args.rcps + self.bp_per_token = args.bp_per_token + self.config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True) + + if "nucleotide-transformer" in args.model_name.lower(): + self.backbone = AutoModelForMaskedLM.from_pretrained(args.model_name, trust_remote_code=True).esm + else: + self.backbone = AutoModel.from_pretrained(args.model_name, trust_remote_code=True) + + print(f"MODEL LOADED: {self.backbone}") + self.inner_dim = get_last_embedding_dimension(self.backbone,self.rcps) + print(f"Inner dim found for the Foundation Model: {self.inner_dim}") + self.head = MLP_VEP_OMIM(input_size=2*self.inner_dim, hidden_size=2*self.inner_dim, output_size=2) + + def forward(self, alt_input_ids, ref_input_ids, variant_idx): + embeds_alternate = self.backbone(alt_input_ids)[0] + embeds_reference = self.backbone(ref_input_ids)[0] + window_size = WINDOW_SIZE_BP // self.bp_per_token // 2 + num_channels = embeds_alternate.size(-1) + + if self.rcps: + embeds_alt = embeds_alternate[..., :num_channels // 2] + embeds_ref = embeds_reference[..., :num_channels // 2] + expanded_indices = torch.arange(-window_size, window_size + 1, device=variant_idx.device).unsqueeze(0) + variant_idx.unsqueeze(1) + expanded_indices = torch.clamp(expanded_indices, 0, embeds_alt.size(1) - 1) + + windowed_embeds_alt = torch.gather(embeds_alt, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_alt.size(2))) + windowed_embeds_ref = torch.gather(embeds_ref, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_ref.size(2))) + + mean_embeds_alt = windowed_embeds_alt.mean(dim=1) + mean_embeds_ref = windowed_embeds_ref.mean(dim=1) + + concat_embeds = torch.cat([mean_embeds_alt, mean_embeds_ref], dim=-1) + + rc_embeds_alt = embeds_alternate[..., num_channels // 2:].contiguous().flip(dims=[1, 2]) + rc_embeds_ref = embeds_reference[..., num_channels // 2:].contiguous().flip(dims=[1, 2]) + + rc_windowed_embeds_alt = torch.gather(rc_embeds_alt, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, rc_embeds_alt.size(2))) + rc_windowed_embeds_ref = torch.gather(rc_embeds_ref, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, rc_embeds_ref.size(2))) + + rc_mean_embeds_alt = rc_windowed_embeds_alt.mean(dim=1) + rc_mean_embeds_ref = rc_windowed_embeds_ref.mean(dim=1) + + rc_concat_embeds = torch.cat([rc_mean_embeds_alt, rc_mean_embeds_ref], dim=-1) + + final_window = concat_embeds + rc_concat_embeds + return self.head(final_window) + + else: + expanded_indices = torch.arange(-window_size, window_size + 1, device=variant_idx.device).unsqueeze(0) + variant_idx.unsqueeze(1) + expanded_indices = torch.clamp(expanded_indices, 0, embeds_alternate.size(1) - 1) + + windowed_embeds_alt = torch.gather(embeds_alternate, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_alternate.size(2))).mean(dim=1) + windowed_embeds_ref = torch.gather(embeds_reference, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_reference.size(2))).mean(dim=1) + + concat_embeds = torch.cat([windowed_embeds_alt, windowed_embeds_ref], dim=-1) + return self.head(concat_embeds) + +class Lit_OMIMFinetuning(pl.LightningModule): + """ + PyTorch Lightning model for fine-tuning on OMIM Variant Effect Prediction. + + Args: + args: Arguments containing model and training configurations. + """ + def __init__(self, args): + super().__init__() + self.save_hyperparameters(args) + self.setup() + + def setup(self,stage=None): + self.model = DNAModelForOMIMFinetuning(self.hparams) + self.criterion = nn.CrossEntropyLoss() + self.validation_step_preds = [] + self.validation_step_labels = [] + self.training_step_preds = [] + self.training_step_labels = [] + self.training_step_correct = 0 + self.training_step_total = 0 + self.validation_step_correct = 0 + self.validation_step_total = 0 + + def forward(self, alt_input_ids, ref_input_ids, variant_idx): + return self.model(alt_input_ids, ref_input_ids, variant_idx) + + def training_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + alt_input_ids = batch["alt_input_ids"] + variant_index = batch["variant_idx"] + labels = batch["labels"] + + logits = self(alt_input_ids, ref_input_ids, variant_index) + loss = self.criterion(logits, labels) + self.log('train_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + preds = torch.argmax(logits, dim=1) + correct = (preds == labels).sum().item() + self.training_step_correct += correct + self.training_step_total += len(labels) + + all_labels = labels.cpu().numpy() + all_predictions = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy() + self.training_step_preds.extend(all_predictions) + self.training_step_labels.extend(all_labels) + + return loss + + def validation_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + alt_input_ids = batch["alt_input_ids"] + variant_index = batch["variant_idx"] + labels = batch["labels"] + + logits = self(alt_input_ids, ref_input_ids, variant_index) + loss = self.criterion(logits, labels) + self.log('val_loss', loss, on_epoch=True, on_step=False, sync_dist=True) + + preds = torch.argmax(logits, dim=1) + correct = (preds == labels).sum().item() + self.validation_step_correct += correct + self.validation_step_total += len(labels) + + all_labels = labels.cpu().numpy() + all_predictions = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy() + self.validation_step_preds.extend(all_predictions) + self.validation_step_labels.extend(all_labels) + + def test_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + alt_input_ids = batch["alt_input_ids"] + variant_index = batch["variant_idx"] + labels = batch["labels"] + + logits = self(alt_input_ids, ref_input_ids, variant_index) + + preds = torch.argmax(logits, dim=1) + correct = (preds == labels).sum().item() + self.validation_step_correct += correct + self.validation_step_total += len(labels) + + all_labels = labels.cpu().numpy() + all_predictions = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy() + self.validation_step_preds.extend(all_predictions) + self.validation_step_labels.extend(all_labels) + + def on_validation_epoch_end(self): + val_auroc = roc_auc_score(self.validation_step_labels, self.validation_step_preds) + precision, recall, _ = precision_recall_curve(self.validation_step_labels, self.validation_step_preds) + val_auprc = auc(recall, precision) + val_accuracy = self.validation_step_correct / self.validation_step_total + + self.logger.experiment.log({ + "validation/AUROC": val_auroc, + "validation/Accuracy": val_accuracy, + "validation/AUPRC": val_auprc, + }) + + self.validation_step_labels.clear() + self.validation_step_preds.clear() + self.validation_step_correct = 0 + self.validation_step_total = 0 + + def on_test_epoch_end(self): + val_auroc = roc_auc_score(self.validation_step_labels, self.validation_step_preds) + precision, recall, _ = precision_recall_curve(self.validation_step_labels, self.validation_step_preds) + val_auprc = auc(recall, precision) + val_accuracy = self.validation_step_correct / self.validation_step_total + + self.logger.experiment.log({ + "test/AUROC": val_auroc, + "test/Accuracy": val_accuracy, + "test/AUPRC": val_auprc, + }) + + self.validation_step_labels.clear() + self.validation_step_preds.clear() + self.validation_step_correct = 0 + self.validation_step_total = 0 + + def on_train_epoch_end(self): + train_auroc = roc_auc_score(self.training_step_labels, self.training_step_preds) + precision, recall, _ = precision_recall_curve(self.training_step_labels, self.training_step_preds) + train_auprc = auc(recall, precision) + train_accuracy = self.training_step_correct / self.training_step_total + + self.logger.experiment.log({ + "train/AUROC": train_auroc, + "train/Accuracy": train_accuracy, + "train/AUPRC": train_auprc + }) + + self.training_step_labels.clear() + self.training_step_preds.clear() + self.training_step_correct = 0 + self.training_step_total = 0 + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate) + +class VariantEffectPredictionDataModule(pl.LightningDataModule): + """ + Data module for OMIM Variant Effect Prediction fine-tuning with PyTorch Lightning. + + Args: + config: Configuration dictionary with data-related parameters. + """ + def __init__(self, config): + super().__init__() + self.seq_len = config.seq_len + self.bp_per_token = config.bp_per_token + self.model_name = config.model_name + self.train_batch_size = config.train_batch_size + self.test_batch_size = config.test_batch_size + self.num_workers = config.num_workers + self.train_ratio = config.train_ratio + self.eval_ratio = config.eval_ratio + self.cache_dir = "./" + self.dataset = None + + if "caduceus" in self.model_name: + self.tokenizer = CaduceusTokenizer( + model_max_length=self.seq_len, + add_special_tokens=False + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + + def prepare_data(self): + if not fsspec_exists(self._get_preprocessed_cache_file()): + self._download_and_preprocess_data() + + def setup(self, stage=None): + self.prepare_data() + self.dataset = load_from_disk(self._get_preprocessed_cache_file()) + + # Split the dataset into train and validation sets + self.test_dataset = self.dataset["test"] + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _split_dataset(self,selected_validation_chromosome): + log.warning(f"SPLITTING THE DATASET INTO TRAIN AND VAL SET, VAL SET BEING CHROMOSOME {selected_validation_chromosome}") + self.train_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + != selected_validation_chromosome, + keep_in_memory=True, + ) + self.val_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + == selected_validation_chromosome, + keep_in_memory=True, + ) + self.validation_chromosome = selected_validation_chromosome + + def _get_preprocessed_cache_file(self): + self.cache_dir = osp.join( + "./", "data", "InstaDeepAI___genomics-long-range-benchmark", + "variant_effect_pathogenic_omim", f"seqlen{self.seq_len}" + ) + cache_file = os.path.join(self.cache_dir, "caduceus_char_token_preprocessed") + return re.sub(r"=", "_", cache_file) + + def _download_and_preprocess_data(self): + log.warning(f"Downloading and preprocessing data...") + dataset = load_dataset( + "InstaDeepAI/genomics-long-range-benchmark", + task_name="variant_effect_pathogenic_omim", + sequence_length=self.seq_len, + load_from_cache=False, + trust_remote_code=True + ) + try: + del dataset["validation"] + except KeyError: + pass + + dataset = dataset.filter( + lambda example: example["ref_forward_sequence"].count('N') < 0.005 * self.seq_len, + desc="Filter N's" + ) + dataset = dataset.map( + recast_chromosome, + remove_columns=["chromosome"], + desc="Recast chromosome" + ) + dataset = dataset.map( + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len//self.bp_per_token), + batch_size=1000, + batched=True, + remove_columns=["ref_forward_sequence", "alt_forward_sequence"], + desc="Tokenize", + num_proc=self.num_workers + ) + dataset = dataset.map(find_variant_idx, desc="Find variant idx") + + dataset.save_to_disk(self._get_preprocessed_cache_file()) + log.warning(f"Data downloaded and preprocessed successfully.") + +def finetune(args): + """ + Main function to start the training process for OMIM Variant Effect Prediction using PyTorch Lightning. + + Args: + args: Command line arguments or configuration dictionary. + """ + wandb.login(key=args.wandb_api_key) + data_module = VariantEffectPredictionDataModule(args) + data_module.setup() + + np.random.seed(0) + candidates = np.unique(data_module.dataset["train"]["chromosome"]) + held_chromosomes = np.random.choice(candidates,5,replace = False) + + for idx,val_chromosome in enumerate(held_chromosomes): + + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}-fold-{idx+1}", + project="Variant Effect Prediction OMIM", + log_model=True + ) + + data_module._split_dataset(val_chromosome) + model = Lit_OMIMFinetuning(args) + + checkpoint_callback = ModelCheckpoint( + dirpath=f"{args.save_dir}/checkpoints", + filename=f"best-checkpoint-on-chromosome{val_chromosome}", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision=args.precision, + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + trainer.fit(model, data_module) + trainer.test(model,data_module,ckpt_path=f"./{args.save_dir}/checkpoints/best-checkpoint-on-chromosome{val_chromosome}.ckpt") + + # Finish the current WandB run + wandb.finish() diff --git a/finetuning_glrb/finetune_variant_effect_causal_eqtl.py b/finetuning_glrb/finetune_variant_effect_causal_eqtl.py new file mode 100644 index 0000000..bf325d5 --- /dev/null +++ b/finetuning_glrb/finetune_variant_effect_causal_eqtl.py @@ -0,0 +1,560 @@ +import os +from functools import partial +from os import path as osp +import re +import torch +import torch.nn as nn +import numpy as np +from torch.utils.data import DataLoader, Dataset +from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer, AutoConfig +import wandb +import lightning.pytorch as pl +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +import numpy as np +from datasets import load_dataset, load_from_disk +from sklearn import preprocessing +from sklearn.metrics import accuracy_score, precision_recall_curve, auc, roc_auc_score +from transformers import DefaultDataCollator +from src.utils.train import get_logger +from caduceus.tokenization_caduceus import CaduceusTokenizer +from finetuning_glrb.utils import fsspec_exists, get_last_embedding_dimension + +# Constants +WINDOW_SIZE_BP = 1536 +DIST_TO_TSS = [[0, 30_000], [30_000, 100_000], [100_000, np.inf]] + +# Logger setup +log = get_logger(__name__) + +def recast_chromosome_tissue_dist2TSS(examples): + """ + Recast chromosome to integer and retain tissue and distance to nearest TSS. + + Returns: + dict with recast chromosome, tissue, and distance to TSS. + """ + return { + "chromosome": -1 if examples["chromosome"] in ["X","Y"] else int(examples["chromosome"]), + "tissue": examples["tissue"], + "distance_to_nearest_tss": examples["distance_to_nearest_tss"] + } + +def tokenize_variants(examples, tokenizer, max_length: int): + """ + Tokenize sequences. + + Args: + examples: A batch of items from the dataset. + tokenizer: AutoTokenizer instance. + max_length: Maximum length for tokenization. + + Returns: + dict with tokenized input IDs for reference and alternate sequences. + """ + ref_tokenized = tokenizer.batch_encode_plus( + examples["ref_forward_sequence"], + add_special_tokens=False, + return_attention_mask=False, + max_length=max_length, + truncation=True, + ) + alt_tokenized = tokenizer.batch_encode_plus( + examples["alt_forward_sequence"], + add_special_tokens=False, + return_attention_mask=False, + max_length=max_length, + truncation=True, + ) + return { + "ref_input_ids": ref_tokenized["input_ids"], + "alt_input_ids": alt_tokenized["input_ids"] + } + +def find_variant_idx(examples): + """ + Find the index of the variant in the sequence. + + Args: + examples: Items from the dataset (not batched). + + Returns: + dict with the index of the variant. + """ + idx = len(examples["ref_input_ids"]) // 2 # Assume variant is at the midpoint + if examples["ref_input_ids"][idx] == examples["alt_input_ids"][idx]: + idx = -1 + for i, (ref, alt) in enumerate(zip(examples["ref_input_ids"], examples["alt_input_ids"])): + if ref != alt: + idx = i + return {"variant_idx": idx} + +def dataset_tss_filter(data: Dataset, min_distance: int, max_distance: int): + """ + Filter the data based on the distance to the nearest TSS. + + Args: + data: Dataset to be filtered. + min_distance: Minimum distance to the TSS. + max_distance: Maximum distance to the TSS. + + Returns: + Filtered dataset. + """ + distance_mask = (data["distance_to_nearest_tss"] >= min_distance) & (data["distance_to_nearest_tss"] <= max_distance) + filtered_data = {key: value[distance_mask] for key, value in data.items()} + return filtered_data + +class MLP_VEP(nn.Module): + """ + MLP head for Variant Effect Prediction (VEP). + + Args: + input_size: Input size for the linear layer. + hidden_size: Hidden layer size. + output_size: Output size for the linear layer. + """ + def __init__(self, input_size, hidden_size, output_size): + super(MLP_VEP, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.sp1 = nn.Softplus() + self.fc2 = nn.Linear(hidden_size, hidden_size) + self.sp2 = nn.Softplus() + self.fc3 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + return self.fc3(self.sp2(self.fc2(self.sp1(self.fc1(x))))) + +class DNAModelForVEPFinetuning(nn.Module): + """ + DNA Model for Variant Effect Prediction fine-tuning. + + Args: + args: Arguments containing model configurations. + """ + def __init__(self, args): + super().__init__() + self.rcps = args.rcps + self.bp_per_token = args.bp_per_token + self.config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True) + + if "nucleotide-transformer" in args.model_name.lower(): + self.backbone = AutoModelForMaskedLM.from_pretrained(args.model_name, trust_remote_code=True).esm + else: + self.backbone = AutoModel.from_pretrained(args.model_name, trust_remote_code=True) + + print(f"MODEL LOADED: {self.backbone}") + self.inner_dim = get_last_embedding_dimension(self.backbone,self.rcps) + print(f"Inner dim found for the Foundation Model: {self.inner_dim}") + self.head = MLP_VEP(input_size=2 * self.inner_dim + 1, hidden_size=2 * self.inner_dim, output_size=2) + + def forward(self, alt_input_ids, ref_input_ids, variant_idx, tissue_embed): + # Get embeddings for the alternate and reference sequences + embeds_alternate = self.backbone(alt_input_ids)[0] + embeds_reference = self.backbone(ref_input_ids)[0] + window_size = WINDOW_SIZE_BP // self.bp_per_token // 2 + num_channels = embeds_alternate.size(-1) + + if self.rcps: + # Reverse complement processing + embeds_alt = embeds_alternate[..., :num_channels // 2] + embeds_ref = embeds_reference[..., :num_channels // 2] + + expanded_indices = torch.arange(-window_size, window_size + 1, device=variant_idx.device).unsqueeze(0) + variant_idx.unsqueeze(1) + expanded_indices = torch.clamp(expanded_indices, 0, embeds_alt.size(1) - 1) + + # Extract windowed embeddings + windowed_embeds_alt = torch.gather(embeds_alt, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_alt.size(2))) + windowed_embeds_ref = torch.gather(embeds_ref, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_ref.size(2))) + + mean_embeds_alt = windowed_embeds_alt.mean(dim=1) + mean_embeds_ref = windowed_embeds_ref.mean(dim=1) + + # Concatenate the embeddings + concat_embeds = torch.cat([mean_embeds_alt, mean_embeds_ref, tissue_embed[..., None]], dim=-1) + + #Same for the RC-Equivalent part + rc_embeds_alt = embeds_alternate[..., num_channels // 2:].contiguous().flip(dims=[1, 2]) + rc_embeds_ref = embeds_reference[..., num_channels // 2:].contiguous().flip(dims=[1, 2]) + + rc_windowed_embeds_alt = torch.gather(rc_embeds_alt, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, rc_embeds_alt.size(2))) + rc_windowed_embeds_ref = torch.gather(rc_embeds_ref, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, rc_embeds_ref.size(2))) + + rc_mean_embeds_alt = rc_windowed_embeds_alt.mean(dim=1) + rc_mean_embeds_ref = rc_windowed_embeds_ref.mean(dim=1) + + rc_concat_embeds = torch.cat([rc_mean_embeds_alt, rc_mean_embeds_ref, tissue_embed[..., None]], dim=-1) + + final_window = concat_embeds + rc_concat_embeds + return self.head(final_window) + + else: + expanded_indices = torch.arange(-window_size, window_size + 1, device=variant_idx.device).unsqueeze(0) + variant_idx.unsqueeze(1) + expanded_indices = torch.clamp(expanded_indices, 0, embeds_alternate.size(1) - 1) + + windowed_embeds_alt = torch.gather(embeds_alternate, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_alternate.size(2))).mean(dim=1) + windowed_embeds_ref = torch.gather(embeds_reference, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_reference.size(2))).mean(dim=1) + + concat_embeds = torch.cat([windowed_embeds_alt, windowed_embeds_ref, tissue_embed[..., None]], dim=-1) + return self.head(concat_embeds) + +class LitVEPFinetuning(pl.LightningModule): + """ + PyTorch Lightning model for fine-tuning on Variant Effect Prediction. + + Args: + args: Arguments containing model and training configurations. + """ + def __init__(self, args): + super().__init__() + self.save_hyperparameters(args) + self.setup() + + def setup(self,stage=None): + self.model = DNAModelForVEPFinetuning(self.hparams) + self.criterion = nn.CrossEntropyLoss() + self.validation_step_preds = {i: [] for i in range(len(DIST_TO_TSS))} + self.validation_step_labels = {i: [] for i in range(len(DIST_TO_TSS))} + + def forward(self, alt_input_ids, ref_input_ids, variant_idx, tissue_embed): + return self.model(alt_input_ids, ref_input_ids, variant_idx, tissue_embed) + + def training_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + alt_input_ids = batch["alt_input_ids"] + variant_index = batch["variant_idx"] + tissue_embed = batch["tissue_embed"] + labels = batch["labels"] + + logits = self(alt_input_ids, ref_input_ids, variant_index, tissue_embed) + loss = self.criterion(logits, labels) + self.log('train_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + return loss + + def validation_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + alt_input_ids = batch["alt_input_ids"] + variant_index = batch["variant_idx"] + tissue_embed = batch["tissue_embed"] + labels = batch["labels"] + distance_to_nearest_tss = batch["distance_to_nearest_tss"] + + logits = self(alt_input_ids, ref_input_ids, variant_index, tissue_embed) + loss = self.criterion(logits, labels) + self.log('val_loss', loss, on_epoch=True, on_step=False, sync_dist=True) + + # Predictions for AUROC + preds = torch.argmax(logits, dim=1).detach().cpu().numpy() + labels_np = labels.cpu().numpy() + + for i, (min_dist, max_dist) in enumerate(DIST_TO_TSS): + mask = ((distance_to_nearest_tss >= min_dist) & (distance_to_nearest_tss < max_dist)).cpu().numpy() + filtered_preds = preds[mask] + filtered_labels = labels_np[mask] + + if len(filtered_labels) > 0: + self.validation_step_labels[i].extend(filtered_labels) + self.validation_step_preds[i].extend(filtered_preds) + + def test_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + alt_input_ids = batch["alt_input_ids"] + variant_index = batch["variant_idx"] + tissue_embed = batch["tissue_embed"] + labels = batch["labels"] + distance_to_nearest_tss = batch["distance_to_nearest_tss"] + + logits = self(alt_input_ids, ref_input_ids, variant_index, tissue_embed) + + # Predictions for AUROC + preds = torch.argmax(logits, dim=1).detach().cpu().numpy() + labels_np = labels.cpu().numpy() + + for i, (min_dist, max_dist) in enumerate(DIST_TO_TSS): + mask = ((distance_to_nearest_tss >= min_dist) & (distance_to_nearest_tss < max_dist)).cpu().numpy() + filtered_preds = preds[mask] + filtered_labels = labels_np[mask] + + if len(filtered_labels) > 0: + self.validation_step_labels[i].extend(filtered_labels) + self.validation_step_preds[i].extend(filtered_preds) + + def on_validation_epoch_end(self): + # Initialize lists to store all labels and predictions across all TSS distance buckets + all_labels = [] + all_preds = [] + + for i, (min_dist, max_dist) in enumerate(DIST_TO_TSS): + if len(self.validation_step_labels[i]) > 0: + val_auroc = roc_auc_score(self.validation_step_labels[i], self.validation_step_preds[i]) + precision, recall, _ = precision_recall_curve(self.validation_step_labels[i], self.validation_step_preds[i]) + val_auprc = auc(recall, precision) + val_accuracy = accuracy_score(self.validation_step_labels[i], self.validation_step_preds[i]) + + # Log metrics for each TSS distance bucket + self.log(f'validation/TSS({min_dist}-{max_dist})/AUROC', val_auroc, on_epoch=True, sync_dist=True) + self.log(f'validation/TSS({min_dist}-{max_dist})/Accuracy', val_accuracy, on_epoch=True, sync_dist=True) + self.log(f'validation/TSS({min_dist}-{max_dist})/AUPRC', val_auprc, on_epoch=True, sync_dist=True) + print(f'Bucket {i} [{min_dist}-{max_dist}] - AUROC: {val_auroc:.4f}') + + # Aggregate the labels and predictions + all_labels.extend(self.validation_step_labels[i]) + all_preds.extend(self.validation_step_preds[i]) + + self.validation_step_labels[i].clear() + self.validation_step_preds[i].clear() + + # Compute overall metrics if there are any labels + if all_labels: + overall_auroc = roc_auc_score(all_labels, all_preds) + precision, recall, _ = precision_recall_curve(all_labels, all_preds) + overall_auprc = auc(recall, precision) + overall_accuracy = accuracy_score(all_labels, all_preds) + + # Log overall metrics + self.log('validation/overall/AUROC', overall_auroc, on_epoch=True, sync_dist=True) + self.log('validation/overall/Accuracy', overall_accuracy, on_epoch=True, sync_dist=True) + self.log('validation/overall/AUPRC', overall_auprc, on_epoch=True, sync_dist=True) + print(f'Overall - AUROC: {overall_auroc:.4f}') + + + def on_test_epoch_end(self): + # Initialize lists to store all labels and predictions across all TSS distance buckets + all_labels = [] + all_preds = [] + + for i, (min_dist, max_dist) in enumerate(DIST_TO_TSS): + if len(self.validation_step_labels[i]) > 0: + val_auroc = roc_auc_score(self.validation_step_labels[i], self.validation_step_preds[i]) + precision, recall, _ = precision_recall_curve(self.validation_step_labels[i], self.validation_step_preds[i]) + val_auprc = auc(recall, precision) + val_accuracy = accuracy_score(self.validation_step_labels[i], self.validation_step_preds[i]) + + # Log metrics for each TSS distance bucket + self.log(f'test/TSS({min_dist}-{max_dist})/AUROC', val_auroc, on_epoch=True, sync_dist=True) + self.log(f'test/TSS({min_dist}-{max_dist})/Accuracy', val_accuracy, on_epoch=True, sync_dist=True) + self.log(f'test/TSS({min_dist}-{max_dist})/AUPRC', val_auprc, on_epoch=True, sync_dist=True) + print(f'Bucket {i} [{min_dist}-{max_dist}] - AUROC: {val_auroc:.4f}') + + # Aggregate the labels and predictions + all_labels.extend(self.validation_step_labels[i]) + all_preds.extend(self.validation_step_preds[i]) + + self.validation_step_labels[i].clear() + self.validation_step_preds[i].clear() + + # Compute overall metrics if there are any labels + if all_labels: + overall_auroc = roc_auc_score(all_labels, all_preds) + precision, recall, _ = precision_recall_curve(all_labels, all_preds) + overall_auprc = auc(recall, precision) + overall_accuracy = accuracy_score(all_labels, all_preds) + + # Log overall metrics + self.log('test/overall/AUROC', overall_auroc, on_epoch=True, sync_dist=True) + self.log('test/overall/Accuracy', overall_accuracy, on_epoch=True, sync_dist=True) + self.log('test/overall/AUPRC', overall_auprc, on_epoch=True, sync_dist=True) + print(f'Overall - AUROC: {overall_auroc:.4f}') + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate) + +class VariantEffectPredictionDataModule(pl.LightningDataModule): + """ + Data module for Variant Effect Prediction finetuning with PyTorch Lightning. + + Args: + config: Configuration dictionary with data-related parameters. + """ + def __init__(self, config): + super().__init__() + self.seq_len = config.seq_len + self.bp_per_token = config.bp_per_token + self.model_name = config.model_name + self.train_batch_size = config.train_batch_size + self.test_batch_size = config.test_batch_size + self.num_workers = config.num_workers + self.train_ratio = config.train_ratio + self.eval_ratio = config.eval_ratio + self.cache_dir = "./" + self.dataset = None + + # Initialize the tokenizer + if "caduceus" in self.model_name: + self.tokenizer = CaduceusTokenizer( + model_max_length=self.seq_len, + add_special_tokens=False + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + + def prepare_data(self): + # Download and preprocess data if not already done + if not fsspec_exists(self._get_preprocessed_cache_file()): + self._download_and_preprocess_data() + + def setup(self, stage=None): + # Load the preprocessed dataset + self.prepare_data() + self.dataset = load_from_disk(self._get_preprocessed_cache_file()) + + # Split the dataset into train and validation sets + self.test_dataset = self.dataset["test"] + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _split_dataset(self,selected_validation_chromosome): + log.warning(f"SPLITTING THE DATASET INTO TRAIN AND VAL SET, VAL SET BEING CHROMOSOME {selected_validation_chromosome}") + self.train_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + != selected_validation_chromosome, + keep_in_memory=True, + ) + self.val_dataset = self.dataset["train"].filter( + lambda example: example["chromosome"] + == selected_validation_chromosome, + keep_in_memory=True, + ) + self.validation_chromosome = selected_validation_chromosome + + def _get_preprocessed_cache_file(self): + cache_dir = osp.join( + "./", "data", "InstaDeepAI___genomics-long-range-benchmark", + "variant_effect_causal_eqtl", f"seqlen{self.seq_len}" + ) + cache_file = os.path.join(cache_dir, "caduceus_char_token_preprocessed") + return re.sub(r"=", "_", cache_file) + + def _download_and_preprocess_data(self): + log.warning(f"Downloading and preprocessing data...") + dataset = load_dataset( + "InstaDeepAI/genomics-long-range-benchmark", + task_name="variant_effect_causal_eqtl", + sequence_length=self.seq_len, + load_from_cache=False, + trust_remote_code=True + ) + try: + del dataset["validation"] # Remove empty validation split if it exists + except KeyError: + pass + + dataset = dataset.filter( + lambda example: example["ref_forward_sequence"].count('N') < 0.005 * self.seq_len, + desc="Filter N's" + ) + dataset = dataset.map( + recast_chromosome_tissue_dist2TSS, + remove_columns=["chromosome", "tissue", "distance_to_nearest_tss"], + desc="Recast chromosome" + ) + dataset = dataset.map( + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len//self.bp_per_token), + batch_size=1000, + batched=True, + remove_columns=["ref_forward_sequence", "alt_forward_sequence"], + desc="Tokenize", + num_proc=self.num_workers + ) + dataset = dataset.map(find_variant_idx, desc="Find variant idx") + + label_encoder = preprocessing.LabelEncoder() + label_encoder.fit(dataset["test"]["tissue"]) + train_tissue_embed = label_encoder.transform(dataset["train"]["tissue"]) + dataset["train"] = dataset["train"].add_column("tissue_embed", train_tissue_embed) + test_tissue_embed = label_encoder.transform(dataset["test"]["tissue"]) + dataset["test"] = dataset["test"].add_column("tissue_embed", test_tissue_embed) + + # Save to disk if running locally + dataset.save_to_disk(self._get_preprocessed_cache_file()) + + log.warning(f"Data downloaded and preprocessed successfully.") + +def finetune(args): + """ + Main function to start process for Variant Effect Prediction Finetuning eQTL using PyTorch Lightning. + + Args: + args: Command line arguments or configuration dictionary. + """ + wandb.login(key=args.wandb_api_key) + data_module = VariantEffectPredictionDataModule(args) + data_module.setup() + + np.random.seed(0) + candidates = np.unique(data_module.dataset["train"]["chromosome"]) + held_chromosomes = np.random.choice(candidates,5,replace = False) + + for idx,val_chromosome in enumerate(held_chromosomes): + + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}-fold-{idx+1}", + project="Variant Effect Prediction Causal eQTL", + log_model=True # Automatically log model checkpoints + ) + + data_module._split_dataset(val_chromosome) + model = LitVEPFinetuning(args) + + checkpoint_callback = ModelCheckpoint( + dirpath=f"{args.save_dir}/checkpoints", + filename=f"best-checkpoint-on-chromosome{val_chromosome}", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision=args.precision, + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + # Start the training process + trainer.fit(model, data_module) + trainer.test(model,data_module,ckpt_path=f"./{args.save_dir}/checkpoints/best-checkpoint-on-chromosome{val_chromosome}.ckpt") + + # Finish the current WandB run + wandb.finish() diff --git a/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py b/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py new file mode 100644 index 0000000..da83a2b --- /dev/null +++ b/finetuning_glrb/finetune_variant_effect_pathogenic_clinvar.py @@ -0,0 +1,481 @@ +import os +from functools import partial +from os import path as osp +import re +import torch +import torch.nn as nn +import numpy as np +from torch.utils.data import DataLoader +from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer, AutoConfig +import wandb +import lightning.pytorch as pl +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from datasets import load_dataset, load_from_disk +from sklearn.metrics import precision_recall_curve, auc, roc_auc_score +from transformers import DefaultDataCollator +from src.utils.train import get_logger +from caduceus.tokenization_caduceus import CaduceusTokenizer +from finetuning_glrb.utils import fsspec_exists, get_last_embedding_dimension + +# Logger setup +log = get_logger(__name__) + +# Constants +WINDOW_SIZE_BP = 1536 + +def tokenize_variants(examples, tokenizer, max_length: int): + """ + Tokenize reference and alternate sequences. + + Args: + examples: A batch of items from the dataset. + tokenizer: AutoTokenizer instance. + max_length: Maximum length for tokenization. + + Returns: + dict with tokenized input IDs for reference and alternate sequences. + """ + ref_tokenized = tokenizer.batch_encode_plus( + examples["ref_forward_sequence"], + add_special_tokens=False, + return_attention_mask=False, + max_length=max_length, + truncation=True, + ) + alt_tokenized = tokenizer.batch_encode_plus( + examples["alt_forward_sequence"], + add_special_tokens=False, + return_attention_mask=False, + max_length=max_length, + truncation=True, + ) + return { + "ref_input_ids": ref_tokenized["input_ids"], + "alt_input_ids": alt_tokenized["input_ids"] + } + +def recast_chromosome(examples): + """ + Recast chromosome to integer format. + + Returns: + dict with recast chromosome as an integer. + """ + return { + "chromosome": -1 if examples["chromosome"] in ["X","Y"] else int(examples["chromosome"]) + } + +def find_variant_idx(examples): + """ + Find the index of the variant in the sequence. + + Args: + examples: Items from the dataset (not batched). + + Returns: + dict with the index of the variant. + """ + idx = len(examples["ref_input_ids"]) // 2 # Assume variant is at the midpoint + if examples["ref_input_ids"][idx] == examples["alt_input_ids"][idx]: + idx = -1 + for i, (ref, alt) in enumerate(zip(examples["ref_input_ids"], examples["alt_input_ids"])): + if ref != alt: + idx = i + return {"variant_idx": idx} + +class MLP_VEP_ClinVar(nn.Module): + """ + MLP head for Variant Effect Prediction (ClinVar). + + Args: + input_size: Input size for the linear layer. + hidden_size: Hidden layer size. + output_size: Output size for the linear layer. + """ + def __init__(self, input_size, hidden_size, output_size): + super(MLP_VEP_ClinVar, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.sp1 = nn.Softplus() + self.fc2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + return self.fc2(self.sp1(self.fc1(x))) + +class DNAModelForVEPFinetuning(nn.Module): + """ + DNA Model for Variant Effect Prediction (ClinVar) fine-tuning. + + Args: + args: Arguments containing model configurations. + """ + def __init__(self, args): + super().__init__() + self.rcps = args.rcps + self.bp_per_token = args.bp_per_token + self.config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True) + + if "nucleotide-transformer" in args.model_name.lower(): + self.backbone = AutoModelForMaskedLM.from_pretrained(args.model_name, trust_remote_code=True).esm + else: + self.backbone = AutoModel.from_pretrained(args.model_name, trust_remote_code=True) + + print(f"MODEL LOADED: {self.backbone}") + self.inner_dim = get_last_embedding_dimension(self.backbone,self.rcps) + print(f"Inner dim found for the Foundation Model: {self.inner_dim}") + self.head = MLP_VEP_ClinVar(input_size=2*self.inner_dim, hidden_size=2*self.inner_dim, output_size=2) + + def forward(self, alt_input_ids, ref_input_ids, variant_idx): + embeds_alternate = self.backbone(alt_input_ids)[0] + embeds_reference = self.backbone(ref_input_ids)[0] + window_size = WINDOW_SIZE_BP // self.bp_per_token // 2 + num_channels = embeds_alternate.size(-1) + + if self.rcps: + embeds_alt = embeds_alternate[..., :num_channels // 2] + embeds_ref = embeds_reference[..., :num_channels // 2] + expanded_indices = torch.arange(-window_size, window_size + 1, device=variant_idx.device).unsqueeze(0) + variant_idx.unsqueeze(1) + expanded_indices = torch.clamp(expanded_indices, 0, embeds_alt.size(1) - 1) + + windowed_embeds_alt = torch.gather(embeds_alt, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_alt.size(2))) + windowed_embeds_ref = torch.gather(embeds_ref, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_ref.size(2))) + + mean_embeds_alt = windowed_embeds_alt.mean(dim=1) + mean_embeds_ref = windowed_embeds_ref.mean(dim=1) + + concat_embeds = torch.cat([mean_embeds_alt, mean_embeds_ref], dim=-1) + + rc_embeds_alt = embeds_alternate[..., num_channels // 2:].contiguous().flip(dims=[1, 2]) + rc_embeds_ref = embeds_reference[..., num_channels // 2:].contiguous().flip(dims=[1, 2]) + + rc_windowed_embeds_alt = torch.gather(rc_embeds_alt, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, rc_embeds_alt.size(2))) + rc_windowed_embeds_ref = torch.gather(rc_embeds_ref, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, rc_embeds_ref.size(2))) + + rc_mean_embeds_alt = rc_windowed_embeds_alt.mean(dim=1) + rc_mean_embeds_ref = rc_windowed_embeds_ref.mean(dim=1) + + rc_concat_embeds = torch.cat([rc_mean_embeds_alt, rc_mean_embeds_ref], dim=-1) + + final_window = concat_embeds + rc_concat_embeds + return self.head(final_window) + + else: + expanded_indices = torch.arange(-window_size, window_size + 1, device=variant_idx.device).unsqueeze(0) + variant_idx.unsqueeze(1) + expanded_indices = torch.clamp(expanded_indices, 0, embeds_alternate.size(1) - 1) + + windowed_embeds_alt = torch.gather(embeds_alternate, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_alternate.size(2))).mean(dim=1) + windowed_embeds_ref = torch.gather(embeds_reference, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, embeds_reference.size(2))).mean(dim=1) + + concat_embeds = torch.cat([windowed_embeds_alt, windowed_embeds_ref], dim=-1) + return self.head(concat_embeds) + +class Lit_ClinVarFinetuning(pl.LightningModule): + """ + PyTorch Lightning model for fine-tuning on ClinVar Variant Effect Prediction. + + Args: + args: Arguments containing model and training configurations. + """ + def __init__(self, args): + super().__init__() + self.save_hyperparameters(args) + self.setup() + + def setup(self,stage=None): + self.model = DNAModelForVEPFinetuning(self.hparams) + self.criterion = nn.CrossEntropyLoss() + self.validation_step_preds = [] + self.validation_step_labels = [] + self.training_step_preds = [] + self.training_step_labels = [] + self.training_step_correct = 0 + self.training_step_total = 0 + self.validation_step_correct = 0 + self.validation_step_total = 0 + + def forward(self, alt_input_ids, ref_input_ids, variant_idx): + return self.model(alt_input_ids, ref_input_ids, variant_idx) + + def training_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + alt_input_ids = batch["alt_input_ids"] + variant_index = batch["variant_idx"] + labels = batch["labels"] + + logits = self(alt_input_ids, ref_input_ids, variant_index) + loss = self.criterion(logits, labels) + self.log('train_loss', loss, on_epoch=True, on_step=True, sync_dist=True) + + preds = torch.argmax(logits, dim=1) + correct = (preds == labels).sum().item() + self.training_step_correct += correct + self.training_step_total += len(labels) + + all_labels = labels.cpu().numpy() + all_predictions = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy() + self.training_step_preds.extend(all_predictions) + self.training_step_labels.extend(all_labels) + + return loss + + def validation_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + alt_input_ids = batch["alt_input_ids"] + variant_index = batch["variant_idx"] + labels = batch["labels"] + + logits = self(alt_input_ids, ref_input_ids, variant_index) + loss = self.criterion(logits, labels) + self.log('val_loss', loss, on_epoch=True, on_step=False, sync_dist=True) + + preds = torch.argmax(logits, dim=1) + correct = (preds == labels).sum().item() + self.validation_step_correct += correct + self.validation_step_total += len(labels) + + all_labels = labels.cpu().numpy() + all_predictions = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy() + self.validation_step_preds.extend(all_predictions) + self.validation_step_labels.extend(all_labels) + + def test_step(self, batch, batch_idx): + ref_input_ids = batch["ref_input_ids"] + alt_input_ids = batch["alt_input_ids"] + variant_index = batch["variant_idx"] + labels = batch["labels"] + + logits = self(alt_input_ids, ref_input_ids, variant_index) + preds = torch.argmax(logits, dim=1) + correct = (preds == labels).sum().item() + self.validation_step_correct += correct + self.validation_step_total += len(labels) + + all_labels = labels.cpu().numpy() + all_predictions = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy() + self.validation_step_preds.extend(all_predictions) + self.validation_step_labels.extend(all_labels) + + def on_validation_epoch_end(self): + val_auroc = roc_auc_score(self.validation_step_labels, self.validation_step_preds) + precision, recall, _ = precision_recall_curve(self.validation_step_labels, self.validation_step_preds) + val_auprc = auc(recall, precision) + val_accuracy = self.validation_step_correct / self.validation_step_total + + self.logger.experiment.log({ + "validation/AUROC": val_auroc, + "validation/Accuracy": val_accuracy, + "validation/AUPRC": val_auprc, + }) + + self.validation_step_labels.clear() + self.validation_step_preds.clear() + self.validation_step_correct = 0 + self.validation_step_total = 0 + + def on_test_epoch_end(self): + val_auroc = roc_auc_score(self.validation_step_labels, self.validation_step_preds) + precision, recall, _ = precision_recall_curve(self.validation_step_labels, self.validation_step_preds) + val_auprc = auc(recall, precision) + val_accuracy = self.validation_step_correct / self.validation_step_total + + self.logger.experiment.log({ + "test/AUROC": val_auroc, + "test/Accuracy": val_accuracy, + "test/AUPRC": val_auprc, + }) + + self.validation_step_labels.clear() + self.validation_step_preds.clear() + self.validation_step_correct = 0 + self.validation_step_total = 0 + + def on_train_epoch_end(self): + train_auroc = roc_auc_score(self.training_step_labels, self.training_step_preds) + precision, recall, _ = precision_recall_curve(self.training_step_labels, self.training_step_preds) + train_auprc = auc(recall, precision) + train_accuracy = self.training_step_correct / self.training_step_total + + self.logger.experiment.log({ + "train/AUROC": train_auroc, + "train/Accuracy": train_accuracy, + "train/AUPRC": train_auprc + }) + + self.training_step_labels.clear() + self.training_step_preds.clear() + self.training_step_correct = 0 + self.training_step_total = 0 + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate) + +class VariantEffectPredictionDataModule(pl.LightningDataModule): + """ + Data module for ClinVar Variant Effect Prediction fine-tuning with PyTorch Lightning. + + Args: + config: Configuration dictionary with data-related parameters. + """ + def __init__(self, config): + super().__init__() + self.seq_len = config.seq_len + self.bp_per_tokem = config.bp_per_token + self.model_name = config.model_name + self.train_batch_size = config.train_batch_size + self.test_batch_size = config.test_batch_size + self.num_workers = config.num_workers + self.train_ratio = config.train_ratio + self.eval_ratio = config.eval_ratio + self.cache_dir = "./" + self.dataset = None + + if "caduceus" in self.model_name: + self.tokenizer = CaduceusTokenizer( + model_max_length=self.seq_len, + add_special_tokens=False + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + + def prepare_data(self): + if not fsspec_exists(self._get_preprocessed_cache_file()): + self._download_and_preprocess_data() + + def setup(self, stage=None): + self.prepare_data() + self.dataset = load_from_disk(self._get_preprocessed_cache_file()) + + # Split the dataset into train and validation sets + self.test_dataset = self.dataset["test"] + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=True + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.test_batch_size, + collate_fn=DefaultDataCollator(return_tensors="pt"), + num_workers=self.num_workers, + pin_memory=True, + shuffle=False + ) + + def _get_preprocessed_cache_file(self): + self.cache_dir = osp.join( + "./", "data", "InstaDeepAI___genomics-long-range-benchmark", + "variant_effect_pathogenic_clinvar", f"seqlen{self.seq_len}" + ) + cache_file = os.path.join(self.cache_dir, "caduceus_char_token_preprocessed") + return re.sub(r"=", "_", cache_file) + + def _download_and_preprocess_data(self): + log.warning(f"Downloading and preprocessing data...") + dataset = load_dataset( + "InstaDeepAI/genomics-long-range-benchmark", + task_name="variant_effect_pathogenic_clinvar", + sequence_length=self.seq_len, + load_from_cache=False, + trust_remote_code=True + ) + try: + del dataset["validation"] + except KeyError: + pass + + dataset = dataset.filter( + lambda example: example["ref_forward_sequence"].count('N') < 0.005 * self.seq_len, + desc="Filter N's" + ) + dataset = dataset.map( + recast_chromosome, + remove_columns=["chromosome"], + desc="Recast chromosome" + ) + dataset = dataset.map( + partial(tokenize_variants, tokenizer=self.tokenizer, max_length=self.seq_len//self.bp_per_token), + batch_size=1000, + batched=True, + remove_columns=["ref_forward_sequence", "alt_forward_sequence"], + desc="Tokenize", + num_proc=self.num_workers + ) + dataset = dataset.map(find_variant_idx, desc="Find variant idx") + + dataset.save_to_disk(self._get_preprocessed_cache_file()) + log.warning(f"Data downloaded and preprocessed successfully.") + +def finetune(args): + """ + Main function to start the training process for ClinVar Variant Effect Prediction using PyTorch Lightning. + + Args: + args: Command line arguments or configuration dictionary. + """ + wandb.login(key=args.wandb_api_key) + data_module = VariantEffectPredictionDataModule(args) + data_module.setup() + + np.random.seed(0) + candidates = np.unique(data_module.dataset["train"]["chromosome"]) + held_chromosomes = np.random.choice(candidates,5,replace = False) + + for idx,val_chromosome in enumerate(held_chromosomes): + + wandb_logger = WandbLogger( + name=f"{args.name_wb}-{args.seq_len}-fold-{idx+1}", + project="Variant Effect Prediction ClinVar", + log_model=True + ) + data_module._split_dataset(val_chromosome) + model = Lit_ClinVarFinetuning(args) + + checkpoint_callback = ModelCheckpoint( + dirpath=f"{args.save_dir}/checkpoints", + filename=f"best-checkpoint-on-chromosome{val_chromosome}", + save_top_k=1, + verbose=True, + monitor="val_loss", + mode="min" + ) + + nb_device = "1" if "nucleotide-transformer" in args.model_name.lower() else "auto" + + # Set up the PyTorch Lightning Trainer + trainer = pl.Trainer( + max_epochs=args.num_epochs, + devices=nb_device, + logger=wandb_logger, + callbacks=[checkpoint_callback], + log_every_n_steps=1, + limit_train_batches=args.train_ratio, + limit_val_batches=args.eval_ratio, + val_check_interval=args.log_interval, + gradient_clip_val=1.0, + precision=args.precision, + accumulate_grad_batches=args.accumulate_grad_batches, + num_sanity_val_steps=0 + ) + + trainer.fit(model, data_module) + trainer.test(model,data_module,ckpt_path=f"./{args.save_dir}/checkpoints/best-checkpoint-on-chromosome{val_chromosome}.ckpt") + + # Finish the current WandB run + wandb.finish() diff --git a/finetuning_glrb/main.py b/finetuning_glrb/main.py new file mode 100644 index 0000000..2b61e46 --- /dev/null +++ b/finetuning_glrb/main.py @@ -0,0 +1,102 @@ + +import os +import sys +# Add the parent directory to sys.path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from os import path as osp +import argparse +import torch + + +from src.utils.train import get_logger +from finetuning_glrb.finetune_variant_effect_pathogenic_clinvar import finetune as finetune_vep_clinvar +from finetuning_glrb.finetune_variant_effect_OMIM import finetune as main_omim +from finetuning_glrb.finetune_variant_effect_causal_eqtl import finetune as finetune_vep_eqtl +from finetuning_glrb.finetune_bulk_rna import finetune as finetune_bulk_rna_expression +from finetuning_glrb.finetune_chromatin import finetune_histone_marks,finetune_dna_accessibility +from finetuning_glrb.finetune_regulatory_elements import finetune_enhancers, finetune_promoters +from finetuning_glrb.finetune_cage import finetune as finetune_cage + +log = get_logger(__name__) + +def main(opts): + # Check if the value of args.task matches one of the predefined options + if opts.task == "variant_effect_causal_eqtl": + finetune_vep_eqtl(opts) + elif opts.task == "variant_effect_pathogenic_clinvar": + finetune_vep_clinvar(opts) + elif opts.task == "variant_effect_pathogenic_omim": + main_omim(opts) + elif opts.task == "bulk_rna_expression": + finetune_bulk_rna_expression(opts) + elif opts.task == "cage_prediction": + finetune_cage(opts) + elif opts.task == "chromatin_features_histone_marks": + finetune_histone_marks(opts) + elif opts.task == "chromatin_features_dna_accessibility": + finetune_dna_accessibility(opts) + elif opts.task == "regulatory_element_promoter": + finetune_promoters(opts) + elif opts.task == "regulatory_element_enhancer": + finetune_enhancers(opts) + + + + +if __name__ == "__main__": + torch.multiprocessing.set_sharing_strategy('file_system') + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--task", + type=str, + choices=[ + "variant_effect_causal_eqtl", + "variant_effect_pathogenic_clinvar", + "variant_effect_pathogenic_omim", + "cage_prediction", + "bulk_rna_expression", + "chromatin_features_histone_marks", + "chromatin_features_dna_accessibility", + "regulatory_element_promoter", + "regulatory_element_enhancer" + ], + required=True, + help="Choose one of the predefined variant effects." + ) + parser.add_argument("--seq_len", type=int, default=131072, + help="Sequence length (in bp)..") + parser.add_argument("--model_name", type=str, required=True, help="Name of the pre-trained model to fine-tune") + parser.add_argument("--bp_per_token", type = int, default = 1, help = "Number of baise pairs per token.") + parser.add_argument("--save_dir", type=str, default="./outputs/downstream/vep_embeddings", + help="Directory to save downstream task.") + parser.add_argument("--wandb_api_key",type=str,default=None,help="Weights & Biases API key for logging.") + parser.add_argument("--name_wb", type=str, default=None, help="Embeddings model name.") + parser.add_argument("--train_batch_size", type=int, default=16, help="Batch size for training.") + parser.add_argument("--test_batch_size", type=int, default=16, help="Batch size for testing/validation.") + parser.add_argument("--num_workers", type=int, default=4, help="Number of workers.") + parser.add_argument("--rcps", type=bool, default=False, help="Using rcps when extracting embeddings or not.") + parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs to train.") + parser.add_argument( + "--precision", + type=str, + choices=['transformer-engine', 'transformer-engine-float16', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', '32-true', '64-true'], + required=True, + help="Choose one mode." + ) + parser.add_argument("--accumulate_grad_batches", type=int, default=1, help="Accumulate gradients") + parser.add_argument("--learning_rate", type=float, default=1e-4, + help="Learning rate for optimizer") + parser.add_argument("--log_interval", type=int, default=5, + help="Log interval") + parser.add_argument("--train_ratio", type=float, default=1.0, + help="Evaluation data ratio") + parser.add_argument("--eval_ratio", type=float, default=1.0, + help="Evaluation data ratio") + opts = parser.parse_args() + log.warning("*** Args ************************") + for k, v in vars(opts).items(): + log.warning(f" - {k}: {v}") + log.warning("******************************\n") + + main(opts) \ No newline at end of file diff --git a/finetuning_glrb/utils.py b/finetuning_glrb/utils.py new file mode 100644 index 0000000..4399b94 --- /dev/null +++ b/finetuning_glrb/utils.py @@ -0,0 +1,70 @@ +import os +import sys +# Add the parent directory to sys.path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +import fsspec +import torch +import torch.nn as nn +import torch.utils + +# Check if file exists using fsspec +def fsspec_exists(filename): + """Check if file exists in manner compatible with fsspec.""" + fs, _ = fsspec.core.url_to_fs(filename) + return fs.exists(filename) + +# List directory contents using fsspec +def fsspec_listdir(dirname): + """Listdir in manner compatible with fsspec.""" + fs, _ = fsspec.core.url_to_fs(dirname) + return fs.ls(dirname) + + +def get_last_embedding_dimension(model: nn.Module, rcps=False) -> int: + """ + Function to get the last embedding dimension of a PyTorch model by passing + a random tensor through the model and inspecting the output shape. + This is done with gradients disabled and always on GPU. + + Args: + model (nn.Module): The PyTorch model instance. + + Returns: + int: The last embedding dimension (i.e., the last dimension of the output tensor). + """ + # Move the model to GPU + model = model.to('cuda') + + # Try to determine the input shape based on the first layer of the model + for module in model.modules(): + if isinstance(module, nn.Conv2d): + # Assume a common image input size if it's a Conv2d layer + input_shape = (3, 224, 224) # RGB image of size 224x224 + break + elif isinstance(module, nn.Linear): + # Assume a 1D input size for a fully connected layer + input_shape = (module.in_features,) + break + elif isinstance(module, nn.Embedding): + # Assume a single index for an Embedding layer + input_shape = (1,) + break + else: + raise ValueError("Unable to determine the input shape automatically.") + + # Generate a random input tensor and move it to GPU + random_input = torch.randint(low=0, high=16, size=(1, *input_shape)).to('cuda') # Add batch size of 1 + + # Pass the tensor through the model with no gradients + with torch.no_grad(): + output = model(random_input)[0] + + # Get the shape of the output tensor + last_embedding_dimension = output.shape[-1] + + if rcps: + last_embedding_dimension //= 2 + + # Return the last dimension of the output tensor + return last_embedding_dimension