From 4e4fd86a13c1a6c7df38bd0a2764a1da7ba67d90 Mon Sep 17 00:00:00 2001 From: ikergarcia1996 Date: Wed, 11 Sep 2024 22:10:25 +0200 Subject: [PATCH] Fixes --- .gitignore | 6 +- scripts/finetune.sh | 32 ++++++++++ src/train.py | 20 +++---- src/train2.py | 117 +++++++++++++++++++++++++++++++++++++ train_configs/gemma2B.yaml | 2 +- train_configs/llama8b.yaml | 62 ++++++++++++++++++++ 6 files changed, 226 insertions(+), 13 deletions(-) create mode 100644 scripts/finetune.sh create mode 100644 src/train2.py create mode 100644 train_configs/llama8b.yaml diff --git a/.gitignore b/.gitignore index 55f2e4c..dd7a1a3 100644 --- a/.gitignore +++ b/.gitignore @@ -125,4 +125,8 @@ dmypy.json *.ipynb -models/ \ No newline at end of file +models/ + +.slurm/ + +wandb/ \ No newline at end of file diff --git a/scripts/finetune.sh b/scripts/finetune.sh new file mode 100644 index 0000000..2294d85 --- /dev/null +++ b/scripts/finetune.sh @@ -0,0 +1,32 @@ +#!/bin/bash +#SBATCH --job-name=Odesia-finetune +#SBATCH --cpus-per-task=16 +#SBATCH --nodes=1 +#SBATCH --gres=gpu:2 +#SBATCH --mem=64G +#SBATCH --output=.slurm/Odesia-finetune.out.txt +#SBATCH --error=.slurm/Odesia-finetune.err.txt + + +source /ikerlariak/igarcia945/envs/pytorch2/bin/activate + + +export LC_ALL=en_US.UTF-8 +export LANG=en_US.UTF-8 +export LANGUAGE=en_US.UTF-8 +export TOKENIZERS_PARALLELISM=true +export TRANSFORMERS_NO_ADVISORY_WARNINGS=true +export WANDB_ENTITY=igarciaf +export WANDB_PROJECT=Odesia +export OMP_NUM_THREADS=16 +export WANDB__SERVICE_WAIT=300 + +echo CUDA_VISIBLE_DEVICES "${CUDA_VISIBLE_DEVICES}" + + +export PYTHONPATH="$PYTHONPATH:$PWD" +torchrun --standalone --master_port 37227 --nproc_per_node=2 src/train.py train_configs/gemma2B.yaml +torchrun --standalone --master_port 37227 --nproc_per_node=2 src/train.py train_configs/llama8b.yaml + +torchrun --standalone --master_port 37227 --nproc_per_node=2 src/evaluate.py --tasks all --model_name models/gemma-2b-it --output_dir results/finetune/gemma-2b-it +torchrun --standalone --master_port 37227 --nproc_per_node=2 src/evaluate.py --tasks all --model_name models/Llama-3.1-8B-Instruct --output_dir results/finetune/Llama-3.1-8B-Instruct diff --git a/src/train.py b/src/train.py index 613b4e1..6e9166f 100644 --- a/src/train.py +++ b/src/train.py @@ -6,10 +6,10 @@ from datasets import load_dataset from tqdm import tqdm from transformers import HfArgumentParser -from trl import SFTConfig, SFTTrainer, setup_chat_format +from trl import SFTConfig, SFTTrainer from src.config.config import ModelArguments -from src.model.load_model import load_model, merge_lora_model +from src.model.load_model import load_model from src.tasks import get_tasks @@ -34,13 +34,13 @@ def train(training_args: SFTConfig, model_args: ModelArguments): rope_scaling_factor=model_args.rope_scaling_factor, ) - tasks = get_tasks(tokenizer=tokenizer, tasks="all") + tasks = get_tasks(tokenizer=tokenizer, tasks=["exist_2022_t1_es"]) train_dataset = [] validation_dataset = [] for task_name, task in tqdm(tasks.items(), desc="Loading datasets"): print(f"Loading dataset for task {task_name}") - train = task.get_dataset_training(split="train") + train = task.get_dataset_training(split="train")[:100] train_dataset.extend(train) print(f"Train dataset size: {len(train)}") dev = task.get_dataset_training(split="dev") @@ -60,7 +60,6 @@ def train(training_args: SFTConfig, model_args: ModelArguments): for example in validation_dataset: print(json.dumps(example, ensure_ascii=False), file=f) - model, tokenizer = setup_chat_format(model, tokenizer) dataset = load_dataset( "json", data_files={ @@ -71,20 +70,19 @@ def train(training_args: SFTConfig, model_args: ModelArguments): }, ) + print(dataset) + trainer = SFTTrainer( model=model, tokenizer=tokenizer, args=training_args, - train_dataset=dataset, + train_dataset=dataset["train"], + eval_dataset=dataset["validation"], ) trainer.train() - merge_lora_model( - weights_path=model_args.model_name_or_path, - lora_weights_name_or_path=training_args.output_dir, - output_path=training_args.output_dir, - ) + trainer.save_model() if __name__ == "__main__": diff --git a/src/train2.py b/src/train2.py new file mode 100644 index 0000000..3a5a9f6 --- /dev/null +++ b/src/train2.py @@ -0,0 +1,117 @@ +import logging +import os +import sys + +from datasets import load_dataset +from transformers import AutoTokenizer, HfArgumentParser +from trl import SFTConfig, SFTTrainer +from unsloth import FastLanguageModel # Also reqires pip install xformers + +from src.config.config import ModelArguments +from src.model.model_utils import find_all_linear_names + + +def train(training_args: SFTConfig, model_args: ModelArguments): + os.makedirs(training_args.output_dir, exist_ok=True) + + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=model_args.model_name_or_path, + dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ + load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False + ) + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path + ) # FastLanguageModel doesn't load the chat template + + model = FastLanguageModel.get_peft_model( + model, + r=model_args.lora_r, + target_modules=find_all_linear_names(model), + lora_alpha=model_args.lora_alpha, + lora_dropout=model_args.lora_dropout, # Dropout = 0 is currently optimized + bias="none", # Bias = "none" is currently optimized + use_gradient_checkpointing=True, + random_state=3407, + ) + + """ + tasks = get_tasks(tokenizer=tokenizer, tasks="all") + + train_dataset = [] + validation_dataset = [] + for task_name, task in tqdm(tasks.items(), desc="Loading datasets"): + print(f"Loading dataset for task {task_name}") + train = task.get_dataset_training(split="train") + train_dataset.extend(train) + print(f"Train dataset size: {len(train)}") + dev = task.get_dataset_training(split="dev") + validation_dataset.extend(dev) + print(f"Validation dataset size: {len(dev)}") + + with open(os.path.join(training_args.output_dir, "train_dataset.jsonl"), "w") as f: + for example in train_dataset: + print(json.dumps(example, ensure_ascii=False), file=f) + + print(f"Full training dataset size: {len(train_dataset)}") + print(f"Full validation dataset size: {len(validation_dataset)}") + + + with open( + os.path.join(training_args.output_dir, "validation_dataset.jsonl"), "w" + ) as f: + for example in validation_dataset: + print(json.dumps(example, ensure_ascii=False), file=f) + """ + print("Loading datasets") + dataset = load_dataset( + "json", + data_files={ + "train": os.path.join(training_args.output_dir, "train_dataset.jsonl"), + "validation": os.path.join( + training_args.output_dir, "validation_dataset.jsonl" + ), + }, + ) + + print(dataset) + + trainer = SFTTrainer( + model=model, + tokenizer=tokenizer, + args=training_args, + train_dataset=dataset["train"], + eval_dataset=dataset["validation"], + ) + + trainer.train() + + trainer.save_model() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + + parser = HfArgumentParser((ModelArguments, SFTConfig)) + logging.info(f"Sys args {sys.argv}") + + if len(sys.argv) > 0 and sys.argv[-1].endswith(".json"): + # If we pass only one argument to the script, and it's the path to a json file, + # let's parse it to get our arguments. + logging.info(f"Loading json config {sys.argv[-1]}") + model_args, training_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[-1]) + ) + + elif len(sys.argv) > 0 and sys.argv[-1].endswith(".yaml"): + # If we pass only one argument to the script, and it's the path to a yaml file, + # let's parse it to get our arguments. + logging.info(f"Loading yaml config {sys.argv[-1]}") + model_args, training_args = parser.parse_yaml_file( + yaml_file=os.path.abspath(sys.argv[-1]) + ) + else: + logging.info("No config file passed, using command line arguments.") + model_args, training_args = parser.parse_args_into_dataclasses() + + train(training_args, model_args) diff --git a/train_configs/gemma2B.yaml b/train_configs/gemma2B.yaml index d006ebc..ea85699 100644 --- a/train_configs/gemma2B.yaml +++ b/train_configs/gemma2B.yaml @@ -6,7 +6,7 @@ quantization: null gradient_checkpointing: true force_auto_device_map: false use_flash_attention: true -packing: true +packing: false output_dir: models/gemma-2b-it overwrite_output_dir: true diff --git a/train_configs/llama8b.yaml b/train_configs/llama8b.yaml new file mode 100644 index 0000000..0b57711 --- /dev/null +++ b/train_configs/llama8b.yaml @@ -0,0 +1,62 @@ +#Training args +model_name_or_path: meta-llama/Meta-Llama-3.1-8B-Instruct +torch_dtype: bfloat16 +use_lora: true +quantization: null +gradient_checkpointing: true +force_auto_device_map: false +use_flash_attention: true +packing: false + +output_dir: models/Llama-3.1-8B-Instruct +overwrite_output_dir: true +load_best_model_at_end: false +metric_for_best_model: +greater_is_better: true +save_strategy: "no" +save_only_model: true +save_total_limit: 1 + +# evaluation +do_train: true +do_eval: true +do_predict: false +evaluation_strategy: "epoch" + +per_device_train_batch_size: 8 +per_device_eval_batch_size: 2 +gradient_accumulation_steps: 8 + +# optimizer settings +optim: adamw_torch_fused +learning_rate: 0.0003 +weight_decay: 0.001 +num_train_epochs: 3 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +adam_epsilon: 0.0000001 + +# lora settings +lora_r: 128 +lora_alpha: 256 +lora_dropout: 0.05 +lora_target_modules: + - all + +# reporting +logging_strategy: steps +logging_first_step: true +logging_steps: 5 +report_to: wandb +run_name: "Llama-3.1-8B-Instruct" +disable_tqdm: false + +# hub settings +push_to_hub: false +resume_from_checkpoint: false + +# performance +bf16: true +fp16: false +torch_compile: false +ddp_find_unused_parameters: false \ No newline at end of file