diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 7052740..2fb5427 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -39,4 +39,4 @@ python: install: - requirements: docs/requirements.txt - method: pip - path: . \ No newline at end of file + path: . diff --git a/README.md b/README.md index 9383258..2cc0637 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![Code License: Apache 2.0](https://img.shields.io/badge/license-Apache%20License%202.0-blue)](https://www.apache.org/licenses/LICENSE-2.0) [![PyPI version](https://badge.fury.io/py/cell2sentence.svg)](https://badge.fury.io/py/cell2sentence) [![DOI:10.1101/2025.04.14.648850](http://img.shields.io/badge/DOI-10.1101/2025.04.14.648850-B31B1B.svg)](https://doi.org/10.1101/2025.04.14.648850) -[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/release/python-380/) +[![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/release/python-31020/) ![cell2sentence workflow image](c2s_overview_figure.png) @@ -47,7 +47,7 @@ git clone https://github.com/vandijklab/cell2sentence.git Navigate a terminal into the root of the repository. Next, create an Anaconda environment using `python3` using [anaconda](https://docs.anaconda.com/anaconda/install/) with: ```bash -conda create -n cell2sentence python=3.8 +conda create -n cell2sentence python=3.10 ``` Next, activate the environment: @@ -62,7 +62,7 @@ make install This will install the latest development environment of cell2sentence, along with other pacakge dependendies. You can also install cell2sentence itself using `pip`: ```bash -pip install cell2sentence==1.1.0 +pip install cell2sentence==1.2.0 ``` The C2S package will allow usage of the core functionalities of C2S, including inference using existing C2S models and finetuning your own C2S models on your own datasets. @@ -86,6 +86,10 @@ The following notebooks provide guides on common workflows with C2S models. For | [c2s_tutorial_4_cell_type_prediction.ipynb](tutorials/c2s_tutorial_4_cell_type_prediction.ipynb) | Cell type prediction using C2S models | [c2s_tutorial_5_cell_generation.ipynb](tutorials/c2s_tutorial_5_cell_generation.ipynb) | Cell generation conditioned on cell type | [c2s_tutorial_6_cell_annotation_with_foundation_model.ipynb](tutorials/c2s_tutorial_6_cell_annotation_with_foundation_model.ipynb) | Cell type annotation with foundation model +| [c2s_tutorial_7_custom_prompt_templates.ipynb](tutorials/c2s_tutorial_7_custom_prompt_templates.ipynb) | Custom Prompt Templates with C2S PromptFormatter class +| [c2s_tutorial_8_multi_cell_tissue_prediction.ipynb](tutorials/c2s_tutorial_8_multi_cell_tissue_prediction.ipynb) | Classifying the Tissue based on Multiple cell sentences +| [c2s_tutorial_9_natural_language_interpretation.ipynb](tutorials/c2s_tutorial_9_natural_language_interpretation.ipynb) | Use the C2S model to generate insightful summaries for different sets of cells +| [c2s_tutorial_10_perturbation_response_prediction.ipynb](tutorials/c2s_tutorial_10_perturbation_response_prediction.ipynb)| Use the C2S model to perform perturbation based on pre-cell sentences and treatment ## Model Zoo @@ -111,7 +115,7 @@ each explain which model they use. - [x] Add tutorial notebooks for main C2S workflows: cell type prediction, cell generation - [x] Add multi-cell prompt formatting - [ ] Add support for legacy C2S-GPT-2 model prompts -- [ ] Add parameter-efficient finetuning methods (LoRA) +- [x] Add parameter-efficient finetuning methods (LoRA) ## License diff --git a/docs/Makefile b/docs/Makefile index 8b95755..19417c7 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -20,3 +20,4 @@ help: @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) # Example: sphinx-build -M html docs/source/ docs/build/ + diff --git a/docs/make.bat b/docs/make.bat index dc1312a..9d3a30f 100644 --- a/docs/make.bat +++ b/docs/make.bat @@ -33,3 +33,4 @@ goto end :end popd + diff --git a/docs/source/csmodel.rst b/docs/source/csmodel.rst index 6eff74d..ab75469 100644 --- a/docs/source/csmodel.rst +++ b/docs/source/csmodel.rst @@ -1,5 +1,3 @@ -CSModel -======= A CSModel object is a wrapper around a Cell2Sentence model, which tracks the path of the model saved on disk. When needed, the model is loaded from the path on disk for inference or finetuning. @@ -21,4 +19,4 @@ The class contains utilities for model generation and cell embedding with a Hugg .. autofunction:: csmodel.CSModel.embed_cells_batched -.. autofunction:: csmodel.CSModel.push_model_to_hub +.. autofunction:: csmodel.CSModel.push_model_to_hub \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 9775de0..1ce4c69 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,44 +1,49 @@ -[metadata] -name = cell2sentence -version = 1.2.0 -author = Syed Asad Rizvi -author_email = syed.rizvi@yale.edu -description = Cell2Sentence: Single-cell Analysis With LLMs -long_description = file: README.md -long_description_content_type = text/markdown -url = https://github.com/vandijklab/cell2sentence -license = 'BY-NC-ND' -project_urls = - Bug Tracker = https://github.com/vandijklab/cell2sentence/issues -classifiers = - Programming Language :: Python :: 3 - Development Status :: 2 - Pre-Alpha - Operating System :: OS Independent - -[options] -package_dir = - = src -packages = find: -python_requires = >=3.7 -install_requires = - torch - transformers - datasets - anndata - scanpy - numpy - pandas - scipy - tqdm - scikit-learn - jupyterlab - accelerate - plotnine - sphinx - sphinx-rtd-theme - -[options.packages.find] -where = src - -[options.package_data] +[metadata] +name = cell2sentence +version = 1.2.0 +author = Syed Asad Rizvi +author_email = syed.rizvi@yale.edu +description = Cell2Sentence: Single-cell Analysis With LLMs +long_description = file: README.md +long_description_content_type = text/markdown +url = https://github.com/vandijklab/cell2sentence +license = 'BY-NC-ND' +project_urls = + Bug Tracker = https://github.com/vandijklab/cell2sentence/issues +classifiers = + Programming Language :: Python :: 3 + Development Status :: 2 - Pre-Alpha + Operating System :: OS Independent + +[options] +package_dir = + = src +packages = find: +python_requires = >=3.10 +install_requires = + torch + transformers + peft + bitsandbytes + datasets + anndata + scanpy + numpy + pandas + scipy + tqdm + scikit-learn + jupyterlab + accelerate + plotnine + sphinx + sphinx-rtd-theme + tiktoken + sentencepiece + protobuf + +[options.packages.find] +where = src + +[options.package_data] * = *.json \ No newline at end of file diff --git a/src/cell2sentence/csmodel.py b/src/cell2sentence/csmodel.py index 7490619..38eb580 100644 --- a/src/cell2sentence/csmodel.py +++ b/src/cell2sentence/csmodel.py @@ -1,337 +1,399 @@ -""" -Main model wrapper class definition -""" - -# -# @authors: Rahul Dhodapkar, Syed Rizvi -# - -# Python built-in libraries -import os -import pickle -from random import sample -from typing import Optional - -# Third-party libraries -import numpy as np -from datasets import load_from_disk, DatasetDict, Dataset - -# Pytorch, Huggingface imports -import torch -from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments - -# Local imports -from cell2sentence.prompt_formatter import PromptFormatter, C2SPromptFormatter -from cell2sentence.utils import train_test_split_arrow_ds, tokenize_all, tokenize_loss_on_response - - -class CSModel(): - """ - Wrapper class to abstract different types of input data that can be passed - in cell2sentence based workflows. - """ - - def __init__(self, model_name_or_path, save_dir, save_name): - """ - Core constructor, CSModel class contains a path to a model. - - Arguments: - model_name_or_path: either a string representing a Huggingface model if - want to start with a default LLM, or a path to an already-trained C2S - model on disk if want to do inference with/finetune starting from - an already-trained C2S model - save_dir: directory where model should be saved to - save_name: name to save model under (no file extension needed) - """ - self.model_name_or_path = model_name_or_path # path to model to load - self.save_dir = save_dir - self.device = "cuda" if torch.cuda.is_available() else "cpu" - print("Using device:", self.device) - - # Create save path - if not os.path.exists(save_dir): - os.mkdir(save_dir) - self.save_path = os.path.join(save_dir, save_name) - - # Load tokenizer - self.tokenizer = AutoTokenizer.from_pretrained( - model_name_or_path, padding_side='left' - ) - if self.tokenizer.pad_token is None: - self.tokenizer.pad_token = self.tokenizer.eos_token - - # Load model - either a pretrained C2S model path, or a Huggingface LLM name (if want to train from scratch on your own dataset) - model = AutoModelForCausalLM.from_pretrained( - model_name_or_path, - cache_dir=os.path.join(save_dir, ".cache"), # model file takes up several GB if loading default Huggignface LLM models - trust_remote_code=True - ) - model.save_pretrained(self.save_path) - - def __str__(self): - """ - Summarize CSData object as string for debugging and logging. - """ - return f"CSModel Object; Path={self.save_path}" - - def fine_tune(self, - csdata, - task: str, - train_args: TrainingArguments, - loss_on_response_only: bool = True, - top_k_genes: int = 100, - max_eval_samples: int = 500, - data_split_indices_dict: Optional[dict] = None, - prompt_formatter: Optional[PromptFormatter] = None, - formatted_hf_ds: Optional[Dataset] = None, - num_proc: int = 3, - ): - """ - Fine tune a model using the provided CSData object data - - Arguments: - csdata: a CSData object to be used as input for finetuning. - alternatively, data can be any generator of sequential - text that satisfies the same functional contract as - a CSData object - task: name of finetuning task (see supported tasks in prompt_formatter.py). Ignored - if prompt_formatter is not None. - train_args: Huggingface Trainer arguments object - loss_on_response_only: whether to take loss only on model's answer - top_k_genes: number of genes to use for each cell sentence. Ignored if - prompt_formatter is not None. - max_eval_samples: number of samples to use for validation - data_split_indices_dict: dictionary of indices for train, val, and (optionally) - test set. Required keys are "train" and "val", value - should be a list of indices of samples in that data split. - prompt_formatter: optional custom PromptFormatter object. If None, a default one - will be created using task and top_k_genes parameters. - formatted_hf_ds: optional Huggingface Dataset object containing formatted data, - used in cases where custom formatting is desired (e.g. multicell - tasks where more complex formatting is needed). - num_proc: number of processes to use for tokenization. Defaults to 3. - Return: - None: an updated CSModel is generated in-place - """ - # Load data from csdata object - if csdata.dataset_backend == "arrow": - hf_ds = load_from_disk(csdata.data_path) - else: - raise NotImplementedError("Please use arrow backend implementation for training") - - # Define prompt formatter, format prompts - if prompt_formatter is None: - prompt_formatter = C2SPromptFormatter(task=task, top_k_genes=top_k_genes) - if formatted_hf_ds is None: - # If formatted dataset not supplied, format hf_ds using prompt_formatter - formatted_hf_ds = prompt_formatter.format_hf_ds(hf_ds) - - # Load model - print("Reloading model from path on disk:", self.save_path) - model = AutoModelForCausalLM.from_pretrained( - self.save_path, - cache_dir=os.path.join(self.save_dir, ".cache"), - trust_remote_code=True - ) - model = model.to(self.device) - - # Tokenize data using LLM tokenizer - # - this function applies a lambda function to tokenize each dataset split in the DatasetDict - if loss_on_response_only: - tokenization_function = tokenize_loss_on_response - else: - tokenization_function = tokenize_all - formatted_hf_ds = formatted_hf_ds.map( - lambda batch: tokenization_function(batch, self.tokenizer), - batched=True, - load_from_cache_file=False, - num_proc=num_proc, - batch_size=1000, - ) - - # Define parameters needed in data collator: - block_size = model.config.max_position_embeddings # maximum input sequence length possible - tokenizer = self.tokenizer # define tokenizer as variable here so it is accessible in dataloader - def data_collator(examples): - # Note: this data collator assumes we are not using flash attention, and pads samples - # to the max size in the batch. All sample lengths are capped at the size of the - # LLM's context). - max_length = max(list(map(lambda x: len(x["input_ids"]), examples))) - batch_input_ids, batch_attention_mask, batch_labels = [], [], [] - for i in range(len(examples)): - sample_input_ids = examples[i]["input_ids"] - label_input_ids = examples[i]["labels"] - attention_mask = examples[i]["attention_mask"] - assert len(sample_input_ids) == len(label_input_ids) == len(attention_mask) - - size_diff = max_length - len(sample_input_ids) - final_input_ids = [tokenizer.pad_token_id] * (size_diff) + sample_input_ids - final_attention_mask = [0] * (size_diff) + attention_mask - final_label_input_ids = [-100] * (size_diff) + label_input_ids - - batch_input_ids.append(final_input_ids[: block_size]) - batch_attention_mask.append(final_attention_mask[: block_size]) - batch_labels.append(final_label_input_ids[: block_size]) - - return { - "input_ids": torch.tensor(batch_input_ids), - "attention_mask": torch.tensor(batch_attention_mask), - "labels": torch.tensor(batch_labels), - } - - output_dir = train_args.output_dir - print(f"Starting training. Output directory: {output_dir}") - - # Perform dataset split - if data_split_indices_dict is None: - split_ds_dict, data_split_indices_dict = train_test_split_arrow_ds(formatted_hf_ds) - else: - # Dataset split indices supplied by user, split formatted arrow dataset accordingly - train_ds = formatted_hf_ds.select(data_split_indices_dict["train"]) - val_ds = formatted_hf_ds.select(data_split_indices_dict["val"]) - ds_dict_object = { - "train": train_ds, - "validation": val_ds, - } - split_ds_dict = DatasetDict(ds_dict_object) - with open(os.path.join(output_dir, 'data_split_indices_dict.pkl'), 'wb') as f: - pickle.dump(data_split_indices_dict, f) - - train_dataset = split_ds_dict["train"] - eval_dataset = split_ds_dict["validation"] - if (max_eval_samples is not None) and (max_eval_samples < eval_dataset.num_rows): - sampled_eval_indices = sample(list(range(eval_dataset.num_rows)), k=max_eval_samples) - sampled_eval_indices.sort() - np.save(os.path.join(output_dir, 'sampled_eval_indices.npy'), np.array(sampled_eval_indices, dtype=np.int64)) - print(f"Selecting {max_eval_samples} samples of eval dataset to shorten validation loop.") - eval_dataset = eval_dataset.select(sampled_eval_indices) - - # Define Trainer - trainer = Trainer( - model=model, - args=train_args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - tokenizer=self.tokenizer - ) - trainer.train() - print(f"Finetuning completed. Updated model saved to disk at: {output_dir}") - - def generate_from_prompt(self, model, prompt, max_num_tokens=1024, **kwargs): - """ - Generate new data using the model, starting with a given prompt. - - Arguments: - model: a C2S model - prompt: a textual prompt - max_num_tokens: the maximum number of tokens to generate given the model supplied - kwargs: arguments for model.generate() (for generation options, see Huggingface docs: - https://huggingface.co/docs/transformers/en/main_classes/text_generation). - Any kwargs are passed without input validation to the model.generate() function - Return: - Text corresponding to the number `n` of tokens requested - """ - return self.generate_from_prompt_batched( - model=model, - prompt_list=[prompt], - max_num_tokens=max_num_tokens, - **kwargs - )[0] - - def generate_from_prompt_batched(self, model, prompt_list, max_num_tokens=1024, **kwargs): - """ - Batched generation with C2S model. Takes as input a model and a list of prompts to - generate from. - - Arguments: - model: a C2S model - prompt: a textual prompt - max_num_tokens: the maximum number of tokens to generate given the model supplied - kwargs: arguments for model.generate() (for generation options, see Huggingface docs: - https://huggingface.co/docs/transformers/en/main_classes/text_generation) - Return: - Text corresponding to the number `n` of tokens requested - """ - tokens = self.tokenizer(prompt_list, padding=True, return_tensors='pt') - input_ids = tokens['input_ids'].to(self.device) - attention_mask = tokens['attention_mask'].to(self.device) - - outputs = model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - max_new_tokens=max_num_tokens, - pad_token_id=self.tokenizer.pad_token_id, - **kwargs - ) - pred_list = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) - predictions_without_input_prompt = [] - for pred, prompt in zip(pred_list, prompt_list): - pred_cleaned = pred.replace(prompt, "") - pred_cleaned = pred_cleaned.replace("<|endoftext|>", "") # remove end of text string - pred_cleaned = pred_cleaned.lstrip() # remove any leading whitespace - predictions_without_input_prompt.append(pred_cleaned) - - return predictions_without_input_prompt - - def embed_cell(self, model, prompt, max_num_tokens=1024): - """ - Embed cell using the model, starting with a given prompt. - - Arguments: - model: a C2S model - prompt: a textual prompt - max_num_tokens: the maximum number of tokens to generate given the model supplied - Return: - Text corresponding to the number `n` of tokens requested - """ - embedding_list = self.embed_cells_batched( - model=model, - prompt_list=[prompt], - max_num_tokens=max_num_tokens) - return embedding_list[0] # return 1 cell embedding - - def embed_cells_batched(self, model, prompt_list, max_num_tokens=1024): - """ - Embed multiple cell in batched fashion using the model, starting with a given prompt. - - Arguments: - model: a C2S model for cell embedding - prompt_list: a list of textual prompts - max_num_tokens: the maximum number of tokens to generate given the model supplied - Return: - Text corresponding to the number `n` of tokens requested - """ - tokens = self.tokenizer(prompt_list, padding=True, return_tensors='pt') - input_ids = tokens['input_ids'].to(self.device) - attention_mask = tokens['attention_mask'].to(self.device) - - outputs = model( - input_ids=input_ids, - attention_mask=attention_mask, - output_hidden_states=True - ) - # Take last layer output, average over sequence dimension - all_embeddings = [] - for idx in range(len(prompt_list)): - embedding = outputs.hidden_states[-1][idx].mean(0).detach().cpu().numpy() - all_embeddings.append(embedding) - return all_embeddings - - def push_model_to_hub(self, model_id_or_name): - """ - Helper function to push the model to Huggingface. Note: need to be logged - into Huggingface, see: https://huggingface.co/docs/transformers/en/model_sharing - - Arguments: - model_id_or_name: name to push Huggingface model to - """ - # Reload model - model = AutoModelForCausalLM.from_pretrained( - self.save_path, - cache_dir=os.path.join(self.save_dir, ".cache"), - trust_remote_code=True - ) - - # Push to hub - model.push_to_hub(model_id_or_name, use_auth_token=True) +""" +Main model wrapper class definition +""" + +# +# @authors: Rahul Dhodapkar, Syed Rizvi +# + +# Python built-in libraries +import os +import pickle +from random import sample +from typing import Optional + +# Third-party libraries +import numpy as np +from datasets import load_from_disk, DatasetDict, Dataset + +# Pytorch, Huggingface imports +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments +from peft import LoraConfig, get_peft_model, AutoPeftModelForCausalLM +from huggingface_hub import login + +# Local imports +from cell2sentence.prompt_formatter import PromptFormatter, C2SPromptFormatter +from cell2sentence.utils import train_test_split_arrow_ds, tokenize_all, tokenize_loss_on_response + + +class CSModel(): + """ + Wrapper class to abstract different types of input data that can be passed + in cell2sentence based workflows. + """ + + def __init__( + self, + model_name_or_path, + save_dir, save_name, + peft = False, + r = 16, + alpha = 32, + modules = None, + bias = "none", + task_type = "CAUSAL_LM", + huggingface_token: Optional[str] = None + ): + """ + Core constructor, CSModel class contains a path to a model. + + Arguments: + model_name_or_path (str): Huggingface model ID or local path to a pretrained model. + save_dir (str): Directory where model should be saved. + save_name (str): Name to save model under (no file extension needed). + peft (bool): Whether to enable parameter-efficient fine-tuning (PEFT/LoRA). + When True, LoRA adapters are configured and attached to the base model. + r (int): LoRA rank (must be > 0). Controls the low-rank decomposition + dimension for the adapter. A larger `r` increases adapter capacity. + alpha (int): LoRA alpha (must be > 0). Scaling factor applied to LoRA + updates; typically used together with `r` to control update magnitude. + modules (list[str] or None): List of target module name substrings to + apply LoRA to (e.g. ["q_proj", "k_proj", "v_proj", "o_proj"]). + bias (str): Bias handling passed to `LoraConfig` (common values: 'none', + 'all', 'lora_only'). + task_type (str): PEFT task type (e.g. 'CAUSAL_LM'). This is unrelated to + the higher-level C2S `task` used for prompt formatting. + huggingface_token (str or None): Optional HF token used to log in to the + Hugging Face hub when pushing or fetching private models. + + Notes: + - When `peft=True`, `r` and `alpha` must be positive integers; passing + non-positive values should raise a `ValueError` so callers are not + silently corrected. + - `modules` defaults to common projection layer names but callers should + pass an explicit list if they plan to mutate it later (avoid + relying on mutable default arguments). + """ + if huggingface_token: + login(huggingface_token) + + self.model_name_or_path = model_name_or_path # path to model to load + self.save_dir = save_dir + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.peft = peft + print("Using device:", self.device) + + # Create save path + os.makedirs(save_dir,exist_ok=True) + self.save_path = os.path.join(save_dir, save_name) + + # Avoid mutable default for modules + if modules is None: + modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + + # Load tokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, padding_side='left' + ) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + # Load model - either a pretrained C2S model path, or a Huggingface LLM name (if want to train from scratch on your own dataset) + model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + cache_dir=os.path.join(save_dir, ".cache"), # model file takes up several GB if loading default Huggignface LLM models + trust_remote_code=True + ) + + # Load parameter efficient model ready for training + if peft: + if r <= 0: + raise ValueError(f"The value of r <= 0; found r = {r}") + if alpha <= 0: + raise ValueError(f"The value of alpha <= 0; found alpha = {alpha}") + peft_config = LoraConfig( + r = r, + lora_alpha = alpha, + target_modules = modules, + bias = bias, + task_type = task_type + ) + model = get_peft_model(model, peft_config) + + model.save_pretrained(self.save_path) + + def __str__(self): + """ + Summarize CSData object as string for debugging and logging. + """ + return f"CSModel Object; Path={self.save_path}" + + def fine_tune(self, + csdata, + task: str, + train_args: TrainingArguments, + loss_on_response_only: bool = True, + top_k_genes: int = 100, + max_eval_samples: int = 500, + data_split_indices_dict: Optional[dict] = None, + prompt_formatter: Optional[PromptFormatter] = None, + formatted_hf_ds: Optional[Dataset] = None, + num_proc: int = 3, + ): + """ + Fine tune a model using the provided CSData object data + + Arguments: + csdata: a CSData object to be used as input for finetuning. + alternatively, data can be any generator of sequential + text that satisfies the same functional contract as + a CSData object + task: name of finetuning task (see supported tasks in prompt_formatter.py). Ignored + if prompt_formatter is not None. + train_args: Huggingface Trainer arguments object + loss_on_response_only: whether to take loss only on model's answer + top_k_genes: number of genes to use for each cell sentence. Ignored if + prompt_formatter is not None. + max_eval_samples: number of samples to use for validation + data_split_indices_dict: dictionary of indices for train, val, and (optionally) + test set. Required keys are "train" and "val", value + should be a list of indices of samples in that data split. + prompt_formatter: optional custom PromptFormatter object. If None, a default one + will be created using task and top_k_genes parameters. + formatted_hf_ds: optional Huggingface Dataset object containing formatted data, + used in cases where custom formatting is desired (e.g. multicell + tasks where more complex formatting is needed). + num_proc: number of processes to use for tokenization. Defaults to 3. + Return: + None: an updated CSModel is generated in-place + """ + # Load data from csdata object + if csdata.dataset_backend == "arrow": + hf_ds = load_from_disk(csdata.data_path) + else: + raise NotImplementedError("Please use arrow backend implementation for training") + + # Define prompt formatter, format prompts + if prompt_formatter is None: + prompt_formatter = C2SPromptFormatter(task=task, top_k_genes=top_k_genes) + if formatted_hf_ds is None: + # If formatted dataset not supplied, format hf_ds using prompt_formatter + formatted_hf_ds = prompt_formatter.format_hf_ds(hf_ds) + + # Load model + print("Reloading model from path on disk:", self.save_path) + if not self.peft: + model = AutoModelForCausalLM.from_pretrained( + self.save_path, + cache_dir=os.path.join(self.save_dir, ".cache"), + trust_remote_code=True + ) + else: + model = AutoPeftModelForCausalLM.from_pretrained( + self.save_path, + cache_dir=os.path.join(self.save_dir, ".cache"), + trust_remote_code=True + ) + model = model.to(self.device) + + # Tokenize data using LLM tokenizer + # - this function applies a lambda function to tokenize each dataset split in the DatasetDict + if loss_on_response_only: + tokenization_function = tokenize_loss_on_response + else: + tokenization_function = tokenize_all + formatted_hf_ds = formatted_hf_ds.map( + lambda batch: tokenization_function(batch, self.tokenizer), + batched=True, + load_from_cache_file=False, + num_proc=num_proc, + batch_size=1000, + ) + + # Define parameters needed in data collator: + block_size = model.config.max_position_embeddings # maximum input sequence length possible + tokenizer = self.tokenizer # define tokenizer as variable here so it is accessible in dataloader + def data_collator(examples): + # Note: this data collator assumes we are not using flash attention, and pads samples + # to the max size in the batch. All sample lengths are capped at the size of the + # LLM's context). + max_length = max(list(map(lambda x: len(x["input_ids"]), examples))) + batch_input_ids, batch_attention_mask, batch_labels = [], [], [] + for i in range(len(examples)): + sample_input_ids = examples[i]["input_ids"] + label_input_ids = examples[i]["labels"] + attention_mask = examples[i]["attention_mask"] + assert len(sample_input_ids) == len(label_input_ids) == len(attention_mask) + + size_diff = max_length - len(sample_input_ids) + final_input_ids = [tokenizer.pad_token_id] * (size_diff) + sample_input_ids + final_attention_mask = [0] * (size_diff) + attention_mask + final_label_input_ids = [-100] * (size_diff) + label_input_ids + + batch_input_ids.append(final_input_ids[: block_size]) + batch_attention_mask.append(final_attention_mask[: block_size]) + batch_labels.append(final_label_input_ids[: block_size]) + + return { + "input_ids": torch.tensor(batch_input_ids), + "attention_mask": torch.tensor(batch_attention_mask), + "labels": torch.tensor(batch_labels), + } + + output_dir = train_args.output_dir + print(f"Starting training. Output directory: {output_dir}") + + # Perform dataset split + if data_split_indices_dict is None: + split_ds_dict, data_split_indices_dict = train_test_split_arrow_ds(formatted_hf_ds) + else: + # Dataset split indices supplied by user, split formatted arrow dataset accordingly + train_ds = formatted_hf_ds.select(data_split_indices_dict["train"]) + val_ds = formatted_hf_ds.select(data_split_indices_dict["val"]) + ds_dict_object = { + "train": train_ds, + "validation": val_ds, + } + split_ds_dict = DatasetDict(ds_dict_object) + with open(os.path.join(output_dir, 'data_split_indices_dict.pkl'), 'wb') as f: + pickle.dump(data_split_indices_dict, f) + + train_dataset = split_ds_dict["train"] + eval_dataset = split_ds_dict["validation"] + if (max_eval_samples is not None) and (max_eval_samples < eval_dataset.num_rows): + sampled_eval_indices = sample(list(range(eval_dataset.num_rows)), k=max_eval_samples) + sampled_eval_indices.sort() + np.save(os.path.join(output_dir, 'sampled_eval_indices.npy'), np.array(sampled_eval_indices, dtype=np.int64)) + print(f"Selecting {max_eval_samples} samples of eval dataset to shorten validation loop.") + eval_dataset = eval_dataset.select(sampled_eval_indices) + + # Define Trainer + trainer = Trainer( + model=model, + args=train_args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=self.tokenizer #changed argument from tokenizer to processing_class as per modern documentation + ) + trainer.train() + print(f"Finetuning completed. Updated model saved to disk at: {output_dir}") + + def generate_from_prompt(self, model, prompt, max_num_tokens=1024, **kwargs): + """ + Generate new data using the model, starting with a given prompt. + + Arguments: + model: a C2S model + prompt: a textual prompt + max_num_tokens: the maximum number of tokens to generate given the model supplied + kwargs: arguments for model.generate() (for generation options, see Huggingface docs: + https://huggingface.co/docs/transformers/en/main_classes/text_generation). + Any kwargs are passed without input validation to the model.generate() function + Return: + Text corresponding to the number `n` of tokens requested + """ + return self.generate_from_prompt_batched( + model=model, + prompt_list=[prompt], + max_num_tokens=max_num_tokens, + **kwargs + )[0] + + def generate_from_prompt_batched(self, model, prompt_list, max_num_tokens=1024, **kwargs): + """ + Batched generation with C2S model. Takes as input a model and a list of prompts to + generate from. + + Arguments: + model: a C2S model + prompt: a textual prompt + max_num_tokens: the maximum number of tokens to generate given the model supplied + kwargs: arguments for model.generate() (for generation options, see Huggingface docs: + https://huggingface.co/docs/transformers/en/main_classes/text_generation) + Return: + Text corresponding to the number `n` of tokens requested + """ + tokens = self.tokenizer(prompt_list, padding=True, return_tensors='pt') + input_ids = tokens['input_ids'].to(self.device) + attention_mask = tokens['attention_mask'].to(self.device) + + outputs = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=max_num_tokens, + pad_token_id=self.tokenizer.pad_token_id, + **kwargs + ) + pred_list = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + predictions_without_input_prompt = [] + for pred, prompt in zip(pred_list, prompt_list): + pred_cleaned = pred.replace(prompt, "") + pred_cleaned = pred_cleaned.replace("<|endoftext|>", "") # remove end of text string + pred_cleaned = pred_cleaned.lstrip() # remove any leading whitespace + predictions_without_input_prompt.append(pred_cleaned) + + return predictions_without_input_prompt + + def embed_cell(self, model, prompt, max_num_tokens=1024): + """ + Embed cell using the model, starting with a given prompt. + + Arguments: + model: a C2S model + prompt: a textual prompt + max_num_tokens: the maximum number of tokens to generate given the model supplied + Return: + Text corresponding to the number `n` of tokens requested + """ + embedding_list = self.embed_cells_batched( + model=model, + prompt_list=[prompt], + max_num_tokens=max_num_tokens) + return embedding_list[0] # return 1 cell embedding + + def embed_cells_batched(self, model, prompt_list, max_num_tokens=1024): + """ + Embed multiple cell in batched fashion using the model, starting with a given prompt. + + Arguments: + model: a C2S model for cell embedding + prompt_list: a list of textual prompts + max_num_tokens: the maximum number of tokens to generate given the model supplied + Return: + Text corresponding to the number `n` of tokens requested + """ + tokens = self.tokenizer(prompt_list, padding=True, return_tensors='pt') + input_ids = tokens['input_ids'].to(self.device) + attention_mask = tokens['attention_mask'].to(self.device) + + outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True + ) + # Take last layer output, average over sequence dimension + all_embeddings = [] + for idx in range(len(prompt_list)): + embedding = outputs.hidden_states[-1][idx].mean(0).detach().cpu().numpy() + all_embeddings.append(embedding) + return all_embeddings + + def push_model_to_hub(self, model_id_or_name): + """ + Helper function to push the model to Huggingface. Note: need to be logged + into Huggingface, see: https://huggingface.co/docs/transformers/en/model_sharing + + Arguments: + model_id_or_name: name to push Huggingface model to + """ + # Reload model + model = AutoModelForCausalLM.from_pretrained( + self.save_path, + cache_dir=os.path.join(self.save_dir, ".cache"), + trust_remote_code=True + ) + + # Push to hub + model.push_to_hub(model_id_or_name, use_auth_token=True) diff --git a/src/cell2sentence/tests/small_data.csv b/src/cell2sentence/tests/small_data.csv index cf2ee95..76f2bff 100644 --- a/src/cell2sentence/tests/small_data.csv +++ b/src/cell2sentence/tests/small_data.csv @@ -1,4 +1,5 @@ ,cell1,cell2,cell3,cell4,cell5 g1,0,3,0,1,3 g2,0,0,1,1,2 -g3,3,1,0,0,1 \ No newline at end of file +g3,3,1,0,0,1 +g4,1,0,0,0,0 \ No newline at end of file diff --git a/src/cell2sentence/tests/small_data_diffgenes.csv b/src/cell2sentence/tests/small_data_diffgenes.csv index cf48d65..76f2bff 100644 --- a/src/cell2sentence/tests/small_data_diffgenes.csv +++ b/src/cell2sentence/tests/small_data_diffgenes.csv @@ -2,4 +2,4 @@ g1,0,3,0,1,3 g2,0,0,1,1,2 g3,3,1,0,0,1 -g4,2,1,0,0,1 \ No newline at end of file +g4,1,0,0,0,0 \ No newline at end of file diff --git a/src/cell2sentence/tests/test_csmodel.py b/src/cell2sentence/tests/test_csmodel.py index 5615e92..b984459 100644 --- a/src/cell2sentence/tests/test_csmodel.py +++ b/src/cell2sentence/tests/test_csmodel.py @@ -1,50 +1,86 @@ -#!/usr/bin/env python -# -# Test model handling with CSModel wrapper -# - -# Python built-in libraries -import os -import random -from pathlib import Path - -# Third-party libraries -import pytest - -# Pytorch, Huggingface -from transformers import AutoModelForCausalLM -from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM - -# Local imports -import cell2sentence as cs -from cell2sentence.csmodel import CSModel - -HERE = Path(__file__).parent - - -class TestCSModelCellTypeConditionalGenerationWorkflow: - @classmethod - def setup_class(self): - # Define CSModel object - cell_type_cond_generation_model_path = "/home/sr2464/scratch/C2S_Files/multicell_pretraining_v2_important_models/pythia-410m-multicell_v2_2024-07-28_14-10-44_checkpoint-7000_cell_type_cond_generation" - self.save_dir = "/home/sr2464/scratch/C2S_Files/c2s_api_testing/csmodel_testing" - self.save_name = "cell_type_cond_generation_pythia_410M_1" - self.csmodel = CSModel( - model_name_or_path=cell_type_cond_generation_model_path, - save_dir=self.save_dir, - save_name=self.save_name - ) - - def test_csmodel_string_representation(self): - assert 'CSModel' in (str(self.csmodel) + '') - - def test_csmodel_created_correctly(self): - assert self.csmodel.save_path == os.path.join(self.save_dir, self.save_name) - - def test_csmodel_reload_from_disk(self): - reloaded_model = AutoModelForCausalLM.from_pretrained( - self.csmodel.save_path, - cache_dir=os.path.join(self.save_dir, ".cache"), - trust_remote_code=True - ) - assert type(reloaded_model) == GPTNeoXForCausalLM +#!/usr/bin/env python +# +# Test model handling with CSModel wrapper +# + +# Python built-in libraries +import os +import random +from pathlib import Path + +# Third-party libraries +import pytest + +# Pytorch, Huggingface +from transformers import AutoModelForCausalLM +from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM +from peft import PeftModel +# Local imports +import cell2sentence as cs +from cell2sentence.csmodel import CSModel + +HERE = Path(__file__).parent + + +# class TestCSModelCellTypeConditionalGenerationWorkflow: +# @classmethod +# def setup_class(self): +# # Define CSModel object +# cell_type_cond_generation_model_path = "/home/sr2464/scratch/C2S_Files/multicell_pretraining_v2_important_models/pythia-410m-multicell_v2_2024-07-28_14-10-44_checkpoint-7000_cell_type_cond_generation" +# self.save_dir = "/home/sr2464/scratch/C2S_Files/c2s_api_testing/csmodel_testing" +# self.save_name = "cell_type_cond_generation_pythia_410M_1" +# self.csmodel = CSModel( +# model_name_or_path=cell_type_cond_generation_model_path, +# save_dir=self.save_dir, +# save_name=self.save_name +# ) + +# def test_csmodel_string_representation(self): +# assert 'CSModel' in (str(self.csmodel) + '') + +# def test_csmodel_created_correctly(self): +# assert self.csmodel.save_path == os.path.join(self.save_dir, self.save_name) + +# def test_csmodel_reload_from_disk(self): +# reloaded_model = AutoModelForCausalLM.from_pretrained( +# self.csmodel.save_path, +# cache_dir=os.path.join(self.save_dir, ".cache"), +# trust_remote_code=True +# ) +# assert type(reloaded_model) == GPTNeoXForCausalLM + +class TestCSModelPeftModelLoadingAndErrorHandling: + @classmethod + def setup_class(self): + self.save_dir = "/mnt/c/Users/khmam/Desktop/c2s_model_directory" + self.save_name = "lora_gemma_model" + hf_model_path = "vandijklab/C2S-Scale-Gemma-2-2B" + self.csmodel = CSModel( + model_name_or_path=hf_model_path, + save_dir=self.save_dir, + save_name=self.save_name, + peft = True, + ) + + def test_csmodel_created_correctly(self): + assert self.csmodel.save_path == os.path.join(self.save_dir, self.save_name) + + def test_layers_are_created_correctly(self): + from peft import PeftModel, AutoPeftModelForCausalLM + + # Load the model back from disk using PEFT's loading method + loaded_model = AutoPeftModelForCausalLM.from_pretrained( + self.csmodel.save_path, + trust_remote_code=True, + is_trainable=False + ) + + # Verify it loaded as a PEFT model + assert isinstance(loaded_model, PeftModel), "Model is not a PeftModel" + + # Verify that LoRA layers are present in the loaded model + lora_modules = [name for name, module in loaded_model.named_modules() if "lora" in name.lower()] + assert len(lora_modules) > 0, "No LoRA layers found in the reloaded model modules" + + # Ensure the active adapter is set (typical for LoRA) + assert hasattr(loaded_model, "active_adapter"), "No active adapter found on the PEFT model"